Grokking: When Models Suddenly Understand

Grokking reveals something profound: neural networks can memorize first, then generalize much later. Understanding how this happens with modular arithmetic teaches us how models discover algorithms.


What Is Grokking?

Grokking is delayed generalization: a model memorizes the training data perfectly, then long after training loss hits zero, suddenly learns to generalize.

Training Loss: ████████░░░░░░░░░░░░░░░░ → 0 (early)
Test Loss:     ████████████████████░░░░ → 0 (much later!)
                    ↑
             "Grokking" happens here

First observed by Power et al. (2022) on algorithmic tasks. The model memorizes lookup tables, then discovers the underlying algorithm.


The Modular Addition Task

The classic grokking setup:

Task: Learn $(x + y) \mod p$ for prime $p$ (typically $p = 113$)

# Input: two tokens x, y (each in range [0, p-1])
# Output: (x + y) mod p

# Example for p = 5:
# (2, 3) → 0  because (2 + 3) mod 5 = 0
# (4, 4) → 3  because (4 + 4) mod 5 = 3
# (1, 1) → 2  because (1 + 1) mod 5 = 2

Training setup:


The Grokking Phenomenon

# Training curves look like this:
epochs = []
train_loss = []  # Drops to ~0 by epoch 1000
test_loss = []   # Stays high until epoch ~30,000, then drops!

# The model first memorizes the training set
# Then discovers the general algorithm
# This "sudden understanding" is grokking

Why does this happen?


How Does the Model Compute Modular Addition?

Here's the key insight: the model uses Fourier analysis on the cyclic group $\mathbb{Z}/p\mathbb{Z}$.

The residual stream learns to represent numbers using trigonometric functions:

$$\text{embed}(x) \approx [\cos(2\pi kx/p), \sin(2\pi kx/p)]_{k=1,2,...}$$


The Fourier Basis

For modular arithmetic, numbers live on a circle (they "wrap around"):

0 → 1 → 2 → ... → (p-1) → 0 → 1 → ...

The natural basis for functions on a circle is Fourier:

$$\cos(2\pi k x / p), \quad \sin(2\pi k x / p) \quad \text{for } k = 0, 1, ..., p-1$$

import torch
import numpy as np

p = 113  # Prime modulus

def fourier_basis(p):
    """Create Fourier basis for Z_p."""
    x = torch.arange(p)
    basis = []
    for k in range(p):
        cos_k = torch.cos(2 * np.pi * k * x / p)
        sin_k = torch.sin(2 * np.pi * k * x / p)
        basis.extend([cos_k, sin_k])
    return torch.stack(basis)  # Shape: (2p, p)

The Trigonometric Identity for Addition

The model exploits a fundamental identity:

$$\cos(a + b) = \cos(a)\cos(b) - \sin(a)\sin(b)$$ $$\sin(a + b) = \sin(a)\cos(b) + \cos(a)\sin(b)$$

This means: if you know $\cos(\theta_x)$, $\sin(\theta_x)$ and $\cos(\theta_y)$, $\sin(\theta_y)$, you can compute $\cos(\theta_x + \theta_y)$ and $\sin(\theta_x + \theta_y)$!

def compute_sum_fourier(cos_x, sin_x, cos_y, sin_y):
    """Compute cos(x+y) and sin(x+y) from components."""
    cos_sum = cos_x * cos_y - sin_x * sin_y
    sin_sum = sin_x * cos_y + cos_x * sin_y
    return cos_sum, sin_sum

The Algorithm: Step by Step

Here's what the trained model learns to do:

Step 1: EMBED
─────────────
Input tokens x, y → Fourier representation
x → [cos(2πkx/p), sin(2πkx/p)] for key frequencies k

Step 2: COMPUTE QUADRATICS (in MLP)
───────────────────────────────────
Compute products:
• cos(θx) · cos(θy)
• sin(θx) · sin(θy)
• cos(θx) · sin(θy)
• sin(θx) · cos(θy)

Step 3: COMBINE (via trig identity)
───────────────────────────────────
cos(θx + θy) = cos(θx)cos(θy) - sin(θx)sin(θy)
sin(θx + θy) = sin(θx)cos(θy) + cos(θx)sin(θy)

Step 4: UNEMBED
───────────────
Convert [cos(2πk(x+y)/p), sin(2πk(x+y)/p)] → logits

Key Frequencies

Not all frequencies are equally important. The model focuses on key frequencies $k$ where:

$$k^2 \equiv 1 \pmod{p} \quad \text{or} \quad k \text{ divides } (p-1)$$

For $p = 113$, key frequencies include $k \in {1, 2, 4, 7, 8, ...}$

def find_key_frequencies(p):
    """Find frequencies the model focuses on."""
    key_freqs = []
    for k in range(1, p):
        # Check if k is important for the Fourier representation
        if (p - 1) % k == 0 or (k * k) % p == 1:
            key_freqs.append(k)
    return key_freqs

Analyzing the Embedding

We can verify the model learns Fourier embeddings:

def analyze_embedding_fourier(model, p):
    """Check if embeddings are Fourier-like."""
    W_E = model.W_E  # Shape: (vocab_size, d_model)

    # Project onto Fourier basis
    fourier = fourier_basis(p)  # Shape: (2p, p)

    # For each embedding dimension, compute Fourier coefficients
    coeffs = fourier @ W_E[:p].T  # Project embeddings onto Fourier basis

    # Strong coefficients at key frequencies indicate Fourier structure
    return coeffs

# Expect: Large coefficients at positions 2k and 2k+1 for key frequencies k

Analyzing the MLP

The MLP computes the quadratic terms:

def analyze_mlp_quadratics(model, p):
    """
    The MLP should compute cos(x)cos(y), sin(x)sin(y), etc.

    We can verify by:
    1. Running the model on all (x, y) pairs
    2. Looking at MLP activations
    3. Checking if they correlate with expected quadratic terms
    """
    all_pairs = [(x, y) for x in range(p) for y in range(p)]

    # Get MLP activations
    activations = []
    for x, y in all_pairs:
        tokens = torch.tensor([[x, y]])
        _, cache = model.run_with_cache(tokens)
        mlp_out = cache["mlp_out", 0][0, -1]  # Last position
        activations.append(mlp_out)

    activations = torch.stack(activations)

    # Compute expected quadratics
    expected = []
    for x, y in all_pairs:
        for k in [1, 2, 4]:  # Key frequencies
            theta_x = 2 * np.pi * k * x / p
            theta_y = 2 * np.pi * k * y / p
            expected.append([
                np.cos(theta_x) * np.cos(theta_y),
                np.sin(theta_x) * np.sin(theta_y),
            ])

    # Correlate and find matching neurons
    # High correlation = neuron computes that quadratic

Constructive and Destructive Interference

How does the model convert Fourier components to output logits?

Key insight: The unembedding uses interference patterns.

# For the correct answer z = (x + y) mod p:
# Fourier components add constructively

# For wrong answers z' != (x + y) mod p:
# Fourier components interfere destructively

def compute_output_interference(cos_sum, sin_sum, z, p, k):
    """
    Compute contribution to logit for answer z.

    For correct z = (x+y) mod p:
        cos(θ_{x+y}) · cos(θ_z) + sin(θ_{x+y}) · sin(θ_z)
      = cos(θ_{x+y} - θ_z) = cos(0) = 1  (constructive!)

    For wrong z:
        cos(θ_{x+y} - θ_z) = cos(2πk(x+y-z)/p) ≈ 0 on average
    """
    theta_z = 2 * np.pi * k * z / p
    contribution = cos_sum * np.cos(theta_z) + sin_sum * np.sin(theta_z)
    return contribution

Visualizing the Algorithm

                    Input: (x, y)
                         │
                         ▼
    ┌─────────────────────────────────────────┐
    │              EMBEDDING                   │
    │  x → [cos(2πkx/p), sin(2πkx/p)]_k       │
    │  y → [cos(2πky/p), sin(2πky/p)]_k       │
    └─────────────────────────────────────────┘
                         │
                         ▼
    ┌─────────────────────────────────────────┐
    │              ATTENTION                   │
    │  Moves x information to y position      │
    │  (needed so MLP can see both)           │
    └─────────────────────────────────────────┘
                         │
                         ▼
    ┌─────────────────────────────────────────┐
    │                MLP                       │
    │  Computes quadratic terms:              │
    │  cos(θx)·cos(θy), sin(θx)·sin(θy), etc. │
    │  → Forms cos(θx+θy), sin(θx+θy)         │
    └─────────────────────────────────────────┘
                         │
                         ▼
    ┌─────────────────────────────────────────┐
    │            UNEMBEDDING                   │
    │  [cos(θ_{x+y}), sin(θ_{x+y})]           │
    │         → logits for each z             │
    │  Correct z gets constructive boost      │
    └─────────────────────────────────────────┘
                         │
                         ▼
                Output: (x + y) mod p

Why Grokking Happens

The phase transition has two competing solutions:

Solution Train Loss Test Loss Complexity
Memorization 0 High Low per-sample, high total
Algorithm 0 0 Higher per-sample, lower total

Weight decay penalizes total parameter norm, eventually favoring the algorithm.

# Simplified view of loss landscape
def effective_loss(weights, train_data, test_data, weight_decay):
    train_loss = compute_loss(weights, train_data)
    regularization = weight_decay * (weights ** 2).sum()

    # Early training: minimize train_loss (memorize)
    # Late training: regularization dominates → find simpler solution
    return train_loss + regularization

The Circuit in Practice

def run_grokking_model(model, x, y, p):
    """
    Trace the Fourier algorithm through the model.
    """
    # 1. Tokenize
    tokens = torch.tensor([[x, y]])

    # 2. Run with cache
    logits, cache = model.run_with_cache(tokens)

    # 3. Extract key activations
    embed_x = cache["embed"][0, 0]  # Embedding of x
    embed_y = cache["embed"][0, 1]  # Embedding of y

    attn_out = cache["attn_out", 0][0, 1]  # Attention moved x to y position
    mlp_out = cache["mlp_out", 0][0, 1]    # MLP computed quadratics

    final_resid = cache["resid_post", 0][0, 1]  # Final residual

    # 4. Check prediction
    predicted = logits[0, 1].argmax().item()
    actual = (x + y) % p

    return {
        "embed_x": embed_x,
        "embed_y": embed_y,
        "attn_out": attn_out,
        "mlp_out": mlp_out,
        "predicted": predicted,
        "actual": actual,
        "correct": predicted == actual
    }

Fourier Analysis of Weights

def fourier_analyze_weights(W, p):
    """
    Decompose weight matrix into Fourier components.

    W: (d_model, p) - e.g., embedding or unembedding
    Returns: (d_model, 2p) - Fourier coefficients
    """
    fourier = fourier_basis(p)  # (2p, p)

    # W @ fourier.T gives coefficients
    # Each row of result: how much each Fourier component contributes
    coeffs = W @ fourier.T / p  # Normalize

    return coeffs

def plot_fourier_spectrum(coeffs, title="Fourier Spectrum"):
    """Plot the magnitude of Fourier coefficients."""
    import matplotlib.pyplot as plt

    magnitudes = (coeffs[:, ::2]**2 + coeffs[:, 1::2]**2).sqrt()
    # magnitudes[i, k] = strength of frequency k in dimension i

    plt.figure(figsize=(12, 4))
    plt.imshow(magnitudes.T.detach(), aspect='auto', cmap='hot')
    plt.xlabel("Model dimension")
    plt.ylabel("Frequency k")
    plt.title(title)
    plt.colorbar(label="Magnitude")

Progress Measures

Track whether the model has "grokked":

def measure_grokking_progress(model, p, train_pairs, test_pairs):
    """
    Compute metrics that track grokking.
    """
    metrics = {}

    # 1. Train/test accuracy
    metrics["train_acc"] = compute_accuracy(model, train_pairs, p)
    metrics["test_acc"] = compute_accuracy(model, test_pairs, p)

    # 2. Fourier coefficient strength (does embedding look Fourier?)
    W_E = model.W_E[:p]
    coeffs = fourier_analyze_weights(W_E.T, p)
    key_freqs = find_key_frequencies(p)
    metrics["fourier_strength"] = coeffs[:, key_freqs].abs().mean().item()

    # 3. Weight norm (are weights getting simpler?)
    metrics["weight_norm"] = sum(p.norm() for p in model.parameters()).item()

    return metrics

# Grokking signature:
# - train_acc hits 100% early
# - test_acc stays low, then jumps to 100%
# - fourier_strength increases during grokking
# - weight_norm decreases (simpler solution)

Capstone Connection

Grokking and sycophancy detection:

Grokking teaches us that models can discover algorithms hidden in seemingly arbitrary weights. This connects to your sycophancy capstone in several ways:

# 1. ALGORITHMIC REPRESENTATIONS
# Just as modular addition uses Fourier bases,
# sycophancy might use specific representational schemes

# 2. PHASE TRANSITIONS IN BEHAVIOR
# Like grokking, sycophancy might "turn on" at specific layers
# or only activate under certain conditions

def analyze_sycophancy_circuit(model, sycophantic_prompts, honest_prompts):
    """
    Apply grokking-style analysis to sycophancy:
    - What basis does the model use for user sentiment?
    - How are agreement/disagreement computed?
    - Where does the "sycophancy algorithm" live?
    """
    # Look for the computational structure
    # Just like we found Fourier structure in mod addition

    # Step 1: Find the representation of user opinion
    # Step 2: Find where agreement/disagreement is computed
    # Step 3: Find the circuit that converts this to sycophantic output

# 3. WEIGHT ANALYSIS
# Fourier analysis found structure in embeddings
# Similar techniques might reveal sycophancy circuits

The deeper lesson: Mechanistic interpretability works because neural networks discover structured algorithms. Grokking is proof that even simple networks find elegant mathematical solutions. Your capstone applies these techniques to find the structure underlying sycophantic behavior.


🎓 Tyla's Exercise

  1. Prove the Fourier identity: Show that for $k$ a key frequency, $\sum_{z=0}^{p-1} \cos(2\pi k z / p) \cos(2\pi k z' / p) = \frac{p}{2}\delta_{z,z'}$ (Fourier orthogonality). Why does this enable the interference-based output computation?

  2. Analyze the circuit depth: The algorithm requires computing quadratic terms (products of sines and cosines). Prove that a 1-layer transformer cannot compute these quadratics. What is the minimum depth needed?

  3. Generalize to other operations: The Fourier approach works for addition. Sketch how you might modify it for multiplication $(x \cdot y) \mod p$. What additional structure would the model need to learn?


💻 Aaliyah's Exercise

Build a grokking analysis toolkit:

def train_grokking_model(p=113, train_fraction=0.3, epochs=50000):
    """
    1. Create the modular addition dataset
    2. Split into train/test (train_fraction of all pairs)
    3. Train a 1-layer transformer
    4. Log train/test accuracy every 100 epochs
    5. Return model and training curves
    """
    # Hyperparameters that encourage grokking:
    # - High weight decay (0.1 or more)
    # - Full batch training
    # - Small model (1 layer, 128 dim)
    pass

def analyze_trained_grokking_model(model, p):
    """
    1. Extract and visualize Fourier components of W_E
    2. Find which frequencies are used (should be key frequencies)
    3. Verify MLP computes quadratic terms
    4. Show constructive interference for correct answers
    5. Produce a summary: "This model uses frequencies k=... to compute mod-p addition"
    """
    pass

def intervene_on_grokking(model, p, freq_to_ablate):
    """
    1. Identify the Fourier direction for frequency k
    2. Project it out of the embedding/unembedding
    3. Measure impact on accuracy
    4. Hypothesis: Ablating key frequencies breaks the model
    """
    pass

📚 Maneesha's Reflection

  1. On discovery vs design: The model was trained on input-output pairs, not taught the Fourier algorithm. It discovered this elegant solution through gradient descent. What does this tell us about the relationship between learning and understanding? Does a model that uses Fourier analysis "understand" modular arithmetic?

  2. On interpretability methodology: We verified the Fourier hypothesis by analyzing weights and activations. But we started with a hypothesis (Fourier analysis is natural for cyclic groups). How would you discover such structure if you didn't already know what to look for? What general principles guide mechanistic interpretability?

  3. On grokking as a teaching moment: Grokking shows that train loss can be misleading - a model with zero train loss might not have learned anything generalizable. How would you modify standard ML education to incorporate this insight? What visualizations or metrics would help practitioners detect whether their model has truly learned vs. memorized?