Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: n1ck-guo <[email protected]>
  • Loading branch information
n1ck-guo committed Jun 25, 2024
1 parent 2cc6a8f commit a9488b4
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,17 +113,17 @@
# get optimized model
if args.h2o:
print('Enable Small Cache Size')
from intel_extension_for_transformers.transformers.modeling.kv_cache_compression import H2OConfig, H2OLlamaForCausalLM
from intel_extension_for_transformers.transformers.kv_cache_compression import H2OConfig, LlamaForCausalLM
h2o_config = H2OConfig(
heavy_ratio=args.heavy_ratio,
recent_ratio=args.recent_ratio,
h2o_min_seqlen=args.h2o_min_seqlen,
real_drop=args.real_drop,
mean=False,
)
user_model = H2OLlamaForCausalLM.from_pretrained(
user_model = LlamaForCausalLM.from_pretrained(
args.model,
h2o_config=h2o_config,
prune_config=h2o_config,
trust_remote_code=args.trust_remote_code)
print("converted model: ", user_model)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .prune.h2o import H2OConfig, H2OKVPruner
from .models.modeling_llama import LlamaForCausalLM
from .models.modeling_gaudi_llama import GaudiLlamaForCausalLM
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,10 @@ def pre_attn_forward(

# pruning kv cache
if self.pruner.real_drop:
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask
if cache_position is not None:
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
if self.layer_idx == 0:
self.pruner.past_length += query_states.size(-2)
new_key_states, new_value_states = self.pruner.prune(
Expand Down Expand Up @@ -479,6 +483,10 @@ def pre_attn_forward(
attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) * self.norm_factor

if not self.pruner.real_drop:
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask
if cache_position is not None:
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
mask = self.pruner.get_mask(self, query_states, key_states, value_states,
causal_mask=causal_mask)
attn_weights[~mask] = torch.finfo(attn_weights.dtype).min
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss

from transformers.cache_utils import Cache, StaticCache
from transformers.cache_utils import Cache, StaticCache # pylint: disable=E0611
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.utils import (
add_start_docstrings_to_model_forward,
Expand Down Expand Up @@ -61,10 +61,10 @@
if version.parse(transformers.__version__) > version.parse("4.33.0"):
from transformers.utils import is_flash_attn_greater_or_equal_2_10, is_flash_attn_2_available
if is_flash_attn_2_available():
from flash_attn import (
from flash_attn import ( # pylint: disable=E0401
flash_attn_func,
flash_attn_varlen_func) # pylint: disable=E1101
from flash_attn.bert_padding import (
flash_attn_varlen_func) # pylint: disable=E0401
from flash_attn.bert_padding import ( # pylint: disable=E1101
index_first_axis,
pad_input,
unpad_input) # pylint: disable=E1101
Expand Down Expand Up @@ -861,7 +861,7 @@ def prepare_inputs_for_generation(
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
attention_mask = attention_mask[:, -max_cache_length:] # pylint: disable=E1130

position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
Expand Down Expand Up @@ -905,11 +905,6 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand All @@ -928,7 +923,7 @@ def _update_causal_mask(
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
): # pylint: disable=E1101
return None

dtype, device = input_tensor.dtype, input_tensor.device
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __call__(self, attn_score, key_states, value_states, **kwargs):
# hh-selection
mask = torch.zeros(self.hh_score.shape, dtype=attn_score.dtype).to(key_states.device)
if not self.recent_budget == 0:
mask[:,:,-self.recent_budget:] = 1
mask[:,:,-self.recent_budget:] = 1 # pylint: disable=E1130
select_hh_scores = self.hh_score[:,:,:seq_len - self.recent_budget]

if not self.heavy_budget == 0:
Expand Down

0 comments on commit a9488b4

Please sign in to comment.