Skip to content

Commit

Permalink
Repo sync (#366)
Browse files Browse the repository at this point in the history
# Pull Request

## What problem does this PR solve?

Issue Number: Fixed #

## Possible side effects?

- Performance:

- Backward compatibility:
  • Loading branch information
anakinxc authored Oct 13, 2023
1 parent 815c15e commit 6228366
Show file tree
Hide file tree
Showing 254 changed files with 1,783 additions and 1,496 deletions.
2 changes: 1 addition & 1 deletion benchmark/binary_op_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@

import argparse
import json
import time

import jax.numpy as jnp
import numpy as np
import time

import spu.utils.distributed as ppd

Expand Down
2 changes: 1 addition & 1 deletion benchmark/unary_op_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@

import argparse
import json
import time

import jax.numpy as jnp
import numpy as np
import time

import spu.utils.distributed as ppd

Expand Down
7 changes: 4 additions & 3 deletions docs/reference/gen_benchmark_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
# limitations under the License.


import argparse
import json
import pandas as pd
import numpy as np
import os
from enum import Enum
import argparse

import numpy as np
import pandas as pd

g_time_list = [
('ns', 1000),
Expand Down
3 changes: 2 additions & 1 deletion docs/reference/gen_complexity_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@


import argparse
from pytablewriter import MarkdownTableWriter
import json

from pytablewriter import MarkdownTableWriter


def main():
parser = argparse.ArgumentParser(
Expand Down
3 changes: 2 additions & 1 deletion docs/reference/gen_np_op_status_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@


import argparse
from pytablewriter import MarkdownTableWriter
import json

from mdutils.mdutils import MdUtils
from pytablewriter import MarkdownTableWriter


def main():
Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/cpp_lr_example.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ In the first terminal.

.. code-block:: bash
bazel run //examples/cpp:simple_lr -- -rank 0 -dataset examples/cpp/data/perfect_logit_a.csv -has_label=true
bazel run //examples/cpp:simple_lr -- -rank 0 -dataset examples/cpp/perfect_logit_a.csv -has_label=true
In the second terminal.

.. code-block:: bash
bazel run //examples/cpp:simple_lr -- -rank 1 -dataset examples/cpp/data/perfect_logit_b.csv
bazel run //examples/cpp:simple_lr -- -rank 1 -dataset examples/cpp/perfect_logit_b.csv
7 changes: 6 additions & 1 deletion examples/cpp/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@ spu_cc_binary(
deps = [
":utils",
"//libspu/device:io",
"//libspu/kernel/hal",
"//libspu/kernel/hal:public_helper",
"//libspu/kernel/hlo:basic_binary",
"//libspu/kernel/hlo:basic_unary",
"//libspu/kernel/hlo:casting",
"//libspu/kernel/hlo:const",
"//libspu/kernel/hlo:geometrical",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@yacl//yacl/link:factory",
Expand Down
56 changes: 32 additions & 24 deletions examples/cpp/simple_lr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,45 +24,51 @@

#include "examples/cpp/utils.h"
#include "spdlog/spdlog.h"
#include "xtensor/xarray.hpp"
#include "xtensor/xcsv.hpp"
#include "xtensor/xview.hpp"

#include "libspu/device/io.h"
#include "libspu/kernel/hal/hal.h"
#include "libspu/kernel/hal/type_cast.h"
#include "libspu/kernel/hal/public_helper.h"
#include "libspu/kernel/hlo/basic_binary.h"
#include "libspu/kernel/hlo/basic_unary.h"
#include "libspu/kernel/hlo/casting.h"
#include "libspu/kernel/hlo/const.h"
#include "libspu/kernel/hlo/geometrical.h"
#include "libspu/mpc/factory.h"

using namespace spu::kernel;

spu::Value train_step(spu::SPUContext* ctx, const spu::Value& x,
const spu::Value& y, const spu::Value& w) {
// Padding x
auto padding = hal::constant(ctx, 1.0F, spu::DT_F32, {x.shape()[0], 1});
auto padded_x = hal::concatenate(ctx, {x, hal::seal(ctx, padding)}, 1);
auto pred = hal::logistic(ctx, hal::matmul(ctx, padded_x, w));
auto padding = hlo::Constant(ctx, 1.0F, {x.shape()[0], 1});
auto padded_x = hlo::Concatenate(
ctx, {x, hlo::Cast(ctx, padding, spu::VIS_SECRET, padding.dtype())}, 1);
auto pred = hlo::Logistic(ctx, hlo::Dot(ctx, padded_x, w));

SPDLOG_DEBUG("[SSLR] Err = Pred - Y");
auto err = hal::sub(ctx, pred, y);
auto err = hlo::Sub(ctx, pred, y);

SPDLOG_DEBUG("[SSLR] Grad = X.t * Err");
auto grad = hal::matmul(ctx, hal::transpose(ctx, padded_x), err);
auto grad = hlo::Dot(ctx, hlo::Transpose(ctx, padded_x, {}), err);

SPDLOG_DEBUG("[SSLR] Step = LR / B * Grad");
auto lr = hal::constant(ctx, 0.0001F, spu::DT_F32);
auto msize =
hal::constant(ctx, static_cast<float>(y.shape()[0]), spu::DT_F32);
auto p1 = hal::mul(ctx, lr, hal::reciprocal(ctx, msize));
auto step = hal::mul(ctx, hal::broadcast_to(ctx, p1, grad.shape()), grad);
auto lr = hlo::Constant(ctx, 0.0001F, {});
auto msize = hlo::Constant(ctx, static_cast<float>(y.shape()[0]), {});
auto p1 = hlo::Mul(ctx, lr, hlo::Reciprocal(ctx, msize));
auto step = hlo::Mul(ctx, hlo::Broadcast(ctx, p1, grad.shape(), {}), grad);

SPDLOG_DEBUG("[SSLR] W = W - Step");
auto new_w = hal::sub(ctx, w, step);
auto new_w = hlo::Sub(ctx, w, step);

return new_w;
}

spu::Value train(spu::SPUContext* ctx, const spu::Value& x, const spu::Value& y,
size_t num_epoch, size_t bsize) {
const size_t num_iter = x.shape()[0] / bsize;
auto w = hal::constant(ctx, 0.0F, spu::DT_F32, {x.shape()[1] + 1, 1});
auto w = hlo::Constant(ctx, 0.0F, {x.shape()[1] + 1, 1});

// Run train loop
for (size_t epoch = 0; epoch < num_epoch; ++epoch) {
Expand All @@ -73,10 +79,10 @@ spu::Value train(spu::SPUContext* ctx, const spu::Value& x, const spu::Value& y,
const int64_t rows_end = rows_beg + bsize;

const auto x_slice =
hal::slice(ctx, x, {rows_beg, 0}, {rows_end, x.shape()[1]}, {});
hlo::Slice(ctx, x, {rows_beg, 0}, {rows_end, x.shape()[1]}, {});

const auto y_slice =
hal::slice(ctx, y, {rows_beg, 0}, {rows_end, y.shape()[1]}, {});
hlo::Slice(ctx, y, {rows_beg, 0}, {rows_end, y.shape()[1]}, {});

w = train_step(ctx, x_slice, y_slice, w);
}
Expand All @@ -87,9 +93,10 @@ spu::Value train(spu::SPUContext* ctx, const spu::Value& x, const spu::Value& y,

spu::Value inference(spu::SPUContext* ctx, const spu::Value& x,
const spu::Value& weight) {
auto padding = hal::constant(ctx, 1.0F, spu::DT_F32, {x.shape()[0], 1});
auto padded_x = hal::concatenate(ctx, {x, hal::seal(ctx, padding)}, 1);
return hal::matmul(ctx, padded_x, weight);
auto padding = hlo::Constant(ctx, 1.0F, {x.shape()[0], 1});
auto padded_x = hlo::Concatenate(
ctx, {x, hlo::Cast(ctx, padding, spu::VIS_SECRET, padding.dtype())}, 1);
return hlo::Dot(ctx, padded_x, weight);
}

float SSE(const xt::xarray<float>& y_true, const xt::xarray<float>& y_pred) {
Expand Down Expand Up @@ -143,7 +150,7 @@ std::pair<spu::Value, spu::Value> infeed(spu::SPUContext* sctx,
auto x = cio.deviceGetVar("x-0");
// Concatenate all slices
for (size_t idx = 1; idx < cio.getWorldSize(); ++idx) {
x = hal::concatenate(sctx, {x, cio.deviceGetVar(fmt::format("x-{}", idx))},
x = hlo::Concatenate(sctx, {x, cio.deviceGetVar(fmt::format("x-{}", idx))},
1);
}
auto y = cio.deviceGetVar("label");
Expand Down Expand Up @@ -175,10 +182,11 @@ int main(int argc, char** argv) {

const auto scores = inference(sctx.get(), x, w);

xt::xarray<float> revealed_labels =
hal::dump_public_as<float>(sctx.get(), hal::reveal(sctx.get(), y));
xt::xarray<float> revealed_scores =
hal::dump_public_as<float>(sctx.get(), hal::reveal(sctx.get(), scores));
xt::xarray<float> revealed_labels = hal::dump_public_as<float>(
sctx.get(), hlo::Cast(sctx.get(), y, spu::VIS_PUBLIC, y.dtype()));
xt::xarray<float> revealed_scores = hal::dump_public_as<float>(
sctx.get(),
hlo::Cast(sctx.get(), scores, spu::VIS_PUBLIC, scores.dtype()));

auto mse = MSE(revealed_labels, revealed_scores);
std::cout << "MSE = " << mse << "\n";
Expand Down
1 change: 0 additions & 1 deletion examples/cpp/simple_pphlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
// clang-format on

#include "examples/cpp/utils.h"
#include "spdlog/spdlog.h"

#include "libspu/device/api.h"
#include "libspu/device/io.h"
Expand Down
4 changes: 2 additions & 2 deletions examples/cpp/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

#include "examples/cpp/utils.h"

#include "absl/strings/match.h"
#include "absl/strings/str_split.h"
#include "yacl/link/factory.h"

#include "libspu/core/config.h"

Expand All @@ -41,7 +41,7 @@ std::shared_ptr<yacl::link::Context> MakeLink(const std::string& parties,
std::vector<std::string> hosts = absl::StrSplit(parties, ',');
for (size_t rank = 0; rank < hosts.size(); rank++) {
const auto id = fmt::format("party{}", rank);
lctx_desc.parties.push_back({id, hosts[rank]});
lctx_desc.parties.emplace_back(id, hosts[rank]);
}
auto lctx = yacl::link::FactoryBrpc().CreateContext(lctx_desc, rank);
lctx->ConnectToMesh();
Expand Down
2 changes: 1 addition & 1 deletion examples/python/ir_dump/ir_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
import jax.numpy as jnp
import numpy as np

import spu.utils.distributed as ppd
import spu.spu_pb2 as spu_pb2
import spu.utils.distributed as ppd

logging.basicConfig(level=logging.INFO)

Expand Down
14 changes: 8 additions & 6 deletions examples/python/ml/flax_llama7b/flax_llama7b.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,20 @@

import argparse
import json
from contextlib import contextmanager
from typing import Any, Optional, Tuple, Union

import flax.linen as nn
import jax
import jax.numpy as jnp
import jax.nn as jnn
import flax.linen as nn
import jax.numpy as jnp
from EasyLM.models.llama.llama_model import FlaxLLaMAForCausalLM, LLaMAConfig
from flax.linen.linear import Array
from typing import Any, Optional, Tuple, Union
from transformers import LlamaTokenizer
from EasyLM.models.llama.llama_model import LLaMAConfig, FlaxLLaMAForCausalLM
import spu.utils.distributed as ppd
from contextlib import contextmanager

import spu.intrinsic as intrinsic
import spu.spu_pb2 as spu_pb2
import spu.utils.distributed as ppd

parser = argparse.ArgumentParser(description='distributed driver.')
parser.add_argument("-c", "--config", default="examples/python/ml/flax_llama/3pc.json")
Expand Down
3 changes: 2 additions & 1 deletion examples/python/ml/flax_mlp/flax_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,10 @@ def run_on_cpu():

ppd.init(conf["nodes"], conf["devices"])

import cloudpickle as pickle
import tempfile

import cloudpickle as pickle


def compute_score(param, type):
x_test, y_test = dsutil.breast_cancer(slice(None, None, None), False)
Expand Down
10 changes: 4 additions & 6 deletions examples/python/ml/flax_resnet/flax_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,17 @@
# See issue #620.
# pytype: disable=wrong-arg-count

from typing import Any
import argparse
import time
from typing import Any

import jax
import jax.numpy as jnp
import optax
import tensorflow as tf
import tensorflow_datasets as tfds

import optax
from flax.training import train_state
import jax.numpy as jnp
import jax
from jax import random

from models import ResNet18

NUM_CLASSES = 10
Expand Down
2 changes: 1 addition & 1 deletion examples/python/ml/flax_resnet/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from functools import partial
from typing import Any, Callable, Sequence, Tuple

from flax import linen as nn
import jax.numpy as jnp
from flax import linen as nn

ModuleDef = Any

Expand Down
4 changes: 2 additions & 2 deletions examples/python/ml/flax_vae/flax_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
import optax
import tensorflow as tf
import tensorflow_datasets as tfds
from flax import linen as nn
from flax.training import train_state
from jax import random

import examples.python.ml.flax_vae.utils as vae_utils
from flax import linen as nn
from flax.training import train_state

# Replace absl.flags used by original authors with argparse for unittest
parser = argparse.ArgumentParser(description='distributed driver.')
Expand Down
5 changes: 3 additions & 2 deletions examples/python/ml/jraph_gnn/jraph_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@

import logging

from absl import app
import haiku as hk
import jax
import jax.numpy as jnp
import jraph
import optax
from absl import app


def get_zacharys_karate_club() -> jraph.GraphsTuple:
Expand Down Expand Up @@ -252,9 +252,10 @@ def predict(params):


import argparse
import spu.utils.distributed as ppd
import json

import spu.utils.distributed as ppd

parser = argparse.ArgumentParser(description="distributed driver.")
parser.add_argument("-c", "--config", default="examples/python/conf/3pc.json")
args = parser.parse_args()
Expand Down
5 changes: 2 additions & 3 deletions examples/python/ml/ml_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,20 @@
# limitations under the License.


import inspect
import json
import logging
import os
import sys
import unittest
from time import perf_counter
import os

import multiprocess
import numpy.testing as npt
import pandas as pd
import inspect

import spu.utils.distributed as ppd


with open("examples/python/conf/3pc.json", 'r') as file:
conf = json.load(file)

Expand Down
Loading

0 comments on commit 6228366

Please sign in to comment.