From d386d3fa3fea7be04691c3d847034500e6958729 Mon Sep 17 00:00:00 2001 From: lwxxxxxxx <100955060+lwxxxxxxx@users.noreply.github.com> Date: Fri, 27 Oct 2023 13:54:46 +0800 Subject: [PATCH] add sml/svm (#362) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 按照最新的版本增加了svm --- sml/svm/BUILD.bazel | 30 ++++++ sml/svm/emulations/BUILD.bazel | 26 +++++ sml/svm/emulations/svm_emul.py | 84 +++++++++++++++ sml/svm/smo.py | 191 +++++++++++++++++++++++++++++++++ sml/svm/svm.py | 141 ++++++++++++++++++++++++ sml/svm/tests/BUILD.bazel | 27 +++++ sml/svm/tests/svm_test.py | 76 +++++++++++++ 7 files changed, 575 insertions(+) create mode 100644 sml/svm/BUILD.bazel create mode 100644 sml/svm/emulations/BUILD.bazel create mode 100644 sml/svm/emulations/svm_emul.py create mode 100644 sml/svm/smo.py create mode 100644 sml/svm/svm.py create mode 100644 sml/svm/tests/BUILD.bazel create mode 100644 sml/svm/tests/svm_test.py diff --git a/sml/svm/BUILD.bazel b/sml/svm/BUILD.bazel new file mode 100644 index 00000000..fccd63ff --- /dev/null +++ b/sml/svm/BUILD.bazel @@ -0,0 +1,30 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_python//python:defs.bzl", "py_library") + +package(default_visibility = ["//visibility:public"]) + +py_library( + name = "smo", + srcs = ["smo.py"], +) + +py_library( + name = "svm", + srcs = ["svm.py"], + deps = [ + ":smo", + ], +) diff --git a/sml/svm/emulations/BUILD.bazel b/sml/svm/emulations/BUILD.bazel new file mode 100644 index 00000000..1a499939 --- /dev/null +++ b/sml/svm/emulations/BUILD.bazel @@ -0,0 +1,26 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_python//python:defs.bzl", "py_binary") + +package(default_visibility = ["//visibility:public"]) + +py_binary( + name = "svm_emul", + srcs = ["svm_emul.py"], + deps = [ + "//sml/svm", + "//sml/utils:emulation", + ], +) diff --git a/sml/svm/emulations/svm_emul.py b/sml/svm/emulations/svm_emul.py new file mode 100644 index 00000000..465b67bf --- /dev/null +++ b/sml/svm/emulations/svm_emul.py @@ -0,0 +1,84 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import jax.numpy as jnp +from sklearn import datasets +from sklearn.metrics import accuracy_score, classification_report +from sklearn.model_selection import train_test_split +from sklearn.svm import SVC + +import sml.utils.emulation as emulation +import spu.spu_pb2 as spu_pb2 # type: ignore +from sml.svm.svm import SVM + + +def emul_SVM(mode: emulation.Mode.MULTIPROCESS): + def proc(x0, x1, y0): + rbf_svm = SVM(kernel="rbf", max_iter=102) + rbf_svm.fit(x0, y0) + + return rbf_svm.predict(x1) + + def load_data(): + breast_cancer = datasets.load_breast_cancer() + data = breast_cancer.data + data = data / (jnp.max(data) - jnp.min(data)) + target = breast_cancer.target + X_train, X_test, y_train, y_test = train_test_split( + data, target, test_size=0.2, random_state=1 + ) + + y_train[y_train != 1] = -1 + X_train, X_test, y_train, y_test = ( + jnp.array(X_train), + jnp.array(X_test), + jnp.array(y_train), + jnp.array(y_test), + ) + + return X_train, X_test, y_train, y_test + + try: + # bandwidth and latency only work for docker mode + emulator = emulation.Emulator( + "examples/python/conf/3pc.json", mode, bandwidth=300, latency=20 + ) + emulator.up() + + time0 = time.time() + # load data + X_train, X_test, y_train, y_test = load_data() + + # mark these data to be protected in SPU + X_train, X_test, y_train = emulator.seal(X_train, X_test, y_train) + result1 = emulator.run(proc)(X_train, X_test, y_train) + print("result\n", result1) + print("accuracy score", accuracy_score(result1, y_test)) + print("cost time ", time.time() - time0) + + # Compare with sklearn + print("sklearn") + X_train, X_test, y_train, y_test = load_data() + clf_svc = SVC(C=1.0, kernel="rbf", gamma='scale', tol=1e-3) + result2 = clf_svc.fit(X_train, y_train).predict(X_test) + print("result\n", (result2 > 0).astype(int)) + print("accuracy score", accuracy_score((result2 > 0).astype(int), y_test)) + finally: + emulator.down() + + +if __name__ == "__main__": + emul_SVM(emulation.Mode.MULTIPROCESS) diff --git a/sml/svm/smo.py b/sml/svm/smo.py new file mode 100644 index 00000000..ac3f01e6 --- /dev/null +++ b/sml/svm/smo.py @@ -0,0 +1,191 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax.numpy as jnp + + +class SMO: + """ + Reference: [FCLJ05] + Fan R E, Chen P H, Lin C J, et al. Working set selection using second order information for + training support vector machines[J]. Journal of machine learning research, 2005, 6(12). + + Parameters + ---------- + size : int + Size of data. + + C : float + Error penalty coefficient. + + tol : float, default=1e-3 + Acceptable error to consider the two to be equal. + """ + + def __init__(self, size, C: float, tol: float = 1e-3) -> None: + self.size = size + self.C = C + self.tol = tol + self.tau = 1e-6 + self.Cs = jnp.array([self.C] * size) + self.zeros = jnp.array([0] * size) + + def working_set_select_i(self, alpha, y, neg_y_grad): + """ + Select the first working set. + """ + alpha_lower_C, alpha_upper_0, y_lower_0, y_upper_0 = jnp.array( + [self.Cs, alpha, self.zeros, y] + ) > jnp.array([alpha, self.zeros, y, self.zeros]) + Zup = (alpha_lower_C & y_upper_0) | (alpha_upper_0 & y_lower_0) + Mup = (1 - Zup) * jnp.min(neg_y_grad) + i = jnp.argmax(neg_y_grad * Zup + Mup) + return i + + def working_set_select_j(self, i, alpha, y, neg_y_grad, Q): + """ + Select the second working set. + """ + alpha_lower_C, alpha_upper_0, y_lower_0, y_upper_0 = jnp.array( + [self.Cs, alpha, self.zeros, y] + ) > jnp.array([alpha, self.zeros, y, self.zeros]) + Zlow = (alpha_lower_C & y_lower_0) | (alpha_upper_0 & y_upper_0) + + m = neg_y_grad[i] + + Zlow_m = Zlow & (neg_y_grad < m) + Qi = Q[i] + Qj = Q.diagonal() + quad_coef = Qi[i] + Qj - 2 * Q[i] + quad_coef = (quad_coef > 0) * quad_coef + (1 - (quad_coef > 0)) * self.tau + Ft = -((m - neg_y_grad) ** 2) / (quad_coef) + Mlow_m = (1 - Zlow_m) * jnp.max(Ft) + j = jnp.argmin(Ft * Zlow_m + Mlow_m) + return j + + def update(self, i, j, Q, y, alpha, neg_y_grad): + """ + Update `alpha[i]` and `alpha[j]` by adjusting the way of `z = x if t else y` to `z = t*x + (1-t)*y`. + """ + + Qi, Qj = Q[i], Q[j] + yi, yj = y[i], y[j] + alpha_i, alpha_j = alpha[i] + 0, alpha[j] + 0 + alpha_i0, alpha_j0 = alpha_i + 0, alpha_j + 0 + + quad_coef = Qi[i] + Qj[j] - 2 * yi * yj * Qi[j] + quad_coef = (quad_coef > 0) * quad_coef + (1 - (quad_coef > 0)) * self.tau + + yi_mul_yj = yi * yj + yi_neq_yj = yi != yj + + delta = (-yi_mul_yj * neg_y_grad[i] * yi + neg_y_grad[j] * yj) / quad_coef + diff_sum = alpha_i + yi_mul_yj * alpha_j + alpha_i = alpha_i + (-1 * yi_mul_yj * delta) + alpha_j = alpha_j + delta + + # first cal + ( + diff_sum_upper_0, + diff_sum_upper_C, + alpha_i_lower_0, + alpha_j_lower_0, + alpha_i_upper_C, + ) = jnp.array([diff_sum, diff_sum, 0, 0, alpha_i]) > jnp.array( + [0, self.C, alpha_i, alpha_j, self.C] + ) + outer = jnp.array( + [yi_neq_yj, yi_neq_yj, 1 - yi_neq_yj, 1 - yi_neq_yj] + ) * jnp.array( + [ + diff_sum_upper_0, + 1 - diff_sum_upper_0, + diff_sum_upper_C, + 1 - diff_sum_upper_C, + ] + ) + update_condition = jnp.array( + [alpha_j_lower_0, alpha_i_lower_0, alpha_i_upper_C, alpha_j_lower_0] * 2 + ) + update_from = jnp.array( + [alpha_i, alpha_i, alpha_i, alpha_i, alpha_j, alpha_j, alpha_j, alpha_j] + ) + update_to = jnp.array( + [diff_sum, 0, self.C, diff_sum, 0, -diff_sum, diff_sum - self.C, 0] + ) + inner = (update_from + update_condition * (update_to - update_from)).reshape( + 2, -1 + ) + alpha_i, alpha_j = jnp.dot(inner, outer.T) + + # second cal + alpha_i_lower_0, alpha_i_upper_C, alpha_j_upper_C = jnp.array( + [0, alpha_i, alpha_j] + ) > jnp.array([alpha_i, self.C, self.C]) + update_condition = jnp.array( + [alpha_i_upper_C, alpha_j_upper_C, alpha_j_upper_C, alpha_i_lower_0] * 2 + ) + update_from = jnp.array( + [alpha_i, alpha_i, alpha_i, alpha_i, alpha_j, alpha_j, alpha_j, alpha_j] + ) + update_to = jnp.array( + [ + self.C, + self.C + diff_sum, + diff_sum - self.C, + 0, + self.C - diff_sum, + self.C, + self.C, + diff_sum, + ] + ) + inner = (update_from + update_condition * (update_to - update_from)).reshape( + 2, -1 + ) + alpha_i, alpha_j = jnp.dot(inner, outer.T) + + delta_i = alpha_i - alpha_i0 + delta_j = alpha_j - alpha_j0 + + neg_y_grad = neg_y_grad - y * ( + jnp.dot(jnp.array([delta_i, delta_j]), jnp.array([Q[i], Q[j]])) + ) + alpha = alpha.at[jnp.array([i, j])].set(jnp.array([alpha_i, alpha_j])) + + return neg_y_grad, alpha + + def cal_b(self, alpha, neg_y_grad, y) -> float: + """Calculate bias.""" + + alpha_lower_C = alpha < self.C - self.tol + alpha_equal_C = jnp.abs(alpha - self.C) < self.tol + alpha_equal_0 = jnp.abs(alpha) < self.tol + alpha_upper_0 = alpha > 0 + y_lower_0 = y < 0 + y_upper_0 = y > 0 + + alpha_upper_0_and_lower_C = alpha_upper_0 & alpha_lower_C + sv_sum = jnp.sum(alpha_upper_0_and_lower_C) + + rho_0 = -1 * (neg_y_grad * alpha_upper_0_and_lower_C).sum() / sv_sum + + Zub = (alpha_equal_0 & y_lower_0) | (alpha_equal_C & y_upper_0) + Zlb = (alpha_equal_0 & y_upper_0) | (alpha_equal_C & y_lower_0) + rho_1 = -((neg_y_grad * Zub).min() + (neg_y_grad * Zlb).max()) / 2 + + rho = (sv_sum > 0) * rho_0 + (1 - (sv_sum > 0)) * rho_1 + + b = -1 * rho + return b diff --git a/sml/svm/svm.py b/sml/svm/svm.py new file mode 100644 index 00000000..114fd6e6 --- /dev/null +++ b/sml/svm/svm.py @@ -0,0 +1,141 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax.numpy as jnp + +from sml.svm.smo import SMO + + +class SVM: + """ + Parameters + ---------- + kernel : str, default="rbf" + The kernel function used in the svm algorithm, maps samples + to a higher dimensional feature space. + + C : float, default=1.0 + Regularization parameter. The strength of the regularization + is inversely proportional to C. Must be strictly positive. The penalty is a squared l2 penalty. + + gamma : {'scale', 'auto'}, default="scale" + Kernel coefficient for 'rbf', 'poly' and 'sigmoid'. + if gamma='scale' (default) is passed then it uses 1 / (n_features * X.var()) as value of gamma, + if 'auto', uses 1 / n_features + + max_iter : int, default=300 + Maximum number of iterations of the svm algorithm for a + single run. + + tol : float, default=1e-3 + Acceptable error to consider the two to be equal. + """ + + def __init__(self, kernel="rbf", C=1.0, gamma='scale', max_iter=300, tol=1e-3): + self.kernel = kernel + self.C = C + self.gamma = gamma + self.max_iter = max_iter + self.tol = tol + self.n_features = None + + self.alpha_y = None + self.b = None + self.X = None + + assert self.gamma in {'scale', 'auto'}, "Gamma only support 'scale' and 'auto'" + assert self.kernel == "rbf", "Kernel function only support 'rbf'" + + def cal_kernel(self, x, x_): + """Calculate kernel.""" + gamma = { + 'scale': 1 / (self.n_features * x.var()), + 'auto': 1 / self.n_features, + }[self.gamma] + + kernel_res = jnp.exp( + -gamma + * ( + (x**2).sum(1, keepdims=True) + + (x_**2).sum(1) + - 2 * jnp.matmul(x, x_.T) + ) + ) + + return kernel_res + + def cal_Q(self, x, y): + """Calculate Q.""" + kernel_res = self.cal_kernel(x, x) + Q = y.reshape(-1, 1) * y * kernel_res + return Q + + def fit(self, X, y): + """Fit SVM. + + Using the Sequential Minimal Optimization(SMO) algorithm to solve the Quadratic programming problem in + the SVM, which decomposes the large optimization problem to several small optimization problems. Firstly, + the SMO algorithm selects alpha_i and alpha_j by 'smo.working_set_select_i()' and 'smo.working_set_select_j'. + Secondly, the SMO algorithm update the parameter by 'smo.update()'. Last, calculate the bias. + + Parameters + ---------- + X : {array-like}, shape (n_samples, n_features) + Input data. + + y : {array-like}, shape (n_samples) + Lable of the input data. + + """ + + l, self.n_features = X.shape + p = -jnp.ones(l) + smo = SMO(l, self.C, self.tol) + Q = self.cal_Q(X, y) + alpha = 0.0 * y + neg_y_grad = -p * y + for _ in range(self.max_iter): + i = smo.working_set_select_i(alpha, y, neg_y_grad) + j = smo.working_set_select_j(i, alpha, y, neg_y_grad, Q) + neg_y_grad, alpha = smo.update(i, j, Q, y, alpha, neg_y_grad) + + self.b = smo.cal_b(alpha, neg_y_grad, y) + self.alpha_y = alpha * y + + self.X = X + + def predict(self, x): + """Result estimates. + + Calculate the classification result of the input data. + + Parameters + ---------- + x : {array-like}, shape (n_samples, n_features) + Input data for prediction. + + Returns + ------- + ndarray of shape (n_samples) + Returns the classification result of the input data for prediction. + """ + + pred = ( + jnp.matmul( + self.alpha_y, + self.cal_kernel(self.X, x), + ) + + self.b + ) + return (pred >= 0).astype(int) diff --git a/sml/svm/tests/BUILD.bazel b/sml/svm/tests/BUILD.bazel new file mode 100644 index 00000000..91f54ba0 --- /dev/null +++ b/sml/svm/tests/BUILD.bazel @@ -0,0 +1,27 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_python//python:defs.bzl", "py_test") + +package(default_visibility = ["//visibility:public"]) + +py_test( + name = "svm_test", + srcs = ["svm_test.py"], + deps = [ + "//sml/svm", + "//spu:init", + "//spu/utils:simulation", + ], +) diff --git a/sml/svm/tests/svm_test.py b/sml/svm/tests/svm_test.py new file mode 100644 index 00000000..c2b7a580 --- /dev/null +++ b/sml/svm/tests/svm_test.py @@ -0,0 +1,76 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import unittest + +import jax.numpy as jnp +from sklearn import datasets +from sklearn.metrics import accuracy_score, classification_report +from sklearn.model_selection import train_test_split +from sklearn.svm import SVC + +import spu.spu_pb2 as spu_pb2 # type: ignore +import spu.utils.simulation as spsim +from sml.svm.svm import SVM + + +class UnitTests(unittest.TestCase): + def test_svm(self): + sim = spsim.Simulator.simple( + 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 + ) + + def proc(x0, x1, y0): + rbf_svm = SVM(kernel="rbf", max_iter=102) + rbf_svm.fit(x0, y0) + + return rbf_svm.predict(x1) + + def load_data(): + breast_cancer = datasets.load_breast_cancer() + data = breast_cancer.data + data = data / (jnp.max(data) - jnp.min(data)) + target = breast_cancer.target + X_train, X_test, y_train, y_test = train_test_split( + data, target, test_size=0.2, random_state=1 + ) + + y_train[y_train != 1] = -1 + X_train, X_test, y_train, y_test = ( + jnp.array(X_train), + jnp.array(X_test), + jnp.array(y_train), + jnp.array(y_test), + ) + + return X_train, X_test, y_train, y_test + + time0 = time.time() + X_train, X_test, y_train, y_test = load_data() + result1 = spsim.sim_jax(sim, proc)(X_train, X_test, y_train) + print("result\n", result1) + print("accuracy score", accuracy_score(result1, y_test)) + print("cost time ", time.time() - time0) + + # Compare with sklearn + print("sklearn") + clf_svc = SVC(C=1.0, kernel="rbf", gamma='scale', tol=1e-3) + result2 = clf_svc.fit(X_train, y_train).predict(X_test) + print("result\n", (result2 > 0).astype(int)) + print("accuracy score", accuracy_score((result2 > 0).astype(int), y_test)) + + +if __name__ == "__main__": + unittest.main()