-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmerge_weights.py
130 lines (108 loc) · 5.22 KB
/
merge_weights.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import os
import sys
import torch
from transformers import (
HfArgumentParser,
TrainingArguments,
)
from args import AdditionalArguments
from models.modeling_llama import LlamaForCausalLM
from models.modeling_llama import LlamaConfig
from models.model_args import ModelArguments
from utils.nuteprune_utils import load_zs
logger = logging.getLogger(__name__)
def update_params(lm_model, zs):
model = lm_model.model
config = lm_model.config
hidden_dims = config.hidden_size
num_heads = config.num_attention_heads
dims_per_head = hidden_dims // num_heads
num_layers = config.num_hidden_layers
if zs is not None:
if "intermediate_z" in zs:
for layer in range(num_layers):
if "mlp_z" in zs and zs["mlp_z"][layer] == 0:
continue
intermediate_z = zs["intermediate_z"][layer].cpu().squeeze().clone()
model.layers[layer].mlp.gate_proj.weight.data = model.layers[layer].mlp.gate_proj.weight.transpose(0, 1).data.mul(intermediate_z).transpose(0, 1)
model.layers[layer].mlp.up_proj.weight.data = model.layers[layer].mlp.up_proj.weight.transpose(0, 1).data.mul(intermediate_z).transpose(0, 1)
if "head_z" in zs:
for layer in range(num_layers):
if "head_layer_z" in zs and zs["head_layer_z"][layer] == 0:
continue
head_z = zs["head_z"][layer].cpu().squeeze().clone()
head_z = torch.repeat_interleave(head_z, dims_per_head)
model.layers[layer].self_attn.v_proj.weight.data = model.layers[layer].self_attn.v_proj.weight.transpose(0, 1).data.mul(head_z).transpose(0, 1)
# GQA pruning?
# model.layers[layer].self_attn.o_proj.weight.data = model.layers[layer].self_attn.o_proj.weight.data.mul(head_z)
if "hidden_z" in zs:
hidden_z = zs["hidden_z"].cpu().squeeze().clone()
for layer in range(num_layers):
model.layers[layer].self_attn.o_proj.weight.data = model.layers[layer].self_attn.o_proj.weight.transpose(0, 1).data.mul(hidden_z).transpose(0, 1)
model.layers[layer].mlp.down_proj.weight.data = model.layers[layer].mlp.down_proj.weight.transpose(0, 1).data.mul(hidden_z).transpose(0, 1)
def set_lora_args(config, modeling_args):
config.use_lora = modeling_args.use_lora
config.lora_rank = modeling_args.lora_rank
config.lora_train_bias = modeling_args.lora_train_bias
config.lora_alpha = modeling_args.lora_alpha
config.lora_param = modeling_args.lora_param
config.lora_layers = modeling_args.lora_layers
return config
def main():
parser = HfArgumentParser((ModelArguments, TrainingArguments, AdditionalArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, training_args, additional_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, training_args, additional_args = parser.parse_args_into_dataclasses()
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
training_args.report_to = []
# model initialize
config = LlamaConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
)
config.use_cache = False
config = set_lora_args(config, model_args)
lora_ckpt = os.path.join(additional_args.pretrained_pruned_model, 'lora_weights.pt')
model = LlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
)
if lora_ckpt is not None:
model.load_state_dict(torch.load(lora_ckpt), strict=False)
def merge_lora(module):
if hasattr(module, 'merge'):
module.merge()
for m in module.children():
merge_lora(m)
merge_lora(model)
config.use_lora = False
llama = LlamaForCausalLM.from_pretrained(model_args.model_name_or_path, config=config)
output_path = "./llama_pruned" if training_args.output_dir == "./" else training_args.output_dir
llama.load_state_dict(model.state_dict(), strict=False)
print(f"LoRA weights merged! Output path: {output_path}")
zs = load_zs(os.path.join(additional_args.pretrained_pruned_model, 'zs.pt'))
for key in zs:
zs[key] = zs[key].detach()
update_params(llama, zs)
output_path = "./llama_pruned" if training_args.output_dir == "./" else training_args.output_dir
llama.half()
llama.save_pretrained(output_path)
print(f"Pruning mask merged! Output path: {output_path}")
if __name__ == "__main__":
main()