From a0feb321094d7db785d5d174f2a012d830335320 Mon Sep 17 00:00:00 2001 From: anakinxc <103552181+anakinxc@users.noreply.github.com> Date: Fri, 18 Aug 2023 12:44:35 +0800 Subject: [PATCH] Repo Sync (#312) --- .bazelrc | 8 +- .clang-format | 20 +- .clang-tidy | 3 +- .vscode/cspell.json | 1 + bazel/repositories.bzl | 2 +- docs/development/ir_dump.rst | 18 +- libspu/compiler/core/core.cc | 2 + .../compiler/passes/hlo_legalize_to_pphlo.cc | 291 +++++++++++--- libspu/compiler/passes/optimize_maxpool.cc | 17 +- .../tests/hlo_to_pphlo_reduce_window.mlir | 18 + .../hlo_to_pphlo_select_and_scatter.mlir | 2 +- libspu/compiler/tests/optimize_maxpool.mlir | 8 +- libspu/core/ndarray_ref.cc | 1 + libspu/device/BUILD.bazel | 11 + libspu/device/api.cc | 58 ++- libspu/device/debug_dump_constant.cc | 51 +++ libspu/device/debug_dump_constant.h | 40 ++ libspu/device/pphlo/BUILD.bazel | 13 + libspu/device/pphlo/pphlo_executor.cc | 64 +-- .../pphlo/pphlo_executor_debug_runner.cc | 195 ++++++++++ libspu/device/pphlo/pphlo_executor_test.cc | 99 +++-- libspu/device/symbol_table.h | 5 - libspu/dialect/pphlo_ops.cc | 22 +- libspu/dialect/pphlo_ops.td | 18 +- libspu/kernel/hal/shape_ops.h | 2 +- libspu/kernel/hlo/BUILD.bazel | 9 + libspu/kernel/hlo/basic_binary.cc | 2 +- libspu/kernel/hlo/reduce.cc | 8 +- libspu/kernel/hlo/select_and_scatter.cc | 368 ++++++++---------- libspu/kernel/hlo/select_and_scatter_test.cc | 113 ++++++ libspu/kernel/hlo/utils.cc | 49 ++- libspu/kernel/hlo/utils.h | 3 +- libspu/spu.proto | 4 +- 33 files changed, 1036 insertions(+), 489 deletions(-) create mode 100644 libspu/compiler/tests/hlo_to_pphlo_reduce_window.mlir create mode 100644 libspu/device/debug_dump_constant.cc create mode 100644 libspu/device/debug_dump_constant.h create mode 100644 libspu/device/pphlo/pphlo_executor_debug_runner.cc create mode 100644 libspu/kernel/hlo/select_and_scatter_test.cc diff --git a/.bazelrc b/.bazelrc index c509b77e..dc995af1 100644 --- a/.bazelrc +++ b/.bazelrc @@ -1,11 +1,11 @@ # Copyright 2023 Ant Group Co., Ltd. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,7 +17,7 @@ common --experimental_repo_remote_exec # Required by OpenXLA build --nocheck_visibility -build --incompatible_new_actions_api=false +build --incompatible_new_actions_api=false build --copt=-fdiagnostics-color=always build --enable_platform_specific_config diff --git a/.clang-format b/.clang-format index dd90f992..369aa979 100644 --- a/.clang-format +++ b/.clang-format @@ -3,13 +3,13 @@ BasedOnStyle: Google IncludeBlocks: Regroup IncludeCategories: - - Regex: '^<.*\.h>' - Priority: 1 - - Regex: '^<.*' - Priority: 2 - - Regex: '.*\.pb\.h"$' - Priority: 5 - - Regex: '^"libspu.*' - Priority: 4 - - Regex: '^".*' - Priority: 3 + - Regex: '^<.*\.h>' + Priority: 1 + - Regex: "^<.*" + Priority: 2 + - Regex: '.*\.pb\.h"$' + Priority: 5 + - Regex: '^"libspu.*' + Priority: 4 + - Regex: '^".*' + Priority: 3 diff --git a/.clang-tidy b/.clang-tidy index e4220347..4ad05fca 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -74,6 +74,5 @@ CheckOptions: - key: readability-identifier-naming.FunctionCase value: "CamelBack" - - key: performance-unnecessary-value-param.AllowedTypes + - key: performance-unnecessary-value-param.AllowedTypes value: PtBufferView - diff --git a/.vscode/cspell.json b/.vscode/cspell.json index 074e2f89..4b1e1b38 100644 --- a/.vscode/cspell.json +++ b/.vscode/cspell.json @@ -100,6 +100,7 @@ "ponit", "pphlo", "precheck", + "proto", "PRNG", "protobuf", "Prss", diff --git a/bazel/repositories.bzl b/bazel/repositories.bzl index 717eddfb..72a87b53 100644 --- a/bazel/repositories.bzl +++ b/bazel/repositories.bzl @@ -18,7 +18,7 @@ load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") SECRETFLOW_GIT = "https://github.com/secretflow" -YACL_COMMIT_ID = "ff20dff1476071ca885c69bee94d2b3bdf85034c" +YACL_COMMIT_ID = "a9c1d7d119c80eb75d5ec63ee6cd77145dff18c2" def spu_deps(): _bazel_platform() diff --git a/docs/development/ir_dump.rst b/docs/development/ir_dump.rst index 194cd2ea..43bd7a42 100644 --- a/docs/development/ir_dump.rst +++ b/docs/development/ir_dump.rst @@ -6,7 +6,7 @@ Dump IR to DAG Introduction ------------ -This document provides the demo for how to dump the IR (Intermediate Representation, generated by `XLA `_) to a DAG. +This document provides the demo for how to dump the IR (Intermediate Representation, generated by `XLA `_) to a DAG. With the aid of visualized DAG, the execution logic and required operators will be more explicit. @@ -33,7 +33,7 @@ Please first have a look at the :spu_code_host:`spu.proto `_ to convert them to PDF or PNG to visualize the DAG. - + For **DOT** files, you should use `GraphViz `_ to convert them to PDF or PNG to visualize the DAG. + While for **HTML** files, you can directly open the them in your Web Browser, which shall render the DAG. .. code-block:: protobuf @@ -105,7 +105,7 @@ First of all, we declare an CompilerOptions object. Note that the **pretty_print Then we pass the CompilerOptions to the executed SPU code. -The code shall be modified from +The code shall be modified from .. code-block:: python :caption: SPU execution without customized compiler options @@ -137,7 +137,7 @@ We here provide the code snippet for dumping IR to HTML files. The DAG for the e Here, we define a `max` function use jax.numpy. """ return jnp.maximum(x, y) - + def get_data(seed=123): """ Any IO function that loads the data. @@ -145,7 +145,7 @@ We here provide the code snippet for dumping IR to HTML files. The DAG for the e np.random.seed(seed) data = np.random.randn(3, 4) return data - + x = get_data(1) y = get_data(2) @@ -162,7 +162,7 @@ We here provide the code snippet for dumping IR to HTML files. The DAG for the e res_spu = ppd.device("SPU")(func, copts=copts)(x_spu, y_spu) .. Note:: - You may find multiple files in the output directory since XLA has mutliple compile passes and generates multiple IRs, with each corresponding to one DAG. + You may find multiple files in the output directory since XLA has multiple compile passes and generates multiple IRs, with each corresponding to one DAG. The **HTML** output is rendered as follows. diff --git a/libspu/compiler/core/core.cc b/libspu/compiler/core/core.cc index 093d78be..4e531a08 100644 --- a/libspu/compiler/core/core.cc +++ b/libspu/compiler/core/core.cc @@ -47,6 +47,8 @@ void Core::buildPipeline(mlir::PassManager *pm) { // lowering auto &optPM = pm->nest(); if (!options.disable_maxpooling_optimization()) { + // Need a cse before maxpooling + optPM.addPass(mlir::createCSEPass()); optPM.addPass(mlir::pphlo::createOptimizeMaxPoolingPass()); } optPM.addPass(mlir::pphlo::createDecomposeComparisonPass()); diff --git a/libspu/compiler/passes/hlo_legalize_to_pphlo.cc b/libspu/compiler/passes/hlo_legalize_to_pphlo.cc index 9699be8f..a0e2b054 100644 --- a/libspu/compiler/passes/hlo_legalize_to_pphlo.cc +++ b/libspu/compiler/passes/hlo_legalize_to_pphlo.cc @@ -276,20 +276,18 @@ class HloCompToPPHloOpConverter } }; -template -struct ReduceOpConverter : public OpConversionPattern { +struct ReduceOpConverter : public OpConversionPattern { private: const ValueVisibilityMap &vis_; public: ReduceOpConverter(TypeConverter &type_converter, MLIRContext *context, const ValueVisibilityMap &vis) - : OpConversionPattern(type_converter, context), vis_(vis) { - } + : OpConversionPattern(type_converter, context), + vis_(vis) {} LogicalResult - matchAndRewrite(HloReduceOpTy op, - typename ReduceOpConverter::OpAdaptor adaptor, + matchAndRewrite(stablehlo::ReduceOp op, stablehlo::ReduceOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // We may need to materialize operands @@ -341,7 +339,7 @@ struct ReduceOpConverter : public OpConversionPattern { } auto new_op = - rewriter.replaceOpWithNewOp>( + rewriter.replaceOpWithNewOp>( op, result_types, materialized_operands, op->getAttrs()); // Copy over the operations inside the region. @@ -357,6 +355,148 @@ struct ReduceOpConverter : public OpConversionPattern { } }; +struct ReduceWindowOpConverter + : public OpConversionPattern { +private: + const ValueVisibilityMap &vis_; + +public: + ReduceWindowOpConverter(TypeConverter &type_converter, MLIRContext *context, + const ValueVisibilityMap &vis) + : OpConversionPattern(type_converter, context), + vis_(vis) {} + + LogicalResult + matchAndRewrite(stablehlo::ReduceWindowOp op, + stablehlo::ReduceWindowOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + // We may need to materialize operands + llvm::SmallVector materialized_operands; + llvm::SmallVector result_types; + size_t num_results = op.getNumResults(); + + materialized_operands.resize(2 * num_results); + result_types.resize(num_results); + + OpBuilder builder(op); + + auto materialize = [&, this](size_t idx) { + auto current_vis = getOperandVisibility(adaptor.getOperands()[idx]); + auto expected_vis = + vis_.getValueVisibility(op.getBody().getArguments()[idx]); + + if (expected_vis == current_vis) { + materialized_operands[idx] = adaptor.getOperands()[idx]; + } else { + auto new_type = HloToPPHloTypeConverter::getTypeWithVisibility( + adaptor.getOperands()[idx].getType(), expected_vis); + materialized_operands[idx] = + this->getTypeConverter()->materializeTargetConversion( + builder, op.getLoc(), new_type, adaptor.getOperands()[idx]); + } + }; + + for (size_t idx = 0; idx < num_results; ++idx) { + auto result_vis = vis_.getValueVisibility(op.getResult(idx)); + // Check input vis + materialize(idx); + materialize(idx + num_results); + // Push result type + result_types[idx] = HloToPPHloTypeConverter::getTypeWithVisibility( + this->getTypeConverter()->convertType(op.getType(idx)), result_vis); + } + + // Convert the region signature. + auto &entry_block = op.getBody().front(); + TypeConverter::SignatureConversion sig_conversion( + entry_block.getNumArguments()); + + for (const auto &arg : entry_block.getArguments()) { + auto arg_t = this->getTypeConverter()->convertType(arg.getType()); + auto lower_t = HloToPPHloTypeConverter::getTypeWithVisibility( + arg_t, vis_.getValueVisibility(arg)); + sig_conversion.addInputs(arg.getArgNumber(), lower_t); + } + + if (op.getBaseDilations().has_value() || op.getPadding().has_value()) { + auto rank = + op->getOperandTypes()[0].dyn_cast().getRank(); + llvm::SmallVector interior_padding(rank, 0); + llvm::SmallVector padding_low(rank, 0); + llvm::SmallVector padding_high(rank, 0); + + bool has_dilation = + op.getBaseDilations().has_value() && + (!op.getBaseDilationsAttr().isSplat() || + op.getBaseDilationsAttr().getSplatValue() != 1); + + if (has_dilation) { + for (int64_t rank_idx = 0; rank_idx < rank; ++rank_idx) { + interior_padding[rank_idx] = + op.getBaseDilationsAttr().getValues()[rank_idx] - 1; + } + } + + bool has_padding = op.getPadding().has_value() && + (!op.getPaddingAttr().isSplat() || + op.getPaddingAttr().getSplatValue() != 0); + + if (has_padding) { + for (int64_t rank_idx = 0; rank_idx < rank; ++rank_idx) { + padding_low[rank_idx] = + op.getPaddingAttr().getValues()[2 * rank_idx]; + padding_high[rank_idx] = + op.getPaddingAttr().getValues()[2 * rank_idx + 1]; + } + } + + if (has_dilation || has_padding) { + for (size_t idx = 0; idx < num_results; ++idx) { + materialized_operands[idx] = rewriter.create( + op->getLoc(), materialized_operands[idx], + materialized_operands[idx + num_results], + builder.getI64TensorAttr(padding_low), + builder.getI64TensorAttr(padding_high), + builder.getI64TensorAttr(interior_padding)); + } + } + } + + llvm::SmallVector attrs; + { + // I64ElementsAttr:$window_dimensions, + attrs.push_back({builder.getStringAttr("window_dimensions"), + op.getWindowDimensionsAttr()}); + // OptionalAttr:$window_strides, + if (op.getWindowStrides().has_value()) { + attrs.push_back({builder.getStringAttr("window_strides"), + op.getWindowStridesAttr()}); + } + // OptionalAttr:$window_dilations, + if (op.getWindowDilations().has_value()) { + attrs.push_back({builder.getStringAttr("window_dilations"), + op.getWindowDilationsAttr()}); + } + } + + auto new_op = + rewriter + .replaceOpWithNewOp>( + op, result_types, materialized_operands, attrs); + + // Copy over the operations inside the region. + rewriter.inlineRegionBefore(op.getBody(), new_op.getBody(), + new_op.getBody().end()); + + if (failed(rewriter.convertRegionTypes( + &new_op.getBody(), *this->getTypeConverter(), &sig_conversion))) { + return failure(); + } + + return success(); + } +}; struct IfOpConverter : public OpConversionPattern { private: const ValueVisibilityMap &vis_; @@ -832,10 +972,32 @@ struct HloToPPHloOpConverter auto result_type = HloToPPHloTypeConverter::getTypeWithVisibility( op.getType(), vis_.getValueVisibility(op.getResult())); + if (op.getPadding().has_value() && + (!op.getPaddingAttr().isSplat() || + op.getPaddingAttr().getSplatValue() != 0)) { + auto rank = + op->getOperandTypes()[0].dyn_cast().getRank(); + llvm::SmallVector padding_low(rank, 0); + llvm::SmallVector padding_high(rank, 0); + llvm::SmallVector padding_interior(rank, 0); + for (int64_t rank_idx = 0; rank_idx < rank; ++rank_idx) { + padding_low[rank_idx] = + op.getPaddingAttr().getValues()[2 * rank_idx]; + padding_high[rank_idx] = + op.getPaddingAttr().getValues()[2 * rank_idx + 1]; + } + + materialized_operand = rewriter.create( + op->getLoc(), materialized_operand, materialized_init_value, + builder.getI64TensorAttr(padding_low), + builder.getI64TensorAttr(padding_high), + builder.getI64TensorAttr(padding_interior)); + } + auto new_op = rewriter.replaceOpWithNewOp( op, result_type, materialized_operand, adaptor.getSource(), materialized_init_value, op.getWindowDimensionsAttr(), - op.getWindowStridesAttr(), op.getPaddingAttr()); + op.getWindowStridesAttr()); // Convert the region signature. TypeConverter::SignatureConversion select_sig_conversion( @@ -1365,62 +1527,63 @@ struct HloLegalizeToPPHlo const ValueVisibilityMap &vis_map) { auto *context = patterns.getContext(); - patterns.insert< - FuncOpConverter, ReturnOpConverter, HloCompToPPHloOpConverter, - CustomCallConverter, ReduceOpConverter, - ReduceOpConverter, WhileOpConverter, - IfOpConverter, CaseOpConverter, HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter, - HloToPPHloOpConverter>(converter, context, vis_map); + patterns + .insert, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter>(converter, context, + vis_map); } public: diff --git a/libspu/compiler/passes/optimize_maxpool.cc b/libspu/compiler/passes/optimize_maxpool.cc index f9e02354..975765e2 100644 --- a/libspu/compiler/passes/optimize_maxpool.cc +++ b/libspu/compiler/passes/optimize_maxpool.cc @@ -60,9 +60,7 @@ struct SelectAndScatterConverter : public OpRewritePattern { op->getLoc(), SmallVector{current_ret_type, index_result_type}, op.getInputs()[0], op.getWindowDimensions(), op.getWindowStrides().value_or(nullptr), - op.getBaseDilations().value_or(nullptr), - op.getWindowDilations().value_or(nullptr), - op.getPadding().value_or(nullptr)); + op.getWindowDilations().value_or(nullptr)); op->getResult(0).replaceAllUsesWith(argmax->getResult(0)); @@ -134,19 +132,11 @@ struct SelectAndScatterConverter : public OpRewritePattern { continue; } - if (op.getPadding() != previous_reduce_window.getPadding()) { - continue; - } - // Make sure no dilation auto window_dilation = previous_reduce_window.getWindowDilations(); - auto base_dilation = previous_reduce_window.getBaseDilations(); if (window_dilation.has_value() && !isAllOne(*window_dilation)) { continue; } - if (base_dilation.has_value() && !isAllOne(*base_dilation)) { - continue; - } selected_indices = rewriteReduceWindow(previous_reduce_window, rewriter); @@ -156,14 +146,13 @@ struct SelectAndScatterConverter : public OpRewritePattern { } } - if (rewritten == false) { + if (!rewritten) { return failure(); } rewriter.replaceOpWithNewOp( op, op->getResultTypes()[0], selected_indices, op.getSource(), - op.getWindowDimensions(), op.getWindowStrides().value_or(nullptr), - op.getPadding().value_or(nullptr)); + op.getWindowDimensions(), op.getWindowStrides().value_or(nullptr)); return status; } diff --git a/libspu/compiler/tests/hlo_to_pphlo_reduce_window.mlir b/libspu/compiler/tests/hlo_to_pphlo_reduce_window.mlir new file mode 100644 index 00000000..723fe689 --- /dev/null +++ b/libspu/compiler/tests/hlo_to_pphlo_reduce_window.mlir @@ -0,0 +1,18 @@ +// RUN: mlir-pphlo-opt --hlo-legalize-to-pphlo=input_vis_list=VIS_PUBLIC,VIS_PUBLIC --split-input-file %s | FileCheck %s + +func.func @main(%arg0: tensor<3x2xi64>, %arg1: tensor) -> tensor<2x2xi64> { + // CHECK: %0 = "pphlo.pad"(%arg0, %arg1) {edge_padding_high = dense<[1, 0]> : tensor<2xi64>, edge_padding_low = dense<[2, 0]> : tensor<2xi64>, interior_padding = dense<[1, 0]> : tensor<2xi64>} : (tensor<3x2x!pphlo.pub>, tensor>) -> tensor<8x2x!pphlo.pub> + // CHECK: %1 = "pphlo.reduce_window"(%0, %arg1) + %result = "stablehlo.reduce_window"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %0 = "stablehlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) { + window_dimensions = dense<[2, 1]> : tensor<2xi64>, + window_strides = dense<[4, 1]> : tensor<2xi64>, + base_dilations = dense<[2, 1]> : tensor<2xi64>, + window_dilations = dense<[3, 1]> : tensor<2xi64>, + padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64> + } : (tensor<3x2xi64>, tensor) -> tensor<2x2xi64> +return %result : tensor<2x2xi64> +} \ No newline at end of file diff --git a/libspu/compiler/tests/hlo_to_pphlo_select_and_scatter.mlir b/libspu/compiler/tests/hlo_to_pphlo_select_and_scatter.mlir index 2fbaea36..c4e03576 100644 --- a/libspu/compiler/tests/hlo_to_pphlo_select_and_scatter.mlir +++ b/libspu/compiler/tests/hlo_to_pphlo_select_and_scatter.mlir @@ -9,7 +9,7 @@ func.func @main(%arg0: tensor<128x5x5x32xf32>, %arg1: tensor<128x4x4x32xf32>, %a // CHECK: ^bb0(%arg3: tensor>, %arg4: tensor>): // CHECK: %2 = "pphlo.add"(%arg3, %arg4) : (tensor>, tensor>) -> tensor> // CHECK: "pphlo.return"(%2) : (tensor>) -> () - // CHECK: }) {padding = dense<0> : tensor<4x2xi64>, window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<128x5x5x32x!pphlo.sec>, tensor<128x4x4x32x!pphlo.pub>, tensor>) -> tensor<128x5x5x32x!pphlo.sec> + // CHECK: }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<128x5x5x32x!pphlo.sec>, tensor<128x4x4x32x!pphlo.pub>, tensor>) -> tensor<128x5x5x32x!pphlo.sec> %0 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ ^bb0(%arg3: tensor, %arg4: tensor): %1 = "stablehlo.compare"(%arg3, %arg4) {comparison_direction = #stablehlo} : (tensor, tensor) -> tensor diff --git a/libspu/compiler/tests/optimize_maxpool.mlir b/libspu/compiler/tests/optimize_maxpool.mlir index a8926d67..3f37bf80 100644 --- a/libspu/compiler/tests/optimize_maxpool.mlir +++ b/libspu/compiler/tests/optimize_maxpool.mlir @@ -5,12 +5,12 @@ func.func @main(%arg0: tensor<129x24x24x16x!pphlo.sec>, %arg1: tensor<129x2 %1 = "pphlo.constant"() {value = dense<0.000000e+00> : tensor} : () -> tensor> %2 = "pphlo.convert"(%0) : (tensor>) -> tensor> %3 = "pphlo.convert"(%1) : (tensor>) -> tensor> - //CHECK: "pphlo.argmax"(%arg0) {base_dilations = dense<1> : tensor<4xi64>, onehot_index = true, padding = dense<0> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<129x24x24x16x!pphlo.sec>) -> (tensor<129x23x23x16x!pphlo.sec>, tensor<129x23x23x16x4x!pphlo.sec>) + //CHECK: "pphlo.argmax"(%arg0) %4 = "pphlo.reduce_window"(%arg0, %2) ({ ^bb0(%arg2: tensor>, %arg3: tensor>): %6 = "pphlo.maximum"(%arg2, %arg3) : (tensor>, tensor>) -> tensor> "pphlo.return"(%6) : (tensor>) -> () - }) {base_dilations = dense<1> : tensor<4xi64>, padding = dense<0> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<129x24x24x16x!pphlo.sec>, tensor>) -> tensor<129x23x23x16x!pphlo.sec> + }) {base_dilations = dense<1> : tensor<4xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<129x24x24x16x!pphlo.sec>, tensor>) -> tensor<129x23x23x16x!pphlo.sec> //CHECK-NOT: pphlo.select_and_scatter //CHECK : pphlo.maxpool_scatter %5 = "pphlo.select_and_scatter"(%arg0, %arg1, %3) ({ @@ -21,7 +21,7 @@ func.func @main(%arg0: tensor<129x24x24x16x!pphlo.sec>, %arg1: tensor<129x2 ^bb0(%arg2: tensor>, %arg3: tensor>): %6 = "pphlo.add"(%arg2, %arg3) : (tensor>, tensor>) -> tensor> "pphlo.return"(%6) : (tensor>) -> () - }) {padding = dense<0> : tensor<4x2xi64>, window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<129x24x24x16x!pphlo.sec>, tensor<129x23x23x16x!pphlo.sec>, tensor>) -> tensor<129x24x24x16x!pphlo.sec> + }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<129x24x24x16x!pphlo.sec>, tensor<129x23x23x16x!pphlo.sec>, tensor>) -> tensor<129x24x24x16x!pphlo.sec> return %4, %5 : tensor<129x23x23x16x!pphlo.sec>, tensor<129x24x24x16x!pphlo.sec> } @@ -42,7 +42,7 @@ func.func @main(%arg0: tensor<128x2x2x256x!pphlo.sec>, %arg1: tensor<128x1x ^bb0(%arg2: tensor>, %arg3: tensor>): %5 = "pphlo.add"(%arg2, %arg3) : (tensor>, tensor>) -> tensor> "pphlo.return"(%5) : (tensor>) -> () - }) {padding = dense<0> : tensor<4x2xi64>, window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<128x2x2x256x!pphlo.sec>, tensor<128x1x1x256x!pphlo.sec>, tensor>) -> tensor<128x2x2x256x!pphlo.sec> + }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<128x2x2x256x!pphlo.sec>, tensor<128x1x1x256x!pphlo.sec>, tensor>) -> tensor<128x2x2x256x!pphlo.sec> return %3, %4 : tensor<128x2x2x256x!pphlo.sec>, tensor<128x2x2x256x!pphlo.sec> } diff --git a/libspu/core/ndarray_ref.cc b/libspu/core/ndarray_ref.cc index 01c96d7b..9b0ab407 100644 --- a/libspu/core/ndarray_ref.cc +++ b/libspu/core/ndarray_ref.cc @@ -258,6 +258,7 @@ NdArrayRef NdArrayRef::broadcast_to(const Shape& to_shape, Strides new_strides(to_shape.size(), 0); + // TODO: check to_shape match broadcasting rules. if (!in_dims.empty()) { for (size_t idx = 0; idx < in_dims.size(); ++idx) { new_strides[in_dims[idx]] = strides()[idx]; diff --git a/libspu/device/BUILD.bazel b/libspu/device/BUILD.bazel index 5422f18b..e0625bc5 100644 --- a/libspu/device/BUILD.bazel +++ b/libspu/device/BUILD.bazel @@ -65,11 +65,22 @@ spu_cc_library( ], ) +spu_cc_library( + name = "debug_dump_constant", + srcs = [ + "debug_dump_constant.cc", + ], + hdrs = [ + "debug_dump_constant.h", + ], +) + spu_cc_library( name = "api", srcs = ["api.cc"], hdrs = ["api.h"], deps = [ + ":debug_dump_constant", ":executor", "//libspu/device/pphlo:pphlo_executor", "@llvm-project//mlir:FuncDialect", diff --git a/libspu/device/api.cc b/libspu/device/api.cc index aa2299eb..b2eed2ad 100644 --- a/libspu/device/api.cc +++ b/libspu/device/api.cc @@ -24,6 +24,7 @@ #include "mlir/Parser/Parser.h" #include "spdlog/spdlog.h" +#include "libspu/device/debug_dump_constant.h" #include "libspu/device/pphlo/pphlo_executor.h" #include "libspu/dialect/pphlo_dialect.h" @@ -106,28 +107,48 @@ struct ActionStats { } }; -/* - @shantang / @wuju - TODO: temporary remove, need to adapt value slice change void takeSnapshot(size_t rank, const RuntimeConfig &rt_config, const ExecutableProto &executable, const SymbolTable &env) { - - const std::string &dump_dir = rt_config.processor_dump_dir(); + const std::string &dump_dir = rt_config.snapshot_dump_dir(); // Naming convention for dumped files must align with debug runner. std::filesystem::path dump_folder(dump_dir); std::filesystem::create_directories(dump_folder); - auto dump_path = dump_folder / fmt::format("snapshot_{}.spu", rank); - SnapshotProto snapshot; - snapshot.set_rank(rank); - *snapshot.mutable_executable() = executable; - *snapshot.mutable_runtime_cfg() = rt_config; - *snapshot.mutable_environ() = env.toProto(); + // Dump executable + { + std::ofstream config_file(getConfigFilePath(dump_folder), + std::ios::binary | std::ios::out); + config_file << rt_config.SerializeAsString(); + } + + // Dump executable + { + std::ofstream main_file(getCodeFilePath(dump_folder), + std::ios::binary | std::ios::out); + main_file << executable.SerializeAsString(); + } + + auto value_dump_dir = getRankFolder(dump_folder, rank); + std::filesystem::create_directories(value_dump_dir); - std::ofstream dump_file(dump_path, std::ios::binary | std::ios::out); - dump_file << snapshot.SerializeAsString(); + // Dump inputs + for (const auto &[name, var] : env) { + auto serialized = var.toProto(std::numeric_limits::max()); + { + std::ofstream meta_file(getMetaFilePath(dump_folder, rank, name), + std::ios::binary | std::ios::out); + meta_file << serialized.meta.SerializeAsString(); + } + { + for (const auto &chunk : llvm::enumerate(serialized.chunks)) { + std::ofstream chunk_file( + getValueChunkFilePath(dump_folder, rank, name, chunk.index()), + std::ios::binary | std::ios::out); + chunk_file << chunk.value().SerializeAsString(); + } + } + } } -*/ void printProfilingData(spu::SPUContext *sctx, const std::string &name, const ExecutionStats &exec_stats, @@ -222,18 +243,13 @@ void executeImpl(OpExecutor *executor, spu::SPUContext *sctx, } } - // TODO: rename this flag, enable_execution_dump? const RuntimeConfig rt_config = sctx->config(); - /* - @shantang / @wuju - TODO: temporary remove, need to adapt value slice change - if (rt_config.enable_processor_dump()) { + + if (rt_config.enable_runtime_snapshot()) { const bool isRefHal = sctx->lctx() == nullptr; const size_t rank = isRefHal ? 0 : sctx->lctx()->Rank(); takeSnapshot(rank, rt_config, executable, *env); - } - */ // execution std::vector outputs; diff --git a/libspu/device/debug_dump_constant.cc b/libspu/device/debug_dump_constant.cc new file mode 100644 index 00000000..2aae4e5a --- /dev/null +++ b/libspu/device/debug_dump_constant.cc @@ -0,0 +1,51 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "libspu/device/debug_dump_constant.h" + +#include "fmt/format.h" + +namespace spu::device { + +std::string getMetaExtension() { return ".meta"; } + +std::filesystem::path getRankFolder(const std::filesystem::path& base, + int64_t rank) { + return base / fmt::format("rank_{}", rank); +} + +std::filesystem::path getConfigFilePath(const std::filesystem::path& base) { + return base / "config"; +} + +std::filesystem::path getCodeFilePath(const std::filesystem::path& base) { + return base / "code"; +} + +std::filesystem::path getMetaFilePath(const std::filesystem::path& base, + int64_t rank, + const std::string& var_name) { + return getRankFolder(base, rank) / + fmt::format("{}{}", var_name, getMetaExtension()); +} + +std::filesystem::path getValueChunkFilePath(const std::filesystem::path& base, + int64_t rank, + const std::string& var_name, + int64_t chunk_id) { + return getRankFolder(base, rank) / + fmt::format("{}_{}.chunk", var_name, chunk_id); +} + +} // namespace spu::device diff --git a/libspu/device/debug_dump_constant.h b/libspu/device/debug_dump_constant.h new file mode 100644 index 00000000..0b5ad89e --- /dev/null +++ b/libspu/device/debug_dump_constant.h @@ -0,0 +1,40 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace spu::device { + +std::string getMetaExtension(); + +std::filesystem::path getRankFolder(const std::filesystem::path& base, + int64_t rank); + +std::filesystem::path getConfigFilePath(const std::filesystem::path& base); + +std::filesystem::path getCodeFilePath(const std::filesystem::path& base); + +std::filesystem::path getMetaFilePath(const std::filesystem::path& base, + int64_t rank, + const std::string& var_name); + +std::filesystem::path getValueChunkFilePath(const std::filesystem::path& base, + int64_t rank, + const std::string& var_name, + int64_t chunk_id); + +} // namespace spu::device diff --git a/libspu/device/pphlo/BUILD.bazel b/libspu/device/pphlo/BUILD.bazel index 34013690..09617fc6 100644 --- a/libspu/device/pphlo/BUILD.bazel +++ b/libspu/device/pphlo/BUILD.bazel @@ -87,3 +87,16 @@ spu_cc_test( "@llvm-project//mlir:Parser", ], ) + +spu_cc_binary( + name = "pphlo_executor_debug_runner", + testonly = True, + srcs = ["pphlo_executor_debug_runner.cc"], + deps = [ + ":pphlo_executor", + "//libspu/device:api", + "//libspu/device:debug_dump_constant", + "//libspu/device:test_utils", + "@llvm-project//llvm:Support", + ], +) diff --git a/libspu/device/pphlo/pphlo_executor.cc b/libspu/device/pphlo/pphlo_executor.cc index 481fdc90..4f5f8d40 100644 --- a/libspu/device/pphlo/pphlo_executor.cc +++ b/libspu/device/pphlo/pphlo_executor.cc @@ -567,16 +567,6 @@ void execute(OpExecutor *executor, SPUContext *sctx, SymbolScope *sscope, // window padding std::vector> window_padding(window_shape.size(), {0, 0}); - if (op.getPadding().has_value()) { - const auto v = *op.getPadding(); // NOLINT - - SPU_ENFORCE(window_padding.size() * 2 == (size_t)v.size()); - - for (size_t idx = 0; idx < window_padding.size(); ++idx) { - window_padding[idx] = {*(v.getValues().begin() + 2 * idx), - *(v.getValues().begin() + 2 * idx + 1)}; - } - } auto ret = kernel::hlo::SelectAndScatter( sctx, operand, source, init_val, window_shape, window_strides, @@ -612,16 +602,6 @@ void execute(OpExecutor *executor, SPUContext *sctx, SymbolScope *sscope, // window padding std::vector> window_padding(window_shape.size(), {0, 0}); - if (op.getPadding().has_value()) { - const auto v = *op.getPadding(); // NOLINT - - SPU_ENFORCE(window_padding.size() * 2 == (size_t)v.size()); - - for (size_t idx = 0; idx < window_padding.size(); ++idx) { - window_padding[idx] = {*(v.getValues().begin() + 2 * idx), - *(v.getValues().begin() + 2 * idx + 1)}; - } - } auto base_shape = op.getResult().getType().dyn_cast().getShape(); @@ -889,26 +869,9 @@ void execute(OpExecutor *executor, SPUContext *sctx, SymbolScope *sscope, window_dilations); // NOLINT } - // window padding std::vector> window_padding(window_shape.size(), {0, 0}); - if (op.getPadding().has_value()) { - const auto v = *op.getPadding(); // NOLINT - - SPU_ENFORCE(window_padding.size() * 2 == (size_t)v.size()); - - for (size_t idx = 0; idx < window_padding.size(); ++idx) { - window_padding[idx] = {*(v.getValues().begin() + 2 * idx), - *(v.getValues().begin() + 2 * idx + 1)}; - } - } - - // base dilation Sizes base_dilation(window_shape.size(), 1); - if (op.getBaseDilations().has_value()) { - convertDenseIntElementAttr(*op.getBaseDilations(), - base_dilation); // NOLINT - } kernel::hlo::ReduceWindowConfig config; config.window_shape = window_shape; @@ -953,38 +916,21 @@ void execute(OpExecutor *executor, SPUContext *sctx, SymbolScope *sscope, window_dilations); // NOLINT } - // window padding - std::vector> window_padding(window_shape.size(), - {0, 0}); - if (op.getPadding().has_value()) { - const auto v = *op.getPadding(); // NOLINT - - SPU_ENFORCE(window_padding.size() * 2 == (size_t)v.size()); - - for (size_t idx = 0; idx < window_padding.size(); ++idx) { - window_padding[idx] = {*(v.getValues().begin() + 2 * idx), - *(v.getValues().begin() + 2 * idx + 1)}; - } - } - - // base dilation - Sizes base_dilation(window_shape.size(), 1); - if (op.getBaseDilations().has_value()) { - convertDenseIntElementAttr(*op.getBaseDilations(), - base_dilation); // NOLINT - } - auto ret_shape = op->getResults()[0] .getType() .dyn_cast() .getShape(); + std::vector> window_padding(window_shape.size(), + {0, 0}); + Sizes base_dilations(window_shape.size(), 1); + kernel::hlo::ReduceWindowConfig config; config.window_shape = window_shape; config.window_strides = window_strides; config.window_dilations = window_dilations; config.window_padding = window_padding; - config.base_dilations = base_dilation; + config.base_dilations = base_dilations; auto ret = kernel::hlo::ArgMax(sctx, lookupValue(sscope, op.getInput(), opts), ret_shape, config); diff --git a/libspu/device/pphlo/pphlo_executor_debug_runner.cc b/libspu/device/pphlo/pphlo_executor_debug_runner.cc new file mode 100644 index 00000000..51f76768 --- /dev/null +++ b/libspu/device/pphlo/pphlo_executor_debug_runner.cc @@ -0,0 +1,195 @@ +// Copyright 2021 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "absl/strings/str_split.h" +#include "llvm/Support/CommandLine.h" +#include "spdlog/spdlog.h" + +#include "libspu/core/value.h" +#include "libspu/device/api.h" +#include "libspu/device/debug_dump_constant.h" +#include "libspu/device/pphlo/pphlo_executor.h" +#include "libspu/device/symbol_table.h" +#include "libspu/device/test_utils.h" +#include "libspu/kernel/hal/debug.h" +#include "libspu/mpc/factory.h" +#include "libspu/mpc/utils/simulate.h" + +llvm::cl::opt SnapshotDir( + "snapshot_dir", llvm::cl::desc("folder contains core snapshot files"), + llvm::cl::init(".")); + +// Mode switch +llvm::cl::opt LocalMode("local", llvm::cl::desc("local simulation mode"), + llvm::cl::init(false)); + +// Network only settings +llvm::cl::opt Parties( + "parties", llvm::cl::init("127.0.0.1:9530,127.0.0.1:9531,127.0.0.1:9532"), + llvm::cl::desc("server list, format: host1:port1[,host2:port2, ...]")); + +llvm::cl::opt Rank("rank", llvm::cl::init(0), + llvm::cl::desc("self rank")); + +// Local simulation only settings +llvm::cl::opt NumProc( + "num_processor", + llvm::cl::desc("number of processors to create (local simulation only)"), + llvm::cl::init(3)); + +std::shared_ptr MakeLink(const std::string &parties, + size_t rank) { + yacl::link::ContextDesc lctx_desc; + std::vector hosts = absl::StrSplit(parties, ','); + for (size_t rank = 0; rank < hosts.size(); rank++) { + const auto id = fmt::format("party{}", rank); + lctx_desc.parties.push_back({id, hosts[rank]}); + } + auto lctx = yacl::link::FactoryBrpc().CreateContext(lctx_desc, rank); + lctx->ConnectToMesh(); + return lctx; +} + +std::unique_ptr MakeSPUContext( + const spu::RuntimeConfig &config) { + auto lctx = MakeLink(Parties.getValue(), Rank.getValue()); + + return std::make_unique(config, lctx); +} + +spu::RuntimeConfig parseRuntimeConfig( + const std::filesystem::path &snapshot_dir) { + auto config_file = spu::device::getConfigFilePath(snapshot_dir); + SPU_ENFORCE(std::filesystem::exists(config_file), + "Serialized config file {} does not exit", config_file.c_str()); + SPDLOG_INFO("Read config file from {}", config_file.c_str()); + std::ifstream stream(config_file, std::ios::binary); + + spu::RuntimeConfig config; + SPU_ENFORCE(config.ParseFromIstream(&stream), + "Parse serialized config file {} failed", config_file.c_str()); + return config; +} + +spu::ExecutableProto parseExecutable( + const std::filesystem::path &snapshot_dir) { + auto code_file = spu::device::getCodeFilePath(snapshot_dir); + SPU_ENFORCE(std::filesystem::exists(code_file), + "Serialized executable file {} does not exit", code_file.c_str()); + SPDLOG_INFO("Read config file from {}", code_file.c_str()); + std::ifstream stream(code_file, std::ios::binary); + + spu::ExecutableProto code; + SPU_ENFORCE(code.ParseFromIstream(&stream), + "Parse serialized code file {} failed", code_file.c_str()); + return code; +} + +spu::device::SymbolTable parseSymbolTable( + const std::filesystem::path &snapshot_dir) { + auto data_dir = spu::device::getRankFolder(snapshot_dir, Rank.getValue()); + SPU_ENFORCE(std::filesystem::exists(data_dir), + "Serialized data dir {} does not exit", data_dir.c_str()); + SPDLOG_INFO("Read inputs file from {}", data_dir.c_str()); + + spu::device::SymbolTable table; + + for (const auto &file : std::filesystem::directory_iterator(data_dir)) { + const auto &filename = file.path().filename(); + + if (filename.extension() == spu::device::getMetaExtension()) { + spu::ValueProto vp; + { + SPDLOG_INFO("Read inputs meta {}", file.path().c_str()); + std::ifstream stream(file.path(), std::ios::binary); + vp.meta.ParseFromIstream(&stream); + } + const auto var_name = filename.stem().native(); + // Get slices + int64_t counter = 0; + while (true) { + auto chunk_file = spu::device::getValueChunkFilePath( + snapshot_dir, Rank.getValue(), var_name, counter); + if (std::filesystem::exists(chunk_file)) { + SPDLOG_INFO("Read inputs data chunk {}", chunk_file.c_str()); + std::ifstream stream(chunk_file, std::ios::binary); + vp.chunks.resize(counter + 1); + vp.chunks[counter].ParseFromIstream(&stream); + ++counter; + } else { + break; + } + } + + table.setVar(var_name, spu::Value::fromProto(vp)); + } + } + + return table; +} + +void RpcBasedRunner(const std::filesystem::path &snapshot_dir) { + auto sctx = MakeSPUContext(parseRuntimeConfig(snapshot_dir)); + + spu::device::SymbolTable table = parseSymbolTable(snapshot_dir); + + spu::device::pphlo::PPHloExecutor executor; + + SPDLOG_INFO("Run with config {}", sctx->config().DebugString()); + + spu::device::execute(&executor, sctx.get(), parseExecutable(snapshot_dir), + &table); +} + +void MemBasedRunner(const std::filesystem::path &snapshot_dir) { + auto world_size = NumProc.getValue(); + + SPDLOG_INFO("world size = {}", world_size); + + auto rt_config = parseRuntimeConfig(snapshot_dir); + rt_config.set_enable_runtime_snapshot(false); + + spu::mpc::utils::simulate( + world_size, [&](const std::shared_ptr<::yacl::link::Context> &lctx) { + spu::SPUContext sctx(rt_config, lctx); + + spu::mpc::Factory::RegisterProtocol(&sctx, sctx.lctx()); + + spu::device::pphlo::PPHloExecutor executor; + + auto executable = parseExecutable(snapshot_dir); + spu::device::SymbolTable table = parseSymbolTable(snapshot_dir); + + spu::device::execute(&executor, &sctx, executable, &table); + }); +} + +int main(int argc, char **argv) { + llvm::cl::ParseCommandLineOptions(argc, argv); + + std::filesystem::path snapshot_dir = SnapshotDir.getValue(); + + auto local = LocalMode.getValue(); + + if (local) { + MemBasedRunner(snapshot_dir); + } else { + RpcBasedRunner(snapshot_dir); + } +} diff --git a/libspu/device/pphlo/pphlo_executor_test.cc b/libspu/device/pphlo/pphlo_executor_test.cc index 2299c5c2..00373049 100644 --- a/libspu/device/pphlo/pphlo_executor_test.cc +++ b/libspu/device/pphlo/pphlo_executor_test.cc @@ -492,7 +492,7 @@ func.func @main(%arg0: tensor<4x4x!pphlo.pub>) -> (tensor<1x1x!pphlo.pub>) -> () }) {base_dilations = dense<1> : tensor<2xi64>, padding = dense<0> : tensor<2x2xi64>, window_dilations = dense<2> : tensor<2xi64>, window_dimensions = dense<2> : tensor<2xi64>, window_strides = dense<2> : tensor<2xi64>} : (tensor<4x4x!pphlo.pub>, tensor>) -> tensor<1x1x!pphlo.pub> - return %1 : tensor<1x1x!pphlo.pub> + return %1 : tensor<1x1x!pphlo.pub> })"); xt::xarray expect = {10}; @@ -507,17 +507,23 @@ TEST_P(ExecutorTest, ReduceWindowMaxIotaBaseDilation) { {0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}, {12, 13, 14, 15}}; r.addInput(in1); - r.run(R"( -func.func @main(%arg0: tensor<4x4x!pphlo.pub>) -> (tensor<6x6x!pphlo.pub>) { - %0 = "pphlo.constant"() {value = dense<0> : tensor} : () -> tensor> - %1 = "pphlo.reduce_window"(%arg0, %0) ( { - ^bb0(%arg1: tensor>, %arg2: tensor>): // no predecessors - %2 = "pphlo.maximum"(%arg1, %arg2) : (tensor>, tensor>) -> tensor> - "pphlo.return"(%2) : (tensor>) -> () - }) {base_dilations = dense<2> : tensor<2xi64>, padding = dense<0> : tensor<2x2xi64>, window_dilations = dense<1> : tensor<2xi64>, window_dimensions = dense<2> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : (tensor<4x4x!pphlo.pub>, tensor>) -> tensor<6x6x!pphlo.pub> - - return %1 : tensor<6x6x!pphlo.pub> -})"); + r.run(r.compileMHlo(R"( +func.func @main(%arg0: tensor<4x4xi32>) -> (tensor<6x6xi32>) { + %0 = "mhlo.constant"() {value = dense<0> : tensor} : () -> tensor + %1 = "mhlo.reduce_window"(%arg0, %0) ( { + ^bb0(%arg1: tensor, %arg2: tensor): // no predecessors + %2 = "mhlo.maximum"(%arg1, %arg2) : (tensor, tensor) -> tensor + "mhlo.return"(%2) : (tensor) -> () + }) { + base_dilations = dense<2> : tensor<2xi64>, + padding = dense<0> : tensor<2x2xi64>, + window_dilations = dense<1> : tensor<2xi64>, + window_dimensions = dense<2> : tensor<2xi64>, + window_strides = dense<1> : tensor<2xi64> + } : (tensor<4x4xi32>, tensor) -> tensor<6x6xi32> + return %1 : tensor<6x6xi32> +})", + {Visibility::VIS_PUBLIC})); xt::xarray expect = {{0, 1, 1, 2, 2, 3}, {4, 5, 5, 6, 6, 7}, {4, 5, 5, 6, 6, 7}, {8, 9, 9, 10, 10, 11}, @@ -533,17 +539,20 @@ TEST_P(ExecutorTest, ReduceWindowMaxIotaStrideBaseDilation) { {0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}, {12, 13, 14, 15}}; r.addInput(in1); - r.run(R"( -func.func @main(%arg0: tensor<4x4x!pphlo.pub>) -> (tensor<3x3x!pphlo.pub>) { - %0 = "pphlo.constant"() {value = dense<0> : tensor} : () -> tensor> - %1 = "pphlo.reduce_window"(%arg0, %0) ( { - ^bb0(%arg1: tensor>, %arg2: tensor>): // no predecessors - %2 = "pphlo.maximum"(%arg1, %arg2) : (tensor>, tensor>) -> tensor> - "pphlo.return"(%2) : (tensor>) -> () - }) {base_dilations = dense<2> : tensor<2xi64>, padding = dense<0> : tensor<2x2xi64>, window_dilations = dense<1> : tensor<2xi64>, window_dimensions = dense<2> : tensor<2xi64>, window_strides = dense<2> : tensor<2xi64>} : (tensor<4x4x!pphlo.pub>, tensor>) -> tensor<3x3x!pphlo.pub> + auto compiled = r.compileMHlo(R"( +func.func @main(%arg0: tensor<4x4xi32>) -> (tensor<3x3xi32>) { + %0 = "mhlo.constant"() {value = dense<0> : tensor} : () -> tensor + %1 = "mhlo.reduce_window"(%arg0, %0) ( { + ^bb0(%arg1: tensor, %arg2: tensor): // no predecessors + %2 = "mhlo.maximum"(%arg1, %arg2) : (tensor, tensor) -> tensor + "mhlo.return"(%2) : (tensor) -> () + }) {base_dilations = dense<2> : tensor<2xi64>, padding = dense<0> : tensor<2x2xi64>, window_dilations = dense<1> : tensor<2xi64>, window_dimensions = dense<2> : tensor<2xi64>, window_strides = dense<2> : tensor<2xi64>} : (tensor<4x4xi32>, tensor) -> tensor<3x3xi32> - return %1 : tensor<3x3x!pphlo.pub> -})"); + return %1 : tensor<3x3xi32> +})", + {Visibility::VIS_PUBLIC}); + + r.run(compiled); xt::xarray expect = {{0, 1, 2}, {4, 5, 6}, {8, 9, 10}}; r.verifyOutput(expect.data()); @@ -557,17 +566,20 @@ TEST_P(ExecutorTest, ReduceWindowMaxIotaStrideBothDilation) { {0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}, {12, 13, 14, 15}}; r.addInput(in1); - r.run(R"( -func.func @main(%arg0: tensor<4x4x!pphlo.pub>) -> (tensor<3x3x!pphlo.pub>) { - %0 = "pphlo.constant"() {value = dense<0> : tensor} : () -> tensor> - %1 = "pphlo.reduce_window"(%arg0, %0) ( { - ^bb0(%arg1: tensor>, %arg2: tensor>): // no predecessors - %2 = "pphlo.maximum"(%arg1, %arg2) : (tensor>, tensor>) -> tensor> - "pphlo.return"(%2) : (tensor>) -> () - }) {base_dilations = dense<2> : tensor<2xi64>, padding = dense<0> : tensor<2x2xi64>, window_dilations = dense<2> : tensor<2xi64>, window_dimensions = dense<2> : tensor<2xi64>, window_strides = dense<2> : tensor<2xi64>} : (tensor<4x4x!pphlo.pub>, tensor>) -> tensor<3x3x!pphlo.pub> + auto compiled = r.compileMHlo(R"( +func.func @main(%arg0: tensor<4x4xi32>) -> (tensor<3x3xi32>) { + %0 = "mhlo.constant"() {value = dense<0> : tensor} : () -> tensor + %1 = "mhlo.reduce_window"(%arg0, %0) ( { + ^bb0(%arg1: tensor, %arg2: tensor): // no predecessors + %2 = "mhlo.maximum"(%arg1, %arg2) : (tensor, tensor) -> tensor + "mhlo.return"(%2) : (tensor) -> () + }) {base_dilations = dense<2> : tensor<2xi64>, padding = dense<0> : tensor<2x2xi64>, window_dilations = dense<2> : tensor<2xi64>, window_dimensions = dense<2> : tensor<2xi64>, window_strides = dense<2> : tensor<2xi64>} : (tensor<4x4xi32>, tensor) -> tensor<3x3xi32> - return %1 : tensor<3x3x!pphlo.pub> -})"); + return %1 : tensor<3x3xi32> +})", + {Visibility::VIS_PUBLIC}); + + r.run(compiled); xt::xarray expect = {{5, 6, 7}, {9, 10, 11}, {13, 14, 15}}; r.verifyOutput(expect.data()); @@ -581,17 +593,20 @@ TEST_P(ExecutorTest, ReduceWindowMaxIotaPaddingStrideBaseDilation) { {0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}, {12, 13, 14, 15}}; r.addInput(in1); - r.run(R"( -func.func @main(%arg0: tensor<4x4x!pphlo.pub>) -> (tensor<3x3x!pphlo.pub>) { - %0 = "pphlo.constant"() {value = dense<0> : tensor} : () -> tensor> - %1 = "pphlo.reduce_window"(%arg0, %0) ( { - ^bb0(%arg1: tensor>, %arg2: tensor>): // no predecessors - %2 = "pphlo.maximum"(%arg1, %arg2) : (tensor>, tensor>) -> tensor> - "pphlo.return"(%2) : (tensor>) -> () - }) {base_dilations = dense<2> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, window_dilations = dense<1> : tensor<2xi64>, window_dimensions = dense<3> : tensor<2xi64>, window_strides = dense<3> : tensor<2xi64>} : (tensor<4x4x!pphlo.pub>, tensor>) -> tensor<3x3x!pphlo.pub> + auto compiled = r.compileMHlo(R"( +func.func @main(%arg0: tensor<4x4xi32>) -> (tensor<3x3xi32>) { + %0 = "mhlo.constant"() {value = dense<0> : tensor} : () -> tensor + %1 = "mhlo.reduce_window"(%arg0, %0) ( { + ^bb0(%arg1: tensor, %arg2: tensor): // no predecessors + %2 = "mhlo.maximum"(%arg1, %arg2) : (tensor, tensor) -> tensor + "mhlo.return"(%2) : (tensor) -> () + }) {base_dilations = dense<2> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, window_dilations = dense<1> : tensor<2xi64>, window_dimensions = dense<3> : tensor<2xi64>, window_strides = dense<3> : tensor<2xi64>} : (tensor<4x4xi32>, tensor) -> tensor<3x3xi32> - return %1 : tensor<3x3x!pphlo.pub> -})"); + return %1 : tensor<3x3xi32> +})", + {Visibility::VIS_PUBLIC}); + + r.run(compiled); xt::xarray expect = {{0, 2, 3}, {8, 10, 11}, {12, 14, 15}}; r.verifyOutput(expect.data()); diff --git a/libspu/device/symbol_table.h b/libspu/device/symbol_table.h index f82a2b66..38bacd90 100644 --- a/libspu/device/symbol_table.h +++ b/libspu/device/symbol_table.h @@ -36,11 +36,6 @@ class SymbolTable { auto begin() { return data_.begin(); } auto end() { return data_.end(); } - - // @shantang / @wuju - // TODO: temporary remove, need to adapt value slice change - // SymbolTableProto toProto() const; - // static SymbolTable fromProto(const SymbolTableProto &proto); }; } // namespace spu::device diff --git a/libspu/dialect/pphlo_ops.cc b/libspu/dialect/pphlo_ops.cc index bee3f970..8cdf5395 100644 --- a/libspu/dialect/pphlo_ops.cc +++ b/libspu/dialect/pphlo_ops.cc @@ -825,9 +825,29 @@ LogicalResult PadOp::inferReturnTypeComponents( adaptor.getEdgePaddingHigh(), adaptor.getInteriorPadding(), types); // Convert type to STC + TypeTools tools; for (auto& t : types) { auto rt = t.dyn_cast(); - inferredReturnShapes.emplace_back(rt.getShape(), rt.getElementType()); + if (tools.isMPCType(rt)) { + llvm::SmallVector vis; + + for (const auto& op : operands) { + if (tools.isMPCType(op.getType())) { + auto p = op.getDefiningOp() + ->getOperandTypes()[0]; + vis.emplace_back(tools.getTypeVisibility(p)); + } else { + vis.emplace_back(tools.getTypeVisibility(op.getType())); + } + } + + auto result_vis = tools.inferResultVisibility(vis); + inferredReturnShapes.emplace_back( + rt.getShape(), + tools.getTypeWithVisibility(rt.getElementType(), result_vis)); + } else { + inferredReturnShapes.emplace_back(rt.getShape(), rt.getElementType()); + } } return status; diff --git a/libspu/dialect/pphlo_ops.td b/libspu/dialect/pphlo_ops.td index 96a317f4..cb933264 100644 --- a/libspu/dialect/pphlo_ops.td +++ b/libspu/dialect/pphlo_ops.td @@ -698,8 +698,7 @@ def PPHLO_SelectAndScatterOp: PPHLO_Op<"select_and_scatter", PPHLO_Tensor:$source, PPHLO_Tensor:$init_value, I64ElementsAttr:$window_dimensions, - OptionalAttr:$window_strides, - OptionalAttr:$padding + OptionalAttr:$window_strides ); let regions = (region SizedRegion<1>:$select, SizedRegion<1>:$scatter); @@ -720,8 +719,7 @@ def PPHLO_MaxPoolScatterOp: PPHLO_Op<"maxpool_scatter", [Pure]> { PPHLO_IntTensor:$scatter_indices, PPHLO_Tensor:$update, OptionalAttr:$window_dimensions, - OptionalAttr:$window_strides, - OptionalAttr:$padding + OptionalAttr:$window_strides ); let results = (outs PPHLO_Tensor); @@ -772,12 +770,9 @@ def PPHLO_ReduceWindowOp : PPHLO_Op<"reduce_window", [ Variadic:$init_values, I64ElementsAttr:$window_dimensions, // If strides or dilations attributes are missing then the default value is - // one for each of the input dimensions. Similarly, padding values are zero - // for both low and high in each of the dimensions, if not specified. + // one for each of the input dimensions. OptionalAttr:$window_strides, - OptionalAttr:$base_dilations, - OptionalAttr:$window_dilations, - OptionalAttr:$padding + OptionalAttr:$window_dilations ); let results = (outs Variadic); @@ -796,12 +791,9 @@ def PPHLO_ArgMaxOp: PPHLO_Op<"argmax", [Pure]> { PPHLO_Tensor:$input, I64ElementsAttr:$window_dimensions, // If strides or dilations attributes are missing then the default value is - // one for each of the input dimensions. Similarly, padding values are zero - // for both low and high in each of the dimensions, if not specified. + // one for each of the input dimensions. OptionalAttr:$window_strides, - OptionalAttr:$base_dilations, OptionalAttr:$window_dilations, - OptionalAttr:$padding, DefaultValuedAttr:$onehot_index ); diff --git a/libspu/kernel/hal/shape_ops.h b/libspu/kernel/hal/shape_ops.h index c4bef79e..ab9ad6e6 100644 --- a/libspu/kernel/hal/shape_ops.h +++ b/libspu/kernel/hal/shape_ops.h @@ -38,7 +38,7 @@ Value reshape(SPUContext* ctx, const Value& in, const Shape& to_shape); // @param end_indices, the end indices // @param strides, the strides Value slice(SPUContext* ctx, const Value& input, const Index& start_indices, - const Index& end_indices, const Strides& strides); + const Index& end_indices, const Strides& strides = {}); /// This is a special slice for single element at indices // @returns a array with empty shape (scalar) diff --git a/libspu/kernel/hlo/BUILD.bazel b/libspu/kernel/hlo/BUILD.bazel index 9168ea75..4a8192e7 100644 --- a/libspu/kernel/hlo/BUILD.bazel +++ b/libspu/kernel/hlo/BUILD.bazel @@ -243,6 +243,15 @@ spu_cc_library( ], ) +spu_cc_test( + name = "select_and_scatter_test", + srcs = ["select_and_scatter_test.cc"], + deps = [ + ":select_and_scatter", + "//libspu/kernel:test_util", + ], +) + spu_cc_library( name = "shift", srcs = ["shift.cc"], diff --git a/libspu/kernel/hlo/basic_binary.cc b/libspu/kernel/hlo/basic_binary.cc index 4fb11eeb..bf8211fc 100644 --- a/libspu/kernel/hlo/basic_binary.cc +++ b/libspu/kernel/hlo/basic_binary.cc @@ -55,7 +55,7 @@ spu::Value Remainder(SPUContext *ctx, const spu::Value &lhs, auto quotient = hal::div(ctx, lhs, rhs); if (lhs.isFxp() || rhs.isFxp()) { - // 2nd: round to nearst number through (x >= 0.0) ? floor(x) : ceil(x)... + // 2nd: round to nearest number through (x >= 0.0) ? floor(x) : ceil(x)... auto zero = hal::zeros(ctx, quotient.dtype(), quotient.shape()); quotient = hal::select(ctx, hal::greater_equal(ctx, quotient, zero), hal::floor(ctx, quotient), hal::ceil(ctx, quotient)); diff --git a/libspu/kernel/hlo/reduce.cc b/libspu/kernel/hlo/reduce.cc index 8fed7257..1881fd8a 100644 --- a/libspu/kernel/hlo/reduce.cc +++ b/libspu/kernel/hlo/reduce.cc @@ -127,8 +127,9 @@ std::vector ReduceWindowWithoutDilation( std::vector expanded; for (size_t idx = 0; idx < nargs; ++idx) { const auto &input = inputs[idx]; - expanded.emplace_back( - expandWindow(ctx, input, window_shape, window_strides, window_padding)); + const auto &init_val = init_values[idx]; + expanded.emplace_back(expandWindow(ctx, input, window_shape, window_strides, + window_padding, init_val)); } if (last_operand_is_window_mask) { @@ -493,7 +494,8 @@ std::pair ArgMax(SPUContext *ctx, auto mask = hal::constant(ctx, e, DT_I1); auto result = ReduceWindowImpl( - ctx, {input, mask}, {}, ret_shape, config, true, true, + ctx, {input, mask}, {spu::Value(), spu::Value()}, ret_shape, config, true, + true, [&](absl::Span lhs, absl::Span rhs) -> std::vector { SPU_ENFORCE(lhs.size() == 2); diff --git a/libspu/kernel/hlo/select_and_scatter.cc b/libspu/kernel/hlo/select_and_scatter.cc index 01686489..66ea1b38 100644 --- a/libspu/kernel/hlo/select_and_scatter.cc +++ b/libspu/kernel/hlo/select_and_scatter.cc @@ -15,264 +15,210 @@ #include "libspu/kernel/hlo/select_and_scatter.h" -#include -#include -#include -#include - -#include "yacl/utils/parallel.h" - -#include "libspu/core/context.h" -#include "libspu/core/value.h" #include "libspu/kernel/hal/constants.h" #include "libspu/kernel/hal/debug.h" -#include "libspu/kernel/hal/polymorphic.h" // for select -#include "libspu/kernel/hal/ring.h" +#include "libspu/kernel/hal/polymorphic.h" #include "libspu/kernel/hal/shape_ops.h" -#include "libspu/kernel/hal/type_cast.h" -#include "libspu/kernel/hlo/const.h" +#include "libspu/kernel/hlo/const.h" // iota #include "libspu/kernel/hlo/reduce.h" #include "libspu/kernel/hlo/utils.h" namespace spu::kernel::hlo { -spu::Value MaxPoolScatter1x2x2x1NoPaddingNoDilation( - SPUContext *ctx, const spu::Value &scatter_indices, - const spu::Value &source, const Strides &window_strides) { - std::vector slices(4); - for (int64_t idx = 0; idx < 4; ++idx) { - slices[idx] = hal::slice( - ctx, scatter_indices, {0, 0, 0, 0, idx}, - {scatter_indices.shape()[0], scatter_indices.shape()[1], - scatter_indices.shape()[2], scatter_indices.shape()[3], idx + 1}, - {1, 1, 1, 1, 1}); - slices[idx] = - hal::mul(ctx, hal::reshape(ctx, slices[idx], source.shape()), source); - - // FIXME(jint), handle int type promotion - slices[idx] = hal::dtype_cast(ctx, slices[idx], source.dtype()); +/// The simplified scatter function +static spu::Value ScatterWindow( + SPUContext *ctx, // + const spu::Value &source, // scatter source, shape = num_window + const spu::Value &scatter_indices, // the one-hot encoded scatter index + const spu::Value &init, // scalar value for non-scattered positions. + const Shape &base_shape, // + const Shape &window_shape, // + const Strides &window_strides, // + const ValueBinaryFn &scatter_fn) { + // alias shapes, use B,W,N. + const Shape &B = base_shape; // base shape + const Shape &W = window_shape; // window shape + const Shape &N = source.shape(); // number of window + const Shape NW2d = {N.numel(), W.numel()}; // flat N x W + + // sanity check. + const size_t ndim = source.shape().size(); + SPU_ENFORCE_EQ(ndim, window_shape.size()); + SPU_ENFORCE_EQ(ndim, window_strides.size()); + SPU_ENFORCE(init.shape().isScalar()); + + // scatter_indices is the one-hot encoding for each window. + // win0: [0, 0, 0, 1, 0, 0] // position 3 is selected. + // win1: [0, 1, 0, 0, 0, 0] // position 1 is selected. + // ... + SPU_ENFORCE_EQ(ndim + 1, scatter_indices.shape().size()); + SPU_ENFORCE_EQ(N, Shape(scatter_indices.shape().begin(), + scatter_indices.shape().begin() + ndim)); + SPU_ENFORCE_EQ(W.numel(), scatter_indices.shape()[ndim]); + auto scatter_indices_2d = hal::reshape(ctx, scatter_indices, NW2d); + + auto source2d = + hal::broadcast_to(ctx, hal::reshape(ctx, source, {N.numel()}), NW2d, {0}); + + // One hot selected value per-window. + // win0: [0, 0, 0, X, 0, 0] // position 3 is selected. + // win1: [0, Y, 0, 0, 0, 0] // position 1 is selected. + auto selected = hal::mul(ctx, source2d, scatter_indices_2d); + SPU_ENFORCE_EQ(selected.shape(), NW2d); + + // selected value per-window index. + std::vector base_per_widx(W.numel()); + for (int64_t widx = 0; widx < W.numel(); widx++) { + // for the i-th index in window, find all selected values. + // win0: _, [0], _, _, _, _ + // win1: _, [Y], _, _, _, _ + // .. + auto sel_pw = hal::slice(ctx, selected, {0, widx}, {N.numel(), widx + 1}); + SPU_ENFORCE_EQ(sel_pw.shape(), Shape({N.numel(), 1})); + sel_pw = hal::reshape(ctx, sel_pw, N); + + // scatter it from num_window space to base space. + Index window_index = unflattenIndex(widx, W); + Sizes padding_lo(ndim, 0); + Sizes padding_hi(ndim, 0); + Sizes padding_in(ndim, 0); + for (size_t dim = 0; dim < ndim; dim++) { + padding_lo[dim] = window_index[dim]; + padding_hi[dim] = window_shape[dim] - window_index[dim] - 1; + padding_in[dim] = window_strides[dim] - 1; + } + + base_per_widx[widx] = + hal::pad(ctx, sel_pw, init, padding_lo, padding_hi, padding_in); + SPU_ENFORCE_EQ(base_per_widx[widx].shape(), B); } - // Improvement idea: If window strides is >= window size (no overlap), we - // should be able to compute scatter result with just one multiply - auto z = hal::zeros(ctx, slices[0].dtype()); - if (slices[0].isSecret()) { - z = hal::seal(ctx, z); - } + // last step, stack and reduce it. + auto res = hal::concatenate(ctx, base_per_widx, 0); + Shape WflatB = {W.numel()}; + WflatB.insert(WflatB.end(), B.begin(), B.end()); + res = hal::reshape(ctx, res, WflatB); - std::vector> f_slices(4); - f_slices[0] = - std::async(std::launch::async, hal::pad, ctx, slices[0], z, - Sizes{0, 0, 0, 0}, Sizes{0, 1, 1, 0}, - Sizes{0, window_strides[1] - 1, window_strides[2] - 1, 0}); - f_slices[1] = - std::async(std::launch::async, hal::pad, ctx, slices[1], z, - Sizes{0, 0, 1, 0}, Sizes{0, 1, 0, 0}, - Sizes{0, window_strides[1] - 1, window_strides[2] - 1, 0}); - f_slices[2] = - std::async(std::launch::async, hal::pad, ctx, slices[2], z, - Sizes{0, 1, 0, 0}, Sizes{0, 0, 1, 0}, - Sizes{0, window_strides[1] - 1, window_strides[2] - 1, 0}); - f_slices[3] = - std::async(std::launch::async, hal::pad, ctx, slices[3], z, - Sizes{0, 1, 1, 0}, Sizes{0, 0, 0, 0}, - Sizes{0, window_strides[1] - 1, window_strides[2] - 1, 0}); - - spu::Value ret = f_slices[0].get(); - for (size_t idx = 1; idx < 4; ++idx) { - ret = hal::add(ctx, ret, f_slices[idx].get()); - } + res = TreeReduce( + ctx, {res}, /* axis */ 0, + [&](absl::Span lhs, absl::Span rhs) { + return std::vector{scatter_fn(lhs[0], rhs[0])}; + })[0]; - return ret; -}; + // TODO: if the reshape failed, that maybe the right edge is not sampled by + // the window, we should add a padding operation here. + return hal::reshape(ctx, res, B); +} spu::Value MaxPoolScatter( SPUContext *ctx, const spu::Value &scatter_indices, const spu::Value &source, const Shape &window_shape, const Shape &base_shape, const Strides &window_strides, absl::Span> window_padding) { - // Add a fast 1x2x2x1, no padding fast reduce auto no_padding = std::all_of(window_padding.begin(), window_padding.end(), [](const std::pair &p) { return p.first == 0 && p.second == 0; }); - if (window_shape == absl::Span{1, 2, 2, 1} && no_padding) { - return MaxPoolScatter1x2x2x1NoPaddingNoDilation(ctx, scatter_indices, - source, window_strides); - } - // source_shape * window_numel - auto tiled_1d_shape = source.shape(); - const int64_t window_numel = std::accumulate( - window_shape.begin(), window_shape.end(), 1, std::multiplies<>()); - tiled_1d_shape.push_back(window_numel); - - Axes broadcast_dims(source.shape().size(), 0); - std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); - - auto tiled_1d_source = - hal::broadcast_to(ctx, source, tiled_1d_shape, broadcast_dims); - - // selected_pos is the one hot encoding for each window. - auto selected = hal::mul(ctx, tiled_1d_source, scatter_indices); - - Shape tiled_shape(source.shape().begin(), source.shape().end()); - tiled_shape.insert(tiled_shape.end(), window_shape.begin(), - window_shape.end()); - - selected = hal::reshape(ctx, selected, tiled_shape); - - const size_t ndim = base_shape.size(); - auto base_x_window_shape = base_shape; - base_x_window_shape.insert(base_x_window_shape.end(), window_shape.begin(), - window_shape.end()); - auto output = hal::zeros(ctx, source.dtype(), base_x_window_shape); - if (source.isSecret()) { - output = hal::seal(ctx, output); - } + SPU_ENFORCE(no_padding, "Expect padding to be removed by previous pass"); - const std::vector window_dilations(window_shape.size(), 1); - const std::vector base_dilations(source.shape().size(), 1); - std::vector window_index(ndim, 0); - - do { - yacl::parallel_for( - 0, source.numel(), 2048, [&](int64_t begin, int64_t end) { - Index tiled_index(2 * ndim, 0); - Index base_x_window_index(2 * ndim, 0); - std::copy(window_index.begin(), window_index.end(), - base_x_window_index.begin() + ndim); - std::copy(window_index.begin(), window_index.end(), - tiled_index.begin() + ndim); - auto source_index = unflattenIndex(begin, source.shape()); - for (int64_t idx = begin; idx < end; ++idx) { - bool out_of_bound = getBaseIndexFromWindowIndex( - window_shape, window_strides, window_dilations, window_padding, - base_shape, base_dilations, - absl::MakeSpan(tiled_index).subspan(0, ndim), window_index, - absl::MakeSpan(base_x_window_index).subspan(0, ndim)); - if (!out_of_bound) { - // TODO: anti-pattern, do not use .data(), use ops instead. - output.data().update_slice( - selected.data().slice_scalar_at(tiled_index), - base_x_window_index); - } - bumpIndices(source.shape(), - absl::MakeSpan(tiled_index).subspan(0, ndim)); - } - }); - } while (bumpIndices(window_shape, absl::MakeSpan(window_index))); - - auto base_1d_shape = base_shape; - base_1d_shape.push_back(window_numel); - output = hal::reshape(ctx, output, base_1d_shape); - - output = TreeReduce( - ctx, {output}, base_1d_shape.size() - 1, - [&](absl::Span lhs, absl::Span rhs) { - return std::vector{hal::add(ctx, lhs[0], rhs[0])}; - })[0]; + // In MaxPoolScatter, one-hot scatter_indices is carried from the 'forward' + // reduce window operation. So we can avoid on equal test here. + + auto init = hal::zeros(ctx, source.dtype(), {}); + auto scatter_fn = [&ctx](spu::Value const &lhs, + spu::Value const &rhs) -> spu::Value { + return hal::add(ctx, lhs, rhs); + }; - return hal::reshape(ctx, output, base_shape); + return ScatterWindow(ctx, source, scatter_indices, init, base_shape, + window_shape, window_strides, scatter_fn); } spu::Value SelectAndScatter( SPUContext *ctx, const spu::Value &base, const spu::Value &source, const spu::Value &init_val, const Shape &window_shape, const Strides &window_strides, - absl::Span> window_padding, + absl::Span> padding, const ValueBinaryFn &select_fn, const ValueBinaryFn &scatter_fn) { + // sanity check. const size_t ndim = base.shape().size(); + SPU_ENFORCE_EQ(ndim, window_shape.size()); + SPU_ENFORCE_EQ(ndim, window_strides.size()); + SPU_ENFORCE(init_val.shape().isScalar()); - // expand the base, simplify following actions without strides and padding. - auto tiled = - expandWindow(ctx, base, window_shape, window_strides, window_padding); - - // collapse the tile to 1d for better reduce performance - auto tiled_1d_shape = source.shape(); - const int64_t window_numel = std::accumulate( - window_shape.begin(), window_shape.end(), 1, std::multiplies<>()); - tiled_1d_shape.push_back(window_numel); - auto tiled_1d = hal::reshape(ctx, tiled, tiled_1d_shape); + // alias shapes, use B,W,N. + const Shape &W = window_shape; + const Shape &N = source.shape(); + // clang-format off + // + // The algorithm: + // tiled = win_count x window : (N,W) + // index = iota(0, num_window) : (_,W) + // sel_pos = reduce(tiled, index) : (N,) # find selected position of each window + // onehot = sel_pos == index : (N,_)->(_,W)->(N,W) # each window ia one-hot position + // sel_val = sel(onehot, source, init) : (N,W)->(N,)->()->(N,W) + // sel_val = reduce(sel_val, 1) : (N,W)->(N) // - auto indices = Iota(ctx, DT_I64, window_numel); - indices = hal::broadcast_to(ctx, indices, tiled_1d_shape); + // clang-format on + + // Expand the base, simplify further actions without strides and padding. + // Now tiled shaped is (N0, N1, ..., Nn, W0, W1, ..., Wn) where + // window_count = (N0, N1, ..., Nn), where Ni = (Bi-Wi)/Strides{i} + 1 + auto tiled = expandWindow(ctx, base, W, window_strides, padding, init_val); + SPU_ENFORCE_EQ(tiled.shape().size(), 2 * ndim); + SPU_ENFORCE_EQ(N, Shape(tiled.shape().begin(), tiled.shape().begin() + ndim)); + SPU_ENFORCE_EQ(W, Shape(tiled.shape().begin() + ndim, tiled.shape().end())); + + // Use 2k, (N, W) to (N.numel(), W.numel()) to make future processing simpler. + const Shape NW2d = {N.numel(), W.numel()}; + auto tiled2d = hal::reshape(ctx, tiled, NW2d); + + // indices is the iota for each window. + // win0: [0, 1, 2, 3, 4, 5] + // win1: [0, 1, 2, 3, 4, 5] + // ... + auto indices = hal::broadcast_to(ctx, Iota(ctx, DT_I64, W.numel()), NW2d); + SPU_ENFORCE_EQ(indices.shape(), NW2d); // Apply the reduce with indices. - // total number of select_fn call is log2(window_numel) auto reduced = TreeReduce( - ctx, {tiled_1d, indices}, tiled_1d_shape.size() - 1, + ctx, {tiled2d, indices}, 1, [&](absl::Span lhs, absl::Span rhs) { + SPU_ENFORCE(lhs.size() == 2 && rhs.size() == 2); auto pred = select_fn(lhs[0], rhs[0]); - pred = hal::_prefer_a(ctx, pred); - std::vector rets; for (size_t idx = 0; idx < lhs.size(); idx++) { + // TODO: if reduce window does not require lhs[0].shape == + // lhs[1].shape, then we could avoid the later comparison. rets.push_back(hal::select(ctx, pred, lhs[idx], rhs[idx])); } return rets; }); - Axes broadcast_dims(source.shape().size(), 0); - std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); - - // selected_pos is the one hot encoding for each window. - auto selected_pos = hal::equal( - ctx, hal::broadcast_to(ctx, reduced[1], tiled_1d_shape), indices); - auto selected = hal::select( - ctx, selected_pos, - hal::broadcast_to(ctx, source, tiled_1d_shape, broadcast_dims), - hal::broadcast_to(ctx, init_val, tiled_1d_shape)); - - // last step, collapse expanded shape to strided window - // build a tensor with [base.shape() x window_shape], so each - // [base.shape(), window_index] does not overlap with each other. - selected = hal::reshape(ctx, selected, tiled.shape()); - - auto base_x_window_shape = base.shape(); - base_x_window_shape.insert(base_x_window_shape.end(), window_shape.begin(), - window_shape.end()); - auto output = hal::expand(ctx, init_val, base_x_window_shape); - - const std::vector window_dilations(window_shape.size(), 1); - const std::vector base_dilations(source.shape().size(), 1); - std::vector window_index(ndim, 0); - Index tiled_index(2 * ndim, 0); - Index base_x_window_index(2 * ndim, 0); - std::vector base_index(ndim, 0); - do { - std::copy(window_index.begin(), window_index.end(), - base_x_window_index.begin() + ndim); - std::fill(tiled_index.begin(), tiled_index.begin() + ndim, 0); - std::copy(window_index.begin(), window_index.end(), - tiled_index.begin() + ndim); - - do { - bool out_of_bound = getBaseIndexFromWindowIndex( - window_shape, window_strides, window_dilations, window_padding, - base.shape(), base_dilations, - absl::MakeSpan(tiled_index).subspan(0, ndim), window_index, - absl::MakeSpan(base_x_window_index).subspan(0, ndim)); - if (!out_of_bound) { - output.data().update_slice(selected.data().slice_scalar_at(tiled_index), - base_x_window_index); - } - - } while (bumpIndices(source.shape(), - absl::MakeSpan(tiled_index).subspan(0, ndim))); - } while (bumpIndices(window_shape, absl::MakeSpan(window_index))); - - auto base_1d_shape = base.shape(); - base_1d_shape.push_back(window_numel); - output = hal::reshape(ctx, output, base_1d_shape); - - output = TreeReduce( - ctx, {output}, base_1d_shape.size() - 1, - [&](absl::Span lhs, absl::Span rhs) { - return std::vector{scatter_fn(lhs[0], rhs[0])}; - })[0]; - - return hal::reshape(ctx, output, base.shape()); + // indices is the iota for each window. + // win0: [3] // position 3 is selected. + // win1: [1] // position 1 is selected. + // ... + auto sel_pos = reduced[1]; + SPU_ENFORCE_EQ(sel_pos.shape(), Shape({N.numel(), 1})); + + // win0: [3, 3, 3, 3, 3, 3] + // win1: [1, 1, 1, 1, 1, 1] + sel_pos = hal::broadcast_to(ctx, sel_pos, NW2d, {0}); + + // one hot encoding for each window + // win0: [0, 0, 0, 1, 0, 0] // position 3 is selected. + // win1: [0, 1, 0, 0, 0, 0] // position 1 is selected. + // ... + auto onehot = hal::equal(ctx, sel_pos, indices); + SPU_ENFORCE_EQ(onehot.shape(), NW2d); + + Shape N_W1d = N; + N_W1d.push_back(W.numel()); + + return ScatterWindow(ctx, source, hal::reshape(ctx, onehot, N_W1d), init_val, + base.shape(), window_shape, window_strides, scatter_fn); } } // namespace spu::kernel::hlo diff --git a/libspu/kernel/hlo/select_and_scatter_test.cc b/libspu/kernel/hlo/select_and_scatter_test.cc new file mode 100644 index 00000000..6bd3d8d1 --- /dev/null +++ b/libspu/kernel/hlo/select_and_scatter_test.cc @@ -0,0 +1,113 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "libspu/kernel/hlo/select_and_scatter.h" + +#include "gtest/gtest.h" +#include "xtensor/xio.hpp" + +#include "libspu/kernel/hal/polymorphic.h" +#include "libspu/kernel/hal/type_cast.h" +#include "libspu/kernel/test_util.h" + +namespace spu::kernel::hlo { + +struct SelectAndScatterTestParam { + xt::xarray operand; + xt::xarray source; + std::vector> window_padding; + Shape window_shape; + Strides window_strides; + xt::xarray expected; +}; + +class SelectAndScatterTest + : public ::testing::TestWithParam { + public: + SelectAndScatterTest() : ctx_(test::makeSPUContext()) {} + + SPUContext ctx_; +}; + +TEST_P(SelectAndScatterTest, ParamTest) { + xt::xarray operand = GetParam().operand; + xt::xarray source = GetParam().source; + xt::xarray expected = GetParam().expected; + xt::xarray init = 0; + + Value operand_s = test::makeValue(&ctx_, operand, VIS_SECRET); + Value source_s = test::makeValue(&ctx_, source, VIS_SECRET); + Value init_val = test::makeValue(&ctx_, init, VIS_SECRET); + + const auto ret = SelectAndScatter( + &ctx_, operand_s, source_s, init_val, GetParam().window_shape, + GetParam().window_strides, GetParam().window_padding, + [&](const spu::Value &lhs, const spu::Value &rhs) { + return hal::greater(&ctx_, lhs, rhs); + }, + [&](const spu::Value &lhs, const spu::Value &rhs) { + return hal::add(&ctx_, lhs, rhs); + }); + auto ret_hat = hal::dump_public_as(&ctx_, hal::reveal(&ctx_, ret)); + EXPECT_TRUE(xt::allclose(expected, ret_hat, 0.01, 0.001)); +} + +INSTANTIATE_TEST_CASE_P( + SelectAndScatterTest_Instantiation, SelectAndScatterTest, + ::testing::Values( + SelectAndScatterTestParam{{1, 9, 3, 7, 5, 6}, + {34, 42}, + {{0, 0}}, + {3}, + {3}, + {0, 34, 0, 42, 0, 0}}, + SelectAndScatterTestParam{{{7, 2, 5, 3, 10, 2}, + {3, 8, 9, 3, 4, 2}, + {1, 5, 7, 5, 6, 1}, + {0, 6, 2, 7, 2, 8}}, + {{2, 6}, {3, 1}}, + {{0, 0}, {0, 0}}, + {2, 3}, + {2, 3}, + {{0, 0, 0, 0, 6, 0}, + {0, 0, 2, 0, 0, 0}, + {0, 0, 3, 0, 0, 0}, + {0, 0, 0, 0, 0, 1}}}, + SelectAndScatterTestParam{{1, 9, 3, 7, 5, 6}, + {34, 42, 53, 19}, + {{0, 0}}, + {3}, + {1}, + {0, 76, 0, 72, 0, 0}}, + SelectAndScatterTestParam{{{7, 2, 5, 3, 10, 2}, {3, 8, 9, 3, 4, 2}}, + {{2, 6}}, + {{0, 0}, {0, 0}}, + {2, 3}, + {2, 3}, + {{0, 0, 0, 0, 6, 0}, {0, 0, 2, 0, 0, 0}}}, + SelectAndScatterTestParam{{{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}}, + {{2, 6, 4}}, + {{0, 0}, {0, 0}}, + {2, 3}, + {1, 1}, + {{0, 0, 0, 0, 0}, {0, 0, 12, 0, 0}}}, + SelectAndScatterTestParam{ + {{1.5, 2.5, 1.5}, {3.5, 1.5, 3.5}, {4.5, 2.5, 4.5}}, + {{1.0, 2.0}, {3.0, 4.0}}, + {{0, 0}, {0, 0}}, + {2, 2}, + {1, 1}, + {{0.0, 0.0, 0.0}, {1.0, 0.0, 2.0}, {3.0, 0.0, 4.0}}})); + +} // namespace spu::kernel::hlo diff --git a/libspu/kernel/hlo/utils.cc b/libspu/kernel/hlo/utils.cc index 369efb16..0e2bba61 100644 --- a/libspu/kernel/hlo/utils.cc +++ b/libspu/kernel/hlo/utils.cc @@ -45,28 +45,12 @@ xt::xarray getIndices(SPUContext *ctx, const spu::Value &value) { spu::Value expandWindow(SPUContext *ctx, const spu::Value &base, const Shape &window_shape, - const Strides &window_strides, - absl::Span> padding) { + const Strides &window_strides) { const size_t ndim = base.shape().size(); // sanity check. SPU_ENFORCE(ndim == window_shape.size()); SPU_ENFORCE(ndim == window_strides.size()); - SPU_ENFORCE(ndim == padding.size()); - - // pad the input. - Value padded; - { - Sizes padding_lo(ndim); - Sizes padding_hi(ndim); - Sizes padding_in(ndim, 0); // no dilation - for (size_t idx = 0; idx < padding.size(); idx++) { - padding_lo[idx] = padding[idx].first; - padding_hi[idx] = padding[idx].second; - } - padded = hal::pad(ctx, base, hal::constant(ctx, 0, base.dtype(), {}), - padding_lo, padding_hi, padding_in); - } // let base = (B0, B1, ..., Bn) // window = (W0, W1, ..., Wn) @@ -78,8 +62,6 @@ spu::Value expandWindow(SPUContext *ctx, const spu::Value &base, const Strides &S = window_strides; Shape N(ndim); for (size_t dim = 0; dim < ndim; dim++) { - SPU_ENFORCE_EQ((B[dim] - W[dim]) % S[dim], 0, - "window is not aligned, B={}, W={}, S={}", B, W, S); N[dim] = (B[dim] - W[dim]) / S[dim] + 1; } @@ -95,7 +77,7 @@ spu::Value expandWindow(SPUContext *ctx, const spu::Value &base, start[dim] = window_index[dim] * S[dim]; end[dim] = start[dim] + W[dim]; } - auto window = hal::slice(ctx, padded, start, end, {}); + auto window = hal::slice(ctx, base, start, end, {}); Shape new_shape = window.shape(); new_shape.insert(new_shape.begin(), 1); @@ -114,4 +96,31 @@ spu::Value expandWindow(SPUContext *ctx, const spu::Value &base, return hal::reshape(ctx, res, res_shape); } +spu::Value expandWindow(SPUContext *ctx, const spu::Value &base, + const Shape &window_shape, + const Strides &window_strides, + absl::Span> padding, + const spu::Value &init_val) { + // sanity check. + const size_t ndim = base.shape().size(); + SPU_ENFORCE(ndim == padding.size()); + + Sizes padding_lo(ndim); + Sizes padding_hi(ndim); + Sizes padding_in(ndim, 0); // no dilation + bool need_pad = false; + for (size_t idx = 0; idx < padding.size(); idx++) { + padding_lo[idx] = padding[idx].first; + padding_hi[idx] = padding[idx].second; + need_pad |= (padding[idx].first != 0 || padding[idx].second != 0); + } + if (need_pad) { + Value padded = + hal::pad(ctx, base, init_val, padding_lo, padding_hi, padding_in); + return expandWindow(ctx, padded, window_shape, window_strides); + } + + return expandWindow(ctx, base, window_shape, window_strides); +} + } // namespace spu::kernel diff --git a/libspu/kernel/hlo/utils.h b/libspu/kernel/hlo/utils.h index e42c992d..b94436f0 100644 --- a/libspu/kernel/hlo/utils.h +++ b/libspu/kernel/hlo/utils.h @@ -141,6 +141,7 @@ inline void RunOnWindowIndex( spu::Value expandWindow(SPUContext *ctx, const spu::Value &base, const Shape &window_shape, const Strides &window_strides, - absl::Span> padding); + absl::Span> padding, + const spu::Value &init_val); } // namespace spu::kernel diff --git a/libspu/spu.proto b/libspu/spu.proto index 33a75fb5..36affd27 100644 --- a/libspu/spu.proto +++ b/libspu/spu.proto @@ -185,8 +185,8 @@ message RuntimeConfig { // When enabled, runtime dumps executed executables in the dump_dir, debug // purpose only. - bool enable_processor_dump = 13; - string processor_dump_dir = 14; + bool enable_runtime_snapshot = 13; + string snapshot_dump_dir = 14; // When enabled, runtime records detailed pphlo timing data, debug purpose // only.