Skip to content

Commit

Permalink
reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
Rainysponge committed Oct 30, 2023
1 parent 34362bf commit b133688
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 44 deletions.
10 changes: 3 additions & 7 deletions examples/python/ml/flax_llama7b_split/flax_llama7b_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@

parser = argparse.ArgumentParser(description='distributed driver.')
parser.add_argument(
"-c", "--config", default="examples/python/ml/flax_llama_split/3pc.json")
"-c", "--config", default="examples/python/ml/flax_llama_split/3pc.json"
)
args = parser.parse_args()

with open(args.config, 'r') as file:
Expand Down Expand Up @@ -83,7 +84,6 @@

mid_params_dict = {
"transformer": {

"h": {str(i): params['params']["transformer"]["h"][str(i)] for i in range(2, 3)}
}
}
Expand All @@ -95,7 +95,6 @@
str(i): params['params']["transformer"]["h"][str(i)]
for i in range(3, len(params['params']["transformer"]["h"]))
},

},
"lm_head": {
"kernel": params['params']["lm_head"]["kernel"],
Expand Down Expand Up @@ -193,7 +192,6 @@ def embeding_generation(input_ids, params):


def mid_generation(input_ids, params, attention_mask, position_ids):

config = LLaMAConfig()
_model = FlaxLLaMAForCausalLMMid(config=config)

Expand Down Expand Up @@ -222,7 +220,6 @@ def server_generation(input_ids, params, attention_mask, position_ids):


def run_on_cpu(token_num=9):

input_ids = tokenizer.encode(
'Q: What is the largest animal?\nA:', return_tensors='jax'
)
Expand Down Expand Up @@ -288,8 +285,7 @@ def run_on_spu(token_num=9):
next_token_logits = outputs[0][0, -1, :]
next_token = jnp.argmax(next_token_logits)

input_ids = jnp.concatenate(
[input_ids, jnp.array([[next_token]])], axis=1)
input_ids = jnp.concatenate([input_ids, jnp.array([[next_token]])], axis=1)

return input_ids

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -802,8 +802,8 @@ def init_cache(self, batch_size, max_length):
# if other_shape is None:
input_ids = jnp.ones((batch_size, max_length))
attention_mask = jnp.ones_like(input_ids)
position_ids = jnp.broadcast_to(jnp.arange(
jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape
position_ids = jnp.broadcast_to(
jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape
)

init_variables = self.module.init(
Expand Down Expand Up @@ -939,18 +939,18 @@ def __init__(
input_shape=input_shape,
seed=seed,
dtype=dtype,
_do_init=_do_init
_do_init=_do_init,
)

def init_weights(
self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None
) -> FrozenDict:
self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None
) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
tmp_ids = jnp.zeros(input_shape[:2], dtype="i4")
attention_mask = jnp.ones_like(jnp.zeros(input_shape[:2], dtype="i4"))
position_ids = jnp.broadcast_to(jnp.arange(
jnp.atleast_2d(tmp_ids).shape[-1]), input_shape[:2]
position_ids = jnp.broadcast_to(
jnp.arange(jnp.atleast_2d(tmp_ids).shape[-1]), input_shape[:2]
)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
Expand All @@ -968,7 +968,6 @@ def init_weights(
return_dict=False,
)
else:

module_init_outputs = self.module.init(
rngs, input_ids, attention_mask, position_ids, return_dict=False
)
Expand Down Expand Up @@ -997,8 +996,8 @@ def init_cache(self, batch_size, max_length):
# init input variables to retrieve cache
input_ids = jnp.ones((batch_size, max_length))
attention_mask = jnp.ones_like(input_ids)
position_ids = jnp.broadcast_to(jnp.arange(
jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape
position_ids = jnp.broadcast_to(
jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape
)

init_variables = self.module.init(
Expand All @@ -1007,7 +1006,7 @@ def init_cache(self, batch_size, max_length):
attention_mask,
position_ids,
return_dict=False,
init_cache=True
init_cache=True,
)
return init_variables["cache"]

Expand Down Expand Up @@ -1108,17 +1107,19 @@ def setup(self):
block = FlaxLLaMABlock
if self.config.remat_block != '':
block = remat(
FlaxLLaMABlock, static_argnums=(3, 4, 5),
policy=get_gradient_checkpoint_policy(self.config.remat_block)
FlaxLLaMABlock,
static_argnums=(3, 4, 5),
policy=get_gradient_checkpoint_policy(self.config.remat_block),
)
self.blocks = [
block(
self.config,
name=str(i),
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision
) for i in range(self.config.num_hidden_layers)
precision=self.precision,
)
for i in range(self.config.num_hidden_layers)
]

def __call__(
Expand Down Expand Up @@ -1638,7 +1639,6 @@ class FlaxLLaMAModuleServerEmbed(nn.Module):
precision: Optional[Union[jax.lax.Precision, str]] = None

def setup(self):

self.h = FlaxLLaMABlockCollectionServer(
self.config,
dtype=self.dtype,
Expand All @@ -1665,7 +1665,6 @@ def __call__(
server: bool = False,
splitlayer: int = [0, 1],
):

if len(input_ids.shape) == 2:
input_embeds = self.wte(input_ids.astype("i4"))
hidden_states = self.dropout(input_embeds, deterministic=deterministic)
Expand Down Expand Up @@ -1709,7 +1708,6 @@ class FlaxLLaMAModel(FlaxLLaMAPreTrainedModel):
module_class = FlaxLLaMAModule



@add_start_docstrings("", "")
class FlaxLLaMAModelServer(FlaxLLaMAPreTrainedModelServer):
module_class = FlaxLLaMAModuleServerEmbed
Expand All @@ -1719,6 +1717,7 @@ class FlaxLLaMAModelServer(FlaxLLaMAPreTrainedModelServer):
class FlaxLLaMAModelClient(FlaxLLaMAPreTrainedModel):
module_class = FlaxLLaMAModuleClientEmbed


# @add_start_docstrings("", "")
# class FlaxLLaMAModelServer(FlaxLLaMAPreTrainedModel):
# module_class = FlaxLLaMAModule
Expand Down Expand Up @@ -1797,7 +1796,7 @@ def __call__(
return FlaxCausalLMOutput(
logits=lm_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions
attentions=outputs.attentions,
)


Expand Down Expand Up @@ -1950,7 +1949,7 @@ def __call__(
if position_ids is None:
position_ids = jnp.broadcast_to(
jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0),
(batch_size, seq_length)
(batch_size, seq_length),
)
outputs = self.transformer(
input_ids,
Expand All @@ -1972,8 +1971,8 @@ class FlaxLLaMAForCausalLMMid(FlaxLLaMAPreTrainedModelServer):
module_class = FlaxLLaMAForCausalLMMidEmbedModule

def prepare_inputs_for_generation(
self, input_ids, max_length, attention_mask: Optional[jax.Array] = None
):
self, input_ids, max_length, attention_mask: Optional[jax.Array] = None
):
# initializing the cache

if len(input_ids.shape) == 2:
Expand Down Expand Up @@ -2014,8 +2013,8 @@ class FlaxLLaMAForCausalLMServer(FlaxLLaMAPreTrainedModelServer):
module_class = FlaxLLaMAForCausalLMServerEmbedModule

def prepare_inputs_for_generation(
self, input_ids, max_length, attention_mask: Optional[jax.Array] = None
):
self, input_ids, max_length, attention_mask: Optional[jax.Array] = None
):
# initializing the cache

if len(input_ids.shape) == 2:
Expand Down Expand Up @@ -2056,8 +2055,8 @@ class FlaxLLaMAForCausalLMClient(FlaxLLaMAPreTrainedModel):
module_class = FlaxLLaMAForCausalLMClientEmbedModule

def prepare_inputs_for_generation(
self, input_ids, max_length, attention_mask: Optional[jax.Array] = None
):
self, input_ids, max_length, attention_mask: Optional[jax.Array] = None
):
# initializing the cache
batch_size, seq_length = input_ids.shape

Expand Down Expand Up @@ -2094,8 +2093,8 @@ class FlaxLLaMAForCausalLM(FlaxLLaMAPreTrainedModel):
module_class = FlaxLLaMAForCausalLMModule

def prepare_inputs_for_generation(
self, input_ids, max_length, attention_mask: Optional[jax.Array] = None
):
self, input_ids, max_length, attention_mask: Optional[jax.Array] = None
):
# initializing the cache

batch_size, seq_length = input_ids.shape
Expand Down Expand Up @@ -2243,8 +2242,8 @@ def convert_tokens_to_string(self, tokens):
return out_string.strip()

def save_vocabulary(
self, save_directory, filename_prefix: Optional[str] = None
) -> Tuple[str]:
self, save_directory, filename_prefix: Optional[str] = None
) -> Tuple[str]:
"""
Save the vocabulary and special tokens file to a directory.
Args:
Expand All @@ -2263,8 +2262,8 @@ def save_vocabulary(
)

if os.path.abspath(self.vocab_file) != os.path.abspath(
out_vocab_file
) and os.path.isfile(self.vocab_file):
out_vocab_file
) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
Expand All @@ -2290,11 +2289,11 @@ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
return output

def get_special_tokens_mask(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False
) -> List[int]:
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False
) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` method.
Expand Down

0 comments on commit b133688

Please sign in to comment.