Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
new api
Browse files Browse the repository at this point in the history
Signed-off-by: n1ck-guo <[email protected]>
  • Loading branch information
n1ck-guo committed Jun 20, 2024
1 parent 76656c9 commit 91c5f3c
Show file tree
Hide file tree
Showing 4 changed files with 220 additions and 178 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@
CausalLMOutputWithPast,
)

from ..h2o import H2OKVCache, H2OConfig, generate

from intel_extension_for_transformers.transformers.modeling.modeling_gaudi import adapt_transformers_to_gaudi
from ..prune import PruneConfig, H2OConfig

logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LlamaConfig"
Expand Down Expand Up @@ -96,14 +96,13 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed

class H2OLlamaAttention(nn.Module):
class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper."""

def __init__(
self,
config: LlamaConfig,
layer_idx: Optional[int] = None,
h2o_config: H2OConfig = None,
):
super().__init__()
self.config = config
Expand Down Expand Up @@ -137,18 +136,14 @@ def __init__(
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
self._init_rope()

# for h2o
self.h2o_config = h2o_config
self.is_gen = False
self.real_drop = h2o_config.real_drop
self.mean = h2o_config.mean
self.local = h2o_config.local

self.heavy_ratio = h2o_config.heavy_ratio
self.recent_ratio = h2o_config.recent_ratio
self.h2o_min_seqlen = h2o_config.h2o_min_seqlen

self.h2o_kv_cache = H2OKVCache(self.heavy_ratio, self.recent_ratio, h2o_config.h2o_min_seqlen)
self._init_func = []

def register_init_func(self, func):
self._init_func.append(func)

def post_init(self):
for func in self._init_func:
func(self)

def _init_rope(self):
if self.config.rope_scaling is None:
Expand Down Expand Up @@ -177,7 +172,6 @@ def _init_rope(self):
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")


def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -239,33 +233,46 @@ def forward(

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

causal_mask = None
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask

# H2O
if past_key_value is not None:
if not self.is_gen:
self.h2o_kv_cache.clean_scores()
if self.real_drop:
new_key_states, new_value_states = self.h2o_kv_cache(
attn_weights.detach().clone(),
past_key_value.key_cache[self.layer_idx],
past_key_value.value_cache[self.layer_idx],
mean=self.mean
# pruning kv cache
if self.pruner.real_drop:
if self.layer_idx == 0:
self.pruner.past_length += query_states.size(-2)
if past_key_value is not None:
new_key_states, new_value_states = self.pruner.prune(
self,
query_states,
key_states,
value_states,
causal_mask=causal_mask
)
# reshape kv cache
if self.num_key_value_groups > 1:
n_rep = self.num_key_value_groups
drop_mask = torch.tensor(
[True if i % n_rep == 0 else False for i in range(0, new_key_states.size(1))]
).repeat(new_key_states.size(0), 1).to(new_key_states.device)
new_key_states = new_key_states[drop_mask].view(
new_key_states.shape[0],
int(new_key_states.shape[1] / n_rep),
-1,
new_key_states.shape[-1])
new_value_states = new_value_states[drop_mask].view(
new_value_states.shape[0],
int(new_value_states.shape[1] / n_rep),
-1,
new_value_states.shape[-1])

past_key_value.key_cache[self.layer_idx] = new_key_states
past_key_value.value_cache[self.layer_idx] = new_value_states
self.h2o_kv_cache.past_length += attn_weights.size(-2)
else:
from ..h2o import get_hh_mask
mask = get_hh_mask(
self.heavy_ratio,
self.recent_ratio,
attn_weights.detach().clone(),
local=self.local)
attn_weights[~mask] = torch.finfo(attn_weights.dtype).min

else: # similuate pruning to calculate acc
mask = self.pruner.get_mask(self, query_states, key_states, value_states,
causal_mask=causal_mask)
attn_weights[~mask] = torch.finfo(attn_weights.dtype).min
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
Expand Down Expand Up @@ -294,7 +301,7 @@ def forward(

return attn_output, attn_weights, past_key_value

class H2OLlamaFlashAttention2(H2OLlamaAttention):
class LlamaFlashAttention2(LlamaAttention):
"""Llama flash attention module.
This module inherits from `LlamaAttention` as the weights of the module stays
Expand Down Expand Up @@ -389,18 +396,20 @@ def forward(


# h2o
# pruning kv cache
if self.layer_idx == 0:
self.pruner.past_length += query_states.size(-2)
if past_key_value is not None:
if not self.is_gen:
self.h2o_kv_cache.clean_scores()
new_key_states, new_value_states = self.h2o_kv_cache(
attn_weights,
past_key_value.key_cache[self.layer_idx],
past_key_value.value_cache[self.layer_idx],
mean=self.mean
new_key_states, new_value_states = self.pruner.prune(
self,
query_states,
key_states,
value_states,
causal_mask=causal_mask
)

past_key_value.key_cache[self.layer_idx] = new_key_states
past_key_value.value_cache[self.layer_idx] = new_value_states
self.h2o_kv_cache.past_length += attn_weights.size(-2)

attn_output = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
Expand Down Expand Up @@ -510,7 +519,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)

class H2OLlamaSdpaAttention(H2OLlamaAttention):
class LlamaSdpaAttention(LlamaAttention):
"""Llama attention module using torch.nn.functional.scaled_dot_product_attention.
This module inherits from
Expand Down Expand Up @@ -589,28 +598,36 @@ def forward(
key_states = key_states.contiguous()
value_states = value_states.contiguous()

# h2o
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask

# pruning kv cache
if self.layer_idx == 0:
self.pruner.past_length += query_states.size(-2)
if past_key_value is not None:
if not self.is_gen:
self.h2o_kv_cache.clean_scores()
new_key_states, new_value_states = self.h2o_kv_cache(
attn_weights,
past_key_value.key_cache[self.layer_idx],
past_key_value.value_cache[self.layer_idx],
mean=self.mean
new_key_states, new_value_states = self.pruner.prune(
self,
query_states,
key_states,
value_states,
causal_mask=causal_mask
)
# if self.layer_idx == 0:
# print(self.layer_idx, self.is_gen, query_states.shape, new_key_states.shape)
# print(query_states.shape, key_states.shape, value_states.shape)
# reshape kv cache
if self.num_key_value_groups > 1:
n_rep = self.num_key_value_groups
drop_mask = torch.tensor(
[True if i % n_rep == 0 else False for i in range(0, new_key_states.size(1))]
).repeat(new_key_states.size(0), 1).to(new_key_states.device)
new_key_states = new_key_states[drop_mask].view(
new_key_states.shape[0],
int(new_key_states.shape[1] / n_rep),
-1,
new_key_states.shape[-1])
new_value_states = new_value_states[drop_mask].view(
new_value_states.shape[0],
int(new_value_states.shape[1] / n_rep),
-1,
new_value_states.shape[-1])

past_key_value.key_cache[self.layer_idx] = new_key_states
past_key_value.value_cache[self.layer_idx] = new_value_states
self.h2o_kv_cache.past_length += attn_weights.size(-2)

# In case we are not compiling, we may set `causal_mask` to None,
# which is required to dispatch to SDPA's Flash Attention 2 backend, rather
Expand All @@ -633,43 +650,58 @@ def forward(

return attn_output, None, past_key_value

class H2OLlamaForCausalLM(LlamaPreTrainedModel):
class LlamaForCausalLM(LlamaPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]

def __init__(
self,
config: LlamaConfig,
h2o_config: H2OConfig,
prune_config: PruneConfig,
):
super().__init__(config)
self.model = LlamaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

if isinstance(prune_config, H2OConfig):
from ..prune import H2OKVPruner
self.pruner = H2OKVPruner(prune_config)
else:
from ..prune import KVPruner
self.pruner = KVPruner(prune_config)

num_layers = len(self.model.layers)
for layer_idx in range(num_layers):
module = self.model.layers[layer_idx].self_attn
cls_name = module.__class__.__name__
if not h2o_config.real_drop:
cls = H2OLlamaAttention
if not prune_config.real_drop:
cls = LlamaAttention
elif cls_name == "LlamaFlashAttention2":
cls = H2OLlamaFlashAttention2
cls = LlamaFlashAttention2
elif cls_name == "LlamaSdpaAttention":
cls = H2OLlamaSdpaAttention
cls = LlamaSdpaAttention
else:
cls = H2OLlamaAttention

cls = LlamaAttention
self.model.layers[layer_idx].self_attn = cls(
config,
layer_idx,
h2o_config
layer_idx
)
self.model.layers[layer_idx].self_attn.register_init_func(self.pruner.self_attn_init)
self.model.layers[layer_idx].self_attn.post_init()

self.model.layers[layer_idx].self_attn.pruner = self.pruner

# Initialize weights and apply final processing
self.post_init()

def _generate(**kwargs):
self.pruner.before_generate(self, **kwargs)
result = self.ori_generate(**kwargs)
self.pruner.after_generate(self, **kwargs)
return result

self.ori_generate = self.generate
self.generate = partial(generate, self)
self.generate = _generate

def get_input_embeddings(self):
return self.model.embed_tokens
Expand Down Expand Up @@ -810,7 +842,7 @@ def prepare_inputs_for_generation(
max_cache_length = None

# cache_length = past_length = input_ids.shape[-1] - 1
cache_length = past_length = self.model.layers[0].self_attn.h2o_kv_cache.past_length
cache_length = past_length = self.pruner.past_length
max_cache_length = None
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .base import PruneConfig, KVPruner
from .h2o import H2OConfig, H2OKVPruner
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
class PruneConfig(dict):
def __init__(self, real_drop=True):
self.real_drop = real_drop

class KVPruner:
def __init__(self, prune_config) -> None:
self._past_length = 0

def self_attn_init(self, module):
pass

def prune(self, module, query_states, key_states, value_states, **kwargs):
pass

def before_generate(self, model, **kwargs):
self.past_length = 0

def after_generate(self, model, **kwargs):
pass

def get_mask(self, model, **kwargs):
pass

@property
def past_length(self):
return self._past_length

@past_length.setter
def past_length(self, value):
self._past_length = value
Loading

0 comments on commit 91c5f3c

Please sign in to comment.