From 58c47de67ae6b0d4cc49f0a18d93b5e3ebf75335 Mon Sep 17 00:00:00 2001 From: Yi Gui Date: Mon, 18 Dec 2023 16:44:32 +0800 Subject: [PATCH] Fix the llama7b example (#440) # Pull Request ## What problem does this PR solve? fix the issue #437 Issue Number: Fixed #437 ## Possible side effects? - Performance: NO - Backward compatibility: NO --- examples/python/ml/flax_llama7b/3pc.json | 4 +-- examples/python/ml/flax_llama7b/README.md | 30 ++++++++++++++++--- .../python/ml/flax_llama7b/flax_llama7b.py | 5 +++- 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/examples/python/ml/flax_llama7b/3pc.json b/examples/python/ml/flax_llama7b/3pc.json index 506dbe86..20ab0c91 100644 --- a/examples/python/ml/flax_llama7b/3pc.json +++ b/examples/python/ml/flax_llama7b/3pc.json @@ -34,13 +34,13 @@ "P1": { "kind": "PYU", "config": { - "node_id": "node:0" + "node_id": "node:3" } }, "P2": { "kind": "PYU", "config": { - "node_id": "node:0" + "node_id": "node:4" } } } diff --git a/examples/python/ml/flax_llama7b/README.md b/examples/python/ml/flax_llama7b/README.md index 13fc76bb..4410abe4 100644 --- a/examples/python/ml/flax_llama7b/README.md +++ b/examples/python/ml/flax_llama7b/README.md @@ -17,18 +17,26 @@ This example demonstrates how to use SPU to run secure inference on a pre-traine git clone https://github.com/young-geng/EasyLM.git cd EasyLM export PYTHONPATH="${PWD}:$PYTHONPATH" + cd ./EasyLM/models/llama ``` + Since EasyLM have an issue,so we have to make a samll change to support the option "streaming=false". + Open and edit "convert_hf_to_easylm.py", chang this: + ```python + parser.add_argument("--streaming", action="store_true", default=True, help="whether is model weight saved stream format",) + ``` + to: + ```python + parser.add_argument("--streaming", action="store_true", default=False, help="whether is model weight saved stream format",) + ``` Download trained LLaMA-B[PyTroch-Version] from [Hugging Face](https://huggingface.co/openlm-research/open_llama_7b) , and convert it to Flax.msgpack as: - - ```sh - cd path_to_EasyLM/EasyLM/models/llama + ```sh python convert_hf_to_easylm.py \ --checkpoint_dir path-to-flax-llama7b-dir \ --output_file path-to-flax-llama7b-EasyLM.msgpack \ --model_size 7b \ - --streaming + --streaming false ``` 3. Launch SPU backend runtime @@ -37,12 +45,26 @@ This example demonstrates how to use SPU to run secure inference on a pre-traine bazel run -c opt //examples/python/utils:nodectl -- --config `pwd`/examples/python/ml/flax_llama7b/3pc.json up ``` + or + (recommended) + + ```sh + cd examples/python/utils + python nodectl.py --config ../ml/flax_llama7b/3pc.json up + ``` + 4. Run `flax_llama7b` example ```sh bazel run -c opt //examples/python/ml/flax_llama7b -- --config `pwd`/examples/python/ml/flax_llama7b/3pc.json ``` + or(recommended) + + ```sh + cd examples/python/ml/flax_llama7b + python flax_llama7b_split.py --model_path dir-to-flax-llama7b-EasyLM --config ./3pc.json + and you can get the following results from our example: ```md diff --git a/examples/python/ml/flax_llama7b/flax_llama7b.py b/examples/python/ml/flax_llama7b/flax_llama7b.py index 88754fbd..ac3dc3e0 100644 --- a/examples/python/ml/flax_llama7b/flax_llama7b.py +++ b/examples/python/ml/flax_llama7b/flax_llama7b.py @@ -36,6 +36,9 @@ parser = argparse.ArgumentParser(description='distributed driver.') parser.add_argument("-c", "--config", default="examples/python/ml/flax_llama/3pc.json") +parser.add_argument( + "-m", "--model_path", required=True, help="please input the dir-to-flax-llama7b" +) args = parser.parse_args() with open(args.config, 'r') as file: @@ -49,7 +52,7 @@ # enable x / broadcast(y) -> x * broadcast(1/y) copts.enable_optimize_denominator_with_broadcast = True -model_path = 'path-to-flax-llama7b' +model_path = parser.model_path tokenizer = LlamaTokenizer.from_pretrained(model_path) tokenizer.pad_token_id = tokenizer.eos_token_id