Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
f7ed authored Jan 3, 2025
2 parents 57978ac + eb25d5e commit 391c827
Show file tree
Hide file tree
Showing 70 changed files with 2,103 additions and 436 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/scorecard.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:

steps:
- name: "Checkout code"
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
persist-credentials: false

Expand Down Expand Up @@ -67,6 +67,6 @@ jobs:

# Upload the results to GitHub's code scanning dashboard.
- name: "Upload to code-scanning"
uses: github/codeql-action/upload-sarif@c36620d31ac7c881962c3d9dd939c40ec9434f2b # v3.26.12
uses: github/codeql-action/upload-sarif@48ab28a6f5dbc2a99bf1e0131198dd8f1df78169 # v3.28.0
with:
sarif_file: results.sarif
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
>
> please add your unreleased change here.
## 20241219

- [SPU] 0.9.3b0 release
- [Improvement] Optimize exponential computation for semi2k (**experimental**)
- [Feature] Add more send/recv actions profiling

## 20240716
Expand Down
18 changes: 9 additions & 9 deletions bazel/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,21 @@ def _yacl():
http_archive,
name = "yacl",
urls = [
"https://github.com/secretflow/yacl/archive/refs/tags/0.4.5b7_nightly_20240930.tar.gz",
"https://github.com/secretflow/yacl/archive/refs/tags/0.4.5b8_nightly_20241014.tar.gz",
],
strip_prefix = "yacl-0.4.5b7_nightly_20240930",
sha256 = "cf8dc7cceb9c5d05df00f1c086feec99d554db3e3cbe101253cf2a5a1adb9072",
strip_prefix = "yacl-0.4.5b8_nightly_20241014",
sha256 = "9141792f07eba507ffd21c57ec3df2ad5fdf90ce605ffb7bc1b7b4e84a9c34fa",
)

def _libpsi():
maybe(
http_archive,
name = "psi",
urls = [
"https://github.com/secretflow/psi/archive/refs/tags/v0.4.3.dev240919.tar.gz",
"https://github.com/secretflow/psi/archive/refs/tags/v0.5.0.dev241115.tar.gz",
],
strip_prefix = "psi-0.4.3.dev240919",
sha256 = "1ee34fbbd9a8f36dea8f7c45588a858e8c31f3a38e60e1fc67cb428ea79334e3",
strip_prefix = "psi-0.5.0.dev241115",
sha256 = "4d5ccc61282c4f887cee2c12fe3f414dfd7e916952849e92ffb1f6835d657a35",
)

def _rules_proto_grpc():
Expand Down Expand Up @@ -242,10 +242,10 @@ def _com_github_nvidia_cutlass():
maybe(
http_archive,
name = "cutlass_archive",
strip_prefix = "cutlass-3.5.1",
strip_prefix = "cutlass-3.6.0",
urls = [
"https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.5.1.tar.gz",
"https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.6.0.tar.gz",
],
sha256 = "20b7247cda2d257cbf8ba59ba3ca40a9211c4da61a9c9913e32b33a2c5883a36",
sha256 = "7576f3437b90d0de5923560ccecebaa1357e5d72f36c0a59ad77c959c9790010",
build_file = "@spulib//bazel:nvidia_cutlass.BUILD",
)
8 changes: 4 additions & 4 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
myst-parser==4.0.0
rstcheck==6.2.4
sphinx==8.0.2
nbsphinx==0.9.5
sphinx==8.1.3
nbsphinx==0.9.6
sphinx-autobuild==2024.10.3
sphinx-markdown-parser==0.2.4
sphinxcontrib-actdiag==3.0.0
sphinxcontrib-blockdiag==3.0.0
sphinxcontrib-mermaid==0.9.2
sphinxcontrib-mermaid==1.0.0
sphinxcontrib-nwdiag==2.0.0
sphinxcontrib-seqdiag==3.0.0
pytablewriter==1.2.0
pytablewriter==1.2.1
linkify-it-py==2.0.3
mdutils==1.6.0
spu>=0.3.1b0
Expand Down
2 changes: 2 additions & 0 deletions libspu/compiler/front_end/hlo_importer.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ class CompilationContext;

class HloImporter final {
public:
// clang-format off
explicit HloImporter(CompilationContext *context) : context_(context) {};
// clang-format on

/// Load a xla module and returns a mlir-hlo module
mlir::OwningOpRef<mlir::ModuleOp>
Expand Down
11 changes: 11 additions & 0 deletions libspu/core/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,17 @@ void populateRuntimeConfig(RuntimeConfig& cfg) {
if (cfg.fxp_exp_mode() == RuntimeConfig::EXP_DEFAULT) {
cfg.set_fxp_exp_mode(RuntimeConfig::EXP_TAYLOR);
}
if (cfg.fxp_exp_mode() == RuntimeConfig::EXP_PRIME) {
// 0 offset is not supported
if (cfg.experimental_exp_prime_offset() == 0) {
// For FM128 default offset is 13
if (cfg.field() == FieldType::FM128) {
cfg.set_experimental_exp_prime_offset(13);
}
// TODO: set defaults for other fields, currently only FM128 is
// supported
}
}

if (cfg.fxp_exp_iters() == 0) {
cfg.set_fxp_exp_iters(8);
Expand Down
8 changes: 8 additions & 0 deletions libspu/dialect/pphlo/IR/fold.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ OpFoldResult ReverseOp::fold(FoldAdaptor) {
dims, [&](int64_t dim) { return shapedType.getDimSize(dim) == 1; })) {
return input;
}

// reverse(reverse(x, dims), dims) = x
if (auto prev = input.getDefiningOp<ReverseOp>()) {
if (prev.getDimensions() == dims) {
return prev.getOperand();
}
}

return {};
}

Expand Down
1 change: 1 addition & 0 deletions libspu/kernel/hal/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ spu_cc_test(
deps = [
":fxp_approx",
"//libspu/kernel:test_util",
"//libspu/mpc/utils:simulate",
],
)

Expand Down
36 changes: 35 additions & 1 deletion libspu/kernel/hal/fxp_approx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,31 @@ Value exp_taylor(SPUContext* ctx, const Value& x) {
return res;
}

Value exp_prime(SPUContext* ctx, const Value& x) {
auto clamped_x = x;
auto offset = ctx->config().experimental_exp_prime_offset();
auto fxp = ctx->getFxpBits();
if (!ctx->config().experimental_exp_prime_disable_lower_bound()) {
// currently the bound is tied to FM128
SPU_ENFORCE_EQ(ctx->getField(), FieldType::FM128);
auto lower_bound = (48.0 - offset - 2.0 * fxp) / M_LOG2E;
clamped_x = _clamp_lower(ctx, clamped_x,
constant(ctx, lower_bound, x.dtype(), x.shape()))
.setDtype(x.dtype());
}
if (ctx->config().experimental_exp_prime_enable_upper_bound()) {
// currently the bound is tied to FM128
SPU_ENFORCE_EQ(ctx->getField(), FieldType::FM128);
auto upper_bound = (124.0 - 2.0 * fxp - offset) / M_LOG2E;
clamped_x = _clamp_upper(ctx, clamped_x,
constant(ctx, upper_bound, x.dtype(), x.shape()))
.setDtype(x.dtype());
}

auto ret = dynDispatch<spu::Value>(ctx, "exp_a", clamped_x);
return ret.setDtype(x.dtype());
}

namespace {

// Pade approximation of exp2(x), x is in [0, 1].
Expand Down Expand Up @@ -439,13 +464,22 @@ Value f_exp(SPUContext* ctx, const Value& x) {
case RuntimeConfig::EXP_PADE: {
// The valid input for exp_pade is [-kInputLimit, kInputLimit].
// TODO(junfeng): should merge clamp into exp_pade to save msb ops.
const float kInputLimit = 32 / std::log2(std::exp(1));
const float kInputLimit = 32.0 / std::log2(std::exp(1));
const auto clamped_x =
_clamp(ctx, x, constant(ctx, -kInputLimit, x.dtype(), x.shape()),
constant(ctx, kInputLimit, x.dtype(), x.shape()))
.setDtype(x.dtype());
return detail::exp_pade(ctx, clamped_x);
}
case RuntimeConfig::EXP_PRIME:
if (ctx->hasKernel("exp_a")) {
return detail::exp_prime(ctx, x);
} else {
SPU_THROW(
"exp_a is not implemented for this protocol, currently only "
"2pc "
"semi2k is supported.");
}
default:
SPU_THROW("unexpected exp approximation method {}",
ctx->config().fxp_exp_mode());
Expand Down
2 changes: 2 additions & 0 deletions libspu/kernel/hal/fxp_approx.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ Value exp2_pade(SPUContext* ctx, const Value& x);
// Works for range [-12.0, 18.0]
Value exp_pade(SPUContext* ctx, const Value& x);

Value exp_prime(SPUContext* ctx, const Value& x);

Value tanh_chebyshev(SPUContext* ctx, const Value& x);

} // namespace detail
Expand Down
28 changes: 27 additions & 1 deletion libspu/kernel/hal/fxp_approx_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "libspu/kernel/hal/constants.h"
#include "libspu/kernel/hal/type_cast.h"
#include "libspu/kernel/test_util.h"
#include "libspu/mpc/utils/simulate.h"

namespace spu::kernel::hal {

Expand Down Expand Up @@ -78,10 +79,35 @@ TEST(FxpTest, ExponentialPade) {
<< y;
}

TEST(FxpTest, ExponentialPrime) {
std::cout << "test exp_prime" << std::endl;
spu::mpc::utils::simulate(2, [&](std::shared_ptr<yacl::link::Context> lctx) {
RuntimeConfig conf;
conf.set_protocol(ProtocolKind::SEMI2K);
conf.set_field(FieldType::FM128);
conf.set_fxp_fraction_bits(40);
conf.set_experimental_enable_exp_prime(true);
SPUContext ctx = test::makeSPUContext(conf, lctx);

auto offset = ctx.config().experimental_exp_prime_offset();
auto fxp = ctx.getFxpBits();
auto lower_bound = (48.0 - offset - 2.0 * fxp) / M_LOG2E;
auto upper_bound = (124.0 - 2.0 * fxp - offset) / M_LOG2E;

xt::xarray<float> x = xt::linspace<float>(lower_bound, upper_bound, 4000);

Value a = test::makeValue(&ctx, x, VIS_SECRET);
Value c = detail::exp_prime(&ctx, a);
auto y = dump_public_as<float>(&ctx, reveal(&ctx, c));
EXPECT_TRUE(xt::allclose(xt::exp(x), y, 0.01, 0.001))
<< xt::exp(x) << std::endl
<< y;
});
}

TEST(FxpTest, Log) {
// GIVEN
SPUContext ctx = test::makeSPUContext();

xt::xarray<float> x = {{0.05, 0.5}, {5, 50}};
// public log
{
Expand Down
15 changes: 13 additions & 2 deletions libspu/kernel/hal/ring.cc
Original file line number Diff line number Diff line change
Expand Up @@ -472,14 +472,25 @@ Value _mux(SPUContext* ctx, const Value& pred, const Value& a, const Value& b) {
Value _clamp(SPUContext* ctx, const Value& x, const Value& minv,
const Value& maxv) {
SPU_TRACE_HAL_LEAF(ctx, x, minv, maxv);

// clamp lower bound, res = x < minv ? minv : x
auto res = _mux(ctx, _less(ctx, x, minv), minv, x);

// clamp upper bound, res = res < maxv ? res, maxv
return _mux(ctx, _less(ctx, res, maxv), res, maxv);
}

// TODO: refactor polymorphic, and may use select functions in polymorphic
Value _clamp_lower(SPUContext* ctx, const Value& x, const Value& minv) {
SPU_TRACE_HAL_LEAF(ctx, x, minv);
// clamp lower bound, res = x < minv ? minv : x
return _mux(ctx, _less(ctx, x, minv), minv, x);
}

Value _clamp_upper(SPUContext* ctx, const Value& x, const Value& maxv) {
SPU_TRACE_HAL_LEAF(ctx, x, maxv);
// clamp upper bound, x = x < maxv ? x, maxv
return _mux(ctx, _less(ctx, x, maxv), x, maxv);
}

Value _constant(SPUContext* ctx, uint128_t init, const Shape& shape) {
return _make_p(ctx, init, shape);
}
Expand Down
5 changes: 5 additions & 0 deletions libspu/kernel/hal/ring.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ Value _mux(SPUContext* ctx, const Value& pred, const Value& a, const Value& b);
// TODO: test me
Value _clamp(SPUContext* ctx, const Value& x, const Value& minv,
const Value& maxv);

Value _clamp_lower(SPUContext* ctx, const Value& x, const Value& minv);

Value _clamp_upper(SPUContext* ctx, const Value& x, const Value& maxv);

// Make a public value from uint128_t init value.
//
// If the current working field has less than 128bit, the lower sizeof(field)
Expand Down
8 changes: 4 additions & 4 deletions libspu/kernel/hal/shape_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ Value update_slice(SPUContext* ctx, const Value& in, const Value& update,
SPU_TRACE_HAL_DISP(ctx, in, start_indices);

if (in.storage_type() != update.storage_type()) {
auto u =
_cast_type(ctx, update, in.storage_type()).setDtype(update.dtype());

return update_slice(ctx, in, u, start_indices);
auto ct = _common_type(ctx, update.storage_type(), in.storage_type());
auto u = _cast_type(ctx, update, ct).setDtype(update.dtype());
auto i = _cast_type(ctx, in, ct).setDtype(in.dtype());
return update_slice(ctx, i, u, start_indices);
}

return _update_slice(ctx, in, update, start_indices).setDtype(in.dtype());
Expand Down
6 changes: 6 additions & 0 deletions libspu/kernel/hal/type_cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ Value reveal(SPUContext* ctx, const Value& x) {
return _s2p(ctx, x).setDtype(x.dtype());
}

Value reveal_to(SPUContext* ctx, const Value& x, size_t rank) {
SPU_TRACE_HAL_LEAF(ctx, x, rank);
SPU_ENFORCE(x.isSecret());
return _s2v(ctx, x, rank).setDtype(x.dtype());
}

Value dtype_cast(SPUContext* ctx, const Value& in, DataType to_type) {
SPU_TRACE_HAL_DISP(ctx, in, to_type);

Expand Down
5 changes: 5 additions & 0 deletions libspu/kernel/hal/type_cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,9 @@ Value seal(SPUContext* ctx, const Value& x);
// @param in, the input value
Value reveal(SPUContext* ctx, const Value& x);

/// reveal a secret to a specific party
// @param in, the input value
// @param rank, the rank of the party to reveal to
Value reveal_to(SPUContext* ctx, const Value& x, size_t rank);

} // namespace spu::kernel::hal
1 change: 1 addition & 0 deletions libspu/kernel/hlo/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ spu_cc_test(
":casting",
":const",
"//libspu/kernel:test_util",
"//libspu/mpc/utils:simulate",
],
)

Expand Down
4 changes: 4 additions & 0 deletions libspu/kernel/hlo/casting.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ spu::Value Reveal(SPUContext *ctx, const spu::Value &in) {
return hal::reveal(ctx, in);
}

spu::Value RevealTo(SPUContext *ctx, const spu::Value &in, size_t rank) {
return hal::reveal_to(ctx, in, rank);
}

spu::Value Seal(SPUContext *ctx, const spu::Value &in) {
return hal::seal(ctx, in);
}
Expand Down
2 changes: 2 additions & 0 deletions libspu/kernel/hlo/casting.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ spu::Value Bitcast(SPUContext *ctx, const spu::Value &in, DataType dst_dtype);

spu::Value Reveal(SPUContext *ctx, const spu::Value &in);

spu::Value RevealTo(SPUContext *ctx, const spu::Value &in, size_t rank);

spu::Value Seal(SPUContext *ctx, const spu::Value &in);

} // namespace spu::kernel::hlo
Loading

0 comments on commit 391c827

Please sign in to comment.