Skip to content

Commit

Permalink
repo-sync-2023-11-03T13:21:43+0800
Browse files Browse the repository at this point in the history
  • Loading branch information
anakinxc committed Nov 3, 2023
1 parent 9f60e68 commit b06aa85
Show file tree
Hide file tree
Showing 13 changed files with 94 additions and 97 deletions.
99 changes: 56 additions & 43 deletions examples/python/ml/flax_llama7b_split/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ This example demonstrates how to use SPU to run secure inference on a pre-traine

1. Motivation

- Time:
Using the full LlaMA model for inference on SPU can take a significant amount of time. If only a portion of the model is passed through SPU to ensure privacy, it can greatly improve inference efficiency.
- Time:
Using the full LlaMA model for inference on SPU can take a significant amount of time.
If only a portion of the model is passed through SPU to ensure privacy, it can greatly improve inference efficiency.

- RAM Usage:
Using the full LLaMA model for inference on SPU requires a large amount of memory(more than 256 GB). Splitting the model can significantly reduce memory usage, making it available for use in hardware-constrained environments.
- RAM Usage:
Using the full LLaMA model for inference on SPU requires a large amount of memory(more than 256 GB).
Splitting the model can significantly reduce memory usage, making it available for use in hardware-constrained environments.

2. Download EasyML library to support Flax-LLaMA-7B

Expand All @@ -18,6 +20,7 @@ Using the full LLaMA model for inference on SPU requires a large amount of memor
cd EasyLM
export PYTHONPATH="${PWD}:$PYTHONPATH"
```

Or use a fork created for transformer split.

```sh
Expand All @@ -42,7 +45,7 @@ Using the full LLaMA model for inference on SPU requires a large amount of memor
pip install jax==0.4.11 jaxlib==0.4.11
```

Download trained LLaMA-7B[PyTroch-Version] from "https://github.com/facebookresearch/llama", and convert it to EasyLM format as:
Download trained LLaMA-7B[PyTroch-Version]("https://github.com/facebookresearch/llama"), and convert it to EasyLM format as:

```sh
cd path_to_EasyLM/EasyLM/models/llama
Expand Down Expand Up @@ -102,9 +105,11 @@ Using the full LLaMA model for inference on SPU requires a large amount of memor
A: The largest animal is the blue whale.
generate on SPU: 812.9427680969238 seconds
```

RAM peak: 64.5888GB

And If you set token_num to 30, you can get the following results:

```sh
------
Run on CPU
Expand All @@ -126,56 +131,64 @@ Using the full LLaMA model for inference on SPU requires a large amount of memor

5. Supplement the Split Strategy

In this example, we split the LLaMA-7B model into three parts as follows:
In this example, we split the LLaMA-7B model into three parts as follows:

- **Client**: Embedding + 0-1 LLaMA-Block
- **Mid**: 2nd LLaMA-Block (_runing on the spu_)
- **Server**: 3-31 LLaMA-Block + RMSNorm Layer
- **Client**: Embedding + 0-1 LLaMA-Block
- **Mid**: 2nd LLaMA-Block (_runing on the spu_)
- **Server**: 3-31 LLaMA-Block + RMSNorm Layer

Actually, if users want to split LLaMA-7B model in other way, this can be easily achieved by rewriting few lines of code in `flax_llama7b_split.py` and `llama_model_splited_transformer.py`.
Actually, if users want to split LLaMA-7B model in other way, this can be easily achieved by rewriting few lines of code in `flax_llama7b_split.py`
and `llama_model_splited_transformer.py`.

For example, we rewrite the files as follow.
For example, we rewrite the files as follow.

```python
# flax_llama7b_split.py
# lines 76-89
```python
# flax_llama7b_split.py
# lines 76-89
client_params_dict = {
client_params_dict = {
"transformer":{
"wte":params['params']["transformer"]["wte"],
"ln_f": params['params']["transformer"]["ln_f"],
"h":{str(i): params['params']["transformer"]["h"][str(i)] for i in range(2)}
}
}
mid_params_dict = {
"transformer":{
"wte":params['params']["transformer"]["wte"],
"ln_f": params['params']["transformer"]["ln_f"],
"h":{str(i): params['params']["transformer"]["h"][str(i)] for i in range(2)}
"h":{str(i): params['params']["transformer"]["h"][str(i)] for i in range(2, 5)}
}
}
```
mid_params_dict = {
"transformer":{
```python
# llama_model_splited_transformer.py
# lines 1194
for block in self.blocks[: 2]:
...
"h":{str(i): params['params']["transformer"]["h"][str(i)] for i in range(2, 5)}
}
}
```
```python
# llama_model_splited_transformer.py
# lines 1194
for block in self.blocks[: 2]:
...
# lines 1274
for block in self.blocks[2: 5]:
...
# lines 1355
for block in self.blocks[5:]:
...
```
After that, the LLaMA-7B model will be split as follows:
- **Client**: Embedding + 0-1 LLaMA-Block
- **Mid**: 2-4 LLaMA-Block (_runing on the spu_)
- **Server**: 5-31 LLaMA-Block + RMSNorm Layer
# lines 1274
for block in self.blocks[2: 5]:
...
# lines 1355
for block in self.blocks[5:]:
...
```
After that, the LLaMA-7B model will be split as follows:
- **Client**: Embedding + 0-1 LLaMA-Block
- **Mid**: 2-4 LLaMA-Block (_runing on the spu_)
- **Server**: 5-31 LLaMA-Block + RMSNorm Layer
6. Privacy Security Warning
In this example, our main motivation is to reduce the hardware and time resource costs of [Llama-7B](https://research.facebook.com/publications/llama-open-and-efficient-foundation-language-models/) model inference using the SPU. Therefore, spu is only used for inference on the middle blocks of the model. Its privacy protection capability for the original data is weaker when using spu for inference on the entire Llama-7B model. It may be vulnerable to Model Inversion Attacks known in Split Learning as follows:
In this example, our main motivation is to reduce the hardware and time resource costs of [Llama-7B](https://research.facebook.com/publications/llama-open-and-efficient-foundation-language-models/)
model inference using the SPU. Therefore, spu is only used for inference on the middle blocks of the model.
Its privacy protection capability for the original data is weaker when using spu for inference on the entire Llama-7B model.
It may be vulnerable to Model Inversion Attacks known in Split Learning as follows:
- [PCAT: Functionality and Data Stealing from Split Learning by Pseudo-Client Attack](https://www.usenix.org/system/files/usenixsecurity23-gao.pdf)
- [UnSplit: Data-Oblivious Model Inversion, Model Stealing, and Label Inference Attacks Against Split Learning](https://arxiv.org/pdf/2108.09033.pdf)
1 change: 0 additions & 1 deletion examples/python/ml/flax_llama7b_split/gpu_environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,3 @@ dependencies:
- fastapi
- uvicorn
- gradio

7 changes: 0 additions & 7 deletions libspu/core/ndarray_ref.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,13 +219,6 @@ NdArrayRef NdArrayRef::clone() const {
return res;
}

std::shared_ptr<yacl::Buffer> NdArrayRef::getOrCreateCompactBuf() const {
if (isCompact() && offset_ == 0) {
return buf();
}
return clone().buf();
}

void NdArrayRef::copy_slice(const NdArrayRef& src, const Index& src_base,
const Index& dst_base, int64_t num_copy) {
NdArrayRef::Iterator src_iter(src, src_base);
Expand Down
4 changes: 1 addition & 3 deletions libspu/core/ndarray_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,6 @@ class NdArrayRef {

std::shared_ptr<yacl::Buffer> buf() const { return buf_; }

std::shared_ptr<yacl::Buffer> getOrCreateCompactBuf() const;

// create a compact clone.
NdArrayRef clone() const;

Expand Down Expand Up @@ -447,7 +445,7 @@ template <typename T>
size_t maxBitWidth(const NdArrayRef& in) {
auto numel = in.numel();
if (numel == 0) {
return 0;
return sizeof(T) * 8;
}

if (std::all_of(in.strides().begin(), in.strides().end(),
Expand Down
2 changes: 0 additions & 2 deletions libspu/kernel/hal/polymorphic.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,6 @@ Value right_shift_logical(SPUContext* ctx, const Value& x, size_t bits);

Value right_shift_arithmetic(SPUContext* ctx, const Value& x, size_t bits);

Value popcount(SPUContext* ctx, const Value& x);

/// the element-wise base-2 logarithm of x
// @param in, should be positive, or the result is implementation defined.
Value log2(SPUContext* ctx, const Value& in);
Expand Down
31 changes: 0 additions & 31 deletions libspu/kernel/hal/polymorphic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -769,35 +769,4 @@ TYPED_TEST(MathTest, Div) {
<< z << std::endl;
}

// TEST(PopcountTest, Works) {
// {
// // GIVEN
// xt::xarray<int32_t> x = xt::xarray<int32_t>{1, 100, 1000, -1, -100,
// -1000};

// // WHAT
// auto z = test::evalUnaryOp<int>(secret_v(), popcount, x);
// auto expected = xt::xarray<int>{1, 3, 6, 64, 60, 56};

// // THEN
// EXPECT_TRUE(xt::allclose(expected, z, 0.01, 0.001)) << expected <<
// std::endl
// << z;
// }

// {
// // GIVEN
// xt::xarray<float> x = xt::xarray<float>{1, 100, 1000, -1, -100, -1000};

// // WHAT
// auto z = test::evalUnaryOp<int>(secret_v(), popcount, x);
// auto expected = xt::xarray<int>{1, 3, 6, 46, 42, 38};

// // THEN
// EXPECT_TRUE(xt::allclose(expected, z, 0.01, 0.001)) << expected <<
// std::endl
// << z;
// }
//}

} // namespace spu::kernel::hal
1 change: 1 addition & 0 deletions libspu/kernel/hlo/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ spu_cc_library(
deps = [
":basic_binary",
":casting",
":const",
":utils",
"//libspu/kernel/hal:shape_ops",
"//libspu/kernel/hal:sort",
Expand Down
3 changes: 3 additions & 0 deletions libspu/kernel/hlo/sort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ std::vector<spu::Value> Sort(SPUContext *ctx,
// - W is the vector length.
const int64_t M = inputs.size();
const int64_t W = shape.dim(sort_dim);
if (W == 0) {
return std::vector<spu::Value>(inputs.begin(), inputs.end());
}
const int64_t N = shape.numel() / W;
Axes perm(shape.ndim());
Axes unperm;
Expand Down
19 changes: 19 additions & 0 deletions libspu/kernel/hlo/sort_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#include "libspu/kernel/hal/constants.h"
#include "libspu/kernel/hal/polymorphic.h"
#include "libspu/kernel/hal/type_cast.h"
#include "libspu/kernel/hlo/casting.h"
#include "libspu/kernel/hlo/const.h"
#include "libspu/kernel/test_util.h"

namespace spu::kernel::hlo {
Expand Down Expand Up @@ -147,4 +149,21 @@ TEST(SortTest, MultiOperands) {
<< sorted_k2_hat << std::endl;
}

TEST(SortTest, EmptyOperands) {
SPUContext ctx = test::makeSPUContext();
auto empty_x = Seal(&ctx, Constant(&ctx, 1, {0}));

std::vector<spu::Value> rets = Sort(
&ctx, {empty_x}, 0, false,
[&](absl::Span<const spu::Value> inputs) {
return hal::less(&ctx, inputs[0], inputs[1]);
},
Visibility::VIS_SECRET);

EXPECT_EQ(rets.size(), 1);
EXPECT_EQ(rets[0].numel(), 0);
EXPECT_EQ(rets[0].shape().size(), 1);
EXPECT_EQ(rets[0].shape()[0], 0);
}

} // namespace spu::kernel::hlo
2 changes: 1 addition & 1 deletion libspu/mpc/aby3/boolean.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ NdArrayRef AndBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs,
const auto* lhs_ty = lhs.eltype().as<BShrTy>();
const auto* rhs_ty = rhs.eltype().as<BShrTy>();

const size_t out_nbits = std::max(lhs_ty->nbits(), rhs_ty->nbits());
const size_t out_nbits = std::min(lhs_ty->nbits(), rhs_ty->nbits());
const PtType out_btype = calcBShareBacktype(out_nbits);
NdArrayRef out(makeType<BShrTy>(out_btype, out_nbits), lhs.shape());

Expand Down
1 change: 1 addition & 0 deletions libspu/mpc/aby3/conversion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ template <typename T>
static std::vector<bool> bitDecompose(const NdArrayRef& in, size_t nbits) {
auto numel = in.numel();
// decompose each bit of an array of element.
// FIXME: this is not thread-safe.
std::vector<bool> dep(numel * nbits);

NdArrayView<T> _in(in);
Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ Value SvpUnaryDisp(SPUContext* ctx, const Value& x, Args&&... args) {
FORCE_NAMED_DISPATCH(CTX, __func__, __VA_ARGS__)

#define TRY_NAMED_DISPATCH(CTX, FNAME, ...) \
if ((CTX)->hasKernel(__func__)) { \
if ((CTX)->hasKernel(FNAME)) { \
SPU_TRACE_MPC_LEAF(CTX, __VA_ARGS__); \
return dynDispatch((CTX), FNAME, __VA_ARGS__); \
}
Expand Down
19 changes: 11 additions & 8 deletions libspu/mpc/common/communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,18 @@ std::shared_ptr<yacl::Buffer> stealBuffer(yacl::Buffer&& buf) {
return std::make_shared<yacl::Buffer>(std::move(buf));
}

std::shared_ptr<yacl::Buffer> getOrCreateCompactBuf(const NdArrayRef& in) {
if (in.numel() * in.elsize() != static_cast<size_t>(in.buf()->size())) {
return in.clone().buf();
}
return in.buf();
}

} // namespace

NdArrayRef Communicator::allReduce(ReduceOp op, const NdArrayRef& in,
std::string_view tag) {
const auto buf = in.getOrCreateCompactBuf();

const auto buf = getOrCreateCompactBuf(in);
std::vector<yacl::Buffer> bufs = yacl::link::AllGather(lctx_, *buf, tag);

SPU_ENFORCE(bufs.size() == getWorldSize());
Expand Down Expand Up @@ -61,8 +67,7 @@ NdArrayRef Communicator::allReduce(ReduceOp op, const NdArrayRef& in,
NdArrayRef Communicator::reduce(ReduceOp op, const NdArrayRef& in, size_t root,
std::string_view tag) {
SPU_ENFORCE(root < lctx_->WorldSize());
const auto buf = in.getOrCreateCompactBuf();

const auto buf = getOrCreateCompactBuf(in);
std::vector<yacl::Buffer> bufs = yacl::link::Gather(lctx_, *buf, root, tag);

auto res = in.clone();
Expand Down Expand Up @@ -92,8 +97,7 @@ NdArrayRef Communicator::reduce(ReduceOp op, const NdArrayRef& in, size_t root,
}

NdArrayRef Communicator::rotate(const NdArrayRef& in, std::string_view tag) {
const auto buf = in.getOrCreateCompactBuf();

const auto buf = getOrCreateCompactBuf(in);
lctx_->SendAsync(lctx_->PrevRank(), *buf, tag);

auto res_buf = lctx_->Recv(lctx_->NextRank(), tag);
Expand All @@ -107,8 +111,7 @@ NdArrayRef Communicator::rotate(const NdArrayRef& in, std::string_view tag) {

void Communicator::sendAsync(size_t dst_rank, const NdArrayRef& in,
std::string_view tag) {
const auto buf = in.getOrCreateCompactBuf();

const auto buf = getOrCreateCompactBuf(in);
lctx_->SendAsync(dst_rank, *buf, tag);
}

Expand Down

0 comments on commit b06aa85

Please sign in to comment.