Balanced Bracket Classifier: Algorithmic Interpretability

Toy models trained on synthetic tasks often learn clean, interpretable algorithms. Time to reverse-engineer one.


Why Study Toy Models?

Algorithmic interpretability offers unique advantages:

Benefit Why It Matters
Ground truth We know the correct algorithm
Small models Fast experiments, complete enumeration
Clean signals One task, no competing behaviors
Generalizable insights Techniques transfer to larger models

The bracket classifier is "interpretability on easy mode" - but the lessons apply everywhere.


The Task: Bracket Balancing

Classify whether a parenthesis string is balanced:

# Balanced examples
"()"      -> True
"(())"    -> True
"()()"    -> True
"((()))"  -> True

# Unbalanced examples
")("      -> False  # Wrong order
"(()"     -> False  # Missing close
"())"     -> False  # Extra close
"((())"   -> False  # Mismatched count

Two failure modes: elevation failure (unequal opens/closes) and negative failure (too many closes at some point).


The Human Algorithm

How would you solve this?

def is_balanced_forloop(parens: str) -> bool:
    """
    Track 'elevation' - the number of unclosed opens.
    """
    elevation = 0
    for char in parens:
        elevation += 1 if char == "(" else -1
        if elevation < 0:  # More closes than opens so far
            return False

    return elevation == 0  # Must end at zero

Key insight: This is a cumulative sum problem with two checks:

  1. Final elevation must be zero
  2. Elevation never goes negative

Vectorized Algorithm

The same logic, without loops:

def is_balanced_vectorized(tokens: Tensor) -> bool:
    # Map tokens: '(' -> +1, ')' -> -1, others -> 0
    table = torch.tensor([0, 0, 0, 1, -1])  # [start, pad, end, (, )]
    changes = table[tokens]

    # Cumulative sum gives elevation at each position
    elevation = torch.cumsum(changes, dim=-1)

    # Check both conditions
    no_elevation_failure = elevation[-1] == 0
    no_negative_failure = elevation.min() >= 0

    return no_elevation_failure and no_negative_failure

This is exactly what the transformer learns to compute!


The Model Architecture

A small bidirectional transformer:

cfg = HookedTransformerConfig(
    n_ctx=42,           # Max sequence length
    d_model=56,         # Hidden dimension
    d_head=28,          # Head dimension
    n_heads=2,          # 2 heads per layer
    d_mlp=56,           # MLP width
    n_layers=3,         # 3 transformer layers
    attention_dir="bidirectional",  # Not causal!
    act_fn="relu",
    d_vocab=5,          # [start], [pad], [end], (, )
    d_vocab_out=2,      # Binary classification
)

Classification uses position 0 ([start] token) output.


Bidirectional vs Causal Attention

Causal attention (GPT-style):

Bidirectional attention (BERT-style):

# Causal: Triangular mask
#   [ 1  0  0  0 ]
#   [ 1  1  0  0 ]
#   [ 1  1  1  0 ]
#   [ 1  1  1  1 ]

# Bidirectional: No positional mask
#   [ 1  1  1  1 ]
#   [ 1  1  1  1 ]
#   [ 1  1  1  1 ]
#   [ 1  1  1  1 ]

For bracket classification, bidirectional makes sense - we need to see the whole string.


Classification with Transformers

Instead of predicting next tokens, output class probabilities:

# Language model: d_vocab_out = 50257 (all tokens)
# Classifier: d_vocab_out = 2 (balanced/unbalanced)

# Get logits at position 0 (the [start] token)
logits = model(tokens)[:, 0, :]  # Shape: [batch, 2]

# Class 0 = unbalanced, Class 1 = balanced
prob_balanced = logits.softmax(-1)[:, 1]

The model aggregates information from all positions into position 0.


Our Key Metric: Logit Difference

def logit_difference(model, tokens):
    """
    Positive = predicts unbalanced
    Negative = predicts balanced
    """
    logits = model(tokens)[:, 0]  # [batch, 2]
    return logits[:, 0] - logits[:, 1]  # unbalanced - balanced

We can decompose this into contributions from each component.


The Unbalanced Direction

The model outputs logits via the unembedding matrix:

def get_unbalanced_direction(model):
    """
    Direction in residual stream that increases P(unbalanced).
    """
    # W_U[:, 0] increases unbalanced logit
    # W_U[:, 1] increases balanced logit
    # Their difference is the "unbalanced direction"
    return model.W_U[:, 0] - model.W_U[:, 1]

Project any component's output onto this direction to see its contribution.


Logit Attribution

Which components drive the prediction?

def logit_attribution(model, tokens, data):
    """
    Decompose logit difference by component.
    """
    _, cache = model.run_with_cache(tokens)
    unbalanced_dir = get_unbalanced_direction(model)

    contributions = {}

    # Embedding contribution
    embed = cache["hook_embed"] + cache["hook_pos_embed"]
    contributions["embed"] = (embed[:, 0] @ unbalanced_dir).mean()

    # Each attention head
    for layer in range(model.cfg.n_layers):
        for head in range(model.cfg.n_heads):
            result = cache[f"blocks.{layer}.attn.hook_result"][:, 0, head]
            contributions[f"head_{layer}.{head}"] = (result @ unbalanced_dir).mean()

    # Each MLP
    for layer in range(model.cfg.n_layers):
        mlp_out = cache[f"blocks.{layer}.hook_mlp_out"][:, 0]
        contributions[f"mlp_{layer}"] = (mlp_out @ unbalanced_dir).mean()

    return contributions

The Total Elevation Circuit

Heads 2.0 and 2.1 dominate the final logit contribution:

Component contributions to unbalanced prediction:

head_2.0:  +8.3  (detects elevation failure)
head_2.1:  +4.2  (detects negative failure)
mlp_0:     +2.1  (intermediate computation)
others:    ~0    (minimal contribution)

These two heads implement different parts of the algorithm!


Head 2.0: Elevation Detector

Head 2.0 fires when total elevation is non-zero:

# Attention pattern for head 2.0
# Query at position 0 attends uniformly to all bracket positions
#
# Position:  [start]  (  (  )  (  )  )  [end]
# Attention:  0.0    0.17 0.17 0.17 0.17 0.17 0.17  0.0

# What it computes:
# Sum of all bracket values = total elevation

The OV circuit maps:

Net result: output proportional to #opens - #closes.


Head 2.1: Negative Elevation Detector

Head 2.1 detects if elevation ever goes negative:

# More complex attention pattern
# Focuses on positions where cumulative sum might go negative

# Hypothesis: Uses information from earlier layers
# that track "running elevation" at each position

This head activates when ) appears before enough (.


The Role of MLPs

MLPs contribute to the intermediate "running elevation" signal:

def analyze_mlp_neurons(model, data, layer):
    """
    Which neurons matter for the elevation computation?
    """
    _, cache = model.run_with_cache(data.tokens)

    # Get post-activation neuron values
    neuron_acts = cache[f"blocks.{layer}.mlp.hook_post"]  # [batch, seq, d_mlp]

    # For each neuron, check correlation with bracket type
    for neuron_idx in range(model.cfg.d_mlp):
        acts = neuron_acts[:, :, neuron_idx]
        # Correlate with open-proportion, elevation, etc.

Key neurons compute running totals that feed into layer 2 heads.


Handling LayerNorm

LayerNorm complicates direct logit attribution:

# The residual stream passes through LayerNorm before unembedding
# resid -> LN(resid) -> W_U -> logits

# Problem: LN is nonlinear (normalization depends on all components)

Solution: Fit a linear approximation:

from sklearn.linear_model import LinearRegression

def fit_layernorm_linear(model, data, layernorm, seq_pos):
    """
    Approximate LayerNorm as a linear transformation.
    """
    _, cache = model.run_with_cache(data.tokens)

    # Get inputs and outputs of LayerNorm
    ln_input = cache[f"blocks.{model.cfg.n_layers-1}.hook_resid_post"][:, seq_pos]
    ln_output = cache["ln_final.hook_normalized"][:, seq_pos]

    # Fit linear regression
    fit = LinearRegression().fit(ln_input.numpy(), ln_output.numpy())

    # High R^2 means linear approximation is good
    r2 = fit.score(ln_input.numpy(), ln_output.numpy())
    print(f"R^2 = {r2:.4f}")  # Often > 0.99!

    return torch.tensor(fit.coef_)

Surprisingly, LayerNorm is nearly linear in this model!


MLPs as Neuron Collections

View the MLP as a bag of independent neurons:

def get_neuron_contributions(model, data, layer, direction):
    """
    Decompose MLP output by neuron.
    """
    _, cache = model.run_with_cache(data.tokens)

    # Post-ReLU activations
    neuron_acts = cache[f"blocks.{layer}.mlp.hook_post"]  # [batch, seq, d_mlp]

    # Each neuron's output direction
    W_out = model.W_out[layer]  # [d_mlp, d_model]

    # Contribution = activation * (output_direction . target_direction)
    output_projections = W_out @ direction  # [d_mlp]
    contributions = neuron_acts * output_projections  # [batch, seq, d_mlp]

    return contributions

This reveals which neurons matter for specific computations.


The Full Circuit Diagram

Input: "(())"

Layer 0:
  - Embeddings encode token identity + position
  - Head 0.0: Duplicate token detection (finds matching brackets?)
  - Head 0.1: Position-based patterns
  - MLP 0: Computes per-position features

Layer 1:
  - Head 1.0/1.1: Aggregate information across positions
  - MLP 1: Compute "running elevation" at each position

Layer 2:
  - Head 2.0: Reads position 1, outputs total elevation
  - Head 2.1: Detects negative elevation anywhere
  - MLP 2: Final adjustments

Output (position 0):
  - Large positive = unbalanced
  - Large negative = balanced

Attention Patterns in Layer 0

Head 0.0 shows a distinctive pattern:

# Attention from position i to position j
# Head 0.0 attends to tokens at similar "elevation levels"

# For "(())", elevation profile is: 1, 2, 1, 0
# Position 0 (elev=1) attends to position 2 (elev=1)

This helps propagate matching-bracket information.


Testing Hypotheses with Interventions

def test_elevation_hypothesis(model, data):
    """
    If head 2.0 computes elevation, ablating it should break
    elevation-failure detection but not negative-failure detection.
    """
    def ablate_head_20(result, hook):
        result[:, :, 0] = 0  # Zero out head 2.0
        return result

    # Test on elevation failures only
    elevation_failures = data.tokens[data.has_elevation_failure]

    clean_acc = model(elevation_failures)[:, 0].argmax(-1).float().mean()

    ablated_acc = model.run_with_hooks(
        elevation_failures,
        fwd_hooks=[("blocks.2.attn.hook_result", ablate_head_20)]
    )[:, 0].argmax(-1).float().mean()

    print(f"Clean accuracy: {clean_acc:.2%}")
    print(f"Ablated accuracy: {ablated_acc:.2%}")
    # Expect big drop!

Adversarial Examples

Understanding the circuit reveals weaknesses:

# The model might fail on:

# 1. Very long sequences (out of distribution)
adversarial_1 = "(" * 100 + ")" * 100

# 2. Patterns that confuse the elevation tracking
adversarial_2 = "(" * 15 + ")(" + ")" * 15  # Sneaky unbalanced

# 3. Edge cases in attention patterns
adversarial_3 = "()" * 20 + ")("  # Hidden failure at end

What the OV Circuits Compute

For head 2.0:

def analyze_ov_circuit(model, tokenizer):
    """
    What does head 2.0's OV circuit do to bracket embeddings?
    """
    W_OV = model.W_V[2, 0] @ model.W_O[2, 0]

    # Get bracket embeddings
    open_embed = model.W_E[tokenizer.t_to_i["("]]
    close_embed = model.W_E[tokenizer.t_to_i[")"]]

    # Transform through OV
    open_transformed = open_embed @ W_OV
    close_transformed = close_embed @ W_OV

    # Check: Are these opposite directions?
    cosine_sim = torch.cosine_similarity(open_transformed, close_transformed, dim=0)
    print(f"Cosine similarity: {cosine_sim:.4f}")  # Expect ~ -1.0!

Open and close brackets map to opposite directions - perfect for computing elevation!


Capstone Connection

Bracket balancing and sycophancy share a pattern:

Both require tracking cumulative state across a sequence:

# Sycophancy might use similar circuits:
# - "User approval tracker" neurons that accumulate positive/negative signals
# - Final heads that read this tracker and bias toward agreement

# Hypothesis to test:
# 1. Find neurons that track "user sentiment" across conversation
# 2. Check if sycophantic responses correlate with high positive-sentiment accumulation
# 3. Test: Does ablating these neurons reduce sycophancy?

def find_sentiment_tracking_neurons(model, conversations):
    """
    Look for neurons whose activation correlates with
    cumulative positive/negative user sentiment.
    """
    # Similar analysis to bracket elevation tracking!
    pass

The bracket classifier teaches us how to find and verify "accumulator circuits."


Key Takeaways

  1. Clean algorithms emerge: The model learned exactly the cumulative-sum solution
  2. Heads specialize: Different heads handle different failure modes
  3. MLPs compute features: Neurons implement specific sub-computations
  4. LayerNorm is ~linear: Often well-approximated for interpretability
  5. Interventions verify hypotheses: Ablation tests confirm causal roles

🎓 Tyla's Exercise

  1. Prove that detecting balanced brackets requires at least O(log n) bits of state, where n is sequence length. How does the model's d_model of 56 relate to this?

  2. The model uses bidirectional attention. Could a causal (autoregressive) model solve bracket classification? What architectural modifications would be needed?

  3. Head 2.0 attends uniformly across positions. Show mathematically that uniform attention + the right OV circuit computes exactly the sum of token embeddings.

  4. LayerNorm normalizes by subtracting mean and dividing by std. Under what conditions is this approximately linear?


💻 Aaliyah's Exercise

Implement the core analysis:

def get_component_contributions(model, tokens):
    """
    1. Run forward pass with caching
    2. Compute unbalanced direction
    3. Project each component's output onto this direction
    4. Return dict of {component_name: mean_contribution}
    """
    pass


def identify_failure_mode(model, unbalanced_string):
    """
    1. Compute elevation at each position
    2. Classify as 'elevation_failure', 'negative_failure', or 'both'
    3. Return failure type and relevant positions
    """
    pass


def ablate_and_test(model, data, component_name):
    """
    1. Run clean forward pass, record accuracy by failure type
    2. Ablate specified component (zero its output)
    3. Record accuracy by failure type again
    4. Report which failure types are affected
    """
    pass


def find_accumulator_neurons(model, data, target_quantity):
    """
    1. Compute target_quantity (e.g., running elevation) for each position
    2. Get neuron activations at each position
    3. Correlate each neuron with target_quantity
    4. Return neurons with high correlation
    """
    pass

📚 Maneesha's Reflection

  1. This toy model learns a clean algorithm matching human intuition. Do you expect language models to learn similarly clean algorithms for complex behaviors? Why or why not?

  2. The bracket task has a known optimal solution. For tasks without known solutions (like "be helpful"), how would you verify that you've found the "real" algorithm?

  3. Bidirectional attention seems natural for classification. Yet GPT (causal) models can also classify. What are the tradeoffs? When would you choose each?

  4. The model achieves near-perfect accuracy on brackets of length <= 40. What experiments would you run to understand its out-of-distribution generalization?