ARENA Interpretability Hackathon project. Co-authored with @lucyfarnik, @Aaquib111, and @canrager. Supervised by @ArthurConmy. Awarded first prize.
Code liberally taken from ARENA Exercises and Neel Nanda's attribution patching demo.
- Install Poetry
- Install Python ^3.8
- Install dev dependencies
poetry config virtualenvs.in-project true poetry install
Recent advancements in automated circuit discovery have made it possible to find circuits responsible for particular behaviors without much human intervention. However, existing algorithms require a relatively large number of forward passes through the model, which means they are quite slow to run. We integrate these algorithms with the first-order Taylor approximations proposed by Neel Nanda, arriving at a much faster algorithm. Our final algorithm takes a model, clean dataset, corrupted dataset, and a metric evaluating the model's behavior on the given task, and returns a pruned model and a list of heads in the circuit implementing the task. Most importantly, the algorithm is faster than ACDC, discovering the IOI circuit in GPT2-small in 4.1 seconds compared to ACDC's 8 minutes.
Develop an algorithm which speeds up automated circuit discovery using attribution patching.
Automated Circuit DisCovery (ACDC) is a technique which aims to automatically identify important units of a neural network in order to speed up the Mechanistic Interpretability workflow for finding circuits. The technique has had success finding many circuits, including those for IOI, induction, docstring, greater than, and tracr.
However, the ACDC technique is currently too computationally intensive to apply to models the size of GPT2-XL or larger. This is because multiple forward passes are required for each edge in the computational graph.
Our aim was to create a modified algorithm which speeds up ACDC using the extremely fast Attribution Patching technique.
- Accurately find circuits which have been found manually
- Work arbitrary tasks on arbitrarily large models
- Require a constant number of forward/backward passes of the model
- Attribution patching is a technique which uses gradients to take a linear approximation to activation patching.
- This reduces the number of passes from linear in number of activations patched, to constant
- The approximation is more valid when patching small activations than large ones.
(source: ACDC Paper - Conmy et. al. 2023)
- ACDC is an algorithm which uses iterated pruning to find important subgraphs of a computational graph representing a neural network.
- Replicated detection of S-inhibition heads from the IOI paper, using attribution patching instead of activation patching.
- Designed and implemented algorithm that identifies the IOI subgraph at least 100x faster than ACDC
We showed that attribution patching was able to identify the important heads in the IOI circuit by replicating figure 5(b) from the IOI paper
Original Direct Effect on S-Inhibition Head's Values Direct Effect on S-Inhibition Head's Values using Attribution Patching
We designed and implemented an algorithm which identified the important heads in the IOI circuit in under 4.1 seconds on an A10 GPU. This is compared to 8 minutes for the ACDC algorithm running on an A100.
Nodes in the IOI circuit in GPT2-small that our algorithm finds for a given threshold value
The original IOI paper identifies 26 heads relevant to the task. Exploring a few of these thresholds:
- For threshold 0.2, our method identifies 33 heads, of which 18 are part of the original IOI paper. Our method did not pick up on 8 heads, of which only 3 were not Backup Name Mover heads (BNMs). We believe this distinction is relevant since BNMs do not play a large role in the model's computation unless the regular name mover heads are ablated, and thus we would not expect our algorithm to detect them.
- For threshold 0.3, our method identifies 21 head, of which 16 are heads also identified by the IOI paper. Our method did not pick up on 10 heads, of which only 5 were not Backup Name Mover heads.
- For threshold 0.4, our method identifies 16 heads, of which 14 are also identified by the original IOI paper. Our method did not pick up on 12 heads, of which 5 were not Backup Name Mover heads. 1 of the 5 heads was a Negative Name Mover head.
- For threshold 0.5, our method identifies 15 heads, of which all 15 are also identified by the IOI paper. Our method did not pick up on the remaining 15 heads, of which 10 are not Backup Name Mover heads. 1 of the 10 heads is a Negative Name Mover head.
We seem to miss backup name mover heads, but capture the negative name move heads -- this is to be expected as backup name mover heads do not significantly contribute unless parts of the model are ablated, while negative heads always contribute to the outcome. Layer 6 head 0 is also consistently falsely identified as part of the circuit, and a case study on why would be worth exploring.
- Two forward passes: caching activations for clean and corrupted prompt, yielding
clean_cache, corrupted_cache
- One backward pass, caching gradients of the loss metric on the clean prompt w.r.t head activations
- Compute the importance of each node in the computational graph using attribution patching:
(clean_head_activations - corr_head_activations) * clean_head_grad_in
- Do ACDC-style thresholding based on the metric (eg. logit difference for the IOI task)
- Prune nodes by filling their
W_Q
with zeros (we believe this is vaguely analogous to mean-pruning since it effectively turns the QK circuit into an averaging circuit, but without taking the activations too far out of distribution since the OV circuit stays intact; however we still need to verify this in future work by comparing it to resample-pruning)
Algorithm 1: The node-based automated circuit discovery using attribution patching
Data: Model
For a specific pair of sender
and receiver
heads, we approximate their attribution to the metric as follows:
- Caching activations and gradients:
Two fwd passes: caching activations for clean and corrupted prompt, yielding
clean_cache
,corrupted_cache
One bwd pass: caching gradients of the loss metric w.r.t. activations for the clean prompt, yieldingclean_grad_cache
- Approximate the effect of an edge on the logit difference using the formula
((clean_early_node_out - corrupted_early_node_out) * corrupted_late_node_in_grad).sum()
Wherecorrupted_late_node_in_grad
is the derivative of the metric with respect to the residual stream input immediately before the layer norm before the receiver node.
Investigate the validity of the approximation when freezing vs not freezing ln in the receiver node when computing gradient w.r.t. the input of the receiver node.
- Test both Node attribution patching and Path attribution patching for identifying induction heads in a short transformer (currently debugging error regarding loss metric)
Code available at: https://github.com/rusheb/arena-hackathon-attribution-patching