Skip to content

Commit

Permalink
使用 SPU 逻辑回归算法功能性增强[add L1 & elasticnet penalty] (#363)
Browse files Browse the repository at this point in the history
# Pull Request

## What problem does this PR solve?
add L1 & elasticnet penalty

Issue Number: Fixed #252

## Possible side effects?

- Performance:
test:
```
Ran 1 test in 30.356s

OK
penalty_list=['l1', 'l2', 'elasticnet']
l1 ROC Score: 0.9923365572644153
l2 ROC Score: 0.9919665979599387
elasticnet ROC Score: 0.9921912161090851
```
emulations
```
l1 ROC Score: 0.9923233444321125
l2 ROC Score: 0.9919665979599387
elasticnet ROC Score: 0.9921780032767824
[2023-10-10 03:11:33,316] Shutdown multiprocess cluster...
```

- Backward compatibility:

---------

Signed-off-by: magic-hya <[email protected]>
  • Loading branch information
magic-hya authored Oct 11, 2023
1 parent 79615b7 commit 815c15e
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 37 deletions.
25 changes: 15 additions & 10 deletions sml/linear_model/emulations/logistic_emul.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,20 @@


def emul_LogisticRegression(mode: emulation.Mode.MULTIPROCESS):
penalty_list = ["l1", "l2", "elasticnet"]
print(f"penalty_list={penalty_list}")

# Test SGDClassifier
def proc(x, y):
def proc(x, y, penalty):
model = LogisticRegression(
epochs=3,
epochs=1,
learning_rate=0.1,
batch_size=8,
solver="sgd",
penalty="l2",
penalty=penalty,
sig_type="sr",
l2_norm=1.0,
C=1.0,
l1_ratio=0.5,
class_weight=None,
multi_class="binary",
)
Expand Down Expand Up @@ -66,12 +70,13 @@ def proc(x, y):
X.values, y.values.reshape(-1, 1)
) # X, y should be two-dimension array

# Run
result = emulator.run(proc)(X_spu, y_spu)
print("Predict result prob: ", result[0])
print("Predict result label: ", result[1])

print("ROC Score: ", roc_auc_score(y.values, result[0]))
for i in range(len(penalty_list)):
penalty = penalty_list[i]
# Run
result = emulator.run(proc, static_argnums=(2,))(X_spu, y_spu, penalty)
# print("Predict result prob: ", result[0])
# print("Predict result label: ", result[1])
print(f"{penalty} ROC Score: {roc_auc_score(y.values, result[0])}")

finally:
emulator.down()
Expand Down
46 changes: 31 additions & 15 deletions sml/linear_model/logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@

class Penalty(Enum):
NONE = 'None'
L1 = 'l1' # not supported
L1 = 'l1'
L2 = 'l2'
Elastic = 'elasticnet' # not supported
Elastic = 'elasticnet'


class MultiClass(Enum):
Expand All @@ -49,7 +49,7 @@ class LogisticRegression:
Parameters
----------
penalty: Specify the norm of the penalty:
{'l1', 'l2', 'elasticnet', 'None'}, default='l2' (current only support l2)
{'l1', 'l2', 'elasticnet', 'None'}, default='l2'
solver: Algorithm to use in the optimization problem, default='sgd'.
Expand All @@ -64,7 +64,14 @@ class LogisticRegression:
sig_type: the approximation method for sigmoid function, default='sr'
for all choices, refer to `SigType`
l2_norm: the strength of L2 norm, must be a positive float, default=0.01
C: float, default=1.0
Inverse of regularization strength; must be a positive float. Like in support vector machines,
smaller values specify stronger regularization.
l1_ratio: float, default=0.5
The Elastic-Net mixing parameter, with 0 <= l1_ratio <= 1. Only used if penalty='elasticnet'.
Setting l1_ratio=0 is equivalent to using penalty='l2', while setting l1_ratio=1 is equivalent to using penalty='l1'.
For 0 < l1_ratio <1, the penalty is a combination of L1 and L2.
epochs, learning_rate, batch_size: hyper-parameters for sgd solver
epochs: default=20
Expand All @@ -80,7 +87,8 @@ def __init__(
multi_class: str = 'binary',
class_weight=None,
sig_type: str = 'sr',
l2_norm: float = 0.01,
C: float = 1.0,
l1_ratio: float = 0.5,
epochs: int = 20,
learning_rate: float = 0.1,
batch_size: int = 512,
Expand All @@ -89,10 +97,12 @@ def __init__(
assert epochs > 0, f"epochs should >0"
assert learning_rate > 0, f"learning_rate should >0"
assert batch_size > 0, f"batch_size should >0"
assert penalty == 'l2', "only support L2 penalty for now"
assert solver == 'sgd', "only support sgd solver for now"
if penalty == Penalty.L2:
assert l2_norm > 0, f"l2_norm should >0 if use L2 penalty"
assert C > 0, f"C should >0"
if penalty == Penalty.Elastic:
assert (
0 <= l1_ratio <= 1
), f"l1_ratio should in `[0, 1]` if use Elastic penalty"
assert penalty in [
e.value for e in Penalty
], f"penalty should in {[e.value for e in Penalty]}, but got {penalty}"
Expand All @@ -105,7 +115,8 @@ def __init__(
self._epochs = epochs
self._learning_rate = learning_rate
self._batch_size = batch_size
self._l2_norm = l2_norm
self._C = C
self._l1_ratio = l1_ratio
self._penalty = Penalty(penalty)
self._sig_type = SigType(sig_type)
self._class_weight = class_weight
Expand Down Expand Up @@ -142,20 +153,25 @@ def _update_weights(
err = pred - y_slice
grad = jnp.matmul(jnp.transpose(x_slice), err)

if self._penalty == Penalty.L2:
if self._penalty != Penalty.NONE:
w_with_zero_bias = jnp.resize(w, (num_feat, 1))
w_with_zero_bias = jnp.concatenate(
(w_with_zero_bias, jnp.zeros((1, 1))),
axis=0,
)
grad = grad + w_with_zero_bias * self._l2_norm
if self._penalty == Penalty.L2:
reg = w_with_zero_bias * 1.0 / self._C
elif self._penalty == Penalty.L1:
raise NotImplementedError
reg = jnp.sign(w_with_zero_bias) * 1.0 / self._C
elif self._penalty == Penalty.Elastic:
raise NotImplementedError
reg = (
jnp.sign(w_with_zero_bias) * self._l1_ratio * 1.0 / self._C
+ w_with_zero_bias * (1 - self._l1_ratio) * 1.0 / self._C
)
else:
# None penalty
raise NotImplementedError
reg = 0

grad = grad + reg

step = (self._learning_rate * grad) / batch_size

Expand Down
29 changes: 17 additions & 12 deletions sml/linear_model/tests/logistic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,20 @@ def test_logistic(self):
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64
)

penalty_list = ["l1", "l2", "elasticnet"]
print(f"penalty_list={penalty_list}")

# Test SGDClassifier
def proc(x, y):
def proc(x, y, penalty):
model = LogisticRegression(
epochs=3,
epochs=1,
learning_rate=0.1,
batch_size=8,
solver="sgd",
penalty="l2",
penalty=penalty,
sig_type="sr",
l2_norm=1.0,
C=1.0,
l1_ratio=0.5,
class_weight=None,
multi_class="binary",
)
Expand All @@ -62,14 +66,15 @@ def proc(x, y):
X = scalar.fit_transform(X)
X = pd.DataFrame(X, columns=cols)

# Run
result = spsim.sim_jax(sim, proc)(
X.values, y.values.reshape(-1, 1)
) # X, y should be two-dimension array
print("Predict result prob: ", result[0])
print("Predict result label: ", result[1])

print("ROC Score: ", roc_auc_score(y.values, result[0]))
for i in range(len(penalty_list)):
penalty = penalty_list[i]
# Run
result = spsim.sim_jax(sim, proc, static_argnums=(2,))(
X.values, y.values.reshape(-1, 1), penalty
) # X, y should be two-dimension array
# print("Predict result prob: ", result[0])
# print("Predict result label: ", result[1])
print(f"{penalty} ROC Score: {roc_auc_score(y.values, result[0])}")


if __name__ == "__main__":
Expand Down

0 comments on commit 815c15e

Please sign in to comment.