Skip to content

Commit

Permalink
add sml/svm (#362)
Browse files Browse the repository at this point in the history
按照最新的版本增加了svm
  • Loading branch information
lwxxxxxxx authored Oct 27, 2023
1 parent 04328c0 commit d386d3f
Show file tree
Hide file tree
Showing 7 changed files with 575 additions and 0 deletions.
30 changes: 30 additions & 0 deletions sml/svm/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2023 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

load("@rules_python//python:defs.bzl", "py_library")

package(default_visibility = ["//visibility:public"])

py_library(
name = "smo",
srcs = ["smo.py"],
)

py_library(
name = "svm",
srcs = ["svm.py"],
deps = [
":smo",
],
)
26 changes: 26 additions & 0 deletions sml/svm/emulations/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright 2023 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

load("@rules_python//python:defs.bzl", "py_binary")

package(default_visibility = ["//visibility:public"])

py_binary(
name = "svm_emul",
srcs = ["svm_emul.py"],
deps = [
"//sml/svm",
"//sml/utils:emulation",
],
)
84 changes: 84 additions & 0 deletions sml/svm/emulations/svm_emul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright 2023 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time

import jax.numpy as jnp
from sklearn import datasets
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC

import sml.utils.emulation as emulation
import spu.spu_pb2 as spu_pb2 # type: ignore
from sml.svm.svm import SVM


def emul_SVM(mode: emulation.Mode.MULTIPROCESS):
def proc(x0, x1, y0):
rbf_svm = SVM(kernel="rbf", max_iter=102)
rbf_svm.fit(x0, y0)

return rbf_svm.predict(x1)

def load_data():
breast_cancer = datasets.load_breast_cancer()
data = breast_cancer.data
data = data / (jnp.max(data) - jnp.min(data))
target = breast_cancer.target
X_train, X_test, y_train, y_test = train_test_split(
data, target, test_size=0.2, random_state=1
)

y_train[y_train != 1] = -1
X_train, X_test, y_train, y_test = (
jnp.array(X_train),
jnp.array(X_test),
jnp.array(y_train),
jnp.array(y_test),
)

return X_train, X_test, y_train, y_test

try:
# bandwidth and latency only work for docker mode
emulator = emulation.Emulator(
"examples/python/conf/3pc.json", mode, bandwidth=300, latency=20
)
emulator.up()

time0 = time.time()
# load data
X_train, X_test, y_train, y_test = load_data()

# mark these data to be protected in SPU
X_train, X_test, y_train = emulator.seal(X_train, X_test, y_train)
result1 = emulator.run(proc)(X_train, X_test, y_train)
print("result\n", result1)
print("accuracy score", accuracy_score(result1, y_test))
print("cost time ", time.time() - time0)

# Compare with sklearn
print("sklearn")
X_train, X_test, y_train, y_test = load_data()
clf_svc = SVC(C=1.0, kernel="rbf", gamma='scale', tol=1e-3)
result2 = clf_svc.fit(X_train, y_train).predict(X_test)
print("result\n", (result2 > 0).astype(int))
print("accuracy score", accuracy_score((result2 > 0).astype(int), y_test))
finally:
emulator.down()


if __name__ == "__main__":
emul_SVM(emulation.Mode.MULTIPROCESS)
191 changes: 191 additions & 0 deletions sml/svm/smo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# Copyright 2023 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import jax.numpy as jnp


class SMO:
"""
Reference: [FCLJ05]
Fan R E, Chen P H, Lin C J, et al. Working set selection using second order information for
training support vector machines[J]. Journal of machine learning research, 2005, 6(12).
Parameters
----------
size : int
Size of data.
C : float
Error penalty coefficient.
tol : float, default=1e-3
Acceptable error to consider the two to be equal.
"""

def __init__(self, size, C: float, tol: float = 1e-3) -> None:
self.size = size
self.C = C
self.tol = tol
self.tau = 1e-6
self.Cs = jnp.array([self.C] * size)
self.zeros = jnp.array([0] * size)

def working_set_select_i(self, alpha, y, neg_y_grad):
"""
Select the first working set.
"""
alpha_lower_C, alpha_upper_0, y_lower_0, y_upper_0 = jnp.array(
[self.Cs, alpha, self.zeros, y]
) > jnp.array([alpha, self.zeros, y, self.zeros])
Zup = (alpha_lower_C & y_upper_0) | (alpha_upper_0 & y_lower_0)
Mup = (1 - Zup) * jnp.min(neg_y_grad)
i = jnp.argmax(neg_y_grad * Zup + Mup)
return i

def working_set_select_j(self, i, alpha, y, neg_y_grad, Q):
"""
Select the second working set.
"""
alpha_lower_C, alpha_upper_0, y_lower_0, y_upper_0 = jnp.array(
[self.Cs, alpha, self.zeros, y]
) > jnp.array([alpha, self.zeros, y, self.zeros])
Zlow = (alpha_lower_C & y_lower_0) | (alpha_upper_0 & y_upper_0)

m = neg_y_grad[i]

Zlow_m = Zlow & (neg_y_grad < m)
Qi = Q[i]
Qj = Q.diagonal()
quad_coef = Qi[i] + Qj - 2 * Q[i]
quad_coef = (quad_coef > 0) * quad_coef + (1 - (quad_coef > 0)) * self.tau
Ft = -((m - neg_y_grad) ** 2) / (quad_coef)
Mlow_m = (1 - Zlow_m) * jnp.max(Ft)
j = jnp.argmin(Ft * Zlow_m + Mlow_m)
return j

def update(self, i, j, Q, y, alpha, neg_y_grad):
"""
Update `alpha[i]` and `alpha[j]` by adjusting the way of `z = x if t else y` to `z = t*x + (1-t)*y`.
"""

Qi, Qj = Q[i], Q[j]
yi, yj = y[i], y[j]
alpha_i, alpha_j = alpha[i] + 0, alpha[j] + 0
alpha_i0, alpha_j0 = alpha_i + 0, alpha_j + 0

quad_coef = Qi[i] + Qj[j] - 2 * yi * yj * Qi[j]
quad_coef = (quad_coef > 0) * quad_coef + (1 - (quad_coef > 0)) * self.tau

yi_mul_yj = yi * yj
yi_neq_yj = yi != yj

delta = (-yi_mul_yj * neg_y_grad[i] * yi + neg_y_grad[j] * yj) / quad_coef
diff_sum = alpha_i + yi_mul_yj * alpha_j
alpha_i = alpha_i + (-1 * yi_mul_yj * delta)
alpha_j = alpha_j + delta

# first cal
(
diff_sum_upper_0,
diff_sum_upper_C,
alpha_i_lower_0,
alpha_j_lower_0,
alpha_i_upper_C,
) = jnp.array([diff_sum, diff_sum, 0, 0, alpha_i]) > jnp.array(
[0, self.C, alpha_i, alpha_j, self.C]
)
outer = jnp.array(
[yi_neq_yj, yi_neq_yj, 1 - yi_neq_yj, 1 - yi_neq_yj]
) * jnp.array(
[
diff_sum_upper_0,
1 - diff_sum_upper_0,
diff_sum_upper_C,
1 - diff_sum_upper_C,
]
)
update_condition = jnp.array(
[alpha_j_lower_0, alpha_i_lower_0, alpha_i_upper_C, alpha_j_lower_0] * 2
)
update_from = jnp.array(
[alpha_i, alpha_i, alpha_i, alpha_i, alpha_j, alpha_j, alpha_j, alpha_j]
)
update_to = jnp.array(
[diff_sum, 0, self.C, diff_sum, 0, -diff_sum, diff_sum - self.C, 0]
)
inner = (update_from + update_condition * (update_to - update_from)).reshape(
2, -1
)
alpha_i, alpha_j = jnp.dot(inner, outer.T)

# second cal
alpha_i_lower_0, alpha_i_upper_C, alpha_j_upper_C = jnp.array(
[0, alpha_i, alpha_j]
) > jnp.array([alpha_i, self.C, self.C])
update_condition = jnp.array(
[alpha_i_upper_C, alpha_j_upper_C, alpha_j_upper_C, alpha_i_lower_0] * 2
)
update_from = jnp.array(
[alpha_i, alpha_i, alpha_i, alpha_i, alpha_j, alpha_j, alpha_j, alpha_j]
)
update_to = jnp.array(
[
self.C,
self.C + diff_sum,
diff_sum - self.C,
0,
self.C - diff_sum,
self.C,
self.C,
diff_sum,
]
)
inner = (update_from + update_condition * (update_to - update_from)).reshape(
2, -1
)
alpha_i, alpha_j = jnp.dot(inner, outer.T)

delta_i = alpha_i - alpha_i0
delta_j = alpha_j - alpha_j0

neg_y_grad = neg_y_grad - y * (
jnp.dot(jnp.array([delta_i, delta_j]), jnp.array([Q[i], Q[j]]))
)
alpha = alpha.at[jnp.array([i, j])].set(jnp.array([alpha_i, alpha_j]))

return neg_y_grad, alpha

def cal_b(self, alpha, neg_y_grad, y) -> float:
"""Calculate bias."""

alpha_lower_C = alpha < self.C - self.tol
alpha_equal_C = jnp.abs(alpha - self.C) < self.tol
alpha_equal_0 = jnp.abs(alpha) < self.tol
alpha_upper_0 = alpha > 0
y_lower_0 = y < 0
y_upper_0 = y > 0

alpha_upper_0_and_lower_C = alpha_upper_0 & alpha_lower_C
sv_sum = jnp.sum(alpha_upper_0_and_lower_C)

rho_0 = -1 * (neg_y_grad * alpha_upper_0_and_lower_C).sum() / sv_sum

Zub = (alpha_equal_0 & y_lower_0) | (alpha_equal_C & y_upper_0)
Zlb = (alpha_equal_0 & y_upper_0) | (alpha_equal_C & y_lower_0)
rho_1 = -((neg_y_grad * Zub).min() + (neg_y_grad * Zlb).max()) / 2

rho = (sv_sum > 0) * rho_0 + (1 - (sv_sum > 0)) * rho_1

b = -1 * rho
return b
Loading

0 comments on commit d386d3f

Please sign in to comment.