diff --git a/libspu/mpc/ab_api.cc b/libspu/mpc/ab_api.cc index f298b36f..8dcacff4 100644 --- a/libspu/mpc/ab_api.cc +++ b/libspu/mpc/ab_api.cc @@ -118,6 +118,10 @@ Value mul_aa(SPUContext* ctx, const Value& x, const Value& y) { TILED_DISPATCH(ctx, x, y); } +Value mul_aaa(SPUContext* ctx, const Value& x, const Value& y, const Value& z) { + TILED_DISPATCH(ctx, x, y, z); +} + Value mul_aa_p(SPUContext* ctx, const Value& x, const Value& y) { TILED_DISPATCH(ctx, x, y); } diff --git a/libspu/mpc/ab_api.h b/libspu/mpc/ab_api.h index 8bb954eb..3a39dab2 100644 --- a/libspu/mpc/ab_api.h +++ b/libspu/mpc/ab_api.h @@ -41,6 +41,7 @@ OptionalAPI add_av(SPUContext* ctx, const Value& x, const Value& y); Value mul_ap(SPUContext* ctx, const Value& x, const Value& y); Value mul_aa(SPUContext* ctx, const Value& x, const Value& y); +Value mul_aaa(SPUContext* ctx, const Value& x, const Value& y, const Value& z); Value mul_aa_p(SPUContext* ctx, const Value& x, const Value& y); Value square_a(SPUContext* ctx, const Value& x); diff --git a/libspu/mpc/ab_api_test.cc b/libspu/mpc/ab_api_test.cc index aaee85d3..cb36e44f 100644 --- a/libspu/mpc/ab_api_test.cc +++ b/libspu/mpc/ab_api_test.cc @@ -329,6 +329,37 @@ TEST_P(ArithmeticTest, MulAA) { }); } +TEST_P(ArithmeticTest, MulAAA) { + const auto factory = std::get<0>(GetParam()); + const RuntimeConfig& conf = std::get<1>(GetParam()); + const size_t npc = std::get<2>(GetParam()); + + utils::simulate(npc, [&](const std::shared_ptr& lctx) { + auto sctx = factory(conf, lctx); + + auto p0 = rand_p(sctx.get(), kShape); + auto p1 = rand_p(sctx.get(), kShape); + auto p2 = rand_p(sctx.get(), kShape); + + auto v0 = p2v(sctx.get(), p0, 0); + auto v1 = p2v(sctx.get(), p1, 1); + auto v2 = p2v(sctx.get(), p2, 2); + + auto a0 = v2a(sctx.get(), v0); + auto a1 = v2a(sctx.get(), v1); + auto a2 = v2a(sctx.get(), v2); + + auto prod = mul_aaa(sctx.get(), a0, a1, a2); + auto p_prod = a2p(sctx.get(), prod); + + auto s = mul_pp(sctx.get(), p0, p1); + auto s_prime = mul_pp(sctx.get(), s, p2); + + /* THEN */ + EXPECT_VALUE_EQ(s_prime, p_prod); + }); +} + TEST_P(ArithmeticTest, MulAAP) { const auto factory = std::get<0>(GetParam()); const RuntimeConfig& conf = std::get<1>(GetParam()); diff --git a/libspu/mpc/kernel.cc b/libspu/mpc/kernel.cc index c43a7d17..1e9e8d96 100644 --- a/libspu/mpc/kernel.cc +++ b/libspu/mpc/kernel.cc @@ -66,6 +66,21 @@ void BinaryKernel::evaluate(KernelEvalContext* ctx) const { ctx->pushOutput(WrapValue(z)); } +void TernaryKnernel::evaluate(KernelEvalContext* ctx) const { + const auto& x = ctx->getParam(0); + const auto& y = ctx->getParam(1); + const auto& z = ctx->getParam(2); + + SPU_ENFORCE(x.shape() == y.shape(), "shape mismatch {} {}", x.shape(), + y.shape()); + SPU_ENFORCE(x.shape() == z.shape(), "shape mismatch {} {}", x.shape(), + z.shape()); + + auto out = proc(ctx, UnwrapValue(x), UnwrapValue(y), UnwrapValue(z)); + + ctx->pushOutput(WrapValue(out)); +} + void MatmulKernel::evaluate(KernelEvalContext* ctx) const { const auto& lhs = ctx->getParam(0); const auto& rhs = ctx->getParam(1); diff --git a/libspu/mpc/kernel.h b/libspu/mpc/kernel.h index 8db9b546..a9010029 100644 --- a/libspu/mpc/kernel.h +++ b/libspu/mpc/kernel.h @@ -52,6 +52,13 @@ class BinaryKernel : public Kernel { const NdArrayRef& rhs) const = 0; }; +class TernaryKnernel : public Kernel { + public: + void evaluate(KernelEvalContext* ctx) const override; + virtual NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& x, + const NdArrayRef& y, const NdArrayRef& z) const = 0; +}; + class MatmulKernel : public Kernel { public: void evaluate(KernelEvalContext* ctx) const override; diff --git a/libspu/mpc/shamir/arithmetic.cc b/libspu/mpc/shamir/arithmetic.cc index fb137584..2a6e3de7 100644 --- a/libspu/mpc/shamir/arithmetic.cc +++ b/libspu/mpc/shamir/arithmetic.cc @@ -36,6 +36,10 @@ NdArrayRef wrap_a2p(SPUContext* ctx, const NdArrayRef& x) { return UnwrapValue(a2p(ctx, WrapValue(x))); } +NdArrayRef wrap_negate_a(SPUContext* ctx, const NdArrayRef& x) { + return UnwrapValue(negate_a(ctx, WrapValue(x))); +} + // Generate zero sharings of degree = threshold NdArrayRef gen_zero_shares(KernelEvalContext* ctx, int64_t numel, int64_t threshold) { diff --git a/libspu/mpc/shamir/arithmetic.h b/libspu/mpc/shamir/arithmetic.h index 608f810d..25113242 100644 --- a/libspu/mpc/shamir/arithmetic.h +++ b/libspu/mpc/shamir/arithmetic.h @@ -147,6 +147,18 @@ class MulAA : public BinaryKernel { const NdArrayRef& rhs) const override; }; +class MulAAA : public TernaryKnernel { + public: + static constexpr const char* kBindName() { return "mul_aaa"; } + + ce::CExpr latency() const override { return ce::Const(2); } + + ce::CExpr comm() const override { return ce::K() * 4; } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& x, + const NdArrayRef& y, const NdArrayRef& z) const override; +}; + class MulAAP : public BinaryKernel { public: static constexpr const char* kBindName() { return "mul_aa_p"; }