From a9488b4ddc9f9346e677e64d7ecd58a11cb9ea3f Mon Sep 17 00:00:00 2001 From: n1ck-guo Date: Mon, 24 Jun 2024 21:48:15 -0400 Subject: [PATCH] update Signed-off-by: n1ck-guo --- .../text-generation/h2o/run_generation.py | 6 +++--- .../kv_cache_compression/__init__.py | 4 ++++ .../kv_cache_compression/models/__init__.py | 0 .../models/modeling_gaudi_llama.py | 8 ++++++++ .../models/modeling_llama.py | 17 ++++++----------- .../kv_cache_compression/prune/__init__.py | 0 .../kv_cache_compression/prune/base.py | 0 .../kv_cache_compression/prune/h2o.py | 2 +- 8 files changed, 22 insertions(+), 15 deletions(-) rename intel_extension_for_transformers/transformers/{modeling => }/kv_cache_compression/__init__.py (79%) rename intel_extension_for_transformers/transformers/{modeling => }/kv_cache_compression/models/__init__.py (100%) rename intel_extension_for_transformers/transformers/{modeling => }/kv_cache_compression/models/modeling_gaudi_llama.py (98%) rename intel_extension_for_transformers/transformers/{modeling => }/kv_cache_compression/models/modeling_llama.py (98%) rename intel_extension_for_transformers/transformers/{modeling => }/kv_cache_compression/prune/__init__.py (100%) rename intel_extension_for_transformers/transformers/{modeling => }/kv_cache_compression/prune/base.py (100%) rename intel_extension_for_transformers/transformers/{modeling => }/kv_cache_compression/prune/h2o.py (99%) diff --git a/examples/huggingface/pytorch/text-generation/h2o/run_generation.py b/examples/huggingface/pytorch/text-generation/h2o/run_generation.py index 94ba87dd634..225275d6323 100644 --- a/examples/huggingface/pytorch/text-generation/h2o/run_generation.py +++ b/examples/huggingface/pytorch/text-generation/h2o/run_generation.py @@ -113,7 +113,7 @@ # get optimized model if args.h2o: print('Enable Small Cache Size') - from intel_extension_for_transformers.transformers.modeling.kv_cache_compression import H2OConfig, H2OLlamaForCausalLM + from intel_extension_for_transformers.transformers.kv_cache_compression import H2OConfig, LlamaForCausalLM h2o_config = H2OConfig( heavy_ratio=args.heavy_ratio, recent_ratio=args.recent_ratio, @@ -121,9 +121,9 @@ real_drop=args.real_drop, mean=False, ) - user_model = H2OLlamaForCausalLM.from_pretrained( + user_model = LlamaForCausalLM.from_pretrained( args.model, - h2o_config=h2o_config, + prune_config=h2o_config, trust_remote_code=args.trust_remote_code) print("converted model: ", user_model) else: diff --git a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/__init__.py b/intel_extension_for_transformers/transformers/kv_cache_compression/__init__.py similarity index 79% rename from intel_extension_for_transformers/transformers/modeling/kv_cache_compression/__init__.py rename to intel_extension_for_transformers/transformers/kv_cache_compression/__init__.py index 1e8078e9b40..98d0edde605 100644 --- a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/__init__.py +++ b/intel_extension_for_transformers/transformers/kv_cache_compression/__init__.py @@ -14,3 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .prune.h2o import H2OConfig, H2OKVPruner +from .models.modeling_llama import LlamaForCausalLM +from .models.modeling_gaudi_llama import GaudiLlamaForCausalLM \ No newline at end of file diff --git a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/__init__.py b/intel_extension_for_transformers/transformers/kv_cache_compression/models/__init__.py similarity index 100% rename from intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/__init__.py rename to intel_extension_for_transformers/transformers/kv_cache_compression/models/__init__.py diff --git a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_gaudi_llama.py b/intel_extension_for_transformers/transformers/kv_cache_compression/models/modeling_gaudi_llama.py similarity index 98% rename from intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_gaudi_llama.py rename to intel_extension_for_transformers/transformers/kv_cache_compression/models/modeling_gaudi_llama.py index ac2cba82bcf..ed3248687a0 100644 --- a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_gaudi_llama.py +++ b/intel_extension_for_transformers/transformers/kv_cache_compression/models/modeling_gaudi_llama.py @@ -417,6 +417,10 @@ def pre_attn_forward( # pruning kv cache if self.pruner.real_drop: + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask + if cache_position is not None: + causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]] if self.layer_idx == 0: self.pruner.past_length += query_states.size(-2) new_key_states, new_value_states = self.pruner.prune( @@ -479,6 +483,10 @@ def pre_attn_forward( attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) * self.norm_factor if not self.pruner.real_drop: + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask + if cache_position is not None: + causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]] mask = self.pruner.get_mask(self, query_states, key_states, value_states, causal_mask=causal_mask) attn_weights[~mask] = torch.finfo(attn_weights.dtype).min diff --git a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_llama.py b/intel_extension_for_transformers/transformers/kv_cache_compression/models/modeling_llama.py similarity index 98% rename from intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_llama.py rename to intel_extension_for_transformers/transformers/kv_cache_compression/models/modeling_llama.py index 3a961b9c1e0..ca97047f1d3 100644 --- a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_llama.py +++ b/intel_extension_for_transformers/transformers/kv_cache_compression/models/modeling_llama.py @@ -25,7 +25,7 @@ import torch.nn.functional as F from torch.nn import CrossEntropyLoss -from transformers.cache_utils import Cache, StaticCache +from transformers.cache_utils import Cache, StaticCache # pylint: disable=E0611 from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.utils import ( add_start_docstrings_to_model_forward, @@ -61,10 +61,10 @@ if version.parse(transformers.__version__) > version.parse("4.33.0"): from transformers.utils import is_flash_attn_greater_or_equal_2_10, is_flash_attn_2_available if is_flash_attn_2_available(): - from flash_attn import ( + from flash_attn import ( # pylint: disable=E0401 flash_attn_func, - flash_attn_varlen_func) # pylint: disable=E1101 - from flash_attn.bert_padding import ( + flash_attn_varlen_func) # pylint: disable=E0401 + from flash_attn.bert_padding import ( # pylint: disable=E1101 index_first_axis, pad_input, unpad_input) # pylint: disable=E1101 @@ -861,7 +861,7 @@ def prepare_inputs_for_generation( and attention_mask is not None and cache_length + input_ids.shape[1] > max_cache_length ): - attention_mask = attention_mask[:, -max_cache_length:] + attention_mask = attention_mask[:, -max_cache_length:] # pylint: disable=E1130 position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: @@ -905,11 +905,6 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask @@ -928,7 +923,7 @@ def _update_causal_mask( inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, is_training=self.training, - ): + ): # pylint: disable=E1101 return None dtype, device = input_tensor.dtype, input_tensor.device diff --git a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/prune/__init__.py b/intel_extension_for_transformers/transformers/kv_cache_compression/prune/__init__.py similarity index 100% rename from intel_extension_for_transformers/transformers/modeling/kv_cache_compression/prune/__init__.py rename to intel_extension_for_transformers/transformers/kv_cache_compression/prune/__init__.py diff --git a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/prune/base.py b/intel_extension_for_transformers/transformers/kv_cache_compression/prune/base.py similarity index 100% rename from intel_extension_for_transformers/transformers/modeling/kv_cache_compression/prune/base.py rename to intel_extension_for_transformers/transformers/kv_cache_compression/prune/base.py diff --git a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/prune/h2o.py b/intel_extension_for_transformers/transformers/kv_cache_compression/prune/h2o.py similarity index 99% rename from intel_extension_for_transformers/transformers/modeling/kv_cache_compression/prune/h2o.py rename to intel_extension_for_transformers/transformers/kv_cache_compression/prune/h2o.py index 802ebc495be..ffbb939237b 100644 --- a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/prune/h2o.py +++ b/intel_extension_for_transformers/transformers/kv_cache_compression/prune/h2o.py @@ -153,7 +153,7 @@ def __call__(self, attn_score, key_states, value_states, **kwargs): # hh-selection mask = torch.zeros(self.hh_score.shape, dtype=attn_score.dtype).to(key_states.device) if not self.recent_budget == 0: - mask[:,:,-self.recent_budget:] = 1 + mask[:,:,-self.recent_budget:] = 1 # pylint: disable=E1130 select_hh_scores = self.hh_score[:,:,:seq_len - self.recent_budget] if not self.heavy_budget == 0: