Deep Q-Networks: Advanced Techniques

From vanilla DQN to the Rainbow: improvements that made deep RL practical.


The Maximization Bias Problem

Vanilla DQN uses the same network to select and evaluate actions:

$$y = r + \gamma \max_{a'} Q(s', a'; \theta^-)$$

The Problem: If $Q(s', a_1)$ and $Q(s', a_2)$ are both noisy estimates of the true value (say, 0), taking the max will systematically overestimate:

$$\mathbb{E}[\max(\hat{Q}_1, \hat{Q}_2)] > \max(\mathbb{E}[\hat{Q}_1], \mathbb{E}[\hat{Q}_2])$$

This maximization bias compounds across the entire trajectory, leading to overoptimistic value estimates and suboptimal policies.


Double DQN

Key Insight: Decouple action selection from action evaluation.

# Vanilla DQN target:
# max_a' Q(s', a'; theta-)  # Same network selects AND evaluates

# Double DQN target:
# Q(s', argmax_a' Q(s', a'; theta); theta-)
#       ^^^^^^^^^^^^^^^^^^^^^^^^^
#       Online network SELECTS best action
#                                  ^^^^^^^
#                                  Target network EVALUATES it

In code:

def compute_double_dqn_target(
    q_network: QNetwork,
    target_network: QNetwork,
    next_obs: Tensor,
    rewards: Tensor,
    terminated: Tensor,
    gamma: float
) -> Tensor:
    """Compute Double DQN target values."""
    with torch.no_grad():
        # Online network selects best actions
        next_q_values = q_network(next_obs)
        best_actions = next_q_values.argmax(dim=-1)

        # Target network evaluates those actions
        target_q_values = target_network(next_obs)
        next_q = target_q_values.gather(1, best_actions.unsqueeze(1)).squeeze(1)

        # TD target
        target = rewards + gamma * next_q * (1 - terminated.float())

    return target


# Compare with vanilla DQN:
def compute_vanilla_dqn_target(target_network, next_obs, rewards, terminated, gamma):
    with torch.no_grad():
        # Target network both selects AND evaluates
        next_q = target_network(next_obs).max(dim=-1).values
        target = rewards + gamma * next_q * (1 - terminated.float())
    return target

Why This Works:


Dueling DQN

Key Insight: Decompose Q-values into state value + action advantage:

$$Q(s, a) = V(s) + A(s, a)$$

Where:

The Architecture:

class DuelingQNetwork(nn.Module):
    def __init__(self, obs_shape: tuple[int], num_actions: int):
        super().__init__()

        # Shared feature extraction
        self.features = nn.Sequential(
            nn.Linear(obs_shape[0], 128),
            nn.ReLU(),
        )

        # Value stream: V(s)
        self.value_stream = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)  # Single value
        )

        # Advantage stream: A(s, a)
        self.advantage_stream = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, num_actions)  # One per action
        )

    def forward(self, x: Tensor) -> Tensor:
        features = self.features(x)

        value = self.value_stream(features)  # [batch, 1]
        advantages = self.advantage_stream(features)  # [batch, num_actions]

        # Combine: Q = V + (A - mean(A))
        # Subtracting mean ensures identifiability
        q_values = value + (advantages - advantages.mean(dim=-1, keepdim=True))

        return q_values

Why Subtract Mean Advantage?

Without the constraint, $V$ and $A$ are not identifiable:

By forcing $\sum_a A(s, a) = 0$, we get:


Why Dueling Helps

Consider a state where all actions are equally good (or bad):

Vanilla DQN: Must learn separate Q-values for each action, even if they're all the same.

Dueling DQN: Learns $V(s) = $ common value, $A(s, a) \approx 0$ for all $a$. More efficient!

State: CartPole perfectly balanced, centered
- All actions roughly equivalent
- Vanilla DQN: Must estimate Q(s, left), Q(s, right) separately
- Dueling: V(s) = high, A(s, left) ≈ A(s, right) ≈ 0

State: CartPole tilting right
- Action matters!
- Dueling: V(s) = medium, A(s, left) > 0, A(s, right) < 0

Prioritized Experience Replay

Problem: Uniform sampling from replay buffer is inefficient. Some experiences are more "surprising" and informative than others.

Solution: Sample experiences proportional to their TD error:

$$P(i) \propto |TD_i|^\alpha + \epsilon$$

Where:

class PrioritizedReplayBuffer:
    def __init__(self, capacity: int, alpha: float = 0.6):
        self.capacity = capacity
        self.alpha = alpha
        self.priorities = np.zeros(capacity, dtype=np.float32)
        self.data = []
        self.position = 0
        self.max_priority = 1.0

    def add(self, experience):
        """Add with maximum priority (will be updated after training)."""
        if len(self.data) < self.capacity:
            self.data.append(experience)
        else:
            self.data[self.position] = experience

        self.priorities[self.position] = self.max_priority ** self.alpha
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size: int, beta: float = 0.4):
        """Sample proportional to priorities."""
        n = len(self.data)
        probs = self.priorities[:n] / self.priorities[:n].sum()

        indices = np.random.choice(n, batch_size, p=probs, replace=False)
        experiences = [self.data[i] for i in indices]

        # Importance sampling weights (for unbiased updates)
        weights = (n * probs[indices]) ** (-beta)
        weights /= weights.max()

        return experiences, indices, weights

    def update_priorities(self, indices, td_errors):
        """Update priorities based on new TD errors."""
        for idx, td_error in zip(indices, td_errors):
            priority = (abs(td_error) + 1e-6) ** self.alpha
            self.priorities[idx] = priority
            self.max_priority = max(self.max_priority, priority)

Importance Sampling Correction:

Prioritized sampling introduces bias. We correct with importance sampling weights:

$$w_i = \left( \frac{1}{N \cdot P(i)} \right)^\beta$$

def training_step_prioritized(self, step: int, beta: float):
    """Training step with prioritized replay."""
    experiences, indices, weights = self.buffer.sample(self.batch_size, beta)

    # Compute TD errors
    td_errors = self.compute_td_errors(experiences)

    # Weight the loss by importance sampling weights
    loss = (weights * td_errors.pow(2)).mean()

    self.optimizer.zero_grad()
    loss.backward()
    self.optimizer.step()

    # Update priorities with new TD errors
    self.buffer.update_priorities(indices, td_errors.detach().cpu().numpy())

The CartPole Environment

Our training ground for DQN:

import gymnasium as gym

env = gym.make("CartPole-v1")

# Observation space: Box(4)
# [cart_position, cart_velocity, pole_angle, pole_angular_velocity]

# Action space: Discrete(2)
# 0 = push cart left
# 1 = push cart right

# Reward: +1 for every step the pole stays upright

# Termination conditions:
# - Pole angle > 12 degrees
# - Cart moves > 2.4 units from center

# Truncation: Episode ends at 500 steps

# Solved: Average reward >= 475 over 100 episodes

Why CartPole is a Good Benchmark:


Atari: The Original DQN Benchmark

The 2013 DQN paper trained on raw pixels from Atari games:

# Atari observation: 210 x 160 x 3 RGB image
# Preprocessed to: 84 x 84 x 4 grayscale frames (stacked)

class AtariQNetwork(nn.Module):
    def __init__(self, num_actions: int):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(4, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
        )

        # Calculate flattened size
        # 84 -> 20 -> 9 -> 7, so 64 * 7 * 7 = 3136
        self.fc = nn.Sequential(
            nn.Linear(3136, 512),
            nn.ReLU(),
            nn.Linear(512, num_actions)
        )

    def forward(self, x: Tensor) -> Tensor:
        # x: [batch, 4, 84, 84]
        features = self.conv(x / 255.0)  # Normalize pixels
        return self.fc(features)

Atari Preprocessing Pipeline:

  1. Convert RGB to grayscale
  2. Resize from 210x160 to 84x84
  3. Stack 4 consecutive frames (to capture motion)
  4. Clip rewards to -1, +1

Implementing DQN from Scratch

The complete training loop:

@dataclass
class DQNConfig:
    env_id: str = "CartPole-v1"
    total_timesteps: int = 500_000
    buffer_size: int = 10_000
    batch_size: int = 128
    gamma: float = 0.99
    learning_rate: float = 2.5e-4
    target_update_freq: int = 1000
    train_frequency: int = 10
    start_epsilon: float = 1.0
    end_epsilon: float = 0.1
    exploration_fraction: float = 0.2
    use_double_dqn: bool = True


class DQNTrainer:
    def __init__(self, config: DQNConfig):
        self.config = config
        self.env = gym.make(config.env_id)

        obs_shape = self.env.observation_space.shape
        num_actions = self.env.action_space.n

        self.q_network = QNetwork(obs_shape, num_actions)
        self.target_network = QNetwork(obs_shape, num_actions)
        self.target_network.load_state_dict(self.q_network.state_dict())

        self.optimizer = torch.optim.AdamW(
            self.q_network.parameters(),
            lr=config.learning_rate
        )

        self.buffer = ReplayBuffer(config.buffer_size, obs_shape)
        self.rng = np.random.default_rng(42)

    def train(self):
        obs, _ = self.env.reset()
        episode_reward = 0
        episode_rewards = []

        for step in range(self.config.total_timesteps):
            # Epsilon-greedy action selection
            epsilon = linear_schedule(
                step,
                self.config.start_epsilon,
                self.config.end_epsilon,
                self.config.exploration_fraction,
                self.config.total_timesteps
            )

            if self.rng.random() < epsilon:
                action = self.env.action_space.sample()
            else:
                with torch.no_grad():
                    q_values = self.q_network(torch.tensor(obs).unsqueeze(0))
                    action = q_values.argmax().item()

            # Environment step
            next_obs, reward, terminated, truncated, _ = self.env.step(action)
            episode_reward += reward

            # Store experience
            self.buffer.add(obs, action, reward, terminated, next_obs)
            obs = next_obs

            # Episode end
            if terminated or truncated:
                episode_rewards.append(episode_reward)
                obs, _ = self.env.reset()
                episode_reward = 0

            # Training
            if step >= self.config.buffer_size and step % self.config.train_frequency == 0:
                self._training_step(step)

            # Target network update
            if step % self.config.target_update_freq == 0:
                self.target_network.load_state_dict(self.q_network.state_dict())

        return episode_rewards

    def _training_step(self, step: int):
        data = self.buffer.sample(self.config.batch_size)

        with torch.no_grad():
            if self.config.use_double_dqn:
                # Double DQN: online selects, target evaluates
                next_actions = self.q_network(data.next_obs).argmax(dim=-1)
                next_q = self.target_network(data.next_obs).gather(
                    1, next_actions.unsqueeze(1)
                ).squeeze(1)
            else:
                # Vanilla DQN
                next_q = self.target_network(data.next_obs).max(dim=-1).values

            target = data.rewards + self.config.gamma * next_q * (1 - data.terminated.float())

        current_q = self.q_network(data.obs).gather(
            1, data.actions.unsqueeze(1)
        ).squeeze(1)

        loss = F.mse_loss(current_q, target)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

Debugging DQN: What to Log

# Essential metrics to track:
wandb.log({
    # Performance
    "episode_reward": episode_reward,
    "episode_length": episode_length,

    # Learning dynamics
    "td_loss": loss.item(),
    "q_values_mean": current_q.mean().item(),
    "q_values_max": current_q.max().item(),

    # Exploration
    "epsilon": epsilon,

    # Throughput
    "steps_per_second": steps / elapsed_time,
})

Interpreting the Metrics:

Metric Healthy Sign Warning Sign
Episode Reward Increasing trend Stuck at random policy level
TD Loss Decreasing initially, then stable Exploding or NaN
Q-values Trending toward $\frac{1}{1-\gamma}$ Diverging or negative
Epsilon Smooth decay N/A

Catastrophic Forgetting

A common failure mode:

Step 0-100k:   Agent learns to balance (reward ~500)
Step 100k-150k: Performance drops to ~50
Step 150k-200k: Recovers to ~500
Step 200k-250k: Drops again...

Why It Happens:

Solutions:

  1. Keep some old experiences in buffer (reservoir sampling)
  2. Use prioritized replay (rare failure cases stay prioritized)
  3. Periodically reset part of the buffer

Rainbow DQN

The "Rainbow" paper combined 6 DQN improvements:

Component Contribution
Double DQN Reduces overestimation
Dueling Networks Better value decomposition
Prioritized Replay Focus on surprising experiences
Multi-step Returns Better credit assignment
Distributional RL Model full return distribution
Noisy Nets Learned exploration

Each addition helps, but the combination is more than the sum of parts.


Capstone Connection

DQN failure modes preview alignment challenges:

  1. Reward Hacking in Practice:
# CoastRunners boat racing game
# Intended: Race around the track
# Actual: Agent found a loop of respawning powerups
#         Crashed repeatedly while collecting them
#         Got higher score than actually racing!

# Lesson: Your objective function IS your objective
# If you can game it, the agent will find a way
  1. Sparse Reward Challenge:
# Montezuma's Revenge: DQN scored 0%
# Rooms require: Find key -> Navigate maze -> Use key on door
# Random exploration almost never stumbles on this sequence
# Even with reward, credit assignment fails over long horizons

# Relevance: Complex alignment goals may have similar structure
# "Be helpful" requires many correct intermediate steps
  1. Distribution Shift:
# Buffer contains experiences from old policy
# Current policy is different
# Training on off-policy data can cause divergence

# In deployment: Training data != deployment distribution
# Models optimized for one context may fail in another

🎓 Tyla's Exercise

  1. Double DQN Theory: Prove that Double DQN has lower bias than vanilla DQN. Under what conditions might it have higher variance?

  2. Dueling Architecture: In the dueling network, we subtract mean advantage: $Q = V + (A - \bar{A})$. The original paper also tried max: $Q = V + (A - \max_a A)$. Derive the mathematical properties of each. Why might mean work better?

  3. Prioritized Replay Bias: Show mathematically why prioritized sampling introduces bias in the gradient estimate. Then prove that importance sampling weights correct this bias as $\beta \to 1$.

  4. Sample Complexity: DQN on Atari requires ~50 million frames. A human can learn Breakout in ~15 minutes. Estimate the sample efficiency gap and hypothesize what inductive biases humans have that DQN lacks.


💻 Aaliyah's Exercise

Implement the full Rainbow-lite DQN:

class RainbowDQN:
    """
    Implement these improvements incrementally:

    1. Double DQN (required)
       - Modify target computation
       - Test: Q-values should be more stable

    2. Dueling Architecture (required)
       - Split network into value and advantage streams
       - Test: Similar performance with fewer parameters

    3. Prioritized Experience Replay (stretch)
       - Implement priority-based sampling
       - Add importance sampling weights
       - Test: Faster learning on sparse reward tasks

    Evaluation checklist:
    [ ] Pass all 5 probe environments with each modification
    [ ] Solve CartPole (>475 avg reward)
    [ ] Compare learning curves: vanilla vs double vs dueling
    [ ] Log Q-value distributions over training
    """

    def __init__(self, config):
        # Your implementation here
        pass

    def compute_double_dqn_target(self, batch):
        """
        Online network selects: argmax_a Q(s', a; theta)
        Target network evaluates: Q(s', a*; theta-)
        """
        pass

    def create_dueling_network(self, obs_shape, num_actions):
        """
        Architecture:
        obs -> shared_layers -> [value_stream, advantage_stream]
        Q = V + (A - mean(A))
        """
        pass


# Test your implementation:
def test_rainbow():
    config = DQNConfig(use_double_dqn=True, use_dueling=True)
    trainer = RainbowDQN(config)

    # Run probe tests
    for probe in range(1, 6):
        test_probe(trainer, probe)
        print(f"Probe {probe} passed!")

    # Train on CartPole
    rewards = trainer.train()
    assert np.mean(rewards[-100:]) > 475, "Failed to solve CartPole"
    print("CartPole solved!")

📚 Maneesha's Reflection

  1. Bias-Variance in Learning Design:

Double DQN reduces bias at potential cost of variance. In instructional design, we face similar tradeoffs:

How do you balance these when designing curriculum?

  1. The Value-Advantage Decomposition:

Dueling DQN separates "how good is this situation" from "which action is best." This maps to educational concepts:

When should instruction focus on building foundational value vs. optimizing specific advantages?

  1. Prioritized Replay as Attention:

The brain naturally prioritizes surprising or emotionally significant memories. Prioritized replay mimics this:

How might this inform spaced repetition algorithms? Should learning software prioritize material where the student's predictions are most wrong?

  1. The Forgetting Problem:

Catastrophic forgetting in DQN - losing skills when the buffer fills with successes - mirrors educational challenges:

What mechanisms do effective educational systems use to prevent skill regression?

  1. From 50 Million Frames to 15 Minutes:

The sample efficiency gap between DQN and humans is enormous. Hypothesize what makes human learning so efficient:

How can AI education materials leverage these human advantages rather than fighting against them?