Chapter 0: Mastering einops
Before you can understand transformers, you need to think in tensors.
einops is the tool that makes tensor operations readable. Instead of memorizing .reshape(), .permute(), .transpose(), you describe what you want in words.
The Mental Model
einops.rearrange transforms tensor shapes by describing:
- The input dimensions (left side of arrow)
- The output dimensions (right side of arrow)
- How dimensions combine (parentheses) or split (named values)
einops.einsum computes any combination of:
- Matrix multiplication
- Dot products
- Summing over dimensions
The pattern: dimensions that appear on both inputs but NOT in the output get summed.
Worked Example 1: Reshaping Tensors
import einops
import torch as t
# Start with a flat tensor
x = t.arange(12) # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
print(f"Original: {x.shape}") # torch.Size([12])
# Reshape to matrix: "(h w) -> h w"
y = einops.rearrange(x, "(h w) -> h w", h=3, w=4)
print(f"Matrix: {y.shape}") # torch.Size([3, 4])
print(y)
# tensor([[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]])
What happened?
(h w)on the left: The 12 elements are h×w (3×4=12)h won the right: Separate dimensions h and w- Parentheses on the left = "collapsed together"
- No parentheses on the right = "separate dimensions"
💻 Aaliyah's Translation: Think of it like JavaScript:
// Conceptually similar to:
arr.reduce((rows, val, i) => {
const row = Math.floor(i / 4);
const col = i % 4;
rows[row] = rows[row] || [];
rows[row][col] = val;
return rows;
}, [])
Worked Example 2: Adding a Batch Dimension
Neural networks expect batched inputs. Here's how to add a batch dimension:
# Image: (height, width, channels)
img = t.randn(28, 28, 1)
print(f"Image: {img.shape}") # torch.Size([28, 28, 1])
# Network expects: (batch, channels, height, width)
batched = einops.rearrange(img, "h w c -> 1 c h w")
print(f"Batched: {batched.shape}") # torch.Size([1, 1, 28, 28])
What happened?
h w c -> 1 c h w:- Added a dimension of size 1 at the front (batch)
- Reordered c to come before h and w (channels first)
Worked Example 3: Flattening for Linear Layers
A common pattern: flatten spatial dimensions before a linear layer.
# (batch, channels, height, width)
x = t.randn(32, 3, 28, 28)
print(f"Before: {x.shape}") # torch.Size([32, 3, 28, 28])
# Flatten to (batch, features)
flat = einops.rearrange(x, "b c h w -> b (c h w)")
print(f"After: {flat.shape}") # torch.Size([32, 2352])
What happened?
bstays separate (batch dimension)(c h w)collapses to single dimension: 3 × 28 × 28 = 2352
Worked Example 4: Matrix Multiplication with einsum
This is the pattern you'll use constantly.
# Standard matrix multiply: (M, K) @ (K, N) -> (M, N)
A = t.randn(3, 4) # 3×4
B = t.randn(4, 5) # 4×5
C = einops.einsum(A, B, "m k, k n -> m n")
print(f"Result: {C.shape}") # torch.Size([3, 5])
# Verify: same as A @ B
assert t.allclose(C, A @ B)
What happened?
m k, k n -> m n:- A has dimensions (m, k)
- B has dimensions (k, n)
- Output has dimensions (m, n)
kappears on left but NOT on right → sum over k- This is exactly matrix multiplication!
Worked Example 5: The Linear Layer Pattern
This is what you'll use to implement nn.Linear:
batch_size = 32
in_features = 768
out_features = 256
x = t.randn(batch_size, in_features)
W = t.randn(out_features, in_features)
# Linear layer computation: x @ W.T
out = einops.einsum(x, W, "batch in_f, out_f in_f -> batch out_f")
print(f"Output: {out.shape}") # torch.Size([32, 256])
# Verify
assert t.allclose(out, x @ W.T, atol=1e-5)
Why this matters for your capstone:
Attention uses this pattern:
scores = einops.einsum(Q, K, "b h sq d, b h sk d -> b h sq sk")
output = einops.einsum(attention_weights, V, "b h sq sk, b h sk d -> b h sq d")
When you analyze sycophancy in Chapter 1, you'll look at what Q is "querying" and what K "contains." Understanding einsum now means you can read attention code later.
🎓 Tyla's Exercise
Predict the output shapes without running the code:
x = t.randn(4, 8, 16) # (batch, sequence, features)
# 1. einops.rearrange(x, "b s f -> s b f")
# 2. einops.rearrange(x, "b s f -> b (s f)")
# 3. einops.einsum(x, x, "b s f, b s f -> b")
Then verify your predictions.
After completing: What did each operation teach you about how transformers process sequences?
💻 Aaliyah's Exercise
Implement these without looking at the solutions:
def batch_and_channels_first(img):
"""
Input: (H, W, C) image
Output: (1, C, H, W) batched tensor
"""
# Your code here
pass
def flatten_for_linear(x):
"""
Input: (B, C, H, W) batch of images
Output: (B, C*H*W) flattened for linear layer
"""
# Your code here
pass
📚 Maneesha's Reflection
Don't implement anything. Instead, answer:
Why does PyTorch use (batch, channels, height, width) order instead of (batch, height, width, channels)?
What's the pedagogical insight here? Why is
einopseasier to learn than raw PyTorch operations?How would you teach einsum to someone with no math background?
Capstone Connection
In your sycophancy evaluation:
- You'll use einsum to compute attention scores
- You'll use rearrange to extract specific attention heads
- You'll need to understand tensor shapes to interpret what the model is "looking at"
Master this now. It's invisible infrastructure for everything that follows.