Skip to content

Commit

Permalink
fixed rf
Browse files Browse the repository at this point in the history
  • Loading branch information
xbw886 committed Jul 29, 2024
1 parent 62e505d commit d9bd661
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 344 deletions.
27 changes: 5 additions & 22 deletions sml/ensemble/forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import jax
import math
import random
random.seed(20)

from sml.tree.tree import DecisionTreeClassifier as sml_dtc

Expand Down Expand Up @@ -133,20 +134,20 @@ def fit(self, X, y):
assert (
0 < self.max_features <= self.n_features
), "0 < max_features <= n_features when it's an integer"
# self.max_features = jnp.array(self.max_features, dtype=int)

elif isinstance(self.max_features, float):
assert (
0 < self.max_features <= 1
), "max_features should be in the range (0, 1] when it's a float"
self.max_features = (int)(self.max_features * self.n_features)
# self.max_features = jnp.array((self.max_features * n_features), dtype=int)

elif isinstance(self.max_features, str):
if self.max_features == 'sqrt':
self.max_features = (int)(math.sqrt(self.n_features))
# self.max_features = jnp.array(jnp.sqrt(n_features), dtype=int)

elif self.max_features == 'log2':
self.max_features = (int)(math.log2(self.n_features))
# self.max_features = jnp.array(jnp.log2(n_features), dtype=int)

else:
self.max_features = self.n_features
else:
Expand Down Expand Up @@ -175,20 +176,6 @@ def predict(self, X):

tree_predictions = jnp.array(predictions_list).T


# 目前jit函数single_tree_predict的时候,会将features_indices当成tracer导致报错
# features_indices = jnp.array(self.features_indices)
# features_indices = self.features_indices
# print(features_indices[3])
# # Define a function that predicts using a single tree and the corresponding features
# def single_tree_predict(i, X):
# features = self.features_indices[i]
# print(features)
# return self.trees[i].predict(X[:, features])

# # Vectorize the single_tree_predict function
# tree_predictions = jax.vmap(single_tree_predict, in_axes=(0, None))(jnp.arange(self.n_estimators), X)

y_pred, _ = jax_mode_row(tree_predictions)

return y_pred.ravel()
Expand All @@ -197,14 +184,11 @@ def predict(self, X):
def jax_mode_row(data):
# 获取每行的众数

# 获取数据的形状
num_rows, num_cols = data.shape

# 初始化众数和计数的数组
modes_list = []
counts_list = []

# 计算每行的众数及其计数
for row in range(num_rows):
row_data = data[row, :]
unique_values, value_counts = jnp.unique(
Expand All @@ -214,7 +198,6 @@ def jax_mode_row(data):
modes_list.append(unique_values[max_count_idx])
counts_list.append(value_counts[max_count_idx])

# 将列表转换为 jnp.array
modes = jnp.array(modes_list, dtype=data.dtype)
counts = jnp.array(counts_list, dtype=jnp.int32)

Expand Down
68 changes: 50 additions & 18 deletions sml/tree/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ def __init__(self, criterion, splitter, max_depth, n_labels):
self.max_depth = max_depth
self.n_labels = n_labels

def fit(self, X, y):
self.T, self.F = odtt(X, y, self.max_depth, self.n_labels)

def fit(self, X, y, sample_weight=None):
self.T, self.F = odtt(X, y, self.max_depth, self.n_labels, sample_weight)
return self

def predict(self, X):
Expand Down Expand Up @@ -115,7 +116,7 @@ def oaa_elementwise(array, index_array):


# def oblivious_learning(X, y, T, F, M, Cn, h):
def oblivious_learning(X, y, T, F, M, h, Cn, n_labels):
def oblivious_learning(X, y, T, F, M, h, Cn, n_labels, sample_weight=None):
'''partition the data and count the number of data samples.
params:
Expand All @@ -135,27 +136,54 @@ def oblivious_learning(X, y, T, F, M, h, Cn, n_labels):
Dval = oaae(X, Tval)
M = 2 * M + Dval + 1

# (n_leaves)
LCidx = jnp.arange(0, n_h)
isLeaf = jnp.equal(F[n_h - 1 : 2 * n_h - 1], jnp.ones(n_h))
# (n_samples, n_leaves)
LCF = jnp.equal(M[:, jnp.newaxis] - n_h + 1, LCidx)
LCF = LCF * isLeaf
# (n_samples, n_leaves, n_labels, 2 * n_features)

Cd = jnp.zeros((n_d, n_h, n_labels + 1, 2 * n_f))
Cd = Cd.at[:, :, 0, 0::2].set(jnp.tile((1 - X)[:, jnp.newaxis, :], (1, n_h, 1)))
Cd = Cd.at[:, :, 0, 1::2].set(jnp.tile((X)[:, jnp.newaxis, :], (1, n_h, 1)))
for i in range(n_labels):
Cd = Cd.at[:, :, i + 1, 0::2].set(
jnp.tile(
((1 - X) * (i == y)[:, jnp.newaxis])[:, jnp.newaxis, :], (1, n_h, 1)
)
if sample_weight is not None:
Cd = Cd.at[:, :, 0, 0::2].set(
jnp.tile((1 - X)[:, jnp.newaxis, :] * sample_weight[:, jnp.newaxis, jnp.newaxis], (1, n_h, 1))
)
Cd = Cd.at[:, :, i + 1, 1::2].set(
jnp.tile(((X) * (i == y)[:, jnp.newaxis])[:, jnp.newaxis, :], (1, n_h, 1))
Cd = Cd.at[:, :, 0, 1::2].set(
jnp.tile((X)[:, jnp.newaxis, :] * sample_weight[:, jnp.newaxis, jnp.newaxis], (1, n_h, 1))
)
else:
Cd = Cd.at[:, :, 0, 0::2].set(jnp.tile((1 - X)[:, jnp.newaxis, :], (1, n_h, 1)))
Cd = Cd.at[:, :, 0, 1::2].set(jnp.tile((X)[:, jnp.newaxis, :], (1, n_h, 1)))

for i in range(n_labels):
if sample_weight is not None:
# sample_weight = sample_weight.reshape(-1, 1, 1)
Cd = Cd.at[:, :, i + 1, 0::2].set(
jnp.tile(
((1 - X)[:, jnp.newaxis, :] * (i == y)[:, jnp.newaxis, jnp.newaxis] * sample_weight[:, jnp.newaxis, jnp.newaxis]),
(1, n_h, 1)
)
)
Cd = Cd.at[:, :, i + 1, 1::2].set(
jnp.tile(
((X)[:, jnp.newaxis, :] * (i == y)[:, jnp.newaxis, jnp.newaxis] * sample_weight[:, jnp.newaxis, jnp.newaxis]),
(1, n_h, 1)
)
)
else:
Cd = Cd.at[:, :, i + 1, 0::2].set(
jnp.tile(
((1 - X) * (i == y)[:, jnp.newaxis])[:, jnp.newaxis, :],
(1, n_h, 1)
)
)
Cd = Cd.at[:, :, i + 1, 1::2].set(
jnp.tile(
((X) * (i == y)[:, jnp.newaxis])[:, jnp.newaxis, :],
(1, n_h, 1)
)
)

Cd = Cd * LCF[:, :, jnp.newaxis, jnp.newaxis]
# (n_leaves, n_labels+1, 2*n_features)

new_Cn = jnp.sum(Cd, axis=0)

if h != 0:
Expand All @@ -165,6 +193,7 @@ def oblivious_learning(X, y, T, F, M, h, Cn, n_labels):
return new_Cn, M



def oblivious_heuristic_computation(Cn, gamma, F, h, n_labels):
'''Compute gini index, find the best feature, and update F.
Expand Down Expand Up @@ -221,7 +250,7 @@ def oblivious_node_split(SD, T, F, Cn, h, max_depth):
return T, Cn


def oblivious_DT_training(X, y, max_depth, n_labels):
def oblivious_DT_training(X, y, max_depth, n_labels, sample_weight=None):
n_samples, n_features = X.shape
T = jnp.zeros((2 ** (max_depth + 1) - 1))
F = jnp.ones((2**max_depth - 1))
Expand All @@ -231,7 +260,10 @@ def oblivious_DT_training(X, y, max_depth, n_labels):

h = 0
while h < max_depth:
Cn, M = ol(X, y, T, F, M, h, Cn, n_labels)
if sample_weight is not None:
Cn, M = ol(X, y, T, F, M, h, Cn, n_labels, sample_weight)
else:
Cn, M = ol(X, y, T, F, M, h, Cn, n_labels)

SD, gamma, F = ohc(Cn, gamma, F, h, n_labels)

Expand Down
Loading

0 comments on commit d9bd661

Please sign in to comment.