diff --git a/CHANGELOG.md b/CHANGELOG.md
index afd2d802..614416ac 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -10,6 +10,8 @@
>
> please add your unreleased change here.
+## TBD
+
- [Feature] Add Odd-Even Merge Sort to replace the bitonic sort
- [Feature] Add radix sort support for ABY3
- [Feature] Integrate with secretflow/psi
diff --git a/docs/reference/pphlo_doc.rst b/docs/reference/pphlo_doc.rst
index 48f61e20..e539ab63 100644
--- a/docs/reference/pphlo_doc.rst
+++ b/docs/reference/pphlo_doc.rst
@@ -3,7 +3,7 @@ PPHlo API reference
PPHlo is short for (SPU High level ops), it's the assembly language of SPU.
-PPHlo is built on `MLIR `_ infrastructure, the concrete ops definition could be found :spu_code_host:`here `.
+PPHlo is built on `MLIR `_ infrastructure, the concrete ops definition could be found :spu_code_host:`here `.
Op List
~~~~~~~
diff --git a/docs/tutorials/develop_your_first_mpc_application.ipynb b/docs/tutorials/develop_your_first_mpc_application.ipynb
index 57908d91..d89ca5e1 100644
--- a/docs/tutorials/develop_your_first_mpc_application.ipynb
+++ b/docs/tutorials/develop_your_first_mpc_application.ipynb
@@ -1129,11 +1129,11 @@
"\n",
"Before diving into the problem deeper, we highly recommend the reader to read: \n",
"\n",
- " 1. [spu_inside](https://www.secretflow.org.cn/docs/spu/en/getting_started/tutorials/spu_inside.html#Tracing): gives some introductions how spu works inside for float-point operations.\n",
+ " 1. [spu_inside](https://www.secretflow.org.cn/docs/spu/latest/en-US/tutorials/spu_inside#Tracing): gives some introductions how spu works inside for float-point operations.\n",
"\n",
- " 2. [pitfall](https://www.secretflow.org.cn/docs/spu/en/reference/fxp.html): spu implements math function(like `reciprocal`, `log` and so on) with approximation algorithm, so some precision issue will occur when inputs fall into some intervals. We list some known issue about this.\n",
+ " 2. [pitfall](https://www.secretflow.org.cn/docs/spu/latest/en-US/development/fxp): spu implements math function(like `reciprocal`, `log` and so on) with approximation algorithm, so some precision issue will occur when inputs fall into some intervals. We list some known issue about this.\n",
"\n",
- "3. [protocols](https://www.secretflow.org.cn/docs/spu/en/reference/mpc_status.html): list all protocols spu implements now. Generally speaking, for 2pc, it's safe to use cheetah, while for 3pc, ABY3 is the only choice.\n",
+ "3. [protocols](https://www.secretflow.org.cn/docs/spu/latest/en-US/reference/mpc_status): list all protocols spu implements now. Generally speaking, for 2pc, it's safe to use cheetah, while for 3pc, ABY3 is the only choice.\n",
"\n",
"First define a function just like `fit_and_predict` to get dk_arr."
]
@@ -1659,7 +1659,7 @@
"source": [
"#### Rsqrt v.s. Norm\n",
"\n",
- "We can recall that the `compute_dk_func` function defined in Part 1 contains a `method` arg, and we just ignore this arg before. Indeed, we can tell simulator to print more information like [spu_inside](https://www.secretflow.org.cn/docs/spu/en/getting_started/tutorials/spu_inside.html#Tracing) do: enable **hlo**(High Level Operations) trace and profile on. Then we can figure out which op has been invoked and its time cost.\n",
+ "We can recall that the `compute_dk_func` function defined in Part 1 contains a `method` arg, and we just ignore this arg before. Indeed, we can tell simulator to print more information like [spu_inside](https://www.secretflow.org.cn/docs/spu/latest/en-US/tutorials/spu_inside#Tracing) do: enable **hlo**(High Level Operations) trace and profile on. Then we can figure out which op has been invoked and its time cost.\n",
"\n",
"Here, we list some advantages of using `jax.lax.rsqrt` rather than `jnp.linalg.norm`:\n",
"\n",
diff --git a/examples/python/ml/flax_llama7b/README.md b/examples/python/ml/flax_llama7b/README.md
index 4410abe4..ab1f5a90 100644
--- a/examples/python/ml/flax_llama7b/README.md
+++ b/examples/python/ml/flax_llama7b/README.md
@@ -22,16 +22,20 @@ This example demonstrates how to use SPU to run secure inference on a pre-traine
Since EasyLM have an issue,so we have to make a samll change to support the option "streaming=false".
Open and edit "convert_hf_to_easylm.py", chang this:
+
```python
parser.add_argument("--streaming", action="store_true", default=True, help="whether is model weight saved stream format",)
```
+
to:
+
```python
parser.add_argument("--streaming", action="store_true", default=False, help="whether is model weight saved stream format",)
```
+
Download trained LLaMA-B[PyTroch-Version] from [Hugging Face](https://huggingface.co/openlm-research/open_llama_7b)
, and convert it to Flax.msgpack as:
- ```sh
+ ```sh
python convert_hf_to_easylm.py \
--checkpoint_dir path-to-flax-llama7b-dir \
--output_file path-to-flax-llama7b-EasyLM.msgpack \
@@ -45,8 +49,7 @@ This example demonstrates how to use SPU to run secure inference on a pre-traine
bazel run -c opt //examples/python/utils:nodectl -- --config `pwd`/examples/python/ml/flax_llama7b/3pc.json up
```
- or
- (recommended)
+ or(recommended)
```sh
cd examples/python/utils
diff --git a/libspu/compiler/front_end/BUILD.bazel b/libspu/compiler/front_end/BUILD.bazel
index c7e8c6f4..8a6e1f97 100644
--- a/libspu/compiler/front_end/BUILD.bazel
+++ b/libspu/compiler/front_end/BUILD.bazel
@@ -39,6 +39,7 @@ spu_cc_library(
"@xla//xla/service:dot_decomposer",
"@xla//xla/service:eigh_expander",
"@xla//xla/service:float_normalization",
+ "@xla//xla/service:gather_expander",
"@xla//xla/service:gather_simplifier",
"@xla//xla/service:hlo_constant_folding",
"@xla//xla/service:hlo_cse",
diff --git a/libspu/compiler/front_end/hlo_importer.cc b/libspu/compiler/front_end/hlo_importer.cc
index 296406d7..149f5a4b 100644
--- a/libspu/compiler/front_end/hlo_importer.cc
+++ b/libspu/compiler/front_end/hlo_importer.cc
@@ -28,6 +28,7 @@
#include "xla/service/eigh_expander.h"
#include "xla/service/float_normalization.h"
#include "xla/service/float_support.h"
+#include "xla/service/gather_expander.h"
#include "xla/service/gather_simplifier.h"
#include "xla/service/gpu/dot_dimension_sorter.h"
#include "xla/service/hlo_constant_folding.h"
@@ -126,6 +127,7 @@ void runHloPasses(xla::HloModule *module) {
/*allow_mixed_precision=*/false);
pipeline.AddPass();
+ pipeline.AddPass(GatherExpander::kEliminateSimpleGathers);
pipeline.AddPass(ScatterExpander::kEliminateAllScatters);
pipeline.AddPass(options);
pipeline.AddPass();
diff --git a/spu/utils/frontend.py b/spu/utils/frontend.py
index 5e6024f7..097c3a4e 100644
--- a/spu/utils/frontend.py
+++ b/spu/utils/frontend.py
@@ -29,6 +29,7 @@ def _jax_compilation_key(
fn: Callable, static_argnums, static_argnames, args: List, kwargs: Dict
):
import jax
+ import numpy as np
from jax._src.api_util import argnames_partial_except, argnums_partial_except
try:
@@ -60,9 +61,15 @@ def _function_contents(func):
f, dkwargs = argnames_partial_except(f, static_argnames, kwargs)
f, dargs = argnums_partial_except(f, static_argnums, args, allow_invalid=True)
- flat_args, _ = jax.tree_util.tree_flatten((dargs, dkwargs))
- types = [(a.dtype, a.shape) if hasattr(a, 'dtype') else type(a) for a in flat_args]
- hash_str = f'{hash(f)}-{hash(_function_contents(fn))}-{types}'
+ flat_args, tree = jax.tree_util.tree_flatten((dargs, dkwargs))
+ types = []
+ for a in flat_args:
+ if hasattr(a, 'dtype'):
+ types.append((a.dtype, a.shape))
+ else:
+ np_array = np.asarray(a)
+ types.append((np_array.dtype, np_array.shape))
+ hash_str = f'{hash(f)}-{hash(_function_contents(fn))}-{types}-{hash(tree)}'
return hash_str
@@ -114,6 +121,7 @@ def _jax_compilation(
cfn, output = jax.xla_computation(
fn, return_shape=True, static_argnums=static_argnums, backend="interpreter"
)(*args, **kwargs)
+
return cfn.as_serialized_hlo_module_proto(), output