Advantage Actor Critic (A2C)

A2C is an on-policy actor critic policy gradient algorithm. 1

Consider the use of the one-step TD method to estimate the return in the episodic case, i.e., we replace with . If we use as a baseline, the REINFORCE update becomes

where is a single sample approximation to the advantage function . This method is therefore called advantage actor critic or A2C. Note that if is a done state, representing the end of an episode.

Algorithm 1 A2C (episodic)

Initialize policy parameters θ\mathbf{\theta}, critic parameters w\mathbf{w}

repeat

sample starting state s0s_0 of a new episode

for t=0,1,2,t=0,1,2,\dots do

Sample action atπθ(st)a_t \sim \pi_{\mathbf{\theta}}(\bullet \mid s_t)

(st+1,rt,done)(s_{t+1}, r_t, \mathrm{done}) \gets env.step(st,ats_t, a_t)

yt=rt+γ(1done)Vw(st+1)y_t = r_t + \gamma (1 - \mathrm{done}) V_{\mathbf{w}}(s_{t+1}) // target

δt=ytVw(st)\delta_t = y_t - V_{\mathbf{w}}(s_t) // advantage

ww+ηwδtwVw(st)\mathbf{w} \gets \mathbf{w} + \eta_\mathbf{w} \delta_t \nabla_{\mathbf{w}} V_{\mathbf{w}} (s_t) // critic

θθ+ηθγtstopgrad(δt)θlogπθ(atst)\mathbf{\theta} \gets \mathbf{\theta} + \eta_{\mathbf{\theta}} \gamma^t \, \mathrm{stopgrad}(\delta_t) \nabla_{\mathbf{\theta}} \log \pi_{\mathbf{\theta}} (a_t \mid s_t) // actor

if done\mathrm{done} then

break

end if

end for

until converged

This is an on-policy algorithm, where we update the value function to reflect the value of the current policy .

Note

The true advantage function is . In A2C, we typically do not know this quantity exactly, so we replace it with the one-step TD residual , which is a sample-based estimate of the true advantage.

Caution

In practice, we should use a stop-gradient operator on the target value for the TD update. Furthermore, it is common to add an entropy term to the policy, to act as a regularizer (to ensure the policy remains stochastic), which smooths the loss function.

Implementation

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
 
 
class ActorCriticNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=128):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
        )
        self.policy_head = nn.Linear(hidden_dim, action_dim)
        self.value_head = nn.Linear(hidden_dim, 1)
 
    def forward(self, state):
        if isinstance(state, np.ndarray):
            state = torch.from_numpy(state)
        if state.ndim == 1:
            state = state.unsqueeze(0)
        state = state.float()
 
        features = self.backbone(state)
        logits = self.policy_head(features)
        state_value = self.value_head(features).squeeze(-1)
        return logits, state_value
 
 
def train(env, num_episodes: int, gamma: float = 0.99, lr: float = 3e-4):
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
 
    model = ActorCriticNetwork(state_dim, action_dim)
    optimizer = optim.Adam(model.parameters(), lr=lr)
 
    for episode in range(num_episodes):
        state, _ = env.reset()
        discount = 1.0
 
        while True:
            logits, value = model(state)
            dist = Categorical(logits=logits)
            action = dist.sample()
 
            next_state, reward, terminated, truncated, _ = env.step(action.item())
            done = terminated or truncated
 
            with torch.no_grad():
                _, next_value = model(next_state)
                td_target = torch.tensor([reward], dtype=torch.float32)
                if not done:
                    td_target = td_target + gamma * next_value
 
            advantage = td_target - value
            actor_loss = -discount * dist.log_prob(action) * advantage.detach()
            critic_loss = advantage.pow(2)
            loss = actor_loss + critic_loss
 
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
 
            state = next_state
            discount *= gamma
            if done:
                break

Sources

  • Murphy, K. (2025). Reinforcement Learning: An Overview. Chapter 3.

Footnotes

  1. Asynchronous Methods for Deep Reinforcement Learning