Skip to content

Commit

Permalink
Repo sync (#705)
Browse files Browse the repository at this point in the history
  • Loading branch information
anakinxc authored May 24, 2024
1 parent 7b45d88 commit aef1c67
Show file tree
Hide file tree
Showing 71 changed files with 3,775 additions and 1,952 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

- [Feature] Add ORAM based dynamic_slice for ABY3
- [Feature] Add Atan2Op support
- [API] Add beaver cache support for semi2k (**experimental**)

## 20240415

Expand Down
16 changes: 8 additions & 8 deletions bazel/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,21 @@ def _yacl():
http_archive,
name = "yacl",
urls = [
"https://github.com/secretflow/yacl/archive/refs/tags/0.4.5b0.tar.gz",
"https://github.com/secretflow/yacl/archive/refs/tags/0.4.5b1.tar.gz",
],
strip_prefix = "yacl-0.4.5b0",
sha256 = "68d1dbeb255d404606d3ba9380b915fbbe3886cde575bbe89795657286742bd2",
strip_prefix = "yacl-0.4.5b1",
sha256 = "28064053b9add0db8e1e8e648421a0579f1d3e7ee8a4bbd7bd5959cb59598088",
)

def _libpsi():
maybe(
http_archive,
name = "psi",
urls = [
"https://github.com/secretflow/psi/archive/refs/tags/v0.4.0.dev240517.tar.gz",
"https://github.com/secretflow/psi/archive/refs/tags/v0.4.0.dev240524.tar.gz",
],
strip_prefix = "psi-0.4.0.dev240517",
sha256 = "43a475d44798d0a634f9cff2d2bd3a2c2c5f0f0dee34f01ac5de803f2a0de328",
strip_prefix = "psi-0.4.0.dev240524",
sha256 = "c2868fa6a9d804e6bbed9922dab6dc819ec6e180e15eafe7eb1b661302508c88",
)

def _rules_proto_grpc():
Expand Down Expand Up @@ -124,8 +124,8 @@ def _com_github_xtensor_xtl():
)

def _com_github_openxla_xla():
OPENXLA_COMMIT = "1acf05ef0d41181caaf0cd691aa9d453ffc41a73"
OPENXLA_SHA256 = "04a1cd0d530398419393f0db32f62a1b3b2f221b0dea52d7db75978109343558"
OPENXLA_COMMIT = "5f70248ff0e9702544c8eeea0ab9b03e1ef144b0"
OPENXLA_SHA256 = "e2db58c41b7160259e0ec109ecbfc9c4b07c0889312719a19796ab30a970ba9e"

# We need openxla to handle xla/mhlo/stablehlo
maybe(
Expand Down
44 changes: 44 additions & 0 deletions examples/python/conf/2pc_semi2k.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
{
"id": "colocated.2pc",
"nodes": {
"node:0": "127.0.0.1:64320",
"node:1": "127.0.0.1:64321"
},
"devices": {
"SPU": {
"kind": "SPU",
"config": {
"node_ids": [
"node:0",
"node:1"
],
"experimental_data_folder": [
"/tmp/spu_data_0/",
"/tmp/spu_data_1/"
],
"spu_internal_addrs": [
"127.0.0.1:64330",
"127.0.0.1:64331"
],
"runtime_config": {
"protocol": "SEMI2K",
"field": "FM64",
"enable_pphlo_profile": true,
"enable_hal_profile": true
}
}
},
"P1": {
"kind": "PYU",
"config": {
"node_id": "node:0"
}
},
"P2": {
"kind": "PYU",
"config": {
"node_id": "node:1"
}
}
}
}
1 change: 1 addition & 0 deletions examples/python/conf/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ filegroup(
name = "conf",
srcs = [
"2pc.json",
"2pc_semi2k.json",
"3pc.json",
"3pc_colocated.json",
"ds_breast_cancer_basic.json",
Expand Down
1 change: 1 addition & 0 deletions examples/python/ml/jax_lr/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ py_binary(
],
deps = [
"//examples/python/utils:dataset_utils",
"//spu:init",
"//spu/utils:distributed",
],
)
118 changes: 78 additions & 40 deletions examples/python/ml/jax_lr/jax_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

# Start nodes.
# > bazel run -c opt //examples/python/utils:nodectl -- up
# > bazel run -c opt //examples/python/utils:nodectl -- -c examples/python/conf/2pc_semi2k.json up
#
# Run this example script.
# > bazel run -c opt //examples/python/ml/jax_lr:jax_lr
Expand All @@ -24,10 +24,10 @@

import jax
import jax.numpy as jnp
import numpy as np
from sklearn import metrics

import examples.python.utils.dataset_utils as dsutil
import spu
import spu.utils.distributed as ppd


Expand All @@ -42,9 +42,17 @@ def predict(x, w, b):
return sigmoid(jnp.matmul(x, w) + b)


def loss(x, y, w, b):
def loss(x, y, w, b, use_cache):
if use_cache:
w = spu.experimental.make_cached_var(w)
b = spu.experimental.make_cached_var(b)
pred = predict(x, w, b)
label_prob = pred * y + (1 - pred) * (1 - y)

if use_cache:
w = spu.experimental.drop_cached_var(w, label_prob)
b = spu.experimental.drop_cached_var(b, label_prob)

return -jnp.mean(jnp.log(label_prob))


Expand All @@ -54,28 +62,39 @@ def __init__(self, n_epochs=10, n_iters=10, step_size=0.1):
self.n_iters = n_iters
self.step_size = step_size

def fit_auto_grad(self, feature, label):
def fit_auto_grad(self, feature, label, use_cache=False):
w = jnp.zeros(feature.shape[1])
b = 0.0

if use_cache:
feature = spu.experimental.make_cached_var(feature)

xs = jnp.array_split(feature, self.n_iters, axis=0)
ys = jnp.array_split(label, self.n_iters, axis=0)

def body_fun(_, loop_carry):
w_, b_ = loop_carry
for x, y in zip(xs, ys):
grad = jax.grad(loss, argnums=(2, 3))(x, y, w_, b_)
grad = jax.grad(loss, argnums=(2, 3))(x, y, w_, b_, use_cache)
w_ -= grad[0] * self.step_size
b_ -= grad[1] * self.step_size

return w_, b_

return jax.lax.fori_loop(0, self.n_epochs, body_fun, (w, b))
ret = jax.lax.fori_loop(0, self.n_epochs, body_fun, (w, b))

if use_cache:
feature = spu.experimental.drop_cached_var(feature, *ret)

return ret

def fit_manual_grad(self, feature, label):
def fit_manual_grad(self, feature, label, use_cache=False):
w = jnp.zeros(feature.shape[1])
b = 0.0

if use_cache:
feature = spu.experimental.make_cached_var(feature)

xs = jnp.array_split(feature, self.n_iters, axis=0)
ys = jnp.array_split(label, self.n_iters, axis=0)

Expand All @@ -89,22 +108,15 @@ def body_fun(_, loop_carry):

return w_, b_

return jax.lax.fori_loop(0, self.n_epochs, body_fun, (w, b))


parser = argparse.ArgumentParser(description='distributed driver.')
parser.add_argument("-c", "--config", default="examples/python/conf/3pc.json")
args = parser.parse_args()
ret = jax.lax.fori_loop(0, self.n_epochs, body_fun, (w, b))

with open(args.config, 'r') as file:
conf = json.load(file)
if use_cache:
feature = spu.experimental.drop_cached_var(feature, *ret)

ppd.init(conf["nodes"], conf["devices"])
return ret


def run_on_cpu():
x_train, y_train = dsutil.breast_cancer(slice(None, None, None), True)

def run_on_cpu(x_train, y_train):
lr = LogitRegression()

w0, b0 = jax.jit(lr.fit_auto_grad)(x_train, y_train)
Expand All @@ -120,59 +132,85 @@ def run_on_cpu():
import cloudpickle as pickle


def save_and_load_model():
# 1. run with spu
W, b = run_on_spu()

# 2. save metadata and spu objects.
def save_and_load_model(x_test, y_test, W, b):
# 1. save metadata and spu objects.
meta = ppd.save((W, b))
with open(SPU_OBJECT_META_PATH, "wb") as f:
pickle.dump(meta, f)

# 3. load metadata and spu objects.
# 2. load metadata and spu objects.
with open(SPU_OBJECT_META_PATH, "rb") as f:
meta_ = pickle.load(f)
W_, b_ = ppd.load(meta_)

W_r, b_r = ppd.get(W_), ppd.get(b_)
print(W_r, b_r)

x_test, y_test = dsutil.breast_cancer(slice(None, None, None), False)

score = metrics.roc_auc_score(y_test, predict(x_test, W_r, b_r))
print("AUC(save_and_load_model)={}".format(score))

return score


def compute_score(W_r, b_r, type):
x_test, y_test = dsutil.breast_cancer(slice(None, None, None), False)
def compute_score(x_test, y_test, W_r, b_r, type):
score = metrics.roc_auc_score(y_test, predict(x_test, W_r, b_r))
print(f"AUC({type})={score}")
return score


def run_on_spu():
def run_on_spu(x, y, use_cache=False, auto_grad=False):
@ppd.device("SPU")
def train(x1, x2, y):
x = jnp.concatenate((x1, x2), axis=1)
lr = LogitRegression()
return lr.fit_auto_grad(x, y)

x1, y = ppd.device("P1")(dsutil.breast_cancer)(slice(None, 15), True)
x2, _ = ppd.device("P2")(dsutil.breast_cancer)(slice(15, None), True)
if auto_grad:
return lr.fit_auto_grad(x, y, use_cache)
else:
return lr.fit_manual_grad(x, y, use_cache)

x1 = ppd.device("P1")(lambda x: x[:, :50])(x)
x2 = ppd.device("P2")(lambda x: x[:, 50:])(x)
y = ppd.device("P1")(lambda x: x)(y)
W, b = train(x1, x2, y)

return W, b


if __name__ == "__main__":
parser = argparse.ArgumentParser(description='distributed driver.')
parser.add_argument(
"-c", "--config", default="examples/python/conf/2pc_semi2k.json"
)
args = parser.parse_args()

with open(args.config, 'r') as file:
conf = json.load(file)

ppd.init(conf["nodes"], conf["devices"])

x, y = dsutil.mock_classification(10 * 10000, 100, 0.0, 42)

print('Run on CPU\n------\n')
w, b = run_on_cpu()
compute_score(w[0], b[0], 'cpu, auto_grad')
compute_score(w[1], b[1], 'cpu, manual_grad')
w, b = run_on_cpu(x, y)
compute_score(x, y, w[0], b[0], 'cpu, auto_grad')
compute_score(x, y, w[1], b[1], 'cpu, manual_grad')
print('Run on SPU\n------\n')
w, b = run_on_spu()
# without cache
# total send bytes 2376240800, recv bytes 2376240800
w, b = run_on_spu(x, y)
w_r, b_r = ppd.get(w), ppd.get(b)
compute_score(x, y, w_r, b_r, 'spu')
save_and_load_model(x, y, w, b)

print('Run on SPU with cache\n------\n')
# with semi2k beaver cache
# total send bytes 856240800, recv bytes 856240800
# Reduced communication bytes by 64%
w, b = run_on_spu(x, y, True)
w_r, b_r = ppd.get(w), ppd.get(b)
compute_score(x, y, w_r, b_r, 'spu_cached')

print('Run on SPU auto_grad\n------\n')
w, b = run_on_spu(x, y, True, True)
w_r, b_r = ppd.get(w), ppd.get(b)
compute_score(w_r, b_r, 'spu')
save_and_load_model()
compute_score(x, y, w_r, b_r, 'spu, auto_grad')
18 changes: 8 additions & 10 deletions examples/python/ml/ml_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,17 @@ def test_jax_kmeans(self):
npt.assert_array_equal(cpu_labels, spu_labels)

def test_jax_lr(self):
from examples.python.utils import dataset_utils as dsutil
from examples.python.ml.jax_lr import jax_lr

w, b = profile_test_point(jax_lr.run_on_spu)
score = jax_lr.compute_score(ppd.get(w), ppd.get(b), 'spu')
x, y = dsutil.mock_classification(10000, 100, 0.0, 42)
w, b = profile_test_point(jax_lr.run_on_spu, x, y)

self.assertGreater(score, 0.95)
score = jax_lr.compute_score(x, y, ppd.get(w), ppd.get(b), 'spu')
self.assertGreater(score, 0.85)

score = jax_lr.save_and_load_model(x, y, w, b)
self.assertGreater(score, 0.85)

def test_jax_svm(self):
from examples.python.ml.jax_svm import jax_svm
Expand Down Expand Up @@ -219,12 +224,6 @@ def test_torch_resnet_experiment(self):
label = torch_resnet_experiment.run_inference_on_spu(model, image)
self.assertEqual(label, 258)

def test_save_and_load_model(self):
from examples.python.ml.jax_lr import jax_lr

score = jax_lr.save_and_load_model()
self.assertGreater(score, 0.9)


def suite():
suite = unittest.TestSuite()
Expand All @@ -239,7 +238,6 @@ def suite():
suite.addTest(UnitTests('test_ss_xgb'))
suite.addTest(UnitTests('test_stax_mnist_classifier'))
suite.addTest(UnitTests('test_stax_nn'))
suite.addTest(UnitTests('test_save_and_load_model'))
# should put JAX tests above
suite.addTest(UnitTests('test_tf_experiment'))
suite.addTest(UnitTests('test_torch_lr_experiment'))
Expand Down
7 changes: 3 additions & 4 deletions libspu/compiler/passes/convert_push_down.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,10 @@ struct TypeAgnosticOpConverter : public OpRewritePattern<OpT> {
return failure();
}

const auto &from_type = parentConvert.getOperand()
.getType()
.template dyn_cast<RankedTensorType>();
const auto &from_type =
mlir::dyn_cast<RankedTensorType>(parentConvert.getOperand().getType());
const auto &to_type =
op.getResult().getType().template dyn_cast<RankedTensorType>();
mlir::dyn_cast<RankedTensorType>(op.getResult().getType());

OpBuilder builder(op);

Expand Down
Loading

0 comments on commit aef1c67

Please sign in to comment.