From b1ab771c4a6fd8ed2ebf814cbf9baa64af9affc6 Mon Sep 17 00:00:00 2001 From: n1ck-guo Date: Tue, 16 Jul 2024 00:49:00 -0400 Subject: [PATCH] update Signed-off-by: n1ck-guo --- .../unitTest/coverage/.optimize-coveragerc | 1 + .../kv_cache_compression/__init__.py | 4 +++- .../models/modeling_llama.py | 8 ++++---- .../kv_cache_compression/prune/h2o.py | 20 ++++++++++++++----- 4 files changed, 23 insertions(+), 10 deletions(-) diff --git a/.github/workflows/script/unitTest/coverage/.optimize-coveragerc b/.github/workflows/script/unitTest/coverage/.optimize-coveragerc index 7503552f189..1164f66dd99 100644 --- a/.github/workflows/script/unitTest/coverage/.optimize-coveragerc +++ b/.github/workflows/script/unitTest/coverage/.optimize-coveragerc @@ -18,6 +18,7 @@ omit = */intel_extension_for_transformers/langchain/** */intel_extension_for_transformers/llama_index/** */intel_extension_for_transformers/transformers/utils/get_throughput.py + */intel_extension_for_transformers/transformers/kv_cache_compression/models/** exclude_lines = pragma: no cover raise NotImplementedError diff --git a/intel_extension_for_transformers/transformers/kv_cache_compression/__init__.py b/intel_extension_for_transformers/transformers/kv_cache_compression/__init__.py index 6be078e6993..f73db323734 100644 --- a/intel_extension_for_transformers/transformers/kv_cache_compression/__init__.py +++ b/intel_extension_for_transformers/transformers/kv_cache_compression/__init__.py @@ -17,4 +17,6 @@ from .prune.h2o import H2OConfig, H2OKVPruner from .models.modeling_llama import LlamaForCausalLM -from .models.modeling_gaudi_llama import GaudiLlamaForCausalLM +from intel_extension_for_transformers.transformers.utils.utility import LazyImport + +GaudiLlamaForCausalLM = LazyImport(".models.modeling_gaudi_llama.GaudiLlamaForCausalLM") diff --git a/intel_extension_for_transformers/transformers/kv_cache_compression/models/modeling_llama.py b/intel_extension_for_transformers/transformers/kv_cache_compression/models/modeling_llama.py index 1c8928ce4d1..91b56a031ec 100644 --- a/intel_extension_for_transformers/transformers/kv_cache_compression/models/modeling_llama.py +++ b/intel_extension_for_transformers/transformers/kv_cache_compression/models/modeling_llama.py @@ -693,10 +693,10 @@ def __init__( # Initialize weights and apply final processing self.post_init() - def _generate(**kwargs): - self.pruner.before_generate(self, **kwargs) - result = self.ori_generate(**kwargs) - self.pruner.after_generate(self, **kwargs) + def _generate(*args, **kwargs): + self.pruner.before_generate(self, *args, **kwargs) + result = self.ori_generate(*args, **kwargs) + self.pruner.after_generate(self, *args, **kwargs) return result self.ori_generate = self.generate diff --git a/intel_extension_for_transformers/transformers/kv_cache_compression/prune/h2o.py b/intel_extension_for_transformers/transformers/kv_cache_compression/prune/h2o.py index 99df3584c4c..ecb2f8ca3fe 100644 --- a/intel_extension_for_transformers/transformers/kv_cache_compression/prune/h2o.py +++ b/intel_extension_for_transformers/transformers/kv_cache_compression/prune/h2o.py @@ -44,7 +44,7 @@ def local_heavy_hitter_mask(attn_weights, heavy_budget, no_padding_seq_length=No for token_index in range(heavy_budget+padding_length, seq_length): tmp_attn_index = nn.functional.softmax( - attn_weights[:,:,token_index,:], dim=-1, dtype=torch.float32).to(dtype_attn_weights) + attn_weights[:,:,token_index,:], dim=-1, dtype=torch.float32).to(dtype_attn_weights) _, tmp_topk_index = accumulated_attention_score.topk(k=heavy_budget-1, dim=-1) zeros_index = torch.zeros_like(tmp_attn_index, dtype=torch.bool) mask_bottom_index = zeros_index.scatter(-1, tmp_topk_index, True) #(head, keys) @@ -123,6 +123,9 @@ def __init__( mean=False ): ## bsz, num_heads, seq_len, head_dim + assert 0 <= heavy_ratio <= 1 and 0 <= recent_ratio <= 1, "ratio should be in [0, 1]" + assert heavy_budget is None or heavy_budget >= 0, "heavy_budget should be non-negative" + assert recent_budget is None or recent_budget >= 0, "recent_budget should be non-negative" self.heavy_ratio = heavy_ratio self.recent_ratio = recent_ratio self.heavy_budget = heavy_budget @@ -221,15 +224,22 @@ def self_attn_init(self, module): ) def before_generate(self, model, inputs, *args, **kwargs): + assert self.real_drop is True, 'H2O only support real drop mode when use generate func.' self.past_length = 0 - max_length = kwargs['max_new_tokens'] if kwargs.get('max_new_tokens') else kwargs['max_length'] - max_length += inputs.size(-1) + if kwargs.get('max_new_tokens', None): + max_length = kwargs['max_new_tokens'] + inputs.size(-1) + elif kwargs.get('max_length', None): + max_length = kwargs['max_length'] + else: + max_length = model.config.max_length + if max_length <= inputs.size(-1): + max_length += inputs.size(-1) for _, module in model.named_modules(): if "Attention" in module.__class__.__name__: if module.h2o_kv_cache.heavy_budget is None: - module.h2o_kv_cache.heavy_budget = int(max_length * module.h2o_kv_cache.heavy_ratio) + module.h2o_kv_cache.heavy_budget = round(max_length * module.h2o_kv_cache.heavy_ratio) if module.h2o_kv_cache.recent_budget is None: - module.h2o_kv_cache.recent_budget = int(max_length * module.h2o_kv_cache.recent_ratio) + module.h2o_kv_cache.recent_budget = round(max_length * module.h2o_kv_cache.recent_ratio) if self.prune_kv_cache_size is None: self.prune_kv_cache_size = module.h2o_kv_cache.recent_budget + module.h2o_kv_cache.heavy_budget