Consider the use of the one-step TD method to estimate the return in the episodic case, i.e., we replace with . If we use as a baseline, the REINFORCE update becomes
where is a single sample approximation to the advantage function . This method is therefore called advantage actor critic or A2C. Note that if is a done state, representing the end of an episode.
Algorithm 1 A2C (episodic)
Initialize policy parameters θ, critic parameters w
repeat
sample starting state s0 of a new episode
for t=0,1,2,… do
Sample action at∼πθ(∙∣st)
(st+1,rt,done)←env.step(st,at)
yt=rt+γ(1−done)Vw(st+1) // target
δt=yt−Vw(st) // advantage
w←w+ηwδt∇wVw(st) // critic
θ←θ+ηθγtstopgrad(δt)∇θlogπθ(at∣st) // actor
if done then
break
end if
end for
until converged
This is an on-policy algorithm, where we update the value function to reflect the value of the current policy .
Note
The true advantage function is . In A2C, we typically do not know this quantity exactly, so we replace it with the one-step TD residual, which is a sample-based estimate of the true advantage.
Caution
In practice, we should use a stop-gradient operator on the target value for the TD update. Furthermore, it is common to add an entropy term to the policy, to act as a regularizer (to ensure the policy remains stochastic), which smooths the loss function.
Implementation
import numpy as npimport torchimport torch.nn as nnimport torch.optim as optimfrom torch.distributions import Categoricalclass ActorCriticNetwork(nn.Module): def __init__(self, state_dim, action_dim, hidden_dim=128): super().__init__() self.backbone = nn.Sequential( nn.Linear(state_dim, hidden_dim), nn.ReLU(), ) self.policy_head = nn.Linear(hidden_dim, action_dim) self.value_head = nn.Linear(hidden_dim, 1) def forward(self, state): if isinstance(state, np.ndarray): state = torch.from_numpy(state) if state.ndim == 1: state = state.unsqueeze(0) state = state.float() features = self.backbone(state) logits = self.policy_head(features) state_value = self.value_head(features).squeeze(-1) return logits, state_valuedef train(env, num_episodes: int, gamma: float = 0.99, lr: float = 3e-4): state_dim = env.observation_space.shape[0] action_dim = env.action_space.n model = ActorCriticNetwork(state_dim, action_dim) optimizer = optim.Adam(model.parameters(), lr=lr) for episode in range(num_episodes): state, _ = env.reset() discount = 1.0 while True: logits, value = model(state) dist = Categorical(logits=logits) action = dist.sample() next_state, reward, terminated, truncated, _ = env.step(action.item()) done = terminated or truncated with torch.no_grad(): _, next_value = model(next_state) td_target = torch.tensor([reward], dtype=torch.float32) if not done: td_target = td_target + gamma * next_value advantage = td_target - value actor_loss = -discount * dist.log_prob(action) * advantage.detach() critic_loss = advantage.pow(2) loss = actor_loss + critic_loss optimizer.zero_grad() loss.backward() optimizer.step() state = next_state discount *= gamma if done: break
Sources
Murphy, K. (2025). Reinforcement Learning: An Overview. Chapter 3.