Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
refactor code
Browse files Browse the repository at this point in the history
Signed-off-by: n1ck-guo <[email protected]>
  • Loading branch information
n1ck-guo committed Jun 18, 2024
1 parent d892e74 commit 124bb72
Show file tree
Hide file tree
Showing 9 changed files with 2,311 additions and 259 deletions.
79 changes: 53 additions & 26 deletions examples/huggingface/pytorch/text-generation/h2o/run_generation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import argparse
import sys
sys.path.insert(0, '/home/hengguo/code/intel-extension-for-transformers')
import time
import json
import torch
Expand Down Expand Up @@ -66,7 +65,6 @@
# transformers version >= 4.32.0 contained the mpt modeling definition.
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/mpt/modeling_mpt.py
# 4.31.0 for ipex.optimize_transformers
check_min_version("4.35.2")
# get model config
if args.peft_model_id:
from peft import PeftConfig
Expand Down Expand Up @@ -111,24 +109,26 @@
device = args.device
else:
device = f"cuda:{args.device}"
user_model = AutoModelForCausalLM.from_pretrained(args.model, trust_remote_code=args.trust_remote_code)
user_model.to(device)

# get optimized model
if args.h2o:
print('Enable Small Cache Size')
# checkpoint = copy.deepcopy(model.state_dict())
# model = ENABLE_Heavy_Hitter_FUNCTIONS[args.model_type](model, config)
from intel_extension_for_transformers.transformers.modeling.kv_cache_compression import convert_model
user_model = convert_model(
user_model,
from intel_extension_for_transformers.transformers.modeling.kv_cache_compression import H2OConfig, H2OLlamaForCausalLM
h2o_config = H2OConfig(
heavy_ratio=args.heavy_ratio,
recent_ratio=args.recent_ratio,
h2o_min_seqlen=args.h2o_min_seqlen,
real_drop=args.real_drop,
is_gen=args.is_gen
)
mean=False,
)
user_model = H2OLlamaForCausalLM.from_pretrained(
args.model,
h2o_config=h2o_config,
trust_remote_code=args.trust_remote_code)
print("converted model: ", user_model)
else:
user_model = AutoModelForCausalLM.from_pretrained(args.model, trust_remote_code=args.trust_remote_code)
user_model.to(device)

# save model
# if args.output_dir is not None:
Expand Down Expand Up @@ -194,18 +194,45 @@
if args.accuracy:
user_model = (user_model.eval() if (not (args.int8 or args.int8_bf16_mixed) and hasattr(user_model, "eval")) \
else user_model)
from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate, LMEvalParser
model_args="pretrained="+args.model+",trust_remote_code="+str(args.trust_remote_code)
args.tasks = ",".join(args.tasks)
tokenizer.pad_token = tokenizer.eos_token
eval_args = LMEvalParser(model = "hf",
user_model=user_model,
tokenizer=tokenizer,
model_args=model_args,
tasks = args.tasks,
device = device,
num_fewshot=args.num_fewshot,
output_path=args.save_accuracy_path,
batch_size = args.batch_size)
print("using device:", device)
results = evaluate(eval_args)
# from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate, LMEvalParser
# model_args="pretrained="+args.model+",trust_remote_code="+str(args.trust_remote_code)
# args.tasks = ",".join(args.tasks)
# tokenizer.pad_token = tokenizer.eos_token
# eval_args = LMEvalParser(model = "hf",
# user_model=user_model,
# tokenizer=tokenizer,
# model_args=model_args,
# tasks = args.tasks,
# device = device,
# num_fewshot=args.num_fewshot,
# output_path=args.save_accuracy_path,
# batch_size = args.batch_size)
# print("using device:", device)
# results = evaluate(eval_args)


# original lm_eval
from lm_eval.evaluator import simple_evaluate
from lm_eval.tasks import TaskManager
import lm_eval

verbosity = 'INFO'
task_manager = TaskManager(verbosity)
limit = None
cache_requests = False
lm = lm_eval.api.registry.get_model("hf")(
pretrained=user_model,
batch_size=args.batch_size,
max_batch_size=None,
)
model_args="pretrained="+ args.model+ ",tokenizer="+ args.model + ",dtype=float32"
use_cache = None
results = simple_evaluate(
model=lm,
model_args=model_args,
tasks=args.tasks,
num_fewshot=args.num_fewshot,
device=device
)
import pprint
pprint.pprint(results["results"])
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .h2o import convert_model
from .h2o import H2OConfig
from .models.modeling_llama import H2OLlamaForCausalLM
from .models.modeling_bloom import H2OBloomForCausalLM
from .models.modeling_gpt_neox import H2OGPTNeoXForCausalLM
from .models.modeling_opt import H2OOPTForCausalLM
from .models.modeling_mistral import H2OMistralForCausalLM
from .models.modeling_mixtral import H2OMixtralForCausalLM
Original file line number Diff line number Diff line change
Expand Up @@ -67,61 +67,22 @@ def get_module(model, op_name):

def clean_cache(model):
for _, module in model.named_modules():
if "H2O" in module.__class__.__name__:
if "Attention" in module.__class__.__name__:
module.h2o_kv_cache.clean_scores()

def generate(model, **kwargs):
max_length = kwargs['max_new_tokens'] if kwargs.get('max_new_tokens') else kwargs['max_length']
for _, module in model.named_modules():
if "H2O" in module.__class__.__name__:
if "Attention" in module.__class__.__name__:
module.is_gen = True
if module.h2o_kv_cache.heavy_budget is None:
module.h2o_kv_cache.heavy_budget = int(max_length * module.h2o_kv_cache.heavy_ratio)
if module.h2o_kv_cache.recent_budget is None:
module.h2o_kv_cache.recent_budget = int(max_length * module.h2o_kv_cache.recent_ratio)
result = model.ori_generate(**kwargs)
clean_cache(model)
return result

def convert_model(
model,
heavy_ratio,
recent_ratio,
h2o_min_seqlen=1024,
real_drop=True,
is_gen=False,
mean=False,
local=True
):
model_type = model.config.model_type
device = model.device
atten_layers = []
for name, module in model.named_modules():
if "Attention" in module.__class__.__name__:
atten_layers.append(name)

for layer_name in atten_layers:
module = get_module(model, layer_name)
cls_name = "H2O" + module.__class__.__name__ if real_drop else SIM_CLS_MAPPING[model_type]
h2o_cls = getattr(
importlib.import_module(
f".models.modeling_{model_type}",
"intel_extension_for_transformers.transformers.modeling.kv_cache_compression"
),
cls_name)
module = h2o_cls(
module,
model.config,
heavy_ratio,
recent_ratio,
h2o_min_seqlen=h2o_min_seqlen,
real_drop=real_drop,
is_gen=is_gen,
mean=mean,
local=local
)
set_module(model, layer_name, module)
model.clean_cache = lambda: clean_cache(model)
model.ori_generate = model.generate
model.generate = partial(generate, model)
model = model.to(device)
return model


def local_heavy_hitter_mask(attn_weights, heavy_budget, no_padding_seq_length=None):

Expand Down Expand Up @@ -188,28 +149,37 @@ def get_hh_mask(heavy_budget_ratio, recent_budget_ratio, attn_weights, local=Tru

return mask_bottom


class H2OKVCache:
def __init__(
self,
heavy_ratio=0.2,
recent_ratio=0.2,
heavy_budget=None,
recent_budget=None,
min_seqlen=-1
):
## bsz, num_heads, seq_len, head_dim
self.heavy_ratio = heavy_ratio
self.recent_ratio = recent_ratio
self.heavy_budget = heavy_budget
self.recent_budget = recent_budget
self.hh_score = None
self.min_seqlen = min_seqlen
self.idx = 0

self._past_length = 0


def __call__(self, attn_score, key_states, value_states, mean=False, **kwargs):
seq_len = key_states.size(-2)
heavy_budget = int(self.heavy_ratio * seq_len)
recent_budget = int(self.recent_ratio * seq_len)
cache_size = heavy_budget + recent_budget
if self.heavy_budget is None:
self.heavy_budget = int(self.heavy_ratio * seq_len)
if self.recent_budget is None:
self.recent_budget = int(self.recent_ratio * seq_len)
cache_size = self.heavy_budget + self.recent_budget
if seq_len <= self.min_seqlen or seq_len <= cache_size:
return torch.ones(attn_score.shape[:-1], dtype=attn_score.dtype).to(key_states.device)
return key_states, value_states
self.idx += 1
# attn_score shape (bsz, num_heads, seq_len, head_dim)
if len(attn_score.shape) == 3:
Expand All @@ -218,12 +188,12 @@ def __call__(self, attn_score, key_states, value_states, mean=False, **kwargs):

# hh-selection
mask = torch.zeros(self.hh_score.shape, dtype=attn_score.dtype).to(key_states.device)
if not recent_budget == 0:
mask[:,:,-recent_budget:] = 1
select_hh_scores = self.hh_score[:,:,:seq_len - recent_budget]
if not self.recent_budget == 0:
mask[:,:,-self.recent_budget:] = 1
select_hh_scores = self.hh_score[:,:,:seq_len - self.recent_budget]

if not heavy_budget == 0:
_, keep_topk = torch.topk(select_hh_scores, heavy_budget, dim=-1, largest=True)
if not self.heavy_budget == 0:
_, keep_topk = torch.topk(select_hh_scores, self.heavy_budget, dim=-1, largest=True)
mask = mask.scatter(-1, keep_topk, 1)

mask = mask.bool()
Expand Down Expand Up @@ -260,7 +230,36 @@ def _update_hh_score(self, attn_score_cache, mean=False):

self.hh_score = attn_score_cache


def clean_scores(self):
self.idx = 0
self.past_length = 0
self.hh_score = None

@property
def past_length(self):
return self._past_length

@past_length.setter
def past_length(self, value):
self._past_length = value

class H2OConfig(dict):
def __init__(
self,
heavy_ratio: float = None,
recent_ratio: float = None,
heavy_budget: int = None,
recent_budget: int = None,
h2o_min_seqlen: int = -1,
real_drop: bool = True,
mean: bool = False,
local: bool = True
):
self.heavy_ratio = heavy_ratio
self.recent_ratio = recent_ratio
self.heavy_budget = heavy_budget
self.recent_budget = recent_budget
self.h2o_min_seqlen = h2o_min_seqlen
self.real_drop = real_drop
self.mean = mean
self.local = local
Loading

0 comments on commit 124bb72

Please sign in to comment.