diff --git a/examples/python/ml/flax_llama7b_split/3pc.json b/examples/python/ml/flax_llama7b_split/3pc.json new file mode 100644 index 00000000..74284f99 --- /dev/null +++ b/examples/python/ml/flax_llama7b_split/3pc.json @@ -0,0 +1,46 @@ +{ + "id": "outsourcing.3pc", + "nodes": { + "node:0": "127.0.0.1:9920", + "node:1": "127.0.0.1:9921", + "node:2": "127.0.0.1:9922", + "node:3": "127.0.0.1:9923", + "node:4": "127.0.0.1:9924" + }, + "devices": { + "SPU": { + "kind": "SPU", + "config": { + "node_ids": [ + "node:0", + "node:1", + "node:2" + ], + "spu_internal_addrs": [ + "127.0.0.1:9930", + "127.0.0.1:9931", + "127.0.0.1:9932" + ], + "runtime_config": { + "protocol": "ABY3", + "field": "FM64", + "enable_pphlo_profile": true, + "enable_hal_profile": true, + "fxp_exp_mode": 1 + } + } + }, + "P1": { + "kind": "PYU", + "config": { + "node_id": "node:3" + } + }, + "P2": { + "kind": "PYU", + "config": { + "node_id": "node:4" + } + } + } +} \ No newline at end of file diff --git a/examples/python/ml/flax_llama7b_split/BUILD.bazel b/examples/python/ml/flax_llama7b_split/BUILD.bazel new file mode 100644 index 00000000..6a735840 --- /dev/null +++ b/examples/python/ml/flax_llama7b_split/BUILD.bazel @@ -0,0 +1,31 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_python//python:defs.bzl", "py_binary") + +package(default_visibility = ["//visibility:public"]) + +py_binary( + name = "flax_llama7b_split", + srcs = ["flax_llama7b_split.py"], + data = [ + "//examples/python/ml/flax_llama7b_split:3pc.json", + ], + deps = [ + "//spu:api", + "//spu/intrinsic:all_intrinsics", + "//spu/utils:distributed", + "//spu/utils:simulation", + ], +) diff --git a/examples/python/ml/flax_llama7b_split/README.md b/examples/python/ml/flax_llama7b_split/README.md new file mode 100644 index 00000000..950b9083 --- /dev/null +++ b/examples/python/ml/flax_llama7b_split/README.md @@ -0,0 +1,181 @@ +# Flax LlaMA-7B Example with Model Split + +This example demonstrates how to use SPU to run secure inference on a pre-trained +[LlaMA-7B](https://research.facebook.com/publications/llama-open-and-efficient-foundation-language-models/) model with model split. + +1. Motivation + +- Time: +Using the full LlaMA model for inference on SPU can take a significant amount of time. If only a portion of the model is passed through SPU to ensure privacy, it can greatly improve inference efficiency. + +- RAM Usage: +Using the full LLaMA model for inference on SPU requires a large amount of memory(more than 256 GB). Splitting the model can significantly reduce memory usage, making it available for use in hardware-constrained environments. + +2. Download EasyML library to support Flax-LLaMA-7B + + ```sh + git clone https://github.com/young-geng/EasyLM.git + cd EasyLM + export PYTHONPATH="${PWD}:$PYTHONPATH" + ``` + Or use a fork created for transformer split. + + ```sh + git clone -b split_llama_2 https://github.com/Rainysponge/EasyLM.git + cd EasyLM + export PYTHONPATH="${PWD}:$PYTHONPATH" + ``` + + Install EasyLM Environment Before Install Secretflow & SPU + + ```sh + conda env create -f examples/python/ml/flax_llama7b_split/gpu_environment.yml + conda activate EasyLM + pip install 'transformers[flax]' + pip install spu + ``` + + If Do not Want to use GPU + + ```sh + pip uninstall jax jaxlib + pip install jax==0.4.11 jaxlib==0.4.11 + ``` + + Download trained LLaMA-7B[PyTroch-Version] from "https://github.com/facebookresearch/llama", and convert it to EasyLM format as: + + ```sh + cd path_to_EasyLM/EasyLM/models/llama + python convert_hf_to_easylm.py \ + --checkpoint_dir path_to_llama_weights \ + --output_file path_to_outputfile \ + --model_size 7b \ + --streaming + ``` + + Move the python file to EasyLM if you do not use the fork fork created for transformer split. + + ```sh + cp path-to-llama_model_split_transformer_py path_to_EasyLM/EasyLM/models/llama + ``` + +3. Launch SPU backend runtime + + ```sh + bazel run -c opt //examples/python/utils:nodectl -- --config `pwd`/examples/python/ml/flax_llama7b_split/3pc.json up + ``` + + or + (recommended) + + ```sh + cd examples/python/utils + python nodectl.py --config `pwd`/examples/python/ml/flax_llama7b_split/3pc.json up + ``` + +4. Run `flax_llama7b_split` example + + ```sh + bazel run -c opt //examples/python/ml/flax_llama7b_split -- --config `pwd`/examples/python/ml/flax_llama7b_split/3pc.json + ``` + + or(recommended) + + ```sh + cd examples/python/ml/flax_llama7b_split + python flax_llama7b_split.py --config `pwd`/examples/python/ml/flax_llama7b_split/3pc.json + ``` + + and you can get the following results from our example: + + ```md + ------ + Run on CPU + Q: What is the largest animal? + A: The largest animal is the blue whale. + generate on CPU: 256.08824276924133 seconds + + ------ + Run on SPU + [0000-00-00 00:00:00.000] [info] [thread_pool.cc:30] Create a fixed thread pool with size 127 + Q: What is the largest animal? + A: The largest animal is the blue whale. + generate on SPU: 812.9427680969238 seconds + ``` + RAM peak: 64.5888GB + + And If you set token_num to 30, you can get the following results: + ```sh + ------ + Run on CPU + Q: What is the largest animal? + A: The largest animal is the blue whale. + Q: What is the smallest animal? + A: The smallest animal is the bacterium. + generate on CPU: 837.0810837745667 seconds + + ------ + Run on SPU + [0000-00-00 00:00:00.000] [info] [thread_pool.cc:30] Create a fixed thread pool with size 127 + Q: What is the largest animal? + A: The largest animal is the blue whale. + Q: What is the smallest animal? + A: The smallest animal is the bacterium. + generate on SPU: 2760.5035905838013 seconds + ``` + +5. Supplement the Split Strategy + +In this example, we split the LLaMA-7B model into three parts as follows: + +- **Client**: Embedding + 0-1 LLaMA-Block +- **Mid**: 2nd LLaMA-Block (_runing on the spu_) +- **Server**: 3-31 LLaMA-Block + RMSNorm Layer + +Actually, if users want to split LLaMA-7B model in other way, this can be easily achieved by rewriting few lines of code in `flax_llama7b_split.py` and `llama_model_splited_transformer.py`. + +For example, we rewrite the files as follow. + +```python +# flax_llama7b_split.py +# lines 76-89 + +client_params_dict = { + "transformer":{ + "wte":params['params']["transformer"]["wte"], + "ln_f": params['params']["transformer"]["ln_f"], + "h":{str(i): params['params']["transformer"]["h"][str(i)] for i in range(2)} + } + } + +mid_params_dict = { + "transformer":{ + + "h":{str(i): params['params']["transformer"]["h"][str(i)] for i in range(2, 5)} + } +} +``` +```python +# llama_model_splited_transformer.py +# lines 1194 +for block in self.blocks[: 2]: + ... + +# lines 1274 +for block in self.blocks[2: 5]: + ... + +# lines 1355 +for block in self.blocks[5:]: + ... +``` +After that, the LLaMA-7B model will be split as follows: +- **Client**: Embedding + 0-1 LLaMA-Block +- **Mid**: 2-4 LLaMA-Block (_runing on the spu_) +- **Server**: 5-31 LLaMA-Block + RMSNorm Layer + +6. Privacy Security Warning + +In this example, our main motivation is to reduce the hardware and time resource costs of [Llama-7B](https://research.facebook.com/publications/llama-open-and-efficient-foundation-language-models/) model inference using the SPU. Therefore, spu is only used for inference on the middle blocks of the model. Its privacy protection capability for the original data is weaker when using spu for inference on the entire Llama-7B model. It may be vulnerable to Model Inversion Attacks known in Split Learning as follows: +- [PCAT: Functionality and Data Stealing from Split Learning by Pseudo-Client Attack](https://www.usenix.org/system/files/usenixsecurity23-gao.pdf) +- [UnSplit: Data-Oblivious Model Inversion, Model Stealing, and Label Inference Attacks Against Split Learning](https://arxiv.org/pdf/2108.09033.pdf) diff --git a/examples/python/ml/flax_llama7b_split/flax_llama7b_split.py b/examples/python/ml/flax_llama7b_split/flax_llama7b_split.py new file mode 100644 index 00000000..14ef45fc --- /dev/null +++ b/examples/python/ml/flax_llama7b_split/flax_llama7b_split.py @@ -0,0 +1,306 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Start nodes. +# > bazel run -c opt //examples/python/utils:nodectl -- --config `pwd`/examples/python/ml/flax_llama_split/3pc.json" up +# Run this example script. +# > bazel run -c opt //examples/python/ml/flax_llama7b -- --config `pwd`/examples/python/ml/flax_llama_split/3pc.json +import time +import argparse +import json +import jax +import jax.numpy as jnp +import jax.nn as jnn +import flax.linen as nn +from flax.linen.linear import Array +from typing import Any, Optional, Tuple, Union +from transformers import LlamaTokenizer +from EasyLM.checkpoint import StreamingCheckpointer +from EasyLM.models.llama.llama_model import FlaxLLaMAForCausalLM +from EasyLM.models.llama.llama_model_splited_transformer import ( + FlaxLLaMAForCausalLMClient, + FlaxLLaMAForCausalLMServer, + FlaxLLaMAModule, + FlaxLLaMAForCausalLMMid, + LLaMAConfig, +) + + +import spu.utils.distributed as ppd +from contextlib import contextmanager +import spu.spu_pb2 as spu_pb2 + +from flax.linen.linear import Array +from typing import Any, Optional, Tuple, Union + +parser = argparse.ArgumentParser(description='distributed driver.') +parser.add_argument( + "-c", "--config", default="examples/python/ml/flax_llama_split/3pc.json" +) +args = parser.parse_args() + +with open(args.config, 'r') as file: + conf = json.load(file) + +ppd.init(conf["nodes"], conf["devices"]) + +copts = spu_pb2.CompilerOptions() +copts.enable_pretty_print = False +copts.xla_pp_kind = 2 +# enable x / broadcast(y) -> x * broadcast(1/y) +copts.enable_optimize_denominator_with_broadcast = True + +# model_path = 'path-to-flax-llama7b' + +model_path = "params::path-to-flax-llama7b-checkpoint" +tokenizer_path = "path-to-flax-llama7b" +tokenizer = LlamaTokenizer.from_pretrained(tokenizer_path) +tokenizer.pad_token_id = tokenizer.eos_token_id +config = LLaMAConfig() +# pretrained_model = FlaxLLaMAForCausalLM.from_pretrained(model_path, config=config) +with jax.default_device(jax.devices("cpu")[0]): + # llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config) + _, params = StreamingCheckpointer.load_trainstate_checkpoint( + model_path, disallow_trainstate=True + ) +client_params_dict = { + "transformer": { + "wte": params['params']["transformer"]["wte"], + "ln_f": params['params']["transformer"]["ln_f"], + "h": {str(i): params['params']["transformer"]["h"][str(i)] for i in range(2)}, + } +} + +mid_params_dict = { + "transformer": { + "h": {str(i): params['params']["transformer"]["h"][str(i)] for i in range(2, 3)} + } +} + +server_params_dict = { + "transformer": { + "ln_f": params['params']["transformer"]["ln_f"], + "h": { + str(i): params['params']["transformer"]["h"][str(i)] + for i in range(3, len(params['params']["transformer"]["h"])) + }, + }, + "lm_head": { + "kernel": params['params']["lm_head"]["kernel"], + }, +} + + +def hack_softmax( + x: Array, + axis: Optional[Union[int, Tuple[int, ...]]] = -1, + where: Optional[Array] = None, + initial: Optional[Array] = None, +) -> Array: + x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True) + x = x - x_max + # exp on large negative is clipped to zero + b = x > -14 + nexp = jnp.exp(x) * b + divisor = jnp.sum(nexp, axis, where=where, keepdims=True) + return nexp / divisor + + +@contextmanager +def hack_softmax_context(msg: str, enabled: bool = True): + if not enabled: + yield + return + # hijack some target functions + raw_softmax = jnn.softmax + jnn.softmax = hack_softmax + yield + # recover back + jnn.softmax = raw_softmax + + +def hack_silu(x: Array) -> Array: + b0 = x < -8.0 + b1 = x < -4.0 + b2 = x > 4.0 + b3 = b1 ^ b2 ^ True # x in [-4.0, 4.0) + b4 = b0 ^ b1 # x in [-8.0, -4.0) + # seg1 = a[2] * x^2 + a[1] * x + a[0] + # seg2 = b[6] * x^6 + b[4] * x^4 + b[2] * x^2 + b[0] + a_coeffs = jnp.array( + [-0.3067541139982155, -0.0819767021525476, -0.0055465625580307] + ) + b_coeffs = jnp.array( + [ + 0.0085064025895951, + 0.5, + 0.2281430841728270, + -0.011113046708173, + 0.0002743776353465, + ] + ) + x2 = jnp.square(x) + x4 = jnp.square(x2) + x6 = x2 * x4 + seg1 = a_coeffs[2] * x2 + a_coeffs[1] * x + a_coeffs[0] + seg2 = ( + b_coeffs[4] * x6 + + b_coeffs[3] * x4 + + b_coeffs[2] * x2 + + b_coeffs[1] * x + + b_coeffs[0] + ) + ret = b2 * x + b4 * seg1 + b3 * seg2 + return ret + + +@contextmanager +def hack_silu_context(msg: str, enabled: bool = True): + if not enabled: + yield + return + # hijack some target functions + raw_silu = nn.silu + nn.silu = hack_silu + yield + # recover back + nn.silu = raw_silu + + +# greedy search +# ref: https://huggingface.co/blog/how-to-generate +# for embedding generation +def embeding_generation(input_ids, params): + config = LLaMAConfig() + model = FlaxLLaMAForCausalLMClient(config=config) + smasheddata, attention_mask, position_ids = model( + input_ids=input_ids, params=params + ) + del model + return smasheddata, attention_mask, position_ids + + +def mid_generation(input_ids, params, attention_mask, position_ids): + config = LLaMAConfig() + _model = FlaxLLaMAForCausalLMMid(config=config) + + _smasheddata = _model( + input_ids=input_ids, + params=params, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + return _smasheddata, attention_mask, position_ids + + +def server_generation(input_ids, params, attention_mask, position_ids): + config = LLaMAConfig() + _model = FlaxLLaMAForCausalLMServer(config=config) + + _smasheddata = _model( + input_ids=input_ids, + params=params, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + return _smasheddata + + +def run_on_cpu(token_num=9): + input_ids = tokenizer.encode( + 'Q: What is the largest animal?\nA:', return_tensors='jax' + ) + for _ in range(token_num): + smasheddata, attention_mask, position_ids = embeding_generation( + input_ids=input_ids, params=client_params_dict + ) + + _smasheddata, attention_mask, position_ids = mid_generation( + input_ids=smasheddata, + params=mid_params_dict, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + outputs = server_generation( + input_ids=_smasheddata, + params=server_params_dict, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + next_token_logits = outputs[0][0, -1, :] + next_token = jnp.argmax(next_token_logits) + + input_ids = jnp.concatenate([input_ids, jnp.array([[next_token]])], axis=1) + + return input_ids + + +def run_on_spu(token_num=9): + # encode context the generation is conditioned on + input_ids = tokenizer.encode( + 'Q: What is the largest animal?\nA:', return_tensors='jax' + ) + for _ in range(token_num): + smasheddata, attention_mask, position_ids = embeding_generation( + input_ids=input_ids, params=client_params_dict + ) + with hack_softmax_context( + "hack exp of softmax", enabled=True + ), hack_silu_context("hack silu", enabled=True): + _input_ids = ppd.device("P1")(lambda x: x)(smasheddata) + _params = ppd.device("P2")(lambda x: x)(mid_params_dict) + + _smasheddata, attention_mask, position_ids = ppd.device("SPU")( + mid_generation + )(_input_ids, _params, attention_mask, position_ids) + + _smasheddata, attention_mask, position_ids = ( + ppd.get(_smasheddata), + ppd.get(attention_mask), + ppd.get(position_ids), + ) + + outputs = server_generation( + input_ids=_smasheddata, + params=server_params_dict, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + next_token_logits = outputs[0][0, -1, :] + next_token = jnp.argmax(next_token_logits) + + input_ids = jnp.concatenate([input_ids, jnp.array([[next_token]])], axis=1) + + return input_ids + + +if __name__ == '__main__': + print('\n------\nRun on CPU') + start_time = time.time() + outputs_ids = run_on_cpu() + print(tokenizer.decode(outputs_ids[0], skip_special_tokens=True)) + end_time = time.time() + print(f"generate on CPU: {end_time - start_time} seconds") + + print('\n------\nRun on SPU') + start_time = time.time() + outputs_ids = run_on_spu() + print(tokenizer.decode(outputs_ids[0], skip_special_tokens=True)) + end_time = time.time() + print(f"generate on SPU: {end_time - start_time} seconds") diff --git a/examples/python/ml/flax_llama7b_split/gpu_environment.yml b/examples/python/ml/flax_llama7b_split/gpu_environment.yml new file mode 100644 index 00000000..68c38e6f --- /dev/null +++ b/examples/python/ml/flax_llama7b_split/gpu_environment.yml @@ -0,0 +1,41 @@ +name: EasyLM +channels: + - conda-forge +dependencies: + - python=3.8 + - pip + - numpy + - scipy + - matplotlib + - seaborn + - jupyter + - tqdm + - sentencepiece + - pip: + - -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + - --extra-index-url https://download.pytorch.org/whl/cu118 + - jax[cuda11_pip]==0.4.11 + - flax==0.7.0 + - optax==0.1.7 + - distrax==0.1.3 + - chex==0.1.7 + - transformers==4.31.0 + - torch==2.0.1 + - huggingface_hub==0.16.4 + - datasets==2.14.2 + - einops + - tensorflow==2.11.1 + - dill + - absl-py + - wandb + - ml_collections + - gcsfs + - requests + - jupyter_http_over_ws + - lm-eval + - mlxu==0.1.11 + - pydantic + - fastapi + - uvicorn + - gradio + diff --git a/examples/python/ml/flax_llama7b_split/llama_model_splited_transformer.py b/examples/python/ml/flax_llama7b_split/llama_model_splited_transformer.py new file mode 100644 index 00000000..4af22853 --- /dev/null +++ b/examples/python/ml/flax_llama7b_split/llama_model_splited_transformer.py @@ -0,0 +1,2339 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Additional Reference Link: +# Original Source Code Form +# [EasyLM](https://github.com/young-geng/EasyLM/tree/main) + +import os +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union +import json +import tempfile +from functools import partial +from jax import jit +import numpy as np +import jax +import jax.numpy as jnp +from jax import lax +from jax.sharding import PartitionSpec as PS +import flax.linen as nn +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen import combine_masks, make_causal_mask +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict +from flax.linen import partitioning as nn_partitioning +import einops + +import sentencepiece as spm +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging +from transformers.tokenization_utils import PreTrainedTokenizer +from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput +from transformers.modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + append_call_sample_docstring, +) +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) + + +from ml_collections import ConfigDict +from ml_collections.config_dict import config_dict +from mlxu import function_args_to_config, load_pickle, open_file + +from EasyLM.bpt import blockwise_ffn, blockwise_attn +from EasyLM.jax_utils import ( + with_sharding_constraint, + get_jax_mesh, + get_gradient_checkpoint_policy, +) + + +LLAMA_STANDARD_CONFIGS = { + '7b': { + 'vocab_size': 32000, + 'hidden_size': 4096, + 'intermediate_size': 11008, + 'num_hidden_layers': 32, + 'num_attention_heads': 32, + 'max_sequence_length': 2048, + 'initializer_range': 0.02, + 'rms_norm_eps': 1e-6, + 'use_cache': True, + 'tie_word_embeddings': False, + }, +} + + +class LLaMAConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`~LLaMAModel`]. It is used to instantiate an LLaMA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the LLaMA-7B. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`~LLaMAModel`] or [`~TFLLaMAModel`]. + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_sequence_length (`int`, *optional*, defaults to 2048): + Max sequence length for model (for RoPE computation) + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings(`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + Example: + ```python + >>> from transformers import LLaMAModel, LLaMAConfig + >>> # Initializing a LLaMA llama-7b style configuration + >>> configuration = LLaMAConfig() + >>> # Initializing a model from the llama-7b style configuration + >>> model = LLaMAModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "llama" + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + max_sequence_length=2048, + rms_norm_eps=1e-6, + initializer_range=0.02, + use_cache=True, + # pad_token_id=-1, + bos_token_id=0, + eos_token_id=1, + resid_pdrop=0.0, + embd_pdrop=0.0, + attn_pdrop=0.0, + tie_word_embeddings=False, + remat_block='', + remat_attention='', + remat_mlp='', + scan_attention=False, + scan_mlp=False, + scan_query_chunk_size=1024, + scan_key_chunk_size=1024, + scan_mlp_chunk_size=1024, + fcm_min_ratio=0.0, + fcm_max_ratio=0.0, + splitlayer=(0, 1), + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.initializer_range = initializer_range + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_sequence_length = max_sequence_length + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attn_pdrop = attn_pdrop + self.remat_block = remat_block + self.remat_attention = remat_attention + self.remat_mlp = remat_mlp + self.scan_attention = scan_attention + self.scan_mlp = scan_mlp + self.scan_query_chunk_size = scan_query_chunk_size + self.scan_key_chunk_size = scan_key_chunk_size + self.scan_mlp_chunk_size = scan_mlp_chunk_size + self.fcm_min_ratio = fcm_min_ratio + self.fcm_max_ratio = fcm_max_ratio + self.splitlayer = splitlayer + + super().__init__( + # pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + @classmethod + def get_default_config(cls, updates=None): + config = function_args_to_config(cls.__init__) + + if updates is not None: + config.update(ConfigDict(updates).copy_and_resolve_references()) + + return config + + @staticmethod + def get_jax_mesh(axis_dims): + return get_jax_mesh(axis_dims, ('dp', 'fsdp', 'mp')) + + @staticmethod + def get_partition_rules(): + """Parition rules for GPTJ. Note that these rules are orderd, so that + the beginning rules match first. It is important to use + PartitionSpec() instead of None here because JAX does not treat + None as a pytree leaf. + """ + return ( + # embeddings + ("transformer/wte/embedding", PS("mp", "fsdp")), + # atention + ("attention/(wq|wk|wv)/kernel", PS("fsdp", "mp")), + ("attention/wo/kernel", PS("mp", "fsdp")), + # mlp + ("feed_forward/w1/kernel", PS("fsdp", "mp")), + ("feed_forward/w2/kernel", PS("mp", "fsdp")), + ("feed_forward/w3/kernel", PS("fsdp", "mp")), + # layer norms + ("attention_norm/kernel", PS(None)), + ("ffn_norm/kernel", PS(None)), + # output head + ("transformer/ln_f/kernel", PS(None)), + ("lm_head/kernel", PS("fsdp", "mp")), + ('.*', PS(None)), + ) + + @staticmethod + def get_weight_decay_exclusions(): + return tuple() + + @staticmethod + def rng_keys(): + return ('params', 'dropout', 'fcm') + + @staticmethod + def get_tokenizer_config(updates=None): + config = ConfigDict() + config.vocab_file = '' + config.add_bos_token = False + config.add_eos_token = False + + if updates is not None: + config.update(ConfigDict(updates).copy_and_resolve_references()) + return config + + @classmethod + def get_tokenizer(cls, config, padding_side='left', truncation_side='right'): + config = cls.get_tokenizer_config(config) + assert config.vocab_file != '', 'vocab_file must be specified' + tokenizer = LLaMATokenizer( + vocab_file=config.vocab_file, + add_bos_token=config.add_bos_token, + add_eos_token=config.add_eos_token, + padding_side=padding_side, + truncation_side=truncation_side, + ) + return tokenizer + + @classmethod + def load_config(cls, path): + if path in LLAMA_STANDARD_CONFIGS: + return cls.from_dict(LLAMA_STANDARD_CONFIGS[path]) + load_type, load_path = path.split('::', 1) + if load_type == 'pickle': + return cls.from_dict(load_pickle(load_path)['llama_config']) + elif load_type == 'json': + with open_file(load_path, 'r') as fin: + raw_config = fin.read() + return cls.from_dict(json.loads(raw_config)) + else: + raise ValueError(f'Unsupported load config type: {load_type}') + + +remat = nn_partitioning.remat + +logger = logging.get_logger(__name__) + + +class RMSNorm(nn.Module): + dim: int + eps: float = 1e-6 + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + + def setup(self) -> None: + self.weight = self.param( + 'kernel', + nn.initializers.ones, + (self.dim,), + self.param_dtype, + ) + + def _norm(self, x: jnp.ndarray) -> jnp.ndarray: + return x * jax.lax.rsqrt(jnp.square(x).mean(-1, keepdims=True) + self.eps) + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + x = x.astype(jnp.promote_types(self.dtype, jnp.float32)) + output = self._norm(x).astype(self.dtype) + weight = jnp.asarray(self.weight, self.dtype) + + return output * weight + + +def precompute_freqs_cis( + dim: int, end: int, theta: float = 10000.0, dtype: jnp.dtype = jnp.float32 +) -> jnp.ndarray: + freqs = 1.0 / (theta ** (np.arange(0, dim, 2)[: (dim // 2)].astype(dtype) / dim)) + t = np.arange(end) # type: ignore + freqs = np.outer(t, freqs).astype(dtype) # type: ignore + sin, cos = np.sin(freqs), np.cos(freqs) + freqs_cis = np.complex64(cos + 1j * sin) + return jnp.asarray(freqs_cis) + + +def apply_rotary_emb( + xq: jnp.ndarray, + xk: jnp.ndarray, + freqs_cis: jnp.ndarray, + dtype: jnp.dtype = jnp.float32, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2) + reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2) + + xq_ = jax.lax.complex(reshape_xq[..., 0], reshape_xq[..., 1]) + xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1]) + + # add head dim + freqs_cis = jnp.reshape(freqs_cis, (*freqs_cis.shape[:2], 1, *freqs_cis.shape[2:])) + + xq_out = xq_ * freqs_cis + xq_out = jnp.stack((jnp.real(xq_out), jnp.imag(xq_out)), axis=-1).reshape( + *xq_out.shape[:-1], -1 + ) + + xk_out = xk_ * freqs_cis + xk_out = jnp.stack((jnp.real(xk_out), jnp.imag(xk_out)), axis=-1).reshape( + *xk_out.shape[:-1], -1 + ) + + return xq_out.astype(dtype), xk_out.astype(dtype) + + +class FlaxLLaMAAttention(nn.Module): + config: LLaMAConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[jax.lax.Precision, str]] = None + + def setup(self): + config = self.config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + + self.wq = nn.Dense( + config.num_attention_heads * self.head_dim, + dtype=self.dtype, + param_dtype=self.param_dtype, + use_bias=False, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + precision=self.precision, + ) + self.wk = nn.Dense( + config.num_attention_heads * self.head_dim, + dtype=self.dtype, + param_dtype=self.param_dtype, + use_bias=False, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + precision=self.precision, + ) + self.wv = nn.Dense( + config.num_attention_heads * self.head_dim, + dtype=self.dtype, + param_dtype=self.param_dtype, + use_bias=False, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + precision=self.precision, + ) + self.wo = nn.Dense( + config.hidden_size, + dtype=self.dtype, + param_dtype=self.param_dtype, + use_bias=False, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + precision=self.precision, + ) + + self.resid_dropout = nn.Dropout(rate=config.resid_pdrop) + + self.causal_mask = make_causal_mask( + jnp.ones((1, config.max_sequence_length), dtype="bool"), dtype="bool" + ) + + self.freqs_cis = precompute_freqs_cis( + self.head_dim, + config.max_sequence_length * 2, + dtype=self.dtype, + ) + + def _split_heads(self, hidden_states): + return hidden_states.reshape( + hidden_states.shape[:2] + (self.num_heads, self.head_dim) + ) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) + + @nn.compact + def _concatenate_to_cache(self, key, value, query, attention_mask): + """ + This function takes projected key, value states from a single input token and concatenates the states to cached + states from previous steps. This function is slighly adapted from the official Flax repository: + https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 + """ + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable( + "cache", "cached_key", jnp.zeros, key.shape, key.dtype + ) + cached_value = self.variable( + "cache", "cached_value", jnp.zeros, value.shape, value.dtype + ) + cache_index = self.variable( + "cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32) + ) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = lax.dynamic_update_slice(cached_key.value, key, indices) + value = lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query + # position should only attend to those key positions that have + # already been generated and cached, not the remaining zero + # elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + return key, value, attention_mask + + def __call__( + self, + hidden_states, + attention_mask, + position_ids, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + fcm_mask=None, + ): + xq, xk, xv = ( + self.wq(hidden_states), + self.wk(hidden_states), + self.wv(hidden_states), + ) + + xq = with_sharding_constraint(xq, PS(("dp", "fsdp"), None, "mp")) + xk = with_sharding_constraint(xk, PS(("dp", "fsdp"), None, "mp")) + xv = with_sharding_constraint(xv, PS(("dp", "fsdp"), None, "mp")) + + xq = self._split_heads(xq) + xk = self._split_heads(xk) + xv = self._split_heads(xv) + + freqs_cis = jnp.take(self.freqs_cis, position_ids, axis=0) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis, dtype=self.dtype) + + dropout_rng = None + if not deterministic and self.config.attn_pdrop > 0.0: + dropout_rng = self.make_rng("dropout") + + if self.config.scan_attention and not ( + self.has_variable("cache", "cached_key") or init_cache + ): + # print('self.config.scan_attention and not (self.has_variable("cache", "cached_key") or init_cache)') + # doesn't need blockwise attention if we are doing autoregressive + # decoding since no quadratic memory + + # attention mask without nxn materlization, blockwise_attn will + # handle the rest + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + # transform boolean mask into float mask + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype( + self.dtype + ), + ) + attn_weights = None + attn_output = blockwise_attn( + xq, + xk, + xv, + bias=attention_bias, + deterministic=deterministic, + dropout_rng=dropout_rng, + attn_pdrop=self.config.attn_pdrop, + causal=True, + query_chunk_size=self.config.scan_query_chunk_size, + key_chunk_size=self.config.scan_key_chunk_size, + dtype=self.dtype, + policy=get_gradient_checkpoint_policy('nothing_saveable'), + precision=self.precision, + float32_logits=True, + prevent_cse=True, + ) + attn_output = with_sharding_constraint( + attn_output, PS(("dp", "fsdp"), None, "mp", None) + ) + else: + query_length, key_length = xq.shape[1], xk.shape[1] + + if self.has_variable("cache", "cached_key"): + mask_shift = self.variables["cache"]["cache_index"] + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_mask = lax.dynamic_slice( + self.causal_mask, + (0, 0, mask_shift, 0), + (1, 1, query_length, max_decoder_length), + ) + else: + causal_mask = self.causal_mask[:, :, :query_length, :key_length] + + batch_size = hidden_states.shape[0] + causal_mask = jnp.broadcast_to( + causal_mask, (batch_size,) + causal_mask.shape[1:] + ) + + attention_mask = jnp.broadcast_to( + jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape + ) + attention_mask = combine_masks(attention_mask, causal_mask, fcm_mask) + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + if self.has_variable("cache", "cached_key") or init_cache: + xk, xv, attention_mask = self._concatenate_to_cache( + xk, xv, xq, attention_mask + ) + + # transform boolean mask into float mask + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype( + self.dtype + ), + ) + attn_weights = dot_product_attention_weights( + xq, + xk, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.config.attn_pdrop, + deterministic=deterministic, + dtype=jnp.promote_types(self.dtype, jnp.float32), + precision=self.precision, + ) + attn_weights = with_sharding_constraint( + attn_weights, PS(("dp", "fsdp"), "mp", None, None) + ) + attn_output = jnp.einsum( + "...hqk,...khd->...qhd", attn_weights, xv, precision=self.precision + ) + + attn_output = self._merge_heads(attn_output) + attn_output = self.wo(attn_output) + attn_output = self.resid_dropout(attn_output, deterministic=deterministic) + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) + return outputs + + +class FlaxLLaMAMLP(nn.Module): + config: LLaMAConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[jax.lax.Precision, str]] = None + + def setup(self) -> None: + config = self.config + + self.w1 = nn.Dense( + config.intermediate_size, + dtype=self.dtype, + param_dtype=self.param_dtype, + use_bias=False, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + precision=self.precision, + ) + self.w2 = nn.Dense( + config.hidden_size, + dtype=self.dtype, + param_dtype=self.param_dtype, + use_bias=False, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + precision=self.precision, + ) + self.w3 = nn.Dense( + config.intermediate_size, + dtype=self.dtype, + param_dtype=self.param_dtype, + use_bias=False, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + precision=self.precision, + ) + self.dropout = nn.Dropout(rate=self.config.resid_pdrop) + + def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + x = self.w2(nn.silu(self.w1(x)) * self.w3(x)) + x = self.dropout(x, deterministic=deterministic) + return x + + +class FlaxLLaMABlock(nn.Module): + config: LLaMAConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[jax.lax.Precision, str]] = None + + def setup(self) -> None: + attention_module = FlaxLLaMAAttention + mlp_module = FlaxLLaMAMLP + if self.config.remat_attention != '': + attention_module = remat( + FlaxLLaMAAttention, + static_argnums=(3, 4, 5), + policy=get_gradient_checkpoint_policy(self.config.remat_attention), + prevent_cse=True, + ) + if self.config.remat_mlp != '': + mlp_module = remat( + FlaxLLaMAMLP, + static_argnums=(1,), + policy=get_gradient_checkpoint_policy(self.config.remat_mlp), + prevent_cse=True, + ) + + self.attention = attention_module( + self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + ) + self.feed_forward = mlp_module( + self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + ) + self.attention_norm = RMSNorm( + self.config.hidden_size, + eps=self.config.rms_norm_eps, + dtype=self.dtype, + param_dtype=self.param_dtype, + ) + self.ffn_norm = RMSNorm( + self.config.hidden_size, + eps=self.config.rms_norm_eps, + dtype=self.dtype, + param_dtype=self.param_dtype, + ) + + def __call__( + self, + hidden_states, + attention_mask=None, + position_ids=None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + fcm_mask: Optional[jnp.ndarray] = None, + ): + attn_outputs = self.attention( + self.attention_norm(hidden_states), + attention_mask, + position_ids, + deterministic, + init_cache, + output_attentions, + fcm_mask, + ) + attn_output = attn_outputs[0] + hidden_states = hidden_states + attn_output + + feed_forward_input = self.ffn_norm(hidden_states) + + if self.config.scan_mlp: + feed_forward_hidden_states = blockwise_ffn( + self.feed_forward, + feed_forward_input, + self.config.scan_mlp_chunk_size, + deterministic, + ) + else: + feed_forward_hidden_states = self.feed_forward( + feed_forward_input, + deterministic, + ) + feed_forward_hidden_states = with_sharding_constraint( + feed_forward_hidden_states, PS(("dp", "fsdp"), None, "mp") + ) + + hidden_states = hidden_states + feed_forward_hidden_states + + return (hidden_states,) + attn_outputs[1:] + + +class FlaxLLaMAPreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = LLaMAConfig + base_model_prefix = "transformer" + module_class: nn.Module = None + + def __init__( + self, + config: LLaMAConfig, + input_shape: Tuple = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__( + config, + module, + input_shape=input_shape, + seed=seed, + dtype=dtype, + _do_init=_do_init, + ) + + def init_weights( + self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None + ) -> FrozenDict: + # init input tensors + + input_ids = jnp.zeros(input_shape, dtype="i4") + attention_mask = jnp.ones_like(input_ids) + position_ids = jnp.broadcast_to( + jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape + ) + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + if self.config.add_cross_attention: + encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,)) + encoder_attention_mask = attention_mask + module_init_outputs = self.module.init( + rngs, + input_ids, + attention_mask, + position_ids, + encoder_hidden_states, + encoder_attention_mask, + return_dict=False, + ) + else: + module_init_outputs = self.module.init( + rngs, input_ids, attention_mask, position_ids, return_dict=False + ) + + random_params = module_init_outputs["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + def init_cache(self, batch_size, max_length): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + """ + # init input variables to retrieve cache + # if other_shape is None: + input_ids = jnp.ones((batch_size, max_length)) + attention_mask = jnp.ones_like(input_ids) + position_ids = jnp.broadcast_to( + jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape + ) + + init_variables = self.module.init( + jax.random.PRNGKey(0), + input_ids, + attention_mask, + position_ids, + return_dict=False, + init_cache=True, + ) + return init_variables["cache"] + # else: + # input_ids = jnp.ones(other_shape) + # attention_mask = jnp.ones_like(jnp.ones((1, other_shape[1]))) + # position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(attention_mask).shape[-1]), attention_mask.shape) + + # init_variables = self.module.init( + # jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True + # ) + # return init_variables["cache"] + + @add_start_docstrings_to_model_forward("") + def __call__( + self, + input_ids, + attention_mask=None, + position_ids=None, + params: dict = None, + past_key_values: dict = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + splitlayer: Tuple = (0, 1), + ): + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.return_dict + ) + if len(input_ids.shape) == 2: + batch_size, sequence_length = input_ids.shape + + if position_ids is None: + if past_key_values is not None: + raise ValueError( + "Make sure to provide `position_ids` when passing `past_key_values`." + ) + + position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + if attention_mask is None: + attention_mask = jnp.ones((batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a + # private flag init_cache has to be passed down to ensure cache is used. + # It has to be made sure that cache is marked as mutable so that it can + # be changed by FlaxGPTJAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + jnp.array(position_ids, dtype="i4"), + not train, + False, + output_attentions, + output_hidden_states, + return_dict, + rngs=rngs, + mutable=mutable, + splitlayer=splitlayer, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past_key_values = outputs + outputs["past_key_values"] = unfreeze(past_key_values["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past_key_values = outputs + outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] + + return outputs + + +class FlaxLLaMAPreTrainedModelServer(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = LLaMAConfig + base_model_prefix = "transformer" + module_class: nn.Module = None + + def __init__( + self, + config: LLaMAConfig, + input_shape: Tuple = (1, 1, 4096), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__( + config, + module, + input_shape=input_shape, + seed=seed, + dtype=dtype, + _do_init=_do_init, + ) + + def init_weights( + self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None + ) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + tmp_ids = jnp.zeros(input_shape[:2], dtype="i4") + attention_mask = jnp.ones_like(jnp.zeros(input_shape[:2], dtype="i4")) + position_ids = jnp.broadcast_to( + jnp.arange(jnp.atleast_2d(tmp_ids).shape[-1]), input_shape[:2] + ) + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + if self.config.add_cross_attention: + encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,)) + encoder_attention_mask = attention_mask + module_init_outputs = self.module.init( + rngs, + input_ids, + attention_mask, + position_ids, + encoder_hidden_states, + encoder_attention_mask, + return_dict=False, + ) + else: + module_init_outputs = self.module.init( + rngs, input_ids, attention_mask, position_ids, return_dict=False + ) + + random_params = module_init_outputs["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + def init_cache(self, batch_size, max_length): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + """ + # init input variables to retrieve cache + input_ids = jnp.ones((batch_size, max_length)) + attention_mask = jnp.ones_like(input_ids) + position_ids = jnp.broadcast_to( + jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape + ) + + init_variables = self.module.init( + jax.random.PRNGKey(0), + input_ids, + attention_mask, + position_ids, + return_dict=False, + init_cache=True, + ) + return init_variables["cache"] + + @add_start_docstrings_to_model_forward("") + def __call__( + self, + input_ids, + attention_mask=None, + position_ids=None, + params: dict = None, + past_key_values: dict = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + splitlayer: Tuple = (0, 1), + ): + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.return_dict + ) + + batch_size, sequence_length = input_ids.shape[:2] + + if position_ids is None: + if past_key_values is not None: + raise ValueError( + "Make sure to provide `position_ids` when passing `past_key_values`." + ) + + position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + if attention_mask is None: + attention_mask = jnp.ones((batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a + # private flag init_cache has to be passed down to ensure cache is used. + # It has to be made sure that cache is marked as mutable so that it can + # be changed by FlaxGPTJAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype=jnp.float32), + jnp.array(attention_mask, dtype="i4"), + jnp.array(position_ids, dtype="i4"), + not train, + False, + output_attentions, + output_hidden_states, + return_dict, + rngs=rngs, + mutable=mutable, + splitlayer=splitlayer, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past_key_values = outputs + outputs["past_key_values"] = unfreeze(past_key_values["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past_key_values = outputs + outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] + + return outputs + + +class FlaxLLaMABlockCollection(nn.Module): + config: LLaMAConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[jax.lax.Precision, str]] = None + + def setup(self): + block = FlaxLLaMABlock + if self.config.remat_block != '': + block = remat( + FlaxLLaMABlock, + static_argnums=(3, 4, 5), + policy=get_gradient_checkpoint_policy(self.config.remat_block), + ) + self.blocks = [ + block( + self.config, + name=str(i), + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + ) + for i in range(self.config.num_hidden_layers) + ] + + def __call__( + self, + hidden_states, + attention_mask=None, + position_ids=None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + if not deterministic and self.config.fcm_max_ratio > 0: + # Apply forgetful causal mask + batch_size, seq_length = hidden_states.shape[0], hidden_states.shape[1] + fcm_ratio = jax.random.uniform( + self.make_rng('fcm'), + shape=(batch_size, 1, 1, 1), + minval=self.config.fcm_min_ratio, + maxval=self.config.fcm_max_ratio, + ) + fcm_mask = ( + jax.random.uniform( + self.make_rng('fcm'), shape=(batch_size, 1, 1, seq_length) + ) + > fcm_ratio + ) + fcm_mask = fcm_mask.at[:, :, :, 0].set(True) + fcm_mask = fcm_mask.astype('bool') + else: + fcm_mask = None + + for block in self.blocks: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = block( + hidden_states, + attention_mask, + position_ids, + deterministic, + init_cache, + output_attentions, + fcm_mask, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + # this contains possible `None` values - `FlaxGPTJModule` will filter + # them out + outputs = (hidden_states, all_hidden_states, all_attentions) + + return outputs + + +class FlaxLLaMABlockCollectionClient(nn.Module): + config: LLaMAConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[jax.lax.Precision, str]] = None + + def setup(self): + block = FlaxLLaMABlock + if self.config.remat_block != '': + block = remat( + FlaxLLaMABlock, + static_argnums=(3, 4, 5), + policy=get_gradient_checkpoint_policy(self.config.remat_block), + ) + self.blocks = [ + block( + self.config, + name=str(i), + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + ) + for i in range(self.config.num_hidden_layers) + ] + + def __call__( + self, + hidden_states, + attention_mask=None, + position_ids=None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + splitlayer: Tuple = (0, 1), + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + if not deterministic and self.config.fcm_max_ratio > 0: + # Apply forgetful causal mask + batch_size, seq_length = hidden_states.shape[0], hidden_states.shape[1] + fcm_ratio = jax.random.uniform( + self.make_rng('fcm'), + shape=(batch_size, 1, 1, 1), + minval=self.config.fcm_min_ratio, + maxval=self.config.fcm_max_ratio, + ) + fcm_mask = ( + jax.random.uniform( + self.make_rng('fcm'), shape=(batch_size, 1, 1, seq_length) + ) + > fcm_ratio + ) + fcm_mask = fcm_mask.at[:, :, :, 0].set(True) + fcm_mask = fcm_mask.astype('bool') + else: + fcm_mask = None + + for block in self.blocks[:2]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = block( + hidden_states, + attention_mask, + position_ids, + deterministic, + init_cache, + output_attentions, + fcm_mask, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + # this contains possible `None` values - `FlaxGPTJModule` will filter + # them out + outputs = (hidden_states, all_hidden_states, all_attentions) + + # return outputs + return hidden_states + + +class FlaxLLaMABlockCollectionMid(nn.Module): + config: LLaMAConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[jax.lax.Precision, str]] = None + + def setup(self): + block = FlaxLLaMABlock + if self.config.remat_block != '': + block = remat( + FlaxLLaMABlock, + static_argnums=(3, 4, 5), + policy=get_gradient_checkpoint_policy(self.config.remat_block), + ) + self.blocks = [ + block( + self.config, + name=str(i), + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + ) + for i in range(self.config.num_hidden_layers) + ] + + def __call__( + self, + hidden_states, + attention_mask=None, + position_ids=None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + splitlayer: Tuple = (0, 1), + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + if not deterministic and self.config.fcm_max_ratio > 0: + # Apply forgetful causal mask + batch_size, seq_length = hidden_states.shape[0], hidden_states.shape[1] + fcm_ratio = jax.random.uniform( + self.make_rng('fcm'), + shape=(batch_size, 1, 1, 1), + minval=self.config.fcm_min_ratio, + maxval=self.config.fcm_max_ratio, + ) + fcm_mask = ( + jax.random.uniform( + self.make_rng('fcm'), shape=(batch_size, 1, 1, seq_length) + ) + > fcm_ratio + ) + fcm_mask = fcm_mask.at[:, :, :, 0].set(True) + fcm_mask = fcm_mask.astype('bool') + else: + fcm_mask = None + + for block in self.blocks[2:3]: + # print(block) + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = block( + hidden_states, + attention_mask, + position_ids, + deterministic, + init_cache, + output_attentions, + fcm_mask, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + # this contains possible `None` values - `FlaxGPTJModule` will filter + # them out + outputs = (hidden_states, all_hidden_states, all_attentions) + + # return outputs + return hidden_states + + +class FlaxLLaMABlockCollectionServer(nn.Module): + config: LLaMAConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[jax.lax.Precision, str]] = None + + def setup(self): + block = FlaxLLaMABlock + if self.config.remat_block != '': + block = remat( + FlaxLLaMABlock, + static_argnums=(3, 4, 5), + policy=get_gradient_checkpoint_policy(self.config.remat_block), + ) + self.blocks = [ + block( + self.config, + name=str(i), + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + ) + for i in range(self.config.num_hidden_layers) + ] + + def __call__( + self, + hidden_states, + attention_mask=None, + position_ids=None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + splitlayer: Tuple = (0, 1), + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + if not deterministic and self.config.fcm_max_ratio > 0: + # Apply forgetful causal mask + batch_size, seq_length = hidden_states.shape[0], hidden_states.shape[1] + fcm_ratio = jax.random.uniform( + self.make_rng('fcm'), + shape=(batch_size, 1, 1, 1), + minval=self.config.fcm_min_ratio, + maxval=self.config.fcm_max_ratio, + ) + fcm_mask = ( + jax.random.uniform( + self.make_rng('fcm'), shape=(batch_size, 1, 1, seq_length) + ) + > fcm_ratio + ) + fcm_mask = fcm_mask.at[:, :, :, 0].set(True) + fcm_mask = fcm_mask.astype('bool') + else: + fcm_mask = None + + for block in self.blocks[3:]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = block( + hidden_states, + attention_mask, + position_ids, + deterministic, + init_cache, + output_attentions, + fcm_mask, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + # this contains possible `None` values - `FlaxGPTJModule` will filter + # them out + outputs = (hidden_states, all_hidden_states, all_attentions) + + return outputs + + +class FlaxLLaMAModule(nn.Module): + config: LLaMAConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[jax.lax.Precision, str]] = None + + def setup(self): + self.embed_dim = self.config.hidden_size + + self.wte = nn.Embed( + self.config.vocab_size, + self.config.hidden_size, + embedding_init=jax.nn.initializers.normal( + stddev=self.config.initializer_range + ), + dtype=self.dtype, + param_dtype=self.param_dtype, + ) + self.dropout = nn.Dropout(rate=self.config.embd_pdrop) + self.h = FlaxLLaMABlockCollection( + self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + ) + self.ln_f = RMSNorm( + self.config.hidden_size, + eps=self.config.rms_norm_eps, + dtype=self.dtype, + param_dtype=self.param_dtype, + ) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + deterministic=True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + input_embeds = self.wte(input_ids.astype("i4")) + + hidden_states = self.dropout(input_embeds, deterministic=deterministic) + + outputs = self.h( + hidden_states, + attention_mask, + position_ids=position_ids, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.ln_f(hidden_states) + + if output_hidden_states: + all_hidden_states = outputs[1] + (hidden_states,) + outputs = (hidden_states, all_hidden_states) + outputs[2:] + else: + outputs = (hidden_states,) + outputs[1:] + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=outputs[1], + attentions=outputs[-1], + ) + + +class FlaxLLaMAModuleClientEmbed(nn.Module): + config: LLaMAConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[jax.lax.Precision, str]] = None + + def setup(self): + self.embed_dim = self.config.hidden_size + + self.wte = nn.Embed( + self.config.vocab_size, + self.config.hidden_size, + embedding_init=jax.nn.initializers.normal( + stddev=self.config.initializer_range + ), + dtype=self.dtype, + param_dtype=self.param_dtype, + ) + self.dropout = nn.Dropout(rate=self.config.embd_pdrop) + self.h = FlaxLLaMABlockCollectionClient( + self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + ) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + deterministic=True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + splitlayer: Tuple = (0, 1), + ): + input_embeds = self.wte(input_ids.astype("i4")) + + hidden_states = self.dropout(input_embeds, deterministic=deterministic) + outputs = self.h( + hidden_states, + attention_mask, + position_ids=position_ids, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + splitlayer=splitlayer, + ) + + return outputs, attention_mask, position_ids + + +class FlaxLLaMAModuleMidEmbed(nn.Module): + config: LLaMAConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[jax.lax.Precision, str]] = None + + def setup(self): + self.embed_dim = self.config.hidden_size + + self.wte = nn.Embed( + self.config.vocab_size, + self.config.hidden_size, + embedding_init=jax.nn.initializers.normal( + stddev=self.config.initializer_range + ), + dtype=self.dtype, + param_dtype=self.param_dtype, + ) + self.dropout = nn.Dropout(rate=self.config.embd_pdrop) + self.h = FlaxLLaMABlockCollectionMid( + self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + ) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + deterministic=True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + splitlayer: Tuple = (0, 1), + ): + if len(input_ids.shape) == 2: + input_embeds = self.wte(input_ids.astype("i4")) + + hidden_states = self.dropout(input_embeds, deterministic=deterministic) + + else: + hidden_states = input_ids # 暂时性的绕过aJax + + outputs = self.h( + hidden_states, + attention_mask, + position_ids=position_ids, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + splitlayer=splitlayer, + ) + + return outputs + + +class FlaxLLaMAModuleServerEmbed(nn.Module): + config: LLaMAConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[jax.lax.Precision, str]] = None + + def setup(self): + self.h = FlaxLLaMABlockCollectionServer( + self.config, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + ) + self.ln_f = RMSNorm( + self.config.hidden_size, + eps=self.config.rms_norm_eps, + dtype=self.dtype, + param_dtype=self.param_dtype, + ) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + deterministic=True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + server: bool = False, + splitlayer: int = [0, 1], + ): + if len(input_ids.shape) == 2: + input_embeds = self.wte(input_ids.astype("i4")) + hidden_states = self.dropout(input_embeds, deterministic=deterministic) + + else: + hidden_states = input_ids + + outputs = self.h( + hidden_states, + attention_mask, + position_ids=position_ids, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + splitlayer=splitlayer, + ) + + hidden_states = outputs[0] + hidden_states = self.ln_f(hidden_states) + + if output_hidden_states: + all_hidden_states = outputs[1] + (hidden_states,) + outputs = (hidden_states, all_hidden_states) + outputs[2:] + else: + outputs = (hidden_states,) + outputs[1:] + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=outputs[1], + attentions=outputs[-1], + ) + + +@add_start_docstrings("", "") +class FlaxLLaMAModel(FlaxLLaMAPreTrainedModel): + module_class = FlaxLLaMAModule + + +@add_start_docstrings("", "") +class FlaxLLaMAModelServer(FlaxLLaMAPreTrainedModelServer): + module_class = FlaxLLaMAModuleServerEmbed + + +@add_start_docstrings("", "") +class FlaxLLaMAModelClient(FlaxLLaMAPreTrainedModel): + module_class = FlaxLLaMAModuleClientEmbed + + +# @add_start_docstrings("", "") +# class FlaxLLaMAModelServer(FlaxLLaMAPreTrainedModel): +# module_class = FlaxLLaMAModule + +# append_call_sample_docstring( +# FlaxLLaMAModel, +# _TOKENIZER_FOR_DOC, +# _CHECKPOINT_FOR_DOC, +# FlaxCausalLMOutput, +# _CONFIG_FOR_DOC, +# ) + + +class FlaxLLaMAForCausalLMModule(nn.Module): + config: LLaMAConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[jax.lax.Precision, str]] = None + + def setup(self): + self.transformer = FlaxLLaMAModule(self.config, dtype=self.dtype) + self.lm_head = nn.Dense( + self.config.vocab_size, + dtype=self.dtype, + param_dtype=self.param_dtype, + use_bias=False, + kernel_init=jax.nn.initializers.normal( + stddev=self.config.initializer_range + ), + precision=self.precision, + ) + + def __call__( + self, + input_ids, + attention_mask=None, + position_ids=None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + batch_size, seq_length = input_ids.shape + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + if position_ids is None: + position_ids = jnp.broadcast_to( + jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0), + (batch_size, seq_length), + ) + outputs = self.transformer( + input_ids, + attention_mask, + position_ids, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + if self.config.tie_word_embeddings: + shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T + lm_logits = self.lm_head.apply( + {"params": {"kernel": shared_kernel}}, hidden_states + ) + else: + lm_logits = self.lm_head(hidden_states) + + if not return_dict: + return (lm_logits,) + outputs[1:] + + return FlaxCausalLMOutput( + logits=lm_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class FlaxLLaMAForCausalLMServerEmbedModule(nn.Module): + config: LLaMAConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[jax.lax.Precision, str]] = None + + def setup(self): + self.transformer = FlaxLLaMAModuleServerEmbed(self.config, dtype=self.dtype) + self.lm_head = nn.Dense( + self.config.vocab_size, + dtype=self.dtype, + param_dtype=self.param_dtype, + use_bias=False, + kernel_init=jax.nn.initializers.normal( + stddev=self.config.initializer_range + ), + precision=self.precision, + ) + + def __call__( + self, + input_ids, + attention_mask=None, + position_ids=None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + splitlayer: Tuple = (0, 1), + ): + if len(input_ids.shape) == 2: + batch_size, seq_length = input_ids.shape + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + if position_ids is None: + position_ids = jnp.broadcast_to( + jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0), + (batch_size, seq_length), + ) + else: + assert attention_mask is not None + outputs = self.transformer( + input_ids, + attention_mask, + position_ids, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + splitlayer=splitlayer, + ) + + hidden_states = outputs[0] + + if self.config.tie_word_embeddings: + shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T + lm_logits = self.lm_head.apply( + {"params": {"kernel": shared_kernel}}, hidden_states + ) + else: + lm_logits = self.lm_head(hidden_states) + + if not return_dict: + return (lm_logits,) + outputs[1:] + + return FlaxCausalLMOutput( + logits=lm_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class FlaxLLaMAForCausalLMMidEmbedModule(nn.Module): + config: LLaMAConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[jax.lax.Precision, str]] = None + + def setup(self): + self.transformer = FlaxLLaMAModuleMidEmbed(self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask=None, + position_ids=None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + splitlayer: Tuple = (0, 1), + ): + if len(input_ids.shape) == 2: + batch_size, seq_length = input_ids.shape + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + if position_ids is None: + position_ids = jnp.broadcast_to( + jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0), + (batch_size, seq_length), + ) + else: + assert attention_mask is not None + + outputs = self.transformer( + input_ids, + attention_mask, + position_ids, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + splitlayer=splitlayer, + ) + + return outputs + + +class FlaxLLaMAForCausalLMClientEmbedModule(nn.Module): + config: LLaMAConfig + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + precision: Optional[Union[jax.lax.Precision, str]] = None + + def setup(self): + self.transformer = FlaxLLaMAModuleClientEmbed(self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask=None, + position_ids=None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + splitlayer: Tuple = (0, 1), + ): + batch_size, seq_length = input_ids.shape + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + if position_ids is None: + position_ids = jnp.broadcast_to( + jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0), + (batch_size, seq_length), + ) + outputs = self.transformer( + input_ids, + attention_mask, + position_ids, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + splitlayer=splitlayer, + ) + + return outputs + + +@add_start_docstrings("", "") +class FlaxLLaMAForCausalLMMid(FlaxLLaMAPreTrainedModelServer): + module_class = FlaxLLaMAForCausalLMMidEmbedModule + + def prepare_inputs_for_generation( + self, input_ids, max_length, attention_mask: Optional[jax.Array] = None + ): + # initializing the cache + + if len(input_ids.shape) == 2: + batch_size, seq_length = input_ids.shape + else: + batch_size, seq_length = 1, 2048 + + past_key_values = self.init_cache(batch_size, max_length) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since GPTJ uses a causal mask, those positions are masked anyways. + # Thus we can create a single static attention_mask here, which is more + # efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if attention_mask is not None: + position_ids = attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice( + extended_attention_mask, attention_mask, (0, 0) + ) + else: + position_ids = jnp.broadcast_to( + jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length) + ) + + return { + "past_key_values": past_key_values, + "attention_mask": extended_attention_mask, + "position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 + return model_kwargs + + +@add_start_docstrings("", "") +class FlaxLLaMAForCausalLMServer(FlaxLLaMAPreTrainedModelServer): + module_class = FlaxLLaMAForCausalLMServerEmbedModule + + def prepare_inputs_for_generation( + self, input_ids, max_length, attention_mask: Optional[jax.Array] = None + ): + # initializing the cache + + if len(input_ids.shape) == 2: + batch_size, seq_length = input_ids.shape + else: + batch_size, seq_length = input_ids.shape[:2] + + past_key_values = self.init_cache(batch_size, max_length) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since GPTJ uses a causal mask, those positions are masked anyways. + # Thus we can create a single static attention_mask here, which is more + # efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if attention_mask is not None: + position_ids = attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice( + extended_attention_mask, attention_mask, (0, 0) + ) + else: + position_ids = jnp.broadcast_to( + jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length) + ) + + return { + "past_key_values": past_key_values, + "attention_mask": extended_attention_mask, + "position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 + return model_kwargs + + +@add_start_docstrings("", "") +class FlaxLLaMAForCausalLMClient(FlaxLLaMAPreTrainedModel): + module_class = FlaxLLaMAForCausalLMClientEmbedModule + + def prepare_inputs_for_generation( + self, input_ids, max_length, attention_mask: Optional[jax.Array] = None + ): + # initializing the cache + batch_size, seq_length = input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since GPTJ uses a causal mask, those positions are masked anyways. + # Thus we can create a single static attention_mask here, which is more + # efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if attention_mask is not None: + position_ids = attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice( + extended_attention_mask, attention_mask, (0, 0) + ) + else: + position_ids = jnp.broadcast_to( + jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length) + ) + + return { + "past_key_values": past_key_values, + "attention_mask": extended_attention_mask, + "position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 + return model_kwargs + + +@add_start_docstrings("", "") +class FlaxLLaMAForCausalLM(FlaxLLaMAPreTrainedModel): + module_class = FlaxLLaMAForCausalLMModule + + def prepare_inputs_for_generation( + self, input_ids, max_length, attention_mask: Optional[jax.Array] = None + ): + # initializing the cache + + batch_size, seq_length = input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since GPTJ uses a causal mask, those positions are masked anyways. + # Thus we can create a single static attention_mask here, which is more + # efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if attention_mask is not None: + position_ids = attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice( + extended_attention_mask, attention_mask, (0, 0) + ) + else: + position_ids = jnp.broadcast_to( + jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length) + ) + + return { + "past_key_values": past_key_values, + "attention_mask": extended_attention_mask, + "position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 + return model_kwargs + + +# append_call_sample_docstring( +# FlaxGPTJForCausalLM, +# _TOKENIZER_FOR_DOC, +# _CHECKPOINT_FOR_DOC, +# FlaxCausalLMOutput, +# _CONFIG_FOR_DOC, +# ) + + +VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} + +PRETRAINED_VOCAB_FILES_MAP = {} + + +class LLaMATokenizer(PreTrainedTokenizer): + """ + Construct a LLaMA tokenizer. Based on byte-level Byte-Pair-Encoding. + Args: + vocab_file (`str`): + Path to the vocabulary file. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + unk_token="", + bos_token="", + eos_token="", + sp_model_kwargs: Optional[Dict[str, Any]] = None, + add_bos_token=False, + add_eos_token=False, + **kwargs, + ): + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + super().__init__( + bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs + ) + self.vocab_file = vocab_file + self.add_bos_token = add_bos_token + self.add_eos_token = add_eos_token + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + + with tempfile.NamedTemporaryFile() as tfile: + with open_file(self.vocab_file, 'rb') as fin: + tfile.write(fin.read()) + tfile.flush() + tfile.seek(0) + self.sp_model.Load(tfile.name) + """ Initialisation""" + self.add_special_tokens( + dict( + unk_token=unk_token, + bos_token=bos_token, + eos_token=eos_token, + ) + ) + self.pad_token_id = self.unk_token_id + + @property + def vocab_size(self): + """Returns vocab size""" + return self.sp_model.get_piece_size() + + @property + def bos_token_id(self) -> Optional[int]: + return self.sp_model.bos_id() + + @property + def eos_token_id(self) -> Optional[int]: + return self.sp_model.eos_id() + + def get_vocab(self): + """Returns vocab as a dict""" + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text): + """Returns a tokenized string.""" + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = self.sp_model.IdToPiece(index) + return token + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece + # model + if token in self.all_special_tokens: + if not prev_is_special: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string.strip() + + def save_vocabulary( + self, save_directory, filename_prefix: Optional[str] = None + ) -> Tuple[str]: + """ + Save the vocabulary and special tokens file to a directory. + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, + (filename_prefix + "-" if filename_prefix else "") + + VOCAB_FILES_NAMES["vocab_file"], + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath( + out_vocab_file + ) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + if self.add_bos_token: + bos_token_ids = [self.bos_token_id] + else: + bos_token_ids = [] + + output = bos_token_ids + token_ids_0 + + if token_ids_1 is not None: + output = output + token_ids_1 + + if self.add_eos_token: + output = output + [self.eos_token_id] + + return output + + def get_special_tokens_mask( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None, + already_has_special_tokens: bool = False, + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, + token_ids_1=token_ids_1, + already_has_special_tokens=True, + ) + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make + use of token type ids, therefore a list of zeros is returned. + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + Returns: + `List[int]`: List of zeros. + """ + eos = [self.eos_token_id] + + if token_ids_1 is None: + return len(token_ids_0 + eos) * [0] + return len(token_ids_0 + eos + token_ids_1 + eos) * [0]