Grokking: When Models Suddenly Understand
Grokking reveals something profound: neural networks can memorize first, then generalize much later. Understanding how this happens with modular arithmetic teaches us how models discover algorithms.
What Is Grokking?
Grokking is delayed generalization: a model memorizes the training data perfectly, then long after training loss hits zero, suddenly learns to generalize.
Training Loss: ████████░░░░░░░░░░░░░░░░ → 0 (early)
Test Loss: ████████████████████░░░░ → 0 (much later!)
↑
"Grokking" happens here
First observed by Power et al. (2022) on algorithmic tasks. The model memorizes lookup tables, then discovers the underlying algorithm.
The Modular Addition Task
The classic grokking setup:
Task: Learn $(x + y) \mod p$ for prime $p$ (typically $p = 113$)
# Input: two tokens x, y (each in range [0, p-1])
# Output: (x + y) mod p
# Example for p = 5:
# (2, 3) → 0 because (2 + 3) mod 5 = 0
# (4, 4) → 3 because (4 + 4) mod 5 = 3
# (1, 1) → 2 because (1 + 1) mod 5 = 2
Training setup:
- Vocabulary: $[0, 1, 2, ..., p-1]$
- Input format:
[x, y](two tokens) - Output: single token representing $(x + y) \mod p$
- Train on ~30% of all $(x, y)$ pairs
- Test on remaining ~70%
The Grokking Phenomenon
# Training curves look like this:
epochs = []
train_loss = [] # Drops to ~0 by epoch 1000
test_loss = [] # Stays high until epoch ~30,000, then drops!
# The model first memorizes the training set
# Then discovers the general algorithm
# This "sudden understanding" is grokking
Why does this happen?
- Memorization is the "easy" solution (low complexity per sample)
- Generalization requires discovering structure
- Weight decay slowly pushes toward simpler (generalizing) solutions
How Does the Model Compute Modular Addition?
Here's the key insight: the model uses Fourier analysis on the cyclic group $\mathbb{Z}/p\mathbb{Z}$.
The residual stream learns to represent numbers using trigonometric functions:
$$\text{embed}(x) \approx [\cos(2\pi kx/p), \sin(2\pi kx/p)]_{k=1,2,...}$$
The Fourier Basis
For modular arithmetic, numbers live on a circle (they "wrap around"):
0 → 1 → 2 → ... → (p-1) → 0 → 1 → ...
The natural basis for functions on a circle is Fourier:
$$\cos(2\pi k x / p), \quad \sin(2\pi k x / p) \quad \text{for } k = 0, 1, ..., p-1$$
import torch
import numpy as np
p = 113 # Prime modulus
def fourier_basis(p):
"""Create Fourier basis for Z_p."""
x = torch.arange(p)
basis = []
for k in range(p):
cos_k = torch.cos(2 * np.pi * k * x / p)
sin_k = torch.sin(2 * np.pi * k * x / p)
basis.extend([cos_k, sin_k])
return torch.stack(basis) # Shape: (2p, p)
The Trigonometric Identity for Addition
The model exploits a fundamental identity:
$$\cos(a + b) = \cos(a)\cos(b) - \sin(a)\sin(b)$$ $$\sin(a + b) = \sin(a)\cos(b) + \cos(a)\sin(b)$$
This means: if you know $\cos(\theta_x)$, $\sin(\theta_x)$ and $\cos(\theta_y)$, $\sin(\theta_y)$, you can compute $\cos(\theta_x + \theta_y)$ and $\sin(\theta_x + \theta_y)$!
def compute_sum_fourier(cos_x, sin_x, cos_y, sin_y):
"""Compute cos(x+y) and sin(x+y) from components."""
cos_sum = cos_x * cos_y - sin_x * sin_y
sin_sum = sin_x * cos_y + cos_x * sin_y
return cos_sum, sin_sum
The Algorithm: Step by Step
Here's what the trained model learns to do:
Step 1: EMBED
─────────────
Input tokens x, y → Fourier representation
x → [cos(2πkx/p), sin(2πkx/p)] for key frequencies k
Step 2: COMPUTE QUADRATICS (in MLP)
───────────────────────────────────
Compute products:
• cos(θx) · cos(θy)
• sin(θx) · sin(θy)
• cos(θx) · sin(θy)
• sin(θx) · cos(θy)
Step 3: COMBINE (via trig identity)
───────────────────────────────────
cos(θx + θy) = cos(θx)cos(θy) - sin(θx)sin(θy)
sin(θx + θy) = sin(θx)cos(θy) + cos(θx)sin(θy)
Step 4: UNEMBED
───────────────
Convert [cos(2πk(x+y)/p), sin(2πk(x+y)/p)] → logits
Key Frequencies
Not all frequencies are equally important. The model focuses on key frequencies $k$ where:
$$k^2 \equiv 1 \pmod{p} \quad \text{or} \quad k \text{ divides } (p-1)$$
For $p = 113$, key frequencies include $k \in {1, 2, 4, 7, 8, ...}$
def find_key_frequencies(p):
"""Find frequencies the model focuses on."""
key_freqs = []
for k in range(1, p):
# Check if k is important for the Fourier representation
if (p - 1) % k == 0 or (k * k) % p == 1:
key_freqs.append(k)
return key_freqs
Analyzing the Embedding
We can verify the model learns Fourier embeddings:
def analyze_embedding_fourier(model, p):
"""Check if embeddings are Fourier-like."""
W_E = model.W_E # Shape: (vocab_size, d_model)
# Project onto Fourier basis
fourier = fourier_basis(p) # Shape: (2p, p)
# For each embedding dimension, compute Fourier coefficients
coeffs = fourier @ W_E[:p].T # Project embeddings onto Fourier basis
# Strong coefficients at key frequencies indicate Fourier structure
return coeffs
# Expect: Large coefficients at positions 2k and 2k+1 for key frequencies k
Analyzing the MLP
The MLP computes the quadratic terms:
def analyze_mlp_quadratics(model, p):
"""
The MLP should compute cos(x)cos(y), sin(x)sin(y), etc.
We can verify by:
1. Running the model on all (x, y) pairs
2. Looking at MLP activations
3. Checking if they correlate with expected quadratic terms
"""
all_pairs = [(x, y) for x in range(p) for y in range(p)]
# Get MLP activations
activations = []
for x, y in all_pairs:
tokens = torch.tensor([[x, y]])
_, cache = model.run_with_cache(tokens)
mlp_out = cache["mlp_out", 0][0, -1] # Last position
activations.append(mlp_out)
activations = torch.stack(activations)
# Compute expected quadratics
expected = []
for x, y in all_pairs:
for k in [1, 2, 4]: # Key frequencies
theta_x = 2 * np.pi * k * x / p
theta_y = 2 * np.pi * k * y / p
expected.append([
np.cos(theta_x) * np.cos(theta_y),
np.sin(theta_x) * np.sin(theta_y),
])
# Correlate and find matching neurons
# High correlation = neuron computes that quadratic
Constructive and Destructive Interference
How does the model convert Fourier components to output logits?
Key insight: The unembedding uses interference patterns.
# For the correct answer z = (x + y) mod p:
# Fourier components add constructively
# For wrong answers z' != (x + y) mod p:
# Fourier components interfere destructively
def compute_output_interference(cos_sum, sin_sum, z, p, k):
"""
Compute contribution to logit for answer z.
For correct z = (x+y) mod p:
cos(θ_{x+y}) · cos(θ_z) + sin(θ_{x+y}) · sin(θ_z)
= cos(θ_{x+y} - θ_z) = cos(0) = 1 (constructive!)
For wrong z:
cos(θ_{x+y} - θ_z) = cos(2πk(x+y-z)/p) ≈ 0 on average
"""
theta_z = 2 * np.pi * k * z / p
contribution = cos_sum * np.cos(theta_z) + sin_sum * np.sin(theta_z)
return contribution
Visualizing the Algorithm
Input: (x, y)
│
▼
┌─────────────────────────────────────────┐
│ EMBEDDING │
│ x → [cos(2πkx/p), sin(2πkx/p)]_k │
│ y → [cos(2πky/p), sin(2πky/p)]_k │
└─────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────┐
│ ATTENTION │
│ Moves x information to y position │
│ (needed so MLP can see both) │
└─────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────┐
│ MLP │
│ Computes quadratic terms: │
│ cos(θx)·cos(θy), sin(θx)·sin(θy), etc. │
│ → Forms cos(θx+θy), sin(θx+θy) │
└─────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────┐
│ UNEMBEDDING │
│ [cos(θ_{x+y}), sin(θ_{x+y})] │
│ → logits for each z │
│ Correct z gets constructive boost │
└─────────────────────────────────────────┘
│
▼
Output: (x + y) mod p
Why Grokking Happens
The phase transition has two competing solutions:
| Solution | Train Loss | Test Loss | Complexity |
|---|---|---|---|
| Memorization | 0 | High | Low per-sample, high total |
| Algorithm | 0 | 0 | Higher per-sample, lower total |
Weight decay penalizes total parameter norm, eventually favoring the algorithm.
# Simplified view of loss landscape
def effective_loss(weights, train_data, test_data, weight_decay):
train_loss = compute_loss(weights, train_data)
regularization = weight_decay * (weights ** 2).sum()
# Early training: minimize train_loss (memorize)
# Late training: regularization dominates → find simpler solution
return train_loss + regularization
The Circuit in Practice
def run_grokking_model(model, x, y, p):
"""
Trace the Fourier algorithm through the model.
"""
# 1. Tokenize
tokens = torch.tensor([[x, y]])
# 2. Run with cache
logits, cache = model.run_with_cache(tokens)
# 3. Extract key activations
embed_x = cache["embed"][0, 0] # Embedding of x
embed_y = cache["embed"][0, 1] # Embedding of y
attn_out = cache["attn_out", 0][0, 1] # Attention moved x to y position
mlp_out = cache["mlp_out", 0][0, 1] # MLP computed quadratics
final_resid = cache["resid_post", 0][0, 1] # Final residual
# 4. Check prediction
predicted = logits[0, 1].argmax().item()
actual = (x + y) % p
return {
"embed_x": embed_x,
"embed_y": embed_y,
"attn_out": attn_out,
"mlp_out": mlp_out,
"predicted": predicted,
"actual": actual,
"correct": predicted == actual
}
Fourier Analysis of Weights
def fourier_analyze_weights(W, p):
"""
Decompose weight matrix into Fourier components.
W: (d_model, p) - e.g., embedding or unembedding
Returns: (d_model, 2p) - Fourier coefficients
"""
fourier = fourier_basis(p) # (2p, p)
# W @ fourier.T gives coefficients
# Each row of result: how much each Fourier component contributes
coeffs = W @ fourier.T / p # Normalize
return coeffs
def plot_fourier_spectrum(coeffs, title="Fourier Spectrum"):
"""Plot the magnitude of Fourier coefficients."""
import matplotlib.pyplot as plt
magnitudes = (coeffs[:, ::2]**2 + coeffs[:, 1::2]**2).sqrt()
# magnitudes[i, k] = strength of frequency k in dimension i
plt.figure(figsize=(12, 4))
plt.imshow(magnitudes.T.detach(), aspect='auto', cmap='hot')
plt.xlabel("Model dimension")
plt.ylabel("Frequency k")
plt.title(title)
plt.colorbar(label="Magnitude")
Progress Measures
Track whether the model has "grokked":
def measure_grokking_progress(model, p, train_pairs, test_pairs):
"""
Compute metrics that track grokking.
"""
metrics = {}
# 1. Train/test accuracy
metrics["train_acc"] = compute_accuracy(model, train_pairs, p)
metrics["test_acc"] = compute_accuracy(model, test_pairs, p)
# 2. Fourier coefficient strength (does embedding look Fourier?)
W_E = model.W_E[:p]
coeffs = fourier_analyze_weights(W_E.T, p)
key_freqs = find_key_frequencies(p)
metrics["fourier_strength"] = coeffs[:, key_freqs].abs().mean().item()
# 3. Weight norm (are weights getting simpler?)
metrics["weight_norm"] = sum(p.norm() for p in model.parameters()).item()
return metrics
# Grokking signature:
# - train_acc hits 100% early
# - test_acc stays low, then jumps to 100%
# - fourier_strength increases during grokking
# - weight_norm decreases (simpler solution)
Capstone Connection
Grokking and sycophancy detection:
Grokking teaches us that models can discover algorithms hidden in seemingly arbitrary weights. This connects to your sycophancy capstone in several ways:
# 1. ALGORITHMIC REPRESENTATIONS
# Just as modular addition uses Fourier bases,
# sycophancy might use specific representational schemes
# 2. PHASE TRANSITIONS IN BEHAVIOR
# Like grokking, sycophancy might "turn on" at specific layers
# or only activate under certain conditions
def analyze_sycophancy_circuit(model, sycophantic_prompts, honest_prompts):
"""
Apply grokking-style analysis to sycophancy:
- What basis does the model use for user sentiment?
- How are agreement/disagreement computed?
- Where does the "sycophancy algorithm" live?
"""
# Look for the computational structure
# Just like we found Fourier structure in mod addition
# Step 1: Find the representation of user opinion
# Step 2: Find where agreement/disagreement is computed
# Step 3: Find the circuit that converts this to sycophantic output
# 3. WEIGHT ANALYSIS
# Fourier analysis found structure in embeddings
# Similar techniques might reveal sycophancy circuits
The deeper lesson: Mechanistic interpretability works because neural networks discover structured algorithms. Grokking is proof that even simple networks find elegant mathematical solutions. Your capstone applies these techniques to find the structure underlying sycophantic behavior.
🎓 Tyla's Exercise
Prove the Fourier identity: Show that for $k$ a key frequency, $\sum_{z=0}^{p-1} \cos(2\pi k z / p) \cos(2\pi k z' / p) = \frac{p}{2}\delta_{z,z'}$ (Fourier orthogonality). Why does this enable the interference-based output computation?
Analyze the circuit depth: The algorithm requires computing quadratic terms (products of sines and cosines). Prove that a 1-layer transformer cannot compute these quadratics. What is the minimum depth needed?
Generalize to other operations: The Fourier approach works for addition. Sketch how you might modify it for multiplication $(x \cdot y) \mod p$. What additional structure would the model need to learn?
💻 Aaliyah's Exercise
Build a grokking analysis toolkit:
def train_grokking_model(p=113, train_fraction=0.3, epochs=50000):
"""
1. Create the modular addition dataset
2. Split into train/test (train_fraction of all pairs)
3. Train a 1-layer transformer
4. Log train/test accuracy every 100 epochs
5. Return model and training curves
"""
# Hyperparameters that encourage grokking:
# - High weight decay (0.1 or more)
# - Full batch training
# - Small model (1 layer, 128 dim)
pass
def analyze_trained_grokking_model(model, p):
"""
1. Extract and visualize Fourier components of W_E
2. Find which frequencies are used (should be key frequencies)
3. Verify MLP computes quadratic terms
4. Show constructive interference for correct answers
5. Produce a summary: "This model uses frequencies k=... to compute mod-p addition"
"""
pass
def intervene_on_grokking(model, p, freq_to_ablate):
"""
1. Identify the Fourier direction for frequency k
2. Project it out of the embedding/unembedding
3. Measure impact on accuracy
4. Hypothesis: Ablating key frequencies breaks the model
"""
pass
📚 Maneesha's Reflection
On discovery vs design: The model was trained on input-output pairs, not taught the Fourier algorithm. It discovered this elegant solution through gradient descent. What does this tell us about the relationship between learning and understanding? Does a model that uses Fourier analysis "understand" modular arithmetic?
On interpretability methodology: We verified the Fourier hypothesis by analyzing weights and activations. But we started with a hypothesis (Fourier analysis is natural for cyclic groups). How would you discover such structure if you didn't already know what to look for? What general principles guide mechanistic interpretability?
On grokking as a teaching moment: Grokking shows that train loss can be misleading - a model with zero train loss might not have learned anything generalizable. How would you modify standard ML education to incorporate this insight? What visualizations or metrics would help practitioners detect whether their model has truly learned vs. memorized?