From 4a4c5fa9a68fa6fa274e23cb3fedaa924ea19f2f Mon Sep 17 00:00:00 2001 From: anakinxc <103552181+anakinxc@users.noreply.github.com> Date: Fri, 12 Jan 2024 14:18:54 +0800 Subject: [PATCH] Repo sync (#489) --- .vscode/settings.json | 3 +- CHANGELOG.md | 4 +- bazel/repositories.bzl | 6 +- examples/python/ml/flax_llama7b/README.md | 1 + libspu/compiler/passes/BUILD.bazel | 11 - .../compiler/passes/expand_secret_gather.cc | 37 +- .../compiler/passes/hlo_legalize_to_pphlo.cc | 112 ++-- libspu/compiler/passes/optimize_maxpool.cc | 24 +- libspu/compiler/passes/utils.cc | 36 -- libspu/compiler/tests/convert_push_down.mlir | 2 +- libspu/compiler/tests/enum_conversion_test.cc | 3 +- .../compiler/tests/expand_secret_gather.mlir | 4 +- .../tests/hlo_to_pphlo_ops_other.mlir | 10 +- .../tests/hlo_to_pphlo_reduce_window.mlir | 2 +- .../hlo_to_pphlo_select_and_scatter.mlir | 2 +- .../tests/no_expand_secret_gather.mlir | 2 +- libspu/compiler/tests/ops_negative.mlir | 40 +- .../tests/optimize_denominator_with_bcst.mlir | 2 +- libspu/compiler/tests/optimize_maxpool.mlir | 6 +- .../tests/pphlo_type_inference_reduce.mlir | 4 +- libspu/compiler/tools/BUILD.bazel | 16 + libspu/compiler/tools/mlir-pphlo-lsp.cc | 29 + libspu/core/context.h | 7 +- libspu/core/object.h | 6 + libspu/core/shape.h | 6 + libspu/core/type.h | 10 + libspu/core/value.cc | 5 + libspu/core/value.h | 1 + libspu/cuda_support/BUILD.bazel | 13 - libspu/device/pphlo/pphlo_executor.cc | 87 +-- libspu/device/pphlo/pphlo_executor_test.cc | 326 ++++++++++- .../pphlo/pphlo_executor_test_runner.cc | 1 + .../device/pphlo/pphlo_executor_test_runner.h | 5 +- libspu/device/pphlo/pphlo_verifier_test.cc | 2 +- libspu/device/test_utils.h | 5 +- libspu/dialect/pphlo_ops.cc | 161 ++---- libspu/dialect/pphlo_ops.td | 51 +- libspu/kernel/hal/BUILD.bazel | 2 +- libspu/kernel/hal/debug.cc | 2 +- libspu/kernel/hal/permute.cc | 538 ++++++++++++++---- libspu/kernel/hal/prot_wrapper.cc | 107 +++- libspu/kernel/hal/prot_wrapper.h | 33 +- libspu/kernel/hal/ring.cc | 51 +- libspu/kernel/hal/shape_ops.cc | 119 +--- libspu/kernel/hal/shape_ops.h | 2 +- libspu/kernel/hlo/geometrical.cc | 2 +- libspu/kernel/hlo/geometrical.h | 2 +- libspu/kernel/hlo/indexing.cc | 151 +++-- libspu/kernel/hlo/indexing.h | 8 +- libspu/kernel/hlo/shuffle.cc | 18 +- libspu/mpc/ab_api_test.cc | 59 +- libspu/mpc/aby3/BUILD.bazel | 1 + libspu/mpc/aby3/permute.cc | 6 +- libspu/mpc/aby3/permute.h | 12 +- libspu/mpc/aby3/protocol.cc | 79 +-- libspu/mpc/aby3/type.h | 4 +- libspu/mpc/api.cc | 98 +++- libspu/mpc/api.h | 63 +- libspu/mpc/cheetah/BUILD.bazel | 1 + libspu/mpc/cheetah/arith/cheetah_dot.cc | 4 +- libspu/mpc/cheetah/arith/conv2d_prot.cc | 2 +- libspu/mpc/cheetah/arith/matmat_prot.cc | 2 +- libspu/mpc/cheetah/arithmetic.cc | 25 + libspu/mpc/cheetah/arithmetic.h | 11 + libspu/mpc/cheetah/ot/yacl/BUILD.bazel | 4 +- libspu/mpc/cheetah/ot/yacl/ferret.cc | 118 ++-- libspu/mpc/cheetah/ot/yacl/ferret.h | 1 + .../mpc/cheetah/ot/yacl/yacl_ferret_test.cc | 211 ------- .../mpc/cheetah/ot/yacl/yacl_ote_adapter.cc | 219 ++++--- libspu/mpc/cheetah/ot/yacl/yacl_ote_adapter.h | 136 +++-- libspu/mpc/cheetah/ot/yacl/yacl_util.cc | 41 -- libspu/mpc/cheetah/ot/yacl/yacl_util.h | 6 + libspu/mpc/cheetah/protocol.cc | 57 +- libspu/mpc/cheetah/rlwe/modswitch_helper.cc | 6 +- libspu/mpc/common/pv2k.cc | 316 ++++++++-- libspu/mpc/kernel.cc | 109 +++- libspu/mpc/kernel.h | 76 +++ libspu/mpc/ref2k/BUILD.bazel | 1 + libspu/mpc/ref2k/ref2k.cc | 43 +- libspu/mpc/securenn/BUILD.bazel | 1 + libspu/mpc/securenn/protocol.cc | 64 +-- libspu/mpc/semi2k/BUILD.bazel | 1 + libspu/mpc/semi2k/permute.cc | 26 +- libspu/mpc/semi2k/permute.h | 27 +- libspu/mpc/semi2k/protocol.cc | 62 +- libspu/mpc/semi2k/type.h | 4 +- libspu/mpc/spdz2k/BUILD.bazel | 1 + libspu/mpc/spdz2k/protocol.cc | 56 +- libspu/mpc/standard_shape/BUILD.bazel | 43 ++ libspu/mpc/standard_shape/kernels.cc | 116 ++++ libspu/mpc/standard_shape/kernels.h | 131 +++++ libspu/mpc/standard_shape/protocol.cc | 34 ++ .../utils.h => mpc/standard_shape/protocol.h} | 10 +- libspu/mpc/utils/ring_ops.cc | 4 +- spu/version.py | 2 +- 95 files changed, 2914 insertions(+), 1456 deletions(-) delete mode 100644 libspu/compiler/passes/utils.cc create mode 100644 libspu/compiler/tools/mlir-pphlo-lsp.cc delete mode 100644 libspu/mpc/cheetah/ot/yacl/yacl_ferret_test.cc delete mode 100644 libspu/mpc/cheetah/ot/yacl/yacl_util.cc create mode 100644 libspu/mpc/standard_shape/BUILD.bazel create mode 100644 libspu/mpc/standard_shape/kernels.cc create mode 100644 libspu/mpc/standard_shape/kernels.h create mode 100644 libspu/mpc/standard_shape/protocol.cc rename libspu/{compiler/passes/utils.h => mpc/standard_shape/protocol.h} (75%) diff --git a/.vscode/settings.json b/.vscode/settings.json index 24e63f35..c21d4f53 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -40,5 +40,6 @@ "git.ignoreLimitWarning": true, "[python]": { "editor.defaultFormatter": "ms-python.black-formatter" - } + }, + "mlir.server_path": "bazel-bin/libspu/compiler/tools/mlir-pphlo-lsp" } \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 15306cdd..e7280d1a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,9 @@ > > please add your unreleased change here. -## TBD +- [Improvement] Optimize one-time setup for yacl ot + +## 20240105 - [Feature] Add Odd-Even Merge Sort to replace the bitonic sort - [Feature] Add radix sort support for ABY3 diff --git a/bazel/repositories.bzl b/bazel/repositories.bzl index 39f8e808..a5cb186b 100644 --- a/bazel/repositories.bzl +++ b/bazel/repositories.bzl @@ -140,15 +140,15 @@ def _com_github_xtensor_xtl(): ) def _com_github_openxla_xla(): - OPENXLA_COMMIT = "d5791b01aa7541e3400224ac0a2985cc0f6940cb" - OPENXLA_SHA256 = "82dd50e6f51d79e8da69f109a234e33b8036f7b8798e41a03831b19c0c64d6e5" + OPENXLA_COMMIT = "fa9331a7e557b4ec1381f84cbbf7401a8f41ac66" + OPENXLA_SHA256 = "d19c570d434002b7b0490327d407fc7cf2b18633f4a2d3b1bb44f3f0e4b36533" maybe( http_archive, name = "bazel_skylib", sha256 = "74d544d96f4a5bb630d465ca8bbcfe231e3594e5aae57e1edbf17a6eb3ca2506", urls = [ - "https://github.com/bazelbuild/bazel-skylib/releases/download/{version}/bazel-skylib-1.3.0.tar.gz", + "https://github.com/bazelbuild/bazel-skylib/releases/download/1.3.0/bazel-skylib-1.3.0.tar.gz", ], ) diff --git a/examples/python/ml/flax_llama7b/README.md b/examples/python/ml/flax_llama7b/README.md index fb4f0980..c95d8cc1 100644 --- a/examples/python/ml/flax_llama7b/README.md +++ b/examples/python/ml/flax_llama7b/README.md @@ -22,6 +22,7 @@ This example demonstrates how to use SPU to run secure inference on a pre-traine Since EasyLM have an issue,so we have to make a samll change to support the option "streaming=false". Open and edit "convert_hf_to_easylm.py", chang this: + ```python parser.add_argument("--streaming", action="store_true", default=True, help="whether is model weight saved stream format",) ``` diff --git a/libspu/compiler/passes/BUILD.bazel b/libspu/compiler/passes/BUILD.bazel index 23483b0a..c470e86e 100644 --- a/libspu/compiler/passes/BUILD.bazel +++ b/libspu/compiler/passes/BUILD.bazel @@ -92,7 +92,6 @@ spu_cc_library( deps = [ ":map_stablehlo_to_pphlo_op", ":pass_details", - ":utils", ":visibility_inference", "//libspu/compiler/common:compilation_context", "//libspu/core:prelude", @@ -204,7 +203,6 @@ spu_cc_library( hdrs = ["passes.h"], deps = [ ":pass_details", - ":utils", "//libspu/dialect:pphlo_dialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:TransformUtils", @@ -259,15 +257,6 @@ spu_cc_library( ], ) -spu_cc_library( - name = "utils", - srcs = ["utils.cc"], - hdrs = ["utils.h"], - deps = [ - "@llvm-project//mlir:IR", - ], -) - spu_cc_library( name = "convert_push_down", srcs = ["convert_push_down.cc"], diff --git a/libspu/compiler/passes/expand_secret_gather.cc b/libspu/compiler/passes/expand_secret_gather.cc index f37e9577..10117f31 100644 --- a/libspu/compiler/passes/expand_secret_gather.cc +++ b/libspu/compiler/passes/expand_secret_gather.cc @@ -20,7 +20,6 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "libspu/compiler/passes/pass_details.h" -#include "libspu/compiler/passes/utils.h" #include "libspu/dialect/pphlo_ops.h" namespace mlir::pphlo { @@ -28,7 +27,7 @@ namespace mlir::pphlo { namespace { bool GatherIsBroadcast(GatherOp &op) { - auto gather_slice_size = op.getSliceSizes().getValues(); + auto gather_slice_size = op.getSliceSizes(); auto op_shape = op.getOperand().getType().getShape(); return (gather_slice_size.size() == op_shape.size()) && (std::equal(gather_slice_size.begin(), gather_slice_size.end(), @@ -113,7 +112,7 @@ TransposeIndexVectorDimToLast(TypedValue &start_indices, start_indices.getLoc(), RankedTensorType::get(result_shape, start_indices.getType().getElementType()), - start_indices, ConvertDimensions(&builder, permutation)); + start_indices, permutation); return transpose.getResult(); } @@ -211,17 +210,16 @@ CanonicalizeGatherIndices(TypedValue &start_indices, } TypedValue CreateGatherLoopAccumulatorInitValue( - GatherOp op, Type element_type, DenseIntElementsAttr slice_sizes, + GatherOp op, Type element_type, llvm::ArrayRef slice_sizes, int64_t gather_loop_trip_count, const GatherDimensionNumbersAttr &dim_numbers) { std::vector accumulator_state_shape_dims; - auto array_slice_size = slice_sizes.getValues(); - accumulator_state_shape_dims.reserve(1 + array_slice_size.size()); + accumulator_state_shape_dims.reserve(1 + slice_sizes.size()); accumulator_state_shape_dims.push_back(gather_loop_trip_count); - for (int64_t i = 0; i < static_cast(array_slice_size.size()); i++) { + for (int64_t i = 0; i < static_cast(slice_sizes.size()); i++) { if (!std::binary_search(dim_numbers.getCollapsedSliceDims().begin(), dim_numbers.getCollapsedSliceDims().end(), i)) { - accumulator_state_shape_dims.emplace_back(array_slice_size[i]); + accumulator_state_shape_dims.emplace_back(slice_sizes[i]); } } @@ -392,9 +390,12 @@ llvm::SmallVector ExpandIndexVectorIntoOperandSpace( auto component_to_concat = builder->create( index_vector.getLoc(), RankedTensorType::get({1}, index_vector.getType().getElementType()), - index_vector, ConvertDimensions(builder, {index_vector_dim_index}), - ConvertDimensions(builder, {index_vector_dim_index + 1}), - ConvertDimensions(builder, {1})); + index_vector, + DenseI64ArrayAttr::get(builder->getContext(), + {index_vector_dim_index}), + DenseI64ArrayAttr::get(builder->getContext(), + {index_vector_dim_index + 1}), + DenseI64ArrayAttr::get(builder->getContext(), {1})); auto reshaped = builder->create( index_vector.getLoc(), RankedTensorType::get({}, index_vector.getType().getElementType()), @@ -450,9 +451,9 @@ void GatherLoopBody(GatherOp gather, Region &body, if (has_scalar_indices) { // In this case start_indices has rank 1 and induction_var_as_vector (of // shape {1}) is an index into this rank 1 tensor. - auto ds = builder.create(gather->getLoc(), start_indices, - ValueRange{induction_var}, - ConvertDimensions(&builder, {1})); + auto ds = builder.create( + gather->getLoc(), start_indices, ValueRange{induction_var}, + DenseI64ArrayAttr::get(builder.getContext(), {1})); index_vector = ds.getResult(); } else { // In this case start_indices has rank 2 and induction_var_as_vector (of @@ -463,7 +464,7 @@ void GatherLoopBody(GatherOp gather, Region &body, auto index_vector_2d = builder.create( gather->getLoc(), start_indices, ValueRange{induction_var, index_zero}, - ConvertDimensions(&builder, {1, index_vector_size})); + DenseI64ArrayAttr::get(builder.getContext(), {1, index_vector_size})); index_vector = ElideDegenerateDims(&builder, index_vector_2d, {0}); } @@ -523,8 +524,8 @@ struct GatherConverter : public OpRewritePattern { op->getLoc(), reshaped_type, op.getOperand()); rewriter.replaceOpWithNewOp( op, op->getResults().getType(), broadcast_operand, - ConvertDimensions(&builder, - op.getDimensionNumbers().getOffsetDims())); + DenseI64ArrayAttr::get(builder.getContext(), + op.getDimensionNumbers().getOffsetDims())); return success(); } @@ -612,7 +613,7 @@ struct GatherConverter : public OpRewritePattern { rewriter.replaceOpWithNewOp( op, op.getResult().getType(), accumulator_with_batch_dims_decanonicalized, - ConvertDimensions(&builder, permutation)); + DenseI64ArrayAttr::get(builder.getContext(), permutation)); return success(); } diff --git a/libspu/compiler/passes/hlo_legalize_to_pphlo.cc b/libspu/compiler/passes/hlo_legalize_to_pphlo.cc index e77da754..c5c78b7f 100644 --- a/libspu/compiler/passes/hlo_legalize_to_pphlo.cc +++ b/libspu/compiler/passes/hlo_legalize_to_pphlo.cc @@ -28,7 +28,6 @@ #include "libspu/compiler/passes/map_stablehlo_to_pphlo_op.h" #include "libspu/compiler/passes/pass_details.h" -#include "libspu/compiler/passes/utils.h" #include "libspu/compiler/passes/value_visibility_map.h" #include "libspu/compiler/passes/visibility_inference.h" #include "libspu/core/prelude.h" @@ -40,6 +39,11 @@ namespace mlir::pphlo { namespace { +DenseI64ArrayAttr ConvertDenseIntElementAttr(const DenseIntElementsAttr &attr) { + llvm::SmallVector array(attr.getValues()); + return DenseI64ArrayAttr::get(attr.getContext(), array); +} + ValueVisibilityMap VisibilityDiscovery(const llvm::ArrayRef input_vis_list, ModuleOp op) { @@ -348,9 +352,14 @@ struct ReduceOpConverter : public OpConversionPattern { sig_conversion.addInputs(arg.getArgNumber(), lower_t); } + mlir::NamedAttribute dimAttr( + StringAttr::get(op->getContext(), "dimensions"), + ConvertDenseIntElementAttr(op.getDimensions())); + auto new_op = rewriter.replaceOpWithNewOp>( - op, result_types, materialized_operands, op->getAttrs()); + op, result_types, materialized_operands, + llvm::SmallVector{dimAttr}); // Copy over the operations inside the region. rewriter.inlineRegionBefore(op.getBody(), new_op.getBody(), @@ -466,9 +475,9 @@ struct ReduceWindowOpConverter materialized_operands[idx] = rewriter.create( op->getLoc(), materialized_operands[idx], materialized_operands[idx + num_results], - builder.getI64TensorAttr(padding_low), - builder.getI64TensorAttr(padding_high), - builder.getI64TensorAttr(interior_padding)); + DenseI64ArrayAttr::get(op->getContext(), padding_low), + DenseI64ArrayAttr::get(op->getContext(), padding_high), + DenseI64ArrayAttr::get(op->getContext(), interior_padding)); } } } @@ -476,17 +485,20 @@ struct ReduceWindowOpConverter llvm::SmallVector attrs; { // I64ElementsAttr:$window_dimensions, - attrs.push_back({builder.getStringAttr("window_dimensions"), - op.getWindowDimensionsAttr()}); + attrs.push_back( + {builder.getStringAttr("window_dimensions"), + ConvertDenseIntElementAttr(op.getWindowDimensionsAttr())}); // OptionalAttr:$window_strides, if (op.getWindowStrides().has_value()) { - attrs.push_back({builder.getStringAttr("window_strides"), - op.getWindowStridesAttr()}); + attrs.push_back( + {builder.getStringAttr("window_strides"), + ConvertDenseIntElementAttr(op.getWindowStridesAttr())}); } // OptionalAttr:$window_dilations, if (op.getWindowDilations().has_value()) { - attrs.push_back({builder.getStringAttr("window_dilations"), - op.getWindowDilationsAttr()}); + attrs.push_back( + {builder.getStringAttr("window_dilations"), + ConvertDenseIntElementAttr(op.getWindowDilationsAttr())}); } } @@ -755,6 +767,40 @@ class HloToPPHloOpConverter : public OpConversionPattern { } }; +template <> +class HloToPPHloOpConverter + : public OpConversionPattern { +private: + const ValueVisibilityMap &vis_; + +public: + HloToPPHloOpConverter(TypeConverter &type_converter, MLIRContext *context, + const ValueVisibilityMap &vis) + : OpConversionPattern(type_converter, + context), + vis_(vis) {} + + LogicalResult + matchAndRewrite(stablehlo::BroadcastInDimOp hlo_op, + stablehlo::BroadcastInDimOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto result_vis = vis_.getValueVisibility(hlo_op.getResult()); + + Type resultType = HloToPPHloTypeConverter::getTypeWithVisibility( + this->getTypeConverter()->convertType(hlo_op.getType()), result_vis); + + mlir::NamedAttribute dim( + StringAttr::get(hlo_op.getContext(), "broadcast_dimensions"), + ConvertDenseIntElementAttr(hlo_op.getBroadcastDimensions())); + + rewriter + .replaceOpWithNewOp>( + hlo_op, resultType, adaptor.getOperands(), dim); + + return success(); + } +}; + template <> class HloToPPHloOpConverter : public OpConversionPattern { @@ -1003,14 +1049,15 @@ struct HloToPPHloOpConverter materialized_operand = rewriter.create( op->getLoc(), materialized_operand, materialized_init_value, - builder.getI64TensorAttr(padding_low), - builder.getI64TensorAttr(padding_high), - builder.getI64TensorAttr(padding_interior)); + DenseI64ArrayAttr::get(op->getContext(), padding_low), + DenseI64ArrayAttr::get(op->getContext(), padding_high), + DenseI64ArrayAttr::get(op->getContext(), padding_interior)); new_op = rewriter.create( op->getLoc(), materialized_operand.getType(), materialized_operand, adaptor.getSource(), materialized_init_value, - op.getWindowDimensionsAttr(), op.getWindowStridesAttr()); + ConvertDenseIntElementAttr(op.getWindowDimensionsAttr()), + ConvertDenseIntElementAttr(op.getWindowStridesAttr())); llvm::SmallVector slice_end( new_op.getType().dyn_cast().getShape().begin(), @@ -1022,15 +1069,18 @@ struct HloToPPHloOpConverter // Slice back rewriter.replaceOpWithNewOp( - op, result_type, new_op, ConvertDimensions(&builder, padding_low), - ConvertDimensions(&builder, slice_end), - ConvertDimensions(&builder, - llvm::SmallVector(slice_end.size(), 1))); + op, result_type, new_op, + DenseI64ArrayAttr::get(builder.getContext(), padding_low), + DenseI64ArrayAttr::get(builder.getContext(), slice_end), + DenseI64ArrayAttr::get( + builder.getContext(), + llvm::SmallVector(slice_end.size(), 1))); } else { new_op = rewriter.replaceOpWithNewOp( op, result_type, materialized_operand, adaptor.getSource(), - materialized_init_value, op.getWindowDimensionsAttr(), - op.getWindowStridesAttr()); + materialized_init_value, + ConvertDenseIntElementAttr(op.getWindowDimensionsAttr()), + ConvertDenseIntElementAttr(op.getWindowStridesAttr())); } // Convert the region signature. @@ -1204,7 +1254,8 @@ class HloToPPHloOpConverter rewriter.replaceOpWithNewOp( op, resultType, adaptor.getOperands()[0], adaptor.getOperands()[1], - attr, op.getSliceSizes(), op.getIndicesAreSorted()); + attr, ConvertDenseIntElementAttr(op.getSliceSizes()), + op.getIndicesAreSorted()); return success(); } @@ -1262,17 +1313,15 @@ class HloToPPHloOpConverter } TypeTools type_tools; - auto indexType = rewriter.getIntegerType(64); - auto attrType = RankedTensorType::get({rank}, indexType); Value zero = rewriter.create( loc, rewriter.getZeroAttr(RankedTensorType::get( {}, type_tools.getExpressedType(inputType.getElementType())))); zero = rewriter.create( loc, RankedTensorType::get({}, inputType.getElementType()), zero); return rewriter.create( - loc, input, zero, DenseIntElementsAttr::get(attrType, padLow), - DenseIntElementsAttr::get(attrType, padHigh), - DenseIntElementsAttr::get(attrType, padInterior)); + loc, input, zero, DenseI64ArrayAttr::get(loc.getContext(), padLow), + DenseI64ArrayAttr::get(loc.getContext(), padHigh), + DenseI64ArrayAttr::get(loc.getContext(), padInterior)); } public: @@ -1338,16 +1387,13 @@ class HloToPPHloOpConverter modifiedRhs = rewriter.create( op.getLoc(), modifiedRhs, - mlir::DenseIntElementsAttr::get( - RankedTensorType::get(reversedDims.size(), - rewriter.getIntegerType(64)), - reversedDims)); + DenseI64ArrayAttr::get(op->getContext(), reversedDims)); } rewriter.replaceOpWithNewOp( op, resultType, modifiedLhs, modifiedRhs, - op.getWindowStrides().value_or(nullptr), attr, - op.getFeatureGroupCount(), op.getBatchGroupCount()); + ConvertDenseIntElementAttr(op.getWindowStrides().value_or(nullptr)), + attr, op.getFeatureGroupCount(), op.getBatchGroupCount()); return success(); } diff --git a/libspu/compiler/passes/optimize_maxpool.cc b/libspu/compiler/passes/optimize_maxpool.cc index 975765e2..8ca78097 100644 --- a/libspu/compiler/passes/optimize_maxpool.cc +++ b/libspu/compiler/passes/optimize_maxpool.cc @@ -37,10 +37,9 @@ struct SelectAndScatterConverter : public OpRewritePattern { Value rewriteReduceWindow(ReduceWindowOp op, PatternRewriter &rewriter) const { - auto window_size = - std::accumulate(op.getWindowDimensions().getValues().begin(), - op.getWindowDimensions().getValues().end(), 1, - std::multiplies()); + auto window_size = std::accumulate(op.getWindowDimensions().begin(), + op.getWindowDimensions().end(), 1, + std::multiplies()); auto current_ret_type = op.getResult(0).getType().dyn_cast(); @@ -59,8 +58,10 @@ struct SelectAndScatterConverter : public OpRewritePattern { auto argmax = builder.create( op->getLoc(), SmallVector{current_ret_type, index_result_type}, op.getInputs()[0], op.getWindowDimensions(), - op.getWindowStrides().value_or(nullptr), - op.getWindowDilations().value_or(nullptr)); + DenseI64ArrayAttr::get(op->getContext(), + op.getWindowStrides().value_or(std::nullopt)), + DenseI64ArrayAttr::get(op->getContext(), + op.getWindowDilations().value_or(std::nullopt))); op->getResult(0).replaceAllUsesWith(argmax->getResult(0)); @@ -101,10 +102,6 @@ struct SelectAndScatterConverter : public OpRewritePattern { Value selected_indices; bool rewritten = false; - auto isAllOne = [](const DenseIntElementsAttr &attr) { - return attr.isSplat() && attr.getSplatValue() == 1; - }; - for (const auto &u : uses) { if (auto previous_reduce_window = mlir::dyn_cast(u.getOwner())) { @@ -134,7 +131,8 @@ struct SelectAndScatterConverter : public OpRewritePattern { // Make sure no dilation auto window_dilation = previous_reduce_window.getWindowDilations(); - if (window_dilation.has_value() && !isAllOne(*window_dilation)) { + if (window_dilation.has_value() && + !llvm::all_of(*window_dilation, [](int64_t v) { return v == 1; })) { continue; } @@ -152,7 +150,9 @@ struct SelectAndScatterConverter : public OpRewritePattern { rewriter.replaceOpWithNewOp( op, op->getResultTypes()[0], selected_indices, op.getSource(), - op.getWindowDimensions(), op.getWindowStrides().value_or(nullptr)); + DenseI64ArrayAttr::get(op->getContext(), op.getWindowDimensions()), + DenseI64ArrayAttr::get(op->getContext(), + op.getWindowStrides().value_or(std::nullopt))); return status; } diff --git a/libspu/compiler/passes/utils.cc b/libspu/compiler/passes/utils.cc deleted file mode 100644 index 689ea14a..00000000 --- a/libspu/compiler/passes/utils.cc +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2023 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "libspu/compiler/passes/utils.h" - -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypes.h" - -namespace mlir::pphlo { - -mlir::DenseIntElementsAttr -ConvertDimensions(OpBuilder *builder, llvm::ArrayRef op_dimensions) { - llvm::SmallVector dimensions; - dimensions.reserve(op_dimensions.size()); - for (auto value : op_dimensions) { - dimensions.emplace_back(APInt(64, value)); - } - - return DenseIntElementsAttr::get( - RankedTensorType::get(dimensions.size(), builder->getIntegerType(64)), - dimensions); -} - -} // namespace mlir::pphlo diff --git a/libspu/compiler/tests/convert_push_down.mlir b/libspu/compiler/tests/convert_push_down.mlir index d4dae75c..a8a877d3 100644 --- a/libspu/compiler/tests/convert_push_down.mlir +++ b/libspu/compiler/tests/convert_push_down.mlir @@ -15,7 +15,7 @@ func.func @main(%arg0: tensor<2x3x!pphlo.pub>, %arg1: tensor<2x3x!pphlo.pub // CHECK: %0 = "pphlo.transpose"(%arg0) // CHECK: %1 = "pphlo.convert"(%0) %0 = "pphlo.convert"(%arg0) : (tensor<2x3x!pphlo.pub>) -> tensor<2x3x!pphlo.pub> - %1 = "pphlo.transpose"(%0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x3x!pphlo.pub>) -> tensor<3x2x!pphlo.pub> + %1 = "pphlo.transpose"(%0) {permutation = array} : (tensor<2x3x!pphlo.pub>) -> tensor<3x2x!pphlo.pub> %2 = "pphlo.dot"(%1, %arg1) : (tensor<3x2x!pphlo.pub>, tensor<2x3x!pphlo.pub>) -> tensor<3x3x!pphlo.pub> return %2 : tensor<3x3x!pphlo.pub> } diff --git a/libspu/compiler/tests/enum_conversion_test.cc b/libspu/compiler/tests/enum_conversion_test.cc index 889110a5..8a41d105 100644 --- a/libspu/compiler/tests/enum_conversion_test.cc +++ b/libspu/compiler/tests/enum_conversion_test.cc @@ -29,7 +29,8 @@ TEST(EnumConversion, Public) { mlir::pphlo::symbolizeEnum(Visibility_Name(v)); \ EXPECT_EQ(mlir_v, mlir::pphlo::Visibility::T); - {CHECK(VIS_PUBLIC)} { CHECK(VIS_SECRET) } + { CHECK(VIS_PUBLIC) } + { CHECK(VIS_SECRET) } #undef CHECK } diff --git a/libspu/compiler/tests/expand_secret_gather.mlir b/libspu/compiler/tests/expand_secret_gather.mlir index 6e9a2848..36ac1553 100644 --- a/libspu/compiler/tests/expand_secret_gather.mlir +++ b/libspu/compiler/tests/expand_secret_gather.mlir @@ -3,12 +3,12 @@ func.func @main(%arg0: tensor<2x!pphlo.pub>, %arg1: tensor<1x!pphlo.sec>) -> (tensor>) { //CHECK-NOT: pphlo.gather //CHECK : pphlo.while - %0 = "pphlo.gather"(%arg0, %arg1) {dimension_numbers = #pphlo.gather, indices_are_sorted = true, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<2x!pphlo.pub>, tensor<1x!pphlo.sec>) -> tensor> + %0 = "pphlo.gather"(%arg0, %arg1) {dimension_numbers = #pphlo.gather, indices_are_sorted = true, slice_sizes = array} : (tensor<2x!pphlo.pub>, tensor<1x!pphlo.sec>) -> tensor> return %0: tensor> } // ----- func.func @main(%arg0: tensor<3x3x!pphlo.pub>, %arg1: tensor<2x!pphlo.sec>) -> (tensor<2x3x!pphlo.sec>) { - %0 = "pphlo.gather"(%arg0, %arg1) {dimension_numbers = #pphlo.gather, indices_are_sorted = false, slice_sizes = dense<[1, 3]> : tensor<2xi64>} : (tensor<3x3x!pphlo.pub>, tensor<2x!pphlo.sec>) -> tensor<2x3x!pphlo.sec> + %0 = "pphlo.gather"(%arg0, %arg1) {dimension_numbers = #pphlo.gather, indices_are_sorted = false, slice_sizes = array} : (tensor<3x3x!pphlo.pub>, tensor<2x!pphlo.sec>) -> tensor<2x3x!pphlo.sec> return %0 : tensor<2x3x!pphlo.sec> } diff --git a/libspu/compiler/tests/hlo_to_pphlo_ops_other.mlir b/libspu/compiler/tests/hlo_to_pphlo_ops_other.mlir index 218d39ee..6fbe32d8 100644 --- a/libspu/compiler/tests/hlo_to_pphlo_ops_other.mlir +++ b/libspu/compiler/tests/hlo_to_pphlo_ops_other.mlir @@ -1,19 +1,19 @@ // RUN: mlir-pphlo-opt --hlo-legalize-to-pphlo=input_vis_list=VIS_PUBLIC,VIS_PUBLIC,VIS_PUBLIC,VIS_PUBLIC,VIS_PUBLIC --split-input-file %s | FileCheck %s func.func @main(%arg0: tensor<16xf32>,%arg1: tensor<1024x1xi1>, %arg2: tensor<1024x1xf32>, %arg3: tensor<1024x1xf32>, %arg4: tensor<3x4xi32>) -> (tensor<1024x16xf32>) { - // CHECK: %0 = "pphlo.broadcast"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<16x!pphlo.pub>) -> tensor<1024x16x!pphlo.pub> + // CHECK: %0 = "pphlo.broadcast"(%arg0) {broadcast_dimensions = array} : (tensor<16x!pphlo.pub>) -> tensor<1024x16x!pphlo.pub> %0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<16xf32>) -> tensor<1024x16xf32> // CHECK: %1 = "pphlo.reshape"(%arg0) : (tensor<16x!pphlo.pub>) -> tensor<1x16x!pphlo.pub> %1 = "stablehlo.reshape"(%arg0) : (tensor<16xf32>) -> tensor<1x16xf32> - // CHECK: %2 = "pphlo.transpose"(%1) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x16x!pphlo.pub>) -> tensor<16x1x!pphlo.pub> - %2 = "stablehlo.transpose"(%1) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x16xf32>) -> tensor<16x1xf32> + // CHECK: %2 = "pphlo.transpose"(%1) {permutation = array} : (tensor<1x16x!pphlo.pub>) -> tensor<16x1x!pphlo.pub> + %2 = "stablehlo.transpose"(%1) {permutation = array} : (tensor<1x16xf32>) -> tensor<16x1xf32> // CHECK: %3 = "pphlo.dot"(%0, %2) : (tensor<1024x16x!pphlo.pub>, tensor<16x1x!pphlo.pub>) -> tensor<1024x1x!pphlo.pub> %3 = "stablehlo.dot"(%0, %2) {precision_config = [#stablehlo, #stablehlo]} : (tensor<1024x16xf32>, tensor<16x1xf32>) -> tensor<1024x1xf32> // CHECK: %6 = "pphlo.concatenate"(%4, %5) {dimension = 1 : i64} : (tensor<1024x16x!pphlo.pub>, tensor<1024x1x!pphlo.pub>) -> tensor<1024x17x!pphlo.pub> %4 = "stablehlo.concatenate"(%0, %3) {dimension = 1 : i64} : (tensor<1024x16xf32>, tensor<1024x1xf32>) -> tensor<1024x17xf32> // CHECK: %7 = "pphlo.select"(%arg1, %arg2, %arg3) : (tensor<1024x1x!pphlo.pub>, tensor<1024x1x!pphlo.pub>, tensor<1024x1x!pphlo.pub>) -> tensor<1024x1x!pphlo.pub> %5 = "stablehlo.select"(%arg1, %arg2, %arg3) : (tensor<1024x1xi1>, tensor<1024x1xf32>, tensor<1024x1xf32>) -> tensor<1024x1xf32> - // CHECK: %8 = "pphlo.slice"(%arg4) {limit_indices = dense<[2, 4]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4x!pphlo.pub>) -> tensor<1x2x!pphlo.pub> - %6 = "stablehlo.slice"(%arg4) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x2xi32> + // CHECK: %8 = "pphlo.slice"(%arg4) {limit_indices = array, start_indices = array, strides = array} : (tensor<3x4x!pphlo.pub>) -> tensor<1x2x!pphlo.pub> + %6 = "stablehlo.slice"(%arg4) {start_indices = array, limit_indices = array, strides = array} : (tensor<3x4xi32>) -> tensor<1x2xi32> return %0 : tensor<1024x16xf32> } diff --git a/libspu/compiler/tests/hlo_to_pphlo_reduce_window.mlir b/libspu/compiler/tests/hlo_to_pphlo_reduce_window.mlir index 723fe689..6c483c44 100644 --- a/libspu/compiler/tests/hlo_to_pphlo_reduce_window.mlir +++ b/libspu/compiler/tests/hlo_to_pphlo_reduce_window.mlir @@ -1,7 +1,7 @@ // RUN: mlir-pphlo-opt --hlo-legalize-to-pphlo=input_vis_list=VIS_PUBLIC,VIS_PUBLIC --split-input-file %s | FileCheck %s func.func @main(%arg0: tensor<3x2xi64>, %arg1: tensor) -> tensor<2x2xi64> { - // CHECK: %0 = "pphlo.pad"(%arg0, %arg1) {edge_padding_high = dense<[1, 0]> : tensor<2xi64>, edge_padding_low = dense<[2, 0]> : tensor<2xi64>, interior_padding = dense<[1, 0]> : tensor<2xi64>} : (tensor<3x2x!pphlo.pub>, tensor>) -> tensor<8x2x!pphlo.pub> + // CHECK: %0 = "pphlo.pad"(%arg0, %arg1) {edge_padding_high = array, edge_padding_low = array, interior_padding = array} : (tensor<3x2x!pphlo.pub>, tensor>) -> tensor<8x2x!pphlo.pub> // CHECK: %1 = "pphlo.reduce_window"(%0, %arg1) %result = "stablehlo.reduce_window"(%arg0, %arg1) ({ ^bb0(%arg2: tensor, %arg3: tensor): diff --git a/libspu/compiler/tests/hlo_to_pphlo_select_and_scatter.mlir b/libspu/compiler/tests/hlo_to_pphlo_select_and_scatter.mlir index d828a4cf..52ab9d53 100644 --- a/libspu/compiler/tests/hlo_to_pphlo_select_and_scatter.mlir +++ b/libspu/compiler/tests/hlo_to_pphlo_select_and_scatter.mlir @@ -9,7 +9,7 @@ func.func @main(%arg0: tensor<128x5x5x32xf32>, %arg1: tensor<128x4x4x32xf32>, %a // CHECK: ^bb0(%arg3: tensor>, %arg4: tensor>): // CHECK: %2 = "pphlo.add"(%arg3, %arg4) : (tensor>, tensor>) -> tensor> // CHECK: "pphlo.return"(%2) : (tensor>) -> () - // CHECK: }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<128x5x5x32x!pphlo.sec>, tensor<128x4x4x32x!pphlo.pub>, tensor>) -> tensor<128x5x5x32x!pphlo.sec> + // CHECK: }) {window_dimensions = array, window_strides = array} : (tensor<128x5x5x32x!pphlo.sec>, tensor<128x4x4x32x!pphlo.pub>, tensor>) -> tensor<128x5x5x32x!pphlo.sec> %0 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ ^bb0(%arg3: tensor, %arg4: tensor): %1 = "stablehlo.compare"(%arg3, %arg4) {comparison_direction = #stablehlo} : (tensor, tensor) -> tensor diff --git a/libspu/compiler/tests/no_expand_secret_gather.mlir b/libspu/compiler/tests/no_expand_secret_gather.mlir index 71144956..155a4dca 100644 --- a/libspu/compiler/tests/no_expand_secret_gather.mlir +++ b/libspu/compiler/tests/no_expand_secret_gather.mlir @@ -3,6 +3,6 @@ func.func @main(%arg0: tensor<2x!pphlo.pub>, %arg1: tensor<1x!pphlo.pub>) -> (tensor>) { //CHECK-NOT: pphlo.while //CHECK : pphlo.gather - %0 = "pphlo.gather"(%arg0, %arg1) {dimension_numbers = #pphlo.gather, indices_are_sorted = true, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<2x!pphlo.pub>, tensor<1x!pphlo.pub>) -> tensor> + %0 = "pphlo.gather"(%arg0, %arg1) {dimension_numbers = #pphlo.gather, indices_are_sorted = true, slice_sizes = array} : (tensor<2x!pphlo.pub>, tensor<1x!pphlo.pub>) -> tensor> return %0: tensor> } diff --git a/libspu/compiler/tests/ops_negative.mlir b/libspu/compiler/tests/ops_negative.mlir index 2028c724..80943e15 100644 --- a/libspu/compiler/tests/ops_negative.mlir +++ b/libspu/compiler/tests/ops_negative.mlir @@ -15,7 +15,7 @@ func.func @main() -> tensor> { %3 = "pphlo.floor"(%2) : (tensor<9x!pphlo.pub>) -> tensor<9x!pphlo.pub> %9 = "pphlo.concatenate"(%3) {dimension = 0 : i64} : (tensor<9x!pphlo.pub>) -> tensor<9x!pphlo.pub> // expected-error @+1 {{broadcast_dimensions contains invalid value 13 for result with rank 1}} - %10 = "pphlo.broadcast"(%9) {broadcast_dimensions = dense<13> : tensor<1xi64>} : (tensor<9x!pphlo.pub>) -> tensor<9x!pphlo.pub> + %10 = "pphlo.broadcast"(%9) {broadcast_dimensions = array} : (tensor<9x!pphlo.pub>) -> tensor<9x!pphlo.pub> %51 = "pphlo.constant"() {value = dense<5> : tensor} : () -> tensor> "pphlo.return"(%51) : (tensor>) -> () } @@ -25,7 +25,7 @@ func.func @main() -> tensor> { func.func @main() -> tensor> { %0 = "pphlo.constant"() {value = dense<[0.000000e+00, -3.40282347E+38]> : tensor<2xf32>} : () -> tensor<2x!pphlo.pub> // expected-error @+1 {{op broadcast_dimensions contains invalid value -6 for result with rank 1}} - %1 = "pphlo.broadcast"(%0) {broadcast_dimensions = dense<-6> : tensor<1xi64>} : (tensor<2x!pphlo.pub>) -> tensor<2x!pphlo.pub> + %1 = "pphlo.broadcast"(%0) {broadcast_dimensions = array} : (tensor<2x!pphlo.pub>) -> tensor<2x!pphlo.pub> %2 = "pphlo.constant"() {value = dense<5> : tensor} : () -> tensor> "pphlo.return"(%2) : (tensor>) -> () } @@ -43,7 +43,7 @@ func.func @main() -> tensor> { func.func @main(%arg0: tensor<9x9x1x!pphlo.sec>) -> tensor<9x9x1x!pphlo.sec> { // expected-error @+1 {{op permutation -837266656812241085 out of range [0, 2]}} - %0 = "pphlo.transpose"(%arg0) {permutation = dense<[-837266656812241085, -1986534498277253088, -6908486506403635863]> : tensor<3xi64>} : (tensor<9x9x1x!pphlo.sec>) -> tensor<9x9x1x!pphlo.sec> + %0 = "pphlo.transpose"(%arg0) {permutation = array} : (tensor<9x9x1x!pphlo.sec>) -> tensor<9x9x1x!pphlo.sec> "pphlo.return"(%0) : (tensor<9x9x1x!pphlo.sec>) -> () } @@ -51,7 +51,7 @@ func.func @main(%arg0: tensor<9x9x1x!pphlo.sec>) -> tensor<9x9x1x!pphlo.sec func.func @main(%arg0: tensor<9x1x!pphlo.sec>) -> tensor<9x1x!pphlo.sec> { // expected-error @+1 {{op requires the same element type for all operands and results}} - %0 = "pphlo.transpose"(%arg0) {permutation = dense<[0, 1]> : tensor<2xi64>} : (tensor<9x1x!pphlo.sec>) -> tensor<9x1x!pphlo.sec> + %0 = "pphlo.transpose"(%arg0) {permutation = array} : (tensor<9x1x!pphlo.sec>) -> tensor<9x1x!pphlo.sec> "pphlo.return"(%0) : (tensor<9x1x!pphlo.sec>) -> () } @@ -59,7 +59,7 @@ func.func @main(%arg0: tensor<9x1x!pphlo.sec>) -> tensor<9x1x!pphlo.sec>) -> tensor<9x1x!pphlo.sec> { // expected-error @+1 {{op shape mismatch input shape = 9x1, result shape = 9x1, permutation = 1x0}} - %0 = "pphlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<9x1x!pphlo.sec>) -> tensor<9x1x!pphlo.sec> + %0 = "pphlo.transpose"(%arg0) {permutation = array} : (tensor<9x1x!pphlo.sec>) -> tensor<9x1x!pphlo.sec> "pphlo.return"(%0) : (tensor<9x1x!pphlo.sec>) -> () } @@ -67,7 +67,7 @@ func.func @main(%arg0: tensor<9x1x!pphlo.sec>) -> tensor<9x1x!pphlo.sec>) -> tensor<9x9x1x!pphlo.pub> { // expected-error @+1 {{op all dimensions should be non-negative. Got dimension: -1191754011229144205.}} - %0 = "pphlo.reverse"(%arg0) {dimensions = dense<[-4367244339678518167, -1191754011229144205, -977434623931441042]> : tensor<3xi64>} : (tensor<9x9x1x!pphlo.pub>) -> tensor<9x9x1x!pphlo.pub> + %0 = "pphlo.reverse"(%arg0) {dimensions = array} : (tensor<9x9x1x!pphlo.pub>) -> tensor<9x9x1x!pphlo.pub> "pphlo.return"(%0) : (tensor<9x9x1x!pphlo.pub>) -> () } @@ -75,7 +75,7 @@ func.func @main(%arg0: tensor<9x9x1x!pphlo.pub>) -> tensor<9x9x1x!pphlo.pub func.func @main(%arg0: tensor<9x9x1x!pphlo.pub>) -> tensor<9x9x1x!pphlo.pub> { // expected-error @+1 {{op all dimensions should be between [0, 3). Got dimension: 4367244339678518167.}} - %0 = "pphlo.reverse"(%arg0) {dimensions = dense<[4367244339678518167, 1191754011229144205, 977434623931441042]> : tensor<3xi64>} : (tensor<9x9x1x!pphlo.pub>) -> tensor<9x9x1x!pphlo.pub> + %0 = "pphlo.reverse"(%arg0) {dimensions = array} : (tensor<9x9x1x!pphlo.pub>) -> tensor<9x9x1x!pphlo.pub> "pphlo.return"(%0) : (tensor<9x9x1x!pphlo.pub>) -> () } @@ -83,15 +83,7 @@ func.func @main(%arg0: tensor<9x9x1x!pphlo.pub>) -> tensor<9x9x1x!pphlo.pub func.func @main(%arg0: tensor<9x9x1x!pphlo.pub>) -> tensor<9x9x1x!pphlo.pub> { // expected-error @+1 {{op dimensions are not unique}} - %0 = "pphlo.reverse"(%arg0) {dimensions = dense<[1,1,1]> : tensor<3xi64>} : (tensor<9x9x1x!pphlo.pub>) -> tensor<9x9x1x!pphlo.pub> - "pphlo.return"(%0) : (tensor<9x9x1x!pphlo.pub>) -> () -} - -// ----- - -func.func @main(%arg0: tensor<9x9x1x!pphlo.pub>) -> tensor<9x9x1x!pphlo.pub> { - // expected-error @+1 {{op dimensions must be a 1-dimensional tensor}} - %0 = "pphlo.reverse"(%arg0) {dimensions = dense<[[1,2],[3,4]]> : tensor<2x2xi64>} : (tensor<9x9x1x!pphlo.pub>) -> tensor<9x9x1x!pphlo.pub> + %0 = "pphlo.reverse"(%arg0) {dimensions = array} : (tensor<9x9x1x!pphlo.pub>) -> tensor<9x9x1x!pphlo.pub> "pphlo.return"(%0) : (tensor<9x9x1x!pphlo.pub>) -> () } @@ -104,22 +96,18 @@ func.func @main(%arg0: tensor<10x!pphlo.pub>) -> (tensor>) ^bb0(%arg1: tensor>, %arg2: tensor>): // no predecessors %2 = "pphlo.add"(%arg1, %arg2) : (tensor>, tensor>) -> tensor> "pphlo.return"(%2) : (tensor>) -> () - }) {dimensions = dense<-12233434> : tensor<1xi64>} : (tensor<10x!pphlo.pub>, tensor>) -> tensor> + }) {dimensions = array} : (tensor<10x!pphlo.pub>, tensor>) -> tensor> return %1 : tensor> } // ----- func.func @main() -> tensor> { - %0 = "pphlo.constant"() {value = dense<127> : tensor} : () -> tensor> - %1 = "pphlo.slice"(%0) {limit_indices = dense<> : tensor<0xi64>, start_indices = dense<> : tensor<0xi64>, strides = dense<> : tensor<0xi64>} : (tensor>) -> tensor> - %2 = "pphlo.slice"(%1) {limit_indices = dense<> : tensor<0xi64>, start_indices = dense<> : tensor<0xi64>, strides = dense<> : tensor<0xi64>} : (tensor>) -> tensor> - %3 = "pphlo.slice"(%0) {limit_indices = dense<> : tensor<0xi64>, start_indices = dense<> : tensor<0xi64>, strides = dense<> : tensor<0xi64>} : (tensor>) -> tensor> - %4 = "pphlo.constant"() {value = dense<-1.7976931344453863E+308> : tensor<1x1xf64>} : () -> tensor<1x1x!pphlo.pub> + %0 = "pphlo.constant"() {value = dense<-1.7976931344453863E+308> : tensor<1x1xf64>} : () -> tensor<1x1x!pphlo.pub> // expected-error @+1 {{op negative start index -9220555925398487041 in dimension 0}} - %5 = "pphlo.slice"(%4) {limit_indices = dense<[-9220555925398487041, 0]> : tensor<2xi64>, start_indices = dense<[-9220555925398487041, 0]> : tensor<2xi64>, strides = dense<[-9220555925398487041, 0]> : tensor<2xi64>} : (tensor<1x1x!pphlo.pub>) -> tensor<1x1x!pphlo.pub> - %6 = "pphlo.slice"(%4) {limit_indices = dense<[-8502447508339815911, -9223371558496411295]> : tensor<2xi64>, start_indices = dense<[-8502447508339815911, -9223371558496411295]> : tensor<2xi64>, strides = dense<[-8502447508339815911, -9223371558496411295]> : tensor<2xi64>} : (tensor<1x1x!pphlo.pub>) -> tensor<1x1x!pphlo.pub> - %7 = "pphlo.constant"() {value = dense<5> : tensor} : () -> tensor> - "pphlo.return"(%7) : (tensor>) -> () + %1 = "pphlo.slice"(%0) {limit_indices = array, start_indices = array, strides = array} : (tensor<1x1x!pphlo.pub>) -> tensor<1x1x!pphlo.pub> + %2 = "pphlo.slice"(%0) {limit_indices = array, start_indices = array, strides = array} : (tensor<1x1x!pphlo.pub>) -> tensor<1x1x!pphlo.pub> + %3 = "pphlo.constant"() {value = dense<5> : tensor} : () -> tensor> + "pphlo.return"(%3) : (tensor>) -> () } diff --git a/libspu/compiler/tests/optimize_denominator_with_bcst.mlir b/libspu/compiler/tests/optimize_denominator_with_bcst.mlir index 47c9cd53..ee856fec 100644 --- a/libspu/compiler/tests/optimize_denominator_with_bcst.mlir +++ b/libspu/compiler/tests/optimize_denominator_with_bcst.mlir @@ -5,7 +5,7 @@ func.func @main(%arg0: tensor<16x!pphlo.sec>, %arg1: tensor<16x10000x!pphlo //CHECK: %1 = "pphlo.broadcast"(%0) //CHECK: %2 = "pphlo.multiply"(%arg1, %1) //CHECK: return %2 - %0 = "pphlo.broadcast"(%arg0) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<16x!pphlo.sec>) -> tensor<16x10000x!pphlo.sec> + %0 = "pphlo.broadcast"(%arg0) {broadcast_dimensions = array} : (tensor<16x!pphlo.sec>) -> tensor<16x10000x!pphlo.sec> %1 = "pphlo.divide"(%arg1, %0) : (tensor<16x10000x!pphlo.sec>, tensor<16x10000x!pphlo.sec>) -> tensor<16x10000x!pphlo.sec> return %1 : tensor<16x10000x!pphlo.sec> } diff --git a/libspu/compiler/tests/optimize_maxpool.mlir b/libspu/compiler/tests/optimize_maxpool.mlir index 3f37bf80..75fa8b0d 100644 --- a/libspu/compiler/tests/optimize_maxpool.mlir +++ b/libspu/compiler/tests/optimize_maxpool.mlir @@ -10,7 +10,7 @@ func.func @main(%arg0: tensor<129x24x24x16x!pphlo.sec>, %arg1: tensor<129x2 ^bb0(%arg2: tensor>, %arg3: tensor>): %6 = "pphlo.maximum"(%arg2, %arg3) : (tensor>, tensor>) -> tensor> "pphlo.return"(%6) : (tensor>) -> () - }) {base_dilations = dense<1> : tensor<4xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<129x24x24x16x!pphlo.sec>, tensor>) -> tensor<129x23x23x16x!pphlo.sec> + }) {base_dilations = array, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor<129x24x24x16x!pphlo.sec>, tensor>) -> tensor<129x23x23x16x!pphlo.sec> //CHECK-NOT: pphlo.select_and_scatter //CHECK : pphlo.maxpool_scatter %5 = "pphlo.select_and_scatter"(%arg0, %arg1, %3) ({ @@ -21,7 +21,7 @@ func.func @main(%arg0: tensor<129x24x24x16x!pphlo.sec>, %arg1: tensor<129x2 ^bb0(%arg2: tensor>, %arg3: tensor>): %6 = "pphlo.add"(%arg2, %arg3) : (tensor>, tensor>) -> tensor> "pphlo.return"(%6) : (tensor>) -> () - }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<129x24x24x16x!pphlo.sec>, tensor<129x23x23x16x!pphlo.sec>, tensor>) -> tensor<129x24x24x16x!pphlo.sec> + }) {window_dimensions = array, window_strides = array} : (tensor<129x24x24x16x!pphlo.sec>, tensor<129x23x23x16x!pphlo.sec>, tensor>) -> tensor<129x24x24x16x!pphlo.sec> return %4, %5 : tensor<129x23x23x16x!pphlo.sec>, tensor<129x24x24x16x!pphlo.sec> } @@ -42,7 +42,7 @@ func.func @main(%arg0: tensor<128x2x2x256x!pphlo.sec>, %arg1: tensor<128x1x ^bb0(%arg2: tensor>, %arg3: tensor>): %5 = "pphlo.add"(%arg2, %arg3) : (tensor>, tensor>) -> tensor> "pphlo.return"(%5) : (tensor>) -> () - }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<128x2x2x256x!pphlo.sec>, tensor<128x1x1x256x!pphlo.sec>, tensor>) -> tensor<128x2x2x256x!pphlo.sec> + }) {window_dimensions = array, window_strides = array} : (tensor<128x2x2x256x!pphlo.sec>, tensor<128x1x1x256x!pphlo.sec>, tensor>) -> tensor<128x2x2x256x!pphlo.sec> return %3, %4 : tensor<128x2x2x256x!pphlo.sec>, tensor<128x2x2x256x!pphlo.sec> } diff --git a/libspu/compiler/tests/pphlo_type_inference_reduce.mlir b/libspu/compiler/tests/pphlo_type_inference_reduce.mlir index c687fa2c..4b33127b 100644 --- a/libspu/compiler/tests/pphlo_type_inference_reduce.mlir +++ b/libspu/compiler/tests/pphlo_type_inference_reduce.mlir @@ -8,11 +8,11 @@ func.func @main(%arg1: tensor<1024x1xf32>) -> (tensor<1024xf32>) { // CHECK: ^bb0(%arg1: tensor>, %arg2: tensor>): // CHECK: %3 = "pphlo.add"(%arg1, %arg2) : (tensor>, tensor>) -> tensor> // CHECK: "pphlo.return"(%3) : (tensor>) -> () - // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1024x1x!pphlo.sec>, tensor>) -> tensor<1024x!pphlo.sec> + // CHECK: }) {dimensions = array} : (tensor<1024x1x!pphlo.sec>, tensor>) -> tensor<1024x!pphlo.sec> %1 = "stablehlo.reduce"(%arg1, %0) ( { ^bb0(%arg2: tensor, %arg3: tensor): // no predecessors %2 = "stablehlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor "stablehlo.return"(%2) : (tensor) -> () - }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1024x1xf32>, tensor) -> tensor<1024xf32> + }) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<1024x1xf32>, tensor) -> tensor<1024xf32> return %1 : tensor<1024xf32> } diff --git a/libspu/compiler/tools/BUILD.bazel b/libspu/compiler/tools/BUILD.bazel index 4a370f64..152fa442 100644 --- a/libspu/compiler/tools/BUILD.bazel +++ b/libspu/compiler/tools/BUILD.bazel @@ -34,3 +34,19 @@ spu_cc_binary( "@xla//xla/mlir_hlo:mhlo_passes", ], ) + +spu_cc_binary( + name = "mlir-pphlo-lsp", + srcs = [ + "mlir-pphlo-lsp.cc", + ], + deps = [ + "//libspu/dialect:pphlo_dialect", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncExtensions", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MlirLspServerLib", + "@stablehlo//:stablehlo_ops", + "@xla//xla/mlir_hlo", + ], +) diff --git a/libspu/compiler/tools/mlir-pphlo-lsp.cc b/libspu/compiler/tools/mlir-pphlo-lsp.cc new file mode 100644 index 00000000..c1960aee --- /dev/null +++ b/libspu/compiler/tools/mlir-pphlo-lsp.cc @@ -0,0 +1,29 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mlir/Dialect/Func/Extensions/InlinerExtension.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Tools/mlir-lsp-server/MlirLspServerMain.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +#include "libspu/dialect/pphlo_dialect.h" + +int main(int argc, char **argv) { + mlir::DialectRegistry registry; + registry.insert(); + mlir::func::registerInlinerExtension(registry); + return mlir::failed(mlir::MlirLspServerMain(argc, argv, registry)); +} diff --git a/libspu/core/context.h b/libspu/core/context.h index 1f9fe489..23ecb652 100644 --- a/libspu/core/context.h +++ b/libspu/core/context.h @@ -93,7 +93,12 @@ class KernelEvalContext final { Type, // type of type uint128_t, // ring constant int64_t, // - SignType // + SignType, // + std::vector, // + Axes, // + Index, // + Strides, // + Sizes // >; SPUContext* sctx_; diff --git a/libspu/core/object.h b/libspu/core/object.h index e987b8a1..a1aa3c03 100644 --- a/libspu/core/object.h +++ b/libspu/core/object.h @@ -115,6 +115,12 @@ class Object final { regKernel(KernelT::kBindName, std::make_unique()); } + template + void regKernel() { + regKernel(); + regKernel(); + } + template void regKernel(const std::string& name) { return regKernel(name, std::make_unique()); diff --git a/libspu/core/shape.h b/libspu/core/shape.h index 439f25bc..2263a2cb 100644 --- a/libspu/core/shape.h +++ b/libspu/core/shape.h @@ -78,6 +78,9 @@ class Index : public std::vector { public: using Base::Base; + /*explicit*/ Index(llvm::ArrayRef arr) + : Base(arr.begin(), arr.end()) {} + /// Checks if an element `e` at kth axis of `this` object follows /// `0 <= e <= bounds[k]`. bool inBounds(const Shape &bounds) const; @@ -119,6 +122,9 @@ class Sizes : public std::vector { public: using Base::Base; + /*explicit*/ Sizes(llvm::ArrayRef arr) + : Base(arr.begin(), arr.end()) {} + friend std::ostream &operator<<(std::ostream &out, const Sizes &s) { out << fmt::format("{}", fmt::join(s, "x")); return out; diff --git a/libspu/core/type.h b/libspu/core/type.h index c8f13203..cb8367ca 100644 --- a/libspu/core/type.h +++ b/libspu/core/type.h @@ -138,6 +138,16 @@ class BShare { void setNbits(size_t nbits) { nbits_ = nbits; } }; +// Permutation share, a secret permutation can be a composition of a series of +// individual permutations hold by different parties. Each individual +// permutation is represented as a PShare in SPU. PShare is a secret type. +// We use the letter m for naming PShare values in order to be distinguished +// from public values. +class PShare { + public: + virtual ~PShare() = default; +}; + //////////////////////////////////////////////////////////////////////////// // Type interfaces end. //////////////////////////////////////////////////////////////////////////// diff --git a/libspu/core/value.cc b/libspu/core/value.cc index 360ba623..cc327985 100644 --- a/libspu/core/value.cc +++ b/libspu/core/value.cc @@ -209,4 +209,9 @@ std::ostream& operator<<(std::ostream& out, const Value& v) { return out; } +std::ostream& operator<<(std::ostream& out, const std::vector& v) { + out << fmt::format("{}", fmt::join(v, ",")); + return out; +} + } // namespace spu diff --git a/libspu/core/value.h b/libspu/core/value.h index 90043164..75c10f78 100644 --- a/libspu/core/value.h +++ b/libspu/core/value.h @@ -145,6 +145,7 @@ struct SimdTrait { }; std::ostream& operator<<(std::ostream& out, const Value& v); +std::ostream& operator<<(std::ostream& out, const std::vector& v); inline auto format_as(const Value& v) { return fmt::streamed(v); } diff --git a/libspu/cuda_support/BUILD.bazel b/libspu/cuda_support/BUILD.bazel index 7e413426..f85ccc36 100644 --- a/libspu/cuda_support/BUILD.bazel +++ b/libspu/cuda_support/BUILD.bazel @@ -11,19 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# Copyright 2023 Ant Group Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. load("@rules_cuda//cuda:defs.bzl", "cuda_library") load("//bazel:spu.bzl", "spu_cc_test") diff --git a/libspu/device/pphlo/pphlo_executor.cc b/libspu/device/pphlo/pphlo_executor.cc index 488f6dd8..6ed81a4d 100644 --- a/libspu/device/pphlo/pphlo_executor.cc +++ b/libspu/device/pphlo/pphlo_executor.cc @@ -42,15 +42,6 @@ namespace { -template -void convertDenseIntElementAttr(const mlir::DenseIntElementsAttr &attr, - T &out) { - out.clear(); - for (const auto &v : attr.getValues()) { - out.emplace_back(v); - } -} - template std::string mlirObjectToString(T &&mlir_obj) { std::string buf; @@ -206,7 +197,7 @@ void do_type_checker(mlir::Value key, const spu::Value &val, if (tool.isMPCType(mlir_type)) { SPU_ENFORCE(val.isPublic()); } else if (tool.isMPCType(mlir_type)) { - SPU_ENFORCE(val.isSecret()); + SPU_ENFORCE(val.isSecret() || val.isPrivate()); } else { SPU_ENFORCE("Unknown vtype"); } @@ -428,10 +419,7 @@ void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope, Strides window_strides(dnums.getInputSpatialDimensions().size(), 1); if (op.getWindowStrides().has_value()) { - for (const auto &iter : llvm::enumerate( - op.getWindowStrides()->getValues())) { // NOLINT - window_strides[iter.index()] = iter.value(); - } + window_strides = *op.getWindowStrides(); } kernel::hlo::ConvolutionConfig config; @@ -480,8 +468,7 @@ void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope, void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope, mlir::pphlo::DynamicSliceOp &op, const ExecutionOptions &opts) { // Start indices - auto iter = op.getSliceSizes().getValues(); - Sizes slice_size{iter.begin(), iter.end()}; + Sizes slice_size{op.getSliceSizes().begin(), op.getSliceSizes().end()}; const auto &operand = lookupValue(sscope, op.getOperand(), opts); std::vector start_indices(op.getStartIndices().size()); @@ -510,9 +497,9 @@ void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope, const auto &dim_numbers = op.getDimensionNumbers(); kernel::hlo::GatherConfig config; - Sizes ss; - convertDenseIntElementAttr(op.getSliceSizes(), ss); - config.sliceSizes = ss; + // Sizes ss; + // convertDenseIntElementAttr(op.getSliceSizes(), ss); + config.sliceSizes = op.getSliceSizes(); config.indexVectorDim = dim_numbers.getIndexVectorDim(); config.offsetDims = dim_numbers.getOffsetDims(); config.collapsedSliceDims = dim_numbers.getCollapsedSliceDims(); @@ -603,14 +590,13 @@ void execute(OpExecutor *executor, SPUContext *sctx, SymbolScope *sscope, auto source = lookupValue(sscope, op.getSource(), opts); auto init_val = lookupValue(sscope, op.getInitValue(), opts); - Shape window_shape; - convertDenseIntElementAttr(op.getWindowDimensions(), window_shape); + Shape window_shape(op.getWindowDimensions().begin(), + op.getWindowDimensions().end()); // build strides Strides window_strides(window_shape.size(), 1); if (op.getWindowStrides().has_value()) { - convertDenseIntElementAttr(*op.getWindowStrides(), // NOLINT - window_strides); + window_strides = *op.getWindowStrides(); } // window padding @@ -639,14 +625,13 @@ void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope, auto scatter_indices = lookupValue(sscope, op.getScatterIndices(), opts); auto update = lookupValue(sscope, op.getUpdate(), opts); - Shape window_shape; - convertDenseIntElementAttr(op.getWindowDimensions().value(), window_shape); + Shape window_shape(op.getWindowDimensions().begin(), + op.getWindowDimensions().end()); // build strides Strides window_strides(window_shape.size(), 1); if (op.getWindowStrides().has_value()) { - convertDenseIntElementAttr(*op.getWindowStrides(), // NOLINT - window_strides); + window_strides = *op.getWindowStrides(); } // window padding @@ -764,8 +749,7 @@ void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope, void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope, mlir::pphlo::TransposeOp &op, const ExecutionOptions &opts) { - Axes permu; - convertDenseIntElementAttr(op.getPermutation(), permu); + Axes permu = op.getPermutation(); addValue(sscope, op.getResult(), kernel::hlo::Transpose( @@ -776,8 +760,7 @@ void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope, void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope, mlir::pphlo::BroadcastOp &op, const ExecutionOptions &opts) { auto to_shape = op.getType().dyn_cast().getShape(); - Axes in_dims; - convertDenseIntElementAttr(op.getBroadcastDimensions(), in_dims); + Axes in_dims = op.getBroadcastDimensions(); addValue( sscope, op.getResult(), kernel::hlo::Broadcast(sctx, lookupValue(sscope, op.getOperand(), opts), @@ -809,12 +792,9 @@ void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope, void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope, mlir::pphlo::SliceOp &op, const ExecutionOptions &opts) { - Index start; - Index end; - Strides s; - convertDenseIntElementAttr(op.getStartIndices(), start); - convertDenseIntElementAttr(op.getLimitIndices(), end); - convertDenseIntElementAttr(op.getStrides(), s); + Index start = op.getStartIndices(); + Index end = op.getLimitIndices(); + Strides s = op.getStrides(); addValue(sscope, op.getResult(), kernel::hlo::Slice(sctx, lookupValue(sscope, op.getOperand(), opts), start, end, s), @@ -828,16 +808,13 @@ void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope, const auto &padding_value = lookupValue(sscope, op.getPaddingValue(), opts); SPU_ENFORCE(padding_value.shape().isScalar()); - Sizes edge_padding_low; - convertDenseIntElementAttr(op.getEdgePaddingLow(), edge_padding_low); + Sizes edge_padding_low = op.getEdgePaddingLow(); SPU_ENFORCE(edge_padding_low.size() == operand_rank); - Sizes edge_padding_high; - convertDenseIntElementAttr(op.getEdgePaddingHigh(), edge_padding_high); + Sizes edge_padding_high = op.getEdgePaddingHigh(); SPU_ENFORCE(edge_padding_high.size() == operand_rank); - Sizes interior_padding; - convertDenseIntElementAttr(op.getInteriorPadding(), interior_padding); + Sizes interior_padding = op.getInteriorPadding(); SPU_ENFORCE(interior_padding.size() == operand_rank); SPU_ENFORCE(std::all_of(interior_padding.begin(), interior_padding.end(), [](int64_t i) { return i >= 0; })); @@ -850,8 +827,7 @@ void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope, void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope, mlir::pphlo::ReverseOp &op, const ExecutionOptions &opts) { - Axes dims; - convertDenseIntElementAttr(op.getDimensions(), dims); + Axes dims = op.getDimensions(); addValue(sscope, op.getResult(), kernel::hlo::Reverse( sctx, lookupValue(sscope, op.getOperand(), opts), dims), @@ -861,8 +837,7 @@ void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope, void execute(OpExecutor *executor, SPUContext *sctx, SymbolScope *sscope, mlir::pphlo::ReduceOp &op, const ExecutionOptions &opts) { int64_t num_args = op->getNumOperands() / 2; - Axes dimensions_to_reduce; - convertDenseIntElementAttr(op.getDimensions(), dimensions_to_reduce); + Axes dimensions_to_reduce = op.getDimensions(); std::vector input_args(num_args); std::vector init_values(num_args); @@ -910,21 +885,18 @@ void execute(OpExecutor *executor, SPUContext *sctx, SymbolScope *sscope, .getType() .dyn_cast() .getShape(); - Shape window_shape; - convertDenseIntElementAttr(op.getWindowDimensions(), window_shape); + Shape window_shape = op.getWindowDimensions(); // build strides Strides window_strides(window_shape.size(), 1); if (op.getWindowStrides().has_value()) { - convertDenseIntElementAttr(*op.getWindowStrides(), // NOLINT - window_strides); + window_strides = *op.getWindowStrides(); } // window dilation Sizes window_dilations(window_shape.size(), 1); if (op.getWindowDilations().has_value()) { - convertDenseIntElementAttr(*op.getWindowDilations(), // NOLINT - window_dilations); + window_dilations = *op.getWindowDilations(); } std::vector> window_padding(window_shape.size(), @@ -957,21 +929,18 @@ void execute(OpExecutor *executor, SPUContext *sctx, SymbolScope *sscope, void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope, mlir::pphlo::ArgMaxOp &op, const ExecutionOptions &opts) { - Shape window_shape; - convertDenseIntElementAttr(op.getWindowDimensions(), window_shape); + Shape window_shape = op.getWindowDimensions(); // build strides Strides window_strides(window_shape.size(), 1); if (op.getWindowStrides().has_value()) { - convertDenseIntElementAttr(*op.getWindowStrides(), // NOLINT - window_strides); + window_strides = *op.getWindowStrides(); } // window dilation Sizes window_dilations(window_shape.size(), 1); if (op.getWindowDilations().has_value()) { - convertDenseIntElementAttr(*op.getWindowDilations(), // NOLINT - window_dilations); + window_dilations = *op.getWindowDilations(); } auto ret_shape = op->getResults()[0] diff --git a/libspu/device/pphlo/pphlo_executor_test.cc b/libspu/device/pphlo/pphlo_executor_test.cc index 8a5c0b75..46362969 100644 --- a/libspu/device/pphlo/pphlo_executor_test.cc +++ b/libspu/device/pphlo/pphlo_executor_test.cc @@ -125,7 +125,7 @@ func.func @main() -> tensor> { %2 = "pphlo.constant"() {value = dense<[0x41DA6E5887800000, 0x41C94E3940000000, 0x41C4BD2007000000, 0x41DC95133AC00000, 0x41D1650CEC000000, 0x41C9DF42E7800000, 0x41D46C43B6800000, 0x41C467EE0E800000, 0x41DC705F14400000]> : tensor<9xf64>} : () -> tensor<9x!pphlo.pub> %3 = "pphlo.floor"(%2) : (tensor<9x!pphlo.pub>) -> tensor<9x!pphlo.pub> %9 = "pphlo.concatenate"(%3) {dimension = 0 : i64} : (tensor<9x!pphlo.pub>) -> tensor<9x!pphlo.pub> - %10 = "pphlo.broadcast"(%9) {broadcast_dimensions = dense<13> : tensor<1xi64>} : (tensor<9x!pphlo.pub>) -> tensor<9x!pphlo.pub> + %10 = "pphlo.broadcast"(%9) {broadcast_dimensions = array} : (tensor<9x!pphlo.pub>) -> tensor<9x!pphlo.pub> %51 = "pphlo.constant"() {value = dense<5> : tensor} : () -> tensor> "pphlo.return"(%51) : (tensor>) -> () })"), @@ -203,7 +203,7 @@ TEST_P(ExecutorTest, Slice) { r.run(R"( func.func @main(%arg0: tensor<4x3x!pphlo.pub>) -> (tensor<2x2x!pphlo.pub>) { - %0 = "pphlo.slice"(%arg0) {limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[2, 1]> : tensor<2xi64>, strides = dense<[1, 1]> : tensor<2xi64>} : (tensor<4x3x!pphlo.pub>) -> tensor<2x2x!pphlo.pub> + %0 = "pphlo.slice"(%arg0) {limit_indices = array, start_indices = array, strides = array} : (tensor<4x3x!pphlo.pub>) -> tensor<2x2x!pphlo.pub> return %0 : tensor<2x2x!pphlo.pub> })"); @@ -222,7 +222,7 @@ TEST_P(ExecutorTest, SliceStride) { r.run(R"( func.func @main(%arg0: tensor<4x6x!pphlo.pub>) -> (tensor<2x3x!pphlo.pub>) { - %0 = "pphlo.slice"(%arg0) {limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<[2, 2]> : tensor<2xi64>} : (tensor<4x6x!pphlo.pub>) -> tensor<2x3x!pphlo.pub> + %0 = "pphlo.slice"(%arg0) {limit_indices = array, start_indices = array, strides = array} : (tensor<4x6x!pphlo.pub>) -> tensor<2x3x!pphlo.pub> return %0 : tensor<2x3x!pphlo.pub> })"); @@ -287,7 +287,7 @@ func.func @main(%arg0: tensor<10x!pphlo.pub>) -> (tensor>) ^bb0(%arg1: tensor>, %arg2: tensor>): // no predecessors %2 = "pphlo.add"(%arg1, %arg2) : (tensor>, tensor>) -> tensor> "pphlo.return"(%2) : (tensor>) -> () - }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<10x!pphlo.pub>, tensor>) -> tensor> + }) {dimensions = array} : (tensor<10x!pphlo.pub>, tensor>) -> tensor> return %1 : tensor> })"); @@ -309,7 +309,7 @@ func.func @main(%arg0: tensor<2x3x!pphlo.pub>) -> (tensor<2x!pphlo.pub ^bb0(%arg1: tensor>, %arg2: tensor>): // no predecessors %2 = "pphlo.add"(%arg1, %arg2) : (tensor>, tensor>) -> tensor> "pphlo.return"(%2) : (tensor>) -> () - }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3x!pphlo.pub>, tensor>) -> tensor<2x!pphlo.pub> + }) {dimensions = array} : (tensor<2x3x!pphlo.pub>, tensor>) -> tensor<2x!pphlo.pub> return %1 : tensor<2x!pphlo.pub> })"); @@ -331,7 +331,7 @@ func.func @main(%arg0: tensor<2x3x!pphlo.pub>) -> (tensor<3x!pphlo.pub ^bb0(%arg1: tensor>, %arg2: tensor>): // no predecessors %2 = "pphlo.add"(%arg1, %arg2) : (tensor>, tensor>) -> tensor> "pphlo.return"(%2) : (tensor>) -> () - }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<2x3x!pphlo.pub>, tensor>) -> tensor<3x!pphlo.pub> + }) {dimensions = array} : (tensor<2x3x!pphlo.pub>, tensor>) -> tensor<3x!pphlo.pub> return %1 : tensor<3x!pphlo.pub> })"); @@ -354,7 +354,7 @@ func.func @main(%arg0: tensor<10x!pphlo.pub>) -> (tensor>, %2 = "pphlo.add"(%arg1, %arg3) : (tensor>, tensor>) -> tensor> %3 = "pphlo.maximum"(%arg2, %arg4) : (tensor>, tensor>) -> tensor> "pphlo.return"(%2, %3) : (tensor>, tensor>) -> () - }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<10x!pphlo.pub>, tensor<10x!pphlo.pub>, tensor>, tensor>) -> (tensor>, tensor>) + }) {dimensions = array} : (tensor<10x!pphlo.pub>, tensor<10x!pphlo.pub>, tensor>, tensor>) -> (tensor>, tensor>) return %1#0, %1#1 : tensor>, tensor> })", 2); @@ -380,7 +380,7 @@ func.func @main(%arg0: tensor<1x10x!pphlo.pub>) -> (tensor<1x!pphlo.pub>, %arg2: tensor>): // no predecessors %2 = "pphlo.maximum"(%arg1, %arg2) : (tensor>, tensor>) -> tensor> "pphlo.return"(%2) : (tensor>) -> () - }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10x!pphlo.pub>, tensor>) -> tensor<1x!pphlo.pub> + }) {dimensions = array} : (tensor<1x10x!pphlo.pub>, tensor>) -> tensor<1x!pphlo.pub> return %1 : tensor<1x!pphlo.pub> })"); @@ -406,7 +406,7 @@ func.func @main(%arg0: tensor<2x3x4x!pphlo.sec>, %arg1: tensor<2x3x4x!pphlo ^bb0(%arg2: tensor>, %arg3: tensor>): %4 = "pphlo.and"(%arg2, %arg3) : (tensor>, tensor>) -> tensor> "pphlo.return"(%4) : (tensor>) -> () - }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<2x3x4x!pphlo.sec>, tensor>) -> tensor> + }) {dimensions = array} : (tensor<2x3x4x!pphlo.sec>, tensor>) -> tensor> return %3 : tensor> })"); @@ -431,7 +431,7 @@ func.func @main(%arg0: tensor<4x6x!pphlo.pub>) -> (tensor<2x2x!pphlo.pub>, %arg2: tensor>): // no predecessors %2 = "pphlo.add"(%arg1, %arg2) : (tensor>, tensor>) -> tensor> "pphlo.return"(%2) : (tensor>) -> () - }) {base_dilations = dense<1> : tensor<2xi64>, padding = dense<0> : tensor<2x2xi64>, window_dilations = dense<1> : tensor<2xi64>, window_dimensions = dense<[2,3]> : tensor<2xi64>, window_strides = dense<[2,3]> : tensor<2xi64>} : (tensor<4x6x!pphlo.pub>, tensor>) -> tensor<2x2x!pphlo.pub> + }) {base_dilations = array, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor<4x6x!pphlo.pub>, tensor>) -> tensor<2x2x!pphlo.pub> return %1 : tensor<2x2x!pphlo.pub> })"); @@ -515,7 +515,7 @@ func.func @main(%arg0: tensor<4x6x!pphlo.pub>) -> (tensor<3x4x!pphlo.pub>, %arg2: tensor>): // no predecessors %2 = "pphlo.maximum"(%arg1, %arg2) : (tensor>, tensor>) -> tensor> "pphlo.return"(%2) : (tensor>) -> () - }) {base_dilations = dense<1> : tensor<2xi64>, padding = dense<0> : tensor<2x2xi64>, window_dilations = dense<1> : tensor<2xi64>, window_dimensions = dense<[2,3]> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : (tensor<4x6x!pphlo.pub>, tensor>) -> tensor<3x4x!pphlo.pub> + }) {base_dilations = array, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor<4x6x!pphlo.pub>, tensor>) -> tensor<3x4x!pphlo.pub> return %1 : tensor<3x4x!pphlo.pub> })"); @@ -540,7 +540,7 @@ func.func @main(%arg0: tensor<4x4x!pphlo.pub>) -> (tensor<2x2x!pphlo.pub>, %arg2: tensor>): // no predecessors %2 = "pphlo.maximum"(%arg1, %arg2) : (tensor>, tensor>) -> tensor> "pphlo.return"(%2) : (tensor>) -> () - }) {base_dilations = dense<1> : tensor<2xi64>, padding = dense<0> : tensor<2x2xi64>, window_dilations = dense<2> : tensor<2xi64>, window_dimensions = dense<2> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : (tensor<4x4x!pphlo.pub>, tensor>) -> tensor<2x2x!pphlo.pub> + }) {base_dilations = array, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor<4x4x!pphlo.pub>, tensor>) -> tensor<2x2x!pphlo.pub> return %1 : tensor<2x2x!pphlo.pub> })"); @@ -564,7 +564,7 @@ func.func @main(%arg0: tensor<4x4x!pphlo.pub>) -> (tensor<1x1x!pphlo.pub>, %arg2: tensor>): // no predecessors %2 = "pphlo.maximum"(%arg1, %arg2) : (tensor>, tensor>) -> tensor> "pphlo.return"(%2) : (tensor>) -> () - }) {base_dilations = dense<1> : tensor<2xi64>, padding = dense<0> : tensor<2x2xi64>, window_dilations = dense<2> : tensor<2xi64>, window_dimensions = dense<2> : tensor<2xi64>, window_strides = dense<2> : tensor<2xi64>} : (tensor<4x4x!pphlo.pub>, tensor>) -> tensor<1x1x!pphlo.pub> + }) {base_dilations = array, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor<4x4x!pphlo.pub>, tensor>) -> tensor<1x1x!pphlo.pub> return %1 : tensor<1x1x!pphlo.pub> })"); @@ -1261,7 +1261,7 @@ TEST_P(ExecutorTest, DynamicSlice1D) { r.run(R"( func.func @main(%arg0: tensor<5x!pphlo.pub>, %arg1: tensor>) -> tensor<2x!pphlo.pub> { - %0 = "pphlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<5x!pphlo.pub>, tensor>) -> tensor<2x!pphlo.pub> + %0 = "pphlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = array} : (tensor<5x!pphlo.pub>, tensor>) -> tensor<2x!pphlo.pub> return %0 : tensor<2x!pphlo.pub> })"); @@ -1284,7 +1284,7 @@ TEST_P(ExecutorTest, DynamicSlice2D) { r.run(R"( func.func @main(%arg0: tensor<4x3x!pphlo.pub>, %arg1: tensor>, %arg2: tensor>) -> tensor<2x2x!pphlo.pub> { - %0 = "pphlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[2, 2]> : tensor<2xi64>} : (tensor<4x3x!pphlo.pub>, tensor>, tensor>) -> tensor<2x2x!pphlo.pub> + %0 = "pphlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = array} : (tensor<4x3x!pphlo.pub>, tensor>, tensor>) -> tensor<2x2x!pphlo.pub> return %0 : tensor<2x2x!pphlo.pub> })"); @@ -1761,6 +1761,286 @@ func.func @main(%arg0: tensor<10x!pphlo.pub>, %arg1: tensor<10x!pphlo.sec key0 = {10, 10, 10, 10, 10, 10, 10, 10, + -10, -10, -10, -10, -10, -10, -10, -10}; + xt::xarray key1 = {-3.0, -3.0, -3.0, -3.0, -3.0, -3.0, -3.0, -1.0, + 1.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0}; + xt::xarray key2 = {-10, -10, -10, -10, -10, 8, 9, 6, + 7, 5, 4, 10, 10, 10, 10, 10}; + xt::xarray key3 = {4.0, 4.0, 4.0, 4.0, -4.0, -3.0, -2.0, -1.0, + 0.0, 1.0, 2.0, 3.0, 4.0, 4.0, 4.0, 4.0}; + xt::xarray key4 = {-10, -10, -10, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 9, 10, 10, 10}; + xt::xarray key5 = {10, 10, -1, -2, -3, -4, -5, -6, + 6, 5, 4, 3, 2, 1, 10, 10}; + xt::xarray key6 = {10, 9, -1, -2, -3, -4, -5, -6, + 6, 5, 4, 3, 2, 1, 9, 10}; + xt::xarray val = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + + // ascending + xt::xarray expected_key0_asc = {-10, -10, -10, -10, -10, -10, -10, -10, + 10, 10, 10, 10, 10, 10, 10, 10}; + xt::xarray expected_key1_asc = {1.0, 3.0, 3.0, 3.0, 3.0, 3.0, + 3.0, 3.0, -3.0, -3.0, -3.0, -3.0, + -3.0, -3.0, -3.0, -1.0}; + xt::xarray expected_key2_asc = {7, 4, 5, 10, 10, 10, 10, 10, + -10, -10, -10, -10, -10, 8, 9, 6}; + xt::xarray expected_key3_asc = {0.0, 2.0, 1.0, 3.0, 4.0, 4.0, + 4.0, 4.0, -4.0, 4.0, 4.0, 4.0, + 4.0, -3.0, -2.0, -1.0}; + xt::xarray expected_key4_asc = {6, 8, 7, 9, 9, 10, 10, 10, + 2, -10, -10, -10, 1, 3, 4, 5}; + xt::xarray expected_key5_asc = {6, 4, 5, 3, 2, 1, 10, 10, + -3, -1, 10, 10, -2, -4, -5, -6}; + xt::xarray expected_key6_asc = {6, 4, 5, 3, 2, 1, 9, 10, + -3, -1, 9, 10, -2, -4, -5, -6}; + xt::xarray expected_val_asc = {9, 11, 10, 12, 13, 14, 15, 16, + 5, 3, 2, 1, 4, 6, 7, 8}; + + // descending + xt::xarray expected_key0_des = {10, 10, 10, 10, 10, 10, 10, 10, + -10, -10, -10, -10, -10, -10, -10, -10}; + xt::xarray expected_key1_des = {-1.0, -3.0, -3.0, -3.0, -3.0, -3.0, + -3.0, -3.0, 3.0, 3.0, 3.0, 3.0, + 3.0, 3.0, 3.0, 1.0}; + xt::xarray expected_key2_des = {6, 9, 8, -10, -10, -10, -10, -10, + 10, 10, 10, 10, 10, 5, 4, 7}; + xt::xarray expected_key3_des = {-1.0, -2.0, -3.0, 4.0, 4.0, 4.0, + 4.0, -4.0, 4.0, 4.0, 4.0, 4.0, + 3.0, 1.0, 2.0, 0.0}; + xt::xarray expected_key4_des = {5, 4, 3, 1, -10, -10, -10, 2, + 10, 10, 10, 9, 9, 7, 8, 6}; + xt::xarray expected_key5_des = {-6, -5, -4, -2, 10, 10, -1, -3, + 10, 10, 1, 2, 3, 5, 4, 6}; + xt::xarray expected_key6_des = {-6, -5, -4, -2, 10, 9, -1, -3, + 10, 9, 1, 2, 3, 5, 4, 6}; + xt::xarray expected_val_des = {8, 7, 6, 4, 1, 2, 3, 5, + 16, 15, 14, 13, 12, 10, 11, 9}; + + auto VERIFY_RESULTS = [&](Runner &r, bool is_ascending) { + if (is_ascending) { + r.verifyOutput(expected_key0_asc.data(), 0); + r.verifyOutput(expected_key1_asc.data(), 1); + r.verifyOutput(expected_key2_asc.data(), 2); + r.verifyOutput(expected_key3_asc.data(), 3); + r.verifyOutput(expected_key4_asc.data(), 4); + r.verifyOutput(expected_key5_asc.data(), 5); + r.verifyOutput(expected_key6_asc.data(), 6); + r.verifyOutput(expected_val_asc.data(), 7); + } else { + r.verifyOutput(expected_key0_des.data(), 0); + r.verifyOutput(expected_key1_des.data(), 1); + r.verifyOutput(expected_key2_des.data(), 2); + r.verifyOutput(expected_key3_des.data(), 3); + r.verifyOutput(expected_key4_des.data(), 4); + r.verifyOutput(expected_key5_des.data(), 5); + r.verifyOutput(expected_key6_des.data(), 6); + r.verifyOutput(expected_val_des.data(), 7); + } + }; + + // ascending direction + { + Runner r(std::get<0>(GetParam()), std::get<1>(GetParam()), + std::get<2>(GetParam())); + + r.addInput(key0); + r.addInput(key1); + r.addInput(key2); + r.addInput(key3); + r.addInput(key4); + r.addInput(key5); + r.addInput(key6); + r.addInput(val); + + // all public + r.run(R"( +func.func @main(%arg0: tensor<16x!pphlo.pub>, %arg1: tensor<16x!pphlo.pub>, %arg2: tensor<16x!pphlo.pub>, %arg3: tensor<16x!pphlo.pub>, %arg4: tensor<16x!pphlo.pub>, %arg5: tensor<16x!pphlo.pub>, %arg6: tensor<16x!pphlo.pub>, %arg7: tensor<16x!pphlo.pub>) -> (tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>,tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>) { + %0:8 = "pphlo.simple_sort"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) {dimension = 0 : i64, num_keys = 7 : i64, sort_direction = 0 : i32} : (tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>,tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>) -> (tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>,tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>) + return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7 : tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>,tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub> +})", + 8); + + VERIFY_RESULTS(r, true); + } + + // descending direction + { + Runner r(std::get<0>(GetParam()), std::get<1>(GetParam()), + std::get<2>(GetParam())); + + r.addInput(key0); + r.addInput(key1); + r.addInput(key2); + r.addInput(key3); + r.addInput(key4); + r.addInput(key5); + r.addInput(key6); + r.addInput(val); + + // all public + r.run(R"( +func.func @main(%arg0: tensor<16x!pphlo.pub>, %arg1: tensor<16x!pphlo.pub>, %arg2: tensor<16x!pphlo.pub>, %arg3: tensor<16x!pphlo.pub>, %arg4: tensor<16x!pphlo.pub>, %arg5: tensor<16x!pphlo.pub>, %arg6: tensor<16x!pphlo.pub>, %arg7: tensor<16x!pphlo.pub>) -> (tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>,tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>) { + %0:8 = "pphlo.simple_sort"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) {dimension = 0 : i64, num_keys = 7 : i64, sort_direction = 1 : i32} : (tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>,tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>) -> (tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>,tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>) + return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7 : tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub>,tensor<16x!pphlo.pub>, tensor<16x!pphlo.pub> +})", + 8); + + VERIFY_RESULTS(r, false); + } + + // ascending direction + { + Runner r(std::get<0>(GetParam()), std::get<1>(GetParam()), + std::get<2>(GetParam())); + + r.addInput(key0, VIS_SECRET, 0); + r.addInput(key1, VIS_SECRET, 0); + r.addInput(key2, VIS_SECRET, 0); + r.addInput(key3, VIS_SECRET, 0); + r.addInput(key4, VIS_SECRET, 1); + r.addInput(key5, VIS_SECRET, 1); + r.addInput(key6, VIS_SECRET, 1); + r.addInput(val, VIS_SECRET); + + // all private + r.run(R"( +func.func @main(%arg0: tensor<16x!pphlo.sec>, %arg1: tensor<16x!pphlo.sec>, %arg2: tensor<16x!pphlo.sec>, %arg3: tensor<16x!pphlo.sec>, %arg4: tensor<16x!pphlo.sec>, %arg5: tensor<16x!pphlo.sec>, %arg6: tensor<16x!pphlo.sec>, %arg7: tensor<16x!pphlo.sec>) -> (tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>,tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>) { + %0:8 = "pphlo.simple_sort"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) {dimension = 0 : i64, num_keys = 7 : i64, sort_direction = 0 : i32} : (tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>,tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>) -> (tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>,tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>) + return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7 : tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>,tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec> +})", + 8); + + VERIFY_RESULTS(r, true); + } + + // descending direction + { + Runner r(std::get<0>(GetParam()), std::get<1>(GetParam()), + std::get<2>(GetParam())); + + r.addInput(key0, VIS_SECRET, 0); + r.addInput(key1, VIS_SECRET, 0); + r.addInput(key2, VIS_SECRET, 0); + r.addInput(key3, VIS_SECRET, 0); + r.addInput(key4, VIS_SECRET, 1); + r.addInput(key5, VIS_SECRET, 1); + r.addInput(key6, VIS_SECRET, 1); + r.addInput(val, VIS_SECRET); + + // all private + r.run(R"( +func.func @main(%arg0: tensor<16x!pphlo.sec>, %arg1: tensor<16x!pphlo.sec>, %arg2: tensor<16x!pphlo.sec>, %arg3: tensor<16x!pphlo.sec>, %arg4: tensor<16x!pphlo.sec>, %arg5: tensor<16x!pphlo.sec>, %arg6: tensor<16x!pphlo.sec>, %arg7: tensor<16x!pphlo.sec>) -> (tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>,tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>) { + %0:8 = "pphlo.simple_sort"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) {dimension = 0 : i64, num_keys = 7 : i64, sort_direction = 1 : i32} : (tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>,tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>) -> (tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>,tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>) + return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7 : tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>,tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec> +})", + 8); + + VERIFY_RESULTS(r, false); + } + + // ascending direction + { + Runner r(std::get<0>(GetParam()), std::get<1>(GetParam()), + std::get<2>(GetParam())); + + r.addInput(key0, VIS_SECRET, 0); + r.addInput(key1, VIS_PUBLIC, 0); + r.addInput(key2, VIS_SECRET, 0); + r.addInput(key3, VIS_SECRET, 0); + r.addInput(key4, VIS_SECRET, 1); + r.addInput(key5, VIS_PUBLIC, 1); + r.addInput(key6, VIS_SECRET, 1); + r.addInput(val, VIS_SECRET); + + // mixed visibility + r.run(R"( +func.func @main(%arg0: tensor<16x!pphlo.sec>, %arg1: tensor<16x!pphlo.pub>, %arg2: tensor<16x!pphlo.sec>, %arg3: tensor<16x!pphlo.sec>, %arg4: tensor<16x!pphlo.sec>, %arg5: tensor<16x!pphlo.pub>, %arg6: tensor<16x!pphlo.sec>, %arg7: tensor<16x!pphlo.sec>) -> (tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>,tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>) { + %0:8 = "pphlo.simple_sort"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) {dimension = 0 : i64, num_keys = 7 : i64, sort_direction = 0 : i32} : (tensor<16x!pphlo.sec>, tensor<16x!pphlo.pub>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.pub>,tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>) -> (tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>,tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>) + return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7 : tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>,tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec> +})", + 8); + + VERIFY_RESULTS(r, true); + } + + // descending direction + { + Runner r(std::get<0>(GetParam()), std::get<1>(GetParam()), + std::get<2>(GetParam())); + + r.addInput(key0, VIS_SECRET, 0); + r.addInput(key1, VIS_PUBLIC, 0); + r.addInput(key2, VIS_SECRET, 0); + r.addInput(key3, VIS_SECRET, 0); + r.addInput(key4, VIS_SECRET, 1); + r.addInput(key5, VIS_PUBLIC, 1); + r.addInput(key6, VIS_SECRET, 1); + r.addInput(val, VIS_SECRET); + + // mixed visibility + r.run(R"( +func.func @main(%arg0: tensor<16x!pphlo.sec>, %arg1: tensor<16x!pphlo.pub>, %arg2: tensor<16x!pphlo.sec>, %arg3: tensor<16x!pphlo.sec>, %arg4: tensor<16x!pphlo.sec>, %arg5: tensor<16x!pphlo.pub>, %arg6: tensor<16x!pphlo.sec>, %arg7: tensor<16x!pphlo.sec>) -> (tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>,tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>) { + %0:8 = "pphlo.simple_sort"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) {dimension = 0 : i64, num_keys = 7 : i64, sort_direction = 1 : i32} : (tensor<16x!pphlo.sec>, tensor<16x!pphlo.pub>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.pub>,tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>) -> (tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>,tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>) + return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7 : tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>,tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec> +})", + 8); + + VERIFY_RESULTS(r, false); + } + + // ascending direction + { + Runner r(std::get<0>(GetParam()), std::get<1>(GetParam()), + std::get<2>(GetParam())); + + r.addInput(key0, VIS_SECRET, 0); + r.addInput(key1, VIS_PUBLIC, 0); + r.addInput(key2, VIS_SECRET); + r.addInput(key3, VIS_SECRET, 0); + r.addInput(key4, VIS_SECRET, 1); + r.addInput(key5, VIS_PUBLIC, 1); + r.addInput(key6, VIS_SECRET); + r.addInput(val, VIS_SECRET); + + // mixed visibility + r.run(R"( +func.func @main(%arg0: tensor<16x!pphlo.sec>, %arg1: tensor<16x!pphlo.pub>, %arg2: tensor<16x!pphlo.sec>, %arg3: tensor<16x!pphlo.sec>, %arg4: tensor<16x!pphlo.sec>, %arg5: tensor<16x!pphlo.pub>, %arg6: tensor<16x!pphlo.sec>, %arg7: tensor<16x!pphlo.sec>) -> (tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>,tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>) { + %0:8 = "pphlo.simple_sort"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) {dimension = 0 : i64, num_keys = 7 : i64, sort_direction = 0 : i32} : (tensor<16x!pphlo.sec>, tensor<16x!pphlo.pub>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.pub>,tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>) -> (tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>,tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>) + return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7 : tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>,tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec> +})", + 8); + + VERIFY_RESULTS(r, true); + } + + // descending direction + { + Runner r(std::get<0>(GetParam()), std::get<1>(GetParam()), + std::get<2>(GetParam())); + + r.addInput(key0, VIS_SECRET, 0); + r.addInput(key1, VIS_PUBLIC, 0); + r.addInput(key2, VIS_SECRET); + r.addInput(key3, VIS_SECRET, 0); + r.addInput(key4, VIS_SECRET, 1); + r.addInput(key5, VIS_PUBLIC, 1); + r.addInput(key6, VIS_SECRET); + r.addInput(val, VIS_SECRET); + + // mixed visibility + r.run(R"( +func.func @main(%arg0: tensor<16x!pphlo.sec>, %arg1: tensor<16x!pphlo.pub>, %arg2: tensor<16x!pphlo.sec>, %arg3: tensor<16x!pphlo.sec>, %arg4: tensor<16x!pphlo.sec>, %arg5: tensor<16x!pphlo.pub>, %arg6: tensor<16x!pphlo.sec>, %arg7: tensor<16x!pphlo.sec>) -> (tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>,tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>) { + %0:8 = "pphlo.simple_sort"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) {dimension = 0 : i64, num_keys = 7 : i64, sort_direction = 1 : i32} : (tensor<16x!pphlo.sec>, tensor<16x!pphlo.pub>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.pub>,tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>) -> (tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>,tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>) + return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7 : tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec>,tensor<16x!pphlo.sec>, tensor<16x!pphlo.sec> +})", + 8); + + VERIFY_RESULTS(r, false); + } +} + TEST_P(ExecutorTest, SortComplicatedComparator) { xt::xarray x = {3, 1, 4, 2}; xt::xarray y = {42, 50, 49, 47}; @@ -2050,7 +2330,7 @@ func.func @main(%arg0: tensor<4x6x!pphlo.pub>, %arg1: tensor<2x2x!pphlo.pub ^bb0(%arg3: tensor>, %arg4: tensor>): %1 = "pphlo.add"(%arg3, %arg4) : (tensor>, tensor>) -> tensor> "pphlo.return"(%1) : (tensor>) -> () - }) {padding = dense<0> : tensor<2x2xi64>, window_dimensions = dense<[2,3]> : tensor<2xi64>, window_strides = dense<[2,3]> : tensor<2xi64>} : (tensor<4x6x!pphlo.pub>, tensor<2x2x!pphlo.pub>, tensor>) -> tensor<4x6x!pphlo.pub> + }) {window_dimensions = array, window_strides = array} : (tensor<4x6x!pphlo.pub>, tensor<2x2x!pphlo.pub>, tensor>) -> tensor<4x6x!pphlo.pub> return %0 : tensor<4x6x!pphlo.pub> })"); @@ -2088,7 +2368,7 @@ func.func @main(%arg0: tensor<4x5x!pphlo.pub>, %arg1: tensor<2x2x!pphlo.pub ^bb0(%arg3: tensor>, %arg4: tensor>): %1 = "pphlo.add"(%arg3, %arg4) : (tensor>, tensor>) -> tensor> "pphlo.return"(%1) : (tensor>) -> () - }) {padding = dense<0> : tensor<2x2xi64>, window_dimensions = dense<[2,3]> : tensor<2xi64>, window_strides = dense<[2,2]> : tensor<2xi64>} : (tensor<4x5x!pphlo.pub>, tensor<2x2x!pphlo.pub>, tensor>) -> tensor<4x5x!pphlo.pub> + }) {window_dimensions = array, window_strides = array} : (tensor<4x5x!pphlo.pub>, tensor<2x2x!pphlo.pub>, tensor>) -> tensor<4x5x!pphlo.pub> return %0 : tensor<4x5x!pphlo.pub> })"); @@ -2114,7 +2394,7 @@ TEST_P(ExecutorTest, MaxPoolScatter1) { r.run(R"( func.func @main(%arg0: tensor<2x2x6x!pphlo.pub>, %arg1: tensor<2x2x!pphlo.pub>) -> (tensor<4x6x!pphlo.pub>) { - %0 = "pphlo.maxpool_scatter"(%arg0, %arg1) {padding = dense<0> : tensor<2x2xi64>, window_dimensions = dense<[2,3]> : tensor<2xi64>, window_strides = dense<[2,3]> : tensor<2xi64>} : (tensor<2x2x6x!pphlo.pub>, tensor<2x2x!pphlo.pub>) -> tensor<4x6x!pphlo.pub> + %0 = "pphlo.maxpool_scatter"(%arg0, %arg1) {window_dimensions = array, window_strides = array} : (tensor<2x2x6x!pphlo.pub>, tensor<2x2x!pphlo.pub>) -> tensor<4x6x!pphlo.pub> return %0 : tensor<4x6x!pphlo.pub> })"); @@ -2141,7 +2421,7 @@ TEST_P(ExecutorTest, MaxPoolScatter2) { r.run(R"( func.func @main(%arg0: tensor<2x2x6x!pphlo.pub>, %arg1: tensor<2x2x!pphlo.pub>) -> (tensor<4x5x!pphlo.pub>) { - %0 = "pphlo.maxpool_scatter"(%arg0, %arg1) {padding = dense<0> : tensor<2x2xi64>, window_dimensions = dense<[2,3]> : tensor<2xi64>, window_strides = dense<[2,2]> : tensor<2xi64>} : (tensor<2x2x6x!pphlo.pub>, tensor<2x2x!pphlo.pub>) -> tensor<4x5x!pphlo.pub> + %0 = "pphlo.maxpool_scatter"(%arg0, %arg1) {window_dimensions = array, window_strides = array} : (tensor<2x2x6x!pphlo.pub>, tensor<2x2x!pphlo.pub>) -> tensor<4x5x!pphlo.pub> return %0 : tensor<4x5x!pphlo.pub> })"); @@ -2165,7 +2445,7 @@ TEST_P(ExecutorTest, MaxPoolReduce1) { r.run(R"( func.func @main(%arg0: tensor<4x6x!pphlo.pub>) -> (tensor<2x2x!pphlo.pub>, tensor<2x2x6x!pphlo.pub>) { - %4:2 = "pphlo.argmax"(%arg0) {base_dilations = dense<1> : tensor<4xi64>, padding = dense<0> : tensor<2x2xi64>, window_dilations = dense<1> : tensor<2xi64>, window_dimensions = dense<[2, 3]> : tensor<2xi64>, window_strides = dense<[2, 3]> : tensor<2xi64>} : (tensor<4x6x!pphlo.pub>) -> (tensor<2x2x!pphlo.pub>, tensor<2x2x6x!pphlo.pub>) + %4:2 = "pphlo.argmax"(%arg0) {window_dilations = array, window_dimensions = array, window_strides = array} : (tensor<4x6x!pphlo.pub>) -> (tensor<2x2x!pphlo.pub>, tensor<2x2x6x!pphlo.pub>) return %4#0, %4#1: tensor<2x2x!pphlo.pub>, tensor<2x2x6x!pphlo.pub> })", 2); @@ -2195,7 +2475,7 @@ TEST_P(ExecutorTest, MaxPoolReduce2) { r.run(R"( func.func @main(%arg0: tensor<4x5x!pphlo.pub>) -> (tensor<2x2x!pphlo.pub>, tensor<2x2x6x!pphlo.pub>) { - %4:2 = "pphlo.argmax"(%arg0) {base_dilations = dense<1> : tensor<4xi64>, padding = dense<0> : tensor<2x2xi64>, window_dilations = dense<1> : tensor<2xi64>, window_dimensions = dense<[2, 3]> : tensor<2xi64>, window_strides = dense<[2, 2]> : tensor<2xi64>} : (tensor<4x5x!pphlo.pub>) -> (tensor<2x2x!pphlo.pub>, tensor<2x2x6x!pphlo.pub>) + %4:2 = "pphlo.argmax"(%arg0) {window_dilations = array, window_dimensions = array, window_strides = array} : (tensor<4x5x!pphlo.pub>) -> (tensor<2x2x!pphlo.pub>, tensor<2x2x6x!pphlo.pub>) return %4#0, %4#1: tensor<2x2x!pphlo.pub>, tensor<2x2x6x!pphlo.pub> })", 2); @@ -2296,8 +2576,8 @@ TEST_P(ExecutorTest, MaxPoolReduce3) { r.run(R"( func.func @main(%arg0: tensor<1x4x4x1x!pphlo.pub>, %arg1: tensor<1x3x3x1x!pphlo.pub>) -> (tensor<1x3x3x1x!pphlo.pub>, tensor<1x3x3x1x4x!pphlo.pub>, tensor<1x4x4x1x!pphlo.pub>) { - %0:2 = "pphlo.argmax"(%arg0) {base_dilations = dense<1> : tensor<4xi64>, padding = dense<0> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<1x4x4x1x!pphlo.pub>) -> (tensor<1x3x3x1x!pphlo.pub>, tensor<1x3x3x1x4x!pphlo.pub>) - %1 = "pphlo.maxpool_scatter"(%0#1, %arg1) {padding = dense<0> : tensor<4x2xi64>, window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<1x3x3x1x4x!pphlo.pub>, tensor<1x3x3x1x!pphlo.pub>) -> tensor<1x4x4x1x!pphlo.pub> + %0:2 = "pphlo.argmax"(%arg0) {window_dilations = array, window_dimensions = array, window_strides = array} : (tensor<1x4x4x1x!pphlo.pub>) -> (tensor<1x3x3x1x!pphlo.pub>, tensor<1x3x3x1x4x!pphlo.pub>) + %1 = "pphlo.maxpool_scatter"(%0#1, %arg1) {window_dimensions = array, window_strides = array} : (tensor<1x3x3x1x4x!pphlo.pub>, tensor<1x3x3x1x!pphlo.pub>) -> tensor<1x4x4x1x!pphlo.pub> return %0#0, %0#1, %1: tensor<1x3x3x1x!pphlo.pub>, tensor<1x3x3x1x4x!pphlo.pub>, tensor<1x4x4x1x!pphlo.pub> })", 3); diff --git a/libspu/device/pphlo/pphlo_executor_test_runner.cc b/libspu/device/pphlo/pphlo_executor_test_runner.cc index 7c770cb2..be485319 100644 --- a/libspu/device/pphlo/pphlo_executor_test_runner.cc +++ b/libspu/device/pphlo/pphlo_executor_test_runner.cc @@ -28,6 +28,7 @@ Runner::Runner(size_t world_size, FieldType field, ProtocolKind protocol) config_.set_field(field); config_.set_protocol(protocol); config_.set_enable_type_checker(true); + config_.set_experimental_enable_colocated_optimization(true); io_ = std::make_unique(world_size_, config_); } diff --git a/libspu/device/pphlo/pphlo_executor_test_runner.h b/libspu/device/pphlo/pphlo_executor_test_runner.h index 0fb976f1..9263cc87 100644 --- a/libspu/device/pphlo/pphlo_executor_test_runner.h +++ b/libspu/device/pphlo/pphlo_executor_test_runner.h @@ -35,9 +35,10 @@ class Runner { auto &getConfig() { return config_; } template - void addInput(const T &input, Visibility vis = Visibility::VIS_PUBLIC) { + void addInput(const T &input, Visibility vis = Visibility::VIS_PUBLIC, + int owner_rank = -1) { const std::string name = fmt::format("input{}", input_idx_++); - io_->InFeed(name, input, vis); + io_->InFeed(name, input, vis, owner_rank); executable_.add_input_names(name); } diff --git a/libspu/device/pphlo/pphlo_verifier_test.cc b/libspu/device/pphlo/pphlo_verifier_test.cc index dcf0df75..e9e0f587 100644 --- a/libspu/device/pphlo/pphlo_verifier_test.cc +++ b/libspu/device/pphlo/pphlo_verifier_test.cc @@ -301,7 +301,7 @@ TEST(Verify, Clamp) { TEST(Verify, DynamicSlice) { std::string mlir = R"( func.func @main(%arg0: tensor<5x!pphlo.pub>, %arg1: tensor>) -> tensor<2x!pphlo.pub> { - %0 = "pphlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<5x!pphlo.pub>, tensor>) -> tensor<2x!pphlo.pub> + %0 = "pphlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = array} : (tensor<5x!pphlo.pub>, tensor>) -> tensor<2x!pphlo.pub> return %0 : tensor<2x!pphlo.pub> })"; diff --git a/libspu/device/test_utils.h b/libspu/device/test_utils.h index f52ed909..2e8e852b 100644 --- a/libspu/device/test_utils.h +++ b/libspu/device/test_utils.h @@ -33,8 +33,9 @@ class LocalIo { // } - void InFeed(const std::string &name, PtBufferView view, Visibility vtype) { - auto shares = io_client_.makeShares(view, vtype); + void InFeed(const std::string &name, PtBufferView view, Visibility vtype, + int owner_rank = -1) { + auto shares = io_client_.makeShares(view, vtype, owner_rank); SPU_ENFORCE(shares.size() == symbol_tables_.size()); for (size_t idx = 0; idx < symbol_tables_.size(); ++idx) { symbol_tables_[idx].setVar(name, shares[idx]); diff --git a/libspu/dialect/pphlo_ops.cc b/libspu/dialect/pphlo_ops.cc index 8785c8ff..0dd64ca0 100644 --- a/libspu/dialect/pphlo_ops.cc +++ b/libspu/dialect/pphlo_ops.cc @@ -120,7 +120,7 @@ class TransposeReshapeGenericDotGeneral } return b.create( loc, RankedTensorType::get(transposeShape, type.getElementType()), src, - b.getI64TensorAttr(target_order)); + target_order); } static Value ReshapeIfMorethan3D(OpBuilder& b, Location loc, Value src, @@ -275,12 +275,6 @@ std::vector InversePermutation( return output_permutation; } -mlir::DenseIntElementsAttr ConvertToDenseIntElementAttr( - OpBuilder* builder, llvm::ArrayRef value) { - return DenseIntElementsAttr::get( - RankedTensorType::get(value.size(), builder->getIntegerType(64)), value); -} - bool IsSameShape(llvm::ArrayRef lhs, llvm::ArrayRef rhs) { if (lhs.size() != rhs.size()) { return false; @@ -358,8 +352,7 @@ class NormalizeDimensionOrder : public OpRewritePattern { auto new_input_type = RankedTensorType::get(new_input_dims, input_type.getElementType()); new_input = rewriter.create( - op->getLoc(), new_input_type, input, - ConvertToDenseIntElementAttr(&rewriter, new_input_dim_order)); + op->getLoc(), new_input_type, input, new_input_dim_order); } auto kernel = op.getRhs(); @@ -383,9 +376,8 @@ class NormalizeDimensionOrder : public OpRewritePattern { if (needTranspose(kernel_shape, new_kernel_dims, new_kernel_dim_order)) { auto new_kernel_type = RankedTensorType::get(new_kernel_dims, kernel_type.getElementType()); - new_kernel = rewriter.create( - op->getLoc(), new_kernel_type, kernel, - ConvertToDenseIntElementAttr(&rewriter, new_kernel_dim_order)); + new_kernel = rewriter.create(op->getLoc(), new_kernel_type, + kernel, new_kernel_dim_order); } if (input == new_input && kernel == new_kernel) { @@ -413,9 +405,9 @@ class NormalizeDimensionOrder : public OpRewritePattern { auto new_conv_type = RankedTensorType::get(new_conv_dims, result_type.getElementType()); - std::vector input_sd(num_spatial_dims); - std::vector kernel_sd(num_spatial_dims); - std::vector output_sd(num_spatial_dims); + llvm::SmallVector input_sd(num_spatial_dims); + llvm::SmallVector kernel_sd(num_spatial_dims); + llvm::SmallVector output_sd(num_spatial_dims); for (int64_t i = 0; i < num_spatial_dims; ++i) { input_sd[i] = i + 1; @@ -432,14 +424,14 @@ class NormalizeDimensionOrder : public OpRewritePattern { // example, input height and width are the same as before the reshapes. auto new_conv = rewriter.create( op->getLoc(), new_conv_type, new_input, new_kernel, - op.getWindowStrides().value_or(nullptr), new_dnums, - op.getFeatureGroupCount(), op.getBatchGroupCount()); + DenseI64ArrayAttr::get(op->getContext(), + op.getWindowStrides().value_or(std::nullopt)), + new_dnums, op.getFeatureGroupCount(), op.getBatchGroupCount()); // Reshape the output back to the shape of the original convolution. rewriter.replaceOpWithNewOp( op, op->getResultTypes()[0], new_conv, - ConvertToDenseIntElementAttr(&rewriter, - InversePermutation(new_output_dim_order))); + InversePermutation(new_output_dim_order)); return success(); } @@ -452,16 +444,15 @@ OpFoldResult ReverseOp::fold(FoldAdaptor) { // No dimensions to reverse. auto dims = getDimensions(); - if (dims.getNumElements() == 0) { + if (dims.empty()) { return input; } // If the dimensions to reverse are all statically 1, then the reverse is a // no-op. auto shapedType = input.getType().cast(); - if (llvm::all_of(dims.getValues(), [&](int64_t dim) { - return shapedType.getDimSize(dim) == 1; - })) { + if (llvm::all_of( + dims, [&](int64_t dim) { return shapedType.getDimSize(dim) == 1; })) { return input; } return {}; @@ -477,13 +468,8 @@ LogicalResult ReverseOp::verify() { return emitOpError("operand and result type mismatch"); } - // dimensions: 1-dimensional tensor constant of type si64 - if (getDimensions().getType().getRank() != 1) { - return emitOpError("dimensions must be a 1-dimensional tensor"); - } - //(C2) All dimensions in dimensions are unique. - auto dims = getDimensions().getValues(); + auto dims = getDimensions(); llvm::SmallDenseSet unique_dims(dims.begin(), dims.end()); if (unique_dims.size() != dims.size()) { @@ -550,8 +536,9 @@ OpFoldResult ReciprocalOp::fold(FoldAdaptor operands) { LogicalResult verifyReduceOpInputsAndInferShape( std::optional location, SmallVector inputArgTypes, - SmallVector /*initValueTypes*/, DenseIntElementsAttr dimensions, - SmallVector& /*newDimensions*/, Attribute& /*encoding*/) { + SmallVector /*initValueTypes*/, + llvm::ArrayRef dimensions, SmallVector& /*newDimensions*/, + Attribute& /*encoding*/) { uint64_t numInputs = inputArgTypes.size(); for (uint64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) { @@ -565,7 +552,7 @@ LogicalResult verifyReduceOpInputsAndInferShape( } DenseSet dimensionsToReduceSet; - for (int64_t dimension : dimensions.getValues()) { + for (int64_t dimension : dimensions) { if ((dimension >= inputArgTypes[0].getRank()) || dimension < 0) { return emitOptionalError( location, "Out-of-bounds dimension ", dimension, @@ -762,8 +749,8 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor) { } OpFoldResult TransposeOp::fold(FoldAdaptor) { - for (const auto& it : llvm::enumerate(getPermutation().getValues())) { - if (it.index() != it.value()) { + for (const auto& it : llvm::enumerate(getPermutation())) { + if (static_cast(it.index()) != it.value()) { return {}; } } @@ -784,7 +771,7 @@ LogicalResult TransposeOp::verify() { // (C2) permutation is a permutation of [0, 1, ..., R-1] where R is the rank // of operand. auto max_rank = inputType.getRank(); - auto permutation = getPermutation().getValues(); + auto permutation = getPermutation(); for (auto p : permutation) { if (p < 0 || p > max_rank - 1) { return emitOpError(llvm::formatv("permutation {0} out of range [0, {1}]", @@ -918,7 +905,7 @@ LogicalResult BroadcastOp::verify() { auto operandRank = operandType.getRank(); - if (!getBroadcastDimensions()) { + if (getBroadcastDimensions().empty()) { if (operandRank == 0) { return success(); } @@ -928,22 +915,14 @@ LogicalResult BroadcastOp::verify() { operandRank)); } - auto dimensionsType = getBroadcastDimensions().getType(); - auto dimensionsRank = dimensionsType.getRank(); - if (dimensionsRank != 1) { - return emitOpError(llvm::formatv( - "broadcast_dimensions has rank {0} instead of rank 1", dimensionsRank)); - } - - auto dimensionsSize = dimensionsType.getNumElements(); - if (dimensionsSize != operandRank) { + auto dimensionsSize = getBroadcastDimensions().size(); + if (static_cast(dimensionsSize) != operandRank) { return emitOpError(llvm::formatv( "broadcast_dimensions size ({0}) does not match operand rank ({1})", dimensionsSize, operandRank)); } - auto dimensions = - llvm::to_vector(getBroadcastDimensions().getValues()); + auto dimensions = getBroadcastDimensions(); if (hasDuplicates(dimensions)) { return emitOpError("broadcast_dimensions should not have duplicates"); } @@ -951,7 +930,7 @@ LogicalResult BroadcastOp::verify() { auto resultType = getResult().getType().cast(); auto resultRank = resultType.getRank(); - for (int i = 0; i != dimensionsSize; ++i) { + for (size_t i = 0; i != dimensionsSize; ++i) { auto dimIndex = dimensions[i]; if ((dimIndex >= resultRank) || (dimIndex < 0)) { return emitOpError( @@ -995,26 +974,19 @@ LogicalResult IotaOp::verify() { LogicalResult SliceOp::verify() { auto rankedTy = getOperand().getType(); - // slice_i2 - ShapedType attrTy = getStartIndices().getType(); - if (attrTy.getRank() != 1) { - return emitOpError( - llvm::formatv("start_indices has rank {0} instead of required rank 1", - attrTy.getRank())); - } // slice_c2 int64_t rank = rankedTy.getRank(); - if (attrTy.getNumElements() != rank) { + if (static_cast(getStartIndices().size()) != rank) { return emitOpError( llvm::formatv("the number of elements in start_indices ({0}) does not " "match the rank of the operand ({1})", - attrTy.getNumElements(), rank)); + getStartIndices().size(), rank)); } - auto start = getStartIndices().getValues(); - auto limit = getLimitIndices().getValues(); - auto strideVals = getStrides().getValues(); + auto start = getStartIndices(); + auto limit = getLimitIndices(); + auto strideVals = getStrides(); for (int64_t i = 0, e = rank; i != e; i++) { // slice_c3 @@ -1051,16 +1023,10 @@ LogicalResult SliceOp::verify() { LogicalResult inferDynamicSliceOp(std::optional location, Type operandType, TypeRange startIndicesTypes, - DenseIntElementsAttr sliceSizes, + llvm::ArrayRef sliceSizes, SmallVectorImpl& inferredReturnTypes) { - // dynamic_slice_i3 - if (sliceSizes.getType().getRank() != 1) { - return emitOptionalError(location, - "slice_sizes should be rank 1, but got rank ", - sliceSizes.getType().getRank(), "."); - } // dynamic_slice_c2 - int numSliceSizes = sliceSizes.getNumElements(); + int numSliceSizes = sliceSizes.size(); int numStartIndices = startIndicesTypes.size(); if (numStartIndices != numSliceSizes) { return emitOptionalError(location, "has mismatched number of slice sizes (", @@ -1077,7 +1043,7 @@ LogicalResult inferDynamicSliceOp(std::optional location, // dynamic_slice_c4 for (int i = 0; i < numSliceSizes; ++i) { - int64_t sliceSize = sliceSizes.getValues()[i]; + int64_t sliceSize = sliceSizes[i]; if (sliceSize < 0) { return emitOptionalError( location, "has negative size index to dynamic slice: ", sliceSize); @@ -1092,11 +1058,9 @@ LogicalResult inferDynamicSliceOp(std::optional location, } } - std::vector slice_size(sliceSizes.getValues().begin(), - sliceSizes.getValues().end()); // dynamic_slice_c5 inferredReturnTypes.emplace_back( - RankedTensorType::get(slice_size, rankedOperandType.getElementType())); + RankedTensorType::get(sliceSizes, rankedOperandType.getElementType())); return success(); } @@ -1650,57 +1614,23 @@ Attribute ConvDimensionNumbersAttr::parse(AsmParser& parser, Type) { namespace { // Custom formatting for convolution window attributes. -void printWindowAttribute(OpAsmPrinter& p, DenseElementsAttr attribute) { - if (attribute.getElementType().isInteger(/*width=*/1)) { - // boolean attribute. - llvm::interleaveComma(attribute.getValues(), p, - [&](bool b) { p << (b ? 1 : 0); }); - return; - } - if (attribute.getType().getRank() == 2) { - // Padding is Nx2 attribute. - auto it = attribute.value_begin(); - std::vector> values(attribute.getNumElements() / - 2); - for (auto& item : values) { - int64_t first = *it; - ++it; - int64_t second = *it; - ++it; - item = {first, second}; - } - llvm::interleaveComma( - values, p, [&](const std::pair pair) { - p << '[' << pair.first << ", " << pair.second << ']'; - }); - } else { - llvm::interleaveComma(attribute.getValues(), p); - } +void printWindowAttribute(OpAsmPrinter& p, llvm::ArrayRef attribute) { + llvm::interleaveComma(attribute, p); } } // namespace void printWindowAttributes(OpAsmPrinter& p, Operation*, - std::optional window_strides) { - using PairT = std::pair; - std::array printed_attributes = {{ - {window_strides ? *window_strides : nullptr, "stride"}, - }}; - - // Do not print attributes that do no exist. - auto non_null_attributes = llvm::make_filter_range( - printed_attributes, - [](const PairT& a) { return static_cast(a.first); }); - - llvm::interleaveComma(non_null_attributes, p, [&](const PairT& a) { - p << a.second << " = ["; - printWindowAttribute(p, a.first); + std::optional window_strides) { + if (window_strides.has_value()) { + p << "stride = ["; + printWindowAttribute(p, *window_strides); p << "]"; - }); + } } ParseResult parseWindowAttributes(OpAsmParser& parser, - DenseIntElementsAttr& window_strides) { + DenseI64ArrayAttr& window_strides) { StringRef attribute_name; // Helper to parse an array of the form [ e0, e1, .. ] @@ -1754,10 +1684,9 @@ ParseResult parseWindowAttributes(OpAsmParser& parser, if (parse_array(int64_parser)) { return failure(); } - auto attr = parser.getBuilder().getI64TensorAttr(values); if (attribute_name == "stride") { - window_strides = attr; + window_strides = DenseI64ArrayAttr::get(parser.getContext(), values); } else { llvm_unreachable("Unexpected attribute name"); } diff --git a/libspu/dialect/pphlo_ops.td b/libspu/dialect/pphlo_ops.td index c2a6a424..125012b1 100644 --- a/libspu/dialect/pphlo_ops.td +++ b/libspu/dialect/pphlo_ops.td @@ -613,7 +613,7 @@ def PPHLO_BroadcastOp See https://www.tensorflow.org/xla/broadcasting. }]; let arguments = (ins PPHLO_Tensor: $operand, - I64ElementsAttr: $broadcast_dimensions); + DenseI64ArrayAttr: $broadcast_dimensions); let results = (outs PPHLO_Tensor); @@ -719,8 +719,8 @@ def PPHLO_SelectAndScatterOp: PPHLO_Op<"select_and_scatter", PPHLO_Tensor:$operand, PPHLO_Tensor:$source, PPHLO_Tensor:$init_value, - I64ElementsAttr:$window_dimensions, - OptionalAttr:$window_strides + DenseI64ArrayAttr:$window_dimensions, + OptionalAttr:$window_strides ); let regions = (region SizedRegion<1>:$select, SizedRegion<1>:$scatter); @@ -740,8 +740,8 @@ def PPHLO_MaxPoolScatterOp: PPHLO_Op<"maxpool_scatter", [Pure]> { let arguments = (ins PPHLO_IntTensor:$scatter_indices, PPHLO_Tensor:$update, - OptionalAttr:$window_dimensions, - OptionalAttr:$window_strides + DenseI64ArrayAttr:$window_dimensions, + OptionalAttr:$window_strides ); let results = (outs PPHLO_Tensor); @@ -761,14 +761,14 @@ def PPHLO_ReduceOp : PPHLO_Op<"reduce", [ let arguments = (ins Variadic:$inputs, Variadic:$init_values, - I64ElementsAttr:$dimensions + DenseI64ArrayAttr:$dimensions ); let results = (outs Variadic); let builders = [ OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$init_values, - "DenseIntElementsAttr":$dimensions)>]; + "DenseI64ArrayAttr":$dimensions)>]; let regions = (region SizedRegion<1> : $body); @@ -790,11 +790,11 @@ def PPHLO_ReduceWindowOp : PPHLO_Op<"reduce_window", [ let arguments = (ins Variadic:$inputs, Variadic:$init_values, - I64ElementsAttr:$window_dimensions, + DenseI64ArrayAttr:$window_dimensions, // If strides or dilations attributes are missing then the default value is // one for each of the input dimensions. - OptionalAttr:$window_strides, - OptionalAttr:$window_dilations + OptionalAttr:$window_strides, + OptionalAttr:$window_dilations ); let results = (outs Variadic); @@ -811,11 +811,11 @@ def PPHLO_ArgMaxOp: PPHLO_Op<"argmax", [Pure]> { let arguments = (ins PPHLO_Tensor:$input, - I64ElementsAttr:$window_dimensions, + DenseI64ArrayAttr:$window_dimensions, // If strides or dilations attributes are missing then the default value is // one for each of the input dimensions. - OptionalAttr:$window_strides, - OptionalAttr:$window_dilations, + OptionalAttr:$window_strides, + OptionalAttr:$window_dilations, DefaultValuedAttr:$onehot_index ); @@ -840,7 +840,7 @@ def PPHLO_TransposeOp See https://www.tensorflow.org/xla/operation_semantics#transpose. }]; - let arguments = (ins PPHLO_Tensor : $operand, I64ElementsAttr : $permutation); + let arguments = (ins PPHLO_Tensor : $operand, DenseI64ArrayAttr : $permutation); let results = (outs PPHLO_Tensor); let hasFolder = 1; @@ -849,7 +849,8 @@ def PPHLO_TransposeOp def PPHLO_SliceOp : PPHLO_Op<"slice", [ Pure, SameOperandsAndResultElementType, - AllTypesMatch<["start_indices", "limit_indices", "strides"]> + AllMatchSameOperatorTrait<["start_indices", "limit_indices", + "strides"], "$_self.size()", "size"> /*slice_c2*/, ]> { let description = [{ The dynamic shape version of SliceOp. Extracts a sub-array from the input @@ -860,9 +861,9 @@ def PPHLO_SliceOp : PPHLO_Op<"slice", [ See https://www.tensorflow.org/xla/operation_semantics#slice }]; let arguments = (ins PPHLO_Tensor - : $operand, I64ElementsAttr - : $start_indices, I64ElementsAttr - : $limit_indices, I64ElementsAttr + : $operand, DenseI64ArrayAttr + : $start_indices, DenseI64ArrayAttr + : $limit_indices, DenseI64ArrayAttr : $strides); let results = (outs PPHLO_Tensor); @@ -964,7 +965,7 @@ def PPHLO_ReverseOp See https://www.tensorflow.org/xla/operation_semantics#rev_reverse. }]; - let arguments = (ins PPHLO_Tensor : $operand, I64ElementsAttr : $dimensions); + let arguments = (ins PPHLO_Tensor : $operand, DenseI64ArrayAttr : $dimensions); let results = (outs PPHLO_Tensor); @@ -984,9 +985,9 @@ def HLO_PadOp }]; let arguments = (ins PPHLO_Tensor : $operand, PPHLO_Tensor - : $padding_value, I64ElementsAttr - : $edge_padding_low, I64ElementsAttr - : $edge_padding_high, I64ElementsAttr + : $padding_value, DenseI64ArrayAttr + : $edge_padding_low, DenseI64ArrayAttr + : $edge_padding_high, DenseI64ArrayAttr : $interior_padding); let results = (outs PPHLO_Tensor); @@ -1005,7 +1006,7 @@ def PPHLO_GatherOp : PPHLO_Op<"gather", [Pure]> { PPHLO_Tensor:$operand, PPHLO_IntTensor:$start_indices, GatherDimensionNumbers:$dimension_numbers, - I64ElementsAttr:$slice_sizes, + DenseI64ArrayAttr:$slice_sizes, DefaultValuedAttr:$indices_are_sorted ); @@ -1015,7 +1016,7 @@ def PPHLO_GatherOp : PPHLO_Op<"gather", [Pure]> { def ConvolutionAttributes { dag attributes = (ins // Default value: one for each of the spatial dimension. - OptionalAttr:$window_strides, + OptionalAttr:$window_strides, ConvDimensionNumbers:$dimension_numbers, I64Attr:$feature_group_count, I64Attr:$batch_group_count @@ -1057,7 +1058,7 @@ def PPHLO_DynamicSliceOp: PPHLO_WithShapeInferOp<"dynamic-slice", [Pure]> { let arguments = (ins PPHLO_Tensor:$operand, Variadic:$start_indices, - I64ElementsAttr:$slice_sizes + DenseI64ArrayAttr:$slice_sizes ); let results = (outs PPHLO_Tensor:$result); diff --git a/libspu/kernel/hal/BUILD.bazel b/libspu/kernel/hal/BUILD.bazel index 261dd4c6..09eedf47 100644 --- a/libspu/kernel/hal/BUILD.bazel +++ b/libspu/kernel/hal/BUILD.bazel @@ -33,7 +33,6 @@ spu_cc_library( hdrs = ["ring.h"], deps = [ ":prot_wrapper", - ":shape_ops", "//libspu/core:context", ], ) @@ -255,6 +254,7 @@ spu_cc_library( deps = [ # Please DONT add extra dependency here. ":prot_wrapper", + ":ring", "//libspu/core:context", ], ) diff --git a/libspu/kernel/hal/debug.cc b/libspu/kernel/hal/debug.cc index f2b1825e..e695b4ec 100644 --- a/libspu/kernel/hal/debug.cc +++ b/libspu/kernel/hal/debug.cc @@ -36,7 +36,7 @@ void dbg_print(SPUContext* ctx, const Value& v) { if ((ctx->lctx() && ctx->lctx()->Rank() == 0) || ctx->lctx() == nullptr) { SPDLOG_INFO("dbg_print {}", ss.str()); } - } else if (v.isSecret()) { + } else if (v.isSecret() || v.isPrivate()) { dbg_print(ctx, reveal(ctx, v)); } else { SPU_THROW("unsupport vtype={}", v.vtype()); diff --git a/libspu/kernel/hal/permute.cc b/libspu/kernel/hal/permute.cc index 681f4c1d..f37a8106 100644 --- a/libspu/kernel/hal/permute.cc +++ b/libspu/kernel/hal/permute.cc @@ -16,7 +16,9 @@ #include +#include "libspu/core/bit_utils.h" #include "libspu/core/context.h" +#include "libspu/core/trace.h" #include "libspu/kernel/hal/constants.h" #include "libspu/kernel/hal/polymorphic.h" #include "libspu/kernel/hal/prot_wrapper.h" @@ -28,10 +30,18 @@ namespace spu::kernel::hal { -namespace { +namespace internal { + +inline int64_t _get_owner(const Value &x) { + return x.storage_type().as()->owner(); +} + +inline bool _has_same_owner(const Value &x, const Value &y) { + return _get_owner(x) == _get_owner(y); +} // generate inverse permutation -Index GenInvPerm(const Index &p) { +Index _inverse_index(const Index &p) { Index q(p.size()); const auto n = static_cast(p.size()); for (int64_t i = 0; i < n; ++i) { @@ -40,7 +50,16 @@ Index GenInvPerm(const Index &p) { return q; } -Value Permute1D(SPUContext *, const Value &x, const Index &indices) { +spu::Value _2s(SPUContext *ctx, const Value &x) { + if (x.isPublic()) { + return _p2s(ctx, x); + } else if (x.isPrivate()) { + return _v2s(ctx, x); + } + return x; +} + +Value _permute_1d(SPUContext *, const Value &x, const Index &indices) { SPU_ENFORCE(x.shape().size() == 1); return Value(x.data().linear_gather(indices), x.dtype()); } @@ -48,7 +67,7 @@ Value Permute1D(SPUContext *, const Value &x, const Index &indices) { // FIXME: move to mpc layer // Vectorized Prefix Sum // Ref: https://en.algorithmica.org/hpc/algorithms/prefix/ -Value PrefixSum(SPUContext *ctx, const Value &x) { +Value _prefix_sum(SPUContext *ctx, const Value &x) { SPU_ENFORCE(x.shape().ndim() == 2U && x.shape()[0] == 1, "x should be 1-row matrix"); @@ -63,9 +82,9 @@ Value PrefixSum(SPUContext *ctx, const Value &x) { return x_t; } -void CmpSwap(SPUContext *ctx, const CompFn &comparator_body, - absl::Span values_to_sort, const Index &lhs_indices, - const Index &rhs_indices) { +void _cmp_swap(SPUContext *ctx, const CompFn &comparator_body, + absl::Span values_to_sort, const Index &lhs_indices, + const Index &rhs_indices) { size_t num_operands = values_to_sort.size(); std::vector values; @@ -95,15 +114,15 @@ void CmpSwap(SPUContext *ctx, const CompFn &comparator_body, // Secure Odd-even mergesort // Ref: // https://hwlang.de/algorithmen/sortieren/networks/oemen.htm -std::vector OddEvenMergeSort(SPUContext *ctx, - const CompFn &comparator_body, - absl::Span inputs) { +std::vector odd_even_merge_sort( + SPUContext *ctx, const CompFn &comparator_body, + absl::Span inputs) { // make a copy for inplace sort std::vector ret; for (auto const &input : inputs) { - if (input.isPublic()) { + if (!input.isSecret()) { // we can not linear_scatter a secret value to a public operand - ret.emplace_back(_p2s(ctx, input.clone()).setDtype(input.dtype())); + ret.emplace_back(_2s(ctx, input.clone()).setDtype(input.dtype())); } else { ret.emplace_back(input.clone()); } @@ -111,7 +130,7 @@ std::vector OddEvenMergeSort(SPUContext *ctx, // sort by per network layer for memory optimizations, sorting N elements // needs log2(N) stages, and the i_th stage has i layers, which means the - // same latency cost as BitonicSort but less CmpSwap unit. + // same latency cost as BitonicSort but less _cmp_swap unit. const auto n = inputs.front().numel(); for (int64_t max_gap_in_stage = 1; max_gap_in_stage < n; max_gap_in_stage += max_gap_in_stage) { @@ -136,35 +155,32 @@ std::vector OddEvenMergeSort(SPUContext *ctx, } } - CmpSwap(ctx, comparator_body, absl::MakeSpan(ret), lhs_indices, - rhs_indices); + _cmp_swap(ctx, comparator_body, absl::MakeSpan(ret), lhs_indices, + rhs_indices); } } return ret; } -// Secure shuffle a shared permutation and use it to permute shared bit -// vectors of x. -// x is a list of shared bit vectors, is a shared permutation, -// random_perm is a permutation for shuffling , and m is the -// revealed permutation of shuffled . +// Ref: https://eprint.iacr.org/2019/695.pdf +// Algorithm 13 Optimized inverse application of a permutation // // The steps are as follows: // 1) secure shuffle as // 2) secure shuffle as // 3) reveal securely shuffled as m // 4) inverse permute by m and return -std::pair, spu::Value> ShufflePerm( +std::pair, spu::Value> _opt_apply_inv_perm_ss( SPUContext *ctx, absl::Span x, spu::Value perm, spu::Value random_perm) { // 1. = secure shuffle - auto sp = _perm_ss(ctx, perm, random_perm); + auto sp = hal::_perm_ss(ctx, perm, random_perm); // 2. = secure shuffle std::vector sx; for (size_t i = 0; i < x.size(); ++i) { - sx.emplace_back(_perm_ss(ctx, x[i], random_perm)); + sx.emplace_back(hal::_perm_ss(ctx, x[i], random_perm)); } // 3. M = reveal() @@ -175,7 +191,7 @@ std::pair, spu::Value> ShufflePerm( std::vector v; for (size_t i = 0; i < sx.size(); ++i) { - auto t = _inv_perm_sp(ctx, sx[i], m); + auto t = hal::_inv_perm_sp(ctx, sx[i], m); v.emplace_back(std::move(t)); } @@ -185,14 +201,14 @@ std::pair, spu::Value> ShufflePerm( // Process two bit vectors in one loop // Reference: https://eprint.iacr.org/2019/695.pdf (5.2 Optimizations) // -// perm = GenInvPermByTwoBitVectors(x, y) +// perm = _gen_inv_perm_by_bv(x, y) // input: bit vector x, bit vector y // bit vector y is more significant than x // output: shared inverse permutation // // We can generate inverse permutation by two bit vectors in one loop. // It needs one extra mul op and 2 times memory to store intermediate data -// than GenInvPermByBitVector. But the number of invocations of +// than _gen_inv_perm_by_bv. But the number of invocations of // permutation-related protocols such as SecureInvPerm or Compose will be // reduced to half. // @@ -223,8 +239,8 @@ std::pair, spu::Value> ShufflePerm( // r = [2, 1] // 8) get res by sub r by one // res = [1, 0] -spu::Value GenInvPermByTwoBitVectors(SPUContext *ctx, const spu::Value &x, - const spu::Value &y) { +spu::Value _gen_inv_perm_by_bv(SPUContext *ctx, const spu::Value &x, + const spu::Value &y) { SPU_ENFORCE(x.shape() == y.shape(), "x and y should has the same shape"); SPU_ENFORCE(x.shape().ndim() == 1, "x and y should be 1-d"); @@ -245,7 +261,7 @@ spu::Value GenInvPermByTwoBitVectors(SPUContext *ctx, const spu::Value &x, 1); // calculate prefix sum - auto ps = PrefixSum(ctx, f); + auto ps = _prefix_sum(ctx, f); // mul f and s auto fs = _mul(ctx, f, ps); @@ -284,7 +300,7 @@ spu::Value GenInvPermByTwoBitVectors(SPUContext *ctx, const spu::Value &x, // r = [4, 1, 5, 2, 3] // 8) get res by sub r by one // res = [3, 0, 4, 1, 2] -spu::Value GenInvPermByBitVector(SPUContext *ctx, const spu::Value &x) { +spu::Value _gen_inv_perm_by_bv(SPUContext *ctx, const spu::Value &x) { SPU_ENFORCE(x.shape().ndim() == 1, "x should be 1-d"); const auto k1 = _constant(ctx, 1U, x.shape()); @@ -296,7 +312,7 @@ spu::Value GenInvPermByBitVector(SPUContext *ctx, const spu::Value &x) { ctx, {reshape(ctx, rev_x, new_shape), reshape(ctx, x, new_shape)}, 1); // calculate prefix sum - auto ps = PrefixSum(ctx, f); + auto ps = _prefix_sum(ctx, f); // mul f and s auto fs = _mul(ctx, f, ps); @@ -310,27 +326,30 @@ spu::Value GenInvPermByBitVector(SPUContext *ctx, const spu::Value &x) { return res; } -// This is the inverse of ShufflePerm. +// Ref: https://eprint.iacr.org/2019/695.pdf +// Algorithm 14: Optimized composition of two permutations +// +// Compose is actually a special case of apply_perm where both inputs are +// permutations. +// // The input is a shared inverse permutation , a public permutation -// shuffled_perm generated by ShufflePerm, and a secret permutation -// random_perm for secure unshuffle. +// shuffled_perm generated by _opt_apply_inv_perm_ss, and a secret permutation +// share random_perm for secure unshuffle. // // The steps are as follows: // 1) permute by shuffled_perm as // 2) secure unshuffle and return results -// -// By doing ShufflePerm and UnshufflePerm, we get the shared inverse -// permutation of initial shared bit vectors. -spu::Value UnshufflePerm(SPUContext *ctx, const spu::Value &perm, - const spu::Value &shuffled_perm, - const spu::Value &random_perm) { - auto sm = _perm_sp(ctx, perm, shuffled_perm); - auto res = _inv_perm_ss(ctx, sm, random_perm); +spu::Value _opt_apply_perm_ss(SPUContext *ctx, const spu::Value &perm, + const spu::Value &shuffled_perm, + const spu::Value &random_perm) { + auto sm = hal::_perm_sp(ctx, perm, shuffled_perm); + // this is actually shuffle + auto res = hal::_inv_perm_ss(ctx, sm, random_perm); return res; } -std::vector BitDecompose(SPUContext *ctx, const spu::Value &x, - int64_t valid_bits) { +std::vector _bit_decompose(SPUContext *ctx, const spu::Value &x, + int64_t valid_bits) { auto x_bshare = _prefer_b(ctx, x); const auto k1 = _constant(ctx, 1U, x.shape()); std::vector rets; @@ -349,15 +368,15 @@ std::vector BitDecompose(SPUContext *ctx, const spu::Value &x, } // Generate vector of bit decomposition of sorting keys -std::vector GenBvVector(SPUContext *ctx, - absl::Span inputs, - SortDirection direction, int64_t num_keys, - int64_t valid_bits) { +std::vector _gen_bv_vector(SPUContext *ctx, + absl::Span keys, + SortDirection direction, + int64_t valid_bits) { std::vector ret; - const auto k1 = _constant(ctx, 1U, inputs[0].shape()); - // inputs[0] is the most significant key - for (int64_t i = num_keys - 1; i >= 0; --i) { - const auto &t = BitDecompose(ctx, inputs[i], valid_bits); + const auto k1 = _constant(ctx, 1U, keys[0].shape()); + // keys[0] is the most significant key + for (size_t i = keys.size(); i > 0; --i) { + const auto t = _bit_decompose(ctx, keys[i - 1], valid_bits); SPU_ENFORCE(t.size() > 0); for (size_t j = 0; j < t.size() - 1; j++) { @@ -380,67 +399,65 @@ std::vector GenBvVector(SPUContext *ctx, } // Generate shared inverse permutation by key -spu::Value GenInvPerm(SPUContext *ctx, absl::Span inputs, - SortDirection direction, int64_t num_keys, - int64_t valid_bits) { +spu::Value _gen_inv_perm_s(SPUContext *ctx, absl::Span keys, + SortDirection direction, int64_t valid_bits) { // 1. generate bit decomposition vector of keys - std::vector bv = - GenBvVector(ctx, inputs, direction, num_keys, valid_bits); + std::vector bv = _gen_bv_vector(ctx, keys, direction, valid_bits); SPU_ENFORCE_GT(bv.size(), 0U); // 2. generate natural permutation for initialization auto dt = ctx->config().field() == FieldType::FM32 ? spu::DT_I32 : spu::DT_I64; - auto init_perm = iota(ctx, dt, inputs[0].numel()); + auto init_perm = iota(ctx, dt, keys[0].numel()); auto shared_perm = _p2s(ctx, init_perm); // 3. generate shared inverse permutation by bit vector and process size_t bv_size = bv.size(); size_t bv_idx = 0; for (; bv_idx < bv_size - 1; bv_idx += 2) { - auto random_perm = _rand_perm_s(ctx, inputs[0].shape()); - auto [shuffled_bv, shuffled_perm] = - ShufflePerm(ctx, std::vector{bv[bv_idx], bv[bv_idx + 1]}, - shared_perm, random_perm); - auto perm = GenInvPermByTwoBitVectors(ctx, shuffled_bv[0], shuffled_bv[1]); - shared_perm = UnshufflePerm(ctx, perm, shuffled_perm, random_perm); + // generate random permutation for shuffle + auto random_perm = hal::_rand_perm_s(ctx, keys[0].shape()); + auto [shuffled_bv, shuffled_perm] = _opt_apply_inv_perm_ss( + ctx, std::vector{bv[bv_idx], bv[bv_idx + 1]}, shared_perm, + random_perm); + auto perm = _gen_inv_perm_by_bv(ctx, shuffled_bv[0], shuffled_bv[1]); + shared_perm = _opt_apply_perm_ss(ctx, perm, shuffled_perm, random_perm); } if (bv_idx == bv_size - 1) { - auto random_perm = _rand_perm_s(ctx, inputs[0].shape()); - auto [shuffled_bv, shuffled_perm] = ShufflePerm( + // generate random permutation for shuffle + auto random_perm = hal::_rand_perm_s(ctx, keys[0].shape()); + auto [shuffled_bv, shuffled_perm] = _opt_apply_inv_perm_ss( ctx, std::vector{bv[bv_idx]}, shared_perm, random_perm); - auto perm = GenInvPermByBitVector(ctx, shuffled_bv[0]); - shared_perm = UnshufflePerm(ctx, perm, shuffled_perm, random_perm); + auto perm = _gen_inv_perm_by_bv(ctx, shuffled_bv[0]); + shared_perm = _opt_apply_perm_ss(ctx, perm, shuffled_perm, random_perm); } return shared_perm; } +spu::Value _gen_inv_perm_s(SPUContext *ctx, const spu::Value &key, + bool is_ascending, int64_t valid_bits) { + std::vector keys{key}; + auto direction = + is_ascending ? SortDirection::Ascending : SortDirection::Descending; + auto ret = _gen_inv_perm_s(ctx, keys, direction, valid_bits); + return ret; +} + // Apply inverse permutation on each tensor of x by a shared inverse // permutation -std::vector ApplyInvPerm(SPUContext *ctx, - absl::Span x, - const spu::Value &perm) { - // sanity check. - SPU_ENFORCE(!x.empty(), "inputs should not be empty"); - SPU_ENFORCE(x[0].shape().ndim() == 1, - "inputs should be 1-d but actually have {} dimensions", - x[0].shape().ndim()); - SPU_ENFORCE(std::all_of(x.begin(), x.end(), - [&x](const spu::Value &input) { - return input.shape() == x[0].shape(); - }), - "inputs shape mismatched"); - +std::vector _apply_inv_perm_ss(SPUContext *ctx, + absl::Span x, + const spu::Value &perm) { // 1. = secure shuffle - auto shuffle_perm = _rand_perm_s(ctx, x[0].shape()); - auto sp = _perm_ss(ctx, perm, shuffle_perm); + auto shuffle_perm = hal::_rand_perm_s(ctx, x[0].shape()); + auto sp = hal::_perm_ss(ctx, perm, shuffle_perm); // 2. = secure shuffle std::vector sx; for (size_t i = 0; i < x.size(); ++i) { - sx.emplace_back(_perm_ss(ctx, x[i], shuffle_perm)); + sx.emplace_back(hal::_perm_ss(ctx, x[i], shuffle_perm)); } // 3. M = reveal() @@ -450,29 +467,336 @@ std::vector ApplyInvPerm(SPUContext *ctx, // 4. = SP() std::vector v; for (size_t i = 0; i < sx.size(); ++i) { - auto t = _inv_perm_sp(ctx, sx[i], m); + auto t = hal::_inv_perm_sp(ctx, sx[i], m); v.emplace_back(std::move(t)); } return v; } +spu::Value _apply_inv_perm_ss(SPUContext *ctx, const spu::Value &x, + const spu::Value &perm) { + std::vector inputs{x}; + auto ret = _apply_inv_perm_ss(ctx, inputs, perm); + return std::move(ret[0]); +} + +// Ref: https://eprint.iacr.org/2019/695.pdf +// Algorithm 5: Composition of two share-vector permutations +// +// Compose is actually a special case of apply_perm where both inputs are +// permutations. So to be more general, we use the name _apply_perm_ss +// rather than _compose_ss here +spu::Value _apply_perm_ss(SPUContext *ctx, const Value &x, const Value &perm) { + // 1. = secure shuffle + auto shuffle_perm = hal::_rand_perm_s(ctx, x.shape()); + auto sp = hal::_perm_ss(ctx, perm, shuffle_perm); + + // 2. M = reveal() + auto m = _s2p(ctx, sp); + SPU_ENFORCE_EQ(m.shape().ndim(), 1U, "perm should be 1-d tensor"); + + // 3. sx = apply_perm(x,m) + auto sx = hal::_perm_sp(ctx, x, m); + + // 4. ret = unshuffle() + auto ret = hal::_inv_perm_ss(ctx, sx, shuffle_perm); + + return ret; +} + +// Find mergeable keys from keys. Consecutive public/private(belong to one +// owner) keys can be merged. Assume there are six keys, i.e., public_key0, +// bob_key0, bob_key1, alice_key0, alice_key1, secret_key0. We can merge the six +// keys into bob_new_key, alice_new_key, secret_key0 for the following sorting. +// This function will return a vector of indices [3,5,6] which means key[0,3), +// key[3,5), and key[5,6) can be merged. +std::vector _find_mergeable_keys(SPUContext *ctx, + absl::Span keys) { + std::vector split_indices; + split_indices.push_back(keys.size()); + auto idx = keys.size() - 1; + int64_t pre_owner = keys[idx].isPrivate() ? _get_owner(keys[idx]) : -1; + + while (idx > 0) { + idx--; + const auto &pre_key = keys[idx + 1]; + const auto &cur_key = keys[idx]; + // secret key cannot be merged + if (pre_key.isSecret()) { + split_indices.push_back(idx + 1); + } else { + // if current key are not belong to different owners of previous + // keys, they cannot be merged + if (cur_key.isPublic()) { + continue; + } else if (cur_key.isPrivate()) { + if (pre_owner == -1 || _get_owner(cur_key) == pre_owner) { + pre_owner = _get_owner(cur_key); + continue; + } else { + split_indices.push_back(idx + 1); + } + } else { + split_indices.push_back(idx + 1); + } + pre_owner = cur_key.isPrivate() ? _get_owner(cur_key) : -1; + } + } + std::reverse(split_indices.begin(), split_indices.end()); + return split_indices; +} + +// Given a 1-d array input, generate its inverse permutation +spu::Value _gen_inv_perm(SPUContext *ctx, const Value &in, bool is_ascending, + int64_t valid_bits = -1) { + SPU_TRACE_HAL_DISP(ctx, in, is_ascending, valid_bits); + if (in.isPublic()) { + return _gen_inv_perm_p(ctx, in, is_ascending); + } else if (in.isSecret()) { + return _gen_inv_perm_s(ctx, in, is_ascending, valid_bits); + } else if (in.isPrivate()) { + return _gen_inv_perm_v(ctx, in, is_ascending); + } else { + SPU_THROW("should not be here"); + } +} + +spu::Value _apply_inv_perm_sv(SPUContext *ctx, const Value &in, + const Value &perm) { + if (ctx->hasKernel("inv_perm_av")) { + return hal::_inv_perm_sv(ctx, in, perm); + } else { + return _apply_inv_perm_ss(ctx, in, _v2s(ctx, perm)); + } +} + +#define MAP_APPLY_PERM_OP(NAME) \ + spu::Value _apply##NAME(SPUContext *ctx, const Value &in, \ + const Value &perm) { \ + return hal::NAME(ctx, in, perm); \ + } + +MAP_APPLY_PERM_OP(_perm_pp); +MAP_APPLY_PERM_OP(_perm_vv); +MAP_APPLY_PERM_OP(_perm_sp); +MAP_APPLY_PERM_OP(_inv_perm_pp); +MAP_APPLY_PERM_OP(_inv_perm_vv); +MAP_APPLY_PERM_OP(_inv_perm_sp); + +// Given a permutation, apply (inverse) permutation on a 1-d array input +#define MAP_PERM_OP(NAME) \ + spu::Value NAME(SPUContext *ctx, const Value &in, const Value &perm) { \ + SPU_TRACE_HAL_DISP(ctx, in, perm); \ + if (in.isPublic() && perm.isPublic()) { /*PP*/ \ + return NAME##_pp(ctx, in, perm); \ + } else if (in.isPublic() && perm.isSecret()) { /*PS*/ \ + return NAME##_ss(ctx, _p2s(ctx, in), perm); \ + } else if (in.isPublic() && perm.isPrivate()) { /*PV*/ \ + return NAME##_vv(ctx, _p2v(ctx, in, _get_owner(perm)), perm); \ + } else if (in.isPrivate() && perm.isPrivate()) { /*VV*/ \ + if (_has_same_owner(in, perm)) { \ + return NAME##_vv(ctx, in, perm); \ + } else { \ + return NAME##_sv(ctx, _v2s(ctx, in), perm); \ + } \ + } else if (in.isPrivate() && perm.isPublic()) { /*VP*/ \ + return NAME##_vv(ctx, in, _p2v(ctx, perm, _get_owner(in))); \ + } else if (in.isPrivate() && perm.isSecret()) { /*VS*/ \ + return NAME##_ss(ctx, _v2s(ctx, in), perm); \ + } else if (in.isSecret() && perm.isSecret()) { /*SS*/ \ + return NAME##_ss(ctx, in, perm); \ + } else if (in.isSecret() && perm.isPublic()) { /*SP*/ \ + return NAME##_sp(ctx, in, perm); \ + } else if (in.isSecret() && perm.isPrivate()) { /*SV*/ \ + return NAME##_sv(ctx, in, perm); \ + } else { \ + SPU_THROW("should not be here"); \ + } \ + } + +// Inverse permute 1-D array x with a permutation perm +// ret[perm[i]] = x[i] +MAP_PERM_OP(_apply_inv_perm) + +// Given a permutation, generate its inverse permutation +// ret[perm[i]] = i +spu::Value _inverse(SPUContext *ctx, const Value &perm) { + auto dt = + ctx->config().field() == FieldType::FM32 ? spu::DT_I32 : spu::DT_I64; + auto iota_perm = iota(ctx, dt, perm.numel()); + return _apply_inv_perm(ctx, iota_perm, perm); +} + +spu::Value _apply_perm_sv(SPUContext *ctx, const Value &in, const Value &perm) { + if (ctx->hasKernel("inv_perm_av")) { + return hal::_inv_perm_sv(ctx, in, _inverse(ctx, perm)); + } else { + return _apply_inv_perm_ss(ctx, in, _v2s(ctx, _inverse(ctx, perm))); + } +} + +// Permute 1-D array x with a permutation perm +// ret[i] = x[perm[i]] +MAP_PERM_OP(_apply_perm) + +// Compose two permutations into one permutation +// If we have two permutations x and y, we want to get a permutation z from x +// and y that apply_inv_perm(in, z) = apply_inv_perm(apply_inv_perm(in, x), y) +spu::Value _compose_perm(SPUContext *ctx, const Value &x, const Value &y) { + return _apply_perm(ctx, y, x); +} + +spu::Value _merge_keys(SPUContext *ctx, absl::Span inputs, + bool is_ascending) { + if (inputs[0].isPublic()) { + SPU_ENFORCE(std::all_of(inputs.begin(), inputs.end(), + [](const spu::Value &v) { return v.isPublic(); }), + "keys should be all public"); + return _merge_keys_p(ctx, inputs, is_ascending); + } else if (inputs[0].isPrivate()) { + SPU_ENFORCE(std::all_of(inputs.begin(), inputs.end(), + [&inputs](const spu::Value &v) { + return v.isPrivate() && + _has_same_owner(v, inputs[0]); + }), + "keys should have a same owner"); + return _merge_keys_v(ctx, inputs, is_ascending); + } else if (inputs[0].isSecret()) { + SPU_THROW("merge secret permutation is currently not supported"); + } else { + SPU_THROW("should not be here"); + } +} + +spu::Value _merge_pub_pri_keys(SPUContext *ctx, + absl::Span keys, + bool is_ascending) { + SPU_ENFORCE(std::all_of(keys.begin(), keys.end(), + [](const spu::Value &v) { return !v.isSecret(); }), + "secret keys should not be here"); + SPU_ENFORCE_GE(keys.size(), 1U, "there are at least 1 key to merge"); + const auto &pre_key = keys.back(); + + auto inv_perm = _gen_inv_perm(ctx, pre_key, is_ascending); + + for (int64_t i = keys.size() - 2; i >= 0; --i) { + const auto &cur_key = keys[i]; + auto cur_key_hat = _apply_inv_perm(ctx, cur_key, inv_perm); + auto cur_inv_perm = _gen_inv_perm(ctx, cur_key_hat, is_ascending); + inv_perm = _compose_perm(ctx, inv_perm, cur_inv_perm); + } + auto dt = + ctx->config().field() == FieldType::FM32 ? spu::DT_I32 : spu::DT_I64; + std::vector permed_keys; + for (const auto &key : keys) { + permed_keys.emplace_back(_apply_inv_perm(ctx, key, inv_perm)); + } + auto merged_key = _merge_keys(ctx, permed_keys, is_ascending).setDtype(dt); + return _apply_perm(ctx, merged_key, inv_perm); +} + +// Merge consecutive private/public keys +std::vector _merge_sorting_keys(SPUContext *ctx, + absl::Span keys, + bool is_ascending) { + auto merge_pos = _find_mergeable_keys(ctx, keys); + SPU_ENFORCE_GT(merge_pos.size(), 0U, "there is at least 1 key after merging"); + std::vector new_keys; + size_t beg_idx = 0; + for (size_t end_idx : merge_pos) { + // for a single private/public, merge the key can use valid_bits + // optimization + if (end_idx - beg_idx == 1 && keys[beg_idx].isSecret()) { + new_keys.push_back(keys[beg_idx]); + } else { + auto merged_key = _merge_pub_pri_keys( + ctx, keys.subspan(beg_idx, end_idx - beg_idx), is_ascending); + new_keys.push_back(std::move(merged_key)); + } + beg_idx = end_idx; + } + return new_keys; +} + +// Generate an inverse permutation vector according to sorting keys. The +// permutation vector should be secret or private (if enabled) but cannot be +// public as we have already process sorting with public keys outside of +// radix sort. +spu::Value gen_inv_perm(SPUContext *ctx, absl::Span inputs, + SortDirection direction, int64_t num_keys, + int64_t valid_bits) { + // merge consecutive private/public keys + auto keys = inputs.subspan(0, num_keys); + if (std::all_of(keys.begin(), keys.end(), + [](const spu::Value &v) { return v.isSecret(); })) { + auto perm = _gen_inv_perm_s(ctx, keys, direction, valid_bits); + return perm; + } + bool is_ascending = direction == SortDirection::Ascending; + auto merged_keys = _merge_sorting_keys(ctx, keys, is_ascending); + + // generate inverse permutation + const auto &pre_key = merged_keys.back(); + auto inv_perm = _gen_inv_perm(ctx, pre_key, is_ascending, valid_bits); + for (int64_t i = merged_keys.size() - 2; i >= 0; --i) { + const auto &cur_key = merged_keys[i]; + auto cur_key_hat = _apply_inv_perm(ctx, cur_key, inv_perm); + auto real_valid_bits = + cur_key.isSecret() ? valid_bits : Log2Floor(cur_key.numel()) + 2; + auto cur_inv_perm = + _gen_inv_perm(ctx, cur_key_hat, is_ascending, real_valid_bits); + inv_perm = _compose_perm(ctx, inv_perm, cur_inv_perm); + } + + return inv_perm; +} + +std::vector apply_inv_perm(SPUContext *ctx, + absl::Span inputs, + const spu::Value &perm) { + if (perm.isSecret()) { + std::vector inputs_s; + for (const auto &input : inputs) { + inputs_s.emplace_back(_2s(ctx, input).setDtype(input.dtype())); + } + return _apply_inv_perm_ss(ctx, inputs_s, perm); + } else if (perm.isPrivate()) { + if (ctx->hasKernel("inv_perm_av")) { + std::vector ret; + for (const auto &input : inputs) { + ret.emplace_back( + _apply_inv_perm(ctx, input, perm).setDtype(input.dtype())); + } + return ret; + } else { + std::vector inputs_s; + for (const auto &input : inputs) { + inputs_s.emplace_back(_2s(ctx, input).setDtype(input.dtype())); + } + return _apply_inv_perm_ss(ctx, inputs_s, _2s(ctx, perm)); + } + } else { + SPU_THROW("Should not be here"); + } +} + // Secure Radix Sort // Ref: // https://eprint.iacr.org/2019/695.pdf // // Each input is a 1-d tensor, inputs[0, num_keys) are the keys, and sort // inputs according to keys -std::vector RadixSort(SPUContext *ctx, - absl::Span inputs, - SortDirection direction, int64_t num_keys, - int64_t valid_bits) { - auto perm = GenInvPerm(ctx, inputs, direction, num_keys, valid_bits); - auto res = ApplyInvPerm(ctx, inputs, perm); +std::vector radix_sort(SPUContext *ctx, + absl::Span inputs, + SortDirection direction, int64_t num_keys, + int64_t valid_bits) { + auto perm = gen_inv_perm(ctx, inputs, direction, num_keys, valid_bits); + auto res = apply_inv_perm(ctx, inputs, perm); return res; } -} // namespace +} // namespace internal std::vector sort1d(SPUContext *ctx, absl::Span inputs, @@ -513,13 +837,13 @@ std::vector sort1d(SPUContext *ctx, ret.reserve(inputs.size()); for (const auto &input : inputs) { - ret.push_back(Permute1D(ctx, input, indices_to_sort)); + ret.push_back(internal::_permute_1d(ctx, input, indices_to_sort)); } } else if (comparator_ret_vis == VIS_SECRET) { SPU_ENFORCE(!is_stable, "Stable sort is unsupported if comparator return is secret."); - ret = OddEvenMergeSort(ctx, cmp, inputs); + ret = internal::odd_even_merge_sort(ctx, cmp, inputs); } else { SPU_THROW("Should not reach here"); } @@ -545,26 +869,22 @@ std::vector simple_sort1d(SPUContext *ctx, "num_keys {} is not valid", num_keys); bool fallback = false; - // If all keys are secret values and the protocol supports secret shuffle and - // unshuffle, we can use radix sort for fast 1-D sort. Otherwise, we fallback - // to generic sort1d, and use the inputs[0] as the sorting key - if (!std::all_of(inputs.begin(), inputs.begin() + num_keys, - [](const spu::Value &v) { return v.isSecret(); })) { + // if all keys are public, fallback to public sort + if (std::all_of(inputs.begin(), inputs.begin() + num_keys, + [](const spu::Value &v) { return v.isPublic(); })) { fallback = true; - SPDLOG_WARN("Fallback to generic sort1d because not all keys are secret"); } - + // If the protocol supports secret shuffle and unshuffle, we can use radix + // sort for fast 1-D sort. Otherwise, we fallback to generic sort1d if (!fallback && - !(ctx->hasKernel("rand_perm_s") && ctx->hasKernel("perm_as") && - ctx->hasKernel("perm_ap") && ctx->hasKernel("inv_perm_as") && + !(ctx->hasKernel("rand_perm_m") && ctx->hasKernel("perm_am") && + ctx->hasKernel("perm_ap") && ctx->hasKernel("inv_perm_am") && ctx->hasKernel("inv_perm_ap"))) { fallback = true; - SPDLOG_WARN( - "Fallback to generic sort1d because permutation-related kernels are " - "not supported"); } if (!fallback) { - auto ret = RadixSort(ctx, inputs, direction, num_keys, valid_bits); + auto ret = + internal::radix_sort(ctx, inputs, direction, num_keys, valid_bits); return ret; } else { auto scalar_cmp = [direction](spu::SPUContext *ctx, const spu::Value &lhs, @@ -629,7 +949,7 @@ std::vector permute(SPUContext *ctx, std::iota(perm.begin(), perm.end(), 0); std::swap(perm[permute_dim], perm.back()); - auto q = GenInvPerm(Index(perm.begin(), perm.end())); + auto q = internal::_inverse_index(Index(perm.begin(), perm.end())); unperm = Axes(q.begin(), q.end()); } diff --git a/libspu/kernel/hal/prot_wrapper.cc b/libspu/kernel/hal/prot_wrapper.cc index 45942b88..3f8c6c06 100644 --- a/libspu/kernel/hal/prot_wrapper.cc +++ b/libspu/kernel/hal/prot_wrapper.cc @@ -212,27 +212,116 @@ MAP_OPTIONAL_BINARY_OP(equal_ss) MAP_OPTIONAL_BINARY_OP(equal_sp) MAP_BINARY_OP(equal_pp) -#define MAP_OPTIONAL_PERM_OP(NAME) \ - Value _##NAME(SPUContext* ctx, const Value& x, const Value& y) { \ - SPU_TRACE_HAL_DISP(ctx, x, y); \ - SPU_ENFORCE(x.shape().ndim() == 1, "x should be a 1-d tensor"); \ - auto ret = mpc::NAME(ctx, x, y); \ - SPU_ENFORCE(ret.has_value(), "{} api not implemented", #NAME); \ - return ret.value().setDtype(x.dtype()); \ +#define MAP_OPTIONAL_PERM_OP(NAME) \ + Value _##NAME(SPUContext* ctx, const Value& x, const Value& y) { \ + SPU_TRACE_HAL_DISP(ctx, x, y); \ + SPU_ENFORCE(x.shape() == y.shape(), "shape mismatch: x={}, y={}", \ + x.shape(), y.shape()); \ + SPU_ENFORCE(x.shape().ndim() == 1, "x should be a 1-d tensor"); \ + auto ret = mpc::NAME(ctx, x, y); \ + SPU_ENFORCE(ret.has_value(), "{} api not implemented", #NAME); \ + return ret.value().setDtype(x.dtype()); \ } // namespace spu::kernel::hal MAP_OPTIONAL_PERM_OP(perm_ss); MAP_OPTIONAL_PERM_OP(perm_sp); MAP_OPTIONAL_PERM_OP(inv_perm_ss); MAP_OPTIONAL_PERM_OP(inv_perm_sp); +MAP_OPTIONAL_PERM_OP(inv_perm_sv); Value _rand_perm_s(SPUContext* ctx, const Shape& shape) { SPU_TRACE_HAL_DISP(ctx, shape); - SPU_ENFORCE(shape.ndim() == 1, "shape should be a 1-d"); - + SPU_ENFORCE(shape.ndim() == 1, "shape should be 1-d"); auto ret = mpc::rand_perm_s(ctx, shape); SPU_ENFORCE(ret.has_value(), "rand_perm_s api not implemented"); return ret.value(); } +Value _broadcast(SPUContext* ctx, const Value& in, const Shape& to_shape, + const Axes& in_dims) { + return mpc::broadcast(ctx, in, to_shape, in_dims).setDtype(in.dtype()); +} + +Value _reshape(SPUContext* ctx, const Value& in, const Shape& to_shape) { + return mpc::reshape(ctx, in, to_shape).setDtype(in.dtype()); +} + +Value _extract_slice(SPUContext* ctx, const Value& in, + const Index& start_indices, const Index& end_indices, + const Strides& strides) { + return mpc::extract_slice(ctx, in, start_indices, end_indices, strides) + .setDtype(in.dtype()); +} + +Value _update_slice(SPUContext* ctx, const Value& in, const Value& update, + const Index& start_indices) { + return mpc::update_slice(ctx, in, update, start_indices).setDtype(in.dtype()); +} + +Value _transpose(SPUContext* ctx, const Value& in, const Axes& permutation) { + return mpc::transpose(ctx, in, permutation).setDtype(in.dtype()); +} + +Value _reverse(SPUContext* ctx, const Value& in, const Axes& dimensions) { + return mpc::reverse(ctx, in, dimensions).setDtype(in.dtype()); +} + +Value _fill(SPUContext* ctx, const Value& in, const Shape& to_shape) { + return mpc::fill(ctx, in, to_shape).setDtype(in.dtype()); +} + +Value _pad(SPUContext* ctx, const Value& in, const Value& padding_value, + const Sizes& edge_padding_low, const Sizes& edge_padding_high, + const Sizes& interior_padding) { + return mpc::pad(ctx, in, padding_value, edge_padding_low, edge_padding_high, + interior_padding) + .setDtype(in.dtype()); +} + +Value _concatenate(SPUContext* ctx, const std::vector& values, + int64_t axis) { + return mpc::concatenate(ctx, values, axis).setDtype(values.front().dtype()); +} + +Value _gen_inv_perm_p(SPUContext* ctx, const Value& in, bool is_ascending) { + SPU_TRACE_HAL_DISP(ctx, in, is_ascending); + SPU_ENFORCE(in.shape().ndim() == 1, "input should be 1-d"); + return dynDispatch(ctx, "gen_inv_perm_p", in, is_ascending); +} + +Value _gen_inv_perm_v(SPUContext* ctx, const Value& in, bool is_ascending) { + SPU_TRACE_HAL_DISP(ctx, in, is_ascending); + SPU_ENFORCE(in.shape().ndim() == 1, "input should be 1-d"); + return dynDispatch(ctx, "gen_inv_perm_v", in, is_ascending); +} + +Value _merge_keys_p(SPUContext* ctx, absl::Span inputs, + bool is_ascending) { + SPU_TRACE_HAL_DISP(ctx, inputs.size(), inputs[0].shape(), is_ascending); + std::vector in(inputs.begin(), inputs.end()); + return dynDispatch(ctx, "merge_keys_p", in, is_ascending); +} + +Value _merge_keys_v(SPUContext* ctx, absl::Span inputs, + bool is_ascending) { + SPU_TRACE_HAL_DISP(ctx, inputs.size(), inputs[0].shape(), is_ascending); + std::vector in(inputs.begin(), inputs.end()); + return dynDispatch(ctx, "merge_keys_v", in, is_ascending); +} + +#define MAP_PERM_OP(NAME) \ + Value _##NAME(SPUContext* ctx, const Value& x, const Value& y) { \ + SPU_TRACE_HAL_DISP(ctx, x, y); \ + SPU_ENFORCE(x.shape() == y.shape(), "shape mismatch: x={}, y={}", \ + x.shape(), y.shape()); \ + SPU_ENFORCE(x.shape().ndim() == 1, "x should be a 1-d tensor"); \ + auto ret = mpc::NAME(ctx, x, y); \ + return ret.setDtype(x.dtype()); \ + } + +MAP_PERM_OP(inv_perm_pp); +MAP_PERM_OP(inv_perm_vv); +MAP_PERM_OP(perm_pp); +MAP_PERM_OP(perm_vv); + } // namespace spu::kernel::hal diff --git a/libspu/kernel/hal/prot_wrapper.h b/libspu/kernel/hal/prot_wrapper.h index 2d0a62f4..a0c5d5ef 100644 --- a/libspu/kernel/hal/prot_wrapper.h +++ b/libspu/kernel/hal/prot_wrapper.h @@ -115,12 +115,43 @@ Value _make_p(SPUContext* ctx, uint128_t init, const Shape& shape); Value _rand_p(SPUContext* ctx, const Shape& shape); Value _rand_s(SPUContext* ctx, const Shape& shape); -// FIXME: temporary API +// FIXME: temporary API, formalize later Value _rand_perm_s(SPUContext* ctx, const Shape& shape); Value _perm_ss(SPUContext* ctx, const Value& x, const Value& perm); Value _perm_sp(SPUContext* ctx, const Value& x, const Value& perm); +Value _perm_pp(SPUContext* ctx, const Value& x, const Value& perm); +Value _perm_vv(SPUContext* ctx, const Value& x, const Value& perm); Value _inv_perm_ss(SPUContext* ctx, const Value& x, const Value& perm); Value _inv_perm_sp(SPUContext* ctx, const Value& x, const Value& perm); +Value _inv_perm_sv(SPUContext* ctx, const Value& x, const Value& perm); +Value _inv_perm_pp(SPUContext* ctx, const Value& x, const Value& perm); +Value _inv_perm_vv(SPUContext* ctx, const Value& x, const Value& perm); + +Value _gen_inv_perm_p(SPUContext* ctx, const Value& x, bool is_ascending); +Value _gen_inv_perm_v(SPUContext* ctx, const Value& x, bool is_ascending); +Value _merge_keys_p(SPUContext* ctx, absl::Span inputs, + bool is_ascending); +Value _merge_keys_v(SPUContext* ctx, absl::Span inputs, + bool is_ascending); + +// Shape ops +Value _broadcast(SPUContext* ctx, const Value& in, const Shape& to_shape, + const Axes& in_dims); +Value _reshape(SPUContext* ctx, const Value& in, const Shape& to_shape); +Value _extract_slice(SPUContext* ctx, const Value& in, + const Index& start_indices, const Index& end_indices, + const Strides& strides); +Value _update_slice(SPUContext* ctx, const Value& in, const Value& update, + const Index& start_indices); +Value _transpose(SPUContext* ctx, const Value& in, + const Axes& permutation = {}); +Value _reverse(SPUContext* ctx, const Value& in, const Axes& dimensions); +Value _fill(SPUContext* ctx, const Value& in, const Shape& to_shape); +Value _pad(SPUContext* ctx, const Value& in, const Value& padding_value, + const Sizes& edge_padding_low, const Sizes& edge_padding_high, + const Sizes& interior_padding); +Value _concatenate(SPUContext* ctx, const std::vector& values, + int64_t axis); // NOLINTEND(readability-identifier-naming) diff --git a/libspu/kernel/hal/ring.cc b/libspu/kernel/hal/ring.cc index 528d8719..d10d6eb6 100644 --- a/libspu/kernel/hal/ring.cc +++ b/libspu/kernel/hal/ring.cc @@ -22,17 +22,22 @@ #include "libspu/core/prelude.h" #include "libspu/core/trace.h" #include "libspu/kernel/hal/prot_wrapper.h" -#include "libspu/kernel/hal/shape_ops.h" namespace spu::kernel::hal { Type _common_type(SPUContext* ctx, const Type& a, const Type& b) { if (a.isa() && b.isa()) { return _common_type_s(ctx, a, b); + } else if (a.isa() && b.isa()) { + return _common_type_v(ctx, a, b); } else if (a.isa()) { return a; } else if (b.isa()) { return b; + } else if (a.isa()) { + return a; + } else if (b.isa()) { + return b; } else { SPU_ENFORCE(a.isa() && b.isa()); return a; @@ -40,11 +45,18 @@ Type _common_type(SPUContext* ctx, const Type& a, const Type& b) { } Value _cast_type(SPUContext* ctx, const Value& x, const Type& to) { + if (x.storage_type() == to) { + return x; + } if (x.isPublic() && to.isa()) { return x; } else if (x.isPublic() && to.isa()) { // FIXME: casting to BShare semantic is wrong. return _p2s(ctx, x); + } else if (x.isPublic() && to.isa()) { + return _p2v(ctx, x, to.as()->owner()); + } else if (x.isPrivate() && to.isa()) { + return _v2s(ctx, x); } else if (x.isSecret() && to.isa()) { return _cast_type_s(ctx, x, to); } else { @@ -163,15 +175,18 @@ static Value _mmul_impl(SPUContext* ctx, const Value& x, const Value& y) { } else if (x.isSecret() && y.isPublic()) { // SP return _mmul_sp(ctx, x, y); } else if (x.isPublic() && y.isSecret()) { // PS - return transpose(ctx, _mmul_sp(ctx, transpose(ctx, y), transpose(ctx, x))); + return _transpose(ctx, + _mmul_sp(ctx, _transpose(ctx, y), _transpose(ctx, x))); } else if (x.isPrivate() && y.isPublic()) { // VP return _mmul_vp(ctx, x, y); } else if (x.isPublic() && y.isPrivate()) { // PV - return transpose(ctx, _mmul_vp(ctx, transpose(ctx, y), transpose(ctx, x))); + return _transpose(ctx, + _mmul_vp(ctx, _transpose(ctx, y), _transpose(ctx, x))); } else if (x.isSecret() && y.isPrivate()) { // SV return _mmul_sv(ctx, x, y); } else if (x.isPrivate() && y.isSecret()) { // VS - return transpose(ctx, _mmul_sv(ctx, transpose(ctx, y), transpose(ctx, x))); + return _transpose(ctx, + _mmul_sv(ctx, _transpose(ctx, y), _transpose(ctx, x))); } else { SPU_THROW("unsupported op {} for x={}, y={}", "_matmul", x, y); } @@ -314,17 +329,19 @@ Value _mmul(SPUContext* ctx, const Value& x, const Value& y) { Value x_block; if (x.shape().size() == 1) { SPU_ENFORCE(m_start == 0 && m_end == 1); - x_block = slice(ctx, x, {k_start}, {k_end}, {}); + x_block = _extract_slice(ctx, x, {k_start}, {k_end}, {}); } else { - x_block = slice(ctx, x, {m_start, k_start}, {m_end, k_end}, {}); + x_block = + _extract_slice(ctx, x, {m_start, k_start}, {m_end, k_end}, {}); } Value y_block; if (y.shape().size() == 1) { SPU_ENFORCE(n_start == 0 && n_end == 1); - y_block = slice(ctx, y, {k_start}, {k_end}, {}); + y_block = _extract_slice(ctx, y, {k_start}, {k_end}, {}); } else { - y_block = slice(ctx, y, {k_start, n_start}, {k_end, n_end}, {}); + y_block = + _extract_slice(ctx, y, {k_start, n_start}, {k_end, n_end}, {}); } auto mmul_ret = _mmul_impl(ctx, x_block, y_block); @@ -615,17 +632,17 @@ Value _tensordot(SPUContext* ctx, const Value& x, const Value& y, std::rotate(perm_x.begin(), perm_x.begin() + nc, perm_x.end()); // convert to mmul shape. - auto xx = transpose(ctx, x, Axes(perm_x)); + auto xx = _transpose(ctx, x, Axes(perm_x)); Shape xxs = xx.shape(); - xx = reshape(ctx, xx, - {product(xxs.begin(), xxs.end() - nc), - product(xxs.end() - nc, xxs.end())}); + xx = _reshape(ctx, xx, + {product(xxs.begin(), xxs.end() - nc), + product(xxs.end() - nc, xxs.end())}); - auto yy = transpose(ctx, y, Axes(perm_y)); + auto yy = _transpose(ctx, y, Axes(perm_y)); Shape yys = yy.shape(); - yy = reshape(ctx, yy, - {product(yys.begin(), yys.begin() + nc), - product(yys.begin() + nc, yys.end())}); + yy = _reshape(ctx, yy, + {product(yys.begin(), yys.begin() + nc), + product(yys.begin() + nc, yys.end())}); // do matrix multiplication. auto zz = _mmul(ctx, xx, yy); @@ -634,7 +651,7 @@ Value _tensordot(SPUContext* ctx, const Value& x, const Value& y, Shape res_shape(xxs.begin(), xxs.end() - nc); res_shape.insert(res_shape.end(), yys.begin() + nc, yys.end()); - return reshape(ctx, zz, res_shape); + return _reshape(ctx, zz, res_shape); } } // namespace spu::kernel::hal diff --git a/libspu/kernel/hal/shape_ops.cc b/libspu/kernel/hal/shape_ops.cc index 733ce226..6337719f 100644 --- a/libspu/kernel/hal/shape_ops.cc +++ b/libspu/kernel/hal/shape_ops.cc @@ -14,101 +14,24 @@ #include "libspu/kernel/hal/shape_ops.h" -#include - #include "libspu/core/context.h" -#include "libspu/core/ndarray_ref.h" #include "libspu/core/trace.h" #include "libspu/kernel/hal/prot_wrapper.h" +#include "libspu/kernel/hal/ring.h" namespace spu::kernel::hal { -namespace { - -// TODO: these code is copied from ring.cc, remove it when shape ops is lowered -// to mpc layer. -Type _common_type(SPUContext* ctx, const Type& a, const Type& b) { - if (a.isa() && b.isa()) { - return _common_type_s(ctx, a, b); - } else if (a.isa() && b.isa()) { - return _common_type_v(ctx, a, b); - } else if (a.isa()) { - return a; - } else if (b.isa()) { - return b; - } else { - SPU_ENFORCE(a.isa() && b.isa()); - return a; - } -} - -Value _cast_type(SPUContext* ctx, const Value& x, const Type& to) { - if (x.storage_type() == to) { - return x; - } - if (x.isPublic() && to.isa()) { - return x; - } else if (x.isPublic() && to.isa()) { - // FIXME: casting to BShare semantic is wrong. - return _p2s(ctx, x); - } else if (x.isPublic() && to.isa()) { - return _p2v(ctx, x, to.as()->owner()); - } else if (x.isSecret() && to.isa()) { - return _s2v(ctx, x, to.as()->owner()); - } else if (x.isPrivate() && to.isa()) { - return _v2s(ctx, x); - } else if (x.isSecret() && to.isa()) { - return _cast_type_s(ctx, x, to); - } else { - SPU_THROW("should not be here x={}, to={}", x, to); - } -} - -// Compact threshold heuristic, try to make it same as L1 cache size -#define COMPACT_THRESHOLD (32 * 1024) // 32K - -SPU_ALWAYS_INLINE NdArrayRef _try_compact(const NdArrayRef& in) { - // If in data is not compact after some shape ops and small enough, make it - // compact - if (in.numel() * in.elsize() <= COMPACT_THRESHOLD && !in.isCompact()) { - return in.clone(); - } - return in; -} - -} // namespace Value transpose(SPUContext* ctx, const Value& in, const Axes& permutation) { - SPU_TRACE_HAL_DISP(ctx, in); - - Axes perm = permutation; - if (perm.empty()) { - // by default, transpose the data in reverse order. - perm.resize(in.shape().size()); - std::iota(perm.rbegin(), perm.rend(), 0); - } - - // sanity check. - SPU_ENFORCE_EQ(perm.size(), in.shape().size()); - std::set uniq(perm.begin(), perm.end()); - SPU_ENFORCE_EQ(uniq.size(), perm.size(), "perm={} is not unique", perm); - - // fast path, if identity permutation, return it. - Axes no_perm(in.shape().size()); - std::iota(no_perm.begin(), no_perm.end(), 0); - if (perm == no_perm) { - return in; - } + SPU_TRACE_HAL_DISP(ctx, in, permutation); - return Value(_try_compact(in.data().transpose(perm)), in.dtype()); + return _transpose(ctx, in, permutation); } Value slice(SPUContext* ctx, const Value& in, const Index& start_indices, const Index& end_indices, const Strides& strides) { SPU_TRACE_HAL_DISP(ctx, in, start_indices, end_indices, strides); - return Value( - _try_compact(in.data().slice(start_indices, end_indices, strides)), - in.dtype()); + return _extract_slice(ctx, in, start_indices, end_indices, strides); } Value slice_scalar_at(SPUContext*, const Value& input, const Index& indices) { @@ -117,6 +40,8 @@ Value slice_scalar_at(SPUContext*, const Value& input, const Index& indices) { Value update_slice(SPUContext* ctx, const Value& in, const Value& update, const Index& start_indices) { + SPU_TRACE_HAL_DISP(ctx, in, start_indices); + if (in.storage_type() != update.storage_type()) { auto u = _cast_type(ctx, update, in.storage_type()).setDtype(update.dtype()); @@ -124,32 +49,32 @@ Value update_slice(SPUContext* ctx, const Value& in, const Value& update, return update_slice(ctx, in, u, start_indices); } - auto ret = in.clone(); - ret.data().update_slice(update.data(), start_indices); - return ret; + return _update_slice(ctx, in, update, start_indices).setDtype(in.dtype()); } Value reshape(SPUContext* ctx, const Value& in, const Shape& to_shape) { SPU_TRACE_HAL_DISP(ctx, in, to_shape); - return Value(_try_compact(in.data().reshape(to_shape)), in.dtype()); + return _reshape(ctx, in, to_shape).setDtype(in.dtype()); } Value broadcast_to(SPUContext* ctx, const Value& in, const Shape& to_shape, const Axes& in_dims) { SPU_TRACE_HAL_DISP(ctx, in, to_shape); - return Value(in.data().broadcast_to(to_shape, in_dims), in.dtype()); + return _broadcast(ctx, in, to_shape, in_dims).setDtype(in.dtype()); } Value reverse(SPUContext* ctx, const Value& in, const Axes& dimensions) { SPU_TRACE_HAL_DISP(ctx, in, dimensions); - return Value(in.data().reverse(dimensions), in.dtype()); + return _reverse(ctx, in, dimensions); } -Value expand(SPUContext*, const Value& in, const Shape& to_shape) { - return Value(in.data().expand(to_shape), in.dtype()); +Value expand(SPUContext* ctx, const Value& in, const Shape& to_shape) { + SPU_TRACE_HAL_DISP(ctx, in, to_shape); + + return _fill(ctx, in, to_shape); } Value pad(SPUContext* ctx, const Value& in, const Value& padding_value, @@ -165,14 +90,13 @@ Value pad(SPUContext* ctx, const Value& in, const Value& padding_value, edge_padding_high, interior_padding); } - return Value(in.data().pad(padding_value.data(), edge_padding_low, - edge_padding_high, interior_padding), - in.dtype()); + return _pad(ctx, in, padding_value, edge_padding_low, edge_padding_high, + interior_padding); } -Value concatenate(SPUContext* ctx, absl::Span values, +Value concatenate(SPUContext* ctx, const std::vector& values, int64_t axis) { - SPU_TRACE_HAL_DISP(ctx, axis); + SPU_TRACE_HAL_DISP(ctx, values, axis); SPU_ENFORCE(!values.empty(), "got={}", values.size()); if (values.size() == 1) { @@ -207,12 +131,7 @@ Value concatenate(SPUContext* ctx, absl::Span values, SPU_ENFORCE(all_same_stype); - std::vector array(values.size() - 1); - for (int64_t idx = 1; idx < static_cast(values.size()); ++idx) { - array[idx - 1] = values[idx].data(); - } - - return Value(values[0].data().concatenate(array, axis), values[0].dtype()); + return _concatenate(ctx, values, axis); } } // namespace spu::kernel::hal diff --git a/libspu/kernel/hal/shape_ops.h b/libspu/kernel/hal/shape_ops.h index 61114cd2..24f0e859 100644 --- a/libspu/kernel/hal/shape_ops.h +++ b/libspu/kernel/hal/shape_ops.h @@ -82,7 +82,7 @@ Value pad(SPUContext* ctx, const Value& in, const Value& padding_value, // @param first, the first param // @param second, the second param // @param axis, the axis -Value concatenate(SPUContext* ctx, absl::Span values, +Value concatenate(SPUContext* ctx, const std::vector& values, int64_t axis); } // namespace spu::kernel::hal diff --git a/libspu/kernel/hlo/geometrical.cc b/libspu/kernel/hlo/geometrical.cc index 2e98e37f..0a239cc1 100644 --- a/libspu/kernel/hlo/geometrical.cc +++ b/libspu/kernel/hlo/geometrical.cc @@ -49,7 +49,7 @@ spu::Value Reshape(SPUContext *ctx, const spu::Value &in, return hal::reshape(ctx, in, to_shape); } -spu::Value Concatenate(SPUContext *ctx, absl::Span operands, +spu::Value Concatenate(SPUContext *ctx, const std::vector &operands, int64_t axis) { if (operands.front().isComplex()) { std::vector r_operands(operands.size()); diff --git a/libspu/kernel/hlo/geometrical.h b/libspu/kernel/hlo/geometrical.h index 677bc4f0..15c0e30c 100644 --- a/libspu/kernel/hlo/geometrical.h +++ b/libspu/kernel/hlo/geometrical.h @@ -31,7 +31,7 @@ spu::Value Broadcast(SPUContext *ctx, const spu::Value &in, spu::Value Reshape(SPUContext *ctx, const spu::Value &in, const Shape &to_shape); -spu::Value Concatenate(SPUContext *ctx, absl::Span operands, +spu::Value Concatenate(SPUContext *ctx, const std::vector &operands, int64_t axis); spu::Value Slice(SPUContext *ctx, const spu::Value &in, const Index &start, diff --git a/libspu/kernel/hlo/indexing.cc b/libspu/kernel/hlo/indexing.cc index 66b65525..b815be7d 100644 --- a/libspu/kernel/hlo/indexing.cc +++ b/libspu/kernel/hlo/indexing.cc @@ -36,19 +36,19 @@ void hintNumberOfBits(const Value &a, size_t nbits); namespace { struct IndexIterationSpace { - std::vector index_base; - std::vector index_count; - std::vector index_incr; + spu::Index index_base; + spu::Index index_count; + spu::Index index_incr; }; // Returns an IndexIterationSpace that iterates over the output batch // dimensions while keeping the rest of the output dimensions clamped to 0. IndexIterationSpace iterationSpaceForOutputBatchIndices( - absl::Span output_shape, + const spu::Shape &output_shape, const spu::kernel::hlo::GatherConfig &config) { int64_t output_rank = output_shape.size(); - std::vector index_base(output_rank, 0); - std::vector index_count; + spu::Index index_base(output_rank, 0); + spu::Index index_count; index_count.reserve(output_rank); for (int64_t i = 0; i < output_rank; i++) { @@ -58,15 +58,15 @@ IndexIterationSpace iterationSpaceForOutputBatchIndices( } return {std::move(index_base), std::move(index_count), - std::vector(output_rank, 1)}; + spu::Index(output_rank, 1)}; } // Return an IndexIterationSpace that iterates over the output slice // dimensions while keeping the rest of the output dimensions clamped to 0. IndexIterationSpace iterationSpaceForOutputOffsetIndices( int64_t output_rank, const spu::kernel::hlo::GatherConfig &config) { - std::vector index_base(output_rank, 0); - std::vector index_count(output_rank, 1); + spu::Index index_base(output_rank, 0); + spu::Index index_count(output_rank, 1); int64_t slice_sizes_idx = 0; for (int64_t i = 0; i < output_rank; i++) { @@ -83,7 +83,7 @@ IndexIterationSpace iterationSpaceForOutputOffsetIndices( } return {std::move(index_base), std::move(index_count), - std::vector(output_rank, 1)}; + spu::Index(output_rank, 1)}; } // This functor computes the contribution of start_indices to an input index @@ -97,8 +97,7 @@ class OutputBatchIndexToInputIndex { // iterations. explicit OutputBatchIndexToInputIndex( const spu::kernel::hlo::GatherConfig &config, - absl::Span input_shape, - absl::Span output_shape, + const spu::Shape &input_shape, const spu::Shape &output_shape, const xt::xarray &start_indices) : config_(config), start_indices_(start_indices) { for (int64_t i = 0; i < static_cast(output_shape.size()); ++i) { @@ -146,7 +145,7 @@ class OutputBatchIndexToInputIndex { // same storage for all invocations. // // This returns a Span into memory owned by the class. - absl::Span operator()(absl::Span output_index) { + spu::Index &operator()(const spu::Index &output_index) { propagateOutputIndexGatherDimsToIndexVectorIndex(output_index); fetchIndexVector(); propagateIndexVectorToInputIndex(); @@ -197,7 +196,7 @@ class OutputBatchIndexToInputIndex { // input_dim_value_to_index_vector_[i] tells us how to compute dimension i // of the input index from the index vector. See // PropagateIndexVectorToInputIndex. - std::vector input_dim_value_to_index_vector_; + spu::Index input_dim_value_to_index_vector_; // output_dim_is_batch_dims_[i] is true iff the output index i is a gather // dimension. @@ -208,11 +207,11 @@ class OutputBatchIndexToInputIndex { spu::Index index_vector_index_; // The index vector fetched from start_indices_. - std::vector index_vector_; + spu::Index index_vector_; // The result computed by this functor. operator() returns a Span into // this vector. - std::vector input_index_; + spu::Index input_index_; const spu::kernel::hlo::GatherConfig &config_; const xt::xarray &start_indices_; @@ -229,9 +228,8 @@ class OutputOffsetIndexToInputIndex { // iterations. explicit OutputOffsetIndexToInputIndex( const spu::kernel::hlo::GatherConfig &config, - absl::Span input_shape, - absl::Span output_shape) { - std::vector window_index_to_output_index; + const spu::Shape &input_shape, const spu::Shape &output_shape) { + spu::Index window_index_to_output_index; int64_t output_index_count = 0; for (int64_t i = 0; i < static_cast(output_shape.size()); i++) { if (std::binary_search(config.offsetDims.begin(), config.offsetDims.end(), @@ -265,7 +263,7 @@ class OutputOffsetIndexToInputIndex { // result (input_index_), mutating it in place. // // This returns a Span into memory owned by the class. - absl::Span operator()(absl::Span output_index) { + spu::Index &operator()(const spu::Index &output_index) { propagateOutputIndexWindowDimsToInputIndex(output_index); return input_index_; } @@ -291,11 +289,11 @@ class OutputOffsetIndexToInputIndex { // input_dim_value_to_index_vector_[i] tells us how to compute dimension i // of the input index from the output index. See // PropagateOutputIndexWindowDimsToInputIndex. - std::vector input_dim_value_to_output_index_; + spu::Index input_dim_value_to_output_index_; // The result computed by this functor. operator() returns a Span into // this vector. - std::vector input_index_; + spu::Index input_index_; }; spu::Value reshapedGatherIndices(spu::SPUContext *ctx, int64_t index_vector_dim, @@ -393,10 +391,10 @@ std::vector ClampAndFlattenIndex( } // Now compute offsets of each index - std::vector base(iterate_shape.size(), 0); - std::vector incr(iterate_shape.size(), 1); + spu::Index base(iterate_shape.size(), 0); + spu::Index incr(iterate_shape.size(), 1); - std::vector flatten_idx; + spu::Index flatten_idx; spu::kernel::forEachIndex( limit_shape, base, iterate_shape, incr, [&flatten_idx, &limit_shape](const spu::Index &idx) { @@ -470,58 +468,55 @@ spu::Value Gather(SPUContext *ctx, const spu::Value &operand, operand.dtype()); } - auto gather_inner_loop_body = - [&](absl::Span output_window_index, - absl::Span input_gather_index, - absl::Span output_gather_index) { - auto input_window_index = - output_offset_index_to_input_index(output_window_index); - for (int i = 0, e = output_index.size(); i < e; i++) { - output_index[i] = output_gather_index[i] + output_window_index[i]; - } - for (int i = 0, e = input_gather_index.size(); i < e; i++) { - int64_t output_dim = output_offset_index_to_input_index - .input_dim_value_to_output_index(i); - // If 'output_dim' is -1, it means 'i' is an elided window dim. This - // means we set the iteration index to 0, so for the purpose of the - // following calculations we can consider the output dimension size - // to be 1. - int64_t output_dim_size = - output_dim == -1 ? 1 : result_shape[output_dim]; - // Clamp the gather index so that the gather region fits in the - // operand. input_index_clamped[i] = clamp(input_gather_index[i], 0, - // operand_shape.dimensions(i) - // - output_dim_size); - input_index_clamped[i] = - std::min(operand_shape[i] - output_dim_size, - std::max(int64_t{0}, input_gather_index[i])); - } - for (int i = 0, e = input_index.size(); i < e; i++) { - input_index[i] = input_index_clamped[i] + input_window_index[i]; - } - - result.data().update_slice(operand.data().slice_scalar_at(input_index), - output_index); - - if (result.isComplex()) { - result.imag()->update_slice( - operand.imag()->slice_scalar_at(input_index), output_index); - } - }; - - auto gather_outer_loop_body = - [&](absl::Span output_gather_index) { - auto input_gather_index = - output_batch_index_to_input_index(output_gather_index); - forEachIndex(result_shape, offset_indices_iteration_space.index_base, - offset_indices_iteration_space.index_count, - offset_indices_iteration_space.index_incr, - [&](absl::Span output_window_index) { - return gather_inner_loop_body(output_window_index, - input_gather_index, - output_gather_index); - }); - }; + auto gather_inner_loop_body = [&](const spu::Index &output_window_index, + const spu::Index &input_gather_index, + const spu::Index &output_gather_index) { + auto input_window_index = + output_offset_index_to_input_index(output_window_index); + for (int i = 0, e = output_index.size(); i < e; i++) { + output_index[i] = output_gather_index[i] + output_window_index[i]; + } + for (int i = 0, e = input_gather_index.size(); i < e; i++) { + int64_t output_dim = + output_offset_index_to_input_index.input_dim_value_to_output_index(i); + // If 'output_dim' is -1, it means 'i' is an elided window dim. This + // means we set the iteration index to 0, so for the purpose of the + // following calculations we can consider the output dimension size + // to be 1. + int64_t output_dim_size = output_dim == -1 ? 1 : result_shape[output_dim]; + // Clamp the gather index so that the gather region fits in the + // operand. input_index_clamped[i] = clamp(input_gather_index[i], 0, + // operand_shape.dimensions(i) + // - output_dim_size); + input_index_clamped[i] = + std::min(operand_shape[i] - output_dim_size, + std::max(int64_t{0}, input_gather_index[i])); + } + for (int i = 0, e = input_index.size(); i < e; i++) { + input_index[i] = input_index_clamped[i] + input_window_index[i]; + } + + result.data().update_slice(operand.data().slice_scalar_at(input_index), + output_index); + + if (result.isComplex()) { + result.imag()->update_slice(operand.imag()->slice_scalar_at(input_index), + output_index); + } + }; + + auto gather_outer_loop_body = [&](const spu::Index &output_gather_index) { + auto input_gather_index = + output_batch_index_to_input_index(output_gather_index); + forEachIndex(result_shape, offset_indices_iteration_space.index_base, + offset_indices_iteration_space.index_count, + offset_indices_iteration_space.index_incr, + [&](const spu::Index &output_window_index) { + return gather_inner_loop_body(output_window_index, + input_gather_index, + output_gather_index); + }); + }; forEachIndex(result_shape, start_indices_iteration_space.index_base, start_indices_iteration_space.index_count, @@ -706,7 +701,7 @@ spu::Value SecretDynamicSlice(SPUContext *ctx, const spu::Value &operand, hlo::Constant(ctx, std::vector(slice_size.size(), 0), {static_cast(slice_size.size())}); - std::vector limit = operand.shape(); + spu::Shape limit = operand.shape(); for (size_t idx = 0; idx < limit.size(); ++idx) { limit[idx] -= slice_size[idx]; } diff --git a/libspu/kernel/hlo/indexing.h b/libspu/kernel/hlo/indexing.h index 91f38fdf..1086e05a 100644 --- a/libspu/kernel/hlo/indexing.h +++ b/libspu/kernel/hlo/indexing.h @@ -23,11 +23,11 @@ class SPUContext; namespace spu::kernel::hlo { struct GatherConfig { - absl::Span sliceSizes; + spu::Sizes sliceSizes; int64_t indexVectorDim; - absl::Span offsetDims; - absl::Span collapsedSliceDims; - absl::Span startIndexMap; + spu::Axes offsetDims; + spu::Axes collapsedSliceDims; + spu::Axes startIndexMap; }; // This is ported from diff --git a/libspu/kernel/hlo/shuffle.cc b/libspu/kernel/hlo/shuffle.cc index e33c324c..2ed5d0a0 100644 --- a/libspu/kernel/hlo/shuffle.cc +++ b/libspu/kernel/hlo/shuffle.cc @@ -22,6 +22,19 @@ namespace spu::kernel::hlo { +namespace { + +spu::Value _2s(SPUContext* ctx, const Value& x) { + if (x.isPublic()) { + return hal::_p2s(ctx, x); + } else if (x.isPrivate()) { + return hal::_v2s(ctx, x); + } + return x; +} + +} // namespace + std::vector Shuffle(SPUContext* ctx, absl::Span inputs, int64_t axis) { @@ -32,12 +45,13 @@ std::vector Shuffle(SPUContext* ctx, auto input_shape = inputs[0].shape(); // TODO: Rename permute-related kernels - if (ctx->hasKernel("rand_perm_s") && ctx->hasKernel("perm_as")) { + if (ctx->hasKernel("rand_perm_m") && ctx->hasKernel("perm_am")) { auto shuffle_fn = [&](absl::Span input) { std::vector rets; auto rand_perm = hal::_rand_perm_s(ctx, input_shape); for (size_t i = 0; i < input.size(); ++i) { - rets.emplace_back(hal::_perm_ss(ctx, input[i], rand_perm)); + rets.emplace_back(hal::_perm_ss(ctx, _2s(ctx, input[i]), rand_perm) + .setDtype(input[i].dtype())); } return rets; }; diff --git a/libspu/mpc/ab_api_test.cc b/libspu/mpc/ab_api_test.cc index 6bc25ffe..ed031921 100644 --- a/libspu/mpc/ab_api_test.cc +++ b/libspu/mpc/ab_api_test.cc @@ -101,10 +101,10 @@ bool verifyCost(Kernel* kernel, std::string_view name, FieldType field, /* WHEN */ \ auto a0 = p2a(obj.get(), p0); \ auto a1 = p2a(obj.get(), p1); \ - auto prev = obj->prot()->getState()->getStats(); \ + auto prev = obj->prot() -> getState() -> getStats(); \ auto tmp = OP##_aa(obj.get(), a0, a1); \ auto cost = \ - obj->prot()->getState()->getStats() - prev; \ + obj->prot() -> getState() -> getStats() - prev; \ auto re = a2p(obj.get(), tmp); \ auto rp = OP##_pp(obj.get(), p0, p1); \ \ @@ -131,10 +131,10 @@ bool verifyCost(Kernel* kernel, std::string_view name, FieldType field, \ /* WHEN */ \ auto a0 = p2a(obj.get(), p0); \ - auto prev = obj->prot()->getState()->getStats(); \ + auto prev = obj->prot() -> getState() -> getStats(); \ auto tmp = OP##_ap(obj.get(), a0, p1); \ auto cost = \ - obj->prot()->getState()->getStats() - prev; \ + obj->prot() -> getState() -> getStats() - prev; \ auto re = a2p(obj.get(), tmp); \ auto rp = OP##_pp(obj.get(), p0, p1); \ \ @@ -274,6 +274,45 @@ TEST_P(ArithmeticTest, MatMulAA) { }); } +TEST_P(ArithmeticTest, MatMulAV) { + const auto factory = std::get<0>(GetParam()); + const RuntimeConfig& conf = std::get<1>(GetParam()); + const size_t npc = std::get<2>(GetParam()); + const int64_t M = 3; + const int64_t K = 4; + const int64_t N = 3; + const Shape shape_A = {M, K}; + const Shape shape_B = {K, N}; + const Shape shape_C = {M, N}; + utils::simulate(npc, [&](const std::shared_ptr& lctx) { + auto obj = factory(conf, lctx); + /* GIVEN */ + auto p0 = rand_p(obj.get(), shape_A); + auto p1 = rand_p(obj.get(), shape_B); + auto a0 = p2a(obj.get(), p0); + auto v1 = p2v(obj.get(), p1, 0); + /* WHEN */ + auto prev = obj->prot()->getState()->getStats(); + auto _tmp = mmul_av(obj.get(), a0, v1); + if (!_tmp.has_value()) { + return; + } + auto tmp = _tmp.value(); + auto cost = obj->prot()->getState()->getStats() - prev; + auto r_aa = a2p(obj.get(), tmp); + auto r_pp = mmul_pp(obj.get(), p0, p1); + /* THEN */ + EXPECT_VALUE_EQ(r_aa, r_pp); + ce::Params params = {{"K", SizeOf(conf.field()) * 8}, + {"N", npc}, + {"m", M}, + {"n", N}, + {"k", K}}; + EXPECT_TRUE(verifyCost(obj->prot()->getKernel("mmul_av"), "mmul_av", params, + cost, 1)); + }); +} + TEST_P(ArithmeticTest, NotA) { const auto factory = std::get<0>(GetParam()); const RuntimeConfig& conf = std::get<1>(GetParam()); @@ -441,10 +480,10 @@ TEST_P(ArithmeticTest, A2P) { /* WHEN */ \ auto b0 = p2b(obj.get(), p0); \ auto b1 = p2b(obj.get(), p1); \ - auto prev = obj->prot()->getState()->getStats(); \ + auto prev = obj->prot() -> getState() -> getStats(); \ auto tmp = OP##_bb(obj.get(), b0, b1); \ auto cost = \ - obj->prot()->getState()->getStats() - prev; \ + obj->prot() -> getState() -> getStats() - prev; \ auto re = b2p(obj.get(), tmp); \ auto rp = OP##_pp(obj.get(), p0, p1); \ \ @@ -471,10 +510,10 @@ TEST_P(ArithmeticTest, A2P) { \ /* WHEN */ \ auto b0 = p2b(obj.get(), p0); \ - auto prev = obj->prot()->getState()->getStats(); \ + auto prev = obj->prot() -> getState() -> getStats(); \ auto tmp = OP##_bp(obj.get(), b0, p1); \ auto cost = \ - obj->prot()->getState()->getStats() - prev; \ + obj->prot() -> getState() -> getStats() - prev; \ auto re = b2p(obj.get(), tmp); \ auto rp = OP##_pp(obj.get(), p0, p1); \ \ @@ -511,10 +550,10 @@ TEST_BOOLEAN_BINARY_OP(xor) continue; \ } \ /* WHEN */ \ - auto prev = obj->prot()->getState()->getStats(); \ + auto prev = obj->prot() -> getState() -> getStats(); \ auto tmp = OP##_b(obj.get(), b0, bits); \ auto cost = \ - obj->prot()->getState()->getStats() - prev; \ + obj->prot() -> getState() -> getStats() - prev; \ auto r_b = b2p(obj.get(), tmp); \ auto r_p = OP##_p(obj.get(), p0, bits); \ \ diff --git a/libspu/mpc/aby3/BUILD.bazel b/libspu/mpc/aby3/BUILD.bazel index dd7f81f6..92090ec0 100644 --- a/libspu/mpc/aby3/BUILD.bazel +++ b/libspu/mpc/aby3/BUILD.bazel @@ -34,6 +34,7 @@ spu_cc_library( ":conversion", ":permute", ":value", + "//libspu/mpc/standard_shape:protocol", ], ) diff --git a/libspu/mpc/aby3/permute.cc b/libspu/mpc/aby3/permute.cc index cf11241c..edbeed7f 100644 --- a/libspu/mpc/aby3/permute.cc +++ b/libspu/mpc/aby3/permute.cc @@ -39,7 +39,7 @@ PermVector ring2pv(const NdArrayRef& x) { } // namespace -NdArrayRef RandPermS::proc(KernelEvalContext* ctx, const Shape& shape) const { +NdArrayRef RandPermM::proc(KernelEvalContext* ctx, const Shape& shape) const { NdArrayRef out(makeType(), shape); // generate a RandU64 pair as permutation seeds @@ -67,7 +67,7 @@ NdArrayRef RandPermS::proc(KernelEvalContext* ctx, const Shape& shape) const { // Ref: https://eprint.iacr.org/2019/695.pdf // Algorithm 9: Optimized shuffling protocol -NdArrayRef PermAS::proc(KernelEvalContext* ctx, const NdArrayRef& in, +NdArrayRef PermAM::proc(KernelEvalContext* ctx, const NdArrayRef& in, const NdArrayRef& perm) const { auto* comm = ctx->getState(); const auto numel = in.numel(); @@ -187,7 +187,7 @@ NdArrayRef PermAP::proc(KernelEvalContext* ctx, const NdArrayRef& in, // Ref: https://eprint.iacr.org/2019/695.pdf // Algorithm 17: Optimized unshuffling protocol -NdArrayRef InvPermAS::proc(KernelEvalContext* ctx, const NdArrayRef& in, +NdArrayRef InvPermAM::proc(KernelEvalContext* ctx, const NdArrayRef& in, const NdArrayRef& perm) const { auto* comm = ctx->getState(); const auto numel = in.numel(); diff --git a/libspu/mpc/aby3/permute.h b/libspu/mpc/aby3/permute.h index 121422d3..8a332da5 100644 --- a/libspu/mpc/aby3/permute.h +++ b/libspu/mpc/aby3/permute.h @@ -18,9 +18,9 @@ namespace spu::mpc::aby3 { -class RandPermS : public RandKernel { +class RandPermM : public RandKernel { public: - static constexpr char kBindName[] = "rand_perm_s"; + static constexpr char kBindName[] = "rand_perm_m"; ce::CExpr latency() const override { return ce::Const(0); } @@ -29,9 +29,9 @@ class RandPermS : public RandKernel { NdArrayRef proc(KernelEvalContext* ctx, const Shape& shape) const override; }; -class PermAS : public PermKernel { +class PermAM : public PermKernel { public: - static constexpr char kBindName[] = "perm_as"; + static constexpr char kBindName[] = "perm_am"; Kind kind() const override { return Kind::Dynamic; } @@ -51,9 +51,9 @@ class PermAP : public PermKernel { const NdArrayRef& perm) const override; }; -class InvPermAS : public PermKernel { +class InvPermAM : public PermKernel { public: - static constexpr char kBindName[] = "inv_perm_as"; + static constexpr char kBindName[] = "inv_perm_am"; Kind kind() const override { return Kind::Dynamic; } diff --git a/libspu/mpc/aby3/protocol.cc b/libspu/mpc/aby3/protocol.cc index 068fe460..b9e7a80c 100644 --- a/libspu/mpc/aby3/protocol.cc +++ b/libspu/mpc/aby3/protocol.cc @@ -22,6 +22,9 @@ #include "libspu/mpc/common/communicator.h" #include "libspu/mpc/common/prg_state.h" #include "libspu/mpc/common/pv2k.h" +#include "libspu/mpc/standard_shape/protocol.h" + +#define ENABLE_PRECISE_ABY3_TRUNCPR namespace spu::mpc { @@ -40,58 +43,38 @@ void regAby3Protocol(SPUContext* ctx, // register public kernels. regPV2kKernels(ctx->prot()); - // register arithmetic & binary kernels - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); + // Register standard shape ops + regStandardShapeOps(ctx); -#define ENABLE_PRECISE_ABY3_TRUNCPR + // register arithmetic & binary kernels + ctx->prot() + ->regKernel< // + aby3::P2A, aby3::V2A, aby3::A2P, aby3::A2V, // Conversions + aby3::B2P, aby3::P2B, aby3::A2B, // Conversion2 + aby3::B2ASelector, /*aby3::B2AByOT, aby3::B2AByPPA*/ // B2A + aby3::CastTypeB, // Cast + aby3::NotA, // Not + aby3::AddAP, aby3::AddAA, // Add + aby3::MulAP, aby3::MulAA, aby3::MulA1B, // Mul + aby3::MatMulAP, aby3::MatMulAA, // MatMul + aby3::LShiftA, aby3::LShiftB, // LShift + aby3::RShiftB, aby3::ARShiftB, // (A)Rshift + aby3::MsbA2B, // MSB + aby3::EqualAA, aby3::EqualAP, // Equal + aby3::CommonTypeB, aby3::CommonTypeV, // CommonType + aby3::AndBP, aby3::AndBB, // And + aby3::XorBP, aby3::XorBB, // Xor + aby3::BitrevB, // bitreverse + aby3::BitIntlB, aby3::BitDeintlB, // bit(de)interleave + aby3::RandA, // rand #ifdef ENABLE_PRECISE_ABY3_TRUNCPR - ctx->prot()->regKernel(); + aby3::TruncAPr, // Trunc #else - ctx->prot()->regKernel(); + aby3::TruncA, #endif - - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - // ctx->prot()->regKernel(); - // ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); + aby3::RandPermM, aby3::PermAM, aby3::PermAP, aby3::InvPermAM, // perm + aby3::InvPermAP // perm + >(); } std::unique_ptr makeAby3Protocol( diff --git a/libspu/mpc/aby3/type.h b/libspu/mpc/aby3/type.h index 4852d03f..74382b9c 100644 --- a/libspu/mpc/aby3/type.h +++ b/libspu/mpc/aby3/type.h @@ -90,8 +90,8 @@ class BShrTy : public TypeImpl { }; // Permutation share -class PShrTy : public TypeImpl { - using Base = TypeImpl; +class PShrTy : public TypeImpl { + using Base = TypeImpl; public: using Base::Base; diff --git a/libspu/mpc/api.cc b/libspu/mpc/api.cc index 77d33c46..ac93e39c 100644 --- a/libspu/mpc/api.cc +++ b/libspu/mpc/api.cc @@ -24,6 +24,7 @@ namespace { inline bool IsA(const Value& x) { return x.storage_type().isa(); } inline bool IsB(const Value& x) { return x.storage_type().isa(); } +inline bool IsPShr(const Value& x) { return x.storage_type().isa(); } [[maybe_unused]] inline bool IsP(const Value& x) { return x.storage_type().isa(); } @@ -669,13 +670,14 @@ Value bitrev_p(SPUContext* ctx, const Value& x, size_t start, size_t end) { OptionalAPI rand_perm_s(SPUContext* ctx, const Shape& shape) { SPU_TRACE_MPC_DISP(ctx, shape); - TRY_DISPATCH(ctx, shape); + TRY_NAMED_DISPATCH(ctx, "rand_perm_m", shape); return NotAvailable; } OptionalAPI perm_ss(SPUContext* ctx, const Value& x, const Value& perm) { + SPU_ENFORCE(IsPShr(perm), "perm should be a PShare"); SPU_TRACE_MPC_DISP(ctx, x, perm); - TRY_NAMED_DISPATCH(ctx, "perm_as", _2a(ctx, x), perm); + TRY_NAMED_DISPATCH(ctx, "perm_am", _2a(ctx, x), perm); return NotAvailable; } @@ -685,10 +687,21 @@ OptionalAPI perm_sp(SPUContext* ctx, const Value& x, const Value& perm) { return NotAvailable; } +spu::Value perm_pp(SPUContext* ctx, const Value& in, const Value& perm) { + FORCE_DISPATCH(ctx, in, perm); +} + +spu::Value perm_vv(SPUContext* ctx, const Value& in, const Value& perm) { + SPU_ENFORCE(hasSameOwner(in, perm), + "in and perm should belong to the same owner"); + FORCE_DISPATCH(ctx, in, perm); +} + OptionalAPI inv_perm_ss(SPUContext* ctx, const Value& x, const Value& perm) { + SPU_ENFORCE(IsPShr(perm), "perm should be a PShare"); SPU_TRACE_MPC_DISP(ctx, x, perm); - TRY_NAMED_DISPATCH(ctx, "inv_perm_as", _2a(ctx, x), perm); + TRY_NAMED_DISPATCH(ctx, "inv_perm_am", _2a(ctx, x), perm); return NotAvailable; } @@ -699,4 +712,83 @@ OptionalAPI inv_perm_sp(SPUContext* ctx, const Value& x, return NotAvailable; } +OptionalAPI inv_perm_sv(SPUContext* ctx, const Value& x, + const Value& perm) { + SPU_TRACE_MPC_DISP(ctx, x, perm); + TRY_NAMED_DISPATCH(ctx, "inv_perm_av", _2a(ctx, x), perm); + return NotAvailable; +} + +spu::Value inv_perm_pp(SPUContext* ctx, const Value& in, const Value& perm) { + FORCE_DISPATCH(ctx, in, perm); +} + +spu::Value inv_perm_vv(SPUContext* ctx, const Value& in, const Value& perm) { + SPU_ENFORCE(hasSameOwner(in, perm), + "in and perm should belong to the same owner"); + FORCE_DISPATCH(ctx, in, perm); +} + +Value broadcast(SPUContext* ctx, const Value& in, const Shape& to_shape, + const Axes& in_dims) { + SPU_TRACE_MPC_DISP(ctx, in, to_shape, in_dims); + FORCE_DISPATCH(ctx, in, to_shape, in_dims); +} + +// Resahpe a Value +Value reshape(SPUContext* ctx, const Value& in, const Shape& to_shape) { + SPU_TRACE_MPC_DISP(ctx, in, to_shape); + FORCE_DISPATCH(ctx, in, to_shape); +} + +// Extract a slice from a Value +Value extract_slice(SPUContext* ctx, const Value& in, + const Index& start_indices, const Index& end_indices, + const Strides& strides) { + SPU_TRACE_MPC_DISP(ctx, in, start_indices, end_indices, strides); + FORCE_DISPATCH(ctx, in, start_indices, end_indices, strides); +} + +// Update a Value at index with given value +Value update_slice(SPUContext* ctx, const Value& in, const Value& update, + const Index& start_indices) { + SPU_TRACE_MPC_DISP(ctx, in, update, start_indices); + FORCE_DISPATCH(ctx, in, update, start_indices); +} + +// Transpose a Value +Value transpose(SPUContext* ctx, const Value& in, const Axes& permutation) { + SPU_TRACE_MPC_DISP(ctx, in, permutation); + FORCE_DISPATCH(ctx, in, permutation); +} + +// Reverse a Value at dimensions +Value reverse(SPUContext* ctx, const Value& in, const Axes& dimensions) { + SPU_TRACE_MPC_DISP(ctx, in, dimensions); + FORCE_DISPATCH(ctx, in, dimensions); +} + +// Fill a Value with input value +Value fill(SPUContext* ctx, const Value& in, const Shape& to_shape) { + SPU_TRACE_MPC_DISP(ctx, in, to_shape); + FORCE_DISPATCH(ctx, in, to_shape); +} + +// Pad a Value +Value pad(SPUContext* ctx, const Value& in, const Value& padding_value, + const Sizes& edge_padding_low, const Sizes& edge_padding_high, + const Sizes& interior_padding) { + SPU_TRACE_MPC_DISP(ctx, in, padding_value, edge_padding_low, + edge_padding_high, interior_padding); + FORCE_DISPATCH(ctx, in, padding_value, edge_padding_low, edge_padding_high, + interior_padding); +} + +// Concate Values at an axis +Value concatenate(SPUContext* ctx, const std::vector& values, + int64_t axis) { + SPU_TRACE_MPC_DISP(ctx, values, axis); + FORCE_DISPATCH(ctx, values, axis); +} + } // namespace spu::mpc diff --git a/libspu/mpc/api.h b/libspu/mpc/api.h index b8842c89..7ec17422 100644 --- a/libspu/mpc/api.h +++ b/libspu/mpc/api.h @@ -161,28 +161,65 @@ Value bitrev_s(SPUContext* ctx, const Value& x, size_t start, size_t end); Value bitrev_v(SPUContext* ctx, const Value& x, size_t start, size_t end); Value bitrev_p(SPUContext* ctx, const Value& x, size_t start, size_t end); -// Generate a 1-D random secret permutation, here secret means the permutation +////////////////////////////////////////////////////////////////////////////// +// TODO: Formalize these permutation APIs +////////////////////////////////////////////////////////////////////////////// +// Generate a 1-D random secret permutation. Here secret means the permutation // is composed of a series of individual permutations hold by each party. // Specifically, if Perm = Perm1(Perm0), then party0 holds Perm0 and party1 // holds Perm1 OptionalAPI rand_perm_s(SPUContext* ctx, const Shape& shape); -// Permute 1-D secret x with public permutation perm -// ret[i] = [perm[i]] -OptionalAPI perm_ss(SPUContext* ctx, const Value& x, const Value& perm); - -// Permute 1-D secret x with secret permutation perm -// ret[i] = +// Permute 1-D x with permutation perm +// ret[i] = x[perm[i]] OptionalAPI perm_sp(SPUContext* ctx, const Value& x, const Value& perm); +OptionalAPI perm_ss(SPUContext* ctx, const Value& x, const Value& perm); +Value perm_pp(SPUContext* ctx, const Value& x, const Value& perm); +Value perm_vv(SPUContext* ctx, const Value& x, const Value& perm); -// Inverse permute 1-D secret x with public permutation perm -// ret[perm[i]] = [i] +// Inverse permute 1-D x with permutation perm +// ret[perm[i]] = x[i] +OptionalAPI inv_perm_sp(SPUContext* ctx, const Value& x, + const Value& perm); OptionalAPI inv_perm_ss(SPUContext* ctx, const Value& x, const Value& perm); - -// Inverse permute 1-D secret x with secret permutation perm -// ret[perm[i]] = -OptionalAPI inv_perm_sp(SPUContext* ctx, const Value& x, +OptionalAPI inv_perm_sv(SPUContext* ctx, const Value& x, const Value& perm); +Value inv_perm_pp(SPUContext* ctx, const Value& x, const Value& perm); +Value inv_perm_vv(SPUContext* ctx, const Value& x, const Value& perm); + +/*---------------------------- Value APIs ----------------------------------*/ +// Broadcast a Value +Value broadcast(SPUContext* ctx, const Value& in, const Shape& to_shape, + const Axes& in_dims); + +// Resahpe a Value +Value reshape(SPUContext* ctx, const Value& in, const Shape& to_shape); + +// Extract a slice from a Value +Value extract_slice(SPUContext* ctx, const Value& in, + const Index& start_indices, const Index& end_indices, + const Strides& strides); + +// Update a Value at index with given value +Value update_slice(SPUContext* ctx, const Value& in, const Value& update, + const Index& start_indices); + +// Transpose a Value +Value transpose(SPUContext* ctx, const Value& in, const Axes& permutation); + +// Reverse a Value at dimensions +Value reverse(SPUContext* ctx, const Value& in, const Axes& dimensions); + +// Fill a Value with input value +Value fill(SPUContext* ctx, const Value& in, const Shape& to_shape); + +// Pad a Value +Value pad(SPUContext* ctx, const Value& in, const Value& padding_value, + const Sizes& edge_padding_low, const Sizes& edge_padding_high, + const Sizes& interior_padding); +// Concate Values at an axis +Value concatenate(SPUContext* ctx, const std::vector& values, + int64_t axis); } // namespace spu::mpc diff --git a/libspu/mpc/cheetah/BUILD.bazel b/libspu/mpc/cheetah/BUILD.bazel index ba05873d..4112ecd4 100644 --- a/libspu/mpc/cheetah/BUILD.bazel +++ b/libspu/mpc/cheetah/BUILD.bazel @@ -99,6 +99,7 @@ spu_cc_library( ":state", "//libspu/mpc/common:prg_state", "//libspu/mpc/common:pv2k", + "//libspu/mpc/standard_shape:protocol", ], ) diff --git a/libspu/mpc/cheetah/arith/cheetah_dot.cc b/libspu/mpc/cheetah/arith/cheetah_dot.cc index b43c68d5..e64324fb 100644 --- a/libspu/mpc/cheetah/arith/cheetah_dot.cc +++ b/libspu/mpc/cheetah/arith/cheetah_dot.cc @@ -556,7 +556,7 @@ NdArrayRef CheetahDot::Impl::DotOLE(const NdArrayRef &prv_mat, conn = lctx_.get(); } auto eltype = prv_mat.eltype(); - SPU_ENFORCE(eltype.isa(), "must be ring_type, got={}", eltype); + SPU_ENFORCE(eltype.isa(), "must be ring_type, got={}", eltype); SPU_ENFORCE(prv_mat.numel() > 0 && prv_mat.ndim() == 2); if (is_self_lhs) { @@ -576,7 +576,7 @@ NdArrayRef CheetahDot::Impl::BatchDotOLE(const NdArrayRef &prv_mat, conn = lctx_.get(); } auto eltype = prv_mat.eltype(); - SPU_ENFORCE(eltype.isa(), "must be ring_type, got={}", eltype); + SPU_ENFORCE(eltype.isa(), "must be ring_type, got={}", eltype); SPU_ENFORCE(prv_mat.numel() > 0 && prv_mat.ndim() == 3); if (is_self_lhs) { diff --git a/libspu/mpc/cheetah/arith/conv2d_prot.cc b/libspu/mpc/cheetah/arith/conv2d_prot.cc index 1c197deb..64fc3e6e 100644 --- a/libspu/mpc/cheetah/arith/conv2d_prot.cc +++ b/libspu/mpc/cheetah/arith/conv2d_prot.cc @@ -416,7 +416,7 @@ ArrayRef Conv2DProtocol::ParseResult(FieldType field, const Meta &meta, Conv2DHelper helper(meta, GetSubTensorShape(meta)); size_t poly_per_channel = helper.slice_size(kH) * helper.slice_size(kW); - NdArrayRef computed_tensor(makeType(field), oshape); + NdArrayRef computed_tensor(makeType(field), oshape); for (int64_t m = 0; m < meta.num_kernels; ++m) { size_t poly_idx = m * poly_per_channel; diff --git a/libspu/mpc/cheetah/arith/matmat_prot.cc b/libspu/mpc/cheetah/arith/matmat_prot.cc index 54206653..8e49b1b0 100644 --- a/libspu/mpc/cheetah/arith/matmat_prot.cc +++ b/libspu/mpc/cheetah/arith/matmat_prot.cc @@ -98,7 +98,7 @@ NdArrayRef ConcatSubMatrix(const NdArrayRef& mat, const Shape2D& mat_shape, const Shape2D& submat_shape, int64_t num_coeff, const Indexer& indexer) { const Type& eltype = mat.eltype(); - SPU_ENFORCE(eltype.isa(), "must be ring_type, got={}", eltype); + SPU_ENFORCE(eltype.isa(), "must be ring_type, got={}", eltype); // SPU_ENFORCE(mat.ndim() == 2, "should be a 2D matrix"); SPU_ENFORCE_EQ(mat.numel(), mat_shape[0] * mat_shape[1]); SPU_ENFORCE(num_coeff >= submat_shape[0] * submat_shape[1]); diff --git a/libspu/mpc/cheetah/arithmetic.cc b/libspu/mpc/cheetah/arithmetic.cc index 9dac9189..2a0fe91a 100644 --- a/libspu/mpc/cheetah/arithmetic.cc +++ b/libspu/mpc/cheetah/arithmetic.cc @@ -340,4 +340,29 @@ NdArrayRef MatMulAA::proc(KernelEvalContext* ctx, const NdArrayRef& x, return ring_add(ret, task.get()).as(x.eltype()); } +NdArrayRef MatMulAV::proc(KernelEvalContext* ctx, const NdArrayRef& x, + const NdArrayRef& y) const { + if (0 == x.numel() || 0 == y.numel()) { + return NdArrayRef(x.eltype(), {x.shape()[0], y.shape()[1]}); + } + auto* comm = ctx->getState(); + auto* dot_prot = ctx->getState()->get(); + const int rank = comm->getRank(); + const auto* ptype = y.eltype().as(); + SPU_ENFORCE(ptype != nullptr, "rhs should be a private type"); + const int owner = ptype->owner(); + NdArrayRef out; + const Shape3D dim3 = {x.shape()[0], x.shape()[1], y.shape()[1]}; + // (x0 + x1)*y = _0 + _1 + x1 * y + if (rank == owner) { + // Compute + out = dot_prot->DotOLE(y, dim3, false); + auto local = ring_mmul(x, y); + ring_add_(out, local); + } else { + out = dot_prot->DotOLE(x, dim3, true); + } + return out.as(x.eltype()); +} + } // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/arithmetic.h b/libspu/mpc/cheetah/arithmetic.h index b66d38b5..3af3713f 100644 --- a/libspu/mpc/cheetah/arithmetic.h +++ b/libspu/mpc/cheetah/arithmetic.h @@ -155,6 +155,17 @@ class MatMulAP : public MatmulKernel { const NdArrayRef& y) const override; }; +class MatMulAV : public MatmulKernel { + public: + static constexpr char kBindName[] = "mmul_av"; + + Kind kind() const override { return Kind::Dynamic; } + // LHS: m x k + // RHS: k x n + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& x, + const NdArrayRef& y) const override; +}; + class MatMulAA : public MatmulKernel { public: static constexpr char kBindName[] = "mmul_aa"; diff --git a/libspu/mpc/cheetah/ot/yacl/BUILD.bazel b/libspu/mpc/cheetah/ot/yacl/BUILD.bazel index a7181688..1268c392 100644 --- a/libspu/mpc/cheetah/ot/yacl/BUILD.bazel +++ b/libspu/mpc/cheetah/ot/yacl/BUILD.bazel @@ -22,7 +22,6 @@ spu_cc_library( srcs = [ "ferret.cc", "yacl_ote_adapter.cc", - "yacl_util.cc", ], hdrs = [ "ferret.h", @@ -37,12 +36,15 @@ spu_cc_library( "//libspu/mpc/cheetah/ot:ot_util", "//libspu/mpc/common:communicator", "//libspu/mpc/semi2k:conversion", + "@yacl//yacl/base:aligned_vector", + "@yacl//yacl/base:buffer", "@yacl//yacl/base:dynamic_bitset", "@yacl//yacl/base:int128", "@yacl//yacl/crypto/base/aes:aes_opt", "@yacl//yacl/crypto/primitives/ot:base_ot", "@yacl//yacl/crypto/primitives/ot:ferret_ote", "@yacl//yacl/crypto/primitives/ot:iknp_ote", + "@yacl//yacl/crypto/primitives/ot:softspoken_ote", "@yacl//yacl/crypto/tools:crhash", "@yacl//yacl/crypto/tools:rp", "@yacl//yacl/crypto/utils:rand", diff --git a/libspu/mpc/cheetah/ot/yacl/ferret.cc b/libspu/mpc/cheetah/ot/yacl/ferret.cc index b21a589d..941fe6f1 100644 --- a/libspu/mpc/cheetah/ot/yacl/ferret.cc +++ b/libspu/mpc/cheetah/ot/yacl/ferret.cc @@ -222,7 +222,8 @@ struct YaclFerretOt::Impl { "bit_width={} out-of-range T={} bits", bit_width, sizeof(T) * 8); - yacl::AlignedVector rcm_output(n); + yacl::Buffer buf(n * sizeof(uint128_t)); + auto rcm_output = MakeSpan_Uint128(buf); SendRandCorrelatedMsgChosenChoice(rcm_output.data(), n); @@ -262,6 +263,7 @@ struct YaclFerretOt::Impl { io_->send_data(corr_output.data(), sizeof(T) * this_batch); } } + io_->flush(); } template @@ -276,8 +278,10 @@ struct YaclFerretOt::Impl { "bit_width={} out-of-range T={} bits", bit_width, sizeof(T) * 8); - yacl::AlignedVector rcm_output(n); - RecvRandCorrelatedMsgChosenChoice(choices, absl::MakeSpan(rcm_output)); + yacl::Buffer buf(n * sizeof(uint128_t)); + auto rcm_output = MakeSpan_Uint128(buf); + + RecvRandCorrelatedMsgChosenChoice(choices, rcm_output); std::array pad; std::vector corr_output(kOTBatchSize); @@ -347,7 +351,10 @@ struct YaclFerretOt::Impl { size_t n) { SPU_ENFORCE(msg0 != nullptr && msg1 != nullptr); SPU_ENFORCE(n > 0); - yacl::AlignedVector rcm_data(n); + + yacl::Buffer buf(n * sizeof(uint128_t)); + auto rcm_data = MakeSpan_Uint128(buf); + SendRandCorrelatedMsgChosenChoice(rcm_data.data(), n); uint128_t delta = ferret_->GetDelta(); @@ -368,6 +375,7 @@ struct YaclFerretOt::Impl { io_->send_data(pad.data(), 2 * sizeof(uint128_t) * this_batch); } + io_->flush(); } void RecvChosenMsgChosenChoice(absl::Span choices, @@ -399,15 +407,17 @@ struct YaclFerretOt::Impl { SPU_ENFORCE_EQ(n, output1.size()); const T mask = makeBitsMask(bit_width); - yacl::AlignedVector rm_data(2 * n); - auto* rm_data0 = rm_data.data(); - auto* rm_data1 = rm_data.data() + n; - SendRandMsgRandChoice({rm_data0, n}, {rm_data1, n}); + yacl::Buffer buf(2 * n * sizeof(uint128_t)); + auto rm_data = MakeSpan_Uint128(buf); + + auto rm_data0 = rm_data.subspan(0, n); + auto rm_data1 = rm_data.subspan(n, n); + SendRandMsgRandChoice(rm_data0, rm_data1); - std::transform(rm_data0, rm_data0 + n, output0.data(), - [mask](uint128_t x) { return (T)x & mask; }); - std::transform(rm_data1, rm_data1 + n, output1.data(), - [mask](uint128_t x) { return (T)x & mask; }); + std::transform(rm_data0.cbegin(), rm_data0.cend(), output0.data(), + [mask](const uint128_t& x) { return (T)x & mask; }); + std::transform(rm_data1.cbegin(), rm_data1.cend(), output1.data(), + [mask](const uint128_t& x) { return (T)x & mask; }); } template @@ -418,12 +428,13 @@ struct YaclFerretOt::Impl { SPU_ENFORCE_EQ(n, output.size()); const T mask = makeBitsMask(bit_width); - yacl::AlignedVector rm_data(n); + yacl::Buffer buf(n * sizeof(uint128_t)); + auto rm_data = MakeSpan_Uint128(buf); - RecvRandMsgRandChoice(choices, absl::MakeSpan(rm_data)); + RecvRandMsgRandChoice(choices, rm_data); - std::transform(rm_data.begin(), rm_data.end(), output.data(), - [mask](uint128_t x) { return ((T)x) & mask; }); + std::transform(rm_data.cbegin(), rm_data.cend(), output.data(), + [mask](const uint128_t& x) { return ((T)x) & mask; }); } template @@ -442,13 +453,18 @@ struct YaclFerretOt::Impl { // Send: (s_{0, j}, s_{1, j}) for 0 <= j < logN // Recv: c_j \in {0, 1} - yacl::AlignedVector rm_data0(n * logN); - yacl::AlignedVector rm_data1(n * logN); + yacl::Buffer buf_data0(n * logN * sizeof(uint128_t)); + yacl::Buffer buf_data1(n * logN * sizeof(uint128_t)); + auto rm_data0 = MakeSpan_Uint128(buf_data0); + auto rm_data1 = MakeSpan_Uint128(buf_data1); SendRandMsgChosenChoice(rm_data0.data(), rm_data1.data(), n * logN); - yacl::AlignedVector hash_in0(N - 1); - yacl::AlignedVector hash_in1(N - 1); + yacl::Buffer buf_in0((N - 1) * sizeof(uint128_t)); + yacl::Buffer buf_in1((N - 1) * sizeof(uint128_t)); + auto hash_in0 = MakeSpan_Uint128(buf_in0); + auto hash_in1 = MakeSpan_Uint128(buf_in1); + { size_t idx = 0; for (size_t x = 0; x < logN; ++x) { @@ -460,9 +476,13 @@ struct YaclFerretOt::Impl { } } - yacl::AlignedVector hash_out0(N - 1); - yacl::AlignedVector hash_out1(N - 1); - yacl::AlignedVector pad(kOTBatchSize * N); + yacl::Buffer buf_out0((N - 1) * sizeof(uint128_t)); + yacl::Buffer buf_out1((N - 1) * sizeof(uint128_t)); + yacl::Buffer buf_pad(kOTBatchSize * N * sizeof(uint128_t)); + + auto hash_out0 = MakeSpan_Uint128(buf_out0); + auto hash_out1 = MakeSpan_Uint128(buf_out1); + auto pad = MakeSpan_Uint128(buf_pad); const T msg_mask = makeBitsMask(bit_width); size_t eltsize = 8 * sizeof(T); @@ -481,10 +501,12 @@ struct YaclFerretOt::Impl { std::memset(pad.data(), 0, pad.size() * sizeof(uint128_t)); for (size_t j = 0; j < this_batch; ++j) { - mitccrh_exp_.renew_ks(&rm_data0[(i + j) * logN], logN); + mitccrh_exp_.renew_ks( + reinterpret_cast(&rm_data0[(i + j) * logN]), logN); mitccrh_exp_.hash_exp(hash_out0.data(), hash_in0.data(), logN); - mitccrh_exp_.renew_ks(&rm_data1[(i + j) * logN], logN); + mitccrh_exp_.renew_ks( + reinterpret_cast(&rm_data1[(i + j) * logN]), logN); mitccrh_exp_.hash_exp(hash_out1.data(), hash_in1.data(), logN); for (size_t k = 0; k < N; ++k) { @@ -521,6 +543,7 @@ struct YaclFerretOt::Impl { io_->send_data(to_send.data(), N * this_batch * sizeof(T)); } } + io_->flush(); } template @@ -548,13 +571,18 @@ struct YaclFerretOt::Impl { // rm_data[logN * i + k] = 1-of-2 OT on the k-th bits of the i-th // message - yacl::AlignedVector rm_data(n * logN); - RecvRandMsgChosenChoice(absl::MakeSpan(bool_choices), - absl::MakeSpan(rm_data)); + yacl::Buffer buf(n * logN * sizeof(uint128_t)); + auto rm_data = MakeSpan_Uint128(buf); + + RecvRandMsgChosenChoice(absl::MakeSpan(bool_choices), rm_data); + + yacl::Buffer buf_in(logN * sizeof(uint128_t)); + yacl::Buffer buf_out(logN * sizeof(uint128_t)); + yacl::Buffer buf_pad(kOTBatchSize * sizeof(uint128_t)); - yacl::AlignedVector hash_in(logN); - yacl::AlignedVector hash_out(logN); - yacl::AlignedVector pad(kOTBatchSize); + auto hash_in = MakeSpan_Uint128(buf_in); + auto hash_out = MakeSpan_Uint128(buf_out); + auto pad = MakeSpan_Uint128(buf_pad); const T msg_mask = makeBitsMask(bit_width); size_t eltsize = 8 * sizeof(T); @@ -584,7 +612,8 @@ struct YaclFerretOt::Impl { auto h = choices[i + j] & makeBitsMask(1 + s); hash_in[s] = yacl::MakeUint128(h, 0); } - mitccrh_exp_.renew_ks(&rm_data[(i + j) * logN], logN); + mitccrh_exp_.renew_ks( + reinterpret_cast(&rm_data[(i + j) * logN]), logN); mitccrh_exp_.hash_single(hash_out.data(), hash_in.data(), logN); pad[j] = std::accumulate(hash_out.begin(), hash_out.end(), pad[j], @@ -604,17 +633,21 @@ struct YaclFerretOt::Impl { SPU_ENFORCE(n > 0); SPU_ENFORCE_EQ(n, output1.size()); - yacl::AlignedVector rm_data(2 * n); + yacl::Buffer buf(2 * n * sizeof(uint128_t)); + auto rm_data = MakeSpan_Uint128(buf); + auto* rm_data0 = rm_data.data(); auto* rm_data1 = rm_data0 + n; SendRandMsgChosenChoice(rm_data0, rm_data1, n); // Type conversion const T msg_mask = makeBitsMask(bit_width); - std::transform(rm_data0, rm_data0 + n, output0.data(), - [msg_mask](uint128_t val) { return ((T)val) & msg_mask; }); - std::transform(rm_data1, rm_data1 + n, output1.data(), - [msg_mask](uint128_t val) { return ((T)val) & msg_mask; }); + std::transform( + rm_data0, rm_data0 + n, output0.data(), + [msg_mask](const uint128_t& val) { return ((T)val) & msg_mask; }); + std::transform( + rm_data1, rm_data1 + n, output1.data(), + [msg_mask](const uint128_t& val) { return ((T)val) & msg_mask; }); } // Modified by @wenfan @@ -625,13 +658,16 @@ struct YaclFerretOt::Impl { SPU_ENFORCE(n > 0); SPU_ENFORCE_EQ(n, output.size()); - yacl::AlignedVector rm_data(n); - RecvRandMsgChosenChoice(choices, absl::MakeSpan(rm_data)); + yacl::Buffer buf(n * sizeof(uint128_t)); + auto rm_data = MakeSpan_Uint128(buf); + + RecvRandMsgChosenChoice(choices, rm_data); // Type conversion const T msg_mask = makeBitsMask(bit_width); - std::transform(rm_data.begin(), rm_data.end(), output.begin(), - [msg_mask](uint128_t val) { return ((T)val) & msg_mask; }); + std::transform( + rm_data.begin(), rm_data.end(), output.begin(), + [msg_mask](const uint128_t& val) { return ((T)val) & msg_mask; }); } // Inplace diff --git a/libspu/mpc/cheetah/ot/yacl/ferret.h b/libspu/mpc/cheetah/ot/yacl/ferret.h index 6b0daf1d..25bcec06 100644 --- a/libspu/mpc/cheetah/ot/yacl/ferret.h +++ b/libspu/mpc/cheetah/ot/yacl/ferret.h @@ -20,6 +20,7 @@ #include "yacl/base/int128.h" #include "libspu/mpc/cheetah/ot/ferret_ot_interface.h" +#include "libspu/mpc/cheetah/ot/yacl/yacl_util.h" #include "libspu/mpc/common/communicator.h" namespace spu::mpc::cheetah { diff --git a/libspu/mpc/cheetah/ot/yacl/yacl_ferret_test.cc b/libspu/mpc/cheetah/ot/yacl/yacl_ferret_test.cc deleted file mode 100644 index a9298a69..00000000 --- a/libspu/mpc/cheetah/ot/yacl/yacl_ferret_test.cc +++ /dev/null @@ -1,211 +0,0 @@ -// Copyright 2023 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "libspu/mpc/cheetah/ot/yacl/yacl_ferret.h" - -#include - -#include "gtest/gtest.h" - -#include "libspu/core/xt_helper.h" -#include "libspu/mpc/cheetah/yacl_ot/basic_ot_prot.h" -#include "libspu/mpc/semi2k/type.h" -#include "libspu/mpc/utils/ring_ops.h" -#include "libspu/mpc/utils/simulate.h" - -namespace spu::mpc::cheetah::test { - -class YaclFerretTest : public testing::TestWithParam {}; - -INSTANTIATE_TEST_SUITE_P( - Cheetah, YaclFerretTest, - testing::Values(FieldType::FM32, FieldType::FM64, FieldType::FM128), - [](const testing::TestParamInfo &p) { - return fmt::format("{}", p.param); - }); - -TEST_P(YaclFerretTest, ChosenCorrelationChosenChoice) { - size_t kWorldSize = 2; - int64_t n = 10; - auto field = GetParam(); - - auto _correlation = ring_rand(field, {n}); - std::vector choices(n); - std::default_random_engine rdv; - std::uniform_int_distribution uniform(0, -1); - std::generate_n(choices.begin(), n, [&]() -> uint8_t { - return static_cast(uniform(rdv) & 1); - }); - - DISPATCH_ALL_FIELDS(field, "", [&]() { - NdArrayView correlation(_correlation); - std::vector computed[2]; - utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { - auto conn = std::make_shared(ctx); - int rank = ctx->Rank(); - computed[rank].resize(n); - YaclFerretOt ferret(conn, rank == 0); - if (rank == 0) { - ferret.SendCAMCC({correlation.data(), correlation.size()}, - absl::MakeSpan(computed[0])); - ferret.Flush(); - } else { - ferret.RecvCAMCC(absl::MakeSpan(choices), absl::MakeSpan(computed[1])); - } - }); - - for (int64_t i = 0; i < n; ++i) { - ring2k_t c = -computed[0][i] + computed[1][i]; - ring2k_t e = choices[i] ? correlation[i] : 0; - EXPECT_EQ(e, c); - } - }); -} - -TEST_P(YaclFerretTest, RndMsgRndChoice) { - size_t kWorldSize = 2; - auto field = GetParam(); - constexpr size_t bw = 2; - - size_t n = 10; - DISPATCH_ALL_FIELDS(field, "", [&]() { - std::vector msg0(n); - std::vector msg1(n); - ring2k_t max = static_cast(1) << bw; - - std::vector choices(n); - std::vector selected(n); - - utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { - auto conn = std::make_shared(ctx); - int rank = ctx->Rank(); - YaclFerretOt ferret(conn, rank == 0); - if (rank == 0) { - ferret.SendRMRC(absl::MakeSpan(msg0), absl::MakeSpan(msg1), bw); - ferret.Flush(); - } else { - ferret.RecvRMRC(absl::MakeSpan(choices), absl::MakeSpan(selected), bw); - } - }); - - for (size_t i = 0; i < n; ++i) { - ring2k_t e = choices[i] ? msg1[i] : msg0[i]; - ring2k_t c = selected[i]; - EXPECT_TRUE(choices[i] < 2); - EXPECT_LT(e, max); - EXPECT_LT(c, max); - EXPECT_EQ(e, c); - } - }); -} - -TEST_P(YaclFerretTest, RndMsgChosenChoice) { - size_t kWorldSize = 2; - auto field = GetParam(); - constexpr size_t bw = 2; - - size_t n = 10; - DISPATCH_ALL_FIELDS(field, "", [&]() { - std::vector msg0(n); - std::vector msg1(n); - ring2k_t max = static_cast(1) << bw; - - std::vector choices(n); - std::default_random_engine rdv; - std::uniform_int_distribution uniform(0, -1); - std::generate_n(choices.begin(), n, [&]() -> uint8_t { - return static_cast(uniform(rdv) & 1); - }); - - std::vector selected(n); - - utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { - auto conn = std::make_shared(ctx); - int rank = ctx->Rank(); - YaclFerretOt ferret(conn, rank == 0); - if (rank == 0) { - ferret.SendRMCC(absl::MakeSpan(msg0), absl::MakeSpan(msg1), bw); - ferret.Flush(); - } else { - ferret.RecvRMCC(absl::MakeSpan(choices), absl::MakeSpan(selected), bw); - } - }); - - for (size_t i = 0; i < n; ++i) { - ring2k_t e = choices[i] ? msg1[i] : msg0[i]; - ring2k_t c = selected[i]; - EXPECT_LT(e, max); - EXPECT_LT(c, max); - EXPECT_EQ(e, c); - } - }); -} - -TEST_P(YaclFerretTest, ChosenMsgChosenChoice) { - size_t kWorldSize = 2; - int64_t n = 106; - auto field = GetParam(); - DISPATCH_ALL_FIELDS(field, "", [&]() { - using scalar_t = ring2k_t; - std::default_random_engine rdv; - std::uniform_int_distribution uniform(0, -1); - for (size_t bw : {2UL, 4UL, sizeof(scalar_t) * 8}) { - scalar_t mask = (static_cast(1) << bw) - 1; - // @wenfan - // N = 4,5,6,7,8 would raise errors as following: - // 1. (N=8) --> corrupted double-linked list - // 2. (otherwise) --> free(): invalid next size (fast) - for (int64_t N : {2, 3, 4, 5, 6, 7, 8}) { - SPDLOG_INFO(fmt::format("bw is {}, N is {}", bw, N)); - auto _msg = ring_rand(field, {N * n}); - auto msg = xt_mutable_adapt(_msg); - msg &= mask; - std::vector choices(n); - - std::generate_n(choices.begin(), n, [&]() -> uint8_t { - return static_cast(uniform(rdv) % N); - }); - - std::vector selected(n); - - utils::simulate( - kWorldSize, [&](std::shared_ptr ctx) { - auto conn = std::make_shared(ctx); - int rank = ctx->Rank(); - { - YaclFerretOt ferret(conn, rank == 0); - if (rank == 0) { - ferret.SendCMCC({msg.data(), msg.size()}, N, bw); - ferret.Flush(); - } else { - ferret.RecvCMCC(absl::MakeSpan(choices), N, - absl::MakeSpan(selected), bw); - } - } - SPDLOG_INFO(fmt::format("Rank {} End of Simulation", rank)); - }); - SPDLOG_INFO(fmt::format("Test Checking")); - - for (int64_t i = 0; i < n; ++i) { - scalar_t e = msg[i * N + choices[i]]; - scalar_t c = selected[i]; - EXPECT_EQ(e, c); - } - SPDLOG_INFO(fmt::format("End of Test Case")); - } - } - }); -} - -} // namespace spu::mpc::cheetah::test diff --git a/libspu/mpc/cheetah/ot/yacl/yacl_ote_adapter.cc b/libspu/mpc/cheetah/ot/yacl/yacl_ote_adapter.cc index 19c3f2e2..21f11db7 100644 --- a/libspu/mpc/cheetah/ot/yacl/yacl_ote_adapter.cc +++ b/libspu/mpc/cheetah/ot/yacl/yacl_ote_adapter.cc @@ -19,6 +19,10 @@ namespace spu::mpc::cheetah { namespace yc = yacl::crypto; namespace yl = yacl::link; +// ------------------------ +// FerretOTeAdapter +// ------------------------ + uint128_t YaclFerretOTeAdapter::yacl_id_ = 0; void YaclFerretOTeAdapter::OneTimeSetup() { @@ -27,68 +31,54 @@ void YaclFerretOTeAdapter::OneTimeSetup() { } uint128_t pre_lpn_num_ = yc::FerretCotHelper(pre_lpn_param_, 0); - // Sender if (is_sender_) { - auto choices = yc::RandBits>(128, true); - // In Compact mode, the last bit of delta is one - choices.data()[0] = (choices.data()[0] | ~one); + auto ss_sender = yc::SoftspokenOtExtSender(); + ss_sender.OneTimeSetup(ctx_); + + auto ss_send_blocks = + yacl::AlignedVector>(pre_lpn_num_, {0, 0}); + ss_sender.Send(ctx_, absl::MakeSpan(ss_send_blocks), true); + + auto ss_send_block0 = yacl::AlignedVector(pre_lpn_num_, 0); + std::transform( + ss_send_blocks.cbegin(), ss_send_blocks.cend(), ss_send_block0.begin(), + [&one = std::as_const(one)](const std::array& blocks) { + return blocks[0] & one; + }); + + Delta = ss_sender.GetDelta() | ~one; + auto pre_ferret_sent_ot = + yc::MakeCompactOtSendStore(std::move(ss_send_block0), Delta); - // Generate BaseOT for IKNP-OTe - auto base_ot = yc::BaseOtRecv(ctx_, choices, 128); - // Invoke IKNP-OTe to generate COT - auto iknp_send_ot = yc::IknpOtExtSend(ctx_, base_ot, pre_lpn_num_, true); - Delta = iknp_send_ot.GetDelta(); - - // Notice !!! - // IknpOtExt Protocol would generate Normal mode OtStore - // But ferret OTe require COMPACT mode OtStore - yc::OtSendStore pre_ferret_sent_ot_ = - yc::OtSendStore(pre_lpn_num_, yc::OtStoreType::Compact); - pre_ferret_sent_ot_.SetDelta(Delta); - // Warning: copy, low efficiency - for (uint64_t i = 0; i < pre_lpn_num_; ++i) { - pre_ferret_sent_ot_.SetCompactBlock(i, iknp_send_ot.GetBlock(i, 0) & one); - } // pre ferret OTe - auto send_ot_store = yc::FerretOtExtSend(ctx_, pre_ferret_sent_ot_, - pre_lpn_param_, pre_lpn_param_.n); - // fill ot_buff_ - for (size_t i = 0; i < pre_lpn_param_.n; ++i) { - ot_buff_[i] = send_ot_store.GetBlock(i, 0); - } + yc::FerretOtExtSend_cheetah( + ctx_, pre_ferret_sent_ot, pre_lpn_param_, pre_lpn_param_.n, + absl::MakeSpan(ot_buff_.data(), pre_lpn_param_.n)); } // Receiver else { - // Generate BaseOT for IKNP-OTe - auto base_ot = yc::BaseOtSend(ctx_, 128); - // Random choices for IKNP-OTe - auto choices = + auto ss_receiver = yc::SoftspokenOtExtReceiver(); + ss_receiver.OneTimeSetup(ctx_); + + auto ss_choices = yc::RandBits>(pre_lpn_num_, true); - // Invoke IKNP-OTe to generate COT - auto iknp_recv_ot = - yc::IknpOtExtRecv(ctx_, base_ot, choices, pre_lpn_num_, true); - - // Notice !!! - // IknpOtExt Protocol would generate Normal mode OtStore - // But ferret OTe require COMPACT mode OtStore - yc::OtRecvStore pre_ferret_recv_ot_ = - yc::OtRecvStore(pre_lpn_num_, yc::OtStoreType::Compact); - // Warning: copy, low efficiency + auto ss_recv_blocks = yacl::AlignedVector(pre_lpn_num_, 0); + + ss_receiver.Recv(ctx_, ss_choices, absl::MakeSpan(ss_recv_blocks), true); + for (uint64_t i = 0; i < pre_lpn_num_; ++i) { - uint128_t block = (iknp_recv_ot.GetBlock(i) & one) | choices[i]; - pre_ferret_recv_ot_.SetBlock(i, block); + ss_recv_blocks[i] = (ss_recv_blocks[i] & one) | ss_choices[i]; } + yc::OtRecvStore pre_ferret_recv_ot = + yc::MakeCompactOtRecvStore(std::move(ss_recv_blocks)); + // pre ferret OTe - auto recv_ot_store = yc::FerretOtExtRecv(ctx_, pre_ferret_recv_ot_, - pre_lpn_param_, pre_lpn_param_.n); - // fill ot_buff_ - for (size_t i = 0; i < pre_lpn_param_.n; ++i) { - ot_buff_[i] = recv_ot_store.GetBlock(i); - } + yc::FerretOtExtRecv_cheetah( + ctx_, pre_ferret_recv_ot, pre_lpn_param_, pre_lpn_param_.n, + absl::MakeSpan(ot_buff_.data(), pre_lpn_param_.n)); } - is_setup_ = true; buff_used_num_ = reserve_num_; buff_upper_bound_ = pre_lpn_param_.n; @@ -104,14 +94,13 @@ void YaclFerretOTeAdapter::rcot(absl::Span data) { uint64_t data_offset = 0; uint64_t require_num = data.size(); uint64_t remain_num = buff_upper_bound_ - buff_used_num_; - // When require_num is greater than lpn_param.n // call FerretOTe with data's subspan to avoid memory copy { uint32_t bootstrap_inplace_counter = 0; absl::Span ot_span = - absl::MakeSpan(ot_buff_.data(), reserve_num_); - while (require_num > lpn_param_.n) { + absl::MakeSpan(ot_buff_.data(), reserve_num_); + while (require_num >= lpn_param_.n) { // avoid memory copy BootstrapInplace(ot_span, data.subspan(data_offset, lpn_param_.n)); @@ -123,14 +112,16 @@ void YaclFerretOTeAdapter::rcot(absl::Span data) { ot_span = data.subspan(data_offset, reserve_num_); } if (bootstrap_inplace_counter != 0) { - memcpy(ot_buff_.data(), ot_span.data(), reserve_num_ * sizeof(uint128_t)); + std::memcpy(reinterpret_cast(ot_buff_.data()), + ot_span.data(), reserve_num_ * sizeof(uint128_t)); } } uint64_t ot_num = std::min(remain_num, require_num); - memcpy(data.data() + data_offset, ot_buff_.data() + buff_used_num_, - ot_num * sizeof(uint128_t)); + std::memcpy(data.data() + data_offset, + ot_buff_.data() + buff_used_num_, + ot_num * sizeof(uint128_t)); buff_used_num_ += ot_num; // add state @@ -148,7 +139,8 @@ void YaclFerretOTeAdapter::rcot(absl::Span data) { if (require_num > (buff_upper_bound_ - reserve_num_)) { SPDLOG_WARN("[YACL] Worst Case!!! current require_num {}", require_num); // Bootstrap would reset buff_used_num_ - memcpy(data.data() + data_offset, ot_buff_.data() + buff_used_num_, + memcpy(data.data() + data_offset, + ot_buff_.data() + reserve_num_, (buff_upper_bound_ - reserve_num_) * sizeof(uint128_t)); require_num -= (buff_upper_bound_ - reserve_num_); consumed_ot_num_ += (buff_upper_bound_ - reserve_num_); @@ -157,7 +149,8 @@ void YaclFerretOTeAdapter::rcot(absl::Span data) { // Bootstrap would reset buff_used_num_ Bootstrap(); } - memcpy(data.data() + data_offset, ot_buff_.data() + buff_used_num_, + memcpy(data.data() + data_offset, + ot_buff_.data() + buff_used_num_, require_num * sizeof(uint128_t)); buff_used_num_ += require_num; consumed_ot_num_ += require_num; @@ -207,17 +200,17 @@ void YaclFerretOTeAdapter::recv_cot( void YaclFerretOTeAdapter::Bootstrap() { auto begin = std::chrono::high_resolution_clock::now(); if (is_sender_) { - yacl::AlignedVector send_ot(ot_buff_.begin(), - ot_buff_.begin() + reserve_num_); + yacl::AlignedVector send_ot( + ot_buff_.data(), ot_buff_.data() + reserve_num_); auto send_ot_store = yc::MakeCompactOtSendStore(std::move(send_ot), Delta); yc::FerretOtExtSend_cheetah(ctx_, send_ot_store, lpn_param_, lpn_param_.n, - absl::MakeSpan(ot_buff_.data(), lpn_param_.n)); + MakeSpan_Uint128(ot_buff_)); } else { - yacl::AlignedVector recv_ot(ot_buff_.begin(), - ot_buff_.begin() + reserve_num_); + yacl::AlignedVector recv_ot( + ot_buff_.data(), ot_buff_.data() + reserve_num_); auto recv_ot_store = yc::MakeCompactOtRecvStore(std::move(recv_ot)); yc::FerretOtExtRecv_cheetah(ctx_, recv_ot_store, lpn_param_, lpn_param_.n, - absl::MakeSpan(ot_buff_.data(), lpn_param_.n)); + MakeSpan_Uint128(ot_buff_)); } auto end = std::chrono::high_resolution_clock::now(); auto elapse = @@ -260,13 +253,16 @@ void YaclFerretOTeAdapter::BootstrapInplace(absl::Span ot, bootstrap_time_ += elapse * 1000; } +// ------------------------ +// IknpOTeAdapter +// ------------------------ + uint128_t YaclIknpOTeAdapter::yacl_id_ = 0; void YaclIknpOTeAdapter::OneTimeSetup() { if (is_setup_) { return; } - // Sender if (is_sender_) { auto choices = yc::RandBits>(128, true); @@ -286,4 +282,99 @@ void YaclIknpOTeAdapter::OneTimeSetup() { is_setup_ = true; } +void YaclIknpOTeAdapter::send_cot(absl::Span data) { + YACL_ENFORCE(is_sender_); + auto begin = std::chrono::high_resolution_clock::now(); + + // [Warning] copy, low efficiency + yacl::Buffer send_buf(2 * data.size() * sizeof(uint128_t)); + // std::vector> send_blocks(data.size()); + auto send_span = absl::MakeSpan( + reinterpret_cast*>(send_buf.data()), + data.size()); + yc::IknpOtExtSend(ctx_, *recv_ot_ptr_, send_span, true); + std::transform( + send_span.cbegin(), send_span.cend(), data.begin(), + [](const std::array& blocks) { return blocks[0]; }); + + auto end = std::chrono::high_resolution_clock::now(); + auto elapse = + std::chrono::duration_cast>(end - begin) + .count(); + ote_time_ += elapse * 1000; + consumed_ot_num_ += data.size(); + ++ote_num_; +} + +void YaclIknpOTeAdapter::recv_cot( + absl::Span data, + const yacl::dynamic_bitset& choices) { + YACL_ENFORCE(is_sender_ == false); + auto begin = std::chrono::high_resolution_clock::now(); + yc::IknpOtExtRecv(ctx_, *send_ot_ptr_, choices, absl::MakeSpan(data), true); + auto end = std::chrono::high_resolution_clock::now(); + auto elapse = + std::chrono::duration_cast>(end - begin) + .count(); + ote_time_ += elapse * 1000; + consumed_ot_num_ += data.size(); + ++ote_num_; +} + +// ------------------------ +// SoftspokenOTeAdapter +// ------------------------ + +uint128_t YaclSsOTeAdapter::yacl_id_ = 0; + +void YaclSsOTeAdapter::OneTimeSetup() { + if (is_setup_) { + return; + } + if (is_sender_) { + ss_sender_->OneTimeSetup(ctx_); + Delta = ss_sender_->GetDelta(); + } else { + ss_receiver_->OneTimeSetup(ctx_); + } +} + +void YaclSsOTeAdapter::send_cot(absl::Span data) { + YACL_ENFORCE(is_sender_); + auto begin = std::chrono::high_resolution_clock::now(); + // [Warning] copy, low efficiency + yacl::Buffer send_buf(2 * data.size() * sizeof(uint128_t)); + // std::vector> send_blocks(data.size()); + auto send_span = absl::MakeSpan( + reinterpret_cast*>(send_buf.data()), + data.size()); + + ss_sender_->Send(ctx_, send_span, true); + std::transform( + send_span.cbegin(), send_span.cend(), data.begin(), + [](const std::array& blocks) { return blocks[0]; }); + + auto end = std::chrono::high_resolution_clock::now(); + auto elapse = + std::chrono::duration_cast>(end - begin) + .count(); + ote_time_ += elapse * 1000; + consumed_ot_num_ += data.size(); + ++ote_num_; +} + +void YaclSsOTeAdapter::recv_cot( + absl::Span data, + const yacl::dynamic_bitset& choices) { + YACL_ENFORCE(is_sender_ == false); + auto begin = std::chrono::high_resolution_clock::now(); + ss_receiver_->Recv(ctx_, choices, data, true); + auto end = std::chrono::high_resolution_clock::now(); + auto elapse = + std::chrono::duration_cast>(end - begin) + .count(); + ote_time_ += elapse * 1000; + consumed_ot_num_ += data.size(); + ++ote_num_; +} }; // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/ot/yacl/yacl_ote_adapter.h b/libspu/mpc/cheetah/ot/yacl/yacl_ote_adapter.h index 1904a9c1..fcae303c 100644 --- a/libspu/mpc/cheetah/ot/yacl/yacl_ote_adapter.h +++ b/libspu/mpc/cheetah/ot/yacl/yacl_ote_adapter.h @@ -19,6 +19,7 @@ #include "yacl/crypto/primitives/ot/ferret_ote.h" #include "yacl/crypto/primitives/ot/iknp_ote.h" #include "yacl/crypto/primitives/ot/ot_store.h" +#include "yacl/crypto/primitives/ot/softspoken_ote.h" #include "yacl/crypto/utils/rand.h" #include "libspu/core/prelude.h" @@ -48,12 +49,13 @@ class YaclOTeAdapter { class YaclFerretOTeAdapter : public YaclOTeAdapter { public: - YaclFerretOTeAdapter(const std::shared_ptr ctx, bool is_sender) { - ctx_ = ctx->Spawn(); // Spawn link + YaclFerretOTeAdapter(const std::shared_ptr& ctx, + bool is_sender) { + ctx_ = ctx; is_sender_ = is_sender; reserve_num_ = yc::FerretCotHelper(lpn_param_, 0); - ot_buff_.resize(lpn_param_.n); + ot_buff_ = yacl::Buffer(lpn_param_.n * sizeof(uint128_t)); id_ = yacl_id_; ++yacl_id_; @@ -121,11 +123,15 @@ class YaclFerretOTeAdapter : public YaclOTeAdapter { uint64_t buff_upper_bound_{0}; - yacl::AlignedVector ot_buff_; // ot buffer + // We choose `yacl::Buffer` instead of `yacl::AlignedVector`. Because + // `yacl::AlignedVector` or `std::vector` would fill the + // vector with initializing data. When `size` is a big number, it would + // take lots of time to set the memory (ten millison for thiry milliseconds). + // Thus, we use `yacl::Buffer` to avoid meaningless initialization. + yacl::Buffer ot_buff_; // ot buffer // Yacl Ferret OTe void Bootstrap(); - // Yacl Ferret OTe void BootstrapInplace(absl::Span ot, absl::Span data); @@ -139,15 +145,15 @@ class YaclFerretOTeAdapter : public YaclOTeAdapter { class YaclIknpOTeAdapter : public YaclOTeAdapter { public: - YaclIknpOTeAdapter(const std::shared_ptr ctx, bool is_sender) { - ctx_ = ctx->Spawn(); // Spawn link + YaclIknpOTeAdapter(const std::shared_ptr& ctx, bool is_sender) { + ctx_ = ctx; is_sender_ = is_sender; id_ = yacl_id_; ++yacl_id_; } ~YaclIknpOTeAdapter() { - SPDLOG_INFO( + SPDLOG_DEBUG( "[IknpAdapter {}]({}), comsume OT {}, total time {:.3e} ms," "invoke IKNP-OTe {} ( {:.2e} ms per iknp , {:.2e} ms per ot )", id_, (is_sender_ ? fmt::format("Sender") : fmt::format("Receiver")), @@ -175,38 +181,10 @@ class YaclIknpOTeAdapter : public YaclOTeAdapter { // IKNP ENTRY // Correlated Cot with Chosen Choices - void send_cot(absl::Span data) override { - YACL_ENFORCE(is_sender_); - auto begin = std::chrono::high_resolution_clock::now(); - - // [Warning] copy, low efficiency - std::vector> send_blocks(data.size()); - yc::IknpOtExtSend(ctx_, *recv_ot_ptr_, absl::MakeSpan(send_blocks), true); - for (uint64_t i = 0; i < data.size(); ++i) { - data[i] = send_blocks[i][0]; - } - auto end = std::chrono::high_resolution_clock::now(); - auto elapse = - std::chrono::duration_cast>(end - begin) - .count(); - ote_time_ += elapse * 1000; - consumed_ot_num_ += data.size(); - ++ote_num_; - } + void send_cot(absl::Span data) override; void recv_cot(absl::Span data, - const yacl::dynamic_bitset& choices) { - YACL_ENFORCE(is_sender_ == false); - auto begin = std::chrono::high_resolution_clock::now(); - yc::IknpOtExtRecv(ctx_, *send_ot_ptr_, choices, absl::MakeSpan(data), true); - auto end = std::chrono::high_resolution_clock::now(); - auto elapse = - std::chrono::duration_cast>(end - begin) - .count(); - ote_time_ += elapse * 1000; - consumed_ot_num_ += data.size(); - ++ote_num_; - } + const yacl::dynamic_bitset& choices); uint128_t GetDelta() const override { return Delta; } @@ -233,4 +211,86 @@ class YaclIknpOTeAdapter : public YaclOTeAdapter { static uint128_t yacl_id_; }; +class YaclSsOTeAdapter : public YaclOTeAdapter { + public: + // LocalHost or 10000Mbps, set k = 2 + // 1000Mbps, set k = 4 + // 500Mbps, set k = 5 + // 200Mbps, set k = 7 + // 100Mbps or lower, set k = 8 + YaclSsOTeAdapter(const std::shared_ptr& ctx, bool is_sender, + uint64_t k = 2) { + ctx_ = ctx; + is_sender_ = is_sender; + + if (is_sender_) { + ss_sender_ = std::make_unique(k); + } else { + ss_receiver_ = std::make_unique(k); + } + + id_ = yacl_id_; + ++yacl_id_; + } + + ~YaclSsOTeAdapter() { + SPDLOG_DEBUG( + "[Destructor] SoftspokenAdapter work as {}, total comsume OT {}, " + "invoke softspoken {}, softspoken time {} ms, {} ms per softspoken , " + "{} ms per ot ", + (is_sender_ ? fmt::format("Sender") : fmt::format("Receiver")), + consumed_ot_num_, ote_num_, ote_time_, ote_time_ / ote_num_, + ote_time_ / consumed_ot_num_); + } + + void OneTimeSetup() override; + + void recv_cot(absl::Span data, + absl::Span choices) override { + recv_cot(data, VecU8toBitset(choices)); + } + + inline void send_rcot(absl::Span data) override { send_cot(data); } + + inline void recv_rcot(absl::Span data, + absl::Span choices) override { + auto _choices = + yc::RandBits>(data.size(), true); + BitsettoVecU8(_choices, choices); + + recv_cot(data, _choices); + } + + // Softspoken ENTRY + // Correlated Cot with Chosen Choices + void send_cot(absl::Span data) override; + + void recv_cot(absl::Span data, + const yacl::dynamic_bitset& choices); + + uint128_t GetDelta() const override { return Delta; } + + uint128_t GetConsumed() const { return consumed_ot_num_; } + + double GetTime() const { return ote_time_; } + + private: + std::shared_ptr ctx_{nullptr}; + + std::unique_ptr ss_sender_{nullptr}; + + std::unique_ptr ss_receiver_{nullptr}; + + bool is_sender_{false}; + + bool is_setup_{false}; + + // Just for test + uint128_t consumed_ot_num_{0}; + uint128_t ote_num_{0}; // number of invoke ote protocol + double ote_time_{0.0}; // ms + // Debug only + uint128_t id_{0}; + static uint128_t yacl_id_; +}; } // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/ot/yacl/yacl_util.cc b/libspu/mpc/cheetah/ot/yacl/yacl_util.cc deleted file mode 100644 index 74f57185..00000000 --- a/libspu/mpc/cheetah/ot/yacl/yacl_util.cc +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2023 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "libspu/mpc/cheetah/ot/yacl/yacl_util.h" - -#include - -#include "libspu/core/prelude.h" - -namespace spu::mpc::cheetah { - -uint8_t BoolToU8(absl::Span bits) { - size_t len = bits.size(); - SPU_ENFORCE(len >= 1 && len <= 8); - return std::accumulate( - bits.data(), bits.data() + len, - /*init*/ static_cast(0), - [](uint8_t init, uint8_t next) { return (init << 1) | (next & 1); }); -} - -void U8ToBool(absl::Span bits, uint8_t u8) { - size_t len = std::min(8UL, bits.size()); - SPU_ENFORCE(len >= 1); - for (size_t i = 0; i < len; ++i) { - bits[i] = (u8 & 1); - u8 >>= 1; - } -} - -} // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/ot/yacl/yacl_util.h b/libspu/mpc/cheetah/ot/yacl/yacl_util.h index 55987aae..c2ac2b73 100644 --- a/libspu/mpc/cheetah/ot/yacl/yacl_util.h +++ b/libspu/mpc/cheetah/ot/yacl/yacl_util.h @@ -15,6 +15,8 @@ #pragma once #include "absl/types/span.h" +#include "yacl/base/aligned_vector.h" +#include "yacl/base/buffer.h" #include "yacl/base/dynamic_bitset.h" #include "yacl/base/int128.h" @@ -57,4 +59,8 @@ inline std::vector BitsettoVecU8( return bits; } +absl::Span inline MakeSpan_Uint128(yacl::Buffer& buf) { + return absl::MakeSpan(buf.data(), buf.size() / sizeof(uint128_t)); +} + } // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/protocol.cc b/libspu/mpc/cheetah/protocol.cc index 79edf22b..2a921fd3 100644 --- a/libspu/mpc/cheetah/protocol.cc +++ b/libspu/mpc/cheetah/protocol.cc @@ -24,6 +24,7 @@ #include "libspu/mpc/cheetah/state.h" #include "libspu/mpc/cheetah/type.h" #include "libspu/mpc/common/pv2k.h" +#include "libspu/mpc/standard_shape/protocol.h" namespace spu::mpc { @@ -48,43 +49,27 @@ void regCheetahProtocol(SPUContext* ctx, // register public kernels. regPV2kKernels(ctx->prot()); + // Register standard shape ops + regStandardShapeOps(ctx); + // register arithmetic & binary kernels - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - // ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); + ctx->prot() + ->regKernel(); } std::unique_ptr makeCheetahProtocol( diff --git a/libspu/mpc/cheetah/rlwe/modswitch_helper.cc b/libspu/mpc/cheetah/rlwe/modswitch_helper.cc index 2bfe3237..69c86704 100644 --- a/libspu/mpc/cheetah/rlwe/modswitch_helper.cc +++ b/libspu/mpc/cheetah/rlwe/modswitch_helper.cc @@ -433,7 +433,7 @@ void ModulusSwitchHelper::ModulusUpAt(const NdArrayRef &src, size_t mod_idx, const size_t numel = src.numel(); SPU_ENFORCE_EQ(numel, out.size()); SPU_ENFORCE(src.shape().size() == 1, "need 1D array"); - SPU_ENFORCE(eltype.isa(), "source must be ring_type, got={}", eltype); + SPU_ENFORCE(eltype.isa(), "source must be ring_type, got={}", eltype); const auto field = eltype.as()->field(); DISPATCH_ALL_FIELDS(field, "ModulusUpAt", [&]() { @@ -448,7 +448,7 @@ void ModulusSwitchHelper::CenteralizeAt(const NdArrayRef &src, size_t mod_idx, const Type &eltype = src.eltype(); const size_t numel = src.numel(); SPU_ENFORCE_EQ(numel, out.size()); - SPU_ENFORCE(eltype.isa(), "source must be ring_type, got={}", eltype); + SPU_ENFORCE(eltype.isa(), "source must be ring_type, got={}", eltype); const auto field = eltype.as()->field(); DISPATCH_ALL_FIELDS(field, "CenteralizeAt", [&]() { using ring2u = std::make_unsigned::type; @@ -473,7 +473,7 @@ void ModulusSwitchHelper::ModulusDownRNS(absl::Span src, NdArrayRef out) const { yacl::CheckNotNull(impl_.get()); auto eltype = out.eltype(); - SPU_ENFORCE(eltype.isa(), "must be ring_type, got={}", eltype); + SPU_ENFORCE(eltype.isa(), "must be ring_type, got={}", eltype); auto field = eltype.as()->field(); SPU_ENFORCE(out.isCompact(), "need compact output"); diff --git a/libspu/mpc/common/pv2k.cc b/libspu/mpc/common/pv2k.cc index cd020dd6..3fed08e3 100644 --- a/libspu/mpc/common/pv2k.cc +++ b/libspu/mpc/common/pv2k.cc @@ -14,6 +14,7 @@ #include "libspu/mpc/common/pv2k.h" +#include #include #include "libspu/core/ndarray_ref.h" @@ -25,11 +26,15 @@ namespace spu::mpc { namespace { -bool isOwner(KernelEvalContext* ctx, const Type& type) { +inline bool isOwner(KernelEvalContext* ctx, const Type& type) { auto* comm = ctx->getState(); return type.as()->owner() == static_cast(comm->getRank()); } +inline int64_t getOwner(const NdArrayRef& x) { + return x.eltype().as()->owner(); +} + class P2V : public RevealToKernel { public: static constexpr char kBindName[] = "p2v"; @@ -681,6 +686,260 @@ class BitrevV : public BitrevKernel { } }; +class GenInvPermP : public GenInvPermKernel { + public: + static constexpr char kBindName[] = "gen_inv_perm_p"; + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext*, const NdArrayRef& in, + bool is_ascending) const override { + const auto field = in.eltype().as()->field(); + NdArrayRef out(makeType(field), in.shape()); + + auto numel = in.numel(); + + DISPATCH_ALL_FIELDS(field, "gen_inv_perm_p", [&]() { + using T = std::make_signed_t; + std::vector perm(numel); + std::iota(perm.begin(), perm.end(), 0); + // TODO: Add an iterator for NdArrayView + NdArrayView _in(in); + NdArrayView _out(out); + auto cmp = [&_in, is_ascending](int64_t a, int64_t b) { + return is_ascending ? _in[a] < _in[b] : _in[a] > _in[b]; + }; + std::stable_sort(perm.begin(), perm.end(), cmp); + for (int64_t idx = 0; idx < numel; ++idx) { + _out[perm[idx]] = idx; + } + }); + return out; + } +}; + +class GenInvPermV : public GenInvPermKernel { + public: + static constexpr char kBindName[] = "gen_inv_perm_v"; + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, + bool is_ascending) const override { + if (isOwner(ctx, in.eltype())) { + NdArrayRef out(in.eltype(), in.shape()); + auto numel = in.numel(); + const auto field = in.eltype().as()->field(); + + DISPATCH_ALL_FIELDS(field, "gen_inv_perm_v", [&]() { + using T = std::make_signed_t; + std::vector perm(numel); + std::iota(perm.begin(), perm.end(), 0); + // TODO: Add an iterator for NdArrayView + NdArrayView _in(in); + NdArrayView _out(out); + auto cmp = [&_in, is_ascending](int64_t a, int64_t b) { + return is_ascending ? _in[a] < _in[b] : _in[a] > _in[b]; + }; + std::stable_sort(perm.begin(), perm.end(), cmp); + for (int64_t idx = 0; idx < numel; ++idx) { + _out[perm[idx]] = idx; + } + }); + return out; + } else { + return in; + } + } +}; + +class InvPermPP : public PermKernel { + public: + static constexpr char kBindName[] = "inv_perm_pp"; + + ce::CExpr latency() const override { return ce::Const(0); } + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext*, const NdArrayRef& x, + const NdArrayRef& y) const override { + SPU_ENFORCE_EQ(x.eltype(), y.eltype()); + NdArrayRef z(x.eltype(), x.shape()); + const auto field = x.eltype().as()->field(); + DISPATCH_ALL_FIELDS(field, "_", [&]() { + using T = std::make_signed_t; + NdArrayView _x(x); + NdArrayView _y(y); + NdArrayView _z(z); + for (int64_t idx = 0; idx < x.numel(); ++idx) { + _z[_y[idx]] = _x[idx]; + } + }); + return z; + } +}; + +class InvPermVV : public PermKernel { + public: + static constexpr char kBindName[] = "inv_perm_vv"; + + ce::CExpr latency() const override { return ce::Const(0); } + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& x, + const NdArrayRef& y) const override { + SPU_ENFORCE_EQ(x.eltype(), y.eltype()); + if (isOwner(ctx, x.eltype())) { + NdArrayRef z(x.eltype(), x.shape()); + const auto field = x.eltype().as()->field(); + DISPATCH_ALL_FIELDS(field, "_", [&]() { + using T = std::make_signed_t; + NdArrayView _x(x); + NdArrayView _y(y); + NdArrayView _z(z); + for (int64_t idx = 0; idx < x.numel(); ++idx) { + _z[_y[idx]] = _x[idx]; + } + }); + return z; + } else { + return x; + } + } +}; + +class PermPP : public PermKernel { + public: + static constexpr char kBindName[] = "perm_pp"; + + ce::CExpr latency() const override { return ce::Const(0); } + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext*, const NdArrayRef& x, + const NdArrayRef& y) const override { + SPU_ENFORCE_EQ(x.eltype(), y.eltype()); + NdArrayRef z(x.eltype(), x.shape()); + const auto field = x.eltype().as()->field(); + DISPATCH_ALL_FIELDS(field, "_", [&]() { + using T = std::make_signed_t; + NdArrayView _x(x); + NdArrayView _y(y); + NdArrayView _z(z); + for (int64_t idx = 0; idx < x.numel(); ++idx) { + _z[idx] = _x[_y[idx]]; + } + }); + return z; + } +}; + +class PermVV : public PermKernel { + public: + static constexpr char kBindName[] = "perm_vv"; + + ce::CExpr latency() const override { return ce::Const(0); } + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& x, + const NdArrayRef& y) const override { + SPU_ENFORCE_EQ(x.eltype(), y.eltype()); + if (isOwner(ctx, x.eltype())) { + NdArrayRef z(x.eltype(), x.shape()); + const auto field = x.eltype().as()->field(); + DISPATCH_ALL_FIELDS(field, "_", [&]() { + using T = std::make_signed_t; + NdArrayView _x(x); + NdArrayView _y(y); + NdArrayView _z(z); + for (int64_t idx = 0; idx < x.numel(); ++idx) { + _z[idx] = _x[_y[idx]]; + } + }); + return z; + } else { + return x; + } + } +}; + +class MergeKeysP : public MergeKeysKernel { + public: + static constexpr char kBindName[] = "merge_keys_p"; + + ce::CExpr latency() const override { return ce::Const(0); } + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, absl::Span inputs, + bool is_ascending) const override { + SPU_ENFORCE(!inputs.empty(), "Inputs should not be empty"); + NdArrayRef out(inputs[0].eltype(), inputs[0].shape()); + const auto field = inputs[0].eltype().as()->field(); + const auto numel = inputs[0].numel(); + DISPATCH_ALL_FIELDS(field, "_", [&]() { + using T = std::make_signed_t; + NdArrayView _out(out); + _out[0] = 0; + for (int64_t i = 1; i < numel; ++i) { + if (std::all_of(inputs.begin(), inputs.end(), [i](const NdArrayRef& x) { + NdArrayView _x(x); + return _x[i] == _x[i - 1]; + })) { + _out[i] = _out[i - 1]; + } else { + _out[i] = is_ascending ? _out[i - 1] + 1 : _out[i - 1] - 1; + } + } + }); + return out; + } +}; + +class MergeKeysV : public MergeKeysKernel { + public: + static constexpr char kBindName[] = "merge_keys_v"; + + ce::CExpr latency() const override { return ce::Const(0); } + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, absl::Span inputs, + bool is_ascending) const override { + SPU_ENFORCE(!inputs.empty(), "Inputs should not be empty"); + SPU_ENFORCE(std::all_of(inputs.begin(), inputs.end(), + [&inputs](const NdArrayRef& v) { + return getOwner(v) == getOwner(inputs[0]); + }), + "Inputs should belong to the same owner"); + + if (isOwner(ctx, inputs[0].eltype())) { + NdArrayRef out(inputs[0].eltype(), inputs[0].shape()); + const auto field = inputs[0].eltype().as()->field(); + const auto numel = inputs[0].numel(); + DISPATCH_ALL_FIELDS(field, "_", [&]() { + using T = std::make_signed_t; + NdArrayView _out(out); + _out[0] = 0; + for (int64_t i = 1; i < numel; ++i) { + if (std::all_of(inputs.begin(), inputs.end(), + [i](const NdArrayRef& x) { + NdArrayView _x(x); + return _x[i] == _x[i - 1]; + })) { + _out[i] = _out[i - 1]; + } else { + _out[i] = is_ascending ? _out[i - 1] + 1 : _out[i - 1] - 1; + } + } + }); + return out; + } else { + return makeConstantArrayRef(inputs[0].eltype(), inputs[0].shape()); + } + } +}; + } // namespace void regPV2kTypes() { @@ -691,42 +950,25 @@ void regPV2kTypes() { } void regPV2kKernels(Object* obj) { - obj->regKernel(); - obj->regKernel(); - obj->regKernel(); - obj->regKernel(); - obj->regKernel(); - obj->regKernel(); - obj->regKernel(); - obj->regKernel(); - obj->regKernel(); - obj->regKernel(); - obj->regKernel(); - obj->regKernel(); - obj->regKernel(); - obj->regKernel(); - obj->regKernel(); - obj->regKernel(); - obj->regKernel(); - obj->regKernel(); - obj->regKernel(); - obj->regKernel(); - obj->regKernel(); - obj->regKernel(); - obj->regKernel(); - obj->regKernel(); - obj->regKernel(); - obj->regKernel(); - obj->regKernel(); - obj->regKernel(); - obj->regKernel(); - obj->regKernel(); - obj->regKernel(); - obj->regKernel(); - obj->regKernel(); - obj->regKernel(); - obj->regKernel(); - obj->regKernel(); + obj->regKernel(); } } // namespace spu::mpc diff --git a/libspu/mpc/kernel.cc b/libspu/mpc/kernel.cc index 6e8f8082..860db047 100644 --- a/libspu/mpc/kernel.cc +++ b/libspu/mpc/kernel.cc @@ -122,14 +122,113 @@ void CastTypeKernel::evaluate(KernelEvalContext* ctx) const { } void PermKernel::evaluate(KernelEvalContext* ctx) const { - const auto& in = ctx->getParam(0); - const auto& perm = ctx->getParam(1); + const auto& x = ctx->getParam(0); + const auto& y = ctx->getParam(1); + + SPU_ENFORCE(x.shape() == y.shape(), "shape mismatch {} {}", x.shape(), + x.shape()); + SPU_ENFORCE(x.shape().ndim() == 1, "input should be a 1-d tensor"); - SPU_ENFORCE(in.shape() == perm.shape(), "shape mismatch {} {}", in.shape(), - perm.shape()); + auto z = proc(ctx, UnwrapValue(x), UnwrapValue(y)); + + ctx->setOutput(WrapValue(z)); +} + +void GenInvPermKernel::evaluate(KernelEvalContext* ctx) const { + const auto& in = ctx->getParam(0); + bool is_ascending = ctx->getParam(1); SPU_ENFORCE(in.shape().ndim() == 1, "input should be a 1-d tensor"); - auto z = proc(ctx, UnwrapValue(in), UnwrapValue(perm)); + auto y = proc(ctx, UnwrapValue(in), is_ascending); + + ctx->setOutput(WrapValue(y)); +} + +void MergeKeysKernel::evaluate(KernelEvalContext* ctx) const { + const auto& in = ctx->getParam>(0); + bool is_ascending = ctx->getParam(1); + std::vector inputs; + for (size_t i = 0; i < in.size(); ++i) { + inputs.push_back(UnwrapValue(in[i])); + } + auto y = proc(ctx, inputs, is_ascending); + + ctx->setOutput(WrapValue(y)); +} + +void BroadcastKernel::evaluate(KernelEvalContext* ctx) const { + const auto& in = ctx->getParam(0); + const auto& to_shape = ctx->getParam(1); + const auto& in_dims = ctx->getParam(2); + + auto z = proc(ctx, UnwrapValue(in), to_shape, in_dims); + + ctx->setOutput(WrapValue(z)); +} + +void DimsBasedKernel::evaluate(KernelEvalContext* ctx) const { + const auto& in = ctx->getParam(0); + const auto& axes = ctx->getParam(1); + + auto z = proc(ctx, UnwrapValue(in), axes); + + ctx->setOutput(WrapValue(z)); +} + +void ShapeBasedKernel::evaluate(KernelEvalContext* ctx) const { + const auto& in = ctx->getParam(0); + const auto& to_shape = ctx->getParam(1); + + auto z = proc(ctx, UnwrapValue(in), to_shape); + + ctx->setOutput(WrapValue(z)); +} + +void ExtractSliceKernel::evaluate(KernelEvalContext* ctx) const { + const auto& in = ctx->getParam(0); + const auto& start = ctx->getParam(1); + const auto& end = ctx->getParam(2); + const auto& strides = ctx->getParam(3); + + auto z = proc(ctx, UnwrapValue(in), start, end, strides); + + ctx->setOutput(WrapValue(z)); +} + +void UpdateSliceKernel::evaluate(KernelEvalContext* ctx) const { + const auto& in = ctx->getParam(0); + const auto& update = ctx->getParam(1); + const auto& start = ctx->getParam(2); + + auto z = proc(ctx, UnwrapValue(in), UnwrapValue(update), start); + + ctx->setOutput(WrapValue(z)); +} + +void PadKernel::evaluate(KernelEvalContext* ctx) const { + const auto& in = ctx->getParam(0); + const auto& padding_value = ctx->getParam(1); + const auto& edge_low = ctx->getParam(2); + const auto& edge_high = ctx->getParam(3); + const auto& interior_padding = ctx->getParam(4); + + auto z = proc(ctx, UnwrapValue(in), UnwrapValue(padding_value), edge_low, + edge_high, interior_padding); + + ctx->setOutput(WrapValue(z)); +} + +void ConcateKernel::evaluate(KernelEvalContext* ctx) const { + const auto& ins = ctx->getParam>(0); + const auto& axis = ctx->getParam(1); + + std::vector unwrapped(ins.size()); + + for (size_t idx = 0; idx < ins.size(); ++idx) { + unwrapped[idx] = UnwrapValue(ins[idx]); + } + + auto z = proc(ctx, unwrapped, axis); ctx->setOutput(WrapValue(z)); } diff --git a/libspu/mpc/kernel.h b/libspu/mpc/kernel.h index 53b6cf62..bbc49903 100644 --- a/libspu/mpc/kernel.h +++ b/libspu/mpc/kernel.h @@ -127,4 +127,80 @@ class PermKernel : public Kernel { const NdArrayRef& perm) const = 0; }; +class GenInvPermKernel : public Kernel { + public: + void evaluate(KernelEvalContext* ctx) const override; + virtual NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, + bool is_ascending) const = 0; +}; + +class MergeKeysKernel : public Kernel { + public: + void evaluate(KernelEvalContext* ctx) const override; + virtual NdArrayRef proc(KernelEvalContext* ctx, + absl::Span inputs, + bool is_ascending) const = 0; +}; + +class BroadcastKernel : public Kernel { + public: + void evaluate(KernelEvalContext* ctx) const override; + + virtual NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, + const Shape& to_shape, const Axes& in_dims) const = 0; +}; + +class DimsBasedKernel : public Kernel { + public: + void evaluate(KernelEvalContext* ctx) const override; + + virtual NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, + const Axes& perm) const = 0; +}; + +class ShapeBasedKernel : public Kernel { + public: + void evaluate(KernelEvalContext* ctx) const override; + + virtual NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, + const Shape& to_shape) const = 0; +}; + +class ExtractSliceKernel : public Kernel { + public: + void evaluate(KernelEvalContext* ctx) const override; + virtual NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, + const Index& start, const Index& end, + const Strides& strides) const = 0; +}; + +class UpdateSliceKernel : public Kernel { + public: + void evaluate(KernelEvalContext* ctx) const override; + + virtual NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, + const NdArrayRef& update, + const Index& start) const = 0; +}; + +class PadKernel : public Kernel { + public: + void evaluate(KernelEvalContext* ctx) const override; + + virtual NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, + const NdArrayRef& padding_value, + const Sizes& edge_padding_low, + const Sizes& edge_padding_high, + const Sizes& interior_padding) const = 0; +}; + +class ConcateKernel : public Kernel { + public: + void evaluate(KernelEvalContext* ctx) const override; + + virtual NdArrayRef proc(KernelEvalContext* ctx, + const std::vector& values, + int64_t axis) const = 0; +}; + } // namespace spu::mpc diff --git a/libspu/mpc/ref2k/BUILD.bazel b/libspu/mpc/ref2k/BUILD.bazel index 606f3d9e..c0b3ee93 100644 --- a/libspu/mpc/ref2k/BUILD.bazel +++ b/libspu/mpc/ref2k/BUILD.bazel @@ -25,6 +25,7 @@ spu_cc_library( "//libspu/mpc:io_interface", "//libspu/mpc/common:prg_state", "//libspu/mpc/common:pv2k", + "//libspu/mpc/standard_shape:protocol", "@yacl//yacl/link", ], ) diff --git a/libspu/mpc/ref2k/ref2k.cc b/libspu/mpc/ref2k/ref2k.cc index efcd9afc..5beba06e 100644 --- a/libspu/mpc/ref2k/ref2k.cc +++ b/libspu/mpc/ref2k/ref2k.cc @@ -22,6 +22,7 @@ #include "libspu/mpc/common/prg_state.h" #include "libspu/mpc/common/pv2k.h" #include "libspu/mpc/kernel.h" +#include "libspu/mpc/standard_shape/protocol.h" #include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc { @@ -479,33 +480,23 @@ void regRef2kProtocol(SPUContext* ctx, // register public kernels. regPV2kKernels(ctx->prot()); + // Register standard shape ops + regStandardShapeOps(ctx); + // register compute kernels - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); + ctx->prot() + ->regKernel(); } std::unique_ptr makeRef2kProtocol( diff --git a/libspu/mpc/securenn/BUILD.bazel b/libspu/mpc/securenn/BUILD.bazel index cad24960..cec7bc2e 100644 --- a/libspu/mpc/securenn/BUILD.bazel +++ b/libspu/mpc/securenn/BUILD.bazel @@ -83,6 +83,7 @@ spu_cc_library( ":conversion", ":state", "//libspu/mpc/common:prg_state", + "//libspu/mpc/standard_shape:protocol", ], ) diff --git a/libspu/mpc/securenn/protocol.cc b/libspu/mpc/securenn/protocol.cc index 41b85be7..0fe800b0 100644 --- a/libspu/mpc/securenn/protocol.cc +++ b/libspu/mpc/securenn/protocol.cc @@ -14,13 +14,14 @@ #include "libspu/mpc/securenn/protocol.h" +#include "libspu/mpc/common/communicator.h" #include "libspu/mpc/common/prg_state.h" #include "libspu/mpc/common/pv2k.h" #include "libspu/mpc/securenn/arithmetic.h" #include "libspu/mpc/securenn/boolean.h" #include "libspu/mpc/securenn/conversion.h" -#include "libspu/mpc/securenn/state.h" #include "libspu/mpc/securenn/type.h" +#include "libspu/mpc/standard_shape/protocol.h" namespace spu::mpc { @@ -40,47 +41,30 @@ void regSecurennProtocol(SPUContext* ctx, // register public kernels. regPV2kKernels(ctx->prot()); + // Register standard shape ops + regStandardShapeOps(ctx); + // register arithmetic & binary kernels // ctx->prot()->addState(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - - ctx->prot()->regKernel(); - // ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); + + ctx->prot() + ->regKernel< + securenn::P2A, securenn::A2P, securenn::A2V, securenn::V2A, // + securenn::NotA, // + securenn::AddAP, securenn::AddAA, // + securenn::MulAP, securenn::MulAA, // + securenn::MatMulAP, securenn::MatMulAA, securenn::MatMulAA_simple, // + securenn::LShiftA, securenn::LShiftB, securenn::RShiftB, + securenn::ARShiftB, // + securenn::Msb, securenn::Msb_opt, // + securenn::TruncAPr, // + securenn::CommonTypeB, securenn::CommonTypeV, securenn::CastTypeB, + securenn::B2P, securenn::P2B, securenn::A2B, securenn::Msb_a2b, + /*securenn::B2A,*/ securenn::B2A_Randbit, // + securenn::AndBP, securenn::AndBB, // + securenn::XorBP, securenn::XorBB, // + securenn::BitrevB, securenn::BitIntlB, securenn::BitDeintlB, + securenn::RandA>(); } std::unique_ptr makeSecurennProtocol( diff --git a/libspu/mpc/semi2k/BUILD.bazel b/libspu/mpc/semi2k/BUILD.bazel index 2fe06595..825bf36f 100644 --- a/libspu/mpc/semi2k/BUILD.bazel +++ b/libspu/mpc/semi2k/BUILD.bazel @@ -85,6 +85,7 @@ spu_cc_library( ":permute", ":state", "//libspu/mpc/common:prg_state", + "//libspu/mpc/standard_shape:protocol", ], ) diff --git a/libspu/mpc/semi2k/permute.cc b/libspu/mpc/semi2k/permute.cc index 2c2146c7..d647306b 100644 --- a/libspu/mpc/semi2k/permute.cc +++ b/libspu/mpc/semi2k/permute.cc @@ -31,6 +31,15 @@ NdArrayRef wrap_a2v(SPUContext* ctx, const NdArrayRef& x, size_t rank) { return UnwrapValue(a2v(ctx, WrapValue(x), rank)); } +inline bool isOwner(KernelEvalContext* ctx, const Type& type) { + auto* comm = ctx->getState(); + return type.as()->owner() == static_cast(comm->getRank()); +} + +inline int64_t getOwner(const NdArrayRef& x) { + return x.eltype().as()->owner(); +} + // Secure inverse permutation of x by perm_rank's permutation pv // The idea here is: // Input permutation pv, beaver generates perm pair {, } that @@ -58,7 +67,7 @@ NdArrayRef SecureInvPerm(KernelEvalContext* ctx, const NdArrayRef& x, } // namespace -NdArrayRef RandPermS::proc(KernelEvalContext* ctx, const Shape& shape) const { +NdArrayRef RandPermM::proc(KernelEvalContext* ctx, const Shape& shape) const { NdArrayRef out(makeType(), shape); // generate a RandU64 as permutation seed @@ -77,7 +86,7 @@ NdArrayRef RandPermS::proc(KernelEvalContext* ctx, const Shape& shape) const { return out; } -NdArrayRef PermAS::proc(KernelEvalContext* ctx, const NdArrayRef& in, +NdArrayRef PermAM::proc(KernelEvalContext* ctx, const NdArrayRef& in, const NdArrayRef& perm) const { auto* comm = ctx->getState(); @@ -97,7 +106,7 @@ NdArrayRef PermAP::proc(KernelEvalContext* ctx, const NdArrayRef& in, return out; } -NdArrayRef InvPermAS::proc(KernelEvalContext* ctx, const NdArrayRef& in, +NdArrayRef InvPermAM::proc(KernelEvalContext* ctx, const NdArrayRef& in, const NdArrayRef& perm) const { auto* comm = ctx->getState(); PermVector pv = ring2pv(perm); @@ -117,4 +126,15 @@ NdArrayRef InvPermAP::proc(KernelEvalContext* ctx, const NdArrayRef& in, return out; } +NdArrayRef InvPermAV::proc(KernelEvalContext* ctx, const NdArrayRef& in, + const NdArrayRef& perm) const { + PermVector pv; + const auto lctx = ctx->lctx(); + if (isOwner(ctx, perm.eltype())) { + pv = ring2pv(perm); + } + auto out = SecureInvPerm(ctx, in, getOwner(perm), pv); + return out; +} + } // namespace spu::mpc::semi2k \ No newline at end of file diff --git a/libspu/mpc/semi2k/permute.h b/libspu/mpc/semi2k/permute.h index 0f0c3dca..20f35996 100644 --- a/libspu/mpc/semi2k/permute.h +++ b/libspu/mpc/semi2k/permute.h @@ -18,9 +18,9 @@ namespace spu::mpc::semi2k { -class RandPermS : public RandKernel { +class RandPermM : public RandKernel { public: - static constexpr char kBindName[] = "rand_perm_s"; + static constexpr char kBindName[] = "rand_perm_m"; ce::CExpr latency() const override { return ce::Const(0); } @@ -29,9 +29,9 @@ class RandPermS : public RandKernel { NdArrayRef proc(KernelEvalContext* ctx, const Shape& shape) const override; }; -class PermAS : public PermKernel { +class PermAM : public PermKernel { public: - static constexpr char kBindName[] = "perm_as"; + static constexpr char kBindName[] = "perm_am"; ce::CExpr latency() const override { return ce::N(); } @@ -53,9 +53,9 @@ class PermAP : public PermKernel { const NdArrayRef& perm) const override; }; -class InvPermAS : public PermKernel { +class InvPermAM : public PermKernel { public: - static constexpr char kBindName[] = "inv_perm_as"; + static constexpr char kBindName[] = "inv_perm_am"; ce::CExpr latency() const override { return ce::N(); } @@ -77,4 +77,19 @@ class InvPermAP : public PermKernel { const NdArrayRef& perm) const override; }; +class InvPermAV : public PermKernel { + public: + static constexpr char kBindName[] = "inv_perm_av"; + + // communication is unbalanced + Kind kind() const override { return Kind::Dynamic; } + + ce::CExpr latency() const override { return ce::Const(1); } + + ce::CExpr comm() const override { return ce::K(); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, + const NdArrayRef& perm) const override; +}; + } // namespace spu::mpc::semi2k \ No newline at end of file diff --git a/libspu/mpc/semi2k/protocol.cc b/libspu/mpc/semi2k/protocol.cc index 748f7ed6..31045228 100644 --- a/libspu/mpc/semi2k/protocol.cc +++ b/libspu/mpc/semi2k/protocol.cc @@ -23,6 +23,7 @@ #include "libspu/mpc/semi2k/permute.h" #include "libspu/mpc/semi2k/state.h" #include "libspu/mpc/semi2k/type.h" +#include "libspu/mpc/standard_shape/protocol.h" namespace spu::mpc { @@ -42,58 +43,39 @@ void regSemi2kProtocol(SPUContext* ctx, // register public kernels. regPV2kKernels(ctx->prot()); + // Register standard shape ops + regStandardShapeOps(ctx); + // register arithmetic & binary kernels ctx->prot()->addState(ctx->config(), lctx); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); + ctx->prot() + ->regKernel< + semi2k::P2A, semi2k::A2P, semi2k::A2V, semi2k::V2A, // + semi2k::NotA, // + semi2k::AddAP, semi2k::AddAA, // + semi2k::MulAP, semi2k::MulAA, // + semi2k::MatMulAP, semi2k::MatMulAA, // + semi2k::LShiftA, semi2k::LShiftB, semi2k::RShiftB, + semi2k::ARShiftB, // + semi2k::CommonTypeB, semi2k::CommonTypeV, semi2k::CastTypeB, // + semi2k::B2P, semi2k::P2B, semi2k::A2B, semi2k::B2A_Randbit, // + semi2k::AndBP, semi2k::AndBB, semi2k::XorBP, semi2k::XorBB, + semi2k::BitrevB, // + semi2k::BitIntlB, semi2k::BitDeintlB, // + semi2k::RandA, semi2k::RandPermM, semi2k::PermAM, semi2k::PermAP, + semi2k::InvPermAM, semi2k::InvPermAP, semi2k::InvPermAV, // + semi2k::EqualAA, semi2k::EqualAP>(); + if (ctx->config().trunc_allow_msb_error()) { ctx->prot()->regKernel(); } else { ctx->prot()->regKernel(); } - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - if (lctx->WorldSize() == 2) { ctx->prot()->regKernel(); } // ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); } std::unique_ptr makeSemi2kProtocol( diff --git a/libspu/mpc/semi2k/type.h b/libspu/mpc/semi2k/type.h index 287f2f64..d2f8763d 100644 --- a/libspu/mpc/semi2k/type.h +++ b/libspu/mpc/semi2k/type.h @@ -56,8 +56,8 @@ class BShrTy : public TypeImpl { } }; -class PShrTy : public TypeImpl { - using Base = TypeImpl; +class PShrTy : public TypeImpl { + using Base = TypeImpl; public: using Base::Base; diff --git a/libspu/mpc/spdz2k/BUILD.bazel b/libspu/mpc/spdz2k/BUILD.bazel index 5e9e02b4..b7719ab2 100644 --- a/libspu/mpc/spdz2k/BUILD.bazel +++ b/libspu/mpc/spdz2k/BUILD.bazel @@ -36,6 +36,7 @@ spu_cc_library( ":value", "//libspu/core:context", "//libspu/mpc/common:prg_state", + "//libspu/mpc/standard_shape:protocol", ], ) diff --git a/libspu/mpc/spdz2k/protocol.cc b/libspu/mpc/spdz2k/protocol.cc index 80c8006c..46db690a 100644 --- a/libspu/mpc/spdz2k/protocol.cc +++ b/libspu/mpc/spdz2k/protocol.cc @@ -23,6 +23,7 @@ #include "libspu/mpc/spdz2k/conversion.h" #include "libspu/mpc/spdz2k/state.h" #include "libspu/mpc/spdz2k/type.h" +#include "libspu/mpc/standard_shape/protocol.h" namespace spu::mpc { @@ -42,50 +43,29 @@ void regSpdz2kProtocol(SPUContext* ctx, // register public kernels. regPV2kKernels(ctx->prot()); + // Register standard shape ops + regStandardShapeOps(ctx); + // register arithmetic kernels ctx->prot()->addState(ctx->config(), lctx); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); + ctx->prot() + ->regKernel(); // register boolean kernels - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); + ctx->prot() + ->regKernel(); // register conversion kernels - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); - ctx->prot()->regKernel(); + ctx->prot() + ->regKernel(); } std::unique_ptr makeSpdz2kProtocol( diff --git a/libspu/mpc/standard_shape/BUILD.bazel b/libspu/mpc/standard_shape/BUILD.bazel new file mode 100644 index 00000000..186f902f --- /dev/null +++ b/libspu/mpc/standard_shape/BUILD.bazel @@ -0,0 +1,43 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//bazel:spu.bzl", "spu_cc_library", "spu_cc_test") + +package(default_visibility = ["//visibility:public"]) + +spu_cc_library( + name = "standard_shape", + deps = [ + ":protocol", + ], +) + +spu_cc_library( + name = "protocol", + srcs = ["protocol.cc"], + hdrs = ["protocol.h"], + deps = [ + ":kernels", + "//libspu/core:context", + ], +) + +spu_cc_library( + name = "kernels", + srcs = ["kernels.cc"], + hdrs = ["kernels.h"], + deps = [ + "//libspu/mpc:kernel", + ], +) diff --git a/libspu/mpc/standard_shape/kernels.cc b/libspu/mpc/standard_shape/kernels.cc new file mode 100644 index 00000000..cf4180f7 --- /dev/null +++ b/libspu/mpc/standard_shape/kernels.cc @@ -0,0 +1,116 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "libspu/mpc/standard_shape/kernels.h" + +#include + +#include "libspu/core/ndarray_ref.h" + +namespace spu::mpc::standard_shape { + +// Compact threshold heuristic, try to make it same as L1 cache size +#define COMPACT_THRESHOLD (32 * 1024) // 32K + +SPU_ALWAYS_INLINE NdArrayRef _try_compact(const NdArrayRef& in) { + // If in data is not compact after some shape ops and small enough, make it + // compact + if (in.numel() * in.elsize() <= COMPACT_THRESHOLD && !in.isCompact()) { + return in.clone(); + } + return in; +} + +NdArrayRef Broadcast::proc(KernelEvalContext* ctx, const NdArrayRef& in, + const Shape& to_shape, const Axes& in_dims) const { + return in.broadcast_to(to_shape, in_dims); +} + +NdArrayRef Reshape::proc(KernelEvalContext* ctx, const NdArrayRef& in, + const Shape& to_shape) const { + return _try_compact(in.reshape(to_shape)); +} + +NdArrayRef ExtractSlice::proc(KernelEvalContext* ctx, const NdArrayRef& in, + const Index& start, const Index& end, + const Strides& strides) const { + return _try_compact(in.slice(start, end, strides)); +} + +NdArrayRef UpdateSlice::proc(KernelEvalContext* ctx, const NdArrayRef& in, + const NdArrayRef& update, + const Index& start) const { + SPU_ENFORCE(in.eltype() == update.eltype(), + "Element type mismatch, in = {}, update ={}", in.eltype(), + update.eltype()); + + auto ret = in.clone(); + ret.update_slice(update, start); + return ret; +} + +NdArrayRef Transpose::proc(KernelEvalContext* ctx, const NdArrayRef& in, + const Axes& permutation) const { + Axes perm = permutation; + if (perm.empty()) { + // by default, transpose the data in reverse order. + perm.resize(in.shape().size()); + std::iota(perm.rbegin(), perm.rend(), 0); + } + + // sanity check. + SPU_ENFORCE_EQ(perm.size(), in.shape().size()); + std::set uniq(perm.begin(), perm.end()); + SPU_ENFORCE_EQ(uniq.size(), perm.size(), "perm={} is not unique", perm); + + // fast path, if identity permutation, return it. + Axes no_perm(in.shape().size()); + std::iota(no_perm.begin(), no_perm.end(), 0); + if (perm == no_perm) { + return in; + } + + return _try_compact(in.transpose(perm)); +} + +NdArrayRef Reverse::proc(KernelEvalContext* ctx, const NdArrayRef& in, + const Axes& dimensions) const { + return in.reverse(dimensions); +} + +NdArrayRef Fill::proc(KernelEvalContext* ctx, const NdArrayRef& in, + const Shape& to_shape) const { + return in.expand(to_shape); +} + +NdArrayRef Pad::proc(KernelEvalContext* ctx, const NdArrayRef& in, + const NdArrayRef& padding_value, + const Sizes& edge_padding_low, + const Sizes& edge_padding_high, + const Sizes& interior_padding) const { + SPU_ENFORCE(in.eltype() == padding_value.eltype(), + "Element type mismatch, in = {}, pad_value ={}", in.eltype(), + padding_value.eltype()); + return in.pad(padding_value, edge_padding_low, edge_padding_high, + interior_padding); +} + +NdArrayRef Concate::proc(KernelEvalContext* ctx, + const std::vector& values, + int64_t axis) const { + return values.front().concatenate( + absl::MakeSpan(&values[1], values.size() - 1), axis); +} + +} // namespace spu::mpc::standard_shape diff --git a/libspu/mpc/standard_shape/kernels.h b/libspu/mpc/standard_shape/kernels.h new file mode 100644 index 00000000..c4826504 --- /dev/null +++ b/libspu/mpc/standard_shape/kernels.h @@ -0,0 +1,131 @@ +// Copyright 2021 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "libspu/mpc/kernel.h" + +namespace spu::mpc::standard_shape { + +class Broadcast : public BroadcastKernel { + public: + static constexpr char kBindName[] = "broadcast"; + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, + const Shape& to_shape, const Axes& in_dims) const override; +}; + +class Reshape : public ShapeBasedKernel { + public: + static constexpr char kBindName[] = "reshape"; + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, + const Shape& to_shape) const override; +}; + +class ExtractSlice : public ExtractSliceKernel { + public: + static constexpr char kBindName[] = "extract_slice"; + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, + const Index& start, const Index& end, + const Strides& strides) const override; +}; + +class UpdateSlice : public UpdateSliceKernel { + public: + static constexpr char kBindName[] = "update_slice"; + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, + const NdArrayRef& update, const Index& start) const override; +}; + +class Transpose : public DimsBasedKernel { + public: + static constexpr char kBindName[] = "transpose"; + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, + const Axes& permutation) const override; +}; + +class Reverse : public DimsBasedKernel { + public: + static constexpr char kBindName[] = "reverse"; + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, + const Axes& dimensions) const override; +}; + +class Fill : public ShapeBasedKernel { + public: + static constexpr char kBindName[] = "fill"; + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, + const Shape& to_shape) const override; +}; + +class Pad : public PadKernel { + public: + static constexpr char kBindName[] = "pad"; + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, + const NdArrayRef& padding_value, + const Sizes& edge_padding_low, const Sizes& edge_padding_high, + const Sizes& interior_padding) const override; +}; + +class Concate : public ConcateKernel { + public: + static constexpr char kBindName[] = "concatenate"; + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const std::vector& vales, + int64_t axis) const override; +}; +} // namespace spu::mpc::standard_shape diff --git a/libspu/mpc/standard_shape/protocol.cc b/libspu/mpc/standard_shape/protocol.cc new file mode 100644 index 00000000..ed2e2028 --- /dev/null +++ b/libspu/mpc/standard_shape/protocol.cc @@ -0,0 +1,34 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "libspu/mpc/standard_shape/protocol.h" + +#include "libspu/core/context.h" +#include "libspu/mpc/standard_shape/kernels.h" + +namespace spu::mpc { + +void regStandardShapeOps(SPUContext* ctx) { + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); +} + +} // namespace spu::mpc diff --git a/libspu/compiler/passes/utils.h b/libspu/mpc/standard_shape/protocol.h similarity index 75% rename from libspu/compiler/passes/utils.h rename to libspu/mpc/standard_shape/protocol.h index 188b9496..b6c240b3 100644 --- a/libspu/compiler/passes/utils.h +++ b/libspu/mpc/standard_shape/protocol.h @@ -14,12 +14,10 @@ #pragma once -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinAttributes.h" +#include "libspu/core/context.h" -namespace mlir::pphlo { +namespace spu::mpc { -mlir::DenseIntElementsAttr -ConvertDimensions(OpBuilder *builder, llvm::ArrayRef op_dimensions); +void regStandardShapeOps(SPUContext* ctx); -} +} // namespace spu::mpc diff --git a/libspu/mpc/utils/ring_ops.cc b/libspu/mpc/utils/ring_ops.cc index cc87f7c5..fccd0292 100644 --- a/libspu/mpc/utils/ring_ops.cc +++ b/libspu/mpc/utils/ring_ops.cc @@ -45,7 +45,7 @@ constexpr char kModule[] = "RingOps"; #define DEF_UNARY_RING_OP(NAME, OP) \ void NAME##_impl(NdArrayRef& ret, const NdArrayRef& x) { \ ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, x); \ - const auto field = x.eltype().as()->field(); \ + const auto field = x.eltype().as() -> field(); \ const int64_t numel = ret.numel(); \ return DISPATCH_ALL_FIELDS(field, kModule, [&]() { \ using T = std::make_signed_t; \ @@ -65,7 +65,7 @@ DEF_UNARY_RING_OP(ring_neg, -); const NdArrayRef& y) { \ ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, x); \ ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, y); \ - const auto field = x.eltype().as()->field(); \ + const auto field = x.eltype().as() -> field(); \ const int64_t numel = ret.numel(); \ return DISPATCH_ALL_FIELDS(field, kModule, [&]() { \ NdArrayView _x(x); \ diff --git a/spu/version.py b/spu/version.py index 50b8b8fa..d0584b79 100644 --- a/spu/version.py +++ b/spu/version.py @@ -13,4 +13,4 @@ # limitations under the License. -__version__ = "0.7.0.dev$$DATE$$" +__version__ = "0.8.0.dev$$DATE$$"