+
#### Operands:
| Operand | Description |
| :-----: | ----------- |
-| `lhs` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values
-| `rhs` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values
+| `operand` | statically shaped tensor of 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values values
#### Results:
| Result | Description |
| :----: | ----------- |
-| `result` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values
+| `result` | statically shaped tensor of 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values values
-### `pphlo.prefer_a` (spu::pphlo::PreferAOp)
+### `pphlo.power` (spu::pphlo::PowOp)
-_Prefer AShare operator_
+_Power operator_
Syntax:
```
-operation ::= `pphlo.prefer_a` $operand attr-dict `:` custom(type($operand), type($result))
+operation ::= `pphlo.power` $lhs `,` $rhs attr-dict
+ `:` custom(type($lhs), type($rhs), type($result))
```
-Convert input to AShare if possible.
+Performs element-wise exponentiation of `lhs` tensor by `rhs` tensor and produces a `result` tensor.
-Traits: `AlwaysSpeculatableImplTrait`, `Elementwise`, `SameOperandsAndResultShape`, `SameOperandsAndResultType`
+Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#power
+
+Traits: `AlwaysSpeculatableImplTrait`, `Elementwise`, `SameOperandsAndResultShape`
Interfaces: `ConditionallySpeculatable`, `InferShapedTypeOpInterface`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)`
@@ -1684,7 +1692,8 @@ Effects: `MemoryEffects::Effect{}`
| Operand | Description |
| :-----: | ----------- |
-| `operand` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values
+| `lhs` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values
+| `rhs` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values
#### Results:
@@ -2270,12 +2279,21 @@ Returns the sign of the `operand` element-wise and produces a `result` tensor.
Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sign
+PPHLO Extension: when `ignore_zero` is set to true, sign does not enforce sign(0) to 0
+
Traits: `AlwaysSpeculatableImplTrait`, `Elementwise`, `SameOperandsAndResultShape`, `SameOperandsAndResultType`
Interfaces: `ConditionallySpeculatable`, `InferShapedTypeOpInterface`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)`
Effects: `MemoryEffects::Effect{}`
+#### Attributes:
+
+
+
Attribute
MLIR Type
Description
+
ignore_zero
::mlir::BoolAttr
bool attribute
+
+
#### Operands:
| Operand | Description |
@@ -2377,7 +2395,7 @@ Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#slice
Traits: `AlwaysSpeculatableImplTrait`, `SameOperandsAndResultElementType`
-Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`
+Interfaces: `ConditionallySpeculatable`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)`
Effects: `MemoryEffects::Effect{}`
@@ -2551,7 +2569,7 @@ Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#transpose
Traits: `AlwaysSpeculatableImplTrait`, `SameOperandsAndResultElementType`
-Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`
+Interfaces: `ConditionallySpeculatable`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)`
Effects: `MemoryEffects::Effect{}`
diff --git a/docs/reference/runtime_config.md b/docs/reference/runtime_config.md
index f556bce4..836095cc 100644
--- a/docs/reference/runtime_config.md
+++ b/docs/reference/runtime_config.md
@@ -179,8 +179,9 @@ The SPU runtime configuration.
| Field | Type | Description |
| ----- | ---- | ----------- |
| server_host | [ string](#string) | TrustedThirdParty beaver server's remote ip:port or load-balance uri. |
-| session_id | [ string](#string) | if empty, use link id as session id. |
| adjust_rank | [ int32](#int32) | which rank do adjust rpc call, usually choose the rank closer to the server. |
+| asym_crypto_schema | [ string](#string) | asym_crypto_schema: support ["SM2"] Will support 25519 in the future, after yacl supported it. |
+| server_public_key | [ bytes](#bytes) | server's public key |
diff --git a/libspu/compiler/common/compilation_context.cc b/libspu/compiler/common/compilation_context.cc
index ef22871c..553fe913 100644
--- a/libspu/compiler/common/compilation_context.cc
+++ b/libspu/compiler/common/compilation_context.cc
@@ -23,7 +23,7 @@ namespace {
void SPUErrorHandler(void * /*use_data*/, const char *reason,
bool /*gen_crash_diag*/) {
- SPU_THROW(reason);
+ SPU_THROW("{}", reason);
}
} // namespace
diff --git a/libspu/compiler/common/ir_printer_config.cc b/libspu/compiler/common/ir_printer_config.cc
index 47dff64a..f3c15ff2 100644
--- a/libspu/compiler/common/ir_printer_config.cc
+++ b/libspu/compiler/common/ir_printer_config.cc
@@ -51,6 +51,7 @@ void IRPrinterConfig::printBeforeIfEnabled(Pass *pass, Operation *,
if (ec.value() != 0) {
spdlog::error("Open file {} failed, error = {}", file_name.c_str(),
ec.message());
+ return;
}
print_callback(f);
}
@@ -64,6 +65,7 @@ void IRPrinterConfig::printAfterIfEnabled(Pass *pass, Operation *,
if (ec.value() != 0) {
spdlog::error("Open file {} failed, error = {}", file_name.c_str(),
ec.message());
+ return;
}
print_callback(f);
}
diff --git a/libspu/compiler/front_end/fe.cc b/libspu/compiler/front_end/fe.cc
index 682f1c5b..e560c772 100644
--- a/libspu/compiler/front_end/fe.cc
+++ b/libspu/compiler/front_end/fe.cc
@@ -54,6 +54,8 @@ mlir::OwningOpRef FE::doit(const CompilationSource &source) {
module = mlir::parseSourceString(source.ir_txt(),
ctx_->getMLIRContext());
+ SPU_ENFORCE(module, "MLIR parser failure");
+
// Convert stablehlo to mhlo first
mlir::PassManager pm(ctx_->getMLIRContext());
pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass());
diff --git a/libspu/compiler/front_end/hlo_importer.cc b/libspu/compiler/front_end/hlo_importer.cc
index 5845af57..363bd76c 100644
--- a/libspu/compiler/front_end/hlo_importer.cc
+++ b/libspu/compiler/front_end/hlo_importer.cc
@@ -196,12 +196,12 @@ HloImporter::parseXlaModuleFromString(const std::string &content) {
auto module_config =
xla::HloModule::CreateModuleConfigFromProto(hlo_module, debug_options);
if (!module_config.status().ok()) {
- SPU_THROW(module_config.status().message());
+ SPU_THROW("{}", module_config.status().message());
}
auto module = xla::HloModule::CreateFromProto(hlo_module, *module_config);
if (!module.status().ok()) {
- SPU_THROW(module.status().message());
+ SPU_THROW("{}", module.status().message());
}
xla::runHloPasses((*module).get());
@@ -214,7 +214,7 @@ HloImporter::parseXlaModuleFromString(const std::string &content) {
auto status = importer.Import(**module);
if (!status.ok()) {
- SPU_THROW(status.message());
+ SPU_THROW("{}", status.message());
}
return mlir_hlo;
diff --git a/libspu/device/api.cc b/libspu/device/api.cc
index 5535ad85..1b27f881 100644
--- a/libspu/device/api.cc
+++ b/libspu/device/api.cc
@@ -229,7 +229,7 @@ void printProfilingData(spu::SPUContext *sctx, const std::string &name,
void SPUErrorHandler(void *use_data, const char *reason, bool gen_crash_diag) {
(void)use_data;
(void)gen_crash_diag;
- SPU_THROW(reason);
+ SPU_THROW("{}", reason);
}
std::mutex ErrorHandlerMutex;
diff --git a/libspu/kernel/hal/permute.cc b/libspu/kernel/hal/permute.cc
index 657b81fc..1c67c997 100644
--- a/libspu/kernel/hal/permute.cc
+++ b/libspu/kernel/hal/permute.cc
@@ -19,6 +19,7 @@
#include "libspu/core/bit_utils.h"
#include "libspu/core/context.h"
#include "libspu/core/trace.h"
+#include "libspu/core/vectorize.h"
#include "libspu/kernel/hal/constants.h"
#include "libspu/kernel/hal/polymorphic.h"
#include "libspu/kernel/hal/prot_wrapper.h"
@@ -43,6 +44,12 @@ inline bool _has_same_owner(const Value &x, const Value &y) {
return _get_owner(x) == _get_owner(y);
}
+void _hint_nbits(const Value &a, size_t nbits) {
+ if (a.storage_type().isa()) {
+ const_cast(a.storage_type()).as()->setNbits(nbits);
+ }
+}
+
// generate inverse permutation
Index _inverse_index(const Index &p) {
Index q(p.size());
@@ -531,20 +538,29 @@ spu::Value _opt_apply_perm_ss(SPUContext *ctx, const spu::Value &perm,
std::vector _bit_decompose(SPUContext *ctx, const spu::Value &x,
int64_t valid_bits) {
auto x_bshare = _prefer_b(ctx, x);
- const auto k1 = _constant(ctx, 1U, x.shape());
- std::vector rets;
size_t nbits = valid_bits != -1
? static_cast(valid_bits)
: x_bshare.storage_type().as()->nbits();
- rets.reserve(nbits);
+ _hint_nbits(x_bshare, nbits);
+ if (ctx->hasKernel("b2a_disassemble")) {
+ auto ret =
+ dynDispatch>(ctx, "b2a_disassemble", x_bshare);
+ return ret;
+ }
+
+ const auto k1 = _constant(ctx, 1U, x.shape());
+ std::vector rets_b;
+ rets_b.reserve(nbits);
for (size_t bit = 0; bit < nbits; ++bit) {
auto x_bshare_shift = right_shift_logical(ctx, x_bshare, bit);
- auto lowest_bit = _and(ctx, x_bshare_shift, k1);
- rets.emplace_back(_prefer_a(ctx, lowest_bit));
+ rets_b.push_back(_and(ctx, x_bshare_shift, k1));
}
- return rets;
+ std::vector rets_a;
+ vmap(rets_b.begin(), rets_b.end(), std::back_inserter(rets_a),
+ [&](const Value &x) { return _prefer_a(ctx, x); });
+ return rets_a;
}
// Generate vector of bit decomposition of sorting keys
diff --git a/libspu/mpc/cheetah/boolean_semi2k.cc b/libspu/mpc/cheetah/boolean_semi2k.cc
index a64da4cc..786d0c38 100644
--- a/libspu/mpc/cheetah/boolean_semi2k.cc
+++ b/libspu/mpc/cheetah/boolean_semi2k.cc
@@ -81,8 +81,8 @@ NdArrayRef P2B::proc(KernelEvalContext* ctx, const NdArrayRef& in) const {
if (comm->getRank() == 0) {
ring_xor_(x, in);
}
-
- return makeBShare(x, field, getNumBits(in));
+ auto nbits = getNumBits(in) == 0 ? 1 : getNumBits(in);
+ return makeBShare(x, field, nbits);
}
NdArrayRef AndBP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs,
diff --git a/libspu/mpc/kernel.cc b/libspu/mpc/kernel.cc
index 45bed576..90d51818 100644
--- a/libspu/mpc/kernel.cc
+++ b/libspu/mpc/kernel.cc
@@ -233,6 +233,17 @@ void ConcateKernel::evaluate(KernelEvalContext* ctx) const {
ctx->pushOutput(WrapValue(z));
}
+void DisassembleKernel::evaluate(KernelEvalContext* ctx) const {
+ const auto& in = ctx->getParam(0);
+ auto z = proc(ctx, UnwrapValue(in));
+
+ std::vector wrapped(z.size());
+ for (size_t idx = 0; idx < z.size(); ++idx) {
+ wrapped[idx] = WrapValue(z[idx]);
+ }
+ ctx->pushOutput(wrapped);
+};
+
void OramOneHotKernel::evaluate(KernelEvalContext* ctx) const {
auto target = ctx->getParam(0);
auto s = ctx->getParam(1);
diff --git a/libspu/mpc/kernel.h b/libspu/mpc/kernel.h
index 4a4c29d8..12383391 100644
--- a/libspu/mpc/kernel.h
+++ b/libspu/mpc/kernel.h
@@ -217,4 +217,12 @@ class ConcateKernel : public Kernel {
int64_t axis) const = 0;
};
+class DisassembleKernel : public Kernel {
+ public:
+ void evaluate(KernelEvalContext* ctx) const override;
+
+ virtual std::vector proc(KernelEvalContext* ctx,
+ const NdArrayRef& in) const = 0;
+};
+
} // namespace spu::mpc
diff --git a/libspu/mpc/semi2k/conversion.cc b/libspu/mpc/semi2k/conversion.cc
index 35d8c701..dce1857f 100644
--- a/libspu/mpc/semi2k/conversion.cc
+++ b/libspu/mpc/semi2k/conversion.cc
@@ -42,6 +42,26 @@ static NdArrayRef wrap_and_bb(SPUContext* ctx, const NdArrayRef& x,
return UnwrapValue(and_bb(ctx, WrapValue(x), WrapValue(y)));
}
+// TODO: Move to some common place
+PtType getBacktype(size_t nbits) {
+ if (nbits <= 8) {
+ return PT_U8;
+ }
+ if (nbits <= 16) {
+ return PT_U16;
+ }
+ if (nbits <= 32) {
+ return PT_U32;
+ }
+ if (nbits <= 64) {
+ return PT_U64;
+ }
+ if (nbits <= 128) {
+ return PT_U128;
+ }
+ SPU_THROW("invalid number of bits={}", nbits);
+}
+
NdArrayRef A2B::proc(KernelEvalContext* ctx, const NdArrayRef& x) const {
const auto field = x.eltype().as()->field();
auto* comm = ctx->getState();
@@ -90,6 +110,9 @@ NdArrayRef B2A::proc(KernelEvalContext* ctx, const NdArrayRef& x) const {
return r_a;
}
+// TODO(jimi): pack {numel * nbits} to fully make use of undelying storage to
+// save communications. If implemented, B2A_Disassemble kernel is also no longer
+// needed
NdArrayRef B2A_Randbit::proc(KernelEvalContext* ctx,
const NdArrayRef& x) const {
const auto field = x.eltype().as()->field();
@@ -105,6 +128,7 @@ NdArrayRef B2A_Randbit::proc(KernelEvalContext* ctx,
const auto numel = x.numel();
const auto rand_numel = numel * static_cast(nbits);
+ const PtType backtype = getBacktype(nbits);
auto randbits = beaver->RandBit(field, rand_numel);
SPU_ENFORCE(static_cast(randbits.size()) ==
@@ -119,32 +143,125 @@ NdArrayRef B2A_Randbit::proc(KernelEvalContext* ctx,
// algorithm begins.
// Ref: III.D @ https://eprint.iacr.org/2019/599.pdf (SPDZ-2K primitives)
- std::vector x_xor_r(numel);
-
- pforeach(0, numel, [&](int64_t idx) {
- // use _r[i*nbits, (i+1)*nbits) to construct rb[i]
- U mask = 0;
- for (int64_t bit = 0; bit < nbits; ++bit) {
- mask += (_randbits[idx * nbits + bit] & 0x1) << bit;
- }
- x_xor_r[idx] = _x[idx] ^ mask;
+ DISPATCH_UINT_PT_TYPES(backtype, "_", [&]() {
+ using V = ScalarT;
+ std::vector x_xor_r(numel);
+
+ pforeach(0, numel, [&](int64_t idx) {
+ // use _r[i*nbits, (i+1)*nbits) to construct rb[i]
+ V mask = 0;
+ for (int64_t bit = 0; bit < nbits; ++bit) {
+ mask += (static_cast(_randbits[idx * nbits + bit]) & 0x1) << bit;
+ }
+ x_xor_r[idx] = _x[idx] ^ mask;
+ });
+
+ // open c = x ^ r
+ x_xor_r = comm->allReduce(x_xor_r, "open(x^r)");
+
+ NdArrayView _res(res);
+ pforeach(0, numel, [&](int64_t idx) {
+ _res[idx] = 0;
+ for (int64_t bit = 0; bit < nbits; bit++) {
+ auto c_i = static_cast(x_xor_r[idx] >> bit) & 0x1;
+ if (comm->getRank() == 0) {
+ _res[idx] += (c_i + (1 - c_i * 2) * _randbits[idx * nbits + bit])
+ << bit;
+ } else {
+ _res[idx] += ((1 - c_i * 2) * _randbits[idx * nbits + bit]) << bit;
+ }
+ }
+ });
});
+ });
+
+ return res;
+}
- // open c = x ^ r
- x_xor_r = comm->allReduce(x_xor_r, "open(x^r)");
-
- NdArrayView _res(res);
- pforeach(0, numel, [&](int64_t idx) {
- _res[idx] = 0;
- for (int64_t bit = 0; bit < nbits; bit++) {
- auto c_i = (x_xor_r[idx] >> bit) & 0x1;
- if (comm->getRank() == 0) {
- _res[idx] += (c_i + (1 - c_i * 2) * _randbits[idx * nbits + bit])
- << bit;
- } else {
- _res[idx] += ((1 - c_i * 2) * _randbits[idx * nbits + bit]) << bit;
+// Reference:
+// III.D @ https://eprint.iacr.org/2019/599.pdf (SPDZ-2K primitives)
+//
+// Analysis:
+// Online Latency: 1 (x_xor_r reveal)
+// Communication: one element bits for one element
+// Vectorization: yes
+//
+// HighLevel Intuition:
+// Since: X = sum: Xi * 2^i
+// If we have A, then we can construct A = sum: A * 2^i.
+//
+// The problem is that we only have B in hand. Details for how to
+// construct A from B:
+// - trusted third party choose a random bit r, where r == 0 or r == 1.
+// - trusted third party send A to parties
+// - parties compute B from A
+// - parties xor_open c = Xi ^ r = open(B ^ B), Xi is still safe due
+// to protection from r.
+// - parties compute: = c + (1-2c)*
+// A = 1 - A if c == 1, i.e. Xi != r
+// A = A if c == 0, i.e. Xi == r
+// i.e. A = c + (1-2c) * A
+//
+// Online Communication:
+// = 1 (xor open)
+
+// Disassemble BShr to AShr bit-by-bit
+// Input: BShr
+// Return: a vector of k AShr, k is the valid bits of BShr
+std::vector B2A_Disassemble::proc(KernelEvalContext* ctx,
+ const NdArrayRef& x) const {
+ const auto field = x.eltype().as()->field();
+ auto* comm = ctx->getState();
+ auto* beaver = ctx->getState()->beaver();
+
+ const int64_t nbits = x.eltype().as()->nbits();
+ SPU_ENFORCE((size_t)nbits > 0 && (size_t)nbits <= SizeOf(field) * 8,
+ "invalid nbits={}", nbits);
+
+ const auto numel = x.numel();
+ const auto rand_numel = numel * static_cast(nbits);
+ const PtType backtype = getBacktype(nbits);
+
+ auto randbits = beaver->RandBit(field, rand_numel);
+
+ std::vector res;
+ res.reserve(nbits);
+ for (int64_t idx = 0; idx < nbits; ++idx) {
+ res.emplace_back(makeType(field), x.shape());
+ }
+ DISPATCH_ALL_FIELDS(field, "_", [&]() {
+ using U = ring2k_t;
+
+ absl::Span _randbits(randbits.data(), rand_numel);
+ NdArrayView _x(x);
+
+ DISPATCH_UINT_PT_TYPES(backtype, "_", [&]() {
+ using V = ScalarT;
+ std::vector x_xor_r(numel);
+
+ pforeach(0, numel, [&](int64_t idx) {
+ // use _r[i*nbits, (i+1)*nbits) to construct rb[i]
+ V mask = 0;
+ for (int64_t bit = 0; bit < nbits; ++bit) {
+ mask += (static_cast(_randbits[idx * nbits + bit]) & 0x1) << bit;
}
- }
+ x_xor_r[idx] = _x[idx] ^ mask;
+ });
+
+ // open c = x ^ r
+ x_xor_r = comm->allReduce(x_xor_r, "open(x^r)");
+
+ pforeach(0, numel, [&](int64_t idx) {
+ pforeach(0, nbits, [&](int64_t bit) {
+ NdArrayView _res(res[bit]);
+ auto c_i = static_cast(x_xor_r[idx] >> bit) & 0x1;
+ if (comm->getRank() == 0) {
+ _res[idx] = (c_i + (1 - c_i * 2) * _randbits[idx * nbits + bit]);
+ } else {
+ _res[idx] = ((1 - c_i * 2) * _randbits[idx * nbits + bit]);
+ }
+ });
+ });
});
});
diff --git a/libspu/mpc/semi2k/conversion.h b/libspu/mpc/semi2k/conversion.h
index bc249998..891a23cd 100644
--- a/libspu/mpc/semi2k/conversion.h
+++ b/libspu/mpc/semi2k/conversion.h
@@ -73,6 +73,21 @@ class B2A_Randbit : public UnaryKernel {
NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& x) const override;
};
+class B2A_Disassemble : public DisassembleKernel {
+ public:
+ static constexpr char kBindName[] = "b2a_disassemble";
+
+ ce::CExpr latency() const override { return ce::Const(1); }
+
+ ce::CExpr comm() const override {
+ return ce::K() * (ce::N() - 1) // Open bit masked value
+ ;
+ }
+
+ std::vector proc(KernelEvalContext* ctx,
+ const NdArrayRef& x) const override;
+};
+
// Note: current only for 2PC.
class MsbA2B : public UnaryKernel {
public:
diff --git a/libspu/mpc/semi2k/protocol.cc b/libspu/mpc/semi2k/protocol.cc
index a7d3b7e1..3acd8344 100644
--- a/libspu/mpc/semi2k/protocol.cc
+++ b/libspu/mpc/semi2k/protocol.cc
@@ -50,21 +50,23 @@ void regSemi2kProtocol(SPUContext* ctx,
ctx->prot()->addState(ctx->config(), lctx);
ctx->prot()
->regKernel<
- semi2k::P2A, semi2k::A2P, semi2k::A2V, semi2k::V2A, //
- semi2k::NotA, //
- semi2k::AddAP, semi2k::AddAA, //
- semi2k::MulAP, semi2k::MulAA, semi2k::SquareA, //
- semi2k::MatMulAP, semi2k::MatMulAA, //
- semi2k::LShiftA, semi2k::LShiftB, semi2k::RShiftB,
- semi2k::ARShiftB, //
- semi2k::CommonTypeB, semi2k::CommonTypeV, semi2k::CastTypeB, //
- semi2k::B2P, semi2k::P2B, semi2k::A2B, semi2k::B2A_Randbit, //
- semi2k::AndBP, semi2k::AndBB, semi2k::XorBP, semi2k::XorBB,
- semi2k::BitrevB, //
- semi2k::BitIntlB, semi2k::BitDeintlB, //
- semi2k::RandA, semi2k::RandPermM, semi2k::PermAM, semi2k::PermAP,
- semi2k::InvPermAM, semi2k::InvPermAP, semi2k::InvPermAV, //
- semi2k::EqualAA, semi2k::EqualAP, semi2k::BeaverCacheKernel>();
+ semi2k::P2A, semi2k::A2P, semi2k::A2V, semi2k::V2A, //
+ semi2k::NotA, //
+ semi2k::AddAP, semi2k::AddAA, //
+ semi2k::MulAP, semi2k::MulAA, semi2k::SquareA, //
+ semi2k::MatMulAP, semi2k::MatMulAA, //
+ semi2k::LShiftA, semi2k::LShiftB, semi2k::RShiftB, //
+ semi2k::ARShiftB, //
+ semi2k::CommonTypeB, semi2k::CommonTypeV, semi2k::CastTypeB, //
+ semi2k::B2P, semi2k::P2B, //
+ semi2k::A2B, semi2k::B2A_Randbit, semi2k::B2A_Disassemble, //
+ semi2k::AndBP, semi2k::AndBB, semi2k::XorBP, semi2k::XorBB, //
+ semi2k::BitrevB, //
+ semi2k::BitIntlB, semi2k::BitDeintlB, //
+ semi2k::RandA, semi2k::RandPermM, semi2k::PermAM, semi2k::PermAP, //
+ semi2k::InvPermAM, semi2k::InvPermAP, semi2k::InvPermAV, //
+ semi2k::EqualAA, semi2k::EqualAP, //
+ semi2k::BeaverCacheKernel>();
if (ctx->config().trunc_allow_msb_error()) {
ctx->prot()->regKernel();
diff --git a/requirements.txt b/requirements.txt
index bea5a7d1..ac8fd6cc 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,8 +1,8 @@
grpcio>=1.42.0,!=1.48.0
-numpy>=1.22.0, < 2
+numpy>=1.22.0
protobuf>=4, <5
cloudpickle>=2.0.0
multiprocess>=0.70.12.2
cachetools>=5.0.0
-jax[cpu]>=0.4.16, <=0.4.26 # FIXME
+jax[cpu]>=0.4.16, <=0.4.26 # FIXME: Jax 0.4.26+ select perf issue
termcolor>=2.0.0
diff --git a/spu/tests/BUILD.bazel b/spu/tests/BUILD.bazel
index d7453e60..cfd029b8 100644
--- a/spu/tests/BUILD.bazel
+++ b/spu/tests/BUILD.bazel
@@ -87,6 +87,7 @@ py_test(
py_test(
name = "jnp_cheetah_r64_test",
+ size = "enormous",
timeout = "long",
srcs = ["jnp_cheetah_r64_test.py"],
deps = [
@@ -96,7 +97,7 @@ py_test(
py_test(
name = "jnp_cheetah_r64_test_x64",
- timeout = "long",
+ size = "enormous",
srcs = ["jnp_cheetah_r64_test.py"],
env = {
"ENABLE_X64_TEST": "1",
diff --git a/spu/tests/jnp_cheetah_r64_test.py b/spu/tests/jnp_cheetah_r64_test.py
index b113dbcd..10c02a53 100644
--- a/spu/tests/jnp_cheetah_r64_test.py
+++ b/spu/tests/jnp_cheetah_r64_test.py
@@ -22,8 +22,7 @@
from spu.tests.jnp_testbase import JnpTests
-@unittest.skip("too slow, last run succeed")
-class JnpTestAby3FM64(JnpTests.JnpTestBase):
+class JnpTestCheetahFM64(JnpTests.JnpTestBase):
def setUp(self):
self._sim = ppsim.Simulator.simple(
2, spu_pb2.ProtocolKind.CHEETAH, spu_pb2.FieldType.FM64
diff --git a/spu/utils/distributed_impl.py b/spu/utils/distributed_impl.py
index 79b13426..ebd6a703 100644
--- a/spu/utils/distributed_impl.py
+++ b/spu/utils/distributed_impl.py
@@ -723,7 +723,10 @@ def mock_parameters(obj: Union[SPU.Object, np.ndarray]):
fn_name = repr(fn)
- import jax.extend.linear_util as lu
+ try:
+ import jax.extend.linear_util as lu
+ except ImportError:
+ import jax.linear_util as lu # fallback
from jax._src import api_util as japi_util
from jax.tree_util import tree_map, tree_flatten
diff --git a/spu/utils/frontend.py b/spu/utils/frontend.py
index e1a4c24c..fc20cd0c 100644
--- a/spu/utils/frontend.py
+++ b/spu/utils/frontend.py
@@ -115,13 +115,32 @@ def _jax_compilation(
register_backend_factory('interpreter', xla_back, priority=-100)
- fn, kwargs = _argnames_partial_except(fn, static_argnames, kwargs)
+ jax_version = jax.__version_info__
+
+ if jax_version[0] > 1 or jax_version[1] > 4 or jax_version[2] > 29:
+ # xla_computation is deprecated since 0.4.30, move to new api
+ lowered = (
+ jax.jit(
+ fn,
+ static_argnums=static_argnums,
+ static_argnames=static_argnames,
+ keep_unused=True,
+ )
+ .trace(*args, **kwargs)
+ .lower(lowering_platforms=('interpreter',))
+ )
+ return (
+ lowered.compiler_ir('hlo').as_serialized_hlo_module_proto(),
+ lowered.out_info,
+ )
+ else:
+ fn, kwargs = _argnames_partial_except(fn, static_argnames, kwargs)
- cfn, output = jax.xla_computation(
- fn, return_shape=True, static_argnums=static_argnums, backend="interpreter"
- )(*args, **kwargs)
+ cfn, output = jax.xla_computation(
+ fn, return_shape=True, static_argnums=static_argnums, backend="interpreter"
+ )(*args, **kwargs)
- return cfn.as_serialized_hlo_module_proto(), output
+ return cfn.as_serialized_hlo_module_proto(), output
## Frontend patches
diff --git a/spu/utils/simulation.py b/spu/utils/simulation.py
index 8df6185c..592ccaff 100644
--- a/spu/utils/simulation.py
+++ b/spu/utils/simulation.py
@@ -17,7 +17,11 @@
from typing import Callable
import jax
-import jax.extend.linear_util as jax_lu # Moved in jax 0.4.16
+
+try:
+ import jax.extend.linear_util as jax_lu
+except ImportError:
+ import jax.linear_util as jax_lu # fallback
import jax.numpy as jnp
import numpy as np
from jax._src import api_util as japi_util