Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add sml/svm #362

Merged
merged 16 commits into from
Oct 27, 2023
30 changes: 30 additions & 0 deletions sml/svm/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2023 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

load("@rules_python//python:defs.bzl", "py_library")

package(default_visibility = ["//visibility:public"])

py_library(
name = "smo",
srcs = ["smo.py"],
)

py_library(
name = "svm",
srcs = ["svm.py"],
deps = [
":smo",
],
)
26 changes: 26 additions & 0 deletions sml/svm/emulations/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright 2023 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

load("@rules_python//python:defs.bzl", "py_binary")

package(default_visibility = ["//visibility:public"])

py_binary(
name = "svm_emul",
srcs = ["svm_emul.py"],
deps = [
"//sml/svm",
"//sml/utils:emulation",
],
)
86 changes: 86 additions & 0 deletions sml/svm/emulations/svm_emul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2023 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import jax.numpy as jnp
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
from sklearn.svm import SVC

import spu.spu_pb2 as spu_pb2 # type: ignore
from sml.svm.svm import SVM
lwxxxxxxx marked this conversation as resolved.
Show resolved Hide resolved
import sml.utils.emulation as emulation

import time


def emul_SVM(mode: emulation.Mode.MULTIPROCESS):
def proc(x0, x1, y0):
rbf_svm = SVM(kernel="rbf", max_iter=102)
rbf_svm.fit(x0, y0)

return rbf_svm.predict(x0, y0, x1)

def load_data():
breast_cancer = datasets.load_breast_cancer()
data = breast_cancer.data
data = data / (jnp.max(data) - jnp.min(data))
target = breast_cancer.target
X_train, X_test, y_train, y_test = train_test_split(
data, target, test_size=0.2, random_state=1
)

party_split_num = len(X_train) // 2
lwxxxxxxx marked this conversation as resolved.
Show resolved Hide resolved

y_train[y_train != 1] = -1
X_train, X_test, y_train, y_test = (
jnp.array(X_train),
jnp.array(X_test),
jnp.array(y_train),
jnp.array(y_test),
)

return X_train, X_test, y_train, y_test

try:
# bandwidth and latency only work for docker mode
emulator = emulation.Emulator(
"examples/python/conf/3pc.json", mode, bandwidth=300, latency=20
)
emulator.up()

time0 = time.time()
# load data
X_train, X_test, y_train, y_test = load_data()

# mark these data to be protected in SPU
X_train, X_test, y_train = emulator.seal(X_train, X_test, y_train)
result1 = emulator.run(proc)(X_train, X_test, y_train)
print("result\n", result1)
print("accuracy score", accuracy_score(result1, y_test))
print("cost time ", time.time() - time0)

# Compare with sklearn
print("sklearn")
X_train, X_test, y_train, y_test = load_data()
clf_svc = SVC(C=1.0, kernel="rbf", gamma='scale', tol=1e-3)
result2 = clf_svc.fit(X_train, y_train).predict(X_test)
print("result\n", (result2 > 0).astype(int))
print("accuracy score", accuracy_score((result2 > 0).astype(int), y_test))
finally:
emulator.down()


if __name__ == "__main__":
emul_SVM(emulation.Mode.MULTIPROCESS)
177 changes: 177 additions & 0 deletions sml/svm/smo.py
lwxxxxxxx marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import jax
import jax.numpy as jnp
import time
lwxxxxxxx marked this conversation as resolved.
Show resolved Hide resolved

INF = float('inf')


class SMO:
"""
lwxxxxxxx marked this conversation as resolved.
Show resolved Hide resolved
Parameters
----------
size : int
Size of data.

C : float
Error penalty coefficient.

tol : float, default=1e-3
Acceptable error to consider the two to be equal.
"""

def __init__(self, size, C: float, tol: float = 1e-3) -> None:
self.size = size
self.C = C
self.tol = tol
self.tau = 1e-6
self.Cs = jnp.array([self.C] * size)
self.zeros = jnp.array([0] * size)

def working_set_select_i(self, alpha, y, neg_y_grad):
"""
Select the first working set.
"""
alpha_lower_C, alpha_upper_0, y_lower_0, y_upper_0 = jnp.array(
[self.Cs, alpha, self.zeros, y]
) > jnp.array([alpha, self.zeros, y, self.zeros])
Zup = (alpha_lower_C & y_upper_0) | (alpha_upper_0 & y_lower_0)
Mup = (1 - Zup) * jnp.min(neg_y_grad)
i = jnp.argmax(neg_y_grad * Zup + Mup)
return i

def working_set_select_j(self, i, alpha, y, neg_y_grad, Q):
"""
Select the second working set.
"""
alpha_lower_C, alpha_upper_0, y_lower_0, y_upper_0 = jnp.array(
[self.Cs, alpha, self.zeros, y]
) > jnp.array([alpha, self.zeros, y, self.zeros])
Zlow = (alpha_lower_C & y_lower_0) | (alpha_upper_0 & y_upper_0)

m = neg_y_grad[i]

Zlow_m = Zlow & (neg_y_grad < m)
Qi = Q[i]
Qj = Q.diagonal()
quad_coef = Qi[i] + Qj - 2 * Q[i]
quad_coef = (quad_coef > 0) * quad_coef + (1 - (quad_coef > 0)) * self.tau
Ft = -((m - neg_y_grad) ** 2) / (quad_coef)
Mlow_m = (1 - Zlow_m) * jnp.max(Ft)
j = jnp.argmin(Ft * Zlow_m + Mlow_m)
return j

def update(self, i, j, Q, y, alpha, neg_y_grad):
"""
Update `alpha[i]` and `alpha[j]` by adjusting the way of `z = x if t else y` to `z = t*x + (1-t)*y`.
"""

Qi, Qj = Q[i], Q[j]
yi, yj = y[i], y[j]
alpha_i, alpha_j = alpha[i] + 0, alpha[j] + 0
alpha_i0, alpha_j0 = alpha_i + 0, alpha_j + 0

quad_coef = Qi[i] + Qj[j] - 2 * yi * yj * Qi[j]
quad_coef = (quad_coef > 0) * quad_coef + (1 - (quad_coef > 0)) * self.tau

yi_mul_yj = yi * yj
yi_neq_yj = yi != yj

delta = (-yi_mul_yj * neg_y_grad[i] * yi + neg_y_grad[j] * yj) / quad_coef
diff_sum = alpha_i + yi_mul_yj * alpha_j
alpha_i = alpha_i + (-1 * yi_mul_yj * delta)
alpha_j = alpha_j + delta

# first cal
(
diff_sum_upper_0,
diff_sum_upper_C,
alpha_i_lower_0,
alpha_j_lower_0,
alpha_i_upper_C,
) = jnp.array([diff_sum, diff_sum, 0, 0, alpha_i]) > jnp.array(
[0, self.C, alpha_i, alpha_j, self.C]
)
outer = jnp.array(
[yi_neq_yj, yi_neq_yj, 1 - yi_neq_yj, 1 - yi_neq_yj]
) * jnp.array(
[
diff_sum_upper_0,
1 - diff_sum_upper_0,
diff_sum_upper_C,
1 - diff_sum_upper_C,
]
)
update_condition = jnp.array(
[alpha_j_lower_0, alpha_i_lower_0, alpha_i_upper_C, alpha_j_lower_0] * 2
)
update_from = jnp.array(
[alpha_i, alpha_i, alpha_i, alpha_i, alpha_j, alpha_j, alpha_j, alpha_j]
)
update_to = jnp.array(
[diff_sum, 0, self.C, diff_sum, 0, -diff_sum, diff_sum - self.C, 0]
)
inner = (update_from + update_condition * (update_to - update_from)).reshape(
2, -1
)
alpha_i, alpha_j = jnp.dot(inner, outer.T)

# second cal
alpha_i_lower_0, alpha_i_upper_C, alpha_j_upper_C = jnp.array(
[0, alpha_i, alpha_j]
) > jnp.array([alpha_i, self.C, self.C])
update_condition = jnp.array(
[alpha_i_upper_C, alpha_j_upper_C, alpha_j_upper_C, alpha_i_lower_0] * 2
)
update_from = jnp.array(
[alpha_i, alpha_i, alpha_i, alpha_i, alpha_j, alpha_j, alpha_j, alpha_j]
)
update_to = jnp.array(
[
self.C,
self.C + diff_sum,
diff_sum - self.C,
0,
self.C - diff_sum,
self.C,
self.C,
diff_sum,
]
)
inner = (update_from + update_condition * (update_to - update_from)).reshape(
2, -1
)
alpha_i, alpha_j = jnp.dot(inner, outer.T)

delta_i = alpha_i - alpha_i0
delta_j = alpha_j - alpha_j0

neg_y_grad = neg_y_grad - y * (
jnp.dot(jnp.array([delta_i, delta_j]), jnp.array([Q[i], Q[j]]))
)
alpha = alpha.at[jnp.array([i, j])].set(jnp.array([alpha_i, alpha_j]))

return neg_y_grad, alpha

def cal_b(self, alpha, neg_y_grad, y) -> float:
"""Calculate bias."""

alpha_lower_C = alpha < self.C - self.tol
alpha_equal_C = jnp.abs(alpha - self.C) < self.tol
alpha_equal_0 = jnp.abs(alpha) < self.tol
alpha_upper_0 = alpha > 0
y_lower_0 = y < 0
y_upper_0 = y > 0

alpha_upper_0_and_lower_C = alpha_upper_0 & alpha_lower_C
sv_sum = jnp.sum(alpha_upper_0_and_lower_C)

rho_0 = -1 * (neg_y_grad * alpha_upper_0_and_lower_C).sum() / sv_sum

Zub = (alpha_equal_0 & y_lower_0) | (alpha_equal_C & y_upper_0)
Zlb = (alpha_equal_0 & y_upper_0) | (alpha_equal_C & y_lower_0)
rho_1 = -((neg_y_grad * Zub).min() + (neg_y_grad * Zlb).max()) / 2

rho = (sv_sum > 0) * rho_0 + (1 - (sv_sum > 0)) * rho_1

b = -1 * rho
return b
Loading