Skip to content

Commit

Permalink
Fix the llama7b example (#440)
Browse files Browse the repository at this point in the history
# Pull Request

## What problem does this PR solve?
fix the issue #437 

Issue Number: Fixed #437 

## Possible side effects?

- Performance: NO

- Backward compatibility: NO
  • Loading branch information
gystar authored Dec 18, 2023
1 parent 4d8aac3 commit 58c47de
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 7 deletions.
4 changes: 2 additions & 2 deletions examples/python/ml/flax_llama7b/3pc.json
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@
"P1": {
"kind": "PYU",
"config": {
"node_id": "node:0"
"node_id": "node:3"
}
},
"P2": {
"kind": "PYU",
"config": {
"node_id": "node:0"
"node_id": "node:4"
}
}
}
Expand Down
30 changes: 26 additions & 4 deletions examples/python/ml/flax_llama7b/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,26 @@ This example demonstrates how to use SPU to run secure inference on a pre-traine
git clone https://github.com/young-geng/EasyLM.git
cd EasyLM
export PYTHONPATH="${PWD}:$PYTHONPATH"
cd ./EasyLM/models/llama
```

Since EasyLM have an issue,so we have to make a samll change to support the option "streaming=false".
Open and edit "convert_hf_to_easylm.py", chang this:
```python
parser.add_argument("--streaming", action="store_true", default=True, help="whether is model weight saved stream format",)
```
to:
```python
parser.add_argument("--streaming", action="store_true", default=False, help="whether is model weight saved stream format",)
```
Download trained LLaMA-B[PyTroch-Version] from [Hugging Face](https://huggingface.co/openlm-research/open_llama_7b)
, and convert it to Flax.msgpack as:

```sh
cd path_to_EasyLM/EasyLM/models/llama
```sh
python convert_hf_to_easylm.py \
--checkpoint_dir path-to-flax-llama7b-dir \
--output_file path-to-flax-llama7b-EasyLM.msgpack \
--model_size 7b \
--streaming
--streaming false
```

3. Launch SPU backend runtime
Expand All @@ -37,12 +45,26 @@ This example demonstrates how to use SPU to run secure inference on a pre-traine
bazel run -c opt //examples/python/utils:nodectl -- --config `pwd`/examples/python/ml/flax_llama7b/3pc.json up
```

or
(recommended)

```sh
cd examples/python/utils
python nodectl.py --config ../ml/flax_llama7b/3pc.json up
```

4. Run `flax_llama7b` example

```sh
bazel run -c opt //examples/python/ml/flax_llama7b -- --config `pwd`/examples/python/ml/flax_llama7b/3pc.json
```

or(recommended)

```sh
cd examples/python/ml/flax_llama7b
python flax_llama7b_split.py --model_path dir-to-flax-llama7b-EasyLM --config ./3pc.json
and you can get the following results from our example:
```md
Expand Down
5 changes: 4 additions & 1 deletion examples/python/ml/flax_llama7b/flax_llama7b.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@

parser = argparse.ArgumentParser(description='distributed driver.')
parser.add_argument("-c", "--config", default="examples/python/ml/flax_llama/3pc.json")
parser.add_argument(
"-m", "--model_path", required=True, help="please input the dir-to-flax-llama7b"
)
args = parser.parse_args()

with open(args.config, 'r') as file:
Expand All @@ -49,7 +52,7 @@
# enable x / broadcast(y) -> x * broadcast(1/y)
copts.enable_optimize_denominator_with_broadcast = True

model_path = 'path-to-flax-llama7b'
model_path = parser.model_path

tokenizer = LlamaTokenizer.from_pretrained(model_path)
tokenizer.pad_token_id = tokenizer.eos_token_id
Expand Down

0 comments on commit 58c47de

Please sign in to comment.