forked from TransformerLensOrg/TransformerLens
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathHookedEncoder.py
397 lines (336 loc) · 15.1 KB
/
HookedEncoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
from __future__ import annotations
import logging
from functools import lru_cache
from typing import Dict, Optional, Tuple, Union, cast, overload
import torch
from einops import repeat
from jaxtyping import Float, Int
from torch import nn
from transformers import AutoTokenizer
from typeguard import typeguard_ignore
from typing_extensions import Literal
import transformer_lens.loading_from_pretrained as loading
from transformer_lens import ActivationCache, FactoredMatrix, HookedTransformerConfig
from transformer_lens.components import BertBlock, BertEmbed, BertMLMHead, Unembed
from transformer_lens.hook_points import HookedRootModule, HookPoint
from transformer_lens.utilities import devices
class HookedEncoder(HookedRootModule):
"""
This class implements a BERT-style encoder using the components in ./components.py, with HookPoints on every interesting activation. It inherits from HookedRootModule.
Limitations:
- The current MVP implementation supports only the masked language modelling (MLM) task. Next sentence prediction (NSP), causal language modelling, and other tasks are not yet supported.
- Also note that model does not include dropouts, which may lead to inconsistent results from training or fine-tuning.
Like HookedTransformer, it can have a pretrained Transformer's weights loaded via `.from_pretrained`. There are a few features you might know from HookedTransformer which are not yet supported:
- There is no preprocessing (e.g. LayerNorm folding) when loading a pretrained model
- The model only accepts tokens as inputs, and not strings, or lists of strings
"""
def __init__(self, cfg, tokenizer=None, move_to_device=True, **kwargs):
super().__init__()
if isinstance(cfg, Dict):
cfg = HookedTransformerConfig(**cfg)
elif isinstance(cfg, str):
raise ValueError(
"Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a pretrained model, use HookedEncoder.from_pretrained() instead."
)
self.cfg = cfg
assert (
self.cfg.n_devices == 1
), "Multiple devices not supported for HookedEncoder"
if move_to_device:
self.to(self.cfg.device)
if tokenizer is not None:
self.tokenizer = tokenizer
elif self.cfg.tokenizer_name is not None:
self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.tokenizer_name)
else:
self.tokenizer = None
if self.cfg.d_vocab == -1:
# If we have a tokenizer, vocab size can be inferred from it.
assert (
self.tokenizer is not None
), "Must provide a tokenizer if d_vocab is not provided"
self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1
if self.cfg.d_vocab_out == -1:
self.cfg.d_vocab_out = self.cfg.d_vocab
self.embed = BertEmbed(self.cfg)
self.blocks = nn.ModuleList(
[BertBlock(self.cfg) for _ in range(self.cfg.n_layers)]
)
self.mlm_head = BertMLMHead(cfg)
self.unembed = Unembed(self.cfg)
self.hook_full_embed = HookPoint()
self.setup()
@overload
def forward(
self,
input: Int[torch.Tensor, "batch pos"],
return_type: Literal["logits"],
token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None,
one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
) -> Float[torch.Tensor, "batch pos d_vocab"]:
...
@overload
def forward(
self,
input: Int[torch.Tensor, "batch pos"],
return_type: Literal[None],
token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None,
one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
) -> Optional[Float[torch.Tensor, "batch pos d_vocab"]]:
...
def forward(
self,
input: Int[torch.Tensor, "batch pos"],
return_type: Optional[str] = "logits",
token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None,
one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
) -> Optional[Float[torch.Tensor, "batch pos d_vocab"]]:
"""Input must be a batch of tokens. Strings and lists of strings are not yet supported.
return_type Optional[str]: The type of output to return. Can be one of: None (return nothing, don't calculate logits), or 'logits' (return logits).
token_type_ids Optional[torch.Tensor]: Binary ids indicating whether a token belongs to sequence A or B. For example, for two sentences: "[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be [0, 0, ..., 0, 1, ..., 1, 1]. `0` represents tokens from Sentence A, `1` from Sentence B. If not provided, BERT assumes a single sequence input. Typically, shape is (batch_size, sequence_length).
one_zero_attention_mask: Optional[torch.Tensor]: A binary mask which indicates which tokens should be attended to (1) and which should be ignored (0). Primarily used for padding variable-length sentences in a batch. For instance, in a batch with sentences of differing lengths, shorter sentences are padded with 0s on the right. If not provided, the model assumes all tokens should be attended to.
"""
tokens = input
if tokens.device.type != self.cfg.device:
tokens = tokens.to(self.cfg.device)
if one_zero_attention_mask is not None:
one_zero_attention_mask = one_zero_attention_mask.to(self.cfg.device)
resid = self.hook_full_embed(self.embed(tokens, token_type_ids))
large_negative_number = -1e5
additive_attention_mask = (
large_negative_number
* repeat(1 - one_zero_attention_mask, "batch pos -> batch 1 1 pos")
if one_zero_attention_mask is not None
else None
)
for block in self.blocks:
resid = block(resid, additive_attention_mask)
resid = self.mlm_head(resid)
if return_type is None:
return
logits = self.unembed(resid)
return logits
@overload
def run_with_cache(
self, *model_args, return_cache_object: Literal[True] = True, **kwargs
) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], ActivationCache]:
...
@overload
def run_with_cache(
self, *model_args, return_cache_object: Literal[False] = False, **kwargs
) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Dict[str, torch.Tensor]]:
...
def run_with_cache(
self,
*model_args,
return_cache_object: bool = True,
remove_batch_dim: bool = False,
**kwargs,
) -> Tuple[
Float[torch.Tensor, "batch pos d_vocab"],
Union[ActivationCache, Dict[str, torch.Tensor]],
]:
"""
Wrapper around run_with_cache in HookedRootModule. If return_cache_object is True, this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods, otherwise it will return a dictionary of activations as in HookedRootModule. This function was copied directly from HookedTransformer.
"""
out, cache_dict = super().run_with_cache(
*model_args, remove_batch_dim=remove_batch_dim, **kwargs
)
if return_cache_object:
cache = ActivationCache(
cache_dict, self, has_batch_dim=not remove_batch_dim
)
return out, cache
else:
return out, cache_dict
def to(
self,
device_or_dtype: Union[torch.device, str, torch.dtype],
print_details: bool = True,
):
return devices.move_to_and_update_config(self, device_or_dtype, print_details)
def cuda(self):
# Wrapper around cuda that also changes self.cfg.device
return self.to("cuda")
def cpu(self):
# Wrapper around cuda that also changes self.cfg.device
return self.to("cpu")
@classmethod
def from_pretrained(
cls,
model_name: str,
checkpoint_index: Optional[int] = None,
checkpoint_value: Optional[int] = None,
hf_model=None,
device: Optional[str] = None,
**model_kwargs,
) -> HookedEncoder:
"""Loads in the pretrained weights from huggingface. Currently supports loading weight from HuggingFace BertForMaskedLM. Unlike HookedTransformer, this does not yet do any preprocessing on the model."""
logging.warning(
"Support for BERT in TransformerLens is currently experimental, until such a time when it has feature "
"parity with HookedTransformer and has been tested on real research tasks. Until then, backward "
"compatibility is not guaranteed. Please see the docs for information on the limitations of the current "
"implementation."
"\n"
"If using BERT for interpretability research, keep in mind that BERT has some significant architectural "
"differences to GPT. For example, LayerNorms are applied *after* the attention and MLP components, meaning "
"that the last LayerNorm in a block cannot be folded."
)
official_model_name = loading.get_official_model_name(model_name)
cfg = loading.get_pretrained_model_config(
official_model_name,
checkpoint_index=checkpoint_index,
checkpoint_value=checkpoint_value,
fold_ln=False,
device=device,
n_devices=1,
)
state_dict = loading.get_pretrained_state_dict(
official_model_name, cfg, hf_model
)
model = cls(cfg, **model_kwargs)
model.load_state_dict(state_dict, strict=False)
print(f"Loaded pretrained model {model_name} into HookedTransformer")
return model
@property
@typeguard_ignore
def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]:
"""
Convenience to get the unembedding matrix (ie the linear map from the final residual stream to the output logits)
"""
return self.unembed.W_U
@property
@typeguard_ignore
def b_U(self) -> Float[torch.Tensor, "d_vocab"]:
return self.unembed.b_U
@property
@typeguard_ignore
def W_E(self) -> Float[torch.Tensor, "d_vocab d_model"]:
"""
Convenience to get the embedding matrix
"""
return self.embed.embed.W_E
@property
@typeguard_ignore
def W_pos(self) -> Float[torch.Tensor, "n_ctx d_model"]:
"""
Convenience function to get the positional embedding. Only works on models with absolute positional embeddings!
"""
return self.embed.pos_embed.W_pos
@property
@typeguard_ignore
def W_E_pos(self) -> Float[torch.Tensor, "d_vocab+n_ctx d_model"]:
"""
Concatenated W_E and W_pos. Used as a full (overcomplete) basis of the input space, useful for full QK and full OV circuits.
"""
return torch.cat([self.W_E, self.W_pos], dim=0)
@property
@typeguard_ignore
@lru_cache(maxsize=None)
def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
"""Stacks the key weights across all layers"""
return torch.stack(
[cast(BertBlock, block).attn.W_K for block in self.blocks], dim=0
)
@property
@typeguard_ignore
@lru_cache(maxsize=None)
def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
"""Stacks the query weights across all layers"""
return torch.stack(
[cast(BertBlock, block).attn.W_Q for block in self.blocks], dim=0
)
@property
@typeguard_ignore
@lru_cache(maxsize=None)
def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
"""Stacks the value weights across all layers"""
return torch.stack(
[cast(BertBlock, block).attn.W_V for block in self.blocks], dim=0
)
@property
@typeguard_ignore
@lru_cache(maxsize=None)
def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]:
"""Stacks the attn output weights across all layers"""
return torch.stack(
[cast(BertBlock, block).attn.W_O for block in self.blocks], dim=0
)
@property
@typeguard_ignore
@lru_cache(maxsize=None)
def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]:
"""Stacks the MLP input weights across all layers"""
return torch.stack(
[cast(BertBlock, block).mlp.W_in for block in self.blocks], dim=0
)
@property
@typeguard_ignore
@lru_cache(maxsize=None)
def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]:
"""Stacks the MLP output weights across all layers"""
return torch.stack(
[cast(BertBlock, block).mlp.W_out for block in self.blocks], dim=0
)
@property
@typeguard_ignore
@lru_cache(maxsize=None)
def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
"""Stacks the key biases across all layers"""
return torch.stack(
[cast(BertBlock, block).attn.b_K for block in self.blocks], dim=0
)
@property
@typeguard_ignore
@lru_cache(maxsize=None)
def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
"""Stacks the query biases across all layers"""
return torch.stack(
[cast(BertBlock, block).attn.b_Q for block in self.blocks], dim=0
)
@property
@typeguard_ignore
@lru_cache(maxsize=None)
def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
"""Stacks the value biases across all layers"""
return torch.stack(
[cast(BertBlock, block).attn.b_V for block in self.blocks], dim=0
)
@property
@typeguard_ignore
@lru_cache(maxsize=None)
def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]:
"""Stacks the attn output biases across all layers"""
return torch.stack(
[cast(BertBlock, block).attn.b_O for block in self.blocks], dim=0
)
@property
@typeguard_ignore
@lru_cache(maxsize=None)
def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]:
"""Stacks the MLP input biases across all layers"""
return torch.stack(
[cast(BertBlock, block).mlp.b_in for block in self.blocks], dim=0
)
@property
@typeguard_ignore
@lru_cache(maxsize=None)
def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]:
"""Stacks the MLP output biases across all layers"""
return torch.stack(
[cast(BertBlock, block).mlp.b_out for block in self.blocks], dim=0
)
@property
@typeguard_ignore
def QK(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model]
return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1))
@property
@typeguard_ignore
def OV(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model]
return FactoredMatrix(self.W_V, self.W_O)
def all_head_labels(self) -> list[str]:
return [
f"L{l}H{h}"
for l in range(self.cfg.n_layers)
for h in range(self.cfg.n_heads)
]