CNNs: Making Your Own Modules

Neural networks are made of modules. Understanding nn.Module is understanding PyTorch.

This chapter teaches you to build reusable components from scratch.


The nn.Module Pattern

Every PyTorch neural network component inherits from nn.Module:

import torch.nn as nn

class MyModule(nn.Module):
    def __init__(self, ...):
        super().__init__()
        # Define parameters and sub-modules

    def forward(self, x):
        # Define computation
        return output

The key methods:


Implementing ReLU

The simplest activation function:

$$\text{ReLU}(x) = \max(0, x)$$

class ReLU(nn.Module):
    def forward(self, x: t.Tensor) -> t.Tensor:
        return t.maximum(x, t.tensor(0.0))

No parameters, no __init__ needed. The module is just a wrapper around an operation.

Why wrap it in a module? Composability. We can include it in nn.Sequential.


Implementing Linear

A linear layer: $y = xW^T + b$

class Linear(nn.Module):
    def __init__(self, in_features: int, out_features: int, bias: bool = True):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        # Kaiming initialization
        k = 1 / np.sqrt(in_features)

        # Parameters MUST be wrapped in nn.Parameter
        weight = k * (2 * t.rand(out_features, in_features) - 1)
        self.weight = nn.Parameter(weight)

        if bias:
            b = k * (2 * t.rand(out_features) - 1)
            self.bias = nn.Parameter(b)
        else:
            self.bias = None

    def forward(self, x: t.Tensor) -> t.Tensor:
        out = x @ self.weight.T
        if self.bias is not None:
            out = out + self.bias
        return out

Why nn.Parameter?

nn.Parameter is a special tensor that:

  1. Gets registered with the module
  2. Appears in model.parameters()
  3. Gets moved when you call model.to(device)
  4. Gets saved/loaded with model.state_dict()

Without it, your weights won't be trained!

# WRONG - won't be trained
self.weight = t.randn(out_features, in_features)

# RIGHT - will be trained
self.weight = nn.Parameter(t.randn(out_features, in_features))

Kaiming Initialization

Why $\frac{1}{\sqrt{n_{in}}}$?

Each output is a sum of $n_{in}$ products: $y_i = \sum_j w_{ij} x_j$

If weights have variance $\sigma_w^2$ and inputs have variance $\sigma_x^2$:

$$\text{Var}(y_i) = n_{in} \cdot \sigma_w^2 \cdot \sigma_x^2$$

For variance to stay constant through layers, we need $\sigma_w^2 = \frac{1}{n_{in}}$.

Bad initialization → gradients explode or vanish → training fails.


Assembling a Simple Network

class SimpleMLP(nn.Module):
    def __init__(self, in_dim: int, hidden_dim: int, out_dim: int):
        super().__init__()
        self.layers = nn.Sequential(
            Linear(in_dim, hidden_dim),
            ReLU(),
            Linear(hidden_dim, hidden_dim),
            ReLU(),
            Linear(hidden_dim, out_dim),
        )

    def forward(self, x: t.Tensor) -> t.Tensor:
        return self.layers(x)

# Usage
model = SimpleMLP(784, 128, 10)  # MNIST: 28*28 → 128 → 10 classes
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

The extra_repr Method

Make your modules printable:

class Linear(nn.Module):
    # ... __init__ and forward ...

    def extra_repr(self) -> str:
        return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}"

# Now printing the model is informative:
print(Linear(784, 128))
# Linear(in_features=784, out_features=128, bias=True)

Capstone Connection

Modules in Transformers:

Every transformer is built from modules:

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, n_heads)
        self.mlp = MLP(d_model)
        self.ln1 = LayerNorm(d_model)
        self.ln2 = LayerNorm(d_model)

    def forward(self, x):
        x = x + self.attention(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

When analyzing sycophancy, you'll look at:

Understanding modules = understanding where to look.


🎓 Tyla's Exercise

  1. Implement Linear without looking at the solution. Verify it produces the same output as nn.Linear (up to numerical precision).

  2. Why does PyTorch store weights as (out_features, in_features) instead of (in_features, out_features)?

  3. Derive the Xavier initialization formula: $\frac{\sqrt{6}}{\sqrt{n_{in} + n_{out}}}$


💻 Aaliyah's Exercise

Build and train a simple classifier:

class MNISTClassifier(nn.Module):
    """
    Architecture:
    - Flatten: (batch, 1, 28, 28) → (batch, 784)
    - Linear: 784 → 256
    - ReLU
    - Linear: 256 → 128
    - ReLU
    - Linear: 128 → 10
    """
    pass

# Train it:
def train(model, train_loader, epochs=5):
    """
    Use CrossEntropyLoss and Adam optimizer.
    Print accuracy after each epoch.
    Target: >95% accuracy.
    """
    pass

📚 Maneesha's Reflection

  1. The nn.Module pattern is an example of the "Composite" design pattern. How does this relate to the compositional nature of neural networks?

  2. Initialization is "setting up the student before the lesson." What's the pedagogical equivalent of bad initialization?

  3. Why do you think PyTorch makes parameters() automatic instead of requiring explicit registration?