Chapter 0: Mastering einops

Before you can understand transformers, you need to think in tensors.

einops is the tool that makes tensor operations readable. Instead of memorizing .reshape(), .permute(), .transpose(), you describe what you want in words.


The Mental Model

einops.rearrange transforms tensor shapes by describing:

einops.einsum computes any combination of:

The pattern: dimensions that appear on both inputs but NOT in the output get summed.


Worked Example 1: Reshaping Tensors

import einops
import torch as t

# Start with a flat tensor
x = t.arange(12)  # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
print(f"Original: {x.shape}")  # torch.Size([12])

# Reshape to matrix: "(h w) -> h w"
y = einops.rearrange(x, "(h w) -> h w", h=3, w=4)
print(f"Matrix: {y.shape}")  # torch.Size([3, 4])
print(y)
# tensor([[ 0,  1,  2,  3],
#         [ 4,  5,  6,  7],
#         [ 8,  9, 10, 11]])

What happened?

💻 Aaliyah's Translation: Think of it like JavaScript:

// Conceptually similar to:
arr.reduce((rows, val, i) => {
  const row = Math.floor(i / 4);
  const col = i % 4;
  rows[row] = rows[row] || [];
  rows[row][col] = val;
  return rows;
}, [])

Worked Example 2: Adding a Batch Dimension

Neural networks expect batched inputs. Here's how to add a batch dimension:

# Image: (height, width, channels)
img = t.randn(28, 28, 1)
print(f"Image: {img.shape}")  # torch.Size([28, 28, 1])

# Network expects: (batch, channels, height, width)
batched = einops.rearrange(img, "h w c -> 1 c h w")
print(f"Batched: {batched.shape}")  # torch.Size([1, 1, 28, 28])

What happened?


Worked Example 3: Flattening for Linear Layers

A common pattern: flatten spatial dimensions before a linear layer.

# (batch, channels, height, width)
x = t.randn(32, 3, 28, 28)
print(f"Before: {x.shape}")  # torch.Size([32, 3, 28, 28])

# Flatten to (batch, features)
flat = einops.rearrange(x, "b c h w -> b (c h w)")
print(f"After: {flat.shape}")  # torch.Size([32, 2352])

What happened?


Worked Example 4: Matrix Multiplication with einsum

This is the pattern you'll use constantly.

# Standard matrix multiply: (M, K) @ (K, N) -> (M, N)
A = t.randn(3, 4)  # 3×4
B = t.randn(4, 5)  # 4×5

C = einops.einsum(A, B, "m k, k n -> m n")
print(f"Result: {C.shape}")  # torch.Size([3, 5])

# Verify: same as A @ B
assert t.allclose(C, A @ B)

What happened?


Worked Example 5: The Linear Layer Pattern

This is what you'll use to implement nn.Linear:

batch_size = 32
in_features = 768
out_features = 256

x = t.randn(batch_size, in_features)
W = t.randn(out_features, in_features)

# Linear layer computation: x @ W.T
out = einops.einsum(x, W, "batch in_f, out_f in_f -> batch out_f")
print(f"Output: {out.shape}")  # torch.Size([32, 256])

# Verify
assert t.allclose(out, x @ W.T, atol=1e-5)

Why this matters for your capstone:

Attention uses this pattern:

scores = einops.einsum(Q, K, "b h sq d, b h sk d -> b h sq sk")
output = einops.einsum(attention_weights, V, "b h sq sk, b h sk d -> b h sq d")

When you analyze sycophancy in Chapter 1, you'll look at what Q is "querying" and what K "contains." Understanding einsum now means you can read attention code later.


🎓 Tyla's Exercise

Predict the output shapes without running the code:

x = t.randn(4, 8, 16)  # (batch, sequence, features)

# 1. einops.rearrange(x, "b s f -> s b f")
# 2. einops.rearrange(x, "b s f -> b (s f)")
# 3. einops.einsum(x, x, "b s f, b s f -> b")

Then verify your predictions.

After completing: What did each operation teach you about how transformers process sequences?


💻 Aaliyah's Exercise

Implement these without looking at the solutions:

def batch_and_channels_first(img):
    """
    Input: (H, W, C) image
    Output: (1, C, H, W) batched tensor
    """
    # Your code here
    pass

def flatten_for_linear(x):
    """
    Input: (B, C, H, W) batch of images
    Output: (B, C*H*W) flattened for linear layer
    """
    # Your code here
    pass

📚 Maneesha's Reflection

Don't implement anything. Instead, answer:

  1. Why does PyTorch use (batch, channels, height, width) order instead of (batch, height, width, channels)?

  2. What's the pedagogical insight here? Why is einops easier to learn than raw PyTorch operations?

  3. How would you teach einsum to someone with no math background?


Capstone Connection

In your sycophancy evaluation:

Master this now. It's invisible infrastructure for everything that follows.