forked from aws/amazon-sagemaker-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathextract_subnetworks.py
126 lines (104 loc) · 4.97 KB
/
extract_subnetworks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from collections import OrderedDict
import torch.nn as nn
from transformers import AutoModelForSequenceClassification
from transformers.models.bert.modeling_bert import BertForSequenceClassification, BertForMultipleChoice, BertConfig
def copy_linear_layer(new_layer, old_layer, weight_shape, bias_shape):
old_state = old_layer.state_dict()
new_state_dict = OrderedDict()
new_state_dict["weight"] = old_state["weight"][: weight_shape[0], : weight_shape[1]]
new_state_dict["bias"] = old_state["bias"][:bias_shape]
new_layer.load_state_dict(new_state_dict)
def copy_layer_norm(new_layer, old_layer):
old_state = old_layer.state_dict()
new_state_dict = OrderedDict()
new_state_dict["weight"] = old_state["weight"]
new_state_dict["bias"] = old_state["bias"]
new_layer.load_state_dict(new_state_dict)
def get_final_bert_model(original_model, new_model_config):
assert isinstance(original_model, (BertForSequenceClassification,
AutoModelForSequenceClassification,
BertForMultipleChoice)), f"Make sure to pass a valid BERT model for" \
f" sequence classification or multiple choice Q/A"
assert isinstance(new_model_config, BertConfig), f"Make sure to pass a valid BERT model for" \
f" sequence classification or multiple choice Q/A"
original_model.eval()
new_model = AutoModelForSequenceClassification.from_config(new_model_config)
new_model.eval()
new_model.bert.embeddings.load_state_dict(
original_model.bert.embeddings.state_dict()
)
new_model.bert.pooler.load_state_dict(original_model.bert.pooler.state_dict())
new_model.classifier.load_state_dict(original_model.classifier.state_dict())
num_attention_heads = new_model_config.num_attention_heads
attention_head_size = new_model_config.attention_head_size
all_head_size = num_attention_heads * attention_head_size
for li, layer in enumerate(new_model.bert.encoder.layer):
attention = layer.attention
attention.self.query = nn.Linear(new_model_config.hidden_size, all_head_size)
attention.self.key = nn.Linear(new_model_config.hidden_size, all_head_size)
attention.self.value = nn.Linear(new_model_config.hidden_size, all_head_size)
attention.output.dense = nn.Linear(
all_head_size,
new_model_config.hidden_size,
)
attention.self.all_head_size = all_head_size
attention.self.attention_head_size = attention_head_size
mha_original_model = original_model.bert.encoder.layer[li].attention
copy_linear_layer(
attention.self.query,
mha_original_model.self.query,
(all_head_size, new_model_config.hidden_size),
(all_head_size),
)
copy_linear_layer(
attention.self.key,
mha_original_model.self.key,
(all_head_size, new_model_config.hidden_size),
(all_head_size),
)
copy_linear_layer(
attention.self.value,
mha_original_model.self.value,
(all_head_size, new_model_config.hidden_size),
(all_head_size),
)
copy_linear_layer(
attention.output.dense,
mha_original_model.output.dense,
(new_model_config.hidden_size, all_head_size),
(new_model_config.hidden_size),
)
copy_layer_norm(attention.output.LayerNorm, mha_original_model.output.LayerNorm)
ffn_layer = layer.intermediate.dense
ffn_original_model = original_model.bert.encoder.layer[li].intermediate.dense
copy_linear_layer(
ffn_layer,
ffn_original_model,
(new_model_config.intermediate_size, new_model_config.hidden_size),
(new_model_config.intermediate_size),
)
ffn_layer = layer.output.dense
ffn_original_model = original_model.bert.encoder.layer[li].output.dense
copy_linear_layer(
ffn_layer,
ffn_original_model,
(new_model_config.hidden_size, new_model_config.intermediate_size),
(new_model_config.hidden_size),
)
copy_layer_norm(
layer.output.LayerNorm,
original_model.bert.encoder.layer[li].output.LayerNorm,
)
return new_model