PPO: Proximal Policy Optimization
The algorithm that made RLHF possible. Simple enough to implement, stable enough to scale.
Why Policy Gradients Fail
Vanilla policy gradient has a fatal flaw: update magnitude.
# Vanilla policy gradient
loss = -log_prob(action) * advantage
# Problem: if advantage is huge, gradient is huge
# Result: policy jumps too far, performance collapses
One bad update can destroy a policy that took hours to train.
Trust Regions: The Core Idea
What if we constrained how much the policy can change?
# Trust Region concept:
# "Only update the policy within a region where our
# estimates are trustworthy"
# Old policy: pi_old(a|s)
# New policy: pi_new(a|s)
# Constraint: KL(pi_old, pi_new) < delta
# If new policy is too different from old,
# our advantage estimates become unreliable
The advantage was computed using the old policy. If the new policy is very different, those advantages are stale.
TRPO: Trust Region Policy Optimization
Schulman et al. (2015) formalized this:
# TRPO objective:
# maximize E[r(theta) * A]
# subject to: KL(pi_old, pi_new) <= delta
# where r(theta) = pi_new(a|s) / pi_old(a|s)
# (probability ratio)
def trpo_objective(old_policy, new_policy, states, actions, advantages):
"""
TRPO maximizes expected advantage while constraining KL divergence.
"""
old_probs = old_policy(states).log_prob(actions).exp()
new_probs = new_policy(states).log_prob(actions).exp()
ratio = new_probs / old_probs
objective = (ratio * advantages).mean()
# Constraint: KL divergence must be small
kl = kl_divergence(old_policy(states), new_policy(states)).mean()
return objective, kl # Need constrained optimization
Problem: constrained optimization is complex and slow.
PPO: Making Trust Regions Practical
PPO's insight: replace the constraint with a clipped objective.
# PPO-Clip objective:
# L = min(r * A, clip(r, 1-eps, 1+eps) * A)
def ppo_clip_objective(ratio, advantages, epsilon=0.2):
"""
The famous PPO clipped surrogate objective.
ratio: pi_new(a|s) / pi_old(a|s)
advantages: estimated advantages
epsilon: clipping parameter (typically 0.1-0.3)
"""
# Unclipped objective
unclipped = ratio * advantages
# Clipped objective
clipped_ratio = torch.clamp(ratio, 1 - epsilon, 1 + epsilon)
clipped = clipped_ratio * advantages
# Take minimum (pessimistic bound)
loss = -torch.min(unclipped, clipped).mean()
return loss
Understanding the Clip
The clipping creates a "pessimistic" bound:
# Case 1: Positive advantage (good action)
# ratio > 1: action more likely now -> clipped at 1+eps
# ratio < 1: action less likely now -> no clipping
# Result: Can't increase probability too much
# Case 2: Negative advantage (bad action)
# ratio > 1: action more likely now -> no clipping (want to fix this!)
# ratio < 1: action less likely now -> clipped at 1-eps
# Result: Can't decrease probability too much
def visualize_clipping(advantages, epsilon=0.2):
"""
Shows when clipping activates.
"""
ratios = torch.linspace(0.5, 1.5, 100)
for A in advantages:
unclipped = ratios * A
clipped_ratios = torch.clamp(ratios, 1-epsilon, 1+epsilon)
clipped = clipped_ratios * A
objective = torch.min(unclipped, clipped)
# Plot shows flat regions where clipping prevents updates
The Surrogate Loss Derivation
Why does clipping approximate the trust region?
# When ratio = 1, policies are identical
# When ratio = 1 +/- eps, KL is approximately bounded
# Taylor expansion: KL(p, q) ~ (1/2) * E[(p-q)^2/p]
# For small changes: ratio near 1 means KL is small
# Clipping at [1-eps, 1+eps] approximately enforces
# that the policy hasn't changed too much
# This is an approximation, not exact!
# But empirically works very well
PPO Hyperparameters
The key hyperparameters and their effects:
class PPOConfig:
# Clipping
clip_epsilon: float = 0.2 # 0.1-0.3 typical
# Learning
learning_rate: float = 3e-4 # Often 1e-4 to 3e-4
num_epochs: int = 10 # Updates per batch
num_minibatches: int = 32 # Minibatches per epoch
# GAE (Generalized Advantage Estimation)
gamma: float = 0.99 # Discount factor
gae_lambda: float = 0.95 # GAE parameter
# Regularization
entropy_coef: float = 0.01 # Entropy bonus
value_coef: float = 0.5 # Value loss weight
max_grad_norm: float = 0.5 # Gradient clipping
# Data collection
num_steps: int = 2048 # Steps per rollout
num_envs: int = 8 # Parallel environments
Epsilon: The Clipping Parameter
# epsilon controls trust region size
# epsilon = 0.1: Conservative updates
# - More stable but slower learning
# - Good for complex/sensitive tasks
# epsilon = 0.2: Standard (default)
# - Good balance for most tasks
# - OpenAI's recommendation
# epsilon = 0.3: Aggressive updates
# - Faster learning but less stable
# - Risk of policy collapse
def adaptive_epsilon(episode, max_episodes, start=0.3, end=0.1):
"""Some implementations anneal epsilon over training."""
return start - (start - end) * (episode / max_episodes)
Learning Rate Sensitivity
# PPO is sensitive to learning rate
# Too high: Policy collapse
# - Updates overshoot even with clipping
# - Ratio goes to 0 or infinity
# Too low: Slow learning
# - Never reaches optimal policy
# - Wastes compute
# The interaction: lr and epsilon together determine update magnitude
# High lr + high eps = unstable
# Low lr + low eps = stable but slow
# Common practice: linear annealing
def get_lr(update, total_updates, initial_lr=3e-4):
return initial_lr * (1 - update / total_updates)
Generalized Advantage Estimation
PPO typically uses GAE for advantage computation:
def compute_gae(rewards, values, dones, gamma=0.99, gae_lambda=0.95):
"""
Generalized Advantage Estimation (Schulman et al., 2015)
Balances bias vs variance in advantage estimates.
lambda=0: High bias, low variance (TD)
lambda=1: Low bias, high variance (MC)
"""
advantages = torch.zeros_like(rewards)
last_gae = 0
for t in reversed(range(len(rewards))):
if t == len(rewards) - 1:
next_value = 0
else:
next_value = values[t + 1]
# TD error
delta = rewards[t] + gamma * next_value * (1 - dones[t]) - values[t]
# GAE
advantages[t] = delta + gamma * gae_lambda * (1 - dones[t]) * last_gae
last_gae = advantages[t]
returns = advantages + values
return advantages, returns
GAE Lambda
# gae_lambda controls bias-variance tradeoff
# lambda = 0: TD(0)
# advantage = r + gamma*V(s') - V(s)
# High bias: V might be wrong
# Low variance: Only one step of randomness
# lambda = 1: Monte Carlo
# advantage = sum(gamma^t * r_t) - V(s)
# Low bias: Uses actual returns
# High variance: Many steps of randomness
# lambda = 0.95: Standard choice
# Exponentially-weighted average of n-step returns
# Good empirical tradeoff
# Intuition: Trust value estimates for ~20 steps (1/(1-0.95))
Capstone Connection
PPO is the foundation of RLHF:
# In RLHF (Reinforcement Learning from Human Feedback):
# 1. Train reward model on human preferences
# 2. Use PPO to optimize policy against reward model
def rlhf_training_step(policy, reward_model, prompts):
"""
The core RLHF loop uses PPO.
"""
# Generate responses
responses = policy.generate(prompts)
# Get rewards from trained reward model
rewards = reward_model(prompts, responses)
# PPO update
# The clipping prevents the policy from
# "hacking" the reward model by moving too far
# from the original (SFT) policy
ppo_loss = compute_ppo_loss(
policy=policy,
old_policy=old_policy, # Snapshot from start of batch
responses=responses,
rewards=rewards,
kl_penalty=0.01 # Additional KL term for stability
)
return ppo_loss
# Why PPO for RLHF?
# 1. Stable: Can't destroy the base model in one update
# 2. Sample efficient: Reuses data with multiple epochs
# 3. Simple: No complex constrained optimization
PPO vs TRPO vs Other Methods
| Method | Constraint | Complexity | Stability |
|---|---|---|---|
| Vanilla PG | None | Simple | Unstable |
| TRPO | KL constraint | Complex | Stable |
| PPO-Clip | Clipped ratio | Simple | Stable |
| PPO-KL | Adaptive KL | Medium | Stable |
# PPO-KL (alternative): Penalize KL instead of clipping
def ppo_kl_objective(ratio, advantages, old_dist, new_dist, beta=0.01):
surrogate = (ratio * advantages).mean()
kl_penalty = kl_divergence(old_dist, new_dist).mean()
return surrogate - beta * kl_penalty
# PPO-Clip is more common because:
# 1. No need to tune beta
# 2. Often works better empirically
🎓 Tyla's Exercise
Prove that when ratio = 1, the PPO objective equals the vanilla policy gradient objective. What does this mean about the first update after collecting a batch?
The PPO paper claims clipping is a "lower bound" on the unclipped objective when advantages are positive. Verify this mathematically. What happens when advantages are negative?
Derive the approximate relationship between epsilon and the KL divergence bound. Under what assumptions does clipping at [1-eps, 1+eps] approximately bound KL?
💻 Aaliyah's Exercise
Build a PPO objective function with diagnostics:
def ppo_objective_with_diagnostics(
old_log_probs: torch.Tensor,
new_log_probs: torch.Tensor,
advantages: torch.Tensor,
epsilon: float = 0.2
) -> dict:
"""
Compute PPO loss and return useful diagnostics.
Returns:
- loss: The PPO clipped loss
- clip_fraction: What fraction of samples were clipped?
- approx_kl: Approximate KL divergence
- ratio_mean: Mean probability ratio
- ratio_std: Std of probability ratio
"""
# Compute ratio
ratio = (new_log_probs - old_log_probs).exp()
# TODO: Implement clipped objective
# TODO: Compute clip_fraction
# TODO: Compute approx_kl = ((ratio - 1) - log(ratio)).mean()
# TODO: Return all diagnostics
pass
Test: Verify that clip_fraction increases when learning rate is too high.
📚 Maneesha's Reflection
PPO's "trust region" is about trusting our estimates within a certain range. What analogies from human learning capture this idea? When do we trust our intuitions vs seek external validation?
The clip creates a "pessimistic" bound - always assuming the worst case. This is conservative but safe. Where else in life do we apply pessimistic reasoning, and when is it appropriate vs limiting?
PPO made RLHF practical by being "stable enough." What does stability mean in the context of training AI systems? Why might instability be especially concerning when fine-tuning on human preferences?