From c03eccf7c2fa36b2f9da96eb1afa2bba377e85d8 Mon Sep 17 00:00:00 2001 From: Yancheng Zheng <103552181+anakinxc@users.noreply.github.com> Date: Wed, 3 Jul 2024 16:26:25 +0800 Subject: [PATCH 01/27] Update yacl (#755) --- bazel/repositories.bzl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bazel/repositories.bzl b/bazel/repositories.bzl index ca1b5349..83bbfe20 100644 --- a/bazel/repositories.bzl +++ b/bazel/repositories.bzl @@ -40,10 +40,10 @@ def _yacl(): http_archive, name = "yacl", urls = [ - "https://github.com/secretflow/yacl/archive/refs/tags/0.4.5b1.tar.gz", + "https://github.com/secretflow/yacl/archive/refs/tags/0.4.5b3.tar.gz", ], - strip_prefix = "yacl-0.4.5b1", - sha256 = "28064053b9add0db8e1e8e648421a0579f1d3e7ee8a4bbd7bd5959cb59598088", + strip_prefix = "yacl-0.4.5b3", + sha256 = "bd89d63312e5e83eff5e001e2cf2135baff321c4b72a309f7d00cc53ce02e1a1", ) def _libpsi(): From 442158f34438c98ca9cefcbd3b6483fc10219373 Mon Sep 17 00:00:00 2001 From: Yancheng Zheng <103552181+anakinxc@users.noreply.github.com> Date: Mon, 8 Jul 2024 15:01:55 +0800 Subject: [PATCH 02/27] Repo sync (#757) --- bazel/repositories.bzl | 12 +- docs/reference/np_op_status.json | 2 +- docs/reference/np_op_status.md | 14 ++ docs/reference/pphlo_doc.rst | 6 +- docs/reference/pphlo_op_doc.md | 58 ++++--- docs/reference/runtime_config.md | 3 +- libspu/compiler/common/compilation_context.cc | 2 +- libspu/compiler/common/ir_printer_config.cc | 2 + libspu/compiler/front_end/fe.cc | 2 + libspu/compiler/front_end/hlo_importer.cc | 6 +- libspu/device/api.cc | 2 +- libspu/kernel/hal/permute.cc | 28 ++- libspu/mpc/cheetah/boolean_semi2k.cc | 4 +- libspu/mpc/kernel.cc | 11 ++ libspu/mpc/kernel.h | 8 + libspu/mpc/semi2k/conversion.cc | 163 +++++++++++++++--- libspu/mpc/semi2k/conversion.h | 15 ++ libspu/mpc/semi2k/protocol.cc | 32 ++-- requirements.txt | 4 +- spu/tests/BUILD.bazel | 3 +- spu/tests/jnp_cheetah_r64_test.py | 3 +- spu/utils/distributed_impl.py | 5 +- spu/utils/frontend.py | 29 +++- spu/utils/simulation.py | 6 +- 24 files changed, 326 insertions(+), 94 deletions(-) diff --git a/bazel/repositories.bzl b/bazel/repositories.bzl index 83bbfe20..3fe35ee0 100644 --- a/bazel/repositories.bzl +++ b/bazel/repositories.bzl @@ -51,10 +51,10 @@ def _libpsi(): http_archive, name = "psi", urls = [ - "https://github.com/secretflow/psi/archive/refs/tags/v0.4.0.dev240524.tar.gz", + "https://github.com/secretflow/psi/archive/refs/tags/v0.4.0beta.tar.gz", ], - strip_prefix = "psi-0.4.0.dev240524", - sha256 = "c2868fa6a9d804e6bbed9922dab6dc819ec6e180e15eafe7eb1b661302508c88", + strip_prefix = "psi-0.4.0beta", + sha256 = "c2fbf486a66eca9d3ec1725a81d93a7c6e80a9206ef1c9263a1608e0bef95e1a", ) def _rules_proto_grpc(): @@ -169,10 +169,10 @@ def _com_github_pybind11(): http_archive, name = "pybind11", build_file = "@pybind11_bazel//:pybind11.BUILD", - sha256 = "bf8f242abd1abcd375d516a7067490fb71abd79519a282d22b6e4d19282185a7", - strip_prefix = "pybind11-2.12.0", + sha256 = "51631e88960a8856f9c497027f55c9f2f9115cafb08c0005439838a05ba17bfc", + strip_prefix = "pybind11-2.13.1", urls = [ - "https://github.com/pybind/pybind11/archive/refs/tags/v2.12.0.tar.gz", + "https://github.com/pybind/pybind11/archive/refs/tags/v2.13.1.tar.gz", ], ) diff --git a/docs/reference/np_op_status.json b/docs/reference/np_op_status.json index 1272847c..20ed81a7 100644 --- a/docs/reference/np_op_status.json +++ b/docs/reference/np_op_status.json @@ -1 +1 @@ -[{"name": "abs", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "add", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "angle", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "angle", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "arccos", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "arccosh", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "nan"}, {"name": "arcsin", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "arcsinh", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "arctan", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "arctan2", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "arctanh", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "[-1, 1] nan"}, {"name": "argmax", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "argmin", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "array_equal", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "array_equiv", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "atleast_1d", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "atleast_2d", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "atleast_3d", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "bitwise_and", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "bitwise_not", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "bitwise_or", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "bitwise_xor", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "cbrt", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "stablehlo.cbrt"}, {"name": "ceil", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "conj", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "conjugate", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "copysign", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "shift"}, {"name": "cos", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "cosh", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "deg2rad", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "divide", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "divmod", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "ediff1d", "dtypes": ["int32"], "status": "Status.Pass", "note": ""}, {"name": "equal", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "exp", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "exp2", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "expm1", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "fabs", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "fix", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "float_power", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "floor", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "floor_divide", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "fmax", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "fmin", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "fmod", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "gcd", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "secret while"}, {"name": "greater", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "greater_equal", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "heaviside", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "hypot", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "i0", "dtypes": ["float32"], "status": "Status.Failed", "note": "accuracy"}, {"name": "imag", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "invert", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "isclose", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "iscomplex", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "isfinite", "dtypes": ["float32"], "status": "Status.UnSupport", "note": ""}, {"name": "isinf", "dtypes": ["float32"], "status": "Status.UnSupport", "note": ""}, {"name": "isnan", "dtypes": ["float32"], "status": "Status.UnSupport", "note": ""}, {"name": "isneginf", "dtypes": ["float32"], "status": "Status.UnSupport", "note": ""}, {"name": "isposinf", "dtypes": ["float32"], "status": "Status.UnSupport", "note": ""}, {"name": "isreal", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "isrealobj", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "kron", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "lcm", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "secret while"}, {"name": "ldexp", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Failed", "note": "IEEE-754"}, {"name": "left_shift", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "less", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "less_equal", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "log", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "log10", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "log1p", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "log1p", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "log2", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "logaddexp", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "logaddexp2", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "logical_and", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "logical_not", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "logical_or", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "logical_xor", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "maximum", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "mean", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "mean", "dtypes": ["float32"], "status": "Status.PassNoGen", "note": ""}, {"name": "minimum", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "mod", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "modf", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "multiply", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "nanargmax", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.UnSupport", "note": ""}, {"name": "nanargmin", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.UnSupport", "note": ""}, {"name": "nanmean", "dtypes": ["float32"], "status": "Status.UnSupport", "note": ""}, {"name": "nanprod", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.UnSupport", "note": ""}, {"name": "nansum", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.UnSupport", "note": ""}, {"name": "negative", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "nextafter", "dtypes": ["float32"], "status": "Status.Failed", "note": "IEEE-754"}, {"name": "not_equal", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "outer", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "polyval", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "positive", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "power", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "prod", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "rad2deg", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "ravel", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "real", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "reciprocal", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "remainder", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "right_shift", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "rint", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "sign", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "signbit", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "sin", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "sinc", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "sinh", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "sqrt", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "square", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "square", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "subtract", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "sum", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "tan", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "stablehlo.tan"}, {"name": "tanh", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "transpose", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "true_divide", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "trunc", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "unwrap", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}] \ No newline at end of file +[{"name": "abs", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "add", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "angle", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "angle", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "arccos", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "arccosh", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "nan"}, {"name": "arcsin", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "arcsinh", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "arctan", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "arctan2", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "arctanh", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "[-1, 1] nan"}, {"name": "argmax", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "argmin", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "array_equal", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "array_equiv", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "atleast_1d", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "atleast_2d", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "atleast_3d", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "bitwise_and", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "bitwise_count", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "bitwise_not", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "bitwise_or", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "bitwise_xor", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "cbrt", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "stablehlo.cbrt"}, {"name": "ceil", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "conj", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "conjugate", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "copysign", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "shift"}, {"name": "cos", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "cosh", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "deg2rad", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "divide", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "divmod", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "ediff1d", "dtypes": ["int32"], "status": "Status.Pass", "note": ""}, {"name": "equal", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "exp", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "exp2", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "expm1", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "fabs", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "fix", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "float_power", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "floor", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "floor_divide", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "fmax", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "fmin", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "fmod", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "gcd", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "secret while"}, {"name": "greater", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "greater_equal", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "heaviside", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "hypot", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "i0", "dtypes": ["float32"], "status": "Status.Failed", "note": "accuracy"}, {"name": "imag", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "invert", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "isclose", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "iscomplex", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "isfinite", "dtypes": ["float32"], "status": "Status.UnSupport", "note": ""}, {"name": "isinf", "dtypes": ["float32"], "status": "Status.UnSupport", "note": ""}, {"name": "isnan", "dtypes": ["float32"], "status": "Status.UnSupport", "note": ""}, {"name": "isneginf", "dtypes": ["float32"], "status": "Status.UnSupport", "note": ""}, {"name": "isposinf", "dtypes": ["float32"], "status": "Status.UnSupport", "note": ""}, {"name": "isreal", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "isrealobj", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "kron", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "lcm", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "secret while"}, {"name": "ldexp", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Failed", "note": "IEEE-754"}, {"name": "left_shift", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "less", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "less_equal", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "log", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "log10", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "log1p", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "log1p", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "log2", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "logaddexp", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "logaddexp2", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "logical_and", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "logical_not", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "logical_or", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "logical_xor", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "maximum", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "mean", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "mean", "dtypes": ["float32"], "status": "Status.PassNoGen", "note": ""}, {"name": "minimum", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "mod", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "modf", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "multiply", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "nanargmax", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.UnSupport", "note": ""}, {"name": "nanargmin", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.UnSupport", "note": ""}, {"name": "nanmean", "dtypes": ["float32"], "status": "Status.UnSupport", "note": ""}, {"name": "nanprod", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.UnSupport", "note": ""}, {"name": "nansum", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.UnSupport", "note": ""}, {"name": "negative", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "nextafter", "dtypes": ["float32"], "status": "Status.Failed", "note": "IEEE-754"}, {"name": "not_equal", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "outer", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "polyval", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "positive", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "power", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "prod", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "rad2deg", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "ravel", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "real", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "reciprocal", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "remainder", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "right_shift", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "rint", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "sign", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "signbit", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "sin", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "sinc", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "sinh", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "sqrt", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "square", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "square", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "subtract", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "sum", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "tan", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "stablehlo.tan"}, {"name": "tanh", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "transpose", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "true_divide", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "trunc", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "unwrap", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}] \ No newline at end of file diff --git a/docs/reference/np_op_status.md b/docs/reference/np_op_status.md index d69f20a0..ab3539a4 100644 --- a/docs/reference/np_op_status.md +++ b/docs/reference/np_op_status.md @@ -293,6 +293,20 @@ Please check *Supported Dtypes* as well. - uint16 - uint32 +## bitwise_count + +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.bitwise_count.html +### Status + +**PASS** +Please check *Supported Dtypes* as well. +### Supported Dtypes + +- int16 +- int32 +- uint16 +- uint32 + ## bitwise_not JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.bitwise_not.html diff --git a/docs/reference/pphlo_doc.rst b/docs/reference/pphlo_doc.rst index 751541e1..148d417e 100644 --- a/docs/reference/pphlo_doc.rst +++ b/docs/reference/pphlo_doc.rst @@ -1,9 +1,9 @@ -PPHlo API reference +PPHLO API reference =================== -PPHlo is short for (SPU High level ops), it's the assembly language of SPU. +PPHLO is short for (Privacy-Preserving High-Level Operations), it's the assembly language of SPU. -PPHlo is built on `MLIR `_ infrastructure, the concrete ops definition could be found :spu_code_host:`here `. +PPHLO is built on `MLIR `_ infrastructure, the concrete ops definition could be found :spu_code_host:`here `. Op List ~~~~~~~ diff --git a/docs/reference/pphlo_op_doc.md b/docs/reference/pphlo_op_doc.md index f6afc268..2839c9a1 100644 --- a/docs/reference/pphlo_op_doc.md +++ b/docs/reference/pphlo_op_doc.md @@ -747,7 +747,7 @@ Ref https://www.tensorflow.org/xla/operation_semantics#dot. Traits: `AlwaysSpeculatableImplTrait` -Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)` +Interfaces: `ConditionallySpeculatable`, `InferShapedTypeOpInterface`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)` Effects: `MemoryEffects::Effect{}` @@ -1626,55 +1626,63 @@ Effects: `MemoryEffects::Effect{}` | :----: | ----------- | | `result` | statically shaped tensor of 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values values -### `pphlo.power` (spu::pphlo::PowOp) +### `pphlo.popcnt` (spu::pphlo::PopcntOp) -_Power operator_ +_Popcnt operator, ties away from zero_ Syntax: ``` -operation ::= `pphlo.power` $lhs `,` $rhs attr-dict - `:` custom(type($lhs), type($rhs), type($result)) +operation ::= `pphlo.popcnt` $operand attr-dict `:` custom(type($operand), type($result)) ``` -Performs element-wise exponentiation of `lhs` tensor by `rhs` tensor and produces a `result` tensor. +Performs element-wise count of the number of bits set in the `operand` tensor and produces a `result` tensor. -Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#power +Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#popcnt -Traits: `AlwaysSpeculatableImplTrait`, `Elementwise`, `SameOperandsAndResultShape` +Traits: `AlwaysSpeculatableImplTrait`, `Elementwise`, `SameOperandsAndResultShape`, `SameOperandsAndResultType` -Interfaces: `ConditionallySpeculatable`, `InferShapedTypeOpInterface`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)` +Interfaces: `ConditionallySpeculatable`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)` Effects: `MemoryEffects::Effect{}` +#### Attributes: + + + + +
AttributeMLIR TypeDescription
bits::mlir::IntegerAttr64-bit signless integer attribute
+ #### Operands: | Operand | Description | | :-----: | ----------- | -| `lhs` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values -| `rhs` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values +| `operand` | statically shaped tensor of 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values values #### Results: | Result | Description | | :----: | ----------- | -| `result` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values +| `result` | statically shaped tensor of 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values values -### `pphlo.prefer_a` (spu::pphlo::PreferAOp) +### `pphlo.power` (spu::pphlo::PowOp) -_Prefer AShare operator_ +_Power operator_ Syntax: ``` -operation ::= `pphlo.prefer_a` $operand attr-dict `:` custom(type($operand), type($result)) +operation ::= `pphlo.power` $lhs `,` $rhs attr-dict + `:` custom(type($lhs), type($rhs), type($result)) ``` -Convert input to AShare if possible. +Performs element-wise exponentiation of `lhs` tensor by `rhs` tensor and produces a `result` tensor. -Traits: `AlwaysSpeculatableImplTrait`, `Elementwise`, `SameOperandsAndResultShape`, `SameOperandsAndResultType` +Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#power + +Traits: `AlwaysSpeculatableImplTrait`, `Elementwise`, `SameOperandsAndResultShape` Interfaces: `ConditionallySpeculatable`, `InferShapedTypeOpInterface`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)` @@ -1684,7 +1692,8 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `operand` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values +| `lhs` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values +| `rhs` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values #### Results: @@ -2270,12 +2279,21 @@ Returns the sign of the `operand` element-wise and produces a `result` tensor. Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sign +PPHLO Extension: when `ignore_zero` is set to true, sign does not enforce sign(0) to 0 + Traits: `AlwaysSpeculatableImplTrait`, `Elementwise`, `SameOperandsAndResultShape`, `SameOperandsAndResultType` Interfaces: `ConditionallySpeculatable`, `InferShapedTypeOpInterface`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)` Effects: `MemoryEffects::Effect{}` +#### Attributes: + + + + +
AttributeMLIR TypeDescription
ignore_zero::mlir::BoolAttrbool attribute
+ #### Operands: | Operand | Description | @@ -2377,7 +2395,7 @@ Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#slice Traits: `AlwaysSpeculatableImplTrait`, `SameOperandsAndResultElementType` -Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)` +Interfaces: `ConditionallySpeculatable`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)` Effects: `MemoryEffects::Effect{}` @@ -2551,7 +2569,7 @@ Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#transpose Traits: `AlwaysSpeculatableImplTrait`, `SameOperandsAndResultElementType` -Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)` +Interfaces: `ConditionallySpeculatable`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)` Effects: `MemoryEffects::Effect{}` diff --git a/docs/reference/runtime_config.md b/docs/reference/runtime_config.md index f556bce4..836095cc 100644 --- a/docs/reference/runtime_config.md +++ b/docs/reference/runtime_config.md @@ -179,8 +179,9 @@ The SPU runtime configuration. | Field | Type | Description | | ----- | ---- | ----------- | | server_host | [ string](#string) | TrustedThirdParty beaver server's remote ip:port or load-balance uri. | -| session_id | [ string](#string) | if empty, use link id as session id. | | adjust_rank | [ int32](#int32) | which rank do adjust rpc call, usually choose the rank closer to the server. | +| asym_crypto_schema | [ string](#string) | asym_crypto_schema: support ["SM2"] Will support 25519 in the future, after yacl supported it. | +| server_public_key | [ bytes](#bytes) | server's public key | diff --git a/libspu/compiler/common/compilation_context.cc b/libspu/compiler/common/compilation_context.cc index ef22871c..553fe913 100644 --- a/libspu/compiler/common/compilation_context.cc +++ b/libspu/compiler/common/compilation_context.cc @@ -23,7 +23,7 @@ namespace { void SPUErrorHandler(void * /*use_data*/, const char *reason, bool /*gen_crash_diag*/) { - SPU_THROW(reason); + SPU_THROW("{}", reason); } } // namespace diff --git a/libspu/compiler/common/ir_printer_config.cc b/libspu/compiler/common/ir_printer_config.cc index 47dff64a..f3c15ff2 100644 --- a/libspu/compiler/common/ir_printer_config.cc +++ b/libspu/compiler/common/ir_printer_config.cc @@ -51,6 +51,7 @@ void IRPrinterConfig::printBeforeIfEnabled(Pass *pass, Operation *, if (ec.value() != 0) { spdlog::error("Open file {} failed, error = {}", file_name.c_str(), ec.message()); + return; } print_callback(f); } @@ -64,6 +65,7 @@ void IRPrinterConfig::printAfterIfEnabled(Pass *pass, Operation *, if (ec.value() != 0) { spdlog::error("Open file {} failed, error = {}", file_name.c_str(), ec.message()); + return; } print_callback(f); } diff --git a/libspu/compiler/front_end/fe.cc b/libspu/compiler/front_end/fe.cc index 682f1c5b..e560c772 100644 --- a/libspu/compiler/front_end/fe.cc +++ b/libspu/compiler/front_end/fe.cc @@ -54,6 +54,8 @@ mlir::OwningOpRef FE::doit(const CompilationSource &source) { module = mlir::parseSourceString(source.ir_txt(), ctx_->getMLIRContext()); + SPU_ENFORCE(module, "MLIR parser failure"); + // Convert stablehlo to mhlo first mlir::PassManager pm(ctx_->getMLIRContext()); pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); diff --git a/libspu/compiler/front_end/hlo_importer.cc b/libspu/compiler/front_end/hlo_importer.cc index 5845af57..363bd76c 100644 --- a/libspu/compiler/front_end/hlo_importer.cc +++ b/libspu/compiler/front_end/hlo_importer.cc @@ -196,12 +196,12 @@ HloImporter::parseXlaModuleFromString(const std::string &content) { auto module_config = xla::HloModule::CreateModuleConfigFromProto(hlo_module, debug_options); if (!module_config.status().ok()) { - SPU_THROW(module_config.status().message()); + SPU_THROW("{}", module_config.status().message()); } auto module = xla::HloModule::CreateFromProto(hlo_module, *module_config); if (!module.status().ok()) { - SPU_THROW(module.status().message()); + SPU_THROW("{}", module.status().message()); } xla::runHloPasses((*module).get()); @@ -214,7 +214,7 @@ HloImporter::parseXlaModuleFromString(const std::string &content) { auto status = importer.Import(**module); if (!status.ok()) { - SPU_THROW(status.message()); + SPU_THROW("{}", status.message()); } return mlir_hlo; diff --git a/libspu/device/api.cc b/libspu/device/api.cc index 5535ad85..1b27f881 100644 --- a/libspu/device/api.cc +++ b/libspu/device/api.cc @@ -229,7 +229,7 @@ void printProfilingData(spu::SPUContext *sctx, const std::string &name, void SPUErrorHandler(void *use_data, const char *reason, bool gen_crash_diag) { (void)use_data; (void)gen_crash_diag; - SPU_THROW(reason); + SPU_THROW("{}", reason); } std::mutex ErrorHandlerMutex; diff --git a/libspu/kernel/hal/permute.cc b/libspu/kernel/hal/permute.cc index 657b81fc..1c67c997 100644 --- a/libspu/kernel/hal/permute.cc +++ b/libspu/kernel/hal/permute.cc @@ -19,6 +19,7 @@ #include "libspu/core/bit_utils.h" #include "libspu/core/context.h" #include "libspu/core/trace.h" +#include "libspu/core/vectorize.h" #include "libspu/kernel/hal/constants.h" #include "libspu/kernel/hal/polymorphic.h" #include "libspu/kernel/hal/prot_wrapper.h" @@ -43,6 +44,12 @@ inline bool _has_same_owner(const Value &x, const Value &y) { return _get_owner(x) == _get_owner(y); } +void _hint_nbits(const Value &a, size_t nbits) { + if (a.storage_type().isa()) { + const_cast(a.storage_type()).as()->setNbits(nbits); + } +} + // generate inverse permutation Index _inverse_index(const Index &p) { Index q(p.size()); @@ -531,20 +538,29 @@ spu::Value _opt_apply_perm_ss(SPUContext *ctx, const spu::Value &perm, std::vector _bit_decompose(SPUContext *ctx, const spu::Value &x, int64_t valid_bits) { auto x_bshare = _prefer_b(ctx, x); - const auto k1 = _constant(ctx, 1U, x.shape()); - std::vector rets; size_t nbits = valid_bits != -1 ? static_cast(valid_bits) : x_bshare.storage_type().as()->nbits(); - rets.reserve(nbits); + _hint_nbits(x_bshare, nbits); + if (ctx->hasKernel("b2a_disassemble")) { + auto ret = + dynDispatch>(ctx, "b2a_disassemble", x_bshare); + return ret; + } + + const auto k1 = _constant(ctx, 1U, x.shape()); + std::vector rets_b; + rets_b.reserve(nbits); for (size_t bit = 0; bit < nbits; ++bit) { auto x_bshare_shift = right_shift_logical(ctx, x_bshare, bit); - auto lowest_bit = _and(ctx, x_bshare_shift, k1); - rets.emplace_back(_prefer_a(ctx, lowest_bit)); + rets_b.push_back(_and(ctx, x_bshare_shift, k1)); } - return rets; + std::vector rets_a; + vmap(rets_b.begin(), rets_b.end(), std::back_inserter(rets_a), + [&](const Value &x) { return _prefer_a(ctx, x); }); + return rets_a; } // Generate vector of bit decomposition of sorting keys diff --git a/libspu/mpc/cheetah/boolean_semi2k.cc b/libspu/mpc/cheetah/boolean_semi2k.cc index a64da4cc..786d0c38 100644 --- a/libspu/mpc/cheetah/boolean_semi2k.cc +++ b/libspu/mpc/cheetah/boolean_semi2k.cc @@ -81,8 +81,8 @@ NdArrayRef P2B::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { if (comm->getRank() == 0) { ring_xor_(x, in); } - - return makeBShare(x, field, getNumBits(in)); + auto nbits = getNumBits(in) == 0 ? 1 : getNumBits(in); + return makeBShare(x, field, nbits); } NdArrayRef AndBP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, diff --git a/libspu/mpc/kernel.cc b/libspu/mpc/kernel.cc index 45bed576..90d51818 100644 --- a/libspu/mpc/kernel.cc +++ b/libspu/mpc/kernel.cc @@ -233,6 +233,17 @@ void ConcateKernel::evaluate(KernelEvalContext* ctx) const { ctx->pushOutput(WrapValue(z)); } +void DisassembleKernel::evaluate(KernelEvalContext* ctx) const { + const auto& in = ctx->getParam(0); + auto z = proc(ctx, UnwrapValue(in)); + + std::vector wrapped(z.size()); + for (size_t idx = 0; idx < z.size(); ++idx) { + wrapped[idx] = WrapValue(z[idx]); + } + ctx->pushOutput(wrapped); +}; + void OramOneHotKernel::evaluate(KernelEvalContext* ctx) const { auto target = ctx->getParam(0); auto s = ctx->getParam(1); diff --git a/libspu/mpc/kernel.h b/libspu/mpc/kernel.h index 4a4c29d8..12383391 100644 --- a/libspu/mpc/kernel.h +++ b/libspu/mpc/kernel.h @@ -217,4 +217,12 @@ class ConcateKernel : public Kernel { int64_t axis) const = 0; }; +class DisassembleKernel : public Kernel { + public: + void evaluate(KernelEvalContext* ctx) const override; + + virtual std::vector proc(KernelEvalContext* ctx, + const NdArrayRef& in) const = 0; +}; + } // namespace spu::mpc diff --git a/libspu/mpc/semi2k/conversion.cc b/libspu/mpc/semi2k/conversion.cc index 35d8c701..dce1857f 100644 --- a/libspu/mpc/semi2k/conversion.cc +++ b/libspu/mpc/semi2k/conversion.cc @@ -42,6 +42,26 @@ static NdArrayRef wrap_and_bb(SPUContext* ctx, const NdArrayRef& x, return UnwrapValue(and_bb(ctx, WrapValue(x), WrapValue(y))); } +// TODO: Move to some common place +PtType getBacktype(size_t nbits) { + if (nbits <= 8) { + return PT_U8; + } + if (nbits <= 16) { + return PT_U16; + } + if (nbits <= 32) { + return PT_U32; + } + if (nbits <= 64) { + return PT_U64; + } + if (nbits <= 128) { + return PT_U128; + } + SPU_THROW("invalid number of bits={}", nbits); +} + NdArrayRef A2B::proc(KernelEvalContext* ctx, const NdArrayRef& x) const { const auto field = x.eltype().as()->field(); auto* comm = ctx->getState(); @@ -90,6 +110,9 @@ NdArrayRef B2A::proc(KernelEvalContext* ctx, const NdArrayRef& x) const { return r_a; } +// TODO(jimi): pack {numel * nbits} to fully make use of undelying storage to +// save communications. If implemented, B2A_Disassemble kernel is also no longer +// needed NdArrayRef B2A_Randbit::proc(KernelEvalContext* ctx, const NdArrayRef& x) const { const auto field = x.eltype().as()->field(); @@ -105,6 +128,7 @@ NdArrayRef B2A_Randbit::proc(KernelEvalContext* ctx, const auto numel = x.numel(); const auto rand_numel = numel * static_cast(nbits); + const PtType backtype = getBacktype(nbits); auto randbits = beaver->RandBit(field, rand_numel); SPU_ENFORCE(static_cast(randbits.size()) == @@ -119,32 +143,125 @@ NdArrayRef B2A_Randbit::proc(KernelEvalContext* ctx, // algorithm begins. // Ref: III.D @ https://eprint.iacr.org/2019/599.pdf (SPDZ-2K primitives) - std::vector x_xor_r(numel); - - pforeach(0, numel, [&](int64_t idx) { - // use _r[i*nbits, (i+1)*nbits) to construct rb[i] - U mask = 0; - for (int64_t bit = 0; bit < nbits; ++bit) { - mask += (_randbits[idx * nbits + bit] & 0x1) << bit; - } - x_xor_r[idx] = _x[idx] ^ mask; + DISPATCH_UINT_PT_TYPES(backtype, "_", [&]() { + using V = ScalarT; + std::vector x_xor_r(numel); + + pforeach(0, numel, [&](int64_t idx) { + // use _r[i*nbits, (i+1)*nbits) to construct rb[i] + V mask = 0; + for (int64_t bit = 0; bit < nbits; ++bit) { + mask += (static_cast(_randbits[idx * nbits + bit]) & 0x1) << bit; + } + x_xor_r[idx] = _x[idx] ^ mask; + }); + + // open c = x ^ r + x_xor_r = comm->allReduce(x_xor_r, "open(x^r)"); + + NdArrayView _res(res); + pforeach(0, numel, [&](int64_t idx) { + _res[idx] = 0; + for (int64_t bit = 0; bit < nbits; bit++) { + auto c_i = static_cast(x_xor_r[idx] >> bit) & 0x1; + if (comm->getRank() == 0) { + _res[idx] += (c_i + (1 - c_i * 2) * _randbits[idx * nbits + bit]) + << bit; + } else { + _res[idx] += ((1 - c_i * 2) * _randbits[idx * nbits + bit]) << bit; + } + } + }); }); + }); + + return res; +} - // open c = x ^ r - x_xor_r = comm->allReduce(x_xor_r, "open(x^r)"); - - NdArrayView _res(res); - pforeach(0, numel, [&](int64_t idx) { - _res[idx] = 0; - for (int64_t bit = 0; bit < nbits; bit++) { - auto c_i = (x_xor_r[idx] >> bit) & 0x1; - if (comm->getRank() == 0) { - _res[idx] += (c_i + (1 - c_i * 2) * _randbits[idx * nbits + bit]) - << bit; - } else { - _res[idx] += ((1 - c_i * 2) * _randbits[idx * nbits + bit]) << bit; +// Reference: +// III.D @ https://eprint.iacr.org/2019/599.pdf (SPDZ-2K primitives) +// +// Analysis: +// Online Latency: 1 (x_xor_r reveal) +// Communication: one element bits for one element +// Vectorization: yes +// +// HighLevel Intuition: +// Since: X = sum: Xi * 2^i +// If we have A, then we can construct A = sum: A * 2^i. +// +// The problem is that we only have B in hand. Details for how to +// construct A from B: +// - trusted third party choose a random bit r, where r == 0 or r == 1. +// - trusted third party send A to parties +// - parties compute B from A +// - parties xor_open c = Xi ^ r = open(B ^ B), Xi is still safe due +// to protection from r. +// - parties compute: = c + (1-2c)* +// A = 1 - A if c == 1, i.e. Xi != r +// A = A if c == 0, i.e. Xi == r +// i.e. A = c + (1-2c) * A +// +// Online Communication: +// = 1 (xor open) + +// Disassemble BShr to AShr bit-by-bit +// Input: BShr +// Return: a vector of k AShr, k is the valid bits of BShr +std::vector B2A_Disassemble::proc(KernelEvalContext* ctx, + const NdArrayRef& x) const { + const auto field = x.eltype().as()->field(); + auto* comm = ctx->getState(); + auto* beaver = ctx->getState()->beaver(); + + const int64_t nbits = x.eltype().as()->nbits(); + SPU_ENFORCE((size_t)nbits > 0 && (size_t)nbits <= SizeOf(field) * 8, + "invalid nbits={}", nbits); + + const auto numel = x.numel(); + const auto rand_numel = numel * static_cast(nbits); + const PtType backtype = getBacktype(nbits); + + auto randbits = beaver->RandBit(field, rand_numel); + + std::vector res; + res.reserve(nbits); + for (int64_t idx = 0; idx < nbits; ++idx) { + res.emplace_back(makeType(field), x.shape()); + } + DISPATCH_ALL_FIELDS(field, "_", [&]() { + using U = ring2k_t; + + absl::Span _randbits(randbits.data(), rand_numel); + NdArrayView _x(x); + + DISPATCH_UINT_PT_TYPES(backtype, "_", [&]() { + using V = ScalarT; + std::vector x_xor_r(numel); + + pforeach(0, numel, [&](int64_t idx) { + // use _r[i*nbits, (i+1)*nbits) to construct rb[i] + V mask = 0; + for (int64_t bit = 0; bit < nbits; ++bit) { + mask += (static_cast(_randbits[idx * nbits + bit]) & 0x1) << bit; } - } + x_xor_r[idx] = _x[idx] ^ mask; + }); + + // open c = x ^ r + x_xor_r = comm->allReduce(x_xor_r, "open(x^r)"); + + pforeach(0, numel, [&](int64_t idx) { + pforeach(0, nbits, [&](int64_t bit) { + NdArrayView _res(res[bit]); + auto c_i = static_cast(x_xor_r[idx] >> bit) & 0x1; + if (comm->getRank() == 0) { + _res[idx] = (c_i + (1 - c_i * 2) * _randbits[idx * nbits + bit]); + } else { + _res[idx] = ((1 - c_i * 2) * _randbits[idx * nbits + bit]); + } + }); + }); }); }); diff --git a/libspu/mpc/semi2k/conversion.h b/libspu/mpc/semi2k/conversion.h index bc249998..891a23cd 100644 --- a/libspu/mpc/semi2k/conversion.h +++ b/libspu/mpc/semi2k/conversion.h @@ -73,6 +73,21 @@ class B2A_Randbit : public UnaryKernel { NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& x) const override; }; +class B2A_Disassemble : public DisassembleKernel { + public: + static constexpr char kBindName[] = "b2a_disassemble"; + + ce::CExpr latency() const override { return ce::Const(1); } + + ce::CExpr comm() const override { + return ce::K() * (ce::N() - 1) // Open bit masked value + ; + } + + std::vector proc(KernelEvalContext* ctx, + const NdArrayRef& x) const override; +}; + // Note: current only for 2PC. class MsbA2B : public UnaryKernel { public: diff --git a/libspu/mpc/semi2k/protocol.cc b/libspu/mpc/semi2k/protocol.cc index a7d3b7e1..3acd8344 100644 --- a/libspu/mpc/semi2k/protocol.cc +++ b/libspu/mpc/semi2k/protocol.cc @@ -50,21 +50,23 @@ void regSemi2kProtocol(SPUContext* ctx, ctx->prot()->addState(ctx->config(), lctx); ctx->prot() ->regKernel< - semi2k::P2A, semi2k::A2P, semi2k::A2V, semi2k::V2A, // - semi2k::NotA, // - semi2k::AddAP, semi2k::AddAA, // - semi2k::MulAP, semi2k::MulAA, semi2k::SquareA, // - semi2k::MatMulAP, semi2k::MatMulAA, // - semi2k::LShiftA, semi2k::LShiftB, semi2k::RShiftB, - semi2k::ARShiftB, // - semi2k::CommonTypeB, semi2k::CommonTypeV, semi2k::CastTypeB, // - semi2k::B2P, semi2k::P2B, semi2k::A2B, semi2k::B2A_Randbit, // - semi2k::AndBP, semi2k::AndBB, semi2k::XorBP, semi2k::XorBB, - semi2k::BitrevB, // - semi2k::BitIntlB, semi2k::BitDeintlB, // - semi2k::RandA, semi2k::RandPermM, semi2k::PermAM, semi2k::PermAP, - semi2k::InvPermAM, semi2k::InvPermAP, semi2k::InvPermAV, // - semi2k::EqualAA, semi2k::EqualAP, semi2k::BeaverCacheKernel>(); + semi2k::P2A, semi2k::A2P, semi2k::A2V, semi2k::V2A, // + semi2k::NotA, // + semi2k::AddAP, semi2k::AddAA, // + semi2k::MulAP, semi2k::MulAA, semi2k::SquareA, // + semi2k::MatMulAP, semi2k::MatMulAA, // + semi2k::LShiftA, semi2k::LShiftB, semi2k::RShiftB, // + semi2k::ARShiftB, // + semi2k::CommonTypeB, semi2k::CommonTypeV, semi2k::CastTypeB, // + semi2k::B2P, semi2k::P2B, // + semi2k::A2B, semi2k::B2A_Randbit, semi2k::B2A_Disassemble, // + semi2k::AndBP, semi2k::AndBB, semi2k::XorBP, semi2k::XorBB, // + semi2k::BitrevB, // + semi2k::BitIntlB, semi2k::BitDeintlB, // + semi2k::RandA, semi2k::RandPermM, semi2k::PermAM, semi2k::PermAP, // + semi2k::InvPermAM, semi2k::InvPermAP, semi2k::InvPermAV, // + semi2k::EqualAA, semi2k::EqualAP, // + semi2k::BeaverCacheKernel>(); if (ctx->config().trunc_allow_msb_error()) { ctx->prot()->regKernel(); diff --git a/requirements.txt b/requirements.txt index bea5a7d1..ac8fd6cc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ grpcio>=1.42.0,!=1.48.0 -numpy>=1.22.0, < 2 +numpy>=1.22.0 protobuf>=4, <5 cloudpickle>=2.0.0 multiprocess>=0.70.12.2 cachetools>=5.0.0 -jax[cpu]>=0.4.16, <=0.4.26 # FIXME +jax[cpu]>=0.4.16, <=0.4.26 # FIXME: Jax 0.4.26+ select perf issue termcolor>=2.0.0 diff --git a/spu/tests/BUILD.bazel b/spu/tests/BUILD.bazel index d7453e60..cfd029b8 100644 --- a/spu/tests/BUILD.bazel +++ b/spu/tests/BUILD.bazel @@ -87,6 +87,7 @@ py_test( py_test( name = "jnp_cheetah_r64_test", + size = "enormous", timeout = "long", srcs = ["jnp_cheetah_r64_test.py"], deps = [ @@ -96,7 +97,7 @@ py_test( py_test( name = "jnp_cheetah_r64_test_x64", - timeout = "long", + size = "enormous", srcs = ["jnp_cheetah_r64_test.py"], env = { "ENABLE_X64_TEST": "1", diff --git a/spu/tests/jnp_cheetah_r64_test.py b/spu/tests/jnp_cheetah_r64_test.py index b113dbcd..10c02a53 100644 --- a/spu/tests/jnp_cheetah_r64_test.py +++ b/spu/tests/jnp_cheetah_r64_test.py @@ -22,8 +22,7 @@ from spu.tests.jnp_testbase import JnpTests -@unittest.skip("too slow, last run succeed") -class JnpTestAby3FM64(JnpTests.JnpTestBase): +class JnpTestCheetahFM64(JnpTests.JnpTestBase): def setUp(self): self._sim = ppsim.Simulator.simple( 2, spu_pb2.ProtocolKind.CHEETAH, spu_pb2.FieldType.FM64 diff --git a/spu/utils/distributed_impl.py b/spu/utils/distributed_impl.py index 79b13426..ebd6a703 100644 --- a/spu/utils/distributed_impl.py +++ b/spu/utils/distributed_impl.py @@ -723,7 +723,10 @@ def mock_parameters(obj: Union[SPU.Object, np.ndarray]): fn_name = repr(fn) - import jax.extend.linear_util as lu + try: + import jax.extend.linear_util as lu + except ImportError: + import jax.linear_util as lu # fallback from jax._src import api_util as japi_util from jax.tree_util import tree_map, tree_flatten diff --git a/spu/utils/frontend.py b/spu/utils/frontend.py index e1a4c24c..fc20cd0c 100644 --- a/spu/utils/frontend.py +++ b/spu/utils/frontend.py @@ -115,13 +115,32 @@ def _jax_compilation( register_backend_factory('interpreter', xla_back, priority=-100) - fn, kwargs = _argnames_partial_except(fn, static_argnames, kwargs) + jax_version = jax.__version_info__ + + if jax_version[0] > 1 or jax_version[1] > 4 or jax_version[2] > 29: + # xla_computation is deprecated since 0.4.30, move to new api + lowered = ( + jax.jit( + fn, + static_argnums=static_argnums, + static_argnames=static_argnames, + keep_unused=True, + ) + .trace(*args, **kwargs) + .lower(lowering_platforms=('interpreter',)) + ) + return ( + lowered.compiler_ir('hlo').as_serialized_hlo_module_proto(), + lowered.out_info, + ) + else: + fn, kwargs = _argnames_partial_except(fn, static_argnames, kwargs) - cfn, output = jax.xla_computation( - fn, return_shape=True, static_argnums=static_argnums, backend="interpreter" - )(*args, **kwargs) + cfn, output = jax.xla_computation( + fn, return_shape=True, static_argnums=static_argnums, backend="interpreter" + )(*args, **kwargs) - return cfn.as_serialized_hlo_module_proto(), output + return cfn.as_serialized_hlo_module_proto(), output ## Frontend patches diff --git a/spu/utils/simulation.py b/spu/utils/simulation.py index 8df6185c..592ccaff 100644 --- a/spu/utils/simulation.py +++ b/spu/utils/simulation.py @@ -17,7 +17,11 @@ from typing import Callable import jax -import jax.extend.linear_util as jax_lu # Moved in jax 0.4.16 + +try: + import jax.extend.linear_util as jax_lu +except ImportError: + import jax.linear_util as jax_lu # fallback import jax.numpy as jnp import numpy as np from jax._src import api_util as japi_util From 279fc79bff1ca2f8768f2683417e1c7d1556ffaf Mon Sep 17 00:00:00 2001 From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com> Date: Mon, 8 Jul 2024 15:02:23 +0800 Subject: [PATCH 03/27] chore(deps): update dependency rstcheck to v6.2.4 (#756) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [![Mend Renovate](https://app.renovatebot.com/images/banner.svg)](https://renovatebot.com) This PR contains the following updates: | Package | Change | Age | Adoption | Passing | Confidence | |---|---|---|---|---|---| | [rstcheck](https://togithub.com/rstcheck/rstcheck) ([changelog](https://togithub.com/rstcheck/rstcheck/blob/main/CHANGELOG.md)) | `==6.2.1` -> `==6.2.4` | [![age](https://developer.mend.io/api/mc/badges/age/pypi/rstcheck/6.2.4?slim=true)](https://docs.renovatebot.com/merge-confidence/) | [![adoption](https://developer.mend.io/api/mc/badges/adoption/pypi/rstcheck/6.2.4?slim=true)](https://docs.renovatebot.com/merge-confidence/) | [![passing](https://developer.mend.io/api/mc/badges/compatibility/pypi/rstcheck/6.2.1/6.2.4?slim=true)](https://docs.renovatebot.com/merge-confidence/) | [![confidence](https://developer.mend.io/api/mc/badges/confidence/pypi/rstcheck/6.2.1/6.2.4?slim=true)](https://docs.renovatebot.com/merge-confidence/) | --- ### Release Notes
rstcheck/rstcheck (rstcheck) ### [`v6.2.4`](https://togithub.com/rstcheck/rstcheck/blob/HEAD/CHANGELOG.md#v624-2024-07-07) [Compare Source](https://togithub.com/rstcheck/rstcheck/compare/v6.2.3...v6.2.4) [diff v6.2.3...v6.2.4](https://togithub.com/rstcheck/rstcheck/compare/v6.2.3...v6.2.4) ##### Documentation - Add note on how to disable pretty exception output ([#​228](https://togithub.com/rstcheck/rstcheck/pull/228)) ##### Miscellaneous - Add help text to `--version` flag ([#​228](https://togithub.com/rstcheck/rstcheck/pull/228)) ### [`v6.2.3`](https://togithub.com/rstcheck/rstcheck/blob/HEAD/CHANGELOG.md#v623-2024-07-07) [Compare Source](https://togithub.com/rstcheck/rstcheck/compare/v6.2.2...v6.2.3) [diff v6.2.2...v6.2.3](https://togithub.com/rstcheck/rstcheck/compare/v6.2.2...v6.2.3) ##### Bugfixes - Fix typer dependency by removing the `[standard]` extra which is only used on typer-slim. Typer by default has the extras included. ### [`v6.2.2`](https://togithub.com/rstcheck/rstcheck/blob/HEAD/CHANGELOG.md#v622-2024-07-07) [Compare Source](https://togithub.com/rstcheck/rstcheck/compare/v6.2.1...v6.2.2) [diff v6.2.1...v6.2.2](https://togithub.com/rstcheck/rstcheck/compare/v6.2.1...v6.2.2) ##### Miscellaneous - Bump min. version of typer and fix dependency group name ([#​223](https://togithub.com/rstcheck/rstcheck/issues/223)) - Update configs for dev tooling ([#​225](https://togithub.com/rstcheck/rstcheck/pull/225)) - Bump default python version to 3.12 ([#​225](https://togithub.com/rstcheck/rstcheck/pull/225))
--- ### Configuration 📅 **Schedule**: Branch creation - At any time (no schedule defined), Automerge - At any time (no schedule defined). 🚦 **Automerge**: Disabled by config. Please merge this manually once you are satisfied. ♻ **Rebasing**: Whenever PR becomes conflicted, or you tick the rebase/retry checkbox. 🔕 **Ignore**: Close this PR and you won't be reminded about this update again. --- - [ ] If you want to rebase/retry this PR, check this box --- This PR has been generated by [Mend Renovate](https://www.mend.io/free-developer-tools/renovate/). View repository job log [here](https://developer.mend.io/github/secretflow/spu). Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> --- docs/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 0c9b6330..9b6bb997 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,5 +1,5 @@ myst-parser==3.0.1 -rstcheck==6.2.1 +rstcheck==6.2.4 sphinx==7.3.7 nbsphinx==0.9.4 sphinx-autobuild==2024.4.16 From 4ecd05bdcfe6e530c8dce4c84869311ae38df584 Mon Sep 17 00:00:00 2001 From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com> Date: Mon, 8 Jul 2024 15:02:49 +0800 Subject: [PATCH 04/27] chore(deps): update github/codeql-action action to v3.25.11 (#749) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [![Mend Renovate](https://app.renovatebot.com/images/banner.svg)](https://renovatebot.com) This PR contains the following updates: | Package | Type | Update | Change | |---|---|---|---| | [github/codeql-action](https://togithub.com/github/codeql-action) | action | patch | `v3.25.10` -> `v3.25.11` | --- ### Release Notes
github/codeql-action (github/codeql-action) ### [`v3.25.11`](https://togithub.com/github/codeql-action/compare/v3.25.10...v3.25.11) [Compare Source](https://togithub.com/github/codeql-action/compare/v3.25.10...v3.25.11)
--- ### Configuration 📅 **Schedule**: Branch creation - At any time (no schedule defined), Automerge - At any time (no schedule defined). 🚦 **Automerge**: Disabled by config. Please merge this manually once you are satisfied. ♻ **Rebasing**: Whenever PR becomes conflicted, or you tick the rebase/retry checkbox. 🔕 **Ignore**: Close this PR and you won't be reminded about this update again. --- - [ ] If you want to rebase/retry this PR, check this box --- This PR has been generated by [Mend Renovate](https://www.mend.io/free-developer-tools/renovate/). View repository job log [here](https://developer.mend.io/github/secretflow/spu). Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> --- .github/workflows/scorecard.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/scorecard.yml b/.github/workflows/scorecard.yml index 68978a19..3b6e8be7 100644 --- a/.github/workflows/scorecard.yml +++ b/.github/workflows/scorecard.yml @@ -67,6 +67,6 @@ jobs: # Upload the results to GitHub's code scanning dashboard. - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@23acc5c183826b7a8a97bce3cecc52db901f8251 # v3.25.10 + uses: github/codeql-action/upload-sarif@b611370bb5703a7efb587f9d136a52ea24c5c38c # v3.25.11 with: sarif_file: results.sarif From 54c0885f9a790c13ba76aee7ba8d3a8b99a28300 Mon Sep 17 00:00:00 2001 From: Yancheng Zheng <103552181+anakinxc@users.noreply.github.com> Date: Mon, 8 Jul 2024 15:27:57 +0800 Subject: [PATCH 05/27] Repo sync (#759) --- libspu/compiler/tests/interpret/power.mlir | 4 ++-- libspu/compiler/tests/interpret/test_json/power.json | 2 +- spu/tests/BUILD.bazel | 1 - 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/libspu/compiler/tests/interpret/power.mlir b/libspu/compiler/tests/interpret/power.mlir index dd950ef1..307fce4c 100644 --- a/libspu/compiler/tests/interpret/power.mlir +++ b/libspu/compiler/tests/interpret/power.mlir @@ -58,7 +58,7 @@ func.func @power_op_test_f64_f64_pp() { %1 = pphlo.constant dense<[2.0, 2.0, 2.0, -1.0, 1.0]> : tensor<5xf64> %2 = pphlo.power %0,%1 : (tensor<5xf64>,tensor<5xf64>)->tensor<5xf64> %3 = pphlo.constant dense<[4.000000e+00, 0.000000e+00, 2.500000e+01, 0.33333333333333331, 10000.0]> : tensor<5xf64> - pphlo.custom_call @expect_almost_eq(%2, %3) { tol = 0.5 }: (tensor<5xf64>, tensor<5xf64>)->() + pphlo.custom_call @expect_almost_eq(%2, %3) { tol = 0.6 }: (tensor<5xf64>, tensor<5xf64>)->() func.return } @@ -72,6 +72,6 @@ func.func @power_op_test_f64_f64_ss() { %4 = pphlo.power %2, %3 : (tensor<5x!pphlo.secret>,tensor<5x!pphlo.secret>)->tensor<5x!pphlo.secret> %5 = pphlo.constant dense<[4.000000e+00, 0.000000e+00, 2.500000e+01, 0.33333333333333331, 10000.0]> : tensor<5xf64> %6 = pphlo.convert %4 : (tensor<5x!pphlo.secret>)->tensor<5xf64> - pphlo.custom_call @expect_almost_eq(%5, %6) { tol = 0.5 }: (tensor<5xf64>, tensor<5xf64>)->() + pphlo.custom_call @expect_almost_eq(%5, %6) { tol = 0.6 }: (tensor<5xf64>, tensor<5xf64>)->() func.return } diff --git a/libspu/compiler/tests/interpret/test_json/power.json b/libspu/compiler/tests/interpret/test_json/power.json index e5e6f3cf..a95ae33c 100644 --- a/libspu/compiler/tests/interpret/test_json/power.json +++ b/libspu/compiler/tests/interpret/test_json/power.json @@ -65,7 +65,7 @@ } ], "checker": "expect_almost_eq", - "tol": 0.5 + "tol": 0.6 } ] } \ No newline at end of file diff --git a/spu/tests/BUILD.bazel b/spu/tests/BUILD.bazel index cfd029b8..fd037279 100644 --- a/spu/tests/BUILD.bazel +++ b/spu/tests/BUILD.bazel @@ -88,7 +88,6 @@ py_test( py_test( name = "jnp_cheetah_r64_test", size = "enormous", - timeout = "long", srcs = ["jnp_cheetah_r64_test.py"], deps = [ ":jnp_testbase", From ddb1acbe8938be2328d8db9fe93a1bc84c093410 Mon Sep 17 00:00:00 2001 From: Yancheng Zheng <103552181+anakinxc@users.noreply.github.com> Date: Mon, 8 Jul 2024 15:37:59 +0800 Subject: [PATCH 06/27] Update acknowledgement (#760) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 56375409..8fad8f9c 100644 --- a/README.md +++ b/README.md @@ -71,4 +71,4 @@ If you think SPU is helpful for your research or development, please consider ci ## Acknowledgement -We thank the significant contributions made by [Alibaba Gemini Lab](https://alibaba-gemini-lab.github.io). +We thank the significant contributions made by [Alibaba Gemini Lab](https://alibaba-gemini-lab.github.io) and security advisories made by [NISL@THU](https://netsec.ccert.edu.cn/vul337). From 1e5aca15cf1e680b7c8630142bcea3d05ce02ae1 Mon Sep 17 00:00:00 2001 From: Yancheng Zheng <103552181+anakinxc@users.noreply.github.com> Date: Mon, 8 Jul 2024 15:40:54 +0800 Subject: [PATCH 07/27] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 8fad8f9c..6a1fe7e3 100644 --- a/README.md +++ b/README.md @@ -71,4 +71,4 @@ If you think SPU is helpful for your research or development, please consider ci ## Acknowledgement -We thank the significant contributions made by [Alibaba Gemini Lab](https://alibaba-gemini-lab.github.io) and security advisories made by [NISL@THU](https://netsec.ccert.edu.cn/vul337). +We thank the significant contributions made by [Alibaba Gemini Lab](https://alibaba-gemini-lab.github.io) and security advisories made by [VUL337@NISL@THU](https://netsec.ccert.edu.cn/vul337). From 630297dec8412c6404196be00f25b652b141c004 Mon Sep 17 00:00:00 2001 From: Yancheng Zheng <103552181+anakinxc@users.noreply.github.com> Date: Thu, 11 Jul 2024 10:15:40 +0800 Subject: [PATCH 08/27] Repo sync (#762) --- libspu/mpc/ab_api.cc | 4 ++-- libspu/mpc/cheetah/arith/matmat_prot.cc | 2 +- libspu/mpc/cheetah/rlwe/packlwes.cc | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/libspu/mpc/ab_api.cc b/libspu/mpc/ab_api.cc index a52de586..7d02fe47 100644 --- a/libspu/mpc/ab_api.cc +++ b/libspu/mpc/ab_api.cc @@ -252,7 +252,7 @@ Value bitintl_b(SPUContext* ctx, const Value& x, size_t stride) { idx--) { auto K = hack_make_p(ctx, spu::detail::kBitIntlKeepMasks[idx], x.shape()); auto M = hack_make_p(ctx, spu::detail::kBitIntlSwapMasks[idx], x.shape()); - int64_t S = 1 << idx; + int64_t S = static_cast(1) << idx; // out = (out & K) ^ ((out >> S) & M) ^ ((out & M) << S); out = xor_bb( ctx, @@ -281,7 +281,7 @@ Value bitdeintl_b(SPUContext* ctx, const Value& x, size_t stride) { for (int64_t idx = stride; idx + 1 < Log2Ceil(nbits); idx++) { auto K = hack_make_p(ctx, spu::detail::kBitIntlKeepMasks[idx], x.shape()); auto M = hack_make_p(ctx, spu::detail::kBitIntlSwapMasks[idx], x.shape()); - int64_t S = 1 << idx; + int64_t S = static_cast(1) << idx; // out = (out & K) ^ ((out >> S) & M) ^ ((out & M) << S); out = xor_bb( ctx, diff --git a/libspu/mpc/cheetah/arith/matmat_prot.cc b/libspu/mpc/cheetah/arith/matmat_prot.cc index a56e522f..ebb80f9f 100644 --- a/libspu/mpc/cheetah/arith/matmat_prot.cc +++ b/libspu/mpc/cheetah/arith/matmat_prot.cc @@ -183,7 +183,7 @@ Shape3D MatMatProtocol::GetSubMatShape(const Meta& meta, int64_t poly_deg, const double cpu_price = 1.0; const double bandwidth_price = 1000.0; - Shape3D subshape; + Shape3D subshape = {0, 0, 0}; Shape3D blk; const int64_t n = poly_deg; diff --git a/libspu/mpc/cheetah/rlwe/packlwes.cc b/libspu/mpc/cheetah/rlwe/packlwes.cc index 59f99e4c..f2c7272f 100644 --- a/libspu/mpc/cheetah/rlwe/packlwes.cc +++ b/libspu/mpc/cheetah/rlwe/packlwes.cc @@ -134,7 +134,7 @@ void PackingHelper::doPackingRLWEs(absl::Span rlwes, seal::Evaluator evaluator(context_); const int64_t logn = absl::bit_width(gap_) - 1; for (int64_t k = logn; k >= 1; --k) { - int64_t h = 1 << (k - 1); + int64_t h = static_cast(1) << (k - 1); yacl::parallel_for(0, h, [&](int64_t bgn, int64_t end) { RLWECt dummy; // zero-padding with zero RLWE for (int64_t i = bgn; i < end; ++i) { @@ -188,7 +188,7 @@ void GenerateGaloisKeyForPacking(const seal::SEALContext &context, size_t logN = absl::bit_width(N) - 1; std::vector galois_elt; for (uint32_t i = 1; i <= logN; i++) { - galois_elt.push_back((1u << i) + 1); + galois_elt.push_back((static_cast(1) << i) + 1); } seal::KeyGenerator keygen(context, key); From fcef2ff4ae58d15e4e2df48369b4cd8577a01552 Mon Sep 17 00:00:00 2001 From: Yancheng Zheng <103552181+anakinxc@users.noreply.github.com> Date: Thu, 11 Jul 2024 10:21:32 +0800 Subject: [PATCH 09/27] Update repositories.bzl --- bazel/repositories.bzl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bazel/repositories.bzl b/bazel/repositories.bzl index 3fe35ee0..59ed839d 100644 --- a/bazel/repositories.bzl +++ b/bazel/repositories.bzl @@ -229,7 +229,7 @@ def _com_github_microsoft_seal(): maybe( http_archive, name = "com_github_microsoft_seal", - sha256 = "78ef7334114de930daf7659e8ba60c5abfff85c86ec2b827a2b7c67c3c42da43", + sha256 = "acc2a1a127a85d1e1ffcca3ffd148f736e665df6d6b072df0e42fff64795a13c", strip_prefix = "SEAL-4.1.2", type = "tar.gz", patch_args = ["-p1"], From 5b42950f3d6d771cf719395e5e572e1f96a31540 Mon Sep 17 00:00:00 2001 From: Yancheng Zheng <103552181+anakinxc@users.noreply.github.com> Date: Mon, 15 Jul 2024 13:34:00 +0800 Subject: [PATCH 10/27] Repo sync (#768) # Pull Request ## What problem does this PR solve? Issue Number: Fixed #766 ## Possible side effects? - Performance: - Backward compatibility: --- libspu/compiler/core/core.cc | 1 + libspu/device/pphlo/pphlo_executor.cc | 11 ++ libspu/device/pphlo/pphlo_verifier.h | 1 + libspu/dialect/pphlo/IR/ops.td | 6 + libspu/dialect/pphlo/transforms/passes.h | 3 + libspu/dialect/pphlo/transforms/passes.td | 6 + .../pphlo/transforms/region_access_fixture.cc | 141 ++++++++++++++++++ 7 files changed, 169 insertions(+) create mode 100644 libspu/dialect/pphlo/transforms/region_access_fixture.cc diff --git a/libspu/compiler/core/core.cc b/libspu/compiler/core/core.cc index 355dd68a..e977050c 100644 --- a/libspu/compiler/core/core.cc +++ b/libspu/compiler/core/core.cc @@ -95,6 +95,7 @@ void Core::buildPipeline(mlir::PassManager *pm) { } optPM.addPass(mlir::createLoopInvariantCodeMotionPass()); + optPM.addPass(mlir::spu::pphlo::createRegionAccessFixture()); optPM.addPass(mlir::createCSEPass()); if (!options.disable_deallocation_insertion()) { diff --git a/libspu/device/pphlo/pphlo_executor.cc b/libspu/device/pphlo/pphlo_executor.cc index d20d54f9..af8a04b8 100644 --- a/libspu/device/pphlo/pphlo_executor.cc +++ b/libspu/device/pphlo/pphlo_executor.cc @@ -720,6 +720,17 @@ void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope, addValue(sscope, op.getOutput(), std::move(iota_ret), opts); } +void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope, + mlir::spu::pphlo::BroadcastShapeAsOp &op, + const ExecutionOptions &opts) { + // Start indices + const auto &lhs = lookupValue(sscope, op.getLhs(), opts); + const auto &rhs = lookupValue(sscope, op.getRhs(), opts); + + addValue(sscope, op.getResult(), + kernel::hlo::Broadcast(sctx, lhs, rhs.shape(), {}), opts); +} + void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope, mlir::spu::pphlo::RemOp &op, const ExecutionOptions &opts) { // FIXME: When hal has a remainder, use that diff --git a/libspu/device/pphlo/pphlo_verifier.h b/libspu/device/pphlo/pphlo_verifier.h index 428883c8..478a3874 100644 --- a/libspu/device/pphlo/pphlo_verifier.h +++ b/libspu/device/pphlo/pphlo_verifier.h @@ -149,6 +149,7 @@ class PPHloVerifier { NO_VERIFY_DEFN(ImagOp) NO_VERIFY_DEFN(ComplexOp) NO_VERIFY_DEFN(SimpleSortOp) + NO_VERIFY_DEFN(BroadcastShapeAsOp) #undef NO_VERIFY_DEFN }; diff --git a/libspu/dialect/pphlo/IR/ops.td b/libspu/dialect/pphlo/IR/ops.td index b9f6b0c2..205dcf75 100644 --- a/libspu/dialect/pphlo/IR/ops.td +++ b/libspu/dialect/pphlo/IR/ops.td @@ -619,6 +619,12 @@ def PPHLO_BroadcastOp }]; } + +def PPHLO_BroadcastShapeAsOp: PPHLO_BinaryElementwiseOp<"broadcast_as", [Pure, + SameOperandsAndResultShape], PPHLO_Tensor> { + let summary = "BroadcastShapeAs operator"; +} + def PPHLO_ClampOp : PPHLO_Op<"clamp", [Pure, SameOperandsAndResultShape]> { let summary = "Clamp operator"; diff --git a/libspu/dialect/pphlo/transforms/passes.h b/libspu/dialect/pphlo/transforms/passes.h index b70ce1ab..9bcc2660 100644 --- a/libspu/dialect/pphlo/transforms/passes.h +++ b/libspu/dialect/pphlo/transforms/passes.h @@ -82,6 +82,9 @@ std::unique_ptr> createInlineSecretControlFlow(); // Convert signbit pattern to SignOp std::unique_ptr> createRewriteSignbitPatterns(); +// Fix region access shape mismatch +std::unique_ptr> createRegionAccessFixture(); + } // namespace spu::pphlo } // namespace mlir diff --git a/libspu/dialect/pphlo/transforms/passes.td b/libspu/dialect/pphlo/transforms/passes.td index 4fc28fdc..a13c70b2 100644 --- a/libspu/dialect/pphlo/transforms/passes.td +++ b/libspu/dialect/pphlo/transforms/passes.td @@ -122,3 +122,9 @@ def InlineSecretControlFlow: Pass<"inline-secret-control-flow", "func::FuncOp"> let constructor = "createInlineSecretControlFlow()"; let dependentDialects = ["pphlo::PPHloDialect"]; } + +def RegionAccessFixture: Pass<"region-access-fixture", "func::FuncOp"> { + let summary = "Fix region access mismatched shape"; + let constructor = "createRegionAccessFixture()"; + let dependentDialects = ["pphlo::PPHloDialect"]; +} \ No newline at end of file diff --git a/libspu/dialect/pphlo/transforms/region_access_fixture.cc b/libspu/dialect/pphlo/transforms/region_access_fixture.cc new file mode 100644 index 00000000..3bd9d83a --- /dev/null +++ b/libspu/dialect/pphlo/transforms/region_access_fixture.cc @@ -0,0 +1,141 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mlir/Pass/Pass.h" + +#include "libspu/dialect/pphlo/IR/ops.h" +#include "libspu/dialect/pphlo/transforms/pass_details.h" + +namespace mlir::spu::pphlo { + +namespace { + +struct Deallocator { + public: + LogicalResult transformOp(Operation *op) { + for (auto &r : op->getRegions()) { + if (failed(transformRegion(r))) { + return failure(); + } + } + + const auto &operands = op->getOperands(); + + if (op->getNumOperands() < 2 || + !op->hasTrait<::mlir::OpTrait::Elementwise>() || + std::all_of(operands.begin(), operands.end(), [](const auto &operand) { + return operand.template getDefiningOp(); + })) { + return success(); + } + + auto *op_region = op->getParentRegion(); + + Value base_val; + llvm::SmallVector values_to_update; + + OpBuilder builder(op->getContext()); + builder.setInsertionPoint(op); + + for (const auto &[idx, operand] : llvm::enumerate(operands)) { + // Get defining region + auto *defining_op = operand.getDefiningOp(); + + Region *defining_region = nullptr; + + if (defining_op != nullptr) { + defining_region = defining_op->getParentRegion(); + } + + if (defining_op == nullptr || defining_region == op_region) { + // BlockArg or op defined in current region can be a base val + base_val = operand; + continue; + } + + if (defining_region != op_region) { + // This op is accessing a variable out of op's region. + // Insert a broadcast as to fix runtime shape mismatch during simd + // region execution + values_to_update.emplace_back(idx); + } + } + + if (!base_val) { + return values_to_update.empty() + ? failure() // same region however failed to pick base value + : success(); // can't pick base value since multi-level + // nesting + } + + for (const auto &idx : values_to_update) { + auto op_to_broadcast = op->getOperand(idx); + auto b = builder.create( + op->getLoc(), op_to_broadcast.getType(), op_to_broadcast, base_val); + op->setOperand(idx, b); + } + + return success(); + } + + LogicalResult transformBlock(Block &block) { + for (auto &op : llvm::make_early_inc_range(block.without_terminator())) { + auto opResult = transformOp(&op); + if (failed(opResult)) { + return failure(); + } + } + return success(); + } + + LogicalResult transformRegion(Region &r) { + for (auto &b : r.getBlocks()) { + if (failed(transformBlock(b))) { + return failure(); + } + } + return success(); + } + + LogicalResult transformFuncOp(func::FuncOp op) { + if (op->getNumRegions() == 0) { + return success(); + } + + // Transform function body. + if (failed(transformRegion(op.getBody()))) { + return failure(); + } + + return success(); + } +}; + +struct RegionAccessFixture + : public RegionAccessFixtureBase { + void runOnOperation() override { + if (failed(Deallocator().transformFuncOp(getOperation()))) { + signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr> createRegionAccessFixture() { + return std::make_unique(); +} + +} // namespace mlir::spu::pphlo From 333f16cc76157c1f0aacdf33b08d4dc9b18304a8 Mon Sep 17 00:00:00 2001 From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com> Date: Tue, 16 Jul 2024 08:07:51 +0800 Subject: [PATCH 11/27] chore(deps): update dependency sphinx to v7.4.4 (#770) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [![Mend Renovate](https://app.renovatebot.com/images/banner.svg)](https://renovatebot.com) This PR contains the following updates: | Package | Change | Age | Adoption | Passing | Confidence | |---|---|---|---|---|---| | [sphinx](https://togithub.com/sphinx-doc/sphinx) ([changelog](https://www.sphinx-doc.org/en/master/changes.html)) | `==7.3.7` -> `==7.4.4` | [![age](https://developer.mend.io/api/mc/badges/age/pypi/sphinx/7.4.4?slim=true)](https://docs.renovatebot.com/merge-confidence/) | [![adoption](https://developer.mend.io/api/mc/badges/adoption/pypi/sphinx/7.4.4?slim=true)](https://docs.renovatebot.com/merge-confidence/) | [![passing](https://developer.mend.io/api/mc/badges/compatibility/pypi/sphinx/7.3.7/7.4.4?slim=true)](https://docs.renovatebot.com/merge-confidence/) | [![confidence](https://developer.mend.io/api/mc/badges/confidence/pypi/sphinx/7.3.7/7.4.4?slim=true)](https://docs.renovatebot.com/merge-confidence/) | --- ### Release Notes
sphinx-doc/sphinx (sphinx) ### [`v7.4.4`](https://togithub.com/sphinx-doc/sphinx/blob/HEAD/CHANGES.rst#Release-744-released-Jul-15-2024) [Compare Source](https://togithub.com/sphinx-doc/sphinx/compare/v7.4.3...v7.4.4) \===================================== ## Bugs fixed - [#​12585](https://togithub.com/sphinx-doc/sphinx/issues/12585), [#​12586](https://togithub.com/sphinx-doc/sphinx/issues/12586): Do not warn when an intersphinx inventory contains case-insensitively ambiguous duplicate items. Patch by James Addison. ### [`v7.4.3`](https://togithub.com/sphinx-doc/sphinx/blob/HEAD/CHANGES.rst#Release-743-released-Jul-15-2024) [Compare Source](https://togithub.com/sphinx-doc/sphinx/compare/v7.4.2...v7.4.3) \===================================== ## Bugs fixed - [#​12582](https://togithub.com/sphinx-doc/sphinx/issues/12582): Restore support for list-styled :confval:`source_suffix` values with extensions that register parsers. Patch by Adam Turner. ### [`v7.4.2`](https://togithub.com/sphinx-doc/sphinx/blob/HEAD/CHANGES.rst#Release-742-released-Jul-15-2024) [Compare Source](https://togithub.com/sphinx-doc/sphinx/compare/v7.4.1...v7.4.2) \===================================== ## Bugs fixed - [#​12580](https://togithub.com/sphinx-doc/sphinx/issues/12580), [#​12583](https://togithub.com/sphinx-doc/sphinx/issues/12583): Resolve failures with the C domain on incremental builds with Sphinx 7.3.7 and earlier. Patch by Adam Turner. ### [`v7.4.1`](https://togithub.com/sphinx-doc/sphinx/blob/HEAD/CHANGES.rst#Release-741-in-development) [Compare Source](https://togithub.com/sphinx-doc/sphinx/compare/v7.4.0...v7.4.1) \============================== ## Dependencies ## Incompatible changes ## Deprecated ## Features added ## Bugs fixed - Fix invalid HTML when a rubric node with invalid `heading-level` is used. Patch by Adam Turner. - [#​12579](https://togithub.com/sphinx-doc/sphinx/issues/12579), [#​12581](https://togithub.com/sphinx-doc/sphinx/issues/12581): Restore support for `typing.ParamSpec` in autodoc. Patch by Adam Turner. ## Testing ### [`v7.4.0`](https://togithub.com/sphinx-doc/sphinx/blob/HEAD/CHANGES.rst#Release-740-released-Jul-15-2024) [Compare Source](https://togithub.com/sphinx-doc/sphinx/compare/v7.3.7...v7.4.0) \===================================== ## Dependencies - [#​12555](https://togithub.com/sphinx-doc/sphinx/issues/12555): Drop Docutils 0.18.1 and Docutils 0.19 support. Patch by Adam Turner. - LaTeX: the `xcolor` package is now required (but is for example part of Ubuntu `texlive-latex-recommended` which has always been required). - LaTeX: the `fontawesome5` LaTeX package is needed for the default choices of icons now used in admonition titles in PDF output; but if unavailable the PDF build will simply silently omit rendering such icons. Check the documentation of the `iconpackage` key of :ref:`'sphinxsetup' ` for more. ## Deprecated - LaTeX: the `sphinxlightbox` environment is not used anymore, all types of admonitions use (by default) only `sphinxheavybox`. ## Features added .. rst-class:: compact - [#​11165](https://togithub.com/sphinx-doc/sphinx/issues/11165): Support the `officially recommended`\_ `.jinja` suffix for template files. Patch by James Addison and Adam Turner .. \_officially recommended: https://jinja.palletsprojects.com/en/latest/templates/#template-file-extension - [#​12325](https://togithub.com/sphinx-doc/sphinx/issues/12325): Flatten `Union[Literal[T], Literal[U], ...]` to `Literal[T, U, ...]` when turning annotations into strings. Patch by Adam Turner. - [#​12319](https://togithub.com/sphinx-doc/sphinx/issues/12319): `sphinx.ext.extlinks`: Add `extlink-{name}` CSS class to links. Patch by Hugo van Kemenade. - [#​12387](https://togithub.com/sphinx-doc/sphinx/issues/12387): Improve CLI progress message, when copying assets. Patch by INADA Nakoi and Bénédikt Tran. - [#​12361](https://togithub.com/sphinx-doc/sphinx/issues/12361): Add :attr:`.BuildEnvironment.parser`. Patch by Chris Sewell. - [#​12358](https://togithub.com/sphinx-doc/sphinx/issues/12358): Add :attr:`.Sphinx.fresh_env_used`. Patch by Chris Sewell. - [#​12329](https://togithub.com/sphinx-doc/sphinx/issues/12329): Add detection of ambiguous `std:label` and `std:term` references during loading and resolution of Intersphinx targets. Patch by James Addison. - [#​12422](https://togithub.com/sphinx-doc/sphinx/issues/12422): Do not duplicate "navigation" in aria-label of built-in themes. Patch by Thomas Weißschuh - [#​12421](https://togithub.com/sphinx-doc/sphinx/issues/12421): Include project name in `logo_alt` of built-in themes. Patch by Thomas Weißschuh - [#​12448](https://togithub.com/sphinx-doc/sphinx/issues/12448): Add :option:`sphinx-apidoc --remove-old` option. Patch by Chris Sewell. - [#​12456](https://togithub.com/sphinx-doc/sphinx/issues/12456): Add :option:`sphinx-autogen --remove-old` option. Patch by Chris Sewell. - [#​12479](https://togithub.com/sphinx-doc/sphinx/issues/12479): Add warning subtype `toc.no_title`. Patch by Ondřej Navrátil. - [#​12492](https://togithub.com/sphinx-doc/sphinx/issues/12492): Add helper methods for parsing reStructuredText content into nodes from within a directive. - :py:meth:`~sphinx.util.docutils.SphinxDirective.parse_content_to_nodes()` parses the directive's content and returns a list of Docutils nodes. - :py:meth:`~sphinx.util.docutils.SphinxDirective.parse_text_to_nodes()` parses the provided text and returns a list of Docutils nodes. - :py:meth:`~sphinx.util.docutils.SphinxDirective.parse_inline()` parses the provided text into inline elements and text nodes. Patch by Adam Turner. - [#​12258](https://togithub.com/sphinx-doc/sphinx/issues/12258): Support `typing_extensions.Unpack` Patch by Bénédikt Tran and Adam Turner. - [#​12524](https://togithub.com/sphinx-doc/sphinx/issues/12524): Add a `class` option to the :rst:dir:`toctree` directive. Patch by Tim Hoffmann. - [#​12536](https://togithub.com/sphinx-doc/sphinx/issues/12536): Add the :rst:dir:`confval` directive. Patch by Adam Turner. - [#​12537](https://togithub.com/sphinx-doc/sphinx/issues/12537): :confval:`c_id_attributes`, :confval:`c_paren_attributes`, :confval:`cpp_id_attributes`, and :confval:`cpp_paren_attributes` can now be a tuple of strings. :confval:`c_extra_keywords`, :confval:`gettext_additional_targets`, :confval:`html_domain_indices`, :confval:`latex_domain_indices`, and :confval:`texinfo_domain_indices`, can now be a set of strings. Patch by Adam Turner. - [#​12523](https://togithub.com/sphinx-doc/sphinx/issues/12523): Added configuration option, :confval:`math_numsep`, to define the separator for math numbering. Patch by Thomas Fanning - [#​11592](https://togithub.com/sphinx-doc/sphinx/issues/11592): Add :confval:`coverage_modules` to the coverage builder to allow explicitly specifying which modules should be documented. Patch by Stephen Finucane. - [#​7896](https://togithub.com/sphinx-doc/sphinx/issues/7896), [#​11989](https://togithub.com/sphinx-doc/sphinx/issues/11989): Add a :rst:dir:`py:type` directive for documenting type aliases, and a :rst:role:`py:type` role for linking to them. Patch by Ashley Whetter. - [#​12549](https://togithub.com/sphinx-doc/sphinx/issues/12549): Add optional `description` argument to :meth:`.Sphinx.add_config_value`. Patch by Chris Sewell. - [#​6792](https://togithub.com/sphinx-doc/sphinx/issues/6792): Prohibit module import cycles in :mod:`sphinx.ext.autosummary`. Patch by Trevor Bekolay. - [#​12508](https://togithub.com/sphinx-doc/sphinx/issues/12508): LaTeX: Revamped styling of all admonitions, with addition of a title row with icon. Patch by Jean-François B. - [#​11773](https://togithub.com/sphinx-doc/sphinx/issues/11773): Display :py:class:`~typing.Annotated` annotations with their metadata in the Python domain. Patch by Adam Turner and David Stansby. - [#​12506](https://togithub.com/sphinx-doc/sphinx/issues/12506): Add `level` option to :rst:dir:`rubric` directive. Patch by Chris Sewell. - [#​12567](https://togithub.com/sphinx-doc/sphinx/issues/12567): Add the :event:`write-started` event. Patch by Chris Sewell. ## Bugs fixed - [#​12314](https://togithub.com/sphinx-doc/sphinx/issues/12314): Properly format `collections.abc.Callable` in annotations. Patch by Adam Turner. - [#​12162](https://togithub.com/sphinx-doc/sphinx/issues/12162): Fix a performance regression in the C domain that has been present since version 3.0.0. Patch by Donald Hunter. - [#​12320](https://togithub.com/sphinx-doc/sphinx/issues/12320): Fix removal of anchors from search summaries (regression in 7.3.0). Patch by Will Lachance. - [#​12251](https://togithub.com/sphinx-doc/sphinx/issues/12251): Fix `merge_domaindata()` in `sphinx.ext.duration`. Patch by Matthias Geier. - [#​12224](https://togithub.com/sphinx-doc/sphinx/issues/12224): Properly detect WebP files. Patch by Benjamin Cabé. - [#​12380](https://togithub.com/sphinx-doc/sphinx/issues/12380): LaTeX: Footnote mark sometimes indicates `Page N` where `N` is the current page number and the footnote does appear on that same page. Patch by Jean-François B. - [#​12410](https://togithub.com/sphinx-doc/sphinx/issues/12410): LaTeX: for French and `'lualatex'` as :confval:`latex_engine` `polyglossia` and not `babel` is used (contrarily to `'xelatex'`). Patch by Jean-François B. - [#​12416](https://togithub.com/sphinx-doc/sphinx/issues/12416): Ensure that configuration setting aliases are always synchronised when one value or the other is modified. Patch by Bénédikt Tran. - [#​12220](https://togithub.com/sphinx-doc/sphinx/issues/12220): Fix loading custom template translations for `en` locale. Patch by Nicolas Peugnet. - [#​12459](https://togithub.com/sphinx-doc/sphinx/issues/12459): Add valid-type arguments to the `linkcheck_rate_limit_timeout` configuration setting. Patch by James Addison. - [#​12331](https://togithub.com/sphinx-doc/sphinx/issues/12331): Resolve data-URI-image-extraction regression from v7.3.0 affecting builders without native support for data-URIs in their output format. Patch by James Addison. - [#​12494](https://togithub.com/sphinx-doc/sphinx/issues/12494): Fix invalid genindex.html file produced with translated docs (regression in 7.1.0). Patch by Nicolas Peugnet. - [#​11961](https://togithub.com/sphinx-doc/sphinx/issues/11961): Omit anchor references from document title entries in the search index, removing duplication of search results. Patch by James Addison. - [#​12425](https://togithub.com/sphinx-doc/sphinx/issues/12425): Use Docutils' SVG processing in the HTML builder and remove Sphinx's custom logic. Patch by Tunç Başar Köse. - [#​12391](https://togithub.com/sphinx-doc/sphinx/issues/12391): Adjust scoring of matches during HTML search so that document main titles tend to rank higher than subsection titles. In addition, boost matches on the name of programming domain objects relative to title/subtitle matches. Patch by James Addison and Will Lachance. - [#​9634](https://togithub.com/sphinx-doc/sphinx/issues/9634): Do not add a fallback language by stripping the country code. Patch by Alvin Wong. - [#​12352](https://togithub.com/sphinx-doc/sphinx/issues/12352): Add domain objects to the table of contents in the same order as defined in the document. Previously, each domain used language-specific nesting rules, which removed control from document authors. Patch by Jakob Lykke Andersen and Adam Turner. - [#​11041](https://togithub.com/sphinx-doc/sphinx/issues/11041): linkcheck: Ignore URLs that respond with non-Unicode content. Patch by James Addison. - [#​12543](https://togithub.com/sphinx-doc/sphinx/issues/12543): Fix :pep:`695` formatting for LaTeX output. Patch by Bénédikt Tran. ## Testing - karma: refactor HTML search tests to use fixtures generated by Sphinx. Patch by James Addison.
--- ### Configuration 📅 **Schedule**: Branch creation - At any time (no schedule defined), Automerge - At any time (no schedule defined). 🚦 **Automerge**: Disabled by config. Please merge this manually once you are satisfied. ♻ **Rebasing**: Whenever PR becomes conflicted, or you tick the rebase/retry checkbox. 🔕 **Ignore**: Close this PR and you won't be reminded about this update again. --- - [ ] If you want to rebase/retry this PR, check this box --- This PR has been generated by [Mend Renovate](https://www.mend.io/free-developer-tools/renovate/). View repository job log [here](https://developer.mend.io/github/secretflow/spu). Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> --- docs/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 9b6bb997..3b26b70d 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,6 +1,6 @@ myst-parser==3.0.1 rstcheck==6.2.4 -sphinx==7.3.7 +sphinx==7.4.4 nbsphinx==0.9.4 sphinx-autobuild==2024.4.16 sphinx-markdown-parser==0.2.4 From f3296773573842d254950a76428473f68b96eb92 Mon Sep 17 00:00:00 2001 From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com> Date: Tue, 16 Jul 2024 08:08:21 +0800 Subject: [PATCH 12/27] chore(deps): update github/codeql-action action to v3.25.12 (#767) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [![Mend Renovate](https://app.renovatebot.com/images/banner.svg)](https://renovatebot.com) This PR contains the following updates: | Package | Type | Update | Change | |---|---|---|---| | [github/codeql-action](https://togithub.com/github/codeql-action) | action | patch | `v3.25.11` -> `v3.25.12` | --- ### Release Notes
github/codeql-action (github/codeql-action) ### [`v3.25.12`](https://togithub.com/github/codeql-action/compare/v3.25.11...v3.25.12) [Compare Source](https://togithub.com/github/codeql-action/compare/v3.25.11...v3.25.12)
--- ### Configuration 📅 **Schedule**: Branch creation - At any time (no schedule defined), Automerge - At any time (no schedule defined). 🚦 **Automerge**: Disabled by config. Please merge this manually once you are satisfied. ♻ **Rebasing**: Whenever PR becomes conflicted, or you tick the rebase/retry checkbox. 🔕 **Ignore**: Close this PR and you won't be reminded about this update again. --- - [ ] If you want to rebase/retry this PR, check this box --- This PR has been generated by [Mend Renovate](https://www.mend.io/free-developer-tools/renovate/). View repository job log [here](https://developer.mend.io/github/secretflow/spu). Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> --- .github/workflows/scorecard.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/scorecard.yml b/.github/workflows/scorecard.yml index 3b6e8be7..5b25f5b2 100644 --- a/.github/workflows/scorecard.yml +++ b/.github/workflows/scorecard.yml @@ -67,6 +67,6 @@ jobs: # Upload the results to GitHub's code scanning dashboard. - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@b611370bb5703a7efb587f9d136a52ea24c5c38c # v3.25.11 + uses: github/codeql-action/upload-sarif@4fa2a7953630fd2f3fb380f21be14ede0169dd4f # v3.25.12 with: sarif_file: results.sarif From 63cce33618b9b3d76761126453851f72d56fbf8b Mon Sep 17 00:00:00 2001 From: Yancheng Zheng <103552181+anakinxc@users.noreply.github.com> Date: Wed, 17 Jul 2024 17:48:22 +0800 Subject: [PATCH 13/27] Repo sync (#772) --- .clang-tidy | 3 +- CHANGELOG.md | 3 +- bazel/repositories.bzl | 4 +- docs/development/add_protocols.rst | 2 +- examples/python/ml/jax_lr/README.md | 2 +- experimental/squirrel/bin_matvec_prot_test.cc | 4 +- experimental/squirrel/objectives.cc | 6 +- experimental/squirrel/tree_build_worker.cc | 2 +- experimental/squirrel/utils.cc | 6 +- experimental/squirrel/utils_test.cc | 4 +- libspu/compiler/common/compilation_context.h | 1 - libspu/compiler/compile.cc | 2 - libspu/compiler/core/core.cc | 1 - libspu/compiler/front_end/fe.cc | 1 - libspu/compiler/tests/interpret/abs.mlir | 2 + libspu/compiler/tests/interpret/add.mlir | 2 + libspu/compiler/tests/interpret/and.mlir | 2 + libspu/compiler/tests/interpret/atan2.mlir | 2 + .../compiler/tests/interpret/broadcast.mlir | 2 + libspu/compiler/tests/interpret/case.mlir | 2 + libspu/compiler/tests/interpret/ceil.mlir | 2 + libspu/compiler/tests/interpret/clamp.mlir | 2 + .../compiler/tests/interpret/concatenate.mlir | 2 + libspu/compiler/tests/interpret/convert.mlir | 2 + .../compiler/tests/interpret/convolution.mlir | 2 + libspu/compiler/tests/interpret/cosine.mlir | 2 + libspu/compiler/tests/interpret/divide.mlir | 2 + .../compiler/tests/interpret/dot_general.mlir | 2 + .../tests/interpret/dynamic_slice.mlir | 2 + .../tests/interpret/dynamic_update_slice.mlir | 2 + libspu/compiler/tests/interpret/equal.mlir | 2 + .../compiler/tests/interpret/exponential.mlir | 2 + .../interpret/exponential_minus_one.mlir | 2 + libspu/compiler/tests/interpret/floor.mlir | 2 + .../tests/interpret/generate_mlir_tests.py | 9 + libspu/compiler/tests/interpret/greater.mlir | 2 + .../tests/interpret/greater_equal.mlir | 2 + libspu/compiler/tests/interpret/if.mlir | 2 + libspu/compiler/tests/interpret/iota.mlir | 2 + libspu/compiler/tests/interpret/less.mlir | 2 + .../compiler/tests/interpret/less_equal.mlir | 2 + libspu/compiler/tests/interpret/log.mlir | 2 + .../tests/interpret/log_plus_one.mlir | 2 + libspu/compiler/tests/interpret/logistic.mlir | 2 + libspu/compiler/tests/interpret/maximum.mlir | 1 + libspu/compiler/tests/interpret/minimum.mlir | 1 + libspu/compiler/tests/interpret/multiply.mlir | 2 + libspu/compiler/tests/interpret/negate.mlir | 2 + libspu/compiler/tests/interpret/not.mlir | 2 + .../compiler/tests/interpret/not_equal.mlir | 2 + libspu/compiler/tests/interpret/or.mlir | 2 + libspu/compiler/tests/interpret/pad.mlir | 2 + libspu/compiler/tests/interpret/popcnt.mlir | 2 + libspu/compiler/tests/interpret/power.mlir | 2 + .../compiler/tests/interpret/reciprocal.mlir | 26 +++ libspu/compiler/tests/interpret/reduce.mlir | 2 + .../tests/interpret/reduce_window.mlir | 2 + libspu/compiler/tests/interpret/reshape.mlir | 2 + libspu/compiler/tests/interpret/reverse.mlir | 2 + .../compiler/tests/interpret/ring_cast.mlir | 2 + .../tests/interpret/round_nearest_afz.mlir | 2 + libspu/compiler/tests/interpret/rsqrt.mlir | 2 + libspu/compiler/tests/interpret/select.mlir | 2 + .../tests/interpret/select_and_scatter.mlir | 2 + .../interpret/shift_right_arithmetic.mlir | 2 + .../tests/interpret/shift_right_logical.mlir | 2 + libspu/compiler/tests/interpret/sign.mlir | 2 + libspu/compiler/tests/interpret/sine.mlir | 2 + libspu/compiler/tests/interpret/slice.mlir | 2 + libspu/compiler/tests/interpret/sort.mlir | 2 + libspu/compiler/tests/interpret/sqrt.mlir | 2 + libspu/compiler/tests/interpret/subtract.mlir | 2 + libspu/compiler/tests/interpret/tanh.mlir | 2 + .../tests/interpret/test_json/reciprocal.json | 24 +++ .../compiler/tests/interpret/transpose.mlir | 2 + libspu/compiler/tests/interpret/while.mlir | 2 + libspu/compiler/tests/interpret/xor.mlir | 2 + .../passes/hlo2pphlo/select_and_scatter.mlir | 4 +- .../tests/passes/hlo2pphlo/sort_p.mlir | 4 +- .../tests/passes/hlo2pphlo/sort_s.mlir | 4 +- libspu/compiler/tools/spu-lsp.cc | 3 +- libspu/compiler/tools/spu-translate.cc | 20 ++- libspu/core/encoding.cc | 12 +- libspu/core/pt_buffer_view.cc | 2 +- libspu/core/type_util.h | 170 +++++++++--------- libspu/dialect/pphlo/IR/dialect.td | 5 +- libspu/dialect/pphlo/IR/type_inference.cc | 23 +-- .../pphlo/transforms/hlo_legalize_to_pphlo.cc | 16 +- libspu/kernel/hal/constants.cc | 2 +- libspu/kernel/hal/fxp_approx.cc | 33 ++-- libspu/kernel/hal/fxp_base.cc | 21 ++- libspu/kernel/hal/fxp_base.h | 2 + libspu/kernel/hal/fxp_cleartext.cc | 6 +- libspu/kernel/hal/permute.cc | 11 +- libspu/kernel/hal/polymorphic.cc | 18 +- libspu/kernel/hal/polymorphic.h | 7 +- libspu/kernel/hal/prot_wrapper.cc | 10 +- libspu/kernel/hal/prot_wrapper.h | 18 +- libspu/kernel/hal/ring.cc | 36 ++-- libspu/kernel/hal/ring.h | 6 +- libspu/kernel/hal/type_cast.cc | 7 +- libspu/kernel/hlo/basic_unary.cc | 10 +- libspu/kernel/hlo/shift.cc | 29 +-- libspu/mpc/ab_api.cc | 46 ++--- libspu/mpc/ab_api.h | 8 +- libspu/mpc/ab_api_test.cc | 43 ++--- libspu/mpc/aby3/arithmetic.cc | 51 +++--- libspu/mpc/aby3/arithmetic.h | 2 +- libspu/mpc/aby3/boolean.cc | 91 +++++----- libspu/mpc/aby3/boolean.h | 6 +- libspu/mpc/aby3/conversion.cc | 45 ++--- libspu/mpc/aby3/io.cc | 6 +- libspu/mpc/aby3/oram.cc | 12 +- libspu/mpc/aby3/ot.cc | 2 - libspu/mpc/aby3/permute.cc | 8 +- libspu/mpc/aby3/value.h | 2 +- libspu/mpc/api.cc | 33 ++-- libspu/mpc/api.h | 18 +- libspu/mpc/api_test.cc | 167 ++++++++--------- .../mpc/cheetah/arith/cheetah_conv2d_test.cc | 2 +- libspu/mpc/cheetah/arith/cheetah_dot_test.cc | 4 +- libspu/mpc/cheetah/arith/common.cc | 2 +- libspu/mpc/cheetah/arith/matmat_prot.cc | 8 +- libspu/mpc/cheetah/arith/matmat_prot_test.cc | 6 +- libspu/mpc/cheetah/arith/vector_encoder.cc | 3 +- .../mpc/cheetah/arith/vector_encoder_test.cc | 2 +- libspu/mpc/cheetah/arithmetic.cc | 2 +- libspu/mpc/cheetah/arithmetic.h | 2 +- libspu/mpc/cheetah/arithmetic_semi2k.cc | 9 +- libspu/mpc/cheetah/boolean.h | 6 +- libspu/mpc/cheetah/boolean_semi2k.cc | 28 ++- libspu/mpc/cheetah/nonlinear/compare_prot.cc | 8 +- .../cheetah/nonlinear/compare_prot_test.cc | 19 +- libspu/mpc/cheetah/nonlinear/equal_prot.cc | 6 +- .../mpc/cheetah/nonlinear/equal_prot_test.cc | 4 +- libspu/mpc/cheetah/nonlinear/truncate_prot.cc | 17 +- .../cheetah/nonlinear/truncate_prot_test.cc | 14 +- libspu/mpc/cheetah/ot/basic_ot_prot.cc | 20 +-- libspu/mpc/cheetah/ot/basic_ot_prot_test.cc | 30 ++-- libspu/mpc/cheetah/ot/emp/ferret_test.cc | 8 +- libspu/mpc/cheetah/ot/ot_util.cc | 19 +- libspu/mpc/cheetah/ot/ot_util_test.cc | 6 +- libspu/mpc/cheetah/ot/yacl/ferret_test.cc | 8 +- libspu/mpc/cheetah/rlwe/modswitch_helper.cc | 6 +- .../mpc/cheetah/rlwe/modswitch_helper_test.cc | 13 +- libspu/mpc/cheetah/rlwe/packlwes_test.cc | 7 +- libspu/mpc/cheetah/state.cc | 4 +- libspu/mpc/common/prg_state.h | 31 ++-- libspu/mpc/common/pv2k.cc | 52 +++--- libspu/mpc/kernel.cc | 6 +- libspu/mpc/kernel.h | 2 +- libspu/mpc/ref2k/ref2k.cc | 17 +- libspu/mpc/securenn/arithmetic.cc | 33 ++-- libspu/mpc/securenn/arithmetic.h | 2 +- libspu/mpc/securenn/boolean.cc | 35 ++-- libspu/mpc/securenn/boolean.h | 6 +- libspu/mpc/securenn/conversion.cc | 10 +- libspu/mpc/semi2k/arithmetic.cc | 15 +- libspu/mpc/semi2k/arithmetic.h | 2 +- .../semi2k/beaver/beaver_impl/beaver_test.cc | 43 +++-- .../semi2k/beaver/beaver_impl/beaver_tfp.cc | 2 +- .../semi2k/beaver/beaver_impl/beaver_ttp.cc | 2 +- .../trusted_party/trusted_party.cc | 9 +- libspu/mpc/semi2k/boolean.cc | 31 ++-- libspu/mpc/semi2k/boolean.h | 6 +- libspu/mpc/semi2k/conversion.cc | 20 ++- libspu/mpc/semi2k/permute.cc | 2 +- libspu/mpc/spdz2k/abprotocol_spdz2k_test.cc | 4 +- libspu/mpc/spdz2k/arithmetic.cc | 21 +-- libspu/mpc/spdz2k/arithmetic.h | 2 +- libspu/mpc/spdz2k/beaver/beaver_test.cc | 7 +- libspu/mpc/spdz2k/beaver/beaver_tfp.cc | 8 +- libspu/mpc/spdz2k/beaver/beaver_tinyot.cc | 34 ++-- libspu/mpc/spdz2k/beaver/trusted_party.cc | 5 +- libspu/mpc/spdz2k/boolean.cc | 45 +++-- libspu/mpc/spdz2k/boolean.h | 6 +- libspu/mpc/spdz2k/conversion.cc | 30 ++-- libspu/mpc/spdz2k/io.cc | 8 +- libspu/mpc/spdz2k/value.cc | 8 +- libspu/mpc/tools/benchmark.h | 32 ++-- libspu/mpc/utils/BUILD.bazel | 1 + libspu/mpc/utils/circuits.h | 38 ++-- libspu/mpc/utils/circuits_test.cc | 12 +- libspu/mpc/utils/permute.cc | 8 +- libspu/mpc/utils/ring_ops.cc | 82 +++++---- libspu/mpc/utils/ring_ops.h | 12 +- libspu/version.h | 2 +- 187 files changed, 1197 insertions(+), 994 deletions(-) create mode 100644 libspu/compiler/tests/interpret/reciprocal.mlir create mode 100644 libspu/compiler/tests/interpret/test_json/reciprocal.json diff --git a/.clang-tidy b/.clang-tidy index bacbb3a2..227a407d 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -28,7 +28,8 @@ Checks: "abseil-cleanup-ctad, -readability-identifier-length, -readability-function-cognitive-complexity, -readability-magic-numbers, - -readability-named-parameter" + -readability-named-parameter, + -readability-convert-member-functions-to-static" CheckOptions: - key: bugprone-argument-comment.StrictMode diff --git a/CHANGELOG.md b/CHANGELOG.md index 0712ac66..6b86bb5b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,8 +10,9 @@ > > please add your unreleased change here. -## TBD +## 20240716 +- [SPU] 0.9.2b0 release - [Feature] Support jax.numpy.bitwise_count - [Bugfix] Fix jax.numpy.signbit wrong answer with very large input diff --git a/bazel/repositories.bzl b/bazel/repositories.bzl index 59ed839d..d97c8e68 100644 --- a/bazel/repositories.bzl +++ b/bazel/repositories.bzl @@ -136,8 +136,8 @@ def _bazel_skylib(): ) def _com_github_openxla_xla(): - OPENXLA_COMMIT = "9b0dd58c9b625a2e958f4fc7787a1ff5c95dbb40" - OPENXLA_SHA256 = "f150c5b49e4d4497aae2c79232f1efe2baccaa72223b21dc8715be73eab74417" + OPENXLA_COMMIT = "8533a6869ae02fb3b15a8a12739a982fc3c9f6e7" + OPENXLA_SHA256 = "d5b076825c992f59542f6b94e5480c7e7c6c627cd18c80ec60b6d5b295c160d4" # We need openxla to handle xla/mhlo/stablehlo maybe( diff --git a/docs/development/add_protocols.rst b/docs/development/add_protocols.rst index 31b4b5e2..988c846c 100644 --- a/docs/development/add_protocols.rst +++ b/docs/development/add_protocols.rst @@ -248,7 +248,7 @@ When kernels are implemented and registered, a new protocol is finally added. auto* prg_state = ctx->getState(); // dispatch the real implementation to different fields - return DISPATCH_ALL_FIELDS(field, "aby3.mulAA", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { // the real protocol implementation ... }); diff --git a/examples/python/ml/jax_lr/README.md b/examples/python/ml/jax_lr/README.md index fc0e6580..bc340199 100644 --- a/examples/python/ml/jax_lr/README.md +++ b/examples/python/ml/jax_lr/README.md @@ -5,7 +5,7 @@ This example demonstrates how to use SPU to train a logistic regression model pr 1. Launch SPU backend runtime ```sh - bazel run -c opt //examples/python/utils:nodectl -- up + bazel run -c opt //examples/python/utils:nodectl -- -c examples/python/conf/2pc_semi2k.json up ``` 2. Run `jax_lr` example diff --git a/experimental/squirrel/bin_matvec_prot_test.cc b/experimental/squirrel/bin_matvec_prot_test.cc index 0d5cae29..452e127d 100644 --- a/experimental/squirrel/bin_matvec_prot_test.cc +++ b/experimental/squirrel/bin_matvec_prot_test.cc @@ -110,7 +110,7 @@ TEST_P(BinMatVecProtTest, Basic) { }); NdArrayRef reveal = ring_add(out_shr[0], out_shr[1]); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _vec(vec); auto expected = BinAccumuate(_vec, mat); NdArrayView got(reveal); @@ -160,7 +160,7 @@ TEST_P(BinMatVecProtTest, WithIndicator) { }); NdArrayRef reveal = ring_add(out_shr[0], out_shr[1]); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _vec(vec); auto expected = BinAccumuate(_vec, mat, absl::MakeConstSpan(indicator)); diff --git a/experimental/squirrel/objectives.cc b/experimental/squirrel/objectives.cc index b9423d2d..baf5f08d 100644 --- a/experimental/squirrel/objectives.cc +++ b/experimental/squirrel/objectives.cc @@ -272,9 +272,9 @@ namespace { res = NdArrayRef(makeType(ftype), in.shape()); } - return DISPATCH_ALL_FIELDS(field, "cheetah.ring_cast", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using from_ring2k_t = ring2k_t; - return DISPATCH_ALL_FIELDS(ftype, "cheetah.ring_cast", [&]() { + return DISPATCH_ALL_FIELDS(ftype, [&]() { using to_ring2k_t = ring2k_t; NdArrayView _in(in); NdArrayView _res(res); @@ -383,7 +383,7 @@ spu::Value Logistic(spu::SPUContext* ctx, const spu::Value& x) { spu::Value Sigmoid(spu::SPUContext* ctx, const spu::Value& x) { namespace sk = spu::kernel; auto c05 = sk::hlo::Constant(ctx, 0.5F, x.shape()); - auto half = sk::hal::right_shift_arithmetic(ctx, x, 1); + auto half = sk::hal::right_shift_arithmetic(ctx, x, {1}); auto divisor = sk::hlo::Add(ctx, sk::hlo::Constant(ctx, 1, x.shape()), sk::hal::f_square(ctx, x)); return sk::hlo::Add(ctx, c05, diff --git a/experimental/squirrel/tree_build_worker.cc b/experimental/squirrel/tree_build_worker.cc index 30c441f9..06724354 100644 --- a/experimental/squirrel/tree_build_worker.cc +++ b/experimental/squirrel/tree_build_worker.cc @@ -230,7 +230,7 @@ void AccumulateHistogram(spu::NdArrayRef buckets_share, size_t nfeatures, // The buckets belong to the i-th feature is // `buckets[i*bucket_size:(i+1)*bucket_size]` auto field = buckets_share.eltype().as()->field(); - DISPATCH_ALL_FIELDS(field, "AccumulateHistogram", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView histogram(buckets_share); for (size_t j = 0; j < nfeatures; ++j) { size_t start = j * bucket_size; diff --git a/experimental/squirrel/utils.cc b/experimental/squirrel/utils.cc index 3f14436e..58a98d70 100644 --- a/experimental/squirrel/utils.cc +++ b/experimental/squirrel/utils.cc @@ -167,7 +167,7 @@ spu::Value MulPrivateArithWithPrivateBoolean(spu::SPUContext* ctx, &kctx, arith.data(), [&](const NdArrayRef& input, const std::shared_ptr& base_ot) { - return DISPATCH_ALL_FIELDS(ft, "ot", [&]() { + return DISPATCH_ALL_FIELDS(ft, [&]() { NdArrayRef ot_out = spu::mpc::ring_zeros(ft, input.shape()); auto inp = absl::MakeConstSpan(&input.at(0), input.numel()); auto oup = absl::MakeSpan(&ot_out.at(0), ot_out.numel()); @@ -193,7 +193,7 @@ spu::Value MulPrivateArithWithPrivateBoolean(spu::SPUContext* ctx, &kctx, boolean, [&](absl::Span input, const std::shared_ptr& base_ot) { - return DISPATCH_ALL_FIELDS(ft, "ot", [&]() { + return DISPATCH_ALL_FIELDS(ft, [&]() { NdArrayRef ot_out = spu::mpc::ring_zeros(ft, {(int64_t)input.size()}); auto oup = absl::MakeSpan(&ot_out.at(0), input.size()); base_ot->GetReceiverCOT()->RecvCAMCC(input, oup); @@ -222,7 +222,7 @@ spu::Value MulArithShareWithANDBoolShare(spu::SPUContext* ctx, std::shared_ptr base_ot) { NdArrayRef out(x.eltype(), x.shape()); - DISPATCH_ALL_FIELDS(ft, "camcc", [&]() { + DISPATCH_ALL_FIELDS(ft, [&]() { spu::NdArrayView _ashr(x); auto oup = absl::MakeSpan(&out.at(0), y.size()); std::vector corr(y.size()); diff --git a/experimental/squirrel/utils_test.cc b/experimental/squirrel/utils_test.cc index 3bdc004e..8cde97fd 100644 --- a/experimental/squirrel/utils_test.cc +++ b/experimental/squirrel/utils_test.cc @@ -85,7 +85,7 @@ TEST_F(UtilsTest, ReduceSum) { const double fxp = std::pow(2., rt_config.fxp_fraction_bits()); auto flatten = got.data().reshape({got.numel()}); - DISPATCH_ALL_FIELDS(field, "check", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using s2k = std::make_signed::type; NdArrayView got(flatten); for (int64_t i = 0; i < expected.numel(); ++i) { @@ -136,7 +136,7 @@ TEST_F(UtilsTest, ArgMax) { if (lctx->Rank() == 0) { auto flatten = got.data().reshape({got.numel()}); - DISPATCH_ALL_FIELDS(field, "check", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView got(flatten); for (size_t i = 0; i < expected.size(); ++i) { ASSERT_EQ(expected(i), got[i]); diff --git a/libspu/compiler/common/compilation_context.h b/libspu/compiler/common/compilation_context.h index 24ebe043..8abfed96 100644 --- a/libspu/compiler/common/compilation_context.h +++ b/libspu/compiler/common/compilation_context.h @@ -16,7 +16,6 @@ #include #include -#include #include "mlir/IR/MLIRContext.h" #include "mlir/Pass/PassManager.h" diff --git a/libspu/compiler/compile.cc b/libspu/compiler/compile.cc index 38a59acb..236d6d02 100644 --- a/libspu/compiler/compile.cc +++ b/libspu/compiler/compile.cc @@ -14,8 +14,6 @@ #include "libspu/compiler/compile.h" -#include "mlir/IR/BuiltinOps.h" - #include "libspu/compiler/codegen/codegen.h" #include "libspu/compiler/common/compilation_context.h" #include "libspu/compiler/core/core.h" diff --git a/libspu/compiler/core/core.cc b/libspu/compiler/core/core.cc index e977050c..b7c46d8f 100644 --- a/libspu/compiler/core/core.cc +++ b/libspu/compiler/core/core.cc @@ -16,7 +16,6 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" diff --git a/libspu/compiler/front_end/fe.cc b/libspu/compiler/front_end/fe.cc index e560c772..5bcb693f 100644 --- a/libspu/compiler/front_end/fe.cc +++ b/libspu/compiler/front_end/fe.cc @@ -21,7 +21,6 @@ #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" -#include "spdlog/spdlog.h" #include "stablehlo/dialect/StablehloOps.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" diff --git a/libspu/compiler/tests/interpret/abs.mlir b/libspu/compiler/tests/interpret/abs.mlir index dd655bf1..49e409a8 100644 --- a/libspu/compiler/tests/interpret/abs.mlir +++ b/libspu/compiler/tests/interpret/abs.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @abs_op_test_i64_i64_p() { diff --git a/libspu/compiler/tests/interpret/add.mlir b/libspu/compiler/tests/interpret/add.mlir index a763c0ae..5074d3eb 100644 --- a/libspu/compiler/tests/interpret/add.mlir +++ b/libspu/compiler/tests/interpret/add.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @add_op_test_i8_i8_pp() { diff --git a/libspu/compiler/tests/interpret/and.mlir b/libspu/compiler/tests/interpret/and.mlir index 6e3cdfa6..618d3222 100644 --- a/libspu/compiler/tests/interpret/and.mlir +++ b/libspu/compiler/tests/interpret/and.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @and_op_test_i8_i8_pp() { diff --git a/libspu/compiler/tests/interpret/atan2.mlir b/libspu/compiler/tests/interpret/atan2.mlir index e1938852..52c2b122 100644 --- a/libspu/compiler/tests/interpret/atan2.mlir +++ b/libspu/compiler/tests/interpret/atan2.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @atan2_op_test_f64_f64_pp() { diff --git a/libspu/compiler/tests/interpret/broadcast.mlir b/libspu/compiler/tests/interpret/broadcast.mlir index 1d0e9a63..7e2d17f8 100644 --- a/libspu/compiler/tests/interpret/broadcast.mlir +++ b/libspu/compiler/tests/interpret/broadcast.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @broadcast_in_dim() { %operand = pphlo.constant dense<[[1], [2], [3]]> : tensor<3x1xi64> diff --git a/libspu/compiler/tests/interpret/case.mlir b/libspu/compiler/tests/interpret/case.mlir index 543fbf86..96105d4e 100644 --- a/libspu/compiler/tests/interpret/case.mlir +++ b/libspu/compiler/tests/interpret/case.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @case_negative_index_default() { %index = pphlo.constant dense<-1> : tensor diff --git a/libspu/compiler/tests/interpret/ceil.mlir b/libspu/compiler/tests/interpret/ceil.mlir index d4f74b8b..23c05538 100644 --- a/libspu/compiler/tests/interpret/ceil.mlir +++ b/libspu/compiler/tests/interpret/ceil.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @ceil_op_test_f16_f16_p() { diff --git a/libspu/compiler/tests/interpret/clamp.mlir b/libspu/compiler/tests/interpret/clamp.mlir index 69c05696..9e76acc8 100644 --- a/libspu/compiler/tests/interpret/clamp.mlir +++ b/libspu/compiler/tests/interpret/clamp.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @clamp_op_test_si64() { %min = pphlo.constant dense<[1, 5, -5]> : tensor<3xi64> diff --git a/libspu/compiler/tests/interpret/concatenate.mlir b/libspu/compiler/tests/interpret/concatenate.mlir index 1d0d7655..63107f22 100644 --- a/libspu/compiler/tests/interpret/concatenate.mlir +++ b/libspu/compiler/tests/interpret/concatenate.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @concatenate() { %input0 = pphlo.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi64> diff --git a/libspu/compiler/tests/interpret/convert.mlir b/libspu/compiler/tests/interpret/convert.mlir index 747047ee..b3b23f3b 100644 --- a/libspu/compiler/tests/interpret/convert.mlir +++ b/libspu/compiler/tests/interpret/convert.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @convert_op_test_1() { %0 = pphlo.constant dense<[0, 1, 8, -9, 0]> : tensor<5xi32> diff --git a/libspu/compiler/tests/interpret/convolution.mlir b/libspu/compiler/tests/interpret/convolution.mlir index 294cdd36..2ce1030c 100644 --- a/libspu/compiler/tests/interpret/convolution.mlir +++ b/libspu/compiler/tests/interpret/convolution.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @main() { %0 = pphlo.constant dense<[[[[ 1.0, 2.0, 3.0, 4.0], diff --git a/libspu/compiler/tests/interpret/cosine.mlir b/libspu/compiler/tests/interpret/cosine.mlir index 72d6f249..5ee36f58 100644 --- a/libspu/compiler/tests/interpret/cosine.mlir +++ b/libspu/compiler/tests/interpret/cosine.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @cosine_op_test_f16_f16_p() { diff --git a/libspu/compiler/tests/interpret/divide.mlir b/libspu/compiler/tests/interpret/divide.mlir index 56ce53c2..a7540602 100644 --- a/libspu/compiler/tests/interpret/divide.mlir +++ b/libspu/compiler/tests/interpret/divide.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @divide_op_test_i64_i64_pp() { diff --git a/libspu/compiler/tests/interpret/dot_general.mlir b/libspu/compiler/tests/interpret/dot_general.mlir index 29023c81..a4c506ed 100644 --- a/libspu/compiler/tests/interpret/dot_general.mlir +++ b/libspu/compiler/tests/interpret/dot_general.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @dot_general_op_test_si64() { %lhs = pphlo.constant dense<[[[1, 2], [3, 4]], diff --git a/libspu/compiler/tests/interpret/dynamic_slice.mlir b/libspu/compiler/tests/interpret/dynamic_slice.mlir index 672ace7c..d0275fd9 100644 --- a/libspu/compiler/tests/interpret/dynamic_slice.mlir +++ b/libspu/compiler/tests/interpret/dynamic_slice.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @dynamic_slice() { %operand = pphlo.constant dense<[[1, 1, 1], diff --git a/libspu/compiler/tests/interpret/dynamic_update_slice.mlir b/libspu/compiler/tests/interpret/dynamic_update_slice.mlir index a7053428..423528ac 100644 --- a/libspu/compiler/tests/interpret/dynamic_update_slice.mlir +++ b/libspu/compiler/tests/interpret/dynamic_update_slice.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @dynamic_update_slice() { %operand = pphlo.constant dense<[[1, 1, 1, 1], diff --git a/libspu/compiler/tests/interpret/equal.mlir b/libspu/compiler/tests/interpret/equal.mlir index f5364638..c8291d1a 100644 --- a/libspu/compiler/tests/interpret/equal.mlir +++ b/libspu/compiler/tests/interpret/equal.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @equal_op_test_i64_i1_pp() { diff --git a/libspu/compiler/tests/interpret/exponential.mlir b/libspu/compiler/tests/interpret/exponential.mlir index a85c5bb6..21ec3591 100644 --- a/libspu/compiler/tests/interpret/exponential.mlir +++ b/libspu/compiler/tests/interpret/exponential.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @exponential_op_test_f64_f64_p() { diff --git a/libspu/compiler/tests/interpret/exponential_minus_one.mlir b/libspu/compiler/tests/interpret/exponential_minus_one.mlir index 1a337a21..35131bbc 100644 --- a/libspu/compiler/tests/interpret/exponential_minus_one.mlir +++ b/libspu/compiler/tests/interpret/exponential_minus_one.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @exponential_minus_one_op_test_f64_f64_p() { diff --git a/libspu/compiler/tests/interpret/floor.mlir b/libspu/compiler/tests/interpret/floor.mlir index 97ccdb1e..602d0161 100644 --- a/libspu/compiler/tests/interpret/floor.mlir +++ b/libspu/compiler/tests/interpret/floor.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @floor_op_test_f16_f16_p() { diff --git a/libspu/compiler/tests/interpret/generate_mlir_tests.py b/libspu/compiler/tests/interpret/generate_mlir_tests.py index f9bdef86..7a271164 100755 --- a/libspu/compiler/tests/interpret/generate_mlir_tests.py +++ b/libspu/compiler/tests/interpret/generate_mlir_tests.py @@ -65,6 +65,7 @@ def TestCase(inputs, expected, checker='expect_eq', tol=None): "or", # "popcnt", "power", + "reciprocal", "reshape", "round_afz", "rsqrt", @@ -99,6 +100,14 @@ def TestCase(inputs, expected, checker='expect_eq', tol=None): f.write( "// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s\n" ) + f.write( + "// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s\n" + ) + # Some test values in max and min are not supported by protocol 5. + if test not in ["max", "min"]: + f.write( + "// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s\n" + ) f.write("// AUTO GENERATED, DO NOT EDIT\n\n") # Emit cases diff --git a/libspu/compiler/tests/interpret/greater.mlir b/libspu/compiler/tests/interpret/greater.mlir index 7f8e76be..92140f85 100644 --- a/libspu/compiler/tests/interpret/greater.mlir +++ b/libspu/compiler/tests/interpret/greater.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @greater_op_test_i64_i1_pp() { diff --git a/libspu/compiler/tests/interpret/greater_equal.mlir b/libspu/compiler/tests/interpret/greater_equal.mlir index 305aaff5..af3ffa7a 100644 --- a/libspu/compiler/tests/interpret/greater_equal.mlir +++ b/libspu/compiler/tests/interpret/greater_equal.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @greater_equal_op_test_i64_i1_pp() { diff --git a/libspu/compiler/tests/interpret/if.mlir b/libspu/compiler/tests/interpret/if.mlir index cc98b65d..ef73d547 100644 --- a/libspu/compiler/tests/interpret/if.mlir +++ b/libspu/compiler/tests/interpret/if.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @if_ops_true_branch() { %pred = pphlo.constant dense : tensor diff --git a/libspu/compiler/tests/interpret/iota.mlir b/libspu/compiler/tests/interpret/iota.mlir index a7ee86ed..cc71b040 100644 --- a/libspu/compiler/tests/interpret/iota.mlir +++ b/libspu/compiler/tests/interpret/iota.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @iota_op_test_si8_dim_0() { %0 = pphlo.iota dim = 0 : tensor<3x4xi8> diff --git a/libspu/compiler/tests/interpret/less.mlir b/libspu/compiler/tests/interpret/less.mlir index 58444a29..1a9d3060 100644 --- a/libspu/compiler/tests/interpret/less.mlir +++ b/libspu/compiler/tests/interpret/less.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @less_op_test_i64_i1_pp() { diff --git a/libspu/compiler/tests/interpret/less_equal.mlir b/libspu/compiler/tests/interpret/less_equal.mlir index 9951a569..c454bcc7 100644 --- a/libspu/compiler/tests/interpret/less_equal.mlir +++ b/libspu/compiler/tests/interpret/less_equal.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @less_equal_op_test_i64_i1_pp() { diff --git a/libspu/compiler/tests/interpret/log.mlir b/libspu/compiler/tests/interpret/log.mlir index af61d86a..bf309137 100644 --- a/libspu/compiler/tests/interpret/log.mlir +++ b/libspu/compiler/tests/interpret/log.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @log_op_test_f64_f64_p() { diff --git a/libspu/compiler/tests/interpret/log_plus_one.mlir b/libspu/compiler/tests/interpret/log_plus_one.mlir index 3aec9184..99bcf9ff 100644 --- a/libspu/compiler/tests/interpret/log_plus_one.mlir +++ b/libspu/compiler/tests/interpret/log_plus_one.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @log_plus_one_op_test_f64_f64_p() { diff --git a/libspu/compiler/tests/interpret/logistic.mlir b/libspu/compiler/tests/interpret/logistic.mlir index eeac6fab..655cb82b 100644 --- a/libspu/compiler/tests/interpret/logistic.mlir +++ b/libspu/compiler/tests/interpret/logistic.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @logistic_op_test_f64_f64_p() { diff --git a/libspu/compiler/tests/interpret/maximum.mlir b/libspu/compiler/tests/interpret/maximum.mlir index b90553e7..4919c356 100644 --- a/libspu/compiler/tests/interpret/maximum.mlir +++ b/libspu/compiler/tests/interpret/maximum.mlir @@ -1,6 +1,7 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @maximum_op_test_i8_i8_pp() { diff --git a/libspu/compiler/tests/interpret/minimum.mlir b/libspu/compiler/tests/interpret/minimum.mlir index 74bb2b83..1853124b 100644 --- a/libspu/compiler/tests/interpret/minimum.mlir +++ b/libspu/compiler/tests/interpret/minimum.mlir @@ -1,6 +1,7 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @minimum_op_test_i8_i8_pp() { diff --git a/libspu/compiler/tests/interpret/multiply.mlir b/libspu/compiler/tests/interpret/multiply.mlir index 82b780f3..f7d415c4 100644 --- a/libspu/compiler/tests/interpret/multiply.mlir +++ b/libspu/compiler/tests/interpret/multiply.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @multiply_op_test_i8_i8_pp() { diff --git a/libspu/compiler/tests/interpret/negate.mlir b/libspu/compiler/tests/interpret/negate.mlir index 8c5e74c3..6a00d9b8 100644 --- a/libspu/compiler/tests/interpret/negate.mlir +++ b/libspu/compiler/tests/interpret/negate.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @negate_op_test_i8_i8_p() { diff --git a/libspu/compiler/tests/interpret/not.mlir b/libspu/compiler/tests/interpret/not.mlir index 0fdf44e4..721e1850 100644 --- a/libspu/compiler/tests/interpret/not.mlir +++ b/libspu/compiler/tests/interpret/not.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @not_op_test_i8_i8_p() { diff --git a/libspu/compiler/tests/interpret/not_equal.mlir b/libspu/compiler/tests/interpret/not_equal.mlir index 1bbbd5b3..cb598070 100644 --- a/libspu/compiler/tests/interpret/not_equal.mlir +++ b/libspu/compiler/tests/interpret/not_equal.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @not_equal_op_test_i64_i1_pp() { diff --git a/libspu/compiler/tests/interpret/or.mlir b/libspu/compiler/tests/interpret/or.mlir index 79836813..1eef975f 100644 --- a/libspu/compiler/tests/interpret/or.mlir +++ b/libspu/compiler/tests/interpret/or.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @or_op_test_i8_i8_pp() { diff --git a/libspu/compiler/tests/interpret/pad.mlir b/libspu/compiler/tests/interpret/pad.mlir index abedc73a..521313e2 100644 --- a/libspu/compiler/tests/interpret/pad.mlir +++ b/libspu/compiler/tests/interpret/pad.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @pad() { %operand = pphlo.constant dense<[[0, 0, 0, 0], diff --git a/libspu/compiler/tests/interpret/popcnt.mlir b/libspu/compiler/tests/interpret/popcnt.mlir index e5f83152..efea9133 100644 --- a/libspu/compiler/tests/interpret/popcnt.mlir +++ b/libspu/compiler/tests/interpret/popcnt.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @popcnt_op_test_i64_i64_p() { diff --git a/libspu/compiler/tests/interpret/power.mlir b/libspu/compiler/tests/interpret/power.mlir index 307fce4c..b6bfc475 100644 --- a/libspu/compiler/tests/interpret/power.mlir +++ b/libspu/compiler/tests/interpret/power.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @power_op_test_i64_i64_pp() { diff --git a/libspu/compiler/tests/interpret/reciprocal.mlir b/libspu/compiler/tests/interpret/reciprocal.mlir new file mode 100644 index 00000000..fe1623a5 --- /dev/null +++ b/libspu/compiler/tests/interpret/reciprocal.mlir @@ -0,0 +1,26 @@ +// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s +// AUTO GENERATED, DO NOT EDIT + +func.func @reciprocal_op_test_f64_f64_p() { + %0 = pphlo.constant dense<[[1.0, -200.0], [100.0, 286991.875]]> : tensor<2x2xf64> + %1 = pphlo.reciprocal %0 : (tensor<2x2xf64>)->tensor<2x2xf64> + %2 = pphlo.constant dense<[[1.0, -0.005], [0.01, 0.0]]> : tensor<2x2xf64> + pphlo.custom_call @expect_almost_eq(%1, %2) { tol = 0.01 }: (tensor<2x2xf64>, tensor<2x2xf64>)->() + func.return +} + +// ----- + +func.func @reciprocal_op_test_f64_f64_s() { + %0 = pphlo.constant dense<[[1.0, -200.0], [100.0, 286991.875]]> : tensor<2x2xf64> + %1 = pphlo.convert %0 : (tensor<2x2xf64>)->tensor<2x2x!pphlo.secret> + %2 = pphlo.reciprocal %1 : (tensor<2x2x!pphlo.secret>)->tensor<2x2x!pphlo.secret> + %3 = pphlo.constant dense<[[1.0, -0.005], [0.01, 0.0]]> : tensor<2x2xf64> + %4 = pphlo.convert %2 : (tensor<2x2x!pphlo.secret>)->tensor<2x2xf64> + pphlo.custom_call @expect_almost_eq(%3, %4) { tol = 0.01 }: (tensor<2x2xf64>, tensor<2x2xf64>)->() + func.return +} diff --git a/libspu/compiler/tests/interpret/reduce.mlir b/libspu/compiler/tests/interpret/reduce.mlir index 3b34bf3e..efc1d40d 100644 --- a/libspu/compiler/tests/interpret/reduce.mlir +++ b/libspu/compiler/tests/interpret/reduce.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @reduce() { %input = pphlo.constant dense<[[0, 1, 2, 3, 4, 5]]> : tensor<1x6xi64> diff --git a/libspu/compiler/tests/interpret/reduce_window.mlir b/libspu/compiler/tests/interpret/reduce_window.mlir index 5385ab04..d15cd021 100644 --- a/libspu/compiler/tests/interpret/reduce_window.mlir +++ b/libspu/compiler/tests/interpret/reduce_window.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @reduce_window() { %input = pphlo.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi64> diff --git a/libspu/compiler/tests/interpret/reshape.mlir b/libspu/compiler/tests/interpret/reshape.mlir index 9e483b77..0a670923 100644 --- a/libspu/compiler/tests/interpret/reshape.mlir +++ b/libspu/compiler/tests/interpret/reshape.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @reshape_op_test_i32_i32_p() { diff --git a/libspu/compiler/tests/interpret/reverse.mlir b/libspu/compiler/tests/interpret/reverse.mlir index 63ab9590..832262e2 100644 --- a/libspu/compiler/tests/interpret/reverse.mlir +++ b/libspu/compiler/tests/interpret/reverse.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @reverse() { %operand = pphlo.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi64> diff --git a/libspu/compiler/tests/interpret/ring_cast.mlir b/libspu/compiler/tests/interpret/ring_cast.mlir index 94b53241..a6b6806c 100644 --- a/libspu/compiler/tests/interpret/ring_cast.mlir +++ b/libspu/compiler/tests/interpret/ring_cast.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @cast_1() { %c0 = pphlo.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32> diff --git a/libspu/compiler/tests/interpret/round_nearest_afz.mlir b/libspu/compiler/tests/interpret/round_nearest_afz.mlir index 6601fa1d..40e64ebe 100644 --- a/libspu/compiler/tests/interpret/round_nearest_afz.mlir +++ b/libspu/compiler/tests/interpret/round_nearest_afz.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @round_nearest_afz_op_test_f64_f64_p() { diff --git a/libspu/compiler/tests/interpret/rsqrt.mlir b/libspu/compiler/tests/interpret/rsqrt.mlir index ef6470c9..5e5e0869 100644 --- a/libspu/compiler/tests/interpret/rsqrt.mlir +++ b/libspu/compiler/tests/interpret/rsqrt.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @rsqrt_op_test_f64_f64_p() { diff --git a/libspu/compiler/tests/interpret/select.mlir b/libspu/compiler/tests/interpret/select.mlir index 33e9a9c7..c2752761 100644 --- a/libspu/compiler/tests/interpret/select.mlir +++ b/libspu/compiler/tests/interpret/select.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @select_op_test_si64() { %pred = pphlo.constant dense<[true, false, true]> : tensor<3xi1> diff --git a/libspu/compiler/tests/interpret/select_and_scatter.mlir b/libspu/compiler/tests/interpret/select_and_scatter.mlir index 5e55dc59..c7f0057a 100644 --- a/libspu/compiler/tests/interpret/select_and_scatter.mlir +++ b/libspu/compiler/tests/interpret/select_and_scatter.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // FIXME func.func @select_and_scatter_op_test() { diff --git a/libspu/compiler/tests/interpret/shift_right_arithmetic.mlir b/libspu/compiler/tests/interpret/shift_right_arithmetic.mlir index 494666ae..ce4bedbc 100644 --- a/libspu/compiler/tests/interpret/shift_right_arithmetic.mlir +++ b/libspu/compiler/tests/interpret/shift_right_arithmetic.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @shift_right_arithmetic_op_test_i64_i64_pp() { diff --git a/libspu/compiler/tests/interpret/shift_right_logical.mlir b/libspu/compiler/tests/interpret/shift_right_logical.mlir index 253d6cc3..73d69f0a 100644 --- a/libspu/compiler/tests/interpret/shift_right_logical.mlir +++ b/libspu/compiler/tests/interpret/shift_right_logical.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @shift_right_logical_op_test_i64_i64_pp() { diff --git a/libspu/compiler/tests/interpret/sign.mlir b/libspu/compiler/tests/interpret/sign.mlir index b153a373..105c5cb2 100644 --- a/libspu/compiler/tests/interpret/sign.mlir +++ b/libspu/compiler/tests/interpret/sign.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @sign_op_test_i64_i64_p() { diff --git a/libspu/compiler/tests/interpret/sine.mlir b/libspu/compiler/tests/interpret/sine.mlir index f0b59e5c..1d62529d 100644 --- a/libspu/compiler/tests/interpret/sine.mlir +++ b/libspu/compiler/tests/interpret/sine.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @sine_op_test_f16_f16_p() { diff --git a/libspu/compiler/tests/interpret/slice.mlir b/libspu/compiler/tests/interpret/slice.mlir index b4b8bdbe..583b977b 100644 --- a/libspu/compiler/tests/interpret/slice.mlir +++ b/libspu/compiler/tests/interpret/slice.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @slice_op() { %operand = pphlo.constant dense<[[0, 0, 1, 0, 0, 1], diff --git a/libspu/compiler/tests/interpret/sort.mlir b/libspu/compiler/tests/interpret/sort.mlir index 837ded37..2433d4ae 100644 --- a/libspu/compiler/tests/interpret/sort.mlir +++ b/libspu/compiler/tests/interpret/sort.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @sort_stable() { %input0 = pphlo.constant dense<[[1, 2, 3], [3, 2, 1]]> : tensor<2x3xi64> diff --git a/libspu/compiler/tests/interpret/sqrt.mlir b/libspu/compiler/tests/interpret/sqrt.mlir index 9248077a..59372c6b 100644 --- a/libspu/compiler/tests/interpret/sqrt.mlir +++ b/libspu/compiler/tests/interpret/sqrt.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @sqrt_op_test_f64_f64_p() { diff --git a/libspu/compiler/tests/interpret/subtract.mlir b/libspu/compiler/tests/interpret/subtract.mlir index ce8b9633..37032882 100644 --- a/libspu/compiler/tests/interpret/subtract.mlir +++ b/libspu/compiler/tests/interpret/subtract.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @subtract_op_test_i8_i8_pp() { diff --git a/libspu/compiler/tests/interpret/tanh.mlir b/libspu/compiler/tests/interpret/tanh.mlir index 413dc6d2..ab45fa2e 100644 --- a/libspu/compiler/tests/interpret/tanh.mlir +++ b/libspu/compiler/tests/interpret/tanh.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @tanh_op_test_f16_f16_p() { diff --git a/libspu/compiler/tests/interpret/test_json/reciprocal.json b/libspu/compiler/tests/interpret/test_json/reciprocal.json new file mode 100644 index 00000000..d8e07529 --- /dev/null +++ b/libspu/compiler/tests/interpret/test_json/reciprocal.json @@ -0,0 +1,24 @@ +{ + "name": "reciprocal", + "template": "basic_unary", + "testcases": [ + { + "inputs": [ + { + "data": "[[1.0, -200.0], [100.0, 286991.875]]", + "shape": "2x2", + "dtype": "f64" + } + ], + "expected": [ + { + "data": "[[1.0, -0.005], [0.01, 0.0]]", + "shape": "2x2", + "dtype": "f64" + } + ], + "checker": "expect_almost_eq", + "tol": 0.01 + } + ] +} \ No newline at end of file diff --git a/libspu/compiler/tests/interpret/transpose.mlir b/libspu/compiler/tests/interpret/transpose.mlir index d5ce9b6e..d36883de 100644 --- a/libspu/compiler/tests/interpret/transpose.mlir +++ b/libspu/compiler/tests/interpret/transpose.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @transpose_op_test_si32() { %0 = pphlo.constant dense<[[[1,2],[3,4],[5,6]], [[7,8],[9,10],[11,12]]]> : tensor<2x3x2xi32> diff --git a/libspu/compiler/tests/interpret/while.mlir b/libspu/compiler/tests/interpret/while.mlir index 32789fae..734e199e 100644 --- a/libspu/compiler/tests/interpret/while.mlir +++ b/libspu/compiler/tests/interpret/while.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s func.func @while() { // int i = 0; diff --git a/libspu/compiler/tests/interpret/xor.mlir b/libspu/compiler/tests/interpret/xor.mlir index c2e1f1ee..a9ee8348 100644 --- a/libspu/compiler/tests/interpret/xor.mlir +++ b/libspu/compiler/tests/interpret/xor.mlir @@ -1,6 +1,8 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s +// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT func.func @xor_op_test_i8_i8_pp() { diff --git a/libspu/compiler/tests/passes/hlo2pphlo/select_and_scatter.mlir b/libspu/compiler/tests/passes/hlo2pphlo/select_and_scatter.mlir index ac7492bc..bc21172f 100644 --- a/libspu/compiler/tests/passes/hlo2pphlo/select_and_scatter.mlir +++ b/libspu/compiler/tests/passes/hlo2pphlo/select_and_scatter.mlir @@ -1,7 +1,7 @@ // RUN: spu-opt --hlo-legalize-to-pphlo=input_vis_list=VIS_SECRET,VIS_PUBLIC,VIS_PUBLIC --lower-conversion-cast --split-input-file %s | FileCheck %s func.func @main(%arg0: tensor<128x5x5x32xf32>, %arg1: tensor<128x4x4x32xf32>, %arg2: tensor) -> tensor<128x5x5x32xf32> { - // CHECK: %1 = "pphlo.select_and_scatter"(%arg0, %arg1, %0) ({ + // CHECK: %1 = "pphlo.select_and_scatter"(%arg0, %arg1, %0) <{window_dimensions = array, window_strides = array}> ({ // CHECK: ^bb0(%arg3: tensor>, %arg4: tensor>): // CHECK: %2 = pphlo.greater_equal %arg3, %arg4 : (tensor>, tensor>) -> tensor> // CHECK: pphlo.return %2 : tensor> @@ -9,7 +9,7 @@ func.func @main(%arg0: tensor<128x5x5x32xf32>, %arg1: tensor<128x4x4x32xf32>, %a // CHECK: ^bb0(%arg3: tensor, %arg4: tensor>): // CHECK: %2 = pphlo.add %arg3, %arg4 : (tensor, tensor>) -> tensor> // CHECK: pphlo.return %2 : tensor> - // CHECK: }) {window_dimensions = array, window_strides = array} : (tensor<128x5x5x32x!pphlo.secret>, tensor<128x4x4x32xf32>, tensor>) -> tensor<128x5x5x32x!pphlo.secret> + // CHECK: }) : (tensor<128x5x5x32x!pphlo.secret>, tensor<128x4x4x32xf32>, tensor>) -> tensor<128x5x5x32x!pphlo.secret> %0 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ ^bb0(%arg3: tensor, %arg4: tensor): %1 = "stablehlo.compare"(%arg3, %arg4) {comparison_direction = #stablehlo} : (tensor, tensor) -> tensor diff --git a/libspu/compiler/tests/passes/hlo2pphlo/sort_p.mlir b/libspu/compiler/tests/passes/hlo2pphlo/sort_p.mlir index 2224ee7d..facd338f 100644 --- a/libspu/compiler/tests/passes/hlo2pphlo/sort_p.mlir +++ b/libspu/compiler/tests/passes/hlo2pphlo/sort_p.mlir @@ -2,11 +2,11 @@ func.func @main(%arg0: tensor<20xi32>) -> (tensor<20xi32>) { %0 = stablehlo.iota dim = 0 : tensor<20xi32> - // CHECK: %1:2 = "pphlo.sort"(%arg0, %0) ({ + // CHECK: %1:2 = "pphlo.sort"(%arg0, %0) <{dimension = 0 : i64, is_stable = true}> ({ // CHECK: ^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor): // CHECK: %2 = pphlo.less %arg1, %arg2 : (tensor, tensor) -> tensor // CHECK: pphlo.return %2 : tensor - // CHECK: }) {dimension = 0 : i64, is_stable = true} : (tensor<20xi32>, tensor<20xi32>) -> (tensor<20xi32>, tensor<20xi32>) + // CHECK: }) : (tensor<20xi32>, tensor<20xi32>) -> (tensor<20xi32>, tensor<20xi32>) %1:2 = "stablehlo.sort"(%arg0, %0) ({ ^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor): %2 = stablehlo.compare LT, %arg1, %arg2 : (tensor, tensor) -> tensor diff --git a/libspu/compiler/tests/passes/hlo2pphlo/sort_s.mlir b/libspu/compiler/tests/passes/hlo2pphlo/sort_s.mlir index c4d4bd41..cf7ef73d 100644 --- a/libspu/compiler/tests/passes/hlo2pphlo/sort_s.mlir +++ b/libspu/compiler/tests/passes/hlo2pphlo/sort_s.mlir @@ -2,11 +2,11 @@ func.func @main(%arg0: tensor<20xi32>) -> (tensor<20xi32>) { %0 = stablehlo.iota dim = 0 : tensor<20xi32> - // CHECK: %1:2 = "pphlo.sort"(%arg0, %0) ({ + // CHECK: %1:2 = "pphlo.sort"(%arg0, %0) <{dimension = 0 : i64, is_stable = true}> ({ // CHECK: ^bb0(%arg1: tensor>, %arg2: tensor>, %arg3: tensor, %arg4: tensor): // CHECK: %2 = pphlo.less %arg1, %arg2 : (tensor>, tensor>) -> tensor> // CHECK: pphlo.return %2 : tensor> - // CHECK: }) {dimension = 0 : i64, is_stable = true} : (tensor<20x!pphlo.secret>, tensor<20xi32>) -> (tensor<20x!pphlo.secret>, tensor<20x!pphlo.secret>) + // CHECK: }) : (tensor<20x!pphlo.secret>, tensor<20xi32>) -> (tensor<20x!pphlo.secret>, tensor<20x!pphlo.secret>) %1:2 = "stablehlo.sort"(%arg0, %0) ({ ^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor): %2 = stablehlo.compare LT, %arg1, %arg2 : (tensor, tensor) -> tensor diff --git a/libspu/compiler/tools/spu-lsp.cc b/libspu/compiler/tools/spu-lsp.cc index 0b9ddbd6..9f59509d 100644 --- a/libspu/compiler/tools/spu-lsp.cc +++ b/libspu/compiler/tools/spu-lsp.cc @@ -25,5 +25,6 @@ int main(int argc, char **argv) { registry.insert(); mlir::func::registerInlinerExtension(registry); - return mlir::failed(mlir::MlirLspServerMain(argc, argv, registry)); + return static_cast( + mlir::failed(mlir::MlirLspServerMain(argc, argv, registry))); } diff --git a/libspu/compiler/tools/spu-translate.cc b/libspu/compiler/tools/spu-translate.cc index 423321a5..68857fa0 100644 --- a/libspu/compiler/tools/spu-translate.cc +++ b/libspu/compiler/tools/spu-translate.cc @@ -22,12 +22,10 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Tools/mlir-translate/MlirTranslateMain.h" #include "mlir/Tools/mlir-translate/Translation.h" -#include "mlir/Transforms/Passes.h" #include "stablehlo/dialect/StablehloOps.h" #include "xtensor/xio.hpp" #include "libspu/compiler/common/compilation_context.h" -#include "libspu/compiler/utils/utils.h" #include "libspu/core/prelude.h" #include "libspu/device/pphlo/pphlo_executor.h" #include "libspu/dialect/pphlo/IR/dialect.h" @@ -89,7 +87,7 @@ bool testOpHandler(::spu::SPUContext *sctx, mlir::Operation *op, auto callOp = mlir::dyn_cast(op); if (callOp.getCallTargetName() == "expect_almost_eq") { ::spu::Value runtimeLhs = inputs[0]; - ::spu::Value runtimeRhs = inputs[1]; + const ::spu::Value &runtimeRhs = inputs[1]; if (!runtimeLhs.isPublic()) { runtimeLhs = ::spu::kernel::hal::_s2p(sctx, runtimeLhs) @@ -123,7 +121,7 @@ bool testOpHandler(::spu::SPUContext *sctx, mlir::Operation *op, if (callOp.getCallTargetName() == "expect_eq") { ::spu::Value runtimeLhs = inputs[0]; - ::spu::Value runtimeRhs = inputs[1]; + const ::spu::Value &runtimeRhs = inputs[1]; if (!runtimeLhs.isPublic()) { runtimeLhs = ::spu::kernel::hal::_s2p(sctx, runtimeLhs) @@ -239,6 +237,16 @@ void evalModule(ModuleOp module) { numParties = 3; break; } + case 4: { + conf.set_protocol(::spu::CHEETAH); + numParties = 2; + break; + } + case 5: { + conf.set_protocol(::spu::SECURENN); + numParties = 3; + break; + } } SPDLOG_INFO(conf.DebugString()); @@ -278,6 +286,6 @@ TranslateFromMLIRRegistration interpretRegistration( } // namespace mlir int main(int argc, char **argv) { - return failed( - mlir::mlirTranslateMain(argc, argv, "SPU interpreter driver\n")); + return static_cast( + failed(mlir::mlirTranslateMain(argc, argv, "SPU interpreter driver\n"))); } diff --git a/libspu/core/encoding.cc b/libspu/core/encoding.cc index eb26ce9b..98a17a1a 100644 --- a/libspu/core/encoding.cc +++ b/libspu/core/encoding.cc @@ -60,8 +60,8 @@ NdArrayRef encodeToRing(const PtBufferView& bv, FieldType field, } if (pt_type == PT_F32 || pt_type == PT_F64 || pt_type == PT_F16) { - DISPATCH_FLOAT_PT_TYPES(pt_type, "_", [&]() { - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_FLOAT_PT_TYPES(pt_type, [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using Float = ScalarT; using T = std::make_signed_t; @@ -100,8 +100,8 @@ NdArrayRef encodeToRing(const PtBufferView& bv, FieldType field, return dst; } else { // handle integer & boolean - DISPATCH_INT_PT_TYPES(pt_type, "_", [&]() { - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_INT_PT_TYPES(pt_type, [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using Integer = ScalarT; SPU_ENFORCE(sizeof(ring2k_t) >= sizeof(Integer), "integer encoding failed, ring={} could not represent {}", @@ -138,8 +138,8 @@ void decodeFromRing(const NdArrayRef& src, DataType in_dtype, size_t fxp_bits, *out_pt_type = pt_type; } - DISPATCH_ALL_FIELDS(field, "field", [&]() { - DISPATCH_ALL_PT_TYPES(pt_type, "pt_type", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { + DISPATCH_ALL_PT_TYPES(pt_type, [&]() { using T = std::make_signed_t; auto _src = NdArrayView(src); diff --git a/libspu/core/pt_buffer_view.cc b/libspu/core/pt_buffer_view.cc index 50e6f891..3f7e1084 100644 --- a/libspu/core/pt_buffer_view.cc +++ b/libspu/core/pt_buffer_view.cc @@ -50,7 +50,7 @@ NdArrayRef convertToNdArray(PtBufferView bv) { } const auto type = makePtType(bv.pt_type); auto out = NdArrayRef(type, bv.shape); - return DISPATCH_ALL_PT_TYPES(bv.pt_type, "pt_type", [&]() { + return DISPATCH_ALL_PT_TYPES(bv.pt_type, [&]() { using T = ScalarT; if (bv.shape.numel() > 0) { auto* out_ptr = out.data(); diff --git a/libspu/core/type_util.h b/libspu/core/type_util.h index 2ac1d294..4e70ea7e 100644 --- a/libspu/core/type_util.h +++ b/libspu/core/type_util.h @@ -97,91 +97,90 @@ std::ostream& operator<<(std::ostream& os, const DataType& dtype); // Helper macros to enumerate all py types. // NOLINTNEXTLINE: Global internal used macro. -#define __CASE_PT_TYPE(PT_TYPE, NAME, ...) \ - case (PT_TYPE): { \ - [[maybe_unused]] constexpr std::string_view _kName = NAME; \ - using ScalarT = EnumToPtType::type; \ - return __VA_ARGS__(); \ +#define __CASE_PT_TYPE(PT_TYPE, ...) \ + case (PT_TYPE): { \ + using ScalarT = EnumToPtType::type; \ + return __VA_ARGS__(); \ } -#define DISPATCH_FLOAT_PT_TYPES(PT_TYPE, NAME, ...) \ - [&] { \ - switch (PT_TYPE) { \ - __CASE_PT_TYPE(spu::PT_F16, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_F32, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_F64, NAME, __VA_ARGS__) \ - default: \ - SPU_THROW("{} not implemented for pt_type={}", #NAME, PT_TYPE); \ - } \ +#define DISPATCH_FLOAT_PT_TYPES(PT_TYPE, ...) \ + [&] { \ + switch (PT_TYPE) { \ + __CASE_PT_TYPE(spu::PT_F16, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_F32, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_F64, __VA_ARGS__) \ + default: \ + SPU_THROW("unimplemented for pt_type={}", PT_TYPE); \ + } \ }() -#define DISPATCH_UINT_PT_TYPES(PT_TYPE, NAME, ...) \ - [&] { \ - switch (PT_TYPE) { \ - __CASE_PT_TYPE(spu::PT_U8, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U16, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U32, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U64, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U128, NAME, __VA_ARGS__) \ - default: \ - SPU_THROW("{} not implemented for pt_type={}", #NAME, PT_TYPE); \ - } \ +#define DISPATCH_UINT_PT_TYPES(PT_TYPE, ...) \ + [&] { \ + switch (PT_TYPE) { \ + __CASE_PT_TYPE(spu::PT_U8, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U16, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U32, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U64, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U128, __VA_ARGS__) \ + default: \ + SPU_THROW("unimplemented for pt_type={}", PT_TYPE); \ + } \ }() -#define DISPATCH_INT_PT_TYPES(PT_TYPE, NAME, ...) \ - [&] { \ - switch (PT_TYPE) { \ - __CASE_PT_TYPE(spu::PT_I1, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_I8, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U8, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_I16, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U16, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_I32, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U32, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_I64, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U64, NAME, __VA_ARGS__) \ - default: \ - SPU_THROW("{} not implemented for pt_type={}", #NAME, PT_TYPE); \ - } \ +#define DISPATCH_INT_PT_TYPES(PT_TYPE, ...) \ + [&] { \ + switch (PT_TYPE) { \ + __CASE_PT_TYPE(spu::PT_I1, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_I8, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U8, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_I16, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U16, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_I32, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U32, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_I64, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U64, __VA_ARGS__) \ + default: \ + SPU_THROW("unimplemented for pt_type={}", PT_TYPE); \ + } \ }() -#define DISPATCH_ALL_PT_TYPES(PT_TYPE, NAME, ...) \ - [&] { \ - switch (PT_TYPE) { \ - __CASE_PT_TYPE(spu::PT_I1, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_I8, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U8, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_I16, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U16, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_I32, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U32, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_I64, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U64, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_F16, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_F32, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_F64, NAME, __VA_ARGS__) \ - default: \ - SPU_THROW("{} not implemented for pt_type={}", #NAME, PT_TYPE); \ - } \ +#define DISPATCH_ALL_PT_TYPES(PT_TYPE, ...) \ + [&] { \ + switch (PT_TYPE) { \ + __CASE_PT_TYPE(spu::PT_I1, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_I8, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U8, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_I16, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U16, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_I32, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U32, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_I64, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U64, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_F16, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_F32, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_F64, __VA_ARGS__) \ + default: \ + SPU_THROW("unimplemented for pt_type={}", PT_TYPE); \ + } \ }() -#define DISPATCH_ALL_NONE_BOOL_PT_TYPES(PT_TYPE, NAME, ...) \ - [&] { \ - switch (PT_TYPE) { \ - __CASE_PT_TYPE(spu::PT_I8, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U8, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_I16, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U16, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_I32, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U32, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_I64, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_U64, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_F16, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_F32, NAME, __VA_ARGS__) \ - __CASE_PT_TYPE(spu::PT_F64, NAME, __VA_ARGS__) \ - default: \ - SPU_THROW("{} not implemented for pt_type={}", #NAME, PT_TYPE); \ - } \ +#define DISPATCH_ALL_NONE_BOOL_PT_TYPES(PT_TYPE, ...) \ + [&] { \ + switch (PT_TYPE) { \ + __CASE_PT_TYPE(spu::PT_I8, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U8, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_I16, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U16, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_I32, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U32, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_I64, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_U64, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_F16, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_F32, __VA_ARGS__) \ + __CASE_PT_TYPE(spu::PT_F64, __VA_ARGS__) \ + default: \ + SPU_THROW("unimplemented for pt_type={}", PT_TYPE); \ + } \ }() std::ostream& operator<<(std::ostream& os, const PtType& pt_type); @@ -241,24 +240,23 @@ inline size_t SizeOf(FieldType field) { return SizeOf(GetStorageType(field)); } // Helper macros to enumerate all fields // NOLINTNEXTLINE: Global internal used macro. -#define __CASE_FIELD(FIELD, NAME, ...) \ +#define __CASE_FIELD(FIELD, ...) \ case (FIELD): { \ /* inject `_kField` & `_kName` for the continuation call */ \ [[maybe_unused]] constexpr spu::FieldType _kField = FIELD; \ - [[maybe_unused]] constexpr std::string_view _kName = NAME; \ using ring2k_t [[maybe_unused]] = Ring2kTrait<_kField>::scalar_t; \ return __VA_ARGS__(); \ } -#define DISPATCH_ALL_FIELDS(FIELD, NAME, ...) \ - [&] { \ - switch (FIELD) { \ - __CASE_FIELD(spu::FieldType::FM32, NAME, __VA_ARGS__) \ - __CASE_FIELD(spu::FieldType::FM64, NAME, __VA_ARGS__) \ - __CASE_FIELD(spu::FieldType::FM128, NAME, __VA_ARGS__) \ - default: \ - SPU_THROW("{} not implemented for field={}", #NAME, FIELD); \ - } \ +#define DISPATCH_ALL_FIELDS(FIELD, ...) \ + [&] { \ + switch (FIELD) { \ + __CASE_FIELD(spu::FieldType::FM32, __VA_ARGS__) \ + __CASE_FIELD(spu::FieldType::FM64, __VA_ARGS__) \ + __CASE_FIELD(spu::FieldType::FM128, __VA_ARGS__) \ + default: \ + SPU_THROW("unimplemented for field={}", FIELD); \ + } \ }() ////////////////////////////////////////////////////////////// diff --git a/libspu/dialect/pphlo/IR/dialect.td b/libspu/dialect/pphlo/IR/dialect.td index de2c648c..7d47197c 100644 --- a/libspu/dialect/pphlo/IR/dialect.td +++ b/libspu/dialect/pphlo/IR/dialect.td @@ -32,15 +32,14 @@ def PPHlo_Dialect : Dialect { string summary = "Privacy-Preserving HLO(PPHLO) dialect"; string description = [{ PPHLO represents a high level abstraction for language use by SPU. - It implements a subset of mlir mhlo ops with it's own privacy-preserving focused type system. + It implements a subset of mlir stablehlo ops with it's own privacy-preserving focused type system. - Learn more about mlir hlo at https://github.com/tensorflow/mlir-hlo + Learn more about mlir stablehlo at https://github.com/openxla/stablehlo }]; let name = "pphlo"; let cppNamespace = "::mlir::spu::pphlo"; let useDefaultAttributePrinterParser = 0; let useDefaultTypePrinterParser = 0; - let usePropertiesForAttributes = 0; let hasConstantMaterializer = 1; let extraClassDeclaration = [{ Attribute parseAttribute(DialectAsmParser & parser, Type type) diff --git a/libspu/dialect/pphlo/IR/type_inference.cc b/libspu/dialect/pphlo/IR/type_inference.cc index 4e3a9464..6c95aac3 100644 --- a/libspu/dialect/pphlo/IR/type_inference.cc +++ b/libspu/dialect/pphlo/IR/type_inference.cc @@ -285,7 +285,7 @@ LogicalResult PadOp::inferReturnTypes( ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type>& inferredReturnTypes) { - PadOp::Adaptor adaptor(operands, attributes, {}, regions); + PadOp::Adaptor adaptor(operands, attributes, properties, regions); return hlo::inferPadOp(location, adaptor.getOperand().getType(), adaptor.getPaddingValue().getType(), adaptor.getEdgePaddingLow(), @@ -295,27 +295,27 @@ LogicalResult PadOp::inferReturnTypes( LogicalResult ConcatenateOp::inferReturnTypes( MLIRContext*, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferred_return_types) { - ConcatenateOp::Adaptor adaptor(operands, attributes, {}, regions); + ConcatenateOp::Adaptor adaptor(operands, attributes, properties, regions); return hlo::inferConcatenateOp(location, adaptor.getInputs().getTypes(), adaptor.getDimension(), inferred_return_types); } LogicalResult TransposeOp::inferReturnTypes( MLIRContext*, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferred_return_types) { - TransposeOp::Adaptor adaptor(operands, attributes, {}, regions); + TransposeOp::Adaptor adaptor(operands, attributes, properties, regions); return hlo::inferTransposeOp(location, adaptor.getOperand(), adaptor.getPermutation(), inferred_return_types); } LogicalResult SliceOp::inferReturnTypes( MLIRContext*, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferred_return_types) { - SliceOp::Adaptor adaptor(operands, attributes, {}, regions); + SliceOp::Adaptor adaptor(operands, attributes, properties, regions); return hlo::inferSliceOp(location, adaptor.getOperand().getType(), adaptor.getStartIndices(), adaptor.getLimitIndices(), adaptor.getStrides(), inferred_return_types); @@ -375,9 +375,9 @@ LogicalResult inferDynamicSliceOp(std::optional location, LogicalResult DynamicSliceOp::inferReturnTypes( MLIRContext*, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { - DynamicSliceOp::Adaptor adaptor(operands, attributes, {}, regions); + DynamicSliceOp::Adaptor adaptor(operands, attributes, properties, regions); return inferDynamicSliceOp(location, adaptor.getOperand().getType(), adaptor.getStartIndices().getTypes(), adaptor.getSliceSizes(), inferredReturnTypes); @@ -427,9 +427,10 @@ LogicalResult inferDynamicUpdateSliceOp( LogicalResult DynamicUpdateSliceOp::inferReturnTypes( MLIRContext*, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { - DynamicUpdateSliceOp::Adaptor adaptor(operands, attributes, {}, regions); + DynamicUpdateSliceOp::Adaptor adaptor(operands, attributes, properties, + regions); return inferDynamicUpdateSliceOp( location, adaptor.getOperand(), adaptor.getUpdate(), diff --git a/libspu/dialect/pphlo/transforms/hlo_legalize_to_pphlo.cc b/libspu/dialect/pphlo/transforms/hlo_legalize_to_pphlo.cc index e17f85e5..781e0940 100644 --- a/libspu/dialect/pphlo/transforms/hlo_legalize_to_pphlo.cc +++ b/libspu/dialect/pphlo/transforms/hlo_legalize_to_pphlo.cc @@ -133,24 +133,20 @@ class FuncOpConverter : public OpConversionPattern<::mlir::func::FuncOp> { auto ®ion = op.getBody(); // Convert non-entry blocks - SmallVector conversions; - for (Block &block : llvm::drop_begin(region, 1)) { - conversions.emplace_back(block.getNumArguments()); - TypeConverter::SignatureConversion &back = conversions.back(); + for (Block &block : + llvm::make_early_inc_range(llvm::drop_begin(region, 1))) { + TypeConverter::SignatureConversion conversion( + /*numOrigInputs=*/block.getNumArguments()); for (BlockArgument blockArgument : block.getArguments()) { auto idx = blockArgument.getArgNumber(); auto vis_v = vis_.getValueVisibility(blockArgument); auto convertedType = tools_.getType( typeConverter->convertType(blockArgument.getType()), vis_v); - back.addInputs(idx, convertedType); + conversion.addInputs(idx, convertedType); } - } - if (failed(rewriter.convertNonEntryRegionTypes(®ion, *typeConverter, - conversions))) { - rewriter.cancelOpModification(op); - return failure(); + rewriter.applySignatureConversion(&block, conversion, getTypeConverter()); } // Convert function arguments using the provided TypeConverter. diff --git a/libspu/kernel/hal/constants.cc b/libspu/kernel/hal/constants.cc index 1b44cbd0..8659c5e3 100644 --- a/libspu/kernel/hal/constants.cc +++ b/libspu/kernel/hal/constants.cc @@ -103,7 +103,7 @@ spu::Value zeros(SPUContext* ctx, DataType dtype, const Shape& shape) { } Value iota(SPUContext* ctx, DataType dtype, int64_t numel) { - return DISPATCH_ALL_NONE_BOOL_PT_TYPES(getDecodeType(dtype), "iota", [&]() { + return DISPATCH_ALL_NONE_BOOL_PT_TYPES(getDecodeType(dtype), [&]() { std::vector arr(numel); std::iota(arr.begin(), arr.end(), 0); return constant(ctx, arr, dtype, {numel}); diff --git a/libspu/kernel/hal/fxp_approx.cc b/libspu/kernel/hal/fxp_approx.cc index 5cda83df..0f9864fd 100644 --- a/libspu/kernel/hal/fxp_approx.cc +++ b/libspu/kernel/hal/fxp_approx.cc @@ -69,12 +69,12 @@ Value log_minmax(SPUContext* ctx, const Value& x) { // get most significant non-zero bit of x // we avoid direct using detail::highestOneBit for saving one _prefix_or - auto pre_x1 = _rshift(ctx, pre_x, 1); + auto pre_x1 = _rshift(ctx, pre_x, {1}); auto msb = _xor(ctx, pre_x, pre_x1); // let x = x_norm * factor, where x in [1.0, 2.0) auto factor = _bitrev(ctx, msb, 0, 2 * num_fxp_bits + 1).setDtype(x.dtype()); - detail::hintNumberOfBits(factor, 2 * num_fxp_bits + 1); + factor = maskNumberOfBits(ctx, factor, 2 * num_fxp_bits + 1); auto norm = f_mul(ctx, x, factor); // log(x) = log(x_norm * factor) @@ -83,7 +83,7 @@ Value log_minmax(SPUContext* ctx, const Value& x) { auto log_norm = log_minmax_normalized(ctx, norm); auto log2_e = _lshift(ctx, _sub(ctx, k, _constant(ctx, num_fxp_bits + 1, x.shape())), - num_fxp_bits) + {static_cast(num_fxp_bits)}) .setDtype(x.dtype()); auto k_log2 = constant(ctx, std::log(2), x.dtype(), x.shape()); auto log_e = f_mul(ctx, log2_e, k_log2); @@ -145,7 +145,7 @@ Value log2_pade(SPUContext* ctx, const Value& x) { // let x = x_norm * factor, where x in [0.5, 1.0) auto msb = detail::highestOneBit(ctx, x); auto factor = _bitrev(ctx, msb, 0, 2 * num_fxp_bits).setDtype(x.dtype()); - detail::hintNumberOfBits(factor, 2 * num_fxp_bits); + factor = maskNumberOfBits(ctx, factor, 2 * num_fxp_bits); auto norm = f_mul(ctx, x, factor); // log2(x) = log2(x_norm * factor) @@ -154,7 +154,7 @@ Value log2_pade(SPUContext* ctx, const Value& x) { return _add( ctx, log2_pade_normalized(ctx, norm), _lshift(ctx, _sub(ctx, k, _constant(ctx, num_fxp_bits, x.shape())), - num_fxp_bits)) + {static_cast(num_fxp_bits)})) .setDtype(x.dtype()); } @@ -260,15 +260,18 @@ Value exp2_pade(SPUContext* ctx, const Value& x) { const size_t bit_width = SizeOf(ctx->getField()) * 8; const auto x_bshare = _prefer_b(ctx, x); - const auto x_msb = _rshift(ctx, x_bshare, bit_width - 1); - auto x_integer = _rshift(ctx, x_bshare, fbits); + const auto x_msb = + _rshift(ctx, x_bshare, {static_cast(bit_width - 1)}); + auto x_integer = _rshift(ctx, x_bshare, {static_cast(fbits)}); auto x_fraction = - _sub(ctx, x, _lshift(ctx, x_integer, fbits)).setDtype(x.dtype()); + _sub(ctx, x, _lshift(ctx, x_integer, {static_cast(fbits)})) + .setDtype(x.dtype()); auto ret = exp2_pade_normalized(ctx, x_fraction); for (size_t idx = 0; idx < int_bits; idx++) { - auto a = _and(ctx, _rshift(ctx, x_integer, idx), k1); - detail::hintNumberOfBits(a, 1); + auto a = + _and(ctx, _rshift(ctx, x_integer, {static_cast(idx)}), k1); + a = detail::maskNumberOfBits(ctx, a, 1); a = _prefer_a(ctx, a); const auto K = 1U << std::min(1UL << idx, bit_width - 2); ret = _mul(ctx, ret, @@ -543,7 +546,7 @@ static Value rsqrt_init_guess(SPUContext* ctx, const Value& x, const Value& z) { // let u in [0.25, 0.5) auto z_rev = _bitrev(ctx, z, 0, 2 * f); - detail::hintNumberOfBits(z_rev, 2 * f); + z_rev = detail::maskNumberOfBits(ctx, z_rev, 2 * f); auto u = _trunc(ctx, _mul(ctx, x, z_rev)).setDtype(x.dtype()); @@ -583,17 +586,17 @@ static Value rsqrt_comp(SPUContext* ctx, const Value& x, const Value& z) { auto lo_mask = _constant(ctx, (static_cast(1) << (k / 2)) - 1, x.shape()); auto z_even = _and(ctx, z_sep, lo_mask); - auto z_odd = _and(ctx, _rshift(ctx, z_sep, k / 2), lo_mask); + auto z_odd = + _and(ctx, _rshift(ctx, z_sep, {static_cast(k / 2)}), lo_mask); // a[i] = z[2*i] ^ z[2*i+1] a = _xor(ctx, z_odd, z_even); // b ^= z[2*i] b = _bit_parity(ctx, z_even, k / 2); - detail::hintNumberOfBits(b, 1); } auto a_rev = _bitrev(ctx, a, 0, (f / 2) * 2); - detail::hintNumberOfBits(a_rev, (f / 2) * 2); + a_rev = detail::maskNumberOfBits(ctx, a_rev, (f / 2) * 2); // do compensation // Note: @@ -623,7 +626,7 @@ static Value rsqrt_np2(SPUContext* ctx, const Value& x) { SPU_TRACE_HAL_LEAF(ctx, x); // let e = NP2(x), z = 2^(e+f) - return _lshift(ctx, detail::highestOneBit(ctx, x), 1); + return _lshift(ctx, detail::highestOneBit(ctx, x), {1}); } // Reference: diff --git a/libspu/kernel/hal/fxp_base.cc b/libspu/kernel/hal/fxp_base.cc index 209782e1..1594a9ac 100644 --- a/libspu/kernel/hal/fxp_base.cc +++ b/libspu/kernel/hal/fxp_base.cc @@ -72,7 +72,7 @@ Value polynomial(SPUContext* ctx, const Value& x, Value highestOneBit(SPUContext* ctx, const Value& x) { auto y = _prefix_or(ctx, x); - auto y1 = _rshift(ctx, y, 1); + auto y1 = _rshift(ctx, y, {1}); return _xor(ctx, y, y1); } @@ -85,6 +85,13 @@ void hintNumberOfBits(const Value& a, size_t nbits) { } } +Value maskNumberOfBits(SPUContext* ctx, const Value& in, size_t nbits) { + auto k1 = constant(ctx, 1UL, spu::DT_I64, in.shape()); + auto mask = _sub(ctx, _lshift(ctx, k1, {static_cast(nbits)}), k1); + auto out = _and(ctx, in, mask).setDtype(in.dtype()); + return out; +} + namespace { Value reciprocal_goldschmidt_normalized_approx(SPUContext* ctx, @@ -178,7 +185,8 @@ Value div_goldschmidt_general(SPUContext* ctx, const Value& a, const Value& b, // factor = 2^{f-m} = 2^{-m} * 2^f, the fixed point repr of 2^{-m} const size_t num_fxp_bits = ctx->getFxpBits(); auto factor = _bitrev(ctx, b_msb, 0, 2 * num_fxp_bits).setDtype(b.dtype()); - detail::hintNumberOfBits(factor, 2 * num_fxp_bits); + + factor = maskNumberOfBits(ctx, factor, 2 * num_fxp_bits); // also, we use factor twice factor = _prefer_a(ctx, factor); @@ -209,7 +217,7 @@ Value reciprocal_goldschmidt_positive(SPUContext* ctx, const Value& b_abs) { const size_t num_fxp_bits = ctx->getFxpBits(); auto factor = _bitrev(ctx, b_msb, 0, 2 * num_fxp_bits).setDtype(b_abs.dtype()); - detail::hintNumberOfBits(factor, 2 * num_fxp_bits); + factor = maskNumberOfBits(ctx, factor, 2 * num_fxp_bits); // also, we use factor twice factor = _prefer_a(ctx, factor); @@ -237,13 +245,12 @@ Value reciprocal_goldschmidt(SPUContext* ctx, const Value& b) { // factor = 2^{f-m} = 2^{-m} * 2^f, the fixed point repr of 2^{-m} const size_t num_fxp_bits = ctx->getFxpBits(); auto factor = _bitrev(ctx, b_msb, 0, 2 * num_fxp_bits).setDtype(b.dtype()); - detail::hintNumberOfBits(factor, 2 * num_fxp_bits); + factor = maskNumberOfBits(ctx, factor, 2 * num_fxp_bits); // also, we use factor twice factor = _prefer_a(ctx, factor); // compute approximation of normalize b_abs auto r = reciprocal_goldschmidt_normalized_approx(ctx, b_abs, factor); - r = f_mul(ctx, r, factor, SignType::Positive); return _mux(ctx, is_negative, _negate(ctx, r), r).setDtype(b.dtype()); @@ -370,8 +377,8 @@ Value f_floor(SPUContext* ctx, const Value& x) { SPU_ENFORCE(x.isFxp()); - const size_t fbits = ctx->getFxpBits(); - return _lshift(ctx, _arshift(ctx, x, fbits), fbits).setDtype(x.dtype()); + const int64_t fbits = ctx->getFxpBits(); + return _lshift(ctx, _arshift(ctx, x, {fbits}), {fbits}).setDtype(x.dtype()); } Value f_ceil(SPUContext* ctx, const Value& x) { diff --git a/libspu/kernel/hal/fxp_base.h b/libspu/kernel/hal/fxp_base.h index c72022b6..61fe7bca 100644 --- a/libspu/kernel/hal/fxp_base.h +++ b/libspu/kernel/hal/fxp_base.h @@ -30,6 +30,8 @@ Value highestOneBit(SPUContext* ctx, const Value& x); void hintNumberOfBits(const Value& a, size_t nbits); +Value maskNumberOfBits(SPUContext* ctx, const Value& a, size_t nbits); + // we provide this general function to support some special cases (a or b has // guarranteed sign) in fxp_approx for better both performance and accuracy. Value div_goldschmidt_general(SPUContext* ctx, const Value& a, const Value& b, diff --git a/libspu/kernel/hal/fxp_cleartext.cc b/libspu/kernel/hal/fxp_cleartext.cc index 818e0b23..b5061d5b 100644 --- a/libspu/kernel/hal/fxp_cleartext.cc +++ b/libspu/kernel/hal/fxp_cleartext.cc @@ -57,7 +57,7 @@ Value applyFloatingPointFn(SPUContext* ctx, const Value& in, FN&& fn) { auto pt_type = getDecodeType(in.dtype()); for (auto iter = fp_arr.begin(); iter != fp_arr.end(); ++iter) { - DISPATCH_FLOAT_PT_TYPES(pt_type, "pt_type", [&]() { + DISPATCH_FLOAT_PT_TYPES(pt_type, [&]() { auto* ptr = reinterpret_cast(&*iter); *ptr = fn(*ptr); }); @@ -92,9 +92,9 @@ Value applyFloatingPointFn(SPUContext* ctx, const Value& x, const Value& y, for (auto itr_x = flp_x.begin(), itr_y = flp_y.begin(); itr_x != flp_x.end(); itr_x++, itr_y++) { - DISPATCH_FLOAT_PT_TYPES(x_pt_type, "x_pt_type", [&]() { + DISPATCH_FLOAT_PT_TYPES(x_pt_type, [&]() { auto* ptr_x = reinterpret_cast(&*itr_x); - DISPATCH_FLOAT_PT_TYPES(y_pt_type, "y_pt_type", [&]() { + DISPATCH_FLOAT_PT_TYPES(y_pt_type, [&]() { auto* ptr_y = reinterpret_cast(&*itr_y); *ptr_x = fn(*ptr_x, *ptr_y); }); diff --git a/libspu/kernel/hal/permute.cc b/libspu/kernel/hal/permute.cc index 1c67c997..7ae47cda 100644 --- a/libspu/kernel/hal/permute.cc +++ b/libspu/kernel/hal/permute.cc @@ -553,7 +553,8 @@ std::vector _bit_decompose(SPUContext *ctx, const spu::Value &x, rets_b.reserve(nbits); for (size_t bit = 0; bit < nbits; ++bit) { - auto x_bshare_shift = right_shift_logical(ctx, x_bshare, bit); + auto x_bshare_shift = + right_shift_logical(ctx, x_bshare, {static_cast(bit)}); rets_b.push_back(_and(ctx, x_bshare_shift, k1)); } @@ -703,10 +704,10 @@ spu::Value _apply_perm_ss(SPUContext *ctx, const Value &x, const Value &perm) { // Find mergeable keys from keys. Consecutive public/private(belong to one // owner) keys can be merged. Assume there are six keys, i.e., public_key0, -// bob_key0, bob_key1, alice_key0, alice_key1, secret_key0. We can merge the six -// keys into bob_new_key, alice_new_key, secret_key0 for the following sorting. -// This function will return a vector of indices [3,5,6] which means key[0,3), -// key[3,5), and key[5,6) can be merged. +// bob_key0, bob_key1, alice_key0, alice_key1, secret_key0. We can merge the +// six keys into bob_new_key, alice_new_key, secret_key0 for the following +// sorting. This function will return a vector of indices [3,5,6] which means +// key[0,3), key[3,5), and key[5,6) can be merged. std::vector _find_mergeable_keys(SPUContext *ctx, absl::Span keys) { std::vector split_indices; diff --git a/libspu/kernel/hal/polymorphic.cc b/libspu/kernel/hal/polymorphic.cc index 34cf5f77..21680cc6 100644 --- a/libspu/kernel/hal/polymorphic.cc +++ b/libspu/kernel/hal/polymorphic.cc @@ -356,7 +356,7 @@ Value power(SPUContext* ctx, const Value& x, const Value& y) { const auto bit_width = SizeOf(ctx->getField()) * 8; auto y_b = _prefer_b(ctx, y); - auto msb_y = _rshift(ctx, y_b, bit_width - 1); + auto msb_y = _rshift(ctx, y_b, {static_cast(bit_width - 1)}); auto x_abs1 = _equal(ctx, abs(ctx, x), k1); auto ret = _constant(ctx, 1, x.shape()); @@ -379,7 +379,9 @@ Value power(SPUContext* ctx, const Value& x, const Value& y) { // e.g. y=0101, then ret = (x) * (1) * (x^(2^2)) * (1) = x^5 for (size_t idx = 0; idx < y_bits; idx++) { // x^(2^idx) * y_{idx} - auto cur_pow = _mux(ctx, _and(ctx, _rshift(ctx, y_b, idx), k1), base, k1); + auto cur_pow = _mux( + ctx, _and(ctx, _rshift(ctx, y_b, {static_cast(idx)}), k1), + base, k1); ret = _mul(ctx, cur_pow, ret); if (idx < y_bits - 1) { base = _mul(ctx, base, base); @@ -409,8 +411,9 @@ Value power(SPUContext* ctx, const Value& x, const Value& y) { // the final sign is decided on both sign of x and the parity of y // when x<0 and y is odd, e.g. (-2)^3 = -8 - auto odd = _and(ctx, _rshift(ctx, y, ctx->getFxpBits()), - _constant(ctx, 1, y.shape())); + auto odd = + _and(ctx, _rshift(ctx, y, {static_cast(ctx->getFxpBits())}), + _constant(ctx, 1, y.shape())); auto sign = _and(ctx, msb, odd); return _mux(ctx, sign, _negate(ctx, val), val).setDtype(x.dtype()); @@ -488,19 +491,20 @@ Value bitcast(SPUContext* ctx, const Value& x, DataType dtype) { return Value(x.data().clone(), dtype); } -Value left_shift(SPUContext* ctx, const Value& x, size_t bits) { +Value left_shift(SPUContext* ctx, const Value& x, const Sizes& bits) { SPU_TRACE_HAL_DISP(ctx, x, bits); return _lshift(ctx, x, bits).setDtype(x.dtype()); } -Value right_shift_logical(SPUContext* ctx, const Value& x, size_t bits) { +Value right_shift_logical(SPUContext* ctx, const Value& x, const Sizes& bits) { SPU_TRACE_HAL_DISP(ctx, x, bits); return _rshift(ctx, x, bits).setDtype(x.dtype()); } -Value right_shift_arithmetic(SPUContext* ctx, const Value& x, size_t bits) { +Value right_shift_arithmetic(SPUContext* ctx, const Value& x, + const Sizes& bits) { SPU_TRACE_HAL_DISP(ctx, x, bits); return _arshift(ctx, x, bits).setDtype(x.dtype()); diff --git a/libspu/kernel/hal/polymorphic.h b/libspu/kernel/hal/polymorphic.h index 58e90f00..9b47b85d 100644 --- a/libspu/kernel/hal/polymorphic.h +++ b/libspu/kernel/hal/polymorphic.h @@ -187,11 +187,12 @@ Value clamp(SPUContext* ctx, const Value& x, const Value& min, // @param dtype, second input value Value bitcast(SPUContext* ctx, const Value& x, DataType dtype); -Value left_shift(SPUContext* ctx, const Value& x, size_t bits); +Value left_shift(SPUContext* ctx, const Value& x, const Sizes& bits); -Value right_shift_logical(SPUContext* ctx, const Value& x, size_t bits); +Value right_shift_logical(SPUContext* ctx, const Value& x, const Sizes& bits); -Value right_shift_arithmetic(SPUContext* ctx, const Value& x, size_t bits); +Value right_shift_arithmetic(SPUContext* ctx, const Value& x, + const Sizes& bits); /// the element-wise base-2 logarithm of x // @param in, should be positive, or the result is implementation defined. diff --git a/libspu/kernel/hal/prot_wrapper.cc b/libspu/kernel/hal/prot_wrapper.cc index 10743c10..23f46388 100644 --- a/libspu/kernel/hal/prot_wrapper.cc +++ b/libspu/kernel/hal/prot_wrapper.cc @@ -30,11 +30,11 @@ namespace spu::kernel::hal { return mpc::NAME(ctx, in); \ } -#define MAP_SHIFT_OP(NAME) \ - Value _##NAME(SPUContext* ctx, const Value& in, size_t bits) { \ - SPU_TRACE_HAL_DISP(ctx, in, bits); \ - auto ret = mpc::NAME(ctx, in, bits); \ - return ret; \ +#define MAP_SHIFT_OP(NAME) \ + Value _##NAME(SPUContext* ctx, const Value& in, const Sizes& bits) { \ + SPU_TRACE_HAL_DISP(ctx, in, bits); \ + auto ret = mpc::NAME(ctx, in, bits); \ + return ret; \ } #define MAP_BITREV_OP(NAME) \ diff --git a/libspu/kernel/hal/prot_wrapper.h b/libspu/kernel/hal/prot_wrapper.h index 6294f834..5fb110bc 100644 --- a/libspu/kernel/hal/prot_wrapper.h +++ b/libspu/kernel/hal/prot_wrapper.h @@ -52,17 +52,17 @@ Value _equal_pp(SPUContext* ctx, const Value& x, const Value& y); std::optional _equal_sp(SPUContext* ctx, const Value& x, const Value& y); std::optional _equal_ss(SPUContext* ctx, const Value& x, const Value& y); -Value _lshift_p(SPUContext* ctx, const Value& in, size_t bits); -Value _lshift_s(SPUContext* ctx, const Value& in, size_t bits); -Value _lshift_v(SPUContext* ctx, const Value& in, size_t bits); +Value _lshift_p(SPUContext* ctx, const Value& in, const Sizes& bits); +Value _lshift_s(SPUContext* ctx, const Value& in, const Sizes& bits); +Value _lshift_v(SPUContext* ctx, const Value& in, const Sizes& bits); -Value _rshift_p(SPUContext* ctx, const Value& in, size_t bits); -Value _rshift_s(SPUContext* ctx, const Value& in, size_t bits); -Value _rshift_v(SPUContext* ctx, const Value& in, size_t bits); +Value _rshift_p(SPUContext* ctx, const Value& in, const Sizes& bits); +Value _rshift_s(SPUContext* ctx, const Value& in, const Sizes& bits); +Value _rshift_v(SPUContext* ctx, const Value& in, const Sizes& bits); -Value _arshift_p(SPUContext* ctx, const Value& in, size_t bits); -Value _arshift_s(SPUContext* ctx, const Value& in, size_t bits); -Value _arshift_v(SPUContext* ctx, const Value& in, size_t bits); +Value _arshift_p(SPUContext* ctx, const Value& in, const Sizes& bits); +Value _arshift_s(SPUContext* ctx, const Value& in, const Sizes& bits); +Value _arshift_v(SPUContext* ctx, const Value& in, const Sizes& bits); Value _trunc_p(SPUContext* ctx, const Value& in, size_t bits, SignType sign); Value _trunc_s(SPUContext* ctx, const Value& in, size_t bits, SignType sign); diff --git a/libspu/kernel/hal/ring.cc b/libspu/kernel/hal/ring.cc index 08b873f6..b016cd71 100644 --- a/libspu/kernel/hal/ring.cc +++ b/libspu/kernel/hal/ring.cc @@ -84,18 +84,18 @@ IMPL_UNARY_OP(_square) #undef IMPL_UNARY_OP -#define IMPL_SHIFT_OP(Name) \ - Value Name(SPUContext* ctx, const Value& in, size_t bits) { \ - SPU_TRACE_HAL_LEAF(ctx, in, bits); \ - if (in.isPublic()) { \ - return Name##_p(ctx, in, bits); \ - } else if (in.isSecret()) { \ - return Name##_s(ctx, in, bits); \ - } else if (in.isPrivate()) { \ - return Name##_v(ctx, in, bits); \ - } else { \ - SPU_THROW("unsupport unary op={} for {}", #Name, in); \ - } \ +#define IMPL_SHIFT_OP(Name) \ + Value Name(SPUContext* ctx, const Value& in, const Sizes& bits) { \ + SPU_TRACE_HAL_LEAF(ctx, in, bits); \ + if (in.isPublic()) { \ + return Name##_p(ctx, in, bits); \ + } else if (in.isSecret()) { \ + return Name##_s(ctx, in, bits); \ + } else if (in.isPrivate()) { \ + return Name##_v(ctx, in, bits); \ + } else { \ + SPU_THROW("unsupport unary op={} for {}", #Name, in); \ + } \ } IMPL_SHIFT_OP(_lshift) @@ -497,7 +497,7 @@ Value _bit_parity(SPUContext* ctx, const Value& x, size_t bits) { SPU_ENFORCE(absl::has_single_bit(bits), "currently only support power of 2"); auto ret = _prefer_b(ctx, x); while (bits > 1) { - ret = _xor(ctx, ret, _rshift(ctx, ret, bits / 2)); + ret = _xor(ctx, ret, _rshift(ctx, ret, {static_cast(bits / 2)})); bits /= 2; } @@ -518,7 +518,7 @@ Value _popcount(SPUContext* ctx, const Value& x, size_t bits) { std::vector vs; vs.reserve(bits); for (size_t idx = 0; idx < bits; idx++) { - auto x_ = _rshift(ctx, xb, idx); + auto x_ = _rshift(ctx, xb, {static_cast(idx)}); x_ = _and(ctx, x_, _constant(ctx, 1U, x.shape())); if (x_.storage_type().isa()) { @@ -547,8 +547,8 @@ Value _prefix_or(SPUContext* ctx, const Value& x) { auto b0 = _prefer_b(ctx, x); const size_t bit_width = SizeOf(ctx->getField()) * 8; for (int idx = 0; idx < absl::bit_width(bit_width) - 1; idx++) { - const size_t offset = 1UL << idx; - auto b1 = _rshift(ctx, b0, offset); + const int64_t offset = 1L << idx; + auto b1 = _rshift(ctx, b0, {offset}); b0 = _or(ctx, b0, b1); } return b0; @@ -574,8 +574,8 @@ Value _bitdeintl(SPUContext* ctx, const Value& in) { // out = (out & keep) ^ ((out >> shift) & move) ^ ((out & move) << shift); out = _xor(ctx, _xor(ctx, _and(ctx, out, keep), - _and(ctx, _rshift(ctx, out, shift), move)), - _lshift(ctx, _and(ctx, out, move), shift)); + _and(ctx, _rshift(ctx, out, {shift}), move)), + _lshift(ctx, _and(ctx, out, move), {shift})); } return out; } diff --git a/libspu/kernel/hal/ring.h b/libspu/kernel/hal/ring.h index b2fb6f65..0dd7234a 100644 --- a/libspu/kernel/hal/ring.h +++ b/libspu/kernel/hal/ring.h @@ -71,11 +71,11 @@ Value _equal(SPUContext* ctx, const Value& x, const Value& y); Value _less(SPUContext* ctx, const Value& x, const Value& y); -Value _lshift(SPUContext* ctx, const Value& in, size_t bits); +Value _lshift(SPUContext* ctx, const Value& in, const Sizes& bits); -Value _rshift(SPUContext* ctx, const Value& in, size_t bits); +Value _rshift(SPUContext* ctx, const Value& in, const Sizes& bits); -Value _arshift(SPUContext* ctx, const Value& in, size_t bits); +Value _arshift(SPUContext* ctx, const Value& in, const Sizes& bits); Value _trunc(SPUContext* ctx, const Value& x, size_t bits = 0, SignType sign = SignType::Unknown); diff --git a/libspu/kernel/hal/type_cast.cc b/libspu/kernel/hal/type_cast.cc index 53b65742..9adb77a8 100644 --- a/libspu/kernel/hal/type_cast.cc +++ b/libspu/kernel/hal/type_cast.cc @@ -27,7 +27,8 @@ Value int2fxp(SPUContext* ctx, const Value& x, DataType to_type) { SPU_TRACE_HAL_LEAF(ctx, x); SPU_ENFORCE(x.isInt(), "expect integer, got {}", x.dtype()); - return _lshift(ctx, x, ctx->getFxpBits()).setDtype(to_type); + return _lshift(ctx, x, {static_cast(ctx->getFxpBits())}) + .setDtype(to_type); } // Casting fxp to integer. @@ -49,12 +50,12 @@ Value fxp2int(SPUContext* ctx, const Value& x, DataType to_type) { SPU_TRACE_HAL_LEAF(ctx, x); SPU_ENFORCE(x.isFxp()); - const size_t fxp_bits = ctx->getFxpBits(); + const int64_t fxp_bits = ctx->getFxpBits(); const Value kOneMinusEps = _constant(ctx, (1 << fxp_bits) - 1, x.shape()); // (x + 0.99 * (x < 0)) >> fxp_bits return _arshift(ctx, _add(ctx, x, _mul(ctx, kOneMinusEps, _msb(ctx, x))), - fxp_bits) + {fxp_bits}) .setDtype(to_type); } diff --git a/libspu/kernel/hlo/basic_unary.cc b/libspu/kernel/hlo/basic_unary.cc index ce1d2c17..44cc9e36 100644 --- a/libspu/kernel/hlo/basic_unary.cc +++ b/libspu/kernel/hlo/basic_unary.cc @@ -118,19 +118,19 @@ spu::Value Round_RNTE(SPUContext *ctx, const spu::Value &in) { // so comp = b && (c || a) SPU_ENFORCE(!in.isComplex()); SPU_ENFORCE(in.isFxp(), "Round only supports fxp"); - const auto fxp_bits = ctx->getFxpBits(); + const int64_t fxp_bits = ctx->getFxpBits(); const auto k1 = hal::_constant(ctx, 1U, in.shape()); auto x_prime = hal::_prefer_b(ctx, in); auto y = hal::floor(ctx, x_prime); - auto a = hal::_and(ctx, hal::_rshift(ctx, x_prime, fxp_bits), k1); - auto b = hal::_and(ctx, hal::_rshift(ctx, x_prime, fxp_bits - 1), k1); + auto a = hal::_and(ctx, hal::_rshift(ctx, x_prime, {fxp_bits}), k1); + auto b = hal::_and(ctx, hal::_rshift(ctx, x_prime, {fxp_bits - 1}), k1); std::vector cs; cs.reserve(fxp_bits - 1); - for (size_t idx = 0; idx < fxp_bits - 1; idx++) { - auto x_ = hal::_and(ctx, hal::_rshift(ctx, x_prime, idx), k1); + for (int64_t idx = 0; idx < fxp_bits - 1; idx++) { + auto x_ = hal::_and(ctx, hal::_rshift(ctx, x_prime, {idx}), k1); cs.push_back(std::move(x_)); } auto c = vreduce(cs.begin(), cs.end(), [&](const Value &a, const Value &b) { diff --git a/libspu/kernel/hlo/shift.cc b/libspu/kernel/hlo/shift.cc index 08fe7140..0b6f3f1d 100644 --- a/libspu/kernel/hlo/shift.cc +++ b/libspu/kernel/hlo/shift.cc @@ -26,31 +26,8 @@ namespace spu::kernel::hlo { template spu::Value shift_impl_p(SPUContext *ctx, const spu::Value &lhs, const spu::Value &rhs, const Fn &f) { - auto shift_bits = hal::dump_public_as(ctx, rhs); - if (std::all_of(rhs.strides().begin(), rhs.strides().end(), - [](int64_t s) { return s == 0; })) { - // rhs is a splat - return f(ctx, lhs, shift_bits[0]); - } - - // Not a splat... - spu::Value ret = - hal::constant(ctx, static_cast(0), lhs.dtype(), lhs.shape()); - auto dtype_size = getWidth(lhs.dtype()); - for (size_t bits = 0; bits < dtype_size; ++bits) { - if (std::none_of(shift_bits.begin(), shift_bits.end(), [&bits](int8_t b) { - return b == static_cast(bits); - })) { - continue; - } - auto current_bits = hal::constant(ctx, static_cast(bits), - rhs.dtype(), rhs.shape()); - auto mask = hal::equal(ctx, rhs, current_bits); - auto shifted = f(ctx, lhs, bits); - ret = hal::add(ctx, ret, hal::mul(ctx, mask, shifted)); - } - - return ret; + auto shift_bits = hal::dump_public_as(ctx, rhs); + return f(ctx, lhs, {shift_bits.begin(), shift_bits.end()}); } template @@ -63,7 +40,7 @@ spu::Value shift_impl_s(SPUContext *ctx, const spu::Value &lhs, auto current_bits = hal::constant(ctx, static_cast(bits), rhs.dtype(), rhs.shape()); auto mask = hal::equal(ctx, rhs, current_bits); - auto shifted = f(ctx, lhs, bits); + auto shifted = f(ctx, lhs, {static_cast(bits)}); ret = hal::add(ctx, ret, hal::mul(ctx, mask, shifted)); } diff --git a/libspu/mpc/ab_api.cc b/libspu/mpc/ab_api.cc index 7d02fe47..a0720870 100644 --- a/libspu/mpc/ab_api.cc +++ b/libspu/mpc/ab_api.cc @@ -133,7 +133,7 @@ OptionalAPI mul_a1bv(SPUContext* ctx, const Value& x, const Value& y) { return NotAvailable; } -Value lshift_a(SPUContext* ctx, const Value& x, size_t nbits) { +Value lshift_a(SPUContext* ctx, const Value& x, const Sizes& nbits) { FORCE_DISPATCH(ctx, x, nbits); } @@ -201,15 +201,15 @@ OptionalAPI xor_bv(SPUContext* ctx, const Value& x, const Value& y) { return NotAvailable; } -Value lshift_b(SPUContext* ctx, const Value& x, size_t nbits) { +Value lshift_b(SPUContext* ctx, const Value& x, const Sizes& nbits) { FORCE_DISPATCH(ctx, x, nbits); } -Value rshift_b(SPUContext* ctx, const Value& x, size_t nbits) { +Value rshift_b(SPUContext* ctx, const Value& x, const Sizes& nbits) { FORCE_DISPATCH(ctx, x, nbits); } -Value arshift_b(SPUContext* ctx, const Value& x, size_t nbits) { +Value arshift_b(SPUContext* ctx, const Value& x, const Sizes& nbits) { FORCE_DISPATCH(ctx, x, nbits); } @@ -254,10 +254,10 @@ Value bitintl_b(SPUContext* ctx, const Value& x, size_t stride) { auto M = hack_make_p(ctx, spu::detail::kBitIntlSwapMasks[idx], x.shape()); int64_t S = static_cast(1) << idx; // out = (out & K) ^ ((out >> S) & M) ^ ((out & M) << S); - out = xor_bb( - ctx, - xor_bb(ctx, and_bp(ctx, out, K), and_bp(ctx, rshift_b(ctx, out, S), M)), - lshift_b(ctx, and_bp(ctx, out, M), S)); + out = xor_bb(ctx, + xor_bb(ctx, and_bp(ctx, out, K), + and_bp(ctx, rshift_b(ctx, out, {S}), M)), + lshift_b(ctx, and_bp(ctx, out, M), {S})); } out = setNumBits(out, numBits(x)); return out; @@ -283,10 +283,10 @@ Value bitdeintl_b(SPUContext* ctx, const Value& x, size_t stride) { auto M = hack_make_p(ctx, spu::detail::kBitIntlSwapMasks[idx], x.shape()); int64_t S = static_cast(1) << idx; // out = (out & K) ^ ((out >> S) & M) ^ ((out & M) << S); - out = xor_bb( - ctx, - xor_bb(ctx, and_bp(ctx, out, K), and_bp(ctx, rshift_b(ctx, out, S), M)), - lshift_b(ctx, and_bp(ctx, out, M), S)); + out = xor_bb(ctx, + xor_bb(ctx, and_bp(ctx, out, K), + and_bp(ctx, rshift_b(ctx, out, {S}), M)), + lshift_b(ctx, and_bp(ctx, out, M), {S})); } out = setNumBits(out, numBits(x)); return out; @@ -318,9 +318,9 @@ Value ppa_kogge_stone(SPUContext* ctx, const Value& lhs, const Value& rhs, auto G = and_bb(ctx, lhs, rhs); for (int idx = 0; idx < Log2Ceil(nbits); ++idx) { - const size_t offset = 1UL << idx; - auto G1 = lshift_b(ctx, G, offset); - auto P1 = lshift_b(ctx, P, offset); + const int64_t offset = static_cast(1) << idx; + auto G1 = lshift_b(ctx, G, {offset}); + auto P1 = lshift_b(ctx, P, {offset}); // P1 = P & P1 // G1 = G ^ (P & G1) @@ -332,7 +332,7 @@ Value ppa_kogge_stone(SPUContext* ctx, const Value& lhs, const Value& rhs, } // out = (G << 1) ^ p0 - auto C = lshift_b(ctx, G, 1); + auto C = lshift_b(ctx, G, {1}); return xor_bb(ctx, xor_bb(ctx, lhs, rhs), C); } @@ -343,7 +343,7 @@ std::pair bit_scatter(SPUContext* ctx, const Value& in, SPU_ENFORCE(absl::has_single_bit(nbits), "unsupported {}", nbits); auto out = bitdeintl_b(ctx, in, stride); - auto hi = rshift_b(ctx, out, nbits / 2); + auto hi = rshift_b(ctx, out, {static_cast(nbits / 2)}); auto mask = hack_make_p(ctx, (static_cast(1) << (nbits / 2)) - 1, in.shape()); auto lo = and_bp(ctx, out, mask); @@ -357,7 +357,7 @@ Value bit_gather(SPUContext* ctx, const Value& hi, const Value& lo, SPU_ENFORCE(nbits == numBits(lo), "nbits mismatch {}, {}", nbits, numBits(lo)); - auto out = xor_bb(ctx, lshift_b(ctx, hi, nbits), lo); + auto out = xor_bb(ctx, lshift_b(ctx, hi, {static_cast(nbits)}), lo); return bitintl_b(ctx, out, stride); } @@ -395,8 +395,8 @@ Value ppa_sklansky(SPUContext* ctx, Value const& lhs, Value const& rhs, auto Gs = and_bp(ctx, Gl, s_mask); auto Ps = and_bp(ctx, Pl, s_mask); for (int j = 0; j < idx; j++) { - Gs = xor_bb(ctx, Gs, rshift_b(ctx, Gs, 1 << j)); - Ps = xor_bb(ctx, Ps, rshift_b(ctx, Ps, 1 << j)); + Gs = xor_bb(ctx, Gs, rshift_b(ctx, Gs, {1 << j})); + Ps = xor_bb(ctx, Ps, rshift_b(ctx, Ps, {1 << j})); } // SPU_ENFORCE(numBits(Ps) == bit_width / 2); // SPU_ENFORCE(numBits(Gs) == bit_width / 2); @@ -416,7 +416,7 @@ Value ppa_sklansky(SPUContext* ctx, Value const& lhs, Value const& rhs, } // out = (G0 << 1) ^ p0 - auto C = lshift_b(ctx, G, 1); + auto C = lshift_b(ctx, G, {1}); return xor_bb(ctx, xor_bb(ctx, lhs, rhs), C); } @@ -460,8 +460,8 @@ Value carry_a2b(SPUContext* ctx, const Value& x, const Value& y, size_t k) { while (k > 1) { if (k % 2 != 0) { k += 1; - P = lshift_b(ctx, P, 1); - G = lshift_b(ctx, G, 1); + P = lshift_b(ctx, P, {1}); + G = lshift_b(ctx, G, {1}); } auto [P1, P0] = bit_scatter(ctx, P, 0); auto [G1, G0] = bit_scatter(ctx, G, 0); diff --git a/libspu/mpc/ab_api.h b/libspu/mpc/ab_api.h index 72a8476f..94f17d98 100644 --- a/libspu/mpc/ab_api.h +++ b/libspu/mpc/ab_api.h @@ -46,7 +46,7 @@ OptionalAPI mul_av(SPUContext* ctx, const Value& x, const Value& y); Value mul_a1b(SPUContext* ctx, const Value& x, const Value& y); OptionalAPI mul_a1bv(SPUContext* ctx, const Value& x, const Value& y); -Value lshift_a(SPUContext* ctx, const Value& x, size_t nbits); +Value lshift_a(SPUContext* ctx, const Value& x, const Sizes& nbits); Value trunc_a(SPUContext* ctx, const Value& x, size_t nbits, SignType sign); Value mmul_ap(SPUContext* ctx, const Value& x, const Value& y); @@ -72,9 +72,9 @@ Value xor_bb(SPUContext* ctx, const Value& x, const Value& y); OptionalAPI xor_bv(SPUContext* ctx, const Value& x, const Value& y); // TODO -Value lshift_b(SPUContext* ctx, const Value& x, size_t nbits); -Value rshift_b(SPUContext* ctx, const Value& x, size_t nbits); -Value arshift_b(SPUContext* ctx, const Value& x, size_t nbits); +Value lshift_b(SPUContext* ctx, const Value& x, const Sizes& nbits); +Value rshift_b(SPUContext* ctx, const Value& x, const Sizes& nbits); +Value arshift_b(SPUContext* ctx, const Value& x, const Sizes& nbits); // Bit reverse for binary share. Value bitrev_b(SPUContext* ctx, const Value& x, size_t start, size_t end); diff --git a/libspu/mpc/ab_api_test.cc b/libspu/mpc/ab_api_test.cc index 13ef4080..592f9389 100644 --- a/libspu/mpc/ab_api_test.cc +++ b/libspu/mpc/ab_api_test.cc @@ -195,7 +195,7 @@ TEST_P(ArithmeticTest, MulA1B) { return; } - const size_t K = spu::SizeOf(conf.field()) * 8; + const int64_t K = spu::SizeOf(conf.field()) * 8; /* GIVEN */ auto p0 = rand_p(obj.get(), conf.protocol() == ProtocolKind::CHEETAH @@ -204,12 +204,12 @@ TEST_P(ArithmeticTest, MulA1B) { auto p1 = rand_p(obj.get(), conf.protocol() == ProtocolKind::CHEETAH ? Shape({200, 26}) : kShape); - p1 = rshift_p(obj.get(), p1, K - 1); + p1 = rshift_p(obj.get(), p1, {K - 1}); auto a0 = p2a(obj.get(), p0); auto a1 = p2b(obj.get(), p1); // hint runtime this is a 1bit value. - a1 = lshift_b(obj.get(), a1, K - 1); - a1 = rshift_b(obj.get(), a1, K - 1); + a1 = lshift_b(obj.get(), a1, {K - 1}); + a1 = rshift_b(obj.get(), a1, {K - 1}); /* WHEN */ auto prev = obj->prot()->getState()->getStats(); @@ -238,12 +238,12 @@ TEST_P(ArithmeticTest, MulAV) { return; } - const size_t K = spu::SizeOf(conf.field()) * 8; + const int64_t K = spu::SizeOf(conf.field()) * 8; /* GIVEN */ auto p0 = rand_p(obj.get(), kShape); auto p1 = rand_p(obj.get(), kShape); - p1 = rshift_p(obj.get(), p1, K - 1); + p1 = rshift_p(obj.get(), p1, {K - 1}); auto a0 = p2a(obj.get(), p0); auto a1 = p2v(obj.get(), p1, 0); @@ -275,17 +275,17 @@ TEST_P(ArithmeticTest, MulA1BV) { return; } - const size_t K = spu::SizeOf(conf.field()) * 8; + const int64_t K = spu::SizeOf(conf.field()) * 8; /* GIVEN */ auto p0 = rand_p(obj.get(), kShape); auto p1 = rand_p(obj.get(), kShape); - p1 = rshift_p(obj.get(), p1, K - 1); + p1 = rshift_p(obj.get(), p1, {K - 1}); auto a0 = p2a(obj.get(), p0); auto a1 = p2v(obj.get(), p1, 0); // hint runtime this is a 1bit value. - a1 = lshift_v(obj.get(), a1, K - 1); - a1 = rshift_v(obj.get(), a1, K - 1); + a1 = lshift_v(obj.get(), a1, {K - 1}); + a1 = rshift_v(obj.get(), a1, {K - 1}); // auto a1 = b2v(obj.get(), _a1, 0); /* WHEN */ @@ -482,10 +482,10 @@ TEST_P(ArithmeticTest, LShiftA) { } /* WHEN */ auto prev = obj->prot()->getState()->getStats(); - auto tmp = lshift_a(obj.get(), a0, bits); + auto tmp = lshift_a(obj.get(), a0, {static_cast(bits)}); auto cost = obj->prot()->getState()->getStats() - prev; auto r_b = a2p(obj.get(), tmp); - auto r_p = lshift_p(obj.get(), p0, bits); + auto r_p = lshift_p(obj.get(), p0, {static_cast(bits)}); /* THEN */ EXPECT_VALUE_EQ(r_b, r_p); @@ -513,10 +513,11 @@ TEST_P(ArithmeticTest, TruncA) { if (!kernel->hasMsbError()) { // trunc requires MSB to be zero. - p0 = arshift_p(obj.get(), p0, 1); + p0 = arshift_p(obj.get(), p0, {1}); } else { // has msb error, only use lowest 10 bits. - p0 = arshift_p(obj.get(), p0, SizeOf(conf.field()) * 8 - 10); + p0 = arshift_p(obj.get(), p0, + {static_cast(SizeOf(conf.field()) * 8 - 10)}); } /* GIVEN */ @@ -529,7 +530,7 @@ TEST_P(ArithmeticTest, TruncA) { auto cost = obj->prot()->getState()->getStats() - prev; auto r_a = a2p(obj.get(), a1); - auto r_p = arshift_p(obj.get(), p0, bits); + auto r_p = arshift_p(obj.get(), p0, {static_cast(bits)}); /* THEN */ EXPECT_VALUE_ALMOST_EQ(r_a, r_p, npc); @@ -674,11 +675,11 @@ TEST_BOOLEAN_BINARY_OP(xor) } \ /* WHEN */ \ auto prev = obj->prot()->getState()->getStats(); \ - auto tmp = OP##_b(obj.get(), b0, bits); \ + auto tmp = OP##_b(obj.get(), b0, {static_cast(bits)}); \ auto cost = \ obj->prot()->getState()->getStats() - prev; \ auto r_b = b2p(obj.get(), tmp); \ - auto r_p = OP##_p(obj.get(), p0, bits); \ + auto r_p = OP##_p(obj.get(), p0, {static_cast(bits)}); \ \ /* THEN */ \ EXPECT_VALUE_EQ(r_b, r_p); \ @@ -837,7 +838,7 @@ TEST_P(ConversionTest, MSB) { // SECURENN has an msb input range here if (conf.protocol() == ProtocolKind::SECURENN) { - p0 = arshift_p(obj.get(), p0, 1); + p0 = arshift_p(obj.get(), p0, {1}); } auto a0 = p2a(obj.get(), p0); @@ -850,8 +851,10 @@ TEST_P(ConversionTest, MSB) { /* THEN */ EXPECT_TRUE(verifyCost(obj->prot()->getKernel("msb_a2b"), "msb_a2b", conf.field(), kShape, npc, cost)); - EXPECT_VALUE_EQ(rshift_p(obj.get(), p0, SizeOf(conf.field()) * 8 - 1), - b2p(obj.get(), b1)); + EXPECT_VALUE_EQ( + rshift_p(obj.get(), p0, + {static_cast(SizeOf(conf.field()) * 8 - 1)}), + b2p(obj.get(), b1)); }); } diff --git a/libspu/mpc/aby3/arithmetic.cc b/libspu/mpc/aby3/arithmetic.cc index a23749ff..6c365220 100644 --- a/libspu/mpc/aby3/arithmetic.cc +++ b/libspu/mpc/aby3/arithmetic.cc @@ -41,7 +41,7 @@ std::vector a1b_offline(size_t sender, const NdArrayRef& a, auto numel = a.numel(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() -> std::vector { + return DISPATCH_ALL_FIELDS(field, [&]() -> std::vector { using ashr_el_t = ring2k_t; NdArrayRef m0(makeType(field), a.shape()); NdArrayRef m1(makeType(field), a.shape()); @@ -52,7 +52,7 @@ std::vector a1b_offline(size_t sender, const NdArrayRef& a, NdArrayView _m1(m1); return DISPATCH_UINT_PT_TYPES( - b.eltype().as()->getBacktype(), "_", + b.eltype().as()->getBacktype(), [&]() -> std::vector { using bshr_t = std::array; if (self_rank == sender) { @@ -82,8 +82,6 @@ std::vector a1b_offline(size_t sender, const NdArrayRef& a, return {c1, c2, m0, m1}; } else if (self_rank == (sender + 1) % 3) { - prg_state->genPrssPair(field, a.shape(), - PrgState::GenPrssCtrl::None); auto c1 = prg_state ->genPrssPair(field, a.shape(), PrgState::GenPrssCtrl::First) @@ -95,8 +93,6 @@ std::vector a1b_offline(size_t sender, const NdArrayRef& a, ->genPrssPair(field, a.shape(), PrgState::GenPrssCtrl::Second) .second; - prg_state->genPrssPair(field, a.shape(), - PrgState::GenPrssCtrl::None); return {c2}; } @@ -109,7 +105,7 @@ std::vector ring_cast_boolean(const NdArrayRef& x) { const size_t numel = x.numel(); std::vector res(numel); - DISPATCH_UINT_PT_TYPES(x.eltype().as()->pt_type(), "_", [&]() { + DISPATCH_UINT_PT_TYPES(x.eltype().as()->pt_type(), [&]() { NdArrayView _x(x); pforeach(0, numel, [&](int64_t idx) { res[idx] = static_cast(_x[idx] & 0x1); @@ -127,7 +123,7 @@ NdArrayRef RandA::proc(KernelEvalContext* ctx, const Shape& shape) const { NdArrayRef out(makeType(field), shape); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; std::vector r0(shape.numel()); @@ -153,7 +149,7 @@ NdArrayRef A2P::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { const auto field = in.eltype().as()->field(); auto numel = in.numel(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using pshr_el_t = ring2k_t; using ashr_el_t = ring2k_t; using ashr_t = std::array; @@ -184,7 +180,7 @@ NdArrayRef P2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { auto rank = comm->getRank(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using ashr_el_t = ring2k_t; using ashr_t = std::array; using pshr_el_t = ring2k_t; @@ -226,7 +222,7 @@ NdArrayRef A2V::proc(KernelEvalContext* ctx, const NdArrayRef& in, auto* comm = ctx->getState(); const auto field = in.eltype().as()->field(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using vshr_el_t = ring2k_t; using ashr_el_t = ring2k_t; using ashr_t = std::array; @@ -267,7 +263,7 @@ NdArrayRef V2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { size_t owner_rank = in_ty->owner(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using ashr_el_t = ring2k_t; using ashr_t = std::array; @@ -311,7 +307,7 @@ NdArrayRef NotA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { auto rank = comm->getRank(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using el_t = std::make_unsigned_t; using shr_t = std::array; @@ -349,7 +345,7 @@ NdArrayRef AddAP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, auto rank = comm->getRank(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; using shr_t = std::array; @@ -376,7 +372,7 @@ NdArrayRef AddAA::proc(KernelEvalContext*, const NdArrayRef& lhs, SPU_ENFORCE(lhs_ty->field() == rhs_ty->field()); const auto field = lhs_ty->field(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using shr_t = std::array; NdArrayRef out(makeType(field), lhs.shape()); @@ -403,7 +399,7 @@ NdArrayRef MulAP::proc(KernelEvalContext*, const NdArrayRef& lhs, SPU_ENFORCE(lhs_ty->field() == rhs_ty->field()); const auto field = lhs_ty->field(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; using shr_t = std::array; @@ -426,7 +422,7 @@ NdArrayRef MulAA::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, auto* comm = ctx->getState(); auto* prg_state = ctx->getState(); - return DISPATCH_ALL_FIELDS(field, "aby3.mulAA", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; using shr_t = std::array; @@ -575,7 +571,7 @@ NdArrayRef MulA1B::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, .reshape(r2.first.shape()); } - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { // using = ring2k_t; NdArrayView r1_0(r1.first); NdArrayView r1_1(r1.second); @@ -675,11 +671,12 @@ NdArrayRef MatMulAA::proc(KernelEvalContext* ctx, const NdArrayRef& x, } NdArrayRef LShiftA::proc(KernelEvalContext*, const NdArrayRef& in, - size_t bits) const { + const Sizes& bits) const { const auto* in_ty = in.eltype().as(); const auto field = in_ty->field(); + bool is_splat = bits.size() == 1; - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using shr_t = std::array; NdArrayRef out(makeType(field), in.shape()); @@ -687,8 +684,9 @@ NdArrayRef LShiftA::proc(KernelEvalContext*, const NdArrayRef& in, NdArrayView _in(in); pforeach(0, in.numel(), [&](int64_t idx) { - _out[idx][0] = _in[idx][0] << bits; - _out[idx][1] = _in[idx][1] << bits; + auto shift_bit = is_splat ? bits[0] : bits[idx]; + _out[idx][0] = _in[idx][0] << shift_bit; + _out[idx][1] = _in[idx][1] << shift_bit; }); return out; @@ -722,22 +720,23 @@ NdArrayRef TruncA::proc(KernelEvalContext* ctx, const NdArrayRef& in, comm->addCommStatsManually(1, kComm); // comm => 1, 2 // ret + const Sizes shift_bit = {static_cast(bits)}; switch (comm->getRank()) { case 0: { - const auto z1 = ring_arshift(x1, bits); + const auto z1 = ring_arshift(x1, shift_bit); const auto z2 = comm->recv(1, x1.eltype(), kBindName); return makeAShare(z1, z2, field); } case 1: { auto r1 = r_future.get().second; - const auto z1 = ring_sub(ring_arshift(ring_add(x1, x2), bits), r1); + const auto z1 = ring_sub(ring_arshift(ring_add(x1, x2), shift_bit), r1); comm->sendAsync(0, z1, kBindName); return makeAShare(z1, r1, field); } case 2: { - const auto z2 = ring_arshift(x2, bits); + const auto z2 = ring_arshift(x2, shift_bit); return makeAShare(r_future.get().first, z2, field); } @@ -790,7 +789,7 @@ NdArrayRef TruncAPr::proc(KernelEvalContext* ctx, const NdArrayRef& in, size_t P2 = (pivot + 2) % 3; NdArrayRef out(in.eltype(), in.shape()); - DISPATCH_ALL_FIELDS(field, "aby3.truncpr", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; using shr_t = std::array; diff --git a/libspu/mpc/aby3/arithmetic.h b/libspu/mpc/aby3/arithmetic.h index 75135fa7..94929715 100644 --- a/libspu/mpc/aby3/arithmetic.h +++ b/libspu/mpc/aby3/arithmetic.h @@ -243,7 +243,7 @@ class LShiftA : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override; + const Sizes& bits) const override; }; // Refer to: diff --git a/libspu/mpc/aby3/boolean.cc b/libspu/mpc/aby3/boolean.cc index 8e07307a..d8ee6db5 100644 --- a/libspu/mpc/aby3/boolean.cc +++ b/libspu/mpc/aby3/boolean.cc @@ -42,11 +42,11 @@ void CommonTypeB::evaluate(KernelEvalContext* ctx) const { NdArrayRef CastTypeB::proc(KernelEvalContext*, const NdArrayRef& in, const Type& to_type) const { NdArrayRef out(to_type, in.shape()); - DISPATCH_UINT_PT_TYPES(in.eltype().as()->getBacktype(), "_", [&]() { + DISPATCH_UINT_PT_TYPES(in.eltype().as()->getBacktype(), [&]() { using in_el_t = ScalarT; using in_shr_t = std::array; - DISPATCH_UINT_PT_TYPES(to_type.as()->getBacktype(), "_", [&]() { + DISPATCH_UINT_PT_TYPES(to_type.as()->getBacktype(), [&]() { using out_el_t = ScalarT; using out_shr_t = std::array; @@ -69,11 +69,11 @@ NdArrayRef B2P::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { const PtType btype = in.eltype().as()->getBacktype(); const auto field = ctx->getState()->getDefaultField(); - return DISPATCH_UINT_PT_TYPES(btype, "aby3.b2p", [&]() { + return DISPATCH_UINT_PT_TYPES(btype, [&]() { using bshr_el_t = ScalarT; using bshr_t = std::array; - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using pshr_el_t = ring2k_t; NdArrayRef out(makeType(field), in.shape()); @@ -99,12 +99,12 @@ NdArrayRef P2B::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { const auto* in_ty = in.eltype().as(); const auto field = in_ty->field(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { const size_t nbits = maxBitWidth(in); const PtType btype = calcBShareBacktype(nbits); NdArrayView _in(in); - return DISPATCH_UINT_PT_TYPES(btype, "_", [&]() { + return DISPATCH_UINT_PT_TYPES(btype, [&]() { using bshr_el_t = ScalarT; using bshr_t = std::array; @@ -134,12 +134,12 @@ NdArrayRef B2V::proc(KernelEvalContext* ctx, const NdArrayRef& in, const PtType btype = in.eltype().as()->getBacktype(); const auto field = ctx->getState()->getDefaultField(); - return DISPATCH_UINT_PT_TYPES(btype, "aby3.b2v", [&]() { + return DISPATCH_UINT_PT_TYPES(btype, [&]() { using bshr_el_t = ScalarT; using bshr_t = std::array; NdArrayView _in(in); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using vshr_scalar_t = ring2k_t; auto out_ty = makeType(field, rank); @@ -177,7 +177,7 @@ NdArrayRef AndBP::proc(KernelEvalContext*, const NdArrayRef& lhs, const auto* lhs_ty = lhs.eltype().as(); const auto* rhs_ty = rhs.eltype().as(); - return DISPATCH_ALL_FIELDS(rhs_ty->field(), "_", [&]() { + return DISPATCH_ALL_FIELDS(rhs_ty->field(), [&]() { using rhs_scalar_t = ring2k_t; const size_t rhs_nbits = maxBitWidth(rhs); @@ -186,13 +186,13 @@ NdArrayRef AndBP::proc(KernelEvalContext*, const NdArrayRef& lhs, NdArrayView _rhs(rhs); - return DISPATCH_UINT_PT_TYPES(lhs_ty->getBacktype(), "_", [&]() { + return DISPATCH_UINT_PT_TYPES(lhs_ty->getBacktype(), [&]() { using lhs_el_t = ScalarT; using lhs_shr_t = std::array; NdArrayView _lhs(lhs); - return DISPATCH_UINT_PT_TYPES(out_btype, "_", [&]() { + return DISPATCH_UINT_PT_TYPES(out_btype, [&]() { using out_el_t = ScalarT; using out_shr_t = std::array; @@ -224,17 +224,17 @@ NdArrayRef AndBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const PtType out_btype = calcBShareBacktype(out_nbits); NdArrayRef out(makeType(out_btype, out_nbits), lhs.shape()); - return DISPATCH_UINT_PT_TYPES(rhs_ty->getBacktype(), "_", [&]() { + return DISPATCH_UINT_PT_TYPES(rhs_ty->getBacktype(), [&]() { using rhs_el_t = ScalarT; using rhs_shr_t = std::array; NdArrayView _rhs(rhs); - return DISPATCH_UINT_PT_TYPES(lhs_ty->getBacktype(), "_", [&]() { + return DISPATCH_UINT_PT_TYPES(lhs_ty->getBacktype(), [&]() { using lhs_el_t = ScalarT; using lhs_shr_t = std::array; NdArrayView _lhs(lhs); - return DISPATCH_UINT_PT_TYPES(out_btype, "_", [&]() { + return DISPATCH_UINT_PT_TYPES(out_btype, [&]() { using out_el_t = ScalarT; using out_shr_t = std::array; @@ -269,7 +269,7 @@ NdArrayRef XorBP::proc(KernelEvalContext*, const NdArrayRef& lhs, const auto* lhs_ty = lhs.eltype().as(); const auto* rhs_ty = rhs.eltype().as(); - return DISPATCH_ALL_FIELDS(rhs_ty->field(), "_", [&]() { + return DISPATCH_ALL_FIELDS(rhs_ty->field(), [&]() { using rhs_scalar_t = ring2k_t; const size_t rhs_nbits = maxBitWidth(rhs); @@ -280,13 +280,13 @@ NdArrayRef XorBP::proc(KernelEvalContext*, const NdArrayRef& lhs, NdArrayRef out(makeType(out_btype, out_nbits), lhs.shape()); - return DISPATCH_UINT_PT_TYPES(lhs_ty->getBacktype(), "_", [&]() { + return DISPATCH_UINT_PT_TYPES(lhs_ty->getBacktype(), [&]() { using lhs_el_t = ScalarT; using lhs_shr_t = std::array; NdArrayView _lhs(lhs); - return DISPATCH_UINT_PT_TYPES(out_btype, "_", [&]() { + return DISPATCH_UINT_PT_TYPES(out_btype, [&]() { using out_el_t = ScalarT; using out_shr_t = std::array; @@ -311,19 +311,19 @@ NdArrayRef XorBB::proc(KernelEvalContext*, const NdArrayRef& lhs, const size_t out_nbits = std::max(lhs_ty->nbits(), rhs_ty->nbits()); const PtType out_btype = calcBShareBacktype(out_nbits); - return DISPATCH_UINT_PT_TYPES(rhs_ty->getBacktype(), "_", [&]() { + return DISPATCH_UINT_PT_TYPES(rhs_ty->getBacktype(), [&]() { using rhs_el_t = ScalarT; using rhs_shr_t = std::array; NdArrayView _rhs(rhs); - return DISPATCH_UINT_PT_TYPES(lhs_ty->getBacktype(), "_", [&]() { + return DISPATCH_UINT_PT_TYPES(lhs_ty->getBacktype(), [&]() { using lhs_el_t = ScalarT; using lhs_shr_t = std::array; NdArrayView _lhs(lhs); - return DISPATCH_UINT_PT_TYPES(out_btype, "_", [&]() { + return DISPATCH_UINT_PT_TYPES(out_btype, [&]() { using out_el_t = ScalarT; using out_shr_t = std::array; @@ -343,21 +343,24 @@ NdArrayRef XorBB::proc(KernelEvalContext*, const NdArrayRef& lhs, } NdArrayRef LShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const { + const Sizes& bits) const { const auto* in_ty = in.eltype().as(); // TODO: the hal dtype should tell us about the max number of possible bits. const auto field = ctx->getState()->getDefaultField(); - const size_t out_nbits = std::min(in_ty->nbits() + bits, SizeOf(field) * 8); + const size_t out_nbits = + std::min(in_ty->nbits() + *std::max_element(bits.begin(), bits.end()), + SizeOf(field) * 8); const PtType out_btype = calcBShareBacktype(out_nbits); + bool is_splat = bits.size() == 1; - return DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), "_", [&]() { + return DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), [&]() { using in_el_t = ScalarT; using in_shr_t = std::array; NdArrayView _in(in); - return DISPATCH_UINT_PT_TYPES(out_btype, "_", [&]() { + return DISPATCH_UINT_PT_TYPES(out_btype, [&]() { using out_el_t = ScalarT; using out_shr_t = std::array; @@ -366,8 +369,9 @@ NdArrayRef LShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, pforeach(0, in.numel(), [&](int64_t idx) { const auto& v = _in[idx]; - _out[idx][0] = static_cast(v[0]) << bits; - _out[idx][1] = static_cast(v[1]) << bits; + auto shift_bit = is_splat ? bits[0] : bits[idx]; + _out[idx][0] = static_cast(v[0]) << shift_bit; + _out[idx][1] = static_cast(v[1]) << shift_bit; }); return out; @@ -376,19 +380,19 @@ NdArrayRef LShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, } NdArrayRef RShiftB::proc(KernelEvalContext*, const NdArrayRef& in, - size_t bits) const { + const Sizes& bits) const { const auto* in_ty = in.eltype().as(); - bits = std::min(in_ty->nbits(), bits); - size_t out_nbits = in_ty->nbits(); - out_nbits -= std::min(out_nbits, bits); + int64_t out_nbits = in_ty->nbits(); + out_nbits -= std::min(out_nbits, *std::min_element(bits.begin(), bits.end())); const PtType out_btype = calcBShareBacktype(out_nbits); + bool is_splat = bits.size() == 1; - return DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), "_", [&]() { + return DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), [&]() { using in_shr_t = std::array; NdArrayView _in(in); - return DISPATCH_UINT_PT_TYPES(out_btype, "_", [&]() { + return DISPATCH_UINT_PT_TYPES(out_btype, [&]() { using out_el_t = ScalarT; using out_shr_t = std::array; @@ -397,8 +401,9 @@ NdArrayRef RShiftB::proc(KernelEvalContext*, const NdArrayRef& in, pforeach(0, in.numel(), [&](int64_t idx) { const auto& v = _in[idx]; - _out[idx][0] = static_cast(v[0] >> bits); - _out[idx][1] = static_cast(v[1] >> bits); + auto shift_bit = is_splat ? bits[0] : bits[idx]; + _out[idx][0] = static_cast(v[0] >> shift_bit); + _out[idx][1] = static_cast(v[1] >> shift_bit); }); return out; @@ -407,9 +412,10 @@ NdArrayRef RShiftB::proc(KernelEvalContext*, const NdArrayRef& in, } NdArrayRef ARShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const { + const Sizes& bits) const { const auto field = ctx->getState()->getDefaultField(); const auto* in_ty = in.eltype().as(); + bool is_splat = bits.size() == 1; // arithmetic right shift expects to work on ring, or the behaviour is // undefined. @@ -418,7 +424,7 @@ NdArrayRef ARShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, const PtType out_btype = in_ty->getBacktype(); const size_t out_nbits = in_ty->nbits(); - return DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), "_", [&]() { + return DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), [&]() { using el_t = std::make_signed_t; using shr_t = std::array; @@ -428,8 +434,9 @@ NdArrayRef ARShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, pforeach(0, in.numel(), [&](int64_t idx) { const auto& v = _in[idx]; - _out[idx][0] = v[0] >> bits; - _out[idx][1] = v[1] >> bits; + auto shift_bit = is_splat ? bits[0] : bits[idx]; + _out[idx][0] = v[0] >> shift_bit; + _out[idx][1] = v[1] >> shift_bit; }); return out; @@ -444,13 +451,13 @@ NdArrayRef BitrevB::proc(KernelEvalContext*, const NdArrayRef& in, size_t start, const size_t out_nbits = std::max(in_ty->nbits(), end); const PtType out_btype = calcBShareBacktype(out_nbits); - return DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), "_", [&]() { + return DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), [&]() { using in_el_t = ScalarT; using in_shr_t = std::array; NdArrayView _in(in); - return DISPATCH_UINT_PT_TYPES(out_btype, "_", [&]() { + return DISPATCH_UINT_PT_TYPES(out_btype, [&]() { using out_el_t = ScalarT; using out_shr_t = std::array; @@ -488,7 +495,7 @@ NdArrayRef BitIntlB::proc(KernelEvalContext*, const NdArrayRef& in, SPU_ENFORCE(absl::has_single_bit(nbits)); NdArrayRef out(in.eltype(), in.shape()); - DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), "_", [&]() { + DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), [&]() { using el_t = ScalarT; using shr_t = std::array; NdArrayView _out(out); @@ -511,7 +518,7 @@ NdArrayRef BitDeintlB::proc(KernelEvalContext*, const NdArrayRef& in, SPU_ENFORCE(absl::has_single_bit(nbits)); NdArrayRef out(in.eltype(), in.shape()); - DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), "_", [&]() { + DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), [&]() { using el_t = ScalarT; using shr_t = std::array; NdArrayView _out(out); diff --git a/libspu/mpc/aby3/boolean.h b/libspu/mpc/aby3/boolean.h index b3036620..9fd83b45 100644 --- a/libspu/mpc/aby3/boolean.h +++ b/libspu/mpc/aby3/boolean.h @@ -152,7 +152,7 @@ class LShiftB : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override; + const Sizes& bits) const override; }; class RShiftB : public ShiftKernel { @@ -164,7 +164,7 @@ class RShiftB : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override; + const Sizes& bits) const override; }; class ARShiftB : public ShiftKernel { @@ -176,7 +176,7 @@ class ARShiftB : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override; + const Sizes& bits) const override; }; class BitrevB : public BitrevKernel { diff --git a/libspu/mpc/aby3/conversion.cc b/libspu/mpc/aby3/conversion.cc index 070450c5..8154acac 100644 --- a/libspu/mpc/aby3/conversion.cc +++ b/libspu/mpc/aby3/conversion.cc @@ -65,11 +65,11 @@ NdArrayRef A2B::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { auto numel = in.numel(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using ashr_t = std::array; NdArrayView _in(in); - DISPATCH_UINT_PT_TYPES(out_btype, "_", [&]() { + DISPATCH_UINT_PT_TYPES(out_btype, [&]() { using bshr_el_t = ScalarT; using bshr_t = std::array; @@ -152,7 +152,7 @@ NdArrayRef B2AByPPA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { if (in_nbits == 0) { // special case, it's known to be zero. - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView> _out(out); pforeach(0, numel, [&](int64_t idx) { _out[idx][0] = 0; @@ -165,11 +165,11 @@ NdArrayRef B2AByPPA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { auto* comm = ctx->getState(); auto* prg_state = ctx->getState(); - DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), "_", [&]() { + DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), [&]() { using bshr_t = std::array; NdArrayView _in(in); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using ashr_el_t = ring2k_t; using ashr_t = std::array; @@ -324,7 +324,7 @@ NdArrayRef B2AByOT::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { if (in_nbits == 0) { // special case, it's known to be zero. - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView> _out(out); pforeach(0, numel, [&](int64_t idx) { _out[idx][0] = 0; @@ -345,12 +345,12 @@ NdArrayRef B2AByOT::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { size_t P1 = (pivot + 1) % 3; size_t P2 = (pivot + 2) % 3; - DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), "_", [&]() { + DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), [&]() { using bshr_el_t = ScalarT; using bshr_t = std::array; NdArrayView _in(in); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using ashr_el_t = ring2k_t; using ashr_t = std::array; @@ -391,11 +391,6 @@ NdArrayRef B2AByOT::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { }); } else if (comm->getRank() == P1) { // the receiver - prg_state->fillPrssPair(nullptr, nullptr, r0.size(), - PrgState::GenPrssCtrl::None); - prg_state->fillPrssPair(nullptr, nullptr, r0.size(), - PrgState::GenPrssCtrl::None); - auto b2 = bitDecompose(getShare(in, 0), in_nbits); // ot.recv @@ -494,12 +489,12 @@ NdArrayRef B2AByOT::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { NdArrayRef lo(out_type, in.shape()); NdArrayRef hi(out_type, in.shape()); - DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), "_", [&]() { + DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), [&]() { using in_el_t = ScalarT; using in_shr_t = std::array; NdArrayView _in(in); - DISPATCH_UINT_PT_TYPES(out_backtype, "_", [&]() { + DISPATCH_UINT_PT_TYPES(out_backtype, [&]() { using out_el_t = ScalarT; using out_shr_t = std::array; @@ -574,7 +569,7 @@ NdArrayRef MsbA2B::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { makeType(GetStorageType(field), SizeOf(field) * 8); NdArrayRef m(bshr_type, in.shape()); NdArrayRef n(bshr_type, in.shape()); - DISPATCH_ALL_FIELDS(field, "aby3.msb.split", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; using shr_t = std::array; @@ -620,7 +615,9 @@ NdArrayRef MsbA2B::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { // Compute the k'th bit. // (m^n)[k] ^ carry - auto msb = xor_bb(sctx, rshift_b(sctx, xor_bb(sctx, wrap_m, wrap_n), nbits), + auto msb = xor_bb(sctx, + rshift_b(sctx, xor_bb(sctx, wrap_m, wrap_n), + {static_cast(nbits)}), carry); return UnwrapValue(msb); @@ -660,10 +657,10 @@ NdArrayRef eqz(KernelEvalContext* ctx, const NdArrayRef& in) { size_t P1 = (pivot + 1) % 3; size_t P2 = (pivot + 2) % 3; - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using ashr_el_t = ring2k_t; using ashr_t = std::array; - DISPATCH_UINT_PT_TYPES(in_bshr_btype, "_", [&]() { + DISPATCH_UINT_PT_TYPES(in_bshr_btype, [&]() { using bshr_el_t = ScalarT; std::vector zero_flag_3pc_0(numel); std::vector zero_flag_3pc_1(numel); @@ -715,10 +712,6 @@ NdArrayRef eqz(KernelEvalContext* ctx, const NdArrayRef& in) { PrgState::GenPrssCtrl::First); } else { pforeach(0, numel, [&](int64_t idx) { a_s[idx] = _in[idx][1]; }); - prg_state->fillPrssPair({}, {}, numel, - PrgState::GenPrssCtrl::None); - prg_state->fillPrssPair({}, {}, numel, - PrgState::GenPrssCtrl::None); r_arith = comm->recv(P0, "r_arith"); r_bool = comm->recv(P0, "r_bool"); } @@ -754,8 +747,6 @@ NdArrayRef eqz(KernelEvalContext* ctx, const NdArrayRef& in) { // P1 zero_flag = (not(c_p xor [r]b0)^ rz, rb1) pforeach(0, numel, [&](int64_t idx) { zero_flag_3pc_1[idx] = r_bool[idx]; }); - prg_state->fillPrssPair({}, {}, numel, - PrgState::GenPrssCtrl::None); auto flag_split = comm->recv(P1, "flag_split"); pforeach(0, numel, [&](int64_t idx) { @@ -866,7 +857,7 @@ NdArrayRef EqualAA::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const auto field = lhs_ty->field(); NdArrayRef out(makeType(field), lhs.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using shr_t = std::array; NdArrayView _out(out); NdArrayView _lhs(lhs); @@ -893,7 +884,7 @@ NdArrayRef EqualAP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, auto rank = comm->getRank(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; using shr_t = std::array; diff --git a/libspu/mpc/aby3/io.cc b/libspu/mpc/aby3/io.cc index bc3be20c..3ded34a9 100644 --- a/libspu/mpc/aby3/io.cc +++ b/libspu/mpc/aby3/io.cc @@ -148,7 +148,7 @@ NdArrayRef Aby3Io::fromShares(const std::vector& shares) const { SPU_ENFORCE(field_ == eltype.as()->field()); NdArrayRef out(makeType(field_), shares[0].shape()); - DISPATCH_ALL_FIELDS(field_, "_", [&]() { + DISPATCH_ALL_FIELDS(field_, [&]() { using el_t = ring2k_t; using shr_t = std::array; NdArrayView _out(out); @@ -166,10 +166,10 @@ NdArrayRef Aby3Io::fromShares(const std::vector& shares) const { } else if (eltype.isa()) { NdArrayRef out(makeType(field_), shares[0].shape()); - DISPATCH_ALL_FIELDS(field_, "_", [&]() { + DISPATCH_ALL_FIELDS(field_, [&]() { NdArrayView _out(out); - DISPATCH_UINT_PT_TYPES(eltype.as()->getBacktype(), "_", [&] { + DISPATCH_UINT_PT_TYPES(eltype.as()->getBacktype(), [&] { using shr_t = std::array; for (size_t si = 0; si < shares.size(); si++) { NdArrayView _s(shares[si]); diff --git a/libspu/mpc/aby3/oram.cc b/libspu/mpc/aby3/oram.cc index 2ccf83f0..395a05ba 100644 --- a/libspu/mpc/aby3/oram.cc +++ b/libspu/mpc/aby3/oram.cc @@ -37,7 +37,7 @@ NdArrayRef OramOneHotAA::proc(KernelEvalContext *ctx, const NdArrayRef &in, const auto field = eltype.as()->field(); NdArrayRef out(makeType(field), {s}); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; using shr_t = std::array; NdArrayView out_(out); @@ -91,7 +91,7 @@ NdArrayRef OramOneHotAP::proc(KernelEvalContext *ctx, const NdArrayRef &in, const auto numel = in.numel(); NdArrayRef out(makeType(field), {s}); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; using shr_t = std::array; NdArrayView out_(out); @@ -150,7 +150,7 @@ NdArrayRef OramReadOA::proc(KernelEvalContext *ctx, const NdArrayRef &onehot, NdArrayRef out(makeType(field), {1, index_times}); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; using shr_t = std::array; @@ -208,7 +208,7 @@ NdArrayRef OramReadOP::proc(KernelEvalContext *ctx, const NdArrayRef &onehot, auto o2 = getSecondShare(out); int64_t db_numel = onehot.numel(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; using shr_t = std::array; @@ -286,15 +286,11 @@ Triple> genOramBeaverPrim(KernelEvalContext *ctx, int64_t num, std::vector beaver_triple(num * 3); if (comm->getRank() == adjust_rank) { - prg->fillPrssPair(nullptr, nullptr, num * 3, - PrgState::GenPrssCtrl::None); prg->fillPrssPair(nullptr, beaver_triple.data(), num * 3, PrgState::GenPrssCtrl::Second); } else { prg->fillPrssPair(beaver_triple.data(), nullptr, num * 3, PrgState::GenPrssCtrl::First); - prg->fillPrssPair(nullptr, nullptr, num * 3, - PrgState::GenPrssCtrl::None); } std::vector a(beaver_triple.begin(), beaver_triple.begin() + num); diff --git a/libspu/mpc/aby3/ot.cc b/libspu/mpc/aby3/ot.cc index 122e0fa3..8f5f4863 100644 --- a/libspu/mpc/aby3/ot.cc +++ b/libspu/mpc/aby3/ot.cc @@ -71,8 +71,6 @@ std::pair Ot3::genMasks() { } } else { SPU_ENFORCE(comm_->getRank() == roles_.receiver); - prg_state_->genPrssPair(field_, shape_, PrgState::GenPrssCtrl::None); - prg_state_->genPrssPair(field_, shape_, PrgState::GenPrssCtrl::None); } return {w0, w1}; diff --git a/libspu/mpc/aby3/permute.cc b/libspu/mpc/aby3/permute.cc index edbeed7f..fb80bfd0 100644 --- a/libspu/mpc/aby3/permute.cc +++ b/libspu/mpc/aby3/permute.cc @@ -30,7 +30,7 @@ PermVector ring2pv(const NdArrayRef& x) { x.eltype()); const auto field = x.eltype().as()->field(); PermVector pv(x.numel()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _x(x); pforeach(0, x.numel(), [&](int64_t idx) { pv[idx] = int64_t(_x[idx]); }); }); @@ -54,7 +54,7 @@ NdArrayRef RandPermM::proc(KernelEvalContext* ctx, const Shape& shape) const { const auto field = out.eltype().as()->field(); auto out1 = getFirstShare(out); auto out2 = getSecondShare(out); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _out1(out1); NdArrayView _out2(out2); pforeach(0, out.numel(), [&](int64_t idx) { @@ -78,7 +78,7 @@ NdArrayRef PermAM::proc(KernelEvalContext* ctx, const NdArrayRef& in, PermVector pv_next = ring2pv(getSecondShare(perm)); NdArrayRef out(in.eltype(), in.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; using shr_t = std::array; @@ -198,7 +198,7 @@ NdArrayRef InvPermAM::proc(KernelEvalContext* ctx, const NdArrayRef& in, PermVector pv_next = ring2pv(getSecondShare(perm)); NdArrayRef out(in.eltype(), in.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; using shr_t = std::array; diff --git a/libspu/mpc/aby3/value.h b/libspu/mpc/aby3/value.h index fbb413a0..8396b121 100644 --- a/libspu/mpc/aby3/value.h +++ b/libspu/mpc/aby3/value.h @@ -58,7 +58,7 @@ std::vector getShareAs(const NdArrayRef& in, size_t share_idx) { auto numel = in.numel(); std::vector res(numel); - DISPATCH_UINT_PT_TYPES(share.eltype().as()->pt_type(), "_", [&]() { + DISPATCH_UINT_PT_TYPES(share.eltype().as()->pt_type(), [&]() { NdArrayView _share(share); for (auto idx = 0; idx < numel; ++idx) { res[idx] = _share[idx]; diff --git a/libspu/mpc/api.cc b/libspu/mpc/api.cc index 2008e8aa..58a1f4c1 100644 --- a/libspu/mpc/api.cc +++ b/libspu/mpc/api.cc @@ -14,8 +14,6 @@ #include "libspu/mpc/api.h" -#include - #include "libspu/core/trace.h" #include "libspu/mpc/ab_api.h" @@ -303,13 +301,14 @@ Value msb_s(SPUContext* ctx, const Value& x) { if (ctx->hasKernel("msb_a2b")) { if (IsB(x)) { - return rshift_b(ctx, x, SizeOf(field) * 8 - 1); + return rshift_b(ctx, x, {static_cast(SizeOf(field) * 8 - 1)}); } else { // fast path, directly apply msb x AShare, result a BShare. return msb_a2b(ctx, x); } } else { - return rshift_b(ctx, _2b(ctx, x), SizeOf(field) * 8 - 1); + return rshift_b(ctx, _2b(ctx, x), + {static_cast(SizeOf(field) * 8 - 1)}); } } @@ -601,7 +600,7 @@ Value xor_pp(SPUContext* ctx, const Value& x, const Value& y) { ////////////////////////////////////////////////////////////////////////////// -Value lshift_s(SPUContext* ctx, const Value& x, size_t bits) { +Value lshift_s(SPUContext* ctx, const Value& x, const Sizes& bits) { SPU_TRACE_MPC_DISP(ctx, x, bits); TRY_DISPATCH(ctx, x, bits); if (IsA(x)) { @@ -613,43 +612,43 @@ Value lshift_s(SPUContext* ctx, const Value& x, size_t bits) { } } -Value lshift_v(SPUContext* ctx, const Value& x, size_t nbits) { +Value lshift_v(SPUContext* ctx, const Value& x, const Sizes& nbits) { FORCE_DISPATCH(ctx, x, nbits); } -Value lshift_p(SPUContext* ctx, const Value& x, size_t nbits) { +Value lshift_p(SPUContext* ctx, const Value& x, const Sizes& nbits) { FORCE_DISPATCH(ctx, x, nbits); } ////////////////////////////////////////////////////////////////////////////// -Value rshift_s(SPUContext* ctx, const Value& x, size_t bits) { +Value rshift_s(SPUContext* ctx, const Value& x, const Sizes& bits) { SPU_TRACE_MPC_DISP(ctx, x, bits); TRY_DISPATCH(ctx, x, bits); return rshift_b(ctx, _2b(ctx, x), bits); } -Value rshift_v(SPUContext* ctx, const Value& x, size_t nbits) { +Value rshift_v(SPUContext* ctx, const Value& x, const Sizes& nbits) { FORCE_DISPATCH(ctx, x, nbits); } -Value rshift_p(SPUContext* ctx, const Value& x, size_t nbits) { +Value rshift_p(SPUContext* ctx, const Value& x, const Sizes& nbits) { FORCE_DISPATCH(ctx, x, nbits); } ////////////////////////////////////////////////////////////////////////////// -Value arshift_s(SPUContext* ctx, const Value& x, size_t bits) { +Value arshift_s(SPUContext* ctx, const Value& x, const Sizes& bits) { SPU_TRACE_MPC_DISP(ctx, x, bits); TRY_DISPATCH(ctx, x, bits); return arshift_b(ctx, _2b(ctx, x), bits); } -Value arshift_v(SPUContext* ctx, const Value& x, size_t nbits) { +Value arshift_v(SPUContext* ctx, const Value& x, const Sizes& nbits) { FORCE_DISPATCH(ctx, x, nbits); } -Value arshift_p(SPUContext* ctx, const Value& x, size_t nbits) { +Value arshift_p(SPUContext* ctx, const Value& x, const Sizes& nbits) { FORCE_DISPATCH(ctx, x, nbits); } @@ -662,11 +661,15 @@ Value trunc_s(SPUContext* ctx, const Value& x, size_t bits, SignType sign) { } Value trunc_v(SPUContext* ctx, const Value& x, size_t nbits, SignType sign) { - FORCE_DISPATCH(ctx, x, nbits, sign); + // FIXME: trunc_v use shift kernel + const Sizes trunc_bits = {static_cast(nbits)}; + FORCE_DISPATCH(ctx, x, trunc_bits, sign); } Value trunc_p(SPUContext* ctx, const Value& x, size_t nbits, SignType sign) { - FORCE_DISPATCH(ctx, x, nbits, sign); + // FIXME: trunc_p use shift kernel + const Sizes trunc_bits = {static_cast(nbits)}; + FORCE_DISPATCH(ctx, x, trunc_bits, sign); } ////////////////////////////////////////////////////////////////////////////// diff --git a/libspu/mpc/api.h b/libspu/mpc/api.h index 7d0b2333..ea61cc81 100644 --- a/libspu/mpc/api.h +++ b/libspu/mpc/api.h @@ -143,17 +143,17 @@ Value xor_vv(SPUContext* ctx, const Value& x, const Value& y); Value xor_vp(SPUContext* ctx, const Value& x, const Value& y); Value xor_pp(SPUContext* ctx, const Value& x, const Value& y); -Value lshift_s(SPUContext* ctx, const Value& x, size_t nbits); -Value lshift_v(SPUContext* ctx, const Value& x, size_t nbits); -Value lshift_p(SPUContext* ctx, const Value& x, size_t nbits); +Value lshift_s(SPUContext* ctx, const Value& x, const Sizes& nbits); +Value lshift_v(SPUContext* ctx, const Value& x, const Sizes& nbits); +Value lshift_p(SPUContext* ctx, const Value& x, const Sizes& nbits); -Value rshift_s(SPUContext* ctx, const Value& x, size_t nbits); -Value rshift_v(SPUContext* ctx, const Value& x, size_t nbits); -Value rshift_p(SPUContext* ctx, const Value& x, size_t nbits); +Value rshift_s(SPUContext* ctx, const Value& x, const Sizes& nbits); +Value rshift_v(SPUContext* ctx, const Value& x, const Sizes& nbits); +Value rshift_p(SPUContext* ctx, const Value& x, const Sizes& nbits); -Value arshift_s(SPUContext* ctx, const Value& x, size_t nbits); -Value arshift_v(SPUContext* ctx, const Value& x, size_t nbits); -Value arshift_p(SPUContext* ctx, const Value& x, size_t nbits); +Value arshift_s(SPUContext* ctx, const Value& x, const Sizes& nbits); +Value arshift_v(SPUContext* ctx, const Value& x, const Sizes& nbits); +Value arshift_p(SPUContext* ctx, const Value& x, const Sizes& nbits); Value trunc_s(SPUContext* ctx, const Value& x, size_t nbits, SignType sign); Value trunc_v(SPUContext* ctx, const Value& x, size_t nbits, SignType sign); diff --git a/libspu/mpc/api_test.cc b/libspu/mpc/api_test.cc index 33568d77..43cb0cf0 100644 --- a/libspu/mpc/api_test.cc +++ b/libspu/mpc/api_test.cc @@ -261,7 +261,7 @@ TEST_P(ApiTest, MsbS) { // SECURENN has an msb input range requirement here if (conf.protocol() == ProtocolKind::SECURENN) { - p0 = arshift_p(sctx.get(), p0, 1); + p0 = arshift_p(sctx.get(), p0, {1}); } auto r_s = s2p(sctx.get(), msb_s(sctx.get(), p2s(sctx.get(), p0))); @@ -272,88 +272,92 @@ TEST_P(ApiTest, MsbS) { }); } -#define TEST_UNARY_OP_WITH_BIT_S(OP) \ - TEST_P(ApiTest, OP##S) { \ - const auto factory = std::get<0>(GetParam()); \ - const RuntimeConfig& conf = std::get<1>(GetParam()); \ - const size_t npc = std::get<2>(GetParam()); \ - \ - utils::simulate( \ - npc, [&](const std::shared_ptr& lctx) { \ - auto sctx = factory(conf, lctx); \ - \ - /* GIVEN */ \ - auto x_p = rand_p(sctx.get(), kShape); \ - auto x_s = p2s(sctx.get(), x_p); \ - \ - for (auto bits : kShiftBits) { \ - if (bits >= SizeOf(conf.field()) * 8) { \ - continue; \ - } \ - /* WHEN */ \ - auto r_s = s2p(sctx.get(), OP##_s(sctx.get(), x_s, bits)); \ - auto r_p = OP##_p(sctx.get(), x_p, bits); \ - \ - /* THEN */ \ - EXPECT_VALUE_EQ(r_s, r_p); \ - } \ - }); \ +#define TEST_UNARY_OP_WITH_BIT_S(OP) \ + TEST_P(ApiTest, OP##S) { \ + const auto factory = std::get<0>(GetParam()); \ + const RuntimeConfig& conf = std::get<1>(GetParam()); \ + const size_t npc = std::get<2>(GetParam()); \ + \ + utils::simulate( \ + npc, [&](const std::shared_ptr& lctx) { \ + auto sctx = factory(conf, lctx); \ + \ + /* GIVEN */ \ + auto x_p = rand_p(sctx.get(), kShape); \ + auto x_s = p2s(sctx.get(), x_p); \ + \ + for (auto bits : kShiftBits) { \ + if (bits >= SizeOf(conf.field()) * 8) { \ + continue; \ + } \ + /* WHEN */ \ + auto r_s = s2p(sctx.get(), OP##_s(sctx.get(), x_s, \ + {static_cast(bits)})); \ + auto r_p = OP##_p(sctx.get(), x_p, {static_cast(bits)}); \ + \ + /* THEN */ \ + EXPECT_VALUE_EQ(r_s, r_p); \ + } \ + }); \ } -#define TEST_UNARY_OP_WITH_BIT_V(OP) \ - TEST_P(ApiTest, OP##V) { \ - const auto factory = std::get<0>(GetParam()); \ - const RuntimeConfig& conf = std::get<1>(GetParam()); \ - const size_t npc = std::get<2>(GetParam()); \ - \ - utils::simulate( \ - npc, [&](const std::shared_ptr& lctx) { \ - auto sctx = factory(conf, lctx); \ - \ - for (size_t rank = 0; rank < npc; rank++) { \ - /* GIVEN */ \ - auto x_p = rand_p(sctx.get(), kShape); \ - auto x_v = p2v(sctx.get(), x_p, rank); \ - \ - for (auto bits : kShiftBits) { \ - if (bits >= SizeOf(conf.field()) * 8) { \ - continue; \ - } \ - /* WHEN */ \ - auto r_v = v2p(sctx.get(), OP##_v(sctx.get(), x_v, bits)); \ - auto r_p = OP##_p(sctx.get(), x_p, bits); \ - \ - /* THEN */ \ - EXPECT_VALUE_EQ(r_v, r_p); \ - } \ - } \ - }); \ +#define TEST_UNARY_OP_WITH_BIT_V(OP) \ + TEST_P(ApiTest, OP##V) { \ + const auto factory = std::get<0>(GetParam()); \ + const RuntimeConfig& conf = std::get<1>(GetParam()); \ + const size_t npc = std::get<2>(GetParam()); \ + \ + utils::simulate( \ + npc, [&](const std::shared_ptr& lctx) { \ + auto sctx = factory(conf, lctx); \ + \ + for (size_t rank = 0; rank < npc; rank++) { \ + /* GIVEN */ \ + auto x_p = rand_p(sctx.get(), kShape); \ + auto x_v = p2v(sctx.get(), x_p, rank); \ + \ + for (auto bits : kShiftBits) { \ + if (bits >= SizeOf(conf.field()) * 8) { \ + continue; \ + } \ + /* WHEN */ \ + auto r_v = \ + v2p(sctx.get(), \ + OP##_v(sctx.get(), x_v, {static_cast(bits)})); \ + auto r_p = \ + OP##_p(sctx.get(), x_p, {static_cast(bits)}); \ + \ + /* THEN */ \ + EXPECT_VALUE_EQ(r_v, r_p); \ + } \ + } \ + }); \ } -#define TEST_UNARY_OP_WITH_BIT_P(OP) \ - TEST_P(ApiTest, OP##P) { \ - const auto factory = std::get<0>(GetParam()); \ - const RuntimeConfig& conf = std::get<1>(GetParam()); \ - const size_t npc = std::get<2>(GetParam()); \ - \ - utils::simulate(npc, \ - [&](const std::shared_ptr& lctx) { \ - auto sctx = factory(conf, lctx); \ - \ - /* GIVEN */ \ - auto p0 = rand_p(sctx.get(), kShape); \ - \ - for (auto bits : kShiftBits) { /* WHEN */ \ - if (bits >= SizeOf(conf.field()) * 8) { \ - continue; \ - } \ - auto r_p = OP##_p(sctx.get(), p0, bits); \ - auto r_pp = OP##_p(sctx.get(), p0, bits); \ - \ - /* THEN */ \ - EXPECT_VALUE_EQ(r_p, r_pp); \ - } \ - }); \ +#define TEST_UNARY_OP_WITH_BIT_P(OP) \ + TEST_P(ApiTest, OP##P) { \ + const auto factory = std::get<0>(GetParam()); \ + const RuntimeConfig& conf = std::get<1>(GetParam()); \ + const size_t npc = std::get<2>(GetParam()); \ + \ + utils::simulate( \ + npc, [&](const std::shared_ptr& lctx) { \ + auto sctx = factory(conf, lctx); \ + \ + /* GIVEN */ \ + auto p0 = rand_p(sctx.get(), kShape); \ + \ + for (auto bits : kShiftBits) { /* WHEN */ \ + if (bits >= SizeOf(conf.field()) * 8) { \ + continue; \ + } \ + auto r_p = OP##_p(sctx.get(), p0, {static_cast(bits)}); \ + auto r_pp = OP##_p(sctx.get(), p0, {static_cast(bits)}); \ + \ + /* THEN */ \ + EXPECT_VALUE_EQ(r_p, r_pp); \ + } \ + }); \ } #define TEST_UNARY_OP_WITH_BIT(OP) \ @@ -379,12 +383,13 @@ TEST_P(ApiTest, TruncS) { : kShape); // TODO: here we assume has msb error, only use lowest 10 bits. - p0 = arshift_p(sctx.get(), p0, SizeOf(conf.field()) * 8 - 10); + p0 = arshift_p(sctx.get(), p0, + {static_cast(SizeOf(conf.field()) * 8 - 10)}); const size_t bits = 2; auto r_s = s2p(sctx.get(), trunc_s(sctx.get(), p2s(sctx.get(), p0), bits, SignType::Unknown)); - auto r_p = arshift_p(sctx.get(), p0, bits); + auto r_p = arshift_p(sctx.get(), p0, {bits}); /* THEN */ EXPECT_VALUE_ALMOST_EQ(r_s, r_p, npc); diff --git a/libspu/mpc/cheetah/arith/cheetah_conv2d_test.cc b/libspu/mpc/cheetah/arith/cheetah_conv2d_test.cc index b05235a6..d71d2bb1 100644 --- a/libspu/mpc/cheetah/arith/cheetah_conv2d_test.cc +++ b/libspu/mpc/cheetah/arith/cheetah_conv2d_test.cc @@ -84,7 +84,7 @@ TEST_P(CheetahConv2dTest, Basic) { flatten(ring_add(toNdArray(result[0]), toNdArray(result[1]))); const int64_t kMaxDiff = 1; - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { ArrayView c(computed); NdArrayView exp(expected); for (auto idx = 0; idx < expected.numel(); idx++) { diff --git a/libspu/mpc/cheetah/arith/cheetah_dot_test.cc b/libspu/mpc/cheetah/arith/cheetah_dot_test.cc index 79697937..df93112d 100644 --- a/libspu/mpc/cheetah/arith/cheetah_dot_test.cc +++ b/libspu/mpc/cheetah/arith/cheetah_dot_test.cc @@ -68,7 +68,7 @@ TEST_P(CheetahDotTest, Basic) { EXPECT_EQ(expected.numel(), computed.numel()); const int64_t kMaxDiff = 1; - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto e = NdArrayView(expected); auto c = NdArrayView(computed); @@ -119,7 +119,7 @@ TEST_P(CheetahDotTest, BatchDot) { [[maybe_unused]] constexpr int64_t kMaxDiff = 1; int64_t max_diff = 0; - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto e = NdArrayView(expected); auto c = NdArrayView(computed); diff --git a/libspu/mpc/cheetah/arith/common.cc b/libspu/mpc/cheetah/arith/common.cc index 3deda13e..f87d92ab 100644 --- a/libspu/mpc/cheetah/arith/common.cc +++ b/libspu/mpc/cheetah/arith/common.cc @@ -102,7 +102,7 @@ NdArrayRef ring_conv2d(const NdArrayRef &tensor, const NdArrayRef &filter, NdArrayRef _filter = filter.reshape(fs); NdArrayRef _ret = ring_zeros(field, result_shape); - DISPATCH_ALL_FIELDS(field, "ring_conv2d", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { // NOTE(juhou): valid padding so offset are always 0. constexpr int64_t padh = 0; constexpr int64_t padw = 0; diff --git a/libspu/mpc/cheetah/arith/matmat_prot.cc b/libspu/mpc/cheetah/arith/matmat_prot.cc index ebb80f9f..919a13fa 100644 --- a/libspu/mpc/cheetah/arith/matmat_prot.cc +++ b/libspu/mpc/cheetah/arith/matmat_prot.cc @@ -113,7 +113,7 @@ NdArrayRef ConcatSubMatrix(const NdArrayRef& mat, const Shape2D& mat_shape, // NOTE: zero padding via initialization NdArrayRef flatten = ring_zeros(field, {num_coeff}); - DISPATCH_ALL_FIELDS(field, "ConcatSubMat", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using uT = std::make_unsigned::type; for (int64_t r = 0, rr = starts[0]; r < extents[0]; ++r, ++rr) { @@ -423,7 +423,7 @@ NdArrayRef MatMatProtocol::ParseResult(FieldType field, const Meta& meta, for (int64_t r = 0; r < row_ext; ++r) { for (int64_t c = 0; c < col_ext; ++c) { int64_t dst_idx = (r + row_start) * meta.dims[2] + col_start + c; - DISPATCH_ALL_FIELDS(field, "ParseResult", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { matmat.at(dst_idx) = result_poly.at(r * subdims[2] + c); }); @@ -532,13 +532,13 @@ NdArrayRef MatMatProtocol::ParsePackLWEsResult( std::vector decoded_vectors(ans_poly.size()); for (size_t i = 0; i < ans_poly.size(); ++i) { decoded_vectors[i] = - msh.ModulusDownRNS(field, {(int64_t)packing_width}, + msh.ModulusDownRNS(field, {static_cast(packing_width)}, {ans_poly[i].data(), ans_poly[i].coeff_count()}); } NdArrayRef matmat = ring_zeros(field, {meta.dims[0] * meta.dims[2]}); - DISPATCH_ALL_FIELDS(field, "pack_lwes_results", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView xmatmat(matmat); for (size_t i = 0; i < ans_poly.size(); ++i) { diff --git a/libspu/mpc/cheetah/arith/matmat_prot_test.cc b/libspu/mpc/cheetah/arith/matmat_prot_test.cc index 0b01f0d0..3a5f6521 100644 --- a/libspu/mpc/cheetah/arith/matmat_prot_test.cc +++ b/libspu/mpc/cheetah/arith/matmat_prot_test.cc @@ -153,7 +153,7 @@ TEST_P(MatMatProtTest, Plain) { EXPECT_EQ(expected.numel(), computed.numel()); - DISPATCH_ALL_FIELDS(field_, "", [&]() { + DISPATCH_ALL_FIELDS(field_, [&]() { auto xe = NdArrayView(expected); auto xc = NdArrayView(computed); for (int64_t i = 0; i < xc.numel(); ++i) { @@ -215,7 +215,7 @@ TEST_P(MatMatProtTest, EncLHS) { EXPECT_EQ(expected.numel(), computed.numel()); - DISPATCH_ALL_FIELDS(field_, "", [&]() { + DISPATCH_ALL_FIELDS(field_, [&]() { auto xe = NdArrayView(expected); auto xc = NdArrayView(computed); for (int64_t i = 0; i < xc.numel(); ++i) { @@ -276,7 +276,7 @@ TEST_P(MatMatProtTest, EncRHS) { EXPECT_EQ(expected.numel(), computed.numel()); - DISPATCH_ALL_FIELDS(field_, "", [&]() { + DISPATCH_ALL_FIELDS(field_, [&]() { auto xe = NdArrayView(expected); auto xc = NdArrayView(computed); for (int64_t i = 0; i < xc.numel(); ++i) { diff --git a/libspu/mpc/cheetah/arith/vector_encoder.cc b/libspu/mpc/cheetah/arith/vector_encoder.cc index 52b71488..3ee37a2b 100644 --- a/libspu/mpc/cheetah/arith/vector_encoder.cc +++ b/libspu/mpc/cheetah/arith/vector_encoder.cc @@ -16,7 +16,6 @@ #include "libspu/core/prelude.h" #include "libspu/core/type_util.h" -#include "libspu/core/xt_helper.h" #include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc::cheetah { @@ -84,7 +83,7 @@ void VectorEncoder::Backward(const NdArrayRef &vec, RLWEPt *out, const auto field = eltype.as()->field(); - DISPATCH_ALL_FIELDS(field, "Backward", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto tmp_buff = ring_zeros(field, {(int64_t)poly_deg_}); auto xvec = NdArrayView(vec); auto xtmp = NdArrayView(tmp_buff); diff --git a/libspu/mpc/cheetah/arith/vector_encoder_test.cc b/libspu/mpc/cheetah/arith/vector_encoder_test.cc index 6056a5c6..0231eeb7 100644 --- a/libspu/mpc/cheetah/arith/vector_encoder_test.cc +++ b/libspu/mpc/cheetah/arith/vector_encoder_test.cc @@ -125,7 +125,7 @@ TEST_P(VectorEncoderTest, ForwardBackward) { auto computed = ms_helper_->ModulusDownRNS(field_, {1L}, absl::MakeSpan(cnst)); - DISPATCH_ALL_FIELDS(field_, "Check", [&]() { + DISPATCH_ALL_FIELDS(field_, [&]() { NdArrayView got(computed); NdArrayView v0(vec0); NdArrayView v1(vec1); diff --git a/libspu/mpc/cheetah/arithmetic.cc b/libspu/mpc/cheetah/arithmetic.cc index e90e80ae..c1765b46 100644 --- a/libspu/mpc/cheetah/arithmetic.cc +++ b/libspu/mpc/cheetah/arithmetic.cc @@ -74,7 +74,7 @@ NdArrayRef MsbA2B::proc(KernelEvalContext* ctx, const NdArrayRef& x) const { const int rank = ctx->getState()->getRank(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using u2k = std::make_unsigned::type; const u2k mask = (static_cast(1) << shft) - 1; NdArrayRef adjusted = ring_zeros(field, x.shape()); diff --git a/libspu/mpc/cheetah/arithmetic.h b/libspu/mpc/cheetah/arithmetic.h index a0a1a5a6..26ecbb93 100644 --- a/libspu/mpc/cheetah/arithmetic.h +++ b/libspu/mpc/cheetah/arithmetic.h @@ -257,7 +257,7 @@ class LShiftA : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override; + const Sizes& bits) const override; }; class MsbA2B : public UnaryKernel { diff --git a/libspu/mpc/cheetah/arithmetic_semi2k.cc b/libspu/mpc/cheetah/arithmetic_semi2k.cc index 89ef7daa..dff06e17 100644 --- a/libspu/mpc/cheetah/arithmetic_semi2k.cc +++ b/libspu/mpc/cheetah/arithmetic_semi2k.cc @@ -24,7 +24,7 @@ namespace spu::mpc::cheetah { NdArrayRef RandA::proc(KernelEvalContext* ctx, const Shape& shape) const { auto* prg_state = ctx->getState(); const auto field = ctx->getState()->getDefaultField(); - return ring_rshift(prg_state->genPriv(field, shape), 2) + return ring_rshift(prg_state->genPriv(field, shape), {2}) .as(makeType(field)); } @@ -59,7 +59,7 @@ NdArrayRef A2V::proc(KernelEvalContext* ctx, const NdArrayRef& in, auto numel = in.numel(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { std::vector share(numel); NdArrayView _in(in); pforeach(0, numel, [&](int64_t idx) { share[idx] = _in[idx]; }); @@ -143,10 +143,7 @@ NdArrayRef MatMulAP::proc(KernelEvalContext*, const NdArrayRef& x, } NdArrayRef LShiftA::proc(KernelEvalContext*, const NdArrayRef& in, - size_t bits) const { - const auto field = in.eltype().as()->field(); - bits %= SizeOf(field) * 8; - + const Sizes& bits) const { return ring_lshift(in, bits).as(in.eltype()); } diff --git a/libspu/mpc/cheetah/boolean.h b/libspu/mpc/cheetah/boolean.h index cdca12db..2fbb11fd 100644 --- a/libspu/mpc/cheetah/boolean.h +++ b/libspu/mpc/cheetah/boolean.h @@ -116,7 +116,7 @@ class LShiftB : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t shift) const override; + const Sizes& shift) const override; }; class RShiftB : public ShiftKernel { @@ -128,7 +128,7 @@ class RShiftB : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t shift) const override; + const Sizes& shift) const override; }; class ARShiftB : public ShiftKernel { @@ -140,7 +140,7 @@ class ARShiftB : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t shift) const override; + const Sizes& shift) const override; }; class BitrevB : public BitrevKernel { diff --git a/libspu/mpc/cheetah/boolean_semi2k.cc b/libspu/mpc/cheetah/boolean_semi2k.cc index 786d0c38..697b2287 100644 --- a/libspu/mpc/cheetah/boolean_semi2k.cc +++ b/libspu/mpc/cheetah/boolean_semi2k.cc @@ -25,7 +25,7 @@ namespace { size_t getNumBits(const NdArrayRef& in) { if (in.eltype().isa()) { const auto field = in.eltype().as()->field(); - return DISPATCH_ALL_FIELDS(field, "_", + return DISPATCH_ALL_FIELDS(field, [&]() { return maxBitWidth(in); }); } else if (in.eltype().isa()) { return in.eltype().as()->nbits(); @@ -93,7 +93,7 @@ NdArrayRef AndBP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const size_t out_nbits = std::min(getNumBits(lhs), getNumBits(rhs)); NdArrayRef out(makeType(field, out_nbits), lhs.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _lhs(lhs); NdArrayView _rhs(rhs); NdArrayView _out(out); @@ -130,32 +130,30 @@ NdArrayRef XorBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, } NdArrayRef LShiftB::proc(KernelEvalContext*, const NdArrayRef& in, - size_t shift) const { + const Sizes& shift) const { const auto field = in.eltype().as()->field(); - shift %= SizeOf(field) * 8; - - size_t out_nbits = in.eltype().as()->nbits() + shift; + size_t out_nbits = in.eltype().as()->nbits() + + *std::max_element(shift.begin(), shift.end()); out_nbits = std::clamp(out_nbits, static_cast(0), SizeOf(field) * 8); return makeBShare(ring_lshift(in, shift), field, out_nbits); } NdArrayRef RShiftB::proc(KernelEvalContext*, const NdArrayRef& in, - size_t shift) const { + const Sizes& shift) const { const auto field = in.eltype().as()->field(); - shift %= SizeOf(field) * 8; - size_t nbits = in.eltype().as()->nbits(); - size_t out_nbits = nbits - std::min(nbits, shift); - SPU_ENFORCE(nbits <= SizeOf(field) * 8); + int64_t nbits = in.eltype().as()->nbits(); + int64_t out_nbits = + nbits - std::min(nbits, *std::min_element(shift.begin(), shift.end())); + SPU_ENFORCE(nbits <= static_cast(SizeOf(field) * 8)); return makeBShare(ring_rshift(in, shift), field, out_nbits); } NdArrayRef ARShiftB::proc(KernelEvalContext*, const NdArrayRef& in, - size_t shift) const { + const Sizes& shift) const { const auto field = in.eltype().as()->field(); - shift %= SizeOf(field) * 8; // arithmetic right shift expects to work on ring, or the behaviour is // undefined. @@ -183,7 +181,7 @@ NdArrayRef BitIntlB::proc(KernelEvalContext*, const NdArrayRef& in, NdArrayRef out(in.eltype(), in.shape()); auto numel = in.numel(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _in(in); NdArrayView _out(out); @@ -204,7 +202,7 @@ NdArrayRef BitDeintlB::proc(KernelEvalContext*, const NdArrayRef& in, NdArrayRef out(in.eltype(), in.shape()); auto numel = in.numel(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _in(in); NdArrayView _out(out); diff --git a/libspu/mpc/cheetah/nonlinear/compare_prot.cc b/libspu/mpc/cheetah/nonlinear/compare_prot.cc index 7ee00a9f..94d8b2a1 100644 --- a/libspu/mpc/cheetah/nonlinear/compare_prot.cc +++ b/libspu/mpc/cheetah/nonlinear/compare_prot.cc @@ -16,13 +16,11 @@ #include "yacl/crypto/rand/rand.h" #include "yacl/crypto/tools/prg.h" -#include "yacl/link/link.h" #include "libspu/core/type.h" #include "libspu/mpc/cheetah/ot/basic_ot_prot.h" #include "libspu/mpc/cheetah/ot/ot_util.h" #include "libspu/mpc/cheetah/type.h" -#include "libspu/mpc/common/communicator.h" #include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc::cheetah { @@ -58,14 +56,14 @@ void SetLeafOTMsg(absl::Span ot_messages, uint8_t digit, NdArrayRef CompareProtocol::DoCompute(const NdArrayRef& inp, bool greater_than, NdArrayRef* keep_eq, int64_t bitwidth) { auto field = inp.eltype().as()->field(); - int64_t num_digits = CeilDiv(bitwidth, (int64_t)compare_radix_); + int64_t num_digits = CeilDiv(bitwidth, static_cast(compare_radix_)); size_t radix = static_cast(1) << compare_radix_; // one-of-N OT int64_t num_cmp = inp.numel(); // init to all zero std::vector digits(num_cmp * num_digits, 0); // Step 1 break into digits \in [0, radix) - DISPATCH_ALL_FIELDS(field, "break_digits", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using u2k = std::make_unsigned::type; const auto mask_radix = makeBitsMask(compare_radix_); NdArrayView xinp(inp); @@ -142,7 +140,7 @@ NdArrayRef CompareProtocol::DoCompute(const NdArrayRef& inp, bool greater_than, ring_zeros(field, {static_cast(num_digits * num_cmp)}) .as(boolean_t); - DISPATCH_ALL_FIELDS(field, "copy_leaf", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView xprev_cmp(prev_cmp); NdArrayView xprev_eq(prev_eq); pforeach(0, prev_cmp.numel(), [&](int64_t i) { diff --git a/libspu/mpc/cheetah/nonlinear/compare_prot_test.cc b/libspu/mpc/cheetah/nonlinear/compare_prot_test.cc index a04da8b4..26f9fbd1 100644 --- a/libspu/mpc/cheetah/nonlinear/compare_prot_test.cc +++ b/libspu/mpc/cheetah/nonlinear/compare_prot_test.cc @@ -14,12 +14,9 @@ #include "libspu/mpc/cheetah/nonlinear/compare_prot.h" -#include - #include "gtest/gtest.h" #include "libspu/mpc/cheetah/ot/basic_ot_prot.h" -#include "libspu/mpc/cheetah/type.h" #include "libspu/mpc/utils/ring_ops.h" #include "libspu/mpc/utils/simulate.h" @@ -51,7 +48,7 @@ TEST_P(CompareProtTest, Compare) { inp[0] = ring_rand(field, shape); inp[1] = ring_rand(field, shape); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto xinp = NdArrayView(inp[0]); xinp[0] = 1; xinp[1] = 10; @@ -74,7 +71,7 @@ TEST_P(CompareProtTest, Compare) { cmp_oup[rank] = _c; }); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto xout0 = NdArrayView(cmp_oup[0]); auto xout1 = NdArrayView(cmp_oup[1]); auto xinp0 = NdArrayView(inp[0]); @@ -100,7 +97,7 @@ TEST_P(CompareProtTest, CompareBitWidth) { inp[0] = ring_rand(field, {n, 2}); inp[1] = ring_rand(field, {n, 2}); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { ring2k_t mask = (static_cast(1) << bw) - 1; auto xinp = NdArrayView(inp[0]); xinp[0] = 1; @@ -142,7 +139,7 @@ TEST_P(CompareProtTest, CompareBitWidth) { cmp_oup[rank] = _c; }); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto xout0 = NdArrayView(cmp_oup[0]); auto xout1 = NdArrayView(cmp_oup[1]); auto xinp0 = NdArrayView(inp[0]); @@ -172,7 +169,7 @@ TEST_P(CompareProtTest, WithEq) { inp[1] = _inp[0].slice({0, 0, 0}, {10, 10, 10}, {2, 3, 2}); shape = inp[0].shape(); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto xinp = NdArrayView(inp[0]); xinp[0] = 1; xinp[1] = 10; @@ -197,7 +194,7 @@ TEST_P(CompareProtTest, WithEq) { eq_oup[rank] = _e; }); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto xout0 = NdArrayView(cmp_oup[0]); auto xout1 = NdArrayView(cmp_oup[1]); auto xeq0 = NdArrayView(eq_oup[0]); @@ -229,7 +226,7 @@ TEST_P(CompareProtTest, WithEqBitWidth) { inp[0] = ring_rand(field, {n, 2}); inp[1] = ring_rand(field, {n, 2}); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { ring2k_t mask = (static_cast(1) << bw) - 1; auto xinp = NdArrayView(inp[0]); xinp[0] = 1; @@ -271,7 +268,7 @@ TEST_P(CompareProtTest, WithEqBitWidth) { eq_oup[rank] = _e; }); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto xout0 = NdArrayView(cmp_oup[0]); auto xout1 = NdArrayView(cmp_oup[1]); auto xeq0 = NdArrayView(eq_oup[0]); diff --git a/libspu/mpc/cheetah/nonlinear/equal_prot.cc b/libspu/mpc/cheetah/nonlinear/equal_prot.cc index 1ad852e6..2d04c886 100644 --- a/libspu/mpc/cheetah/nonlinear/equal_prot.cc +++ b/libspu/mpc/cheetah/nonlinear/equal_prot.cc @@ -16,13 +16,11 @@ #include "yacl/crypto/rand/rand.h" #include "yacl/crypto/tools/prg.h" -#include "yacl/link/link.h" #include "libspu/core/type.h" #include "libspu/mpc/cheetah/ot/basic_ot_prot.h" #include "libspu/mpc/cheetah/ot/ot_util.h" #include "libspu/mpc/cheetah/type.h" -#include "libspu/mpc/common/communicator.h" #include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc::cheetah { @@ -61,7 +59,7 @@ NdArrayRef EqualProtocol::DoCompute(const NdArrayRef& inp, size_t bit_width) { // init to all zero std::vector digits(num_cmp * num_digits, 0); - DISPATCH_ALL_FIELDS(field, "Equal_break_digits", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using u2k = std::make_unsigned::type; const auto mask_radix = makeBitsMask(compare_radix_); const auto mask_remain = makeBitsMask(remain); @@ -133,7 +131,7 @@ NdArrayRef EqualProtocol::DoCompute(const NdArrayRef& inp, size_t bit_width) { // m0[1], m1[1], ..., mN[1] // ... // m0[M], m1[M], ..., mN[M] - DISPATCH_ALL_FIELDS(field, "Equal_transpose", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView xprev_eq(prev_eq); for (int64_t r = 0; r < num_cmp; ++r) { for (int64_t c = 0; c < num_digits; ++c) { diff --git a/libspu/mpc/cheetah/nonlinear/equal_prot_test.cc b/libspu/mpc/cheetah/nonlinear/equal_prot_test.cc index 7a51cbd5..4eeead58 100644 --- a/libspu/mpc/cheetah/nonlinear/equal_prot_test.cc +++ b/libspu/mpc/cheetah/nonlinear/equal_prot_test.cc @@ -45,7 +45,7 @@ TEST_P(EqualProtTest, Basic) { inp[0] = ring_rand(field, shape); inp[1] = ring_rand(field, shape); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto xinp0 = NdArrayView(inp[0]); auto xinp1 = NdArrayView(inp[1]); std::copy_n(&xinp1[0], 5, &xinp0[0]); @@ -64,7 +64,7 @@ TEST_P(EqualProtTest, Basic) { SPU_ENFORCE_EQ(eq_oup[0].shape(), shape); SPU_ENFORCE_EQ(eq_oup[1].shape(), shape); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto xeq0 = NdArrayView(eq_oup[0]); auto xeq1 = NdArrayView(eq_oup[1]); auto xinp0 = NdArrayView(inp[0]); diff --git a/libspu/mpc/cheetah/nonlinear/truncate_prot.cc b/libspu/mpc/cheetah/nonlinear/truncate_prot.cc index 3edb5117..32e8dff1 100644 --- a/libspu/mpc/cheetah/nonlinear/truncate_prot.cc +++ b/libspu/mpc/cheetah/nonlinear/truncate_prot.cc @@ -16,7 +16,6 @@ #include "libspu/core/type.h" #include "libspu/mpc/cheetah/nonlinear/compare_prot.h" #include "libspu/mpc/cheetah/ot/basic_ot_prot.h" -#include "libspu/mpc/cheetah/ot/ot_util.h" #include "libspu/mpc/cheetah/type.h" #include "libspu/mpc/utils/ring_ops.h" @@ -66,7 +65,7 @@ NdArrayRef TruncateProtocol::ComputeWrap(const NdArrayRef& inp, wrap_bool = compare_prot.Compute(inp, true); } else { auto adjusted = ring_neg(inp); - DISPATCH_ALL_FIELDS(field, "wrap_adjust", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView xadj(adjusted); pforeach(0, inp.numel(), [&](int64_t i) { xadj[i] -= 1; }); }); @@ -93,7 +92,7 @@ NdArrayRef TruncateProtocol::MSB1ToWrap(const NdArrayRef& inp, const size_t bw = SizeOf(field) * 8; NdArrayRef cot_output = ring_zeros(field, inp.shape()); - DISPATCH_ALL_FIELDS(field, "MSB1ToWrap", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using u2k = std::make_unsigned::type; NdArrayView xinp(inp); auto xout = absl::MakeSpan(&cot_output.at(0), cot_output.numel()); @@ -148,7 +147,7 @@ NdArrayRef TruncateProtocol::MSB0ToWrap(const NdArrayRef& inp, outp = ring_randbit(field, inp.shape()); std::vector send(numel * N); - DISPATCH_ALL_FIELDS(field, "MSB0_adjust", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using u2k = std::make_unsigned::type; NdArrayView xinp(inp); NdArrayView xrnd(outp); @@ -166,7 +165,7 @@ NdArrayRef TruncateProtocol::MSB0ToWrap(const NdArrayRef& inp, sender->Flush(); } else { std::vector choices(numel, 0); - DISPATCH_ALL_FIELDS(field, "MSB0_adjust", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using u2k = std::make_unsigned::type; NdArrayView xinp(inp); for (int64_t i = 0; i < numel; ++i) { @@ -179,7 +178,7 @@ NdArrayRef TruncateProtocol::MSB0ToWrap(const NdArrayRef& inp, absl::MakeSpan(recv), nbits); outp = ring_zeros(field, inp.shape()); - DISPATCH_ALL_FIELDS(field, "MSB0_finalize", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView xoup(outp); pforeach(0, numel, [&](int64_t i) { xoup[i] = static_cast(recv[i] & 1); @@ -220,7 +219,7 @@ NdArrayRef TruncateProtocol::Compute(const NdArrayRef& inp, Meta meta) { if (rank == 0) { NdArrayRef tmp = inp.clone(); - DISPATCH_ALL_FIELDS(field, "trunc_with_heuristic", [&] { + DISPATCH_ALL_FIELDS(field, [&] { NdArrayView _inp(tmp); ring2k_t big_value = static_cast(1) << (bit_width - kHeuristicBound); @@ -230,7 +229,7 @@ NdArrayRef TruncateProtocol::Compute(const NdArrayRef& inp, Meta meta) { tmp = Compute(tmp, meta); - DISPATCH_ALL_FIELDS(field, "trunc_with_heuristic", [&] { + DISPATCH_ALL_FIELDS(field, [&] { NdArrayView _outp(tmp); ring2k_t big_value = static_cast(1) << (bit_width - kHeuristicBound - shift); @@ -246,7 +245,7 @@ NdArrayRef TruncateProtocol::Compute(const NdArrayRef& inp, Meta meta) { NdArrayRef wrap_ashr; NdArrayRef out = ring_zeros(field, inp.shape()); - return DISPATCH_ALL_FIELDS(field, "Truncate", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { const ring2k_t component = (static_cast(1) << (bit_width - 1)); NdArrayView xinp(inp); diff --git a/libspu/mpc/cheetah/nonlinear/truncate_prot_test.cc b/libspu/mpc/cheetah/nonlinear/truncate_prot_test.cc index 4cf591ac..f1c737e8 100644 --- a/libspu/mpc/cheetah/nonlinear/truncate_prot_test.cc +++ b/libspu/mpc/cheetah/nonlinear/truncate_prot_test.cc @@ -14,12 +14,9 @@ #include "libspu/mpc/cheetah/nonlinear/truncate_prot.h" -#include - #include "gtest/gtest.h" #include "libspu/mpc/cheetah/ot/basic_ot_prot.h" -#include "libspu/mpc/cheetah/type.h" #include "libspu/mpc/utils/ring_ops.h" #include "libspu/mpc/utils/simulate.h" @@ -65,7 +62,7 @@ TEST_P(TruncateProtTest, Basic) { sign = SignType::Unknown; } else { auto msg = ring_rand(field, {n}); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto xmsg = NdArrayView(msg); size_t bw = SizeOf(field) * 8; if (msb == "Zero") { @@ -111,7 +108,7 @@ TEST_P(TruncateProtTest, Basic) { EXPECT_EQ(oup[0].shape(), oup[1].shape()); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using signed_t = std::make_signed::type; using usigned_t = std::make_unsigned::type; @@ -163,9 +160,10 @@ TEST_P(TruncateProtTest, Heuristic) { NdArrayRef inp[2]; inp[0] = ring_rand(field, {n}); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto msg = ring_rand(field, {n}); - ring_rshift_(msg, TruncateProtocol::kHeuristicBound); + ring_rshift_(msg, + {static_cast(TruncateProtocol::kHeuristicBound)}); NdArrayView xmsg(msg); for (int64_t i = 0; i < n; i += 2) { xmsg[i] = -xmsg[i]; @@ -191,7 +189,7 @@ TEST_P(TruncateProtTest, Heuristic) { [[maybe_unused]] int count_zero = 0; [[maybe_unused]] int count_pos = 0; [[maybe_unused]] int count_neg = 0; - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using signed_t = std::make_signed::type; auto xout0 = NdArrayView(oup[0]); diff --git a/libspu/mpc/cheetah/ot/basic_ot_prot.cc b/libspu/mpc/cheetah/ot/basic_ot_prot.cc index 201a57ad..f5a03549 100644 --- a/libspu/mpc/cheetah/ot/basic_ot_prot.cc +++ b/libspu/mpc/cheetah/ot/basic_ot_prot.cc @@ -79,7 +79,7 @@ NdArrayRef BasicOTProtocols::PackedB2A(const NdArrayRef &inp) { } SPU_ENFORCE(nbits > 0 && nbits <= 8 * SizeOf(field)); - auto rand_bits = DISPATCH_ALL_FIELDS(field, "single_b2a", [&]() { + auto rand_bits = DISPATCH_ALL_FIELDS(field, [&]() { if ((nbits & 7) or (n * inp.elsize()) & 7) { // The SseTranspose requires the #rows and #columns is multiple of 8. // Thus, we call the less efficient RandBits on margin cases. @@ -142,7 +142,7 @@ NdArrayRef BasicOTProtocols::PackedB2A(const NdArrayRef &inp) { const int64_t n = _bits.numel() / nbits; // init as all 0s. auto iform = ring_zeros(field, inp.shape()); - DISPATCH_ALL_FIELDS(field, "conv_to_bits", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto bits = NdArrayView(_bits); auto digit = NdArrayView(iform); for (int64_t i = 0; i < n; ++i) { @@ -163,7 +163,7 @@ NdArrayRef BasicOTProtocols::PackedB2A(const NdArrayRef &inp) { // compute c + (1 - 2*c)* NdArrayRef oup = ring_zeros(field, inp.shape()); - DISPATCH_ALL_FIELDS(field, "packed_b2a", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using u2k = std::make_unsigned::type; int rank = Rank(); auto xr = NdArrayView(rand_bits); @@ -205,7 +205,7 @@ NdArrayRef BasicOTProtocols::SingleB2A(const NdArrayRef &inp, int bit_width) { const int64_t n = inp.numel(); NdArrayRef oup = ring_zeros(field, inp.shape()); - DISPATCH_ALL_FIELDS(field, "single_b2a", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using u2k = std::make_unsigned::type; const u2k msk = makeBitsMask(bit_width); auto input = NdArrayView(inp); @@ -373,7 +373,7 @@ std::array BasicOTProtocols::AndTriple(FieldType field, auto AND_b = ring_zeros(field, shape); auto AND_c = ring_zeros(field, shape); - DISPATCH_ALL_FIELDS(field, "AndTriple", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto AND_xa = NdArrayView(AND_a); auto AND_xb = NdArrayView(AND_b); auto AND_xc = NdArrayView(AND_c); @@ -430,7 +430,7 @@ std::array BasicOTProtocols::CorrelatedAndTriple( auto AND_b1 = ring_zeros(field, shape); auto AND_c1 = ring_zeros(field, shape); - DISPATCH_ALL_FIELDS(field, "AndTriple", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto AND_xa = NdArrayView(AND_a); auto AND_xb0 = NdArrayView(AND_b0); auto AND_xc0 = NdArrayView(AND_c0); @@ -465,7 +465,7 @@ NdArrayRef BasicOTProtocols::Multiplexer(const NdArrayRef &msg, std::vector sel(size); // Compute (x0 + x1) * (b0 ^ b1) // Also b0 ^ b1 = 1 - 2*b0*b1 - return DISPATCH_ALL_FIELDS(field, "Multiplexer", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _msg(msg); NdArrayView _sel(select); auto corr_data = absl::MakeSpan(&_corr_data.at(0), size); @@ -506,7 +506,7 @@ NdArrayRef BasicOTProtocols::PrivateMulxRecv(const NdArrayRef &msg, auto recv = ring_zeros(field, msg.shape()); std::vector sel(size); - DISPATCH_ALL_FIELDS(field, "convert", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _sel(select); pforeach(0, size, [&](int64_t i) { sel[i] = static_cast(_sel[i] & 1); }); @@ -526,7 +526,7 @@ NdArrayRef BasicOTProtocols::PrivateMulxRecv(const NdArrayRef &msg, // Compute (x0 + x1) * b // x0 * b + x1 * b // COT compute - DISPATCH_ALL_FIELDS(field, "MultiplexerOnPrivate", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _msg(msg); auto _recv = absl::MakeSpan(&recv.at(0), size); @@ -549,7 +549,7 @@ NdArrayRef BasicOTProtocols::PrivateMulxSend(const NdArrayRef &msg) { auto recv = ring_zeros(field, msg.shape()); // Compute (x0 + x1) * b // x0 * b + x1 * b - DISPATCH_ALL_FIELDS(field, "MultiplexerOnPrivate", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto _msg = absl::MakeConstSpan(&msg.at(0), size); auto _recv = absl::MakeSpan(&recv.at(0), size); diff --git a/libspu/mpc/cheetah/ot/basic_ot_prot_test.cc b/libspu/mpc/cheetah/ot/basic_ot_prot_test.cc index aa18321c..17f78a4d 100644 --- a/libspu/mpc/cheetah/ot/basic_ot_prot_test.cc +++ b/libspu/mpc/cheetah/ot/basic_ot_prot_test.cc @@ -82,7 +82,7 @@ TEST_P(BasicOTProtTest, SingleB2A) { auto bshr0 = ring_rand(field, shape).as(boolean_t); auto bshr1 = ring_rand(field, shape).as(boolean_t); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto mask = static_cast(-1); if (nbits > 0) { mask = (static_cast(1) << packed_nbits) - 1; @@ -110,7 +110,7 @@ TEST_P(BasicOTProtTest, SingleB2A) { EXPECT_EQ(ashr0.shape(), ashr1.shape()); EXPECT_EQ(shape, ashr0.shape()); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView b0(bshr0); NdArrayView b1(bshr1); NdArrayView a0(ashr0); @@ -138,7 +138,7 @@ TEST_P(BasicOTProtTest, PackedB2A) { auto bshr0 = ring_rand(field, shape).as(boolean_t); auto bshr1 = ring_rand(field, shape).as(boolean_t); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto mask = static_cast(-1); if (nbits > 0) { mask = (static_cast(1) << packed_nbits) - 1; @@ -165,7 +165,7 @@ TEST_P(BasicOTProtTest, PackedB2A) { EXPECT_EQ(ashr0.shape(), ashr1.shape()); EXPECT_EQ(ashr0.shape(), shape); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView b0(bshr0); NdArrayView b1(bshr1); NdArrayView a0(ashr0); @@ -197,7 +197,7 @@ TEST_P(BasicOTProtTest, PackedB2AFull) { auto bshr0 = ring_rand(field, shape).as(boolean_t); auto bshr1 = ring_rand(field, shape).as(boolean_t); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto mask = static_cast(-1); if (nbits > 0) { mask = (static_cast(1) << packed_nbits) - 1; @@ -228,7 +228,7 @@ TEST_P(BasicOTProtTest, PackedB2AFull) { EXPECT_EQ(ashr0.shape(), ashr1.shape()); EXPECT_EQ(ashr0.shape(), shape); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView b0(bshr0); NdArrayView b1(bshr1); NdArrayView a0(ashr0); @@ -267,7 +267,7 @@ TEST_P(BasicOTProtTest, AndTripleSparse) { } }); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { ring2k_t max = static_cast(1) << target_nbits; NdArrayView a0(triple[0][0]); NdArrayView b0(triple[0][1]); @@ -304,7 +304,7 @@ TEST_P(BasicOTProtTest, BitwiseAnd) { for (int i : {0, 1}) { lhs[i] = ring_rand(field, shape).as(boolean_t); rhs[i] = ring_rand(field, shape).as(boolean_t); - DISPATCH_ALL_FIELDS(field, "mask", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { ring2k_t mask = makeBitsMask(bw); NdArrayView L(lhs[i]); NdArrayView R(rhs[i]); @@ -324,7 +324,7 @@ TEST_P(BasicOTProtTest, BitwiseAnd) { auto expected = ring_and(ring_xor(lhs[0], lhs[1]), ring_xor(rhs[0], rhs[1])); auto got = ring_xor(out[0], out[1]); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView e(expected); NdArrayView g(got); @@ -350,7 +350,7 @@ TEST_P(BasicOTProtTest, AndTripleFull) { } }); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView a0(packed_triple[0][0]); NdArrayView b0(packed_triple[0][1]); NdArrayView c0(packed_triple[0][2]); @@ -383,7 +383,7 @@ TEST_P(BasicOTProtTest, Multiplexer) { auto bshr0 = ring_rand(field, shape).as(boolean_t); auto bshr1 = ring_rand(field, shape).as(boolean_t); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto mask = static_cast(1); NdArrayView xb0(bshr0); NdArrayView xb1(bshr1); @@ -407,7 +407,7 @@ TEST_P(BasicOTProtTest, Multiplexer) { EXPECT_EQ(computed[0].shape(), computed[1].shape()); EXPECT_EQ(computed[0].shape(), shape); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView a0(ashr0); NdArrayView a1(ashr1); NdArrayView b0(bshr0); @@ -444,7 +444,7 @@ TEST_P(BasicOTProtTest, CorrelatedAndTriple) { EXPECT_EQ(corr_triple[1][0].shape(), corr_triple[1][i].shape()); } - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto a0 = NdArrayView(corr_triple[0][0]); auto b0 = NdArrayView(corr_triple[0][1]); auto c0 = NdArrayView(corr_triple[0][2]); @@ -487,7 +487,7 @@ TEST_P(BasicOTProtTest, PrivateMulx) { auto ashr1 = ring_rand(field, shape); auto choices = ring_rand(field, shape).as(boolean_t); - DISPATCH_ALL_FIELDS(field, "bit", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto mask = static_cast(1); NdArrayView xb(choices); pforeach(0, xb.numel(), [&](int64_t i) { xb[i] &= mask; }); @@ -507,7 +507,7 @@ TEST_P(BasicOTProtTest, PrivateMulx) { EXPECT_EQ(computed[0].shape(), computed[1].shape()); EXPECT_EQ(computed[0].shape(), shape); - DISPATCH_ALL_FIELDS(field, "check", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView a0(ashr0); NdArrayView a1(ashr1); NdArrayView c(choices); diff --git a/libspu/mpc/cheetah/ot/emp/ferret_test.cc b/libspu/mpc/cheetah/ot/emp/ferret_test.cc index af6cb96a..b878bdbe 100644 --- a/libspu/mpc/cheetah/ot/emp/ferret_test.cc +++ b/libspu/mpc/cheetah/ot/emp/ferret_test.cc @@ -55,7 +55,7 @@ TEST_P(FerretCOTTest, ChosenCorrelationChosenChoice) { return static_cast(uniform(rdv) & 1); }); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView correlation(_correlation); std::vector computed[2]; utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { @@ -86,7 +86,7 @@ TEST_P(FerretCOTTest, RndMsgRndChoice) { constexpr size_t bw = 2; size_t n = 10; - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { std::vector msg0(n); std::vector msg1(n); ring2k_t max = static_cast(1) << bw; @@ -123,7 +123,7 @@ TEST_P(FerretCOTTest, RndMsgChosenChoice) { constexpr size_t bw = 2; size_t n = 10; - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { std::vector msg0(n); std::vector msg1(n); ring2k_t max = static_cast(1) << bw; @@ -163,7 +163,7 @@ TEST_P(FerretCOTTest, ChosenMsgChosenChoice) { size_t kWorldSize = 2; int64_t n = 100; auto field = GetParam(); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using scalar_t = ring2k_t; std::default_random_engine rdv; std::uniform_int_distribution uniform(0, -1); diff --git a/libspu/mpc/cheetah/ot/ot_util.cc b/libspu/mpc/cheetah/ot/ot_util.cc index 1d308a2f..3d593512 100644 --- a/libspu/mpc/cheetah/ot/ot_util.cc +++ b/libspu/mpc/cheetah/ot/ot_util.cc @@ -61,7 +61,7 @@ NdArrayRef OpenShare(const NdArrayRef &shr, ReduceOp op, size_t nbits, size_t compact_numel = CeilDiv(numel * nbits, fwidth); NdArrayRef out(shr.eltype(), {(int64_t)numel}); - DISPATCH_ALL_FIELDS(field, "zip", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto inp = absl::MakeConstSpan(&shr.at(0), numel); auto oup = absl::MakeSpan(&out.at(0), compact_numel); @@ -94,8 +94,10 @@ NdArrayRef OpenShare(const NdArrayRef &shr, ReduceOp op, size_t nbits, __attribute__((target("sse2"))) #endif void SseTranspose(uint8_t *out, uint8_t const *inp, uint64_t nrows, uint64_t ncols) { - uint64_t rr, cc; - int i, h; + uint64_t rr; + uint64_t cc; + int i; + int h; union { __m128i x; uint8_t b[16]; @@ -116,7 +118,9 @@ void SseTranspose(uint8_t *out, uint8_t const *inp, uint64_t nrows, uint64_t nco *(uint16_t *)&OUT(rr, cc + i) = _mm_movemask_epi8(vec); } } - if (rr == nrows) return; + if (rr == nrows) { + return; + } // The remainder is a block of 8x(16n+8) bits (n may be 0). // Do a PAIR of 8x8 blocks in each step: @@ -153,9 +157,12 @@ void SseTranspose(uint8_t *out, uint8_t const *inp, uint64_t nrows, uint64_t nco if (cc == ncols) return; // Do the remaining 8x8 block: - for (i = 0; i < 8; ++i) tmp.b[i] = INP(rr + i, cc); - for (i = 8; --i >= 0; tmp.x = _mm_slli_epi64(tmp.x, 1)) + for (i = 0; i < 8; ++i) { + tmp.b[i] = INP(rr + i, cc); + } + for (i = 8; --i >= 0; tmp.x = _mm_slli_epi64(tmp.x, 1)) { OUT(rr, cc + i) = _mm_movemask_epi8(tmp.x); + } } } // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/ot/ot_util_test.cc b/libspu/mpc/cheetah/ot/ot_util_test.cc index ea5e0aac..149e9694 100644 --- a/libspu/mpc/cheetah/ot/ot_util_test.cc +++ b/libspu/mpc/cheetah/ot/ot_util_test.cc @@ -14,8 +14,6 @@ #include "libspu/mpc/cheetah/ot/ot_util.h" -#include - #include "gtest/gtest.h" #include "libspu/mpc/utils/ring_ops.h" @@ -40,7 +38,7 @@ TEST_P(OtUtilTest, ZipArray) { auto unzip = ring_zeros(field, {n}); - DISPATCH_ALL_FIELDS(field, "UT_ZipArray", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { for (size_t bw : {1, 2, 4, 7, 15, 16}) { int64_t pack_load = elsze * 8 / bw; auto zip = ring_zeros(field, {(n + pack_load - 1) / pack_load}); @@ -69,7 +67,7 @@ TEST_P(OtUtilTest, ZipArrayBit) { auto unzip = ring_zeros(field, {n}); - DISPATCH_ALL_FIELDS(field, "UT_ZipArrayBit", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { const size_t elsze = SizeOf(field); for (size_t bw : {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) { size_t width = elsze * 8; diff --git a/libspu/mpc/cheetah/ot/yacl/ferret_test.cc b/libspu/mpc/cheetah/ot/yacl/ferret_test.cc index c78abe62..a1eea9e7 100644 --- a/libspu/mpc/cheetah/ot/yacl/ferret_test.cc +++ b/libspu/mpc/cheetah/ot/yacl/ferret_test.cc @@ -55,7 +55,7 @@ TEST_P(FerretCOTTest, ChosenCorrelationChosenChoice) { return static_cast(uniform(rdv) & 1); }); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView correlation(_correlation); std::vector computed[2]; utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { @@ -87,7 +87,7 @@ TEST_P(FerretCOTTest, RndMsgRndChoice) { constexpr size_t bw = 2; size_t n = 10; - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { std::vector msg0(n); std::vector msg1(n); ring2k_t max = static_cast(1) << bw; @@ -125,7 +125,7 @@ TEST_P(FerretCOTTest, RndMsgChosenChoice) { constexpr size_t bw = 2; size_t n = 10; - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { std::vector msg0(n); std::vector msg1(n); ring2k_t max = static_cast(1) << bw; @@ -166,7 +166,7 @@ TEST_P(FerretCOTTest, ChosenMsgChosenChoice) { int64_t n = 1 << 10; auto field = std::get<0>(GetParam()); auto use_ss = std::get<1>(GetParam()); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using scalar_t = ring2k_t; std::default_random_engine rdv; std::uniform_int_distribution uniform(0, -1); diff --git a/libspu/mpc/cheetah/rlwe/modswitch_helper.cc b/libspu/mpc/cheetah/rlwe/modswitch_helper.cc index 69c86704..41ea9ff1 100644 --- a/libspu/mpc/cheetah/rlwe/modswitch_helper.cc +++ b/libspu/mpc/cheetah/rlwe/modswitch_helper.cc @@ -436,7 +436,7 @@ void ModulusSwitchHelper::ModulusUpAt(const NdArrayRef &src, size_t mod_idx, SPU_ENFORCE(eltype.isa(), "source must be ring_type, got={}", eltype); const auto field = eltype.as()->field(); - DISPATCH_ALL_FIELDS(field, "ModulusUpAt", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using ring2u = std::make_unsigned::type; impl_->ModulusUpAt(NdArrayView(src), mod_idx, out); }); @@ -450,7 +450,7 @@ void ModulusSwitchHelper::CenteralizeAt(const NdArrayRef &src, size_t mod_idx, SPU_ENFORCE_EQ(numel, out.size()); SPU_ENFORCE(eltype.isa(), "source must be ring_type, got={}", eltype); const auto field = eltype.as()->field(); - DISPATCH_ALL_FIELDS(field, "CenteralizeAt", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using ring2u = std::make_unsigned::type; impl_->CenteralizeAt(NdArrayView(src), mod_idx, out); }); @@ -481,7 +481,7 @@ void ModulusSwitchHelper::ModulusDownRNS(absl::Span src, size_t num_elt = out.numel(); SPU_ENFORCE_EQ(num_elt * num_modulus, src.size()); - return DISPATCH_ALL_FIELDS(field, "ModulusDownRNS", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using ring2u = std::make_unsigned::type; absl::Span out_wrap(reinterpret_cast(out.data()), num_elt); diff --git a/libspu/mpc/cheetah/rlwe/modswitch_helper_test.cc b/libspu/mpc/cheetah/rlwe/modswitch_helper_test.cc index ecb5beee..2aa7a203 100644 --- a/libspu/mpc/cheetah/rlwe/modswitch_helper_test.cc +++ b/libspu/mpc/cheetah/rlwe/modswitch_helper_test.cc @@ -120,9 +120,10 @@ TEST_P(RLWE2LWETest, ModulusSwitch_UpDown) { } // e = round(r'/Delta) mod t \in R_t auto src = absl::MakeSpan(pt.data(), pt.coeff_count()); - auto cmp = ms_helper_->ModulusDownRNS(field_, {(int64_t)poly_deg}, src); + auto cmp = ms_helper_->ModulusDownRNS( + field_, {static_cast(poly_deg)}, src); // check r =? e - DISPATCH_ALL_FIELDS(field_, "", [&]() { + DISPATCH_ALL_FIELDS(field_, [&]() { auto expected = NdArrayView(_vec); auto computed = NdArrayView(cmp); for (int64_t i = 0; i < expected.numel(); ++i) { @@ -145,7 +146,7 @@ TEST_P(RLWE2LWETest, ModulusSwitch_DownUp) { // r <- R_q RLWEPt rnd; UniformPoly(*context_, &rnd); - auto &modulus = context_->first_context_data()->parms().coeff_modulus(); + const auto &modulus = context_->first_context_data()->parms().coeff_modulus(); // b = a' - r RLWEPt poly1; { @@ -158,8 +159,8 @@ TEST_P(RLWE2LWETest, ModulusSwitch_DownUp) { // r' = round(r/Delta) mod t \in R_t // b' = round(b/Delta) mod t \in R_t - auto shr0 = ring_zeros(field_, {(int64_t)poly_deg}); - auto shr1 = ring_zeros(field_, {(int64_t)poly_deg}); + auto shr0 = ring_zeros(field_, {static_cast(poly_deg)}); + auto shr1 = ring_zeros(field_, {static_cast(poly_deg)}); { auto src = absl::MakeSpan(rnd.data(), rnd.coeff_count()); ms_helper_->ModulusDownRNS(src, shr0); @@ -168,7 +169,7 @@ TEST_P(RLWE2LWETest, ModulusSwitch_DownUp) { ms_helper_->ModulusDownRNS(src, shr1); } - DISPATCH_ALL_FIELDS(field_, "", [&]() { + DISPATCH_ALL_FIELDS(field_, [&]() { auto expected = NdArrayView(vec_a); auto computed0 = NdArrayView(shr0); auto computed1 = NdArrayView(shr1); diff --git a/libspu/mpc/cheetah/rlwe/packlwes_test.cc b/libspu/mpc/cheetah/rlwe/packlwes_test.cc index 6d924c26..4bda1a27 100644 --- a/libspu/mpc/cheetah/rlwe/packlwes_test.cc +++ b/libspu/mpc/cheetah/rlwe/packlwes_test.cc @@ -13,8 +13,6 @@ // limitations under the License. #include "libspu/mpc/cheetah/rlwe/packlwes.h" -#include - #include "gtest/gtest.h" #include "seal/seal.h" @@ -235,7 +233,8 @@ TEST_P(PackLWEsTest, Phantom) { N_encoder_->Forward(array, &pt, true); NttInplace(pt, *N_context_); - RLWECt rlwe0, rlwe1; + RLWECt rlwe0; + RLWECt rlwe1; CATCH_SEAL_ERROR(encryptor.encrypt_symmetric(pt, rlwe0)); CATCH_SEAL_ERROR(encryptor.encrypt_symmetric(pt, rlwe1)); if (rlwe0.is_ntt_form()) { @@ -345,7 +344,7 @@ void VectorEncoder::Backward(const NdArrayRef &vec, RLWEPt *out, const auto field = eltype.as()->field(); - DISPATCH_ALL_FIELDS(field, "Backward", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { auto tmp_buff = ring_zeros(field, {(int64_t)poly_deg_}); auto xvec = NdArrayView(vec); auto xtmp = NdArrayView(tmp_buff); diff --git a/libspu/mpc/cheetah/state.cc b/libspu/mpc/cheetah/state.cc index 1f754baa..dded2f5f 100644 --- a/libspu/mpc/cheetah/state.cc +++ b/libspu/mpc/cheetah/state.cc @@ -16,8 +16,6 @@ #include -#include "spdlog/spdlog.h" - #include "libspu/core/context.h" #include "libspu/core/ndarray_ref.h" #include "libspu/core/prelude.h" @@ -89,7 +87,7 @@ void CheetahMulState::makeSureCacheSize(FieldType field, int64_t numel) { cross.slice({num_beaver}, {2 * num_beaver}, {1})), ring_mul(beaver[0], beaver[1])); - DISPATCH_ALL_FIELDS(field, "makeSureCacheSize", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { for (size_t i : {0, 1, 2}) { auto tmp = ring_zeros(field, {num_beaver + cached_sze_}); NdArrayView new_cache(tmp); diff --git a/libspu/mpc/common/prg_state.h b/libspu/mpc/common/prg_state.h index 7e8938e6..47a6eb66 100644 --- a/libspu/mpc/common/prg_state.h +++ b/libspu/mpc/common/prg_state.h @@ -39,9 +39,10 @@ class PrgState : public State { uint64_t priv_counter_ = 0; // Pseudorandom Secret Sharing seeds. - uint128_t next_seed_ = 0; uint128_t self_seed_ = 0; - uint64_t prss_counter_ = 0; + uint128_t next_seed_ = 0; + uint64_t r0_counter_ = 0; // cnt for self_seed + uint64_t r1_counter_ = 0; // cnt for next_seed public: static constexpr char kBindName[] = "PrgState"; @@ -65,7 +66,7 @@ class PrgState : public State { // This correlation could be used to construct zero shares. // // Note: ignore_first, ignore_second is for perf improvement. - enum class GenPrssCtrl { Both, First, Second, None }; + enum class GenPrssCtrl { Both, First, Second }; std::pair genPrssPair(FieldType field, const Shape& shape, GenPrssCtrl ctrl); @@ -73,29 +74,21 @@ class PrgState : public State { template void fillPrssPair(T* r0, T* r1, size_t numel, GenPrssCtrl ctrl) { switch (ctrl) { - case GenPrssCtrl::None: { - // Nothing to generate, pure dummy - prss_counter_ = yacl::crypto::DummyUpdateRandomCount(prss_counter_, - numel * sizeof(T)); - return; - } case GenPrssCtrl::First: { - prss_counter_ = yacl::crypto::FillPRand( - kAesType, self_seed_, 0, prss_counter_, absl::MakeSpan(r0, numel)); + r0_counter_ = yacl::crypto::FillPRand( + kAesType, self_seed_, 0, r0_counter_, absl::MakeSpan(r0, numel)); return; } case GenPrssCtrl::Second: { - prss_counter_ = yacl::crypto::FillPRand( - kAesType, next_seed_, 0, prss_counter_, absl::MakeSpan(r1, numel)); + r1_counter_ = yacl::crypto::FillPRand( + kAesType, next_seed_, 0, r1_counter_, absl::MakeSpan(r1, numel)); return; } case GenPrssCtrl::Both: { - auto counter0 = yacl::crypto::FillPRand( - kAesType, self_seed_, 0, prss_counter_, absl::MakeSpan(r0, numel)); - auto counter1 = yacl::crypto::FillPRand( - kAesType, next_seed_, 0, prss_counter_, absl::MakeSpan(r1, numel)); - SPU_ENFORCE(counter0 == counter1); - prss_counter_ = counter0; + r0_counter_ = yacl::crypto::FillPRand( + kAesType, self_seed_, 0, r0_counter_, absl::MakeSpan(r0, numel)); + r1_counter_ = yacl::crypto::FillPRand( + kAesType, next_seed_, 0, r1_counter_, absl::MakeSpan(r1, numel)); return; } } diff --git a/libspu/mpc/common/pv2k.cc b/libspu/mpc/common/pv2k.cc index 6c777779..ce59f9bb 100644 --- a/libspu/mpc/common/pv2k.cc +++ b/libspu/mpc/common/pv2k.cc @@ -73,7 +73,7 @@ class V2P : public UnaryKernel { auto numel = in.numel(); - DISPATCH_ALL_FIELDS(field, "v2p", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { std::vector priv(numel); NdArrayView _in(in); @@ -113,7 +113,7 @@ class MakeP : public Kernel { Strides(shape.size(), 0), // strides 0); - DISPATCH_ALL_FIELDS(field, "pub2k.make_p", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { arr.at(Index(shape.size(), 0)) = static_cast(init); }); return Value(arr, DT_INVALID); @@ -176,7 +176,8 @@ class MsbP : public UnaryKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext*, const NdArrayRef& in) const override { - return ring_rshift(in, in.elsize() * 8 - 1).as(in.eltype()); + return ring_rshift(in, {static_cast(in.elsize() * 8 - 1)}) + .as(in.eltype()); } }; @@ -190,7 +191,8 @@ class MsbV : public UnaryKernel { NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override { if (isOwner(ctx, in.eltype())) { - return ring_rshift(in, in.elsize() * 8 - 1).as(in.eltype()); + return ring_rshift(in, {static_cast(in.elsize() * 8 - 1)}) + .as(in.eltype()); } else { return in; } @@ -512,7 +514,7 @@ class LShiftP : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext*, const NdArrayRef& in, - size_t bits) const override { + const Sizes& bits) const override { return ring_lshift(in, bits).as(in.eltype()); } }; @@ -526,7 +528,7 @@ class LShiftV : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override { + const Sizes& bits) const override { if (isOwner(ctx, in.eltype())) { return ring_lshift(in, bits).as(in.eltype()); } else { @@ -544,7 +546,7 @@ class RShiftP : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext*, const NdArrayRef& in, - size_t bits) const override { + const Sizes& bits) const override { return ring_rshift(in, bits).as(in.eltype()); } }; @@ -558,7 +560,7 @@ class RShiftV : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override { + const Sizes& bits) const override { if (isOwner(ctx, in.eltype())) { return ring_rshift(in, bits).as(in.eltype()); } else { @@ -576,7 +578,7 @@ class ARShiftP : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext*, const NdArrayRef& in, - size_t bits) const override { + const Sizes& bits) const override { return ring_arshift(in, bits).as(in.eltype()); } }; @@ -590,7 +592,7 @@ class ARShiftV : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override { + const Sizes& bits) const override { if (isOwner(ctx, in.eltype())) { return ring_arshift(in, bits).as(in.eltype()); } else { @@ -607,8 +609,8 @@ NdArrayRef rounded_arshift(const NdArrayRef& in, size_t bits) { // https://stackoverflow.com/questions/14008330/how-do-you-multiply-two-fixed-point-numbers // Under certain pattern, like sum(mul(A, B)), error can accumulate in a // fairly significant way - auto v1 = ring_arshift(in, bits); - auto v2 = ring_arshift(in, bits - 1); + auto v1 = ring_arshift(in, {static_cast(bits)}); + auto v2 = ring_arshift(in, {static_cast(bits - 1)}); ring_and_(v2, ring_ones(in.eltype().as()->field(), in.shape())); ring_add_(v1, v2); return v1; @@ -623,8 +625,9 @@ class TruncP : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext*, const NdArrayRef& in, - size_t bits) const override { - return rounded_arshift(in, bits).as(in.eltype()); + const Sizes& bits) const override { + SPU_ENFORCE(bits.size() == 1, "truncation bits should be splat"); + return rounded_arshift(in, bits[0]).as(in.eltype()); } }; @@ -637,9 +640,10 @@ class TruncV : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override { + const Sizes& bits) const override { if (isOwner(ctx, in.eltype())) { - return rounded_arshift(in, bits).as(in.eltype()); + SPU_ENFORCE(bits.size() == 1, "truncation bits should be splat"); + return rounded_arshift(in, bits[0]).as(in.eltype()); } else { return in; } @@ -701,7 +705,7 @@ class GenInvPermP : public GenInvPermKernel { auto numel = in.numel(); - DISPATCH_ALL_FIELDS(field, "gen_inv_perm_p", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = std::make_signed_t; std::vector perm(numel); std::iota(perm.begin(), perm.end(), 0); @@ -735,7 +739,7 @@ class GenInvPermV : public GenInvPermKernel { auto numel = in.numel(); const auto field = in.eltype().as()->field(); - DISPATCH_ALL_FIELDS(field, "gen_inv_perm_v", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = std::make_signed_t; std::vector perm(numel); std::iota(perm.begin(), perm.end(), 0); @@ -769,7 +773,7 @@ class InvPermPP : public PermKernel { SPU_ENFORCE_EQ(x.eltype(), y.eltype()); NdArrayRef z(x.eltype(), x.shape()); const auto field = x.eltype().as()->field(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = std::make_signed_t; NdArrayView _x(x); NdArrayView _y(y); @@ -795,7 +799,7 @@ class InvPermVV : public PermKernel { if (isOwner(ctx, x.eltype())) { NdArrayRef z(x.eltype(), x.shape()); const auto field = x.eltype().as()->field(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = std::make_signed_t; NdArrayView _x(x); NdArrayView _y(y); @@ -823,7 +827,7 @@ class PermPP : public PermKernel { SPU_ENFORCE_EQ(x.eltype(), y.eltype()); NdArrayRef z(x.eltype(), x.shape()); const auto field = x.eltype().as()->field(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = std::make_signed_t; NdArrayView _x(x); NdArrayView _y(y); @@ -849,7 +853,7 @@ class PermVV : public PermKernel { if (isOwner(ctx, x.eltype())) { NdArrayRef z(x.eltype(), x.shape()); const auto field = x.eltype().as()->field(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = std::make_signed_t; NdArrayView _x(x); NdArrayView _y(y); @@ -878,7 +882,7 @@ class MergeKeysP : public MergeKeysKernel { NdArrayRef out(inputs[0].eltype(), inputs[0].shape()); const auto field = inputs[0].eltype().as()->field(); const auto numel = inputs[0].numel(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = std::make_signed_t; NdArrayView _out(out); _out[0] = 0; @@ -917,7 +921,7 @@ class MergeKeysV : public MergeKeysKernel { NdArrayRef out(inputs[0].eltype(), inputs[0].shape()); const auto field = inputs[0].eltype().as()->field(); const auto numel = inputs[0].numel(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = std::make_signed_t; NdArrayView _out(out); _out[0] = 0; diff --git a/libspu/mpc/kernel.cc b/libspu/mpc/kernel.cc index 90d51818..56b5c1bb 100644 --- a/libspu/mpc/kernel.cc +++ b/libspu/mpc/kernel.cc @@ -43,7 +43,11 @@ void RevealToKernel::evaluate(KernelEvalContext* ctx) const { void ShiftKernel::evaluate(KernelEvalContext* ctx) const { const auto& in = ctx->getParam(0); - size_t bits = ctx->getParam(1); + const auto& bits = ctx->getParam(1); + + SPU_ENFORCE( + bits.size() == 1 || in.numel() == static_cast(bits.size()), + "numel mismatch {} {}", in.numel(), bits.size()); auto res = proc(ctx, UnwrapValue(in), bits); diff --git a/libspu/mpc/kernel.h b/libspu/mpc/kernel.h index 12383391..d75b3426 100644 --- a/libspu/mpc/kernel.h +++ b/libspu/mpc/kernel.h @@ -42,7 +42,7 @@ class ShiftKernel : public Kernel { public: void evaluate(KernelEvalContext* ctx) const override; virtual NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const = 0; + const Sizes& bits) const = 0; }; class BinaryKernel : public Kernel { diff --git a/libspu/mpc/ref2k/ref2k.cc b/libspu/mpc/ref2k/ref2k.cc index 1d76ea95..cc75e8a7 100644 --- a/libspu/mpc/ref2k/ref2k.cc +++ b/libspu/mpc/ref2k/ref2k.cc @@ -165,7 +165,7 @@ class Ref2kV2S : public UnaryKernel { int64_t numel = in.numel(); - DISPATCH_ALL_FIELDS(field, "v2s", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { std::vector _send(numel); NdArrayView _in(in); @@ -196,7 +196,7 @@ class Ref2kRandS : public RandKernel { const auto field = ctx->getState()->getDefaultField(); return ring_rshift( - state->genPubl(field, shape).as(makeType(field)), 2); + state->genPubl(field, shape).as(makeType(field)), {2}); } }; @@ -368,7 +368,7 @@ class Ref2kLShiftS : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override { + const Sizes& bits) const override { return ring_lshift(in, bits).as(in.eltype()); } }; @@ -382,7 +382,7 @@ class Ref2kRShiftS : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override { + const Sizes& bits) const override { return ring_rshift(in, bits).as(in.eltype()); } }; @@ -414,7 +414,7 @@ class Ref2kARShiftS : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override { + const Sizes& bits) const override { return ring_arshift(in, bits).as(in.eltype()); } }; @@ -441,8 +441,8 @@ class Ref2kTruncS : public TruncAKernel { // https://stackoverflow.com/questions/14008330/how-do-you-multiply-two-fixed-point-numbers // Under certain pattern, like sum(mul(A, B)), error can accumulate in a // fairly significant way - auto v1 = ring_arshift(in, bits); - auto v2 = ring_arshift(in, bits - 1); + auto v1 = ring_arshift(in, {static_cast(bits)}); + auto v2 = ring_arshift(in, {static_cast(bits - 1)}); ring_and_(v2, ring_ones(in.eltype().as()->field(), in.shape())); ring_add_(v1, v2); return v1; @@ -458,7 +458,8 @@ class Ref2kMsbS : public UnaryKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override { - return ring_rshift(in, in.elsize() * 8 - 1).as(in.eltype()); + return ring_rshift(in, {static_cast(in.elsize() * 8 - 1)}) + .as(in.eltype()); } }; diff --git a/libspu/mpc/securenn/arithmetic.cc b/libspu/mpc/securenn/arithmetic.cc index 242b8e75..365637d7 100644 --- a/libspu/mpc/securenn/arithmetic.cc +++ b/libspu/mpc/securenn/arithmetic.cc @@ -19,8 +19,6 @@ #include #include "libspu/core/type_util.h" -#include "libspu/core/vectorize.h" -#include "libspu/mpc/ab_api.h" #include "libspu/mpc/common/communicator.h" #include "libspu/mpc/common/prg_state.h" #include "libspu/mpc/common/pv2k.h" @@ -37,7 +35,7 @@ NdArrayRef A2V::proc(KernelEvalContext* ctx, const NdArrayRef& in, auto numel = in.numel(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { std::vector share(numel); NdArrayView _in(in); pforeach(0, numel, [&](int64_t idx) { share[idx] = _in[idx]; }); @@ -109,7 +107,7 @@ NdArrayRef RandA::proc(KernelEvalContext* ctx, const Shape& shape) const { // - https://eprint.iacr.org/2019/599.pdf // It's safer to keep the number within [-2**(k-2), 2**(k-2)) for comparison // operations. - return ring_rshift(prg_state->genPriv(field, shape), 2) + return ring_rshift(prg_state->genPriv(field, shape), {2}) .as(makeType(field)); } @@ -200,10 +198,7 @@ NdArrayRef MatMulAP::proc(KernelEvalContext* ctx, const NdArrayRef& x, } NdArrayRef LShiftA::proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const { - const auto field = in.eltype().as()->field(); - bits %= SizeOf(field) * 8; - + const Sizes& bits) const { return ring_lshift(in, bits).as(in.eltype()); } @@ -216,10 +211,10 @@ NdArrayRef TruncAPr::proc(KernelEvalContext* ctx, const NdArrayRef& in, auto rank = comm->getRank(); const auto numel = in.numel(); const auto field = in.eltype().as()->field(); - const size_t k = SizeOf(field) * 8; + const int k = SizeOf(field) * 8; NdArrayRef out(in.eltype(), in.shape()); - DISPATCH_ALL_FIELDS(field, "securenn.truncpr", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = ring2k_t; auto r = prg_state->genPriv(field, in.shape()); @@ -232,9 +227,11 @@ NdArrayRef TruncAPr::proc(KernelEvalContext* ctx, const NdArrayRef& in, auto rb_recon = comm->reduce(ReduceOp::ADD, rb, 2, "rb"); if (rank == 2) { - auto adjust1 = - ring_sub(ring_rshift(ring_lshift(r_recon, 1), bits + 1), rc_recon); - auto adjust2 = ring_sub(ring_rshift(r_recon, k - 1), rb_recon); + auto adjust1 = ring_sub(ring_rshift(ring_lshift(r_recon, {1}), + {static_cast(bits + 1)}), + rc_recon); + auto adjust2 = ring_sub( + ring_rshift(r_recon, {static_cast(k - 1)}), rb_recon); comm->sendAsync(0, adjust1, "adjust1"); comm->sendAsync(0, adjust2, "adjust2"); } @@ -362,7 +359,6 @@ NdArrayRef MulAA::proc(KernelEvalContext* ctx, const NdArrayRef& x, b = prg_state ->genPrssPair(field, x.shape(), PrgState::GenPrssCtrl::Second) .second; - prg_state->genPrssPair(field, x.shape(), PrgState::GenPrssCtrl::None); c = comm->recv(2, ty, "c"); c = c.reshape(x.shape()); } @@ -527,7 +523,6 @@ NdArrayRef MatMulAA::proc(KernelEvalContext* ctx, const NdArrayRef& x, .second; b = prg_state->genPrssPair(field, shape2, PrgState::GenPrssCtrl::Second) .second; - prg_state->genPrssPair(field, shape3, PrgState::GenPrssCtrl::None); c = comm->recv(2, ty, "c"); c = c.reshape(shape3); @@ -594,7 +589,7 @@ NdArrayRef ShareConvert::proc(KernelEvalContext* ctx, const auto kComm = a.elsize() * size; comm->addCommStatsManually(4, 4 * log_p * kComm + 6 * kComm); - DISPATCH_ALL_FIELDS(field, "securenn.sc", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = ring2k_t; const U L_1 = (U)(~0); // 2^k - 1 // P0 and P1 add the share of zero @@ -902,7 +897,7 @@ NdArrayRef Msb::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { const auto kComm = in.elsize() * size; comm->addCommStatsManually(5, 13 * kComm + 4 * kComm * log_p); - DISPATCH_ALL_FIELDS(field, "securenn.msb", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = ring2k_t; const U L_1 = (U)(~0); @@ -1050,7 +1045,6 @@ NdArrayRef Msb::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { prg_state ->genPrssPair(field, in.shape(), PrgState::GenPrssCtrl::Second) .second; - prg_state->genPrssPair(field, in.shape(), PrgState::GenPrssCtrl::None); beaver_c = comm->recv(2, ty, "beaver_c"); beaver_c = beaver_c.reshape(in.shape()); } @@ -1241,7 +1235,7 @@ NdArrayRef Msb_opt::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { const auto kComm = in.elsize() * size; comm->addCommStatsManually(5, 9 * kComm + 3 * kComm * log_p); - DISPATCH_ALL_FIELDS(field, "securenn.msb", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = ring2k_t; const U L_1 = (U)(~0); @@ -1395,7 +1389,6 @@ NdArrayRef Msb_opt::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { prg_state ->genPrssPair(field, in.shape(), PrgState::GenPrssCtrl::Second) .second; - prg_state->genPrssPair(field, in.shape(), PrgState::GenPrssCtrl::None); beaver_c = comm->recv(2, ty, "beaver_c"); beaver_c = beaver_c.reshape(in.shape()); } diff --git a/libspu/mpc/securenn/arithmetic.h b/libspu/mpc/securenn/arithmetic.h index 599bf14a..e0829464 100644 --- a/libspu/mpc/securenn/arithmetic.h +++ b/libspu/mpc/securenn/arithmetic.h @@ -172,7 +172,7 @@ class LShiftA : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override; + const Sizes& bits) const override; }; // Refer to: diff --git a/libspu/mpc/securenn/boolean.cc b/libspu/mpc/securenn/boolean.cc index 7d08673e..9598b5c2 100644 --- a/libspu/mpc/securenn/boolean.cc +++ b/libspu/mpc/securenn/boolean.cc @@ -31,7 +31,7 @@ namespace { size_t getNumBits(const NdArrayRef& in) { if (in.eltype().isa()) { const auto field = in.eltype().as()->field(); - return DISPATCH_ALL_FIELDS(field, "_", + return DISPATCH_ALL_FIELDS(field, [&]() { return maxBitWidth(in); }); } else if (in.eltype().isa()) { return in.eltype().as()->nbits(); @@ -120,7 +120,7 @@ NdArrayRef AndBP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const size_t out_nbits = std::min(getNumBits(lhs), getNumBits(rhs)); NdArrayRef out(makeType(field, out_nbits), lhs.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _lhs(lhs); NdArrayView _rhs(rhs); NdArrayView _out(out); @@ -146,12 +146,12 @@ NdArrayRef AndBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const int64_t numel = lhs.numel(); NdArrayRef out(makeType(field, out_nbits), lhs.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = ring2k_t; NdArrayView _lhs(lhs); NdArrayView _rhs(rhs); - DISPATCH_UINT_PT_TYPES(backtype, "_", [&]() { + DISPATCH_UINT_PT_TYPES(backtype, [&]() { using V = ScalarT; int64_t numBytes = numel * SizeOf(backtype); @@ -197,7 +197,6 @@ NdArrayRef AndBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, b = prg_state ->genPrssPair(field, {numField}, PrgState::GenPrssCtrl::Second) .second; - prg_state->genPrssPair(field, {numField}, PrgState::GenPrssCtrl::None); c = comm->recv(2, ty, "c"); c = c.reshape({numField}); } @@ -257,32 +256,32 @@ NdArrayRef XorBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, } NdArrayRef LShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t shift) const { + const Sizes& shift) const { const auto field = in.eltype().as()->field(); - shift %= SizeOf(field) * 8; - size_t out_nbits = in.eltype().as()->nbits() + shift; - out_nbits = std::clamp(out_nbits, static_cast(0), SizeOf(field) * 8); + int64_t out_nbits = in.eltype().as()->nbits() + + *std::max_element(shift.begin(), shift.end()); + out_nbits = + std::clamp(out_nbits, 0L, static_cast(SizeOf(field) * 8)); return makeBShare(ring_lshift(in, shift), field, out_nbits); } NdArrayRef RShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t shift) const { + const Sizes& shift) const { const auto field = in.eltype().as()->field(); - shift %= SizeOf(field) * 8; - size_t nbits = in.eltype().as()->nbits(); - size_t out_nbits = nbits - std::min(nbits, shift); - SPU_ENFORCE(nbits <= SizeOf(field) * 8); + int64_t nbits = in.eltype().as()->nbits(); + int64_t out_nbits = + nbits - std::min(nbits, *std::min_element(shift.begin(), shift.end())); + SPU_ENFORCE(nbits <= static_cast(SizeOf(field) * 8)); return makeBShare(ring_rshift(in, shift), field, out_nbits); } NdArrayRef ARShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t shift) const { + const Sizes& shift) const { const auto field = in.eltype().as()->field(); - shift %= SizeOf(field) * 8; // arithmetic right shift expects to work on ring, or the behaviour is // undefined. @@ -310,7 +309,7 @@ NdArrayRef BitIntlB::proc(KernelEvalContext* ctx, const NdArrayRef& in, NdArrayRef out(in.eltype(), in.shape()); auto numel = in.numel(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _in(in); NdArrayView _out(out); @@ -331,7 +330,7 @@ NdArrayRef BitDeintlB::proc(KernelEvalContext* ctx, const NdArrayRef& in, NdArrayRef out(in.eltype(), in.shape()); auto numel = in.numel(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _in(in); NdArrayView _out(out); diff --git a/libspu/mpc/securenn/boolean.h b/libspu/mpc/securenn/boolean.h index 35433c0b..ccaaa90b 100644 --- a/libspu/mpc/securenn/boolean.h +++ b/libspu/mpc/securenn/boolean.h @@ -118,7 +118,7 @@ class LShiftB : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t shift) const override; + const Sizes& shift) const override; }; class RShiftB : public ShiftKernel { @@ -130,7 +130,7 @@ class RShiftB : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t shift) const override; + const Sizes& shift) const override; }; class ARShiftB : public ShiftKernel { @@ -142,7 +142,7 @@ class ARShiftB : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t shift) const override; + const Sizes& shift) const override; }; class BitrevB : public BitrevKernel { diff --git a/libspu/mpc/securenn/conversion.cc b/libspu/mpc/securenn/conversion.cc index 1fe99fe4..8a05d8a4 100644 --- a/libspu/mpc/securenn/conversion.cc +++ b/libspu/mpc/securenn/conversion.cc @@ -112,8 +112,12 @@ NdArrayRef B2A_Randbit::proc(KernelEvalContext* ctx, auto randbits = prg_state->genPriv(field, {numel * static_cast(nbits)}); // reconstruct ranbits - if (rank == 0) comm->sendAsync(2, randbits, "randbits0"); - if (rank == 1) comm->sendAsync(2, randbits, "randbits1"); + if (rank == 0) { + comm->sendAsync(2, randbits, "randbits0"); + } + if (rank == 1) { + comm->sendAsync(2, randbits, "randbits1"); + } if (rank == 2) { auto randbits0 = comm->recv(0, makeType(field), "randbits0"); randbits0 = randbits0.reshape(randbits.shape()); @@ -132,7 +136,7 @@ NdArrayRef B2A_Randbit::proc(KernelEvalContext* ctx, auto res = NdArrayRef(makeType(field), x.shape()); - DISPATCH_ALL_FIELDS(field, kBindName, [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = ring2k_t; NdArrayView _randbits(randbits); diff --git a/libspu/mpc/semi2k/arithmetic.cc b/libspu/mpc/semi2k/arithmetic.cc index 90591da6..017c4640 100644 --- a/libspu/mpc/semi2k/arithmetic.cc +++ b/libspu/mpc/semi2k/arithmetic.cc @@ -38,7 +38,7 @@ NdArrayRef RandA::proc(KernelEvalContext* ctx, const Shape& shape) const { // - https://eprint.iacr.org/2019/599.pdf // It's safer to keep the number within [-2**(k-2), 2**(k-2)) for comparison // operations. - return ring_rshift(prg_state->genPriv(field, shape), 2) + return ring_rshift(prg_state->genPriv(field, shape), {2}) .as(makeType(field)); } @@ -73,7 +73,7 @@ NdArrayRef A2V::proc(KernelEvalContext* ctx, const NdArrayRef& in, auto numel = in.numel(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { std::vector share(numel); NdArrayView _in(in); pforeach(0, numel, [&](int64_t idx) { share[idx] = _in[idx]; }); @@ -364,10 +364,7 @@ NdArrayRef MatMulAA::proc(KernelEvalContext* ctx, const NdArrayRef& x, } NdArrayRef LShiftA::proc(KernelEvalContext*, const NdArrayRef& in, - size_t bits) const { - const auto field = in.eltype().as()->field(); - bits %= SizeOf(field) * 8; - + const Sizes& bits) const { return ring_lshift(in, bits).as(in.eltype()); } @@ -381,7 +378,7 @@ NdArrayRef TruncA::proc(KernelEvalContext* ctx, const NdArrayRef& x, if (comm->getWorldSize() == 2) { // SecureML, local truncation. // Ref: Theorem 1. https://eprint.iacr.org/2017/396.pdf - return ring_arshift(x, bits).as(x.eltype()); + return ring_arshift(x, {static_cast(bits)}).as(x.eltype()); } else { // ABY3, truncation pair method. // Ref: Section 5.1.2 https://eprint.iacr.org/2018/403.pdf @@ -399,7 +396,7 @@ NdArrayRef TruncA::proc(KernelEvalContext* ctx, const NdArrayRef& x, auto x_r = comm->allReduce(ReduceOp::ADD, ring_sub(x, r), kBindName); auto res = rb; if (comm->getRank() == 0) { - ring_add_(res, ring_arshift(x_r, bits)); + ring_add_(res, ring_arshift(x_r, {static_cast(bits)})); } // res = [x-r] + [r], x which [*] is truncation operation. @@ -418,7 +415,7 @@ NdArrayRef TruncAPr::proc(KernelEvalContext* ctx, const NdArrayRef& in, NdArrayRef out(in.eltype(), in.shape()); - DISPATCH_ALL_FIELDS(field, "semi2k.truncpr", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = ring2k_t; auto [r, rc, rb] = beaver->TruncPr(field, numel, bits); SPU_ENFORCE(static_cast(r.size()) == numel * SizeOf(field)); diff --git a/libspu/mpc/semi2k/arithmetic.h b/libspu/mpc/semi2k/arithmetic.h index 221a9f44..8b9510f8 100644 --- a/libspu/mpc/semi2k/arithmetic.h +++ b/libspu/mpc/semi2k/arithmetic.h @@ -206,7 +206,7 @@ class LShiftA : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override; + const Sizes& bits) const override; }; class TruncA : public TruncAKernel { diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_test.cc b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_test.cc index 636766ec..58dcdc80 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_test.cc +++ b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_test.cc @@ -14,7 +14,6 @@ #include -#include "fmt/format.h" #include "gtest/gtest.h" #include "yacl/crypto/key_utils.h" #include "yacl/link/algorithm/barrier.h" @@ -187,7 +186,7 @@ TEST_P(BeaverTest, Mul_large) { auto open = open_buffer(triples, kField, std::vector(3, {kNumel}), kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _b(open[1]); NdArrayView _c(open[2]); @@ -214,7 +213,7 @@ TEST_P(BeaverTest, Mul_large) { auto open = open_buffer(triples, kField, std::vector(3, {kNumel}), kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _cache_a(x_cache); NdArrayView _b(open[1]); @@ -240,7 +239,7 @@ TEST_P(BeaverTest, Mul_large) { auto open = open_buffer(triples, kField, std::vector(3, {kNumel}), kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _b(open[1]); NdArrayView _cache_b(y_cache); @@ -267,7 +266,7 @@ TEST_P(BeaverTest, Mul_large) { auto open = open_buffer(triples, kField, std::vector(3, {kNumel}), kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _b(open[1]); NdArrayView _cache_a(x_cache); @@ -297,7 +296,7 @@ TEST_P(BeaverTest, Mul_large) { auto open = open_buffer(triples, kField, std::vector(3, {kNumel}), kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _b(open[1]); NdArrayView _cache_a(x_cache); @@ -342,7 +341,7 @@ TEST_P(BeaverTest, Mul) { auto open = open_buffer(triples, kField, std::vector(3, {kNumel}), kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _b(open[1]); NdArrayView _c(open[2]); @@ -369,7 +368,7 @@ TEST_P(BeaverTest, Mul) { auto open = open_buffer(triples, kField, std::vector(3, {kNumel}), kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _cache_a(x_cache); NdArrayView _b(open[1]); @@ -395,7 +394,7 @@ TEST_P(BeaverTest, Mul) { auto open = open_buffer(triples, kField, std::vector(3, {kNumel}), kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _b(open[1]); NdArrayView _cache_b(y_cache); @@ -422,7 +421,7 @@ TEST_P(BeaverTest, Mul) { auto open = open_buffer(triples, kField, std::vector(3, {kNumel}), kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _b(open[1]); NdArrayView _cache_a(x_cache); @@ -452,7 +451,7 @@ TEST_P(BeaverTest, Mul) { auto open = open_buffer(triples, kField, std::vector(3, {kNumel}), kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _b(open[1]); NdArrayView _cache_a(x_cache); @@ -539,7 +538,7 @@ TEST_P(BeaverTest, Dot) { kWorldSize, true); auto res = ring_mmul(open[0], open[1]); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _r(res); NdArrayView _c(open[2]); for (auto idx = 0; idx < res.numel(); idx++) { @@ -565,7 +564,7 @@ TEST_P(BeaverTest, Dot) { kWorldSize, true); auto res = ring_mmul(x_cache, open[1]); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _cache_a(x_cache); NdArrayView _r(res); @@ -592,7 +591,7 @@ TEST_P(BeaverTest, Dot) { kWorldSize, true); auto res = ring_mmul(open[0], y_cache); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _b(open[1]); NdArrayView _cache_b(y_cache); NdArrayView _r(res); @@ -620,7 +619,7 @@ TEST_P(BeaverTest, Dot) { kWorldSize, true); auto res = ring_mmul(x_cache, y_cache); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _cache_a(x_cache); NdArrayView _b(open[1]); @@ -651,7 +650,7 @@ TEST_P(BeaverTest, Dot) { kWorldSize, true); auto res = ring_mmul(x_cache, y_cache); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { auto transpose_a = open[0].transpose(); NdArrayView _a(transpose_a); NdArrayView _cache_a(y_cache); @@ -684,7 +683,7 @@ TEST_P(BeaverTest, Dot) { auto open = open_buffer(triples, kField, std::vector(3, {M * K}), kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _cache_a(x_cache); NdArrayView _b(open[1]); @@ -725,7 +724,7 @@ TEST_P(BeaverTest, Dot_large) { open_buffer(triples, kField, {{M, K}, {K, N}, {M, N}}, kWorldSize, true); auto res = ring_mmul(open[0], open[1]); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _r(res); NdArrayView _c(open[2]); for (auto idx = 0; idx < res.numel(); idx++) { @@ -741,7 +740,7 @@ TEST_P(BeaverTest, Trunc) { const FieldType kField = std::get<2>(GetParam()); const size_t adjust_rank = std::get<4>(GetParam()); const int64_t kNumel = 7; - const size_t kBits = 5; + const int64_t kBits = 5; std::vector pairs; pairs.resize(kWorldSize); @@ -756,7 +755,7 @@ TEST_P(BeaverTest, Trunc) { EXPECT_EQ(pairs.size(), kWorldSize); auto open = open_buffer(pairs, kField, {{kNumel}, {kNumel}}, kWorldSize, true); - EXPECT_TRUE(ring_all_equal(ring_arshift(open[0], kBits), open[1], 0)); + EXPECT_TRUE(ring_all_equal(ring_arshift(open[0], {kBits}), open[1], 0)); } TEST_P(BeaverTest, TruncPr) { @@ -782,7 +781,7 @@ TEST_P(BeaverTest, TruncPr) { auto open = open_buffer(rets, kField, std::vector(3, {kNumel}), kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "semi2k.truncpr.ut", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { using T = ring2k_t; auto sum_r_iter = open[0].begin(); auto sum_rc_iter = open[1].begin(); @@ -821,7 +820,7 @@ TEST_P(BeaverTest, Randbit) { EXPECT_EQ(shares.size(), kWorldSize); auto open = open_buffer(shares, kField, {{kNumel}}, kWorldSize, true); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { using scalar_t = typename Ring2kTrait<_kField>::scalar_t; auto x = xt_adapt(open[0]); EXPECT_TRUE(xt::all(x <= xt::ones_like(x))); diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.cc b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.cc index 574d892a..d3d09401 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.cc +++ b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.cc @@ -48,7 +48,7 @@ void FillReplayDesc(Beaver::ReplayDesc* desc, FieldType field, int64_t size, } // namespace BeaverTfpUnsafe::BeaverTfpUnsafe(std::shared_ptr lctx) - : lctx_(std::move(std::move(lctx))), + : lctx_(std::move(lctx)), seed_(yacl::crypto::SecureRandSeed()), counter_(0) { auto buf = yacl::SerializeUint128(seed_); diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.cc b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.cc index 4f3094c3..cca2c571 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.cc +++ b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.cc @@ -152,7 +152,7 @@ std::vector RpcCall(brpc::Channel& channel, AdjustRequest req, } // namespace BeaverTtp::BeaverTtp(std::shared_ptr lctx, Options ops) - : lctx_(std::move(std::move(lctx))), + : lctx_(std::move(lctx)), seed_(yacl::crypto::SecureRandSeed()), counter_(0), options_(std::move(ops)) { diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.cc b/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.cc index c16c6af3..80ef18e7 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.cc +++ b/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.cc @@ -115,7 +115,7 @@ NdArrayRef TrustedParty::adjustTrunc(absl::Span ops, auto rs = reconstruct(RecOp::ADD, ops); // adjust = (rs[0] >> bits) - rs[1]; - return ring_sub(ring_arshift(rs[0], bits), rs[1]); + return ring_sub(ring_arshift(rs[0], {static_cast(bits)}), rs[1]); } std::pair TrustedParty::adjustTruncPr( @@ -127,11 +127,14 @@ std::pair TrustedParty::adjustTruncPr( auto rs = reconstruct(RecOp::ADD, ops); // adjust1 = ((rs[0] << 1) >> (bits + 1)) - rs[1]; - auto adjust1 = ring_sub(ring_rshift(ring_lshift(rs[0], 1), bits + 1), rs[1]); + auto adjust1 = ring_sub( + ring_rshift(ring_lshift(rs[0], {1}), {static_cast(bits + 1)}), + rs[1]); // adjust2 = (rs[0] >> (k - 1)) - rs[2]; const size_t k = SizeOf(ops[0].desc.field) * 8; - auto adjust2 = ring_sub(ring_rshift(rs[0], k - 1), rs[2]); + auto adjust2 = + ring_sub(ring_rshift(rs[0], {static_cast(k - 1)}), rs[2]); return {adjust1, adjust2}; } diff --git a/libspu/mpc/semi2k/boolean.cc b/libspu/mpc/semi2k/boolean.cc index 4dae73a1..3fa9bf4a 100644 --- a/libspu/mpc/semi2k/boolean.cc +++ b/libspu/mpc/semi2k/boolean.cc @@ -30,7 +30,7 @@ namespace { size_t getNumBits(const NdArrayRef& in) { if (in.eltype().isa()) { const auto field = in.eltype().as()->field(); - return DISPATCH_ALL_FIELDS(field, "_", + return DISPATCH_ALL_FIELDS(field, [&]() { return maxBitWidth(in); }); } else if (in.eltype().isa()) { return in.eltype().as()->nbits(); @@ -119,7 +119,7 @@ NdArrayRef AndBP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const size_t out_nbits = std::min(getNumBits(lhs), getNumBits(rhs)); NdArrayRef out(makeType(field, out_nbits), lhs.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _lhs(lhs); NdArrayView _rhs(rhs); NdArrayView _out(out); @@ -144,12 +144,12 @@ NdArrayRef AndBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, // semi2k always use the same storage type. NdArrayRef out(makeType(field, out_nbits), lhs.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = ring2k_t; NdArrayView _lhs(lhs); NdArrayView _rhs(rhs); - DISPATCH_UINT_PT_TYPES(backtype, "_", [&]() { + DISPATCH_UINT_PT_TYPES(backtype, [&]() { using V = ScalarT; // TODO: redefine beaver interface, generate variadic beaver and bits. @@ -215,32 +215,31 @@ NdArrayRef XorBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, } NdArrayRef LShiftB::proc(KernelEvalContext*, const NdArrayRef& in, - size_t shift) const { + const Sizes& shift) const { const auto field = in.eltype().as()->field(); - shift %= SizeOf(field) * 8; - size_t out_nbits = in.eltype().as()->nbits() + shift; + size_t out_nbits = in.eltype().as()->nbits() + + *std::max_element(shift.begin(), shift.end()); out_nbits = std::clamp(out_nbits, static_cast(0), SizeOf(field) * 8); return makeBShare(ring_lshift(in, shift), field, out_nbits); } NdArrayRef RShiftB::proc(KernelEvalContext*, const NdArrayRef& in, - size_t shift) const { + const Sizes& shift) const { const auto field = in.eltype().as()->field(); - shift %= SizeOf(field) * 8; - size_t nbits = in.eltype().as()->nbits(); - size_t out_nbits = nbits - std::min(nbits, shift); - SPU_ENFORCE(nbits <= SizeOf(field) * 8); + int64_t nbits = in.eltype().as()->nbits(); + int64_t out_nbits = + nbits - std::min(nbits, *std::min_element(shift.begin(), shift.end())); + SPU_ENFORCE(nbits <= static_cast(SizeOf(field) * 8)); return makeBShare(ring_rshift(in, shift), field, out_nbits); } NdArrayRef ARShiftB::proc(KernelEvalContext*, const NdArrayRef& in, - size_t shift) const { + const Sizes& shift) const { const auto field = in.eltype().as()->field(); - shift %= SizeOf(field) * 8; // arithmetic right shift expects to work on ring, or the behaviour is // undefined. @@ -268,7 +267,7 @@ NdArrayRef BitIntlB::proc(KernelEvalContext*, const NdArrayRef& in, NdArrayRef out(in.eltype(), in.shape()); auto numel = in.numel(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _in(in); NdArrayView _out(out); @@ -289,7 +288,7 @@ NdArrayRef BitDeintlB::proc(KernelEvalContext*, const NdArrayRef& in, NdArrayRef out(in.eltype(), in.shape()); auto numel = in.numel(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _in(in); NdArrayView _out(out); diff --git a/libspu/mpc/semi2k/boolean.h b/libspu/mpc/semi2k/boolean.h index 8c7e6835..8ca39710 100644 --- a/libspu/mpc/semi2k/boolean.h +++ b/libspu/mpc/semi2k/boolean.h @@ -118,7 +118,7 @@ class LShiftB : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t shift) const override; + const Sizes& shift) const override; }; class RShiftB : public ShiftKernel { @@ -130,7 +130,7 @@ class RShiftB : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t shift) const override; + const Sizes& shift) const override; }; class ARShiftB : public ShiftKernel { @@ -142,7 +142,7 @@ class ARShiftB : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t shift) const override; + const Sizes& shift) const override; }; class BitrevB : public BitrevKernel { diff --git a/libspu/mpc/semi2k/conversion.cc b/libspu/mpc/semi2k/conversion.cc index dce1857f..ae638ce4 100644 --- a/libspu/mpc/semi2k/conversion.cc +++ b/libspu/mpc/semi2k/conversion.cc @@ -135,7 +135,7 @@ NdArrayRef B2A_Randbit::proc(KernelEvalContext* ctx, rand_numel * SizeOf(field)); auto res = NdArrayRef(makeType(field), x.shape()); - DISPATCH_ALL_FIELDS(field, kBindName, [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = ring2k_t; absl::Span _randbits(randbits.data(), rand_numel); @@ -143,7 +143,7 @@ NdArrayRef B2A_Randbit::proc(KernelEvalContext* ctx, // algorithm begins. // Ref: III.D @ https://eprint.iacr.org/2019/599.pdf (SPDZ-2K primitives) - DISPATCH_UINT_PT_TYPES(backtype, "_", [&]() { + DISPATCH_UINT_PT_TYPES(backtype, [&]() { using V = ScalarT; std::vector x_xor_r(numel); @@ -229,13 +229,13 @@ std::vector B2A_Disassemble::proc(KernelEvalContext* ctx, for (int64_t idx = 0; idx < nbits; ++idx) { res.emplace_back(makeType(field), x.shape()); } - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = ring2k_t; absl::Span _randbits(randbits.data(), rand_numel); NdArrayView _x(x); - DISPATCH_UINT_PT_TYPES(backtype, "_", [&]() { + DISPATCH_UINT_PT_TYPES(backtype, [&]() { using V = ScalarT; std::vector x_xor_r(numel); @@ -305,7 +305,9 @@ NdArrayRef MsbA2B::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { // Compute the k'th bit. // (m^n)[k] ^ carry - auto msb = xor_bb(sctx, rshift_b(sctx, xor_bb(sctx, m, n), k), carry); + auto msb = xor_bb( + sctx, rshift_b(sctx, xor_bb(sctx, m, n), {static_cast(k)}), + carry); return UnwrapValue(msb); } @@ -327,7 +329,7 @@ NdArrayRef eqz(KernelEvalContext* ctx, const NdArrayRef& in) { // beaver samples r and deals [r]a and [r]b // receal c = a+r // check a == 0 <=> c == r - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using el_t = ring2k_t; auto [ra_buf, rb_buf] = beaver->Eqz(field, numel); @@ -355,11 +357,11 @@ NdArrayRef eqz(KernelEvalContext* ctx, const NdArrayRef& in) { // TODO: fix AND triple // in beaver->AND(field, shape), min FM32, need min 1byte to reduce comm NdArrayRef round_out = rb.as(makeType(field)); - size_t cur_bits = round_out.eltype().as()->nbits(); + int64_t cur_bits = round_out.eltype().as()->nbits(); while (cur_bits != 1) { cur_bits /= 2; - round_out = - wrap_and_bb(ctx->sctx(), round_out, ring_rshift(round_out, cur_bits)); + round_out = wrap_and_bb(ctx->sctx(), round_out, + ring_rshift(round_out, {cur_bits})); } // 1 bit info in lsb diff --git a/libspu/mpc/semi2k/permute.cc b/libspu/mpc/semi2k/permute.cc index 28278f8c..21787a42 100644 --- a/libspu/mpc/semi2k/permute.cc +++ b/libspu/mpc/semi2k/permute.cc @@ -82,7 +82,7 @@ NdArrayRef RandPermM::proc(KernelEvalContext* ctx, const Shape& shape) const { const auto perm_vector = genRandomPerm(out.numel(), _seed[0]); const auto field = out.eltype().as()->field(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _out(out); pforeach(0, out.numel(), [&](int64_t idx) { _out[idx] = ring2k_t(perm_vector[idx]); }); diff --git a/libspu/mpc/spdz2k/abprotocol_spdz2k_test.cc b/libspu/mpc/spdz2k/abprotocol_spdz2k_test.cc index b49e34c2..8bed069f 100644 --- a/libspu/mpc/spdz2k/abprotocol_spdz2k_test.cc +++ b/libspu/mpc/spdz2k/abprotocol_spdz2k_test.cc @@ -242,7 +242,7 @@ TEST_P(ConversionTest, BitLT) { auto re = b2p(obj.get(), tmp); const auto field = p0.storage_type().as()->field(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = std::make_unsigned::type; size_t numel = kShape.numel(); @@ -291,7 +291,7 @@ TEST_P(ConversionTest, BitLE) { auto re = b2p(obj.get(), tmp); const auto field = p0.storage_type().as()->field(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = std::make_unsigned::type; size_t numel = kShape.numel(); NdArrayView p0_data(p0.data()); diff --git a/libspu/mpc/spdz2k/arithmetic.cc b/libspu/mpc/spdz2k/arithmetic.cc index 6e902a61..11eeb186 100644 --- a/libspu/mpc/spdz2k/arithmetic.cc +++ b/libspu/mpc/spdz2k/arithmetic.cc @@ -47,9 +47,9 @@ NdArrayRef CastRing(const NdArrayRef& in, FieldType out_field) { const auto in_field = in_ty->field(); auto out = ring_zeros(out_field, in.shape()); - return DISPATCH_ALL_FIELDS(in_field, "_", [&]() { + return DISPATCH_ALL_FIELDS(in_field, [&]() { NdArrayView _in(in); - return DISPATCH_ALL_FIELDS(out_field, "_", [&]() { + return DISPATCH_ALL_FIELDS(out_field, [&]() { NdArrayView _out(out); pforeach(0, in.numel(), [&](int64_t idx) { _out[idx] = static_cast(_in[idx]); @@ -106,7 +106,7 @@ NdArrayRef RandA::proc(KernelEvalContext* ctx, const Shape& shape) const { // - https://eprint.iacr.org/2019/599.pdf // It's safer to keep the number within [-2**(k-2), 2**(k-2)) for comparison // operations. - auto x = ring_rshift(prg_state->genPriv(field, shape), 2) + auto x = ring_rshift(prg_state->genPriv(field, shape), {2}) .as(makeType(field)); auto x_mac = beaver->AuthArrayRef(x, field, k, s); return makeAShare(x, x_mac, field); @@ -296,7 +296,7 @@ bool SingleCheck(KernelEvalContext* ctx, const NdArrayRef& in) { auto* comm = ctx->getState(); auto* beaver = ctx->getState()->beaver(); const auto key = ctx->getState()->key(); - const size_t k = ctx->getState()->k(); + const int64_t k = ctx->getState()->k(); const size_t s = ctx->getState()->s(); // 1. Generate a random, shared value [r] @@ -305,8 +305,8 @@ bool SingleCheck(KernelEvalContext* ctx, const NdArrayRef& in) { // 2. Locally construct [y] const auto& x = getValueShare(in); const auto& x_mac = getMacShare(in); - auto y = ring_add(x, ring_lshift(r, k)); - auto y_mac = ring_add(x_mac, ring_lshift(r_mac, k)); + auto y = ring_add(x, ring_lshift(r, {k})); + auto y_mac = ring_add(x_mac, ring_lshift(r_mac, {k})); // 3. Open the value auto plain_y = comm->allReduce(ReduceOp::ADD, y, kBindName); @@ -334,7 +334,7 @@ bool SingleCheck(KernelEvalContext* ctx, const NdArrayRef& in) { static NdArrayRef wrap_lshift_a(SPUContext* ctx, const NdArrayRef& x, size_t k) { - return UnwrapValue(lshift_a(ctx, WrapValue(x), k)); + return UnwrapValue(lshift_a(ctx, WrapValue(x), {static_cast(k)})); } static NdArrayRef wrap_add_aa(SPUContext* ctx, const NdArrayRef& x, @@ -579,9 +579,8 @@ NdArrayRef MatMulAA::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, } NdArrayRef LShiftA::proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const { + const Sizes& bits) const { const auto field = in.eltype().as()->field(); - bits %= SizeOf(field) * 8; // in const auto& x = getValueShare(in); @@ -617,7 +616,9 @@ NdArrayRef TruncA::proc(KernelEvalContext* ctx, const NdArrayRef& in, beaver->BatchOpen(ring_sub(x, r), ring_sub(x_mac, r_mac), k, s); SPU_ENFORCE(beaver->BatchMacCheck(x_r, check_mac, k, s)); size_t bit_len = SizeOf(field) * 8; - auto tr_x_r = ring_arshift(ring_lshift(x_r, bit_len - k), bit_len - k + bits); + auto tr_x_r = + ring_arshift(ring_lshift(x_r, {static_cast(bit_len - k)}), + {static_cast(bit_len - k + bits)}); ring_bitmask_(tr_x_r, 0, k); // res = [x-r] + [r], which [*] is truncation operation. diff --git a/libspu/mpc/spdz2k/arithmetic.h b/libspu/mpc/spdz2k/arithmetic.h index deb782d5..7edfe33b 100644 --- a/libspu/mpc/spdz2k/arithmetic.h +++ b/libspu/mpc/spdz2k/arithmetic.h @@ -180,7 +180,7 @@ class LShiftA : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override; + const Sizes& bits) const override; }; // Refer to: diff --git a/libspu/mpc/spdz2k/beaver/beaver_test.cc b/libspu/mpc/spdz2k/beaver/beaver_test.cc index 0395ee21..c1508522 100644 --- a/libspu/mpc/spdz2k/beaver/beaver_test.cc +++ b/libspu/mpc/spdz2k/beaver/beaver_test.cc @@ -126,7 +126,7 @@ TEST_P(BeaverTest, AuthAnd) { EXPECT_TRUE(ring_all_equal(ring_mul(sum_c, sum_key), sum_c_mac)) << sum_c << sum_key << sum_c_mac; - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _valid_a(valid_a); NdArrayView _valid_b(valid_b); NdArrayView _valid_c(valid_c); @@ -314,7 +314,8 @@ TEST_P(BeaverTest, AuthTrunc) { const size_t bit_len = SizeOf(kField) * 8; auto trunc_sum_a = - ring_arshift(ring_lshift(sum_a, bit_len - k), bit_len - k + kBits); + ring_arshift(ring_lshift(sum_a, {static_cast(bit_len - k)}), + {static_cast(bit_len - k + kBits)}); ring_bitmask_(trunc_sum_a, 0, k); EXPECT_TRUE(ring_all_equal(trunc_sum_a, ring_bitmask(sum_b, 0, k))) @@ -386,7 +387,7 @@ TEST_P(BeaverTest, AuthDot) { << sum_c << sum_key << sum_c_mac; auto res = ring_mmul(sum_a, sum_b); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { + DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _sum_a(sum_a); NdArrayView _sum_c(sum_c); NdArrayView _res(res); diff --git a/libspu/mpc/spdz2k/beaver/beaver_tfp.cc b/libspu/mpc/spdz2k/beaver/beaver_tfp.cc index 3c7a5ca7..3ab9fd0e 100644 --- a/libspu/mpc/spdz2k/beaver/beaver_tfp.cc +++ b/libspu/mpc/spdz2k/beaver/beaver_tfp.cc @@ -49,7 +49,7 @@ uint128_t BeaverTfpUnsafe::InitSpdzKey(FieldType field, size_t s) { const int64_t size = 1; auto a = prgCreateArray(field, {size}, seed_, &counter_, &desc); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _a(a); if (comm_->getRank() == 0) { auto t = tp_.adjustSpdzKey(desc); @@ -260,8 +260,10 @@ std::pair BeaverTfpUnsafe::BatchOpen( // Open the low k_bits only // value = value + r_val * 2^k // mac = mac + r_mac * 2^k - auto masked_val = ring_add(value, ring_lshift(r_val, k)); - auto masked_mac = ring_add(mac, ring_lshift(r_mac, k)); + auto masked_val = + ring_add(value, ring_lshift(r_val, {static_cast(k)})); + auto masked_mac = + ring_add(mac, ring_lshift(r_mac, {static_cast(k)})); auto open_val = comm_->allReduce(ReduceOp::ADD, masked_val, kBindName); return {open_val, masked_mac}; diff --git a/libspu/mpc/spdz2k/beaver/beaver_tinyot.cc b/libspu/mpc/spdz2k/beaver/beaver_tinyot.cc index f0481944..ed7f2ae8 100644 --- a/libspu/mpc/spdz2k/beaver/beaver_tinyot.cc +++ b/libspu/mpc/spdz2k/beaver/beaver_tinyot.cc @@ -68,7 +68,7 @@ NdArrayRef ring_sqrt2k(const NdArrayRef& x, size_t bits = 0) { } auto ret = ring_zeros(field, x.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = std::make_unsigned::type; NdArrayView _ret(ret); NdArrayView _x(x); @@ -101,7 +101,7 @@ NdArrayRef ring_inv2k(const NdArrayRef& x, size_t bits = 0) { } auto ret = ring_zeros(field, x.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = std::make_unsigned::type; NdArrayView _ret(ret); NdArrayView _x(x); @@ -118,7 +118,7 @@ std::vector ring_cast_vector_boolean(const NdArrayRef& x) { const auto field = x.eltype().as()->field(); std::vector res(x.numel()); - DISPATCH_ALL_FIELDS(field, "RingOps", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _x(x); yacl::parallel_for(0, x.numel(), 4096, [&](size_t start, size_t end) { for (size_t i = start; i < end; i++) { @@ -192,7 +192,7 @@ uint128_t BeaverTinyOt::InitSpdzKey(FieldType, size_t s) { // - https://eprint.iacr.org/2018/482.pdf NdArrayRef BeaverTinyOt::AuthArrayRef(const NdArrayRef& x, FieldType field, size_t k, size_t s) { - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using T = ring2k_t; // 1. l_ = max(l, r + s, 2s) @@ -346,7 +346,7 @@ BeaverTinyOt::Triple_Pair BeaverTinyOt::AuthAnd(FieldType field, // Generate authorize bits in the form of B-Share NdArrayRef spdz_choices(makeType(field), {tinyot_num * 3 + sigma}); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = std::make_unsigned::type; auto _size = auth_abcr.choices.size(); NdArrayView _spdz_choices(spdz_choices); @@ -392,7 +392,7 @@ BeaverTinyOt::Triple_Pair BeaverTinyOt::AuthAnd(FieldType field, auto seed = GenSharedSeed(comm_); auto prg = yacl::crypto::Prg(seed); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = std::make_unsigned::type; NdArrayView _check_spdz_bit(check_spdz_bit); NdArrayView _check_spdz_mac(check_spdz_mac); @@ -546,7 +546,7 @@ BeaverTinyOt::Pair_Pair BeaverTinyOt::AuthTrunc(FieldType field, NdArrayRef tr_val(b_val.eltype(), shape); NdArrayRef tr_mac(b_val.eltype(), shape); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using PShrT = ring2k_t; NdArrayView _val(b_val); NdArrayView _mac(b_mac); @@ -642,7 +642,7 @@ BeaverTinyOt::Pair BeaverTinyOt::AuthRandBit(FieldType field, SPU_ENFORCE(ring_all_equal(ring_bitmask(square, 0, 1), ones)); auto root = ring_sqrt2k(square, k + 2); auto root_inv = ring_inv2k(root, k + 2); - auto root_inv_div2 = ring_rshift(root_inv, 1); + auto root_inv_div2 = ring_rshift(root_inv, {1}); auto d = ring_mul(root_inv_div2, y); auto d_mac = ring_mul(root_inv_div2, y_mac); @@ -751,17 +751,19 @@ std::pair BeaverTinyOt::BatchOpen( // Open the low k_bits only // value = value + r * 2^k // mac = mac + r_mac * 2^k - auto masked_val = ring_add(value, ring_lshift(r_val, k)); - auto masked_mac = ring_add(mac, ring_lshift(r_mac, k)); + auto masked_val = + ring_add(value, ring_lshift(r_val, {static_cast(k)})); + auto masked_mac = + ring_add(mac, ring_lshift(r_mac, {static_cast(k)})); - // Because we would use Maccheck to comfirm the open value. + // Because we would use Maccheck to confirm the open value. // Thus, we don't need commit them. auto open_val = comm_->allReduce(ReduceOp::ADD, masked_val, kBindName); return {open_val, masked_mac}; } void BeaverTinyOt::rotSend(FieldType field, NdArrayRef* q0, NdArrayRef* q1) { - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = ring2k_t; SPDLOG_DEBUG("rotSend start with numel {}", q0->numel()); @@ -784,7 +786,7 @@ void BeaverTinyOt::rotSend(FieldType field, NdArrayRef* q0, NdArrayRef* q1) { // todo: use dynamic_bitset instead of ArrayRef for `a` to improve performance void BeaverTinyOt::rotRecv(FieldType field, const NdArrayRef& a, NdArrayRef* s) { - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = ring2k_t; SPDLOG_DEBUG("rotRecv start with numel {}", a.numel()); @@ -814,7 +816,7 @@ void BeaverTinyOt::rotRecv(FieldType field, const NdArrayRef& a, // SPDZ2k: Efficient MPC mod 2k for Dishonest Majority // - https://eprint.iacr.org/2018/482.pdf NdArrayRef BeaverTinyOt::voleSend(FieldType field, const NdArrayRef& x) { - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using T = ring2k_t; SPU_ENFORCE(spdz2k_ot_primitives_ != nullptr); @@ -831,7 +833,7 @@ NdArrayRef BeaverTinyOt::voleSend(FieldType field, const NdArrayRef& x) { } NdArrayRef BeaverTinyOt::voleRecv(FieldType field, const NdArrayRef& alpha) { - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using T = ring2k_t; SPU_ENFORCE(spdz2k_ot_primitives_ != nullptr); @@ -915,7 +917,7 @@ BeaverTinyOt::Triple_Pair BeaverTinyOt::AuthMul(FieldType field, size_t s) { auto _size = shape.numel(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using T = ring2k_t; SPDLOG_DEBUG("AuthMul start..."); diff --git a/libspu/mpc/spdz2k/beaver/trusted_party.cc b/libspu/mpc/spdz2k/beaver/trusted_party.cc index 978f1eee..287d1f76 100644 --- a/libspu/mpc/spdz2k/beaver/trusted_party.cc +++ b/libspu/mpc/spdz2k/beaver/trusted_party.cc @@ -248,9 +248,10 @@ std::vector TrustedParty::adjustAuthTrunc( ring_add_(r0[0], ring_sub(rs[0], t_rs)); // r0[1] += (rs[0] >> bits) - rs[1]; - const size_t bit_len = SizeOf(field) * 8; + const int64_t bit_len = SizeOf(field) * 8; auto tr_rs0 = - ring_arshift(ring_lshift(rs[0], bit_len - k), bit_len - k + bits); + ring_arshift(ring_lshift(rs[0], {static_cast(bit_len - k)}), + {static_cast(bit_len - k + bits)}); ring_bitmask_(tr_rs0, 0, k); ring_add_(r0[1], ring_sub(tr_rs0, rs[1])); diff --git a/libspu/mpc/spdz2k/boolean.cc b/libspu/mpc/spdz2k/boolean.cc index 226de5ad..7f230a28 100644 --- a/libspu/mpc/spdz2k/boolean.cc +++ b/libspu/mpc/spdz2k/boolean.cc @@ -38,7 +38,7 @@ NdArrayRef P2Value(FieldType out_field, const NdArrayRef& in, size_t k, size_t new_nbits = 0) { const auto* in_ty = in.eltype().as(); const auto in_field = in_ty->field(); - return DISPATCH_ALL_FIELDS(in_field, "_", [&]() { + return DISPATCH_ALL_FIELDS(in_field, [&]() { using PShrT = ring2k_t; size_t valid_nbits = k; @@ -52,7 +52,7 @@ NdArrayRef P2Value(FieldType out_field, const NdArrayRef& in, size_t k, out_shape.back() *= valid_nbits; auto out = ring_zeros(out_field, out_shape); - return DISPATCH_ALL_FIELDS(out_field, "_", [&]() { + return DISPATCH_ALL_FIELDS(out_field, [&]() { using BShrT = ring2k_t; NdArrayView _in(in); NdArrayView _out(out); @@ -279,10 +279,10 @@ NdArrayRef B2P::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { // 2. Maccheck SPU_ENFORCE(beaver_ptr->BatchMacCheck(pub, mac, 1, s)); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using BShrT = ring2k_t; auto& value = pub; - return DISPATCH_ALL_FIELDS(out_field, "_", [&]() { + return DISPATCH_ALL_FIELDS(out_field, [&]() { using PShrT = ring2k_t; NdArrayRef out(makeType(out_field), in.shape()); @@ -374,7 +374,7 @@ NdArrayRef BitrevB::proc(KernelEvalContext* ctx, const NdArrayRef& in, auto ret = x.clone(); auto ret_mac = x_mac.clone(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _ret(ret); NdArrayView _ret_mac(ret_mac); NdArrayView _x(x); @@ -515,37 +515,48 @@ NdArrayRef AndBP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, } NdArrayRef LShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const { + const Sizes& bits) const { + SPU_ENFORCE(bits.size() == 1, "only support splat shift bits, but got {}", + bits); + const auto field = in.eltype().as()->field(); const auto k = ctx->getState()->k(); const size_t nbits = in.eltype().as()->nbits(); - size_t res_nbits = nbits + bits; + size_t bit = bits[0]; + size_t res_nbits = nbits + bit; - if (bits >= k) { + if (bit >= k) { res_nbits = 1; } else if (res_nbits > k) { res_nbits = k; } - auto [ret, ret_mac] = LShiftBImpl(in, bits, k); + auto [ret, ret_mac] = LShiftBImpl(in, bit, k); return makeBShare(ret, ret_mac, field, res_nbits); } NdArrayRef RShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const { + const Sizes& bits) const { + SPU_ENFORCE(bits.size() == 1, "only support splat shift bits, but got {}", + bits); + + size_t bit = bits[0]; const auto field = in.eltype().as()->field(); const auto nbits = in.eltype().as()->nbits(); - size_t new_nbis = nbits > bits ? nbits - bits : 1; - auto [ret, ret_mac] = RShiftBImpl(in, bits); + size_t new_nbis = nbits > bit ? nbits - bit : 1; + auto [ret, ret_mac] = RShiftBImpl(in, bit); return makeBShare(ret, ret_mac, field, new_nbis); } static NdArrayRef wrap_rshift_b(SPUContext* ctx, const NdArrayRef& x, - size_t bits) { + const Sizes& bits) { return UnwrapValue(rshift_b(ctx, WrapValue(x), bits)); } NdArrayRef ARShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const { + const Sizes& bits) const { + SPU_ENFORCE(bits.size() == 1, "only support splat shift bits, but got {}", + bits); + const auto field = in.eltype().as()->field(); const auto k = ctx->getState()->k(); const auto nbits = in.eltype().as()->nbits(); @@ -553,7 +564,7 @@ NdArrayRef ARShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, if (nbits != k) { return wrap_rshift_b(ctx->sctx(), in, bits); } else { - auto [ret, ret_mac] = ARShiftBImpl(in, bits, k); + auto [ret, ret_mac] = ARShiftBImpl(in, bits[0], k); return makeBShare(ret, ret_mac, field, k); } } @@ -565,7 +576,7 @@ NdArrayRef BitIntlB::proc(KernelEvalContext* ctx, const NdArrayRef& in, const auto k = ctx->getState()->k(); NdArrayRef out(in.eltype(), in.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = ring2k_t; if (in.eltype().isa()) { @@ -612,7 +623,7 @@ NdArrayRef BitDeintlB::proc(KernelEvalContext* ctx, const NdArrayRef& in, const auto k = ctx->getState()->k(); NdArrayRef out(in.eltype(), in.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = ring2k_t; if (in.eltype().isa()) { diff --git a/libspu/mpc/spdz2k/boolean.h b/libspu/mpc/spdz2k/boolean.h index 6d2d9c9f..7f42cd36 100644 --- a/libspu/mpc/spdz2k/boolean.h +++ b/libspu/mpc/spdz2k/boolean.h @@ -146,7 +146,7 @@ class LShiftB : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override; + const Sizes& bits) const override; }; class RShiftB : public ShiftKernel { @@ -158,7 +158,7 @@ class RShiftB : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override; + const Sizes& bits) const override; }; class ARShiftB : public ShiftKernel { @@ -170,7 +170,7 @@ class ARShiftB : public ShiftKernel { ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, - size_t bits) const override; + const Sizes& bits) const override; }; class BitIntlB : public BitSplitKernel { diff --git a/libspu/mpc/spdz2k/conversion.cc b/libspu/mpc/spdz2k/conversion.cc index 8f8ded08..122d1c26 100644 --- a/libspu/mpc/spdz2k/conversion.cc +++ b/libspu/mpc/spdz2k/conversion.cc @@ -113,7 +113,7 @@ CircuitBasicBlock MakeSPDZBasicBlock(SPUContext* ctx) { cbb._and = [=](T const& x, T const& y) -> T { COMMUTATIVE_DISPATCH(and_pp, and_bp, and_bb); }; - cbb.lshift = [=](T const& x, size_t bits) -> T { + cbb.lshift = [=](T const& x, const Sizes& bits) -> T { if (_IsP(x)) { return lshift_p(ctx, x, bits); } else if (_IsB(x)) { @@ -121,7 +121,7 @@ CircuitBasicBlock MakeSPDZBasicBlock(SPUContext* ctx) { } SPU_THROW("unsupported op x={}", x); }; - cbb.rshift = [=](T const& x, size_t bits) -> T { + cbb.rshift = [=](T const& x, const Sizes& bits) -> T { if (_IsP(x)) { return rshift_p(ctx, x, bits); } else if (_IsB(x)) { @@ -191,7 +191,7 @@ NdArrayRef Bit2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { ring_bitmask_(c, 0, 1); // 5. [x] = c + [r] - 2 * c * [r] - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _c(c); NdArrayView _r(r); NdArrayView _r_mac(r_mac); @@ -292,7 +292,7 @@ NdArrayRef B2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { ring_bitmask_(c, 0, 1); // 4. [x] = c + [r] - 2 * c * [r] - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { NdArrayRef out(makeType(field, true), out_shape); NdArrayRef expand_out(makeType(field, true), _in.shape()); @@ -354,8 +354,8 @@ NdArrayRef MSB::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { // then set r = \sum r_i 2^{i} for (int64_t i = 0; i < k; ++i) { auto [_r_i, _r_i_mac] = beaver->AuthRandBit(field, in.shape(), k, s); - ring_add_(_r_val, ring_lshift(_r_i, i)); - ring_add_(_r_mac, ring_lshift(_r_i_mac, i)); + ring_add_(_r_val, ring_lshift(_r_i, {i})); + ring_add_(_r_mac, ring_lshift(_r_i_mac, {i})); // record r_i & r_i_mac _r_vec.emplace_back(std::move(_r_i)); _r_mac_vec.emplace_back(std::move(_r_i_mac)); @@ -381,8 +381,8 @@ NdArrayRef MSB::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { const auto ty = makeType(field); for (int64_t i = 0; i < k - 1; ++i) { - ring_add_(_ar, ring_lshift(_r_vec[i], i)); - ring_add_(_ar_mac, ring_lshift(_r_mac_vec[i], i)); + ring_add_(_ar, ring_lshift(_r_vec[i], {i})); + ring_add_(_ar_mac, ring_lshift(_r_mac_vec[i], {i})); auto at_r_i = makeAShare(_r_vec[i], _r_mac_vec[i], field); auto bt_r_i = wrap_a2bit(ctx->sctx(), at_r_i); @@ -415,8 +415,8 @@ NdArrayRef MSB::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { // d = a - a' auto _au = getValueShare(au); auto _au_mac = getMacShare(au); - auto _aa = ring_sub(ring_lshift(_au, k - 1), _ar); - auto _aa_mac = ring_sub(ring_lshift(_au_mac, k - 1), _ar_mac); + auto _aa = ring_sub(ring_lshift(_au, {k - 1}), _ar); + auto _aa_mac = ring_sub(ring_lshift(_au_mac, {k - 1}), _ar_mac); if (comm->getRank() == 0) { ring_add_(_aa, _c_open); } @@ -426,18 +426,18 @@ NdArrayRef MSB::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { // 7. let e = d + 2^{k-1} b, then open e auto [_b, _b_mac] = beaver->AuthRandBit(field, in.shape(), k, s); - auto _e = ring_add(_d, ring_lshift(_b, k - 1)); - auto _e_mac = ring_add(_d_mac, ring_lshift(_b_mac, k - 1)); + auto _e = ring_add(_d, ring_lshift(_b, {k - 1})); + auto _e_mac = ring_add(_d_mac, ring_lshift(_b_mac, {k - 1})); auto [e_open, e_zero_mac] = beaver->BatchOpen(_e, _e_mac, k, s); SPU_ENFORCE(beaver->BatchMacCheck(e_open, e_zero_mac, k, s)); // 8. e' be the most significant bit of e - auto _ee = ring_bitmask(ring_rshift(e_open, k - 1), 0, 1); + auto _ee = ring_bitmask(ring_rshift(e_open, {k - 1}), 0, 1); // 9. output e_{k-1} + b - 2 e_{k-1} b - auto _ret = ring_sub(_b, ring_lshift(ring_mul(_b, _ee), 1)); - auto _ret_mac = ring_sub(_b_mac, ring_lshift(ring_mul(_b_mac, _ee), 1)); + auto _ret = ring_sub(_b, ring_lshift(ring_mul(_b, _ee), {1})); + auto _ret_mac = ring_sub(_b_mac, ring_lshift(ring_mul(_b_mac, _ee), {1})); if (comm->getRank() == 0) { ring_add_(_ret, _ee); } diff --git a/libspu/mpc/spdz2k/io.cc b/libspu/mpc/spdz2k/io.cc index 136ba2d7..dff59022 100644 --- a/libspu/mpc/spdz2k/io.cc +++ b/libspu/mpc/spdz2k/io.cc @@ -60,9 +60,9 @@ std::vector Spdz2kIo::toShares(const NdArrayRef& raw, const auto runtime_field = getRuntimeField(field); NdArrayRef x(makeType(runtime_field), raw.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _raw(raw); - DISPATCH_ALL_FIELDS(runtime_field, "_", [&]() { + DISPATCH_ALL_FIELDS(runtime_field, [&]() { NdArrayView _x(x); pforeach(0, raw.numel(), [&](int64_t idx) { _x[idx] = static_cast(_raw[idx]); @@ -113,9 +113,9 @@ NdArrayRef Spdz2kIo::fromShares(const std::vector& shares) const { { NdArrayRef x(makeType(field_), res.shape()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _res(res); - DISPATCH_ALL_FIELDS(field_, "_", [&]() { + DISPATCH_ALL_FIELDS(field_, [&]() { NdArrayView _x(x); pforeach(0, x.numel(), [&](int64_t idx) { _x[idx] = static_cast(_res[idx]); diff --git a/libspu/mpc/spdz2k/value.cc b/libspu/mpc/spdz2k/value.cc index a4e49fb8..029f0651 100644 --- a/libspu/mpc/spdz2k/value.cc +++ b/libspu/mpc/spdz2k/value.cc @@ -56,7 +56,7 @@ NdArrayRef makeBShare(const NdArrayRef& s1, const NdArrayRef& s2, NdArrayRef res(ty, new_shape); int64_t res_numel = res.numel(); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView> _res(res); pforeach(0, res_numel * k, [&](int64_t i) { @@ -108,7 +108,7 @@ NdArrayRef getShare(const NdArrayRef& in, int64_t share_idx) { } else { NdArrayRef ret(ty, new_shape); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { size_t numel = in.numel(); NdArrayView _ret(ret); NdArrayView> _in(in); @@ -141,7 +141,7 @@ size_t maxNumBits(const NdArrayRef& lhs, const NdArrayRef& rhs) { } const auto* rhs_ty = rhs.eltype().as(); const auto rhs_field = rhs_ty->field(); - return DISPATCH_ALL_FIELDS(rhs_field, "_", [&]() { + return DISPATCH_ALL_FIELDS(rhs_field, [&]() { using PShrT = ring2k_t; return std::max(lhs.eltype().as()->nbits(), maxBitWidth(rhs)); @@ -158,7 +158,7 @@ size_t minNumBits(const NdArrayRef& lhs, const NdArrayRef& rhs) { } const auto* rhs_ty = rhs.eltype().as(); const auto rhs_field = rhs_ty->field(); - return DISPATCH_ALL_FIELDS(rhs_field, "_", [&]() { + return DISPATCH_ALL_FIELDS(rhs_field, [&]() { using PShrT = ring2k_t; return std::min(lhs.eltype().as()->nbits(), maxBitWidth(rhs)); diff --git a/libspu/mpc/tools/benchmark.h b/libspu/mpc/tools/benchmark.h index 811dc9aa..764a50e5 100644 --- a/libspu/mpc/tools/benchmark.h +++ b/libspu/mpc/tools/benchmark.h @@ -172,10 +172,14 @@ class OpData { } for (auto& b1 : b1s) { b1 = p2b(obj_, rand_p(obj_, Shape{state.range(1)})); - b1 = lshift_b(obj_, b1, - SizeOf(static_cast(state.range(0))) * 8 - 1); - b1 = rshift_b(obj_, b1, - SizeOf(static_cast(state.range(0))) * 8 - 1); + b1 = lshift_b( + obj_, b1, + {static_cast( + SizeOf(static_cast(state.range(0))) * 8 - 1)}); + b1 = rshift_b( + obj_, b1, + {static_cast( + SizeOf(static_cast(state.range(0))) * 8 - 1)}); } } virtual ~OpData() = default; @@ -218,12 +222,12 @@ MPC_BENCH_DEFINE(BenchAndSP, OpData1S1P, and_sp, ss[0], ps[0]) MPC_BENCH_DEFINE(BenchXorSP, OpData1S1P, xor_sp, ss[0], ps[0]) MPC_BENCH_DEFINE(BenchNotS, OpData1S, not_s, ss[0]) MPC_BENCH_DEFINE(BenchNotP, OpData1P, not_p, ps[0]) -MPC_BENCH_DEFINE(BenchLShiftS, OpData1S, lshift_s, ss[0], state.range(2)) -MPC_BENCH_DEFINE(BenchLShiftP, OpData1P, lshift_p, ps[0], state.range(2)) -MPC_BENCH_DEFINE(BenchRShiftS, OpData1S, rshift_s, ss[0], state.range(2)) -MPC_BENCH_DEFINE(BenchRShiftP, OpData1P, rshift_p, ps[0], state.range(2)) -MPC_BENCH_DEFINE(BenchARShiftS, OpData1S, arshift_s, ss[0], state.range(2)) -MPC_BENCH_DEFINE(BenchARShiftP, OpData1P, arshift_p, ps[0], state.range(2)) +MPC_BENCH_DEFINE(BenchLShiftS, OpData1S, lshift_s, ss[0], {state.range(2)}) +MPC_BENCH_DEFINE(BenchLShiftP, OpData1P, lshift_p, ps[0], {state.range(2)}) +MPC_BENCH_DEFINE(BenchRShiftS, OpData1S, rshift_s, ss[0], {state.range(2)}) +MPC_BENCH_DEFINE(BenchRShiftP, OpData1P, rshift_p, ps[0], {state.range(2)}) +MPC_BENCH_DEFINE(BenchARShiftS, OpData1S, arshift_s, ss[0], {state.range(2)}) +MPC_BENCH_DEFINE(BenchARShiftP, OpData1P, arshift_p, ps[0], {state.range(2)}) MPC_BENCH_DEFINE(BenchTruncS, OpData1S, trunc_s_wrapper, ss[0], state.range(2)) MPC_BENCH_DEFINE(BenchS2P, OpData1S, s2p, ss[0]) MPC_BENCH_DEFINE(BenchP2S, OpData1P, p2s, ps[0]) @@ -243,7 +247,7 @@ MPC_BENCH_DEFINE(BenchMulAP, OpData1A1P, mul_ap, as[0], ps[0]) MPC_BENCH_DEFINE(BenchAddAA, OpData2A, add_aa, as[0], as[1]) MPC_BENCH_DEFINE(BenchMulAA, OpData2A, mul_aa, as[0], as[1]) MPC_BENCH_DEFINE(BenchMulA1B, OpData1A1B1, mul_a1b, as[0], b1s[0]) -MPC_BENCH_DEFINE(BenchLShiftA, OpData1A, lshift_a, as[0], state.range(2)) +MPC_BENCH_DEFINE(BenchLShiftA, OpData1A, lshift_a, as[0], {state.range(2)}) MPC_BENCH_DEFINE(BenchTruncA, OpData1A, trunc_a_wrapper, as[0], state.range(2)) // MPC_BENCH_DEFINE(BenchMMulAP, OpData1MA1MP, mmul_ap, mas[0], mps[0], // state.range(1), state.range(1), state.range(1)) @@ -257,9 +261,9 @@ MPC_BENCH_DEFINE(BenchAndBP, OpData1B1P, and_bp, bs[0], ps[0]) MPC_BENCH_DEFINE(BenchAndBB, OpData2B, and_bb, bs[0], bs[1]) MPC_BENCH_DEFINE(BenchXorBP, OpData1B1P, xor_bp, bs[0], ps[0]) MPC_BENCH_DEFINE(BenchXorBB, OpData2B, xor_bb, bs[0], bs[1]) -MPC_BENCH_DEFINE(BenchLShiftB, OpData1B, lshift_b, bs[0], state.range(2)) -MPC_BENCH_DEFINE(BenchRShiftB, OpData1B, rshift_b, bs[0], state.range(2)) -MPC_BENCH_DEFINE(BenchARShiftB, OpData1B, arshift_b, bs[0], state.range(2)) +MPC_BENCH_DEFINE(BenchLShiftB, OpData1B, lshift_b, bs[0], {state.range(2)}) +MPC_BENCH_DEFINE(BenchRShiftB, OpData1B, rshift_b, bs[0], {state.range(2)}) +MPC_BENCH_DEFINE(BenchARShiftB, OpData1B, arshift_b, bs[0], {state.range(2)}) MPC_BENCH_DEFINE(BenchBitRevB, OpData1B, bitrev_b, bs[0], 0, SizeOf(static_cast(state.range(0)))) MPC_BENCH_DEFINE(BenchBitIntlB, OpData1B, bitintl_b, bs[0], 0) diff --git a/libspu/mpc/utils/BUILD.bazel b/libspu/mpc/utils/BUILD.bazel index dea5932c..7523c4f3 100644 --- a/libspu/mpc/utils/BUILD.bazel +++ b/libspu/mpc/utils/BUILD.bazel @@ -22,6 +22,7 @@ spu_cc_library( hdrs = ["circuits.h"], deps = [ "//libspu/core:bit_utils", + "//libspu/core:shape", "//libspu/core:vectorize", ], ) diff --git a/libspu/mpc/utils/circuits.h b/libspu/mpc/utils/circuits.h index 3d032d6b..06cad9a0 100644 --- a/libspu/mpc/utils/circuits.h +++ b/libspu/mpc/utils/circuits.h @@ -22,6 +22,7 @@ #include "yacl/base/int128.h" #include "libspu/core/bit_utils.h" +#include "libspu/core/shape.h" #include "libspu/core/vectorize.h" namespace spu::mpc { @@ -35,10 +36,10 @@ struct CircuitBasicBlock { using And = std::function; // (logical) left shift - using LShift = std::function; + using LShift = std::function; // (logical) right shift - using RShift = std::function; + using RShift = std::function; // Init a constant. using InitLike = std::function; @@ -72,9 +73,9 @@ T kogge_stone(const CircuitBasicBlock& ctx, T const& lhs, T const& rhs, auto G = ctx._and(lhs, rhs); for (int idx = 0; idx < Log2Ceil(nbits); ++idx) { - const size_t offset = 1UL << idx; - auto G1 = ctx.lshift(G, offset); - auto P1 = ctx.lshift(P, offset); + const int64_t offset = 1L << idx; + auto G1 = ctx.lshift(G, {offset}); + auto P1 = ctx.lshift(P, {offset}); // P1 = P & P1 // G1 = G ^ (P & G1) @@ -90,7 +91,7 @@ T kogge_stone(const CircuitBasicBlock& ctx, T const& lhs, T const& rhs, } // out = (G << 1) ^ p0 - auto C = ctx.lshift(G, 1); + auto C = ctx.lshift(G, {1}); return ctx._xor(ctx._xor(lhs, rhs), C); } @@ -122,12 +123,12 @@ T sklansky(const CircuitBasicBlock& ctx, T const& lhs, T const& rhs, auto G = ctx._and(lhs, rhs); for (int idx = 0; idx < Log2Ceil(nbits); ++idx) { const auto s_mask = ctx.init_like(G, kSelMask[idx]); - auto G1 = ctx.lshift(ctx._and(G, s_mask), 1); - auto P1 = ctx.lshift(ctx._and(P, s_mask), 1); + auto G1 = ctx.lshift(ctx._and(G, s_mask), {1}); + auto P1 = ctx.lshift(ctx._and(P, s_mask), {1}); for (int j = 0; j < idx; j++) { - G1 = ctx._xor(G1, ctx.lshift(G1, 1 << j)); - P1 = ctx._xor(P1, ctx.lshift(P1, 1 << j)); + G1 = ctx._xor(G1, ctx.lshift(G1, {1 << j})); + P1 = ctx._xor(P1, ctx.lshift(P1, {1 << j})); } const auto k_mask = ctx.init_like(G, kKeepMasks[idx]); @@ -147,7 +148,7 @@ T sklansky(const CircuitBasicBlock& ctx, T const& lhs, T const& rhs, } // out = (G0 << 1) ^ p0 - auto C = ctx.lshift(G, 1); + auto C = ctx.lshift(G, {1}); return ctx._xor(ctx._xor(lhs, rhs), C); } @@ -181,21 +182,22 @@ T odd_even_split(const CircuitBasicBlock& ctx, const T& v, size_t nbits) { }}; // let r = v - T r = ctx.lshift(v, 0); + T r = ctx.lshift(v, {0}); for (int idx = 0; idx + 1 < Log2Ceil(nbits); ++idx) { // r = (r & keep) ^ ((r >> i) & move) ^ ((r & move) << i) const auto keep = ctx.init_like(r, kKeepMasks[idx]); const auto move = ctx.init_like(r, kSwapMasks[idx]); r = ctx._xor(ctx._and(r, keep), - ctx._xor(ctx._and(ctx.rshift(r, 1 << idx), move), - ctx.lshift(ctx._and(r, move), 1 << idx))); + ctx._xor(ctx._and(ctx.rshift(r, {1 << idx}), move), + ctx.lshift(ctx._and(r, move), {1 << idx}))); } if (!absl::has_single_bit(nbits)) { // handle non 2^k bits case. T mask = ctx.init_like(r, (1ULL << (nbits / 2)) - 1); - r = ctx._xor(ctx.lshift(ctx.rshift(r, 1 << Log2Floor(nbits)), nbits / 2), + r = ctx._xor(ctx.lshift(ctx.rshift(r, {1 << Log2Floor(nbits)}), + {static_cast(nbits / 2)}), ctx._and(r, mask)); } @@ -228,7 +230,7 @@ T carry_out(const CircuitBasicBlock& ctx, const T& x, const T& y, auto perm = odd_even_split(ctx, in, kk); T mask = ctx.init_like(perm, (static_cast(1) << hk) - 1); T t0 = ctx._and(perm, mask); - T t1 = ctx._and(ctx.rshift(perm, hk), mask); + T t1 = ctx._and(ctx.rshift(perm, {static_cast(hk)}), mask); ctx.set_nbits(t0, hk); ctx.set_nbits(t1, hk); return std::make_tuple(t0, t1); @@ -247,8 +249,8 @@ T carry_out(const CircuitBasicBlock& ctx, const T& x, const T& y, while (k > 1) { if (k % 2 != 0) { k += 1; - P = ctx.lshift(P, 1); - G = ctx.lshift(G, 1); + P = ctx.lshift(P, {1}); + G = ctx.lshift(G, {1}); } auto [P0, P1] = bit_split(P, k); auto [G0, G1] = bit_split(G, k); diff --git a/libspu/mpc/utils/circuits_test.cc b/libspu/mpc/utils/circuits_test.cc index 1e8e034e..c0e82433 100644 --- a/libspu/mpc/utils/circuits_test.cc +++ b/libspu/mpc/utils/circuits_test.cc @@ -26,8 +26,8 @@ CircuitBasicBlock makeScalarCBB() { CircuitBasicBlock cbb; cbb._xor = [](T const& lhs, T const& rhs) -> T { return lhs ^ rhs; }; cbb._and = [](T const& lhs, T const& rhs) -> T { return lhs & rhs; }; - cbb.lshift = [](T const& x, size_t bits) -> T { return x << bits; }; - cbb.rshift = [](T const& x, size_t bits) -> T { return x >> bits; }; + cbb.lshift = [](T const& x, const Sizes& bits) -> T { return x << bits[0]; }; + cbb.rshift = [](T const& x, const Sizes& bits) -> T { return x >> bits[0]; }; cbb.init_like = [](T const&, uint128_t init) -> T { return static_cast(init); }; @@ -53,16 +53,16 @@ CircuitBasicBlock makeVectorCBB() { std::bit_and<>()); return res; }; - cbb.lshift = [](C const& x, size_t bits) -> C { + cbb.lshift = [](C const& x, const Sizes& bits) -> C { C res; std::transform(x.begin(), x.end(), std::back_inserter(res), - [&](const auto& e) { return e << bits; }); + [&](const auto& e) { return e << bits[0]; }); return res; }; - cbb.rshift = [](C const& x, size_t bits) -> C { + cbb.rshift = [](C const& x, const Sizes& bits) -> C { C res; std::transform(x.begin(), x.end(), std::back_inserter(res), - [&](const auto& e) { return e >> bits; }); + [&](const auto& e) { return e >> bits[0]; }); return res; }; cbb.set_nbits = [](C&, size_t) -> void {}; diff --git a/libspu/mpc/utils/permute.cc b/libspu/mpc/utils/permute.cc index 1c5d2920..9ef42a61 100644 --- a/libspu/mpc/utils/permute.cc +++ b/libspu/mpc/utils/permute.cc @@ -27,7 +27,7 @@ PermVector ring2pv(const NdArrayRef& x) { x.eltype()); const auto field = x.eltype().as()->field(); PermVector pv(x.numel()); - DISPATCH_ALL_FIELDS(field, "_", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _x(x); pforeach(0, x.numel(), [&](int64_t idx) { pv[idx] = int64_t(_x[idx]); }); }); @@ -39,7 +39,7 @@ NdArrayRef applyInvPerm(const NdArrayRef& x, absl::Span pv) { NdArrayRef y(x.eltype(), x.shape()); const auto field = x.eltype().as()->field(); - DISPATCH_ALL_FIELDS(field, kPermModule, [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _x(x); NdArrayView _y(y); for (int64_t i = 0; i < y.numel(); i++) { @@ -54,7 +54,7 @@ NdArrayRef applyPerm(const NdArrayRef& x, absl::Span pv) { NdArrayRef y(x.eltype(), x.shape()); const auto field = x.eltype().as()->field(); - DISPATCH_ALL_FIELDS(field, kPermModule, [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _x(x); NdArrayView _y(y); for (int64_t i = 0; i < y.numel(); i++) { @@ -86,7 +86,7 @@ PermVector genPermBySort(const NdArrayRef& x) { PermVector perm(x.shape()[0]); std::iota(perm.begin(), perm.end(), 0); const auto field = x.eltype().as()->field(); - DISPATCH_ALL_FIELDS(field, kPermModule, [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using T = std::make_signed_t; NdArrayView _x(x); diff --git a/libspu/mpc/utils/ring_ops.cc b/libspu/mpc/utils/ring_ops.cc index 0d49f743..a09bc56c 100644 --- a/libspu/mpc/utils/ring_ops.cc +++ b/libspu/mpc/utils/ring_ops.cc @@ -29,8 +29,6 @@ namespace spu::mpc { namespace { -constexpr char kModule[] = "RingOps"; - #define SPU_ENFORCE_RING(x) \ SPU_ENFORCE((x).eltype().isa(), "expect ring type, got={}", \ (x).eltype()); @@ -47,7 +45,7 @@ constexpr char kModule[] = "RingOps"; ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, x); \ const auto field = x.eltype().as()->field(); \ const int64_t numel = ret.numel(); \ - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { \ + return DISPATCH_ALL_FIELDS(field, [&]() { \ using T = std::make_signed_t; \ NdArrayView _x(x); \ NdArrayView _ret(ret); \ @@ -67,7 +65,7 @@ DEF_UNARY_RING_OP(ring_neg, -); ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, y); \ const auto field = x.eltype().as()->field(); \ const int64_t numel = ret.numel(); \ - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { \ + return DISPATCH_ALL_FIELDS(field, [&]() { \ NdArrayView _x(x); \ NdArrayView _y(y); \ NdArrayView _ret(ret); \ @@ -86,40 +84,56 @@ DEF_BINARY_RING_OP(ring_xor, ^); #undef DEF_BINARY_RING_OP -void ring_arshift_impl(NdArrayRef& ret, const NdArrayRef& x, size_t bits) { +void ring_arshift_impl(NdArrayRef& ret, const NdArrayRef& x, + const Sizes& bits) { ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, x); + bool is_splat = bits.size() == 1; + SPU_ENFORCE(static_cast(bits.size()) == x.numel() || is_splat, + "mismatched numel {} vs {}", bits.size(), x.numel()); const auto numel = ret.numel(); const auto field = x.eltype().as()->field(); - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { // According to K&R 2nd edition the results are implementation-dependent for // right shifts of signed values, but "usually" its arithmetic right shift. using S = std::make_signed::type; NdArrayView _ret(ret); NdArrayView _x(x); - pforeach(0, numel, [&](int64_t idx) { _ret[idx] = _x[idx] >> bits; }); + pforeach(0, numel, [&](int64_t idx) { + _ret[idx] = _x[idx] >> (is_splat ? bits[0] : bits[idx]); + }); }); } -void ring_rshift_impl(NdArrayRef& ret, const NdArrayRef& x, size_t bits) { +void ring_rshift_impl(NdArrayRef& ret, const NdArrayRef& x, const Sizes& bits) { ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, x); + bool is_splat = bits.size() == 1; + SPU_ENFORCE(static_cast(bits.size()) == x.numel() || is_splat, + "mismatched numel {} vs {}", bits.size(), x.numel()); const auto numel = ret.numel(); const auto field = x.eltype().as()->field(); - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using U = ring2k_t; NdArrayView _ret(ret); NdArrayView _x(x); - pforeach(0, numel, [&](int64_t idx) { _ret[idx] = _x[idx] >> bits; }); + pforeach(0, numel, [&](int64_t idx) { + _ret[idx] = _x[idx] >> (is_splat ? bits[0] : bits[idx]); + }); }); } -void ring_lshift_impl(NdArrayRef& ret, const NdArrayRef& x, size_t bits) { +void ring_lshift_impl(NdArrayRef& ret, const NdArrayRef& x, const Sizes& bits) { ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, x); + bool is_splat = bits.size() == 1; + SPU_ENFORCE(static_cast(bits.size()) == x.numel() || is_splat, + "mismatched numel {} vs {}", bits.size(), x.numel()); const auto numel = ret.numel(); const auto field = x.eltype().as()->field(); - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _ret(ret); NdArrayView _x(x); - pforeach(0, numel, [&](int64_t idx) { _ret[idx] = _x[idx] << bits; }); + pforeach(0, numel, [&](int64_t idx) { + _ret[idx] = _x[idx] << (is_splat ? bits[0] : bits[idx]); + }); }); } @@ -130,7 +144,7 @@ void ring_bitrev_impl(NdArrayRef& ret, const NdArrayRef& x, size_t start, const auto field = x.eltype().as()->field(); const auto numel = ret.numel(); - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using U = ring2k_t; // optimize: use faster reverse method. @@ -161,7 +175,7 @@ void ring_bitmask_impl(NdArrayRef& ret, const NdArrayRef& x, size_t low, SPU_ENFORCE(low < high && high <= SizeOf(field) * 8); - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using U = ring2k_t; U mask = 0; if (high - low < SizeOf(field) * 8) { @@ -184,7 +198,7 @@ void ring_print(const NdArrayRef& x, std::string_view name) { SPU_ENFORCE_RING(x); const auto field = x.eltype().as()->field(); - DISPATCH_ALL_FIELDS(field, kModule, [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = ring2k_t; std::string out; @@ -231,7 +245,7 @@ NdArrayRef ring_rand_range(FieldType field, const Shape& shape, int32_t min, NdArrayRef x(makeType(field), shape); auto numel = x.numel(); - DISPATCH_ALL_FIELDS(field, kModule, [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { SPU_ENFORCE(sizeof(ring2k_t) >= sizeof(int32_t)); auto iter = x.begin(); @@ -250,7 +264,7 @@ void ring_assign(NdArrayRef& x, const NdArrayRef& y) { const auto numel = x.numel(); const auto field = x.eltype().as()->field(); - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _y(y); NdArrayView _x(x); pforeach(0, numel, [&](int64_t idx) { _x[idx] = _y[idx]; }); @@ -261,7 +275,7 @@ NdArrayRef ring_zeros(FieldType field, const Shape& shape) { NdArrayRef ret(makeType(field), shape); auto numel = ret.numel(); - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _ret(ret); pforeach(0, numel, [&](int64_t idx) { _ret[idx] = ring2k_t(0); }); return ret; @@ -272,7 +286,7 @@ NdArrayRef ring_ones(FieldType field, const Shape& shape) { NdArrayRef ret(makeType(field), shape); auto numel = ret.numel(); - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _ret(ret); pforeach(0, numel, [&](int64_t idx) { _ret[idx] = ring2k_t(1); }); return ret; @@ -287,7 +301,7 @@ NdArrayRef ring_randbit(FieldType field, const Shape& shape) { NdArrayRef ret(makeType(field), shape); auto numel = ret.numel(); - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _ret(ret); for (auto idx = 0; idx < numel; ++idx) { _ret[idx] = distrib(gen) & 0x1; @@ -341,7 +355,7 @@ void ring_mul_impl(NdArrayRef& ret, const NdArrayRef& x, uint128_t y) { const auto numel = x.numel(); const auto field = x.eltype().as()->field(); - DISPATCH_ALL_FIELDS(field, kModule, [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using U = std::make_unsigned::type; NdArrayView _x(x); NdArrayView _ret(ret); @@ -368,7 +382,7 @@ void ring_mmul_impl(NdArrayRef& z, const NdArrayRef& lhs, SPU_ENFORCE(rhs.eltype().isa(), "rhs not ring, got={}", rhs.eltype()); const auto field = lhs.eltype().as()->field(); - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { const auto lhs_stride_scale = lhs.elsize() / sizeof(ring2k_t); const auto rhs_stride_scale = rhs.elsize() / sizeof(ring2k_t); const auto ret_stride_scale = z.elsize() / sizeof(ring2k_t); @@ -436,31 +450,35 @@ void ring_equal_(NdArrayRef& x, const NdArrayRef& y) { ring_equal_impl(x, x, y); } -NdArrayRef ring_arshift(const NdArrayRef& x, size_t bits) { +NdArrayRef ring_arshift(const NdArrayRef& x, const Sizes& bits) { NdArrayRef res(x.eltype(), x.shape()); ring_arshift_impl(res, x, bits); return res; } -void ring_arshift_(NdArrayRef& x, size_t bits) { +void ring_arshift_(NdArrayRef& x, const Sizes& bits) { ring_arshift_impl(x, x, bits); } -NdArrayRef ring_rshift(const NdArrayRef& x, size_t bits) { +NdArrayRef ring_rshift(const NdArrayRef& x, const Sizes& bits) { NdArrayRef res(x.eltype(), x.shape()); ring_rshift_impl(res, x, bits); return res; } -void ring_rshift_(NdArrayRef& x, size_t bits) { ring_rshift_impl(x, x, bits); } +void ring_rshift_(NdArrayRef& x, const Sizes& bits) { + ring_rshift_impl(x, x, bits); +} -NdArrayRef ring_lshift(const NdArrayRef& x, size_t bits) { +NdArrayRef ring_lshift(const NdArrayRef& x, const Sizes& bits) { NdArrayRef res(x.eltype(), x.shape()); ring_lshift_impl(res, x, bits); return res; } -void ring_lshift_(NdArrayRef& x, size_t bits) { ring_lshift_impl(x, x, bits); } +void ring_lshift_(NdArrayRef& x, const Sizes& bits) { + ring_lshift_impl(x, x, bits); +} NdArrayRef ring_bitrev(const NdArrayRef& x, size_t start, size_t end) { NdArrayRef res(x.eltype(), x.shape()); @@ -503,7 +521,7 @@ bool ring_all_equal(const NdArrayRef& x, const NdArrayRef& y, size_t abs_err) { auto numel = x.numel(); const auto field = x.eltype().as()->field(); - return DISPATCH_ALL_FIELDS(field, "_", [&]() { + return DISPATCH_ALL_FIELDS(field, [&]() { using T = std::make_signed_t; NdArrayView _x(x); @@ -529,7 +547,7 @@ std::vector ring_cast_boolean(const NdArrayRef& x) { auto numel = x.numel(); std::vector res(numel); - DISPATCH_ALL_FIELDS(field, kModule, [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _x(x); pforeach(0, numel, [&](int64_t idx) { res[idx] = static_cast(_x[idx] & 0x1); @@ -549,7 +567,7 @@ NdArrayRef ring_select(const std::vector& c, const NdArrayRef& x, NdArrayRef z(x.eltype(), x.shape()); const int64_t numel = c.size(); - DISPATCH_ALL_FIELDS(field, kModule, [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _x(x); NdArrayView _y(y); NdArrayView _z(z); diff --git a/libspu/mpc/utils/ring_ops.h b/libspu/mpc/utils/ring_ops.h index cea10a06..8a960ebc 100644 --- a/libspu/mpc/utils/ring_ops.h +++ b/libspu/mpc/utils/ring_ops.h @@ -97,14 +97,14 @@ NdArrayRef ring_equal(const NdArrayRef& x, const NdArrayRef& y); void ring_equal_(NdArrayRef& x, const NdArrayRef& y); DEF_RVALUE_BINARY_RING_OP(ring_equal, true); -NdArrayRef ring_arshift(const NdArrayRef& x, size_t bits); -void ring_arshift_(NdArrayRef& x, size_t bits); +NdArrayRef ring_arshift(const NdArrayRef& x, const Sizes& bits); +void ring_arshift_(NdArrayRef& x, const Sizes& bits); -NdArrayRef ring_rshift(const NdArrayRef& x, size_t bits); -void ring_rshift_(NdArrayRef& x, size_t bits); +NdArrayRef ring_rshift(const NdArrayRef& x, const Sizes& bits); +void ring_rshift_(NdArrayRef& x, const Sizes& bits); -NdArrayRef ring_lshift(const NdArrayRef& x, size_t bits); -void ring_lshift_(NdArrayRef& x, size_t bits); +NdArrayRef ring_lshift(const NdArrayRef& x, const Sizes& bits); +void ring_lshift_(NdArrayRef& x, const Sizes& bits); NdArrayRef ring_bitrev(const NdArrayRef& x, size_t start, size_t end); void ring_bitrev_(NdArrayRef& x, size_t start, size_t end); diff --git a/libspu/version.h b/libspu/version.h index d4a7a9fe..a7f569ae 100644 --- a/libspu/version.h +++ b/libspu/version.h @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#define SPU_VERSION "0.9.2.dev$$DATE$$" +#define SPU_VERSION "0.9.3.dev$$DATE$$" #include From bba503fa2a0e2b4f7463045dd549d0ec28be1e05 Mon Sep 17 00:00:00 2001 From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com> Date: Wed, 17 Jul 2024 17:48:43 +0800 Subject: [PATCH 14/27] chore(deps): update dependency sphinx to v7.4.5 (#773) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [![Mend Renovate](https://app.renovatebot.com/images/banner.svg)](https://renovatebot.com) This PR contains the following updates: | Package | Change | Age | Adoption | Passing | Confidence | |---|---|---|---|---|---| | [sphinx](https://togithub.com/sphinx-doc/sphinx) ([changelog](https://www.sphinx-doc.org/en/master/changes.html)) | `==7.4.4` -> `==7.4.5` | [![age](https://developer.mend.io/api/mc/badges/age/pypi/sphinx/7.4.5?slim=true)](https://docs.renovatebot.com/merge-confidence/) | [![adoption](https://developer.mend.io/api/mc/badges/adoption/pypi/sphinx/7.4.5?slim=true)](https://docs.renovatebot.com/merge-confidence/) | [![passing](https://developer.mend.io/api/mc/badges/compatibility/pypi/sphinx/7.4.4/7.4.5?slim=true)](https://docs.renovatebot.com/merge-confidence/) | [![confidence](https://developer.mend.io/api/mc/badges/confidence/pypi/sphinx/7.4.4/7.4.5?slim=true)](https://docs.renovatebot.com/merge-confidence/) | --- ### Release Notes
sphinx-doc/sphinx (sphinx) ### [`v7.4.5`](https://togithub.com/sphinx-doc/sphinx/blob/HEAD/CHANGES.rst#Release-745-released-Jul-16-2024) [Compare Source](https://togithub.com/sphinx-doc/sphinx/compare/v7.4.4...v7.4.5) \===================================== ## Bugs fixed - [#​12593](https://togithub.com/sphinx-doc/sphinx/issues/12593), [#​12600](https://togithub.com/sphinx-doc/sphinx/issues/12600): Revert coercing the type of selected :confval:`html_sidebars` values to a list. Log an error message when string values are detected. Patch by Adam Turner. - [#​12594](https://togithub.com/sphinx-doc/sphinx/issues/12594): LaTeX: since 7.4.0, :rst:dir:`seealso` and other "light" admonitions now break PDF builds if they contain a :dudir:`figure` directive; and also if they are contained in a table cell (rendered by `tabulary`). Patch by Jean-François B.
--- ### Configuration 📅 **Schedule**: Branch creation - At any time (no schedule defined), Automerge - At any time (no schedule defined). 🚦 **Automerge**: Disabled by config. Please merge this manually once you are satisfied. ♻ **Rebasing**: Whenever PR becomes conflicted, or you tick the rebase/retry checkbox. 🔕 **Ignore**: Close this PR and you won't be reminded about this update again. --- - [ ] If you want to rebase/retry this PR, check this box --- This PR has been generated by [Mend Renovate](https://www.mend.io/free-developer-tools/renovate/). View repository job log [here](https://developer.mend.io/github/secretflow/spu). Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> --- docs/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 3b26b70d..c1a970e9 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,6 +1,6 @@ myst-parser==3.0.1 rstcheck==6.2.4 -sphinx==7.4.4 +sphinx==7.4.5 nbsphinx==0.9.4 sphinx-autobuild==2024.4.16 sphinx-markdown-parser==0.2.4 From cb37d78252d7132f76c07aa2ab5c60b25ff9b1f0 Mon Sep 17 00:00:00 2001 From: Yancheng Zheng <103552181+anakinxc@users.noreply.github.com> Date: Wed, 17 Jul 2024 19:37:47 +0800 Subject: [PATCH 15/27] Update fxp_base.cc --- libspu/kernel/hal/fxp_base.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libspu/kernel/hal/fxp_base.cc b/libspu/kernel/hal/fxp_base.cc index 1594a9ac..601d42c7 100644 --- a/libspu/kernel/hal/fxp_base.cc +++ b/libspu/kernel/hal/fxp_base.cc @@ -86,7 +86,7 @@ void hintNumberOfBits(const Value& a, size_t nbits) { } Value maskNumberOfBits(SPUContext* ctx, const Value& in, size_t nbits) { - auto k1 = constant(ctx, 1UL, spu::DT_I64, in.shape()); + auto k1 = constant(ctx, static_cast(1), spu::DT_I64, in.shape()); auto mask = _sub(ctx, _lshift(ctx, k1, {static_cast(nbits)}), k1); auto out = _and(ctx, in, mask).setDtype(in.dtype()); return out; From 7d5e2274b6cd31aebd7c6a4f568c5f5bc1c5c54b Mon Sep 17 00:00:00 2001 From: Yancheng Zheng <103552181+anakinxc@users.noreply.github.com> Date: Wed, 17 Jul 2024 19:47:51 +0800 Subject: [PATCH 16/27] Update boolean.cc --- libspu/mpc/aby3/boolean.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libspu/mpc/aby3/boolean.cc b/libspu/mpc/aby3/boolean.cc index d8ee6db5..85a44542 100644 --- a/libspu/mpc/aby3/boolean.cc +++ b/libspu/mpc/aby3/boolean.cc @@ -349,7 +349,7 @@ NdArrayRef LShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, // TODO: the hal dtype should tell us about the max number of possible bits. const auto field = ctx->getState()->getDefaultField(); const size_t out_nbits = - std::min(in_ty->nbits() + *std::max_element(bits.begin(), bits.end()), + std::min(in_ty->nbits() + *std::max_element(bits.begin(), bits.end()), SizeOf(field) * 8); const PtType out_btype = calcBShareBacktype(out_nbits); bool is_splat = bits.size() == 1; From a0cc9d0d435fa2180a4f6347d4857462f0c644dc Mon Sep 17 00:00:00 2001 From: Yancheng Zheng <103552181+anakinxc@users.noreply.github.com> Date: Wed, 17 Jul 2024 19:53:12 +0800 Subject: [PATCH 17/27] Update boolean.cc --- libspu/mpc/securenn/boolean.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libspu/mpc/securenn/boolean.cc b/libspu/mpc/securenn/boolean.cc index 9598b5c2..a985ac0b 100644 --- a/libspu/mpc/securenn/boolean.cc +++ b/libspu/mpc/securenn/boolean.cc @@ -262,7 +262,7 @@ NdArrayRef LShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, int64_t out_nbits = in.eltype().as()->nbits() + *std::max_element(shift.begin(), shift.end()); out_nbits = - std::clamp(out_nbits, 0L, static_cast(SizeOf(field) * 8)); + std::clamp(out_nbits, 0L, static_cast(SizeOf(field) * 8)); return makeBShare(ring_lshift(in, shift), field, out_nbits); } From 6bbbddcc372bcc050b15d4a7a63a954ae29ca4cb Mon Sep 17 00:00:00 2001 From: Jimmy MA <9068135+tpppppub@users.noreply.github.com> Date: Mon, 22 Jul 2024 10:39:38 +0800 Subject: [PATCH 18/27] Add more send/recv actions profiling (#780) # Pull Request ## What problem does this PR solve? Issue Number: Fixed #777 ## Possible side effects? - Performance: N/A - Backward compatibility: N/A --- CHANGELOG.md | 1 + libspu/core/trace.h | 15 ++++++++++++++- libspu/device/api.cc | 19 +++++++++++++++---- 3 files changed, 30 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b86bb5b..4dc066f0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ ## staging > > please add your unreleased change here. +- [Feature] Add more send/recv actions profiling ## 20240716 diff --git a/libspu/core/trace.h b/libspu/core/trace.h index c67dbff0..d05b798f 100644 --- a/libspu/core/trace.h +++ b/libspu/core/trace.h @@ -154,6 +154,10 @@ struct ActionRecord final { size_t send_bytes_end; size_t recv_bytes_start; size_t recv_bytes_end; + size_t send_actions_start; + size_t send_actions_end; + size_t recv_actions_start; + size_t recv_actions_end; }; class ProfState final { @@ -238,6 +242,10 @@ class TraceAction final { size_t send_bytes_end_; size_t recv_bytes_start_; size_t recv_bytes_end_; + size_t send_actions_start_; + size_t send_actions_end_; + size_t recv_actions_start_; + size_t recv_actions_end_; int64_t saved_tracer_flag_; @@ -247,6 +255,8 @@ class TraceAction final { if (lctx_) { send_bytes_start_ = lctx_->GetStats()->sent_bytes.load(); recv_bytes_start_ = lctx_->GetStats()->recv_bytes.load(); + send_actions_start_ = lctx_->GetStats()->sent_actions.load(); + recv_actions_start_ = lctx_->GetStats()->recv_actions.load(); } const auto flag = flag_ & tracer_->getFlag(); if ((flag & TR_LOGB) != 0) { @@ -269,6 +279,8 @@ class TraceAction final { if (lctx_) { send_bytes_end_ = lctx_->GetStats()->sent_bytes.load(); recv_bytes_end_ = lctx_->GetStats()->recv_bytes.load(); + send_actions_end_ = lctx_->GetStats()->sent_actions.load(); + recv_actions_end_ = lctx_->GetStats()->recv_actions.load(); } const auto flag = flag_ & tracer_->getFlag(); if ((flag & TR_LOGE) != 0) { @@ -279,7 +291,8 @@ class TraceAction final { tracer_->getProfState()->addRecord( ActionRecord{id_, name_, std::move(detail_), flag_, start_, end_, send_bytes_start_, send_bytes_end_, recv_bytes_start_, - recv_bytes_end_}); + recv_bytes_end_, send_actions_start_, send_actions_end_, + recv_actions_start_, recv_actions_end_}); } } diff --git a/libspu/device/api.cc b/libspu/device/api.cc index 1b27f881..bae9b462 100644 --- a/libspu/device/api.cc +++ b/libspu/device/api.cc @@ -71,12 +71,14 @@ struct CommunicationStats { size_t send_bytes = 0; size_t recv_bytes = 0; size_t send_actions = 0; + size_t recv_actions = 0; void reset(const std::shared_ptr &lctx) { if (!lctx) { return; } send_actions = lctx->GetStats()->sent_actions; + recv_actions = lctx->GetStats()->recv_actions; send_bytes = lctx->GetStats()->sent_bytes; recv_bytes = lctx->GetStats()->recv_bytes; } @@ -88,6 +90,7 @@ struct CommunicationStats { send_bytes = lctx->GetStats()->sent_bytes - send_bytes; recv_bytes = lctx->GetStats()->recv_bytes - recv_bytes; send_actions = lctx->GetStats()->sent_actions - send_actions; + recv_actions = lctx->GetStats()->recv_actions - recv_actions; } }; @@ -108,6 +111,10 @@ struct ActionStats { size_t send_bytes = 0; // total recv bytes. size_t recv_bytes = 0; + // total send actions. + size_t send_actions = 0; + // total recv actions. + size_t recv_actions = 0; inline double getTotalTimeInSecond() const { return std::chrono::duration_cast>(total_time) @@ -183,6 +190,8 @@ void printProfilingData(spu::SPUContext *sctx, const std::string &name, std::chrono::duration_cast(rec.end - rec.start); stat.send_bytes += (rec.send_bytes_end - rec.send_bytes_start); stat.recv_bytes += (rec.recv_bytes_end - rec.recv_bytes_start); + stat.send_actions += (rec.send_actions_end - rec.send_actions_start); + stat.recv_actions += (rec.recv_actions_end - rec.recv_actions_start); } static std::map kModules = { @@ -213,17 +222,19 @@ void printProfilingData(spu::SPUContext *sctx, const std::string &name, const auto &stat = stats.find(key)->second; SPDLOG_INFO( "- {}, executed {} times, duration {}s, send bytes {} recv " - "bytes {}", + "bytes {}, send actions {}, recv actions {}", key.name, stat.count, stat.getTotalTimeInSecond(), stat.send_bytes, - stat.recv_bytes); + stat.recv_bytes, stat.send_actions, stat.recv_actions); } } } // print link statistics SPDLOG_INFO( - "Link details: total send bytes {}, recv bytes {}, send actions {}", - comm_stats.send_bytes, comm_stats.recv_bytes, comm_stats.send_actions); + "Link details: total send bytes {}, recv bytes {}, send actions {}, recv " + "actions {}", + comm_stats.send_bytes, comm_stats.recv_bytes, comm_stats.send_actions, + comm_stats.recv_actions); } void SPUErrorHandler(void *use_data, const char *reason, bool gen_crash_diag) { From 75f65a58a4e1d2bd5294730db7267230a18b6d34 Mon Sep 17 00:00:00 2001 From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com> Date: Mon, 22 Jul 2024 10:40:15 +0800 Subject: [PATCH 19/27] chore(deps): update github/codeql-action action to v3.25.13 (#779) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [![Mend Renovate](https://app.renovatebot.com/images/banner.svg)](https://renovatebot.com) This PR contains the following updates: | Package | Type | Update | Change | |---|---|---|---| | [github/codeql-action](https://togithub.com/github/codeql-action) | action | patch | `v3.25.12` -> `v3.25.13` | --- ### Release Notes
github/codeql-action (github/codeql-action) ### [`v3.25.13`](https://togithub.com/github/codeql-action/compare/v3.25.12...v3.25.13) [Compare Source](https://togithub.com/github/codeql-action/compare/v3.25.12...v3.25.13)
--- ### Configuration 📅 **Schedule**: Branch creation - At any time (no schedule defined), Automerge - At any time (no schedule defined). 🚦 **Automerge**: Disabled by config. Please merge this manually once you are satisfied. ♻ **Rebasing**: Whenever PR becomes conflicted, or you tick the rebase/retry checkbox. 🔕 **Ignore**: Close this PR and you won't be reminded about this update again. --- - [ ] If you want to rebase/retry this PR, check this box --- This PR has been generated by [Mend Renovate](https://www.mend.io/free-developer-tools/renovate/). View repository job log [here](https://developer.mend.io/github/secretflow/spu). Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> --- .github/workflows/scorecard.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/scorecard.yml b/.github/workflows/scorecard.yml index 5b25f5b2..c327e62a 100644 --- a/.github/workflows/scorecard.yml +++ b/.github/workflows/scorecard.yml @@ -67,6 +67,6 @@ jobs: # Upload the results to GitHub's code scanning dashboard. - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@4fa2a7953630fd2f3fb380f21be14ede0169dd4f # v3.25.12 + uses: github/codeql-action/upload-sarif@2d790406f505036ef40ecba973cc774a50395aac # v3.25.13 with: sarif_file: results.sarif From 9e5d9bb69c11f8d22df30471d258f5c0790db1e0 Mon Sep 17 00:00:00 2001 From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com> Date: Mon, 22 Jul 2024 10:40:42 +0800 Subject: [PATCH 20/27] chore(deps): update dependency sphinx to v7.4.7 (#776) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [![Mend Renovate](https://app.renovatebot.com/images/banner.svg)](https://renovatebot.com) This PR contains the following updates: | Package | Change | Age | Adoption | Passing | Confidence | |---|---|---|---|---|---| | [sphinx](https://togithub.com/sphinx-doc/sphinx) ([changelog](https://www.sphinx-doc.org/en/master/changes.html)) | `==7.4.5` -> `==7.4.7` | [![age](https://developer.mend.io/api/mc/badges/age/pypi/sphinx/7.4.7?slim=true)](https://docs.renovatebot.com/merge-confidence/) | [![adoption](https://developer.mend.io/api/mc/badges/adoption/pypi/sphinx/7.4.7?slim=true)](https://docs.renovatebot.com/merge-confidence/) | [![passing](https://developer.mend.io/api/mc/badges/compatibility/pypi/sphinx/7.4.5/7.4.7?slim=true)](https://docs.renovatebot.com/merge-confidence/) | [![confidence](https://developer.mend.io/api/mc/badges/confidence/pypi/sphinx/7.4.5/7.4.7?slim=true)](https://docs.renovatebot.com/merge-confidence/) | --- ### Release Notes
sphinx-doc/sphinx (sphinx) ### [`v7.4.7`](https://togithub.com/sphinx-doc/sphinx/blob/HEAD/CHANGES.rst#Release-747-released-Jul-20-2024) [Compare Source](https://togithub.com/sphinx-doc/sphinx/compare/v7.4.6...v7.4.7) \===================================== ## Bugs fixed - [#​12096](https://togithub.com/sphinx-doc/sphinx/issues/12096): Warn when files are overwritten in the build directory. Patch by Adam Turner and Bénédikt Tran. - [#​12620](https://togithub.com/sphinx-doc/sphinx/issues/12620): Ensure that old-style object description options are respected. Patch by Adam Turner. - [#​12601](https://togithub.com/sphinx-doc/sphinx/issues/12601), [#​12625](https://togithub.com/sphinx-doc/sphinx/issues/12625): Support callable objects in :py:class:`~typing.Annotated` type metadata in the Python domain. Patch by Adam Turner. - [#​12601](https://togithub.com/sphinx-doc/sphinx/issues/12601), [#​12622](https://togithub.com/sphinx-doc/sphinx/issues/12622): Resolve :py:class:`~typing.Annotated` warnings with `sphinx.ext.autodoc`, especially when using :mod:`dataclasses` as type metadata. Patch by Adam Turner. - [#​12589](https://togithub.com/sphinx-doc/sphinx/issues/12589), [#​12626](https://togithub.com/sphinx-doc/sphinx/issues/12626): autosummary: Fix warnings with :rst:role:`!autolink`. Patch by Adam Turner. ### [`v7.4.6`](https://togithub.com/sphinx-doc/sphinx/blob/HEAD/CHANGES.rst#Release-746-released-Jul-18-2024) [Compare Source](https://togithub.com/sphinx-doc/sphinx/compare/v7.4.5...v7.4.6) \===================================== ## Bugs fixed - [#​12859](https://togithub.com/sphinx-doc/sphinx/issues/12859), [#​9743](https://togithub.com/sphinx-doc/sphinx/issues/9743), [#​12609](https://togithub.com/sphinx-doc/sphinx/issues/12609): autosummary: Do not add the package prefix when generating autosummary directives for modules within a package. Patch by Adam Turner. - [#​12613](https://togithub.com/sphinx-doc/sphinx/issues/12613): Reduce log severity for ambiguity detection during inventory loading. Patch by James Addison.
--- ### Configuration 📅 **Schedule**: Branch creation - At any time (no schedule defined), Automerge - At any time (no schedule defined). 🚦 **Automerge**: Disabled by config. Please merge this manually once you are satisfied. ♻ **Rebasing**: Whenever PR becomes conflicted, or you tick the rebase/retry checkbox. 🔕 **Ignore**: Close this PR and you won't be reminded about this update again. --- - [ ] If you want to rebase/retry this PR, check this box --- This PR has been generated by [Mend Renovate](https://www.mend.io/free-developer-tools/renovate/). View repository job log [here](https://developer.mend.io/github/secretflow/spu). Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> --- docs/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index c1a970e9..b6df9201 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,6 +1,6 @@ myst-parser==3.0.1 rstcheck==6.2.4 -sphinx==7.4.5 +sphinx==7.4.7 nbsphinx==0.9.4 sphinx-autobuild==2024.4.16 sphinx-markdown-parser==0.2.4 From 54ec5f1b6f5f2125fa84a01ff247d8ecb3d64bc4 Mon Sep 17 00:00:00 2001 From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com> Date: Mon, 22 Jul 2024 10:41:11 +0800 Subject: [PATCH 21/27] chore(deps): update dependency path-filtering to v1.1.0 (#775) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [![Mend Renovate](https://app.renovatebot.com/images/banner.svg)](https://renovatebot.com) This PR contains the following updates: | Package | Type | Update | Change | |---|---|---|---| | [path-filtering](https://circleci.com/developer/orbs/orb/circleci/path-filtering) | orb | minor | `1.0.0` -> `1.1.0` | --- ### Configuration 📅 **Schedule**: Branch creation - At any time (no schedule defined), Automerge - At any time (no schedule defined). 🚦 **Automerge**: Disabled by config. Please merge this manually once you are satisfied. ♻ **Rebasing**: Whenever PR becomes conflicted, or you tick the rebase/retry checkbox. 🔕 **Ignore**: Close this PR and you won't be reminded about this update again. --- - [ ] If you want to rebase/retry this PR, check this box --- This PR has been generated by [Mend Renovate](https://www.mend.io/free-developer-tools/renovate/). View repository job log [here](https://developer.mend.io/github/secretflow/spu). Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> --- .circleci/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 41e5522e..c4617cd9 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -17,7 +17,7 @@ version: 2.1 setup: true orbs: - path-filtering: circleci/path-filtering@1.0.0 + path-filtering: circleci/path-filtering@1.1.0 continuation: circleci/continuation@1.0.0 parameters: From e5b366a08a61300ba2416b85e423410dae92720f Mon Sep 17 00:00:00 2001 From: Yancheng Zheng <103552181+anakinxc@users.noreply.github.com> Date: Mon, 29 Jul 2024 16:22:54 +0800 Subject: [PATCH 22/27] Repo sync (#795) --- CHANGELOG.md | 1 + bazel/repositories.bzl | 16 ++-- libspu/compiler/tests/interpret/and.mlir | 1 - .../tests/interpret/generate_mlir_tests.py | 8 +- libspu/compiler/tests/interpret/or.mlir | 1 - libspu/compiler/tests/interpret/xor.mlir | 1 - .../passes/optimizations/ops_negative.mlir | 4 +- libspu/compiler/utils/utils.h | 1 + libspu/dialect/pphlo/IR/ops.cc | 80 +------------------ libspu/dialect/pphlo/IR/print_parse.cc | 10 +++ .../transforms/inline_secret_control_flow.cc | 5 +- libspu/mpc/aby3/boolean.cc | 6 +- libspu/mpc/cheetah/ot/yacl/yacl_ote_adapter.h | 2 +- libspu/mpc/securenn/boolean.cc | 4 +- .../ttp_server/beaver_server_main.cc | 3 +- libspu/mpc/spdz2k/beaver/BUILD.bazel | 2 +- libspu/mpc/spdz2k/beaver/beaver_tinyot.cc | 2 +- libspu/mpc/spdz2k/beaver/beaver_tinyot.h | 2 +- libspu/mpc/spdz2k/ot/BUILD.bazel | 2 +- libspu/mpc/spdz2k/ot/kos_ote.h | 2 +- libspu/mpc/spdz2k/ot/tiny_ot.h | 2 +- 21 files changed, 46 insertions(+), 109 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4dc066f0..27502fd6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ ## staging > > please add your unreleased change here. + - [Feature] Add more send/recv actions profiling ## 20240716 diff --git a/bazel/repositories.bzl b/bazel/repositories.bzl index d97c8e68..1318e460 100644 --- a/bazel/repositories.bzl +++ b/bazel/repositories.bzl @@ -40,10 +40,10 @@ def _yacl(): http_archive, name = "yacl", urls = [ - "https://github.com/secretflow/yacl/archive/refs/tags/0.4.5b3.tar.gz", + "https://github.com/secretflow/yacl/archive/refs/tags/0.4.5b3_nightly_20240722.tar.gz", ], - strip_prefix = "yacl-0.4.5b3", - sha256 = "bd89d63312e5e83eff5e001e2cf2135baff321c4b72a309f7d00cc53ce02e1a1", + strip_prefix = "yacl-0.4.5b3_nightly_20240722", + sha256 = "ccca599e6ded6089c5afbb87c8f5e09383195af256caacd50089f0c7443e8604", ) def _libpsi(): @@ -51,10 +51,10 @@ def _libpsi(): http_archive, name = "psi", urls = [ - "https://github.com/secretflow/psi/archive/refs/tags/v0.4.0beta.tar.gz", + "https://github.com/secretflow/psi/archive/refs/tags/v0.4.1.dev240722.tar.gz", ], - strip_prefix = "psi-0.4.0beta", - sha256 = "c2fbf486a66eca9d3ec1725a81d93a7c6e80a9206ef1c9263a1608e0bef95e1a", + strip_prefix = "psi-0.4.1.dev240722", + sha256 = "878cd8af2c7b9850944a27adf91f21dd4937d09d38e8365baad3b5165db8b39a", ) def _rules_proto_grpc(): @@ -136,8 +136,8 @@ def _bazel_skylib(): ) def _com_github_openxla_xla(): - OPENXLA_COMMIT = "8533a6869ae02fb3b15a8a12739a982fc3c9f6e7" - OPENXLA_SHA256 = "d5b076825c992f59542f6b94e5480c7e7c6c627cd18c80ec60b6d5b295c160d4" + OPENXLA_COMMIT = "04f2bfe797408c9efe742b89e2e4db6cf526ebb7" + OPENXLA_SHA256 = "7e1d24737815be7607eed5f02fe7f81d97ffe358dfb7b4876f97bce8f48b3b3e" # We need openxla to handle xla/mhlo/stablehlo maybe( diff --git a/libspu/compiler/tests/interpret/and.mlir b/libspu/compiler/tests/interpret/and.mlir index 618d3222..fbcc2a2d 100644 --- a/libspu/compiler/tests/interpret/and.mlir +++ b/libspu/compiler/tests/interpret/and.mlir @@ -1,7 +1,6 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s -// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT diff --git a/libspu/compiler/tests/interpret/generate_mlir_tests.py b/libspu/compiler/tests/interpret/generate_mlir_tests.py index 7a271164..460f67f6 100755 --- a/libspu/compiler/tests/interpret/generate_mlir_tests.py +++ b/libspu/compiler/tests/interpret/generate_mlir_tests.py @@ -100,9 +100,11 @@ def TestCase(inputs, expected, checker='expect_eq', tol=None): f.write( "// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s\n" ) - f.write( - "// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s\n" - ) + # FIXME: these tests are not stable for cheetah now + if test not in ["xor", "or", "and"]: + f.write( + "// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s\n" + ) # Some test values in max and min are not supported by protocol 5. if test not in ["max", "min"]: f.write( diff --git a/libspu/compiler/tests/interpret/or.mlir b/libspu/compiler/tests/interpret/or.mlir index 1eef975f..acd9c28f 100644 --- a/libspu/compiler/tests/interpret/or.mlir +++ b/libspu/compiler/tests/interpret/or.mlir @@ -1,7 +1,6 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s -// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT diff --git a/libspu/compiler/tests/interpret/xor.mlir b/libspu/compiler/tests/interpret/xor.mlir index a9ee8348..f2b8fba9 100644 --- a/libspu/compiler/tests/interpret/xor.mlir +++ b/libspu/compiler/tests/interpret/xor.mlir @@ -1,7 +1,6 @@ // RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s -// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s // RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s // AUTO GENERATED, DO NOT EDIT diff --git a/libspu/compiler/tests/passes/optimizations/ops_negative.mlir b/libspu/compiler/tests/passes/optimizations/ops_negative.mlir index 926a5de5..e8644b96 100644 --- a/libspu/compiler/tests/passes/optimizations/ops_negative.mlir +++ b/libspu/compiler/tests/passes/optimizations/ops_negative.mlir @@ -24,7 +24,7 @@ func.func @main() -> tensor { func.func @main() -> tensor { %0 = pphlo.constant dense<[0.000000e+00, -3.40282347E+38]> : tensor<2xf32> - // expected-error @+1 {{op broadcast_dimensions contains invalid value -6 for result with rank 1}} + // expected-error @+1 {{broadcast_dimensions contains invalid value -6 for result with rank 1}} %1 = pphlo.broadcast %0, dims = [-6] : (tensor<2xf32>) -> tensor<2xf32> %2 = pphlo.constant dense<5> : tensor pphlo.return %2 : tensor @@ -33,7 +33,7 @@ func.func @main() -> tensor { // ----- func.func @main() -> tensor { - // expected-error @+1 {{op iota dimension cannot go beyond the output rank or be negative}} + // expected-error @+1 {{iota dimension cannot go beyond the output rank}} %0 = pphlo.iota dim = 1000 : tensor<1xi32> %1 = pphlo.constant dense<5> : tensor pphlo.return %1 : tensor diff --git a/libspu/compiler/utils/utils.h b/libspu/compiler/utils/utils.h index faeff8d9..3b06397a 100644 --- a/libspu/compiler/utils/utils.h +++ b/libspu/compiler/utils/utils.h @@ -14,6 +14,7 @@ #pragma once +#include "llvm/ADT/Twine.h" #include "mlir/Support/LogicalResult.h" namespace mlir::spu { diff --git a/libspu/dialect/pphlo/IR/ops.cc b/libspu/dialect/pphlo/IR/ops.cc index 564d44e1..4d199e93 100644 --- a/libspu/dialect/pphlo/IR/ops.cc +++ b/libspu/dialect/pphlo/IR/ops.cc @@ -18,21 +18,12 @@ #include "llvm/Support/FormatVariadic.h" #include "mlir/IR/Builders.h" #include "mlir/IR/TypeUtilities.h" +#include "stablehlo/dialect/TypeInference.h" #include "libspu/dialect/pphlo/IR/ops.h.inc" namespace mlir::spu::pphlo { -namespace { - -// Checks if the vector `nums` has duplicates. -bool hasDuplicates(const ArrayRef nums) { - llvm::SmallDenseSet set(nums.begin(), nums.end()); - return set.size() != nums.size(); -} - -} // namespace - template static LogicalResult Verify(T /*op*/) { return success(); @@ -386,75 +377,12 @@ LogicalResult ConcatenateOp::verify() { } LogicalResult BroadcastOp::verify() { - auto operandType = mlir::dyn_cast(getOperand().getType()); - - auto operandRank = operandType.getRank(); - - if (getBroadcastDimensions().empty()) { - if (operandRank == 0) { - return success(); - } - return emitOpError( - llvm::formatv("broadcast_dimensions is absent, but required because " - "operand has non-zero rank ({0})", - operandRank)); - } - - auto dimensionsSize = getBroadcastDimensions().size(); - if (static_cast(dimensionsSize) != operandRank) { - return emitOpError(llvm::formatv( - "broadcast_dimensions size ({0}) does not match operand rank ({1})", - dimensionsSize, operandRank)); - } - - auto dimensions = getBroadcastDimensions(); - if (hasDuplicates(dimensions)) { - return emitOpError("broadcast_dimensions should not have duplicates"); - } - - auto resultType = mlir::dyn_cast(getResult().getType()); - auto resultRank = resultType.getRank(); - - for (size_t i = 0; i != dimensionsSize; ++i) { - auto dimIndex = dimensions[i]; - if ((dimIndex >= resultRank) || (dimIndex < 0)) { - return emitOpError( - llvm::formatv("broadcast_dimensions contains invalid value {0} for " - "result with rank {1}", - dimIndex, resultRank)); - } - - if (!operandType.isDynamicDim(i)) { - auto dimSize = operandType.getDimSize(i); - auto resultDimSize = resultType.getDimSize(dimIndex); - if (dimSize != 1 && dimSize != resultDimSize) { - return emitOpError( - llvm::formatv("size of operand dimension {0} ({1}) is not equal to " - "1 or size of result dimension {2} ({3})", - i, dimSize, dimIndex, resultDimSize)); - } - } - } - - return success(); + return hlo::verifyBroadcastInDimOp(getLoc(), getOperand(), + getBroadcastDimensions(), getResult()); } LogicalResult IotaOp::verify() { - auto shape = mlir::dyn_cast(getType()); - if (!shape.hasRank()) { - return success(); - } - - if (shape.getRank() == 0) { - return emitOpError() << "does not support scalars."; - } - - auto iotaDimension = static_cast(this->getIotaDimension()); - if (iotaDimension >= shape.getRank() || iotaDimension < 0) { - return emitOpError() - << "iota dimension cannot go beyond the output rank or be negative."; - } - return success(); + return hlo::verifyIotaOp(getLoc(), getIotaDimension(), getResult()); } LogicalResult SliceOp::verify() { diff --git a/libspu/dialect/pphlo/IR/print_parse.cc b/libspu/dialect/pphlo/IR/print_parse.cc index 8716e011..6f693e30 100644 --- a/libspu/dialect/pphlo/IR/print_parse.cc +++ b/libspu/dialect/pphlo/IR/print_parse.cc @@ -105,9 +105,19 @@ ParseResult ConstantOp::parse(OpAsmParser& parser, OperationState& result) { if (parser.parseRParen()) { return failure(); } + // Parse optional properties + if (succeeded(parser.parseOptionalLess()) && + (failed(parser.parseAttribute(result.propertiesAttr)) || + failed(parser.parseGreater()))) { + return failure(); + } + + // Parse optional attributes if (parser.parseOptionalAttrDict(result.attributes)) { return failure(); } + + // Parse type signature if (parser.parseColon() || parser.parseLParen() || parser.parseRParen() || parser.parseArrow()) { return failure(); diff --git a/libspu/dialect/pphlo/transforms/inline_secret_control_flow.cc b/libspu/dialect/pphlo/transforms/inline_secret_control_flow.cc index 981a8c41..9f801875 100644 --- a/libspu/dialect/pphlo/transforms/inline_secret_control_flow.cc +++ b/libspu/dialect/pphlo/transforms/inline_secret_control_flow.cc @@ -36,9 +36,8 @@ class CaseConverter : public OpRewritePattern { if (target_type.getNumElements() == in_type.getNumElements()) { return rewriter.create(loc, broadcasted_mask_type, in); } else { - return rewriter.create( - loc, broadcasted_mask_type, in, - llvm::SmallVector(target_type.getRank(), 0)); + return rewriter.create(loc, broadcasted_mask_type, in, + llvm::SmallVector{0}); } } diff --git a/libspu/mpc/aby3/boolean.cc b/libspu/mpc/aby3/boolean.cc index 85a44542..96033cc5 100644 --- a/libspu/mpc/aby3/boolean.cc +++ b/libspu/mpc/aby3/boolean.cc @@ -348,9 +348,9 @@ NdArrayRef LShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, // TODO: the hal dtype should tell us about the max number of possible bits. const auto field = ctx->getState()->getDefaultField(); - const size_t out_nbits = - std::min(in_ty->nbits() + *std::max_element(bits.begin(), bits.end()), - SizeOf(field) * 8); + const size_t out_nbits = std::min( + in_ty->nbits() + *std::max_element(bits.begin(), bits.end()), + SizeOf(field) * 8); const PtType out_btype = calcBShareBacktype(out_nbits); bool is_splat = bits.size() == 1; diff --git a/libspu/mpc/cheetah/ot/yacl/yacl_ote_adapter.h b/libspu/mpc/cheetah/ot/yacl/yacl_ote_adapter.h index 294e095f..bd37640b 100644 --- a/libspu/mpc/cheetah/ot/yacl/yacl_ote_adapter.h +++ b/libspu/mpc/cheetah/ot/yacl/yacl_ote_adapter.h @@ -19,8 +19,8 @@ #include "yacl/kernel/algorithms/base_ot.h" #include "yacl/kernel/algorithms/ferret_ote.h" #include "yacl/kernel/algorithms/iknp_ote.h" -#include "yacl/kernel/algorithms/ot_store.h" #include "yacl/kernel/algorithms/softspoken_ote.h" +#include "yacl/kernel/type/ot_store.h" #include "libspu/core/prelude.h" #include "libspu/mpc/cheetah/ot/ot_util.h" diff --git a/libspu/mpc/securenn/boolean.cc b/libspu/mpc/securenn/boolean.cc index a985ac0b..9deccec8 100644 --- a/libspu/mpc/securenn/boolean.cc +++ b/libspu/mpc/securenn/boolean.cc @@ -261,8 +261,8 @@ NdArrayRef LShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, int64_t out_nbits = in.eltype().as()->nbits() + *std::max_element(shift.begin(), shift.end()); - out_nbits = - std::clamp(out_nbits, 0L, static_cast(SizeOf(field) * 8)); + out_nbits = std::clamp(out_nbits, 0L, + static_cast(SizeOf(field) * 8)); return makeBShare(ring_lshift(in, shift), field, out_nbits); } diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server_main.cc b/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server_main.cc index 1680325e..0a701cf3 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server_main.cc +++ b/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server_main.cc @@ -91,8 +91,7 @@ int main(int argc, char* argv[]) { std::string key; SPU_ENFORCE( butil::Base64Decode(ttp_server_config::FLAGS_server_private_key, &key)); - decode_private_key = - yacl::Buffer(decode_private_key.data(), decode_private_key.size()); + decode_private_key = yacl::Buffer(key.data(), key.size()); } spu::mpc::semi2k::beaver::ttp_server::ServerOptions ops{ diff --git a/libspu/mpc/spdz2k/beaver/BUILD.bazel b/libspu/mpc/spdz2k/beaver/BUILD.bazel index 5ea2a3f6..b7f4d80e 100644 --- a/libspu/mpc/spdz2k/beaver/BUILD.bazel +++ b/libspu/mpc/spdz2k/beaver/BUILD.bazel @@ -80,7 +80,7 @@ spu_cc_library( "//libspu/mpc/spdz2k/ot:tiny_ot", "//libspu/mpc/utils:ring_ops", "@yacl//yacl/crypto/tools:prg", - "@yacl//yacl/kernel/algorithms:ot_store", + "@yacl//yacl/kernel/type:ot_store", "@yacl//yacl/link", "@yacl//yacl/utils:matrix_utils", "@yacl//yacl/utils:serialize", diff --git a/libspu/mpc/spdz2k/beaver/beaver_tinyot.cc b/libspu/mpc/spdz2k/beaver/beaver_tinyot.cc index ed7f2ae8..f8d090fc 100644 --- a/libspu/mpc/spdz2k/beaver/beaver_tinyot.cc +++ b/libspu/mpc/spdz2k/beaver/beaver_tinyot.cc @@ -23,7 +23,7 @@ #include "yacl/crypto/rand/rand.h" #include "yacl/crypto/tools/prg.h" #include "yacl/kernel/algorithms/base_ot.h" -#include "yacl/kernel/algorithms/ot_store.h" +#include "yacl/kernel/type/ot_store.h" #include "yacl/utils/serialize.h" #include "libspu/mpc/common/prg_tensor.h" diff --git a/libspu/mpc/spdz2k/beaver/beaver_tinyot.h b/libspu/mpc/spdz2k/beaver/beaver_tinyot.h index b0d35372..6f8ec9ad 100644 --- a/libspu/mpc/spdz2k/beaver/beaver_tinyot.h +++ b/libspu/mpc/spdz2k/beaver/beaver_tinyot.h @@ -14,7 +14,7 @@ #pragma once -#include "yacl/kernel/algorithms/ot_store.h" +#include "yacl/kernel/type/ot_store.h" #include "yacl/link/context.h" #include "libspu/mpc/common/prg_state.h" diff --git a/libspu/mpc/spdz2k/ot/BUILD.bazel b/libspu/mpc/spdz2k/ot/BUILD.bazel index 45143795..124adaf1 100644 --- a/libspu/mpc/spdz2k/ot/BUILD.bazel +++ b/libspu/mpc/spdz2k/ot/BUILD.bazel @@ -68,7 +68,7 @@ spu_cc_library( "//libspu/mpc/utils:ring_ops", "@com_github_emptoolkit_emp_tool//:emp-tool", "@yacl//yacl/crypto/tools:prg", - "@yacl//yacl/kernel/algorithms:ot_store", + "@yacl//yacl/kernel/type:ot_store", "@yacl//yacl/link", ], ) diff --git a/libspu/mpc/spdz2k/ot/kos_ote.h b/libspu/mpc/spdz2k/ot/kos_ote.h index 58b9fc07..4ba1f1b9 100644 --- a/libspu/mpc/spdz2k/ot/kos_ote.h +++ b/libspu/mpc/spdz2k/ot/kos_ote.h @@ -15,7 +15,7 @@ #pragma once #include "absl/types/span.h" #include "yacl/base/dynamic_bitset.h" -#include "yacl/kernel/algorithms/ot_store.h" +#include "yacl/kernel/type/ot_store.h" #include "yacl/link/link.h" namespace spu::mpc::spdz2k { diff --git a/libspu/mpc/spdz2k/ot/tiny_ot.h b/libspu/mpc/spdz2k/ot/tiny_ot.h index cb099095..ebe38563 100644 --- a/libspu/mpc/spdz2k/ot/tiny_ot.h +++ b/libspu/mpc/spdz2k/ot/tiny_ot.h @@ -13,7 +13,7 @@ // limitations under the License. #include -#include "yacl/kernel/algorithms/ot_store.h" +#include "yacl/kernel/type/ot_store.h" #include "libspu/mpc/common/communicator.h" From 5c922243c1ed42d959148a3869b17a37ff55346c Mon Sep 17 00:00:00 2001 From: lwj Date: Wed, 31 Jul 2024 10:45:21 +0800 Subject: [PATCH 23/27] =?UTF-8?q?[squirrel]=20Remove=20the=20assertation?= =?UTF-8?q?=20on=20empty=20indicator,=20only=20reporting=20=E2=80=A6=20(#7?= =?UTF-8?q?87)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …warnning. # Pull Request ## What problem does this PR solve? Issue Number: Fixed #785 ## Possible side effects? - Performance: - Backward compatibility: --- experimental/squirrel/README.md | 17 +++- experimental/squirrel/bin_matvec_prot.cc | 9 +- experimental/squirrel/bin_matvec_prot.h | 5 +- experimental/squirrel/bin_matvec_prot_test.cc | 96 +++++++++++++++++++ 4 files changed, 121 insertions(+), 6 deletions(-) diff --git a/experimental/squirrel/README.md b/experimental/squirrel/README.md index 5c33ff83..affce442 100644 --- a/experimental/squirrel/README.md +++ b/experimental/squirrel/README.md @@ -29,11 +29,24 @@ Code under this folder is purely for research demonstration and it's **NOT desig * On one terminal ```sh - bazel-bin/experimental/squirrel/squirrel_demo_main --rank0_nfeatures=85 --rank1_nfeatures=85 --standalone --train=BinaryClassification_Aps_Test_60000_171.csv --test=BinaryClassification_Aps_Test_16000_171.csv --standalone=true --rank=0 --has_label=0 --lr=1.0 --subsample=0.8 + bazel-bin/experimental/squirrel/squirrel_demo_main --rank0_nfeatures=85 --rank1_nfeatures=85 --standalone=true --train=BinaryClassification_Aps_Test_60000_171.csv --test=BinaryClassification_Aps_Test_16000_171.csv --rank=0 --has_label=0 --lr=1.0 --subsample=0.8 ``` * On another terminal ```sh - bazel-bin/experimental/squirrel/squirrel_demo_main --rank0_nfeatures=85 --rank1_nfeatures=85 --standalone --train=BinaryClassification_Aps_Test_60000_171.csv --test=BinaryClassification_Aps_Test_16000_171.csv --standalone=true --rank=1 --has_label=1 --lr=1.0 --subsample=0.8 + bazel-bin/experimental/squirrel/squirrel_demo_main --rank0_nfeatures=85 --rank1_nfeatures=85 --standalone=true --train=BinaryClassification_Aps_Test_60000_171.csv --test=BinaryClassification_Aps_Test_16000_171.csv --rank=1 --has_label=1 --lr=1.0 --subsample=0.8 ``` +* Run on distributed dataset, e.g., using the `breast_cancer` dataset from the SPU repo. + * On one terminal + + ```sh + bazel-bin/experimental/squirrel/squirrel_demo_main --rank0_nfeatures=15 --rank1_nfeatures=15 --standalone=false --train=examples/data/breast_cancer_a.csv --rank=0 --has_label=0 --lr=1.0 --subsample=0.8 + ``` + + * On another terminal + + ```sh + bazel-bin/experimental/squirrel/squirrel_demo_main --rank0_nfeatures=15 --rank1_nfeatures=15 --standalone=false --train=examples/data/breast_cancer_b.csv --rank=1 --has_label=1 --lr=1.0 --subsample=0.8 + ``` + diff --git a/experimental/squirrel/bin_matvec_prot.cc b/experimental/squirrel/bin_matvec_prot.cc index 6946fa4a..35e4ee3a 100644 --- a/experimental/squirrel/bin_matvec_prot.cc +++ b/experimental/squirrel/bin_matvec_prot.cc @@ -766,9 +766,12 @@ spu::NdArrayRef BinMatVecProtocol::Recv(const spu::NdArrayRef &vec_in, SPU_ENFORCE_EQ(vec_in.numel(), dim_in); if (not indicator.empty()) { // mat * diag(indicator) should not be all zeros. - SPU_ENFORCE(std::any_of(indicator.begin(), indicator.end(), - [](uint8_t x) { return x > 0; }), - "empty matrix is not allowed"); + if (std::all_of(indicator.begin(), indicator.end(), + [](uint8_t x) { return x == 0; })) { + SPDLOG_WARN( + "Empty matrix! Make sure the 1-bit error will not ruin your " + "computation."); + } } auto eltype = vec_in.eltype(); diff --git a/experimental/squirrel/bin_matvec_prot.h b/experimental/squirrel/bin_matvec_prot.h index 2ac8ea6c..ba34499f 100644 --- a/experimental/squirrel/bin_matvec_prot.h +++ b/experimental/squirrel/bin_matvec_prot.h @@ -70,7 +70,10 @@ struct StlSparseMatrix { // Outputs: // Sender: z0 \in Zk^{m} // Recv: z1 \in Zk^{m} -// such that z0 + z1 = M * (v0 + v1) mod Zk +// such that z0 + z1 = M * (v0 + v1) + e mod Zk +// +// Note that we might introduce 1-bit error `e` due to the coefficient-based +// resharing HE ciphertexts to additive shares. class BinMatVecProtocol { public: BinMatVecProtocol(size_t ring_bitwidth, diff --git a/experimental/squirrel/bin_matvec_prot_test.cc b/experimental/squirrel/bin_matvec_prot_test.cc index 452e127d..9ef145ad 100644 --- a/experimental/squirrel/bin_matvec_prot_test.cc +++ b/experimental/squirrel/bin_matvec_prot_test.cc @@ -172,4 +172,100 @@ TEST_P(BinMatVecProtTest, WithIndicator) { } }); } + +TEST_P(BinMatVecProtTest, EmptyMat) { + using namespace spu; + using namespace spu::mpc; + constexpr size_t kWorldSize = 2; + + FieldType field = std::get<0>(GetParam()); + int64_t dim_in = std::get<0>(std::get<1>(GetParam())); + int64_t dim_out = std::get<1>(std::get<1>(GetParam())); + + StlSparseMatrix mat; + mat.rows_data_.resize(dim_out); + mat.cols_ = dim_in; + + NdArrayRef vec_shr[2]; + vec_shr[0] = ring_rand(field, {dim_in}) + .as(spu::makeType(field)); + vec_shr[1] = ring_rand(field, {dim_in}) + .as(spu::makeType(field)); + + NdArrayRef vec = ring_add(vec_shr[0], vec_shr[1]); + + NdArrayRef out_shr[2]; + utils::simulate(kWorldSize, [&](std::shared_ptr lctx) { + BinMatVecProtocol binmat_prot(SizeOf(field) * 8, lctx); + if (0 == lctx->Rank()) { + out_shr[0] = binmat_prot.Send(vec_shr[0], dim_out, dim_in); + } else { + out_shr[1] = binmat_prot.Recv(vec_shr[1], dim_out, dim_in, mat); + } + }); + NdArrayRef reveal = ring_add(out_shr[0], out_shr[1]); + + DISPATCH_ALL_FIELDS(field, "", [&]() { + using sT = std::make_signed::type; + NdArrayView _vec(vec); + auto expected = BinAccumuate(_vec, mat); + NdArrayView got(reveal); + + EXPECT_EQ(expected.size(), (size_t)got.numel()); + for (int64_t i = 0; i < dim_out; ++i) { + EXPECT_NEAR(expected[i], got[i], 1); + } + }); +} + +TEST_P(BinMatVecProtTest, WithEmptyIndicator) { + using namespace spu; + using namespace spu::mpc; + constexpr size_t kWorldSize = 2; + + FieldType field = std::get<0>(GetParam()); + int64_t dim_in = std::get<0>(std::get<1>(GetParam())); + int64_t dim_out = std::get<1>(std::get<1>(GetParam())); + + StlSparseMatrix mat; + PrepareBinaryMat(mat, dim_out, dim_in, 0); + + NdArrayRef vec_shr[2]; + vec_shr[0] = ring_rand(field, {dim_in}) + .as(spu::makeType(field)); + vec_shr[1] = ring_rand(field, {dim_in}) + .as(spu::makeType(field)); + std::vector indicator(dim_in); + std::default_random_engine rdv; + std::uniform_int_distribution dist(0, 10); + // empty indicator + std::fill_n(indicator.data(), indicator.size(), 0); + + NdArrayRef vec = ring_add(vec_shr[0], vec_shr[1]); + + NdArrayRef out_shr[2]; + utils::simulate(kWorldSize, [&](std::shared_ptr lctx) { + BinMatVecProtocol binmat_prot(SizeOf(field) * 8, lctx); + if (0 == lctx->Rank()) { + out_shr[0] = binmat_prot.Send(vec_shr[0], dim_out, dim_in); + } else { + out_shr[1] = binmat_prot.Recv(vec_shr[1], dim_out, dim_in, mat, + absl::MakeConstSpan(indicator)); + } + }); + NdArrayRef reveal = ring_add(out_shr[0], out_shr[1]); + + DISPATCH_ALL_FIELDS(field, "", [&]() { + using sT = std::make_signed::type; + NdArrayView _vec(vec); + NdArrayView got(reveal); + auto expected = BinAccumuate(_vec, mat, absl::MakeConstSpan(indicator)); + + EXPECT_EQ(expected.size(), (size_t)got.numel()); + for (int64_t i = 0; i < dim_out; ++i) { + EXPECT_NEAR(expected[i], got[i], 1); + } + }); +} + } // namespace squirrel::test From b47f68654f853f46a120beaf322b1760d38f4a2a Mon Sep 17 00:00:00 2001 From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com> Date: Wed, 31 Jul 2024 10:51:06 +0800 Subject: [PATCH 24/27] chore(deps): update dependency sphinx to v8 (#796) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [![Mend Renovate](https://app.renovatebot.com/images/banner.svg)](https://renovatebot.com) This PR contains the following updates: | Package | Change | Age | Adoption | Passing | Confidence | |---|---|---|---|---|---| | [sphinx](https://togithub.com/sphinx-doc/sphinx) ([changelog](https://www.sphinx-doc.org/en/master/changes.html)) | `==7.4.7` -> `==8.0.2` | [![age](https://developer.mend.io/api/mc/badges/age/pypi/sphinx/8.0.2?slim=true)](https://docs.renovatebot.com/merge-confidence/) | [![adoption](https://developer.mend.io/api/mc/badges/adoption/pypi/sphinx/8.0.2?slim=true)](https://docs.renovatebot.com/merge-confidence/) | [![passing](https://developer.mend.io/api/mc/badges/compatibility/pypi/sphinx/7.4.7/8.0.2?slim=true)](https://docs.renovatebot.com/merge-confidence/) | [![confidence](https://developer.mend.io/api/mc/badges/confidence/pypi/sphinx/7.4.7/8.0.2?slim=true)](https://docs.renovatebot.com/merge-confidence/) | --- ### Release Notes
sphinx-doc/sphinx (sphinx) ### [`v8.0.2`](https://togithub.com/sphinx-doc/sphinx/releases/tag/v8.0.2): Sphinx 8.0.2 [Compare Source](https://togithub.com/sphinx-doc/sphinx/compare/v8.0.1...v8.0.2) Changelog: https://www.sphinx-doc.org/en/master/changes.html ### [`v8.0.1`](https://togithub.com/sphinx-doc/sphinx/releases/tag/v8.0.1): Sphinx 8.0.1 [Compare Source](https://togithub.com/sphinx-doc/sphinx/compare/v8.0.0...v8.0.1) Changelog: https://www.sphinx-doc.org/en/master/changes.html ### [`v8.0.0`](https://togithub.com/sphinx-doc/sphinx/releases/tag/v8.0.0): Sphinx 8.0.0 [Compare Source](https://togithub.com/sphinx-doc/sphinx/compare/v7.4.7...v8.0.0) Changelog: https://www.sphinx-doc.org/en/master/changes.html ## Dependencies - [#​12633](https://togithub.com/sphinx-doc/sphinx/issues/12633): Drop Python 3.9 support. ## Incompatible changes - Remove deprecated functions from `sphinx.util`: - Removed `sphinx.util.path_stabilize` (use `sphinx.util.osutil.path_stabilize`). - Removed `sphinx.util.display_chunk` (use `sphinx.util.display.display_chunk`). - Removed `sphinx.util.status_iterator` (use `sphinx.util.display.status_iterator`). - Removed `sphinx.util.SkipProgressMessage` (use `sphinx.util.display.SkipProgressMessage`). - Removed `sphinx.util.progress_message` (use `sphinx.util.display.progress_message`). - Removed `sphinx.util.epoch_to_rfc1123` (use `sphinx.http_date.epoch_to_rfc1123`). - Removed `sphinx.util.rfc1123_to_epoch` (use `sphinx.http_date.rfc1123_to_epoch`). - Removed `sphinx.util.save_traceback` (use `sphinx.exceptions.save_traceback`). - Removed `sphinx.util.format_exception_cut_frames` (use `sphinx.exceptions.format_exception_cut_frames`). - Removed `sphinx.util.xmlname_checker` (use `sphinx.builders.epub3._XML_NAME_PATTERN`). Patch by Adam Turner. - Removed `sphinx.util.osutil.cd` (use `contextlib.chdir`). Patch by Adam Turner. - Removed `sphinx.util.typing.stringify` (use `sphinx.util.typing.stringify_annotation`). Patch by Adam Turner. - [#​12593](https://togithub.com/sphinx-doc/sphinx/issues/12593): Raise an error for invalid `html_sidebars` values. Patch by Adam Turner. - [#​12593](https://togithub.com/sphinx-doc/sphinx/issues/12593): Raise an error in `Theme.get_config` for invalid sections. Patch by Adam Turner. - [#​11693](https://togithub.com/sphinx-doc/sphinx/issues/11693): Remove support for old-style `Makefile` and `make.bat` output in `sphinx-quickstart`. - [#​11693](https://togithub.com/sphinx-doc/sphinx/issues/11693): Remove the `--no-use-make-mode`, `-M`, `--use-make-mode`, and `-m` options from `sphinx-quickstart`. Patch by Adam Turner. - Removed the tuple interface to `sphinx.ext.autodoc.ObjectMember`. Patch by Adam Turner. - [#​12630](https://togithub.com/sphinx-doc/sphinx/issues/12630): Sphinx 8 makes two changes to the `linkcheck` configuration defaults: - `linkcheck_allow_unauthorized` is now `False` by default. - `linkcheck_report_timeouts_as_broken` is now `False` by default. Patch by James Addison. - [#​12597](https://togithub.com/sphinx-doc/sphinx/issues/12597): Change the default of `show_warning_types` from `False` to `True`. Patch by Chris Sewell. - [#​12083](https://togithub.com/sphinx-doc/sphinx/issues/12083): Remove support for the old (2008--2010) Sphinx 0.5 and Sphinx 0.6 `intersphinx_mapping` format. Patch by Bénédikt Tran and Adam Turner. - [#​12096](https://togithub.com/sphinx-doc/sphinx/issues/12096): Do not overwrite user-supplied files when copying assets unless forced with `force=True`. Patch by Adam Turner. - [#​12646](https://togithub.com/sphinx-doc/sphinx/issues/12646): Remove `sphinx.util.inspect.isNewType`. Use `isinstance(obj, typing.NewType)` instead on Python 3.10 and newer. Patch by Adam Turner. - Remove the long-deprecated (since Sphinx 2) alias to `VersionChange` in `sphinx.directives.other` (Deprecated since Sphinx 2). Use `sphinx.domains.changeset.VersionChange` directly. Patch by Adam Turner. ## Deprecated - [#​12643](https://togithub.com/sphinx-doc/sphinx/issues/12643): Renamed `sphinx.ext.intersphinx.normalize_intersphinx_mapping` to `sphinx.ext.intersphinx.validate_intersphinx_mapping`. The old name will be removed in Sphinx 10. Patch by Adam Turner. - [#​12650](https://togithub.com/sphinx-doc/sphinx/issues/12650), [#​12686](https://togithub.com/sphinx-doc/sphinx/issues/12686), [#​12690](https://togithub.com/sphinx-doc/sphinx/issues/12690): Extend the deprecation for string methods on `pathlib.Path` objects to Sphinx 9. Use `os.fspath` to convert :py:class:`~pathlib.Path` objects to strings, or `pathlib.Path`'s methods to work with path objects. Patch by Adam Turner.
--- ### Configuration 📅 **Schedule**: Branch creation - At any time (no schedule defined), Automerge - At any time (no schedule defined). 🚦 **Automerge**: Disabled by config. Please merge this manually once you are satisfied. ♻ **Rebasing**: Whenever PR becomes conflicted, or you tick the rebase/retry checkbox. 🔕 **Ignore**: Close this PR and you won't be reminded about this update again. --- - [ ] If you want to rebase/retry this PR, check this box --- This PR was generated by [Mend Renovate](https://www.mend.io/free-developer-tools/renovate/). View the [repository job log](https://developer.mend.io/github/secretflow/spu). Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> --- docs/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index b6df9201..937b1bf0 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,6 +1,6 @@ myst-parser==3.0.1 rstcheck==6.2.4 -sphinx==7.4.7 +sphinx==8.0.2 nbsphinx==0.9.4 sphinx-autobuild==2024.4.16 sphinx-markdown-parser==0.2.4 From cc97761e32a20496f4de79af102d9c195eec5142 Mon Sep 17 00:00:00 2001 From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com> Date: Wed, 31 Jul 2024 10:53:00 +0800 Subject: [PATCH 25/27] chore(deps): update ossf/scorecard-action action to v2.4.0 (#786) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [![Mend Renovate](https://app.renovatebot.com/images/banner.svg)](https://renovatebot.com) This PR contains the following updates: | Package | Type | Update | Change | |---|---|---|---| | [ossf/scorecard-action](https://togithub.com/ossf/scorecard-action) | action | minor | `v2.3.3` -> `v2.4.0` | --- ### Release Notes
ossf/scorecard-action (ossf/scorecard-action) ### [`v2.4.0`](https://togithub.com/ossf/scorecard-action/compare/v2.3.3...v2.4.0) [Compare Source](https://togithub.com/ossf/scorecard-action/compare/v2.3.3...v2.4.0)
--- ### Configuration 📅 **Schedule**: Branch creation - At any time (no schedule defined), Automerge - At any time (no schedule defined). 🚦 **Automerge**: Disabled by config. Please merge this manually once you are satisfied. ♻ **Rebasing**: Whenever PR becomes conflicted, or you tick the rebase/retry checkbox. 🔕 **Ignore**: Close this PR and you won't be reminded about this update again. --- - [ ] If you want to rebase/retry this PR, check this box --- This PR was generated by [Mend Renovate](https://www.mend.io/free-developer-tools/renovate/). View the [repository job log](https://developer.mend.io/github/secretflow/spu). Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> --- .github/workflows/scorecard.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/scorecard.yml b/.github/workflows/scorecard.yml index c327e62a..b83903c7 100644 --- a/.github/workflows/scorecard.yml +++ b/.github/workflows/scorecard.yml @@ -37,7 +37,7 @@ jobs: persist-credentials: false - name: "Run analysis" - uses: ossf/scorecard-action@dc50aa9510b46c811795eb24b2f1ba02a914e534 # v2.3.3 + uses: ossf/scorecard-action@62b2cac7ed8198b15735ed49ab1e5cf35480ba46 # v2.4.0 with: results_file: results.sarif results_format: sarif From 3601e8556a7718b74c3462ce071feb72bfdff0a8 Mon Sep 17 00:00:00 2001 From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com> Date: Wed, 31 Jul 2024 10:55:02 +0800 Subject: [PATCH 26/27] chore(deps): update github/codeql-action action to v3.25.15 (#783) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [![Mend Renovate](https://app.renovatebot.com/images/banner.svg)](https://renovatebot.com) This PR contains the following updates: | Package | Type | Update | Change | |---|---|---|---| | [github/codeql-action](https://togithub.com/github/codeql-action) | action | patch | `v3.25.13` -> `v3.25.15` | --- ### Release Notes
github/codeql-action (github/codeql-action) ### [`v3.25.15`](https://togithub.com/github/codeql-action/compare/v3.25.14...v3.25.15) [Compare Source](https://togithub.com/github/codeql-action/compare/v3.25.14...v3.25.15) ### [`v3.25.14`](https://togithub.com/github/codeql-action/compare/v3.25.13...v3.25.14) [Compare Source](https://togithub.com/github/codeql-action/compare/v3.25.13...v3.25.14)
--- ### Configuration 📅 **Schedule**: Branch creation - At any time (no schedule defined), Automerge - At any time (no schedule defined). 🚦 **Automerge**: Disabled by config. Please merge this manually once you are satisfied. ♻ **Rebasing**: Whenever PR becomes conflicted, or you tick the rebase/retry checkbox. 🔕 **Ignore**: Close this PR and you won't be reminded about this update again. --- - [ ] If you want to rebase/retry this PR, check this box --- This PR was generated by [Mend Renovate](https://www.mend.io/free-developer-tools/renovate/). View the [repository job log](https://developer.mend.io/github/secretflow/spu). Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> --- .github/workflows/scorecard.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/scorecard.yml b/.github/workflows/scorecard.yml index b83903c7..c446b981 100644 --- a/.github/workflows/scorecard.yml +++ b/.github/workflows/scorecard.yml @@ -67,6 +67,6 @@ jobs: # Upload the results to GitHub's code scanning dashboard. - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@2d790406f505036ef40ecba973cc774a50395aac # v3.25.13 + uses: github/codeql-action/upload-sarif@afb54ba388a7dca6ecae48f608c4ff05ff4cc77a # v3.25.15 with: sarif_file: results.sarif From 854f3ef0925bc419ae804ae77a4a9843ec15db7f Mon Sep 17 00:00:00 2001 From: Li Zhihang <97899797+xbw886@users.noreply.github.com> Date: Thu, 1 Aug 2024 14:43:29 +0800 Subject: [PATCH 27/27] =?UTF-8?q?[OSCP]=E4=BD=BF=E7=94=A8SPU=E5=AE=9E?= =?UTF-8?q?=E7=8E=B0=E9=9A=8F=E6=9C=BA=E6=A3=AE=E6=9E=97=E7=AE=97=E6=B3=95?= =?UTF-8?q?=20(#792)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fixed : https://github.com/secretflow/spu/issues/254 --- sml/ensemble/BUILD.bazel | 14 +- sml/ensemble/emulations/BUILD.bazel | 15 +- sml/ensemble/emulations/forest_emul.py | 125 +++++++++++++++ sml/ensemble/forest.py | 201 +++++++++++++++++++++++++ sml/ensemble/tests/BUILD.bazel | 16 +- sml/ensemble/tests/forest_test.py | 117 ++++++++++++++ 6 files changed, 485 insertions(+), 3 deletions(-) create mode 100644 sml/ensemble/emulations/forest_emul.py create mode 100644 sml/ensemble/forest.py create mode 100644 sml/ensemble/tests/forest_test.py diff --git a/sml/ensemble/BUILD.bazel b/sml/ensemble/BUILD.bazel index 7832e732..ccdc1848 100644 --- a/sml/ensemble/BUILD.bazel +++ b/sml/ensemble/BUILD.bazel @@ -1,4 +1,4 @@ -# Copyright 2023 Ant Group Co., Ltd. +# Copyright 2024 Ant Group Co., Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,3 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +load("@rules_python//python:defs.bzl", "py_library") + +package(default_visibility = ["//visibility:public"]) + +py_library( + name = "forest", + srcs = ["forest.py"], + deps = [ + "//sml/tree", + ], +) diff --git a/sml/ensemble/emulations/BUILD.bazel b/sml/ensemble/emulations/BUILD.bazel index 7832e732..6775fe71 100644 --- a/sml/ensemble/emulations/BUILD.bazel +++ b/sml/ensemble/emulations/BUILD.bazel @@ -1,4 +1,4 @@ -# Copyright 2023 Ant Group Co., Ltd. +# Copyright 2024 Ant Group Co., Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,3 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +load("@rules_python//python:defs.bzl", "py_binary") + +package(default_visibility = ["//visibility:public"]) + +py_binary( + name = "forest_emul", + srcs = ["forest_emul.py"], + deps = [ + "//sml/ensemble:forest", + "//sml/utils:emulation", + ], +) diff --git a/sml/ensemble/emulations/forest_emul.py b/sml/ensemble/emulations/forest_emul.py new file mode 100644 index 00000000..90437386 --- /dev/null +++ b/sml/ensemble/emulations/forest_emul.py @@ -0,0 +1,125 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time + +import jax.numpy as jnp +from sklearn.datasets import load_iris +from sklearn.ensemble import RandomForestClassifier + +import sml.utils.emulation as emulation +from sml.ensemble.forest import RandomForestClassifier as sml_rfc + +MAX_DEPTH = 3 +CONFIG_FILE = emulation.CLUSTER_ABY3_3PC + + +def emul_forest(mode=emulation.Mode.MULTIPROCESS): + def proc_wrapper( + n_estimators, + max_features, + criterion, + splitter, + max_depth, + bootstrap, + max_samples, + n_labels, + ): + rf_custom = sml_rfc( + n_estimators, + max_features, + criterion, + splitter, + max_depth, + bootstrap, + max_samples, + n_labels, + ) + + def proc(X, y): + rf_custom_fit = rf_custom.fit(X, y) + result = rf_custom_fit.predict(X) + return result + + return proc + + def load_data(): + iris = load_iris() + iris_data, iris_label = jnp.array(iris.data), jnp.array(iris.target) + # sorted_features: n_samples * n_features_in + n_samples, n_features_in = iris_data.shape + n_labels = len(jnp.unique(iris_label)) + sorted_features = jnp.sort(iris_data, axis=0) + new_threshold = (sorted_features[:-1, :] + sorted_features[1:, :]) / 2 + new_features = jnp.greater_equal( + iris_data[:, :], new_threshold[:, jnp.newaxis, :] + ) + new_features = new_features.transpose([1, 0, 2]).reshape(n_samples, -1) + + X, y = new_features[:, ::3], iris_label[:] + return X, y + + try: + # bandwidth and latency only work for docker mode + emulator = emulation.Emulator(CONFIG_FILE, mode, bandwidth=300, latency=20) + emulator.up() + + # load mock data + X, y = load_data() + n_labels = jnp.unique(y).shape[0] + + # compare with sklearn + rf = RandomForestClassifier( + n_estimators=3, + max_features=None, + criterion='gini', + max_depth=MAX_DEPTH, + bootstrap=False, + max_samples=None, + ) + start = time.time() + rf = rf.fit(X, y) + score_plain = rf.score(X, y) + end = time.time() + print(f"Running time in SKlearn: {end - start:.2f}s") + + # mark these data to be protected in SPU + X_spu, y_spu = emulator.seal(X, y) + + # run + proc = proc_wrapper( + n_estimators=3, + max_features=0.7, + criterion='gini', + splitter='best', + max_depth=3, + bootstrap=False, + max_samples=None, + n_labels=n_labels, + ) + start = time.time() + result = emulator.run(proc)(X_spu, y_spu) + end = time.time() + score_encrpted = jnp.mean((result == y)) + print(f"Running time in SPU: {end - start:.2f}s") + + # print acc + print(f"Accuracy in SKlearn: {score_plain:.2f}") + print(f"Accuracy in SPU: {score_encrpted:.2f}") + + finally: + emulator.down() + + +if __name__ == "__main__": + emul_forest(emulation.Mode.MULTIPROCESS) diff --git a/sml/ensemble/forest.py b/sml/ensemble/forest.py new file mode 100644 index 00000000..c2dd2649 --- /dev/null +++ b/sml/ensemble/forest.py @@ -0,0 +1,201 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import random + +import jax +import jax.numpy as jnp +from jax import lax + +from sml.tree.tree import DecisionTreeClassifier as sml_dtc + + +class RandomForestClassifier: + """A random forest classifier based on DecisionTreeClassifier. + + Parameters + ---------- + n_estimators : int + The number of trees in the forest. Must specify an integer > 0. + + max_features : int, float, "auto", "sqrt", "log2", or None. + The number of features to consider when looking for the best split. + If it's an integer, must 0 < integer < n_features. + If it's an float, must 0 < float <= 1. + + criterion : {"gini"}, default="gini" + The function to measure the quality of a split. Supported criteria are + "gini" for the Gini impurity. + + splitter : {"best"}, default="best" + The strategy used to choose the split at each node. Supported + strategies are "best" to choose the best split. + + max_depth : int + The maximum depth of the tree. Must specify an integer > 0. + + bootstrap : bool + Whether bootstrap samples are used when building trees. + + max_samples : int, float ,None, default=None + The number of samples to draw from X to train each base estimator. + This parameter is only valid if bootstrap is ture. + If it's an integer, must 0 < integer < n_samples. + If it's an float, must 0 < float <= 1. + + n_labels: int + The max number of labels. + + """ + + def __init__( + self, + n_estimators, + max_features, + criterion, + splitter, + max_depth, + bootstrap, + max_samples, + n_labels, + ): + assert criterion == "gini", "criteria other than gini is not supported." + assert splitter == "best", "splitter other than best is not supported." + assert ( + n_estimators is not None and n_estimators > 0 + ), "n_estimators should not be None and must > 0." + assert ( + max_depth is not None and max_depth > 0 + ), "max_depth should not be None and must > 0." + assert isinstance( + bootstrap, bool + ), "bootstrap should be a boolean value (True or False)" + + self.n_estimators = n_estimators + self.max_features = max_features + self.criterion = criterion + self.splitter = splitter + self.max_depth = max_depth + self.bootstrap = bootstrap + self.max_samples = max_samples + self.n_labels = n_labels + + self.trees = [] + self.features_indices = [] + + def _calculate_max_samples(self, max_samples, n_samples): + if isinstance(max_samples, int): + assert ( + max_samples <= n_samples + ), "max_samples should not exceed n_samples when it's an integer" + return max_samples + elif isinstance(max_samples, float): + assert ( + 0 < max_samples <= 1 + ), "max_samples should be in the range (0, 1] when it's a float" + return int(max_samples * n_samples) + else: + return n_samples + + def _bootstrap_sample(self, X, y): + n_samples = X.shape[0] + max_samples = self._calculate_max_samples(self.max_samples, n_samples) + + if not self.bootstrap: + return X, y + + # 实现bootstrap + population = range(n_samples) + indices = random.sample(population, max_samples) + + indices = jnp.array(indices) + return X[indices], y[indices] + + def _select_features(self, n, k): + indices = range(n) + selected_elements = random.sample(indices, k) + return selected_elements + + def _calculate_max_features(self, max_features, n_features): + if isinstance(max_features, int): + assert ( + 0 < max_features <= n_features + ), "0 < max_features <= n_features when it's an integer" + return max_features + + elif isinstance(max_features, float): + assert ( + 0 < max_features <= 1 + ), "max_features should be in the range (0, 1] when it's a float" + return int(max_features * n_features) + + elif isinstance(max_features, str): + if max_features == 'sqrt': + return int(math.sqrt(n_features)) + elif max_features == 'log2': + return int(math.log2(n_features)) + else: + return n_features + else: + return n_features + + def fit(self, X, y): + n_samples, n_features = X.shape + self.n_features = n_features + self.max_features = self._calculate_max_features( + self.max_features, self.n_features + ) + self.label_list = jnp.arange(self.n_labels) + + self.trees = [] + self.features_indices = [] + + for _ in range(self.n_estimators): + X_sample, y_sample = self._bootstrap_sample(X, y) + features = self._select_features(self.n_features, self.max_features) + + tree = sml_dtc(self.criterion, self.splitter, self.max_depth, self.n_labels) + tree.fit(X_sample[:, features], y_sample) + self.trees.append(tree) + self.features_indices.append(features) + + return self + + def jax_mode_row_vectorized(self, data): + label_list = jnp.array(self.label_list) + + data_expanded = jnp.expand_dims(data, axis=-1) + label_expanded = jnp.expand_dims(label_list, axis=0) + + mask = (data_expanded == label_expanded).astype(jnp.int32) + + counts = jnp.sum(mask, axis=1) + mode_indices = jnp.argmax(counts, axis=1) + + modes = label_list[mode_indices] + return modes + + def predict(self, X): + predictions_list = [] + for i, tree in enumerate(self.trees): + features = self.features_indices[i] + predictions = tree.predict(X[:, features]) + predictions_list.append(predictions) + + tree_predictions = jnp.array(predictions_list).T + + y_pred = self.jax_mode_row_vectorized(tree_predictions) + + return y_pred.ravel() diff --git a/sml/ensemble/tests/BUILD.bazel b/sml/ensemble/tests/BUILD.bazel index 7832e732..2caeadf0 100644 --- a/sml/ensemble/tests/BUILD.bazel +++ b/sml/ensemble/tests/BUILD.bazel @@ -1,4 +1,4 @@ -# Copyright 2023 Ant Group Co., Ltd. +# Copyright 2024 Ant Group Co., Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,3 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +load("@rules_python//python:defs.bzl", "py_test") + +package(default_visibility = ["//visibility:public"]) + +py_test( + name = "forest_test", + srcs = ["forest_test.py"], + deps = [ + "//sml/ensemble:forest", + "//spu:init", + "//spu/utils:simulation", + ], +) diff --git a/sml/ensemble/tests/forest_test.py b/sml/ensemble/tests/forest_test.py new file mode 100644 index 00000000..6c9d3280 --- /dev/null +++ b/sml/ensemble/tests/forest_test.py @@ -0,0 +1,117 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import jax.numpy as jnp +from sklearn.datasets import load_iris +from sklearn.ensemble import RandomForestClassifier + +import spu.spu_pb2 as spu_pb2 # type: ignore +import spu.utils.simulation as spsim +from sml.ensemble.forest import RandomForestClassifier as sml_rfc + +MAX_DEPTH = 3 + + +class UnitTests(unittest.TestCase): + def test_forest(self): + def proc_wrapper( + n_estimators, + max_features, + criterion, + splitter, + max_depth, + bootstrap, + max_samples, + n_labels, + ): + rf_custom = sml_rfc( + n_estimators, + max_features, + criterion, + splitter, + max_depth, + bootstrap, + max_samples, + n_labels, + ) + + def proc(X, y): + rf_custom_fit = rf_custom.fit(X, y) + + result = rf_custom_fit.predict(X) + return result + + return proc + + def load_data(): + iris = load_iris() + iris_data, iris_label = jnp.array(iris.data), jnp.array(iris.target) + n_samples, n_features_in = iris_data.shape + n_labels = len(jnp.unique(iris_label)) + sorted_features = jnp.sort(iris_data, axis=0) + new_threshold = (sorted_features[:-1, :] + sorted_features[1:, :]) / 2 + new_features = jnp.greater_equal( + iris_data[:, :], new_threshold[:, jnp.newaxis, :] + ) + new_features = new_features.transpose([1, 0, 2]).reshape(n_samples, -1) + + X, y = new_features[:, ::3], iris_label[:] + return X, y + + # bandwidth and latency only work for docker mode + sim = spsim.Simulator.simple( + 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 + ) + + # load mock data + X, y = load_data() + n_labels = jnp.unique(y).shape[0] + + # compare with sklearn + rf = RandomForestClassifier( + n_estimators=3, + max_features="log2", + criterion='gini', + max_depth=MAX_DEPTH, + bootstrap=True, + max_samples=0.7, + ) + rf = rf.fit(X, y) + score_plain = rf.score(X, y) + tree_predictions = jnp.array([tree.predict(X) for tree in rf.estimators_]) + + # run + proc = proc_wrapper( + n_estimators=3, + max_features="log2", + criterion='gini', + splitter='best', + max_depth=3, + bootstrap=True, + max_samples=0.7, + n_labels=n_labels, + ) + + result = spsim.sim_jax(sim, proc)(X, y) + + score_encrpted = jnp.mean((result == y)) + + # print acc + print(f"Accuracy in SKlearn: {score_plain}") + print(f"Accuracy in SPU: {score_encrpted}") + + +if __name__ == "__main__": + unittest.main()