From 124bb72a3ca1a1069b92d17c239807e36293e0f4 Mon Sep 17 00:00:00 2001 From: n1ck-guo Date: Tue, 18 Jun 2024 02:07:52 -0400 Subject: [PATCH] refactor code Signed-off-by: n1ck-guo --- .../text-generation/h2o/run_generation.py | 79 ++- .../modeling/kv_cache_compression/__init__.py | 8 +- .../modeling/kv_cache_compression/h2o.py | 111 ++-- .../models/modeling_bloom.py | 223 ++++++- .../models/modeling_gpt_neox.py | 485 +++++++++++++- .../models/modeling_llama.py | 426 +++++++++++- .../models/modeling_mistral.py | 615 +++++++++++++++++- .../models/modeling_mixtral.py | 318 ++++++++- .../models/modeling_opt.py | 305 +++++++-- 9 files changed, 2311 insertions(+), 259 deletions(-) diff --git a/examples/huggingface/pytorch/text-generation/h2o/run_generation.py b/examples/huggingface/pytorch/text-generation/h2o/run_generation.py index b41712d9d96..94ba87dd634 100644 --- a/examples/huggingface/pytorch/text-generation/h2o/run_generation.py +++ b/examples/huggingface/pytorch/text-generation/h2o/run_generation.py @@ -1,6 +1,5 @@ import argparse import sys -sys.path.insert(0, '/home/hengguo/code/intel-extension-for-transformers') import time import json import torch @@ -66,7 +65,6 @@ # transformers version >= 4.32.0 contained the mpt modeling definition. # https://github.com/huggingface/transformers/blob/main/src/transformers/models/mpt/modeling_mpt.py # 4.31.0 for ipex.optimize_transformers -check_min_version("4.35.2") # get model config if args.peft_model_id: from peft import PeftConfig @@ -111,24 +109,26 @@ device = args.device else: device = f"cuda:{args.device}" -user_model = AutoModelForCausalLM.from_pretrained(args.model, trust_remote_code=args.trust_remote_code) -user_model.to(device) # get optimized model if args.h2o: print('Enable Small Cache Size') - # checkpoint = copy.deepcopy(model.state_dict()) - # model = ENABLE_Heavy_Hitter_FUNCTIONS[args.model_type](model, config) - from intel_extension_for_transformers.transformers.modeling.kv_cache_compression import convert_model - user_model = convert_model( - user_model, + from intel_extension_for_transformers.transformers.modeling.kv_cache_compression import H2OConfig, H2OLlamaForCausalLM + h2o_config = H2OConfig( heavy_ratio=args.heavy_ratio, recent_ratio=args.recent_ratio, h2o_min_seqlen=args.h2o_min_seqlen, real_drop=args.real_drop, - is_gen=args.is_gen - ) + mean=False, + ) + user_model = H2OLlamaForCausalLM.from_pretrained( + args.model, + h2o_config=h2o_config, + trust_remote_code=args.trust_remote_code) print("converted model: ", user_model) +else: + user_model = AutoModelForCausalLM.from_pretrained(args.model, trust_remote_code=args.trust_remote_code) +user_model.to(device) # save model # if args.output_dir is not None: @@ -194,18 +194,45 @@ if args.accuracy: user_model = (user_model.eval() if (not (args.int8 or args.int8_bf16_mixed) and hasattr(user_model, "eval")) \ else user_model) - from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate, LMEvalParser - model_args="pretrained="+args.model+",trust_remote_code="+str(args.trust_remote_code) - args.tasks = ",".join(args.tasks) - tokenizer.pad_token = tokenizer.eos_token - eval_args = LMEvalParser(model = "hf", - user_model=user_model, - tokenizer=tokenizer, - model_args=model_args, - tasks = args.tasks, - device = device, - num_fewshot=args.num_fewshot, - output_path=args.save_accuracy_path, - batch_size = args.batch_size) - print("using device:", device) - results = evaluate(eval_args) \ No newline at end of file + # from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate, LMEvalParser + # model_args="pretrained="+args.model+",trust_remote_code="+str(args.trust_remote_code) + # args.tasks = ",".join(args.tasks) + # tokenizer.pad_token = tokenizer.eos_token + # eval_args = LMEvalParser(model = "hf", + # user_model=user_model, + # tokenizer=tokenizer, + # model_args=model_args, + # tasks = args.tasks, + # device = device, + # num_fewshot=args.num_fewshot, + # output_path=args.save_accuracy_path, + # batch_size = args.batch_size) + # print("using device:", device) + # results = evaluate(eval_args) + + + # original lm_eval + from lm_eval.evaluator import simple_evaluate + from lm_eval.tasks import TaskManager + import lm_eval + + verbosity = 'INFO' + task_manager = TaskManager(verbosity) + limit = None + cache_requests = False + lm = lm_eval.api.registry.get_model("hf")( + pretrained=user_model, + batch_size=args.batch_size, + max_batch_size=None, + ) + model_args="pretrained="+ args.model+ ",tokenizer="+ args.model + ",dtype=float32" + use_cache = None + results = simple_evaluate( + model=lm, + model_args=model_args, + tasks=args.tasks, + num_fewshot=args.num_fewshot, + device=device + ) + import pprint + pprint.pprint(results["results"]) diff --git a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/__init__.py b/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/__init__.py index 99ca690b652..471a68ace95 100644 --- a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/__init__.py +++ b/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/__init__.py @@ -15,4 +15,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .h2o import convert_model +from .h2o import H2OConfig +from .models.modeling_llama import H2OLlamaForCausalLM +from .models.modeling_bloom import H2OBloomForCausalLM +from .models.modeling_gpt_neox import H2OGPTNeoXForCausalLM +from .models.modeling_opt import H2OOPTForCausalLM +from .models.modeling_mistral import H2OMistralForCausalLM +from .models.modeling_mixtral import H2OMixtralForCausalLM diff --git a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/h2o.py b/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/h2o.py index b90b2a52490..2c15ab3725e 100644 --- a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/h2o.py +++ b/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/h2o.py @@ -67,61 +67,22 @@ def get_module(model, op_name): def clean_cache(model): for _, module in model.named_modules(): - if "H2O" in module.__class__.__name__: + if "Attention" in module.__class__.__name__: module.h2o_kv_cache.clean_scores() def generate(model, **kwargs): + max_length = kwargs['max_new_tokens'] if kwargs.get('max_new_tokens') else kwargs['max_length'] for _, module in model.named_modules(): - if "H2O" in module.__class__.__name__: + if "Attention" in module.__class__.__name__: module.is_gen = True + if module.h2o_kv_cache.heavy_budget is None: + module.h2o_kv_cache.heavy_budget = int(max_length * module.h2o_kv_cache.heavy_ratio) + if module.h2o_kv_cache.recent_budget is None: + module.h2o_kv_cache.recent_budget = int(max_length * module.h2o_kv_cache.recent_ratio) result = model.ori_generate(**kwargs) clean_cache(model) return result -def convert_model( - model, - heavy_ratio, - recent_ratio, - h2o_min_seqlen=1024, - real_drop=True, - is_gen=False, - mean=False, - local=True - ): - model_type = model.config.model_type - device = model.device - atten_layers = [] - for name, module in model.named_modules(): - if "Attention" in module.__class__.__name__: - atten_layers.append(name) - - for layer_name in atten_layers: - module = get_module(model, layer_name) - cls_name = "H2O" + module.__class__.__name__ if real_drop else SIM_CLS_MAPPING[model_type] - h2o_cls = getattr( - importlib.import_module( - f".models.modeling_{model_type}", - "intel_extension_for_transformers.transformers.modeling.kv_cache_compression" - ), - cls_name) - module = h2o_cls( - module, - model.config, - heavy_ratio, - recent_ratio, - h2o_min_seqlen=h2o_min_seqlen, - real_drop=real_drop, - is_gen=is_gen, - mean=mean, - local=local - ) - set_module(model, layer_name, module) - model.clean_cache = lambda: clean_cache(model) - model.ori_generate = model.generate - model.generate = partial(generate, model) - model = model.to(device) - return model - def local_heavy_hitter_mask(attn_weights, heavy_budget, no_padding_seq_length=None): @@ -188,28 +149,37 @@ def get_hh_mask(heavy_budget_ratio, recent_budget_ratio, attn_weights, local=Tru return mask_bottom + class H2OKVCache: def __init__( self, heavy_ratio=0.2, recent_ratio=0.2, + heavy_budget=None, + recent_budget=None, min_seqlen=-1 ): ## bsz, num_heads, seq_len, head_dim self.heavy_ratio = heavy_ratio self.recent_ratio = recent_ratio + self.heavy_budget = heavy_budget + self.recent_budget = recent_budget self.hh_score = None self.min_seqlen = min_seqlen self.idx = 0 + + self._past_length = 0 def __call__(self, attn_score, key_states, value_states, mean=False, **kwargs): seq_len = key_states.size(-2) - heavy_budget = int(self.heavy_ratio * seq_len) - recent_budget = int(self.recent_ratio * seq_len) - cache_size = heavy_budget + recent_budget + if self.heavy_budget is None: + self.heavy_budget = int(self.heavy_ratio * seq_len) + if self.recent_budget is None: + self.recent_budget = int(self.recent_ratio * seq_len) + cache_size = self.heavy_budget + self.recent_budget if seq_len <= self.min_seqlen or seq_len <= cache_size: - return torch.ones(attn_score.shape[:-1], dtype=attn_score.dtype).to(key_states.device) + return key_states, value_states self.idx += 1 # attn_score shape (bsz, num_heads, seq_len, head_dim) if len(attn_score.shape) == 3: @@ -218,12 +188,12 @@ def __call__(self, attn_score, key_states, value_states, mean=False, **kwargs): # hh-selection mask = torch.zeros(self.hh_score.shape, dtype=attn_score.dtype).to(key_states.device) - if not recent_budget == 0: - mask[:,:,-recent_budget:] = 1 - select_hh_scores = self.hh_score[:,:,:seq_len - recent_budget] + if not self.recent_budget == 0: + mask[:,:,-self.recent_budget:] = 1 + select_hh_scores = self.hh_score[:,:,:seq_len - self.recent_budget] - if not heavy_budget == 0: - _, keep_topk = torch.topk(select_hh_scores, heavy_budget, dim=-1, largest=True) + if not self.heavy_budget == 0: + _, keep_topk = torch.topk(select_hh_scores, self.heavy_budget, dim=-1, largest=True) mask = mask.scatter(-1, keep_topk, 1) mask = mask.bool() @@ -260,7 +230,36 @@ def _update_hh_score(self, attn_score_cache, mean=False): self.hh_score = attn_score_cache - def clean_scores(self): self.idx = 0 + self.past_length = 0 self.hh_score = None + + @property + def past_length(self): + return self._past_length + + @past_length.setter + def past_length(self, value): + self._past_length = value + +class H2OConfig(dict): + def __init__( + self, + heavy_ratio: float = None, + recent_ratio: float = None, + heavy_budget: int = None, + recent_budget: int = None, + h2o_min_seqlen: int = -1, + real_drop: bool = True, + mean: bool = False, + local: bool = True + ): + self.heavy_ratio = heavy_ratio + self.recent_ratio = recent_ratio + self.heavy_budget = heavy_budget + self.recent_budget = recent_budget + self.h2o_min_seqlen = h2o_min_seqlen + self.real_drop = real_drop + self.mean = mean + self.local = local diff --git a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_bloom.py b/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_bloom.py index d61b312931e..2494208d2d2 100644 --- a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_bloom.py +++ b/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_bloom.py @@ -15,30 +15,38 @@ # See the License for the specific language governing permissions and # limitations under the License. import math +import warnings from torch.nn import functional as F from typing import List, Optional, Tuple, Union import torch import torch.nn as nn - -from ..h2o import H2OKVCache +from torch.nn import CrossEntropyLoss from transformers.utils import logging +from transformers.models.bloom import ( + BloomConfig, + BloomModel, + BloomPreTrainedModel, + BLOOM_INPUTS_DOCSTRING, + _CHECKPOINT_FOR_DOC, + _CONFIG_FOR_DOC +) +from transformers.file_utils import ( + add_code_sample_docstrings, + add_start_docstrings_to_model_forward + ) +from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions + +from ..h2o import H2OKVCache, H2OConfig logger = logging.get_logger(__name__) class H2OBloomAttention(nn.Module): def __init__( self, - model, - config, - heavy_ratio, - recent_ratio, - h2o_min_seqlen=1024, - real_drop=False, - is_gen=False, - mean=False, - local=True + config: BloomConfig, + h2o_config: H2OConfig ): super().__init__() @@ -61,24 +69,24 @@ def __init__( self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) self.beta = 1.0 - self.query_key_value = model.query_key_value - self.dense = model.dense - self.attention_dropout = model.attention_dropout + self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True) + self.dense = nn.Linear(self.hidden_size, self.hidden_size) + self.attention_dropout = nn.Dropout(config.attention_dropout) # for h2o if real_drop: real_drop = False logger.warning_once("BloomAttention not support for kv cache, usning simulation mode.") - self.real_drop = real_drop - self.is_gen = is_gen - self.mean = mean - self.local = local + self.h2o_config = h2o_config + self.is_gen = False + self.mean = h2o_config.mean + self.local = h2o_config.local - self.heavy_ratio = heavy_ratio - self.recent_ratio = recent_ratio - self.h2o_min_seqlen = h2o_min_seqlen + self.heavy_ratio = h2o_config.heavy_ratio + self.recent_ratio = h2o_config.recent_ratio + self.h2o_min_seqlen = h2o_config.h2o_min_seqlen - self.h2o_kv_cache = H2OKVCache(self.heavy_ratio, self.recent_ratio, real_drop, h2o_min_seqlen) + self.h2o_kv_cache = H2OKVCache(self.heavy_ratio, self.recent_ratio, h2o_config.h2o_min_seqlen) def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory @@ -230,3 +238,174 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: out = F.dropout(x, p=prob, training=training) out = residual + out return out + +class H2OBloomForCausalLM(BloomPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__( + self, + config: BloomConfig, + h2o_config: H2OConfig + ): + super().__init__(config) + self.transformer = BloomModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + num_layers = len(self.transformer.h) + for layer_idx in range(num_layers): + self.transformer.h[layer_idx].self_attention = H2OBloomAttention( + config, + h2o_config + ) + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings: torch.Tensor): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> dict: + # only last tokens for input_ids if past is not None + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + + # the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed + if past_key_values[0][0].shape[0] == input_ids.shape[0]: + past_key_values = self._convert_to_bloom_cache(past_key_values) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **deprecated_arguments, + ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + if deprecated_arguments.pop("position_ids", False) is not False: + # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" + " passing `position_ids`.", + FutureWarning, + ) + if len(deprecated_arguments) > 0: + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + batch_size, seq_length, vocab_size = shift_logits.shape + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length) + ) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + def _reorder_cache( + self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + + Output shares the same memory storage as `past`. + """ + standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx)) + + # Get a copy of `beam_idx` on all the devices where we need those indices. + device_to_beam_idx = { + past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past + } + reordered_past = tuple( + ( + layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]), + layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]), + ) + for layer_past in standardized_past + ) + return self._convert_to_bloom_cache(reordered_past) \ No newline at end of file diff --git a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_gpt_neox.py b/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_gpt_neox.py index 840095ec62a..3aa13dbca05 100644 --- a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_gpt_neox.py +++ b/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_gpt_neox.py @@ -18,60 +18,81 @@ import torch import torch.nn as nn +from torch.nn import CrossEntropyLoss from transformers.utils import logging - -from ..h2o import H2OKVCache +from transformers.models.gpt_neox.modeling_gpt_neox import ( + _get_unpad_data, + GPTNeoXPreTrainedModel, + GPTNeoXModel, + CausalLMOutputWithPast, + GPT_NEOX_INPUTS_DOCSTRING, + GPT_NEOX_START_DOCSTRING, + _CONFIG_FOR_DOC + +) +from transformers.file_utils import ( + add_start_docstrings, + replace_return_docstrings, + add_start_docstrings_to_model_forward, + ) + +from ..h2o import H2OKVCache, H2OConfig logger = logging.get_logger(__name__) +from packaging import version +import transformers +if version.parse(transformers.__version__) > version.parse("4.33.0"): + from transformers.utils import is_flash_attn_greater_or_equal_2_10, is_flash_attn_2_available + if is_flash_attn_2_available(): + from flash_attn import ( + flash_attn_func, + flash_attn_varlen_func) # pylint: disable=E1101 + from flash_attn.bert_padding import ( + index_first_axis, + pad_input, + unpad_input) # pylint: disable=E1101 + class H2OGPTNeoXAttention(nn.Module): def __init__( self, - model, config, - heavy_ratio, - recent_ratio, - h2o_min_seqlen=1024, - real_drop=False, - is_gen=False, - mean=False, - local=True + h2o_config: H2OConfig, ): - super().__init__() self.config = config self.num_attention_heads = config.num_attention_heads self.hidden_size = config.hidden_size if self.hidden_size % self.num_attention_heads != 0: raise ValueError( - "The hidden size is not divisible by the number of attention heads! Make sure to update them" + "The hidden size is not divisble by the number of attention heads! Make sure to update them" ) self.head_size = self.hidden_size // self.num_attention_heads self.rotary_ndims = int(self.head_size * config.rotary_pct) - self.bias = model.bias + self._init_bias(config.max_position_embeddings) self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False) - self.rotary_emb = model.rotary_emb + self._init_rope() self.norm_factor = self.head_size**-0.5 - self.query_key_value = model.query_key_value - self.dense = model.dense - self.attention_dropout = model.attention_dropout + self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=config.attention_bias) + self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) + self.attention_dropout = nn.Dropout(config.attention_dropout) self.is_causal = True # for h2o if real_drop: real_drop = False logger.warning_once("GPTNeoXAttention not support for kv cache, usning simulation mode.") - self.real_drop = real_drop - self.is_gen = is_gen - self.mean = mean - self.local = local + self.h2o_config = h2o_config + self.is_gen = False + self.mean = h2o_config.mean + self.local = h2o_config.local - self.heavy_ratio = heavy_ratio - self.recent_ratio = recent_ratio - self.h2o_min_seqlen = h2o_min_seqlen + self.heavy_ratio = h2o_config.heavy_ratio + self.recent_ratio = h2o_config.recent_ratio + self.h2o_min_seqlen = h2o_config.h2o_min_seqlen - self.h2o_kv_cache = H2OKVCache(self.heavy_ratio, self.recent_ratio, real_drop, h2o_min_seqlen) + self.h2o_kv_cache = H2OKVCache(self.heavy_ratio, self.recent_ratio, h2o_config.h2o_min_seqlen) def forward( self, @@ -200,7 +221,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): attn_weights = nn.functional.softmax(attn_scores, dim=-1) attn_weights = attn_weights.to(value.dtype) - # get hh mask + # h2o from ..h2o import get_hh_mask mask = get_hh_mask(self.heavy_ratio, self.recent_ratio, attn_weights.detach().clone(), local=self.local) attn_weights[~mask] = torch.finfo(attn_weights.dtype).min @@ -214,6 +235,232 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): attn_output = torch.matmul(attn_weights, value) return attn_output, attn_weights +class H2OGPTNeoXFlashAttention2(H2OGPTNeoXAttention): + """ + GPTNeoX flash attention module. This module inherits from `GPTNeoXAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: torch.FloatTensor, + position_ids: torch.LongTensor, + head_mask: Optional[torch.FloatTensor] = None, + layer_past: Optional[Tuple[torch.Tensor]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ): + has_layer_past = layer_past is not None + + # Compute QKV + # Attention heads [batch, seq_len, hidden_size] + # --> [batch, seq_len, (np * 3 * head_size)] + qkv = self.query_key_value(hidden_states) + + # [batch, seq_len, (num_heads * 3 * head_size)] + # --> [batch, seq_len, num_heads, 3 * head_size] + new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size) + qkv = qkv.view(*new_qkv_shape) + + # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size] + query = qkv[..., : self.head_size].permute(0, 2, 1, 3) + key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3) + value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3) + + query_length = query.shape[-2] + + # Compute rotary embeddings on rotary_ndims + query_rot = query[..., : self.rotary_ndims] + query_pass = query[..., self.rotary_ndims :] + key_rot = key[..., : self.rotary_ndims] + key_pass = key[..., self.rotary_ndims :] + + # Compute token offset for rotary embeddings (when decoding) + seq_len = key.shape[-2] + if has_layer_past: + seq_len += layer_past[0].shape[-2] + cos, sin = self.rotary_emb(value, seq_len=seq_len) + query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) + query = torch.cat((query, query_pass), dim=-1) + key = torch.cat((key, key_pass), dim=-1) + + # Cache QKV values + if has_layer_past: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + present = (key, value) if use_cache else None + + # GPT-neo-X casts query and key in fp32 to apply rotary embedding in full precision + target_dtype = value.dtype + if query.dtype != target_dtype: + query = query.to(target_dtype) + if key.dtype != target_dtype: + key = key.to(target_dtype) + + # Permute to get the expected shape for Flash Attention + query = query.permute(0, 2, 1, 3) + key = key.permute(0, 2, 1, 3) + value = value.permute(0, 2, 1, 3) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 / bfloat16 just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + input_dtype = query.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.query_key_value.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + attention_dropout = self.config.attention_dropout if self.training else 0.0 + + # Compute attention + attn_weights = self._flash_attention_forward( + query, key, value, attention_mask, query_length, dropout=attention_dropout, softmax_scale=self.norm_factor + ) + + # h2o + from ..h2o import get_hh_mask + mask = get_hh_mask(self.heavy_ratio, self.recent_ratio, attn_weights.detach().clone(), local=self.local) + attn_weights[~mask] = torch.finfo(attn_weights.dtype).min + + # Reshape outputs + attn_output = attn_weights.reshape( + attn_weights.shape[0], attn_weights.shape[1], self.num_attention_heads * self.head_size + ) + attn_output = self.dense(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input with num_heads->num_attention_heads + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_attention_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -246,3 +493,189 @@ def rotate_half(x): x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) + + +@add_start_docstrings( + """GPTNeoX Model with a `language modeling` head on top for CLM fine-tuning.""", GPT_NEOX_START_DOCSTRING +) +class H2OGPTNeoXForCausalLM(GPTNeoXPreTrainedModel): + _tied_weights_keys = ["embed_out.weight"] + + def __init__(self, config, h2o_config): + super().__init__(config) + + self.gpt_neox = GPTNeoXModel(config) + self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + num_layers = len(self.model.layers) + for layer_idx in range(num_layers): + module = self.gpt_neox.layers[layer_idx].self_attn + cls_name = module.__class__.__name__ + if cls_name == "GPTNeoXFlashAttention2": + cls = H2OGPTNeoXFlashAttention2 + else: + cls = H2OGPTNeoXAttention + + self.gpt_neox.layers[layer_idx].self_attn = cls( + config, + layer_idx, + h2o_config + ) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.embed_out + + def set_output_embeddings(self, new_embeddings): + self.embed_out = new_embeddings + + @add_start_docstrings_to_model_forward(GPT_NEOX_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are + only required when the model is used as a decoder in a Sequence to Sequence model. + + Contains pre-computed hidden-states (key and values in the self-attention blocks that can be used (see + `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, GPTNeoXForCausalLM, GPTNeoXConfig + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") + >>> config = GPTNeoXConfig.from_pretrained("EleutherAI/gpt-neox-20b") + >>> config.is_decoder = True + >>> model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", config=config) + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.gpt_neox( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + lm_logits = self.embed_out(hidden_states) + + lm_loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # we are doing next-token prediction; shift prediction scores and input ids by one + shift_logits = lm_logits[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithPast( + loss=lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + input_shape = input_ids.shape + # cut decoder_input_ids if past is used + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + model_inputs.update( + { + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "position_ids": position_ids, + "use_cache": kwargs.get("use_cache"), + } + ) + + return model_inputs + + def _reorder_cache(self, past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + + layer_past[2:], + ) + return reordered_past \ No newline at end of file diff --git a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_llama.py b/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_llama.py index d9aaa7a5336..7cc47e31737 100644 --- a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_llama.py +++ b/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_llama.py @@ -18,20 +18,43 @@ import math from typing import List, Optional, Tuple, Union +from functools import partial import torch import torch.nn as nn import torch.nn.functional as F - -from transformers.cache_utils import Cache -from transformers.utils import logging +from torch.nn import CrossEntropyLoss + +from transformers.cache_utils import Cache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.utils import ( + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, + ) from transformers.models.llama.configuration_llama import LlamaConfig -from transformers.models.llama.modeling_llama import rotate_half, repeat_kv, _get_unpad_data - -from ..h2o import H2OKVCache +from transformers.models.llama.modeling_llama import ( + rotate_half, + repeat_kv, + _get_unpad_data, + LlamaRotaryEmbedding, + LlamaLinearScalingRotaryEmbedding, + LlamaDynamicNTKScalingRotaryEmbedding, + LlamaPreTrainedModel, + LlamaModel, + LLAMA_INPUTS_DOCSTRING +) +from transformers.modeling_outputs import ( + CausalLMOutputWithPast, +) + +from ..h2o import H2OKVCache, H2OConfig, generate logger = logging.get_logger(__name__) +_CONFIG_FOR_DOC = "LlamaConfig" from packaging import version import transformers @@ -78,19 +101,13 @@ class H2OLlamaAttention(nn.Module): def __init__( self, - model, config: LlamaConfig, - heavy_ratio, - recent_ratio, - h2o_min_seqlen=1024, - real_drop=False, - is_gen=False, - mean=False, - local=True + layer_idx: Optional[int] = None, + h2o_config: H2OConfig = None, ): super().__init__() self.config = config - self.layer_idx = model.layer_idx + self.layer_idx = layer_idx if self.layer_idx is None: logger.warning_once( f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " @@ -114,23 +131,52 @@ def __init__( f" and `num_heads`: {self.num_heads})." ) - self.q_proj = model.q_proj - self.k_proj = model.k_proj - self.v_proj = model.v_proj - self.o_proj = model.o_proj - self.rotary_emb = model.rotary_emb + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + self._init_rope() # for h2o - self.is_gen = is_gen - self.real_drop = real_drop - self.mean = mean - self.local = local - - self.heavy_ratio = heavy_ratio - self.recent_ratio = recent_ratio - self.h2o_min_seqlen = h2o_min_seqlen - - self.h2o_kv_cache = H2OKVCache(self.heavy_ratio, self.recent_ratio, h2o_min_seqlen) + self.h2o_config = h2o_config + self.is_gen = False + self.real_drop = h2o_config.real_drop + self.mean = h2o_config.mean + self.local = h2o_config.local + + self.heavy_ratio = h2o_config.heavy_ratio + self.recent_ratio = h2o_config.recent_ratio + self.h2o_min_seqlen = h2o_config.h2o_min_seqlen + + self.h2o_kv_cache = H2OKVCache(self.heavy_ratio, self.recent_ratio, h2o_config.h2o_min_seqlen) + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + def forward( self, @@ -210,6 +256,7 @@ def forward( ) past_key_value.key_cache[self.layer_idx] = new_key_states past_key_value.value_cache[self.layer_idx] = new_value_states + self.h2o_kv_cache.past_length += attn_weights.size(-2) else: from ..h2o import get_hh_mask mask = get_hh_mask( @@ -353,6 +400,7 @@ def forward( ) past_key_value.key_cache[self.layer_idx] = new_key_states past_key_value.value_cache[self.layer_idx] = new_value_states + self.h2o_kv_cache.past_length += attn_weights.size(-2) attn_output = self._flash_attention_forward( query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate @@ -557,8 +605,12 @@ def forward( past_key_value.value_cache[self.layer_idx], mean=self.mean ) + # if self.layer_idx == 0: + # print(self.layer_idx, self.is_gen, query_states.shape, new_key_states.shape) + # print(query_states.shape, key_states.shape, value_states.shape) past_key_value.key_cache[self.layer_idx] = new_key_states past_key_value.value_cache[self.layer_idx] = new_value_states + self.h2o_kv_cache.past_length += attn_weights.size(-2) # In case we are not compiling, we may set `causal_mask` to None, # which is required to dispatch to SDPA's Flash Attention 2 backend, rather @@ -576,5 +628,319 @@ def forward( attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) + # if self.layer_idx == 0: + # print(attn_output.shape) return attn_output, None, past_key_value + +class H2OLlamaForCausalLM(LlamaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__( + self, + config: LlamaConfig, + h2o_config: H2OConfig, + ): + super().__init__(config) + self.model = LlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + num_layers = len(self.model.layers) + for layer_idx in range(num_layers): + module = self.model.layers[layer_idx].self_attn + cls_name = module.__class__.__name__ + if not h2o_config.real_drop: + cls = H2OLlamaAttention + elif cls_name == "LlamaFlashAttention2": + cls = H2OLlamaFlashAttention2 + elif cls_name == "LlamaSdpaAttention": + cls = H2OLlamaSdpaAttention + else: + cls = H2OLlamaAttention + + self.model.layers[layer_idx].self_attn = cls( + config, + layer_idx, + h2o_config + ) + + # Initialize weights and apply final processing + self.post_init() + + self.ori_generate = self.generate + self.generate = partial(generate, self) + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + use_cache=True, + **kwargs, + ): + past_length = 0 + if past_key_values is not None: + if isinstance(past_key_values, Cache): + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # cache_length = past_length = input_ids.shape[-1] - 1 + cache_length = past_length = self.model.layers[0].self_attn.h2o_kv_cache.past_length + max_cache_length = None + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {"input_ids": input_ids.contiguous()} + + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + elif use_cache: + cache_position = cache_position[-input_length:] + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask \ No newline at end of file diff --git a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_mistral.py b/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_mistral.py index 6a5e9a4a9e1..468db0d64a1 100644 --- a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_mistral.py +++ b/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_mistral.py @@ -15,19 +15,52 @@ # See the License for the specific language governing permissions and # limitations under the License. import math +import inspect import logging from typing import List, Optional, Tuple, Union +from functools import partial import torch import torch.nn as nn +from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache -from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv - -from ..h2o import H2OKVCache +from transformers.models.mistral.modeling_mistral import ( + apply_rotary_pos_emb, + repeat_kv, + _get_unpad_data, + MistralConfig, + MistralModel, + MistralPreTrainedModel, + MistralRotaryEmbedding, + CausalLMOutputWithPast, + MISTRAL_INPUTS_DOCSTRING, + _CONFIG_FOR_DOC, + + ) +from transformers.file_utils import ( + replace_return_docstrings, + add_start_docstrings_to_model_forward, + ) + +from ..h2o import H2OKVCache, H2OConfig, generate logger = logging.getLogger(__name__) +from packaging import version +import transformers +if version.parse(transformers.__version__) > version.parse("4.33.0"): + from transformers.utils import is_flash_attn_greater_or_equal_2_10, is_flash_attn_2_available + if is_flash_attn_2_available(): + from flash_attn import ( + flash_attn_func, + flash_attn_varlen_func) # pylint: disable=E1101 + from flash_attn.bert_padding import ( + index_first_axis, + pad_input, + unpad_input) # pylint: disable=E1101 + _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) + class H2OMistralAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper. @@ -36,21 +69,14 @@ class H2OMistralAttention(nn.Module): """ def __init__( - self, - model, - config, - heavy_ratio, - recent_ratio, - h2o_min_seqlen=1024, - real_drop=False, - is_gen=False, - mean=False, - local=True + self, config: MistralConfig, + layer_idx: Optional[int] = None, + h2o_config: H2OConfig = None, ): super().__init__() self.config = config - self.layer_idx = model.layer_idx - if self.layer_idx is None: + self.layer_idx = layer_idx + if layer_idx is None: logger.warning_once( f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " @@ -72,23 +98,29 @@ def __init__( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads})." ) - self.q_proj = model.q_proj - self.k_proj = model.k_proj - self.v_proj = model.v_proj - self.o_proj = model.o_proj - self.rotary_emb = model.rotary_emb + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.rotary_emb = MistralRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) # for h2o - self.is_gen = is_gen - self.real_drop = real_drop - self.mean = mean - self.local = local + self.h2o_config = h2o_config + self.is_gen = False + self.real_drop = h2o_config.real_drop + self.mean = h2o_config.mean + self.local = h2o_config.local - self.heavy_ratio = heavy_ratio - self.recent_ratio = recent_ratio - self.h2o_min_seqlen = h2o_min_seqlen + self.heavy_ratio = h2o_config.heavy_ratio + self.recent_ratio = h2o_config.recent_ratio + self.h2o_min_seqlen = h2o_config.h2o_min_seqlen - self.h2o_kv_cache = H2OKVCache(self.heavy_ratio, self.recent_ratio, real_drop, h2o_min_seqlen) + self.h2o_kv_cache = H2OKVCache(self.heavy_ratio, self.recent_ratio, h2o_config.h2o_min_seqlen) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -166,6 +198,7 @@ def forward( ) past_key_value.key_cache[self.layer_idx] = new_key_states past_key_value.value_cache[self.layer_idx] = new_value_states + self.h2o_kv_cache.past_length += attn_weights.size(-2) else: from ..h2o import get_hh_mask mask = get_hh_mask( @@ -197,6 +230,318 @@ def forward( return attn_output, attn_weights, past_key_value +class H2OMistralFlashAttention2(H2OMistralAttention): + """ + Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + use_sliding_windows = ( + _flash_supports_window_size + and getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + ) + + if not _flash_supports_window_size: + logger.warning_once( + "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" + " make sure to upgrade flash-attn library." + ) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # h2o + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + if past_key_value is not None: + if not self.is_gen: + self.h2o_kv_cache.clean_scores() + new_key_states, new_value_states = self.h2o_kv_cache( + attn_weights, + past_key_value.key_cache[self.layer_idx], + past_key_value.value_cache[self.layer_idx], + mean=self.mean + ) + past_key_value.key_cache[self.layer_idx] = new_key_states + past_key_value.value_cache[self.layer_idx] = new_value_states + self.h2o_kv_cache.past_length += attn_weights.size(-2) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + use_sliding_windows=use_sliding_windows, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + use_sliding_windows=False, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + use_sliding_windows (`bool`, *optional*): + Whether to activate sliding window attention. + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + if not use_sliding_windows: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + if not use_sliding_windows: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape + + # On the first iteration we need to properly re-create the padding mask + # by slicing it on the proper place + if kv_seq_len != attention_mask.shape[-1]: + attention_mask_num_tokens = attention_mask.shape[-1] + attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] + + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + class H2OMistralSdpaAttention(H2OMistralAttention): """ Mistral attention module using torch.nn.functional.scaled_dot_product_attention. @@ -295,6 +640,7 @@ def forward( ) past_key_value.key_cache[self.layer_idx] = new_key_states past_key_value.value_cache[self.layer_idx] = new_value_states + self.h2o_kv_cache.past_length += attn_weights.size(-2) if query_states.device.type == "cuda" and attention_mask is not None: query_states = query_states.contiguous() @@ -318,3 +664,216 @@ def forward( attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value + +class MistralForCausalLM(MistralPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__( + self, + config, + h2o_config: H2OConfig): + super().__init__(config) + self.model = MistralModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + num_layers = len(self.model.layers) + for layer_idx in range(num_layers): + module = self.model.layers[layer_idx].self_attn + cls_name = module.__class__.__name__ + if not h2o_config.real_drop: + cls = H2OMistralAttention + elif cls_name == "MistralFlashAttention2": + cls = H2OMistralFlashAttention2 + elif cls_name == "LlamaSdpaAttention": + cls = H2OMistralSdpaAttention + else: + cls = H2OMistralAttention + + self.model.layers[layer_idx].self_attn = cls( + config, + layer_idx, + h2o_config + ) + + # Initialize weights and apply final processing + self.post_init() + + self.ori_generate = self.generate + self.generate = partial(generate, self) + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MistralForCausalLM + + >>> model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Ensure tensors are on the same device + shift_labels = shift_labels.to(shift_logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + # Omit tokens covered by past_key_values + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + cache_length = past_length = self.model.layers[0].self_attn.h2o_kv_cache.past_length + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past \ No newline at end of file diff --git a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_mixtral.py b/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_mixtral.py index 3c0fd4f7e1e..16213841b45 100644 --- a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_mixtral.py +++ b/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_mixtral.py @@ -18,15 +18,33 @@ import math import logging from typing import List, Optional, Tuple, Union +from functools import partial import torch import torch.nn as nn +from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache -from transformers.models.mixtral.modeling_mixtral import apply_rotary_pos_emb, repeat_kv, _get_unpad_data +from transformers.models.mixtral.modeling_mixtral import ( + apply_rotary_pos_emb, + repeat_kv, + _get_unpad_data, + load_balancing_loss_func, + MixtralConfig, + MixtralRotaryEmbedding, + MixtralModel, + MixtralPreTrainedModel, + MoeCausalLMOutputWithPast, + MIXTRAL_INPUTS_DOCSTRING, + _CONFIG_FOR_DOC, + ) +from transformers.file_utils import ( + replace_return_docstrings, + add_start_docstrings_to_model_forward, + ) from transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10 -from ..h2o import H2OKVCache +from ..h2o import H2OKVCache, H2OConfig, generate logger = logging.getLogger(__name__) @@ -44,21 +62,14 @@ class H2OMixtralAttention(nn.Module): """ def __init__( - self, - model, - config, - heavy_ratio, - recent_ratio, - h2o_min_seqlen=1024, - real_drop=False, - is_gen=False, - mean=False, - local=True + self, config: MixtralConfig, + layer_idx: Optional[int] = None, + h2o_config: H2OConfig = None, ): super().__init__() self.config = config - self.layer_idx = model.layer_idx - if self.layer_idx is None: + self.layer_idx = layer_idx + if layer_idx is None: logger.warning_once( f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " @@ -80,24 +91,29 @@ def __init__( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads})." ) - self.q_proj = model.q_proj - self.k_proj = model.k_proj - self.v_proj = model.v_proj - self.o_proj = model.o_proj - - self.rotary_emb = model.rotary_emb + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.rotary_emb = MixtralRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) # for h2o - self.is_gen = is_gen - self.real_drop = real_drop - self.mean = mean - self.local = local + self.h2o_config = h2o_config + self.is_gen = False + self.real_drop = h2o_config.real_drop + self.mean = h2o_config.mean + self.local = h2o_config.local - self.heavy_ratio = heavy_ratio - self.recent_ratio = recent_ratio - self.h2o_min_seqlen = h2o_min_seqlen + self.heavy_ratio = h2o_config.heavy_ratio + self.recent_ratio = h2o_config.recent_ratio + self.h2o_min_seqlen = h2o_config.h2o_min_seqlen - self.h2o_kv_cache = H2OKVCache(self.heavy_ratio, self.recent_ratio, real_drop, h2o_min_seqlen) + self.h2o_kv_cache = H2OKVCache(self.heavy_ratio, self.recent_ratio, h2o_config.h2o_min_seqlen) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -176,6 +192,7 @@ def forward( ) past_key_value.key_cache[self.layer_idx] = new_key_states past_key_value.value_cache[self.layer_idx] = new_value_states + self.h2o_kv_cache.past_length += attn_weights.size(-2) else: from ..h2o import get_hh_mask mask = get_hh_mask( @@ -330,6 +347,7 @@ def forward( ) past_key_value.key_cache[self.layer_idx] = new_key_states past_key_value.value_cache[self.layer_idx] = new_value_states + self.h2o_kv_cache.past_length += attn_weights.size(-2) # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need @@ -601,6 +619,7 @@ def forward( ) past_key_value.key_cache[self.layer_idx] = new_key_states past_key_value.value_cache[self.layer_idx] = new_value_states + self.h2o_kv_cache.past_length += attn_weights.size(-2) if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): @@ -630,4 +649,245 @@ def forward( attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value \ No newline at end of file + return attn_output, None, past_key_value + + +class MixtralForCausalLM(MixtralPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config, h2o_config: H2OConfig): + super().__init__(config) + self.model = MixtralModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + + num_layers = len(self.model.layers) + for layer_idx in range(num_layers): + module = self.model.layers[layer_idx].self_attn + cls_name = module.__class__.__name__ + if not h2o_config.real_drop: + cls = H2OMixtralAttention + elif cls_name == "MixtralFlashAttention2": + cls = H2OMixtralFlashAttention2 + elif cls_name == "MixtralSdpaAttention": + cls = H2OMixtralSdpaAttention + else: + cls = H2OMixtralAttention + + self.model.layers[layer_idx].self_attn = cls( + config, + layer_idx, + h2o_config + ) + + # Initialize weights and apply final processing + self.post_init() + + self.ori_generate = self.generate + self.generate = partial(generate, self) + + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + # Ignore copy + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MoeCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MixtralForCausalLM + + >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + output_router_logits=False, + **kwargs, + ): + # Omit tokens covered by past_key_values + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "output_router_logits": output_router_logits, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past \ No newline at end of file diff --git a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_opt.py b/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_opt.py index ff5ef8ace53..0780c3414a5 100644 --- a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_opt.py +++ b/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_opt.py @@ -16,29 +16,34 @@ # limitations under the License. """PyTorch OPT model.""" from typing import List, Optional, Tuple, Union +from functools import partial import torch import torch.nn as nn +from torch.nn import CrossEntropyLoss, MSELoss from ..h2o import get_hh_mask, H2OKVCache -from transformers.utils import is_flash_attn_greater_or_equal_2_10, logging +from transformers.utils import ( + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings + ) +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.models.opt.modeling_opt import OPTPreTrainedModel, OPTModel, OPTConfig + +from ..h2o import H2OKVCache, H2OConfig, generate logger = logging.get_logger(__name__) +_CONFIG_FOR_DOC = "OPTConfig" class H2OOPTAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper.""" def __init__( self, - model, - config, - heavy_ratio, - recent_ratio, - h2o_min_seqlen=1024, - real_drop=False, - is_gen=False, - mean=False, - local=True + config: OPTConfig, + is_decoder: bool = False, + h2o_config: H2OConfig =- None, ): super().__init__() self.config = config @@ -56,24 +61,25 @@ def __init__( f" and `num_heads`: {self.num_heads})." ) self.scaling = self.head_dim**-0.5 - self.is_decoder = model.is_decoder + self.is_decoder = is_decoder - self.k_proj = model.k_proj - self.v_proj = model.v_proj - self.q_proj = model.q_proj - self.out_proj = model.out_proj + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) # for h2o - self.is_gen = is_gen - self.real_drop = real_drop - self.mean = mean - self.local = local + self.h2o_config = h2o_config + self.is_gen = False + self.real_drop = h2o_config.real_drop + self.mean = h2o_config.mean + self.local = h2o_config.local - self.heavy_ratio = heavy_ratio - self.recent_ratio = recent_ratio - self.h2o_min_seqlen = h2o_min_seqlen + self.heavy_ratio = h2o_config.heavy_ratio + self.recent_ratio = h2o_config.recent_ratio + self.h2o_min_seqlen = h2o_config.h2o_min_seqlen - self.h2o_kv_cache = H2OKVCache(self.heavy_ratio, self.recent_ratio, real_drop, h2o_min_seqlen) + self.h2o_kv_cache = H2OKVCache(self.heavy_ratio, self.recent_ratio, h2o_config.h2o_min_seqlen) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -163,8 +169,8 @@ def forward( mean=self.mean ) past_key_value = (new_key_states, new_value_states) + self.h2o_kv_cache.past_length += attn_weights.size(-2) else: - from ..h2o import get_hh_mask mask = get_hh_mask( self.heavy_ratio, self.recent_ratio, @@ -367,23 +373,14 @@ def forward( if not self.is_gen: self.h2o_kv_cache.clean_scores() - if self.real_drop: - new_key_states, new_value_states = self.h2o_kv_cache( - attn_weights, - past_key_value[0], - past_key_value[1], - mean=self.mean - ) - past_key_value = (new_key_states, new_value_states) - else: - mask = self.h2o_kv_cache( - attn_weights, - past_key_value[0], - past_key_value[1], - mean=self.mean - ) - key_states = key_states * mask.unsqueeze(-1) - value_states = value_states * mask.unsqueeze(-1) + new_key_states, new_value_states = self.h2o_kv_cache( + attn_weights, + past_key_value[0], + past_key_value[1], + mean=self.mean + ) + past_key_value = (new_key_states, new_value_states) + self.h2o_kv_cache.past_length += attn_weights.size(-2) attn_output = self._flash_attention_forward( query_states, key_states, value_states, attention_mask, query_length, dropout=attn_dropout @@ -396,3 +393,229 @@ def forward( attn_weights_reshaped = None return attn_output, attn_weights_reshaped, past_key_value + + +class H2OOPTForCausalLM(OPTPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__( + self, + config, + h2o_config: H2OConfig, + ): + super().__init__(config) + self.model = OPTModel(config) + + # the lm_head weight is automatically tied to the embed tokens weight + self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) + + num_layers = len(self.model.decoder.layers) + for layer_idx in range(num_layers): + module = self.model.layers[layer_idx].self_attn + cls_name = module.__class__.__name__ + if not h2o_config.real_drop: + cls = H2OOPTAttention + elif cls_name == "OptFlashAttention2": + cls = H2OOptFlashAttention2 + else: + cls = H2OOPTAttention + self.model.decoder.layers[layer_idx].self_attn = cls( + config, + is_decoder=True, + h2o_config=h2o_config + ) + # Initialize weights and apply final processing + self.post_init() + + self.ori_generate = self.generate + self.generate = partial(generate, self) + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional + tensors are only required when the model is used as a decoder in a Sequence to Sequence model. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, OPTForCausalLM + + >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = self.lm_head(outputs[0]).contiguous() + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if past_key_values is not None: + # past_length = past_key_values[0][0].shape[2] + past_length = self.model.decoder.layers[0].self_attn.h2o_kv_cache.past_length + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past \ No newline at end of file