diff --git a/sml/ensemble/BUILD.bazel b/sml/ensemble/BUILD.bazel index 32e3eb55..8d4a19b9 100644 --- a/sml/ensemble/BUILD.bazel +++ b/sml/ensemble/BUILD.bazel @@ -22,7 +22,9 @@ py_library( deps = [ "//sml/tree:tree", ] - +) + +py_library( name = "forest", srcs = ["forest.py"], deps = [ diff --git a/sml/ensemble/adaboost.py b/sml/ensemble/adaboost.py index 9dea1a56..7f81fe47 100644 --- a/sml/ensemble/adaboost.py +++ b/sml/ensemble/adaboost.py @@ -35,6 +35,14 @@ class AdaBoostClassifier: learning_rate : float The step size used to update the model weights during training. It's an float, must learning_rate > 0. + + algorithm : str (default='discrete') + The boosting algorithm to use. Only the SAMME discrete algorithm is used in this implementation. + In scikit-learn, the Real Boosting Algorithm (SAMME.R) will be deprecated. + + epsilon : float (default=1e-5) + A small positive value used in calculations to avoid division by zero and other numerical issues. + Must be greater than 0 and less than 0.1. """ def __init__( @@ -43,6 +51,7 @@ def __init__( n_estimators, learning_rate, algorithm, + epsilon = 1e-5, ): assert isinstance(estimator, sml_dtc), "Estimator other than sml_dtc is not supported." assert ( @@ -54,19 +63,23 @@ def __init__( "You can refer to the official documentation for more details: " "https://github.com/scikit-learn/scikit-learn/issues/26784" ) - + assert ( + epsilon > 0 and epsilon < 0.1 + ), "epsilon must be > 0 and < 0.1." + self.estimator = estimator self.n_estimators = n_estimators self.learning_rate = learning_rate self.algorithm = algorithm + self.epsilon = epsilon self.n_classes = estimator.n_labels self.estimators_ = [] - self.estimator_weight = jnp.zeros(self.n_estimators, dtype=jnp.float32) - self.estimator_errors = jnp.ones(self.n_estimators, dtype=jnp.float32) - self.estimator_flags_ = [] - # self.early_stop = False # 添加 early_stop 标志 + self.estimator_weight_ = jnp.zeros(self.n_estimators, dtype=jnp.float32) + self.estimator_errors_ = jnp.ones(self.n_estimators, dtype=jnp.float32) + self.estimator_flags_ = jnp.zeros(self.n_estimators, dtype=jnp.bool_) + self.early_stop = False # 添加 early_stop 标志 def _num_samples(self, x): """返回x中的样本数量.""" @@ -133,22 +146,20 @@ def fit(self, X, y, sample_weight=None): sample_weight /= sample_weight.sum() self.classes = y - - epsilon = jnp.finfo(sample_weight.dtype).eps - - self.estimator_weight_ = jnp.zeros(self.n_estimators, dtype=jnp.float32) - self.estimator_errors_ = jnp.ones(self.n_estimators, dtype=jnp.float32) + + epsilon = self.epsilon for iboost in range(self.n_estimators): sample_weight = jnp.clip(sample_weight, a_min=epsilon, a_max=None) estimator = copy.deepcopy(self.estimator) - sample_weight, estimator_weight, estimator_error = self._boost_discrete( + sample_weight, estimator_weight, estimator_error, flag = self._boost_discrete( iboost, X, y, sample_weight, estimator, ) self.estimator_weight_ = self.estimator_weight_.at[iboost].set(estimator_weight) self.estimator_errors_ = self.estimator_errors_.at[iboost].set(estimator_error) + self.estimator_flags_ = self.estimator_flags_.at[iboost].set(flag) sample_weight_sum = jnp.sum(sample_weight) if iboost < self.n_estimators - 1: @@ -161,6 +172,7 @@ def _boost_discrete(self, iboost, X, y, sample_weight, estimator): self.estimators_.append(estimator) n_classes = self.n_classes + epsilon = self.epsilon estimator.fit(X, y, sample_weight=sample_weight) @@ -169,55 +181,57 @@ def _boost_discrete(self, iboost, X, y, sample_weight, estimator): incorrect = y_predict != y estimator_error = jnp.mean(jnp.average(incorrect, weights=sample_weight, axis=0)) - # 判断是否需要提前停止 - # if estimator_error > 0.0: - # self.early_stop = True - # self.early_stop = lax.cond( - # estimator_error > 0.0, - # lambda _: jnp.array(True, dtype=jnp.bool_), - # lambda _: jnp.array(False, dtype=jnp.bool_), - # operand=None - # ) + self.early_stop = lax.cond( + estimator_error <= epsilon, + lambda _: jnp.array(True, dtype=jnp.bool_), + lambda _: self.early_stop, + operand=None + ) def true_0_fun(sample_weight): return sample_weight, 1.0, 0.0, jnp.array(False, dtype=jnp.bool_) def false_0_fun(sample_weight): - estimator_weight = self.learning_rate * ( - jnp.log((1.0 - estimator_error) / estimator_error) + jnp.log(n_classes - 1.0) + flag = estimator_error < 1.0 - (1.0 / n_classes) + flag = lax.cond( + self.early_stop, + lambda _: jnp.array(False, dtype=jnp.bool_), + lambda _: flag, + operand=None ) - def not_last_iboost(sample_weight): - # Only boost positive weights - sample_weight *= jnp.exp(estimator_weight * incorrect) - return sample_weight - - def last_iboost(sample_weight): - return sample_weight - sample_weight = lax.cond(iboost != self.n_estimators - 1, - not_last_iboost, last_iboost, sample_weight) + # Update weights only if flag is True + def update_weights(params): + estimator_error, incorrect, sample_weight = params + estimator_weight = self.learning_rate * ( + jnp.log((1.0 - estimator_error) / estimator_error) + jnp.log(n_classes - 1.0) + ) + sample_weight *= jnp.exp(estimator_weight * incorrect) + return sample_weight, estimator_weight - flag = estimator_error < 1.0 - (1.0 / n_classes) - # flag = lax.cond( - # self.early_stop, - # lambda _: jnp.array(False, dtype=jnp.bool_), - # lambda _: flag, - # operand=None - # ) + def skip_update(params): + estimator_error, incorrect, sample_weight = params + return sample_weight, 0.0 # Return zero for estimator_weight + + sample_weight, estimator_weight = lax.cond( + flag, + update_weights, + skip_update, + operand=(estimator_error, incorrect, sample_weight) + ) return sample_weight, estimator_weight, estimator_error, flag sample_weight, estimator_weight, estimator_error, flag = lax.cond( - estimator_error <= 0.0, true_0_fun, false_0_fun, sample_weight + estimator_error <= epsilon, true_0_fun, false_0_fun, sample_weight ) - self.estimator_flags_.append(flag) # 维护 flag 属性 - - return sample_weight, estimator_weight, estimator_error + return sample_weight, estimator_weight, estimator_error, flag def predict(self, X): pred = self.decision_function(X) + print(self.early_stop) if self.n_classes == 2: return self.classes.take(pred > 0, axis=0) @@ -228,65 +242,25 @@ def predict(self, X): def decision_function(self, X): n_classes = self.n_classes classes = self.classes[:, jnp.newaxis] - - # pred = sum( - # jnp.where( - # (estimator.predict(X) == classes).T, - # w, - # -1 / (n_classes - 1) * w, - # ) - # for estimator, w in zip(self.estimators_, self.estimator_weight_) - # ) - # pred /= self.estimator_weight_.sum() + + print(self.estimators_) + print(self.estimator_weight_) + print('--------') + print(self.estimator_flags_) pred = sum( jnp.where( (estimator.predict(X) == classes).T, w, -1 / (n_classes - 1) * w, - ) * flag # 使用 flag + ) * flag for estimator, w, flag in zip(self.estimators_, self.estimator_weight_, self.estimator_flags_) ) - # pred = sum( - # jnp.where( - # (estimator.predict(X) == classes).T, - # w, - # -1 / (n_classes - 1) * w, - # ) * flag - # for estimator, w, flag in zip(self.estimators_, self.estimator_weight_, self.estimator_flags_) - # if not self.early_stop or flag # 使用 early_stop 进行过滤 - # ) - # 将列表转换为 JAX 数组,并进行求和 weights_flags = jnp.array([w * flag for w, flag in zip(self.estimator_weight_, self.estimator_flags_)]) pred /= jnp.sum(weights_flags) - # # 计算每个估计器的预测结果 - # predictions = [ - # jnp.where( - # (estimator.predict(X) == classes).T, - # w, - # -1 / (n_classes - 1) * w, - # ) - # for estimator, w, flag in zip(self.estimators_, self.estimator_weight_, self.estimator_flags_) - # ] - - # # 使用 lax.cond 处理 early_stop 逻辑 - # def apply_flags(predictions, weights_flags): - # return sum(p * f for p, f in zip(predictions, weights_flags)) - - # weights_flags = jnp.array([w * flag for w, flag in zip(self.estimator_weight_, self.estimator_flags_)]) - - # pred = lax.cond( - # self.early_stop, - # lambda _: apply_flags(predictions, jnp.array([0] * len(predictions))), - # lambda _: apply_flags(predictions, weights_flags), - # operand=None - # ) - - # pred /= jnp.sum(weights_flags) - if n_classes == 2: pred[:, 0] *= -1 return pred.sum(axis=1) diff --git a/sml/ensemble/emulations/BUILD.bazel b/sml/ensemble/emulations/BUILD.bazel index 16c1a939..a2d1da51 100644 --- a/sml/ensemble/emulations/BUILD.bazel +++ b/sml/ensemble/emulations/BUILD.bazel @@ -17,18 +17,19 @@ load("@rules_python//python:defs.bzl", "py_binary") package(default_visibility = ["//visibility:public"]) py_binary( -<<<<<<< HEAD name = "adaboost_emul", srcs = ["adaboost_emul.py"], deps = [ "//sml/ensemble:adaboost", + "//sml/utils:emulation", ] -======= +) + +py_binary( name = "forest_emul", srcs = ["forest_emul.py"], deps = [ "//sml/ensemble:forest", ->>>>>>> 854f3ef0925bc419ae804ae77a4a9843ec15db7f "//sml/utils:emulation", ], ) diff --git a/sml/ensemble/emulations/adaboost_emul.py b/sml/ensemble/emulations/adaboost_emul.py index e4ff9987..2424165b 100644 --- a/sml/ensemble/emulations/adaboost_emul.py +++ b/sml/ensemble/emulations/adaboost_emul.py @@ -19,6 +19,7 @@ from sklearn.tree import DecisionTreeClassifier import sml.utils.emulation as emulation +from sml.tree.tree import DecisionTreeClassifier as sml_dtc from sml.ensemble.adaboost import AdaBoostClassifier as sml_Adaboost MAX_DEPTH = 3 @@ -26,18 +27,18 @@ def emul_ada(mode=emulation.Mode.MULTIPROCESS): def proc_wrapper( - estimator = "dtc", - n_estimators = 50, - max_depth = MAX_DEPTH, - learning_rate = 1.0, - n_classes = 3, - ): + estimator, + n_estimators, + learning_rate, + algorithm, + epsilon, + ): ada_custom = sml_Adaboost( - estimator = "dtc", - n_estimators = 50, - max_depth = MAX_DEPTH, - learning_rate = 1.0, - n_classes = 3, + estimator = estimator, + n_estimators = n_estimators, + learning_rate = learning_rate, + algorithm=algorithm, + epsilon=epsilon, ) def proc(X, y): @@ -86,12 +87,13 @@ def load_data(): X_spu, y_spu = emulator.seal(X, y) # run + dtc = sml_dtc("gini", "best", 3, 3) proc = proc_wrapper( - estimator = "dtc", + estimator = dtc, n_estimators = 3, - max_depth = MAX_DEPTH, learning_rate = 1.0, - n_classes = 3, + algorithm="discrete", + epsilon = 1e-5, ) start = time.time() diff --git a/sml/ensemble/tests/BUILD.bazel b/sml/ensemble/tests/BUILD.bazel index a0b7756b..6815cf85 100644 --- a/sml/ensemble/tests/BUILD.bazel +++ b/sml/ensemble/tests/BUILD.bazel @@ -24,6 +24,9 @@ py_test( "//spu:init", "//spu/utils:simulation", ], +) + +py_test( name = "forest_test", srcs = ["forest_test.py"], deps = [ diff --git a/sml/ensemble/tests/adaboost_test.py b/sml/ensemble/tests/adaboost_test.py index 2692d6a8..21ee7347 100644 --- a/sml/ensemble/tests/adaboost_test.py +++ b/sml/ensemble/tests/adaboost_test.py @@ -33,12 +33,14 @@ def proc_wrapper( n_estimators, learning_rate, algorithm, + epsilon, ): ada_custom = sml_Adaboost( estimator = estimator, n_estimators = n_estimators, learning_rate = learning_rate, algorithm=algorithm, + epsilon=epsilon, ) def proc(X, y): @@ -86,6 +88,7 @@ def load_data(): n_estimators = 3, learning_rate = 1.0, algorithm="discrete", + epsilon = 1e-5, ) result = spsim.sim_jax(sim, proc)(X, y)