Generalized Advantage Estimation (GAE)
GAE is an on-policy actor critic policy gradient algorithm.
In A2C, we replaced the high variance but unbiased MC return
and thus obtain the (truncated)
Writing
Instead of using a single value of
To derive the recurrence, split off the first term and factor out
The general trick is to rewrite a long-horizon quantity in terms of local one-step TD errors. If we want to isolate
Caution
GAE does not define a new target different from the advantage function. The target is still the true advantage
. What changes is the estimator: one-step TD gives a local, high-bias/low-variance estimate; longer -step returns reduce the bias but increase the variance; GAE combines these estimators with a parameter to smoothly trade bias for variance.
Here,
Algorithm 3 Actor Critic with GAE
Initialize parameters , environment state
repeat
rollout()
GAE()
until converged
function GAE()
for down to do
// advantage
// return
end for
return
end function
More generally, many policy-gradient methods use estimators of the form
The choice of
Tip
These are not arbitrary choices: to produce a valid policy-gradient estimator,
should estimate up to the subtraction of a baseline that depends only on the state. The main difference between methods is therefore not the target policy gradient itself, but which estimator of the return/advantage is used and what bias-variance tradeoff it induces.
Implementation
A simple implementation is:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
class ActorCriticNetwork(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=128):
super().__init__()
self.backbone = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
)
self.policy_head = nn.Linear(hidden_dim, action_dim)
self.value_head = nn.Linear(hidden_dim, 1)
def forward(self, state):
if isinstance(state, np.ndarray):
state = torch.from_numpy(state)
if state.ndim == 1:
state = state.unsqueeze(0)
state = state.float()
features = self.backbone(state)
logits = self.policy_head(features)
state_value = self.value_head(features).squeeze(-1)
return logits, state_value
def compute_gae(rewards, values, dones, gamma: float = 0.99, lam: float = 0.95):
advantages = []
gae = torch.tensor(0.0)
for t in reversed(range(len(rewards))):
delta = rewards[t] + gamma * values[t + 1] * (1 - dones[t]) - values[t]
gae = delta + gamma * lam * (1 - dones[t]) * gae
advantages.insert(0, gae)
advantages = torch.stack(advantages)
returns = advantages + values[:-1]
return advantages, returns
def train_with_gae(
env,
num_episodes: int,
gamma: float = 0.99,
lam: float = 0.95,
lr: float = 3e-4,
rollout_length: int = 128,
value_coef: float = 0.5,
entropy_coef: float = 0.01,
):
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
model = ActorCriticNetwork(state_dim, action_dim)
optimizer = optim.Adam(model.parameters(), lr=lr)
state, _ = env.reset()
for episode in range(num_episodes):
rewards = []
dones = []
log_probs = []
entropies = []
values = []
for _ in range(rollout_length):
logits, value = model(state)
dist = Categorical(logits=logits)
action = dist.sample()
next_state, reward, terminated, truncated, _ = env.step(action.item())
done = terminated or truncated
rewards.append(torch.tensor(reward, dtype=torch.float32))
dones.append(torch.tensor(float(done), dtype=torch.float32))
log_probs.append(dist.log_prob(action).squeeze(0))
entropies.append(dist.entropy().squeeze(0))
values.append(value.squeeze(0))
state = next_state
if done:
state, _ = env.reset()
break
with torch.no_grad():
_, bootstrap_value = model(state)
values.append(bootstrap_value.squeeze(0))
values = torch.stack(values)
rewards = torch.stack(rewards)
dones = torch.stack(dones)
log_probs = torch.stack(log_probs)
entropies = torch.stack(entropies)
advantages, returns = compute_gae(rewards, values, dones, gamma=gamma, lam=lam)
actor_loss = -(advantages.detach() * log_probs).mean()
critic_loss = (returns.detach() - values[:-1]).pow(2).mean()
entropy_bonus = entropies.mean()
loss = actor_loss + value_coef * critic_loss - entropy_coef * entropy_bonus
optimizer.zero_grad()
loss.backward()
optimizer.step()Sources
- Murphy, K. (2025). Reinforcement Learning: An Overview. Chapter 3.