Transformers: The Attention Mechanism

Attention is the transformer's core innovation. It lets every position talk to every other position.


The Attention Question

At each position, the model asks: "What information from other positions is relevant here?"

"The cat sat on the mat because it was tired"

At "it": Which earlier word does "it" refer to?
Attention should look at "cat" more than "mat"

Queries, Keys, and Values

Three projections of each token:

Q = x @ W_Q  # (batch, seq, d_head)
K = x @ W_K  # (batch, seq, d_head)
V = x @ W_V  # (batch, seq, d_head)

Attention score = how well Q matches K. Output = weighted sum of V, weighted by attention scores.


The Attention Formula

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V$$

Step by step:

# 1. Compute attention scores
scores = Q @ K.transpose(-2, -1)  # (batch, seq_q, seq_k)

# 2. Scale (prevents softmax saturation)
scores = scores / np.sqrt(d_head)

# 3. Apply softmax (attention weights sum to 1)
pattern = F.softmax(scores, dim=-1)  # (batch, seq_q, seq_k)

# 4. Weighted sum of values
output = pattern @ V  # (batch, seq_q, d_head)

Implementing Single-Head Attention

class Attention(nn.Module):
    def __init__(self, d_model: int, d_head: int):
        super().__init__()
        self.d_head = d_head

        self.W_Q = nn.Linear(d_model, d_head, bias=False)
        self.W_K = nn.Linear(d_model, d_head, bias=False)
        self.W_V = nn.Linear(d_model, d_head, bias=False)
        self.W_O = nn.Linear(d_head, d_model, bias=False)

    def forward(self, x: t.Tensor) -> t.Tensor:
        Q = self.W_Q(x)
        K = self.W_K(x)
        V = self.W_V(x)

        scores = Q @ K.transpose(-2, -1) / np.sqrt(self.d_head)
        pattern = F.softmax(scores, dim=-1)
        out = pattern @ V

        return self.W_O(out)

Multi-Head Attention

Multiple attention "heads" in parallel, each learning different patterns:

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_head = d_model // n_heads
        self.n_heads = n_heads

        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        self.W_O = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x: t.Tensor) -> t.Tensor:
        batch, seq, _ = x.shape

        # Project and reshape to (batch, n_heads, seq, d_head)
        Q = self.W_Q(x).view(batch, seq, self.n_heads, self.d_head).transpose(1, 2)
        K = self.W_K(x).view(batch, seq, self.n_heads, self.d_head).transpose(1, 2)
        V = self.W_V(x).view(batch, seq, self.n_heads, self.d_head).transpose(1, 2)

        # Attention per head
        scores = Q @ K.transpose(-2, -1) / np.sqrt(self.d_head)
        pattern = F.softmax(scores, dim=-1)
        out = pattern @ V

        # Concatenate heads and project
        out = out.transpose(1, 2).contiguous().view(batch, seq, -1)
        return self.W_O(out)

Causal Masking

For autoregressive models (GPT), positions can only attend to earlier positions:

def causal_mask(seq_len: int) -> t.Tensor:
    # Lower triangular matrix
    mask = t.tril(t.ones(seq_len, seq_len))
    # Convert to attention mask (0 → -inf, 1 → 0)
    return mask.masked_fill(mask == 0, float('-inf'))

# Apply before softmax
scores = scores + causal_mask(seq_len)
pattern = F.softmax(scores, dim=-1)

The mask ensures position i can only see positions 0, 1, ..., i.


Visualizing Attention Patterns

import circuitsvis as cv

# Get attention patterns from a model
_, cache = model.run_with_cache(tokens)
attention_pattern = cache["pattern", layer_idx]  # (batch, n_heads, seq_q, seq_k)

# Visualize
cv.attention.attention_patterns(
    tokens=model.to_str_tokens(tokens),
    attention=attention_pattern[0],  # First batch item
)

Different heads learn different patterns:


Capstone Connection

Attention and sycophancy:

Attention patterns reveal what the model "looks at" when generating:

# When generating a response to "What do you think about X?"
# Where does the model attend?

# Sycophantic model might:
# - Strongly attend to sentiment words from user
# - Attend to "you think" more than factual context

# Honest model might:
# - Attend more to factual claims
# - Balance attention between user query and context

In Chapter 1.2, you'll use TransformerLens to inspect these patterns.


🎓 Tyla's Exercise

  1. Why do we divide by $\sqrt{d_k}$? (Hint: What's the variance of $Q \cdot K$ if Q and K have unit variance?)

  2. Prove that attention is permutation-equivariant: if we permute the input, the output gets permuted the same way.

  3. Multi-head attention with n_heads=1 and d_head=d_model is equivalent to single-head attention. What's the benefit of multiple smaller heads?


💻 Aaliyah's Exercise

Implement attention from scratch:

def manual_attention(Q, K, V, mask=None):
    """
    Q: (batch, n_heads, seq_q, d_head)
    K: (batch, n_heads, seq_k, d_head)
    V: (batch, n_heads, seq_k, d_head)
    mask: optional (seq_q, seq_k) or (batch, n_heads, seq_q, seq_k)

    Returns: (batch, n_heads, seq_q, d_head)
    """
    pass

# Verify against PyTorch's F.scaled_dot_product_attention

📚 Maneesha's Reflection

  1. The attention mechanism can be viewed as a "soft database lookup." Q is the query, K is the key, V is the value. How does this analogy help or hurt understanding?

  2. Attention lets every position see every other position. What's the computational cost? When might this be a problem?

  3. Human attention is selective and limited. Machine attention is (by default) complete. What are the implications for how these systems process information?