From 665e0739f58f1e7ff422c96e7d89659d1da88b23 Mon Sep 17 00:00:00 2001 From: fionser Date: Sun, 10 Dec 2023 21:37:26 +0800 Subject: [PATCH] [2PC] Add dispatch from DotGeneral to BatchMatmul and add some documentations Current DotGeneral is decomposed into seperated matmul calls But for some back-ends (e.g., Cheetah); we can do better. --- libspu/kernel/hal/fxp_base.cc | 12 ++ libspu/kernel/hal/fxp_base.h | 3 + libspu/kernel/hal/integer.cc | 11 ++ libspu/kernel/hal/integer.h | 3 + libspu/kernel/hal/polymorphic.cc | 22 ++++ libspu/kernel/hal/polymorphic.h | 6 + libspu/kernel/hal/prot_wrapper.cc | 9 ++ libspu/kernel/hal/prot_wrapper.h | 4 + libspu/kernel/hal/ring.cc | 15 +++ libspu/kernel/hal/ring.h | 3 + libspu/kernel/hlo/basic_binary.cc | 5 + libspu/mpc/ab_api.cc | 12 ++ libspu/mpc/ab_api.h | 4 + libspu/mpc/ab_api_test.cc | 44 +++++++ libspu/mpc/api.cc | 24 ++++ libspu/mpc/api.h | 7 ++ libspu/mpc/cheetah/arith/cheetah_dot.h | 9 +- libspu/mpc/cheetah/arith/cheetah_mul.h | 6 +- libspu/mpc/cheetah/arithmetic.cc | 165 +++++++++++++++++++++++++ libspu/mpc/cheetah/arithmetic.h | 38 ++++++ libspu/mpc/cheetah/protocol.cc | 4 +- libspu/mpc/cheetah/rlwe/packlwes.h | 3 + libspu/mpc/common/pv2k.cc | 56 +++++++++ 23 files changed, 459 insertions(+), 6 deletions(-) diff --git a/libspu/kernel/hal/fxp_base.cc b/libspu/kernel/hal/fxp_base.cc index 8b45c125..a5d753c5 100644 --- a/libspu/kernel/hal/fxp_base.cc +++ b/libspu/kernel/hal/fxp_base.cc @@ -270,6 +270,18 @@ Value f_mmul(SPUContext* ctx, const Value& x, const Value& y) { return _trunc(ctx, _mmul(ctx, x, y)).setDtype(x.dtype()); } +std::optional f_batch_mmul(SPUContext* ctx, const Value& x, + const Value& y) { + SPU_TRACE_HAL_LEAF(ctx, x, y); + + SPU_ENFORCE(x.isFxp() && y.isFxp() && x.dtype() == y.dtype()); + auto ret = _batch_mmul(ctx, x, y); + if (not ret.has_value()) { + return NotAvailable; + } + return _trunc(ctx, *ret).setDtype(x.dtype()); +} + Value f_conv2d(SPUContext* ctx, const Value& x, const Value& y, const Strides& window_strides) { SPU_TRACE_HAL_LEAF(ctx, x, y, window_strides); diff --git a/libspu/kernel/hal/fxp_base.h b/libspu/kernel/hal/fxp_base.h index df92c2d1..1b61f0c8 100644 --- a/libspu/kernel/hal/fxp_base.h +++ b/libspu/kernel/hal/fxp_base.h @@ -59,6 +59,9 @@ Value f_mul(SPUContext* ctx, const Value& x, const Value& y, Value f_mmul(SPUContext* ctx, const Value& x, const Value& y); +std::optional f_batch_mmul(SPUContext* ctx, const Value& x, + const Value& y); + Value f_conv2d(SPUContext* ctx, const Value& x, const Value& y, const Strides& window_strides); diff --git a/libspu/kernel/hal/integer.cc b/libspu/kernel/hal/integer.cc index 696f1aec..fcf825e6 100644 --- a/libspu/kernel/hal/integer.cc +++ b/libspu/kernel/hal/integer.cc @@ -98,4 +98,15 @@ Value i_tensordot(SPUContext* ctx, const Value& x, const Value& y, return _tensordot(ctx, x, y, ix, iy).setDtype(x.dtype()); } +std::optional i_batch_mmul(SPUContext* ctx, const Value& x, + const Value& y) { + SPU_TRACE_HAL_LEAF(ctx, x, y); + ENSURE_INT_AND_DTYPE_MATCH(x, y); + auto ret = _batch_mmul(ctx, x, y); + if (ret.has_value()) { + ret->setDtype(x.dtype()); + } + return ret; +} + } // namespace spu::kernel::hal diff --git a/libspu/kernel/hal/integer.h b/libspu/kernel/hal/integer.h index 12461365..7267c5b1 100644 --- a/libspu/kernel/hal/integer.h +++ b/libspu/kernel/hal/integer.h @@ -40,6 +40,9 @@ Value i_mul(SPUContext* ctx, const Value& x, const Value& y); Value i_mmul(SPUContext* ctx, const Value& x, const Value& y); +std::optional i_batch_mmul(SPUContext* ctx, const Value& x, + const Value& y); + Value i_tensordot(SPUContext* ctx, const Value& x, const Value& y, const Index& ix, const Index& iy); diff --git a/libspu/kernel/hal/polymorphic.cc b/libspu/kernel/hal/polymorphic.cc index 2542eb59..7b00e7f8 100644 --- a/libspu/kernel/hal/polymorphic.cc +++ b/libspu/kernel/hal/polymorphic.cc @@ -126,6 +126,28 @@ Value matmul(SPUContext* ctx, const Value& x, const Value& y) { return dtypeBinaryDispatch("mmul", f_mmul, i_mmul, ctx, x, y); } +std::optional batch_matmul(SPUContext* ctx, const Value& x, + const Value& y) { + SPU_TRACE_HAL_DISP(ctx, x, y); + if (isCrossIntFxp(x, y)) { + auto ret = _batch_mmul(ctx, x, y); + if (ret.has_value()) { + auto new_dtype = x.isFxp() ? x.dtype() : y.dtype(); + ret->setDtype(new_dtype); + } + return ret; + } + + if (x.isInt() && y.isInt()) { + return i_batch_mmul(ctx, x, y); + } + + if (x.isFxp() && y.isFxp()) { + return f_batch_mmul(ctx, x, y); + } + return NotAvailable; +} + Value tensordot(SPUContext* ctx, const Value& x, const Value& y, const Index& ix, const Index& iy) { SPU_TRACE_HAL_DISP(ctx, x, y, ix, iy); diff --git a/libspu/kernel/hal/polymorphic.h b/libspu/kernel/hal/polymorphic.h index e9e82d0b..10cca7d9 100644 --- a/libspu/kernel/hal/polymorphic.h +++ b/libspu/kernel/hal/polymorphic.h @@ -55,6 +55,12 @@ Value bitwise_not(SPUContext* ctx, const Value& in); // @param y, the second parameter Value matmul(SPUContext* ctx, const Value& x, const Value& y); +/// batch matrix production operator +// @param x, the first parameter +// @param y, the second parameter +std::optional batch_matmul(SPUContext* ctx, const Value& x, + const Value& y); + /// matrix production operator // @param x, the first parameter // @param y, the second parameter diff --git a/libspu/kernel/hal/prot_wrapper.cc b/libspu/kernel/hal/prot_wrapper.cc index 45942b88..369d0df4 100644 --- a/libspu/kernel/hal/prot_wrapper.cc +++ b/libspu/kernel/hal/prot_wrapper.cc @@ -60,6 +60,13 @@ namespace spu::kernel::hal { return ret; \ } +#define MAP_OPTIONAL_MMUL_OP(NAME) \ + std::optional _##NAME(SPUContext* ctx, const Value& x, \ + const Value& y) { \ + SPU_TRACE_HAL_DISP(ctx, x, y); \ + return mpc::NAME(ctx, x, y); \ + } + Type _common_type_s(SPUContext* ctx, const Type& a, const Type& b) { SPU_TRACE_HAL_DISP(ctx, a, b); return mpc::common_type_s(ctx, a, b); @@ -194,6 +201,8 @@ MAP_MMUL_OP(mmul_ss) MAP_MMUL_OP(mmul_sv) MAP_MMUL_OP(mmul_vp) MAP_MMUL_OP(mmul_vv) +MAP_OPTIONAL_MMUL_OP(batch_mmul_ss) +MAP_OPTIONAL_MMUL_OP(batch_mmul_sv) #define MAP_OPTIONAL_BINARY_OP(NAME) \ std::optional _##NAME(SPUContext* ctx, const Value& x, \ diff --git a/libspu/kernel/hal/prot_wrapper.h b/libspu/kernel/hal/prot_wrapper.h index 2d0a62f4..b1eafb7c 100644 --- a/libspu/kernel/hal/prot_wrapper.h +++ b/libspu/kernel/hal/prot_wrapper.h @@ -88,6 +88,10 @@ Value _mmul_ss(SPUContext* ctx, const Value& x, const Value& y); Value _mmul_vv(SPUContext* ctx, const Value& x, const Value& y); Value _mmul_vp(SPUContext* ctx, const Value& x, const Value& y); Value _mmul_sv(SPUContext* ctx, const Value& x, const Value& y); +std::optional _batch_mmul_ss(SPUContext* ctx, const Value& x, + const Value& y); +std::optional _batch_mmul_sv(SPUContext* ctx, const Value& x, + const Value& y); Value _conv2d_ss(SPUContext* ctx, const Value& input, const Value& kernel, const Strides& strides); diff --git a/libspu/kernel/hal/ring.cc b/libspu/kernel/hal/ring.cc index 528d8719..99c1189e 100644 --- a/libspu/kernel/hal/ring.cc +++ b/libspu/kernel/hal/ring.cc @@ -177,6 +177,16 @@ static Value _mmul_impl(SPUContext* ctx, const Value& x, const Value& y) { } }; +static OptionalAPI _batch_mmul_impl(SPUContext* ctx, const Value& x, + const Value& y) { + if (x.isSecret() && y.isSecret()) { // SS + return _batch_mmul_ss(ctx, x, y); + } else if (x.isSecret() && y.isPrivate()) { // SV + return _batch_mmul_sv(ctx, x, y); + } + return NotAvailable; +} + Value _trunc(SPUContext* ctx, const Value& x, size_t bits, SignType sign) { SPU_TRACE_HAL_LEAF(ctx, x, bits); bits = (bits == 0) ? ctx->getFxpBits() : bits; @@ -277,6 +287,11 @@ Value _sub(SPUContext* ctx, const Value& x, const Value& y) { return res; } +OptionalAPI _batch_mmul(SPUContext* ctx, const Value& x, + const Value& y) { + return _batch_mmul_impl(ctx, x, y); +} + Value _mmul(SPUContext* ctx, const Value& x, const Value& y) { auto [m, n, k] = deduceMmulArgs(x.shape(), y.shape()); diff --git a/libspu/kernel/hal/ring.h b/libspu/kernel/hal/ring.h index da901655..9d742982 100644 --- a/libspu/kernel/hal/ring.h +++ b/libspu/kernel/hal/ring.h @@ -51,6 +51,9 @@ Value _mul(SPUContext* ctx, const Value& x, const Value& y); Value _mmul(SPUContext* ctx, const Value& x, const Value& y); +std::optional _batch_mmul(SPUContext* ctx, const Value& x, + const Value& y); + Value _conv2d(SPUContext* ctx, const Value& x, const Value& y, const Strides& strides); diff --git a/libspu/kernel/hlo/basic_binary.cc b/libspu/kernel/hlo/basic_binary.cc index ae59810b..306c08f2 100644 --- a/libspu/kernel/hlo/basic_binary.cc +++ b/libspu/kernel/hlo/basic_binary.cc @@ -80,6 +80,11 @@ spu::Value DotGeneral(SPUContext *ctx, const spu::Value &lhs, const spu::Value &rhs) { int64_t num_batch = lhs.shape()[0]; + auto ret = kernel::hal::batch_matmul(ctx, lhs, rhs); + if (ret.has_value()) { + return *ret; + } + std::vector results(num_batch); Index lhs_slice_begin(3, 0); Index lhs_slice_end(lhs.shape().begin(), lhs.shape().end()); diff --git a/libspu/mpc/ab_api.cc b/libspu/mpc/ab_api.cc index 432b82e6..a439fc1b 100644 --- a/libspu/mpc/ab_api.cc +++ b/libspu/mpc/ab_api.cc @@ -147,6 +147,18 @@ OptionalAPI mmul_av(SPUContext* ctx, const Value& x, const Value& y) { return NotAvailable; } +OptionalAPI batch_mmul_aa(SPUContext* ctx, const Value& x, + const Value& y) { + TRY_DISPATCH(ctx, x, y); + return NotAvailable; +} + +OptionalAPI batch_mmul_av(SPUContext* ctx, const Value& x, + const Value& y) { + TRY_DISPATCH(ctx, x, y); + return NotAvailable; +} + Type common_type_b(SPUContext* ctx, const Type& a, const Type& b) { SPU_TRACE_MPC_LEAF(ctx, a, b); return dynDispatch(ctx, __func__, a, b); diff --git a/libspu/mpc/ab_api.h b/libspu/mpc/ab_api.h index e0957a79..3244366b 100644 --- a/libspu/mpc/ab_api.h +++ b/libspu/mpc/ab_api.h @@ -50,6 +50,10 @@ Value trunc_a(SPUContext* ctx, const Value& x, size_t nbits, SignType sign); Value mmul_ap(SPUContext* ctx, const Value& x, const Value& y); Value mmul_aa(SPUContext* ctx, const Value& x, const Value& y); OptionalAPI mmul_av(SPUContext* ctx, const Value& x, const Value& y); +OptionalAPI batch_mmul_aa(SPUContext* ctx, const Value& x, + const Value& y); +OptionalAPI batch_mmul_av(SPUContext* ctx, const Value& x, + const Value& y); Type common_type_b(SPUContext* ctx, const Type& a, const Type& b); Value cast_type_b(SPUContext* ctx, const Value& a, const Type& to_type); diff --git a/libspu/mpc/ab_api_test.cc b/libspu/mpc/ab_api_test.cc index 6f58eda8..71d7321f 100644 --- a/libspu/mpc/ab_api_test.cc +++ b/libspu/mpc/ab_api_test.cc @@ -274,6 +274,50 @@ TEST_P(ArithmeticTest, MatMulAA) { }); } +TEST_P(ArithmeticTest, BatchMatMulAA) { + const auto factory = std::get<0>(GetParam()); + const RuntimeConfig& conf = std::get<1>(GetParam()); + const size_t npc = std::get<2>(GetParam()); + + const int64_t B = 3; + const int64_t M = 4; + const int64_t K = 5; + const int64_t N = 6; + const Shape shape_A = {B, M, K}; + const Shape shape_B = {B, K, N}; + const Shape shape_C = {B, M, N}; + + utils::simulate(npc, [&](const std::shared_ptr& lctx) { + auto obj = factory(conf, lctx); + + /* GIVEN */ + auto p0 = rand_p(obj.get(), shape_A); + auto p1 = rand_p(obj.get(), shape_B); + auto a0 = p2a(obj.get(), p0); + auto a1 = p2a(obj.get(), p1); + + /* WHEN */ + auto prev = obj->prot()->getState()->getStats(); + auto tmp = batch_mmul_aa(obj.get(), a0, a1); + if (not tmp.has_value()) { + return; + } + auto cost = obj->prot()->getState()->getStats() - prev; + + auto r_aa = a2p(obj.get(), *tmp); + auto r_pp = batch_mmul_pp(obj.get(), p0, p1); + + /* THEN */ + EXPECT_VALUE_EQ(r_aa, r_pp); + ce::Params params = {{"K", SizeOf(conf.field()) * 8}, + {"N", npc}, + {"m", M}, + {"n", N}, + {"k", K}}; + EXPECT_TRUE(verifyCost(obj->prot()->getKernel("mmul_aa"), "mmul_aa", params, + cost, 1)); + }); +} TEST_P(ArithmeticTest, NotA) { const auto factory = std::get<0>(GetParam()); const RuntimeConfig& conf = std::get<1>(GetParam()); diff --git a/libspu/mpc/api.cc b/libspu/mpc/api.cc index 77d33c46..098b2fcc 100644 --- a/libspu/mpc/api.cc +++ b/libspu/mpc/api.cc @@ -495,6 +495,30 @@ Value mmul_pp(SPUContext* ctx, const Value& x, const Value& y) { FORCE_DISPATCH(ctx, x, y); } +// NOTE(lwj): LHS.shape: B x m x k, and RHS.shape: B x k x n +// Out shape is B x m x n +Value batch_mmul_pp(SPUContext* ctx, const Value& x, const Value& y) { + FORCE_DISPATCH(ctx, x, y); +} + +OptionalAPI batch_mmul_ss(SPUContext* ctx, const Value& x, + const Value& y) { + SPU_TRACE_MPC_DISP(ctx, x, y); + if (auto ret = batch_mmul_aa(ctx, x, y)) { + return ret.value(); + } + return NotAvailable; +} + +OptionalAPI batch_mmul_sv(SPUContext* ctx, const Value& x, + const Value& y) { + SPU_TRACE_MPC_DISP(ctx, x, y); + if (auto ret = batch_mmul_av(ctx, x, y)) { + return ret.value(); + } + return NotAvailable; +} + ////////////////////////////////////////////////////////////////////////////// Value and_ss(SPUContext* ctx, const Value& x, const Value& y) { diff --git a/libspu/mpc/api.h b/libspu/mpc/api.h index b8842c89..29845897 100644 --- a/libspu/mpc/api.h +++ b/libspu/mpc/api.h @@ -125,6 +125,13 @@ Value mmul_sp(SPUContext* ctx, const Value& x, const Value& y); Value mmul_vv(SPUContext* ctx, const Value& x, const Value& y); Value mmul_vp(SPUContext* ctx, const Value& x, const Value& y); Value mmul_pp(SPUContext* ctx, const Value& x, const Value& y); +// NOTE(lwj): LHS.shape: B x m x k, and RHS.shape: B x k x n +// Out shape is B x m x n +Value batch_mmul_pp(SPUContext* ctx, const Value& x, const Value& y); +OptionalAPI batch_mmul_ss(SPUContext* ctx, const Value& x, + const Value& y); +OptionalAPI batch_mmul_sv(SPUContext* ctx, const Value& x, + const Value& y); Value and_ss(SPUContext* ctx, const Value& x, const Value& y); Value and_sv(SPUContext* ctx, const Value& x, const Value& y); diff --git a/libspu/mpc/cheetah/arith/cheetah_dot.h b/libspu/mpc/cheetah/arith/cheetah_dot.h index cf9bf36d..30b874cd 100644 --- a/libspu/mpc/cheetah/arith/cheetah_dot.h +++ b/libspu/mpc/cheetah/arith/cheetah_dot.h @@ -23,10 +23,13 @@ namespace spu::mpc::cheetah { +// clang-format off // Implementation for Dot. -// Ref: Huang et al. "Cheetah: Lean and Fast Secure Two-Party Deep Neural -// Network Inference" -// https://eprint.iacr.org/2022/207.pdf +// Ref: Lu et al. "BumbleBee: Secure Two-party Inference Framework for Large Transformers" +// https://eprint.iacr.org/2023/1678 +// Ref: Huang et al. "Cheetah: Lean and Fast Secure Two-Party Deep Neural Network Inference" +// https://eprint.iacr.org/2022/207.pdf +// clang-format on class CheetahDot { public: explicit CheetahDot(const std::shared_ptr& lctx, diff --git a/libspu/mpc/cheetah/arith/cheetah_mul.h b/libspu/mpc/cheetah/arith/cheetah_mul.h index 03ccd65d..7f400485 100644 --- a/libspu/mpc/cheetah/arith/cheetah_mul.h +++ b/libspu/mpc/cheetah/arith/cheetah_mul.h @@ -22,9 +22,11 @@ namespace spu::mpc::cheetah { +// clang-format off // Implementation for Mul -// Ref: Rathee et al. "Improved Multiplication Triple Generation over Rings -// via RLWE-based AHE" +// Ref: Lu et al. "BumbleBee: Secure Two-party Inference Framework for Large Transformers" +// https://eprint.iacr.org/2023/1678 +// Ref: Rathee et al. "Improved Multiplication Triple Generation over Rings via RLWE-based AHE" // https://eprint.iacr.org/2019/577.pdf class CheetahMul { public: diff --git a/libspu/mpc/cheetah/arithmetic.cc b/libspu/mpc/cheetah/arithmetic.cc index 9dac9189..9a9bd800 100644 --- a/libspu/mpc/cheetah/arithmetic.cc +++ b/libspu/mpc/cheetah/arithmetic.cc @@ -340,4 +340,169 @@ NdArrayRef MatMulAA::proc(KernelEvalContext* ctx, const NdArrayRef& x, return ring_add(ret, task.get()).as(x.eltype()); } +// LHS is a share type (A); RHS is a private type (V) +NdArrayRef MatMulAV::proc(KernelEvalContext* ctx, const NdArrayRef& x, + const NdArrayRef& y) const { + if (0 == x.numel() || 0 == y.numel()) { + return NdArrayRef(x.eltype(), {x.shape()[0], y.shape()[1]}); + } + + auto* comm = ctx->getState(); + auto* dot_prot = ctx->getState()->get(); + const auto* priv_type = y.eltype().as(); + SPU_ENFORCE(priv_type != nullptr, "RHS should be a private type"); + const int rank = comm->getRank(); + const int owner = priv_type->owner(); + + const Shape3D dim3 = {x.shape()[0], x.shape()[1], y.shape()[1]}; + NdArrayRef out; + if (rank == owner) { + out = dot_prot->DotOLE(y, dim3, false); + auto tmp = ring_mmul(x, y); + ring_add_(out, tmp); + } else { + out = dot_prot->DotOLE(x, dim3, true); + } + return out.as(x.eltype()); +} + +void BatchMatMulAA::evaluate(KernelEvalContext* ctx) const { + // NOTE(lwj): overwrite the shape check in the MatmulKernel + const auto& lhs = ctx->getParam(0); + const auto& rhs = ctx->getParam(1); + const auto& lhs_shape = lhs.shape(); + const auto& rhs_shape = rhs.shape(); + SPU_ENFORCE(lhs_shape.ndim() == rhs_shape.ndim(), + "ndim mismatch: lhs={}, rhs={}", lhs_shape, rhs_shape); + SPU_ENFORCE(lhs_shape[0] == rhs_shape[0], "batch mismatch: lhs={}, rhs={}", + lhs_shape, rhs_shape); + SPU_ENFORCE(lhs_shape[2] == rhs_shape[1], "shape mismatch: lhs={}, rhs={}", + lhs_shape, rhs_shape); + ctx->setOutput(WrapValue(proc(ctx, lhs.data(), rhs.data()))); +} + +NdArrayRef BatchMatMulAA::proc(KernelEvalContext* ctx, const NdArrayRef& x, + const NdArrayRef& y) const { + if (0 == x.numel() || 0 == y.numel()) { + return NdArrayRef(x.eltype(), {x.shape()[0], y.shape()[1]}); + } + + auto* comm = ctx->getState(); + auto* dot_prot = ctx->getState()->get(); + const int rank = comm->getRank(); + + // (x0 + x1) * (y0 + y1) + // Compute the cross terms homomorphically + const Shape4D dim4 = {x.shape()[0], x.shape()[1], x.shape()[2], y.shape()[2]}; + + auto* conn = comm->lctx().get(); + auto dupx = ctx->getState()->duplx(); + std::future task = std::async(std::launch::async, [&] { + // Compute x0*y1 + if (rank == 0) { + return dot_prot->BatchDotOLE(x, dupx.get(), dim4, true); + } else { + return dot_prot->BatchDotOLE(y, dupx.get(), dim4, false); + } + }); + + NdArrayRef x1y0; + if (rank == 0) { + x1y0 = dot_prot->BatchDotOLE(y, conn, dim4, false); + } else { + x1y0 = dot_prot->BatchDotOLE(x, conn, dim4, true); + } + + // local batch mmul + const Strides strides(x.shape().size(), 1); + Index lhs_slice_end(x.shape().begin(), x.shape().end()); + Index rhs_slice_end(y.shape().begin(), y.shape().end()); + Index lhs_slice_begin(3, 0); + Index rhs_slice_begin(3, 0); + NdArrayRef out(x.eltype(), {dim4[0], dim4[1], dim4[3]}); + for (int64_t batch = 0; batch < dim4[0]; ++batch) { + lhs_slice_begin[0] = batch; + lhs_slice_end[0] = batch + 1; + rhs_slice_begin[0] = batch; + rhs_slice_end[0] = batch + 1; + auto lhs_slice = x.slice(lhs_slice_begin, lhs_slice_end, strides) + .reshape({dim4[1], dim4[2]}); + auto rhs_slice = y.slice(rhs_slice_begin, rhs_slice_end, strides) + .reshape({dim4[2], dim4[3]}); + + auto out_slice = + out.slice({batch, 0, 0}, {batch + 1, dim4[1], dim4[3]}, strides); + out_slice = out_slice.reshape({dim4[1], dim4[3]}); + ring_mmul_(out_slice, lhs_slice, rhs_slice); + } + + ring_add_(out, x1y0); + ring_add_(out, task.get()); + return out; +} + +void BatchMatMulAV::evaluate(KernelEvalContext* ctx) const { + // NOTE(lwj): overwrite the shape check in the MatmulKernel + const auto& lhs = ctx->getParam(0); + const auto& rhs = ctx->getParam(1); + const auto& lhs_shape = lhs.shape(); + const auto& rhs_shape = rhs.shape(); + SPU_ENFORCE(lhs_shape.ndim() == rhs_shape.ndim(), + "ndim mismatch: lhs={}, rhs={}", lhs_shape, rhs_shape); + SPU_ENFORCE(lhs_shape[0] == rhs_shape[0], "batch mismatch: lhs={}, rhs={}", + lhs_shape, rhs_shape); + SPU_ENFORCE(lhs_shape[2] == rhs_shape[1], "shape mismatch: lhs={}, rhs={}", + lhs_shape, rhs_shape); + ctx->setOutput(WrapValue(proc(ctx, lhs.data(), rhs.data()))); +} + +NdArrayRef BatchMatMulAV::proc(KernelEvalContext* ctx, const NdArrayRef& x, + const NdArrayRef& y) const { + if (0 == x.numel() || 0 == y.numel()) { + return NdArrayRef(x.eltype(), {x.shape()[0], y.shape()[1]}); + } + + auto* comm = ctx->getState(); + auto* dot_prot = ctx->getState()->get(); + const auto* priv_type = y.eltype().as(); + SPU_ENFORCE(priv_type != nullptr, "RHS should be a private type"); + const int rank = comm->getRank(); + const int owner = priv_type->owner(); + + // (x0 + x1) * (y0 + y1) + // Compute the cross terms homomorphically + const Shape4D dim4 = {x.shape()[0], x.shape()[1], x.shape()[2], y.shape()[2]}; + + NdArrayRef out; + if (rank != owner) { + out = dot_prot->BatchDotOLE(x, comm->lctx().get(), dim4, true); + } else { + out = dot_prot->BatchDotOLE(y, comm->lctx().get(), dim4, false); + // local batch mmul + const Strides strides(x.shape().size(), 1); + Index lhs_slice_end(x.shape().begin(), x.shape().end()); + Index rhs_slice_end(y.shape().begin(), y.shape().end()); + Index lhs_slice_begin(3, 0); + Index rhs_slice_begin(3, 0); + NdArrayRef out(x.eltype(), {dim4[0], dim4[1], dim4[3]}); + for (int64_t batch = 0; batch < dim4[0]; ++batch) { + lhs_slice_begin[0] = batch; + lhs_slice_end[0] = batch + 1; + rhs_slice_begin[0] = batch; + rhs_slice_end[0] = batch + 1; + auto lhs_slice = x.slice(lhs_slice_begin, lhs_slice_end, strides) + .reshape({dim4[1], dim4[2]}); + auto rhs_slice = y.slice(rhs_slice_begin, rhs_slice_end, strides) + .reshape({dim4[2], dim4[3]}); + auto local = ring_mmul(lhs_slice, rhs_slice); + + auto out_slice = + out.slice({batch, 0, 0}, {batch + 1, dim4[1], dim4[3]}, strides); + out_slice = out_slice.reshape({dim4[1], dim4[3]}); + ring_add_(out_slice, local); + } + } + return out; +} + } // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/arithmetic.h b/libspu/mpc/cheetah/arithmetic.h index b66d38b5..f38aa0f3 100644 --- a/libspu/mpc/cheetah/arithmetic.h +++ b/libspu/mpc/cheetah/arithmetic.h @@ -166,6 +166,44 @@ class MatMulAA : public MatmulKernel { const NdArrayRef& y) const override; }; +class MatMulAV : public MatmulKernel { + public: + static constexpr char kBindName[] = "mmul_av"; + + Kind kind() const override { return Kind::Dynamic; } + // LHS: m x k + // RHS: k x n + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& x, + const NdArrayRef& y) const override; +}; + +class BatchMatMulAA : public MatmulKernel { + public: + static constexpr char kBindName[] = "batch_mmul_aa"; + + Kind kind() const override { return Kind::Dynamic; } + + void evaluate(KernelEvalContext* ctx) const override; + + // LHS: b x m x k + // RHS: b x k x n + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& x, + const NdArrayRef& y) const override; +}; + +class BatchMatMulAV : public MatmulKernel { + public: + static constexpr char kBindName[] = "batch_mmul_av"; + + Kind kind() const override { return Kind::Dynamic; } + + void evaluate(KernelEvalContext* ctx) const override; + // LHS: b x m x k + // RHS: b x k x n + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& x, + const NdArrayRef& y) const override; +}; + class Conv2DAA : public Conv2DKernel { public: static constexpr char kBindName[] = "conv2d_aa"; diff --git a/libspu/mpc/cheetah/protocol.cc b/libspu/mpc/cheetah/protocol.cc index 79edf22b..15f4869c 100644 --- a/libspu/mpc/cheetah/protocol.cc +++ b/libspu/mpc/cheetah/protocol.cc @@ -63,7 +63,9 @@ void regCheetahProtocol(SPUContext* ctx, ctx->prot()->regKernel(); ctx->prot()->regKernel(); ctx->prot()->regKernel(); - // ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); ctx->prot()->regKernel(); ctx->prot()->regKernel(); ctx->prot()->regKernel(); diff --git a/libspu/mpc/cheetah/rlwe/packlwes.h b/libspu/mpc/cheetah/rlwe/packlwes.h index c9337cfb..e6b5bf94 100644 --- a/libspu/mpc/cheetah/rlwe/packlwes.h +++ b/libspu/mpc/cheetah/rlwe/packlwes.h @@ -24,6 +24,9 @@ namespace spu::mpc::cheetah { void GenerateGaloisKeyForPacking(const seal::SEALContext &context, const RLWESecretKey &key, bool save_seed, GaloisKeys *out); + +// REF: BumbleBee: Secure Two-party Inference Framework for Large Transformers +// https://eprint.iacr.org/2023/1678 class PackingHelper { public: PackingHelper(size_t gap, const seal::GaloisKeys &galois_keys, diff --git a/libspu/mpc/common/pv2k.cc b/libspu/mpc/common/pv2k.cc index cd020dd6..99da793e 100644 --- a/libspu/mpc/common/pv2k.cc +++ b/libspu/mpc/common/pv2k.cc @@ -395,6 +395,61 @@ class MatMulVP : public MatmulKernel { } }; +class BatchMatMulPP : public MatmulKernel { + public: + static constexpr char kBindName[] = "batch_mmul_pp"; + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + void evaluate(KernelEvalContext* ctx) const override { + // NOTE(lwj): overwrite the shape check in the MatmulKernel + const auto& lhs = ctx->getParam(0); + const auto& rhs = ctx->getParam(1); + const auto& lhs_shape = lhs.shape(); + const auto& rhs_shape = rhs.shape(); + SPU_ENFORCE(lhs_shape.ndim() == rhs_shape.ndim(), + "ndim mismatch: lhs={}, rhs={}", lhs_shape, rhs_shape); + SPU_ENFORCE(lhs_shape[0] == rhs_shape[0], "batch mismatch: lhs={}, rhs={}", + lhs_shape, rhs_shape); + SPU_ENFORCE(lhs_shape[2] == rhs_shape[1], "shape mismatch: lhs={}, rhs={}", + lhs_shape, rhs_shape); + ctx->setOutput(WrapValue(proc(ctx, lhs.data(), rhs.data()))); + } + + NdArrayRef proc(KernelEvalContext*, const NdArrayRef& lhs, + const NdArrayRef& rhs) const override { + SPU_ENFORCE(lhs.eltype() == rhs.eltype()); + const int64_t dim4[4] = {lhs.shape()[0], lhs.shape()[1], lhs.shape()[2], + rhs.shape()[2]}; + const Strides strides(lhs.shape().size(), 1); + Index lhs_slice_end(lhs.shape().begin(), lhs.shape().end()); + Index rhs_slice_end(rhs.shape().begin(), rhs.shape().end()); + Index lhs_slice_begin(3, 0); + Index rhs_slice_begin(3, 0); + + NdArrayRef out(lhs.eltype(), {dim4[0], dim4[1], dim4[3]}); + for (int64_t batch = 0; batch < dim4[0]; ++batch) { + lhs_slice_begin[0] = batch; + lhs_slice_end[0] = batch + 1; + rhs_slice_begin[0] = batch; + rhs_slice_end[0] = batch + 1; + auto lhs_slice = lhs.slice(lhs_slice_begin, lhs_slice_end, strides) + .reshape({dim4[1], dim4[2]}); + auto rhs_slice = rhs.slice(rhs_slice_begin, rhs_slice_end, strides) + .reshape({dim4[2], dim4[3]}); + + auto out_slice = + out.slice({batch, 0, 0}, {batch + 1, dim4[1], dim4[3]}, strides); + out_slice = out_slice.reshape({dim4[1], dim4[3]}); + ring_mmul_(out_slice, lhs_slice, rhs_slice); + } + + return out; + } +}; + class AndPP : public BinaryKernel { public: static constexpr char kBindName[] = "and_pp"; @@ -709,6 +764,7 @@ void regPV2kKernels(Object* obj) { obj->regKernel(); obj->regKernel(); obj->regKernel(); + obj->regKernel(); obj->regKernel(); obj->regKernel(); obj->regKernel();