Skip to content

Commit

Permalink
adds node truncate, leakyRelu, ConvTranspose2d and bug fixes (#188)
Browse files Browse the repository at this point in the history
* added transpose for 4D,5D tensors with perm

* leakyrelu added in ct_sytorch and OnnxBridge

* leaky relu added to LLAMA and ct

* Adds Cleartext float via sytorch (#186) (#187) from master

* float added from kanav-gpt in sytorch

* cleartext_fp added in Onnxbridge

* i64->T in tensor.h

* debug statement

* typecast corrected in module.h

* removing debug statement

* added testing for sytorch float ct

* rename tests

* type_cast modified and leakyrelu added to float

* convTranspose2d added in cleartext

* added custom test script using pytest

* adding convTranspose2d-LLAMA correctness not tested

* bug fix size=1 input_prng.cpp

* sytorch conv3d bug fix- dim d != 1

* addition in test file

* printing bug resolved
  • Loading branch information
drunkenlegend authored Sep 21, 2023
1 parent f6bee92 commit fc2f88d
Show file tree
Hide file tree
Showing 25 changed files with 987 additions and 68 deletions.
11 changes: 8 additions & 3 deletions OnnxBridge/LLAMA/sytorchBackendRep.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def func_call(node, value_info):
"""
func_map = {
"Relu": "ReLU",
"LeakyRelu": "LeakyReLU",
"Conv": f"{'Conv3D' if len(value_info[node.inputs[0]][1]) == 5 else 'Conv2D'}",
"MaxPool": "MaxPool2D",
"Flatten": "Flatten",
Expand All @@ -21,7 +22,8 @@ def func_call(node, value_info):
"AveragePool": "AvgPool2D",
"GlobalAveragePool": "GlobalAvgPool2D",
"Add": "add",
"ConvTranspose": "ConvTranspose3D",
"ConvTranspose": f"{'ConvTranspose3D' if len(value_info[node.inputs[0]][1]) == 5 else 'ConvTranspose2D'}",
"Transpose": "Transpose",
}
return func_map[node.op_type]

Expand Down Expand Up @@ -56,6 +58,7 @@ def inputs_to_take(node):
tmp_dict = {
"Conv": 1,
"Relu": 1,
"LeakyRelu": 1,
"MaxPool": 1,
"Gemm": 1,
"Flatten": 1,
Expand All @@ -65,6 +68,7 @@ def inputs_to_take(node):
"BatchNormalization": 1,
"GlobalAveragePool": 1,
"ConvTranspose": 1,
"Transpose": 1,
}
return tmp_dict[node]

Expand Down Expand Up @@ -130,7 +134,8 @@ def cleartext_post(code_list, program, scale, mode, indent):
input.input_nchw(scale);
print_dot_graph(net.root);
net.forward(input);
print(net.activation, scale, 64);
net.activation.printshape();
print_nchw(net.activation, scale, 64);
return 0;
{'}'}
Expand Down Expand Up @@ -300,7 +305,7 @@ def llama_post(code_list, program, scale, mode, bitlength, indent):
auto &output = net.activation;
llama->outputA(output);
if (party == CLIENT) {'{'}
print(output, scale, LlamaConfig::bitlength);
print_nchw(output, scale, LlamaConfig::bitlength);
{'}'}
llama->finalize();
{'}'}
Expand Down
37 changes: 35 additions & 2 deletions OnnxBridge/LLAMA/sytorch_func_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ def Relu(cls, attributes, inputs, outputs, value_info, var_dict, mode, indent):
logger.debug("Inside Relu function call.")
return str(f"{' ' * indent}new ReLU<T>();")

@classmethod
def LeakyRelu(cls, attributes, inputs, outputs, value_info, var_dict, mode, indent):
logger.debug("Inside LeakyRelu function call.")
if "alpha" in attributes.keys():
alpha = attributes["alpha"]
else:
alpha = 0.01
return str(f"{' ' * indent}new LeakyReLU<T>({alpha});")

@classmethod
def BatchNormalization(
cls, attributes, inputs, outputs, value_info, var_dict, mode, indent
Expand Down Expand Up @@ -174,10 +183,27 @@ def ConvTranspose(
cls, attributes, inputs, outputs, value_info, var_dict, mode, indent
):
logger.debug("Inside ConvTranspose function call.")
pads = get_padding_3d(attributes, inputs, outputs, value_info, var_dict)
spatial_size = len(value_info[inputs[0]][1]) - 2
if spatial_size == 3:
if spatial_size == 2:
assert len(inputs) == 2 or len(inputs) == 3
pads = get_padding(attributes, inputs, outputs, value_info, var_dict)
assert len(attributes["strides"]) == 2
assert value_info[inputs[1]][1][2:] == tuple(attributes["kernel_shape"])
CI = value_info[inputs[0]][1][1]
CO = value_info[outputs[0]][1][1]
filterShape = value_info[inputs[1]][1]
pad = pads[0]
strides = attributes["strides"]
dilations = get_dilation(attributes, inputs, outputs, value_info, var_dict)
isBias = ", true" if len(inputs) == 3 else ""
return str(
f"{' ' * indent}new ConvTranspose2D<T>("
f"{CI}, {CO}, {'{'}{iterate_list(filterShape[2:])}{'}'}, {'{'}{iterate_list(pads)}{'}'}, {'{'}{iterate_list(strides)}{'}'},{'{'}{iterate_list(dilations)}{'}'}{isBias}"
f");"
)
elif spatial_size == 3:
assert len(inputs) == 2 or len(inputs) == 3
pads = get_padding_3d(attributes, inputs, outputs, value_info, var_dict)
assert len(attributes["strides"]) == 3
assert value_info[inputs[1]][1][2:] == tuple(attributes["kernel_shape"])
CI = value_info[inputs[0]][1][1]
Expand Down Expand Up @@ -232,6 +258,13 @@ def Reshape(cls, attributes, inputs, outputs, value_info, var_dict, mode, indent
return str(f"{' ' * indent}new Reshape<T>();")
# todo : check format

@classmethod
def Transpose(cls, attributes, inputs, outputs, value_info, var_dict, mode, indent):
logger.debug("Inside Transpose function call.")
return str(
f"{' ' * indent}new Transpose<T>( {'{'}{iterate_list(attributes['perm'])}{'}'});"
)

@classmethod
def Gemm(cls, attributes, inputs, outputs, value_info, var_dict, mode, indent):
logger.debug("Inside Gemm function call.")
Expand Down
2 changes: 2 additions & 0 deletions OnnxBridge/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def is_compatible(cls, model, backend, device: str = "2PC", **kwargs):
not_supported = []
implemented_sytorch = [
"Relu",
"LeakyRelu",
"Softmax",
"Conv",
"MaxPool",
Expand All @@ -133,6 +134,7 @@ def is_compatible(cls, model, backend, device: str = "2PC", **kwargs):
"GlobalAveragePool",
"Add",
"ConvTranspose",
"Transpose",
]
implemented_secfloat = [
"Relu",
Expand Down
32 changes: 30 additions & 2 deletions OnnxBridge/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@ def pytest_addoption(parser):
help="batch size",
required=False,
)
parser.addoption(
"--model",
action="store",
help="absolute mdel path",
required=False,
)
parser.addoption(
"--input_name",
action="store",
help="absolute input_name path",
required=False,
)


@pytest.fixture(scope="session")
Expand All @@ -34,6 +46,18 @@ def backend(request):
return opt


@pytest.fixture(scope="session")
def model(request):
opt = request.config.getoption("--model")
return opt


@pytest.fixture(scope="session")
def input_name(request):
opt = request.config.getoption("--input_name")
return opt


@pytest.fixture(scope="session")
def batch_size(request):
opt = request.config.getoption("--batch_size")
Expand Down Expand Up @@ -79,8 +103,12 @@ def pytest_runtest_makereport(item, call):
@pytest.fixture
def test_dir(request, test_env):
print("\nRequest node: ", request.node.name)
test_name_list = request.node.name.split("[")
parameter_name = test_name_list[1].split("]")[0]
if "[" in request.node.name:
test_name_list = request.node.name.split("[")
parameter_name = test_name_list[1].split("]")[0]
else:
test_name_list = request.node.name.split("_")
parameter_name = "custom"
full_test_name = test_name_list[0] + "_" + parameter_name
test_name = full_test_name[len("test_") :]
main_test_dir = test_env["test_dir"]
Expand Down
43 changes: 43 additions & 0 deletions OnnxBridge/tests/custom_model_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pytest
import os
from utils import (
run_onnx,
compile_model,
run_backend,
compare_output,
)

# Get the directory path where the current script is located
script_directory = os.path.dirname(os.path.abspath(__file__))
ezpc_dir = os.path.join(script_directory, "..", "..")


def test_custom_model(test_dir, backend, model, input_name):
"""
Usage:
pytest path/custom_model_test.py -s --backend CLEARTEXT_LLAMA --model /home/saksham/EzPC/OnnxBridge/nnUnet/optimized_fabiansPreActUnet.onnx --input_name /home/saksham/EzPC/OnnxBridge/nnUnet/inputs/2d_input
"""
os.chdir(test_dir)

# model is absolute path to model.onnx
# input is absolute path to input1.j
# download the model & data & preprocessing_file
os.system(f"cp {model} model.onnx")
os.system(f"cp {input_name}.inp input1.inp")
os.system(f"cp {input_name}.npy input1.npy")

# preprocessed input is directly copied

# run the model with OnnxRuntime
run_onnx("input1.npy")

# compile the model with backend
compile_model(backend)

# run the model with backend
run_backend(backend, "input1.inp")

# compare the output
compare_output()

os.chdir("../..")
3 changes: 3 additions & 0 deletions OnnxBridge/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ def compare_output():
arr1 = np.load("output.npy", allow_pickle=True).flatten()
arr2 = np.load("onnx_output/expected.npy", allow_pickle=True).flatten()

print("Secure Output Shape: " + str(arr1.shape))
print("Expected Output Shape: " + str(arr2.shape))

matching_prec = -1
for prec in range(1, 10):
try:
Expand Down
4 changes: 2 additions & 2 deletions OnnxBridge/utils/onnx_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ def ConvTranspose(cls, node):

@classmethod
def LeakyRelu(cls, node):
if "alpha" not in node.attributes:
node.attributes["alpha"] = 0.01
if "alpha" not in node.attrs:
node.attrs["alpha"] = 0.01

@classmethod
def Tanh(cls, node):
Expand Down
88 changes: 87 additions & 1 deletion sytorch/ext/llama/api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1510,4 +1510,90 @@ void ConvTranspose3DWrapper(int64_t N,

std::cerr << ">> ConvTranspose3D - End" << "\n";

}
}

void ConvTranspose2DWrapper(int64_t N,
int64_t H,
int64_t W,
int64_t CI,
int64_t FH,
int64_t FW,
int64_t CO,
int64_t zPadHLeft,
int64_t zPadHRight,
int64_t zPadWLeft,
int64_t zPadWRight,
int64_t strideH,
int64_t strideW,
int64_t outH,
int64_t outW,
GroupElement *inputArr,
GroupElement *filterArr,
GroupElement *outArr)
{
std::cerr << ">> ConvTranspose2D - Start" << std::endl;
always_assert(outH == (H - 1) * strideH - zPadHLeft - zPadHRight + FH);
always_assert(outW == (W - 1) * strideW - zPadWLeft - zPadWRight + FW);

if (party == DEALER)
{
auto local_start = std::chrono::high_resolution_clock::now();

// not good for in place operations
for (int i = 0; i < N * outH * outW * CO; ++i)
{
outArr[i] = random_ge(bitlength);
}

auto keys = KeyGenConvTranspose2D(bitlength, N, H, W, CI, FH, FW, CO, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideH, strideW, outH, outW, inputArr, filterArr, outArr);

auto local_end = std::chrono::high_resolution_clock::now();

client->send_triple_key(keys.second);
freeTripleKey(keys.second);
auto local_time_taken = std::chrono::duration_cast<std::chrono::microseconds>(local_end -
local_start)
.count();
dealerMicroseconds += local_time_taken;
std::cerr << " Dealer Time = " << local_time_taken / 1000.0 << " milliseconds\n";
}
else
{

auto keyread_start = std::chrono::high_resolution_clock::now();
auto key = dealer->recv_triple_key(bitlength, N * H * W * CI, CI * FH * FW * CO, N * outH * outW * CO);
auto keyread_end = std::chrono::high_resolution_clock::now();
auto keyread_time_taken = std::chrono::duration_cast<std::chrono::milliseconds>(keyread_end - keyread_start).count();

peer->sync();

auto local_start = std::chrono::high_resolution_clock::now();

EvalConvTranspose2D(party, key, N, H, W, CI, FH, FW, CO,
zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideH, strideW, outH, outW, inputArr, filterArr, outArr);

auto t1 = std::chrono::high_resolution_clock::now();
uint64_t onlineComm0 = peer->bytesReceived + peer->bytesSent;

reconstruct(N * outH * outW * CO, outArr, bitlength);

uint64_t onlineComm1 = peer->bytesReceived + peer->bytesSent;
convOnlineComm += (onlineComm1 - onlineComm0);
auto local_end = std::chrono::high_resolution_clock::now();

freeTripleKey(key);
auto compute_time = std::chrono::duration_cast<std::chrono::microseconds>(t1 - local_start).count();
auto reconstruct_time = std::chrono::duration_cast<std::chrono::microseconds>(local_end - t1).count();

convEvalMicroseconds += (reconstruct_time + compute_time);
evalMicroseconds += (reconstruct_time + compute_time);
std::cerr << " Key Read Time = " << keyread_time_taken << " milliseconds\n";
std::cerr << " Compute Time = " << compute_time / 1000.0 << " milliseconds\n";
std::cerr << " Reconstruct Time = " << reconstruct_time / 1000.0 << " milliseconds\n";
std::cerr << " Online Time = " << (reconstruct_time + compute_time) / 1000.0 << " milliseconds\n";
std::cerr << " Online Comm = " << (onlineComm1 - onlineComm0) << " bytes\n";
}

std::cerr << ">> ConvTranspose2D - End" << std::endl;
}

Loading

0 comments on commit fc2f88d

Please sign in to comment.