Sparse Autoencoders: Untangling Features

If superposition is the disease, sparse autoencoders are the treatment.


The SAE Idea

Expand the compressed space back into interpretable features:

Residual Stream (768D) → SAE Encoder → Latent Space (16000D) → SAE Decoder → Reconstructed (768D)
                              ↓
                    Sparse, interpretable features

The key constraint: sparsity. Only a few latents should be active at once.


SAE Architecture

class SparseAutoencoder(nn.Module):
    def __init__(self, d_model, n_latents):
        super().__init__()
        self.encoder = nn.Linear(d_model, n_latents)
        self.decoder = nn.Linear(n_latents, d_model, bias=False)

    def forward(self, x):
        # Encode
        pre_acts = self.encoder(x)
        latents = F.relu(pre_acts)  # Sparsity via ReLU

        # Decode
        reconstructed = self.decoder(latents)

        return reconstructed, latents

The expansion ratio is crucial: typically 8x-64x the model dimension.


The SAE Loss Function

Two competing objectives:

def sae_loss(x, reconstructed, latents, l1_coeff):
    # 1. Reconstruction: match the original
    reconstruction_loss = (x - reconstructed).pow(2).mean()

    # 2. Sparsity: few latents should be active
    sparsity_loss = latents.abs().mean()

    return reconstruction_loss + l1_coeff * sparsity_loss

The l1_coeff controls the trade-off:


What SAEs Learn

After training, each SAE latent corresponds to a feature:

# Example latents from a trained SAE on GPT-2
latent_2847  "code context" (activates on programming)
latent_9123  "past tense" (activates on -ed endings)
latent_15234  "scientific text" (activates on academic language)

These are more interpretable than neurons!


SAE Variants

Gated SAE (DeepMind):

class GatedSAE(nn.Module):
    def forward(self, x):
        # Separate magnitude and gate
        gate = torch.sigmoid(self.W_gate @ x)
        magnitude = F.relu(self.W_mag @ x)

        latents = gate * magnitude  # Element-wise product
        return self.decoder(latents), latents

JumpReLU SAE:

def jump_relu(x, threshold):
    return F.relu(x - threshold) * (x > threshold)

Training SAEs

Key considerations:

  1. What to train on: Residual stream? MLP outputs? Attention outputs?

  2. Which layer: Earlier layers → more basic features

  3. Dataset: Diverse text for general features

# Training loop sketch
for batch in dataloader:
    # Run model, get activations
    _, cache = model.run_with_cache(batch)
    activations = cache["resid_post", layer]

    # Forward through SAE
    reconstructed, latents = sae(activations)

    # Compute loss
    loss = sae_loss(activations, reconstructed, latents, l1_coeff)
    loss.backward()
    optimizer.step()

Neuron Resampling

Dead latents (never activate) waste capacity:

def resample_dead_neurons(sae, activations, dead_threshold=1e-7):
    """
    Find latents that rarely activate.
    Re-initialize them to explain high-error examples.
    """
    # Find dead latents
    mean_activation = activations.mean(dim=0)
    dead_mask = mean_activation < dead_threshold

    # Find examples with high reconstruction error
    errors = (activations - sae(activations)[0]).pow(2).sum(dim=-1)
    top_error_examples = activations[errors.topk(dead_mask.sum()).indices]

    # Re-initialize dead encoder weights to point at high-error examples
    sae.encoder.weight[dead_mask] = top_error_examples

Evaluating SAEs

Metric What it measures
Reconstruction loss How well does SAE preserve information?
L0 (avg active latents) How sparse is the representation?
Explained variance What fraction of variance is captured?
Downstream loss Does patching SAE output hurt model performance?
def evaluate_sae(sae, model, layer, test_data):
    _, cache = model.run_with_cache(test_data)
    acts = cache["resid_post", layer]

    recon, latents = sae(acts)

    l2_loss = (acts - recon).pow(2).mean()
    l0 = (latents > 0).float().sum(dim=-1).mean()
    explained_var = 1 - l2_loss / acts.var()

    return {"l2": l2_loss, "l0": l0, "var_explained": explained_var}

SAELens: The Standard Tool

from sae_lens import SAE, HookedSAETransformer

# Load model with SAE integration
model = HookedSAETransformer.from_pretrained("gpt2-small")

# Load pretrained SAE
sae, cfg, sparsity = SAE.from_pretrained(
    release="gpt2-small-res-jb",  # SAE release name
    sae_id="blocks.8.hook_resid_pre",  # Which layer
)

# Run with SAE
with model.saes(saes=[sae]):
    logits = model("Hello world")

Neuronpedia: Exploring SAE Latents

Neuronpedia lets you:

  1. Search for latents by behavior
  2. See example activations
  3. View feature dashboards
  4. Steer models with latents

Each latent has a dashboard showing:


Capstone Connection

SAEs for sycophancy detection:

# Find sycophancy-related latents
for latent_idx in range(sae.n_latents):
    # Check if this latent activates on sycophantic responses
    sycophantic_activation = sae_latents[sycophantic_prompts, :, latent_idx].mean()
    honest_activation = sae_latents[honest_prompts, :, latent_idx].mean()

    if sycophantic_activation > 3 * honest_activation:
        print(f"Latent {latent_idx} might detect sycophancy!")
        # Ablate and test

🎓 Tyla's Exercise

  1. Why does L1 regularization encourage sparsity? Derive this from the gradient of the L1 penalty.

  2. If an SAE has expansion ratio 16x (768D → 12288D), how many independent features could it theoretically represent? What's the practical limit?

  3. Explain the "feature splitting" phenomenon: why might multiple latents represent similar features?


💻 Aaliyah's Exercise

Train and analyze an SAE:

def train_toy_sae(activations, n_latents, l1_coeff, epochs=1000):
    """
    1. Initialize an SAE with given expansion ratio
    2. Train on provided activations
    3. Plot loss curves (reconstruction vs sparsity)
    4. Return trained SAE and training history
    """
    pass

def find_interpretable_latents(sae, model, prompts):
    """
    1. Run prompts through model and SAE
    2. For each latent, find max-activating tokens
    3. Display top-10 latents with their activating contexts
    4. Manually label what each latent represents
    """
    pass

📚 Maneesha's Reflection

  1. SAEs impose our prior that features are sparse and linear. What if the model uses fundamentally different representations?

  2. "Dead latents" waste capacity but might represent rare features. How would you design an SAE training procedure that balances both?

  3. If SAE latents become the basis for AI interpretability, what are the risks of over-relying on this paradigm?