forked from TransformerLensOrg/TransformerLens
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpatching.py
692 lines (563 loc) · 29.8 KB
/
patching.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
# %%
"""
A module for patching activations in a transformer model, and measuring the effect of the patch on the output.
This implements the activation patching technique for a range of types of activation.
The structure is to have a single generic_activation_patch function that does everything, and to have a range of specialised functions for specific types of activation.
See this explanation for more https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=qeWBvs-R-taFfcCq-S_hgMqx
And check out the Activation Patching in TransformerLens Demo notebook for a demo of how to use this module.
"""
from __future__ import annotations
import itertools
from functools import partial
from typing import Callable, Optional, Sequence, Tuple, Union
import einops
import pandas as pd
import torch
from jaxtyping import Float, Int
from tqdm.auto import tqdm
from typing_extensions import Literal
import transformer_lens.utils as utils
from transformer_lens import ActivationCache, HookedTransformer
# %%
Logits = torch.Tensor
AxisNames = Literal["layer", "pos", "head_index", "head", "src_pos", "dest_pos"]
# %%
from typing import Sequence
def make_df_from_ranges(
column_max_ranges: Sequence[int], column_names: Sequence[str]
) -> pd.DataFrame:
"""
Takes in a list of column names and max ranges for each column, and returns a dataframe with the cartesian product of the range for each column (ie iterating through all combinations from zero to column_max_range - 1, in order, incrementing the final column first)
"""
rows = list(
itertools.product(
*[range(axis_max_range) for axis_max_range in column_max_ranges]
)
)
df = pd.DataFrame(rows, columns=column_names)
return df
# %%
CorruptedActivation = torch.Tensor
PatchedActivation = torch.Tensor
def generic_activation_patch(
model: HookedTransformer,
corrupted_tokens: Int[torch.Tensor, "batch pos"],
clean_cache: ActivationCache,
patching_metric: Callable[
[Float[torch.Tensor, "batch pos d_vocab"]], Float[torch.Tensor, ""]
],
patch_setter: Callable[
[CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation
],
activation_name: str,
index_axis_names: Optional[Sequence[AxisNames]] = None,
index_df: Optional[pd.DataFrame] = None,
return_index_df: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, pd.DataFrame]]:
"""
A generic function to do activation patching, will be specialised to specific use cases.
Activation patching is about studying the counterfactual effect of a specific activation between a clean run and a corrupted run. The idea is have two inputs, clean and corrupted, which have two different outputs, and differ in some key detail. Eg "The Eiffel Tower is in" vs "The Colosseum is in". Then to take a cached set of activations from the "clean" run, and a set of corrupted.
Internally, the key function comes from three things: A list of tuples of indices (eg (layer, position, head_index)), a index_to_act_name function which identifies the right activation for each index, a patch_setter function which takes the corrupted activation, the index and the clean cache, and a metric for how well the patched model has recovered.
The indices can either be given explicitly as a pandas dataframe, or by listing the relevant axis names and having them inferred from the tokens and the model config. It is assumed that the first column is always layer.
This function then iterates over every tuple of indices, does the relevant patch, and stores it
Args:
model: The relevant model
corrupted_tokens: The input tokens for the corrupted run
clean_cache: The cached activations from the clean run
patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
patch_setter: A function which acts on (corrupted_activation, index, clean_cache) to edit the activation and patch in the relevant chunk of the clean activation
activation_name: The name of the activation being patched
index_axis_names: The names of the axes to (fully) iterate over, implicitly fills in index_df
index_df: The dataframe of indices, columns are axis names and each row is a tuple of indices. Will be inferred from index_axis_names if not given. When this is input, the output will be a flattened tensor with an element per row of index_df
return_index_df: A Boolean flag for whether to return the dataframe of indices too
Returns:
patched_output: The tensor of the patching metric for each patch. By default it has one dimension for each index dimension, via index_df set explicitly it is flattened with one element per row.
index_df *optional*: The dataframe of indices
"""
if index_df is None:
assert index_axis_names is not None
# Get the max range for all possible axes
max_axis_range = {
"layer": model.cfg.n_layers,
"pos": corrupted_tokens.shape[-1],
"head_index": model.cfg.n_heads,
}
max_axis_range["src_pos"] = max_axis_range["pos"]
max_axis_range["dest_pos"] = max_axis_range["pos"]
max_axis_range["head"] = max_axis_range["head_index"]
# Get the max range for each axis we iterate over
index_axis_max_range = [
max_axis_range[axis_name] for axis_name in index_axis_names
]
# Get the dataframe where each row is a tuple of indices
index_df = make_df_from_ranges(index_axis_max_range, index_axis_names)
flattened_output = False
else:
# A dataframe of indices was provided. Verify that we did not *also* receive index_axis_names
assert index_axis_names is None
index_axis_max_range = index_df.max().to_list()
flattened_output = True
# Create an empty tensor to show the patched metric for each patch
if flattened_output:
patched_metric_output = torch.zeros(len(index_df), device=model.cfg.device)
else:
patched_metric_output = torch.zeros(
index_axis_max_range, device=model.cfg.device
)
# A generic patching hook - for each index, it applies the patch_setter appropriately to patch the activation
def patching_hook(corrupted_activation, hook, index, clean_activation):
return patch_setter(corrupted_activation, index, clean_activation)
# Iterate over every list of indices, and make the appropriate patch!
for c, index_row in enumerate(tqdm((list(index_df.iterrows())))):
index = index_row[1].to_list()
# The current activation name is just the activation name plus the layer (assumed to be the first element of the input)
current_activation_name = utils.get_act_name(activation_name, layer=index[0])
# The hook function cannot receive additional inputs, so we use partial to include the specific index and the corresponding clean activation
current_hook = partial(
patching_hook,
index=index,
clean_activation=clean_cache[current_activation_name],
)
# Run the model with the patching hook and get the logits!
patched_logits = model.run_with_hooks(
corrupted_tokens, fwd_hooks=[(current_activation_name, current_hook)]
)
# Calculate the patching metric and store
if flattened_output:
patched_metric_output[c] = patching_metric(patched_logits).item()
else:
patched_metric_output[tuple(index)] = patching_metric(patched_logits).item()
if return_index_df:
return patched_metric_output, index_df
else:
return patched_metric_output
# %%
# Defining patch setters for various shapes of activations
def layer_pos_patch_setter(corrupted_activation, index, clean_activation):
"""
Applies the activation patch where index = [layer, pos]
Implicitly assumes that the activation axis order is [batch, pos, ...], which is true of everything that is not an attention pattern shaped tensor.
"""
assert len(index) == 2
layer, pos = index
corrupted_activation[:, pos, ...] = clean_activation[:, pos, ...]
return corrupted_activation
def layer_pos_head_vector_patch_setter(
corrupted_activation,
index,
clean_activation,
):
"""
Applies the activation patch where index = [layer, pos, head_index]
Implicitly assumes that the activation axis order is [batch, pos, head_index, ...], which is true of all attention head vector activations (q, k, v, z, result) but *not* of attention patterns.
"""
assert len(index) == 3
layer, pos, head_index = index
corrupted_activation[:, pos, head_index] = clean_activation[:, pos, head_index]
return corrupted_activation
def layer_head_vector_patch_setter(
corrupted_activation,
index,
clean_activation,
):
"""
Applies the activation patch where index = [layer, head_index]
Implicitly assumes that the activation axis order is [batch, pos, head_index, ...], which is true of all attention head vector activations (q, k, v, z, result) but *not* of attention patterns.
"""
assert len(index) == 2
layer, head_index = index
corrupted_activation[:, :, head_index] = clean_activation[:, :, head_index]
return corrupted_activation
def layer_head_pattern_patch_setter(
corrupted_activation,
index,
clean_activation,
):
"""
Applies the activation patch where index = [layer, head_index]
Implicitly assumes that the activation axis order is [batch, head_index, dest_pos, src_pos], which is true of attention scores and patterns.
"""
assert len(index) == 2
layer, head_index = index
corrupted_activation[:, head_index, :, :] = clean_activation[:, head_index, :, :]
return corrupted_activation
def layer_head_pos_pattern_patch_setter(
corrupted_activation,
index,
clean_activation,
):
"""
Applies the activation patch where index = [layer, head_index, dest_pos]
Implicitly assumes that the activation axis order is [batch, head_index, dest_pos, src_pos], which is true of attention scores and patterns.
"""
assert len(index) == 3
layer, head_index, dest_pos = index
corrupted_activation[:, head_index, dest_pos, :] = clean_activation[
:, head_index, dest_pos, :
]
return corrupted_activation
def layer_head_dest_src_pos_pattern_patch_setter(
corrupted_activation,
index,
clean_activation,
):
"""
Applies the activation patch where index = [layer, head_index, dest_pos, src_pos]
Implicitly assumes that the activation axis order is [batch, head_index, dest_pos, src_pos], which is true of attention scores and patterns.
"""
assert len(index) == 4
layer, head_index, dest_pos, src_pos = index
corrupted_activation[:, head_index, dest_pos, src_pos] = clean_activation[
:, head_index, dest_pos, src_pos
]
return corrupted_activation
# %%
# Defining activation patching functions for a range of common activation patches.
get_act_patch_resid_pre = partial(
generic_activation_patch,
patch_setter=layer_pos_patch_setter,
activation_name="resid_pre",
index_axis_names=("layer", "pos"),
)
get_act_patch_resid_pre.__doc__ = """
Function to get activation patching results for the residual stream (at the start of each block) (by position). Returns a tensor of shape [n_layers, pos]
See generic_activation_patch for a more detailed explanation of activation patching
Args:
model: The relevant model
corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
clean_cache (ActivationCache): The cached activations from the clean run
patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
Returns:
patched_output (torch.Tensor): The tensor of the patching metric for each resid_pre patch. Has shape [n_layers, pos]
"""
get_act_patch_resid_mid = partial(
generic_activation_patch,
patch_setter=layer_pos_patch_setter,
activation_name="resid_mid",
index_axis_names=("layer", "pos"),
)
get_act_patch_resid_mid.__doc__ = """
Function to get activation patching results for the residual stream (between the attn and MLP layer of each block) (by position). Returns a tensor of shape [n_layers, pos]
See generic_activation_patch for a more detailed explanation of activation patching
Args:
model: The relevant model
corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
clean_cache (ActivationCache): The cached activations from the clean run
patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
Returns:
patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos]
"""
get_act_patch_attn_out = partial(
generic_activation_patch,
patch_setter=layer_pos_patch_setter,
activation_name="attn_out",
index_axis_names=("layer", "pos"),
)
get_act_patch_attn_out.__doc__ = """
Function to get activation patching results for the output of each Attention layer (by position). Returns a tensor of shape [n_layers, pos]
See generic_activation_patch for a more detailed explanation of activation patching
Args:
model: The relevant model
corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
clean_cache (ActivationCache): The cached activations from the clean run
patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
Returns:
patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos]
"""
get_act_patch_mlp_out = partial(
generic_activation_patch,
patch_setter=layer_pos_patch_setter,
activation_name="mlp_out",
index_axis_names=("layer", "pos"),
)
get_act_patch_mlp_out.__doc__ = """
Function to get activation patching results for the output of each MLP layer (by position). Returns a tensor of shape [n_layers, pos]
See generic_activation_patch for a more detailed explanation of activation patching
Args:
model: The relevant model
corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
clean_cache (ActivationCache): The cached activations from the clean run
patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
Returns:
patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos]
"""
# %%
get_act_patch_attn_head_out_by_pos = partial(
generic_activation_patch,
patch_setter=layer_pos_head_vector_patch_setter,
activation_name="z",
index_axis_names=("layer", "pos", "head"),
)
get_act_patch_attn_head_out_by_pos.__doc__ = """
Function to get activation patching results for the output of each Attention Head (by position). Returns a tensor of shape [n_layers, pos, n_heads]
See generic_activation_patch for a more detailed explanation of activation patching
Args:
model: The relevant model
corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
clean_cache (ActivationCache): The cached activations from the clean run
patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
Returns:
patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos, n_heads]
"""
get_act_patch_attn_head_q_by_pos = partial(
generic_activation_patch,
patch_setter=layer_pos_head_vector_patch_setter,
activation_name="q",
index_axis_names=("layer", "pos", "head"),
)
get_act_patch_attn_head_q_by_pos.__doc__ = """
Function to get activation patching results for the queries of each Attention Head (by position). Returns a tensor of shape [n_layers, pos, n_heads]
See generic_activation_patch for a more detailed explanation of activation patching
Args:
model: The relevant model
corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
clean_cache (ActivationCache): The cached activations from the clean run
patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
Returns:
patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos, n_heads]
"""
get_act_patch_attn_head_k_by_pos = partial(
generic_activation_patch,
patch_setter=layer_pos_head_vector_patch_setter,
activation_name="k",
index_axis_names=("layer", "pos", "head"),
)
get_act_patch_attn_head_k_by_pos.__doc__ = """
Function to get activation patching results for the keys of each Attention Head (by position). Returns a tensor of shape [n_layers, pos, n_heads]
See generic_activation_patch for a more detailed explanation of activation patching
Args:
model: The relevant model
corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
clean_cache (ActivationCache): The cached activations from the clean run
patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
Returns:
patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos, n_heads]
"""
get_act_patch_attn_head_v_by_pos = partial(
generic_activation_patch,
patch_setter=layer_pos_head_vector_patch_setter,
activation_name="v",
index_axis_names=("layer", "pos", "head"),
)
get_act_patch_attn_head_v_by_pos.__doc__ = """
Function to get activation patching results for the values of each Attention Head (by position). Returns a tensor of shape [n_layers, pos, n_heads]
See generic_activation_patch for a more detailed explanation of activation patching
Args:
model: The relevant model
corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
clean_cache (ActivationCache): The cached activations from the clean run
patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
Returns:
patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos, n_heads]
"""
# %%
get_act_patch_attn_head_pattern_by_pos = partial(
generic_activation_patch,
patch_setter=layer_head_pos_pattern_patch_setter,
activation_name="pattern",
index_axis_names=("layer", "head_index", "dest_pos"),
)
get_act_patch_attn_head_pattern_by_pos.__doc__ = """
Function to get activation patching results for the attention pattern of each Attention Head (by destination position). Returns a tensor of shape [n_layers, n_heads, dest_pos]
See generic_activation_patch for a more detailed explanation of activation patching
Args:
model: The relevant model
corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
clean_cache (ActivationCache): The cached activations from the clean run
patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
Returns:
patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads, dest_pos]
"""
get_act_patch_attn_head_pattern_dest_src_pos = partial(
generic_activation_patch,
patch_setter=layer_head_dest_src_pos_pattern_patch_setter,
activation_name="pattern",
index_axis_names=("layer", "head_index", "dest_pos", "src_pos"),
)
get_act_patch_attn_head_pattern_dest_src_pos.__doc__ = """
Function to get activation patching results for each destination, source entry of the attention pattern for each Attention Head. Returns a tensor of shape [n_layers, n_heads, dest_pos, src_pos]
See generic_activation_patch for a more detailed explanation of activation patching
Args:
model: The relevant model
corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
clean_cache (ActivationCache): The cached activations from the clean run
patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
Returns:
patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads, dest_pos, src_pos]
"""
# %%
get_act_patch_attn_head_out_all_pos = partial(
generic_activation_patch,
patch_setter=layer_head_vector_patch_setter,
activation_name="z",
index_axis_names=("layer", "head"),
)
get_act_patch_attn_head_out_all_pos.__doc__ = """
Function to get activation patching results for the outputs of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads]
See generic_activation_patch for a more detailed explanation of activation patching
Args:
model: The relevant model
corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
clean_cache (ActivationCache): The cached activations from the clean run
patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
Returns:
patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads]
"""
get_act_patch_attn_head_q_all_pos = partial(
generic_activation_patch,
patch_setter=layer_head_vector_patch_setter,
activation_name="q",
index_axis_names=("layer", "head"),
)
get_act_patch_attn_head_q_all_pos.__doc__ = """
Function to get activation patching results for the queries of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads]
See generic_activation_patch for a more detailed explanation of activation patching
Args:
model: The relevant model
corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
clean_cache (ActivationCache): The cached activations from the clean run
patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
Returns:
patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads]
"""
get_act_patch_attn_head_k_all_pos = partial(
generic_activation_patch,
patch_setter=layer_head_vector_patch_setter,
activation_name="k",
index_axis_names=("layer", "head"),
)
get_act_patch_attn_head_k_all_pos.__doc__ = """
Function to get activation patching results for the keys of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads]
See generic_activation_patch for a more detailed explanation of activation patching
Args:
model: The relevant model
corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
clean_cache (ActivationCache): The cached activations from the clean run
patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
Returns:
patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads]
"""
get_act_patch_attn_head_v_all_pos = partial(
generic_activation_patch,
patch_setter=layer_head_vector_patch_setter,
activation_name="v",
index_axis_names=("layer", "head"),
)
get_act_patch_attn_head_v_all_pos.__doc__ = """
Function to get activation patching results for the values of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads]
See generic_activation_patch for a more detailed explanation of activation patching
Args:
model: The relevant model
corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
clean_cache (ActivationCache): The cached activations from the clean run
patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
Returns:
patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads]
"""
get_act_patch_attn_head_pattern_all_pos = partial(
generic_activation_patch,
patch_setter=layer_head_pattern_patch_setter,
activation_name="pattern",
index_axis_names=("layer", "head_index"),
)
get_act_patch_attn_head_pattern_all_pos.__doc__ = """
Function to get activation patching results for the attention pattern of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads]
See generic_activation_patch for a more detailed explanation of activation patching
Args:
model: The relevant model
corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
clean_cache (ActivationCache): The cached activations from the clean run
patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
Returns:
patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads]
"""
# %%
def get_act_patch_attn_head_all_pos_every(
model, corrupted_tokens, clean_cache, metric
) -> Float[torch.Tensor, "patch_type layer head"]:
"""Helper function to get activation patching results for every head (across all positions) for every act type (output, query, key, value, pattern). Wrapper around each's patching function, returns a stacked tensor of shape [5, n_layers, n_heads]
Args:
model: The relevant model
corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
clean_cache (ActivationCache): The cached activations from the clean run
metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
Returns:
patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [5, n_layers, n_heads]
"""
act_patch_results = []
act_patch_results.append(
get_act_patch_attn_head_out_all_pos(
model, corrupted_tokens, clean_cache, metric
)
)
act_patch_results.append(
get_act_patch_attn_head_q_all_pos(model, corrupted_tokens, clean_cache, metric)
)
act_patch_results.append(
get_act_patch_attn_head_k_all_pos(model, corrupted_tokens, clean_cache, metric)
)
act_patch_results.append(
get_act_patch_attn_head_v_all_pos(model, corrupted_tokens, clean_cache, metric)
)
act_patch_results.append(
get_act_patch_attn_head_pattern_all_pos(
model, corrupted_tokens, clean_cache, metric
)
)
return torch.stack(act_patch_results, dim=0)
def get_act_patch_attn_head_by_pos_every(
model, corrupted_tokens, clean_cache, metric
) -> Float[torch.Tensor, "patch_type layer pos head"]:
"""Helper function to get activation patching results for every head (by position) for every act type (output, query, key, value, pattern). Wrapper around each's patching function, returns a stacked tensor of shape [5, n_layers, pos, n_heads]
Args:
model: The relevant model
corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
clean_cache (ActivationCache): The cached activations from the clean run
metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
Returns:
patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [5, n_layers, pos, n_heads]
"""
act_patch_results = []
act_patch_results.append(
get_act_patch_attn_head_out_by_pos(model, corrupted_tokens, clean_cache, metric)
)
act_patch_results.append(
get_act_patch_attn_head_q_by_pos(model, corrupted_tokens, clean_cache, metric)
)
act_patch_results.append(
get_act_patch_attn_head_k_by_pos(model, corrupted_tokens, clean_cache, metric)
)
act_patch_results.append(
get_act_patch_attn_head_v_by_pos(model, corrupted_tokens, clean_cache, metric)
)
# Reshape pattern to be compatible with the rest of the results
pattern_results = get_act_patch_attn_head_pattern_by_pos(
model, corrupted_tokens, clean_cache, metric
)
act_patch_results.append(
einops.rearrange(pattern_results, "batch head pos -> batch pos head")
)
return torch.stack(act_patch_results, dim=0)
def get_act_patch_block_every(
model, corrupted_tokens, clean_cache, metric
) -> Float[torch.Tensor, "patch_type layer pos"]:
"""Helper function to get activation patching results for the residual stream (at the start of each block), output of each Attention layer and output of each MLP layer. Wrapper around each's patching function, returns a stacked tensor of shape [3, n_layers, pos]
Args:
model: The relevant model
corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
clean_cache (ActivationCache): The cached activations from the clean run
metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
Returns:
patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [3, n_layers, pos]
"""
act_patch_results = []
act_patch_results.append(
get_act_patch_resid_pre(model, corrupted_tokens, clean_cache, metric)
)
act_patch_results.append(
get_act_patch_attn_out(model, corrupted_tokens, clean_cache, metric)
)
act_patch_results.append(
get_act_patch_mlp_out(model, corrupted_tokens, clean_cache, metric)
)
return torch.stack(act_patch_results, dim=0)