+
#### Operands:
| Operand | Description |
| :-----: | ----------- |
-| `lhs` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values
-| `rhs` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values
+| `operand` | statically shaped tensor of 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values values
#### Results:
| Result | Description |
| :----: | ----------- |
-| `result` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values
+| `result` | statically shaped tensor of 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values values
-### `pphlo.prefer_a` (spu::pphlo::PreferAOp)
+### `pphlo.power` (spu::pphlo::PowOp)
-_Prefer AShare operator_
+_Power operator_
Syntax:
```
-operation ::= `pphlo.prefer_a` $operand attr-dict `:` custom(type($operand), type($result))
+operation ::= `pphlo.power` $lhs `,` $rhs attr-dict
+ `:` custom(type($lhs), type($rhs), type($result))
```
-Convert input to AShare if possible.
+Performs element-wise exponentiation of `lhs` tensor by `rhs` tensor and produces a `result` tensor.
-Traits: `AlwaysSpeculatableImplTrait`, `Elementwise`, `SameOperandsAndResultShape`, `SameOperandsAndResultType`
+Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#power
+
+Traits: `AlwaysSpeculatableImplTrait`, `Elementwise`, `SameOperandsAndResultShape`
Interfaces: `ConditionallySpeculatable`, `InferShapedTypeOpInterface`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)`
@@ -1684,7 +1692,8 @@ Effects: `MemoryEffects::Effect{}`
| Operand | Description |
| :-----: | ----------- |
-| `operand` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values
+| `lhs` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values
+| `rhs` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values
#### Results:
@@ -2270,12 +2279,21 @@ Returns the sign of the `operand` element-wise and produces a `result` tensor.
Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sign
+PPHLO Extension: when `ignore_zero` is set to true, sign does not enforce sign(0) to 0
+
Traits: `AlwaysSpeculatableImplTrait`, `Elementwise`, `SameOperandsAndResultShape`, `SameOperandsAndResultType`
Interfaces: `ConditionallySpeculatable`, `InferShapedTypeOpInterface`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)`
Effects: `MemoryEffects::Effect{}`
+#### Attributes:
+
+
+
Attribute
MLIR Type
Description
+
ignore_zero
::mlir::BoolAttr
bool attribute
+
+
#### Operands:
| Operand | Description |
@@ -2377,7 +2395,7 @@ Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#slice
Traits: `AlwaysSpeculatableImplTrait`, `SameOperandsAndResultElementType`
-Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`
+Interfaces: `ConditionallySpeculatable`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)`
Effects: `MemoryEffects::Effect{}`
@@ -2551,7 +2569,7 @@ Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#transpose
Traits: `AlwaysSpeculatableImplTrait`, `SameOperandsAndResultElementType`
-Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`
+Interfaces: `ConditionallySpeculatable`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)`
Effects: `MemoryEffects::Effect{}`
diff --git a/docs/reference/runtime_config.md b/docs/reference/runtime_config.md
index f556bce4..836095cc 100644
--- a/docs/reference/runtime_config.md
+++ b/docs/reference/runtime_config.md
@@ -179,8 +179,9 @@ The SPU runtime configuration.
| Field | Type | Description |
| ----- | ---- | ----------- |
| server_host | [ string](#string) | TrustedThirdParty beaver server's remote ip:port or load-balance uri. |
-| session_id | [ string](#string) | if empty, use link id as session id. |
| adjust_rank | [ int32](#int32) | which rank do adjust rpc call, usually choose the rank closer to the server. |
+| asym_crypto_schema | [ string](#string) | asym_crypto_schema: support ["SM2"] Will support 25519 in the future, after yacl supported it. |
+| server_public_key | [ bytes](#bytes) | server's public key |
diff --git a/libspu/compiler/common/compilation_context.cc b/libspu/compiler/common/compilation_context.cc
index ef22871c..553fe913 100644
--- a/libspu/compiler/common/compilation_context.cc
+++ b/libspu/compiler/common/compilation_context.cc
@@ -23,7 +23,7 @@ namespace {
void SPUErrorHandler(void * /*use_data*/, const char *reason,
bool /*gen_crash_diag*/) {
- SPU_THROW(reason);
+ SPU_THROW("{}", reason);
}
} // namespace
diff --git a/libspu/compiler/common/ir_printer_config.cc b/libspu/compiler/common/ir_printer_config.cc
index 47dff64a..f3c15ff2 100644
--- a/libspu/compiler/common/ir_printer_config.cc
+++ b/libspu/compiler/common/ir_printer_config.cc
@@ -51,6 +51,7 @@ void IRPrinterConfig::printBeforeIfEnabled(Pass *pass, Operation *,
if (ec.value() != 0) {
spdlog::error("Open file {} failed, error = {}", file_name.c_str(),
ec.message());
+ return;
}
print_callback(f);
}
@@ -64,6 +65,7 @@ void IRPrinterConfig::printAfterIfEnabled(Pass *pass, Operation *,
if (ec.value() != 0) {
spdlog::error("Open file {} failed, error = {}", file_name.c_str(),
ec.message());
+ return;
}
print_callback(f);
}
diff --git a/libspu/compiler/front_end/fe.cc b/libspu/compiler/front_end/fe.cc
index 682f1c5b..e560c772 100644
--- a/libspu/compiler/front_end/fe.cc
+++ b/libspu/compiler/front_end/fe.cc
@@ -54,6 +54,8 @@ mlir::OwningOpRef FE::doit(const CompilationSource &source) {
module = mlir::parseSourceString(source.ir_txt(),
ctx_->getMLIRContext());
+ SPU_ENFORCE(module, "MLIR parser failure");
+
// Convert stablehlo to mhlo first
mlir::PassManager pm(ctx_->getMLIRContext());
pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass());
diff --git a/libspu/compiler/front_end/hlo_importer.cc b/libspu/compiler/front_end/hlo_importer.cc
index 5845af57..363bd76c 100644
--- a/libspu/compiler/front_end/hlo_importer.cc
+++ b/libspu/compiler/front_end/hlo_importer.cc
@@ -196,12 +196,12 @@ HloImporter::parseXlaModuleFromString(const std::string &content) {
auto module_config =
xla::HloModule::CreateModuleConfigFromProto(hlo_module, debug_options);
if (!module_config.status().ok()) {
- SPU_THROW(module_config.status().message());
+ SPU_THROW("{}", module_config.status().message());
}
auto module = xla::HloModule::CreateFromProto(hlo_module, *module_config);
if (!module.status().ok()) {
- SPU_THROW(module.status().message());
+ SPU_THROW("{}", module.status().message());
}
xla::runHloPasses((*module).get());
@@ -214,7 +214,7 @@ HloImporter::parseXlaModuleFromString(const std::string &content) {
auto status = importer.Import(**module);
if (!status.ok()) {
- SPU_THROW(status.message());
+ SPU_THROW("{}", status.message());
}
return mlir_hlo;
diff --git a/libspu/device/api.cc b/libspu/device/api.cc
index 5535ad85..1b27f881 100644
--- a/libspu/device/api.cc
+++ b/libspu/device/api.cc
@@ -229,7 +229,7 @@ void printProfilingData(spu::SPUContext *sctx, const std::string &name,
void SPUErrorHandler(void *use_data, const char *reason, bool gen_crash_diag) {
(void)use_data;
(void)gen_crash_diag;
- SPU_THROW(reason);
+ SPU_THROW("{}", reason);
}
std::mutex ErrorHandlerMutex;
diff --git a/libspu/kernel/hal/permute.cc b/libspu/kernel/hal/permute.cc
index 657b81fc..1c67c997 100644
--- a/libspu/kernel/hal/permute.cc
+++ b/libspu/kernel/hal/permute.cc
@@ -19,6 +19,7 @@
#include "libspu/core/bit_utils.h"
#include "libspu/core/context.h"
#include "libspu/core/trace.h"
+#include "libspu/core/vectorize.h"
#include "libspu/kernel/hal/constants.h"
#include "libspu/kernel/hal/polymorphic.h"
#include "libspu/kernel/hal/prot_wrapper.h"
@@ -43,6 +44,12 @@ inline bool _has_same_owner(const Value &x, const Value &y) {
return _get_owner(x) == _get_owner(y);
}
+void _hint_nbits(const Value &a, size_t nbits) {
+ if (a.storage_type().isa()) {
+ const_cast(a.storage_type()).as()->setNbits(nbits);
+ }
+}
+
// generate inverse permutation
Index _inverse_index(const Index &p) {
Index q(p.size());
@@ -531,20 +538,29 @@ spu::Value _opt_apply_perm_ss(SPUContext *ctx, const spu::Value &perm,
std::vector _bit_decompose(SPUContext *ctx, const spu::Value &x,
int64_t valid_bits) {
auto x_bshare = _prefer_b(ctx, x);
- const auto k1 = _constant(ctx, 1U, x.shape());
- std::vector rets;
size_t nbits = valid_bits != -1
? static_cast(valid_bits)
: x_bshare.storage_type().as()->nbits();
- rets.reserve(nbits);
+ _hint_nbits(x_bshare, nbits);
+ if (ctx->hasKernel("b2a_disassemble")) {
+ auto ret =
+ dynDispatch>(ctx, "b2a_disassemble", x_bshare);
+ return ret;
+ }
+
+ const auto k1 = _constant(ctx, 1U, x.shape());
+ std::vector rets_b;
+ rets_b.reserve(nbits);
for (size_t bit = 0; bit < nbits; ++bit) {
auto x_bshare_shift = right_shift_logical(ctx, x_bshare, bit);
- auto lowest_bit = _and(ctx, x_bshare_shift, k1);
- rets.emplace_back(_prefer_a(ctx, lowest_bit));
+ rets_b.push_back(_and(ctx, x_bshare_shift, k1));
}
- return rets;
+ std::vector rets_a;
+ vmap(rets_b.begin(), rets_b.end(), std::back_inserter(rets_a),
+ [&](const Value &x) { return _prefer_a(ctx, x); });
+ return rets_a;
}
// Generate vector of bit decomposition of sorting keys
diff --git a/libspu/mpc/cheetah/boolean_semi2k.cc b/libspu/mpc/cheetah/boolean_semi2k.cc
index a64da4cc..786d0c38 100644
--- a/libspu/mpc/cheetah/boolean_semi2k.cc
+++ b/libspu/mpc/cheetah/boolean_semi2k.cc
@@ -81,8 +81,8 @@ NdArrayRef P2B::proc(KernelEvalContext* ctx, const NdArrayRef& in) const {
if (comm->getRank() == 0) {
ring_xor_(x, in);
}
-
- return makeBShare(x, field, getNumBits(in));
+ auto nbits = getNumBits(in) == 0 ? 1 : getNumBits(in);
+ return makeBShare(x, field, nbits);
}
NdArrayRef AndBP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs,
diff --git a/libspu/mpc/kernel.cc b/libspu/mpc/kernel.cc
index 45bed576..90d51818 100644
--- a/libspu/mpc/kernel.cc
+++ b/libspu/mpc/kernel.cc
@@ -233,6 +233,17 @@ void ConcateKernel::evaluate(KernelEvalContext* ctx) const {
ctx->pushOutput(WrapValue(z));
}
+void DisassembleKernel::evaluate(KernelEvalContext* ctx) const {
+ const auto& in = ctx->getParam(0);
+ auto z = proc(ctx, UnwrapValue(in));
+
+ std::vector wrapped(z.size());
+ for (size_t idx = 0; idx < z.size(); ++idx) {
+ wrapped[idx] = WrapValue(z[idx]);
+ }
+ ctx->pushOutput(wrapped);
+};
+
void OramOneHotKernel::evaluate(KernelEvalContext* ctx) const {
auto target = ctx->getParam(0);
auto s = ctx->getParam(1);
diff --git a/libspu/mpc/kernel.h b/libspu/mpc/kernel.h
index 4a4c29d8..12383391 100644
--- a/libspu/mpc/kernel.h
+++ b/libspu/mpc/kernel.h
@@ -217,4 +217,12 @@ class ConcateKernel : public Kernel {
int64_t axis) const = 0;
};
+class DisassembleKernel : public Kernel {
+ public:
+ void evaluate(KernelEvalContext* ctx) const override;
+
+ virtual std::vector proc(KernelEvalContext* ctx,
+ const NdArrayRef& in) const = 0;
+};
+
} // namespace spu::mpc
diff --git a/libspu/mpc/semi2k/conversion.cc b/libspu/mpc/semi2k/conversion.cc
index 35d8c701..dce1857f 100644
--- a/libspu/mpc/semi2k/conversion.cc
+++ b/libspu/mpc/semi2k/conversion.cc
@@ -42,6 +42,26 @@ static NdArrayRef wrap_and_bb(SPUContext* ctx, const NdArrayRef& x,
return UnwrapValue(and_bb(ctx, WrapValue(x), WrapValue(y)));
}
+// TODO: Move to some common place
+PtType getBacktype(size_t nbits) {
+ if (nbits <= 8) {
+ return PT_U8;
+ }
+ if (nbits <= 16) {
+ return PT_U16;
+ }
+ if (nbits <= 32) {
+ return PT_U32;
+ }
+ if (nbits <= 64) {
+ return PT_U64;
+ }
+ if (nbits <= 128) {
+ return PT_U128;
+ }
+ SPU_THROW("invalid number of bits={}", nbits);
+}
+
NdArrayRef A2B::proc(KernelEvalContext* ctx, const NdArrayRef& x) const {
const auto field = x.eltype().as()->field();
auto* comm = ctx->getState();
@@ -90,6 +110,9 @@ NdArrayRef B2A::proc(KernelEvalContext* ctx, const NdArrayRef& x) const {
return r_a;
}
+// TODO(jimi): pack {numel * nbits} to fully make use of undelying storage to
+// save communications. If implemented, B2A_Disassemble kernel is also no longer
+// needed
NdArrayRef B2A_Randbit::proc(KernelEvalContext* ctx,
const NdArrayRef& x) const {
const auto field = x.eltype().as()->field();
@@ -105,6 +128,7 @@ NdArrayRef B2A_Randbit::proc(KernelEvalContext* ctx,
const auto numel = x.numel();
const auto rand_numel = numel * static_cast(nbits);
+ const PtType backtype = getBacktype(nbits);
auto randbits = beaver->RandBit(field, rand_numel);
SPU_ENFORCE(static_cast(randbits.size()) ==
@@ -119,32 +143,125 @@ NdArrayRef B2A_Randbit::proc(KernelEvalContext* ctx,
// algorithm begins.
// Ref: III.D @ https://eprint.iacr.org/2019/599.pdf (SPDZ-2K primitives)
- std::vector x_xor_r(numel);
-
- pforeach(0, numel, [&](int64_t idx) {
- // use _r[i*nbits, (i+1)*nbits) to construct rb[i]
- U mask = 0;
- for (int64_t bit = 0; bit < nbits; ++bit) {
- mask += (_randbits[idx * nbits + bit] & 0x1) << bit;
- }
- x_xor_r[idx] = _x[idx] ^ mask;
+ DISPATCH_UINT_PT_TYPES(backtype, "_", [&]() {
+ using V = ScalarT;
+ std::vector x_xor_r(numel);
+
+ pforeach(0, numel, [&](int64_t idx) {
+ // use _r[i*nbits, (i+1)*nbits) to construct rb[i]
+ V mask = 0;
+ for (int64_t bit = 0; bit < nbits; ++bit) {
+ mask += (static_cast(_randbits[idx * nbits + bit]) & 0x1) << bit;
+ }
+ x_xor_r[idx] = _x[idx] ^ mask;
+ });
+
+ // open c = x ^ r
+ x_xor_r = comm->allReduce(x_xor_r, "open(x^r)");
+
+ NdArrayView _res(res);
+ pforeach(0, numel, [&](int64_t idx) {
+ _res[idx] = 0;
+ for (int64_t bit = 0; bit < nbits; bit++) {
+ auto c_i = static_cast(x_xor_r[idx] >> bit) & 0x1;
+ if (comm->getRank() == 0) {
+ _res[idx] += (c_i + (1 - c_i * 2) * _randbits[idx * nbits + bit])
+ << bit;
+ } else {
+ _res[idx] += ((1 - c_i * 2) * _randbits[idx * nbits + bit]) << bit;
+ }
+ }
+ });
});
+ });
+
+ return res;
+}
- // open c = x ^ r
- x_xor_r = comm->allReduce(x_xor_r, "open(x^r)");
-
- NdArrayView _res(res);
- pforeach(0, numel, [&](int64_t idx) {
- _res[idx] = 0;
- for (int64_t bit = 0; bit < nbits; bit++) {
- auto c_i = (x_xor_r[idx] >> bit) & 0x1;
- if (comm->getRank() == 0) {
- _res[idx] += (c_i + (1 - c_i * 2) * _randbits[idx * nbits + bit])
- << bit;
- } else {
- _res[idx] += ((1 - c_i * 2) * _randbits[idx * nbits + bit]) << bit;
+// Reference:
+// III.D @ https://eprint.iacr.org/2019/599.pdf (SPDZ-2K primitives)
+//
+// Analysis:
+// Online Latency: 1 (x_xor_r reveal)
+// Communication: one element bits for one element
+// Vectorization: yes
+//
+// HighLevel Intuition:
+// Since: X = sum: Xi * 2^i
+// If we have A, then we can construct A = sum: A * 2^i.
+//
+// The problem is that we only have B in hand. Details for how to
+// construct A from B:
+// - trusted third party choose a random bit r, where r == 0 or r == 1.
+// - trusted third party send A to parties
+// - parties compute B from A
+// - parties xor_open c = Xi ^ r = open(B ^ B), Xi is still safe due
+// to protection from r.
+// - parties compute: = c + (1-2c)*
+// A = 1 - A if c == 1, i.e. Xi != r
+// A = A if c == 0, i.e. Xi == r
+// i.e. A = c + (1-2c) * A
+//
+// Online Communication:
+// = 1 (xor open)
+
+// Disassemble BShr to AShr bit-by-bit
+// Input: BShr
+// Return: a vector of k AShr, k is the valid bits of BShr
+std::vector B2A_Disassemble::proc(KernelEvalContext* ctx,
+ const NdArrayRef& x) const {
+ const auto field = x.eltype().as()->field();
+ auto* comm = ctx->getState();
+ auto* beaver = ctx->getState()->beaver();
+
+ const int64_t nbits = x.eltype().as()->nbits();
+ SPU_ENFORCE((size_t)nbits > 0 && (size_t)nbits <= SizeOf(field) * 8,
+ "invalid nbits={}", nbits);
+
+ const auto numel = x.numel();
+ const auto rand_numel = numel * static_cast(nbits);
+ const PtType backtype = getBacktype(nbits);
+
+ auto randbits = beaver->RandBit(field, rand_numel);
+
+ std::vector res;
+ res.reserve(nbits);
+ for (int64_t idx = 0; idx < nbits; ++idx) {
+ res.emplace_back(makeType(field), x.shape());
+ }
+ DISPATCH_ALL_FIELDS(field, "_", [&]() {
+ using U = ring2k_t;
+
+ absl::Span _randbits(randbits.data(), rand_numel);
+ NdArrayView _x(x);
+
+ DISPATCH_UINT_PT_TYPES(backtype, "_", [&]() {
+ using V = ScalarT;
+ std::vector x_xor_r(numel);
+
+ pforeach(0, numel, [&](int64_t idx) {
+ // use _r[i*nbits, (i+1)*nbits) to construct rb[i]
+ V mask = 0;
+ for (int64_t bit = 0; bit < nbits; ++bit) {
+ mask += (static_cast(_randbits[idx * nbits + bit]) & 0x1) << bit;
}
- }
+ x_xor_r[idx] = _x[idx] ^ mask;
+ });
+
+ // open c = x ^ r
+ x_xor_r = comm->allReduce(x_xor_r, "open(x^r)");
+
+ pforeach(0, numel, [&](int64_t idx) {
+ pforeach(0, nbits, [&](int64_t bit) {
+ NdArrayView _res(res[bit]);
+ auto c_i = static_cast(x_xor_r[idx] >> bit) & 0x1;
+ if (comm->getRank() == 0) {
+ _res[idx] = (c_i + (1 - c_i * 2) * _randbits[idx * nbits + bit]);
+ } else {
+ _res[idx] = ((1 - c_i * 2) * _randbits[idx * nbits + bit]);
+ }
+ });
+ });
});
});
diff --git a/libspu/mpc/semi2k/conversion.h b/libspu/mpc/semi2k/conversion.h
index bc249998..891a23cd 100644
--- a/libspu/mpc/semi2k/conversion.h
+++ b/libspu/mpc/semi2k/conversion.h
@@ -73,6 +73,21 @@ class B2A_Randbit : public UnaryKernel {
NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& x) const override;
};
+class B2A_Disassemble : public DisassembleKernel {
+ public:
+ static constexpr char kBindName[] = "b2a_disassemble";
+
+ ce::CExpr latency() const override { return ce::Const(1); }
+
+ ce::CExpr comm() const override {
+ return ce::K() * (ce::N() - 1) // Open bit masked value
+ ;
+ }
+
+ std::vector proc(KernelEvalContext* ctx,
+ const NdArrayRef& x) const override;
+};
+
// Note: current only for 2PC.
class MsbA2B : public UnaryKernel {
public:
diff --git a/libspu/mpc/semi2k/protocol.cc b/libspu/mpc/semi2k/protocol.cc
index a7d3b7e1..3acd8344 100644
--- a/libspu/mpc/semi2k/protocol.cc
+++ b/libspu/mpc/semi2k/protocol.cc
@@ -50,21 +50,23 @@ void regSemi2kProtocol(SPUContext* ctx,
ctx->prot()->addState(ctx->config(), lctx);
ctx->prot()
->regKernel<
- semi2k::P2A, semi2k::A2P, semi2k::A2V, semi2k::V2A, //
- semi2k::NotA, //
- semi2k::AddAP, semi2k::AddAA, //
- semi2k::MulAP, semi2k::MulAA, semi2k::SquareA, //
- semi2k::MatMulAP, semi2k::MatMulAA, //
- semi2k::LShiftA, semi2k::LShiftB, semi2k::RShiftB,
- semi2k::ARShiftB, //
- semi2k::CommonTypeB, semi2k::CommonTypeV, semi2k::CastTypeB, //
- semi2k::B2P, semi2k::P2B, semi2k::A2B, semi2k::B2A_Randbit, //
- semi2k::AndBP, semi2k::AndBB, semi2k::XorBP, semi2k::XorBB,
- semi2k::BitrevB, //
- semi2k::BitIntlB, semi2k::BitDeintlB, //
- semi2k::RandA, semi2k::RandPermM, semi2k::PermAM, semi2k::PermAP,
- semi2k::InvPermAM, semi2k::InvPermAP, semi2k::InvPermAV, //
- semi2k::EqualAA, semi2k::EqualAP, semi2k::BeaverCacheKernel>();
+ semi2k::P2A, semi2k::A2P, semi2k::A2V, semi2k::V2A, //
+ semi2k::NotA, //
+ semi2k::AddAP, semi2k::AddAA, //
+ semi2k::MulAP, semi2k::MulAA, semi2k::SquareA, //
+ semi2k::MatMulAP, semi2k::MatMulAA, //
+ semi2k::LShiftA, semi2k::LShiftB, semi2k::RShiftB, //
+ semi2k::ARShiftB, //
+ semi2k::CommonTypeB, semi2k::CommonTypeV, semi2k::CastTypeB, //
+ semi2k::B2P, semi2k::P2B, //
+ semi2k::A2B, semi2k::B2A_Randbit, semi2k::B2A_Disassemble, //
+ semi2k::AndBP, semi2k::AndBB, semi2k::XorBP, semi2k::XorBB, //
+ semi2k::BitrevB, //
+ semi2k::BitIntlB, semi2k::BitDeintlB, //
+ semi2k::RandA, semi2k::RandPermM, semi2k::PermAM, semi2k::PermAP, //
+ semi2k::InvPermAM, semi2k::InvPermAP, semi2k::InvPermAV, //
+ semi2k::EqualAA, semi2k::EqualAP, //
+ semi2k::BeaverCacheKernel>();
if (ctx->config().trunc_allow_msb_error()) {
ctx->prot()->regKernel();
diff --git a/requirements.txt b/requirements.txt
index bea5a7d1..ac8fd6cc 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,8 +1,8 @@
grpcio>=1.42.0,!=1.48.0
-numpy>=1.22.0, < 2
+numpy>=1.22.0
protobuf>=4, <5
cloudpickle>=2.0.0
multiprocess>=0.70.12.2
cachetools>=5.0.0
-jax[cpu]>=0.4.16, <=0.4.26 # FIXME
+jax[cpu]>=0.4.16, <=0.4.26 # FIXME: Jax 0.4.26+ select perf issue
termcolor>=2.0.0
diff --git a/spu/tests/BUILD.bazel b/spu/tests/BUILD.bazel
index d7453e60..cfd029b8 100644
--- a/spu/tests/BUILD.bazel
+++ b/spu/tests/BUILD.bazel
@@ -87,6 +87,7 @@ py_test(
py_test(
name = "jnp_cheetah_r64_test",
+ size = "enormous",
timeout = "long",
srcs = ["jnp_cheetah_r64_test.py"],
deps = [
@@ -96,7 +97,7 @@ py_test(
py_test(
name = "jnp_cheetah_r64_test_x64",
- timeout = "long",
+ size = "enormous",
srcs = ["jnp_cheetah_r64_test.py"],
env = {
"ENABLE_X64_TEST": "1",
diff --git a/spu/tests/jnp_cheetah_r64_test.py b/spu/tests/jnp_cheetah_r64_test.py
index b113dbcd..10c02a53 100644
--- a/spu/tests/jnp_cheetah_r64_test.py
+++ b/spu/tests/jnp_cheetah_r64_test.py
@@ -22,8 +22,7 @@
from spu.tests.jnp_testbase import JnpTests
-@unittest.skip("too slow, last run succeed")
-class JnpTestAby3FM64(JnpTests.JnpTestBase):
+class JnpTestCheetahFM64(JnpTests.JnpTestBase):
def setUp(self):
self._sim = ppsim.Simulator.simple(
2, spu_pb2.ProtocolKind.CHEETAH, spu_pb2.FieldType.FM64
diff --git a/spu/utils/distributed_impl.py b/spu/utils/distributed_impl.py
index 79b13426..ebd6a703 100644
--- a/spu/utils/distributed_impl.py
+++ b/spu/utils/distributed_impl.py
@@ -723,7 +723,10 @@ def mock_parameters(obj: Union[SPU.Object, np.ndarray]):
fn_name = repr(fn)
- import jax.extend.linear_util as lu
+ try:
+ import jax.extend.linear_util as lu
+ except ImportError:
+ import jax.linear_util as lu # fallback
from jax._src import api_util as japi_util
from jax.tree_util import tree_map, tree_flatten
diff --git a/spu/utils/frontend.py b/spu/utils/frontend.py
index e1a4c24c..fc20cd0c 100644
--- a/spu/utils/frontend.py
+++ b/spu/utils/frontend.py
@@ -115,13 +115,32 @@ def _jax_compilation(
register_backend_factory('interpreter', xla_back, priority=-100)
- fn, kwargs = _argnames_partial_except(fn, static_argnames, kwargs)
+ jax_version = jax.__version_info__
+
+ if jax_version[0] > 1 or jax_version[1] > 4 or jax_version[2] > 29:
+ # xla_computation is deprecated since 0.4.30, move to new api
+ lowered = (
+ jax.jit(
+ fn,
+ static_argnums=static_argnums,
+ static_argnames=static_argnames,
+ keep_unused=True,
+ )
+ .trace(*args, **kwargs)
+ .lower(lowering_platforms=('interpreter',))
+ )
+ return (
+ lowered.compiler_ir('hlo').as_serialized_hlo_module_proto(),
+ lowered.out_info,
+ )
+ else:
+ fn, kwargs = _argnames_partial_except(fn, static_argnames, kwargs)
- cfn, output = jax.xla_computation(
- fn, return_shape=True, static_argnums=static_argnums, backend="interpreter"
- )(*args, **kwargs)
+ cfn, output = jax.xla_computation(
+ fn, return_shape=True, static_argnums=static_argnums, backend="interpreter"
+ )(*args, **kwargs)
- return cfn.as_serialized_hlo_module_proto(), output
+ return cfn.as_serialized_hlo_module_proto(), output
## Frontend patches
diff --git a/spu/utils/simulation.py b/spu/utils/simulation.py
index 8df6185c..592ccaff 100644
--- a/spu/utils/simulation.py
+++ b/spu/utils/simulation.py
@@ -17,7 +17,11 @@
from typing import Callable
import jax
-import jax.extend.linear_util as jax_lu # Moved in jax 0.4.16
+
+try:
+ import jax.extend.linear_util as jax_lu
+except ImportError:
+ import jax.linear_util as jax_lu # fallback
import jax.numpy as jnp
import numpy as np
from jax._src import api_util as japi_util
From 279fc79bff1ca2f8768f2683417e1c7d1556ffaf Mon Sep 17 00:00:00 2001
From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com>
Date: Mon, 8 Jul 2024 15:02:23 +0800
Subject: [PATCH 03/27] chore(deps): update dependency rstcheck to v6.2.4
(#756)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
[![Mend
Renovate](https://app.renovatebot.com/images/banner.svg)](https://renovatebot.com)
This PR contains the following updates:
| Package | Change | Age | Adoption | Passing | Confidence |
|---|---|---|---|---|---|
| [rstcheck](https://togithub.com/rstcheck/rstcheck)
([changelog](https://togithub.com/rstcheck/rstcheck/blob/main/CHANGELOG.md))
| `==6.2.1` -> `==6.2.4` |
[![age](https://developer.mend.io/api/mc/badges/age/pypi/rstcheck/6.2.4?slim=true)](https://docs.renovatebot.com/merge-confidence/)
|
[![adoption](https://developer.mend.io/api/mc/badges/adoption/pypi/rstcheck/6.2.4?slim=true)](https://docs.renovatebot.com/merge-confidence/)
|
[![passing](https://developer.mend.io/api/mc/badges/compatibility/pypi/rstcheck/6.2.1/6.2.4?slim=true)](https://docs.renovatebot.com/merge-confidence/)
|
[![confidence](https://developer.mend.io/api/mc/badges/confidence/pypi/rstcheck/6.2.1/6.2.4?slim=true)](https://docs.renovatebot.com/merge-confidence/)
|
---
### Release Notes
rstcheck/rstcheck (rstcheck)
###
[`v6.2.4`](https://togithub.com/rstcheck/rstcheck/blob/HEAD/CHANGELOG.md#v624-2024-07-07)
[Compare
Source](https://togithub.com/rstcheck/rstcheck/compare/v6.2.3...v6.2.4)
[diff
v6.2.3...v6.2.4](https://togithub.com/rstcheck/rstcheck/compare/v6.2.3...v6.2.4)
##### Documentation
- Add note on how to disable pretty exception output
([#228](https://togithub.com/rstcheck/rstcheck/pull/228))
##### Miscellaneous
- Add help text to `--version` flag
([#228](https://togithub.com/rstcheck/rstcheck/pull/228))
###
[`v6.2.3`](https://togithub.com/rstcheck/rstcheck/blob/HEAD/CHANGELOG.md#v623-2024-07-07)
[Compare
Source](https://togithub.com/rstcheck/rstcheck/compare/v6.2.2...v6.2.3)
[diff
v6.2.2...v6.2.3](https://togithub.com/rstcheck/rstcheck/compare/v6.2.2...v6.2.3)
##### Bugfixes
- Fix typer dependency by removing the `[standard]` extra which is only
used on typer-slim.
Typer by default has the extras included.
###
[`v6.2.2`](https://togithub.com/rstcheck/rstcheck/blob/HEAD/CHANGELOG.md#v622-2024-07-07)
[Compare
Source](https://togithub.com/rstcheck/rstcheck/compare/v6.2.1...v6.2.2)
[diff
v6.2.1...v6.2.2](https://togithub.com/rstcheck/rstcheck/compare/v6.2.1...v6.2.2)
##### Miscellaneous
- Bump min. version of typer and fix dependency group name
([#223](https://togithub.com/rstcheck/rstcheck/issues/223))
- Update configs for dev tooling
([#225](https://togithub.com/rstcheck/rstcheck/pull/225))
- Bump default python version to 3.12
([#225](https://togithub.com/rstcheck/rstcheck/pull/225))
---
### Configuration
📅 **Schedule**: Branch creation - At any time (no schedule defined),
Automerge - At any time (no schedule defined).
🚦 **Automerge**: Disabled by config. Please merge this manually once you
are satisfied.
♻ **Rebasing**: Whenever PR becomes conflicted, or you tick the
rebase/retry checkbox.
🔕 **Ignore**: Close this PR and you won't be reminded about this update
again.
---
- [ ] If you want to rebase/retry this PR, check
this box
---
This PR has been generated by [Mend
Renovate](https://www.mend.io/free-developer-tools/renovate/). View
repository job log
[here](https://developer.mend.io/github/secretflow/spu).
Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
---
docs/requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/requirements.txt b/docs/requirements.txt
index 0c9b6330..9b6bb997 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -1,5 +1,5 @@
myst-parser==3.0.1
-rstcheck==6.2.1
+rstcheck==6.2.4
sphinx==7.3.7
nbsphinx==0.9.4
sphinx-autobuild==2024.4.16
From 4ecd05bdcfe6e530c8dce4c84869311ae38df584 Mon Sep 17 00:00:00 2001
From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com>
Date: Mon, 8 Jul 2024 15:02:49 +0800
Subject: [PATCH 04/27] chore(deps): update github/codeql-action action to
v3.25.11 (#749)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
[![Mend
Renovate](https://app.renovatebot.com/images/banner.svg)](https://renovatebot.com)
This PR contains the following updates:
| Package | Type | Update | Change |
|---|---|---|---|
| [github/codeql-action](https://togithub.com/github/codeql-action) |
action | patch | `v3.25.10` -> `v3.25.11` |
---
### Release Notes
github/codeql-action (github/codeql-action)
###
[`v3.25.11`](https://togithub.com/github/codeql-action/compare/v3.25.10...v3.25.11)
[Compare
Source](https://togithub.com/github/codeql-action/compare/v3.25.10...v3.25.11)
---
### Configuration
📅 **Schedule**: Branch creation - At any time (no schedule defined),
Automerge - At any time (no schedule defined).
🚦 **Automerge**: Disabled by config. Please merge this manually once you
are satisfied.
♻ **Rebasing**: Whenever PR becomes conflicted, or you tick the
rebase/retry checkbox.
🔕 **Ignore**: Close this PR and you won't be reminded about this update
again.
---
- [ ] If you want to rebase/retry this PR, check
this box
---
This PR has been generated by [Mend
Renovate](https://www.mend.io/free-developer-tools/renovate/). View
repository job log
[here](https://developer.mend.io/github/secretflow/spu).
Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
---
.github/workflows/scorecard.yml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/.github/workflows/scorecard.yml b/.github/workflows/scorecard.yml
index 68978a19..3b6e8be7 100644
--- a/.github/workflows/scorecard.yml
+++ b/.github/workflows/scorecard.yml
@@ -67,6 +67,6 @@ jobs:
# Upload the results to GitHub's code scanning dashboard.
- name: "Upload to code-scanning"
- uses: github/codeql-action/upload-sarif@23acc5c183826b7a8a97bce3cecc52db901f8251 # v3.25.10
+ uses: github/codeql-action/upload-sarif@b611370bb5703a7efb587f9d136a52ea24c5c38c # v3.25.11
with:
sarif_file: results.sarif
From 54c0885f9a790c13ba76aee7ba8d3a8b99a28300 Mon Sep 17 00:00:00 2001
From: Yancheng Zheng <103552181+anakinxc@users.noreply.github.com>
Date: Mon, 8 Jul 2024 15:27:57 +0800
Subject: [PATCH 05/27] Repo sync (#759)
---
libspu/compiler/tests/interpret/power.mlir | 4 ++--
libspu/compiler/tests/interpret/test_json/power.json | 2 +-
spu/tests/BUILD.bazel | 1 -
3 files changed, 3 insertions(+), 4 deletions(-)
diff --git a/libspu/compiler/tests/interpret/power.mlir b/libspu/compiler/tests/interpret/power.mlir
index dd950ef1..307fce4c 100644
--- a/libspu/compiler/tests/interpret/power.mlir
+++ b/libspu/compiler/tests/interpret/power.mlir
@@ -58,7 +58,7 @@ func.func @power_op_test_f64_f64_pp() {
%1 = pphlo.constant dense<[2.0, 2.0, 2.0, -1.0, 1.0]> : tensor<5xf64>
%2 = pphlo.power %0,%1 : (tensor<5xf64>,tensor<5xf64>)->tensor<5xf64>
%3 = pphlo.constant dense<[4.000000e+00, 0.000000e+00, 2.500000e+01, 0.33333333333333331, 10000.0]> : tensor<5xf64>
- pphlo.custom_call @expect_almost_eq(%2, %3) { tol = 0.5 }: (tensor<5xf64>, tensor<5xf64>)->()
+ pphlo.custom_call @expect_almost_eq(%2, %3) { tol = 0.6 }: (tensor<5xf64>, tensor<5xf64>)->()
func.return
}
@@ -72,6 +72,6 @@ func.func @power_op_test_f64_f64_ss() {
%4 = pphlo.power %2, %3 : (tensor<5x!pphlo.secret>,tensor<5x!pphlo.secret>)->tensor<5x!pphlo.secret>
%5 = pphlo.constant dense<[4.000000e+00, 0.000000e+00, 2.500000e+01, 0.33333333333333331, 10000.0]> : tensor<5xf64>
%6 = pphlo.convert %4 : (tensor<5x!pphlo.secret>)->tensor<5xf64>
- pphlo.custom_call @expect_almost_eq(%5, %6) { tol = 0.5 }: (tensor<5xf64>, tensor<5xf64>)->()
+ pphlo.custom_call @expect_almost_eq(%5, %6) { tol = 0.6 }: (tensor<5xf64>, tensor<5xf64>)->()
func.return
}
diff --git a/libspu/compiler/tests/interpret/test_json/power.json b/libspu/compiler/tests/interpret/test_json/power.json
index e5e6f3cf..a95ae33c 100644
--- a/libspu/compiler/tests/interpret/test_json/power.json
+++ b/libspu/compiler/tests/interpret/test_json/power.json
@@ -65,7 +65,7 @@
}
],
"checker": "expect_almost_eq",
- "tol": 0.5
+ "tol": 0.6
}
]
}
\ No newline at end of file
diff --git a/spu/tests/BUILD.bazel b/spu/tests/BUILD.bazel
index cfd029b8..fd037279 100644
--- a/spu/tests/BUILD.bazel
+++ b/spu/tests/BUILD.bazel
@@ -88,7 +88,6 @@ py_test(
py_test(
name = "jnp_cheetah_r64_test",
size = "enormous",
- timeout = "long",
srcs = ["jnp_cheetah_r64_test.py"],
deps = [
":jnp_testbase",
From ddb1acbe8938be2328d8db9fe93a1bc84c093410 Mon Sep 17 00:00:00 2001
From: Yancheng Zheng <103552181+anakinxc@users.noreply.github.com>
Date: Mon, 8 Jul 2024 15:37:59 +0800
Subject: [PATCH 06/27] Update acknowledgement (#760)
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 56375409..8fad8f9c 100644
--- a/README.md
+++ b/README.md
@@ -71,4 +71,4 @@ If you think SPU is helpful for your research or development, please consider ci
## Acknowledgement
-We thank the significant contributions made by [Alibaba Gemini Lab](https://alibaba-gemini-lab.github.io).
+We thank the significant contributions made by [Alibaba Gemini Lab](https://alibaba-gemini-lab.github.io) and security advisories made by [NISL@THU](https://netsec.ccert.edu.cn/vul337).
From 1e5aca15cf1e680b7c8630142bcea3d05ce02ae1 Mon Sep 17 00:00:00 2001
From: Yancheng Zheng <103552181+anakinxc@users.noreply.github.com>
Date: Mon, 8 Jul 2024 15:40:54 +0800
Subject: [PATCH 07/27] Update README.md
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 8fad8f9c..6a1fe7e3 100644
--- a/README.md
+++ b/README.md
@@ -71,4 +71,4 @@ If you think SPU is helpful for your research or development, please consider ci
## Acknowledgement
-We thank the significant contributions made by [Alibaba Gemini Lab](https://alibaba-gemini-lab.github.io) and security advisories made by [NISL@THU](https://netsec.ccert.edu.cn/vul337).
+We thank the significant contributions made by [Alibaba Gemini Lab](https://alibaba-gemini-lab.github.io) and security advisories made by [VUL337@NISL@THU](https://netsec.ccert.edu.cn/vul337).
From 630297dec8412c6404196be00f25b652b141c004 Mon Sep 17 00:00:00 2001
From: Yancheng Zheng <103552181+anakinxc@users.noreply.github.com>
Date: Thu, 11 Jul 2024 10:15:40 +0800
Subject: [PATCH 08/27] Repo sync (#762)
---
libspu/mpc/ab_api.cc | 4 ++--
libspu/mpc/cheetah/arith/matmat_prot.cc | 2 +-
libspu/mpc/cheetah/rlwe/packlwes.cc | 4 ++--
3 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/libspu/mpc/ab_api.cc b/libspu/mpc/ab_api.cc
index a52de586..7d02fe47 100644
--- a/libspu/mpc/ab_api.cc
+++ b/libspu/mpc/ab_api.cc
@@ -252,7 +252,7 @@ Value bitintl_b(SPUContext* ctx, const Value& x, size_t stride) {
idx--) {
auto K = hack_make_p(ctx, spu::detail::kBitIntlKeepMasks[idx], x.shape());
auto M = hack_make_p(ctx, spu::detail::kBitIntlSwapMasks[idx], x.shape());
- int64_t S = 1 << idx;
+ int64_t S = static_cast(1) << idx;
// out = (out & K) ^ ((out >> S) & M) ^ ((out & M) << S);
out = xor_bb(
ctx,
@@ -281,7 +281,7 @@ Value bitdeintl_b(SPUContext* ctx, const Value& x, size_t stride) {
for (int64_t idx = stride; idx + 1 < Log2Ceil(nbits); idx++) {
auto K = hack_make_p(ctx, spu::detail::kBitIntlKeepMasks[idx], x.shape());
auto M = hack_make_p(ctx, spu::detail::kBitIntlSwapMasks[idx], x.shape());
- int64_t S = 1 << idx;
+ int64_t S = static_cast(1) << idx;
// out = (out & K) ^ ((out >> S) & M) ^ ((out & M) << S);
out = xor_bb(
ctx,
diff --git a/libspu/mpc/cheetah/arith/matmat_prot.cc b/libspu/mpc/cheetah/arith/matmat_prot.cc
index a56e522f..ebb80f9f 100644
--- a/libspu/mpc/cheetah/arith/matmat_prot.cc
+++ b/libspu/mpc/cheetah/arith/matmat_prot.cc
@@ -183,7 +183,7 @@ Shape3D MatMatProtocol::GetSubMatShape(const Meta& meta, int64_t poly_deg,
const double cpu_price = 1.0;
const double bandwidth_price = 1000.0;
- Shape3D subshape;
+ Shape3D subshape = {0, 0, 0};
Shape3D blk;
const int64_t n = poly_deg;
diff --git a/libspu/mpc/cheetah/rlwe/packlwes.cc b/libspu/mpc/cheetah/rlwe/packlwes.cc
index 59f99e4c..f2c7272f 100644
--- a/libspu/mpc/cheetah/rlwe/packlwes.cc
+++ b/libspu/mpc/cheetah/rlwe/packlwes.cc
@@ -134,7 +134,7 @@ void PackingHelper::doPackingRLWEs(absl::Span rlwes,
seal::Evaluator evaluator(context_);
const int64_t logn = absl::bit_width(gap_) - 1;
for (int64_t k = logn; k >= 1; --k) {
- int64_t h = 1 << (k - 1);
+ int64_t h = static_cast(1) << (k - 1);
yacl::parallel_for(0, h, [&](int64_t bgn, int64_t end) {
RLWECt dummy; // zero-padding with zero RLWE
for (int64_t i = bgn; i < end; ++i) {
@@ -188,7 +188,7 @@ void GenerateGaloisKeyForPacking(const seal::SEALContext &context,
size_t logN = absl::bit_width(N) - 1;
std::vector galois_elt;
for (uint32_t i = 1; i <= logN; i++) {
- galois_elt.push_back((1u << i) + 1);
+ galois_elt.push_back((static_cast(1) << i) + 1);
}
seal::KeyGenerator keygen(context, key);
From fcef2ff4ae58d15e4e2df48369b4cd8577a01552 Mon Sep 17 00:00:00 2001
From: Yancheng Zheng <103552181+anakinxc@users.noreply.github.com>
Date: Thu, 11 Jul 2024 10:21:32 +0800
Subject: [PATCH 09/27] Update repositories.bzl
---
bazel/repositories.bzl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/bazel/repositories.bzl b/bazel/repositories.bzl
index 3fe35ee0..59ed839d 100644
--- a/bazel/repositories.bzl
+++ b/bazel/repositories.bzl
@@ -229,7 +229,7 @@ def _com_github_microsoft_seal():
maybe(
http_archive,
name = "com_github_microsoft_seal",
- sha256 = "78ef7334114de930daf7659e8ba60c5abfff85c86ec2b827a2b7c67c3c42da43",
+ sha256 = "acc2a1a127a85d1e1ffcca3ffd148f736e665df6d6b072df0e42fff64795a13c",
strip_prefix = "SEAL-4.1.2",
type = "tar.gz",
patch_args = ["-p1"],
From 5b42950f3d6d771cf719395e5e572e1f96a31540 Mon Sep 17 00:00:00 2001
From: Yancheng Zheng <103552181+anakinxc@users.noreply.github.com>
Date: Mon, 15 Jul 2024 13:34:00 +0800
Subject: [PATCH 10/27] Repo sync (#768)
# Pull Request
## What problem does this PR solve?
Issue Number: Fixed #766
## Possible side effects?
- Performance:
- Backward compatibility:
---
libspu/compiler/core/core.cc | 1 +
libspu/device/pphlo/pphlo_executor.cc | 11 ++
libspu/device/pphlo/pphlo_verifier.h | 1 +
libspu/dialect/pphlo/IR/ops.td | 6 +
libspu/dialect/pphlo/transforms/passes.h | 3 +
libspu/dialect/pphlo/transforms/passes.td | 6 +
.../pphlo/transforms/region_access_fixture.cc | 141 ++++++++++++++++++
7 files changed, 169 insertions(+)
create mode 100644 libspu/dialect/pphlo/transforms/region_access_fixture.cc
diff --git a/libspu/compiler/core/core.cc b/libspu/compiler/core/core.cc
index 355dd68a..e977050c 100644
--- a/libspu/compiler/core/core.cc
+++ b/libspu/compiler/core/core.cc
@@ -95,6 +95,7 @@ void Core::buildPipeline(mlir::PassManager *pm) {
}
optPM.addPass(mlir::createLoopInvariantCodeMotionPass());
+ optPM.addPass(mlir::spu::pphlo::createRegionAccessFixture());
optPM.addPass(mlir::createCSEPass());
if (!options.disable_deallocation_insertion()) {
diff --git a/libspu/device/pphlo/pphlo_executor.cc b/libspu/device/pphlo/pphlo_executor.cc
index d20d54f9..af8a04b8 100644
--- a/libspu/device/pphlo/pphlo_executor.cc
+++ b/libspu/device/pphlo/pphlo_executor.cc
@@ -720,6 +720,17 @@ void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope,
addValue(sscope, op.getOutput(), std::move(iota_ret), opts);
}
+void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope,
+ mlir::spu::pphlo::BroadcastShapeAsOp &op,
+ const ExecutionOptions &opts) {
+ // Start indices
+ const auto &lhs = lookupValue(sscope, op.getLhs(), opts);
+ const auto &rhs = lookupValue(sscope, op.getRhs(), opts);
+
+ addValue(sscope, op.getResult(),
+ kernel::hlo::Broadcast(sctx, lhs, rhs.shape(), {}), opts);
+}
+
void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope,
mlir::spu::pphlo::RemOp &op, const ExecutionOptions &opts) {
// FIXME: When hal has a remainder, use that
diff --git a/libspu/device/pphlo/pphlo_verifier.h b/libspu/device/pphlo/pphlo_verifier.h
index 428883c8..478a3874 100644
--- a/libspu/device/pphlo/pphlo_verifier.h
+++ b/libspu/device/pphlo/pphlo_verifier.h
@@ -149,6 +149,7 @@ class PPHloVerifier {
NO_VERIFY_DEFN(ImagOp)
NO_VERIFY_DEFN(ComplexOp)
NO_VERIFY_DEFN(SimpleSortOp)
+ NO_VERIFY_DEFN(BroadcastShapeAsOp)
#undef NO_VERIFY_DEFN
};
diff --git a/libspu/dialect/pphlo/IR/ops.td b/libspu/dialect/pphlo/IR/ops.td
index b9f6b0c2..205dcf75 100644
--- a/libspu/dialect/pphlo/IR/ops.td
+++ b/libspu/dialect/pphlo/IR/ops.td
@@ -619,6 +619,12 @@ def PPHLO_BroadcastOp
}];
}
+
+def PPHLO_BroadcastShapeAsOp: PPHLO_BinaryElementwiseOp<"broadcast_as", [Pure,
+ SameOperandsAndResultShape], PPHLO_Tensor> {
+ let summary = "BroadcastShapeAs operator";
+}
+
def PPHLO_ClampOp
: PPHLO_Op<"clamp", [Pure, SameOperandsAndResultShape]> {
let summary = "Clamp operator";
diff --git a/libspu/dialect/pphlo/transforms/passes.h b/libspu/dialect/pphlo/transforms/passes.h
index b70ce1ab..9bcc2660 100644
--- a/libspu/dialect/pphlo/transforms/passes.h
+++ b/libspu/dialect/pphlo/transforms/passes.h
@@ -82,6 +82,9 @@ std::unique_ptr> createInlineSecretControlFlow();
// Convert signbit pattern to SignOp
std::unique_ptr> createRewriteSignbitPatterns();
+// Fix region access shape mismatch
+std::unique_ptr> createRegionAccessFixture();
+
} // namespace spu::pphlo
} // namespace mlir
diff --git a/libspu/dialect/pphlo/transforms/passes.td b/libspu/dialect/pphlo/transforms/passes.td
index 4fc28fdc..a13c70b2 100644
--- a/libspu/dialect/pphlo/transforms/passes.td
+++ b/libspu/dialect/pphlo/transforms/passes.td
@@ -122,3 +122,9 @@ def InlineSecretControlFlow: Pass<"inline-secret-control-flow", "func::FuncOp">
let constructor = "createInlineSecretControlFlow()";
let dependentDialects = ["pphlo::PPHloDialect"];
}
+
+def RegionAccessFixture: Pass<"region-access-fixture", "func::FuncOp"> {
+ let summary = "Fix region access mismatched shape";
+ let constructor = "createRegionAccessFixture()";
+ let dependentDialects = ["pphlo::PPHloDialect"];
+}
\ No newline at end of file
diff --git a/libspu/dialect/pphlo/transforms/region_access_fixture.cc b/libspu/dialect/pphlo/transforms/region_access_fixture.cc
new file mode 100644
index 00000000..3bd9d83a
--- /dev/null
+++ b/libspu/dialect/pphlo/transforms/region_access_fixture.cc
@@ -0,0 +1,141 @@
+// Copyright 2024 Ant Group Co., Ltd.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+
+#include "mlir/Pass/Pass.h"
+
+#include "libspu/dialect/pphlo/IR/ops.h"
+#include "libspu/dialect/pphlo/transforms/pass_details.h"
+
+namespace mlir::spu::pphlo {
+
+namespace {
+
+struct Deallocator {
+ public:
+ LogicalResult transformOp(Operation *op) {
+ for (auto &r : op->getRegions()) {
+ if (failed(transformRegion(r))) {
+ return failure();
+ }
+ }
+
+ const auto &operands = op->getOperands();
+
+ if (op->getNumOperands() < 2 ||
+ !op->hasTrait<::mlir::OpTrait::Elementwise>() ||
+ std::all_of(operands.begin(), operands.end(), [](const auto &operand) {
+ return operand.template getDefiningOp();
+ })) {
+ return success();
+ }
+
+ auto *op_region = op->getParentRegion();
+
+ Value base_val;
+ llvm::SmallVector values_to_update;
+
+ OpBuilder builder(op->getContext());
+ builder.setInsertionPoint(op);
+
+ for (const auto &[idx, operand] : llvm::enumerate(operands)) {
+ // Get defining region
+ auto *defining_op = operand.getDefiningOp();
+
+ Region *defining_region = nullptr;
+
+ if (defining_op != nullptr) {
+ defining_region = defining_op->getParentRegion();
+ }
+
+ if (defining_op == nullptr || defining_region == op_region) {
+ // BlockArg or op defined in current region can be a base val
+ base_val = operand;
+ continue;
+ }
+
+ if (defining_region != op_region) {
+ // This op is accessing a variable out of op's region.
+ // Insert a broadcast as to fix runtime shape mismatch during simd
+ // region execution
+ values_to_update.emplace_back(idx);
+ }
+ }
+
+ if (!base_val) {
+ return values_to_update.empty()
+ ? failure() // same region however failed to pick base value
+ : success(); // can't pick base value since multi-level
+ // nesting
+ }
+
+ for (const auto &idx : values_to_update) {
+ auto op_to_broadcast = op->getOperand(idx);
+ auto b = builder.create(
+ op->getLoc(), op_to_broadcast.getType(), op_to_broadcast, base_val);
+ op->setOperand(idx, b);
+ }
+
+ return success();
+ }
+
+ LogicalResult transformBlock(Block &block) {
+ for (auto &op : llvm::make_early_inc_range(block.without_terminator())) {
+ auto opResult = transformOp(&op);
+ if (failed(opResult)) {
+ return failure();
+ }
+ }
+ return success();
+ }
+
+ LogicalResult transformRegion(Region &r) {
+ for (auto &b : r.getBlocks()) {
+ if (failed(transformBlock(b))) {
+ return failure();
+ }
+ }
+ return success();
+ }
+
+ LogicalResult transformFuncOp(func::FuncOp op) {
+ if (op->getNumRegions() == 0) {
+ return success();
+ }
+
+ // Transform function body.
+ if (failed(transformRegion(op.getBody()))) {
+ return failure();
+ }
+
+ return success();
+ }
+};
+
+struct RegionAccessFixture
+ : public RegionAccessFixtureBase {
+ void runOnOperation() override {
+ if (failed(Deallocator().transformFuncOp(getOperation()))) {
+ signalPassFailure();
+ }
+ }
+};
+} // namespace
+
+std::unique_ptr> createRegionAccessFixture() {
+ return std::make_unique();
+}
+
+} // namespace mlir::spu::pphlo
From 333f16cc76157c1f0aacdf33b08d4dc9b18304a8 Mon Sep 17 00:00:00 2001
From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com>
Date: Tue, 16 Jul 2024 08:07:51 +0800
Subject: [PATCH 11/27] chore(deps): update dependency sphinx to v7.4.4 (#770)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
[![Mend
Renovate](https://app.renovatebot.com/images/banner.svg)](https://renovatebot.com)
This PR contains the following updates:
| Package | Change | Age | Adoption | Passing | Confidence |
|---|---|---|---|---|---|
| [sphinx](https://togithub.com/sphinx-doc/sphinx)
([changelog](https://www.sphinx-doc.org/en/master/changes.html)) |
`==7.3.7` -> `==7.4.4` |
[![age](https://developer.mend.io/api/mc/badges/age/pypi/sphinx/7.4.4?slim=true)](https://docs.renovatebot.com/merge-confidence/)
|
[![adoption](https://developer.mend.io/api/mc/badges/adoption/pypi/sphinx/7.4.4?slim=true)](https://docs.renovatebot.com/merge-confidence/)
|
[![passing](https://developer.mend.io/api/mc/badges/compatibility/pypi/sphinx/7.3.7/7.4.4?slim=true)](https://docs.renovatebot.com/merge-confidence/)
|
[![confidence](https://developer.mend.io/api/mc/badges/confidence/pypi/sphinx/7.3.7/7.4.4?slim=true)](https://docs.renovatebot.com/merge-confidence/)
|
---
### Release Notes
sphinx-doc/sphinx (sphinx)
###
[`v7.4.4`](https://togithub.com/sphinx-doc/sphinx/blob/HEAD/CHANGES.rst#Release-744-released-Jul-15-2024)
[Compare
Source](https://togithub.com/sphinx-doc/sphinx/compare/v7.4.3...v7.4.4)
\=====================================
## Bugs fixed
- [#12585](https://togithub.com/sphinx-doc/sphinx/issues/12585),
[#12586](https://togithub.com/sphinx-doc/sphinx/issues/12586): Do
not warn when an intersphinx inventory contains
case-insensitively ambiguous duplicate items.
Patch by James Addison.
###
[`v7.4.3`](https://togithub.com/sphinx-doc/sphinx/blob/HEAD/CHANGES.rst#Release-743-released-Jul-15-2024)
[Compare
Source](https://togithub.com/sphinx-doc/sphinx/compare/v7.4.2...v7.4.3)
\=====================================
## Bugs fixed
- [#12582](https://togithub.com/sphinx-doc/sphinx/issues/12582):
Restore support for list-styled :confval:`source_suffix` values
with extensions that register parsers.
Patch by Adam Turner.
###
[`v7.4.2`](https://togithub.com/sphinx-doc/sphinx/blob/HEAD/CHANGES.rst#Release-742-released-Jul-15-2024)
[Compare
Source](https://togithub.com/sphinx-doc/sphinx/compare/v7.4.1...v7.4.2)
\=====================================
## Bugs fixed
- [#12580](https://togithub.com/sphinx-doc/sphinx/issues/12580),
[#12583](https://togithub.com/sphinx-doc/sphinx/issues/12583):
Resolve failures with the C domain on incremental builds
with Sphinx 7.3.7 and earlier.
Patch by Adam Turner.
###
[`v7.4.1`](https://togithub.com/sphinx-doc/sphinx/blob/HEAD/CHANGES.rst#Release-741-in-development)
[Compare
Source](https://togithub.com/sphinx-doc/sphinx/compare/v7.4.0...v7.4.1)
\==============================
## Dependencies
## Incompatible changes
## Deprecated
## Features added
## Bugs fixed
- Fix invalid HTML when a rubric node with invalid `heading-level` is
used.
Patch by Adam Turner.
- [#12579](https://togithub.com/sphinx-doc/sphinx/issues/12579),
[#12581](https://togithub.com/sphinx-doc/sphinx/issues/12581):
Restore support for `typing.ParamSpec` in autodoc.
Patch by Adam Turner.
## Testing
###
[`v7.4.0`](https://togithub.com/sphinx-doc/sphinx/blob/HEAD/CHANGES.rst#Release-740-released-Jul-15-2024)
[Compare
Source](https://togithub.com/sphinx-doc/sphinx/compare/v7.3.7...v7.4.0)
\=====================================
## Dependencies
- [#12555](https://togithub.com/sphinx-doc/sphinx/issues/12555):
Drop Docutils 0.18.1 and Docutils 0.19 support.
Patch by Adam Turner.
- LaTeX: the `xcolor` package is now required (but is for example part
of
Ubuntu `texlive-latex-recommended` which has always been required).
- LaTeX: the `fontawesome5` LaTeX package is needed for the default
choices
of icons now used in admonition titles in PDF output; but if unavailable
the
PDF build will simply silently omit rendering such icons. Check the
documentation of the `iconpackage` key of :ref:`'sphinxsetup'
` for more.
## Deprecated
- LaTeX: the `sphinxlightbox` environment is not used anymore, all types
of admonitions use (by default) only `sphinxheavybox`.
## Features added
.. rst-class:: compact
- [#11165](https://togithub.com/sphinx-doc/sphinx/issues/11165):
Support the `officially recommended`\_ `.jinja` suffix for template
files.
Patch by James Addison and Adam Turner
.. \_officially recommended:
https://jinja.palletsprojects.com/en/latest/templates/#template-file-extension
- [#12325](https://togithub.com/sphinx-doc/sphinx/issues/12325):
Flatten `Union[Literal[T], Literal[U], ...]` to `Literal[T, U, ...]`
when turning annotations into strings.
Patch by Adam Turner.
- [#12319](https://togithub.com/sphinx-doc/sphinx/issues/12319):
`sphinx.ext.extlinks`: Add `extlink-{name}` CSS class to links.
Patch by Hugo van Kemenade.
- [#12387](https://togithub.com/sphinx-doc/sphinx/issues/12387):
Improve CLI progress message, when copying assets.
Patch by INADA Nakoi and Bénédikt Tran.
- [#12361](https://togithub.com/sphinx-doc/sphinx/issues/12361):
Add :attr:`.BuildEnvironment.parser`.
Patch by Chris Sewell.
- [#12358](https://togithub.com/sphinx-doc/sphinx/issues/12358):
Add :attr:`.Sphinx.fresh_env_used`.
Patch by Chris Sewell.
- [#12329](https://togithub.com/sphinx-doc/sphinx/issues/12329):
Add detection of ambiguous `std:label` and `std:term` references during
loading and resolution of Intersphinx targets.
Patch by James Addison.
- [#12422](https://togithub.com/sphinx-doc/sphinx/issues/12422):
Do not duplicate "navigation" in aria-label of built-in themes.
Patch by Thomas Weißschuh
- [#12421](https://togithub.com/sphinx-doc/sphinx/issues/12421):
Include project name in `logo_alt` of built-in themes.
Patch by Thomas Weißschuh
- [#12448](https://togithub.com/sphinx-doc/sphinx/issues/12448):
Add :option:`sphinx-apidoc --remove-old` option.
Patch by Chris Sewell.
- [#12456](https://togithub.com/sphinx-doc/sphinx/issues/12456):
Add :option:`sphinx-autogen --remove-old` option.
Patch by Chris Sewell.
- [#12479](https://togithub.com/sphinx-doc/sphinx/issues/12479):
Add warning subtype `toc.no_title`.
Patch by Ondřej Navrátil.
- [#12492](https://togithub.com/sphinx-doc/sphinx/issues/12492):
Add helper methods for parsing reStructuredText content into nodes from
within a directive.
-
:py:meth:`~sphinx.util.docutils.SphinxDirective.parse_content_to_nodes()`
parses the directive's content and returns a list of Docutils nodes.
- :py:meth:`~sphinx.util.docutils.SphinxDirective.parse_text_to_nodes()`
parses the provided text and returns a list of Docutils nodes.
- :py:meth:`~sphinx.util.docutils.SphinxDirective.parse_inline()`
parses the provided text into inline elements and text nodes.
Patch by Adam Turner.
- [#12258](https://togithub.com/sphinx-doc/sphinx/issues/12258):
Support `typing_extensions.Unpack`
Patch by Bénédikt Tran and Adam Turner.
- [#12524](https://togithub.com/sphinx-doc/sphinx/issues/12524):
Add a `class` option to the :rst:dir:`toctree` directive.
Patch by Tim Hoffmann.
- [#12536](https://togithub.com/sphinx-doc/sphinx/issues/12536):
Add the :rst:dir:`confval` directive.
Patch by Adam Turner.
- [#12537](https://togithub.com/sphinx-doc/sphinx/issues/12537):
:confval:`c_id_attributes`, :confval:`c_paren_attributes`,
:confval:`cpp_id_attributes`, and :confval:`cpp_paren_attributes`
can now be a tuple of strings.
:confval:`c_extra_keywords`, :confval:`gettext_additional_targets`,
:confval:`html_domain_indices`, :confval:`latex_domain_indices`,
and :confval:`texinfo_domain_indices`,
can now be a set of strings.
Patch by Adam Turner.
- [#12523](https://togithub.com/sphinx-doc/sphinx/issues/12523):
Added configuration option, :confval:`math_numsep`, to define the
separator for math numbering.
Patch by Thomas Fanning
- [#11592](https://togithub.com/sphinx-doc/sphinx/issues/11592):
Add :confval:`coverage_modules` to the coverage builder
to allow explicitly specifying which modules should be documented.
Patch by Stephen Finucane.
- [#7896](https://togithub.com/sphinx-doc/sphinx/issues/7896),
[#11989](https://togithub.com/sphinx-doc/sphinx/issues/11989):
Add a :rst:dir:`py:type` directive for documenting type aliases,
and a :rst:role:`py:type` role for linking to them.
Patch by Ashley Whetter.
- [#12549](https://togithub.com/sphinx-doc/sphinx/issues/12549):
Add optional `description` argument to
:meth:`.Sphinx.add_config_value`.
Patch by Chris Sewell.
- [#6792](https://togithub.com/sphinx-doc/sphinx/issues/6792):
Prohibit module import cycles in :mod:`sphinx.ext.autosummary`.
Patch by Trevor Bekolay.
- [#12508](https://togithub.com/sphinx-doc/sphinx/issues/12508):
LaTeX: Revamped styling of all admonitions, with addition of a
title row with icon.
Patch by Jean-François B.
- [#11773](https://togithub.com/sphinx-doc/sphinx/issues/11773):
Display :py:class:`~typing.Annotated` annotations
with their metadata in the Python domain.
Patch by Adam Turner and David Stansby.
- [#12506](https://togithub.com/sphinx-doc/sphinx/issues/12506):
Add `level` option to :rst:dir:`rubric` directive.
Patch by Chris Sewell.
- [#12567](https://togithub.com/sphinx-doc/sphinx/issues/12567):
Add the :event:`write-started` event.
Patch by Chris Sewell.
## Bugs fixed
- [#12314](https://togithub.com/sphinx-doc/sphinx/issues/12314):
Properly format `collections.abc.Callable` in annotations.
Patch by Adam Turner.
- [#12162](https://togithub.com/sphinx-doc/sphinx/issues/12162):
Fix a performance regression in the C domain that has
been present since version 3.0.0.
Patch by Donald Hunter.
- [#12320](https://togithub.com/sphinx-doc/sphinx/issues/12320):
Fix removal of anchors from search summaries (regression in 7.3.0).
Patch by Will Lachance.
- [#12251](https://togithub.com/sphinx-doc/sphinx/issues/12251):
Fix `merge_domaindata()` in `sphinx.ext.duration`.
Patch by Matthias Geier.
- [#12224](https://togithub.com/sphinx-doc/sphinx/issues/12224):
Properly detect WebP files.
Patch by Benjamin Cabé.
- [#12380](https://togithub.com/sphinx-doc/sphinx/issues/12380):
LaTeX: Footnote mark sometimes indicates `Page N` where `N` is
the current page number and the footnote does appear on that same page.
Patch by Jean-François B.
- [#12410](https://togithub.com/sphinx-doc/sphinx/issues/12410):
LaTeX: for French and `'lualatex'` as :confval:`latex_engine`
`polyglossia` and not `babel` is used (contrarily to `'xelatex'`).
Patch by Jean-François B.
- [#12416](https://togithub.com/sphinx-doc/sphinx/issues/12416):
Ensure that configuration setting aliases are always synchronised
when one value or the other is modified.
Patch by Bénédikt Tran.
- [#12220](https://togithub.com/sphinx-doc/sphinx/issues/12220):
Fix loading custom template translations for `en` locale.
Patch by Nicolas Peugnet.
- [#12459](https://togithub.com/sphinx-doc/sphinx/issues/12459):
Add valid-type arguments to the `linkcheck_rate_limit_timeout`
configuration setting.
Patch by James Addison.
- [#12331](https://togithub.com/sphinx-doc/sphinx/issues/12331):
Resolve data-URI-image-extraction regression from v7.3.0 affecting
builders without native support for data-URIs in their output format.
Patch by James Addison.
- [#12494](https://togithub.com/sphinx-doc/sphinx/issues/12494):
Fix invalid genindex.html file produced with translated docs
(regression in 7.1.0).
Patch by Nicolas Peugnet.
- [#11961](https://togithub.com/sphinx-doc/sphinx/issues/11961):
Omit anchor references from document title entries in the search index,
removing duplication of search results.
Patch by James Addison.
- [#12425](https://togithub.com/sphinx-doc/sphinx/issues/12425):
Use Docutils' SVG processing in the HTML builder
and remove Sphinx's custom logic.
Patch by Tunç Başar Köse.
- [#12391](https://togithub.com/sphinx-doc/sphinx/issues/12391):
Adjust scoring of matches during HTML search so that document main
titles tend to rank higher than subsection titles. In addition, boost
matches
on the name of programming domain objects relative to title/subtitle
matches.
Patch by James Addison and Will Lachance.
- [#9634](https://togithub.com/sphinx-doc/sphinx/issues/9634): Do
not add a fallback language by stripping the country code.
Patch by Alvin Wong.
- [#12352](https://togithub.com/sphinx-doc/sphinx/issues/12352):
Add domain objects to the table of contents
in the same order as defined in the document.
Previously, each domain used language-specific nesting rules,
which removed control from document authors.
Patch by Jakob Lykke Andersen and Adam Turner.
- [#11041](https://togithub.com/sphinx-doc/sphinx/issues/11041):
linkcheck: Ignore URLs that respond with non-Unicode content.
Patch by James Addison.
- [#12543](https://togithub.com/sphinx-doc/sphinx/issues/12543):
Fix :pep:`695` formatting for LaTeX output.
Patch by Bénédikt Tran.
## Testing
- karma: refactor HTML search tests to use fixtures generated by Sphinx.
Patch by James Addison.
---
### Configuration
📅 **Schedule**: Branch creation - At any time (no schedule defined),
Automerge - At any time (no schedule defined).
🚦 **Automerge**: Disabled by config. Please merge this manually once you
are satisfied.
♻ **Rebasing**: Whenever PR becomes conflicted, or you tick the
rebase/retry checkbox.
🔕 **Ignore**: Close this PR and you won't be reminded about this update
again.
---
- [ ] If you want to rebase/retry this PR, check
this box
---
This PR has been generated by [Mend
Renovate](https://www.mend.io/free-developer-tools/renovate/). View
repository job log
[here](https://developer.mend.io/github/secretflow/spu).
Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
---
docs/requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/requirements.txt b/docs/requirements.txt
index 9b6bb997..3b26b70d 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -1,6 +1,6 @@
myst-parser==3.0.1
rstcheck==6.2.4
-sphinx==7.3.7
+sphinx==7.4.4
nbsphinx==0.9.4
sphinx-autobuild==2024.4.16
sphinx-markdown-parser==0.2.4
From f3296773573842d254950a76428473f68b96eb92 Mon Sep 17 00:00:00 2001
From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com>
Date: Tue, 16 Jul 2024 08:08:21 +0800
Subject: [PATCH 12/27] chore(deps): update github/codeql-action action to
v3.25.12 (#767)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
[![Mend
Renovate](https://app.renovatebot.com/images/banner.svg)](https://renovatebot.com)
This PR contains the following updates:
| Package | Type | Update | Change |
|---|---|---|---|
| [github/codeql-action](https://togithub.com/github/codeql-action) |
action | patch | `v3.25.11` -> `v3.25.12` |
---
### Release Notes
github/codeql-action (github/codeql-action)
###
[`v3.25.12`](https://togithub.com/github/codeql-action/compare/v3.25.11...v3.25.12)
[Compare
Source](https://togithub.com/github/codeql-action/compare/v3.25.11...v3.25.12)
---
### Configuration
📅 **Schedule**: Branch creation - At any time (no schedule defined),
Automerge - At any time (no schedule defined).
🚦 **Automerge**: Disabled by config. Please merge this manually once you
are satisfied.
♻ **Rebasing**: Whenever PR becomes conflicted, or you tick the
rebase/retry checkbox.
🔕 **Ignore**: Close this PR and you won't be reminded about this update
again.
---
- [ ] If you want to rebase/retry this PR, check
this box
---
This PR has been generated by [Mend
Renovate](https://www.mend.io/free-developer-tools/renovate/). View
repository job log
[here](https://developer.mend.io/github/secretflow/spu).
Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
---
.github/workflows/scorecard.yml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/.github/workflows/scorecard.yml b/.github/workflows/scorecard.yml
index 3b6e8be7..5b25f5b2 100644
--- a/.github/workflows/scorecard.yml
+++ b/.github/workflows/scorecard.yml
@@ -67,6 +67,6 @@ jobs:
# Upload the results to GitHub's code scanning dashboard.
- name: "Upload to code-scanning"
- uses: github/codeql-action/upload-sarif@b611370bb5703a7efb587f9d136a52ea24c5c38c # v3.25.11
+ uses: github/codeql-action/upload-sarif@4fa2a7953630fd2f3fb380f21be14ede0169dd4f # v3.25.12
with:
sarif_file: results.sarif
From 63cce33618b9b3d76761126453851f72d56fbf8b Mon Sep 17 00:00:00 2001
From: Yancheng Zheng <103552181+anakinxc@users.noreply.github.com>
Date: Wed, 17 Jul 2024 17:48:22 +0800
Subject: [PATCH 13/27] Repo sync (#772)
---
.clang-tidy | 3 +-
CHANGELOG.md | 3 +-
bazel/repositories.bzl | 4 +-
docs/development/add_protocols.rst | 2 +-
examples/python/ml/jax_lr/README.md | 2 +-
experimental/squirrel/bin_matvec_prot_test.cc | 4 +-
experimental/squirrel/objectives.cc | 6 +-
experimental/squirrel/tree_build_worker.cc | 2 +-
experimental/squirrel/utils.cc | 6 +-
experimental/squirrel/utils_test.cc | 4 +-
libspu/compiler/common/compilation_context.h | 1 -
libspu/compiler/compile.cc | 2 -
libspu/compiler/core/core.cc | 1 -
libspu/compiler/front_end/fe.cc | 1 -
libspu/compiler/tests/interpret/abs.mlir | 2 +
libspu/compiler/tests/interpret/add.mlir | 2 +
libspu/compiler/tests/interpret/and.mlir | 2 +
libspu/compiler/tests/interpret/atan2.mlir | 2 +
.../compiler/tests/interpret/broadcast.mlir | 2 +
libspu/compiler/tests/interpret/case.mlir | 2 +
libspu/compiler/tests/interpret/ceil.mlir | 2 +
libspu/compiler/tests/interpret/clamp.mlir | 2 +
.../compiler/tests/interpret/concatenate.mlir | 2 +
libspu/compiler/tests/interpret/convert.mlir | 2 +
.../compiler/tests/interpret/convolution.mlir | 2 +
libspu/compiler/tests/interpret/cosine.mlir | 2 +
libspu/compiler/tests/interpret/divide.mlir | 2 +
.../compiler/tests/interpret/dot_general.mlir | 2 +
.../tests/interpret/dynamic_slice.mlir | 2 +
.../tests/interpret/dynamic_update_slice.mlir | 2 +
libspu/compiler/tests/interpret/equal.mlir | 2 +
.../compiler/tests/interpret/exponential.mlir | 2 +
.../interpret/exponential_minus_one.mlir | 2 +
libspu/compiler/tests/interpret/floor.mlir | 2 +
.../tests/interpret/generate_mlir_tests.py | 9 +
libspu/compiler/tests/interpret/greater.mlir | 2 +
.../tests/interpret/greater_equal.mlir | 2 +
libspu/compiler/tests/interpret/if.mlir | 2 +
libspu/compiler/tests/interpret/iota.mlir | 2 +
libspu/compiler/tests/interpret/less.mlir | 2 +
.../compiler/tests/interpret/less_equal.mlir | 2 +
libspu/compiler/tests/interpret/log.mlir | 2 +
.../tests/interpret/log_plus_one.mlir | 2 +
libspu/compiler/tests/interpret/logistic.mlir | 2 +
libspu/compiler/tests/interpret/maximum.mlir | 1 +
libspu/compiler/tests/interpret/minimum.mlir | 1 +
libspu/compiler/tests/interpret/multiply.mlir | 2 +
libspu/compiler/tests/interpret/negate.mlir | 2 +
libspu/compiler/tests/interpret/not.mlir | 2 +
.../compiler/tests/interpret/not_equal.mlir | 2 +
libspu/compiler/tests/interpret/or.mlir | 2 +
libspu/compiler/tests/interpret/pad.mlir | 2 +
libspu/compiler/tests/interpret/popcnt.mlir | 2 +
libspu/compiler/tests/interpret/power.mlir | 2 +
.../compiler/tests/interpret/reciprocal.mlir | 26 +++
libspu/compiler/tests/interpret/reduce.mlir | 2 +
.../tests/interpret/reduce_window.mlir | 2 +
libspu/compiler/tests/interpret/reshape.mlir | 2 +
libspu/compiler/tests/interpret/reverse.mlir | 2 +
.../compiler/tests/interpret/ring_cast.mlir | 2 +
.../tests/interpret/round_nearest_afz.mlir | 2 +
libspu/compiler/tests/interpret/rsqrt.mlir | 2 +
libspu/compiler/tests/interpret/select.mlir | 2 +
.../tests/interpret/select_and_scatter.mlir | 2 +
.../interpret/shift_right_arithmetic.mlir | 2 +
.../tests/interpret/shift_right_logical.mlir | 2 +
libspu/compiler/tests/interpret/sign.mlir | 2 +
libspu/compiler/tests/interpret/sine.mlir | 2 +
libspu/compiler/tests/interpret/slice.mlir | 2 +
libspu/compiler/tests/interpret/sort.mlir | 2 +
libspu/compiler/tests/interpret/sqrt.mlir | 2 +
libspu/compiler/tests/interpret/subtract.mlir | 2 +
libspu/compiler/tests/interpret/tanh.mlir | 2 +
.../tests/interpret/test_json/reciprocal.json | 24 +++
.../compiler/tests/interpret/transpose.mlir | 2 +
libspu/compiler/tests/interpret/while.mlir | 2 +
libspu/compiler/tests/interpret/xor.mlir | 2 +
.../passes/hlo2pphlo/select_and_scatter.mlir | 4 +-
.../tests/passes/hlo2pphlo/sort_p.mlir | 4 +-
.../tests/passes/hlo2pphlo/sort_s.mlir | 4 +-
libspu/compiler/tools/spu-lsp.cc | 3 +-
libspu/compiler/tools/spu-translate.cc | 20 ++-
libspu/core/encoding.cc | 12 +-
libspu/core/pt_buffer_view.cc | 2 +-
libspu/core/type_util.h | 170 +++++++++---------
libspu/dialect/pphlo/IR/dialect.td | 5 +-
libspu/dialect/pphlo/IR/type_inference.cc | 23 +--
.../pphlo/transforms/hlo_legalize_to_pphlo.cc | 16 +-
libspu/kernel/hal/constants.cc | 2 +-
libspu/kernel/hal/fxp_approx.cc | 33 ++--
libspu/kernel/hal/fxp_base.cc | 21 ++-
libspu/kernel/hal/fxp_base.h | 2 +
libspu/kernel/hal/fxp_cleartext.cc | 6 +-
libspu/kernel/hal/permute.cc | 11 +-
libspu/kernel/hal/polymorphic.cc | 18 +-
libspu/kernel/hal/polymorphic.h | 7 +-
libspu/kernel/hal/prot_wrapper.cc | 10 +-
libspu/kernel/hal/prot_wrapper.h | 18 +-
libspu/kernel/hal/ring.cc | 36 ++--
libspu/kernel/hal/ring.h | 6 +-
libspu/kernel/hal/type_cast.cc | 7 +-
libspu/kernel/hlo/basic_unary.cc | 10 +-
libspu/kernel/hlo/shift.cc | 29 +--
libspu/mpc/ab_api.cc | 46 ++---
libspu/mpc/ab_api.h | 8 +-
libspu/mpc/ab_api_test.cc | 43 ++---
libspu/mpc/aby3/arithmetic.cc | 51 +++---
libspu/mpc/aby3/arithmetic.h | 2 +-
libspu/mpc/aby3/boolean.cc | 91 +++++-----
libspu/mpc/aby3/boolean.h | 6 +-
libspu/mpc/aby3/conversion.cc | 45 ++---
libspu/mpc/aby3/io.cc | 6 +-
libspu/mpc/aby3/oram.cc | 12 +-
libspu/mpc/aby3/ot.cc | 2 -
libspu/mpc/aby3/permute.cc | 8 +-
libspu/mpc/aby3/value.h | 2 +-
libspu/mpc/api.cc | 33 ++--
libspu/mpc/api.h | 18 +-
libspu/mpc/api_test.cc | 167 ++++++++---------
.../mpc/cheetah/arith/cheetah_conv2d_test.cc | 2 +-
libspu/mpc/cheetah/arith/cheetah_dot_test.cc | 4 +-
libspu/mpc/cheetah/arith/common.cc | 2 +-
libspu/mpc/cheetah/arith/matmat_prot.cc | 8 +-
libspu/mpc/cheetah/arith/matmat_prot_test.cc | 6 +-
libspu/mpc/cheetah/arith/vector_encoder.cc | 3 +-
.../mpc/cheetah/arith/vector_encoder_test.cc | 2 +-
libspu/mpc/cheetah/arithmetic.cc | 2 +-
libspu/mpc/cheetah/arithmetic.h | 2 +-
libspu/mpc/cheetah/arithmetic_semi2k.cc | 9 +-
libspu/mpc/cheetah/boolean.h | 6 +-
libspu/mpc/cheetah/boolean_semi2k.cc | 28 ++-
libspu/mpc/cheetah/nonlinear/compare_prot.cc | 8 +-
.../cheetah/nonlinear/compare_prot_test.cc | 19 +-
libspu/mpc/cheetah/nonlinear/equal_prot.cc | 6 +-
.../mpc/cheetah/nonlinear/equal_prot_test.cc | 4 +-
libspu/mpc/cheetah/nonlinear/truncate_prot.cc | 17 +-
.../cheetah/nonlinear/truncate_prot_test.cc | 14 +-
libspu/mpc/cheetah/ot/basic_ot_prot.cc | 20 +--
libspu/mpc/cheetah/ot/basic_ot_prot_test.cc | 30 ++--
libspu/mpc/cheetah/ot/emp/ferret_test.cc | 8 +-
libspu/mpc/cheetah/ot/ot_util.cc | 19 +-
libspu/mpc/cheetah/ot/ot_util_test.cc | 6 +-
libspu/mpc/cheetah/ot/yacl/ferret_test.cc | 8 +-
libspu/mpc/cheetah/rlwe/modswitch_helper.cc | 6 +-
.../mpc/cheetah/rlwe/modswitch_helper_test.cc | 13 +-
libspu/mpc/cheetah/rlwe/packlwes_test.cc | 7 +-
libspu/mpc/cheetah/state.cc | 4 +-
libspu/mpc/common/prg_state.h | 31 ++--
libspu/mpc/common/pv2k.cc | 52 +++---
libspu/mpc/kernel.cc | 6 +-
libspu/mpc/kernel.h | 2 +-
libspu/mpc/ref2k/ref2k.cc | 17 +-
libspu/mpc/securenn/arithmetic.cc | 33 ++--
libspu/mpc/securenn/arithmetic.h | 2 +-
libspu/mpc/securenn/boolean.cc | 35 ++--
libspu/mpc/securenn/boolean.h | 6 +-
libspu/mpc/securenn/conversion.cc | 10 +-
libspu/mpc/semi2k/arithmetic.cc | 15 +-
libspu/mpc/semi2k/arithmetic.h | 2 +-
.../semi2k/beaver/beaver_impl/beaver_test.cc | 43 +++--
.../semi2k/beaver/beaver_impl/beaver_tfp.cc | 2 +-
.../semi2k/beaver/beaver_impl/beaver_ttp.cc | 2 +-
.../trusted_party/trusted_party.cc | 9 +-
libspu/mpc/semi2k/boolean.cc | 31 ++--
libspu/mpc/semi2k/boolean.h | 6 +-
libspu/mpc/semi2k/conversion.cc | 20 ++-
libspu/mpc/semi2k/permute.cc | 2 +-
libspu/mpc/spdz2k/abprotocol_spdz2k_test.cc | 4 +-
libspu/mpc/spdz2k/arithmetic.cc | 21 +--
libspu/mpc/spdz2k/arithmetic.h | 2 +-
libspu/mpc/spdz2k/beaver/beaver_test.cc | 7 +-
libspu/mpc/spdz2k/beaver/beaver_tfp.cc | 8 +-
libspu/mpc/spdz2k/beaver/beaver_tinyot.cc | 34 ++--
libspu/mpc/spdz2k/beaver/trusted_party.cc | 5 +-
libspu/mpc/spdz2k/boolean.cc | 45 +++--
libspu/mpc/spdz2k/boolean.h | 6 +-
libspu/mpc/spdz2k/conversion.cc | 30 ++--
libspu/mpc/spdz2k/io.cc | 8 +-
libspu/mpc/spdz2k/value.cc | 8 +-
libspu/mpc/tools/benchmark.h | 32 ++--
libspu/mpc/utils/BUILD.bazel | 1 +
libspu/mpc/utils/circuits.h | 38 ++--
libspu/mpc/utils/circuits_test.cc | 12 +-
libspu/mpc/utils/permute.cc | 8 +-
libspu/mpc/utils/ring_ops.cc | 82 +++++----
libspu/mpc/utils/ring_ops.h | 12 +-
libspu/version.h | 2 +-
187 files changed, 1197 insertions(+), 994 deletions(-)
create mode 100644 libspu/compiler/tests/interpret/reciprocal.mlir
create mode 100644 libspu/compiler/tests/interpret/test_json/reciprocal.json
diff --git a/.clang-tidy b/.clang-tidy
index bacbb3a2..227a407d 100644
--- a/.clang-tidy
+++ b/.clang-tidy
@@ -28,7 +28,8 @@ Checks: "abseil-cleanup-ctad,
-readability-identifier-length,
-readability-function-cognitive-complexity,
-readability-magic-numbers,
- -readability-named-parameter"
+ -readability-named-parameter,
+ -readability-convert-member-functions-to-static"
CheckOptions:
- key: bugprone-argument-comment.StrictMode
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 0712ac66..6b86bb5b 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -10,8 +10,9 @@
>
> please add your unreleased change here.
-## TBD
+## 20240716
+- [SPU] 0.9.2b0 release
- [Feature] Support jax.numpy.bitwise_count
- [Bugfix] Fix jax.numpy.signbit wrong answer with very large input
diff --git a/bazel/repositories.bzl b/bazel/repositories.bzl
index 59ed839d..d97c8e68 100644
--- a/bazel/repositories.bzl
+++ b/bazel/repositories.bzl
@@ -136,8 +136,8 @@ def _bazel_skylib():
)
def _com_github_openxla_xla():
- OPENXLA_COMMIT = "9b0dd58c9b625a2e958f4fc7787a1ff5c95dbb40"
- OPENXLA_SHA256 = "f150c5b49e4d4497aae2c79232f1efe2baccaa72223b21dc8715be73eab74417"
+ OPENXLA_COMMIT = "8533a6869ae02fb3b15a8a12739a982fc3c9f6e7"
+ OPENXLA_SHA256 = "d5b076825c992f59542f6b94e5480c7e7c6c627cd18c80ec60b6d5b295c160d4"
# We need openxla to handle xla/mhlo/stablehlo
maybe(
diff --git a/docs/development/add_protocols.rst b/docs/development/add_protocols.rst
index 31b4b5e2..988c846c 100644
--- a/docs/development/add_protocols.rst
+++ b/docs/development/add_protocols.rst
@@ -248,7 +248,7 @@ When kernels are implemented and registered, a new protocol is finally added.
auto* prg_state = ctx->getState();
// dispatch the real implementation to different fields
- return DISPATCH_ALL_FIELDS(field, "aby3.mulAA", [&]() {
+ return DISPATCH_ALL_FIELDS(field, [&]() {
// the real protocol implementation
...
});
diff --git a/examples/python/ml/jax_lr/README.md b/examples/python/ml/jax_lr/README.md
index fc0e6580..bc340199 100644
--- a/examples/python/ml/jax_lr/README.md
+++ b/examples/python/ml/jax_lr/README.md
@@ -5,7 +5,7 @@ This example demonstrates how to use SPU to train a logistic regression model pr
1. Launch SPU backend runtime
```sh
- bazel run -c opt //examples/python/utils:nodectl -- up
+ bazel run -c opt //examples/python/utils:nodectl -- -c examples/python/conf/2pc_semi2k.json up
```
2. Run `jax_lr` example
diff --git a/experimental/squirrel/bin_matvec_prot_test.cc b/experimental/squirrel/bin_matvec_prot_test.cc
index 0d5cae29..452e127d 100644
--- a/experimental/squirrel/bin_matvec_prot_test.cc
+++ b/experimental/squirrel/bin_matvec_prot_test.cc
@@ -110,7 +110,7 @@ TEST_P(BinMatVecProtTest, Basic) {
});
NdArrayRef reveal = ring_add(out_shr[0], out_shr[1]);
- DISPATCH_ALL_FIELDS(field, "", [&]() {
+ DISPATCH_ALL_FIELDS(field, [&]() {
NdArrayView _vec(vec);
auto expected = BinAccumuate(_vec, mat);
NdArrayView got(reveal);
@@ -160,7 +160,7 @@ TEST_P(BinMatVecProtTest, WithIndicator) {
});
NdArrayRef reveal = ring_add(out_shr[0], out_shr[1]);
- DISPATCH_ALL_FIELDS(field, "", [&]() {
+ DISPATCH_ALL_FIELDS(field, [&]() {
NdArrayView _vec(vec);
auto expected =
BinAccumuate(_vec, mat, absl::MakeConstSpan(indicator));
diff --git a/experimental/squirrel/objectives.cc b/experimental/squirrel/objectives.cc
index b9423d2d..baf5f08d 100644
--- a/experimental/squirrel/objectives.cc
+++ b/experimental/squirrel/objectives.cc
@@ -272,9 +272,9 @@ namespace {
res = NdArrayRef(makeType(ftype), in.shape());
}
- return DISPATCH_ALL_FIELDS(field, "cheetah.ring_cast", [&]() {
+ return DISPATCH_ALL_FIELDS(field, [&]() {
using from_ring2k_t = ring2k_t;
- return DISPATCH_ALL_FIELDS(ftype, "cheetah.ring_cast", [&]() {
+ return DISPATCH_ALL_FIELDS(ftype, [&]() {
using to_ring2k_t = ring2k_t;
NdArrayView _in(in);
NdArrayView _res(res);
@@ -383,7 +383,7 @@ spu::Value Logistic(spu::SPUContext* ctx, const spu::Value& x) {
spu::Value Sigmoid(spu::SPUContext* ctx, const spu::Value& x) {
namespace sk = spu::kernel;
auto c05 = sk::hlo::Constant(ctx, 0.5F, x.shape());
- auto half = sk::hal::right_shift_arithmetic(ctx, x, 1);
+ auto half = sk::hal::right_shift_arithmetic(ctx, x, {1});
auto divisor = sk::hlo::Add(ctx, sk::hlo::Constant(ctx, 1, x.shape()),
sk::hal::f_square(ctx, x));
return sk::hlo::Add(ctx, c05,
diff --git a/experimental/squirrel/tree_build_worker.cc b/experimental/squirrel/tree_build_worker.cc
index 30c441f9..06724354 100644
--- a/experimental/squirrel/tree_build_worker.cc
+++ b/experimental/squirrel/tree_build_worker.cc
@@ -230,7 +230,7 @@ void AccumulateHistogram(spu::NdArrayRef buckets_share, size_t nfeatures,
// The buckets belong to the i-th feature is
// `buckets[i*bucket_size:(i+1)*bucket_size]`
auto field = buckets_share.eltype().as()->field();
- DISPATCH_ALL_FIELDS(field, "AccumulateHistogram", [&]() {
+ DISPATCH_ALL_FIELDS(field, [&]() {
NdArrayView histogram(buckets_share);
for (size_t j = 0; j < nfeatures; ++j) {
size_t start = j * bucket_size;
diff --git a/experimental/squirrel/utils.cc b/experimental/squirrel/utils.cc
index 3f14436e..58a98d70 100644
--- a/experimental/squirrel/utils.cc
+++ b/experimental/squirrel/utils.cc
@@ -167,7 +167,7 @@ spu::Value MulPrivateArithWithPrivateBoolean(spu::SPUContext* ctx,
&kctx, arith.data(),
[&](const NdArrayRef& input,
const std::shared_ptr& base_ot) {
- return DISPATCH_ALL_FIELDS(ft, "ot", [&]() {
+ return DISPATCH_ALL_FIELDS(ft, [&]() {
NdArrayRef ot_out = spu::mpc::ring_zeros(ft, input.shape());
auto inp = absl::MakeConstSpan(&input.at(0), input.numel());
auto oup = absl::MakeSpan(&ot_out.at(0), ot_out.numel());
@@ -193,7 +193,7 @@ spu::Value MulPrivateArithWithPrivateBoolean(spu::SPUContext* ctx,
&kctx, boolean,
[&](absl::Span input,
const std::shared_ptr& base_ot) {
- return DISPATCH_ALL_FIELDS(ft, "ot", [&]() {
+ return DISPATCH_ALL_FIELDS(ft, [&]() {
NdArrayRef ot_out = spu::mpc::ring_zeros(ft, {(int64_t)input.size()});
auto oup = absl::MakeSpan(&ot_out.at(0), input.size());
base_ot->GetReceiverCOT()->RecvCAMCC(input, oup);
@@ -222,7 +222,7 @@ spu::Value MulArithShareWithANDBoolShare(spu::SPUContext* ctx,
std::shared_ptr base_ot) {
NdArrayRef out(x.eltype(), x.shape());
- DISPATCH_ALL_FIELDS(ft, "camcc", [&]() {
+ DISPATCH_ALL_FIELDS(ft, [&]() {
spu::NdArrayView _ashr(x);
auto oup = absl::MakeSpan(&out.at(0), y.size());
std::vector corr(y.size());
diff --git a/experimental/squirrel/utils_test.cc b/experimental/squirrel/utils_test.cc
index 3bdc004e..8cde97fd 100644
--- a/experimental/squirrel/utils_test.cc
+++ b/experimental/squirrel/utils_test.cc
@@ -85,7 +85,7 @@ TEST_F(UtilsTest, ReduceSum) {
const double fxp = std::pow(2., rt_config.fxp_fraction_bits());
auto flatten = got.data().reshape({got.numel()});
- DISPATCH_ALL_FIELDS(field, "check", [&]() {
+ DISPATCH_ALL_FIELDS(field, [&]() {
using s2k = std::make_signed::type;
NdArrayView got(flatten);
for (int64_t i = 0; i < expected.numel(); ++i) {
@@ -136,7 +136,7 @@ TEST_F(UtilsTest, ArgMax) {
if (lctx->Rank() == 0) {
auto flatten = got.data().reshape({got.numel()});
- DISPATCH_ALL_FIELDS(field, "check", [&]() {
+ DISPATCH_ALL_FIELDS(field, [&]() {
NdArrayView got(flatten);
for (size_t i = 0; i < expected.size(); ++i) {
ASSERT_EQ(expected(i), got[i]);
diff --git a/libspu/compiler/common/compilation_context.h b/libspu/compiler/common/compilation_context.h
index 24ebe043..8abfed96 100644
--- a/libspu/compiler/common/compilation_context.h
+++ b/libspu/compiler/common/compilation_context.h
@@ -16,7 +16,6 @@
#include
#include
-#include
#include "mlir/IR/MLIRContext.h"
#include "mlir/Pass/PassManager.h"
diff --git a/libspu/compiler/compile.cc b/libspu/compiler/compile.cc
index 38a59acb..236d6d02 100644
--- a/libspu/compiler/compile.cc
+++ b/libspu/compiler/compile.cc
@@ -14,8 +14,6 @@
#include "libspu/compiler/compile.h"
-#include "mlir/IR/BuiltinOps.h"
-
#include "libspu/compiler/codegen/codegen.h"
#include "libspu/compiler/common/compilation_context.h"
#include "libspu/compiler/core/core.h"
diff --git a/libspu/compiler/core/core.cc b/libspu/compiler/core/core.cc
index e977050c..b7c46d8f 100644
--- a/libspu/compiler/core/core.cc
+++ b/libspu/compiler/core/core.cc
@@ -16,7 +16,6 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
-#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
diff --git a/libspu/compiler/front_end/fe.cc b/libspu/compiler/front_end/fe.cc
index e560c772..5bcb693f 100644
--- a/libspu/compiler/front_end/fe.cc
+++ b/libspu/compiler/front_end/fe.cc
@@ -21,7 +21,6 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
-#include "spdlog/spdlog.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "xla/mlir_hlo/mhlo/transforms/passes.h"
diff --git a/libspu/compiler/tests/interpret/abs.mlir b/libspu/compiler/tests/interpret/abs.mlir
index dd655bf1..49e409a8 100644
--- a/libspu/compiler/tests/interpret/abs.mlir
+++ b/libspu/compiler/tests/interpret/abs.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @abs_op_test_i64_i64_p() {
diff --git a/libspu/compiler/tests/interpret/add.mlir b/libspu/compiler/tests/interpret/add.mlir
index a763c0ae..5074d3eb 100644
--- a/libspu/compiler/tests/interpret/add.mlir
+++ b/libspu/compiler/tests/interpret/add.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @add_op_test_i8_i8_pp() {
diff --git a/libspu/compiler/tests/interpret/and.mlir b/libspu/compiler/tests/interpret/and.mlir
index 6e3cdfa6..618d3222 100644
--- a/libspu/compiler/tests/interpret/and.mlir
+++ b/libspu/compiler/tests/interpret/and.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @and_op_test_i8_i8_pp() {
diff --git a/libspu/compiler/tests/interpret/atan2.mlir b/libspu/compiler/tests/interpret/atan2.mlir
index e1938852..52c2b122 100644
--- a/libspu/compiler/tests/interpret/atan2.mlir
+++ b/libspu/compiler/tests/interpret/atan2.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @atan2_op_test_f64_f64_pp() {
diff --git a/libspu/compiler/tests/interpret/broadcast.mlir b/libspu/compiler/tests/interpret/broadcast.mlir
index 1d0e9a63..7e2d17f8 100644
--- a/libspu/compiler/tests/interpret/broadcast.mlir
+++ b/libspu/compiler/tests/interpret/broadcast.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
func.func @broadcast_in_dim() {
%operand = pphlo.constant dense<[[1], [2], [3]]> : tensor<3x1xi64>
diff --git a/libspu/compiler/tests/interpret/case.mlir b/libspu/compiler/tests/interpret/case.mlir
index 543fbf86..96105d4e 100644
--- a/libspu/compiler/tests/interpret/case.mlir
+++ b/libspu/compiler/tests/interpret/case.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
func.func @case_negative_index_default() {
%index = pphlo.constant dense<-1> : tensor
diff --git a/libspu/compiler/tests/interpret/ceil.mlir b/libspu/compiler/tests/interpret/ceil.mlir
index d4f74b8b..23c05538 100644
--- a/libspu/compiler/tests/interpret/ceil.mlir
+++ b/libspu/compiler/tests/interpret/ceil.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @ceil_op_test_f16_f16_p() {
diff --git a/libspu/compiler/tests/interpret/clamp.mlir b/libspu/compiler/tests/interpret/clamp.mlir
index 69c05696..9e76acc8 100644
--- a/libspu/compiler/tests/interpret/clamp.mlir
+++ b/libspu/compiler/tests/interpret/clamp.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
func.func @clamp_op_test_si64() {
%min = pphlo.constant dense<[1, 5, -5]> : tensor<3xi64>
diff --git a/libspu/compiler/tests/interpret/concatenate.mlir b/libspu/compiler/tests/interpret/concatenate.mlir
index 1d0d7655..63107f22 100644
--- a/libspu/compiler/tests/interpret/concatenate.mlir
+++ b/libspu/compiler/tests/interpret/concatenate.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
func.func @concatenate() {
%input0 = pphlo.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi64>
diff --git a/libspu/compiler/tests/interpret/convert.mlir b/libspu/compiler/tests/interpret/convert.mlir
index 747047ee..b3b23f3b 100644
--- a/libspu/compiler/tests/interpret/convert.mlir
+++ b/libspu/compiler/tests/interpret/convert.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
func.func @convert_op_test_1() {
%0 = pphlo.constant dense<[0, 1, 8, -9, 0]> : tensor<5xi32>
diff --git a/libspu/compiler/tests/interpret/convolution.mlir b/libspu/compiler/tests/interpret/convolution.mlir
index 294cdd36..2ce1030c 100644
--- a/libspu/compiler/tests/interpret/convolution.mlir
+++ b/libspu/compiler/tests/interpret/convolution.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
func.func @main() {
%0 = pphlo.constant dense<[[[[ 1.0, 2.0, 3.0, 4.0],
diff --git a/libspu/compiler/tests/interpret/cosine.mlir b/libspu/compiler/tests/interpret/cosine.mlir
index 72d6f249..5ee36f58 100644
--- a/libspu/compiler/tests/interpret/cosine.mlir
+++ b/libspu/compiler/tests/interpret/cosine.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @cosine_op_test_f16_f16_p() {
diff --git a/libspu/compiler/tests/interpret/divide.mlir b/libspu/compiler/tests/interpret/divide.mlir
index 56ce53c2..a7540602 100644
--- a/libspu/compiler/tests/interpret/divide.mlir
+++ b/libspu/compiler/tests/interpret/divide.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @divide_op_test_i64_i64_pp() {
diff --git a/libspu/compiler/tests/interpret/dot_general.mlir b/libspu/compiler/tests/interpret/dot_general.mlir
index 29023c81..a4c506ed 100644
--- a/libspu/compiler/tests/interpret/dot_general.mlir
+++ b/libspu/compiler/tests/interpret/dot_general.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
func.func @dot_general_op_test_si64() {
%lhs = pphlo.constant dense<[[[1, 2], [3, 4]],
diff --git a/libspu/compiler/tests/interpret/dynamic_slice.mlir b/libspu/compiler/tests/interpret/dynamic_slice.mlir
index 672ace7c..d0275fd9 100644
--- a/libspu/compiler/tests/interpret/dynamic_slice.mlir
+++ b/libspu/compiler/tests/interpret/dynamic_slice.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
func.func @dynamic_slice() {
%operand = pphlo.constant dense<[[1, 1, 1],
diff --git a/libspu/compiler/tests/interpret/dynamic_update_slice.mlir b/libspu/compiler/tests/interpret/dynamic_update_slice.mlir
index a7053428..423528ac 100644
--- a/libspu/compiler/tests/interpret/dynamic_update_slice.mlir
+++ b/libspu/compiler/tests/interpret/dynamic_update_slice.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
func.func @dynamic_update_slice() {
%operand = pphlo.constant dense<[[1, 1, 1, 1],
diff --git a/libspu/compiler/tests/interpret/equal.mlir b/libspu/compiler/tests/interpret/equal.mlir
index f5364638..c8291d1a 100644
--- a/libspu/compiler/tests/interpret/equal.mlir
+++ b/libspu/compiler/tests/interpret/equal.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @equal_op_test_i64_i1_pp() {
diff --git a/libspu/compiler/tests/interpret/exponential.mlir b/libspu/compiler/tests/interpret/exponential.mlir
index a85c5bb6..21ec3591 100644
--- a/libspu/compiler/tests/interpret/exponential.mlir
+++ b/libspu/compiler/tests/interpret/exponential.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @exponential_op_test_f64_f64_p() {
diff --git a/libspu/compiler/tests/interpret/exponential_minus_one.mlir b/libspu/compiler/tests/interpret/exponential_minus_one.mlir
index 1a337a21..35131bbc 100644
--- a/libspu/compiler/tests/interpret/exponential_minus_one.mlir
+++ b/libspu/compiler/tests/interpret/exponential_minus_one.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @exponential_minus_one_op_test_f64_f64_p() {
diff --git a/libspu/compiler/tests/interpret/floor.mlir b/libspu/compiler/tests/interpret/floor.mlir
index 97ccdb1e..602d0161 100644
--- a/libspu/compiler/tests/interpret/floor.mlir
+++ b/libspu/compiler/tests/interpret/floor.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @floor_op_test_f16_f16_p() {
diff --git a/libspu/compiler/tests/interpret/generate_mlir_tests.py b/libspu/compiler/tests/interpret/generate_mlir_tests.py
index f9bdef86..7a271164 100755
--- a/libspu/compiler/tests/interpret/generate_mlir_tests.py
+++ b/libspu/compiler/tests/interpret/generate_mlir_tests.py
@@ -65,6 +65,7 @@ def TestCase(inputs, expected, checker='expect_eq', tol=None):
"or",
# "popcnt",
"power",
+ "reciprocal",
"reshape",
"round_afz",
"rsqrt",
@@ -99,6 +100,14 @@ def TestCase(inputs, expected, checker='expect_eq', tol=None):
f.write(
"// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s\n"
)
+ f.write(
+ "// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s\n"
+ )
+ # Some test values in max and min are not supported by protocol 5.
+ if test not in ["max", "min"]:
+ f.write(
+ "// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s\n"
+ )
f.write("// AUTO GENERATED, DO NOT EDIT\n\n")
# Emit cases
diff --git a/libspu/compiler/tests/interpret/greater.mlir b/libspu/compiler/tests/interpret/greater.mlir
index 7f8e76be..92140f85 100644
--- a/libspu/compiler/tests/interpret/greater.mlir
+++ b/libspu/compiler/tests/interpret/greater.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @greater_op_test_i64_i1_pp() {
diff --git a/libspu/compiler/tests/interpret/greater_equal.mlir b/libspu/compiler/tests/interpret/greater_equal.mlir
index 305aaff5..af3ffa7a 100644
--- a/libspu/compiler/tests/interpret/greater_equal.mlir
+++ b/libspu/compiler/tests/interpret/greater_equal.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @greater_equal_op_test_i64_i1_pp() {
diff --git a/libspu/compiler/tests/interpret/if.mlir b/libspu/compiler/tests/interpret/if.mlir
index cc98b65d..ef73d547 100644
--- a/libspu/compiler/tests/interpret/if.mlir
+++ b/libspu/compiler/tests/interpret/if.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
func.func @if_ops_true_branch() {
%pred = pphlo.constant dense : tensor
diff --git a/libspu/compiler/tests/interpret/iota.mlir b/libspu/compiler/tests/interpret/iota.mlir
index a7ee86ed..cc71b040 100644
--- a/libspu/compiler/tests/interpret/iota.mlir
+++ b/libspu/compiler/tests/interpret/iota.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
func.func @iota_op_test_si8_dim_0() {
%0 = pphlo.iota dim = 0 : tensor<3x4xi8>
diff --git a/libspu/compiler/tests/interpret/less.mlir b/libspu/compiler/tests/interpret/less.mlir
index 58444a29..1a9d3060 100644
--- a/libspu/compiler/tests/interpret/less.mlir
+++ b/libspu/compiler/tests/interpret/less.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @less_op_test_i64_i1_pp() {
diff --git a/libspu/compiler/tests/interpret/less_equal.mlir b/libspu/compiler/tests/interpret/less_equal.mlir
index 9951a569..c454bcc7 100644
--- a/libspu/compiler/tests/interpret/less_equal.mlir
+++ b/libspu/compiler/tests/interpret/less_equal.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @less_equal_op_test_i64_i1_pp() {
diff --git a/libspu/compiler/tests/interpret/log.mlir b/libspu/compiler/tests/interpret/log.mlir
index af61d86a..bf309137 100644
--- a/libspu/compiler/tests/interpret/log.mlir
+++ b/libspu/compiler/tests/interpret/log.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @log_op_test_f64_f64_p() {
diff --git a/libspu/compiler/tests/interpret/log_plus_one.mlir b/libspu/compiler/tests/interpret/log_plus_one.mlir
index 3aec9184..99bcf9ff 100644
--- a/libspu/compiler/tests/interpret/log_plus_one.mlir
+++ b/libspu/compiler/tests/interpret/log_plus_one.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @log_plus_one_op_test_f64_f64_p() {
diff --git a/libspu/compiler/tests/interpret/logistic.mlir b/libspu/compiler/tests/interpret/logistic.mlir
index eeac6fab..655cb82b 100644
--- a/libspu/compiler/tests/interpret/logistic.mlir
+++ b/libspu/compiler/tests/interpret/logistic.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @logistic_op_test_f64_f64_p() {
diff --git a/libspu/compiler/tests/interpret/maximum.mlir b/libspu/compiler/tests/interpret/maximum.mlir
index b90553e7..4919c356 100644
--- a/libspu/compiler/tests/interpret/maximum.mlir
+++ b/libspu/compiler/tests/interpret/maximum.mlir
@@ -1,6 +1,7 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @maximum_op_test_i8_i8_pp() {
diff --git a/libspu/compiler/tests/interpret/minimum.mlir b/libspu/compiler/tests/interpret/minimum.mlir
index 74bb2b83..1853124b 100644
--- a/libspu/compiler/tests/interpret/minimum.mlir
+++ b/libspu/compiler/tests/interpret/minimum.mlir
@@ -1,6 +1,7 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @minimum_op_test_i8_i8_pp() {
diff --git a/libspu/compiler/tests/interpret/multiply.mlir b/libspu/compiler/tests/interpret/multiply.mlir
index 82b780f3..f7d415c4 100644
--- a/libspu/compiler/tests/interpret/multiply.mlir
+++ b/libspu/compiler/tests/interpret/multiply.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @multiply_op_test_i8_i8_pp() {
diff --git a/libspu/compiler/tests/interpret/negate.mlir b/libspu/compiler/tests/interpret/negate.mlir
index 8c5e74c3..6a00d9b8 100644
--- a/libspu/compiler/tests/interpret/negate.mlir
+++ b/libspu/compiler/tests/interpret/negate.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @negate_op_test_i8_i8_p() {
diff --git a/libspu/compiler/tests/interpret/not.mlir b/libspu/compiler/tests/interpret/not.mlir
index 0fdf44e4..721e1850 100644
--- a/libspu/compiler/tests/interpret/not.mlir
+++ b/libspu/compiler/tests/interpret/not.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @not_op_test_i8_i8_p() {
diff --git a/libspu/compiler/tests/interpret/not_equal.mlir b/libspu/compiler/tests/interpret/not_equal.mlir
index 1bbbd5b3..cb598070 100644
--- a/libspu/compiler/tests/interpret/not_equal.mlir
+++ b/libspu/compiler/tests/interpret/not_equal.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @not_equal_op_test_i64_i1_pp() {
diff --git a/libspu/compiler/tests/interpret/or.mlir b/libspu/compiler/tests/interpret/or.mlir
index 79836813..1eef975f 100644
--- a/libspu/compiler/tests/interpret/or.mlir
+++ b/libspu/compiler/tests/interpret/or.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @or_op_test_i8_i8_pp() {
diff --git a/libspu/compiler/tests/interpret/pad.mlir b/libspu/compiler/tests/interpret/pad.mlir
index abedc73a..521313e2 100644
--- a/libspu/compiler/tests/interpret/pad.mlir
+++ b/libspu/compiler/tests/interpret/pad.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
func.func @pad() {
%operand = pphlo.constant dense<[[0, 0, 0, 0],
diff --git a/libspu/compiler/tests/interpret/popcnt.mlir b/libspu/compiler/tests/interpret/popcnt.mlir
index e5f83152..efea9133 100644
--- a/libspu/compiler/tests/interpret/popcnt.mlir
+++ b/libspu/compiler/tests/interpret/popcnt.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @popcnt_op_test_i64_i64_p() {
diff --git a/libspu/compiler/tests/interpret/power.mlir b/libspu/compiler/tests/interpret/power.mlir
index 307fce4c..b6bfc475 100644
--- a/libspu/compiler/tests/interpret/power.mlir
+++ b/libspu/compiler/tests/interpret/power.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @power_op_test_i64_i64_pp() {
diff --git a/libspu/compiler/tests/interpret/reciprocal.mlir b/libspu/compiler/tests/interpret/reciprocal.mlir
new file mode 100644
index 00000000..fe1623a5
--- /dev/null
+++ b/libspu/compiler/tests/interpret/reciprocal.mlir
@@ -0,0 +1,26 @@
+// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
+// AUTO GENERATED, DO NOT EDIT
+
+func.func @reciprocal_op_test_f64_f64_p() {
+ %0 = pphlo.constant dense<[[1.0, -200.0], [100.0, 286991.875]]> : tensor<2x2xf64>
+ %1 = pphlo.reciprocal %0 : (tensor<2x2xf64>)->tensor<2x2xf64>
+ %2 = pphlo.constant dense<[[1.0, -0.005], [0.01, 0.0]]> : tensor<2x2xf64>
+ pphlo.custom_call @expect_almost_eq(%1, %2) { tol = 0.01 }: (tensor<2x2xf64>, tensor<2x2xf64>)->()
+ func.return
+}
+
+// -----
+
+func.func @reciprocal_op_test_f64_f64_s() {
+ %0 = pphlo.constant dense<[[1.0, -200.0], [100.0, 286991.875]]> : tensor<2x2xf64>
+ %1 = pphlo.convert %0 : (tensor<2x2xf64>)->tensor<2x2x!pphlo.secret>
+ %2 = pphlo.reciprocal %1 : (tensor<2x2x!pphlo.secret>)->tensor<2x2x!pphlo.secret>
+ %3 = pphlo.constant dense<[[1.0, -0.005], [0.01, 0.0]]> : tensor<2x2xf64>
+ %4 = pphlo.convert %2 : (tensor<2x2x!pphlo.secret>)->tensor<2x2xf64>
+ pphlo.custom_call @expect_almost_eq(%3, %4) { tol = 0.01 }: (tensor<2x2xf64>, tensor<2x2xf64>)->()
+ func.return
+}
diff --git a/libspu/compiler/tests/interpret/reduce.mlir b/libspu/compiler/tests/interpret/reduce.mlir
index 3b34bf3e..efc1d40d 100644
--- a/libspu/compiler/tests/interpret/reduce.mlir
+++ b/libspu/compiler/tests/interpret/reduce.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
func.func @reduce() {
%input = pphlo.constant dense<[[0, 1, 2, 3, 4, 5]]> : tensor<1x6xi64>
diff --git a/libspu/compiler/tests/interpret/reduce_window.mlir b/libspu/compiler/tests/interpret/reduce_window.mlir
index 5385ab04..d15cd021 100644
--- a/libspu/compiler/tests/interpret/reduce_window.mlir
+++ b/libspu/compiler/tests/interpret/reduce_window.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
func.func @reduce_window() {
%input = pphlo.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi64>
diff --git a/libspu/compiler/tests/interpret/reshape.mlir b/libspu/compiler/tests/interpret/reshape.mlir
index 9e483b77..0a670923 100644
--- a/libspu/compiler/tests/interpret/reshape.mlir
+++ b/libspu/compiler/tests/interpret/reshape.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @reshape_op_test_i32_i32_p() {
diff --git a/libspu/compiler/tests/interpret/reverse.mlir b/libspu/compiler/tests/interpret/reverse.mlir
index 63ab9590..832262e2 100644
--- a/libspu/compiler/tests/interpret/reverse.mlir
+++ b/libspu/compiler/tests/interpret/reverse.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
func.func @reverse() {
%operand = pphlo.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi64>
diff --git a/libspu/compiler/tests/interpret/ring_cast.mlir b/libspu/compiler/tests/interpret/ring_cast.mlir
index 94b53241..a6b6806c 100644
--- a/libspu/compiler/tests/interpret/ring_cast.mlir
+++ b/libspu/compiler/tests/interpret/ring_cast.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
func.func @cast_1() {
%c0 = pphlo.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>
diff --git a/libspu/compiler/tests/interpret/round_nearest_afz.mlir b/libspu/compiler/tests/interpret/round_nearest_afz.mlir
index 6601fa1d..40e64ebe 100644
--- a/libspu/compiler/tests/interpret/round_nearest_afz.mlir
+++ b/libspu/compiler/tests/interpret/round_nearest_afz.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @round_nearest_afz_op_test_f64_f64_p() {
diff --git a/libspu/compiler/tests/interpret/rsqrt.mlir b/libspu/compiler/tests/interpret/rsqrt.mlir
index ef6470c9..5e5e0869 100644
--- a/libspu/compiler/tests/interpret/rsqrt.mlir
+++ b/libspu/compiler/tests/interpret/rsqrt.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @rsqrt_op_test_f64_f64_p() {
diff --git a/libspu/compiler/tests/interpret/select.mlir b/libspu/compiler/tests/interpret/select.mlir
index 33e9a9c7..c2752761 100644
--- a/libspu/compiler/tests/interpret/select.mlir
+++ b/libspu/compiler/tests/interpret/select.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
func.func @select_op_test_si64() {
%pred = pphlo.constant dense<[true, false, true]> : tensor<3xi1>
diff --git a/libspu/compiler/tests/interpret/select_and_scatter.mlir b/libspu/compiler/tests/interpret/select_and_scatter.mlir
index 5e55dc59..c7f0057a 100644
--- a/libspu/compiler/tests/interpret/select_and_scatter.mlir
+++ b/libspu/compiler/tests/interpret/select_and_scatter.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// FIXME
func.func @select_and_scatter_op_test() {
diff --git a/libspu/compiler/tests/interpret/shift_right_arithmetic.mlir b/libspu/compiler/tests/interpret/shift_right_arithmetic.mlir
index 494666ae..ce4bedbc 100644
--- a/libspu/compiler/tests/interpret/shift_right_arithmetic.mlir
+++ b/libspu/compiler/tests/interpret/shift_right_arithmetic.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @shift_right_arithmetic_op_test_i64_i64_pp() {
diff --git a/libspu/compiler/tests/interpret/shift_right_logical.mlir b/libspu/compiler/tests/interpret/shift_right_logical.mlir
index 253d6cc3..73d69f0a 100644
--- a/libspu/compiler/tests/interpret/shift_right_logical.mlir
+++ b/libspu/compiler/tests/interpret/shift_right_logical.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @shift_right_logical_op_test_i64_i64_pp() {
diff --git a/libspu/compiler/tests/interpret/sign.mlir b/libspu/compiler/tests/interpret/sign.mlir
index b153a373..105c5cb2 100644
--- a/libspu/compiler/tests/interpret/sign.mlir
+++ b/libspu/compiler/tests/interpret/sign.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @sign_op_test_i64_i64_p() {
diff --git a/libspu/compiler/tests/interpret/sine.mlir b/libspu/compiler/tests/interpret/sine.mlir
index f0b59e5c..1d62529d 100644
--- a/libspu/compiler/tests/interpret/sine.mlir
+++ b/libspu/compiler/tests/interpret/sine.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @sine_op_test_f16_f16_p() {
diff --git a/libspu/compiler/tests/interpret/slice.mlir b/libspu/compiler/tests/interpret/slice.mlir
index b4b8bdbe..583b977b 100644
--- a/libspu/compiler/tests/interpret/slice.mlir
+++ b/libspu/compiler/tests/interpret/slice.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
func.func @slice_op() {
%operand = pphlo.constant dense<[[0, 0, 1, 0, 0, 1],
diff --git a/libspu/compiler/tests/interpret/sort.mlir b/libspu/compiler/tests/interpret/sort.mlir
index 837ded37..2433d4ae 100644
--- a/libspu/compiler/tests/interpret/sort.mlir
+++ b/libspu/compiler/tests/interpret/sort.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
func.func @sort_stable() {
%input0 = pphlo.constant dense<[[1, 2, 3], [3, 2, 1]]> : tensor<2x3xi64>
diff --git a/libspu/compiler/tests/interpret/sqrt.mlir b/libspu/compiler/tests/interpret/sqrt.mlir
index 9248077a..59372c6b 100644
--- a/libspu/compiler/tests/interpret/sqrt.mlir
+++ b/libspu/compiler/tests/interpret/sqrt.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @sqrt_op_test_f64_f64_p() {
diff --git a/libspu/compiler/tests/interpret/subtract.mlir b/libspu/compiler/tests/interpret/subtract.mlir
index ce8b9633..37032882 100644
--- a/libspu/compiler/tests/interpret/subtract.mlir
+++ b/libspu/compiler/tests/interpret/subtract.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @subtract_op_test_i8_i8_pp() {
diff --git a/libspu/compiler/tests/interpret/tanh.mlir b/libspu/compiler/tests/interpret/tanh.mlir
index 413dc6d2..ab45fa2e 100644
--- a/libspu/compiler/tests/interpret/tanh.mlir
+++ b/libspu/compiler/tests/interpret/tanh.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @tanh_op_test_f16_f16_p() {
diff --git a/libspu/compiler/tests/interpret/test_json/reciprocal.json b/libspu/compiler/tests/interpret/test_json/reciprocal.json
new file mode 100644
index 00000000..d8e07529
--- /dev/null
+++ b/libspu/compiler/tests/interpret/test_json/reciprocal.json
@@ -0,0 +1,24 @@
+{
+ "name": "reciprocal",
+ "template": "basic_unary",
+ "testcases": [
+ {
+ "inputs": [
+ {
+ "data": "[[1.0, -200.0], [100.0, 286991.875]]",
+ "shape": "2x2",
+ "dtype": "f64"
+ }
+ ],
+ "expected": [
+ {
+ "data": "[[1.0, -0.005], [0.01, 0.0]]",
+ "shape": "2x2",
+ "dtype": "f64"
+ }
+ ],
+ "checker": "expect_almost_eq",
+ "tol": 0.01
+ }
+ ]
+}
\ No newline at end of file
diff --git a/libspu/compiler/tests/interpret/transpose.mlir b/libspu/compiler/tests/interpret/transpose.mlir
index d5ce9b6e..d36883de 100644
--- a/libspu/compiler/tests/interpret/transpose.mlir
+++ b/libspu/compiler/tests/interpret/transpose.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
func.func @transpose_op_test_si32() {
%0 = pphlo.constant dense<[[[1,2],[3,4],[5,6]], [[7,8],[9,10],[11,12]]]> : tensor<2x3x2xi32>
diff --git a/libspu/compiler/tests/interpret/while.mlir b/libspu/compiler/tests/interpret/while.mlir
index 32789fae..734e199e 100644
--- a/libspu/compiler/tests/interpret/while.mlir
+++ b/libspu/compiler/tests/interpret/while.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
func.func @while() {
// int i = 0;
diff --git a/libspu/compiler/tests/interpret/xor.mlir b/libspu/compiler/tests/interpret/xor.mlir
index c2e1f1ee..a9ee8348 100644
--- a/libspu/compiler/tests/interpret/xor.mlir
+++ b/libspu/compiler/tests/interpret/xor.mlir
@@ -1,6 +1,8 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
+// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT
func.func @xor_op_test_i8_i8_pp() {
diff --git a/libspu/compiler/tests/passes/hlo2pphlo/select_and_scatter.mlir b/libspu/compiler/tests/passes/hlo2pphlo/select_and_scatter.mlir
index ac7492bc..bc21172f 100644
--- a/libspu/compiler/tests/passes/hlo2pphlo/select_and_scatter.mlir
+++ b/libspu/compiler/tests/passes/hlo2pphlo/select_and_scatter.mlir
@@ -1,7 +1,7 @@
// RUN: spu-opt --hlo-legalize-to-pphlo=input_vis_list=VIS_SECRET,VIS_PUBLIC,VIS_PUBLIC --lower-conversion-cast --split-input-file %s | FileCheck %s
func.func @main(%arg0: tensor<128x5x5x32xf32>, %arg1: tensor<128x4x4x32xf32>, %arg2: tensor) -> tensor<128x5x5x32xf32> {
- // CHECK: %1 = "pphlo.select_and_scatter"(%arg0, %arg1, %0) ({
+ // CHECK: %1 = "pphlo.select_and_scatter"(%arg0, %arg1, %0) <{window_dimensions = array, window_strides = array}> ({
// CHECK: ^bb0(%arg3: tensor>, %arg4: tensor>):
// CHECK: %2 = pphlo.greater_equal %arg3, %arg4 : (tensor>, tensor>) -> tensor>
// CHECK: pphlo.return %2 : tensor>
@@ -9,7 +9,7 @@ func.func @main(%arg0: tensor<128x5x5x32xf32>, %arg1: tensor<128x4x4x32xf32>, %a
// CHECK: ^bb0(%arg3: tensor, %arg4: tensor>):
// CHECK: %2 = pphlo.add %arg3, %arg4 : (tensor, tensor>) -> tensor>
// CHECK: pphlo.return %2 : tensor>
- // CHECK: }) {window_dimensions = array, window_strides = array} : (tensor<128x5x5x32x!pphlo.secret>, tensor<128x4x4x32xf32>, tensor>) -> tensor<128x5x5x32x!pphlo.secret>
+ // CHECK: }) : (tensor<128x5x5x32x!pphlo.secret>, tensor<128x4x4x32xf32>, tensor>) -> tensor<128x5x5x32x!pphlo.secret>
%0 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({
^bb0(%arg3: tensor, %arg4: tensor):
%1 = "stablehlo.compare"(%arg3, %arg4) {comparison_direction = #stablehlo} : (tensor, tensor) -> tensor
diff --git a/libspu/compiler/tests/passes/hlo2pphlo/sort_p.mlir b/libspu/compiler/tests/passes/hlo2pphlo/sort_p.mlir
index 2224ee7d..facd338f 100644
--- a/libspu/compiler/tests/passes/hlo2pphlo/sort_p.mlir
+++ b/libspu/compiler/tests/passes/hlo2pphlo/sort_p.mlir
@@ -2,11 +2,11 @@
func.func @main(%arg0: tensor<20xi32>) -> (tensor<20xi32>) {
%0 = stablehlo.iota dim = 0 : tensor<20xi32>
- // CHECK: %1:2 = "pphlo.sort"(%arg0, %0) ({
+ // CHECK: %1:2 = "pphlo.sort"(%arg0, %0) <{dimension = 0 : i64, is_stable = true}> ({
// CHECK: ^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor):
// CHECK: %2 = pphlo.less %arg1, %arg2 : (tensor, tensor) -> tensor
// CHECK: pphlo.return %2 : tensor
- // CHECK: }) {dimension = 0 : i64, is_stable = true} : (tensor<20xi32>, tensor<20xi32>) -> (tensor<20xi32>, tensor<20xi32>)
+ // CHECK: }) : (tensor<20xi32>, tensor<20xi32>) -> (tensor<20xi32>, tensor<20xi32>)
%1:2 = "stablehlo.sort"(%arg0, %0) ({
^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor):
%2 = stablehlo.compare LT, %arg1, %arg2 : (tensor, tensor) -> tensor
diff --git a/libspu/compiler/tests/passes/hlo2pphlo/sort_s.mlir b/libspu/compiler/tests/passes/hlo2pphlo/sort_s.mlir
index c4d4bd41..cf7ef73d 100644
--- a/libspu/compiler/tests/passes/hlo2pphlo/sort_s.mlir
+++ b/libspu/compiler/tests/passes/hlo2pphlo/sort_s.mlir
@@ -2,11 +2,11 @@
func.func @main(%arg0: tensor<20xi32>) -> (tensor<20xi32>) {
%0 = stablehlo.iota dim = 0 : tensor<20xi32>
- // CHECK: %1:2 = "pphlo.sort"(%arg0, %0) ({
+ // CHECK: %1:2 = "pphlo.sort"(%arg0, %0) <{dimension = 0 : i64, is_stable = true}> ({
// CHECK: ^bb0(%arg1: tensor>, %arg2: tensor>, %arg3: tensor, %arg4: tensor):
// CHECK: %2 = pphlo.less %arg1, %arg2 : (tensor>, tensor>) -> tensor>
// CHECK: pphlo.return %2 : tensor>
- // CHECK: }) {dimension = 0 : i64, is_stable = true} : (tensor<20x!pphlo.secret>, tensor<20xi32>) -> (tensor<20x!pphlo.secret>, tensor<20x!pphlo.secret>)
+ // CHECK: }) : (tensor<20x!pphlo.secret>, tensor<20xi32>) -> (tensor<20x!pphlo.secret>, tensor<20x!pphlo.secret>)
%1:2 = "stablehlo.sort"(%arg0, %0) ({
^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor):
%2 = stablehlo.compare LT, %arg1, %arg2 : (tensor, tensor) -> tensor
diff --git a/libspu/compiler/tools/spu-lsp.cc b/libspu/compiler/tools/spu-lsp.cc
index 0b9ddbd6..9f59509d 100644
--- a/libspu/compiler/tools/spu-lsp.cc
+++ b/libspu/compiler/tools/spu-lsp.cc
@@ -25,5 +25,6 @@ int main(int argc, char **argv) {
registry.insert();
mlir::func::registerInlinerExtension(registry);
- return mlir::failed(mlir::MlirLspServerMain(argc, argv, registry));
+ return static_cast(
+ mlir::failed(mlir::MlirLspServerMain(argc, argv, registry)));
}
diff --git a/libspu/compiler/tools/spu-translate.cc b/libspu/compiler/tools/spu-translate.cc
index 423321a5..68857fa0 100644
--- a/libspu/compiler/tools/spu-translate.cc
+++ b/libspu/compiler/tools/spu-translate.cc
@@ -22,12 +22,10 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Tools/mlir-translate/MlirTranslateMain.h"
#include "mlir/Tools/mlir-translate/Translation.h"
-#include "mlir/Transforms/Passes.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "xtensor/xio.hpp"
#include "libspu/compiler/common/compilation_context.h"
-#include "libspu/compiler/utils/utils.h"
#include "libspu/core/prelude.h"
#include "libspu/device/pphlo/pphlo_executor.h"
#include "libspu/dialect/pphlo/IR/dialect.h"
@@ -89,7 +87,7 @@ bool testOpHandler(::spu::SPUContext *sctx, mlir::Operation *op,
auto callOp = mlir::dyn_cast(op);
if (callOp.getCallTargetName() == "expect_almost_eq") {
::spu::Value runtimeLhs = inputs[0];
- ::spu::Value runtimeRhs = inputs[1];
+ const ::spu::Value &runtimeRhs = inputs[1];
if (!runtimeLhs.isPublic()) {
runtimeLhs = ::spu::kernel::hal::_s2p(sctx, runtimeLhs)
@@ -123,7 +121,7 @@ bool testOpHandler(::spu::SPUContext *sctx, mlir::Operation *op,
if (callOp.getCallTargetName() == "expect_eq") {
::spu::Value runtimeLhs = inputs[0];
- ::spu::Value runtimeRhs = inputs[1];
+ const ::spu::Value &runtimeRhs = inputs[1];
if (!runtimeLhs.isPublic()) {
runtimeLhs = ::spu::kernel::hal::_s2p(sctx, runtimeLhs)
@@ -239,6 +237,16 @@ void evalModule(ModuleOp module) {
numParties = 3;
break;
}
+ case 4: {
+ conf.set_protocol(::spu::CHEETAH);
+ numParties = 2;
+ break;
+ }
+ case 5: {
+ conf.set_protocol(::spu::SECURENN);
+ numParties = 3;
+ break;
+ }
}
SPDLOG_INFO(conf.DebugString());
@@ -278,6 +286,6 @@ TranslateFromMLIRRegistration interpretRegistration(
} // namespace mlir
int main(int argc, char **argv) {
- return failed(
- mlir::mlirTranslateMain(argc, argv, "SPU interpreter driver\n"));
+ return static_cast(
+ failed(mlir::mlirTranslateMain(argc, argv, "SPU interpreter driver\n")));
}
diff --git a/libspu/core/encoding.cc b/libspu/core/encoding.cc
index eb26ce9b..98a17a1a 100644
--- a/libspu/core/encoding.cc
+++ b/libspu/core/encoding.cc
@@ -60,8 +60,8 @@ NdArrayRef encodeToRing(const PtBufferView& bv, FieldType field,
}
if (pt_type == PT_F32 || pt_type == PT_F64 || pt_type == PT_F16) {
- DISPATCH_FLOAT_PT_TYPES(pt_type, "_", [&]() {
- DISPATCH_ALL_FIELDS(field, "_", [&]() {
+ DISPATCH_FLOAT_PT_TYPES(pt_type, [&]() {
+ DISPATCH_ALL_FIELDS(field, [&]() {
using Float = ScalarT;
using T = std::make_signed_t;
@@ -100,8 +100,8 @@ NdArrayRef encodeToRing(const PtBufferView& bv, FieldType field,
return dst;
} else {
// handle integer & boolean
- DISPATCH_INT_PT_TYPES(pt_type, "_", [&]() {
- DISPATCH_ALL_FIELDS(field, "_", [&]() {
+ DISPATCH_INT_PT_TYPES(pt_type, [&]() {
+ DISPATCH_ALL_FIELDS(field, [&]() {
using Integer = ScalarT;
SPU_ENFORCE(sizeof(ring2k_t) >= sizeof(Integer),
"integer encoding failed, ring={} could not represent {}",
@@ -138,8 +138,8 @@ void decodeFromRing(const NdArrayRef& src, DataType in_dtype, size_t fxp_bits,
*out_pt_type = pt_type;
}
- DISPATCH_ALL_FIELDS(field, "field", [&]() {
- DISPATCH_ALL_PT_TYPES(pt_type, "pt_type", [&]() {
+ DISPATCH_ALL_FIELDS(field, [&]() {
+ DISPATCH_ALL_PT_TYPES(pt_type, [&]() {
using T = std::make_signed_t;
auto _src = NdArrayView(src);
diff --git a/libspu/core/pt_buffer_view.cc b/libspu/core/pt_buffer_view.cc
index 50e6f891..3f7e1084 100644
--- a/libspu/core/pt_buffer_view.cc
+++ b/libspu/core/pt_buffer_view.cc
@@ -50,7 +50,7 @@ NdArrayRef convertToNdArray(PtBufferView bv) {
}
const auto type = makePtType(bv.pt_type);
auto out = NdArrayRef(type, bv.shape);
- return DISPATCH_ALL_PT_TYPES(bv.pt_type, "pt_type", [&]() {
+ return DISPATCH_ALL_PT_TYPES(bv.pt_type, [&]() {
using T = ScalarT;
if (bv.shape.numel() > 0) {
auto* out_ptr = out.data();
diff --git a/libspu/core/type_util.h b/libspu/core/type_util.h
index 2ac1d294..4e70ea7e 100644
--- a/libspu/core/type_util.h
+++ b/libspu/core/type_util.h
@@ -97,91 +97,90 @@ std::ostream& operator<<(std::ostream& os, const DataType& dtype);
// Helper macros to enumerate all py types.
// NOLINTNEXTLINE: Global internal used macro.
-#define __CASE_PT_TYPE(PT_TYPE, NAME, ...) \
- case (PT_TYPE): { \
- [[maybe_unused]] constexpr std::string_view _kName = NAME; \
- using ScalarT = EnumToPtType::type; \
- return __VA_ARGS__(); \
+#define __CASE_PT_TYPE(PT_TYPE, ...) \
+ case (PT_TYPE): { \
+ using ScalarT = EnumToPtType::type; \
+ return __VA_ARGS__(); \
}
-#define DISPATCH_FLOAT_PT_TYPES(PT_TYPE, NAME, ...) \
- [&] { \
- switch (PT_TYPE) { \
- __CASE_PT_TYPE(spu::PT_F16, NAME, __VA_ARGS__) \
- __CASE_PT_TYPE(spu::PT_F32, NAME, __VA_ARGS__) \
- __CASE_PT_TYPE(spu::PT_F64, NAME, __VA_ARGS__) \
- default: \
- SPU_THROW("{} not implemented for pt_type={}", #NAME, PT_TYPE); \
- } \
+#define DISPATCH_FLOAT_PT_TYPES(PT_TYPE, ...) \
+ [&] { \
+ switch (PT_TYPE) { \
+ __CASE_PT_TYPE(spu::PT_F16, __VA_ARGS__) \
+ __CASE_PT_TYPE(spu::PT_F32, __VA_ARGS__) \
+ __CASE_PT_TYPE(spu::PT_F64, __VA_ARGS__) \
+ default: \
+ SPU_THROW("unimplemented for pt_type={}", PT_TYPE); \
+ } \
}()
-#define DISPATCH_UINT_PT_TYPES(PT_TYPE, NAME, ...) \
- [&] { \
- switch (PT_TYPE) { \
- __CASE_PT_TYPE(spu::PT_U8, NAME, __VA_ARGS__) \
- __CASE_PT_TYPE(spu::PT_U16, NAME, __VA_ARGS__) \
- __CASE_PT_TYPE(spu::PT_U32, NAME, __VA_ARGS__) \
- __CASE_PT_TYPE(spu::PT_U64, NAME, __VA_ARGS__) \
- __CASE_PT_TYPE(spu::PT_U128, NAME, __VA_ARGS__) \
- default: \
- SPU_THROW("{} not implemented for pt_type={}", #NAME, PT_TYPE); \
- } \
+#define DISPATCH_UINT_PT_TYPES(PT_TYPE, ...) \
+ [&] { \
+ switch (PT_TYPE) { \
+ __CASE_PT_TYPE(spu::PT_U8, __VA_ARGS__) \
+ __CASE_PT_TYPE(spu::PT_U16, __VA_ARGS__) \
+ __CASE_PT_TYPE(spu::PT_U32, __VA_ARGS__) \
+ __CASE_PT_TYPE(spu::PT_U64, __VA_ARGS__) \
+ __CASE_PT_TYPE(spu::PT_U128, __VA_ARGS__) \
+ default: \
+ SPU_THROW("unimplemented for pt_type={}", PT_TYPE); \
+ } \
}()
-#define DISPATCH_INT_PT_TYPES(PT_TYPE, NAME, ...) \
- [&] { \
- switch (PT_TYPE) { \
- __CASE_PT_TYPE(spu::PT_I1, NAME, __VA_ARGS__) \
- __CASE_PT_TYPE(spu::PT_I8, NAME, __VA_ARGS__) \
- __CASE_PT_TYPE(spu::PT_U8, NAME, __VA_ARGS__) \
- __CASE_PT_TYPE(spu::PT_I16, NAME, __VA_ARGS__) \
- __CASE_PT_TYPE(spu::PT_U16, NAME, __VA_ARGS__) \
- __CASE_PT_TYPE(spu::PT_I32, NAME, __VA_ARGS__) \
- __CASE_PT_TYPE(spu::PT_U32, NAME, __VA_ARGS__) \
- __CASE_PT_TYPE(spu::PT_I64, NAME, __VA_ARGS__) \
- __CASE_PT_TYPE(spu::PT_U64, NAME, __VA_ARGS__) \
- default: \
- SPU_THROW("{} not implemented for pt_type={}", #NAME, PT_TYPE); \
- } \
+#define DISPATCH_INT_PT_TYPES(PT_TYPE, ...) \
+ [&] { \
+ switch (PT_TYPE) { \
+ __CASE_PT_TYPE(spu::PT_I1, __VA_ARGS__) \
+ __CASE_PT_TYPE(spu::PT_I8, __VA_ARGS__) \
+ __CASE_PT_TYPE(spu::PT_U8, __VA_ARGS__) \
+ __CASE_PT_TYPE(spu::PT_I16, __VA_ARGS__) \
+ __CASE_PT_TYPE(spu::PT_U16, __VA_ARGS__) \
+ __CASE_PT_TYPE(spu::PT_I32, __VA_ARGS__) \
+ __CASE_PT_TYPE(spu::PT_U32, __VA_ARGS__) \
+ __CASE_PT_TYPE(spu::PT_I64, __VA_ARGS__) \
+ __CASE_PT_TYPE(spu::PT_U64, __VA_ARGS__) \
+ default: \
+ SPU_THROW("unimplemented for pt_type={}", PT_TYPE); \
+ } \
}()
-#define DISPATCH_ALL_PT_TYPES(PT_TYPE, NAME, ...) \
- [&] { \
- switch (PT_TYPE) { \
- __CASE_PT_TYPE(spu::PT_I1, NAME, __VA_ARGS__) \
- __CASE_PT_TYPE(spu::PT_I8, NAME, __VA_ARGS__) \
- __CASE_PT_TYPE(spu::PT_U8, NAME, __VA_ARGS__) \
- __CASE_PT_TYPE(spu::PT_I16, NAME, __VA_ARGS__) \
- __CASE_PT_TYPE(spu::PT_U16, NAME, __VA_ARGS__) \
- __CASE_PT_TYPE(spu::PT_I32, NAME, __VA_ARGS__) \
- __CASE_PT_TYPE(spu::PT_U32, NAME, __VA_ARGS__) \
- __CASE_PT_TYPE(spu::PT_I64, NAME, __VA_ARGS__) \
- __CASE_PT_TYPE(spu::PT_U64, NAME, __VA_ARGS__) \
- __CASE_PT_TYPE(spu::PT_F16, NAME, __VA_ARGS__) \
- __CASE_PT_TYPE(spu::PT_F32, NAME, __VA_ARGS__) \
- __CASE_PT_TYPE(spu::PT_F64, NAME, __VA_ARGS__) \
- default: \
- SPU_THROW("{} not implemented for pt_type={}", #NAME, PT_TYPE); \
- } \
+#define DISPATCH_ALL_PT_TYPES(PT_TYPE, ...) \
+ [&] { \
+ switch (PT_TYPE) { \
+ __CASE_PT_TYPE(spu::PT_I1, __VA_ARGS__) \
+ __CASE_PT_TYPE(spu::PT_I8, __VA_ARGS__) \
+ __CASE_PT_TYPE(spu::PT_U8, __VA_ARGS__) \
+ __CASE_PT_TYPE(spu::PT_I16, __VA_ARGS__) \
+ __CASE_PT_TYPE(spu::PT_U16, __VA_ARGS__) \
+ __CASE_PT_TYPE(spu::PT_I32, __VA_ARGS__) \
+ __CASE_PT_TYPE(spu::PT_U32, __VA_ARGS__) \
+ __CASE_PT_TYPE(spu::PT_I64, __VA_ARGS__) \
+ __CASE_PT_TYPE(spu::PT_U64, __VA_ARGS__) \
+ __CASE_PT_TYPE(spu::PT_F16, __VA_ARGS__) \
+ __CASE_PT_TYPE(spu::PT_F32, __VA_ARGS__) \
+ __CASE_PT_TYPE(spu::PT_F64, __VA_ARGS__) \
+ default: \
+ SPU_THROW("unimplemented for pt_type={}", PT_TYPE); \
+ } \
}()
-#define DISPATCH_ALL_NONE_BOOL_PT_TYPES(PT_TYPE, NAME, ...) \
- [&] { \
- switch (PT_TYPE) { \
- __CASE_PT_TYPE(spu::PT_I8, NAME, __VA_ARGS__) \
- __CASE_PT_TYPE(spu::PT_U8, NAME, __VA_ARGS__) \
- __CASE_PT_TYPE(spu::PT_I16, NAME, __VA_ARGS__) \
- __CASE_PT_TYPE(spu::PT_U16, NAME, __VA_ARGS__) \
- __CASE_PT_TYPE(spu::PT_I32, NAME, __VA_ARGS__) \
- __CASE_PT_TYPE(spu::PT_U32, NAME, __VA_ARGS__) \
- __CASE_PT_TYPE(spu::PT_I64, NAME, __VA_ARGS__) \
- __CASE_PT_TYPE(spu::PT_U64, NAME, __VA_ARGS__) \
- __CASE_PT_TYPE(spu::PT_F16, NAME, __VA_ARGS__) \
- __CASE_PT_TYPE(spu::PT_F32, NAME, __VA_ARGS__) \
- __CASE_PT_TYPE(spu::PT_F64, NAME, __VA_ARGS__) \
- default: \
- SPU_THROW("{} not implemented for pt_type={}", #NAME, PT_TYPE); \
- } \
+#define DISPATCH_ALL_NONE_BOOL_PT_TYPES(PT_TYPE, ...) \
+ [&] { \
+ switch (PT_TYPE) { \
+ __CASE_PT_TYPE(spu::PT_I8, __VA_ARGS__) \
+ __CASE_PT_TYPE(spu::PT_U8, __VA_ARGS__) \
+ __CASE_PT_TYPE(spu::PT_I16, __VA_ARGS__) \
+ __CASE_PT_TYPE(spu::PT_U16, __VA_ARGS__) \
+ __CASE_PT_TYPE(spu::PT_I32, __VA_ARGS__) \
+ __CASE_PT_TYPE(spu::PT_U32, __VA_ARGS__) \
+ __CASE_PT_TYPE(spu::PT_I64, __VA_ARGS__) \
+ __CASE_PT_TYPE(spu::PT_U64, __VA_ARGS__) \
+ __CASE_PT_TYPE(spu::PT_F16, __VA_ARGS__) \
+ __CASE_PT_TYPE(spu::PT_F32, __VA_ARGS__) \
+ __CASE_PT_TYPE(spu::PT_F64, __VA_ARGS__) \
+ default: \
+ SPU_THROW("unimplemented for pt_type={}", PT_TYPE); \
+ } \
}()
std::ostream& operator<<(std::ostream& os, const PtType& pt_type);
@@ -241,24 +240,23 @@ inline size_t SizeOf(FieldType field) { return SizeOf(GetStorageType(field)); }
// Helper macros to enumerate all fields
// NOLINTNEXTLINE: Global internal used macro.
-#define __CASE_FIELD(FIELD, NAME, ...) \
+#define __CASE_FIELD(FIELD, ...) \
case (FIELD): { \
/* inject `_kField` & `_kName` for the continuation call */ \
[[maybe_unused]] constexpr spu::FieldType _kField = FIELD; \
- [[maybe_unused]] constexpr std::string_view _kName = NAME; \
using ring2k_t [[maybe_unused]] = Ring2kTrait<_kField>::scalar_t; \
return __VA_ARGS__(); \
}
-#define DISPATCH_ALL_FIELDS(FIELD, NAME, ...) \
- [&] { \
- switch (FIELD) { \
- __CASE_FIELD(spu::FieldType::FM32, NAME, __VA_ARGS__) \
- __CASE_FIELD(spu::FieldType::FM64, NAME, __VA_ARGS__) \
- __CASE_FIELD(spu::FieldType::FM128, NAME, __VA_ARGS__) \
- default: \
- SPU_THROW("{} not implemented for field={}", #NAME, FIELD); \
- } \
+#define DISPATCH_ALL_FIELDS(FIELD, ...) \
+ [&] { \
+ switch (FIELD) { \
+ __CASE_FIELD(spu::FieldType::FM32, __VA_ARGS__) \
+ __CASE_FIELD(spu::FieldType::FM64, __VA_ARGS__) \
+ __CASE_FIELD(spu::FieldType::FM128, __VA_ARGS__) \
+ default: \
+ SPU_THROW("unimplemented for field={}", FIELD); \
+ } \
}()
//////////////////////////////////////////////////////////////
diff --git a/libspu/dialect/pphlo/IR/dialect.td b/libspu/dialect/pphlo/IR/dialect.td
index de2c648c..7d47197c 100644
--- a/libspu/dialect/pphlo/IR/dialect.td
+++ b/libspu/dialect/pphlo/IR/dialect.td
@@ -32,15 +32,14 @@ def PPHlo_Dialect : Dialect {
string summary = "Privacy-Preserving HLO(PPHLO) dialect";
string description = [{
PPHLO represents a high level abstraction for language use by SPU.
- It implements a subset of mlir mhlo ops with it's own privacy-preserving focused type system.
+ It implements a subset of mlir stablehlo ops with it's own privacy-preserving focused type system.
- Learn more about mlir hlo at https://github.com/tensorflow/mlir-hlo
+ Learn more about mlir stablehlo at https://github.com/openxla/stablehlo
}];
let name = "pphlo";
let cppNamespace = "::mlir::spu::pphlo";
let useDefaultAttributePrinterParser = 0;
let useDefaultTypePrinterParser = 0;
- let usePropertiesForAttributes = 0;
let hasConstantMaterializer = 1;
let extraClassDeclaration = [{
Attribute parseAttribute(DialectAsmParser & parser, Type type)
diff --git a/libspu/dialect/pphlo/IR/type_inference.cc b/libspu/dialect/pphlo/IR/type_inference.cc
index 4e3a9464..6c95aac3 100644
--- a/libspu/dialect/pphlo/IR/type_inference.cc
+++ b/libspu/dialect/pphlo/IR/type_inference.cc
@@ -285,7 +285,7 @@ LogicalResult PadOp::inferReturnTypes(
::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
::llvm::SmallVectorImpl<::mlir::Type>& inferredReturnTypes) {
- PadOp::Adaptor adaptor(operands, attributes, {}, regions);
+ PadOp::Adaptor adaptor(operands, attributes, properties, regions);
return hlo::inferPadOp(location, adaptor.getOperand().getType(),
adaptor.getPaddingValue().getType(),
adaptor.getEdgePaddingLow(),
@@ -295,27 +295,27 @@ LogicalResult PadOp::inferReturnTypes(
LogicalResult ConcatenateOp::inferReturnTypes(
MLIRContext*, std::optional location, ValueRange operands,
- DictionaryAttr attributes, OpaqueProperties, RegionRange regions,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl& inferred_return_types) {
- ConcatenateOp::Adaptor adaptor(operands, attributes, {}, regions);
+ ConcatenateOp::Adaptor adaptor(operands, attributes, properties, regions);
return hlo::inferConcatenateOp(location, adaptor.getInputs().getTypes(),
adaptor.getDimension(), inferred_return_types);
}
LogicalResult TransposeOp::inferReturnTypes(
MLIRContext*, std::optional location, ValueRange operands,
- DictionaryAttr attributes, OpaqueProperties, RegionRange regions,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl& inferred_return_types) {
- TransposeOp::Adaptor adaptor(operands, attributes, {}, regions);
+ TransposeOp::Adaptor adaptor(operands, attributes, properties, regions);
return hlo::inferTransposeOp(location, adaptor.getOperand(),
adaptor.getPermutation(), inferred_return_types);
}
LogicalResult SliceOp::inferReturnTypes(
MLIRContext*, std::optional location, ValueRange operands,
- DictionaryAttr attributes, OpaqueProperties, RegionRange regions,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl& inferred_return_types) {
- SliceOp::Adaptor adaptor(operands, attributes, {}, regions);
+ SliceOp::Adaptor adaptor(operands, attributes, properties, regions);
return hlo::inferSliceOp(location, adaptor.getOperand().getType(),
adaptor.getStartIndices(), adaptor.getLimitIndices(),
adaptor.getStrides(), inferred_return_types);
@@ -375,9 +375,9 @@ LogicalResult inferDynamicSliceOp(std::optional location,
LogicalResult DynamicSliceOp::inferReturnTypes(
MLIRContext*, std::optional location, ValueRange operands,
- DictionaryAttr attributes, OpaqueProperties, RegionRange regions,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl& inferredReturnTypes) {
- DynamicSliceOp::Adaptor adaptor(operands, attributes, {}, regions);
+ DynamicSliceOp::Adaptor adaptor(operands, attributes, properties, regions);
return inferDynamicSliceOp(location, adaptor.getOperand().getType(),
adaptor.getStartIndices().getTypes(),
adaptor.getSliceSizes(), inferredReturnTypes);
@@ -427,9 +427,10 @@ LogicalResult inferDynamicUpdateSliceOp(
LogicalResult DynamicUpdateSliceOp::inferReturnTypes(
MLIRContext*, std::optional location, ValueRange operands,
- DictionaryAttr attributes, OpaqueProperties, RegionRange regions,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl& inferredReturnTypes) {
- DynamicUpdateSliceOp::Adaptor adaptor(operands, attributes, {}, regions);
+ DynamicUpdateSliceOp::Adaptor adaptor(operands, attributes, properties,
+ regions);
return inferDynamicUpdateSliceOp(
location, adaptor.getOperand(), adaptor.getUpdate(),
diff --git a/libspu/dialect/pphlo/transforms/hlo_legalize_to_pphlo.cc b/libspu/dialect/pphlo/transforms/hlo_legalize_to_pphlo.cc
index e17f85e5..781e0940 100644
--- a/libspu/dialect/pphlo/transforms/hlo_legalize_to_pphlo.cc
+++ b/libspu/dialect/pphlo/transforms/hlo_legalize_to_pphlo.cc
@@ -133,24 +133,20 @@ class FuncOpConverter : public OpConversionPattern<::mlir::func::FuncOp> {
auto ®ion = op.getBody();
// Convert non-entry blocks
- SmallVector conversions;
- for (Block &block : llvm::drop_begin(region, 1)) {
- conversions.emplace_back(block.getNumArguments());
- TypeConverter::SignatureConversion &back = conversions.back();
+ for (Block &block :
+ llvm::make_early_inc_range(llvm::drop_begin(region, 1))) {
+ TypeConverter::SignatureConversion conversion(
+ /*numOrigInputs=*/block.getNumArguments());
for (BlockArgument blockArgument : block.getArguments()) {
auto idx = blockArgument.getArgNumber();
auto vis_v = vis_.getValueVisibility(blockArgument);
auto convertedType = tools_.getType(
typeConverter->convertType(blockArgument.getType()), vis_v);
- back.addInputs(idx, convertedType);
+ conversion.addInputs(idx, convertedType);
}
- }
- if (failed(rewriter.convertNonEntryRegionTypes(®ion, *typeConverter,
- conversions))) {
- rewriter.cancelOpModification(op);
- return failure();
+ rewriter.applySignatureConversion(&block, conversion, getTypeConverter());
}
// Convert function arguments using the provided TypeConverter.
diff --git a/libspu/kernel/hal/constants.cc b/libspu/kernel/hal/constants.cc
index 1b44cbd0..8659c5e3 100644
--- a/libspu/kernel/hal/constants.cc
+++ b/libspu/kernel/hal/constants.cc
@@ -103,7 +103,7 @@ spu::Value zeros(SPUContext* ctx, DataType dtype, const Shape& shape) {
}
Value iota(SPUContext* ctx, DataType dtype, int64_t numel) {
- return DISPATCH_ALL_NONE_BOOL_PT_TYPES(getDecodeType(dtype), "iota", [&]() {
+ return DISPATCH_ALL_NONE_BOOL_PT_TYPES(getDecodeType(dtype), [&]() {
std::vector arr(numel);
std::iota(arr.begin(), arr.end(), 0);
return constant(ctx, arr, dtype, {numel});
diff --git a/libspu/kernel/hal/fxp_approx.cc b/libspu/kernel/hal/fxp_approx.cc
index 5cda83df..0f9864fd 100644
--- a/libspu/kernel/hal/fxp_approx.cc
+++ b/libspu/kernel/hal/fxp_approx.cc
@@ -69,12 +69,12 @@ Value log_minmax(SPUContext* ctx, const Value& x) {
// get most significant non-zero bit of x
// we avoid direct using detail::highestOneBit for saving one _prefix_or
- auto pre_x1 = _rshift(ctx, pre_x, 1);
+ auto pre_x1 = _rshift(ctx, pre_x, {1});
auto msb = _xor(ctx, pre_x, pre_x1);
// let x = x_norm * factor, where x in [1.0, 2.0)
auto factor = _bitrev(ctx, msb, 0, 2 * num_fxp_bits + 1).setDtype(x.dtype());
- detail::hintNumberOfBits(factor, 2 * num_fxp_bits + 1);
+ factor = maskNumberOfBits(ctx, factor, 2 * num_fxp_bits + 1);
auto norm = f_mul(ctx, x, factor);
// log(x) = log(x_norm * factor)
@@ -83,7 +83,7 @@ Value log_minmax(SPUContext* ctx, const Value& x) {
auto log_norm = log_minmax_normalized(ctx, norm);
auto log2_e =
_lshift(ctx, _sub(ctx, k, _constant(ctx, num_fxp_bits + 1, x.shape())),
- num_fxp_bits)
+ {static_cast(num_fxp_bits)})
.setDtype(x.dtype());
auto k_log2 = constant(ctx, std::log(2), x.dtype(), x.shape());
auto log_e = f_mul(ctx, log2_e, k_log2);
@@ -145,7 +145,7 @@ Value log2_pade(SPUContext* ctx, const Value& x) {
// let x = x_norm * factor, where x in [0.5, 1.0)
auto msb = detail::highestOneBit(ctx, x);
auto factor = _bitrev(ctx, msb, 0, 2 * num_fxp_bits).setDtype(x.dtype());
- detail::hintNumberOfBits(factor, 2 * num_fxp_bits);
+ factor = maskNumberOfBits(ctx, factor, 2 * num_fxp_bits);
auto norm = f_mul(ctx, x, factor);
// log2(x) = log2(x_norm * factor)
@@ -154,7 +154,7 @@ Value log2_pade(SPUContext* ctx, const Value& x) {
return _add(
ctx, log2_pade_normalized(ctx, norm),
_lshift(ctx, _sub(ctx, k, _constant(ctx, num_fxp_bits, x.shape())),
- num_fxp_bits))
+ {static_cast(num_fxp_bits)}))
.setDtype(x.dtype());
}
@@ -260,15 +260,18 @@ Value exp2_pade(SPUContext* ctx, const Value& x) {
const size_t bit_width = SizeOf(ctx->getField()) * 8;
const auto x_bshare = _prefer_b(ctx, x);
- const auto x_msb = _rshift(ctx, x_bshare, bit_width - 1);
- auto x_integer = _rshift(ctx, x_bshare, fbits);
+ const auto x_msb =
+ _rshift(ctx, x_bshare, {static_cast(bit_width - 1)});
+ auto x_integer = _rshift(ctx, x_bshare, {static_cast(fbits)});
auto x_fraction =
- _sub(ctx, x, _lshift(ctx, x_integer, fbits)).setDtype(x.dtype());
+ _sub(ctx, x, _lshift(ctx, x_integer, {static_cast(fbits)}))
+ .setDtype(x.dtype());
auto ret = exp2_pade_normalized(ctx, x_fraction);
for (size_t idx = 0; idx < int_bits; idx++) {
- auto a = _and(ctx, _rshift(ctx, x_integer, idx), k1);
- detail::hintNumberOfBits(a, 1);
+ auto a =
+ _and(ctx, _rshift(ctx, x_integer, {static_cast(idx)}), k1);
+ a = detail::maskNumberOfBits(ctx, a, 1);
a = _prefer_a(ctx, a);
const auto K = 1U << std::min(1UL << idx, bit_width - 2);
ret = _mul(ctx, ret,
@@ -543,7 +546,7 @@ static Value rsqrt_init_guess(SPUContext* ctx, const Value& x, const Value& z) {
// let u in [0.25, 0.5)
auto z_rev = _bitrev(ctx, z, 0, 2 * f);
- detail::hintNumberOfBits(z_rev, 2 * f);
+ z_rev = detail::maskNumberOfBits(ctx, z_rev, 2 * f);
auto u = _trunc(ctx, _mul(ctx, x, z_rev)).setDtype(x.dtype());
@@ -583,17 +586,17 @@ static Value rsqrt_comp(SPUContext* ctx, const Value& x, const Value& z) {
auto lo_mask =
_constant(ctx, (static_cast(1) << (k / 2)) - 1, x.shape());
auto z_even = _and(ctx, z_sep, lo_mask);
- auto z_odd = _and(ctx, _rshift(ctx, z_sep, k / 2), lo_mask);
+ auto z_odd =
+ _and(ctx, _rshift(ctx, z_sep, {static_cast(k / 2)}), lo_mask);
// a[i] = z[2*i] ^ z[2*i+1]
a = _xor(ctx, z_odd, z_even);
// b ^= z[2*i]
b = _bit_parity(ctx, z_even, k / 2);
- detail::hintNumberOfBits(b, 1);
}
auto a_rev = _bitrev(ctx, a, 0, (f / 2) * 2);
- detail::hintNumberOfBits(a_rev, (f / 2) * 2);
+ a_rev = detail::maskNumberOfBits(ctx, a_rev, (f / 2) * 2);
// do compensation
// Note:
@@ -623,7 +626,7 @@ static Value rsqrt_np2(SPUContext* ctx, const Value& x) {
SPU_TRACE_HAL_LEAF(ctx, x);
// let e = NP2(x), z = 2^(e+f)
- return _lshift(ctx, detail::highestOneBit(ctx, x), 1);
+ return _lshift(ctx, detail::highestOneBit(ctx, x), {1});
}
// Reference:
diff --git a/libspu/kernel/hal/fxp_base.cc b/libspu/kernel/hal/fxp_base.cc
index 209782e1..1594a9ac 100644
--- a/libspu/kernel/hal/fxp_base.cc
+++ b/libspu/kernel/hal/fxp_base.cc
@@ -72,7 +72,7 @@ Value polynomial(SPUContext* ctx, const Value& x,
Value highestOneBit(SPUContext* ctx, const Value& x) {
auto y = _prefix_or(ctx, x);
- auto y1 = _rshift(ctx, y, 1);
+ auto y1 = _rshift(ctx, y, {1});
return _xor(ctx, y, y1);
}
@@ -85,6 +85,13 @@ void hintNumberOfBits(const Value& a, size_t nbits) {
}
}
+Value maskNumberOfBits(SPUContext* ctx, const Value& in, size_t nbits) {
+ auto k1 = constant(ctx, 1UL, spu::DT_I64, in.shape());
+ auto mask = _sub(ctx, _lshift(ctx, k1, {static_cast(nbits)}), k1);
+ auto out = _and(ctx, in, mask).setDtype(in.dtype());
+ return out;
+}
+
namespace {
Value reciprocal_goldschmidt_normalized_approx(SPUContext* ctx,
@@ -178,7 +185,8 @@ Value div_goldschmidt_general(SPUContext* ctx, const Value& a, const Value& b,
// factor = 2^{f-m} = 2^{-m} * 2^f, the fixed point repr of 2^{-m}
const size_t num_fxp_bits = ctx->getFxpBits();
auto factor = _bitrev(ctx, b_msb, 0, 2 * num_fxp_bits).setDtype(b.dtype());
- detail::hintNumberOfBits(factor, 2 * num_fxp_bits);
+
+ factor = maskNumberOfBits(ctx, factor, 2 * num_fxp_bits);
// also, we use factor twice
factor = _prefer_a(ctx, factor);
@@ -209,7 +217,7 @@ Value reciprocal_goldschmidt_positive(SPUContext* ctx, const Value& b_abs) {
const size_t num_fxp_bits = ctx->getFxpBits();
auto factor =
_bitrev(ctx, b_msb, 0, 2 * num_fxp_bits).setDtype(b_abs.dtype());
- detail::hintNumberOfBits(factor, 2 * num_fxp_bits);
+ factor = maskNumberOfBits(ctx, factor, 2 * num_fxp_bits);
// also, we use factor twice
factor = _prefer_a(ctx, factor);
@@ -237,13 +245,12 @@ Value reciprocal_goldschmidt(SPUContext* ctx, const Value& b) {
// factor = 2^{f-m} = 2^{-m} * 2^f, the fixed point repr of 2^{-m}
const size_t num_fxp_bits = ctx->getFxpBits();
auto factor = _bitrev(ctx, b_msb, 0, 2 * num_fxp_bits).setDtype(b.dtype());
- detail::hintNumberOfBits(factor, 2 * num_fxp_bits);
+ factor = maskNumberOfBits(ctx, factor, 2 * num_fxp_bits);
// also, we use factor twice
factor = _prefer_a(ctx, factor);
// compute approximation of normalize b_abs
auto r = reciprocal_goldschmidt_normalized_approx(ctx, b_abs, factor);
-
r = f_mul(ctx, r, factor, SignType::Positive);
return _mux(ctx, is_negative, _negate(ctx, r), r).setDtype(b.dtype());
@@ -370,8 +377,8 @@ Value f_floor(SPUContext* ctx, const Value& x) {
SPU_ENFORCE(x.isFxp());
- const size_t fbits = ctx->getFxpBits();
- return _lshift(ctx, _arshift(ctx, x, fbits), fbits).setDtype(x.dtype());
+ const int64_t fbits = ctx->getFxpBits();
+ return _lshift(ctx, _arshift(ctx, x, {fbits}), {fbits}).setDtype(x.dtype());
}
Value f_ceil(SPUContext* ctx, const Value& x) {
diff --git a/libspu/kernel/hal/fxp_base.h b/libspu/kernel/hal/fxp_base.h
index c72022b6..61fe7bca 100644
--- a/libspu/kernel/hal/fxp_base.h
+++ b/libspu/kernel/hal/fxp_base.h
@@ -30,6 +30,8 @@ Value highestOneBit(SPUContext* ctx, const Value& x);
void hintNumberOfBits(const Value& a, size_t nbits);
+Value maskNumberOfBits(SPUContext* ctx, const Value& a, size_t nbits);
+
// we provide this general function to support some special cases (a or b has
// guarranteed sign) in fxp_approx for better both performance and accuracy.
Value div_goldschmidt_general(SPUContext* ctx, const Value& a, const Value& b,
diff --git a/libspu/kernel/hal/fxp_cleartext.cc b/libspu/kernel/hal/fxp_cleartext.cc
index 818e0b23..b5061d5b 100644
--- a/libspu/kernel/hal/fxp_cleartext.cc
+++ b/libspu/kernel/hal/fxp_cleartext.cc
@@ -57,7 +57,7 @@ Value applyFloatingPointFn(SPUContext* ctx, const Value& in, FN&& fn) {
auto pt_type = getDecodeType(in.dtype());
for (auto iter = fp_arr.begin(); iter != fp_arr.end(); ++iter) {
- DISPATCH_FLOAT_PT_TYPES(pt_type, "pt_type", [&]() {
+ DISPATCH_FLOAT_PT_TYPES(pt_type, [&]() {
auto* ptr = reinterpret_cast(&*iter);
*ptr = fn(*ptr);
});
@@ -92,9 +92,9 @@ Value applyFloatingPointFn(SPUContext* ctx, const Value& x, const Value& y,
for (auto itr_x = flp_x.begin(), itr_y = flp_y.begin(); itr_x != flp_x.end();
itr_x++, itr_y++) {
- DISPATCH_FLOAT_PT_TYPES(x_pt_type, "x_pt_type", [&]() {
+ DISPATCH_FLOAT_PT_TYPES(x_pt_type, [&]() {
auto* ptr_x = reinterpret_cast(&*itr_x);
- DISPATCH_FLOAT_PT_TYPES(y_pt_type, "y_pt_type", [&]() {
+ DISPATCH_FLOAT_PT_TYPES(y_pt_type, [&]() {
auto* ptr_y = reinterpret_cast(&*itr_y);
*ptr_x = fn(*ptr_x, *ptr_y);
});
diff --git a/libspu/kernel/hal/permute.cc b/libspu/kernel/hal/permute.cc
index 1c67c997..7ae47cda 100644
--- a/libspu/kernel/hal/permute.cc
+++ b/libspu/kernel/hal/permute.cc
@@ -553,7 +553,8 @@ std::vector _bit_decompose(SPUContext *ctx, const spu::Value &x,
rets_b.reserve(nbits);
for (size_t bit = 0; bit < nbits; ++bit) {
- auto x_bshare_shift = right_shift_logical(ctx, x_bshare, bit);
+ auto x_bshare_shift =
+ right_shift_logical(ctx, x_bshare, {static_cast(bit)});
rets_b.push_back(_and(ctx, x_bshare_shift, k1));
}
@@ -703,10 +704,10 @@ spu::Value _apply_perm_ss(SPUContext *ctx, const Value &x, const Value &perm) {
// Find mergeable keys from keys. Consecutive public/private(belong to one
// owner) keys can be merged. Assume there are six keys, i.e., public_key0,
-// bob_key0, bob_key1, alice_key0, alice_key1, secret_key0. We can merge the six
-// keys into bob_new_key, alice_new_key, secret_key0 for the following sorting.
-// This function will return a vector of indices [3,5,6] which means key[0,3),
-// key[3,5), and key[5,6) can be merged.
+// bob_key0, bob_key1, alice_key0, alice_key1, secret_key0. We can merge the
+// six keys into bob_new_key, alice_new_key, secret_key0 for the following
+// sorting. This function will return a vector of indices [3,5,6] which means
+// key[0,3), key[3,5), and key[5,6) can be merged.
std::vector _find_mergeable_keys(SPUContext *ctx,
absl::Span keys) {
std::vector split_indices;
diff --git a/libspu/kernel/hal/polymorphic.cc b/libspu/kernel/hal/polymorphic.cc
index 34cf5f77..21680cc6 100644
--- a/libspu/kernel/hal/polymorphic.cc
+++ b/libspu/kernel/hal/polymorphic.cc
@@ -356,7 +356,7 @@ Value power(SPUContext* ctx, const Value& x, const Value& y) {
const auto bit_width = SizeOf(ctx->getField()) * 8;
auto y_b = _prefer_b(ctx, y);
- auto msb_y = _rshift(ctx, y_b, bit_width - 1);
+ auto msb_y = _rshift(ctx, y_b, {static_cast(bit_width - 1)});
auto x_abs1 = _equal(ctx, abs(ctx, x), k1);
auto ret = _constant(ctx, 1, x.shape());
@@ -379,7 +379,9 @@ Value power(SPUContext* ctx, const Value& x, const Value& y) {
// e.g. y=0101, then ret = (x) * (1) * (x^(2^2)) * (1) = x^5
for (size_t idx = 0; idx < y_bits; idx++) {
// x^(2^idx) * y_{idx}
- auto cur_pow = _mux(ctx, _and(ctx, _rshift(ctx, y_b, idx), k1), base, k1);
+ auto cur_pow = _mux(
+ ctx, _and(ctx, _rshift(ctx, y_b, {static_cast(idx)}), k1),
+ base, k1);
ret = _mul(ctx, cur_pow, ret);
if (idx < y_bits - 1) {
base = _mul(ctx, base, base);
@@ -409,8 +411,9 @@ Value power(SPUContext* ctx, const Value& x, const Value& y) {
// the final sign is decided on both sign of x and the parity of y
// when x<0 and y is odd, e.g. (-2)^3 = -8
- auto odd = _and(ctx, _rshift(ctx, y, ctx->getFxpBits()),
- _constant(ctx, 1, y.shape()));
+ auto odd =
+ _and(ctx, _rshift(ctx, y, {static_cast(ctx->getFxpBits())}),
+ _constant(ctx, 1, y.shape()));
auto sign = _and(ctx, msb, odd);
return _mux(ctx, sign, _negate(ctx, val), val).setDtype(x.dtype());
@@ -488,19 +491,20 @@ Value bitcast(SPUContext* ctx, const Value& x, DataType dtype) {
return Value(x.data().clone(), dtype);
}
-Value left_shift(SPUContext* ctx, const Value& x, size_t bits) {
+Value left_shift(SPUContext* ctx, const Value& x, const Sizes& bits) {
SPU_TRACE_HAL_DISP(ctx, x, bits);
return _lshift(ctx, x, bits).setDtype(x.dtype());
}
-Value right_shift_logical(SPUContext* ctx, const Value& x, size_t bits) {
+Value right_shift_logical(SPUContext* ctx, const Value& x, const Sizes& bits) {
SPU_TRACE_HAL_DISP(ctx, x, bits);
return _rshift(ctx, x, bits).setDtype(x.dtype());
}
-Value right_shift_arithmetic(SPUContext* ctx, const Value& x, size_t bits) {
+Value right_shift_arithmetic(SPUContext* ctx, const Value& x,
+ const Sizes& bits) {
SPU_TRACE_HAL_DISP(ctx, x, bits);
return _arshift(ctx, x, bits).setDtype(x.dtype());
diff --git a/libspu/kernel/hal/polymorphic.h b/libspu/kernel/hal/polymorphic.h
index 58e90f00..9b47b85d 100644
--- a/libspu/kernel/hal/polymorphic.h
+++ b/libspu/kernel/hal/polymorphic.h
@@ -187,11 +187,12 @@ Value clamp(SPUContext* ctx, const Value& x, const Value& min,
// @param dtype, second input value
Value bitcast(SPUContext* ctx, const Value& x, DataType dtype);
-Value left_shift(SPUContext* ctx, const Value& x, size_t bits);
+Value left_shift(SPUContext* ctx, const Value& x, const Sizes& bits);
-Value right_shift_logical(SPUContext* ctx, const Value& x, size_t bits);
+Value right_shift_logical(SPUContext* ctx, const Value& x, const Sizes& bits);
-Value right_shift_arithmetic(SPUContext* ctx, const Value& x, size_t bits);
+Value right_shift_arithmetic(SPUContext* ctx, const Value& x,
+ const Sizes& bits);
/// the element-wise base-2 logarithm of x
// @param in, should be positive, or the result is implementation defined.
diff --git a/libspu/kernel/hal/prot_wrapper.cc b/libspu/kernel/hal/prot_wrapper.cc
index 10743c10..23f46388 100644
--- a/libspu/kernel/hal/prot_wrapper.cc
+++ b/libspu/kernel/hal/prot_wrapper.cc
@@ -30,11 +30,11 @@ namespace spu::kernel::hal {
return mpc::NAME(ctx, in); \
}
-#define MAP_SHIFT_OP(NAME) \
- Value _##NAME(SPUContext* ctx, const Value& in, size_t bits) { \
- SPU_TRACE_HAL_DISP(ctx, in, bits); \
- auto ret = mpc::NAME(ctx, in, bits); \
- return ret; \
+#define MAP_SHIFT_OP(NAME) \
+ Value _##NAME(SPUContext* ctx, const Value& in, const Sizes& bits) { \
+ SPU_TRACE_HAL_DISP(ctx, in, bits); \
+ auto ret = mpc::NAME(ctx, in, bits); \
+ return ret; \
}
#define MAP_BITREV_OP(NAME) \
diff --git a/libspu/kernel/hal/prot_wrapper.h b/libspu/kernel/hal/prot_wrapper.h
index 6294f834..5fb110bc 100644
--- a/libspu/kernel/hal/prot_wrapper.h
+++ b/libspu/kernel/hal/prot_wrapper.h
@@ -52,17 +52,17 @@ Value _equal_pp(SPUContext* ctx, const Value& x, const Value& y);
std::optional _equal_sp(SPUContext* ctx, const Value& x, const Value& y);
std::optional _equal_ss(SPUContext* ctx, const Value& x, const Value& y);
-Value _lshift_p(SPUContext* ctx, const Value& in, size_t bits);
-Value _lshift_s(SPUContext* ctx, const Value& in, size_t bits);
-Value _lshift_v(SPUContext* ctx, const Value& in, size_t bits);
+Value _lshift_p(SPUContext* ctx, const Value& in, const Sizes& bits);
+Value _lshift_s(SPUContext* ctx, const Value& in, const Sizes& bits);
+Value _lshift_v(SPUContext* ctx, const Value& in, const Sizes& bits);
-Value _rshift_p(SPUContext* ctx, const Value& in, size_t bits);
-Value _rshift_s(SPUContext* ctx, const Value& in, size_t bits);
-Value _rshift_v(SPUContext* ctx, const Value& in, size_t bits);
+Value _rshift_p(SPUContext* ctx, const Value& in, const Sizes& bits);
+Value _rshift_s(SPUContext* ctx, const Value& in, const Sizes& bits);
+Value _rshift_v(SPUContext* ctx, const Value& in, const Sizes& bits);
-Value _arshift_p(SPUContext* ctx, const Value& in, size_t bits);
-Value _arshift_s(SPUContext* ctx, const Value& in, size_t bits);
-Value _arshift_v(SPUContext* ctx, const Value& in, size_t bits);
+Value _arshift_p(SPUContext* ctx, const Value& in, const Sizes& bits);
+Value _arshift_s(SPUContext* ctx, const Value& in, const Sizes& bits);
+Value _arshift_v(SPUContext* ctx, const Value& in, const Sizes& bits);
Value _trunc_p(SPUContext* ctx, const Value& in, size_t bits, SignType sign);
Value _trunc_s(SPUContext* ctx, const Value& in, size_t bits, SignType sign);
diff --git a/libspu/kernel/hal/ring.cc b/libspu/kernel/hal/ring.cc
index 08b873f6..b016cd71 100644
--- a/libspu/kernel/hal/ring.cc
+++ b/libspu/kernel/hal/ring.cc
@@ -84,18 +84,18 @@ IMPL_UNARY_OP(_square)
#undef IMPL_UNARY_OP
-#define IMPL_SHIFT_OP(Name) \
- Value Name(SPUContext* ctx, const Value& in, size_t bits) { \
- SPU_TRACE_HAL_LEAF(ctx, in, bits); \
- if (in.isPublic()) { \
- return Name##_p(ctx, in, bits); \
- } else if (in.isSecret()) { \
- return Name##_s(ctx, in, bits); \
- } else if (in.isPrivate()) { \
- return Name##_v(ctx, in, bits); \
- } else { \
- SPU_THROW("unsupport unary op={} for {}", #Name, in); \
- } \
+#define IMPL_SHIFT_OP(Name) \
+ Value Name(SPUContext* ctx, const Value& in, const Sizes& bits) { \
+ SPU_TRACE_HAL_LEAF(ctx, in, bits); \
+ if (in.isPublic()) { \
+ return Name##_p(ctx, in, bits); \
+ } else if (in.isSecret()) { \
+ return Name##_s(ctx, in, bits); \
+ } else if (in.isPrivate()) { \
+ return Name##_v(ctx, in, bits); \
+ } else { \
+ SPU_THROW("unsupport unary op={} for {}", #Name, in); \
+ } \
}
IMPL_SHIFT_OP(_lshift)
@@ -497,7 +497,7 @@ Value _bit_parity(SPUContext* ctx, const Value& x, size_t bits) {
SPU_ENFORCE(absl::has_single_bit(bits), "currently only support power of 2");
auto ret = _prefer_b(ctx, x);
while (bits > 1) {
- ret = _xor(ctx, ret, _rshift(ctx, ret, bits / 2));
+ ret = _xor(ctx, ret, _rshift(ctx, ret, {static_cast(bits / 2)}));
bits /= 2;
}
@@ -518,7 +518,7 @@ Value _popcount(SPUContext* ctx, const Value& x, size_t bits) {
std::vector vs;
vs.reserve(bits);
for (size_t idx = 0; idx < bits; idx++) {
- auto x_ = _rshift(ctx, xb, idx);
+ auto x_ = _rshift(ctx, xb, {static_cast(idx)});
x_ = _and(ctx, x_, _constant(ctx, 1U, x.shape()));
if (x_.storage_type().isa()) {
@@ -547,8 +547,8 @@ Value _prefix_or(SPUContext* ctx, const Value& x) {
auto b0 = _prefer_b(ctx, x);
const size_t bit_width = SizeOf(ctx->getField()) * 8;
for (int idx = 0; idx < absl::bit_width(bit_width) - 1; idx++) {
- const size_t offset = 1UL << idx;
- auto b1 = _rshift(ctx, b0, offset);
+ const int64_t offset = 1L << idx;
+ auto b1 = _rshift(ctx, b0, {offset});
b0 = _or(ctx, b0, b1);
}
return b0;
@@ -574,8 +574,8 @@ Value _bitdeintl(SPUContext* ctx, const Value& in) {
// out = (out & keep) ^ ((out >> shift) & move) ^ ((out & move) << shift);
out = _xor(ctx,
_xor(ctx, _and(ctx, out, keep),
- _and(ctx, _rshift(ctx, out, shift), move)),
- _lshift(ctx, _and(ctx, out, move), shift));
+ _and(ctx, _rshift(ctx, out, {shift}), move)),
+ _lshift(ctx, _and(ctx, out, move), {shift}));
}
return out;
}
diff --git a/libspu/kernel/hal/ring.h b/libspu/kernel/hal/ring.h
index b2fb6f65..0dd7234a 100644
--- a/libspu/kernel/hal/ring.h
+++ b/libspu/kernel/hal/ring.h
@@ -71,11 +71,11 @@ Value _equal(SPUContext* ctx, const Value& x, const Value& y);
Value _less(SPUContext* ctx, const Value& x, const Value& y);
-Value _lshift(SPUContext* ctx, const Value& in, size_t bits);
+Value _lshift(SPUContext* ctx, const Value& in, const Sizes& bits);
-Value _rshift(SPUContext* ctx, const Value& in, size_t bits);
+Value _rshift(SPUContext* ctx, const Value& in, const Sizes& bits);
-Value _arshift(SPUContext* ctx, const Value& in, size_t bits);
+Value _arshift(SPUContext* ctx, const Value& in, const Sizes& bits);
Value _trunc(SPUContext* ctx, const Value& x, size_t bits = 0,
SignType sign = SignType::Unknown);
diff --git a/libspu/kernel/hal/type_cast.cc b/libspu/kernel/hal/type_cast.cc
index 53b65742..9adb77a8 100644
--- a/libspu/kernel/hal/type_cast.cc
+++ b/libspu/kernel/hal/type_cast.cc
@@ -27,7 +27,8 @@ Value int2fxp(SPUContext* ctx, const Value& x, DataType to_type) {
SPU_TRACE_HAL_LEAF(ctx, x);
SPU_ENFORCE(x.isInt(), "expect integer, got {}", x.dtype());
- return _lshift(ctx, x, ctx->getFxpBits()).setDtype(to_type);
+ return _lshift(ctx, x, {static_cast(ctx->getFxpBits())})
+ .setDtype(to_type);
}
// Casting fxp to integer.
@@ -49,12 +50,12 @@ Value fxp2int(SPUContext* ctx, const Value& x, DataType to_type) {
SPU_TRACE_HAL_LEAF(ctx, x);
SPU_ENFORCE(x.isFxp());
- const size_t fxp_bits = ctx->getFxpBits();
+ const int64_t fxp_bits = ctx->getFxpBits();
const Value kOneMinusEps = _constant(ctx, (1 << fxp_bits) - 1, x.shape());
// (x + 0.99 * (x < 0)) >> fxp_bits
return _arshift(ctx, _add(ctx, x, _mul(ctx, kOneMinusEps, _msb(ctx, x))),
- fxp_bits)
+ {fxp_bits})
.setDtype(to_type);
}
diff --git a/libspu/kernel/hlo/basic_unary.cc b/libspu/kernel/hlo/basic_unary.cc
index ce1d2c17..44cc9e36 100644
--- a/libspu/kernel/hlo/basic_unary.cc
+++ b/libspu/kernel/hlo/basic_unary.cc
@@ -118,19 +118,19 @@ spu::Value Round_RNTE(SPUContext *ctx, const spu::Value &in) {
// so comp = b && (c || a)
SPU_ENFORCE(!in.isComplex());
SPU_ENFORCE(in.isFxp(), "Round only supports fxp");
- const auto fxp_bits = ctx->getFxpBits();
+ const int64_t fxp_bits = ctx->getFxpBits();
const auto k1 = hal::_constant(ctx, 1U, in.shape());
auto x_prime = hal::_prefer_b(ctx, in);
auto y = hal::floor(ctx, x_prime);
- auto a = hal::_and(ctx, hal::_rshift(ctx, x_prime, fxp_bits), k1);
- auto b = hal::_and(ctx, hal::_rshift(ctx, x_prime, fxp_bits - 1), k1);
+ auto a = hal::_and(ctx, hal::_rshift(ctx, x_prime, {fxp_bits}), k1);
+ auto b = hal::_and(ctx, hal::_rshift(ctx, x_prime, {fxp_bits - 1}), k1);
std::vector cs;
cs.reserve(fxp_bits - 1);
- for (size_t idx = 0; idx < fxp_bits - 1; idx++) {
- auto x_ = hal::_and(ctx, hal::_rshift(ctx, x_prime, idx), k1);
+ for (int64_t idx = 0; idx < fxp_bits - 1; idx++) {
+ auto x_ = hal::_and(ctx, hal::_rshift(ctx, x_prime, {idx}), k1);
cs.push_back(std::move(x_));
}
auto c = vreduce(cs.begin(), cs.end(), [&](const Value &a, const Value &b) {
diff --git a/libspu/kernel/hlo/shift.cc b/libspu/kernel/hlo/shift.cc
index 08fe7140..0b6f3f1d 100644
--- a/libspu/kernel/hlo/shift.cc
+++ b/libspu/kernel/hlo/shift.cc
@@ -26,31 +26,8 @@ namespace spu::kernel::hlo {
template
spu::Value shift_impl_p(SPUContext *ctx, const spu::Value &lhs,
const spu::Value &rhs, const Fn &f) {
- auto shift_bits = hal::dump_public_as(ctx, rhs);
- if (std::all_of(rhs.strides().begin(), rhs.strides().end(),
- [](int64_t s) { return s == 0; })) {
- // rhs is a splat
- return f(ctx, lhs, shift_bits[0]);
- }
-
- // Not a splat...
- spu::Value ret =
- hal::constant(ctx, static_cast(0), lhs.dtype(), lhs.shape());
- auto dtype_size = getWidth(lhs.dtype());
- for (size_t bits = 0; bits < dtype_size; ++bits) {
- if (std::none_of(shift_bits.begin(), shift_bits.end(), [&bits](int8_t b) {
- return b == static_cast(bits);
- })) {
- continue;
- }
- auto current_bits = hal::constant(ctx, static_cast(bits),
- rhs.dtype(), rhs.shape());
- auto mask = hal::equal(ctx, rhs, current_bits);
- auto shifted = f(ctx, lhs, bits);
- ret = hal::add(ctx, ret, hal::mul(ctx, mask, shifted));
- }
-
- return ret;
+ auto shift_bits = hal::dump_public_as(ctx, rhs);
+ return f(ctx, lhs, {shift_bits.begin(), shift_bits.end()});
}
template
@@ -63,7 +40,7 @@ spu::Value shift_impl_s(SPUContext *ctx, const spu::Value &lhs,
auto current_bits = hal::constant(ctx, static_cast(bits),
rhs.dtype(), rhs.shape());
auto mask = hal::equal(ctx, rhs, current_bits);
- auto shifted = f(ctx, lhs, bits);
+ auto shifted = f(ctx, lhs, {static_cast(bits)});
ret = hal::add(ctx, ret, hal::mul(ctx, mask, shifted));
}
diff --git a/libspu/mpc/ab_api.cc b/libspu/mpc/ab_api.cc
index 7d02fe47..a0720870 100644
--- a/libspu/mpc/ab_api.cc
+++ b/libspu/mpc/ab_api.cc
@@ -133,7 +133,7 @@ OptionalAPI mul_a1bv(SPUContext* ctx, const Value& x, const Value& y) {
return NotAvailable;
}
-Value lshift_a(SPUContext* ctx, const Value& x, size_t nbits) {
+Value lshift_a(SPUContext* ctx, const Value& x, const Sizes& nbits) {
FORCE_DISPATCH(ctx, x, nbits);
}
@@ -201,15 +201,15 @@ OptionalAPI xor_bv(SPUContext* ctx, const Value& x, const Value& y) {
return NotAvailable;
}
-Value lshift_b(SPUContext* ctx, const Value& x, size_t nbits) {
+Value lshift_b(SPUContext* ctx, const Value& x, const Sizes& nbits) {
FORCE_DISPATCH(ctx, x, nbits);
}
-Value rshift_b(SPUContext* ctx, const Value& x, size_t nbits) {
+Value rshift_b(SPUContext* ctx, const Value& x, const Sizes& nbits) {
FORCE_DISPATCH(ctx, x, nbits);
}
-Value arshift_b(SPUContext* ctx, const Value& x, size_t nbits) {
+Value arshift_b(SPUContext* ctx, const Value& x, const Sizes& nbits) {
FORCE_DISPATCH(ctx, x, nbits);
}
@@ -254,10 +254,10 @@ Value bitintl_b(SPUContext* ctx, const Value& x, size_t stride) {
auto M = hack_make_p(ctx, spu::detail::kBitIntlSwapMasks[idx], x.shape());
int64_t S = static_cast(1) << idx;
// out = (out & K) ^ ((out >> S) & M) ^ ((out & M) << S);
- out = xor_bb(
- ctx,
- xor_bb(ctx, and_bp(ctx, out, K), and_bp(ctx, rshift_b(ctx, out, S), M)),
- lshift_b(ctx, and_bp(ctx, out, M), S));
+ out = xor_bb(ctx,
+ xor_bb(ctx, and_bp(ctx, out, K),
+ and_bp(ctx, rshift_b(ctx, out, {S}), M)),
+ lshift_b(ctx, and_bp(ctx, out, M), {S}));
}
out = setNumBits(out, numBits(x));
return out;
@@ -283,10 +283,10 @@ Value bitdeintl_b(SPUContext* ctx, const Value& x, size_t stride) {
auto M = hack_make_p(ctx, spu::detail::kBitIntlSwapMasks[idx], x.shape());
int64_t S = static_cast(1) << idx;
// out = (out & K) ^ ((out >> S) & M) ^ ((out & M) << S);
- out = xor_bb(
- ctx,
- xor_bb(ctx, and_bp(ctx, out, K), and_bp(ctx, rshift_b(ctx, out, S), M)),
- lshift_b(ctx, and_bp(ctx, out, M), S));
+ out = xor_bb(ctx,
+ xor_bb(ctx, and_bp(ctx, out, K),
+ and_bp(ctx, rshift_b(ctx, out, {S}), M)),
+ lshift_b(ctx, and_bp(ctx, out, M), {S}));
}
out = setNumBits(out, numBits(x));
return out;
@@ -318,9 +318,9 @@ Value ppa_kogge_stone(SPUContext* ctx, const Value& lhs, const Value& rhs,
auto G = and_bb(ctx, lhs, rhs);
for (int idx = 0; idx < Log2Ceil(nbits); ++idx) {
- const size_t offset = 1UL << idx;
- auto G1 = lshift_b(ctx, G, offset);
- auto P1 = lshift_b(ctx, P, offset);
+ const int64_t offset = static_cast(1) << idx;
+ auto G1 = lshift_b(ctx, G, {offset});
+ auto P1 = lshift_b(ctx, P, {offset});
// P1 = P & P1
// G1 = G ^ (P & G1)
@@ -332,7 +332,7 @@ Value ppa_kogge_stone(SPUContext* ctx, const Value& lhs, const Value& rhs,
}
// out = (G << 1) ^ p0
- auto C = lshift_b(ctx, G, 1);
+ auto C = lshift_b(ctx, G, {1});
return xor_bb(ctx, xor_bb(ctx, lhs, rhs), C);
}
@@ -343,7 +343,7 @@ std::pair bit_scatter(SPUContext* ctx, const Value& in,
SPU_ENFORCE(absl::has_single_bit(nbits), "unsupported {}", nbits);
auto out = bitdeintl_b(ctx, in, stride);
- auto hi = rshift_b(ctx, out, nbits / 2);
+ auto hi = rshift_b(ctx, out, {static_cast(nbits / 2)});
auto mask = hack_make_p(ctx, (static_cast(1) << (nbits / 2)) - 1,
in.shape());
auto lo = and_bp(ctx, out, mask);
@@ -357,7 +357,7 @@ Value bit_gather(SPUContext* ctx, const Value& hi, const Value& lo,
SPU_ENFORCE(nbits == numBits(lo), "nbits mismatch {}, {}", nbits,
numBits(lo));
- auto out = xor_bb(ctx, lshift_b(ctx, hi, nbits), lo);
+ auto out = xor_bb(ctx, lshift_b(ctx, hi, {static_cast(nbits)}), lo);
return bitintl_b(ctx, out, stride);
}
@@ -395,8 +395,8 @@ Value ppa_sklansky(SPUContext* ctx, Value const& lhs, Value const& rhs,
auto Gs = and_bp(ctx, Gl, s_mask);
auto Ps = and_bp(ctx, Pl, s_mask);
for (int j = 0; j < idx; j++) {
- Gs = xor_bb(ctx, Gs, rshift_b(ctx, Gs, 1 << j));
- Ps = xor_bb(ctx, Ps, rshift_b(ctx, Ps, 1 << j));
+ Gs = xor_bb(ctx, Gs, rshift_b(ctx, Gs, {1 << j}));
+ Ps = xor_bb(ctx, Ps, rshift_b(ctx, Ps, {1 << j}));
}
// SPU_ENFORCE(numBits(Ps) == bit_width / 2);
// SPU_ENFORCE(numBits(Gs) == bit_width / 2);
@@ -416,7 +416,7 @@ Value ppa_sklansky(SPUContext* ctx, Value const& lhs, Value const& rhs,
}
// out = (G0 << 1) ^ p0
- auto C = lshift_b(ctx, G, 1);
+ auto C = lshift_b(ctx, G, {1});
return xor_bb(ctx, xor_bb(ctx, lhs, rhs), C);
}
@@ -460,8 +460,8 @@ Value carry_a2b(SPUContext* ctx, const Value& x, const Value& y, size_t k) {
while (k > 1) {
if (k % 2 != 0) {
k += 1;
- P = lshift_b(ctx, P, 1);
- G = lshift_b(ctx, G, 1);
+ P = lshift_b(ctx, P, {1});
+ G = lshift_b(ctx, G, {1});
}
auto [P1, P0] = bit_scatter(ctx, P, 0);
auto [G1, G0] = bit_scatter(ctx, G, 0);
diff --git a/libspu/mpc/ab_api.h b/libspu/mpc/ab_api.h
index 72a8476f..94f17d98 100644
--- a/libspu/mpc/ab_api.h
+++ b/libspu/mpc/ab_api.h
@@ -46,7 +46,7 @@ OptionalAPI mul_av(SPUContext* ctx, const Value& x, const Value& y);
Value mul_a1b(SPUContext* ctx, const Value& x, const Value& y);
OptionalAPI mul_a1bv(SPUContext* ctx, const Value& x, const Value& y);
-Value lshift_a(SPUContext* ctx, const Value& x, size_t nbits);
+Value lshift_a(SPUContext* ctx, const Value& x, const Sizes& nbits);
Value trunc_a(SPUContext* ctx, const Value& x, size_t nbits, SignType sign);
Value mmul_ap(SPUContext* ctx, const Value& x, const Value& y);
@@ -72,9 +72,9 @@ Value xor_bb(SPUContext* ctx, const Value& x, const Value& y);
OptionalAPI xor_bv(SPUContext* ctx, const Value& x,
const Value& y); // TODO
-Value lshift_b(SPUContext* ctx, const Value& x, size_t nbits);
-Value rshift_b(SPUContext* ctx, const Value& x, size_t nbits);
-Value arshift_b(SPUContext* ctx, const Value& x, size_t nbits);
+Value lshift_b(SPUContext* ctx, const Value& x, const Sizes& nbits);
+Value rshift_b(SPUContext* ctx, const Value& x, const Sizes& nbits);
+Value arshift_b(SPUContext* ctx, const Value& x, const Sizes& nbits);
// Bit reverse for binary share.
Value bitrev_b(SPUContext* ctx, const Value& x, size_t start, size_t end);
diff --git a/libspu/mpc/ab_api_test.cc b/libspu/mpc/ab_api_test.cc
index 13ef4080..592f9389 100644
--- a/libspu/mpc/ab_api_test.cc
+++ b/libspu/mpc/ab_api_test.cc
@@ -195,7 +195,7 @@ TEST_P(ArithmeticTest, MulA1B) {
return;
}
- const size_t K = spu::SizeOf(conf.field()) * 8;
+ const int64_t K = spu::SizeOf(conf.field()) * 8;
/* GIVEN */
auto p0 = rand_p(obj.get(), conf.protocol() == ProtocolKind::CHEETAH
@@ -204,12 +204,12 @@ TEST_P(ArithmeticTest, MulA1B) {
auto p1 = rand_p(obj.get(), conf.protocol() == ProtocolKind::CHEETAH
? Shape({200, 26})
: kShape);
- p1 = rshift_p(obj.get(), p1, K - 1);
+ p1 = rshift_p(obj.get(), p1, {K - 1});
auto a0 = p2a(obj.get(), p0);
auto a1 = p2b(obj.get(), p1);
// hint runtime this is a 1bit value.
- a1 = lshift_b(obj.get(), a1, K - 1);
- a1 = rshift_b(obj.get(), a1, K - 1);
+ a1 = lshift_b(obj.get(), a1, {K - 1});
+ a1 = rshift_b(obj.get(), a1, {K - 1});
/* WHEN */
auto prev = obj->prot()->getState()->getStats();
@@ -238,12 +238,12 @@ TEST_P(ArithmeticTest, MulAV) {
return;
}
- const size_t K = spu::SizeOf(conf.field()) * 8;
+ const int64_t K = spu::SizeOf(conf.field()) * 8;
/* GIVEN */
auto p0 = rand_p(obj.get(), kShape);
auto p1 = rand_p(obj.get(), kShape);
- p1 = rshift_p(obj.get(), p1, K - 1);
+ p1 = rshift_p(obj.get(), p1, {K - 1});
auto a0 = p2a(obj.get(), p0);
auto a1 = p2v(obj.get(), p1, 0);
@@ -275,17 +275,17 @@ TEST_P(ArithmeticTest, MulA1BV) {
return;
}
- const size_t K = spu::SizeOf(conf.field()) * 8;
+ const int64_t K = spu::SizeOf(conf.field()) * 8;
/* GIVEN */
auto p0 = rand_p(obj.get(), kShape);
auto p1 = rand_p(obj.get(), kShape);
- p1 = rshift_p(obj.get(), p1, K - 1);
+ p1 = rshift_p(obj.get(), p1, {K - 1});
auto a0 = p2a(obj.get(), p0);
auto a1 = p2v(obj.get(), p1, 0);
// hint runtime this is a 1bit value.
- a1 = lshift_v(obj.get(), a1, K - 1);
- a1 = rshift_v(obj.get(), a1, K - 1);
+ a1 = lshift_v(obj.get(), a1, {K - 1});
+ a1 = rshift_v(obj.get(), a1, {K - 1});
// auto a1 = b2v(obj.get(), _a1, 0);
/* WHEN */
@@ -482,10 +482,10 @@ TEST_P(ArithmeticTest, LShiftA) {
}
/* WHEN */
auto prev = obj->prot()->getState()->getStats();
- auto tmp = lshift_a(obj.get(), a0, bits);
+ auto tmp = lshift_a(obj.get(), a0, {static_cast(bits)});
auto cost = obj->prot()->getState()->getStats() - prev;
auto r_b = a2p(obj.get(), tmp);
- auto r_p = lshift_p(obj.get(), p0, bits);
+ auto r_p = lshift_p(obj.get(), p0, {static_cast(bits)});
/* THEN */
EXPECT_VALUE_EQ(r_b, r_p);
@@ -513,10 +513,11 @@ TEST_P(ArithmeticTest, TruncA) {
if (!kernel->hasMsbError()) {
// trunc requires MSB to be zero.
- p0 = arshift_p(obj.get(), p0, 1);
+ p0 = arshift_p(obj.get(), p0, {1});
} else {
// has msb error, only use lowest 10 bits.
- p0 = arshift_p(obj.get(), p0, SizeOf(conf.field()) * 8 - 10);
+ p0 = arshift_p(obj.get(), p0,
+ {static_cast(SizeOf(conf.field()) * 8 - 10)});
}
/* GIVEN */
@@ -529,7 +530,7 @@ TEST_P(ArithmeticTest, TruncA) {
auto cost = obj->prot()->getState()->getStats() - prev;
auto r_a = a2p(obj.get(), a1);
- auto r_p = arshift_p(obj.get(), p0, bits);
+ auto r_p = arshift_p(obj.get(), p0, {static_cast