Sparse Autoencoders: Untangling Features
If superposition is the disease, sparse autoencoders are the treatment.
The SAE Idea
Expand the compressed space back into interpretable features:
Residual Stream (768D) → SAE Encoder → Latent Space (16000D) → SAE Decoder → Reconstructed (768D)
↓
Sparse, interpretable features
The key constraint: sparsity. Only a few latents should be active at once.
SAE Architecture
class SparseAutoencoder(nn.Module):
def __init__(self, d_model, n_latents):
super().__init__()
self.encoder = nn.Linear(d_model, n_latents)
self.decoder = nn.Linear(n_latents, d_model, bias=False)
def forward(self, x):
# Encode
pre_acts = self.encoder(x)
latents = F.relu(pre_acts) # Sparsity via ReLU
# Decode
reconstructed = self.decoder(latents)
return reconstructed, latents
The expansion ratio is crucial: typically 8x-64x the model dimension.
The SAE Loss Function
Two competing objectives:
def sae_loss(x, reconstructed, latents, l1_coeff):
# 1. Reconstruction: match the original
reconstruction_loss = (x - reconstructed).pow(2).mean()
# 2. Sparsity: few latents should be active
sparsity_loss = latents.abs().mean()
return reconstruction_loss + l1_coeff * sparsity_loss
The l1_coeff controls the trade-off:
- Too low → dense, uninterpretable latents
- Too high → poor reconstruction
What SAEs Learn
After training, each SAE latent corresponds to a feature:
# Example latents from a trained SAE on GPT-2
latent_2847 → "code context" (activates on programming)
latent_9123 → "past tense" (activates on -ed endings)
latent_15234 → "scientific text" (activates on academic language)
These are more interpretable than neurons!
SAE Variants
Gated SAE (DeepMind):
class GatedSAE(nn.Module):
def forward(self, x):
# Separate magnitude and gate
gate = torch.sigmoid(self.W_gate @ x)
magnitude = F.relu(self.W_mag @ x)
latents = gate * magnitude # Element-wise product
return self.decoder(latents), latents
JumpReLU SAE:
def jump_relu(x, threshold):
return F.relu(x - threshold) * (x > threshold)
Training SAEs
Key considerations:
What to train on: Residual stream? MLP outputs? Attention outputs?
Which layer: Earlier layers → more basic features
Dataset: Diverse text for general features
# Training loop sketch
for batch in dataloader:
# Run model, get activations
_, cache = model.run_with_cache(batch)
activations = cache["resid_post", layer]
# Forward through SAE
reconstructed, latents = sae(activations)
# Compute loss
loss = sae_loss(activations, reconstructed, latents, l1_coeff)
loss.backward()
optimizer.step()
Neuron Resampling
Dead latents (never activate) waste capacity:
def resample_dead_neurons(sae, activations, dead_threshold=1e-7):
"""
Find latents that rarely activate.
Re-initialize them to explain high-error examples.
"""
# Find dead latents
mean_activation = activations.mean(dim=0)
dead_mask = mean_activation < dead_threshold
# Find examples with high reconstruction error
errors = (activations - sae(activations)[0]).pow(2).sum(dim=-1)
top_error_examples = activations[errors.topk(dead_mask.sum()).indices]
# Re-initialize dead encoder weights to point at high-error examples
sae.encoder.weight[dead_mask] = top_error_examples
Evaluating SAEs
| Metric | What it measures |
|---|---|
| Reconstruction loss | How well does SAE preserve information? |
| L0 (avg active latents) | How sparse is the representation? |
| Explained variance | What fraction of variance is captured? |
| Downstream loss | Does patching SAE output hurt model performance? |
def evaluate_sae(sae, model, layer, test_data):
_, cache = model.run_with_cache(test_data)
acts = cache["resid_post", layer]
recon, latents = sae(acts)
l2_loss = (acts - recon).pow(2).mean()
l0 = (latents > 0).float().sum(dim=-1).mean()
explained_var = 1 - l2_loss / acts.var()
return {"l2": l2_loss, "l0": l0, "var_explained": explained_var}
SAELens: The Standard Tool
from sae_lens import SAE, HookedSAETransformer
# Load model with SAE integration
model = HookedSAETransformer.from_pretrained("gpt2-small")
# Load pretrained SAE
sae, cfg, sparsity = SAE.from_pretrained(
release="gpt2-small-res-jb", # SAE release name
sae_id="blocks.8.hook_resid_pre", # Which layer
)
# Run with SAE
with model.saes(saes=[sae]):
logits = model("Hello world")
Neuronpedia: Exploring SAE Latents
Neuronpedia lets you:
- Search for latents by behavior
- See example activations
- View feature dashboards
- Steer models with latents
Each latent has a dashboard showing:
- Top activating examples
- Logit attribution
- Feature activation histogram
Capstone Connection
SAEs for sycophancy detection:
# Find sycophancy-related latents
for latent_idx in range(sae.n_latents):
# Check if this latent activates on sycophantic responses
sycophantic_activation = sae_latents[sycophantic_prompts, :, latent_idx].mean()
honest_activation = sae_latents[honest_prompts, :, latent_idx].mean()
if sycophantic_activation > 3 * honest_activation:
print(f"Latent {latent_idx} might detect sycophancy!")
# Ablate and test
🎓 Tyla's Exercise
Why does L1 regularization encourage sparsity? Derive this from the gradient of the L1 penalty.
If an SAE has expansion ratio 16x (768D → 12288D), how many independent features could it theoretically represent? What's the practical limit?
Explain the "feature splitting" phenomenon: why might multiple latents represent similar features?
💻 Aaliyah's Exercise
Train and analyze an SAE:
def train_toy_sae(activations, n_latents, l1_coeff, epochs=1000):
"""
1. Initialize an SAE with given expansion ratio
2. Train on provided activations
3. Plot loss curves (reconstruction vs sparsity)
4. Return trained SAE and training history
"""
pass
def find_interpretable_latents(sae, model, prompts):
"""
1. Run prompts through model and SAE
2. For each latent, find max-activating tokens
3. Display top-10 latents with their activating contexts
4. Manually label what each latent represents
"""
pass
📚 Maneesha's Reflection
SAEs impose our prior that features are sparse and linear. What if the model uses fundamentally different representations?
"Dead latents" waste capacity but might represent rare features. How would you design an SAE training procedure that balances both?
If SAE latents become the basis for AI interpretability, what are the risks of over-relying on this paradigm?