TransformerLens: Hooks & Interventions
Hooks let you read and modify activations during forward passes. This is the foundation for causal interventions.
What Are Hooks?
Hooks are functions that run at specific points during the forward pass:
def my_hook(activation, hook):
"""
activation: the tensor at this point
hook: metadata about where we are
"""
print(f"At {hook.name}: shape {activation.shape}")
return activation # Must return (possibly modified) activation
# Run with hook
model.run_with_hooks(
"Hello world",
fwd_hooks=[("blocks.0.attn.hook_pattern", my_hook)]
)
Hook Points
Every interesting activation has a hook point:
# List all hook points
for name, hook in model.hook_dict.items():
print(name)
# Output:
# hook_embed
# hook_pos_embed
# blocks.0.hook_resid_pre
# blocks.0.attn.hook_q
# blocks.0.attn.hook_k
# blocks.0.attn.hook_v
# blocks.0.attn.hook_pattern
# blocks.0.attn.hook_z
# blocks.0.hook_attn_out
# blocks.0.hook_resid_mid
# blocks.0.mlp.hook_pre
# blocks.0.mlp.hook_post
# blocks.0.hook_resid_post
# ... (repeated for each layer)
Reading Activations with Hooks
stored_activations = {}
def store_hook(activation, hook):
stored_activations[hook.name] = activation.clone()
return activation
# Store specific activations
model.run_with_hooks(
"The cat sat",
fwd_hooks=[
("blocks.5.hook_resid_post", store_hook),
("blocks.5.attn.hook_pattern", store_hook),
]
)
print(stored_activations["blocks.5.hook_resid_post"].shape)
Modifying Activations
Zero ablation: Remove a component's contribution
def zero_head_hook(activation, hook, head_idx):
"""Zero out a specific attention head."""
# activation shape: (batch, seq, n_heads, d_head)
activation[:, :, head_idx, :] = 0
return activation
# Ablate head 5 in layer 3
model.run_with_hooks(
"The cat sat",
fwd_hooks=[
("blocks.3.attn.hook_z", partial(zero_head_hook, head_idx=5))
]
)
Activation Patching
Copy activations from one run to another:
def patch_hook(activation, hook, patch_cache, pos=None):
"""Replace activation with cached version."""
patch_value = patch_cache[hook.name]
if pos is not None:
activation[:, pos] = patch_value[:, pos]
else:
activation[:] = patch_value
return activation
# Get clean activations
_, clean_cache = model.run_with_cache("The capital of France is")
# Patch into corrupted run
corrupted_output = model.run_with_hooks(
"The capital of Germany is", # Wrong country
fwd_hooks=[
("blocks.8.hook_resid_post",
partial(patch_hook, patch_cache=clean_cache))
]
)
Causal Tracing
Find which components are important:
def causal_trace(model, clean_prompt, corrupted_prompt, target_token):
"""
1. Run clean prompt, save activations
2. Run corrupted prompt
3. For each layer: patch clean → corrupted, measure recovery
"""
_, clean_cache = model.run_with_cache(clean_prompt)
corrupted_logits = model(corrupted_prompt)
results = []
for layer in range(model.cfg.n_layers):
patched_logits = model.run_with_hooks(
corrupted_prompt,
fwd_hooks=[(
f"blocks.{layer}.hook_resid_post",
partial(patch_hook, patch_cache=clean_cache)
)]
)
# Measure how much patching this layer recovers correct prediction
recovery = compute_recovery(
corrupted_logits, patched_logits, clean_logits, target_token
)
results.append(recovery)
return results
Hook Filters
Run hooks only on certain inputs:
# Using hook filter
def hook_filter(name):
return "attn" in name and "pattern" in name
model.run_with_hooks(
"Hello",
fwd_hooks=[(hook_filter, my_hook)] # Runs on all attention patterns
)
Common Patterns
Measuring component importance:
def measure_importance(model, prompt, layer, head):
"""How much does this head matter for this prompt?"""
clean_loss = model(prompt, return_type="loss")
ablated_loss = model.run_with_hooks(
prompt,
fwd_hooks=[(
f"blocks.{layer}.attn.hook_z",
partial(zero_head_hook, head_idx=head)
)],
return_type="loss"
)
return ablated_loss - clean_loss # Positive = head was helpful
Direct logit attribution:
def get_head_direct_effect(cache, layer, head):
"""What does this head contribute to the output logits?"""
head_output = cache["z", layer][:, :, head, :] # (batch, seq, d_head)
head_contrib = head_output @ model.W_O[layer, head] # (batch, seq, d_model)
logit_effect = head_contrib @ model.W_U # (batch, seq, vocab)
return logit_effect
Capstone Connection
Hooks for sycophancy analysis:
def measure_sycophancy_contribution(model, prompt_pairs):
"""
For each (honest_prompt, sycophantic_prompt) pair:
1. Run both, cache activations
2. Find heads where activation difference predicts sycophancy
3. Ablate those heads and measure if model becomes more honest
"""
for honest, sycophantic in prompt_pairs:
_, honest_cache = model.run_with_cache(honest)
_, syco_cache = model.run_with_cache(sycophantic)
for layer in range(model.cfg.n_layers):
for head in range(model.cfg.n_heads):
diff = (
honest_cache["z", layer][:, :, head] -
syco_cache["z", layer][:, :, head]
).norm()
if diff > threshold:
# This head behaves differently! Investigate further
pass
🎓 Tyla's Exercise
Explain why we must
return activationin a hook function. What happens if we don't?Design an experiment to test if a specific attention head is necessary for a specific behavior. What's your null hypothesis?
Activation patching replaces activations from run A into run B. When might this NOT tell us about causation?
💻 Aaliyah's Exercise
Implement causal interventions:
def ablation_experiment(model, prompt, target_token):
"""
1. Get baseline loss on predicting target_token
2. For each head, ablate it and measure new loss
3. Return ranking of heads by importance
4. Visualize: which layers matter most?
"""
pass
def activation_patching(model, clean_prompt, corrupted_prompt):
"""
1. Clean prompt: "The Eiffel Tower is in Paris"
2. Corrupted: "The Eiffel Tower is in Rome"
3. Patch each layer's residual stream from clean → corrupted
4. Find which layer most recovers "Paris" prediction
"""
pass
📚 Maneesha's Reflection
Hooks give us the ability to intervene on a running computation. What ethical considerations arise from this capability?
"Correlation is not causation" applies here too. How do interventions help us move from correlation to causation?
If you could add one new hook point to TransformerLens, what would it be and why?