PPO: Proximal Policy Optimization

The algorithm that made RLHF possible. Simple enough to implement, stable enough to scale.


Why Policy Gradients Fail

Vanilla policy gradient has a fatal flaw: update magnitude.

# Vanilla policy gradient
loss = -log_prob(action) * advantage

# Problem: if advantage is huge, gradient is huge
# Result: policy jumps too far, performance collapses

One bad update can destroy a policy that took hours to train.


Trust Regions: The Core Idea

What if we constrained how much the policy can change?

# Trust Region concept:
# "Only update the policy within a region where our
# estimates are trustworthy"

# Old policy: pi_old(a|s)
# New policy: pi_new(a|s)
# Constraint: KL(pi_old, pi_new) < delta

# If new policy is too different from old,
# our advantage estimates become unreliable

The advantage was computed using the old policy. If the new policy is very different, those advantages are stale.


TRPO: Trust Region Policy Optimization

Schulman et al. (2015) formalized this:

# TRPO objective:
# maximize E[r(theta) * A]
# subject to: KL(pi_old, pi_new) <= delta

# where r(theta) = pi_new(a|s) / pi_old(a|s)
# (probability ratio)

def trpo_objective(old_policy, new_policy, states, actions, advantages):
    """
    TRPO maximizes expected advantage while constraining KL divergence.
    """
    old_probs = old_policy(states).log_prob(actions).exp()
    new_probs = new_policy(states).log_prob(actions).exp()

    ratio = new_probs / old_probs
    objective = (ratio * advantages).mean()

    # Constraint: KL divergence must be small
    kl = kl_divergence(old_policy(states), new_policy(states)).mean()

    return objective, kl  # Need constrained optimization

Problem: constrained optimization is complex and slow.


PPO: Making Trust Regions Practical

PPO's insight: replace the constraint with a clipped objective.

# PPO-Clip objective:
# L = min(r * A, clip(r, 1-eps, 1+eps) * A)

def ppo_clip_objective(ratio, advantages, epsilon=0.2):
    """
    The famous PPO clipped surrogate objective.

    ratio: pi_new(a|s) / pi_old(a|s)
    advantages: estimated advantages
    epsilon: clipping parameter (typically 0.1-0.3)
    """
    # Unclipped objective
    unclipped = ratio * advantages

    # Clipped objective
    clipped_ratio = torch.clamp(ratio, 1 - epsilon, 1 + epsilon)
    clipped = clipped_ratio * advantages

    # Take minimum (pessimistic bound)
    loss = -torch.min(unclipped, clipped).mean()

    return loss

Understanding the Clip

The clipping creates a "pessimistic" bound:

# Case 1: Positive advantage (good action)
# ratio > 1: action more likely now -> clipped at 1+eps
# ratio < 1: action less likely now -> no clipping
# Result: Can't increase probability too much

# Case 2: Negative advantage (bad action)
# ratio > 1: action more likely now -> no clipping (want to fix this!)
# ratio < 1: action less likely now -> clipped at 1-eps
# Result: Can't decrease probability too much

def visualize_clipping(advantages, epsilon=0.2):
    """
    Shows when clipping activates.
    """
    ratios = torch.linspace(0.5, 1.5, 100)

    for A in advantages:
        unclipped = ratios * A
        clipped_ratios = torch.clamp(ratios, 1-epsilon, 1+epsilon)
        clipped = clipped_ratios * A

        objective = torch.min(unclipped, clipped)
        # Plot shows flat regions where clipping prevents updates

The Surrogate Loss Derivation

Why does clipping approximate the trust region?

# When ratio = 1, policies are identical
# When ratio = 1 +/- eps, KL is approximately bounded

# Taylor expansion: KL(p, q) ~ (1/2) * E[(p-q)^2/p]
# For small changes: ratio near 1 means KL is small

# Clipping at [1-eps, 1+eps] approximately enforces
# that the policy hasn't changed too much

# This is an approximation, not exact!
# But empirically works very well

PPO Hyperparameters

The key hyperparameters and their effects:

class PPOConfig:
    # Clipping
    clip_epsilon: float = 0.2    # 0.1-0.3 typical

    # Learning
    learning_rate: float = 3e-4  # Often 1e-4 to 3e-4
    num_epochs: int = 10         # Updates per batch
    num_minibatches: int = 32    # Minibatches per epoch

    # GAE (Generalized Advantage Estimation)
    gamma: float = 0.99          # Discount factor
    gae_lambda: float = 0.95     # GAE parameter

    # Regularization
    entropy_coef: float = 0.01   # Entropy bonus
    value_coef: float = 0.5      # Value loss weight
    max_grad_norm: float = 0.5   # Gradient clipping

    # Data collection
    num_steps: int = 2048        # Steps per rollout
    num_envs: int = 8            # Parallel environments

Epsilon: The Clipping Parameter

# epsilon controls trust region size

# epsilon = 0.1: Conservative updates
#   - More stable but slower learning
#   - Good for complex/sensitive tasks

# epsilon = 0.2: Standard (default)
#   - Good balance for most tasks
#   - OpenAI's recommendation

# epsilon = 0.3: Aggressive updates
#   - Faster learning but less stable
#   - Risk of policy collapse

def adaptive_epsilon(episode, max_episodes, start=0.3, end=0.1):
    """Some implementations anneal epsilon over training."""
    return start - (start - end) * (episode / max_episodes)

Learning Rate Sensitivity

# PPO is sensitive to learning rate

# Too high: Policy collapse
#   - Updates overshoot even with clipping
#   - Ratio goes to 0 or infinity

# Too low: Slow learning
#   - Never reaches optimal policy
#   - Wastes compute

# The interaction: lr and epsilon together determine update magnitude
# High lr + high eps = unstable
# Low lr + low eps = stable but slow

# Common practice: linear annealing
def get_lr(update, total_updates, initial_lr=3e-4):
    return initial_lr * (1 - update / total_updates)

Generalized Advantage Estimation

PPO typically uses GAE for advantage computation:

def compute_gae(rewards, values, dones, gamma=0.99, gae_lambda=0.95):
    """
    Generalized Advantage Estimation (Schulman et al., 2015)

    Balances bias vs variance in advantage estimates.
    lambda=0: High bias, low variance (TD)
    lambda=1: Low bias, high variance (MC)
    """
    advantages = torch.zeros_like(rewards)
    last_gae = 0

    for t in reversed(range(len(rewards))):
        if t == len(rewards) - 1:
            next_value = 0
        else:
            next_value = values[t + 1]

        # TD error
        delta = rewards[t] + gamma * next_value * (1 - dones[t]) - values[t]

        # GAE
        advantages[t] = delta + gamma * gae_lambda * (1 - dones[t]) * last_gae
        last_gae = advantages[t]

    returns = advantages + values
    return advantages, returns

GAE Lambda

# gae_lambda controls bias-variance tradeoff

# lambda = 0: TD(0)
#   advantage = r + gamma*V(s') - V(s)
#   High bias: V might be wrong
#   Low variance: Only one step of randomness

# lambda = 1: Monte Carlo
#   advantage = sum(gamma^t * r_t) - V(s)
#   Low bias: Uses actual returns
#   High variance: Many steps of randomness

# lambda = 0.95: Standard choice
#   Exponentially-weighted average of n-step returns
#   Good empirical tradeoff

# Intuition: Trust value estimates for ~20 steps (1/(1-0.95))

Capstone Connection

PPO is the foundation of RLHF:

# In RLHF (Reinforcement Learning from Human Feedback):

# 1. Train reward model on human preferences
# 2. Use PPO to optimize policy against reward model

def rlhf_training_step(policy, reward_model, prompts):
    """
    The core RLHF loop uses PPO.
    """
    # Generate responses
    responses = policy.generate(prompts)

    # Get rewards from trained reward model
    rewards = reward_model(prompts, responses)

    # PPO update
    # The clipping prevents the policy from
    # "hacking" the reward model by moving too far
    # from the original (SFT) policy

    ppo_loss = compute_ppo_loss(
        policy=policy,
        old_policy=old_policy,  # Snapshot from start of batch
        responses=responses,
        rewards=rewards,
        kl_penalty=0.01  # Additional KL term for stability
    )

    return ppo_loss

# Why PPO for RLHF?
# 1. Stable: Can't destroy the base model in one update
# 2. Sample efficient: Reuses data with multiple epochs
# 3. Simple: No complex constrained optimization

PPO vs TRPO vs Other Methods

Method Constraint Complexity Stability
Vanilla PG None Simple Unstable
TRPO KL constraint Complex Stable
PPO-Clip Clipped ratio Simple Stable
PPO-KL Adaptive KL Medium Stable
# PPO-KL (alternative): Penalize KL instead of clipping
def ppo_kl_objective(ratio, advantages, old_dist, new_dist, beta=0.01):
    surrogate = (ratio * advantages).mean()
    kl_penalty = kl_divergence(old_dist, new_dist).mean()
    return surrogate - beta * kl_penalty

# PPO-Clip is more common because:
# 1. No need to tune beta
# 2. Often works better empirically

🎓 Tyla's Exercise

  1. Prove that when ratio = 1, the PPO objective equals the vanilla policy gradient objective. What does this mean about the first update after collecting a batch?

  2. The PPO paper claims clipping is a "lower bound" on the unclipped objective when advantages are positive. Verify this mathematically. What happens when advantages are negative?

  3. Derive the approximate relationship between epsilon and the KL divergence bound. Under what assumptions does clipping at [1-eps, 1+eps] approximately bound KL?


💻 Aaliyah's Exercise

Build a PPO objective function with diagnostics:

def ppo_objective_with_diagnostics(
    old_log_probs: torch.Tensor,
    new_log_probs: torch.Tensor,
    advantages: torch.Tensor,
    epsilon: float = 0.2
) -> dict:
    """
    Compute PPO loss and return useful diagnostics.

    Returns:
        - loss: The PPO clipped loss
        - clip_fraction: What fraction of samples were clipped?
        - approx_kl: Approximate KL divergence
        - ratio_mean: Mean probability ratio
        - ratio_std: Std of probability ratio
    """
    # Compute ratio
    ratio = (new_log_probs - old_log_probs).exp()

    # TODO: Implement clipped objective
    # TODO: Compute clip_fraction
    # TODO: Compute approx_kl = ((ratio - 1) - log(ratio)).mean()
    # TODO: Return all diagnostics
    pass

Test: Verify that clip_fraction increases when learning rate is too high.


📚 Maneesha's Reflection

  1. PPO's "trust region" is about trusting our estimates within a certain range. What analogies from human learning capture this idea? When do we trust our intuitions vs seek external validation?

  2. The clip creates a "pessimistic" bound - always assuming the worst case. This is conservative but safe. Where else in life do we apply pessimistic reasoning, and when is it appropriate vs limiting?

  3. PPO made RLHF practical by being "stable enough." What does stability mean in the context of training AI systems? Why might instability be especially concerning when fine-tuning on human preferences?