RLHF Implementation: PPO for Language Models

Theory meets practice. Now we train models to maximize human preferences.


PPO for Language Models

Proximal Policy Optimization (PPO) is the workhorse of RLHF.

Why PPO?

# The PPO objective:
# L = E[min(r_t * A_t, clip(r_t, 1-ε, 1+ε) * A_t)]

# Where:
# r_t = π(a|s) / π_old(a|s)  # probability ratio
# A_t = advantage estimate   # how good is this action?
# ε = clip range (typically 0.2)

The RLHF-PPO Loop

def rlhf_ppo_training(policy, ref_model, reward_model, prompts, config):
    """
    Full RLHF training loop with PPO.
    """
    optimizer = AdamW(policy.parameters(), lr=config.lr)
    value_head = ValueHead(policy.config.hidden_size)

    for epoch in range(config.epochs):
        for batch in prompts:
            # === ROLLOUT PHASE ===
            # Generate responses from current policy
            with torch.no_grad():
                responses, log_probs_old = policy.generate_with_logprobs(
                    batch, max_length=config.max_response_length
                )

                # Score with reward model
                rewards = reward_model(batch, responses)

                # Compute KL penalty
                ref_log_probs = ref_model.log_probs(batch, responses)
                kl = log_probs_old - ref_log_probs

                # Combined reward
                modified_rewards = rewards - config.beta * kl

                # Compute advantages
                values = value_head(policy.get_hidden_states(batch, responses))
                advantages = compute_gae(modified_rewards, values, config.gamma, config.lam)

            # === PPO UPDATE PHASE ===
            for _ in range(config.ppo_epochs):
                # Current policy log probs
                log_probs = policy.log_probs(batch, responses)

                # Probability ratio
                ratio = torch.exp(log_probs - log_probs_old)

                # Clipped objective
                surr1 = ratio * advantages
                surr2 = torch.clamp(ratio, 1 - config.eps, 1 + config.eps) * advantages
                policy_loss = -torch.min(surr1, surr2).mean()

                # Value loss
                new_values = value_head(policy.get_hidden_states(batch, responses))
                value_loss = F.mse_loss(new_values, modified_rewards)

                # Combined loss
                loss = policy_loss + config.vf_coef * value_loss

                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(policy.parameters(), config.max_grad_norm)
                optimizer.step()

        # Log metrics
        log_metrics(epoch, rewards, kl, policy_loss, value_loss)

    return policy

The KL Penalty

The KL divergence penalty keeps the policy close to the reference model.

Why is this critical?

Without KL penalty:

def compute_kl_penalty(policy, ref_model, input_ids, response_ids):
    """
    KL divergence between policy and reference model.
    """
    # Get log probabilities from both models
    policy_logits = policy(input_ids, response_ids).logits
    ref_logits = ref_model(input_ids, response_ids).logits

    # Convert to log probabilities
    policy_log_probs = F.log_softmax(policy_logits, dim=-1)
    ref_log_probs = F.log_softmax(ref_logits, dim=-1)

    # Gather the log probs for actual tokens
    policy_lp = torch.gather(policy_log_probs, -1, response_ids.unsqueeze(-1)).squeeze(-1)
    ref_lp = torch.gather(ref_log_probs, -1, response_ids.unsqueeze(-1)).squeeze(-1)

    # KL divergence per token, then sum over response
    kl_per_token = policy_lp - ref_lp
    kl_total = kl_per_token.sum(dim=-1)

    return kl_total

# Typical beta values:
# beta = 0.01: Weak penalty, policy can drift far
# beta = 0.1:  Standard, good balance
# beta = 1.0:  Strong penalty, policy stays close to ref

The KL-reward tradeoff:

       High KL penalty (beta=0.5)          Low KL penalty (beta=0.01)
       ─────────────────────────          ─────────────────────────
       - Stable training                  - Aggressive optimization
       - Limited improvement              - Risk of reward hacking
       - Preserves capabilities           - May lose coherence

Reference Models

The reference model anchors the policy to sensible language.

class RLHFTrainer:
    def __init__(self, model_name, config):
        # Policy: the model we're training
        self.policy = AutoModelForCausalLM.from_pretrained(model_name)

        # Reference: frozen copy of initial policy
        self.ref_model = AutoModelForCausalLM.from_pretrained(model_name)
        self.ref_model.eval()
        for param in self.ref_model.parameters():
            param.requires_grad = False  # Never train the reference!

        self.reward_model = RewardModel.from_pretrained(config.rm_path)

Common reference model choices:

Choice Pros Cons
SFT model Aligned baseline May have biases
Base pretrained Neutral anchor Less aligned
Previous checkpoint Gradual drift Complexity

The TRL Library

Hugging Face's TRL makes RLHF implementation straightforward.

from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead

# Configuration
config = PPOConfig(
    model_name="gpt2",
    learning_rate=1.41e-5,
    batch_size=128,
    mini_batch_size=32,
    gradient_accumulation_steps=1,
    ppo_epochs=4,
    kl_penalty="kl",  # or "abs" or "mse"
    init_kl_coef=0.2,
    target_kl=6.0,
    cliprange=0.2,
    cliprange_value=0.2,
    gamma=1.0,
    lam=0.95,
)

# Initialize models
model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)

# Initialize trainer
ppo_trainer = PPOTrainer(
    config=config,
    model=model,
    ref_model=ref_model,
    tokenizer=tokenizer,
)

TRL Training Loop

from tqdm import tqdm

def train_with_trl(ppo_trainer, reward_model, dataset, epochs=10):
    for epoch in range(epochs):
        for batch in tqdm(dataset):
            prompts = batch["prompt"]

            # Tokenize prompts
            query_tensors = [tokenizer.encode(p, return_tensors="pt").squeeze() for p in prompts]

            # Generate responses
            response_tensors = ppo_trainer.generate(
                query_tensors,
                max_new_tokens=128,
                do_sample=True,
                temperature=0.7,
            )

            # Decode for reward model
            responses = [tokenizer.decode(r) for r in response_tensors]

            # Compute rewards
            rewards = [torch.tensor(reward_model.score(p, r)) for p, r in zip(prompts, responses)]

            # PPO step
            stats = ppo_trainer.step(query_tensors, response_tensors, rewards)

            # Log
            print(f"Epoch {epoch} | Reward: {stats['ppo/mean_scores']:.3f} | KL: {stats['objective/kl']:.3f}")

    return ppo_trainer.model

Debugging RLHF

RLHF training is notoriously unstable. Here's how to debug.

Problem 1: Reward collapse

Symptoms: Rewards don't improve or suddenly collapse.

# Diagnosis
def check_reward_distribution(rewards_per_step):
    for step, rewards in enumerate(rewards_per_step):
        mean = rewards.mean()
        std = rewards.std()
        print(f"Step {step}: mean={mean:.3f}, std={std:.3f}")

        # Red flags:
        if std < 0.01:
            print("WARNING: Reward variance collapsed - model may be degenerate")
        if mean < -10:
            print("WARNING: Rewards very negative - KL penalty too high?")

# Fix: Lower beta, check reward model, inspect generated outputs

Problem 2: KL explosion

Symptoms: KL divergence grows unboundedly.

# Diagnosis
def check_kl_trajectory(kl_per_step):
    for step, kl in enumerate(kl_per_step):
        print(f"Step {step}: KL={kl:.3f}")

        if kl > 10:
            print("WARNING: KL too high - policy diverging from reference")

# Fix: Increase beta, reduce learning rate, use adaptive KL

Problem 3: Policy collapse

Symptoms: Model outputs become repetitive or degenerate.

def check_policy_health(model, prompts):
    responses = [model.generate(p) for p in prompts]

    # Check diversity
    unique_responses = len(set(responses))
    print(f"Unique responses: {unique_responses}/{len(prompts)}")

    # Check length distribution
    lengths = [len(r) for r in responses]
    print(f"Length: mean={np.mean(lengths):.1f}, std={np.std(lengths):.1f}")

    # Check for degenerate patterns
    for r in responses[:5]:
        print(f"  Response: {r[:100]}...")

# Fix: Lower learning rate, increase KL penalty, check reward model

Adaptive KL Controller

TRL includes an adaptive KL controller that adjusts beta during training.

class AdaptiveKLController:
    def __init__(self, init_kl_coef=0.2, target_kl=6.0, horizon=10000):
        self.kl_coef = init_kl_coef
        self.target_kl = target_kl
        self.horizon = horizon

    def update(self, current_kl, step):
        """
        If KL is too high, increase penalty.
        If KL is too low, decrease penalty.
        """
        proportional_error = (current_kl - self.target_kl) / self.target_kl

        # Multiplicative update
        multiplier = 1.0 + proportional_error * (step / self.horizon)
        self.kl_coef = self.kl_coef * multiplier

        # Clamp to reasonable range
        self.kl_coef = max(0.001, min(10.0, self.kl_coef))

        return self.kl_coef

# Usage in training loop:
kl_controller = AdaptiveKLController(target_kl=6.0)

for step, batch in enumerate(dataloader):
    # ... training code ...
    stats = ppo_trainer.step(queries, responses, rewards)

    # Update KL coefficient
    new_beta = kl_controller.update(stats['objective/kl'], step)
    ppo_trainer.kl_coef = new_beta

Monitoring RLHF Training

Essential metrics to track:

def log_rlhf_metrics(wandb, stats, step):
    wandb.log({
        # Reward metrics
        "reward/mean": stats["ppo/mean_scores"],
        "reward/std": stats["ppo/std_scores"],

        # KL metrics
        "kl/mean": stats["objective/kl"],
        "kl/coef": stats["objective/kl_coef"],

        # Policy metrics
        "policy/entropy": stats["objective/entropy"],
        "policy/approxkl": stats["ppo/policy/approxkl"],
        "policy/clipfrac": stats["ppo/policy/clipfrac"],

        # Value metrics
        "value/loss": stats["ppo/val/loss"],
        "value/explained_var": stats["ppo/val/var_explained"],

        # Learning metrics
        "lr": stats["ppo/learning_rate"],
        "step": step,
    })

Dashboard red flags:

Metric Healthy Range Red Flag
KL 0-10 >15 (diverging)
Entropy Decreasing slowly Drops to 0 (collapse)
Clip fraction 0.1-0.3 >0.5 (updates too big)
Explained variance >0.5 <0 (value head broken)

Capstone Connection

RLHF implementation is where sycophancy is baked in or trained out.

Your Milestone 3 intervention will modify this training loop:

def antisycophancy_rlhf(policy, ref_model, rm, prompts, config):
    """
    Modified RLHF that penalizes sycophantic behavior.
    """
    for batch in prompts:
        responses = policy.generate(batch)

        # Standard reward
        rm_reward = rm(batch, responses)

        # === YOUR INTERVENTION ===
        # Add sycophancy penalty
        sycophancy_score = detect_sycophancy(batch, responses)
        honesty_bonus = detect_truthful_correction(batch, responses)

        modified_reward = rm_reward - config.syc_penalty * sycophancy_score + config.honesty_bonus * honesty_bonus
        # === END INTERVENTION ===

        # Continue with PPO...
        advantages = compute_advantages(modified_reward, values)
        ppo_update(policy, advantages)

# The question: How do you define detect_sycophancy and detect_truthful_correction?

Implementation options for your capstone:

  1. Classifier-based: Train a sycophancy classifier, use as penalty
  2. Rule-based: Detect agreement phrases when user is wrong
  3. Contrastive: Compare to "honest" reference responses
  4. Steering: Apply anti-sycophancy steering during generation

🎓 Tyla's Exercise

  1. PPO uses a clipped objective to prevent large policy updates. Derive the gradient of the clipped objective and show when clipping activates (hint: it depends on the sign of the advantage).

  2. The KL penalty prevents mode collapse but also limits improvement. Prove: if beta is too high, the optimal policy is the reference model (no learning). Find the critical beta as a function of the reward landscape.

  3. RLHF optimizes E[reward - beta * KL]. This is equivalent to maximizing reward subject to KL < some constant. Find the relationship between beta and this KL constraint (hint: Lagrangian duality).


💻 Aaliyah's Exercise

Build a minimal RLHF training loop from scratch:

class MinimalRLHF:
    def __init__(self, policy, ref_model, reward_model, config):
        self.policy = policy
        self.ref_model = ref_model
        self.rm = reward_model
        self.config = config

    def generate_with_logprobs(self, prompts):
        """
        Generate responses and compute log probabilities.
        Return: responses, log_probs
        """
        pass

    def compute_kl(self, prompts, responses):
        """
        Compute KL(policy || ref) for each response.
        """
        pass

    def compute_advantages(self, rewards, values):
        """
        Generalized Advantage Estimation (GAE).
        """
        pass

    def ppo_step(self, prompts, responses, advantages, old_log_probs):
        """
        Single PPO update step.
        1. Compute current log probs
        2. Compute ratio
        3. Compute clipped objective
        4. Backward pass
        """
        pass

    def train(self, dataset, epochs=10):
        """
        Full training loop with logging.
        Track: reward, KL, policy loss, value loss
        """
        pass

📚 Maneesha's Reflection

  1. RLHF requires extensive hyperparameter tuning (beta, learning rate, clip range, etc.). This makes it expensive and inaccessible to small teams. What are the implications for who gets to decide how AI models are aligned?

  2. The KL penalty keeps the model "close" to its reference. But what if the reference model is itself biased or problematic? Is "don't change too much" always a good objective?

  3. Debugging RLHF is notoriously difficult and requires deep expertise. If alignment techniques are too complex for most practitioners to implement correctly, what does this mean for AI safety as a field?