diff --git a/benchmark/binary_op_bench.py b/benchmark/binary_op_bench.py index b048fae8..ded577b1 100644 --- a/benchmark/binary_op_bench.py +++ b/benchmark/binary_op_bench.py @@ -14,10 +14,10 @@ import argparse import json +import time import jax.numpy as jnp import numpy as np -import time import spu.utils.distributed as ppd diff --git a/benchmark/unary_op_bench.py b/benchmark/unary_op_bench.py index dd66462b..9c4e4d79 100644 --- a/benchmark/unary_op_bench.py +++ b/benchmark/unary_op_bench.py @@ -14,10 +14,10 @@ import argparse import json +import time import jax.numpy as jnp import numpy as np -import time import spu.utils.distributed as ppd diff --git a/docs/reference/gen_benchmark_report.py b/docs/reference/gen_benchmark_report.py index 04fb397e..11bbb5ae 100644 --- a/docs/reference/gen_benchmark_report.py +++ b/docs/reference/gen_benchmark_report.py @@ -15,12 +15,13 @@ # limitations under the License. +import argparse import json -import pandas as pd -import numpy as np import os from enum import Enum -import argparse + +import numpy as np +import pandas as pd g_time_list = [ ('ns', 1000), diff --git a/docs/reference/gen_complexity_md.py b/docs/reference/gen_complexity_md.py index a4ea1ccc..a7527178 100644 --- a/docs/reference/gen_complexity_md.py +++ b/docs/reference/gen_complexity_md.py @@ -16,9 +16,10 @@ import argparse -from pytablewriter import MarkdownTableWriter import json +from pytablewriter import MarkdownTableWriter + def main(): parser = argparse.ArgumentParser( diff --git a/docs/reference/gen_np_op_status_doc.py b/docs/reference/gen_np_op_status_doc.py index 3a3808a3..d13c1213 100644 --- a/docs/reference/gen_np_op_status_doc.py +++ b/docs/reference/gen_np_op_status_doc.py @@ -16,9 +16,10 @@ import argparse -from pytablewriter import MarkdownTableWriter import json + from mdutils.mdutils import MdUtils +from pytablewriter import MarkdownTableWriter def main(): diff --git a/docs/tutorials/cpp_lr_example.rst b/docs/tutorials/cpp_lr_example.rst index 6d83a4ed..87bec61d 100644 --- a/docs/tutorials/cpp_lr_example.rst +++ b/docs/tutorials/cpp_lr_example.rst @@ -20,11 +20,11 @@ In the first terminal. .. code-block:: bash - bazel run //examples/cpp:simple_lr -- -rank 0 -dataset examples/cpp/data/perfect_logit_a.csv -has_label=true + bazel run //examples/cpp:simple_lr -- -rank 0 -dataset examples/cpp/perfect_logit_a.csv -has_label=true In the second terminal. .. code-block:: bash - bazel run //examples/cpp:simple_lr -- -rank 1 -dataset examples/cpp/data/perfect_logit_b.csv + bazel run //examples/cpp:simple_lr -- -rank 1 -dataset examples/cpp/perfect_logit_b.csv diff --git a/examples/cpp/BUILD.bazel b/examples/cpp/BUILD.bazel index 5eabbf99..93c5f79c 100644 --- a/examples/cpp/BUILD.bazel +++ b/examples/cpp/BUILD.bazel @@ -25,7 +25,12 @@ spu_cc_binary( deps = [ ":utils", "//libspu/device:io", - "//libspu/kernel/hal", + "//libspu/kernel/hal:public_helper", + "//libspu/kernel/hlo:basic_binary", + "//libspu/kernel/hlo:basic_unary", + "//libspu/kernel/hlo:casting", + "//libspu/kernel/hlo:const", + "//libspu/kernel/hlo:geometrical", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@yacl//yacl/link:factory", diff --git a/examples/cpp/simple_lr.cc b/examples/cpp/simple_lr.cc index fa2f4fba..77a076d7 100644 --- a/examples/cpp/simple_lr.cc +++ b/examples/cpp/simple_lr.cc @@ -24,11 +24,17 @@ #include "examples/cpp/utils.h" #include "spdlog/spdlog.h" +#include "xtensor/xarray.hpp" #include "xtensor/xcsv.hpp" +#include "xtensor/xview.hpp" #include "libspu/device/io.h" -#include "libspu/kernel/hal/hal.h" -#include "libspu/kernel/hal/type_cast.h" +#include "libspu/kernel/hal/public_helper.h" +#include "libspu/kernel/hlo/basic_binary.h" +#include "libspu/kernel/hlo/basic_unary.h" +#include "libspu/kernel/hlo/casting.h" +#include "libspu/kernel/hlo/const.h" +#include "libspu/kernel/hlo/geometrical.h" #include "libspu/mpc/factory.h" using namespace spu::kernel; @@ -36,25 +42,25 @@ using namespace spu::kernel; spu::Value train_step(spu::SPUContext* ctx, const spu::Value& x, const spu::Value& y, const spu::Value& w) { // Padding x - auto padding = hal::constant(ctx, 1.0F, spu::DT_F32, {x.shape()[0], 1}); - auto padded_x = hal::concatenate(ctx, {x, hal::seal(ctx, padding)}, 1); - auto pred = hal::logistic(ctx, hal::matmul(ctx, padded_x, w)); + auto padding = hlo::Constant(ctx, 1.0F, {x.shape()[0], 1}); + auto padded_x = hlo::Concatenate( + ctx, {x, hlo::Cast(ctx, padding, spu::VIS_SECRET, padding.dtype())}, 1); + auto pred = hlo::Logistic(ctx, hlo::Dot(ctx, padded_x, w)); SPDLOG_DEBUG("[SSLR] Err = Pred - Y"); - auto err = hal::sub(ctx, pred, y); + auto err = hlo::Sub(ctx, pred, y); SPDLOG_DEBUG("[SSLR] Grad = X.t * Err"); - auto grad = hal::matmul(ctx, hal::transpose(ctx, padded_x), err); + auto grad = hlo::Dot(ctx, hlo::Transpose(ctx, padded_x, {}), err); SPDLOG_DEBUG("[SSLR] Step = LR / B * Grad"); - auto lr = hal::constant(ctx, 0.0001F, spu::DT_F32); - auto msize = - hal::constant(ctx, static_cast(y.shape()[0]), spu::DT_F32); - auto p1 = hal::mul(ctx, lr, hal::reciprocal(ctx, msize)); - auto step = hal::mul(ctx, hal::broadcast_to(ctx, p1, grad.shape()), grad); + auto lr = hlo::Constant(ctx, 0.0001F, {}); + auto msize = hlo::Constant(ctx, static_cast(y.shape()[0]), {}); + auto p1 = hlo::Mul(ctx, lr, hlo::Reciprocal(ctx, msize)); + auto step = hlo::Mul(ctx, hlo::Broadcast(ctx, p1, grad.shape(), {}), grad); SPDLOG_DEBUG("[SSLR] W = W - Step"); - auto new_w = hal::sub(ctx, w, step); + auto new_w = hlo::Sub(ctx, w, step); return new_w; } @@ -62,7 +68,7 @@ spu::Value train_step(spu::SPUContext* ctx, const spu::Value& x, spu::Value train(spu::SPUContext* ctx, const spu::Value& x, const spu::Value& y, size_t num_epoch, size_t bsize) { const size_t num_iter = x.shape()[0] / bsize; - auto w = hal::constant(ctx, 0.0F, spu::DT_F32, {x.shape()[1] + 1, 1}); + auto w = hlo::Constant(ctx, 0.0F, {x.shape()[1] + 1, 1}); // Run train loop for (size_t epoch = 0; epoch < num_epoch; ++epoch) { @@ -73,10 +79,10 @@ spu::Value train(spu::SPUContext* ctx, const spu::Value& x, const spu::Value& y, const int64_t rows_end = rows_beg + bsize; const auto x_slice = - hal::slice(ctx, x, {rows_beg, 0}, {rows_end, x.shape()[1]}, {}); + hlo::Slice(ctx, x, {rows_beg, 0}, {rows_end, x.shape()[1]}, {}); const auto y_slice = - hal::slice(ctx, y, {rows_beg, 0}, {rows_end, y.shape()[1]}, {}); + hlo::Slice(ctx, y, {rows_beg, 0}, {rows_end, y.shape()[1]}, {}); w = train_step(ctx, x_slice, y_slice, w); } @@ -87,9 +93,10 @@ spu::Value train(spu::SPUContext* ctx, const spu::Value& x, const spu::Value& y, spu::Value inference(spu::SPUContext* ctx, const spu::Value& x, const spu::Value& weight) { - auto padding = hal::constant(ctx, 1.0F, spu::DT_F32, {x.shape()[0], 1}); - auto padded_x = hal::concatenate(ctx, {x, hal::seal(ctx, padding)}, 1); - return hal::matmul(ctx, padded_x, weight); + auto padding = hlo::Constant(ctx, 1.0F, {x.shape()[0], 1}); + auto padded_x = hlo::Concatenate( + ctx, {x, hlo::Cast(ctx, padding, spu::VIS_SECRET, padding.dtype())}, 1); + return hlo::Dot(ctx, padded_x, weight); } float SSE(const xt::xarray& y_true, const xt::xarray& y_pred) { @@ -143,7 +150,7 @@ std::pair infeed(spu::SPUContext* sctx, auto x = cio.deviceGetVar("x-0"); // Concatenate all slices for (size_t idx = 1; idx < cio.getWorldSize(); ++idx) { - x = hal::concatenate(sctx, {x, cio.deviceGetVar(fmt::format("x-{}", idx))}, + x = hlo::Concatenate(sctx, {x, cio.deviceGetVar(fmt::format("x-{}", idx))}, 1); } auto y = cio.deviceGetVar("label"); @@ -175,10 +182,11 @@ int main(int argc, char** argv) { const auto scores = inference(sctx.get(), x, w); - xt::xarray revealed_labels = - hal::dump_public_as(sctx.get(), hal::reveal(sctx.get(), y)); - xt::xarray revealed_scores = - hal::dump_public_as(sctx.get(), hal::reveal(sctx.get(), scores)); + xt::xarray revealed_labels = hal::dump_public_as( + sctx.get(), hlo::Cast(sctx.get(), y, spu::VIS_PUBLIC, y.dtype())); + xt::xarray revealed_scores = hal::dump_public_as( + sctx.get(), + hlo::Cast(sctx.get(), scores, spu::VIS_PUBLIC, scores.dtype())); auto mse = MSE(revealed_labels, revealed_scores); std::cout << "MSE = " << mse << "\n"; diff --git a/examples/cpp/simple_pphlo.cc b/examples/cpp/simple_pphlo.cc index 71f36dbc..b95a4c95 100644 --- a/examples/cpp/simple_pphlo.cc +++ b/examples/cpp/simple_pphlo.cc @@ -19,7 +19,6 @@ // clang-format on #include "examples/cpp/utils.h" -#include "spdlog/spdlog.h" #include "libspu/device/api.h" #include "libspu/device/io.h" diff --git a/examples/cpp/utils.cc b/examples/cpp/utils.cc index 0bca8629..368d4a3f 100644 --- a/examples/cpp/utils.cc +++ b/examples/cpp/utils.cc @@ -14,8 +14,8 @@ #include "examples/cpp/utils.h" -#include "absl/strings/match.h" #include "absl/strings/str_split.h" +#include "yacl/link/factory.h" #include "libspu/core/config.h" @@ -41,7 +41,7 @@ std::shared_ptr MakeLink(const std::string& parties, std::vector hosts = absl::StrSplit(parties, ','); for (size_t rank = 0; rank < hosts.size(); rank++) { const auto id = fmt::format("party{}", rank); - lctx_desc.parties.push_back({id, hosts[rank]}); + lctx_desc.parties.emplace_back(id, hosts[rank]); } auto lctx = yacl::link::FactoryBrpc().CreateContext(lctx_desc, rank); lctx->ConnectToMesh(); diff --git a/examples/python/ir_dump/ir_dump.py b/examples/python/ir_dump/ir_dump.py index fd081657..4e85eace 100644 --- a/examples/python/ir_dump/ir_dump.py +++ b/examples/python/ir_dump/ir_dump.py @@ -28,8 +28,8 @@ import jax.numpy as jnp import numpy as np -import spu.utils.distributed as ppd import spu.spu_pb2 as spu_pb2 +import spu.utils.distributed as ppd logging.basicConfig(level=logging.INFO) diff --git a/examples/python/ml/flax_llama7b/flax_llama7b.py b/examples/python/ml/flax_llama7b/flax_llama7b.py index 80ff47b8..88754fbd 100644 --- a/examples/python/ml/flax_llama7b/flax_llama7b.py +++ b/examples/python/ml/flax_llama7b/flax_llama7b.py @@ -19,18 +19,20 @@ import argparse import json +from contextlib import contextmanager +from typing import Any, Optional, Tuple, Union + +import flax.linen as nn import jax -import jax.numpy as jnp import jax.nn as jnn -import flax.linen as nn +import jax.numpy as jnp +from EasyLM.models.llama.llama_model import FlaxLLaMAForCausalLM, LLaMAConfig from flax.linen.linear import Array -from typing import Any, Optional, Tuple, Union from transformers import LlamaTokenizer -from EasyLM.models.llama.llama_model import LLaMAConfig, FlaxLLaMAForCausalLM -import spu.utils.distributed as ppd -from contextlib import contextmanager + import spu.intrinsic as intrinsic import spu.spu_pb2 as spu_pb2 +import spu.utils.distributed as ppd parser = argparse.ArgumentParser(description='distributed driver.') parser.add_argument("-c", "--config", default="examples/python/ml/flax_llama/3pc.json") diff --git a/examples/python/ml/flax_mlp/flax_mlp.py b/examples/python/ml/flax_mlp/flax_mlp.py index 0576b95f..c146a8f1 100644 --- a/examples/python/ml/flax_mlp/flax_mlp.py +++ b/examples/python/ml/flax_mlp/flax_mlp.py @@ -106,9 +106,10 @@ def run_on_cpu(): ppd.init(conf["nodes"], conf["devices"]) -import cloudpickle as pickle import tempfile +import cloudpickle as pickle + def compute_score(param, type): x_test, y_test = dsutil.breast_cancer(slice(None, None, None), False) diff --git a/examples/python/ml/flax_resnet/flax_resnet.py b/examples/python/ml/flax_resnet/flax_resnet.py index fcce8606..2946ade5 100644 --- a/examples/python/ml/flax_resnet/flax_resnet.py +++ b/examples/python/ml/flax_resnet/flax_resnet.py @@ -17,19 +17,17 @@ # See issue #620. # pytype: disable=wrong-arg-count -from typing import Any import argparse import time +from typing import Any +import jax +import jax.numpy as jnp +import optax import tensorflow as tf import tensorflow_datasets as tfds - -import optax from flax.training import train_state -import jax.numpy as jnp -import jax from jax import random - from models import ResNet18 NUM_CLASSES = 10 diff --git a/examples/python/ml/flax_resnet/models.py b/examples/python/ml/flax_resnet/models.py index 667dc79a..4dc2982b 100644 --- a/examples/python/ml/flax_resnet/models.py +++ b/examples/python/ml/flax_resnet/models.py @@ -20,8 +20,8 @@ from functools import partial from typing import Any, Callable, Sequence, Tuple -from flax import linen as nn import jax.numpy as jnp +from flax import linen as nn ModuleDef = Any diff --git a/examples/python/ml/flax_vae/flax_vae.py b/examples/python/ml/flax_vae/flax_vae.py index beeb07fa..7251fd2a 100644 --- a/examples/python/ml/flax_vae/flax_vae.py +++ b/examples/python/ml/flax_vae/flax_vae.py @@ -21,11 +21,11 @@ import optax import tensorflow as tf import tensorflow_datasets as tfds +from flax import linen as nn +from flax.training import train_state from jax import random import examples.python.ml.flax_vae.utils as vae_utils -from flax import linen as nn -from flax.training import train_state # Replace absl.flags used by original authors with argparse for unittest parser = argparse.ArgumentParser(description='distributed driver.') diff --git a/examples/python/ml/jraph_gnn/jraph_gnn.py b/examples/python/ml/jraph_gnn/jraph_gnn.py index 85fe499c..d82dfa3b 100644 --- a/examples/python/ml/jraph_gnn/jraph_gnn.py +++ b/examples/python/ml/jraph_gnn/jraph_gnn.py @@ -24,12 +24,12 @@ import logging -from absl import app import haiku as hk import jax import jax.numpy as jnp import jraph import optax +from absl import app def get_zacharys_karate_club() -> jraph.GraphsTuple: @@ -252,9 +252,10 @@ def predict(params): import argparse -import spu.utils.distributed as ppd import json +import spu.utils.distributed as ppd + parser = argparse.ArgumentParser(description="distributed driver.") parser.add_argument("-c", "--config", default="examples/python/conf/3pc.json") args = parser.parse_args() diff --git a/examples/python/ml/ml_test.py b/examples/python/ml/ml_test.py index 3c2b8d18..7526b2b9 100644 --- a/examples/python/ml/ml_test.py +++ b/examples/python/ml/ml_test.py @@ -13,21 +13,20 @@ # limitations under the License. +import inspect import json import logging +import os import sys import unittest from time import perf_counter -import os import multiprocess import numpy.testing as npt import pandas as pd -import inspect import spu.utils.distributed as ppd - with open("examples/python/conf/3pc.json", 'r') as file: conf = json.load(file) diff --git a/examples/python/ml/ss_lr/ss_lr.py b/examples/python/ml/ss_lr/ss_lr.py index bee4c33d..eb5c1f2f 100644 --- a/examples/python/ml/ss_lr/ss_lr.py +++ b/examples/python/ml/ss_lr/ss_lr.py @@ -20,17 +20,17 @@ import argparse import json -import time import logging +import time from enum import Enum -from sklearn.metrics import roc_auc_score, explained_variance_score from typing import Dict, List -import examples.python.utils.dataset_utils as dsutil - import jax.numpy as jnp import numpy as np +from sklearn.metrics import explained_variance_score, roc_auc_score + import examples.python.utils.appr_sigmoid as Sigmoid +import examples.python.utils.dataset_utils as dsutil import spu.utils.distributed as ppd diff --git a/examples/python/ml/ss_xgb/ss_xgb.py b/examples/python/ml/ss_xgb/ss_xgb.py index c1279544..836d6980 100644 --- a/examples/python/ml/ss_xgb/ss_xgb.py +++ b/examples/python/ml/ss_xgb/ss_xgb.py @@ -20,21 +20,19 @@ import argparse import json -from statistics import mode import time - -from sklearn.metrics import roc_auc_score - +from functools import reduce +from statistics import mode from typing import Any, Dict, List, Tuple + import jax.numpy as jnp import numpy as np import pandas as pd -from functools import reduce +from sklearn.metrics import roc_auc_score -import examples.python.utils.dataset_utils as dsutil import examples.python.utils.appr_sigmoid as Sigmoid +import examples.python.utils.dataset_utils as dsutil import spu.utils.distributed as ppd - from spu.utils.distributed import PYU, SPU parser = argparse.ArgumentParser(description='distributed driver.') diff --git a/examples/python/ml/stax_mnist_classifier/stax_mnist_classifier.py b/examples/python/ml/stax_mnist_classifier/stax_mnist_classifier.py index 98990398..b6af9fbe 100644 --- a/examples/python/ml/stax_mnist_classifier/stax_mnist_classifier.py +++ b/examples/python/ml/stax_mnist_classifier/stax_mnist_classifier.py @@ -25,16 +25,15 @@ # Run this example script. # > bazel run -c opt //examples/python/ml/stax_mnist_classifier:stax_mnist_classifier -import time import itertools +import time +import jax.numpy as jnp import numpy.random as npr +from jax import grad, jit, random +from jax.example_libraries import optimizers, stax +from jax.example_libraries.stax import Conv, Dense, Flatten, LogSoftmax, MaxPool, Relu -import jax.numpy as jnp -from jax import jit, grad, random -from jax.example_libraries import optimizers -from jax.example_libraries import stax -from jax.example_libraries.stax import Dense, Relu, LogSoftmax, Flatten, Conv, MaxPool import examples.python.utils.dataset_utils as datasets @@ -129,6 +128,7 @@ def run_cpu(): def run_spu(): import argparse import json + import spu.utils.distributed as ppd parser = argparse.ArgumentParser(description='distributed driver.') diff --git a/examples/python/ml/stax_nn/models.py b/examples/python/ml/stax_nn/models.py index b56fd15a..23d2a45d 100644 --- a/examples/python/ml/stax_nn/models.py +++ b/examples/python/ml/stax_nn/models.py @@ -15,16 +15,16 @@ from jax.example_libraries import stax from jax.example_libraries.stax import ( - Conv, - MaxPool, AvgPool, - Flatten, + BatchNorm, + Conv, Dense, + Flatten, + LogSoftmax, + MaxPool, Relu, Sigmoid, - LogSoftmax, Softmax, - BatchNorm, ) diff --git a/examples/python/ml/tf_experiment/tf_experiment.py b/examples/python/ml/tf_experiment/tf_experiment.py index 20331590..2eb058fd 100644 --- a/examples/python/ml/tf_experiment/tf_experiment.py +++ b/examples/python/ml/tf_experiment/tf_experiment.py @@ -22,6 +22,7 @@ import numpy as np import tensorflow as tf from sklearn import metrics + import spu.utils.distributed as ppd # This example is to show tf program could be converted to XLA IR and run by SPU. diff --git a/examples/python/pir/pir_client.py b/examples/python/pir/pir_client.py index 48a924bc..274133f4 100644 --- a/examples/python/pir/pir_client.py +++ b/examples/python/pir/pir_client.py @@ -19,9 +19,9 @@ from absl import app, flags -import spu.pir as pir import spu.libspu.link as link import spu.libspu.logging as logging +import spu.pir as pir flags.DEFINE_integer("rank", 0, "rank: 0/1/2...") flags.DEFINE_string("party_ips", "127.0.0.1:9307,127.0.0.1:9308", "party addresses") diff --git a/examples/python/pir/pir_mem_server.py b/examples/python/pir/pir_mem_server.py index 1429648c..fd5f7ea9 100644 --- a/examples/python/pir/pir_mem_server.py +++ b/examples/python/pir/pir_mem_server.py @@ -18,13 +18,13 @@ # > --count_per_query 1 -max_label_length 256 \ # > --oprf_key_path oprf_key.bin --setup_path setup_path +import time + from absl import app, flags -import spu.pir as pir import spu.libspu.link as link import spu.libspu.logging as logging - -import time +import spu.pir as pir flags.DEFINE_integer("rank", 0, "rank: 0/1/2...") flags.DEFINE_string("party_ips", "127.0.0.1:9307,127.0.0.1:9308", "party addresses") diff --git a/examples/python/pir/pir_server.py b/examples/python/pir/pir_server.py index 538c286b..6ed219f2 100644 --- a/examples/python/pir/pir_server.py +++ b/examples/python/pir/pir_server.py @@ -19,9 +19,9 @@ from absl import app, flags -import spu.pir as pir import spu.libspu.link as link import spu.libspu.logging as logging +import spu.pir as pir flags.DEFINE_integer("rank", 0, "rank: 0/1/2...") flags.DEFINE_string("party_ips", "127.0.0.1:9307,127.0.0.1:9308", "party addresses") diff --git a/examples/python/pir/pir_setup.py b/examples/python/pir/pir_setup.py index 848a5b09..bed2f599 100644 --- a/examples/python/pir/pir_setup.py +++ b/examples/python/pir/pir_setup.py @@ -20,9 +20,9 @@ from absl import app, flags -import spu.pir as pir import spu.libspu.link as link import spu.libspu.logging as logging +import spu.pir as pir flags.DEFINE_string("in_path", "data.csv", "data input path") flags.DEFINE_string("key_columns", "id", "csv file key filed name") diff --git a/examples/python/psi/mem_psi.py b/examples/python/psi/mem_psi.py index 38b94d3b..f56ee944 100644 --- a/examples/python/psi/mem_psi.py +++ b/examples/python/psi/mem_psi.py @@ -17,12 +17,12 @@ # > bazel run //examples/python/psi:mem_psi -- --rank 0 --protocol ECDH_PSI_2PC --in_path examples/data/psi_1.csv --field_name id --out_path /tmp/p1.out # > bazel run //examples/python/psi:mem_psi -- --rank 1 --protocol ECDH_PSI_2PC --in_path examples/data/psi_2.csv --field_name id --out_path /tmp/p2.out +import pandas as pd from absl import app, flags -import pandas as pd -import spu.psi as psi import spu.libspu.link as link import spu.libspu.logging as logging +import spu.psi as psi flags.DEFINE_string("protocol", "ECDH_PSI_2PC", "psi protocol, see `spu/psi/psi.proto`") flags.DEFINE_integer("rank", 0, "rank: 0/1/2...") diff --git a/examples/python/psi/simple_psi.py b/examples/python/psi/simple_psi.py index d593df45..5ef20a05 100644 --- a/examples/python/psi/simple_psi.py +++ b/examples/python/psi/simple_psi.py @@ -19,9 +19,9 @@ from absl import app, flags -import spu.psi as psi import spu.libspu.link as link import spu.libspu.logging as logging +import spu.psi as psi flags.DEFINE_string("protocol", "ECDH_PSI_2PC", "psi protocol, see `spu/psi/psi.proto`") flags.DEFINE_integer("rank", 0, "rank: 0/1/2...") diff --git a/examples/python/psi/unbalanced_psi.py b/examples/python/psi/unbalanced_psi.py index fb03f322..dbe2bcdb 100644 --- a/examples/python/psi/unbalanced_psi.py +++ b/examples/python/psi/unbalanced_psi.py @@ -17,11 +17,12 @@ # > bazel run //examples/python/psi:unbalanced_psi -- --rank 0 --in_path examples/data/psi_1.csv --field_names id --out_path /tmp/p1.out # > bazel run //examples/python/psi:unbalanced_psi -- --rank 1 --in_path examples/data/psi_2.csv --field_names id --out_path /tmp/p2.out +import time + from absl import app, flags -import spu.psi as psi import spu.libspu.link as link -import time +import spu.psi as psi flags.DEFINE_integer("rank", 0, "rank: 0/1/2...") flags.DEFINE_string("in_path", "data.csv", "data input path") diff --git a/examples/python/utils/dataset_utils.py b/examples/python/utils/dataset_utils.py index f706bfa2..4bb4e5c1 100644 --- a/examples/python/utils/dataset_utils.py +++ b/examples/python/utils/dataset_utils.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np import array import gzip import os -from os import path import struct import urllib.request +from os import path + +import numpy as np def standardize(data): diff --git a/examples/python/utils/nodectl.py b/examples/python/utils/nodectl.py index d1661428..52f6c497 100644 --- a/examples/python/utils/nodectl.py +++ b/examples/python/utils/nodectl.py @@ -14,14 +14,10 @@ import argparse import json -import multiprocess -import sys -import grpc -import jax +import multiprocess import spu.utils.distributed as ppd -from spu.utils import distributed_pb2_grpc parser = argparse.ArgumentParser(description='SPU node service.') parser.add_argument( diff --git a/libspu/compiler/codegen/codegen.cc b/libspu/compiler/codegen/codegen.cc index 03e316fc..1c628b9c 100644 --- a/libspu/compiler/codegen/codegen.cc +++ b/libspu/compiler/codegen/codegen.cc @@ -14,9 +14,7 @@ #include "libspu/compiler/codegen/codegen.h" -#include - -#include "llvm/Support/raw_os_ostream.h" +#include "llvm/Support/raw_ostream.h" namespace spu::compiler { diff --git a/libspu/compiler/codegen/codegen.h b/libspu/compiler/codegen/codegen.h index f63a773a..649bc7f4 100644 --- a/libspu/compiler/codegen/codegen.h +++ b/libspu/compiler/codegen/codegen.h @@ -15,8 +15,6 @@ #pragma once #include -#include -#include #include "mlir/IR/BuiltinOps.h" diff --git a/libspu/compiler/common/compilation_context.h b/libspu/compiler/common/compilation_context.h index 25e1b59e..dec345fd 100644 --- a/libspu/compiler/common/compilation_context.h +++ b/libspu/compiler/common/compilation_context.h @@ -17,7 +17,6 @@ #include #include #include -#include #include "mlir/IR/MLIRContext.h" #include "mlir/Pass/PassManager.h" diff --git a/libspu/compiler/common/ir_printer_config.cc b/libspu/compiler/common/ir_printer_config.cc index 02432a16..47dff64a 100644 --- a/libspu/compiler/common/ir_printer_config.cc +++ b/libspu/compiler/common/ir_printer_config.cc @@ -16,8 +16,7 @@ #include -#include "fmt/chrono.h" -#include "fmt/format.h" +#include "fmt/chrono.h" // IWYU pragma: keep, format chrono needs this header #include "llvm/Support/FileSystem.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Pass/Pass.h" diff --git a/libspu/compiler/common/ir_printer_config.h b/libspu/compiler/common/ir_printer_config.h index 69626e61..63b69d73 100644 --- a/libspu/compiler/common/ir_printer_config.h +++ b/libspu/compiler/common/ir_printer_config.h @@ -16,7 +16,6 @@ #include #include -#include #include "mlir/Pass/PassManager.h" diff --git a/libspu/compiler/compile.cc b/libspu/compiler/compile.cc index 0d93b1b0..9e6cb728 100644 --- a/libspu/compiler/compile.cc +++ b/libspu/compiler/compile.cc @@ -14,18 +14,11 @@ #include "libspu/compiler/compile.h" -#include -#include -#include - #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/MLIRContext.h" -#include "spdlog/spdlog.h" #include "libspu/compiler/codegen/codegen.h" #include "libspu/compiler/core/core.h" #include "libspu/compiler/front_end/fe.h" -#include "libspu/core/prelude.h" namespace spu::compiler { @@ -40,9 +33,7 @@ std::string compile(CompilationContext *ctx, core.doit(mlir_module.get()); // Run codegen - CodeGen codegen; - - return codegen.doit(mlir_module.get()); + return spu::compiler::CodeGen::doit(mlir_module.get()); } } // namespace spu::compiler diff --git a/libspu/compiler/front_end/hlo_importer.cc b/libspu/compiler/front_end/hlo_importer.cc index 977d9c98..2a9f08e2 100644 --- a/libspu/compiler/front_end/hlo_importer.cc +++ b/libspu/compiler/front_end/hlo_importer.cc @@ -14,7 +14,6 @@ #include "libspu/compiler/front_end/hlo_importer.h" -#include "spdlog/spdlog.h" #include "xla/service/algebraic_simplifier.h" #include "xla/service/batch_dot_simplification.h" #include "xla/service/batchnorm_expander.h" diff --git a/libspu/compiler/passes/decompose_comparison.cc b/libspu/compiler/passes/decompose_comparison.cc index 4f9550c3..f15337d5 100644 --- a/libspu/compiler/passes/decompose_comparison.cc +++ b/libspu/compiler/passes/decompose_comparison.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include - #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" diff --git a/libspu/compiler/passes/decompose_minmax.cc b/libspu/compiler/passes/decompose_minmax.cc index 2cab1f7e..490a7509 100644 --- a/libspu/compiler/passes/decompose_minmax.cc +++ b/libspu/compiler/passes/decompose_minmax.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include - #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" diff --git a/libspu/compiler/passes/hlo_legalize_to_pphlo.cc b/libspu/compiler/passes/hlo_legalize_to_pphlo.cc index 501e2090..aada80e1 100644 --- a/libspu/compiler/passes/hlo_legalize_to_pphlo.cc +++ b/libspu/compiler/passes/hlo_legalize_to_pphlo.cc @@ -28,7 +28,6 @@ #include "mlir/Transforms/DialectConversion.h" #include "stablehlo/dialect/StablehloOps.h" -#include "libspu/compiler/common/compilation_context.h" #include "libspu/compiler/passes/map_stablehlo_to_pphlo_op.h" #include "libspu/compiler/passes/pass_details.h" #include "libspu/compiler/passes/value_visibility_map.h" diff --git a/libspu/compiler/passes/optimize_select.cc b/libspu/compiler/passes/optimize_select.cc index 67f797dd..44765619 100644 --- a/libspu/compiler/passes/optimize_select.cc +++ b/libspu/compiler/passes/optimize_select.cc @@ -19,7 +19,6 @@ #include "libspu/compiler/passes/pass_details.h" #include "libspu/compiler/passes/passes.h" #include "libspu/dialect/pphlo_ops.h" -#include "libspu/dialect/pphlo_types.h" namespace mlir::pphlo { diff --git a/libspu/compiler/passes/optimize_sqrt_plus_eps.cc b/libspu/compiler/passes/optimize_sqrt_plus_eps.cc index 86c874e8..63930aae 100644 --- a/libspu/compiler/passes/optimize_sqrt_plus_eps.cc +++ b/libspu/compiler/passes/optimize_sqrt_plus_eps.cc @@ -12,13 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "spdlog/spdlog.h" #include "libspu/compiler/passes/pass_details.h" #include "libspu/dialect/pphlo_ops.h" diff --git a/libspu/compiler/passes/pass_details.h b/libspu/compiler/passes/pass_details.h index 3960ebf8..6dccf126 100644 --- a/libspu/compiler/passes/pass_details.h +++ b/libspu/compiler/passes/pass_details.h @@ -17,7 +17,6 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "libspu/compiler/passes/passes.h" -#include "libspu/dialect/pphlo_base_enums.h" #include "libspu/dialect/pphlo_dialect.h" namespace mlir::pphlo { diff --git a/libspu/compiler/passes/rewrite_div_sqrt_patterns.cc b/libspu/compiler/passes/rewrite_div_sqrt_patterns.cc index 3c732384..a975799e 100644 --- a/libspu/compiler/passes/rewrite_div_sqrt_patterns.cc +++ b/libspu/compiler/passes/rewrite_div_sqrt_patterns.cc @@ -18,7 +18,6 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "spdlog/spdlog.h" #include "libspu/compiler/passes/pass_details.h" #include "libspu/dialect/pphlo_ops.h" diff --git a/libspu/compiler/passes/sort_lowering.cc b/libspu/compiler/passes/sort_lowering.cc index c6a51828..4ffb3a84 100644 --- a/libspu/compiler/passes/sort_lowering.cc +++ b/libspu/compiler/passes/sort_lowering.cc @@ -15,12 +15,10 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "spdlog/spdlog.h" #include "libspu/compiler/passes/pass_details.h" #include "libspu/compiler/passes/passes.h" #include "libspu/dialect/pphlo_ops.h" -#include "libspu/dialect/pphlo_types.h" namespace mlir::pphlo { diff --git a/libspu/compiler/passes/value_visibility_map.cc b/libspu/compiler/passes/value_visibility_map.cc index 733f35a6..6efd0d83 100644 --- a/libspu/compiler/passes/value_visibility_map.cc +++ b/libspu/compiler/passes/value_visibility_map.cc @@ -32,18 +32,18 @@ Visibility ComputePromotedVisibility(Visibility v1, Visibility v2) { } // namespace Visibility ValueVisibilityMap::getValueVisibility(const Value &v) const { - const auto &iter = storage.find(v); - SPU_ENFORCE(iter != storage.end()); + const auto &iter = storage_.find(v); + SPU_ENFORCE(iter != storage_.end()); return iter->second; } void ValueVisibilityMap::setValueVisibility(const Value &val, Visibility vis) { - const auto &iter = storage.find(val); - if (iter != storage.end()) { + const auto &iter = storage_.find(val); + if (iter != storage_.end()) { // Merge - storage[val] = ComputePromotedVisibility(iter->second, vis); + storage_[val] = ComputePromotedVisibility(iter->second, vis); } else { - storage[val] = vis; + storage_[val] = vis; } } diff --git a/libspu/compiler/passes/value_visibility_map.h b/libspu/compiler/passes/value_visibility_map.h index b6fda4e1..c15ec383 100644 --- a/libspu/compiler/passes/value_visibility_map.h +++ b/libspu/compiler/passes/value_visibility_map.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once + #include "llvm/ADT/DenseMap.h" #include "mlir/IR/Value.h" @@ -22,7 +23,7 @@ namespace mlir::pphlo { class ValueVisibilityMap { private: - llvm::DenseMap storage; + llvm::DenseMap storage_; public: Visibility getValueVisibility(const Value &v) const; diff --git a/libspu/compiler/passes/visibility_inference.cc b/libspu/compiler/passes/visibility_inference.cc index 0b9a0df1..3b6eae6c 100644 --- a/libspu/compiler/passes/visibility_inference.cc +++ b/libspu/compiler/passes/visibility_inference.cc @@ -47,18 +47,18 @@ void VisibilityInference::inferIf(Operation &op) { llvm::SmallVector input_vis; for (const auto &operand : op.getOperands()) { - input_vis.emplace_back(ValueVis_.getValueVisibility(operand)); + input_vis.emplace_back(value_vis_.getValueVisibility(operand)); } // Infer true branch for (const auto &blkarg : ifOp.getTrueBranch().getArguments()) { - ValueVis_.setValueVisibility(blkarg, input_vis[blkarg.getArgNumber()]); + value_vis_.setValueVisibility(blkarg, input_vis[blkarg.getArgNumber()]); } inferRegion(ifOp.getTrueBranch()); // Infer false branch for (const auto &blkarg : ifOp.getFalseBranch().getArguments()) { - ValueVis_.setValueVisibility(blkarg, input_vis[blkarg.getArgNumber()]); + value_vis_.setValueVisibility(blkarg, input_vis[blkarg.getArgNumber()]); } inferRegion(ifOp.getFalseBranch()); @@ -69,7 +69,7 @@ void VisibilityInference::inferIf(Operation &op) { SPU_ENFORCE(llvm::isa(false_return)); // Cond vis - auto cond_vis = ValueVis_.getValueVisibility(ifOp.getPred()); + auto cond_vis = value_vis_.getValueVisibility(ifOp.getPred()); for (const auto &ret : llvm::enumerate(ifOp->getResults())) { SmallVector vis; @@ -79,13 +79,13 @@ void VisibilityInference::inferIf(Operation &op) { // Get true branch result vis vis.emplace_back( - ValueVis_.getValueVisibility(true_return.getOperand(ret.index()))); + value_vis_.getValueVisibility(true_return.getOperand(ret.index()))); // Get false branch result vis vis.emplace_back( - ValueVis_.getValueVisibility(false_return.getOperand(ret.index()))); + value_vis_.getValueVisibility(false_return.getOperand(ret.index()))); - ValueVis_.setValueVisibility(ret.value(), - TypeTools::inferResultVisibility(vis)); + value_vis_.setValueVisibility(ret.value(), + TypeTools::inferResultVisibility(vis)); } } @@ -96,13 +96,13 @@ void VisibilityInference::inferCase(Operation &op) { llvm::SmallVector input_vis; llvm::SmallVector returns; for (const auto &operand : caseOp->getOperands()) { - input_vis.emplace_back(ValueVis_.getValueVisibility(operand)); + input_vis.emplace_back(value_vis_.getValueVisibility(operand)); } // Infer each branch for (auto ®ion : caseOp.getBranches()) { for (const auto &blkarg : region.getArguments()) { - ValueVis_.setValueVisibility(blkarg, input_vis[blkarg.getArgNumber()]); + value_vis_.setValueVisibility(blkarg, input_vis[blkarg.getArgNumber()]); } inferRegion(region); auto *ret = ®ion.back().back(); @@ -111,7 +111,7 @@ void VisibilityInference::inferCase(Operation &op) { } // Index vis - auto index_vis = ValueVis_.getValueVisibility(caseOp.getIndex()); + auto index_vis = value_vis_.getValueVisibility(caseOp.getIndex()); // Infer result visibility for (const auto &ret_enu : llvm::enumerate(caseOp->getResults())) { @@ -121,11 +121,11 @@ void VisibilityInference::inferCase(Operation &op) { for (auto *ret : returns) { vis.emplace_back( - ValueVis_.getValueVisibility(ret->getOperand(ret_enu.index()))); + value_vis_.getValueVisibility(ret->getOperand(ret_enu.index()))); } - ValueVis_.setValueVisibility(ret_enu.value(), - TypeTools::inferResultVisibility(vis)); + value_vis_.setValueVisibility(ret_enu.value(), + TypeTools::inferResultVisibility(vis)); } } @@ -137,14 +137,14 @@ void VisibilityInference::inferWhile(Operation &op) { SmallVector result_vis(op.getNumOperands()); for (int64_t idx = 0; idx < op.getNumOperands(); ++idx) { - input_vis[idx] = ValueVis_.getValueVisibility(whileOp->getOperand(idx)); + input_vis[idx] = value_vis_.getValueVisibility(whileOp->getOperand(idx)); } bool converge = false; do { // Push visibility to block args for (const auto &blkarg : whileOp.getBody().getArguments()) { - ValueVis_.setValueVisibility(blkarg, input_vis[blkarg.getArgNumber()]); + value_vis_.setValueVisibility(blkarg, input_vis[blkarg.getArgNumber()]); } // Infer body region @@ -157,7 +157,7 @@ void VisibilityInference::inferWhile(Operation &op) { // Update visibility for (int64_t idx = 0; idx < body_return.getNumOperands(); ++idx) { result_vis[idx] = - ValueVis_.getValueVisibility(body_return.getOperand(idx)); + value_vis_.getValueVisibility(body_return.getOperand(idx)); } converge = (input_vis == result_vis); @@ -165,17 +165,17 @@ void VisibilityInference::inferWhile(Operation &op) { } while (!converge); for (int64_t idx = 0; idx < op.getNumOperands(); ++idx) { - ValueVis_.setValueVisibility(whileOp.getBody().getArgument(idx), - input_vis[idx]); - ValueVis_.setValueVisibility(whileOp.getCond().getArgument(idx), - input_vis[idx]); + value_vis_.setValueVisibility(whileOp.getBody().getArgument(idx), + input_vis[idx]); + value_vis_.setValueVisibility(whileOp.getCond().getArgument(idx), + input_vis[idx]); } inferRegion(whileOp.getCond()); // Update result visibility for (int64_t idx = 0; idx < op.getNumResults(); ++idx) { - ValueVis_.setValueVisibility(op.getResult(idx), input_vis[idx]); + value_vis_.setValueVisibility(op.getResult(idx), input_vis[idx]); } } @@ -184,14 +184,14 @@ void VisibilityInference::inferSort(Operation &op) { // Push inputs to body region for (const auto &in : llvm::enumerate(op.getOperands())) { - auto inputVis = ValueVis_.getValueVisibility(in.value()); - ValueVis_.setValueVisibility( + auto inputVis = value_vis_.getValueVisibility(in.value()); + value_vis_.setValueVisibility( sortOp.getComparator().getArgument(2 * in.index()), inputVis); - ValueVis_.setValueVisibility( + value_vis_.setValueVisibility( sortOp.getComparator().getArgument(2 * in.index() + 1), inputVis); // Sort does not change result vis - ValueVis_.setValueVisibility(op.getResult(in.index()), inputVis); + value_vis_.setValueVisibility(op.getResult(in.index()), inputVis); } inferRegion(sortOp.getComparator()); @@ -200,20 +200,20 @@ void VisibilityInference::inferSort(Operation &op) { auto &comp_ret = *sortOp.getComparator().front().getTerminator(); SPU_ENFORCE(llvm::isa(comp_ret)); - if (ValueVis_.getValueVisibility(comp_ret.getOperand(0)) == + if (value_vis_.getValueVisibility(comp_ret.getOperand(0)) == Visibility::VIS_SECRET) { // If comparator result is secret, all results are secrets for (const auto &in : llvm::enumerate(op.getOperands())) { - ValueVis_.setValueVisibility( + value_vis_.setValueVisibility( sortOp.getComparator().getArgument(2 * in.index()), Visibility::VIS_SECRET); - ValueVis_.setValueVisibility( + value_vis_.setValueVisibility( sortOp.getComparator().getArgument(2 * in.index() + 1), Visibility::VIS_SECRET); // Sort does not change result vis - ValueVis_.setValueVisibility(op.getResult(in.index()), - Visibility::VIS_SECRET); + value_vis_.setValueVisibility(op.getResult(in.index()), + Visibility::VIS_SECRET); } inferRegion(sortOp.getComparator()); @@ -223,11 +223,11 @@ void VisibilityInference::inferSort(Operation &op) { void VisibilityInference::inferSelectAndScatter(Operation &op) { auto selectAndScatterOp = llvm::dyn_cast(op); - auto op_vis = ValueVis_.getValueVisibility(selectAndScatterOp.getOperand()); + auto op_vis = value_vis_.getValueVisibility(selectAndScatterOp.getOperand()); auto source_vis = - ValueVis_.getValueVisibility(selectAndScatterOp.getSource()); + value_vis_.getValueVisibility(selectAndScatterOp.getSource()); auto init_vis = - ValueVis_.getValueVisibility(selectAndScatterOp.getInitValue()); + value_vis_.getValueVisibility(selectAndScatterOp.getInitValue()); // init and operand must have the same visibility auto promoted_init_op_vis = @@ -235,18 +235,18 @@ void VisibilityInference::inferSelectAndScatter(Operation &op) { // Select region { - ValueVis_.setValueVisibility(selectAndScatterOp.getSelect().getArgument(0), - promoted_init_op_vis); - ValueVis_.setValueVisibility(selectAndScatterOp.getSelect().getArgument(1), - promoted_init_op_vis); + value_vis_.setValueVisibility(selectAndScatterOp.getSelect().getArgument(0), + promoted_init_op_vis); + value_vis_.setValueVisibility(selectAndScatterOp.getSelect().getArgument(1), + promoted_init_op_vis); inferRegion(selectAndScatterOp.getSelect()); } // Scatter region { - ValueVis_.setValueVisibility(selectAndScatterOp.getScatter().getArgument(0), - source_vis); - ValueVis_.setValueVisibility(selectAndScatterOp.getScatter().getArgument(1), - promoted_init_op_vis); + value_vis_.setValueVisibility( + selectAndScatterOp.getScatter().getArgument(0), source_vis); + value_vis_.setValueVisibility( + selectAndScatterOp.getScatter().getArgument(1), promoted_init_op_vis); inferRegion(selectAndScatterOp.getScatter()); } @@ -258,9 +258,9 @@ void VisibilityInference::inferSelectAndScatter(Operation &op) { llvm::dyn_cast(scatter_return)->getNumOperands() == 1); - ValueVis_.setValueVisibility( + value_vis_.setValueVisibility( selectAndScatterOp.getResult(), - ValueVis_.getValueVisibility(scatter_return.getOperand(0))); + value_vis_.getValueVisibility(scatter_return.getOperand(0))); } void VisibilityInference::inferIntrinsic(Operation &op) { @@ -274,18 +274,18 @@ void VisibilityInference::inferIntrinsic(Operation &op) { if (op.getNumResults() == 1) { SmallVector operand_vis; for (auto operand : op.getOperands()) { - operand_vis.emplace_back(ValueVis_.getValueVisibility(operand)); + operand_vis.emplace_back(value_vis_.getValueVisibility(operand)); } auto ret_vis = TypeTools::inferResultVisibility(operand_vis); - ValueVis_.setValueVisibility(op.getResult(0), ret_vis); + value_vis_.setValueVisibility(op.getResult(0), ret_vis); } else { SPU_ENFORCE(op.getNumResults() == op.getNumOperands(), "Default intrinsic inference can only handle single output or " "#output matches #input"); for (int64_t idx = 0; idx < op.getNumResults(); ++idx) { - ValueVis_.setValueVisibility( - op.getResult(idx), ValueVis_.getValueVisibility(op.getOperand(idx))); + value_vis_.setValueVisibility( + op.getResult(idx), value_vis_.getValueVisibility(op.getOperand(idx))); } } } @@ -303,15 +303,15 @@ void VisibilityInference::inferOperation(Operation &op) { inferCase(op); } else if (llvm::isa(op)) { // Constant always returns public - ValueVis_.setValueVisibility(op.getResult(0), Visibility::VIS_PUBLIC); + value_vis_.setValueVisibility(op.getResult(0), Visibility::VIS_PUBLIC); } else if (llvm::isa(op)) { inferSort(op); } else if (llvm::isa(op)) { // For gather op, if either operand or indices is a secret, result is a // secret - auto operand_vis = ValueVis_.getValueVisibility(op.getOperand(0)); - auto indices_vis = ValueVis_.getValueVisibility(op.getOperand(1)); - ValueVis_.setValueVisibility( + auto operand_vis = value_vis_.getValueVisibility(op.getOperand(0)); + auto indices_vis = value_vis_.getValueVisibility(op.getOperand(1)); + value_vis_.setValueVisibility( op.getResult(0), TypeTools::inferResultVisibility({operand_vis, indices_vis})); } else if (llvm::isa(op)) { @@ -321,10 +321,10 @@ void VisibilityInference::inferOperation(Operation &op) { } else if (op.getNumResults() == 1) { SmallVector operand_vis; for (auto operand : op.getOperands()) { - operand_vis.emplace_back(ValueVis_.getValueVisibility(operand)); + operand_vis.emplace_back(value_vis_.getValueVisibility(operand)); } auto ret_vis = TypeTools::inferResultVisibility(operand_vis); - ValueVis_.setValueVisibility(op.getResult(0), ret_vis); + value_vis_.setValueVisibility(op.getResult(0), ret_vis); } else if (llvm::isa(op) || llvm::isa(op)) { // Do nothing diff --git a/libspu/compiler/passes/visibility_inference.h b/libspu/compiler/passes/visibility_inference.h index c641981a..3a86f61d 100644 --- a/libspu/compiler/passes/visibility_inference.h +++ b/libspu/compiler/passes/visibility_inference.h @@ -15,8 +15,6 @@ #pragma once #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Value.h" #include "libspu/compiler/passes/value_visibility_map.h" #include "libspu/dialect/pphlo_types.h" @@ -25,8 +23,8 @@ namespace mlir::pphlo { class VisibilityInference { public: - explicit VisibilityInference(ValueVisibilityMap &ValueVis) - : ValueVis_(ValueVis) {} + explicit VisibilityInference(ValueVisibilityMap &value_vis) + : value_vis_(value_vis) {} void inferFunc(func::FuncOp &func); void inferRegion(Region ®ion); @@ -48,16 +46,17 @@ class VisibilityInference { size_t num_results = op.getNumResults(); std::vector input_vis; for (size_t idx = 0; idx < num_results; ++idx) { - auto inputVis = ValueVis_.getValueVisibility(reduceOp.getOperands()[idx]); + auto inputVis = + value_vis_.getValueVisibility(reduceOp.getOperands()[idx]); auto initVis = - ValueVis_.getValueVisibility(reduceOp.getInitValues()[idx]); + value_vis_.getValueVisibility(reduceOp.getInitValues()[idx]); auto promoted_vis = TypeTools::inferResultVisibility({inputVis, initVis}); input_vis.emplace_back(promoted_vis); - ValueVis_.setValueVisibility(reduceOp.getBody().getArgument(idx), - promoted_vis); - ValueVis_.setValueVisibility( + value_vis_.setValueVisibility(reduceOp.getBody().getArgument(idx), + promoted_vis); + value_vis_.setValueVisibility( reduceOp.getBody().getArgument(num_results + idx), promoted_vis); } @@ -73,8 +72,8 @@ class VisibilityInference { std::vector ret_vis; for (size_t idx = 0; idx < reduceOp->getNumResults(); ++idx) { auto resultVis = - ValueVis_.getValueVisibility(terminator->getOperand(idx)); - ValueVis_.setValueVisibility(reduceOp->getResult(idx), resultVis); + value_vis_.getValueVisibility(terminator->getOperand(idx)); + value_vis_.setValueVisibility(reduceOp->getResult(idx), resultVis); ret_vis.emplace_back(resultVis); if (resultVis != input_vis[idx]) { reinfer = true; @@ -83,9 +82,9 @@ class VisibilityInference { if (reinfer) { for (size_t idx = 0; idx < num_results; ++idx) { - ValueVis_.setValueVisibility(reduceOp.getBody().getArgument(idx), - ret_vis[idx]); - ValueVis_.setValueVisibility( + value_vis_.setValueVisibility(reduceOp.getBody().getArgument(idx), + ret_vis[idx]); + value_vis_.setValueVisibility( reduceOp.getBody().getArgument(num_results + idx), ret_vis[idx]); } @@ -95,7 +94,7 @@ class VisibilityInference { } } - ValueVisibilityMap &ValueVis_; + ValueVisibilityMap &value_vis_; }; } // namespace mlir::pphlo diff --git a/libspu/core/BUILD.bazel b/libspu/core/BUILD.bazel index 9af03c0d..71a22b1b 100644 --- a/libspu/core/BUILD.bazel +++ b/libspu/core/BUILD.bazel @@ -16,29 +16,13 @@ load("//bazel:spu.bzl", "spu_cc_binary", "spu_cc_library", "spu_cc_test") package(default_visibility = ["//visibility:public"]) -spu_cc_library( - name = "core", - deps = [ - ":config", - ":encoding", - ":ndarray_ref", - ":shape", - ":type", - ":type_util", - ":xt_helper", - "@yacl//yacl/base:buffer", - ], -) - spu_cc_library( name = "trace", srcs = ["trace.cc"], hdrs = ["trace.h"], deps = [ "//libspu/core:prelude", - "@com_google_absl//absl/types:span", "@yacl//yacl/link", - "@yacl//yacl/utils:scope_guard", ], ) @@ -70,7 +54,6 @@ spu_cc_library( ":half", "//libspu:spu_cc_proto", "//libspu/core:prelude", - "@com_google_absl//absl/types:span", "@yacl//yacl/base:int128", ], ) @@ -82,9 +65,8 @@ spu_cc_library( deps = [ ":ndarray_ref", ":parallel_utils", - ":xt_helper", + ":pt_buffer_view", "@yacl//yacl/crypto/tools:prg", - "@yacl//yacl/utils:parallel", ], ) @@ -187,7 +169,6 @@ spu_cc_test( spu_cc_library( name = "xt_helper", - srcs = ["xt_helper.cc"], hdrs = ["xt_helper.h"], deps = [ ":ndarray_ref", @@ -210,7 +191,6 @@ spu_cc_library( hdrs = ["vectorize.h"], deps = [ "//libspu/core:prelude", - "@com_google_absl//absl/types:span", ], ) diff --git a/libspu/core/bit_utils_test.cc b/libspu/core/bit_utils_test.cc index cf657944..333cb18a 100644 --- a/libspu/core/bit_utils_test.cc +++ b/libspu/core/bit_utils_test.cc @@ -14,7 +14,6 @@ #include "libspu/core/bit_utils.h" -#include "gmock/gmock.h" #include "gtest/gtest.h" namespace spu { @@ -49,13 +48,13 @@ TEST(BitUtilsTest, Log2Ceil) { } TEST(BitUtilsTest, BitWidth) { - EXPECT_EQ(BitWidth(0u), 0); - EXPECT_EQ(BitWidth(1u), 1); - EXPECT_EQ(BitWidth(1u << 3), 3 + 1); - EXPECT_EQ(BitWidth(1ull << 3), 3 + 1); - EXPECT_EQ(BitWidth(1ull << 40), 40 + 1); - EXPECT_EQ(BitWidth(yacl::MakeInt128(0, 1ull << 3)), 3 + 1); - EXPECT_EQ(BitWidth(yacl::MakeInt128(1ull << 3, 0)), 3 + 1 + 64); + EXPECT_EQ(BitWidth(0U), 0); + EXPECT_EQ(BitWidth(1U), 1); + EXPECT_EQ(BitWidth(1U << 3), 3 + 1); + EXPECT_EQ(BitWidth(1ULL << 3), 3 + 1); + EXPECT_EQ(BitWidth(1ULL << 40), 40 + 1); + EXPECT_EQ(BitWidth(yacl::MakeInt128(0, 1ULL << 3)), 3 + 1); + EXPECT_EQ(BitWidth(yacl::MakeInt128(1ULL << 3, 0)), 3 + 1 + 64); } TEST(BitUtilsTest, BitDeintl32) { diff --git a/libspu/core/cexpr.cc b/libspu/core/cexpr.cc index c1b65e36..8476d558 100644 --- a/libspu/core/cexpr.cc +++ b/libspu/core/cexpr.cc @@ -19,8 +19,6 @@ #include #include -#include "fmt/format.h" - #include "libspu/core/prelude.h" namespace spu::ce { diff --git a/libspu/core/context.cc b/libspu/core/context.cc index 2fa7626a..7555c74a 100644 --- a/libspu/core/context.cc +++ b/libspu/core/context.cc @@ -14,6 +14,8 @@ #include "libspu/core/context.h" +#include "libspu/core/trace.h" + namespace spu { namespace { diff --git a/libspu/core/context.h b/libspu/core/context.h index 7aacac0f..c27d23c4 100644 --- a/libspu/core/context.h +++ b/libspu/core/context.h @@ -18,11 +18,10 @@ #include #include -#include "yacl/link/link.h" +#include "yacl/link/context.h" #include "libspu/core/object.h" #include "libspu/core/prelude.h" -#include "libspu/core/trace.h" // TODO: bad reference, but implicitly include too much. #include "libspu/core/value.h" #include "libspu/spu.pb.h" @@ -95,8 +94,7 @@ class KernelEvalContext final { uint128_t, // ring constant int64_t, // SignType, // - std::vector, // for sort - absl::Span // for sort + std::vector // for sort >; SPUContext* sctx_; diff --git a/libspu/core/encoding.cc b/libspu/core/encoding.cc index cf1dd4e4..9e869e63 100644 --- a/libspu/core/encoding.cc +++ b/libspu/core/encoding.cc @@ -49,12 +49,11 @@ PtType getDecodeType(DataType dtype) { #undef CASE } -NdArrayRef encodeToRing(const NdArrayRef& src, FieldType field, size_t fxp_bits, - DataType* out_dtype) { - SPU_ENFORCE(src.eltype().isa(), "expect PtType, got={}", src.eltype()); - const PtType pt_type = src.eltype().as()->pt_type(); - const size_t numel = src.numel(); - NdArrayRef dst(makeType(field), src.shape()); +NdArrayRef encodeToRing(const PtBufferView& bv, FieldType field, + size_t fxp_bits, DataType* out_dtype) { + const PtType pt_type = bv.pt_type; + const size_t numel = bv.shape.numel(); + NdArrayRef dst(makeType(field), bv.shape); if (out_dtype != nullptr) { *out_dtype = getEncodeType(pt_type); @@ -73,16 +72,16 @@ NdArrayRef encodeToRing(const NdArrayRef& src, FieldType field, size_t fxp_bits, const T kScale = T(1) << fxp_bits; const T kFxpLower = -(T)std::pow(2, k - 2); const T kFxpUpper = (T)std::pow(2, k - 2) - 1; - const Float kFlpUpper = + const auto kFlpUpper = static_cast(static_cast(kFxpUpper) / kScale); - const Float kFlpLower = + const auto kFlpLower = static_cast(static_cast(kFxpLower) / kScale); - auto _src = NdArrayView(src); auto _dst = NdArrayView(dst); pforeach(0, numel, [&](int64_t idx) { - const auto src_value = _src[idx]; + const auto indices = unflattenIndex(idx, bv.shape); + auto src_value = bv.get(indices); if (std::isnan(src_value)) { // see numpy.nan_to_num // note(jint) I dont know why nan could be @@ -110,11 +109,13 @@ NdArrayRef encodeToRing(const NdArrayRef& src, FieldType field, size_t fxp_bits, field, pt_type); using T = std::make_signed_t; - auto _src = NdArrayView(src); + auto _dst = NdArrayView(dst); // TODO: encoding integer in range [-2^(k-2),2^(k-2)) pforeach(0, numel, [&](int64_t idx) { - _dst[idx] = static_cast(_src[idx]); // NOLINT + const auto indices = unflattenIndex(idx, bv.shape); + auto src_value = bv.get(indices); + _dst[idx] = static_cast(src_value); // NOLINT }); }); }); @@ -125,8 +126,8 @@ NdArrayRef encodeToRing(const NdArrayRef& src, FieldType field, size_t fxp_bits, SPU_THROW("should not be here"); } -NdArrayRef decodeFromRing(const NdArrayRef& src, DataType in_dtype, - size_t fxp_bits, PtType* out_pt_type) { +void decodeFromRing(const NdArrayRef& src, DataType in_dtype, size_t fxp_bits, + PtBufferView* out_bv, PtType* out_pt_type) { const Type& src_type = src.eltype(); const FieldType field = src_type.as()->field(); const PtType pt_type = getDecodeType(in_dtype); @@ -139,36 +140,35 @@ NdArrayRef decodeFromRing(const NdArrayRef& src, DataType in_dtype, *out_pt_type = pt_type; } - NdArrayRef dst(makePtType(pt_type), src.shape()); - DISPATCH_ALL_FIELDS(field, "field", [&]() { DISPATCH_ALL_PT_TYPES(pt_type, "pt_type", [&]() { using T = std::make_signed_t; auto _src = NdArrayView(src); - auto _dst = NdArrayView(dst); if (in_dtype == DT_I1) { constexpr bool kSanity = std::is_same_v; SPU_ENFORCE(kSanity); - pforeach(0, numel, - [&](int64_t idx) { _dst[idx] = !((_src[idx] & 0x1) == 0); }); + pforeach(0, numel, [&](int64_t idx) { + bool value = !((_src[idx] & 0x1) == 0); + out_bv->set(idx, value); + }); } else if (in_dtype == DT_F32 || in_dtype == DT_F64 || in_dtype == DT_F16) { const T kScale = T(1) << fxp_bits; pforeach(0, numel, [&](int64_t idx) { - _dst[idx] = + auto value = static_cast(static_cast(_src[idx]) / kScale); + out_bv->set(idx, value); }); } else { pforeach(0, numel, [&](int64_t idx) { - _dst[idx] = static_cast(_src[idx]); + auto value = static_cast(_src[idx]); + out_bv->set(idx, value); }); } }); }); - - return dst; } } // namespace spu diff --git a/libspu/core/encoding.h b/libspu/core/encoding.h index d6ceee1a..e611127c 100644 --- a/libspu/core/encoding.h +++ b/libspu/core/encoding.h @@ -15,7 +15,7 @@ #pragma once #include "libspu/core/ndarray_ref.h" -#include "libspu/core/type.h" +#include "libspu/core/pt_buffer_view.h" namespace spu { @@ -84,12 +84,10 @@ DataType getEncodeType(PtType pt_type); PtType getDecodeType(DataType dtype); -// TODO: document me, verbosely -NdArrayRef encodeToRing(const NdArrayRef& src, FieldType field, size_t fxp_bits, - DataType* out_dtype = nullptr); +NdArrayRef encodeToRing(const PtBufferView& src, FieldType field, + size_t fxp_bits, DataType* out_dtype = nullptr); -// TODO: document me, verbosely -NdArrayRef decodeFromRing(const NdArrayRef& src, DataType in_dtype, - size_t fxp_bits, PtType* out_pt_type = nullptr); +void decodeFromRing(const NdArrayRef& src, DataType in_dtype, size_t fxp_bits, + PtBufferView* out_bv, PtType* out_pt_type = nullptr); } // namespace spu diff --git a/libspu/core/encoding_test.cc b/libspu/core/encoding_test.cc index 8a3628bb..b5534bbe 100644 --- a/libspu/core/encoding_test.cc +++ b/libspu/core/encoding_test.cc @@ -104,31 +104,44 @@ TYPED_TEST(FloatEncodingTest, Works) { NdArrayRef frm(makePtType(PtTypeToEnum::value), {samples.size()}); std::copy(samples.begin(), samples.end(), &frm.at({0})); - DataType encoded_dtype; - auto encoded = encodeToRing(frm, kField, kFxpBits, &encoded_dtype); + PtBufferView frm_pv(static_cast(frm.data()), + PtTypeToEnum::value, frm.shape(), frm.strides()); + + DataType encoded_dtype_by_pv; + auto encoded_by_pv = + encodeToRing(frm_pv, kField, kFxpBits, &encoded_dtype_by_pv); if constexpr (std::is_same_v) { - EXPECT_EQ(encoded_dtype, DT_F32); + EXPECT_EQ(encoded_dtype_by_pv, DT_F32); } else { - EXPECT_EQ(encoded_dtype, DT_F64); + EXPECT_EQ(encoded_dtype_by_pv, DT_F64); } - PtType out_pt_type; - auto decoded = decodeFromRing(encoded, encoded_dtype, kFxpBits, &out_pt_type); + + PtType out_pt_type_by_pv; + NdArrayRef decoded_by_pv(makePtType(PtTypeToEnum::value), + {samples.size()}); + PtBufferView decoded_pv(static_cast(decoded_by_pv.data()), + PtTypeToEnum::value, decoded_by_pv.shape(), + decoded_by_pv.strides()); + decodeFromRing(encoded_by_pv, encoded_dtype_by_pv, kFxpBits, &decoded_pv, + &out_pt_type_by_pv); if constexpr (std::is_same_v) { - EXPECT_EQ(encoded_dtype, DT_F32); + EXPECT_EQ(out_pt_type_by_pv, PT_F32); } else { - EXPECT_EQ(encoded_dtype, DT_F64); + EXPECT_EQ(out_pt_type_by_pv, PT_F64); } - auto* out_ptr = &decoded.at({0}); + auto* out_ptr_by_pv = &decoded_by_pv.at({0}); const int64_t kReprBits = SizeOf(kField) * 8 - 2; const int64_t kScale = 1LL << kFxpBits; - EXPECT_EQ(out_ptr[0], -static_cast((1LL << kReprBits)) / kScale); - EXPECT_EQ(out_ptr[1], static_cast((1LL << kReprBits) - 1) / kScale); - EXPECT_EQ(out_ptr[2], -1.0); - EXPECT_EQ(out_ptr[3], 0.0); - EXPECT_EQ(out_ptr[4], 1.0); - EXPECT_NEAR(out_ptr[5], 3.1415926, 0.00001F); + EXPECT_EQ(out_ptr_by_pv[0], + -static_cast((1LL << kReprBits)) / kScale); + EXPECT_EQ(out_ptr_by_pv[1], + static_cast((1LL << kReprBits) - 1) / kScale); + EXPECT_EQ(out_ptr_by_pv[2], -1.0); + EXPECT_EQ(out_ptr_by_pv[3], 0.0); + EXPECT_EQ(out_ptr_by_pv[4], 1.0); + EXPECT_NEAR(out_ptr_by_pv[5], 3.1415926, 0.00001F); } template @@ -155,19 +168,28 @@ TYPED_TEST(IntEncodingTest, Works) { std::copy(samples.begin(), samples.end(), &frm.at({0})); DataType encoded_dtype; - auto encoded = encodeToRing(frm, kField, kFxpBits, &encoded_dtype); + PtBufferView frm_pv(static_cast(frm.data()), + PtTypeToEnum::value, frm.shape(), frm.strides()); + auto encoded_by_pv = encodeToRing(frm_pv, kField, kFxpBits, &encoded_dtype); EXPECT_EQ(encoded_dtype, getEncodeType(frm_pt_type)); PtType out_pt_type; - auto decoded = decodeFromRing(encoded, encoded_dtype, kFxpBits, &out_pt_type); + + NdArrayRef decoded_by_pv(makePtType(PtTypeToEnum::value), + {samples.size()}); + PtBufferView decoded_pv(static_cast(decoded_by_pv.data()), + PtTypeToEnum::value, decoded_by_pv.shape(), + decoded_by_pv.strides()); + decodeFromRing(encoded_by_pv, encoded_dtype, kFxpBits, &decoded_pv, + &out_pt_type); EXPECT_EQ(out_pt_type, frm_pt_type); - IntT* out_ptr = &decoded.at({0}); - EXPECT_EQ(out_ptr[0], samples[0]); - EXPECT_EQ(out_ptr[1], samples[1]); - EXPECT_EQ(out_ptr[2], static_cast(-1)); - EXPECT_EQ(out_ptr[3], 0); - EXPECT_EQ(out_ptr[4], 1); + IntT* out_ptr_by_pv = &decoded_by_pv.at({0}); + EXPECT_EQ(out_ptr_by_pv[0], samples[0]); + EXPECT_EQ(out_ptr_by_pv[1], samples[1]); + EXPECT_EQ(out_ptr_by_pv[2], static_cast(-1)); + EXPECT_EQ(out_ptr_by_pv[3], 0); + EXPECT_EQ(out_ptr_by_pv[4], 1); } } // namespace spu diff --git a/libspu/core/logging.h b/libspu/core/logging.h index d72ff3f1..8662b028 100644 --- a/libspu/core/logging.h +++ b/libspu/core/logging.h @@ -16,7 +16,6 @@ #include -#include "spdlog/spdlog.h" #include "yacl/link/trace.h" namespace spu::logging { diff --git a/libspu/core/ndarray_ref.cc b/libspu/core/ndarray_ref.cc index bbb10cc0..512f68f3 100644 --- a/libspu/core/ndarray_ref.cc +++ b/libspu/core/ndarray_ref.cc @@ -19,11 +19,6 @@ #include #include -#include "fmt/format.h" -#include "fmt/ostream.h" - -#include "libspu/core/parallel_utils.h" - namespace spu { namespace { diff --git a/libspu/core/ndarray_ref.h b/libspu/core/ndarray_ref.h index 126ad1ea..b2595048 100644 --- a/libspu/core/ndarray_ref.h +++ b/libspu/core/ndarray_ref.h @@ -19,7 +19,6 @@ #include "absl/types/span.h" #include "fmt/ostream.h" -#include "spdlog/spdlog.h" #include "yacl/base/buffer.h" #include "libspu/core/bit_utils.h" diff --git a/libspu/core/ndarray_ref_test.cc b/libspu/core/ndarray_ref_test.cc index af85b323..73f69a69 100644 --- a/libspu/core/ndarray_ref_test.cc +++ b/libspu/core/ndarray_ref_test.cc @@ -16,7 +16,6 @@ #include -#include "gmock/gmock.h" #include "gtest/gtest.h" namespace spu { diff --git a/libspu/core/prelude.h b/libspu/core/prelude.h index b094b712..07691501 100644 --- a/libspu/core/prelude.h +++ b/libspu/core/prelude.h @@ -47,8 +47,11 @@ #define SPU_DEBUG_ONLY_THROW YACL_THROW #endif +// Force compiler to inline something regardless of optimization level. +#define SPU_ALWAYS_INLINE inline __attribute__((always_inline)) + // forward scope guard related macros -#include "yacl/utils/scope_guard.h" +#include "yacl/utils/scope_guard.h" // IWYU pragma: keep // Format #include "fmt/ostream.h" diff --git a/libspu/core/pt_buffer_view.cc b/libspu/core/pt_buffer_view.cc index de778e02..5d3fe6f0 100644 --- a/libspu/core/pt_buffer_view.cc +++ b/libspu/core/pt_buffer_view.cc @@ -15,6 +15,7 @@ #include "libspu/core/pt_buffer_view.h" #include "libspu/core/shape.h" +#include "libspu/core/type_util.h" namespace spu { @@ -28,24 +29,23 @@ std::ostream& operator<<(std::ostream& out, PtBufferView v) { NdArrayRef convertToNdArray(PtBufferView bv) { const auto type = makePtType(bv.pt_type); auto out = NdArrayRef(type, bv.shape); - - if (bv.shape.numel() > 0) { - auto* out_ptr = out.data(); - - size_t elsize = SizeOf(bv.pt_type); - - Index indices(bv.shape.size(), 0); - if (bv.isCompact()) { - std::memcpy(out_ptr, bv.get(indices), elsize * bv.shape.numel()); - } else { - do { - std::memcpy(out_ptr, bv.get(indices), elsize); - out_ptr += elsize; - } while (bumpIndices(bv.shape, absl::MakeSpan(indices))); + return DISPATCH_ALL_PT_TYPES(bv.pt_type, "pt_type", [&]() { + using T = ScalarT; + if (bv.shape.numel() > 0) { + auto* out_ptr = out.data(); + + Index indices(bv.shape.size(), 0); + if (bv.isCompact()) { + std::memcpy(out_ptr, &bv.get(indices), sizeof(T) * bv.shape.numel()); + } else { + do { + *out_ptr = bv.get(indices); + out_ptr += 1; + } while (bumpIndices(bv.shape, absl::MakeSpan(indices))); + } } - } - - return out; + return out; + }); } } // namespace spu diff --git a/libspu/core/pt_buffer_view.h b/libspu/core/pt_buffer_view.h index 1b69a9d5..3310fc80 100644 --- a/libspu/core/pt_buffer_view.h +++ b/libspu/core/pt_buffer_view.h @@ -16,8 +16,6 @@ #include -#include "absl/types/span.h" - #include "libspu/core/ndarray_ref.h" #include "libspu/core/prelude.h" #include "libspu/core/shape.h" @@ -41,65 +39,115 @@ constexpr bool } // namespace detail // A view of a plaintext buffer. -// -// Please do not direct use this class if possible. struct PtBufferView { - void const* const ptr; // Pointer to the underlying storage - PtType const pt_type; // Plaintext data type. - Shape const shape; // Shape of the tensor. - Strides const strides; // Strides in number of elements. + void* const ptr; // Pointer to the underlying storage + PtType const pt_type; // Plaintext data type. + Shape const shape; // Shape of the tensor. + Strides const strides; // Strides in number of elements. + bool const write_able{false}; // Whether this is a writable buffer + bool const compacted{false}; // Whether this is a compacted buffer // We have to take a concrete buffer as a view. PtBufferView() = delete; // full constructor - explicit PtBufferView(void const* ptr, PtType pt_type, Shape shape, - Strides strides) - : ptr(ptr), + template + explicit PtBufferView(Pointer ptr, PtType pt_type, Shape in_shape, + Strides in_strides) + : ptr(const_cast(static_cast(ptr))), pt_type(pt_type), - shape(std::move(shape)), - strides(std::move(strides)) {} + shape(std::move(in_shape)), + strides(std::move(in_strides)), + write_able(!std::is_const_v>), + compacted(strides == makeCompactStrides(shape)) { + static_assert(std::is_pointer_v); + } // View c++ builtin scalar type as a buffer template , bool> = true> /* implicit */ PtBufferView(T const& s) // NOLINT - : ptr(static_cast(&s)), + : ptr(const_cast(static_cast(&s))), pt_type(PtTypeToEnum::value), shape(), - strides() {} + strides(), + compacted(true) {} // FIXME(jint): make it work when T = bool - template , bool> = true> + template , bool> = true> /* implicit */ PtBufferView(const T& c) // NOLINT - : ptr(static_cast(c.data())), + : ptr(const_cast(static_cast(c.data()))), pt_type(PtTypeToEnum::value), shape({static_cast(c.size())}), - strides({1}) {} + strides({1}), + compacted(true) {} // View a tensor-like type (i.e. xt::xarray) as a buffer. template , bool> = true> + std::enable_if_t, bool> = true> /* implicit */ PtBufferView(const T& t) // NOLINT - : ptr(static_cast(t.data())), + : ptr(const_cast(static_cast(t.data()))), + pt_type(PtTypeToEnum::value), + shape(t.shape().begin(), t.shape().end()), + strides(t.strides().begin(), t.strides().end()), + compacted(strides == makeCompactStrides(shape)) {} + + template , bool> = true> + /* implicit */ PtBufferView(T& t) // NOLINT + : ptr(const_cast(static_cast(t.data()))), pt_type(PtTypeToEnum::value), shape(t.shape().begin(), t.shape().end()), - strides(t.strides().begin(), t.strides().end()) {} + strides(t.strides().begin(), t.strides().end()), + write_able(true), + compacted(strides == makeCompactStrides(shape)) {} - template - const T* get(const Index& indices) const { + template + const S& get(const Index& indices) const { + SPU_ENFORCE(PtTypeToEnum::value == pt_type); auto fi = calcFlattenOffset(indices, shape, strides); const auto* addr = static_cast(ptr) + SizeOf(pt_type) * fi; - return reinterpret_cast(addr); + return *reinterpret_cast(addr); + } + + template + const S& get(size_t idx) const { + if (isCompact()) { + const auto* addr = + static_cast(ptr) + SizeOf(pt_type) * idx; + return *reinterpret_cast(addr); + } else { + const auto& indices = unflattenIndex(idx, shape); + return get(indices); + } + } + + template + void set(const Index& indices, S v) { + SPU_ENFORCE(write_able); + SPU_ENFORCE(PtTypeToEnum::value == pt_type); + auto fi = calcFlattenOffset(indices, shape, strides); + auto* addr = static_cast(ptr) + SizeOf(pt_type) * fi; + *reinterpret_cast(addr) = v; + } + + template + void set(size_t idx, S v) { + if (isCompact()) { + auto* addr = static_cast(ptr) + SizeOf(pt_type) * idx; + *reinterpret_cast(addr) = v; + } else { + const auto& indices = unflattenIndex(idx, shape); + set(indices, v); + } } - bool isCompact() const { return strides == makeCompactStrides(shape); } + bool isCompact() const { return compacted; } }; std::ostream& operator<<(std::ostream& out, PtBufferView v); -// Make a ndarray from a plaintext buffer. NdArrayRef convertToNdArray(PtBufferView bv); } // namespace spu diff --git a/libspu/core/pt_buffer_view_test.cc b/libspu/core/pt_buffer_view_test.cc index 841af2b5..e1fa3e59 100644 --- a/libspu/core/pt_buffer_view_test.cc +++ b/libspu/core/pt_buffer_view_test.cc @@ -53,9 +53,9 @@ TEST(PtBufferView, Vector) { EXPECT_EQ(bv_f32.pt_type, PT_F32); EXPECT_THAT(bv_f32.shape, testing::ElementsAre(3)); EXPECT_THAT(bv_f32.strides, testing::ElementsAre(1)); - EXPECT_FLOAT_EQ((*bv_f32.get({0})), 1.0); - EXPECT_FLOAT_EQ((*bv_f32.get({1})), 2.0); - EXPECT_FLOAT_EQ((*bv_f32.get({2})), 3.0); + EXPECT_FLOAT_EQ((bv_f32.get({0})), 1.0); + EXPECT_FLOAT_EQ((bv_f32.get({1})), 2.0); + EXPECT_FLOAT_EQ((bv_f32.get({2})), 3.0); } TEST(PtBufferView, ConvertToNdArray) { diff --git a/libspu/core/trace.cc b/libspu/core/trace.cc index 4e48e2ee..bcbcf5f4 100644 --- a/libspu/core/trace.cc +++ b/libspu/core/trace.cc @@ -26,8 +26,6 @@ #include "absl/strings/numbers.h" #include "absl/strings/str_split.h" -#include "libspu/core/prelude.h" - #ifdef __APPLE__ #include #endif diff --git a/libspu/core/trace.h b/libspu/core/trace.h index 47bec4ea..c62c0d3c 100644 --- a/libspu/core/trace.h +++ b/libspu/core/trace.h @@ -23,9 +23,8 @@ #include "absl/types/span.h" #include "fmt/format.h" -#include "fmt/ostream.h" #include "spdlog/spdlog.h" -#include "yacl/link/link.h" +#include "yacl/link/context.h" namespace std { diff --git a/libspu/core/type.cc b/libspu/core/type.cc index ab19ce22..0e33bd27 100644 --- a/libspu/core/type.cc +++ b/libspu/core/type.cc @@ -16,9 +16,6 @@ #include -#include "absl/strings/match.h" -#include "absl/strings/str_split.h" - namespace spu { Type::Type() diff --git a/libspu/core/type_util.cc b/libspu/core/type_util.cc index 33113868..896de751 100644 --- a/libspu/core/type_util.cc +++ b/libspu/core/type_util.cc @@ -14,8 +14,6 @@ #include "libspu/core/type_util.h" -#include "absl/strings/str_join.h" - namespace spu { ////////////////////////////////////////////////////////////// diff --git a/libspu/core/type_util.h b/libspu/core/type_util.h index 7a3efae3..beffee22 100644 --- a/libspu/core/type_util.h +++ b/libspu/core/type_util.h @@ -19,8 +19,6 @@ #include #include -#include "fmt/format.h" -#include "fmt/ostream.h" #include "yacl/base/int128.h" #include "libspu/core/half.h" diff --git a/libspu/core/value.cc b/libspu/core/value.cc index c6e2d4a0..7c30372f 100644 --- a/libspu/core/value.cc +++ b/libspu/core/value.cc @@ -18,7 +18,6 @@ #include #include "fmt/format.h" -#include "fmt/ostream.h" #include "libspu/core/ndarray_ref.h" #include "libspu/core/prelude.h" diff --git a/libspu/core/value.h b/libspu/core/value.h index 162ae32e..759e393b 100644 --- a/libspu/core/value.h +++ b/libspu/core/value.h @@ -16,7 +16,6 @@ #include -#include "absl/types/span.h" #include "fmt/ostream.h" #include "libspu/core/ndarray_ref.h" diff --git a/libspu/core/vectorize.h b/libspu/core/vectorize.h index 38e07d1f..a3f91679 100644 --- a/libspu/core/vectorize.h +++ b/libspu/core/vectorize.h @@ -18,8 +18,6 @@ #include #include -#include "absl/types/span.h" - #include "libspu/core/prelude.h" namespace spu { diff --git a/libspu/core/xt_helper.cc b/libspu/core/xt_helper.cc deleted file mode 100644 index 821b9918..00000000 --- a/libspu/core/xt_helper.cc +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2021 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "libspu/core/xt_helper.h" - -namespace spu { - -// - -} // namespace spu diff --git a/libspu/core/xt_helper.h b/libspu/core/xt_helper.h index 9040ae92..3230eada 100644 --- a/libspu/core/xt_helper.h +++ b/libspu/core/xt_helper.h @@ -18,12 +18,9 @@ #include "xtensor/xeval.hpp" #include "xtensor/xexpression.hpp" #include "xtensor/xio.hpp" -#include "xtensor/xrandom.hpp" #include "libspu/core/ndarray_ref.h" #include "libspu/core/prelude.h" -#include "libspu/core/pt_buffer_view.h" -#include "libspu/core/type_util.h" namespace spu { diff --git a/libspu/device/BUILD.bazel b/libspu/device/BUILD.bazel index f3dc3705..bd60b75b 100644 --- a/libspu/device/BUILD.bazel +++ b/libspu/device/BUILD.bazel @@ -34,7 +34,6 @@ spu_cc_library( "//libspu/core:context", "//libspu/core:pt_buffer_view", "//libspu/core:value", - "//libspu/kernel/hal:constants", "//libspu/kernel/hal:public_helper", "//libspu/mpc:factory", ], diff --git a/libspu/device/api.cc b/libspu/device/api.cc index b2eed2ad..88850406 100644 --- a/libspu/device/api.cc +++ b/libspu/device/api.cc @@ -21,11 +21,12 @@ #include "llvm/Support/ErrorHandling.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/Parser/Parser.h" #include "spdlog/spdlog.h" +#include "libspu/core/trace.h" #include "libspu/device/debug_dump_constant.h" -#include "libspu/device/pphlo/pphlo_executor.h" #include "libspu/dialect/pphlo_dialect.h" namespace spu::device { diff --git a/libspu/device/debug_dump_constant.cc b/libspu/device/debug_dump_constant.cc index 2aae4e5a..6502ed2b 100644 --- a/libspu/device/debug_dump_constant.cc +++ b/libspu/device/debug_dump_constant.cc @@ -14,7 +14,7 @@ #include "libspu/device/debug_dump_constant.h" -#include "fmt/format.h" +#include "fmt/format.h" // IWYU pragma: keep namespace spu::device { diff --git a/libspu/device/executor.cc b/libspu/device/executor.cc index bc176bf8..5d0c400a 100644 --- a/libspu/device/executor.cc +++ b/libspu/device/executor.cc @@ -20,9 +20,8 @@ #include #include -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Value.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Region.h" #include "libspu/core/context.h" #include "libspu/core/prelude.h" @@ -237,7 +236,7 @@ class BlockParallelRunner final { threads_.reserve(opts_.concurrency); for (uint64_t i = 0; i < opts_.concurrency; i++) { - threads_.emplace_back(std::thread(&BlockParallelRunner::run_task, this)); + threads_.emplace_back(&BlockParallelRunner::run_task, this); } for (uint64_t i = 0; i < opts_.concurrency; i++) { @@ -260,7 +259,7 @@ class BlockParallelRunner final { SPU_THROW("Should not be here"); } - void run_task(void) { + void run_task() { std::unique_lock queue_lock(queue_mtx_); while (!task_queue_.empty()) { diff --git a/libspu/device/executor.h b/libspu/device/executor.h index 14d93eae..38a0c488 100644 --- a/libspu/device/executor.h +++ b/libspu/device/executor.h @@ -18,8 +18,8 @@ #include #include "llvm/ADT/DenseMap.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" #include "libspu/core/context.h" #include "libspu/core/value.h" diff --git a/libspu/device/io.cc b/libspu/device/io.cc index 68dcc601..117fb333 100644 --- a/libspu/device/io.cc +++ b/libspu/device/io.cc @@ -16,10 +16,11 @@ #include +#include "yacl/link/algorithm/allgather.h" + #include "libspu/core/config.h" #include "libspu/core/encoding.h" #include "libspu/core/pt_buffer_view.h" -#include "libspu/kernel/hal/constants.h" #include "libspu/kernel/hal/public_helper.h" #include "libspu/mpc/factory.h" @@ -47,10 +48,7 @@ std::vector IoClient::makeShares(const PtBufferView &bv, if (bv.pt_type == PT_BOOL && vtype == VIS_SECRET && base_io_->hasBitSecretSupport()) { - // handle boolean type encoding. - NdArrayRef arr = convertToNdArray(bv); - - auto shares = base_io_->makeBitSecret(arr); + auto shares = base_io_->makeBitSecret(bv); SPU_ENFORCE(shares.size() == world_size_); std::vector result; @@ -64,7 +62,6 @@ std::vector IoClient::makeShares(const PtBufferView &bv, if (bv.pt_type == PT_CF32 || bv.pt_type == PT_CF64) { auto s_type = bv.pt_type == PT_CF32 ? PT_F32 : PT_F64; auto offset = bv.pt_type == PT_CF32 ? sizeof(float) : sizeof(double); - // SPDLOG_INFO("bv.strides = {}", bv.strides); Strides ds = bv.strides; for (auto &s : ds) { @@ -72,7 +69,8 @@ std::vector IoClient::makeShares(const PtBufferView &bv, } PtBufferView real_view(bv.ptr, s_type, bv.shape, ds); - PtBufferView imag_view((std::byte *)bv.ptr + offset, s_type, bv.shape, ds); + PtBufferView imag_view(static_cast(bv.ptr) + offset, + s_type, bv.shape, ds); auto r_shares = makeShares(real_view, vtype, owner_rank); auto i_shares = makeShares(imag_view, vtype, owner_rank); @@ -88,8 +86,7 @@ std::vector IoClient::makeShares(const PtBufferView &bv, // encode to ring. DataType dtype; - NdArrayRef encoded = - encodeToRing(convertToNdArray(bv), config_.field(), fxp_bits, &dtype); + NdArrayRef encoded = encodeToRing(bv, config_.field(), fxp_bits, &dtype); // make shares. std::vector shares = base_io_->toShares(encoded, vtype); @@ -103,48 +100,54 @@ std::vector IoClient::makeShares(const PtBufferView &bv, return result; } -template -NdArrayRef combineComplex(const NdArrayRef &real, const NdArrayRef &imag, - const Type &complex_type) { - NdArrayRef ret(complex_type, real.shape()); - NdArrayView> ret_v(ret); - NdArrayView rv(real); - NdArrayView iv(imag); - for (int64_t idx = 0; idx < real.numel(); ++idx) { - ret_v[idx].real(rv[idx]); - ret_v[idx].imag(iv[idx]); +PtType IoClient::getPtType(absl::Span values) { + const DataType dtype = values.front().dtype(); + if (values.front().isComplex()) { + if (dtype == DT_F32) { + return PT_CF32; + } else { + SPU_ENFORCE(dtype == DT_F64); + return PT_CF64; + } + } else { + return getDecodeType(dtype); } - return ret; } -NdArrayRef IoClient::combineShares(absl::Span values) { +void IoClient::combineShares(absl::Span values, + PtBufferView *out) { SPU_ENFORCE(values.size() == world_size_, "wrong number of shares, got={}, expect={}", values.size(), world_size_); if (values.front().isComplex()) { - NdArrayRef real; - NdArrayRef imag; + Strides ds = out->strides; + for (auto &s : ds) { + s *= 2; + } + + auto s_type = values.front().dtype() == DT_F32 ? PT_F32 : PT_F64; + auto offset = + values.front().dtype() == DT_F32 ? sizeof(float) : sizeof(double); + + PtBufferView real_pv(out->ptr, s_type, out->shape, ds); + PtBufferView imag_pv(static_cast(out->ptr) + offset, s_type, + out->shape, ds); { - std::vector reals(values.size()); + std::vector reals(values.size()); for (size_t idx = 0; idx < values.size(); ++idx) { reals[idx] = Value(values[idx].data(), values[idx].dtype()); } - real = combineShares(reals); + combineShares(reals, &real_pv); } { std::vector imags(values.size()); for (size_t idx = 0; idx < values.size(); ++idx) { - imags[idx] = Value(*values[idx].imag(), values[idx].dtype()); + imags[idx] = Value(values[idx].imag().value(), values[idx].dtype()); } - imag = combineShares(imags); - } - if (values.front().dtype() == DT_F32) { - return combineComplex(real, imag, CF32); - } else { - SPU_ENFORCE(values.front().dtype() == DT_F64); - return combineComplex(real, imag, CF64); + combineShares(imags, &imag_pv); } + return; } const size_t fxp_bits = config_.fxp_fraction_bits(); @@ -164,7 +167,7 @@ NdArrayRef IoClient::combineShares(absl::Span values) { // decode from ring. const DataType dtype = values.front().dtype(); - return decodeFromRing(encoded, dtype, fxp_bits); + decodeFromRing(encoded, dtype, fxp_bits, out); } ColocatedIo::ColocatedIo(SPUContext *sctx) : sctx_(sctx) {} diff --git a/libspu/device/io.h b/libspu/device/io.h index 630e78a5..b1caa898 100644 --- a/libspu/device/io.h +++ b/libspu/device/io.h @@ -112,8 +112,10 @@ class IoClient { size_t getShareSize(const PtBufferView &bv, Visibility vtype, int owner_rank = -1); - // Combine shares to a plaintext ndarray. - NdArrayRef combineShares(absl::Span values); + // Combine shares to a plaintext buffer. + void combineShares(absl::Span values, PtBufferView *out); + + PtType getPtType(absl::Span values); }; class ColocatedIo { diff --git a/libspu/device/io_test.cc b/libspu/device/io_test.cc index bf146992..8d606da3 100644 --- a/libspu/device/io_test.cc +++ b/libspu/device/io_test.cc @@ -39,10 +39,11 @@ TEST_P(IoClientTest, Float) { auto shares = io.makeShares(in_data, kVisibility); EXPECT_EQ(shares.size(), kWorldSize); - auto out = io.combineShares(shares); - EXPECT_EQ(out.eltype().as()->pt_type(), PT_F32); + EXPECT_EQ(io.getPtType(shares), PT_F32); + xt::xarray out_data(in_data.shape()); + PtBufferView out_pv(out_data); + io.combineShares(shares, &out_pv); - auto out_data = xt_adapt(out); EXPECT_EQ(in_data, out_data); } @@ -60,10 +61,11 @@ TEST_P(IoClientTest, Int) { auto shares = io.makeShares(in_data, kVisibility); EXPECT_EQ(shares.size(), kWorldSize); - auto out = io.combineShares(shares); - EXPECT_EQ(out.eltype().as()->pt_type(), PT_I32); + EXPECT_EQ(io.getPtType(shares), PT_I32); + xt::xarray out_data(in_data.shape()); + PtBufferView out_pv(out_data); + io.combineShares(shares, &out_pv); - auto out_data = xt_adapt(out); EXPECT_EQ(in_data, out_data); } diff --git a/libspu/device/pphlo/BUILD.bazel b/libspu/device/pphlo/BUILD.bazel index 09617fc6..73e427d9 100644 --- a/libspu/device/pphlo/BUILD.bazel +++ b/libspu/device/pphlo/BUILD.bazel @@ -25,7 +25,21 @@ spu_cc_library( ":pphlo_verifier", "//libspu/device:executor", "//libspu/dialect:pphlo_dialect", - "//libspu/kernel/hlo", + "//libspu/kernel/hal:debug", + "//libspu/kernel/hlo:basic_binary", + "//libspu/kernel/hlo:basic_ternary", + "//libspu/kernel/hlo:basic_unary", + "//libspu/kernel/hlo:casting", + "//libspu/kernel/hlo:const", + "//libspu/kernel/hlo:control_flow", + "//libspu/kernel/hlo:convolution", + "//libspu/kernel/hlo:geometrical", + "//libspu/kernel/hlo:indexing", + "//libspu/kernel/hlo:rand", + "//libspu/kernel/hlo:reduce", + "//libspu/kernel/hlo:select_and_scatter", + "//libspu/kernel/hlo:shift", + "//libspu/kernel/hlo:sort", ], ) @@ -34,8 +48,8 @@ spu_cc_library( srcs = ["pphlo_intrinsic_executor.cc"], hdrs = ["pphlo_intrinsic_executor.h"], deps = [ - "//libspu/core:context", - "//libspu/kernel/hlo", + "//libspu/kernel/hlo:casting", + "//libspu/kernel/hlo:const", "@llvm-project//llvm:Support", ], ) @@ -68,10 +82,10 @@ spu_cc_library( srcs = ["pphlo_verifier.cc"], hdrs = ["pphlo_verifier.h"], deps = [ - "//libspu/core:context", "//libspu/core:value", "//libspu/dialect:pphlo_dialect", - "//libspu/kernel/hlo", + "//libspu/kernel/hal:public_helper", + "//libspu/kernel/hal:type_cast", "@stablehlo//:reference_ops", ], ) diff --git a/libspu/device/pphlo/pphlo_executor.cc b/libspu/device/pphlo/pphlo_executor.cc index 88e73d39..cb422e2a 100644 --- a/libspu/device/pphlo/pphlo_executor.cc +++ b/libspu/device/pphlo/pphlo_executor.cc @@ -15,13 +15,15 @@ #include "libspu/device/pphlo/pphlo_executor.h" #include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/Location.h" #include "libspu/core/encoding.h" +#include "libspu/core/trace.h" #include "libspu/device/pphlo/pphlo_intrinsic_executor.h" #include "libspu/device/pphlo/pphlo_verifier.h" #include "libspu/dialect/pphlo_base_enums.h" #include "libspu/dialect/pphlo_ops.h" +#include "libspu/kernel/hal/debug.h" +#include "libspu/kernel/hal/public_helper.h" #include "libspu/kernel/hal/ring.h" #include "libspu/kernel/hlo/basic_binary.h" #include "libspu/kernel/hlo/basic_ternary.h" @@ -611,7 +613,8 @@ void execute(OpExecutor *executor, SPUContext *sctx, SymbolScope *sscope, // build strides Strides window_strides(window_shape.size(), 1); if (op.getWindowStrides().has_value()) { - convertDenseIntElementAttr(*op.getWindowStrides(), window_strides); + convertDenseIntElementAttr(*op.getWindowStrides(), // NOLINT + window_strides); } // window padding @@ -646,7 +649,8 @@ void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope, // build strides Strides window_strides(window_shape.size(), 1); if (op.getWindowStrides().has_value()) { - convertDenseIntElementAttr(*op.getWindowStrides(), window_strides); + convertDenseIntElementAttr(*op.getWindowStrides(), // NOLINT + window_strides); } // window padding @@ -916,15 +920,15 @@ void execute(OpExecutor *executor, SPUContext *sctx, SymbolScope *sscope, // build strides Strides window_strides(window_shape.size(), 1); if (op.getWindowStrides().has_value()) { - convertDenseIntElementAttr(*op.getWindowStrides(), - window_strides); // NOLINT + convertDenseIntElementAttr(*op.getWindowStrides(), // NOLINT + window_strides); } // window dilation Sizes window_dilations(window_shape.size(), 1); if (op.getWindowDilations().has_value()) { - convertDenseIntElementAttr(*op.getWindowDilations(), - window_dilations); // NOLINT + convertDenseIntElementAttr(*op.getWindowDilations(), // NOLINT + window_dilations); } std::vector> window_padding(window_shape.size(), @@ -963,15 +967,15 @@ void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope, // build strides Strides window_strides(window_shape.size(), 1); if (op.getWindowStrides().has_value()) { - convertDenseIntElementAttr(*op.getWindowStrides(), - window_strides); // NOLINT + convertDenseIntElementAttr(*op.getWindowStrides(), // NOLINT + window_strides); } // window dilation Sizes window_dilations(window_shape.size(), 1); if (op.getWindowDilations().has_value()) { - convertDenseIntElementAttr(*op.getWindowDilations(), - window_dilations); // NOLINT + convertDenseIntElementAttr(*op.getWindowDilations(), // NOLINT + window_dilations); } auto ret_shape = op->getResults()[0] diff --git a/libspu/device/pphlo/pphlo_executor.h b/libspu/device/pphlo/pphlo_executor.h index 6e939ec9..e21ddae9 100644 --- a/libspu/device/pphlo/pphlo_executor.h +++ b/libspu/device/pphlo/pphlo_executor.h @@ -14,9 +14,12 @@ #pragma once -#include "libspu/core/context.h" #include "libspu/device/executor.h" +namespace spu { +class SPUContext; +} + namespace spu::device::pphlo { class PPHloExecutor : public OpExecutor { diff --git a/libspu/device/pphlo/pphlo_executor_debug_runner.cc b/libspu/device/pphlo/pphlo_executor_debug_runner.cc index 51f76768..4e95a86b 100644 --- a/libspu/device/pphlo/pphlo_executor_debug_runner.cc +++ b/libspu/device/pphlo/pphlo_executor_debug_runner.cc @@ -26,8 +26,6 @@ #include "libspu/device/debug_dump_constant.h" #include "libspu/device/pphlo/pphlo_executor.h" #include "libspu/device/symbol_table.h" -#include "libspu/device/test_utils.h" -#include "libspu/kernel/hal/debug.h" #include "libspu/mpc/factory.h" #include "libspu/mpc/utils/simulate.h" @@ -57,9 +55,9 @@ std::shared_ptr MakeLink(const std::string &parties, size_t rank) { yacl::link::ContextDesc lctx_desc; std::vector hosts = absl::StrSplit(parties, ','); - for (size_t rank = 0; rank < hosts.size(); rank++) { + for (auto &host : hosts) { const auto id = fmt::format("party{}", rank); - lctx_desc.parties.push_back({id, hosts[rank]}); + lctx_desc.parties.push_back({id, host}); } auto lctx = yacl::link::FactoryBrpc().CreateContext(lctx_desc, rank); lctx->ConnectToMesh(); diff --git a/libspu/device/pphlo/pphlo_executor_test_runner.h b/libspu/device/pphlo/pphlo_executor_test_runner.h index 1cf32ee6..0fb976f1 100644 --- a/libspu/device/pphlo/pphlo_executor_test_runner.h +++ b/libspu/device/pphlo/pphlo_executor_test_runner.h @@ -14,7 +14,7 @@ #pragma once -#include "fmt/format.h" +#include "fmt/format.h" // IWYU pragma: keep #include "libspu/device/test_utils.h" diff --git a/libspu/device/pphlo/pphlo_intrinsic_executor.cc b/libspu/device/pphlo/pphlo_intrinsic_executor.cc index 8c97cc52..8b73478e 100644 --- a/libspu/device/pphlo/pphlo_intrinsic_executor.cc +++ b/libspu/device/pphlo/pphlo_intrinsic_executor.cc @@ -14,6 +14,9 @@ #include "libspu/device/pphlo/pphlo_intrinsic_executor.h" +#include "spdlog/spdlog.h" + +#include "libspu/kernel/hlo/casting.h" #include "libspu/kernel/hlo/const.h" namespace spu::device::pphlo { diff --git a/libspu/device/pphlo/pphlo_intrinsic_executor.h b/libspu/device/pphlo/pphlo_intrinsic_executor.h index e5ff5f93..466ea8b4 100644 --- a/libspu/device/pphlo/pphlo_intrinsic_executor.h +++ b/libspu/device/pphlo/pphlo_intrinsic_executor.h @@ -16,7 +16,11 @@ #include "llvm/ADT/StringRef.h" -#include "libspu/core/context.h" +#include "libspu/core/value.h" + +namespace spu { +class SPUContext; +} namespace spu::device::pphlo { diff --git a/libspu/device/pphlo/pphlo_verifier.cc b/libspu/device/pphlo/pphlo_verifier.cc index 963c390b..ce684730 100644 --- a/libspu/device/pphlo/pphlo_verifier.cc +++ b/libspu/device/pphlo/pphlo_verifier.cc @@ -21,10 +21,8 @@ #include "stablehlo/reference/Tensor.h" #include "libspu/dialect/pphlo_ops.h" -#include "libspu/kernel/hal/constants.h" #include "libspu/kernel/hal/public_helper.h" #include "libspu/kernel/hal/type_cast.h" -#include "libspu/kernel/hlo/utils.h" namespace spu::device::pphlo { namespace { diff --git a/libspu/device/pphlo/pphlo_verifier.h b/libspu/device/pphlo/pphlo_verifier.h index d988ef78..50697e83 100644 --- a/libspu/device/pphlo/pphlo_verifier.h +++ b/libspu/device/pphlo/pphlo_verifier.h @@ -14,11 +14,12 @@ #pragma once -#include "libspu/core/context.h" #include "libspu/core/value.h" -#include "libspu/dialect/pphlo_dialect.h" #include "libspu/dialect/pphlo_ops.h" -#include "libspu/dialect/pphlo_types.h" + +namespace spu { +class SPUContext; +} namespace spu::device::pphlo { diff --git a/libspu/device/pphlo/pphlo_verifier_test.cc b/libspu/device/pphlo/pphlo_verifier_test.cc index d82e7e3c..dcf0df75 100644 --- a/libspu/device/pphlo/pphlo_verifier_test.cc +++ b/libspu/device/pphlo/pphlo_verifier_test.cc @@ -21,6 +21,7 @@ #include "xtensor/xarray.hpp" #include "libspu/device/test_utils.h" +#include "libspu/dialect/pphlo_dialect.h" #include "libspu/kernel/test_util.h" #include "libspu/mpc/utils/simulate.h" diff --git a/libspu/device/test_utils.h b/libspu/device/test_utils.h index 549dabaa..f52ed909 100644 --- a/libspu/device/test_utils.h +++ b/libspu/device/test_utils.h @@ -48,7 +48,13 @@ class LocalIo { shares.push_back(st.getVar(name)); } - return io_client_.combineShares(shares); + auto pt_type = io_client_.getPtType(shares); + NdArrayRef ret(makePtType(pt_type), shares.front().shape()); + PtBufferView pv(ret.data(), pt_type, ret.shape(), ret.strides()); + + io_client_.combineShares(shares, &pv); + + return ret; } SymbolTable *GetSymbolTable(size_t idx) { return &symbol_tables_[idx]; } diff --git a/libspu/kernel/BUILD.bazel b/libspu/kernel/BUILD.bazel index c65da3a0..c1853e34 100644 --- a/libspu/kernel/BUILD.bazel +++ b/libspu/kernel/BUILD.bazel @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//bazel:spu.bzl", "spu_cc_library", "spu_cc_test") +load("//bazel:spu.bzl", "spu_cc_library") package(default_visibility = ["//visibility:public"]) diff --git a/libspu/kernel/hal/BUILD.bazel b/libspu/kernel/hal/BUILD.bazel index bb0dc259..87926a9c 100644 --- a/libspu/kernel/hal/BUILD.bazel +++ b/libspu/kernel/hal/BUILD.bazel @@ -16,25 +16,6 @@ load("//bazel:spu.bzl", "spu_cc_library", "spu_cc_test") package(default_visibility = ["//visibility:public"]) -spu_cc_library( - name = "hal", - hdrs = ["hal.h"], - deps = [ - ":complex", - ":constants", - ":debug", - ":fxp", - ":integer", - ":polymorphic", - ":public_helper", - ":random", - ":shape_ops", - ":sort", - ":type_cast", - "//libspu/core:value", - ], -) - spu_cc_library( name = "prot_wrapper", srcs = ["prot_wrapper.cc"], @@ -150,15 +131,6 @@ spu_cc_test( ], ) -spu_cc_library( - name = "fxp", - hdrs = ["fxp.h"], - deps = [ - ":fxp_approx", - ":fxp_base", - ], -) - spu_cc_library( name = "constants", srcs = ["constants.cc"], @@ -191,6 +163,7 @@ spu_cc_library( "//libspu/core:context", "//libspu/core:encoding", "//libspu/core:value", + "//libspu/core:xt_helper", "//libspu/mpc/common:pv2k", # TODO: this is a bad reference ], ) @@ -212,7 +185,8 @@ spu_cc_library( srcs = ["polymorphic.cc"], hdrs = ["polymorphic.h"], deps = [ - ":fxp", + ":fxp_approx", + ":fxp_base", ":integer", ":shape_ops", ":type_cast", diff --git a/libspu/kernel/hal/constants.cc b/libspu/kernel/hal/constants.cc index c5c8bf04..1b44cbd0 100644 --- a/libspu/kernel/hal/constants.cc +++ b/libspu/kernel/hal/constants.cc @@ -17,8 +17,8 @@ #include "libspu/core/encoding.h" #include "libspu/core/ndarray_ref.h" #include "libspu/core/pt_buffer_view.h" +#include "libspu/core/trace.h" #include "libspu/core/type_util.h" -#include "libspu/kernel/hal/prot_wrapper.h" #include "libspu/kernel/hal/ring.h" #include "libspu/mpc/common/pv2k.h" @@ -31,13 +31,11 @@ namespace { Value make_pub2k(SPUContext* ctx, const PtBufferView& bv) { SPU_TRACE_HAL_DISP(ctx, bv); - NdArrayRef raw = convertToNdArray(bv); - const auto field = ctx->getField(); const auto fxp_bits = ctx->getFxpBits(); DataType dtype; - NdArrayRef encoded = encodeToRing(raw, field, fxp_bits, &dtype); + NdArrayRef encoded = encodeToRing(bv, field, fxp_bits, &dtype); return Value(encoded.as(makeType(field)), dtype); } diff --git a/libspu/kernel/hal/fxp.h b/libspu/kernel/hal/fxp.h deleted file mode 100644 index 1e0d78c3..00000000 --- a/libspu/kernel/hal/fxp.h +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright 2021 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "libspu/kernel/hal/fxp_approx.h" -#include "libspu/kernel/hal/fxp_base.h" diff --git a/libspu/kernel/hal/fxp_approx.cc b/libspu/kernel/hal/fxp_approx.cc index 8b309642..3e4ef849 100644 --- a/libspu/kernel/hal/fxp_approx.cc +++ b/libspu/kernel/hal/fxp_approx.cc @@ -19,6 +19,7 @@ #include #include +#include "libspu/core/trace.h" #include "libspu/kernel/hal/constants.h" #include "libspu/kernel/hal/fxp_base.h" #include "libspu/kernel/hal/fxp_cleartext.h" diff --git a/libspu/kernel/hal/fxp_approx.h b/libspu/kernel/hal/fxp_approx.h index 9a7648f5..56073810 100644 --- a/libspu/kernel/hal/fxp_approx.h +++ b/libspu/kernel/hal/fxp_approx.h @@ -14,9 +14,12 @@ #pragma once -#include "libspu/core/context.h" #include "libspu/core/value.h" +namespace spu { +class SPUContext; +} + // !!please read [README.md] for api naming conventions. namespace spu::kernel::hal { namespace detail { diff --git a/libspu/kernel/hal/fxp_base.cc b/libspu/kernel/hal/fxp_base.cc index 0a2ace18..a1e3feda 100644 --- a/libspu/kernel/hal/fxp_base.cc +++ b/libspu/kernel/hal/fxp_base.cc @@ -17,6 +17,7 @@ #include #include "libspu/core/prelude.h" +#include "libspu/core/trace.h" #include "libspu/kernel/hal/constants.h" #include "libspu/kernel/hal/fxp_cleartext.h" #include "libspu/kernel/hal/ring.h" diff --git a/libspu/kernel/hal/fxp_base.h b/libspu/kernel/hal/fxp_base.h index ba3c056b..df92c2d1 100644 --- a/libspu/kernel/hal/fxp_base.h +++ b/libspu/kernel/hal/fxp_base.h @@ -14,10 +14,12 @@ #pragma once -#include "libspu/core/context.h" -#include "libspu/core/pt_buffer_view.h" #include "libspu/core/value.h" +namespace spu { +class SPUContext; +} + // !!please read [README.md] for api naming conventions. namespace spu::kernel::hal { namespace detail { diff --git a/libspu/kernel/hal/fxp_cleartext.cc b/libspu/kernel/hal/fxp_cleartext.cc index 7dc5e936..af8cd97b 100644 --- a/libspu/kernel/hal/fxp_cleartext.cc +++ b/libspu/kernel/hal/fxp_cleartext.cc @@ -16,11 +16,32 @@ #include +#include "libspu/core/context.h" #include "libspu/core/encoding.h" +#include "libspu/core/trace.h" namespace spu::kernel::hal { namespace { +NdArrayRef encodeToRing(const NdArrayRef& src, FieldType field, size_t fxp_bits, + DataType* out_type) { + SPU_ENFORCE(src.eltype().isa(), "expect PtType, got={}", src.eltype()); + const PtType pt_type = src.eltype().as()->pt_type(); + PtBufferView pv(static_cast(src.data()), pt_type, src.shape(), + src.strides()); + return encodeToRing(pv, field, fxp_bits, out_type); +} + +NdArrayRef decodeFromRing(const NdArrayRef& src, DataType in_dtype, + size_t fxp_bits) { + const PtType pt_type = getDecodeType(in_dtype); + NdArrayRef dst(makePtType(pt_type), src.shape()); + PtBufferView pv(static_cast(dst.data()), pt_type, dst.shape(), + dst.strides()); + decodeFromRing(src, in_dtype, fxp_bits, &pv, nullptr); + return dst; +} + template Value applyFloatingPointFn(SPUContext* ctx, const Value& in, FN&& fn) { SPU_TRACE_HAL_DISP(ctx, in); @@ -33,7 +54,6 @@ Value applyFloatingPointFn(SPUContext* ctx, const Value& in, FN&& fn) { // decode to floating point auto f32_arr = decodeFromRing(in.data().as(ring_ty), in.dtype(), fxp_bits); - for (auto iter = f32_arr.begin(); iter != f32_arr.end(); ++iter) { auto* ptr = reinterpret_cast(&*iter); *ptr = fn(*ptr); @@ -41,7 +61,6 @@ Value applyFloatingPointFn(SPUContext* ctx, const Value& in, FN&& fn) { DataType dtype; const auto out = encodeToRing(f32_arr, field, fxp_bits, &dtype); - SPU_ENFORCE(dtype == DT_F32 || dtype == DT_F64, "sanity failed"); return Value(out.as(in.storage_type()), dtype); } diff --git a/libspu/kernel/hal/fxp_cleartext.h b/libspu/kernel/hal/fxp_cleartext.h index e9a75430..894921a6 100644 --- a/libspu/kernel/hal/fxp_cleartext.h +++ b/libspu/kernel/hal/fxp_cleartext.h @@ -14,9 +14,12 @@ #pragma once -#include "libspu/core/context.h" #include "libspu/core/value.h" +namespace spu { +class SPUContext; +} + // !!please read [README.md] for api naming conventions. namespace spu::kernel::hal { @@ -31,8 +34,8 @@ Value f_exp_p(SPUContext* ctx, const Value& in); Value f_div_p(SPUContext* ctx, const Value& x, const Value& y); -Value f_sine_p(SPUContext* ctx, const Value& x); +Value f_sine_p(SPUContext* ctx, const Value& in); -Value f_cosine_p(SPUContext* ctx, const Value& x); +Value f_cosine_p(SPUContext* ctx, const Value& in); } // namespace spu::kernel::hal diff --git a/libspu/kernel/hal/hal.h b/libspu/kernel/hal/hal.h deleted file mode 100644 index a635cf0a..00000000 --- a/libspu/kernel/hal/hal.h +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright 2021 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "libspu/kernel/hal/constants.h" -#include "libspu/kernel/hal/debug.h" -#include "libspu/kernel/hal/polymorphic.h" -#include "libspu/kernel/hal/public_helper.h" -#include "libspu/kernel/hal/random.h" -#include "libspu/kernel/hal/shape_ops.h" -#include "libspu/kernel/hal/type_cast.h" diff --git a/libspu/kernel/hal/integer.cc b/libspu/kernel/hal/integer.cc index e95dc444..696f1aec 100644 --- a/libspu/kernel/hal/integer.cc +++ b/libspu/kernel/hal/integer.cc @@ -14,7 +14,8 @@ #include "libspu/kernel/hal/integer.h" -#include "libspu/kernel/hal/prot_wrapper.h" +#include "libspu/core/context.h" +#include "libspu/core/trace.h" #include "libspu/kernel/hal/ring.h" namespace spu::kernel::hal { diff --git a/libspu/kernel/hal/integer.h b/libspu/kernel/hal/integer.h index 1c831104..12461365 100644 --- a/libspu/kernel/hal/integer.h +++ b/libspu/kernel/hal/integer.h @@ -14,9 +14,12 @@ #pragma once -#include "libspu/core/context.h" #include "libspu/core/value.h" +namespace spu { +class SPUContext; +} + namespace spu::kernel::hal { // !!please read [README.md] for api naming conventions. diff --git a/libspu/kernel/hal/polymorphic.cc b/libspu/kernel/hal/polymorphic.cc index 6d688c71..2542eb59 100644 --- a/libspu/kernel/hal/polymorphic.cc +++ b/libspu/kernel/hal/polymorphic.cc @@ -14,18 +14,13 @@ #include "libspu/kernel/hal/polymorphic.h" -#include "fmt/format.h" -#include "fmt/ostream.h" - #include "libspu/core/context.h" -#include "libspu/core/encoding.h" // for bitcast #include "libspu/core/prelude.h" #include "libspu/core/trace.h" -#include "libspu/kernel/hal/constants.h" -#include "libspu/kernel/hal/fxp.h" +#include "libspu/kernel/hal/fxp_approx.h" +#include "libspu/kernel/hal/fxp_base.h" #include "libspu/kernel/hal/integer.h" #include "libspu/kernel/hal/ring.h" // for fast fxp x int -#include "libspu/kernel/hal/shape_ops.h" #include "libspu/kernel/hal/type_cast.h" // TODO: handle dtype promotion inside integer dtypes. diff --git a/libspu/kernel/hal/polymorphic.h b/libspu/kernel/hal/polymorphic.h index 35673a51..5563eb29 100644 --- a/libspu/kernel/hal/polymorphic.h +++ b/libspu/kernel/hal/polymorphic.h @@ -14,9 +14,12 @@ #pragma once -#include "libspu/core/context.h" #include "libspu/core/value.h" +namespace spu { +class SPUContext; +} + namespace spu::kernel::hal { /// the element-wise absolute value function diff --git a/libspu/kernel/hal/prot_wrapper.cc b/libspu/kernel/hal/prot_wrapper.cc index d850703d..90bf27dc 100644 --- a/libspu/kernel/hal/prot_wrapper.cc +++ b/libspu/kernel/hal/prot_wrapper.cc @@ -15,11 +15,10 @@ #include "libspu/kernel/hal/prot_wrapper.h" #include -#include #include -#include "libspu/core/ndarray_ref.h" #include "libspu/core/prelude.h" +#include "libspu/core/trace.h" #include "libspu/core/type_util.h" #include "libspu/mpc/api.h" @@ -111,7 +110,11 @@ Value _trunc_s(SPUContext* ctx, const Value& in, size_t bits, SignType sign) { std::vector _sort_s(SPUContext* ctx, absl::Span x) { SPU_TRACE_HAL_DISP(ctx, x.size()); // FIXME(jimi): formalize mpc sort api - return dynDispatch>(ctx, "sort_a", x); + + // As pass absl::Span in dynDispatch is dangerous, we initialize a new vector + // here. And the copy of value is cheap, so it's ok. + std::vector x_val(x.begin(), x.end()); + return dynDispatch>(ctx, "sort_a", x_val); } MAP_UNARY_OP(p2s) diff --git a/libspu/kernel/hal/prot_wrapper.h b/libspu/kernel/hal/prot_wrapper.h index 131f533e..0d769ce7 100644 --- a/libspu/kernel/hal/prot_wrapper.h +++ b/libspu/kernel/hal/prot_wrapper.h @@ -16,9 +16,12 @@ #include -#include "libspu/core/context.h" #include "libspu/core/value.h" +namespace spu { +class SPUContext; +} + namespace spu::kernel::hal { // NOLINTBEGIN(readability-identifier-naming) diff --git a/libspu/kernel/hal/public_helper.cc b/libspu/kernel/hal/public_helper.cc index 9a3fd0a2..25fc749f 100644 --- a/libspu/kernel/hal/public_helper.cc +++ b/libspu/kernel/hal/public_helper.cc @@ -14,8 +14,10 @@ #include "libspu/kernel/hal/public_helper.h" +#include "libspu/core/context.h" #include "libspu/core/encoding.h" #include "libspu/core/ndarray_ref.h" +#include "libspu/core/trace.h" #include "libspu/mpc/common/pv2k.h" namespace spu::kernel::hal { @@ -26,7 +28,14 @@ NdArrayRef dump_public(SPUContext* ctx, const Value& v) { const auto field = v.storage_type().as()->field(); auto encoded = v.data().as(makeType(field)); - return decodeFromRing(encoded, v.dtype(), ctx->getFxpBits()); + const PtType pt_type = getDecodeType(v.dtype()); + NdArrayRef dst(makePtType(pt_type), v.shape()); + PtBufferView pv(static_cast(dst.data()), pt_type, dst.shape(), + dst.strides()); + + decodeFromRing(encoded, v.dtype(), ctx->getFxpBits(), &pv); + + return dst; } bool getBooleanValue(SPUContext* ctx, const spu::Value& value) { diff --git a/libspu/kernel/hal/public_helper.h b/libspu/kernel/hal/public_helper.h index 163fc2e2..2973d44e 100644 --- a/libspu/kernel/hal/public_helper.h +++ b/libspu/kernel/hal/public_helper.h @@ -14,10 +14,13 @@ #pragma once -#include "libspu/core/context.h" #include "libspu/core/value.h" #include "libspu/core/xt_helper.h" +namespace spu { +class SPUContext; +} + namespace spu::kernel::hal { // Export a value to a buffer. diff --git a/libspu/kernel/hal/random.cc b/libspu/kernel/hal/random.cc index 6c41f1fd..c614d191 100644 --- a/libspu/kernel/hal/random.cc +++ b/libspu/kernel/hal/random.cc @@ -14,8 +14,10 @@ #include "libspu/kernel/hal/random.h" +#include + #include "libspu/core/prelude.h" -#include "libspu/core/xt_helper.h" +#include "libspu/core/trace.h" #include "libspu/kernel/hal/constants.h" #include "libspu/kernel/hal/prot_wrapper.h" #include "libspu/kernel/hal/public_helper.h" diff --git a/libspu/kernel/hal/random.h b/libspu/kernel/hal/random.h index bd87d9a8..d355653c 100644 --- a/libspu/kernel/hal/random.h +++ b/libspu/kernel/hal/random.h @@ -14,9 +14,12 @@ #pragma once -#include "libspu/core/context.h" #include "libspu/core/value.h" +namespace spu { +class SPUContext; +} + namespace spu::kernel::hal { /// Uniform rand diff --git a/libspu/kernel/hal/ring.cc b/libspu/kernel/hal/ring.cc index 4b333d8f..d3cc3cf2 100644 --- a/libspu/kernel/hal/ring.cc +++ b/libspu/kernel/hal/ring.cc @@ -14,11 +14,13 @@ #include "libspu/kernel/hal/ring.h" -#include #include +#include #include "libspu/core/bit_utils.h" +#include "libspu/core/context.h" #include "libspu/core/prelude.h" +#include "libspu/core/trace.h" #include "libspu/kernel/hal/prot_wrapper.h" #include "libspu/kernel/hal/shape_ops.h" diff --git a/libspu/kernel/hal/ring.h b/libspu/kernel/hal/ring.h index ef18a578..da901655 100644 --- a/libspu/kernel/hal/ring.h +++ b/libspu/kernel/hal/ring.h @@ -14,9 +14,12 @@ #pragma once -#include "libspu/core/context.h" #include "libspu/core/value.h" +namespace spu { +class SPUContext; +} + // !!please read [README.md] for api naming conventions. // this module implements ops x ring 2k space WITHOUT dtype check. // diff --git a/libspu/kernel/hal/shape_ops.cc b/libspu/kernel/hal/shape_ops.cc index fd0adb01..7bf96065 100644 --- a/libspu/kernel/hal/shape_ops.cc +++ b/libspu/kernel/hal/shape_ops.cc @@ -16,7 +16,9 @@ #include +#include "libspu/core/context.h" #include "libspu/core/ndarray_ref.h" +#include "libspu/core/trace.h" #include "libspu/kernel/hal/prot_wrapper.h" namespace spu::kernel::hal { @@ -53,6 +55,18 @@ Value _cast_type(SPUContext* ctx, const Value& x, const Type& to) { } } +// Compact threshold heuristic, try to make it same as L1 cache size +#define COMPACT_THRESHOLD (32 * 1024) // 32K + +SPU_ALWAYS_INLINE NdArrayRef _try_compact(const NdArrayRef& in) { + // If in data is not compact after some shape ops and small enough, make it + // compact + if (in.numel() * in.elsize() <= COMPACT_THRESHOLD && !in.isCompact()) { + return in.clone(); + } + return in; +} + } // namespace Value transpose(SPUContext* ctx, const Value& in, const Axes& permutation) { @@ -77,15 +91,16 @@ Value transpose(SPUContext* ctx, const Value& in, const Axes& permutation) { return in; } - return Value(in.data().transpose(perm), in.dtype()); + return Value(_try_compact(in.data().transpose(perm)), in.dtype()); } Value slice(SPUContext* ctx, const Value& in, const Index& start_indices, const Index& end_indices, const Strides& strides) { SPU_TRACE_HAL_DISP(ctx, in, start_indices, end_indices, strides); - return Value(in.data().slice(start_indices, end_indices, strides), - in.dtype()); + return Value( + _try_compact(in.data().slice(start_indices, end_indices, strides)), + in.dtype()); } Value slice_scalar_at(SPUContext*, const Value& input, const Index& indices) { @@ -109,7 +124,7 @@ Value update_slice(SPUContext* ctx, const Value& in, const Value& update, Value reshape(SPUContext* ctx, const Value& in, const Shape& to_shape) { SPU_TRACE_HAL_DISP(ctx, in, to_shape); - return Value(in.data().reshape(to_shape), in.dtype()); + return Value(_try_compact(in.data().reshape(to_shape)), in.dtype()); } Value broadcast_to(SPUContext* ctx, const Value& in, const Shape& to_shape, diff --git a/libspu/kernel/hal/shape_ops.h b/libspu/kernel/hal/shape_ops.h index ab9ad6e6..61114cd2 100644 --- a/libspu/kernel/hal/shape_ops.h +++ b/libspu/kernel/hal/shape_ops.h @@ -16,9 +16,12 @@ #include -#include "libspu/core/context.h" #include "libspu/core/value.h" +namespace spu { +class SPUContext; +} + namespace spu::kernel::hal { /// the broadcast function diff --git a/libspu/kernel/hal/sort.cc b/libspu/kernel/hal/sort.cc index de38aa1c..586f252f 100644 --- a/libspu/kernel/hal/sort.cc +++ b/libspu/kernel/hal/sort.cc @@ -2,6 +2,7 @@ #include +#include "libspu/core/context.h" #include "libspu/kernel/hal/polymorphic.h" #include "libspu/kernel/hal/prot_wrapper.h" #include "libspu/kernel/hal/public_helper.h" @@ -136,7 +137,7 @@ std::vector sort1d(SPUContext *ctx, "Inputs should be 1-d but actually have {} dimensions", inputs[0].shape().ndim()); SPU_ENFORCE(std::all_of(inputs.begin(), inputs.end(), - [&inputs](const spu::Value v) { + [&inputs](const spu::Value &v) { return v.shape() == inputs[0].shape(); }), "Inputs shape mismatched"); @@ -148,9 +149,9 @@ std::vector sort1d(SPUContext *ctx, auto comparator = [&cmp, &inputs, &ctx](int64_t a, int64_t b) { std::vector values; values.reserve(2 * inputs.size()); - for (int64_t i = 0; i < static_cast(inputs.size()); ++i) { - values.push_back(hal::slice(ctx, inputs[i], {a}, {a + 1})); - values.push_back(hal::slice(ctx, inputs[i], {b}, {b + 1})); + for (const auto &input : inputs) { + values.push_back(hal::slice(ctx, input, {a}, {a + 1})); + values.push_back(hal::slice(ctx, input, {b}, {b + 1})); } spu::Value cmp_ret = cmp(values); return getBooleanValue(ctx, cmp_ret); @@ -164,8 +165,8 @@ std::vector sort1d(SPUContext *ctx, } ret.reserve(inputs.size()); - for (int64_t i = 0; i < static_cast(inputs.size()); ++i) { - ret.push_back(Permute1D(ctx, inputs[i], indices_to_sort)); + for (const auto &input : inputs) { + ret.push_back(Permute1D(ctx, input, indices_to_sort)); } } else { SPU_ENFORCE(!is_stable, diff --git a/libspu/kernel/hal/sort.h b/libspu/kernel/hal/sort.h index b0000905..dac88581 100644 --- a/libspu/kernel/hal/sort.h +++ b/libspu/kernel/hal/sort.h @@ -2,9 +2,12 @@ #include "absl/types/span.h" -#include "libspu/core/context.h" #include "libspu/core/value.h" +namespace spu { +class SPUContext; +} + namespace spu::kernel::hal { using CompFn = std::function)>; diff --git a/libspu/kernel/hal/type_cast.cc b/libspu/kernel/hal/type_cast.cc index 582035c4..8ff37ac7 100644 --- a/libspu/kernel/hal/type_cast.cc +++ b/libspu/kernel/hal/type_cast.cc @@ -14,8 +14,9 @@ #include "libspu/kernel/hal/type_cast.h" +#include "libspu/core/context.h" +#include "libspu/core/trace.h" #include "libspu/core/type_util.h" -#include "libspu/kernel/hal/constants.h" #include "libspu/kernel/hal/prot_wrapper.h" // vtype_cast #include "libspu/kernel/hal/ring.h" diff --git a/libspu/kernel/hal/type_cast.h b/libspu/kernel/hal/type_cast.h index e6e10bbf..cfcc05fb 100644 --- a/libspu/kernel/hal/type_cast.h +++ b/libspu/kernel/hal/type_cast.h @@ -14,9 +14,12 @@ #pragma once -#include "libspu/core/context.h" #include "libspu/core/value.h" +namespace spu { +class SPUContext; +} + namespace spu::kernel::hal { /// cast dtype diff --git a/libspu/kernel/hlo/BUILD.bazel b/libspu/kernel/hlo/BUILD.bazel index 4a8192e7..139a5c9f 100644 --- a/libspu/kernel/hlo/BUILD.bazel +++ b/libspu/kernel/hlo/BUILD.bazel @@ -16,33 +16,15 @@ load("//bazel:spu.bzl", "spu_cc_library", "spu_cc_test") package(default_visibility = ["//visibility:public"]) -spu_cc_library( - name = "hlo", - deps = [ - ":basic_binary", - ":basic_ternary", - ":basic_unary", - ":casting", - ":const", - ":control_flow", - ":convolution", - ":geometrical", - ":indexing", - ":rand", - ":reduce", - ":select_and_scatter", - ":shift", - ":sort", - ], -) - spu_cc_library( name = "basic_binary", srcs = ["basic_binary.cc"], hdrs = ["basic_binary.h"], deps = [ ":utils", - "//libspu/kernel/hal", + "//libspu/kernel/hal:complex", + "//libspu/kernel/hal:constants", + "//libspu/kernel/hal:polymorphic", ], ) @@ -51,9 +33,9 @@ spu_cc_test( srcs = ["basic_binary_test.cc"], deps = [ ":basic_binary", + ":casting", ":const", "//libspu/kernel:test_util", - "//libspu/kernel/hal", "//libspu/mpc/utils:simulate", ], ) @@ -63,8 +45,12 @@ spu_cc_library( srcs = ["basic_ternary.cc"], hdrs = ["basic_ternary.h"], deps = [ + ":casting", + ":const", ":utils", - "//libspu/kernel/hal", + "//libspu/kernel/hal:complex", + "//libspu/kernel/hal:polymorphic", + "//libspu/kernel/hal:ring", ], ) @@ -75,7 +61,6 @@ spu_cc_test( ":basic_ternary", ":const", "//libspu/kernel:test_util", - "//libspu/kernel/hal", "//libspu/mpc/utils:simulate", ], ) @@ -86,7 +71,10 @@ spu_cc_library( hdrs = ["basic_unary.h"], deps = [ ":utils", - "//libspu/kernel/hal", + "//libspu/kernel/hal:complex", + "//libspu/kernel/hal:constants", + "//libspu/kernel/hal:polymorphic", + "//libspu/kernel/hal:type_cast", ], ) @@ -95,9 +83,9 @@ spu_cc_test( srcs = ["basic_unary_test.cc"], deps = [ ":basic_unary", + ":casting", ":const", "//libspu/kernel:test_util", - "//libspu/kernel/hal", "//libspu/mpc/utils:simulate", ], ) @@ -108,7 +96,9 @@ spu_cc_library( hdrs = ["casting.h"], deps = [ ":utils", - "//libspu/kernel/hal", + "//libspu/kernel/hal:complex", + "//libspu/kernel/hal:polymorphic", + "//libspu/kernel/hal:type_cast", ], ) @@ -127,9 +117,10 @@ spu_cc_library( srcs = ["const.cc"], hdrs = ["const.h"], deps = [ - ":casting", ":utils", - "//libspu/kernel/hal", + "//libspu/core:encoding", + "//libspu/kernel/hal:constants", + "//libspu/kernel/hal:shape_ops", ], ) @@ -149,7 +140,11 @@ spu_cc_library( deps = [ ":const", ":utils", - "//libspu/kernel/hal", + "//libspu/kernel/hal:constants", + "//libspu/kernel/hal:polymorphic", + "//libspu/kernel/hal:public_helper", + "//libspu/kernel/hal:shape_ops", + "//libspu/kernel/hal:type_cast", ], ) @@ -158,8 +153,8 @@ spu_cc_library( srcs = ["convolution.cc"], hdrs = ["convolution.h"], deps = [ - ":utils", - "//libspu/kernel/hal", + "//libspu/kernel/hal:polymorphic", + "//libspu/kernel/hal:shape_ops", ], ) @@ -168,14 +163,14 @@ spu_cc_library( srcs = ["indexing.cc"], hdrs = ["indexing.h"], deps = [ - ":basic_binary", - ":basic_ternary", ":basic_unary", ":const", - ":geometrical", - ":reduce", ":utils", - "//libspu/kernel/hal", + "//libspu/kernel/hal:constants", + "//libspu/kernel/hal:polymorphic", + "//libspu/kernel/hal:ring", + "//libspu/kernel/hal:shape_ops", + "//libspu/kernel/hal:type_cast", ], ) @@ -183,6 +178,7 @@ spu_cc_test( name = "indexing_test", srcs = ["indexing_test.cc"], deps = [ + ":casting", ":indexing", "//libspu/kernel:test_util", ], @@ -193,9 +189,8 @@ spu_cc_library( srcs = ["geometrical.cc"], hdrs = ["geometrical.h"], deps = [ - ":basic_binary", - ":utils", - "//libspu/kernel/hal", + "//libspu/kernel/hal:complex", + "//libspu/kernel/hal:shape_ops", ], ) @@ -203,6 +198,8 @@ spu_cc_test( name = "geometrical_test", srcs = ["geometrical_test.cc"], deps = [ + ":basic_binary", + ":casting", ":const", ":geometrical", "//libspu/kernel:test_util", @@ -215,8 +212,7 @@ spu_cc_library( srcs = ["rand.cc"], hdrs = ["rand.h"], deps = [ - ":utils", - "//libspu/kernel/hal", + "//libspu/kernel/hal:random", ], ) @@ -225,9 +221,11 @@ spu_cc_library( srcs = ["reduce.cc"], hdrs = ["reduce.h"], deps = [ - ":geometrical", ":utils", - "//libspu/kernel/hal", + "//libspu/kernel/hal:constants", + "//libspu/kernel/hal:polymorphic", + "//libspu/kernel/hal:ring", + "//libspu/kernel/hal:shape_ops", ], ) @@ -239,7 +237,9 @@ spu_cc_library( ":const", ":reduce", ":utils", - "//libspu/kernel/hal", + "//libspu/kernel/hal:constants", + "//libspu/kernel/hal:polymorphic", + "//libspu/kernel/hal:shape_ops", ], ) @@ -258,7 +258,9 @@ spu_cc_library( hdrs = ["shift.h"], deps = [ ":utils", - "//libspu/kernel/hal", + "//libspu/kernel/hal:constants", + "//libspu/kernel/hal:polymorphic", + "//libspu/kernel/hal:public_helper", ], ) @@ -270,7 +272,8 @@ spu_cc_library( ":basic_binary", ":casting", ":utils", - "//libspu/kernel/hal", + "//libspu/kernel/hal:shape_ops", + "//libspu/kernel/hal:sort", ], ) @@ -289,7 +292,8 @@ spu_cc_library( srcs = ["utils.cc"], hdrs = ["utils.h"], deps = [ - "//libspu/kernel/hal", + "//libspu/kernel/hal:public_helper", + "//libspu/kernel/hal:shape_ops", ], ) @@ -307,7 +311,8 @@ spu_cc_library( hdrs = ["shuffle.h"], deps = [ ":sort", - "//libspu/kernel/hal", + "//libspu/kernel/hal:polymorphic", + "//libspu/kernel/hal:random", ], ) diff --git a/libspu/kernel/hlo/basic_binary.cc b/libspu/kernel/hlo/basic_binary.cc index 7e0d165c..1ca5bfc2 100644 --- a/libspu/kernel/hlo/basic_binary.cc +++ b/libspu/kernel/hlo/basic_binary.cc @@ -16,9 +16,7 @@ #include "libspu/kernel/hal/complex.h" #include "libspu/kernel/hal/constants.h" -#include "libspu/kernel/hal/debug.h" #include "libspu/kernel/hal/polymorphic.h" -#include "libspu/kernel/hal/type_cast.h" namespace spu::kernel::hlo { diff --git a/libspu/kernel/hlo/basic_binary_test.cc b/libspu/kernel/hlo/basic_binary_test.cc index f90c1c01..c55c6256 100644 --- a/libspu/kernel/hlo/basic_binary_test.cc +++ b/libspu/kernel/hlo/basic_binary_test.cc @@ -17,12 +17,10 @@ #include "gtest/gtest.h" #include "libspu/core/context.h" -#include "libspu/core/ndarray_ref.h" #include "libspu/core/value.h" #include "libspu/kernel/hlo/casting.h" #include "libspu/kernel/hlo/const.h" #include "libspu/kernel/test_util.h" -#include "libspu/mpc/factory.h" #include "libspu/mpc/utils/simulate.h" namespace spu::kernel::hlo { diff --git a/libspu/kernel/hlo/basic_ternary_test.cc b/libspu/kernel/hlo/basic_ternary_test.cc index 64972edd..5b8c7879 100644 --- a/libspu/kernel/hlo/basic_ternary_test.cc +++ b/libspu/kernel/hlo/basic_ternary_test.cc @@ -17,7 +17,6 @@ #include "gtest/gtest.h" #include "libspu/core/context.h" -#include "libspu/core/ndarray_ref.h" #include "libspu/core/value.h" #include "libspu/kernel/hlo/casting.h" #include "libspu/kernel/hlo/const.h" diff --git a/libspu/kernel/hlo/basic_unary.cc b/libspu/kernel/hlo/basic_unary.cc index f6f694ca..f819cf0c 100644 --- a/libspu/kernel/hlo/basic_unary.cc +++ b/libspu/kernel/hlo/basic_unary.cc @@ -14,8 +14,6 @@ #include "libspu/kernel/hlo/basic_unary.h" -#include "libspu/core/context.h" -#include "libspu/core/value.h" #include "libspu/kernel/hal/complex.h" #include "libspu/kernel/hal/constants.h" #include "libspu/kernel/hal/polymorphic.h" diff --git a/libspu/kernel/hlo/basic_unary.h b/libspu/kernel/hlo/basic_unary.h index 3be84357..46578ee7 100644 --- a/libspu/kernel/hlo/basic_unary.h +++ b/libspu/kernel/hlo/basic_unary.h @@ -14,7 +14,11 @@ #pragma once -#include "libspu/kernel/hal/hal.h" +#include "libspu/core/value.h" + +namespace spu { +class SPUContext; +} namespace spu::kernel::hlo { diff --git a/libspu/kernel/hlo/basic_unary_test.cc b/libspu/kernel/hlo/basic_unary_test.cc index 677cb204..7e4d2e0a 100644 --- a/libspu/kernel/hlo/basic_unary_test.cc +++ b/libspu/kernel/hlo/basic_unary_test.cc @@ -17,9 +17,7 @@ #include "gtest/gtest.h" #include "libspu/core/context.h" -#include "libspu/core/ndarray_ref.h" #include "libspu/core/value.h" -#include "libspu/kernel/hlo/casting.h" #include "libspu/kernel/hlo/const.h" #include "libspu/kernel/test_util.h" #include "libspu/mpc/utils/simulate.h" diff --git a/libspu/kernel/hlo/casting.h b/libspu/kernel/hlo/casting.h index 899512ae..469ba42d 100644 --- a/libspu/kernel/hlo/casting.h +++ b/libspu/kernel/hlo/casting.h @@ -14,7 +14,11 @@ #pragma once -#include "libspu/kernel/hlo/utils.h" +#include "libspu/core/value.h" + +namespace spu { +class SPUContext; +} namespace spu::kernel::hlo { diff --git a/libspu/kernel/hlo/const.h b/libspu/kernel/hlo/const.h index 0035e528..f9b458e7 100644 --- a/libspu/kernel/hlo/const.h +++ b/libspu/kernel/hlo/const.h @@ -14,11 +14,12 @@ #pragma once -#include - #include "libspu/core/pt_buffer_view.h" -#include "libspu/kernel/hlo/casting.h" -#include "libspu/kernel/hlo/utils.h" +#include "libspu/core/value.h" + +namespace spu { +class SPUContext; +} namespace spu::kernel::hlo { diff --git a/libspu/kernel/hlo/control_flow.cc b/libspu/kernel/hlo/control_flow.cc index 6fe472cb..c7c4a6df 100644 --- a/libspu/kernel/hlo/control_flow.cc +++ b/libspu/kernel/hlo/control_flow.cc @@ -15,12 +15,12 @@ #include "libspu/kernel/hlo/control_flow.h" #include "libspu/kernel/hal/constants.h" -#include "libspu/kernel/hal/debug.h" #include "libspu/kernel/hal/polymorphic.h" #include "libspu/kernel/hal/public_helper.h" #include "libspu/kernel/hal/shape_ops.h" #include "libspu/kernel/hal/type_cast.h" #include "libspu/kernel/hlo/const.h" +#include "libspu/kernel/hlo/utils.h" // Allow runtime to reveal `secret variable` use as while // condition result, debug purpose only. diff --git a/libspu/kernel/hlo/control_flow.h b/libspu/kernel/hlo/control_flow.h index c829beab..476d7f48 100644 --- a/libspu/kernel/hlo/control_flow.h +++ b/libspu/kernel/hlo/control_flow.h @@ -14,9 +14,12 @@ #pragma once -#include "libspu/core/context.h" #include "libspu/core/value.h" +namespace spu { +class SPUContext; +} + namespace spu::kernel::hlo { using BranchFcnT = std::function()>; diff --git a/libspu/kernel/hlo/convolution.cc b/libspu/kernel/hlo/convolution.cc index 4d6801f7..994c6381 100644 --- a/libspu/kernel/hlo/convolution.cc +++ b/libspu/kernel/hlo/convolution.cc @@ -14,14 +14,9 @@ #include "libspu/kernel/hlo/convolution.h" -#include "libspu/core/context.h" #include "libspu/core/value.h" -#include "libspu/kernel/hal/constants.h" #include "libspu/kernel/hal/polymorphic.h" -#include "libspu/kernel/hal/ring.h" #include "libspu/kernel/hal/shape_ops.h" -#include "libspu/kernel/hal/type_cast.h" -#include "libspu/kernel/hlo/utils.h" namespace spu::kernel::hlo { diff --git a/libspu/kernel/hlo/convolution.h b/libspu/kernel/hlo/convolution.h index ce9e743e..14fa4638 100644 --- a/libspu/kernel/hlo/convolution.h +++ b/libspu/kernel/hlo/convolution.h @@ -14,7 +14,11 @@ #pragma once -#include "libspu/kernel/hlo/utils.h" +#include "libspu/core/value.h" + +namespace spu { +class SPUContext; +} namespace spu::kernel::hlo { diff --git a/libspu/kernel/hlo/geometrical.h b/libspu/kernel/hlo/geometrical.h index 533515cf..677bc4f0 100644 --- a/libspu/kernel/hlo/geometrical.h +++ b/libspu/kernel/hlo/geometrical.h @@ -14,7 +14,12 @@ #pragma once -#include "libspu/kernel/hlo/utils.h" +#include "libspu/core/value.h" + +namespace spu { +class SPUContext; +} + namespace spu::kernel::hlo { spu::Value Transpose(SPUContext *ctx, const spu::Value &in, diff --git a/libspu/kernel/hlo/geometrical_test.cc b/libspu/kernel/hlo/geometrical_test.cc index be1f5f3c..15fe4548 100644 --- a/libspu/kernel/hlo/geometrical_test.cc +++ b/libspu/kernel/hlo/geometrical_test.cc @@ -17,9 +17,9 @@ #include "gtest/gtest.h" #include "libspu/core/context.h" -#include "libspu/core/ndarray_ref.h" #include "libspu/core/value.h" #include "libspu/kernel/hlo/basic_binary.h" +#include "libspu/kernel/hlo/casting.h" #include "libspu/kernel/hlo/const.h" #include "libspu/kernel/test_util.h" #include "libspu/mpc/utils/simulate.h" diff --git a/libspu/kernel/hlo/indexing.cc b/libspu/kernel/hlo/indexing.cc index 149ffaa9..d3ef4388 100644 --- a/libspu/kernel/hlo/indexing.cc +++ b/libspu/kernel/hlo/indexing.cc @@ -20,14 +20,13 @@ #include "libspu/core/ndarray_ref.h" #include "libspu/core/value.h" -#include "libspu/kernel/hal/hal.h" +#include "libspu/kernel/hal/constants.h" +#include "libspu/kernel/hal/polymorphic.h" #include "libspu/kernel/hal/ring.h" -#include "libspu/kernel/hlo/basic_binary.h" -#include "libspu/kernel/hlo/basic_ternary.h" +#include "libspu/kernel/hal/shape_ops.h" +#include "libspu/kernel/hal/type_cast.h" #include "libspu/kernel/hlo/basic_unary.h" #include "libspu/kernel/hlo/const.h" -#include "libspu/kernel/hlo/geometrical.h" -#include "libspu/kernel/hlo/reduce.h" #include "libspu/kernel/hlo/utils.h" // forward @@ -323,22 +322,22 @@ spu::Value SecretLinearUpdateIndexing(spu::SPUContext *ctx, // Basic idea here: // eq(iota, idx) * update + !eq(iota, idx) * operand auto linear_idx_broadcasted = - spu::kernel::hlo::Broadcast(ctx, linear_idx, {operand.numel()}, {}); + spu::kernel::hal::broadcast_to(ctx, linear_idx, {operand.numel()}, {}); spu::Value idx_iota = - spu::kernel::hlo::Iota(ctx, spu::DT_I64, operand.numel()); - auto mask = spu::kernel::hlo::Equal(ctx, linear_idx_broadcasted, idx_iota); + spu::kernel::hal::iota(ctx, spu::DT_I64, operand.numel()); + auto mask = spu::kernel::hal::equal(ctx, linear_idx_broadcasted, idx_iota); auto c0 = spu::kernel::hlo::Constant(ctx, static_cast(0), {}); - auto i0 = spu::kernel::hlo::Cast(ctx, c0, c0.vtype(), operand.dtype()); + auto i0 = spu::kernel::hal::dtype_cast(ctx, c0, operand.dtype()); auto reverse_mask = spu::kernel::hlo::Not(ctx, mask); auto broadcast_update = - spu::kernel::hlo::Broadcast(ctx, update, operand.shape(), {0}); + spu::kernel::hal::broadcast_to(ctx, update, operand.shape(), {0}); - return spu::kernel::hlo::Add( - ctx, spu::kernel::hlo::Mul(ctx, operand, reverse_mask), - spu::kernel::hlo::Mul(ctx, broadcast_update, mask)); + return spu::kernel::hal::add( + ctx, spu::kernel::hal::mul(ctx, operand, reverse_mask), + spu::kernel::hal::mul(ctx, broadcast_update, mask)); } std::vector ClampAndFlattenIndex( @@ -354,15 +353,15 @@ std::vector ClampAndFlattenIndex( std::transform(start_indices.cbegin(), start_indices.cend(), std::back_inserter(reshaped_start_indices), [&](const spu::Value &x) { - return spu::kernel::hlo::Reshape(ctx, x, {1}); + return spu::kernel::hal::reshape(ctx, x, {1}); }); auto concat_idx = - spu::kernel::hlo::Concatenate(ctx, reshaped_start_indices, 0); + spu::kernel::hal::concatenate(ctx, reshaped_start_indices, 0); auto lower_bound = spu::kernel::hlo::Constant(ctx, static_cast(0), concat_idx.shape()); - lower_bound = spu::kernel::hlo::Cast(ctx, lower_bound, lower_bound.vtype(), - concat_idx.dtype()); + lower_bound = + spu::kernel::hal::dtype_cast(ctx, lower_bound, concat_idx.dtype()); std::vector upper_bound_pt(start_indices.size()); for (size_t idx = 0; idx < upper_bound_pt.size(); ++idx) { @@ -370,14 +369,14 @@ std::vector ClampAndFlattenIndex( } auto upper_bound = spu::kernel::hlo::Constant(ctx, upper_bound_pt, concat_idx.shape()); - upper_bound = spu::kernel::hlo::Cast(ctx, upper_bound, upper_bound.vtype(), - concat_idx.dtype()); + upper_bound = + spu::kernel::hal::dtype_cast(ctx, upper_bound, concat_idx.dtype()); - auto c = spu::kernel::hlo::Clamp(ctx, concat_idx, lower_bound, upper_bound); + auto c = spu::kernel::hal::clamp(ctx, concat_idx, lower_bound, upper_bound); for (int64_t idx = 0; idx < static_cast(clamped_start.size()); ++idx) { - clamped_start[idx] = spu::kernel::hlo::Reshape( - ctx, spu::kernel::hlo::Slice(ctx, c, {idx}, {idx + 1}, {1}), {}); + clamped_start[idx] = spu::kernel::hal::reshape( + ctx, spu::kernel::hal::slice(ctx, c, {idx}, {idx + 1}, {1}), {}); } } @@ -386,9 +385,9 @@ std::vector ClampAndFlattenIndex( spu::kernel::hlo::Constant(ctx, static_cast(0), {}); int64_t stride = 1; for (int64_t idx = iterate_shape.size() - 1; idx >= 0; --idx) { - linear_idx = spu::kernel::hlo::Add( + linear_idx = spu::kernel::hal::add( ctx, linear_idx, - spu::kernel::hlo::Mul(ctx, clamped_start[idx], + spu::kernel::hal::mul(ctx, clamped_start[idx], spu::kernel::hlo::Constant(ctx, stride, {}))); stride *= limit_shape[idx]; } @@ -407,15 +406,15 @@ std::vector ClampAndFlattenIndex( auto num_index = iterate_shape.numel(); std::vector linear_indices; linear_indices.reserve(num_index); - auto added = spu::kernel::hlo::Add( + auto added = spu::kernel::hal::add( ctx, - spu::kernel::hlo::Broadcast( - ctx, spu::kernel::hlo::Reshape(ctx, linear_idx, {1}), {num_index}, + spu::kernel::hal::broadcast_to( + ctx, spu::kernel::hal::reshape(ctx, linear_idx, {1}), {num_index}, {0}), spu::kernel::hlo::Constant(ctx, flatten_idx, {num_index})); for (int64_t idx = 0; idx < num_index; ++idx) { - linear_indices.emplace_back(spu::kernel::hlo::Reshape( - ctx, spu::kernel::hlo::Slice(ctx, added, {idx}, {idx + 1}, {1}), {})); + linear_indices.emplace_back(spu::kernel::hal::reshape( + ctx, spu::kernel::hal::slice(ctx, added, {idx}, {idx + 1}, {1}), {})); } return linear_indices; } @@ -546,7 +545,7 @@ spu::Value DynamicUpdateSlice(SPUContext *ctx, const spu::Value &operand, spu::Value flattened_operand = hal::reshape(ctx, operand, {operand.numel()}); - spu::Value flattened_update = Reshape(ctx, update, {update.numel()}); + spu::Value flattened_update = hal::reshape(ctx, update, {update.numel()}); auto flattened_indices = ClampAndFlattenIndex( ctx, start_indices, update.shape(), operand.shape()); @@ -555,12 +554,12 @@ spu::Value DynamicUpdateSlice(SPUContext *ctx, const spu::Value &operand, for (int64_t n = 0; n < static_cast(flattened_indices.size()); ++n) { - auto update_slice = Slice(ctx, flattened_update, {n}, {n + 1}, {1}); + auto update_slice = hal::slice(ctx, flattened_update, {n}, {n + 1}, {1}); ret = SecretLinearUpdateIndexing(ctx, ret, update_slice, flattened_indices[n]); } - return Reshape(ctx, ret, operand.shape()); + return hal::reshape(ctx, ret, operand.shape()); } else { // Start indices @@ -715,10 +714,8 @@ spu::Value SecretDynamicSlice(SPUContext *ctx, const spu::Value &operand, hlo::Constant(ctx, limit, {static_cast(slice_size.size())}); // Cast to proper type - lower_bound = hlo::Cast(ctx, lower_bound, lower_bound.vtype(), - start_indices[0].dtype()); - upper_bound = hlo::Cast(ctx, upper_bound, upper_bound.vtype(), - start_indices[0].dtype()); + lower_bound = hal::dtype_cast(ctx, lower_bound, start_indices[0].dtype()); + upper_bound = hal::dtype_cast(ctx, upper_bound, start_indices[0].dtype()); // Reshape from scalar to {1} to make concat happy std::vector adjusted_start_indices; diff --git a/libspu/kernel/hlo/indexing.h b/libspu/kernel/hlo/indexing.h index 310a6bfc..91f38fdf 100644 --- a/libspu/kernel/hlo/indexing.h +++ b/libspu/kernel/hlo/indexing.h @@ -14,8 +14,11 @@ #pragma once -#include "libspu/core/context.h" -#include "libspu/kernel/hlo/utils.h" +#include "libspu/core/value.h" + +namespace spu { +class SPUContext; +} namespace spu::kernel::hlo { diff --git a/libspu/kernel/hlo/indexing_test.cc b/libspu/kernel/hlo/indexing_test.cc index 4529cfd7..da7c9c25 100644 --- a/libspu/kernel/hlo/indexing_test.cc +++ b/libspu/kernel/hlo/indexing_test.cc @@ -20,6 +20,7 @@ #include "libspu/core/ndarray_ref.h" #include "libspu/core/type.h" #include "libspu/core/value.h" +#include "libspu/kernel/hlo/casting.h" #include "libspu/kernel/hlo/const.h" #include "libspu/kernel/test_util.h" diff --git a/libspu/kernel/hlo/rand.h b/libspu/kernel/hlo/rand.h index d3312376..e09befe0 100644 --- a/libspu/kernel/hlo/rand.h +++ b/libspu/kernel/hlo/rand.h @@ -14,7 +14,11 @@ #pragma once -#include "libspu/kernel/hlo/utils.h" +#include "libspu/core/value.h" + +namespace spu { +class SPUContext; +} namespace spu::kernel::hlo { diff --git a/libspu/kernel/hlo/reduce.cc b/libspu/kernel/hlo/reduce.cc index b3280b35..5752d573 100644 --- a/libspu/kernel/hlo/reduce.cc +++ b/libspu/kernel/hlo/reduce.cc @@ -16,19 +16,13 @@ #include #include -#include -#include #include #include -#include "libspu/core/parallel_utils.h" -#include "libspu/core/xt_helper.h" #include "libspu/kernel/hal/constants.h" #include "libspu/kernel/hal/polymorphic.h" #include "libspu/kernel/hal/ring.h" #include "libspu/kernel/hal/shape_ops.h" -#include "libspu/kernel/hal/type_cast.h" -#include "libspu/kernel/hlo/geometrical.h" #include "libspu/kernel/hlo/utils.h" namespace spu::kernel::hlo { @@ -254,7 +248,7 @@ std::vector ReduceWindowImpl( std::vector windows(nargs); for (int64_t idx = 0; idx < nargs; ++idx) { windows[idx] = hal::slice(ctx, padded_inputs[idx], start, end, - (Strides)config.window_dilations); + Strides(config.window_dilations)); } reduced_rets.emplace_back( Reduce(ctx, windows, init_values, reduce_dims, reducer)); @@ -266,12 +260,12 @@ std::vector ReduceWindowImpl( for (int64_t input_idx = 0; input_idx < nargs; ++input_idx) { std::vector reduced_values; - for (size_t widx = 0; widx < reduced_rets.size(); ++widx) { - Shape new_shape = reduced_rets[widx][input_idx].shape(); + for (auto &reduced_ret : reduced_rets) { + Shape new_shape = reduced_ret[input_idx].shape(); new_shape.insert(new_shape.begin(), 1); reduced_values.emplace_back( - hal::reshape(ctx, reduced_rets[widx][input_idx], new_shape)); + hal::reshape(ctx, reduced_ret[input_idx], new_shape)); } rets.emplace_back( diff --git a/libspu/kernel/hlo/reduce.h b/libspu/kernel/hlo/reduce.h index 5f025015..5a3c5e12 100644 --- a/libspu/kernel/hlo/reduce.h +++ b/libspu/kernel/hlo/reduce.h @@ -19,9 +19,12 @@ #include "absl/types/span.h" -#include "libspu/core/context.h" #include "libspu/core/value.h" +namespace spu { +class SPUContext; +} + namespace spu::kernel::hlo { using BatchedValueBinaryFn = std::function( diff --git a/libspu/kernel/hlo/select_and_scatter.cc b/libspu/kernel/hlo/select_and_scatter.cc index 66ea1b38..87b32998 100644 --- a/libspu/kernel/hlo/select_and_scatter.cc +++ b/libspu/kernel/hlo/select_and_scatter.cc @@ -16,7 +16,6 @@ #include "libspu/kernel/hlo/select_and_scatter.h" #include "libspu/kernel/hal/constants.h" -#include "libspu/kernel/hal/debug.h" #include "libspu/kernel/hal/polymorphic.h" #include "libspu/kernel/hal/shape_ops.h" #include "libspu/kernel/hlo/const.h" // iota diff --git a/libspu/kernel/hlo/select_and_scatter.h b/libspu/kernel/hlo/select_and_scatter.h index cb35a908..dc4ee566 100644 --- a/libspu/kernel/hlo/select_and_scatter.h +++ b/libspu/kernel/hlo/select_and_scatter.h @@ -15,13 +15,15 @@ #pragma once #include -#include #include "absl/types/span.h" -#include "libspu/core/context.h" #include "libspu/core/value.h" +namespace spu { +class SPUContext; +} + namespace spu::kernel::hlo { using ValueBinaryFn = diff --git a/libspu/kernel/hlo/shift.cc b/libspu/kernel/hlo/shift.cc index bc3c76e5..08fe7140 100644 --- a/libspu/kernel/hlo/shift.cc +++ b/libspu/kernel/hlo/shift.cc @@ -20,7 +20,6 @@ #include "libspu/kernel/hal/constants.h" #include "libspu/kernel/hal/polymorphic.h" #include "libspu/kernel/hal/public_helper.h" -#include "libspu/kernel/hal/shape_ops.h" namespace spu::kernel::hlo { diff --git a/libspu/kernel/hlo/shift.h b/libspu/kernel/hlo/shift.h index ea2e6bcb..8195d99d 100644 --- a/libspu/kernel/hlo/shift.h +++ b/libspu/kernel/hlo/shift.h @@ -14,7 +14,11 @@ #pragma once -#include "libspu/kernel/hlo/utils.h" +#include "libspu/core/value.h" + +namespace spu { +class SPUContext; +} namespace spu::kernel::hlo { diff --git a/libspu/kernel/hlo/shuffle.cc b/libspu/kernel/hlo/shuffle.cc index 9555e1e9..2a10797f 100644 --- a/libspu/kernel/hlo/shuffle.cc +++ b/libspu/kernel/hlo/shuffle.cc @@ -14,10 +14,8 @@ #include "libspu/kernel/hlo/shuffle.h" -#include "libspu/kernel/hal/constants.h" #include "libspu/kernel/hal/polymorphic.h" #include "libspu/kernel/hal/random.h" -#include "libspu/kernel/hal/shape_ops.h" #include "libspu/kernel/hlo/sort.h" namespace spu::kernel::hlo { diff --git a/libspu/kernel/hlo/shuffle.h b/libspu/kernel/hlo/shuffle.h index 71a8aa47..034d5158 100644 --- a/libspu/kernel/hlo/shuffle.h +++ b/libspu/kernel/hlo/shuffle.h @@ -16,9 +16,12 @@ #include "absl/types/span.h" -#include "libspu/core/context.h" #include "libspu/core/value.h" +namespace spu { +class SPUContext; +} + namespace spu::kernel::hlo { // secret shuffle which means the order is kept secret between parties. diff --git a/libspu/kernel/hlo/sort.cc b/libspu/kernel/hlo/sort.cc index ea448482..ad4e3017 100644 --- a/libspu/kernel/hlo/sort.cc +++ b/libspu/kernel/hlo/sort.cc @@ -92,6 +92,7 @@ std::vector Sort(SPUContext *ctx, for (int64_t ni = 0; ni < N; ni++) { // TODO: all these small sort could be done in parallel. std::vector input_i; + input_i.reserve(inputs2d.size()); for (auto const &input : inputs2d) { // we need 1-d tensor here input_i.push_back( diff --git a/libspu/kernel/hlo/utils.cc b/libspu/kernel/hlo/utils.cc index a26d2707..24c72432 100644 --- a/libspu/kernel/hlo/utils.cc +++ b/libspu/kernel/hlo/utils.cc @@ -14,8 +14,8 @@ #include "libspu/kernel/hlo/utils.h" -#include "libspu/kernel/hal/hal.h" #include "libspu/kernel/hal/public_helper.h" +#include "libspu/kernel/hal/shape_ops.h" namespace spu::kernel { diff --git a/libspu/kernel/hlo/utils.h b/libspu/kernel/hlo/utils.h index 1d1ebc6d..20e4d180 100644 --- a/libspu/kernel/hlo/utils.h +++ b/libspu/kernel/hlo/utils.h @@ -17,7 +17,7 @@ #include #include -#include "xtensor/xarray.hpp" +#include "xtensor/xarray.hpp" // IWYU pragma: keep #include "libspu/core/context.h" #include "libspu/core/value.h" diff --git a/libspu/kernel/test_util.h b/libspu/kernel/test_util.h index b180b7d7..62baa3a3 100644 --- a/libspu/kernel/test_util.h +++ b/libspu/kernel/test_util.h @@ -15,8 +15,8 @@ #include "xtensor/xrandom.hpp" #include "libspu/core/context.h" +#include "libspu/core/pt_buffer_view.h" #include "libspu/core/value.h" -#include "libspu/core/xt_helper.h" #include "libspu/kernel/hal/prot_wrapper.h" // bad reference #include "libspu/kernel/hal/public_helper.h" // bad reference diff --git a/libspu/mpc/BUILD.bazel b/libspu/mpc/BUILD.bazel index 569aedfc..ffced4cb 100644 --- a/libspu/mpc/BUILD.bazel +++ b/libspu/mpc/BUILD.bazel @@ -21,6 +21,7 @@ spu_cc_library( hdrs = ["io_interface.h"], deps = [ "//libspu/core:ndarray_ref", + "//libspu/core:pt_buffer_view", "//libspu/core:type", ], ) diff --git a/libspu/mpc/aby3/BUILD.bazel b/libspu/mpc/aby3/BUILD.bazel index 6456ebc1..6d9fd664 100644 --- a/libspu/mpc/aby3/BUILD.bazel +++ b/libspu/mpc/aby3/BUILD.bazel @@ -70,7 +70,6 @@ spu_cc_library( ":value", "//libspu/mpc/common:communicator", "//libspu/mpc/common:prg_state", - "@yacl//yacl/utils:platform_utils", ], ) @@ -79,7 +78,6 @@ spu_cc_library( srcs = ["conversion.cc"], hdrs = ["conversion.h"], deps = [ - ":ot", ":value", "//libspu/mpc:ab_api", "//libspu/mpc/common:communicator", @@ -95,7 +93,7 @@ spu_cc_library( hdrs = ["value.h"], deps = [ ":type", - "//libspu/core", + "//libspu/core:ndarray_ref", "//libspu/mpc/utils:ring_ops", ], ) @@ -105,7 +103,7 @@ spu_cc_library( srcs = ["ot.cc"], hdrs = ["ot.h"], deps = [ - "//libspu/core", + "//libspu/core:ndarray_ref", "//libspu/mpc/common:communicator", "//libspu/mpc/common:prg_state", "//libspu/mpc/utils:ring_ops", diff --git a/libspu/mpc/aby3/arithmetic.cc b/libspu/mpc/aby3/arithmetic.cc index 14d7a410..90c85bbc 100644 --- a/libspu/mpc/aby3/arithmetic.cc +++ b/libspu/mpc/aby3/arithmetic.cc @@ -17,15 +17,12 @@ #include #include -#include "spdlog/spdlog.h" - #include "libspu/mpc/aby3/ot.h" #include "libspu/mpc/aby3/type.h" #include "libspu/mpc/aby3/value.h" #include "libspu/mpc/common/communicator.h" #include "libspu/mpc/common/prg_state.h" #include "libspu/mpc/common/pv2k.h" -#include "libspu/mpc/utils/linalg.h" #include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc::aby3 { diff --git a/libspu/mpc/aby3/boolean.cc b/libspu/mpc/aby3/boolean.cc index b5343163..a95a42ee 100644 --- a/libspu/mpc/aby3/boolean.cc +++ b/libspu/mpc/aby3/boolean.cc @@ -16,8 +16,6 @@ #include -#include "yacl/utils/platform_utils.h" - #include "libspu/core/bit_utils.h" #include "libspu/core/parallel_utils.h" #include "libspu/mpc/aby3/type.h" diff --git a/libspu/mpc/aby3/conversion.h b/libspu/mpc/aby3/conversion.h index 98ec7788..59248867 100644 --- a/libspu/mpc/aby3/conversion.h +++ b/libspu/mpc/aby3/conversion.h @@ -15,7 +15,6 @@ #pragma once #include "libspu/core/ndarray_ref.h" -#include "libspu/mpc/aby3/value.h" #include "libspu/mpc/kernel.h" namespace spu::mpc::aby3 { diff --git a/libspu/mpc/aby3/io.cc b/libspu/mpc/aby3/io.cc index 87c96309..6cc46c91 100644 --- a/libspu/mpc/aby3/io.cc +++ b/libspu/mpc/aby3/io.cc @@ -93,9 +93,8 @@ size_t Aby3Io::getBitSecretShareSize(size_t numel) const { return numel * type.size(); } -std::vector Aby3Io::makeBitSecret(const NdArrayRef& in) const { - SPU_ENFORCE(in.eltype().isa(), "expected PtType, got {}", in.eltype()); - PtType in_pt_type = in.eltype().as()->pt_type(); +std::vector Aby3Io::makeBitSecret(const PtBufferView& in) const { + PtType in_pt_type = in.pt_type; SPU_ENFORCE(in_pt_type == PT_BOOL); if (in_pt_type == PT_BOOL) { @@ -104,43 +103,39 @@ std::vector Aby3Io::makeBitSecret(const NdArrayRef& in) const { } const auto out_type = makeType(PT_U8, /* out_nbits */ 1); - const size_t numel = in.numel(); + const size_t numel = in.shape.numel(); - std::vector shares = {NdArrayRef(out_type, in.shape()), - NdArrayRef(out_type, in.shape()), - NdArrayRef(out_type, in.shape())}; + std::vector shares = {NdArrayRef(out_type, in.shape), + NdArrayRef(out_type, in.shape), + NdArrayRef(out_type, in.shape)}; - return DISPATCH_UINT_PT_TYPES(in_pt_type, "_", [&]() { - using in_el_t = ScalarT; - using bshr_el_t = uint8_t; - using bshr_t = std::array; + using bshr_el_t = uint8_t; + using bshr_t = std::array; - NdArrayView _in(in); + std::vector r0(numel); + std::vector r1(numel); - std::vector r0(numel); - std::vector r1(numel); + yacl::crypto::PrgAesCtr(yacl::crypto::RandSeed(), absl::MakeSpan(r0)); + yacl::crypto::PrgAesCtr(yacl::crypto::RandSeed(), absl::MakeSpan(r1)); - yacl::crypto::PrgAesCtr(yacl::crypto::RandSeed(), absl::MakeSpan(r0)); - yacl::crypto::PrgAesCtr(yacl::crypto::RandSeed(), absl::MakeSpan(r1)); + NdArrayView _s0(shares[0]); + NdArrayView _s1(shares[1]); + NdArrayView _s2(shares[2]); - NdArrayView _s0(shares[0]); - NdArrayView _s1(shares[1]); - NdArrayView _s2(shares[2]); + for (size_t idx = 0; idx < numel; idx++) { + const bshr_el_t r2 = + static_cast(in.get(idx)) - r0[idx] - r1[idx]; - for (int64_t idx = 0; idx < in.numel(); idx++) { - const bshr_el_t r2 = static_cast(_in[idx]) - r0[idx] - r1[idx]; + _s0[idx][0] = r0[idx] & 0x1; + _s0[idx][1] = r1[idx] & 0x1; - _s0[idx][0] = r0[idx] & 0x1; - _s0[idx][1] = r1[idx] & 0x1; + _s1[idx][0] = r1[idx] & 0x1; + _s1[idx][1] = r2 & 0x1; - _s1[idx][0] = r1[idx] & 0x1; - _s1[idx][1] = r2 & 0x1; - - _s2[idx][0] = r2 & 0x1; - _s2[idx][1] = r0[idx] & 0x1; - } - return shares; - }); + _s2[idx][0] = r2 & 0x1; + _s2[idx][1] = r0[idx] & 0x1; + } + return shares; } NdArrayRef Aby3Io::fromShares(const std::vector& shares) const { diff --git a/libspu/mpc/aby3/io.h b/libspu/mpc/aby3/io.h index f213adea..adf86b2e 100644 --- a/libspu/mpc/aby3/io.h +++ b/libspu/mpc/aby3/io.h @@ -29,7 +29,7 @@ class Aby3Io final : public BaseIo { NdArrayRef fromShares(const std::vector& shares) const override; - std::vector makeBitSecret(const NdArrayRef& in) const override; + std::vector makeBitSecret(const PtBufferView& in) const override; size_t getBitSecretShareSize(size_t numel) const override; diff --git a/libspu/mpc/aby3/ot.h b/libspu/mpc/aby3/ot.h index 771c8ad0..af6ad2d8 100644 --- a/libspu/mpc/aby3/ot.h +++ b/libspu/mpc/aby3/ot.h @@ -16,10 +16,7 @@ #include -#include "yacl/link/link.h" - #include "libspu/core/ndarray_ref.h" -#include "libspu/core/type_util.h" #include "libspu/mpc/common/communicator.h" #include "libspu/mpc/common/prg_state.h" diff --git a/libspu/mpc/aby3/value.h b/libspu/mpc/aby3/value.h index 1af40b62..fbb413a0 100644 --- a/libspu/mpc/aby3/value.h +++ b/libspu/mpc/aby3/value.h @@ -15,7 +15,6 @@ #pragma once #include "libspu/core/ndarray_ref.h" -#include "libspu/core/parallel_utils.h" #include "libspu/core/type_util.h" namespace spu::mpc::aby3 { diff --git a/libspu/mpc/cheetah/arith/BUILD.bazel b/libspu/mpc/cheetah/arith/BUILD.bazel index 7e6fc334..9842fad0 100644 --- a/libspu/mpc/cheetah/arith/BUILD.bazel +++ b/libspu/mpc/cheetah/arith/BUILD.bazel @@ -64,6 +64,7 @@ spu_cc_library( ], deps = [ "//libspu/core:prelude", + "//libspu/core:xt_helper", "//libspu/mpc/cheetah/rlwe:cheetah_rlwe", "@yacl//yacl/link", ], diff --git a/libspu/mpc/cheetah/arithmetic.cc b/libspu/mpc/cheetah/arithmetic.cc index 5810d06b..4778f92d 100644 --- a/libspu/mpc/cheetah/arithmetic.cc +++ b/libspu/mpc/cheetah/arithmetic.cc @@ -33,7 +33,6 @@ namespace spu::mpc::cheetah { NdArrayRef TruncA::proc(KernelEvalContext* ctx, const NdArrayRef& x, size_t bits, SignType sign) const { - SPU_TRACE_MPC_LEAF(ctx, x); size_t n = x.numel(); NdArrayRef out(x.eltype(), x.shape()); if (n == 0) { @@ -76,7 +75,6 @@ NdArrayRef TruncA::proc(KernelEvalContext* ctx, const NdArrayRef& x, // 1{(x0 + x1) > 2^{k - 1} - 1} = 1{x0 > 2^{k - 1} - 1 - x1} // is computed using a Millionare protocol. NdArrayRef MsbA2B::proc(KernelEvalContext* ctx, const NdArrayRef& x) const { - SPU_TRACE_MPC_LEAF(ctx, x); const int64_t numel = x.numel(); const auto field = ctx->getState()->getDefaultField(); const size_t nbits = nbits_ == 0 ? SizeOf(field) * 8 : nbits_; @@ -139,7 +137,6 @@ NdArrayRef MsbA2B::proc(KernelEvalContext* ctx, const NdArrayRef& x) const { NdArrayRef EqualAP::proc(KernelEvalContext* ctx, const NdArrayRef& x, const NdArrayRef& y) const { - SPU_TRACE_MPC_LEAF(ctx, x, y); EqualAA equal_aa; const auto field = ctx->getState()->getDefaultField(); // TODO(juhou): Can we use any place holder to indicate the dummy 0s. @@ -152,7 +149,6 @@ NdArrayRef EqualAP::proc(KernelEvalContext* ctx, const NdArrayRef& x, NdArrayRef EqualAA::proc(KernelEvalContext* ctx, const NdArrayRef& x, const NdArrayRef& y) const { - SPU_TRACE_MPC_LEAF(ctx, x, y); SPU_ENFORCE_EQ(x.shape(), y.shape()); const int64_t numel = x.numel(); @@ -203,7 +199,6 @@ NdArrayRef EqualAA::proc(KernelEvalContext* ctx, const NdArrayRef& x, NdArrayRef MulA1B::proc(KernelEvalContext* ctx, const NdArrayRef& ashr, const NdArrayRef& bshr) const { - SPU_TRACE_MPC_LEAF(ctx, ashr, bshr); SPU_ENFORCE_EQ(ashr.shape(), bshr.shape()); const int64_t numel = ashr.numel(); NdArrayRef out(ashr.eltype(), ashr.shape()); @@ -241,7 +236,6 @@ NdArrayRef MulA1B::proc(KernelEvalContext* ctx, const NdArrayRef& ashr, NdArrayRef MulAA::proc(KernelEvalContext* ctx, const NdArrayRef& x, const NdArrayRef& y) const { - SPU_TRACE_MPC_LEAF(ctx, x, y); SPU_ENFORCE_EQ(x.shape(), y.shape()); int64_t batch_sze = ctx->getState()->get()->OLEBatchSize(); @@ -357,4 +351,4 @@ NdArrayRef MatMulAA::proc(KernelEvalContext* ctx, const NdArrayRef& x, return ring_add(ret, task.get()).as(x.eltype()); } -} // namespace spu::mpc::cheetah \ No newline at end of file +} // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/rlwe/BUILD.bazel b/libspu/mpc/cheetah/rlwe/BUILD.bazel index f2871f86..5b08672f 100644 --- a/libspu/mpc/cheetah/rlwe/BUILD.bazel +++ b/libspu/mpc/cheetah/rlwe/BUILD.bazel @@ -71,7 +71,7 @@ spu_cc_test( srcs = ["modswitch_helper_test.cc"], deps = [ ":modswitch_helper", - "@com_github_xtensor_xtensor//:xtensor", + "//libspu/core:xt_helper", ], ) diff --git a/libspu/mpc/common/BUILD.bazel b/libspu/mpc/common/BUILD.bazel index 2a55faaf..13bc2c9a 100644 --- a/libspu/mpc/common/BUILD.bazel +++ b/libspu/mpc/common/BUILD.bazel @@ -21,8 +21,6 @@ spu_cc_library( srcs = ["pv2k.cc"], hdrs = ["pv2k.h"], deps = [ - "//libspu/core", - "//libspu/core:trace", "//libspu/mpc:kernel", "//libspu/mpc/common:communicator", "//libspu/mpc/common:prg_state", @@ -45,7 +43,10 @@ spu_cc_library( deps = [ "//libspu/core:object", "//libspu/mpc/utils:ring_ops", - "@yacl//yacl/link", + "@yacl//yacl/link:context", + "@yacl//yacl/link/algorithm:allgather", + "@yacl//yacl/link/algorithm:broadcast", + "@yacl//yacl/link/algorithm:gather", ], ) @@ -63,11 +64,11 @@ spu_cc_library( srcs = ["prg_state.cc"], hdrs = ["prg_state.h"], deps = [ - "//libspu/core", "//libspu/core:object", "@yacl//yacl/crypto/tools:prg", "@yacl//yacl/crypto/utils:rand", - "@yacl//yacl/link", + "@yacl//yacl/link:context", + "@yacl//yacl/link/algorithm:allgather", ], ) @@ -77,6 +78,7 @@ spu_cc_test( deps = [ ":prg_state", "//libspu/mpc/utils:simulate", + "@yacl//yacl/link/algorithm:barrier", ], ) diff --git a/libspu/mpc/common/communicator.h b/libspu/mpc/common/communicator.h index 32f38e58..d566450e 100644 --- a/libspu/mpc/common/communicator.h +++ b/libspu/mpc/common/communicator.h @@ -21,7 +21,10 @@ #include #include "yacl/base/buffer.h" -#include "yacl/link/link.h" +#include "yacl/link/algorithm/allgather.h" +#include "yacl/link/algorithm/broadcast.h" +#include "yacl/link/algorithm/gather.h" +#include "yacl/link/context.h" #include "libspu/core/ndarray_ref.h" #include "libspu/core/object.h" diff --git a/libspu/mpc/common/prg_state.cc b/libspu/mpc/common/prg_state.cc index ba333118..c7f0458a 100644 --- a/libspu/mpc/common/prg_state.cc +++ b/libspu/mpc/common/prg_state.cc @@ -16,6 +16,7 @@ #include "yacl/crypto/tools/prg.h" #include "yacl/crypto/utils/rand.h" +#include "yacl/link/algorithm/allgather.h" #include "yacl/utils/serialize.h" namespace spu::mpc { diff --git a/libspu/mpc/common/prg_state.h b/libspu/mpc/common/prg_state.h index 8d525cbb..7e8938e6 100644 --- a/libspu/mpc/common/prg_state.h +++ b/libspu/mpc/common/prg_state.h @@ -16,7 +16,7 @@ #include "absl/types/span.h" #include "yacl/crypto/tools/prg.h" -#include "yacl/link/link.h" +#include "yacl/link/context.h" #include "libspu/core/ndarray_ref.h" #include "libspu/core/object.h" diff --git a/libspu/mpc/common/prg_state_test.cc b/libspu/mpc/common/prg_state_test.cc index 3d7f011f..fd1fe2ef 100644 --- a/libspu/mpc/common/prg_state_test.cc +++ b/libspu/mpc/common/prg_state_test.cc @@ -15,7 +15,8 @@ #include "libspu/mpc/common/prg_state.h" #include "gtest/gtest.h" -#include "yacl/link/link.h" +#include "yacl/link/algorithm/barrier.h" +#include "yacl/link/context.h" #include "libspu/mpc/utils/simulate.h" diff --git a/libspu/mpc/common/pv2k.cc b/libspu/mpc/common/pv2k.cc index 06acfac4..71641d78 100644 --- a/libspu/mpc/common/pv2k.cc +++ b/libspu/mpc/common/pv2k.cc @@ -17,7 +17,6 @@ #include #include "libspu/core/ndarray_ref.h" -#include "libspu/core/trace.h" #include "libspu/mpc/common/communicator.h" #include "libspu/mpc/common/prg_state.h" #include "libspu/mpc/kernel.h" diff --git a/libspu/mpc/factory.h b/libspu/mpc/factory.h index 5583faa8..110a35d4 100644 --- a/libspu/mpc/factory.h +++ b/libspu/mpc/factory.h @@ -16,7 +16,7 @@ #include -#include "yacl/link/link.h" +#include "yacl/link/context.h" #include "libspu/core/context.h" #include "libspu/mpc/io_interface.h" diff --git a/libspu/mpc/io_interface.h b/libspu/mpc/io_interface.h index d1035756..f29054e2 100644 --- a/libspu/mpc/io_interface.h +++ b/libspu/mpc/io_interface.h @@ -18,6 +18,7 @@ #include #include "libspu/core/ndarray_ref.h" +#include "libspu/core/pt_buffer_view.h" namespace spu::mpc { @@ -49,7 +50,7 @@ class IoInterface { // // @param raw, with type as PtType. virtual std::vector makeBitSecret( - const NdArrayRef& raw) const = 0; + const PtBufferView& raw) const = 0; virtual size_t getBitSecretShareSize(size_t numel) const = 0; @@ -73,7 +74,7 @@ class BaseIo : public IoInterface { explicit BaseIo(FieldType field, size_t world_size) : field_(field), world_size_(world_size) {} - std::vector makeBitSecret(const NdArrayRef&) const override { + std::vector makeBitSecret(const PtBufferView&) const override { SPU_THROW("should not be here"); } diff --git a/libspu/mpc/kernel.h b/libspu/mpc/kernel.h index 456ebfe4..a1ce4795 100644 --- a/libspu/mpc/kernel.h +++ b/libspu/mpc/kernel.h @@ -15,7 +15,6 @@ #pragma once #include "libspu/core/context.h" -#include "libspu/core/prelude.h" namespace spu::mpc { diff --git a/libspu/mpc/ref2k/ref2k.cc b/libspu/mpc/ref2k/ref2k.cc index 1bb9b819..7a600a40 100644 --- a/libspu/mpc/ref2k/ref2k.cc +++ b/libspu/mpc/ref2k/ref2k.cc @@ -186,7 +186,6 @@ class Ref2kNotS : public UnaryKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override { - SPU_TRACE_MPC_LEAF(ctx, in); const auto field = in.eltype().as()->field(); return ring_not(in).as(makeType(field)); } @@ -202,7 +201,6 @@ class Ref2kAddSS : public BinaryKernel { NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const NdArrayRef& rhs) const override { - SPU_TRACE_MPC_LEAF(ctx, lhs, rhs); SPU_ENFORCE(lhs.eltype() == rhs.eltype()); return ring_add(lhs, rhs).as(lhs.eltype()); } @@ -218,7 +216,6 @@ class Ref2kAddSP : public BinaryKernel { NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const NdArrayRef& rhs) const override { - SPU_TRACE_MPC_LEAF(ctx, lhs, rhs); return ring_add(lhs, rhs).as(lhs.eltype()); } }; @@ -233,7 +230,6 @@ class Ref2kMulSS : public BinaryKernel { NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const NdArrayRef& rhs) const override { - SPU_TRACE_MPC_LEAF(ctx, lhs, rhs); SPU_ENFORCE(lhs.eltype() == rhs.eltype()); return ring_mul(lhs, rhs).as(lhs.eltype()); } @@ -249,7 +245,6 @@ class Ref2kMulSP : public BinaryKernel { NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const NdArrayRef& rhs) const override { - SPU_TRACE_MPC_LEAF(ctx, lhs, rhs); return ring_mul(lhs, rhs).as(lhs.eltype()); } }; @@ -264,7 +259,6 @@ class Ref2kMatMulSS : public MatmulKernel { NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const NdArrayRef& rhs) const override { - SPU_TRACE_MPC_LEAF(ctx, lhs, rhs); SPU_ENFORCE(lhs.eltype() == rhs.eltype()); return ring_mmul(lhs, rhs).as(lhs.eltype()); } @@ -280,7 +274,6 @@ class Ref2kMatMulSP : public MatmulKernel { NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const NdArrayRef& rhs) const override { - SPU_TRACE_MPC_LEAF(ctx, lhs, rhs); return ring_mmul(lhs, rhs).as(lhs.eltype()); } }; @@ -295,7 +288,6 @@ class Ref2kAndSS : public BinaryKernel { NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const NdArrayRef& rhs) const override { - SPU_TRACE_MPC_LEAF(ctx, lhs, rhs); SPU_ENFORCE(lhs.eltype() == rhs.eltype()); return ring_and(lhs, rhs).as(lhs.eltype()); } @@ -311,7 +303,6 @@ class Ref2kAndSP : public BinaryKernel { NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const NdArrayRef& rhs) const override { - SPU_TRACE_MPC_LEAF(ctx, lhs, rhs); return ring_and(lhs, rhs).as(lhs.eltype()); } }; @@ -326,7 +317,6 @@ class Ref2kXorSS : public BinaryKernel { NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const NdArrayRef& rhs) const override { - SPU_TRACE_MPC_LEAF(ctx, lhs, rhs); SPU_ENFORCE(lhs.eltype() == rhs.eltype()); return ring_xor(lhs, rhs).as(lhs.eltype()); } @@ -342,7 +332,6 @@ class Ref2kXorSP : public BinaryKernel { NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const NdArrayRef& rhs) const override { - SPU_TRACE_MPC_LEAF(ctx, lhs, rhs); return ring_xor(lhs, rhs).as(lhs.eltype()); } }; @@ -357,7 +346,6 @@ class Ref2kLShiftS : public ShiftKernel { NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, size_t bits) const override { - SPU_TRACE_MPC_LEAF(ctx, in, bits); return ring_lshift(in, bits).as(in.eltype()); } }; @@ -372,7 +360,6 @@ class Ref2kRShiftS : public ShiftKernel { NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, size_t bits) const override { - SPU_TRACE_MPC_LEAF(ctx, in, bits); return ring_rshift(in, bits).as(in.eltype()); } }; @@ -391,7 +378,6 @@ class Ref2kBitrevS : public BitrevKernel { SPU_ENFORCE(start <= end); SPU_ENFORCE(end <= SizeOf(field) * 8); - SPU_TRACE_MPC_LEAF(ctx, in, start, end); return ring_bitrev(in, start, end).as(in.eltype()); } }; @@ -406,7 +392,6 @@ class Ref2kARShiftS : public ShiftKernel { NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, size_t bits) const override { - SPU_TRACE_MPC_LEAF(ctx, in, bits); return ring_arshift(in, bits).as(in.eltype()); } }; @@ -427,8 +412,6 @@ class Ref2kTruncS : public TruncAKernel { NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, size_t bits, SignType) const override { - SPU_TRACE_MPC_LEAF(ctx, in, bits); - // Rounding // AxB = (AxB >> 14) + ((AxB >> 13) & 1); // See @@ -452,7 +435,6 @@ class Ref2kMsbS : public UnaryKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override { - SPU_TRACE_MPC_LEAF(ctx, in); return ring_rshift(in, in.elsize() * 8 - 1).as(in.eltype()); } }; diff --git a/libspu/mpc/semi2k/BUILD.bazel b/libspu/mpc/semi2k/BUILD.bazel index ebba5361..9c2f4163 100644 --- a/libspu/mpc/semi2k/BUILD.bazel +++ b/libspu/mpc/semi2k/BUILD.bazel @@ -40,7 +40,6 @@ spu_cc_library( deps = [ ":state", ":type", - "//libspu/mpc:ab_api", "//libspu/mpc:kernel", "//libspu/mpc/common:communicator", ], @@ -68,7 +67,6 @@ spu_cc_library( ":state", ":type", "//libspu/core:vectorize", - "//libspu/mpc:ab_api", "//libspu/mpc:kernel", "//libspu/mpc/common:communicator", "//libspu/mpc/utils:circuits", diff --git a/libspu/mpc/semi2k/arithmetic.cc b/libspu/mpc/semi2k/arithmetic.cc index e6d5b8c6..358039c4 100644 --- a/libspu/mpc/semi2k/arithmetic.cc +++ b/libspu/mpc/semi2k/arithmetic.cc @@ -18,7 +18,6 @@ #include "libspu/core/type_util.h" #include "libspu/core/vectorize.h" -#include "libspu/mpc/ab_api.h" #include "libspu/mpc/common/communicator.h" #include "libspu/mpc/common/prg_state.h" #include "libspu/mpc/common/pv2k.h" diff --git a/libspu/mpc/semi2k/beaver/BUILD.bazel b/libspu/mpc/semi2k/beaver/BUILD.bazel index dc861dd2..7929b599 100644 --- a/libspu/mpc/semi2k/beaver/BUILD.bazel +++ b/libspu/mpc/semi2k/beaver/BUILD.bazel @@ -37,6 +37,7 @@ spu_cc_test( deps = [ ":beaver_tfp", ":beaver_ttp", + "//libspu/core:xt_helper", "//libspu/mpc/semi2k/beaver/ttp_server:beaver_server", "//libspu/mpc/utils:permute", "//libspu/mpc/utils:simulate", @@ -49,7 +50,6 @@ spu_cc_library( srcs = ["trusted_party.cc"], hdrs = ["trusted_party.h"], deps = [ - "//libspu/core:type_util", "//libspu/mpc/common:prg_tensor", "//libspu/mpc/utils:permute", "//libspu/mpc/utils:ring_ops", @@ -60,7 +60,7 @@ spu_cc_library( name = "beaver_interface", hdrs = ["beaver_interface.h"], deps = [ - "//libspu/core", + "//libspu/core:ndarray_ref", ], ) @@ -73,8 +73,9 @@ spu_cc_library( "//libspu/mpc/common:prg_tensor", "//libspu/mpc/semi2k/beaver/ttp_server:service_cc_proto", "//libspu/mpc/utils:ring_ops", - "@com_github_microsoft_seal//:seal", - "@yacl//yacl/link", + "@yacl//yacl/link:context", + "@yacl//yacl/link/algorithm:barrier", "@yacl//yacl/utils:parallel", + "@yacl//yacl/utils:serialize", ], ) diff --git a/libspu/mpc/semi2k/beaver/beaver_interface.h b/libspu/mpc/semi2k/beaver/beaver_interface.h index 31097fa9..1f5b89e2 100644 --- a/libspu/mpc/semi2k/beaver/beaver_interface.h +++ b/libspu/mpc/semi2k/beaver/beaver_interface.h @@ -62,7 +62,7 @@ class Beaver { │ │ └───────────────────────┘ - Perm(A) = B + InversePerm(A) = B if perm_rank == lctx->Rank(); perm not empty. */ diff --git a/libspu/mpc/semi2k/beaver/beaver_test.cc b/libspu/mpc/semi2k/beaver/beaver_test.cc index be27751c..2ee171b7 100644 --- a/libspu/mpc/semi2k/beaver/beaver_test.cc +++ b/libspu/mpc/semi2k/beaver/beaver_test.cc @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "fmt/format.h" #include "gtest/gtest.h" #include "xtensor/xarray.hpp" -#include "yacl/link/link.h" +#include "yacl/link/algorithm/barrier.h" +#include "yacl/link/context.h" #include "libspu/core/type_util.h" #include "libspu/core/xt_helper.h" diff --git a/libspu/mpc/semi2k/beaver/beaver_tfp.cc b/libspu/mpc/semi2k/beaver/beaver_tfp.cc index 3e279fbc..b695d34d 100644 --- a/libspu/mpc/semi2k/beaver/beaver_tfp.cc +++ b/libspu/mpc/semi2k/beaver/beaver_tfp.cc @@ -14,16 +14,14 @@ #include "libspu/mpc/semi2k/beaver/beaver_tfp.h" -#include #include #include "yacl/crypto/utils/rand.h" -#include "yacl/link/link.h" +#include "yacl/link/algorithm/gather.h" #include "yacl/utils/serialize.h" #include "libspu/mpc/common/prg_tensor.h" #include "libspu/mpc/semi2k/beaver/trusted_party.h" -#include "libspu/mpc/utils/permute.h" #include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc::semi2k { @@ -125,7 +123,7 @@ BeaverTfpUnsafe::Triple BeaverTfpUnsafe::TruncPr(FieldType field, NdArrayRef BeaverTfpUnsafe::RandBit(FieldType field, const Shape& shape) { std::vector descs(1); - auto a = prgCreateArray(field, shape, seed_, &counter_, &descs[0]); + auto a = prgCreateArray(field, shape, seed_, &counter_, descs.data()); if (lctx_->Rank() == 0) { auto adjust = TrustedParty::adjustRandBit(descs, seeds_); diff --git a/libspu/mpc/semi2k/beaver/beaver_ttp.cc b/libspu/mpc/semi2k/beaver/beaver_ttp.cc index c1e970a8..213534c5 100644 --- a/libspu/mpc/semi2k/beaver/beaver_ttp.cc +++ b/libspu/mpc/semi2k/beaver/beaver_ttp.cc @@ -19,7 +19,7 @@ #include #include "yacl/crypto/utils/rand.h" -#include "yacl/link/link.h" +#include "yacl/link/algorithm/barrier.h" #include "yacl/utils/serialize.h" #include "libspu/mpc/common/prg_tensor.h" diff --git a/libspu/mpc/semi2k/beaver/beaver_ttp.h b/libspu/mpc/semi2k/beaver/beaver_ttp.h index 9585da3f..c5ea9677 100644 --- a/libspu/mpc/semi2k/beaver/beaver_ttp.h +++ b/libspu/mpc/semi2k/beaver/beaver_ttp.h @@ -15,7 +15,6 @@ #pragma once #include -#include #include "brpc/channel.h" #include "yacl/link/context.h" diff --git a/libspu/mpc/semi2k/beaver/trusted_party.h b/libspu/mpc/semi2k/beaver/trusted_party.h index 8533f435..0c4de181 100644 --- a/libspu/mpc/semi2k/beaver/trusted_party.h +++ b/libspu/mpc/semi2k/beaver/trusted_party.h @@ -14,9 +14,6 @@ #pragma once -#include -#include -#include #include #include "absl/types/span.h" diff --git a/libspu/mpc/semi2k/boolean.cc b/libspu/mpc/semi2k/boolean.cc index b5afb13e..23be034d 100644 --- a/libspu/mpc/semi2k/boolean.cc +++ b/libspu/mpc/semi2k/boolean.cc @@ -17,11 +17,9 @@ #include #include "libspu/core/bit_utils.h" -#include "libspu/mpc/ab_api.h" #include "libspu/mpc/common/communicator.h" #include "libspu/mpc/common/prg_state.h" #include "libspu/mpc/common/pv2k.h" -#include "libspu/mpc/kernel.h" #include "libspu/mpc/semi2k/state.h" #include "libspu/mpc/semi2k/type.h" #include "libspu/mpc/utils/ring_ops.h" diff --git a/libspu/mpc/semi2k/protocol.cc b/libspu/mpc/semi2k/protocol.cc index ebc4880f..d8d8f43c 100644 --- a/libspu/mpc/semi2k/protocol.cc +++ b/libspu/mpc/semi2k/protocol.cc @@ -14,6 +14,7 @@ #include "libspu/mpc/semi2k/protocol.h" +#include "libspu/mpc/common/communicator.h" #include "libspu/mpc/common/prg_state.h" #include "libspu/mpc/common/pv2k.h" #include "libspu/mpc/semi2k/arithmetic.h" diff --git a/libspu/mpc/semi2k/protocol.h b/libspu/mpc/semi2k/protocol.h index c27917d4..b708acfa 100644 --- a/libspu/mpc/semi2k/protocol.h +++ b/libspu/mpc/semi2k/protocol.h @@ -14,7 +14,7 @@ #pragma once -#include "yacl/link/link.h" +#include "yacl/link/context.h" #include "libspu/core/context.h" diff --git a/libspu/mpc/semi2k/state.h b/libspu/mpc/semi2k/state.h index d6573ad5..450ba524 100644 --- a/libspu/mpc/semi2k/state.h +++ b/libspu/mpc/semi2k/state.h @@ -16,7 +16,7 @@ #include -#include "libspu/mpc/common/communicator.h" +#include "libspu/core/object.h" #include "libspu/mpc/semi2k/beaver/beaver_interface.h" #include "libspu/mpc/semi2k/beaver/beaver_tfp.h" #include "libspu/mpc/semi2k/beaver/beaver_ttp.h" diff --git a/libspu/mpc/spdz2k/BUILD.bazel b/libspu/mpc/spdz2k/BUILD.bazel index cad7c4ed..8be7c3b3 100644 --- a/libspu/mpc/spdz2k/BUILD.bazel +++ b/libspu/mpc/spdz2k/BUILD.bazel @@ -177,7 +177,7 @@ spu_cc_library( hdrs = ["value.h"], deps = [ ":type", - "//libspu/core", + "//libspu/core:ndarray_ref", "//libspu/mpc/common:pv2k", "//libspu/mpc/utils:ring_ops", ], diff --git a/libspu/mpc/spdz2k/arithmetic.cc b/libspu/mpc/spdz2k/arithmetic.cc index 69c40571..ee3586c8 100644 --- a/libspu/mpc/spdz2k/arithmetic.cc +++ b/libspu/mpc/spdz2k/arithmetic.cc @@ -93,8 +93,6 @@ NdArrayRef GetMacShare(KernelEvalContext* ctx, const NdArrayRef& in) { } NdArrayRef RandA::proc(KernelEvalContext* ctx, const Shape& shape) const { - SPU_TRACE_MPC_LEAF(ctx, shape); - const auto field = ctx->getState()->getDefaultField(); auto* prg_state = ctx->getState(); auto* beaver = ctx->getState()->beaver(); @@ -115,8 +113,6 @@ NdArrayRef RandA::proc(KernelEvalContext* ctx, const Shape& shape) const { } NdArrayRef P2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { - SPU_TRACE_MPC_LEAF(ctx, in); - const auto field = ctx->getState()->getDefaultField(); auto* comm = ctx->getState(); const auto key = ctx->getState()->key(); @@ -136,8 +132,6 @@ NdArrayRef P2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { } NdArrayRef A2P::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { - SPU_TRACE_MPC_LEAF(ctx, in); - const auto out_field = ctx->getState()->getDefaultField(); auto* beaver = ctx->getState()->beaver(); const auto k = ctx->getState()->k(); @@ -158,8 +152,6 @@ NdArrayRef A2P::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { NdArrayRef A2V::proc(KernelEvalContext* ctx, const NdArrayRef& in, size_t rank) const { - SPU_TRACE_MPC_LEAF(ctx, in); - const auto field = ctx->getState()->getDefaultField(); const auto out_field = ctx->getState()->getDefaultField(); auto* comm = ctx->getState(); @@ -221,8 +213,6 @@ NdArrayRef V2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { } NdArrayRef NotA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { - SPU_TRACE_MPC_LEAF(ctx, in); - const auto field = in.eltype().as()->field(); const auto key = ctx->getState()->key(); auto* comm = ctx->getState(); @@ -251,8 +241,6 @@ NdArrayRef NotA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { //////////////////////////////////////////////////////////////////// NdArrayRef AddAP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const NdArrayRef& rhs) const { - SPU_TRACE_MPC_LEAF(ctx, lhs, rhs); - const auto field = lhs.eltype().as()->field(); auto* comm = ctx->getState(); const auto key = ctx->getState()->key(); @@ -275,8 +263,6 @@ NdArrayRef AddAP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, NdArrayRef AddAA::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const NdArrayRef& rhs) const { - SPU_TRACE_MPC_LEAF(ctx, lhs, rhs); - const auto field = lhs.eltype().as()->field(); // lhs @@ -456,8 +442,6 @@ bool BatchCheck(KernelEvalContext* ctx, const std::vector& ins) { NdArrayRef MulAP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const NdArrayRef& rhs) const { - SPU_TRACE_MPC_LEAF(ctx, lhs, rhs); - const auto field = lhs.eltype().as()->field(); // lhs @@ -480,8 +464,6 @@ NdArrayRef MulAP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, // TODO: use DISPATCH_ALL_FIELDS instead of ring ops to improve performance NdArrayRef MulAA::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const NdArrayRef& rhs) const { - SPU_TRACE_MPC_LEAF(ctx, lhs, rhs); - const auto field = lhs.eltype().as()->field(); auto* comm = ctx->getState(); auto* beaver = ctx->getState()->beaver(); @@ -542,8 +524,6 @@ NdArrayRef MulAA::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, //////////////////////////////////////////////////////////////////// NdArrayRef MatMulAP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const NdArrayRef& rhs) const { - SPU_TRACE_MPC_LEAF(ctx, lhs, rhs); - const auto field = lhs.eltype().as()->field(); // in @@ -559,8 +539,6 @@ NdArrayRef MatMulAP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, NdArrayRef MatMulAA::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const NdArrayRef& rhs) const { - SPU_TRACE_MPC_LEAF(ctx, lhs, rhs); - const auto field = lhs.eltype().as()->field(); auto* comm = ctx->getState(); auto* beaver = ctx->getState()->beaver(); @@ -602,8 +580,6 @@ NdArrayRef MatMulAA::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, NdArrayRef LShiftA::proc(KernelEvalContext* ctx, const NdArrayRef& in, size_t bits) const { - SPU_TRACE_MPC_LEAF(ctx, in, bits); - const auto field = in.eltype().as()->field(); bits %= SizeOf(field) * 8; @@ -621,8 +597,6 @@ NdArrayRef LShiftA::proc(KernelEvalContext* ctx, const NdArrayRef& in, // Ref: Section 5.1.2 https://eprint.iacr.org/2018/403.pdf NdArrayRef TruncA::proc(KernelEvalContext* ctx, const NdArrayRef& in, size_t bits, SignType sign) const { - SPU_TRACE_MPC_LEAF(ctx, in, bits); - (void)sign; // TODO: optimize me. const auto key = ctx->getState()->key(); diff --git a/libspu/mpc/spdz2k/beaver/BUILD.bazel b/libspu/mpc/spdz2k/beaver/BUILD.bazel index 43b80bb6..fcf0f258 100644 --- a/libspu/mpc/spdz2k/beaver/BUILD.bazel +++ b/libspu/mpc/spdz2k/beaver/BUILD.bazel @@ -20,7 +20,7 @@ spu_cc_library( name = "beaver_interface", hdrs = ["beaver_interface.h"], deps = [ - "//libspu/core", + "//libspu/core:ndarray_ref", ], ) @@ -76,7 +76,7 @@ spu_cc_library( "//libspu/mpc/spdz2k/ot:kos_ote", "//libspu/mpc/spdz2k/ot:tiny_ot", "//libspu/mpc/utils:ring_ops", - "@yacl//yacl/crypto/primitives/ot:base_ot", + "@yacl//yacl/crypto/primitives/ot:ot_store", "@yacl//yacl/crypto/tools:prg", "@yacl//yacl/link", "@yacl//yacl/utils:matrix_utils", diff --git a/libspu/mpc/spdz2k/beaver/beaver_tfp.cc b/libspu/mpc/spdz2k/beaver/beaver_tfp.cc index b93363e8..f794c2a5 100644 --- a/libspu/mpc/spdz2k/beaver/beaver_tfp.cc +++ b/libspu/mpc/spdz2k/beaver/beaver_tfp.cc @@ -18,7 +18,6 @@ #include #include "yacl/crypto/utils/rand.h" -#include "yacl/link/link.h" #include "yacl/utils/serialize.h" #include "libspu/mpc/common/prg_tensor.h" diff --git a/libspu/mpc/spdz2k/beaver/beaver_tinyot.cc b/libspu/mpc/spdz2k/beaver/beaver_tinyot.cc index 4c7d4421..6f1d16bd 100644 --- a/libspu/mpc/spdz2k/beaver/beaver_tinyot.cc +++ b/libspu/mpc/spdz2k/beaver/beaver_tinyot.cc @@ -20,16 +20,14 @@ #include #include "yacl/base/dynamic_bitset.h" +#include "yacl/crypto/primitives/ot/base_ot.h" #include "yacl/crypto/primitives/ot/ot_store.h" #include "yacl/crypto/tools/prg.h" #include "yacl/crypto/utils/rand.h" -#include "yacl/link/link.h" -#include "yacl/utils/matrix_utils.h" #include "yacl/utils/serialize.h" #include "libspu/mpc/common/prg_tensor.h" #include "libspu/mpc/spdz2k/commitment.h" -#include "libspu/mpc/spdz2k/ot/kos_ote.h" #include "libspu/mpc/spdz2k/ot/tiny_ot.h" #include "libspu/mpc/utils/ring_ops.h" diff --git a/libspu/mpc/spdz2k/beaver/beaver_tinyot.h b/libspu/mpc/spdz2k/beaver/beaver_tinyot.h index 8c00cca7..57c766d4 100644 --- a/libspu/mpc/spdz2k/beaver/beaver_tinyot.h +++ b/libspu/mpc/spdz2k/beaver/beaver_tinyot.h @@ -14,7 +14,7 @@ #pragma once -#include "yacl/crypto/primitives/ot/base_ot.h" +#include "yacl/crypto/primitives/ot/ot_store.h" #include "yacl/link/context.h" #include "libspu/mpc/common/prg_state.h" diff --git a/libspu/mpc/spdz2k/boolean.cc b/libspu/mpc/spdz2k/boolean.cc index e7b8de46..ed4603d9 100644 --- a/libspu/mpc/spdz2k/boolean.cc +++ b/libspu/mpc/spdz2k/boolean.cc @@ -265,8 +265,6 @@ NdArrayRef CastTypeB::proc(KernelEvalContext*, const NdArrayRef& in, } NdArrayRef B2P::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { - SPU_TRACE_MPC_LEAF(ctx, in); - auto* beaver_ptr = ctx->getState()->beaver(); const auto s = ctx->getState()->s(); const auto field = in.eltype().as()->field(); @@ -305,8 +303,6 @@ NdArrayRef B2P::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { } NdArrayRef P2B::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { - SPU_TRACE_MPC_LEAF(ctx, in); - auto* comm = ctx->getState(); const auto k = ctx->getState()->k(); const auto key = ctx->getState()->key(); @@ -337,8 +333,6 @@ NdArrayRef P2B::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { } NdArrayRef NotB::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { - SPU_TRACE_MPC_LEAF(ctx, in); - const auto field = in.eltype().as()->field(); const auto nbits = in.eltype().as()->nbits(); const auto key = ctx->getState()->key(); @@ -368,8 +362,6 @@ NdArrayRef NotB::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { NdArrayRef BitrevB::proc(KernelEvalContext* ctx, const NdArrayRef& in, size_t start, size_t end) const { - SPU_TRACE_MPC_LEAF(ctx, in); - const auto field = in.eltype().as()->field(); const auto nbits = in.eltype().as()->nbits(); const auto numel = in.numel(); @@ -401,7 +393,6 @@ NdArrayRef BitrevB::proc(KernelEvalContext* ctx, const NdArrayRef& in, NdArrayRef XorBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const NdArrayRef& rhs) const { - SPU_TRACE_MPC_LEAF(ctx, lhs, rhs); const auto field = lhs.eltype().as()->field(); const auto nbits = maxNumBits(lhs, rhs); @@ -419,8 +410,6 @@ NdArrayRef XorBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, NdArrayRef XorBP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const NdArrayRef& rhs) const { - SPU_TRACE_MPC_LEAF(ctx, lhs, rhs); - const auto field = lhs.eltype().as()->field(); const auto nbits = maxNumBits(lhs, rhs); const auto k = ctx->getState()->k(); @@ -450,8 +439,6 @@ NdArrayRef XorBP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, NdArrayRef AndBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const NdArrayRef& rhs) const { - SPU_TRACE_MPC_LEAF(ctx, lhs, rhs); - const auto field = lhs.eltype().as()->field(); auto* comm = ctx->getState(); auto* beaver_ptr = ctx->getState()->beaver(); @@ -506,7 +493,6 @@ NdArrayRef AndBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, NdArrayRef AndBP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const NdArrayRef& rhs) const { - SPU_TRACE_MPC_LEAF(ctx, lhs, rhs); SPU_ENFORCE(lhs.shape() == rhs.shape(), "lhs shape {}, rhs shape {}", lhs.shape(), rhs.shape()); @@ -546,8 +532,6 @@ NdArrayRef LShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, NdArrayRef RShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, size_t bits) const { - SPU_TRACE_MPC_LEAF(ctx, in, bits); - const auto field = in.eltype().as()->field(); const auto nbits = in.eltype().as()->nbits(); size_t new_nbis = nbits > bits ? nbits - bits : 1; @@ -562,8 +546,6 @@ static NdArrayRef wrap_rshift_b(SPUContext* ctx, const NdArrayRef& x, NdArrayRef ARShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, size_t bits) const { - SPU_TRACE_MPC_LEAF(ctx, in, bits); - const auto field = in.eltype().as()->field(); const auto k = ctx->getState()->k(); const auto nbits = in.eltype().as()->nbits(); diff --git a/libspu/mpc/spdz2k/conversion.cc b/libspu/mpc/spdz2k/conversion.cc index 6c67c129..8f8ded08 100644 --- a/libspu/mpc/spdz2k/conversion.cc +++ b/libspu/mpc/spdz2k/conversion.cc @@ -141,8 +141,6 @@ CircuitBasicBlock MakeSPDZBasicBlock(SPUContext* ctx) { }; // namespace NdArrayRef A2Bit::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { - SPU_TRACE_MPC_LEAF(ctx, in); - const auto field = in.eltype().as()->field(); const size_t s = ctx->getState()->s(); const size_t nbits = 1; @@ -159,8 +157,6 @@ NdArrayRef A2Bit::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { } NdArrayRef Bit2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { - SPU_TRACE_MPC_LEAF(ctx, in); - const auto field = in.eltype().as()->field(); auto* comm = ctx->getState(); auto* beaver = ctx->getState()->beaver(); @@ -217,8 +213,6 @@ NdArrayRef Bit2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { } NdArrayRef A2B::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { - SPU_TRACE_MPC_LEAF(ctx, in); - const auto field = in.eltype().as()->field(); auto* beaver = ctx->getState()->beaver(); const size_t k = ctx->getState()->k(); @@ -261,8 +255,6 @@ NdArrayRef A2B::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { } NdArrayRef B2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { - SPU_TRACE_MPC_LEAF(ctx, in); - const auto field = in.eltype().as()->field(); const auto nbits = in.eltype().as()->nbits(); auto* comm = ctx->getState(); @@ -334,7 +326,6 @@ NdArrayRef B2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { } NdArrayRef MSB::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { - SPU_TRACE_MPC_LEAF(ctx, in); const auto field = in.eltype().as()->field(); const int64_t k = ctx->getState()->k(); const size_t s = ctx->getState()->s(); @@ -458,7 +449,6 @@ NdArrayRef MSB::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { NdArrayRef AddBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const NdArrayRef& rhs) const { - SPU_TRACE_MPC_LEAF(ctx, lhs, rhs); const size_t nbits = maxNumBits(lhs, rhs); const auto field = lhs.eltype().as()->field(); const auto [x_val, x_mac] = BShareSwitch2Nbits(lhs, nbits); @@ -475,8 +465,6 @@ NdArrayRef AddBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, NdArrayRef AddBP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const NdArrayRef& rhs) const { - SPU_TRACE_MPC_LEAF(ctx, lhs, rhs); - const size_t nbits = maxNumBits(lhs, rhs); const auto field = lhs.eltype().as()->field(); const auto [x_val, x_mac] = BShareSwitch2Nbits(lhs, nbits); @@ -491,7 +479,6 @@ NdArrayRef AddBP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, #if 0 ArrayRef BitLTBB::proc(KernelEvalContext* ctx, const ArrayRef& lhs, const ArrayRef& rhs) const { - SPU_TRACE_MPC_LEAF(ctx, lhs, rhs); const auto nbits = maxNumBits(lhs, rhs); const auto field = lhs.eltype().as()->field(); const auto numel = lhs.numel(); @@ -517,8 +504,6 @@ ArrayRef BitLTBB::proc(KernelEvalContext* ctx, const ArrayRef& lhs, #else NdArrayRef BitLTBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const NdArrayRef& rhs) const { - SPU_TRACE_MPC_LEAF(ctx, lhs, rhs); - auto res0 = wrap_bitle_bb(ctx->sctx(), lhs, rhs); auto res1 = wrap_bitle_bb(ctx->sctx(), rhs, lhs); auto eq = wrap_and_bb(ctx->sctx(), res0, res1); @@ -531,8 +516,6 @@ NdArrayRef BitLTBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, NdArrayRef BitLEBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const NdArrayRef& rhs) const { - SPU_TRACE_MPC_LEAF(ctx, lhs, rhs); - const auto nbits = maxNumBits(lhs, rhs); const auto field = lhs.eltype().as()->field(); diff --git a/libspu/mpc/spdz2k/value.h b/libspu/mpc/spdz2k/value.h index e4708687..fc5e52bb 100644 --- a/libspu/mpc/spdz2k/value.h +++ b/libspu/mpc/spdz2k/value.h @@ -15,7 +15,6 @@ #pragma once #include "libspu/core/ndarray_ref.h" -#include "libspu/core/type_util.h" namespace spu::mpc::spdz2k { diff --git a/libspu/mpc/utils/BUILD.bazel b/libspu/mpc/utils/BUILD.bazel index 6f4560a7..837333f6 100644 --- a/libspu/mpc/utils/BUILD.bazel +++ b/libspu/mpc/utils/BUILD.bazel @@ -38,7 +38,7 @@ spu_cc_library( name = "simulate", hdrs = ["simulate.h"], deps = [ - "@yacl//yacl/link", + "@yacl//yacl/link:test_util", ], ) @@ -47,7 +47,7 @@ spu_cc_library( srcs = ["permute.cc"], hdrs = ["permute.h"], deps = [ - "//libspu/core", + "//libspu/core:ndarray_ref", ], ) @@ -72,7 +72,7 @@ spu_cc_library( }), deps = [ ":linalg", - "//libspu/core", + "//libspu/core:ndarray_ref", "@yacl//yacl/crypto/tools:prg", "@yacl//yacl/crypto/utils:rand", "@yacl//yacl/utils:parallel", @@ -102,7 +102,6 @@ spu_cc_library( hdrs = ["linalg.h"], deps = [ "//libspu/core:parallel_utils", - "//libspu/core:prelude", "@com_github_eigenteam_eigen//:eigen3", ] + select({ "@bazel_tools//src/conditions:darwin_x86_64": ["@local_homebrew_x64//:openmp"], diff --git a/libspu/mpc/utils/linalg.cc b/libspu/mpc/utils/linalg.cc index 254c90ac..478c44cc 100644 --- a/libspu/mpc/utils/linalg.cc +++ b/libspu/mpc/utils/linalg.cc @@ -14,6 +14,8 @@ #include "libspu/mpc/utils/linalg.h" +#include "libspu/core/parallel_utils.h" + namespace spu::mpc::linalg::detail { void setEigenParallelLevel(int64_t expected_threads) { diff --git a/libspu/mpc/utils/linalg.h b/libspu/mpc/utils/linalg.h index b36115da..ea50b101 100644 --- a/libspu/mpc/utils/linalg.h +++ b/libspu/mpc/utils/linalg.h @@ -16,11 +16,6 @@ #include -#include "spdlog/spdlog.h" - -#include "libspu/core/parallel_utils.h" -#include "libspu/core/prelude.h" - #define EIGEN_HAS_OPENMP #include "Eigen/Core" diff --git a/libspu/mpc/utils/linalg_test.cc b/libspu/mpc/utils/linalg_test.cc index 5e8b0226..189b5821 100644 --- a/libspu/mpc/utils/linalg_test.cc +++ b/libspu/mpc/utils/linalg_test.cc @@ -27,8 +27,8 @@ TEST(LinalgTest, MatMulBasic) { matmul(4, 2, 3, A.data(), 3, 1, B.data(), 2, 1, C.data(), 2, 1); - std::vector expected = {22.f, 28.f, 58.f, 76.f, - 94.f, 124.f, 130.f, 172.f}; + std::vector expected = {22.F, 28.F, 58.F, 76.F, + 94.F, 124.F, 130.F, 172.F}; EXPECT_EQ(C, expected); } @@ -46,8 +46,8 @@ TEST(LinalgTest, MatMulStrides) { matmul(4, 2, 3, A.data(), 6, 2, B.data(), 4, 2, C.data(), 2, 1); - std::vector expected = {22.f, 28.f, 58.f, 76.f, - 94.f, 124.f, 130.f, 172.f}; + std::vector expected = {22.F, 28.F, 58.F, 76.F, + 94.F, 124.F, 130.F, 172.F}; EXPECT_EQ(C, expected); } diff --git a/libspu/mpc/utils/ring_ops.cc b/libspu/mpc/utils/ring_ops.cc index d480bda6..57b7f0dd 100644 --- a/libspu/mpc/utils/ring_ops.cc +++ b/libspu/mpc/utils/ring_ops.cc @@ -22,7 +22,6 @@ #include "absl/types/span.h" #include "yacl/crypto/tools/prg.h" #include "yacl/crypto/utils/rand.h" -#include "yacl/utils/parallel.h" #include "libspu/mpc/utils/linalg.h" diff --git a/libspu/mpc/utils/ring_ops.h b/libspu/mpc/utils/ring_ops.h index c5b86ed1..5565ff1c 100644 --- a/libspu/mpc/utils/ring_ops.h +++ b/libspu/mpc/utils/ring_ops.h @@ -15,7 +15,6 @@ #pragma once #include "libspu/core/ndarray_ref.h" -#include "libspu/core/type.h" namespace spu::mpc { diff --git a/libspu/mpc/utils/simulate.h b/libspu/mpc/utils/simulate.h index 64108854..f1b76bee 100644 --- a/libspu/mpc/utils/simulate.h +++ b/libspu/mpc/utils/simulate.h @@ -17,7 +17,6 @@ #include #include -#include "yacl/link/link.h" #include "yacl/link/test_util.h" namespace spu::mpc::utils { diff --git a/libspu/psi/tools/generate_psi.py b/libspu/psi/tools/generate_psi.py index 06263db7..f5d9951b 100644 --- a/libspu/psi/tools/generate_psi.py +++ b/libspu/psi/tools/generate_psi.py @@ -13,10 +13,9 @@ # limitations under the License. -from random import randint -from random import sample import csv import sys +from random import randint, sample def random_with_N_digits(n): diff --git a/pyproject.toml b/pyproject.toml index d613f1f8..7f4a9905 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,3 +18,16 @@ ignore_roles = [ ignore_languages = [ "cpp" ] + +[tool.pyright] +include = [ + "spu", + "sml", + "examples", +] + +reportMissingImports = true +reportMissingTypeStubs = false + +pythonVersion = "3.9" +pythonPlatform = "Linux" diff --git a/setup.py b/setup.py index fd658748..8ddcf1d5 100644 --- a/setup.py +++ b/setup.py @@ -17,14 +17,15 @@ import io import logging import os +import platform import re import shutil import subprocess import sys +from datetime import datetime, timedelta + import setuptools import setuptools.command.build_ext -import platform -from datetime import datetime, timedelta logger = logging.getLogger(__name__) diff --git a/sml/decomposition/nmf.py b/sml/decomposition/nmf.py index 6776fe78..fff3c610 100644 --- a/sml/decomposition/nmf.py +++ b/sml/decomposition/nmf.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np import jax.numpy as jnp +import numpy as np def update_w(X, W, H, H_sum, HHt, XHt, l1_reg_W, l2_reg_W): diff --git a/sml/decomposition/pca.py b/sml/decomposition/pca.py index 1b06b3eb..6c942ed8 100644 --- a/sml/decomposition/pca.py +++ b/sml/decomposition/pca.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import jax.numpy as jnp -from enum import Enum import os import sys +from enum import Enum + +import jax.numpy as jnp sys.path.append(os.path.join(os.path.dirname(__file__), '../../')) diff --git a/sml/decomposition/tests/pca_test.py b/sml/decomposition/tests/pca_test.py index 2d2c0b43..2f20c8a2 100644 --- a/sml/decomposition/tests/pca_test.py +++ b/sml/decomposition/tests/pca_test.py @@ -20,6 +20,7 @@ import numpy as np from jax import random from sklearn.decomposition import PCA as SklearnPCA + import spu.spu_pb2 as spu_pb2 import spu.utils.simulation as spsim diff --git a/sml/linear_model/emulations/glm_emul.py b/sml/linear_model/emulations/glm_emul.py index 2f971da8..74421c5e 100644 --- a/sml/linear_model/emulations/glm_emul.py +++ b/sml/linear_model/emulations/glm_emul.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import numpy as np -from sml.linear_model.glm import _GeneralizedLinearRegressor import sml.utils.emulation as emulation import spu.utils.distributed as ppd +from sml.linear_model.glm import _GeneralizedLinearRegressor n_samples, n_features = 100, 5 diff --git a/sml/linear_model/glm.py b/sml/linear_model/glm.py index 66103698..5482af84 100644 --- a/sml/linear_model/glm.py +++ b/sml/linear_model/glm.py @@ -15,6 +15,7 @@ import warnings import jax.numpy as jnp + from sml.linear_model.utils.link import * from sml.linear_model.utils.loss import * from sml.linear_model.utils.solver import * diff --git a/sml/linear_model/tests/glm_test.py b/sml/linear_model/tests/glm_test.py index 35cbf1fa..a9e7c734 100644 --- a/sml/linear_model/tests/glm_test.py +++ b/sml/linear_model/tests/glm_test.py @@ -16,12 +16,6 @@ import jax.numpy as jnp import numpy as np -from sml.linear_model.glm import ( - GammaRegressor, - PoissonRegressor, - TweedieRegressor, - _GeneralizedLinearRegressor, -) from sklearn.linear_model._glm import GammaRegressor as std_GammaRegressor from sklearn.linear_model._glm import PoissonRegressor as std_PoissonRegressor from sklearn.linear_model._glm import TweedieRegressor as std_TweedieRegressor @@ -31,6 +25,12 @@ import spu.spu_pb2 as spu_pb2 import spu.utils.simulation as spsim +from sml.linear_model.glm import ( + GammaRegressor, + PoissonRegressor, + TweedieRegressor, + _GeneralizedLinearRegressor, +) verbose = 0 n_samples, n_features = 100, 5 diff --git a/sml/metrics/classification/auc.py b/sml/metrics/classification/auc.py index 4e86ce01..1072499a 100644 --- a/sml/metrics/classification/auc.py +++ b/sml/metrics/classification/auc.py @@ -13,9 +13,11 @@ # limitations under the License. from typing import Tuple -import jax.numpy as jnp + import jax -from spu.ops.groupby.groupby import groupby_sorted +import jax.numpy as jnp + +from spu.ops.groupby import groupby_sorted def binary_clf_curve(sorted_pairs: jnp.array) -> Tuple[jnp.array, jnp.array, jnp.array]: diff --git a/sml/metrics/classification/classification.py b/sml/metrics/classification/classification.py index 4df70a05..6b8c7fe2 100644 --- a/sml/metrics/classification/classification.py +++ b/sml/metrics/classification/classification.py @@ -13,12 +13,13 @@ # limitations under the License. from typing import Tuple + import jax import jax.numpy as jnp - -from spu.ops.groupby.groupby import groupby, groupby_sum_no_shuffle from auc import binary_roc_auc +from spu.ops.groupby import groupby, groupby_sum + def roc_auc_score(y_true, y_pred): sorted_arr = create_sorted_label_score_pair(y_true, y_pred) @@ -53,7 +54,7 @@ def bin_counts( bin_sorted, bin_count_cols, _, effective_rows = groupby( [-bins], [y_true, y_true_negate] ) - bin_count_matrix = groupby_sum_no_shuffle(bin_count_cols, effective_rows) + bin_count_matrix = groupby_sum(bin_count_cols, effective_rows) return ( -bin_sorted[0], bin_count_matrix[:, 0], diff --git a/sml/metrics/classification/classification_emul.py b/sml/metrics/classification/classification_emul.py index 01ba0b50..fbddcf8d 100644 --- a/sml/metrics/classification/classification_emul.py +++ b/sml/metrics/classification/classification_emul.py @@ -19,10 +19,8 @@ # add ops dir to the path sys.path.append(os.path.join(os.path.dirname(__file__), '../../')) -from sml.metrics.classification.classification import roc_auc_score - - import sml.utils.emulation as emulation +from sml.metrics.classification.classification import roc_auc_score # TODO: design the enumation framework, just like py.unittest diff --git a/sml/metrics/classification/classification_test.py b/sml/metrics/classification/classification_test.py index 43c7b7a5..8ff59db5 100644 --- a/sml/metrics/classification/classification_test.py +++ b/sml/metrics/classification/classification_test.py @@ -16,8 +16,8 @@ import time import unittest -import numpy as np import jax.numpy as jnp +import numpy as np import spu.spu_pb2 as spu_pb2 import spu.utils.simulation as spsim @@ -25,14 +25,14 @@ # add ops dir to the path sys.path.append(os.path.join(os.path.dirname(__file__), '../../')) +from sklearn.metrics import roc_auc_score as sk_roc_auc_score + from sml.metrics.classification.classification import ( - roc_auc_score, - equal_obs, bin_counts, + equal_obs, + roc_auc_score, ) -from sklearn.metrics import roc_auc_score as sk_roc_auc_score - class UnitTests(unittest.TestCase): def test_simple(self): diff --git a/sml/utils/fxp_approx.py b/sml/utils/fxp_approx.py index 52dcf111..932bc0fa 100644 --- a/sml/utils/fxp_approx.py +++ b/sml/utils/fxp_approx.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import jax.numpy as jnp from enum import Enum +import jax.numpy as jnp + class SigType(Enum): T1 = 't1' diff --git a/spu/__init__.py b/spu/__init__.py index 8142c5f0..61c1d9cf 100644 --- a/spu/__init__.py +++ b/spu/__init__.py @@ -13,26 +13,22 @@ # limitations under the License. -from .version import __version__ # type: ignore - +from . import pir, psi +from .api import Io, Runtime, check_cpu_feature, compile +from .intrinsic import * from .spu_pb2 import ( # type: ignore + CompilerOptions, DataType, - Visibility, - PtType, - ProtocolKind, + ExecutableProto, FieldType, - ShapeProto, + ProtocolKind, + PtType, RuntimeConfig, - ExecutableProto, - CompilerOptions, + ShapeProto, + Visibility, ) - -from .api import Io, Runtime, compile, check_cpu_feature from .utils import simulation -from .intrinsic import * - -from . import pir -from . import psi +from .version import __version__ # type: ignore __all__ = [ "__version__", diff --git a/spu/intrinsic/__init__.py b/spu/intrinsic/__init__.py index ffff546e..3559fa8a 100644 --- a/spu/intrinsic/__init__.py +++ b/spu/intrinsic/__init__.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .example_impl import example - from .example_binary_impl import example_binary +from .example_impl import example # DO-NOT-EDIT:ADD_IMPORT diff --git a/spu/intrinsic/example_binary_impl.py b/spu/intrinsic/example_binary_impl.py index 29a9593b..7a7f672d 100644 --- a/spu/intrinsic/example_binary_impl.py +++ b/spu/intrinsic/example_binary_impl.py @@ -2,6 +2,7 @@ from functools import partial +import numpy as np from jax import core, dtypes from jax.core import ShapedArray from jax.interpreters import ad, batching, mlir, xla @@ -9,8 +10,6 @@ # from jax.lib import xla_client from jaxlib.hlo_helpers import custom_call -import numpy as np - # Public facing interface def example_binary(in1, in2): diff --git a/spu/libspu.cc b/spu/libspu.cc index c2062471..2e68f184 100644 --- a/spu/libspu.cc +++ b/spu/libspu.cc @@ -515,13 +515,20 @@ class IoWrapper { cur.dtype(), prev.dtype()); } - auto ndarr = ptr_->combineShares(shares); - SPU_ENFORCE(ndarr.eltype().isa(), "expect decode to pt_type, got {}", - ndarr.eltype()); + const PtType pt_type = ptr_->getPtType(shares); + std::vector shape = {shares.front().shape().begin(), + shares.front().shape().end()}; - const auto pt_type = ndarr.eltype().as()->pt_type(); - std::vector shape = {ndarr.shape().begin(), ndarr.shape().end()}; - return py::array(py::dtype(PtTypeToPyFormat(pt_type)), shape, ndarr.data()); + py::array ret(py::dtype(PtTypeToPyFormat(pt_type)), shape); + const py::buffer_info& binfo = ret.request(); + + spu::PtBufferView ret_view( + binfo.ptr, pt_type, Shape(binfo.shape.begin(), binfo.shape.end()), + ByteToElementStrides(binfo.strides.begin(), binfo.strides.end(), + binfo.itemsize)); + + ptr_->combineShares(shares, &ret_view); + return ret; } }; diff --git a/spu/ops/__init__.py b/spu/ops/__init__.py new file mode 100644 index 00000000..602546f7 --- /dev/null +++ b/spu/ops/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2021 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/spu/ops/groupby/BUILD.bazel b/spu/ops/groupby/BUILD.bazel index f44ad836..2b6747a0 100644 --- a/spu/ops/groupby/BUILD.bazel +++ b/spu/ops/groupby/BUILD.bazel @@ -12,20 +12,82 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("@rules_python//python:defs.bzl", "py_binary", "py_library", "py_test") +load("@rules_python//python:defs.bzl", "py_library", "py_test") package(default_visibility = ["//visibility:public"]) py_library( name = "groupby", - srcs = ["groupby.py"], + srcs = [ + "__init__.py", + ], + deps = [ + ":aggregation", + ":groupby_via_shuffle", + ":postprocess", + ":segmentation", + ":shuffle", + ":utils", + ], +) + +py_library( + name = "segmentation", + srcs = [ + "segmentation.py", + ], + deps = [":utils"], +) + +py_library( + name = "aggregation", + srcs = [ + "aggregation.py", + ], + deps = [":utils"], +) + +py_library( + name = "utils", + srcs = [ + "utils.py", + ], +) + +py_library( + name = "groupby_via_shuffle", + srcs = [ + "groupby_via_shuffle.py", + ], + deps = [ + ":aggregation", + ":shuffle", + ], +) + +py_library( + name = "shuffle", + srcs = [ + "shuffle.py", + ], + deps = [":utils"], +) + +py_library( + name = "postprocess", + srcs = [ + "postprocess.py", + ], ) py_test( name = "groupby_test", srcs = ["groupby_test.py"], deps = [ - ":groupby", + ":aggregation", + ":groupby_via_shuffle", + ":postprocess", + ":segmentation", "//spu:init", "//spu/utils:simulation", ], diff --git a/spu/ops/groupby/README.md b/spu/ops/groupby/README.md new file mode 100644 index 00000000..f5f4d265 --- /dev/null +++ b/spu/ops/groupby/README.md @@ -0,0 +1,79 @@ +# The Groupby Operation + +## What is groupby operation? + +The groupby operation in pandas is used to split data into groups based on a set of specific attributes or columns. +It allows us to apply functions or calculations to each group, facilitating analysis and summarization of data. +The results obtained provide insights on patterns and characteristics within each group. + +## How to do groupby operation using pure jax numpy (secure and fast)? + +our groupby operation is achived in two major steps and one optional postprocess step: + +1. segmentation: segmentation is a process of splitting data into groups based on a set of specific attributes or columns. +2. aggregation: aggregation is a process of combining data within each group to obtain a single value. +3. (optional) shuffling: shuffle the results, necessary if count statistics is not revealed. + +## What does the segmentation do? + +The segmentation consists of 3 steps: + +1. sort the table according to the key_columns. +2. the group ends are found by finding the differences in the key columns. +3. the group numbers are assigned. + +Step 1 will produce: + +* key_columns_sorted +* target_columns_sorted + +Step 2 will produce: + +* segment_end_marks + +Step 3 will produce: + +* segment_ids + +We call these 4 results the "segmentation information", because they entail which group each element is assigned into. + +## What does the aggregation do? + +Based on the segmentation information, we are now ready to aggregate the data. +For each group, we will perform sum, or min, or max, etc operation. + +### How to do this fast? + +The main idea is a classfical prefix sum with some modifications: + +1. We use one column to indicate the membership of each sample to its group. +2. We devise a modified version of operation which has 3 properties: + + * correctness: The binary operation scanning through the list produces the final statistics we care. + (e.g. the addition of all elements is indeed the sum) + * indicator awareness: The binary operation is aware of the membership of each sample to its group. + * associativity: The binary operation is associative. + +3. We use associative scan to do the prefix sum in a parallel manner. + +## Why shuffle? + +By the definition of MPC, we cannot reveal any information except those can be inferred from the result. +When we compute the statistics except count, we cannot infer the information about the number of elements in each group. +Hence the count of each group must be protected. + +Our aggregation algorithms produce the results in form of a sparse matrix, with group statistics at the end of each group and zero else where. +Without shuffling, the count statistics can be inferred by positions of the non-zero elements. + +So we need to shuffle the results to make sure that the count statistics are not revealed. + +As a side effect, we also need some (cleartext) postprocessing steps to clean up the shuffled results. + +## Organization of files + +1. segmentation.py: contains the implementation of the segmentation algorithm. +2. aggregation.py: contains the implementation of the aggregation algorithm (without shuffling, suitable for internal use). +3. shuffle.py: contains the implementation of the shuffle algorithm. +4. utils.py: contains utility functions for the algorithms. +5. groupby_via_shuffle.py: contains the aggregation operations with shuffling (suitable for application). +6. postpocess.py: contains the (cleartext) postprocess operations to clean up the results. diff --git a/spu/ops/groupby/__init__.py b/spu/ops/groupby/__init__.py index 602546f7..14b07481 100644 --- a/spu/ops/groupby/__init__.py +++ b/spu/ops/groupby/__init__.py @@ -11,3 +11,43 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + + +from spu.ops.groupby.aggregation import ( + groupby_count, + groupby_count_cleartext, + groupby_max, + groupby_mean, + groupby_sum, + groupby_var, +) +from spu.ops.groupby.groupby_via_shuffle import ( + groupby_max_via_shuffle, + groupby_mean_via_shuffle, + groupby_min_via_shuffle, + groupby_sum_via_shuffle, + groupby_var_via_shuffle, +) +from spu.ops.groupby.postprocess import groupby_agg_postprocess, view_key_postprocessing +from spu.ops.groupby.segmentation import groupby, groupby_sorted +from spu.ops.groupby.shuffle import shuffle_cols, shuffle_matrix + +__all__ = [ + groupby, + groupby_sorted, + groupby_count, + groupby_max, + groupby_sum, + groupby_var, + groupby_mean, + shuffle_cols, + shuffle_matrix, + groupby_max_via_shuffle, + groupby_mean_via_shuffle, + groupby_var_via_shuffle, + groupby_min_via_shuffle, + groupby_sum_via_shuffle, + groupby_count_cleartext, + groupby_agg_postprocess, + view_key_postprocessing, +] diff --git a/spu/ops/groupby/aggregation.py b/spu/ops/groupby/aggregation.py new file mode 100644 index 00000000..ee2d6811 --- /dev/null +++ b/spu/ops/groupby/aggregation.py @@ -0,0 +1,151 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import jax +import jax.numpy as jnp + +from spu.ops.groupby.utils import cols_to_matrix, matrix_to_cols + + +def groupby_agg( + cols, + seg_end_marks, + segment_aware_ops, +) -> jnp.ndarray: + """Performs groupby aggregation operation. + + The returns of this function are NOT safe to open (count of group elements revealed) + However, if user already asked about count statistics, then we can safely open it. + + return: + group_agg_matrix: + shape = (n_samples, n_cols) + group aggregations + padded with zeros. + + """ + group_mask = jnp.ones(seg_end_marks.shape) - jnp.roll(seg_end_marks, 1) + + X = jnp.vstack([group_mask] + list(cols)).T + + X_prefix_sum = jax.lax.associative_scan(segment_aware_ops, X, axis=0) + X_prefix_sum_masked = seg_end_marks.reshape(-1, 1) * X_prefix_sum + + return X_prefix_sum_masked[:, 1:] + + +def groupby_transform(seg_end_marks, group_agg_matrix): + """broadcast the result of groupby_agg in a group wise manner + [[0,0,0,b1,0,0,0,b2], + [0,0,0,c1, 0,0,0,c2]] -> [[b1, b1, b1, b1, b2,b2,b2,b2],[c1,c1,c1,c1,c2,c2,c2,c2]] + """ + + group_agg_matrix_offseted = group_agg_matrix + # perform groupby sum + group_mask = jnp.ones(seg_end_marks.shape) - seg_end_marks + X = jnp.hstack([group_mask.reshape(-1, 1), group_agg_matrix_offseted]) + X_prefix_sum = jax.lax.associative_scan( + segment_aware_addition, X, axis=0, reverse=True + ) + # restore the offset + return X_prefix_sum[:, 1:] + + +def groupby_sum(cols, seg_end_marks) -> jnp.ndarray: + return groupby_agg(cols, seg_end_marks, segment_aware_addition) + + +def groupby_max(cols, seg_end_marks) -> jnp.ndarray: + return groupby_agg(cols, seg_end_marks, segment_aware_max) + + +def groupby_min(cols, seg_end_marks) -> jnp.ndarray: + return groupby_agg(cols, seg_end_marks, segment_aware_min) + + +def groupby_count(cols, seg_end_marks): + """groupby count, it does not require cleartext data""" + ones = jnp.ones(cols[0].shape) + group_count_matrix = groupby_agg( + [ones], seg_end_marks, segment_aware_ops=segment_aware_addition + ) + return group_count_matrix + + +# the method is simple: open segment_ids and do count in cleartext +# It is supposed to be faster than the groupby_count which outputs ciphertext. +# in SPU all NaN values are encoded as 0, so count becomes trivial. +# cleartext function: +# further if a query includes count, the shuffle ops in groupby sum can be skipped +def groupby_count_cleartext(opened_segment_ids): + """Count the number of elements in each group.""" + _, counts = jnp.unique(opened_segment_ids, return_counts=True) + return counts + + +def groupby_mean(cols, seg_end_marks): + assert len(cols) > 0, "at least one col is required" + + # note nan are zeros in SPU, count is the same for all columns + group_count_matrix = groupby_count(cols, seg_end_marks) + return grouby_mean_given_count(cols, seg_end_marks, group_count_matrix) + + +def grouby_mean_given_count(cols, seg_end_marks, group_count_matrix): + group_sum_matrix = groupby_agg( + cols, + seg_end_marks, + segment_aware_ops=segment_aware_addition, + ) + return group_sum_matrix / group_count_matrix + + +def groupby_var(cols, seg_end_marks): + """Perform groupby var operation and shuffle the results. + The inputs are groupby operation's output. + + recall that var(X) = (x_i - x_mean)^2 / (N - 1) + """ + + group_count_matrix = groupby_count(cols, seg_end_marks) + group_mean_matrix = grouby_mean_given_count(cols, seg_end_marks, group_count_matrix) + mean_matrix = groupby_transform(seg_end_marks, group_mean_matrix) + raw_matrix = cols_to_matrix(cols) + residual_square_matrix = (raw_matrix - mean_matrix) ** 2 + group_rrs_matrix = groupby_sum( + matrix_to_cols(residual_square_matrix), seg_end_marks + ) + group_var_matrix = ( + seg_end_marks.reshape(-1, 1) * group_rrs_matrix / (group_count_matrix - 1) + ) + return group_var_matrix + + +def segment_aware_addition(row1, row2): + return segment_aware_ops(row1, row2, jnp.add) + + +def segment_aware_max(row1, row2): + return segment_aware_ops(row1, row2, jnp.maximum) + + +def segment_aware_min(row1, row2): + return segment_aware_ops(row1, row2, jnp.minimum) + + +def segment_aware_ops(row1, row2, ops): + cum_part = jnp.where((row2[:, 0] == 1).reshape(-1, 1), ops(row1, row2), row2)[:, 1:] + lead_part = (row1[:, 0] * row2[:, 0]).reshape(-1, 1) + return jnp.c_[lead_part, cum_part] diff --git a/spu/ops/groupby/groupby.py b/spu/ops/groupby/groupby.py deleted file mode 100644 index 5bd9b36c..00000000 --- a/spu/ops/groupby/groupby.py +++ /dev/null @@ -1,292 +0,0 @@ -# Copyright 2023 Ant Group Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import functools -from typing import List, Tuple - -import jax -import jax.numpy as jnp -import numpy as np - -# Conceptually we want to do the following -# However, due to limitations of design, we cannot do this -# we will use other groupby accumulators to do groupby some -# and treat (target_columns_sorted: List[jnp.ndarray], segment_ids: List[jnp.ndarray]) as the groupby object. -# class GroupByObject: -# def __init__( -# self, target_columns_sorted: List[jnp.ndarray], segment_ids: List[jnp.ndarray] -# ): -# self.target_columns_sorted = target_columns_sorted -# self.segment_ids = segment_ids - -# def sum(self, group_num: int): -# """ -# group num should be revealed and accessed from segment ids and taken as a static int. -# """ -# segment_ids = self.segment_ids -# x = self.target_columns_sorted -# return jax.ops.segment_sum(x, segment_ids, num_segments=group_num) - - -def segment_aware_addition(row1, row2): - return segment_aware_ops(row1, row2, jnp.add) - - -def segment_aware_max(row1, row2): - return segment_aware_ops(row1, row2, jnp.maximum) - - -def segment_aware_ops(row1, row2, ops): - cum_part = jnp.where((row2[:, 0] == 1).reshape(-1, 1), ops(row1, row2), row2)[:, 1:] - lead_part = (row1[:, 0] * row2[:, 0]).reshape(-1, 1) - return jnp.c_[lead_part, cum_part] - - -def groupby_agg( - cols, - seg_end_marks, - segment_aware_ops, -) -> jnp.ndarray: - """Groupby Aggregation with no shuffle. Usually used as internal tool in MPC. - - The returns of this function are NOT safe to open (count of group elements revealed) - return: - group_agg_matrix: - shape = (n_samples, n_cols) - group aggregations - padded with zeros. - - """ - group_mask = jnp.ones(seg_end_marks.shape) - jnp.roll(seg_end_marks, 1) - - X = jnp.vstack([group_mask] + list(cols)).T - - X_prefix_sum = jax.lax.associative_scan(segment_aware_ops, X, axis=0) - X_prefix_sum_masked = seg_end_marks.reshape(-1, 1) * X_prefix_sum - - return X_prefix_sum_masked[:, 1:] - - -def groupby_sum_no_shuffle(cols, seg_end_marks) -> jnp.ndarray: - return groupby_agg(cols, seg_end_marks, segment_aware_addition) - - -def groupby_agg_via_shuffle( - cols, - seg_end_marks, - segment_ids, - secret_random_order: jnp.ndarray, - segment_aware_ops, -): - """Groupby Aggregation with shuffled outputs. - - trick: output segment_end_marks and group aggregations in shuffled state, - filter to get the group aggregations in cleartext. - - shuffle to protect number of group elements. - - The returns of this function are supposed to be ok to be opened. - return: - segment_ids_shuffled: - shuffled segment ids - shuffled_group_end_masks: - shuffled group end masks - shuffled_group_agg_matrix: - shape = (n_samples, n_cols) - group aggregations shuffled - padded with zeros. - - """ - group_mask = jnp.ones(seg_end_marks.shape) - jnp.roll(seg_end_marks, 1) - - X = jnp.vstack([group_mask] + list(cols)).T - - X_prefix_sum = jax.lax.associative_scan(segment_aware_ops, X, axis=0) - X_prefix_sum_masked = seg_end_marks.reshape(-1, 1) * X_prefix_sum - segment_ids_masked = seg_end_marks * segment_ids - shuffled_cols = jax.lax.sort( - [secret_random_order] - + [segment_ids_masked] - + [seg_end_marks] - + [X_prefix_sum_masked[:, i] for i in range(1, X_prefix_sum_masked.shape[1])], - num_keys=1, - ) - return [ - shuffled_cols[1], - shuffled_cols[2], - jnp.vstack(shuffled_cols[3:]).T, - ] - - -def groupby_sum_via_shuffle( - cols, seg_end_marks, segment_ids, secret_random_order: jnp.ndarray -): - return groupby_agg_via_shuffle( - cols, - seg_end_marks, - segment_ids, - secret_random_order, - segment_aware_ops=segment_aware_addition, - ) - - -def groupby_max_via_shuffle( - cols, seg_end_marks, segment_ids, secret_random_order: jnp.ndarray -): - return groupby_agg_via_shuffle( - cols, - seg_end_marks, - segment_ids, - secret_random_order, - segment_aware_ops=segment_aware_max, - ) - - -# cleartext function -def groupby_agg_postprocess( - segment_ids, seg_end_marks, group_agg_matrix, group_num: int -): - assert ( - isinstance(group_num, int) and group_num > 0 - ), f"group num must be a positive integer. got {group_num}, {type(group_num)}" - if group_num > 1: - filter_mask = seg_end_marks == 1 - segment_ids = segment_ids[filter_mask] - group_agg_matrix = group_agg_matrix[filter_mask] - sorted_results = jax.lax.sort( - [segment_ids] - + [group_agg_matrix[:, i] for i in range(group_agg_matrix.shape[1])], - num_keys=1, - )[1:] - return jnp.vstack(sorted_results).T - else: - return group_agg_matrix[-1] - - -def batch_product(list_of_cols, multiplier_col): - return list( - map( - lambda x: x * multiplier_col, - list_of_cols, - ) - ) - - -def view_key( - key_columns_sorted: List[jnp.ndarray], - seg_end_marks: jnp.ndarray, - secret_random_order: jnp.ndarray, -): - """The secret_random_order must be secret to all parties - trick: open a shuffled array and unique in cleartext - """ - assert len(key_columns_sorted) > 0, "number of keys must be non-empty" - keys = batch_product(key_columns_sorted, seg_end_marks) - assert ( - secret_random_order.shape == key_columns_sorted[0].shape - ), "the secret_random_order should be the same shape as each of the key columns." - keys_shuffled = jax.lax.sort([secret_random_order] + keys, num_keys=1)[1:] - return keys_shuffled - - -# function operating on cleartext, used to postprocess opened results. -def view_key_postprocessing(keys, group_num: int): - keys = np.unique(np.vstack(keys).T, axis=0) - if keys.shape[0] > group_num: - keys = keys[1:, :] - return keys - - -def groupby( - key_columns: List[jnp.ndarray], - target_columns: List[jnp.ndarray], -) -> Tuple[List[jnp.ndarray], jnp.ndarray]: - """GroupBy - Given a matrix X, it has multiple columns. - We want to calculate some statistics of target columns grouped by some columns as keys. - This operator completes the first step of GroupBy statistics: transform the matrix x into a form, - that is suitable for subsequent statistics. - - Parameters - ---------- - - key_columns : List[jnp.ndarray] - List of columns that are used as keys, these should be arrays of the same shape. - - target_columns : List[jnp.ndarray] - List of columns that are used as keys, these should be arrays of the same shape as the shape in key columns. - - - Returns - ------- - key_columns_sorted : List[jnp.ndarray] - target_columns_sorted : List[jnp.ndarray] - segment_ids : jnp.ndarray - seg_end_marks : jnp.ndarray - """ - # parameter check. - assert isinstance(key_columns, List) - assert isinstance(target_columns, List) - assert len(key_columns) > 0, "There should be at least one key_column." - assert len(target_columns) > 0, "There should be at least one target_column." - assert ( - len(set(map(lambda x: x.shape, key_columns + target_columns))) == 1 - ), f"Columns' shape should be consistent. {set(map(lambda x: x.shape, key_columns + target_columns))}" - key_columns = key_columns - target_columns = target_columns - sorted_columns = jax.lax.sort( - key_columns + target_columns, num_keys=len(key_columns) - ) - key_columns_sorted = sorted_columns[: len(key_columns)] - target_columns_sorted = sorted_columns[len(key_columns) :] - return groupby_sorted(key_columns_sorted, target_columns_sorted) - - -def groupby_sorted( - key_columns_sorted: List[jnp.ndarray], - target_columns_sorted: List[jnp.ndarray], -) -> Tuple[List[jnp.ndarray], jnp.ndarray]: - key_columns_sorted_rolled = rotate_cols(key_columns_sorted) - seg_end_marks = get_segment_marks(key_columns_sorted, key_columns_sorted_rolled) - mark_accumulated = associative_scan(seg_end_marks) - segment_ids = mark_accumulated - seg_end_marks - return key_columns_sorted, target_columns_sorted, segment_ids, seg_end_marks - - -# the method is simple: open segment_ids and do count in cleartext -# in SPU all NaN values are encoded as 0, so count becomes trivial. -# cleartext function: -# further if a query includes count, the shuffle ops in groupby sum can be skipped -def groupby_count(opened_segment_ids): - _, counts = jnp.unique(opened_segment_ids, return_counts=True) - return counts - - -def rotate_cols(key_columns_sorted) -> List[jnp.ndarray]: - return list(map(lambda x: jnp.roll(x, -1), key_columns_sorted)) - - -def get_segment_marks(key_columns_sorted, key_columns_sorted_rolled): - tuple_list = list(zip(key_columns_sorted, key_columns_sorted_rolled)) - equal = [a - b == 0 for (a, b) in tuple_list] - c = ~functools.reduce(lambda x, y: x & y, equal) - c = c.astype(int) - result = jnp.r_[c[: c.size - 1], [1]] - # try - # result = c.at[c.size - 1].set(1) - return result - - -def associative_scan(seg_end_marks): - return jax.lax.associative_scan(jnp.add, seg_end_marks) diff --git a/spu/ops/groupby/groupby_test.py b/spu/ops/groupby/groupby_test.py index c392135e..e1994f08 100644 --- a/spu/ops/groupby/groupby_test.py +++ b/spu/ops/groupby/groupby_test.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -import sys + import time import unittest @@ -21,113 +20,137 @@ import spu.spu_pb2 as spu_pb2 import spu.utils.simulation as spsim - -# Add the ops directory to the path -sys.path.append(os.path.join(os.path.dirname(__file__), '../../')) - -from groupby import ( - groupby, - groupby_agg_postprocess, - groupby_sum_via_shuffle, +from spu.ops.groupby.aggregation import groupby_count, groupby_count_cleartext +from spu.ops.groupby.groupby_via_shuffle import ( groupby_max_via_shuffle, - groupby_count, - view_key, - view_key_postprocessing, + groupby_mean_via_shuffle, + groupby_min_via_shuffle, + groupby_sum_via_shuffle, + groupby_var_via_shuffle, ) +from spu.ops.groupby.postprocess import groupby_agg_postprocess, view_key_postprocessing +from spu.ops.groupby.segmentation import groupby +from spu.ops.groupby.shuffle import shuffle_cols + + +def groupby_agg_fun(agg): + if agg == 'sum': + return groupby_sum_via_shuffle + elif agg == 'max': + return groupby_max_via_shuffle + elif agg == 'min': + return groupby_min_via_shuffle + elif agg == 'mean': + return groupby_mean_via_shuffle + elif agg == 'count': + return groupby_count + elif agg == 'var': + return groupby_var_via_shuffle + else: + raise ValueError(f'Unknown agg {agg}') + + +def test_fn(agg): + sim = spsim.Simulator.simple(3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64) + + def proc(x1, x2, y): + return groupby([x1[:, 2], x2[:, 3]], [y]) + + def proc_view_key(key_cols, segment_end_marks, key): + return shuffle_cols(key_cols, segment_end_marks, key) + + np.random.seed(1234) + n_rows = 3000 + n_cols = 10 + x1 = np.random.random((n_rows, n_cols)) + x2 = np.random.random((n_rows, n_cols)) + y = np.random.random((n_rows,)) + # groupby category only supports discrete values + # here we are taking the shortcut, in reality, only the key need to be discrete. + # (their difference should be large enough so that their fxp repr are difference) + # we shrink the data value in order to reduce the size of groups. + # in our test data, we will get a group size about 15 + x1 = (x1 * 10 / 3).astype(int) + x2 = (x2 * 10 / 3).astype(int) + y = (y * 10 / 3).astype(int) + start = time.perf_counter() + keys, target_cols, segment_ids, segment_end_marks = spsim.sim_jax(sim, proc)( + x1, x2, y + ) + end = time.perf_counter() + print("groupby takes time", end - start) + X = np.zeros((x1.shape[0], 3)) + X[:, 0] = x1[:, 2] + X[:, 1] = x2[:, 3] + X[:, 2] = y + df = pd.DataFrame( + X, + columns=[f'col{i}' for i in range(3)], + ) + # Perform group by agg using pandas + pandas_groupby_agg = getattr( + df.groupby([df.columns[0], df.columns[1]])[df.columns[2]], agg + )() + num_groups = pandas_groupby_agg.shape[0] + # num_groups can also be obtained by revealing segment_ids[-1] + start = time.perf_counter() + # how to produce a secret random array of shape (row_num, ) is another question (not addressed here). + p1_random_order = np.random.random((X.shape[0],)) + p2_random_order = np.random.random((X.shape[0],)) + secret_random_order = p1_random_order + p2_random_order + proc_agg_shuffle = groupby_agg_fun(agg) + agg_result = spsim.sim_jax(sim, proc_agg_shuffle)( + target_cols, segment_end_marks, segment_ids, secret_random_order + ) + agg_result = groupby_agg_postprocess( + agg_result[0], agg_result[1], agg_result[2], num_groups + ) + end = time.perf_counter() + print("agg takes take", end - start) + assert ( + np.max(abs(pandas_groupby_agg.values.reshape(agg_result.shape) - agg_result)) + < 0.001 + ), f"{pandas_groupby_agg}, ours: \n {agg_result}" + + correct_keys = list(pandas_groupby_agg.index.to_numpy()) + correct_keys = np.array([[*a] for a in correct_keys]) + + # we open shuffled keys and take set(keys) + start = time.perf_counter() + # how to produce a secret random array of shape (row_num, ) is another question (not addressed here). + p1_random_order = np.random.random((X.shape[0],)) + p2_random_order = np.random.random((X.shape[0],)) + secret_random_order = p1_random_order + p2_random_order + keys = spsim.sim_jax(sim, proc_view_key)( + keys, + segment_end_marks, + secret_random_order, + ) + + keys = view_key_postprocessing(keys, num_groups) + + end = time.perf_counter() + print("view key takes take", end - start) + assert ( + np.max(abs(correct_keys - keys)) < 0.001 + ), f"value{ max(abs(correct_keys - keys))}, correct_keys, {correct_keys}, keys{keys}" class UnitTests(unittest.TestCase): def test_sum(self): - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) - - def proc(x1, x2, y): - return groupby([x1[:, 2], x2[:, 3]], [y]) - - def proc_sum_shuffle(cols, seg_end_marks, segment_ids, secret_random_order): - return groupby_sum_via_shuffle( - cols, seg_end_marks, segment_ids, secret_random_order - ) - - def proc_view_key(key_cols, segment_end_marks, key): - return view_key(key_cols, segment_end_marks, key) - - n_rows = 3000 - n_cols = 10 - x1 = np.random.random((n_rows, n_cols)) - x2 = np.random.random((n_rows, n_cols)) - y = np.random.random((n_rows,)) - # groupby category only supports discrete values - # here we are taking the shortcut, in reality, only the key need to be discrete. - # (their difference should be large enough so that their fxp repr are difference) - # we shrink the data value in order to reduce the size of groups. - # in our test data, we will get a group size about 15 - x1 = (x1 * 10 / 3).astype(int) - x2 = (x2 * 10 / 3).astype(int) - y = (y * 10 / 3).astype(int) - start = time.perf_counter() - keys, target_cols, segment_ids, segment_end_marks = spsim.sim_jax(sim, proc)( - x1, x2, y - ) - end = time.perf_counter() - print("groupby takes time", end - start) - X = np.zeros((x1.shape[0], 3)) - X[:, 0] = x1[:, 2] - X[:, 1] = x2[:, 3] - X[:, 2] = y - df = pd.DataFrame( - X, - columns=[f'col{i}' for i in range(3)], - ) - # Perform group by sum using pandas - pandas_groupby_sum = df.groupby([df.columns[0], df.columns[1]])[ - df.columns[2] - ].sum() - num_groups = pandas_groupby_sum.shape[0] - # num_groups can also be obtained by revealing segment_ids[-1] - start = time.perf_counter() - # how to produce a secret random array of shape (row_num, ) is another question (not addressed here). - p1_random_order = np.random.random((X.shape[0],)) - p2_random_order = np.random.random((X.shape[0],)) - secret_random_order = p1_random_order + p2_random_order - agg_result = spsim.sim_jax(sim, proc_sum_shuffle)( - target_cols, segment_end_marks, segment_ids, secret_random_order - ) - agg_result = groupby_agg_postprocess( - agg_result[0], agg_result[1], agg_result[2], num_groups - ) - end = time.perf_counter() - print("sum takes take", end - start) - assert ( - np.max( - abs(pandas_groupby_sum.values.reshape(agg_result.shape) - agg_result) - ) - < 0.001 - ), f"{pandas_groupby_sum}, ours: \n {agg_result}" + test_fn('sum') - correct_keys = list(pandas_groupby_sum.index.to_numpy()) - correct_keys = np.array([[*a] for a in correct_keys]) + def test_max(self): + test_fn('max') - # we open shuffled keys and take set(keys) - start = time.perf_counter() - # how to produce a secret random array of shape (row_num, ) is another question (not addressed here). - p1_random_order = np.random.random((X.shape[0],)) - p2_random_order = np.random.random((X.shape[0],)) - secret_random_order = p1_random_order + p2_random_order - keys = spsim.sim_jax(sim, proc_view_key)( - keys, - segment_end_marks, - secret_random_order, - ) + def test_min(self): + test_fn('min') - keys = view_key_postprocessing(keys, num_groups) + def test_mean(self): + test_fn('mean') - end = time.perf_counter() - print("view key takes take", end - start) - assert ( - np.max(abs(correct_keys - keys)) < 0.001 - ), f"value{ max(abs(correct_keys - keys))}, correct_keys, {correct_keys}, keys{keys}" + def test_var(self): + test_fn('var') def test_count(self): sim = spsim.Simulator.simple( @@ -170,7 +193,7 @@ def proc(x1, x2, y): ].count() start = time.perf_counter() - count_result = groupby_count(segment_ids) + count_result = groupby_count_cleartext(segment_ids) end = time.perf_counter() print("count takes take", end - start) assert ( @@ -197,98 +220,6 @@ def proc(x1, x2, y): np.max(abs(correct_keys - keys)) < 0.001 ), f"value{ max(abs(correct_keys - keys))}, correct_keys, {correct_keys}, keys{keys}" - def test_max(self): - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) - - def proc(x1, x2, y): - return groupby([x1[:, 2], x2[:, 3]], [y]) - - def proc_max_shuffle(cols, seg_end_marks, segment_ids, secret_random_order): - return groupby_max_via_shuffle( - cols, seg_end_marks, segment_ids, secret_random_order - ) - - def proc_view_key(key_cols, segment_end_marks, key): - return view_key(key_cols, segment_end_marks, key) - - n_rows = 30000 - n_cols = 10 - x1 = np.random.random((n_rows, n_cols)) - x2 = np.random.random((n_rows, n_cols)) - y = np.random.random((n_rows,)) - # groupby category only supports discrete values - # here we are taking the shortcut, in reality, only the key need to be discrete. - # (their difference should be large enough so that their fxp repr are difference) - # we shrink the data value in order to reduce the size of groups. - # in our test data, we will get a group size about 15 - x1 = (x1 * 10 / 3).astype(int) - x2 = (x2 * 10 / 3).astype(int) - y = (y * 10 / 3).astype(int) - start = time.perf_counter() - keys, target_cols, segment_ids, segment_end_marks = spsim.sim_jax(sim, proc)( - x1, x2, y - ) - end = time.perf_counter() - print("groupby takes time", end - start) - X = np.zeros((x1.shape[0], 3)) - X[:, 0] = x1[:, 2] - X[:, 1] = x2[:, 3] - X[:, 2] = y - df = pd.DataFrame( - X, - columns=[f'col{i}' for i in range(3)], - ) - # Perform group by sum using pandas - pandas_groupby_sum = df.groupby([df.columns[0], df.columns[1]])[ - df.columns[2] - ].max() - num_groups = pandas_groupby_sum.shape[0] - # num_groups can also be obtained by revealing segment_ids[-1] - start = time.perf_counter() - # how to produce a secret random array of shape (row_num, ) is another question (not addressed here). - p1_random_order = np.random.random((X.shape[0],)) - p2_random_order = np.random.random((X.shape[0],)) - secret_random_order = p1_random_order + p2_random_order - agg_result = spsim.sim_jax(sim, proc_max_shuffle)( - target_cols, segment_end_marks, segment_ids, secret_random_order - ) - agg_result = groupby_agg_postprocess( - agg_result[0], agg_result[1], agg_result[2], num_groups - ) - end = time.perf_counter() - print("sum takes take", end - start) - assert ( - np.max( - abs(pandas_groupby_sum.values.reshape(agg_result.shape) - agg_result) - ) - < 0.001 - ), f"{pandas_groupby_sum}, ours: \n {agg_result}" - - correct_keys = list(pandas_groupby_sum.index.to_numpy()) - correct_keys = np.array([[*a] for a in correct_keys]) - - # we open shuffled keys and take set(keys) - start = time.perf_counter() - # how to produce a secret random array of shape (row_num, ) is another question (not addressed here). - p1_random_order = np.random.random((X.shape[0],)) - p2_random_order = np.random.random((X.shape[0],)) - secret_random_order = p1_random_order + p2_random_order - keys = spsim.sim_jax(sim, proc_view_key)( - keys, - segment_end_marks, - secret_random_order, - ) - - keys = view_key_postprocessing(keys, num_groups) - - end = time.perf_counter() - print("view key takes take", end - start) - assert ( - np.max(abs(correct_keys - keys)) < 0.001 - ), f"value{ max(abs(correct_keys - keys))}, correct_keys, {correct_keys}, keys{keys}" - if __name__ == "__main__": unittest.main() diff --git a/spu/ops/groupby/groupby_via_shuffle.py b/spu/ops/groupby/groupby_via_shuffle.py new file mode 100644 index 00000000..341bce84 --- /dev/null +++ b/spu/ops/groupby/groupby_via_shuffle.py @@ -0,0 +1,88 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax.numpy as jnp + +from spu.ops.groupby.aggregation import ( + groupby_max, + groupby_mean, + groupby_min, + groupby_sum, + groupby_var, +) +from spu.ops.groupby.shuffle import shuffle_matrix + + +def groupby_sum_via_shuffle( + cols, seg_end_marks, segment_ids, secret_random_order: jnp.ndarray +): + """Perform groupby sum operation and shuffle the results. + The inputs are groupby operation's output. + """ + group_agg_matrix = groupby_sum(cols, seg_end_marks) + return shuffle_matrix( + group_agg_matrix, seg_end_marks, segment_ids, secret_random_order + ) + + +def groupby_max_via_shuffle( + cols, seg_end_marks, segment_ids, secret_random_order: jnp.ndarray +): + """Perform groupby max operation and shuffle the results. + The inputs are groupby operation's output. + """ + group_agg_matrix = groupby_max(cols, seg_end_marks) + return shuffle_matrix( + group_agg_matrix, seg_end_marks, segment_ids, secret_random_order + ) + + +def groupby_min_via_shuffle( + cols, seg_end_marks, segment_ids, secret_random_order: jnp.ndarray +): + """Perform groupby min operation and shuffle the results. + The inputs are groupby operation's output. + """ + group_agg_matrix = groupby_min(cols, seg_end_marks) + return shuffle_matrix( + group_agg_matrix, seg_end_marks, segment_ids, secret_random_order + ) + + +def groupby_mean_via_shuffle( + cols, seg_end_marks, segment_ids, secret_random_order: jnp.ndarray +): + """Perform groupby mean operation and shuffle the results. + The inputs are groupby operation's output. + """ + group_mean_matrix = groupby_mean(cols, seg_end_marks) + + return shuffle_matrix( + group_mean_matrix, seg_end_marks, segment_ids, secret_random_order + ) + + +def groupby_var_via_shuffle( + cols, seg_end_marks, segment_ids, secret_random_order: jnp.ndarray +): + """Perform groupby var operation and shuffle the results. + The inputs are groupby operation's output. + + recall that var(X) = (x_i - x_mean)^2 / (N - 1) + """ + + group_var_matrix = groupby_var(cols, seg_end_marks) + return shuffle_matrix( + group_var_matrix, seg_end_marks, segment_ids, secret_random_order + ) diff --git a/spu/ops/groupby/postprocess.py b/spu/ops/groupby/postprocess.py new file mode 100644 index 00000000..a60b9d09 --- /dev/null +++ b/spu/ops/groupby/postprocess.py @@ -0,0 +1,48 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import jax +import jax.numpy as jnp +import numpy as np + + +# cleartext function +def groupby_agg_postprocess( + segment_ids, seg_end_marks, group_agg_matrix, group_num: int +): + assert ( + isinstance(group_num, int) and group_num > 0 + ), f"group num must be a positive integer. got {group_num}, {type(group_num)}" + if group_num > 1: + filter_mask = seg_end_marks == 1 + segment_ids = segment_ids[filter_mask] + group_agg_matrix = group_agg_matrix[filter_mask] + sorted_results = jax.lax.sort( + [segment_ids] + + [group_agg_matrix[:, i] for i in range(group_agg_matrix.shape[1])], + num_keys=1, + )[1:] + return jnp.vstack(sorted_results).T + else: + return group_agg_matrix[-1] + + +# function operating on cleartext, used to postprocess opened results. +def view_key_postprocessing(keys, group_num: int): + """We want to view the key in order.""" + keys = np.unique(np.vstack(keys).T, axis=0) + if keys.shape[0] > group_num: + keys = keys[1:, :] + return keys diff --git a/spu/ops/groupby/segmentation.py b/spu/ops/groupby/segmentation.py new file mode 100644 index 00000000..a328d0b3 --- /dev/null +++ b/spu/ops/groupby/segmentation.py @@ -0,0 +1,91 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from typing import List, Tuple + +import jax +import jax.numpy as jnp + +from spu.ops.groupby.utils import rotate_cols + + +def groupby( + key_columns: List[jnp.ndarray], + target_columns: List[jnp.ndarray], +) -> Tuple[List[jnp.ndarray], jnp.ndarray]: + """GroupBy + Given a matrix X, it has multiple columns. + We want to calculate some statistics of target columns grouped by some columns as keys. + This operator completes the first step of GroupBy statistics: transform the matrix x into a form, + that is suitable for subsequent statistics. + + Parameters + ---------- + + key_columns : List[jnp.ndarray] + List of columns that are used as keys, these should be arrays of the same shape. + + target_columns : List[jnp.ndarray] + List of columns that are used as keys, these should be arrays of the same shape as the shape in key columns. + + + Returns + ------- + key_columns_sorted : List[jnp.ndarray] + target_columns_sorted : List[jnp.ndarray] + segment_ids : jnp.ndarray + seg_end_marks : jnp.ndarray + """ + # parameter check. + assert isinstance(key_columns, List) + assert isinstance(target_columns, List) + assert len(key_columns) > 0, "There should be at least one key_column." + assert len(target_columns) > 0, "There should be at least one target_column." + assert ( + len(set(map(lambda x: x.shape, key_columns + target_columns))) == 1 + ), f"Columns' shape should be consistent. {set(map(lambda x: x.shape, key_columns + target_columns))}" + key_columns = key_columns + target_columns = target_columns + sorted_columns = jax.lax.sort( + key_columns + target_columns, num_keys=len(key_columns) + ) + key_columns_sorted = sorted_columns[: len(key_columns)] + target_columns_sorted = sorted_columns[len(key_columns) :] + return groupby_sorted(key_columns_sorted, target_columns_sorted) + + +def groupby_sorted( + key_columns_sorted: List[jnp.ndarray], + target_columns_sorted: List[jnp.ndarray], +) -> Tuple[List[jnp.ndarray], jnp.ndarray]: + """Groupby on sorted data.""" + key_columns_sorted_rolled = rotate_cols(key_columns_sorted) + seg_end_marks = get_segment_marks(key_columns_sorted, key_columns_sorted_rolled) + mark_accumulated = associative_scan(seg_end_marks) + segment_ids = mark_accumulated - seg_end_marks + return key_columns_sorted, target_columns_sorted, segment_ids, seg_end_marks + + +def get_segment_marks(key_columns_sorted, key_columns_sorted_rolled): + tuple_list = list(zip(key_columns_sorted, key_columns_sorted_rolled)) + equal = [a - b == 0 for (a, b) in tuple_list] + c = ~functools.reduce(lambda x, y: x & y, equal) + c = c.astype(int) + result = jnp.r_[c[: c.size - 1], [1]] + return result + + +def associative_scan(seg_end_marks): + return jax.lax.associative_scan(jnp.add, seg_end_marks) diff --git a/spu/ops/groupby/shuffle.py b/spu/ops/groupby/shuffle.py new file mode 100644 index 00000000..4f32db0e --- /dev/null +++ b/spu/ops/groupby/shuffle.py @@ -0,0 +1,63 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import jax +import jax.numpy as jnp + +from spu.ops.groupby.utils import batch_product + + +def shuffle_matrix( + group_agg_matrix, + seg_end_marks, + segment_ids, + secret_random_order: jnp.ndarray, +): + """ + Shuffle the groupby matrix results for security + """ + segment_ids_masked = seg_end_marks * segment_ids + shuffled_cols = jax.lax.sort( + [secret_random_order] + + [segment_ids_masked] + + [seg_end_marks] + + [group_agg_matrix[:, i] for i in range(group_agg_matrix.shape[1])], + num_keys=1, + ) + return [ + shuffled_cols[1], + shuffled_cols[2], + jnp.vstack(shuffled_cols[3:]).T, + ] + + +def shuffle_cols( + cols_sorted: List[jnp.ndarray], + seg_end_marks: jnp.ndarray, + secret_random_order: jnp.ndarray, +): + """Shuffle the cols sorted based on the secret random order. + Often used to shuffle the key cols before revealing. + We want to view the key without leaking the number of elements in each key. + So we shuffle before revealing the key. + """ + assert len(cols_sorted) > 0, "number of keys must be non-empty" + keys = batch_product(cols_sorted, seg_end_marks) + assert ( + secret_random_order.shape == cols_sorted[0].shape + ), "the secret_random_order should be the same shape as each of the key columns." + cols_shuffled = jax.lax.sort([secret_random_order] + keys, num_keys=1)[1:] + return cols_shuffled diff --git a/spu/ops/groupby/utils.py b/spu/ops/groupby/utils.py new file mode 100644 index 00000000..7617ade3 --- /dev/null +++ b/spu/ops/groupby/utils.py @@ -0,0 +1,39 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import jax.numpy as jnp + + +def cols_to_matrix(cols): + return jnp.vstack(cols).T + + +def matrix_to_cols(matrix): + return [matrix[:, i] for i in range(matrix.shape[1])] + + +def batch_product(list_of_cols, multiplier_col): + """apply multiplication of multiplier col to each column of list_of_cols""" + return list( + map( + lambda x: x * multiplier_col, + list_of_cols, + ) + ) + + +def rotate_cols(key_columns_sorted) -> List[jnp.ndarray]: + return list(map(lambda x: jnp.roll(x, -1), key_columns_sorted)) diff --git a/spu/pir.py b/spu/pir.py index fcbf9f3c..a62b1c3d 100644 --- a/spu/pir.py +++ b/spu/pir.py @@ -16,17 +16,16 @@ from typing import List +from . import libspu # type: ignore from .pir_pb2 import ( # type: ignore KvStoreType, + PirClientConfig, PirProtocol, PirResultReport, - PirClientConfig, PirServerConfig, PirSetupConfig, ) -from . import libspu # type: ignore - def pir_setup(config: PirSetupConfig) -> List[str]: report_str = libspu.libs.pir_setup(config.SerializeToString()) diff --git a/spu/psi.py b/spu/psi.py index f80ac386..73b14f01 100644 --- a/spu/psi.py +++ b/spu/psi.py @@ -16,6 +16,8 @@ from typing import List +from . import libspu # type: ignore +from .libspu.libs import ProgressData from .psi_pb2 import ( # type: ignore BucketPsiConfig, CurveType, @@ -26,9 +28,6 @@ PsiType, ) -from . import libspu # type: ignore -from .libspu.libs import ProgressData - def mem_psi( link: libspu.link.Context, config: MemoryPsiConfig, input_items: List[str] diff --git a/spu/tests/distributed_test.py b/spu/tests/distributed_test.py index f6cc64a8..89454892 100644 --- a/spu/tests/distributed_test.py +++ b/spu/tests/distributed_test.py @@ -87,6 +87,10 @@ def no_in_dict_out(): return {"first": np.array([1, 2]), "second": np.array([3.0, 4.0])} +def tf_fun(x, y): + return tf.add(x, y) + + class UnitTests(unittest.TestCase): @classmethod def setUpClass(cls): @@ -263,13 +267,13 @@ def test_basic_spu_tf(self): npt.assert_equal(ppd.get(d["second"]), np.array([3.0, 4.0])) # immediate input from driver - e = ppd.device("SPU")(tf.add)(np.array([1, 2]), np.array([3, 4])) + e = ppd.device("SPU")(tf_fun)(np.array([1, 2]), np.array([3, 4])) self.assertTrue(isinstance(e, ppd.SPU.Object)) self.assertEqual(e.vtype, spu_pb2.VIS_PUBLIC) npt.assert_equal(ppd.get(e), np.array([4, 6])) # reuse inputs from SPU - c = ppd.device("SPU")(tf.add)(a, a) + c = ppd.device("SPU")(tf_fun)(a, a) self.assertTrue(isinstance(c, ppd.SPU.Object)) self.assertEqual(c.vtype, spu_pb2.VIS_PUBLIC) self.assertTrue(c.device is ppd.current().devices["SPU"]) @@ -277,7 +281,7 @@ def test_basic_spu_tf(self): # reuse a from SPU, x from pyu x = ppd.device("P1")(no_in_one_out)() - c = ppd.device("SPU")(tf.add)(a, x) + c = ppd.device("SPU")(tf_fun)(a, x) self.assertTrue(isinstance(c, ppd.SPU.Object)) self.assertEqual(c.vtype, spu_pb2.VIS_SECRET) self.assertTrue(c.device is ppd.current().devices["SPU"]) diff --git a/spu/tests/frontend_test.py b/spu/tests/frontend_test.py index 6e27ca07..46d4a2ca 100644 --- a/spu/tests/frontend_test.py +++ b/spu/tests/frontend_test.py @@ -94,16 +94,19 @@ def test_jax_compile(self): self.assertEqual(output.dtype, np.dtype("int32")) def test_tf_compile(self): + def foo(x, y): + return tf.add(x, y) + executable, output = spu_fe.compile( spu_fe.Kind.Tensorflow, - tf.add, + foo, (np.array([1, 2]), np.array([2, 4])), {}, ["in1", "in2"], [spu_pb2.VIS_PUBLIC, spu_pb2.VIS_PUBLIC], lambda out_flat: [f'test-out{idx}' for idx in range(len(out_flat))], ) - self.assertEqual(executable.name, "add") + self.assertEqual(executable.name, "foo") self.assertEqual(executable.input_names, ["in1", "in2"]) self.assertEqual(executable.output_names, ["test-out0"]) self.assertTrue( diff --git a/spu/tests/jax_sanity_test.py b/spu/tests/jax_sanity_test.py index ca62ddab..c07f2ebf 100644 --- a/spu/tests/jax_sanity_test.py +++ b/spu/tests/jax_sanity_test.py @@ -23,8 +23,8 @@ from sklearn import metrics from sklearn.datasets import load_breast_cancer -import spu.utils.simulation as ppsim import spu.spu_pb2 as spu_pb2 +import spu.utils.simulation as ppsim # Note: for un-normalized data, grad(sigmoid) is likely to overflow, either with exp/tanh or taylor series diff --git a/spu/tests/jnp_aby3_r128_test.py b/spu/tests/jnp_aby3_r128_test.py index 40ebf707..f9383905 100644 --- a/spu/tests/jnp_aby3_r128_test.py +++ b/spu/tests/jnp_aby3_r128_test.py @@ -17,8 +17,8 @@ import numpy as np -import spu.utils.simulation as ppsim import spu.spu_pb2 as spu_pb2 +import spu.utils.simulation as ppsim from spu.tests.jnp_testbase import JnpTests diff --git a/spu/tests/jnp_aby3_r64_test.py b/spu/tests/jnp_aby3_r64_test.py index 0599dfdd..6595ea3c 100644 --- a/spu/tests/jnp_aby3_r64_test.py +++ b/spu/tests/jnp_aby3_r64_test.py @@ -17,8 +17,8 @@ import numpy as np -import spu.utils.simulation as ppsim import spu.spu_pb2 as spu_pb2 +import spu.utils.simulation as ppsim from spu.tests.jnp_testbase import JnpTests diff --git a/spu/tests/jnp_cheetah_r64_test.py b/spu/tests/jnp_cheetah_r64_test.py index 0fd38c9a..b113dbcd 100644 --- a/spu/tests/jnp_cheetah_r64_test.py +++ b/spu/tests/jnp_cheetah_r64_test.py @@ -17,8 +17,8 @@ import numpy as np -import spu.utils.simulation as ppsim import spu.spu_pb2 as spu_pb2 +import spu.utils.simulation as ppsim from spu.tests.jnp_testbase import JnpTests diff --git a/spu/tests/jnp_debug.py b/spu/tests/jnp_debug.py index 956d19cb..54b7003a 100644 --- a/spu/tests/jnp_debug.py +++ b/spu/tests/jnp_debug.py @@ -18,9 +18,9 @@ import jax.numpy as jnp import numpy as np -import spu.utils.simulation as ppsim -import spu.spu_pb2 as spu_pb2 import spu.intrinsic as si +import spu.spu_pb2 as spu_pb2 +import spu.utils.simulation as ppsim if __name__ == "__main__": """ diff --git a/spu/tests/jnp_ref2k_r64_test.py b/spu/tests/jnp_ref2k_r64_test.py index 336333dc..3d536a14 100644 --- a/spu/tests/jnp_ref2k_r64_test.py +++ b/spu/tests/jnp_ref2k_r64_test.py @@ -17,8 +17,8 @@ import numpy as np -import spu.utils.simulation as ppsim import spu.spu_pb2 as spu_pb2 +import spu.utils.simulation as ppsim from spu.tests.jnp_testbase import JnpTests diff --git a/spu/tests/jnp_semi2k_r128_test.py b/spu/tests/jnp_semi2k_r128_test.py index ca0076a3..10295fed 100644 --- a/spu/tests/jnp_semi2k_r128_test.py +++ b/spu/tests/jnp_semi2k_r128_test.py @@ -17,8 +17,8 @@ import numpy as np -import spu.utils.simulation as ppsim import spu.spu_pb2 as spu_pb2 +import spu.utils.simulation as ppsim from spu.tests.jnp_testbase import JnpTests diff --git a/spu/tests/jnp_semi2k_r64_test.py b/spu/tests/jnp_semi2k_r64_test.py index 1a1baa17..dac8c371 100644 --- a/spu/tests/jnp_semi2k_r64_test.py +++ b/spu/tests/jnp_semi2k_r64_test.py @@ -17,8 +17,8 @@ import numpy as np -import spu.utils.simulation as ppsim import spu.spu_pb2 as spu_pb2 +import spu.utils.simulation as ppsim from spu.tests.jnp_testbase import JnpTests diff --git a/spu/tests/jnp_testbase.py b/spu/tests/jnp_testbase.py index 7ce841eb..02e66792 100644 --- a/spu/tests/jnp_testbase.py +++ b/spu/tests/jnp_testbase.py @@ -15,9 +15,9 @@ import collections import itertools -from os import getenv from enum import Enum from functools import partial +from os import getenv import jax.numpy as jnp import numpy as np diff --git a/spu/tests/spu_compiler_test.py b/spu/tests/spu_compiler_test.py index 4bfbbe7b..4e65befa 100644 --- a/spu/tests/spu_compiler_test.py +++ b/spu/tests/spu_compiler_test.py @@ -19,8 +19,8 @@ import numpy as np import numpy.testing as npt -import spu.utils.frontend as spu_fe import spu.spu_pb2 as spu_pb2 +import spu.utils.frontend as spu_fe class UnitTests(unittest.TestCase): diff --git a/spu/utils/distributed.py b/spu/utils/distributed.py index cdd2d065..f5089f36 100644 --- a/spu/utils/distributed.py +++ b/spu/utils/distributed.py @@ -49,12 +49,12 @@ from jax.tree_util import tree_map, tree_unflatten from termcolor import colored -from . import frontend as spu_fe -from . import distributed_pb2 # type: ignore -from . import distributed_pb2_grpc # type: ignore -from .. import libspu # type: ignore from .. import api as spu_api +from .. import libspu # type: ignore from .. import spu_pb2 +from . import distributed_pb2 # type: ignore +from . import distributed_pb2_grpc # type: ignore +from . import frontend as spu_fe """ This module is used as a simple scheduler to demonstrate SPU usage. diff --git a/spu/utils/frontend.py b/spu/utils/frontend.py index 125e4de3..5e527950 100644 --- a/spu/utils/frontend.py +++ b/spu/utils/frontend.py @@ -14,6 +14,7 @@ import functools import warnings from enum import Enum +from threading import Lock from typing import Callable, Dict, Iterable, List from cachetools import LRUCache, cached @@ -21,8 +22,6 @@ from .. import api as spu_api from .. import spu_pb2 -from threading import Lock - _jax_lock = Lock() @@ -73,23 +72,28 @@ def _jax_compilation( fn: Callable, static_argnums, static_argnames, args: List, kwargs: Dict ): import jax - - from jax._src.xla_bridge import register_backend_factory, _backend_lock, _backends - from jax._src.lib import xla_client + from jax._src.xla_bridge import _backend_lock, _backends, register_backend_factory + from jax._src.lib import xla_client, xla_extension_version # Register interpreter backend since we don't want any cpu/gpu/tpu specific optimization - try: + if xla_extension_version < 164: + # interpreter is registerd by default before jaxlib 0.4.13 + pass + else: has_interpreter_backend = False with _backend_lock: if 'interpreter' in _backends: has_interpreter_backend = True - if not has_interpreter_backend: - register_backend_factory( - 'interpreter', xla_client.make_interpreter_client, priority=-100 - ) - finally: - pass # Silent re-register error.... + if xla_extension_version < 194: + # make_interpreter_client has been removed after jaxlib 0.4.16 + register_backend_factory( + 'interpreter', xla_client.make_interpreter_client, priority=-100 + ) + else: + from jax.interpreters.xla import Backend as xla_back + + register_backend_factory('interpreter', xla_back, priority=-100) fn, kwargs = _argnames_partial_except(fn, static_argnames, kwargs) diff --git a/spu/utils/simulation.py b/spu/utils/simulation.py index ac71a003..b67a6e9b 100644 --- a/spu/utils/simulation.py +++ b/spu/utils/simulation.py @@ -22,10 +22,10 @@ from jax import linear_util as jax_lu from jax._src import api_util as japi_util -from . import frontend as spu_fe -from .. import libspu # type: ignore from .. import api as spu_api +from .. import libspu # type: ignore from .. import spu_pb2 +from . import frontend as spu_fe # https://stackoverflow.com/questions/2829329/catch-a-threads-exception-in-the-caller-thread-in-python