Skip to content

Commit

Permalink
Repo sync (#447)
Browse files Browse the repository at this point in the history
  • Loading branch information
anakinxc authored Dec 22, 2023
1 parent 76d2b8b commit 338ff3d
Show file tree
Hide file tree
Showing 25 changed files with 142 additions and 31 deletions.
1 change: 1 addition & 0 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

common --experimental_repo_remote_exec
common --experimental_cc_shared_library

# Required by OpenXLA
build --nocheck_visibility
Expand Down
4 changes: 2 additions & 2 deletions bazel/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe")

SECRETFLOW_GIT = "https://github.com/secretflow"

YACL_COMMIT_ID = "2b7d8882c78f07bd9e78217b7f9ca13135781e65"
YACL_COMMIT_ID = "816ac40ead311507ade19d521e1069b065747b63"

LIBSPI_COMMIT_ID = "dbf452b9d87619b40b73746b7c9afefc1845975b"
LIBSPI_COMMIT_ID = "dbe7028f9fceecbaf944b99d908d56f6c07449fc"

def spu_deps():
_rules_cuda()
Expand Down
2 changes: 1 addition & 1 deletion examples/python/ml/flax_llama7b/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ This example demonstrates how to use SPU to run secure inference on a pre-traine

Since EasyLM have an issue,so we have to make a samll change to support the option "streaming=false".
Open and edit "convert_hf_to_easylm.py", chang this:

```python
parser.add_argument("--streaming", action="store_true", default=True, help="whether is model weight saved stream format",)
```
Expand All @@ -35,6 +34,7 @@ This example demonstrates how to use SPU to run secure inference on a pre-traine

Download trained LLaMA-B[PyTroch-Version] from [Hugging Face](https://huggingface.co/openlm-research/open_llama_7b)
, and convert it to Flax.msgpack as:

```sh
python convert_hf_to_easylm.py \
--checkpoint_dir path-to-flax-llama7b-dir \
Expand Down
10 changes: 8 additions & 2 deletions libspu/device/io.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,13 @@ ColocatedIo::ColocatedIo(SPUContext *sctx) : sctx_(sctx) {}

void ColocatedIo::hostSetVar(const std::string &name, const PtBufferView &bv,
Visibility vtype) {
unsynced_[name] = {convertToNdArray(bv), vtype};
if (vtype == VIS_PRIVATE) {
// handle SECRET/PRIVATE compiler/runtime trick.
unsynced_[name] = {convertToNdArray(bv), VIS_SECRET,
static_cast<int>(sctx_->lctx()->Rank())};
} else {
unsynced_[name] = {convertToNdArray(bv), vtype};
}
}

NdArrayRef ColocatedIo::hostGetVar(const std::string &name) const {
Expand Down Expand Up @@ -325,7 +331,7 @@ void ColocatedIo::sync() {
PtBufferView bv(arr.data(), arr.eltype().as<PtTy>()->pt_type(), arr.shape(),
arr.strides());

auto shares = io.makeShares(bv, priv.vtype);
auto shares = io.makeShares(bv, priv.vtype, priv.owner_rank);
SPU_ENFORCE(shares.size() == lctx->WorldSize());

for (size_t idx = 0; idx < shares.size(); idx++) {
Expand Down
3 changes: 2 additions & 1 deletion libspu/device/io.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ class ColocatedIo {
// un-synchronized data.
struct PrivData {
NdArrayRef arr;
Visibility vtype;
Visibility vtype{Visibility::VIS_INVALID};
int owner_rank{-1};
};
std::map<std::string, PrivData> unsynced_;

Expand Down
32 changes: 32 additions & 0 deletions libspu/device/io_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,38 @@ TEST_P(ColocatedIoTest, Works) {
});
}

TEST(ColocatedIoTest, PrivateWorks) {
const size_t kWorldSize = 2;

RuntimeConfig hconf;
hconf.set_protocol(ProtocolKind::SEMI2K);
hconf.set_field(FieldType::FM64);
hconf.set_experimental_enable_colocated_optimization(true);

mpc::utils::simulate(kWorldSize, [&](auto lctx) {
SPUContext sctx(hconf, lctx);
ColocatedIo cio(&sctx);

// WHEN
if (lctx->Rank() == 0) {
cio.hostSetVar("x", xt::xarray<int>{{1, -2, 3, 0}},
Visibility::VIS_PRIVATE);
} else if (lctx->Rank() == 1) {
cio.hostSetVar("y", xt::xarray<float>{{1, -2, 3, 0}},
Visibility::VIS_PRIVATE);
}
cio.sync();

// THEN
EXPECT_TRUE(cio.deviceHasVar("x"));
auto x = cio.deviceGetVar("x");
EXPECT_TRUE(x.isPrivate()) << x;
EXPECT_TRUE(cio.deviceHasVar("y"));
auto y = cio.deviceGetVar("y");
EXPECT_TRUE(y.isPrivate()) << y;
});
}

INSTANTIATE_TEST_SUITE_P(
ColocatedIoTestInstance, ColocatedIoTest,
testing::Combine(
Expand Down
4 changes: 2 additions & 2 deletions libspu/mpc/aby3/io.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ std::vector<NdArrayRef> Aby3Io::makeBitSecret(const PtBufferView& in) const {
std::vector<bshr_el_t> r0(numel);
std::vector<bshr_el_t> r1(numel);

yacl::crypto::PrgAesCtr(yacl::crypto::RandSeed(), absl::MakeSpan(r0));
yacl::crypto::PrgAesCtr(yacl::crypto::RandSeed(), absl::MakeSpan(r1));
yacl::crypto::PrgAesCtr(yacl::crypto::SecureRandSeed(), absl::MakeSpan(r0));
yacl::crypto::PrgAesCtr(yacl::crypto::SecureRandSeed(), absl::MakeSpan(r1));

NdArrayView<bshr_t> _s0(shares[0]);
NdArrayView<bshr_t> _s1(shares[1]);
Expand Down
4 changes: 2 additions & 2 deletions libspu/mpc/cheetah/arith/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
namespace spu::mpc::cheetah {

EnableCPRNG::EnableCPRNG()
: seed_(yacl::crypto::RandSeed(/*drbg*/ true)), prng_counter_(0) {}
: seed_(yacl::crypto::SecureRandSeed()), prng_counter_(0) {}

// Uniform random on prime field
void EnableCPRNG::UniformPrime(const seal::Modulus &prime,
Expand Down Expand Up @@ -75,7 +75,7 @@ NdArrayRef EnableCPRNG::CPRNG(FieldType field, size_t size) {
// Lock prng_counter_
std::scoped_lock guard(counter_lock_);
if (prng_counter_ > kPRNG_THREASHOLD) {
seed_ = yacl::crypto::RandSeed(true);
seed_ = yacl::crypto::SecureRandSeed();
prng_counter_ = 0;
}
return ring_rand(field, {static_cast<int64_t>(size)}, seed_, &prng_counter_);
Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/cheetah/arith/matmat_prot_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class MatMatProtTest
seal::KeyGenerator keygen(*context_);
rlwe_sk_ = std::make_shared<RLWESecretKey>(keygen.secret_key());

seed_ = yacl::crypto::RandSeed();
seed_ = yacl::crypto::SecureRandSeed();
prng_counter_ = 0;
}
};
Expand Down
3 changes: 2 additions & 1 deletion libspu/mpc/cheetah/ot/yacl/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ spu_cc_library(
"@yacl//yacl/crypto/primitives/ot:base_ot",
"@yacl//yacl/crypto/primitives/ot:ferret_ote",
"@yacl//yacl/crypto/primitives/ot:iknp_ote",
"@yacl//yacl/crypto/tools:random_permutation",
"@yacl//yacl/crypto/tools:crhash",
"@yacl//yacl/crypto/tools:rp",
"@yacl//yacl/crypto/utils:rand",
"@yacl//yacl/link",
],
Expand Down
3 changes: 2 additions & 1 deletion libspu/mpc/cheetah/ot/yacl/ferret.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

#include "spdlog/spdlog.h"
#include "yacl/base/buffer.h"
#include "yacl/crypto/tools/random_permutation.h"
#include "yacl/crypto/tools/crhash.h"
#include "yacl/crypto/tools/rp.h"
#include "yacl/link/link.h"

#include "libspu/mpc/cheetah/ot/ot_util.h"
Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/cheetah/rlwe/modswitch_helper_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class RLWE2LWETest : public testing::TestWithParam<FieldType> {
seal::KeyGenerator keygen(*context_);
rlwe_sk_ = std::make_shared<RLWESecretKey>(keygen.secret_key());

seed_ = yacl::crypto::RandSeed();
seed_ = yacl::crypto::SecureRandSeed();
prng_counter_ = 0;
}
};
Expand Down
10 changes: 5 additions & 5 deletions libspu/mpc/common/prg_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace spu::mpc {
PrgState::PrgState() {
pub_seed_ = 0;

priv_seed_ = yacl::crypto::RandSeed();
priv_seed_ = yacl::crypto::SecureRandSeed();

self_seed_ = 0;
next_seed_ = 0;
Expand All @@ -33,7 +33,7 @@ PrgState::PrgState() {
PrgState::PrgState(const std::shared_ptr<yacl::link::Context>& lctx) {
// synchronize public state.
{
uint128_t self_pk = yacl::crypto::RandSeed();
uint128_t self_pk = yacl::crypto::SecureRandSeed();

const auto all_buf = yacl::link::AllGather(
lctx, yacl::SerializeUint128(self_pk), "Random::PK");
Expand All @@ -46,11 +46,11 @@ PrgState::PrgState(const std::shared_ptr<yacl::link::Context>& lctx) {
}

// init private state.
priv_seed_ = yacl::crypto::RandSeed();
priv_seed_ = yacl::crypto::SecureRandSeed();

// init PRSS state.
{
self_seed_ = yacl::crypto::RandSeed();
self_seed_ = yacl::crypto::SecureRandSeed();

constexpr char kCommTag[] = "Random:PRSS";

Expand All @@ -68,7 +68,7 @@ std::unique_ptr<State> PrgState::fork() {

fillPubl(absl::MakeSpan(&new_prg->pub_seed_, 1));

new_prg->priv_seed_ = yacl::crypto::RandSeed();
new_prg->priv_seed_ = yacl::crypto::SecureRandSeed();

fillPrssPair(&new_prg->self_seed_, &new_prg->next_seed_, 1,
PrgState::GenPrssCtrl::Both);
Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/semi2k/beaver/beaver_tfp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace spu::mpc::semi2k {

BeaverTfpUnsafe::BeaverTfpUnsafe(std::shared_ptr<yacl::link::Context> lctx)
: lctx_(std::move(std::move(lctx))),
seed_(yacl::crypto::RandSeed(true)),
seed_(yacl::crypto::SecureRandSeed()),
counter_(0) {
auto buf = yacl::SerializeUint128(seed_);
std::vector<yacl::Buffer> all_bufs =
Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/semi2k/beaver/beaver_ttp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ BeaverTtp::~BeaverTtp() {

BeaverTtp::BeaverTtp(std::shared_ptr<yacl::link::Context> lctx, Options ops)
: lctx_(std::move(std::move(lctx))),
seed_(yacl::crypto::RandSeed(true)),
seed_(yacl::crypto::SecureRandSeed()),
counter_(0),
options_(std::move(ops)),
child_counter_(0) {
Expand Down
2 changes: 2 additions & 0 deletions libspu/mpc/spdz2k/beaver/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ spu_cc_library(
"//libspu/mpc/spdz2k:commitment",
"//libspu/mpc/utils:ring_ops",
"@com_github_microsoft_seal//:seal",
"@yacl//yacl/crypto/base/block_cipher:symmetric_crypto",
"@yacl//yacl/crypto/tools:prg",
"@yacl//yacl/link",
"@yacl//yacl/utils:parallel",
],
Expand Down
2 changes: 2 additions & 0 deletions libspu/mpc/spdz2k/beaver/beaver_tfp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include <random>
#include <utility>

#include "yacl/crypto/base/block_cipher/symmetric_crypto.h"
#include "yacl/crypto/tools/prg.h"
#include "yacl/crypto/utils/rand.h"
#include "yacl/utils/serialize.h"

Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/spdz2k/commitment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ bool commit_and_open(const std::shared_ptr<yacl::link::Context>& lctx,
std::vector<std::string>* z_strs) {
bool res = true;
size_t send_player = lctx->Rank();
uint128_t rs = yacl::crypto::RandSeed();
uint128_t rs = yacl::crypto::SecureRandSeed();
std::string rs_str(reinterpret_cast<char*>(&rs), sizeof(rs));
// 1. commit and send
auto cmt = commit(send_player, z_str, rs_str);
Expand Down
7 changes: 4 additions & 3 deletions libspu/mpc/spdz2k/ot/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

load("//bazel:spu.bzl", "spu_cc_library", "spu_cc_test")
load("//bazel:spu.bzl", "spu_cc_library")
load("@yacl//bazel:yacl.bzl", "AES_COPT_FLAGS")

package(default_visibility = ["//visibility:public"])
Expand Down Expand Up @@ -46,9 +46,10 @@ spu_cc_library(
"@yacl//yacl/crypto/base/hash:hash_interface",
"@yacl//yacl/crypto/base/hash:hash_utils",
"@yacl//yacl/crypto/primitives/ot:base_ot",
"@yacl//yacl/crypto/tools:crhash",
"@yacl//yacl/crypto/tools:prg",
"@yacl//yacl/crypto/tools:random_oracle",
"@yacl//yacl/crypto/tools:random_permutation",
"@yacl//yacl/crypto/tools:ro",
"@yacl//yacl/crypto/tools:rp",
"@yacl//yacl/link",
"@yacl//yacl/utils:matrix_utils",
"@yacl//yacl/utils:serialize",
Expand Down
5 changes: 3 additions & 2 deletions libspu/mpc/spdz2k/ot/kos_ote.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
#include "emp-tool/utils/block.h"
#include "emp-tool/utils/f2k.h"
#include "yacl/crypto/base/hash/hash_utils.h"
#include "yacl/crypto/tools/crhash.h"
#include "yacl/crypto/tools/prg.h"
#include "yacl/crypto/tools/random_oracle.h"
#include "yacl/crypto/tools/random_permutation.h"
#include "yacl/crypto/tools/ro.h"
#include "yacl/crypto/tools/rp.h"
#include "yacl/link/link.h"
#include "yacl/utils/matrix_utils.h"
#include "yacl/utils/serialize.h"
Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/utils/ring_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ void ring_print(const NdArrayRef& x, std::string_view name) {
NdArrayRef ring_rand(FieldType field, const Shape& shape) {
uint64_t cnt = 0;
return ring_rand(field, shape, yacl::crypto::RandSeed(), &cnt);
return ring_rand(field, shape, yacl::crypto::SecureRandSeed(), &cnt);
}
NdArrayRef ring_rand(FieldType field, const Shape& shape, uint128_t prg_seed,
Expand Down
6 changes: 5 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def get_packages(self):

setup_spec.install_requires = read_requirements('requirements.txt')

files_to_remove = ["spu/intrinsic/add_new_intrinsic.py"]
files_to_remove = [
"spu/intrinsic/add_new_intrinsic.py",
]


# Calls Bazel in PATH
Expand Down Expand Up @@ -207,6 +209,8 @@ def pip_run(build_ext):

# Change __module__ in psi_pb2.py and pir_pb2.py
fix_pb('bazel-bin/spu/psi_pb2.py', 'psi.psi.psi_pb2', 'spu.psi_pb2')
fix_pb('bazel-bin/spu/link_pb2.py', 'yacl.link.link_pb2', 'link.pir_pb2')
fix_pb('bazel-bin/spu/psi_v2_pb2.py', 'psi.proto.psi_v2_pb2', 'spu.psi_pb2')
fix_pb('bazel-bin/spu/pir_pb2.py', 'psi.pir.pir_pb2', 'spu.pir_pb2')

setup_spec.files_to_include += spu_lib_files
Expand Down
23 changes: 23 additions & 0 deletions spu/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ pybind_extension(
":version_script.lds",
"@psi//psi/pir",
"@psi//psi/psi:bucket_psi",
"@psi//psi/psi:factory",
"@psi//psi/psi:memory_psi",
"@yacl//yacl/link",
],
Expand All @@ -90,11 +91,33 @@ python_proto_compile(
protos = ["@psi//psi/psi:psi_proto"],
)

python_proto_compile(
name = "link_py_proto",
output_mode = "NO_PREFIX_FLAT",
protos = ["@yacl//yacl/link:link_proto"],
)

python_proto_compile(
name = "psi_v2_py_proto",
output_mode = "NO_PREFIX",
protos = ["@psi//psi/proto:psi_v2_proto"],
)

# Hack generated protobuf due to https://github.com/protocolbuffers/protobuf/issues/1491
genrule(
name = "psi_v2_py_proto_fixed",
srcs = [":psi_v2_py_proto"],
outs = ["psi_v2_pb2.py"],
cmd = "sed 's#from yacl.link import#from . import#g;s#from psi.psi import#from . import#g' $(SRCS) > $(OUTS)",
)

py_library(
name = "psi",
srcs = [
"psi.py",
":link_py_proto",
":psi_py_proto",
":psi_v2_py_proto_fixed",
],
data = [
":libpsi.so",
Expand Down
16 changes: 16 additions & 0 deletions spu/libpsi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

#include "psi/pir/pir.h"
#include "psi/psi/bucket_psi.h"
#include "psi/psi/factory.h"
#include "psi/psi/memory_psi.h"
#include "psi/psi/utils/progress.h"

Expand Down Expand Up @@ -76,6 +77,21 @@ void BindLibs(py::module& m) {
py::arg("callbacks_interval_ms") = 5 * 1000, py::arg("ic_mode") = false,
"Run bucket psi. ic_mode means run in interconnection mode", NO_GIL);

m.def(
"psi_v2",
[](const std::string& config_pb,
const std::shared_ptr<yacl::link::Context>& lctx) -> py::bytes {
psi::v2::PsiConfig psi_config;
YACL_ENFORCE(psi_config.ParseFromString(config_pb));

std::unique_ptr<psi::AbstractPSIParty> psi_party =
psi::createPSIParty(psi_config, lctx);
psi::v2::PsiReport report = psi_party->Run();
return report.SerializeAsString();
},
py::arg("psi_config"), py::arg("link_context") = nullptr,
"Run PSI with v2 API.", NO_GIL);

m.def(
"pir_setup",
[](const std::string& config_pb) -> py::bytes {
Expand Down
Loading

0 comments on commit 338ff3d

Please sign in to comment.