From 338ff3d0935a8ee064769194c9eacce1be7eec2e Mon Sep 17 00:00:00 2001 From: anakinxc <103552181+anakinxc@users.noreply.github.com> Date: Fri, 22 Dec 2023 11:29:53 +0800 Subject: [PATCH] Repo sync (#447) --- .bazelrc | 1 + bazel/repositories.bzl | 4 +-- examples/python/ml/flax_llama7b/README.md | 2 +- libspu/device/io.cc | 10 ++++-- libspu/device/io.h | 3 +- libspu/device/io_test.cc | 32 +++++++++++++++++++ libspu/mpc/aby3/io.cc | 4 +-- libspu/mpc/cheetah/arith/common.cc | 4 +-- libspu/mpc/cheetah/arith/matmat_prot_test.cc | 2 +- libspu/mpc/cheetah/ot/yacl/BUILD.bazel | 3 +- libspu/mpc/cheetah/ot/yacl/ferret.cc | 3 +- .../mpc/cheetah/rlwe/modswitch_helper_test.cc | 2 +- libspu/mpc/common/prg_state.cc | 10 +++--- libspu/mpc/semi2k/beaver/beaver_tfp.cc | 2 +- libspu/mpc/semi2k/beaver/beaver_ttp.cc | 2 +- libspu/mpc/spdz2k/beaver/BUILD.bazel | 2 ++ libspu/mpc/spdz2k/beaver/beaver_tfp.cc | 2 ++ libspu/mpc/spdz2k/commitment.cc | 2 +- libspu/mpc/spdz2k/ot/BUILD.bazel | 7 ++-- libspu/mpc/spdz2k/ot/kos_ote.cc | 5 +-- libspu/mpc/utils/ring_ops.cc | 2 +- setup.py | 6 +++- spu/BUILD.bazel | 23 +++++++++++++ spu/libpsi.cc | 16 ++++++++++ spu/psi.py | 24 ++++++++++++-- 25 files changed, 142 insertions(+), 31 deletions(-) diff --git a/.bazelrc b/.bazelrc index a1604334..118c3bd2 100644 --- a/.bazelrc +++ b/.bazelrc @@ -13,6 +13,7 @@ # limitations under the License. common --experimental_repo_remote_exec +common --experimental_cc_shared_library # Required by OpenXLA build --nocheck_visibility diff --git a/bazel/repositories.bzl b/bazel/repositories.bzl index a88465f8..074c2fdf 100644 --- a/bazel/repositories.bzl +++ b/bazel/repositories.bzl @@ -18,9 +18,9 @@ load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe") SECRETFLOW_GIT = "https://github.com/secretflow" -YACL_COMMIT_ID = "2b7d8882c78f07bd9e78217b7f9ca13135781e65" +YACL_COMMIT_ID = "816ac40ead311507ade19d521e1069b065747b63" -LIBSPI_COMMIT_ID = "dbf452b9d87619b40b73746b7c9afefc1845975b" +LIBSPI_COMMIT_ID = "dbe7028f9fceecbaf944b99d908d56f6c07449fc" def spu_deps(): _rules_cuda() diff --git a/examples/python/ml/flax_llama7b/README.md b/examples/python/ml/flax_llama7b/README.md index ab1f5a90..fb4f0980 100644 --- a/examples/python/ml/flax_llama7b/README.md +++ b/examples/python/ml/flax_llama7b/README.md @@ -22,7 +22,6 @@ This example demonstrates how to use SPU to run secure inference on a pre-traine Since EasyLM have an issue,so we have to make a samll change to support the option "streaming=false". Open and edit "convert_hf_to_easylm.py", chang this: - ```python parser.add_argument("--streaming", action="store_true", default=True, help="whether is model weight saved stream format",) ``` @@ -35,6 +34,7 @@ This example demonstrates how to use SPU to run secure inference on a pre-traine Download trained LLaMA-B[PyTroch-Version] from [Hugging Face](https://huggingface.co/openlm-research/open_llama_7b) , and convert it to Flax.msgpack as: + ```sh python convert_hf_to_easylm.py \ --checkpoint_dir path-to-flax-llama7b-dir \ diff --git a/libspu/device/io.cc b/libspu/device/io.cc index 4a40667e..69d7674e 100644 --- a/libspu/device/io.cc +++ b/libspu/device/io.cc @@ -178,7 +178,13 @@ ColocatedIo::ColocatedIo(SPUContext *sctx) : sctx_(sctx) {} void ColocatedIo::hostSetVar(const std::string &name, const PtBufferView &bv, Visibility vtype) { - unsynced_[name] = {convertToNdArray(bv), vtype}; + if (vtype == VIS_PRIVATE) { + // handle SECRET/PRIVATE compiler/runtime trick. + unsynced_[name] = {convertToNdArray(bv), VIS_SECRET, + static_cast(sctx_->lctx()->Rank())}; + } else { + unsynced_[name] = {convertToNdArray(bv), vtype}; + } } NdArrayRef ColocatedIo::hostGetVar(const std::string &name) const { @@ -325,7 +331,7 @@ void ColocatedIo::sync() { PtBufferView bv(arr.data(), arr.eltype().as()->pt_type(), arr.shape(), arr.strides()); - auto shares = io.makeShares(bv, priv.vtype); + auto shares = io.makeShares(bv, priv.vtype, priv.owner_rank); SPU_ENFORCE(shares.size() == lctx->WorldSize()); for (size_t idx = 0; idx < shares.size(); idx++) { diff --git a/libspu/device/io.h b/libspu/device/io.h index b1caa898..5cab1985 100644 --- a/libspu/device/io.h +++ b/libspu/device/io.h @@ -128,7 +128,8 @@ class ColocatedIo { // un-synchronized data. struct PrivData { NdArrayRef arr; - Visibility vtype; + Visibility vtype{Visibility::VIS_INVALID}; + int owner_rank{-1}; }; std::map unsynced_; diff --git a/libspu/device/io_test.cc b/libspu/device/io_test.cc index 8d606da3..12fb2f64 100644 --- a/libspu/device/io_test.cc +++ b/libspu/device/io_test.cc @@ -131,6 +131,38 @@ TEST_P(ColocatedIoTest, Works) { }); } +TEST(ColocatedIoTest, PrivateWorks) { + const size_t kWorldSize = 2; + + RuntimeConfig hconf; + hconf.set_protocol(ProtocolKind::SEMI2K); + hconf.set_field(FieldType::FM64); + hconf.set_experimental_enable_colocated_optimization(true); + + mpc::utils::simulate(kWorldSize, [&](auto lctx) { + SPUContext sctx(hconf, lctx); + ColocatedIo cio(&sctx); + + // WHEN + if (lctx->Rank() == 0) { + cio.hostSetVar("x", xt::xarray{{1, -2, 3, 0}}, + Visibility::VIS_PRIVATE); + } else if (lctx->Rank() == 1) { + cio.hostSetVar("y", xt::xarray{{1, -2, 3, 0}}, + Visibility::VIS_PRIVATE); + } + cio.sync(); + + // THEN + EXPECT_TRUE(cio.deviceHasVar("x")); + auto x = cio.deviceGetVar("x"); + EXPECT_TRUE(x.isPrivate()) << x; + EXPECT_TRUE(cio.deviceHasVar("y")); + auto y = cio.deviceGetVar("y"); + EXPECT_TRUE(y.isPrivate()) << y; + }); +} + INSTANTIATE_TEST_SUITE_P( ColocatedIoTestInstance, ColocatedIoTest, testing::Combine( diff --git a/libspu/mpc/aby3/io.cc b/libspu/mpc/aby3/io.cc index 6fdf4af6..de5b5f86 100644 --- a/libspu/mpc/aby3/io.cc +++ b/libspu/mpc/aby3/io.cc @@ -111,8 +111,8 @@ std::vector Aby3Io::makeBitSecret(const PtBufferView& in) const { std::vector r0(numel); std::vector r1(numel); - yacl::crypto::PrgAesCtr(yacl::crypto::RandSeed(), absl::MakeSpan(r0)); - yacl::crypto::PrgAesCtr(yacl::crypto::RandSeed(), absl::MakeSpan(r1)); + yacl::crypto::PrgAesCtr(yacl::crypto::SecureRandSeed(), absl::MakeSpan(r0)); + yacl::crypto::PrgAesCtr(yacl::crypto::SecureRandSeed(), absl::MakeSpan(r1)); NdArrayView _s0(shares[0]); NdArrayView _s1(shares[1]); diff --git a/libspu/mpc/cheetah/arith/common.cc b/libspu/mpc/cheetah/arith/common.cc index aecbd2ab..92c52f43 100644 --- a/libspu/mpc/cheetah/arith/common.cc +++ b/libspu/mpc/cheetah/arith/common.cc @@ -22,7 +22,7 @@ namespace spu::mpc::cheetah { EnableCPRNG::EnableCPRNG() - : seed_(yacl::crypto::RandSeed(/*drbg*/ true)), prng_counter_(0) {} + : seed_(yacl::crypto::SecureRandSeed()), prng_counter_(0) {} // Uniform random on prime field void EnableCPRNG::UniformPrime(const seal::Modulus &prime, @@ -75,7 +75,7 @@ NdArrayRef EnableCPRNG::CPRNG(FieldType field, size_t size) { // Lock prng_counter_ std::scoped_lock guard(counter_lock_); if (prng_counter_ > kPRNG_THREASHOLD) { - seed_ = yacl::crypto::RandSeed(true); + seed_ = yacl::crypto::SecureRandSeed(); prng_counter_ = 0; } return ring_rand(field, {static_cast(size)}, seed_, &prng_counter_); diff --git a/libspu/mpc/cheetah/arith/matmat_prot_test.cc b/libspu/mpc/cheetah/arith/matmat_prot_test.cc index e70d6abe..fadb8cbd 100644 --- a/libspu/mpc/cheetah/arith/matmat_prot_test.cc +++ b/libspu/mpc/cheetah/arith/matmat_prot_test.cc @@ -75,7 +75,7 @@ class MatMatProtTest seal::KeyGenerator keygen(*context_); rlwe_sk_ = std::make_shared(keygen.secret_key()); - seed_ = yacl::crypto::RandSeed(); + seed_ = yacl::crypto::SecureRandSeed(); prng_counter_ = 0; } }; diff --git a/libspu/mpc/cheetah/ot/yacl/BUILD.bazel b/libspu/mpc/cheetah/ot/yacl/BUILD.bazel index 53957f43..a7181688 100644 --- a/libspu/mpc/cheetah/ot/yacl/BUILD.bazel +++ b/libspu/mpc/cheetah/ot/yacl/BUILD.bazel @@ -43,7 +43,8 @@ spu_cc_library( "@yacl//yacl/crypto/primitives/ot:base_ot", "@yacl//yacl/crypto/primitives/ot:ferret_ote", "@yacl//yacl/crypto/primitives/ot:iknp_ote", - "@yacl//yacl/crypto/tools:random_permutation", + "@yacl//yacl/crypto/tools:crhash", + "@yacl//yacl/crypto/tools:rp", "@yacl//yacl/crypto/utils:rand", "@yacl//yacl/link", ], diff --git a/libspu/mpc/cheetah/ot/yacl/ferret.cc b/libspu/mpc/cheetah/ot/yacl/ferret.cc index f8b0f030..b21a589d 100644 --- a/libspu/mpc/cheetah/ot/yacl/ferret.cc +++ b/libspu/mpc/cheetah/ot/yacl/ferret.cc @@ -18,7 +18,8 @@ #include "spdlog/spdlog.h" #include "yacl/base/buffer.h" -#include "yacl/crypto/tools/random_permutation.h" +#include "yacl/crypto/tools/crhash.h" +#include "yacl/crypto/tools/rp.h" #include "yacl/link/link.h" #include "libspu/mpc/cheetah/ot/ot_util.h" diff --git a/libspu/mpc/cheetah/rlwe/modswitch_helper_test.cc b/libspu/mpc/cheetah/rlwe/modswitch_helper_test.cc index 7ebc4e69..75cb918b 100644 --- a/libspu/mpc/cheetah/rlwe/modswitch_helper_test.cc +++ b/libspu/mpc/cheetah/rlwe/modswitch_helper_test.cc @@ -91,7 +91,7 @@ class RLWE2LWETest : public testing::TestWithParam { seal::KeyGenerator keygen(*context_); rlwe_sk_ = std::make_shared(keygen.secret_key()); - seed_ = yacl::crypto::RandSeed(); + seed_ = yacl::crypto::SecureRandSeed(); prng_counter_ = 0; } }; diff --git a/libspu/mpc/common/prg_state.cc b/libspu/mpc/common/prg_state.cc index c7f0458a..e470a6d3 100644 --- a/libspu/mpc/common/prg_state.cc +++ b/libspu/mpc/common/prg_state.cc @@ -24,7 +24,7 @@ namespace spu::mpc { PrgState::PrgState() { pub_seed_ = 0; - priv_seed_ = yacl::crypto::RandSeed(); + priv_seed_ = yacl::crypto::SecureRandSeed(); self_seed_ = 0; next_seed_ = 0; @@ -33,7 +33,7 @@ PrgState::PrgState() { PrgState::PrgState(const std::shared_ptr& lctx) { // synchronize public state. { - uint128_t self_pk = yacl::crypto::RandSeed(); + uint128_t self_pk = yacl::crypto::SecureRandSeed(); const auto all_buf = yacl::link::AllGather( lctx, yacl::SerializeUint128(self_pk), "Random::PK"); @@ -46,11 +46,11 @@ PrgState::PrgState(const std::shared_ptr& lctx) { } // init private state. - priv_seed_ = yacl::crypto::RandSeed(); + priv_seed_ = yacl::crypto::SecureRandSeed(); // init PRSS state. { - self_seed_ = yacl::crypto::RandSeed(); + self_seed_ = yacl::crypto::SecureRandSeed(); constexpr char kCommTag[] = "Random:PRSS"; @@ -68,7 +68,7 @@ std::unique_ptr PrgState::fork() { fillPubl(absl::MakeSpan(&new_prg->pub_seed_, 1)); - new_prg->priv_seed_ = yacl::crypto::RandSeed(); + new_prg->priv_seed_ = yacl::crypto::SecureRandSeed(); fillPrssPair(&new_prg->self_seed_, &new_prg->next_seed_, 1, PrgState::GenPrssCtrl::Both); diff --git a/libspu/mpc/semi2k/beaver/beaver_tfp.cc b/libspu/mpc/semi2k/beaver/beaver_tfp.cc index db15af63..32ae2133 100644 --- a/libspu/mpc/semi2k/beaver/beaver_tfp.cc +++ b/libspu/mpc/semi2k/beaver/beaver_tfp.cc @@ -28,7 +28,7 @@ namespace spu::mpc::semi2k { BeaverTfpUnsafe::BeaverTfpUnsafe(std::shared_ptr lctx) : lctx_(std::move(std::move(lctx))), - seed_(yacl::crypto::RandSeed(true)), + seed_(yacl::crypto::SecureRandSeed()), counter_(0) { auto buf = yacl::SerializeUint128(seed_); std::vector all_bufs = diff --git a/libspu/mpc/semi2k/beaver/beaver_ttp.cc b/libspu/mpc/semi2k/beaver/beaver_ttp.cc index dad4f6be..a849cebe 100644 --- a/libspu/mpc/semi2k/beaver/beaver_ttp.cc +++ b/libspu/mpc/semi2k/beaver/beaver_ttp.cc @@ -146,7 +146,7 @@ BeaverTtp::~BeaverTtp() { BeaverTtp::BeaverTtp(std::shared_ptr lctx, Options ops) : lctx_(std::move(std::move(lctx))), - seed_(yacl::crypto::RandSeed(true)), + seed_(yacl::crypto::SecureRandSeed()), counter_(0), options_(std::move(ops)), child_counter_(0) { diff --git a/libspu/mpc/spdz2k/beaver/BUILD.bazel b/libspu/mpc/spdz2k/beaver/BUILD.bazel index fcf0f258..66bedd1d 100644 --- a/libspu/mpc/spdz2k/beaver/BUILD.bazel +++ b/libspu/mpc/spdz2k/beaver/BUILD.bazel @@ -36,6 +36,8 @@ spu_cc_library( "//libspu/mpc/spdz2k:commitment", "//libspu/mpc/utils:ring_ops", "@com_github_microsoft_seal//:seal", + "@yacl//yacl/crypto/base/block_cipher:symmetric_crypto", + "@yacl//yacl/crypto/tools:prg", "@yacl//yacl/link", "@yacl//yacl/utils:parallel", ], diff --git a/libspu/mpc/spdz2k/beaver/beaver_tfp.cc b/libspu/mpc/spdz2k/beaver/beaver_tfp.cc index f794c2a5..db9995f5 100644 --- a/libspu/mpc/spdz2k/beaver/beaver_tfp.cc +++ b/libspu/mpc/spdz2k/beaver/beaver_tfp.cc @@ -17,6 +17,8 @@ #include #include +#include "yacl/crypto/base/block_cipher/symmetric_crypto.h" +#include "yacl/crypto/tools/prg.h" #include "yacl/crypto/utils/rand.h" #include "yacl/utils/serialize.h" diff --git a/libspu/mpc/spdz2k/commitment.cc b/libspu/mpc/spdz2k/commitment.cc index 065e1811..4fc58ccf 100644 --- a/libspu/mpc/spdz2k/commitment.cc +++ b/libspu/mpc/spdz2k/commitment.cc @@ -50,7 +50,7 @@ bool commit_and_open(const std::shared_ptr& lctx, std::vector* z_strs) { bool res = true; size_t send_player = lctx->Rank(); - uint128_t rs = yacl::crypto::RandSeed(); + uint128_t rs = yacl::crypto::SecureRandSeed(); std::string rs_str(reinterpret_cast(&rs), sizeof(rs)); // 1. commit and send auto cmt = commit(send_player, z_str, rs_str); diff --git a/libspu/mpc/spdz2k/ot/BUILD.bazel b/libspu/mpc/spdz2k/ot/BUILD.bazel index 50651d6f..cf3aa52c 100644 --- a/libspu/mpc/spdz2k/ot/BUILD.bazel +++ b/libspu/mpc/spdz2k/ot/BUILD.bazel @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//bazel:spu.bzl", "spu_cc_library", "spu_cc_test") +load("//bazel:spu.bzl", "spu_cc_library") load("@yacl//bazel:yacl.bzl", "AES_COPT_FLAGS") package(default_visibility = ["//visibility:public"]) @@ -46,9 +46,10 @@ spu_cc_library( "@yacl//yacl/crypto/base/hash:hash_interface", "@yacl//yacl/crypto/base/hash:hash_utils", "@yacl//yacl/crypto/primitives/ot:base_ot", + "@yacl//yacl/crypto/tools:crhash", "@yacl//yacl/crypto/tools:prg", - "@yacl//yacl/crypto/tools:random_oracle", - "@yacl//yacl/crypto/tools:random_permutation", + "@yacl//yacl/crypto/tools:ro", + "@yacl//yacl/crypto/tools:rp", "@yacl//yacl/link", "@yacl//yacl/utils:matrix_utils", "@yacl//yacl/utils:serialize", diff --git a/libspu/mpc/spdz2k/ot/kos_ote.cc b/libspu/mpc/spdz2k/ot/kos_ote.cc index 85c57eac..5b92abe2 100644 --- a/libspu/mpc/spdz2k/ot/kos_ote.cc +++ b/libspu/mpc/spdz2k/ot/kos_ote.cc @@ -19,9 +19,10 @@ #include "emp-tool/utils/block.h" #include "emp-tool/utils/f2k.h" #include "yacl/crypto/base/hash/hash_utils.h" +#include "yacl/crypto/tools/crhash.h" #include "yacl/crypto/tools/prg.h" -#include "yacl/crypto/tools/random_oracle.h" -#include "yacl/crypto/tools/random_permutation.h" +#include "yacl/crypto/tools/ro.h" +#include "yacl/crypto/tools/rp.h" #include "yacl/link/link.h" #include "yacl/utils/matrix_utils.h" #include "yacl/utils/serialize.h" diff --git a/libspu/mpc/utils/ring_ops.cc b/libspu/mpc/utils/ring_ops.cc index 57b7f0dd..cc87f7c5 100644 --- a/libspu/mpc/utils/ring_ops.cc +++ b/libspu/mpc/utils/ring_ops.cc @@ -205,7 +205,7 @@ void ring_print(const NdArrayRef& x, std::string_view name) { NdArrayRef ring_rand(FieldType field, const Shape& shape) { uint64_t cnt = 0; - return ring_rand(field, shape, yacl::crypto::RandSeed(), &cnt); + return ring_rand(field, shape, yacl::crypto::SecureRandSeed(), &cnt); } NdArrayRef ring_rand(FieldType field, const Shape& shape, uint128_t prg_seed, diff --git a/setup.py b/setup.py index 1a18f8fb..14ccc88a 100644 --- a/setup.py +++ b/setup.py @@ -112,7 +112,9 @@ def get_packages(self): setup_spec.install_requires = read_requirements('requirements.txt') -files_to_remove = ["spu/intrinsic/add_new_intrinsic.py"] +files_to_remove = [ + "spu/intrinsic/add_new_intrinsic.py", +] # Calls Bazel in PATH @@ -207,6 +209,8 @@ def pip_run(build_ext): # Change __module__ in psi_pb2.py and pir_pb2.py fix_pb('bazel-bin/spu/psi_pb2.py', 'psi.psi.psi_pb2', 'spu.psi_pb2') + fix_pb('bazel-bin/spu/link_pb2.py', 'yacl.link.link_pb2', 'link.pir_pb2') + fix_pb('bazel-bin/spu/psi_v2_pb2.py', 'psi.proto.psi_v2_pb2', 'spu.psi_pb2') fix_pb('bazel-bin/spu/pir_pb2.py', 'psi.pir.pir_pb2', 'spu.pir_pb2') setup_spec.files_to_include += spu_lib_files diff --git a/spu/BUILD.bazel b/spu/BUILD.bazel index 6b543099..52d14b92 100644 --- a/spu/BUILD.bazel +++ b/spu/BUILD.bazel @@ -67,6 +67,7 @@ pybind_extension( ":version_script.lds", "@psi//psi/pir", "@psi//psi/psi:bucket_psi", + "@psi//psi/psi:factory", "@psi//psi/psi:memory_psi", "@yacl//yacl/link", ], @@ -90,11 +91,33 @@ python_proto_compile( protos = ["@psi//psi/psi:psi_proto"], ) +python_proto_compile( + name = "link_py_proto", + output_mode = "NO_PREFIX_FLAT", + protos = ["@yacl//yacl/link:link_proto"], +) + +python_proto_compile( + name = "psi_v2_py_proto", + output_mode = "NO_PREFIX", + protos = ["@psi//psi/proto:psi_v2_proto"], +) + +# Hack generated protobuf due to https://github.com/protocolbuffers/protobuf/issues/1491 +genrule( + name = "psi_v2_py_proto_fixed", + srcs = [":psi_v2_py_proto"], + outs = ["psi_v2_pb2.py"], + cmd = "sed 's#from yacl.link import#from . import#g;s#from psi.psi import#from . import#g' $(SRCS) > $(OUTS)", +) + py_library( name = "psi", srcs = [ "psi.py", + ":link_py_proto", ":psi_py_proto", + ":psi_v2_py_proto_fixed", ], data = [ ":libpsi.so", diff --git a/spu/libpsi.cc b/spu/libpsi.cc index 70c06f8e..4b8a9df7 100644 --- a/spu/libpsi.cc +++ b/spu/libpsi.cc @@ -22,6 +22,7 @@ #include "psi/pir/pir.h" #include "psi/psi/bucket_psi.h" +#include "psi/psi/factory.h" #include "psi/psi/memory_psi.h" #include "psi/psi/utils/progress.h" @@ -76,6 +77,21 @@ void BindLibs(py::module& m) { py::arg("callbacks_interval_ms") = 5 * 1000, py::arg("ic_mode") = false, "Run bucket psi. ic_mode means run in interconnection mode", NO_GIL); + m.def( + "psi_v2", + [](const std::string& config_pb, + const std::shared_ptr& lctx) -> py::bytes { + psi::v2::PsiConfig psi_config; + YACL_ENFORCE(psi_config.ParseFromString(config_pb)); + + std::unique_ptr psi_party = + psi::createPSIParty(psi_config, lctx); + psi::v2::PsiReport report = psi_party->Run(); + return report.SerializeAsString(); + }, + py::arg("psi_config"), py::arg("link_context") = nullptr, + "Run PSI with v2 API.", NO_GIL); + m.def( "pir_setup", [](const std::string& config_pb) -> py::bytes { diff --git a/spu/psi.py b/spu/psi.py index b8287955..f65ce3b4 100644 --- a/spu/psi.py +++ b/spu/psi.py @@ -16,9 +16,9 @@ from typing import List -from .libspu.link import Context # type: ignore -from .libpsi.libs import ProgressData from . import libpsi # type: ignore +from .libpsi.libs import ProgressData +from .libspu.link import Context # type: ignore from .psi_pb2 import ( # type: ignore BucketPsiConfig, CurveType, @@ -28,6 +28,7 @@ PsiResultReport, PsiType, ) +from .psi_v2_pb2 import PsiConfig, PsiReport def mem_psi( @@ -87,3 +88,22 @@ def gen_cache_for_2pc_ub_psi(config: BucketPsiConfig) -> PsiResultReport: report = PsiResultReport() report.ParseFromString(report_str) return report + + +def psi_v2( + config: BucketPsiConfig, + link: Context = None, +) -> PsiReport: + """ + Run PSI with v2 API. + :param config: psi config + :param link: the transport layer + :return: statistical results + """ + report_str = libpsi.libs.psi_v2( + config.SerializeToString(), + link, + ) + report = PsiReport() + report.ParseFromString(report_str) + return report