diff --git a/.circleci/config.yml b/.circleci/config.yml index 41e5522e..c4617cd9 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -17,7 +17,7 @@ version: 2.1 setup: true orbs: - path-filtering: circleci/path-filtering@1.0.0 + path-filtering: circleci/path-filtering@1.1.0 continuation: circleci/continuation@1.0.0 parameters: diff --git a/.clang-tidy b/.clang-tidy index bacbb3a2..227a407d 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -28,7 +28,8 @@ Checks: "abseil-cleanup-ctad, -readability-identifier-length, -readability-function-cognitive-complexity, -readability-magic-numbers, - -readability-named-parameter" + -readability-named-parameter, + -readability-convert-member-functions-to-static" CheckOptions: - key: bugprone-argument-comment.StrictMode diff --git a/.github/workflows/scorecard.yml b/.github/workflows/scorecard.yml index 68978a19..c446b981 100644 --- a/.github/workflows/scorecard.yml +++ b/.github/workflows/scorecard.yml @@ -37,7 +37,7 @@ jobs: persist-credentials: false - name: "Run analysis" - uses: ossf/scorecard-action@dc50aa9510b46c811795eb24b2f1ba02a914e534 # v2.3.3 + uses: ossf/scorecard-action@62b2cac7ed8198b15735ed49ab1e5cf35480ba46 # v2.4.0 with: results_file: results.sarif results_format: sarif @@ -67,6 +67,6 @@ jobs: # Upload the results to GitHub's code scanning dashboard. - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@23acc5c183826b7a8a97bce3cecc52db901f8251 # v3.25.10 + uses: github/codeql-action/upload-sarif@afb54ba388a7dca6ecae48f608c4ff05ff4cc77a # v3.25.15 with: sarif_file: results.sarif diff --git a/CHANGELOG.md b/CHANGELOG.md index 0712ac66..27502fd6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,8 +10,11 @@ > > please add your unreleased change here. -## TBD +- [Feature] Add more send/recv actions profiling +## 20240716 + +- [SPU] 0.9.2b0 release - [Feature] Support jax.numpy.bitwise_count - [Bugfix] Fix jax.numpy.signbit wrong answer with very large input diff --git a/README.md b/README.md index 56375409..6a1fe7e3 100644 --- a/README.md +++ b/README.md @@ -71,4 +71,4 @@ If you think SPU is helpful for your research or development, please consider ci ## Acknowledgement -We thank the significant contributions made by [Alibaba Gemini Lab](https://alibaba-gemini-lab.github.io). +We thank the significant contributions made by [Alibaba Gemini Lab](https://alibaba-gemini-lab.github.io) and security advisories made by [VUL337@NISL@THU](https://netsec.ccert.edu.cn/vul337). diff --git a/bazel/repositories.bzl b/bazel/repositories.bzl index ca1b5349..1318e460 100644 --- a/bazel/repositories.bzl +++ b/bazel/repositories.bzl @@ -40,10 +40,10 @@ def _yacl(): http_archive, name = "yacl", urls = [ - "https://github.com/secretflow/yacl/archive/refs/tags/0.4.5b1.tar.gz", + "https://github.com/secretflow/yacl/archive/refs/tags/0.4.5b3_nightly_20240722.tar.gz", ], - strip_prefix = "yacl-0.4.5b1", - sha256 = "28064053b9add0db8e1e8e648421a0579f1d3e7ee8a4bbd7bd5959cb59598088", + strip_prefix = "yacl-0.4.5b3_nightly_20240722", + sha256 = "ccca599e6ded6089c5afbb87c8f5e09383195af256caacd50089f0c7443e8604", ) def _libpsi(): @@ -51,10 +51,10 @@ def _libpsi(): http_archive, name = "psi", urls = [ - "https://github.com/secretflow/psi/archive/refs/tags/v0.4.0.dev240524.tar.gz", + "https://github.com/secretflow/psi/archive/refs/tags/v0.4.1.dev240722.tar.gz", ], - strip_prefix = "psi-0.4.0.dev240524", - sha256 = "c2868fa6a9d804e6bbed9922dab6dc819ec6e180e15eafe7eb1b661302508c88", + strip_prefix = "psi-0.4.1.dev240722", + sha256 = "878cd8af2c7b9850944a27adf91f21dd4937d09d38e8365baad3b5165db8b39a", ) def _rules_proto_grpc(): @@ -136,8 +136,8 @@ def _bazel_skylib(): ) def _com_github_openxla_xla(): - OPENXLA_COMMIT = "9b0dd58c9b625a2e958f4fc7787a1ff5c95dbb40" - OPENXLA_SHA256 = "f150c5b49e4d4497aae2c79232f1efe2baccaa72223b21dc8715be73eab74417" + OPENXLA_COMMIT = "04f2bfe797408c9efe742b89e2e4db6cf526ebb7" + OPENXLA_SHA256 = "7e1d24737815be7607eed5f02fe7f81d97ffe358dfb7b4876f97bce8f48b3b3e" # We need openxla to handle xla/mhlo/stablehlo maybe( @@ -169,10 +169,10 @@ def _com_github_pybind11(): http_archive, name = "pybind11", build_file = "@pybind11_bazel//:pybind11.BUILD", - sha256 = "bf8f242abd1abcd375d516a7067490fb71abd79519a282d22b6e4d19282185a7", - strip_prefix = "pybind11-2.12.0", + sha256 = "51631e88960a8856f9c497027f55c9f2f9115cafb08c0005439838a05ba17bfc", + strip_prefix = "pybind11-2.13.1", urls = [ - "https://github.com/pybind/pybind11/archive/refs/tags/v2.12.0.tar.gz", + "https://github.com/pybind/pybind11/archive/refs/tags/v2.13.1.tar.gz", ], ) @@ -229,7 +229,7 @@ def _com_github_microsoft_seal(): maybe( http_archive, name = "com_github_microsoft_seal", - sha256 = "78ef7334114de930daf7659e8ba60c5abfff85c86ec2b827a2b7c67c3c42da43", + sha256 = "acc2a1a127a85d1e1ffcca3ffd148f736e665df6d6b072df0e42fff64795a13c", strip_prefix = "SEAL-4.1.2", type = "tar.gz", patch_args = ["-p1"], diff --git a/docs/development/add_protocols.rst b/docs/development/add_protocols.rst index 31b4b5e2..988c846c 100644 --- a/docs/development/add_protocols.rst +++ b/docs/development/add_protocols.rst @@ -248,7 +248,7 @@ When kernels are implemented and registered, a new protocol is finally added. auto* prg_state = ctx->getState(); // dispatch the real implementation to different fields - return DISPATCH_ALL_FIELDS(field, "aby3.mulAA", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { // the real protocol implementation ... }); diff --git a/docs/reference/np_op_status.json b/docs/reference/np_op_status.json index 1272847c..20ed81a7 100644 --- a/docs/reference/np_op_status.json +++ b/docs/reference/np_op_status.json @@ -1 +1 @@ -[{"name": "abs", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "add", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "angle", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "angle", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "arccos", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "arccosh", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "nan"}, {"name": "arcsin", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "arcsinh", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "arctan", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "arctan2", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "arctanh", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "[-1, 1] nan"}, {"name": "argmax", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "argmin", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "array_equal", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "array_equiv", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "atleast_1d", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "atleast_2d", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "atleast_3d", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "bitwise_and", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "bitwise_not", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "bitwise_or", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "bitwise_xor", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "cbrt", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "stablehlo.cbrt"}, {"name": "ceil", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "conj", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "conjugate", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "copysign", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "shift"}, {"name": "cos", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "cosh", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "deg2rad", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "divide", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "divmod", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "ediff1d", "dtypes": ["int32"], "status": "Status.Pass", "note": ""}, {"name": "equal", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "exp", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "exp2", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "expm1", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "fabs", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "fix", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "float_power", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "floor", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "floor_divide", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "fmax", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "fmin", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "fmod", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "gcd", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "secret while"}, {"name": "greater", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "greater_equal", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "heaviside", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "hypot", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "i0", "dtypes": ["float32"], "status": "Status.Failed", "note": "accuracy"}, {"name": "imag", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "invert", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "isclose", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "iscomplex", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "isfinite", "dtypes": ["float32"], "status": "Status.UnSupport", "note": ""}, {"name": "isinf", "dtypes": ["float32"], "status": "Status.UnSupport", "note": ""}, {"name": "isnan", "dtypes": ["float32"], "status": "Status.UnSupport", "note": ""}, {"name": "isneginf", "dtypes": ["float32"], "status": "Status.UnSupport", "note": ""}, {"name": "isposinf", "dtypes": ["float32"], "status": "Status.UnSupport", "note": ""}, {"name": "isreal", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "isrealobj", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "kron", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "lcm", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "secret while"}, {"name": "ldexp", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Failed", "note": "IEEE-754"}, {"name": "left_shift", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "less", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "less_equal", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "log", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "log10", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "log1p", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "log1p", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "log2", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "logaddexp", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "logaddexp2", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "logical_and", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "logical_not", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "logical_or", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "logical_xor", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "maximum", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "mean", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "mean", "dtypes": ["float32"], "status": "Status.PassNoGen", "note": ""}, {"name": "minimum", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "mod", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "modf", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "multiply", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "nanargmax", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.UnSupport", "note": ""}, {"name": "nanargmin", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.UnSupport", "note": ""}, {"name": "nanmean", "dtypes": ["float32"], "status": "Status.UnSupport", "note": ""}, {"name": "nanprod", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.UnSupport", "note": ""}, {"name": "nansum", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.UnSupport", "note": ""}, {"name": "negative", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "nextafter", "dtypes": ["float32"], "status": "Status.Failed", "note": "IEEE-754"}, {"name": "not_equal", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "outer", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "polyval", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "positive", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "power", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "prod", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "rad2deg", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "ravel", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "real", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "reciprocal", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "remainder", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "right_shift", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "rint", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "sign", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "signbit", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "sin", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "sinc", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "sinh", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "sqrt", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "square", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "square", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "subtract", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "sum", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "tan", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "stablehlo.tan"}, {"name": "tanh", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "transpose", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "true_divide", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "trunc", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "unwrap", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}] \ No newline at end of file +[{"name": "abs", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "add", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "angle", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "angle", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "arccos", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "arccosh", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "nan"}, {"name": "arcsin", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "arcsinh", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "arctan", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "arctan2", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "arctanh", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "[-1, 1] nan"}, {"name": "argmax", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "argmin", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "array_equal", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "array_equiv", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "atleast_1d", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "atleast_2d", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "atleast_3d", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "bitwise_and", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "bitwise_count", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "bitwise_not", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "bitwise_or", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "bitwise_xor", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "cbrt", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "stablehlo.cbrt"}, {"name": "ceil", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "conj", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "conjugate", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "copysign", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "shift"}, {"name": "cos", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "cosh", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "deg2rad", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "divide", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "divmod", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "ediff1d", "dtypes": ["int32"], "status": "Status.Pass", "note": ""}, {"name": "equal", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "exp", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "exp2", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "expm1", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "fabs", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "fix", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "float_power", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "floor", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "floor_divide", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "fmax", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "fmin", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "fmod", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "gcd", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "secret while"}, {"name": "greater", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "greater_equal", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "heaviside", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "hypot", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "i0", "dtypes": ["float32"], "status": "Status.Failed", "note": "accuracy"}, {"name": "imag", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "invert", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "isclose", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "iscomplex", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "isfinite", "dtypes": ["float32"], "status": "Status.UnSupport", "note": ""}, {"name": "isinf", "dtypes": ["float32"], "status": "Status.UnSupport", "note": ""}, {"name": "isnan", "dtypes": ["float32"], "status": "Status.UnSupport", "note": ""}, {"name": "isneginf", "dtypes": ["float32"], "status": "Status.UnSupport", "note": ""}, {"name": "isposinf", "dtypes": ["float32"], "status": "Status.UnSupport", "note": ""}, {"name": "isreal", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "isrealobj", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "kron", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "lcm", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "secret while"}, {"name": "ldexp", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Failed", "note": "IEEE-754"}, {"name": "left_shift", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "less", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "less_equal", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "log", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "log10", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "log1p", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "log1p", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "log2", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "logaddexp", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "logaddexp2", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "logical_and", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "logical_not", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "logical_or", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "logical_xor", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "maximum", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "mean", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "mean", "dtypes": ["float32"], "status": "Status.PassNoGen", "note": ""}, {"name": "minimum", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "mod", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "modf", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "multiply", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "nanargmax", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.UnSupport", "note": ""}, {"name": "nanargmin", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.UnSupport", "note": ""}, {"name": "nanmean", "dtypes": ["float32"], "status": "Status.UnSupport", "note": ""}, {"name": "nanprod", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.UnSupport", "note": ""}, {"name": "nansum", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.UnSupport", "note": ""}, {"name": "negative", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "nextafter", "dtypes": ["float32"], "status": "Status.Failed", "note": "IEEE-754"}, {"name": "not_equal", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "outer", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "polyval", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "positive", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "power", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "prod", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "rad2deg", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "ravel", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "real", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "reciprocal", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "remainder", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "right_shift", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "rint", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "sign", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "signbit", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "sin", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "sinc", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "sinh", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "sqrt", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "square", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "square", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "subtract", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "sum", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "tan", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "stablehlo.tan"}, {"name": "tanh", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "transpose", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "true_divide", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "trunc", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "unwrap", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}] \ No newline at end of file diff --git a/docs/reference/np_op_status.md b/docs/reference/np_op_status.md index d69f20a0..ab3539a4 100644 --- a/docs/reference/np_op_status.md +++ b/docs/reference/np_op_status.md @@ -293,6 +293,20 @@ Please check *Supported Dtypes* as well. - uint16 - uint32 +## bitwise_count + +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.bitwise_count.html +### Status + +**PASS** +Please check *Supported Dtypes* as well. +### Supported Dtypes + +- int16 +- int32 +- uint16 +- uint32 + ## bitwise_not JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.bitwise_not.html diff --git a/docs/reference/pphlo_doc.rst b/docs/reference/pphlo_doc.rst index 751541e1..148d417e 100644 --- a/docs/reference/pphlo_doc.rst +++ b/docs/reference/pphlo_doc.rst @@ -1,9 +1,9 @@ -PPHlo API reference +PPHLO API reference =================== -PPHlo is short for (SPU High level ops), it's the assembly language of SPU. +PPHLO is short for (Privacy-Preserving High-Level Operations), it's the assembly language of SPU. -PPHlo is built on `MLIR `_ infrastructure, the concrete ops definition could be found :spu_code_host:`here `. +PPHLO is built on `MLIR `_ infrastructure, the concrete ops definition could be found :spu_code_host:`here `. Op List ~~~~~~~ diff --git a/docs/reference/pphlo_op_doc.md b/docs/reference/pphlo_op_doc.md index f6afc268..2839c9a1 100644 --- a/docs/reference/pphlo_op_doc.md +++ b/docs/reference/pphlo_op_doc.md @@ -747,7 +747,7 @@ Ref https://www.tensorflow.org/xla/operation_semantics#dot. Traits: `AlwaysSpeculatableImplTrait` -Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)` +Interfaces: `ConditionallySpeculatable`, `InferShapedTypeOpInterface`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)` Effects: `MemoryEffects::Effect{}` @@ -1626,55 +1626,63 @@ Effects: `MemoryEffects::Effect{}` | :----: | ----------- | | `result` | statically shaped tensor of 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values values -### `pphlo.power` (spu::pphlo::PowOp) +### `pphlo.popcnt` (spu::pphlo::PopcntOp) -_Power operator_ +_Popcnt operator, ties away from zero_ Syntax: ``` -operation ::= `pphlo.power` $lhs `,` $rhs attr-dict - `:` custom(type($lhs), type($rhs), type($result)) +operation ::= `pphlo.popcnt` $operand attr-dict `:` custom(type($operand), type($result)) ``` -Performs element-wise exponentiation of `lhs` tensor by `rhs` tensor and produces a `result` tensor. +Performs element-wise count of the number of bits set in the `operand` tensor and produces a `result` tensor. -Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#power +Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#popcnt -Traits: `AlwaysSpeculatableImplTrait`, `Elementwise`, `SameOperandsAndResultShape` +Traits: `AlwaysSpeculatableImplTrait`, `Elementwise`, `SameOperandsAndResultShape`, `SameOperandsAndResultType` -Interfaces: `ConditionallySpeculatable`, `InferShapedTypeOpInterface`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)` +Interfaces: `ConditionallySpeculatable`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)` Effects: `MemoryEffects::Effect{}` +#### Attributes: + + + + +
AttributeMLIR TypeDescription
bits::mlir::IntegerAttr64-bit signless integer attribute
+ #### Operands: | Operand | Description | | :-----: | ----------- | -| `lhs` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values -| `rhs` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values +| `operand` | statically shaped tensor of 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values values #### Results: | Result | Description | | :----: | ----------- | -| `result` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values +| `result` | statically shaped tensor of 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values values -### `pphlo.prefer_a` (spu::pphlo::PreferAOp) +### `pphlo.power` (spu::pphlo::PowOp) -_Prefer AShare operator_ +_Power operator_ Syntax: ``` -operation ::= `pphlo.prefer_a` $operand attr-dict `:` custom(type($operand), type($result)) +operation ::= `pphlo.power` $lhs `,` $rhs attr-dict + `:` custom(type($lhs), type($rhs), type($result)) ``` -Convert input to AShare if possible. +Performs element-wise exponentiation of `lhs` tensor by `rhs` tensor and produces a `result` tensor. -Traits: `AlwaysSpeculatableImplTrait`, `Elementwise`, `SameOperandsAndResultShape`, `SameOperandsAndResultType` +Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#power + +Traits: `AlwaysSpeculatableImplTrait`, `Elementwise`, `SameOperandsAndResultShape` Interfaces: `ConditionallySpeculatable`, `InferShapedTypeOpInterface`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)` @@ -1684,7 +1692,8 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `operand` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values +| `lhs` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values +| `rhs` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values #### Results: @@ -2270,12 +2279,21 @@ Returns the sign of the `operand` element-wise and produces a `result` tensor. Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sign +PPHLO Extension: when `ignore_zero` is set to true, sign does not enforce sign(0) to 0 + Traits: `AlwaysSpeculatableImplTrait`, `Elementwise`, `SameOperandsAndResultShape`, `SameOperandsAndResultType` Interfaces: `ConditionallySpeculatable`, `InferShapedTypeOpInterface`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)` Effects: `MemoryEffects::Effect{}` +#### Attributes: + + + + +
AttributeMLIR TypeDescription
ignore_zero::mlir::BoolAttrbool attribute
+ #### Operands: | Operand | Description | @@ -2377,7 +2395,7 @@ Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#slice Traits: `AlwaysSpeculatableImplTrait`, `SameOperandsAndResultElementType` -Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)` +Interfaces: `ConditionallySpeculatable`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)` Effects: `MemoryEffects::Effect{}` @@ -2551,7 +2569,7 @@ Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#transpose Traits: `AlwaysSpeculatableImplTrait`, `SameOperandsAndResultElementType` -Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)` +Interfaces: `ConditionallySpeculatable`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)` Effects: `MemoryEffects::Effect{}` diff --git a/docs/reference/runtime_config.md b/docs/reference/runtime_config.md index f556bce4..836095cc 100644 --- a/docs/reference/runtime_config.md +++ b/docs/reference/runtime_config.md @@ -179,8 +179,9 @@ The SPU runtime configuration. | Field | Type | Description | | ----- | ---- | ----------- | | server_host | [ string](#string) | TrustedThirdParty beaver server's remote ip:port or load-balance uri. | -| session_id | [ string](#string) | if empty, use link id as session id. | | adjust_rank | [ int32](#int32) | which rank do adjust rpc call, usually choose the rank closer to the server. | +| asym_crypto_schema | [ string](#string) | asym_crypto_schema: support ["SM2"] Will support 25519 in the future, after yacl supported it. | +| server_public_key | [ bytes](#bytes) | server's public key | diff --git a/docs/requirements.txt b/docs/requirements.txt index 0c9b6330..937b1bf0 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,6 +1,6 @@ myst-parser==3.0.1 -rstcheck==6.2.1 -sphinx==7.3.7 +rstcheck==6.2.4 +sphinx==8.0.2 nbsphinx==0.9.4 sphinx-autobuild==2024.4.16 sphinx-markdown-parser==0.2.4 diff --git a/examples/python/ml/jax_lr/README.md b/examples/python/ml/jax_lr/README.md index fc0e6580..bc340199 100644 --- a/examples/python/ml/jax_lr/README.md +++ b/examples/python/ml/jax_lr/README.md @@ -5,7 +5,7 @@ This example demonstrates how to use SPU to train a logistic regression model pr 1. Launch SPU backend runtime ```sh - bazel run -c opt //examples/python/utils:nodectl -- up + bazel run -c opt //examples/python/utils:nodectl -- -c examples/python/conf/2pc_semi2k.json up ``` 2. Run `jax_lr` example diff --git a/experimental/squirrel/README.md b/experimental/squirrel/README.md index 5c33ff83..affce442 100644 --- a/experimental/squirrel/README.md +++ b/experimental/squirrel/README.md @@ -29,11 +29,24 @@ Code under this folder is purely for research demonstration and it's **NOT desig * On one terminal ```sh - bazel-bin/experimental/squirrel/squirrel_demo_main --rank0_nfeatures=85 --rank1_nfeatures=85 --standalone --train=BinaryClassification_Aps_Test_60000_171.csv --test=BinaryClassification_Aps_Test_16000_171.csv --standalone=true --rank=0 --has_label=0 --lr=1.0 --subsample=0.8 + bazel-bin/experimental/squirrel/squirrel_demo_main --rank0_nfeatures=85 --rank1_nfeatures=85 --standalone=true --train=BinaryClassification_Aps_Test_60000_171.csv --test=BinaryClassification_Aps_Test_16000_171.csv --rank=0 --has_label=0 --lr=1.0 --subsample=0.8 ``` * On another terminal ```sh - bazel-bin/experimental/squirrel/squirrel_demo_main --rank0_nfeatures=85 --rank1_nfeatures=85 --standalone --train=BinaryClassification_Aps_Test_60000_171.csv --test=BinaryClassification_Aps_Test_16000_171.csv --standalone=true --rank=1 --has_label=1 --lr=1.0 --subsample=0.8 + bazel-bin/experimental/squirrel/squirrel_demo_main --rank0_nfeatures=85 --rank1_nfeatures=85 --standalone=true --train=BinaryClassification_Aps_Test_60000_171.csv --test=BinaryClassification_Aps_Test_16000_171.csv --rank=1 --has_label=1 --lr=1.0 --subsample=0.8 ``` +* Run on distributed dataset, e.g., using the `breast_cancer` dataset from the SPU repo. + * On one terminal + + ```sh + bazel-bin/experimental/squirrel/squirrel_demo_main --rank0_nfeatures=15 --rank1_nfeatures=15 --standalone=false --train=examples/data/breast_cancer_a.csv --rank=0 --has_label=0 --lr=1.0 --subsample=0.8 + ``` + + * On another terminal + + ```sh + bazel-bin/experimental/squirrel/squirrel_demo_main --rank0_nfeatures=15 --rank1_nfeatures=15 --standalone=false --train=examples/data/breast_cancer_b.csv --rank=1 --has_label=1 --lr=1.0 --subsample=0.8 + ``` + diff --git a/experimental/squirrel/bin_matvec_prot.cc b/experimental/squirrel/bin_matvec_prot.cc index 6946fa4a..35e4ee3a 100644 --- a/experimental/squirrel/bin_matvec_prot.cc +++ b/experimental/squirrel/bin_matvec_prot.cc @@ -766,9 +766,12 @@ spu::NdArrayRef BinMatVecProtocol::Recv(const spu::NdArrayRef &vec_in, SPU_ENFORCE_EQ(vec_in.numel(), dim_in); if (not indicator.empty()) { // mat * diag(indicator) should not be all zeros. - SPU_ENFORCE(std::any_of(indicator.begin(), indicator.end(), - [](uint8_t x) { return x > 0; }), - "empty matrix is not allowed"); + if (std::all_of(indicator.begin(), indicator.end(), + [](uint8_t x) { return x == 0; })) { + SPDLOG_WARN( + "Empty matrix! Make sure the 1-bit error will not ruin your " + "computation."); + } } auto eltype = vec_in.eltype(); diff --git a/experimental/squirrel/bin_matvec_prot.h b/experimental/squirrel/bin_matvec_prot.h index 2ac8ea6c..ba34499f 100644 --- a/experimental/squirrel/bin_matvec_prot.h +++ b/experimental/squirrel/bin_matvec_prot.h @@ -70,7 +70,10 @@ struct StlSparseMatrix { // Outputs: // Sender: z0 \in Zk^{m} // Recv: z1 \in Zk^{m} -// such that z0 + z1 = M * (v0 + v1) mod Zk +// such that z0 + z1 = M * (v0 + v1) + e mod Zk +// +// Note that we might introduce 1-bit error `e` due to the coefficient-based +// resharing HE ciphertexts to additive shares. class BinMatVecProtocol { public: BinMatVecProtocol(size_t ring_bitwidth, diff --git a/experimental/squirrel/bin_matvec_prot_test.cc b/experimental/squirrel/bin_matvec_prot_test.cc index 0d5cae29..9ef145ad 100644 --- a/experimental/squirrel/bin_matvec_prot_test.cc +++ b/experimental/squirrel/bin_matvec_prot_test.cc @@ -110,7 +110,7 @@ TEST_P(BinMatVecProtTest, Basic) { }); NdArrayRef reveal = ring_add(out_shr[0], out_shr[1]); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _vec(vec); auto expected = BinAccumuate(_vec, mat); NdArrayView got(reveal); @@ -160,7 +160,7 @@ TEST_P(BinMatVecProtTest, WithIndicator) { }); NdArrayRef reveal = ring_add(out_shr[0], out_shr[1]); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _vec(vec); auto expected = BinAccumuate(_vec, mat, absl::MakeConstSpan(indicator)); @@ -172,4 +172,100 @@ TEST_P(BinMatVecProtTest, WithIndicator) { } }); } + +TEST_P(BinMatVecProtTest, EmptyMat) { + using namespace spu; + using namespace spu::mpc; + constexpr size_t kWorldSize = 2; + + FieldType field = std::get<0>(GetParam()); + int64_t dim_in = std::get<0>(std::get<1>(GetParam())); + int64_t dim_out = std::get<1>(std::get<1>(GetParam())); + + StlSparseMatrix mat; + mat.rows_data_.resize(dim_out); + mat.cols_ = dim_in; + + NdArrayRef vec_shr[2]; + vec_shr[0] = ring_rand(field, {dim_in}) + .as(spu::makeType(field)); + vec_shr[1] = ring_rand(field, {dim_in}) + .as(spu::makeType(field)); + + NdArrayRef vec = ring_add(vec_shr[0], vec_shr[1]); + + NdArrayRef out_shr[2]; + utils::simulate(kWorldSize, [&](std::shared_ptr lctx) { + BinMatVecProtocol binmat_prot(SizeOf(field) * 8, lctx); + if (0 == lctx->Rank()) { + out_shr[0] = binmat_prot.Send(vec_shr[0], dim_out, dim_in); + } else { + out_shr[1] = binmat_prot.Recv(vec_shr[1], dim_out, dim_in, mat); + } + }); + NdArrayRef reveal = ring_add(out_shr[0], out_shr[1]); + + DISPATCH_ALL_FIELDS(field, "", [&]() { + using sT = std::make_signed::type; + NdArrayView _vec(vec); + auto expected = BinAccumuate(_vec, mat); + NdArrayView got(reveal); + + EXPECT_EQ(expected.size(), (size_t)got.numel()); + for (int64_t i = 0; i < dim_out; ++i) { + EXPECT_NEAR(expected[i], got[i], 1); + } + }); +} + +TEST_P(BinMatVecProtTest, WithEmptyIndicator) { + using namespace spu; + using namespace spu::mpc; + constexpr size_t kWorldSize = 2; + + FieldType field = std::get<0>(GetParam()); + int64_t dim_in = std::get<0>(std::get<1>(GetParam())); + int64_t dim_out = std::get<1>(std::get<1>(GetParam())); + + StlSparseMatrix mat; + PrepareBinaryMat(mat, dim_out, dim_in, 0); + + NdArrayRef vec_shr[2]; + vec_shr[0] = ring_rand(field, {dim_in}) + .as(spu::makeType(field)); + vec_shr[1] = ring_rand(field, {dim_in}) + .as(spu::makeType(field)); + std::vector indicator(dim_in); + std::default_random_engine rdv; + std::uniform_int_distribution dist(0, 10); + // empty indicator + std::fill_n(indicator.data(), indicator.size(), 0); + + NdArrayRef vec = ring_add(vec_shr[0], vec_shr[1]); + + NdArrayRef out_shr[2]; + utils::simulate(kWorldSize, [&](std::shared_ptr lctx) { + BinMatVecProtocol binmat_prot(SizeOf(field) * 8, lctx); + if (0 == lctx->Rank()) { + out_shr[0] = binmat_prot.Send(vec_shr[0], dim_out, dim_in); + } else { + out_shr[1] = binmat_prot.Recv(vec_shr[1], dim_out, dim_in, mat, + absl::MakeConstSpan(indicator)); + } + }); + NdArrayRef reveal = ring_add(out_shr[0], out_shr[1]); + + DISPATCH_ALL_FIELDS(field, "", [&]() { + using sT = std::make_signed::type; + NdArrayView _vec(vec); + NdArrayView got(reveal); + auto expected = BinAccumuate(_vec, mat, absl::MakeConstSpan(indicator)); + + EXPECT_EQ(expected.size(), (size_t)got.numel()); + for (int64_t i = 0; i < dim_out; ++i) { + EXPECT_NEAR(expected[i], got[i], 1); + } + }); +} + } // namespace squirrel::test diff --git a/experimental/squirrel/objectives.cc b/experimental/squirrel/objectives.cc index b9423d2d..baf5f08d 100644 --- a/experimental/squirrel/objectives.cc +++ b/experimental/squirrel/objectives.cc @@ -272,9 +272,9 @@ namespace { res = NdArrayRef(makeType(ftype), in.shape()); } - return DISPATCH_ALL_FIELDS(field, "cheetah.ring_cast", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using from_ring2k_t = ring2k_t; - return DISPATCH_ALL_FIELDS(ftype, "cheetah.ring_cast", [&]() { + return DISPATCH_ALL_FIELDS(ftype, [&]() { using to_ring2k_t = ring2k_t; NdArrayView _in(in); NdArrayView _res(res); @@ -383,7 +383,7 @@ spu::Value Logistic(spu::SPUContext* ctx, const spu::Value& x) { spu::Value Sigmoid(spu::SPUContext* ctx, const spu::Value& x) { namespace sk = spu::kernel; auto c05 = sk::hlo::Constant(ctx, 0.5F, x.shape()); - auto half = sk::hal::right_shift_arithmetic(ctx, x, 1); + auto half = sk::hal::right_shift_arithmetic(ctx, x, {1}); auto divisor = sk::hlo::Add(ctx, sk::hlo::Constant(ctx, 1, x.shape()), sk::hal::f_square(ctx, x)); return sk::hlo::Add(ctx, c05, diff --git a/experimental/squirrel/tree_build_worker.cc b/experimental/squirrel/tree_build_worker.cc index 30c441f9..06724354 100644 --- a/experimental/squirrel/tree_build_worker.cc +++ b/experimental/squirrel/tree_build_worker.cc @@ -230,7 +230,7 @@ void AccumulateHistogram(spu::NdArrayRef buckets_share, size_t nfeatures, // The buckets belong to the i-th feature is // `buckets[i*bucket_size:(i+1)*bucket_size]` auto field = buckets_share.eltype().as()->field(); - DISPATCH_ALL_FIELDS(field, "AccumulateHistogram", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView histogram(buckets_share); for (size_t j = 0; j < nfeatures; ++j) { size_t start = j * bucket_size; diff --git a/experimental/squirrel/utils.cc b/experimental/squirrel/utils.cc index 3f14436e..58a98d70 100644 --- a/experimental/squirrel/utils.cc +++ b/experimental/squirrel/utils.cc @@ -167,7 +167,7 @@ spu::Value MulPrivateArithWithPrivateBoolean(spu::SPUContext* ctx, &kctx, arith.data(), [&](const NdArrayRef& input, const std::shared_ptr& base_ot) { - return DISPATCH_ALL_FIELDS(ft, "ot", [&]() { + return DISPATCH_ALL_FIELDS(ft, [&]() { NdArrayRef ot_out = spu::mpc::ring_zeros(ft, input.shape()); auto inp = absl::MakeConstSpan(&input.at(0), input.numel()); auto oup = absl::MakeSpan(&ot_out.at(0), ot_out.numel()); @@ -193,7 +193,7 @@ spu::Value MulPrivateArithWithPrivateBoolean(spu::SPUContext* ctx, &kctx, boolean, [&](absl::Span input, const std::shared_ptr& base_ot) { - return DISPATCH_ALL_FIELDS(ft, "ot", [&]() { + return DISPATCH_ALL_FIELDS(ft, [&]() { NdArrayRef ot_out = spu::mpc::ring_zeros(ft, {(int64_t)input.size()}); auto oup = absl::MakeSpan(&ot_out.at(0), input.size()); base_ot->GetReceiverCOT()->RecvCAMCC(input, oup); @@ -222,7 +222,7 @@ spu::Value MulArithShareWithANDBoolShare(spu::SPUContext* ctx, std::shared_ptr base_ot) { NdArrayRef out(x.eltype(), x.shape()); - DISPATCH_ALL_FIELDS(ft, "camcc", [&]() { + DISPATCH_ALL_FIELDS(ft, [&]() { spu::NdArrayView _ashr(x); auto oup = absl::MakeSpan(&out.at(0), y.size()); std::vector corr(y.size()); diff --git a/experimental/squirrel/utils_test.cc b/experimental/squirrel/utils_test.cc index 3bdc004e..8cde97fd 100644 --- a/experimental/squirrel/utils_test.cc +++ b/experimental/squirrel/utils_test.cc @@ -85,7 +85,7 @@ TEST_F(UtilsTest, ReduceSum) { const double fxp = std::pow(2., rt_config.fxp_fraction_bits()); auto flatten = got.data().reshape({got.numel()}); - DISPATCH_ALL_FIELDS(field, "check", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using s2k = std::make_signed::type; NdArrayView got(flatten); for (int64_t i = 0; i < expected.numel(); ++i) { @@ -136,7 +136,7 @@ TEST_F(UtilsTest, ArgMax) { if (lctx->Rank() == 0) { auto flatten = got.data().reshape({got.numel()}); - DISPATCH_ALL_FIELDS(field, "check", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView got(flatten); for (size_t i = 0; i < expected.size(); ++i) { ASSERT_EQ(expected(i), got[i]); diff --git a/libspu/compiler/common/compilation_context.cc b/libspu/compiler/common/compilation_context.cc index ef22871c..553fe913 100644 --- a/libspu/compiler/common/compilation_context.cc +++ b/libspu/compiler/common/compilation_context.cc @@ -23,7 +23,7 @@ namespace { void SPUErrorHandler(void * /*use_data*/, const char *reason, bool /*gen_crash_diag*/) { - SPU_THROW(reason); + SPU_THROW("{}", reason); } } // namespace diff --git a/libspu/compiler/common/compilation_context.h b/libspu/compiler/common/compilation_context.h index 24ebe043..8abfed96 100644 --- a/libspu/compiler/common/compilation_context.h +++ b/libspu/compiler/common/compilation_context.h @@ -16,7 +16,6 @@ #include #include -#include #include "mlir/IR/MLIRContext.h" #include "mlir/Pass/PassManager.h" diff --git a/libspu/compiler/common/ir_printer_config.cc b/libspu/compiler/common/ir_printer_config.cc index 47dff64a..f3c15ff2 100644 --- a/libspu/compiler/common/ir_printer_config.cc +++ b/libspu/compiler/common/ir_printer_config.cc @@ -51,6 +51,7 @@ void IRPrinterConfig::printBeforeIfEnabled(Pass *pass, Operation *, if (ec.value() != 0) { spdlog::error("Open file {} failed, error = {}", file_name.c_str(), ec.message()); + return; } print_callback(f); } @@ -64,6 +65,7 @@ void IRPrinterConfig::printAfterIfEnabled(Pass *pass, Operation *, if (ec.value() != 0) { spdlog::error("Open file {} failed, error = {}", file_name.c_str(), ec.message()); + return; } print_callback(f); } diff --git a/libspu/compiler/compile.cc b/libspu/compiler/compile.cc index 38a59acb..236d6d02 100644 --- a/libspu/compiler/compile.cc +++ b/libspu/compiler/compile.cc @@ -14,8 +14,6 @@ #include "libspu/compiler/compile.h" -#include "mlir/IR/BuiltinOps.h" - #include "libspu/compiler/codegen/codegen.h" #include "libspu/compiler/common/compilation_context.h" #include "libspu/compiler/core/core.h" diff --git a/libspu/compiler/core/core.cc b/libspu/compiler/core/core.cc index 355dd68a..b7c46d8f 100644 --- a/libspu/compiler/core/core.cc +++ b/libspu/compiler/core/core.cc @@ -16,7 +16,6 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" @@ -95,6 +94,7 @@ void Core::buildPipeline(mlir::PassManager *pm) { } optPM.addPass(mlir::createLoopInvariantCodeMotionPass()); + optPM.addPass(mlir::spu::pphlo::createRegionAccessFixture()); optPM.addPass(mlir::createCSEPass()); if (!options.disable_deallocation_insertion()) { diff --git a/libspu/compiler/front_end/fe.cc b/libspu/compiler/front_end/fe.cc index 682f1c5b..5bcb693f 100644 --- a/libspu/compiler/front_end/fe.cc +++ b/libspu/compiler/front_end/fe.cc @@ -21,7 +21,6 @@ #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" -#include "spdlog/spdlog.h" #include "stablehlo/dialect/StablehloOps.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" @@ -54,6 +53,8 @@ mlir::OwningOpRef FE::doit(const CompilationSource &source) { module = mlir::parseSourceString(source.ir_txt(), ctx_->getMLIRContext()); + SPU_ENFORCE(module, "MLIR parser failure"); + // Convert stablehlo to mhlo first mlir::PassManager pm(ctx_->getMLIRContext()); pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); diff --git a/libspu/compiler/front_end/hlo_importer.cc b/libspu/compiler/front_end/hlo_importer.cc index 5845af57..363bd76c 100644 --- a/libspu/compiler/front_end/hlo_importer.cc +++ b/libspu/compiler/front_end/hlo_importer.cc @@ -196,12 +196,12 @@ HloImporter::parseXlaModuleFromString(const std::string &content) { auto module_config = xla::HloModule::CreateModuleConfigFromProto(hlo_module, debug_options); if (!module_config.status().ok()) { - SPU_THROW(module_config.status().message()); + SPU_THROW("{}", module_config.status().message()); } auto module = xla::HloModule::CreateFromProto(hlo_module, *module_config); if (!module.status().ok()) { - SPU_THROW(module.status().message()); + SPU_THROW("{}", module.status().message()); } xla::runHloPasses((*module).get()); @@ -214,7 +214,7 @@ HloImporter::parseXlaModuleFromString(const std::string &content) { auto status = importer.Import(**module); if (!status.ok()) { - SPU_THROW(status.message()); + SPU_THROW("{}", status.message()); } return mlir_hlo; diff --git a/libspu/compiler/tests/interpret/abs.mlir b/libspu/compiler/tests/interpret/abs.mlir index dd655bf1..49e409a8 100644 --- a/libspu/compiler/tests/interpret/abs.mlir +++ b/libspu/compiler/tests/interpret/abs.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @abs_op_test_i64_i64_p() { diff --git a/libspu/compiler/tests/interpret/add.mlir b/libspu/compiler/tests/interpret/add.mlir index a763c0ae..5074d3eb 100644 --- a/libspu/compiler/tests/interpret/add.mlir +++ b/libspu/compiler/tests/interpret/add.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @add_op_test_i8_i8_pp() { diff --git a/libspu/compiler/tests/interpret/and.mlir b/libspu/compiler/tests/interpret/and.mlir index 6e3cdfa6..fbcc2a2d 100644 --- a/libspu/compiler/tests/interpret/and.mlir +++ b/libspu/compiler/tests/interpret/and.mlir @@ -1,6 +1,7 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @and_op_test_i8_i8_pp() { diff --git a/libspu/compiler/tests/interpret/atan2.mlir b/libspu/compiler/tests/interpret/atan2.mlir index e1938852..52c2b122 100644 --- a/libspu/compiler/tests/interpret/atan2.mlir +++ b/libspu/compiler/tests/interpret/atan2.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @atan2_op_test_f64_f64_pp() { diff --git a/libspu/compiler/tests/interpret/broadcast.mlir b/libspu/compiler/tests/interpret/broadcast.mlir index 1d0e9a63..7e2d17f8 100644 --- a/libspu/compiler/tests/interpret/broadcast.mlir +++ b/libspu/compiler/tests/interpret/broadcast.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @broadcast_in_dim() { %operand = pphlo.constant dense<[[1], [2], [3]]> : tensor<3x1xi64> diff --git a/libspu/compiler/tests/interpret/case.mlir b/libspu/compiler/tests/interpret/case.mlir index 543fbf86..96105d4e 100644 --- a/libspu/compiler/tests/interpret/case.mlir +++ b/libspu/compiler/tests/interpret/case.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @case_negative_index_default() { %index = pphlo.constant dense<-1> : tensor diff --git a/libspu/compiler/tests/interpret/ceil.mlir b/libspu/compiler/tests/interpret/ceil.mlir index d4f74b8b..23c05538 100644 --- a/libspu/compiler/tests/interpret/ceil.mlir +++ b/libspu/compiler/tests/interpret/ceil.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @ceil_op_test_f16_f16_p() { diff --git a/libspu/compiler/tests/interpret/clamp.mlir b/libspu/compiler/tests/interpret/clamp.mlir index 69c05696..9e76acc8 100644 --- a/libspu/compiler/tests/interpret/clamp.mlir +++ b/libspu/compiler/tests/interpret/clamp.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @clamp_op_test_si64() { %min = pphlo.constant dense<[1, 5, -5]> : tensor<3xi64> diff --git a/libspu/compiler/tests/interpret/concatenate.mlir b/libspu/compiler/tests/interpret/concatenate.mlir index 1d0d7655..63107f22 100644 --- a/libspu/compiler/tests/interpret/concatenate.mlir +++ b/libspu/compiler/tests/interpret/concatenate.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @concatenate() { %input0 = pphlo.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi64> diff --git a/libspu/compiler/tests/interpret/convert.mlir b/libspu/compiler/tests/interpret/convert.mlir index 747047ee..b3b23f3b 100644 --- a/libspu/compiler/tests/interpret/convert.mlir +++ b/libspu/compiler/tests/interpret/convert.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @convert_op_test_1() { %0 = pphlo.constant dense<[0, 1, 8, -9, 0]> : tensor<5xi32> diff --git a/libspu/compiler/tests/interpret/convolution.mlir b/libspu/compiler/tests/interpret/convolution.mlir index 294cdd36..2ce1030c 100644 --- a/libspu/compiler/tests/interpret/convolution.mlir +++ b/libspu/compiler/tests/interpret/convolution.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @main() { %0 = pphlo.constant dense<[[[[ 1.0, 2.0, 3.0, 4.0], diff --git a/libspu/compiler/tests/interpret/cosine.mlir b/libspu/compiler/tests/interpret/cosine.mlir index 72d6f249..5ee36f58 100644 --- a/libspu/compiler/tests/interpret/cosine.mlir +++ b/libspu/compiler/tests/interpret/cosine.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @cosine_op_test_f16_f16_p() { diff --git a/libspu/compiler/tests/interpret/divide.mlir b/libspu/compiler/tests/interpret/divide.mlir index 56ce53c2..a7540602 100644 --- a/libspu/compiler/tests/interpret/divide.mlir +++ b/libspu/compiler/tests/interpret/divide.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @divide_op_test_i64_i64_pp() { diff --git a/libspu/compiler/tests/interpret/dot_general.mlir b/libspu/compiler/tests/interpret/dot_general.mlir index 29023c81..a4c506ed 100644 --- a/libspu/compiler/tests/interpret/dot_general.mlir +++ b/libspu/compiler/tests/interpret/dot_general.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @dot_general_op_test_si64() { %lhs = pphlo.constant dense<[[[1, 2], [3, 4]], diff --git a/libspu/compiler/tests/interpret/dynamic_slice.mlir b/libspu/compiler/tests/interpret/dynamic_slice.mlir index 672ace7c..d0275fd9 100644 --- a/libspu/compiler/tests/interpret/dynamic_slice.mlir +++ b/libspu/compiler/tests/interpret/dynamic_slice.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @dynamic_slice() { %operand = pphlo.constant dense<[[1, 1, 1], diff --git a/libspu/compiler/tests/interpret/dynamic_update_slice.mlir b/libspu/compiler/tests/interpret/dynamic_update_slice.mlir index a7053428..423528ac 100644 --- a/libspu/compiler/tests/interpret/dynamic_update_slice.mlir +++ b/libspu/compiler/tests/interpret/dynamic_update_slice.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @dynamic_update_slice() { %operand = pphlo.constant dense<[[1, 1, 1, 1], diff --git a/libspu/compiler/tests/interpret/equal.mlir b/libspu/compiler/tests/interpret/equal.mlir index f5364638..c8291d1a 100644 --- a/libspu/compiler/tests/interpret/equal.mlir +++ b/libspu/compiler/tests/interpret/equal.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @equal_op_test_i64_i1_pp() { diff --git a/libspu/compiler/tests/interpret/exponential.mlir b/libspu/compiler/tests/interpret/exponential.mlir index a85c5bb6..21ec3591 100644 --- a/libspu/compiler/tests/interpret/exponential.mlir +++ b/libspu/compiler/tests/interpret/exponential.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @exponential_op_test_f64_f64_p() { diff --git a/libspu/compiler/tests/interpret/exponential_minus_one.mlir b/libspu/compiler/tests/interpret/exponential_minus_one.mlir index 1a337a21..35131bbc 100644 --- a/libspu/compiler/tests/interpret/exponential_minus_one.mlir +++ b/libspu/compiler/tests/interpret/exponential_minus_one.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @exponential_minus_one_op_test_f64_f64_p() { diff --git a/libspu/compiler/tests/interpret/floor.mlir b/libspu/compiler/tests/interpret/floor.mlir index 97ccdb1e..602d0161 100644 --- a/libspu/compiler/tests/interpret/floor.mlir +++ b/libspu/compiler/tests/interpret/floor.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @floor_op_test_f16_f16_p() { diff --git a/libspu/compiler/tests/interpret/generate_mlir_tests.py b/libspu/compiler/tests/interpret/generate_mlir_tests.py index f9bdef86..460f67f6 100755 --- a/libspu/compiler/tests/interpret/generate_mlir_tests.py +++ b/libspu/compiler/tests/interpret/generate_mlir_tests.py @@ -65,6 +65,7 @@ def TestCase(inputs, expected, checker='expect_eq', tol=None): "or", # "popcnt", "power", + "reciprocal", "reshape", "round_afz", "rsqrt", @@ -99,6 +100,16 @@ def TestCase(inputs, expected, checker='expect_eq', tol=None): f.write( "// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s\n" ) + # FIXME: these tests are not stable for cheetah now + if test not in ["xor", "or", "and"]: + f.write( + "// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s\n" + ) + # Some test values in max and min are not supported by protocol 5. + if test not in ["max", "min"]: + f.write( + "// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s\n" + ) f.write("// AUTO GENERATED, DO NOT EDIT\n\n") # Emit cases diff --git a/libspu/compiler/tests/interpret/greater.mlir b/libspu/compiler/tests/interpret/greater.mlir index 7f8e76be..92140f85 100644 --- a/libspu/compiler/tests/interpret/greater.mlir +++ b/libspu/compiler/tests/interpret/greater.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @greater_op_test_i64_i1_pp() { diff --git a/libspu/compiler/tests/interpret/greater_equal.mlir b/libspu/compiler/tests/interpret/greater_equal.mlir index 305aaff5..af3ffa7a 100644 --- a/libspu/compiler/tests/interpret/greater_equal.mlir +++ b/libspu/compiler/tests/interpret/greater_equal.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @greater_equal_op_test_i64_i1_pp() { diff --git a/libspu/compiler/tests/interpret/if.mlir b/libspu/compiler/tests/interpret/if.mlir index cc98b65d..ef73d547 100644 --- a/libspu/compiler/tests/interpret/if.mlir +++ b/libspu/compiler/tests/interpret/if.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @if_ops_true_branch() { %pred = pphlo.constant dense : tensor diff --git a/libspu/compiler/tests/interpret/iota.mlir b/libspu/compiler/tests/interpret/iota.mlir index a7ee86ed..cc71b040 100644 --- a/libspu/compiler/tests/interpret/iota.mlir +++ b/libspu/compiler/tests/interpret/iota.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @iota_op_test_si8_dim_0() { %0 = pphlo.iota dim = 0 : tensor<3x4xi8> diff --git a/libspu/compiler/tests/interpret/less.mlir b/libspu/compiler/tests/interpret/less.mlir index 58444a29..1a9d3060 100644 --- a/libspu/compiler/tests/interpret/less.mlir +++ b/libspu/compiler/tests/interpret/less.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @less_op_test_i64_i1_pp() { diff --git a/libspu/compiler/tests/interpret/less_equal.mlir b/libspu/compiler/tests/interpret/less_equal.mlir index 9951a569..c454bcc7 100644 --- a/libspu/compiler/tests/interpret/less_equal.mlir +++ b/libspu/compiler/tests/interpret/less_equal.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @less_equal_op_test_i64_i1_pp() { diff --git a/libspu/compiler/tests/interpret/log.mlir b/libspu/compiler/tests/interpret/log.mlir index af61d86a..bf309137 100644 --- a/libspu/compiler/tests/interpret/log.mlir +++ b/libspu/compiler/tests/interpret/log.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @log_op_test_f64_f64_p() { diff --git a/libspu/compiler/tests/interpret/log_plus_one.mlir b/libspu/compiler/tests/interpret/log_plus_one.mlir index 3aec9184..99bcf9ff 100644 --- a/libspu/compiler/tests/interpret/log_plus_one.mlir +++ b/libspu/compiler/tests/interpret/log_plus_one.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @log_plus_one_op_test_f64_f64_p() { diff --git a/libspu/compiler/tests/interpret/logistic.mlir b/libspu/compiler/tests/interpret/logistic.mlir index eeac6fab..655cb82b 100644 --- a/libspu/compiler/tests/interpret/logistic.mlir +++ b/libspu/compiler/tests/interpret/logistic.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @logistic_op_test_f64_f64_p() { diff --git a/libspu/compiler/tests/interpret/maximum.mlir b/libspu/compiler/tests/interpret/maximum.mlir index b90553e7..4919c356 100644 --- a/libspu/compiler/tests/interpret/maximum.mlir +++ b/libspu/compiler/tests/interpret/maximum.mlir @@ -1,6 +1,7 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @maximum_op_test_i8_i8_pp() { diff --git a/libspu/compiler/tests/interpret/minimum.mlir b/libspu/compiler/tests/interpret/minimum.mlir index 74bb2b83..1853124b 100644 --- a/libspu/compiler/tests/interpret/minimum.mlir +++ b/libspu/compiler/tests/interpret/minimum.mlir @@ -1,6 +1,7 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @minimum_op_test_i8_i8_pp() { diff --git a/libspu/compiler/tests/interpret/multiply.mlir b/libspu/compiler/tests/interpret/multiply.mlir index 82b780f3..f7d415c4 100644 --- a/libspu/compiler/tests/interpret/multiply.mlir +++ b/libspu/compiler/tests/interpret/multiply.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @multiply_op_test_i8_i8_pp() { diff --git a/libspu/compiler/tests/interpret/negate.mlir b/libspu/compiler/tests/interpret/negate.mlir index 8c5e74c3..6a00d9b8 100644 --- a/libspu/compiler/tests/interpret/negate.mlir +++ b/libspu/compiler/tests/interpret/negate.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @negate_op_test_i8_i8_p() { diff --git a/libspu/compiler/tests/interpret/not.mlir b/libspu/compiler/tests/interpret/not.mlir index 0fdf44e4..721e1850 100644 --- a/libspu/compiler/tests/interpret/not.mlir +++ b/libspu/compiler/tests/interpret/not.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @not_op_test_i8_i8_p() { diff --git a/libspu/compiler/tests/interpret/not_equal.mlir b/libspu/compiler/tests/interpret/not_equal.mlir index 1bbbd5b3..cb598070 100644 --- a/libspu/compiler/tests/interpret/not_equal.mlir +++ b/libspu/compiler/tests/interpret/not_equal.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @not_equal_op_test_i64_i1_pp() { diff --git a/libspu/compiler/tests/interpret/or.mlir b/libspu/compiler/tests/interpret/or.mlir index 79836813..acd9c28f 100644 --- a/libspu/compiler/tests/interpret/or.mlir +++ b/libspu/compiler/tests/interpret/or.mlir @@ -1,6 +1,7 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @or_op_test_i8_i8_pp() { diff --git a/libspu/compiler/tests/interpret/pad.mlir b/libspu/compiler/tests/interpret/pad.mlir index abedc73a..521313e2 100644 --- a/libspu/compiler/tests/interpret/pad.mlir +++ b/libspu/compiler/tests/interpret/pad.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @pad() { %operand = pphlo.constant dense<[[0, 0, 0, 0], diff --git a/libspu/compiler/tests/interpret/popcnt.mlir b/libspu/compiler/tests/interpret/popcnt.mlir index e5f83152..efea9133 100644 --- a/libspu/compiler/tests/interpret/popcnt.mlir +++ b/libspu/compiler/tests/interpret/popcnt.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @popcnt_op_test_i64_i64_p() { diff --git a/libspu/compiler/tests/interpret/power.mlir b/libspu/compiler/tests/interpret/power.mlir index dd950ef1..b6bfc475 100644 --- a/libspu/compiler/tests/interpret/power.mlir +++ b/libspu/compiler/tests/interpret/power.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @power_op_test_i64_i64_pp() { @@ -58,7 +60,7 @@ func.func @power_op_test_f64_f64_pp() { %1 = pphlo.constant dense<[2.0, 2.0, 2.0, -1.0, 1.0]> : tensor<5xf64> %2 = pphlo.power %0,%1 : (tensor<5xf64>,tensor<5xf64>)->tensor<5xf64> %3 = pphlo.constant dense<[4.000000e+00, 0.000000e+00, 2.500000e+01, 0.33333333333333331, 10000.0]> : tensor<5xf64> - pphlo.custom_call @expect_almost_eq(%2, %3) { tol = 0.5 }: (tensor<5xf64>, tensor<5xf64>)->() + pphlo.custom_call @expect_almost_eq(%2, %3) { tol = 0.6 }: (tensor<5xf64>, tensor<5xf64>)->() func.return } @@ -72,6 +74,6 @@ func.func @power_op_test_f64_f64_ss() { %4 = pphlo.power %2, %3 : (tensor<5x!pphlo.secret>,tensor<5x!pphlo.secret>)->tensor<5x!pphlo.secret> %5 = pphlo.constant dense<[4.000000e+00, 0.000000e+00, 2.500000e+01, 0.33333333333333331, 10000.0]> : tensor<5xf64> %6 = pphlo.convert %4 : (tensor<5x!pphlo.secret>)->tensor<5xf64> - pphlo.custom_call @expect_almost_eq(%5, %6) { tol = 0.5 }: (tensor<5xf64>, tensor<5xf64>)->() + pphlo.custom_call @expect_almost_eq(%5, %6) { tol = 0.6 }: (tensor<5xf64>, tensor<5xf64>)->() func.return } diff --git a/libspu/compiler/tests/interpret/reciprocal.mlir b/libspu/compiler/tests/interpret/reciprocal.mlir new file mode 100644 index 00000000..fe1623a5 --- /dev/null +++ b/libspu/compiler/tests/interpret/reciprocal.mlir @@ -0,0 +1,26 @@ +// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s +// AUTO GENERATED, DO NOT EDIT + +func.func @reciprocal_op_test_f64_f64_p() { + %0 = pphlo.constant dense<[[1.0, -200.0], [100.0, 286991.875]]> : tensor<2x2xf64> + %1 = pphlo.reciprocal %0 : (tensor<2x2xf64>)->tensor<2x2xf64> + %2 = pphlo.constant dense<[[1.0, -0.005], [0.01, 0.0]]> : tensor<2x2xf64> + pphlo.custom_call @expect_almost_eq(%1, %2) { tol = 0.01 }: (tensor<2x2xf64>, tensor<2x2xf64>)->() + func.return +} + +// ----- + +func.func @reciprocal_op_test_f64_f64_s() { + %0 = pphlo.constant dense<[[1.0, -200.0], [100.0, 286991.875]]> : tensor<2x2xf64> + %1 = pphlo.convert %0 : (tensor<2x2xf64>)->tensor<2x2x!pphlo.secret> + %2 = pphlo.reciprocal %1 : (tensor<2x2x!pphlo.secret>)->tensor<2x2x!pphlo.secret> + %3 = pphlo.constant dense<[[1.0, -0.005], [0.01, 0.0]]> : tensor<2x2xf64> + %4 = pphlo.convert %2 : (tensor<2x2x!pphlo.secret>)->tensor<2x2xf64> + pphlo.custom_call @expect_almost_eq(%3, %4) { tol = 0.01 }: (tensor<2x2xf64>, tensor<2x2xf64>)->() + func.return +} diff --git a/libspu/compiler/tests/interpret/reduce.mlir b/libspu/compiler/tests/interpret/reduce.mlir index 3b34bf3e..efc1d40d 100644 --- a/libspu/compiler/tests/interpret/reduce.mlir +++ b/libspu/compiler/tests/interpret/reduce.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @reduce() { %input = pphlo.constant dense<[[0, 1, 2, 3, 4, 5]]> : tensor<1x6xi64> diff --git a/libspu/compiler/tests/interpret/reduce_window.mlir b/libspu/compiler/tests/interpret/reduce_window.mlir index 5385ab04..d15cd021 100644 --- a/libspu/compiler/tests/interpret/reduce_window.mlir +++ b/libspu/compiler/tests/interpret/reduce_window.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @reduce_window() { %input = pphlo.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi64> diff --git a/libspu/compiler/tests/interpret/reshape.mlir b/libspu/compiler/tests/interpret/reshape.mlir index 9e483b77..0a670923 100644 --- a/libspu/compiler/tests/interpret/reshape.mlir +++ b/libspu/compiler/tests/interpret/reshape.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @reshape_op_test_i32_i32_p() { diff --git a/libspu/compiler/tests/interpret/reverse.mlir b/libspu/compiler/tests/interpret/reverse.mlir index 63ab9590..832262e2 100644 --- a/libspu/compiler/tests/interpret/reverse.mlir +++ b/libspu/compiler/tests/interpret/reverse.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @reverse() { %operand = pphlo.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi64> diff --git a/libspu/compiler/tests/interpret/ring_cast.mlir b/libspu/compiler/tests/interpret/ring_cast.mlir index 94b53241..a6b6806c 100644 --- a/libspu/compiler/tests/interpret/ring_cast.mlir +++ b/libspu/compiler/tests/interpret/ring_cast.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @cast_1() { %c0 = pphlo.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32> diff --git a/libspu/compiler/tests/interpret/round_nearest_afz.mlir b/libspu/compiler/tests/interpret/round_nearest_afz.mlir index 6601fa1d..40e64ebe 100644 --- a/libspu/compiler/tests/interpret/round_nearest_afz.mlir +++ b/libspu/compiler/tests/interpret/round_nearest_afz.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @round_nearest_afz_op_test_f64_f64_p() { diff --git a/libspu/compiler/tests/interpret/rsqrt.mlir b/libspu/compiler/tests/interpret/rsqrt.mlir index ef6470c9..5e5e0869 100644 --- a/libspu/compiler/tests/interpret/rsqrt.mlir +++ b/libspu/compiler/tests/interpret/rsqrt.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @rsqrt_op_test_f64_f64_p() { diff --git a/libspu/compiler/tests/interpret/select.mlir b/libspu/compiler/tests/interpret/select.mlir index 33e9a9c7..c2752761 100644 --- a/libspu/compiler/tests/interpret/select.mlir +++ b/libspu/compiler/tests/interpret/select.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @select_op_test_si64() { %pred = pphlo.constant dense<[true, false, true]> : tensor<3xi1> diff --git a/libspu/compiler/tests/interpret/select_and_scatter.mlir b/libspu/compiler/tests/interpret/select_and_scatter.mlir index 5e55dc59..c7f0057a 100644 --- a/libspu/compiler/tests/interpret/select_and_scatter.mlir +++ b/libspu/compiler/tests/interpret/select_and_scatter.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // FIXME func.func @select_and_scatter_op_test() { diff --git a/libspu/compiler/tests/interpret/shift_right_arithmetic.mlir b/libspu/compiler/tests/interpret/shift_right_arithmetic.mlir index 494666ae..ce4bedbc 100644 --- a/libspu/compiler/tests/interpret/shift_right_arithmetic.mlir +++ b/libspu/compiler/tests/interpret/shift_right_arithmetic.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @shift_right_arithmetic_op_test_i64_i64_pp() { diff --git a/libspu/compiler/tests/interpret/shift_right_logical.mlir b/libspu/compiler/tests/interpret/shift_right_logical.mlir index 253d6cc3..73d69f0a 100644 --- a/libspu/compiler/tests/interpret/shift_right_logical.mlir +++ b/libspu/compiler/tests/interpret/shift_right_logical.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @shift_right_logical_op_test_i64_i64_pp() { diff --git a/libspu/compiler/tests/interpret/sign.mlir b/libspu/compiler/tests/interpret/sign.mlir index b153a373..105c5cb2 100644 --- a/libspu/compiler/tests/interpret/sign.mlir +++ b/libspu/compiler/tests/interpret/sign.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @sign_op_test_i64_i64_p() { diff --git a/libspu/compiler/tests/interpret/sine.mlir b/libspu/compiler/tests/interpret/sine.mlir index f0b59e5c..1d62529d 100644 --- a/libspu/compiler/tests/interpret/sine.mlir +++ b/libspu/compiler/tests/interpret/sine.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @sine_op_test_f16_f16_p() { diff --git a/libspu/compiler/tests/interpret/slice.mlir b/libspu/compiler/tests/interpret/slice.mlir index b4b8bdbe..583b977b 100644 --- a/libspu/compiler/tests/interpret/slice.mlir +++ b/libspu/compiler/tests/interpret/slice.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @slice_op() { %operand = pphlo.constant dense<[[0, 0, 1, 0, 0, 1], diff --git a/libspu/compiler/tests/interpret/sort.mlir b/libspu/compiler/tests/interpret/sort.mlir index 837ded37..2433d4ae 100644 --- a/libspu/compiler/tests/interpret/sort.mlir +++ b/libspu/compiler/tests/interpret/sort.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @sort_stable() { %input0 = pphlo.constant dense<[[1, 2, 3], [3, 2, 1]]> : tensor<2x3xi64> diff --git a/libspu/compiler/tests/interpret/sqrt.mlir b/libspu/compiler/tests/interpret/sqrt.mlir index 9248077a..59372c6b 100644 --- a/libspu/compiler/tests/interpret/sqrt.mlir +++ b/libspu/compiler/tests/interpret/sqrt.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @sqrt_op_test_f64_f64_p() { diff --git a/libspu/compiler/tests/interpret/subtract.mlir b/libspu/compiler/tests/interpret/subtract.mlir index ce8b9633..37032882 100644 --- a/libspu/compiler/tests/interpret/subtract.mlir +++ b/libspu/compiler/tests/interpret/subtract.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @subtract_op_test_i8_i8_pp() { diff --git a/libspu/compiler/tests/interpret/tanh.mlir b/libspu/compiler/tests/interpret/tanh.mlir index 413dc6d2..ab45fa2e 100644 --- a/libspu/compiler/tests/interpret/tanh.mlir +++ b/libspu/compiler/tests/interpret/tanh.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @tanh_op_test_f16_f16_p() { diff --git a/libspu/compiler/tests/interpret/test_json/power.json b/libspu/compiler/tests/interpret/test_json/power.json index e5e6f3cf..a95ae33c 100644 --- a/libspu/compiler/tests/interpret/test_json/power.json +++ b/libspu/compiler/tests/interpret/test_json/power.json @@ -65,7 +65,7 @@ } ], "checker": "expect_almost_eq", - "tol": 0.5 + "tol": 0.6 } ] } \ No newline at end of file diff --git a/libspu/compiler/tests/interpret/test_json/reciprocal.json b/libspu/compiler/tests/interpret/test_json/reciprocal.json new file mode 100644 index 00000000..d8e07529 --- /dev/null +++ b/libspu/compiler/tests/interpret/test_json/reciprocal.json @@ -0,0 +1,24 @@ +{ + "name": "reciprocal", + "template": "basic_unary", + "testcases": [ + { + "inputs": [ + { + "data": "[[1.0, -200.0], [100.0, 286991.875]]", + "shape": "2x2", + "dtype": "f64" + } + ], + "expected": [ + { + "data": "[[1.0, -0.005], [0.01, 0.0]]", + "shape": "2x2", + "dtype": "f64" + } + ], + "checker": "expect_almost_eq", + "tol": 0.01 + } + ] +} \ No newline at end of file diff --git a/libspu/compiler/tests/interpret/transpose.mlir b/libspu/compiler/tests/interpret/transpose.mlir index d5ce9b6e..d36883de 100644 --- a/libspu/compiler/tests/interpret/transpose.mlir +++ b/libspu/compiler/tests/interpret/transpose.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @transpose_op_test_si32() { %0 = pphlo.constant dense<[[[1,2],[3,4],[5,6]], [[7,8],[9,10],[11,12]]]> : tensor<2x3x2xi32> diff --git a/libspu/compiler/tests/interpret/while.mlir b/libspu/compiler/tests/interpret/while.mlir index 32789fae..734e199e 100644 --- a/libspu/compiler/tests/interpret/while.mlir +++ b/libspu/compiler/tests/interpret/while.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @while() { // int i = 0; diff --git a/libspu/compiler/tests/interpret/xor.mlir b/libspu/compiler/tests/interpret/xor.mlir index c2e1f1ee..f2b8fba9 100644 --- a/libspu/compiler/tests/interpret/xor.mlir +++ b/libspu/compiler/tests/interpret/xor.mlir @@ -1,6 +1,7 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @xor_op_test_i8_i8_pp() { diff --git a/libspu/compiler/tests/passes/hlo2pphlo/select_and_scatter.mlir b/libspu/compiler/tests/passes/hlo2pphlo/select_and_scatter.mlir index ac7492bc..bc21172f 100644 --- a/libspu/compiler/tests/passes/hlo2pphlo/select_and_scatter.mlir +++ b/libspu/compiler/tests/passes/hlo2pphlo/select_and_scatter.mlir @@ -1,7 +1,7 @@ // RUN: spu-opt --hlo-legalize-to-pphlo=input_vis_list=VIS_SECRET,VIS_PUBLIC,VIS_PUBLIC --lower-conversion-cast --split-input-file %s | FileCheck %s func.func @main(%arg0: tensor<128x5x5x32xf32>, %arg1: tensor<128x4x4x32xf32>, %arg2: tensor) -> tensor<128x5x5x32xf32> { - // CHECK: %1 = "pphlo.select_and_scatter"(%arg0, %arg1, %0) ({ + // CHECK: %1 = "pphlo.select_and_scatter"(%arg0, %arg1, %0) <{window_dimensions = array, window_strides = array}> ({ // CHECK: ^bb0(%arg3: tensor>, %arg4: tensor>): // CHECK: %2 = pphlo.greater_equal %arg3, %arg4 : (tensor>, tensor>) -> tensor> // CHECK: pphlo.return %2 : tensor> @@ -9,7 +9,7 @@ func.func @main(%arg0: tensor<128x5x5x32xf32>, %arg1: tensor<128x4x4x32xf32>, %a // CHECK: ^bb0(%arg3: tensor, %arg4: tensor>): // CHECK: %2 = pphlo.add %arg3, %arg4 : (tensor, tensor>) -> tensor> // CHECK: pphlo.return %2 : tensor> - // CHECK: }) {window_dimensions = array, window_strides = array} : (tensor<128x5x5x32x!pphlo.secret>, tensor<128x4x4x32xf32>, tensor>) -> tensor<128x5x5x32x!pphlo.secret> + // CHECK: }) : (tensor<128x5x5x32x!pphlo.secret>, tensor<128x4x4x32xf32>, tensor>) -> tensor<128x5x5x32x!pphlo.secret> %0 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ ^bb0(%arg3: tensor, %arg4: tensor): %1 = "stablehlo.compare"(%arg3, %arg4) {comparison_direction = #stablehlo} : (tensor, tensor) -> tensor diff --git a/libspu/compiler/tests/passes/hlo2pphlo/sort_p.mlir b/libspu/compiler/tests/passes/hlo2pphlo/sort_p.mlir index 2224ee7d..facd338f 100644 --- a/libspu/compiler/tests/passes/hlo2pphlo/sort_p.mlir +++ b/libspu/compiler/tests/passes/hlo2pphlo/sort_p.mlir @@ -2,11 +2,11 @@ func.func @main(%arg0: tensor<20xi32>) -> (tensor<20xi32>) { %0 = stablehlo.iota dim = 0 : tensor<20xi32> - // CHECK: %1:2 = "pphlo.sort"(%arg0, %0) ({ + // CHECK: %1:2 = "pphlo.sort"(%arg0, %0) <{dimension = 0 : i64, is_stable = true}> ({ // CHECK: ^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor): // CHECK: %2 = pphlo.less %arg1, %arg2 : (tensor, tensor) -> tensor // CHECK: pphlo.return %2 : tensor - // CHECK: }) {dimension = 0 : i64, is_stable = true} : (tensor<20xi32>, tensor<20xi32>) -> (tensor<20xi32>, tensor<20xi32>) + // CHECK: }) : (tensor<20xi32>, tensor<20xi32>) -> (tensor<20xi32>, tensor<20xi32>) %1:2 = "stablehlo.sort"(%arg0, %0) ({ ^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor): %2 = stablehlo.compare LT, %arg1, %arg2 : (tensor, tensor) -> tensor diff --git a/libspu/compiler/tests/passes/hlo2pphlo/sort_s.mlir b/libspu/compiler/tests/passes/hlo2pphlo/sort_s.mlir index c4d4bd41..cf7ef73d 100644 --- a/libspu/compiler/tests/passes/hlo2pphlo/sort_s.mlir +++ b/libspu/compiler/tests/passes/hlo2pphlo/sort_s.mlir @@ -2,11 +2,11 @@ func.func @main(%arg0: tensor<20xi32>) -> (tensor<20xi32>) { %0 = stablehlo.iota dim = 0 : tensor<20xi32> - // CHECK: %1:2 = "pphlo.sort"(%arg0, %0) ({ + // CHECK: %1:2 = "pphlo.sort"(%arg0, %0) <{dimension = 0 : i64, is_stable = true}> ({ // CHECK: ^bb0(%arg1: tensor>, %arg2: tensor>, %arg3: tensor, %arg4: tensor): // CHECK: %2 = pphlo.less %arg1, %arg2 : (tensor>, tensor>) -> tensor> // CHECK: pphlo.return %2 : tensor> - // CHECK: }) {dimension = 0 : i64, is_stable = true} : (tensor<20x!pphlo.secret>, tensor<20xi32>) -> (tensor<20x!pphlo.secret>, tensor<20x!pphlo.secret>) + // CHECK: }) : (tensor<20x!pphlo.secret>, tensor<20xi32>) -> (tensor<20x!pphlo.secret>, tensor<20x!pphlo.secret>) %1:2 = "stablehlo.sort"(%arg0, %0) ({ ^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor): %2 = stablehlo.compare LT, %arg1, %arg2 : (tensor, tensor) -> tensor diff --git a/libspu/compiler/tests/passes/optimizations/ops_negative.mlir b/libspu/compiler/tests/passes/optimizations/ops_negative.mlir index 926a5de5..e8644b96 100644 --- a/libspu/compiler/tests/passes/optimizations/ops_negative.mlir +++ b/libspu/compiler/tests/passes/optimizations/ops_negative.mlir @@ -24,7 +24,7 @@ func.func @main() -> tensor { func.func @main() -> tensor { %0 = pphlo.constant dense<[0.000000e+00, -3.40282347E+38]> : tensor<2xf32> - // expected-error @+1 {{op broadcast_dimensions contains invalid value -6 for result with rank 1}} + // expected-error @+1 {{broadcast_dimensions contains invalid value -6 for result with rank 1}} %1 = pphlo.broadcast %0, dims = [-6] : (tensor<2xf32>) -> tensor<2xf32> %2 = pphlo.constant dense<5> : tensor pphlo.return %2 : tensor @@ -33,7 +33,7 @@ func.func @main() -> tensor { // ----- func.func @main() -> tensor { - // expected-error @+1 {{op iota dimension cannot go beyond the output rank or be negative}} + // expected-error @+1 {{iota dimension cannot go beyond the output rank}} %0 = pphlo.iota dim = 1000 : tensor<1xi32> %1 = pphlo.constant dense<5> : tensor pphlo.return %1 : tensor diff --git a/libspu/compiler/tools/spu-lsp.cc b/libspu/compiler/tools/spu-lsp.cc index 0b9ddbd6..9f59509d 100644 --- a/libspu/compiler/tools/spu-lsp.cc +++ b/libspu/compiler/tools/spu-lsp.cc @@ -25,5 +25,6 @@ int main(int argc, char **argv) { registry.insert(); mlir::func::registerInlinerExtension(registry); - return mlir::failed(mlir::MlirLspServerMain(argc, argv, registry)); + return static_cast( + mlir::failed(mlir::MlirLspServerMain(argc, argv, registry))); } diff --git a/libspu/compiler/tools/spu-translate.cc b/libspu/compiler/tools/spu-translate.cc index 423321a5..68857fa0 100644 --- a/libspu/compiler/tools/spu-translate.cc +++ b/libspu/compiler/tools/spu-translate.cc @@ -22,12 +22,10 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Tools/mlir-translate/MlirTranslateMain.h" #include "mlir/Tools/mlir-translate/Translation.h" -#include "mlir/Transforms/Passes.h" #include "stablehlo/dialect/StablehloOps.h" #include "xtensor/xio.hpp" #include "libspu/compiler/common/compilation_context.h" -#include "libspu/compiler/utils/utils.h" #include "libspu/core/prelude.h" #include "libspu/device/pphlo/pphlo_executor.h" #include "libspu/dialect/pphlo/IR/dialect.h" @@ -89,7 +87,7 @@ bool testOpHandler(::spu::SPUContext *sctx, mlir::Operation *op, auto callOp = mlir::dyn_cast(op); if (callOp.getCallTargetName() == "expect_almost_eq") { ::spu::Value runtimeLhs = inputs[0]; - ::spu::Value runtimeRhs = inputs[1]; + const ::spu::Value &runtimeRhs = inputs[1]; if (!runtimeLhs.isPublic()) { runtimeLhs = ::spu::kernel::hal::_s2p(sctx, runtimeLhs) @@ -123,7 +121,7 @@ bool testOpHandler(::spu::SPUContext *sctx, mlir::Operation *op, if (callOp.getCallTargetName() == "expect_eq") { ::spu::Value runtimeLhs = inputs[0]; - ::spu::Value runtimeRhs = inputs[1]; + const ::spu::Value &runtimeRhs = inputs[1]; if (!runtimeLhs.isPublic()) { runtimeLhs = ::spu::kernel::hal::_s2p(sctx, runtimeLhs) @@ -239,6 +237,16 @@ void evalModule(ModuleOp module) { numParties = 3; break; } + case 4: { + conf.set_protocol(::spu::CHEETAH); + numParties = 2; + break; + } + case 5: { + conf.set_protocol(::spu::SECURENN); + numParties = 3; + break; + } } SPDLOG_INFO(conf.DebugString()); @@ -278,6 +286,6 @@ TranslateFromMLIRRegistration interpretRegistration( } // namespace mlir int main(int argc, char **argv) { - return failed( - mlir::mlirTranslateMain(argc, argv, "SPU interpreter driver\n")); + return static_cast( + failed(mlir::mlirTranslateMain(argc, argv, "SPU interpreter driver\n"))); } diff --git a/libspu/compiler/utils/utils.h b/libspu/compiler/utils/utils.h index faeff8d9..3b06397a 100644 --- a/libspu/compiler/utils/utils.h +++ b/libspu/compiler/utils/utils.h @@ -14,6 +14,7 @@ #pragma once +#include "llvm/ADT/Twine.h" #include "mlir/Support/LogicalResult.h" namespace mlir::spu { diff --git a/libspu/core/encoding.cc b/libspu/core/encoding.cc index eb26ce9b..98a17a1a 100644 --- a/libspu/core/encoding.cc +++ b/libspu/core/encoding.cc @@ -60,8 +60,8 @@ NdArrayRef encodeToRing(const PtBufferView& bv, FieldType field, } if (pt_type == PT_F32 || pt_type == PT_F64 || pt_type == PT_F16) { - DISPATCH_FLOAT_PT_TYPES(pt_type, "_", [&]() { - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_FLOAT_PT_TYPES(pt_type, [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using Float = ScalarT; using T = std::make_signed_t; @@ -100,8 +100,8 @@ NdArrayRef encodeToRing(const PtBufferView& bv, FieldType field, return dst; } else { // handle integer & boolean - DISPATCH_INT_PT_TYPES(pt_type, "_", [&]() { - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_INT_PT_TYPES(pt_type, [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using Integer = ScalarT; SPU_ENFORCE(sizeof(ring2k_t) >= sizeof(Integer), "integer encoding failed, ring={} could not represent {}", @@ -138,8 +138,8 @@ void decodeFromRing(const NdArrayRef& src, DataType in_dtype, size_t fxp_bits, *out_pt_type = pt_type; } - DISPATCH_ALL_FIELDS(field, "field", [&]() { - DISPATCH_ALL_PT_TYPES(pt_type, "pt_type", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { + DISPATCH_ALL_PT_TYPES(pt_type, [&]() { using T = std::make_signed_t; auto _src = NdArrayView(src); diff --git a/libspu/core/pt_buffer_view.cc b/libspu/core/pt_buffer_view.cc index 50e6f891..3f7e1084 100644 --- a/libspu/core/pt_buffer_view.cc +++ b/libspu/core/pt_buffer_view.cc @@ -50,7 +50,7 @@ NdArrayRef convertToNdArray(PtBufferView bv) { } const auto type = makePtType(bv.pt_type); auto out = NdArrayRef(type, bv.shape); - return DISPATCH_ALL_PT_TYPES(bv.pt_type, "pt_type", [&]() { + return DISPATCH_ALL_PT_TYPES(bv.pt_type, [&]() { using T = ScalarT; if (bv.shape.numel() > 0) { auto* out_ptr = out.data(); diff --git a/libspu/core/trace.h b/libspu/core/trace.h index c67dbff0..d05b798f 100644 --- a/libspu/core/trace.h +++ b/libspu/core/trace.h @@ -154,6 +154,10 @@ struct ActionRecord final { size_t send_bytes_end; size_t recv_bytes_start; size_t recv_bytes_end; + size_t send_actions_start; + size_t send_actions_end; + size_t recv_actions_start; + size_t recv_actions_end; }; class ProfState final { @@ -238,6 +242,10 @@ class TraceAction final { size_t send_bytes_end_; size_t recv_bytes_start_; size_t recv_bytes_end_; + size_t send_actions_start_; + size_t send_actions_end_; + size_t recv_actions_start_; + size_t recv_actions_end_; int64_t saved_tracer_flag_; @@ -247,6 +255,8 @@ class TraceAction final { if (lctx_) { send_bytes_start_ = lctx_->GetStats()->sent_bytes.load(); recv_bytes_start_ = lctx_->GetStats()->recv_bytes.load(); + send_actions_start_ = lctx_->GetStats()->sent_actions.load(); + recv_actions_start_ = lctx_->GetStats()->recv_actions.load(); } const auto flag = flag_ & tracer_->getFlag(); if ((flag & TR_LOGB) != 0) { @@ -269,6 +279,8 @@ class TraceAction final { if (lctx_) { send_bytes_end_ = lctx_->GetStats()->sent_bytes.load(); recv_bytes_end_ = lctx_->GetStats()->recv_bytes.load(); + send_actions_end_ = lctx_->GetStats()->sent_actions.load(); + recv_actions_end_ = lctx_->GetStats()->recv_actions.load(); } const auto flag = flag_ & tracer_->getFlag(); if ((flag & TR_LOGE) != 0) { @@ -279,7 +291,8 @@ class TraceAction final { tracer_->getProfState()->addRecord( ActionRecord{id_, name_, std::move(detail_), flag_, start_, end_, send_bytes_start_, send_bytes_end_, recv_bytes_start_, - recv_bytes_end_}); + recv_bytes_end_, send_actions_start_, send_actions_end_, + recv_actions_start_, recv_actions_end_}); } } diff --git a/libspu/core/type_util.h b/libspu/core/type_util.h index 2ac1d294..4e70ea7e 100644 --- a/libspu/core/type_util.h +++ b/libspu/core/type_util.h @@ -97,91 +97,90 @@ std::ostream& operator<<(std::ostream& os, const DataType& dtype); // Helper macros to enumerate all py types. // NOLINTNEXTLINE: Global internal used macro. -#define __CASE_PT_TYPE(PT_TYPE, NAME, ...) \ - case (PT_TYPE): { \ - [[maybe_unused]] constexpr std::string_view _kName = NAME; \ - using ScalarT = EnumToPtType::type; \ - return __VA_ARGS__(); \ +#define __CASE_PT_TYPE(PT_TYPE, ...) \ + case (PT_TYPE): { \ + using ScalarT = EnumToPtType::type; \ + return __VA_ARGS__(); \ } -#define DISPATCH_FLOAT_PT_TYPES(PT_TYPE, NAME, ...) \ - [&] { \ - switch (PT_TYPE) { \ - __CASE_PT_TYPE(spu::PT_F16, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_F32, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_F64, NAME, __VA_ARGS__) \ - default: \ - SPU_THROW("{} not implemented for pt_type={}", #NAME, PT_TYPE); \ - } \ +#define DISPATCH_FLOAT_PT_TYPES(PT_TYPE, ...) \ + [&] { \ + switch (PT_TYPE) { \ + __CASE_PT_TYPE(spu::PT_F16, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_F32, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_F64, __VA_ARGS__) \ + default: \ + SPU_THROW("unimplemented for pt_type={}", PT_TYPE); \ + } \ }() -#define DISPATCH_UINT_PT_TYPES(PT_TYPE, NAME, ...) \ - [&] { \ - switch (PT_TYPE) { \ - __CASE_PT_TYPE(spu::PT_U8, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U16, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U32, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U64, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U128, NAME, __VA_ARGS__) \ - default: \ - SPU_THROW("{} not implemented for pt_type={}", #NAME, PT_TYPE); \ - } \ +#define DISPATCH_UINT_PT_TYPES(PT_TYPE, ...) \ + [&] { \ + switch (PT_TYPE) { \ + __CASE_PT_TYPE(spu::PT_U8, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U16, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U32, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U64, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U128, __VA_ARGS__) \ + default: \ + SPU_THROW("unimplemented for pt_type={}", PT_TYPE); \ + } \ }() -#define DISPATCH_INT_PT_TYPES(PT_TYPE, NAME, ...) \ - [&] { \ - switch (PT_TYPE) { \ - __CASE_PT_TYPE(spu::PT_I1, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_I8, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U8, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_I16, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U16, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_I32, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U32, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_I64, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U64, NAME, __VA_ARGS__) \ - default: \ - SPU_THROW("{} not implemented for pt_type={}", #NAME, PT_TYPE); \ - } \ +#define DISPATCH_INT_PT_TYPES(PT_TYPE, ...) \ + [&] { \ + switch (PT_TYPE) { \ + __CASE_PT_TYPE(spu::PT_I1, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_I8, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U8, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_I16, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U16, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_I32, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U32, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_I64, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U64, __VA_ARGS__) \ + default: \ + SPU_THROW("unimplemented for pt_type={}", PT_TYPE); \ + } \ }() -#define DISPATCH_ALL_PT_TYPES(PT_TYPE, NAME, ...) \ - [&] { \ - switch (PT_TYPE) { \ - __CASE_PT_TYPE(spu::PT_I1, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_I8, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U8, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_I16, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U16, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_I32, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U32, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_I64, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U64, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_F16, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_F32, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_F64, NAME, __VA_ARGS__) \ - default: \ - SPU_THROW("{} not implemented for pt_type={}", #NAME, PT_TYPE); \ - } \ +#define DISPATCH_ALL_PT_TYPES(PT_TYPE, ...) \ + [&] { \ + switch (PT_TYPE) { \ + __CASE_PT_TYPE(spu::PT_I1, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_I8, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U8, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_I16, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U16, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_I32, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U32, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_I64, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U64, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_F16, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_F32, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_F64, __VA_ARGS__) \ + default: \ + SPU_THROW("unimplemented for pt_type={}", PT_TYPE); \ + } \ }() -#define DISPATCH_ALL_NONE_BOOL_PT_TYPES(PT_TYPE, NAME, ...) \ - [&] { \ - switch (PT_TYPE) { \ - __CASE_PT_TYPE(spu::PT_I8, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U8, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_I16, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U16, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_I32, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U32, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_I64, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U64, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_F16, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_F32, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_F64, NAME, __VA_ARGS__) \ - default: \ - SPU_THROW("{} not implemented for pt_type={}", #NAME, PT_TYPE); \ - } \ +#define DISPATCH_ALL_NONE_BOOL_PT_TYPES(PT_TYPE, ...) \ + [&] { \ + switch (PT_TYPE) { \ + __CASE_PT_TYPE(spu::PT_I8, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U8, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_I16, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U16, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_I32, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U32, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_I64, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U64, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_F16, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_F32, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_F64, __VA_ARGS__) \ + default: \ + SPU_THROW("unimplemented for pt_type={}", PT_TYPE); \ + } \ }() std::ostream& operator<<(std::ostream& os, const PtType& pt_type); @@ -241,24 +240,23 @@ inline size_t SizeOf(FieldType field) { return SizeOf(GetStorageType(field)); } // Helper macros to enumerate all fields // NOLINTNEXTLINE: Global internal used macro. -#define __CASE_FIELD(FIELD, NAME, ...) \ +#define __CASE_FIELD(FIELD, ...) \ case (FIELD): { \ /* inject `_kField` & `_kName` for the continuation call */ \ [[maybe_unused]] constexpr spu::FieldType _kField = FIELD; \ - [[maybe_unused]] constexpr std::string_view _kName = NAME; \ using ring2k_t [[maybe_unused]] = Ring2kTrait<_kField>::scalar_t; \ return __VA_ARGS__(); \ } -#define DISPATCH_ALL_FIELDS(FIELD, NAME, ...) \ - [&] { \ - switch (FIELD) { \ - __CASE_FIELD(spu::FieldType::FM32, NAME, __VA_ARGS__) \ - __CASE_FIELD(spu::FieldType::FM64, NAME, __VA_ARGS__) \ - __CASE_FIELD(spu::FieldType::FM128, NAME, __VA_ARGS__) \ - default: \ - SPU_THROW("{} not implemented for field={}", #NAME, FIELD); \ - } \ +#define DISPATCH_ALL_FIELDS(FIELD, ...) \ + [&] { \ + switch (FIELD) { \ + __CASE_FIELD(spu::FieldType::FM32, __VA_ARGS__) \ + __CASE_FIELD(spu::FieldType::FM64, __VA_ARGS__) \ + __CASE_FIELD(spu::FieldType::FM128, __VA_ARGS__) \ + default: \ + SPU_THROW("unimplemented for field={}", FIELD); \ + } \ }() ////////////////////////////////////////////////////////////// diff --git a/libspu/device/api.cc b/libspu/device/api.cc index 5535ad85..bae9b462 100644 --- a/libspu/device/api.cc +++ b/libspu/device/api.cc @@ -71,12 +71,14 @@ struct CommunicationStats { size_t send_bytes = 0; size_t recv_bytes = 0; size_t send_actions = 0; + size_t recv_actions = 0; void reset(const std::shared_ptr &lctx) { if (!lctx) { return; } send_actions = lctx->GetStats()->sent_actions; + recv_actions = lctx->GetStats()->recv_actions; send_bytes = lctx->GetStats()->sent_bytes; recv_bytes = lctx->GetStats()->recv_bytes; } @@ -88,6 +90,7 @@ struct CommunicationStats { send_bytes = lctx->GetStats()->sent_bytes - send_bytes; recv_bytes = lctx->GetStats()->recv_bytes - recv_bytes; send_actions = lctx->GetStats()->sent_actions - send_actions; + recv_actions = lctx->GetStats()->recv_actions - recv_actions; } }; @@ -108,6 +111,10 @@ struct ActionStats { size_t send_bytes = 0; // total recv bytes. size_t recv_bytes = 0; + // total send actions. + size_t send_actions = 0; + // total recv actions. + size_t recv_actions = 0; inline double getTotalTimeInSecond() const { return std::chrono::duration_cast>(total_time) @@ -183,6 +190,8 @@ void printProfilingData(spu::SPUContext *sctx, const std::string &name, std::chrono::duration_cast(rec.end - rec.start); stat.send_bytes += (rec.send_bytes_end - rec.send_bytes_start); stat.recv_bytes += (rec.recv_bytes_end - rec.recv_bytes_start); + stat.send_actions += (rec.send_actions_end - rec.send_actions_start); + stat.recv_actions += (rec.recv_actions_end - rec.recv_actions_start); } static std::map kModules = { @@ -213,23 +222,25 @@ void printProfilingData(spu::SPUContext *sctx, const std::string &name, const auto &stat = stats.find(key)->second; SPDLOG_INFO( "- {}, executed {} times, duration {}s, send bytes {} recv " - "bytes {}", + "bytes {}, send actions {}, recv actions {}", key.name, stat.count, stat.getTotalTimeInSecond(), stat.send_bytes, - stat.recv_bytes); + stat.recv_bytes, stat.send_actions, stat.recv_actions); } } } // print link statistics SPDLOG_INFO( - "Link details: total send bytes {}, recv bytes {}, send actions {}", - comm_stats.send_bytes, comm_stats.recv_bytes, comm_stats.send_actions); + "Link details: total send bytes {}, recv bytes {}, send actions {}, recv " + "actions {}", + comm_stats.send_bytes, comm_stats.recv_bytes, comm_stats.send_actions, + comm_stats.recv_actions); } void SPUErrorHandler(void *use_data, const char *reason, bool gen_crash_diag) { (void)use_data; (void)gen_crash_diag; - SPU_THROW(reason); + SPU_THROW("{}", reason); } std::mutex ErrorHandlerMutex; diff --git a/libspu/device/pphlo/pphlo_executor.cc b/libspu/device/pphlo/pphlo_executor.cc index d20d54f9..af8a04b8 100644 --- a/libspu/device/pphlo/pphlo_executor.cc +++ b/libspu/device/pphlo/pphlo_executor.cc @@ -720,6 +720,17 @@ void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope, addValue(sscope, op.getOutput(), std::move(iota_ret), opts); } +void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope, + mlir::spu::pphlo::BroadcastShapeAsOp &op, + const ExecutionOptions &opts) { + // Start indices + const auto &lhs = lookupValue(sscope, op.getLhs(), opts); + const auto &rhs = lookupValue(sscope, op.getRhs(), opts); + + addValue(sscope, op.getResult(), + kernel::hlo::Broadcast(sctx, lhs, rhs.shape(), {}), opts); +} + void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope, mlir::spu::pphlo::RemOp &op, const ExecutionOptions &opts) { // FIXME: When hal has a remainder, use that diff --git a/libspu/device/pphlo/pphlo_verifier.h b/libspu/device/pphlo/pphlo_verifier.h index 428883c8..478a3874 100644 --- a/libspu/device/pphlo/pphlo_verifier.h +++ b/libspu/device/pphlo/pphlo_verifier.h @@ -149,6 +149,7 @@ class PPHloVerifier { NO_VERIFY_DEFN(ImagOp) NO_VERIFY_DEFN(ComplexOp) NO_VERIFY_DEFN(SimpleSortOp) + NO_VERIFY_DEFN(BroadcastShapeAsOp) #undef NO_VERIFY_DEFN }; diff --git a/libspu/dialect/pphlo/IR/dialect.td b/libspu/dialect/pphlo/IR/dialect.td index de2c648c..7d47197c 100644 --- a/libspu/dialect/pphlo/IR/dialect.td +++ b/libspu/dialect/pphlo/IR/dialect.td @@ -32,15 +32,14 @@ def PPHlo_Dialect : Dialect { string summary = "Privacy-Preserving HLO(PPHLO) dialect"; string description = [{ PPHLO represents a high level abstraction for language use by SPU. - It implements a subset of mlir mhlo ops with it's own privacy-preserving focused type system. + It implements a subset of mlir stablehlo ops with it's own privacy-preserving focused type system. - Learn more about mlir hlo at https://github.com/tensorflow/mlir-hlo + Learn more about mlir stablehlo at https://github.com/openxla/stablehlo }]; let name = "pphlo"; let cppNamespace = "::mlir::spu::pphlo"; let useDefaultAttributePrinterParser = 0; let useDefaultTypePrinterParser = 0; - let usePropertiesForAttributes = 0; let hasConstantMaterializer = 1; let extraClassDeclaration = [{ Attribute parseAttribute(DialectAsmParser & parser, Type type) diff --git a/libspu/dialect/pphlo/IR/ops.cc b/libspu/dialect/pphlo/IR/ops.cc index 564d44e1..4d199e93 100644 --- a/libspu/dialect/pphlo/IR/ops.cc +++ b/libspu/dialect/pphlo/IR/ops.cc @@ -18,21 +18,12 @@ #include "llvm/Support/FormatVariadic.h" #include "mlir/IR/Builders.h" #include "mlir/IR/TypeUtilities.h" +#include "stablehlo/dialect/TypeInference.h" #include "libspu/dialect/pphlo/IR/ops.h.inc" namespace mlir::spu::pphlo { -namespace { - -// Checks if the vector `nums` has duplicates. -bool hasDuplicates(const ArrayRef nums) { - llvm::SmallDenseSet set(nums.begin(), nums.end()); - return set.size() != nums.size(); -} - -} // namespace - template static LogicalResult Verify(T /*op*/) { return success(); @@ -386,75 +377,12 @@ LogicalResult ConcatenateOp::verify() { } LogicalResult BroadcastOp::verify() { - auto operandType = mlir::dyn_cast(getOperand().getType()); - - auto operandRank = operandType.getRank(); - - if (getBroadcastDimensions().empty()) { - if (operandRank == 0) { - return success(); - } - return emitOpError( - llvm::formatv("broadcast_dimensions is absent, but required because " - "operand has non-zero rank ({0})", - operandRank)); - } - - auto dimensionsSize = getBroadcastDimensions().size(); - if (static_cast(dimensionsSize) != operandRank) { - return emitOpError(llvm::formatv( - "broadcast_dimensions size ({0}) does not match operand rank ({1})", - dimensionsSize, operandRank)); - } - - auto dimensions = getBroadcastDimensions(); - if (hasDuplicates(dimensions)) { - return emitOpError("broadcast_dimensions should not have duplicates"); - } - - auto resultType = mlir::dyn_cast(getResult().getType()); - auto resultRank = resultType.getRank(); - - for (size_t i = 0; i != dimensionsSize; ++i) { - auto dimIndex = dimensions[i]; - if ((dimIndex >= resultRank) || (dimIndex < 0)) { - return emitOpError( - llvm::formatv("broadcast_dimensions contains invalid value {0} for " - "result with rank {1}", - dimIndex, resultRank)); - } - - if (!operandType.isDynamicDim(i)) { - auto dimSize = operandType.getDimSize(i); - auto resultDimSize = resultType.getDimSize(dimIndex); - if (dimSize != 1 && dimSize != resultDimSize) { - return emitOpError( - llvm::formatv("size of operand dimension {0} ({1}) is not equal to " - "1 or size of result dimension {2} ({3})", - i, dimSize, dimIndex, resultDimSize)); - } - } - } - - return success(); + return hlo::verifyBroadcastInDimOp(getLoc(), getOperand(), + getBroadcastDimensions(), getResult()); } LogicalResult IotaOp::verify() { - auto shape = mlir::dyn_cast(getType()); - if (!shape.hasRank()) { - return success(); - } - - if (shape.getRank() == 0) { - return emitOpError() << "does not support scalars."; - } - - auto iotaDimension = static_cast(this->getIotaDimension()); - if (iotaDimension >= shape.getRank() || iotaDimension < 0) { - return emitOpError() - << "iota dimension cannot go beyond the output rank or be negative."; - } - return success(); + return hlo::verifyIotaOp(getLoc(), getIotaDimension(), getResult()); } LogicalResult SliceOp::verify() { diff --git a/libspu/dialect/pphlo/IR/ops.td b/libspu/dialect/pphlo/IR/ops.td index b9f6b0c2..205dcf75 100644 --- a/libspu/dialect/pphlo/IR/ops.td +++ b/libspu/dialect/pphlo/IR/ops.td @@ -619,6 +619,12 @@ def PPHLO_BroadcastOp }]; } + +def PPHLO_BroadcastShapeAsOp: PPHLO_BinaryElementwiseOp<"broadcast_as", [Pure, + SameOperandsAndResultShape], PPHLO_Tensor> { + let summary = "BroadcastShapeAs operator"; +} + def PPHLO_ClampOp : PPHLO_Op<"clamp", [Pure, SameOperandsAndResultShape]> { let summary = "Clamp operator"; diff --git a/libspu/dialect/pphlo/IR/print_parse.cc b/libspu/dialect/pphlo/IR/print_parse.cc index 8716e011..6f693e30 100644 --- a/libspu/dialect/pphlo/IR/print_parse.cc +++ b/libspu/dialect/pphlo/IR/print_parse.cc @@ -105,9 +105,19 @@ ParseResult ConstantOp::parse(OpAsmParser& parser, OperationState& result) { if (parser.parseRParen()) { return failure(); } + // Parse optional properties + if (succeeded(parser.parseOptionalLess()) && + (failed(parser.parseAttribute(result.propertiesAttr)) || + failed(parser.parseGreater()))) { + return failure(); + } + + // Parse optional attributes if (parser.parseOptionalAttrDict(result.attributes)) { return failure(); } + + // Parse type signature if (parser.parseColon() || parser.parseLParen() || parser.parseRParen() || parser.parseArrow()) { return failure(); diff --git a/libspu/dialect/pphlo/IR/type_inference.cc b/libspu/dialect/pphlo/IR/type_inference.cc index 4e3a9464..6c95aac3 100644 --- a/libspu/dialect/pphlo/IR/type_inference.cc +++ b/libspu/dialect/pphlo/IR/type_inference.cc @@ -285,7 +285,7 @@ LogicalResult PadOp::inferReturnTypes( ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type>& inferredReturnTypes) { - PadOp::Adaptor adaptor(operands, attributes, {}, regions); + PadOp::Adaptor adaptor(operands, attributes, properties, regions); return hlo::inferPadOp(location, adaptor.getOperand().getType(), adaptor.getPaddingValue().getType(), adaptor.getEdgePaddingLow(), @@ -295,27 +295,27 @@ LogicalResult PadOp::inferReturnTypes( LogicalResult ConcatenateOp::inferReturnTypes( MLIRContext*, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferred_return_types) { - ConcatenateOp::Adaptor adaptor(operands, attributes, {}, regions); + ConcatenateOp::Adaptor adaptor(operands, attributes, properties, regions); return hlo::inferConcatenateOp(location, adaptor.getInputs().getTypes(), adaptor.getDimension(), inferred_return_types); } LogicalResult TransposeOp::inferReturnTypes( MLIRContext*, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferred_return_types) { - TransposeOp::Adaptor adaptor(operands, attributes, {}, regions); + TransposeOp::Adaptor adaptor(operands, attributes, properties, regions); return hlo::inferTransposeOp(location, adaptor.getOperand(), adaptor.getPermutation(), inferred_return_types); } LogicalResult SliceOp::inferReturnTypes( MLIRContext*, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferred_return_types) { - SliceOp::Adaptor adaptor(operands, attributes, {}, regions); + SliceOp::Adaptor adaptor(operands, attributes, properties, regions); return hlo::inferSliceOp(location, adaptor.getOperand().getType(), adaptor.getStartIndices(), adaptor.getLimitIndices(), adaptor.getStrides(), inferred_return_types); @@ -375,9 +375,9 @@ LogicalResult inferDynamicSliceOp(std::optional location, LogicalResult DynamicSliceOp::inferReturnTypes( MLIRContext*, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { - DynamicSliceOp::Adaptor adaptor(operands, attributes, {}, regions); + DynamicSliceOp::Adaptor adaptor(operands, attributes, properties, regions); return inferDynamicSliceOp(location, adaptor.getOperand().getType(), adaptor.getStartIndices().getTypes(), adaptor.getSliceSizes(), inferredReturnTypes); @@ -427,9 +427,10 @@ LogicalResult inferDynamicUpdateSliceOp( LogicalResult DynamicUpdateSliceOp::inferReturnTypes( MLIRContext*, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { - DynamicUpdateSliceOp::Adaptor adaptor(operands, attributes, {}, regions); + DynamicUpdateSliceOp::Adaptor adaptor(operands, attributes, properties, + regions); return inferDynamicUpdateSliceOp( location, adaptor.getOperand(), adaptor.getUpdate(), diff --git a/libspu/dialect/pphlo/transforms/hlo_legalize_to_pphlo.cc b/libspu/dialect/pphlo/transforms/hlo_legalize_to_pphlo.cc index e17f85e5..781e0940 100644 --- a/libspu/dialect/pphlo/transforms/hlo_legalize_to_pphlo.cc +++ b/libspu/dialect/pphlo/transforms/hlo_legalize_to_pphlo.cc @@ -133,24 +133,20 @@ class FuncOpConverter : public OpConversionPattern<::mlir::func::FuncOp> { auto ®ion = op.getBody(); // Convert non-entry blocks - SmallVector conversions; - for (Block &block : llvm::drop_begin(region, 1)) { - conversions.emplace_back(block.getNumArguments()); - TypeConverter::SignatureConversion &back = conversions.back(); + for (Block &block : + llvm::make_early_inc_range(llvm::drop_begin(region, 1))) { + TypeConverter::SignatureConversion conversion( + /*numOrigInputs=*/block.getNumArguments()); for (BlockArgument blockArgument : block.getArguments()) { auto idx = blockArgument.getArgNumber(); auto vis_v = vis_.getValueVisibility(blockArgument); auto convertedType = tools_.getType( typeConverter->convertType(blockArgument.getType()), vis_v); - back.addInputs(idx, convertedType); + conversion.addInputs(idx, convertedType); } - } - if (failed(rewriter.convertNonEntryRegionTypes(®ion, *typeConverter, - conversions))) { - rewriter.cancelOpModification(op); - return failure(); + rewriter.applySignatureConversion(&block, conversion, getTypeConverter()); } // Convert function arguments using the provided TypeConverter. diff --git a/libspu/dialect/pphlo/transforms/inline_secret_control_flow.cc b/libspu/dialect/pphlo/transforms/inline_secret_control_flow.cc index 981a8c41..9f801875 100644 --- a/libspu/dialect/pphlo/transforms/inline_secret_control_flow.cc +++ b/libspu/dialect/pphlo/transforms/inline_secret_control_flow.cc @@ -36,9 +36,8 @@ class CaseConverter : public OpRewritePattern { if (target_type.getNumElements() == in_type.getNumElements()) { return rewriter.create(loc, broadcasted_mask_type, in); } else { - return rewriter.create( - loc, broadcasted_mask_type, in, - llvm::SmallVector(target_type.getRank(), 0)); + return rewriter.create(loc, broadcasted_mask_type, in, + llvm::SmallVector{0}); } } diff --git a/libspu/dialect/pphlo/transforms/passes.h b/libspu/dialect/pphlo/transforms/passes.h index b70ce1ab..9bcc2660 100644 --- a/libspu/dialect/pphlo/transforms/passes.h +++ b/libspu/dialect/pphlo/transforms/passes.h @@ -82,6 +82,9 @@ std::unique_ptr> createInlineSecretControlFlow(); // Convert signbit pattern to SignOp std::unique_ptr> createRewriteSignbitPatterns(); +// Fix region access shape mismatch +std::unique_ptr> createRegionAccessFixture(); + } // namespace spu::pphlo } // namespace mlir diff --git a/libspu/dialect/pphlo/transforms/passes.td b/libspu/dialect/pphlo/transforms/passes.td index 4fc28fdc..a13c70b2 100644 --- a/libspu/dialect/pphlo/transforms/passes.td +++ b/libspu/dialect/pphlo/transforms/passes.td @@ -122,3 +122,9 @@ def InlineSecretControlFlow: Pass<"inline-secret-control-flow", "func::FuncOp"> let constructor = "createInlineSecretControlFlow()"; let dependentDialects = ["pphlo::PPHloDialect"]; } + +def RegionAccessFixture: Pass<"region-access-fixture", "func::FuncOp"> { + let summary = "Fix region access mismatched shape"; + let constructor = "createRegionAccessFixture()"; + let dependentDialects = ["pphlo::PPHloDialect"]; +} \ No newline at end of file diff --git a/libspu/dialect/pphlo/transforms/region_access_fixture.cc b/libspu/dialect/pphlo/transforms/region_access_fixture.cc new file mode 100644 index 00000000..3bd9d83a --- /dev/null +++ b/libspu/dialect/pphlo/transforms/region_access_fixture.cc @@ -0,0 +1,141 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mlir/Pass/Pass.h" + +#include "libspu/dialect/pphlo/IR/ops.h" +#include "libspu/dialect/pphlo/transforms/pass_details.h" + +namespace mlir::spu::pphlo { + +namespace { + +struct Deallocator { + public: + LogicalResult transformOp(Operation *op) { + for (auto &r : op->getRegions()) { + if (failed(transformRegion(r))) { + return failure(); + } + } + + const auto &operands = op->getOperands(); + + if (op->getNumOperands() < 2 || + !op->hasTrait<::mlir::OpTrait::Elementwise>() || + std::all_of(operands.begin(), operands.end(), [](const auto &operand) { + return operand.template getDefiningOp(); + })) { + return success(); + } + + auto *op_region = op->getParentRegion(); + + Value base_val; + llvm::SmallVector values_to_update; + + OpBuilder builder(op->getContext()); + builder.setInsertionPoint(op); + + for (const auto &[idx, operand] : llvm::enumerate(operands)) { + // Get defining region + auto *defining_op = operand.getDefiningOp(); + + Region *defining_region = nullptr; + + if (defining_op != nullptr) { + defining_region = defining_op->getParentRegion(); + } + + if (defining_op == nullptr || defining_region == op_region) { + // BlockArg or op defined in current region can be a base val + base_val = operand; + continue; + } + + if (defining_region != op_region) { + // This op is accessing a variable out of op's region. + // Insert a broadcast as to fix runtime shape mismatch during simd + // region execution + values_to_update.emplace_back(idx); + } + } + + if (!base_val) { + return values_to_update.empty() + ? failure() // same region however failed to pick base value + : success(); // can't pick base value since multi-level + // nesting + } + + for (const auto &idx : values_to_update) { + auto op_to_broadcast = op->getOperand(idx); + auto b = builder.create( + op->getLoc(), op_to_broadcast.getType(), op_to_broadcast, base_val); + op->setOperand(idx, b); + } + + return success(); + } + + LogicalResult transformBlock(Block &block) { + for (auto &op : llvm::make_early_inc_range(block.without_terminator())) { + auto opResult = transformOp(&op); + if (failed(opResult)) { + return failure(); + } + } + return success(); + } + + LogicalResult transformRegion(Region &r) { + for (auto &b : r.getBlocks()) { + if (failed(transformBlock(b))) { + return failure(); + } + } + return success(); + } + + LogicalResult transformFuncOp(func::FuncOp op) { + if (op->getNumRegions() == 0) { + return success(); + } + + // Transform function body. + if (failed(transformRegion(op.getBody()))) { + return failure(); + } + + return success(); + } +}; + +struct RegionAccessFixture + : public RegionAccessFixtureBase { + void runOnOperation() override { + if (failed(Deallocator().transformFuncOp(getOperation()))) { + signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr> createRegionAccessFixture() { + return std::make_unique(); +} + +} // namespace mlir::spu::pphlo diff --git a/libspu/kernel/hal/constants.cc b/libspu/kernel/hal/constants.cc index 1b44cbd0..8659c5e3 100644 --- a/libspu/kernel/hal/constants.cc +++ b/libspu/kernel/hal/constants.cc @@ -103,7 +103,7 @@ spu::Value zeros(SPUContext* ctx, DataType dtype, const Shape& shape) { } Value iota(SPUContext* ctx, DataType dtype, int64_t numel) { - return DISPATCH_ALL_NONE_BOOL_PT_TYPES(getDecodeType(dtype), "iota", [&]() { + return DISPATCH_ALL_NONE_BOOL_PT_TYPES(getDecodeType(dtype), [&]() { std::vector arr(numel); std::iota(arr.begin(), arr.end(), 0); return constant(ctx, arr, dtype, {numel}); diff --git a/libspu/kernel/hal/fxp_approx.cc b/libspu/kernel/hal/fxp_approx.cc index 5cda83df..0f9864fd 100644 --- a/libspu/kernel/hal/fxp_approx.cc +++ b/libspu/kernel/hal/fxp_approx.cc @@ -69,12 +69,12 @@ Value log_minmax(SPUContext* ctx, const Value& x) { // get most significant non-zero bit of x // we avoid direct using detail::highestOneBit for saving one _prefix_or - auto pre_x1 = _rshift(ctx, pre_x, 1); + auto pre_x1 = _rshift(ctx, pre_x, {1}); auto msb = _xor(ctx, pre_x, pre_x1); // let x = x_norm * factor, where x in [1.0, 2.0) auto factor = _bitrev(ctx, msb, 0, 2 * num_fxp_bits + 1).setDtype(x.dtype()); - detail::hintNumberOfBits(factor, 2 * num_fxp_bits + 1); + factor = maskNumberOfBits(ctx, factor, 2 * num_fxp_bits + 1); auto norm = f_mul(ctx, x, factor); // log(x) = log(x_norm * factor) @@ -83,7 +83,7 @@ Value log_minmax(SPUContext* ctx, const Value& x) { auto log_norm = log_minmax_normalized(ctx, norm); auto log2_e = _lshift(ctx, _sub(ctx, k, _constant(ctx, num_fxp_bits + 1, x.shape())), - num_fxp_bits) + {static_cast(num_fxp_bits)}) .setDtype(x.dtype()); auto k_log2 = constant(ctx, std::log(2), x.dtype(), x.shape()); auto log_e = f_mul(ctx, log2_e, k_log2); @@ -145,7 +145,7 @@ Value log2_pade(SPUContext* ctx, const Value& x) { // let x = x_norm * factor, where x in [0.5, 1.0) auto msb = detail::highestOneBit(ctx, x); auto factor = _bitrev(ctx, msb, 0, 2 * num_fxp_bits).setDtype(x.dtype()); - detail::hintNumberOfBits(factor, 2 * num_fxp_bits); + factor = maskNumberOfBits(ctx, factor, 2 * num_fxp_bits); auto norm = f_mul(ctx, x, factor); // log2(x) = log2(x_norm * factor) @@ -154,7 +154,7 @@ Value log2_pade(SPUContext* ctx, const Value& x) { return _add( ctx, log2_pade_normalized(ctx, norm), _lshift(ctx, _sub(ctx, k, _constant(ctx, num_fxp_bits, x.shape())), - num_fxp_bits)) + {static_cast(num_fxp_bits)})) .setDtype(x.dtype()); } @@ -260,15 +260,18 @@ Value exp2_pade(SPUContext* ctx, const Value& x) { const size_t bit_width = SizeOf(ctx->getField()) * 8; const auto x_bshare = _prefer_b(ctx, x); - const auto x_msb = _rshift(ctx, x_bshare, bit_width - 1); - auto x_integer = _rshift(ctx, x_bshare, fbits); + const auto x_msb = + _rshift(ctx, x_bshare, {static_cast(bit_width - 1)}); + auto x_integer = _rshift(ctx, x_bshare, {static_cast(fbits)}); auto x_fraction = - _sub(ctx, x, _lshift(ctx, x_integer, fbits)).setDtype(x.dtype()); + _sub(ctx, x, _lshift(ctx, x_integer, {static_cast(fbits)})) + .setDtype(x.dtype()); auto ret = exp2_pade_normalized(ctx, x_fraction); for (size_t idx = 0; idx < int_bits; idx++) { - auto a = _and(ctx, _rshift(ctx, x_integer, idx), k1); - detail::hintNumberOfBits(a, 1); + auto a = + _and(ctx, _rshift(ctx, x_integer, {static_cast(idx)}), k1); + a = detail::maskNumberOfBits(ctx, a, 1); a = _prefer_a(ctx, a); const auto K = 1U << std::min(1UL << idx, bit_width - 2); ret = _mul(ctx, ret, @@ -543,7 +546,7 @@ static Value rsqrt_init_guess(SPUContext* ctx, const Value& x, const Value& z) { // let u in [0.25, 0.5) auto z_rev = _bitrev(ctx, z, 0, 2 * f); - detail::hintNumberOfBits(z_rev, 2 * f); + z_rev = detail::maskNumberOfBits(ctx, z_rev, 2 * f); auto u = _trunc(ctx, _mul(ctx, x, z_rev)).setDtype(x.dtype()); @@ -583,17 +586,17 @@ static Value rsqrt_comp(SPUContext* ctx, const Value& x, const Value& z) { auto lo_mask = _constant(ctx, (static_cast(1) << (k / 2)) - 1, x.shape()); auto z_even = _and(ctx, z_sep, lo_mask); - auto z_odd = _and(ctx, _rshift(ctx, z_sep, k / 2), lo_mask); + auto z_odd = + _and(ctx, _rshift(ctx, z_sep, {static_cast(k / 2)}), lo_mask); // a[i] = z[2*i] ^ z[2*i+1] a = _xor(ctx, z_odd, z_even); // b ^= z[2*i] b = _bit_parity(ctx, z_even, k / 2); - detail::hintNumberOfBits(b, 1); } auto a_rev = _bitrev(ctx, a, 0, (f / 2) * 2); - detail::hintNumberOfBits(a_rev, (f / 2) * 2); + a_rev = detail::maskNumberOfBits(ctx, a_rev, (f / 2) * 2); // do compensation // Note: @@ -623,7 +626,7 @@ static Value rsqrt_np2(SPUContext* ctx, const Value& x) { SPU_TRACE_HAL_LEAF(ctx, x); // let e = NP2(x), z = 2^(e+f) - return _lshift(ctx, detail::highestOneBit(ctx, x), 1); + return _lshift(ctx, detail::highestOneBit(ctx, x), {1}); } // Reference: diff --git a/libspu/kernel/hal/fxp_base.cc b/libspu/kernel/hal/fxp_base.cc index 209782e1..601d42c7 100644 --- a/libspu/kernel/hal/fxp_base.cc +++ b/libspu/kernel/hal/fxp_base.cc @@ -72,7 +72,7 @@ Value polynomial(SPUContext* ctx, const Value& x, Value highestOneBit(SPUContext* ctx, const Value& x) { auto y = _prefix_or(ctx, x); - auto y1 = _rshift(ctx, y, 1); + auto y1 = _rshift(ctx, y, {1}); return _xor(ctx, y, y1); } @@ -85,6 +85,13 @@ void hintNumberOfBits(const Value& a, size_t nbits) { } } +Value maskNumberOfBits(SPUContext* ctx, const Value& in, size_t nbits) { + auto k1 = constant(ctx, static_cast(1), spu::DT_I64, in.shape()); + auto mask = _sub(ctx, _lshift(ctx, k1, {static_cast(nbits)}), k1); + auto out = _and(ctx, in, mask).setDtype(in.dtype()); + return out; +} + namespace { Value reciprocal_goldschmidt_normalized_approx(SPUContext* ctx, @@ -178,7 +185,8 @@ Value div_goldschmidt_general(SPUContext* ctx, const Value& a, const Value& b, // factor = 2^{f-m} = 2^{-m} * 2^f, the fixed point repr of 2^{-m} const size_t num_fxp_bits = ctx->getFxpBits(); auto factor = _bitrev(ctx, b_msb, 0, 2 * num_fxp_bits).setDtype(b.dtype()); - detail::hintNumberOfBits(factor, 2 * num_fxp_bits); + + factor = maskNumberOfBits(ctx, factor, 2 * num_fxp_bits); // also, we use factor twice factor = _prefer_a(ctx, factor); @@ -209,7 +217,7 @@ Value reciprocal_goldschmidt_positive(SPUContext* ctx, const Value& b_abs) { const size_t num_fxp_bits = ctx->getFxpBits(); auto factor = _bitrev(ctx, b_msb, 0, 2 * num_fxp_bits).setDtype(b_abs.dtype()); - detail::hintNumberOfBits(factor, 2 * num_fxp_bits); + factor = maskNumberOfBits(ctx, factor, 2 * num_fxp_bits); // also, we use factor twice factor = _prefer_a(ctx, factor); @@ -237,13 +245,12 @@ Value reciprocal_goldschmidt(SPUContext* ctx, const Value& b) { // factor = 2^{f-m} = 2^{-m} * 2^f, the fixed point repr of 2^{-m} const size_t num_fxp_bits = ctx->getFxpBits(); auto factor = _bitrev(ctx, b_msb, 0, 2 * num_fxp_bits).setDtype(b.dtype()); - detail::hintNumberOfBits(factor, 2 * num_fxp_bits); + factor = maskNumberOfBits(ctx, factor, 2 * num_fxp_bits); // also, we use factor twice factor = _prefer_a(ctx, factor); // compute approximation of normalize b_abs auto r = reciprocal_goldschmidt_normalized_approx(ctx, b_abs, factor); - r = f_mul(ctx, r, factor, SignType::Positive); return _mux(ctx, is_negative, _negate(ctx, r), r).setDtype(b.dtype()); @@ -370,8 +377,8 @@ Value f_floor(SPUContext* ctx, const Value& x) { SPU_ENFORCE(x.isFxp()); - const size_t fbits = ctx->getFxpBits(); - return _lshift(ctx, _arshift(ctx, x, fbits), fbits).setDtype(x.dtype()); + const int64_t fbits = ctx->getFxpBits(); + return _lshift(ctx, _arshift(ctx, x, {fbits}), {fbits}).setDtype(x.dtype()); } Value f_ceil(SPUContext* ctx, const Value& x) { diff --git a/libspu/kernel/hal/fxp_base.h b/libspu/kernel/hal/fxp_base.h index c72022b6..61fe7bca 100644 --- a/libspu/kernel/hal/fxp_base.h +++ b/libspu/kernel/hal/fxp_base.h @@ -30,6 +30,8 @@ Value highestOneBit(SPUContext* ctx, const Value& x); void hintNumberOfBits(const Value& a, size_t nbits); +Value maskNumberOfBits(SPUContext* ctx, const Value& a, size_t nbits); + // we provide this general function to support some special cases (a or b has // guarranteed sign) in fxp_approx for better both performance and accuracy. Value div_goldschmidt_general(SPUContext* ctx, const Value& a, const Value& b, diff --git a/libspu/kernel/hal/fxp_cleartext.cc b/libspu/kernel/hal/fxp_cleartext.cc index 818e0b23..b5061d5b 100644 --- a/libspu/kernel/hal/fxp_cleartext.cc +++ b/libspu/kernel/hal/fxp_cleartext.cc @@ -57,7 +57,7 @@ Value applyFloatingPointFn(SPUContext* ctx, const Value& in, FN&& fn) { auto pt_type = getDecodeType(in.dtype()); for (auto iter = fp_arr.begin(); iter != fp_arr.end(); ++iter) { - DISPATCH_FLOAT_PT_TYPES(pt_type, "pt_type", [&]() { + DISPATCH_FLOAT_PT_TYPES(pt_type, [&]() { auto* ptr = reinterpret_cast(&*iter); *ptr = fn(*ptr); }); @@ -92,9 +92,9 @@ Value applyFloatingPointFn(SPUContext* ctx, const Value& x, const Value& y, for (auto itr_x = flp_x.begin(), itr_y = flp_y.begin(); itr_x != flp_x.end(); itr_x++, itr_y++) { - DISPATCH_FLOAT_PT_TYPES(x_pt_type, "x_pt_type", [&]() { + DISPATCH_FLOAT_PT_TYPES(x_pt_type, [&]() { auto* ptr_x = reinterpret_cast(&*itr_x); - DISPATCH_FLOAT_PT_TYPES(y_pt_type, "y_pt_type", [&]() { + DISPATCH_FLOAT_PT_TYPES(y_pt_type, [&]() { auto* ptr_y = reinterpret_cast(&*itr_y); *ptr_x = fn(*ptr_x, *ptr_y); }); diff --git a/libspu/kernel/hal/permute.cc b/libspu/kernel/hal/permute.cc index 657b81fc..7ae47cda 100644 --- a/libspu/kernel/hal/permute.cc +++ b/libspu/kernel/hal/permute.cc @@ -19,6 +19,7 @@ #include "libspu/core/bit_utils.h" #include "libspu/core/context.h" #include "libspu/core/trace.h" +#include "libspu/core/vectorize.h" #include "libspu/kernel/hal/constants.h" #include "libspu/kernel/hal/polymorphic.h" #include "libspu/kernel/hal/prot_wrapper.h" @@ -43,6 +44,12 @@ inline bool _has_same_owner(const Value &x, const Value &y) { return _get_owner(x) == _get_owner(y); } +void _hint_nbits(const Value &a, size_t nbits) { + if (a.storage_type().isa()) { + const_cast(a.storage_type()).as()->setNbits(nbits); + } +} + // generate inverse permutation Index _inverse_index(const Index &p) { Index q(p.size()); @@ -531,20 +538,30 @@ spu::Value _opt_apply_perm_ss(SPUContext *ctx, const spu::Value &perm, std::vector _bit_decompose(SPUContext *ctx, const spu::Value &x, int64_t valid_bits) { auto x_bshare = _prefer_b(ctx, x); - const auto k1 = _constant(ctx, 1U, x.shape()); - std::vector rets; size_t nbits = valid_bits != -1 ? static_cast(valid_bits) : x_bshare.storage_type().as()->nbits(); - rets.reserve(nbits); + _hint_nbits(x_bshare, nbits); + if (ctx->hasKernel("b2a_disassemble")) { + auto ret = + dynDispatch>(ctx, "b2a_disassemble", x_bshare); + return ret; + } + + const auto k1 = _constant(ctx, 1U, x.shape()); + std::vector rets_b; + rets_b.reserve(nbits); for (size_t bit = 0; bit < nbits; ++bit) { - auto x_bshare_shift = right_shift_logical(ctx, x_bshare, bit); - auto lowest_bit = _and(ctx, x_bshare_shift, k1); - rets.emplace_back(_prefer_a(ctx, lowest_bit)); + auto x_bshare_shift = + right_shift_logical(ctx, x_bshare, {static_cast(bit)}); + rets_b.push_back(_and(ctx, x_bshare_shift, k1)); } - return rets; + std::vector rets_a; + vmap(rets_b.begin(), rets_b.end(), std::back_inserter(rets_a), + [&](const Value &x) { return _prefer_a(ctx, x); }); + return rets_a; } // Generate vector of bit decomposition of sorting keys @@ -687,10 +704,10 @@ spu::Value _apply_perm_ss(SPUContext *ctx, const Value &x, const Value &perm) { // Find mergeable keys from keys. Consecutive public/private(belong to one // owner) keys can be merged. Assume there are six keys, i.e., public_key0, -// bob_key0, bob_key1, alice_key0, alice_key1, secret_key0. We can merge the six -// keys into bob_new_key, alice_new_key, secret_key0 for the following sorting. -// This function will return a vector of indices [3,5,6] which means key[0,3), -// key[3,5), and key[5,6) can be merged. +// bob_key0, bob_key1, alice_key0, alice_key1, secret_key0. We can merge the +// six keys into bob_new_key, alice_new_key, secret_key0 for the following +// sorting. This function will return a vector of indices [3,5,6] which means +// key[0,3), key[3,5), and key[5,6) can be merged. std::vector _find_mergeable_keys(SPUContext *ctx, absl::Span keys) { std::vector split_indices; diff --git a/libspu/kernel/hal/polymorphic.cc b/libspu/kernel/hal/polymorphic.cc index 34cf5f77..21680cc6 100644 --- a/libspu/kernel/hal/polymorphic.cc +++ b/libspu/kernel/hal/polymorphic.cc @@ -356,7 +356,7 @@ Value power(SPUContext* ctx, const Value& x, const Value& y) { const auto bit_width = SizeOf(ctx->getField()) * 8; auto y_b = _prefer_b(ctx, y); - auto msb_y = _rshift(ctx, y_b, bit_width - 1); + auto msb_y = _rshift(ctx, y_b, {static_cast(bit_width - 1)}); auto x_abs1 = _equal(ctx, abs(ctx, x), k1); auto ret = _constant(ctx, 1, x.shape()); @@ -379,7 +379,9 @@ Value power(SPUContext* ctx, const Value& x, const Value& y) { // e.g. y=0101, then ret = (x) * (1) * (x^(2^2)) * (1) = x^5 for (size_t idx = 0; idx < y_bits; idx++) { // x^(2^idx) * y_{idx} - auto cur_pow = _mux(ctx, _and(ctx, _rshift(ctx, y_b, idx), k1), base, k1); + auto cur_pow = _mux( + ctx, _and(ctx, _rshift(ctx, y_b, {static_cast(idx)}), k1), + base, k1); ret = _mul(ctx, cur_pow, ret); if (idx < y_bits - 1) { base = _mul(ctx, base, base); @@ -409,8 +411,9 @@ Value power(SPUContext* ctx, const Value& x, const Value& y) { // the final sign is decided on both sign of x and the parity of y // when x<0 and y is odd, e.g. (-2)^3 = -8 - auto odd = _and(ctx, _rshift(ctx, y, ctx->getFxpBits()), - _constant(ctx, 1, y.shape())); + auto odd = + _and(ctx, _rshift(ctx, y, {static_cast(ctx->getFxpBits())}), + _constant(ctx, 1, y.shape())); auto sign = _and(ctx, msb, odd); return _mux(ctx, sign, _negate(ctx, val), val).setDtype(x.dtype()); @@ -488,19 +491,20 @@ Value bitcast(SPUContext* ctx, const Value& x, DataType dtype) { return Value(x.data().clone(), dtype); } -Value left_shift(SPUContext* ctx, const Value& x, size_t bits) { +Value left_shift(SPUContext* ctx, const Value& x, const Sizes& bits) { SPU_TRACE_HAL_DISP(ctx, x, bits); return _lshift(ctx, x, bits).setDtype(x.dtype()); } -Value right_shift_logical(SPUContext* ctx, const Value& x, size_t bits) { +Value right_shift_logical(SPUContext* ctx, const Value& x, const Sizes& bits) { SPU_TRACE_HAL_DISP(ctx, x, bits); return _rshift(ctx, x, bits).setDtype(x.dtype()); } -Value right_shift_arithmetic(SPUContext* ctx, const Value& x, size_t bits) { +Value right_shift_arithmetic(SPUContext* ctx, const Value& x, + const Sizes& bits) { SPU_TRACE_HAL_DISP(ctx, x, bits); return _arshift(ctx, x, bits).setDtype(x.dtype()); diff --git a/libspu/kernel/hal/polymorphic.h b/libspu/kernel/hal/polymorphic.h index 58e90f00..9b47b85d 100644 --- a/libspu/kernel/hal/polymorphic.h +++ b/libspu/kernel/hal/polymorphic.h @@ -187,11 +187,12 @@ Value clamp(SPUContext* ctx, const Value& x, const Value& min, // @param dtype, second input value Value bitcast(SPUContext* ctx, const Value& x, DataType dtype); -Value left_shift(SPUContext* ctx, const Value& x, size_t bits); +Value left_shift(SPUContext* ctx, const Value& x, const Sizes& bits); -Value right_shift_logical(SPUContext* ctx, const Value& x, size_t bits); +Value right_shift_logical(SPUContext* ctx, const Value& x, const Sizes& bits); -Value right_shift_arithmetic(SPUContext* ctx, const Value& x, size_t bits); +Value right_shift_arithmetic(SPUContext* ctx, const Value& x, + const Sizes& bits); /// the element-wise base-2 logarithm of x // @param in, should be positive, or the result is implementation defined. diff --git a/libspu/kernel/hal/prot_wrapper.cc b/libspu/kernel/hal/prot_wrapper.cc index 10743c10..23f46388 100644 --- a/libspu/kernel/hal/prot_wrapper.cc +++ b/libspu/kernel/hal/prot_wrapper.cc @@ -30,11 +30,11 @@ namespace spu::kernel::hal { return mpc::NAME(ctx, in); \ } -#define MAP_SHIFT_OP(NAME) \ - Value _##NAME(SPUContext* ctx, const Value& in, size_t bits) { \ - SPU_TRACE_HAL_DISP(ctx, in, bits); \ - auto ret = mpc::NAME(ctx, in, bits); \ - return ret; \ +#define MAP_SHIFT_OP(NAME) \ + Value _##NAME(SPUContext* ctx, const Value& in, const Sizes& bits) { \ + SPU_TRACE_HAL_DISP(ctx, in, bits); \ + auto ret = mpc::NAME(ctx, in, bits); \ + return ret; \ } #define MAP_BITREV_OP(NAME) \ diff --git a/libspu/kernel/hal/prot_wrapper.h b/libspu/kernel/hal/prot_wrapper.h index 6294f834..5fb110bc 100644 --- a/libspu/kernel/hal/prot_wrapper.h +++ b/libspu/kernel/hal/prot_wrapper.h @@ -52,17 +52,17 @@ Value _equal_pp(SPUContext* ctx, const Value& x, const Value& y); std::optional _equal_sp(SPUContext* ctx, const Value& x, const Value& y); std::optional _equal_ss(SPUContext* ctx, const Value& x, const Value& y); -Value _lshift_p(SPUContext* ctx, const Value& in, size_t bits); -Value _lshift_s(SPUContext* ctx, const Value& in, size_t bits); -Value _lshift_v(SPUContext* ctx, const Value& in, size_t bits); +Value _lshift_p(SPUContext* ctx, const Value& in, const Sizes& bits); +Value _lshift_s(SPUContext* ctx, const Value& in, const Sizes& bits); +Value _lshift_v(SPUContext* ctx, const Value& in, const Sizes& bits); -Value _rshift_p(SPUContext* ctx, const Value& in, size_t bits); -Value _rshift_s(SPUContext* ctx, const Value& in, size_t bits); -Value _rshift_v(SPUContext* ctx, const Value& in, size_t bits); +Value _rshift_p(SPUContext* ctx, const Value& in, const Sizes& bits); +Value _rshift_s(SPUContext* ctx, const Value& in, const Sizes& bits); +Value _rshift_v(SPUContext* ctx, const Value& in, const Sizes& bits); -Value _arshift_p(SPUContext* ctx, const Value& in, size_t bits); -Value _arshift_s(SPUContext* ctx, const Value& in, size_t bits); -Value _arshift_v(SPUContext* ctx, const Value& in, size_t bits); +Value _arshift_p(SPUContext* ctx, const Value& in, const Sizes& bits); +Value _arshift_s(SPUContext* ctx, const Value& in, const Sizes& bits); +Value _arshift_v(SPUContext* ctx, const Value& in, const Sizes& bits); Value _trunc_p(SPUContext* ctx, const Value& in, size_t bits, SignType sign); Value _trunc_s(SPUContext* ctx, const Value& in, size_t bits, SignType sign); diff --git a/libspu/kernel/hal/ring.cc b/libspu/kernel/hal/ring.cc index 08b873f6..b016cd71 100644 --- a/libspu/kernel/hal/ring.cc +++ b/libspu/kernel/hal/ring.cc @@ -84,18 +84,18 @@ IMPL_UNARY_OP(_square) #undef IMPL_UNARY_OP -#define IMPL_SHIFT_OP(Name) \ - Value Name(SPUContext* ctx, const Value& in, size_t bits) { \ - SPU_TRACE_HAL_LEAF(ctx, in, bits); \ - if (in.isPublic()) { \ - return Name##_p(ctx, in, bits); \ - } else if (in.isSecret()) { \ - return Name##_s(ctx, in, bits); \ - } else if (in.isPrivate()) { \ - return Name##_v(ctx, in, bits); \ - } else { \ - SPU_THROW("unsupport unary op={} for {}", #Name, in); \ - } \ +#define IMPL_SHIFT_OP(Name) \ + Value Name(SPUContext* ctx, const Value& in, const Sizes& bits) { \ + SPU_TRACE_HAL_LEAF(ctx, in, bits); \ + if (in.isPublic()) { \ + return Name##_p(ctx, in, bits); \ + } else if (in.isSecret()) { \ + return Name##_s(ctx, in, bits); \ + } else if (in.isPrivate()) { \ + return Name##_v(ctx, in, bits); \ + } else { \ + SPU_THROW("unsupport unary op={} for {}", #Name, in); \ + } \ } IMPL_SHIFT_OP(_lshift) @@ -497,7 +497,7 @@ Value _bit_parity(SPUContext* ctx, const Value& x, size_t bits) { SPU_ENFORCE(absl::has_single_bit(bits), "currently only support power of 2"); auto ret = _prefer_b(ctx, x); while (bits > 1) { - ret = _xor(ctx, ret, _rshift(ctx, ret, bits / 2)); + ret = _xor(ctx, ret, _rshift(ctx, ret, {static_cast(bits / 2)})); bits /= 2; } @@ -518,7 +518,7 @@ Value _popcount(SPUContext* ctx, const Value& x, size_t bits) { std::vector vs; vs.reserve(bits); for (size_t idx = 0; idx < bits; idx++) { - auto x_ = _rshift(ctx, xb, idx); + auto x_ = _rshift(ctx, xb, {static_cast(idx)}); x_ = _and(ctx, x_, _constant(ctx, 1U, x.shape())); if (x_.storage_type().isa()) { @@ -547,8 +547,8 @@ Value _prefix_or(SPUContext* ctx, const Value& x) { auto b0 = _prefer_b(ctx, x); const size_t bit_width = SizeOf(ctx->getField()) * 8; for (int idx = 0; idx < absl::bit_width(bit_width) - 1; idx++) { - const size_t offset = 1UL << idx; - auto b1 = _rshift(ctx, b0, offset); + const int64_t offset = 1L << idx; + auto b1 = _rshift(ctx, b0, {offset}); b0 = _or(ctx, b0, b1); } return b0; @@ -574,8 +574,8 @@ Value _bitdeintl(SPUContext* ctx, const Value& in) { // out = (out & keep) ^ ((out >> shift) & move) ^ ((out & move) << shift); out = _xor(ctx, _xor(ctx, _and(ctx, out, keep), - _and(ctx, _rshift(ctx, out, shift), move)), - _lshift(ctx, _and(ctx, out, move), shift)); + _and(ctx, _rshift(ctx, out, {shift}), move)), + _lshift(ctx, _and(ctx, out, move), {shift})); } return out; } diff --git a/libspu/kernel/hal/ring.h b/libspu/kernel/hal/ring.h index b2fb6f65..0dd7234a 100644 --- a/libspu/kernel/hal/ring.h +++ b/libspu/kernel/hal/ring.h @@ -71,11 +71,11 @@ Value _equal(SPUContext* ctx, const Value& x, const Value& y); Value _less(SPUContext* ctx, const Value& x, const Value& y); -Value _lshift(SPUContext* ctx, const Value& in, size_t bits); +Value _lshift(SPUContext* ctx, const Value& in, const Sizes& bits); -Value _rshift(SPUContext* ctx, const Value& in, size_t bits); +Value _rshift(SPUContext* ctx, const Value& in, const Sizes& bits); -Value _arshift(SPUContext* ctx, const Value& in, size_t bits); +Value _arshift(SPUContext* ctx, const Value& in, const Sizes& bits); Value _trunc(SPUContext* ctx, const Value& x, size_t bits = 0, SignType sign = SignType::Unknown); diff --git a/libspu/kernel/hal/type_cast.cc b/libspu/kernel/hal/type_cast.cc index 53b65742..9adb77a8 100644 --- a/libspu/kernel/hal/type_cast.cc +++ b/libspu/kernel/hal/type_cast.cc @@ -27,7 +27,8 @@ Value int2fxp(SPUContext* ctx, const Value& x, DataType to_type) { SPU_TRACE_HAL_LEAF(ctx, x); SPU_ENFORCE(x.isInt(), "expect integer, got {}", x.dtype()); - return _lshift(ctx, x, ctx->getFxpBits()).setDtype(to_type); + return _lshift(ctx, x, {static_cast(ctx->getFxpBits())}) + .setDtype(to_type); } // Casting fxp to integer. @@ -49,12 +50,12 @@ Value fxp2int(SPUContext* ctx, const Value& x, DataType to_type) { SPU_TRACE_HAL_LEAF(ctx, x); SPU_ENFORCE(x.isFxp()); - const size_t fxp_bits = ctx->getFxpBits(); + const int64_t fxp_bits = ctx->getFxpBits(); const Value kOneMinusEps = _constant(ctx, (1 << fxp_bits) - 1, x.shape()); // (x + 0.99 * (x < 0)) >> fxp_bits return _arshift(ctx, _add(ctx, x, _mul(ctx, kOneMinusEps, _msb(ctx, x))), - fxp_bits) + {fxp_bits}) .setDtype(to_type); } diff --git a/libspu/kernel/hlo/basic_unary.cc b/libspu/kernel/hlo/basic_unary.cc index ce1d2c17..44cc9e36 100644 --- a/libspu/kernel/hlo/basic_unary.cc +++ b/libspu/kernel/hlo/basic_unary.cc @@ -118,19 +118,19 @@ spu::Value Round_RNTE(SPUContext *ctx, const spu::Value &in) { // so comp = b && (c || a) SPU_ENFORCE(!in.isComplex()); SPU_ENFORCE(in.isFxp(), "Round only supports fxp"); - const auto fxp_bits = ctx->getFxpBits(); + const int64_t fxp_bits = ctx->getFxpBits(); const auto k1 = hal::_constant(ctx, 1U, in.shape()); auto x_prime = hal::_prefer_b(ctx, in); auto y = hal::floor(ctx, x_prime); - auto a = hal::_and(ctx, hal::_rshift(ctx, x_prime, fxp_bits), k1); - auto b = hal::_and(ctx, hal::_rshift(ctx, x_prime, fxp_bits - 1), k1); + auto a = hal::_and(ctx, hal::_rshift(ctx, x_prime, {fxp_bits}), k1); + auto b = hal::_and(ctx, hal::_rshift(ctx, x_prime, {fxp_bits - 1}), k1); std::vector cs; cs.reserve(fxp_bits - 1); - for (size_t idx = 0; idx < fxp_bits - 1; idx++) { - auto x_ = hal::_and(ctx, hal::_rshift(ctx, x_prime, idx), k1); + for (int64_t idx = 0; idx < fxp_bits - 1; idx++) { + auto x_ = hal::_and(ctx, hal::_rshift(ctx, x_prime, {idx}), k1); cs.push_back(std::move(x_)); } auto c = vreduce(cs.begin(), cs.end(), [&](const Value &a, const Value &b) { diff --git a/libspu/kernel/hlo/shift.cc b/libspu/kernel/hlo/shift.cc index 08fe7140..0b6f3f1d 100644 --- a/libspu/kernel/hlo/shift.cc +++ b/libspu/kernel/hlo/shift.cc @@ -26,31 +26,8 @@ namespace spu::kernel::hlo { template spu::Value shift_impl_p(SPUContext *ctx, const spu::Value &lhs, const spu::Value &rhs, const Fn &f) { - auto shift_bits = hal::dump_public_as(ctx, rhs); - if (std::all_of(rhs.strides().begin(), rhs.strides().end(), - [](int64_t s) { return s == 0; })) { - // rhs is a splat - return f(ctx, lhs, shift_bits[0]); - } - - // Not a splat... - spu::Value ret = - hal::constant(ctx, static_cast(0), lhs.dtype(), lhs.shape()); - auto dtype_size = getWidth(lhs.dtype()); - for (size_t bits = 0; bits < dtype_size; ++bits) { - if (std::none_of(shift_bits.begin(), shift_bits.end(), [&bits](int8_t b) { - return b == static_cast(bits); - })) { - continue; - } - auto current_bits = hal::constant(ctx, static_cast(bits), - rhs.dtype(), rhs.shape()); - auto mask = hal::equal(ctx, rhs, current_bits); - auto shifted = f(ctx, lhs, bits); - ret = hal::add(ctx, ret, hal::mul(ctx, mask, shifted)); - } - - return ret; + auto shift_bits = hal::dump_public_as(ctx, rhs); + return f(ctx, lhs, {shift_bits.begin(), shift_bits.end()}); } template @@ -63,7 +40,7 @@ spu::Value shift_impl_s(SPUContext *ctx, const spu::Value &lhs, auto current_bits = hal::constant(ctx, static_cast(bits), rhs.dtype(), rhs.shape()); auto mask = hal::equal(ctx, rhs, current_bits); - auto shifted = f(ctx, lhs, bits); + auto shifted = f(ctx, lhs, {static_cast(bits)}); ret = hal::add(ctx, ret, hal::mul(ctx, mask, shifted)); } diff --git a/libspu/mpc/ab_api.cc b/libspu/mpc/ab_api.cc index a52de586..a0720870 100644 --- a/libspu/mpc/ab_api.cc +++ b/libspu/mpc/ab_api.cc @@ -133,7 +133,7 @@ OptionalAPI mul_a1bv(SPUContext* ctx, const Value& x, const Value& y) { return NotAvailable; } -Value lshift_a(SPUContext* ctx, const Value& x, size_t nbits) { +Value lshift_a(SPUContext* ctx, const Value& x, const Sizes& nbits) { FORCE_DISPATCH(ctx, x, nbits); } @@ -201,15 +201,15 @@ OptionalAPI xor_bv(SPUContext* ctx, const Value& x, const Value& y) { return NotAvailable; } -Value lshift_b(SPUContext* ctx, const Value& x, size_t nbits) { +Value lshift_b(SPUContext* ctx, const Value& x, const Sizes& nbits) { FORCE_DISPATCH(ctx, x, nbits); } -Value rshift_b(SPUContext* ctx, const Value& x, size_t nbits) { +Value rshift_b(SPUContext* ctx, const Value& x, const Sizes& nbits) { FORCE_DISPATCH(ctx, x, nbits); } -Value arshift_b(SPUContext* ctx, const Value& x, size_t nbits) { +Value arshift_b(SPUContext* ctx, const Value& x, const Sizes& nbits) { FORCE_DISPATCH(ctx, x, nbits); } @@ -252,12 +252,12 @@ Value bitintl_b(SPUContext* ctx, const Value& x, size_t stride) { idx--) { auto K = hack_make_p(ctx, spu::detail::kBitIntlKeepMasks[idx], x.shape()); auto M = hack_make_p(ctx, spu::detail::kBitIntlSwapMasks[idx], x.shape()); - int64_t S = 1 << idx; + int64_t S = static_cast(1) << idx; // out = (out & K) ^ ((out >> S) & M) ^ ((out & M) << S); - out = xor_bb( - ctx, - xor_bb(ctx, and_bp(ctx, out, K), and_bp(ctx, rshift_b(ctx, out, S), M)), - lshift_b(ctx, and_bp(ctx, out, M), S)); + out = xor_bb(ctx, + xor_bb(ctx, and_bp(ctx, out, K), + and_bp(ctx, rshift_b(ctx, out, {S}), M)), + lshift_b(ctx, and_bp(ctx, out, M), {S})); } out = setNumBits(out, numBits(x)); return out; @@ -281,12 +281,12 @@ Value bitdeintl_b(SPUContext* ctx, const Value& x, size_t stride) { for (int64_t idx = stride; idx + 1 < Log2Ceil(nbits); idx++) { auto K = hack_make_p(ctx, spu::detail::kBitIntlKeepMasks[idx], x.shape()); auto M = hack_make_p(ctx, spu::detail::kBitIntlSwapMasks[idx], x.shape()); - int64_t S = 1 << idx; + int64_t S = static_cast(1) << idx; // out = (out & K) ^ ((out >> S) & M) ^ ((out & M) << S); - out = xor_bb( - ctx, - xor_bb(ctx, and_bp(ctx, out, K), and_bp(ctx, rshift_b(ctx, out, S), M)), - lshift_b(ctx, and_bp(ctx, out, M), S)); + out = xor_bb(ctx, + xor_bb(ctx, and_bp(ctx, out, K), + and_bp(ctx, rshift_b(ctx, out, {S}), M)), + lshift_b(ctx, and_bp(ctx, out, M), {S})); } out = setNumBits(out, numBits(x)); return out; @@ -318,9 +318,9 @@ Value ppa_kogge_stone(SPUContext* ctx, const Value& lhs, const Value& rhs, auto G = and_bb(ctx, lhs, rhs); for (int idx = 0; idx < Log2Ceil(nbits); ++idx) { - const size_t offset = 1UL << idx; - auto G1 = lshift_b(ctx, G, offset); - auto P1 = lshift_b(ctx, P, offset); + const int64_t offset = static_cast(1) << idx; + auto G1 = lshift_b(ctx, G, {offset}); + auto P1 = lshift_b(ctx, P, {offset}); // P1 = P & P1 // G1 = G ^ (P & G1) @@ -332,7 +332,7 @@ Value ppa_kogge_stone(SPUContext* ctx, const Value& lhs, const Value& rhs, } // out = (G << 1) ^ p0 - auto C = lshift_b(ctx, G, 1); + auto C = lshift_b(ctx, G, {1}); return xor_bb(ctx, xor_bb(ctx, lhs, rhs), C); } @@ -343,7 +343,7 @@ std::pair bit_scatter(SPUContext* ctx, const Value& in, SPU_ENFORCE(absl::has_single_bit(nbits), "unsupported {}", nbits); auto out = bitdeintl_b(ctx, in, stride); - auto hi = rshift_b(ctx, out, nbits / 2); + auto hi = rshift_b(ctx, out, {static_cast(nbits / 2)}); auto mask = hack_make_p(ctx, (static_cast(1) << (nbits / 2)) - 1, in.shape()); auto lo = and_bp(ctx, out, mask); @@ -357,7 +357,7 @@ Value bit_gather(SPUContext* ctx, const Value& hi, const Value& lo, SPU_ENFORCE(nbits == numBits(lo), "nbits mismatch {}, {}", nbits, numBits(lo)); - auto out = xor_bb(ctx, lshift_b(ctx, hi, nbits), lo); + auto out = xor_bb(ctx, lshift_b(ctx, hi, {static_cast(nbits)}), lo); return bitintl_b(ctx, out, stride); } @@ -395,8 +395,8 @@ Value ppa_sklansky(SPUContext* ctx, Value const& lhs, Value const& rhs, auto Gs = and_bp(ctx, Gl, s_mask); auto Ps = and_bp(ctx, Pl, s_mask); for (int j = 0; j < idx; j++) { - Gs = xor_bb(ctx, Gs, rshift_b(ctx, Gs, 1 << j)); - Ps = xor_bb(ctx, Ps, rshift_b(ctx, Ps, 1 << j)); + Gs = xor_bb(ctx, Gs, rshift_b(ctx, Gs, {1 << j})); + Ps = xor_bb(ctx, Ps, rshift_b(ctx, Ps, {1 << j})); } // SPU_ENFORCE(numBits(Ps) == bit_width / 2); // SPU_ENFORCE(numBits(Gs) == bit_width / 2); @@ -416,7 +416,7 @@ Value ppa_sklansky(SPUContext* ctx, Value const& lhs, Value const& rhs, } // out = (G0 << 1) ^ p0 - auto C = lshift_b(ctx, G, 1); + auto C = lshift_b(ctx, G, {1}); return xor_bb(ctx, xor_bb(ctx, lhs, rhs), C); } @@ -460,8 +460,8 @@ Value carry_a2b(SPUContext* ctx, const Value& x, const Value& y, size_t k) { while (k > 1) { if (k % 2 != 0) { k += 1; - P = lshift_b(ctx, P, 1); - G = lshift_b(ctx, G, 1); + P = lshift_b(ctx, P, {1}); + G = lshift_b(ctx, G, {1}); } auto [P1, P0] = bit_scatter(ctx, P, 0); auto [G1, G0] = bit_scatter(ctx, G, 0); diff --git a/libspu/mpc/ab_api.h b/libspu/mpc/ab_api.h index 72a8476f..94f17d98 100644 --- a/libspu/mpc/ab_api.h +++ b/libspu/mpc/ab_api.h @@ -46,7 +46,7 @@ OptionalAPI mul_av(SPUContext* ctx, const Value& x, const Value& y); Value mul_a1b(SPUContext* ctx, const Value& x, const Value& y); OptionalAPI mul_a1bv(SPUContext* ctx, const Value& x, const Value& y); -Value lshift_a(SPUContext* ctx, const Value& x, size_t nbits); +Value lshift_a(SPUContext* ctx, const Value& x, const Sizes& nbits); Value trunc_a(SPUContext* ctx, const Value& x, size_t nbits, SignType sign); Value mmul_ap(SPUContext* ctx, const Value& x, const Value& y); @@ -72,9 +72,9 @@ Value xor_bb(SPUContext* ctx, const Value& x, const Value& y); OptionalAPI xor_bv(SPUContext* ctx, const Value& x, const Value& y); // TODO -Value lshift_b(SPUContext* ctx, const Value& x, size_t nbits); -Value rshift_b(SPUContext* ctx, const Value& x, size_t nbits); -Value arshift_b(SPUContext* ctx, const Value& x, size_t nbits); +Value lshift_b(SPUContext* ctx, const Value& x, const Sizes& nbits); +Value rshift_b(SPUContext* ctx, const Value& x, const Sizes& nbits); +Value arshift_b(SPUContext* ctx, const Value& x, const Sizes& nbits); // Bit reverse for binary share. Value bitrev_b(SPUContext* ctx, const Value& x, size_t start, size_t end); diff --git a/libspu/mpc/ab_api_test.cc b/libspu/mpc/ab_api_test.cc index 13ef4080..592f9389 100644 --- a/libspu/mpc/ab_api_test.cc +++ b/libspu/mpc/ab_api_test.cc @@ -195,7 +195,7 @@ TEST_P(ArithmeticTest, MulA1B) { return; } - const size_t K = spu::SizeOf(conf.field()) * 8; + const int64_t K = spu::SizeOf(conf.field()) * 8; /* GIVEN */ auto p0 = rand_p(obj.get(), conf.protocol() == ProtocolKind::CHEETAH @@ -204,12 +204,12 @@ TEST_P(ArithmeticTest, MulA1B) { auto p1 = rand_p(obj.get(), conf.protocol() == ProtocolKind::CHEETAH ? Shape({200, 26}) : kShape); - p1 = rshift_p(obj.get(), p1, K - 1); + p1 = rshift_p(obj.get(), p1, {K - 1}); auto a0 = p2a(obj.get(), p0); auto a1 = p2b(obj.get(), p1); // hint runtime this is a 1bit value. - a1 = lshift_b(obj.get(), a1, K - 1); - a1 = rshift_b(obj.get(), a1, K - 1); + a1 = lshift_b(obj.get(), a1, {K - 1}); + a1 = rshift_b(obj.get(), a1, {K - 1}); /* WHEN */ auto prev = obj->prot()->getState()->getStats(); @@ -238,12 +238,12 @@ TEST_P(ArithmeticTest, MulAV) { return; } - const size_t K = spu::SizeOf(conf.field()) * 8; + const int64_t K = spu::SizeOf(conf.field()) * 8; /* GIVEN */ auto p0 = rand_p(obj.get(), kShape); auto p1 = rand_p(obj.get(), kShape); - p1 = rshift_p(obj.get(), p1, K - 1); + p1 = rshift_p(obj.get(), p1, {K - 1}); auto a0 = p2a(obj.get(), p0); auto a1 = p2v(obj.get(), p1, 0); @@ -275,17 +275,17 @@ TEST_P(ArithmeticTest, MulA1BV) { return; } - const size_t K = spu::SizeOf(conf.field()) * 8; + const int64_t K = spu::SizeOf(conf.field()) * 8; /* GIVEN */ auto p0 = rand_p(obj.get(), kShape); auto p1 = rand_p(obj.get(), kShape); - p1 = rshift_p(obj.get(), p1, K - 1); + p1 = rshift_p(obj.get(), p1, {K - 1}); auto a0 = p2a(obj.get(), p0); auto a1 = p2v(obj.get(), p1, 0); // hint runtime this is a 1bit value. - a1 = lshift_v(obj.get(), a1, K - 1); - a1 = rshift_v(obj.get(), a1, K - 1); + a1 = lshift_v(obj.get(), a1, {K - 1}); + a1 = rshift_v(obj.get(), a1, {K - 1}); // auto a1 = b2v(obj.get(), _a1, 0); /* WHEN */ @@ -482,10 +482,10 @@ TEST_P(ArithmeticTest, LShiftA) { } /* WHEN */ auto prev = obj->prot()->getState()->getStats(); - auto tmp = lshift_a(obj.get(), a0, bits); + auto tmp = lshift_a(obj.get(), a0, {static_cast(bits)}); auto cost = obj->prot()->getState()->getStats() - prev; auto r_b = a2p(obj.get(), tmp); - auto r_p = lshift_p(obj.get(), p0, bits); + auto r_p = lshift_p(obj.get(), p0, {static_cast(bits)}); /* THEN */ EXPECT_VALUE_EQ(r_b, r_p); @@ -513,10 +513,11 @@ TEST_P(ArithmeticTest, TruncA) { if (!kernel->hasMsbError()) { // trunc requires MSB to be zero. - p0 = arshift_p(obj.get(), p0, 1); + p0 = arshift_p(obj.get(), p0, {1}); } else { // has msb error, only use lowest 10 bits. - p0 = arshift_p(obj.get(), p0, SizeOf(conf.field()) * 8 - 10); + p0 = arshift_p(obj.get(), p0, + {static_cast(SizeOf(conf.field()) * 8 - 10)}); } /* GIVEN */ @@ -529,7 +530,7 @@ TEST_P(ArithmeticTest, TruncA) { auto cost = obj->prot()->getState()->getStats() - prev; auto r_a = a2p(obj.get(), a1); - auto r_p = arshift_p(obj.get(), p0, bits); + auto r_p = arshift_p(obj.get(), p0, {static_cast(bits)}); /* THEN */ EXPECT_VALUE_ALMOST_EQ(r_a, r_p, npc); @@ -674,11 +675,11 @@ TEST_BOOLEAN_BINARY_OP(xor) } \ /* WHEN */ \ auto prev = obj->prot()->getState()->getStats(); \ - auto tmp = OP##_b(obj.get(), b0, bits); \ + auto tmp = OP##_b(obj.get(), b0, {static_cast(bits)}); \ auto cost = \ obj->prot()->getState()->getStats() - prev; \ auto r_b = b2p(obj.get(), tmp); \ - auto r_p = OP##_p(obj.get(), p0, bits); \ + auto r_p = OP##_p(obj.get(), p0, {static_cast(bits)}); \ \ /* THEN */ \ EXPECT_VALUE_EQ(r_b, r_p); \ @@ -837,7 +838,7 @@ TEST_P(ConversionTest, MSB) { // SECURENN has an msb input range here if (conf.protocol() == ProtocolKind::SECURENN) { - p0 = arshift_p(obj.get(), p0, 1); + p0 = arshift_p(obj.get(), p0, {1}); } auto a0 = p2a(obj.get(), p0); @@ -850,8 +851,10 @@ TEST_P(ConversionTest, MSB) { /* THEN */ EXPECT_TRUE(verifyCost(obj->prot()->getKernel("msb_a2b"), "msb_a2b", conf.field(), kShape, npc, cost)); - EXPECT_VALUE_EQ(rshift_p(obj.get(), p0, SizeOf(conf.field()) * 8 - 1), - b2p(obj.get(), b1)); + EXPECT_VALUE_EQ( + rshift_p(obj.get(), p0, + {static_cast(SizeOf(conf.field()) * 8 - 1)}), + b2p(obj.get(), b1)); }); } diff --git a/libspu/mpc/aby3/arithmetic.cc b/libspu/mpc/aby3/arithmetic.cc index a23749ff..6c365220 100644 --- a/libspu/mpc/aby3/arithmetic.cc +++ b/libspu/mpc/aby3/arithmetic.cc @@ -41,7 +41,7 @@ std::vector a1b_offline(size_t sender, const NdArrayRef& a, auto numel = a.numel(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() -> std::vector { + return DISPATCH_ALL_FIELDS(field, [&]() -> std::vector { using ashr_el_t = ring2k_t; NdArrayRef m0(makeType(field), a.shape()); NdArrayRef m1(makeType(field), a.shape()); @@ -52,7 +52,7 @@ std::vector a1b_offline(size_t sender, const NdArrayRef& a, NdArrayView _m1(m1); return DISPATCH_UINT_PT_TYPES( - b.eltype().as()->getBacktype(), "_", + b.eltype().as()->getBacktype(), [&]() -> std::vector { using bshr_t = std::array; if (self_rank == sender) { @@ -82,8 +82,6 @@ std::vector a1b_offline(size_t sender, const NdArrayRef& a, return {c1, c2, m0, m1}; } else if (self_rank == (sender + 1) % 3) { - prg_state->genPrssPair(field, a.shape(), - PrgState::GenPrssCtrl::None); auto c1 = prg_state ->genPrssPair(field, a.shape(), PrgState::GenPrssCtrl::First) @@ -95,8 +93,6 @@ std::vector a1b_offline(size_t sender, const NdArrayRef& a, ->genPrssPair(field, a.shape(), PrgState::GenPrssCtrl::Second) .second; - prg_state->genPrssPair(field, a.shape(), - PrgState::GenPrssCtrl::None); return {c2}; } @@ -109,7 +105,7 @@ std::vector ring_cast_boolean(const NdArrayRef& x) { const size_t numel = x.numel(); std::vector res(numel); - DISPATCH_UINT_PT_TYPES(x.eltype().as()->pt_type(), "_", [&]() { + DISPATCH_UINT_PT_TYPES(x.eltype().as()->pt_type(), [&]() { NdArrayView _x(x); pforeach(0, numel, [&](int64_t idx) { res[idx] = static_cast(_x[idx] & 0x1); @@ -127,7 +123,7 @@ NdArrayRef RandA::proc(KernelEvalContext* ctx, const Shape& shape) const { NdArrayRef out(makeType(field), shape); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; std::vector r0(shape.numel()); @@ -153,7 +149,7 @@ NdArrayRef A2P::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { const auto field = in.eltype().as()->field(); auto numel = in.numel(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using pshr_el_t = ring2k_t; using ashr_el_t = ring2k_t; using ashr_t = std::array; @@ -184,7 +180,7 @@ NdArrayRef P2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { auto rank = comm->getRank(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using ashr_el_t = ring2k_t; using ashr_t = std::array; using pshr_el_t = ring2k_t; @@ -226,7 +222,7 @@ NdArrayRef A2V::proc(KernelEvalContext* ctx, const NdArrayRef& in, auto* comm = ctx->getState(); const auto field = in.eltype().as()->field(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using vshr_el_t = ring2k_t; using ashr_el_t = ring2k_t; using ashr_t = std::array; @@ -267,7 +263,7 @@ NdArrayRef V2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { size_t owner_rank = in_ty->owner(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using ashr_el_t = ring2k_t; using ashr_t = std::array; @@ -311,7 +307,7 @@ NdArrayRef NotA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { auto rank = comm->getRank(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using el_t = std::make_unsigned_t; using shr_t = std::array; @@ -349,7 +345,7 @@ NdArrayRef AddAP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, auto rank = comm->getRank(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; using shr_t = std::array; @@ -376,7 +372,7 @@ NdArrayRef AddAA::proc(KernelEvalContext*, const NdArrayRef& lhs, SPU_ENFORCE(lhs_ty->field() == rhs_ty->field()); const auto field = lhs_ty->field(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using shr_t = std::array; NdArrayRef out(makeType(field), lhs.shape()); @@ -403,7 +399,7 @@ NdArrayRef MulAP::proc(KernelEvalContext*, const NdArrayRef& lhs, SPU_ENFORCE(lhs_ty->field() == rhs_ty->field()); const auto field = lhs_ty->field(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; using shr_t = std::array; @@ -426,7 +422,7 @@ NdArrayRef MulAA::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, auto* comm = ctx->getState(); auto* prg_state = ctx->getState(); - return DISPATCH_ALL_FIELDS(field, "aby3.mulAA", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; using shr_t = std::array; @@ -575,7 +571,7 @@ NdArrayRef MulA1B::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, .reshape(r2.first.shape()); } - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { // using = ring2k_t; NdArrayView r1_0(r1.first); NdArrayView r1_1(r1.second); @@ -675,11 +671,12 @@ NdArrayRef MatMulAA::proc(KernelEvalContext* ctx, const NdArrayRef& x, } NdArrayRef LShiftA::proc(KernelEvalContext*, const NdArrayRef& in, - size_t bits) const { + const Sizes& bits) const { const auto* in_ty = in.eltype().as(); const auto field = in_ty->field(); + bool is_splat = bits.size() == 1; - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using shr_t = std::array; NdArrayRef out(makeType(field), in.shape()); @@ -687,8 +684,9 @@ NdArrayRef LShiftA::proc(KernelEvalContext*, const NdArrayRef& in, NdArrayView _in(in); pforeach(0, in.numel(), [&](int64_t idx) { - _out[idx][0] = _in[idx][0] << bits; - _out[idx][1] = _in[idx][1] << bits; + auto shift_bit = is_splat ? bits[0] : bits[idx]; + _out[idx][0] = _in[idx][0] << shift_bit; + _out[idx][1] = _in[idx][1] << shift_bit; }); return out; @@ -722,22 +720,23 @@ NdArrayRef TruncA::proc(KernelEvalContext* ctx, const NdArrayRef& in, comm->addCommStatsManually(1, kComm); // comm => 1, 2 // ret + const Sizes shift_bit = {static_cast(bits)}; switch (comm->getRank()) { case 0: { - const auto z1 = ring_arshift(x1, bits); + const auto z1 = ring_arshift(x1, shift_bit); const auto z2 = comm->recv(1, x1.eltype(), kBindName); return makeAShare(z1, z2, field); } case 1: { auto r1 = r_future.get().second; - const auto z1 = ring_sub(ring_arshift(ring_add(x1, x2), bits), r1); + const auto z1 = ring_sub(ring_arshift(ring_add(x1, x2), shift_bit), r1); comm->sendAsync(0, z1, kBindName); return makeAShare(z1, r1, field); } case 2: { - const auto z2 = ring_arshift(x2, bits); + const auto z2 = ring_arshift(x2, shift_bit); return makeAShare(r_future.get().first, z2, field); } @@ -790,7 +789,7 @@ NdArrayRef TruncAPr::proc(KernelEvalContext* ctx, const NdArrayRef& in, size_t P2 = (pivot + 2) % 3; NdArrayRef out(in.eltype(), in.shape()); - DISPATCH_ALL_FIELDS(field, "aby3.truncpr", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; using shr_t = std::array; diff --git a/libspu/mpc/aby3/arithmetic.h b/libspu/mpc/aby3/arithmetic.h index 75135fa7..94929715 100644 --- a/libspu/mpc/aby3/arithmetic.h +++ b/libspu/mpc/aby3/arithmetic.h @@ -243,7 +243,7 @@ class LShiftA : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override; + const Sizes& bits) const override; }; // Refer to: diff --git a/libspu/mpc/aby3/boolean.cc b/libspu/mpc/aby3/boolean.cc index 8e07307a..96033cc5 100644 --- a/libspu/mpc/aby3/boolean.cc +++ b/libspu/mpc/aby3/boolean.cc @@ -42,11 +42,11 @@ void CommonTypeB::evaluate(KernelEvalContext* ctx) const { NdArrayRef CastTypeB::proc(KernelEvalContext*, const NdArrayRef& in, const Type& to_type) const { NdArrayRef out(to_type, in.shape()); - DISPATCH_UINT_PT_TYPES(in.eltype().as()->getBacktype(), "_", [&]() { + DISPATCH_UINT_PT_TYPES(in.eltype().as()->getBacktype(), [&]() { using in_el_t = ScalarT; using in_shr_t = std::array; - DISPATCH_UINT_PT_TYPES(to_type.as()->getBacktype(), "_", [&]() { + DISPATCH_UINT_PT_TYPES(to_type.as()->getBacktype(), [&]() { using out_el_t = ScalarT; using out_shr_t = std::array; @@ -69,11 +69,11 @@ NdArrayRef B2P::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { const PtType btype = in.eltype().as()->getBacktype(); const auto field = ctx->getState()->getDefaultField(); - return DISPATCH_UINT_PT_TYPES(btype, "aby3.b2p", [&]() { + return DISPATCH_UINT_PT_TYPES(btype, [&]() { using bshr_el_t = ScalarT; using bshr_t = std::array; - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using pshr_el_t = ring2k_t; NdArrayRef out(makeType(field), in.shape()); @@ -99,12 +99,12 @@ NdArrayRef P2B::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { const auto* in_ty = in.eltype().as(); const auto field = in_ty->field(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { const size_t nbits = maxBitWidth(in); const PtType btype = calcBShareBacktype(nbits); NdArrayView _in(in); - return DISPATCH_UINT_PT_TYPES(btype, "_", [&]() { + return DISPATCH_UINT_PT_TYPES(btype, [&]() { using bshr_el_t = ScalarT; using bshr_t = std::array; @@ -134,12 +134,12 @@ NdArrayRef B2V::proc(KernelEvalContext* ctx, const NdArrayRef& in, const PtType btype = in.eltype().as()->getBacktype(); const auto field = ctx->getState()->getDefaultField(); - return DISPATCH_UINT_PT_TYPES(btype, "aby3.b2v", [&]() { + return DISPATCH_UINT_PT_TYPES(btype, [&]() { using bshr_el_t = ScalarT; using bshr_t = std::array; NdArrayView _in(in); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using vshr_scalar_t = ring2k_t; auto out_ty = makeType(field, rank); @@ -177,7 +177,7 @@ NdArrayRef AndBP::proc(KernelEvalContext*, const NdArrayRef& lhs, const auto* lhs_ty = lhs.eltype().as(); const auto* rhs_ty = rhs.eltype().as(); - return DISPATCH_ALL_FIELDS(rhs_ty->field(), "_", [&]() { + return DISPATCH_ALL_FIELDS(rhs_ty->field(), [&]() { using rhs_scalar_t = ring2k_t; const size_t rhs_nbits = maxBitWidth(rhs); @@ -186,13 +186,13 @@ NdArrayRef AndBP::proc(KernelEvalContext*, const NdArrayRef& lhs, NdArrayView _rhs(rhs); - return DISPATCH_UINT_PT_TYPES(lhs_ty->getBacktype(), "_", [&]() { + return DISPATCH_UINT_PT_TYPES(lhs_ty->getBacktype(), [&]() { using lhs_el_t = ScalarT; using lhs_shr_t = std::array; NdArrayView _lhs(lhs); - return DISPATCH_UINT_PT_TYPES(out_btype, "_", [&]() { + return DISPATCH_UINT_PT_TYPES(out_btype, [&]() { using out_el_t = ScalarT; using out_shr_t = std::array; @@ -224,17 +224,17 @@ NdArrayRef AndBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const PtType out_btype = calcBShareBacktype(out_nbits); NdArrayRef out(makeType(out_btype, out_nbits), lhs.shape()); - return DISPATCH_UINT_PT_TYPES(rhs_ty->getBacktype(), "_", [&]() { + return DISPATCH_UINT_PT_TYPES(rhs_ty->getBacktype(), [&]() { using rhs_el_t = ScalarT; using rhs_shr_t = std::array; NdArrayView _rhs(rhs); - return DISPATCH_UINT_PT_TYPES(lhs_ty->getBacktype(), "_", [&]() { + return DISPATCH_UINT_PT_TYPES(lhs_ty->getBacktype(), [&]() { using lhs_el_t = ScalarT; using lhs_shr_t = std::array; NdArrayView _lhs(lhs); - return DISPATCH_UINT_PT_TYPES(out_btype, "_", [&]() { + return DISPATCH_UINT_PT_TYPES(out_btype, [&]() { using out_el_t = ScalarT; using out_shr_t = std::array; @@ -269,7 +269,7 @@ NdArrayRef XorBP::proc(KernelEvalContext*, const NdArrayRef& lhs, const auto* lhs_ty = lhs.eltype().as(); const auto* rhs_ty = rhs.eltype().as(); - return DISPATCH_ALL_FIELDS(rhs_ty->field(), "_", [&]() { + return DISPATCH_ALL_FIELDS(rhs_ty->field(), [&]() { using rhs_scalar_t = ring2k_t; const size_t rhs_nbits = maxBitWidth(rhs); @@ -280,13 +280,13 @@ NdArrayRef XorBP::proc(KernelEvalContext*, const NdArrayRef& lhs, NdArrayRef out(makeType(out_btype, out_nbits), lhs.shape()); - return DISPATCH_UINT_PT_TYPES(lhs_ty->getBacktype(), "_", [&]() { + return DISPATCH_UINT_PT_TYPES(lhs_ty->getBacktype(), [&]() { using lhs_el_t = ScalarT; using lhs_shr_t = std::array; NdArrayView _lhs(lhs); - return DISPATCH_UINT_PT_TYPES(out_btype, "_", [&]() { + return DISPATCH_UINT_PT_TYPES(out_btype, [&]() { using out_el_t = ScalarT; using out_shr_t = std::array; @@ -311,19 +311,19 @@ NdArrayRef XorBB::proc(KernelEvalContext*, const NdArrayRef& lhs, const size_t out_nbits = std::max(lhs_ty->nbits(), rhs_ty->nbits()); const PtType out_btype = calcBShareBacktype(out_nbits); - return DISPATCH_UINT_PT_TYPES(rhs_ty->getBacktype(), "_", [&]() { + return DISPATCH_UINT_PT_TYPES(rhs_ty->getBacktype(), [&]() { using rhs_el_t = ScalarT; using rhs_shr_t = std::array; NdArrayView _rhs(rhs); - return DISPATCH_UINT_PT_TYPES(lhs_ty->getBacktype(), "_", [&]() { + return DISPATCH_UINT_PT_TYPES(lhs_ty->getBacktype(), [&]() { using lhs_el_t = ScalarT; using lhs_shr_t = std::array; NdArrayView _lhs(lhs); - return DISPATCH_UINT_PT_TYPES(out_btype, "_", [&]() { + return DISPATCH_UINT_PT_TYPES(out_btype, [&]() { using out_el_t = ScalarT; using out_shr_t = std::array; @@ -343,21 +343,24 @@ NdArrayRef XorBB::proc(KernelEvalContext*, const NdArrayRef& lhs, } NdArrayRef LShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const { + const Sizes& bits) const { const auto* in_ty = in.eltype().as(); // TODO: the hal dtype should tell us about the max number of possible bits. const auto field = ctx->getState()->getDefaultField(); - const size_t out_nbits = std::min(in_ty->nbits() + bits, SizeOf(field) * 8); + const size_t out_nbits = std::min( + in_ty->nbits() + *std::max_element(bits.begin(), bits.end()), + SizeOf(field) * 8); const PtType out_btype = calcBShareBacktype(out_nbits); + bool is_splat = bits.size() == 1; - return DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), "_", [&]() { + return DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), [&]() { using in_el_t = ScalarT; using in_shr_t = std::array; NdArrayView _in(in); - return DISPATCH_UINT_PT_TYPES(out_btype, "_", [&]() { + return DISPATCH_UINT_PT_TYPES(out_btype, [&]() { using out_el_t = ScalarT; using out_shr_t = std::array; @@ -366,8 +369,9 @@ NdArrayRef LShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, pforeach(0, in.numel(), [&](int64_t idx) { const auto& v = _in[idx]; - _out[idx][0] = static_cast(v[0]) << bits; - _out[idx][1] = static_cast(v[1]) << bits; + auto shift_bit = is_splat ? bits[0] : bits[idx]; + _out[idx][0] = static_cast(v[0]) << shift_bit; + _out[idx][1] = static_cast(v[1]) << shift_bit; }); return out; @@ -376,19 +380,19 @@ NdArrayRef LShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, } NdArrayRef RShiftB::proc(KernelEvalContext*, const NdArrayRef& in, - size_t bits) const { + const Sizes& bits) const { const auto* in_ty = in.eltype().as(); - bits = std::min(in_ty->nbits(), bits); - size_t out_nbits = in_ty->nbits(); - out_nbits -= std::min(out_nbits, bits); + int64_t out_nbits = in_ty->nbits(); + out_nbits -= std::min(out_nbits, *std::min_element(bits.begin(), bits.end())); const PtType out_btype = calcBShareBacktype(out_nbits); + bool is_splat = bits.size() == 1; - return DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), "_", [&]() { + return DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), [&]() { using in_shr_t = std::array; NdArrayView _in(in); - return DISPATCH_UINT_PT_TYPES(out_btype, "_", [&]() { + return DISPATCH_UINT_PT_TYPES(out_btype, [&]() { using out_el_t = ScalarT; using out_shr_t = std::array; @@ -397,8 +401,9 @@ NdArrayRef RShiftB::proc(KernelEvalContext*, const NdArrayRef& in, pforeach(0, in.numel(), [&](int64_t idx) { const auto& v = _in[idx]; - _out[idx][0] = static_cast(v[0] >> bits); - _out[idx][1] = static_cast(v[1] >> bits); + auto shift_bit = is_splat ? bits[0] : bits[idx]; + _out[idx][0] = static_cast(v[0] >> shift_bit); + _out[idx][1] = static_cast(v[1] >> shift_bit); }); return out; @@ -407,9 +412,10 @@ NdArrayRef RShiftB::proc(KernelEvalContext*, const NdArrayRef& in, } NdArrayRef ARShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const { + const Sizes& bits) const { const auto field = ctx->getState()->getDefaultField(); const auto* in_ty = in.eltype().as(); + bool is_splat = bits.size() == 1; // arithmetic right shift expects to work on ring, or the behaviour is // undefined. @@ -418,7 +424,7 @@ NdArrayRef ARShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, const PtType out_btype = in_ty->getBacktype(); const size_t out_nbits = in_ty->nbits(); - return DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), "_", [&]() { + return DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), [&]() { using el_t = std::make_signed_t; using shr_t = std::array; @@ -428,8 +434,9 @@ NdArrayRef ARShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, pforeach(0, in.numel(), [&](int64_t idx) { const auto& v = _in[idx]; - _out[idx][0] = v[0] >> bits; - _out[idx][1] = v[1] >> bits; + auto shift_bit = is_splat ? bits[0] : bits[idx]; + _out[idx][0] = v[0] >> shift_bit; + _out[idx][1] = v[1] >> shift_bit; }); return out; @@ -444,13 +451,13 @@ NdArrayRef BitrevB::proc(KernelEvalContext*, const NdArrayRef& in, size_t start, const size_t out_nbits = std::max(in_ty->nbits(), end); const PtType out_btype = calcBShareBacktype(out_nbits); - return DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), "_", [&]() { + return DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), [&]() { using in_el_t = ScalarT; using in_shr_t = std::array; NdArrayView _in(in); - return DISPATCH_UINT_PT_TYPES(out_btype, "_", [&]() { + return DISPATCH_UINT_PT_TYPES(out_btype, [&]() { using out_el_t = ScalarT; using out_shr_t = std::array; @@ -488,7 +495,7 @@ NdArrayRef BitIntlB::proc(KernelEvalContext*, const NdArrayRef& in, SPU_ENFORCE(absl::has_single_bit(nbits)); NdArrayRef out(in.eltype(), in.shape()); - DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), "_", [&]() { + DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), [&]() { using el_t = ScalarT; using shr_t = std::array; NdArrayView _out(out); @@ -511,7 +518,7 @@ NdArrayRef BitDeintlB::proc(KernelEvalContext*, const NdArrayRef& in, SPU_ENFORCE(absl::has_single_bit(nbits)); NdArrayRef out(in.eltype(), in.shape()); - DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), "_", [&]() { + DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), [&]() { using el_t = ScalarT; using shr_t = std::array; NdArrayView _out(out); diff --git a/libspu/mpc/aby3/boolean.h b/libspu/mpc/aby3/boolean.h index b3036620..9fd83b45 100644 --- a/libspu/mpc/aby3/boolean.h +++ b/libspu/mpc/aby3/boolean.h @@ -152,7 +152,7 @@ class LShiftB : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override; + const Sizes& bits) const override; }; class RShiftB : public ShiftKernel { @@ -164,7 +164,7 @@ class RShiftB : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override; + const Sizes& bits) const override; }; class ARShiftB : public ShiftKernel { @@ -176,7 +176,7 @@ class ARShiftB : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override; + const Sizes& bits) const override; }; class BitrevB : public BitrevKernel { diff --git a/libspu/mpc/aby3/conversion.cc b/libspu/mpc/aby3/conversion.cc index 070450c5..8154acac 100644 --- a/libspu/mpc/aby3/conversion.cc +++ b/libspu/mpc/aby3/conversion.cc @@ -65,11 +65,11 @@ NdArrayRef A2B::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { auto numel = in.numel(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using ashr_t = std::array; NdArrayView _in(in); - DISPATCH_UINT_PT_TYPES(out_btype, "_", [&]() { + DISPATCH_UINT_PT_TYPES(out_btype, [&]() { using bshr_el_t = ScalarT; using bshr_t = std::array; @@ -152,7 +152,7 @@ NdArrayRef B2AByPPA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { if (in_nbits == 0) { // special case, it's known to be zero. - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView> _out(out); pforeach(0, numel, [&](int64_t idx) { _out[idx][0] = 0; @@ -165,11 +165,11 @@ NdArrayRef B2AByPPA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { auto* comm = ctx->getState(); auto* prg_state = ctx->getState(); - DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), "_", [&]() { + DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), [&]() { using bshr_t = std::array; NdArrayView _in(in); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using ashr_el_t = ring2k_t; using ashr_t = std::array; @@ -324,7 +324,7 @@ NdArrayRef B2AByOT::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { if (in_nbits == 0) { // special case, it's known to be zero. - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView> _out(out); pforeach(0, numel, [&](int64_t idx) { _out[idx][0] = 0; @@ -345,12 +345,12 @@ NdArrayRef B2AByOT::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { size_t P1 = (pivot + 1) % 3; size_t P2 = (pivot + 2) % 3; - DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), "_", [&]() { + DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), [&]() { using bshr_el_t = ScalarT; using bshr_t = std::array; NdArrayView _in(in); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using ashr_el_t = ring2k_t; using ashr_t = std::array; @@ -391,11 +391,6 @@ NdArrayRef B2AByOT::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { }); } else if (comm->getRank() == P1) { // the receiver - prg_state->fillPrssPair(nullptr, nullptr, r0.size(), - PrgState::GenPrssCtrl::None); - prg_state->fillPrssPair(nullptr, nullptr, r0.size(), - PrgState::GenPrssCtrl::None); - auto b2 = bitDecompose(getShare(in, 0), in_nbits); // ot.recv @@ -494,12 +489,12 @@ NdArrayRef B2AByOT::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { NdArrayRef lo(out_type, in.shape()); NdArrayRef hi(out_type, in.shape()); - DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), "_", [&]() { + DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), [&]() { using in_el_t = ScalarT; using in_shr_t = std::array; NdArrayView _in(in); - DISPATCH_UINT_PT_TYPES(out_backtype, "_", [&]() { + DISPATCH_UINT_PT_TYPES(out_backtype, [&]() { using out_el_t = ScalarT; using out_shr_t = std::array; @@ -574,7 +569,7 @@ NdArrayRef MsbA2B::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { makeType(GetStorageType(field), SizeOf(field) * 8); NdArrayRef m(bshr_type, in.shape()); NdArrayRef n(bshr_type, in.shape()); - DISPATCH_ALL_FIELDS(field, "aby3.msb.split", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; using shr_t = std::array; @@ -620,7 +615,9 @@ NdArrayRef MsbA2B::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { // Compute the k'th bit. // (m^n)[k] ^ carry - auto msb = xor_bb(sctx, rshift_b(sctx, xor_bb(sctx, wrap_m, wrap_n), nbits), + auto msb = xor_bb(sctx, + rshift_b(sctx, xor_bb(sctx, wrap_m, wrap_n), + {static_cast(nbits)}), carry); return UnwrapValue(msb); @@ -660,10 +657,10 @@ NdArrayRef eqz(KernelEvalContext* ctx, const NdArrayRef& in) { size_t P1 = (pivot + 1) % 3; size_t P2 = (pivot + 2) % 3; - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using ashr_el_t = ring2k_t; using ashr_t = std::array; - DISPATCH_UINT_PT_TYPES(in_bshr_btype, "_", [&]() { + DISPATCH_UINT_PT_TYPES(in_bshr_btype, [&]() { using bshr_el_t = ScalarT; std::vector zero_flag_3pc_0(numel); std::vector zero_flag_3pc_1(numel); @@ -715,10 +712,6 @@ NdArrayRef eqz(KernelEvalContext* ctx, const NdArrayRef& in) { PrgState::GenPrssCtrl::First); } else { pforeach(0, numel, [&](int64_t idx) { a_s[idx] = _in[idx][1]; }); - prg_state->fillPrssPair({}, {}, numel, - PrgState::GenPrssCtrl::None); - prg_state->fillPrssPair({}, {}, numel, - PrgState::GenPrssCtrl::None); r_arith = comm->recv(P0, "r_arith"); r_bool = comm->recv(P0, "r_bool"); } @@ -754,8 +747,6 @@ NdArrayRef eqz(KernelEvalContext* ctx, const NdArrayRef& in) { // P1 zero_flag = (not(c_p xor [r]b0)^ rz, rb1) pforeach(0, numel, [&](int64_t idx) { zero_flag_3pc_1[idx] = r_bool[idx]; }); - prg_state->fillPrssPair({}, {}, numel, - PrgState::GenPrssCtrl::None); auto flag_split = comm->recv(P1, "flag_split"); pforeach(0, numel, [&](int64_t idx) { @@ -866,7 +857,7 @@ NdArrayRef EqualAA::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const auto field = lhs_ty->field(); NdArrayRef out(makeType(field), lhs.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using shr_t = std::array; NdArrayView _out(out); NdArrayView _lhs(lhs); @@ -893,7 +884,7 @@ NdArrayRef EqualAP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, auto rank = comm->getRank(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; using shr_t = std::array; diff --git a/libspu/mpc/aby3/io.cc b/libspu/mpc/aby3/io.cc index bc3be20c..3ded34a9 100644 --- a/libspu/mpc/aby3/io.cc +++ b/libspu/mpc/aby3/io.cc @@ -148,7 +148,7 @@ NdArrayRef Aby3Io::fromShares(const std::vector& shares) const { SPU_ENFORCE(field_ == eltype.as()->field()); NdArrayRef out(makeType(field_), shares[0].shape()); - DISPATCH_ALL_FIELDS(field_, "_", [&]() { + DISPATCH_ALL_FIELDS(field_, [&]() { using el_t = ring2k_t; using shr_t = std::array; NdArrayView _out(out); @@ -166,10 +166,10 @@ NdArrayRef Aby3Io::fromShares(const std::vector& shares) const { } else if (eltype.isa()) { NdArrayRef out(makeType(field_), shares[0].shape()); - DISPATCH_ALL_FIELDS(field_, "_", [&]() { + DISPATCH_ALL_FIELDS(field_, [&]() { NdArrayView _out(out); - DISPATCH_UINT_PT_TYPES(eltype.as()->getBacktype(), "_", [&] { + DISPATCH_UINT_PT_TYPES(eltype.as()->getBacktype(), [&] { using shr_t = std::array; for (size_t si = 0; si < shares.size(); si++) { NdArrayView _s(shares[si]); diff --git a/libspu/mpc/aby3/oram.cc b/libspu/mpc/aby3/oram.cc index 2ccf83f0..395a05ba 100644 --- a/libspu/mpc/aby3/oram.cc +++ b/libspu/mpc/aby3/oram.cc @@ -37,7 +37,7 @@ NdArrayRef OramOneHotAA::proc(KernelEvalContext *ctx, const NdArrayRef &in, const auto field = eltype.as()->field(); NdArrayRef out(makeType(field), {s}); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; using shr_t = std::array; NdArrayView out_(out); @@ -91,7 +91,7 @@ NdArrayRef OramOneHotAP::proc(KernelEvalContext *ctx, const NdArrayRef &in, const auto numel = in.numel(); NdArrayRef out(makeType(field), {s}); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; using shr_t = std::array; NdArrayView out_(out); @@ -150,7 +150,7 @@ NdArrayRef OramReadOA::proc(KernelEvalContext *ctx, const NdArrayRef &onehot, NdArrayRef out(makeType(field), {1, index_times}); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; using shr_t = std::array; @@ -208,7 +208,7 @@ NdArrayRef OramReadOP::proc(KernelEvalContext *ctx, const NdArrayRef &onehot, auto o2 = getSecondShare(out); int64_t db_numel = onehot.numel(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; using shr_t = std::array; @@ -286,15 +286,11 @@ Triple> genOramBeaverPrim(KernelEvalContext *ctx, int64_t num, std::vector beaver_triple(num * 3); if (comm->getRank() == adjust_rank) { - prg->fillPrssPair(nullptr, nullptr, num * 3, - PrgState::GenPrssCtrl::None); prg->fillPrssPair(nullptr, beaver_triple.data(), num * 3, PrgState::GenPrssCtrl::Second); } else { prg->fillPrssPair(beaver_triple.data(), nullptr, num * 3, PrgState::GenPrssCtrl::First); - prg->fillPrssPair(nullptr, nullptr, num * 3, - PrgState::GenPrssCtrl::None); } std::vector a(beaver_triple.begin(), beaver_triple.begin() + num); diff --git a/libspu/mpc/aby3/ot.cc b/libspu/mpc/aby3/ot.cc index 122e0fa3..8f5f4863 100644 --- a/libspu/mpc/aby3/ot.cc +++ b/libspu/mpc/aby3/ot.cc @@ -71,8 +71,6 @@ std::pair Ot3::genMasks() { } } else { SPU_ENFORCE(comm_->getRank() == roles_.receiver); - prg_state_->genPrssPair(field_, shape_, PrgState::GenPrssCtrl::None); - prg_state_->genPrssPair(field_, shape_, PrgState::GenPrssCtrl::None); } return {w0, w1}; diff --git a/libspu/mpc/aby3/permute.cc b/libspu/mpc/aby3/permute.cc index edbeed7f..fb80bfd0 100644 --- a/libspu/mpc/aby3/permute.cc +++ b/libspu/mpc/aby3/permute.cc @@ -30,7 +30,7 @@ PermVector ring2pv(const NdArrayRef& x) { x.eltype()); const auto field = x.eltype().as()->field(); PermVector pv(x.numel()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _x(x); pforeach(0, x.numel(), [&](int64_t idx) { pv[idx] = int64_t(_x[idx]); }); }); @@ -54,7 +54,7 @@ NdArrayRef RandPermM::proc(KernelEvalContext* ctx, const Shape& shape) const { const auto field = out.eltype().as()->field(); auto out1 = getFirstShare(out); auto out2 = getSecondShare(out); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _out1(out1); NdArrayView _out2(out2); pforeach(0, out.numel(), [&](int64_t idx) { @@ -78,7 +78,7 @@ NdArrayRef PermAM::proc(KernelEvalContext* ctx, const NdArrayRef& in, PermVector pv_next = ring2pv(getSecondShare(perm)); NdArrayRef out(in.eltype(), in.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; using shr_t = std::array; @@ -198,7 +198,7 @@ NdArrayRef InvPermAM::proc(KernelEvalContext* ctx, const NdArrayRef& in, PermVector pv_next = ring2pv(getSecondShare(perm)); NdArrayRef out(in.eltype(), in.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; using shr_t = std::array; diff --git a/libspu/mpc/aby3/value.h b/libspu/mpc/aby3/value.h index fbb413a0..8396b121 100644 --- a/libspu/mpc/aby3/value.h +++ b/libspu/mpc/aby3/value.h @@ -58,7 +58,7 @@ std::vector getShareAs(const NdArrayRef& in, size_t share_idx) { auto numel = in.numel(); std::vector res(numel); - DISPATCH_UINT_PT_TYPES(share.eltype().as()->pt_type(), "_", [&]() { + DISPATCH_UINT_PT_TYPES(share.eltype().as()->pt_type(), [&]() { NdArrayView _share(share); for (auto idx = 0; idx < numel; ++idx) { res[idx] = _share[idx]; diff --git a/libspu/mpc/api.cc b/libspu/mpc/api.cc index 2008e8aa..58a1f4c1 100644 --- a/libspu/mpc/api.cc +++ b/libspu/mpc/api.cc @@ -14,8 +14,6 @@ #include "libspu/mpc/api.h" -#include - #include "libspu/core/trace.h" #include "libspu/mpc/ab_api.h" @@ -303,13 +301,14 @@ Value msb_s(SPUContext* ctx, const Value& x) { if (ctx->hasKernel("msb_a2b")) { if (IsB(x)) { - return rshift_b(ctx, x, SizeOf(field) * 8 - 1); + return rshift_b(ctx, x, {static_cast(SizeOf(field) * 8 - 1)}); } else { // fast path, directly apply msb x AShare, result a BShare. return msb_a2b(ctx, x); } } else { - return rshift_b(ctx, _2b(ctx, x), SizeOf(field) * 8 - 1); + return rshift_b(ctx, _2b(ctx, x), + {static_cast(SizeOf(field) * 8 - 1)}); } } @@ -601,7 +600,7 @@ Value xor_pp(SPUContext* ctx, const Value& x, const Value& y) { ////////////////////////////////////////////////////////////////////////////// -Value lshift_s(SPUContext* ctx, const Value& x, size_t bits) { +Value lshift_s(SPUContext* ctx, const Value& x, const Sizes& bits) { SPU_TRACE_MPC_DISP(ctx, x, bits); TRY_DISPATCH(ctx, x, bits); if (IsA(x)) { @@ -613,43 +612,43 @@ Value lshift_s(SPUContext* ctx, const Value& x, size_t bits) { } } -Value lshift_v(SPUContext* ctx, const Value& x, size_t nbits) { +Value lshift_v(SPUContext* ctx, const Value& x, const Sizes& nbits) { FORCE_DISPATCH(ctx, x, nbits); } -Value lshift_p(SPUContext* ctx, const Value& x, size_t nbits) { +Value lshift_p(SPUContext* ctx, const Value& x, const Sizes& nbits) { FORCE_DISPATCH(ctx, x, nbits); } ////////////////////////////////////////////////////////////////////////////// -Value rshift_s(SPUContext* ctx, const Value& x, size_t bits) { +Value rshift_s(SPUContext* ctx, const Value& x, const Sizes& bits) { SPU_TRACE_MPC_DISP(ctx, x, bits); TRY_DISPATCH(ctx, x, bits); return rshift_b(ctx, _2b(ctx, x), bits); } -Value rshift_v(SPUContext* ctx, const Value& x, size_t nbits) { +Value rshift_v(SPUContext* ctx, const Value& x, const Sizes& nbits) { FORCE_DISPATCH(ctx, x, nbits); } -Value rshift_p(SPUContext* ctx, const Value& x, size_t nbits) { +Value rshift_p(SPUContext* ctx, const Value& x, const Sizes& nbits) { FORCE_DISPATCH(ctx, x, nbits); } ////////////////////////////////////////////////////////////////////////////// -Value arshift_s(SPUContext* ctx, const Value& x, size_t bits) { +Value arshift_s(SPUContext* ctx, const Value& x, const Sizes& bits) { SPU_TRACE_MPC_DISP(ctx, x, bits); TRY_DISPATCH(ctx, x, bits); return arshift_b(ctx, _2b(ctx, x), bits); } -Value arshift_v(SPUContext* ctx, const Value& x, size_t nbits) { +Value arshift_v(SPUContext* ctx, const Value& x, const Sizes& nbits) { FORCE_DISPATCH(ctx, x, nbits); } -Value arshift_p(SPUContext* ctx, const Value& x, size_t nbits) { +Value arshift_p(SPUContext* ctx, const Value& x, const Sizes& nbits) { FORCE_DISPATCH(ctx, x, nbits); } @@ -662,11 +661,15 @@ Value trunc_s(SPUContext* ctx, const Value& x, size_t bits, SignType sign) { } Value trunc_v(SPUContext* ctx, const Value& x, size_t nbits, SignType sign) { - FORCE_DISPATCH(ctx, x, nbits, sign); + // FIXME: trunc_v use shift kernel + const Sizes trunc_bits = {static_cast(nbits)}; + FORCE_DISPATCH(ctx, x, trunc_bits, sign); } Value trunc_p(SPUContext* ctx, const Value& x, size_t nbits, SignType sign) { - FORCE_DISPATCH(ctx, x, nbits, sign); + // FIXME: trunc_p use shift kernel + const Sizes trunc_bits = {static_cast(nbits)}; + FORCE_DISPATCH(ctx, x, trunc_bits, sign); } ////////////////////////////////////////////////////////////////////////////// diff --git a/libspu/mpc/api.h b/libspu/mpc/api.h index 7d0b2333..ea61cc81 100644 --- a/libspu/mpc/api.h +++ b/libspu/mpc/api.h @@ -143,17 +143,17 @@ Value xor_vv(SPUContext* ctx, const Value& x, const Value& y); Value xor_vp(SPUContext* ctx, const Value& x, const Value& y); Value xor_pp(SPUContext* ctx, const Value& x, const Value& y); -Value lshift_s(SPUContext* ctx, const Value& x, size_t nbits); -Value lshift_v(SPUContext* ctx, const Value& x, size_t nbits); -Value lshift_p(SPUContext* ctx, const Value& x, size_t nbits); +Value lshift_s(SPUContext* ctx, const Value& x, const Sizes& nbits); +Value lshift_v(SPUContext* ctx, const Value& x, const Sizes& nbits); +Value lshift_p(SPUContext* ctx, const Value& x, const Sizes& nbits); -Value rshift_s(SPUContext* ctx, const Value& x, size_t nbits); -Value rshift_v(SPUContext* ctx, const Value& x, size_t nbits); -Value rshift_p(SPUContext* ctx, const Value& x, size_t nbits); +Value rshift_s(SPUContext* ctx, const Value& x, const Sizes& nbits); +Value rshift_v(SPUContext* ctx, const Value& x, const Sizes& nbits); +Value rshift_p(SPUContext* ctx, const Value& x, const Sizes& nbits); -Value arshift_s(SPUContext* ctx, const Value& x, size_t nbits); -Value arshift_v(SPUContext* ctx, const Value& x, size_t nbits); -Value arshift_p(SPUContext* ctx, const Value& x, size_t nbits); +Value arshift_s(SPUContext* ctx, const Value& x, const Sizes& nbits); +Value arshift_v(SPUContext* ctx, const Value& x, const Sizes& nbits); +Value arshift_p(SPUContext* ctx, const Value& x, const Sizes& nbits); Value trunc_s(SPUContext* ctx, const Value& x, size_t nbits, SignType sign); Value trunc_v(SPUContext* ctx, const Value& x, size_t nbits, SignType sign); diff --git a/libspu/mpc/api_test.cc b/libspu/mpc/api_test.cc index 33568d77..43cb0cf0 100644 --- a/libspu/mpc/api_test.cc +++ b/libspu/mpc/api_test.cc @@ -261,7 +261,7 @@ TEST_P(ApiTest, MsbS) { // SECURENN has an msb input range requirement here if (conf.protocol() == ProtocolKind::SECURENN) { - p0 = arshift_p(sctx.get(), p0, 1); + p0 = arshift_p(sctx.get(), p0, {1}); } auto r_s = s2p(sctx.get(), msb_s(sctx.get(), p2s(sctx.get(), p0))); @@ -272,88 +272,92 @@ TEST_P(ApiTest, MsbS) { }); } -#define TEST_UNARY_OP_WITH_BIT_S(OP) \ - TEST_P(ApiTest, OP##S) { \ - const auto factory = std::get<0>(GetParam()); \ - const RuntimeConfig& conf = std::get<1>(GetParam()); \ - const size_t npc = std::get<2>(GetParam()); \ - \ - utils::simulate( \ - npc, [&](const std::shared_ptr& lctx) { \ - auto sctx = factory(conf, lctx); \ - \ - /* GIVEN */ \ - auto x_p = rand_p(sctx.get(), kShape); \ - auto x_s = p2s(sctx.get(), x_p); \ - \ - for (auto bits : kShiftBits) { \ - if (bits >= SizeOf(conf.field()) * 8) { \ - continue; \ - } \ - /* WHEN */ \ - auto r_s = s2p(sctx.get(), OP##_s(sctx.get(), x_s, bits)); \ - auto r_p = OP##_p(sctx.get(), x_p, bits); \ - \ - /* THEN */ \ - EXPECT_VALUE_EQ(r_s, r_p); \ - } \ - }); \ +#define TEST_UNARY_OP_WITH_BIT_S(OP) \ + TEST_P(ApiTest, OP##S) { \ + const auto factory = std::get<0>(GetParam()); \ + const RuntimeConfig& conf = std::get<1>(GetParam()); \ + const size_t npc = std::get<2>(GetParam()); \ + \ + utils::simulate( \ + npc, [&](const std::shared_ptr& lctx) { \ + auto sctx = factory(conf, lctx); \ + \ + /* GIVEN */ \ + auto x_p = rand_p(sctx.get(), kShape); \ + auto x_s = p2s(sctx.get(), x_p); \ + \ + for (auto bits : kShiftBits) { \ + if (bits >= SizeOf(conf.field()) * 8) { \ + continue; \ + } \ + /* WHEN */ \ + auto r_s = s2p(sctx.get(), OP##_s(sctx.get(), x_s, \ + {static_cast(bits)})); \ + auto r_p = OP##_p(sctx.get(), x_p, {static_cast(bits)}); \ + \ + /* THEN */ \ + EXPECT_VALUE_EQ(r_s, r_p); \ + } \ + }); \ } -#define TEST_UNARY_OP_WITH_BIT_V(OP) \ - TEST_P(ApiTest, OP##V) { \ - const auto factory = std::get<0>(GetParam()); \ - const RuntimeConfig& conf = std::get<1>(GetParam()); \ - const size_t npc = std::get<2>(GetParam()); \ - \ - utils::simulate( \ - npc, [&](const std::shared_ptr& lctx) { \ - auto sctx = factory(conf, lctx); \ - \ - for (size_t rank = 0; rank < npc; rank++) { \ - /* GIVEN */ \ - auto x_p = rand_p(sctx.get(), kShape); \ - auto x_v = p2v(sctx.get(), x_p, rank); \ - \ - for (auto bits : kShiftBits) { \ - if (bits >= SizeOf(conf.field()) * 8) { \ - continue; \ - } \ - /* WHEN */ \ - auto r_v = v2p(sctx.get(), OP##_v(sctx.get(), x_v, bits)); \ - auto r_p = OP##_p(sctx.get(), x_p, bits); \ - \ - /* THEN */ \ - EXPECT_VALUE_EQ(r_v, r_p); \ - } \ - } \ - }); \ +#define TEST_UNARY_OP_WITH_BIT_V(OP) \ + TEST_P(ApiTest, OP##V) { \ + const auto factory = std::get<0>(GetParam()); \ + const RuntimeConfig& conf = std::get<1>(GetParam()); \ + const size_t npc = std::get<2>(GetParam()); \ + \ + utils::simulate( \ + npc, [&](const std::shared_ptr& lctx) { \ + auto sctx = factory(conf, lctx); \ + \ + for (size_t rank = 0; rank < npc; rank++) { \ + /* GIVEN */ \ + auto x_p = rand_p(sctx.get(), kShape); \ + auto x_v = p2v(sctx.get(), x_p, rank); \ + \ + for (auto bits : kShiftBits) { \ + if (bits >= SizeOf(conf.field()) * 8) { \ + continue; \ + } \ + /* WHEN */ \ + auto r_v = \ + v2p(sctx.get(), \ + OP##_v(sctx.get(), x_v, {static_cast(bits)})); \ + auto r_p = \ + OP##_p(sctx.get(), x_p, {static_cast(bits)}); \ + \ + /* THEN */ \ + EXPECT_VALUE_EQ(r_v, r_p); \ + } \ + } \ + }); \ } -#define TEST_UNARY_OP_WITH_BIT_P(OP) \ - TEST_P(ApiTest, OP##P) { \ - const auto factory = std::get<0>(GetParam()); \ - const RuntimeConfig& conf = std::get<1>(GetParam()); \ - const size_t npc = std::get<2>(GetParam()); \ - \ - utils::simulate(npc, \ - [&](const std::shared_ptr& lctx) { \ - auto sctx = factory(conf, lctx); \ - \ - /* GIVEN */ \ - auto p0 = rand_p(sctx.get(), kShape); \ - \ - for (auto bits : kShiftBits) { /* WHEN */ \ - if (bits >= SizeOf(conf.field()) * 8) { \ - continue; \ - } \ - auto r_p = OP##_p(sctx.get(), p0, bits); \ - auto r_pp = OP##_p(sctx.get(), p0, bits); \ - \ - /* THEN */ \ - EXPECT_VALUE_EQ(r_p, r_pp); \ - } \ - }); \ +#define TEST_UNARY_OP_WITH_BIT_P(OP) \ + TEST_P(ApiTest, OP##P) { \ + const auto factory = std::get<0>(GetParam()); \ + const RuntimeConfig& conf = std::get<1>(GetParam()); \ + const size_t npc = std::get<2>(GetParam()); \ + \ + utils::simulate( \ + npc, [&](const std::shared_ptr& lctx) { \ + auto sctx = factory(conf, lctx); \ + \ + /* GIVEN */ \ + auto p0 = rand_p(sctx.get(), kShape); \ + \ + for (auto bits : kShiftBits) { /* WHEN */ \ + if (bits >= SizeOf(conf.field()) * 8) { \ + continue; \ + } \ + auto r_p = OP##_p(sctx.get(), p0, {static_cast(bits)}); \ + auto r_pp = OP##_p(sctx.get(), p0, {static_cast(bits)}); \ + \ + /* THEN */ \ + EXPECT_VALUE_EQ(r_p, r_pp); \ + } \ + }); \ } #define TEST_UNARY_OP_WITH_BIT(OP) \ @@ -379,12 +383,13 @@ TEST_P(ApiTest, TruncS) { : kShape); // TODO: here we assume has msb error, only use lowest 10 bits. - p0 = arshift_p(sctx.get(), p0, SizeOf(conf.field()) * 8 - 10); + p0 = arshift_p(sctx.get(), p0, + {static_cast(SizeOf(conf.field()) * 8 - 10)}); const size_t bits = 2; auto r_s = s2p(sctx.get(), trunc_s(sctx.get(), p2s(sctx.get(), p0), bits, SignType::Unknown)); - auto r_p = arshift_p(sctx.get(), p0, bits); + auto r_p = arshift_p(sctx.get(), p0, {bits}); /* THEN */ EXPECT_VALUE_ALMOST_EQ(r_s, r_p, npc); diff --git a/libspu/mpc/cheetah/arith/cheetah_conv2d_test.cc b/libspu/mpc/cheetah/arith/cheetah_conv2d_test.cc index b05235a6..d71d2bb1 100644 --- a/libspu/mpc/cheetah/arith/cheetah_conv2d_test.cc +++ b/libspu/mpc/cheetah/arith/cheetah_conv2d_test.cc @@ -84,7 +84,7 @@ TEST_P(CheetahConv2dTest, Basic) { flatten(ring_add(toNdArray(result[0]), toNdArray(result[1]))); const int64_t kMaxDiff = 1; - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { ArrayView c(computed); NdArrayView exp(expected); for (auto idx = 0; idx < expected.numel(); idx++) { diff --git a/libspu/mpc/cheetah/arith/cheetah_dot_test.cc b/libspu/mpc/cheetah/arith/cheetah_dot_test.cc index 79697937..df93112d 100644 --- a/libspu/mpc/cheetah/arith/cheetah_dot_test.cc +++ b/libspu/mpc/cheetah/arith/cheetah_dot_test.cc @@ -68,7 +68,7 @@ TEST_P(CheetahDotTest, Basic) { EXPECT_EQ(expected.numel(), computed.numel()); const int64_t kMaxDiff = 1; - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto e = NdArrayView(expected); auto c = NdArrayView(computed); @@ -119,7 +119,7 @@ TEST_P(CheetahDotTest, BatchDot) { [[maybe_unused]] constexpr int64_t kMaxDiff = 1; int64_t max_diff = 0; - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto e = NdArrayView(expected); auto c = NdArrayView(computed); diff --git a/libspu/mpc/cheetah/arith/common.cc b/libspu/mpc/cheetah/arith/common.cc index 3deda13e..f87d92ab 100644 --- a/libspu/mpc/cheetah/arith/common.cc +++ b/libspu/mpc/cheetah/arith/common.cc @@ -102,7 +102,7 @@ NdArrayRef ring_conv2d(const NdArrayRef &tensor, const NdArrayRef &filter, NdArrayRef _filter = filter.reshape(fs); NdArrayRef _ret = ring_zeros(field, result_shape); - DISPATCH_ALL_FIELDS(field, "ring_conv2d", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { // NOTE(juhou): valid padding so offset are always 0. constexpr int64_t padh = 0; constexpr int64_t padw = 0; diff --git a/libspu/mpc/cheetah/arith/matmat_prot.cc b/libspu/mpc/cheetah/arith/matmat_prot.cc index a56e522f..919a13fa 100644 --- a/libspu/mpc/cheetah/arith/matmat_prot.cc +++ b/libspu/mpc/cheetah/arith/matmat_prot.cc @@ -113,7 +113,7 @@ NdArrayRef ConcatSubMatrix(const NdArrayRef& mat, const Shape2D& mat_shape, // NOTE: zero padding via initialization NdArrayRef flatten = ring_zeros(field, {num_coeff}); - DISPATCH_ALL_FIELDS(field, "ConcatSubMat", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using uT = std::make_unsigned::type; for (int64_t r = 0, rr = starts[0]; r < extents[0]; ++r, ++rr) { @@ -183,7 +183,7 @@ Shape3D MatMatProtocol::GetSubMatShape(const Meta& meta, int64_t poly_deg, const double cpu_price = 1.0; const double bandwidth_price = 1000.0; - Shape3D subshape; + Shape3D subshape = {0, 0, 0}; Shape3D blk; const int64_t n = poly_deg; @@ -423,7 +423,7 @@ NdArrayRef MatMatProtocol::ParseResult(FieldType field, const Meta& meta, for (int64_t r = 0; r < row_ext; ++r) { for (int64_t c = 0; c < col_ext; ++c) { int64_t dst_idx = (r + row_start) * meta.dims[2] + col_start + c; - DISPATCH_ALL_FIELDS(field, "ParseResult", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { matmat.at(dst_idx) = result_poly.at(r * subdims[2] + c); }); @@ -532,13 +532,13 @@ NdArrayRef MatMatProtocol::ParsePackLWEsResult( std::vector decoded_vectors(ans_poly.size()); for (size_t i = 0; i < ans_poly.size(); ++i) { decoded_vectors[i] = - msh.ModulusDownRNS(field, {(int64_t)packing_width}, + msh.ModulusDownRNS(field, {static_cast(packing_width)}, {ans_poly[i].data(), ans_poly[i].coeff_count()}); } NdArrayRef matmat = ring_zeros(field, {meta.dims[0] * meta.dims[2]}); - DISPATCH_ALL_FIELDS(field, "pack_lwes_results", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView xmatmat(matmat); for (size_t i = 0; i < ans_poly.size(); ++i) { diff --git a/libspu/mpc/cheetah/arith/matmat_prot_test.cc b/libspu/mpc/cheetah/arith/matmat_prot_test.cc index 0b01f0d0..3a5f6521 100644 --- a/libspu/mpc/cheetah/arith/matmat_prot_test.cc +++ b/libspu/mpc/cheetah/arith/matmat_prot_test.cc @@ -153,7 +153,7 @@ TEST_P(MatMatProtTest, Plain) { EXPECT_EQ(expected.numel(), computed.numel()); - DISPATCH_ALL_FIELDS(field_, "", [&]() { + DISPATCH_ALL_FIELDS(field_, [&]() { auto xe = NdArrayView(expected); auto xc = NdArrayView(computed); for (int64_t i = 0; i < xc.numel(); ++i) { @@ -215,7 +215,7 @@ TEST_P(MatMatProtTest, EncLHS) { EXPECT_EQ(expected.numel(), computed.numel()); - DISPATCH_ALL_FIELDS(field_, "", [&]() { + DISPATCH_ALL_FIELDS(field_, [&]() { auto xe = NdArrayView(expected); auto xc = NdArrayView(computed); for (int64_t i = 0; i < xc.numel(); ++i) { @@ -276,7 +276,7 @@ TEST_P(MatMatProtTest, EncRHS) { EXPECT_EQ(expected.numel(), computed.numel()); - DISPATCH_ALL_FIELDS(field_, "", [&]() { + DISPATCH_ALL_FIELDS(field_, [&]() { auto xe = NdArrayView(expected); auto xc = NdArrayView(computed); for (int64_t i = 0; i < xc.numel(); ++i) { diff --git a/libspu/mpc/cheetah/arith/vector_encoder.cc b/libspu/mpc/cheetah/arith/vector_encoder.cc index 52b71488..3ee37a2b 100644 --- a/libspu/mpc/cheetah/arith/vector_encoder.cc +++ b/libspu/mpc/cheetah/arith/vector_encoder.cc @@ -16,7 +16,6 @@ #include "libspu/core/prelude.h" #include "libspu/core/type_util.h" -#include "libspu/core/xt_helper.h" #include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc::cheetah { @@ -84,7 +83,7 @@ void VectorEncoder::Backward(const NdArrayRef &vec, RLWEPt *out, const auto field = eltype.as()->field(); - DISPATCH_ALL_FIELDS(field, "Backward", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto tmp_buff = ring_zeros(field, {(int64_t)poly_deg_}); auto xvec = NdArrayView(vec); auto xtmp = NdArrayView(tmp_buff); diff --git a/libspu/mpc/cheetah/arith/vector_encoder_test.cc b/libspu/mpc/cheetah/arith/vector_encoder_test.cc index 6056a5c6..0231eeb7 100644 --- a/libspu/mpc/cheetah/arith/vector_encoder_test.cc +++ b/libspu/mpc/cheetah/arith/vector_encoder_test.cc @@ -125,7 +125,7 @@ TEST_P(VectorEncoderTest, ForwardBackward) { auto computed = ms_helper_->ModulusDownRNS(field_, {1L}, absl::MakeSpan(cnst)); - DISPATCH_ALL_FIELDS(field_, "Check", [&]() { + DISPATCH_ALL_FIELDS(field_, [&]() { NdArrayView got(computed); NdArrayView v0(vec0); NdArrayView v1(vec1); diff --git a/libspu/mpc/cheetah/arithmetic.cc b/libspu/mpc/cheetah/arithmetic.cc index e90e80ae..c1765b46 100644 --- a/libspu/mpc/cheetah/arithmetic.cc +++ b/libspu/mpc/cheetah/arithmetic.cc @@ -74,7 +74,7 @@ NdArrayRef MsbA2B::proc(KernelEvalContext* ctx, const NdArrayRef& x) const { const int rank = ctx->getState()->getRank(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using u2k = std::make_unsigned::type; const u2k mask = (static_cast(1) << shft) - 1; NdArrayRef adjusted = ring_zeros(field, x.shape()); diff --git a/libspu/mpc/cheetah/arithmetic.h b/libspu/mpc/cheetah/arithmetic.h index a0a1a5a6..26ecbb93 100644 --- a/libspu/mpc/cheetah/arithmetic.h +++ b/libspu/mpc/cheetah/arithmetic.h @@ -257,7 +257,7 @@ class LShiftA : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override; + const Sizes& bits) const override; }; class MsbA2B : public UnaryKernel { diff --git a/libspu/mpc/cheetah/arithmetic_semi2k.cc b/libspu/mpc/cheetah/arithmetic_semi2k.cc index 89ef7daa..dff06e17 100644 --- a/libspu/mpc/cheetah/arithmetic_semi2k.cc +++ b/libspu/mpc/cheetah/arithmetic_semi2k.cc @@ -24,7 +24,7 @@ namespace spu::mpc::cheetah { NdArrayRef RandA::proc(KernelEvalContext* ctx, const Shape& shape) const { auto* prg_state = ctx->getState(); const auto field = ctx->getState()->getDefaultField(); - return ring_rshift(prg_state->genPriv(field, shape), 2) + return ring_rshift(prg_state->genPriv(field, shape), {2}) .as(makeType(field)); } @@ -59,7 +59,7 @@ NdArrayRef A2V::proc(KernelEvalContext* ctx, const NdArrayRef& in, auto numel = in.numel(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { std::vector share(numel); NdArrayView _in(in); pforeach(0, numel, [&](int64_t idx) { share[idx] = _in[idx]; }); @@ -143,10 +143,7 @@ NdArrayRef MatMulAP::proc(KernelEvalContext*, const NdArrayRef& x, } NdArrayRef LShiftA::proc(KernelEvalContext*, const NdArrayRef& in, - size_t bits) const { - const auto field = in.eltype().as()->field(); - bits %= SizeOf(field) * 8; - + const Sizes& bits) const { return ring_lshift(in, bits).as(in.eltype()); } diff --git a/libspu/mpc/cheetah/boolean.h b/libspu/mpc/cheetah/boolean.h index cdca12db..2fbb11fd 100644 --- a/libspu/mpc/cheetah/boolean.h +++ b/libspu/mpc/cheetah/boolean.h @@ -116,7 +116,7 @@ class LShiftB : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t shift) const override; + const Sizes& shift) const override; }; class RShiftB : public ShiftKernel { @@ -128,7 +128,7 @@ class RShiftB : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t shift) const override; + const Sizes& shift) const override; }; class ARShiftB : public ShiftKernel { @@ -140,7 +140,7 @@ class ARShiftB : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t shift) const override; + const Sizes& shift) const override; }; class BitrevB : public BitrevKernel { diff --git a/libspu/mpc/cheetah/boolean_semi2k.cc b/libspu/mpc/cheetah/boolean_semi2k.cc index a64da4cc..697b2287 100644 --- a/libspu/mpc/cheetah/boolean_semi2k.cc +++ b/libspu/mpc/cheetah/boolean_semi2k.cc @@ -25,7 +25,7 @@ namespace { size_t getNumBits(const NdArrayRef& in) { if (in.eltype().isa()) { const auto field = in.eltype().as()->field(); - return DISPATCH_ALL_FIELDS(field, "_", + return DISPATCH_ALL_FIELDS(field, [&]() { return maxBitWidth(in); }); } else if (in.eltype().isa()) { return in.eltype().as()->nbits(); @@ -81,8 +81,8 @@ NdArrayRef P2B::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { if (comm->getRank() == 0) { ring_xor_(x, in); } - - return makeBShare(x, field, getNumBits(in)); + auto nbits = getNumBits(in) == 0 ? 1 : getNumBits(in); + return makeBShare(x, field, nbits); } NdArrayRef AndBP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, @@ -93,7 +93,7 @@ NdArrayRef AndBP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const size_t out_nbits = std::min(getNumBits(lhs), getNumBits(rhs)); NdArrayRef out(makeType(field, out_nbits), lhs.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _lhs(lhs); NdArrayView _rhs(rhs); NdArrayView _out(out); @@ -130,32 +130,30 @@ NdArrayRef XorBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, } NdArrayRef LShiftB::proc(KernelEvalContext*, const NdArrayRef& in, - size_t shift) const { + const Sizes& shift) const { const auto field = in.eltype().as()->field(); - shift %= SizeOf(field) * 8; - - size_t out_nbits = in.eltype().as()->nbits() + shift; + size_t out_nbits = in.eltype().as()->nbits() + + *std::max_element(shift.begin(), shift.end()); out_nbits = std::clamp(out_nbits, static_cast(0), SizeOf(field) * 8); return makeBShare(ring_lshift(in, shift), field, out_nbits); } NdArrayRef RShiftB::proc(KernelEvalContext*, const NdArrayRef& in, - size_t shift) const { + const Sizes& shift) const { const auto field = in.eltype().as()->field(); - shift %= SizeOf(field) * 8; - size_t nbits = in.eltype().as()->nbits(); - size_t out_nbits = nbits - std::min(nbits, shift); - SPU_ENFORCE(nbits <= SizeOf(field) * 8); + int64_t nbits = in.eltype().as()->nbits(); + int64_t out_nbits = + nbits - std::min(nbits, *std::min_element(shift.begin(), shift.end())); + SPU_ENFORCE(nbits <= static_cast(SizeOf(field) * 8)); return makeBShare(ring_rshift(in, shift), field, out_nbits); } NdArrayRef ARShiftB::proc(KernelEvalContext*, const NdArrayRef& in, - size_t shift) const { + const Sizes& shift) const { const auto field = in.eltype().as()->field(); - shift %= SizeOf(field) * 8; // arithmetic right shift expects to work on ring, or the behaviour is // undefined. @@ -183,7 +181,7 @@ NdArrayRef BitIntlB::proc(KernelEvalContext*, const NdArrayRef& in, NdArrayRef out(in.eltype(), in.shape()); auto numel = in.numel(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _in(in); NdArrayView _out(out); @@ -204,7 +202,7 @@ NdArrayRef BitDeintlB::proc(KernelEvalContext*, const NdArrayRef& in, NdArrayRef out(in.eltype(), in.shape()); auto numel = in.numel(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _in(in); NdArrayView _out(out); diff --git a/libspu/mpc/cheetah/nonlinear/compare_prot.cc b/libspu/mpc/cheetah/nonlinear/compare_prot.cc index 7ee00a9f..94d8b2a1 100644 --- a/libspu/mpc/cheetah/nonlinear/compare_prot.cc +++ b/libspu/mpc/cheetah/nonlinear/compare_prot.cc @@ -16,13 +16,11 @@ #include "yacl/crypto/rand/rand.h" #include "yacl/crypto/tools/prg.h" -#include "yacl/link/link.h" #include "libspu/core/type.h" #include "libspu/mpc/cheetah/ot/basic_ot_prot.h" #include "libspu/mpc/cheetah/ot/ot_util.h" #include "libspu/mpc/cheetah/type.h" -#include "libspu/mpc/common/communicator.h" #include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc::cheetah { @@ -58,14 +56,14 @@ void SetLeafOTMsg(absl::Span ot_messages, uint8_t digit, NdArrayRef CompareProtocol::DoCompute(const NdArrayRef& inp, bool greater_than, NdArrayRef* keep_eq, int64_t bitwidth) { auto field = inp.eltype().as()->field(); - int64_t num_digits = CeilDiv(bitwidth, (int64_t)compare_radix_); + int64_t num_digits = CeilDiv(bitwidth, static_cast(compare_radix_)); size_t radix = static_cast(1) << compare_radix_; // one-of-N OT int64_t num_cmp = inp.numel(); // init to all zero std::vector digits(num_cmp * num_digits, 0); // Step 1 break into digits \in [0, radix) - DISPATCH_ALL_FIELDS(field, "break_digits", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using u2k = std::make_unsigned::type; const auto mask_radix = makeBitsMask(compare_radix_); NdArrayView xinp(inp); @@ -142,7 +140,7 @@ NdArrayRef CompareProtocol::DoCompute(const NdArrayRef& inp, bool greater_than, ring_zeros(field, {static_cast(num_digits * num_cmp)}) .as(boolean_t); - DISPATCH_ALL_FIELDS(field, "copy_leaf", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView xprev_cmp(prev_cmp); NdArrayView xprev_eq(prev_eq); pforeach(0, prev_cmp.numel(), [&](int64_t i) { diff --git a/libspu/mpc/cheetah/nonlinear/compare_prot_test.cc b/libspu/mpc/cheetah/nonlinear/compare_prot_test.cc index a04da8b4..26f9fbd1 100644 --- a/libspu/mpc/cheetah/nonlinear/compare_prot_test.cc +++ b/libspu/mpc/cheetah/nonlinear/compare_prot_test.cc @@ -14,12 +14,9 @@ #include "libspu/mpc/cheetah/nonlinear/compare_prot.h" -#include - #include "gtest/gtest.h" #include "libspu/mpc/cheetah/ot/basic_ot_prot.h" -#include "libspu/mpc/cheetah/type.h" #include "libspu/mpc/utils/ring_ops.h" #include "libspu/mpc/utils/simulate.h" @@ -51,7 +48,7 @@ TEST_P(CompareProtTest, Compare) { inp[0] = ring_rand(field, shape); inp[1] = ring_rand(field, shape); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto xinp = NdArrayView(inp[0]); xinp[0] = 1; xinp[1] = 10; @@ -74,7 +71,7 @@ TEST_P(CompareProtTest, Compare) { cmp_oup[rank] = _c; }); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto xout0 = NdArrayView(cmp_oup[0]); auto xout1 = NdArrayView(cmp_oup[1]); auto xinp0 = NdArrayView(inp[0]); @@ -100,7 +97,7 @@ TEST_P(CompareProtTest, CompareBitWidth) { inp[0] = ring_rand(field, {n, 2}); inp[1] = ring_rand(field, {n, 2}); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { ring2k_t mask = (static_cast(1) << bw) - 1; auto xinp = NdArrayView(inp[0]); xinp[0] = 1; @@ -142,7 +139,7 @@ TEST_P(CompareProtTest, CompareBitWidth) { cmp_oup[rank] = _c; }); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto xout0 = NdArrayView(cmp_oup[0]); auto xout1 = NdArrayView(cmp_oup[1]); auto xinp0 = NdArrayView(inp[0]); @@ -172,7 +169,7 @@ TEST_P(CompareProtTest, WithEq) { inp[1] = _inp[0].slice({0, 0, 0}, {10, 10, 10}, {2, 3, 2}); shape = inp[0].shape(); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto xinp = NdArrayView(inp[0]); xinp[0] = 1; xinp[1] = 10; @@ -197,7 +194,7 @@ TEST_P(CompareProtTest, WithEq) { eq_oup[rank] = _e; }); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto xout0 = NdArrayView(cmp_oup[0]); auto xout1 = NdArrayView(cmp_oup[1]); auto xeq0 = NdArrayView(eq_oup[0]); @@ -229,7 +226,7 @@ TEST_P(CompareProtTest, WithEqBitWidth) { inp[0] = ring_rand(field, {n, 2}); inp[1] = ring_rand(field, {n, 2}); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { ring2k_t mask = (static_cast(1) << bw) - 1; auto xinp = NdArrayView(inp[0]); xinp[0] = 1; @@ -271,7 +268,7 @@ TEST_P(CompareProtTest, WithEqBitWidth) { eq_oup[rank] = _e; }); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto xout0 = NdArrayView(cmp_oup[0]); auto xout1 = NdArrayView(cmp_oup[1]); auto xeq0 = NdArrayView(eq_oup[0]); diff --git a/libspu/mpc/cheetah/nonlinear/equal_prot.cc b/libspu/mpc/cheetah/nonlinear/equal_prot.cc index 1ad852e6..2d04c886 100644 --- a/libspu/mpc/cheetah/nonlinear/equal_prot.cc +++ b/libspu/mpc/cheetah/nonlinear/equal_prot.cc @@ -16,13 +16,11 @@ #include "yacl/crypto/rand/rand.h" #include "yacl/crypto/tools/prg.h" -#include "yacl/link/link.h" #include "libspu/core/type.h" #include "libspu/mpc/cheetah/ot/basic_ot_prot.h" #include "libspu/mpc/cheetah/ot/ot_util.h" #include "libspu/mpc/cheetah/type.h" -#include "libspu/mpc/common/communicator.h" #include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc::cheetah { @@ -61,7 +59,7 @@ NdArrayRef EqualProtocol::DoCompute(const NdArrayRef& inp, size_t bit_width) { // init to all zero std::vector digits(num_cmp * num_digits, 0); - DISPATCH_ALL_FIELDS(field, "Equal_break_digits", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using u2k = std::make_unsigned::type; const auto mask_radix = makeBitsMask(compare_radix_); const auto mask_remain = makeBitsMask(remain); @@ -133,7 +131,7 @@ NdArrayRef EqualProtocol::DoCompute(const NdArrayRef& inp, size_t bit_width) { // m0[1], m1[1], ..., mN[1] // ... // m0[M], m1[M], ..., mN[M] - DISPATCH_ALL_FIELDS(field, "Equal_transpose", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView xprev_eq(prev_eq); for (int64_t r = 0; r < num_cmp; ++r) { for (int64_t c = 0; c < num_digits; ++c) { diff --git a/libspu/mpc/cheetah/nonlinear/equal_prot_test.cc b/libspu/mpc/cheetah/nonlinear/equal_prot_test.cc index 7a51cbd5..4eeead58 100644 --- a/libspu/mpc/cheetah/nonlinear/equal_prot_test.cc +++ b/libspu/mpc/cheetah/nonlinear/equal_prot_test.cc @@ -45,7 +45,7 @@ TEST_P(EqualProtTest, Basic) { inp[0] = ring_rand(field, shape); inp[1] = ring_rand(field, shape); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto xinp0 = NdArrayView(inp[0]); auto xinp1 = NdArrayView(inp[1]); std::copy_n(&xinp1[0], 5, &xinp0[0]); @@ -64,7 +64,7 @@ TEST_P(EqualProtTest, Basic) { SPU_ENFORCE_EQ(eq_oup[0].shape(), shape); SPU_ENFORCE_EQ(eq_oup[1].shape(), shape); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto xeq0 = NdArrayView(eq_oup[0]); auto xeq1 = NdArrayView(eq_oup[1]); auto xinp0 = NdArrayView(inp[0]); diff --git a/libspu/mpc/cheetah/nonlinear/truncate_prot.cc b/libspu/mpc/cheetah/nonlinear/truncate_prot.cc index 3edb5117..32e8dff1 100644 --- a/libspu/mpc/cheetah/nonlinear/truncate_prot.cc +++ b/libspu/mpc/cheetah/nonlinear/truncate_prot.cc @@ -16,7 +16,6 @@ #include "libspu/core/type.h" #include "libspu/mpc/cheetah/nonlinear/compare_prot.h" #include "libspu/mpc/cheetah/ot/basic_ot_prot.h" -#include "libspu/mpc/cheetah/ot/ot_util.h" #include "libspu/mpc/cheetah/type.h" #include "libspu/mpc/utils/ring_ops.h" @@ -66,7 +65,7 @@ NdArrayRef TruncateProtocol::ComputeWrap(const NdArrayRef& inp, wrap_bool = compare_prot.Compute(inp, true); } else { auto adjusted = ring_neg(inp); - DISPATCH_ALL_FIELDS(field, "wrap_adjust", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView xadj(adjusted); pforeach(0, inp.numel(), [&](int64_t i) { xadj[i] -= 1; }); }); @@ -93,7 +92,7 @@ NdArrayRef TruncateProtocol::MSB1ToWrap(const NdArrayRef& inp, const size_t bw = SizeOf(field) * 8; NdArrayRef cot_output = ring_zeros(field, inp.shape()); - DISPATCH_ALL_FIELDS(field, "MSB1ToWrap", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using u2k = std::make_unsigned::type; NdArrayView xinp(inp); auto xout = absl::MakeSpan(&cot_output.at(0), cot_output.numel()); @@ -148,7 +147,7 @@ NdArrayRef TruncateProtocol::MSB0ToWrap(const NdArrayRef& inp, outp = ring_randbit(field, inp.shape()); std::vector send(numel * N); - DISPATCH_ALL_FIELDS(field, "MSB0_adjust", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using u2k = std::make_unsigned::type; NdArrayView xinp(inp); NdArrayView xrnd(outp); @@ -166,7 +165,7 @@ NdArrayRef TruncateProtocol::MSB0ToWrap(const NdArrayRef& inp, sender->Flush(); } else { std::vector choices(numel, 0); - DISPATCH_ALL_FIELDS(field, "MSB0_adjust", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using u2k = std::make_unsigned::type; NdArrayView xinp(inp); for (int64_t i = 0; i < numel; ++i) { @@ -179,7 +178,7 @@ NdArrayRef TruncateProtocol::MSB0ToWrap(const NdArrayRef& inp, absl::MakeSpan(recv), nbits); outp = ring_zeros(field, inp.shape()); - DISPATCH_ALL_FIELDS(field, "MSB0_finalize", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView xoup(outp); pforeach(0, numel, [&](int64_t i) { xoup[i] = static_cast(recv[i] & 1); @@ -220,7 +219,7 @@ NdArrayRef TruncateProtocol::Compute(const NdArrayRef& inp, Meta meta) { if (rank == 0) { NdArrayRef tmp = inp.clone(); - DISPATCH_ALL_FIELDS(field, "trunc_with_heuristic", [&] { + DISPATCH_ALL_FIELDS(field, [&] { NdArrayView _inp(tmp); ring2k_t big_value = static_cast(1) << (bit_width - kHeuristicBound); @@ -230,7 +229,7 @@ NdArrayRef TruncateProtocol::Compute(const NdArrayRef& inp, Meta meta) { tmp = Compute(tmp, meta); - DISPATCH_ALL_FIELDS(field, "trunc_with_heuristic", [&] { + DISPATCH_ALL_FIELDS(field, [&] { NdArrayView _outp(tmp); ring2k_t big_value = static_cast(1) << (bit_width - kHeuristicBound - shift); @@ -246,7 +245,7 @@ NdArrayRef TruncateProtocol::Compute(const NdArrayRef& inp, Meta meta) { NdArrayRef wrap_ashr; NdArrayRef out = ring_zeros(field, inp.shape()); - return DISPATCH_ALL_FIELDS(field, "Truncate", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { const ring2k_t component = (static_cast(1) << (bit_width - 1)); NdArrayView xinp(inp); diff --git a/libspu/mpc/cheetah/nonlinear/truncate_prot_test.cc b/libspu/mpc/cheetah/nonlinear/truncate_prot_test.cc index 4cf591ac..f1c737e8 100644 --- a/libspu/mpc/cheetah/nonlinear/truncate_prot_test.cc +++ b/libspu/mpc/cheetah/nonlinear/truncate_prot_test.cc @@ -14,12 +14,9 @@ #include "libspu/mpc/cheetah/nonlinear/truncate_prot.h" -#include - #include "gtest/gtest.h" #include "libspu/mpc/cheetah/ot/basic_ot_prot.h" -#include "libspu/mpc/cheetah/type.h" #include "libspu/mpc/utils/ring_ops.h" #include "libspu/mpc/utils/simulate.h" @@ -65,7 +62,7 @@ TEST_P(TruncateProtTest, Basic) { sign = SignType::Unknown; } else { auto msg = ring_rand(field, {n}); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto xmsg = NdArrayView(msg); size_t bw = SizeOf(field) * 8; if (msb == "Zero") { @@ -111,7 +108,7 @@ TEST_P(TruncateProtTest, Basic) { EXPECT_EQ(oup[0].shape(), oup[1].shape()); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using signed_t = std::make_signed::type; using usigned_t = std::make_unsigned::type; @@ -163,9 +160,10 @@ TEST_P(TruncateProtTest, Heuristic) { NdArrayRef inp[2]; inp[0] = ring_rand(field, {n}); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto msg = ring_rand(field, {n}); - ring_rshift_(msg, TruncateProtocol::kHeuristicBound); + ring_rshift_(msg, + {static_cast(TruncateProtocol::kHeuristicBound)}); NdArrayView xmsg(msg); for (int64_t i = 0; i < n; i += 2) { xmsg[i] = -xmsg[i]; @@ -191,7 +189,7 @@ TEST_P(TruncateProtTest, Heuristic) { [[maybe_unused]] int count_zero = 0; [[maybe_unused]] int count_pos = 0; [[maybe_unused]] int count_neg = 0; - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using signed_t = std::make_signed::type; auto xout0 = NdArrayView(oup[0]); diff --git a/libspu/mpc/cheetah/ot/basic_ot_prot.cc b/libspu/mpc/cheetah/ot/basic_ot_prot.cc index 201a57ad..f5a03549 100644 --- a/libspu/mpc/cheetah/ot/basic_ot_prot.cc +++ b/libspu/mpc/cheetah/ot/basic_ot_prot.cc @@ -79,7 +79,7 @@ NdArrayRef BasicOTProtocols::PackedB2A(const NdArrayRef &inp) { } SPU_ENFORCE(nbits > 0 && nbits <= 8 * SizeOf(field)); - auto rand_bits = DISPATCH_ALL_FIELDS(field, "single_b2a", [&]() { + auto rand_bits = DISPATCH_ALL_FIELDS(field, [&]() { if ((nbits & 7) or (n * inp.elsize()) & 7) { // The SseTranspose requires the #rows and #columns is multiple of 8. // Thus, we call the less efficient RandBits on margin cases. @@ -142,7 +142,7 @@ NdArrayRef BasicOTProtocols::PackedB2A(const NdArrayRef &inp) { const int64_t n = _bits.numel() / nbits; // init as all 0s. auto iform = ring_zeros(field, inp.shape()); - DISPATCH_ALL_FIELDS(field, "conv_to_bits", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto bits = NdArrayView(_bits); auto digit = NdArrayView(iform); for (int64_t i = 0; i < n; ++i) { @@ -163,7 +163,7 @@ NdArrayRef BasicOTProtocols::PackedB2A(const NdArrayRef &inp) { // compute c + (1 - 2*c)* NdArrayRef oup = ring_zeros(field, inp.shape()); - DISPATCH_ALL_FIELDS(field, "packed_b2a", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using u2k = std::make_unsigned::type; int rank = Rank(); auto xr = NdArrayView(rand_bits); @@ -205,7 +205,7 @@ NdArrayRef BasicOTProtocols::SingleB2A(const NdArrayRef &inp, int bit_width) { const int64_t n = inp.numel(); NdArrayRef oup = ring_zeros(field, inp.shape()); - DISPATCH_ALL_FIELDS(field, "single_b2a", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using u2k = std::make_unsigned::type; const u2k msk = makeBitsMask(bit_width); auto input = NdArrayView(inp); @@ -373,7 +373,7 @@ std::array BasicOTProtocols::AndTriple(FieldType field, auto AND_b = ring_zeros(field, shape); auto AND_c = ring_zeros(field, shape); - DISPATCH_ALL_FIELDS(field, "AndTriple", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto AND_xa = NdArrayView(AND_a); auto AND_xb = NdArrayView(AND_b); auto AND_xc = NdArrayView(AND_c); @@ -430,7 +430,7 @@ std::array BasicOTProtocols::CorrelatedAndTriple( auto AND_b1 = ring_zeros(field, shape); auto AND_c1 = ring_zeros(field, shape); - DISPATCH_ALL_FIELDS(field, "AndTriple", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto AND_xa = NdArrayView(AND_a); auto AND_xb0 = NdArrayView(AND_b0); auto AND_xc0 = NdArrayView(AND_c0); @@ -465,7 +465,7 @@ NdArrayRef BasicOTProtocols::Multiplexer(const NdArrayRef &msg, std::vector sel(size); // Compute (x0 + x1) * (b0 ^ b1) // Also b0 ^ b1 = 1 - 2*b0*b1 - return DISPATCH_ALL_FIELDS(field, "Multiplexer", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _msg(msg); NdArrayView _sel(select); auto corr_data = absl::MakeSpan(&_corr_data.at(0), size); @@ -506,7 +506,7 @@ NdArrayRef BasicOTProtocols::PrivateMulxRecv(const NdArrayRef &msg, auto recv = ring_zeros(field, msg.shape()); std::vector sel(size); - DISPATCH_ALL_FIELDS(field, "convert", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _sel(select); pforeach(0, size, [&](int64_t i) { sel[i] = static_cast(_sel[i] & 1); }); @@ -526,7 +526,7 @@ NdArrayRef BasicOTProtocols::PrivateMulxRecv(const NdArrayRef &msg, // Compute (x0 + x1) * b // x0 * b + x1 * b // COT compute - DISPATCH_ALL_FIELDS(field, "MultiplexerOnPrivate", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _msg(msg); auto _recv = absl::MakeSpan(&recv.at(0), size); @@ -549,7 +549,7 @@ NdArrayRef BasicOTProtocols::PrivateMulxSend(const NdArrayRef &msg) { auto recv = ring_zeros(field, msg.shape()); // Compute (x0 + x1) * b // x0 * b + x1 * b - DISPATCH_ALL_FIELDS(field, "MultiplexerOnPrivate", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto _msg = absl::MakeConstSpan(&msg.at(0), size); auto _recv = absl::MakeSpan(&recv.at(0), size); diff --git a/libspu/mpc/cheetah/ot/basic_ot_prot_test.cc b/libspu/mpc/cheetah/ot/basic_ot_prot_test.cc index aa18321c..17f78a4d 100644 --- a/libspu/mpc/cheetah/ot/basic_ot_prot_test.cc +++ b/libspu/mpc/cheetah/ot/basic_ot_prot_test.cc @@ -82,7 +82,7 @@ TEST_P(BasicOTProtTest, SingleB2A) { auto bshr0 = ring_rand(field, shape).as(boolean_t); auto bshr1 = ring_rand(field, shape).as(boolean_t); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto mask = static_cast(-1); if (nbits > 0) { mask = (static_cast(1) << packed_nbits) - 1; @@ -110,7 +110,7 @@ TEST_P(BasicOTProtTest, SingleB2A) { EXPECT_EQ(ashr0.shape(), ashr1.shape()); EXPECT_EQ(shape, ashr0.shape()); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView b0(bshr0); NdArrayView b1(bshr1); NdArrayView a0(ashr0); @@ -138,7 +138,7 @@ TEST_P(BasicOTProtTest, PackedB2A) { auto bshr0 = ring_rand(field, shape).as(boolean_t); auto bshr1 = ring_rand(field, shape).as(boolean_t); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto mask = static_cast(-1); if (nbits > 0) { mask = (static_cast(1) << packed_nbits) - 1; @@ -165,7 +165,7 @@ TEST_P(BasicOTProtTest, PackedB2A) { EXPECT_EQ(ashr0.shape(), ashr1.shape()); EXPECT_EQ(ashr0.shape(), shape); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView b0(bshr0); NdArrayView b1(bshr1); NdArrayView a0(ashr0); @@ -197,7 +197,7 @@ TEST_P(BasicOTProtTest, PackedB2AFull) { auto bshr0 = ring_rand(field, shape).as(boolean_t); auto bshr1 = ring_rand(field, shape).as(boolean_t); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto mask = static_cast(-1); if (nbits > 0) { mask = (static_cast(1) << packed_nbits) - 1; @@ -228,7 +228,7 @@ TEST_P(BasicOTProtTest, PackedB2AFull) { EXPECT_EQ(ashr0.shape(), ashr1.shape()); EXPECT_EQ(ashr0.shape(), shape); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView b0(bshr0); NdArrayView b1(bshr1); NdArrayView a0(ashr0); @@ -267,7 +267,7 @@ TEST_P(BasicOTProtTest, AndTripleSparse) { } }); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { ring2k_t max = static_cast(1) << target_nbits; NdArrayView a0(triple[0][0]); NdArrayView b0(triple[0][1]); @@ -304,7 +304,7 @@ TEST_P(BasicOTProtTest, BitwiseAnd) { for (int i : {0, 1}) { lhs[i] = ring_rand(field, shape).as(boolean_t); rhs[i] = ring_rand(field, shape).as(boolean_t); - DISPATCH_ALL_FIELDS(field, "mask", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { ring2k_t mask = makeBitsMask(bw); NdArrayView L(lhs[i]); NdArrayView R(rhs[i]); @@ -324,7 +324,7 @@ TEST_P(BasicOTProtTest, BitwiseAnd) { auto expected = ring_and(ring_xor(lhs[0], lhs[1]), ring_xor(rhs[0], rhs[1])); auto got = ring_xor(out[0], out[1]); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView e(expected); NdArrayView g(got); @@ -350,7 +350,7 @@ TEST_P(BasicOTProtTest, AndTripleFull) { } }); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView a0(packed_triple[0][0]); NdArrayView b0(packed_triple[0][1]); NdArrayView c0(packed_triple[0][2]); @@ -383,7 +383,7 @@ TEST_P(BasicOTProtTest, Multiplexer) { auto bshr0 = ring_rand(field, shape).as(boolean_t); auto bshr1 = ring_rand(field, shape).as(boolean_t); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto mask = static_cast(1); NdArrayView xb0(bshr0); NdArrayView xb1(bshr1); @@ -407,7 +407,7 @@ TEST_P(BasicOTProtTest, Multiplexer) { EXPECT_EQ(computed[0].shape(), computed[1].shape()); EXPECT_EQ(computed[0].shape(), shape); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView a0(ashr0); NdArrayView a1(ashr1); NdArrayView b0(bshr0); @@ -444,7 +444,7 @@ TEST_P(BasicOTProtTest, CorrelatedAndTriple) { EXPECT_EQ(corr_triple[1][0].shape(), corr_triple[1][i].shape()); } - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto a0 = NdArrayView(corr_triple[0][0]); auto b0 = NdArrayView(corr_triple[0][1]); auto c0 = NdArrayView(corr_triple[0][2]); @@ -487,7 +487,7 @@ TEST_P(BasicOTProtTest, PrivateMulx) { auto ashr1 = ring_rand(field, shape); auto choices = ring_rand(field, shape).as(boolean_t); - DISPATCH_ALL_FIELDS(field, "bit", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto mask = static_cast(1); NdArrayView xb(choices); pforeach(0, xb.numel(), [&](int64_t i) { xb[i] &= mask; }); @@ -507,7 +507,7 @@ TEST_P(BasicOTProtTest, PrivateMulx) { EXPECT_EQ(computed[0].shape(), computed[1].shape()); EXPECT_EQ(computed[0].shape(), shape); - DISPATCH_ALL_FIELDS(field, "check", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView a0(ashr0); NdArrayView a1(ashr1); NdArrayView c(choices); diff --git a/libspu/mpc/cheetah/ot/emp/ferret_test.cc b/libspu/mpc/cheetah/ot/emp/ferret_test.cc index af6cb96a..b878bdbe 100644 --- a/libspu/mpc/cheetah/ot/emp/ferret_test.cc +++ b/libspu/mpc/cheetah/ot/emp/ferret_test.cc @@ -55,7 +55,7 @@ TEST_P(FerretCOTTest, ChosenCorrelationChosenChoice) { return static_cast(uniform(rdv) & 1); }); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView correlation(_correlation); std::vector computed[2]; utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { @@ -86,7 +86,7 @@ TEST_P(FerretCOTTest, RndMsgRndChoice) { constexpr size_t bw = 2; size_t n = 10; - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { std::vector msg0(n); std::vector msg1(n); ring2k_t max = static_cast(1) << bw; @@ -123,7 +123,7 @@ TEST_P(FerretCOTTest, RndMsgChosenChoice) { constexpr size_t bw = 2; size_t n = 10; - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { std::vector msg0(n); std::vector msg1(n); ring2k_t max = static_cast(1) << bw; @@ -163,7 +163,7 @@ TEST_P(FerretCOTTest, ChosenMsgChosenChoice) { size_t kWorldSize = 2; int64_t n = 100; auto field = GetParam(); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using scalar_t = ring2k_t; std::default_random_engine rdv; std::uniform_int_distribution uniform(0, -1); diff --git a/libspu/mpc/cheetah/ot/ot_util.cc b/libspu/mpc/cheetah/ot/ot_util.cc index 1d308a2f..3d593512 100644 --- a/libspu/mpc/cheetah/ot/ot_util.cc +++ b/libspu/mpc/cheetah/ot/ot_util.cc @@ -61,7 +61,7 @@ NdArrayRef OpenShare(const NdArrayRef &shr, ReduceOp op, size_t nbits, size_t compact_numel = CeilDiv(numel * nbits, fwidth); NdArrayRef out(shr.eltype(), {(int64_t)numel}); - DISPATCH_ALL_FIELDS(field, "zip", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto inp = absl::MakeConstSpan(&shr.at(0), numel); auto oup = absl::MakeSpan(&out.at(0), compact_numel); @@ -94,8 +94,10 @@ NdArrayRef OpenShare(const NdArrayRef &shr, ReduceOp op, size_t nbits, __attribute__((target("sse2"))) #endif void SseTranspose(uint8_t *out, uint8_t const *inp, uint64_t nrows, uint64_t ncols) { - uint64_t rr, cc; - int i, h; + uint64_t rr; + uint64_t cc; + int i; + int h; union { __m128i x; uint8_t b[16]; @@ -116,7 +118,9 @@ void SseTranspose(uint8_t *out, uint8_t const *inp, uint64_t nrows, uint64_t nco *(uint16_t *)&OUT(rr, cc + i) = _mm_movemask_epi8(vec); } } - if (rr == nrows) return; + if (rr == nrows) { + return; + } // The remainder is a block of 8x(16n+8) bits (n may be 0). // Do a PAIR of 8x8 blocks in each step: @@ -153,9 +157,12 @@ void SseTranspose(uint8_t *out, uint8_t const *inp, uint64_t nrows, uint64_t nco if (cc == ncols) return; // Do the remaining 8x8 block: - for (i = 0; i < 8; ++i) tmp.b[i] = INP(rr + i, cc); - for (i = 8; --i >= 0; tmp.x = _mm_slli_epi64(tmp.x, 1)) + for (i = 0; i < 8; ++i) { + tmp.b[i] = INP(rr + i, cc); + } + for (i = 8; --i >= 0; tmp.x = _mm_slli_epi64(tmp.x, 1)) { OUT(rr, cc + i) = _mm_movemask_epi8(tmp.x); + } } } // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/ot/ot_util_test.cc b/libspu/mpc/cheetah/ot/ot_util_test.cc index ea5e0aac..149e9694 100644 --- a/libspu/mpc/cheetah/ot/ot_util_test.cc +++ b/libspu/mpc/cheetah/ot/ot_util_test.cc @@ -14,8 +14,6 @@ #include "libspu/mpc/cheetah/ot/ot_util.h" -#include - #include "gtest/gtest.h" #include "libspu/mpc/utils/ring_ops.h" @@ -40,7 +38,7 @@ TEST_P(OtUtilTest, ZipArray) { auto unzip = ring_zeros(field, {n}); - DISPATCH_ALL_FIELDS(field, "UT_ZipArray", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { for (size_t bw : {1, 2, 4, 7, 15, 16}) { int64_t pack_load = elsze * 8 / bw; auto zip = ring_zeros(field, {(n + pack_load - 1) / pack_load}); @@ -69,7 +67,7 @@ TEST_P(OtUtilTest, ZipArrayBit) { auto unzip = ring_zeros(field, {n}); - DISPATCH_ALL_FIELDS(field, "UT_ZipArrayBit", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { const size_t elsze = SizeOf(field); for (size_t bw : {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) { size_t width = elsze * 8; diff --git a/libspu/mpc/cheetah/ot/yacl/ferret_test.cc b/libspu/mpc/cheetah/ot/yacl/ferret_test.cc index c78abe62..a1eea9e7 100644 --- a/libspu/mpc/cheetah/ot/yacl/ferret_test.cc +++ b/libspu/mpc/cheetah/ot/yacl/ferret_test.cc @@ -55,7 +55,7 @@ TEST_P(FerretCOTTest, ChosenCorrelationChosenChoice) { return static_cast(uniform(rdv) & 1); }); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView correlation(_correlation); std::vector computed[2]; utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { @@ -87,7 +87,7 @@ TEST_P(FerretCOTTest, RndMsgRndChoice) { constexpr size_t bw = 2; size_t n = 10; - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { std::vector msg0(n); std::vector msg1(n); ring2k_t max = static_cast(1) << bw; @@ -125,7 +125,7 @@ TEST_P(FerretCOTTest, RndMsgChosenChoice) { constexpr size_t bw = 2; size_t n = 10; - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { std::vector msg0(n); std::vector msg1(n); ring2k_t max = static_cast(1) << bw; @@ -166,7 +166,7 @@ TEST_P(FerretCOTTest, ChosenMsgChosenChoice) { int64_t n = 1 << 10; auto field = std::get<0>(GetParam()); auto use_ss = std::get<1>(GetParam()); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using scalar_t = ring2k_t; std::default_random_engine rdv; std::uniform_int_distribution uniform(0, -1); diff --git a/libspu/mpc/cheetah/ot/yacl/yacl_ote_adapter.h b/libspu/mpc/cheetah/ot/yacl/yacl_ote_adapter.h index 294e095f..bd37640b 100644 --- a/libspu/mpc/cheetah/ot/yacl/yacl_ote_adapter.h +++ b/libspu/mpc/cheetah/ot/yacl/yacl_ote_adapter.h @@ -19,8 +19,8 @@ #include "yacl/kernel/algorithms/base_ot.h" #include "yacl/kernel/algorithms/ferret_ote.h" #include "yacl/kernel/algorithms/iknp_ote.h" -#include "yacl/kernel/algorithms/ot_store.h" #include "yacl/kernel/algorithms/softspoken_ote.h" +#include "yacl/kernel/type/ot_store.h" #include "libspu/core/prelude.h" #include "libspu/mpc/cheetah/ot/ot_util.h" diff --git a/libspu/mpc/cheetah/rlwe/modswitch_helper.cc b/libspu/mpc/cheetah/rlwe/modswitch_helper.cc index 69c86704..41ea9ff1 100644 --- a/libspu/mpc/cheetah/rlwe/modswitch_helper.cc +++ b/libspu/mpc/cheetah/rlwe/modswitch_helper.cc @@ -436,7 +436,7 @@ void ModulusSwitchHelper::ModulusUpAt(const NdArrayRef &src, size_t mod_idx, SPU_ENFORCE(eltype.isa(), "source must be ring_type, got={}", eltype); const auto field = eltype.as()->field(); - DISPATCH_ALL_FIELDS(field, "ModulusUpAt", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using ring2u = std::make_unsigned::type; impl_->ModulusUpAt(NdArrayView(src), mod_idx, out); }); @@ -450,7 +450,7 @@ void ModulusSwitchHelper::CenteralizeAt(const NdArrayRef &src, size_t mod_idx, SPU_ENFORCE_EQ(numel, out.size()); SPU_ENFORCE(eltype.isa(), "source must be ring_type, got={}", eltype); const auto field = eltype.as()->field(); - DISPATCH_ALL_FIELDS(field, "CenteralizeAt", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using ring2u = std::make_unsigned::type; impl_->CenteralizeAt(NdArrayView(src), mod_idx, out); }); @@ -481,7 +481,7 @@ void ModulusSwitchHelper::ModulusDownRNS(absl::Span src, size_t num_elt = out.numel(); SPU_ENFORCE_EQ(num_elt * num_modulus, src.size()); - return DISPATCH_ALL_FIELDS(field, "ModulusDownRNS", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using ring2u = std::make_unsigned::type; absl::Span out_wrap(reinterpret_cast(out.data()), num_elt); diff --git a/libspu/mpc/cheetah/rlwe/modswitch_helper_test.cc b/libspu/mpc/cheetah/rlwe/modswitch_helper_test.cc index ecb5beee..2aa7a203 100644 --- a/libspu/mpc/cheetah/rlwe/modswitch_helper_test.cc +++ b/libspu/mpc/cheetah/rlwe/modswitch_helper_test.cc @@ -120,9 +120,10 @@ TEST_P(RLWE2LWETest, ModulusSwitch_UpDown) { } // e = round(r'/Delta) mod t \in R_t auto src = absl::MakeSpan(pt.data(), pt.coeff_count()); - auto cmp = ms_helper_->ModulusDownRNS(field_, {(int64_t)poly_deg}, src); + auto cmp = ms_helper_->ModulusDownRNS( + field_, {static_cast(poly_deg)}, src); // check r =? e - DISPATCH_ALL_FIELDS(field_, "", [&]() { + DISPATCH_ALL_FIELDS(field_, [&]() { auto expected = NdArrayView(_vec); auto computed = NdArrayView(cmp); for (int64_t i = 0; i < expected.numel(); ++i) { @@ -145,7 +146,7 @@ TEST_P(RLWE2LWETest, ModulusSwitch_DownUp) { // r <- R_q RLWEPt rnd; UniformPoly(*context_, &rnd); - auto &modulus = context_->first_context_data()->parms().coeff_modulus(); + const auto &modulus = context_->first_context_data()->parms().coeff_modulus(); // b = a' - r RLWEPt poly1; { @@ -158,8 +159,8 @@ TEST_P(RLWE2LWETest, ModulusSwitch_DownUp) { // r' = round(r/Delta) mod t \in R_t // b' = round(b/Delta) mod t \in R_t - auto shr0 = ring_zeros(field_, {(int64_t)poly_deg}); - auto shr1 = ring_zeros(field_, {(int64_t)poly_deg}); + auto shr0 = ring_zeros(field_, {static_cast(poly_deg)}); + auto shr1 = ring_zeros(field_, {static_cast(poly_deg)}); { auto src = absl::MakeSpan(rnd.data(), rnd.coeff_count()); ms_helper_->ModulusDownRNS(src, shr0); @@ -168,7 +169,7 @@ TEST_P(RLWE2LWETest, ModulusSwitch_DownUp) { ms_helper_->ModulusDownRNS(src, shr1); } - DISPATCH_ALL_FIELDS(field_, "", [&]() { + DISPATCH_ALL_FIELDS(field_, [&]() { auto expected = NdArrayView(vec_a); auto computed0 = NdArrayView(shr0); auto computed1 = NdArrayView(shr1); diff --git a/libspu/mpc/cheetah/rlwe/packlwes.cc b/libspu/mpc/cheetah/rlwe/packlwes.cc index 59f99e4c..f2c7272f 100644 --- a/libspu/mpc/cheetah/rlwe/packlwes.cc +++ b/libspu/mpc/cheetah/rlwe/packlwes.cc @@ -134,7 +134,7 @@ void PackingHelper::doPackingRLWEs(absl::Span rlwes, seal::Evaluator evaluator(context_); const int64_t logn = absl::bit_width(gap_) - 1; for (int64_t k = logn; k >= 1; --k) { - int64_t h = 1 << (k - 1); + int64_t h = static_cast(1) << (k - 1); yacl::parallel_for(0, h, [&](int64_t bgn, int64_t end) { RLWECt dummy; // zero-padding with zero RLWE for (int64_t i = bgn; i < end; ++i) { @@ -188,7 +188,7 @@ void GenerateGaloisKeyForPacking(const seal::SEALContext &context, size_t logN = absl::bit_width(N) - 1; std::vector galois_elt; for (uint32_t i = 1; i <= logN; i++) { - galois_elt.push_back((1u << i) + 1); + galois_elt.push_back((static_cast(1) << i) + 1); } seal::KeyGenerator keygen(context, key); diff --git a/libspu/mpc/cheetah/rlwe/packlwes_test.cc b/libspu/mpc/cheetah/rlwe/packlwes_test.cc index 6d924c26..4bda1a27 100644 --- a/libspu/mpc/cheetah/rlwe/packlwes_test.cc +++ b/libspu/mpc/cheetah/rlwe/packlwes_test.cc @@ -13,8 +13,6 @@ // limitations under the License. #include "libspu/mpc/cheetah/rlwe/packlwes.h" -#include - #include "gtest/gtest.h" #include "seal/seal.h" @@ -235,7 +233,8 @@ TEST_P(PackLWEsTest, Phantom) { N_encoder_->Forward(array, &pt, true); NttInplace(pt, *N_context_); - RLWECt rlwe0, rlwe1; + RLWECt rlwe0; + RLWECt rlwe1; CATCH_SEAL_ERROR(encryptor.encrypt_symmetric(pt, rlwe0)); CATCH_SEAL_ERROR(encryptor.encrypt_symmetric(pt, rlwe1)); if (rlwe0.is_ntt_form()) { @@ -345,7 +344,7 @@ void VectorEncoder::Backward(const NdArrayRef &vec, RLWEPt *out, const auto field = eltype.as()->field(); - DISPATCH_ALL_FIELDS(field, "Backward", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto tmp_buff = ring_zeros(field, {(int64_t)poly_deg_}); auto xvec = NdArrayView(vec); auto xtmp = NdArrayView(tmp_buff); diff --git a/libspu/mpc/cheetah/state.cc b/libspu/mpc/cheetah/state.cc index 1f754baa..dded2f5f 100644 --- a/libspu/mpc/cheetah/state.cc +++ b/libspu/mpc/cheetah/state.cc @@ -16,8 +16,6 @@ #include -#include "spdlog/spdlog.h" - #include "libspu/core/context.h" #include "libspu/core/ndarray_ref.h" #include "libspu/core/prelude.h" @@ -89,7 +87,7 @@ void CheetahMulState::makeSureCacheSize(FieldType field, int64_t numel) { cross.slice({num_beaver}, {2 * num_beaver}, {1})), ring_mul(beaver[0], beaver[1])); - DISPATCH_ALL_FIELDS(field, "makeSureCacheSize", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { for (size_t i : {0, 1, 2}) { auto tmp = ring_zeros(field, {num_beaver + cached_sze_}); NdArrayView new_cache(tmp); diff --git a/libspu/mpc/common/prg_state.h b/libspu/mpc/common/prg_state.h index 7e8938e6..47a6eb66 100644 --- a/libspu/mpc/common/prg_state.h +++ b/libspu/mpc/common/prg_state.h @@ -39,9 +39,10 @@ class PrgState : public State { uint64_t priv_counter_ = 0; // Pseudorandom Secret Sharing seeds. - uint128_t next_seed_ = 0; uint128_t self_seed_ = 0; - uint64_t prss_counter_ = 0; + uint128_t next_seed_ = 0; + uint64_t r0_counter_ = 0; // cnt for self_seed + uint64_t r1_counter_ = 0; // cnt for next_seed public: static constexpr char kBindName[] = "PrgState"; @@ -65,7 +66,7 @@ class PrgState : public State { // This correlation could be used to construct zero shares. // // Note: ignore_first, ignore_second is for perf improvement. - enum class GenPrssCtrl { Both, First, Second, None }; + enum class GenPrssCtrl { Both, First, Second }; std::pair genPrssPair(FieldType field, const Shape& shape, GenPrssCtrl ctrl); @@ -73,29 +74,21 @@ class PrgState : public State { template void fillPrssPair(T* r0, T* r1, size_t numel, GenPrssCtrl ctrl) { switch (ctrl) { - case GenPrssCtrl::None: { - // Nothing to generate, pure dummy - prss_counter_ = yacl::crypto::DummyUpdateRandomCount(prss_counter_, - numel * sizeof(T)); - return; - } case GenPrssCtrl::First: { - prss_counter_ = yacl::crypto::FillPRand( - kAesType, self_seed_, 0, prss_counter_, absl::MakeSpan(r0, numel)); + r0_counter_ = yacl::crypto::FillPRand( + kAesType, self_seed_, 0, r0_counter_, absl::MakeSpan(r0, numel)); return; } case GenPrssCtrl::Second: { - prss_counter_ = yacl::crypto::FillPRand( - kAesType, next_seed_, 0, prss_counter_, absl::MakeSpan(r1, numel)); + r1_counter_ = yacl::crypto::FillPRand( + kAesType, next_seed_, 0, r1_counter_, absl::MakeSpan(r1, numel)); return; } case GenPrssCtrl::Both: { - auto counter0 = yacl::crypto::FillPRand( - kAesType, self_seed_, 0, prss_counter_, absl::MakeSpan(r0, numel)); - auto counter1 = yacl::crypto::FillPRand( - kAesType, next_seed_, 0, prss_counter_, absl::MakeSpan(r1, numel)); - SPU_ENFORCE(counter0 == counter1); - prss_counter_ = counter0; + r0_counter_ = yacl::crypto::FillPRand( + kAesType, self_seed_, 0, r0_counter_, absl::MakeSpan(r0, numel)); + r1_counter_ = yacl::crypto::FillPRand( + kAesType, next_seed_, 0, r1_counter_, absl::MakeSpan(r1, numel)); return; } } diff --git a/libspu/mpc/common/pv2k.cc b/libspu/mpc/common/pv2k.cc index 6c777779..ce59f9bb 100644 --- a/libspu/mpc/common/pv2k.cc +++ b/libspu/mpc/common/pv2k.cc @@ -73,7 +73,7 @@ class V2P : public UnaryKernel { auto numel = in.numel(); - DISPATCH_ALL_FIELDS(field, "v2p", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { std::vector priv(numel); NdArrayView _in(in); @@ -113,7 +113,7 @@ class MakeP : public Kernel { Strides(shape.size(), 0), // strides 0); - DISPATCH_ALL_FIELDS(field, "pub2k.make_p", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { arr.at(Index(shape.size(), 0)) = static_cast(init); }); return Value(arr, DT_INVALID); @@ -176,7 +176,8 @@ class MsbP : public UnaryKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext*, const NdArrayRef& in) const override { - return ring_rshift(in, in.elsize() * 8 - 1).as(in.eltype()); + return ring_rshift(in, {static_cast(in.elsize() * 8 - 1)}) + .as(in.eltype()); } }; @@ -190,7 +191,8 @@ class MsbV : public UnaryKernel { NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override { if (isOwner(ctx, in.eltype())) { - return ring_rshift(in, in.elsize() * 8 - 1).as(in.eltype()); + return ring_rshift(in, {static_cast(in.elsize() * 8 - 1)}) + .as(in.eltype()); } else { return in; } @@ -512,7 +514,7 @@ class LShiftP : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext*, const NdArrayRef& in, - size_t bits) const override { + const Sizes& bits) const override { return ring_lshift(in, bits).as(in.eltype()); } }; @@ -526,7 +528,7 @@ class LShiftV : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override { + const Sizes& bits) const override { if (isOwner(ctx, in.eltype())) { return ring_lshift(in, bits).as(in.eltype()); } else { @@ -544,7 +546,7 @@ class RShiftP : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext*, const NdArrayRef& in, - size_t bits) const override { + const Sizes& bits) const override { return ring_rshift(in, bits).as(in.eltype()); } }; @@ -558,7 +560,7 @@ class RShiftV : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override { + const Sizes& bits) const override { if (isOwner(ctx, in.eltype())) { return ring_rshift(in, bits).as(in.eltype()); } else { @@ -576,7 +578,7 @@ class ARShiftP : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext*, const NdArrayRef& in, - size_t bits) const override { + const Sizes& bits) const override { return ring_arshift(in, bits).as(in.eltype()); } }; @@ -590,7 +592,7 @@ class ARShiftV : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override { + const Sizes& bits) const override { if (isOwner(ctx, in.eltype())) { return ring_arshift(in, bits).as(in.eltype()); } else { @@ -607,8 +609,8 @@ NdArrayRef rounded_arshift(const NdArrayRef& in, size_t bits) { // https://stackoverflow.com/questions/14008330/how-do-you-multiply-two-fixed-point-numbers // Under certain pattern, like sum(mul(A, B)), error can accumulate in a // fairly significant way - auto v1 = ring_arshift(in, bits); - auto v2 = ring_arshift(in, bits - 1); + auto v1 = ring_arshift(in, {static_cast(bits)}); + auto v2 = ring_arshift(in, {static_cast(bits - 1)}); ring_and_(v2, ring_ones(in.eltype().as()->field(), in.shape())); ring_add_(v1, v2); return v1; @@ -623,8 +625,9 @@ class TruncP : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext*, const NdArrayRef& in, - size_t bits) const override { - return rounded_arshift(in, bits).as(in.eltype()); + const Sizes& bits) const override { + SPU_ENFORCE(bits.size() == 1, "truncation bits should be splat"); + return rounded_arshift(in, bits[0]).as(in.eltype()); } }; @@ -637,9 +640,10 @@ class TruncV : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override { + const Sizes& bits) const override { if (isOwner(ctx, in.eltype())) { - return rounded_arshift(in, bits).as(in.eltype()); + SPU_ENFORCE(bits.size() == 1, "truncation bits should be splat"); + return rounded_arshift(in, bits[0]).as(in.eltype()); } else { return in; } @@ -701,7 +705,7 @@ class GenInvPermP : public GenInvPermKernel { auto numel = in.numel(); - DISPATCH_ALL_FIELDS(field, "gen_inv_perm_p", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = std::make_signed_t; std::vector perm(numel); std::iota(perm.begin(), perm.end(), 0); @@ -735,7 +739,7 @@ class GenInvPermV : public GenInvPermKernel { auto numel = in.numel(); const auto field = in.eltype().as()->field(); - DISPATCH_ALL_FIELDS(field, "gen_inv_perm_v", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = std::make_signed_t; std::vector perm(numel); std::iota(perm.begin(), perm.end(), 0); @@ -769,7 +773,7 @@ class InvPermPP : public PermKernel { SPU_ENFORCE_EQ(x.eltype(), y.eltype()); NdArrayRef z(x.eltype(), x.shape()); const auto field = x.eltype().as()->field(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = std::make_signed_t; NdArrayView _x(x); NdArrayView _y(y); @@ -795,7 +799,7 @@ class InvPermVV : public PermKernel { if (isOwner(ctx, x.eltype())) { NdArrayRef z(x.eltype(), x.shape()); const auto field = x.eltype().as()->field(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = std::make_signed_t; NdArrayView _x(x); NdArrayView _y(y); @@ -823,7 +827,7 @@ class PermPP : public PermKernel { SPU_ENFORCE_EQ(x.eltype(), y.eltype()); NdArrayRef z(x.eltype(), x.shape()); const auto field = x.eltype().as()->field(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = std::make_signed_t; NdArrayView _x(x); NdArrayView _y(y); @@ -849,7 +853,7 @@ class PermVV : public PermKernel { if (isOwner(ctx, x.eltype())) { NdArrayRef z(x.eltype(), x.shape()); const auto field = x.eltype().as()->field(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = std::make_signed_t; NdArrayView _x(x); NdArrayView _y(y); @@ -878,7 +882,7 @@ class MergeKeysP : public MergeKeysKernel { NdArrayRef out(inputs[0].eltype(), inputs[0].shape()); const auto field = inputs[0].eltype().as()->field(); const auto numel = inputs[0].numel(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = std::make_signed_t; NdArrayView _out(out); _out[0] = 0; @@ -917,7 +921,7 @@ class MergeKeysV : public MergeKeysKernel { NdArrayRef out(inputs[0].eltype(), inputs[0].shape()); const auto field = inputs[0].eltype().as()->field(); const auto numel = inputs[0].numel(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = std::make_signed_t; NdArrayView _out(out); _out[0] = 0; diff --git a/libspu/mpc/kernel.cc b/libspu/mpc/kernel.cc index 45bed576..56b5c1bb 100644 --- a/libspu/mpc/kernel.cc +++ b/libspu/mpc/kernel.cc @@ -43,7 +43,11 @@ void RevealToKernel::evaluate(KernelEvalContext* ctx) const { void ShiftKernel::evaluate(KernelEvalContext* ctx) const { const auto& in = ctx->getParam(0); - size_t bits = ctx->getParam(1); + const auto& bits = ctx->getParam(1); + + SPU_ENFORCE( + bits.size() == 1 || in.numel() == static_cast(bits.size()), + "numel mismatch {} {}", in.numel(), bits.size()); auto res = proc(ctx, UnwrapValue(in), bits); @@ -233,6 +237,17 @@ void ConcateKernel::evaluate(KernelEvalContext* ctx) const { ctx->pushOutput(WrapValue(z)); } +void DisassembleKernel::evaluate(KernelEvalContext* ctx) const { + const auto& in = ctx->getParam(0); + auto z = proc(ctx, UnwrapValue(in)); + + std::vector wrapped(z.size()); + for (size_t idx = 0; idx < z.size(); ++idx) { + wrapped[idx] = WrapValue(z[idx]); + } + ctx->pushOutput(wrapped); +}; + void OramOneHotKernel::evaluate(KernelEvalContext* ctx) const { auto target = ctx->getParam(0); auto s = ctx->getParam(1); diff --git a/libspu/mpc/kernel.h b/libspu/mpc/kernel.h index 4a4c29d8..d75b3426 100644 --- a/libspu/mpc/kernel.h +++ b/libspu/mpc/kernel.h @@ -42,7 +42,7 @@ class ShiftKernel : public Kernel { public: void evaluate(KernelEvalContext* ctx) const override; virtual NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const = 0; + const Sizes& bits) const = 0; }; class BinaryKernel : public Kernel { @@ -217,4 +217,12 @@ class ConcateKernel : public Kernel { int64_t axis) const = 0; }; +class DisassembleKernel : public Kernel { + public: + void evaluate(KernelEvalContext* ctx) const override; + + virtual std::vector proc(KernelEvalContext* ctx, + const NdArrayRef& in) const = 0; +}; + } // namespace spu::mpc diff --git a/libspu/mpc/ref2k/ref2k.cc b/libspu/mpc/ref2k/ref2k.cc index 1d76ea95..cc75e8a7 100644 --- a/libspu/mpc/ref2k/ref2k.cc +++ b/libspu/mpc/ref2k/ref2k.cc @@ -165,7 +165,7 @@ class Ref2kV2S : public UnaryKernel { int64_t numel = in.numel(); - DISPATCH_ALL_FIELDS(field, "v2s", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { std::vector _send(numel); NdArrayView _in(in); @@ -196,7 +196,7 @@ class Ref2kRandS : public RandKernel { const auto field = ctx->getState()->getDefaultField(); return ring_rshift( - state->genPubl(field, shape).as(makeType(field)), 2); + state->genPubl(field, shape).as(makeType(field)), {2}); } }; @@ -368,7 +368,7 @@ class Ref2kLShiftS : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override { + const Sizes& bits) const override { return ring_lshift(in, bits).as(in.eltype()); } }; @@ -382,7 +382,7 @@ class Ref2kRShiftS : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override { + const Sizes& bits) const override { return ring_rshift(in, bits).as(in.eltype()); } }; @@ -414,7 +414,7 @@ class Ref2kARShiftS : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override { + const Sizes& bits) const override { return ring_arshift(in, bits).as(in.eltype()); } }; @@ -441,8 +441,8 @@ class Ref2kTruncS : public TruncAKernel { // https://stackoverflow.com/questions/14008330/how-do-you-multiply-two-fixed-point-numbers // Under certain pattern, like sum(mul(A, B)), error can accumulate in a // fairly significant way - auto v1 = ring_arshift(in, bits); - auto v2 = ring_arshift(in, bits - 1); + auto v1 = ring_arshift(in, {static_cast(bits)}); + auto v2 = ring_arshift(in, {static_cast(bits - 1)}); ring_and_(v2, ring_ones(in.eltype().as()->field(), in.shape())); ring_add_(v1, v2); return v1; @@ -458,7 +458,8 @@ class Ref2kMsbS : public UnaryKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override { - return ring_rshift(in, in.elsize() * 8 - 1).as(in.eltype()); + return ring_rshift(in, {static_cast(in.elsize() * 8 - 1)}) + .as(in.eltype()); } }; diff --git a/libspu/mpc/securenn/arithmetic.cc b/libspu/mpc/securenn/arithmetic.cc index 242b8e75..365637d7 100644 --- a/libspu/mpc/securenn/arithmetic.cc +++ b/libspu/mpc/securenn/arithmetic.cc @@ -19,8 +19,6 @@ #include #include "libspu/core/type_util.h" -#include "libspu/core/vectorize.h" -#include "libspu/mpc/ab_api.h" #include "libspu/mpc/common/communicator.h" #include "libspu/mpc/common/prg_state.h" #include "libspu/mpc/common/pv2k.h" @@ -37,7 +35,7 @@ NdArrayRef A2V::proc(KernelEvalContext* ctx, const NdArrayRef& in, auto numel = in.numel(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { std::vector share(numel); NdArrayView _in(in); pforeach(0, numel, [&](int64_t idx) { share[idx] = _in[idx]; }); @@ -109,7 +107,7 @@ NdArrayRef RandA::proc(KernelEvalContext* ctx, const Shape& shape) const { // - https://eprint.iacr.org/2019/599.pdf // It's safer to keep the number within [-2**(k-2), 2**(k-2)) for comparison // operations. - return ring_rshift(prg_state->genPriv(field, shape), 2) + return ring_rshift(prg_state->genPriv(field, shape), {2}) .as(makeType(field)); } @@ -200,10 +198,7 @@ NdArrayRef MatMulAP::proc(KernelEvalContext* ctx, const NdArrayRef& x, } NdArrayRef LShiftA::proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const { - const auto field = in.eltype().as()->field(); - bits %= SizeOf(field) * 8; - + const Sizes& bits) const { return ring_lshift(in, bits).as(in.eltype()); } @@ -216,10 +211,10 @@ NdArrayRef TruncAPr::proc(KernelEvalContext* ctx, const NdArrayRef& in, auto rank = comm->getRank(); const auto numel = in.numel(); const auto field = in.eltype().as()->field(); - const size_t k = SizeOf(field) * 8; + const int k = SizeOf(field) * 8; NdArrayRef out(in.eltype(), in.shape()); - DISPATCH_ALL_FIELDS(field, "securenn.truncpr", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = ring2k_t; auto r = prg_state->genPriv(field, in.shape()); @@ -232,9 +227,11 @@ NdArrayRef TruncAPr::proc(KernelEvalContext* ctx, const NdArrayRef& in, auto rb_recon = comm->reduce(ReduceOp::ADD, rb, 2, "rb"); if (rank == 2) { - auto adjust1 = - ring_sub(ring_rshift(ring_lshift(r_recon, 1), bits + 1), rc_recon); - auto adjust2 = ring_sub(ring_rshift(r_recon, k - 1), rb_recon); + auto adjust1 = ring_sub(ring_rshift(ring_lshift(r_recon, {1}), + {static_cast(bits + 1)}), + rc_recon); + auto adjust2 = ring_sub( + ring_rshift(r_recon, {static_cast(k - 1)}), rb_recon); comm->sendAsync(0, adjust1, "adjust1"); comm->sendAsync(0, adjust2, "adjust2"); } @@ -362,7 +359,6 @@ NdArrayRef MulAA::proc(KernelEvalContext* ctx, const NdArrayRef& x, b = prg_state ->genPrssPair(field, x.shape(), PrgState::GenPrssCtrl::Second) .second; - prg_state->genPrssPair(field, x.shape(), PrgState::GenPrssCtrl::None); c = comm->recv(2, ty, "c"); c = c.reshape(x.shape()); } @@ -527,7 +523,6 @@ NdArrayRef MatMulAA::proc(KernelEvalContext* ctx, const NdArrayRef& x, .second; b = prg_state->genPrssPair(field, shape2, PrgState::GenPrssCtrl::Second) .second; - prg_state->genPrssPair(field, shape3, PrgState::GenPrssCtrl::None); c = comm->recv(2, ty, "c"); c = c.reshape(shape3); @@ -594,7 +589,7 @@ NdArrayRef ShareConvert::proc(KernelEvalContext* ctx, const auto kComm = a.elsize() * size; comm->addCommStatsManually(4, 4 * log_p * kComm + 6 * kComm); - DISPATCH_ALL_FIELDS(field, "securenn.sc", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = ring2k_t; const U L_1 = (U)(~0); // 2^k - 1 // P0 and P1 add the share of zero @@ -902,7 +897,7 @@ NdArrayRef Msb::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { const auto kComm = in.elsize() * size; comm->addCommStatsManually(5, 13 * kComm + 4 * kComm * log_p); - DISPATCH_ALL_FIELDS(field, "securenn.msb", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = ring2k_t; const U L_1 = (U)(~0); @@ -1050,7 +1045,6 @@ NdArrayRef Msb::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { prg_state ->genPrssPair(field, in.shape(), PrgState::GenPrssCtrl::Second) .second; - prg_state->genPrssPair(field, in.shape(), PrgState::GenPrssCtrl::None); beaver_c = comm->recv(2, ty, "beaver_c"); beaver_c = beaver_c.reshape(in.shape()); } @@ -1241,7 +1235,7 @@ NdArrayRef Msb_opt::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { const auto kComm = in.elsize() * size; comm->addCommStatsManually(5, 9 * kComm + 3 * kComm * log_p); - DISPATCH_ALL_FIELDS(field, "securenn.msb", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = ring2k_t; const U L_1 = (U)(~0); @@ -1395,7 +1389,6 @@ NdArrayRef Msb_opt::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { prg_state ->genPrssPair(field, in.shape(), PrgState::GenPrssCtrl::Second) .second; - prg_state->genPrssPair(field, in.shape(), PrgState::GenPrssCtrl::None); beaver_c = comm->recv(2, ty, "beaver_c"); beaver_c = beaver_c.reshape(in.shape()); } diff --git a/libspu/mpc/securenn/arithmetic.h b/libspu/mpc/securenn/arithmetic.h index 599bf14a..e0829464 100644 --- a/libspu/mpc/securenn/arithmetic.h +++ b/libspu/mpc/securenn/arithmetic.h @@ -172,7 +172,7 @@ class LShiftA : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override; + const Sizes& bits) const override; }; // Refer to: diff --git a/libspu/mpc/securenn/boolean.cc b/libspu/mpc/securenn/boolean.cc index 7d08673e..9deccec8 100644 --- a/libspu/mpc/securenn/boolean.cc +++ b/libspu/mpc/securenn/boolean.cc @@ -31,7 +31,7 @@ namespace { size_t getNumBits(const NdArrayRef& in) { if (in.eltype().isa()) { const auto field = in.eltype().as()->field(); - return DISPATCH_ALL_FIELDS(field, "_", + return DISPATCH_ALL_FIELDS(field, [&]() { return maxBitWidth(in); }); } else if (in.eltype().isa()) { return in.eltype().as()->nbits(); @@ -120,7 +120,7 @@ NdArrayRef AndBP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const size_t out_nbits = std::min(getNumBits(lhs), getNumBits(rhs)); NdArrayRef out(makeType(field, out_nbits), lhs.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _lhs(lhs); NdArrayView _rhs(rhs); NdArrayView _out(out); @@ -146,12 +146,12 @@ NdArrayRef AndBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const int64_t numel = lhs.numel(); NdArrayRef out(makeType(field, out_nbits), lhs.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = ring2k_t; NdArrayView _lhs(lhs); NdArrayView _rhs(rhs); - DISPATCH_UINT_PT_TYPES(backtype, "_", [&]() { + DISPATCH_UINT_PT_TYPES(backtype, [&]() { using V = ScalarT; int64_t numBytes = numel * SizeOf(backtype); @@ -197,7 +197,6 @@ NdArrayRef AndBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, b = prg_state ->genPrssPair(field, {numField}, PrgState::GenPrssCtrl::Second) .second; - prg_state->genPrssPair(field, {numField}, PrgState::GenPrssCtrl::None); c = comm->recv(2, ty, "c"); c = c.reshape({numField}); } @@ -257,32 +256,32 @@ NdArrayRef XorBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, } NdArrayRef LShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t shift) const { + const Sizes& shift) const { const auto field = in.eltype().as()->field(); - shift %= SizeOf(field) * 8; - size_t out_nbits = in.eltype().as()->nbits() + shift; - out_nbits = std::clamp(out_nbits, static_cast(0), SizeOf(field) * 8); + int64_t out_nbits = in.eltype().as()->nbits() + + *std::max_element(shift.begin(), shift.end()); + out_nbits = std::clamp(out_nbits, 0L, + static_cast(SizeOf(field) * 8)); return makeBShare(ring_lshift(in, shift), field, out_nbits); } NdArrayRef RShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t shift) const { + const Sizes& shift) const { const auto field = in.eltype().as()->field(); - shift %= SizeOf(field) * 8; - size_t nbits = in.eltype().as()->nbits(); - size_t out_nbits = nbits - std::min(nbits, shift); - SPU_ENFORCE(nbits <= SizeOf(field) * 8); + int64_t nbits = in.eltype().as()->nbits(); + int64_t out_nbits = + nbits - std::min(nbits, *std::min_element(shift.begin(), shift.end())); + SPU_ENFORCE(nbits <= static_cast(SizeOf(field) * 8)); return makeBShare(ring_rshift(in, shift), field, out_nbits); } NdArrayRef ARShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t shift) const { + const Sizes& shift) const { const auto field = in.eltype().as()->field(); - shift %= SizeOf(field) * 8; // arithmetic right shift expects to work on ring, or the behaviour is // undefined. @@ -310,7 +309,7 @@ NdArrayRef BitIntlB::proc(KernelEvalContext* ctx, const NdArrayRef& in, NdArrayRef out(in.eltype(), in.shape()); auto numel = in.numel(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _in(in); NdArrayView _out(out); @@ -331,7 +330,7 @@ NdArrayRef BitDeintlB::proc(KernelEvalContext* ctx, const NdArrayRef& in, NdArrayRef out(in.eltype(), in.shape()); auto numel = in.numel(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _in(in); NdArrayView _out(out); diff --git a/libspu/mpc/securenn/boolean.h b/libspu/mpc/securenn/boolean.h index 35433c0b..ccaaa90b 100644 --- a/libspu/mpc/securenn/boolean.h +++ b/libspu/mpc/securenn/boolean.h @@ -118,7 +118,7 @@ class LShiftB : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t shift) const override; + const Sizes& shift) const override; }; class RShiftB : public ShiftKernel { @@ -130,7 +130,7 @@ class RShiftB : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t shift) const override; + const Sizes& shift) const override; }; class ARShiftB : public ShiftKernel { @@ -142,7 +142,7 @@ class ARShiftB : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t shift) const override; + const Sizes& shift) const override; }; class BitrevB : public BitrevKernel { diff --git a/libspu/mpc/securenn/conversion.cc b/libspu/mpc/securenn/conversion.cc index 1fe99fe4..8a05d8a4 100644 --- a/libspu/mpc/securenn/conversion.cc +++ b/libspu/mpc/securenn/conversion.cc @@ -112,8 +112,12 @@ NdArrayRef B2A_Randbit::proc(KernelEvalContext* ctx, auto randbits = prg_state->genPriv(field, {numel * static_cast(nbits)}); // reconstruct ranbits - if (rank == 0) comm->sendAsync(2, randbits, "randbits0"); - if (rank == 1) comm->sendAsync(2, randbits, "randbits1"); + if (rank == 0) { + comm->sendAsync(2, randbits, "randbits0"); + } + if (rank == 1) { + comm->sendAsync(2, randbits, "randbits1"); + } if (rank == 2) { auto randbits0 = comm->recv(0, makeType(field), "randbits0"); randbits0 = randbits0.reshape(randbits.shape()); @@ -132,7 +136,7 @@ NdArrayRef B2A_Randbit::proc(KernelEvalContext* ctx, auto res = NdArrayRef(makeType(field), x.shape()); - DISPATCH_ALL_FIELDS(field, kBindName, [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = ring2k_t; NdArrayView _randbits(randbits); diff --git a/libspu/mpc/semi2k/arithmetic.cc b/libspu/mpc/semi2k/arithmetic.cc index 90591da6..017c4640 100644 --- a/libspu/mpc/semi2k/arithmetic.cc +++ b/libspu/mpc/semi2k/arithmetic.cc @@ -38,7 +38,7 @@ NdArrayRef RandA::proc(KernelEvalContext* ctx, const Shape& shape) const { // - https://eprint.iacr.org/2019/599.pdf // It's safer to keep the number within [-2**(k-2), 2**(k-2)) for comparison // operations. - return ring_rshift(prg_state->genPriv(field, shape), 2) + return ring_rshift(prg_state->genPriv(field, shape), {2}) .as(makeType(field)); } @@ -73,7 +73,7 @@ NdArrayRef A2V::proc(KernelEvalContext* ctx, const NdArrayRef& in, auto numel = in.numel(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { std::vector share(numel); NdArrayView _in(in); pforeach(0, numel, [&](int64_t idx) { share[idx] = _in[idx]; }); @@ -364,10 +364,7 @@ NdArrayRef MatMulAA::proc(KernelEvalContext* ctx, const NdArrayRef& x, } NdArrayRef LShiftA::proc(KernelEvalContext*, const NdArrayRef& in, - size_t bits) const { - const auto field = in.eltype().as()->field(); - bits %= SizeOf(field) * 8; - + const Sizes& bits) const { return ring_lshift(in, bits).as(in.eltype()); } @@ -381,7 +378,7 @@ NdArrayRef TruncA::proc(KernelEvalContext* ctx, const NdArrayRef& x, if (comm->getWorldSize() == 2) { // SecureML, local truncation. // Ref: Theorem 1. https://eprint.iacr.org/2017/396.pdf - return ring_arshift(x, bits).as(x.eltype()); + return ring_arshift(x, {static_cast(bits)}).as(x.eltype()); } else { // ABY3, truncation pair method. // Ref: Section 5.1.2 https://eprint.iacr.org/2018/403.pdf @@ -399,7 +396,7 @@ NdArrayRef TruncA::proc(KernelEvalContext* ctx, const NdArrayRef& x, auto x_r = comm->allReduce(ReduceOp::ADD, ring_sub(x, r), kBindName); auto res = rb; if (comm->getRank() == 0) { - ring_add_(res, ring_arshift(x_r, bits)); + ring_add_(res, ring_arshift(x_r, {static_cast(bits)})); } // res = [x-r] + [r], x which [*] is truncation operation. @@ -418,7 +415,7 @@ NdArrayRef TruncAPr::proc(KernelEvalContext* ctx, const NdArrayRef& in, NdArrayRef out(in.eltype(), in.shape()); - DISPATCH_ALL_FIELDS(field, "semi2k.truncpr", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = ring2k_t; auto [r, rc, rb] = beaver->TruncPr(field, numel, bits); SPU_ENFORCE(static_cast(r.size()) == numel * SizeOf(field)); diff --git a/libspu/mpc/semi2k/arithmetic.h b/libspu/mpc/semi2k/arithmetic.h index 221a9f44..8b9510f8 100644 --- a/libspu/mpc/semi2k/arithmetic.h +++ b/libspu/mpc/semi2k/arithmetic.h @@ -206,7 +206,7 @@ class LShiftA : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override; + const Sizes& bits) const override; }; class TruncA : public TruncAKernel { diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_test.cc b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_test.cc index 636766ec..58dcdc80 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_test.cc +++ b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_test.cc @@ -14,7 +14,6 @@ #include -#include "fmt/format.h" #include "gtest/gtest.h" #include "yacl/crypto/key_utils.h" #include "yacl/link/algorithm/barrier.h" @@ -187,7 +186,7 @@ TEST_P(BeaverTest, Mul_large) { auto open = open_buffer(triples, kField, std::vector(3, {kNumel}), kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _b(open[1]); NdArrayView _c(open[2]); @@ -214,7 +213,7 @@ TEST_P(BeaverTest, Mul_large) { auto open = open_buffer(triples, kField, std::vector(3, {kNumel}), kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _cache_a(x_cache); NdArrayView _b(open[1]); @@ -240,7 +239,7 @@ TEST_P(BeaverTest, Mul_large) { auto open = open_buffer(triples, kField, std::vector(3, {kNumel}), kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _b(open[1]); NdArrayView _cache_b(y_cache); @@ -267,7 +266,7 @@ TEST_P(BeaverTest, Mul_large) { auto open = open_buffer(triples, kField, std::vector(3, {kNumel}), kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _b(open[1]); NdArrayView _cache_a(x_cache); @@ -297,7 +296,7 @@ TEST_P(BeaverTest, Mul_large) { auto open = open_buffer(triples, kField, std::vector(3, {kNumel}), kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _b(open[1]); NdArrayView _cache_a(x_cache); @@ -342,7 +341,7 @@ TEST_P(BeaverTest, Mul) { auto open = open_buffer(triples, kField, std::vector(3, {kNumel}), kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _b(open[1]); NdArrayView _c(open[2]); @@ -369,7 +368,7 @@ TEST_P(BeaverTest, Mul) { auto open = open_buffer(triples, kField, std::vector(3, {kNumel}), kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _cache_a(x_cache); NdArrayView _b(open[1]); @@ -395,7 +394,7 @@ TEST_P(BeaverTest, Mul) { auto open = open_buffer(triples, kField, std::vector(3, {kNumel}), kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _b(open[1]); NdArrayView _cache_b(y_cache); @@ -422,7 +421,7 @@ TEST_P(BeaverTest, Mul) { auto open = open_buffer(triples, kField, std::vector(3, {kNumel}), kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _b(open[1]); NdArrayView _cache_a(x_cache); @@ -452,7 +451,7 @@ TEST_P(BeaverTest, Mul) { auto open = open_buffer(triples, kField, std::vector(3, {kNumel}), kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _b(open[1]); NdArrayView _cache_a(x_cache); @@ -539,7 +538,7 @@ TEST_P(BeaverTest, Dot) { kWorldSize, true); auto res = ring_mmul(open[0], open[1]); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _r(res); NdArrayView _c(open[2]); for (auto idx = 0; idx < res.numel(); idx++) { @@ -565,7 +564,7 @@ TEST_P(BeaverTest, Dot) { kWorldSize, true); auto res = ring_mmul(x_cache, open[1]); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _cache_a(x_cache); NdArrayView _r(res); @@ -592,7 +591,7 @@ TEST_P(BeaverTest, Dot) { kWorldSize, true); auto res = ring_mmul(open[0], y_cache); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _b(open[1]); NdArrayView _cache_b(y_cache); NdArrayView _r(res); @@ -620,7 +619,7 @@ TEST_P(BeaverTest, Dot) { kWorldSize, true); auto res = ring_mmul(x_cache, y_cache); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _cache_a(x_cache); NdArrayView _b(open[1]); @@ -651,7 +650,7 @@ TEST_P(BeaverTest, Dot) { kWorldSize, true); auto res = ring_mmul(x_cache, y_cache); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { auto transpose_a = open[0].transpose(); NdArrayView _a(transpose_a); NdArrayView _cache_a(y_cache); @@ -684,7 +683,7 @@ TEST_P(BeaverTest, Dot) { auto open = open_buffer(triples, kField, std::vector(3, {M * K}), kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _cache_a(x_cache); NdArrayView _b(open[1]); @@ -725,7 +724,7 @@ TEST_P(BeaverTest, Dot_large) { open_buffer(triples, kField, {{M, K}, {K, N}, {M, N}}, kWorldSize, true); auto res = ring_mmul(open[0], open[1]); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _r(res); NdArrayView _c(open[2]); for (auto idx = 0; idx < res.numel(); idx++) { @@ -741,7 +740,7 @@ TEST_P(BeaverTest, Trunc) { const FieldType kField = std::get<2>(GetParam()); const size_t adjust_rank = std::get<4>(GetParam()); const int64_t kNumel = 7; - const size_t kBits = 5; + const int64_t kBits = 5; std::vector pairs; pairs.resize(kWorldSize); @@ -756,7 +755,7 @@ TEST_P(BeaverTest, Trunc) { EXPECT_EQ(pairs.size(), kWorldSize); auto open = open_buffer(pairs, kField, {{kNumel}, {kNumel}}, kWorldSize, true); - EXPECT_TRUE(ring_all_equal(ring_arshift(open[0], kBits), open[1], 0)); + EXPECT_TRUE(ring_all_equal(ring_arshift(open[0], {kBits}), open[1], 0)); } TEST_P(BeaverTest, TruncPr) { @@ -782,7 +781,7 @@ TEST_P(BeaverTest, TruncPr) { auto open = open_buffer(rets, kField, std::vector(3, {kNumel}), kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "semi2k.truncpr.ut", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { using T = ring2k_t; auto sum_r_iter = open[0].begin(); auto sum_rc_iter = open[1].begin(); @@ -821,7 +820,7 @@ TEST_P(BeaverTest, Randbit) { EXPECT_EQ(shares.size(), kWorldSize); auto open = open_buffer(shares, kField, {{kNumel}}, kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { using scalar_t = typename Ring2kTrait<_kField>::scalar_t; auto x = xt_adapt(open[0]); EXPECT_TRUE(xt::all(x <= xt::ones_like(x))); diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.cc b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.cc index 574d892a..d3d09401 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.cc +++ b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.cc @@ -48,7 +48,7 @@ void FillReplayDesc(Beaver::ReplayDesc* desc, FieldType field, int64_t size, } // namespace BeaverTfpUnsafe::BeaverTfpUnsafe(std::shared_ptr lctx) - : lctx_(std::move(std::move(lctx))), + : lctx_(std::move(lctx)), seed_(yacl::crypto::SecureRandSeed()), counter_(0) { auto buf = yacl::SerializeUint128(seed_); diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.cc b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.cc index 4f3094c3..cca2c571 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.cc +++ b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.cc @@ -152,7 +152,7 @@ std::vector RpcCall(brpc::Channel& channel, AdjustRequest req, } // namespace BeaverTtp::BeaverTtp(std::shared_ptr lctx, Options ops) - : lctx_(std::move(std::move(lctx))), + : lctx_(std::move(lctx)), seed_(yacl::crypto::SecureRandSeed()), counter_(0), options_(std::move(ops)) { diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.cc b/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.cc index c16c6af3..80ef18e7 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.cc +++ b/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.cc @@ -115,7 +115,7 @@ NdArrayRef TrustedParty::adjustTrunc(absl::Span ops, auto rs = reconstruct(RecOp::ADD, ops); // adjust = (rs[0] >> bits) - rs[1]; - return ring_sub(ring_arshift(rs[0], bits), rs[1]); + return ring_sub(ring_arshift(rs[0], {static_cast(bits)}), rs[1]); } std::pair TrustedParty::adjustTruncPr( @@ -127,11 +127,14 @@ std::pair TrustedParty::adjustTruncPr( auto rs = reconstruct(RecOp::ADD, ops); // adjust1 = ((rs[0] << 1) >> (bits + 1)) - rs[1]; - auto adjust1 = ring_sub(ring_rshift(ring_lshift(rs[0], 1), bits + 1), rs[1]); + auto adjust1 = ring_sub( + ring_rshift(ring_lshift(rs[0], {1}), {static_cast(bits + 1)}), + rs[1]); // adjust2 = (rs[0] >> (k - 1)) - rs[2]; const size_t k = SizeOf(ops[0].desc.field) * 8; - auto adjust2 = ring_sub(ring_rshift(rs[0], k - 1), rs[2]); + auto adjust2 = + ring_sub(ring_rshift(rs[0], {static_cast(k - 1)}), rs[2]); return {adjust1, adjust2}; } diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server_main.cc b/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server_main.cc index 1680325e..0a701cf3 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server_main.cc +++ b/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server_main.cc @@ -91,8 +91,7 @@ int main(int argc, char* argv[]) { std::string key; SPU_ENFORCE( butil::Base64Decode(ttp_server_config::FLAGS_server_private_key, &key)); - decode_private_key = - yacl::Buffer(decode_private_key.data(), decode_private_key.size()); + decode_private_key = yacl::Buffer(key.data(), key.size()); } spu::mpc::semi2k::beaver::ttp_server::ServerOptions ops{ diff --git a/libspu/mpc/semi2k/boolean.cc b/libspu/mpc/semi2k/boolean.cc index 4dae73a1..3fa9bf4a 100644 --- a/libspu/mpc/semi2k/boolean.cc +++ b/libspu/mpc/semi2k/boolean.cc @@ -30,7 +30,7 @@ namespace { size_t getNumBits(const NdArrayRef& in) { if (in.eltype().isa()) { const auto field = in.eltype().as()->field(); - return DISPATCH_ALL_FIELDS(field, "_", + return DISPATCH_ALL_FIELDS(field, [&]() { return maxBitWidth(in); }); } else if (in.eltype().isa()) { return in.eltype().as()->nbits(); @@ -119,7 +119,7 @@ NdArrayRef AndBP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const size_t out_nbits = std::min(getNumBits(lhs), getNumBits(rhs)); NdArrayRef out(makeType(field, out_nbits), lhs.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _lhs(lhs); NdArrayView _rhs(rhs); NdArrayView _out(out); @@ -144,12 +144,12 @@ NdArrayRef AndBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, // semi2k always use the same storage type. NdArrayRef out(makeType(field, out_nbits), lhs.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = ring2k_t; NdArrayView _lhs(lhs); NdArrayView _rhs(rhs); - DISPATCH_UINT_PT_TYPES(backtype, "_", [&]() { + DISPATCH_UINT_PT_TYPES(backtype, [&]() { using V = ScalarT; // TODO: redefine beaver interface, generate variadic beaver and bits. @@ -215,32 +215,31 @@ NdArrayRef XorBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, } NdArrayRef LShiftB::proc(KernelEvalContext*, const NdArrayRef& in, - size_t shift) const { + const Sizes& shift) const { const auto field = in.eltype().as()->field(); - shift %= SizeOf(field) * 8; - size_t out_nbits = in.eltype().as()->nbits() + shift; + size_t out_nbits = in.eltype().as()->nbits() + + *std::max_element(shift.begin(), shift.end()); out_nbits = std::clamp(out_nbits, static_cast(0), SizeOf(field) * 8); return makeBShare(ring_lshift(in, shift), field, out_nbits); } NdArrayRef RShiftB::proc(KernelEvalContext*, const NdArrayRef& in, - size_t shift) const { + const Sizes& shift) const { const auto field = in.eltype().as()->field(); - shift %= SizeOf(field) * 8; - size_t nbits = in.eltype().as()->nbits(); - size_t out_nbits = nbits - std::min(nbits, shift); - SPU_ENFORCE(nbits <= SizeOf(field) * 8); + int64_t nbits = in.eltype().as()->nbits(); + int64_t out_nbits = + nbits - std::min(nbits, *std::min_element(shift.begin(), shift.end())); + SPU_ENFORCE(nbits <= static_cast(SizeOf(field) * 8)); return makeBShare(ring_rshift(in, shift), field, out_nbits); } NdArrayRef ARShiftB::proc(KernelEvalContext*, const NdArrayRef& in, - size_t shift) const { + const Sizes& shift) const { const auto field = in.eltype().as()->field(); - shift %= SizeOf(field) * 8; // arithmetic right shift expects to work on ring, or the behaviour is // undefined. @@ -268,7 +267,7 @@ NdArrayRef BitIntlB::proc(KernelEvalContext*, const NdArrayRef& in, NdArrayRef out(in.eltype(), in.shape()); auto numel = in.numel(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _in(in); NdArrayView _out(out); @@ -289,7 +288,7 @@ NdArrayRef BitDeintlB::proc(KernelEvalContext*, const NdArrayRef& in, NdArrayRef out(in.eltype(), in.shape()); auto numel = in.numel(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _in(in); NdArrayView _out(out); diff --git a/libspu/mpc/semi2k/boolean.h b/libspu/mpc/semi2k/boolean.h index 8c7e6835..8ca39710 100644 --- a/libspu/mpc/semi2k/boolean.h +++ b/libspu/mpc/semi2k/boolean.h @@ -118,7 +118,7 @@ class LShiftB : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t shift) const override; + const Sizes& shift) const override; }; class RShiftB : public ShiftKernel { @@ -130,7 +130,7 @@ class RShiftB : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t shift) const override; + const Sizes& shift) const override; }; class ARShiftB : public ShiftKernel { @@ -142,7 +142,7 @@ class ARShiftB : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t shift) const override; + const Sizes& shift) const override; }; class BitrevB : public BitrevKernel { diff --git a/libspu/mpc/semi2k/conversion.cc b/libspu/mpc/semi2k/conversion.cc index 35d8c701..ae638ce4 100644 --- a/libspu/mpc/semi2k/conversion.cc +++ b/libspu/mpc/semi2k/conversion.cc @@ -42,6 +42,26 @@ static NdArrayRef wrap_and_bb(SPUContext* ctx, const NdArrayRef& x, return UnwrapValue(and_bb(ctx, WrapValue(x), WrapValue(y))); } +// TODO: Move to some common place +PtType getBacktype(size_t nbits) { + if (nbits <= 8) { + return PT_U8; + } + if (nbits <= 16) { + return PT_U16; + } + if (nbits <= 32) { + return PT_U32; + } + if (nbits <= 64) { + return PT_U64; + } + if (nbits <= 128) { + return PT_U128; + } + SPU_THROW("invalid number of bits={}", nbits); +} + NdArrayRef A2B::proc(KernelEvalContext* ctx, const NdArrayRef& x) const { const auto field = x.eltype().as()->field(); auto* comm = ctx->getState(); @@ -90,6 +110,9 @@ NdArrayRef B2A::proc(KernelEvalContext* ctx, const NdArrayRef& x) const { return r_a; } +// TODO(jimi): pack {numel * nbits} to fully make use of undelying storage to +// save communications. If implemented, B2A_Disassemble kernel is also no longer +// needed NdArrayRef B2A_Randbit::proc(KernelEvalContext* ctx, const NdArrayRef& x) const { const auto field = x.eltype().as()->field(); @@ -105,13 +128,14 @@ NdArrayRef B2A_Randbit::proc(KernelEvalContext* ctx, const auto numel = x.numel(); const auto rand_numel = numel * static_cast(nbits); + const PtType backtype = getBacktype(nbits); auto randbits = beaver->RandBit(field, rand_numel); SPU_ENFORCE(static_cast(randbits.size()) == rand_numel * SizeOf(field)); auto res = NdArrayRef(makeType(field), x.shape()); - DISPATCH_ALL_FIELDS(field, kBindName, [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = ring2k_t; absl::Span _randbits(randbits.data(), rand_numel); @@ -119,32 +143,125 @@ NdArrayRef B2A_Randbit::proc(KernelEvalContext* ctx, // algorithm begins. // Ref: III.D @ https://eprint.iacr.org/2019/599.pdf (SPDZ-2K primitives) - std::vector x_xor_r(numel); - - pforeach(0, numel, [&](int64_t idx) { - // use _r[i*nbits, (i+1)*nbits) to construct rb[i] - U mask = 0; - for (int64_t bit = 0; bit < nbits; ++bit) { - mask += (_randbits[idx * nbits + bit] & 0x1) << bit; - } - x_xor_r[idx] = _x[idx] ^ mask; + DISPATCH_UINT_PT_TYPES(backtype, [&]() { + using V = ScalarT; + std::vector x_xor_r(numel); + + pforeach(0, numel, [&](int64_t idx) { + // use _r[i*nbits, (i+1)*nbits) to construct rb[i] + V mask = 0; + for (int64_t bit = 0; bit < nbits; ++bit) { + mask += (static_cast(_randbits[idx * nbits + bit]) & 0x1) << bit; + } + x_xor_r[idx] = _x[idx] ^ mask; + }); + + // open c = x ^ r + x_xor_r = comm->allReduce(x_xor_r, "open(x^r)"); + + NdArrayView _res(res); + pforeach(0, numel, [&](int64_t idx) { + _res[idx] = 0; + for (int64_t bit = 0; bit < nbits; bit++) { + auto c_i = static_cast(x_xor_r[idx] >> bit) & 0x1; + if (comm->getRank() == 0) { + _res[idx] += (c_i + (1 - c_i * 2) * _randbits[idx * nbits + bit]) + << bit; + } else { + _res[idx] += ((1 - c_i * 2) * _randbits[idx * nbits + bit]) << bit; + } + } + }); }); + }); + + return res; +} + +// Reference: +// III.D @ https://eprint.iacr.org/2019/599.pdf (SPDZ-2K primitives) +// +// Analysis: +// Online Latency: 1 (x_xor_r reveal) +// Communication: one element bits for one element +// Vectorization: yes +// +// HighLevel Intuition: +// Since: X = sum: Xi * 2^i +// If we have A, then we can construct A = sum: A * 2^i. +// +// The problem is that we only have B in hand. Details for how to +// construct A from B: +// - trusted third party choose a random bit r, where r == 0 or r == 1. +// - trusted third party send A to parties +// - parties compute B from A +// - parties xor_open c = Xi ^ r = open(B ^ B), Xi is still safe due +// to protection from r. +// - parties compute: = c + (1-2c)* +// A = 1 - A if c == 1, i.e. Xi != r +// A = A if c == 0, i.e. Xi == r +// i.e. A = c + (1-2c) * A +// +// Online Communication: +// = 1 (xor open) + +// Disassemble BShr to AShr bit-by-bit +// Input: BShr +// Return: a vector of k AShr, k is the valid bits of BShr +std::vector B2A_Disassemble::proc(KernelEvalContext* ctx, + const NdArrayRef& x) const { + const auto field = x.eltype().as()->field(); + auto* comm = ctx->getState(); + auto* beaver = ctx->getState()->beaver(); + + const int64_t nbits = x.eltype().as()->nbits(); + SPU_ENFORCE((size_t)nbits > 0 && (size_t)nbits <= SizeOf(field) * 8, + "invalid nbits={}", nbits); - // open c = x ^ r - x_xor_r = comm->allReduce(x_xor_r, "open(x^r)"); - - NdArrayView _res(res); - pforeach(0, numel, [&](int64_t idx) { - _res[idx] = 0; - for (int64_t bit = 0; bit < nbits; bit++) { - auto c_i = (x_xor_r[idx] >> bit) & 0x1; - if (comm->getRank() == 0) { - _res[idx] += (c_i + (1 - c_i * 2) * _randbits[idx * nbits + bit]) - << bit; - } else { - _res[idx] += ((1 - c_i * 2) * _randbits[idx * nbits + bit]) << bit; + const auto numel = x.numel(); + const auto rand_numel = numel * static_cast(nbits); + const PtType backtype = getBacktype(nbits); + + auto randbits = beaver->RandBit(field, rand_numel); + + std::vector res; + res.reserve(nbits); + for (int64_t idx = 0; idx < nbits; ++idx) { + res.emplace_back(makeType(field), x.shape()); + } + DISPATCH_ALL_FIELDS(field, [&]() { + using U = ring2k_t; + + absl::Span _randbits(randbits.data(), rand_numel); + NdArrayView _x(x); + + DISPATCH_UINT_PT_TYPES(backtype, [&]() { + using V = ScalarT; + std::vector x_xor_r(numel); + + pforeach(0, numel, [&](int64_t idx) { + // use _r[i*nbits, (i+1)*nbits) to construct rb[i] + V mask = 0; + for (int64_t bit = 0; bit < nbits; ++bit) { + mask += (static_cast(_randbits[idx * nbits + bit]) & 0x1) << bit; } - } + x_xor_r[idx] = _x[idx] ^ mask; + }); + + // open c = x ^ r + x_xor_r = comm->allReduce(x_xor_r, "open(x^r)"); + + pforeach(0, numel, [&](int64_t idx) { + pforeach(0, nbits, [&](int64_t bit) { + NdArrayView _res(res[bit]); + auto c_i = static_cast(x_xor_r[idx] >> bit) & 0x1; + if (comm->getRank() == 0) { + _res[idx] = (c_i + (1 - c_i * 2) * _randbits[idx * nbits + bit]); + } else { + _res[idx] = ((1 - c_i * 2) * _randbits[idx * nbits + bit]); + } + }); + }); }); }); @@ -188,7 +305,9 @@ NdArrayRef MsbA2B::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { // Compute the k'th bit. // (m^n)[k] ^ carry - auto msb = xor_bb(sctx, rshift_b(sctx, xor_bb(sctx, m, n), k), carry); + auto msb = xor_bb( + sctx, rshift_b(sctx, xor_bb(sctx, m, n), {static_cast(k)}), + carry); return UnwrapValue(msb); } @@ -210,7 +329,7 @@ NdArrayRef eqz(KernelEvalContext* ctx, const NdArrayRef& in) { // beaver samples r and deals [r]a and [r]b // receal c = a+r // check a == 0 <=> c == r - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; auto [ra_buf, rb_buf] = beaver->Eqz(field, numel); @@ -238,11 +357,11 @@ NdArrayRef eqz(KernelEvalContext* ctx, const NdArrayRef& in) { // TODO: fix AND triple // in beaver->AND(field, shape), min FM32, need min 1byte to reduce comm NdArrayRef round_out = rb.as(makeType(field)); - size_t cur_bits = round_out.eltype().as()->nbits(); + int64_t cur_bits = round_out.eltype().as()->nbits(); while (cur_bits != 1) { cur_bits /= 2; - round_out = - wrap_and_bb(ctx->sctx(), round_out, ring_rshift(round_out, cur_bits)); + round_out = wrap_and_bb(ctx->sctx(), round_out, + ring_rshift(round_out, {cur_bits})); } // 1 bit info in lsb diff --git a/libspu/mpc/semi2k/conversion.h b/libspu/mpc/semi2k/conversion.h index bc249998..891a23cd 100644 --- a/libspu/mpc/semi2k/conversion.h +++ b/libspu/mpc/semi2k/conversion.h @@ -73,6 +73,21 @@ class B2A_Randbit : public UnaryKernel { NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& x) const override; }; +class B2A_Disassemble : public DisassembleKernel { + public: + static constexpr char kBindName[] = "b2a_disassemble"; + + ce::CExpr latency() const override { return ce::Const(1); } + + ce::CExpr comm() const override { + return ce::K() * (ce::N() - 1) // Open bit masked value + ; + } + + std::vector proc(KernelEvalContext* ctx, + const NdArrayRef& x) const override; +}; + // Note: current only for 2PC. class MsbA2B : public UnaryKernel { public: diff --git a/libspu/mpc/semi2k/permute.cc b/libspu/mpc/semi2k/permute.cc index 28278f8c..21787a42 100644 --- a/libspu/mpc/semi2k/permute.cc +++ b/libspu/mpc/semi2k/permute.cc @@ -82,7 +82,7 @@ NdArrayRef RandPermM::proc(KernelEvalContext* ctx, const Shape& shape) const { const auto perm_vector = genRandomPerm(out.numel(), _seed[0]); const auto field = out.eltype().as()->field(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _out(out); pforeach(0, out.numel(), [&](int64_t idx) { _out[idx] = ring2k_t(perm_vector[idx]); }); diff --git a/libspu/mpc/semi2k/protocol.cc b/libspu/mpc/semi2k/protocol.cc index a7d3b7e1..3acd8344 100644 --- a/libspu/mpc/semi2k/protocol.cc +++ b/libspu/mpc/semi2k/protocol.cc @@ -50,21 +50,23 @@ void regSemi2kProtocol(SPUContext* ctx, ctx->prot()->addState(ctx->config(), lctx); ctx->prot() ->regKernel< - semi2k::P2A, semi2k::A2P, semi2k::A2V, semi2k::V2A, // - semi2k::NotA, // - semi2k::AddAP, semi2k::AddAA, // - semi2k::MulAP, semi2k::MulAA, semi2k::SquareA, // - semi2k::MatMulAP, semi2k::MatMulAA, // - semi2k::LShiftA, semi2k::LShiftB, semi2k::RShiftB, - semi2k::ARShiftB, // - semi2k::CommonTypeB, semi2k::CommonTypeV, semi2k::CastTypeB, // - semi2k::B2P, semi2k::P2B, semi2k::A2B, semi2k::B2A_Randbit, // - semi2k::AndBP, semi2k::AndBB, semi2k::XorBP, semi2k::XorBB, - semi2k::BitrevB, // - semi2k::BitIntlB, semi2k::BitDeintlB, // - semi2k::RandA, semi2k::RandPermM, semi2k::PermAM, semi2k::PermAP, - semi2k::InvPermAM, semi2k::InvPermAP, semi2k::InvPermAV, // - semi2k::EqualAA, semi2k::EqualAP, semi2k::BeaverCacheKernel>(); + semi2k::P2A, semi2k::A2P, semi2k::A2V, semi2k::V2A, // + semi2k::NotA, // + semi2k::AddAP, semi2k::AddAA, // + semi2k::MulAP, semi2k::MulAA, semi2k::SquareA, // + semi2k::MatMulAP, semi2k::MatMulAA, // + semi2k::LShiftA, semi2k::LShiftB, semi2k::RShiftB, // + semi2k::ARShiftB, // + semi2k::CommonTypeB, semi2k::CommonTypeV, semi2k::CastTypeB, // + semi2k::B2P, semi2k::P2B, // + semi2k::A2B, semi2k::B2A_Randbit, semi2k::B2A_Disassemble, // + semi2k::AndBP, semi2k::AndBB, semi2k::XorBP, semi2k::XorBB, // + semi2k::BitrevB, // + semi2k::BitIntlB, semi2k::BitDeintlB, // + semi2k::RandA, semi2k::RandPermM, semi2k::PermAM, semi2k::PermAP, // + semi2k::InvPermAM, semi2k::InvPermAP, semi2k::InvPermAV, // + semi2k::EqualAA, semi2k::EqualAP, // + semi2k::BeaverCacheKernel>(); if (ctx->config().trunc_allow_msb_error()) { ctx->prot()->regKernel(); diff --git a/libspu/mpc/spdz2k/abprotocol_spdz2k_test.cc b/libspu/mpc/spdz2k/abprotocol_spdz2k_test.cc index b49e34c2..8bed069f 100644 --- a/libspu/mpc/spdz2k/abprotocol_spdz2k_test.cc +++ b/libspu/mpc/spdz2k/abprotocol_spdz2k_test.cc @@ -242,7 +242,7 @@ TEST_P(ConversionTest, BitLT) { auto re = b2p(obj.get(), tmp); const auto field = p0.storage_type().as()->field(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = std::make_unsigned::type; size_t numel = kShape.numel(); @@ -291,7 +291,7 @@ TEST_P(ConversionTest, BitLE) { auto re = b2p(obj.get(), tmp); const auto field = p0.storage_type().as()->field(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = std::make_unsigned::type; size_t numel = kShape.numel(); NdArrayView p0_data(p0.data()); diff --git a/libspu/mpc/spdz2k/arithmetic.cc b/libspu/mpc/spdz2k/arithmetic.cc index 6e902a61..11eeb186 100644 --- a/libspu/mpc/spdz2k/arithmetic.cc +++ b/libspu/mpc/spdz2k/arithmetic.cc @@ -47,9 +47,9 @@ NdArrayRef CastRing(const NdArrayRef& in, FieldType out_field) { const auto in_field = in_ty->field(); auto out = ring_zeros(out_field, in.shape()); - return DISPATCH_ALL_FIELDS(in_field, "_", [&]() { + return DISPATCH_ALL_FIELDS(in_field, [&]() { NdArrayView _in(in); - return DISPATCH_ALL_FIELDS(out_field, "_", [&]() { + return DISPATCH_ALL_FIELDS(out_field, [&]() { NdArrayView _out(out); pforeach(0, in.numel(), [&](int64_t idx) { _out[idx] = static_cast(_in[idx]); @@ -106,7 +106,7 @@ NdArrayRef RandA::proc(KernelEvalContext* ctx, const Shape& shape) const { // - https://eprint.iacr.org/2019/599.pdf // It's safer to keep the number within [-2**(k-2), 2**(k-2)) for comparison // operations. - auto x = ring_rshift(prg_state->genPriv(field, shape), 2) + auto x = ring_rshift(prg_state->genPriv(field, shape), {2}) .as(makeType(field)); auto x_mac = beaver->AuthArrayRef(x, field, k, s); return makeAShare(x, x_mac, field); @@ -296,7 +296,7 @@ bool SingleCheck(KernelEvalContext* ctx, const NdArrayRef& in) { auto* comm = ctx->getState(); auto* beaver = ctx->getState()->beaver(); const auto key = ctx->getState()->key(); - const size_t k = ctx->getState()->k(); + const int64_t k = ctx->getState()->k(); const size_t s = ctx->getState()->s(); // 1. Generate a random, shared value [r] @@ -305,8 +305,8 @@ bool SingleCheck(KernelEvalContext* ctx, const NdArrayRef& in) { // 2. Locally construct [y] const auto& x = getValueShare(in); const auto& x_mac = getMacShare(in); - auto y = ring_add(x, ring_lshift(r, k)); - auto y_mac = ring_add(x_mac, ring_lshift(r_mac, k)); + auto y = ring_add(x, ring_lshift(r, {k})); + auto y_mac = ring_add(x_mac, ring_lshift(r_mac, {k})); // 3. Open the value auto plain_y = comm->allReduce(ReduceOp::ADD, y, kBindName); @@ -334,7 +334,7 @@ bool SingleCheck(KernelEvalContext* ctx, const NdArrayRef& in) { static NdArrayRef wrap_lshift_a(SPUContext* ctx, const NdArrayRef& x, size_t k) { - return UnwrapValue(lshift_a(ctx, WrapValue(x), k)); + return UnwrapValue(lshift_a(ctx, WrapValue(x), {static_cast(k)})); } static NdArrayRef wrap_add_aa(SPUContext* ctx, const NdArrayRef& x, @@ -579,9 +579,8 @@ NdArrayRef MatMulAA::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, } NdArrayRef LShiftA::proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const { + const Sizes& bits) const { const auto field = in.eltype().as()->field(); - bits %= SizeOf(field) * 8; // in const auto& x = getValueShare(in); @@ -617,7 +616,9 @@ NdArrayRef TruncA::proc(KernelEvalContext* ctx, const NdArrayRef& in, beaver->BatchOpen(ring_sub(x, r), ring_sub(x_mac, r_mac), k, s); SPU_ENFORCE(beaver->BatchMacCheck(x_r, check_mac, k, s)); size_t bit_len = SizeOf(field) * 8; - auto tr_x_r = ring_arshift(ring_lshift(x_r, bit_len - k), bit_len - k + bits); + auto tr_x_r = + ring_arshift(ring_lshift(x_r, {static_cast(bit_len - k)}), + {static_cast(bit_len - k + bits)}); ring_bitmask_(tr_x_r, 0, k); // res = [x-r] + [r], which [*] is truncation operation. diff --git a/libspu/mpc/spdz2k/arithmetic.h b/libspu/mpc/spdz2k/arithmetic.h index deb782d5..7edfe33b 100644 --- a/libspu/mpc/spdz2k/arithmetic.h +++ b/libspu/mpc/spdz2k/arithmetic.h @@ -180,7 +180,7 @@ class LShiftA : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override; + const Sizes& bits) const override; }; // Refer to: diff --git a/libspu/mpc/spdz2k/beaver/BUILD.bazel b/libspu/mpc/spdz2k/beaver/BUILD.bazel index 5ea2a3f6..b7f4d80e 100644 --- a/libspu/mpc/spdz2k/beaver/BUILD.bazel +++ b/libspu/mpc/spdz2k/beaver/BUILD.bazel @@ -80,7 +80,7 @@ spu_cc_library( "//libspu/mpc/spdz2k/ot:tiny_ot", "//libspu/mpc/utils:ring_ops", "@yacl//yacl/crypto/tools:prg", - "@yacl//yacl/kernel/algorithms:ot_store", + "@yacl//yacl/kernel/type:ot_store", "@yacl//yacl/link", "@yacl//yacl/utils:matrix_utils", "@yacl//yacl/utils:serialize", diff --git a/libspu/mpc/spdz2k/beaver/beaver_test.cc b/libspu/mpc/spdz2k/beaver/beaver_test.cc index 0395ee21..c1508522 100644 --- a/libspu/mpc/spdz2k/beaver/beaver_test.cc +++ b/libspu/mpc/spdz2k/beaver/beaver_test.cc @@ -126,7 +126,7 @@ TEST_P(BeaverTest, AuthAnd) { EXPECT_TRUE(ring_all_equal(ring_mul(sum_c, sum_key), sum_c_mac)) << sum_c << sum_key << sum_c_mac; - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _valid_a(valid_a); NdArrayView _valid_b(valid_b); NdArrayView _valid_c(valid_c); @@ -314,7 +314,8 @@ TEST_P(BeaverTest, AuthTrunc) { const size_t bit_len = SizeOf(kField) * 8; auto trunc_sum_a = - ring_arshift(ring_lshift(sum_a, bit_len - k), bit_len - k + kBits); + ring_arshift(ring_lshift(sum_a, {static_cast(bit_len - k)}), + {static_cast(bit_len - k + kBits)}); ring_bitmask_(trunc_sum_a, 0, k); EXPECT_TRUE(ring_all_equal(trunc_sum_a, ring_bitmask(sum_b, 0, k))) @@ -386,7 +387,7 @@ TEST_P(BeaverTest, AuthDot) { << sum_c << sum_key << sum_c_mac; auto res = ring_mmul(sum_a, sum_b); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _sum_a(sum_a); NdArrayView _sum_c(sum_c); NdArrayView _res(res); diff --git a/libspu/mpc/spdz2k/beaver/beaver_tfp.cc b/libspu/mpc/spdz2k/beaver/beaver_tfp.cc index 3c7a5ca7..3ab9fd0e 100644 --- a/libspu/mpc/spdz2k/beaver/beaver_tfp.cc +++ b/libspu/mpc/spdz2k/beaver/beaver_tfp.cc @@ -49,7 +49,7 @@ uint128_t BeaverTfpUnsafe::InitSpdzKey(FieldType field, size_t s) { const int64_t size = 1; auto a = prgCreateArray(field, {size}, seed_, &counter_, &desc); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _a(a); if (comm_->getRank() == 0) { auto t = tp_.adjustSpdzKey(desc); @@ -260,8 +260,10 @@ std::pair BeaverTfpUnsafe::BatchOpen( // Open the low k_bits only // value = value + r_val * 2^k // mac = mac + r_mac * 2^k - auto masked_val = ring_add(value, ring_lshift(r_val, k)); - auto masked_mac = ring_add(mac, ring_lshift(r_mac, k)); + auto masked_val = + ring_add(value, ring_lshift(r_val, {static_cast(k)})); + auto masked_mac = + ring_add(mac, ring_lshift(r_mac, {static_cast(k)})); auto open_val = comm_->allReduce(ReduceOp::ADD, masked_val, kBindName); return {open_val, masked_mac}; diff --git a/libspu/mpc/spdz2k/beaver/beaver_tinyot.cc b/libspu/mpc/spdz2k/beaver/beaver_tinyot.cc index f0481944..f8d090fc 100644 --- a/libspu/mpc/spdz2k/beaver/beaver_tinyot.cc +++ b/libspu/mpc/spdz2k/beaver/beaver_tinyot.cc @@ -23,7 +23,7 @@ #include "yacl/crypto/rand/rand.h" #include "yacl/crypto/tools/prg.h" #include "yacl/kernel/algorithms/base_ot.h" -#include "yacl/kernel/algorithms/ot_store.h" +#include "yacl/kernel/type/ot_store.h" #include "yacl/utils/serialize.h" #include "libspu/mpc/common/prg_tensor.h" @@ -68,7 +68,7 @@ NdArrayRef ring_sqrt2k(const NdArrayRef& x, size_t bits = 0) { } auto ret = ring_zeros(field, x.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = std::make_unsigned::type; NdArrayView _ret(ret); NdArrayView _x(x); @@ -101,7 +101,7 @@ NdArrayRef ring_inv2k(const NdArrayRef& x, size_t bits = 0) { } auto ret = ring_zeros(field, x.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = std::make_unsigned::type; NdArrayView _ret(ret); NdArrayView _x(x); @@ -118,7 +118,7 @@ std::vector ring_cast_vector_boolean(const NdArrayRef& x) { const auto field = x.eltype().as()->field(); std::vector res(x.numel()); - DISPATCH_ALL_FIELDS(field, "RingOps", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _x(x); yacl::parallel_for(0, x.numel(), 4096, [&](size_t start, size_t end) { for (size_t i = start; i < end; i++) { @@ -192,7 +192,7 @@ uint128_t BeaverTinyOt::InitSpdzKey(FieldType, size_t s) { // - https://eprint.iacr.org/2018/482.pdf NdArrayRef BeaverTinyOt::AuthArrayRef(const NdArrayRef& x, FieldType field, size_t k, size_t s) { - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using T = ring2k_t; // 1. l_ = max(l, r + s, 2s) @@ -346,7 +346,7 @@ BeaverTinyOt::Triple_Pair BeaverTinyOt::AuthAnd(FieldType field, // Generate authorize bits in the form of B-Share NdArrayRef spdz_choices(makeType(field), {tinyot_num * 3 + sigma}); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = std::make_unsigned::type; auto _size = auth_abcr.choices.size(); NdArrayView _spdz_choices(spdz_choices); @@ -392,7 +392,7 @@ BeaverTinyOt::Triple_Pair BeaverTinyOt::AuthAnd(FieldType field, auto seed = GenSharedSeed(comm_); auto prg = yacl::crypto::Prg(seed); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = std::make_unsigned::type; NdArrayView _check_spdz_bit(check_spdz_bit); NdArrayView _check_spdz_mac(check_spdz_mac); @@ -546,7 +546,7 @@ BeaverTinyOt::Pair_Pair BeaverTinyOt::AuthTrunc(FieldType field, NdArrayRef tr_val(b_val.eltype(), shape); NdArrayRef tr_mac(b_val.eltype(), shape); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using PShrT = ring2k_t; NdArrayView _val(b_val); NdArrayView _mac(b_mac); @@ -642,7 +642,7 @@ BeaverTinyOt::Pair BeaverTinyOt::AuthRandBit(FieldType field, SPU_ENFORCE(ring_all_equal(ring_bitmask(square, 0, 1), ones)); auto root = ring_sqrt2k(square, k + 2); auto root_inv = ring_inv2k(root, k + 2); - auto root_inv_div2 = ring_rshift(root_inv, 1); + auto root_inv_div2 = ring_rshift(root_inv, {1}); auto d = ring_mul(root_inv_div2, y); auto d_mac = ring_mul(root_inv_div2, y_mac); @@ -751,17 +751,19 @@ std::pair BeaverTinyOt::BatchOpen( // Open the low k_bits only // value = value + r * 2^k // mac = mac + r_mac * 2^k - auto masked_val = ring_add(value, ring_lshift(r_val, k)); - auto masked_mac = ring_add(mac, ring_lshift(r_mac, k)); + auto masked_val = + ring_add(value, ring_lshift(r_val, {static_cast(k)})); + auto masked_mac = + ring_add(mac, ring_lshift(r_mac, {static_cast(k)})); - // Because we would use Maccheck to comfirm the open value. + // Because we would use Maccheck to confirm the open value. // Thus, we don't need commit them. auto open_val = comm_->allReduce(ReduceOp::ADD, masked_val, kBindName); return {open_val, masked_mac}; } void BeaverTinyOt::rotSend(FieldType field, NdArrayRef* q0, NdArrayRef* q1) { - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = ring2k_t; SPDLOG_DEBUG("rotSend start with numel {}", q0->numel()); @@ -784,7 +786,7 @@ void BeaverTinyOt::rotSend(FieldType field, NdArrayRef* q0, NdArrayRef* q1) { // todo: use dynamic_bitset instead of ArrayRef for `a` to improve performance void BeaverTinyOt::rotRecv(FieldType field, const NdArrayRef& a, NdArrayRef* s) { - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = ring2k_t; SPDLOG_DEBUG("rotRecv start with numel {}", a.numel()); @@ -814,7 +816,7 @@ void BeaverTinyOt::rotRecv(FieldType field, const NdArrayRef& a, // SPDZ2k: Efficient MPC mod 2k for Dishonest Majority // - https://eprint.iacr.org/2018/482.pdf NdArrayRef BeaverTinyOt::voleSend(FieldType field, const NdArrayRef& x) { - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using T = ring2k_t; SPU_ENFORCE(spdz2k_ot_primitives_ != nullptr); @@ -831,7 +833,7 @@ NdArrayRef BeaverTinyOt::voleSend(FieldType field, const NdArrayRef& x) { } NdArrayRef BeaverTinyOt::voleRecv(FieldType field, const NdArrayRef& alpha) { - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using T = ring2k_t; SPU_ENFORCE(spdz2k_ot_primitives_ != nullptr); @@ -915,7 +917,7 @@ BeaverTinyOt::Triple_Pair BeaverTinyOt::AuthMul(FieldType field, size_t s) { auto _size = shape.numel(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using T = ring2k_t; SPDLOG_DEBUG("AuthMul start..."); diff --git a/libspu/mpc/spdz2k/beaver/beaver_tinyot.h b/libspu/mpc/spdz2k/beaver/beaver_tinyot.h index b0d35372..6f8ec9ad 100644 --- a/libspu/mpc/spdz2k/beaver/beaver_tinyot.h +++ b/libspu/mpc/spdz2k/beaver/beaver_tinyot.h @@ -14,7 +14,7 @@ #pragma once -#include "yacl/kernel/algorithms/ot_store.h" +#include "yacl/kernel/type/ot_store.h" #include "yacl/link/context.h" #include "libspu/mpc/common/prg_state.h" diff --git a/libspu/mpc/spdz2k/beaver/trusted_party.cc b/libspu/mpc/spdz2k/beaver/trusted_party.cc index 978f1eee..287d1f76 100644 --- a/libspu/mpc/spdz2k/beaver/trusted_party.cc +++ b/libspu/mpc/spdz2k/beaver/trusted_party.cc @@ -248,9 +248,10 @@ std::vector TrustedParty::adjustAuthTrunc( ring_add_(r0[0], ring_sub(rs[0], t_rs)); // r0[1] += (rs[0] >> bits) - rs[1]; - const size_t bit_len = SizeOf(field) * 8; + const int64_t bit_len = SizeOf(field) * 8; auto tr_rs0 = - ring_arshift(ring_lshift(rs[0], bit_len - k), bit_len - k + bits); + ring_arshift(ring_lshift(rs[0], {static_cast(bit_len - k)}), + {static_cast(bit_len - k + bits)}); ring_bitmask_(tr_rs0, 0, k); ring_add_(r0[1], ring_sub(tr_rs0, rs[1])); diff --git a/libspu/mpc/spdz2k/boolean.cc b/libspu/mpc/spdz2k/boolean.cc index 226de5ad..7f230a28 100644 --- a/libspu/mpc/spdz2k/boolean.cc +++ b/libspu/mpc/spdz2k/boolean.cc @@ -38,7 +38,7 @@ NdArrayRef P2Value(FieldType out_field, const NdArrayRef& in, size_t k, size_t new_nbits = 0) { const auto* in_ty = in.eltype().as(); const auto in_field = in_ty->field(); - return DISPATCH_ALL_FIELDS(in_field, "_", [&]() { + return DISPATCH_ALL_FIELDS(in_field, [&]() { using PShrT = ring2k_t; size_t valid_nbits = k; @@ -52,7 +52,7 @@ NdArrayRef P2Value(FieldType out_field, const NdArrayRef& in, size_t k, out_shape.back() *= valid_nbits; auto out = ring_zeros(out_field, out_shape); - return DISPATCH_ALL_FIELDS(out_field, "_", [&]() { + return DISPATCH_ALL_FIELDS(out_field, [&]() { using BShrT = ring2k_t; NdArrayView _in(in); NdArrayView _out(out); @@ -279,10 +279,10 @@ NdArrayRef B2P::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { // 2. Maccheck SPU_ENFORCE(beaver_ptr->BatchMacCheck(pub, mac, 1, s)); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using BShrT = ring2k_t; auto& value = pub; - return DISPATCH_ALL_FIELDS(out_field, "_", [&]() { + return DISPATCH_ALL_FIELDS(out_field, [&]() { using PShrT = ring2k_t; NdArrayRef out(makeType(out_field), in.shape()); @@ -374,7 +374,7 @@ NdArrayRef BitrevB::proc(KernelEvalContext* ctx, const NdArrayRef& in, auto ret = x.clone(); auto ret_mac = x_mac.clone(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _ret(ret); NdArrayView _ret_mac(ret_mac); NdArrayView _x(x); @@ -515,37 +515,48 @@ NdArrayRef AndBP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, } NdArrayRef LShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const { + const Sizes& bits) const { + SPU_ENFORCE(bits.size() == 1, "only support splat shift bits, but got {}", + bits); + const auto field = in.eltype().as()->field(); const auto k = ctx->getState()->k(); const size_t nbits = in.eltype().as()->nbits(); - size_t res_nbits = nbits + bits; + size_t bit = bits[0]; + size_t res_nbits = nbits + bit; - if (bits >= k) { + if (bit >= k) { res_nbits = 1; } else if (res_nbits > k) { res_nbits = k; } - auto [ret, ret_mac] = LShiftBImpl(in, bits, k); + auto [ret, ret_mac] = LShiftBImpl(in, bit, k); return makeBShare(ret, ret_mac, field, res_nbits); } NdArrayRef RShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const { + const Sizes& bits) const { + SPU_ENFORCE(bits.size() == 1, "only support splat shift bits, but got {}", + bits); + + size_t bit = bits[0]; const auto field = in.eltype().as()->field(); const auto nbits = in.eltype().as()->nbits(); - size_t new_nbis = nbits > bits ? nbits - bits : 1; - auto [ret, ret_mac] = RShiftBImpl(in, bits); + size_t new_nbis = nbits > bit ? nbits - bit : 1; + auto [ret, ret_mac] = RShiftBImpl(in, bit); return makeBShare(ret, ret_mac, field, new_nbis); } static NdArrayRef wrap_rshift_b(SPUContext* ctx, const NdArrayRef& x, - size_t bits) { + const Sizes& bits) { return UnwrapValue(rshift_b(ctx, WrapValue(x), bits)); } NdArrayRef ARShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const { + const Sizes& bits) const { + SPU_ENFORCE(bits.size() == 1, "only support splat shift bits, but got {}", + bits); + const auto field = in.eltype().as()->field(); const auto k = ctx->getState()->k(); const auto nbits = in.eltype().as()->nbits(); @@ -553,7 +564,7 @@ NdArrayRef ARShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, if (nbits != k) { return wrap_rshift_b(ctx->sctx(), in, bits); } else { - auto [ret, ret_mac] = ARShiftBImpl(in, bits, k); + auto [ret, ret_mac] = ARShiftBImpl(in, bits[0], k); return makeBShare(ret, ret_mac, field, k); } } @@ -565,7 +576,7 @@ NdArrayRef BitIntlB::proc(KernelEvalContext* ctx, const NdArrayRef& in, const auto k = ctx->getState()->k(); NdArrayRef out(in.eltype(), in.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = ring2k_t; if (in.eltype().isa()) { @@ -612,7 +623,7 @@ NdArrayRef BitDeintlB::proc(KernelEvalContext* ctx, const NdArrayRef& in, const auto k = ctx->getState()->k(); NdArrayRef out(in.eltype(), in.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = ring2k_t; if (in.eltype().isa()) { diff --git a/libspu/mpc/spdz2k/boolean.h b/libspu/mpc/spdz2k/boolean.h index 6d2d9c9f..7f42cd36 100644 --- a/libspu/mpc/spdz2k/boolean.h +++ b/libspu/mpc/spdz2k/boolean.h @@ -146,7 +146,7 @@ class LShiftB : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override; + const Sizes& bits) const override; }; class RShiftB : public ShiftKernel { @@ -158,7 +158,7 @@ class RShiftB : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override; + const Sizes& bits) const override; }; class ARShiftB : public ShiftKernel { @@ -170,7 +170,7 @@ class ARShiftB : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override; + const Sizes& bits) const override; }; class BitIntlB : public BitSplitKernel { diff --git a/libspu/mpc/spdz2k/conversion.cc b/libspu/mpc/spdz2k/conversion.cc index 8f8ded08..122d1c26 100644 --- a/libspu/mpc/spdz2k/conversion.cc +++ b/libspu/mpc/spdz2k/conversion.cc @@ -113,7 +113,7 @@ CircuitBasicBlock MakeSPDZBasicBlock(SPUContext* ctx) { cbb._and = [=](T const& x, T const& y) -> T { COMMUTATIVE_DISPATCH(and_pp, and_bp, and_bb); }; - cbb.lshift = [=](T const& x, size_t bits) -> T { + cbb.lshift = [=](T const& x, const Sizes& bits) -> T { if (_IsP(x)) { return lshift_p(ctx, x, bits); } else if (_IsB(x)) { @@ -121,7 +121,7 @@ CircuitBasicBlock MakeSPDZBasicBlock(SPUContext* ctx) { } SPU_THROW("unsupported op x={}", x); }; - cbb.rshift = [=](T const& x, size_t bits) -> T { + cbb.rshift = [=](T const& x, const Sizes& bits) -> T { if (_IsP(x)) { return rshift_p(ctx, x, bits); } else if (_IsB(x)) { @@ -191,7 +191,7 @@ NdArrayRef Bit2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { ring_bitmask_(c, 0, 1); // 5. [x] = c + [r] - 2 * c * [r] - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _c(c); NdArrayView _r(r); NdArrayView _r_mac(r_mac); @@ -292,7 +292,7 @@ NdArrayRef B2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { ring_bitmask_(c, 0, 1); // 4. [x] = c + [r] - 2 * c * [r] - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { NdArrayRef out(makeType(field, true), out_shape); NdArrayRef expand_out(makeType(field, true), _in.shape()); @@ -354,8 +354,8 @@ NdArrayRef MSB::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { // then set r = \sum r_i 2^{i} for (int64_t i = 0; i < k; ++i) { auto [_r_i, _r_i_mac] = beaver->AuthRandBit(field, in.shape(), k, s); - ring_add_(_r_val, ring_lshift(_r_i, i)); - ring_add_(_r_mac, ring_lshift(_r_i_mac, i)); + ring_add_(_r_val, ring_lshift(_r_i, {i})); + ring_add_(_r_mac, ring_lshift(_r_i_mac, {i})); // record r_i & r_i_mac _r_vec.emplace_back(std::move(_r_i)); _r_mac_vec.emplace_back(std::move(_r_i_mac)); @@ -381,8 +381,8 @@ NdArrayRef MSB::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { const auto ty = makeType(field); for (int64_t i = 0; i < k - 1; ++i) { - ring_add_(_ar, ring_lshift(_r_vec[i], i)); - ring_add_(_ar_mac, ring_lshift(_r_mac_vec[i], i)); + ring_add_(_ar, ring_lshift(_r_vec[i], {i})); + ring_add_(_ar_mac, ring_lshift(_r_mac_vec[i], {i})); auto at_r_i = makeAShare(_r_vec[i], _r_mac_vec[i], field); auto bt_r_i = wrap_a2bit(ctx->sctx(), at_r_i); @@ -415,8 +415,8 @@ NdArrayRef MSB::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { // d = a - a' auto _au = getValueShare(au); auto _au_mac = getMacShare(au); - auto _aa = ring_sub(ring_lshift(_au, k - 1), _ar); - auto _aa_mac = ring_sub(ring_lshift(_au_mac, k - 1), _ar_mac); + auto _aa = ring_sub(ring_lshift(_au, {k - 1}), _ar); + auto _aa_mac = ring_sub(ring_lshift(_au_mac, {k - 1}), _ar_mac); if (comm->getRank() == 0) { ring_add_(_aa, _c_open); } @@ -426,18 +426,18 @@ NdArrayRef MSB::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { // 7. let e = d + 2^{k-1} b, then open e auto [_b, _b_mac] = beaver->AuthRandBit(field, in.shape(), k, s); - auto _e = ring_add(_d, ring_lshift(_b, k - 1)); - auto _e_mac = ring_add(_d_mac, ring_lshift(_b_mac, k - 1)); + auto _e = ring_add(_d, ring_lshift(_b, {k - 1})); + auto _e_mac = ring_add(_d_mac, ring_lshift(_b_mac, {k - 1})); auto [e_open, e_zero_mac] = beaver->BatchOpen(_e, _e_mac, k, s); SPU_ENFORCE(beaver->BatchMacCheck(e_open, e_zero_mac, k, s)); // 8. e' be the most significant bit of e - auto _ee = ring_bitmask(ring_rshift(e_open, k - 1), 0, 1); + auto _ee = ring_bitmask(ring_rshift(e_open, {k - 1}), 0, 1); // 9. output e_{k-1} + b - 2 e_{k-1} b - auto _ret = ring_sub(_b, ring_lshift(ring_mul(_b, _ee), 1)); - auto _ret_mac = ring_sub(_b_mac, ring_lshift(ring_mul(_b_mac, _ee), 1)); + auto _ret = ring_sub(_b, ring_lshift(ring_mul(_b, _ee), {1})); + auto _ret_mac = ring_sub(_b_mac, ring_lshift(ring_mul(_b_mac, _ee), {1})); if (comm->getRank() == 0) { ring_add_(_ret, _ee); } diff --git a/libspu/mpc/spdz2k/io.cc b/libspu/mpc/spdz2k/io.cc index 136ba2d7..dff59022 100644 --- a/libspu/mpc/spdz2k/io.cc +++ b/libspu/mpc/spdz2k/io.cc @@ -60,9 +60,9 @@ std::vector Spdz2kIo::toShares(const NdArrayRef& raw, const auto runtime_field = getRuntimeField(field); NdArrayRef x(makeType(runtime_field), raw.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _raw(raw); - DISPATCH_ALL_FIELDS(runtime_field, "_", [&]() { + DISPATCH_ALL_FIELDS(runtime_field, [&]() { NdArrayView _x(x); pforeach(0, raw.numel(), [&](int64_t idx) { _x[idx] = static_cast(_raw[idx]); @@ -113,9 +113,9 @@ NdArrayRef Spdz2kIo::fromShares(const std::vector& shares) const { { NdArrayRef x(makeType(field_), res.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _res(res); - DISPATCH_ALL_FIELDS(field_, "_", [&]() { + DISPATCH_ALL_FIELDS(field_, [&]() { NdArrayView _x(x); pforeach(0, x.numel(), [&](int64_t idx) { _x[idx] = static_cast(_res[idx]); diff --git a/libspu/mpc/spdz2k/ot/BUILD.bazel b/libspu/mpc/spdz2k/ot/BUILD.bazel index 45143795..124adaf1 100644 --- a/libspu/mpc/spdz2k/ot/BUILD.bazel +++ b/libspu/mpc/spdz2k/ot/BUILD.bazel @@ -68,7 +68,7 @@ spu_cc_library( "//libspu/mpc/utils:ring_ops", "@com_github_emptoolkit_emp_tool//:emp-tool", "@yacl//yacl/crypto/tools:prg", - "@yacl//yacl/kernel/algorithms:ot_store", + "@yacl//yacl/kernel/type:ot_store", "@yacl//yacl/link", ], ) diff --git a/libspu/mpc/spdz2k/ot/kos_ote.h b/libspu/mpc/spdz2k/ot/kos_ote.h index 58b9fc07..4ba1f1b9 100644 --- a/libspu/mpc/spdz2k/ot/kos_ote.h +++ b/libspu/mpc/spdz2k/ot/kos_ote.h @@ -15,7 +15,7 @@ #pragma once #include "absl/types/span.h" #include "yacl/base/dynamic_bitset.h" -#include "yacl/kernel/algorithms/ot_store.h" +#include "yacl/kernel/type/ot_store.h" #include "yacl/link/link.h" namespace spu::mpc::spdz2k { diff --git a/libspu/mpc/spdz2k/ot/tiny_ot.h b/libspu/mpc/spdz2k/ot/tiny_ot.h index cb099095..ebe38563 100644 --- a/libspu/mpc/spdz2k/ot/tiny_ot.h +++ b/libspu/mpc/spdz2k/ot/tiny_ot.h @@ -13,7 +13,7 @@ // limitations under the License. #include -#include "yacl/kernel/algorithms/ot_store.h" +#include "yacl/kernel/type/ot_store.h" #include "libspu/mpc/common/communicator.h" diff --git a/libspu/mpc/spdz2k/value.cc b/libspu/mpc/spdz2k/value.cc index a4e49fb8..029f0651 100644 --- a/libspu/mpc/spdz2k/value.cc +++ b/libspu/mpc/spdz2k/value.cc @@ -56,7 +56,7 @@ NdArrayRef makeBShare(const NdArrayRef& s1, const NdArrayRef& s2, NdArrayRef res(ty, new_shape); int64_t res_numel = res.numel(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView> _res(res); pforeach(0, res_numel * k, [&](int64_t i) { @@ -108,7 +108,7 @@ NdArrayRef getShare(const NdArrayRef& in, int64_t share_idx) { } else { NdArrayRef ret(ty, new_shape); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { size_t numel = in.numel(); NdArrayView _ret(ret); NdArrayView> _in(in); @@ -141,7 +141,7 @@ size_t maxNumBits(const NdArrayRef& lhs, const NdArrayRef& rhs) { } const auto* rhs_ty = rhs.eltype().as(); const auto rhs_field = rhs_ty->field(); - return DISPATCH_ALL_FIELDS(rhs_field, "_", [&]() { + return DISPATCH_ALL_FIELDS(rhs_field, [&]() { using PShrT = ring2k_t; return std::max(lhs.eltype().as()->nbits(), maxBitWidth(rhs)); @@ -158,7 +158,7 @@ size_t minNumBits(const NdArrayRef& lhs, const NdArrayRef& rhs) { } const auto* rhs_ty = rhs.eltype().as(); const auto rhs_field = rhs_ty->field(); - return DISPATCH_ALL_FIELDS(rhs_field, "_", [&]() { + return DISPATCH_ALL_FIELDS(rhs_field, [&]() { using PShrT = ring2k_t; return std::min(lhs.eltype().as()->nbits(), maxBitWidth(rhs)); diff --git a/libspu/mpc/tools/benchmark.h b/libspu/mpc/tools/benchmark.h index 811dc9aa..764a50e5 100644 --- a/libspu/mpc/tools/benchmark.h +++ b/libspu/mpc/tools/benchmark.h @@ -172,10 +172,14 @@ class OpData { } for (auto& b1 : b1s) { b1 = p2b(obj_, rand_p(obj_, Shape{state.range(1)})); - b1 = lshift_b(obj_, b1, - SizeOf(static_cast(state.range(0))) * 8 - 1); - b1 = rshift_b(obj_, b1, - SizeOf(static_cast(state.range(0))) * 8 - 1); + b1 = lshift_b( + obj_, b1, + {static_cast( + SizeOf(static_cast(state.range(0))) * 8 - 1)}); + b1 = rshift_b( + obj_, b1, + {static_cast( + SizeOf(static_cast(state.range(0))) * 8 - 1)}); } } virtual ~OpData() = default; @@ -218,12 +222,12 @@ MPC_BENCH_DEFINE(BenchAndSP, OpData1S1P, and_sp, ss[0], ps[0]) MPC_BENCH_DEFINE(BenchXorSP, OpData1S1P, xor_sp, ss[0], ps[0]) MPC_BENCH_DEFINE(BenchNotS, OpData1S, not_s, ss[0]) MPC_BENCH_DEFINE(BenchNotP, OpData1P, not_p, ps[0]) -MPC_BENCH_DEFINE(BenchLShiftS, OpData1S, lshift_s, ss[0], state.range(2)) -MPC_BENCH_DEFINE(BenchLShiftP, OpData1P, lshift_p, ps[0], state.range(2)) -MPC_BENCH_DEFINE(BenchRShiftS, OpData1S, rshift_s, ss[0], state.range(2)) -MPC_BENCH_DEFINE(BenchRShiftP, OpData1P, rshift_p, ps[0], state.range(2)) -MPC_BENCH_DEFINE(BenchARShiftS, OpData1S, arshift_s, ss[0], state.range(2)) -MPC_BENCH_DEFINE(BenchARShiftP, OpData1P, arshift_p, ps[0], state.range(2)) +MPC_BENCH_DEFINE(BenchLShiftS, OpData1S, lshift_s, ss[0], {state.range(2)}) +MPC_BENCH_DEFINE(BenchLShiftP, OpData1P, lshift_p, ps[0], {state.range(2)}) +MPC_BENCH_DEFINE(BenchRShiftS, OpData1S, rshift_s, ss[0], {state.range(2)}) +MPC_BENCH_DEFINE(BenchRShiftP, OpData1P, rshift_p, ps[0], {state.range(2)}) +MPC_BENCH_DEFINE(BenchARShiftS, OpData1S, arshift_s, ss[0], {state.range(2)}) +MPC_BENCH_DEFINE(BenchARShiftP, OpData1P, arshift_p, ps[0], {state.range(2)}) MPC_BENCH_DEFINE(BenchTruncS, OpData1S, trunc_s_wrapper, ss[0], state.range(2)) MPC_BENCH_DEFINE(BenchS2P, OpData1S, s2p, ss[0]) MPC_BENCH_DEFINE(BenchP2S, OpData1P, p2s, ps[0]) @@ -243,7 +247,7 @@ MPC_BENCH_DEFINE(BenchMulAP, OpData1A1P, mul_ap, as[0], ps[0]) MPC_BENCH_DEFINE(BenchAddAA, OpData2A, add_aa, as[0], as[1]) MPC_BENCH_DEFINE(BenchMulAA, OpData2A, mul_aa, as[0], as[1]) MPC_BENCH_DEFINE(BenchMulA1B, OpData1A1B1, mul_a1b, as[0], b1s[0]) -MPC_BENCH_DEFINE(BenchLShiftA, OpData1A, lshift_a, as[0], state.range(2)) +MPC_BENCH_DEFINE(BenchLShiftA, OpData1A, lshift_a, as[0], {state.range(2)}) MPC_BENCH_DEFINE(BenchTruncA, OpData1A, trunc_a_wrapper, as[0], state.range(2)) // MPC_BENCH_DEFINE(BenchMMulAP, OpData1MA1MP, mmul_ap, mas[0], mps[0], // state.range(1), state.range(1), state.range(1)) @@ -257,9 +261,9 @@ MPC_BENCH_DEFINE(BenchAndBP, OpData1B1P, and_bp, bs[0], ps[0]) MPC_BENCH_DEFINE(BenchAndBB, OpData2B, and_bb, bs[0], bs[1]) MPC_BENCH_DEFINE(BenchXorBP, OpData1B1P, xor_bp, bs[0], ps[0]) MPC_BENCH_DEFINE(BenchXorBB, OpData2B, xor_bb, bs[0], bs[1]) -MPC_BENCH_DEFINE(BenchLShiftB, OpData1B, lshift_b, bs[0], state.range(2)) -MPC_BENCH_DEFINE(BenchRShiftB, OpData1B, rshift_b, bs[0], state.range(2)) -MPC_BENCH_DEFINE(BenchARShiftB, OpData1B, arshift_b, bs[0], state.range(2)) +MPC_BENCH_DEFINE(BenchLShiftB, OpData1B, lshift_b, bs[0], {state.range(2)}) +MPC_BENCH_DEFINE(BenchRShiftB, OpData1B, rshift_b, bs[0], {state.range(2)}) +MPC_BENCH_DEFINE(BenchARShiftB, OpData1B, arshift_b, bs[0], {state.range(2)}) MPC_BENCH_DEFINE(BenchBitRevB, OpData1B, bitrev_b, bs[0], 0, SizeOf(static_cast(state.range(0)))) MPC_BENCH_DEFINE(BenchBitIntlB, OpData1B, bitintl_b, bs[0], 0) diff --git a/libspu/mpc/utils/BUILD.bazel b/libspu/mpc/utils/BUILD.bazel index dea5932c..7523c4f3 100644 --- a/libspu/mpc/utils/BUILD.bazel +++ b/libspu/mpc/utils/BUILD.bazel @@ -22,6 +22,7 @@ spu_cc_library( hdrs = ["circuits.h"], deps = [ "//libspu/core:bit_utils", + "//libspu/core:shape", "//libspu/core:vectorize", ], ) diff --git a/libspu/mpc/utils/circuits.h b/libspu/mpc/utils/circuits.h index 3d032d6b..06cad9a0 100644 --- a/libspu/mpc/utils/circuits.h +++ b/libspu/mpc/utils/circuits.h @@ -22,6 +22,7 @@ #include "yacl/base/int128.h" #include "libspu/core/bit_utils.h" +#include "libspu/core/shape.h" #include "libspu/core/vectorize.h" namespace spu::mpc { @@ -35,10 +36,10 @@ struct CircuitBasicBlock { using And = std::function; // (logical) left shift - using LShift = std::function; + using LShift = std::function; // (logical) right shift - using RShift = std::function; + using RShift = std::function; // Init a constant. using InitLike = std::function; @@ -72,9 +73,9 @@ T kogge_stone(const CircuitBasicBlock& ctx, T const& lhs, T const& rhs, auto G = ctx._and(lhs, rhs); for (int idx = 0; idx < Log2Ceil(nbits); ++idx) { - const size_t offset = 1UL << idx; - auto G1 = ctx.lshift(G, offset); - auto P1 = ctx.lshift(P, offset); + const int64_t offset = 1L << idx; + auto G1 = ctx.lshift(G, {offset}); + auto P1 = ctx.lshift(P, {offset}); // P1 = P & P1 // G1 = G ^ (P & G1) @@ -90,7 +91,7 @@ T kogge_stone(const CircuitBasicBlock& ctx, T const& lhs, T const& rhs, } // out = (G << 1) ^ p0 - auto C = ctx.lshift(G, 1); + auto C = ctx.lshift(G, {1}); return ctx._xor(ctx._xor(lhs, rhs), C); } @@ -122,12 +123,12 @@ T sklansky(const CircuitBasicBlock& ctx, T const& lhs, T const& rhs, auto G = ctx._and(lhs, rhs); for (int idx = 0; idx < Log2Ceil(nbits); ++idx) { const auto s_mask = ctx.init_like(G, kSelMask[idx]); - auto G1 = ctx.lshift(ctx._and(G, s_mask), 1); - auto P1 = ctx.lshift(ctx._and(P, s_mask), 1); + auto G1 = ctx.lshift(ctx._and(G, s_mask), {1}); + auto P1 = ctx.lshift(ctx._and(P, s_mask), {1}); for (int j = 0; j < idx; j++) { - G1 = ctx._xor(G1, ctx.lshift(G1, 1 << j)); - P1 = ctx._xor(P1, ctx.lshift(P1, 1 << j)); + G1 = ctx._xor(G1, ctx.lshift(G1, {1 << j})); + P1 = ctx._xor(P1, ctx.lshift(P1, {1 << j})); } const auto k_mask = ctx.init_like(G, kKeepMasks[idx]); @@ -147,7 +148,7 @@ T sklansky(const CircuitBasicBlock& ctx, T const& lhs, T const& rhs, } // out = (G0 << 1) ^ p0 - auto C = ctx.lshift(G, 1); + auto C = ctx.lshift(G, {1}); return ctx._xor(ctx._xor(lhs, rhs), C); } @@ -181,21 +182,22 @@ T odd_even_split(const CircuitBasicBlock& ctx, const T& v, size_t nbits) { }}; // let r = v - T r = ctx.lshift(v, 0); + T r = ctx.lshift(v, {0}); for (int idx = 0; idx + 1 < Log2Ceil(nbits); ++idx) { // r = (r & keep) ^ ((r >> i) & move) ^ ((r & move) << i) const auto keep = ctx.init_like(r, kKeepMasks[idx]); const auto move = ctx.init_like(r, kSwapMasks[idx]); r = ctx._xor(ctx._and(r, keep), - ctx._xor(ctx._and(ctx.rshift(r, 1 << idx), move), - ctx.lshift(ctx._and(r, move), 1 << idx))); + ctx._xor(ctx._and(ctx.rshift(r, {1 << idx}), move), + ctx.lshift(ctx._and(r, move), {1 << idx}))); } if (!absl::has_single_bit(nbits)) { // handle non 2^k bits case. T mask = ctx.init_like(r, (1ULL << (nbits / 2)) - 1); - r = ctx._xor(ctx.lshift(ctx.rshift(r, 1 << Log2Floor(nbits)), nbits / 2), + r = ctx._xor(ctx.lshift(ctx.rshift(r, {1 << Log2Floor(nbits)}), + {static_cast(nbits / 2)}), ctx._and(r, mask)); } @@ -228,7 +230,7 @@ T carry_out(const CircuitBasicBlock& ctx, const T& x, const T& y, auto perm = odd_even_split(ctx, in, kk); T mask = ctx.init_like(perm, (static_cast(1) << hk) - 1); T t0 = ctx._and(perm, mask); - T t1 = ctx._and(ctx.rshift(perm, hk), mask); + T t1 = ctx._and(ctx.rshift(perm, {static_cast(hk)}), mask); ctx.set_nbits(t0, hk); ctx.set_nbits(t1, hk); return std::make_tuple(t0, t1); @@ -247,8 +249,8 @@ T carry_out(const CircuitBasicBlock& ctx, const T& x, const T& y, while (k > 1) { if (k % 2 != 0) { k += 1; - P = ctx.lshift(P, 1); - G = ctx.lshift(G, 1); + P = ctx.lshift(P, {1}); + G = ctx.lshift(G, {1}); } auto [P0, P1] = bit_split(P, k); auto [G0, G1] = bit_split(G, k); diff --git a/libspu/mpc/utils/circuits_test.cc b/libspu/mpc/utils/circuits_test.cc index 1e8e034e..c0e82433 100644 --- a/libspu/mpc/utils/circuits_test.cc +++ b/libspu/mpc/utils/circuits_test.cc @@ -26,8 +26,8 @@ CircuitBasicBlock makeScalarCBB() { CircuitBasicBlock cbb; cbb._xor = [](T const& lhs, T const& rhs) -> T { return lhs ^ rhs; }; cbb._and = [](T const& lhs, T const& rhs) -> T { return lhs & rhs; }; - cbb.lshift = [](T const& x, size_t bits) -> T { return x << bits; }; - cbb.rshift = [](T const& x, size_t bits) -> T { return x >> bits; }; + cbb.lshift = [](T const& x, const Sizes& bits) -> T { return x << bits[0]; }; + cbb.rshift = [](T const& x, const Sizes& bits) -> T { return x >> bits[0]; }; cbb.init_like = [](T const&, uint128_t init) -> T { return static_cast(init); }; @@ -53,16 +53,16 @@ CircuitBasicBlock makeVectorCBB() { std::bit_and<>()); return res; }; - cbb.lshift = [](C const& x, size_t bits) -> C { + cbb.lshift = [](C const& x, const Sizes& bits) -> C { C res; std::transform(x.begin(), x.end(), std::back_inserter(res), - [&](const auto& e) { return e << bits; }); + [&](const auto& e) { return e << bits[0]; }); return res; }; - cbb.rshift = [](C const& x, size_t bits) -> C { + cbb.rshift = [](C const& x, const Sizes& bits) -> C { C res; std::transform(x.begin(), x.end(), std::back_inserter(res), - [&](const auto& e) { return e >> bits; }); + [&](const auto& e) { return e >> bits[0]; }); return res; }; cbb.set_nbits = [](C&, size_t) -> void {}; diff --git a/libspu/mpc/utils/permute.cc b/libspu/mpc/utils/permute.cc index 1c5d2920..9ef42a61 100644 --- a/libspu/mpc/utils/permute.cc +++ b/libspu/mpc/utils/permute.cc @@ -27,7 +27,7 @@ PermVector ring2pv(const NdArrayRef& x) { x.eltype()); const auto field = x.eltype().as()->field(); PermVector pv(x.numel()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _x(x); pforeach(0, x.numel(), [&](int64_t idx) { pv[idx] = int64_t(_x[idx]); }); }); @@ -39,7 +39,7 @@ NdArrayRef applyInvPerm(const NdArrayRef& x, absl::Span pv) { NdArrayRef y(x.eltype(), x.shape()); const auto field = x.eltype().as()->field(); - DISPATCH_ALL_FIELDS(field, kPermModule, [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _x(x); NdArrayView _y(y); for (int64_t i = 0; i < y.numel(); i++) { @@ -54,7 +54,7 @@ NdArrayRef applyPerm(const NdArrayRef& x, absl::Span pv) { NdArrayRef y(x.eltype(), x.shape()); const auto field = x.eltype().as()->field(); - DISPATCH_ALL_FIELDS(field, kPermModule, [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _x(x); NdArrayView _y(y); for (int64_t i = 0; i < y.numel(); i++) { @@ -86,7 +86,7 @@ PermVector genPermBySort(const NdArrayRef& x) { PermVector perm(x.shape()[0]); std::iota(perm.begin(), perm.end(), 0); const auto field = x.eltype().as()->field(); - DISPATCH_ALL_FIELDS(field, kPermModule, [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = std::make_signed_t; NdArrayView _x(x); diff --git a/libspu/mpc/utils/ring_ops.cc b/libspu/mpc/utils/ring_ops.cc index 0d49f743..a09bc56c 100644 --- a/libspu/mpc/utils/ring_ops.cc +++ b/libspu/mpc/utils/ring_ops.cc @@ -29,8 +29,6 @@ namespace spu::mpc { namespace { -constexpr char kModule[] = "RingOps"; - #define SPU_ENFORCE_RING(x) \ SPU_ENFORCE((x).eltype().isa(), "expect ring type, got={}", \ (x).eltype()); @@ -47,7 +45,7 @@ constexpr char kModule[] = "RingOps"; ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, x); \ const auto field = x.eltype().as()->field(); \ const int64_t numel = ret.numel(); \ - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { \ + return DISPATCH_ALL_FIELDS(field, [&]() { \ using T = std::make_signed_t; \ NdArrayView _x(x); \ NdArrayView _ret(ret); \ @@ -67,7 +65,7 @@ DEF_UNARY_RING_OP(ring_neg, -); ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, y); \ const auto field = x.eltype().as()->field(); \ const int64_t numel = ret.numel(); \ - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { \ + return DISPATCH_ALL_FIELDS(field, [&]() { \ NdArrayView _x(x); \ NdArrayView _y(y); \ NdArrayView _ret(ret); \ @@ -86,40 +84,56 @@ DEF_BINARY_RING_OP(ring_xor, ^); #undef DEF_BINARY_RING_OP -void ring_arshift_impl(NdArrayRef& ret, const NdArrayRef& x, size_t bits) { +void ring_arshift_impl(NdArrayRef& ret, const NdArrayRef& x, + const Sizes& bits) { ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, x); + bool is_splat = bits.size() == 1; + SPU_ENFORCE(static_cast(bits.size()) == x.numel() || is_splat, + "mismatched numel {} vs {}", bits.size(), x.numel()); const auto numel = ret.numel(); const auto field = x.eltype().as()->field(); - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { // According to K&R 2nd edition the results are implementation-dependent for // right shifts of signed values, but "usually" its arithmetic right shift. using S = std::make_signed::type; NdArrayView _ret(ret); NdArrayView _x(x); - pforeach(0, numel, [&](int64_t idx) { _ret[idx] = _x[idx] >> bits; }); + pforeach(0, numel, [&](int64_t idx) { + _ret[idx] = _x[idx] >> (is_splat ? bits[0] : bits[idx]); + }); }); } -void ring_rshift_impl(NdArrayRef& ret, const NdArrayRef& x, size_t bits) { +void ring_rshift_impl(NdArrayRef& ret, const NdArrayRef& x, const Sizes& bits) { ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, x); + bool is_splat = bits.size() == 1; + SPU_ENFORCE(static_cast(bits.size()) == x.numel() || is_splat, + "mismatched numel {} vs {}", bits.size(), x.numel()); const auto numel = ret.numel(); const auto field = x.eltype().as()->field(); - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using U = ring2k_t; NdArrayView _ret(ret); NdArrayView _x(x); - pforeach(0, numel, [&](int64_t idx) { _ret[idx] = _x[idx] >> bits; }); + pforeach(0, numel, [&](int64_t idx) { + _ret[idx] = _x[idx] >> (is_splat ? bits[0] : bits[idx]); + }); }); } -void ring_lshift_impl(NdArrayRef& ret, const NdArrayRef& x, size_t bits) { +void ring_lshift_impl(NdArrayRef& ret, const NdArrayRef& x, const Sizes& bits) { ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, x); + bool is_splat = bits.size() == 1; + SPU_ENFORCE(static_cast(bits.size()) == x.numel() || is_splat, + "mismatched numel {} vs {}", bits.size(), x.numel()); const auto numel = ret.numel(); const auto field = x.eltype().as()->field(); - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _ret(ret); NdArrayView _x(x); - pforeach(0, numel, [&](int64_t idx) { _ret[idx] = _x[idx] << bits; }); + pforeach(0, numel, [&](int64_t idx) { + _ret[idx] = _x[idx] << (is_splat ? bits[0] : bits[idx]); + }); }); } @@ -130,7 +144,7 @@ void ring_bitrev_impl(NdArrayRef& ret, const NdArrayRef& x, size_t start, const auto field = x.eltype().as()->field(); const auto numel = ret.numel(); - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using U = ring2k_t; // optimize: use faster reverse method. @@ -161,7 +175,7 @@ void ring_bitmask_impl(NdArrayRef& ret, const NdArrayRef& x, size_t low, SPU_ENFORCE(low < high && high <= SizeOf(field) * 8); - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using U = ring2k_t; U mask = 0; if (high - low < SizeOf(field) * 8) { @@ -184,7 +198,7 @@ void ring_print(const NdArrayRef& x, std::string_view name) { SPU_ENFORCE_RING(x); const auto field = x.eltype().as()->field(); - DISPATCH_ALL_FIELDS(field, kModule, [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = ring2k_t; std::string out; @@ -231,7 +245,7 @@ NdArrayRef ring_rand_range(FieldType field, const Shape& shape, int32_t min, NdArrayRef x(makeType(field), shape); auto numel = x.numel(); - DISPATCH_ALL_FIELDS(field, kModule, [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { SPU_ENFORCE(sizeof(ring2k_t) >= sizeof(int32_t)); auto iter = x.begin(); @@ -250,7 +264,7 @@ void ring_assign(NdArrayRef& x, const NdArrayRef& y) { const auto numel = x.numel(); const auto field = x.eltype().as()->field(); - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _y(y); NdArrayView _x(x); pforeach(0, numel, [&](int64_t idx) { _x[idx] = _y[idx]; }); @@ -261,7 +275,7 @@ NdArrayRef ring_zeros(FieldType field, const Shape& shape) { NdArrayRef ret(makeType(field), shape); auto numel = ret.numel(); - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _ret(ret); pforeach(0, numel, [&](int64_t idx) { _ret[idx] = ring2k_t(0); }); return ret; @@ -272,7 +286,7 @@ NdArrayRef ring_ones(FieldType field, const Shape& shape) { NdArrayRef ret(makeType(field), shape); auto numel = ret.numel(); - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _ret(ret); pforeach(0, numel, [&](int64_t idx) { _ret[idx] = ring2k_t(1); }); return ret; @@ -287,7 +301,7 @@ NdArrayRef ring_randbit(FieldType field, const Shape& shape) { NdArrayRef ret(makeType(field), shape); auto numel = ret.numel(); - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _ret(ret); for (auto idx = 0; idx < numel; ++idx) { _ret[idx] = distrib(gen) & 0x1; @@ -341,7 +355,7 @@ void ring_mul_impl(NdArrayRef& ret, const NdArrayRef& x, uint128_t y) { const auto numel = x.numel(); const auto field = x.eltype().as()->field(); - DISPATCH_ALL_FIELDS(field, kModule, [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = std::make_unsigned::type; NdArrayView _x(x); NdArrayView _ret(ret); @@ -368,7 +382,7 @@ void ring_mmul_impl(NdArrayRef& z, const NdArrayRef& lhs, SPU_ENFORCE(rhs.eltype().isa(), "rhs not ring, got={}", rhs.eltype()); const auto field = lhs.eltype().as()->field(); - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { const auto lhs_stride_scale = lhs.elsize() / sizeof(ring2k_t); const auto rhs_stride_scale = rhs.elsize() / sizeof(ring2k_t); const auto ret_stride_scale = z.elsize() / sizeof(ring2k_t); @@ -436,31 +450,35 @@ void ring_equal_(NdArrayRef& x, const NdArrayRef& y) { ring_equal_impl(x, x, y); } -NdArrayRef ring_arshift(const NdArrayRef& x, size_t bits) { +NdArrayRef ring_arshift(const NdArrayRef& x, const Sizes& bits) { NdArrayRef res(x.eltype(), x.shape()); ring_arshift_impl(res, x, bits); return res; } -void ring_arshift_(NdArrayRef& x, size_t bits) { +void ring_arshift_(NdArrayRef& x, const Sizes& bits) { ring_arshift_impl(x, x, bits); } -NdArrayRef ring_rshift(const NdArrayRef& x, size_t bits) { +NdArrayRef ring_rshift(const NdArrayRef& x, const Sizes& bits) { NdArrayRef res(x.eltype(), x.shape()); ring_rshift_impl(res, x, bits); return res; } -void ring_rshift_(NdArrayRef& x, size_t bits) { ring_rshift_impl(x, x, bits); } +void ring_rshift_(NdArrayRef& x, const Sizes& bits) { + ring_rshift_impl(x, x, bits); +} -NdArrayRef ring_lshift(const NdArrayRef& x, size_t bits) { +NdArrayRef ring_lshift(const NdArrayRef& x, const Sizes& bits) { NdArrayRef res(x.eltype(), x.shape()); ring_lshift_impl(res, x, bits); return res; } -void ring_lshift_(NdArrayRef& x, size_t bits) { ring_lshift_impl(x, x, bits); } +void ring_lshift_(NdArrayRef& x, const Sizes& bits) { + ring_lshift_impl(x, x, bits); +} NdArrayRef ring_bitrev(const NdArrayRef& x, size_t start, size_t end) { NdArrayRef res(x.eltype(), x.shape()); @@ -503,7 +521,7 @@ bool ring_all_equal(const NdArrayRef& x, const NdArrayRef& y, size_t abs_err) { auto numel = x.numel(); const auto field = x.eltype().as()->field(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using T = std::make_signed_t; NdArrayView _x(x); @@ -529,7 +547,7 @@ std::vector ring_cast_boolean(const NdArrayRef& x) { auto numel = x.numel(); std::vector res(numel); - DISPATCH_ALL_FIELDS(field, kModule, [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _x(x); pforeach(0, numel, [&](int64_t idx) { res[idx] = static_cast(_x[idx] & 0x1); @@ -549,7 +567,7 @@ NdArrayRef ring_select(const std::vector& c, const NdArrayRef& x, NdArrayRef z(x.eltype(), x.shape()); const int64_t numel = c.size(); - DISPATCH_ALL_FIELDS(field, kModule, [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _x(x); NdArrayView _y(y); NdArrayView _z(z); diff --git a/libspu/mpc/utils/ring_ops.h b/libspu/mpc/utils/ring_ops.h index cea10a06..8a960ebc 100644 --- a/libspu/mpc/utils/ring_ops.h +++ b/libspu/mpc/utils/ring_ops.h @@ -97,14 +97,14 @@ NdArrayRef ring_equal(const NdArrayRef& x, const NdArrayRef& y); void ring_equal_(NdArrayRef& x, const NdArrayRef& y); DEF_RVALUE_BINARY_RING_OP(ring_equal, true); -NdArrayRef ring_arshift(const NdArrayRef& x, size_t bits); -void ring_arshift_(NdArrayRef& x, size_t bits); +NdArrayRef ring_arshift(const NdArrayRef& x, const Sizes& bits); +void ring_arshift_(NdArrayRef& x, const Sizes& bits); -NdArrayRef ring_rshift(const NdArrayRef& x, size_t bits); -void ring_rshift_(NdArrayRef& x, size_t bits); +NdArrayRef ring_rshift(const NdArrayRef& x, const Sizes& bits); +void ring_rshift_(NdArrayRef& x, const Sizes& bits); -NdArrayRef ring_lshift(const NdArrayRef& x, size_t bits); -void ring_lshift_(NdArrayRef& x, size_t bits); +NdArrayRef ring_lshift(const NdArrayRef& x, const Sizes& bits); +void ring_lshift_(NdArrayRef& x, const Sizes& bits); NdArrayRef ring_bitrev(const NdArrayRef& x, size_t start, size_t end); void ring_bitrev_(NdArrayRef& x, size_t start, size_t end); diff --git a/libspu/version.h b/libspu/version.h index d4a7a9fe..a7f569ae 100644 --- a/libspu/version.h +++ b/libspu/version.h @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#define SPU_VERSION "0.9.2.dev$$DATE$$" +#define SPU_VERSION "0.9.3.dev$$DATE$$" #include diff --git a/requirements.txt b/requirements.txt index bea5a7d1..ac8fd6cc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ grpcio>=1.42.0,!=1.48.0 -numpy>=1.22.0, < 2 +numpy>=1.22.0 protobuf>=4, <5 cloudpickle>=2.0.0 multiprocess>=0.70.12.2 cachetools>=5.0.0 -jax[cpu]>=0.4.16, <=0.4.26 # FIXME +jax[cpu]>=0.4.16, <=0.4.26 # FIXME: Jax 0.4.26+ select perf issue termcolor>=2.0.0 diff --git a/sml/ensemble/BUILD.bazel b/sml/ensemble/BUILD.bazel index 50bfc8c4..32e3eb55 100644 --- a/sml/ensemble/BUILD.bazel +++ b/sml/ensemble/BUILD.bazel @@ -21,5 +21,11 @@ py_library( srcs = ["adaboost.py"], deps = [ "//sml/tree:tree", + ] + + name = "forest", + srcs = ["forest.py"], + deps = [ + "//sml/tree", ], ) diff --git a/sml/ensemble/emulations/BUILD.bazel b/sml/ensemble/emulations/BUILD.bazel index 17ef2d53..16c1a939 100644 --- a/sml/ensemble/emulations/BUILD.bazel +++ b/sml/ensemble/emulations/BUILD.bazel @@ -17,10 +17,18 @@ load("@rules_python//python:defs.bzl", "py_binary") package(default_visibility = ["//visibility:public"]) py_binary( +<<<<<<< HEAD name = "adaboost_emul", srcs = ["adaboost_emul.py"], deps = [ "//sml/ensemble:adaboost", + ] +======= + name = "forest_emul", + srcs = ["forest_emul.py"], + deps = [ + "//sml/ensemble:forest", +>>>>>>> 854f3ef0925bc419ae804ae77a4a9843ec15db7f "//sml/utils:emulation", ], ) diff --git a/sml/ensemble/emulations/forest_emul.py b/sml/ensemble/emulations/forest_emul.py new file mode 100644 index 00000000..90437386 --- /dev/null +++ b/sml/ensemble/emulations/forest_emul.py @@ -0,0 +1,125 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time + +import jax.numpy as jnp +from sklearn.datasets import load_iris +from sklearn.ensemble import RandomForestClassifier + +import sml.utils.emulation as emulation +from sml.ensemble.forest import RandomForestClassifier as sml_rfc + +MAX_DEPTH = 3 +CONFIG_FILE = emulation.CLUSTER_ABY3_3PC + + +def emul_forest(mode=emulation.Mode.MULTIPROCESS): + def proc_wrapper( + n_estimators, + max_features, + criterion, + splitter, + max_depth, + bootstrap, + max_samples, + n_labels, + ): + rf_custom = sml_rfc( + n_estimators, + max_features, + criterion, + splitter, + max_depth, + bootstrap, + max_samples, + n_labels, + ) + + def proc(X, y): + rf_custom_fit = rf_custom.fit(X, y) + result = rf_custom_fit.predict(X) + return result + + return proc + + def load_data(): + iris = load_iris() + iris_data, iris_label = jnp.array(iris.data), jnp.array(iris.target) + # sorted_features: n_samples * n_features_in + n_samples, n_features_in = iris_data.shape + n_labels = len(jnp.unique(iris_label)) + sorted_features = jnp.sort(iris_data, axis=0) + new_threshold = (sorted_features[:-1, :] + sorted_features[1:, :]) / 2 + new_features = jnp.greater_equal( + iris_data[:, :], new_threshold[:, jnp.newaxis, :] + ) + new_features = new_features.transpose([1, 0, 2]).reshape(n_samples, -1) + + X, y = new_features[:, ::3], iris_label[:] + return X, y + + try: + # bandwidth and latency only work for docker mode + emulator = emulation.Emulator(CONFIG_FILE, mode, bandwidth=300, latency=20) + emulator.up() + + # load mock data + X, y = load_data() + n_labels = jnp.unique(y).shape[0] + + # compare with sklearn + rf = RandomForestClassifier( + n_estimators=3, + max_features=None, + criterion='gini', + max_depth=MAX_DEPTH, + bootstrap=False, + max_samples=None, + ) + start = time.time() + rf = rf.fit(X, y) + score_plain = rf.score(X, y) + end = time.time() + print(f"Running time in SKlearn: {end - start:.2f}s") + + # mark these data to be protected in SPU + X_spu, y_spu = emulator.seal(X, y) + + # run + proc = proc_wrapper( + n_estimators=3, + max_features=0.7, + criterion='gini', + splitter='best', + max_depth=3, + bootstrap=False, + max_samples=None, + n_labels=n_labels, + ) + start = time.time() + result = emulator.run(proc)(X_spu, y_spu) + end = time.time() + score_encrpted = jnp.mean((result == y)) + print(f"Running time in SPU: {end - start:.2f}s") + + # print acc + print(f"Accuracy in SKlearn: {score_plain:.2f}") + print(f"Accuracy in SPU: {score_encrpted:.2f}") + + finally: + emulator.down() + + +if __name__ == "__main__": + emul_forest(emulation.Mode.MULTIPROCESS) diff --git a/sml/ensemble/forest.py b/sml/ensemble/forest.py new file mode 100644 index 00000000..c2dd2649 --- /dev/null +++ b/sml/ensemble/forest.py @@ -0,0 +1,201 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import random + +import jax +import jax.numpy as jnp +from jax import lax + +from sml.tree.tree import DecisionTreeClassifier as sml_dtc + + +class RandomForestClassifier: + """A random forest classifier based on DecisionTreeClassifier. + + Parameters + ---------- + n_estimators : int + The number of trees in the forest. Must specify an integer > 0. + + max_features : int, float, "auto", "sqrt", "log2", or None. + The number of features to consider when looking for the best split. + If it's an integer, must 0 < integer < n_features. + If it's an float, must 0 < float <= 1. + + criterion : {"gini"}, default="gini" + The function to measure the quality of a split. Supported criteria are + "gini" for the Gini impurity. + + splitter : {"best"}, default="best" + The strategy used to choose the split at each node. Supported + strategies are "best" to choose the best split. + + max_depth : int + The maximum depth of the tree. Must specify an integer > 0. + + bootstrap : bool + Whether bootstrap samples are used when building trees. + + max_samples : int, float ,None, default=None + The number of samples to draw from X to train each base estimator. + This parameter is only valid if bootstrap is ture. + If it's an integer, must 0 < integer < n_samples. + If it's an float, must 0 < float <= 1. + + n_labels: int + The max number of labels. + + """ + + def __init__( + self, + n_estimators, + max_features, + criterion, + splitter, + max_depth, + bootstrap, + max_samples, + n_labels, + ): + assert criterion == "gini", "criteria other than gini is not supported." + assert splitter == "best", "splitter other than best is not supported." + assert ( + n_estimators is not None and n_estimators > 0 + ), "n_estimators should not be None and must > 0." + assert ( + max_depth is not None and max_depth > 0 + ), "max_depth should not be None and must > 0." + assert isinstance( + bootstrap, bool + ), "bootstrap should be a boolean value (True or False)" + + self.n_estimators = n_estimators + self.max_features = max_features + self.criterion = criterion + self.splitter = splitter + self.max_depth = max_depth + self.bootstrap = bootstrap + self.max_samples = max_samples + self.n_labels = n_labels + + self.trees = [] + self.features_indices = [] + + def _calculate_max_samples(self, max_samples, n_samples): + if isinstance(max_samples, int): + assert ( + max_samples <= n_samples + ), "max_samples should not exceed n_samples when it's an integer" + return max_samples + elif isinstance(max_samples, float): + assert ( + 0 < max_samples <= 1 + ), "max_samples should be in the range (0, 1] when it's a float" + return int(max_samples * n_samples) + else: + return n_samples + + def _bootstrap_sample(self, X, y): + n_samples = X.shape[0] + max_samples = self._calculate_max_samples(self.max_samples, n_samples) + + if not self.bootstrap: + return X, y + + # 实现bootstrap + population = range(n_samples) + indices = random.sample(population, max_samples) + + indices = jnp.array(indices) + return X[indices], y[indices] + + def _select_features(self, n, k): + indices = range(n) + selected_elements = random.sample(indices, k) + return selected_elements + + def _calculate_max_features(self, max_features, n_features): + if isinstance(max_features, int): + assert ( + 0 < max_features <= n_features + ), "0 < max_features <= n_features when it's an integer" + return max_features + + elif isinstance(max_features, float): + assert ( + 0 < max_features <= 1 + ), "max_features should be in the range (0, 1] when it's a float" + return int(max_features * n_features) + + elif isinstance(max_features, str): + if max_features == 'sqrt': + return int(math.sqrt(n_features)) + elif max_features == 'log2': + return int(math.log2(n_features)) + else: + return n_features + else: + return n_features + + def fit(self, X, y): + n_samples, n_features = X.shape + self.n_features = n_features + self.max_features = self._calculate_max_features( + self.max_features, self.n_features + ) + self.label_list = jnp.arange(self.n_labels) + + self.trees = [] + self.features_indices = [] + + for _ in range(self.n_estimators): + X_sample, y_sample = self._bootstrap_sample(X, y) + features = self._select_features(self.n_features, self.max_features) + + tree = sml_dtc(self.criterion, self.splitter, self.max_depth, self.n_labels) + tree.fit(X_sample[:, features], y_sample) + self.trees.append(tree) + self.features_indices.append(features) + + return self + + def jax_mode_row_vectorized(self, data): + label_list = jnp.array(self.label_list) + + data_expanded = jnp.expand_dims(data, axis=-1) + label_expanded = jnp.expand_dims(label_list, axis=0) + + mask = (data_expanded == label_expanded).astype(jnp.int32) + + counts = jnp.sum(mask, axis=1) + mode_indices = jnp.argmax(counts, axis=1) + + modes = label_list[mode_indices] + return modes + + def predict(self, X): + predictions_list = [] + for i, tree in enumerate(self.trees): + features = self.features_indices[i] + predictions = tree.predict(X[:, features]) + predictions_list.append(predictions) + + tree_predictions = jnp.array(predictions_list).T + + y_pred = self.jax_mode_row_vectorized(tree_predictions) + + return y_pred.ravel() diff --git a/sml/ensemble/tests/BUILD.bazel b/sml/ensemble/tests/BUILD.bazel index b0d1d676..a0b7756b 100644 --- a/sml/ensemble/tests/BUILD.bazel +++ b/sml/ensemble/tests/BUILD.bazel @@ -24,4 +24,11 @@ py_test( "//spu:init", "//spu/utils:simulation", ], -) \ No newline at end of file + name = "forest_test", + srcs = ["forest_test.py"], + deps = [ + "//sml/ensemble:forest", + "//spu:init", + "//spu/utils:simulation", + ], +) diff --git a/sml/ensemble/tests/forest_test.py b/sml/ensemble/tests/forest_test.py new file mode 100644 index 00000000..6c9d3280 --- /dev/null +++ b/sml/ensemble/tests/forest_test.py @@ -0,0 +1,117 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import jax.numpy as jnp +from sklearn.datasets import load_iris +from sklearn.ensemble import RandomForestClassifier + +import spu.spu_pb2 as spu_pb2 # type: ignore +import spu.utils.simulation as spsim +from sml.ensemble.forest import RandomForestClassifier as sml_rfc + +MAX_DEPTH = 3 + + +class UnitTests(unittest.TestCase): + def test_forest(self): + def proc_wrapper( + n_estimators, + max_features, + criterion, + splitter, + max_depth, + bootstrap, + max_samples, + n_labels, + ): + rf_custom = sml_rfc( + n_estimators, + max_features, + criterion, + splitter, + max_depth, + bootstrap, + max_samples, + n_labels, + ) + + def proc(X, y): + rf_custom_fit = rf_custom.fit(X, y) + + result = rf_custom_fit.predict(X) + return result + + return proc + + def load_data(): + iris = load_iris() + iris_data, iris_label = jnp.array(iris.data), jnp.array(iris.target) + n_samples, n_features_in = iris_data.shape + n_labels = len(jnp.unique(iris_label)) + sorted_features = jnp.sort(iris_data, axis=0) + new_threshold = (sorted_features[:-1, :] + sorted_features[1:, :]) / 2 + new_features = jnp.greater_equal( + iris_data[:, :], new_threshold[:, jnp.newaxis, :] + ) + new_features = new_features.transpose([1, 0, 2]).reshape(n_samples, -1) + + X, y = new_features[:, ::3], iris_label[:] + return X, y + + # bandwidth and latency only work for docker mode + sim = spsim.Simulator.simple( + 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 + ) + + # load mock data + X, y = load_data() + n_labels = jnp.unique(y).shape[0] + + # compare with sklearn + rf = RandomForestClassifier( + n_estimators=3, + max_features="log2", + criterion='gini', + max_depth=MAX_DEPTH, + bootstrap=True, + max_samples=0.7, + ) + rf = rf.fit(X, y) + score_plain = rf.score(X, y) + tree_predictions = jnp.array([tree.predict(X) for tree in rf.estimators_]) + + # run + proc = proc_wrapper( + n_estimators=3, + max_features="log2", + criterion='gini', + splitter='best', + max_depth=3, + bootstrap=True, + max_samples=0.7, + n_labels=n_labels, + ) + + result = spsim.sim_jax(sim, proc)(X, y) + + score_encrpted = jnp.mean((result == y)) + + # print acc + print(f"Accuracy in SKlearn: {score_plain}") + print(f"Accuracy in SPU: {score_encrpted}") + + +if __name__ == "__main__": + unittest.main() diff --git a/spu/tests/BUILD.bazel b/spu/tests/BUILD.bazel index d7453e60..fd037279 100644 --- a/spu/tests/BUILD.bazel +++ b/spu/tests/BUILD.bazel @@ -87,7 +87,7 @@ py_test( py_test( name = "jnp_cheetah_r64_test", - timeout = "long", + size = "enormous", srcs = ["jnp_cheetah_r64_test.py"], deps = [ ":jnp_testbase", @@ -96,7 +96,7 @@ py_test( py_test( name = "jnp_cheetah_r64_test_x64", - timeout = "long", + size = "enormous", srcs = ["jnp_cheetah_r64_test.py"], env = { "ENABLE_X64_TEST": "1", diff --git a/spu/tests/jnp_cheetah_r64_test.py b/spu/tests/jnp_cheetah_r64_test.py index b113dbcd..10c02a53 100644 --- a/spu/tests/jnp_cheetah_r64_test.py +++ b/spu/tests/jnp_cheetah_r64_test.py @@ -22,8 +22,7 @@ from spu.tests.jnp_testbase import JnpTests -@unittest.skip("too slow, last run succeed") -class JnpTestAby3FM64(JnpTests.JnpTestBase): +class JnpTestCheetahFM64(JnpTests.JnpTestBase): def setUp(self): self._sim = ppsim.Simulator.simple( 2, spu_pb2.ProtocolKind.CHEETAH, spu_pb2.FieldType.FM64 diff --git a/spu/utils/distributed_impl.py b/spu/utils/distributed_impl.py index 79b13426..ebd6a703 100644 --- a/spu/utils/distributed_impl.py +++ b/spu/utils/distributed_impl.py @@ -723,7 +723,10 @@ def mock_parameters(obj: Union[SPU.Object, np.ndarray]): fn_name = repr(fn) - import jax.extend.linear_util as lu + try: + import jax.extend.linear_util as lu + except ImportError: + import jax.linear_util as lu # fallback from jax._src import api_util as japi_util from jax.tree_util import tree_map, tree_flatten diff --git a/spu/utils/frontend.py b/spu/utils/frontend.py index e1a4c24c..fc20cd0c 100644 --- a/spu/utils/frontend.py +++ b/spu/utils/frontend.py @@ -115,13 +115,32 @@ def _jax_compilation( register_backend_factory('interpreter', xla_back, priority=-100) - fn, kwargs = _argnames_partial_except(fn, static_argnames, kwargs) + jax_version = jax.__version_info__ + + if jax_version[0] > 1 or jax_version[1] > 4 or jax_version[2] > 29: + # xla_computation is deprecated since 0.4.30, move to new api + lowered = ( + jax.jit( + fn, + static_argnums=static_argnums, + static_argnames=static_argnames, + keep_unused=True, + ) + .trace(*args, **kwargs) + .lower(lowering_platforms=('interpreter',)) + ) + return ( + lowered.compiler_ir('hlo').as_serialized_hlo_module_proto(), + lowered.out_info, + ) + else: + fn, kwargs = _argnames_partial_except(fn, static_argnames, kwargs) - cfn, output = jax.xla_computation( - fn, return_shape=True, static_argnums=static_argnums, backend="interpreter" - )(*args, **kwargs) + cfn, output = jax.xla_computation( + fn, return_shape=True, static_argnums=static_argnums, backend="interpreter" + )(*args, **kwargs) - return cfn.as_serialized_hlo_module_proto(), output + return cfn.as_serialized_hlo_module_proto(), output ## Frontend patches diff --git a/spu/utils/simulation.py b/spu/utils/simulation.py index 8df6185c..592ccaff 100644 --- a/spu/utils/simulation.py +++ b/spu/utils/simulation.py @@ -17,7 +17,11 @@ from typing import Callable import jax -import jax.extend.linear_util as jax_lu # Moved in jax 0.4.16 + +try: + import jax.extend.linear_util as jax_lu +except ImportError: + import jax.linear_util as jax_lu # fallback import jax.numpy as jnp import numpy as np from jax._src import api_util as japi_util