From 89db276b7973276aa6387f81d162dbdd714a919a Mon Sep 17 00:00:00 2001 From: winnylyc Date: Thu, 11 Jan 2024 08:50:00 +0000 Subject: [PATCH 01/10] preprocessing --- sml/preprocessing/BUILD.bazel | 22 ++ sml/preprocessing/emulations/BUILD.bazel | 26 ++ .../emulations/preprocessing_emul.py | 159 ++++++++++ sml/preprocessing/preprocessing.py | 296 ++++++++++++++++++ sml/preprocessing/tests/BUILD.bazel | 27 ++ sml/preprocessing/tests/preprocessing_test.py | 167 ++++++++++ 6 files changed, 697 insertions(+) create mode 100755 sml/preprocessing/BUILD.bazel create mode 100755 sml/preprocessing/emulations/BUILD.bazel create mode 100755 sml/preprocessing/emulations/preprocessing_emul.py create mode 100755 sml/preprocessing/preprocessing.py create mode 100755 sml/preprocessing/tests/BUILD.bazel create mode 100755 sml/preprocessing/tests/preprocessing_test.py diff --git a/sml/preprocessing/BUILD.bazel b/sml/preprocessing/BUILD.bazel new file mode 100755 index 00000000..ad322743 --- /dev/null +++ b/sml/preprocessing/BUILD.bazel @@ -0,0 +1,22 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_python//python:defs.bzl", "py_library") + +package(default_visibility = ["//visibility:public"]) + +py_library( + name = "preprocessing", + srcs = ["preprocessing.py"], +) diff --git a/sml/preprocessing/emulations/BUILD.bazel b/sml/preprocessing/emulations/BUILD.bazel new file mode 100755 index 00000000..99c09f09 --- /dev/null +++ b/sml/preprocessing/emulations/BUILD.bazel @@ -0,0 +1,26 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_python//python:defs.bzl", "py_binary") + +package(default_visibility = ["//visibility:public"]) + +py_binary( + name = "preprocessing_emul", + srcs = ["preprocessing_emul.py"], + deps = [ + "//sml/preprocessing:preprocessing", + "//sml/utils:emulation", + ], +) \ No newline at end of file diff --git a/sml/preprocessing/emulations/preprocessing_emul.py b/sml/preprocessing/emulations/preprocessing_emul.py new file mode 100755 index 00000000..82feff16 --- /dev/null +++ b/sml/preprocessing/emulations/preprocessing_emul.py @@ -0,0 +1,159 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import jax.numpy as jnp +from sklearn import preprocessing +import numpy as np + +import sml.utils.emulation as emulation +from sml.preprocessing.preprocessing import LabelBinarizer, Binarizer, Normalizer + +def emul_labelbinarizer(): + def labelbinarize(X, Y): + transformer = LabelBinarizer(neg_label=-2, pos_label=3) + transformer.fit(X, n_classes=4) + transformed = transformer.transform(Y) + inv_transformed = transformer.inverse_transform(transformed) + return transformed, inv_transformed + + X = jnp.array([1, 2, 6, 4, 2]) + Y = jnp.array([1, 6]) + + spu_transformed, spu_inv_transformed = emulator.run(labelbinarize)(X, Y) + # print("result\n", spu_transformed) + # print("result\n", spu_inv_transformed) + + transformer = preprocessing.LabelBinarizer(neg_label=-2, pos_label=3) + transformer.fit(X) + sk_transformed = transformer.transform(Y) + sk_inv_transformed = transformer.inverse_transform(sk_transformed) + # print("sklearn:\n", sk_transformed) + # print("sklearn:\n", sk_inv_transformed) + np.testing.assert_allclose(sk_transformed, spu_transformed, rtol=0, atol=0) + np.testing.assert_allclose(sk_inv_transformed, spu_inv_transformed, rtol=0, atol=0) + +def emul_labelbinarizer_binary(): + def labelbinarize(X, Y): + transformer = LabelBinarizer() + # transformer.fit(X, n_classes=4) + transformed = transformer.fit_transform(X, n_classes=2) + inv_transformed = transformer.inverse_transform(transformed) + return transformed, inv_transformed + + X = jnp.array([1, -1, -1, 1]) + Y = jnp.array([1, 6]) + + spu_transformed, spu_inv_transformed = emulator.run(labelbinarize)(X, Y) + # print("result\n", spu_transformed) + # print("result\n", spu_inv_transformed) + + transformer = preprocessing.LabelBinarizer() + sk_transformed = transformer.fit_transform(X) + sk_inv_transformed = transformer.inverse_transform(sk_transformed) + # print("sklearn:\n", sk_transformed) + # print("sklearn:\n", sk_inv_transformed) + np.testing.assert_allclose(sk_transformed, spu_transformed, rtol=0, atol=0) + np.testing.assert_allclose(sk_inv_transformed, spu_inv_transformed, rtol=0, atol=0) + +def emul_labelbinarizer_unseen(): + def labelbinarize(X, Y): + transformer = LabelBinarizer() + transformer.fit(X, n_classes=3) + return transformer.transform(Y, unseen = True) + + X = jnp.array([2, 4, 5]) + Y = jnp.array([1, 2, 3, 4, 5, 6]) + + spu_result = emulator.run(labelbinarize)(X, Y) + # print("result\n", spu_result) + + transformer = preprocessing.LabelBinarizer() + transformer.fit(X) + sk_result = transformer.transform(Y) + # print("sklearn:\n", sk_result) + np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=0) + +def emul_binarizer(): + def binarize(X): + transformer = Binarizer() + return transformer.transform(X) + + X = jnp.array([[ 1., -1., 2.], + [ 2., 0., 0.], + [ 0., 1., -1.]]) + + spu_result = emulator.run(binarize)(X) + # print("result\n", spu_result) + + transformer = preprocessing.Binarizer() + sk_result = transformer.transform(X) + # print("sklearn:\n", sk_result) + np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=0) + +def emul_normalizer(): + def normalize_l1(X): + transformer = Normalizer(norm="l1") + return transformer.transform(X) + + def normalize_l2(X): + transformer = Normalizer() + return transformer.transform(X) + + def normalize_max(X): + transformer = Normalizer(norm="max") + return transformer.transform(X) + + X = jnp.array([[4, 1, 2, 2], + [1, 3, 9, 3], + [5, 7, 5, 1]]) + + spu_result_l1 = emulator.run(normalize_l1)(X) + spu_result_l2 = emulator.run(normalize_l2)(X) + spu_result_max = emulator.run(normalize_max)(X) + # print("result\n", spu_result_l1) + # print("result\n", spu_result_l2) + # print("result\n", spu_result_max) + + transformer_l1 = preprocessing.Normalizer(norm="l1") + sk_result_l1 = transformer_l1.transform(X) + transformer_l2 = preprocessing.Normalizer() + sk_result_l2 = transformer_l2.transform(X) + transformer_max = preprocessing.Normalizer(norm="max") + sk_result_max = transformer_max.transform(X) + # print("sklearn:\n", sk_result_l1) + # print("sklearn:\n", sk_result_l2) + # print("sklearn:\n", sk_result_max) + np.testing.assert_allclose(sk_result_l1, spu_result_l1, rtol=0, atol=1e-4) + np.testing.assert_allclose(sk_result_l2, spu_result_l2, rtol=0, atol=1e-4) + np.testing.assert_allclose(sk_result_max, spu_result_max, rtol=0, atol=1e-4) + +if __name__ == "__main__": + try: + # bandwidth and latency only work for docker mode + emulator = emulation.Emulator( + emulation.CLUSTER_ABY3_3PC, + emulation.Mode.MULTIPROCESS, + bandwidth=300, + latency=20, + ) + emulator.up() + emul_labelbinarizer() + emul_labelbinarizer_binary() + emul_labelbinarizer_unseen() + emul_binarizer() + emul_normalizer() + finally: + emulator.down() \ No newline at end of file diff --git a/sml/preprocessing/preprocessing.py b/sml/preprocessing/preprocessing.py new file mode 100755 index 00000000..37d9918d --- /dev/null +++ b/sml/preprocessing/preprocessing.py @@ -0,0 +1,296 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from posixpath import normcase +import jax.numpy as jnp +import jax + +def label_binarize(y, *, classes, n_classes, neg_label=0, pos_label=1, unseen = False): + """Binarize labels in a one-vs-all fashion. + + Parameters + ---------- + y : {array-like}, shape (n_samples,) + Input data. + + classes : {array-like}, shape (n_classes,) + Uniquely holds the label for each class. + + n_classes : int + Number of classes. SPU cannot support dynamic shape, + so this parameter needs to be designated. + + neg_label : int, default=0 + Value with which negative labels must be encoded. + + pos_label : int, default=1 + Value with which positive labels must be encoded. + + unseen : bool, default=False + True if the input array contains the classes that are unseen in the fit phase. + + Returns + ------- + ndarray of shape (n_samples, n_classes) + Shape will be (n_samples, 1) for binary problems. + """ + n_samples = y.shape[0] + indices = jnp.searchsorted(classes, y) + result = jax.nn.one_hot(indices, n_classes, dtype = jnp.int_) + if unseen == True: + indices = jnp.searchsorted(classes, y) + emptylike = jnp.full((n_samples, n_classes), 0) + boolean = jnp.tile(jnp.isin(y, classes)[:, jnp.newaxis], (1, 3)) + result = jax.lax.select(boolean, result, emptylike) + + if neg_label != 0 or pos_label != 1: + result = jnp.where(result, pos_label, neg_label) + + if n_classes == 2: + result = result[:, -1].reshape((-1, 1)) + return result + +def _inverse_binarize_multiclass(y, classes): + """Inverse label binarization transformation for multiclass. + + Multiclass uses the maximal score instead of a threshold. + """ + return jnp.take(classes, y.argmax(axis=1), mode="clip") + +def _inverse_binarize_thresholding(y, classes, threshold): + """Inverse label binarization transformation using thresholding.""" + y = jnp.array(y > threshold, dtype=int) + return classes[y[:, 1]] + +class LabelBinarizer(): + """Binarize labels in a one-vs-all fashion. + + Firstly, use fit() to use an array to set the classes. + The number of classes needs to be designated through parameter n_classes since SPU cannot support dynamic shape. + + Secondly, use transform() to convert the value to a one-hot label for classes. + The input array needs to be 1d. In sklearn, the input array is automatically transformed into 1d, + so more dimension seems to be meaningless. To avoid redundant operations used in MPC implementation, + the automated transformation is canceled. Users can directly use the transformation method like jax.ravel to + transform the input array into 1d then use LabelBinarizer to do further transformation. + In sklearn, transform can accept an array containing the classes not seen in the fit phase. + This function is not supported by default, since many additional operations will be used. + Set the parameter unseen to be true to activate the function. + + Parameters + ---------- + neg_label : int, default=0 + Value with which negative labels must be encoded. + + pos_label : int, default=1 + Value with which positive labels must be encoded. + + """ + + def __init__(self, *, neg_label=0, pos_label=1): + self.neg_label = neg_label + self.pos_label = pos_label + + def fit(self, y, n_classes): + """Fit label binarizer. + + Parameters + ---------- + y : {array-like}, shape (n_samples,) + Input data. + + n_classes : int + Number of classes. SPU cannot support dynamic shape, + so this parameter needs to be designated. + + Returns + ------- + self : object + Returns the instance itself. + """ + if self.neg_label >= self.pos_label: + raise ValueError( + f"neg_label={self.neg_label} must be strictly less than " + f"pos_label={self.pos_label}." + ) + # The output of jax needs to be tensor with known size. + self.classes_ = jnp.unique(y, size = n_classes) + self.n_classes_ = n_classes + return self + + def fit_transform(self, y, n_classes, *, unseen = False): + """Fit label binarizer/transform multi-class labels to binary labels. + + Parameters + ---------- + y : {array-like}, shape (n_samples,) + Input data. + + n_classes : int + Number of classes. SPU cannot support dynamic shape, + so this parameter needs to be designated. + + unseen : bool, default=False + True if the input array contains the classes that are unseen in the fit phase. + + Returns + ------- + ndarray of shape (n_samples, n_classes) + Shape will be (n_samples, 1) for binary problems. + """ + return self.fit(y, n_classes).transform(y, unseen = unseen) + + def transform(self, y, *, unseen = False): + """Transform multi-class labels to binary labels. + Parameters + ---------- + y : {array-like}, shape (n_samples,) + Input data. + + unseen : bool, default=False + True if the input array contains the classes that are unseen in the fit phase. + + Returns + ------- + ndarray of shape (n_samples, n_classes) + Shape will be (n_samples, 1) for binary problems. + """ + return label_binarize( + y, + classes=self.classes_, + n_classes=self.n_classes_, + neg_label=self.neg_label, + pos_label=self.pos_label, + unseen = unseen + ) + + def inverse_transform(self, Y, threshold=None): + """Transform binary labels back to multi-class labels. + + Parameters + ---------- + Y : {array-like}, shape (n_samples, n_classes) + Input data. + + threshold : float, default=None + Threshold used in the binary cases. + + Returns + ------- + ndarray of shape (n_samples,) + + """ + if threshold is None: + threshold = (self.pos_label + self.neg_label) / 2.0 + if self.n_classes_ == 2: + y_inv = _inverse_binarize_thresholding(Y, self.classes_, threshold) + else: + y_inv = _inverse_binarize_multiclass(Y, self.classes_) + return y_inv + + +def binarize(X, *, threshold=0.0): + """Binarize data (set feature values to 0 or 1) according to a threshold. + + Parameters + ---------- + threshold : float, default=0.0 + Feature values below or equal to this are replaced by 0, above it by 1. + + """ + return jnp.where(X > threshold, 1, 0) + +class Binarizer(): + """Binarize data (set feature values to 0 or 1) according to a threshold. + + Parameters + ---------- + threshold : float, default=0.0 + Feature values below or equal to this are replaced by 0, above it by 1. + + """ + def __init__(self, *, threshold=0.0): + self.threshold = threshold + + def transform(self, X, copy=None): + """Binarize each element of X. + + Parameters + ---------- + X : {array-like} of shape (n_samples, n_features) + The data to binarize, element by element. + + Returns + ------- + ndarray of shape (n_samples, n_features) + Transformed array. + """ + return binarize(X, threshold=self.threshold) + +def normalize(X, norm="l2"): + """Scale input vectors individually to unit norm (vector length). + + Parameters + ---------- + X : {array-like} of shape (n_samples, n_features) + The data to normalize, element by element. + + norm : {'l1', 'l2', 'max'}, default='l2' + The norm to use to normalize each non zero sample (or each non-zero + feature if axis is 0). + + Returns + ------- + ndarray of shape (n_samples, n_features) + Transformed array. + """ + if norm == "l1": + norms = jnp.abs(X).sum(axis=1) + return X / norms[:, jnp.newaxis] + elif norm == "l2": + norms = jnp.einsum("ij,ij->i", X, X) + norms = norms.astype(jnp.float32) + # Use rsqrt instead of using combination of reciprocal and square for optimization + return X * jax.lax.rsqrt(norms)[:, jnp.newaxis] + elif norm == "max": + norms = jnp.max(abs(X), axis=1) + return X / norms[:, jnp.newaxis] + +class Normalizer(): + """Normalize samples individually to unit norm. + + Parameters + ---------- + norm : {'l1', 'l2', 'max'}, default='l2' + The norm to use to normalize each non zero sample. If norm='max' + is used, values will be rescaled by the maximum of the absolute + values. + """ + def __init__(self, norm="l2"): + self.norm = norm + + def transform(self, X): + """Scale each non zero row of X to unit norm. + + Parameters + ---------- + X : {array-like} of shape (n_samples, n_features) + The data to normalize, row by row. + + Returns + ------- + ndarray of shape (n_samples, n_features) + Transformed array. + """ + return normalize(X, norm=self.norm) diff --git a/sml/preprocessing/tests/BUILD.bazel b/sml/preprocessing/tests/BUILD.bazel new file mode 100755 index 00000000..1864b4fa --- /dev/null +++ b/sml/preprocessing/tests/BUILD.bazel @@ -0,0 +1,27 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_python//python:defs.bzl", "py_test") + +package(default_visibility = ["//visibility:public"]) + +py_test( + name = "preprocessing_test", + srcs = ["preprocessing_test.py"], + deps = [ + "//sml/preprocessing:preprocessing", + "//spu:init", + "//spu/utils:simulation", + ], +) diff --git a/sml/preprocessing/tests/preprocessing_test.py b/sml/preprocessing/tests/preprocessing_test.py new file mode 100755 index 00000000..4d6d3ae1 --- /dev/null +++ b/sml/preprocessing/tests/preprocessing_test.py @@ -0,0 +1,167 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import jax.numpy as jnp +from sklearn import preprocessing +import numpy as np + +import spu.spu_pb2 as spu_pb2 +import spu.utils.simulation as spsim + +from sml.preprocessing.preprocessing import LabelBinarizer, Binarizer, Normalizer + +class UnitTests(unittest.TestCase): + def test_labelbinarizer(self): + sim = spsim.Simulator.simple( + 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 + ) + + def labelbinarize(X, Y): + transformer = LabelBinarizer(neg_label=-2, pos_label=3) + transformer.fit(X, n_classes=4) + transformed = transformer.transform(Y) + inv_transformed = transformer.inverse_transform(transformed) + return transformed, inv_transformed + + X = jnp.array([1, 2, 6, 4, 2]) + Y = jnp.array([1, 6]) + + spu_transformed, spu_inv_transformed = spsim.sim_jax(sim, labelbinarize)(X, Y) + # print("result\n", spu_transformed) + # print("result\n", spu_inv_transformed) + + transformer = preprocessing.LabelBinarizer(neg_label=-2, pos_label=3) + transformer.fit(X) + sk_transformed = transformer.transform(Y) + sk_inv_transformed = transformer.inverse_transform(sk_transformed) + # print("sklearn:\n", sk_transformed) + # print("sklearn:\n", sk_inv_transformed) + np.testing.assert_allclose(sk_transformed, spu_transformed, rtol=0, atol=0) + np.testing.assert_allclose(sk_inv_transformed, spu_inv_transformed, rtol=0, atol=0) + + def test_labelbinarizer_binary(self): + sim = spsim.Simulator.simple( + 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 + ) + + def labelbinarize(X, Y): + transformer = LabelBinarizer() + # transformer.fit(X, n_classes=4) + transformed = transformer.fit_transform(X, n_classes=2) + inv_transformed = transformer.inverse_transform(transformed) + return transformed, inv_transformed + + X = jnp.array([1, -1, -1, 1]) + Y = jnp.array([1, 6]) + + spu_transformed, spu_inv_transformed = spsim.sim_jax(sim, labelbinarize)(X, Y) + # print("result\n", spu_transformed) + # print("result\n", spu_inv_transformed) + + transformer = preprocessing.LabelBinarizer() + sk_transformed = transformer.fit_transform(X) + sk_inv_transformed = transformer.inverse_transform(sk_transformed) + # print("sklearn:\n", sk_transformed) + # print("sklearn:\n", sk_inv_transformed) + np.testing.assert_allclose(sk_transformed, spu_transformed, rtol=0, atol=0) + np.testing.assert_allclose(sk_inv_transformed, spu_inv_transformed, rtol=0, atol=0) + + def test_labelbinarizer_unseen(self): + sim = spsim.Simulator.simple( + 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 + ) + + def labelbinarize(X, Y): + transformer = LabelBinarizer() + transformer.fit(X, n_classes=3) + return transformer.transform(Y, unseen = True) + + X = jnp.array([2, 4, 5]) + Y = jnp.array([1, 2, 3, 4, 5, 6]) + + spu_result = spsim.sim_jax(sim, labelbinarize)(X, Y) + # print("result\n", spu_result) + + transformer = preprocessing.LabelBinarizer() + transformer.fit(X) + sk_result = transformer.transform(Y) + # print("sklearn:\n", sk_result) + np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=0) + + def test_binarizer(self): + sim = spsim.Simulator.simple( + 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 + ) + + def binarize(X): + transformer = Binarizer() + return transformer.transform(X) + + X = jnp.array([[ 1., -1., 2.], + [ 2., 0., 0.], + [ 0., 1., -1.]]) + + spu_result = spsim.sim_jax(sim, binarize)(X) + # print("result\n", spu_result) + + transformer = preprocessing.Binarizer() + sk_result = transformer.transform(X) + # print("sklearn:\n", sk_result) + np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=0) + + def test_normalizer(self): + sim = spsim.Simulator.simple( + 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 + ) + + def normalize_l1(X): + transformer = Normalizer(norm="l1") + return transformer.transform(X) + + def normalize_l2(X): + transformer = Normalizer() + return transformer.transform(X) + + def normalize_max(X): + transformer = Normalizer(norm="max") + return transformer.transform(X) + + X = jnp.array([[4, 1, 2, 2], + [1, 3, 9, 3], + [5, 7, 5, 1]]) + + spu_result_l1 = spsim.sim_jax(sim, normalize_l1)(X) + spu_result_l2 = spsim.sim_jax(sim, normalize_l2)(X) + spu_result_max = spsim.sim_jax(sim, normalize_max)(X) + # print("result\n", spu_result_l1) + # print("result\n", spu_result_l2) + # print("result\n", spu_result_max) + + transformer_l1 = preprocessing.Normalizer(norm="l1") + sk_result_l1 = transformer_l1.transform(X) + transformer_l2 = preprocessing.Normalizer() + sk_result_l2 = transformer_l2.transform(X) + transformer_max = preprocessing.Normalizer(norm="max") + sk_result_max = transformer_max.transform(X) + # print("sklearn:\n", sk_result_l1) + # print("sklearn:\n", sk_result_l2) + # print("sklearn:\n", sk_result_max) + np.testing.assert_allclose(sk_result_l1, spu_result_l1, rtol=0, atol=1e-4) + np.testing.assert_allclose(sk_result_l2, spu_result_l2, rtol=0, atol=1e-4) + np.testing.assert_allclose(sk_result_max, spu_result_max, rtol=0, atol=1e-4) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From 4c6ac69c8e198d71482970739c0e700d0c2ec86d Mon Sep 17 00:00:00 2001 From: winnylyc Date: Thu, 11 Jan 2024 10:09:11 +0000 Subject: [PATCH 02/10] format --- .../emulations/preprocessing_emul.py | 33 ++++----- sml/preprocessing/preprocessing.py | 74 +++++++++++-------- sml/preprocessing/tests/preprocessing_test.py | 35 ++++----- 3 files changed, 76 insertions(+), 66 deletions(-) diff --git a/sml/preprocessing/emulations/preprocessing_emul.py b/sml/preprocessing/emulations/preprocessing_emul.py index 82feff16..d72524ee 100755 --- a/sml/preprocessing/emulations/preprocessing_emul.py +++ b/sml/preprocessing/emulations/preprocessing_emul.py @@ -12,14 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest - import jax.numpy as jnp -from sklearn import preprocessing import numpy as np +from sklearn import preprocessing import sml.utils.emulation as emulation -from sml.preprocessing.preprocessing import LabelBinarizer, Binarizer, Normalizer +from sml.preprocessing.preprocessing import Binarizer, LabelBinarizer, Normalizer + def emul_labelbinarizer(): def labelbinarize(X, Y): @@ -45,10 +44,10 @@ def labelbinarize(X, Y): np.testing.assert_allclose(sk_transformed, spu_transformed, rtol=0, atol=0) np.testing.assert_allclose(sk_inv_transformed, spu_inv_transformed, rtol=0, atol=0) + def emul_labelbinarizer_binary(): def labelbinarize(X, Y): transformer = LabelBinarizer() - # transformer.fit(X, n_classes=4) transformed = transformer.fit_transform(X, n_classes=2) inv_transformed = transformer.inverse_transform(transformed) return transformed, inv_transformed @@ -68,11 +67,12 @@ def labelbinarize(X, Y): np.testing.assert_allclose(sk_transformed, spu_transformed, rtol=0, atol=0) np.testing.assert_allclose(sk_inv_transformed, spu_inv_transformed, rtol=0, atol=0) + def emul_labelbinarizer_unseen(): def labelbinarize(X, Y): - transformer = LabelBinarizer() - transformer.fit(X, n_classes=3) - return transformer.transform(Y, unseen = True) + transformer = LabelBinarizer() + transformer.fit(X, n_classes=3) + return transformer.transform(Y, unseen=True) X = jnp.array([2, 4, 5]) Y = jnp.array([1, 2, 3, 4, 5, 6]) @@ -86,14 +86,13 @@ def labelbinarize(X, Y): # print("sklearn:\n", sk_result) np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=0) + def emul_binarizer(): def binarize(X): transformer = Binarizer() return transformer.transform(X) - X = jnp.array([[ 1., -1., 2.], - [ 2., 0., 0.], - [ 0., 1., -1.]]) + X = jnp.array([[1.0, -1.0, 2.0], [2.0, 0.0, 0.0], [0.0, 1.0, -1.0]]) spu_result = emulator.run(binarize)(X) # print("result\n", spu_result) @@ -103,22 +102,21 @@ def binarize(X): # print("sklearn:\n", sk_result) np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=0) + def emul_normalizer(): def normalize_l1(X): transformer = Normalizer(norm="l1") return transformer.transform(X) - + def normalize_l2(X): transformer = Normalizer() return transformer.transform(X) - + def normalize_max(X): transformer = Normalizer(norm="max") return transformer.transform(X) - X = jnp.array([[4, 1, 2, 2], - [1, 3, 9, 3], - [5, 7, 5, 1]]) + X = jnp.array([[4, 1, 2, 2], [1, 3, 9, 3], [5, 7, 5, 1]]) spu_result_l1 = emulator.run(normalize_l1)(X) spu_result_l2 = emulator.run(normalize_l2)(X) @@ -140,6 +138,7 @@ def normalize_max(X): np.testing.assert_allclose(sk_result_l2, spu_result_l2, rtol=0, atol=1e-4) np.testing.assert_allclose(sk_result_max, spu_result_max, rtol=0, atol=1e-4) + if __name__ == "__main__": try: # bandwidth and latency only work for docker mode @@ -156,4 +155,4 @@ def normalize_max(X): emul_binarizer() emul_normalizer() finally: - emulator.down() \ No newline at end of file + emulator.down() diff --git a/sml/preprocessing/preprocessing.py b/sml/preprocessing/preprocessing.py index 37d9918d..feaf9ab0 100755 --- a/sml/preprocessing/preprocessing.py +++ b/sml/preprocessing/preprocessing.py @@ -13,10 +13,12 @@ # limitations under the License. from posixpath import normcase -import jax.numpy as jnp + import jax +import jax.numpy as jnp + -def label_binarize(y, *, classes, n_classes, neg_label=0, pos_label=1, unseen = False): +def label_binarize(y, *, classes, n_classes, neg_label=0, pos_label=1, unseen=False): """Binarize labels in a one-vs-all fashion. Parameters @@ -26,9 +28,9 @@ def label_binarize(y, *, classes, n_classes, neg_label=0, pos_label=1, unseen = classes : {array-like}, shape (n_classes,) Uniquely holds the label for each class. - + n_classes : int - Number of classes. SPU cannot support dynamic shape, + Number of classes. SPU cannot support dynamic shape, so this parameter needs to be designated. neg_label : int, default=0 @@ -47,13 +49,13 @@ def label_binarize(y, *, classes, n_classes, neg_label=0, pos_label=1, unseen = """ n_samples = y.shape[0] indices = jnp.searchsorted(classes, y) - result = jax.nn.one_hot(indices, n_classes, dtype = jnp.int_) + result = jax.nn.one_hot(indices, n_classes, dtype=jnp.int_) if unseen == True: indices = jnp.searchsorted(classes, y) emptylike = jnp.full((n_samples, n_classes), 0) boolean = jnp.tile(jnp.isin(y, classes)[:, jnp.newaxis], (1, 3)) result = jax.lax.select(boolean, result, emptylike) - + if neg_label != 0 or pos_label != 1: result = jnp.where(result, pos_label, neg_label) @@ -61,6 +63,7 @@ def label_binarize(y, *, classes, n_classes, neg_label=0, pos_label=1, unseen = result = result[:, -1].reshape((-1, 1)) return result + def _inverse_binarize_multiclass(y, classes): """Inverse label binarization transformation for multiclass. @@ -68,19 +71,21 @@ def _inverse_binarize_multiclass(y, classes): """ return jnp.take(classes, y.argmax(axis=1), mode="clip") + def _inverse_binarize_thresholding(y, classes, threshold): """Inverse label binarization transformation using thresholding.""" y = jnp.array(y > threshold, dtype=int) return classes[y[:, 1]] -class LabelBinarizer(): + +class LabelBinarizer: """Binarize labels in a one-vs-all fashion. Firstly, use fit() to use an array to set the classes. The number of classes needs to be designated through parameter n_classes since SPU cannot support dynamic shape. Secondly, use transform() to convert the value to a one-hot label for classes. - The input array needs to be 1d. In sklearn, the input array is automatically transformed into 1d, + The input array needs to be 1d. In sklearn, the input array is automatically transformed into 1d, so more dimension seems to be meaningless. To avoid redundant operations used in MPC implementation, the automated transformation is canceled. Users can directly use the transformation method like jax.ravel to transform the input array into 1d then use LabelBinarizer to do further transformation. @@ -101,7 +106,7 @@ class LabelBinarizer(): def __init__(self, *, neg_label=0, pos_label=1): self.neg_label = neg_label self.pos_label = pos_label - + def fit(self, y, n_classes): """Fit label binarizer. @@ -109,9 +114,9 @@ def fit(self, y, n_classes): ---------- y : {array-like}, shape (n_samples,) Input data. - + n_classes : int - Number of classes. SPU cannot support dynamic shape, + Number of classes. SPU cannot support dynamic shape, so this parameter needs to be designated. Returns @@ -124,23 +129,23 @@ def fit(self, y, n_classes): f"neg_label={self.neg_label} must be strictly less than " f"pos_label={self.pos_label}." ) - # The output of jax needs to be tensor with known size. - self.classes_ = jnp.unique(y, size = n_classes) + # The output of jax needs to be tensor with known size. + self.classes_ = jnp.unique(y, size=n_classes) self.n_classes_ = n_classes return self - - def fit_transform(self, y, n_classes, *, unseen = False): + + def fit_transform(self, y, n_classes, *, unseen=False): """Fit label binarizer/transform multi-class labels to binary labels. Parameters ---------- y : {array-like}, shape (n_samples,) Input data. - + n_classes : int - Number of classes. SPU cannot support dynamic shape, + Number of classes. SPU cannot support dynamic shape, so this parameter needs to be designated. - + unseen : bool, default=False True if the input array contains the classes that are unseen in the fit phase. @@ -149,15 +154,15 @@ def fit_transform(self, y, n_classes, *, unseen = False): ndarray of shape (n_samples, n_classes) Shape will be (n_samples, 1) for binary problems. """ - return self.fit(y, n_classes).transform(y, unseen = unseen) - - def transform(self, y, *, unseen = False): + return self.fit(y, n_classes).transform(y, unseen=unseen) + + def transform(self, y, *, unseen=False): """Transform multi-class labels to binary labels. Parameters ---------- y : {array-like}, shape (n_samples,) Input data. - + unseen : bool, default=False True if the input array contains the classes that are unseen in the fit phase. @@ -172,9 +177,9 @@ def transform(self, y, *, unseen = False): n_classes=self.n_classes_, neg_label=self.neg_label, pos_label=self.pos_label, - unseen = unseen + unseen=unseen, ) - + def inverse_transform(self, Y, threshold=None): """Transform binary labels back to multi-class labels. @@ -185,7 +190,7 @@ def inverse_transform(self, Y, threshold=None): threshold : float, default=None Threshold used in the binary cases. - + Returns ------- ndarray of shape (n_samples,) @@ -211,7 +216,8 @@ def binarize(X, *, threshold=0.0): """ return jnp.where(X > threshold, 1, 0) -class Binarizer(): + +class Binarizer: """Binarize data (set feature values to 0 or 1) according to a threshold. Parameters @@ -220,9 +226,10 @@ class Binarizer(): Feature values below or equal to this are replaced by 0, above it by 1. """ + def __init__(self, *, threshold=0.0): self.threshold = threshold - + def transform(self, X, copy=None): """Binarize each element of X. @@ -238,6 +245,7 @@ def transform(self, X, copy=None): """ return binarize(X, threshold=self.threshold) + def normalize(X, norm="l2"): """Scale input vectors individually to unit norm (vector length). @@ -257,7 +265,7 @@ def normalize(X, norm="l2"): """ if norm == "l1": norms = jnp.abs(X).sum(axis=1) - return X / norms[:, jnp.newaxis] + return X / norms[:, jnp.newaxis] elif norm == "l2": norms = jnp.einsum("ij,ij->i", X, X) norms = norms.astype(jnp.float32) @@ -265,9 +273,10 @@ def normalize(X, norm="l2"): return X * jax.lax.rsqrt(norms)[:, jnp.newaxis] elif norm == "max": norms = jnp.max(abs(X), axis=1) - return X / norms[:, jnp.newaxis] + return X / norms[:, jnp.newaxis] -class Normalizer(): + +class Normalizer: """Normalize samples individually to unit norm. Parameters @@ -277,16 +286,17 @@ class Normalizer(): is used, values will be rescaled by the maximum of the absolute values. """ + def __init__(self, norm="l2"): self.norm = norm - + def transform(self, X): """Scale each non zero row of X to unit norm. Parameters ---------- X : {array-like} of shape (n_samples, n_features) - The data to normalize, row by row. + The data to normalize, row by row. Returns ------- diff --git a/sml/preprocessing/tests/preprocessing_test.py b/sml/preprocessing/tests/preprocessing_test.py index 4d6d3ae1..d1ef1439 100755 --- a/sml/preprocessing/tests/preprocessing_test.py +++ b/sml/preprocessing/tests/preprocessing_test.py @@ -15,13 +15,13 @@ import unittest import jax.numpy as jnp -from sklearn import preprocessing import numpy as np +from sklearn import preprocessing import spu.spu_pb2 as spu_pb2 import spu.utils.simulation as spsim +from sml.preprocessing.preprocessing import Binarizer, LabelBinarizer, Normalizer -from sml.preprocessing.preprocessing import LabelBinarizer, Binarizer, Normalizer class UnitTests(unittest.TestCase): def test_labelbinarizer(self): @@ -50,8 +50,10 @@ def labelbinarize(X, Y): # print("sklearn:\n", sk_transformed) # print("sklearn:\n", sk_inv_transformed) np.testing.assert_allclose(sk_transformed, spu_transformed, rtol=0, atol=0) - np.testing.assert_allclose(sk_inv_transformed, spu_inv_transformed, rtol=0, atol=0) - + np.testing.assert_allclose( + sk_inv_transformed, spu_inv_transformed, rtol=0, atol=0 + ) + def test_labelbinarizer_binary(self): sim = spsim.Simulator.simple( 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 @@ -77,8 +79,10 @@ def labelbinarize(X, Y): # print("sklearn:\n", sk_transformed) # print("sklearn:\n", sk_inv_transformed) np.testing.assert_allclose(sk_transformed, spu_transformed, rtol=0, atol=0) - np.testing.assert_allclose(sk_inv_transformed, spu_inv_transformed, rtol=0, atol=0) - + np.testing.assert_allclose( + sk_inv_transformed, spu_inv_transformed, rtol=0, atol=0 + ) + def test_labelbinarizer_unseen(self): sim = spsim.Simulator.simple( 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 @@ -87,7 +91,7 @@ def test_labelbinarizer_unseen(self): def labelbinarize(X, Y): transformer = LabelBinarizer() transformer.fit(X, n_classes=3) - return transformer.transform(Y, unseen = True) + return transformer.transform(Y, unseen=True) X = jnp.array([2, 4, 5]) Y = jnp.array([1, 2, 3, 4, 5, 6]) @@ -110,9 +114,7 @@ def binarize(X): transformer = Binarizer() return transformer.transform(X) - X = jnp.array([[ 1., -1., 2.], - [ 2., 0., 0.], - [ 0., 1., -1.]]) + X = jnp.array([[1.0, -1.0, 2.0], [2.0, 0.0, 0.0], [0.0, 1.0, -1.0]]) spu_result = spsim.sim_jax(sim, binarize)(X) # print("result\n", spu_result) @@ -121,7 +123,7 @@ def binarize(X): sk_result = transformer.transform(X) # print("sklearn:\n", sk_result) np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=0) - + def test_normalizer(self): sim = spsim.Simulator.simple( 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 @@ -130,18 +132,16 @@ def test_normalizer(self): def normalize_l1(X): transformer = Normalizer(norm="l1") return transformer.transform(X) - + def normalize_l2(X): transformer = Normalizer() return transformer.transform(X) - + def normalize_max(X): transformer = Normalizer(norm="max") return transformer.transform(X) - X = jnp.array([[4, 1, 2, 2], - [1, 3, 9, 3], - [5, 7, 5, 1]]) + X = jnp.array([[4, 1, 2, 2], [1, 3, 9, 3], [5, 7, 5, 1]]) spu_result_l1 = spsim.sim_jax(sim, normalize_l1)(X) spu_result_l2 = spsim.sim_jax(sim, normalize_l2)(X) @@ -163,5 +163,6 @@ def normalize_max(X): np.testing.assert_allclose(sk_result_l2, spu_result_l2, rtol=0, atol=1e-4) np.testing.assert_allclose(sk_result_max, spu_result_max, rtol=0, atol=1e-4) + if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From 235ab832837cbae4336b438c6fbb8c6b8a251287 Mon Sep 17 00:00:00 2001 From: winnylyc Date: Fri, 12 Jan 2024 00:24:26 +0000 Subject: [PATCH 03/10] buildifier --- sml/preprocessing/emulations/BUILD.bazel | 4 ++-- sml/preprocessing/tests/BUILD.bazel | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sml/preprocessing/emulations/BUILD.bazel b/sml/preprocessing/emulations/BUILD.bazel index 99c09f09..74820410 100755 --- a/sml/preprocessing/emulations/BUILD.bazel +++ b/sml/preprocessing/emulations/BUILD.bazel @@ -20,7 +20,7 @@ py_binary( name = "preprocessing_emul", srcs = ["preprocessing_emul.py"], deps = [ - "//sml/preprocessing:preprocessing", + "//sml/preprocessing", "//sml/utils:emulation", ], -) \ No newline at end of file +) diff --git a/sml/preprocessing/tests/BUILD.bazel b/sml/preprocessing/tests/BUILD.bazel index 1864b4fa..994ed985 100755 --- a/sml/preprocessing/tests/BUILD.bazel +++ b/sml/preprocessing/tests/BUILD.bazel @@ -20,7 +20,7 @@ py_test( name = "preprocessing_test", srcs = ["preprocessing_test.py"], deps = [ - "//sml/preprocessing:preprocessing", + "//sml/preprocessing", "//spu:init", "//spu/utils:simulation", ], From 96633ce59f6bfcf96017c2a0c243b10dd2b9e013 Mon Sep 17 00:00:00 2001 From: winnylyc Date: Sun, 14 Jan 2024 06:12:17 +0000 Subject: [PATCH 04/10] nocopy --- sml/preprocessing/preprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sml/preprocessing/preprocessing.py b/sml/preprocessing/preprocessing.py index feaf9ab0..beb018d5 100755 --- a/sml/preprocessing/preprocessing.py +++ b/sml/preprocessing/preprocessing.py @@ -230,7 +230,7 @@ class Binarizer: def __init__(self, *, threshold=0.0): self.threshold = threshold - def transform(self, X, copy=None): + def transform(self, X): """Binarize each element of X. Parameters From c935474fc06f9dae68cdf3b79f1a3c5af25feb86 Mon Sep 17 00:00:00 2001 From: winnylyc Date: Sun, 14 Jan 2024 06:43:36 +0000 Subject: [PATCH 05/10] no searchsorted --- sml/preprocessing/preprocessing.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/sml/preprocessing/preprocessing.py b/sml/preprocessing/preprocessing.py index beb018d5..89344f95 100755 --- a/sml/preprocessing/preprocessing.py +++ b/sml/preprocessing/preprocessing.py @@ -47,14 +47,8 @@ def label_binarize(y, *, classes, n_classes, neg_label=0, pos_label=1, unseen=Fa ndarray of shape (n_samples, n_classes) Shape will be (n_samples, 1) for binary problems. """ - n_samples = y.shape[0] - indices = jnp.searchsorted(classes, y) - result = jax.nn.one_hot(indices, n_classes, dtype=jnp.int_) - if unseen == True: - indices = jnp.searchsorted(classes, y) - emptylike = jnp.full((n_samples, n_classes), 0) - boolean = jnp.tile(jnp.isin(y, classes)[:, jnp.newaxis], (1, 3)) - result = jax.lax.select(boolean, result, emptylike) + eq_func = lambda x: jnp.where(classes == x, 1, 0) + result = jax.vmap(eq_func)(y) if neg_label != 0 or pos_label != 1: result = jnp.where(result, pos_label, neg_label) From 121416a3ab8225e850bedd16333f3ba76258b8b1 Mon Sep 17 00:00:00 2001 From: winnylyc Date: Sun, 14 Jan 2024 07:07:45 +0000 Subject: [PATCH 06/10] no unseen and add seal to emul --- .../emulations/preprocessing_emul.py | 63 +++++++++++-------- sml/preprocessing/preprocessing.py | 28 +++------ sml/preprocessing/tests/preprocessing_test.py | 42 +++++++------ 3 files changed, 69 insertions(+), 64 deletions(-) diff --git a/sml/preprocessing/emulations/preprocessing_emul.py b/sml/preprocessing/emulations/preprocessing_emul.py index d72524ee..3b4d0916 100755 --- a/sml/preprocessing/emulations/preprocessing_emul.py +++ b/sml/preprocessing/emulations/preprocessing_emul.py @@ -31,16 +31,18 @@ def labelbinarize(X, Y): X = jnp.array([1, 2, 6, 4, 2]) Y = jnp.array([1, 6]) - spu_transformed, spu_inv_transformed = emulator.run(labelbinarize)(X, Y) - # print("result\n", spu_transformed) - # print("result\n", spu_inv_transformed) - transformer = preprocessing.LabelBinarizer(neg_label=-2, pos_label=3) transformer.fit(X) sk_transformed = transformer.transform(Y) sk_inv_transformed = transformer.inverse_transform(sk_transformed) # print("sklearn:\n", sk_transformed) # print("sklearn:\n", sk_inv_transformed) + + X, Y = emulator.seal(X, Y) + spu_transformed, spu_inv_transformed = emulator.run(labelbinarize)(X, Y) + # print("spu:\n", spu_transformed) + # print("spu:\n", spu_inv_transformed) + np.testing.assert_allclose(sk_transformed, spu_transformed, rtol=0, atol=0) np.testing.assert_allclose(sk_inv_transformed, spu_inv_transformed, rtol=0, atol=0) @@ -54,16 +56,17 @@ def labelbinarize(X, Y): X = jnp.array([1, -1, -1, 1]) Y = jnp.array([1, 6]) - - spu_transformed, spu_inv_transformed = emulator.run(labelbinarize)(X, Y) - # print("result\n", spu_transformed) - # print("result\n", spu_inv_transformed) - transformer = preprocessing.LabelBinarizer() sk_transformed = transformer.fit_transform(X) sk_inv_transformed = transformer.inverse_transform(sk_transformed) # print("sklearn:\n", sk_transformed) # print("sklearn:\n", sk_inv_transformed) + + X, Y = emulator.seal(X, Y) + spu_transformed, spu_inv_transformed = emulator.run(labelbinarize)(X, Y) + # print("spu:\n", spu_transformed) + # print("spu:\n", spu_inv_transformed) + np.testing.assert_allclose(sk_transformed, spu_transformed, rtol=0, atol=0) np.testing.assert_allclose(sk_inv_transformed, spu_inv_transformed, rtol=0, atol=0) @@ -72,18 +75,20 @@ def emul_labelbinarizer_unseen(): def labelbinarize(X, Y): transformer = LabelBinarizer() transformer.fit(X, n_classes=3) - return transformer.transform(Y, unseen=True) + return transformer.transform(Y) X = jnp.array([2, 4, 5]) Y = jnp.array([1, 2, 3, 4, 5, 6]) - spu_result = emulator.run(labelbinarize)(X, Y) - # print("result\n", spu_result) - transformer = preprocessing.LabelBinarizer() transformer.fit(X) sk_result = transformer.transform(Y) # print("sklearn:\n", sk_result) + + X, Y = emulator.seal(X, Y) + spu_result = emulator.run(labelbinarize)(X, Y) + # print("spu:\n", spu_result) + np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=0) @@ -94,12 +99,14 @@ def binarize(X): X = jnp.array([[1.0, -1.0, 2.0], [2.0, 0.0, 0.0], [0.0, 1.0, -1.0]]) - spu_result = emulator.run(binarize)(X) - # print("result\n", spu_result) - transformer = preprocessing.Binarizer() sk_result = transformer.transform(X) # print("sklearn:\n", sk_result) + + X = emulator.seal(X) + spu_result = emulator.run(binarize)(X) + # print("spu:\n", spu_result) + np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=0) @@ -118,13 +125,6 @@ def normalize_max(X): X = jnp.array([[4, 1, 2, 2], [1, 3, 9, 3], [5, 7, 5, 1]]) - spu_result_l1 = emulator.run(normalize_l1)(X) - spu_result_l2 = emulator.run(normalize_l2)(X) - spu_result_max = emulator.run(normalize_max)(X) - # print("result\n", spu_result_l1) - # print("result\n", spu_result_l2) - # print("result\n", spu_result_max) - transformer_l1 = preprocessing.Normalizer(norm="l1") sk_result_l1 = transformer_l1.transform(X) transformer_l2 = preprocessing.Normalizer() @@ -134,6 +134,15 @@ def normalize_max(X): # print("sklearn:\n", sk_result_l1) # print("sklearn:\n", sk_result_l2) # print("sklearn:\n", sk_result_max) + + X = emulator.seal(X) + spu_result_l1 = emulator.run(normalize_l1)(X) + spu_result_l2 = emulator.run(normalize_l2)(X) + spu_result_max = emulator.run(normalize_max)(X) + # print("spu:\n", spu_result_l1) + # print("spu:\n", spu_result_l2) + # print("spu:\n", spu_result_max) + np.testing.assert_allclose(sk_result_l1, spu_result_l1, rtol=0, atol=1e-4) np.testing.assert_allclose(sk_result_l2, spu_result_l2, rtol=0, atol=1e-4) np.testing.assert_allclose(sk_result_max, spu_result_max, rtol=0, atol=1e-4) @@ -150,9 +159,9 @@ def normalize_max(X): ) emulator.up() emul_labelbinarizer() - emul_labelbinarizer_binary() - emul_labelbinarizer_unseen() - emul_binarizer() - emul_normalizer() + # emul_labelbinarizer_binary() + # emul_labelbinarizer_unseen() + # emul_binarizer() + # emul_normalizer() finally: emulator.down() diff --git a/sml/preprocessing/preprocessing.py b/sml/preprocessing/preprocessing.py index 89344f95..e2deeac7 100755 --- a/sml/preprocessing/preprocessing.py +++ b/sml/preprocessing/preprocessing.py @@ -18,7 +18,7 @@ import jax.numpy as jnp -def label_binarize(y, *, classes, n_classes, neg_label=0, pos_label=1, unseen=False): +def label_binarize(y, *, classes, n_classes, neg_label=0, pos_label=1): """Binarize labels in a one-vs-all fashion. Parameters @@ -39,16 +39,16 @@ def label_binarize(y, *, classes, n_classes, neg_label=0, pos_label=1, unseen=Fa pos_label : int, default=1 Value with which positive labels must be encoded. - unseen : bool, default=False - True if the input array contains the classes that are unseen in the fit phase. - Returns ------- ndarray of shape (n_samples, n_classes) Shape will be (n_samples, 1) for binary problems. """ - eq_func = lambda x: jnp.where(classes == x, 1, 0) - result = jax.vmap(eq_func)(y) + n_samples = y.shape[0] + indices = jnp.searchsorted(classes, y) + result = jax.nn.one_hot(indices, n_classes, dtype=jnp.int_) + # eq_func = lambda x: jnp.where(classes == x, 1, 0) + # result = jax.vmap(eq_func)(y) if neg_label != 0 or pos_label != 1: result = jnp.where(result, pos_label, neg_label) @@ -83,9 +83,6 @@ class LabelBinarizer: so more dimension seems to be meaningless. To avoid redundant operations used in MPC implementation, the automated transformation is canceled. Users can directly use the transformation method like jax.ravel to transform the input array into 1d then use LabelBinarizer to do further transformation. - In sklearn, transform can accept an array containing the classes not seen in the fit phase. - This function is not supported by default, since many additional operations will be used. - Set the parameter unseen to be true to activate the function. Parameters ---------- @@ -128,7 +125,7 @@ def fit(self, y, n_classes): self.n_classes_ = n_classes return self - def fit_transform(self, y, n_classes, *, unseen=False): + def fit_transform(self, y, n_classes): """Fit label binarizer/transform multi-class labels to binary labels. Parameters @@ -140,26 +137,20 @@ def fit_transform(self, y, n_classes, *, unseen=False): Number of classes. SPU cannot support dynamic shape, so this parameter needs to be designated. - unseen : bool, default=False - True if the input array contains the classes that are unseen in the fit phase. - Returns ------- ndarray of shape (n_samples, n_classes) Shape will be (n_samples, 1) for binary problems. """ - return self.fit(y, n_classes).transform(y, unseen=unseen) + return self.fit(y, n_classes).transform(y) - def transform(self, y, *, unseen=False): + def transform(self, y): """Transform multi-class labels to binary labels. Parameters ---------- y : {array-like}, shape (n_samples,) Input data. - unseen : bool, default=False - True if the input array contains the classes that are unseen in the fit phase. - Returns ------- ndarray of shape (n_samples, n_classes) @@ -171,7 +162,6 @@ def transform(self, y, *, unseen=False): n_classes=self.n_classes_, neg_label=self.neg_label, pos_label=self.pos_label, - unseen=unseen, ) def inverse_transform(self, Y, threshold=None): diff --git a/sml/preprocessing/tests/preprocessing_test.py b/sml/preprocessing/tests/preprocessing_test.py index d1ef1439..e1197561 100755 --- a/sml/preprocessing/tests/preprocessing_test.py +++ b/sml/preprocessing/tests/preprocessing_test.py @@ -39,16 +39,17 @@ def labelbinarize(X, Y): X = jnp.array([1, 2, 6, 4, 2]) Y = jnp.array([1, 6]) - spu_transformed, spu_inv_transformed = spsim.sim_jax(sim, labelbinarize)(X, Y) - # print("result\n", spu_transformed) - # print("result\n", spu_inv_transformed) - transformer = preprocessing.LabelBinarizer(neg_label=-2, pos_label=3) transformer.fit(X) sk_transformed = transformer.transform(Y) sk_inv_transformed = transformer.inverse_transform(sk_transformed) # print("sklearn:\n", sk_transformed) # print("sklearn:\n", sk_inv_transformed) + + spu_transformed, spu_inv_transformed = spsim.sim_jax(sim, labelbinarize)(X, Y) + # print("result\n", spu_transformed) + # print("result\n", spu_inv_transformed) + np.testing.assert_allclose(sk_transformed, spu_transformed, rtol=0, atol=0) np.testing.assert_allclose( sk_inv_transformed, spu_inv_transformed, rtol=0, atol=0 @@ -69,15 +70,16 @@ def labelbinarize(X, Y): X = jnp.array([1, -1, -1, 1]) Y = jnp.array([1, 6]) - spu_transformed, spu_inv_transformed = spsim.sim_jax(sim, labelbinarize)(X, Y) - # print("result\n", spu_transformed) - # print("result\n", spu_inv_transformed) - transformer = preprocessing.LabelBinarizer() sk_transformed = transformer.fit_transform(X) sk_inv_transformed = transformer.inverse_transform(sk_transformed) # print("sklearn:\n", sk_transformed) # print("sklearn:\n", sk_inv_transformed) + + spu_transformed, spu_inv_transformed = spsim.sim_jax(sim, labelbinarize)(X, Y) + # print("result\n", spu_transformed) + # print("result\n", spu_inv_transformed) + np.testing.assert_allclose(sk_transformed, spu_transformed, rtol=0, atol=0) np.testing.assert_allclose( sk_inv_transformed, spu_inv_transformed, rtol=0, atol=0 @@ -91,18 +93,19 @@ def test_labelbinarizer_unseen(self): def labelbinarize(X, Y): transformer = LabelBinarizer() transformer.fit(X, n_classes=3) - return transformer.transform(Y, unseen=True) + return transformer.transform(Y) X = jnp.array([2, 4, 5]) Y = jnp.array([1, 2, 3, 4, 5, 6]) - spu_result = spsim.sim_jax(sim, labelbinarize)(X, Y) - # print("result\n", spu_result) - transformer = preprocessing.LabelBinarizer() transformer.fit(X) sk_result = transformer.transform(Y) # print("sklearn:\n", sk_result) + + spu_result = spsim.sim_jax(sim, labelbinarize)(X, Y) + # print("result\n", spu_result) + np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=0) def test_binarizer(self): @@ -116,12 +119,13 @@ def binarize(X): X = jnp.array([[1.0, -1.0, 2.0], [2.0, 0.0, 0.0], [0.0, 1.0, -1.0]]) - spu_result = spsim.sim_jax(sim, binarize)(X) - # print("result\n", spu_result) - transformer = preprocessing.Binarizer() sk_result = transformer.transform(X) # print("sklearn:\n", sk_result) + + spu_result = spsim.sim_jax(sim, binarize)(X) + # print("result\n", spu_result) + np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=0) def test_normalizer(self): @@ -146,9 +150,6 @@ def normalize_max(X): spu_result_l1 = spsim.sim_jax(sim, normalize_l1)(X) spu_result_l2 = spsim.sim_jax(sim, normalize_l2)(X) spu_result_max = spsim.sim_jax(sim, normalize_max)(X) - # print("result\n", spu_result_l1) - # print("result\n", spu_result_l2) - # print("result\n", spu_result_max) transformer_l1 = preprocessing.Normalizer(norm="l1") sk_result_l1 = transformer_l1.transform(X) @@ -159,6 +160,11 @@ def normalize_max(X): # print("sklearn:\n", sk_result_l1) # print("sklearn:\n", sk_result_l2) # print("sklearn:\n", sk_result_max) + + # print("result\n", spu_result_l1) + # print("result\n", spu_result_l2) + # print("result\n", spu_result_max) + np.testing.assert_allclose(sk_result_l1, spu_result_l1, rtol=0, atol=1e-4) np.testing.assert_allclose(sk_result_l2, spu_result_l2, rtol=0, atol=1e-4) np.testing.assert_allclose(sk_result_max, spu_result_max, rtol=0, atol=1e-4) From 0b373cd9d871d994c087d05f71be8cdfb47dea39 Mon Sep 17 00:00:00 2001 From: winnylyc Date: Sun, 14 Jan 2024 07:12:13 +0000 Subject: [PATCH 07/10] small fix --- sml/preprocessing/preprocessing.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/sml/preprocessing/preprocessing.py b/sml/preprocessing/preprocessing.py index e2deeac7..844c6b45 100755 --- a/sml/preprocessing/preprocessing.py +++ b/sml/preprocessing/preprocessing.py @@ -44,11 +44,8 @@ def label_binarize(y, *, classes, n_classes, neg_label=0, pos_label=1): ndarray of shape (n_samples, n_classes) Shape will be (n_samples, 1) for binary problems. """ - n_samples = y.shape[0] - indices = jnp.searchsorted(classes, y) - result = jax.nn.one_hot(indices, n_classes, dtype=jnp.int_) - # eq_func = lambda x: jnp.where(classes == x, 1, 0) - # result = jax.vmap(eq_func)(y) + eq_func = lambda x: jnp.where(classes == x, 1, 0) + result = jax.vmap(eq_func)(y) if neg_label != 0 or pos_label != 1: result = jnp.where(result, pos_label, neg_label) From 55ff9b7d4f0c6e0494d2fdaa7870bbf6068111ca Mon Sep 17 00:00:00 2001 From: winnylyc Date: Sun, 14 Jan 2024 08:14:37 +0000 Subject: [PATCH 08/10] default no unique --- .../emulations/preprocessing_emul.py | 19 ++++++++-------- sml/preprocessing/preprocessing.py | 22 ++++++++++++++----- sml/preprocessing/tests/preprocessing_test.py | 10 ++++----- 3 files changed, 30 insertions(+), 21 deletions(-) diff --git a/sml/preprocessing/emulations/preprocessing_emul.py b/sml/preprocessing/emulations/preprocessing_emul.py index 3b4d0916..f0996abf 100755 --- a/sml/preprocessing/emulations/preprocessing_emul.py +++ b/sml/preprocessing/emulations/preprocessing_emul.py @@ -28,7 +28,7 @@ def labelbinarize(X, Y): inv_transformed = transformer.inverse_transform(transformed) return transformed, inv_transformed - X = jnp.array([1, 2, 6, 4, 2]) + X = jnp.array([1, 2, 4, 6]) Y = jnp.array([1, 6]) transformer = preprocessing.LabelBinarizer(neg_label=-2, pos_label=3) @@ -48,22 +48,21 @@ def labelbinarize(X, Y): def emul_labelbinarizer_binary(): - def labelbinarize(X, Y): + def labelbinarize(X): transformer = LabelBinarizer() - transformed = transformer.fit_transform(X, n_classes=2) + transformed = transformer.fit_transform(X, n_classes=2, unique=False) inv_transformed = transformer.inverse_transform(transformed) return transformed, inv_transformed X = jnp.array([1, -1, -1, 1]) - Y = jnp.array([1, 6]) transformer = preprocessing.LabelBinarizer() sk_transformed = transformer.fit_transform(X) sk_inv_transformed = transformer.inverse_transform(sk_transformed) # print("sklearn:\n", sk_transformed) # print("sklearn:\n", sk_inv_transformed) - X, Y = emulator.seal(X, Y) - spu_transformed, spu_inv_transformed = emulator.run(labelbinarize)(X, Y) + X = emulator.seal(X) + spu_transformed, spu_inv_transformed = emulator.run(labelbinarize)(X) # print("spu:\n", spu_transformed) # print("spu:\n", spu_inv_transformed) @@ -159,9 +158,9 @@ def normalize_max(X): ) emulator.up() emul_labelbinarizer() - # emul_labelbinarizer_binary() - # emul_labelbinarizer_unseen() - # emul_binarizer() - # emul_normalizer() + emul_labelbinarizer_binary() + emul_labelbinarizer_unseen() + emul_binarizer() + emul_normalizer() finally: emulator.down() diff --git a/sml/preprocessing/preprocessing.py b/sml/preprocessing/preprocessing.py index 844c6b45..89f936a8 100755 --- a/sml/preprocessing/preprocessing.py +++ b/sml/preprocessing/preprocessing.py @@ -74,6 +74,9 @@ class LabelBinarizer: Firstly, use fit() to use an array to set the classes. The number of classes needs to be designated through parameter n_classes since SPU cannot support dynamic shape. + The dynamic shape problem occurs when there are duplicated elements in input of fit function. + The deduplication operation will cause complex computation, so it is not used by default. + Noted that if unique==True, the order of the classes will be kept instead of sorted. Secondly, use transform() to convert the value to a one-hot label for classes. The input array needs to be 1d. In sklearn, the input array is automatically transformed into 1d, @@ -95,7 +98,7 @@ def __init__(self, *, neg_label=0, pos_label=1): self.neg_label = neg_label self.pos_label = pos_label - def fit(self, y, n_classes): + def fit(self, y, n_classes, unique=True): """Fit label binarizer. Parameters @@ -106,6 +109,9 @@ def fit(self, y, n_classes): n_classes : int Number of classes. SPU cannot support dynamic shape, so this parameter needs to be designated. + + unique : bool + Set to False to do deduplication on classes Returns ------- @@ -117,12 +123,15 @@ def fit(self, y, n_classes): f"neg_label={self.neg_label} must be strictly less than " f"pos_label={self.pos_label}." ) - # The output of jax needs to be tensor with known size. - self.classes_ = jnp.unique(y, size=n_classes) + if unique==True: + self.classes_ = y + else: + # The output of jax needs to be tensor with known size. + self.classes_ = jnp.unique(y, size=n_classes) self.n_classes_ = n_classes return self - def fit_transform(self, y, n_classes): + def fit_transform(self, y, n_classes, unique=True): """Fit label binarizer/transform multi-class labels to binary labels. Parameters @@ -133,13 +142,16 @@ def fit_transform(self, y, n_classes): n_classes : int Number of classes. SPU cannot support dynamic shape, so this parameter needs to be designated. + + unique : bool + Set to False to do deduplication on classes Returns ------- ndarray of shape (n_samples, n_classes) Shape will be (n_samples, 1) for binary problems. """ - return self.fit(y, n_classes).transform(y) + return self.fit(y, n_classes, unique=unique).transform(y) def transform(self, y): """Transform multi-class labels to binary labels. diff --git a/sml/preprocessing/tests/preprocessing_test.py b/sml/preprocessing/tests/preprocessing_test.py index e1197561..b7e4bdf7 100755 --- a/sml/preprocessing/tests/preprocessing_test.py +++ b/sml/preprocessing/tests/preprocessing_test.py @@ -36,7 +36,7 @@ def labelbinarize(X, Y): inv_transformed = transformer.inverse_transform(transformed) return transformed, inv_transformed - X = jnp.array([1, 2, 6, 4, 2]) + X = jnp.array([1, 2, 4, 6]) Y = jnp.array([1, 6]) transformer = preprocessing.LabelBinarizer(neg_label=-2, pos_label=3) @@ -60,15 +60,13 @@ def test_labelbinarizer_binary(self): 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 ) - def labelbinarize(X, Y): + def labelbinarize(X): transformer = LabelBinarizer() - # transformer.fit(X, n_classes=4) - transformed = transformer.fit_transform(X, n_classes=2) + transformed = transformer.fit_transform(X, n_classes=2, unique=False) inv_transformed = transformer.inverse_transform(transformed) return transformed, inv_transformed X = jnp.array([1, -1, -1, 1]) - Y = jnp.array([1, 6]) transformer = preprocessing.LabelBinarizer() sk_transformed = transformer.fit_transform(X) @@ -76,7 +74,7 @@ def labelbinarize(X, Y): # print("sklearn:\n", sk_transformed) # print("sklearn:\n", sk_inv_transformed) - spu_transformed, spu_inv_transformed = spsim.sim_jax(sim, labelbinarize)(X, Y) + spu_transformed, spu_inv_transformed = spsim.sim_jax(sim, labelbinarize)(X) # print("result\n", spu_transformed) # print("result\n", spu_inv_transformed) From 97a0d4e0536ee50dd90d11031f0f3fde39e137ad Mon Sep 17 00:00:00 2001 From: winnylyc Date: Sun, 14 Jan 2024 08:32:07 +0000 Subject: [PATCH 09/10] change doc and format --- sml/preprocessing/emulations/preprocessing_emul.py | 2 +- sml/preprocessing/preprocessing.py | 12 +++++------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/sml/preprocessing/emulations/preprocessing_emul.py b/sml/preprocessing/emulations/preprocessing_emul.py index f0996abf..1b326243 100755 --- a/sml/preprocessing/emulations/preprocessing_emul.py +++ b/sml/preprocessing/emulations/preprocessing_emul.py @@ -83,7 +83,7 @@ def labelbinarize(X, Y): transformer.fit(X) sk_result = transformer.transform(Y) # print("sklearn:\n", sk_result) - + X, Y = emulator.seal(X, Y) spu_result = emulator.run(labelbinarize)(X, Y) # print("spu:\n", spu_result) diff --git a/sml/preprocessing/preprocessing.py b/sml/preprocessing/preprocessing.py index 89f936a8..80b44b50 100755 --- a/sml/preprocessing/preprocessing.py +++ b/sml/preprocessing/preprocessing.py @@ -79,10 +79,8 @@ class LabelBinarizer: Noted that if unique==True, the order of the classes will be kept instead of sorted. Secondly, use transform() to convert the value to a one-hot label for classes. - The input array needs to be 1d. In sklearn, the input array is automatically transformed into 1d, - so more dimension seems to be meaningless. To avoid redundant operations used in MPC implementation, - the automated transformation is canceled. Users can directly use the transformation method like jax.ravel to - transform the input array into 1d then use LabelBinarizer to do further transformation. + The input array needs to be 1d. Users can directly use the transformation method like jax.ravel to transform + the input array into 1d then use LabelBinarizer to do further transformation. Parameters ---------- @@ -109,7 +107,7 @@ def fit(self, y, n_classes, unique=True): n_classes : int Number of classes. SPU cannot support dynamic shape, so this parameter needs to be designated. - + unique : bool Set to False to do deduplication on classes @@ -123,7 +121,7 @@ def fit(self, y, n_classes, unique=True): f"neg_label={self.neg_label} must be strictly less than " f"pos_label={self.pos_label}." ) - if unique==True: + if unique == True: self.classes_ = y else: # The output of jax needs to be tensor with known size. @@ -142,7 +140,7 @@ def fit_transform(self, y, n_classes, unique=True): n_classes : int Number of classes. SPU cannot support dynamic shape, so this parameter needs to be designated. - + unique : bool Set to False to do deduplication on classes From 437ca0c9a8169c4c2e120f19976fad6c8d3dd598 Mon Sep 17 00:00:00 2001 From: winnylyc Date: Mon, 15 Jan 2024 05:28:39 +0000 Subject: [PATCH 10/10] remove redundant import --- sml/preprocessing/preprocessing.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sml/preprocessing/preprocessing.py b/sml/preprocessing/preprocessing.py index 80b44b50..a06d0968 100755 --- a/sml/preprocessing/preprocessing.py +++ b/sml/preprocessing/preprocessing.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from posixpath import normcase - import jax import jax.numpy as jnp