Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jun 18, 2024
1 parent 93ad39b commit 58bfcd0
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def __init__(
self.hh_score = None
self.min_seqlen = min_seqlen
self.idx = 0

self._past_length = 0


Expand Down Expand Up @@ -234,15 +234,15 @@ def clean_scores(self):
self.idx = 0
self.past_length = 0
self.hh_score = None

@property
def past_length(self):
return self._past_length

@past_length.setter
def past_length(self, value):
self._past_length = value

class H2OConfig(dict):
def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def prepare_inputs_for_generation(

input_ids = input_ids[:, remove_prefix_length:]

# the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
# the cache may be in the standard format (e.g. in contrastive search), convert to bloom's format if needed
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
past_key_values = self._convert_to_bloom_cache(past_key_values)

Expand Down Expand Up @@ -326,11 +326,11 @@ def forward(
return_dict: Optional[bool] = None,
**deprecated_arguments,
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
r"""Labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
if deprecated_arguments.pop("position_ids", False) is not False:
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
Expand Down Expand Up @@ -388,8 +388,7 @@ def forward(
def _reorder_cache(
self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
"""This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
beam_idx at every generation step.
Expand All @@ -408,4 +407,4 @@ def _reorder_cache(
)
for layer_past in standardized_past
)
return self._convert_to_bloom_cache(reordered_past)
return self._convert_to_bloom_cache(reordered_past)
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
self.hidden_size = config.hidden_size
if self.hidden_size % self.num_attention_heads != 0:
raise ValueError(
"The hidden size is not divisble by the number of attention heads! Make sure to update them"
"The hidden size is not divisible by the number of attention heads! Make sure to update them"
)
self.head_size = self.hidden_size // self.num_attention_heads
self.rotary_ndims = int(self.head_size * config.rotary_pct)
Expand Down Expand Up @@ -236,8 +236,9 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
return attn_output, attn_weights

class H2OGPTNeoXFlashAttention2(H2OGPTNeoXAttention):
"""
GPTNeoX flash attention module. This module inherits from `GPTNeoXAttention` as the weights of the module stays
"""GPTNeoX flash attention module.
This module inherits from `GPTNeoXAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
Expand All @@ -246,7 +247,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()

Expand Down Expand Up @@ -515,7 +516,7 @@ def __init__(self, config, h2o_config):
cls = H2OGPTNeoXFlashAttention2
else:
cls = H2OGPTNeoXAttention

self.gpt_neox.layers[layer_idx].self_attn = cls(
config,
layer_idx,
Expand Down Expand Up @@ -547,8 +548,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
r"""past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are
Expand Down Expand Up @@ -585,7 +585,8 @@ def forward(
>>> outputs = model(**inputs)
>>> prediction_logits = outputs.logits
```"""
```
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.gpt_neox(
Expand Down Expand Up @@ -678,4 +679,4 @@ def _reorder_cache(self, past_key_values, beam_idx):
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
+ layer_past[2:],
)
return reordered_past
return reordered_past
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def __init__(
self.h2o_min_seqlen = h2o_config.h2o_min_seqlen

self.h2o_kv_cache = H2OKVCache(self.heavy_ratio, self.recent_ratio, h2o_config.h2o_min_seqlen)

def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = LlamaRotaryEmbedding(
Expand All @@ -176,7 +176,7 @@ def _init_rope(self):
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")


def forward(
self,
Expand Down Expand Up @@ -645,7 +645,7 @@ def __init__(
self.model = LlamaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

num_layers = len(self.model.layers)
for layer_idx in range(num_layers):
module = self.model.layers[layer_idx].self_attn
Expand All @@ -658,19 +658,19 @@ def __init__(
cls = H2OLlamaSdpaAttention
else:
cls = H2OLlamaAttention

self.model.layers[layer_idx].self_attn = cls(
config,
layer_idx,
h2o_config
)

# Initialize weights and apply final processing
self.post_init()

self.ori_generate = self.generate
self.generate = partial(generate, self)

def get_input_embeddings(self):
return self.model.embed_tokens

Expand All @@ -688,7 +688,7 @@ def set_decoder(self, decoder):

def get_decoder(self):
return self.model

@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
Expand Down Expand Up @@ -783,7 +783,7 @@ def forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

def prepare_inputs_for_generation(
self,
input_ids,
Expand Down Expand Up @@ -864,7 +864,7 @@ def prepare_inputs_for_generation(
}
)
return model_inputs

def _update_causal_mask(
self,
attention_mask: torch.Tensor,
Expand Down Expand Up @@ -943,4 +943,4 @@ def _update_causal_mask(
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)

return causal_mask
return causal_mask
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,9 @@ def forward(
return attn_output, attn_weights, past_key_value

class H2OMistralFlashAttention2(H2OMistralAttention):
"""
Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays
"""Mistral flash attention module.
This module inherits from `MistralAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
Expand All @@ -242,7 +243,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()

Expand Down Expand Up @@ -689,7 +690,7 @@ def __init__(
cls = H2OMistralSdpaAttention
else:
cls = H2OMistralAttention

self.model.layers[layer_idx].self_attn = cls(
config,
layer_idx,
Expand Down Expand Up @@ -876,4 +877,4 @@ def _reorder_cache(past_key_values, beam_idx):
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
return reordered_past
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ def __init__(self, config, h2o_config: H2OConfig):
cls = H2OMixtralSdpaAttention
else:
cls = H2OMixtralAttention

self.model.layers[layer_idx].self_attn = cls(
config,
layer_idx,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
self,
config: OPTConfig,
is_decoder: bool = False,
h2o_config: H2OConfig =- None,
h2o_config: H2OConfig =- None,
):
super().__init__()
self.config = config
Expand Down Expand Up @@ -618,4 +618,4 @@ def _reorder_cache(past_key_values, beam_idx):
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
return reordered_past

0 comments on commit 58bfcd0

Please sign in to comment.