Skip to content

Commit

Permalink
resolve problems
Browse files Browse the repository at this point in the history
  • Loading branch information
xbw886 committed Aug 3, 2024
1 parent 597a0b6 commit 7dda9f2
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 106 deletions.
4 changes: 3 additions & 1 deletion sml/ensemble/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ py_library(
deps = [
"//sml/tree:tree",
]

)

py_library(
name = "forest",
srcs = ["forest.py"],
deps = [
Expand Down
150 changes: 62 additions & 88 deletions sml/ensemble/adaboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ class AdaBoostClassifier:
learning_rate : float
The step size used to update the model weights during training.
It's an float, must learning_rate > 0.
algorithm : str (default='discrete')
The boosting algorithm to use. Only the SAMME discrete algorithm is used in this implementation.
In scikit-learn, the Real Boosting Algorithm (SAMME.R) will be deprecated.
epsilon : float (default=1e-5)
A small positive value used in calculations to avoid division by zero and other numerical issues.
Must be greater than 0 and less than 0.1.
"""
def __init__(
Expand All @@ -43,6 +51,7 @@ def __init__(
n_estimators,
learning_rate,
algorithm,
epsilon = 1e-5,
):
assert isinstance(estimator, sml_dtc), "Estimator other than sml_dtc is not supported."
assert (
Expand All @@ -54,19 +63,23 @@ def __init__(
"You can refer to the official documentation for more details: "
"https://github.com/scikit-learn/scikit-learn/issues/26784"
)

assert (
epsilon > 0 and epsilon < 0.1
), "epsilon must be > 0 and < 0.1."

self.estimator = estimator
self.n_estimators = n_estimators
self.learning_rate = learning_rate
self.algorithm = algorithm
self.epsilon = epsilon

self.n_classes = estimator.n_labels

self.estimators_ = []
self.estimator_weight = jnp.zeros(self.n_estimators, dtype=jnp.float32)
self.estimator_errors = jnp.ones(self.n_estimators, dtype=jnp.float32)
self.estimator_flags_ = []
# self.early_stop = False # 添加 early_stop 标志
self.estimator_weight_ = jnp.zeros(self.n_estimators, dtype=jnp.float32)
self.estimator_errors_ = jnp.ones(self.n_estimators, dtype=jnp.float32)
self.estimator_flags_ = jnp.zeros(self.n_estimators, dtype=jnp.bool_)
self.early_stop = False # 添加 early_stop 标志

def _num_samples(self, x):
"""返回x中的样本数量."""
Expand Down Expand Up @@ -133,22 +146,20 @@ def fit(self, X, y, sample_weight=None):
sample_weight /= sample_weight.sum()

self.classes = y

epsilon = jnp.finfo(sample_weight.dtype).eps

self.estimator_weight_ = jnp.zeros(self.n_estimators, dtype=jnp.float32)
self.estimator_errors_ = jnp.ones(self.n_estimators, dtype=jnp.float32)

epsilon = self.epsilon

for iboost in range(self.n_estimators):
sample_weight = jnp.clip(sample_weight, a_min=epsilon, a_max=None)

estimator = copy.deepcopy(self.estimator)
sample_weight, estimator_weight, estimator_error = self._boost_discrete(
sample_weight, estimator_weight, estimator_error, flag = self._boost_discrete(
iboost, X, y, sample_weight, estimator,
)

self.estimator_weight_ = self.estimator_weight_.at[iboost].set(estimator_weight)
self.estimator_errors_ = self.estimator_errors_.at[iboost].set(estimator_error)
self.estimator_flags_ = self.estimator_flags_.at[iboost].set(flag)

sample_weight_sum = jnp.sum(sample_weight)
if iboost < self.n_estimators - 1:
Expand All @@ -161,6 +172,7 @@ def _boost_discrete(self, iboost, X, y, sample_weight, estimator):
self.estimators_.append(estimator)

n_classes = self.n_classes
epsilon = self.epsilon

estimator.fit(X, y, sample_weight=sample_weight)

Expand All @@ -169,55 +181,57 @@ def _boost_discrete(self, iboost, X, y, sample_weight, estimator):
incorrect = y_predict != y
estimator_error = jnp.mean(jnp.average(incorrect, weights=sample_weight, axis=0))

# 判断是否需要提前停止
# if estimator_error > 0.0:
# self.early_stop = True
# self.early_stop = lax.cond(
# estimator_error > 0.0,
# lambda _: jnp.array(True, dtype=jnp.bool_),
# lambda _: jnp.array(False, dtype=jnp.bool_),
# operand=None
# )
self.early_stop = lax.cond(
estimator_error <= epsilon,
lambda _: jnp.array(True, dtype=jnp.bool_),
lambda _: self.early_stop,
operand=None
)

def true_0_fun(sample_weight):
return sample_weight, 1.0, 0.0, jnp.array(False, dtype=jnp.bool_)

def false_0_fun(sample_weight):
estimator_weight = self.learning_rate * (
jnp.log((1.0 - estimator_error) / estimator_error) + jnp.log(n_classes - 1.0)
flag = estimator_error < 1.0 - (1.0 / n_classes)
flag = lax.cond(
self.early_stop,
lambda _: jnp.array(False, dtype=jnp.bool_),
lambda _: flag,
operand=None
)
def not_last_iboost(sample_weight):
# Only boost positive weights
sample_weight *= jnp.exp(estimator_weight * incorrect)
return sample_weight

def last_iboost(sample_weight):
return sample_weight

sample_weight = lax.cond(iboost != self.n_estimators - 1,
not_last_iboost, last_iboost, sample_weight)
# Update weights only if flag is True
def update_weights(params):
estimator_error, incorrect, sample_weight = params
estimator_weight = self.learning_rate * (
jnp.log((1.0 - estimator_error) / estimator_error) + jnp.log(n_classes - 1.0)
)
sample_weight *= jnp.exp(estimator_weight * incorrect)
return sample_weight, estimator_weight

flag = estimator_error < 1.0 - (1.0 / n_classes)
# flag = lax.cond(
# self.early_stop,
# lambda _: jnp.array(False, dtype=jnp.bool_),
# lambda _: flag,
# operand=None
# )
def skip_update(params):
estimator_error, incorrect, sample_weight = params
return sample_weight, 0.0 # Return zero for estimator_weight

sample_weight, estimator_weight = lax.cond(
flag,
update_weights,
skip_update,
operand=(estimator_error, incorrect, sample_weight)
)

return sample_weight, estimator_weight, estimator_error, flag

sample_weight, estimator_weight, estimator_error, flag = lax.cond(
estimator_error <= 0.0, true_0_fun, false_0_fun, sample_weight
estimator_error <= epsilon, true_0_fun, false_0_fun, sample_weight
)

self.estimator_flags_.append(flag) # 维护 flag 属性

return sample_weight, estimator_weight, estimator_error
return sample_weight, estimator_weight, estimator_error, flag


def predict(self, X):
pred = self.decision_function(X)
print(self.early_stop)

if self.n_classes == 2:
return self.classes.take(pred > 0, axis=0)
Expand All @@ -228,65 +242,25 @@ def predict(self, X):
def decision_function(self, X):
n_classes = self.n_classes
classes = self.classes[:, jnp.newaxis]

# pred = sum(
# jnp.where(
# (estimator.predict(X) == classes).T,
# w,
# -1 / (n_classes - 1) * w,
# )
# for estimator, w in zip(self.estimators_, self.estimator_weight_)
# )
# pred /= self.estimator_weight_.sum()

print(self.estimators_)
print(self.estimator_weight_)
print('--------')
print(self.estimator_flags_)

pred = sum(
jnp.where(
(estimator.predict(X) == classes).T,
w,
-1 / (n_classes - 1) * w,
) * flag # 使用 flag
) * flag
for estimator, w, flag in zip(self.estimators_, self.estimator_weight_, self.estimator_flags_)
)

# pred = sum(
# jnp.where(
# (estimator.predict(X) == classes).T,
# w,
# -1 / (n_classes - 1) * w,
# ) * flag
# for estimator, w, flag in zip(self.estimators_, self.estimator_weight_, self.estimator_flags_)
# if not self.early_stop or flag # 使用 early_stop 进行过滤
# )

# 将列表转换为 JAX 数组,并进行求和
weights_flags = jnp.array([w * flag for w, flag in zip(self.estimator_weight_, self.estimator_flags_)])
pred /= jnp.sum(weights_flags)

# # 计算每个估计器的预测结果
# predictions = [
# jnp.where(
# (estimator.predict(X) == classes).T,
# w,
# -1 / (n_classes - 1) * w,
# )
# for estimator, w, flag in zip(self.estimators_, self.estimator_weight_, self.estimator_flags_)
# ]

# # 使用 lax.cond 处理 early_stop 逻辑
# def apply_flags(predictions, weights_flags):
# return sum(p * f for p, f in zip(predictions, weights_flags))

# weights_flags = jnp.array([w * flag for w, flag in zip(self.estimator_weight_, self.estimator_flags_)])

# pred = lax.cond(
# self.early_stop,
# lambda _: apply_flags(predictions, jnp.array([0] * len(predictions))),
# lambda _: apply_flags(predictions, weights_flags),
# operand=None
# )

# pred /= jnp.sum(weights_flags)

if n_classes == 2:
pred[:, 0] *= -1
return pred.sum(axis=1)
Expand Down
7 changes: 4 additions & 3 deletions sml/ensemble/emulations/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,19 @@ load("@rules_python//python:defs.bzl", "py_binary")
package(default_visibility = ["//visibility:public"])

py_binary(
<<<<<<< HEAD
name = "adaboost_emul",
srcs = ["adaboost_emul.py"],
deps = [
"//sml/ensemble:adaboost",
"//sml/utils:emulation",
]
=======
)

py_binary(
name = "forest_emul",
srcs = ["forest_emul.py"],
deps = [
"//sml/ensemble:forest",
>>>>>>> 854f3ef0925bc419ae804ae77a4a9843ec15db7f
"//sml/utils:emulation",
],
)
30 changes: 16 additions & 14 deletions sml/ensemble/emulations/adaboost_emul.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,26 @@
from sklearn.tree import DecisionTreeClassifier

import sml.utils.emulation as emulation
from sml.tree.tree import DecisionTreeClassifier as sml_dtc
from sml.ensemble.adaboost import AdaBoostClassifier as sml_Adaboost

MAX_DEPTH = 3
CONFIG_FILE = emulation.CLUSTER_ABY3_3PC

def emul_ada(mode=emulation.Mode.MULTIPROCESS):
def proc_wrapper(
estimator = "dtc",
n_estimators = 50,
max_depth = MAX_DEPTH,
learning_rate = 1.0,
n_classes = 3,
):
estimator,
n_estimators,
learning_rate,
algorithm,
epsilon,
):
ada_custom = sml_Adaboost(
estimator = "dtc",
n_estimators = 50,
max_depth = MAX_DEPTH,
learning_rate = 1.0,
n_classes = 3,
estimator = estimator,
n_estimators = n_estimators,
learning_rate = learning_rate,
algorithm=algorithm,
epsilon=epsilon,
)

def proc(X, y):
Expand Down Expand Up @@ -86,12 +87,13 @@ def load_data():
X_spu, y_spu = emulator.seal(X, y)

# run
dtc = sml_dtc("gini", "best", 3, 3)
proc = proc_wrapper(
estimator = "dtc",
estimator = dtc,
n_estimators = 3,
max_depth = MAX_DEPTH,
learning_rate = 1.0,
n_classes = 3,
algorithm="discrete",
epsilon = 1e-5,
)
start = time.time()

Expand Down
3 changes: 3 additions & 0 deletions sml/ensemble/tests/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ py_test(
"//spu:init",
"//spu/utils:simulation",
],
)

py_test(
name = "forest_test",
srcs = ["forest_test.py"],
deps = [
Expand Down
3 changes: 3 additions & 0 deletions sml/ensemble/tests/adaboost_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@ def proc_wrapper(
n_estimators,
learning_rate,
algorithm,
epsilon,
):
ada_custom = sml_Adaboost(
estimator = estimator,
n_estimators = n_estimators,
learning_rate = learning_rate,
algorithm=algorithm,
epsilon=epsilon,
)

def proc(X, y):
Expand Down Expand Up @@ -86,6 +88,7 @@ def load_data():
n_estimators = 3,
learning_rate = 1.0,
algorithm="discrete",
epsilon = 1e-5,
)

result = spsim.sim_jax(sim, proc)(X, y)
Expand Down

0 comments on commit 7dda9f2

Please sign in to comment.