diff --git a/sml/preprocessing/BUILD.bazel b/sml/preprocessing/BUILD.bazel new file mode 100755 index 00000000..ad322743 --- /dev/null +++ b/sml/preprocessing/BUILD.bazel @@ -0,0 +1,22 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_python//python:defs.bzl", "py_library") + +package(default_visibility = ["//visibility:public"]) + +py_library( + name = "preprocessing", + srcs = ["preprocessing.py"], +) diff --git a/sml/preprocessing/emulations/BUILD.bazel b/sml/preprocessing/emulations/BUILD.bazel new file mode 100755 index 00000000..74820410 --- /dev/null +++ b/sml/preprocessing/emulations/BUILD.bazel @@ -0,0 +1,26 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_python//python:defs.bzl", "py_binary") + +package(default_visibility = ["//visibility:public"]) + +py_binary( + name = "preprocessing_emul", + srcs = ["preprocessing_emul.py"], + deps = [ + "//sml/preprocessing", + "//sml/utils:emulation", + ], +) diff --git a/sml/preprocessing/emulations/preprocessing_emul.py b/sml/preprocessing/emulations/preprocessing_emul.py new file mode 100755 index 00000000..1b326243 --- /dev/null +++ b/sml/preprocessing/emulations/preprocessing_emul.py @@ -0,0 +1,166 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax.numpy as jnp +import numpy as np +from sklearn import preprocessing + +import sml.utils.emulation as emulation +from sml.preprocessing.preprocessing import Binarizer, LabelBinarizer, Normalizer + + +def emul_labelbinarizer(): + def labelbinarize(X, Y): + transformer = LabelBinarizer(neg_label=-2, pos_label=3) + transformer.fit(X, n_classes=4) + transformed = transformer.transform(Y) + inv_transformed = transformer.inverse_transform(transformed) + return transformed, inv_transformed + + X = jnp.array([1, 2, 4, 6]) + Y = jnp.array([1, 6]) + + transformer = preprocessing.LabelBinarizer(neg_label=-2, pos_label=3) + transformer.fit(X) + sk_transformed = transformer.transform(Y) + sk_inv_transformed = transformer.inverse_transform(sk_transformed) + # print("sklearn:\n", sk_transformed) + # print("sklearn:\n", sk_inv_transformed) + + X, Y = emulator.seal(X, Y) + spu_transformed, spu_inv_transformed = emulator.run(labelbinarize)(X, Y) + # print("spu:\n", spu_transformed) + # print("spu:\n", spu_inv_transformed) + + np.testing.assert_allclose(sk_transformed, spu_transformed, rtol=0, atol=0) + np.testing.assert_allclose(sk_inv_transformed, spu_inv_transformed, rtol=0, atol=0) + + +def emul_labelbinarizer_binary(): + def labelbinarize(X): + transformer = LabelBinarizer() + transformed = transformer.fit_transform(X, n_classes=2, unique=False) + inv_transformed = transformer.inverse_transform(transformed) + return transformed, inv_transformed + + X = jnp.array([1, -1, -1, 1]) + transformer = preprocessing.LabelBinarizer() + sk_transformed = transformer.fit_transform(X) + sk_inv_transformed = transformer.inverse_transform(sk_transformed) + # print("sklearn:\n", sk_transformed) + # print("sklearn:\n", sk_inv_transformed) + + X = emulator.seal(X) + spu_transformed, spu_inv_transformed = emulator.run(labelbinarize)(X) + # print("spu:\n", spu_transformed) + # print("spu:\n", spu_inv_transformed) + + np.testing.assert_allclose(sk_transformed, spu_transformed, rtol=0, atol=0) + np.testing.assert_allclose(sk_inv_transformed, spu_inv_transformed, rtol=0, atol=0) + + +def emul_labelbinarizer_unseen(): + def labelbinarize(X, Y): + transformer = LabelBinarizer() + transformer.fit(X, n_classes=3) + return transformer.transform(Y) + + X = jnp.array([2, 4, 5]) + Y = jnp.array([1, 2, 3, 4, 5, 6]) + + transformer = preprocessing.LabelBinarizer() + transformer.fit(X) + sk_result = transformer.transform(Y) + # print("sklearn:\n", sk_result) + + X, Y = emulator.seal(X, Y) + spu_result = emulator.run(labelbinarize)(X, Y) + # print("spu:\n", spu_result) + + np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=0) + + +def emul_binarizer(): + def binarize(X): + transformer = Binarizer() + return transformer.transform(X) + + X = jnp.array([[1.0, -1.0, 2.0], [2.0, 0.0, 0.0], [0.0, 1.0, -1.0]]) + + transformer = preprocessing.Binarizer() + sk_result = transformer.transform(X) + # print("sklearn:\n", sk_result) + + X = emulator.seal(X) + spu_result = emulator.run(binarize)(X) + # print("spu:\n", spu_result) + + np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=0) + + +def emul_normalizer(): + def normalize_l1(X): + transformer = Normalizer(norm="l1") + return transformer.transform(X) + + def normalize_l2(X): + transformer = Normalizer() + return transformer.transform(X) + + def normalize_max(X): + transformer = Normalizer(norm="max") + return transformer.transform(X) + + X = jnp.array([[4, 1, 2, 2], [1, 3, 9, 3], [5, 7, 5, 1]]) + + transformer_l1 = preprocessing.Normalizer(norm="l1") + sk_result_l1 = transformer_l1.transform(X) + transformer_l2 = preprocessing.Normalizer() + sk_result_l2 = transformer_l2.transform(X) + transformer_max = preprocessing.Normalizer(norm="max") + sk_result_max = transformer_max.transform(X) + # print("sklearn:\n", sk_result_l1) + # print("sklearn:\n", sk_result_l2) + # print("sklearn:\n", sk_result_max) + + X = emulator.seal(X) + spu_result_l1 = emulator.run(normalize_l1)(X) + spu_result_l2 = emulator.run(normalize_l2)(X) + spu_result_max = emulator.run(normalize_max)(X) + # print("spu:\n", spu_result_l1) + # print("spu:\n", spu_result_l2) + # print("spu:\n", spu_result_max) + + np.testing.assert_allclose(sk_result_l1, spu_result_l1, rtol=0, atol=1e-4) + np.testing.assert_allclose(sk_result_l2, spu_result_l2, rtol=0, atol=1e-4) + np.testing.assert_allclose(sk_result_max, spu_result_max, rtol=0, atol=1e-4) + + +if __name__ == "__main__": + try: + # bandwidth and latency only work for docker mode + emulator = emulation.Emulator( + emulation.CLUSTER_ABY3_3PC, + emulation.Mode.MULTIPROCESS, + bandwidth=300, + latency=20, + ) + emulator.up() + emul_labelbinarizer() + emul_labelbinarizer_binary() + emul_labelbinarizer_unseen() + emul_binarizer() + emul_normalizer() + finally: + emulator.down() diff --git a/sml/preprocessing/preprocessing.py b/sml/preprocessing/preprocessing.py new file mode 100755 index 00000000..a06d0968 --- /dev/null +++ b/sml/preprocessing/preprocessing.py @@ -0,0 +1,295 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +import jax.numpy as jnp + + +def label_binarize(y, *, classes, n_classes, neg_label=0, pos_label=1): + """Binarize labels in a one-vs-all fashion. + + Parameters + ---------- + y : {array-like}, shape (n_samples,) + Input data. + + classes : {array-like}, shape (n_classes,) + Uniquely holds the label for each class. + + n_classes : int + Number of classes. SPU cannot support dynamic shape, + so this parameter needs to be designated. + + neg_label : int, default=0 + Value with which negative labels must be encoded. + + pos_label : int, default=1 + Value with which positive labels must be encoded. + + Returns + ------- + ndarray of shape (n_samples, n_classes) + Shape will be (n_samples, 1) for binary problems. + """ + eq_func = lambda x: jnp.where(classes == x, 1, 0) + result = jax.vmap(eq_func)(y) + + if neg_label != 0 or pos_label != 1: + result = jnp.where(result, pos_label, neg_label) + + if n_classes == 2: + result = result[:, -1].reshape((-1, 1)) + return result + + +def _inverse_binarize_multiclass(y, classes): + """Inverse label binarization transformation for multiclass. + + Multiclass uses the maximal score instead of a threshold. + """ + return jnp.take(classes, y.argmax(axis=1), mode="clip") + + +def _inverse_binarize_thresholding(y, classes, threshold): + """Inverse label binarization transformation using thresholding.""" + y = jnp.array(y > threshold, dtype=int) + return classes[y[:, 1]] + + +class LabelBinarizer: + """Binarize labels in a one-vs-all fashion. + + Firstly, use fit() to use an array to set the classes. + The number of classes needs to be designated through parameter n_classes since SPU cannot support dynamic shape. + The dynamic shape problem occurs when there are duplicated elements in input of fit function. + The deduplication operation will cause complex computation, so it is not used by default. + Noted that if unique==True, the order of the classes will be kept instead of sorted. + + Secondly, use transform() to convert the value to a one-hot label for classes. + The input array needs to be 1d. Users can directly use the transformation method like jax.ravel to transform + the input array into 1d then use LabelBinarizer to do further transformation. + + Parameters + ---------- + neg_label : int, default=0 + Value with which negative labels must be encoded. + + pos_label : int, default=1 + Value with which positive labels must be encoded. + + """ + + def __init__(self, *, neg_label=0, pos_label=1): + self.neg_label = neg_label + self.pos_label = pos_label + + def fit(self, y, n_classes, unique=True): + """Fit label binarizer. + + Parameters + ---------- + y : {array-like}, shape (n_samples,) + Input data. + + n_classes : int + Number of classes. SPU cannot support dynamic shape, + so this parameter needs to be designated. + + unique : bool + Set to False to do deduplication on classes + + Returns + ------- + self : object + Returns the instance itself. + """ + if self.neg_label >= self.pos_label: + raise ValueError( + f"neg_label={self.neg_label} must be strictly less than " + f"pos_label={self.pos_label}." + ) + if unique == True: + self.classes_ = y + else: + # The output of jax needs to be tensor with known size. + self.classes_ = jnp.unique(y, size=n_classes) + self.n_classes_ = n_classes + return self + + def fit_transform(self, y, n_classes, unique=True): + """Fit label binarizer/transform multi-class labels to binary labels. + + Parameters + ---------- + y : {array-like}, shape (n_samples,) + Input data. + + n_classes : int + Number of classes. SPU cannot support dynamic shape, + so this parameter needs to be designated. + + unique : bool + Set to False to do deduplication on classes + + Returns + ------- + ndarray of shape (n_samples, n_classes) + Shape will be (n_samples, 1) for binary problems. + """ + return self.fit(y, n_classes, unique=unique).transform(y) + + def transform(self, y): + """Transform multi-class labels to binary labels. + Parameters + ---------- + y : {array-like}, shape (n_samples,) + Input data. + + Returns + ------- + ndarray of shape (n_samples, n_classes) + Shape will be (n_samples, 1) for binary problems. + """ + return label_binarize( + y, + classes=self.classes_, + n_classes=self.n_classes_, + neg_label=self.neg_label, + pos_label=self.pos_label, + ) + + def inverse_transform(self, Y, threshold=None): + """Transform binary labels back to multi-class labels. + + Parameters + ---------- + Y : {array-like}, shape (n_samples, n_classes) + Input data. + + threshold : float, default=None + Threshold used in the binary cases. + + Returns + ------- + ndarray of shape (n_samples,) + + """ + if threshold is None: + threshold = (self.pos_label + self.neg_label) / 2.0 + if self.n_classes_ == 2: + y_inv = _inverse_binarize_thresholding(Y, self.classes_, threshold) + else: + y_inv = _inverse_binarize_multiclass(Y, self.classes_) + return y_inv + + +def binarize(X, *, threshold=0.0): + """Binarize data (set feature values to 0 or 1) according to a threshold. + + Parameters + ---------- + threshold : float, default=0.0 + Feature values below or equal to this are replaced by 0, above it by 1. + + """ + return jnp.where(X > threshold, 1, 0) + + +class Binarizer: + """Binarize data (set feature values to 0 or 1) according to a threshold. + + Parameters + ---------- + threshold : float, default=0.0 + Feature values below or equal to this are replaced by 0, above it by 1. + + """ + + def __init__(self, *, threshold=0.0): + self.threshold = threshold + + def transform(self, X): + """Binarize each element of X. + + Parameters + ---------- + X : {array-like} of shape (n_samples, n_features) + The data to binarize, element by element. + + Returns + ------- + ndarray of shape (n_samples, n_features) + Transformed array. + """ + return binarize(X, threshold=self.threshold) + + +def normalize(X, norm="l2"): + """Scale input vectors individually to unit norm (vector length). + + Parameters + ---------- + X : {array-like} of shape (n_samples, n_features) + The data to normalize, element by element. + + norm : {'l1', 'l2', 'max'}, default='l2' + The norm to use to normalize each non zero sample (or each non-zero + feature if axis is 0). + + Returns + ------- + ndarray of shape (n_samples, n_features) + Transformed array. + """ + if norm == "l1": + norms = jnp.abs(X).sum(axis=1) + return X / norms[:, jnp.newaxis] + elif norm == "l2": + norms = jnp.einsum("ij,ij->i", X, X) + norms = norms.astype(jnp.float32) + # Use rsqrt instead of using combination of reciprocal and square for optimization + return X * jax.lax.rsqrt(norms)[:, jnp.newaxis] + elif norm == "max": + norms = jnp.max(abs(X), axis=1) + return X / norms[:, jnp.newaxis] + + +class Normalizer: + """Normalize samples individually to unit norm. + + Parameters + ---------- + norm : {'l1', 'l2', 'max'}, default='l2' + The norm to use to normalize each non zero sample. If norm='max' + is used, values will be rescaled by the maximum of the absolute + values. + """ + + def __init__(self, norm="l2"): + self.norm = norm + + def transform(self, X): + """Scale each non zero row of X to unit norm. + + Parameters + ---------- + X : {array-like} of shape (n_samples, n_features) + The data to normalize, row by row. + + Returns + ------- + ndarray of shape (n_samples, n_features) + Transformed array. + """ + return normalize(X, norm=self.norm) diff --git a/sml/preprocessing/tests/BUILD.bazel b/sml/preprocessing/tests/BUILD.bazel new file mode 100755 index 00000000..994ed985 --- /dev/null +++ b/sml/preprocessing/tests/BUILD.bazel @@ -0,0 +1,27 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_python//python:defs.bzl", "py_test") + +package(default_visibility = ["//visibility:public"]) + +py_test( + name = "preprocessing_test", + srcs = ["preprocessing_test.py"], + deps = [ + "//sml/preprocessing", + "//spu:init", + "//spu/utils:simulation", + ], +) diff --git a/sml/preprocessing/tests/preprocessing_test.py b/sml/preprocessing/tests/preprocessing_test.py new file mode 100755 index 00000000..b7e4bdf7 --- /dev/null +++ b/sml/preprocessing/tests/preprocessing_test.py @@ -0,0 +1,172 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import jax.numpy as jnp +import numpy as np +from sklearn import preprocessing + +import spu.spu_pb2 as spu_pb2 +import spu.utils.simulation as spsim +from sml.preprocessing.preprocessing import Binarizer, LabelBinarizer, Normalizer + + +class UnitTests(unittest.TestCase): + def test_labelbinarizer(self): + sim = spsim.Simulator.simple( + 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 + ) + + def labelbinarize(X, Y): + transformer = LabelBinarizer(neg_label=-2, pos_label=3) + transformer.fit(X, n_classes=4) + transformed = transformer.transform(Y) + inv_transformed = transformer.inverse_transform(transformed) + return transformed, inv_transformed + + X = jnp.array([1, 2, 4, 6]) + Y = jnp.array([1, 6]) + + transformer = preprocessing.LabelBinarizer(neg_label=-2, pos_label=3) + transformer.fit(X) + sk_transformed = transformer.transform(Y) + sk_inv_transformed = transformer.inverse_transform(sk_transformed) + # print("sklearn:\n", sk_transformed) + # print("sklearn:\n", sk_inv_transformed) + + spu_transformed, spu_inv_transformed = spsim.sim_jax(sim, labelbinarize)(X, Y) + # print("result\n", spu_transformed) + # print("result\n", spu_inv_transformed) + + np.testing.assert_allclose(sk_transformed, spu_transformed, rtol=0, atol=0) + np.testing.assert_allclose( + sk_inv_transformed, spu_inv_transformed, rtol=0, atol=0 + ) + + def test_labelbinarizer_binary(self): + sim = spsim.Simulator.simple( + 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 + ) + + def labelbinarize(X): + transformer = LabelBinarizer() + transformed = transformer.fit_transform(X, n_classes=2, unique=False) + inv_transformed = transformer.inverse_transform(transformed) + return transformed, inv_transformed + + X = jnp.array([1, -1, -1, 1]) + + transformer = preprocessing.LabelBinarizer() + sk_transformed = transformer.fit_transform(X) + sk_inv_transformed = transformer.inverse_transform(sk_transformed) + # print("sklearn:\n", sk_transformed) + # print("sklearn:\n", sk_inv_transformed) + + spu_transformed, spu_inv_transformed = spsim.sim_jax(sim, labelbinarize)(X) + # print("result\n", spu_transformed) + # print("result\n", spu_inv_transformed) + + np.testing.assert_allclose(sk_transformed, spu_transformed, rtol=0, atol=0) + np.testing.assert_allclose( + sk_inv_transformed, spu_inv_transformed, rtol=0, atol=0 + ) + + def test_labelbinarizer_unseen(self): + sim = spsim.Simulator.simple( + 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 + ) + + def labelbinarize(X, Y): + transformer = LabelBinarizer() + transformer.fit(X, n_classes=3) + return transformer.transform(Y) + + X = jnp.array([2, 4, 5]) + Y = jnp.array([1, 2, 3, 4, 5, 6]) + + transformer = preprocessing.LabelBinarizer() + transformer.fit(X) + sk_result = transformer.transform(Y) + # print("sklearn:\n", sk_result) + + spu_result = spsim.sim_jax(sim, labelbinarize)(X, Y) + # print("result\n", spu_result) + + np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=0) + + def test_binarizer(self): + sim = spsim.Simulator.simple( + 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 + ) + + def binarize(X): + transformer = Binarizer() + return transformer.transform(X) + + X = jnp.array([[1.0, -1.0, 2.0], [2.0, 0.0, 0.0], [0.0, 1.0, -1.0]]) + + transformer = preprocessing.Binarizer() + sk_result = transformer.transform(X) + # print("sklearn:\n", sk_result) + + spu_result = spsim.sim_jax(sim, binarize)(X) + # print("result\n", spu_result) + + np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=0) + + def test_normalizer(self): + sim = spsim.Simulator.simple( + 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 + ) + + def normalize_l1(X): + transformer = Normalizer(norm="l1") + return transformer.transform(X) + + def normalize_l2(X): + transformer = Normalizer() + return transformer.transform(X) + + def normalize_max(X): + transformer = Normalizer(norm="max") + return transformer.transform(X) + + X = jnp.array([[4, 1, 2, 2], [1, 3, 9, 3], [5, 7, 5, 1]]) + + spu_result_l1 = spsim.sim_jax(sim, normalize_l1)(X) + spu_result_l2 = spsim.sim_jax(sim, normalize_l2)(X) + spu_result_max = spsim.sim_jax(sim, normalize_max)(X) + + transformer_l1 = preprocessing.Normalizer(norm="l1") + sk_result_l1 = transformer_l1.transform(X) + transformer_l2 = preprocessing.Normalizer() + sk_result_l2 = transformer_l2.transform(X) + transformer_max = preprocessing.Normalizer(norm="max") + sk_result_max = transformer_max.transform(X) + # print("sklearn:\n", sk_result_l1) + # print("sklearn:\n", sk_result_l2) + # print("sklearn:\n", sk_result_max) + + # print("result\n", spu_result_l1) + # print("result\n", spu_result_l2) + # print("result\n", spu_result_max) + + np.testing.assert_allclose(sk_result_l1, spu_result_l1, rtol=0, atol=1e-4) + np.testing.assert_allclose(sk_result_l2, spu_result_l2, rtol=0, atol=1e-4) + np.testing.assert_allclose(sk_result_max, spu_result_max, rtol=0, atol=1e-4) + + +if __name__ == "__main__": + unittest.main()