diff --git a/.github/workflows/scorecard.yml b/.github/workflows/scorecard.yml index 830fa74c..cdaae1de 100644 --- a/.github/workflows/scorecard.yml +++ b/.github/workflows/scorecard.yml @@ -32,7 +32,7 @@ jobs: steps: - name: "Checkout code" - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: persist-credentials: false @@ -67,6 +67,6 @@ jobs: # Upload the results to GitHub's code scanning dashboard. - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@c36620d31ac7c881962c3d9dd939c40ec9434f2b # v3.26.12 + uses: github/codeql-action/upload-sarif@48ab28a6f5dbc2a99bf1e0131198dd8f1df78169 # v3.28.0 with: sarif_file: results.sarif diff --git a/CHANGELOG.md b/CHANGELOG.md index 27502fd6..97ff36b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,10 @@ > > please add your unreleased change here. +## 20241219 + +- [SPU] 0.9.3b0 release +- [Improvement] Optimize exponential computation for semi2k (**experimental**) - [Feature] Add more send/recv actions profiling ## 20240716 diff --git a/bazel/repositories.bzl b/bazel/repositories.bzl index 6ba0fd47..c032474c 100644 --- a/bazel/repositories.bzl +++ b/bazel/repositories.bzl @@ -39,10 +39,10 @@ def _yacl(): http_archive, name = "yacl", urls = [ - "https://github.com/secretflow/yacl/archive/refs/tags/0.4.5b7_nightly_20240930.tar.gz", + "https://github.com/secretflow/yacl/archive/refs/tags/0.4.5b8_nightly_20241014.tar.gz", ], - strip_prefix = "yacl-0.4.5b7_nightly_20240930", - sha256 = "cf8dc7cceb9c5d05df00f1c086feec99d554db3e3cbe101253cf2a5a1adb9072", + strip_prefix = "yacl-0.4.5b8_nightly_20241014", + sha256 = "9141792f07eba507ffd21c57ec3df2ad5fdf90ce605ffb7bc1b7b4e84a9c34fa", ) def _libpsi(): @@ -50,10 +50,10 @@ def _libpsi(): http_archive, name = "psi", urls = [ - "https://github.com/secretflow/psi/archive/refs/tags/v0.4.3.dev240919.tar.gz", + "https://github.com/secretflow/psi/archive/refs/tags/v0.5.0.dev241115.tar.gz", ], - strip_prefix = "psi-0.4.3.dev240919", - sha256 = "1ee34fbbd9a8f36dea8f7c45588a858e8c31f3a38e60e1fc67cb428ea79334e3", + strip_prefix = "psi-0.5.0.dev241115", + sha256 = "4d5ccc61282c4f887cee2c12fe3f414dfd7e916952849e92ffb1f6835d657a35", ) def _rules_proto_grpc(): @@ -242,10 +242,10 @@ def _com_github_nvidia_cutlass(): maybe( http_archive, name = "cutlass_archive", - strip_prefix = "cutlass-3.5.1", + strip_prefix = "cutlass-3.6.0", urls = [ - "https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.5.1.tar.gz", + "https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.6.0.tar.gz", ], - sha256 = "20b7247cda2d257cbf8ba59ba3ca40a9211c4da61a9c9913e32b33a2c5883a36", + sha256 = "7576f3437b90d0de5923560ccecebaa1357e5d72f36c0a59ad77c959c9790010", build_file = "@spulib//bazel:nvidia_cutlass.BUILD", ) diff --git a/docs/requirements.txt b/docs/requirements.txt index c4b69bd7..113570aa 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,15 +1,15 @@ myst-parser==4.0.0 rstcheck==6.2.4 -sphinx==8.0.2 -nbsphinx==0.9.5 +sphinx==8.1.3 +nbsphinx==0.9.6 sphinx-autobuild==2024.10.3 sphinx-markdown-parser==0.2.4 sphinxcontrib-actdiag==3.0.0 sphinxcontrib-blockdiag==3.0.0 -sphinxcontrib-mermaid==0.9.2 +sphinxcontrib-mermaid==1.0.0 sphinxcontrib-nwdiag==2.0.0 sphinxcontrib-seqdiag==3.0.0 -pytablewriter==1.2.0 +pytablewriter==1.2.1 linkify-it-py==2.0.3 mdutils==1.6.0 spu>=0.3.1b0 diff --git a/libspu/compiler/front_end/hlo_importer.h b/libspu/compiler/front_end/hlo_importer.h index 8e57b7f0..551f2755 100644 --- a/libspu/compiler/front_end/hlo_importer.h +++ b/libspu/compiler/front_end/hlo_importer.h @@ -28,7 +28,9 @@ class CompilationContext; class HloImporter final { public: + // clang-format off explicit HloImporter(CompilationContext *context) : context_(context) {}; + // clang-format on /// Load a xla module and returns a mlir-hlo module mlir::OwningOpRef diff --git a/libspu/core/config.cc b/libspu/core/config.cc index 94648a11..17113f28 100644 --- a/libspu/core/config.cc +++ b/libspu/core/config.cc @@ -62,6 +62,17 @@ void populateRuntimeConfig(RuntimeConfig& cfg) { if (cfg.fxp_exp_mode() == RuntimeConfig::EXP_DEFAULT) { cfg.set_fxp_exp_mode(RuntimeConfig::EXP_TAYLOR); } + if (cfg.fxp_exp_mode() == RuntimeConfig::EXP_PRIME) { + // 0 offset is not supported + if (cfg.experimental_exp_prime_offset() == 0) { + // For FM128 default offset is 13 + if (cfg.field() == FieldType::FM128) { + cfg.set_experimental_exp_prime_offset(13); + } + // TODO: set defaults for other fields, currently only FM128 is + // supported + } + } if (cfg.fxp_exp_iters() == 0) { cfg.set_fxp_exp_iters(8); diff --git a/libspu/dialect/pphlo/IR/fold.cc b/libspu/dialect/pphlo/IR/fold.cc index 7946c07b..f4f42e45 100644 --- a/libspu/dialect/pphlo/IR/fold.cc +++ b/libspu/dialect/pphlo/IR/fold.cc @@ -49,6 +49,14 @@ OpFoldResult ReverseOp::fold(FoldAdaptor) { dims, [&](int64_t dim) { return shapedType.getDimSize(dim) == 1; })) { return input; } + + // reverse(reverse(x, dims), dims) = x + if (auto prev = input.getDefiningOp()) { + if (prev.getDimensions() == dims) { + return prev.getOperand(); + } + } + return {}; } diff --git a/libspu/kernel/hal/BUILD.bazel b/libspu/kernel/hal/BUILD.bazel index ef8a61b9..b5aeef41 100644 --- a/libspu/kernel/hal/BUILD.bazel +++ b/libspu/kernel/hal/BUILD.bazel @@ -127,6 +127,7 @@ spu_cc_test( deps = [ ":fxp_approx", "//libspu/kernel:test_util", + "//libspu/mpc/utils:simulate", ], ) diff --git a/libspu/kernel/hal/fxp_approx.cc b/libspu/kernel/hal/fxp_approx.cc index 9f3105c3..34667e84 100644 --- a/libspu/kernel/hal/fxp_approx.cc +++ b/libspu/kernel/hal/fxp_approx.cc @@ -201,6 +201,31 @@ Value exp_taylor(SPUContext* ctx, const Value& x) { return res; } +Value exp_prime(SPUContext* ctx, const Value& x) { + auto clamped_x = x; + auto offset = ctx->config().experimental_exp_prime_offset(); + auto fxp = ctx->getFxpBits(); + if (!ctx->config().experimental_exp_prime_disable_lower_bound()) { + // currently the bound is tied to FM128 + SPU_ENFORCE_EQ(ctx->getField(), FieldType::FM128); + auto lower_bound = (48.0 - offset - 2.0 * fxp) / M_LOG2E; + clamped_x = _clamp_lower(ctx, clamped_x, + constant(ctx, lower_bound, x.dtype(), x.shape())) + .setDtype(x.dtype()); + } + if (ctx->config().experimental_exp_prime_enable_upper_bound()) { + // currently the bound is tied to FM128 + SPU_ENFORCE_EQ(ctx->getField(), FieldType::FM128); + auto upper_bound = (124.0 - 2.0 * fxp - offset) / M_LOG2E; + clamped_x = _clamp_upper(ctx, clamped_x, + constant(ctx, upper_bound, x.dtype(), x.shape())) + .setDtype(x.dtype()); + } + + auto ret = dynDispatch(ctx, "exp_a", clamped_x); + return ret.setDtype(x.dtype()); +} + namespace { // Pade approximation of exp2(x), x is in [0, 1]. @@ -439,13 +464,22 @@ Value f_exp(SPUContext* ctx, const Value& x) { case RuntimeConfig::EXP_PADE: { // The valid input for exp_pade is [-kInputLimit, kInputLimit]. // TODO(junfeng): should merge clamp into exp_pade to save msb ops. - const float kInputLimit = 32 / std::log2(std::exp(1)); + const float kInputLimit = 32.0 / std::log2(std::exp(1)); const auto clamped_x = _clamp(ctx, x, constant(ctx, -kInputLimit, x.dtype(), x.shape()), constant(ctx, kInputLimit, x.dtype(), x.shape())) .setDtype(x.dtype()); return detail::exp_pade(ctx, clamped_x); } + case RuntimeConfig::EXP_PRIME: + if (ctx->hasKernel("exp_a")) { + return detail::exp_prime(ctx, x); + } else { + SPU_THROW( + "exp_a is not implemented for this protocol, currently only " + "2pc " + "semi2k is supported."); + } default: SPU_THROW("unexpected exp approximation method {}", ctx->config().fxp_exp_mode()); diff --git a/libspu/kernel/hal/fxp_approx.h b/libspu/kernel/hal/fxp_approx.h index fa401887..44c724f3 100644 --- a/libspu/kernel/hal/fxp_approx.h +++ b/libspu/kernel/hal/fxp_approx.h @@ -38,6 +38,8 @@ Value exp2_pade(SPUContext* ctx, const Value& x); // Works for range [-12.0, 18.0] Value exp_pade(SPUContext* ctx, const Value& x); +Value exp_prime(SPUContext* ctx, const Value& x); + Value tanh_chebyshev(SPUContext* ctx, const Value& x); } // namespace detail diff --git a/libspu/kernel/hal/fxp_approx_test.cc b/libspu/kernel/hal/fxp_approx_test.cc index c79dc434..d540eb2b 100644 --- a/libspu/kernel/hal/fxp_approx_test.cc +++ b/libspu/kernel/hal/fxp_approx_test.cc @@ -20,6 +20,7 @@ #include "libspu/kernel/hal/constants.h" #include "libspu/kernel/hal/type_cast.h" #include "libspu/kernel/test_util.h" +#include "libspu/mpc/utils/simulate.h" namespace spu::kernel::hal { @@ -78,10 +79,35 @@ TEST(FxpTest, ExponentialPade) { << y; } +TEST(FxpTest, ExponentialPrime) { + std::cout << "test exp_prime" << std::endl; + spu::mpc::utils::simulate(2, [&](std::shared_ptr lctx) { + RuntimeConfig conf; + conf.set_protocol(ProtocolKind::SEMI2K); + conf.set_field(FieldType::FM128); + conf.set_fxp_fraction_bits(40); + conf.set_experimental_enable_exp_prime(true); + SPUContext ctx = test::makeSPUContext(conf, lctx); + + auto offset = ctx.config().experimental_exp_prime_offset(); + auto fxp = ctx.getFxpBits(); + auto lower_bound = (48.0 - offset - 2.0 * fxp) / M_LOG2E; + auto upper_bound = (124.0 - 2.0 * fxp - offset) / M_LOG2E; + + xt::xarray x = xt::linspace(lower_bound, upper_bound, 4000); + + Value a = test::makeValue(&ctx, x, VIS_SECRET); + Value c = detail::exp_prime(&ctx, a); + auto y = dump_public_as(&ctx, reveal(&ctx, c)); + EXPECT_TRUE(xt::allclose(xt::exp(x), y, 0.01, 0.001)) + << xt::exp(x) << std::endl + << y; + }); +} + TEST(FxpTest, Log) { // GIVEN SPUContext ctx = test::makeSPUContext(); - xt::xarray x = {{0.05, 0.5}, {5, 50}}; // public log { diff --git a/libspu/kernel/hal/ring.cc b/libspu/kernel/hal/ring.cc index 6844a1b8..725fd498 100644 --- a/libspu/kernel/hal/ring.cc +++ b/libspu/kernel/hal/ring.cc @@ -472,14 +472,25 @@ Value _mux(SPUContext* ctx, const Value& pred, const Value& a, const Value& b) { Value _clamp(SPUContext* ctx, const Value& x, const Value& minv, const Value& maxv) { SPU_TRACE_HAL_LEAF(ctx, x, minv, maxv); - // clamp lower bound, res = x < minv ? minv : x auto res = _mux(ctx, _less(ctx, x, minv), minv, x); - // clamp upper bound, res = res < maxv ? res, maxv return _mux(ctx, _less(ctx, res, maxv), res, maxv); } +// TODO: refactor polymorphic, and may use select functions in polymorphic +Value _clamp_lower(SPUContext* ctx, const Value& x, const Value& minv) { + SPU_TRACE_HAL_LEAF(ctx, x, minv); + // clamp lower bound, res = x < minv ? minv : x + return _mux(ctx, _less(ctx, x, minv), minv, x); +} + +Value _clamp_upper(SPUContext* ctx, const Value& x, const Value& maxv) { + SPU_TRACE_HAL_LEAF(ctx, x, maxv); + // clamp upper bound, x = x < maxv ? x, maxv + return _mux(ctx, _less(ctx, x, maxv), x, maxv); +} + Value _constant(SPUContext* ctx, uint128_t init, const Shape& shape) { return _make_p(ctx, init, shape); } diff --git a/libspu/kernel/hal/ring.h b/libspu/kernel/hal/ring.h index 0dd7234a..f0bbb01b 100644 --- a/libspu/kernel/hal/ring.h +++ b/libspu/kernel/hal/ring.h @@ -88,6 +88,11 @@ Value _mux(SPUContext* ctx, const Value& pred, const Value& a, const Value& b); // TODO: test me Value _clamp(SPUContext* ctx, const Value& x, const Value& minv, const Value& maxv); + +Value _clamp_lower(SPUContext* ctx, const Value& x, const Value& minv); + +Value _clamp_upper(SPUContext* ctx, const Value& x, const Value& maxv); + // Make a public value from uint128_t init value. // // If the current working field has less than 128bit, the lower sizeof(field) diff --git a/libspu/kernel/hal/shape_ops.cc b/libspu/kernel/hal/shape_ops.cc index 6337719f..ffcfed60 100644 --- a/libspu/kernel/hal/shape_ops.cc +++ b/libspu/kernel/hal/shape_ops.cc @@ -43,10 +43,10 @@ Value update_slice(SPUContext* ctx, const Value& in, const Value& update, SPU_TRACE_HAL_DISP(ctx, in, start_indices); if (in.storage_type() != update.storage_type()) { - auto u = - _cast_type(ctx, update, in.storage_type()).setDtype(update.dtype()); - - return update_slice(ctx, in, u, start_indices); + auto ct = _common_type(ctx, update.storage_type(), in.storage_type()); + auto u = _cast_type(ctx, update, ct).setDtype(update.dtype()); + auto i = _cast_type(ctx, in, ct).setDtype(in.dtype()); + return update_slice(ctx, i, u, start_indices); } return _update_slice(ctx, in, update, start_indices).setDtype(in.dtype()); diff --git a/libspu/kernel/hal/type_cast.cc b/libspu/kernel/hal/type_cast.cc index 9adb77a8..89ad0142 100644 --- a/libspu/kernel/hal/type_cast.cc +++ b/libspu/kernel/hal/type_cast.cc @@ -78,6 +78,12 @@ Value reveal(SPUContext* ctx, const Value& x) { return _s2p(ctx, x).setDtype(x.dtype()); } +Value reveal_to(SPUContext* ctx, const Value& x, size_t rank) { + SPU_TRACE_HAL_LEAF(ctx, x, rank); + SPU_ENFORCE(x.isSecret()); + return _s2v(ctx, x, rank).setDtype(x.dtype()); +} + Value dtype_cast(SPUContext* ctx, const Value& in, DataType to_type) { SPU_TRACE_HAL_DISP(ctx, in, to_type); diff --git a/libspu/kernel/hal/type_cast.h b/libspu/kernel/hal/type_cast.h index cfcc05fb..43ee72bb 100644 --- a/libspu/kernel/hal/type_cast.h +++ b/libspu/kernel/hal/type_cast.h @@ -35,4 +35,9 @@ Value seal(SPUContext* ctx, const Value& x); // @param in, the input value Value reveal(SPUContext* ctx, const Value& x); +/// reveal a secret to a specific party +// @param in, the input value +// @param rank, the rank of the party to reveal to +Value reveal_to(SPUContext* ctx, const Value& x, size_t rank); + } // namespace spu::kernel::hal diff --git a/libspu/kernel/hlo/BUILD.bazel b/libspu/kernel/hlo/BUILD.bazel index e798fa5e..d9ee64ca 100644 --- a/libspu/kernel/hlo/BUILD.bazel +++ b/libspu/kernel/hlo/BUILD.bazel @@ -110,6 +110,7 @@ spu_cc_test( ":casting", ":const", "//libspu/kernel:test_util", + "//libspu/mpc/utils:simulate", ], ) diff --git a/libspu/kernel/hlo/casting.cc b/libspu/kernel/hlo/casting.cc index 1b87c653..b3c0993c 100644 --- a/libspu/kernel/hlo/casting.cc +++ b/libspu/kernel/hlo/casting.cc @@ -52,6 +52,10 @@ spu::Value Reveal(SPUContext *ctx, const spu::Value &in) { return hal::reveal(ctx, in); } +spu::Value RevealTo(SPUContext *ctx, const spu::Value &in, size_t rank) { + return hal::reveal_to(ctx, in, rank); +} + spu::Value Seal(SPUContext *ctx, const spu::Value &in) { return hal::seal(ctx, in); } diff --git a/libspu/kernel/hlo/casting.h b/libspu/kernel/hlo/casting.h index 469ba42d..cf5a5be9 100644 --- a/libspu/kernel/hlo/casting.h +++ b/libspu/kernel/hlo/casting.h @@ -29,6 +29,8 @@ spu::Value Bitcast(SPUContext *ctx, const spu::Value &in, DataType dst_dtype); spu::Value Reveal(SPUContext *ctx, const spu::Value &in); +spu::Value RevealTo(SPUContext *ctx, const spu::Value &in, size_t rank); + spu::Value Seal(SPUContext *ctx, const spu::Value &in); } // namespace spu::kernel::hlo diff --git a/libspu/kernel/hlo/casting_test.cc b/libspu/kernel/hlo/casting_test.cc index f103a155..6e59b4b3 100644 --- a/libspu/kernel/hlo/casting_test.cc +++ b/libspu/kernel/hlo/casting_test.cc @@ -20,23 +20,48 @@ #include "libspu/core/value.h" #include "libspu/kernel/hlo/const.h" #include "libspu/kernel/test_util.h" +#include "libspu/mpc/utils/simulate.h" namespace spu::kernel::hlo { -TEST(ConstTest, Empty) { - SPUContext sctx = test::makeSPUContext(); +class CastingTest + : public ::testing::TestWithParam> {}; - auto empty_c = Constant(&sctx, true, {0}); +TEST_P(CastingTest, Empty) { + FieldType field = std::get<0>(GetParam()); + ProtocolKind prot = std::get<1>(GetParam()); - // Seal - auto s_empty = Seal(&sctx, empty_c); + mpc::utils::simulate( + 3, [&](const std::shared_ptr &lctx) { + SPUContext sctx = test::makeSPUContext(prot, field, lctx); + auto empty_c = Constant(&sctx, true, {0}); - // Reveal - auto p_empty = Reveal(&sctx, s_empty); + // Seal + auto s_empty = Seal(&sctx, empty_c); - EXPECT_EQ(p_empty.numel(), 0); - EXPECT_EQ(p_empty.shape().size(), 1); - EXPECT_EQ(p_empty.shape()[0], 0); + // Reveal + auto p_empty = Reveal(&sctx, s_empty); + + // RevealTo + auto v_empty = RevealTo(&sctx, s_empty, 0); + + EXPECT_EQ(p_empty.numel(), 0); + EXPECT_EQ(p_empty.shape().size(), 1); + EXPECT_EQ(p_empty.shape()[0], 0); + + EXPECT_EQ(v_empty.numel(), 0); + EXPECT_EQ(v_empty.shape().size(), 1); + EXPECT_EQ(v_empty.shape()[0], 0); + }); } +INSTANTIATE_TEST_SUITE_P( + CastingTestInstances, CastingTest, + testing::Combine(testing::Values(FieldType::FM64, FieldType::FM128), + testing::Values(ProtocolKind::REF2K, ProtocolKind::SEMI2K, + ProtocolKind::ABY3)), + [](const testing::TestParamInfo &p) { + return fmt::format("{}x{}", std::get<0>(p.param), std::get<1>(p.param)); + }); + } // namespace spu::kernel::hlo diff --git a/libspu/mpc/ab_api_test.cc b/libspu/mpc/ab_api_test.cc index cb36e44f..80ffedee 100644 --- a/libspu/mpc/ab_api_test.cc +++ b/libspu/mpc/ab_api_test.cc @@ -203,12 +203,11 @@ TEST_P(ArithmeticTest, MulA1B) { auto p1 = rand_p(obj.get(), conf.protocol() == ProtocolKind::CHEETAH ? Shape({200, 26}) : kShape); - p1 = rshift_p(obj.get(), p1, {K - 1}); auto a0 = p2a(obj.get(), p0); auto a1 = p2b(obj.get(), p1); // hint runtime this is a 1bit value. - a1 = lshift_b(obj.get(), a1, {K - 1}); - a1 = rshift_b(obj.get(), a1, {K - 1}); + // Sometimes, the underlying value is not strictly 1bit + a1.storage_type().as()->setNbits(1); /* WHEN */ auto prev = obj->prot()->getState()->getStats(); @@ -216,7 +215,9 @@ TEST_P(ArithmeticTest, MulA1B) { auto cost = obj->prot()->getState()->getStats() - prev; auto r_aa = a2p(obj.get(), tmp); - auto r_pp = mul_pp(obj.get(), p0, p1); + auto r_pp = + mul_pp(obj.get(), p0, + rshift_p(obj.get(), lshift_p(obj.get(), p1, {K - 1}), {K - 1})); /* THEN */ EXPECT_VALUE_EQ(r_aa, r_pp); diff --git a/libspu/mpc/api.cc b/libspu/mpc/api.cc index ca5842bb..11aa34fd 100644 --- a/libspu/mpc/api.cc +++ b/libspu/mpc/api.cc @@ -173,7 +173,11 @@ Value s2v(SPUContext* ctx, const Value& x, size_t owner) { return a2v(ctx, x, owner); } else { SPU_ENFORCE(IsB(x)); - return b2v(ctx, x, owner); + if (ctx->hasKernel("b2v")) { + return b2v(ctx, x, owner); + } else { + return a2v(ctx, _2a(ctx, x), owner); + } } } diff --git a/libspu/mpc/cheetah/nonlinear/compare_prot_test.cc b/libspu/mpc/cheetah/nonlinear/compare_prot_test.cc index 26f9fbd1..5abf60dc 100644 --- a/libspu/mpc/cheetah/nonlinear/compare_prot_test.cc +++ b/libspu/mpc/cheetah/nonlinear/compare_prot_test.cc @@ -57,7 +57,11 @@ TEST_P(CompareProtTest, Compare) { xinp = NdArrayView(inp[1]); xinp[0] = 1; xinp[1] = 9; - xinp[2] = 1000; + if constexpr (std::is_same_v) { + xinp[2] = 100; + } else { + xinp[2] = 1000; + } }); NdArrayRef cmp_oup[2]; @@ -108,7 +112,11 @@ TEST_P(CompareProtTest, CompareBitWidth) { xinp = NdArrayView(inp[1]); xinp[0] = 1; xinp[1] = 9; - xinp[2] = 1000; + if constexpr (std::is_same_v) { + xinp[2] = 100; + } else { + xinp[2] = 1000; + } pforeach(0, inp[0].numel(), [&](int64_t i) { xinp[i] &= mask; }); }); @@ -178,7 +186,11 @@ TEST_P(CompareProtTest, WithEq) { xinp = NdArrayView(inp[1]); xinp[0] = 1; xinp[1] = 9; - xinp[2] = 1000; + if constexpr (std::is_same_v) { + xinp[2] = 100; + } else { + xinp[2] = 1000; + } }); NdArrayRef cmp_oup[2]; @@ -237,7 +249,11 @@ TEST_P(CompareProtTest, WithEqBitWidth) { xinp = NdArrayView(inp[1]); xinp[0] = 1; xinp[1] = 9; - xinp[2] = 1000; + if constexpr (std::is_same_v) { + xinp[2] = 100; + } else { + xinp[2] = 1000; + } pforeach(0, inp[0].numel(), [&](int64_t i) { xinp[i] &= mask; }); }); diff --git a/libspu/mpc/common/BUILD.bazel b/libspu/mpc/common/BUILD.bazel index 49c5bb54..76511955 100644 --- a/libspu/mpc/common/BUILD.bazel +++ b/libspu/mpc/common/BUILD.bazel @@ -64,6 +64,7 @@ spu_cc_library( hdrs = ["communicator.h"], deps = [ "//libspu/core:object", + "//libspu/mpc/utils:gfmp_ops", "//libspu/mpc/utils:ring_ops", "@yacl//yacl/link:context", "@yacl//yacl/link/algorithm:allgather", @@ -110,6 +111,7 @@ spu_cc_library( hdrs = ["prg_tensor.h"], deps = [ "//libspu/core:ndarray_ref", + "//libspu/mpc/utils:gfmp_ops", "//libspu/mpc/utils:ring_ops", "@yacl//yacl/crypto/tools:prg", ], diff --git a/libspu/mpc/common/communicator.cc b/libspu/mpc/common/communicator.cc index c7cbc427..b7dc4089 100644 --- a/libspu/mpc/common/communicator.cc +++ b/libspu/mpc/common/communicator.cc @@ -14,6 +14,7 @@ #include "libspu/mpc/common/communicator.h" +#include "libspu/mpc/utils/gfmp_ops.h" #include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc { @@ -53,7 +54,11 @@ NdArrayRef Communicator::allReduce(ReduceOp op, const NdArrayRef& in, auto arr = NdArrayRef(stealBuffer(std::move(bufs[idx])), in.eltype(), in.shape(), makeCompactStrides(in.shape()), kOffset); if (op == ReduceOp::ADD) { - ring_add_(res, arr); + if (in.eltype().isa()) { + gfmp_add_mod_(res, arr); + } else { + ring_add_(res, arr); + } } else if (op == ReduceOp::XOR) { ring_xor_(res, arr); } else { @@ -86,7 +91,11 @@ NdArrayRef Communicator::reduce(ReduceOp op, const NdArrayRef& in, size_t root, NdArrayRef(stealBuffer(std::move(bufs[idx])), in.eltype(), in.shape(), makeCompactStrides(in.shape()), kOffset); if (op == ReduceOp::ADD) { - ring_add_(res, arr); + if (in.eltype().isa()) { + gfmp_add_mod_(res, arr); + } else { + ring_add_(res, arr); + } } else if (op == ReduceOp::XOR) { ring_xor_(res, arr); } else { @@ -94,7 +103,6 @@ NdArrayRef Communicator::reduce(ReduceOp op, const NdArrayRef& in, size_t root, } } } - stats_.latency += 1; stats_.comm += in.numel() * in.elsize(); diff --git a/libspu/mpc/common/prg_tensor.h b/libspu/mpc/common/prg_tensor.h index 54de9171..d3b3e704 100644 --- a/libspu/mpc/common/prg_tensor.h +++ b/libspu/mpc/common/prg_tensor.h @@ -15,6 +15,7 @@ #pragma once #include "libspu/core/ndarray_ref.h" +#include "libspu/mpc/utils/gfmp_ops.h" #include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc { @@ -22,28 +23,46 @@ namespace spu::mpc { using PrgSeed = uint128_t; using PrgCounter = uint64_t; +// Gfmp is regarded as word +// standing for Galois Field with Mersenne Prime. +enum class ElementType { kRing, kGfmp }; + struct PrgArrayDesc { Shape shape; FieldType field; PrgCounter prg_counter; + ElementType eltype; }; inline NdArrayRef prgCreateArray(FieldType field, const Shape& shape, PrgSeed seed, PrgCounter* counter, - PrgArrayDesc* desc) { + PrgArrayDesc* desc, + ElementType eltype = ElementType::kRing) { if (desc != nullptr) { - *desc = {Shape(shape.begin(), shape.end()), field, *counter}; + *desc = {Shape(shape.begin(), shape.end()), field, *counter, eltype}; + } + if (eltype == ElementType::kGfmp) { + return gfmp_rand(field, shape, seed, counter); + } else { + return ring_rand(field, shape, seed, counter); } - return ring_rand(field, shape, seed, counter); } inline NdArrayRef prgReplayArray(PrgSeed seed, const PrgArrayDesc& desc) { PrgCounter counter = desc.prg_counter; - return ring_rand(desc.field, desc.shape, seed, &counter); + if (desc.eltype == ElementType::kGfmp) { + return gfmp_rand(desc.field, desc.shape, seed, &counter); + } else { + return ring_rand(desc.field, desc.shape, seed, &counter); + } } inline NdArrayRef prgReplayArrayMutable(PrgSeed seed, PrgArrayDesc& desc) { - return ring_rand(desc.field, desc.shape, seed, &desc.prg_counter); + if (desc.eltype == ElementType::kGfmp) { + return gfmp_rand(desc.field, desc.shape, seed, &desc.prg_counter); + } else { + return ring_rand(desc.field, desc.shape, seed, &desc.prg_counter); + } } } // namespace spu::mpc diff --git a/libspu/mpc/semi2k/BUILD.bazel b/libspu/mpc/semi2k/BUILD.bazel index e845ee73..dfef1ec7 100644 --- a/libspu/mpc/semi2k/BUILD.bazel +++ b/libspu/mpc/semi2k/BUILD.bazel @@ -46,6 +46,34 @@ spu_cc_library( ], ) +spu_cc_library( + name = "prime_utils", + srcs = ["prime_utils.cc"], + hdrs = ["prime_utils.h"], + deps = [ + ":state", + ":type", + "//libspu/mpc:kernel", + "//libspu/mpc/common:communicator", + "//libspu/mpc/utils:gfmp", + "//libspu/mpc/utils:ring_ops", + ], +) + +spu_cc_library( + name = "exp", + srcs = ["exp.cc"], + hdrs = ["exp.h"], + deps = [ + ":prime_utils", + ":state", + ":type", + "//libspu/mpc:kernel", + "//libspu/mpc/utils:gfmp", + "//libspu/mpc/utils:ring_ops", + ], +) + spu_cc_library( name = "conversion", srcs = ["conversion.cc"], @@ -83,6 +111,7 @@ spu_cc_library( ":arithmetic", ":boolean", ":conversion", + ":exp", ":permute", ":state", "//libspu/mpc/common:prg_state", @@ -94,7 +123,10 @@ spu_cc_test( name = "protocol_test", srcs = ["protocol_test.cc"], deps = [ + ":exp", + ":prime_utils", ":protocol", + ":type", "//libspu/mpc:ab_api_test", "//libspu/mpc:api_test", "//libspu/mpc/semi2k/beaver/beaver_impl/ttp_server:beaver_server", diff --git a/libspu/mpc/semi2k/arithmetic.cc b/libspu/mpc/semi2k/arithmetic.cc index 7fa9473c..25c31934 100644 --- a/libspu/mpc/semi2k/arithmetic.cc +++ b/libspu/mpc/semi2k/arithmetic.cc @@ -267,13 +267,17 @@ NdArrayRef MulAA::proc(KernelEvalContext* ctx, const NdArrayRef& x, auto [a, b, c, x_a, y_b] = MulOpen(ctx, x, y, false); // Zi = Ci + (X - A) * Bi + (Y - B) * Ai + <(X - A) * (Y - B)> - auto z = ring_add( - ring_add(ring_mul(std::move(b), x_a), ring_mul(std::move(a), y_b)), c); + ring_mul_(b, x_a); + ring_mul_(a, y_b); + ring_add_(b, a); + ring_add_(b, c); + if (comm->getRank() == 0) { // z += (X-A) * (Y-B); - ring_add_(z, ring_mul(std::move(x_a), y_b)); + ring_mul_(x_a, y_b); + ring_add_(b, x_a); } - return z.as(x.eltype()); + return b.as(x.eltype()); } NdArrayRef SquareA::proc(KernelEvalContext* ctx, const NdArrayRef& x) const { @@ -318,6 +322,51 @@ NdArrayRef SquareA::proc(KernelEvalContext* ctx, const NdArrayRef& x) const { return z.as(x.eltype()); } +// Let x be AShrTy, y be BShrTy, nbits(y) == 1 +// (x0+x1) * (y0^y1) = (x0+x1) * (y0+y1-2y0y1) +// we define xx0 = (1-2y0)x0, xx1 = (1-2y1)x1 +// yy0 = y0, yy1 = y1 +// if we can compute z0+z1 = xx0*yy1 + xx1*yy0 (which can be easily got from Mul +// Beaver), then (x0+x1) * (y0^y1) = (z0 + z1) + (x0y0 + x1y1) +NdArrayRef MulA1B::proc(KernelEvalContext* ctx, const NdArrayRef& x, + const NdArrayRef& y) const { + SPU_ENFORCE(x.eltype().as()->field() == + y.eltype().as()->field()); + + const auto field = x.eltype().as()->field(); + auto* comm = ctx->getState(); + + // IMPORTANT: the underlying value of y is not exactly 0 or 1, so we must mask + // it explicitly. + auto yy = ring_bitmask(y, 0, 1).as(makeType(field)); + // To optimize memory usage, re-use xx buffer + auto xx = ring_ones(field, x.shape()); + ring_sub_(xx, ring_lshift(yy, {1})); + ring_mul_(xx, x); + + auto [a, b, c, xx_a, yy_b] = MulOpen(ctx, xx, yy, false); + + // Zi = Ci + (XX - A) * Bi + (YY - B) * Ai + <(XX - A) * (YY - B)> - XXi * YYi + // We re-use b to compute z + ring_mul_(b, xx_a); + ring_mul_(a, yy_b); + ring_add_(b, a); + ring_add_(b, c); + + ring_mul_(xx, yy); + ring_sub_(b, xx); + if (comm->getRank() == 0) { + // z += (XX-A) * (YY-B); + ring_mul_(xx_a, yy_b); + ring_add_(b, xx_a); + } + + // zi += xi * yi + ring_add_(b, ring_mul(x, yy)); + + return b.as(x.eltype()); +} + //////////////////////////////////////////////////////////////////// // matmul family //////////////////////////////////////////////////////////////////// diff --git a/libspu/mpc/semi2k/arithmetic.h b/libspu/mpc/semi2k/arithmetic.h index 8225e74f..8b887ed5 100644 --- a/libspu/mpc/semi2k/arithmetic.h +++ b/libspu/mpc/semi2k/arithmetic.h @@ -162,6 +162,22 @@ class SquareA : public UnaryKernel { NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& x) const override; }; +// Note: only for 2PC. +class MulA1B : public BinaryKernel { + public: + static constexpr const char* kBindName() { return "mul_a1b"; } + + ce::CExpr latency() const override { + // TODO: consider beaver + return ce::Const(1); + } + + ce::CExpr comm() const override { return ce::K() * 2; } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& x, + const NdArrayRef& y) const override; +}; + //////////////////////////////////////////////////////////////////// // matmul family //////////////////////////////////////////////////////////////////// diff --git a/libspu/mpc/semi2k/beaver/beaver_cache.h b/libspu/mpc/semi2k/beaver/beaver_cache.h index 3c9d9ade..5de345ff 100644 --- a/libspu/mpc/semi2k/beaver/beaver_cache.h +++ b/libspu/mpc/semi2k/beaver/beaver_cache.h @@ -32,9 +32,11 @@ namespace spu::mpc::semi2k { class BeaverCache { public: + // clang-format off BeaverCache() : cache_db_(fmt::format("BeaverCache.{}.{}.{}", getpid(), fmt::ptr(this), std::random_device()())) {}; + // clang-format on ~BeaverCache() { db_.reset(); try { diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/BUILD.bazel b/libspu/mpc/semi2k/beaver/beaver_impl/BUILD.bazel index 05e589f3..5f0bd1e1 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/BUILD.bazel +++ b/libspu/mpc/semi2k/beaver/beaver_impl/BUILD.bazel @@ -25,6 +25,7 @@ spu_cc_library( "//libspu/mpc/semi2k/beaver:beaver_interface", "//libspu/mpc/semi2k/beaver/beaver_impl/trusted_party", "//libspu/mpc/semi2k/beaver/beaver_impl/ttp_server:beaver_stream", + "//libspu/mpc/utils:gfmp_ops", "//libspu/mpc/utils:ring_ops", "@com_github_microsoft_seal//:seal", "@yacl//yacl/link", @@ -40,6 +41,7 @@ spu_cc_test( ":beaver_ttp", "//libspu/core:xt_helper", "//libspu/mpc/semi2k/beaver/beaver_impl/ttp_server:beaver_server", + "//libspu/mpc/utils:gfmp", "//libspu/mpc/utils:permute", "//libspu/mpc/utils:simulate", "@com_google_googletest//:gtest", @@ -55,6 +57,7 @@ spu_cc_library( "//libspu/mpc/semi2k/beaver:beaver_interface", "//libspu/mpc/semi2k/beaver/beaver_impl/ttp_server:beaver_stream", "//libspu/mpc/semi2k/beaver/beaver_impl/ttp_server:service_cc_proto", + "//libspu/mpc/utils:gfmp_ops", "//libspu/mpc/utils:ring_ops", "@yacl//yacl/crypto/pke:sm2_enc", "@yacl//yacl/link", diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_test.cc b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_test.cc index 78cc71d2..300a2f6e 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_test.cc +++ b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_test.cc @@ -24,6 +24,7 @@ #include "libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.h" #include "libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.h" #include "libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server.h" +#include "libspu/mpc/utils/gfmp.h" #include "libspu/mpc/utils/permute.h" #include "libspu/mpc/utils/ring_ops.h" #include "libspu/mpc/utils/simulate.h" @@ -158,6 +159,60 @@ std::vector open_buffer(std::vector& in_buffers, } return ret; } + +template +std::vector open_buffer_gfmp(std::vector& in_buffers, + FieldType k_field, + const std::vector& shapes, + size_t k_world_size, bool add_open) { + std::vector ret; + + auto reduce = [&](NdArrayRef& r, yacl::Buffer& b) { + if (b.size() == 0) { + return; + } + EXPECT_EQ(b.size(), r.shape().numel() * SizeOf(k_field)); + NdArrayRef a(std::make_shared(std::move(b)), ret[0].eltype(), + r.shape()); + auto Ta = r.eltype(); + gfmp_add_mod_(r, a.as(Ta)); + }; + if constexpr (std::is_same_v) { + ret.resize(3); + SPU_ENFORCE(shapes.size() == 3); + for (size_t i = 0; i < shapes.size(); i++) { + ret[i] = gfmp_zeros(k_field, shapes[i]); + } + for (Rank r = 0; r < k_world_size; r++) { + auto& [a_buf, b_buf, c_buf] = in_buffers[r]; + reduce(ret[0], a_buf); + reduce(ret[1], b_buf); + reduce(ret[2], c_buf); + } + } else if constexpr (std::is_same_v) { + ret.resize(2); + SPU_ENFORCE(shapes.size() == 2); + for (size_t i = 0; i < shapes.size(); i++) { + ret[i] = gfmp_zeros(k_field, shapes[i]); + } + for (Rank r = 0; r < k_world_size; r++) { + auto& [a_buf, b_buf] = in_buffers[r]; + reduce(ret[0], a_buf); + reduce(ret[1], b_buf); + } + } else if constexpr (std::is_same_v) { + ret.resize(1); + SPU_ENFORCE(shapes.size() == 1); + for (size_t i = 0; i < shapes.size(); i++) { + ret[i] = gfmp_zeros(k_field, shapes[i]); + } + for (Rank r = 0; r < k_world_size; r++) { + auto& a_buf = in_buffers[r]; + reduce(ret[0], a_buf); + } + } + return ret; +} } // namespace TEST_P(BeaverTest, Mul_large) { @@ -215,11 +270,11 @@ TEST_P(BeaverTest, Mul_large) { DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); - NdArrayView _cache_a(x_cache); + NdArrayView _a_cache(x_cache); NdArrayView _b(open[1]); NdArrayView _c(open[2]); for (auto idx = 0; idx < _a.numel(); idx++) { - EXPECT_EQ(_cache_a[idx], _a[idx]); + EXPECT_EQ(_a_cache[idx], _a[idx]); auto t = _a[idx] * _b[idx]; auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; EXPECT_LE(err, kMaxDiff); @@ -242,10 +297,10 @@ TEST_P(BeaverTest, Mul_large) { DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _b(open[1]); - NdArrayView _cache_b(y_cache); + NdArrayView _b_cache(y_cache); NdArrayView _c(open[2]); for (auto idx = 0; idx < _a.numel(); idx++) { - EXPECT_EQ(_cache_b[idx], _b[idx]); + EXPECT_EQ(_b_cache[idx], _b[idx]); auto t = _a[idx] * _b[idx]; auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; EXPECT_LE(err, kMaxDiff); @@ -269,12 +324,12 @@ TEST_P(BeaverTest, Mul_large) { DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _b(open[1]); - NdArrayView _cache_a(x_cache); - NdArrayView _cache_b(y_cache); + NdArrayView _a_cache(x_cache); + NdArrayView _b_cache(y_cache); NdArrayView _c(open[2]); for (auto idx = 0; idx < _a.numel(); idx++) { - EXPECT_EQ(_cache_a[idx], _a[idx]); - EXPECT_EQ(_cache_b[idx], _b[idx]); + EXPECT_EQ(_a_cache[idx], _a[idx]); + EXPECT_EQ(_b_cache[idx], _b[idx]); auto t = _a[idx] * _b[idx]; auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; EXPECT_LE(err, kMaxDiff); @@ -299,14 +354,14 @@ TEST_P(BeaverTest, Mul_large) { DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _b(open[1]); - NdArrayView _cache_a(x_cache); - NdArrayView _cache_b(y_cache); + NdArrayView _a_cache(x_cache); + NdArrayView _b_cache(y_cache); NdArrayView _c(open[2]); for (auto idx = 0; idx < _a.numel(); idx++) { // mul not support transpose. // enforce ne - EXPECT_NE(_cache_a[idx], _a[idx]); - EXPECT_NE(_cache_b[idx], _b[idx]); + EXPECT_NE(_a_cache[idx], _a[idx]); + EXPECT_NE(_b_cache[idx], _b[idx]); auto t = _a[idx] * _b[idx]; auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; EXPECT_LE(err, kMaxDiff); @@ -370,11 +425,11 @@ TEST_P(BeaverTest, Mul) { DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); - NdArrayView _cache_a(x_cache); + NdArrayView _a_cache(x_cache); NdArrayView _b(open[1]); NdArrayView _c(open[2]); for (auto idx = 0; idx < _a.numel(); idx++) { - EXPECT_EQ(_cache_a[idx], _a[idx]); + EXPECT_EQ(_a_cache[idx], _a[idx]); auto t = _a[idx] * _b[idx]; auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; EXPECT_LE(err, kMaxDiff); @@ -397,10 +452,10 @@ TEST_P(BeaverTest, Mul) { DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _b(open[1]); - NdArrayView _cache_b(y_cache); + NdArrayView _b_cache(y_cache); NdArrayView _c(open[2]); for (auto idx = 0; idx < _a.numel(); idx++) { - EXPECT_EQ(_cache_b[idx], _b[idx]); + EXPECT_EQ(_b_cache[idx], _b[idx]); auto t = _a[idx] * _b[idx]; auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; EXPECT_LE(err, kMaxDiff); @@ -424,12 +479,12 @@ TEST_P(BeaverTest, Mul) { DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _b(open[1]); - NdArrayView _cache_a(x_cache); - NdArrayView _cache_b(y_cache); + NdArrayView _a_cache(x_cache); + NdArrayView _b_cache(y_cache); NdArrayView _c(open[2]); for (auto idx = 0; idx < _a.numel(); idx++) { - EXPECT_EQ(_cache_a[idx], _a[idx]); - EXPECT_EQ(_cache_b[idx], _b[idx]); + EXPECT_EQ(_a_cache[idx], _a[idx]); + EXPECT_EQ(_b_cache[idx], _b[idx]); auto t = _a[idx] * _b[idx]; auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; EXPECT_LE(err, kMaxDiff); @@ -454,14 +509,14 @@ TEST_P(BeaverTest, Mul) { DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _b(open[1]); - NdArrayView _cache_a(x_cache); - NdArrayView _cache_b(y_cache); + NdArrayView _a_cache(x_cache); + NdArrayView _b_cache(y_cache); NdArrayView _c(open[2]); for (auto idx = 0; idx < _a.numel(); idx++) { // mul not support transpose. // enforce ne - EXPECT_NE(_cache_a[idx], _a[idx]); - EXPECT_NE(_cache_b[idx], _b[idx]); + EXPECT_NE(_a_cache[idx], _a[idx]); + EXPECT_NE(_b_cache[idx], _b[idx]); auto t = _a[idx] * _b[idx]; auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; EXPECT_LE(err, kMaxDiff); @@ -470,6 +525,176 @@ TEST_P(BeaverTest, Mul) { } } +TEST_P(BeaverTest, MulGfmp) { + const auto factory = std::get<0>(GetParam()).first; + const size_t kWorldSize = std::get<1>(GetParam()); + const FieldType kField = std::get<2>(GetParam()); + const int64_t kMaxDiff = std::get<3>(GetParam()); + const size_t adjust_rank = std::get<4>(GetParam()); + const int64_t kNumel = 7; + + std::vector triples(kWorldSize); + + std::vector x_desc(kWorldSize); + std::vector y_desc(kWorldSize); + NdArrayRef x_cache; + NdArrayRef y_cache; + { + utils::simulate( + kWorldSize, [&](const std::shared_ptr& lctx) { + auto beaver = factory(lctx, ttp_options_, adjust_rank); + triples[lctx->Rank()] = + beaver->Mul(kField, kNumel, &x_desc[lctx->Rank()], + &y_desc[lctx->Rank()], ElementType::kGfmp); + yacl::link::Barrier(lctx, "BeaverUT"); + }); + + auto open = open_buffer_gfmp( + triples, kField, std::vector(3, {kNumel}), kWorldSize, true); + + DISPATCH_ALL_FIELDS(kField, [&]() { + NdArrayView _a(open[0]); + NdArrayView _b(open[1]); + NdArrayView _c(open[2]); + for (auto idx = 0; idx < _a.numel(); idx++) { + auto prime = ScalarTypeToPrime::prime; + auto t = mul_mod(_a[idx], _b[idx]); + auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; + auto error_mod_p = static_cast(err) % prime; + EXPECT_LE(error_mod_p, kMaxDiff); + } + }); + + x_cache = open[0]; + y_cache = open[1]; + } + { + utils::simulate(kWorldSize, + [&](const std::shared_ptr& lctx) { + auto beaver = factory(lctx, ttp_options_, adjust_rank); + x_desc[lctx->Rank()].status = Beaver::Replay; + triples[lctx->Rank()] = + beaver->Mul(kField, kNumel, &x_desc[lctx->Rank()], + nullptr, ElementType::kGfmp); + yacl::link::Barrier(lctx, "BeaverUT"); + }); + + auto open = open_buffer_gfmp( + triples, kField, std::vector(3, {kNumel}), kWorldSize, true); + + DISPATCH_ALL_FIELDS(kField, [&]() { + NdArrayView _a(open[0]); + NdArrayView _a_cache(x_cache); + NdArrayView _b(open[1]); + NdArrayView _c(open[2]); + for (auto idx = 0; idx < _a.numel(); idx++) { + auto prime = ScalarTypeToPrime::prime; + EXPECT_EQ(_a_cache[idx], _a[idx]); + auto t = mul_mod(_a[idx], _b[idx]) % prime; + auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; + auto error_mod_p = static_cast(err) % prime; + EXPECT_LE(error_mod_p, kMaxDiff); + } + }); + } + { + utils::simulate( + kWorldSize, [&](const std::shared_ptr& lctx) { + auto beaver = factory(lctx, ttp_options_, adjust_rank); + y_desc[lctx->Rank()].status = Beaver::Replay; + triples[lctx->Rank()] = + beaver->Mul(kField, kNumel, nullptr, &y_desc[lctx->Rank()], + ElementType::kGfmp); + yacl::link::Barrier(lctx, "BeaverUT"); + }); + + auto open = open_buffer_gfmp( + triples, kField, std::vector(3, {kNumel}), kWorldSize, true); + + DISPATCH_ALL_FIELDS(kField, [&]() { + NdArrayView _a(open[0]); + NdArrayView _b(open[1]); + NdArrayView _b_cache(y_cache); + NdArrayView _c(open[2]); + for (auto idx = 0; idx < _a.numel(); idx++) { + EXPECT_EQ(_b_cache[idx], _b[idx]); + auto t = mul_mod(_a[idx], _b[idx]); + auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; + auto prime = ScalarTypeToPrime::prime; + auto error_mod_p = static_cast(err) % prime; + EXPECT_LE(error_mod_p, kMaxDiff); + } + }); + } + { + utils::simulate( + kWorldSize, [&](const std::shared_ptr& lctx) { + auto beaver = factory(lctx, ttp_options_, adjust_rank); + x_desc[lctx->Rank()].status = Beaver::Replay; + y_desc[lctx->Rank()].status = Beaver::Replay; + triples[lctx->Rank()] = + beaver->Mul(kField, kNumel, &x_desc[lctx->Rank()], + &y_desc[lctx->Rank()], ElementType::kGfmp); + yacl::link::Barrier(lctx, "BeaverUT"); + }); + + auto open = open_buffer_gfmp( + triples, kField, std::vector(3, {kNumel}), kWorldSize, true); + + DISPATCH_ALL_FIELDS(kField, [&]() { + NdArrayView _a(open[0]); + NdArrayView _b(open[1]); + NdArrayView _a_cache(x_cache); + NdArrayView _b_cache(y_cache); + NdArrayView _c(open[2]); + for (auto idx = 0; idx < _a.numel(); idx++) { + EXPECT_EQ(_a_cache[idx], _a[idx]); + EXPECT_EQ(_b_cache[idx], _b[idx]); + auto t = mul_mod(_a[idx], _b[idx]); + auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; + auto prime = ScalarTypeToPrime::prime; + auto error_mod_p = static_cast(err) % prime; + EXPECT_LE(error_mod_p, kMaxDiff); + } + }); + } + { + utils::simulate( + kWorldSize, [&](const std::shared_ptr& lctx) { + auto beaver = factory(lctx, ttp_options_, adjust_rank); + x_desc[lctx->Rank()].status = Beaver::TransposeReplay; + y_desc[lctx->Rank()].status = Beaver::TransposeReplay; + // mul not support transpose. + triples[lctx->Rank()] = + beaver->Mul(kField, kNumel, &x_desc[lctx->Rank()], + &y_desc[lctx->Rank()], ElementType::kGfmp); + yacl::link::Barrier(lctx, "BeaverUT"); + }); + + auto open = open_buffer_gfmp( + triples, kField, std::vector(3, {kNumel}), kWorldSize, true); + + DISPATCH_ALL_FIELDS(kField, [&]() { + NdArrayView _a(open[0]); + NdArrayView _b(open[1]); + NdArrayView _a_cache(x_cache); + NdArrayView _b_cache(y_cache); + NdArrayView _c(open[2]); + for (auto idx = 0; idx < _a.numel(); idx++) { + // mul not support transpose. + // enforce ne + EXPECT_NE(_a_cache[idx], _a[idx]); + EXPECT_NE(_b_cache[idx], _b[idx]); + auto t = mul_mod(_a[idx], _b[idx]); + auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; + auto prime = ScalarTypeToPrime::prime; + auto error_mod_p = static_cast(err) % prime; + EXPECT_LE(error_mod_p, kMaxDiff); + } + }); + } +} + TEST_P(BeaverTest, And) { const auto factory = std::get<0>(GetParam()).first; const size_t kWorldSize = std::get<1>(GetParam()); @@ -566,11 +791,11 @@ TEST_P(BeaverTest, Dot) { auto res = ring_mmul(x_cache, open[1]); DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); - NdArrayView _cache_a(x_cache); + NdArrayView _a_cache(x_cache); NdArrayView _r(res); NdArrayView _c(open[2]); for (auto idx = 0; idx < res.numel(); idx++) { - EXPECT_EQ(_cache_a[idx], _a[idx]); + EXPECT_EQ(_a_cache[idx], _a[idx]); auto err = _r[idx] > _c[idx] ? _r[idx] - _c[idx] : _c[idx] - _r[idx]; EXPECT_LE(err, kMaxDiff); } @@ -593,11 +818,11 @@ TEST_P(BeaverTest, Dot) { auto res = ring_mmul(open[0], y_cache); DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _b(open[1]); - NdArrayView _cache_b(y_cache); + NdArrayView _b_cache(y_cache); NdArrayView _r(res); NdArrayView _c(open[2]); for (auto idx = 0; idx < res.numel(); idx++) { - EXPECT_EQ(_cache_b[idx], _b[idx]); + EXPECT_EQ(_b_cache[idx], _b[idx]); auto err = _r[idx] > _c[idx] ? _r[idx] - _c[idx] : _c[idx] - _r[idx]; EXPECT_LE(err, kMaxDiff); } @@ -621,14 +846,14 @@ TEST_P(BeaverTest, Dot) { auto res = ring_mmul(x_cache, y_cache); DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); - NdArrayView _cache_a(x_cache); + NdArrayView _a_cache(x_cache); NdArrayView _b(open[1]); - NdArrayView _cache_b(y_cache); + NdArrayView _b_cache(y_cache); NdArrayView _r(res); NdArrayView _c(open[2]); for (auto idx = 0; idx < res.numel(); idx++) { - EXPECT_EQ(_cache_a[idx], _a[idx]); - EXPECT_EQ(_cache_b[idx], _b[idx]); + EXPECT_EQ(_a_cache[idx], _a[idx]); + EXPECT_EQ(_b_cache[idx], _b[idx]); auto err = _r[idx] > _c[idx] ? _r[idx] - _c[idx] : _c[idx] - _r[idx]; EXPECT_LE(err, kMaxDiff); } @@ -653,16 +878,16 @@ TEST_P(BeaverTest, Dot) { DISPATCH_ALL_FIELDS(kField, [&]() { auto transpose_a = open[0].transpose(); NdArrayView _a(transpose_a); - NdArrayView _cache_a(y_cache); + NdArrayView _a_cache(y_cache); auto transpose_b = open[1].transpose(); NdArrayView _b(transpose_b); - NdArrayView _cache_b(x_cache); + NdArrayView _b_cache(x_cache); auto transpose_r = res.transpose(); NdArrayView _r(transpose_r); NdArrayView _c(open[2]); for (auto idx = 0; idx < res.numel(); idx++) { - EXPECT_EQ(_cache_a[idx], _a[idx]); - EXPECT_EQ(_cache_b[idx], _b[idx]); + EXPECT_EQ(_a_cache[idx], _a[idx]); + EXPECT_EQ(_b_cache[idx], _b[idx]); auto err = _r[idx] > _c[idx] ? _r[idx] - _c[idx] : _c[idx] - _r[idx]; EXPECT_LE(err, kMaxDiff); } @@ -685,11 +910,11 @@ TEST_P(BeaverTest, Dot) { DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); - NdArrayView _cache_a(x_cache); + NdArrayView _a_cache(x_cache); NdArrayView _b(open[1]); NdArrayView _c(open[2]); for (auto idx = 0; idx < _a.numel(); idx++) { - EXPECT_EQ(_cache_a[idx], _a[idx]); + EXPECT_EQ(_a_cache[idx], _a[idx]); auto t = _a[idx] * _b[idx]; auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; EXPECT_LE(err, kMaxDiff); diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.cc b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.cc index 402d1c0d..f876209d 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.cc +++ b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.cc @@ -23,6 +23,7 @@ #include "libspu/mpc/common/prg_tensor.h" #include "libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.h" +#include "libspu/mpc/utils/gfmp_ops.h" #include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc::semi2k { @@ -32,9 +33,9 @@ namespace { inline size_t CeilDiv(size_t a, size_t b) { return (a + b - 1) / b; } void FillReplayDesc(Beaver::ReplayDesc* desc, FieldType field, int64_t size, - const std::vector& encrypted_seeds, - PrgCounter counter, PrgSeed self_seed) { + PrgCounter counter, PrgSeed self_seed, + ElementType eltype = ElementType::kRing) { if (desc == nullptr || desc->status != Beaver::Init) { return; } @@ -43,6 +44,7 @@ void FillReplayDesc(Beaver::ReplayDesc* desc, FieldType field, int64_t size, desc->prg_counter = counter; desc->encrypted_seeds = encrypted_seeds; desc->seed = self_seed; + desc->eltype = eltype; } } // namespace @@ -67,7 +69,8 @@ BeaverTfpUnsafe::BeaverTfpUnsafe(std::shared_ptr lctx) BeaverTfpUnsafe::Triple BeaverTfpUnsafe::Mul(FieldType field, int64_t size, ReplayDesc* x_desc, - ReplayDesc* y_desc) { + ReplayDesc* y_desc, + ElementType eltype) { std::vector ops(3); Shape shape({size, 1}); std::vector> replay_seeds(3); @@ -75,9 +78,13 @@ BeaverTfpUnsafe::Triple BeaverTfpUnsafe::Mul(FieldType field, int64_t size, auto if_replay = [&](const ReplayDesc* replay_desc, size_t idx) { if (replay_desc == nullptr || replay_desc->status != Beaver::Replay) { ops[idx].seeds = seeds_; - return prgCreateArray(field, shape, seed_, &counter_, &ops[idx].desc); + // enforce the eltypes in ops + ops[idx].desc.eltype = eltype; + return prgCreateArray(field, shape, seed_, &counter_, &ops[idx].desc, + eltype); } else { SPU_ENFORCE(replay_desc->field == field); + SPU_ENFORCE(replay_desc->eltype == eltype); SPU_ENFORCE(replay_desc->size == size); if (lctx_->Rank() == 0) { SPU_ENFORCE(replay_desc->encrypted_seeds.size() == lctx_->WorldSize()); @@ -90,25 +97,31 @@ BeaverTfpUnsafe::Triple BeaverTfpUnsafe::Mul(FieldType field, int64_t size, } ops[idx].seeds = replay_seeds[idx]; ops[idx].desc.field = field; + ops[idx].desc.eltype = eltype; ops[idx].desc.shape = shape; ops[idx].desc.prg_counter = replay_desc->prg_counter; } PrgCounter tmp_counter = replay_desc->prg_counter; return prgCreateArray(field, shape, replay_desc->seed, &tmp_counter, - nullptr); + nullptr, eltype); } }; - FillReplayDesc(x_desc, field, size, seeds_buff_, counter_, seed_); + FillReplayDesc(x_desc, field, size, seeds_buff_, counter_, seed_, eltype); auto a = if_replay(x_desc, 0); - FillReplayDesc(y_desc, field, size, seeds_buff_, counter_, seed_); + FillReplayDesc(y_desc, field, size, seeds_buff_, counter_, seed_, eltype); auto b = if_replay(y_desc, 1); - auto c = prgCreateArray(field, shape, seed_, &counter_, &ops[2].desc); + auto c = prgCreateArray(field, shape, seed_, &counter_, &ops[2].desc, eltype); if (lctx_->Rank() == 0) { ops[2].seeds = seeds_; auto adjust = TrustedParty::adjustMul(absl::MakeSpan(ops)); - ring_add_(c, adjust); + if (eltype == ElementType::kGfmp) { + auto T = c.eltype(); + gfmp_add_mod_(c, adjust.as(T)); + } else { + ring_add_(c, adjust); + } } Triple ret; @@ -119,6 +132,37 @@ BeaverTfpUnsafe::Triple BeaverTfpUnsafe::Mul(FieldType field, int64_t size, return ret; } +BeaverTfpUnsafe::Pair BeaverTfpUnsafe::MulPriv(FieldType field, int64_t size, + ElementType eltype) { + std::vector ops(2); + Shape shape({size, 1}); + + ops[0].seeds = seeds_; + // enforce the eltypes in ops + ops[0].desc.eltype = eltype; + ops[1].desc.eltype = eltype; + auto a_or_b = + prgCreateArray(field, shape, seed_, &counter_, &ops[0].desc, eltype); + auto c = prgCreateArray(field, shape, seed_, &counter_, &ops[1].desc, eltype); + + if (lctx_->Rank() == 0) { + ops[1].seeds = seeds_; + auto adjust = TrustedParty::adjustMulPriv(absl::MakeSpan(ops)); + if (eltype == ElementType::kGfmp) { + auto T = c.eltype(); + gfmp_add_mod_(c, adjust.as(T)); + } else { + ring_add_(c, adjust); + } + } + + Pair ret; + std::get<0>(ret) = std::move(*a_or_b.buf()); + std::get<1>(ret) = std::move(*c.buf()); + + return ret; +} + BeaverTfpUnsafe::Pair BeaverTfpUnsafe::Square(FieldType field, int64_t size, ReplayDesc* x_desc) { std::vector ops(2); diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.h b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.h index 9ca11bca..2f26a716 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.h +++ b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.h @@ -45,7 +45,11 @@ class BeaverTfpUnsafe final : public Beaver { explicit BeaverTfpUnsafe(std::shared_ptr lctx); Triple Mul(FieldType field, int64_t size, ReplayDesc* x_desc = nullptr, - ReplayDesc* y_desc = nullptr) override; + ReplayDesc* y_desc = nullptr, + ElementType eltype = ElementType::kRing) override; + + Pair MulPriv(FieldType field, int64_t size, + ElementType eltype = ElementType::kRing) override; Pair Square(FieldType field, int64_t size, ReplayDesc* x_desc = nullptr) override; diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.cc b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.cc index c29fa32b..be2e9e86 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.cc +++ b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.cc @@ -24,6 +24,7 @@ #include "libspu/mpc/common/prg_tensor.h" #include "libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_stream.h" +#include "libspu/mpc/utils/gfmp_ops.h" #include "libspu/mpc/utils/ring_ops.h" namespace brpc { @@ -41,7 +42,8 @@ inline size_t CeilDiv(size_t a, size_t b) { return (a + b - 1) / b; } void FillReplayDesc(Beaver::ReplayDesc* desc, FieldType field, int64_t size, const std::vector& encrypted_seeds, - PrgCounter counter, PrgSeed self_seed) { + PrgCounter counter, PrgSeed self_seed, + ElementType eltype = ElementType::kRing) { if (desc == nullptr || desc->status != Beaver::Init) { return; } @@ -50,6 +52,7 @@ void FillReplayDesc(Beaver::ReplayDesc* desc, FieldType field, int64_t size, desc->prg_counter = counter; desc->encrypted_seeds = encrypted_seeds; desc->seed = self_seed; + desc->eltype = eltype; } template @@ -61,11 +64,15 @@ AdjustRequest BuildAdjustRequest( SPU_ENFORCE(!descs.empty()); uint32_t field_size; + ElementType eltype = ElementType::kRing; + for (size_t i = 0; i < descs.size(); i++) { const auto& desc = descs[i]; auto* input = ret.add_prg_inputs(); input->set_prg_count(desc.prg_counter); field_size = SizeOf(desc.field); + eltype = desc.eltype; + input->set_buffer_len(desc.shape.numel() * SizeOf(desc.field)); absl::Span seeds; @@ -83,6 +90,14 @@ AdjustRequest BuildAdjustRequest( beaver::ttp_server::AdjustAndRequest>) { ret.set_field_size(field_size); } + if constexpr (std::is_same_v || + std::is_same_v) { + if (eltype == ElementType::kGfmp) + ret.set_element_type(beaver::ttp_server::ElType::GFMP); + } + return ret; } @@ -223,6 +238,10 @@ std::vector RpcCall( if constexpr (std::is_same_v) { stub.AdjustMul(&cntl, &req, &rsp, nullptr); + } else if constexpr (std::is_same_v< + AdjustRequest, + beaver::ttp_server::AdjustMulPrivRequest>) { + stub.AdjustMulPriv(&cntl, &req, &rsp, nullptr); } else if constexpr (std::is_same_v< AdjustRequest, beaver::ttp_server::AdjustSquareRequest>) { @@ -340,15 +359,18 @@ BeaverTtp::BeaverTtp(std::shared_ptr lctx, Options ops) "BEAVER_TTP:SYNC_ENCRYPTED_SEEDS"); } +// TODO: kGfmp supports more operations BeaverTtp::Triple BeaverTtp::Mul(FieldType field, int64_t size, - ReplayDesc* x_desc, ReplayDesc* y_desc) { + ReplayDesc* x_desc, ReplayDesc* y_desc, + ElementType eltype) { std::vector descs(3); std::vector> descs_seed(3, encrypted_seeds_); Shape shape({size, 1}); auto if_replay = [&](const ReplayDesc* replay_desc, size_t idx) { if (replay_desc == nullptr || replay_desc->status != Beaver::Replay) { - return prgCreateArray(field, shape, seed_, &counter_, &descs[idx]); + return prgCreateArray(field, shape, seed_, &counter_, &descs[idx], + eltype); } else { SPU_ENFORCE(replay_desc->field == field); SPU_ENFORCE(replay_desc->size == size); @@ -356,27 +378,35 @@ BeaverTtp::Triple BeaverTtp::Mul(FieldType field, int64_t size, if (lctx_->Rank() == options_.adjust_rank) { descs_seed[idx] = replay_desc->encrypted_seeds; descs[idx].field = field; + descs[idx].eltype = eltype; descs[idx].shape = shape; descs[idx].prg_counter = replay_desc->prg_counter; } PrgCounter tmp_counter = replay_desc->prg_counter; return prgCreateArray(field, shape, replay_desc->seed, &tmp_counter, - &descs[idx]); + &descs[idx], eltype); } }; - FillReplayDesc(x_desc, field, size, encrypted_seeds_, counter_, seed_); + FillReplayDesc(x_desc, field, size, encrypted_seeds_, counter_, seed_, + eltype); auto a = if_replay(x_desc, 0); - FillReplayDesc(y_desc, field, size, encrypted_seeds_, counter_, seed_); + FillReplayDesc(y_desc, field, size, encrypted_seeds_, counter_, seed_, + eltype); auto b = if_replay(y_desc, 1); - auto c = prgCreateArray(field, shape, seed_, &counter_, &descs[2]); + auto c = prgCreateArray(field, shape, seed_, &counter_, &descs[2], eltype); if (lctx_->Rank() == options_.adjust_rank) { auto req = BuildAdjustRequest( descs, descs_seed); auto adjusts = RpcCall(channel_, req, field); SPU_ENFORCE_EQ(adjusts.size(), 1U); - ring_add_(c, adjusts[0].reshape(shape)); + if (eltype == ElementType::kGfmp) { + auto T = c.eltype(); + gfmp_add_mod_(c, adjusts[0].reshape(shape).as(T)); + } else { + ring_add_(c, adjusts[0].reshape(shape)); + } } Triple ret; @@ -387,6 +417,34 @@ BeaverTtp::Triple BeaverTtp::Mul(FieldType field, int64_t size, return ret; } +BeaverTtp::Pair BeaverTtp::MulPriv(FieldType field, int64_t size, + ElementType eltype) { + std::vector descs(2); + std::vector> descs_seed(2, encrypted_seeds_); + Shape shape({size, 1}); + auto a_or_b = + prgCreateArray(field, shape, seed_, &counter_, &descs[0], eltype); + auto c = prgCreateArray(field, shape, seed_, &counter_, &descs[1], eltype); + if (lctx_->Rank() == options_.adjust_rank) { + auto req = BuildAdjustRequest( + descs, descs_seed); + auto adjusts = RpcCall(channel_, req, field); + SPU_ENFORCE_EQ(adjusts.size(), 1U); + if (eltype == ElementType::kGfmp) { + auto T = c.eltype(); + gfmp_add_mod_(c, adjusts[0].reshape(shape).as(T)); + } else { + ring_add_(c, adjusts[0].reshape(shape)); + } + } + + Pair ret; + std::get<0>(ret) = std::move(*a_or_b.buf()); + std::get<1>(ret) = std::move(*c.buf()); + + return ret; +} + BeaverTtp::Pair BeaverTtp::Square(FieldType field, int64_t size, ReplayDesc* x_desc) { std::vector descs(2); diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.h b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.h index ecb39237..501d5eac 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.h +++ b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.h @@ -66,7 +66,11 @@ class BeaverTtp final : public Beaver { ~BeaverTtp() override = default; Triple Mul(FieldType field, int64_t size, ReplayDesc* x_desc = nullptr, - ReplayDesc* y_desc = nullptr) override; + ReplayDesc* y_desc = nullptr, + ElementType eltype = ElementType::kRing) override; + + Pair MulPriv(FieldType field, int64_t size, + ElementType eltype = ElementType::kRing) override; Pair Square(FieldType field, int64_t size, ReplayDesc* x_desc = nullptr) override; diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/BUILD.bazel b/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/BUILD.bazel index 5a503213..2613516a 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/BUILD.bazel +++ b/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/BUILD.bazel @@ -21,7 +21,9 @@ spu_cc_library( srcs = ["trusted_party.cc"], hdrs = ["trusted_party.h"], deps = [ + "//libspu/core:type_util", "//libspu/mpc/common:prg_tensor", + "//libspu/mpc/utils:gfmp_ops", "//libspu/mpc/utils:permute", "//libspu/mpc/utils:ring_ops", ], diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.cc b/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.cc index 1ff405ad..75c528d6 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.cc +++ b/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.cc @@ -14,22 +14,32 @@ #include "libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.h" +#include "libspu/core/type_util.h" +#include "libspu/mpc/common/prg_tensor.h" +#include "libspu/mpc/utils/gfmp_ops.h" #include "libspu/mpc/utils/permute.h" #include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc::semi2k { namespace { +enum class ReduceOp : uint8_t { + ADD = 0, + XOR = 1, + MUL = 2, +}; + enum class RecOp : uint8_t { ADD = 0, XOR = 1, }; -std::vector reconstruct(RecOp op, - absl::Span ops) { +std::vector reduce(ReduceOp op, + absl::Span ops) { std::vector rs(ops.size()); const auto world_size = ops[0].seeds.size(); + for (size_t rank = 0; rank < world_size; rank++) { for (size_t idx = 0; idx < ops.size(); idx++) { // FIXME: TTP adjuster server and client MUST have same endianness. @@ -43,12 +53,25 @@ std::vector reconstruct(RecOp op, if (rank == 0) { rs[idx] = t; } else { - if (op == RecOp::ADD) { - ring_add_(rs[idx], t); - } else if (op == RecOp::XOR) { + if (op == ReduceOp::ADD) { + if (ops[idx].desc.eltype == ElementType::kGfmp) { + // TODO: generalize the reduction + gfmp_add_mod_(rs[idx], t); + } else { + ring_add_(rs[idx], t); + } + } else if (op == ReduceOp::XOR) { + // gfmp has no xor implementation ring_xor_(rs[idx], t); + } else if (op == ReduceOp::MUL) { + if (ops[idx].desc.eltype == ElementType::kGfmp) { + // TODO: generalize the reduction + gfmp_mul_mod_(rs[idx], t); + } else { + ring_mul_(rs[idx], t); + } } else { - SPU_ENFORCE("not supported reconstruct op"); + SPU_THROW("not supported reduction op"); } } } @@ -57,11 +80,17 @@ std::vector reconstruct(RecOp op, return rs; } +std::vector reconstruct(RecOp op, + absl::Span ops) { + return reduce(ReduceOp(op), ops); +} + void checkOperands(absl::Span ops, bool skip_shape = false, bool allow_transpose = false) { for (size_t idx = 1; idx < ops.size(); idx++) { SPU_ENFORCE(skip_shape || ops[0].desc.shape == ops[idx].desc.shape); SPU_ENFORCE(allow_transpose || ops[0].transpose == false); + SPU_ENFORCE(ops[0].desc.eltype == ops[idx].desc.eltype); SPU_ENFORCE(ops[0].desc.field == ops[idx].desc.field); SPU_ENFORCE(ops[0].seeds.size() == ops[idx].seeds.size(), "{} <> {}", ops[0].seeds.size(), ops[idx].seeds.size()); @@ -70,13 +99,43 @@ void checkOperands(absl::Span ops, } // namespace +// TODO: gfmp support more operations NdArrayRef TrustedParty::adjustMul(absl::Span ops) { SPU_ENFORCE_EQ(ops.size(), 3U); checkOperands(ops); auto rs = reconstruct(RecOp::ADD, ops); // adjust = rs[0] * rs[1] - rs[2]; - return ring_sub(ring_mul(rs[0], rs[1]), rs[2]); + if (ops[0].desc.eltype == ElementType::kGfmp) { + return gfmp_sub_mod(gfmp_mul_mod(rs[0], rs[1]), rs[2]); + } else { + ring_mul_(rs[0], rs[1]); + ring_sub_(rs[0], rs[2]); + return rs[0]; + } +} + +// ops are [a_or_b, c] +// P0 generate a, c0 +// P1 generate b, c1 +// The adjustment is ab - (c0 + c1), +// which only needs to be sent to adjust party, e.g. P0. +// P0 with adjust is ab - c1 = ab - (c0 + c1) + c0 +// Therefore, +// P0 holds: a, ab - c1 +// P1 holds: b, c1 +NdArrayRef TrustedParty::adjustMulPriv(absl::Span ops) { + SPU_ENFORCE_EQ(ops.size(), 2U); + checkOperands(ops); + + auto ab = reduce(ReduceOp::MUL, ops.subspan(0, 1))[0]; + auto c = reconstruct(RecOp::ADD, ops.subspan(1, 1))[0]; + // adjust = ab - c; + if (ops[0].desc.eltype == ElementType::kGfmp) { + return gfmp_sub_mod(ab, c); + } else { + return ring_sub(ab, c); + } } NdArrayRef TrustedParty::adjustSquare(absl::Span ops) { @@ -84,7 +143,9 @@ NdArrayRef TrustedParty::adjustSquare(absl::Span ops) { auto rs = reconstruct(RecOp::ADD, ops); // adjust = rs[0] * rs[0] - rs[1]; - return ring_sub(ring_mul(rs[0], rs[0]), rs[1]); + ring_mul_(rs[0], rs[0]); + ring_sub_(rs[0], rs[1]); + return rs[0]; } NdArrayRef TrustedParty::adjustDot(absl::Span ops) { @@ -101,7 +162,9 @@ NdArrayRef TrustedParty::adjustDot(absl::Span ops) { } // adjust = rs[0] dot rs[1] - rs[2]; - return ring_sub(ring_mmul(rs[0], rs[1]), rs[2]); + auto dot = ring_mmul(rs[0], rs[1]); + ring_sub_(dot, rs[2]); + return dot; } NdArrayRef TrustedParty::adjustAnd(absl::Span ops) { @@ -110,7 +173,9 @@ NdArrayRef TrustedParty::adjustAnd(absl::Span ops) { auto rs = reconstruct(RecOp::XOR, ops); // adjust = (rs[0] & rs[1]) ^ rs[2]; - return ring_xor(ring_and(rs[0], rs[1]), rs[2]); + ring_and_(rs[0], rs[1]); + ring_xor_(rs[0], rs[2]); + return rs[0]; } NdArrayRef TrustedParty::adjustTrunc(absl::Span ops, size_t bits) { @@ -119,7 +184,9 @@ NdArrayRef TrustedParty::adjustTrunc(absl::Span ops, size_t bits) { auto rs = reconstruct(RecOp::ADD, ops); // adjust = (rs[0] >> bits) - rs[1]; - return ring_sub(ring_arshift(rs[0], {static_cast(bits)}), rs[1]); + ring_arshift_(rs[0], {static_cast(bits)}); + ring_sub_(rs[0], rs[1]); + return rs[0]; } std::pair TrustedParty::adjustTruncPr( @@ -131,15 +198,14 @@ std::pair TrustedParty::adjustTruncPr( auto rs = reconstruct(RecOp::ADD, ops); // adjust1 = ((rs[0] << 1) >> (bits + 1)) - rs[1]; - auto adjust1 = ring_sub( - ring_rshift(ring_lshift(rs[0], {1}), {static_cast(bits + 1)}), - rs[1]); + auto adjust1 = ring_lshift(rs[0], {1}); + ring_rshift_(adjust1, {static_cast(bits + 1)}); + ring_sub_(adjust1, rs[1]); // adjust2 = (rs[0] >> (k - 1)) - rs[2]; const size_t k = SizeOf(ops[0].desc.field) * 8; - auto adjust2 = - ring_sub(ring_rshift(rs[0], {static_cast(k - 1)}), rs[2]); - + auto adjust2 = ring_rshift(rs[0], {static_cast(k - 1)}); + ring_sub_(adjust2, rs[2]); return {adjust1, adjust2}; } @@ -148,7 +214,9 @@ NdArrayRef TrustedParty::adjustRandBit(absl::Span ops) { auto rs = reconstruct(RecOp::ADD, ops); // adjust = bitrev - rs[0]; - return ring_sub(ring_randbit(ops[0].desc.field, ops[0].desc.shape), rs[0]); + auto randbits = ring_randbit(ops[0].desc.field, ops[0].desc.shape); + ring_sub_(randbits, rs[0]); + return randbits; } NdArrayRef TrustedParty::adjustEqz(absl::Span ops) { diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.h b/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.h index 55a412e9..60098256 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.h +++ b/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.h @@ -33,6 +33,8 @@ class TrustedParty { static NdArrayRef adjustMul(absl::Span); + static NdArrayRef adjustMulPriv(absl::Span); + static NdArrayRef adjustSquare(absl::Span); static NdArrayRef adjustDot(absl::Span); diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server.cc b/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server.cc index 30a72277..2a5136ec 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server.cc +++ b/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server.cc @@ -52,7 +52,8 @@ template std::tuple, std::vector>, size_t> BuildOperand(const AdjustRequest& req, uint32_t field_size, - const std::unique_ptr& decryptor) { + const std::unique_ptr& decryptor, + ElementType eltype) { std::vector ops; std::vector> seeds; size_t pad_length = 0; @@ -140,7 +141,7 @@ BuildOperand(const AdjustRequest& req, uint32_t field_size, } seeds.emplace_back(std::move(seed)); ops.push_back( - TrustedParty::Operand{{shape, type, prg_count}, seeds.back()}); + TrustedParty::Operand{{shape, type, prg_count, eltype}, seeds.back()}); } if constexpr (std::is_same_v) { @@ -305,6 +306,9 @@ std::vector AdjustImpl(const AdjustRequest& req, if constexpr (std::is_same_v) { auto adjust = TrustedParty::adjustMul(ops); ret.push_back(std::move(adjust)); + } else if constexpr (std::is_same_v) { + auto adjust = TrustedParty::adjustMulPriv(ops); + ret.push_back(std::move(adjust)); } else if constexpr (std::is_same_v) { auto adjust = TrustedParty::adjustSquare(ops); ret.push_back(std::move(adjust)); @@ -357,7 +361,17 @@ void AdjustAndSend( } else { field_size = req.field_size(); } - auto [ops, seeds, pad_length] = BuildOperand(req, field_size, decryptor); + ElementType eltype = ElementType::kRing; + // enable eltype for selected requests here + // later all requests may support gfmp + if constexpr (std::is_same_v || + std::is_same_v) { + if (req.element_type() == ElType::GFMP) { + eltype = ElementType::kGfmp; + } + } + auto [ops, seeds, pad_length] = + BuildOperand(req, field_size, decryptor, eltype); if constexpr (std::is_same_v || std::is_same_v) { @@ -475,6 +489,12 @@ class ServiceImpl final : public BeaverService { Adjust(controller, req, rsp, done); } + void AdjustMulPriv(::google::protobuf::RpcController* controller, + const AdjustMulPrivRequest* req, AdjustResponse* rsp, + ::google::protobuf::Closure* done) override { + Adjust(controller, req, rsp, done); + } + void AdjustSquare(::google::protobuf::RpcController* controller, const AdjustSquareRequest* req, AdjustResponse* rsp, ::google::protobuf::Closure* done) override { diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/service.proto b/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/service.proto index 6b1b3675..23fd3025 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/service.proto +++ b/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/service.proto @@ -25,6 +25,13 @@ enum ErrorCode { StreamAcceptError = 3; } +// The type of element in the field. +// Match the enum in libspu/mpc/common/prg_tensor.h +enum ElType { + UNSPECIFIED = 0; + RING = 1; + GFMP = 2; +} // PRG generated buffer metainfo. // BeaverService replay PRG to generate same buffer using each party's prg_seed // encrypted by server's public key. PrgBufferMeta represent {world_size} @@ -42,6 +49,8 @@ service BeaverService { // V1 adjust ops rpc AdjustMul(AdjustMulRequest) returns (AdjustResponse); + rpc AdjustMulPriv(AdjustMulPrivRequest) returns (AdjustResponse); + rpc AdjustSquare(AdjustSquareRequest) returns (AdjustResponse); rpc AdjustDot(AdjustDotRequest) returns (AdjustResponse); @@ -69,6 +78,27 @@ message AdjustMulRequest { // adjust_c = ra * rb - rc // make // ra * rb = (adjust_c + rc) + + // element type supported: "GFMP", "RING" + ElType element_type = 3; + // if element type is "GFMP" then all ring ops will be changed to gfmp +} + +message AdjustMulPrivRequest { + // input 2 prg buffer + // first is a or b [one party holds a slice, another b slice] + // second is c + repeated PrgBufferMeta prg_inputs = 1; + // What field size should be used to interpret buffer content + uint32 field_size = 2; + // output + // adjust_c = a * b - rc + // make + // a * b = (adjust_c + rc) + + // element type supported: "GFMP", "RING" + ElType element_type = 3; + // if element type is "GFMP" then all ring ops will be changed to gfmp } message AdjustSquareRequest { diff --git a/libspu/mpc/semi2k/beaver/beaver_interface.h b/libspu/mpc/semi2k/beaver/beaver_interface.h index d610f380..89c58267 100644 --- a/libspu/mpc/semi2k/beaver/beaver_interface.h +++ b/libspu/mpc/semi2k/beaver/beaver_interface.h @@ -41,6 +41,7 @@ class Beaver { std::vector encrypted_seeds; int64_t size; FieldType field; + ElementType eltype; }; using Array = yacl::Buffer; @@ -50,8 +51,11 @@ class Beaver { virtual ~Beaver() = default; virtual Triple Mul(FieldType field, int64_t size, - ReplayDesc* x_desc = nullptr, - ReplayDesc* y_desc = nullptr) = 0; + ReplayDesc* x_desc = nullptr, ReplayDesc* y_desc = nullptr, + ElementType eltype = ElementType::kRing) = 0; + + virtual Pair MulPriv(FieldType field, int64_t size, + ElementType eltype = ElementType::kRing) = 0; virtual Pair Square(FieldType field, int64_t size, ReplayDesc* x_desc = nullptr) = 0; diff --git a/libspu/mpc/semi2k/exp.cc b/libspu/mpc/semi2k/exp.cc new file mode 100644 index 00000000..34dba15f --- /dev/null +++ b/libspu/mpc/semi2k/exp.cc @@ -0,0 +1,97 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "exp.h" + +#include "prime_utils.h" +#include "type.h" + +#include "libspu/mpc/utils/gfmp.h" +#include "libspu/mpc/utils/ring_ops.h" + +namespace spu::mpc::semi2k { + +// Given [x*2^fxp] mod 2k for x +// compute [exp(x) * 2^fxp] mod 2^k + +// Assume x is in valid range, otherwise the error may be too large to +// use this method. + +NdArrayRef ExpA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { + const size_t fxp = ctx->sctx()->getFxpBits(); + SPU_ENFORCE( + fxp < 64, + "fxp must be less than 64 for this method, or shift bit overflow ", + "may occur"); + auto field = in.eltype().as()->field(); + NdArrayRef x = in.clone(); + NdArrayRef out; + + // TODO: set different values for FM64 FM32 + const size_t kExpFxp = (field == FieldType::FM128) ? 24 : 13; + + const int rank = ctx->sctx()->lctx()->Rank(); + DISPATCH_ALL_FIELDS(field, [&]() { + auto total_fxp = kExpFxp + fxp; + // note that x is already encoded with fxp + // this conv scale further converts x int fixed point numbers with + // total_fxp + const ring2k_t exp_conv_scale = std::roundf(M_LOG2E * (1L << kExpFxp)); + + // offset scale should directly encoded to a fixed point with total_fxp + const ring2k_t offset = + ctx->sctx()->config().experimental_exp_prime_offset(); + const ring2k_t offset_scaled = offset << total_fxp; + + NdArrayView _x(x); + if (rank == 0) { + pforeach(0, x.numel(), [&](ring2k_t i) { + _x[i] *= exp_conv_scale; + _x[i] += offset_scaled; + }); + } else { + pforeach(0, x.numel(), [&](ring2k_t i) { _x[i] *= exp_conv_scale; }); + } + size_t shr_width = SizeOf(field) * 8 - fxp; + + const ring2k_t kBit = 1; + auto shifted_bit = kBit << total_fxp; + const ring2k_t frac_mask = shifted_bit - 1; + + auto int_part = ring_arshift(x, {static_cast(total_fxp)}); + + // convert from ring-share (int-part) to a prime share over p - 1 + int_part = ProbConvRing2k(int_part, rank, shr_width); + NdArrayView int_part_view(int_part); + + pforeach(0, x.numel(), [&](int64_t i) { + // y = 2^int_part mod p + ring2k_t y = exp_mod(2, int_part_view[i]); + // z = 2^fract_part in RR + double frac_part = static_cast(_x[i] & frac_mask) / shifted_bit; + frac_part = std::pow(2., frac_part); + + // Multiply the 2^{int_part} * 2^{frac_part} mod p + // note that mul_mod uses mersenne prime as modulus according to field + int_part_view[i] = mul_mod( + y, static_cast(std::roundf(frac_part * (kBit << fxp)))); + }); + + NdArrayRef muled = MulPrivModMP(ctx, int_part.as(makeType(field))); + + out = ConvMP(ctx, muled, offset + fxp); + }); + return out.as(in.eltype()); +} + +} // namespace spu::mpc::semi2k \ No newline at end of file diff --git a/libspu/mpc/semi2k/exp.h b/libspu/mpc/semi2k/exp.h new file mode 100644 index 00000000..fcc4711e --- /dev/null +++ b/libspu/mpc/semi2k/exp.h @@ -0,0 +1,37 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "libspu/mpc/kernel.h" + +namespace spu::mpc::semi2k { + +// Given [x*2^fxp] mod 2k for x +// compute [exp(x) * 2^fxp] mod 2^k +// Example: +// spu::mpc::semi2k::ExpA exp; +// outp = exp.proc(&kcontext, ring2k_shr); +class ExpA : public UnaryKernel { + public: + static constexpr const char* kBindName() { return "exp_a"; } + + ce::CExpr latency() const override { return ce::Const(2); } + + ce::CExpr comm() const override { return 2 * ce::K(); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override; +}; + +} // namespace spu::mpc::semi2k diff --git a/libspu/mpc/semi2k/prime_utils.cc b/libspu/mpc/semi2k/prime_utils.cc new file mode 100644 index 00000000..3d406f24 --- /dev/null +++ b/libspu/mpc/semi2k/prime_utils.cc @@ -0,0 +1,202 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "prime_utils.h" + +#include "type.h" + +#include "libspu/mpc/common/communicator.h" +#include "libspu/mpc/semi2k/state.h" +#include "libspu/mpc/utils/gfmp.h" +#include "libspu/mpc/utils/ring_ops.h" + +namespace spu::mpc::semi2k { + +NdArrayRef ProbConvRing2k(const NdArrayRef& inp_share, int rank, + size_t shr_width) { + SPU_ENFORCE(inp_share.eltype().isa()); + SPU_ENFORCE(rank >= 0 && rank <= 1); + + auto eltype = inp_share.eltype(); + NdArrayRef output_share(eltype, inp_share.shape()); + + auto ring_ty = eltype.as()->field(); + uint128_t shifted_bit = 1; + shifted_bit <<= shr_width; + auto mask = shifted_bit - 1; + // x mod p - 1 + // in our case p > 2^shr_width + + DISPATCH_ALL_FIELDS(ring_ty, [&]() { + const auto prime = ScalarTypeToPrime::prime; + ring2k_t prime_minus_one = (prime - 1); + NdArrayView inp(inp_share); + NdArrayView output_share_view(output_share); + pforeach(0, output_share.numel(), [&](int64_t i) { + output_share_view[i] = + rank == 0 ? ((inp[i] & mask) % prime_minus_one) + // numerical considerations here + // we wanted to work on ring 2k or field p - 1 + // however, if we do not add p -1 + // then the computation will resort to int128 + // due to the way computer works + : ((inp[i] & mask) + prime_minus_one - shifted_bit) % + prime_minus_one; + }); + }); + return output_share; +} + +NdArrayRef UnflattenBuffer(yacl::Buffer&& buf, const NdArrayRef& x) { + return NdArrayRef(std::make_shared(std::move(buf)), x.eltype(), + x.shape()); +} + +// P0 holds x,P1 holds y +// Beaver generates ab = c_0 + c_1 +// Give (a, c_0) to P0 +// Give (b, c_1) to P1 +std::tuple MulPrivPrep(KernelEvalContext* ctx, + const NdArrayRef& x) { + const auto field = x.eltype().as()->field(); + auto* beaver = ctx->getState()->beaver(); + + // generate beaver multiple triple. + NdArrayRef a_or_b; + NdArrayRef c; + + const size_t numel = x.shape().numel(); + auto [a_or_b_buf, c_buf] = beaver->MulPriv( + field, numel, // + x.eltype().isa() ? ElementType::kGfmp : ElementType::kRing); + SPU_ENFORCE(static_cast(a_or_b_buf.size()) == numel * SizeOf(field)); + SPU_ENFORCE(static_cast(c_buf.size()) == numel * SizeOf(field)); + + a_or_b = UnflattenBuffer(std::move(a_or_b_buf), x); + c = UnflattenBuffer(std::move(c_buf), x); + + return {std::move(a_or_b), std::move(c)}; +} + +// P0 holds x,P1 holds y +// Beaver generates ab = c_0 + c_1 +// Give (a, c_0) to P0 +// Give (b, c_1) to P1 +// +// - P0 sends (x+a) to P1 ; P1 sends (y+b) to P0 +// - P0 calculates z0 = x(y+b) + c0 ; P1 calculates z1 = -b(x+a) + c1 +NdArrayRef MulPriv(KernelEvalContext* ctx, const NdArrayRef& x) { + SPU_ENFORCE(x.eltype().isa()); + auto* comm = ctx->getState(); + + NdArrayRef a_or_b, c, xa_or_yb; + + std::tie(a_or_b, c) = MulPrivPrep(ctx, x); + + // P0 sends (x+a) to P1 ; P1 sends (y+b) to P0 + comm->sendAsync(comm->nextRank(), ring_add(a_or_b, x), "(x + a) or (y + b)"); + xa_or_yb = comm->recv(comm->prevRank(), x.eltype(), "(x + a) or (y + b)") + .reshape(x.shape()); + // note that our rings are commutative. + if (comm->getRank() == 0) { + ring_add_(c, ring_mul(std::move(xa_or_yb), x)); + } + if (comm->getRank() == 1) { + ring_sub_(c, ring_mul(std::move(xa_or_yb), a_or_b)); + } + return c; +} + +NdArrayRef MulPrivModMP(KernelEvalContext* ctx, const NdArrayRef& x) { + SPU_ENFORCE(x.eltype().isa()); + auto* comm = ctx->getState(); + + NdArrayRef a_or_b, c, xa_or_yb; + std::tie(a_or_b, c) = MulPrivPrep(ctx, x); + + comm->sendAsync(comm->nextRank(), gfmp_add_mod(a_or_b, x), "xa_or_yb"); + xa_or_yb = + comm->recv(comm->prevRank(), x.eltype(), "xa_or_yb").reshape(x.shape()); + + // note that our rings are commutative. + if (comm->getRank() == 0) { + gfmp_add_mod_(c, gfmp_mul_mod(std::move(xa_or_yb), x)); + } + if (comm->getRank() == 1) { + gfmp_sub_mod_(c, gfmp_mul_mod(std::move(xa_or_yb), a_or_b)); + } + return c; +} + +// We assume the input is ``positive'' +// Given h0 + h1 = h mod p and h < p / 2 +// Define b0 = 1{h0 >= p/2} +// b1 = 1{h1 >= p/2} +// Compute w = 1{h0 + h1 >= p} +// It can be proved that w = (b0 or b1) = not (not b0 and not b1) +NdArrayRef WrapBitModMP(KernelEvalContext* ctx, const NdArrayRef& x) { + // create a wrap bit NdArrayRef of the same shape as in + NdArrayRef b(x.eltype(), x.shape()); + + // for each element, we compute b = 1{h < p/2} for each private share piece + const auto numel = x.numel(); + const auto field = x.eltype().as()->field(); + + DISPATCH_ALL_FIELDS(field, [&]() { + ring2k_t prime = ScalarTypeToPrime::prime; + ring2k_t phalf = prime >> 1; + NdArrayView _x(x); + NdArrayView _b(b); + pforeach(0, numel, [&](int64_t idx) { + _b[idx] = static_cast(_x[idx] < phalf); + }); + + // do private mul + b = MulPriv(ctx, b.as(makeType(field))); + + // map 1 to 0 and 0 to 1, use 1 - x + if (ctx->getState()->getRank() == 0) { + pforeach(0, numel, [&](int64_t idx) { _b[idx] = 1 - _b[idx]; }); + } else { + pforeach(0, numel, [&](int64_t idx) { _b[idx] = -_b[idx]; }); + } + }); + + return b; +} +// Mersenne Prime share -> Ring2k share + +NdArrayRef ConvMP(KernelEvalContext* ctx, const NdArrayRef& h, + uint truncate_nbits) { + // calculate wrap bit + NdArrayRef w = WrapBitModMP(ctx, h); + const auto field = h.eltype().as()->field(); + const auto numel = h.numel(); + + // x = (h - p * w) mod 2^k + + NdArrayRef x(makeType(field), h.shape()); + DISPATCH_ALL_FIELDS(field, [&]() { + auto prime = ScalarTypeToPrime::prime; + NdArrayView h_view(h); + NdArrayView _x(x); + NdArrayView w_view(w); + pforeach(0, numel, [&](int64_t idx) { + _x[idx] = static_cast(h_view[idx] >> truncate_nbits) - + static_cast(prime >> truncate_nbits) * w_view[idx]; + }); + }); + return x; +} + +} // namespace spu::mpc::semi2k diff --git a/libspu/mpc/semi2k/prime_utils.h b/libspu/mpc/semi2k/prime_utils.h new file mode 100644 index 00000000..a04acf3a --- /dev/null +++ b/libspu/mpc/semi2k/prime_utils.h @@ -0,0 +1,46 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "libspu/core/context.h" +#include "libspu/mpc/kernel.h" + +namespace spu::mpc::semi2k { +// Ring2k share -> Mersenne Prime - 1 share +// Given x0 + x1 = x mod 2^k +// Compute h0 + h1 = x mod p with probability > 1 - |x|/2^k +NdArrayRef ProbConvRing2k(const NdArrayRef& inp_share, int rank, + size_t shr_width); + +// Mul open private share +std::tuple MulPrivPrep(KernelEvalContext* ctx, + const NdArrayRef& x); + +// Note that [x] = (x_alice, x_bob) and x_alice + x_bob = x +// Note that we actually want to find the muliplication of x_alice and x_bob +// this function is currently achieved by doing (x_alice, 0) * (0, x_bob) +// optimization is possible. +NdArrayRef MulPrivModMP(KernelEvalContext* ctx, const NdArrayRef& x); +// We assume the input is ``positive'' +// Given h0 + h1 = h mod p and h < p / 2 +// Define b0 = 1{h0 >= p/2} +// b1 = 1{h1 >= p/2} +// Compute w = 1{h0 + h1 >= p} +// It can be proved that w = (b0 or b1) +NdArrayRef WrapBitModMP(KernelEvalContext* ctx, const NdArrayRef& x); + +// Mersenne Prime share -> Ring2k share +NdArrayRef ConvMP(KernelEvalContext* ctx, const NdArrayRef& h, + uint truncate_nbits); +} // namespace spu::mpc::semi2k \ No newline at end of file diff --git a/libspu/mpc/semi2k/protocol.cc b/libspu/mpc/semi2k/protocol.cc index 35cd436c..33d6226b 100644 --- a/libspu/mpc/semi2k/protocol.cc +++ b/libspu/mpc/semi2k/protocol.cc @@ -20,6 +20,7 @@ #include "libspu/mpc/semi2k/arithmetic.h" #include "libspu/mpc/semi2k/boolean.h" #include "libspu/mpc/semi2k/conversion.h" +#include "libspu/mpc/semi2k/exp.h" #include "libspu/mpc/semi2k/permute.h" #include "libspu/mpc/semi2k/state.h" #include "libspu/mpc/semi2k/type.h" @@ -76,6 +77,13 @@ void regSemi2kProtocol(SPUContext* ctx, if (lctx->WorldSize() == 2) { ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + + // only supports 2pc fm128 for now + if (ctx->getField() == FieldType::FM128 && + ctx->config().experimental_enable_exp_prime()) { + ctx->prot()->regKernel(); + } } // ctx->prot()->regKernel(); } diff --git a/libspu/mpc/semi2k/protocol_test.cc b/libspu/mpc/semi2k/protocol_test.cc index 66911344..eb1a6c60 100644 --- a/libspu/mpc/semi2k/protocol_test.cc +++ b/libspu/mpc/semi2k/protocol_test.cc @@ -25,7 +25,11 @@ #include "libspu/mpc/api_test.h" #include "libspu/mpc/common/communicator.h" #include "libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server.h" +#include "libspu/mpc/semi2k/exp.h" +#include "libspu/mpc/semi2k/prime_utils.h" #include "libspu/mpc/semi2k/state.h" +#include "libspu/mpc/semi2k/type.h" +#include "libspu/mpc/utils/gfmp.h" #include "libspu/mpc/utils/ring_ops.h" #include "libspu/mpc/utils/simulate.h" @@ -36,6 +40,12 @@ RuntimeConfig makeConfig(FieldType field) { RuntimeConfig conf; conf.set_protocol(ProtocolKind::SEMI2K); conf.set_field(field); + if (field == FieldType::FM64) { + conf.set_fxp_fraction_bits(17); + } else if (field == FieldType::FM128) { + conf.set_fxp_fraction_bits(40); + } + conf.set_experimental_enable_exp_prime(true); return conf; } @@ -404,4 +414,173 @@ TEST_P(BeaverCacheTest, SquareA) { }); } +TEST_P(BeaverCacheTest, priv_mul_test) { + const auto factory = std::get<0>(GetParam()); + const RuntimeConfig& conf = std::get<1>(GetParam()); + const size_t npc = std::get<2>(GetParam()); + // only supports 2 party (not counting beaver) + if (npc != 2) { + return; + } + NdArrayRef ring2k_shr[2]; + + int64_t numel = 1; + FieldType field = conf.field(); + + std::vector real_vec(numel); + for (int64_t i = 0; i < numel; ++i) { + real_vec[i] = 2; + } + + auto rnd_msg = gfmp_zeros(field, {numel}); + + DISPATCH_ALL_FIELDS(field, [&]() { + using sT = std::make_signed::type; + NdArrayView xmsg(rnd_msg); + pforeach(0, numel, [&](int64_t i) { xmsg[i] = std::round(real_vec[i]); }); + }); + + ring2k_shr[0] = rnd_msg; + ring2k_shr[1] = rnd_msg; + + NdArrayRef input, outp_pub; + NdArrayRef outp[2]; + + utils::simulate(npc, [&](const std::shared_ptr& lctx) { + auto obj = factory(conf, lctx); + + KernelEvalContext kcontext(obj.get()); + + int rank = lctx->Rank(); + + outp[rank] = spu::mpc::semi2k::MulPrivModMP(&kcontext, ring2k_shr[rank]); + }); + auto got = gfmp_add_mod(outp[0], outp[1]); + DISPATCH_ALL_FIELDS(field, [&]() { + using sT = std::make_signed::type; + NdArrayView got_view(got); + + double max_err = 0.0; + double min_err = 99.0; + for (int64_t i = 0; i < numel; ++i) { + double expected = real_vec[i] * real_vec[i]; + double got = static_cast(got_view[i]); + max_err = std::max(max_err, std::abs(expected - got)); + min_err = std::min(min_err, std::abs(expected - got)); + } + ASSERT_LE(min_err, 1e-3); + ASSERT_LE(max_err, 1e-3); + }); +} + +TEST_P(BeaverCacheTest, exp_mod_test) { + const RuntimeConfig& conf = std::get<1>(GetParam()); + FieldType field = conf.field(); + + DISPATCH_ALL_FIELDS(field, [&]() { + // exponents < 32 + ring2k_t exponents[5] = {10, 21, 27}; + + for (ring2k_t exponent : exponents) { + ring2k_t y = exp_mod(2, exponent); + ring2k_t prime = ScalarTypeToPrime::prime; + ring2k_t prime_minus_one = (prime - 1); + ring2k_t shifted_bit = 1; + shifted_bit <<= exponent; + EXPECT_EQ(y, shifted_bit % prime_minus_one); + } + }); +} + +TEST_P(BeaverCacheTest, ExpA) { + const auto factory = std::get<0>(GetParam()); + const RuntimeConfig& conf = std::get<1>(GetParam()); + const size_t npc = std::get<2>(GetParam()); + // exp only supports 2 party (not counting beaver) + // only supports FM128 for now + // note not using ctx->hasKernel("exp_a") because we are testing kernel + // registration as well. + if (npc != 2 || conf.field() != FieldType::FM128) { + return; + } + auto fxp = conf.fxp_fraction_bits(); + + NdArrayRef ring2k_shr[2]; + + int64_t numel = 100; + FieldType field = conf.field(); + + // how to define and achieve high pricision for e^20 + std::uniform_real_distribution dist(-18.0, 15.0); + std::default_random_engine rd; + std::vector real_vec(numel); + for (int64_t i = 0; i < numel; ++i) { + // make the input a fixed point number, eliminate the fixed point encoding + // error + real_vec[i] = + static_cast(std::round((dist(rd) * (1L << fxp)))) / (1L << fxp); + } + + auto rnd_msg = ring_zeros(field, {numel}); + + DISPATCH_ALL_FIELDS(field, [&]() { + using sT = std::make_signed::type; + NdArrayView xmsg(rnd_msg); + pforeach(0, numel, [&](int64_t i) { + xmsg[i] = std::round(real_vec[i] * (1L << fxp)); + }); + }); + + ring2k_shr[0] = ring_rand(field, rnd_msg.shape()) + .as(makeType(field)); + ring2k_shr[1] = ring_sub(rnd_msg, ring2k_shr[0]) + .as(makeType(field)); + + NdArrayRef outp_pub; + NdArrayRef outp[2]; + + utils::simulate(npc, [&](const std::shared_ptr& lctx) { + auto obj = factory(conf, lctx); + + KernelEvalContext kcontext(obj.get()); + + int rank = lctx->Rank(); + + size_t bytes = lctx->GetStats()->sent_bytes; + size_t action = lctx->GetStats()->sent_actions; + + spu::mpc::semi2k::ExpA exp; + outp[rank] = exp.proc(&kcontext, ring2k_shr[rank]); + + bytes = lctx->GetStats()->sent_bytes - bytes; + action = lctx->GetStats()->sent_actions - action; + SPDLOG_INFO("ExpA ({}) for n = {}, sent {} MiB ({} B per), actions {}", + field, numel, bytes * 1. / 1024. / 1024., bytes * 1. / numel, + action); + }); + assert(outp[0].eltype() == ring2k_shr[0].eltype()); + auto got = ring_add(outp[0], outp[1]); + ring_print(got, "exp result"); + DISPATCH_ALL_FIELDS(field, [&]() { + using sT = std::make_signed::type; + NdArrayView got_view(got); + + double max_err = 0.0; + for (int64_t i = 0; i < numel; ++i) { + double expected = std::exp(real_vec[i]); + expected = static_cast(std::round((expected * (1L << fxp)))) / + (1L << fxp); + double got = static_cast(got_view[i]) / (1L << fxp); + // cout left here for future improvement + std::cout << "expected: " << fmt::format("{0:f}", expected) + << ", got: " << fmt::format("{0:f}", got) << std::endl; + std::cout << "expected: " + << fmt::format("{0:b}", + static_cast(expected * (1L << fxp))) + << ", got: " << fmt::format("{0:b}", got_view[i]) << std::endl; + max_err = std::max(max_err, std::abs(expected - got)); + } + ASSERT_LE(max_err, 1e-0); + }); +} } // namespace spu::mpc::test diff --git a/libspu/mpc/utils/BUILD.bazel b/libspu/mpc/utils/BUILD.bazel index eec0773b..6d094e09 100644 --- a/libspu/mpc/utils/BUILD.bazel +++ b/libspu/mpc/utils/BUILD.bazel @@ -76,6 +76,7 @@ spu_cc_library( deps = [ ":linalg", "//libspu/core:ndarray_ref", + "//libspu/core:type_util", "@yacl//yacl/crypto/rand", "@yacl//yacl/crypto/tools:prg", "@yacl//yacl/utils:parallel", diff --git a/libspu/mpc/utils/gfmp_ops.cc b/libspu/mpc/utils/gfmp_ops.cc index d73a360d..996e8afb 100644 --- a/libspu/mpc/utils/gfmp_ops.cc +++ b/libspu/mpc/utils/gfmp_ops.cc @@ -162,6 +162,7 @@ NdArrayRef gfmp_rand(FieldType field, const Shape& shape) { // FIXME: this function is not strictly correct as the probability among the // range [0, p-1] is not uniform. + NdArrayRef gfmp_rand(FieldType field, const Shape& shape, uint128_t prg_seed, uint64_t* prg_counter) { constexpr yacl::crypto::SymmetricCrypto::CryptoType kCryptoType = diff --git a/libspu/mpc/utils/ring_ops.h b/libspu/mpc/utils/ring_ops.h index 4fc04a28..035cd955 100644 --- a/libspu/mpc/utils/ring_ops.h +++ b/libspu/mpc/utils/ring_ops.h @@ -18,31 +18,6 @@ namespace spu::mpc { -#define DEF_RVALUE_BINARY_RING_OP(op_name, commutative) \ - template \ - typename std::enable_if< \ - std::is_same_v>> && \ - std::is_same_v>>, \ - NdArrayRef>::type \ - op_name(X&& x, Y&& y) { \ - if constexpr (std::is_rvalue_reference_v) { \ - op_name##_(x, y); \ - if constexpr (std::is_rvalue_reference_v) { \ - NdArrayRef dummy = std::move(y); \ - } \ - return std::move(x); \ - } else if constexpr (std::is_rvalue_reference_v && \ - COMMUTATIVE) { \ - op_name##_(y, x); \ - return std::move(y); \ - } else { \ - return op_name(static_cast(x), \ - static_cast(y)); \ - } \ - } - void ring_print(const NdArrayRef& x, std::string_view name = "_"); NdArrayRef ring_rand(FieldType field, const Shape& shape); @@ -65,15 +40,12 @@ void ring_neg_(NdArrayRef& x); NdArrayRef ring_add(const NdArrayRef& x, const NdArrayRef& y); void ring_add_(NdArrayRef& x, const NdArrayRef& y); -DEF_RVALUE_BINARY_RING_OP(ring_add, true); NdArrayRef ring_sub(const NdArrayRef& x, const NdArrayRef& y); void ring_sub_(NdArrayRef& x, const NdArrayRef& y); -DEF_RVALUE_BINARY_RING_OP(ring_sub, false); NdArrayRef ring_mul(const NdArrayRef& x, const NdArrayRef& y); void ring_mul_(NdArrayRef& x, const NdArrayRef& y); -DEF_RVALUE_BINARY_RING_OP(ring_mul, true); NdArrayRef ring_mul(const NdArrayRef& x, uint128_t y); void ring_mul_(NdArrayRef& x, uint128_t y); @@ -87,15 +59,12 @@ void ring_not_(NdArrayRef& x); NdArrayRef ring_and(const NdArrayRef& x, const NdArrayRef& y); void ring_and_(NdArrayRef& x, const NdArrayRef& y); -DEF_RVALUE_BINARY_RING_OP(ring_and, true); NdArrayRef ring_xor(const NdArrayRef& x, const NdArrayRef& y); void ring_xor_(NdArrayRef& x, const NdArrayRef& y); -DEF_RVALUE_BINARY_RING_OP(ring_xor, true); NdArrayRef ring_equal(const NdArrayRef& x, const NdArrayRef& y); void ring_equal_(NdArrayRef& x, const NdArrayRef& y); -DEF_RVALUE_BINARY_RING_OP(ring_equal, true); NdArrayRef ring_arshift(const NdArrayRef& x, const Sizes& bits); void ring_arshift_(NdArrayRef& x, const Sizes& bits); @@ -138,6 +107,4 @@ void ring_set_value(NdArrayRef& in, const T& value) { pforeach(0, in.numel(), [&](int64_t idx) { _in[idx] = value; }); }; -#undef DEF_RVALUE_BINARY_RING_OP - } // namespace spu::mpc diff --git a/libspu/spu.proto b/libspu/spu.proto index 2d9ebd5b..e7341004 100644 --- a/libspu/spu.proto +++ b/libspu/spu.proto @@ -244,6 +244,7 @@ message RuntimeConfig { EXP_DEFAULT = 0; // Implementation defined. EXP_PADE = 1; // The pade approximation. EXP_TAYLOR = 2; // Taylor series approximation. + EXP_PRIME = 3; // exp prime only available for some implementations } // The exponent approximation method. @@ -338,6 +339,25 @@ message RuntimeConfig { uint64 experimental_inter_op_concurrency = 104; // Enable use of private type bool experimental_enable_colocated_optimization = 105; + + // enable experimental exp prime method + bool experimental_enable_exp_prime = 106; + + // The offset parameter for exp prime methods. + // control the valid range of exp prime method. + // valid range is: + // ((47 - offset - 2fxp)/log_2(e), (125 - 2fxp - offset)/log_2(e)) + // clamp to value would be + // lower bound: (48 - offset - 2fxp)/log_2(e) + // higher bound: (124 - 2fxp - offset)/log_2(e) + // default offset is 13, 0 offset is not supported. + uint32 experimental_exp_prime_offset = 107; + // whether to apply the clamping lower bound + // default to enable it + bool experimental_exp_prime_disable_lower_bound = 108; + // whether to apply the clamping upper bound + // default to disable it + bool experimental_exp_prime_enable_upper_bound = 109; } message TTPBeaverConfig { diff --git a/libspu/version.h b/libspu/version.h index a7f569ae..8ab54eea 100644 --- a/libspu/version.h +++ b/libspu/version.h @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#define SPU_VERSION "0.9.3.dev$$DATE$$" +#define SPU_VERSION "0.9.4.dev$$DATE$$" #include diff --git a/pyrightconfig.json b/pyrightconfig.json new file mode 100644 index 00000000..fbcb1153 --- /dev/null +++ b/pyrightconfig.json @@ -0,0 +1,5 @@ +{ + "executionEnvironments": [ + {"root": "."} + ] +} diff --git a/requirements-dev.txt b/requirements-dev.txt index 6ee814f8..33bcd4da 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,6 +1,6 @@ pandas>=1.4.2 -flax -scikit-learn +flax<0.10.0 +scikit-learn<1.6.0 # for tests absl-py>=1.1.0 tensorflow-cpu>=2.12.0; sys_platform == "linux" and platform_machine == 'x86_64' diff --git a/requirements.txt b/requirements.txt index ac8fd6cc..2c52e827 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ grpcio>=1.42.0,!=1.48.0 -numpy>=1.22.0 +numpy>=1.22.0, <2 # FIXME: for SF compatibility protobuf>=4, <5 cloudpickle>=2.0.0 multiprocess>=0.70.12.2 cachetools>=5.0.0 -jax[cpu]>=0.4.16, <=0.4.26 # FIXME: Jax 0.4.26+ select perf issue +jax[cpu]>=0.4.16, <=0.4.34 # FIXME: Jax 0.4.26+ select perf issue termcolor>=2.0.0 diff --git a/sml/linear_model/BUILD.bazel b/sml/linear_model/BUILD.bazel index 6fa078a8..fa4fdd15 100644 --- a/sml/linear_model/BUILD.bazel +++ b/sml/linear_model/BUILD.bazel @@ -54,3 +54,11 @@ py_binary( "//sml/linear_model/utils:solver", ], ) + +py_library( + name = "quantile", + srcs = ["quantile.py"], + deps = [ + "//sml/linear_model/utils:_linprog_simplex", + ], +) diff --git a/sml/linear_model/emulations/BUILD.bazel b/sml/linear_model/emulations/BUILD.bazel index 6778cd0c..46df0f92 100644 --- a/sml/linear_model/emulations/BUILD.bazel +++ b/sml/linear_model/emulations/BUILD.bazel @@ -62,3 +62,12 @@ py_binary( "//sml/utils:emulation", ], ) + +py_binary( + name = "quantile_emul", + srcs = ["quantile_emul.py"], + deps = [ + "//sml/linear_model:quantile", + "//sml/utils:emulation", + ], +) diff --git a/sml/linear_model/emulations/quantile_emul.py b/sml/linear_model/emulations/quantile_emul.py new file mode 100644 index 00000000..ed81b4c3 --- /dev/null +++ b/sml/linear_model/emulations/quantile_emul.py @@ -0,0 +1,104 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import jax.numpy as jnp +from sklearn.linear_model import QuantileRegressor as SklearnQuantileRegressor + +import sml.utils.emulation as emulation +from sml.linear_model.quantile import QuantileRegressor as SmlQuantileRegressor + +CONFIG_FILE = emulation.CLUSTER_ABY3_3PC + + +def emul_quantile(mode=emulation.Mode.MULTIPROCESS): + def proc_wrapper( + quantile, + alpha, + fit_intercept, + lr, + max_iter, + ): + quantile_custom = SmlQuantileRegressor( + quantile=quantile, + alpha=alpha, + fit_intercept=fit_intercept, + lr=lr, + max_iter=max_iter, + ) + + def proc(X, y): + quantile_custom_fit = quantile_custom.fit(X, y) + result = quantile_custom_fit.predict(X) + return result, quantile_custom_fit.coef_, quantile_custom_fit.intercept_ + + return proc + + def generate_data(): + from jax import random + + key = random.PRNGKey(42) + key, subkey = random.split(key) + X = random.normal(subkey, (100, 2)) + y = 5 * X[:, 0] + 2 * X[:, 1] + random.normal(key, (100,)) * 0.1 + return X, y + + try: + # bandwidth and latency only work for docker mode + emulator = emulation.Emulator(CONFIG_FILE, mode, bandwidth=300, latency=20) + emulator.up() + + # load mock data + X, y = generate_data() + + # compare with sklearn + quantile_sklearn = SklearnQuantileRegressor( + quantile=0.2, alpha=0.1, fit_intercept=True, solver='highs' + ) + start = time.time() + quantile_sklearn_fit = quantile_sklearn.fit(X, y) + y_pred_plain = quantile_sklearn_fit.predict(X) + rmse_plain = jnp.sqrt(jnp.mean((y - y_pred_plain) ** 2)) + end = time.time() + print(f"Running time in SKlearn: {end - start:.2f}s") + print(quantile_sklearn_fit.coef_) + print(quantile_sklearn_fit.intercept_) + + # mark these data to be protected in SPU + X_spu, y_spu = emulator.seal(X, y) + + # run + # Larger max_iter can give higher accuracy, but it will take more time to run + proc = proc_wrapper( + quantile=0.2, alpha=0.1, fit_intercept=True, lr=0.01, max_iter=200 + ) + start = time.time() + result, coef, intercept = emulator.run(proc)(X_spu, y_spu) + end = time.time() + rmse_encrpted = jnp.sqrt(jnp.mean((y - result) ** 2)) + print(f"Running time in SPU: {end - start:.2f}s") + print(coef) + print(intercept) + + # print RMSE + print(f"RMSE in SKlearn: {rmse_plain:.2f}") + print(f"RMSE in SPU: {rmse_encrpted:.2f}") + + finally: + emulator.down() + + +if __name__ == "__main__": + emul_quantile(emulation.Mode.MULTIPROCESS) diff --git a/sml/linear_model/quantile.py b/sml/linear_model/quantile.py new file mode 100644 index 00000000..549e67ae --- /dev/null +++ b/sml/linear_model/quantile.py @@ -0,0 +1,196 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +import jax.numpy as jnp +import pandas as pd +from jax import grad + +from sml.linear_model.utils._linprog_simplex import _linprog_simplex + + +class QuantileRegressor: + """ + Initialize the quantile regression model. + Parameters + ---------- + quantile : float, default=0.5 + The quantile to be predicted. Must be between 0 and 1. + A quantile of 0.5 corresponds to the median (50th percentile). + alpha : float, default=1.0 + Regularization strength; must be a positive float. + Larger values specify stronger regularization, reducing model complexity. + fit_intercept : bool, default=True + Whether to calculate the intercept for the model. + If False, no intercept will be used in calculations, meaning the model will + assume that the data is already centered. + lr : float, default=0.01 + Learning rate for the optimization process. This controls the size of + the steps taken in each iteration towards minimizing the objective function. + max_iter : int, default=1000 + The maximum number of iterations for the optimization algorithm. + This controls how long the model will continue to update the weights + before stopping. + max_val : float, default=1e10 + The maximum value allowed for the model parameters. + Attributes + ---------- + coef_ : array-like of shape (n_features,) + The coefficients (weights) assigned to the input features. These will be + learned during model fitting. + intercept_ : float + The intercept (bias) term. If `fit_intercept=True`, this will be + learned during model fitting. + """ + + def __init__( + self, + quantile=0.5, + alpha=1.0, + fit_intercept=True, + lr=0.01, + max_iter=1000, + max_val=1e10, + ): + self.quantile = quantile + self.alpha = alpha + self.fit_intercept = fit_intercept + self.lr = lr + self.max_iter = max_iter + self.max_val = max_val + + self.coef_ = None + self.intercept_ = None + + def fit(self, X, y, sample_weight=None): + """ + Fit the quantile regression model using linear programming. + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + Training data. + y : array-like of shape (n_samples,) + Target values. + sample_weight : array-like of shape (n_samples,), optional + Individual weights for each sample. If not provided, all samples + are assumed to have equal weight. + Returns + ------- + self : object + Returns an instance of self. + Steps: + 1. Determine the number of parameters (`n_params`), accounting for the intercept if needed. + 2. Define the objective function `c`, incorporating both the L1 regularization and the pinball loss. + 3. Set up the equality constraint matrix `A_eq` and vector `b_eq` based on the input data `X` and `y`. + 4. Solve the linear programming problem using `_linprog_simplex`. + 5. Extract the model parameters (intercept and coefficients) from the solution. + """ + n_samples, n_features = X.shape + n_params = n_features + + if sample_weight is None: + sample_weight = jnp.ones((n_samples,)) + + if self.fit_intercept: + n_params += 1 + + alpha = jnp.sum(sample_weight) * self.alpha + + # After rescaling alpha, the minimization problem is + # min sum(pinball loss) + alpha * L1 + # Use linear programming formulation of quantile regression + # min_x c x + # A_eq x = b_eq + # 0 <= x + # x = (s0, s, t0, t, u, v) = slack variables >= 0 + # intercept = s0 - t0 + # coef = s - t + # c = (0, alpha * 1_p, 0, alpha * 1_p, quantile * 1_n, (1-quantile) * 1_n) + # residual = y - X@coef - intercept = u - v + # A_eq = (1_n, X, -1_n, -X, diag(1_n), -diag(1_n)) + # b_eq = y + # p = n_features + # n = n_samples + # 1_n = vector of length n with entries equal one + # see https://stats.stackexchange.com/questions/384909/ + c = jnp.concatenate( + [ + jnp.full(2 * n_params, fill_value=alpha), + sample_weight * self.quantile, + sample_weight * (1 - self.quantile), + ] + ) + + if self.fit_intercept: + c = c.at[0].set(0) + c = c.at[n_params].set(0) + + eye = jnp.eye(n_samples) + if self.fit_intercept: + ones = jnp.ones((n_samples, 1)) + A = jnp.concatenate([ones, X, -ones, -X, eye, -eye], axis=1) + else: + A = jnp.concatenate([X, -X, eye, -eye], axis=1) + + b = y + + result = _linprog_simplex( + c, A, b, maxiter=self.max_iter, tol=1e-3, max_val=self.max_val + ) + + solution = result + + params = solution[:n_params] - solution[n_params : 2 * n_params] + + if self.fit_intercept: + self.coef_ = params[1:] + self.intercept_ = params[0] + else: + self.coef_ = params + self.intercept_ = 0.0 + return self + + def predict(self, X): + """ + Predict target values using the fitted quantile regression model. + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + Input data for which predictions are to be made. + Returns + ------- + y_pred : array-like of shape (n_samples,) + Predicted target values. + Notes + ----- + The predict method computes the predicted target values using the model's + learned coefficients and intercept (if fit_intercept=True). + - If the model includes an intercept, a column of ones is added to the input data `X` to account + for the intercept in the linear combination. + - The method then computes the dot product between the modified `X` and the stacked vector of + intercept and coefficients. + - If there is no intercept, the method simply computes the dot product between `X` and the coefficients. + """ + + assert ( + self.coef_ is not None and self.intercept_ is not None + ), "Model has not been fitted yet. Please fit the model before predicting." + + n_features = len(self.coef_) + assert X.shape[1] == n_features, ( + f"Input X must have {n_features} features, " + f"but got {X.shape[1]} features instead." + ) + + return jnp.dot(X, self.coef_) + self.intercept_ diff --git a/sml/linear_model/tests/BUILD.bazel b/sml/linear_model/tests/BUILD.bazel index 1fa04f86..f729c206 100644 --- a/sml/linear_model/tests/BUILD.bazel +++ b/sml/linear_model/tests/BUILD.bazel @@ -70,3 +70,13 @@ py_test( "//spu/utils:simulation", ], ) + +py_test( + name = "quantile_test", + srcs = ["quantile_test.py"], + deps = [ + "//sml/linear_model:quantile", + "//spu:init", + "//spu/utils:simulation", + ], +) diff --git a/sml/linear_model/tests/quantile_test.py b/sml/linear_model/tests/quantile_test.py new file mode 100644 index 00000000..d9d12f1f --- /dev/null +++ b/sml/linear_model/tests/quantile_test.py @@ -0,0 +1,93 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import jax.numpy as jnp +from sklearn.linear_model import QuantileRegressor as SklearnQuantileRegressor + +import spu.spu_pb2 as spu_pb2 # type: ignore +import spu.utils.simulation as spsim +from sml.linear_model.quantile import QuantileRegressor as SmlQuantileRegressor + + +class UnitTests(unittest.TestCase): + def test_quantile(self): + def proc_wrapper( + quantile, + alpha, + fit_intercept, + lr, + max_iter, + ): + quantile_custom = SmlQuantileRegressor( + quantile=quantile, + alpha=alpha, + fit_intercept=fit_intercept, + lr=lr, + max_iter=max_iter, + ) + + def proc(X, y): + quantile_custom_fit = quantile_custom.fit(X, y) + result = quantile_custom_fit.predict(X) + return result, quantile_custom_fit.coef_, quantile_custom_fit.intercept_ + + return proc + + n_samples, n_features = 100, 2 + + def generate_data(): + from jax import random + + key = random.PRNGKey(42) + key, subkey = random.split(key) + X = random.normal(subkey, (100, 2)) + y = 5 * X[:, 0] + 2 * X[:, 1] + random.normal(key, (100,)) * 0.1 + return X, y + + # bandwidth and latency only work for docker mode + sim = spsim.Simulator.simple( + 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 + ) + + X, y = generate_data() + + # compare with sklearn + quantile_sklearn = SklearnQuantileRegressor( + quantile=0.2, alpha=0.1, fit_intercept=True, solver='revised simplex' + ) + quantile_sklearn_fit = quantile_sklearn.fit(X, y) + y_pred_plain = quantile_sklearn_fit.predict(X) + rmse_plain = jnp.sqrt(jnp.mean((y - y_pred_plain) ** 2)) + print(f"RMSE in SKlearn: {rmse_plain:.2f}") + print(quantile_sklearn_fit.coef_) + print(quantile_sklearn_fit.intercept_) + + # run + # Larger max_iter can give higher accuracy, but it will take more time to run + proc = proc_wrapper( + quantile=0.2, alpha=0.1, fit_intercept=True, lr=0.01, max_iter=20 + ) + result, coef, intercept = spsim.sim_jax(sim, proc)(X, y) + rmse_encrpted = jnp.sqrt(jnp.mean((y - result) ** 2)) + + # print RMSE + print(f"RMSE in SPU: {rmse_encrpted:.2f}") + print(coef) + print(intercept) + + +if __name__ == "__main__": + unittest.main() diff --git a/sml/linear_model/utils/BUILD.bazel b/sml/linear_model/utils/BUILD.bazel index 7c13def5..27329073 100644 --- a/sml/linear_model/utils/BUILD.bazel +++ b/sml/linear_model/utils/BUILD.bazel @@ -31,3 +31,8 @@ py_library( name = "solver", srcs = ["solver.py"], ) + +py_library( + name = "_linprog_simplex", + srcs = ["_linprog_simplex.py"], +) diff --git a/sml/linear_model/utils/_linprog_simplex.py b/sml/linear_model/utils/_linprog_simplex.py new file mode 100644 index 00000000..0ae02578 --- /dev/null +++ b/sml/linear_model/utils/_linprog_simplex.py @@ -0,0 +1,156 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +import jax.numpy as jnp +from jax import jit, lax + + +def _pivot_col(T, tol=1e-5): + mask = T[-1, :-1] >= -tol + + all_masked = jnp.all(mask) + + ma = jnp.where(mask, jnp.inf, T[-1, :-1]) + min_col = jnp.argmin(ma) + + valid = ~all_masked + result = jnp.where(all_masked, 0, min_col) + + return valid, result + + +def _pivot_row(T, pivcol, phase, tol=1e-5, max_val=1e10): + if phase == 1: + k = 2 + else: + k = 1 + + mask = T[:-k, pivcol] <= tol + ma = jnp.where(mask, jnp.inf, T[:-k, pivcol]) + mb = jnp.where(mask, jnp.inf, T[:-k, -1]) + + q = jnp.where(ma >= max_val, jnp.inf, mb / ma) + + min_rows = jnp.nanargmin(q) + all_masked = jnp.all(mask) + + row = min_rows + row = jnp.where(all_masked, 0, row) + + return ~all_masked, row + + +def _apply_pivot(T, basis, pivrow, pivcol): + pivrow = jnp.int32(pivrow) + pivcol = jnp.int32(pivcol) + + basis = basis.at[pivrow].set(pivcol) + + pivrow_one_hot = jax.nn.one_hot(pivrow, T.shape[0]) + pivcol_one_hot = jax.nn.one_hot(pivcol, T.shape[1]) + + pivval = jnp.dot(pivrow_one_hot, jnp.dot(T, pivcol_one_hot)) + + updated_row = T[pivrow] / pivval + T = pivrow_one_hot[:, None] * updated_row + T * (1 - pivrow_one_hot[:, None]) + + scalar = jnp.dot(T, pivcol_one_hot).reshape(-1, 1) + + updated_T = T - scalar * T[pivrow] + + row_restore_matrix = pivrow_one_hot[:, None] * T[pivrow] + updated_T = row_restore_matrix + updated_T * (1 - pivrow_one_hot[:, None]) + + return updated_T, basis + + +def _solve_simplex( + T, + n, + basis, + maxiter=100, + tol=1e-5, + max_val=1e10, + phase=2, +): + complete = False + + num = 0 + pivcol = 0 + pivrow = 0 + while num < maxiter: + pivcol_found, pivcol = _pivot_col(T, tol) + + def cal_pivcol_found_True(T, pivcol, phase, tol, complete): + pivrow_found, pivrow = _pivot_row(T, pivcol, phase, tol, max_val) + + pivrow_isnot_found = pivrow_found == False + complete = jnp.where(pivrow_isnot_found, True, complete) + + return pivrow, complete + + pivcol_is_found = pivcol_found == True + pivrow_True, complete_True = cal_pivcol_found_True( + T, pivcol, phase, tol, complete + ) + + pivrow = jnp.where(pivcol_is_found, pivrow_True, 0) + + complete = jnp.where(pivcol_is_found, complete_True, complete) + + complete_is_False = complete == False + apply_T, apply_basis = _apply_pivot(T, basis, pivrow, pivcol) + T = jnp.where(complete_is_False, apply_T, T) + basis = jnp.where(complete_is_False, apply_basis, basis) + num = num + 1 + + return T, basis + + +def _linprog_simplex(c, A, b, c0=0, maxiter=300, tol=1e-5, max_val=1e10): + n, m = A.shape + + # All constraints must have b >= 0. + is_negative_constraint = jnp.less(b, 0) + A = jnp.where(is_negative_constraint[:, None], A * -1, A) + b = jnp.where(is_negative_constraint, b * -1, b) + + av = jnp.arange(n) + m + basis = av.copy() + + row_constraints = jnp.hstack((A, jnp.eye(n), b[:, jnp.newaxis])) + row_objective = jnp.hstack((c, jnp.zeros(n), c0)) + row_pseudo_objective = -row_constraints.sum(axis=0) + row_pseudo_objective = row_pseudo_objective.at[av].set(0) + T = jnp.vstack((row_constraints, row_objective, row_pseudo_objective)) + + # phase 1 + T, basis = _solve_simplex( + T, n, basis, maxiter=maxiter, tol=tol, max_val=max_val, phase=1 + ) + + T_new = T[:-1, :] + T = jnp.delete(T_new, av, 1, assume_unique_indices=True) + + # phase 2 + T, basis = _solve_simplex( + T, n, basis, maxiter=maxiter, tol=tol, max_val=max_val, phase=2 + ) + + solution = jnp.zeros(n + m) + solution = solution.at[basis[:n]].set(T[:n, -1]) + x = solution[:m] + + return x diff --git a/spu/tests/BUILD.bazel b/spu/tests/BUILD.bazel index fd037279..d9e348cb 100644 --- a/spu/tests/BUILD.bazel +++ b/spu/tests/BUILD.bazel @@ -178,7 +178,6 @@ py_binary( srcs = ["jnp_debug.py"], deps = [ "//spu:api", - "//spu/intrinsic:all_intrinsics", "//spu/utils:simulation", ], ) diff --git a/spu/tests/jnp_debug.py b/spu/tests/jnp_debug.py index ff970dbf..4757c555 100644 --- a/spu/tests/jnp_debug.py +++ b/spu/tests/jnp_debug.py @@ -15,7 +15,6 @@ import jax.numpy as jnp import numpy as np -import spu.intrinsic as si import spu.spu_pb2 as spu_pb2 import spu.utils.simulation as ppsim @@ -31,9 +30,9 @@ copts.disable_div_sqrt_rewrite = True x = np.random.randn(3, 4) - y = np.random.randn(5, 6) - fn = lambda x, y: si.example_binary(x, y) - # fn = lambda x, y: jnp.matmul(x, y) + y = np.random.randn(4, 5) + fn = lambda x, y: jnp.matmul(x, y) + spu_fn = ppsim.sim_jax(sim, fn, copts=copts) z = spu_fn(x, y) diff --git a/spu/tests/jnp_semi2k_r128_test.py b/spu/tests/jnp_semi2k_r128_test.py index 10295fed..7ce5c825 100644 --- a/spu/tests/jnp_semi2k_r128_test.py +++ b/spu/tests/jnp_semi2k_r128_test.py @@ -30,5 +30,19 @@ def setUp(self): self._rng = np.random.RandomState() +class JnpTestSemi2kFM128TwoParty(JnpTests.JnpTestBase): + def setUp(self): + config = spu_pb2.RuntimeConfig( + protocol=spu_pb2.ProtocolKind.SEMI2K, field=spu_pb2.FieldType.FM128 + ) + config.experimental_enable_exp_prime = True + config.experimental_exp_prime_enable_upper_bound = True + config.experimental_exp_prime_offset = 13 + config.experimental_exp_prime_disable_lower_bound = False + config.fxp_exp_mode = spu_pb2.RuntimeConfig.ExpMode.EXP_PRIME + self._sim = ppsim.Simulator(2, config) + self._rng = np.random.RandomState() + + if __name__ == "__main__": unittest.main() diff --git a/spu/tests/legacy_psi_test.py b/spu/tests/legacy_psi_test.py index eae7e7c8..477e964c 100644 --- a/spu/tests/legacy_psi_test.py +++ b/spu/tests/legacy_psi_test.py @@ -194,252 +194,6 @@ def test_dppsi_2pc(self): 2, inputs, outputs, selected_fields, psi.PsiType.DP_PSI_2PC ) - def test_ecdh_oprf_unbalanced(self): - print("----------test_ecdh_oprf_unbalanced-------------") - - offline_path = ["", "spu/tests/data/bob.csv"] - online_path = ["spu/tests/data/alice.csv", "spu/tests/data/bob.csv"] - outputs = ["./alice-ecdh-unbalanced.csv", "./bob-ecdh-unbalanced.csv"] - preprocess_path = ["./alice-preprocess.csv", ""] - secret_key_path = ["", "./secret_key.bin"] - selected_fields = ["id", "idx"] - - with open(secret_key_path[1], 'wb') as f: - f.write( - bytes.fromhex( - "000102030405060708090a0b0c0d0e0ff0e0d0c0b0a090807060504030201000" - ) - ) - - time_stamp = time.time() - lctx_desc = link.Desc() - lctx_desc.id = str(round(time_stamp * 1000)) - - for rank in range(2): - port = get_free_port() - lctx_desc.add_party(f"id_{rank}", f"127.0.0.1:{port}") - - receiver_rank = 0 - server_rank = 1 - client_rank = 0 - # one-way PSI, just one party get result - broadcast_result = False - - precheck_input = False - server_cache_path = "server_cache.bin" - - def wrap( - rank, - offline_path, - online_path, - out_path, - preprocess_path, - ub_secret_key_path, - ): - link_ctx = link.create_brpc(lctx_desc, rank) - - if receiver_rank != link_ctx.rank: - print("===== gen cache phase =====") - print(f"{offline_path}, {server_cache_path}") - - gen_cache_config = psi.BucketPsiConfig( - psi_type=psi.PsiType.Value('ECDH_OPRF_UB_PSI_2PC_GEN_CACHE'), - input_params=psi.InputParams( - path=offline_path, - select_fields=selected_fields, - precheck=False, - ), - output_params=psi.OutputParams( - path=server_cache_path, need_sort=False - ), - bucket_size=1000000, - curve_type=psi.CurveType.CURVE_FOURQ, - ecdh_secret_key_path=ub_secret_key_path, - ) - - start = time.time() - - gen_cache_report = psi.gen_cache_for_2pc_ub_psi(gen_cache_config) - - server_source_count = wc_count(offline_path) - self.assertEqual( - gen_cache_report.original_count, server_source_count - 1 - ) - - print(f"offline cost time: {time.time() - start}") - print( - f"offline: rank: {rank} original_count: {gen_cache_report.original_count}" - ) - - print("===== transfer cache phase =====") - transfer_cache_config = psi.BucketPsiConfig( - psi_type=psi.PsiType.Value('ECDH_OPRF_UB_PSI_2PC_TRANSFER_CACHE'), - broadcast_result=broadcast_result, - receiver_rank=receiver_rank, - input_params=psi.InputParams( - path=offline_path, - select_fields=selected_fields, - precheck=precheck_input, - ), - bucket_size=1000000, - curve_type=psi.CurveType.CURVE_FOURQ, - ) - - if receiver_rank == link_ctx.rank: - transfer_cache_config.preprocess_path = preprocess_path - else: - transfer_cache_config.input_params.path = server_cache_path - - print( - f"rank:{link_ctx.rank} file:{transfer_cache_config.input_params.path}" - ) - - start = time.time() - transfer_cache_report = psi.bucket_psi(link_ctx, transfer_cache_config) - - if receiver_rank != link_ctx.rank: - server_source_count = wc_count(offline_path) - self.assertEqual( - transfer_cache_report.original_count, server_source_count - 1 - ) - - print(f"transfer cache cost time: {time.time() - start}") - print( - f"transfer cache: rank: {rank} original_count: {transfer_cache_report.original_count}" - ) - - print("===== shuffle online phase =====") - shuffle_online_config = psi.BucketPsiConfig( - psi_type=psi.PsiType.Value('ECDH_OPRF_UB_PSI_2PC_SHUFFLE_ONLINE'), - broadcast_result=False, - receiver_rank=server_rank, - input_params=psi.InputParams( - path=online_path, - select_fields=selected_fields, - precheck=precheck_input, - ), - output_params=psi.OutputParams(path=out_path, need_sort=False), - bucket_size=10000000, - curve_type=psi.CurveType.CURVE_FOURQ, - ) - - if client_rank == link_ctx.rank: - shuffle_online_config.preprocess_path = preprocess_path - else: - shuffle_online_config.preprocess_path = server_cache_path - shuffle_online_config.ecdh_secret_key_path = ub_secret_key_path - - print( - f"rank:{link_ctx.rank} file:{shuffle_online_config.input_params.path}" - ) - - start = time.time() - shuffle_online_report = psi.bucket_psi(link_ctx, shuffle_online_config) - - if server_rank == link_ctx.rank: - server_source_count = wc_count(offline_path) - self.assertEqual( - shuffle_online_report.original_count, server_source_count - 1 - ) - - print(f"shuffle online cost time: {time.time() - start}") - print( - f"shuffle online: rank: {rank} original_count: {shuffle_online_report.original_count}" - ) - print( - f"shuffle online: rank: {rank} intersection: {shuffle_online_report.intersection_count}" - ) - - print("===== offline phase =====") - offline_config = psi.BucketPsiConfig( - psi_type=psi.PsiType.Value('ECDH_OPRF_UB_PSI_2PC_OFFLINE'), - broadcast_result=broadcast_result, - receiver_rank=client_rank, - input_params=psi.InputParams( - path=offline_path, - select_fields=selected_fields, - precheck=precheck_input, - ), - output_params=psi.OutputParams(path="fake.out", need_sort=False), - bucket_size=1000000, - curve_type=psi.CurveType.CURVE_FOURQ, - ) - - if client_rank == link_ctx.rank: - offline_config.preprocess_path = preprocess_path - offline_config.input_params.path = "dummy.csv" - else: - offline_config.ecdh_secret_key_path = ub_secret_key_path - - start = time.time() - offline_report = psi.bucket_psi(link_ctx, offline_config) - - if receiver_rank != link_ctx.rank: - server_source_count = wc_count(offline_path) - self.assertEqual(offline_report.original_count, server_source_count - 1) - - print(f"offline cost time: {time.time() - start}") - print( - f"offline: rank: {rank} original_count: {offline_report.original_count}" - ) - print( - f"offline: rank: {rank} intersection_count: {offline_report.intersection_count}" - ) - - print("===== online phase =====") - online_config = psi.BucketPsiConfig( - psi_type=psi.PsiType.Value('ECDH_OPRF_UB_PSI_2PC_ONLINE'), - broadcast_result=broadcast_result, - receiver_rank=client_rank, - input_params=psi.InputParams( - path=online_path, - select_fields=selected_fields, - precheck=precheck_input, - ), - output_params=psi.OutputParams(path=out_path, need_sort=False), - bucket_size=300000, - curve_type=psi.CurveType.CURVE_FOURQ, - ) - - if receiver_rank == link_ctx.rank: - online_config.preprocess_path = preprocess_path - else: - online_config.ecdh_secret_key_path = ub_secret_key_path - online_config.input_params.path = "dummy.csv" - - start = time.time() - report_online = psi.bucket_psi(link_ctx, online_config) - - if receiver_rank == link_ctx.rank: - client_source_count = wc_count(online_path) - self.assertEqual(report_online.original_count, client_source_count - 1) - - print(f"online cost time: {time.time() - start}") - print(f"online: rank:{rank} original_count: {report_online.original_count}") - print(f"intersection_count: {report_online.intersection_count}") - - link_ctx.stop_link() - - # launch with multiprocess - jobs = [ - multiprocess.Process( - target=wrap, - args=( - rank, - offline_path[rank], - online_path[rank], - outputs[rank], - preprocess_path[rank], - secret_key_path[rank], - ), - ) - for rank in range(2) - ] - [job.start() for job in jobs] - for job in jobs: - job.join() - self.assertEqual(job.exitcode, 0) - if __name__ == '__main__': unittest.main() diff --git a/spu/tests/ub_psi_test.py b/spu/tests/ub_psi_test.py index 2728e1b6..2a5323e1 100644 --- a/spu/tests/ub_psi_test.py +++ b/spu/tests/ub_psi_test.py @@ -43,12 +43,12 @@ def test_ub_psi(self): "role": "ROLE_SERVER", "cache_path": "{self.tempdir_.name}/spu_test_ub_psi_server_cache", "input_config": {{ + "type" : "IO_TYPE_FILE_CSV", "path": "spu/tests/data/alice.csv" }}, "keys": [ "id" - ], - "server_secret_key_path": "{self.tempdir_.name}/spu_test_ub_psi_server_secret_key.key" + ] }} ''' @@ -60,15 +60,6 @@ def test_ub_psi(self): }} ''' - with open( - f"{self.tempdir_.name}/spu_test_ub_psi_server_secret_key.key", 'wb' - ) as f: - f.write( - bytes.fromhex( - "000102030405060708090a0b0c0d0e0ff0e0d0c0b0a090807060504030201000" - ) - ) - configs = [ json_format.ParseDict(json.loads(server_offline_config), psi.UbPsiConfig()), json_format.ParseDict(json.loads(client_offline_config), psi.UbPsiConfig()), @@ -95,8 +86,10 @@ def wrap(rank, link_desc, configs): {{ "mode": "MODE_ONLINE", "role": "ROLE_SERVER", - "server_secret_key_path": "{self.tempdir_.name}/spu_test_ub_psi_server_secret_key.key", - "cache_path": "{self.tempdir_.name}/spu_test_ub_psi_server_cache" + "cache_path": "{self.tempdir_.name}/spu_test_ub_psi_server_cache", + "output_config": {{ + "type" : "IO_TYPE_FILE_CSV" + }} }} ''' @@ -105,9 +98,11 @@ def wrap(rank, link_desc, configs): "mode": "MODE_ONLINE", "role": "ROLE_CLIENT", "input_config": {{ + "type" : "IO_TYPE_FILE_CSV", "path": "spu/tests/data/bob.csv" }}, "output_config": {{ + "type" : "IO_TYPE_FILE_CSV", "path": "{self.tempdir_.name}/spu_test_ubpsi_bob_psi_ouput.csv" }}, "keys": [ diff --git a/spu/utils/simulation.py b/spu/utils/simulation.py index 592ccaff..6a6c220a 100644 --- a/spu/utils/simulation.py +++ b/spu/utils/simulation.py @@ -22,6 +22,7 @@ import jax.extend.linear_util as jax_lu except ImportError: import jax.linear_util as jax_lu # fallback + import jax.numpy as jnp import numpy as np from jax._src import api_util as japi_util @@ -69,6 +70,7 @@ def simple(cls, wsize: int, prot: spu_pb2.ProtocolKind, field: spu_pb2.FieldType A SPU Simulator """ config = spu_pb2.RuntimeConfig(protocol=prot, field=field) + if prot == spu_pb2.ProtocolKind.CHEETAH: # config.cheetah_2pc_config.enable_mul_lsb_error = True # config.cheetah_2pc_config.ot_kind = spu_pb2.CheetahOtKind.YACL_Softspoken