diff --git a/CHANGELOG.md b/CHANGELOG.md index afd2d802..614416ac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ > > please add your unreleased change here. +## TBD + - [Feature] Add Odd-Even Merge Sort to replace the bitonic sort - [Feature] Add radix sort support for ABY3 - [Feature] Integrate with secretflow/psi diff --git a/docs/reference/pphlo_doc.rst b/docs/reference/pphlo_doc.rst index 48f61e20..e539ab63 100644 --- a/docs/reference/pphlo_doc.rst +++ b/docs/reference/pphlo_doc.rst @@ -3,7 +3,7 @@ PPHlo API reference PPHlo is short for (SPU High level ops), it's the assembly language of SPU. -PPHlo is built on `MLIR `_ infrastructure, the concrete ops definition could be found :spu_code_host:`here `. +PPHlo is built on `MLIR `_ infrastructure, the concrete ops definition could be found :spu_code_host:`here `. Op List ~~~~~~~ diff --git a/docs/tutorials/develop_your_first_mpc_application.ipynb b/docs/tutorials/develop_your_first_mpc_application.ipynb index 57908d91..d89ca5e1 100644 --- a/docs/tutorials/develop_your_first_mpc_application.ipynb +++ b/docs/tutorials/develop_your_first_mpc_application.ipynb @@ -1129,11 +1129,11 @@ "\n", "Before diving into the problem deeper, we highly recommend the reader to read: \n", "\n", - " 1. [spu_inside](https://www.secretflow.org.cn/docs/spu/en/getting_started/tutorials/spu_inside.html#Tracing): gives some introductions how spu works inside for float-point operations.\n", + " 1. [spu_inside](https://www.secretflow.org.cn/docs/spu/latest/en-US/tutorials/spu_inside#Tracing): gives some introductions how spu works inside for float-point operations.\n", "\n", - " 2. [pitfall](https://www.secretflow.org.cn/docs/spu/en/reference/fxp.html): spu implements math function(like `reciprocal`, `log` and so on) with approximation algorithm, so some precision issue will occur when inputs fall into some intervals. We list some known issue about this.\n", + " 2. [pitfall](https://www.secretflow.org.cn/docs/spu/latest/en-US/development/fxp): spu implements math function(like `reciprocal`, `log` and so on) with approximation algorithm, so some precision issue will occur when inputs fall into some intervals. We list some known issue about this.\n", "\n", - "3. [protocols](https://www.secretflow.org.cn/docs/spu/en/reference/mpc_status.html): list all protocols spu implements now. Generally speaking, for 2pc, it's safe to use cheetah, while for 3pc, ABY3 is the only choice.\n", + "3. [protocols](https://www.secretflow.org.cn/docs/spu/latest/en-US/reference/mpc_status): list all protocols spu implements now. Generally speaking, for 2pc, it's safe to use cheetah, while for 3pc, ABY3 is the only choice.\n", "\n", "First define a function just like `fit_and_predict` to get dk_arr." ] @@ -1659,7 +1659,7 @@ "source": [ "#### Rsqrt v.s. Norm\n", "\n", - "We can recall that the `compute_dk_func` function defined in Part 1 contains a `method` arg, and we just ignore this arg before. Indeed, we can tell simulator to print more information like [spu_inside](https://www.secretflow.org.cn/docs/spu/en/getting_started/tutorials/spu_inside.html#Tracing) do: enable **hlo**(High Level Operations) trace and profile on. Then we can figure out which op has been invoked and its time cost.\n", + "We can recall that the `compute_dk_func` function defined in Part 1 contains a `method` arg, and we just ignore this arg before. Indeed, we can tell simulator to print more information like [spu_inside](https://www.secretflow.org.cn/docs/spu/latest/en-US/tutorials/spu_inside#Tracing) do: enable **hlo**(High Level Operations) trace and profile on. Then we can figure out which op has been invoked and its time cost.\n", "\n", "Here, we list some advantages of using `jax.lax.rsqrt` rather than `jnp.linalg.norm`:\n", "\n", diff --git a/examples/python/ml/flax_llama7b/README.md b/examples/python/ml/flax_llama7b/README.md index 4410abe4..ab1f5a90 100644 --- a/examples/python/ml/flax_llama7b/README.md +++ b/examples/python/ml/flax_llama7b/README.md @@ -22,16 +22,20 @@ This example demonstrates how to use SPU to run secure inference on a pre-traine Since EasyLM have an issue,so we have to make a samll change to support the option "streaming=false". Open and edit "convert_hf_to_easylm.py", chang this: + ```python parser.add_argument("--streaming", action="store_true", default=True, help="whether is model weight saved stream format",) ``` + to: + ```python parser.add_argument("--streaming", action="store_true", default=False, help="whether is model weight saved stream format",) ``` + Download trained LLaMA-B[PyTroch-Version] from [Hugging Face](https://huggingface.co/openlm-research/open_llama_7b) , and convert it to Flax.msgpack as: - ```sh + ```sh python convert_hf_to_easylm.py \ --checkpoint_dir path-to-flax-llama7b-dir \ --output_file path-to-flax-llama7b-EasyLM.msgpack \ @@ -45,8 +49,7 @@ This example demonstrates how to use SPU to run secure inference on a pre-traine bazel run -c opt //examples/python/utils:nodectl -- --config `pwd`/examples/python/ml/flax_llama7b/3pc.json up ``` - or - (recommended) + or(recommended) ```sh cd examples/python/utils diff --git a/libspu/compiler/front_end/BUILD.bazel b/libspu/compiler/front_end/BUILD.bazel index c7e8c6f4..8a6e1f97 100644 --- a/libspu/compiler/front_end/BUILD.bazel +++ b/libspu/compiler/front_end/BUILD.bazel @@ -39,6 +39,7 @@ spu_cc_library( "@xla//xla/service:dot_decomposer", "@xla//xla/service:eigh_expander", "@xla//xla/service:float_normalization", + "@xla//xla/service:gather_expander", "@xla//xla/service:gather_simplifier", "@xla//xla/service:hlo_constant_folding", "@xla//xla/service:hlo_cse", diff --git a/libspu/compiler/front_end/hlo_importer.cc b/libspu/compiler/front_end/hlo_importer.cc index 296406d7..149f5a4b 100644 --- a/libspu/compiler/front_end/hlo_importer.cc +++ b/libspu/compiler/front_end/hlo_importer.cc @@ -28,6 +28,7 @@ #include "xla/service/eigh_expander.h" #include "xla/service/float_normalization.h" #include "xla/service/float_support.h" +#include "xla/service/gather_expander.h" #include "xla/service/gather_simplifier.h" #include "xla/service/gpu/dot_dimension_sorter.h" #include "xla/service/hlo_constant_folding.h" @@ -126,6 +127,7 @@ void runHloPasses(xla::HloModule *module) { /*allow_mixed_precision=*/false); pipeline.AddPass(); + pipeline.AddPass(GatherExpander::kEliminateSimpleGathers); pipeline.AddPass(ScatterExpander::kEliminateAllScatters); pipeline.AddPass(options); pipeline.AddPass(); diff --git a/spu/utils/frontend.py b/spu/utils/frontend.py index 5e6024f7..097c3a4e 100644 --- a/spu/utils/frontend.py +++ b/spu/utils/frontend.py @@ -29,6 +29,7 @@ def _jax_compilation_key( fn: Callable, static_argnums, static_argnames, args: List, kwargs: Dict ): import jax + import numpy as np from jax._src.api_util import argnames_partial_except, argnums_partial_except try: @@ -60,9 +61,15 @@ def _function_contents(func): f, dkwargs = argnames_partial_except(f, static_argnames, kwargs) f, dargs = argnums_partial_except(f, static_argnums, args, allow_invalid=True) - flat_args, _ = jax.tree_util.tree_flatten((dargs, dkwargs)) - types = [(a.dtype, a.shape) if hasattr(a, 'dtype') else type(a) for a in flat_args] - hash_str = f'{hash(f)}-{hash(_function_contents(fn))}-{types}' + flat_args, tree = jax.tree_util.tree_flatten((dargs, dkwargs)) + types = [] + for a in flat_args: + if hasattr(a, 'dtype'): + types.append((a.dtype, a.shape)) + else: + np_array = np.asarray(a) + types.append((np_array.dtype, np_array.shape)) + hash_str = f'{hash(f)}-{hash(_function_contents(fn))}-{types}-{hash(tree)}' return hash_str @@ -114,6 +121,7 @@ def _jax_compilation( cfn, output = jax.xla_computation( fn, return_shape=True, static_argnums=static_argnums, backend="interpreter" )(*args, **kwargs) + return cfn.as_serialized_hlo_module_proto(), output