From 51de571875e78913531dabe939e10e4fcd89ddc9 Mon Sep 17 00:00:00 2001 From: xbw886 <1042740841@qq.com> Date: Sat, 3 Aug 2024 10:40:28 +0800 Subject: [PATCH] ada --- sml/ensemble/adaboost.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/sml/ensemble/adaboost.py b/sml/ensemble/adaboost.py index f6d1c77c..9dea1a56 100644 --- a/sml/ensemble/adaboost.py +++ b/sml/ensemble/adaboost.py @@ -66,7 +66,7 @@ def __init__( self.estimator_weight = jnp.zeros(self.n_estimators, dtype=jnp.float32) self.estimator_errors = jnp.ones(self.n_estimators, dtype=jnp.float32) self.estimator_flags_ = [] - self.early_stop = False # 添加 early_stop 标志 + # self.early_stop = False # 添加 early_stop 标志 def _num_samples(self, x): """返回x中的样本数量.""" @@ -172,12 +172,12 @@ def _boost_discrete(self, iboost, X, y, sample_weight, estimator): # 判断是否需要提前停止 # if estimator_error > 0.0: # self.early_stop = True - self.early_stop = lax.cond( - estimator_error > 0.0, - lambda _: jnp.array(True, dtype=jnp.bool_), - lambda _: jnp.array(False, dtype=jnp.bool_), - operand=None - ) + # self.early_stop = lax.cond( + # estimator_error > 0.0, + # lambda _: jnp.array(True, dtype=jnp.bool_), + # lambda _: jnp.array(False, dtype=jnp.bool_), + # operand=None + # ) def true_0_fun(sample_weight): return sample_weight, 1.0, 0.0, jnp.array(False, dtype=jnp.bool_) @@ -198,12 +198,12 @@ def last_iboost(sample_weight): not_last_iboost, last_iboost, sample_weight) flag = estimator_error < 1.0 - (1.0 / n_classes) - flag = lax.cond( - self.early_stop, - lambda _: jnp.array(False, dtype=jnp.bool_), - lambda _: flag, - operand=None - ) + # flag = lax.cond( + # self.early_stop, + # lambda _: jnp.array(False, dtype=jnp.bool_), + # lambda _: flag, + # operand=None + # ) return sample_weight, estimator_weight, estimator_error, flag