Actor-Critic Methods: The Best of Both Worlds

REINFORCE waits until the episode ends. Actor-Critic learns at every step.


The Actor-Critic Idea

Combine policy gradients with value function learning:

# REINFORCE: Wait for episode to end, use actual returns
gradient = log_prob * (actual_return - baseline)

# Actor-Critic: Learn at each step, use estimated returns
gradient = log_prob * (estimated_advantage)

# The critic provides the advantage estimate
# No need to wait for the episode to finish!

Why Actor-Critic?

REINFORCE problems:

  1. Must wait until episode ends (can't learn mid-episode)
  2. High variance (returns fluctuate wildly)
  3. No value estimate to guide exploration

Actor-Critic solutions:

  1. Learn after every step (or every few steps)
  2. Critic provides lower-variance estimates
  3. Value function helps with credit assignment

The tradeoff:


The Advantage Function

The advantage tells us: "How much better is this action than average?"

# Value function V(s): Expected return starting from state s
# Q function Q(s,a): Expected return taking action a in state s
#
# Advantage A(s,a) = Q(s,a) - V(s)
#
# Intuition:
# - A > 0: This action is better than average for this state
# - A < 0: This action is worse than average
# - A = 0: This action is exactly as good as average

def compute_advantage(states, actions, rewards, next_states, dones, critic, gamma=0.99):
    """
    Compute advantage using TD(0) estimate.
    """
    # V(s) from critic
    values = critic(states)

    # V(s') from critic (0 if terminal)
    next_values = critic(next_states)
    next_values = next_values * (1 - dones)

    # TD target: r + gamma * V(s')
    # This is our estimate of Q(s,a)
    td_targets = rewards + gamma * next_values

    # Advantage = Q(s,a) - V(s) ≈ TD_target - V(s)
    advantages = td_targets - values

    return advantages

One-Step Actor-Critic

The simplest actor-critic algorithm:

class ActorCritic:
    def __init__(self, state_dim, action_dim, hidden_dim=64, lr=3e-4, gamma=0.99):
        # Actor (policy network)
        self.actor = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
            nn.Softmax(dim=-1)
        )

        # Critic (value network)
        self.critic = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=lr)
        self.gamma = gamma

    def select_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0)
        probs = self.actor(state)
        dist = torch.distributions.Categorical(probs)
        action = dist.sample()
        return action.item(), dist.log_prob(action)

    def update(self, state, action, reward, next_state, done, log_prob):
        """
        One-step TD update for both actor and critic.
        """
        state = torch.FloatTensor(state).unsqueeze(0)
        next_state = torch.FloatTensor(next_state).unsqueeze(0)
        reward = torch.FloatTensor([reward])
        done = torch.FloatTensor([done])

        # Critic update
        value = self.critic(state)
        next_value = self.critic(next_state) * (1 - done)
        td_target = reward + self.gamma * next_value.detach()
        td_error = td_target - value

        critic_loss = td_error.pow(2)

        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # Actor update (use TD error as advantage)
        # Recompute td_error with updated critic? No - use the one we have
        advantage = td_error.detach()
        actor_loss = -log_prob * advantage

        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        return actor_loss.item(), critic_loss.item()

Training Loop for Actor-Critic

def train_actor_critic(env, agent, num_episodes=1000):
    """Train actor-critic with per-step updates."""
    episode_rewards = []

    for episode in range(num_episodes):
        state = env.reset()
        episode_reward = 0
        done = False

        while not done:
            # Act
            action, log_prob = agent.select_action(state)
            next_state, reward, done, _ = env.step(action)

            # Learn immediately!
            agent.update(state, action, reward, next_state, done, log_prob)

            state = next_state
            episode_reward += reward

        episode_rewards.append(episode_reward)

        if episode % 100 == 0:
            avg = sum(episode_rewards[-100:]) / min(100, len(episode_rewards))
            print(f"Episode {episode}, Avg Reward: {avg:.2f}")

    return episode_rewards

A2C: Advantage Actor-Critic

A2C adds several improvements over basic actor-critic:

class A2C:
    def __init__(self, state_dim, action_dim, hidden_dim=256, lr=7e-4,
                 gamma=0.99, entropy_coef=0.01, value_coef=0.5, n_steps=5):
        # Shared feature extractor (optional but common)
        self.features = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
        )

        # Actor head
        self.actor_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
        )

        # Critic head
        self.critic_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
        )

        self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        self.gamma = gamma
        self.entropy_coef = entropy_coef
        self.value_coef = value_coef
        self.n_steps = n_steps

    def parameters(self):
        return list(self.features.parameters()) + \
               list(self.actor_head.parameters()) + \
               list(self.critic_head.parameters())

    def forward(self, state):
        features = self.features(state)
        action_logits = self.actor_head(features)
        value = self.critic_head(features)
        return action_logits, value

    def select_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0)
        logits, value = self.forward(state)
        dist = torch.distributions.Categorical(logits=logits)
        action = dist.sample()
        return action.item(), dist.log_prob(action), dist.entropy(), value

    def compute_returns_and_advantages(self, rewards, values, dones, next_value):
        """
        Compute n-step returns and advantages.
        """
        returns = []
        advantages = []

        # Bootstrap from final value
        R = next_value

        for t in reversed(range(len(rewards))):
            R = rewards[t] + self.gamma * R * (1 - dones[t])
            returns.insert(0, R)
            advantages.insert(0, R - values[t])

        returns = torch.stack(returns)
        advantages = torch.stack(advantages)

        # Normalize advantages
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        return returns, advantages

    def update(self, states, actions, rewards, dones, log_probs, values, entropies, next_value):
        """
        A2C update after n steps.
        """
        returns, advantages = self.compute_returns_and_advantages(
            rewards, values, dones, next_value
        )

        # Policy loss (negative because we want to maximize)
        policy_loss = -(torch.stack(log_probs) * advantages.detach()).mean()

        # Value loss
        value_loss = F.mse_loss(torch.stack(values).squeeze(), returns.detach())

        # Entropy bonus (encourages exploration)
        entropy_loss = -torch.stack(entropies).mean()

        # Combined loss
        loss = policy_loss + self.value_coef * value_loss + self.entropy_coef * entropy_loss

        self.optimizer.zero_grad()
        loss.backward()

        # Gradient clipping (helps stability)
        torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=0.5)

        self.optimizer.step()

        return {
            'policy_loss': policy_loss.item(),
            'value_loss': value_loss.item(),
            'entropy': -entropy_loss.item()
        }

N-Step Returns

Instead of one-step TD, use n-step returns for better bias-variance tradeoff:

# 1-step return (high bias, low variance):
# G_t = r_t + gamma * V(s_{t+1})

# Monte Carlo return (low bias, high variance):
# G_t = r_t + gamma*r_{t+1} + gamma^2*r_{t+2} + ... (full episode)

# N-step return (balanced):
# G_t = r_t + gamma*r_{t+1} + ... + gamma^{n-1}*r_{t+n-1} + gamma^n * V(s_{t+n})

def compute_n_step_return(rewards, values, gamma, n_steps):
    """
    Compute n-step returns with bootstrapping.
    """
    returns = []
    T = len(rewards)

    for t in range(T):
        G = 0
        for k in range(min(n_steps, T - t)):
            G += (gamma ** k) * rewards[t + k]

        # Bootstrap from value function if not at episode end
        if t + n_steps < T:
            G += (gamma ** n_steps) * values[t + n_steps]

        returns.append(G)

    return torch.tensor(returns)

Generalized Advantage Estimation (GAE)

GAE smoothly interpolates between different n-step estimates:

def compute_gae(rewards, values, dones, next_value, gamma=0.99, gae_lambda=0.95):
    """
    Generalized Advantage Estimation (GAE).

    gae_lambda controls bias-variance tradeoff:
    - lambda = 0: One-step TD (high bias, low variance)
    - lambda = 1: Monte Carlo (low bias, high variance)
    - lambda = 0.95: Good default (balanced)
    """
    advantages = []
    gae = 0

    # Add next_value for bootstrapping
    values = list(values) + [next_value]

    for t in reversed(range(len(rewards))):
        # TD error at step t
        delta = rewards[t] + gamma * values[t + 1] * (1 - dones[t]) - values[t]

        # GAE: exponentially weighted sum of TD errors
        gae = delta + gamma * gae_lambda * (1 - dones[t]) * gae

        advantages.insert(0, gae)

    return torch.tensor(advantages)

# GAE intuition:
# advantage_t = delta_t + (gamma*lambda)*delta_{t+1} + (gamma*lambda)^2*delta_{t+2} + ...
#
# delta_t = r_t + gamma*V(s_{t+1}) - V(s_t)  [one-step TD error]
#
# The lambda decay means:
# - Recent TD errors matter most
# - Distant TD errors are downweighted
# - This reduces variance while maintaining low bias

A2C with GAE

class A2CWithGAE(A2C):
    def __init__(self, *args, gae_lambda=0.95, **kwargs):
        super().__init__(*args, **kwargs)
        self.gae_lambda = gae_lambda

    def compute_returns_and_advantages(self, rewards, values, dones, next_value):
        """Use GAE instead of simple n-step returns."""
        # GAE for advantages
        advantages = compute_gae(
            rewards, values, dones, next_value,
            gamma=self.gamma, gae_lambda=self.gae_lambda
        )

        # Returns = advantages + values
        returns = advantages + torch.tensor(values)

        # Normalize advantages
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        return returns, advantages

Continuous Action Spaces

For continuous actions, we output distribution parameters instead of discrete probabilities:

class ContinuousActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super().__init__()

        # Shared features
        self.features = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
        )

        # Actor outputs mean and log_std for Gaussian policy
        self.actor_mean = nn.Linear(hidden_dim, action_dim)
        self.actor_log_std = nn.Parameter(torch.zeros(action_dim))

        # Critic outputs value
        self.critic = nn.Linear(hidden_dim, 1)

    def forward(self, state):
        features = self.features(state)
        return features

    def get_action_distribution(self, state):
        """Return a Gaussian distribution over actions."""
        features = self.forward(state)
        mean = self.actor_mean(features)

        # Clamp log_std for stability
        log_std = self.actor_log_std.clamp(-20, 2)
        std = log_std.exp()

        return torch.distributions.Normal(mean, std)

    def get_value(self, state):
        features = self.forward(state)
        return self.critic(features)

    def select_action(self, state, deterministic=False):
        """Sample action from Gaussian policy."""
        state = torch.FloatTensor(state).unsqueeze(0)
        dist = self.get_action_distribution(state)

        if deterministic:
            action = dist.mean
        else:
            action = dist.rsample()  # Reparameterized sample for gradient

        log_prob = dist.log_prob(action).sum(dim=-1)
        value = self.get_value(state)

        return action.squeeze(0).numpy(), log_prob, value

Continuous Action Training

def train_continuous_a2c(env, agent, num_episodes=1000, n_steps=2048):
    """
    Train continuous A2C (similar to PPO's data collection).
    """
    state = env.reset()
    episode_rewards = []
    current_episode_reward = 0

    for update in range(num_episodes):
        # Collect n_steps of experience
        states, actions, rewards, dones = [], [], [], []
        log_probs, values = [], []

        for _ in range(n_steps):
            action, log_prob, value = agent.select_action(state)

            next_state, reward, done, _ = env.step(action)

            states.append(state)
            actions.append(action)
            rewards.append(reward)
            dones.append(done)
            log_probs.append(log_prob)
            values.append(value)

            current_episode_reward += reward
            state = next_state

            if done:
                episode_rewards.append(current_episode_reward)
                current_episode_reward = 0
                state = env.reset()

        # Compute next value for bootstrapping
        with torch.no_grad():
            _, _, next_value = agent.select_action(state, deterministic=True)

        # Update
        agent.update(states, actions, rewards, dones, log_probs, values, next_value)

        if update % 10 == 0 and episode_rewards:
            print(f"Update {update}, Avg Reward: {sum(episode_rewards[-10:]) / len(episode_rewards[-10:]):.2f}")

    return episode_rewards

Squashing for Bounded Actions

Many environments have bounded action spaces. Use tanh squashing:

class SquashedGaussianPolicy(nn.Module):
    """
    Gaussian policy with tanh squashing for bounded actions.
    Used in SAC and other continuous control algorithms.
    """
    def __init__(self, state_dim, action_dim, hidden_dim=256, action_limit=1.0):
        super().__init__()
        self.action_limit = action_limit

        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
        )
        self.mean = nn.Linear(hidden_dim, action_dim)
        self.log_std = nn.Linear(hidden_dim, action_dim)

    def forward(self, state):
        x = self.net(state)
        mean = self.mean(x)
        log_std = self.log_std(x).clamp(-20, 2)
        return mean, log_std

    def sample(self, state):
        """Sample action with log probability correction for squashing."""
        mean, log_std = self.forward(state)
        std = log_std.exp()

        # Sample from Gaussian
        dist = torch.distributions.Normal(mean, std)
        u = dist.rsample()  # Unsquashed action

        # Squash through tanh
        action = torch.tanh(u) * self.action_limit

        # Correct log probability for the squashing
        # log pi(a) = log pi(u) - sum(log(1 - tanh^2(u)))
        log_prob = dist.log_prob(u) - torch.log(1 - action.pow(2) + 1e-6)
        log_prob = log_prob.sum(dim=-1)

        return action, log_prob

    def deterministic_action(self, state):
        mean, _ = self.forward(state)
        return torch.tanh(mean) * self.action_limit

Entropy Regularization

Entropy bonus prevents premature convergence to deterministic policies:

def compute_entropy(distribution):
    """
    Compute entropy of action distribution.

    Discrete: H = -sum(p * log(p))
    Gaussian: H = 0.5 * log(2 * pi * e * sigma^2) per dimension
    """
    if isinstance(distribution, torch.distributions.Categorical):
        return distribution.entropy()
    elif isinstance(distribution, torch.distributions.Normal):
        return distribution.entropy().sum(dim=-1)
    else:
        raise ValueError(f"Unknown distribution type: {type(distribution)}")

# In the loss function:
# total_loss = policy_loss + value_coef * value_loss - entropy_coef * entropy
#
# Subtracting entropy (with positive coef) MAXIMIZES entropy
# This encourages exploration by preferring uncertain policies

Capstone Connection

Actor-Critic is the core of RLHF algorithms like PPO.

# PPO (used in ChatGPT training) is essentially A2C with:
# 1. Clipped surrogate objective (prevents too-large updates)
# 2. Multiple epochs per batch (better sample efficiency)
# 3. GAE for advantage estimation

def ppo_policy_loss(old_log_probs, new_log_probs, advantages, clip_epsilon=0.2):
    """
    PPO's clipped surrogate objective.
    """
    ratio = (new_log_probs - old_log_probs).exp()

    # Clipped and unclipped objective
    obj_unclipped = ratio * advantages
    obj_clipped = ratio.clamp(1 - clip_epsilon, 1 + clip_epsilon) * advantages

    # Take minimum (pessimistic bound)
    return -torch.min(obj_unclipped, obj_clipped).mean()

# RLHF training loop sketch:
def rlhf_training_step(policy, ref_policy, reward_model, prompts):
    """
    One step of RLHF training.
    """
    # Generate responses with current policy
    responses = policy.generate(prompts)

    # Score with reward model (the "critic" in some sense)
    rewards = reward_model(prompts, responses)

    # KL penalty to stay close to reference (prevents reward hacking)
    kl = compute_kl(policy, ref_policy, prompts, responses)
    rewards = rewards - beta * kl

    # PPO update (actor-critic under the hood)
    ppo_update(policy, prompts, responses, rewards)

For sycophancy evaluation:

Actor-critic methods reveal why RLHF can amplify sycophancy:

  1. The reward model (critic) may have learned to give high scores to agreeable responses
  2. The policy (actor) learns to maximize these scores
  3. Advantage = "this agreeable response is better than average" - even when wrong
  4. No mechanism to distinguish helpful agreement from harmful sycophancy

Understanding actor-critic = understanding RLHF's failure modes.


🎓 Tyla's Exercise

  1. Bias-variance analysis: TD(0) has high bias, Monte Carlo has high variance. Derive the bias and variance of n-step returns as a function of n.

  2. Prove GAE interpolation: Show that GAE with lambda=0 gives TD(0) advantage, and lambda=1 gives Monte Carlo advantage.

  3. Log probability correction: For squashed Gaussian policies, derive why we need to subtract log(1 - tanh^2(u)) from the log probability.

  4. Optimal critic: If the critic perfectly estimates V(s), what is the variance of the policy gradient? Compare to REINFORCE.


💻 Aaliyah's Exercise

Build a complete A2C implementation with continuous actions:

def build_continuous_a2c():
    """
    Implement A2C for continuous control with:
    1. Gaussian policy with learnable std
    2. GAE advantage estimation
    3. Entropy regularization
    4. Gradient clipping
    5. Proper handling of episode boundaries
    """
    pass

def train_on_pendulum():
    """
    Train on Pendulum-v1:
    1. Implement the agent
    2. Train for 500 episodes
    3. Plot learning curve
    4. Visualize learned policy (action vs state)
    5. Compare different GAE lambda values
    """
    pass

def ablation_study():
    """
    Compare:
    - A2C vs REINFORCE
    - Different n_steps values
    - With/without entropy bonus
    - Different GAE lambda values

    Report sample efficiency and final performance.
    """
    pass

📚 Maneesha's Reflection

  1. Actor and critic as teacher and student: The critic evaluates, the actor performs. How is this similar to teacher-student relationships? What happens when the "critic" gives biased feedback?

  2. Advantage vs absolute reward: Actor-critic cares about "better than expected," not "good or bad." How does this relative framing affect learning? When might absolute feedback be more useful than relative feedback?

  3. Entropy and exploration: Entropy regularization encourages the agent to maintain uncertainty. How does this relate to the "beginner's mind" concept in learning theory? When should learners be encouraged to stay uncertain?

  4. RLHF and human values: If RLHF uses actor-critic to optimize for human feedback, what happens when human feedback is inconsistent? How might this inform how we design feedback systems for human learners?

  5. The credit assignment problem: Actor-critic improves credit assignment through bootstrapping. How do human teachers help learners understand which actions led to which outcomes? What can we learn from actor-critic about effective feedback timing?