Skip to content

Commit

Permalink
Repo sync (#492)
Browse files Browse the repository at this point in the history
  • Loading branch information
anakinxc authored Jan 12, 2024
1 parent 43a9e3e commit 1f61217
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 12 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
> please add your unreleased change here.
- [Improvement] Optimize one-time setup for yacl ot
- [Improvement] Optimize sort performance

## 20240105

Expand All @@ -21,7 +22,7 @@
- [Feature] Add equal support for SEMI2K and ABY3
- [Improvement] Optimize sort memory usage
- [Improvement] Improve compatibility with latest Jax
- [Bugfix] Fix compilation cache collision under certian cases
- [Bugfix] Fix compilation cache collision under certain cases
- [Deprecated] macOS 11.x is no longer supported

## 20231108
Expand Down
10 changes: 6 additions & 4 deletions libspu/mpc/cheetah/arith/cheetah_dot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ struct CheetahDot::Impl : public EnableCPRNG {
std::shared_ptr<yacl::link::Context> lctx_;
bool disable_pack_ = false;

mutable std::shared_mutex context_lock_;
// field_bitlen -> functor mapping
std::unordered_map<size_t, std::shared_ptr<seal::SEALContext>> seal_cntxts_;
std::unordered_map<size_t, seal::SEALContext> galoi_cntxts_;
Expand All @@ -197,7 +196,6 @@ struct CheetahDot::Impl : public EnableCPRNG {
};

void CheetahDot::Impl::LazyInitGaloisKey(size_t field_bitlen) {
// NOTE: make sure context_lock_ is obtained.
if (galoi_cntxts_.find(field_bitlen) != galoi_cntxts_.end()) {
return;
}
Expand Down Expand Up @@ -233,7 +231,6 @@ void CheetahDot::Impl::LazyInitGaloisKey(size_t field_bitlen) {
}

void CheetahDot::Impl::LazyInit(size_t field_bitlen, bool need_galois_keys) {
std::unique_lock guard(context_lock_);
if (seal_cntxts_.find(field_bitlen) != seal_cntxts_.end()) {
if (need_galois_keys) {
LazyInitGaloisKey(field_bitlen);
Expand Down Expand Up @@ -760,6 +757,11 @@ CheetahDot::CheetahDot(const std::shared_ptr<yacl::link::Context> &lctx,

CheetahDot::~CheetahDot() = default;

void CheetahDot::LazyInitKeys(FieldType field) {
SPU_ENFORCE(impl_ != nullptr);
return impl_->LazyInit(SizeOf(field) * 8, /*create_galois*/ true);
}

NdArrayRef CheetahDot::DotOLE(const NdArrayRef &inp, yacl::link::Context *conn,
const Shape3D &dim3, bool is_self_lhs) {
SPU_ENFORCE(impl_ != nullptr);
Expand All @@ -780,4 +782,4 @@ NdArrayRef CheetahDot::BatchDotOLE(const NdArrayRef &inp,
return impl_->BatchDotOLE(inp, conn, dim4, is_self_lhs);
}

} // namespace spu::mpc::cheetah
} // namespace spu::mpc::cheetah
5 changes: 5 additions & 0 deletions libspu/mpc/cheetah/arith/cheetah_dot.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,19 @@ class CheetahDot {

CheetahDot(CheetahDot&&) = delete;

void LazyInitKeys(FieldType field);

// make sure to call InitKeys first
NdArrayRef DotOLE(const NdArrayRef& inp, const Shape3D& dim3,
bool is_self_lhs);

// LHS.shape MxK, RHS.shape KxL => MxL
// make sure to call InitKeys first
NdArrayRef DotOLE(const NdArrayRef& inp, yacl::link::Context* conn,
const Shape3D& dim3, bool is_self_lhs);

// LHS.shape BxMxK, RHS.shape BxKxL => BxMxL
// make sure to call InitKeys first
NdArrayRef BatchDotOLE(const NdArrayRef& inp, yacl::link::Context* conn,
const Shape4D& dim4, bool is_self_lhs);

Expand Down
19 changes: 15 additions & 4 deletions libspu/mpc/cheetah/arith/cheetah_mul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
#include "seal/keygenerator.h"
#include "seal/publickey.h"
#include "seal/secretkey.h"
#include "seal/util/locks.h"
#include "seal/util/polyarithsmallmod.h"
#include "seal/valcheck.h"
#include "spdlog/spdlog.h"
Expand Down Expand Up @@ -103,6 +102,15 @@ struct CheetahMul::Impl : public EnableCPRNG {

int64_t num_slots() const { return parms_.poly_modulus_degree(); }

void LazyInit(FieldType field, uint32_t msg_width_hint) {
Options options;
options.ring_bitlen = SizeOf(field) * 8;
options.msg_bitlen =
msg_width_hint == 0 ? options.ring_bitlen : msg_width_hint;
LazyExpandSEALContexts(options);
LazyInitModSwitchHelper(options);
}

void LazyExpandSEALContexts(const Options &options,
yacl::link::Context *conn = nullptr);

Expand Down Expand Up @@ -189,7 +197,6 @@ struct CheetahMul::Impl : public EnableCPRNG {
uint32_t current_crt_plain_bitlen_{0};

// SEAL's contexts for ZZ_{2^k}
mutable std::mutex context_lock_;
std::vector<seal::SEALContext> seal_cntxts_;

// own secret key
Expand All @@ -206,7 +213,6 @@ struct CheetahMul::Impl : public EnableCPRNG {
};

void CheetahMul::Impl::LazyInitModSwitchHelper(const Options &options) {
std::lock_guard guard(context_lock_);
if (ms_helpers_.count(options) > 0) {
return;
}
Expand Down Expand Up @@ -269,7 +275,6 @@ void CheetahMul::Impl::LocalExpandSEALContexts(size_t target) {
void CheetahMul::Impl::LazyExpandSEALContexts(const Options &options,
yacl::link::Context *conn) {
uint32_t target_plain_bitlen = TotalCRTBitLen(options);
std::lock_guard guard(context_lock_);
if (current_crt_plain_bitlen_ >= target_plain_bitlen) {
return;
}
Expand Down Expand Up @@ -719,4 +724,10 @@ NdArrayRef CheetahMul::MulOLE(const NdArrayRef &inp, bool is_evaluator,
return impl_->MulOLE(inp, nullptr, is_evaluator, msg_width_hint);
}

void CheetahMul::LazyInitKeys(FieldType field, uint32_t msg_width_hint) {
SPU_ENFORCE(impl_ != nullptr);
SPU_ENFORCE(msg_width_hint <= SizeOf(field) * 8);
return impl_->LazyInit(field, msg_width_hint);
}

} // namespace spu::mpc::cheetah
4 changes: 4 additions & 0 deletions libspu/mpc/cheetah/arith/cheetah_mul.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,13 @@ class CheetahMul {

CheetahMul(CheetahMul&&) = delete;

void LazyInitKeys(FieldType field, uint32_t msg_width_hint = 0);

// NOTE: make sure to call InitKeys first
NdArrayRef MulOLE(const NdArrayRef& inp, yacl::link::Context* conn,
bool is_evaluator, uint32_t msg_width_hint = 0);

// NOTE: make sure to call InitKeys first
NdArrayRef MulOLE(const NdArrayRef& inp, bool is_evaluator,
uint32_t msg_width_hint = 0);

Expand Down
6 changes: 6 additions & 0 deletions libspu/mpc/cheetah/arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,8 @@ NdArrayRef MulAA::mulDirectly(KernelEvalContext* ctx, const NdArrayRef& x,
// Compute the cross terms x0*y1, x1*y0 homomorphically
auto* comm = ctx->getState<Communicator>();
auto* mul_prot = ctx->getState<CheetahMulState>()->get();
mul_prot->LazyInitKeys(x.eltype().as<Ring2k>()->field());

const int rank = comm->getRank();
auto fx = x.reshape({x.numel()});
auto fy = y.reshape({y.numel()});
Expand Down Expand Up @@ -311,6 +313,8 @@ NdArrayRef MatMulAA::proc(KernelEvalContext* ctx, const NdArrayRef& x,

auto* comm = ctx->getState<Communicator>();
auto* dot_prot = ctx->getState<CheetahDotState>()->get();
dot_prot->LazyInitKeys(x.eltype().as<Ring2k>()->field());

const int rank = comm->getRank();

// (x0 + x1) * (y0 + y1)
Expand Down Expand Up @@ -347,6 +351,8 @@ NdArrayRef MatMulAV::proc(KernelEvalContext* ctx, const NdArrayRef& x,
}
auto* comm = ctx->getState<Communicator>();
auto* dot_prot = ctx->getState<CheetahDotState>()->get();
dot_prot->LazyInitKeys(x.eltype().as<Ring2k>()->field());

const int rank = comm->getRank();
const auto* ptype = y.eltype().as<Priv2kTy>();
SPU_ENFORCE(ptype != nullptr, "rhs should be a private type");
Expand Down
3 changes: 2 additions & 1 deletion libspu/mpc/cheetah/state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

namespace spu::mpc::cheetah {
size_t InitOTState(KernelEvalContext* ctx, size_t njobs) {
constexpr size_t kMinWorkSize = 1500;
constexpr size_t kMinWorkSize = 5000;
if (njobs == 0) {
return 0;
}
Expand Down Expand Up @@ -70,6 +70,7 @@ void CheetahMulState::makeSureCacheSize(FieldType field, int64_t numel) {
// Then the beaver (a0, b0, c0) and (a1, b1, c1)
// where c0 = a0*b0 + <a0*b1> + <a1*b0>
// c1 = a1*b1 + <a0*b1> + <a1*b0>
mul_prot_->LazyInitKeys(field);
const int rank = mul_prot_->Rank();
const int64_t ole_sze = mul_prot_->OLEBatchSize();
const int64_t num_ole = CeilDiv<size_t>(2 * numel, ole_sze);
Expand Down
8 changes: 6 additions & 2 deletions libspu/mpc/cheetah/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,12 @@ class CheetahOTState : public State {
if (basic_ot_prot_[idx]) {
return;
}
// NOTE: create a separated link for OT
auto _comm = std::make_shared<Communicator>(comm->lctx()->Spawn());
// NOTE(lwj): create a separated link for OT
// We **do not** block on the OT link since the message volume is small for
// LPN-based OTe
auto link = comm->lctx()->Spawn();
link->SetThrottleWindowSize(0);
auto _comm = std::make_shared<Communicator>(std::move(link));
basic_ot_prot_[idx] = std::make_shared<BasicOTProtocols>(std::move(_comm));
}

Expand Down

0 comments on commit 1f61217

Please sign in to comment.