forked from TransformerLensOrg/TransformerLens
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathHookedTransformer.py
1754 lines (1563 loc) · 80.1 KB
/
HookedTransformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import logging
from functools import lru_cache
from typing import Dict, List, NamedTuple, Optional, Tuple, Union, overload
import einops
import numpy as np
import torch
import torch.nn as nn
import tqdm.auto as tqdm
from fancy_einsum import einsum
from jaxtyping import Float, Int
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from typeguard import typeguard_ignore
from typing_extensions import Literal
import transformer_lens.loading_from_pretrained as loading
import transformer_lens.utils as utils
from transformer_lens import HookedTransformerConfig
from transformer_lens.ActivationCache import ActivationCache
from transformer_lens.components import (
Embed,
LayerNorm,
LayerNormPre,
PosEmbed,
RMSNorm,
RMSNormPre,
TransformerBlock,
Unembed,
)
from transformer_lens.FactoredMatrix import FactoredMatrix
from transformer_lens.hook_points import HookedRootModule, HookPoint
# Note - activation cache is used with run_with_cache, past_key_value_caching is used for generation.
from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCache
from transformer_lens.utilities import devices
SingleLoss = Float[torch.Tensor, ""] # Type alias for a single element tensor
LossPerToken = Float[torch.Tensor, "batch pos-1"]
Loss = Union[SingleLoss, LossPerToken]
# Named tuple object for if we want to output both logits and loss
class Output(NamedTuple):
logits: Float[torch.Tensor, "batch pos d_vocab"]
loss: Loss
class HookedTransformer(HookedRootModule):
"""
This class implements a full Transformer using the components in ./components.py, with
HookPoints on every interesting activation. It inherits from HookedRootModule.
It can have a pretrained Transformer's weights automatically loaded in via the HookedTransformer.from_pretrained
class method. It can also be instantiated with randomly initialized weights via __init__ and being passed a dict or
HookedTransformerConfig object.
"""
def __init__(
self,
cfg,
tokenizer=None,
move_to_device=True,
):
"""
Model initialization. Note that if you want to load the model from pretrained weights, you should use the
HookedTransformer.from_pretrained() class method instead of this one.
cfg Union[HookedTransformerConfig, Dict]: The config to use for the
model.
tokenizer (*optional): The tokenizer to use for the model. If not
provided, it is inferred from cfg.tokenizer_name or initialized to None.
If None, then the model cannot be passed strings, and d_vocab must be explicitly set.
move_to_device (bool): Whether to move the model to the device specified in cfg.
device. Must be true if `n_devices` in the config is greater than 1, since the model's layers
will be split across multiple devices.
"""
super().__init__()
if isinstance(cfg, Dict):
cfg = HookedTransformerConfig(**cfg)
elif isinstance(cfg, str):
raise ValueError(
"Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a "
"pretrained model, use HookedTransformer.from_pretrained() instead."
)
self.cfg = cfg
assert (
self.cfg.n_devices == 1 or move_to_device
), "If n_devices > 1, must move_to_device"
if tokenizer is not None:
self.set_tokenizer(tokenizer)
elif self.cfg.tokenizer_name is not None:
# If we have a tokenizer name, we can load it from HuggingFace
if "llama" in self.cfg.tokenizer_name:
# llama tokenizer requires special handling
print("Warning: LLaMA tokenizer not loaded. Please load manually.")
else:
self.set_tokenizer(
AutoTokenizer.from_pretrained(self.cfg.tokenizer_name)
)
else:
# If no tokenizer name is provided, we assume we're training on an algorithmic task and will pass in tokens
# directly. In this case, we don't need a tokenizer.
assert (
self.cfg.d_vocab != -1
), "Must provide a tokenizer if d_vocab is not provided"
self.tokenizer = None
self.embed = Embed(self.cfg)
self.hook_embed = HookPoint() # [batch, pos, d_model]
if self.cfg.positional_embedding_type != "rotary":
self.pos_embed = PosEmbed(self.cfg)
self.hook_pos_embed = HookPoint() # [batch, pos, d__dictmodel]
if self.cfg.use_hook_tokens:
self.hook_tokens = HookPoint() # [batch, pos]
self.blocks = nn.ModuleList(
[
TransformerBlock(self.cfg, block_index)
for block_index in range(self.cfg.n_layers)
]
)
if self.cfg.normalization_type == "RMS":
self.ln_final = RMSNorm(self.cfg)
elif self.cfg.normalization_type == "RMSPre":
self.ln_final = RMSNormPre(self.cfg)
elif self.cfg.normalization_type == "LN":
if self.cfg.final_rms:
self.ln_final = RMSNorm(self.cfg)
else:
self.ln_final = LayerNorm(self.cfg)
elif self.cfg.normalization_type == "LNPre":
# We've folded in LayerNorm weights, so just need the center + scale parts
if self.cfg.final_rms:
self.ln_final = RMSNormPre(self.cfg)
else:
self.ln_final = LayerNormPre(self.cfg)
elif self.cfg.normalization_type is None:
# If it's None, don't create either layer
pass
else:
logging.warning(
f"Invalid normalization_type passed in {self.cfg.normalization_type}"
)
self.unembed = Unembed(self.cfg)
if self.cfg.init_weights:
self.init_weights()
if move_to_device:
# We load the devices in a pipeline manner - the first device gets the embed and pos_embed layers and the
# first n_layers // n_devices blocks,
# the second gets the next n_layers // n_devices blocks ... the last gets the last n_layers // n_devices
# blocks, the final
# normalization layer (if it exists) and the unembed layer
HookedTransformer.move_model_modules_to_device(self)
# Helper variable to store a small (10K-20K) dataset of training data. Empty by default, can be loaded with
# load_sample_training_dataset
self.dataset = None
# Gives each module a parameter with its name (relative to this root module)
# Needed for HookPoints to work
self.setup()
def check_hooks_to_add(
self, hook_point, hook_point_name, hook, dir="fwd", is_permanent=False
) -> None:
if hook_point_name.endswith("attn.hook_result"):
assert (
self.cfg.use_attn_result
), f"Cannot add hook {hook_point_name} if use_attn_result_hook is False"
if hook_point_name.endswith(("hook_q_input", "hook_k_input", "hook_v_input")):
assert (
self.cfg.use_split_qkv_input
), f"Cannot add hook {hook_point_name} if use_split_qkv_input is False"
@overload
def forward(
self,
input,
return_type: Literal["logits"],
loss_per_token: bool = False,
prepend_bos: bool = True,
stop_at_layer: Optional[int] = None,
past_kv_cache: Optional[HookedTransformerKeyValueCache] = None,
) -> Loss:
...
@overload
def forward(
self,
input,
return_type: Literal["loss"],
loss_per_token: bool = False,
prepend_bos: bool = True,
stop_at_layer: Optional[int] = None,
past_kv_cache: Optional[HookedTransformerKeyValueCache] = None,
) -> Loss:
...
@overload
def forward(
self,
input,
return_type: Literal["both"],
loss_per_token: bool = False,
prepend_bos: bool = True,
stop_at_layer: Optional[int] = None,
past_kv_cache: Optional[HookedTransformerKeyValueCache] = None,
) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss]:
...
@overload
def forward(
self,
input,
return_type: Literal[None],
loss_per_token: bool = False,
prepend_bos: bool = True,
stop_at_layer: Optional[int] = None,
past_kv_cache: Optional[HookedTransformerKeyValueCache] = None,
) -> None:
...
# TODO make sure type assertions are provided
def forward(
self,
input: Union[str, List[str], Int[torch.Tensor, "batch pos"]],
return_type: Optional[str] = "logits",
loss_per_token: bool = False,
prepend_bos: bool = True,
stop_at_layer: Optional[int] = None,
past_kv_cache: Optional[HookedTransformerKeyValueCache] = None,
) -> Union[
None,
Float[torch.Tensor, "batch pos d_vocab"],
Loss,
Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
]:
"""Input is either a batch of tokens ([batch, pos]) or a text string, a string is automatically tokenized to a
batch of a single element. The prepend_bos flag only applies when inputting a text string.
return_type Optional[str]: The type of output to return. Can be one of: None (return nothing, don't calculate
logits), 'logits' (return logits), 'loss' (return cross-entropy loss), 'both' (return logits and loss)
loss_per_token bool: Whether to return the (next token prediction) loss per token (True) or average (False).
Average loss is a scalar (averaged over position *and* batch), per-token loss is a tensor ([batch, position-1])
- position-1 because we're predicting the next token, and there's no specified next token for the final
token. Defaults to False.
prepend_bos bool: Whether to prepend the BOS token to the input. Only applies when input is a string. Defaults
to True (unlike to_tokens) - even for models not explicitly trained with this, heads often use the first
position as a resting position and accordingly lose information from the first token, so this empirically
seems to give better results.
stop_at_layer Optional[int]: If not None, stop the forward pass at the specified layer. Exclusive - ie,
stop_at_layer = 0 will only run the embedding layer, stop_at_layer = 1 will run the embedding layer and the
first transformer block, etc. Supports negative indexing. Useful for analysis of intermediate layers, eg finding
neuron activations in layer 3 of a 24 layer model. Defaults to None (run the full model).
Note that loss is the standard "predict the next token" cross-entropy loss for GPT-2 style language models -
if you want a custom loss function, the recommended behaviour is returning the logits and then applying your
custom loss function.
"""
if type(input) == str or type(input) == list:
# If text, convert to tokens (batch_size=1)
assert (
self.tokenizer is not None
), "Must provide a tokenizer if passing a string to the model"
# This is only intended to support passing in a single string
tokens = self.to_tokens(input, prepend_bos=prepend_bos)
else:
tokens = input
if len(tokens.shape) == 1:
# If tokens are a rank 1 tensor, add a dummy batch dimension to avoid things breaking.
tokens = tokens[None]
if tokens.device.type != self.cfg.device:
tokens = tokens.to(devices.get_device_for_block_index(0, self.cfg))
# If we're doing caching, then we reuse keys and values from previous runs, as that's the only
# way that past activations will affect the final logits. The cache contains those so we don't
# need to recompute them. This is useful for generating text. As we have absolute positional
# encodings, to implement this we have a `pos_offset` variable, defaulting to zero, which says
# to offset which positional encodings are used (cached keys and values were calculated with
# their own positional encodings).
if past_kv_cache is None:
pos_offset = 0
else:
batch_size, ctx_length = tokens.shape
(
cached_batch_size,
cache_ctx_length,
num_heads_in_cache,
d_head_in_cache,
) = past_kv_cache[0].past_keys.shape
assert cached_batch_size == batch_size
assert num_heads_in_cache == self.cfg.n_heads
assert d_head_in_cache == self.cfg.d_head
# If we want to generate from the empty string, we'd pass in an empty cache, so we need to handle that case
assert (
cache_ctx_length == 0 or ctx_length == 1
), "Pass in one token at a time after loading cache"
pos_offset = cache_ctx_length
if self.cfg.use_hook_tokens:
tokens = self.hook_tokens(tokens)
embed = self.hook_embed(self.embed(tokens)) # [batch, pos, d_model]
if self.cfg.positional_embedding_type == "standard":
pos_embed = self.hook_pos_embed(
self.pos_embed(tokens, pos_offset)
) # [batch, pos, d_model]
residual = embed + pos_embed # [batch, pos, d_model]
shortformer_pos_embed = None
elif self.cfg.positional_embedding_type == "shortformer":
# If we're using shortformer style attention, we don't add the positional embedding to the residual stream.
# See HookedTransformerConfig for details
pos_embed = self.hook_pos_embed(
self.pos_embed(tokens, pos_offset)
) # [batch, pos, d_model]
residual = embed
shortformer_pos_embed = pos_embed
elif self.cfg.positional_embedding_type == "rotary":
# Rotary doesn't use positional embeddings, instead they're applied when dot producting keys and queries.
# See HookedTransformerConfig for details
residual = embed
shortformer_pos_embed = None
else:
raise ValueError(
f"Invalid positional_embedding_type passed in {self.cfg.positional_embedding_type}"
)
if stop_at_layer is None:
# We iterate through every block by default
transformer_block_list = self.blocks
else:
# If we explicitly want to stop at a layer, we only iterate through the blocks up to that layer. Note that
# this is exclusive, eg stop_at_layer==0 means to only run the embed, stop_at_layer==-1 means to run every
# layer *apart* from the final one, etc.
transformer_block_list = self.blocks[:stop_at_layer] # type: ignore
for i, block in enumerate(transformer_block_list): # type: ignore
# Note that each block includes skip connections, so we don't need
# residual + block(residual)
# If we're using multiple GPUs, we need to send the residual and shortformer_pos_embed to the correct GPU
residual = residual.to(devices.get_device_for_block_index(i, self.cfg))
if shortformer_pos_embed is not None:
shortformer_pos_embed = shortformer_pos_embed.to(
devices.get_device_for_block_index(i, self.cfg)
)
residual = block(
residual,
past_kv_cache_entry=past_kv_cache[i]
if past_kv_cache is not None
else None, # Cache contains a list of HookedTransformerKeyValueCache objects, one for each block
shortformer_pos_embed=shortformer_pos_embed,
) # [batch, pos, d_model]
if stop_at_layer is not None:
# When we stop at an early layer, we end here rather than doing further computation
return None
if self.cfg.normalization_type is not None:
residual = self.ln_final(residual) # [batch, pos, d_model]
if return_type is None:
return None
else:
logits = self.unembed(residual) # [batch, pos, d_vocab]
if return_type == "logits":
return logits
else:
loss = self.loss_fn(logits, tokens, per_token=loss_per_token)
if return_type == "loss":
return loss
elif return_type == "both":
return Output(logits, loss)
else:
logging.warning(f"Invalid return_type passed in: {return_type}")
return None
def loss_fn(
self,
logits: Float[torch.Tensor, "batch pos d_vocab"],
tokens: Int[torch.Tensor, "batch pos"],
per_token: bool = False,
):
"""
Wrapper around utils.lm_cross_entropy_loss, used in forward() with return_type=="loss" or "both".
"""
if tokens.device != logits.device:
tokens = tokens.to(logits.device)
return utils.lm_cross_entropy_loss(logits, tokens, per_token)
@overload
def run_with_cache(
self, *model_args, return_cache_object: Literal[True] = True, **kwargs
) -> Tuple[Output, ActivationCache]:
...
@overload
def run_with_cache(
self, *model_args, return_cache_object: Literal[False] = False, **kwargs
) -> Tuple[Output, Dict[str, torch.Tensor]]:
...
def run_with_cache(
self, *model_args, return_cache_object=True, remove_batch_dim=False, **kwargs
) -> Tuple[
Union[
None,
Float[torch.Tensor, "batch pos d_vocab"],
Loss,
Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
],
Union[ActivationCache, Dict[str, torch.Tensor]],
]:
"""
Wrapper around run_with_cache in HookedRootModule. If return_cache_object is True, this will return an
ActivationCache object, with a bunch of useful HookedTransformer specific methods, otherwise it will return a
dictionary of activations as in HookedRootModule.
"""
out, cache_dict = super().run_with_cache(
*model_args, remove_batch_dim=remove_batch_dim, **kwargs
)
if return_cache_object:
cache = ActivationCache(
cache_dict, self, has_batch_dim=not remove_batch_dim
)
return out, cache
else:
return out, cache_dict
def set_tokenizer(self, tokenizer):
"""
Sets the tokenizer to use for this model.
tokenizer (PreTrainedTokenizer): a pretrained HuggingFace tokenizer
"""
assert isinstance(
tokenizer, PreTrainedTokenizerBase
), f"{type(tokenizer)} is not a supported tokenizer, please use PreTrainedTokenizer or PreTrainedTokenizerFast"
self.tokenizer = tokenizer
if self.tokenizer.eos_token is None:
self.tokenizer.eos_token = "<|endoftext|>"
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
if self.tokenizer.bos_token is None:
self.tokenizer.bos_token = self.tokenizer.eos_token
# Infer vocab size from tokenizer
if self.cfg.d_vocab == -1:
self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1
if self.cfg.d_vocab_out == -1:
self.cfg.d_vocab_out = self.cfg.d_vocab
def to_tokens(
self,
input: Union[str, List[str]],
prepend_bos: bool = True,
move_to_device: bool = True,
truncate: bool = True,
) -> Int[torch.Tensor, "batch pos"]:
"""
Converts a string to a tensor of tokens. If prepend_bos is True, prepends the BOS token to the input - this is
recommended when creating a sequence of tokens to be input to a model.
Args:
input (Union[str, List[str]]). The input to tokenize
prepend_bos (bool): Whether to prepend a beginning of sequence token. Defaults to True
move_to_device (bool): Whether to move the output tensor of tokens to the device the model lives on.
Defaults to True
truncate (bool): If the output tokens are too long, whether to truncate the output tokens to the model's
max context window. Does nothing for shorter inputs. Defaults to True.
Gotcha: prepend_bos prepends a beginning of string token. This is a recommended default when inputting a prompt
to the model as the first token is often treated weirdly, but should only be done at the START of the prompt.
Make sure to turn it off if you're looking at the tokenization of part of the prompt!
(Note: some models eg GPT-2 were not trained with a BOS token, others (OPT and my models) were)
Gotcha2: Tokenization of a string depends on whether there is a preceding space and whether the first letter is
capitalized. It's easy to shoot yourself in the foot here if you're not careful!
"""
assert self.tokenizer is not None, "Cannot use to_tokens without a tokenizer"
if prepend_bos:
if isinstance(input, str):
input = self.tokenizer.bos_token + input
else:
input = [self.tokenizer.bos_token + string for string in input]
tokens = self.tokenizer(
input,
return_tensors="pt",
padding=True,
truncation=truncate,
max_length=self.cfg.n_ctx if truncate else None,
add_special_tokens=False
if self.tokenizer.name_or_path.startswith("facebook/opt")
else True, # As we manually add the BOS token
)["input_ids"]
if move_to_device:
tokens = tokens.to(self.cfg.device)
return tokens
def to_string(
self,
tokens: Union[
Int[torch.Tensor, ""],
Int[torch.Tensor, "batch pos"],
Int[torch.Tensor, "pos"],
np.ndarray,
List[Int[torch.Tensor, "pos"]],
],
) -> Union[str, List[str]]:
"""
Converts a tensor of tokens to a string (if rank 1) or a list of strings (if rank 2).
Accepts lists of tokens and numpy arrays as inputs too (and converts to tensors internally)
"""
assert self.tokenizer is not None, "Cannot use to_string without a tokenizer"
if not isinstance(tokens, torch.Tensor):
# We allow lists to be input
tokens = torch.tensor(tokens)
# I'm not sure what exactly clean_up_tokenization_spaces does, but if
# it's set, then tokenization is no longer invertible, and some tokens
# with a bunch of whitespace get collapsed together
if len(tokens.shape) == 2:
return self.tokenizer.batch_decode(
tokens, clean_up_tokenization_spaces=False
)
elif len(tokens.shape) <= 1:
return self.tokenizer.decode(tokens, clean_up_tokenization_spaces=False)
else:
raise ValueError(f"Invalid shape passed in: {tokens.shape}")
def to_str_tokens(
self,
input: Union[
str,
Int[torch.Tensor, "pos"],
Int[torch.Tensor, "1 pos"],
Int[np.ndarray, "pos"],
Int[np.ndarray, "1 pos"],
list,
],
prepend_bos: bool = True,
) -> List[str]:
"""Method to map text, a list of text or tokens to a list of tokens as strings
Gotcha: prepend_bos prepends a beginning of string token. This is a recommended default when inputting a prompt
to the model as the first token is often treated weirdly, but should only be done at the START of the prompt.
Make sure to turn it off if you're looking at the tokenization of part of the prompt!
(Note: some models eg GPT-2 were not trained with a BOS token, others (OPT and my models) were)
Gotcha2: Tokenization of a string depends on whether there is a preceding space and whether the first letter is
capitalized. It's easy to shoot yourself in the foot here if you're not careful!
Gotcha3: If passing a string that exceeds the model's context length (model.cfg.n_ctx), it will be truncated.
Args:
input (Union[str, list, torch.Tensor]): The input - either a string or a tensor of tokens. If tokens, should
be a tensor of shape [pos] or [1, pos]
prepend_bos (bool, optional): Whether to prepend a BOS token. Only applies if input is a string. Defaults to
True.
Returns:
str_tokens: List of individual tokens as strings
"""
if isinstance(input, list):
return list(
map(lambda tokens: self.to_str_tokens(tokens, prepend_bos), input)
) # type: ignore
elif isinstance(input, str):
tokens = self.to_tokens(input, prepend_bos=prepend_bos)[0]
elif isinstance(input, torch.Tensor):
tokens = input
tokens = tokens.squeeze() # Get rid of a trivial batch dimension
if tokens.dim() == 0:
# Don't pass dimensionless tensor
tokens = tokens.unsqueeze(0)
assert (
tokens.dim() == 1
), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}"
elif isinstance(input, np.ndarray):
tokens = input
tokens = tokens.squeeze() # Get rid of a trivial batch dimension
if tokens.ndim == 0:
# Don't pass dimensionless tensor
tokens = np.expand_dims(tokens, axis=0)
assert (
tokens.ndim == 1
), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}"
else:
raise ValueError(f"Invalid input type to to_str_tokens: {type(input)}")
str_tokens = self.tokenizer.batch_decode(
tokens, clean_up_tokenization_spaces=False
)
return str_tokens
def to_single_token(self, string):
"""Maps a string that makes up a single token to the id for that token. Raises an error for strings that are
not a single token! If uncertain use to_tokens"""
# We use the to_tokens method, do not append a BOS token
token = self.to_tokens(string, prepend_bos=False).squeeze()
# If token shape is non-empty, raise error
assert not token.shape, f"Input string: {string} is not a single token!"
return token.item()
def to_single_str_token(self, int_token: int) -> str:
# Gives the single token corresponding to an int in string form
assert isinstance(int_token, int)
token = self.to_str_tokens(torch.tensor([int_token]))
assert len(token) == 1
return token[0]
def get_token_position(
self,
single_token: Union[str, int],
input: Union[
str, Union[Float[torch.Tensor, "pos"], Float[torch.Tensor, "1 pos"]]
],
mode="first",
prepend_bos=True,
):
"""
Get the position of a single_token in a string or sequence of tokens. Raises an error if the token is not
present.
Gotcha: If you're inputting a string, it'll automatically be tokenized. Be careful about prepend_bos is true or
false! When a string is input to the model, a BOS (beginning of sequence) token is prepended by default when the
string is tokenized. But this should only be done at the START of the input, not when inputting part of the
prompt. If you're getting weird off-by-one errors, check carefully for what the setting should be!
Args:
single_token (Union[str, int]): The token to search for. Can
be a token index, or a string (but the string must correspond to a
single token)
input (Union[str, torch.Tensor]): The sequence to
search in. Can be a string or a rank 1 tensor of tokens or a rank 2 tensor of tokens with a dummy batch
dimension.
mode (str, optional): If there are multiple matches, which match to return. Supports "first" or "last".
Defaults to "first".
prepend_bos (bool): Prepends a BOS (beginning of sequence) token when tokenizing a string. Only matters when
inputting a string to the function, otherwise ignored.
"""
if isinstance(input, str):
# If the input is a string, convert to tensor
tokens = self.to_tokens(input, prepend_bos=prepend_bos)
else:
tokens = input
if len(tokens.shape) == 2:
# If the tokens have shape [1, seq_len], flatten to [seq_len]
assert (
tokens.shape[0] == 1
), f"If tokens are rank two, they must have shape [1, seq_len], not {tokens.shape}"
tokens = tokens[0]
if isinstance(single_token, str):
# If the single token is a string, convert to an integer
single_token = self.to_single_token(single_token)
elif isinstance(single_token, torch.Tensor):
single_token = single_token.item()
indices = torch.arange(len(tokens))[tokens == single_token]
assert len(indices) > 0, f"The token does not occur in the prompt"
if mode == "first":
return indices[0].item()
elif mode == "last":
return indices[-1].item()
else:
raise ValueError(f"mode must be 'first' or 'last', not {mode}")
def tokens_to_residual_directions(
self,
tokens: Union[
str,
int,
Int[torch.Tensor, ""],
Int[torch.Tensor, "pos"],
Int[torch.Tensor, "batch pos"],
],
) -> Union[
Float[torch.Tensor, "d_model"],
Float[torch.Tensor, "pos d_model"],
Float[torch.Tensor, "batch pos d_model"],
]:
"""Maps tokens to a tensor with the unembedding vector for those tokens, ie the vector in the residual stream
that we dot with to the get the logit for that token.
WARNING: If you use this without folding in LayerNorm, the results will be misleading and may be incorrect, as
the LN weights change the unembed map. This is done automatically with the fold_ln flag on from_pretrained
WARNING 2: LayerNorm scaling will scale up or down the effective direction in the residual stream for each
output token on any given input token position. ActivationCache.apply_ln_to_stack will apply the appropriate
scaling to these directions.
Args:
tokens (Union[str, int, torch.Tensor]): The token(s). If a single token, can be a single element tensor, an
integer, or string. If string, will be mapped to a single token using to_single_token, and an error
raised if it's multiple tokens. The method also works for a batch of input tokens
Returns:
residual_direction torch.Tensor: The unembedding vector for the token(s), a stack of [d_model] tensor.
"""
if isinstance(tokens, torch.Tensor) and tokens.numel() > 1:
# If the tokens are a tensor, and have more than one element, assume they are a batch of tokens
residual_directions = self.W_U[:, tokens]
residual_directions = einops.rearrange(
residual_directions, "d_model ... -> ... d_model"
)
return residual_directions
else:
# Otherwise there is a single token
if isinstance(tokens, str):
token = self.to_single_token(tokens)
elif isinstance(tokens, int):
token = tokens
elif isinstance(tokens, torch.Tensor) and tokens.numel() == 1:
token = tokens.item()
else:
raise ValueError(f"Invalid token type: {type(tokens)}")
residual_direction = self.W_U[:, token]
return residual_direction
def to(
self,
device_or_dtype: Union[torch.device, str, torch.dtype],
print_details: bool = True,
):
return devices.move_to_and_update_config(self, device_or_dtype, print_details)
def cuda(self):
# Wrapper around cuda that also changes self.cfg.device
return self.to("cuda")
def cpu(self):
# Wrapper around cuda that also changes self.cfg.device
return self.to("cpu")
@classmethod
def move_model_modules_to_device(cls, model: "HookedTransformer"):
model.embed.to(devices.get_device_for_block_index(0, model.cfg))
model.hook_embed.to(devices.get_device_for_block_index(0, model.cfg))
if model.cfg.positional_embedding_type != "rotary":
model.pos_embed.to(devices.get_device_for_block_index(0, model.cfg))
model.hook_pos_embed.to(devices.get_device_for_block_index(0, model.cfg))
if hasattr(model, "ln_final"):
model.ln_final.to(
devices.get_device_for_block_index(model.cfg.n_layers - 1, model.cfg)
)
model.unembed.to(
devices.get_device_for_block_index(model.cfg.n_layers - 1, model.cfg)
)
for i, block in enumerate(model.blocks):
block.to(devices.get_device_for_block_index(i, model.cfg))
@classmethod
def from_pretrained(
cls,
model_name: str,
fold_ln=True,
center_writing_weights=True,
center_unembed=True,
refactor_factored_attn_matrices=False,
checkpoint_index=None,
checkpoint_value=None,
hf_model=None,
device=None,
n_devices=1,
move_state_dict_to_device=True,
**model_kwargs,
) -> "HookedTransformer":
"""Class method to load in a pretrained model weights to the HookedTransformer format and optionally to do some
processing to make the model easier to interpret. Currently supports loading from most autoregressive
HuggingFace models (GPT2, GPTNeo, GPTJ, OPT) and from a range of toy models and SoLU models trained by me (Neel Nanda).
Also supports loading from a checkpoint for checkpointed models (currently, models trained by me (NeelNanda) and
the stanford-crfm models). These can either be determined by the checkpoint index (the index of the checkpoint
in the checkpoint list) or by the checkpoint value (the value of the checkpoint, eg 1000 for a checkpoint taken
at step 1000 or after 1000 tokens. Each model has checkpoints labelled with exactly one of labels and steps).
If neither is specified the final model is loaded. If both are specified, the checkpoint index is used.
See load_and_process_state_dict for details on the processing (folding layer norm, centering the unembedding and
centering the writing weights)
Args:
model_name (str): The model name - must be an element of OFFICIAL_MODEL_NAMES or an alias of one.
fold_ln (bool, optional): Whether to fold in the LayerNorm weights to the
subsequent linear layer. This does not change the computation.
Defaults to True.
center_writing_weights (bool, optional): Whether to center weights
writing to
the residual stream (ie set mean to be zero). Due to LayerNorm
this doesn't change the computation. Defaults to True.
center_unembed (bool, optional): Whether to center W_U (ie set mean
to be zero).
Softmax is translation invariant so this doesn't affect log
probs or loss, but does change logits. Defaults to True.
refactor_factored_attn_matrices (bool, optional): Whether to convert the factored
matrices (W_Q & W_K, and W_O & W_V) to be "even". Defaults to False
checkpoint_index (int, optional): If loading from a checkpoint, the index of
the checkpoint to load. Defaults to None.
checkpoint_value (int, optional): If loading from a checkpoint, the value of
the checkpoint to load, ie the step or token number (each model
has checkpoints labelled with exactly one of these). Defaults to
None.
hf_model (AutoModelForCausalLM, optional): If you have already loaded in the
HuggingFace model, you can pass it in here rather than needing
to recreate the object. Defaults to None.
device (str, optional): The device to load the model onto. By
default will load to CUDA if available, else CPU.
n_devices (int, optional): The number of devices to split the model
across. Defaults to 1. If greater than 1, `device` must be cuda.
move_state_dict_to_device (bool): Whether to move the state dict to the
relevant device before processing and loading in the weights.
Defaults to True.
model_kwargs (dict, optional): Any additional kwargs to pass to the
HookedTransformer initialization.
"""
# Get the model name used in HuggingFace, rather than the alias.
official_model_name = loading.get_official_model_name(model_name)
# Load the config into an HookedTransformerConfig object. If loading from a
# checkpoint, the config object will contain the information about the
# checkpoint
cfg = loading.get_pretrained_model_config(
official_model_name,
checkpoint_index=checkpoint_index,
checkpoint_value=checkpoint_value,
fold_ln=fold_ln,
device=device,
n_devices=n_devices,
)
if cfg.positional_embedding_type == "shortformer":
if fold_ln:
logging.warning(
"You tried to specify fold_ln=True for a shortformer model, but this can't be done! Setting fold_"
"ln=False instead."
)
fold_ln = False
if center_unembed:
logging.warning(
"You tried to specify center_unembed=True for a shortformer model, but this can't be done! "
"Setting center_unembed=False instead."
)
center_unembed = False
if center_writing_weights:
logging.warning(
"You tried to specify center_writing_weights=True for a shortformer model, but this can't be done! "
"Setting center_writing_weights=False instead."
)
center_writing_weights = False
# Get the state dict of the model (ie a mapping of parameter names to tensors), processed to match the
# HookedTransformer parameter names.
state_dict = loading.get_pretrained_state_dict(
official_model_name, cfg, hf_model
)
# Create the HookedTransformer object
model = cls(cfg, **model_kwargs)
model.load_and_process_state_dict(
state_dict,
fold_ln=fold_ln,
center_writing_weights=center_writing_weights,
center_unembed=center_unembed,
refactor_factored_attn_matrices=refactor_factored_attn_matrices,
move_state_dict_to_device=move_state_dict_to_device,
)
print(f"Loaded pretrained model {model_name} into HookedTransformer")
return model
@classmethod
def from_pretrained_no_processing(
cls,
model_name: str,
fold_ln=False,
center_writing_weights=False,
center_unembed=False,
refactor_factored_attn_matrices=False,
**from_pretrained_kwargs,
):
"""Wrapper for from_pretrained with all boolean flags related to simplifying the model set to False. Refer to
from_pretrained for details."""
return cls.from_pretrained(
model_name,
fold_ln=fold_ln,
center_writing_weights=center_writing_weights,
center_unembed=center_unembed,
refactor_factored_attn_matrices=refactor_factored_attn_matrices,
**from_pretrained_kwargs,
)
def init_weights(self):
"""
Initialize weights matrices with a normal of std=initializer_range (default=0.02). This roughly follows the
GPT-2 paper's scheme (but with truncation, and not halving the std for W_pos).
LayerNorm weights are already initialized to 1.0, and all biases are initialized to 0.0 (including LayerNorm),
so this just initializes weight matrices.
Weight matrices are set to empty by default (to save space + compute, since they're the bulk of the parameters),
so it is important to call this if you are not loading in pretrained weights! Note that this function assumes that weight names being with W_
Set seed here to ensure determinism.
This does NOT follow the PyTorch scheme, which as far as I can tell is super out of date but no one has gotten
round to updating it?
https://github.com/pytorch/pytorch/issues/18182
PyTorch Transformers are especially bad - TransformerEncoder initializes all layers to the exact same weights?!
https://github.com/pytorch/pytorch/issues/72253
The best paper I've found on transformer initialization is the muP paper, but haven't integrated those ideas yet:
https://arxiv.org/abs/2203.03466
"""
if self.cfg.seed is not None:
torch.manual_seed(self.cfg.seed)
for name, param in self.named_parameters():
if "W_" in name:
nn.init.normal_(param, std=self.cfg.initializer_range)
def load_and_process_state_dict(
self,
state_dict: Dict[str, torch.Tensor],
fold_ln: bool = True,
center_writing_weights: bool = True,
center_unembed: bool = True,
fold_value_biases: bool = True,
refactor_factored_attn_matrices: bool = False,
move_state_dict_to_device: bool = True,
):
"""Method to load a state dict into the model, and to apply processing to simplify it. The state dict is assumed
to be in the HookedTransformer format.
See the relevant method (same name as the flag) for more details on the folding, centering and processing flags.
Args:
state_dict (dict): The state dict of the model, in HookedTransformer format
fold_ln (bool, optional): Whether to fold in the LayerNorm weights to the
subsequent linear layer. This does not change the computation. Defaults to True.
center_writing_weights (bool, optional): Whether to center weights writing to the
residual stream (ie set mean to be zero). Due to LayerNorm this doesn't change the computation.
efaults to True.
center_unembed (bool, optional): Whether to center W_U (ie set mean to be zero).
Softmax is translation invariant so this doesn't affect log probs or loss, but does change logits.
Defaults to True.
fold_value_biases (bool, optional): Whether to fold the value biases into the output bias.
Because attention patterns add up to 1, the value biases always have a constant effect on a layer's
output, and it doesn't matter which head a bias is associated with. We can factor this all into a single
output bias to the layer, and make it easier to interpret the head's output.
refactor_factored_attn_matrices (bool, optional): Whether to convert the factored
matrices (W_Q & W_K, and W_O & W_V) to be "even". Defaults to False
move_state_dict_to_device (bool, optional): Whether to move the state dict to the device of the model.
Defaults to True.
model_name (str, optional): checks the model name for special cases of state dict loading. Only used for
Redwood 2L model currently
"""
assert (
self.cfg.n_devices == 1 or move_state_dict_to_device
), "If n_devices > 1, move_state_dict_to_device must be True"
if move_state_dict_to_device:
for k, v in state_dict.items():
if k.startswith("embed") or k.startswith("pos_embed"):
state_dict[k] = v.to(
devices.get_device_for_block_index(0, self.cfg)
)
elif k.startswith("ln_final") or k.startswith("unembed"):
state_dict[k] = v.to(
devices.get_device_for_block_index(
self.cfg.n_layers - 1, self.cfg
)
)
elif k.startswith("blocks"):
state_dict[k] = v.to(
devices.get_device_for_block_index(
int(k.split(".")[1]), self.cfg
)
)
else:
raise KeyError(
f"State Dict contains a key not in the HookedTransformer format: {k}"
)
state_dict = self.fill_missing_keys(state_dict)
if fold_ln:
if self.cfg.normalization_type not in ["LN", "LNPre"]:
logging.warning(
"You are not using LayerNorm, so the layer norm weights can't be folded! Skipping"