From 0197308bfc9bf2f3903edaabf20e154e7993fa9d Mon Sep 17 00:00:00 2001 From: anakinxc Date: Tue, 19 Dec 2023 18:04:36 +0800 Subject: [PATCH 1/3] repo-sync-2023-12-19T18:04:30+0800 --- CHANGELOG.md | 2 ++ docs/reference/pphlo_doc.rst | 2 +- .../develop_your_first_mpc_application.ipynb | 8 ++--- examples/python/ml/flax_llama7b/README.md | 31 +++++++++++++++++-- libspu/compiler/front_end/BUILD.bazel | 1 + libspu/compiler/front_end/hlo_importer.cc | 2 ++ libspu/mpc/spdz2k/BUILD.bazel | 2 +- libspu/mpc/spdz2k/ot/BUILD.bazel | 2 +- spu/utils/frontend.py | 14 +++++++-- 9 files changed, 51 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index afd2d802..614416ac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ > > please add your unreleased change here. +## TBD + - [Feature] Add Odd-Even Merge Sort to replace the bitonic sort - [Feature] Add radix sort support for ABY3 - [Feature] Integrate with secretflow/psi diff --git a/docs/reference/pphlo_doc.rst b/docs/reference/pphlo_doc.rst index 48f61e20..e539ab63 100644 --- a/docs/reference/pphlo_doc.rst +++ b/docs/reference/pphlo_doc.rst @@ -3,7 +3,7 @@ PPHlo API reference PPHlo is short for (SPU High level ops), it's the assembly language of SPU. -PPHlo is built on `MLIR `_ infrastructure, the concrete ops definition could be found :spu_code_host:`here `. +PPHlo is built on `MLIR `_ infrastructure, the concrete ops definition could be found :spu_code_host:`here `. Op List ~~~~~~~ diff --git a/docs/tutorials/develop_your_first_mpc_application.ipynb b/docs/tutorials/develop_your_first_mpc_application.ipynb index 57908d91..d89ca5e1 100644 --- a/docs/tutorials/develop_your_first_mpc_application.ipynb +++ b/docs/tutorials/develop_your_first_mpc_application.ipynb @@ -1129,11 +1129,11 @@ "\n", "Before diving into the problem deeper, we highly recommend the reader to read: \n", "\n", - " 1. [spu_inside](https://www.secretflow.org.cn/docs/spu/en/getting_started/tutorials/spu_inside.html#Tracing): gives some introductions how spu works inside for float-point operations.\n", + " 1. [spu_inside](https://www.secretflow.org.cn/docs/spu/latest/en-US/tutorials/spu_inside#Tracing): gives some introductions how spu works inside for float-point operations.\n", "\n", - " 2. [pitfall](https://www.secretflow.org.cn/docs/spu/en/reference/fxp.html): spu implements math function(like `reciprocal`, `log` and so on) with approximation algorithm, so some precision issue will occur when inputs fall into some intervals. We list some known issue about this.\n", + " 2. [pitfall](https://www.secretflow.org.cn/docs/spu/latest/en-US/development/fxp): spu implements math function(like `reciprocal`, `log` and so on) with approximation algorithm, so some precision issue will occur when inputs fall into some intervals. We list some known issue about this.\n", "\n", - "3. [protocols](https://www.secretflow.org.cn/docs/spu/en/reference/mpc_status.html): list all protocols spu implements now. Generally speaking, for 2pc, it's safe to use cheetah, while for 3pc, ABY3 is the only choice.\n", + "3. [protocols](https://www.secretflow.org.cn/docs/spu/latest/en-US/reference/mpc_status): list all protocols spu implements now. Generally speaking, for 2pc, it's safe to use cheetah, while for 3pc, ABY3 is the only choice.\n", "\n", "First define a function just like `fit_and_predict` to get dk_arr." ] @@ -1659,7 +1659,7 @@ "source": [ "#### Rsqrt v.s. Norm\n", "\n", - "We can recall that the `compute_dk_func` function defined in Part 1 contains a `method` arg, and we just ignore this arg before. Indeed, we can tell simulator to print more information like [spu_inside](https://www.secretflow.org.cn/docs/spu/en/getting_started/tutorials/spu_inside.html#Tracing) do: enable **hlo**(High Level Operations) trace and profile on. Then we can figure out which op has been invoked and its time cost.\n", + "We can recall that the `compute_dk_func` function defined in Part 1 contains a `method` arg, and we just ignore this arg before. Indeed, we can tell simulator to print more information like [spu_inside](https://www.secretflow.org.cn/docs/spu/latest/en-US/tutorials/spu_inside#Tracing) do: enable **hlo**(High Level Operations) trace and profile on. Then we can figure out which op has been invoked and its time cost.\n", "\n", "Here, we list some advantages of using `jax.lax.rsqrt` rather than `jnp.linalg.norm`:\n", "\n", diff --git a/examples/python/ml/flax_llama7b/README.md b/examples/python/ml/flax_llama7b/README.md index 4410abe4..949947dc 100644 --- a/examples/python/ml/flax_llama7b/README.md +++ b/examples/python/ml/flax_llama7b/README.md @@ -20,6 +20,32 @@ This example demonstrates how to use SPU to run secure inference on a pre-traine cd ./EasyLM/models/llama ``` + Since EasyLM have an issue,so we have to make a samll change to support the option "streaming=false". + Open and edit "convert_hf_to_easylm.py", chang this: + + ```python + parser.add_argument("--streaming", action="store_true", default=True, help="whether is model weight saved stream format",) + ``` + + to: + + ```python + parser.add_argument("--streaming", action="store_true", default=False, help="whether is model weight saved stream format",) + ``` + + Since EasyLM have an issue,so we have to make a samll change to support the option "streaming=false". + Open and edit "convert_hf_to_easylm.py", chang this: + + ```python + parser.add_argument("--streaming", action="store_true", default=True, help="whether is model weight saved stream format",) + ``` + + to: + + ```python + parser.add_argument("--streaming", action="store_true", default=False, help="whether is model weight saved stream format",) + ``` + Since EasyLM have an issue,so we have to make a samll change to support the option "streaming=false". Open and edit "convert_hf_to_easylm.py", chang this: ```python @@ -31,7 +57,7 @@ This example demonstrates how to use SPU to run secure inference on a pre-traine ``` Download trained LLaMA-B[PyTroch-Version] from [Hugging Face](https://huggingface.co/openlm-research/open_llama_7b) , and convert it to Flax.msgpack as: - ```sh + ```sh python convert_hf_to_easylm.py \ --checkpoint_dir path-to-flax-llama7b-dir \ --output_file path-to-flax-llama7b-EasyLM.msgpack \ @@ -45,8 +71,7 @@ This example demonstrates how to use SPU to run secure inference on a pre-traine bazel run -c opt //examples/python/utils:nodectl -- --config `pwd`/examples/python/ml/flax_llama7b/3pc.json up ``` - or - (recommended) + or(recommended) ```sh cd examples/python/utils diff --git a/libspu/compiler/front_end/BUILD.bazel b/libspu/compiler/front_end/BUILD.bazel index c7e8c6f4..8a6e1f97 100644 --- a/libspu/compiler/front_end/BUILD.bazel +++ b/libspu/compiler/front_end/BUILD.bazel @@ -39,6 +39,7 @@ spu_cc_library( "@xla//xla/service:dot_decomposer", "@xla//xla/service:eigh_expander", "@xla//xla/service:float_normalization", + "@xla//xla/service:gather_expander", "@xla//xla/service:gather_simplifier", "@xla//xla/service:hlo_constant_folding", "@xla//xla/service:hlo_cse", diff --git a/libspu/compiler/front_end/hlo_importer.cc b/libspu/compiler/front_end/hlo_importer.cc index 296406d7..149f5a4b 100644 --- a/libspu/compiler/front_end/hlo_importer.cc +++ b/libspu/compiler/front_end/hlo_importer.cc @@ -28,6 +28,7 @@ #include "xla/service/eigh_expander.h" #include "xla/service/float_normalization.h" #include "xla/service/float_support.h" +#include "xla/service/gather_expander.h" #include "xla/service/gather_simplifier.h" #include "xla/service/gpu/dot_dimension_sorter.h" #include "xla/service/hlo_constant_folding.h" @@ -126,6 +127,7 @@ void runHloPasses(xla::HloModule *module) { /*allow_mixed_precision=*/false); pipeline.AddPass(); + pipeline.AddPass(GatherExpander::kEliminateSimpleGathers); pipeline.AddPass(ScatterExpander::kEliminateAllScatters); pipeline.AddPass(options); pipeline.AddPass(); diff --git a/libspu/mpc/spdz2k/BUILD.bazel b/libspu/mpc/spdz2k/BUILD.bazel index 85c0699d..2a680ce1 100644 --- a/libspu/mpc/spdz2k/BUILD.bazel +++ b/libspu/mpc/spdz2k/BUILD.bazel @@ -111,8 +111,8 @@ spu_cc_library( ":boolean", ":state", ":type", - "//libspu/core:vectorize", "//libspu/core:xt_helper", + "//libspu/core:vectorize", "//libspu/mpc:ab_api", "//libspu/mpc:api", "//libspu/mpc:kernel", diff --git a/libspu/mpc/spdz2k/ot/BUILD.bazel b/libspu/mpc/spdz2k/ot/BUILD.bazel index 50651d6f..17604930 100644 --- a/libspu/mpc/spdz2k/ot/BUILD.bazel +++ b/libspu/mpc/spdz2k/ot/BUILD.bazel @@ -21,7 +21,7 @@ spu_cc_library( name = "ferret", hdrs = ["ferret.h"], deps = [ - "//libspu/mpc/cheetah/ot", + "//libspu/mpc/cheetah/ot:ot", "//libspu/mpc/common:communicator", ], ) diff --git a/spu/utils/frontend.py b/spu/utils/frontend.py index 5e6024f7..097c3a4e 100644 --- a/spu/utils/frontend.py +++ b/spu/utils/frontend.py @@ -29,6 +29,7 @@ def _jax_compilation_key( fn: Callable, static_argnums, static_argnames, args: List, kwargs: Dict ): import jax + import numpy as np from jax._src.api_util import argnames_partial_except, argnums_partial_except try: @@ -60,9 +61,15 @@ def _function_contents(func): f, dkwargs = argnames_partial_except(f, static_argnames, kwargs) f, dargs = argnums_partial_except(f, static_argnums, args, allow_invalid=True) - flat_args, _ = jax.tree_util.tree_flatten((dargs, dkwargs)) - types = [(a.dtype, a.shape) if hasattr(a, 'dtype') else type(a) for a in flat_args] - hash_str = f'{hash(f)}-{hash(_function_contents(fn))}-{types}' + flat_args, tree = jax.tree_util.tree_flatten((dargs, dkwargs)) + types = [] + for a in flat_args: + if hasattr(a, 'dtype'): + types.append((a.dtype, a.shape)) + else: + np_array = np.asarray(a) + types.append((np_array.dtype, np_array.shape)) + hash_str = f'{hash(f)}-{hash(_function_contents(fn))}-{types}-{hash(tree)}' return hash_str @@ -114,6 +121,7 @@ def _jax_compilation( cfn, output = jax.xla_computation( fn, return_shape=True, static_argnums=static_argnums, backend="interpreter" )(*args, **kwargs) + return cfn.as_serialized_hlo_module_proto(), output From 19a43322a9b1c4e8edfbca6824ce5f2084c75fc3 Mon Sep 17 00:00:00 2001 From: anakinxc <103552181+anakinxc@users.noreply.github.com> Date: Tue, 19 Dec 2023 18:07:52 +0800 Subject: [PATCH 2/3] Update README.md --- examples/python/ml/flax_llama7b/README.md | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/examples/python/ml/flax_llama7b/README.md b/examples/python/ml/flax_llama7b/README.md index 949947dc..ab1f5a90 100644 --- a/examples/python/ml/flax_llama7b/README.md +++ b/examples/python/ml/flax_llama7b/README.md @@ -33,28 +33,6 @@ This example demonstrates how to use SPU to run secure inference on a pre-traine parser.add_argument("--streaming", action="store_true", default=False, help="whether is model weight saved stream format",) ``` - Since EasyLM have an issue,so we have to make a samll change to support the option "streaming=false". - Open and edit "convert_hf_to_easylm.py", chang this: - - ```python - parser.add_argument("--streaming", action="store_true", default=True, help="whether is model weight saved stream format",) - ``` - - to: - - ```python - parser.add_argument("--streaming", action="store_true", default=False, help="whether is model weight saved stream format",) - ``` - - Since EasyLM have an issue,so we have to make a samll change to support the option "streaming=false". - Open and edit "convert_hf_to_easylm.py", chang this: - ```python - parser.add_argument("--streaming", action="store_true", default=True, help="whether is model weight saved stream format",) - ``` - to: - ```python - parser.add_argument("--streaming", action="store_true", default=False, help="whether is model weight saved stream format",) - ``` Download trained LLaMA-B[PyTroch-Version] from [Hugging Face](https://huggingface.co/openlm-research/open_llama_7b) , and convert it to Flax.msgpack as: ```sh From 55b711038c798c18563c21e1de8f9740fdc5ddbe Mon Sep 17 00:00:00 2001 From: anakinxc <103552181+anakinxc@users.noreply.github.com> Date: Tue, 19 Dec 2023 10:11:38 +0000 Subject: [PATCH 3/3] format --- libspu/mpc/spdz2k/BUILD.bazel | 2 +- libspu/mpc/spdz2k/ot/BUILD.bazel | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/libspu/mpc/spdz2k/BUILD.bazel b/libspu/mpc/spdz2k/BUILD.bazel index 2a680ce1..85c0699d 100644 --- a/libspu/mpc/spdz2k/BUILD.bazel +++ b/libspu/mpc/spdz2k/BUILD.bazel @@ -111,8 +111,8 @@ spu_cc_library( ":boolean", ":state", ":type", - "//libspu/core:xt_helper", "//libspu/core:vectorize", + "//libspu/core:xt_helper", "//libspu/mpc:ab_api", "//libspu/mpc:api", "//libspu/mpc:kernel", diff --git a/libspu/mpc/spdz2k/ot/BUILD.bazel b/libspu/mpc/spdz2k/ot/BUILD.bazel index 17604930..50651d6f 100644 --- a/libspu/mpc/spdz2k/ot/BUILD.bazel +++ b/libspu/mpc/spdz2k/ot/BUILD.bazel @@ -21,7 +21,7 @@ spu_cc_library( name = "ferret", hdrs = ["ferret.h"], deps = [ - "//libspu/mpc/cheetah/ot:ot", + "//libspu/mpc/cheetah/ot", "//libspu/mpc/common:communicator", ], )