SAE Interpretability: Finding Circuits

SAEs give us interpretable features. Now let's find circuits between them.


The SAE Dashboard

Every SAE latent can be characterized by:

┌─────────────────────────────────────────┐
│ Latent 2847: "Python code context"      │
├─────────────────────────────────────────┤
│ Top Activating Examples:                │
│  • "def train_model(x):" → 0.95         │
│  • "import numpy as np" → 0.87          │
│  • "for i in range(10):" → 0.82         │
├─────────────────────────────────────────┤
│ Logit Attribution:                      │
│  ↑ "def", "class", "import"             │
│  ↓ "the", "and", "is"                   │
├─────────────────────────────────────────┤
│ Activation Histogram: [sparse, peaked]  │
└─────────────────────────────────────────┘

Finding Latents by Behavior

Direct Logit Attribution:

def get_latent_logit_effect(sae, model, latent_idx, token):
    """What effect does this latent have on a token's probability?"""
    # Get decoder direction for this latent
    latent_direction = sae.decoder.weight[:, latent_idx]

    # Project through unembedding
    token_id = model.tokenizer.encode(token)[0]
    logit_effect = latent_direction @ model.W_U[:, token_id]

    return logit_effect

# Find latents that boost "Paris"
for idx in range(sae.n_latents):
    effect = get_latent_logit_effect(sae, model, idx, " Paris")
    if effect > 5.0:
        print(f"Latent {idx} strongly promotes 'Paris'")

Attribution Patching with SAEs

Find which latents matter for a behavior:

def sae_attribution_patching(model, sae, clean_prompt, corrupted_prompt):
    """
    For each SAE latent:
    1. Run clean prompt, save latent activations
    2. Run corrupted prompt with this latent patched from clean
    3. Measure recovery of clean behavior
    """
    _, clean_cache = model.run_with_cache(clean_prompt)
    clean_latents = sae.encode(clean_cache["resid_post", layer])

    results = []
    for latent_idx in range(sae.n_latents):
        def patch_latent(activations, hook):
            latents = sae.encode(activations)
            latents[:, :, latent_idx] = clean_latents[:, :, latent_idx]
            return sae.decode(latents)

        patched_logits = model.run_with_hooks(
            corrupted_prompt,
            fwd_hooks=[(f"blocks.{layer}.hook_resid_post", patch_latent)]
        )
        results.append(measure_recovery(patched_logits))

    return results

SAE Circuits

Latents in different layers can form circuits:

Layer 2 Latent: "is_entity"
        ↓ (via attention)
Layer 5 Latent: "entity_attribute"
        ↓ (via MLP)
Layer 8 Latent: "attribute_value"
        ↓
Output: Predicted token

Finding these connections is the frontier of SAE research.


Attention SAEs

SAEs can also be trained on attention outputs:

# Different hook points for attention SAEs
hook_points = [
    "blocks.0.attn.hook_z",      # Attention output before projection
    "blocks.0.hook_attn_out",    # After W_O projection
]

# Attention SAE latents often represent:
# - "copying information from position X"
# - "attending to tokens matching pattern Y"

Direct Latent Attribution

For attention SAEs, we can attribute to specific head behaviors:

def attention_sae_dla(attn_sae, cache, layer):
    """
    Which attention SAE latents most affect the logits?
    """
    # Get attention output
    attn_out = cache["attn_out", layer]  # (batch, seq, d_model)

    # Encode with SAE
    latents = attn_sae.encode(attn_out)  # (batch, seq, n_latents)

    # Get direct effect on logits for each latent
    effects = []
    for idx in range(attn_sae.n_latents):
        # Decoder direction × unembedding
        direction = attn_sae.decoder.weight[:, idx]
        logit_effect = direction @ model.W_U
        effects.append(logit_effect)

    return torch.stack(effects)

Transcoders: MLP Replacement

Transcoders replace MLPs entirely:

class Transcoder(nn.Module):
    """
    Instead of: MLP(x) = W_out @ ReLU(W_in @ x)
    Use: Transcoder(x) = W_dec @ ReLU(W_enc @ x)

    Where W_enc: (d_model, n_latents)
    And W_dec: (n_latents, d_model)
    """
    def __init__(self, d_model, n_latents):
        self.encoder = nn.Linear(d_model, n_latents)
        self.decoder = nn.Linear(n_latents, d_model, bias=False)

    def forward(self, x):
        latents = F.relu(self.encoder(x))
        return self.decoder(latents), latents

Transcoders make circuit analysis easier because latent→output is direct.


Feature Splitting

Same concept, multiple latents:

# Multiple latents for "code"
code_latents = [
    latent_234,   # Python code
    latent_1892,  # JavaScript code
    latent_4521,  # Code in general
    latent_7234,  # Code inside strings
]

# This is called "feature splitting"
# It's not a bug - it's capturing different aspects

Understanding splitting helps avoid double-counting in analyses.


Autointerp: Automated Interpretation

Use LLMs to label SAE latents:

def autointerp_latent(sae, latent_idx, model, tokenizer):
    """
    1. Get top activating examples
    2. Format as prompt for an LLM
    3. Ask LLM to describe the pattern
    """
    examples = get_max_activating_examples(sae, latent_idx, n=20)

    prompt = f"""
    These text excerpts all strongly activate a particular feature detector.
    What pattern do they share?

    Examples:
    {format_examples(examples)}

    The feature detects:
    """

    return llm_complete(prompt)

Capstone Connection

SAE circuit analysis for sycophancy:

def trace_sycophancy_circuit(model, sae_dict, prompts):
    """
    1. Find latents that activate differently for sycophantic vs honest prompts
    2. Trace connections between layers
    3. Identify the "sycophancy circuit"
    """
    sycophancy_latents = {}

    for layer, sae in sae_dict.items():
        clean_acts = get_activations(model, honest_prompts, layer)
        syco_acts = get_activations(model, sycophantic_prompts, layer)

        clean_latents = sae.encode(clean_acts)
        syco_latents = sae.encode(syco_acts)

        # Find differentially active latents
        diff = (syco_latents - clean_latents).mean(dim=0)
        top_diff = diff.abs().topk(10)

        sycophancy_latents[layer] = top_diff.indices

    # Now trace connections between layers...

🎓 Tyla's Exercise

  1. Explain why attention SAEs might be harder to train than MLP SAEs. (Hint: think about the structure of attention outputs.)

  2. If a latent has high reconstruction loss but also has clearly interpretable activations, what does this tell us?

  3. Derive the formula for "direct logit attribution" from a latent: how the latent's activation affects token probabilities.


💻 Aaliyah's Exercise

Build an SAE analysis pipeline:

def create_latent_dashboard(sae, model, latent_idx, dataset):
    """
    1. Find top 50 activating examples
    2. Compute logit attribution (which tokens are promoted/suppressed)
    3. Plot activation histogram
    4. Find correlated latents (latents that co-activate)
    5. Return dashboard data structure
    """
    pass

def find_circuit_between_layers(sae_l1, sae_l2, model, layer1, layer2):
    """
    1. For each latent in layer 1, measure effect on layer 2 latents
    2. Build connection graph
    3. Identify strongly connected latent pairs
    """
    pass

📚 Maneesha's Reflection

  1. SAE latents are often described with natural language labels. What are the limits of this approach?

  2. If two research groups train SAEs on the same model but get different latents, what does this mean for interpretability claims?

  3. How would you verify that an SAE circuit explanation is "correct" rather than just one of many valid explanations?