VAEs: Variational Autoencoders

Autoencoders learn compressed representations. VAEs make those representations meaningful.


The Autoencoder Idea

Encoder: Compress input to low-dimensional latent space Decoder: Reconstruct input from latent representation

Input (28×28) → Encoder → Latent (20) → Decoder → Output (28×28)
    784 dims              20 dims              784 dims

Train by minimizing reconstruction error: $$L = ||x - \hat{x}||^2$$


The Problem with Autoencoders

The latent space isn't meaningful:

Why? The encoder only needs to find SOME encoding. It doesn't need nearby points to mean similar things.


The VAE Solution

Instead of encoding to a point, encode to a distribution:

Input → Encoder → μ, σ → Sample z ~ N(μ, σ) → Decoder → Output

Key constraint: The latent distribution should be close to standard normal N(0, I).

This forces the latent space to be:

  1. Continuous: Nearby points decode to similar images
  2. Complete: Every point decodes to something reasonable

The Reparameterization Trick

We can't backprop through random sampling directly.

Solution: Sample from N(0, 1) and transform: $$z = \mu + \sigma \odot \epsilon, \quad \epsilon \sim N(0, I)$$

Now gradients flow through μ and σ, not through the sampling operation.

def reparameterize(mu, log_var):
    std = t.exp(0.5 * log_var)
    eps = t.randn_like(std)
    return mu + std * eps

The ELBO Loss

VAE loss has two parts:

$$L = L_{recon} + \beta \cdot L_{KL}$$

Reconstruction loss: How well does the decoder reconstruct? $$L_{recon} = ||x - \hat{x}||^2$$

KL divergence: How close is q(z|x) to N(0, I)? $$L_{KL} = -\frac{1}{2}\sum(1 + \log\sigma^2 - \mu^2 - \sigma^2)$$


Implementing a VAE

class VAE(nn.Module):
    def __init__(self, latent_dim: int = 20):
        super().__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 256),
            nn.ReLU(),
        )

        # Latent space parameters
        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_var = nn.Linear(256, latent_dim)

        # Decoder
        self.decoder_input = nn.Linear(latent_dim, 256)
        self.decoder = nn.Sequential(
            nn.Linear(256, 64 * 7 * 7),
            nn.ReLU(),
            nn.Unflatten(1, (64, 7, 7)),
            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid(),
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_var(h)

    def decode(self, z):
        h = self.decoder_input(z)
        return self.decoder(h)

    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        return self.decode(z), mu, log_var

    def reparameterize(self, mu, log_var):
        std = t.exp(0.5 * log_var)
        eps = t.randn_like(std)
        return mu + std * eps

Training Loop

def vae_loss(recon_x, x, mu, log_var, beta=1.0):
    # Reconstruction loss (binary cross entropy for images)
    recon_loss = F.binary_cross_entropy(recon_x, x, reduction='sum')

    # KL divergence
    kl_loss = -0.5 * t.sum(1 + log_var - mu.pow(2) - log_var.exp())

    return recon_loss + beta * kl_loss

for epoch in range(epochs):
    for x, _ in dataloader:
        optimizer.zero_grad()
        recon, mu, log_var = model(x)
        loss = vae_loss(recon, x, mu, log_var)
        loss.backward()
        optimizer.step()

Generating New Images

Sample from the prior and decode:

def generate(model, n_samples=16):
    with t.no_grad():
        z = t.randn(n_samples, latent_dim)
        samples = model.decode(z)
    return samples

# Generate MNIST digits
samples = generate(model)

Because the latent space is organized, random samples look like real digits!


Latent Space Interpolation

Walk between two images:

def interpolate(model, x1, x2, n_steps=10):
    with t.no_grad():
        mu1, _ = model.encode(x1)
        mu2, _ = model.encode(x2)

        # Linear interpolation in latent space
        alphas = t.linspace(0, 1, n_steps)
        z_interp = [a * mu2 + (1-a) * mu1 for a in alphas]

        return [model.decode(z) for z in z_interp]

# Morph from a "3" to an "8"
frames = interpolate(model, image_3, image_8)

Capstone Connection

VAEs and representation learning:

The latent space of a VAE is a learned representation. Similarly, transformer embeddings are learned representations.

When analyzing sycophancy:

VAEs teach us that how we train affects what we learn.


🎓 Tyla's Exercise

  1. Derive the KL divergence formula for two Gaussians: $D_{KL}(N(\mu, \sigma^2) || N(0, 1))$.

  2. Why do we use log_var instead of var directly? What numerical issue does this avoid?

  3. Explain why the reparameterization trick allows gradients to flow, but sampling directly doesn't.


💻 Aaliyah's Exercise

Train a VAE on MNIST and explore:

def train_vae():
    """
    1. Train for 20 epochs
    2. Plot loss curves (reconstruction and KL separately)
    3. Generate 64 random samples
    4. Show interpolation between digits
    """
    pass

def visualize_latent_space(model, test_loader):
    """
    1. Encode all test images
    2. Use t-SNE or PCA to project to 2D
    3. Color by digit class
    4. Do similar digits cluster together?
    """
    pass

📚 Maneesha's Reflection

  1. The VAE forces the latent space to be "meaningful" by adding the KL term. What does "meaningful" mean in this context?

  2. β-VAE increases the weight on KL divergence to learn more disentangled representations. What's the trade-off?

  3. How would you explain the difference between an autoencoder and a VAE to someone who hasn't seen either?