Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangwfjh committed Oct 30, 2023
1 parent f677768 commit e503426
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 108 deletions.
14 changes: 2 additions & 12 deletions libspu/psi/core/fnp04_mp_psi/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

load("//bazel:spu.bzl", "spu_cc_binary", "spu_cc_library", "spu_cc_test")
load("//bazel:spu.bzl", "spu_cc_library", "spu_cc_test")

package(default_visibility = ["//visibility:public"])

Expand All @@ -21,26 +21,16 @@ spu_cc_library(
srcs = ["fnp04_mp_psi.cc"],
hdrs = ["fnp04_mp_psi.h"],
deps = [
":serializable_cc_proto",
"//libspu/core:prelude",
"@yacl//yacl/link",
"@heu//heu/library/algorithms/paillier_zahlen",
"//libspu/psi/utils:test_utils",
"//libspu/psi/utils",
],
)

spu_cc_test(
name = "fnp04_mp_psi_test",
srcs = ["fnp04_mp_psi_test.cc"],
deps = [":fnp04_mp_psi"],
)

proto_library(
name = "serializable_proto",
srcs = ["serializable.proto"],
)

cc_proto_library(
name = "serializable_cc_proto",
deps = [":serializable_proto"],
)
125 changes: 58 additions & 67 deletions libspu/psi/core/fnp04_mp_psi/fnp04_mp_psi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,18 @@

#include <algorithm>
#include <iterator>
#include <limits>
#include <numeric>
#include <random>
#include <utility>

#include "yacl/link/link.h"
#include "yacl/utils/serialize.h"

#include "libspu/psi/core/fnp04_mp_psi/serializable.pb.h"
#include "libspu/core/prelude.h"
#include "libspu/psi/utils/utils.h"

namespace {

static std::random_device rd{};
static std::default_random_engine engine{rd()};
static std::independent_bits_engine<std::default_random_engine,
sizeof(size_t) * 8, size_t>
rng{rd()};
Expand Down Expand Up @@ -73,26 +72,16 @@ FNP04Party::FNP04Party(const Options& options) : options_{options} {
std::vector<std::string> FNP04Party::Run(
const std::vector<std::string>& inputs) {
auto [ctx, wsize, me, leader] = CollectContext();
// Step 0: Preprocessing inputs
auto count = inputs.size();
auto counts = yacl::link::AllGather(
options_.link_ctx, Serialize(count),
fmt::format("{} send item size", options_.link_ctx->Rank()));
for (const auto& buf : counts) {
if (auto cnt = Deserialize(buf); cnt == 0) {
auto counts = AllGatherItemsSize(ctx, inputs.size());
size_t count{};
for (auto cnt : counts) {
if (cnt == 0) {
return {};
} else {
count = std::max(count, cnt);
}
count = std::max(cnt, count);
}
std::vector<size_t> items;
std::transform(inputs.begin(), inputs.end(), std::back_inserter(items),
[](std::string_view input) {
return std::hash<std::string_view>{}(input);
});
// Add random dummy elements
std::generate_n(std::back_inserter(items), count - inputs.size(),
std::ref(rng));
// Step 0: Encode the inputs
auto items = EncodeInputs(inputs, count);
// Step 1: Broadcast the public key
BroadcastPubKey();
// Step 2: Zero sharing
Expand All @@ -102,20 +91,19 @@ std::vector<std::string> FNP04Party::Run(
shares[i][(leader + 1) % wsize] ^= items[i];
}
}
Share aggregate(count);
if (leader == me) {
// Step 3: Receive the encrypted set
auto hashings = RecvEncryptedSet(count);
// Step 4: Swap encrypted shares
// Step 4: Swap the encrypted shares
SwapShares(shares, items, hashings);
} else {
// Step 3: Send the encrypted set
SendEncryptedSet(items);
// Step 4: Swap encrypted shares
auto recv_shares = SwapShares(shares);
// Step 5: Aggregate share
aggregate = AggregateShare(recv_shares);
shares = SwapShares(shares);
}
// Step 5: Aggregate share
auto aggregate = AggregateShare(shares);
// Step 6: Get intersection
auto intersection_items = GetIntersection(items, aggregate);
std::vector<std::string> intersection;
Expand All @@ -128,6 +116,19 @@ std::vector<std::string> FNP04Party::Run(
return intersection;
}

std::vector<size_t> FNP04Party::EncodeInputs(
const std::vector<std::string>& inputs, size_t count) const {
std::vector<size_t> items;
std::transform(inputs.begin(), inputs.end(), std::back_inserter(items),
[](std::string_view input) {
return std::hash<std::string_view>{}(input);
});
// Add random dummy elements
std::generate_n(std::back_inserter(items), count - inputs.size(),
std::ref(rng));
return items;
}

void FNP04Party::BroadcastPubKey() {
auto [ctx, wsize, me, leader] = CollectContext();
// Publish
Expand Down Expand Up @@ -164,7 +165,7 @@ void FNP04Party::SendEncryptedSet(const std::vector<size_t>& items) const {
[&](auto item) { hashing[item % B].emplace_back(item); });
// Hashing
SecretPolynomial bins(B);
proto::BuffersProto proto;
std::vector<yacl::Buffer> buffers;
for (auto& roots : hashing) {
SPU_ENFORCE(roots.size() <= BinSize, "Severe hash collisions");
roots.resize(BinSize);
Expand All @@ -186,14 +187,13 @@ void FNP04Party::SendEncryptedSet(const std::vector<size_t>& items) const {
roots[3] * roots[4]),
-(roots[0] + roots[1] + roots[2] + roots[3] + roots[4])};
for (const auto& coeff : coeffs) {
auto buf = encryptors_[me]->Encrypt(Plaintext(coeff)).Serialize();
proto.add_buffers(buf.data(), buf.size());
buffers.emplace_back(
encryptors_[me]->Encrypt(Plaintext(coeff)).Serialize());
}
}
yacl::Buffer buf(proto.ByteSizeLong());
proto.SerializeToArray(buf.data(), buf.size());
ctx->SendAsyncThrottled(leader, buf,
fmt::format("Party {} sends the encrypted set", me));
ctx->SendAsyncThrottled(
leader, yacl::SerializeArrayOfBuffers({buffers.begin(), buffers.end()}),
fmt::format("Party {} sends the encrypted set", me));
}

auto FNP04Party::RecvEncryptedSet(size_t count) const
Expand All @@ -207,12 +207,11 @@ auto FNP04Party::RecvEncryptedSet(size_t count) const
auto buf = ctx->Recv(
src, fmt::format(
"The leader receives the encrypted set from party {}", src));
proto::BuffersProto proto;
proto.ParseFromArray(buf.data(), buf.size());
auto buffers = yacl::DeserializeArrayOfBuffers(buf);
size_t i{};
for (auto& bin : hashings[src]) {
for (auto& coeff : bin) {
coeff.Deserialize(proto.buffers(i++));
coeff.Deserialize(buffers[i++]);
}
}
}
Expand Down Expand Up @@ -240,16 +239,14 @@ auto FNP04Party::SwapShares(const std::vector<Share>& shares) const
// Send
for (size_t dst{}; dst != wsize; ++dst) {
if (dst != me && dst != leader) {
proto::BuffersProto proto;
std::vector<yacl::Buffer> buffers;
for (auto& share : shares) {
auto cipher =
encryptors_[dst]->Encrypt(Plaintext(share[dst])).Serialize();
proto.add_buffers(cipher.data(), cipher.size());
auto cipher = encryptors_[dst]->Encrypt(Plaintext(share[dst]));
buffers.emplace_back(cipher.Serialize());
}
yacl::Buffer buf(proto.ByteSizeLong());
proto.SerializeToArray(buf.data(), buf.size());
ctx->SendAsyncThrottled(
dst, buf, fmt::format("Party {} sends secret shares to {}", me, dst));
dst, yacl::SerializeArrayOfBuffers({buffers.begin(), buffers.end()}),
fmt::format("Party {} sends secret shares to {}", me, dst));
}
}
// Receive
Expand All @@ -258,12 +255,11 @@ auto FNP04Party::SwapShares(const std::vector<Share>& shares) const
if (src != me) {
auto buf = ctx->Recv(
src, fmt::format("Party {} receives secret shares from {}", me, src));
proto::BuffersProto proto;
proto.ParseFromArray(buf.data(), buf.size());
auto buffers = yacl::DeserializeArrayOfBuffers(buf);
size_t i{};
for (auto& share : recv_shares) {
Ciphertext cipher;
cipher.Deserialize(proto.buffers(i++));
cipher.Deserialize(buffers[i++]);
share[src] = ToUnsigned(decryptor_->Decrypt(cipher));
}
} else {
Expand Down Expand Up @@ -303,15 +299,12 @@ void FNP04Party::SwapShares(
evaluator.MulInplace(&share[i], scale);
evaluator.AddInplace(&share[i], bias);
}
proto::BuffersProto proto;
for (const auto& s : share) {
auto buf = s.Serialize();
proto.add_buffers(buf.data(), buf.size());
}
yacl::Buffer buf(proto.ByteSizeLong());
proto.SerializeToArray(buf.data(), buf.size());
std::vector<yacl::Buffer> buffers(count);
std::transform(share.begin(), share.end(), buffers.begin(),
[&](const auto& s) { return s.Serialize(); });
ctx->SendAsyncThrottled(
dst, buf, fmt::format("The leader sends secret shares to {}", dst));
dst, yacl::SerializeArrayOfBuffers({buffers.begin(), buffers.end()}),
fmt::format("The leader sends secret shares to {}", dst));
}
}
}
Expand All @@ -320,9 +313,11 @@ auto FNP04Party::AggregateShare(const std::vector<Share>& shares) const
-> Share {
auto [ctx, wsize, me, leader] = CollectContext();
Share share(shares.size());
for (size_t i{}; i != shares.size(); ++i) {
for (size_t src{}; src != wsize; ++src) {
share[i] ^= shares[i][src];
if (me != leader) {
for (size_t i{}; i != shares.size(); ++i) {
for (size_t src{}; src != wsize; ++src) {
share[i] ^= shares[i][src];
}
}
}
return share;
Expand All @@ -332,14 +327,11 @@ std::vector<size_t> FNP04Party::GetIntersection(
const std::vector<size_t>& items, const Share& share) const {
auto [ctx, wsize, me, leader] = CollectContext();
if (me != leader) {
proto::SizesProto proto;
for (auto& s : share) {
proto.add_sizes(s);
}
yacl::Buffer buf(proto.ByteSizeLong());
proto.SerializeToArray(buf.data(), buf.size());
std::vector<yacl::Buffer> buffers(share.size());
std::transform(share.begin(), share.end(), buffers.begin(),
[&](size_t s) { return Serialize(s); });
ctx->SendAsyncThrottled(
leader, buf,
leader, yacl::SerializeArrayOfBuffers({buffers.begin(), buffers.end()}),
fmt::format("Party {} sends aggregated shares to the leader", me));
return {};
}
Expand All @@ -349,11 +341,10 @@ std::vector<size_t> FNP04Party::GetIntersection(
auto buf = ctx->Recv(
src,
fmt::format("The leader receives aggregated shares from {}", src));
proto::SizesProto proto;
proto.ParseFromArray(buf.data(), buf.size());
auto buffers = yacl::DeserializeArrayOfBuffers(buf);
size_t i{};
for (auto& item : universe) {
item ^= proto.sizes(i++);
item ^= Deserialize(buffers[i++]);
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions libspu/psi/core/fnp04_mp_psi/fnp04_mp_psi.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
#include "heu/library/algorithms/paillier_zahlen/paillier.h"
#include "yacl/link/link.h"

#include "libspu/core/prelude.h"

namespace spu::psi {

// Efficient Private Matching and Set Intersection
Expand All @@ -44,6 +42,8 @@ class FNP04Party {
virtual std::vector<std::string> Run(const std::vector<std::string>& inputs);

private:
std::vector<size_t> EncodeInputs(const std::vector<std::string>& inputs,
size_t count) const;
void BroadcastPubKey();
void SendEncryptedSet(const std::vector<size_t>& items) const;
std::vector<SecretPolynomial> RecvEncryptedSet(size_t count) const;
Expand Down
27 changes: 0 additions & 27 deletions libspu/psi/core/fnp04_mp_psi/serializable.proto

This file was deleted.

0 comments on commit e503426

Please sign in to comment.