Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Repo sync #444

Merged
merged 3 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
>
> please add your unreleased change here.

## TBD

- [Feature] Add Odd-Even Merge Sort to replace the bitonic sort
- [Feature] Add radix sort support for ABY3
- [Feature] Integrate with secretflow/psi
Expand Down
2 changes: 1 addition & 1 deletion docs/reference/pphlo_doc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ PPHlo API reference

PPHlo is short for (SPU High level ops), it's the assembly language of SPU.

PPHlo is built on `MLIR <https://mlir.llvm.org/>`_ infrastructure, the concrete ops definition could be found :spu_code_host:`here </spu/blob/master/spu%2Fdialect%2Fpphlo_ops.td>`.
PPHlo is built on `MLIR <https://mlir.llvm.org/>`_ infrastructure, the concrete ops definition could be found :spu_code_host:`here </spu/blob/master/libspu%2Fdialect%2Fpphlo_ops.td>`.

Op List
~~~~~~~
Expand Down
8 changes: 4 additions & 4 deletions docs/tutorials/develop_your_first_mpc_application.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1129,11 +1129,11 @@
"\n",
"Before diving into the problem deeper, we highly recommend the reader to read: \n",
"\n",
" 1. [spu_inside](https://www.secretflow.org.cn/docs/spu/en/getting_started/tutorials/spu_inside.html#Tracing): gives some introductions how spu works inside for float-point operations.\n",
" 1. [spu_inside](https://www.secretflow.org.cn/docs/spu/latest/en-US/tutorials/spu_inside#Tracing): gives some introductions how spu works inside for float-point operations.\n",
"\n",
" 2. [pitfall](https://www.secretflow.org.cn/docs/spu/en/reference/fxp.html): spu implements math function(like `reciprocal`, `log` and so on) with approximation algorithm, so some precision issue will occur when inputs fall into some intervals. We list some known issue about this.\n",
" 2. [pitfall](https://www.secretflow.org.cn/docs/spu/latest/en-US/development/fxp): spu implements math function(like `reciprocal`, `log` and so on) with approximation algorithm, so some precision issue will occur when inputs fall into some intervals. We list some known issue about this.\n",
"\n",
"3. [protocols](https://www.secretflow.org.cn/docs/spu/en/reference/mpc_status.html): list all protocols spu implements now. Generally speaking, for 2pc, it's safe to use cheetah, while for 3pc, ABY3 is the only choice.\n",
"3. [protocols](https://www.secretflow.org.cn/docs/spu/latest/en-US/reference/mpc_status): list all protocols spu implements now. Generally speaking, for 2pc, it's safe to use cheetah, while for 3pc, ABY3 is the only choice.\n",
"\n",
"First define a function just like `fit_and_predict` to get dk_arr."
]
Expand Down Expand Up @@ -1659,7 +1659,7 @@
"source": [
"#### Rsqrt v.s. Norm\n",
"\n",
"We can recall that the `compute_dk_func` function defined in Part 1 contains a `method` arg, and we just ignore this arg before. Indeed, we can tell simulator to print more information like [spu_inside](https://www.secretflow.org.cn/docs/spu/en/getting_started/tutorials/spu_inside.html#Tracing) do: enable **hlo**(High Level Operations) trace and profile on. Then we can figure out which op has been invoked and its time cost.\n",
"We can recall that the `compute_dk_func` function defined in Part 1 contains a `method` arg, and we just ignore this arg before. Indeed, we can tell simulator to print more information like [spu_inside](https://www.secretflow.org.cn/docs/spu/latest/en-US/tutorials/spu_inside#Tracing) do: enable **hlo**(High Level Operations) trace and profile on. Then we can figure out which op has been invoked and its time cost.\n",
"\n",
"Here, we list some advantages of using `jax.lax.rsqrt` rather than `jnp.linalg.norm`:\n",
"\n",
Expand Down
9 changes: 6 additions & 3 deletions examples/python/ml/flax_llama7b/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,20 @@ This example demonstrates how to use SPU to run secure inference on a pre-traine

Since EasyLM have an issue,so we have to make a samll change to support the option "streaming=false".
Open and edit "convert_hf_to_easylm.py", chang this:

```python
parser.add_argument("--streaming", action="store_true", default=True, help="whether is model weight saved stream format",)
```

to:

```python
parser.add_argument("--streaming", action="store_true", default=False, help="whether is model weight saved stream format",)
```

Download trained LLaMA-B[PyTroch-Version] from [Hugging Face](https://huggingface.co/openlm-research/open_llama_7b)
, and convert it to Flax.msgpack as:
```sh
```sh
python convert_hf_to_easylm.py \
--checkpoint_dir path-to-flax-llama7b-dir \
--output_file path-to-flax-llama7b-EasyLM.msgpack \
Expand All @@ -45,8 +49,7 @@ This example demonstrates how to use SPU to run secure inference on a pre-traine
bazel run -c opt //examples/python/utils:nodectl -- --config `pwd`/examples/python/ml/flax_llama7b/3pc.json up
```

or
(recommended)
or(recommended)

```sh
cd examples/python/utils
Expand Down
1 change: 1 addition & 0 deletions libspu/compiler/front_end/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ spu_cc_library(
"@xla//xla/service:dot_decomposer",
"@xla//xla/service:eigh_expander",
"@xla//xla/service:float_normalization",
"@xla//xla/service:gather_expander",
"@xla//xla/service:gather_simplifier",
"@xla//xla/service:hlo_constant_folding",
"@xla//xla/service:hlo_cse",
Expand Down
2 changes: 2 additions & 0 deletions libspu/compiler/front_end/hlo_importer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "xla/service/eigh_expander.h"
#include "xla/service/float_normalization.h"
#include "xla/service/float_support.h"
#include "xla/service/gather_expander.h"
#include "xla/service/gather_simplifier.h"
#include "xla/service/gpu/dot_dimension_sorter.h"
#include "xla/service/hlo_constant_folding.h"
Expand Down Expand Up @@ -126,6 +127,7 @@ void runHloPasses(xla::HloModule *module) {
/*allow_mixed_precision=*/false);

pipeline.AddPass<GatherSimplifier>();
pipeline.AddPass<GatherExpander>(GatherExpander::kEliminateSimpleGathers);
pipeline.AddPass<ScatterExpander>(ScatterExpander::kEliminateAllScatters);
pipeline.AddPass<AlgebraicSimplifier>(options);
pipeline.AddPass<BitcastDtypesExpander>();
Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/spdz2k/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ spu_cc_library(
":boolean",
":state",
":type",
"//libspu/core:vectorize",
"//libspu/core:xt_helper",
"//libspu/core:vectorize",
"//libspu/mpc:ab_api",
"//libspu/mpc:api",
"//libspu/mpc:kernel",
Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/spdz2k/ot/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ spu_cc_library(
name = "ferret",
hdrs = ["ferret.h"],
deps = [
"//libspu/mpc/cheetah/ot",
"//libspu/mpc/cheetah/ot:ot",
"//libspu/mpc/common:communicator",
],
)
Expand Down
14 changes: 11 additions & 3 deletions spu/utils/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def _jax_compilation_key(
fn: Callable, static_argnums, static_argnames, args: List, kwargs: Dict
):
import jax
import numpy as np
from jax._src.api_util import argnames_partial_except, argnums_partial_except

try:
Expand Down Expand Up @@ -60,9 +61,15 @@ def _function_contents(func):
f, dkwargs = argnames_partial_except(f, static_argnames, kwargs)
f, dargs = argnums_partial_except(f, static_argnums, args, allow_invalid=True)

flat_args, _ = jax.tree_util.tree_flatten((dargs, dkwargs))
types = [(a.dtype, a.shape) if hasattr(a, 'dtype') else type(a) for a in flat_args]
hash_str = f'{hash(f)}-{hash(_function_contents(fn))}-{types}'
flat_args, tree = jax.tree_util.tree_flatten((dargs, dkwargs))
types = []
for a in flat_args:
if hasattr(a, 'dtype'):
types.append((a.dtype, a.shape))
else:
np_array = np.asarray(a)
types.append((np_array.dtype, np_array.shape))
hash_str = f'{hash(f)}-{hash(_function_contents(fn))}-{types}-{hash(tree)}'

return hash_str

Expand Down Expand Up @@ -114,6 +121,7 @@ def _jax_compilation(
cfn, output = jax.xla_computation(
fn, return_shape=True, static_argnums=static_argnums, backend="interpreter"
)(*args, **kwargs)

return cfn.as_serialized_hlo_module_proto(), output


Expand Down
Loading