VAEs: Variational Autoencoders
Autoencoders learn compressed representations. VAEs make those representations meaningful.
The Autoencoder Idea
Encoder: Compress input to low-dimensional latent space Decoder: Reconstruct input from latent representation
Input (28×28) → Encoder → Latent (20) → Decoder → Output (28×28)
784 dims 20 dims 784 dims
Train by minimizing reconstruction error: $$L = ||x - \hat{x}||^2$$
The Problem with Autoencoders
The latent space isn't meaningful:
- Point
[1.0, 2.0, 0.5]might decode to a "7" - Point
[1.1, 2.0, 0.5]might decode to noise
Why? The encoder only needs to find SOME encoding. It doesn't need nearby points to mean similar things.
The VAE Solution
Instead of encoding to a point, encode to a distribution:
Input → Encoder → μ, σ → Sample z ~ N(μ, σ) → Decoder → Output
Key constraint: The latent distribution should be close to standard normal N(0, I).
This forces the latent space to be:
- Continuous: Nearby points decode to similar images
- Complete: Every point decodes to something reasonable
The Reparameterization Trick
We can't backprop through random sampling directly.
Solution: Sample from N(0, 1) and transform: $$z = \mu + \sigma \odot \epsilon, \quad \epsilon \sim N(0, I)$$
Now gradients flow through μ and σ, not through the sampling operation.
def reparameterize(mu, log_var):
std = t.exp(0.5 * log_var)
eps = t.randn_like(std)
return mu + std * eps
The ELBO Loss
VAE loss has two parts:
$$L = L_{recon} + \beta \cdot L_{KL}$$
Reconstruction loss: How well does the decoder reconstruct? $$L_{recon} = ||x - \hat{x}||^2$$
KL divergence: How close is q(z|x) to N(0, I)? $$L_{KL} = -\frac{1}{2}\sum(1 + \log\sigma^2 - \mu^2 - \sigma^2)$$
Implementing a VAE
class VAE(nn.Module):
def __init__(self, latent_dim: int = 20):
super().__init__()
# Encoder
self.encoder = nn.Sequential(
nn.Conv2d(1, 32, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, 3, stride=2, padding=1),
nn.ReLU(),
nn.Flatten(),
nn.Linear(64 * 7 * 7, 256),
nn.ReLU(),
)
# Latent space parameters
self.fc_mu = nn.Linear(256, latent_dim)
self.fc_var = nn.Linear(256, latent_dim)
# Decoder
self.decoder_input = nn.Linear(latent_dim, 256)
self.decoder = nn.Sequential(
nn.Linear(256, 64 * 7 * 7),
nn.ReLU(),
nn.Unflatten(1, (64, 7, 7)),
nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(32, 1, 3, stride=2, padding=1, output_padding=1),
nn.Sigmoid(),
)
def encode(self, x):
h = self.encoder(x)
return self.fc_mu(h), self.fc_var(h)
def decode(self, z):
h = self.decoder_input(z)
return self.decoder(h)
def forward(self, x):
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
return self.decode(z), mu, log_var
def reparameterize(self, mu, log_var):
std = t.exp(0.5 * log_var)
eps = t.randn_like(std)
return mu + std * eps
Training Loop
def vae_loss(recon_x, x, mu, log_var, beta=1.0):
# Reconstruction loss (binary cross entropy for images)
recon_loss = F.binary_cross_entropy(recon_x, x, reduction='sum')
# KL divergence
kl_loss = -0.5 * t.sum(1 + log_var - mu.pow(2) - log_var.exp())
return recon_loss + beta * kl_loss
for epoch in range(epochs):
for x, _ in dataloader:
optimizer.zero_grad()
recon, mu, log_var = model(x)
loss = vae_loss(recon, x, mu, log_var)
loss.backward()
optimizer.step()
Generating New Images
Sample from the prior and decode:
def generate(model, n_samples=16):
with t.no_grad():
z = t.randn(n_samples, latent_dim)
samples = model.decode(z)
return samples
# Generate MNIST digits
samples = generate(model)
Because the latent space is organized, random samples look like real digits!
Latent Space Interpolation
Walk between two images:
def interpolate(model, x1, x2, n_steps=10):
with t.no_grad():
mu1, _ = model.encode(x1)
mu2, _ = model.encode(x2)
# Linear interpolation in latent space
alphas = t.linspace(0, 1, n_steps)
z_interp = [a * mu2 + (1-a) * mu1 for a in alphas]
return [model.decode(z) for z in z_interp]
# Morph from a "3" to an "8"
frames = interpolate(model, image_3, image_8)
Capstone Connection
VAEs and representation learning:
The latent space of a VAE is a learned representation. Similarly, transformer embeddings are learned representations.
When analyzing sycophancy:
- What does the "sycophancy direction" in embedding space look like?
- Can we interpolate between sycophantic and honest responses?
- How do different training objectives shape the embedding structure?
VAEs teach us that how we train affects what we learn.
🎓 Tyla's Exercise
Derive the KL divergence formula for two Gaussians: $D_{KL}(N(\mu, \sigma^2) || N(0, 1))$.
Why do we use
log_varinstead ofvardirectly? What numerical issue does this avoid?Explain why the reparameterization trick allows gradients to flow, but sampling directly doesn't.
💻 Aaliyah's Exercise
Train a VAE on MNIST and explore:
def train_vae():
"""
1. Train for 20 epochs
2. Plot loss curves (reconstruction and KL separately)
3. Generate 64 random samples
4. Show interpolation between digits
"""
pass
def visualize_latent_space(model, test_loader):
"""
1. Encode all test images
2. Use t-SNE or PCA to project to 2D
3. Color by digit class
4. Do similar digits cluster together?
"""
pass
📚 Maneesha's Reflection
The VAE forces the latent space to be "meaningful" by adding the KL term. What does "meaningful" mean in this context?
β-VAE increases the weight on KL divergence to learn more disentangled representations. What's the trade-off?
How would you explain the difference between an autoencoder and a VAE to someone who hasn't seen either?