Skip to content

Commit

Permalink
Flax llama7b split (#355)
Browse files Browse the repository at this point in the history
# Pull Request

添加了一个SPU明密文混合运算的用例:结合Split Learning和SPU,将llama7b的模型中的某一层放在SPU上面跑,以此来减少运行的时间和内存的使用,不过安全性有所降级,存在中间层Embedding泄露风险。


- Performance:
```sh
   ------
    Run on CPU
    Q: What is the largest animal?
    A: The largest animal is the blue whale.
    Q: What is the smallest animal?
    A: The smallest animal is the bacterium.
    generate on CPU: 837.0810837745667 seconds

    ------
    Run on SPU
    [0000-00-00 00:00:00.000] [info] [thread_pool.cc:30] Create a fixed thread pool with size 127
    Q: What is the largest animal?
    A: The largest animal is the blue whale.
    Q: What is the smallest animal?
    A: The smallest animal is the bacterium.
    generate  on SPU: 2760.5035905838013 seconds
 ```
  • Loading branch information
Rainysponge authored Oct 31, 2023
1 parent 4ab16e2 commit 9f60e68
Show file tree
Hide file tree
Showing 6 changed files with 2,944 additions and 0 deletions.
46 changes: 46 additions & 0 deletions examples/python/ml/flax_llama7b_split/3pc.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
{
"id": "outsourcing.3pc",
"nodes": {
"node:0": "127.0.0.1:9920",
"node:1": "127.0.0.1:9921",
"node:2": "127.0.0.1:9922",
"node:3": "127.0.0.1:9923",
"node:4": "127.0.0.1:9924"
},
"devices": {
"SPU": {
"kind": "SPU",
"config": {
"node_ids": [
"node:0",
"node:1",
"node:2"
],
"spu_internal_addrs": [
"127.0.0.1:9930",
"127.0.0.1:9931",
"127.0.0.1:9932"
],
"runtime_config": {
"protocol": "ABY3",
"field": "FM64",
"enable_pphlo_profile": true,
"enable_hal_profile": true,
"fxp_exp_mode": 1
}
}
},
"P1": {
"kind": "PYU",
"config": {
"node_id": "node:3"
}
},
"P2": {
"kind": "PYU",
"config": {
"node_id": "node:4"
}
}
}
}
31 changes: 31 additions & 0 deletions examples/python/ml/flax_llama7b_split/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2023 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

load("@rules_python//python:defs.bzl", "py_binary")

package(default_visibility = ["//visibility:public"])

py_binary(
name = "flax_llama7b_split",
srcs = ["flax_llama7b_split.py"],
data = [
"//examples/python/ml/flax_llama7b_split:3pc.json",
],
deps = [
"//spu:api",
"//spu/intrinsic:all_intrinsics",
"//spu/utils:distributed",
"//spu/utils:simulation",
],
)
181 changes: 181 additions & 0 deletions examples/python/ml/flax_llama7b_split/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
# Flax LlaMA-7B Example with Model Split

This example demonstrates how to use SPU to run secure inference on a pre-trained
[LlaMA-7B](https://research.facebook.com/publications/llama-open-and-efficient-foundation-language-models/) model with model split.

1. Motivation

- Time:
Using the full LlaMA model for inference on SPU can take a significant amount of time. If only a portion of the model is passed through SPU to ensure privacy, it can greatly improve inference efficiency.

- RAM Usage:
Using the full LLaMA model for inference on SPU requires a large amount of memory(more than 256 GB). Splitting the model can significantly reduce memory usage, making it available for use in hardware-constrained environments.

2. Download EasyML library to support Flax-LLaMA-7B

```sh
git clone https://github.com/young-geng/EasyLM.git
cd EasyLM
export PYTHONPATH="${PWD}:$PYTHONPATH"
```
Or use a fork created for transformer split.

```sh
git clone -b split_llama_2 https://github.com/Rainysponge/EasyLM.git
cd EasyLM
export PYTHONPATH="${PWD}:$PYTHONPATH"
```

Install EasyLM Environment Before Install Secretflow & SPU

```sh
conda env create -f examples/python/ml/flax_llama7b_split/gpu_environment.yml
conda activate EasyLM
pip install 'transformers[flax]'
pip install spu
```

If Do not Want to use GPU

```sh
pip uninstall jax jaxlib
pip install jax==0.4.11 jaxlib==0.4.11
```

Download trained LLaMA-7B[PyTroch-Version] from "https://github.com/facebookresearch/llama", and convert it to EasyLM format as:

```sh
cd path_to_EasyLM/EasyLM/models/llama
python convert_hf_to_easylm.py \
--checkpoint_dir path_to_llama_weights \
--output_file path_to_outputfile \
--model_size 7b \
--streaming
```

Move the python file to EasyLM if you do not use the fork fork created for transformer split.

```sh
cp path-to-llama_model_split_transformer_py path_to_EasyLM/EasyLM/models/llama
```

3. Launch SPU backend runtime

```sh
bazel run -c opt //examples/python/utils:nodectl -- --config `pwd`/examples/python/ml/flax_llama7b_split/3pc.json up
```

or
(recommended)

```sh
cd examples/python/utils
python nodectl.py --config `pwd`/examples/python/ml/flax_llama7b_split/3pc.json up
```

4. Run `flax_llama7b_split` example

```sh
bazel run -c opt //examples/python/ml/flax_llama7b_split -- --config `pwd`/examples/python/ml/flax_llama7b_split/3pc.json
```

or(recommended)

```sh
cd examples/python/ml/flax_llama7b_split
python flax_llama7b_split.py --config `pwd`/examples/python/ml/flax_llama7b_split/3pc.json
```

and you can get the following results from our example:

```md
------
Run on CPU
Q: What is the largest animal?
A: The largest animal is the blue whale.
generate on CPU: 256.08824276924133 seconds
------
Run on SPU
[0000-00-00 00:00:00.000] [info] [thread_pool.cc:30] Create a fixed thread pool with size 127
Q: What is the largest animal?
A: The largest animal is the blue whale.
generate on SPU: 812.9427680969238 seconds
```
RAM peak: 64.5888GB

And If you set token_num to 30, you can get the following results:
```sh
------
Run on CPU
Q: What is the largest animal?
A: The largest animal is the blue whale.
Q: What is the smallest animal?
A: The smallest animal is the bacterium.
generate on CPU: 837.0810837745667 seconds
------
Run on SPU
[0000-00-00 00:00:00.000] [info] [thread_pool.cc:30] Create a fixed thread pool with size 127
Q: What is the largest animal?
A: The largest animal is the blue whale.
Q: What is the smallest animal?
A: The smallest animal is the bacterium.
generate on SPU: 2760.5035905838013 seconds
```

5. Supplement the Split Strategy

In this example, we split the LLaMA-7B model into three parts as follows:

- **Client**: Embedding + 0-1 LLaMA-Block
- **Mid**: 2nd LLaMA-Block (_runing on the spu_)
- **Server**: 3-31 LLaMA-Block + RMSNorm Layer

Actually, if users want to split LLaMA-7B model in other way, this can be easily achieved by rewriting few lines of code in `flax_llama7b_split.py` and `llama_model_splited_transformer.py`.

For example, we rewrite the files as follow.

```python
# flax_llama7b_split.py
# lines 76-89
client_params_dict = {
"transformer":{
"wte":params['params']["transformer"]["wte"],
"ln_f": params['params']["transformer"]["ln_f"],
"h":{str(i): params['params']["transformer"]["h"][str(i)] for i in range(2)}
}
}
mid_params_dict = {
"transformer":{
"h":{str(i): params['params']["transformer"]["h"][str(i)] for i in range(2, 5)}
}
}
```
```python
# llama_model_splited_transformer.py
# lines 1194
for block in self.blocks[: 2]:
...
# lines 1274
for block in self.blocks[2: 5]:
...
# lines 1355
for block in self.blocks[5:]:
...
```
After that, the LLaMA-7B model will be split as follows:
- **Client**: Embedding + 0-1 LLaMA-Block
- **Mid**: 2-4 LLaMA-Block (_runing on the spu_)
- **Server**: 5-31 LLaMA-Block + RMSNorm Layer
6. Privacy Security Warning
In this example, our main motivation is to reduce the hardware and time resource costs of [Llama-7B](https://research.facebook.com/publications/llama-open-and-efficient-foundation-language-models/) model inference using the SPU. Therefore, spu is only used for inference on the middle blocks of the model. Its privacy protection capability for the original data is weaker when using spu for inference on the entire Llama-7B model. It may be vulnerable to Model Inversion Attacks known in Split Learning as follows:
- [PCAT: Functionality and Data Stealing from Split Learning by Pseudo-Client Attack](https://www.usenix.org/system/files/usenixsecurity23-gao.pdf)
- [UnSplit: Data-Oblivious Model Inversion, Model Stealing, and Label Inference Attacks Against Split Learning](https://arxiv.org/pdf/2108.09033.pdf)
Loading

0 comments on commit 9f60e68

Please sign in to comment.