diff --git a/examples/python/ml/flax_llama7b_split/flax_llama7b_split.py b/examples/python/ml/flax_llama7b_split/flax_llama7b_split.py index 6798a105..14ef45fc 100644 --- a/examples/python/ml/flax_llama7b_split/flax_llama7b_split.py +++ b/examples/python/ml/flax_llama7b_split/flax_llama7b_split.py @@ -46,7 +46,8 @@ parser = argparse.ArgumentParser(description='distributed driver.') parser.add_argument( - "-c", "--config", default="examples/python/ml/flax_llama_split/3pc.json") + "-c", "--config", default="examples/python/ml/flax_llama_split/3pc.json" +) args = parser.parse_args() with open(args.config, 'r') as file: @@ -83,7 +84,6 @@ mid_params_dict = { "transformer": { - "h": {str(i): params['params']["transformer"]["h"][str(i)] for i in range(2, 3)} } } @@ -95,7 +95,6 @@ str(i): params['params']["transformer"]["h"][str(i)] for i in range(3, len(params['params']["transformer"]["h"])) }, - }, "lm_head": { "kernel": params['params']["lm_head"]["kernel"], @@ -193,7 +192,6 @@ def embeding_generation(input_ids, params): def mid_generation(input_ids, params, attention_mask, position_ids): - config = LLaMAConfig() _model = FlaxLLaMAForCausalLMMid(config=config) @@ -222,7 +220,6 @@ def server_generation(input_ids, params, attention_mask, position_ids): def run_on_cpu(token_num=9): - input_ids = tokenizer.encode( 'Q: What is the largest animal?\nA:', return_tensors='jax' ) @@ -288,8 +285,7 @@ def run_on_spu(token_num=9): next_token_logits = outputs[0][0, -1, :] next_token = jnp.argmax(next_token_logits) - input_ids = jnp.concatenate( - [input_ids, jnp.array([[next_token]])], axis=1) + input_ids = jnp.concatenate([input_ids, jnp.array([[next_token]])], axis=1) return input_ids diff --git a/examples/python/ml/flax_llama7b_split/llama_model_splited_transformer.py b/examples/python/ml/flax_llama7b_split/llama_model_splited_transformer.py index 7d1bdcfb..f55356dd 100644 --- a/examples/python/ml/flax_llama7b_split/llama_model_splited_transformer.py +++ b/examples/python/ml/flax_llama7b_split/llama_model_splited_transformer.py @@ -802,8 +802,8 @@ def init_cache(self, batch_size, max_length): # if other_shape is None: input_ids = jnp.ones((batch_size, max_length)) attention_mask = jnp.ones_like(input_ids) - position_ids = jnp.broadcast_to(jnp.arange( - jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape + position_ids = jnp.broadcast_to( + jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape ) init_variables = self.module.init( @@ -939,18 +939,18 @@ def __init__( input_shape=input_shape, seed=seed, dtype=dtype, - _do_init=_do_init + _do_init=_do_init, ) def init_weights( - self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None - ) -> FrozenDict: + self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None + ) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") tmp_ids = jnp.zeros(input_shape[:2], dtype="i4") attention_mask = jnp.ones_like(jnp.zeros(input_shape[:2], dtype="i4")) - position_ids = jnp.broadcast_to(jnp.arange( - jnp.atleast_2d(tmp_ids).shape[-1]), input_shape[:2] + position_ids = jnp.broadcast_to( + jnp.arange(jnp.atleast_2d(tmp_ids).shape[-1]), input_shape[:2] ) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} @@ -968,7 +968,6 @@ def init_weights( return_dict=False, ) else: - module_init_outputs = self.module.init( rngs, input_ids, attention_mask, position_ids, return_dict=False ) @@ -997,8 +996,8 @@ def init_cache(self, batch_size, max_length): # init input variables to retrieve cache input_ids = jnp.ones((batch_size, max_length)) attention_mask = jnp.ones_like(input_ids) - position_ids = jnp.broadcast_to(jnp.arange( - jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape + position_ids = jnp.broadcast_to( + jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape ) init_variables = self.module.init( @@ -1007,7 +1006,7 @@ def init_cache(self, batch_size, max_length): attention_mask, position_ids, return_dict=False, - init_cache=True + init_cache=True, ) return init_variables["cache"] @@ -1108,8 +1107,9 @@ def setup(self): block = FlaxLLaMABlock if self.config.remat_block != '': block = remat( - FlaxLLaMABlock, static_argnums=(3, 4, 5), - policy=get_gradient_checkpoint_policy(self.config.remat_block) + FlaxLLaMABlock, + static_argnums=(3, 4, 5), + policy=get_gradient_checkpoint_policy(self.config.remat_block), ) self.blocks = [ block( @@ -1117,8 +1117,9 @@ def setup(self): name=str(i), dtype=self.dtype, param_dtype=self.param_dtype, - precision=self.precision - ) for i in range(self.config.num_hidden_layers) + precision=self.precision, + ) + for i in range(self.config.num_hidden_layers) ] def __call__( @@ -1638,7 +1639,6 @@ class FlaxLLaMAModuleServerEmbed(nn.Module): precision: Optional[Union[jax.lax.Precision, str]] = None def setup(self): - self.h = FlaxLLaMABlockCollectionServer( self.config, dtype=self.dtype, @@ -1665,7 +1665,6 @@ def __call__( server: bool = False, splitlayer: int = [0, 1], ): - if len(input_ids.shape) == 2: input_embeds = self.wte(input_ids.astype("i4")) hidden_states = self.dropout(input_embeds, deterministic=deterministic) @@ -1709,7 +1708,6 @@ class FlaxLLaMAModel(FlaxLLaMAPreTrainedModel): module_class = FlaxLLaMAModule - @add_start_docstrings("", "") class FlaxLLaMAModelServer(FlaxLLaMAPreTrainedModelServer): module_class = FlaxLLaMAModuleServerEmbed @@ -1719,6 +1717,7 @@ class FlaxLLaMAModelServer(FlaxLLaMAPreTrainedModelServer): class FlaxLLaMAModelClient(FlaxLLaMAPreTrainedModel): module_class = FlaxLLaMAModuleClientEmbed + # @add_start_docstrings("", "") # class FlaxLLaMAModelServer(FlaxLLaMAPreTrainedModel): # module_class = FlaxLLaMAModule @@ -1797,7 +1796,7 @@ def __call__( return FlaxCausalLMOutput( logits=lm_logits, hidden_states=outputs.hidden_states, - attentions=outputs.attentions + attentions=outputs.attentions, ) @@ -1950,7 +1949,7 @@ def __call__( if position_ids is None: position_ids = jnp.broadcast_to( jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0), - (batch_size, seq_length) + (batch_size, seq_length), ) outputs = self.transformer( input_ids, @@ -1972,8 +1971,8 @@ class FlaxLLaMAForCausalLMMid(FlaxLLaMAPreTrainedModelServer): module_class = FlaxLLaMAForCausalLMMidEmbedModule def prepare_inputs_for_generation( - self, input_ids, max_length, attention_mask: Optional[jax.Array] = None - ): + self, input_ids, max_length, attention_mask: Optional[jax.Array] = None + ): # initializing the cache if len(input_ids.shape) == 2: @@ -2014,8 +2013,8 @@ class FlaxLLaMAForCausalLMServer(FlaxLLaMAPreTrainedModelServer): module_class = FlaxLLaMAForCausalLMServerEmbedModule def prepare_inputs_for_generation( - self, input_ids, max_length, attention_mask: Optional[jax.Array] = None - ): + self, input_ids, max_length, attention_mask: Optional[jax.Array] = None + ): # initializing the cache if len(input_ids.shape) == 2: @@ -2056,8 +2055,8 @@ class FlaxLLaMAForCausalLMClient(FlaxLLaMAPreTrainedModel): module_class = FlaxLLaMAForCausalLMClientEmbedModule def prepare_inputs_for_generation( - self, input_ids, max_length, attention_mask: Optional[jax.Array] = None - ): + self, input_ids, max_length, attention_mask: Optional[jax.Array] = None + ): # initializing the cache batch_size, seq_length = input_ids.shape @@ -2094,8 +2093,8 @@ class FlaxLLaMAForCausalLM(FlaxLLaMAPreTrainedModel): module_class = FlaxLLaMAForCausalLMModule def prepare_inputs_for_generation( - self, input_ids, max_length, attention_mask: Optional[jax.Array] = None - ): + self, input_ids, max_length, attention_mask: Optional[jax.Array] = None + ): # initializing the cache batch_size, seq_length = input_ids.shape @@ -2243,8 +2242,8 @@ def convert_tokens_to_string(self, tokens): return out_string.strip() def save_vocabulary( - self, save_directory, filename_prefix: Optional[str] = None - ) -> Tuple[str]: + self, save_directory, filename_prefix: Optional[str] = None + ) -> Tuple[str]: """ Save the vocabulary and special tokens file to a directory. Args: @@ -2263,8 +2262,8 @@ def save_vocabulary( ) if os.path.abspath(self.vocab_file) != os.path.abspath( - out_vocab_file - ) and os.path.isfile(self.vocab_file): + out_vocab_file + ) and os.path.isfile(self.vocab_file): copyfile(self.vocab_file, out_vocab_file) elif not os.path.isfile(self.vocab_file): with open(out_vocab_file, "wb") as fi: @@ -2290,11 +2289,11 @@ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): return output def get_special_tokens_mask( - self, - token_ids_0: List[int], - token_ids_1: Optional[List[int]] = None, - already_has_special_tokens: bool = False - ) -> List[int]: + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None, + already_has_special_tokens: bool = False + ) -> List[int]: """ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer `prepare_for_model` method.