PPO Implementation: From Theory to Code

Building PPO from scratch reveals why each component matters.


The PPO Algorithm Structure

# PPO pseudocode
for iteration in range(num_iterations):
    # 1. Collect rollouts with current policy
    trajectories = collect_rollouts(policy, envs, num_steps)

    # 2. Compute advantages
    advantages = compute_gae(trajectories)

    # 3. Multiple epochs of minibatch updates
    for epoch in range(num_epochs):
        for minibatch in create_minibatches(trajectories):
            # Update policy and value function
            loss = compute_ppo_loss(minibatch, advantages)
            optimizer.step(loss)

Let's implement each piece.


Vectorized Environments

PPO's sample efficiency comes from parallel environments:

import gymnasium as gym
from gymnasium.vector import SyncVectorEnv, AsyncVectorEnv

def make_env(env_id, seed, idx):
    """Factory function for creating environments."""
    def thunk():
        env = gym.make(env_id)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        env.reset(seed=seed + idx)
        return env
    return thunk

# Create vectorized environment
num_envs = 8
envs = SyncVectorEnv([
    make_env("CartPole-v1", seed=42, idx=i)
    for i in range(num_envs)
])

# Now actions/observations are batched
obs, info = envs.reset()  # obs.shape = (num_envs, obs_dim)
actions = policy(obs)      # actions.shape = (num_envs,)
next_obs, rewards, terminated, truncated, infos = envs.step(actions)

Why Vectorized Environments?

# Single environment: Sequential, slow
# for step in range(1000):
#     obs -> action -> next_obs  # One at a time

# Vectorized: Parallel, fast
# obs_batch -> action_batch -> next_obs_batch  # All at once

# Benefits:
# 1. GPU parallelism: Batch inference is efficient
# 2. Decorrelated samples: Different envs = different states
# 3. More data: 8 envs = 8x the experience per wall-clock second

# The math: With 8 envs and 128 steps each, one rollout gives:
# 8 * 128 = 1024 transitions for training

The Rollout Buffer

Store trajectories for training:

class RolloutBuffer:
    """Stores rollout data for PPO training."""

    def __init__(self, num_steps, num_envs, obs_shape, action_shape):
        self.obs = torch.zeros((num_steps, num_envs) + obs_shape)
        self.actions = torch.zeros((num_steps, num_envs) + action_shape)
        self.log_probs = torch.zeros((num_steps, num_envs))
        self.rewards = torch.zeros((num_steps, num_envs))
        self.dones = torch.zeros((num_steps, num_envs))
        self.values = torch.zeros((num_steps, num_envs))

        self.advantages = torch.zeros((num_steps, num_envs))
        self.returns = torch.zeros((num_steps, num_envs))

        self.ptr = 0
        self.num_steps = num_steps

    def store(self, obs, action, log_prob, reward, done, value):
        """Store one step of experience."""
        self.obs[self.ptr] = obs
        self.actions[self.ptr] = action
        self.log_probs[self.ptr] = log_prob
        self.rewards[self.ptr] = reward
        self.dones[self.ptr] = done
        self.values[self.ptr] = value
        self.ptr += 1

    def reset(self):
        self.ptr = 0

Collecting Rollouts

def collect_rollouts(policy, value_net, envs, buffer, num_steps):
    """
    Collect num_steps of experience from vectorized environments.
    """
    obs, _ = envs.reset()
    obs = torch.tensor(obs, dtype=torch.float32)

    for step in range(num_steps):
        with torch.no_grad():
            # Get action distribution
            dist = policy(obs)
            action = dist.sample()
            log_prob = dist.log_prob(action)

            # Get value estimate
            value = value_net(obs).squeeze(-1)

        # Step environment
        next_obs, reward, terminated, truncated, info = envs.step(
            action.cpu().numpy()
        )
        done = terminated | truncated

        # Store transition
        buffer.store(
            obs=obs,
            action=action,
            log_prob=log_prob,
            reward=torch.tensor(reward),
            done=torch.tensor(done, dtype=torch.float32),
            value=value
        )

        obs = torch.tensor(next_obs, dtype=torch.float32)

    # Get final value for GAE computation
    with torch.no_grad():
        final_value = value_net(obs).squeeze(-1)

    return final_value

Computing GAE

def compute_gae(buffer, final_value, gamma=0.99, gae_lambda=0.95):
    """
    Compute Generalized Advantage Estimation.
    """
    advantages = torch.zeros_like(buffer.rewards)
    last_gae = 0

    for t in reversed(range(buffer.num_steps)):
        if t == buffer.num_steps - 1:
            next_value = final_value
            next_non_terminal = 1.0 - buffer.dones[t]
        else:
            next_value = buffer.values[t + 1]
            next_non_terminal = 1.0 - buffer.dones[t]

        # TD error: r + gamma * V(s') - V(s)
        delta = (
            buffer.rewards[t]
            + gamma * next_value * next_non_terminal
            - buffer.values[t]
        )

        # GAE: exponentially-weighted average of TD errors
        advantages[t] = delta + gamma * gae_lambda * next_non_terminal * last_gae
        last_gae = advantages[t]

    buffer.advantages = advantages
    buffer.returns = advantages + buffer.values

Minibatch Updates

The key to PPO's sample efficiency:

def create_minibatches(buffer, batch_size):
    """
    Yield minibatches for multiple epochs of training.
    """
    # Flatten the buffer
    num_samples = buffer.num_steps * buffer.obs.shape[1]  # steps * envs

    obs = buffer.obs.reshape(num_samples, -1)
    actions = buffer.actions.reshape(num_samples, -1)
    log_probs = buffer.log_probs.reshape(num_samples)
    advantages = buffer.advantages.reshape(num_samples)
    returns = buffer.returns.reshape(num_samples)

    # Normalize advantages (important for stability!)
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

    # Random permutation
    indices = torch.randperm(num_samples)

    for start in range(0, num_samples, batch_size):
        end = start + batch_size
        batch_indices = indices[start:end]

        yield {
            'obs': obs[batch_indices],
            'actions': actions[batch_indices],
            'old_log_probs': log_probs[batch_indices],
            'advantages': advantages[batch_indices],
            'returns': returns[batch_indices],
        }

Why Multiple Epochs?

# Traditional RL: Use each sample once, then discard
# PPO: Reuse each sample for num_epochs updates

# Why this works:
# 1. Clipping prevents large updates even with reused data
# 2. Each epoch uses different minibatch orderings
# 3. Significantly improves sample efficiency

# Typical values:
# num_epochs = 4 (Atari)
# num_epochs = 10 (MuJoCo)

# Warning: Too many epochs -> overfitting to current batch
# Monitor: clip_fraction and approx_kl should stay reasonable

The PPO Loss Function

def compute_ppo_loss(
    policy,
    value_net,
    batch,
    clip_epsilon=0.2,
    value_coef=0.5,
    entropy_coef=0.01
):
    """
    Compute the full PPO loss.
    """
    # Get new policy distribution
    dist = policy(batch['obs'])
    new_log_probs = dist.log_prob(batch['actions'].squeeze(-1))
    entropy = dist.entropy().mean()

    # Get new value predictions
    new_values = value_net(batch['obs']).squeeze(-1)

    # --- Policy Loss (clipped surrogate) ---
    ratio = (new_log_probs - batch['old_log_probs']).exp()

    # Clipped objective
    surr1 = ratio * batch['advantages']
    surr2 = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * batch['advantages']
    policy_loss = -torch.min(surr1, surr2).mean()

    # --- Value Loss ---
    value_loss = 0.5 * ((new_values - batch['returns']) ** 2).mean()

    # --- Total Loss ---
    total_loss = policy_loss + value_coef * value_loss - entropy_coef * entropy

    # Diagnostics
    with torch.no_grad():
        approx_kl = ((ratio - 1) - ratio.log()).mean()
        clip_fraction = ((ratio - 1).abs() > clip_epsilon).float().mean()

    return {
        'loss': total_loss,
        'policy_loss': policy_loss,
        'value_loss': value_loss,
        'entropy': entropy,
        'approx_kl': approx_kl,
        'clip_fraction': clip_fraction,
    }

Value Function Clipping

Some implementations clip value updates too:

def clipped_value_loss(
    new_values,
    old_values,
    returns,
    clip_epsilon=0.2
):
    """
    Clipped value loss (optional, used in some implementations).

    Idea: Don't let value predictions change too much either.
    """
    # Unclipped loss
    value_loss_unclipped = (new_values - returns) ** 2

    # Clipped value prediction
    values_clipped = old_values + torch.clamp(
        new_values - old_values,
        -clip_epsilon,
        clip_epsilon
    )
    value_loss_clipped = (values_clipped - returns) ** 2

    # Take maximum (pessimistic)
    value_loss = 0.5 * torch.max(value_loss_unclipped, value_loss_clipped).mean()

    return value_loss

# Note: Value clipping is controversial
# - Original PPO paper: Does not use it
# - OpenAI baselines: Uses it
# - Empirically: Sometimes helps, sometimes hurts
# - Recommendation: Try both, keep what works

Entropy Bonus

Encourages exploration:

# Entropy measures "randomness" of the policy
# High entropy: Actions spread across many options
# Low entropy: Policy concentrated on one action

# Why add entropy bonus?
# 1. Prevents premature convergence to deterministic policy
# 2. Maintains exploration
# 3. Acts as regularization

def compute_entropy(dist):
    """
    Entropy for common distributions.
    """
    if isinstance(dist, torch.distributions.Categorical):
        # H = -sum(p * log(p))
        return dist.entropy()

    elif isinstance(dist, torch.distributions.Normal):
        # H = 0.5 * log(2 * pi * e * sigma^2)
        return dist.entropy().sum(dim=-1)

# Entropy coefficient decay:
# Early training: High entropy (explore)
# Late training: Low entropy (exploit)
def get_entropy_coef(step, total_steps, start=0.01, end=0.001):
    return start - (start - end) * (step / total_steps)

The Actor-Critic Networks

import torch.nn as nn

class ActorCritic(nn.Module):
    """
    Shared network for policy and value function.
    """
    def __init__(self, obs_dim, action_dim, hidden_dim=64):
        super().__init__()

        # Shared feature extractor
        self.shared = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
        )

        # Policy head
        self.policy_mean = nn.Linear(hidden_dim, action_dim)
        self.policy_log_std = nn.Parameter(torch.zeros(action_dim))

        # Value head
        self.value = nn.Linear(hidden_dim, 1)

        # Initialize with small weights
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
                nn.init.constant_(m.bias, 0)

    def forward(self, obs):
        features = self.shared(obs)
        return features

    def get_policy(self, obs):
        features = self.forward(obs)
        mean = self.policy_mean(features)
        std = self.policy_log_std.exp()
        return torch.distributions.Normal(mean, std)

    def get_value(self, obs):
        features = self.forward(obs)
        return self.value(features)

Separate vs Shared Networks

# Shared (above): Policy and value share feature extractor
# - Fewer parameters
# - Features learned for both objectives
# - Can have conflicting gradients

# Separate: Independent networks
class SeparateActorCritic(nn.Module):
    def __init__(self, obs_dim, action_dim, hidden_dim=64):
        super().__init__()

        # Separate policy network
        self.policy_net = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, action_dim),
        )
        self.policy_log_std = nn.Parameter(torch.zeros(action_dim))

        # Separate value network
        self.value_net = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1),
        )

# Which to use?
# - Simple tasks (CartPole): Shared is fine
# - Complex tasks (MuJoCo): Separate often better
# - LLMs/RLHF: Separate (huge policy network)

MuJoCo Environments

Continuous control benchmarks:

import gymnasium as gym

# MuJoCo environments (continuous action spaces)
env_ids = [
    "HalfCheetah-v4",    # 6 DoF, run fast
    "Hopper-v4",         # 3 DoF, hop forward
    "Walker2d-v4",       # 6 DoF, walk forward
    "Ant-v4",            # 8 DoF, walk in any direction
    "Humanoid-v4",       # 17 DoF, humanoid walking
]

# Key differences from discrete (Atari):
# 1. Continuous actions: Use Normal distribution
# 2. Dense rewards: Reward at every step
# 3. No frame stacking: Single observation
# 4. Smaller networks: 64-256 hidden units

# Standard PPO hyperparameters for MuJoCo:
mujoco_config = {
    'num_envs': 1,          # Often just 1 for MuJoCo
    'num_steps': 2048,      # Steps per rollout
    'num_epochs': 10,       # Updates per batch
    'minibatch_size': 64,
    'learning_rate': 3e-4,
    'gamma': 0.99,
    'gae_lambda': 0.95,
    'clip_epsilon': 0.2,
    'entropy_coef': 0.0,    # Often 0 for MuJoCo
    'value_coef': 0.5,
}

Full Training Loop

def train_ppo(
    env_id,
    total_timesteps=1_000_000,
    num_envs=8,
    num_steps=128,
    num_epochs=4,
    minibatch_size=256,
    lr=3e-4,
    gamma=0.99,
    gae_lambda=0.95,
    clip_epsilon=0.2,
    entropy_coef=0.01,
    value_coef=0.5,
    max_grad_norm=0.5,
):
    """Complete PPO training loop."""

    # Create environments
    envs = SyncVectorEnv([make_env(env_id, 42, i) for i in range(num_envs)])
    obs_dim = envs.single_observation_space.shape[0]
    action_dim = envs.single_action_space.n  # For discrete

    # Create networks
    agent = ActorCritic(obs_dim, action_dim)
    optimizer = torch.optim.Adam(agent.parameters(), lr=lr)

    # Create buffer
    buffer = RolloutBuffer(num_steps, num_envs, (obs_dim,), ())

    num_updates = total_timesteps // (num_envs * num_steps)

    for update in range(num_updates):
        # Anneal learning rate
        frac = 1 - update / num_updates
        optimizer.param_groups[0]['lr'] = lr * frac

        # Collect rollouts
        buffer.reset()
        final_value = collect_rollouts(
            agent.get_policy, agent.get_value, envs, buffer, num_steps
        )

        # Compute advantages
        compute_gae(buffer, final_value, gamma, gae_lambda)

        # PPO update
        for epoch in range(num_epochs):
            for batch in create_minibatches(buffer, minibatch_size):
                losses = compute_ppo_loss(
                    agent.get_policy,
                    agent.get_value,
                    batch,
                    clip_epsilon,
                    value_coef,
                    entropy_coef,
                )

                optimizer.zero_grad()
                losses['loss'].backward()
                nn.utils.clip_grad_norm_(agent.parameters(), max_grad_norm)
                optimizer.step()

        # Logging
        if update % 10 == 0:
            print(f"Update {update}: "
                  f"approx_kl={losses['approx_kl']:.4f}, "
                  f"clip_frac={losses['clip_fraction']:.4f}")

    return agent

Implementation Pitfalls

Common bugs and how to avoid them:

# 1. Forgetting to normalize advantages
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# Without this: Training is unstable

# 2. Wrong ratio computation
ratio = (new_log_probs - old_log_probs).exp()  # Correct
# ratio = new_probs / old_probs  # Numerically unstable!

# 3. Not detaching old log probs
# old_log_probs should be from BEFORE the update, not recomputed

# 4. Gradient through value target
returns = advantages + values.detach()  # Detach values!
# Otherwise: Value loss gradients affect advantage computation

# 5. Wrong done handling in GAE
# done=True means episode ended, NOT that next state is terminal
# Use terminated, not truncated, for bootstrapping

# 6. Forgetting gradient clipping
nn.utils.clip_grad_norm_(agent.parameters(), max_grad_norm)
# Without this: Exploding gradients

Debugging PPO

Key metrics to monitor:

def log_training_metrics(losses, buffer):
    """Metrics that indicate training health."""

    print(f"""
    === Training Diagnostics ===
    approx_kl:     {losses['approx_kl']:.4f}  # Should be < 0.02
    clip_fraction: {losses['clip_fraction']:.4f}  # Should be 0.1-0.3
    entropy:       {losses['entropy']:.4f}  # Should decrease slowly
    value_loss:    {losses['value_loss']:.4f}  # Should decrease
    policy_loss:   {losses['policy_loss']:.4f}  # Can be noisy

    === Warning Signs ===
    - approx_kl > 0.05: Learning rate too high
    - clip_fraction > 0.5: Updates too aggressive
    - clip_fraction = 0: Epsilon too large or lr too small
    - entropy = 0: Policy collapsed, no exploration
    - value_loss increasing: Value network diverging
    """)

Capstone Connection

PPO is how we train AI on human preferences:

# RLHF uses PPO because:
# 1. Stability: Can't destroy the base model
# 2. KL constraint: Stays close to original policy
# 3. Sample efficiency: Expensive to get human feedback

class RLHFTrainer:
    """
    PPO-based RLHF training.
    """
    def __init__(self, policy, ref_policy, reward_model):
        self.policy = policy          # Model we're training
        self.ref_policy = ref_policy  # Frozen original model
        self.reward_model = reward_model

    def compute_rewards(self, prompts, responses):
        """
        Reward = RM score - KL penalty
        """
        rm_scores = self.reward_model(prompts, responses)

        # KL penalty keeps policy close to reference
        policy_logprobs = self.policy.log_prob(responses)
        ref_logprobs = self.ref_policy.log_prob(responses)
        kl_penalty = policy_logprobs - ref_logprobs

        rewards = rm_scores - self.kl_coef * kl_penalty
        return rewards

    def update(self, prompts):
        """
        One PPO update step for RLHF.
        """
        # Generate responses with current policy
        responses = self.policy.generate(prompts)

        # Compute rewards
        rewards = self.compute_rewards(prompts, responses)

        # Standard PPO update
        # The clipping prevents reward hacking
        # by limiting how far policy can move

# This is how ChatGPT, Claude, etc. are trained
# PPO is the workhorse of AI alignment

From PPO to DPO

Modern alternatives to PPO for RLHF:

# PPO for RLHF is complex:
# - Need reward model
# - Need reference policy
# - Need to generate samples
# - Hyperparameter sensitive

# DPO (Direct Preference Optimization) simplifies:
# - No reward model
# - No RL at all!
# - Direct optimization from preferences

def dpo_loss(policy, ref_policy, preferred, rejected, beta=0.1):
    """
    DPO: Train directly on preference pairs.
    """
    # Log prob differences
    policy_log_ratio = (
        policy.log_prob(preferred) - policy.log_prob(rejected)
    )
    ref_log_ratio = (
        ref_policy.log_prob(preferred) - ref_policy.log_prob(rejected)
    )

    # DPO loss
    loss = -F.logsigmoid(beta * (policy_log_ratio - ref_log_ratio))
    return loss.mean()

# But PPO is still important:
# 1. More flexible (any reward signal)
# 2. Better for complex objectives
# 3. Foundation of online RLHF

🎓 Tyla's Exercise

  1. Prove that PPO with infinite epsilon (no clipping) reduces to vanilla policy gradient. What does this say about the role of clipping?

  2. The GAE formula is: A_t = sum_{l=0}^{inf} (gamma * lambda)^l * delta_{t+l}. Derive this from first principles. What is the effective horizon of GAE when lambda = 0.95 and gamma = 0.99?

  3. Why does normalizing advantages improve training stability? Prove that normalized advantages have zero mean. What about variance?


💻 Aaliyah's Exercise

Implement a complete PPO trainer:

class PPOTrainer:
    """
    Your task: Implement each method.
    """
    def __init__(self, env_id, config):
        # TODO: Initialize environments, networks, optimizer, buffer
        pass

    def collect_rollouts(self):
        # TODO: Fill buffer with num_steps of experience
        pass

    def compute_gae(self):
        # TODO: Compute advantages and returns
        pass

    def update(self):
        # TODO: Run num_epochs of minibatch updates
        pass

    def train(self, total_timesteps):
        # TODO: Main training loop with logging
        pass

# Test on CartPole-v1: Should reach 500 reward
# Test on LunarLander-v2: Should reach 200+ reward
# Bonus: Test on HalfCheetah-v4 (continuous)

📚 Maneesha's Reflection

  1. PPO collects experience, then trains on it multiple times. This is like studying the same material repeatedly. What learning science principles explain why this works (spaced repetition, interleaving, etc.)?

  2. The entropy bonus encourages exploration - trying new things even when current strategy works. How does this relate to human learning? When is it beneficial vs harmful to explore vs exploit what we know?

  3. RLHF uses PPO to align AI with human preferences. But the "human preferences" are from paid contractors answering questions. What are the limitations of this approach? How might it create blind spots in AI systems?