TransformerLens: Hooks & Interventions

Hooks let you read and modify activations during forward passes. This is the foundation for causal interventions.


What Are Hooks?

Hooks are functions that run at specific points during the forward pass:

def my_hook(activation, hook):
    """
    activation: the tensor at this point
    hook: metadata about where we are
    """
    print(f"At {hook.name}: shape {activation.shape}")
    return activation  # Must return (possibly modified) activation

# Run with hook
model.run_with_hooks(
    "Hello world",
    fwd_hooks=[("blocks.0.attn.hook_pattern", my_hook)]
)

Hook Points

Every interesting activation has a hook point:

# List all hook points
for name, hook in model.hook_dict.items():
    print(name)

# Output:
# hook_embed
# hook_pos_embed
# blocks.0.hook_resid_pre
# blocks.0.attn.hook_q
# blocks.0.attn.hook_k
# blocks.0.attn.hook_v
# blocks.0.attn.hook_pattern
# blocks.0.attn.hook_z
# blocks.0.hook_attn_out
# blocks.0.hook_resid_mid
# blocks.0.mlp.hook_pre
# blocks.0.mlp.hook_post
# blocks.0.hook_resid_post
# ... (repeated for each layer)

Reading Activations with Hooks

stored_activations = {}

def store_hook(activation, hook):
    stored_activations[hook.name] = activation.clone()
    return activation

# Store specific activations
model.run_with_hooks(
    "The cat sat",
    fwd_hooks=[
        ("blocks.5.hook_resid_post", store_hook),
        ("blocks.5.attn.hook_pattern", store_hook),
    ]
)

print(stored_activations["blocks.5.hook_resid_post"].shape)

Modifying Activations

Zero ablation: Remove a component's contribution

def zero_head_hook(activation, hook, head_idx):
    """Zero out a specific attention head."""
    # activation shape: (batch, seq, n_heads, d_head)
    activation[:, :, head_idx, :] = 0
    return activation

# Ablate head 5 in layer 3
model.run_with_hooks(
    "The cat sat",
    fwd_hooks=[
        ("blocks.3.attn.hook_z", partial(zero_head_hook, head_idx=5))
    ]
)

Activation Patching

Copy activations from one run to another:

def patch_hook(activation, hook, patch_cache, pos=None):
    """Replace activation with cached version."""
    patch_value = patch_cache[hook.name]
    if pos is not None:
        activation[:, pos] = patch_value[:, pos]
    else:
        activation[:] = patch_value
    return activation

# Get clean activations
_, clean_cache = model.run_with_cache("The capital of France is")

# Patch into corrupted run
corrupted_output = model.run_with_hooks(
    "The capital of Germany is",  # Wrong country
    fwd_hooks=[
        ("blocks.8.hook_resid_post",
         partial(patch_hook, patch_cache=clean_cache))
    ]
)

Causal Tracing

Find which components are important:

def causal_trace(model, clean_prompt, corrupted_prompt, target_token):
    """
    1. Run clean prompt, save activations
    2. Run corrupted prompt
    3. For each layer: patch clean → corrupted, measure recovery
    """
    _, clean_cache = model.run_with_cache(clean_prompt)
    corrupted_logits = model(corrupted_prompt)

    results = []
    for layer in range(model.cfg.n_layers):
        patched_logits = model.run_with_hooks(
            corrupted_prompt,
            fwd_hooks=[(
                f"blocks.{layer}.hook_resid_post",
                partial(patch_hook, patch_cache=clean_cache)
            )]
        )

        # Measure how much patching this layer recovers correct prediction
        recovery = compute_recovery(
            corrupted_logits, patched_logits, clean_logits, target_token
        )
        results.append(recovery)

    return results

Hook Filters

Run hooks only on certain inputs:

# Using hook filter
def hook_filter(name):
    return "attn" in name and "pattern" in name

model.run_with_hooks(
    "Hello",
    fwd_hooks=[(hook_filter, my_hook)]  # Runs on all attention patterns
)

Common Patterns

Measuring component importance:

def measure_importance(model, prompt, layer, head):
    """How much does this head matter for this prompt?"""
    clean_loss = model(prompt, return_type="loss")

    ablated_loss = model.run_with_hooks(
        prompt,
        fwd_hooks=[(
            f"blocks.{layer}.attn.hook_z",
            partial(zero_head_hook, head_idx=head)
        )],
        return_type="loss"
    )

    return ablated_loss - clean_loss  # Positive = head was helpful

Direct logit attribution:

def get_head_direct_effect(cache, layer, head):
    """What does this head contribute to the output logits?"""
    head_output = cache["z", layer][:, :, head, :]  # (batch, seq, d_head)
    head_contrib = head_output @ model.W_O[layer, head]  # (batch, seq, d_model)
    logit_effect = head_contrib @ model.W_U  # (batch, seq, vocab)
    return logit_effect

Capstone Connection

Hooks for sycophancy analysis:

def measure_sycophancy_contribution(model, prompt_pairs):
    """
    For each (honest_prompt, sycophantic_prompt) pair:
    1. Run both, cache activations
    2. Find heads where activation difference predicts sycophancy
    3. Ablate those heads and measure if model becomes more honest
    """

    for honest, sycophantic in prompt_pairs:
        _, honest_cache = model.run_with_cache(honest)
        _, syco_cache = model.run_with_cache(sycophantic)

        for layer in range(model.cfg.n_layers):
            for head in range(model.cfg.n_heads):
                diff = (
                    honest_cache["z", layer][:, :, head] -
                    syco_cache["z", layer][:, :, head]
                ).norm()

                if diff > threshold:
                    # This head behaves differently! Investigate further
                    pass

🎓 Tyla's Exercise

  1. Explain why we must return activation in a hook function. What happens if we don't?

  2. Design an experiment to test if a specific attention head is necessary for a specific behavior. What's your null hypothesis?

  3. Activation patching replaces activations from run A into run B. When might this NOT tell us about causation?


💻 Aaliyah's Exercise

Implement causal interventions:

def ablation_experiment(model, prompt, target_token):
    """
    1. Get baseline loss on predicting target_token
    2. For each head, ablate it and measure new loss
    3. Return ranking of heads by importance
    4. Visualize: which layers matter most?
    """
    pass

def activation_patching(model, clean_prompt, corrupted_prompt):
    """
    1. Clean prompt: "The Eiffel Tower is in Paris"
    2. Corrupted: "The Eiffel Tower is in Rome"
    3. Patch each layer's residual stream from clean → corrupted
    4. Find which layer most recovers "Paris" prediction
    """
    pass

📚 Maneesha's Reflection

  1. Hooks give us the ability to intervene on a running computation. What ethical considerations arise from this capability?

  2. "Correlation is not causation" applies here too. How do interventions help us move from correlation to causation?

  3. If you could add one new hook point to TransformerLens, what would it be and why?