diff --git a/examples/python/ml/flax_llama7b_split/README.md b/examples/python/ml/flax_llama7b_split/README.md index 950b9083..4a4d6a0d 100644 --- a/examples/python/ml/flax_llama7b_split/README.md +++ b/examples/python/ml/flax_llama7b_split/README.md @@ -5,11 +5,13 @@ This example demonstrates how to use SPU to run secure inference on a pre-traine 1. Motivation -- Time: -Using the full LlaMA model for inference on SPU can take a significant amount of time. If only a portion of the model is passed through SPU to ensure privacy, it can greatly improve inference efficiency. + - Time: + Using the full LlaMA model for inference on SPU can take a significant amount of time. + If only a portion of the model is passed through SPU to ensure privacy, it can greatly improve inference efficiency. -- RAM Usage: -Using the full LLaMA model for inference on SPU requires a large amount of memory(more than 256 GB). Splitting the model can significantly reduce memory usage, making it available for use in hardware-constrained environments. + - RAM Usage: + Using the full LLaMA model for inference on SPU requires a large amount of memory(more than 256 GB). + Splitting the model can significantly reduce memory usage, making it available for use in hardware-constrained environments. 2. Download EasyML library to support Flax-LLaMA-7B @@ -18,6 +20,7 @@ Using the full LLaMA model for inference on SPU requires a large amount of memor cd EasyLM export PYTHONPATH="${PWD}:$PYTHONPATH" ``` + Or use a fork created for transformer split. ```sh @@ -42,7 +45,7 @@ Using the full LLaMA model for inference on SPU requires a large amount of memor pip install jax==0.4.11 jaxlib==0.4.11 ``` - Download trained LLaMA-7B[PyTroch-Version] from "https://github.com/facebookresearch/llama", and convert it to EasyLM format as: + Download trained LLaMA-7B[PyTroch-Version]("https://github.com/facebookresearch/llama"), and convert it to EasyLM format as: ```sh cd path_to_EasyLM/EasyLM/models/llama @@ -102,9 +105,11 @@ Using the full LLaMA model for inference on SPU requires a large amount of memor A: The largest animal is the blue whale. generate on SPU: 812.9427680969238 seconds ``` + RAM peak: 64.5888GB And If you set token_num to 30, you can get the following results: + ```sh ------ Run on CPU @@ -126,56 +131,64 @@ Using the full LLaMA model for inference on SPU requires a large amount of memor 5. Supplement the Split Strategy -In this example, we split the LLaMA-7B model into three parts as follows: + In this example, we split the LLaMA-7B model into three parts as follows: -- **Client**: Embedding + 0-1 LLaMA-Block -- **Mid**: 2nd LLaMA-Block (_runing on the spu_) -- **Server**: 3-31 LLaMA-Block + RMSNorm Layer + - **Client**: Embedding + 0-1 LLaMA-Block + - **Mid**: 2nd LLaMA-Block (_runing on the spu_) + - **Server**: 3-31 LLaMA-Block + RMSNorm Layer -Actually, if users want to split LLaMA-7B model in other way, this can be easily achieved by rewriting few lines of code in `flax_llama7b_split.py` and `llama_model_splited_transformer.py`. + Actually, if users want to split LLaMA-7B model in other way, this can be easily achieved by rewriting few lines of code in `flax_llama7b_split.py` + and `llama_model_splited_transformer.py`. -For example, we rewrite the files as follow. + For example, we rewrite the files as follow. -```python -# flax_llama7b_split.py -# lines 76-89 + ```python + # flax_llama7b_split.py + # lines 76-89 -client_params_dict = { + client_params_dict = { + "transformer":{ + "wte":params['params']["transformer"]["wte"], + "ln_f": params['params']["transformer"]["ln_f"], + "h":{str(i): params['params']["transformer"]["h"][str(i)] for i in range(2)} + } + } + + mid_params_dict = { "transformer":{ - "wte":params['params']["transformer"]["wte"], - "ln_f": params['params']["transformer"]["ln_f"], - "h":{str(i): params['params']["transformer"]["h"][str(i)] for i in range(2)} + + "h":{str(i): params['params']["transformer"]["h"][str(i)] for i in range(2, 5)} } } + ``` -mid_params_dict = { - "transformer":{ + ```python + # llama_model_splited_transformer.py + # lines 1194 + for block in self.blocks[: 2]: + ... - "h":{str(i): params['params']["transformer"]["h"][str(i)] for i in range(2, 5)} - } -} -``` -```python -# llama_model_splited_transformer.py -# lines 1194 -for block in self.blocks[: 2]: - ... - -# lines 1274 -for block in self.blocks[2: 5]: - ... - -# lines 1355 -for block in self.blocks[5:]: - ... -``` -After that, the LLaMA-7B model will be split as follows: -- **Client**: Embedding + 0-1 LLaMA-Block -- **Mid**: 2-4 LLaMA-Block (_runing on the spu_) -- **Server**: 5-31 LLaMA-Block + RMSNorm Layer + # lines 1274 + for block in self.blocks[2: 5]: + ... + + # lines 1355 + for block in self.blocks[5:]: + ... + ``` + + After that, the LLaMA-7B model will be split as follows: + + - **Client**: Embedding + 0-1 LLaMA-Block + - **Mid**: 2-4 LLaMA-Block (_runing on the spu_) + - **Server**: 5-31 LLaMA-Block + RMSNorm Layer 6. Privacy Security Warning -In this example, our main motivation is to reduce the hardware and time resource costs of [Llama-7B](https://research.facebook.com/publications/llama-open-and-efficient-foundation-language-models/) model inference using the SPU. Therefore, spu is only used for inference on the middle blocks of the model. Its privacy protection capability for the original data is weaker when using spu for inference on the entire Llama-7B model. It may be vulnerable to Model Inversion Attacks known in Split Learning as follows: +In this example, our main motivation is to reduce the hardware and time resource costs of [Llama-7B](https://research.facebook.com/publications/llama-open-and-efficient-foundation-language-models/) +model inference using the SPU. Therefore, spu is only used for inference on the middle blocks of the model. +Its privacy protection capability for the original data is weaker when using spu for inference on the entire Llama-7B model. +It may be vulnerable to Model Inversion Attacks known in Split Learning as follows: + - [PCAT: Functionality and Data Stealing from Split Learning by Pseudo-Client Attack](https://www.usenix.org/system/files/usenixsecurity23-gao.pdf) - [UnSplit: Data-Oblivious Model Inversion, Model Stealing, and Label Inference Attacks Against Split Learning](https://arxiv.org/pdf/2108.09033.pdf) diff --git a/examples/python/ml/flax_llama7b_split/gpu_environment.yml b/examples/python/ml/flax_llama7b_split/gpu_environment.yml index 68c38e6f..b5965d2f 100644 --- a/examples/python/ml/flax_llama7b_split/gpu_environment.yml +++ b/examples/python/ml/flax_llama7b_split/gpu_environment.yml @@ -38,4 +38,3 @@ dependencies: - fastapi - uvicorn - gradio - diff --git a/libspu/core/ndarray_ref.cc b/libspu/core/ndarray_ref.cc index 512f68f3..e6c9f0c7 100644 --- a/libspu/core/ndarray_ref.cc +++ b/libspu/core/ndarray_ref.cc @@ -219,13 +219,6 @@ NdArrayRef NdArrayRef::clone() const { return res; } -std::shared_ptr NdArrayRef::getOrCreateCompactBuf() const { - if (isCompact() && offset_ == 0) { - return buf(); - } - return clone().buf(); -} - void NdArrayRef::copy_slice(const NdArrayRef& src, const Index& src_base, const Index& dst_base, int64_t num_copy) { NdArrayRef::Iterator src_iter(src, src_base); diff --git a/libspu/core/ndarray_ref.h b/libspu/core/ndarray_ref.h index b2595048..ec4d9f7e 100644 --- a/libspu/core/ndarray_ref.h +++ b/libspu/core/ndarray_ref.h @@ -121,8 +121,6 @@ class NdArrayRef { std::shared_ptr buf() const { return buf_; } - std::shared_ptr getOrCreateCompactBuf() const; - // create a compact clone. NdArrayRef clone() const; @@ -447,7 +445,7 @@ template size_t maxBitWidth(const NdArrayRef& in) { auto numel = in.numel(); if (numel == 0) { - return 0; + return sizeof(T) * 8; } if (std::all_of(in.strides().begin(), in.strides().end(), diff --git a/libspu/kernel/hal/polymorphic.h b/libspu/kernel/hal/polymorphic.h index 5563eb29..e9e82d0b 100644 --- a/libspu/kernel/hal/polymorphic.h +++ b/libspu/kernel/hal/polymorphic.h @@ -189,8 +189,6 @@ Value right_shift_logical(SPUContext* ctx, const Value& x, size_t bits); Value right_shift_arithmetic(SPUContext* ctx, const Value& x, size_t bits); -Value popcount(SPUContext* ctx, const Value& x); - /// the element-wise base-2 logarithm of x // @param in, should be positive, or the result is implementation defined. Value log2(SPUContext* ctx, const Value& in); diff --git a/libspu/kernel/hal/polymorphic_test.cc b/libspu/kernel/hal/polymorphic_test.cc index 0d9e114d..c623e013 100644 --- a/libspu/kernel/hal/polymorphic_test.cc +++ b/libspu/kernel/hal/polymorphic_test.cc @@ -769,35 +769,4 @@ TYPED_TEST(MathTest, Div) { << z << std::endl; } -// TEST(PopcountTest, Works) { -// { -// // GIVEN -// xt::xarray x = xt::xarray{1, 100, 1000, -1, -100, -// -1000}; - -// // WHAT -// auto z = test::evalUnaryOp(secret_v(), popcount, x); -// auto expected = xt::xarray{1, 3, 6, 64, 60, 56}; - -// // THEN -// EXPECT_TRUE(xt::allclose(expected, z, 0.01, 0.001)) << expected << -// std::endl -// << z; -// } - -// { -// // GIVEN -// xt::xarray x = xt::xarray{1, 100, 1000, -1, -100, -1000}; - -// // WHAT -// auto z = test::evalUnaryOp(secret_v(), popcount, x); -// auto expected = xt::xarray{1, 3, 6, 46, 42, 38}; - -// // THEN -// EXPECT_TRUE(xt::allclose(expected, z, 0.01, 0.001)) << expected << -// std::endl -// << z; -// } -//} - } // namespace spu::kernel::hal diff --git a/libspu/kernel/hlo/BUILD.bazel b/libspu/kernel/hlo/BUILD.bazel index 139a5c9f..0fc0edf9 100644 --- a/libspu/kernel/hlo/BUILD.bazel +++ b/libspu/kernel/hlo/BUILD.bazel @@ -271,6 +271,7 @@ spu_cc_library( deps = [ ":basic_binary", ":casting", + ":const", ":utils", "//libspu/kernel/hal:shape_ops", "//libspu/kernel/hal:sort", diff --git a/libspu/kernel/hlo/sort.cc b/libspu/kernel/hlo/sort.cc index ad4e3017..4fd38127 100644 --- a/libspu/kernel/hlo/sort.cc +++ b/libspu/kernel/hlo/sort.cc @@ -61,6 +61,9 @@ std::vector Sort(SPUContext *ctx, // - W is the vector length. const int64_t M = inputs.size(); const int64_t W = shape.dim(sort_dim); + if (W == 0) { + return std::vector(inputs.begin(), inputs.end()); + } const int64_t N = shape.numel() / W; Axes perm(shape.ndim()); Axes unperm; diff --git a/libspu/kernel/hlo/sort_test.cc b/libspu/kernel/hlo/sort_test.cc index 32fb4737..5773e70e 100644 --- a/libspu/kernel/hlo/sort_test.cc +++ b/libspu/kernel/hlo/sort_test.cc @@ -22,6 +22,8 @@ #include "libspu/kernel/hal/constants.h" #include "libspu/kernel/hal/polymorphic.h" #include "libspu/kernel/hal/type_cast.h" +#include "libspu/kernel/hlo/casting.h" +#include "libspu/kernel/hlo/const.h" #include "libspu/kernel/test_util.h" namespace spu::kernel::hlo { @@ -147,4 +149,21 @@ TEST(SortTest, MultiOperands) { << sorted_k2_hat << std::endl; } +TEST(SortTest, EmptyOperands) { + SPUContext ctx = test::makeSPUContext(); + auto empty_x = Seal(&ctx, Constant(&ctx, 1, {0})); + + std::vector rets = Sort( + &ctx, {empty_x}, 0, false, + [&](absl::Span inputs) { + return hal::less(&ctx, inputs[0], inputs[1]); + }, + Visibility::VIS_SECRET); + + EXPECT_EQ(rets.size(), 1); + EXPECT_EQ(rets[0].numel(), 0); + EXPECT_EQ(rets[0].shape().size(), 1); + EXPECT_EQ(rets[0].shape()[0], 0); +} + } // namespace spu::kernel::hlo diff --git a/libspu/mpc/aby3/boolean.cc b/libspu/mpc/aby3/boolean.cc index a95a42ee..ef030c28 100644 --- a/libspu/mpc/aby3/boolean.cc +++ b/libspu/mpc/aby3/boolean.cc @@ -220,7 +220,7 @@ NdArrayRef AndBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const auto* lhs_ty = lhs.eltype().as(); const auto* rhs_ty = rhs.eltype().as(); - const size_t out_nbits = std::max(lhs_ty->nbits(), rhs_ty->nbits()); + const size_t out_nbits = std::min(lhs_ty->nbits(), rhs_ty->nbits()); const PtType out_btype = calcBShareBacktype(out_nbits); NdArrayRef out(makeType(out_btype, out_nbits), lhs.shape()); diff --git a/libspu/mpc/aby3/conversion.cc b/libspu/mpc/aby3/conversion.cc index 2f030970..f73d3037 100644 --- a/libspu/mpc/aby3/conversion.cc +++ b/libspu/mpc/aby3/conversion.cc @@ -255,6 +255,7 @@ template static std::vector bitDecompose(const NdArrayRef& in, size_t nbits) { auto numel = in.numel(); // decompose each bit of an array of element. + // FIXME: this is not thread-safe. std::vector dep(numel * nbits); NdArrayView _in(in); diff --git a/libspu/mpc/api.cc b/libspu/mpc/api.cc index 7503042b..cea2583a 100644 --- a/libspu/mpc/api.cc +++ b/libspu/mpc/api.cc @@ -131,7 +131,7 @@ Value SvpUnaryDisp(SPUContext* ctx, const Value& x, Args&&... args) { FORCE_NAMED_DISPATCH(CTX, __func__, __VA_ARGS__) #define TRY_NAMED_DISPATCH(CTX, FNAME, ...) \ - if ((CTX)->hasKernel(__func__)) { \ + if ((CTX)->hasKernel(FNAME)) { \ SPU_TRACE_MPC_LEAF(CTX, __VA_ARGS__); \ return dynDispatch((CTX), FNAME, __VA_ARGS__); \ } diff --git a/libspu/mpc/common/communicator.cc b/libspu/mpc/common/communicator.cc index 17f9dfa9..0d20dddd 100644 --- a/libspu/mpc/common/communicator.cc +++ b/libspu/mpc/common/communicator.cc @@ -26,12 +26,18 @@ std::shared_ptr stealBuffer(yacl::Buffer&& buf) { return std::make_shared(std::move(buf)); } +std::shared_ptr getOrCreateCompactBuf(const NdArrayRef& in) { + if (in.numel() * in.elsize() != static_cast(in.buf()->size())) { + return in.clone().buf(); + } + return in.buf(); +} + } // namespace NdArrayRef Communicator::allReduce(ReduceOp op, const NdArrayRef& in, std::string_view tag) { - const auto buf = in.getOrCreateCompactBuf(); - + const auto buf = getOrCreateCompactBuf(in); std::vector bufs = yacl::link::AllGather(lctx_, *buf, tag); SPU_ENFORCE(bufs.size() == getWorldSize()); @@ -61,8 +67,7 @@ NdArrayRef Communicator::allReduce(ReduceOp op, const NdArrayRef& in, NdArrayRef Communicator::reduce(ReduceOp op, const NdArrayRef& in, size_t root, std::string_view tag) { SPU_ENFORCE(root < lctx_->WorldSize()); - const auto buf = in.getOrCreateCompactBuf(); - + const auto buf = getOrCreateCompactBuf(in); std::vector bufs = yacl::link::Gather(lctx_, *buf, root, tag); auto res = in.clone(); @@ -92,8 +97,7 @@ NdArrayRef Communicator::reduce(ReduceOp op, const NdArrayRef& in, size_t root, } NdArrayRef Communicator::rotate(const NdArrayRef& in, std::string_view tag) { - const auto buf = in.getOrCreateCompactBuf(); - + const auto buf = getOrCreateCompactBuf(in); lctx_->SendAsync(lctx_->PrevRank(), *buf, tag); auto res_buf = lctx_->Recv(lctx_->NextRank(), tag); @@ -107,8 +111,7 @@ NdArrayRef Communicator::rotate(const NdArrayRef& in, std::string_view tag) { void Communicator::sendAsync(size_t dst_rank, const NdArrayRef& in, std::string_view tag) { - const auto buf = in.getOrCreateCompactBuf(); - + const auto buf = getOrCreateCompactBuf(in); lctx_->SendAsync(dst_rank, *buf, tag); }