Skip to content

Commit

Permalink
add prot ReLU
Browse files Browse the repository at this point in the history
  • Loading branch information
f7ed committed Jan 2, 2025
1 parent dabb7f9 commit 57978ac
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 0 deletions.
4 changes: 4 additions & 0 deletions libspu/mpc/ab_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ Value mul_aa(SPUContext* ctx, const Value& x, const Value& y) {
TILED_DISPATCH(ctx, x, y);
}

Value mul_aaa(SPUContext* ctx, const Value& x, const Value& y, const Value& z) {
TILED_DISPATCH(ctx, x, y, z);
}

Value mul_aa_p(SPUContext* ctx, const Value& x, const Value& y) {
TILED_DISPATCH(ctx, x, y);
}
Expand Down
1 change: 1 addition & 0 deletions libspu/mpc/ab_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ OptionalAPI<Value> add_av(SPUContext* ctx, const Value& x, const Value& y);

Value mul_ap(SPUContext* ctx, const Value& x, const Value& y);
Value mul_aa(SPUContext* ctx, const Value& x, const Value& y);
Value mul_aaa(SPUContext* ctx, const Value& x, const Value& y, const Value& z);
Value mul_aa_p(SPUContext* ctx, const Value& x, const Value& y);

Value square_a(SPUContext* ctx, const Value& x);
Expand Down
31 changes: 31 additions & 0 deletions libspu/mpc/ab_api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,37 @@ TEST_P(ArithmeticTest, MulAA) {
});
}

TEST_P(ArithmeticTest, MulAAA) {
const auto factory = std::get<0>(GetParam());
const RuntimeConfig& conf = std::get<1>(GetParam());
const size_t npc = std::get<2>(GetParam());

utils::simulate(npc, [&](const std::shared_ptr<yacl::link::Context>& lctx) {
auto sctx = factory(conf, lctx);

auto p0 = rand_p(sctx.get(), kShape);
auto p1 = rand_p(sctx.get(), kShape);
auto p2 = rand_p(sctx.get(), kShape);

auto v0 = p2v(sctx.get(), p0, 0);
auto v1 = p2v(sctx.get(), p1, 1);
auto v2 = p2v(sctx.get(), p2, 2);

auto a0 = v2a(sctx.get(), v0);
auto a1 = v2a(sctx.get(), v1);
auto a2 = v2a(sctx.get(), v2);

auto prod = mul_aaa(sctx.get(), a0, a1, a2);
auto p_prod = a2p(sctx.get(), prod);

auto s = mul_pp(sctx.get(), p0, p1);
auto s_prime = mul_pp(sctx.get(), s, p2);

/* THEN */
EXPECT_VALUE_EQ(s_prime, p_prod);
});
}

TEST_P(ArithmeticTest, MulAAP) {
const auto factory = std::get<0>(GetParam());
const RuntimeConfig& conf = std::get<1>(GetParam());
Expand Down
15 changes: 15 additions & 0 deletions libspu/mpc/kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,21 @@ void BinaryKernel::evaluate(KernelEvalContext* ctx) const {
ctx->pushOutput(WrapValue(z));
}

void TernaryKnernel::evaluate(KernelEvalContext* ctx) const {
const auto& x = ctx->getParam<Value>(0);
const auto& y = ctx->getParam<Value>(1);
const auto& z = ctx->getParam<Value>(2);

SPU_ENFORCE(x.shape() == y.shape(), "shape mismatch {} {}", x.shape(),
y.shape());
SPU_ENFORCE(x.shape() == z.shape(), "shape mismatch {} {}", x.shape(),
z.shape());

auto out = proc(ctx, UnwrapValue(x), UnwrapValue(y), UnwrapValue(z));

ctx->pushOutput(WrapValue(out));
}

void MatmulKernel::evaluate(KernelEvalContext* ctx) const {
const auto& lhs = ctx->getParam<Value>(0);
const auto& rhs = ctx->getParam<Value>(1);
Expand Down
7 changes: 7 additions & 0 deletions libspu/mpc/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ class BinaryKernel : public Kernel {
const NdArrayRef& rhs) const = 0;
};

class TernaryKnernel : public Kernel {
public:
void evaluate(KernelEvalContext* ctx) const override;
virtual NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& x,
const NdArrayRef& y, const NdArrayRef& z) const = 0;
};

class MatmulKernel : public Kernel {
public:
void evaluate(KernelEvalContext* ctx) const override;
Expand Down
4 changes: 4 additions & 0 deletions libspu/mpc/shamir/arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ NdArrayRef wrap_a2p(SPUContext* ctx, const NdArrayRef& x) {
return UnwrapValue(a2p(ctx, WrapValue(x)));
}

NdArrayRef wrap_negate_a(SPUContext* ctx, const NdArrayRef& x) {
return UnwrapValue(negate_a(ctx, WrapValue(x)));
}

// Generate zero sharings of degree = threshold
NdArrayRef gen_zero_shares(KernelEvalContext* ctx, int64_t numel,
int64_t threshold) {
Expand Down
12 changes: 12 additions & 0 deletions libspu/mpc/shamir/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,18 @@ class MulAA : public BinaryKernel {
const NdArrayRef& rhs) const override;
};

class MulAAA : public TernaryKnernel {
public:
static constexpr const char* kBindName() { return "mul_aaa"; }

ce::CExpr latency() const override { return ce::Const(2); }

ce::CExpr comm() const override { return ce::K() * 4; }

NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& x,
const NdArrayRef& y, const NdArrayRef& z) const override;
};

class MulAAP : public BinaryKernel {
public:
static constexpr const char* kBindName() { return "mul_aa_p"; }
Expand Down

0 comments on commit 57978ac

Please sign in to comment.