From c6e60b1bd2125b169ed3c17ee37fbe0e56c4b8ad Mon Sep 17 00:00:00 2001 From: Saksham Gupta Date: Wed, 2 Aug 2023 12:34:08 +0530 Subject: [PATCH] Adds Cleartext float via sytorch (#186) * float added from kanav-gpt in sytorch * cleartext_fp added in Onnxbridge * i64->T in tensor.h * debug statement * typecast corrected in module.h * removing debug statement * added testing for sytorch float ct * rename tests --- .github/workflows/onnx_bridge.yml | 82 ++++- OnnxBridge/LLAMA/compile_llama.sh | 2 +- OnnxBridge/LLAMA/sytorchBackendRep.py | 43 ++- OnnxBridge/backend.py | 6 +- OnnxBridge/main.py | 14 +- OnnxBridge/tests/conftest.py | 10 +- OnnxBridge/tests/utils.py | 12 + sytorch/include/sytorch/backend/default.h | 38 +++ sytorch/include/sytorch/backend/float.h | 51 +++ sytorch/include/sytorch/layers/layers.h | 4 +- sytorch/include/sytorch/module.h | 9 +- sytorch/include/sytorch/tensor.h | 10 +- sytorch/src/sytorch/backend/float.cpp | 399 ++++++++++++++++++++++ 13 files changed, 650 insertions(+), 30 deletions(-) create mode 100644 sytorch/include/sytorch/backend/default.h create mode 100644 sytorch/include/sytorch/backend/float.h create mode 100644 sytorch/src/sytorch/backend/float.cpp diff --git a/.github/workflows/onnx_bridge.yml b/.github/workflows/onnx_bridge.yml index 5bd94409..03fd0854 100644 --- a/.github/workflows/onnx_bridge.yml +++ b/.github/workflows/onnx_bridge.yml @@ -13,7 +13,7 @@ on: # A workflow run is made up of one or more jobs that can run sequentially or in parallel jobs: - LLAMA: + Sytorch-LLAMA: # The type of runner that the job will run on runs-on: ubuntu-latest container: @@ -37,7 +37,7 @@ jobs: run: | apt-get update -y - - name: LLAMA + - name: Sytorch LLAMA if: always() run: | cd OnnxBridge/tests/ @@ -45,7 +45,7 @@ jobs: pytest --backend LLAMA -v -k 'hinet and not batch' shell: bash - LLAMA-batch: + Sytorch-LLAMA-batch: # The type of runner that the job will run on runs-on: ubuntu-latest container: @@ -69,7 +69,7 @@ jobs: run: | apt-get update -y - - name: LLAMA Batch + - name: Sytorch LLAMA Batch if: always() run: | cd OnnxBridge/tests/ @@ -77,7 +77,7 @@ jobs: pytest --backend LLAMA -v -k 'hinet and batch' --batch_size 5 shell: bash - LLAMA-ct: + Sytorch-LLAMA-ct: # The type of runner that the job will run on runs-on: ubuntu-latest container: @@ -101,7 +101,7 @@ jobs: run: | apt-get update -y - - name: LLAMA Cleartext + - name: Sytorch LLAMA Cleartext if: always() run: | cd OnnxBridge/tests/ @@ -110,7 +110,7 @@ jobs: pytest --backend CLEARTEXT_LLAMA -v -k 'chexpert and not batch' shell: bash - LLAMA-ct-batch: + Sytorch-LLAMA-ct-batch: # The type of runner that the job will run on runs-on: ubuntu-latest container: @@ -134,7 +134,7 @@ jobs: run: | apt-get update -y - - name: LLAMA Cleartext Batch + - name: Sytorch LLAMA Cleartext Batch if: always() run: | cd OnnxBridge/tests/ @@ -142,6 +142,72 @@ jobs: pytest --backend CLEARTEXT_LLAMA -v -k 'hinet and batch' --batch_size 5 shell: bash + Sytorch-ct-fp: + # The type of runner that the job will run on + runs-on: ubuntu-latest + container: + image: drunkenlegend/onnxbridge:latest + options: --user root + + # Steps represent a sequence of tasks that will be executed as part of the job + steps: + - name: Update Git + run: | + add-apt-repository ppa:git-core/ppa -y + apt-get update + apt-get install git -y + + - name: Checkout repository + uses: actions/checkout@v3 + with: + submodules: 'true' + + - name: Install dependencies + run: | + apt-get update -y + + - name: Sytorch Cleartext Floating Point + if: always() + run: | + cd OnnxBridge/tests/ + pytest --backend CLEARTEXT_fp -v -k 'lenet and not batch' + pytest --backend CLEARTEXT_fp -v -k 'hinet and not batch' + pytest --backend CLEARTEXT_fp -v -k 'chexpert and not batch' + shell: bash + + Sytorch-ct-fp-batch: + # The type of runner that the job will run on + runs-on: ubuntu-latest + container: + image: drunkenlegend/onnxbridge:latest + options: --user root + + # Steps represent a sequence of tasks that will be executed as part of the job + steps: + - name: Update Git + run: | + add-apt-repository ppa:git-core/ppa -y + apt-get update + apt-get install git -y + + - name: Checkout repository + uses: actions/checkout@v3 + with: + submodules: 'true' + + - name: Install dependencies + run: | + apt-get update -y + + - name: Sytorch Cleartext Floating Point Batch + if: always() + run: | + cd OnnxBridge/tests/ + pytest --backend CLEARTEXT_fp -v -k 'lenet and batch' --batch_size 2 + pytest --backend CLEARTEXT_fp -v -k 'hinet and batch' --batch_size 5 + shell: bash + + Secfloat: # The type of runner that the job will run on runs-on: ubuntu-latest diff --git a/OnnxBridge/LLAMA/compile_llama.sh b/OnnxBridge/LLAMA/compile_llama.sh index 38fb6271..c2d5d032 100755 --- a/OnnxBridge/LLAMA/compile_llama.sh +++ b/OnnxBridge/LLAMA/compile_llama.sh @@ -44,7 +44,7 @@ find_package(Threads REQUIRED) add_subdirectory($sytorch_dir/ext/cryptoTools $pd/cryptoTools) add_subdirectory($sytorch_dir/ext/llama $pd/llama) add_executable($BINARY_NAME - ../$FSS_CPP_FILE $sytorch_dir/src/sytorch/random.cpp $sytorch_dir/src/sytorch/backend/cleartext.cpp + ../$FSS_CPP_FILE $sytorch_dir/src/sytorch/random.cpp $sytorch_dir/src/sytorch/backend/cleartext.cpp $sytorch_dir/src/sytorch/backend/float.cpp ) target_include_directories($BINARY_NAME PUBLIC diff --git a/OnnxBridge/LLAMA/sytorchBackendRep.py b/OnnxBridge/LLAMA/sytorchBackendRep.py index 36ae78e9..6c5a2ed8 100644 --- a/OnnxBridge/LLAMA/sytorchBackendRep.py +++ b/OnnxBridge/LLAMA/sytorchBackendRep.py @@ -137,6 +137,45 @@ def cleartext_post(code_list, program, scale, mode, indent): ) +def cleartext_fp_post(code_list, program, scale, mode, indent): + # Input + n = program[0].shape[0] + c = program[0].shape[1] + dims = program[0].shape[2:] + # n, c, h, w = program[0].shape + code_list.append( + f""" + +int main(int argc, char**__argv){'{'} + + prngWeights.SetSeed(osuCrypto::toBlock(0, 0)); + prngStr.SetSeed(osuCrypto::toBlock(time(NULL))); + + int party = atoi(__argv[1]); + std::string ip = "127.0.0.1"; + + srand(time(NULL)); + + const u64 scale = 0; + + if (party == 0) {'{'} + Net net; + net.init(scale); + std::string weights_file = __argv[2]; + net.load(weights_file); + Tensor input({'{'}{iterate_list([n]+ dims +[c])}{'}'}); + input.input_nchw(scale); + print_dot_graph(net.root); + net.forward(input); + net.activation.print(); + return 0; + {'}'} + +{'}'} + """ + ) + + def llama_pre(code_list, program, scale, mode, bitlength, indent): code_list.append("#include ") code_list.append("#include ") @@ -265,7 +304,7 @@ def prepare_export(program, var_dict, value_info, mode, scale, bitlength, backen # Start CPP program number_of_nodes = 0 - if backend == "CLEARTEXT_LLAMA": + if backend == "CLEARTEXT_LLAMA" or backend == "CLEARTEXT_fp": cleartext_pre(code_list, program, scale, mode, indent) elif backend == "LLAMA": llama_pre(code_list, program, scale, mode, bitlength, indent) @@ -320,6 +359,8 @@ def prepare_export(program, var_dict, value_info, mode, scale, bitlength, backen if backend == "CLEARTEXT_LLAMA": cleartext_post(code_list, program, scale, mode, indent) + elif backend == "CLEARTEXT_fp": + cleartext_fp_post(code_list, program, scale, mode, indent) elif backend == "LLAMA": llama_post(code_list, program, scale, mode, bitlength, indent) diff --git a/OnnxBridge/backend.py b/OnnxBridge/backend.py index 3a6876f4..07628523 100644 --- a/OnnxBridge/backend.py +++ b/OnnxBridge/backend.py @@ -89,7 +89,7 @@ def preprocess_model(cls, model_fname, logging_level, backend): logger.error("Model Not Supported") sys.exit() - if backend in ["CLEARTEXT_LLAMA", "LLAMA"]: + if backend in ["CLEARTEXT_LLAMA", "LLAMA", "CLEARTEXT_fp"]: weights_path = optimizations.dump_model_weights_as_dat( model, model_abs_dir, model_name ) @@ -151,7 +151,7 @@ def is_compatible(cls, model, backend, device: str = "2PC", **kwargs): ] if backend in ["SECFLOAT", "SECFLOAT_CLEARTEXT"]: implemented = implemented_secfloat - elif backend in ["CLEARTEXT_LLAMA", "LLAMA"]: + elif backend in ["CLEARTEXT_LLAMA", "LLAMA", "CLEARTEXT_fp"]: implemented = implemented_sytorch for node in model.graph.node: if node.op_type not in implemented: @@ -217,7 +217,7 @@ def prepare( backend_rep = FzpcBackendRep( program, value_info, var_dict, path, file_name[:-5], backend ) - elif backend in ["CLEARTEXT_LLAMA", "LLAMA"]: + elif backend in ["CLEARTEXT_LLAMA", "LLAMA", "CLEARTEXT_fp"]: backend_rep = SytorchBackendRep( program, value_info, var_dict, path, file_name[:-5] ) diff --git a/OnnxBridge/main.py b/OnnxBridge/main.py index fd92c8bb..ab318620 100644 --- a/OnnxBridge/main.py +++ b/OnnxBridge/main.py @@ -6,7 +6,13 @@ def parse_args(): - backend = ["CLEARTEXT_LLAMA", "LLAMA", "SECFLOAT", "SECFLOAT_CLEARTEXT"] + backend = [ + "CLEARTEXT_LLAMA", + "LLAMA", + "SECFLOAT", + "SECFLOAT_CLEARTEXT", + "CLEARTEXT_fp", + ] parser = argparse.ArgumentParser() parser.add_argument("--path", required=True, type=str, help="Path to the Model.") parser.add_argument( @@ -32,7 +38,7 @@ def parse_args(): ) parser.add_argument( "--generate", - required=any(b in argv for b in [backend[2], backend[3]]), + required=True, type=str, choices=["code", "executable"], default="code", @@ -49,7 +55,7 @@ def main(): mode = "u64" if args.backend == "LLAMA" else "i64" # Export the Model as Secfloat and writes to a cpp file - if args.backend in ["CLEARTEXT_LLAMA", "LLAMA"]: + if args.backend in ["CLEARTEXT_LLAMA", "LLAMA", "CLEARTEXT_fp"]: main_path = os.path.dirname(os.path.abspath(__file__)) file_path = os.path.join(main_path, "LLAMA") backendrep.export_model(mode, args.scale, args.bitlength, args.backend) @@ -65,7 +71,7 @@ def main(): os.system( f"{file_path}/compile_secfloat.sh {args.path[:-5]}_secfloat{ct}.cpp" ) - elif args.backend in ["CLEARTEXT_LLAMA", "LLAMA"]: + elif args.backend in ["CLEARTEXT_LLAMA", "LLAMA", "CLEARTEXT_fp"]: os.system( f"{file_path}/compile_llama.sh {args.path[:-5]}_{args.backend}_{args.scale}.cpp" ) diff --git a/OnnxBridge/tests/conftest.py b/OnnxBridge/tests/conftest.py index 7bffbb9f..5067b98c 100644 --- a/OnnxBridge/tests/conftest.py +++ b/OnnxBridge/tests/conftest.py @@ -9,8 +9,14 @@ def pytest_addoption(parser): parser.addoption( "--backend", action="store", - choices=["CLEARTEXT_LLAMA", "LLAMA", "SECFLOAT", "SECFLOAT_CLEARTEXT"], - help="backend : CLEARTEXT_LLAMA | LLAMA | SECFLOAT | SECFLOAT_CLEARTEXT", + choices=[ + "CLEARTEXT_LLAMA", + "LLAMA", + "SECFLOAT", + "SECFLOAT_CLEARTEXT", + "CLEARTEXT_fp", + ], + help="backend : CLEARTEXT_LLAMA | CLEARTEXT_fp | LLAMA | SECFLOAT | SECFLOAT_CLEARTEXT", required=True, ) parser.addoption( diff --git a/OnnxBridge/tests/utils.py b/OnnxBridge/tests/utils.py index c27a8cbf..edfefe35 100644 --- a/OnnxBridge/tests/utils.py +++ b/OnnxBridge/tests/utils.py @@ -41,6 +41,10 @@ def compile_model(backend): os.system( f"python3 {ezpc_dir}/OnnxBridge/main.py --path model.onnx --generate executable --backend {backend} --scale 15 --bitlength 40 " ) + elif backend == "CLEARTEXT_fp": + os.system( + f"python3 {ezpc_dir}/OnnxBridge/main.py --path model.onnx --generate executable --backend {backend} " + ) elif backend == "SECFLOAT" or backend == "SECFLOAT_CLEARTEXT": os.system( f"python3 {ezpc_dir}/OnnxBridge/main.py --path model.onnx --generate executable --backend {backend} " @@ -62,6 +66,14 @@ def run_backend(backend, input): os.system( f"./model_CLEARTEXT_LLAMA_15 0 model_input_weights.dat < {input} > {raw_output}" ) + elif backend == "CLEARTEXT_fp": + # check if model compiled + assert os.path.exists("model_CLEARTEXT_fp_0") + assert os.path.exists("model_input_weights.dat") + + os.system( + f"./model_CLEARTEXT_fp_0 0 model_input_weights.dat < {input} > {raw_output}" + ) elif backend == "LLAMA": # check if model compiled assert os.path.exists("model_LLAMA_15") diff --git a/sytorch/include/sytorch/backend/default.h b/sytorch/include/sytorch/backend/default.h new file mode 100644 index 00000000..66bbc93c --- /dev/null +++ b/sytorch/include/sytorch/backend/default.h @@ -0,0 +1,38 @@ +#pragma once + +#include +#include + +template +Backend *defaultBackend() +{ + if constexpr (std::is_floating_point::value) + { + return new FloatClearText(); + } + else + { + return new ClearText(); + } +} + +template +inline T type_cast(float val); + +template <> +float type_cast(float val) +{ + return val; +} + +template <> +i64 type_cast(float val) +{ + return (i64)val; +} + +template <> +u64 type_cast(float val) +{ + return (u64(i64(val))); +} \ No newline at end of file diff --git a/sytorch/include/sytorch/backend/float.h b/sytorch/include/sytorch/backend/float.h new file mode 100644 index 00000000..9779bb10 --- /dev/null +++ b/sytorch/include/sytorch/backend/float.h @@ -0,0 +1,51 @@ +#pragma once +#include "backend.h" +#include +#include + +template +class FloatClearText : public Backend +{ +private: +public: + void truncate(T *in, T *out, u64 shift, u64 size, u8 mode); + + template + void fastfor(u64 size, Functor f) + { +#pragma omp parallel for + for (u64 i = 0; i < size; i++) + { + f(i); + } + } + + void matmul(const Tensor2D &a, const Tensor2D &b, Tensor2D &c); + void matmul_triangular(const Tensor2D &a, const Tensor2D &b, Tensor2D &c); + void matmulTransposeA(const Tensor2D &a, const Tensor2D &b, Tensor2D &c); + void matmulTransposeB(const Tensor2D &a, const Tensor2D &b, Tensor2D &c); + + void conv2D(u64 fh, u64 fw, u64 padding, u64 stride, u64 ci, u64 co, const Tensor4D &input, const Tensor2D &filter, Tensor4D &output); + void conv3D(u64 fd, u64 fh, u64 fw, u64 pd, u64 ph, u64 pw, u64 sd, u64 sh, u64 sw, u64 dd, u64 dh, u64 dw, u64 ci, u64 co, const Tensor5D &input, const Tensor2D &filter, Tensor5D &output); + void convTranspose3D(u64 fd, u64 fh, u64 fw, u64 pd, u64 ph, u64 pw, u64 sd, u64 sh, u64 sw, u64 ci, u64 co, const Tensor5D &input, const Tensor2D &filter, Tensor5D &output); + + void relu(const Tensor &in, const Tensor &out, const Tensor &drelu, u64 scale, int mode); + void truncate(T &in, u64 shift); + void div(Tensor &in, T divisor, u64 scale); + void div(T &in, T divisor, u64 scale); + void sumPool2D(u64 ks, u64 padding, u64 stride, const Tensor4D &in, Tensor4D &out); + void avgPool2D(u64 ks, u64 padding, u64 stride, const Tensor4D &in, Tensor4D &out, u64 scale); + void maxPool2D(u64 ks, u64 padding, u64 stride, const Tensor4D &in, Tensor4D &out, Tensor4D &maxIdx, u64 scale, u8 mode); + + void batchNormInference(const Tensor1D &A, const Tensor1D &B, const Tensor &x, Tensor &y, u64 scale); + void add(const std::vector *> &in, Tensor &out); + void gelu(const Tensor &in, const Tensor &out, u64 scale, u64 mode = 0); + void tanh(const Tensor &in, const Tensor &out, u64 scale); + void softmax(Tensor &in, Tensor &out, u64 scale, u64 mode = 0); + void layernorm(const Tensor1D &A, const Tensor1D &B, const Tensor &x, Tensor &y, u64 scale); + void addbias(Tensor &x, const Tensor1D &bias); + void scalarmul(Tensor &x, T scalar, Tensor &y); + void scalardiv(Tensor &x, double scalar, Tensor &y, u64 scale, u64 mode); + void attention_mask(Tensor &x, T scalar, Tensor &y); + void softmax_triangular(Tensor &in, Tensor &out, u64 scale, u64 mode = 0); +}; \ No newline at end of file diff --git a/sytorch/include/sytorch/layers/layers.h b/sytorch/include/sytorch/layers/layers.h index 72f9a181..cff45662 100644 --- a/sytorch/include/sytorch/layers/layers.h +++ b/sytorch/include/sytorch/layers/layers.h @@ -1,7 +1,7 @@ #pragma once #include #include -#include +#include #include template @@ -26,7 +26,7 @@ class Layer { LayerGraphNode *node = nullptr; Layer(const std::string &name) : activation({0}), name(name) { - backend = new ClearText(); + backend = defaultBackend(); } virtual void _initScale(u64 scale) {}; diff --git a/sytorch/include/sytorch/module.h b/sytorch/include/sytorch/module.h index fe26a08a..bce4091d 100644 --- a/sytorch/include/sytorch/module.h +++ b/sytorch/include/sytorch/module.h @@ -3,6 +3,7 @@ #include #include #include +#include template class SytorchModule { @@ -164,15 +165,15 @@ class SytorchModule { auto meanPtr = floatWeights + wIdx + 2 * channel; auto varPtr = floatWeights + wIdx + 3 * channel; for (int j = 0; j < channel; ++j) { - bn->A(j) = i64((gammaPtr[j] / std::sqrt(varPtr[j])) * (1LL << scale)); - bn->B(j) = i64((betaPtr[j] - gammaPtr[j] * meanPtr[j] / std::sqrt(varPtr[j])) * (1LL << (2 * scale))); + bn->A(j) = type_cast((gammaPtr[j] / std::sqrt(varPtr[j])) * (1LL << scale)); + bn->B(j) = type_cast((betaPtr[j] - gammaPtr[j] * meanPtr[j] / std::sqrt(varPtr[j])) * (1LL << (2 * scale))); } wIdx += 4 * channel; } else { auto weights = layer->getweights(); for (u64 j = 0; j < weights.size; j++) { - weights.data[j] = i64(floatWeights[wIdx + j] * (1LL << scale)); + weights.data[j] = type_cast(floatWeights[wIdx + j] * (1LL << scale)); } wIdx += weights.size; @@ -181,7 +182,7 @@ class SytorchModule { if (layer->useBias) { for (u64 j = 0; j < bias.size; ++j) { - bias.data[j] = i64(floatWeights[wIdx + j] * (1LL << (2*scale))); + bias.data[j] = type_cast(floatWeights[wIdx + j] * (float)(1LL << (2 * scale))); } wIdx += bias.size; diff --git a/sytorch/include/sytorch/tensor.h b/sytorch/include/sytorch/tensor.h index fc2ac7bd..389651d1 100644 --- a/sytorch/include/sytorch/tensor.h +++ b/sytorch/include/sytorch/tensor.h @@ -188,7 +188,7 @@ class Tensor { { double d; std::cin >> d; - data[i] = (i64)(d * (1LL << scale)); + data[i] = (T)(d * (1LL << scale)); } } @@ -209,9 +209,9 @@ class Tensor { u64 curr_rest = i % rest_size; u64 new_idx = curr_batch * (num_channel * rest_size) + curr_rest * num_channel + curr_channel; #ifdef Do_Masking - data[new_idx] = (i64)d; + data[new_idx] = (T)d; #else - data[new_idx] = (i64)(d * (1LL << scale)); + data[new_idx] = (T)(d * (1LL << scale)); #endif } } @@ -283,7 +283,7 @@ class Tensor { { for (int m = 0; m < d5; m++) { - this->data[i * d2 * d3 * d4 * d5 + j * d3 * d4 * d5 + k * d4 * d5 + l * d5 + m] = (i64)(arr[i][j][k][l][m] * scale); + this->data[i * d2 * d3 * d4 * d5 + j * d3 * d4 * d5 + k * d4 * d5 + l * d5 + m] = (T)(arr[i][j][k][l][m] * scale); } } } @@ -301,7 +301,7 @@ class Tensor { file.close(); for(u64 i = 0; i < size(); ++i) { - data[i] = (i64)(floatInput[i] * (1LL << scale)); + data[i] = (T)(floatInput[i] * (1LL << scale)); } delete[] floatInput; } diff --git a/sytorch/src/sytorch/backend/float.cpp b/sytorch/src/sytorch/backend/float.cpp new file mode 100644 index 00000000..24ab68ce --- /dev/null +++ b/sytorch/src/sytorch/backend/float.cpp @@ -0,0 +1,399 @@ + +#include +#include + +template +void FloatClearText::matmul(const Tensor2D &a, const Tensor2D &b, Tensor2D &c) +{ + assert(a.d2 == b.d1); + assert(c.d1 == a.d1); + assert(c.d2 == b.d2); + Eigen::Map> eA(a.data, a.d1, a.d2); + Eigen::Map> eB(b.data, b.d1, b.d2); + Eigen::Map> eC(c.data, c.d1, c.d2); + eC = eA * eB; +} + +template +void FloatClearText::matmul_triangular(const Tensor2D &a, const Tensor2D &b, Tensor2D &c) +{ + assert(a.d2 == b.d1); + assert(c.d1 == a.d1); + assert(c.d2 == b.d2); + Eigen::Map> eA(a.data, a.d1, a.d2); + Eigen::Map> eB(b.data, b.d1, b.d2); + Eigen::Map> eC(c.data, c.d1, c.d2); + eC = (eA * eB).template triangularView(); +} + +template +void FloatClearText::matmulTransposeA(const Tensor2D &a, const Tensor2D &b, Tensor2D &c) +{ + assert(a.d1 == b.d1); + assert(c.d1 == a.d2); + assert(c.d2 == b.d2); + Eigen::Map> eA(a.data, a.d2, a.d1); + Eigen::Map> eB(b.data, b.d1, b.d2); + Eigen::Map> eC(c.data, c.d1, c.d2); + eC = eA * eB; +} + +template +void FloatClearText::matmulTransposeB(const Tensor2D &a, const Tensor2D &b, Tensor2D &c) +{ + assert(a.d2 == b.d2); + assert(c.d1 == a.d1); + assert(c.d2 == b.d1); + Eigen::Map> eA(a.data, a.d1, a.d2); + Eigen::Map> eB(b.data, b.d2, b.d1); + Eigen::Map> eC(c.data, c.d1, c.d2); + eC = eA * eB; +} + +template +void FloatClearText::conv2D(u64 fh, u64 fw, u64 padding, u64 stride, u64 ci, u64 co, const Tensor4D &input, const Tensor2D &filter, Tensor4D &output) +{ + assert(input.d4 == ci); + assert(filter.d1 == co); + assert(filter.d2 == fh * fw * ci); + u64 newH = (((input.d2 + 2 * padding - fh) / stride) + 1); + u64 newW = (((input.d3 + 2 * padding - fw) / stride) + 1); + assert(output.d1 == input.d1); + assert(output.d2 == newH); + assert(output.d3 == newW); + assert(output.d4 == co); + Tensor2D reshapedInput = reshapeInputTransposed(input, padding, stride, fh, fw); + Tensor2D tempOutput(filter.d1, reshapedInput.d1); + matmulTransposeB(filter, reshapedInput, tempOutput); + reshapeOutput(tempOutput, input.d1, (((input.d2 + 2 * padding - fh) / stride) + 1), (((input.d3 + 2 * padding - fw) / stride) + 1), co, output); +} + +template +void FloatClearText::conv3D(u64 fd, u64 fh, u64 fw, u64 pd, u64 ph, u64 pw, u64 sd, u64 sh, u64 sw, u64 dd, u64 dh, u64 dw, u64 ci, u64 co, const Tensor5D &input, const Tensor2D &filter, Tensor5D &output) +{ + assert(input.d5 == ci); + assert(filter.d1 == co); + assert(filter.d2 == fd * fh * fw * ci); + always_assert(dd == 1); + always_assert(dh == 1); + always_assert(dw == 1); + u64 newD = (((input.d2 + 2 * pd - fd - (fd - 1) * (dd - 1)) / sd) + 1); + u64 newH = (((input.d3 + 2 * ph - fh - (fh - 1) * (dh - 1)) / sh) + 1); + u64 newW = (((input.d4 + 2 * pw - fw - (fw - 1) * (dw - 1)) / sw) + 1); + assert(output.d1 == input.d1); + assert(output.d2 == newD); + assert(output.d3 == newH); + assert(output.d4 == newW); + assert(output.d5 == co); + + Tensor2D reshapedInput = reshapeInputTransposed3d(input, pd, ph, pw, sd, sh, sw, fd, fh, fw); + Tensor2D tempOutput(filter.d1, reshapedInput.d1); + matmulTransposeB(filter, reshapedInput, tempOutput); + reshapeOutput3d(tempOutput, input.d1, newD, newH, newW, co, output); +} + +template +void FloatClearText::convTranspose3D(u64 fd, u64 fh, u64 fw, u64 pd, u64 ph, u64 pw, u64 sd, u64 sh, u64 sw, u64 ci, u64 co, const Tensor5D &input, const Tensor2D &filter, Tensor5D &output) +{ + assert(input.d5 == ci); + assert(filter.d1 == co); + assert(filter.d2 == fd * fh * fw * ci); + u64 newD = (((input.d2 - 1) * sd + fd - 2 * pd)); + u64 newH = (((input.d3 - 1) * sh + fh - 2 * ph)); + u64 newW = (((input.d4 - 1) * sw + fw - 2 * pw)); + assert(output.d1 == input.d1); + assert(output.d2 == newD); + assert(output.d3 == newH); + assert(output.d4 == newW); + assert(output.d5 == co); + + convTranspose3dLoop(input.d1, input.d2, input.d3, input.d4, input.d5, fd, fh, fw, co, + pd, pd, ph, ph, pw, pw, sd, sh, sw, + output.d2, output.d3, output.d4, input.data, filter.data, output.data); +} + +template +void FloatClearText::relu(const Tensor &in, const Tensor &out, const Tensor &drelu, u64 scale, int mode) +{ + assert(in.is_same_shape(out)); + assert(in.is_same_shape(drelu)); + fastfor(in.size(), [&](u64 i) + { + drelu.data[i] = (T)(in.data[i] > 0); + out.data[i] = (in.data[i] > 0) ? in.data[i] : 0; }); +} + +template +void FloatClearText::truncate(T *in, T *out, u64 shift, u64 size, u8 mode) +{ + always_assert(shift == 0); + fastfor(size, [&](u64 i) + { out[i] = in[i]; }); +} + +template +void FloatClearText::div(Tensor &in, T divisor, u64 scale) +{ + always_assert(scale == 0); + + fastfor(in.size(), [&](u64 i) + { in.data[i] /= divisor; }); +} + +template +void FloatClearText::div(T &in, T divisor, u64 scale) +{ + in = in / divisor; +} + +template +void FloatClearText::sumPool2D(u64 ks, u64 padding, u64 stride, const Tensor4D &in, Tensor4D &out) +{ + assert(in.d1 == out.d1); + assert(in.d4 == out.d4); + u64 newH = (in.d2 + 2 * padding - ks) / stride + 1; + u64 newW = (in.d3 + 2 * padding - ks) / stride + 1; + assert(out.d2 == newH); + assert(out.d3 == newW); + fastfor(in.d1, [&](int i) + { + for(int j = 0; j < newH; j++) { + for(int k = 0; k < newW; k++) { + for(int l = 0; l < in.d4; l++) { + T sum = 0; + for(int m = 0; m < ks; m++) { + for(int n = 0; n < ks; n++) { + sum += in(i, j*stride+m, k*stride+n, l); + } + } + out(i, j, k, l) = sum; + } + } + } }); +} + +template +void FloatClearText::avgPool2D(u64 ks, u64 padding, u64 stride, const Tensor4D &in, Tensor4D &out, u64 scale) +{ + sumPool2D(ks, padding, stride, in, out); + auto out_nd = out.as_nd(); + div(out_nd, (T)(ks * ks), scale); +} + +template +void FloatClearText::maxPool2D(u64 ks, u64 padding, u64 stride, const Tensor4D &in, Tensor4D &out, Tensor4D &maxIdx, u64 scale, u8 mode) +{ + assert(in.d1 == out.d1); + assert(in.d4 == out.d4); + u64 newH = (in.d2 + 2 * padding - ks) / stride + 1; + u64 newW = (in.d3 + 2 * padding - ks) / stride + 1; + assert(out.d2 == newH); + assert(out.d3 == newW); + fastfor(in.d1, [&](int i) + { + for(int j = 0; j < newH; j++) { + for(int k = 0; k < newW; k++) { + for(int l = 0; l < in.d4; l++) { + T max = std::numeric_limits::lowest(); + u64 maxIdxI = 0; + u64 maxIdxJ = 0; + for(int m = 0; m < ks; m++) { + for(int n = 0; n < ks; n++) { + auto h2 = j*stride+m-padding; + auto w2 = k*stride+n-padding; + T val = 0; + if (h2 < in.d2 && w2 < in.d3 && h2 >= 0 && w2 >= 0) + val = in(i, h2, w2, l); + if(val > max) { + max = val; + maxIdxI = m; + maxIdxJ = n; + } + } + } + out(i, j, k, l) = max; + maxIdx(i, j, k, l) = maxIdxI * ks + maxIdxJ; + } + } + } }); +} + +template +void FloatClearText::batchNormInference(const Tensor1D &A, const Tensor1D &B, const Tensor &x, Tensor &y, u64 scale) +{ + assert(A.d1 == B.d1); + assert(A.d1 == x.shape.back()); + assert(x.is_same_shape(y)); + u64 channels = x.shape.back(); + + fastfor(x.size(), [&](u64 i) + { y.data[i] = x.data[i] * A(i % channels) + B(i % channels); }); +} + +template +void FloatClearText::add(const std::vector *> &in, Tensor &out) +{ + always_assert(in.size() > 0); + always_assert(out.size() == in[0]->size()); + for (int i = 0; i < in.size(); i++) + { + always_assert(out.size() == in[i]->size()); + } + fastfor(out.size(), [&](int i) + { + T sum = 0; + for (int j = 0; j < in.size(); j++) { + sum += in[j]->data[i]; + } + out.data[i] = sum; }); +} + +template +void FloatClearText::gelu(const Tensor &in, const Tensor &out, u64 scale, u64 mode) +{ + always_assert(scale == 0); + fastfor(in.size(), [&](u64 i) + { + T x = in.data[i]; + out.data[i] = 0.5 * x * (1 + erf(x / sqrt(2.0))); }); +} + +template +void FloatClearText::softmax(Tensor &_in, Tensor &_out, u64 scale, u64 mode) +{ + always_assert(_in.shape.size() == 2); + always_assert(_out.shape.size() == 2); + always_assert(_in.shape[0] == _out.shape[0]); + always_assert(_in.shape[1] == _out.shape[1]); + always_assert((scale == 0)); + + auto in = _in.as_2d(); + auto out = _out.as_2d(); + + auto batchSize = in.d1; + auto numClasses = in.d2; + for (int b = 0; b < batchSize; ++b) + { + T max = in(b, 0); + for (u64 j = 1; j < numClasses; ++j) + { + if (in(b, j) > max) + { + max = in(b, j); + } + } + + double den = 0.0; + double exps[numClasses]; + for (u64 j = 0; j < numClasses; ++j) + { + double x = in(b, j) - max; + exps[j] = std::exp(x); + den += exps[j]; + } + + for (u64 j = 0; j < numClasses; ++j) + { + out(b, j) = exps[j] / den; + } + } +} + +template +void FloatClearText::softmax_triangular(Tensor &_in, Tensor &_out, u64 scale, u64 mode) +{ + Tensor y(_in.shape); + T scalar = 10000.0 * (1LL << scale); + attention_mask(_in, scalar, y); + softmax(y, _out, scale); +} + +template +void FloatClearText::layernorm(const Tensor1D &A, const Tensor1D &B, const Tensor &x, Tensor &y, u64 scale) +{ + always_assert(A.d1 == B.d1); + always_assert(A.d1 == x.shape.back()); + always_assert(x.is_same_shape(y)); + + u64 channels = x.shape.back(); + + fastfor(x.size() / channels, [&](u64 i) + { + T mean = 0; + T var = 0; + for (u64 j = 0; j < channels; j++) { + mean += x.data[i * channels + j]; + } + mean = mean / T(channels); + + for (u64 j = 0; j < channels; j++) { + var += (x.data[i * channels + j] - mean) * (x.data[i * channels + j] - mean); + } + + var = var / T(channels); + + for (u64 j = 0; j < channels; j++) { + y.data[i * channels + j] = (x.data[i * channels + j] - mean) / std::sqrt(var); + } }); + + fastfor(x.size(), [&](u64 i) + { y.data[i] = y.data[i] * A(i % channels) + B(i % channels); }); +} + +template +void FloatClearText::addbias(Tensor &x, const Tensor1D &bias) +{ + always_assert(x.shape.back() == bias.d1); + fastfor(x.size(), [&](u64 i) + { x.data[i] += bias(i % bias.d1); }); +} + +template +void FloatClearText::scalarmul(Tensor &x, T scalar, Tensor &y) +{ + always_assert(x.is_same_shape(y)); + fastfor(x.size(), [&](u64 i) + { y.data[i] = x.data[i] * scalar; }); +} + +template +void FloatClearText::attention_mask(Tensor &x, T scalar, Tensor &y) +{ + always_assert(x.is_same_shape(y)); + always_assert(x.shape.size() == 2); + always_assert(x.shape[0] == x.shape[1]); + + u64 n_seq = x.shape[0]; + auto y_2d = y.as_2d(); + auto x_2d = x.as_2d(); + + for (u64 j = 0; j < n_seq; ++j) + { + for (u64 k = 0; k < j + 1; ++k) + { + y_2d(j, k) = x_2d(j, k); + } + for (u64 k = j + 1; k < n_seq; ++k) + { + y_2d(j, k) = x_2d(j, k) - scalar; + } + } +} + +template +void FloatClearText::tanh(const Tensor &in, const Tensor &out, u64 scale) +{ + fastfor(in.size(), [&](u64 i) + { out.data[i] = std::tanh(in.data[i]); }); +} + +template +void FloatClearText::scalardiv(Tensor &in, double scalar, Tensor &out, u64 scale, u64 mode) +{ + always_assert(scale == 0); + fastfor(in.size(), [&](u64 i) + { out.data[i] = in.data[i] / scalar; }); +} + +template class FloatClearText; +template class FloatClearText; \ No newline at end of file