PPO Implementation: From Theory to Code
Building PPO from scratch reveals why each component matters.
The PPO Algorithm Structure
# PPO pseudocode
for iteration in range(num_iterations):
# 1. Collect rollouts with current policy
trajectories = collect_rollouts(policy, envs, num_steps)
# 2. Compute advantages
advantages = compute_gae(trajectories)
# 3. Multiple epochs of minibatch updates
for epoch in range(num_epochs):
for minibatch in create_minibatches(trajectories):
# Update policy and value function
loss = compute_ppo_loss(minibatch, advantages)
optimizer.step(loss)
Let's implement each piece.
Vectorized Environments
PPO's sample efficiency comes from parallel environments:
import gymnasium as gym
from gymnasium.vector import SyncVectorEnv, AsyncVectorEnv
def make_env(env_id, seed, idx):
"""Factory function for creating environments."""
def thunk():
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
env.reset(seed=seed + idx)
return env
return thunk
# Create vectorized environment
num_envs = 8
envs = SyncVectorEnv([
make_env("CartPole-v1", seed=42, idx=i)
for i in range(num_envs)
])
# Now actions/observations are batched
obs, info = envs.reset() # obs.shape = (num_envs, obs_dim)
actions = policy(obs) # actions.shape = (num_envs,)
next_obs, rewards, terminated, truncated, infos = envs.step(actions)
Why Vectorized Environments?
# Single environment: Sequential, slow
# for step in range(1000):
# obs -> action -> next_obs # One at a time
# Vectorized: Parallel, fast
# obs_batch -> action_batch -> next_obs_batch # All at once
# Benefits:
# 1. GPU parallelism: Batch inference is efficient
# 2. Decorrelated samples: Different envs = different states
# 3. More data: 8 envs = 8x the experience per wall-clock second
# The math: With 8 envs and 128 steps each, one rollout gives:
# 8 * 128 = 1024 transitions for training
The Rollout Buffer
Store trajectories for training:
class RolloutBuffer:
"""Stores rollout data for PPO training."""
def __init__(self, num_steps, num_envs, obs_shape, action_shape):
self.obs = torch.zeros((num_steps, num_envs) + obs_shape)
self.actions = torch.zeros((num_steps, num_envs) + action_shape)
self.log_probs = torch.zeros((num_steps, num_envs))
self.rewards = torch.zeros((num_steps, num_envs))
self.dones = torch.zeros((num_steps, num_envs))
self.values = torch.zeros((num_steps, num_envs))
self.advantages = torch.zeros((num_steps, num_envs))
self.returns = torch.zeros((num_steps, num_envs))
self.ptr = 0
self.num_steps = num_steps
def store(self, obs, action, log_prob, reward, done, value):
"""Store one step of experience."""
self.obs[self.ptr] = obs
self.actions[self.ptr] = action
self.log_probs[self.ptr] = log_prob
self.rewards[self.ptr] = reward
self.dones[self.ptr] = done
self.values[self.ptr] = value
self.ptr += 1
def reset(self):
self.ptr = 0
Collecting Rollouts
def collect_rollouts(policy, value_net, envs, buffer, num_steps):
"""
Collect num_steps of experience from vectorized environments.
"""
obs, _ = envs.reset()
obs = torch.tensor(obs, dtype=torch.float32)
for step in range(num_steps):
with torch.no_grad():
# Get action distribution
dist = policy(obs)
action = dist.sample()
log_prob = dist.log_prob(action)
# Get value estimate
value = value_net(obs).squeeze(-1)
# Step environment
next_obs, reward, terminated, truncated, info = envs.step(
action.cpu().numpy()
)
done = terminated | truncated
# Store transition
buffer.store(
obs=obs,
action=action,
log_prob=log_prob,
reward=torch.tensor(reward),
done=torch.tensor(done, dtype=torch.float32),
value=value
)
obs = torch.tensor(next_obs, dtype=torch.float32)
# Get final value for GAE computation
with torch.no_grad():
final_value = value_net(obs).squeeze(-1)
return final_value
Computing GAE
def compute_gae(buffer, final_value, gamma=0.99, gae_lambda=0.95):
"""
Compute Generalized Advantage Estimation.
"""
advantages = torch.zeros_like(buffer.rewards)
last_gae = 0
for t in reversed(range(buffer.num_steps)):
if t == buffer.num_steps - 1:
next_value = final_value
next_non_terminal = 1.0 - buffer.dones[t]
else:
next_value = buffer.values[t + 1]
next_non_terminal = 1.0 - buffer.dones[t]
# TD error: r + gamma * V(s') - V(s)
delta = (
buffer.rewards[t]
+ gamma * next_value * next_non_terminal
- buffer.values[t]
)
# GAE: exponentially-weighted average of TD errors
advantages[t] = delta + gamma * gae_lambda * next_non_terminal * last_gae
last_gae = advantages[t]
buffer.advantages = advantages
buffer.returns = advantages + buffer.values
Minibatch Updates
The key to PPO's sample efficiency:
def create_minibatches(buffer, batch_size):
"""
Yield minibatches for multiple epochs of training.
"""
# Flatten the buffer
num_samples = buffer.num_steps * buffer.obs.shape[1] # steps * envs
obs = buffer.obs.reshape(num_samples, -1)
actions = buffer.actions.reshape(num_samples, -1)
log_probs = buffer.log_probs.reshape(num_samples)
advantages = buffer.advantages.reshape(num_samples)
returns = buffer.returns.reshape(num_samples)
# Normalize advantages (important for stability!)
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# Random permutation
indices = torch.randperm(num_samples)
for start in range(0, num_samples, batch_size):
end = start + batch_size
batch_indices = indices[start:end]
yield {
'obs': obs[batch_indices],
'actions': actions[batch_indices],
'old_log_probs': log_probs[batch_indices],
'advantages': advantages[batch_indices],
'returns': returns[batch_indices],
}
Why Multiple Epochs?
# Traditional RL: Use each sample once, then discard
# PPO: Reuse each sample for num_epochs updates
# Why this works:
# 1. Clipping prevents large updates even with reused data
# 2. Each epoch uses different minibatch orderings
# 3. Significantly improves sample efficiency
# Typical values:
# num_epochs = 4 (Atari)
# num_epochs = 10 (MuJoCo)
# Warning: Too many epochs -> overfitting to current batch
# Monitor: clip_fraction and approx_kl should stay reasonable
The PPO Loss Function
def compute_ppo_loss(
policy,
value_net,
batch,
clip_epsilon=0.2,
value_coef=0.5,
entropy_coef=0.01
):
"""
Compute the full PPO loss.
"""
# Get new policy distribution
dist = policy(batch['obs'])
new_log_probs = dist.log_prob(batch['actions'].squeeze(-1))
entropy = dist.entropy().mean()
# Get new value predictions
new_values = value_net(batch['obs']).squeeze(-1)
# --- Policy Loss (clipped surrogate) ---
ratio = (new_log_probs - batch['old_log_probs']).exp()
# Clipped objective
surr1 = ratio * batch['advantages']
surr2 = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * batch['advantages']
policy_loss = -torch.min(surr1, surr2).mean()
# --- Value Loss ---
value_loss = 0.5 * ((new_values - batch['returns']) ** 2).mean()
# --- Total Loss ---
total_loss = policy_loss + value_coef * value_loss - entropy_coef * entropy
# Diagnostics
with torch.no_grad():
approx_kl = ((ratio - 1) - ratio.log()).mean()
clip_fraction = ((ratio - 1).abs() > clip_epsilon).float().mean()
return {
'loss': total_loss,
'policy_loss': policy_loss,
'value_loss': value_loss,
'entropy': entropy,
'approx_kl': approx_kl,
'clip_fraction': clip_fraction,
}
Value Function Clipping
Some implementations clip value updates too:
def clipped_value_loss(
new_values,
old_values,
returns,
clip_epsilon=0.2
):
"""
Clipped value loss (optional, used in some implementations).
Idea: Don't let value predictions change too much either.
"""
# Unclipped loss
value_loss_unclipped = (new_values - returns) ** 2
# Clipped value prediction
values_clipped = old_values + torch.clamp(
new_values - old_values,
-clip_epsilon,
clip_epsilon
)
value_loss_clipped = (values_clipped - returns) ** 2
# Take maximum (pessimistic)
value_loss = 0.5 * torch.max(value_loss_unclipped, value_loss_clipped).mean()
return value_loss
# Note: Value clipping is controversial
# - Original PPO paper: Does not use it
# - OpenAI baselines: Uses it
# - Empirically: Sometimes helps, sometimes hurts
# - Recommendation: Try both, keep what works
Entropy Bonus
Encourages exploration:
# Entropy measures "randomness" of the policy
# High entropy: Actions spread across many options
# Low entropy: Policy concentrated on one action
# Why add entropy bonus?
# 1. Prevents premature convergence to deterministic policy
# 2. Maintains exploration
# 3. Acts as regularization
def compute_entropy(dist):
"""
Entropy for common distributions.
"""
if isinstance(dist, torch.distributions.Categorical):
# H = -sum(p * log(p))
return dist.entropy()
elif isinstance(dist, torch.distributions.Normal):
# H = 0.5 * log(2 * pi * e * sigma^2)
return dist.entropy().sum(dim=-1)
# Entropy coefficient decay:
# Early training: High entropy (explore)
# Late training: Low entropy (exploit)
def get_entropy_coef(step, total_steps, start=0.01, end=0.001):
return start - (start - end) * (step / total_steps)
The Actor-Critic Networks
import torch.nn as nn
class ActorCritic(nn.Module):
"""
Shared network for policy and value function.
"""
def __init__(self, obs_dim, action_dim, hidden_dim=64):
super().__init__()
# Shared feature extractor
self.shared = nn.Sequential(
nn.Linear(obs_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
)
# Policy head
self.policy_mean = nn.Linear(hidden_dim, action_dim)
self.policy_log_std = nn.Parameter(torch.zeros(action_dim))
# Value head
self.value = nn.Linear(hidden_dim, 1)
# Initialize with small weights
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
nn.init.constant_(m.bias, 0)
def forward(self, obs):
features = self.shared(obs)
return features
def get_policy(self, obs):
features = self.forward(obs)
mean = self.policy_mean(features)
std = self.policy_log_std.exp()
return torch.distributions.Normal(mean, std)
def get_value(self, obs):
features = self.forward(obs)
return self.value(features)
Separate vs Shared Networks
# Shared (above): Policy and value share feature extractor
# - Fewer parameters
# - Features learned for both objectives
# - Can have conflicting gradients
# Separate: Independent networks
class SeparateActorCritic(nn.Module):
def __init__(self, obs_dim, action_dim, hidden_dim=64):
super().__init__()
# Separate policy network
self.policy_net = nn.Sequential(
nn.Linear(obs_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, action_dim),
)
self.policy_log_std = nn.Parameter(torch.zeros(action_dim))
# Separate value network
self.value_net = nn.Sequential(
nn.Linear(obs_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, 1),
)
# Which to use?
# - Simple tasks (CartPole): Shared is fine
# - Complex tasks (MuJoCo): Separate often better
# - LLMs/RLHF: Separate (huge policy network)
MuJoCo Environments
Continuous control benchmarks:
import gymnasium as gym
# MuJoCo environments (continuous action spaces)
env_ids = [
"HalfCheetah-v4", # 6 DoF, run fast
"Hopper-v4", # 3 DoF, hop forward
"Walker2d-v4", # 6 DoF, walk forward
"Ant-v4", # 8 DoF, walk in any direction
"Humanoid-v4", # 17 DoF, humanoid walking
]
# Key differences from discrete (Atari):
# 1. Continuous actions: Use Normal distribution
# 2. Dense rewards: Reward at every step
# 3. No frame stacking: Single observation
# 4. Smaller networks: 64-256 hidden units
# Standard PPO hyperparameters for MuJoCo:
mujoco_config = {
'num_envs': 1, # Often just 1 for MuJoCo
'num_steps': 2048, # Steps per rollout
'num_epochs': 10, # Updates per batch
'minibatch_size': 64,
'learning_rate': 3e-4,
'gamma': 0.99,
'gae_lambda': 0.95,
'clip_epsilon': 0.2,
'entropy_coef': 0.0, # Often 0 for MuJoCo
'value_coef': 0.5,
}
Full Training Loop
def train_ppo(
env_id,
total_timesteps=1_000_000,
num_envs=8,
num_steps=128,
num_epochs=4,
minibatch_size=256,
lr=3e-4,
gamma=0.99,
gae_lambda=0.95,
clip_epsilon=0.2,
entropy_coef=0.01,
value_coef=0.5,
max_grad_norm=0.5,
):
"""Complete PPO training loop."""
# Create environments
envs = SyncVectorEnv([make_env(env_id, 42, i) for i in range(num_envs)])
obs_dim = envs.single_observation_space.shape[0]
action_dim = envs.single_action_space.n # For discrete
# Create networks
agent = ActorCritic(obs_dim, action_dim)
optimizer = torch.optim.Adam(agent.parameters(), lr=lr)
# Create buffer
buffer = RolloutBuffer(num_steps, num_envs, (obs_dim,), ())
num_updates = total_timesteps // (num_envs * num_steps)
for update in range(num_updates):
# Anneal learning rate
frac = 1 - update / num_updates
optimizer.param_groups[0]['lr'] = lr * frac
# Collect rollouts
buffer.reset()
final_value = collect_rollouts(
agent.get_policy, agent.get_value, envs, buffer, num_steps
)
# Compute advantages
compute_gae(buffer, final_value, gamma, gae_lambda)
# PPO update
for epoch in range(num_epochs):
for batch in create_minibatches(buffer, minibatch_size):
losses = compute_ppo_loss(
agent.get_policy,
agent.get_value,
batch,
clip_epsilon,
value_coef,
entropy_coef,
)
optimizer.zero_grad()
losses['loss'].backward()
nn.utils.clip_grad_norm_(agent.parameters(), max_grad_norm)
optimizer.step()
# Logging
if update % 10 == 0:
print(f"Update {update}: "
f"approx_kl={losses['approx_kl']:.4f}, "
f"clip_frac={losses['clip_fraction']:.4f}")
return agent
Implementation Pitfalls
Common bugs and how to avoid them:
# 1. Forgetting to normalize advantages
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# Without this: Training is unstable
# 2. Wrong ratio computation
ratio = (new_log_probs - old_log_probs).exp() # Correct
# ratio = new_probs / old_probs # Numerically unstable!
# 3. Not detaching old log probs
# old_log_probs should be from BEFORE the update, not recomputed
# 4. Gradient through value target
returns = advantages + values.detach() # Detach values!
# Otherwise: Value loss gradients affect advantage computation
# 5. Wrong done handling in GAE
# done=True means episode ended, NOT that next state is terminal
# Use terminated, not truncated, for bootstrapping
# 6. Forgetting gradient clipping
nn.utils.clip_grad_norm_(agent.parameters(), max_grad_norm)
# Without this: Exploding gradients
Debugging PPO
Key metrics to monitor:
def log_training_metrics(losses, buffer):
"""Metrics that indicate training health."""
print(f"""
=== Training Diagnostics ===
approx_kl: {losses['approx_kl']:.4f} # Should be < 0.02
clip_fraction: {losses['clip_fraction']:.4f} # Should be 0.1-0.3
entropy: {losses['entropy']:.4f} # Should decrease slowly
value_loss: {losses['value_loss']:.4f} # Should decrease
policy_loss: {losses['policy_loss']:.4f} # Can be noisy
=== Warning Signs ===
- approx_kl > 0.05: Learning rate too high
- clip_fraction > 0.5: Updates too aggressive
- clip_fraction = 0: Epsilon too large or lr too small
- entropy = 0: Policy collapsed, no exploration
- value_loss increasing: Value network diverging
""")
Capstone Connection
PPO is how we train AI on human preferences:
# RLHF uses PPO because:
# 1. Stability: Can't destroy the base model
# 2. KL constraint: Stays close to original policy
# 3. Sample efficiency: Expensive to get human feedback
class RLHFTrainer:
"""
PPO-based RLHF training.
"""
def __init__(self, policy, ref_policy, reward_model):
self.policy = policy # Model we're training
self.ref_policy = ref_policy # Frozen original model
self.reward_model = reward_model
def compute_rewards(self, prompts, responses):
"""
Reward = RM score - KL penalty
"""
rm_scores = self.reward_model(prompts, responses)
# KL penalty keeps policy close to reference
policy_logprobs = self.policy.log_prob(responses)
ref_logprobs = self.ref_policy.log_prob(responses)
kl_penalty = policy_logprobs - ref_logprobs
rewards = rm_scores - self.kl_coef * kl_penalty
return rewards
def update(self, prompts):
"""
One PPO update step for RLHF.
"""
# Generate responses with current policy
responses = self.policy.generate(prompts)
# Compute rewards
rewards = self.compute_rewards(prompts, responses)
# Standard PPO update
# The clipping prevents reward hacking
# by limiting how far policy can move
# This is how ChatGPT, Claude, etc. are trained
# PPO is the workhorse of AI alignment
From PPO to DPO
Modern alternatives to PPO for RLHF:
# PPO for RLHF is complex:
# - Need reward model
# - Need reference policy
# - Need to generate samples
# - Hyperparameter sensitive
# DPO (Direct Preference Optimization) simplifies:
# - No reward model
# - No RL at all!
# - Direct optimization from preferences
def dpo_loss(policy, ref_policy, preferred, rejected, beta=0.1):
"""
DPO: Train directly on preference pairs.
"""
# Log prob differences
policy_log_ratio = (
policy.log_prob(preferred) - policy.log_prob(rejected)
)
ref_log_ratio = (
ref_policy.log_prob(preferred) - ref_policy.log_prob(rejected)
)
# DPO loss
loss = -F.logsigmoid(beta * (policy_log_ratio - ref_log_ratio))
return loss.mean()
# But PPO is still important:
# 1. More flexible (any reward signal)
# 2. Better for complex objectives
# 3. Foundation of online RLHF
🎓 Tyla's Exercise
Prove that PPO with infinite epsilon (no clipping) reduces to vanilla policy gradient. What does this say about the role of clipping?
The GAE formula is: A_t = sum_{l=0}^{inf} (gamma * lambda)^l * delta_{t+l}. Derive this from first principles. What is the effective horizon of GAE when lambda = 0.95 and gamma = 0.99?
Why does normalizing advantages improve training stability? Prove that normalized advantages have zero mean. What about variance?
💻 Aaliyah's Exercise
Implement a complete PPO trainer:
class PPOTrainer:
"""
Your task: Implement each method.
"""
def __init__(self, env_id, config):
# TODO: Initialize environments, networks, optimizer, buffer
pass
def collect_rollouts(self):
# TODO: Fill buffer with num_steps of experience
pass
def compute_gae(self):
# TODO: Compute advantages and returns
pass
def update(self):
# TODO: Run num_epochs of minibatch updates
pass
def train(self, total_timesteps):
# TODO: Main training loop with logging
pass
# Test on CartPole-v1: Should reach 500 reward
# Test on LunarLander-v2: Should reach 200+ reward
# Bonus: Test on HalfCheetah-v4 (continuous)
📚 Maneesha's Reflection
PPO collects experience, then trains on it multiple times. This is like studying the same material repeatedly. What learning science principles explain why this works (spaced repetition, interleaving, etc.)?
The entropy bonus encourages exploration - trying new things even when current strategy works. How does this relate to human learning? When is it beneficial vs harmful to explore vs exploit what we know?
RLHF uses PPO to align AI with human preferences. But the "human preferences" are from paid contractors answering questions. What are the limitations of this approach? How might it create blind spots in AI systems?