From eb22e10b898739df5d66db43eeea2d0b3793a35a Mon Sep 17 00:00:00 2001 From: xbw886 <1042740841@qq.com> Date: Mon, 1 Jul 2024 17:20:12 +0800 Subject: [PATCH 1/3] =?UTF-8?q?random=E4=B8=8D=E6=AD=A3=E7=A1=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sml/forest/BUILD.bazel | 22 +++ sml/forest/emulations/BUILD.bazel | 27 +++ sml/forest/emulations/forest_emul.py | 139 +++++++++++++++ sml/forest/forest.py | 241 +++++++++++++++++++++++++++ sml/forest/tests/BUILD.bazel | 22 +++ sml/forest/tests/forest_test.py | 138 +++++++++++++++ 6 files changed, 589 insertions(+) create mode 100644 sml/forest/BUILD.bazel create mode 100644 sml/forest/emulations/BUILD.bazel create mode 100644 sml/forest/emulations/forest_emul.py create mode 100644 sml/forest/forest.py create mode 100644 sml/forest/tests/BUILD.bazel create mode 100644 sml/forest/tests/forest_test.py diff --git a/sml/forest/BUILD.bazel b/sml/forest/BUILD.bazel new file mode 100644 index 00000000..9d1d0ade --- /dev/null +++ b/sml/forest/BUILD.bazel @@ -0,0 +1,22 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_python//python:defs.bzl", "py_library") + +package(default_visibility = ["//visibility:public"]) + +py_library( + name = "forest", + srcs = ["forest.py"], +) diff --git a/sml/forest/emulations/BUILD.bazel b/sml/forest/emulations/BUILD.bazel new file mode 100644 index 00000000..7323a510 --- /dev/null +++ b/sml/forest/emulations/BUILD.bazel @@ -0,0 +1,27 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_python//python:defs.bzl", "py_binary") + +package(default_visibility = ["//visibility:public"]) + +py_binary( + name = "forest_emul", + srcs = ["forest_emul.py"], + deps = [ + "//sml/forest", + "//sml/tree", + "//sml/utils:emulation", + ], +) diff --git a/sml/forest/emulations/forest_emul.py b/sml/forest/emulations/forest_emul.py new file mode 100644 index 00000000..babc9b45 --- /dev/null +++ b/sml/forest/emulations/forest_emul.py @@ -0,0 +1,139 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +''' +Author: Li Zhihang +Date: 2024-06-16 12:02:32 +LastEditTime: 2024-06-22 16:52:19 +FilePath: /klaus/spu-klaus/sml/forest/emulations/forest_emul.py +Description: +''' +import time + +import jax.numpy as jnp +from sklearn.datasets import load_iris +from sklearn.ensemble import RandomForestClassifier + +import sml.utils.emulation as emulation +from sml.forest.forest import RandomForestClassifier as sml_rfc + +MAX_DEPTH = 3 +CONFIG_FILE = emulation.CLUSTER_ABY3_3PC + + +def emul_forest(mode=emulation.Mode.MULTIPROCESS): + def proc_wrapper( + n_estimators=100, + max_features=None, + n_features=200, + criterion='gini', + splitter='best', + max_depth=3, + bootstrap=True, + max_samples=None, + n_labels=3, + seed=0, + ): + rf_custom = sml_rfc( + n_estimators, + max_features, + n_features, + criterion, + splitter, + max_depth, + bootstrap, + max_samples, + n_labels, + seed, + ) + + def proc(X, y): + rf_custom_fit = rf_custom.fit(X, y) + result = rf_custom_fit.predict(X) + return result + + return proc + + def load_data(): + iris = load_iris() + iris_data, iris_label = jnp.array(iris.data), jnp.array(iris.target) + # sorted_features: n_samples * n_features_in + n_samples, n_features_in = iris_data.shape + n_labels = len(jnp.unique(iris_label)) + sorted_features = jnp.sort(iris_data, axis=0) + new_threshold = (sorted_features[:-1, :] + sorted_features[1:, :]) / 2 + new_features = jnp.greater_equal( + iris_data[:, :], new_threshold[:, jnp.newaxis, :] + ) + new_features = new_features.transpose([1, 0, 2]).reshape(n_samples, -1) + + X, y = new_features[:, ::3], iris_label[:] + return X, y + + try: + # bandwidth and latency only work for docker mode + emulator = emulation.Emulator(CONFIG_FILE, mode, bandwidth=300, latency=20) + emulator.up() + + # load mock data + X, y = load_data() + n_samples, n_features = X.shape + n_labels = jnp.unique(y).shape[0] + + # compare with sklearn + rf = RandomForestClassifier( + n_estimators=3, + max_features=0.7, + criterion='gini', + max_depth=MAX_DEPTH, + bootstrap=True, + ) + start = time.time() + rf = rf.fit(X, y) + score_plain = rf.score(X, y) + end = time.time() + print(f"Running time in SKlearn: {end - start:.2f}s") + + # mark these data to be protected in SPU + X_spu, y_spu = emulator.seal(X, y) + + # run + proc = proc_wrapper( + n_estimators=3, + max_features=0.7, + n_features=n_features, + criterion='gini', + splitter='best', + max_depth=3, + bootstrap=True, + max_samples=0.7, + n_labels=n_labels, + seed=0, + ) + start = time.time() + # 不可以使用bootstrap,否则在spu运行的正确率很低 + result = emulator.run(proc)(X_spu, y_spu) + end = time.time() + score_encrpted = jnp.sum((result == y)) / n_samples + print(f"Running time in SPU: {end - start:.2f}s") + + # print acc + print(f"Accuracy in SKlearn: {score_plain:.2f}") + print(f"Accuracy in SPU: {score_encrpted:.2f}") + + finally: + emulator.down() + + +if __name__ == "__main__": + emul_forest(emulation.Mode.MULTIPROCESS) diff --git a/sml/forest/forest.py b/sml/forest/forest.py new file mode 100644 index 00000000..f081e644 --- /dev/null +++ b/sml/forest/forest.py @@ -0,0 +1,241 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +''' +Author : error: git config user.email & please set dead value or install git +Date : 2024-06-15 20:01 +LastEditors: Please set LastEditors +LastEditTime: 2024-06-18 19:09:17 +FilePath: /klaus/spu-klaus/sml/forest/forest.py +Description : 基本完成函数编写的工作,目前测试结果基本正确,后面需要完成emul和test +bootstrap有问题,bootstrap后predict不输出1,bootstrap无1(因为不支持jax.random的api) + +!最终:bootstrap这个参数,不可用:在明文下bootstrap取样正确,但在forest_test.py时,无法取到标签1, +因为bootstrap不能用,目前没有开发max_samples这个超参数 + +基本完成,因为不支持jax.random的api,所以bootstrap和select_features都没有使用随机算法, +导致max_features为sqrt和log2的时候误差较大,而sklearn的select_features比较随机,因此误差大 +如果不增加n_features这个参数的话,会导致创建动态数组,并对动态数组切分,导致程序报错 +''' +import jax.numpy as jnp +from jax import lax + +from sml.tree.tree import DecisionTreeClassifier as sml_dtc + +# from functools import partial +# from jax import jit +# import jax.random as jdm + + +# key = jdm.PRNGKey(42) + + +class RandomForestClassifier: + """A random forest classifier.""" + + def __init__( + self, + n_estimators, + max_features, + n_features, + criterion, + splitter, + max_depth, + bootstrap, + max_samples, + n_labels, + seed, + ): + assert criterion == "gini", "criteria other than gini is not supported." + assert splitter == "best", "splitter other than best is not supported." + assert ( + n_estimators is not None and n_estimators > 0 + ), "n_estimators should not be None and must > 0." + assert ( + max_depth is not None and max_depth > 0 + ), "max_depth should not be None and must > 0." + assert isinstance( + bootstrap, bool + ), "bootstrap should be a boolean value (True or False)" + assert isinstance(n_features, int), "n_features should be an integer." + if isinstance(max_features, int): + assert ( + max_features <= n_features + ), "max_features should not exceed n_features when it's an integer" + max_features = jnp.array(n_features, dtype=int) + elif isinstance(max_features, float): + assert ( + 0 < max_features <= 1 + ), "max_features should be in the range (0, 1] when it's a float" + max_features = jnp.array((max_features * n_features), dtype=int) + elif isinstance(max_features, str): + if max_features == 'sqrt': + max_features = jnp.array(jnp.sqrt(n_features), dtype=int) + elif max_features == 'log2': + max_features = jnp.array(jnp.log2(n_features), dtype=int) + else: + max_features = n_features + else: + max_features = n_features + + self.seed = seed + # self.key = key + + self.n_estimators = n_estimators + self.max_features = max_features + self.n_features = n_features + self.criterion = criterion + self.splitter = splitter + self.max_depth = max_depth + self.bootstrap = bootstrap + self.max_samples = max_samples + # max_samples 这个参数在bootstrap为ture才有效,可以为float(0,1],int<总数,默认为none + self.n_labels = n_labels + + self.trees = [] + self.features_indices = [] + + def _bootstrap_sample(self, X, y): + n_samples = X.shape[0] + if isinstance(self.max_samples, int): + assert ( + self.max_samples <= n_samples + ), "max_samples should not exceed n_samples when it's an integer" + max_samples = self.max_samples + elif isinstance(self.max_samples, float): + assert ( + 0 < self.max_samples <= 1 + ), "max_samples should be in the range (0, 1] when it's a float" + max_samples = (int)(self.max_samples * n_samples) + else: + max_samples = n_samples + + if not self.bootstrap: + return X, y + # 使用斐波那契数列变体生成伪随机索引 + indices = jnp.zeros(max_samples, dtype=int) + a, b = self.seed % n_samples, (self.seed + 1) % n_samples + for i in range(max_samples): + indices = indices.at[i].set(b % n_samples) + a, b = b, (a + b) % n_samples # 生成斐波那契数列并取模 + # 更新种子值以保证每次调用生成不同的序列 + self.seed += 1 + return X[indices], y[indices] + + # 可以用,但没n选k + def _select_features(self): + # if isinstance(self.max_features, int): + # assert self.max_features <= n_features, "max_features should not exceed n_features when it's an integer" + # max_features = jnp.array(n_features, dtype=int) + # elif isinstance(self.max_features, float): + # assert 0 < self.max_features <= 1, "max_features should be in the range (0, 1] when it's a float" + # max_features = jnp.array((self.max_features * n_features), dtype=int) + # elif isinstance(self.max_features, str): + # if self.max_features == 'sqrt': + # max_features = jnp.array(jnp.sqrt(n_features), dtype=int) + # elif self.max_features == 'log2': + # max_features = jnp.array(jnp.log2(n_features), dtype=int) + # else: + # max_features = n_features + # else: + # max_features = n_features + + selected_indices = self._shuffle_indices(self.n_features, self.max_features)[ + : self.max_features + ] + self.seed += 1 + return selected_indices + + def _shuffle_indices(self, n, k): + # 基于种子的循环洗牌算法,确保不出现重复的索引 + rng = self.seed + indices = jnp.arange(n) + # indices = range(n) + + def cond_fun(state): + i, _, _ = state + return i < k + + def body_fun(state): + i, rng, indices = state + rng = (rng * 48271 + 1) % (2**31 - 1) + j = i + rng % (n - i) + # 交换元素以进行洗牌 + indices = self._swap(indices, i, j) + return (i + 1, rng, indices) + + _, _, shuffled_indices = lax.while_loop(cond_fun, body_fun, (0, rng, indices)) + # selected_indices = shuffled_indices[:k] + selected_indices = shuffled_indices + return selected_indices + + def _swap(self, array, i, j): + # 辅助函数:交换数组中的两个元素 + array = array.at[i].set(array[j]) + array = array.at[j].set(array[i]) + return array + + def fit(self, X, y): + n_samples, n_features = X.shape + self.trees = [] + self.features_indices = [] + + for _ in range(self.n_estimators): + X_sample, y_sample = self._bootstrap_sample(X, y) + features = self._select_features() + # selected_indices = self._shuffle_indices(n_features) + print(y_sample) + tree = sml_dtc(self.criterion, self.splitter, self.max_depth, self.n_labels) + tree.fit(X_sample[:, features], y_sample) + self.trees.append(tree) + self.features_indices.append(features) + + return self + + def predict(self, X): + tree_predictions = jnp.zeros((X.shape[0], self.n_estimators)) + + for i, tree in enumerate(self.trees): + features = self.features_indices[i] + print(features) + tree_predictions = tree_predictions.at[:, i].set( + tree.predict(X[:, features]) + ) + # print(tree_predictions[:, i]) + # Use majority vote to determine final prediction + y_pred, _ = jax_mode_row(tree_predictions) + # return tree_predictions + return y_pred.ravel() + + +def jax_mode_row(data): + # 获取每行的众数 + + # 获取数据的形状 + num_rows, num_cols = data.shape + + # 初始化众数和计数的数组 + modes = jnp.zeros(num_rows, dtype=data.dtype) + counts = jnp.zeros(num_rows, dtype=jnp.int32) + + # 计算每行的众数及其计数 + for row in range(num_rows): + row_data = data[row, :] + unique_values, value_counts = jnp.unique( + row_data, return_counts=True, size=row_data.shape[0] + ) + max_count_idx = jnp.argmax(value_counts) + modes = modes.at[row].set(unique_values[max_count_idx]) + counts = counts.at[row].set(value_counts[max_count_idx]) + + return modes, counts diff --git a/sml/forest/tests/BUILD.bazel b/sml/forest/tests/BUILD.bazel new file mode 100644 index 00000000..9d1d0ade --- /dev/null +++ b/sml/forest/tests/BUILD.bazel @@ -0,0 +1,22 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_python//python:defs.bzl", "py_library") + +package(default_visibility = ["//visibility:public"]) + +py_library( + name = "forest", + srcs = ["forest.py"], +) diff --git a/sml/forest/tests/forest_test.py b/sml/forest/tests/forest_test.py new file mode 100644 index 00000000..7e5fafcf --- /dev/null +++ b/sml/forest/tests/forest_test.py @@ -0,0 +1,138 @@ +''' +Author: Li Zhihang +Date: 2024-06-16 12:03:08 +LastEditTime: 2024-06-22 16:45:43 +FilePath: /klaus/spu-klaus/sml/forest/tests/forest_test.py +Description:正确率相差太大:Accuracy in SKlearn: 0.95;Accuracy in SPU: 0.67 +''' + +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import jax.numpy as jnp +from sklearn.datasets import load_iris +from sklearn.ensemble import RandomForestClassifier + +import spu.spu_pb2 as spu_pb2 # type: ignore +import spu.utils.simulation as spsim +from sml.forest.forest import RandomForestClassifier as sml_rfc + +MAX_DEPTH = 3 + + +class UnitTests(unittest.TestCase): + def test_forest(self): + def proc_wrapper( + n_estimators=100, + max_features=None, + n_features=200, + criterion='gini', + splitter='best', + max_depth=3, + bootstrap=True, + max_samples=None, + n_labels=3, + seed=0, + ): + rf_custom = sml_rfc( + n_estimators, + max_features, + n_features, + criterion, + splitter, + max_depth, + bootstrap, + max_samples, + n_labels, + seed, + ) + + def proc(X, y): + rf_custom_fit = rf_custom.fit(X, y) + + result = rf_custom_fit.predict(X) + return result + + return proc + + def load_data(): + iris = load_iris() + iris_data, iris_label = jnp.array(iris.data), jnp.array(iris.target) + # sorted_features: n_samples * n_features_in + n_samples, n_features_in = iris_data.shape + n_labels = len(jnp.unique(iris_label)) + sorted_features = jnp.sort(iris_data, axis=0) + new_threshold = (sorted_features[:-1, :] + sorted_features[1:, :]) / 2 + new_features = jnp.greater_equal( + iris_data[:, :], new_threshold[:, jnp.newaxis, :] + ) + new_features = new_features.transpose([1, 0, 2]).reshape(n_samples, -1) + + X, y = new_features[:, ::3], iris_label[:] + return X, y + + # bandwidth and latency only work for docker mode + sim = spsim.Simulator.simple( + 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 + ) + + # load mock data + X, y = load_data() + n_samples, n_features = X.shape + n_labels = jnp.unique(y).shape[0] + print(y) + + # compare with sklearn + rf = RandomForestClassifier( + n_estimators=3, + max_features='log2', + criterion='gini', + max_depth=MAX_DEPTH, + bootstrap=True, + max_samples=0.7, + ) + rf = rf.fit(X, y) + score_plain = rf.score(X, y) + # 获取每棵树的预测值 + tree_predictions = jnp.array([tree.predict(X) for tree in rf.estimators_]) + # print("sklearn:") + # print(tree_predictions) + print(n_features) + # run + proc = proc_wrapper( + n_estimators=3, + max_features='log2', + n_features=n_features, + criterion='gini', + splitter='best', + max_depth=3, + bootstrap=True, + max_samples=0.7, + n_labels=n_labels, + seed=0, + ) + # 不可以使用bootstrap,否则在spu运行的正确率很低 + result = spsim.sim_jax(sim, proc)(X, y) + + # print(y_sample) + score_encrpted = jnp.sum((result == y)) / n_samples + + # print acc + print(f"Accuracy in SKlearn: {score_plain}") + print(f"Accuracy in SPU: {score_encrpted}") + + +if __name__ == "__main__": + unittest.main() From a37ae00e10b6fb825f8dda26bb7953aadb96acdc Mon Sep 17 00:00:00 2001 From: xbw886 <1042740841@qq.com> Date: Tue, 2 Jul 2024 11:54:54 +0800 Subject: [PATCH 2/3] =?UTF-8?q?=E5=87=86=E7=A1=AE=E7=8E=87=E6=8F=90?= =?UTF-8?q?=E9=AB=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sml/forest/emulations/forest_emul.py | 9 +++++---- sml/forest/forest.py | 13 ++++++++----- sml/forest/tests/BUILD.bazel | 14 ++++++++++---- sml/forest/tests/forest_test.py | 23 +++++++++++------------ 4 files changed, 34 insertions(+), 25 deletions(-) diff --git a/sml/forest/emulations/forest_emul.py b/sml/forest/emulations/forest_emul.py index babc9b45..b4505951 100644 --- a/sml/forest/emulations/forest_emul.py +++ b/sml/forest/emulations/forest_emul.py @@ -93,10 +93,11 @@ def load_data(): # compare with sklearn rf = RandomForestClassifier( n_estimators=3, - max_features=0.7, + max_features=None, criterion='gini', max_depth=MAX_DEPTH, - bootstrap=True, + bootstrap=False, + max_samples=None, ) start = time.time() rf = rf.fit(X, y) @@ -115,8 +116,8 @@ def load_data(): criterion='gini', splitter='best', max_depth=3, - bootstrap=True, - max_samples=0.7, + bootstrap=False, + max_samples=None, n_labels=n_labels, seed=0, ) diff --git a/sml/forest/forest.py b/sml/forest/forest.py index f081e644..9a7ba03a 100644 --- a/sml/forest/forest.py +++ b/sml/forest/forest.py @@ -124,12 +124,15 @@ def _bootstrap_sample(self, X, y): return X, y # 使用斐波那契数列变体生成伪随机索引 indices = jnp.zeros(max_samples, dtype=int) - a, b = self.seed % n_samples, (self.seed + 1) % n_samples + rng = self.seed + # a, b = self.seed % n_samples, (self.seed + 1) % n_samples for i in range(max_samples): - indices = indices.at[i].set(b % n_samples) - a, b = b, (a + b) % n_samples # 生成斐波那契数列并取模 - # 更新种子值以保证每次调用生成不同的序列 + rng = (rng * 48271 + 1) % (2**31 - 1) + indices = indices.at[i].set(rng % n_samples) + # indices = indices.at[i].set(b % n_samples) + # a, b = b, (a + b) % n_samples # 生成斐波那契数列并取模 self.seed += 1 + # 更新种子值以保证每次调用生成不同的序列 return X[indices], y[indices] # 可以用,但没n选k @@ -153,7 +156,6 @@ def _select_features(self): selected_indices = self._shuffle_indices(self.n_features, self.max_features)[ : self.max_features ] - self.seed += 1 return selected_indices def _shuffle_indices(self, n, k): @@ -172,6 +174,7 @@ def body_fun(state): j = i + rng % (n - i) # 交换元素以进行洗牌 indices = self._swap(indices, i, j) + self.seed += 1 return (i + 1, rng, indices) _, _, shuffled_indices = lax.while_loop(cond_fun, body_fun, (0, rng, indices)) diff --git a/sml/forest/tests/BUILD.bazel b/sml/forest/tests/BUILD.bazel index 9d1d0ade..84133b03 100644 --- a/sml/forest/tests/BUILD.bazel +++ b/sml/forest/tests/BUILD.bazel @@ -12,11 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("@rules_python//python:defs.bzl", "py_library") +load("@rules_python//python:defs.bzl", "py_test") package(default_visibility = ["//visibility:public"]) -py_library( - name = "forest", - srcs = ["forest.py"], +py_test( + name = "forest_test", + srcs = ["forest_test.py"], + deps = [ + "//sml/forest", + "//sml/tree", + "//spu:init", + "//spu/utils:simulation", + ], ) diff --git a/sml/forest/tests/forest_test.py b/sml/forest/tests/forest_test.py index 7e5fafcf..93eb9930 100644 --- a/sml/forest/tests/forest_test.py +++ b/sml/forest/tests/forest_test.py @@ -1,11 +1,3 @@ -''' -Author: Li Zhihang -Date: 2024-06-16 12:03:08 -LastEditTime: 2024-06-22 16:45:43 -FilePath: /klaus/spu-klaus/sml/forest/tests/forest_test.py -Description:正确率相差太大:Accuracy in SKlearn: 0.95;Accuracy in SPU: 0.67 -''' - # Copyright 2023 Ant Group Co., Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -19,6 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +''' +Author: Li Zhihang +Date: 2024-06-16 12:03:08 +LastEditTime: 2024-07-01 18:25:23 +FilePath: /klaus/spu/sml/forest/tests/forest_test.py +Description: +''' import unittest import jax.numpy as jnp @@ -37,7 +36,7 @@ def test_forest(self): def proc_wrapper( n_estimators=100, max_features=None, - n_features=200, + n_features=199, criterion='gini', splitter='best', max_depth=3, @@ -97,7 +96,7 @@ def load_data(): # compare with sklearn rf = RandomForestClassifier( n_estimators=3, - max_features='log2', + max_features=None, criterion='gini', max_depth=MAX_DEPTH, bootstrap=True, @@ -113,7 +112,7 @@ def load_data(): # run proc = proc_wrapper( n_estimators=3, - max_features='log2', + max_features=None, n_features=n_features, criterion='gini', splitter='best', @@ -123,7 +122,7 @@ def load_data(): n_labels=n_labels, seed=0, ) - # 不可以使用bootstrap,否则在spu运行的正确率很低 + result = spsim.sim_jax(sim, proc)(X, y) # print(y_sample) From d4ec1018291a112cc60b2053a08fb559f303cab8 Mon Sep 17 00:00:00 2001 From: xbw886 <1042740841@qq.com> Date: Sat, 27 Jul 2024 19:57:22 +0800 Subject: [PATCH 3/3] rf&quantile&ada --- sml/adaboost/BUILD.bazel | 22 + sml/adaboost/adaboost.py | 285 +++++++++++ .../emulations/BUILD.bazel | 8 +- sml/adaboost/emulations/adaboost_emul.py | 106 ++++ sml/adaboost/tests/BUILD.bazel | 28 + sml/adaboost/tests/adaboost_test.py | 98 ++++ sml/ensemble/BUILD.bazel | 9 + sml/ensemble/emulations/BUILD.bazel | 14 + .../emulations/forest_emul.py | 19 +- sml/ensemble/forest.py | 221 ++++++++ sml/ensemble/tests/BUILD.bazel | 15 + sml/{forest => ensemble}/tests/forest_test.py | 45 +- sml/forest/forest.py | 244 --------- sml/quantile/BUILD.bazel | 25 + sml/quantile/quantile.py | 299 +++++++++++ sml/{forest => quantile}/tests/BUILD.bazel | 7 +- sml/quantile/tests/quantile_test.py | 102 ++++ sml/{forest => quantile/utils}/BUILD.bazel | 5 +- sml/quantile/utils/linprog.py | 478 ++++++++++++++++++ sml/tree/BUILD.bazel | 5 + sml/tree/tree_w.py | 304 +++++++++++ 21 files changed, 2037 insertions(+), 302 deletions(-) create mode 100644 sml/adaboost/BUILD.bazel create mode 100644 sml/adaboost/adaboost.py rename sml/{forest => adaboost}/emulations/BUILD.bazel (86%) create mode 100644 sml/adaboost/emulations/adaboost_emul.py create mode 100644 sml/adaboost/tests/BUILD.bazel create mode 100644 sml/adaboost/tests/adaboost_test.py rename sml/{forest => ensemble}/emulations/forest_emul.py (88%) create mode 100644 sml/ensemble/forest.py rename sml/{forest => ensemble}/tests/forest_test.py (77%) delete mode 100644 sml/forest/forest.py create mode 100644 sml/quantile/BUILD.bazel create mode 100644 sml/quantile/quantile.py rename sml/{forest => quantile}/tests/BUILD.bazel (88%) create mode 100644 sml/quantile/tests/quantile_test.py rename sml/{forest => quantile/utils}/BUILD.bazel (93%) create mode 100644 sml/quantile/utils/linprog.py create mode 100644 sml/tree/tree_w.py diff --git a/sml/adaboost/BUILD.bazel b/sml/adaboost/BUILD.bazel new file mode 100644 index 00000000..687dd03a --- /dev/null +++ b/sml/adaboost/BUILD.bazel @@ -0,0 +1,22 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_python//python:defs.bzl", "py_library") + +package(default_visibility = ["//visibility:public"]) + +py_library( + name = "adaboost", + srcs = ["adaboost.py"], +) diff --git a/sml/adaboost/adaboost.py b/sml/adaboost/adaboost.py new file mode 100644 index 00000000..276226c0 --- /dev/null +++ b/sml/adaboost/adaboost.py @@ -0,0 +1,285 @@ +''' +Author: Li Zhihang +Date: 2024-07-11 12:56:40 +LastEditTime: 2024-07-27 19:53:38 +FilePath: /klaus/spu/sml/adaboost/adaboost.py +Description: + +7.16 基本实现adaboost_test,但是没有早停等算法等, +如果要实现早停算法考虑使用while_loop,但是循环内更新self的变量被视为数据泄露。导致错误 +''' +import jax.numpy as jnp +from jax import lax +import warnings +from sml.tree.tree_w import DecisionTreeClassifier as sml_dtc + +class AdaBoostClassifier: + """A adaboost classifier based on DecisionTreeClassifier. + + Parameters + ---------- + estimator : {"dtc"}, default="dtc" + Specifies the type of model or algorithm to be used for training. + Supported estimators are "dtc". + + n_estimators : int + The number of estimators. Must specify an integer > 0. + + max_depth : int + The maximum depth of the tree. Must specify an integer > 0. + + learning_rate : float + The step size used to update the model weights during training. + It's an float, must learning_rate > 0. + + n_classes: int + The max number of classes. + + """ + def __init__( + self, + estimator, + # 默认estimator为决策树,criterion == "gini" splitter == "best" + n_estimators, + max_depth, + learning_rate, + n_classes, + ): + assert estimator == "dtc", "estimator other than dtc is not supported." + assert ( + n_estimators is not None and n_estimators > 0 + ), "n_estimators should not be None and must > 0." + assert( + max_depth is not None and max_depth > 0 + ), "max_depth should not be None and must > 0." + + self.estimator = estimator + self.n_estimators = n_estimators + self.max_depth = max_depth + self.learning_rate = learning_rate + self.n_classes = n_classes + + self.estimators_ = [] + self.estimator_weight = jnp.zeros(self.n_estimators, dtype=jnp.float32) + self.estimator_errors = jnp.ones(self.n_estimators, dtype=jnp.float32) + + def _num_samples(self, x): + """返回x中的样本数量.""" + if hasattr(x, 'fit'): + # 检查是否是一个estimator + raise TypeError('Expected sequence or array-like, got estimator') + if not hasattr(x, '__len__') and not hasattr(x, 'shape') and not hasattr(x, '__array__'): + raise TypeError("Expected sequence or array-like, got %s" % type(x)) + + if hasattr(x, 'shape'): + if len(x.shape) == 0: # scalar + raise TypeError("Singleton array %r cannot be considered a valid collection." % x) + return x.shape[0] + else: + return len(x) + + def _check_sample_weight(self, sample_weight, X, dtype=None, copy=False, only_non_negative=False): + ''' + description: 验证样本权重. + return {*} + ''' + # jax默认只支持float32, + # 如果需要启用 float64 类型,可以设置 jax_enable_x64 配置选项或 JAX_ENABLE_X64 环境变量。 + n_samples = self._num_samples(X) + if dtype is not None and dtype not in [jnp.float32, jnp.float64]: + dtype = jnp.float32 + + if sample_weight is None: + sample_weight = jnp.ones(n_samples, dtype=dtype) + elif isinstance(sample_weight, numbers.Number): + sample_weight = jnp.full(n_samples, sample_weight, dtype=dtype) + else: + sample_weight = jnp.asarray(sample_weight, dtype=dtype) + if sample_weight.ndim != 1: + raise ValueError("Sample weight must be 1D array or scalar") + + if sample_weight.shape[0] != n_samples: + raise ValueError( + "sample_weight.shape == {}, expected {}!".format( + sample_weight.shape, (n_samples,) + ) + ) + + if copy: + sample_weight = jnp.copy(sample_weight) + + return sample_weight + + def cond_fun(self, iboost, sample_weight, estimator_weight, estimator_error): + status1 = jnp.logical_and(iboost < self.n_estimators, jnp.all(jnp.isfinite(sample_weight))) + status2 = jnp.logical_and(estimator_error > 0, jnp.sum(sample_weight) > 0) + status = jnp.logical_and(status1, status2) + return status + + + def fit(self, X, y, sample_weight=None): + sample_weight = self._check_sample_weight( + sample_weight, X, copy=True, only_non_negative=True + ) + sample_weight /= sample_weight.sum() + + self.classes = y + + + epsilon = jnp.finfo(sample_weight.dtype).eps + + self.estimator_weight_ = jnp.zeros(self.n_estimators, dtype=jnp.float32) + self.estimator_errors_ = jnp.ones(self.n_estimators, dtype=jnp.float32) + + for iboost in range(self.n_estimators): + sample_weight = jnp.clip(sample_weight, a_min=epsilon, a_max=None) + + sample_weight, estimator_weight, estimator_error = self._boost_discrete( + iboost, X, y, sample_weight + ) + + self.estimator_weight_ = self.estimator_weight_.at[iboost].set(estimator_weight) + self.estimator_errors_ = self.estimator_errors_.at[iboost].set(estimator_error) + + sample_weight_sum = jnp.sum(sample_weight) + def not_last_iboost(sample_weight, sample_weight_sum): + sample_weight /= sample_weight_sum + return sample_weight + def last_iboost(sample_weight, sample_weight_sum): + return sample_weight + sample_weight = lax.cond(iboost 0) + ) + return sample_weight + + def last_iboost(sample_weight): + return sample_weight + + sample_weight = lax.cond(iboost != self.n_estimators - 1, + not_last_iboost, last_iboost, sample_weight) + + + return sample_weight, estimator_weight, estimator_error + + sample_weight, estimator_weight, estimator_error = lax.cond( + estimator_error <= 0.0, true_0_fun, false_0_fun, sample_weight + ) + + return sample_weight, estimator_weight, estimator_error + + + def predict(self, X): + pred = self.decision_function(X) + + if self.n_classes == 2: + return self.classes.take(pred > 0, axis=0) + + return self.classes.take(jnp.argmax(pred, axis=1), axis=0) + + + def decision_function(self, X): + n_classes = self.n_classes + classes = self.classes[:, jnp.newaxis] + + pred = sum( + jnp.where( + (estimator.predict(X) == classes).T, + w, + -1 / (n_classes - 1) * w, + ) + for estimator, w in zip(self.estimators_, self.estimator_weight_) + ) + pred /= self.estimator_weight_.sum() + + if n_classes == 2: + pred[:, 0] *= -1 + return pred.sum(axis=1) + return pred + + +# import jax.numpy as jnp +# from sklearn.datasets import load_iris +# from sklearn.metrics import accuracy_score, classification_report +# def load_data(): +# iris = load_iris() +# iris_data, iris_label = jnp.array(iris.data), jnp.array(iris.target) +# # sorted_features: n_samples * n_features_in +# n_samples, n_features_in = iris_data.shape +# n_labels = len(jnp.unique(iris_label)) +# sorted_features = jnp.sort(iris_data, axis=0) +# new_threshold = (sorted_features[:-1, :] + sorted_features[1:, :]) / 2 +# new_features = jnp.greater_equal( +# iris_data[:, :], new_threshold[:, jnp.newaxis, :] +# ) +# new_features = new_features.transpose([1, 0, 2]).reshape(n_samples, -1) + +# X, y = new_features[:, ::3], iris_label[:] +# return X, y + +# X,y = load_data() +# n_labels = len(jnp.unique(y)) +# model = AdaBoostClassifier(estimator='dtc', n_estimators=50,max_depth=2,learning_rate=0.5, n_classes=n_labels) +# # 训练AdaBoost模型 +# model =model.fit(X, y, sample_weight=None) +# # print(model.estimator_weight_) +# print(model.estimator_errors_) +# # 预测测试集 +# y_pred = model.predict(X) +# # print(y_pred) + +# n_samples, n_features = X.shape +# score_encrypted = jnp.mean(y_pred == y) +# print(y_pred) +# print(y) +# print(f"Accuracy in SPU: {score_encrypted}") + +# # 输出预测结果的准确率和分类报告 +# print(f"Accuracy: {accuracy_score(y, y_pred)}") + +# from sklearn.ensemble import AdaBoostClassifier +# from sklearn.tree import DecisionTreeClassifier + +# base_estimator = DecisionTreeClassifier(max_depth=2) # 基分类器 +# model = AdaBoostClassifier(estimator=base_estimator, n_estimators=50,learning_rate=0.5,algorithm="SAMME") + +# # 训练AdaBoost模型 +# model.fit(X, y, sample_weight=None) +# print(model.estimator_errors_) +# # 预测测试集 +# y_pred = model.predict(X) +# print(y_pred) +# score_plain = model.score(X, y) +# print(score_plain) +# # 输出预测结果的准确率和分类报告 +# print(f"Accuracy: {accuracy_score(y, y_pred)}") + \ No newline at end of file diff --git a/sml/forest/emulations/BUILD.bazel b/sml/adaboost/emulations/BUILD.bazel similarity index 86% rename from sml/forest/emulations/BUILD.bazel rename to sml/adaboost/emulations/BUILD.bazel index 7323a510..000fbad3 100644 --- a/sml/forest/emulations/BUILD.bazel +++ b/sml/adaboost/emulations/BUILD.bazel @@ -17,11 +17,11 @@ load("@rules_python//python:defs.bzl", "py_binary") package(default_visibility = ["//visibility:public"]) py_binary( - name = "forest_emul", - srcs = ["forest_emul.py"], + name = "adaboost_emul", + srcs = ["adaboost_emul.py"], deps = [ - "//sml/forest", - "//sml/tree", + "//sml/adaboost", + "//sml/tree:tree_w", "//sml/utils:emulation", ], ) diff --git a/sml/adaboost/emulations/adaboost_emul.py b/sml/adaboost/emulations/adaboost_emul.py new file mode 100644 index 00000000..38ae8e8f --- /dev/null +++ b/sml/adaboost/emulations/adaboost_emul.py @@ -0,0 +1,106 @@ +''' +Author: Li Zhihang +Date: 2024-07-16 11:21:02 +LastEditTime: 2024-07-17 15:26:04 +FilePath: /klaus/spu/sml/adaboost/emulations/adaboost_emul.py +Description: +''' +import time + +import jax.numpy as jnp +from sklearn.datasets import load_iris +from sklearn.ensemble import AdaBoostClassifier +from sklearn.tree import DecisionTreeClassifier + +import sml.utils.emulation as emulation +from sml.adaboost.adaboost import AdaBoostClassifier as sml_Adaboost + +MAX_DEPTH = 3 +CONFIG_FILE = emulation.CLUSTER_ABY3_3PC + +def emul_ada(mode=emulation.Mode.MULTIPROCESS): + def proc_wrapper( + estimator = "dtc", + n_estimators = 50, + max_depth = MAX_DEPTH, + learning_rate = 1.0, + n_classes = 3, + ): + ada_custom = sml_Adaboost( + estimator = "dtc", + n_estimators = 50, + max_depth = MAX_DEPTH, + learning_rate = 1.0, + n_classes = 3, + ) + + def proc(X, y): + ada_custom_fit = ada_custom.fit(X, y, sample_weight=None) + result = ada_custom_fit.predict(X) + return result + + return proc + + def load_data(): + iris = load_iris() + iris_data, iris_label = jnp.array(iris.data), jnp.array(iris.target) + # sorted_features: n_samples * n_features_in + n_samples, n_features_in = iris_data.shape + n_labels = len(jnp.unique(iris_label)) + sorted_features = jnp.sort(iris_data, axis=0) + new_threshold = (sorted_features[:-1, :] + sorted_features[1:, :]) / 2 + new_features = jnp.greater_equal( + iris_data[:, :], new_threshold[:, jnp.newaxis, :] + ) + new_features = new_features.transpose([1, 0, 2]).reshape(n_samples, -1) + + X, y = new_features[:, ::3], iris_label[:] + return X, y + + try: + # bandwidth and latency only work for docker mode + emulator = emulation.Emulator(CONFIG_FILE, mode, bandwidth=300, latency=20) + emulator.up() + + # load mock data + X, y = load_data() + n_labels = jnp.unique(y).shape[0] + + # compare with sklearn + base_estimator = DecisionTreeClassifier(max_depth=3) # 基分类器 + ada = AdaBoostClassifier(estimator=base_estimator, n_estimators=3, learning_rate=1.0, algorithm="SAMME") + + start = time.time() + ada = ada.fit(X, y) + score_plain = ada.score(X, y) + end = time.time() + print(f"Running time in SKlearn: {end - start:.2f}s") + + # mark these data to be protected in SPU + X_spu, y_spu = emulator.seal(X, y) + + # run + proc = proc_wrapper( + estimator = "dtc", + n_estimators = 3, + max_depth = MAX_DEPTH, + learning_rate = 1.0, + n_classes = 3, + ) + start = time.time() + # 不可以使用bootstrap,否则在spu运行的正确率很低 + result = emulator.run(proc)(X_spu, y_spu) + end = time.time() + score_encrpted = jnp.mean((result == y)) + print(f"Running time in SPU: {end - start:.2f}s") + + # print acc + print(f"Accuracy in SKlearn: {score_plain:.2f}") + print(f"Accuracy in SPU: {score_encrpted:.2f}") + + finally: + emulator.down() + + +if __name__ == "__main__": + emul_ada(emulation.Mode.MULTIPROCESS) diff --git a/sml/adaboost/tests/BUILD.bazel b/sml/adaboost/tests/BUILD.bazel new file mode 100644 index 00000000..19985277 --- /dev/null +++ b/sml/adaboost/tests/BUILD.bazel @@ -0,0 +1,28 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_python//python:defs.bzl", "py_test") + +package(default_visibility = ["//visibility:public"]) + +py_test( + name = "adaboost_test", + srcs = ["adaboost_test.py"], + deps = [ + "//sml/adaboost", + "//sml/tree:tree_w", + "//spu:init", + "//spu/utils:simulation", + ], +) diff --git a/sml/adaboost/tests/adaboost_test.py b/sml/adaboost/tests/adaboost_test.py new file mode 100644 index 00000000..2c4c7d85 --- /dev/null +++ b/sml/adaboost/tests/adaboost_test.py @@ -0,0 +1,98 @@ +''' +Author: Li Zhihang +Date: 2024-07-14 21:19:57 +LastEditTime: 2024-07-17 14:15:30 +FilePath: /klaus/spu/sml/adaboost/tests/adaboost_test.py +Description: +''' +import unittest + +import jax.numpy as jnp +from sklearn.datasets import load_iris +from sklearn.ensemble import AdaBoostClassifier +from sklearn.tree import DecisionTreeClassifier + + +import spu.spu_pb2 as spu_pb2 # type: ignore +import spu.utils.simulation as spsim +from sml.adaboost.adaboost import AdaBoostClassifier as sml_Adaboost + +MAX_DEPTH = 3 + +class UnitTests(unittest.TestCase): + def test_Ada(self): + def proc_wrapper( + estimator = "dtc", + n_estimators = 10, + max_depth = MAX_DEPTH, + learning_rate = 1.0, + n_classes = 3, + ): + ada_custom = sml_Adaboost( + estimator = "dtc", + n_estimators = 10, + max_depth = MAX_DEPTH, + learning_rate = 1.0, + n_classes = 3, + ) + + def proc(X, y): + ada_custom_fit = ada_custom.fit(X, y, sample_weight=None) + result = ada_custom_fit.predict(X) + return result + + return proc + + def load_data(): + iris = load_iris() + iris_data, iris_label = jnp.array(iris.data), jnp.array(iris.target) + # sorted_features: n_samples * n_features_in + n_samples, n_features_in = iris_data.shape + n_labels = len(jnp.unique(iris_label)) + sorted_features = jnp.sort(iris_data, axis=0) + new_threshold = (sorted_features[:-1, :] + sorted_features[1:, :]) / 2 + new_features = jnp.greater_equal( + iris_data[:, :], new_threshold[:, jnp.newaxis, :] + ) + new_features = new_features.transpose([1, 0, 2]).reshape(n_samples, -1) + + X, y = new_features[:, ::3], iris_label[:] + return X, y + + sim = spsim.Simulator.simple( + 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 + ) + + + X, y = load_data() + n_samples, n_features = X.shape + n_labels = jnp.unique(y).shape[0] + + # compare with sklearn + base_estimator = DecisionTreeClassifier(max_depth=3) # 基分类器 + ada = AdaBoostClassifier(estimator=base_estimator, n_estimators=3, learning_rate=1.0, algorithm="SAMME") + ada = ada.fit(X, y) + score_plain = ada.score(X, y) + + #run + proc = proc_wrapper( + estimator = "dtc", + n_estimators = 3, + max_depth = 3, + learning_rate = 1.0, + n_classes = 3, + ) + + result = spsim.sim_jax(sim, proc)(X, y) + print(result) + score_encrypted = jnp.mean(result == y) + + # print acc + print(f"Accuracy in SKlearn: {score_plain}") + print(f"Accuracy in SPU: {score_encrypted}") + +if __name__ == '__main__': + unittest.main() + + + diff --git a/sml/ensemble/BUILD.bazel b/sml/ensemble/BUILD.bazel index 7832e732..9d1d0ade 100644 --- a/sml/ensemble/BUILD.bazel +++ b/sml/ensemble/BUILD.bazel @@ -11,3 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +load("@rules_python//python:defs.bzl", "py_library") + +package(default_visibility = ["//visibility:public"]) + +py_library( + name = "forest", + srcs = ["forest.py"], +) diff --git a/sml/ensemble/emulations/BUILD.bazel b/sml/ensemble/emulations/BUILD.bazel index 7832e732..264b4bc2 100644 --- a/sml/ensemble/emulations/BUILD.bazel +++ b/sml/ensemble/emulations/BUILD.bazel @@ -11,3 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +load("@rules_python//python:defs.bzl", "py_binary") + +package(default_visibility = ["//visibility:public"]) + +py_binary( + name = "forest_emul", + srcs = ["forest_emul.py"], + deps = [ + "//sml/ensemble:forest", + "//sml/tree:tree", + "//sml/utils:emulation", + ], +) diff --git a/sml/forest/emulations/forest_emul.py b/sml/ensemble/emulations/forest_emul.py similarity index 88% rename from sml/forest/emulations/forest_emul.py rename to sml/ensemble/emulations/forest_emul.py index b4505951..32869e22 100644 --- a/sml/forest/emulations/forest_emul.py +++ b/sml/ensemble/emulations/forest_emul.py @@ -11,13 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -''' -Author: Li Zhihang -Date: 2024-06-16 12:02:32 -LastEditTime: 2024-06-22 16:52:19 -FilePath: /klaus/spu-klaus/sml/forest/emulations/forest_emul.py -Description: -''' import time import jax.numpy as jnp @@ -25,7 +18,7 @@ from sklearn.ensemble import RandomForestClassifier import sml.utils.emulation as emulation -from sml.forest.forest import RandomForestClassifier as sml_rfc +from sml.ensemble.forest import RandomForestClassifier as sml_rfc MAX_DEPTH = 3 CONFIG_FILE = emulation.CLUSTER_ABY3_3PC @@ -35,26 +28,22 @@ def emul_forest(mode=emulation.Mode.MULTIPROCESS): def proc_wrapper( n_estimators=100, max_features=None, - n_features=200, criterion='gini', splitter='best', max_depth=3, bootstrap=True, max_samples=None, n_labels=3, - seed=0, ): rf_custom = sml_rfc( n_estimators, max_features, - n_features, criterion, splitter, max_depth, bootstrap, max_samples, n_labels, - seed, ) def proc(X, y): @@ -87,7 +76,7 @@ def load_data(): # load mock data X, y = load_data() - n_samples, n_features = X.shape + # n_samples, n_features = X.shape n_labels = jnp.unique(y).shape[0] # compare with sklearn @@ -112,20 +101,18 @@ def load_data(): proc = proc_wrapper( n_estimators=3, max_features=0.7, - n_features=n_features, criterion='gini', splitter='best', max_depth=3, bootstrap=False, max_samples=None, n_labels=n_labels, - seed=0, ) start = time.time() # 不可以使用bootstrap,否则在spu运行的正确率很低 result = emulator.run(proc)(X_spu, y_spu) end = time.time() - score_encrpted = jnp.sum((result == y)) / n_samples + score_encrpted = jnp.mean((result == y)) print(f"Running time in SPU: {end - start:.2f}s") # print acc diff --git a/sml/ensemble/forest.py b/sml/ensemble/forest.py new file mode 100644 index 00000000..909e978a --- /dev/null +++ b/sml/ensemble/forest.py @@ -0,0 +1,221 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax.numpy as jnp +from jax import lax +import jax +import math +import random + +from sml.tree.tree import DecisionTreeClassifier as sml_dtc + +class RandomForestClassifier: + """A random forest classifier based on DecisionTreeClassifier. + + Parameters + ---------- + n_estimators : int + The number of trees in the forest. Must specify an integer > 0. + + max_features : int, float, "auto", "sqrt", "log2", or None. + The number of features to consider when looking for the best split. + If it's an integer, must 0 < integer < n_features. + If it's an float, must 0 < float <= 1. + + criterion : {"gini"}, default="gini" + The function to measure the quality of a split. Supported criteria are + "gini" for the Gini impurity. + + splitter : {"best"}, default="best" + The strategy used to choose the split at each node. Supported + strategies are "best" to choose the best split. + + max_depth : int + The maximum depth of the tree. Must specify an integer > 0. + + bootstrap : bool + Whether bootstrap samples are used when building trees. + + max_samples : int, float ,None, default=None + The number of samples to draw from X to train each base estimator. + This parameter is only valid if bootstrap is ture. + If it's an integer, must 0 < integer < n_samples. + If it's an float, must 0 < float <= 1. + + n_labels: int + The max number of labels. + + """ + + def __init__( + self, + n_estimators, + max_features, + criterion, + splitter, + max_depth, + bootstrap, + max_samples, + n_labels, + ): + assert criterion == "gini", "criteria other than gini is not supported." + assert splitter == "best", "splitter other than best is not supported." + assert ( + n_estimators is not None and n_estimators > 0 + ), "n_estimators should not be None and must > 0." + assert ( + max_depth is not None and max_depth > 0 + ), "max_depth should not be None and must > 0." + assert isinstance( + bootstrap, bool + ), "bootstrap should be a boolean value (True or False)" + + self.n_estimators = n_estimators + self.max_features = max_features + self.criterion = criterion + self.splitter = splitter + self.max_depth = max_depth + self.bootstrap = bootstrap + self.max_samples = max_samples + self.n_labels = n_labels + + self.trees = [] + self.features_indices = [] + + def _bootstrap_sample(self, X, y): + n_samples = X.shape[0] + if isinstance(self.max_samples, int): + assert ( + self.max_samples <= n_samples + ), "max_samples should not exceed n_samples when it's an integer" + max_samples = self.max_samples + elif isinstance(self.max_samples, float): + assert ( + 0 < self.max_samples <= 1 + ), "max_samples should be in the range (0, 1] when it's a float" + max_samples = (int)(self.max_samples * n_samples) + else: + max_samples = n_samples + + if not self.bootstrap: + return X, y + + # 实现bootstrap + population = range(n_samples) + indices = random.sample(population, max_samples) + + indices = jnp.array(indices) + return X[indices], y[indices] + + def _select_features(self, n, k): + indices = range(n) + selected_elements = random.sample(indices, k) + return selected_elements + + def fit(self, X, y): + n_samples, n_features = X.shape + self.n_features = n_features + x = random.random() + + # 取消__init__中n_features参数,这里计算获得max_features + if isinstance(self.max_features, int): + assert ( + 0 < self.max_features <= self.n_features + ), "0 < max_features <= n_features when it's an integer" + # self.max_features = jnp.array(self.max_features, dtype=int) + elif isinstance(self.max_features, float): + assert ( + 0 < self.max_features <= 1 + ), "max_features should be in the range (0, 1] when it's a float" + self.max_features = (int)(self.max_features * self.n_features) + # self.max_features = jnp.array((self.max_features * n_features), dtype=int) + elif isinstance(self.max_features, str): + if self.max_features == 'sqrt': + self.max_features = (int)(math.sqrt(self.n_features)) + # self.max_features = jnp.array(jnp.sqrt(n_features), dtype=int) + elif self.max_features == 'log2': + self.max_features = (int)(math.log2(self.n_features)) + # self.max_features = jnp.array(jnp.log2(n_features), dtype=int) + else: + self.max_features = self.n_features + else: + self.max_features = self.n_features + + self.trees = [] + self.features_indices = [] + + for _ in range(self.n_estimators): + X_sample, y_sample = self._bootstrap_sample(X, y) + features = self._select_features(self.n_features, self.max_features) + + tree = sml_dtc(self.criterion, self.splitter, self.max_depth, self.n_labels) + tree.fit(X_sample[:, features], y_sample) + self.trees.append(tree) + self.features_indices.append(features) + + return self + + def predict(self, X): + predictions_list = [] + for i, tree in enumerate(self.trees): + features = self.features_indices[i] + predictions = tree.predict(X[:, features]) + predictions_list.append(predictions) + + tree_predictions = jnp.array(predictions_list).T + + + # 目前jit函数single_tree_predict的时候,会将features_indices当成tracer导致报错 + # features_indices = jnp.array(self.features_indices) + # features_indices = self.features_indices + # print(features_indices[3]) + # # Define a function that predicts using a single tree and the corresponding features + # def single_tree_predict(i, X): + # features = self.features_indices[i] + # print(features) + # return self.trees[i].predict(X[:, features]) + + # # Vectorize the single_tree_predict function + # tree_predictions = jax.vmap(single_tree_predict, in_axes=(0, None))(jnp.arange(self.n_estimators), X) + + y_pred, _ = jax_mode_row(tree_predictions) + + return y_pred.ravel() + + +def jax_mode_row(data): + # 获取每行的众数 + + # 获取数据的形状 + num_rows, num_cols = data.shape + + # 初始化众数和计数的数组 + modes_list = [] + counts_list = [] + + # 计算每行的众数及其计数 + for row in range(num_rows): + row_data = data[row, :] + unique_values, value_counts = jnp.unique( + row_data, return_counts=True, size=row_data.shape[0] + ) + max_count_idx = jnp.argmax(value_counts) + modes_list.append(unique_values[max_count_idx]) + counts_list.append(value_counts[max_count_idx]) + + # 将列表转换为 jnp.array + modes = jnp.array(modes_list, dtype=data.dtype) + counts = jnp.array(counts_list, dtype=jnp.int32) + + return modes, counts diff --git a/sml/ensemble/tests/BUILD.bazel b/sml/ensemble/tests/BUILD.bazel index 7832e732..28a29227 100644 --- a/sml/ensemble/tests/BUILD.bazel +++ b/sml/ensemble/tests/BUILD.bazel @@ -11,3 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +load("@rules_python//python:defs.bzl", "py_test") + +package(default_visibility = ["//visibility:public"]) + +py_test( + name = "forest_test", + srcs = ["forest_test.py"], + deps = [ + "//sml/ensemble:forest", + "//sml/tree:tree", + "//spu:init", + "//spu/utils:simulation", + ], +) diff --git a/sml/forest/tests/forest_test.py b/sml/ensemble/tests/forest_test.py similarity index 77% rename from sml/forest/tests/forest_test.py rename to sml/ensemble/tests/forest_test.py index 93eb9930..a90c4b68 100644 --- a/sml/forest/tests/forest_test.py +++ b/sml/ensemble/tests/forest_test.py @@ -11,13 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -''' -Author: Li Zhihang -Date: 2024-06-16 12:03:08 -LastEditTime: 2024-07-01 18:25:23 -FilePath: /klaus/spu/sml/forest/tests/forest_test.py -Description: -''' import unittest import jax.numpy as jnp @@ -26,7 +19,7 @@ import spu.spu_pb2 as spu_pb2 # type: ignore import spu.utils.simulation as spsim -from sml.forest.forest import RandomForestClassifier as sml_rfc +from sml.ensemble.forest import RandomForestClassifier as sml_rfc MAX_DEPTH = 3 @@ -34,28 +27,24 @@ class UnitTests(unittest.TestCase): def test_forest(self): def proc_wrapper( - n_estimators=100, - max_features=None, - n_features=199, - criterion='gini', - splitter='best', - max_depth=3, - bootstrap=True, - max_samples=None, - n_labels=3, - seed=0, + n_estimators, + max_features, + criterion, + splitter, + max_depth, + bootstrap, + max_samples, + n_labels, ): rf_custom = sml_rfc( n_estimators, max_features, - n_features, criterion, splitter, max_depth, bootstrap, max_samples, n_labels, - seed, ) def proc(X, y): @@ -69,7 +58,6 @@ def proc(X, y): def load_data(): iris = load_iris() iris_data, iris_label = jnp.array(iris.data), jnp.array(iris.target) - # sorted_features: n_samples * n_features_in n_samples, n_features_in = iris_data.shape n_labels = len(jnp.unique(iris_label)) sorted_features = jnp.sort(iris_data, axis=0) @@ -89,14 +77,12 @@ def load_data(): # load mock data X, y = load_data() - n_samples, n_features = X.shape n_labels = jnp.unique(y).shape[0] - print(y) # compare with sklearn rf = RandomForestClassifier( n_estimators=3, - max_features=None, + max_features="log2", criterion='gini', max_depth=MAX_DEPTH, bootstrap=True, @@ -106,27 +92,22 @@ def load_data(): score_plain = rf.score(X, y) # 获取每棵树的预测值 tree_predictions = jnp.array([tree.predict(X) for tree in rf.estimators_]) - # print("sklearn:") - # print(tree_predictions) - print(n_features) + # run proc = proc_wrapper( n_estimators=3, - max_features=None, - n_features=n_features, + max_features="log2", criterion='gini', splitter='best', max_depth=3, bootstrap=True, max_samples=0.7, n_labels=n_labels, - seed=0, ) result = spsim.sim_jax(sim, proc)(X, y) - # print(y_sample) - score_encrpted = jnp.sum((result == y)) / n_samples + score_encrpted = jnp.mean((result == y)) # print acc print(f"Accuracy in SKlearn: {score_plain}") diff --git a/sml/forest/forest.py b/sml/forest/forest.py deleted file mode 100644 index 9a7ba03a..00000000 --- a/sml/forest/forest.py +++ /dev/null @@ -1,244 +0,0 @@ -# Copyright 2023 Ant Group Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -''' -Author : error: git config user.email & please set dead value or install git -Date : 2024-06-15 20:01 -LastEditors: Please set LastEditors -LastEditTime: 2024-06-18 19:09:17 -FilePath: /klaus/spu-klaus/sml/forest/forest.py -Description : 基本完成函数编写的工作,目前测试结果基本正确,后面需要完成emul和test -bootstrap有问题,bootstrap后predict不输出1,bootstrap无1(因为不支持jax.random的api) - -!最终:bootstrap这个参数,不可用:在明文下bootstrap取样正确,但在forest_test.py时,无法取到标签1, -因为bootstrap不能用,目前没有开发max_samples这个超参数 - -基本完成,因为不支持jax.random的api,所以bootstrap和select_features都没有使用随机算法, -导致max_features为sqrt和log2的时候误差较大,而sklearn的select_features比较随机,因此误差大 -如果不增加n_features这个参数的话,会导致创建动态数组,并对动态数组切分,导致程序报错 -''' -import jax.numpy as jnp -from jax import lax - -from sml.tree.tree import DecisionTreeClassifier as sml_dtc - -# from functools import partial -# from jax import jit -# import jax.random as jdm - - -# key = jdm.PRNGKey(42) - - -class RandomForestClassifier: - """A random forest classifier.""" - - def __init__( - self, - n_estimators, - max_features, - n_features, - criterion, - splitter, - max_depth, - bootstrap, - max_samples, - n_labels, - seed, - ): - assert criterion == "gini", "criteria other than gini is not supported." - assert splitter == "best", "splitter other than best is not supported." - assert ( - n_estimators is not None and n_estimators > 0 - ), "n_estimators should not be None and must > 0." - assert ( - max_depth is not None and max_depth > 0 - ), "max_depth should not be None and must > 0." - assert isinstance( - bootstrap, bool - ), "bootstrap should be a boolean value (True or False)" - assert isinstance(n_features, int), "n_features should be an integer." - if isinstance(max_features, int): - assert ( - max_features <= n_features - ), "max_features should not exceed n_features when it's an integer" - max_features = jnp.array(n_features, dtype=int) - elif isinstance(max_features, float): - assert ( - 0 < max_features <= 1 - ), "max_features should be in the range (0, 1] when it's a float" - max_features = jnp.array((max_features * n_features), dtype=int) - elif isinstance(max_features, str): - if max_features == 'sqrt': - max_features = jnp.array(jnp.sqrt(n_features), dtype=int) - elif max_features == 'log2': - max_features = jnp.array(jnp.log2(n_features), dtype=int) - else: - max_features = n_features - else: - max_features = n_features - - self.seed = seed - # self.key = key - - self.n_estimators = n_estimators - self.max_features = max_features - self.n_features = n_features - self.criterion = criterion - self.splitter = splitter - self.max_depth = max_depth - self.bootstrap = bootstrap - self.max_samples = max_samples - # max_samples 这个参数在bootstrap为ture才有效,可以为float(0,1],int<总数,默认为none - self.n_labels = n_labels - - self.trees = [] - self.features_indices = [] - - def _bootstrap_sample(self, X, y): - n_samples = X.shape[0] - if isinstance(self.max_samples, int): - assert ( - self.max_samples <= n_samples - ), "max_samples should not exceed n_samples when it's an integer" - max_samples = self.max_samples - elif isinstance(self.max_samples, float): - assert ( - 0 < self.max_samples <= 1 - ), "max_samples should be in the range (0, 1] when it's a float" - max_samples = (int)(self.max_samples * n_samples) - else: - max_samples = n_samples - - if not self.bootstrap: - return X, y - # 使用斐波那契数列变体生成伪随机索引 - indices = jnp.zeros(max_samples, dtype=int) - rng = self.seed - # a, b = self.seed % n_samples, (self.seed + 1) % n_samples - for i in range(max_samples): - rng = (rng * 48271 + 1) % (2**31 - 1) - indices = indices.at[i].set(rng % n_samples) - # indices = indices.at[i].set(b % n_samples) - # a, b = b, (a + b) % n_samples # 生成斐波那契数列并取模 - self.seed += 1 - # 更新种子值以保证每次调用生成不同的序列 - return X[indices], y[indices] - - # 可以用,但没n选k - def _select_features(self): - # if isinstance(self.max_features, int): - # assert self.max_features <= n_features, "max_features should not exceed n_features when it's an integer" - # max_features = jnp.array(n_features, dtype=int) - # elif isinstance(self.max_features, float): - # assert 0 < self.max_features <= 1, "max_features should be in the range (0, 1] when it's a float" - # max_features = jnp.array((self.max_features * n_features), dtype=int) - # elif isinstance(self.max_features, str): - # if self.max_features == 'sqrt': - # max_features = jnp.array(jnp.sqrt(n_features), dtype=int) - # elif self.max_features == 'log2': - # max_features = jnp.array(jnp.log2(n_features), dtype=int) - # else: - # max_features = n_features - # else: - # max_features = n_features - - selected_indices = self._shuffle_indices(self.n_features, self.max_features)[ - : self.max_features - ] - return selected_indices - - def _shuffle_indices(self, n, k): - # 基于种子的循环洗牌算法,确保不出现重复的索引 - rng = self.seed - indices = jnp.arange(n) - # indices = range(n) - - def cond_fun(state): - i, _, _ = state - return i < k - - def body_fun(state): - i, rng, indices = state - rng = (rng * 48271 + 1) % (2**31 - 1) - j = i + rng % (n - i) - # 交换元素以进行洗牌 - indices = self._swap(indices, i, j) - self.seed += 1 - return (i + 1, rng, indices) - - _, _, shuffled_indices = lax.while_loop(cond_fun, body_fun, (0, rng, indices)) - # selected_indices = shuffled_indices[:k] - selected_indices = shuffled_indices - return selected_indices - - def _swap(self, array, i, j): - # 辅助函数:交换数组中的两个元素 - array = array.at[i].set(array[j]) - array = array.at[j].set(array[i]) - return array - - def fit(self, X, y): - n_samples, n_features = X.shape - self.trees = [] - self.features_indices = [] - - for _ in range(self.n_estimators): - X_sample, y_sample = self._bootstrap_sample(X, y) - features = self._select_features() - # selected_indices = self._shuffle_indices(n_features) - print(y_sample) - tree = sml_dtc(self.criterion, self.splitter, self.max_depth, self.n_labels) - tree.fit(X_sample[:, features], y_sample) - self.trees.append(tree) - self.features_indices.append(features) - - return self - - def predict(self, X): - tree_predictions = jnp.zeros((X.shape[0], self.n_estimators)) - - for i, tree in enumerate(self.trees): - features = self.features_indices[i] - print(features) - tree_predictions = tree_predictions.at[:, i].set( - tree.predict(X[:, features]) - ) - # print(tree_predictions[:, i]) - # Use majority vote to determine final prediction - y_pred, _ = jax_mode_row(tree_predictions) - # return tree_predictions - return y_pred.ravel() - - -def jax_mode_row(data): - # 获取每行的众数 - - # 获取数据的形状 - num_rows, num_cols = data.shape - - # 初始化众数和计数的数组 - modes = jnp.zeros(num_rows, dtype=data.dtype) - counts = jnp.zeros(num_rows, dtype=jnp.int32) - - # 计算每行的众数及其计数 - for row in range(num_rows): - row_data = data[row, :] - unique_values, value_counts = jnp.unique( - row_data, return_counts=True, size=row_data.shape[0] - ) - max_count_idx = jnp.argmax(value_counts) - modes = modes.at[row].set(unique_values[max_count_idx]) - counts = counts.at[row].set(value_counts[max_count_idx]) - - return modes, counts diff --git a/sml/quantile/BUILD.bazel b/sml/quantile/BUILD.bazel new file mode 100644 index 00000000..abafa2c9 --- /dev/null +++ b/sml/quantile/BUILD.bazel @@ -0,0 +1,25 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_python//python:defs.bzl", "py_library") + +package(default_visibility = ["//visibility:public"]) + +py_library( + name = "quantile", + srcs = ["quantile.py"], + deps = [ + "//sml/quantile/utils:linprog", + ], +) diff --git a/sml/quantile/quantile.py b/sml/quantile/quantile.py new file mode 100644 index 00000000..46e6c7a9 --- /dev/null +++ b/sml/quantile/quantile.py @@ -0,0 +1,299 @@ +''' +Author: Li Zhihang +Date: 2024-07-03 11:29:34 +LastEditTime: 2024-07-27 19:24:52 +FilePath: /klaus/spu/sml/quantile/quantile.py +Description: 报错status=1,原因是b矩阵有负值,需要将其转为非负值 +''' +import jax.numpy as jnp +from jax import grad +import jax + +import numbers +import pandas as pd +from warnings import warn +import warnings + +# from scipy.optimize import linprog + +# from _linprog import _linprog_simplex +from sml.quantile.utils.linprog import _linprog_simplex + +def _num_samples(x): + """返回x中的样本数量.""" + if hasattr(x, 'fit'): + # 检查是否是一个estimator + raise TypeError('Expected sequence or array-like, got estimator') + if not hasattr(x, '__len__') and not hasattr(x, 'shape') and not hasattr(x, '__array__'): + raise TypeError("Expected sequence or array-like, got %s" % type(x)) + + if hasattr(x, 'shape'): + if len(x.shape) == 0: # scalar + raise TypeError("Singleton array %r cannot be considered a valid collection." % x) + return x.shape[0] + else: + return len(x) + +def _check_sample_weight(sample_weight, X, dtype=None, copy=False, only_non_negative=False): + ''' + description: 验证样本权重. + return {*} + ''' + # jax默认只支持float32, + # 如果需要启用 float64 类型,可以设置 jax_enable_x64 配置选项或 JAX_ENABLE_X64 环境变量。 + n_samples = _num_samples(X) + if dtype is not None and dtype not in [jnp.float32, jnp.float64]: + dtype = jnp.float32 + + if sample_weight is None: + sample_weight = jnp.ones(n_samples, dtype=dtype) + elif isinstance(sample_weight, numbers.Number): + sample_weight = jnp.full(n_samples, sample_weight, dtype=dtype) + else: + sample_weight = jnp.asarray(sample_weight, dtype=dtype) + if sample_weight.ndim != 1: + raise ValueError("Sample weights must be 1D array or scalar") + + if sample_weight.shape[0] != n_samples: + raise ValueError( + "sample_weight.shape == {}, expected {}!".format( + sample_weight.shape, (n_samples,) + ) + ) + + if only_non_negative and not jnp.all(sample_weight >= 0): + raise ValueError("`sample_weight` cannot contain negative weights") + + if copy: + sample_weight = jnp.copy(sample_weight) + + return sample_weight + +def _safe_indexing(X, indices, *, axis=0): + if indices is None: + return X + + if axis not in (0, 1): + raise ValueError( + "'axis' should be either 0 (to index rows) or 1 (to index " + " column). Got {} instead.".format(axis) + ) + + if axis == 0 and isinstance(indices, str): + raise ValueError("String indexing is not supported with 'axis=0'") + + if axis == 1 and isinstance(X, list): + raise ValueError("axis=1 is not supported for lists") + + if axis == 1 and hasattr(X, "shape") and len(X.shape) != 2: + raise ValueError( + "'X' should be a 2D JAXNumPy array, " + "dataframe when indexing the columns (i.e. 'axis=1'). " + "Got {} instead with {} dimension(s).".format(type(X), len(X.shape)) + ) + + if axis == 1 and isinstance(indices, str) and not isinstance(X, pd.DataFrame): + raise ValueError( + "Specifying the columns using strings is only supported for dataframes." + ) + + if isinstance(X, pd.DataFrame): + return pandas_indexing(X, indices, axis=axis) + elif isinstance(X, jnp.ndarray): + return numpy_indexing(X, indices, axis=axis) + elif isinstance(X, list): + return list_indexing(X, indices, axis=axis) + else: + raise ValueError("Unsupported input type for X: {}".format(type(X))) + +def pandas_indexing(X, indices, axis=0): + if axis == 0: + return X.iloc[indices] + elif axis == 1: + return X[indices] + +def numpy_indexing(X, indices, axis=0): + if axis == 0: + return X[indices] + elif axis == 1: + return X[:, indices] + +def list_indexing(X, indices, axis=0): + if axis == 0: + return [X[idx] for idx in indices] + else: + raise ValueError("axis=1 is not supported for lists") + +class QuantileRegressor: + + def __init__(self, quantile=0.5, alpha=1.0, fit_intercept=True, lr=0.01, max_iter=1000, n_samples=100): + self.quantile = quantile + self.alpha = alpha + self.fit_intercept = fit_intercept + self.lr = lr + self.max_iter = max_iter + self.n_samples = n_samples + + self.coef_ = None + self.intercept_ = None + + def fit(self, X, y, sample_weight=None): + n_samples, n_features = X.shape + n_params = n_features + + # sample_weight = _check_sample_weight(sample_weight, X) + sample_weight = jnp.ones((self.n_samples,)) + + if self.fit_intercept: + n_params += 1 + + alpha = jnp.sum(sample_weight) * self.alpha + + # indices = jnp.nonzero(sample_weight)[0] + # n_indices = len(indices) + # if n_indices < len(sample_weight): + # sample_weight = sample_weight[indices] + # X = _safe_indexing(X, indices) + # y = _safe_indexing(y, indices) + + c = jnp.concatenate( + [ + jnp.full(2 * n_params, fill_value=alpha), + sample_weight * self.quantile, + sample_weight * (1 - self.quantile), + ] + ) + + if self.fit_intercept: + c = c.at[0].set(0) + c = c.at[n_params].set(0) + + # eye = jnp.eye(n_indices) + eye = jnp.eye(self.n_samples) + if self.fit_intercept: + # ones = jnp.ones((n_indices,1)) + ones = jnp.ones((self.n_samples,1)) + A = jnp.concatenate([ones, X, -ones, -X, eye, -eye], axis=1) + else: + A = jnp.concatenate([X, -X, eye, -eye], axis=1) + + b = y + + n,m = A.shape + av = jnp.arange(n) + m + + result = _linprog_simplex(c, A, b, maxiter=self.max_iter,tol=1e-3) + + # result = linprog(c,A_eq=A,b_eq=b,method='simplex') + # print("result",result) + # print("Optimal solution:", result['x']) + # print("Optimal solution:", result[0]) + # print("Optimal value:", result[0]@c) + solution = result[0] + # solution = result['x'] + # 取消了1: "Iteration limit reached."因为这个方法就是达到迭代次数停止的 + # if not result[1]: + # # if not result['success']: + # failure = { + # 1: "Iteration limit reached.", + # 2: "Optimization failed. Unable to find a feasible" + # " starting point.", + # 3: "Optimization failed. The problem appears to be unbounded.", + # 4: "Optimization failed. Singular matrix encountered." + # } + # warnings.warn( + # "Linear programming for QuantileRegressor did not succeed.\n" + # f"Status is {result[1]}: " + # # + failure.setdefault(result[1], "unknown reason") + # + "\n" + # + "Result message of linprog:\n" + # # + result[2], + # # ConvergenceWarning, + # ) + + params = solution[:n_params] - solution[n_params : 2 * n_params] + # print("params:",params) + self.n_iter_ = result[2] + # self.n_iter_ = result['nit'] + + if self.fit_intercept: + self.coef_ = params[1:] + self.intercept_ = params[0] + else: + self.coef_ = params + self.intercept_ = 0.0 + return self + + def predict(self, X): + if self.fit_intercept: + X = jnp.column_stack((jnp.ones(X.shape[0]), X)) + # print("X:", X) + # print("intercept_:", self.intercept_) + # print("coef_:", self.coef_) + return jnp.dot(X, jnp.hstack([self.intercept_, self.coef_])) + else: + return jnp.dot(X, self.coef_) + # return jnp.dot(X, jnp.hstack([self.intercept_, self.coef_])) + +from jax import random +from sklearn.linear_model import QuantileRegressor as SklearnQuantileRegressor +@jax.jit +def compare_quantile_regressors(X, y, quantile=0.2, alpha=0.1, lr=0.01, max_iter=1000): + # 训练和预测自定义模型 + custom_model = QuantileRegressor(quantile=quantile, alpha=alpha, fit_intercept=True, lr=lr, max_iter=max_iter) + custom_model.fit(X, y) + custom_y_pred = custom_model.predict(X) + + print("Custom Model:") + print("Mean of y <= Custom Predictions:", jnp.mean(y <= custom_y_pred)) + print("Custom Coefficients:", custom_model.coef_) + print("Custom Intercept:", custom_model.intercept_) + +if __name__ == "__main__": + key = random.PRNGKey(42) + key, subkey = random.split(key) + X = random.normal(subkey, (100, 2)) + y = 5 * X[:, 0] + 2 * X[:, 1] + random.normal(key, (100,)) * 0.1 + + compare_quantile_regressors(X,y) + + + + + +# # 设置随机种子 +# key = random.PRNGKey(42) +# # 生成 X 数据 +# key, subkey = random.split(key) +# X = random.normal(subkey, (100, 2)) +# # 生成 y 数据 +# y = 5 * X[:, 0] + 2 * X[:, 1] + random.normal(key, (100,)) * 0.1 # 高相关性,带有小噪声 + +# # print + +# custom_model = QuantileRegressor(quantile=0.2, alpha=0.1, fit_intercept=True, lr=0.01, max_iter=1000) +# custom_model.fit(X, y) +# custom_y_pred = custom_model.predict(X) + +# print(jnp.mean(y <= custom_model.predict(X))) +# print("Custom Coefficients:", custom_model.coef_) +# print("Custom Intercept:", custom_model.intercept_) + + +# from sklearn.linear_model import QuantileRegressor as SklearnQuantileRegressor + +# sklearn_model = SklearnQuantileRegressor(quantile=0.2, alpha=0.1, fit_intercept=True, solver='highs') +# sklearn_model.fit(X, y) +# sklearn_y_pred = sklearn_model.predict(X) + +# print(jnp.mean(y <= sklearn_model.predict(X))) +# print("Sklearn Coefficients:", sklearn_model.coef_) +# print("Sklearn Intercept:", sklearn_model.intercept_) + + + + +# # Print first 10 predictions +# print("Sklearn Predictions:", sklearn_y_pred[:10]) +# print("Custom Predictions:", custom_y_pred[:10]) + diff --git a/sml/forest/tests/BUILD.bazel b/sml/quantile/tests/BUILD.bazel similarity index 88% rename from sml/forest/tests/BUILD.bazel rename to sml/quantile/tests/BUILD.bazel index 84133b03..ac8dad19 100644 --- a/sml/forest/tests/BUILD.bazel +++ b/sml/quantile/tests/BUILD.bazel @@ -17,11 +17,10 @@ load("@rules_python//python:defs.bzl", "py_test") package(default_visibility = ["//visibility:public"]) py_test( - name = "forest_test", - srcs = ["forest_test.py"], + name = "quantile_test", + srcs = ["quantile_test.py"], deps = [ - "//sml/forest", - "//sml/tree", + "//sml/quantile", "//spu:init", "//spu/utils:simulation", ], diff --git a/sml/quantile/tests/quantile_test.py b/sml/quantile/tests/quantile_test.py new file mode 100644 index 00000000..8f0f796f --- /dev/null +++ b/sml/quantile/tests/quantile_test.py @@ -0,0 +1,102 @@ +''' +Author: Li Zhihang +Date: 2024-07-07 20:27:16 +LastEditTime: 2024-07-27 11:36:49 +FilePath: /klaus/spu/sml/quantile/tests/quantile_test.py +Description: +''' +import unittest + +import jax.numpy as jnp +import numpy as np +from sklearn.linear_model import QuantileRegressor as SklearnQuantileRegressor + +import spu.spu_pb2 as spu_pb2 # type: ignore +import spu.utils.simulation as spsim +from sml.quantile.quantile import QuantileRegressor as SmlQuantileRegressor + + +class UnitTests(unittest.TestCase): + def test_forest(self): + def proc_wrapper( + quantile=0.5, + alpha=1.0, + fit_intercept=True, + lr=0.01, + max_iter=1000, + ): + quantile_custom = SmlQuantileRegressor( + quantile, + alpha, + fit_intercept, + lr, + max_iter, + ) + + def proc(X, y): + quantile_custom_fit = quantile_custom.fit(X, y) + # acc = jnp.mean(y <= quantile_custom_fit.predict(X)) + result = quantile_custom_fit.predict(X) + return result + + return proc + + n_samples, n_features = 100, 2 + # def generate_data(): + # """ + # Generate random data for testing. + + # Returns: + # ------- + # X : array-like, shape (n_samples, n_features) + # Feature data. + # y : array-like, shape (n_samples,) + # Target data. + # coef : array-like, shape (n_features + 1,) + # True coefficients, including the intercept term and feature weights. + + # """ + # np.random.seed(42) + # X = np.random.rand(n_samples, n_features) + # coef = np.random.rand(n_features + 1) # +1 for the intercept term + # y = X @ coef[1:] + coef[0] + # sample_weight = np.random.rand(n_samples) + # return X, y, coef, sample_weight + + def generate_data(): + from jax import random + # 设置随机种子 + key = random.PRNGKey(42) + # 生成 X 数据 + key, subkey = random.split(key) + X = random.normal(subkey, (100, 2)) + # 生成 y 数据 + y = 5 * X[:, 0] + 2 * X[:, 1] + random.normal(key, (100,)) * 0.1 # 高相关性,带有小噪声 + return X, y + + # bandwidth and latency only work for docker mode + sim = spsim.Simulator.simple( + 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 + ) + + # X, y, coef, sample_weight = generate_data() + X, y = generate_data() + + # compare with sklearn + quantile_sklearn = SklearnQuantileRegressor(quantile=0.7, alpha=0.1, fit_intercept=True, solver='highs') + quantile_sklearn_fit = quantile_sklearn.fit(X, y) + acc_sklearn = jnp.mean(y <= quantile_sklearn_fit.predict(X)) + print(f"Accuracy in SKlearn: {acc_sklearn:.2f}") + + # run + proc = proc_wrapper(quantile=0.7, alpha=0.1, fit_intercept=True, lr=0.01, max_iter=1000) + result = spsim.sim_jax(sim, proc)(X, y) + acc_custom = jnp.mean(y <= result) + + # print acc + + print(f"Accuracy in SPU: {acc_custom:.2f}") + +if __name__ == "__main__": + unittest.main() + diff --git a/sml/forest/BUILD.bazel b/sml/quantile/utils/BUILD.bazel similarity index 93% rename from sml/forest/BUILD.bazel rename to sml/quantile/utils/BUILD.bazel index 9d1d0ade..04ca5c96 100644 --- a/sml/forest/BUILD.bazel +++ b/sml/quantile/utils/BUILD.bazel @@ -17,6 +17,7 @@ load("@rules_python//python:defs.bzl", "py_library") package(default_visibility = ["//visibility:public"]) py_library( - name = "forest", - srcs = ["forest.py"], + name = "linprog", + srcs = ["linprog.py"], ) + diff --git a/sml/quantile/utils/linprog.py b/sml/quantile/utils/linprog.py new file mode 100644 index 00000000..d3faa024 --- /dev/null +++ b/sml/quantile/utils/linprog.py @@ -0,0 +1,478 @@ +import jax.numpy as jnp +from warnings import warn +import warnings +from jax import lax +from jax import jit +import jax + +# 这个是使用if-else实现的linprog_simplex函数,但是 +# 由于jax中if-else和while-loop的限制,需使用lax.cond函数,所以 +# 还写了一版使用lax.cond实现的linprog_simplex函数,见test.py +# 但由于test.py中的_pivot_row函数中获得min_rows的时候,数据为tracer +# 报错:jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[]. +# 所以暂时无法运行 +# 目前完成linprog后可以替换quantile.py的代码,之后完成tests和emulations +# 先去看adaboost 2024.7.11 + +# 现在的问题是if语句导致的修改矩阵T大小不知道咋处理,jit编译后报错因为都是tracer 7.14 +# 解决办法是删除if判断 +# 但是出现新的错误,不会处理 + + +# from collections import namedtuple + +# _LPProblem = namedtuple('_LPProblem', +# 'c A_ub b_ub A_eq b_eq bounds x0 integrality') +# _LPProblem.__new__.__defaults__ = (None,) * 7 # make c the only required arg + + +def _pivot_col(T, tol=1e-9, bland=False): + # 创建掩码数组 + mask = T[-1, :-1] >= -tol + + # 定义根据 Bland's 规则选择第一个未被掩盖元素的函数 + def bland_func(_): + return jnp.argmin(jnp.where(mask, jnp.inf, jnp.arange(T.shape[1] - 1))) + + # 定义根据最小值选择列的函数 + def min_func(_): + ma = jnp.where(mask, jnp.inf, T[-1, :-1]) + return jnp.argmin(ma) + + # 检查掩码数组是否全被掩盖 + all_masked = jnp.all(mask) + + # 使用 jax.lax.cond 根据条件选择函数 + result = lax.cond(bland, bland_func, min_func, operand=None) + + # 将返回值转换为浮点数类型,以匹配 NaN 的类型 + result = result + + # 使用 jax.lax.cond 处理全被掩盖的情况 + valid = ~all_masked + + result = lax.cond(all_masked, lambda _: 0, lambda _: result, operand=None) + + return valid, result + +def _pivot_row(T, basis, pivcol, phase, tol=1e-9, bland=False): + # pivcol = lax.cond(jnp.isnan(pivcol), lambda _: jnp.nan, lambda _: jnp.float32(pivcol), operand=None) + # pivcol = jnp.floor(pivcol).astype(jnp.int32) + # if jnp.isnan(pivcol): + # pivcol = jnp.nan + # else: + # pivcol = pivcol.astype(jnp.int32) + # k = lax.cond(phase == 1, lambda _: 2, lambda _: 1, operand=None) + # k = 1 + def true_mask_func(): + mask = T[:-2, pivcol] <= tol + ma = jnp.where(mask, jnp.inf, T[:-2, pivcol]) + mb = jnp.where(mask, jnp.inf, T[:-2, -1]) + + q = mb / ma + # 选择最小比值的行 + min_rows = jnp.nanargmin(q) + all_masked = jnp.all(mask) + return min_rows, all_masked + + def false_mask_func(): + mask = T[:-1, pivcol] <= tol + ma = jnp.where(mask, jnp.inf, T[:-1, pivcol]) + mb = jnp.where(mask, jnp.inf, T[:-1, -1]) + + q = mb / ma + # 选择最小比值的行 + min_rows = jnp.nanargmin(q) + all_masked = jnp.all(mask) + return min_rows, all_masked + + min_rows, all_masked = lax.cond(phase==1, true_mask_func, false_mask_func) + + # 定义处理全被掩盖情况的函数 + def all_masked_func(_): + return 0 + + # 定义选择最小比值行的函数 + def bland_func(_): + return min_rows + # return min_rows[jnp.argmin(jnp.take(basis, min_rows))] + + def min_row_func(_): + return min_rows + + # 检查掩码数组是否全被掩盖 + # all_masked = jnp.all(mask) + has_valid_row = min_rows.size > 0 + + row = lax.cond(bland, bland_func, min_row_func, operand=None) + + row = row + + # 使用 jax.lax.cond 处理全被掩盖的情况 + row = lax.cond(all_masked, all_masked_func, lambda _: row, operand=None) + + # 使用 jax.lax.cond 处理没有满足条件的行的情况 + row = lax.cond(has_valid_row, lambda r: r, lambda _: 0, row) + + return ~all_masked & has_valid_row, row + +def _apply_pivot(T, basis, pivrow, pivcol, tol=1e-9): + + basis = basis.at[pivrow].set(pivcol) + pivval = T[pivrow, pivcol] + T = T.at[pivrow].set(T[pivrow] / pivval) + + def update_row(irow, T, pivrow, pivcol): + pivrow_vector = T[pivrow] # shape (n,) + scalar = T[irow, pivcol] # shape () + # print(f"pivrow_vector shape: {pivrow_vector.shape}, scalar shape: {scalar.shape}") + updated_row = T[irow] - pivrow_vector * scalar + T = T.at[irow].set(updated_row) + return T + + def not_update_row(irow, T, pivrow, pivcol): + return T + + # Update all other rows + def condition(state): + irow, T, pivrow, pivcol= state + return irow < T.shape[0] + + def body(state): + irow, T, pivrow, pivcol = state + T = lax.cond(irow != pivrow, + # lambda _: update_row(irow, T, pivrow, pivcol), + # lambda _: T, + lambda _: update_row(irow, T, pivrow, pivcol), + lambda _: not_update_row(irow, T, pivrow, pivcol), + operand = None) + return irow + 1, T, pivrow, pivcol + + state = 0, T, pivrow, pivcol + irow, T, pivrow, pivcol = lax.while_loop(cond_fun=condition, body_fun=body, init_val=state) + + # if jnp.isclose(pivval, tol, atol=0, rtol=1e4): + # message = ( + # f"The pivot operation produces a pivot value of:{pivval: .1e}, " + # "which is only slightly greater than the specified " + # f"tolerance{tol: .1e}. This may lead to issues regarding the " + # "numerical stability of the simplex method. " + # "Removing redundant constraints, changing the pivot strategy " + # "via Bland's rule or increasing the tolerance may " + # "help reduce the issue.") + # warn(message, stacklevel=5) + + return T, basis + +def _solve_simplex(T, n, basis, + maxiter=1000, tol=1e-9, phase=2, bland=False, nit0=0, + ): + # 删除callback参数,删除postsolve_args参数 + nit = nit0 + status = 0 + message ='' + complete = False + # a,b = T_new_shape + # assert phase in [1, 2],"Argument 'phase' to _solve_simplex must be 1 or 2" + m = lax.cond(phase == 1, lambda _: T.shape[1]-2, lambda _: T.shape[1]-1, operand=None) + # print(m) + def func_col(T, pivrow, basis, tol, nit): + def body_fun(carry): + col, pivrow, T, basis, nit, found = carry + + def apply_pivot_and_update(T, basis, pivrow, col, nit, tol): + pivcol = col + T, basis = _apply_pivot(T, basis, pivrow, pivcol, tol) + nit = nit + 1 + return T, basis, nit + + T, basis, nit = lax.cond(abs(T[pivrow, col]) > tol, + lambda _: apply_pivot_and_update(T, basis, pivrow, col, nit, tol), + lambda _: (T, basis, nit), + (T, basis, nit)) + + found = found | (jnp.abs(T[pivrow, col]) > tol) + + return col + 1, pivrow, T, basis, nit, found + + def cond_fun(carry): + col, pivrow, T, basis, nit, found = carry + return jnp.logical_and(col < T.shape[1] - 1, ~found) + + _, _, T, basis, nit, _ = lax.while_loop(cond_fun, body_fun, (0, pivrow, T, basis, nit, False)) + + return T, basis, nit + + def func_row(T, pivrow, basis, tol, nit): + def body_fun(carry): + T, pivrow, basis, tol, nit = carry + T, basis, nit = lax.cond(basis[pivrow] > T.shape[1] - 2, + lambda _: func_col(T, pivrow, basis, tol, nit), + lambda _: (T, basis, nit), + (T, basis, nit)) + return T, pivrow + 1, basis, tol, nit + + def cond_fun(carry): + T, pivrow, basis, tol, nit = carry + return pivrow < basis.size + + T, pivrow, basis, tol, nit = lax.while_loop(cond_fun, body_fun, (T, pivrow, basis, tol, nit)) + return T, pivrow, basis, tol, nit + + T, pivrow, basis, tol, nit = lax.cond(phase == 2, + lambda _: func_row(T, 0, basis, tol, nit), + lambda _: (T, 0, basis, tol, nit), + (T, 0, basis, tol, nit)) + + def cond_ifnot_complete(carry): + T, basis, pivcol, pivrow, phase, tol, bland, status, complete, nit, maxiter = carry + return ~complete + # 这个complete不会更新,一直为False + + def body_ifnot_complete(carry): + T, basis, pivcol, pivrow, phase, tol, bland, status, complete, nit, maxiter = carry + pivcol_found, pivcol = _pivot_col(T, tol, bland) + pivrow_found = False + def cal_pivcol_found_True(T, basis, pivcol, phase, tol, bland, status, complete): + pivrow_found, pivrow = _pivot_row(T, basis, pivcol, phase, tol, bland) + status, complete = lax.cond(pivrow_found==False, + lambda _: (3, True), + lambda _: (status, complete), + None) + return pivrow_found, pivrow, status, complete + + pivcol, pivrow, status, complete = lax.cond(pivcol_found==False, + lambda _: (0, 0, 0, True), + lambda _: (pivcol, pivrow, status, complete), + None) + pivrow_found, pivrow, status, complete = lax.cond(pivcol_found==True, + lambda _: cal_pivcol_found_True(T, basis, pivcol, phase, tol, bland, status, complete), + lambda _: (pivrow_found, pivrow, status, complete), + None) + def cal_ifnot_complete(T, basis, nit, status, complete, maxiter): + status, complete = lax.cond(nit >= maxiter, lambda _: (1, True), lambda _: (status, complete), None) + T, basis = lax.cond(nit < maxiter, lambda _: _apply_pivot(T, basis, pivrow, pivcol, tol), lambda _: (T, basis), None) + nit = lax.cond(nit < maxiter, lambda _: nit+1, lambda _: nit, None) + return T, basis, nit, status, complete + + T, basis, nit, status, complete = lax.cond(complete==False, + lambda _: cal_ifnot_complete(T, basis, nit, status, complete, maxiter), + lambda _: (T, basis, nit, status, complete), + None) + return T, basis, pivcol, pivrow, phase, tol, bland, status, complete, nit, maxiter + + T, basis, pivcol, pivrow, phase, tol, bland, status, complete, nit, maxiter = lax.while_loop(cond_ifnot_complete, + body_ifnot_complete, + (T, basis, 0, 0, phase, tol, bland, status, complete, nit, maxiter)) + return T, basis, nit, status + # while not complete: + # pivcol_found, pivcol = _pivot_col(T, tol, bland) + # print(pivcol_found) + # pivrow_found = False + # def cal_pivcol_found_True(T, basis, pivcol, phase, tol, bland, status, complete): + # pivrow_found, pivrow = _pivot_row(T, basis, pivcol, phase, tol, bland) + # status, complete = lax.cond(pivrow_found==False, + # lambda _: (3, True), + # lambda _: (status, complete), + # None) + # return pivrow_found, pivrow, status, complete + + # pivcol, pivrow, status, complete = lax.cond(pivcol_found==False, + # lambda _: (0, 0, 0, True), + # lambda _: (pivcol, pivrow, status, complete), + # None) + # pivrow_found, pivrow, status, complete = lax.cond(pivcol_found==True, + # lambda _: cal_pivcol_found_True(T, basis, pivcol, phase, tol, bland, status, complete), + # lambda _: (pivrow_found, pivrow, status, complete), + # None) + # def cal_ifnot_complete(T, basis, nit, status, complete, maxiter): + # status, complete = lax.cond(nit >= maxiter, lambda _: (1, True), lambda _: (status, complete), None) + # T, basis = lax.cond(nit < maxiter, lambda _: _apply_pivot(T, basis, pivrow, pivcol, tol), lambda _: (T, basis), None) + # nit = lax.cond(nit < maxiter, lambda _: nit+1, lambda _: nit, None) + # return T, basis, nit, status, complete + # print(T) + # print("---------------------------------------------------") + # T, basis, nit, status, complete = lax.cond(complete==False, + # lambda _: cal_ifnot_complete(T, basis, nit, status, complete, maxiter), + # lambda _: (T, basis, nit, status, complete), + # None) + + # return T, basis, nit, status + + +def _linprog_simplex(c, A, b, c0=0, + maxiter=1000, tol=1, disp=False, bland=False, + **unknown_options): + # 删除参数callback, postsolve_args + status = 0 + messages = {0: "Optimization terminated successfully.", + 1: "Iteration limit reached.", + 2: "Optimization failed. Unable to find a feasible" + " starting point.", + 3: "Optimization failed. The problem appears to be unbounded.", + 4: "Optimization failed. Singular matrix encountered."} + + n, m = A.shape + + # All constraints must have b >= 0. + is_negative_constraint = jnp.less(b, 0) + A = jnp.where(is_negative_constraint[:, None], A * -1, A) + b = jnp.where(is_negative_constraint, b * -1, b) + + av = jnp.arange(n) + m + # print(av) + basis = av.copy() + + row_constraints = jnp.hstack((A, jnp.eye(n), b[:, jnp.newaxis])) + row_objective = jnp.hstack((c, jnp.zeros(n), c0)) + row_pseudo_objective = -row_constraints.sum(axis=0) + row_pseudo_objective = row_pseudo_objective.at[av].set(0) + T = jnp.vstack((row_constraints, row_objective, row_pseudo_objective)) + # print(T) + # print(n) + # print(basis) + # phase 1 + T, basis, nit1, status = _solve_simplex(T, n, basis, maxiter=maxiter, tol=tol, phase=1, bland=bland) + # print(T) + nit2 = nit1 + + def if_abs_true(status): + return status + def if_abs_false(status): + status = 2 + return status + status = lax.cond(jnp.abs(T[-1, -1]) < tol,if_abs_true,if_abs_false,(status)) + # messages[2] = ( + # "Phase 1 of the simplex method failed to find a feasible " + # "solution. The pseudo-objective function evaluates to {0:.1e} " + # "which exceeds the required tolerance of {1} for a solution to be " + # "considered 'close enough' to zero to be a basic solution. " + # "Consider increasing the tolerance to be greater than {0:.1e}. " + # "If this tolerance is unacceptably large the problem may be " + # "infeasible.".format(abs(T[-1, -1]), tol) + # ) + # def modify_tensor(T, av, tol): + # if (abs(T[-1, -1]) < tol): + # T = T[:-1, :] + # T = jnp.delete(T, av, 1) + # return T + + # jit_modify_tensor = jit(modify_tensor, static_argnums=(0,)) + # print(jnp.any(T>0)) + # av = tuple(av.tolist()) + # print(av) + # T = jit_modify_tensor(T, av, tol) + + # av = av.item() + # print(av) + # print("av",av.shape) + + original_shape = T.shape + # print(original_shape) + T_new = T[:-1, :] + jit_delete = jit(jnp.delete, static_argnames=['assume_unique_indices']) + T = jnp.delete(T_new, av, 1, assume_unique_indices=True) + # ndicator = jnp.ones(T.shape[1], dtype=int) + # print("avdsa",av) + # def true_fn(T): + # T_new = T[:-1, :] + # jit_delete = jit(jnp.delete, static_argnames=['assume_unique_indices']) + # T_new = jnp.delete(T_new, av, 1, assume_unique_indices=True) + # # 保存有效部分的形状信息 + # # T_new_shape = jnp.array(T_new.shape) + # T_new_shape = jnp.array([original_shape[0]-1, original_shape[1]-len(av)]) + # padding = [(0, original_shape[0] - T_new.shape[0]), (0, original_shape[1] - T_new.shape[1])] + # T_padded = jnp.pad(T_new, padding, mode='constant') + # return T_padded, jnp.array(original_shape), T_new_shape + # # indicator = indicator.at[av].set(0) + + # def false_fn(T): + # return T, jnp.array(original_shape), jnp.array(original_shape) + + # T, T_shape, T_new_shape = lax.cond(abs(T[-1, -1]) < tol, true_fn, false_fn, T) + + # return T_modified, T_shape, T_new_shape + # print(T_new_shape) + def recover_tensor(T_recovered, T_shape, T_new_shape): + # 根据保存的形状信息恢复原形状 + # rows_to_keep = abs(T_new_shape[0] - T_shape[0]) + # cols_to_keep = abs(T_new_shape[1] - T_shape[1]) + + # 使用lax.dynamic_slice来保留左上角部分 + # print(T_recovered) + T_recovered = lax.dynamic_slice(T_recovered, (0, 0), (T_new_shape[0], T_new_shape[1])) + # print("T_recovered",T_recovered.shape) + return T_recovered + + # T, T_shape, T_new_shape = modify_tensor(T, av, tol) + # print(T_new_shape) + # T = recover_tensor(T, T_shape, T_new_shape) + # # phase 2 + T, basis, nit2, status = lax.cond(status == 0, + lambda _: _solve_simplex(T, n, basis, maxiter, tol, 2, bland, nit1), + lambda _: (T, basis, nit2, status), + None) + + solution = jnp.zeros(n+m) + solution = solution.at[basis[:n]].set(T[:n, -1]) + # status = status.astype(int) + x = solution[:m] + return x, status, nit2 + + + +# if __name__ == "__main__": + # T = jnp.array([ + # [ 1., 1., 0., 1., 0., 0., 4.], + # [ 2., 1., -1., 0., 1., 0., 3.], + # [-1., 2., 1., 0., 0., 1., 2.], + # [-2., -1., 2., 0., 0., 0., 0.], + # [-2., -4., -0., 0., 0., 0., -9.], + # ]) + # n = 3 + # basis = jnp.array([3,4,5]) # 假设初始的基变量索引 + + # # pivcol_found, pivcol = _pivot_col(T) + # # print(pivcol_found) + # # print(pivcol) + # pivcol = 1 + # phase = 1 + # pivrow = 2 + + # # pivrow_found, pivrow = _pivot_row(T, basis, pivcol, phase) + # # print(pivrow_found) + # # print(pivrow) + + # # T, basis = _apply_pivot(T, basis, pivrow, pivcol) + # # print(T) + # # print(basis) + + + + # T, basis, nit1, status = _solve_simplex(T, n, basis,phase=1) + # print(T) + # print(nit1) + +# # 处理结果 +# print("Final T matrix:") +# print(T_final) +# print("Final basis:", basis_final) +# print("Number of iterations:", nit_final) +# print("Status:", status_final) + + # T, basis = _apply_pivot(T, basis, 1, 0) + # print(T) + # print(basis) + + # A_eq = jnp.eye(3) + # b_eq = jnp.ones(3) + # c = jnp.array([1, 2, 3]) + # # n,m = A_eq.shape + # result = _linprog_simplex(c, A_eq, b_eq) + + # print(result) + # print(result[0]@c) + + # from scipy.optimize import linprog + # result = linprog(c, A_eq=A_eq, b_eq=b_eq, method='simplex') + # print(result) diff --git a/sml/tree/BUILD.bazel b/sml/tree/BUILD.bazel index 439ac5ea..b7857901 100644 --- a/sml/tree/BUILD.bazel +++ b/sml/tree/BUILD.bazel @@ -20,3 +20,8 @@ py_library( name = "tree", srcs = ["tree.py"], ) + +py_library( + name = "tree_w", + srcs = ["tree_w.py"], +) diff --git a/sml/tree/tree_w.py b/sml/tree/tree_w.py new file mode 100644 index 00000000..2015c4d8 --- /dev/null +++ b/sml/tree/tree_w.py @@ -0,0 +1,304 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import jax.numpy as jnp + + +class DecisionTreeClassifier: + """A decision tree classifier based on [GTree](https://arxiv.org/abs/2305.00645). + + Adopting a MPC-based linear scan method (i.e. oblivious_array_access), GTree + designs a new GPU-friendly oblivious decision tree training protocol, which is + more efficient than the prior works. The current implementation supports the training + of decision tree with binary features (i.e. {0, 1}) and multi-class labels (i.e. {0, 1, 2, \dots}). + + We provide a simple example to show how to use GTree to train a decision tree classifier + in sml/tree/emulations/tree_emul.py. For training, the memory and time complexity is around + O(n_samples * n_labels * n_features * 2 ** max_depth). + + Parameters + ---------- + criterion : {"gini"}, default="gini" + The function to measure the quality of a split. Supported criteria are + "gini" for the Gini impurity. + + splitter : {"best"}, default="best" + The strategy used to choose the split at each node. Supported + strategies are "best" to choose the best split. + + max_depth : int + The maximum depth of the tree. Must specify an integer > 0. + + n_labels: int, the max number of labels. + """ + + def __init__(self, criterion, splitter, max_depth, n_labels): + assert criterion == "gini", "criteria other than gini is not supported." + assert splitter == "best", "splitter other than best is not supported." + assert ( + max_depth is not None and max_depth > 0 + ), "max_depth should not be None and must > 0." + self.max_depth = max_depth + self.n_labels = n_labels + # self.sample_weight = sample_weight + + def fit(self, X, y, sample_weight): + self.T, self.F = odtt(X, y, self.max_depth, self.n_labels, sample_weight) + return self + + def predict(self, X): + assert self.T != None, "the model has not been trained yet." + return odti(X, self.T, self.max_depth) + + +''' +The protocols of GTree. +''' + + +def oblivious_array_access(array, index): + ''' + Extract elements from array according to index. + + If array is 1D, then output [array[i] for i in index]. + e.g.: array = [1, 2, 3, 4, 5], index = [0, 2, 4], output = [1, 3, 5]. + + If array is 2D, then output [[array[j, i] for i in index] for j in range(array.shape[0])]. + e.g. array = [[1, 2, 3], [4, 5, 6]], index_array = [0, 2], output = [[1, 3], [4, 6]]. + ''' + # (n_array) + count_array = jnp.arange(0, array.shape[-1]) + # (n_array, n_index) + E = jnp.equal(index, count_array[:, jnp.newaxis]) + + assert len(array.shape) <= 2, "OAA protocol only supports 1D or 2D array." + + # OAA basic case + if len(array.shape) == 1: + # (n_array, n_index) + O = array[:, jnp.newaxis] * E # select shares + zu = jnp.sum(O, axis=0) + # OAA vectorization variant + elif len(array.shape) == 2: + # (n_arrays, n_array, n_index) + O = array[:, :, jnp.newaxis] * E[jnp.newaxis, :, :] # select shares + zu = jnp.sum(O, axis=1) + return zu + + +def oaa_elementwise(array, index_array): + ''' + Given index_array, output [array[i, index[i]] for i in range(len(array))]. + + e.g.: array = [[1, 2, 3], [4, 5, 6]], index = [0, 2], output = [1, 6]. + ''' + assert index_array.shape[0] == array.shape[0], "n_arrays must be equal to n_index." + assert len(array.shape) == 2, "OAAE protocol only supports 2D array." + count_array = jnp.arange(0, array.shape[-1]) + # (n_array, n_index) + E = jnp.equal(index_array[:, jnp.newaxis], count_array) + if len(array.shape) == 2: + O = array * E + zu = jnp.sum(O, axis=1) + return zu + + +# def oblivious_learning(X, y, T, F, M, Cn, h): +def oblivious_learning(X, y, T, F, M, h, Cn, n_labels, sample_weight=None): + '''partition the data and count the number of data samples. + + params: + D: data samples, which is splitted into X, y. X: (n_samples, n_features), y: (n_samples, 1). + T: tree structure reprensenting split features. (total_nodes) + F: tree structure reprensenting node types. (total_nodes) + 0 for internal, 1 for leaf, 2 for dummy. + M: which leave node does D[i] belongs to (for level h-1). (n_samples) + Cn: statical information of the data samples. (n_leaves, n_labels+1, 2*n_features) + h: int, current depth of the tree. + ''' + # line 1-5, partition the datas into new leaves. + n_d, n_f = X.shape + n_h = 2**h + if h != 0: + Tval = oaa(T, M) + Dval = oaae(X, Tval) + M = 2 * M + Dval + 1 + + LCidx = jnp.arange(0, n_h) + isLeaf = jnp.equal(F[n_h - 1 : 2 * n_h - 1], jnp.ones(n_h)) + LCF = jnp.equal(M[:, jnp.newaxis] - n_h + 1, LCidx) + LCF = LCF * isLeaf + + Cd = jnp.zeros((n_d, n_h, n_labels + 1, 2 * n_f)) + if sample_weight is not None: + Cd = Cd.at[:, :, 0, 0::2].set( + jnp.tile((1 - X)[:, jnp.newaxis, :] * sample_weight[:, jnp.newaxis, jnp.newaxis], (1, n_h, 1)) + ) + Cd = Cd.at[:, :, 0, 1::2].set( + jnp.tile((X)[:, jnp.newaxis, :] * sample_weight[:, jnp.newaxis, jnp.newaxis], (1, n_h, 1)) + ) + else: + Cd = Cd.at[:, :, 0, 0::2].set(jnp.tile((1 - X)[:, jnp.newaxis, :], (1, n_h, 1))) + Cd = Cd.at[:, :, 0, 1::2].set(jnp.tile((X)[:, jnp.newaxis, :], (1, n_h, 1))) + + for i in range(n_labels): + if sample_weight is not None: + # sample_weight = sample_weight.reshape(-1, 1, 1) + Cd = Cd.at[:, :, i + 1, 0::2].set( + jnp.tile( + ((1 - X)[:, jnp.newaxis, :] * (i == y)[:, jnp.newaxis, jnp.newaxis] * sample_weight[:, jnp.newaxis, jnp.newaxis]), + (1, n_h, 1) + ) + ) + Cd = Cd.at[:, :, i + 1, 1::2].set( + jnp.tile( + ((X)[:, jnp.newaxis, :] * (i == y)[:, jnp.newaxis, jnp.newaxis] * sample_weight[:, jnp.newaxis, jnp.newaxis]), + (1, n_h, 1) + ) + ) + else: + Cd = Cd.at[:, :, i + 1, 0::2].set( + jnp.tile( + ((1 - X) * (i == y)[:, jnp.newaxis]), + (1, n_h, 1) + ) + ) + Cd = Cd.at[:, :, i + 1, 1::2].set( + jnp.tile( + ((X) * (i == y)[:, jnp.newaxis]), + (1, n_h, 1) + ) + ) + + Cd = Cd * LCF[:, :, jnp.newaxis, jnp.newaxis] + + new_Cn = jnp.sum(Cd, axis=0) + + if h != 0: + Cn = Cn.repeat(2, axis=0) + new_Cn = new_Cn[:, :, :] + Cn[:, :, :] * (1 - isLeaf[:, jnp.newaxis, jnp.newaxis]) + + return new_Cn, M + + + +def oblivious_heuristic_computation(Cn, gamma, F, h, n_labels): + '''Compute gini index, find the best feature, and update F. + + params: + Cn: statical information of the data samples. (n_leaves, n_labels+1, 2*n_features) + gamma: gamma[n][i] indicates if feature si has been assigned at node n. (n_leaves, n_features) + F: tree structure reprensenting node types. (total_nodes) + 0 for internal, 1 for leaf, 2 for dummy. + h: int, current depth of the tree. + n_labels: int, number of labels. + ''' + n_leaves = Cn.shape[0] + n_features = gamma.shape[1] + Ds0 = Cn[:, 0, 0::2] + Ds1 = Cn[:, 0, 1::2] + D = Ds0 + Ds1 + Q = D * Ds0 * Ds1 + P = jnp.zeros(gamma.shape) + for i in range(n_labels): + P = P - Ds1 * (Cn[:, i + 1, 0::2] ** 2) - Ds0 * (Cn[:, i + 1, 1::2] ** 2) + gini = Q / (Q + P + 1) + gini = gini * gamma + # (n_leaves) + SD = jnp.argmax(gini, axis=1) + index = jnp.arange(0, n_features) + gamma = gamma * jnp.not_equal(index[jnp.newaxis, :], SD[:, jnp.newaxis]) + new_gamma = jnp.zeros((n_leaves * 2, n_features)) + new_gamma = new_gamma.at[0::2, :].set(gamma) + new_gamma = new_gamma.at[1::2, :].set(gamma) + + # # modification. + psi = jnp.zeros((n_leaves, n_labels)) + for i in range(n_labels): + psi = psi.at[:, i].set(Cn[:, i + 1, 0] + Cn[:, i + 1, 1]) + total = jnp.sum(psi, axis=1) + psi = total[:, jnp.newaxis] - psi + psi = jnp.prod(psi, axis=1) + F = F.at[2**h - 1 : 2 ** (h + 1) - 1].set( + jnp.equal(psi * F[2**h - 1 : 2 ** (h + 1) - 1], 0) + ) + F = F.at[2 ** (h + 1) - 1 : 2 ** (h + 2) - 1 : 2].set( + 2 - jnp.equal(F[2**h - 1 : 2 ** (h + 1) - 1], 0) + ) + F = F.at[2 ** (h + 1) : 2 ** (h + 2) - 1 : 2].set( + F[2 ** (h + 1) - 1 : 2 ** (h + 2) - 1 : 2] + ) + return SD, new_gamma, F + + +def oblivious_node_split(SD, T, F, Cn, h, max_depth): + '''Convert each node into its internal node and generates new leaves at the next level.''' + + T = T.at[2**h - 1 : 2 ** (h + 1) - 1].set(SD) + return T, Cn + + +def oblivious_DT_training(X, y, max_depth, n_labels, sample_weight=None): + n_samples, n_features = X.shape + T = jnp.zeros((2 ** (max_depth + 1) - 1)) + F = jnp.ones((2**max_depth - 1)) + M = jnp.zeros(n_samples) + gamma = jnp.ones((1, n_features)) + Cn = jnp.zeros((1, n_labels + 1, 2 * n_features)) + + h = 0 + while h < max_depth: + if sample_weight is not None: + Cn, M = ol(X, y, T, F, M, h, Cn, n_labels, sample_weight) + else: + Cn, M = ol(X, y, T, F, M, h, Cn, n_labels) + + SD, gamma, F = ohc(Cn, gamma, F, h, n_labels) + + T, Cn = ons(SD, T, F, Cn, h, max_depth) + + h += 1 + + n_leaves = 2**h + psi = jnp.zeros((n_leaves, n_labels)) + for i in range(2 ** (h - 1)): + t1 = oaa(Cn[i, 1:], 2 * SD[i : i + 1]).squeeze() + t2 = oaa(Cn[i, 1:], 2 * SD[i : i + 1] + 1).squeeze() + psi = psi.at[2 * i, :].set(t1) + psi = psi.at[2 * i + 1, :].set(t2) + T = T.at[n_leaves - 1 :].set(jnp.argmax(psi, axis=1)) + return T, F + + +def oblivious_DT_inference(X, T, max_height): + n_samples, n_features = X.shape + Tidx = jnp.zeros((n_samples)) + i = 0 + while i < max_height: + Tval = oaa(T, Tidx) + Dval = oaae(X, Tval) + Tidx = Tidx * 2 + Dval + 1 + i += 1 + Tval = oaa(T, Tidx) + return Tval + + +oaa = oblivious_array_access +oaae = oaa_elementwise +ol = oblivious_learning +ohc = oblivious_heuristic_computation +ons = oblivious_node_split +odtt = oblivious_DT_training +odti = oblivious_DT_inference