diff --git a/sml/ensemble/BUILD.bazel b/sml/ensemble/BUILD.bazel index 687dd03a..50bfc8c4 100644 --- a/sml/ensemble/BUILD.bazel +++ b/sml/ensemble/BUILD.bazel @@ -1,4 +1,4 @@ -# Copyright 2023 Ant Group Co., Ltd. +# Copyright 2024 Ant Group Co., Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,4 +19,7 @@ package(default_visibility = ["//visibility:public"]) py_library( name = "adaboost", srcs = ["adaboost.py"], + deps = [ + "//sml/tree:tree", + ], ) diff --git a/sml/ensemble/adaboost.py b/sml/ensemble/adaboost.py index 9b2c9abd..f6d1c77c 100644 --- a/sml/ensemble/adaboost.py +++ b/sml/ensemble/adaboost.py @@ -1,4 +1,4 @@ -# Copyright 2023 Ant Group Co., Ltd. +# Copyright 2024 Ant Group Co., Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +# 不支持early_stop + +import copy import jax.numpy as jnp from jax import lax import warnings @@ -29,43 +32,41 @@ class AdaBoostClassifier: n_estimators : int The number of estimators. Must specify an integer > 0. - max_depth : int - The maximum depth of the tree. Must specify an integer > 0. - learning_rate : float The step size used to update the model weights during training. It's an float, must learning_rate > 0. - n_classes: int - The max number of classes. - """ def __init__( self, estimator, - # 默认estimator为决策树,criterion == "gini" splitter == "best" n_estimators, - max_depth, learning_rate, - n_classes, + algorithm, ): - assert estimator == "dtc", "estimator other than dtc is not supported." + assert isinstance(estimator, sml_dtc), "Estimator other than sml_dtc is not supported." assert ( n_estimators is not None and n_estimators > 0 ), "n_estimators should not be None and must > 0." - assert( - max_depth is not None and max_depth > 0 - ), "max_depth should not be None and must > 0." - + assert algorithm == "discrete", ( + "Only support SAMME discrete algorithm. " + "In scikit-learn, the Real Boosting Algorithm (SAMME.R) will be deprecated. " + "You can refer to the official documentation for more details: " + "https://github.com/scikit-learn/scikit-learn/issues/26784" + ) + self.estimator = estimator self.n_estimators = n_estimators - self.max_depth = max_depth self.learning_rate = learning_rate - self.n_classes = n_classes + self.algorithm = algorithm + + self.n_classes = estimator.n_labels self.estimators_ = [] self.estimator_weight = jnp.zeros(self.n_estimators, dtype=jnp.float32) self.estimator_errors = jnp.ones(self.n_estimators, dtype=jnp.float32) + self.estimator_flags_ = [] + self.early_stop = False # 添加 early_stop 标志 def _num_samples(self, x): """返回x中的样本数量.""" @@ -82,23 +83,37 @@ def _num_samples(self, x): else: return len(x) - def _check_sample_weight(self, sample_weight, X, dtype=None, copy=False, only_non_negative=False): + def _check_sample_weight(self, sample_weight, X): ''' - description: 验证样本权重. - return {*} - ''' - # jax默认只支持float32, - # 如果需要启用 float64 类型,可以设置 jax_enable_x64 配置选项或 JAX_ENABLE_X64 环境变量。 + Description: Validate and process sample weights. + + Parameters: + - sample_weight: Can be None, a scalar (int or float), or a 1D array-like. + - X: Input data from which to determine the number of samples. + + Returns: + - sample_weight: A 1D array of sample weights, one for each sample in X. + + Sample weight scenarios: + 1. None: + - If sample_weight is None, it will be initialized to an array of ones, + meaning all samples are equally weighted. + 2. Scalar (int or float): + - If sample_weight is a scalar, it will be converted to an array where + each sample's weight is equal to the scalar value. + 3. Array-like: + - If sample_weight is an array or array-like, it will be converted to a JAX array. + - The array must be 1D and its length must match the number of samples. + - If these conditions are not met, an error will be raised. + ''' n_samples = self._num_samples(X) - if dtype is not None and dtype not in [jnp.float32, jnp.float64]: - dtype = jnp.float32 if sample_weight is None: - sample_weight = jnp.ones(n_samples, dtype=dtype) - elif isinstance(sample_weight, numbers.Number): - sample_weight = jnp.full(n_samples, sample_weight, dtype=dtype) + sample_weight = jnp.ones(n_samples, dtype=jnp.float32) + elif isinstance(sample_weight, (jnp.int32, jnp.float32)): + sample_weight = jnp.full(n_samples, sample_weight, dtype=jnp.float32) else: - sample_weight = jnp.asarray(sample_weight, dtype=dtype) + sample_weight = jnp.asarray(sample_weight, dtype=jnp.float32) if sample_weight.ndim != 1: raise ValueError("Sample weight must be 1D array or scalar") @@ -109,59 +124,42 @@ def _check_sample_weight(self, sample_weight, X, dtype=None, copy=False, only_no ) ) - if copy: - sample_weight = jnp.copy(sample_weight) - return sample_weight - def cond_fun(self, iboost, sample_weight, estimator_weight, estimator_error): - status1 = jnp.logical_and(iboost < self.n_estimators, jnp.all(jnp.isfinite(sample_weight))) - status2 = jnp.logical_and(estimator_error > 0, jnp.sum(sample_weight) > 0) - status = jnp.logical_and(status1, status2) - return status - - def fit(self, X, y, sample_weight=None): sample_weight = self._check_sample_weight( - sample_weight, X, copy=True, only_non_negative=True + sample_weight, X, ) sample_weight /= sample_weight.sum() self.classes = y - epsilon = jnp.finfo(sample_weight.dtype).eps - + self.estimator_weight_ = jnp.zeros(self.n_estimators, dtype=jnp.float32) self.estimator_errors_ = jnp.ones(self.n_estimators, dtype=jnp.float32) for iboost in range(self.n_estimators): sample_weight = jnp.clip(sample_weight, a_min=epsilon, a_max=None) - + + estimator = copy.deepcopy(self.estimator) sample_weight, estimator_weight, estimator_error = self._boost_discrete( - iboost, X, y, sample_weight + iboost, X, y, sample_weight, estimator, ) self.estimator_weight_ = self.estimator_weight_.at[iboost].set(estimator_weight) self.estimator_errors_ = self.estimator_errors_.at[iboost].set(estimator_error) sample_weight_sum = jnp.sum(sample_weight) - def not_last_iboost(sample_weight, sample_weight_sum): + if iboost < self.n_estimators - 1: sample_weight /= sample_weight_sum - return sample_weight - def last_iboost(sample_weight, sample_weight_sum): - return sample_weight - sample_weight = lax.cond(iboost 0.0: + # self.early_stop = True + self.early_stop = lax.cond( + estimator_error > 0.0, + lambda _: jnp.array(True, dtype=jnp.bool_), + lambda _: jnp.array(False, dtype=jnp.bool_), + operand=None + ) + def true_0_fun(sample_weight): - return sample_weight, 1.0, 0.0 + return sample_weight, 1.0, 0.0, jnp.array(False, dtype=jnp.bool_) def false_0_fun(sample_weight): estimator_weight = self.learning_rate * ( @@ -180,10 +188,7 @@ def false_0_fun(sample_weight): ) def not_last_iboost(sample_weight): # Only boost positive weights - sample_weight = jnp.exp( - jnp.log(sample_weight) - + estimator_weight * incorrect * (sample_weight > 0) - ) + sample_weight *= jnp.exp(estimator_weight * incorrect) return sample_weight def last_iboost(sample_weight): @@ -192,13 +197,22 @@ def last_iboost(sample_weight): sample_weight = lax.cond(iboost != self.n_estimators - 1, not_last_iboost, last_iboost, sample_weight) + flag = estimator_error < 1.0 - (1.0 / n_classes) + flag = lax.cond( + self.early_stop, + lambda _: jnp.array(False, dtype=jnp.bool_), + lambda _: flag, + operand=None + ) + + return sample_weight, estimator_weight, estimator_error, flag - return sample_weight, estimator_weight, estimator_error - - sample_weight, estimator_weight, estimator_error = lax.cond( + sample_weight, estimator_weight, estimator_error, flag = lax.cond( estimator_error <= 0.0, true_0_fun, false_0_fun, sample_weight ) + self.estimator_flags_.append(flag) # 维护 flag 属性 + return sample_weight, estimator_weight, estimator_error @@ -215,15 +229,63 @@ def decision_function(self, X): n_classes = self.n_classes classes = self.classes[:, jnp.newaxis] + # pred = sum( + # jnp.where( + # (estimator.predict(X) == classes).T, + # w, + # -1 / (n_classes - 1) * w, + # ) + # for estimator, w in zip(self.estimators_, self.estimator_weight_) + # ) + # pred /= self.estimator_weight_.sum() + pred = sum( jnp.where( (estimator.predict(X) == classes).T, w, -1 / (n_classes - 1) * w, - ) - for estimator, w in zip(self.estimators_, self.estimator_weight_) - ) - pred /= self.estimator_weight_.sum() + ) * flag # 使用 flag + for estimator, w, flag in zip(self.estimators_, self.estimator_weight_, self.estimator_flags_) + ) + + # pred = sum( + # jnp.where( + # (estimator.predict(X) == classes).T, + # w, + # -1 / (n_classes - 1) * w, + # ) * flag + # for estimator, w, flag in zip(self.estimators_, self.estimator_weight_, self.estimator_flags_) + # if not self.early_stop or flag # 使用 early_stop 进行过滤 + # ) + + # 将列表转换为 JAX 数组,并进行求和 + weights_flags = jnp.array([w * flag for w, flag in zip(self.estimator_weight_, self.estimator_flags_)]) + pred /= jnp.sum(weights_flags) + + # # 计算每个估计器的预测结果 + # predictions = [ + # jnp.where( + # (estimator.predict(X) == classes).T, + # w, + # -1 / (n_classes - 1) * w, + # ) + # for estimator, w, flag in zip(self.estimators_, self.estimator_weight_, self.estimator_flags_) + # ] + + # # 使用 lax.cond 处理 early_stop 逻辑 + # def apply_flags(predictions, weights_flags): + # return sum(p * f for p, f in zip(predictions, weights_flags)) + + # weights_flags = jnp.array([w * flag for w, flag in zip(self.estimator_weight_, self.estimator_flags_)]) + + # pred = lax.cond( + # self.early_stop, + # lambda _: apply_flags(predictions, jnp.array([0] * len(predictions))), + # lambda _: apply_flags(predictions, weights_flags), + # operand=None + # ) + + # pred /= jnp.sum(weights_flags) if n_classes == 2: pred[:, 0] *= -1 diff --git a/sml/ensemble/emulations/BUILD.bazel b/sml/ensemble/emulations/BUILD.bazel index a45634e2..17ef2d53 100644 --- a/sml/ensemble/emulations/BUILD.bazel +++ b/sml/ensemble/emulations/BUILD.bazel @@ -1,4 +1,4 @@ -# Copyright 2023 Ant Group Co., Ltd. +# Copyright 2024 Ant Group Co., Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,7 +21,6 @@ py_binary( srcs = ["adaboost_emul.py"], deps = [ "//sml/ensemble:adaboost", - "//sml/tree:tree", "//sml/utils:emulation", ], ) diff --git a/sml/ensemble/emulations/adaboost_emul.py b/sml/ensemble/emulations/adaboost_emul.py index 75dbba89..e4ff9987 100644 --- a/sml/ensemble/emulations/adaboost_emul.py +++ b/sml/ensemble/emulations/adaboost_emul.py @@ -1,4 +1,4 @@ -# Copyright 2023 Ant Group Co., Ltd. +# Copyright 2024 Ant Group Co., Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/sml/ensemble/tests/BUILD.bazel b/sml/ensemble/tests/BUILD.bazel index d8637da2..b0d1d676 100644 --- a/sml/ensemble/tests/BUILD.bazel +++ b/sml/ensemble/tests/BUILD.bazel @@ -1,4 +1,4 @@ -# Copyright 2023 Ant Group Co., Ltd. +# Copyright 2024 Ant Group Co., Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,23 +16,11 @@ load("@rules_python//python:defs.bzl", "py_test") package(default_visibility = ["//visibility:public"]) -py_test( - name = "forest_test", - srcs = ["forest_test.py"], - deps = [ - "//sml/ensemble:forest", - "//sml/tree:tree", - "//spu:init", - "//spu/utils:simulation", - ], -) - py_test( name = "adaboost_test", srcs = ["adaboost_test.py"], deps = [ "//sml/ensemble:adaboost", - "//sml/tree:tree", "//spu:init", "//spu/utils:simulation", ], diff --git a/sml/ensemble/tests/adaboost_test.py b/sml/ensemble/tests/adaboost_test.py index f0b73b70..2692d6a8 100644 --- a/sml/ensemble/tests/adaboost_test.py +++ b/sml/ensemble/tests/adaboost_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 Ant Group Co., Ltd. +# Copyright 2024 Ant Group Co., Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ import spu.spu_pb2 as spu_pb2 # type: ignore import spu.utils.simulation as spsim +from sml.tree.tree import DecisionTreeClassifier as sml_dtc from sml.ensemble.adaboost import AdaBoostClassifier as sml_Adaboost MAX_DEPTH = 3 @@ -28,18 +29,16 @@ class UnitTests(unittest.TestCase): def test_Ada(self): def proc_wrapper( - estimator = "dtc", - n_estimators = 10, - max_depth = MAX_DEPTH, - learning_rate = 1.0, - n_classes = 3, - ): + estimator, + n_estimators, + learning_rate, + algorithm, + ): ada_custom = sml_Adaboost( - estimator = "dtc", - n_estimators = 10, - max_depth = MAX_DEPTH, - learning_rate = 1.0, - n_classes = 3, + estimator = estimator, + n_estimators = n_estimators, + learning_rate = learning_rate, + algorithm=algorithm, ) def proc(X, y): @@ -81,12 +80,12 @@ def load_data(): score_plain = ada.score(X, y) #run + dtc = sml_dtc("gini", "best", 3, 3) proc = proc_wrapper( - estimator = "dtc", + estimator = dtc, n_estimators = 3, - max_depth = 3, learning_rate = 1.0, - n_classes = 3, + algorithm="discrete", ) result = spsim.sim_jax(sim, proc)(X, y) diff --git a/sml/tree/tree.py b/sml/tree/tree.py index c055ca43..233c38fd 100644 --- a/sml/tree/tree.py +++ b/sml/tree/tree.py @@ -52,7 +52,6 @@ def __init__(self, criterion, splitter, max_depth, n_labels): ), "max_depth should not be None and must > 0." self.max_depth = max_depth self.n_labels = n_labels - # self.sample_weight = sample_weight def fit(self, X, y, sample_weight=None): self.T, self.F = odtt(X, y, self.max_depth, self.n_labels, sample_weight)