Skip to content

Commit

Permalink
Repo sync (#377)
Browse files Browse the repository at this point in the history
  • Loading branch information
anakinxc authored Oct 27, 2023
1 parent d386d3f commit 759fb6e
Show file tree
Hide file tree
Showing 27 changed files with 3,265 additions and 58 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
- [Feature] Add radix sort support for SEMI2K
- [Feature] Experimental: ABY3 matmul CUDA support
- [Feature] Experimental: Private support under colocated mode
- [Feature] Add yacl ot support for Cheetah

## 20230906

Expand Down
2 changes: 1 addition & 1 deletion bazel/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")

SECRETFLOW_GIT = "https://github.com/secretflow"

YACL_COMMIT_ID = "5418371c4335f4a64fbd0bdabb0efd94da2af808"
YACL_COMMIT_ID = "f933d7ff4caf0d9f7ea84cc3e9f51a9a6ee9eeca"

def spu_deps():
_rules_cuda()
Expand Down
78 changes: 36 additions & 42 deletions libspu/device/pphlo/pphlo_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -367,54 +367,48 @@ void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope,
void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope,
mlir::pphlo::DotGeneralOp &op, const ExecutionOptions &opts) {
auto dnum = op.getDotDimensionNumbers();
// Should in order
SPU_ENFORCE(dnum.getLhsBatchingDimensions().size() == 1 &&
dnum.getLhsContractingDimensions().size() == 1 &&
dnum.getLhsBatchingDimensions()[0] == 0 &&
dnum.getLhsContractingDimensions()[0] == 2,
"LHS dims is not in order");
SPU_ENFORCE(dnum.getRhsBatchingDimensions().size() == 1 &&
dnum.getRhsContractingDimensions().size() == 1 &&
dnum.getRhsBatchingDimensions()[0] == 0 &&
dnum.getRhsContractingDimensions()[0] == 1,
"RHS dims is not in order");

auto lhs = lookupValue(sscope, op.getLhs(), opts);
auto rhs = lookupValue(sscope, op.getRhs(), opts);
SPU_ENFORCE(lhs.shape()[0] == rhs.shape()[0], "Batch dim should equal");
int64_t num_batch = lhs.shape()[0];

std::vector<spu::Value> results(num_batch);
Index lhs_slice_begin(3, 0);
Index lhs_slice_end(lhs.shape().begin(), lhs.shape().end());
Index rhs_slice_begin(3, 0);
Index rhs_slice_end(rhs.shape().begin(), rhs.shape().end());
Strides strides(lhs.shape().size(), 1);

Shape lhs_slice_shape{lhs.shape()[1], lhs.shape()[2]};
Shape rhs_slice_shape{rhs.shape()[1], rhs.shape()[2]};
Shape ret_slice_shape{1, lhs.shape()[1], rhs.shape()[2]};

for (int64_t batch_idx = 0; batch_idx < num_batch; ++batch_idx) {
lhs_slice_begin[0] = batch_idx;
lhs_slice_end[0] = batch_idx + 1;
rhs_slice_begin[0] = batch_idx;
rhs_slice_end[0] = batch_idx + 1;
auto lhs_slice = kernel::hlo::Reshape(
sctx,
kernel::hlo::Slice(sctx, lhs, lhs_slice_begin, lhs_slice_end, strides),
lhs_slice_shape);
auto rhs_slice = kernel::hlo::Reshape(
sctx,
kernel::hlo::Slice(sctx, rhs, rhs_slice_begin, rhs_slice_end, strides),
rhs_slice_shape);
results[batch_idx] = kernel::hlo::Reshape(
sctx, kernel::hlo::Dot(sctx, lhs_slice, rhs_slice), ret_slice_shape);
SPU_ENFORCE(lhs.shape().ndim() == 3 && rhs.shape().ndim() == 3);

SPU_ENFORCE(dnum.getLhsContractingDimensions().size() == 1 &&
dnum.getRhsContractingDimensions().size() == 1);
if (dnum.getLhsBatchingDimensions().size() == 1) {
// LHS should be [b,m,k]
SPU_ENFORCE(dnum.getLhsBatchingDimensions()[0] == 0 &&
dnum.getLhsContractingDimensions()[0] == 2,
"LHS dims is not in order");
} else {
// LHS should be [b0, b1, k]
SPU_ENFORCE(dnum.getLhsBatchingDimensions().size() == 2 &&
dnum.getLhsContractingDimensions()[0] == 2,
"LHS dims is not in order");
// reshape to [b0xb1, 1, k]
lhs = kernel::hlo::Reshape(
sctx, lhs, {lhs.shape()[0] * lhs.shape()[1], 1, lhs.shape()[2]});
}

if (dnum.getRhsBatchingDimensions().size() == 1) {
// RHS should be [b,k,n]
SPU_ENFORCE(dnum.getRhsBatchingDimensions()[0] == 0 &&
dnum.getRhsContractingDimensions()[0] == 1,
"RHS dims is not in order");
} else {
// RHS should be [b0, b1, k]
SPU_ENFORCE(dnum.getRhsBatchingDimensions().size() == 2 &&
dnum.getRhsContractingDimensions()[0] == 2,
"LHS dims is not in order");
// reshape to [b0xb1, k, 1]
rhs = kernel::hlo::Reshape(
sctx, rhs, {rhs.shape()[0] * rhs.shape()[1], rhs.shape()[2], 1});
}

// Must be [b,m,k] * [b, k, n]
SPU_ENFORCE(lhs.shape()[0] == rhs.shape()[0], "Batch dim should equal");
auto ret_type = op.getResult().getType().dyn_cast<mlir::RankedTensorType>();
auto ret = kernel::hlo::Reshape(
sctx, kernel::hlo::Concatenate(sctx, results, 0), ret_type.getShape());
auto ret = kernel::hlo::Reshape(sctx, kernel::hlo::DotGeneral(sctx, lhs, rhs),
ret_type.getShape());

addValue(sscope, op.getResult(), std::move(ret), opts);
}
Expand Down
36 changes: 36 additions & 0 deletions libspu/kernel/hlo/basic_binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "libspu/kernel/hal/complex.h"
#include "libspu/kernel/hal/constants.h"
#include "libspu/kernel/hal/polymorphic.h"
#include "libspu/kernel/hal/shape_ops.h"

namespace spu::kernel::hlo {

Expand Down Expand Up @@ -75,4 +76,39 @@ spu::Value Dot(SPUContext *ctx, const spu::Value &lhs, const spu::Value &rhs) {
return hal::matmul(ctx, lhs, rhs);
}

spu::Value DotGeneral(SPUContext *ctx, const spu::Value &lhs,
const spu::Value &rhs) {
int64_t num_batch = lhs.shape()[0];

std::vector<spu::Value> results(num_batch);
Index lhs_slice_begin(3, 0);
Index lhs_slice_end(lhs.shape().begin(), lhs.shape().end());
Index rhs_slice_begin(3, 0);
Index rhs_slice_end(rhs.shape().begin(), rhs.shape().end());
Strides strides(lhs.shape().size(), 1);

Shape lhs_slice_shape{lhs.shape()[1], lhs.shape()[2]};
Shape rhs_slice_shape{rhs.shape()[1], rhs.shape()[2]};
Shape ret_slice_shape{1, lhs.shape()[1], rhs.shape()[2]};

for (int64_t batch_idx = 0; batch_idx < num_batch; ++batch_idx) {
lhs_slice_begin[0] = batch_idx;
lhs_slice_end[0] = batch_idx + 1;
rhs_slice_begin[0] = batch_idx;
rhs_slice_end[0] = batch_idx + 1;
auto lhs_slice = kernel::hal::reshape(
ctx,
kernel::hal::slice(ctx, lhs, lhs_slice_begin, lhs_slice_end, strides),
lhs_slice_shape);
auto rhs_slice = kernel::hal::reshape(
ctx,
kernel::hal::slice(ctx, rhs, rhs_slice_begin, rhs_slice_end, strides),
rhs_slice_shape);
results[batch_idx] = kernel::hal::reshape(
ctx, kernel::hal::matmul(ctx, lhs_slice, rhs_slice), ret_slice_shape);
}

return kernel::hal::concatenate(ctx, results, 0);
}

} // namespace spu::kernel::hlo
1 change: 1 addition & 0 deletions libspu/kernel/hlo/basic_binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ SIMPLE_BINARY_KERNEL_DECL(Div)
SIMPLE_BINARY_KERNEL_DECL(Remainder)
SIMPLE_BINARY_KERNEL_DECL(Dot)
SIMPLE_BINARY_KERNEL_DECL(Complex)
SIMPLE_BINARY_KERNEL_DECL(DotGeneral)

#undef SIMPLE_BINARY_KERNEL_DECL

Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/cheetah/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ spu_cc_library(
deps = [
"//libspu/mpc/cheetah/arith:cheetah_arith",
"//libspu/mpc/cheetah/nonlinear:cheetah_nonlinear",
"//libspu/mpc/cheetah/ot:cheetah_ot",
"//libspu/mpc/cheetah/rlwe:cheetah_rlwe",
"//libspu/mpc/cheetah/yacl_ot:yacl_ferret_ot",
],
)

Expand Down
4 changes: 2 additions & 2 deletions libspu/mpc/cheetah/nonlinear/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ spu_cc_library(
srcs = ["compare_prot.cc"],
hdrs = ["compare_prot.h"],
deps = [
"//libspu/mpc/cheetah/ot:cheetah_ot",
"//libspu/mpc/cheetah/yacl_ot:yacl_ferret_ot",
"@yacl//yacl/link",
],
)
Expand All @@ -40,7 +40,7 @@ spu_cc_library(
srcs = ["equal_prot.cc"],
hdrs = ["equal_prot.h"],
deps = [
"//libspu/mpc/cheetah/ot:cheetah_ot",
"//libspu/mpc/cheetah/yacl_ot:yacl_ferret_ot",
"@yacl//yacl/link",
],
)
Expand Down
6 changes: 3 additions & 3 deletions libspu/mpc/cheetah/nonlinear/compare_prot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
#include "yacl/link/link.h"

#include "libspu/core/type.h"
#include "libspu/mpc/cheetah/ot/basic_ot_prot.h"
#include "libspu/mpc/cheetah/ot/ferret.h"
#include "libspu/mpc/cheetah/ot/util.h"
#include "libspu/mpc/cheetah/type.h"
#include "libspu/mpc/cheetah/yacl_ot/basic_ot_prot.h"
#include "libspu/mpc/cheetah/yacl_ot/util.h"
#include "libspu/mpc/cheetah/yacl_ot/yacl_ferret.h"
#include "libspu/mpc/common/communicator.h"
#include "libspu/mpc/utils/ring_ops.h"

Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/cheetah/nonlinear/compare_prot_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

#include "gtest/gtest.h"

#include "libspu/mpc/cheetah/ot/basic_ot_prot.h"
#include "libspu/mpc/cheetah/type.h"
#include "libspu/mpc/cheetah/yacl_ot/basic_ot_prot.h"
#include "libspu/mpc/utils/ring_ops.h"
#include "libspu/mpc/utils/simulate.h"

Expand Down
6 changes: 3 additions & 3 deletions libspu/mpc/cheetah/nonlinear/equal_prot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
#include "yacl/link/link.h"

#include "libspu/core/type.h"
#include "libspu/mpc/cheetah/ot/basic_ot_prot.h"
#include "libspu/mpc/cheetah/ot/ferret.h"
#include "libspu/mpc/cheetah/ot/util.h"
#include "libspu/mpc/cheetah/type.h"
#include "libspu/mpc/cheetah/yacl_ot/basic_ot_prot.h"
#include "libspu/mpc/cheetah/yacl_ot/util.h"
#include "libspu/mpc/cheetah/yacl_ot/yacl_ferret.h"
#include "libspu/mpc/common/communicator.h"
#include "libspu/mpc/utils/ring_ops.h"

Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/cheetah/nonlinear/equal_prot_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
#include "gtest/gtest.h"

#include "libspu/core/xt_helper.h"
#include "libspu/mpc/cheetah/ot/basic_ot_prot.h"
#include "libspu/mpc/cheetah/type.h"
#include "libspu/mpc/cheetah/yacl_ot/basic_ot_prot.h"
#include "libspu/mpc/utils/ring_ops.h"
#include "libspu/mpc/utils/simulate.h"

Expand Down
4 changes: 2 additions & 2 deletions libspu/mpc/cheetah/nonlinear/truncate_prot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

#include "libspu/core/type.h"
#include "libspu/mpc/cheetah/nonlinear/compare_prot.h"
#include "libspu/mpc/cheetah/ot/basic_ot_prot.h"
#include "libspu/mpc/cheetah/ot/util.h"
#include "libspu/mpc/cheetah/type.h"
#include "libspu/mpc/cheetah/yacl_ot/basic_ot_prot.h"
#include "libspu/mpc/cheetah/yacl_ot/util.h"
#include "libspu/mpc/utils/ring_ops.h"

namespace spu::mpc::cheetah {
Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/cheetah/nonlinear/truncate_prot_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

#include "gtest/gtest.h"

#include "libspu/mpc/cheetah/ot/basic_ot_prot.h"
#include "libspu/mpc/cheetah/type.h"
#include "libspu/mpc/cheetah/yacl_ot/basic_ot_prot.h"
#include "libspu/mpc/utils/ring_ops.h"
#include "libspu/mpc/utils/simulate.h"

Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/cheetah/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include "libspu/core/object.h"
#include "libspu/mpc/cheetah/arith/cheetah_dot.h"
#include "libspu/mpc/cheetah/arith/cheetah_mul.h"
#include "libspu/mpc/cheetah/ot/basic_ot_prot.h"
#include "libspu/mpc/cheetah/yacl_ot/basic_ot_prot.h"

namespace spu::mpc::cheetah {

Expand Down
82 changes: 82 additions & 0 deletions libspu/mpc/cheetah/yacl_ot/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright 2023 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

load("//bazel:spu.bzl", "spu_cc_library", "spu_cc_test")
load("@yacl//bazel:yacl.bzl", "AES_COPT_FLAGS")

package(default_visibility = ["//visibility:public"])

spu_cc_test(
name = "yacl_ferret_test",
srcs = ["yacl_ferret_test.cc"],
deps = [
":yacl_ferret_ot",
"//libspu/mpc/utils:simulate",
],
)

spu_cc_library(
name = "yacl_ferret_ot",
srcs = [
"basic_ot_prot.cc",
"util.cc",
"yacl_ferret.cc",
"yacl_ote_adapter.cc",
],
hdrs = [
"basic_ot_prot.h",
"mitccrh_exp.h",
"util.h",
"yacl_ferret.h",
"yacl_ote_adapter.h",
],
copts = AES_COPT_FLAGS + ["-Wno-ignored-attributes"],
deps = [
"//libspu/core:xt_helper",
"//libspu/mpc/cheetah:type",
"//libspu/mpc/common:communicator",
"//libspu/mpc/semi2k:conversion",
"@com_github_emptoolkit_emp_tool//:emp-tool",
"@yacl//yacl/base:dynamic_bitset",
"@yacl//yacl/base:int128",
"@yacl//yacl/crypto/base/aes:aes_opt",
"@yacl//yacl/crypto/primitives/ot:base_ot",
"@yacl//yacl/crypto/primitives/ot:ferret_ote",
"@yacl//yacl/crypto/primitives/ot:iknp_ote",
"@yacl//yacl/crypto/tools:random_permutation",
"@yacl//yacl/crypto/utils:rand",
"@yacl//yacl/link",
],
)

spu_cc_test(
name = "basic_ot_prot_test",
size = "large",
srcs = ["basic_ot_prot_test.cc"],
tags = [
"exclusive-if-local",
],
deps = [
":yacl_ferret_ot",
"//libspu/mpc/utils:simulate",
],
)

spu_cc_test(
name = "util_test",
srcs = ["util_test.cc"],
deps = [
":yacl_ferret_ot",
],
)
Loading

0 comments on commit 759fb6e

Please sign in to comment.