Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[2PC] Add dispatch from DotGeneral to BatchMatmul #433

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions libspu/kernel/hal/fxp_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,18 @@ Value f_mmul(SPUContext* ctx, const Value& x, const Value& y) {
return _trunc(ctx, _mmul(ctx, x, y)).setDtype(x.dtype());
}

std::optional<Value> f_batch_mmul(SPUContext* ctx, const Value& x,
const Value& y) {
SPU_TRACE_HAL_LEAF(ctx, x, y);

SPU_ENFORCE(x.isFxp() && y.isFxp() && x.dtype() == y.dtype());
auto ret = _batch_mmul(ctx, x, y);
if (not ret.has_value()) {
return NotAvailable;
}
return _trunc(ctx, *ret).setDtype(x.dtype());
}

Value f_conv2d(SPUContext* ctx, const Value& x, const Value& y,
const Strides& window_strides) {
SPU_TRACE_HAL_LEAF(ctx, x, y, window_strides);
Expand Down
3 changes: 3 additions & 0 deletions libspu/kernel/hal/fxp_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ Value f_mul(SPUContext* ctx, const Value& x, const Value& y,

Value f_mmul(SPUContext* ctx, const Value& x, const Value& y);

std::optional<Value> f_batch_mmul(SPUContext* ctx, const Value& x,
const Value& y);

Value f_conv2d(SPUContext* ctx, const Value& x, const Value& y,
const Strides& window_strides);

Expand Down
11 changes: 11 additions & 0 deletions libspu/kernel/hal/integer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,15 @@ Value i_tensordot(SPUContext* ctx, const Value& x, const Value& y,
return _tensordot(ctx, x, y, ix, iy).setDtype(x.dtype());
}

std::optional<Value> i_batch_mmul(SPUContext* ctx, const Value& x,
const Value& y) {
SPU_TRACE_HAL_LEAF(ctx, x, y);
ENSURE_INT_AND_DTYPE_MATCH(x, y);
auto ret = _batch_mmul(ctx, x, y);
if (ret.has_value()) {
ret->setDtype(x.dtype());
}
return ret;
}

} // namespace spu::kernel::hal
3 changes: 3 additions & 0 deletions libspu/kernel/hal/integer.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ Value i_mul(SPUContext* ctx, const Value& x, const Value& y);

Value i_mmul(SPUContext* ctx, const Value& x, const Value& y);

std::optional<Value> i_batch_mmul(SPUContext* ctx, const Value& x,
const Value& y);

Value i_tensordot(SPUContext* ctx, const Value& x, const Value& y,
const Index& ix, const Index& iy);

Expand Down
22 changes: 22 additions & 0 deletions libspu/kernel/hal/polymorphic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,28 @@ Value matmul(SPUContext* ctx, const Value& x, const Value& y) {
return dtypeBinaryDispatch("mmul", f_mmul, i_mmul, ctx, x, y);
}

std::optional<Value> batch_matmul(SPUContext* ctx, const Value& x,
const Value& y) {
SPU_TRACE_HAL_DISP(ctx, x, y);
if (isCrossIntFxp(x, y)) {
auto ret = _batch_mmul(ctx, x, y);
if (ret.has_value()) {
auto new_dtype = x.isFxp() ? x.dtype() : y.dtype();
ret->setDtype(new_dtype);
}
return ret;
}

if (x.isInt() && y.isInt()) {
return i_batch_mmul(ctx, x, y);
}

if (x.isFxp() && y.isFxp()) {
return f_batch_mmul(ctx, x, y);
}
return NotAvailable;
}

Value tensordot(SPUContext* ctx, const Value& x, const Value& y,
const Index& ix, const Index& iy) {
SPU_TRACE_HAL_DISP(ctx, x, y, ix, iy);
Expand Down
6 changes: 6 additions & 0 deletions libspu/kernel/hal/polymorphic.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ Value bitwise_not(SPUContext* ctx, const Value& in);
// @param y, the second parameter
Value matmul(SPUContext* ctx, const Value& x, const Value& y);

/// batch matrix production operator
// @param x, the first parameter
// @param y, the second parameter
std::optional<Value> batch_matmul(SPUContext* ctx, const Value& x,
const Value& y);

/// matrix production operator
// @param x, the first parameter
// @param y, the second parameter
Expand Down
9 changes: 9 additions & 0 deletions libspu/kernel/hal/prot_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ namespace spu::kernel::hal {
return ret; \
}

#define MAP_OPTIONAL_MMUL_OP(NAME) \
std::optional<Value> _##NAME(SPUContext* ctx, const Value& x, \
const Value& y) { \
SPU_TRACE_HAL_DISP(ctx, x, y); \
return mpc::NAME(ctx, x, y); \
}

Type _common_type_s(SPUContext* ctx, const Type& a, const Type& b) {
SPU_TRACE_HAL_DISP(ctx, a, b);
return mpc::common_type_s(ctx, a, b);
Expand Down Expand Up @@ -194,6 +201,8 @@ MAP_MMUL_OP(mmul_ss)
MAP_MMUL_OP(mmul_sv)
MAP_MMUL_OP(mmul_vp)
MAP_MMUL_OP(mmul_vv)
MAP_OPTIONAL_MMUL_OP(batch_mmul_ss)
MAP_OPTIONAL_MMUL_OP(batch_mmul_sv)

#define MAP_OPTIONAL_BINARY_OP(NAME) \
std::optional<Value> _##NAME(SPUContext* ctx, const Value& x, \
Expand Down
4 changes: 4 additions & 0 deletions libspu/kernel/hal/prot_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ Value _mmul_ss(SPUContext* ctx, const Value& x, const Value& y);
Value _mmul_vv(SPUContext* ctx, const Value& x, const Value& y);
Value _mmul_vp(SPUContext* ctx, const Value& x, const Value& y);
Value _mmul_sv(SPUContext* ctx, const Value& x, const Value& y);
std::optional<Value> _batch_mmul_ss(SPUContext* ctx, const Value& x,
const Value& y);
std::optional<Value> _batch_mmul_sv(SPUContext* ctx, const Value& x,
const Value& y);

Value _conv2d_ss(SPUContext* ctx, const Value& input, const Value& kernel,
const Strides& strides);
Expand Down
15 changes: 15 additions & 0 deletions libspu/kernel/hal/ring.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,16 @@ static Value _mmul_impl(SPUContext* ctx, const Value& x, const Value& y) {
}
};

static OptionalAPI<Value> _batch_mmul_impl(SPUContext* ctx, const Value& x,
const Value& y) {
if (x.isSecret() && y.isSecret()) { // SS
return _batch_mmul_ss(ctx, x, y);
} else if (x.isSecret() && y.isPrivate()) { // SV
return _batch_mmul_sv(ctx, x, y);
}
return NotAvailable;
}

Value _trunc(SPUContext* ctx, const Value& x, size_t bits, SignType sign) {
SPU_TRACE_HAL_LEAF(ctx, x, bits);
bits = (bits == 0) ? ctx->getFxpBits() : bits;
Expand Down Expand Up @@ -277,6 +287,11 @@ Value _sub(SPUContext* ctx, const Value& x, const Value& y) {
return res;
}

OptionalAPI<Value> _batch_mmul(SPUContext* ctx, const Value& x,
const Value& y) {
return _batch_mmul_impl(ctx, x, y);
}

Value _mmul(SPUContext* ctx, const Value& x, const Value& y) {
auto [m, n, k] = deduceMmulArgs(x.shape(), y.shape());

Expand Down
3 changes: 3 additions & 0 deletions libspu/kernel/hal/ring.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ Value _mul(SPUContext* ctx, const Value& x, const Value& y);

Value _mmul(SPUContext* ctx, const Value& x, const Value& y);

std::optional<Value> _batch_mmul(SPUContext* ctx, const Value& x,
const Value& y);

Value _conv2d(SPUContext* ctx, const Value& x, const Value& y,
const Strides& strides);

Expand Down
5 changes: 5 additions & 0 deletions libspu/kernel/hlo/basic_binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ spu::Value DotGeneral(SPUContext *ctx, const spu::Value &lhs,
const spu::Value &rhs) {
int64_t num_batch = lhs.shape()[0];

auto ret = kernel::hal::batch_matmul(ctx, lhs, rhs);
if (ret.has_value()) {
return *ret;
}

std::vector<spu::Value> results(num_batch);
Index lhs_slice_begin(3, 0);
Index lhs_slice_end(lhs.shape().begin(), lhs.shape().end());
Expand Down
12 changes: 12 additions & 0 deletions libspu/mpc/ab_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,18 @@ OptionalAPI<Value> mmul_av(SPUContext* ctx, const Value& x, const Value& y) {
return NotAvailable;
}

OptionalAPI<Value> batch_mmul_aa(SPUContext* ctx, const Value& x,
const Value& y) {
TRY_DISPATCH(ctx, x, y);
return NotAvailable;
}

OptionalAPI<Value> batch_mmul_av(SPUContext* ctx, const Value& x,
const Value& y) {
TRY_DISPATCH(ctx, x, y);
return NotAvailable;
}

Type common_type_b(SPUContext* ctx, const Type& a, const Type& b) {
SPU_TRACE_MPC_LEAF(ctx, a, b);
return dynDispatch<Type>(ctx, __func__, a, b);
Expand Down
4 changes: 4 additions & 0 deletions libspu/mpc/ab_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ Value trunc_a(SPUContext* ctx, const Value& x, size_t nbits, SignType sign);
Value mmul_ap(SPUContext* ctx, const Value& x, const Value& y);
Value mmul_aa(SPUContext* ctx, const Value& x, const Value& y);
OptionalAPI<Value> mmul_av(SPUContext* ctx, const Value& x, const Value& y);
OptionalAPI<Value> batch_mmul_aa(SPUContext* ctx, const Value& x,
const Value& y);
OptionalAPI<Value> batch_mmul_av(SPUContext* ctx, const Value& x,
const Value& y);

Type common_type_b(SPUContext* ctx, const Type& a, const Type& b);
Value cast_type_b(SPUContext* ctx, const Value& a, const Type& to_type);
Expand Down
44 changes: 44 additions & 0 deletions libspu/mpc/ab_api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,50 @@ TEST_P(ArithmeticTest, MatMulAA) {
});
}

TEST_P(ArithmeticTest, BatchMatMulAA) {
const auto factory = std::get<0>(GetParam());
const RuntimeConfig& conf = std::get<1>(GetParam());
const size_t npc = std::get<2>(GetParam());

const int64_t B = 3;
const int64_t M = 4;
const int64_t K = 5;
const int64_t N = 6;
const Shape shape_A = {B, M, K};
const Shape shape_B = {B, K, N};
const Shape shape_C = {B, M, N};

utils::simulate(npc, [&](const std::shared_ptr<yacl::link::Context>& lctx) {
auto obj = factory(conf, lctx);

/* GIVEN */
auto p0 = rand_p(obj.get(), shape_A);
auto p1 = rand_p(obj.get(), shape_B);
auto a0 = p2a(obj.get(), p0);
auto a1 = p2a(obj.get(), p1);

/* WHEN */
auto prev = obj->prot()->getState<Communicator>()->getStats();
auto tmp = batch_mmul_aa(obj.get(), a0, a1);
if (not tmp.has_value()) {
return;
}
auto cost = obj->prot()->getState<Communicator>()->getStats() - prev;

auto r_aa = a2p(obj.get(), *tmp);
auto r_pp = batch_mmul_pp(obj.get(), p0, p1);

/* THEN */
EXPECT_VALUE_EQ(r_aa, r_pp);
ce::Params params = {{"K", SizeOf(conf.field()) * 8},
{"N", npc},
{"m", M},
{"n", N},
{"k", K}};
EXPECT_TRUE(verifyCost(obj->prot()->getKernel("mmul_aa"), "mmul_aa", params,
cost, 1));
});
}
TEST_P(ArithmeticTest, NotA) {
const auto factory = std::get<0>(GetParam());
const RuntimeConfig& conf = std::get<1>(GetParam());
Expand Down
24 changes: 24 additions & 0 deletions libspu/mpc/api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,30 @@ Value mmul_pp(SPUContext* ctx, const Value& x, const Value& y) {
FORCE_DISPATCH(ctx, x, y);
}

// NOTE(lwj): LHS.shape: B x m x k, and RHS.shape: B x k x n
// Out shape is B x m x n
Value batch_mmul_pp(SPUContext* ctx, const Value& x, const Value& y) {
FORCE_DISPATCH(ctx, x, y);
}

OptionalAPI<Value> batch_mmul_ss(SPUContext* ctx, const Value& x,
const Value& y) {
SPU_TRACE_MPC_DISP(ctx, x, y);
if (auto ret = batch_mmul_aa(ctx, x, y)) {
return ret.value();
}
return NotAvailable;
}

OptionalAPI<Value> batch_mmul_sv(SPUContext* ctx, const Value& x,
const Value& y) {
SPU_TRACE_MPC_DISP(ctx, x, y);
if (auto ret = batch_mmul_av(ctx, x, y)) {
return ret.value();
}
return NotAvailable;
}

//////////////////////////////////////////////////////////////////////////////

Value and_ss(SPUContext* ctx, const Value& x, const Value& y) {
Expand Down
7 changes: 7 additions & 0 deletions libspu/mpc/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,13 @@ Value mmul_sp(SPUContext* ctx, const Value& x, const Value& y);
Value mmul_vv(SPUContext* ctx, const Value& x, const Value& y);
Value mmul_vp(SPUContext* ctx, const Value& x, const Value& y);
Value mmul_pp(SPUContext* ctx, const Value& x, const Value& y);
// NOTE(lwj): LHS.shape: B x m x k, and RHS.shape: B x k x n
// Out shape is B x m x n
Value batch_mmul_pp(SPUContext* ctx, const Value& x, const Value& y);
OptionalAPI<Value> batch_mmul_ss(SPUContext* ctx, const Value& x,
const Value& y);
OptionalAPI<Value> batch_mmul_sv(SPUContext* ctx, const Value& x,
const Value& y);

Value and_ss(SPUContext* ctx, const Value& x, const Value& y);
Value and_sv(SPUContext* ctx, const Value& x, const Value& y);
Expand Down
9 changes: 6 additions & 3 deletions libspu/mpc/cheetah/arith/cheetah_dot.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@

namespace spu::mpc::cheetah {

// clang-format off
// Implementation for Dot.
// Ref: Huang et al. "Cheetah: Lean and Fast Secure Two-Party Deep Neural
// Network Inference"
// https://eprint.iacr.org/2022/207.pdf
// Ref: Lu et al. "BumbleBee: Secure Two-party Inference Framework for Large Transformers"
// https://eprint.iacr.org/2023/1678
// Ref: Huang et al. "Cheetah: Lean and Fast Secure Two-Party Deep Neural Network Inference"
// https://eprint.iacr.org/2022/207.pdf
// clang-format on
class CheetahDot {
public:
explicit CheetahDot(const std::shared_ptr<yacl::link::Context>& lctx,
Expand Down
6 changes: 4 additions & 2 deletions libspu/mpc/cheetah/arith/cheetah_mul.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@

namespace spu::mpc::cheetah {

// clang-format off
// Implementation for Mul
// Ref: Rathee et al. "Improved Multiplication Triple Generation over Rings
// via RLWE-based AHE"
// Ref: Lu et al. "BumbleBee: Secure Two-party Inference Framework for Large Transformers"
// https://eprint.iacr.org/2023/1678
// Ref: Rathee et al. "Improved Multiplication Triple Generation over Rings via RLWE-based AHE"
// https://eprint.iacr.org/2019/577.pdf
class CheetahMul {
public:
Expand Down
Loading
Loading