RLHF Implementation: PPO for Language Models
Theory meets practice. Now we train models to maximize human preferences.
PPO for Language Models
Proximal Policy Optimization (PPO) is the workhorse of RLHF.
Why PPO?
- Stable training (unlike vanilla policy gradient)
- Sample efficient (reuses data)
- Works with discrete actions (tokens)
# The PPO objective:
# L = E[min(r_t * A_t, clip(r_t, 1-ε, 1+ε) * A_t)]
# Where:
# r_t = π(a|s) / π_old(a|s) # probability ratio
# A_t = advantage estimate # how good is this action?
# ε = clip range (typically 0.2)
The RLHF-PPO Loop
def rlhf_ppo_training(policy, ref_model, reward_model, prompts, config):
"""
Full RLHF training loop with PPO.
"""
optimizer = AdamW(policy.parameters(), lr=config.lr)
value_head = ValueHead(policy.config.hidden_size)
for epoch in range(config.epochs):
for batch in prompts:
# === ROLLOUT PHASE ===
# Generate responses from current policy
with torch.no_grad():
responses, log_probs_old = policy.generate_with_logprobs(
batch, max_length=config.max_response_length
)
# Score with reward model
rewards = reward_model(batch, responses)
# Compute KL penalty
ref_log_probs = ref_model.log_probs(batch, responses)
kl = log_probs_old - ref_log_probs
# Combined reward
modified_rewards = rewards - config.beta * kl
# Compute advantages
values = value_head(policy.get_hidden_states(batch, responses))
advantages = compute_gae(modified_rewards, values, config.gamma, config.lam)
# === PPO UPDATE PHASE ===
for _ in range(config.ppo_epochs):
# Current policy log probs
log_probs = policy.log_probs(batch, responses)
# Probability ratio
ratio = torch.exp(log_probs - log_probs_old)
# Clipped objective
surr1 = ratio * advantages
surr2 = torch.clamp(ratio, 1 - config.eps, 1 + config.eps) * advantages
policy_loss = -torch.min(surr1, surr2).mean()
# Value loss
new_values = value_head(policy.get_hidden_states(batch, responses))
value_loss = F.mse_loss(new_values, modified_rewards)
# Combined loss
loss = policy_loss + config.vf_coef * value_loss
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(policy.parameters(), config.max_grad_norm)
optimizer.step()
# Log metrics
log_metrics(epoch, rewards, kl, policy_loss, value_loss)
return policy
The KL Penalty
The KL divergence penalty keeps the policy close to the reference model.
Why is this critical?
Without KL penalty:
- Policy collapses to degenerate outputs
- Reward hacking runs rampant
- Model loses language capabilities
def compute_kl_penalty(policy, ref_model, input_ids, response_ids):
"""
KL divergence between policy and reference model.
"""
# Get log probabilities from both models
policy_logits = policy(input_ids, response_ids).logits
ref_logits = ref_model(input_ids, response_ids).logits
# Convert to log probabilities
policy_log_probs = F.log_softmax(policy_logits, dim=-1)
ref_log_probs = F.log_softmax(ref_logits, dim=-1)
# Gather the log probs for actual tokens
policy_lp = torch.gather(policy_log_probs, -1, response_ids.unsqueeze(-1)).squeeze(-1)
ref_lp = torch.gather(ref_log_probs, -1, response_ids.unsqueeze(-1)).squeeze(-1)
# KL divergence per token, then sum over response
kl_per_token = policy_lp - ref_lp
kl_total = kl_per_token.sum(dim=-1)
return kl_total
# Typical beta values:
# beta = 0.01: Weak penalty, policy can drift far
# beta = 0.1: Standard, good balance
# beta = 1.0: Strong penalty, policy stays close to ref
The KL-reward tradeoff:
High KL penalty (beta=0.5) Low KL penalty (beta=0.01)
───────────────────────── ─────────────────────────
- Stable training - Aggressive optimization
- Limited improvement - Risk of reward hacking
- Preserves capabilities - May lose coherence
Reference Models
The reference model anchors the policy to sensible language.
class RLHFTrainer:
def __init__(self, model_name, config):
# Policy: the model we're training
self.policy = AutoModelForCausalLM.from_pretrained(model_name)
# Reference: frozen copy of initial policy
self.ref_model = AutoModelForCausalLM.from_pretrained(model_name)
self.ref_model.eval()
for param in self.ref_model.parameters():
param.requires_grad = False # Never train the reference!
self.reward_model = RewardModel.from_pretrained(config.rm_path)
Common reference model choices:
| Choice | Pros | Cons |
|---|---|---|
| SFT model | Aligned baseline | May have biases |
| Base pretrained | Neutral anchor | Less aligned |
| Previous checkpoint | Gradual drift | Complexity |
The TRL Library
Hugging Face's TRL makes RLHF implementation straightforward.
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
# Configuration
config = PPOConfig(
model_name="gpt2",
learning_rate=1.41e-5,
batch_size=128,
mini_batch_size=32,
gradient_accumulation_steps=1,
ppo_epochs=4,
kl_penalty="kl", # or "abs" or "mse"
init_kl_coef=0.2,
target_kl=6.0,
cliprange=0.2,
cliprange_value=0.2,
gamma=1.0,
lam=0.95,
)
# Initialize models
model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
# Initialize trainer
ppo_trainer = PPOTrainer(
config=config,
model=model,
ref_model=ref_model,
tokenizer=tokenizer,
)
TRL Training Loop
from tqdm import tqdm
def train_with_trl(ppo_trainer, reward_model, dataset, epochs=10):
for epoch in range(epochs):
for batch in tqdm(dataset):
prompts = batch["prompt"]
# Tokenize prompts
query_tensors = [tokenizer.encode(p, return_tensors="pt").squeeze() for p in prompts]
# Generate responses
response_tensors = ppo_trainer.generate(
query_tensors,
max_new_tokens=128,
do_sample=True,
temperature=0.7,
)
# Decode for reward model
responses = [tokenizer.decode(r) for r in response_tensors]
# Compute rewards
rewards = [torch.tensor(reward_model.score(p, r)) for p, r in zip(prompts, responses)]
# PPO step
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
# Log
print(f"Epoch {epoch} | Reward: {stats['ppo/mean_scores']:.3f} | KL: {stats['objective/kl']:.3f}")
return ppo_trainer.model
Debugging RLHF
RLHF training is notoriously unstable. Here's how to debug.
Problem 1: Reward collapse
Symptoms: Rewards don't improve or suddenly collapse.
# Diagnosis
def check_reward_distribution(rewards_per_step):
for step, rewards in enumerate(rewards_per_step):
mean = rewards.mean()
std = rewards.std()
print(f"Step {step}: mean={mean:.3f}, std={std:.3f}")
# Red flags:
if std < 0.01:
print("WARNING: Reward variance collapsed - model may be degenerate")
if mean < -10:
print("WARNING: Rewards very negative - KL penalty too high?")
# Fix: Lower beta, check reward model, inspect generated outputs
Problem 2: KL explosion
Symptoms: KL divergence grows unboundedly.
# Diagnosis
def check_kl_trajectory(kl_per_step):
for step, kl in enumerate(kl_per_step):
print(f"Step {step}: KL={kl:.3f}")
if kl > 10:
print("WARNING: KL too high - policy diverging from reference")
# Fix: Increase beta, reduce learning rate, use adaptive KL
Problem 3: Policy collapse
Symptoms: Model outputs become repetitive or degenerate.
def check_policy_health(model, prompts):
responses = [model.generate(p) for p in prompts]
# Check diversity
unique_responses = len(set(responses))
print(f"Unique responses: {unique_responses}/{len(prompts)}")
# Check length distribution
lengths = [len(r) for r in responses]
print(f"Length: mean={np.mean(lengths):.1f}, std={np.std(lengths):.1f}")
# Check for degenerate patterns
for r in responses[:5]:
print(f" Response: {r[:100]}...")
# Fix: Lower learning rate, increase KL penalty, check reward model
Adaptive KL Controller
TRL includes an adaptive KL controller that adjusts beta during training.
class AdaptiveKLController:
def __init__(self, init_kl_coef=0.2, target_kl=6.0, horizon=10000):
self.kl_coef = init_kl_coef
self.target_kl = target_kl
self.horizon = horizon
def update(self, current_kl, step):
"""
If KL is too high, increase penalty.
If KL is too low, decrease penalty.
"""
proportional_error = (current_kl - self.target_kl) / self.target_kl
# Multiplicative update
multiplier = 1.0 + proportional_error * (step / self.horizon)
self.kl_coef = self.kl_coef * multiplier
# Clamp to reasonable range
self.kl_coef = max(0.001, min(10.0, self.kl_coef))
return self.kl_coef
# Usage in training loop:
kl_controller = AdaptiveKLController(target_kl=6.0)
for step, batch in enumerate(dataloader):
# ... training code ...
stats = ppo_trainer.step(queries, responses, rewards)
# Update KL coefficient
new_beta = kl_controller.update(stats['objective/kl'], step)
ppo_trainer.kl_coef = new_beta
Monitoring RLHF Training
Essential metrics to track:
def log_rlhf_metrics(wandb, stats, step):
wandb.log({
# Reward metrics
"reward/mean": stats["ppo/mean_scores"],
"reward/std": stats["ppo/std_scores"],
# KL metrics
"kl/mean": stats["objective/kl"],
"kl/coef": stats["objective/kl_coef"],
# Policy metrics
"policy/entropy": stats["objective/entropy"],
"policy/approxkl": stats["ppo/policy/approxkl"],
"policy/clipfrac": stats["ppo/policy/clipfrac"],
# Value metrics
"value/loss": stats["ppo/val/loss"],
"value/explained_var": stats["ppo/val/var_explained"],
# Learning metrics
"lr": stats["ppo/learning_rate"],
"step": step,
})
Dashboard red flags:
| Metric | Healthy Range | Red Flag |
|---|---|---|
| KL | 0-10 | >15 (diverging) |
| Entropy | Decreasing slowly | Drops to 0 (collapse) |
| Clip fraction | 0.1-0.3 | >0.5 (updates too big) |
| Explained variance | >0.5 | <0 (value head broken) |
Capstone Connection
RLHF implementation is where sycophancy is baked in or trained out.
Your Milestone 3 intervention will modify this training loop:
def antisycophancy_rlhf(policy, ref_model, rm, prompts, config):
"""
Modified RLHF that penalizes sycophantic behavior.
"""
for batch in prompts:
responses = policy.generate(batch)
# Standard reward
rm_reward = rm(batch, responses)
# === YOUR INTERVENTION ===
# Add sycophancy penalty
sycophancy_score = detect_sycophancy(batch, responses)
honesty_bonus = detect_truthful_correction(batch, responses)
modified_reward = rm_reward - config.syc_penalty * sycophancy_score + config.honesty_bonus * honesty_bonus
# === END INTERVENTION ===
# Continue with PPO...
advantages = compute_advantages(modified_reward, values)
ppo_update(policy, advantages)
# The question: How do you define detect_sycophancy and detect_truthful_correction?
Implementation options for your capstone:
- Classifier-based: Train a sycophancy classifier, use as penalty
- Rule-based: Detect agreement phrases when user is wrong
- Contrastive: Compare to "honest" reference responses
- Steering: Apply anti-sycophancy steering during generation
🎓 Tyla's Exercise
PPO uses a clipped objective to prevent large policy updates. Derive the gradient of the clipped objective and show when clipping activates (hint: it depends on the sign of the advantage).
The KL penalty prevents mode collapse but also limits improvement. Prove: if beta is too high, the optimal policy is the reference model (no learning). Find the critical beta as a function of the reward landscape.
RLHF optimizes E[reward - beta * KL]. This is equivalent to maximizing reward subject to KL < some constant. Find the relationship between beta and this KL constraint (hint: Lagrangian duality).
💻 Aaliyah's Exercise
Build a minimal RLHF training loop from scratch:
class MinimalRLHF:
def __init__(self, policy, ref_model, reward_model, config):
self.policy = policy
self.ref_model = ref_model
self.rm = reward_model
self.config = config
def generate_with_logprobs(self, prompts):
"""
Generate responses and compute log probabilities.
Return: responses, log_probs
"""
pass
def compute_kl(self, prompts, responses):
"""
Compute KL(policy || ref) for each response.
"""
pass
def compute_advantages(self, rewards, values):
"""
Generalized Advantage Estimation (GAE).
"""
pass
def ppo_step(self, prompts, responses, advantages, old_log_probs):
"""
Single PPO update step.
1. Compute current log probs
2. Compute ratio
3. Compute clipped objective
4. Backward pass
"""
pass
def train(self, dataset, epochs=10):
"""
Full training loop with logging.
Track: reward, KL, policy loss, value loss
"""
pass
📚 Maneesha's Reflection
RLHF requires extensive hyperparameter tuning (beta, learning rate, clip range, etc.). This makes it expensive and inaccessible to small teams. What are the implications for who gets to decide how AI models are aligned?
The KL penalty keeps the model "close" to its reference. But what if the reference model is itself biased or problematic? Is "don't change too much" always a good objective?
Debugging RLHF is notoriously difficult and requires deep expertise. If alignment techniques are too complex for most practitioners to implement correctly, what does this mean for AI safety as a field?