Generalized Advantage Estimation (GAE)

GAE is an on-policy actor critic policy gradient algorithm.

In A2C, we replaced the high variance but unbiased MC return with the low variance but biased one-step return . More generally, we can compute the -step return

and thus obtain the (truncated) -step advantage estimate as follows:

Writing and , the first few -step estimates are

is high bias but low variance, and is low bias but high variance.

Instead of using a single value of , we can take a weighted average of these -step estimators. If we choose weights proportional to , this gives exponentially less weight to longer-horizon estimates. In TD-residual form, the resulting estimator is

To derive the recurrence, split off the first term and factor out from the rest:

The general trick is to rewrite a long-horizon quantity in terms of local one-step TD errors. If we want to isolate , then starting from an -step return we add and subtract the missing term so that the first few terms become exactly . The remaining terms then have the same structure one step later in time, which is why the decomposition telescopes.

Caution

GAE does not define a new target different from the advantage function. The target is still the true advantage . What changes is the estimator: one-step TD gives a local, high-bias/low-variance estimate; longer -step returns reduce the bias but increase the variance; GAE combines these estimators with a parameter to smoothly trade bias for variance.

Here, is a parameter that controls the bias-variance tradeoff: larger values decrease the bias but increase the variance. This is called generalized advantage estimation (GAE).

Algorithm 3 Actor Critic with GAE

Initialize parameters ϕ\mathbf{\phi}, environment state ss

repeat

(s1,a1,r1,,sT,aT,rT,sT+1)=(s_1,a_1,r_1,\dots,s_T,a_T,r_T,s_{T+1}) = rollout((s,πϕ)(s, \pi_{\mathbf{\phi}}))

v1:T+1=Vϕ(s1:T+1)v_{1:T+1} = V_{\mathbf{\phi}}(s_{1:T+1})

(A1:T,y1:T)=stopgrad((A_{1:T}, y_{1:T}) = \mathrm{stopgrad}( GAE(r1:T,v1:T+1,γ,λr_{1:T}, v_{1:T+1}, \gamma, \lambda) ))

L(ϕ)=1Tt=1T[λTD(Vϕ(st)yt)2λPGAtlogπϕ(atst)λentH(πϕ(st))]\mathcal{L}(\mathbf{\phi}) = \frac{1}{T} \sum_{t=1}^T [\lambda_{\mathrm{TD}}(V_{\mathbf{\phi}}(s_t) - y_t)^2 - \lambda_{\mathrm{PG}} A_t \log \pi_{\mathbf{\phi}} (a_t \mid s_t) - \lambda_{\mathrm{ent}} \mathbb{H}(\pi_{\mathbf{\phi}}(\bullet \mid s_t))]

ϕϕηϕL(ϕ)\mathbf{\phi} \gets \mathbf{\phi} - \eta \nabla_{\mathbf{\phi}} \mathcal{L}(\mathbf{\phi})

until converged

function GAE(r1:T,v1:T+1,γ,λr_{1:T}, v_{1:T+1}, \gamma, \lambda)

AT+1=0A_{T+1} = 0

for t=Tt=T down to 11 do

δt=rt+γvt+1vt\delta_t = r_t + \gamma v_{t+1} - v_t

At=δt+γλAt+1A_t = \delta_t + \gamma \lambda A_{t+1} // advantage

yt=At+vty_t = A_t + v_t // return

end for

return (A1:T,y1:T)(A_{1:T}, y_{1:T})

end function

More generally, many policy-gradient methods use estimators of the form

The choice of determines which estimator we use. Common choices include

Tip

These are not arbitrary choices: to produce a valid policy-gradient estimator, should estimate up to the subtraction of a baseline that depends only on the state. The main difference between methods is therefore not the target policy gradient itself, but which estimator of the return/advantage is used and what bias-variance tradeoff it induces.

Implementation

A simple implementation is:

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
 
class ActorCriticNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=128):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
        )
        self.policy_head = nn.Linear(hidden_dim, action_dim)
        self.value_head = nn.Linear(hidden_dim, 1)
 
    def forward(self, state):
        if isinstance(state, np.ndarray):
            state = torch.from_numpy(state)
        if state.ndim == 1:
            state = state.unsqueeze(0)
        state = state.float()
 
        features = self.backbone(state)
        logits = self.policy_head(features)
        state_value = self.value_head(features).squeeze(-1)
        return logits, state_value
 
def compute_gae(rewards, values, dones, gamma: float = 0.99, lam: float = 0.95):
    advantages = []
    gae = torch.tensor(0.0)
 
    for t in reversed(range(len(rewards))):
        delta = rewards[t] + gamma * values[t + 1] * (1 - dones[t]) - values[t]
        gae = delta + gamma * lam * (1 - dones[t]) * gae
        advantages.insert(0, gae)
 
    advantages = torch.stack(advantages)
    returns = advantages + values[:-1]
    return advantages, returns
 
def train_with_gae(
    env,
    num_episodes: int,
    gamma: float = 0.99,
    lam: float = 0.95,
    lr: float = 3e-4,
    rollout_length: int = 128,
    value_coef: float = 0.5,
    entropy_coef: float = 0.01,
):
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
 
    model = ActorCriticNetwork(state_dim, action_dim)
    optimizer = optim.Adam(model.parameters(), lr=lr)
 
    state, _ = env.reset()
 
    for episode in range(num_episodes):
        rewards = []
        dones = []
        log_probs = []
        entropies = []
        values = []
 
        for _ in range(rollout_length):
            logits, value = model(state)
            dist = Categorical(logits=logits)
            action = dist.sample()
 
            next_state, reward, terminated, truncated, _ = env.step(action.item())
            done = terminated or truncated
 
            rewards.append(torch.tensor(reward, dtype=torch.float32))
            dones.append(torch.tensor(float(done), dtype=torch.float32))
            log_probs.append(dist.log_prob(action).squeeze(0))
            entropies.append(dist.entropy().squeeze(0))
            values.append(value.squeeze(0))
 
            state = next_state
            if done:
                state, _ = env.reset()
                break
 
        with torch.no_grad():
            _, bootstrap_value = model(state)
            values.append(bootstrap_value.squeeze(0))
 
        values = torch.stack(values)
        rewards = torch.stack(rewards)
        dones = torch.stack(dones)
        log_probs = torch.stack(log_probs)
        entropies = torch.stack(entropies)
 
        advantages, returns = compute_gae(rewards, values, dones, gamma=gamma, lam=lam)
 
        actor_loss = -(advantages.detach() * log_probs).mean()
        critic_loss = (returns.detach() - values[:-1]).pow(2).mean()
        entropy_bonus = entropies.mean()
        loss = actor_loss + value_coef * critic_loss - entropy_coef * entropy_bonus
 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Sources

  • Murphy, K. (2025). Reinforcement Learning: An Overview. Chapter 3.