From fc2f88d596db064328d9c7d4864d50460cb78962 Mon Sep 17 00:00:00 2001 From: Saksham Gupta Date: Thu, 21 Sep 2023 12:40:28 +0530 Subject: [PATCH] adds node truncate, leakyRelu, ConvTranspose2d and bug fixes (#188) * added transpose for 4D,5D tensors with perm * leakyrelu added in ct_sytorch and OnnxBridge * leaky relu added to LLAMA and ct * Adds Cleartext float via sytorch (#186) (#187) from master * float added from kanav-gpt in sytorch * cleartext_fp added in Onnxbridge * i64->T in tensor.h * debug statement * typecast corrected in module.h * removing debug statement * added testing for sytorch float ct * rename tests * type_cast modified and leakyrelu added to float * convTranspose2d added in cleartext * added custom test script using pytest * adding convTranspose2d-LLAMA correctness not tested * bug fix size=1 input_prng.cpp * sytorch conv3d bug fix- dim d != 1 * addition in test file * printing bug resolved --- OnnxBridge/LLAMA/sytorchBackendRep.py | 11 +- OnnxBridge/LLAMA/sytorch_func_calls.py | 37 +++- OnnxBridge/backend.py | 2 + OnnxBridge/tests/conftest.py | 32 ++- OnnxBridge/tests/custom_model_test.py | 43 ++++ OnnxBridge/tests/utils.py | 3 + OnnxBridge/utils/onnx_nodes.py | 4 +- sytorch/ext/llama/api.cpp | 88 +++++++- sytorch/ext/llama/conv.cpp | 119 +++++++++++ sytorch/ext/llama/conv.h | 43 +++- sytorch/ext/llama/include/llama/api.h | 19 ++ sytorch/ext/llama/include/llama/utils.h | 20 ++ sytorch/ext/llama/src/llama/input_prng.cpp | 9 +- sytorch/ext/llama/src/llama/utils.cpp | 73 ++++++- sytorch/include/sytorch/backend/backend.h | 4 + sytorch/include/sytorch/backend/cleartext.h | 2 + sytorch/include/sytorch/backend/default.h | 20 -- sytorch/include/sytorch/backend/float.h | 3 + sytorch/include/sytorch/backend/llama_base.h | 16 ++ .../include/sytorch/backend/llama_extended.h | 36 ++++ sytorch/include/sytorch/layers/layers.h | 194 +++++++++++++++++- sytorch/include/sytorch/tensor.h | 32 ++- sytorch/include/sytorch/utils.h | 106 ++++++++++ sytorch/src/sytorch/backend/cleartext.cpp | 90 ++++++-- sytorch/src/sytorch/backend/float.cpp | 49 +++++ 25 files changed, 987 insertions(+), 68 deletions(-) create mode 100644 OnnxBridge/tests/custom_model_test.py diff --git a/OnnxBridge/LLAMA/sytorchBackendRep.py b/OnnxBridge/LLAMA/sytorchBackendRep.py index a2cd0093..f5335a6a 100644 --- a/OnnxBridge/LLAMA/sytorchBackendRep.py +++ b/OnnxBridge/LLAMA/sytorchBackendRep.py @@ -12,6 +12,7 @@ def func_call(node, value_info): """ func_map = { "Relu": "ReLU", + "LeakyRelu": "LeakyReLU", "Conv": f"{'Conv3D' if len(value_info[node.inputs[0]][1]) == 5 else 'Conv2D'}", "MaxPool": "MaxPool2D", "Flatten": "Flatten", @@ -21,7 +22,8 @@ def func_call(node, value_info): "AveragePool": "AvgPool2D", "GlobalAveragePool": "GlobalAvgPool2D", "Add": "add", - "ConvTranspose": "ConvTranspose3D", + "ConvTranspose": f"{'ConvTranspose3D' if len(value_info[node.inputs[0]][1]) == 5 else 'ConvTranspose2D'}", + "Transpose": "Transpose", } return func_map[node.op_type] @@ -56,6 +58,7 @@ def inputs_to_take(node): tmp_dict = { "Conv": 1, "Relu": 1, + "LeakyRelu": 1, "MaxPool": 1, "Gemm": 1, "Flatten": 1, @@ -65,6 +68,7 @@ def inputs_to_take(node): "BatchNormalization": 1, "GlobalAveragePool": 1, "ConvTranspose": 1, + "Transpose": 1, } return tmp_dict[node] @@ -130,7 +134,8 @@ def cleartext_post(code_list, program, scale, mode, indent): input.input_nchw(scale); print_dot_graph(net.root); net.forward(input); - print(net.activation, scale, 64); + net.activation.printshape(); + print_nchw(net.activation, scale, 64); return 0; {'}'} @@ -300,7 +305,7 @@ def llama_post(code_list, program, scale, mode, bitlength, indent): auto &output = net.activation; llama->outputA(output); if (party == CLIENT) {'{'} - print(output, scale, LlamaConfig::bitlength); + print_nchw(output, scale, LlamaConfig::bitlength); {'}'} llama->finalize(); {'}'} diff --git a/OnnxBridge/LLAMA/sytorch_func_calls.py b/OnnxBridge/LLAMA/sytorch_func_calls.py index e71d20e6..9ebe4f32 100644 --- a/OnnxBridge/LLAMA/sytorch_func_calls.py +++ b/OnnxBridge/LLAMA/sytorch_func_calls.py @@ -93,6 +93,15 @@ def Relu(cls, attributes, inputs, outputs, value_info, var_dict, mode, indent): logger.debug("Inside Relu function call.") return str(f"{' ' * indent}new ReLU();") + @classmethod + def LeakyRelu(cls, attributes, inputs, outputs, value_info, var_dict, mode, indent): + logger.debug("Inside LeakyRelu function call.") + if "alpha" in attributes.keys(): + alpha = attributes["alpha"] + else: + alpha = 0.01 + return str(f"{' ' * indent}new LeakyReLU({alpha});") + @classmethod def BatchNormalization( cls, attributes, inputs, outputs, value_info, var_dict, mode, indent @@ -174,10 +183,27 @@ def ConvTranspose( cls, attributes, inputs, outputs, value_info, var_dict, mode, indent ): logger.debug("Inside ConvTranspose function call.") - pads = get_padding_3d(attributes, inputs, outputs, value_info, var_dict) spatial_size = len(value_info[inputs[0]][1]) - 2 - if spatial_size == 3: + if spatial_size == 2: + assert len(inputs) == 2 or len(inputs) == 3 + pads = get_padding(attributes, inputs, outputs, value_info, var_dict) + assert len(attributes["strides"]) == 2 + assert value_info[inputs[1]][1][2:] == tuple(attributes["kernel_shape"]) + CI = value_info[inputs[0]][1][1] + CO = value_info[outputs[0]][1][1] + filterShape = value_info[inputs[1]][1] + pad = pads[0] + strides = attributes["strides"] + dilations = get_dilation(attributes, inputs, outputs, value_info, var_dict) + isBias = ", true" if len(inputs) == 3 else "" + return str( + f"{' ' * indent}new ConvTranspose2D(" + f"{CI}, {CO}, {'{'}{iterate_list(filterShape[2:])}{'}'}, {'{'}{iterate_list(pads)}{'}'}, {'{'}{iterate_list(strides)}{'}'},{'{'}{iterate_list(dilations)}{'}'}{isBias}" + f");" + ) + elif spatial_size == 3: assert len(inputs) == 2 or len(inputs) == 3 + pads = get_padding_3d(attributes, inputs, outputs, value_info, var_dict) assert len(attributes["strides"]) == 3 assert value_info[inputs[1]][1][2:] == tuple(attributes["kernel_shape"]) CI = value_info[inputs[0]][1][1] @@ -232,6 +258,13 @@ def Reshape(cls, attributes, inputs, outputs, value_info, var_dict, mode, indent return str(f"{' ' * indent}new Reshape();") # todo : check format + @classmethod + def Transpose(cls, attributes, inputs, outputs, value_info, var_dict, mode, indent): + logger.debug("Inside Transpose function call.") + return str( + f"{' ' * indent}new Transpose( {'{'}{iterate_list(attributes['perm'])}{'}'});" + ) + @classmethod def Gemm(cls, attributes, inputs, outputs, value_info, var_dict, mode, indent): logger.debug("Inside Gemm function call.") diff --git a/OnnxBridge/backend.py b/OnnxBridge/backend.py index 07628523..9b602042 100644 --- a/OnnxBridge/backend.py +++ b/OnnxBridge/backend.py @@ -122,6 +122,7 @@ def is_compatible(cls, model, backend, device: str = "2PC", **kwargs): not_supported = [] implemented_sytorch = [ "Relu", + "LeakyRelu", "Softmax", "Conv", "MaxPool", @@ -133,6 +134,7 @@ def is_compatible(cls, model, backend, device: str = "2PC", **kwargs): "GlobalAveragePool", "Add", "ConvTranspose", + "Transpose", ] implemented_secfloat = [ "Relu", diff --git a/OnnxBridge/tests/conftest.py b/OnnxBridge/tests/conftest.py index 5067b98c..0d74279c 100644 --- a/OnnxBridge/tests/conftest.py +++ b/OnnxBridge/tests/conftest.py @@ -26,6 +26,18 @@ def pytest_addoption(parser): help="batch size", required=False, ) + parser.addoption( + "--model", + action="store", + help="absolute mdel path", + required=False, + ) + parser.addoption( + "--input_name", + action="store", + help="absolute input_name path", + required=False, + ) @pytest.fixture(scope="session") @@ -34,6 +46,18 @@ def backend(request): return opt +@pytest.fixture(scope="session") +def model(request): + opt = request.config.getoption("--model") + return opt + + +@pytest.fixture(scope="session") +def input_name(request): + opt = request.config.getoption("--input_name") + return opt + + @pytest.fixture(scope="session") def batch_size(request): opt = request.config.getoption("--batch_size") @@ -79,8 +103,12 @@ def pytest_runtest_makereport(item, call): @pytest.fixture def test_dir(request, test_env): print("\nRequest node: ", request.node.name) - test_name_list = request.node.name.split("[") - parameter_name = test_name_list[1].split("]")[0] + if "[" in request.node.name: + test_name_list = request.node.name.split("[") + parameter_name = test_name_list[1].split("]")[0] + else: + test_name_list = request.node.name.split("_") + parameter_name = "custom" full_test_name = test_name_list[0] + "_" + parameter_name test_name = full_test_name[len("test_") :] main_test_dir = test_env["test_dir"] diff --git a/OnnxBridge/tests/custom_model_test.py b/OnnxBridge/tests/custom_model_test.py new file mode 100644 index 00000000..b01d57b1 --- /dev/null +++ b/OnnxBridge/tests/custom_model_test.py @@ -0,0 +1,43 @@ +import pytest +import os +from utils import ( + run_onnx, + compile_model, + run_backend, + compare_output, +) + +# Get the directory path where the current script is located +script_directory = os.path.dirname(os.path.abspath(__file__)) +ezpc_dir = os.path.join(script_directory, "..", "..") + + +def test_custom_model(test_dir, backend, model, input_name): + """ + Usage: + pytest path/custom_model_test.py -s --backend CLEARTEXT_LLAMA --model /home/saksham/EzPC/OnnxBridge/nnUnet/optimized_fabiansPreActUnet.onnx --input_name /home/saksham/EzPC/OnnxBridge/nnUnet/inputs/2d_input + """ + os.chdir(test_dir) + + # model is absolute path to model.onnx + # input is absolute path to input1.j + # download the model & data & preprocessing_file + os.system(f"cp {model} model.onnx") + os.system(f"cp {input_name}.inp input1.inp") + os.system(f"cp {input_name}.npy input1.npy") + + # preprocessed input is directly copied + + # run the model with OnnxRuntime + run_onnx("input1.npy") + + # compile the model with backend + compile_model(backend) + + # run the model with backend + run_backend(backend, "input1.inp") + + # compare the output + compare_output() + + os.chdir("../..") diff --git a/OnnxBridge/tests/utils.py b/OnnxBridge/tests/utils.py index edfefe35..beebe7a5 100644 --- a/OnnxBridge/tests/utils.py +++ b/OnnxBridge/tests/utils.py @@ -122,6 +122,9 @@ def compare_output(): arr1 = np.load("output.npy", allow_pickle=True).flatten() arr2 = np.load("onnx_output/expected.npy", allow_pickle=True).flatten() + print("Secure Output Shape: " + str(arr1.shape)) + print("Expected Output Shape: " + str(arr2.shape)) + matching_prec = -1 for prec in range(1, 10): try: diff --git a/OnnxBridge/utils/onnx_nodes.py b/OnnxBridge/utils/onnx_nodes.py index 5842482f..31754faf 100644 --- a/OnnxBridge/utils/onnx_nodes.py +++ b/OnnxBridge/utils/onnx_nodes.py @@ -173,8 +173,8 @@ def ConvTranspose(cls, node): @classmethod def LeakyRelu(cls, node): - if "alpha" not in node.attributes: - node.attributes["alpha"] = 0.01 + if "alpha" not in node.attrs: + node.attrs["alpha"] = 0.01 @classmethod def Tanh(cls, node): diff --git a/sytorch/ext/llama/api.cpp b/sytorch/ext/llama/api.cpp index 1e4fa2cf..97c9d8b9 100644 --- a/sytorch/ext/llama/api.cpp +++ b/sytorch/ext/llama/api.cpp @@ -1510,4 +1510,90 @@ void ConvTranspose3DWrapper(int64_t N, std::cerr << ">> ConvTranspose3D - End" << "\n"; -} \ No newline at end of file +} + +void ConvTranspose2DWrapper(int64_t N, + int64_t H, + int64_t W, + int64_t CI, + int64_t FH, + int64_t FW, + int64_t CO, + int64_t zPadHLeft, + int64_t zPadHRight, + int64_t zPadWLeft, + int64_t zPadWRight, + int64_t strideH, + int64_t strideW, + int64_t outH, + int64_t outW, + GroupElement *inputArr, + GroupElement *filterArr, + GroupElement *outArr) +{ + std::cerr << ">> ConvTranspose2D - Start" << std::endl; + always_assert(outH == (H - 1) * strideH - zPadHLeft - zPadHRight + FH); + always_assert(outW == (W - 1) * strideW - zPadWLeft - zPadWRight + FW); + + if (party == DEALER) + { + auto local_start = std::chrono::high_resolution_clock::now(); + + // not good for in place operations + for (int i = 0; i < N * outH * outW * CO; ++i) + { + outArr[i] = random_ge(bitlength); + } + + auto keys = KeyGenConvTranspose2D(bitlength, N, H, W, CI, FH, FW, CO, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideH, strideW, outH, outW, inputArr, filterArr, outArr); + + auto local_end = std::chrono::high_resolution_clock::now(); + + client->send_triple_key(keys.second); + freeTripleKey(keys.second); + auto local_time_taken = std::chrono::duration_cast(local_end - + local_start) + .count(); + dealerMicroseconds += local_time_taken; + std::cerr << " Dealer Time = " << local_time_taken / 1000.0 << " milliseconds\n"; + } + else + { + + auto keyread_start = std::chrono::high_resolution_clock::now(); + auto key = dealer->recv_triple_key(bitlength, N * H * W * CI, CI * FH * FW * CO, N * outH * outW * CO); + auto keyread_end = std::chrono::high_resolution_clock::now(); + auto keyread_time_taken = std::chrono::duration_cast(keyread_end - keyread_start).count(); + + peer->sync(); + + auto local_start = std::chrono::high_resolution_clock::now(); + + EvalConvTranspose2D(party, key, N, H, W, CI, FH, FW, CO, + zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideH, strideW, outH, outW, inputArr, filterArr, outArr); + + auto t1 = std::chrono::high_resolution_clock::now(); + uint64_t onlineComm0 = peer->bytesReceived + peer->bytesSent; + + reconstruct(N * outH * outW * CO, outArr, bitlength); + + uint64_t onlineComm1 = peer->bytesReceived + peer->bytesSent; + convOnlineComm += (onlineComm1 - onlineComm0); + auto local_end = std::chrono::high_resolution_clock::now(); + + freeTripleKey(key); + auto compute_time = std::chrono::duration_cast(t1 - local_start).count(); + auto reconstruct_time = std::chrono::duration_cast(local_end - t1).count(); + + convEvalMicroseconds += (reconstruct_time + compute_time); + evalMicroseconds += (reconstruct_time + compute_time); + std::cerr << " Key Read Time = " << keyread_time_taken << " milliseconds\n"; + std::cerr << " Compute Time = " << compute_time / 1000.0 << " milliseconds\n"; + std::cerr << " Reconstruct Time = " << reconstruct_time / 1000.0 << " milliseconds\n"; + std::cerr << " Online Time = " << (reconstruct_time + compute_time) / 1000.0 << " milliseconds\n"; + std::cerr << " Online Comm = " << (onlineComm1 - onlineComm0) << " bytes\n"; + } + + std::cerr << ">> ConvTranspose2D - End" << std::endl; +} + diff --git a/sytorch/ext/llama/conv.cpp b/sytorch/ext/llama/conv.cpp index a356559f..6c43f283 100644 --- a/sytorch/ext/llama/conv.cpp +++ b/sytorch/ext/llama/conv.cpp @@ -448,3 +448,122 @@ void EvalConvTranspose3D(int party, const TripleKeyPack &key, delete[] temp; } + +std::pair KeyGenConvTranspose2D( + int bw, + int64_t N, + int64_t H, + int64_t W, + int64_t CI, + int64_t FH, + int64_t FW, + int64_t CO, + int64_t zPadHLeft, + int64_t zPadHRight, + int64_t zPadWLeft, + int64_t zPadWRight, + int64_t strideH, + int64_t strideW, + int64_t outH, + int64_t outW, + GroupElement *inputArr, + GroupElement *filterArr, + GroupElement *outArr) +{ + TripleKeyPack k0; + TripleKeyPack k1; + + k1.a = make_array(N, H, W, CI); + k1.b = make_array(FH, FW, CI, CO); + k1.c = make_array(N, outH, outW, CO); + + k1.bw = bw; + k1.na = N * H * W * CI; + k1.nb = FH * FW * CI * CO; + k1.nc = N * outH * outW * CO; + + // Need temp array - matmul cant be done inplace and hence conv3d is not inplace + GroupElement *c = make_array(N, outH, outW, CO); + + ConvTranspose2DLoopInnerClear(N, H, W, CI, FH, FW, CO, + zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, + strideH, strideW, outH, outW, inputArr, filterArr, c); + + MatAdd4(N, outH, outW, CO, c, outArr, c); + + for (int i = 0; i < N * H * W * CI; ++i) + { + auto rin1_split = splitShareCommonPRNG(inputArr[i], bw); + k1.a[i] = rin1_split.second; + } + + for (int i = 0; i < FH * FW * CI * CO; ++i) + { + auto rin2_split = splitShareCommonPRNG(filterArr[i], bw); + k1.b[i] = rin2_split.second; + } + + for (int i = 0; i < N * outH * outW * CO; ++i) + { + auto c_split = splitShareCommonPRNG(c[i], bw); + k1.c[i] = c_split.second; + } + + delete[] c; + + return std::make_pair(k0, k1); +} + +void EvalConvTranspose2D(int party, const TripleKeyPack &key, + int64_t N, + int64_t H, + int64_t W, + int64_t CI, + int64_t FH, + int64_t FW, + int64_t CO, + int64_t zPadHLeft, + int64_t zPadHRight, + int64_t zPadWLeft, + int64_t zPadWRight, + int64_t strideH, + int64_t strideW, + int64_t outH, + int64_t outW, + GroupElement *inputArr, + GroupElement *filterArr, + GroupElement *outArr) +{ + + MatCopy4(N, outH, outW, CO, key.c, outArr); + GroupElement *temp = make_array(N, outH, outW, CO); + + if (party == SERVER) + { + GroupElement *tempFilter = make_array(FH, FW, CI, CO); + + MatSub4(FH, FW, CI, CO, filterArr, key.b, tempFilter); + ConvTranspose2DLoopInnerClear(N, H, W, CI, FH, FW, CO, + zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, + strideH, strideW, outH, outW, + inputArr, tempFilter, temp); + MatAdd4(N, outH, outW, CO, temp, outArr, outArr); + delete[] tempFilter; + } + else + { + ConvTranspose2DLoopInnerClear(N, H, W, CI, FH, FW, CO, + zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, + strideH, strideW, outH, outW, + inputArr, key.b, temp); + MatSub4(N, outH, outW, CO, outArr, temp, outArr); + } + + ConvTranspose2DLoopInnerClear(N, H, W, CI, FH, FW, CO, + zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, + strideH, strideW, outH, outW, + key.a, filterArr, temp); + MatSub4(N, outH, outW, CO, outArr, temp, outArr); + + delete[] temp; +} \ No newline at end of file diff --git a/sytorch/ext/llama/conv.h b/sytorch/ext/llama/conv.h index efa1cc9c..7da72ba2 100644 --- a/sytorch/ext/llama/conv.h +++ b/sytorch/ext/llama/conv.h @@ -105,4 +105,45 @@ void EvalConvTranspose3D(int party, const TripleKeyPack &key, int64_t outW, GroupElement* inputArr, GroupElement* filterArr, - GroupElement* outArr); \ No newline at end of file + GroupElement* outArr); + +std::pair KeyGenConvTranspose2D( + int bw, + int64_t N, + int64_t H, + int64_t W, + int64_t CI, + int64_t FH, + int64_t FW, + int64_t CO, + int64_t zPadHLeft, + int64_t zPadHRight, + int64_t zPadWLeft, + int64_t zPadWRight, + int64_t strideH, + int64_t strideW, + int64_t outH, + int64_t outW, + GroupElement *inputArr, + GroupElement *filterArr, + GroupElement *outArr); + +void EvalConvTranspose2D(int party, const TripleKeyPack &key, + int64_t N, + int64_t H, + int64_t W, + int64_t CI, + int64_t FH, + int64_t FW, + int64_t CO, + int64_t zPadHLeft, + int64_t zPadHRight, + int64_t zPadWLeft, + int64_t zPadWRight, + int64_t strideH, + int64_t strideW, + int64_t outH, + int64_t outW, + GroupElement *inputArr, + GroupElement *filterArr, + GroupElement *outArr); \ No newline at end of file diff --git a/sytorch/ext/llama/include/llama/api.h b/sytorch/ext/llama/include/llama/api.h index e182f162..029e58c4 100644 --- a/sytorch/ext/llama/include/llama/api.h +++ b/sytorch/ext/llama/include/llama/api.h @@ -143,4 +143,23 @@ void ConvTranspose3DWrapper(int64_t N, GroupElement* filterArr, GroupElement* outArr); +void ConvTranspose2DWrapper(int64_t N, + int64_t H, + int64_t W, + int64_t CI, + int64_t FH, + int64_t FW, + int64_t CO, + int64_t zPadHLeft, + int64_t zPadHRight, + int64_t zPadWLeft, + int64_t zPadWRight, + int64_t strideH, + int64_t strideW, + int64_t outH, + int64_t outW, + GroupElement *inputArr, + GroupElement *filterArr, + GroupElement *outArr); + void reconstruct(int32_t size, GroupElement *arr, int bw); diff --git a/sytorch/ext/llama/include/llama/utils.h b/sytorch/ext/llama/include/llama/utils.h index b33d5443..51c9e6e4 100644 --- a/sytorch/ext/llama/include/llama/utils.h +++ b/sytorch/ext/llama/include/llama/utils.h @@ -147,3 +147,23 @@ void ConvTranspose3DLoopInnerClear( GroupElement* inputArr, GroupElement* filterArr, GroupElement* outArr); + +void ConvTranspose2DLoopInnerClear( + int64_t N, + int64_t H, + int64_t W, + int64_t CI, + int64_t FH, + int64_t FW, + int64_t CO, + int64_t zPadHLeft, + int64_t zPadHRight, + int64_t zPadWLeft, + int64_t zPadWRight, + int64_t strideH, + int64_t strideW, + int64_t outH, + int64_t outW, + GroupElement *inputArr, + GroupElement *filterArr, + GroupElement *outArr); \ No newline at end of file diff --git a/sytorch/ext/llama/src/llama/input_prng.cpp b/sytorch/ext/llama/src/llama/input_prng.cpp index 83131adb..3bca16a3 100644 --- a/sytorch/ext/llama/src/llama/input_prng.cpp +++ b/sytorch/ext/llama/src/llama/input_prng.cpp @@ -140,9 +140,12 @@ void input_layer(GroupElement *x, GroupElement *x_mask, int size, int owner) else { uint64_t *tmp = new uint64_t[size]; peer->recv_batched_input(tmp, 1, bitlength); - TIME_THIS_BLOCK_FOR_INPUT_IF( - peer->recv_batched_input(tmp+1, size-1, bitlength); - , true, (owner == SERVER ? accumulatedInputTimeOffline : accumulatedInputTimeOnline)) + if (size > 1) + { + TIME_THIS_BLOCK_FOR_INPUT_IF( + peer->recv_batched_input(tmp + 1, size - 1, bitlength); + , true, (owner == SERVER ? accumulatedInputTimeOffline : accumulatedInputTimeOnline)) + } // todo: parallelize this maybe? for(int i = 0; i < size; ++i) { x[i] = tmp[i]; diff --git a/sytorch/ext/llama/src/llama/utils.cpp b/sytorch/ext/llama/src/llama/utils.cpp index 96ef0ddd..33e2cd13 100644 --- a/sytorch/ext/llama/src/llama/utils.cpp +++ b/sytorch/ext/llama/src/llama/utils.cpp @@ -526,7 +526,7 @@ Conv3DCache allocateConv3DCache(int N, int D, int H, int W, int CI, cache.reshapedFilter = eigenMatrix(reshapedFilterRows, reshapedFilterCols); cache.reshapedInput = eigenMatrix(reshapedIPRows, reshapedIPCols); cache.matmulResult = eigenMatrix(reshapedFilterRows, reshapedIPCols); - cache.temp = make_array(N, newH, newW, CO); + cache.temp = make_array(N, newD, newH, newW, CO); return cache; } @@ -639,4 +639,73 @@ void ConvTranspose3DLoopInnerClear( } } } -} \ No newline at end of file +} + +void ConvTranspose2DLoopInnerClear( + int64_t N, + int64_t H, + int64_t W, + int64_t CI, + int64_t FH, + int64_t FW, + int64_t CO, + int64_t zPadHLeft, + int64_t zPadHRight, + int64_t zPadWLeft, + int64_t zPadWRight, + int64_t strideH, + int64_t strideW, + int64_t outH, + int64_t outW, + GroupElement *inputArr, + GroupElement *filterArr, + GroupElement *outArr) +{ + zPadHLeft = FH - 1 - zPadHLeft; + zPadHRight = FH - 1 - zPadHRight; + zPadWLeft = FW - 1 - zPadWLeft; + zPadWRight = FW - 1 - zPadWRight; + +#pragma omp parallel for collapse(4) + for (int64_t n = 0; n < N; n++) + { + for (int64_t h = 0; h < outH; h++) + { + for (int64_t w = 0; w < outW; w++) + { + for (int64_t co = 0; co < CO; co++) + { + + GroupElement val = 0; + for (int64_t ci = 0; ci < CI; ci++) + { + for (int64_t fh = h; fh < (h + FH); fh++) + { + for (int64_t fw = w; fw < (w + FW); fw++) + { + + int64_t curPosH = ((fh - zPadHLeft) / strideH); + int64_t curPosW = ((fw - zPadWLeft) / strideW); + + if ((curPosH >= 0) && + (curPosW >= 0) && + (curPosH < H) && + (curPosW < W) && + (((fh - zPadHLeft) % strideH) == 0) && + (((fw - zPadWLeft) % strideW) == 0)) + { + int32_t curFilterPosH = FH + h - fh - 1; + int32_t curFilterPosW = FW + w - fw - 1; + val += (Arr4DIdx(inputArr, N, H, W, CI, n, curPosH, curPosW, ci) * Arr4DIdx(filterArr, CO, FH, FW, CI, co, curFilterPosH, curFilterPosW, ci)); + } + } + } + } + Arr4DIdx(outArr, N, outH, outW, CO, n, h, w, co) = val; + // std::cout << "setting element at (" << n << " " << h << " " << w << " " << co << ")" << std::endl; + } + } + } + } +} + diff --git a/sytorch/include/sytorch/backend/backend.h b/sytorch/include/sytorch/backend/backend.h index 97df7e73..abc8ee7b 100644 --- a/sytorch/include/sytorch/backend/backend.h +++ b/sytorch/include/sytorch/backend/backend.h @@ -50,10 +50,14 @@ class Backend { virtual void conv2D(u64 fh, u64 fw, u64 padding, u64 stride, u64 ci, u64 co, const Tensor4D &input, const Tensor2D &filter, Tensor4D &output) NOT_IMPLEMENTED; virtual void conv3D(u64 fd, u64 fh, u64 fw, u64 pd, u64 ph, u64 pw, u64 sd, u64 sh, u64 sw, u64 dd, u64 dh, u64 dw, u64 ci, u64 co, const Tensor5D &input, const Tensor2D &filter, Tensor5D &output) NOT_IMPLEMENTED; virtual void convTranspose3D(u64 fd, u64 fh, u64 fw, u64 pd, u64 ph, u64 pw, u64 sd, u64 sh, u64 sw, u64 ci, u64 co, const Tensor5D &input, const Tensor2D &filter, Tensor5D &output) NOT_IMPLEMENTED; + virtual void convTranspose2D(u64 fh, u64 fw, u64 ph, u64 pw, u64 sh, u64 sw, u64 ci, u64 co, const Tensor4D &input, const Tensor2D &filter, Tensor4D &output) NOT_IMPLEMENTED; // relu API virtual void relu(const Tensor &in, const Tensor &out, const Tensor &drelu, u64 scale, int mode) NOT_IMPLEMENTED; + // leakyrelu API + virtual void leakyRelu(const Tensor &in, const Tensor &out, const Tensor &drelu, u64 scale, int mode, T alpha) NOT_IMPLEMENTED; + // avgpool API virtual void div(Tensor &in, T divisor, u64 scale) NOT_IMPLEMENTED; virtual void sumPool2D(u64 ks, u64 padding, u64 stride, const Tensor4D &in, Tensor4D &out) NOT_IMPLEMENTED; diff --git a/sytorch/include/sytorch/backend/cleartext.h b/sytorch/include/sytorch/backend/cleartext.h index 7133a47b..73a03cb3 100644 --- a/sytorch/include/sytorch/backend/cleartext.h +++ b/sytorch/include/sytorch/backend/cleartext.h @@ -68,8 +68,10 @@ class ClearText : public Backend { void conv2D(u64 fh, u64 fw, u64 padding, u64 stride, u64 ci, u64 co, const Tensor4D &input, const Tensor2D &filter, Tensor4D &output); void conv3D(u64 fd, u64 fh, u64 fw, u64 pd, u64 ph, u64 pw, u64 sd, u64 sh, u64 sw, u64 dd, u64 dh, u64 dw, u64 ci, u64 co, const Tensor5D &input, const Tensor2D &filter, Tensor5D &output); void convTranspose3D(u64 fd, u64 fh, u64 fw, u64 pd, u64 ph, u64 pw, u64 sd, u64 sh, u64 sw, u64 ci, u64 co, const Tensor5D &input, const Tensor2D &filter, Tensor5D &output); + void convTranspose2D(u64 fh, u64 fw, u64 ph, u64 pw, u64 sh, u64 sw, u64 ci, u64 co, const Tensor4D &input, const Tensor2D &filter, Tensor4D &output); void relu(const Tensor &in, const Tensor &out, const Tensor &drelu, u64 scale, int mode); + void leakyRelu(const Tensor &in, const Tensor &out, const Tensor &drelu, u64 scale, int mode, T alpha); // void truncate(const Tensor4D &in, const Tensor4D &out, u64 shift); // void truncate(const Tensor4D &in, u64 shift); // void truncate(const Tensor2D &in, u64 shift); diff --git a/sytorch/include/sytorch/backend/default.h b/sytorch/include/sytorch/backend/default.h index 66bbc93c..0cf3cddd 100644 --- a/sytorch/include/sytorch/backend/default.h +++ b/sytorch/include/sytorch/backend/default.h @@ -16,23 +16,3 @@ Backend *defaultBackend() } } -template -inline T type_cast(float val); - -template <> -float type_cast(float val) -{ - return val; -} - -template <> -i64 type_cast(float val) -{ - return (i64)val; -} - -template <> -u64 type_cast(float val) -{ - return (u64(i64(val))); -} \ No newline at end of file diff --git a/sytorch/include/sytorch/backend/float.h b/sytorch/include/sytorch/backend/float.h index 0673ac8d..06f1a08b 100644 --- a/sytorch/include/sytorch/backend/float.h +++ b/sytorch/include/sytorch/backend/float.h @@ -28,8 +28,11 @@ class FloatClearText : public Backend void conv2D(u64 fh, u64 fw, u64 padding, u64 stride, u64 ci, u64 co, const Tensor4D &input, const Tensor2D &filter, Tensor4D &output); void conv3D(u64 fd, u64 fh, u64 fw, u64 pd, u64 ph, u64 pw, u64 sd, u64 sh, u64 sw, u64 dd, u64 dh, u64 dw, u64 ci, u64 co, const Tensor5D &input, const Tensor2D &filter, Tensor5D &output); void convTranspose3D(u64 fd, u64 fh, u64 fw, u64 pd, u64 ph, u64 pw, u64 sd, u64 sh, u64 sw, u64 ci, u64 co, const Tensor5D &input, const Tensor2D &filter, Tensor5D &output); + void convTranspose2D(u64 fh, u64 fw, u64 ph, u64 pw, u64 sh, u64 sw, u64 ci, u64 co, const Tensor4D &input, const Tensor2D &filter, Tensor4D &output); void relu(const Tensor &in, const Tensor &out, const Tensor &drelu, u64 scale, int mode); + void leakyRelu(const Tensor &in, const Tensor &out, const Tensor &drelu, u64 scale, int mode, T alpha); + void truncate(T &in, u64 shift); void div(Tensor &in, T divisor, u64 scale); void div(T &in, T divisor, u64 scale); diff --git a/sytorch/include/sytorch/backend/llama_base.h b/sytorch/include/sytorch/backend/llama_base.h index 69ef5cb5..10f9332f 100644 --- a/sytorch/include/sytorch/backend/llama_base.h +++ b/sytorch/include/sytorch/backend/llama_base.h @@ -305,6 +305,22 @@ class LlamaBase : public Backend { output.d2, output.d3, output.d4, input.data, filter.data, output.data); } + void convTranspose2D(u64 fh, u64 fw, u64 ph, u64 pw, u64 sh, u64 sw, u64 ci, u64 co, const Tensor4D &input, const Tensor2D &filter, Tensor4D &output) + { + assert(input.d4 == ci); + assert(filter.d1 == co); + assert(filter.d2 == fh * fw * ci); + u64 newH = (((input.d2 - 1) * sh + fh - 2 * ph)); + u64 newW = (((input.d3 - 1) * sw + fw - 2 * pw)); + assert(output.d1 == input.d1); + assert(output.d2 == newH); + assert(output.d3 == newW); + assert(output.d4 == co); + + ConvTranspose2DWrapper(input.d1, input.d2, input.d3, input.d4, fh, fw, co, + ph, ph, pw, pw, sh, sw, output.d2, output.d3, input.data, filter.data, output.data); + } + void sumPool2D(u64 ks, u64 padding, u64 stride, const Tensor4D &in, Tensor4D &out) { assert(in.d1 == out.d1); assert(in.d4 == out.d4); diff --git a/sytorch/include/sytorch/backend/llama_extended.h b/sytorch/include/sytorch/backend/llama_extended.h index a078899c..e0800977 100644 --- a/sytorch/include/sytorch/backend/llama_extended.h +++ b/sytorch/include/sytorch/backend/llama_extended.h @@ -15,6 +15,42 @@ class LlamaExtended : public LlamaBase { Relu(sz, in.data, in.data, out.data, out.data, drelu.data); } + void leakyRelu(const Tensor &in, const Tensor &out, const Tensor &drelu, u64 scale, int mode, T alpha) + { + assert(in.is_same_shape(out)); + assert(in.is_same_shape(drelu)); + int sz = in.size(); + std::vector shape = in.shape; + + T minus_one = type_cast(-1 * (1LL << scale)); + auto ct = new ClearText; + // leakyrelu = relu(x) - alpha * relu(-x) + + // relu(x) + Tensor relu_x(shape); + Relu(sz, in.data, in.data, relu_x.data, relu_x.data, drelu.data); + + // -x + Tensor minus_x(shape); + ct->fastfor(sz, [&](u64 i) + { minus_x.data[i] = minus_one * in.data[i]; }); + Backend::truncate(minus_x, scale); + + // relu(-x) + Tensor relu_minus_x(shape); + Relu(sz, minus_x.data, minus_x.data, relu_minus_x.data, relu_minus_x.data, drelu.data); + + // alpha * relu(-x) + Tensor alpha_relu_minus_x(shape); + ct->fastfor(sz, [&](u64 i) + { alpha_relu_minus_x.data[i] = alpha * relu_minus_x.data[i]; }); + Backend::truncate(alpha_relu_minus_x, scale); + + // relu(x) - alpha * relu(-x) + ct->fastfor(sz, [&](u64 i) + { out.data[i] = relu_x.data[i] - alpha_relu_minus_x.data[i]; }); + } + void truncate(T *in, T *out, u64 shift, u64 size, u8 mode) { if (this->useLocalTruncation) { for(u64 i = 0; i < size; i++) { diff --git a/sytorch/include/sytorch/layers/layers.h b/sytorch/include/sytorch/layers/layers.h index 54ec8600..45d9b1b3 100644 --- a/sytorch/include/sytorch/layers/layers.h +++ b/sytorch/include/sytorch/layers/layers.h @@ -510,6 +510,36 @@ class ReLU: public Layer { } }; +template +class LeakyReLU : public Layer +{ +public: + Tensor drelu; + double alpha; + LeakyReLU(double alpha) : Layer("LeakyReLU"), drelu({0}), alpha(alpha) {} + + void _resize(const std::vector> &shapes) + { + always_assert(shapes.size() == 1); + auto &shape = shapes[0]; + this->drelu.resize(shape); + always_assert(this->alpha >= 0.0); + } + + void _forward(Tensor &a) + { + T alphaFix = type_cast(alpha * (1LL << this->scale)); + this->backend->leakyRelu(a, this->activation, this->drelu, this->scale, this->mode, alphaFix); + } + + std::vector get_output_dims(const std::vector> &inShapes) + { + always_assert(inShapes.size() == 1); + auto &inShape = inShapes[0]; + return inShape; + } +}; + template class BatchNormInference : public Layer { public: @@ -672,6 +702,83 @@ class ConvTranspose3D : public Layer { } }; +template +class ConvTranspose2D : public Layer +{ +public: + Tensor2D filter; + Tensor1D bias; + u64 ci, co; + u64 fh, fw; + u64 ph, pw; + u64 sh, sw; + + ConvTranspose2D(u64 ci, u64 co, u64 f, u64 padding = 0, u64 stride = 1, bool useBias = false) : Layer("ConvTranspose2D"), ci(ci), co(co), fh(f), fw(f), + ph(padding), pw(padding), sh(stride), sw(stride), filter(co, f * f * ci), bias(co) + { + this->doTruncationForward = true; + this->useBias = useBias; + } + + ConvTranspose2D(u64 ci, u64 co, const std::array f, u64 padding = 0, u64 stride = 1, bool useBias = false) : Layer("ConvTranspose2D"), ci(ci), co(co), fh(f[0]), fw(f[1]), + ph(padding), pw(padding), sh(stride), sw(stride), filter(co, f[0] * f[1] * ci), bias(co) + { + this->doTruncationForward = true; + this->useBias = useBias; + } + + ConvTranspose2D(u64 ci, u64 co, const std::array f, const std::array padding = {0, 0, 0, 0}, const std::array stride = {1, 1}, const std::array dialation = {1, 1}, bool useBias = false) : Layer("ConvTranspose2D"), ci(ci), co(co), fh(f[0]), fw(f[1]), + ph(padding[0]), pw(padding[1]), sh(stride[0]), sw(stride[1]), filter(co, f[0] * f[1] * ci), bias(co) + { + always_assert(dialation[0] == 1); + always_assert(dialation[1] == 1); + always_assert(padding[2] == padding[0]); + always_assert(padding[3] == padding[1]); + this->doTruncationForward = true; + this->useBias = useBias; + } + + void _initScale(u64 scale) + { + double xavier = 1.0 / sqrt(ci * fh * fw); + filter.randomize(xavier * (1ULL << scale)); + if (this->useBias) + bias.randomize(xavier * (1ULL << (2 * scale))); + } + + void _resize(const std::vector> &shapes) + { + always_assert(shapes.size() == 1); + auto &shape = shapes[0]; + always_assert(shape.size() == 4); + always_assert(shape[3] == ci); + } + + void _forward(Tensor &a) + { + always_assert(a.shape.size() == 4); + assert(a.shape[3] == ci); + auto act_4d = this->activation.as_4d(); + this->backend->convTranspose2D(fh, fw, ph, pw, sh, sw, ci, co, a.as_4d(), filter, act_4d); + if (this->useBias) + this->backend->addbias(this->activation, bias); + } + + TensorRef getweights() { return filter.ref(); } + TensorRef getbias() { return bias.ref(); } + + std::vector get_output_dims(const std::vector> &inShapes) + { + always_assert(inShapes.size() == 1); + auto &inShape = inShapes[0]; + always_assert(inShape.size() == 4); + always_assert(inShape[3] == ci); + u64 newH = (((inShape[1] - 1) * sh + fh - 2 * ph)); + u64 newW = (((inShape[2] - 1) * sw + fw - 2 * pw)); + return {inShape[0], newH, newW, co}; + } +}; + template class PlaceHolderLayer : public Layer { public: @@ -938,29 +1045,98 @@ class View: public Layer { template class Transpose: public Layer { public: - Transpose() : Layer("Transpose") {} + std::vector perm; + Transpose(const std::vector &perm) : Layer("Transpose"), perm(perm) {} void _resize(const std::vector> &shapes) { always_assert(shapes.size() == 1); auto &shape = shapes[0]; - always_assert(shape.size() == 2); + always_assert(shape.size() >= 2); } void _forward(Tensor &a) { - always_assert(a.shape.size() == 2); - //#pragma omp parallel for collapse(2) - for (u64 i = 0; i < a.shape[0]; ++i) { - for (u64 j = 0; j < a.shape[1]; ++j) { - this->activation.data[j * a.shape[0] + i] = a.data[i * a.shape[1] + j]; + if (a.shape.size() == 2) + { +#pragma omp parallel for collapse(2) + for (u64 i = 0; i < a.shape[0]; ++i) + { + for (u64 j = 0; j < a.shape[1]; ++j) + { + this->activation.data[j * a.shape[perm[1]] + i] = a.data[i * a.shape[1] + j]; + } } } + else if (a.shape.size() == 4) + { + auto a_4d = a.as_4d(); + auto out_4d = this->activation.as_4d(); +#pragma omp parallel for collapse(4) + for (int n = 0; n < a.shape[0]; ++n) + { + for (int h = 0; h < a.shape[1]; ++h) + { + for (int w = 0; w < a.shape[2]; ++w) + { + for (int c = 0; c < a.shape[3]; ++c) + { + out_4d(perm[0] == 0 ? n : (perm[0] == 1 ? h : (perm[0] == 2 ? w : c)), + perm[1] == 0 ? n : (perm[1] == 1 ? h : (perm[1] == 2 ? w : c)), + perm[2] == 0 ? n : (perm[2] == 1 ? h : (perm[2] == 2 ? w : c)), + perm[3] == 0 ? n : (perm[3] == 1 ? h : (perm[3] == 2 ? w : c))) = a_4d(n, h, w, c); + } + } + } + } + } + else if (a.shape.size() == 5) + { + auto a_5d = a.as_5d(); + auto out_5d = this->activation.as_5d(); +#pragma omp parallel for collapse(5) + for (int n = 0; n < a.shape[0]; ++n) + { + for (int h = 0; h < a.shape[1]; ++h) + { + for (int w = 0; w < a.shape[2]; ++w) + { + for (int d = 0; d < a.shape[3]; ++d) + { + for (int c = 0; c < a.shape[4]; ++c) + { + out_5d(perm[0] == 0 ? n : (perm[0] == 1 ? h : (perm[0] == 2 ? w : (perm[0] == 3 ? d : c))), + perm[1] == 0 ? n : (perm[1] == 1 ? h : (perm[1] == 2 ? w : (perm[1] == 3 ? d : c))), + perm[2] == 0 ? n : (perm[2] == 1 ? h : (perm[2] == 2 ? w : (perm[2] == 3 ? d : c))), + perm[3] == 0 ? n : (perm[3] == 1 ? h : (perm[3] == 2 ? w : (perm[3] == 3 ? d : c))), + perm[4] == 0 ? n : (perm[4] == 1 ? h : (perm[4] == 2 ? w : (perm[4] == 3 ? d : c)))) = a_5d(n, h, w, d, c); + } + } + } + } + } + } + else + { + throw std::runtime_error("supported only 2d, 4d, 5d tensors in transpose"); + } } std::vector get_output_dims(const std::vector> &inShapes) { always_assert(inShapes.size() == 1); auto shape = inShapes[0]; - always_assert(shape.size() == 2); - return {shape[1], shape[0]}; + always_assert(perm.size() == shape.size()); + for (auto &p : perm) + { + if (p == 1) + p = shape.size() - 1; + else if (p > 1) + p -= 1; + } + std::vector outShape; + for (auto &p : perm) + { + outShape.push_back(shape[p]); + } + return outShape; } }; diff --git a/sytorch/include/sytorch/tensor.h b/sytorch/include/sytorch/tensor.h index 87bca667..87dfd024 100644 --- a/sytorch/include/sytorch/tensor.h +++ b/sytorch/include/sytorch/tensor.h @@ -18,6 +18,27 @@ typedef uint8_t u8; typedef int64_t i64; typedef int32_t i32; +template +inline T type_cast(float val); + +template <> +inline float type_cast(float val) +{ + return val; +} + +template <> +inline i64 type_cast(float val) +{ + return (i64)val; +} + +template <> +inline u64 type_cast(float val) +{ + return (u64(i64(val))); +} + template class TensorRef { public: @@ -192,7 +213,8 @@ class Tensor { { double d; std::cin >> d; - data[i] = (T)(d * (1LL << scale)); + data[i] = type_cast(d * (1LL << scale)); + } } @@ -213,9 +235,9 @@ class Tensor { u64 curr_rest = i % rest_size; u64 new_idx = curr_batch * (num_channel * rest_size) + curr_rest * num_channel + curr_channel; #ifdef Do_Masking - data[new_idx] = (T)d; + data[new_idx] = type_cast d; #else - data[new_idx] = (T)(d * (1LL << scale)); + data[new_idx] = type_cast(d * (1LL << scale)); #endif } } @@ -287,7 +309,7 @@ class Tensor { { for (int m = 0; m < d5; m++) { - this->data[i * d2 * d3 * d4 * d5 + j * d3 * d4 * d5 + k * d4 * d5 + l * d5 + m] = (T)(arr[i][j][k][l][m] * scale); + this->data[i * d2 * d3 * d4 * d5 + j * d3 * d4 * d5 + k * d4 * d5 + l * d5 + m] = type_cast(arr[i][j][k][l][m] * (1LL << scale)); } } } @@ -312,7 +334,7 @@ class Tensor { floatInput= (float*)mmap(NULL, sb.st_size, PROT_READ, MAP_PRIVATE, fd2, 0); for(u64 i = 0; i < size(); ++i) { - data[i] = (T)(floatInput[i] * (1LL << scale)); + data[i] = type_cast(floatInput[i] * (1LL << scale)); } ::close(fd2); //delete[] floatInput; diff --git a/sytorch/include/sytorch/utils.h b/sytorch/include/sytorch/utils.h index 2c0d3b75..3d66ca8e 100644 --- a/sytorch/include/sytorch/utils.h +++ b/sytorch/include/sytorch/utils.h @@ -278,6 +278,75 @@ void convTranspose3dLoop( } } +template +void convTranspose2dLoop( + int64_t N, + int64_t H, + int64_t W, + int64_t CI, + int64_t FH, + int64_t FW, + int64_t CO, + int64_t zPadHLeft, + int64_t zPadHRight, + int64_t zPadWLeft, + int64_t zPadWRight, + int64_t strideH, + int64_t strideW, + int64_t outH, + int64_t outW, + T *inputArr, + T *filterArr, + T *outArr) +{ + zPadHLeft = FH - 1 - zPadHLeft; + zPadHRight = FH - 1 - zPadHRight; + zPadWLeft = FW - 1 - zPadWLeft; + zPadWRight = FW - 1 - zPadWRight; + +#pragma omp parallel for collapse(4) + for (int64_t n = 0; n < N; n++) + { + for (int64_t h = 0; h < outH; h++) + { + for (int64_t w = 0; w < outW; w++) + { + for (int64_t co = 0; co < CO; co++) + { + + T val = 0; + for (int64_t ci = 0; ci < CI; ci++) + { + for (int64_t fh = h; fh < (h + FH); fh++) + { + for (int64_t fw = w; fw < (w + FW); fw++) + { + + int64_t curPosH = ((fh - zPadHLeft) / strideH); + int64_t curPosW = ((fw - zPadWLeft) / strideW); + + if ((curPosH >= 0) && + (curPosW >= 0) && + (curPosH < H) && + (curPosW < W) && + (((fh - zPadHLeft) % strideH) == 0) && + (((fw - zPadWLeft) % strideW) == 0)) + { + int32_t curFilterPosH = FH + h - fh - 1; + int32_t curFilterPosW = FW + w - fw - 1; + val += (Arr4DIdx(inputArr, N, H, W, CI, n, curPosH, curPosW, ci) * Arr4DIdx(filterArr, CO, FH, FW, CI, co, curFilterPosH, curFilterPosW, ci)); + } + } + } + } + Arr4DIdx(outArr, N, outH, outW, CO, n, h, w, co) = val; + // std::cout << "setting element at (" << n << " " << d << " " << h << " " << w << " " << co << ")" << std::endl; + } + } + } + } +} + template std::vector collect(T &first, Args & ... args) { @@ -371,6 +440,43 @@ void print(const Tensor &p, u64 scale, u64 bw) } } +template +void print_nchw(const Tensor &p, u64 scale, u64 bw) +{ + u64 batch_size = p.shape[0]; + u64 num_channel = p.shape.back(); + u64 rest_size = p.size() / (batch_size * num_channel); + + for (u64 i = 0; i < p.size(); i++) + { + u64 curr_batch = i / (num_channel * rest_size); + u64 curr_channel = (i / rest_size) % num_channel; + u64 curr_rest = i % rest_size; + u64 new_idx = curr_batch * (num_channel * rest_size) + curr_rest * num_channel + curr_channel; + + i64 val; + if (bw == sizeof(T) * 8) + { + val = p.data[new_idx]; + } + else + { + val = (p.data[new_idx] + (1LL << (bw - 1))) % (1LL << bw); + val -= (1LL << (bw - 1)); + } + std::cout << (double)val / (1LL << scale); + if ((i + 1) % num_channel == 0) + { + std::cout << std::endl; + } + else + { + std::cout << " "; + } + } + std::cout << std::endl; +} + template void print(const Tensor &p, u64 scale) { diff --git a/sytorch/src/sytorch/backend/cleartext.cpp b/sytorch/src/sytorch/backend/cleartext.cpp index cd5a9820..d858abc4 100644 --- a/sytorch/src/sytorch/backend/cleartext.cpp +++ b/sytorch/src/sytorch/backend/cleartext.cpp @@ -86,24 +86,43 @@ void ClearText::conv3D(u64 fd, u64 fh, u64 fw, u64 pd, u64 ph, u64 pw, u64 sd template void ClearText::convTranspose3D(u64 fd, u64 fh, u64 fw, u64 pd, u64 ph, u64 pw, u64 sd, u64 sh, u64 sw, u64 ci, u64 co, const Tensor5D &input, const Tensor2D &filter, Tensor5D &output) - { - assert(input.d5 == ci); - assert(filter.d1 == co); - assert(filter.d2 == fd * fh * fw * ci); - u64 newD = (((input.d2 - 1)*sd + fd - 2*pd)); - u64 newH = (((input.d3 - 1)*sh + fh - 2*ph)); - u64 newW = (((input.d4 - 1)*sw + fw - 2*pw)); - assert(output.d1 == input.d1); - assert(output.d2 == newD); - assert(output.d3 == newH); - assert(output.d4 == newW); - assert(output.d5 == co); - - convTranspose3dLoop(input.d1, input.d2, input.d3, input.d4, input.d5, fd, fh, fw, co, - pd, pd, ph, ph, pw, pw, sd, sh, sw, - output.d2, output.d3, output.d4, input.data, filter.data, output.data); - modbw(output); - } +{ + assert(input.d5 == ci); + assert(filter.d1 == co); + assert(filter.d2 == fd * fh * fw * ci); + u64 newD = (((input.d2 - 1) * sd + fd - 2 * pd)); + u64 newH = (((input.d3 - 1) * sh + fh - 2 * ph)); + u64 newW = (((input.d4 - 1) * sw + fw - 2 * pw)); + assert(output.d1 == input.d1); + assert(output.d2 == newD); + assert(output.d3 == newH); + assert(output.d4 == newW); + assert(output.d5 == co); + + convTranspose3dLoop(input.d1, input.d2, input.d3, input.d4, input.d5, fd, fh, fw, co, + pd, pd, ph, ph, pw, pw, sd, sh, sw, + output.d2, output.d3, output.d4, input.data, filter.data, output.data); + modbw(output); +} + +template +void ClearText::convTranspose2D(u64 fh, u64 fw, u64 ph, u64 pw, u64 sh, u64 sw, u64 ci, u64 co, const Tensor4D &input, const Tensor2D &filter, Tensor4D &output) +{ + assert(input.d4 == ci); + assert(filter.d1 == co); + assert(filter.d2 == fh * fw * ci); + u64 newH = (((input.d2 - 1) * sh + fh - 2 * ph)); + u64 newW = (((input.d3 - 1) * sw + fw - 2 * pw)); + assert(output.d1 == input.d1); + assert(output.d2 == newH); + assert(output.d3 == newW); + assert(output.d4 == co); + + convTranspose2dLoop(input.d1, input.d2, input.d3, input.d4, fh, fw, co, + ph, ph, pw, pw, sh, sw, + output.d2, output.d3, input.data, filter.data, output.data); + modbw(output); +} template void ClearText::relu(const Tensor &in, const Tensor &out, const Tensor &drelu, u64 scale, int mode) { @@ -116,6 +135,41 @@ void ClearText::relu(const Tensor &in, const Tensor &out, const Tensor< }); } +template +void ClearText::leakyRelu(const Tensor &in, const Tensor &out, const Tensor &drelu, u64 scale, int mode, T alpha) +{ + assert(in.is_same_shape(out)); + assert(in.is_same_shape(drelu)); + std::vector shape = in.shape; + T minus_one = (T)(-1 * (1LL << scale)); + // leakyrelu = relu(x) - alpha * relu(-x) + Tensor relu_x(shape); + // relu(x) + relu(in, relu_x, drelu, scale, mode); + + // -x + Tensor minus_x(shape); + fastfor(in.size(), [&](u64 i) + { minus_x.data[i] = minus_one * in.data[i]; + modbw(minus_x.data[i]); + truncate(minus_x.data[i], scale); }); + + // relu(-x) + Tensor relu_minus_x(shape); + relu(minus_x, relu_minus_x, drelu, scale, mode); + + // alpha * relu(-x) + Tensor alpha_relu_minus_x(shape); + fastfor(in.size(), [&](u64 i) + { alpha_relu_minus_x.data[i] = (T)(alpha * relu_minus_x.data[i]); + modbw(alpha_relu_minus_x.data[i]); + truncate(alpha_relu_minus_x.data[i], scale); }); + + // relu(x) - alpha * relu(-x) + fastfor(in.size(), [&](u64 i) + { out.data[i] = relu_x.data[i] - alpha_relu_minus_x.data[i]; }); +} + template void ClearText::truncate(T *in, T *out, u64 shift, u64 size, u8 mode) { fastfor(size, [&] (u64 i) { diff --git a/sytorch/src/sytorch/backend/float.cpp b/sytorch/src/sytorch/backend/float.cpp index 24ab68ce..2ab61ff1 100644 --- a/sytorch/src/sytorch/backend/float.cpp +++ b/sytorch/src/sytorch/backend/float.cpp @@ -112,6 +112,24 @@ void FloatClearText::convTranspose3D(u64 fd, u64 fh, u64 fw, u64 pd, u64 ph, output.d2, output.d3, output.d4, input.data, filter.data, output.data); } +template +void FloatClearText::convTranspose2D(u64 fh, u64 fw, u64 ph, u64 pw, u64 sh, u64 sw, u64 ci, u64 co, const Tensor4D &input, const Tensor2D &filter, Tensor4D &output) +{ + assert(input.d4 == ci); + assert(filter.d1 == co); + assert(filter.d2 == fh * fw * ci); + u64 newH = (((input.d2 - 1) * sh + fh - 2 * ph)); + u64 newW = (((input.d3 - 1) * sw + fw - 2 * pw)); + assert(output.d1 == input.d1); + assert(output.d2 == newH); + assert(output.d3 == newW); + assert(output.d4 == co); + + convTranspose2dLoop(input.d1, input.d2, input.d3, input.d4, fh, fw, co, + ph, ph, pw, pw, sh, sw, + output.d2, output.d3, input.data, filter.data, output.data); +} + template void FloatClearText::relu(const Tensor &in, const Tensor &out, const Tensor &drelu, u64 scale, int mode) { @@ -123,6 +141,37 @@ void FloatClearText::relu(const Tensor &in, const Tensor &out, const Te out.data[i] = (in.data[i] > 0) ? in.data[i] : 0; }); } +template +void FloatClearText::leakyRelu(const Tensor &in, const Tensor &out, const Tensor &drelu, u64 scale, int mode, T alpha) +{ + assert(in.is_same_shape(out)); + assert(in.is_same_shape(drelu)); + std::vector shape = in.shape; + + // leakyrelu = relu(x) - alpha * relu(-x) + Tensor relu_x(shape); + // relu(x) + relu(in, relu_x, drelu, scale, mode); + + // -x + Tensor minus_x(shape); + fastfor(in.size(), [&](u64 i) + { minus_x.data[i] = -1.0 * in.data[i]; }); + + // relu(-x) + Tensor relu_minus_x(shape); + relu(minus_x, relu_minus_x, drelu, scale, mode); + + // alpha * relu(-x) + Tensor alpha_relu_minus_x(shape); + fastfor(in.size(), [&](u64 i) + { alpha_relu_minus_x.data[i] = (T)(alpha * relu_minus_x.data[i]); }); + + // relu(x) - alpha * relu(-x) + fastfor(in.size(), [&](u64 i) + { out.data[i] = (relu_x.data[i] - alpha_relu_minus_x.data[i]); }); +} + template void FloatClearText::truncate(T *in, T *out, u64 shift, u64 size, u8 mode) {