Skip to content

Commit

Permalink
Add logistic regression multi classification ovr (#368)
Browse files Browse the repository at this point in the history
# Pull Request

## What problem does this PR solve?
增加多分类ovr功能

Issue Number: Fixed #252 

## Possible side effects?

- Performance:

- Backward compatibility:

---------

Signed-off-by: magic-hya <[email protected]>
  • Loading branch information
magic-hya authored Oct 30, 2023
1 parent 77e90e2 commit 4ab16e2
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 106 deletions.
113 changes: 68 additions & 45 deletions sml/linear_model/emulations/logistic_emul.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import sys

import pandas as pd
from sklearn.datasets import load_breast_cancer
from sklearn.datasets import load_breast_cancer, load_wine
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import MinMaxScaler

Expand All @@ -26,61 +26,84 @@
from sml.linear_model.logistic import LogisticRegression


def emul_LogisticRegression(mode: emulation.Mode.MULTIPROCESS):
def load_data(multi_class="binary"):
# Create dataset
if multi_class == "binary":
X, y = load_breast_cancer(return_X_y=True, as_frame=True)
else:
X, y = load_wine(return_X_y=True, as_frame=True)
scalar = MinMaxScaler(feature_range=(-2, 2))
cols = X.columns
X = scalar.fit_transform(X)
X = pd.DataFrame(X, columns=cols)

# mark these data to be protected in SPU
X_spu, y_spu = emulator.seal(
X.values, y.values.reshape(-1, 1)
) # X, y should be two-dimension array
return X, y, X_spu, y_spu


def proc(x, y, penalty, multi_class="binary"):
class_labels = [0, 1] if multi_class == "binary" else [0, 1, 2]
model = LogisticRegression(
epochs=1,
learning_rate=0.1,
batch_size=8,
solver="sgd",
penalty=penalty,
sig_type="sr",
C=1.0,
l1_ratio=0.5,
class_weight=None,
multi_class=multi_class,
class_labels=class_labels,
)

model = model.fit(x, y)
prob = model.predict_proba(x)
pred = model.predict(x)
return prob, pred


# Test Binary classification
def emul_LogisticRegression(emulator):
penalty_list = ["l1", "l2", "elasticnet"]
print(f"penalty_list={penalty_list}")

# Test SGDClassifier
def proc(x, y, penalty):
model = LogisticRegression(
epochs=1,
learning_rate=0.1,
batch_size=8,
solver="sgd",
penalty=penalty,
sig_type="sr",
C=1.0,
l1_ratio=0.5,
class_weight=None,
multi_class="binary",
X, y, X_spu, y_spu = load_data(multi_class="binary")
for i in range(len(penalty_list)):
penalty = penalty_list[i]
# Run
result = emulator.run(proc, static_argnums=(2, 3))(
X_spu, y_spu, penalty, "binary"
)
# print("Predict result prob: ", result[0])
# print("Predict result label: ", result[1])
print(f"{penalty} ROC Score: {roc_auc_score(y.values, result[0])}")


model = model.fit(x, y)
# Test Multi classification
def emul_LogisticRegression_multi_classificatio(emulator):
X, y, X_spu, y_spu = load_data(multi_class="ovr")
# Run
result = emulator.run(proc, static_argnums=(2, 3))(X_spu, y_spu, "l2", "ovr")
print(
f"Multi classification OVR ROC Score: {roc_auc_score(y.values, result[0], multi_class='ovr')}"
)

prob = model.predict_proba(x)
pred = model.predict(x)
return prob, pred

if __name__ == "__main__":
try:
# bandwidth and latency only work for docker mode
emulator = emulation.Emulator(
emulation.CLUSTER_ABY3_3PC, mode, bandwidth=300, latency=20
emulation.CLUSTER_ABY3_3PC,
emulation.Mode.MULTIPROCESS,
bandwidth=300,
latency=20,
)
emulator.up()

# Create dataset
X, y = load_breast_cancer(return_X_y=True, as_frame=True)
scalar = MinMaxScaler(feature_range=(-2, 2))
cols = X.columns
X = scalar.fit_transform(X)
X = pd.DataFrame(X, columns=cols)

# mark these data to be protected in SPU
X_spu, y_spu = emulator.seal(
X.values, y.values.reshape(-1, 1)
) # X, y should be two-dimension array

for i in range(len(penalty_list)):
penalty = penalty_list[i]
# Run
result = emulator.run(proc, static_argnums=(2,))(X_spu, y_spu, penalty)
# print("Predict result prob: ", result[0])
# print("Predict result label: ", result[1])
print(f"{penalty} ROC Score: {roc_auc_score(y.values, result[0])}")

emul_LogisticRegression(emulator)
emul_LogisticRegression_multi_classificatio(emulator)
finally:
emulator.down()


if __name__ == "__main__":
emul_LogisticRegression(emulation.Mode.MULTIPROCESS)
92 changes: 61 additions & 31 deletions sml/linear_model/logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class Penalty(Enum):

class MultiClass(Enum):
Binary = 'binary'
Ovr = 'ovr' # not supported yet
Ovr = 'ovr'
Multy = 'multinomial' # not supported yet


Expand All @@ -44,7 +44,8 @@ class LogisticRegression:
IMPORTANT: Something different between `LogisticRegression` in sklearn:
1. sigmoid will be computed with approximation
2. you must define multi_class because we can not inspect y to decision the problem type
3. for now, only 0-1 binary classification is supported; so if your label is {-1,1}, you must change it first!
3. Due to the inability to perform data exploration in encrypted state, it is necessary to specify a value for labels.
For example, binary classification [0,1] ,multi classification [0,1,2].
Parameters
----------
Expand All @@ -59,6 +60,10 @@ class LogisticRegression:
- ovr: for each label, will fit a binary problem
- multinomial: the loss minimised is the multinomial loss that fit across the entire probability distribution
class_labels: classification labels, default=[0, 1].
Binary classification labels default=[0, 1].
Unable to perform data exploration in a confidential state, multiple classification labels need to be specified.
class_weight: not support yet, for multi-class tasks, default=None
sig_type: the approximation method for sigmoid function, default='sr'
Expand All @@ -85,6 +90,7 @@ def __init__(
penalty: str = 'l2',
solver: str = 'sgd',
multi_class: str = 'binary',
class_labels: list = [0, 1],
class_weight=None,
sig_type: str = 'sr',
C: float = 1.0,
Expand All @@ -110,7 +116,10 @@ def __init__(
e.value for e in SigType
], f"sig_type should in {[e.value for e in SigType]}, but got {sig_type}"
assert class_weight == None, f"not support class_weight for now"
assert multi_class == 'binary', f"only support binary problem for now"
assert multi_class in [
'binary',
'ovr',
], f"only support [binary,ovr] problem for now"

self._epochs = epochs
self._learning_rate = learning_rate
Expand All @@ -121,6 +130,7 @@ def __init__(
self._sig_type = SigType(sig_type)
self._class_weight = class_weight
self._multi_class = MultiClass(multi_class)
self._class_labels = class_labels

self._weights = jnp.zeros(())

Expand All @@ -131,13 +141,17 @@ def _update_weights(
w, # array-like
total_batch: int,
batch_size: int,
pos_class: int,
) -> np.ndarray:
assert x.shape[0] >= total_batch * batch_size, "total batch is too large"
num_feat = x.shape[1]
assert w.shape[0] == num_feat + 1, "w shape is mismatch to x"
assert len(w.shape) == 1 or w.shape[1] == 1, "w should be list or 1D array"
w = w.reshape((w.shape[0], 1))

mask = y == pos_class
y = mask

for idx in range(total_batch):
begin = idx * batch_size
end = (idx + 1) * batch_size
Expand Down Expand Up @@ -201,25 +215,29 @@ def fit(self, x, y):
num_feat = x.shape[1]
batch_size = min(self._batch_size, num_sample)
total_batch = int(num_sample / batch_size)
weights = jnp.zeros((num_feat + 1, 1))
n_classes = len(self._class_labels)
_classes = self._class_labels
if n_classes == 2:
n_classes = 1
_classes = _classes[1:]
_coefs = [None] * n_classes

# not support class_weight for now
if isinstance(self._class_weight, dict):
raise NotImplementedError
elif self._class_weight == 'balanced':
raise NotImplementedError

# do train
for _ in range(self._epochs):
weights = self._update_weights(
x,
y,
weights,
total_batch,
batch_size,
)
for i in range(n_classes):
weights = jnp.zeros((num_feat + 1, 1))
# do train
for _ in range(self._epochs):
weights = self._update_weights(
x, y, weights, total_batch, batch_size, _classes[i]
)
_coefs[i] = weights

self._weights = weights
self._weights = jnp.array(_coefs)
return self

def predict_proba(self, x):
Expand All @@ -238,7 +256,16 @@ def predict_proba(self, x):
pred = self.decision_function(x)

if self._multi_class == MultiClass.Binary:
prob = sigmoid(pred, self._sig_type)
prob = sigmoid(pred[0], self._sig_type)
elif self._multi_class == MultiClass.Ovr:
preds = [None] * len(pred)
for i in range(len(pred)):
prob = sigmoid(pred[i], self._sig_type)
preds[i] = prob.ravel()
preds = jnp.transpose(jnp.array(preds))
prob = preds / preds.sum(axis=1).reshape((preds.shape[0], -1))
# When using sklearn's "roc_auc_score()" function, accuracy verification will be performed on the sum of "prob": np.allclose (1, prob.sum (axis=1)). The following operation is to eliminate the impact of accuracy errors.
prob = prob.at[:, 0].set(1 - prob.sum(axis=1) + prob[:, 0])
else:
raise NotImplementedError

Expand All @@ -262,26 +289,29 @@ def predict(self, x):

if self._multi_class == MultiClass.Binary:
# for binary task, only check whether logit > 0 (prob > 0.5)
label = jnp.select([pred > 0], [1], 0)
label = jnp.select([pred[0] > 0], [1], 0)
elif self._multi_class == MultiClass.Ovr:
label = jnp.argmax(jnp.array(pred), axis=0)
else:
raise NotImplementedError

return label
return label.reshape((-1,))

def decision_function(self, x):
if self._multi_class == MultiClass.Binary:
num_feat = x.shape[1]
w = self._weights
assert w.shape[0] == num_feat + 1, f"w shape is mismatch to x={x.shape}"
assert len(w.shape) == 1 or w.shape[1] == 1, "w should be list or 1D array"
w.reshape((w.shape[0], 1))

bias = w[-1, 0]
w = jnp.resize(w, (num_feat, 1))
pred = jnp.matmul(x, w) + bias
return pred
elif self._multi_class == MultiClass.Ovr:
raise NotImplementedError
if self._multi_class in [MultiClass.Binary, MultiClass.Ovr]:
n_classes = len(self._class_labels)
preds = [None] * n_classes
for i in range(n_classes):
num_feat = x.shape[1]
w = self._weights[i]
assert w.shape[0] == num_feat + 1, f"w shape is mismatch to x={x.shape}"
assert (
len(w.shape) == 1 or w.shape[1] == 1
), "w should be list or 1D array"
bias = w[-1, 0]
w = jnp.resize(w, (num_feat, 1))
pred = jnp.matmul(x, w) + bias
preds[i] = pred
return preds
else:
# Multy model here
raise NotImplementedError
Loading

0 comments on commit 4ab16e2

Please sign in to comment.