Skip to content

Commit

Permalink
Repo sync (#717)
Browse files Browse the repository at this point in the history
  • Loading branch information
anakinxc authored Jun 4, 2024
1 parent 90f2365 commit d3cab0f
Show file tree
Hide file tree
Showing 37 changed files with 382 additions and 73 deletions.
4 changes: 2 additions & 2 deletions .clang-tidy
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Checks: "abseil-cleanup-ctad,
bugprone-*,
-bugprone-easily-swappable-parameters,
-bugprone-implicit-widening-of-multiplication-result,
-bugprone-narrowing-conversions, # too many false positives around `std::size_t` vs. `*::difference_type`.
-bugprone-narrowing-conversions,
google-build-using-namespace,
google-explicit-constructor,
google-global-names-in-headers,
Expand All @@ -20,7 +20,7 @@ Checks: "abseil-cleanup-ctad,
modernize-*,
-modernize-use-trailing-return-type,
-modernize-avoid-c-arrays,
-modernize-return-braced-init-list, # can hurt readability
-modernize-return-braced-init-list,
-modernize-use-nodiscard,
performance-*,
readability-*,
Expand Down
6 changes: 6 additions & 0 deletions libspu/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
load("@rules_cc//cc:defs.bzl", "cc_proto_library")
load("@rules_proto//proto:defs.bzl", "proto_library")
load("@rules_proto_grpc//python:defs.bzl", "python_proto_compile")
load("//bazel:spu.bzl", "spu_cc_library")

package(default_visibility = ["//visibility:public"])

Expand All @@ -34,3 +35,8 @@ python_proto_compile(
prefix_path = "..",
protos = ["//libspu:spu_proto"],
)

spu_cc_library(
name = "version",
hdrs = ["version.h"],
)
27 changes: 16 additions & 11 deletions libspu/core/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class KernelEvalContext final {
SPUContext* sctx_;

std::vector<ParamType> params_;
ParamType output_;
std::vector<ParamType> outputs_;

public:
explicit KernelEvalContext(SPUContext* sctx) : sctx_(sctx) {}
Expand All @@ -139,20 +139,24 @@ class KernelEvalContext final {
}

size_t numParams() const { return params_.size(); }
size_t numOutputs() const { return outputs_.size(); }

// Steal the output from this evaluation context.
//
// * usually called by kernel caller.
template <typename T = Value>
T&& stealOutput() {
return std::move(std::get<T>(output_));
T&& consumeOutput(size_t pos) {
SPU_DEBUG_ONLY_ENFORCE(pos < outputs_.size(),
"pos={} exceed num of outputs={}", pos,
outputs_.size());
return std::move(std::get<T>(outputs_[pos]));
}

// Bind an input to this evaluation context.
//
// * usually called by kernel caller.
template <typename T>
void bindParam(const T& in) {
void pushParam(const T& in) {
params_.emplace_back(in);
}

Expand All @@ -161,27 +165,28 @@ class KernelEvalContext final {
// * usually called by kernel callee.
template <typename T>
const T& getParam(size_t pos) const {
SPU_ENFORCE(pos < params_.size(), "pos={} exceed num of inputs={}", pos,
params_.size());
SPU_DEBUG_ONLY_ENFORCE(pos < params_.size(),
"pos={} exceed num of inputs={}", pos,
params_.size());
return std::get<T>(params_[pos]);
}

// Set the output.
//
// * usually called by kernel callee.
template <typename T = Value>
void setOutput(T&& out) {
output_ = std::forward<T>(out);
void pushOutput(T&& out) {
outputs_.emplace_back(std::forward<T>(out));
}
};

namespace detail {

template <typename First, typename... Args>
void bindParams(KernelEvalContext* ectx, First&& head, Args&&... tail) {
ectx->bindParam(std::forward<First>(head));
ectx->pushParam(std::forward<First>(head));
if constexpr (sizeof...(Args) > 0) {
return bindParams(ectx, std::forward<Args>(tail)...);
bindParams(ectx, std::forward<Args>(tail)...);
}
}

Expand All @@ -203,7 +208,7 @@ Ret dynDispatch(SPUContext* sctx, const std::string& name, Args&&... args) {
kernel->evaluate(&ectx);

// 4. steal the result and return it.
return ectx.stealOutput<Ret>();
return ectx.consumeOutput<Ret>(0);
}

// helper class
Expand Down
6 changes: 4 additions & 2 deletions libspu/core/prelude.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@
__VA_ARGS__) // NOLINT, readability-simplify-boolean-expr

#ifdef NDEBUG
#define SPU_DEBUG_ONLY_THROW static_cast<void>
#define SPU_DEBUG_ONLY_THROW(...) static_cast<void>(0)
#define SPU_DEBUG_ONLY_ENFORCE(...) static_cast<void>(0)
#else
#define SPU_DEBUG_ONLY_THROW YACL_THROW
#define SPU_DEBUG_ONLY_THROW(...) YACL_THROW(__VA_ARGS__)
#define SPU_DEBUG_ONLY_ENFORCE(COND, ...) SPU_ENFORCE(COND, __VA_ARGS__)
#endif

// Force compiler to inline something regardless of optimization level.
Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/aby3/boolean.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void CommonTypeB::evaluate(KernelEvalContext* ctx) const {
const size_t out_nbits = std::max(lhs_nbits, rhs_nbits);
const PtType out_btype = calcBShareBacktype(out_nbits);

ctx->setOutput(makeType<BShrTy>(out_btype, out_nbits));
ctx->pushOutput(makeType<BShrTy>(out_btype, out_nbits));
}

NdArrayRef CastTypeB::proc(KernelEvalContext*, const NdArrayRef& in,
Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/aby3/conversion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -922,7 +922,7 @@ void CommonTypeV::evaluate(KernelEvalContext* ctx) const {
const auto* lhs_v = lhs.as<Priv2kTy>();
const auto* rhs_v = rhs.as<Priv2kTy>();

ctx->setOutput(makeType<AShrTy>(std::max(lhs_v->field(), rhs_v->field())));
ctx->pushOutput(makeType<AShrTy>(std::max(lhs_v->field(), rhs_v->field())));
}

} // namespace spu::mpc::aby3
2 changes: 1 addition & 1 deletion libspu/mpc/cheetah/boolean_semi2k.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ void CommonTypeB::evaluate(KernelEvalContext* ctx) const {
"cheetah always use same bshare field, lhs={}, rhs={}", lhs_field,
rhs_field);

ctx->setOutput(makeType<BShrTy>(lhs_field, std::max(lhs_nbits, rhs_nbits)));
ctx->pushOutput(makeType<BShrTy>(lhs_field, std::max(lhs_nbits, rhs_nbits)));
}

NdArrayRef CastTypeB::proc(KernelEvalContext*, const NdArrayRef& in,
Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/cheetah/conversion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ void CommonTypeV::evaluate(KernelEvalContext* ctx) const {
const auto* lhs_v = lhs.as<Priv2kTy>();
const auto* rhs_v = rhs.as<Priv2kTy>();

ctx->setOutput(makeType<AShrTy>(std::max(lhs_v->field(), rhs_v->field())));
ctx->pushOutput(makeType<AShrTy>(std::max(lhs_v->field(), rhs_v->field())));
}

} // namespace spu::mpc::cheetah
2 changes: 1 addition & 1 deletion libspu/mpc/common/pv2k.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class MakeP : public Kernel {
ce::CExpr comm() const override { return ce::Const(0); }

void evaluate(KernelEvalContext* ctx) const override {
ctx->setOutput(
ctx->pushOutput(
proc(ctx, ctx->getParam<uint128_t>(0), ctx->getParam<Shape>(1)));
}

Expand Down
46 changes: 23 additions & 23 deletions libspu/mpc/kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ void RandKernel::evaluate(KernelEvalContext* ctx) const {

auto res = proc(ctx, shape);

ctx->setOutput(WrapValue(res));
ctx->pushOutput(WrapValue(res));
}

void UnaryKernel::evaluate(KernelEvalContext* ctx) const {
const auto& in = ctx->getParam<Value>(0);

auto res = proc(ctx, UnwrapValue(in));

ctx->setOutput(WrapValue(res));
ctx->pushOutput(WrapValue(res));
}

void RevealToKernel::evaluate(KernelEvalContext* ctx) const {
Expand All @@ -38,7 +38,7 @@ void RevealToKernel::evaluate(KernelEvalContext* ctx) const {

auto res = proc(ctx, UnwrapValue(in), rank);

ctx->setOutput(WrapValue(res));
ctx->pushOutput(WrapValue(res));
}

void ShiftKernel::evaluate(KernelEvalContext* ctx) const {
Expand All @@ -47,7 +47,7 @@ void ShiftKernel::evaluate(KernelEvalContext* ctx) const {

auto res = proc(ctx, UnwrapValue(in), bits);

ctx->setOutput(WrapValue(res));
ctx->pushOutput(WrapValue(res));
}

void BinaryKernel::evaluate(KernelEvalContext* ctx) const {
Expand All @@ -59,7 +59,7 @@ void BinaryKernel::evaluate(KernelEvalContext* ctx) const {

auto z = proc(ctx, UnwrapValue(lhs), UnwrapValue(rhs));

ctx->setOutput(WrapValue(z));
ctx->pushOutput(WrapValue(z));
}

void MatmulKernel::evaluate(KernelEvalContext* ctx) const {
Expand All @@ -69,7 +69,7 @@ void MatmulKernel::evaluate(KernelEvalContext* ctx) const {
SPU_ENFORCE(lhs.shape()[1] == rhs.shape()[0], "invalid shape {} {}", lhs,
rhs);

ctx->setOutput(WrapValue(proc(ctx, lhs.data(), rhs.data())));
ctx->pushOutput(WrapValue(proc(ctx, lhs.data(), rhs.data())));
}

void Conv2DKernel::evaluate(KernelEvalContext* ctx) const {
Expand All @@ -80,7 +80,7 @@ void Conv2DKernel::evaluate(KernelEvalContext* ctx) const {

auto z = proc(ctx, UnwrapValue(lhs), UnwrapValue(rhs), stride_h, stride_w);

ctx->setOutput(WrapValue(z));
ctx->pushOutput(WrapValue(z));
}

void BitrevKernel::evaluate(KernelEvalContext* ctx) const {
Expand All @@ -90,7 +90,7 @@ void BitrevKernel::evaluate(KernelEvalContext* ctx) const {

auto z = proc(ctx, UnwrapValue(in), start, end);

ctx->setOutput(WrapValue(z));
ctx->pushOutput(WrapValue(z));
}

void TruncAKernel::evaluate(KernelEvalContext* ctx) const {
Expand All @@ -100,7 +100,7 @@ void TruncAKernel::evaluate(KernelEvalContext* ctx) const {

auto z = proc(ctx, UnwrapValue(in), bits, sign);

ctx->setOutput(WrapValue(z));
ctx->pushOutput(WrapValue(z));
}

void BitSplitKernel::evaluate(KernelEvalContext* ctx) const {
Expand All @@ -109,7 +109,7 @@ void BitSplitKernel::evaluate(KernelEvalContext* ctx) const {

auto res = proc(ctx, UnwrapValue(in), stride);

ctx->setOutput(WrapValue(res));
ctx->pushOutput(WrapValue(res));
}

void CastTypeKernel::evaluate(KernelEvalContext* ctx) const {
Expand All @@ -118,7 +118,7 @@ void CastTypeKernel::evaluate(KernelEvalContext* ctx) const {

auto res = proc(ctx, UnwrapValue(val), to_type);

ctx->setOutput(WrapValue(res));
ctx->pushOutput(WrapValue(res));
}

void PermKernel::evaluate(KernelEvalContext* ctx) const {
Expand All @@ -131,7 +131,7 @@ void PermKernel::evaluate(KernelEvalContext* ctx) const {

auto z = proc(ctx, UnwrapValue(x), UnwrapValue(y));

ctx->setOutput(WrapValue(z));
ctx->pushOutput(WrapValue(z));
}

void GenInvPermKernel::evaluate(KernelEvalContext* ctx) const {
Expand All @@ -141,7 +141,7 @@ void GenInvPermKernel::evaluate(KernelEvalContext* ctx) const {

auto y = proc(ctx, UnwrapValue(in), is_ascending);

ctx->setOutput(WrapValue(y));
ctx->pushOutput(WrapValue(y));
}

void MergeKeysKernel::evaluate(KernelEvalContext* ctx) const {
Expand All @@ -153,7 +153,7 @@ void MergeKeysKernel::evaluate(KernelEvalContext* ctx) const {
}
auto y = proc(ctx, inputs, is_ascending);

ctx->setOutput(WrapValue(y));
ctx->pushOutput(WrapValue(y));
}

void BroadcastKernel::evaluate(KernelEvalContext* ctx) const {
Expand All @@ -163,7 +163,7 @@ void BroadcastKernel::evaluate(KernelEvalContext* ctx) const {

auto z = proc(ctx, UnwrapValue(in), to_shape, in_dims);

ctx->setOutput(WrapValue(z));
ctx->pushOutput(WrapValue(z));
}

void DimsBasedKernel::evaluate(KernelEvalContext* ctx) const {
Expand All @@ -172,7 +172,7 @@ void DimsBasedKernel::evaluate(KernelEvalContext* ctx) const {

auto z = proc(ctx, UnwrapValue(in), axes);

ctx->setOutput(WrapValue(z));
ctx->pushOutput(WrapValue(z));
}

void ShapeBasedKernel::evaluate(KernelEvalContext* ctx) const {
Expand All @@ -181,7 +181,7 @@ void ShapeBasedKernel::evaluate(KernelEvalContext* ctx) const {

auto z = proc(ctx, UnwrapValue(in), to_shape);

ctx->setOutput(WrapValue(z));
ctx->pushOutput(WrapValue(z));
}

void ExtractSliceKernel::evaluate(KernelEvalContext* ctx) const {
Expand All @@ -192,7 +192,7 @@ void ExtractSliceKernel::evaluate(KernelEvalContext* ctx) const {

auto z = proc(ctx, UnwrapValue(in), start, end, strides);

ctx->setOutput(WrapValue(z));
ctx->pushOutput(WrapValue(z));
}

void UpdateSliceKernel::evaluate(KernelEvalContext* ctx) const {
Expand All @@ -202,7 +202,7 @@ void UpdateSliceKernel::evaluate(KernelEvalContext* ctx) const {

auto z = proc(ctx, UnwrapValue(in), UnwrapValue(update), start);

ctx->setOutput(WrapValue(z));
ctx->pushOutput(WrapValue(z));
}

void PadKernel::evaluate(KernelEvalContext* ctx) const {
Expand All @@ -215,7 +215,7 @@ void PadKernel::evaluate(KernelEvalContext* ctx) const {
auto z = proc(ctx, UnwrapValue(in), UnwrapValue(padding_value), edge_low,
edge_high, interior_padding);

ctx->setOutput(WrapValue(z));
ctx->pushOutput(WrapValue(z));
}

void ConcateKernel::evaluate(KernelEvalContext* ctx) const {
Expand All @@ -230,7 +230,7 @@ void ConcateKernel::evaluate(KernelEvalContext* ctx) const {

auto z = proc(ctx, unwrapped, axis);

ctx->setOutput(WrapValue(z));
ctx->pushOutput(WrapValue(z));
}

void OramOneHotKernel::evaluate(KernelEvalContext* ctx) const {
Expand All @@ -242,7 +242,7 @@ void OramOneHotKernel::evaluate(KernelEvalContext* ctx) const {

auto res = proc(ctx, UnwrapValue(target), s);

ctx->setOutput(WrapValue(res));
ctx->pushOutput(WrapValue(res));
}

void OramReadKernel::evaluate(KernelEvalContext* ctx) const {
Expand All @@ -256,7 +256,7 @@ void OramReadKernel::evaluate(KernelEvalContext* ctx) const {
SPU_ENFORCE(onehot.shape()[1] == db.shape()[0],
"onehot and database shape mismatch");

ctx->setOutput(
ctx->pushOutput(
WrapValue(proc(ctx, UnwrapValue(onehot), UnwrapValue(db), offset)));
}

Expand Down
4 changes: 2 additions & 2 deletions libspu/mpc/ref2k/ref2k.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class Ref2kCommonTypeS : public Kernel {
SPU_TRACE_MPC_DISP(ctx, lhs, rhs);
SPU_ENFORCE(lhs.isa<Ref2kSecrTy>(), "invalid type, got={}", lhs);
SPU_ENFORCE(rhs.isa<Ref2kSecrTy>(), "invalid type, got={}", rhs);
ctx->setOutput(lhs);
ctx->pushOutput(lhs);
}
};

Expand All @@ -79,7 +79,7 @@ class Ref2kCommonTypeV : public Kernel {
const auto* lhs_v = lhs.as<Priv2kTy>();
const auto* rhs_v = rhs.as<Priv2kTy>();

ctx->setOutput(
ctx->pushOutput(
makeType<Ref2kSecrTy>(std::max(lhs_v->field(), rhs_v->field())));
}
};
Expand Down
Loading

0 comments on commit d3cab0f

Please sign in to comment.