Actor-Critic Methods: The Best of Both Worlds
REINFORCE waits until the episode ends. Actor-Critic learns at every step.
The Actor-Critic Idea
Combine policy gradients with value function learning:
- Actor: The policy pi(a|s) - decides what to do
- Critic: The value function V(s) or Q(s,a) - evaluates how good decisions are
# REINFORCE: Wait for episode to end, use actual returns
gradient = log_prob * (actual_return - baseline)
# Actor-Critic: Learn at each step, use estimated returns
gradient = log_prob * (estimated_advantage)
# The critic provides the advantage estimate
# No need to wait for the episode to finish!
Why Actor-Critic?
REINFORCE problems:
- Must wait until episode ends (can't learn mid-episode)
- High variance (returns fluctuate wildly)
- No value estimate to guide exploration
Actor-Critic solutions:
- Learn after every step (or every few steps)
- Critic provides lower-variance estimates
- Value function helps with credit assignment
The tradeoff:
- Introduces bias (critic's estimates aren't perfect)
- But greatly reduces variance
- Net effect: faster, more stable learning
The Advantage Function
The advantage tells us: "How much better is this action than average?"
# Value function V(s): Expected return starting from state s
# Q function Q(s,a): Expected return taking action a in state s
#
# Advantage A(s,a) = Q(s,a) - V(s)
#
# Intuition:
# - A > 0: This action is better than average for this state
# - A < 0: This action is worse than average
# - A = 0: This action is exactly as good as average
def compute_advantage(states, actions, rewards, next_states, dones, critic, gamma=0.99):
"""
Compute advantage using TD(0) estimate.
"""
# V(s) from critic
values = critic(states)
# V(s') from critic (0 if terminal)
next_values = critic(next_states)
next_values = next_values * (1 - dones)
# TD target: r + gamma * V(s')
# This is our estimate of Q(s,a)
td_targets = rewards + gamma * next_values
# Advantage = Q(s,a) - V(s) ≈ TD_target - V(s)
advantages = td_targets - values
return advantages
One-Step Actor-Critic
The simplest actor-critic algorithm:
class ActorCritic:
def __init__(self, state_dim, action_dim, hidden_dim=64, lr=3e-4, gamma=0.99):
# Actor (policy network)
self.actor = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim),
nn.Softmax(dim=-1)
)
# Critic (value network)
self.critic = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr)
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=lr)
self.gamma = gamma
def select_action(self, state):
state = torch.FloatTensor(state).unsqueeze(0)
probs = self.actor(state)
dist = torch.distributions.Categorical(probs)
action = dist.sample()
return action.item(), dist.log_prob(action)
def update(self, state, action, reward, next_state, done, log_prob):
"""
One-step TD update for both actor and critic.
"""
state = torch.FloatTensor(state).unsqueeze(0)
next_state = torch.FloatTensor(next_state).unsqueeze(0)
reward = torch.FloatTensor([reward])
done = torch.FloatTensor([done])
# Critic update
value = self.critic(state)
next_value = self.critic(next_state) * (1 - done)
td_target = reward + self.gamma * next_value.detach()
td_error = td_target - value
critic_loss = td_error.pow(2)
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
# Actor update (use TD error as advantage)
# Recompute td_error with updated critic? No - use the one we have
advantage = td_error.detach()
actor_loss = -log_prob * advantage
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
return actor_loss.item(), critic_loss.item()
Training Loop for Actor-Critic
def train_actor_critic(env, agent, num_episodes=1000):
"""Train actor-critic with per-step updates."""
episode_rewards = []
for episode in range(num_episodes):
state = env.reset()
episode_reward = 0
done = False
while not done:
# Act
action, log_prob = agent.select_action(state)
next_state, reward, done, _ = env.step(action)
# Learn immediately!
agent.update(state, action, reward, next_state, done, log_prob)
state = next_state
episode_reward += reward
episode_rewards.append(episode_reward)
if episode % 100 == 0:
avg = sum(episode_rewards[-100:]) / min(100, len(episode_rewards))
print(f"Episode {episode}, Avg Reward: {avg:.2f}")
return episode_rewards
A2C: Advantage Actor-Critic
A2C adds several improvements over basic actor-critic:
class A2C:
def __init__(self, state_dim, action_dim, hidden_dim=256, lr=7e-4,
gamma=0.99, entropy_coef=0.01, value_coef=0.5, n_steps=5):
# Shared feature extractor (optional but common)
self.features = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
)
# Actor head
self.actor_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim),
)
# Critic head
self.critic_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1),
)
self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)
self.gamma = gamma
self.entropy_coef = entropy_coef
self.value_coef = value_coef
self.n_steps = n_steps
def parameters(self):
return list(self.features.parameters()) + \
list(self.actor_head.parameters()) + \
list(self.critic_head.parameters())
def forward(self, state):
features = self.features(state)
action_logits = self.actor_head(features)
value = self.critic_head(features)
return action_logits, value
def select_action(self, state):
state = torch.FloatTensor(state).unsqueeze(0)
logits, value = self.forward(state)
dist = torch.distributions.Categorical(logits=logits)
action = dist.sample()
return action.item(), dist.log_prob(action), dist.entropy(), value
def compute_returns_and_advantages(self, rewards, values, dones, next_value):
"""
Compute n-step returns and advantages.
"""
returns = []
advantages = []
# Bootstrap from final value
R = next_value
for t in reversed(range(len(rewards))):
R = rewards[t] + self.gamma * R * (1 - dones[t])
returns.insert(0, R)
advantages.insert(0, R - values[t])
returns = torch.stack(returns)
advantages = torch.stack(advantages)
# Normalize advantages
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
return returns, advantages
def update(self, states, actions, rewards, dones, log_probs, values, entropies, next_value):
"""
A2C update after n steps.
"""
returns, advantages = self.compute_returns_and_advantages(
rewards, values, dones, next_value
)
# Policy loss (negative because we want to maximize)
policy_loss = -(torch.stack(log_probs) * advantages.detach()).mean()
# Value loss
value_loss = F.mse_loss(torch.stack(values).squeeze(), returns.detach())
# Entropy bonus (encourages exploration)
entropy_loss = -torch.stack(entropies).mean()
# Combined loss
loss = policy_loss + self.value_coef * value_loss + self.entropy_coef * entropy_loss
self.optimizer.zero_grad()
loss.backward()
# Gradient clipping (helps stability)
torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=0.5)
self.optimizer.step()
return {
'policy_loss': policy_loss.item(),
'value_loss': value_loss.item(),
'entropy': -entropy_loss.item()
}
N-Step Returns
Instead of one-step TD, use n-step returns for better bias-variance tradeoff:
# 1-step return (high bias, low variance):
# G_t = r_t + gamma * V(s_{t+1})
# Monte Carlo return (low bias, high variance):
# G_t = r_t + gamma*r_{t+1} + gamma^2*r_{t+2} + ... (full episode)
# N-step return (balanced):
# G_t = r_t + gamma*r_{t+1} + ... + gamma^{n-1}*r_{t+n-1} + gamma^n * V(s_{t+n})
def compute_n_step_return(rewards, values, gamma, n_steps):
"""
Compute n-step returns with bootstrapping.
"""
returns = []
T = len(rewards)
for t in range(T):
G = 0
for k in range(min(n_steps, T - t)):
G += (gamma ** k) * rewards[t + k]
# Bootstrap from value function if not at episode end
if t + n_steps < T:
G += (gamma ** n_steps) * values[t + n_steps]
returns.append(G)
return torch.tensor(returns)
Generalized Advantage Estimation (GAE)
GAE smoothly interpolates between different n-step estimates:
def compute_gae(rewards, values, dones, next_value, gamma=0.99, gae_lambda=0.95):
"""
Generalized Advantage Estimation (GAE).
gae_lambda controls bias-variance tradeoff:
- lambda = 0: One-step TD (high bias, low variance)
- lambda = 1: Monte Carlo (low bias, high variance)
- lambda = 0.95: Good default (balanced)
"""
advantages = []
gae = 0
# Add next_value for bootstrapping
values = list(values) + [next_value]
for t in reversed(range(len(rewards))):
# TD error at step t
delta = rewards[t] + gamma * values[t + 1] * (1 - dones[t]) - values[t]
# GAE: exponentially weighted sum of TD errors
gae = delta + gamma * gae_lambda * (1 - dones[t]) * gae
advantages.insert(0, gae)
return torch.tensor(advantages)
# GAE intuition:
# advantage_t = delta_t + (gamma*lambda)*delta_{t+1} + (gamma*lambda)^2*delta_{t+2} + ...
#
# delta_t = r_t + gamma*V(s_{t+1}) - V(s_t) [one-step TD error]
#
# The lambda decay means:
# - Recent TD errors matter most
# - Distant TD errors are downweighted
# - This reduces variance while maintaining low bias
A2C with GAE
class A2CWithGAE(A2C):
def __init__(self, *args, gae_lambda=0.95, **kwargs):
super().__init__(*args, **kwargs)
self.gae_lambda = gae_lambda
def compute_returns_and_advantages(self, rewards, values, dones, next_value):
"""Use GAE instead of simple n-step returns."""
# GAE for advantages
advantages = compute_gae(
rewards, values, dones, next_value,
gamma=self.gamma, gae_lambda=self.gae_lambda
)
# Returns = advantages + values
returns = advantages + torch.tensor(values)
# Normalize advantages
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
return returns, advantages
Continuous Action Spaces
For continuous actions, we output distribution parameters instead of discrete probabilities:
class ContinuousActorCritic(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=256):
super().__init__()
# Shared features
self.features = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
)
# Actor outputs mean and log_std for Gaussian policy
self.actor_mean = nn.Linear(hidden_dim, action_dim)
self.actor_log_std = nn.Parameter(torch.zeros(action_dim))
# Critic outputs value
self.critic = nn.Linear(hidden_dim, 1)
def forward(self, state):
features = self.features(state)
return features
def get_action_distribution(self, state):
"""Return a Gaussian distribution over actions."""
features = self.forward(state)
mean = self.actor_mean(features)
# Clamp log_std for stability
log_std = self.actor_log_std.clamp(-20, 2)
std = log_std.exp()
return torch.distributions.Normal(mean, std)
def get_value(self, state):
features = self.forward(state)
return self.critic(features)
def select_action(self, state, deterministic=False):
"""Sample action from Gaussian policy."""
state = torch.FloatTensor(state).unsqueeze(0)
dist = self.get_action_distribution(state)
if deterministic:
action = dist.mean
else:
action = dist.rsample() # Reparameterized sample for gradient
log_prob = dist.log_prob(action).sum(dim=-1)
value = self.get_value(state)
return action.squeeze(0).numpy(), log_prob, value
Continuous Action Training
def train_continuous_a2c(env, agent, num_episodes=1000, n_steps=2048):
"""
Train continuous A2C (similar to PPO's data collection).
"""
state = env.reset()
episode_rewards = []
current_episode_reward = 0
for update in range(num_episodes):
# Collect n_steps of experience
states, actions, rewards, dones = [], [], [], []
log_probs, values = [], []
for _ in range(n_steps):
action, log_prob, value = agent.select_action(state)
next_state, reward, done, _ = env.step(action)
states.append(state)
actions.append(action)
rewards.append(reward)
dones.append(done)
log_probs.append(log_prob)
values.append(value)
current_episode_reward += reward
state = next_state
if done:
episode_rewards.append(current_episode_reward)
current_episode_reward = 0
state = env.reset()
# Compute next value for bootstrapping
with torch.no_grad():
_, _, next_value = agent.select_action(state, deterministic=True)
# Update
agent.update(states, actions, rewards, dones, log_probs, values, next_value)
if update % 10 == 0 and episode_rewards:
print(f"Update {update}, Avg Reward: {sum(episode_rewards[-10:]) / len(episode_rewards[-10:]):.2f}")
return episode_rewards
Squashing for Bounded Actions
Many environments have bounded action spaces. Use tanh squashing:
class SquashedGaussianPolicy(nn.Module):
"""
Gaussian policy with tanh squashing for bounded actions.
Used in SAC and other continuous control algorithms.
"""
def __init__(self, state_dim, action_dim, hidden_dim=256, action_limit=1.0):
super().__init__()
self.action_limit = action_limit
self.net = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
)
self.mean = nn.Linear(hidden_dim, action_dim)
self.log_std = nn.Linear(hidden_dim, action_dim)
def forward(self, state):
x = self.net(state)
mean = self.mean(x)
log_std = self.log_std(x).clamp(-20, 2)
return mean, log_std
def sample(self, state):
"""Sample action with log probability correction for squashing."""
mean, log_std = self.forward(state)
std = log_std.exp()
# Sample from Gaussian
dist = torch.distributions.Normal(mean, std)
u = dist.rsample() # Unsquashed action
# Squash through tanh
action = torch.tanh(u) * self.action_limit
# Correct log probability for the squashing
# log pi(a) = log pi(u) - sum(log(1 - tanh^2(u)))
log_prob = dist.log_prob(u) - torch.log(1 - action.pow(2) + 1e-6)
log_prob = log_prob.sum(dim=-1)
return action, log_prob
def deterministic_action(self, state):
mean, _ = self.forward(state)
return torch.tanh(mean) * self.action_limit
Entropy Regularization
Entropy bonus prevents premature convergence to deterministic policies:
def compute_entropy(distribution):
"""
Compute entropy of action distribution.
Discrete: H = -sum(p * log(p))
Gaussian: H = 0.5 * log(2 * pi * e * sigma^2) per dimension
"""
if isinstance(distribution, torch.distributions.Categorical):
return distribution.entropy()
elif isinstance(distribution, torch.distributions.Normal):
return distribution.entropy().sum(dim=-1)
else:
raise ValueError(f"Unknown distribution type: {type(distribution)}")
# In the loss function:
# total_loss = policy_loss + value_coef * value_loss - entropy_coef * entropy
#
# Subtracting entropy (with positive coef) MAXIMIZES entropy
# This encourages exploration by preferring uncertain policies
Capstone Connection
Actor-Critic is the core of RLHF algorithms like PPO.
# PPO (used in ChatGPT training) is essentially A2C with:
# 1. Clipped surrogate objective (prevents too-large updates)
# 2. Multiple epochs per batch (better sample efficiency)
# 3. GAE for advantage estimation
def ppo_policy_loss(old_log_probs, new_log_probs, advantages, clip_epsilon=0.2):
"""
PPO's clipped surrogate objective.
"""
ratio = (new_log_probs - old_log_probs).exp()
# Clipped and unclipped objective
obj_unclipped = ratio * advantages
obj_clipped = ratio.clamp(1 - clip_epsilon, 1 + clip_epsilon) * advantages
# Take minimum (pessimistic bound)
return -torch.min(obj_unclipped, obj_clipped).mean()
# RLHF training loop sketch:
def rlhf_training_step(policy, ref_policy, reward_model, prompts):
"""
One step of RLHF training.
"""
# Generate responses with current policy
responses = policy.generate(prompts)
# Score with reward model (the "critic" in some sense)
rewards = reward_model(prompts, responses)
# KL penalty to stay close to reference (prevents reward hacking)
kl = compute_kl(policy, ref_policy, prompts, responses)
rewards = rewards - beta * kl
# PPO update (actor-critic under the hood)
ppo_update(policy, prompts, responses, rewards)
For sycophancy evaluation:
Actor-critic methods reveal why RLHF can amplify sycophancy:
- The reward model (critic) may have learned to give high scores to agreeable responses
- The policy (actor) learns to maximize these scores
- Advantage = "this agreeable response is better than average" - even when wrong
- No mechanism to distinguish helpful agreement from harmful sycophancy
Understanding actor-critic = understanding RLHF's failure modes.
🎓 Tyla's Exercise
Bias-variance analysis: TD(0) has high bias, Monte Carlo has high variance. Derive the bias and variance of n-step returns as a function of n.
Prove GAE interpolation: Show that GAE with lambda=0 gives TD(0) advantage, and lambda=1 gives Monte Carlo advantage.
Log probability correction: For squashed Gaussian policies, derive why we need to subtract log(1 - tanh^2(u)) from the log probability.
Optimal critic: If the critic perfectly estimates V(s), what is the variance of the policy gradient? Compare to REINFORCE.
💻 Aaliyah's Exercise
Build a complete A2C implementation with continuous actions:
def build_continuous_a2c():
"""
Implement A2C for continuous control with:
1. Gaussian policy with learnable std
2. GAE advantage estimation
3. Entropy regularization
4. Gradient clipping
5. Proper handling of episode boundaries
"""
pass
def train_on_pendulum():
"""
Train on Pendulum-v1:
1. Implement the agent
2. Train for 500 episodes
3. Plot learning curve
4. Visualize learned policy (action vs state)
5. Compare different GAE lambda values
"""
pass
def ablation_study():
"""
Compare:
- A2C vs REINFORCE
- Different n_steps values
- With/without entropy bonus
- Different GAE lambda values
Report sample efficiency and final performance.
"""
pass
📚 Maneesha's Reflection
Actor and critic as teacher and student: The critic evaluates, the actor performs. How is this similar to teacher-student relationships? What happens when the "critic" gives biased feedback?
Advantage vs absolute reward: Actor-critic cares about "better than expected," not "good or bad." How does this relative framing affect learning? When might absolute feedback be more useful than relative feedback?
Entropy and exploration: Entropy regularization encourages the agent to maintain uncertainty. How does this relate to the "beginner's mind" concept in learning theory? When should learners be encouraged to stay uncertain?
RLHF and human values: If RLHF uses actor-critic to optimize for human feedback, what happens when human feedback is inconsistent? How might this inform how we design feedback systems for human learners?
The credit assignment problem: Actor-critic improves credit assignment through bootstrapping. How do human teachers help learners understand which actions led to which outcomes? What can we learn from actor-critic about effective feedback timing?