ResNets: Skip Connections

Deep networks should be more powerful. But they weren't—until skip connections.

This chapter explains the degradation problem and how residual connections solve it.


The Degradation Problem

Intuition: A 56-layer network should be at least as good as a 20-layer network. The extra layers could just learn the identity function.

Reality: Deeper networks performed worse on both training and test sets.

Test Error:
20-layer network: 6.7%
56-layer network: 7.8%  ← Worse!

This wasn't overfitting (training error was also worse). The network couldn't even learn to copy its input through the extra layers.


The Residual Solution

Instead of learning $H(x)$, learn $F(x) = H(x) - x$.

The output becomes: $y = F(x) + x$

           ┌─────────────────┐
           │                 │
    x ─────┼──► Conv ──► ReLU ──► Conv ──► + ──► ReLU ──► y
           │                 │         ↑
           └────────────────────────────┘
                   skip connection

Why is this easier?

If the optimal transformation is close to identity, learning $F(x) \approx 0$ is easier than learning $H(x) \approx x$.

Pushing weights toward zero is what regularization does naturally!


Implementing a Residual Block

class ResidualBlock(nn.Module):
    def __init__(self, channels: int):
        super().__init__()
        self.conv1 = Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x: t.Tensor) -> t.Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out = out + identity  # Skip connection!
        out = F.relu(out)

        return out

Downsampling Blocks

When spatial dimensions change, we need a projection shortcut:

class DownsamplingBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, stride: int = 2):
        super().__init__()
        self.conv1 = Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # 1x1 convolution to match dimensions
        self.shortcut = nn.Sequential(
            Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x: t.Tensor) -> t.Tensor:
        identity = self.shortcut(x)  # Project to match dimensions

        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out = out + identity
        out = F.relu(out)

        return out

Batch Normalization

Normalizes activations to have mean 0 and variance 1, then scales and shifts:

$$y = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta$$

Where $\mu$ and $\sigma^2$ are computed over the batch, and $\gamma$, $\beta$ are learned.

Benefits:

  1. Allows higher learning rates
  2. Reduces sensitivity to initialization
  3. Acts as regularization
# In training mode: uses batch statistics
# In eval mode: uses running averages
model.train()  # Batch norm uses batch stats
model.eval()   # Batch norm uses running stats

ResNet-34 Architecture

class ResNet34(nn.Module):
    def __init__(self, num_classes: int = 1000):
        super().__init__()

        # Initial layers
        self.conv1 = Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool = MaxPool2d(kernel_size=3, stride=2, padding=1)

        # Residual blocks
        self.layer1 = self._make_layer(64, 64, 3)
        self.layer2 = self._make_layer(64, 128, 4, stride=2)
        self.layer3 = self._make_layer(128, 256, 6, stride=2)
        self.layer4 = self._make_layer(256, 512, 3, stride=2)

        # Classification head
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = Linear(512, num_classes)

    def _make_layer(self, in_ch, out_ch, num_blocks, stride=1):
        layers = []
        # First block might downsample
        if stride != 1 or in_ch != out_ch:
            layers.append(DownsamplingBlock(in_ch, out_ch, stride))
        else:
            layers.append(ResidualBlock(out_ch))
        # Remaining blocks
        for _ in range(1, num_blocks):
            layers.append(ResidualBlock(out_ch))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.maxpool(F.relu(self.bn1(self.conv1(x))))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

Capstone Connection

Residual connections in Transformers:

Every transformer layer uses residual connections:

# Attention sublayer
x = x + attention(layer_norm(x))

# MLP sublayer
x = x + mlp(layer_norm(x))

This creates a "residual stream" that information flows through. When analyzing sycophancy:

  1. The residual stream carries the "default" response
  2. Each layer adds or subtracts from it
  3. Sycophantic behavior might come from specific layers adding "agreement" features

Understanding residuals = understanding information flow in transformers.


🎓 Tyla's Exercise

  1. Why does $F(x) + x$ help with gradient flow during backpropagation? (Hint: What's the gradient of the addition operation?)

  2. In the original ResNet paper, they tried "plain" networks vs residual networks. What did they observe about the training loss?

  3. Why is batch normalization placed before the activation function in modern architectures (pre-activation ResNets)?


💻 Aaliyah's Exercise

Load pretrained weights into your ResNet:

def load_pretrained_resnet34():
    """
    1. Create your ResNet34 model
    2. Load weights from torchvision.models.resnet34(pretrained=True)
    3. Copy weights from their model to yours
    4. Verify outputs match on a test image

    The weight names will be different - you'll need to map them.
    """
    pass

# Test on ImageNet validation image
# Should get reasonable predictions (dog, cat, etc.)

📚 Maneesha's Reflection

  1. Skip connections can be seen as "ensembling" many shallower networks. How does this relate to the "lottery ticket hypothesis"?

  2. The transition from AlexNet (2012) → VGG (2014) → ResNet (2015) shows increasingly deep networks. What were the key insights at each step?

  3. If you were explaining the degradation problem to a student who hasn't seen it before, what intuition would you build first?