Skip to content

Commit

Permalink
ada
Browse files Browse the repository at this point in the history
  • Loading branch information
xbw886 committed Aug 3, 2024
1 parent 37faf39 commit 51de571
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions sml/ensemble/adaboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
self.estimator_weight = jnp.zeros(self.n_estimators, dtype=jnp.float32)
self.estimator_errors = jnp.ones(self.n_estimators, dtype=jnp.float32)
self.estimator_flags_ = []
self.early_stop = False # 添加 early_stop 标志
# self.early_stop = False # 添加 early_stop 标志

def _num_samples(self, x):
"""返回x中的样本数量."""
Expand Down Expand Up @@ -172,12 +172,12 @@ def _boost_discrete(self, iboost, X, y, sample_weight, estimator):
# 判断是否需要提前停止
# if estimator_error > 0.0:
# self.early_stop = True
self.early_stop = lax.cond(
estimator_error > 0.0,
lambda _: jnp.array(True, dtype=jnp.bool_),
lambda _: jnp.array(False, dtype=jnp.bool_),
operand=None
)
# self.early_stop = lax.cond(
# estimator_error > 0.0,
# lambda _: jnp.array(True, dtype=jnp.bool_),
# lambda _: jnp.array(False, dtype=jnp.bool_),
# operand=None
# )

def true_0_fun(sample_weight):
return sample_weight, 1.0, 0.0, jnp.array(False, dtype=jnp.bool_)
Expand All @@ -198,12 +198,12 @@ def last_iboost(sample_weight):
not_last_iboost, last_iboost, sample_weight)

flag = estimator_error < 1.0 - (1.0 / n_classes)
flag = lax.cond(
self.early_stop,
lambda _: jnp.array(False, dtype=jnp.bool_),
lambda _: flag,
operand=None
)
# flag = lax.cond(
# self.early_stop,
# lambda _: jnp.array(False, dtype=jnp.bool_),
# lambda _: flag,
# operand=None
# )

return sample_weight, estimator_weight, estimator_error, flag

Expand Down

0 comments on commit 51de571

Please sign in to comment.