Balanced Bracket Classifier: Algorithmic Interpretability
Toy models trained on synthetic tasks often learn clean, interpretable algorithms. Time to reverse-engineer one.
Why Study Toy Models?
Algorithmic interpretability offers unique advantages:
| Benefit | Why It Matters |
|---|---|
| Ground truth | We know the correct algorithm |
| Small models | Fast experiments, complete enumeration |
| Clean signals | One task, no competing behaviors |
| Generalizable insights | Techniques transfer to larger models |
The bracket classifier is "interpretability on easy mode" - but the lessons apply everywhere.
The Task: Bracket Balancing
Classify whether a parenthesis string is balanced:
# Balanced examples
"()" -> True
"(())" -> True
"()()" -> True
"((()))" -> True
# Unbalanced examples
")(" -> False # Wrong order
"(()" -> False # Missing close
"())" -> False # Extra close
"((())" -> False # Mismatched count
Two failure modes: elevation failure (unequal opens/closes) and negative failure (too many closes at some point).
The Human Algorithm
How would you solve this?
def is_balanced_forloop(parens: str) -> bool:
"""
Track 'elevation' - the number of unclosed opens.
"""
elevation = 0
for char in parens:
elevation += 1 if char == "(" else -1
if elevation < 0: # More closes than opens so far
return False
return elevation == 0 # Must end at zero
Key insight: This is a cumulative sum problem with two checks:
- Final elevation must be zero
- Elevation never goes negative
Vectorized Algorithm
The same logic, without loops:
def is_balanced_vectorized(tokens: Tensor) -> bool:
# Map tokens: '(' -> +1, ')' -> -1, others -> 0
table = torch.tensor([0, 0, 0, 1, -1]) # [start, pad, end, (, )]
changes = table[tokens]
# Cumulative sum gives elevation at each position
elevation = torch.cumsum(changes, dim=-1)
# Check both conditions
no_elevation_failure = elevation[-1] == 0
no_negative_failure = elevation.min() >= 0
return no_elevation_failure and no_negative_failure
This is exactly what the transformer learns to compute!
The Model Architecture
A small bidirectional transformer:
cfg = HookedTransformerConfig(
n_ctx=42, # Max sequence length
d_model=56, # Hidden dimension
d_head=28, # Head dimension
n_heads=2, # 2 heads per layer
d_mlp=56, # MLP width
n_layers=3, # 3 transformer layers
attention_dir="bidirectional", # Not causal!
act_fn="relu",
d_vocab=5, # [start], [pad], [end], (, )
d_vocab_out=2, # Binary classification
)
Classification uses position 0 ([start] token) output.
Bidirectional vs Causal Attention
Causal attention (GPT-style):
- Information flows forward only
- Position
ican only attend to positions0..i - Used for autoregressive generation
Bidirectional attention (BERT-style):
- Information flows both directions
- Position
ican attend to all positions - Used for classification, understanding
# Causal: Triangular mask
# [ 1 0 0 0 ]
# [ 1 1 0 0 ]
# [ 1 1 1 0 ]
# [ 1 1 1 1 ]
# Bidirectional: No positional mask
# [ 1 1 1 1 ]
# [ 1 1 1 1 ]
# [ 1 1 1 1 ]
# [ 1 1 1 1 ]
For bracket classification, bidirectional makes sense - we need to see the whole string.
Classification with Transformers
Instead of predicting next tokens, output class probabilities:
# Language model: d_vocab_out = 50257 (all tokens)
# Classifier: d_vocab_out = 2 (balanced/unbalanced)
# Get logits at position 0 (the [start] token)
logits = model(tokens)[:, 0, :] # Shape: [batch, 2]
# Class 0 = unbalanced, Class 1 = balanced
prob_balanced = logits.softmax(-1)[:, 1]
The model aggregates information from all positions into position 0.
Our Key Metric: Logit Difference
def logit_difference(model, tokens):
"""
Positive = predicts unbalanced
Negative = predicts balanced
"""
logits = model(tokens)[:, 0] # [batch, 2]
return logits[:, 0] - logits[:, 1] # unbalanced - balanced
We can decompose this into contributions from each component.
The Unbalanced Direction
The model outputs logits via the unembedding matrix:
def get_unbalanced_direction(model):
"""
Direction in residual stream that increases P(unbalanced).
"""
# W_U[:, 0] increases unbalanced logit
# W_U[:, 1] increases balanced logit
# Their difference is the "unbalanced direction"
return model.W_U[:, 0] - model.W_U[:, 1]
Project any component's output onto this direction to see its contribution.
Logit Attribution
Which components drive the prediction?
def logit_attribution(model, tokens, data):
"""
Decompose logit difference by component.
"""
_, cache = model.run_with_cache(tokens)
unbalanced_dir = get_unbalanced_direction(model)
contributions = {}
# Embedding contribution
embed = cache["hook_embed"] + cache["hook_pos_embed"]
contributions["embed"] = (embed[:, 0] @ unbalanced_dir).mean()
# Each attention head
for layer in range(model.cfg.n_layers):
for head in range(model.cfg.n_heads):
result = cache[f"blocks.{layer}.attn.hook_result"][:, 0, head]
contributions[f"head_{layer}.{head}"] = (result @ unbalanced_dir).mean()
# Each MLP
for layer in range(model.cfg.n_layers):
mlp_out = cache[f"blocks.{layer}.hook_mlp_out"][:, 0]
contributions[f"mlp_{layer}"] = (mlp_out @ unbalanced_dir).mean()
return contributions
The Total Elevation Circuit
Heads 2.0 and 2.1 dominate the final logit contribution:
Component contributions to unbalanced prediction:
head_2.0: +8.3 (detects elevation failure)
head_2.1: +4.2 (detects negative failure)
mlp_0: +2.1 (intermediate computation)
others: ~0 (minimal contribution)
These two heads implement different parts of the algorithm!
Head 2.0: Elevation Detector
Head 2.0 fires when total elevation is non-zero:
# Attention pattern for head 2.0
# Query at position 0 attends uniformly to all bracket positions
#
# Position: [start] ( ( ) ( ) ) [end]
# Attention: 0.0 0.17 0.17 0.17 0.17 0.17 0.17 0.0
# What it computes:
# Sum of all bracket values = total elevation
The OV circuit maps:
(embeddings -> positive contribution)embeddings -> negative contribution
Net result: output proportional to #opens - #closes.
Head 2.1: Negative Elevation Detector
Head 2.1 detects if elevation ever goes negative:
# More complex attention pattern
# Focuses on positions where cumulative sum might go negative
# Hypothesis: Uses information from earlier layers
# that track "running elevation" at each position
This head activates when ) appears before enough (.
The Role of MLPs
MLPs contribute to the intermediate "running elevation" signal:
def analyze_mlp_neurons(model, data, layer):
"""
Which neurons matter for the elevation computation?
"""
_, cache = model.run_with_cache(data.tokens)
# Get post-activation neuron values
neuron_acts = cache[f"blocks.{layer}.mlp.hook_post"] # [batch, seq, d_mlp]
# For each neuron, check correlation with bracket type
for neuron_idx in range(model.cfg.d_mlp):
acts = neuron_acts[:, :, neuron_idx]
# Correlate with open-proportion, elevation, etc.
Key neurons compute running totals that feed into layer 2 heads.
Handling LayerNorm
LayerNorm complicates direct logit attribution:
# The residual stream passes through LayerNorm before unembedding
# resid -> LN(resid) -> W_U -> logits
# Problem: LN is nonlinear (normalization depends on all components)
Solution: Fit a linear approximation:
from sklearn.linear_model import LinearRegression
def fit_layernorm_linear(model, data, layernorm, seq_pos):
"""
Approximate LayerNorm as a linear transformation.
"""
_, cache = model.run_with_cache(data.tokens)
# Get inputs and outputs of LayerNorm
ln_input = cache[f"blocks.{model.cfg.n_layers-1}.hook_resid_post"][:, seq_pos]
ln_output = cache["ln_final.hook_normalized"][:, seq_pos]
# Fit linear regression
fit = LinearRegression().fit(ln_input.numpy(), ln_output.numpy())
# High R^2 means linear approximation is good
r2 = fit.score(ln_input.numpy(), ln_output.numpy())
print(f"R^2 = {r2:.4f}") # Often > 0.99!
return torch.tensor(fit.coef_)
Surprisingly, LayerNorm is nearly linear in this model!
MLPs as Neuron Collections
View the MLP as a bag of independent neurons:
def get_neuron_contributions(model, data, layer, direction):
"""
Decompose MLP output by neuron.
"""
_, cache = model.run_with_cache(data.tokens)
# Post-ReLU activations
neuron_acts = cache[f"blocks.{layer}.mlp.hook_post"] # [batch, seq, d_mlp]
# Each neuron's output direction
W_out = model.W_out[layer] # [d_mlp, d_model]
# Contribution = activation * (output_direction . target_direction)
output_projections = W_out @ direction # [d_mlp]
contributions = neuron_acts * output_projections # [batch, seq, d_mlp]
return contributions
This reveals which neurons matter for specific computations.
The Full Circuit Diagram
Input: "(())"
Layer 0:
- Embeddings encode token identity + position
- Head 0.0: Duplicate token detection (finds matching brackets?)
- Head 0.1: Position-based patterns
- MLP 0: Computes per-position features
Layer 1:
- Head 1.0/1.1: Aggregate information across positions
- MLP 1: Compute "running elevation" at each position
Layer 2:
- Head 2.0: Reads position 1, outputs total elevation
- Head 2.1: Detects negative elevation anywhere
- MLP 2: Final adjustments
Output (position 0):
- Large positive = unbalanced
- Large negative = balanced
Attention Patterns in Layer 0
Head 0.0 shows a distinctive pattern:
# Attention from position i to position j
# Head 0.0 attends to tokens at similar "elevation levels"
# For "(())", elevation profile is: 1, 2, 1, 0
# Position 0 (elev=1) attends to position 2 (elev=1)
This helps propagate matching-bracket information.
Testing Hypotheses with Interventions
def test_elevation_hypothesis(model, data):
"""
If head 2.0 computes elevation, ablating it should break
elevation-failure detection but not negative-failure detection.
"""
def ablate_head_20(result, hook):
result[:, :, 0] = 0 # Zero out head 2.0
return result
# Test on elevation failures only
elevation_failures = data.tokens[data.has_elevation_failure]
clean_acc = model(elevation_failures)[:, 0].argmax(-1).float().mean()
ablated_acc = model.run_with_hooks(
elevation_failures,
fwd_hooks=[("blocks.2.attn.hook_result", ablate_head_20)]
)[:, 0].argmax(-1).float().mean()
print(f"Clean accuracy: {clean_acc:.2%}")
print(f"Ablated accuracy: {ablated_acc:.2%}")
# Expect big drop!
Adversarial Examples
Understanding the circuit reveals weaknesses:
# The model might fail on:
# 1. Very long sequences (out of distribution)
adversarial_1 = "(" * 100 + ")" * 100
# 2. Patterns that confuse the elevation tracking
adversarial_2 = "(" * 15 + ")(" + ")" * 15 # Sneaky unbalanced
# 3. Edge cases in attention patterns
adversarial_3 = "()" * 20 + ")(" # Hidden failure at end
What the OV Circuits Compute
For head 2.0:
def analyze_ov_circuit(model, tokenizer):
"""
What does head 2.0's OV circuit do to bracket embeddings?
"""
W_OV = model.W_V[2, 0] @ model.W_O[2, 0]
# Get bracket embeddings
open_embed = model.W_E[tokenizer.t_to_i["("]]
close_embed = model.W_E[tokenizer.t_to_i[")"]]
# Transform through OV
open_transformed = open_embed @ W_OV
close_transformed = close_embed @ W_OV
# Check: Are these opposite directions?
cosine_sim = torch.cosine_similarity(open_transformed, close_transformed, dim=0)
print(f"Cosine similarity: {cosine_sim:.4f}") # Expect ~ -1.0!
Open and close brackets map to opposite directions - perfect for computing elevation!
Capstone Connection
Bracket balancing and sycophancy share a pattern:
Both require tracking cumulative state across a sequence:
- Brackets: Running count of opens vs closes
- Sycophancy: Running sentiment of user opinions
# Sycophancy might use similar circuits:
# - "User approval tracker" neurons that accumulate positive/negative signals
# - Final heads that read this tracker and bias toward agreement
# Hypothesis to test:
# 1. Find neurons that track "user sentiment" across conversation
# 2. Check if sycophantic responses correlate with high positive-sentiment accumulation
# 3. Test: Does ablating these neurons reduce sycophancy?
def find_sentiment_tracking_neurons(model, conversations):
"""
Look for neurons whose activation correlates with
cumulative positive/negative user sentiment.
"""
# Similar analysis to bracket elevation tracking!
pass
The bracket classifier teaches us how to find and verify "accumulator circuits."
Key Takeaways
- Clean algorithms emerge: The model learned exactly the cumulative-sum solution
- Heads specialize: Different heads handle different failure modes
- MLPs compute features: Neurons implement specific sub-computations
- LayerNorm is ~linear: Often well-approximated for interpretability
- Interventions verify hypotheses: Ablation tests confirm causal roles
🎓 Tyla's Exercise
Prove that detecting balanced brackets requires at least O(log n) bits of state, where n is sequence length. How does the model's d_model of 56 relate to this?
The model uses bidirectional attention. Could a causal (autoregressive) model solve bracket classification? What architectural modifications would be needed?
Head 2.0 attends uniformly across positions. Show mathematically that uniform attention + the right OV circuit computes exactly the sum of token embeddings.
LayerNorm normalizes by subtracting mean and dividing by std. Under what conditions is this approximately linear?
💻 Aaliyah's Exercise
Implement the core analysis:
def get_component_contributions(model, tokens):
"""
1. Run forward pass with caching
2. Compute unbalanced direction
3. Project each component's output onto this direction
4. Return dict of {component_name: mean_contribution}
"""
pass
def identify_failure_mode(model, unbalanced_string):
"""
1. Compute elevation at each position
2. Classify as 'elevation_failure', 'negative_failure', or 'both'
3. Return failure type and relevant positions
"""
pass
def ablate_and_test(model, data, component_name):
"""
1. Run clean forward pass, record accuracy by failure type
2. Ablate specified component (zero its output)
3. Record accuracy by failure type again
4. Report which failure types are affected
"""
pass
def find_accumulator_neurons(model, data, target_quantity):
"""
1. Compute target_quantity (e.g., running elevation) for each position
2. Get neuron activations at each position
3. Correlate each neuron with target_quantity
4. Return neurons with high correlation
"""
pass
📚 Maneesha's Reflection
This toy model learns a clean algorithm matching human intuition. Do you expect language models to learn similarly clean algorithms for complex behaviors? Why or why not?
The bracket task has a known optimal solution. For tasks without known solutions (like "be helpful"), how would you verify that you've found the "real" algorithm?
Bidirectional attention seems natural for classification. Yet GPT (causal) models can also classify. What are the tradeoffs? When would you choose each?
The model achieves near-perfect accuracy on brackets of length <= 40. What experiments would you run to understand its out-of-distribution generalization?