ResNets: Skip Connections
Deep networks should be more powerful. But they weren't—until skip connections.
This chapter explains the degradation problem and how residual connections solve it.
The Degradation Problem
Intuition: A 56-layer network should be at least as good as a 20-layer network. The extra layers could just learn the identity function.
Reality: Deeper networks performed worse on both training and test sets.
Test Error:
20-layer network: 6.7%
56-layer network: 7.8% ← Worse!
This wasn't overfitting (training error was also worse). The network couldn't even learn to copy its input through the extra layers.
The Residual Solution
Instead of learning $H(x)$, learn $F(x) = H(x) - x$.
The output becomes: $y = F(x) + x$
┌─────────────────┐
│ │
x ─────┼──► Conv ──► ReLU ──► Conv ──► + ──► ReLU ──► y
│ │ ↑
└────────────────────────────┘
skip connection
Why is this easier?
If the optimal transformation is close to identity, learning $F(x) \approx 0$ is easier than learning $H(x) \approx x$.
Pushing weights toward zero is what regularization does naturally!
Implementing a Residual Block
class ResidualBlock(nn.Module):
def __init__(self, channels: int):
super().__init__()
self.conv1 = Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.conv2 = Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x: t.Tensor) -> t.Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = F.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = out + identity # Skip connection!
out = F.relu(out)
return out
Downsampling Blocks
When spatial dimensions change, we need a projection shortcut:
class DownsamplingBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int, stride: int = 2):
super().__init__()
self.conv1 = Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
# 1x1 convolution to match dimensions
self.shortcut = nn.Sequential(
Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
nn.BatchNorm2d(out_channels)
)
def forward(self, x: t.Tensor) -> t.Tensor:
identity = self.shortcut(x) # Project to match dimensions
out = self.conv1(x)
out = self.bn1(out)
out = F.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = out + identity
out = F.relu(out)
return out
Batch Normalization
Normalizes activations to have mean 0 and variance 1, then scales and shifts:
$$y = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta$$
Where $\mu$ and $\sigma^2$ are computed over the batch, and $\gamma$, $\beta$ are learned.
Benefits:
- Allows higher learning rates
- Reduces sensitivity to initialization
- Acts as regularization
# In training mode: uses batch statistics
# In eval mode: uses running averages
model.train() # Batch norm uses batch stats
model.eval() # Batch norm uses running stats
ResNet-34 Architecture
class ResNet34(nn.Module):
def __init__(self, num_classes: int = 1000):
super().__init__()
# Initial layers
self.conv1 = Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
self.bn1 = nn.BatchNorm2d(64)
self.maxpool = MaxPool2d(kernel_size=3, stride=2, padding=1)
# Residual blocks
self.layer1 = self._make_layer(64, 64, 3)
self.layer2 = self._make_layer(64, 128, 4, stride=2)
self.layer3 = self._make_layer(128, 256, 6, stride=2)
self.layer4 = self._make_layer(256, 512, 3, stride=2)
# Classification head
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = Linear(512, num_classes)
def _make_layer(self, in_ch, out_ch, num_blocks, stride=1):
layers = []
# First block might downsample
if stride != 1 or in_ch != out_ch:
layers.append(DownsamplingBlock(in_ch, out_ch, stride))
else:
layers.append(ResidualBlock(out_ch))
# Remaining blocks
for _ in range(1, num_blocks):
layers.append(ResidualBlock(out_ch))
return nn.Sequential(*layers)
def forward(self, x):
x = self.maxpool(F.relu(self.bn1(self.conv1(x))))
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
Capstone Connection
Residual connections in Transformers:
Every transformer layer uses residual connections:
# Attention sublayer
x = x + attention(layer_norm(x))
# MLP sublayer
x = x + mlp(layer_norm(x))
This creates a "residual stream" that information flows through. When analyzing sycophancy:
- The residual stream carries the "default" response
- Each layer adds or subtracts from it
- Sycophantic behavior might come from specific layers adding "agreement" features
Understanding residuals = understanding information flow in transformers.
🎓 Tyla's Exercise
Why does $F(x) + x$ help with gradient flow during backpropagation? (Hint: What's the gradient of the addition operation?)
In the original ResNet paper, they tried "plain" networks vs residual networks. What did they observe about the training loss?
Why is batch normalization placed before the activation function in modern architectures (pre-activation ResNets)?
💻 Aaliyah's Exercise
Load pretrained weights into your ResNet:
def load_pretrained_resnet34():
"""
1. Create your ResNet34 model
2. Load weights from torchvision.models.resnet34(pretrained=True)
3. Copy weights from their model to yours
4. Verify outputs match on a test image
The weight names will be different - you'll need to map them.
"""
pass
# Test on ImageNet validation image
# Should get reasonable predictions (dog, cat, etc.)
📚 Maneesha's Reflection
Skip connections can be seen as "ensembling" many shallower networks. How does this relate to the "lottery ticket hypothesis"?
The transition from AlexNet (2012) → VGG (2014) → ResNet (2015) shows increasingly deep networks. What were the key insights at each step?
If you were explaining the degradation problem to a student who hasn't seen it before, what intuition would you build first?