diff --git a/libspu/psi/core/fnp04_mp_psi/BUILD.bazel b/libspu/psi/core/fnp04_mp_psi/BUILD.bazel index b013963c..982d6232 100644 --- a/libspu/psi/core/fnp04_mp_psi/BUILD.bazel +++ b/libspu/psi/core/fnp04_mp_psi/BUILD.bazel @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//bazel:spu.bzl", "spu_cc_binary", "spu_cc_library", "spu_cc_test") +load("//bazel:spu.bzl", "spu_cc_library", "spu_cc_test") package(default_visibility = ["//visibility:public"]) @@ -21,11 +21,11 @@ spu_cc_library( srcs = ["fnp04_mp_psi.cc"], hdrs = ["fnp04_mp_psi.h"], deps = [ - ":serializable_cc_proto", "//libspu/core:prelude", "@yacl//yacl/link", "@heu//heu/library/algorithms/paillier_zahlen", "//libspu/psi/utils:test_utils", + "//libspu/psi/utils", ], ) @@ -33,14 +33,4 @@ spu_cc_test( name = "fnp04_mp_psi_test", srcs = ["fnp04_mp_psi_test.cc"], deps = [":fnp04_mp_psi"], -) - -proto_library( - name = "serializable_proto", - srcs = ["serializable.proto"], -) - -cc_proto_library( - name = "serializable_cc_proto", - deps = [":serializable_proto"], ) \ No newline at end of file diff --git a/libspu/psi/core/fnp04_mp_psi/fnp04_mp_psi.cc b/libspu/psi/core/fnp04_mp_psi/fnp04_mp_psi.cc index 9b3605e0..807ec9cd 100644 --- a/libspu/psi/core/fnp04_mp_psi/fnp04_mp_psi.cc +++ b/libspu/psi/core/fnp04_mp_psi/fnp04_mp_psi.cc @@ -16,19 +16,18 @@ #include #include -#include #include #include #include -#include "yacl/link/link.h" +#include "yacl/utils/serialize.h" -#include "libspu/psi/core/fnp04_mp_psi/serializable.pb.h" +#include "libspu/core/prelude.h" +#include "libspu/psi/utils/utils.h" namespace { static std::random_device rd{}; -static std::default_random_engine engine{rd()}; static std::independent_bits_engine rng{rd()}; @@ -73,26 +72,16 @@ FNP04Party::FNP04Party(const Options& options) : options_{options} { std::vector FNP04Party::Run( const std::vector& inputs) { auto [ctx, wsize, me, leader] = CollectContext(); - // Step 0: Preprocessing inputs - auto count = inputs.size(); - auto counts = yacl::link::AllGather( - options_.link_ctx, Serialize(count), - fmt::format("{} send item size", options_.link_ctx->Rank())); - for (const auto& buf : counts) { - if (auto cnt = Deserialize(buf); cnt == 0) { + auto counts = AllGatherItemsSize(ctx, inputs.size()); + size_t count{}; + for (auto cnt : counts) { + if (cnt == 0) { return {}; - } else { - count = std::max(count, cnt); } + count = std::max(cnt, count); } - std::vector items; - std::transform(inputs.begin(), inputs.end(), std::back_inserter(items), - [](std::string_view input) { - return std::hash{}(input); - }); - // Add random dummy elements - std::generate_n(std::back_inserter(items), count - inputs.size(), - std::ref(rng)); + // Step 0: Encode the inputs + auto items = EncodeInputs(inputs, count); // Step 1: Broadcast the public key BroadcastPubKey(); // Step 2: Zero sharing @@ -102,20 +91,19 @@ std::vector FNP04Party::Run( shares[i][(leader + 1) % wsize] ^= items[i]; } } - Share aggregate(count); if (leader == me) { // Step 3: Receive the encrypted set auto hashings = RecvEncryptedSet(count); - // Step 4: Swap encrypted shares + // Step 4: Swap the encrypted shares SwapShares(shares, items, hashings); } else { // Step 3: Send the encrypted set SendEncryptedSet(items); // Step 4: Swap encrypted shares - auto recv_shares = SwapShares(shares); - // Step 5: Aggregate share - aggregate = AggregateShare(recv_shares); + shares = SwapShares(shares); } + // Step 5: Aggregate share + auto aggregate = AggregateShare(shares); // Step 6: Get intersection auto intersection_items = GetIntersection(items, aggregate); std::vector intersection; @@ -128,6 +116,19 @@ std::vector FNP04Party::Run( return intersection; } +std::vector FNP04Party::EncodeInputs( + const std::vector& inputs, size_t count) const { + std::vector items; + std::transform(inputs.begin(), inputs.end(), std::back_inserter(items), + [](std::string_view input) { + return std::hash{}(input); + }); + // Add random dummy elements + std::generate_n(std::back_inserter(items), count - inputs.size(), + std::ref(rng)); + return items; +} + void FNP04Party::BroadcastPubKey() { auto [ctx, wsize, me, leader] = CollectContext(); // Publish @@ -164,7 +165,7 @@ void FNP04Party::SendEncryptedSet(const std::vector& items) const { [&](auto item) { hashing[item % B].emplace_back(item); }); // Hashing SecretPolynomial bins(B); - proto::BuffersProto proto; + std::vector buffers; for (auto& roots : hashing) { SPU_ENFORCE(roots.size() <= BinSize, "Severe hash collisions"); roots.resize(BinSize); @@ -186,14 +187,13 @@ void FNP04Party::SendEncryptedSet(const std::vector& items) const { roots[3] * roots[4]), -(roots[0] + roots[1] + roots[2] + roots[3] + roots[4])}; for (const auto& coeff : coeffs) { - auto buf = encryptors_[me]->Encrypt(Plaintext(coeff)).Serialize(); - proto.add_buffers(buf.data(), buf.size()); + buffers.emplace_back( + encryptors_[me]->Encrypt(Plaintext(coeff)).Serialize()); } } - yacl::Buffer buf(proto.ByteSizeLong()); - proto.SerializeToArray(buf.data(), buf.size()); - ctx->SendAsyncThrottled(leader, buf, - fmt::format("Party {} sends the encrypted set", me)); + ctx->SendAsyncThrottled( + leader, yacl::SerializeArrayOfBuffers({buffers.begin(), buffers.end()}), + fmt::format("Party {} sends the encrypted set", me)); } auto FNP04Party::RecvEncryptedSet(size_t count) const @@ -207,12 +207,11 @@ auto FNP04Party::RecvEncryptedSet(size_t count) const auto buf = ctx->Recv( src, fmt::format( "The leader receives the encrypted set from party {}", src)); - proto::BuffersProto proto; - proto.ParseFromArray(buf.data(), buf.size()); + auto buffers = yacl::DeserializeArrayOfBuffers(buf); size_t i{}; for (auto& bin : hashings[src]) { for (auto& coeff : bin) { - coeff.Deserialize(proto.buffers(i++)); + coeff.Deserialize(buffers[i++]); } } } @@ -240,16 +239,14 @@ auto FNP04Party::SwapShares(const std::vector& shares) const // Send for (size_t dst{}; dst != wsize; ++dst) { if (dst != me && dst != leader) { - proto::BuffersProto proto; + std::vector buffers; for (auto& share : shares) { - auto cipher = - encryptors_[dst]->Encrypt(Plaintext(share[dst])).Serialize(); - proto.add_buffers(cipher.data(), cipher.size()); + auto cipher = encryptors_[dst]->Encrypt(Plaintext(share[dst])); + buffers.emplace_back(cipher.Serialize()); } - yacl::Buffer buf(proto.ByteSizeLong()); - proto.SerializeToArray(buf.data(), buf.size()); ctx->SendAsyncThrottled( - dst, buf, fmt::format("Party {} sends secret shares to {}", me, dst)); + dst, yacl::SerializeArrayOfBuffers({buffers.begin(), buffers.end()}), + fmt::format("Party {} sends secret shares to {}", me, dst)); } } // Receive @@ -258,12 +255,11 @@ auto FNP04Party::SwapShares(const std::vector& shares) const if (src != me) { auto buf = ctx->Recv( src, fmt::format("Party {} receives secret shares from {}", me, src)); - proto::BuffersProto proto; - proto.ParseFromArray(buf.data(), buf.size()); + auto buffers = yacl::DeserializeArrayOfBuffers(buf); size_t i{}; for (auto& share : recv_shares) { Ciphertext cipher; - cipher.Deserialize(proto.buffers(i++)); + cipher.Deserialize(buffers[i++]); share[src] = ToUnsigned(decryptor_->Decrypt(cipher)); } } else { @@ -303,15 +299,12 @@ void FNP04Party::SwapShares( evaluator.MulInplace(&share[i], scale); evaluator.AddInplace(&share[i], bias); } - proto::BuffersProto proto; - for (const auto& s : share) { - auto buf = s.Serialize(); - proto.add_buffers(buf.data(), buf.size()); - } - yacl::Buffer buf(proto.ByteSizeLong()); - proto.SerializeToArray(buf.data(), buf.size()); + std::vector buffers(count); + std::transform(share.begin(), share.end(), buffers.begin(), + [&](const auto& s) { return s.Serialize(); }); ctx->SendAsyncThrottled( - dst, buf, fmt::format("The leader sends secret shares to {}", dst)); + dst, yacl::SerializeArrayOfBuffers({buffers.begin(), buffers.end()}), + fmt::format("The leader sends secret shares to {}", dst)); } } } @@ -320,9 +313,11 @@ auto FNP04Party::AggregateShare(const std::vector& shares) const -> Share { auto [ctx, wsize, me, leader] = CollectContext(); Share share(shares.size()); - for (size_t i{}; i != shares.size(); ++i) { - for (size_t src{}; src != wsize; ++src) { - share[i] ^= shares[i][src]; + if (me != leader) { + for (size_t i{}; i != shares.size(); ++i) { + for (size_t src{}; src != wsize; ++src) { + share[i] ^= shares[i][src]; + } } } return share; @@ -332,14 +327,11 @@ std::vector FNP04Party::GetIntersection( const std::vector& items, const Share& share) const { auto [ctx, wsize, me, leader] = CollectContext(); if (me != leader) { - proto::SizesProto proto; - for (auto& s : share) { - proto.add_sizes(s); - } - yacl::Buffer buf(proto.ByteSizeLong()); - proto.SerializeToArray(buf.data(), buf.size()); + std::vector buffers(share.size()); + std::transform(share.begin(), share.end(), buffers.begin(), + [&](size_t s) { return Serialize(s); }); ctx->SendAsyncThrottled( - leader, buf, + leader, yacl::SerializeArrayOfBuffers({buffers.begin(), buffers.end()}), fmt::format("Party {} sends aggregated shares to the leader", me)); return {}; } @@ -349,11 +341,10 @@ std::vector FNP04Party::GetIntersection( auto buf = ctx->Recv( src, fmt::format("The leader receives aggregated shares from {}", src)); - proto::SizesProto proto; - proto.ParseFromArray(buf.data(), buf.size()); + auto buffers = yacl::DeserializeArrayOfBuffers(buf); size_t i{}; for (auto& item : universe) { - item ^= proto.sizes(i++); + item ^= Deserialize(buffers[i++]); } } } diff --git a/libspu/psi/core/fnp04_mp_psi/fnp04_mp_psi.h b/libspu/psi/core/fnp04_mp_psi/fnp04_mp_psi.h index 6b572622..68e90079 100644 --- a/libspu/psi/core/fnp04_mp_psi/fnp04_mp_psi.h +++ b/libspu/psi/core/fnp04_mp_psi/fnp04_mp_psi.h @@ -21,8 +21,6 @@ #include "heu/library/algorithms/paillier_zahlen/paillier.h" #include "yacl/link/link.h" -#include "libspu/core/prelude.h" - namespace spu::psi { // Efficient Private Matching and Set Intersection @@ -44,6 +42,8 @@ class FNP04Party { virtual std::vector Run(const std::vector& inputs); private: + std::vector EncodeInputs(const std::vector& inputs, + size_t count) const; void BroadcastPubKey(); void SendEncryptedSet(const std::vector& items) const; std::vector RecvEncryptedSet(size_t count) const; diff --git a/libspu/psi/core/fnp04_mp_psi/serializable.proto b/libspu/psi/core/fnp04_mp_psi/serializable.proto deleted file mode 100644 index 21ba4c77..00000000 --- a/libspu/psi/core/fnp04_mp_psi/serializable.proto +++ /dev/null @@ -1,27 +0,0 @@ -// -// Copyright 2023 zhangwfjh -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -syntax = "proto3"; - -package spu.psi.proto; - -message BuffersProto { - repeated bytes buffers = 1; -} - -message SizesProto { - repeated uint64 sizes = 1; -} \ No newline at end of file