Skip to content

Commit

Permalink
repo-sync-2023-12-19T18:04:30+0800
Browse files Browse the repository at this point in the history
  • Loading branch information
anakinxc committed Dec 19, 2023
1 parent 654ef80 commit 0197308
Show file tree
Hide file tree
Showing 9 changed files with 51 additions and 13 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
>
> please add your unreleased change here.
## TBD

- [Feature] Add Odd-Even Merge Sort to replace the bitonic sort
- [Feature] Add radix sort support for ABY3
- [Feature] Integrate with secretflow/psi
Expand Down
2 changes: 1 addition & 1 deletion docs/reference/pphlo_doc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ PPHlo API reference

PPHlo is short for (SPU High level ops), it's the assembly language of SPU.

PPHlo is built on `MLIR <https://mlir.llvm.org/>`_ infrastructure, the concrete ops definition could be found :spu_code_host:`here </spu/blob/master/spu%2Fdialect%2Fpphlo_ops.td>`.
PPHlo is built on `MLIR <https://mlir.llvm.org/>`_ infrastructure, the concrete ops definition could be found :spu_code_host:`here </spu/blob/master/libspu%2Fdialect%2Fpphlo_ops.td>`.

Op List
~~~~~~~
Expand Down
8 changes: 4 additions & 4 deletions docs/tutorials/develop_your_first_mpc_application.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1129,11 +1129,11 @@
"\n",
"Before diving into the problem deeper, we highly recommend the reader to read: \n",
"\n",
" 1. [spu_inside](https://www.secretflow.org.cn/docs/spu/en/getting_started/tutorials/spu_inside.html#Tracing): gives some introductions how spu works inside for float-point operations.\n",
" 1. [spu_inside](https://www.secretflow.org.cn/docs/spu/latest/en-US/tutorials/spu_inside#Tracing): gives some introductions how spu works inside for float-point operations.\n",
"\n",
" 2. [pitfall](https://www.secretflow.org.cn/docs/spu/en/reference/fxp.html): spu implements math function(like `reciprocal`, `log` and so on) with approximation algorithm, so some precision issue will occur when inputs fall into some intervals. We list some known issue about this.\n",
" 2. [pitfall](https://www.secretflow.org.cn/docs/spu/latest/en-US/development/fxp): spu implements math function(like `reciprocal`, `log` and so on) with approximation algorithm, so some precision issue will occur when inputs fall into some intervals. We list some known issue about this.\n",
"\n",
"3. [protocols](https://www.secretflow.org.cn/docs/spu/en/reference/mpc_status.html): list all protocols spu implements now. Generally speaking, for 2pc, it's safe to use cheetah, while for 3pc, ABY3 is the only choice.\n",
"3. [protocols](https://www.secretflow.org.cn/docs/spu/latest/en-US/reference/mpc_status): list all protocols spu implements now. Generally speaking, for 2pc, it's safe to use cheetah, while for 3pc, ABY3 is the only choice.\n",
"\n",
"First define a function just like `fit_and_predict` to get dk_arr."
]
Expand Down Expand Up @@ -1659,7 +1659,7 @@
"source": [
"#### Rsqrt v.s. Norm\n",
"\n",
"We can recall that the `compute_dk_func` function defined in Part 1 contains a `method` arg, and we just ignore this arg before. Indeed, we can tell simulator to print more information like [spu_inside](https://www.secretflow.org.cn/docs/spu/en/getting_started/tutorials/spu_inside.html#Tracing) do: enable **hlo**(High Level Operations) trace and profile on. Then we can figure out which op has been invoked and its time cost.\n",
"We can recall that the `compute_dk_func` function defined in Part 1 contains a `method` arg, and we just ignore this arg before. Indeed, we can tell simulator to print more information like [spu_inside](https://www.secretflow.org.cn/docs/spu/latest/en-US/tutorials/spu_inside#Tracing) do: enable **hlo**(High Level Operations) trace and profile on. Then we can figure out which op has been invoked and its time cost.\n",
"\n",
"Here, we list some advantages of using `jax.lax.rsqrt` rather than `jnp.linalg.norm`:\n",
"\n",
Expand Down
31 changes: 28 additions & 3 deletions examples/python/ml/flax_llama7b/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,32 @@ This example demonstrates how to use SPU to run secure inference on a pre-traine
cd ./EasyLM/models/llama
```

Since EasyLM have an issue,so we have to make a samll change to support the option "streaming=false".
Open and edit "convert_hf_to_easylm.py", chang this:

```python
parser.add_argument("--streaming", action="store_true", default=True, help="whether is model weight saved stream format",)
```

to:

```python
parser.add_argument("--streaming", action="store_true", default=False, help="whether is model weight saved stream format",)
```

Since EasyLM have an issue,so we have to make a samll change to support the option "streaming=false".
Open and edit "convert_hf_to_easylm.py", chang this:

```python
parser.add_argument("--streaming", action="store_true", default=True, help="whether is model weight saved stream format",)
```

to:

```python
parser.add_argument("--streaming", action="store_true", default=False, help="whether is model weight saved stream format",)
```

Since EasyLM have an issue,so we have to make a samll change to support the option "streaming=false".
Open and edit "convert_hf_to_easylm.py", chang this:
```python
Expand All @@ -31,7 +57,7 @@ This example demonstrates how to use SPU to run secure inference on a pre-traine
```
Download trained LLaMA-B[PyTroch-Version] from [Hugging Face](https://huggingface.co/openlm-research/open_llama_7b)
, and convert it to Flax.msgpack as:
```sh
```sh
python convert_hf_to_easylm.py \
--checkpoint_dir path-to-flax-llama7b-dir \
--output_file path-to-flax-llama7b-EasyLM.msgpack \
Expand All @@ -45,8 +71,7 @@ This example demonstrates how to use SPU to run secure inference on a pre-traine
bazel run -c opt //examples/python/utils:nodectl -- --config `pwd`/examples/python/ml/flax_llama7b/3pc.json up
```

or
(recommended)
or(recommended)

```sh
cd examples/python/utils
Expand Down
1 change: 1 addition & 0 deletions libspu/compiler/front_end/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ spu_cc_library(
"@xla//xla/service:dot_decomposer",
"@xla//xla/service:eigh_expander",
"@xla//xla/service:float_normalization",
"@xla//xla/service:gather_expander",
"@xla//xla/service:gather_simplifier",
"@xla//xla/service:hlo_constant_folding",
"@xla//xla/service:hlo_cse",
Expand Down
2 changes: 2 additions & 0 deletions libspu/compiler/front_end/hlo_importer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "xla/service/eigh_expander.h"
#include "xla/service/float_normalization.h"
#include "xla/service/float_support.h"
#include "xla/service/gather_expander.h"
#include "xla/service/gather_simplifier.h"
#include "xla/service/gpu/dot_dimension_sorter.h"
#include "xla/service/hlo_constant_folding.h"
Expand Down Expand Up @@ -126,6 +127,7 @@ void runHloPasses(xla::HloModule *module) {
/*allow_mixed_precision=*/false);

pipeline.AddPass<GatherSimplifier>();
pipeline.AddPass<GatherExpander>(GatherExpander::kEliminateSimpleGathers);
pipeline.AddPass<ScatterExpander>(ScatterExpander::kEliminateAllScatters);
pipeline.AddPass<AlgebraicSimplifier>(options);
pipeline.AddPass<BitcastDtypesExpander>();
Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/spdz2k/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ spu_cc_library(
":boolean",
":state",
":type",
"//libspu/core:vectorize",
"//libspu/core:xt_helper",
"//libspu/core:vectorize",
"//libspu/mpc:ab_api",
"//libspu/mpc:api",
"//libspu/mpc:kernel",
Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/spdz2k/ot/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ spu_cc_library(
name = "ferret",
hdrs = ["ferret.h"],
deps = [
"//libspu/mpc/cheetah/ot",
"//libspu/mpc/cheetah/ot:ot",
"//libspu/mpc/common:communicator",
],
)
Expand Down
14 changes: 11 additions & 3 deletions spu/utils/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def _jax_compilation_key(
fn: Callable, static_argnums, static_argnames, args: List, kwargs: Dict
):
import jax
import numpy as np
from jax._src.api_util import argnames_partial_except, argnums_partial_except

try:
Expand Down Expand Up @@ -60,9 +61,15 @@ def _function_contents(func):
f, dkwargs = argnames_partial_except(f, static_argnames, kwargs)
f, dargs = argnums_partial_except(f, static_argnums, args, allow_invalid=True)

flat_args, _ = jax.tree_util.tree_flatten((dargs, dkwargs))
types = [(a.dtype, a.shape) if hasattr(a, 'dtype') else type(a) for a in flat_args]
hash_str = f'{hash(f)}-{hash(_function_contents(fn))}-{types}'
flat_args, tree = jax.tree_util.tree_flatten((dargs, dkwargs))
types = []
for a in flat_args:
if hasattr(a, 'dtype'):
types.append((a.dtype, a.shape))
else:
np_array = np.asarray(a)
types.append((np_array.dtype, np_array.shape))
hash_str = f'{hash(f)}-{hash(_function_contents(fn))}-{types}-{hash(tree)}'

return hash_str

Expand Down Expand Up @@ -114,6 +121,7 @@ def _jax_compilation(
cfn, output = jax.xla_computation(
fn, return_shape=True, static_argnums=static_argnums, backend="interpreter"
)(*args, **kwargs)

return cfn.as_serialized_hlo_module_proto(), output


Expand Down

0 comments on commit 0197308

Please sign in to comment.