Transformers: The Attention Mechanism
Attention is the transformer's core innovation. It lets every position talk to every other position.
The Attention Question
At each position, the model asks: "What information from other positions is relevant here?"
"The cat sat on the mat because it was tired"
At "it": Which earlier word does "it" refer to?
Attention should look at "cat" more than "mat"
Queries, Keys, and Values
Three projections of each token:
- Query (Q): "What am I looking for?"
- Key (K): "What do I contain?"
- Value (V): "What information should I contribute?"
Q = x @ W_Q # (batch, seq, d_head)
K = x @ W_K # (batch, seq, d_head)
V = x @ W_V # (batch, seq, d_head)
Attention score = how well Q matches K. Output = weighted sum of V, weighted by attention scores.
The Attention Formula
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V$$
Step by step:
# 1. Compute attention scores
scores = Q @ K.transpose(-2, -1) # (batch, seq_q, seq_k)
# 2. Scale (prevents softmax saturation)
scores = scores / np.sqrt(d_head)
# 3. Apply softmax (attention weights sum to 1)
pattern = F.softmax(scores, dim=-1) # (batch, seq_q, seq_k)
# 4. Weighted sum of values
output = pattern @ V # (batch, seq_q, d_head)
Implementing Single-Head Attention
class Attention(nn.Module):
def __init__(self, d_model: int, d_head: int):
super().__init__()
self.d_head = d_head
self.W_Q = nn.Linear(d_model, d_head, bias=False)
self.W_K = nn.Linear(d_model, d_head, bias=False)
self.W_V = nn.Linear(d_model, d_head, bias=False)
self.W_O = nn.Linear(d_head, d_model, bias=False)
def forward(self, x: t.Tensor) -> t.Tensor:
Q = self.W_Q(x)
K = self.W_K(x)
V = self.W_V(x)
scores = Q @ K.transpose(-2, -1) / np.sqrt(self.d_head)
pattern = F.softmax(scores, dim=-1)
out = pattern @ V
return self.W_O(out)
Multi-Head Attention
Multiple attention "heads" in parallel, each learning different patterns:
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int):
super().__init__()
assert d_model % n_heads == 0
self.d_head = d_model // n_heads
self.n_heads = n_heads
self.W_Q = nn.Linear(d_model, d_model, bias=False)
self.W_K = nn.Linear(d_model, d_model, bias=False)
self.W_V = nn.Linear(d_model, d_model, bias=False)
self.W_O = nn.Linear(d_model, d_model, bias=False)
def forward(self, x: t.Tensor) -> t.Tensor:
batch, seq, _ = x.shape
# Project and reshape to (batch, n_heads, seq, d_head)
Q = self.W_Q(x).view(batch, seq, self.n_heads, self.d_head).transpose(1, 2)
K = self.W_K(x).view(batch, seq, self.n_heads, self.d_head).transpose(1, 2)
V = self.W_V(x).view(batch, seq, self.n_heads, self.d_head).transpose(1, 2)
# Attention per head
scores = Q @ K.transpose(-2, -1) / np.sqrt(self.d_head)
pattern = F.softmax(scores, dim=-1)
out = pattern @ V
# Concatenate heads and project
out = out.transpose(1, 2).contiguous().view(batch, seq, -1)
return self.W_O(out)
Causal Masking
For autoregressive models (GPT), positions can only attend to earlier positions:
def causal_mask(seq_len: int) -> t.Tensor:
# Lower triangular matrix
mask = t.tril(t.ones(seq_len, seq_len))
# Convert to attention mask (0 → -inf, 1 → 0)
return mask.masked_fill(mask == 0, float('-inf'))
# Apply before softmax
scores = scores + causal_mask(seq_len)
pattern = F.softmax(scores, dim=-1)
The mask ensures position i can only see positions 0, 1, ..., i.
Visualizing Attention Patterns
import circuitsvis as cv
# Get attention patterns from a model
_, cache = model.run_with_cache(tokens)
attention_pattern = cache["pattern", layer_idx] # (batch, n_heads, seq_q, seq_k)
# Visualize
cv.attention.attention_patterns(
tokens=model.to_str_tokens(tokens),
attention=attention_pattern[0], # First batch item
)
Different heads learn different patterns:
- Previous token heads: Look at position i-1
- Induction heads: Look for pattern matching
- Positional heads: Fixed positional patterns
Capstone Connection
Attention and sycophancy:
Attention patterns reveal what the model "looks at" when generating:
# When generating a response to "What do you think about X?"
# Where does the model attend?
# Sycophantic model might:
# - Strongly attend to sentiment words from user
# - Attend to "you think" more than factual context
# Honest model might:
# - Attend more to factual claims
# - Balance attention between user query and context
In Chapter 1.2, you'll use TransformerLens to inspect these patterns.
🎓 Tyla's Exercise
Why do we divide by $\sqrt{d_k}$? (Hint: What's the variance of $Q \cdot K$ if Q and K have unit variance?)
Prove that attention is permutation-equivariant: if we permute the input, the output gets permuted the same way.
Multi-head attention with n_heads=1 and d_head=d_model is equivalent to single-head attention. What's the benefit of multiple smaller heads?
💻 Aaliyah's Exercise
Implement attention from scratch:
def manual_attention(Q, K, V, mask=None):
"""
Q: (batch, n_heads, seq_q, d_head)
K: (batch, n_heads, seq_k, d_head)
V: (batch, n_heads, seq_k, d_head)
mask: optional (seq_q, seq_k) or (batch, n_heads, seq_q, seq_k)
Returns: (batch, n_heads, seq_q, d_head)
"""
pass
# Verify against PyTorch's F.scaled_dot_product_attention
📚 Maneesha's Reflection
The attention mechanism can be viewed as a "soft database lookup." Q is the query, K is the key, V is the value. How does this analogy help or hurt understanding?
Attention lets every position see every other position. What's the computational cost? When might this be a problem?
Human attention is selective and limited. Machine attention is (by default) complete. What are the implications for how these systems process information?