SAE Interpretability: Finding Circuits
SAEs give us interpretable features. Now let's find circuits between them.
The SAE Dashboard
Every SAE latent can be characterized by:
┌─────────────────────────────────────────┐
│ Latent 2847: "Python code context" │
├─────────────────────────────────────────┤
│ Top Activating Examples: │
│ • "def train_model(x):" → 0.95 │
│ • "import numpy as np" → 0.87 │
│ • "for i in range(10):" → 0.82 │
├─────────────────────────────────────────┤
│ Logit Attribution: │
│ ↑ "def", "class", "import" │
│ ↓ "the", "and", "is" │
├─────────────────────────────────────────┤
│ Activation Histogram: [sparse, peaked] │
└─────────────────────────────────────────┘
Finding Latents by Behavior
Direct Logit Attribution:
def get_latent_logit_effect(sae, model, latent_idx, token):
"""What effect does this latent have on a token's probability?"""
# Get decoder direction for this latent
latent_direction = sae.decoder.weight[:, latent_idx]
# Project through unembedding
token_id = model.tokenizer.encode(token)[0]
logit_effect = latent_direction @ model.W_U[:, token_id]
return logit_effect
# Find latents that boost "Paris"
for idx in range(sae.n_latents):
effect = get_latent_logit_effect(sae, model, idx, " Paris")
if effect > 5.0:
print(f"Latent {idx} strongly promotes 'Paris'")
Attribution Patching with SAEs
Find which latents matter for a behavior:
def sae_attribution_patching(model, sae, clean_prompt, corrupted_prompt):
"""
For each SAE latent:
1. Run clean prompt, save latent activations
2. Run corrupted prompt with this latent patched from clean
3. Measure recovery of clean behavior
"""
_, clean_cache = model.run_with_cache(clean_prompt)
clean_latents = sae.encode(clean_cache["resid_post", layer])
results = []
for latent_idx in range(sae.n_latents):
def patch_latent(activations, hook):
latents = sae.encode(activations)
latents[:, :, latent_idx] = clean_latents[:, :, latent_idx]
return sae.decode(latents)
patched_logits = model.run_with_hooks(
corrupted_prompt,
fwd_hooks=[(f"blocks.{layer}.hook_resid_post", patch_latent)]
)
results.append(measure_recovery(patched_logits))
return results
SAE Circuits
Latents in different layers can form circuits:
Layer 2 Latent: "is_entity"
↓ (via attention)
Layer 5 Latent: "entity_attribute"
↓ (via MLP)
Layer 8 Latent: "attribute_value"
↓
Output: Predicted token
Finding these connections is the frontier of SAE research.
Attention SAEs
SAEs can also be trained on attention outputs:
# Different hook points for attention SAEs
hook_points = [
"blocks.0.attn.hook_z", # Attention output before projection
"blocks.0.hook_attn_out", # After W_O projection
]
# Attention SAE latents often represent:
# - "copying information from position X"
# - "attending to tokens matching pattern Y"
Direct Latent Attribution
For attention SAEs, we can attribute to specific head behaviors:
def attention_sae_dla(attn_sae, cache, layer):
"""
Which attention SAE latents most affect the logits?
"""
# Get attention output
attn_out = cache["attn_out", layer] # (batch, seq, d_model)
# Encode with SAE
latents = attn_sae.encode(attn_out) # (batch, seq, n_latents)
# Get direct effect on logits for each latent
effects = []
for idx in range(attn_sae.n_latents):
# Decoder direction × unembedding
direction = attn_sae.decoder.weight[:, idx]
logit_effect = direction @ model.W_U
effects.append(logit_effect)
return torch.stack(effects)
Transcoders: MLP Replacement
Transcoders replace MLPs entirely:
class Transcoder(nn.Module):
"""
Instead of: MLP(x) = W_out @ ReLU(W_in @ x)
Use: Transcoder(x) = W_dec @ ReLU(W_enc @ x)
Where W_enc: (d_model, n_latents)
And W_dec: (n_latents, d_model)
"""
def __init__(self, d_model, n_latents):
self.encoder = nn.Linear(d_model, n_latents)
self.decoder = nn.Linear(n_latents, d_model, bias=False)
def forward(self, x):
latents = F.relu(self.encoder(x))
return self.decoder(latents), latents
Transcoders make circuit analysis easier because latent→output is direct.
Feature Splitting
Same concept, multiple latents:
# Multiple latents for "code"
code_latents = [
latent_234, # Python code
latent_1892, # JavaScript code
latent_4521, # Code in general
latent_7234, # Code inside strings
]
# This is called "feature splitting"
# It's not a bug - it's capturing different aspects
Understanding splitting helps avoid double-counting in analyses.
Autointerp: Automated Interpretation
Use LLMs to label SAE latents:
def autointerp_latent(sae, latent_idx, model, tokenizer):
"""
1. Get top activating examples
2. Format as prompt for an LLM
3. Ask LLM to describe the pattern
"""
examples = get_max_activating_examples(sae, latent_idx, n=20)
prompt = f"""
These text excerpts all strongly activate a particular feature detector.
What pattern do they share?
Examples:
{format_examples(examples)}
The feature detects:
"""
return llm_complete(prompt)
Capstone Connection
SAE circuit analysis for sycophancy:
def trace_sycophancy_circuit(model, sae_dict, prompts):
"""
1. Find latents that activate differently for sycophantic vs honest prompts
2. Trace connections between layers
3. Identify the "sycophancy circuit"
"""
sycophancy_latents = {}
for layer, sae in sae_dict.items():
clean_acts = get_activations(model, honest_prompts, layer)
syco_acts = get_activations(model, sycophantic_prompts, layer)
clean_latents = sae.encode(clean_acts)
syco_latents = sae.encode(syco_acts)
# Find differentially active latents
diff = (syco_latents - clean_latents).mean(dim=0)
top_diff = diff.abs().topk(10)
sycophancy_latents[layer] = top_diff.indices
# Now trace connections between layers...
🎓 Tyla's Exercise
Explain why attention SAEs might be harder to train than MLP SAEs. (Hint: think about the structure of attention outputs.)
If a latent has high reconstruction loss but also has clearly interpretable activations, what does this tell us?
Derive the formula for "direct logit attribution" from a latent: how the latent's activation affects token probabilities.
💻 Aaliyah's Exercise
Build an SAE analysis pipeline:
def create_latent_dashboard(sae, model, latent_idx, dataset):
"""
1. Find top 50 activating examples
2. Compute logit attribution (which tokens are promoted/suppressed)
3. Plot activation histogram
4. Find correlated latents (latents that co-activate)
5. Return dashboard data structure
"""
pass
def find_circuit_between_layers(sae_l1, sae_l2, model, layer1, layer2):
"""
1. For each latent in layer 1, measure effect on layer 2 latents
2. Build connection graph
3. Identify strongly connected latent pairs
"""
pass
📚 Maneesha's Reflection
SAE latents are often described with natural language labels. What are the limits of this approach?
If two research groups train SAEs on the same model but get different latents, what does this mean for interpretability claims?
How would you verify that an SAE circuit explanation is "correct" rather than just one of many valid explanations?