World Models

World models are trained to predict future observations, rewards, values, and/or latent embeddings. Once trained, the models can be used for decision-time planning, background planning, or just as an auxiliary signal to aid things like intrinsic curiosity.

World Models Trained to Predict Observation Targets

We can use a models model to generate imaginary trajectories by sampling from the following joint distribution

This model can be augmented with latent variables. If the state space is high dimensional (e.g. images), then we denote the observable data by . We can then learn using standard techniques for conditional image generation such as diffusion models.

Generative World Models without Latent Variables

The simplest approach is to define as a conditional generative model over states. If the observed states are low-dimensional vectors, we can use transformers. 1

Generative World Models with Latent Variables

Methods that use latent variables as part of their world model can improve the speed of generating imaginary futures, and can provide a compact latent space as input to a policy. We let denote the latent (or hidden) state at time . This can be a discrete or continuous variable (or vector of variables). The generative model has the form of a controlled hidden Markov model (HMM)

where is the decoder or likelihood function, is the dynamics in latent space, and is the policy in latent space.

The world model is usually trained by maximizing the marginal likelihood of the observed outputs given an action sequence. Computing the marginal likelihood requires marginalizing over the hidden variables . To make this computationally tractable, it is common to use amortized variational inference, in which we train an encoder network , to approximate the posterior over the latents.

Example: Dreamer

The Dreamer methods 2345 are all based on the background planning approach, in which the policy is trained on imaginary trajectories generated by a latent variable world model. In Dreamer, the latent state is split into a deterministic recurrent state and a stochastic state . The recurrent state summarizes history, while the stochastic state carries information needed to reconstruct the current observation and predict future uncertainty. It is standard to treat the full RSSM state as .

In more detail, Dreamer uses the following functions:

  • A hidden dynamics model:
  • A latent prior:
  • A latent posterior (encoder):
  • An observation decoder:
  • A reward predictor:
  • A policy:
  • A value model:

Intuitively, the reason Dreamer separates and is:

  • acts like a deterministic memory state. It can carry forward long-range temporal information through the recurrence without forcing the model to resample everything at each step.
  • captures the stochastic part of the latent state. It gives the model a way to represent uncertainty or fine-grained details that are difficult to compress into a purely deterministic recurrence.
  • If we tried to run the dynamics directly only on , then would need to do two jobs at once: store stable history and represent stochastic variation. This tends to make the latent transition harder to model and train.
  • The split also makes imagination cheaper and more stable: provides a smooth summary of the past, while injects the random detail needed for reconstruction and prediction.
  • Another way to say it is that carries the “belief state,” and is a stochastic refinement of that belief at the current time step.

The functions are used in two alternating phases:

  1. World-model update: encode real trajectories from the replay buffer into posterior latent states, and fit the dynamics / decoder / reward models.
  2. Behavior learning in imagination: start from posterior states inferred from real data, roll the RSSM forward using the prior and policy only, and train the actor-critic on those imagined trajectories.

Algorithm 12 Dreamer

Initialize replay buffer B\mathcal{B}

repeat

Collect new environment experience with atπ(atht,zt)a_t \sim \pi(a_t \mid h_t, z_t)

Store trajectories (ot,at,rt)(o_t, a_t, r_t) in B\mathcal{B}

// Phase 1: world-model learning on real trajectories

Sample a sequence (o1:T,a1:T1,r1:T)(o_{1:T}, a_{1:T-1}, r_{1:T}) from B\mathcal{B}

Initialize h0h_0

for t=1:T do

ht=U(ht1,zt1,at1)h_t = \mathcal{U}(h_{t-1}, z_{t-1}, a_{t-1})

ztE(ztht,ot)z_t \sim E(z_t | h_t, o_t) // posterior state

Compute reconstruction, reward, and KL terms

end for

Update U,P,E,D,R\mathcal{U}, P, E, D, R using LWM\mathcal{L}^{\mathrm{WM}}

// Phase 2: actor-critic learning in imagination

Sample posterior states (ht,zt)(h_t, z_t) from the sequence above

for imagination step k=0:H1k=0:H-1 do

akπ(akhk,zk)a_k \sim \pi(a_k | h_k, z_k)

hk+1=U(hk,zk,ak)h_{k+1} = \mathcal{U}(h_k, z_k, a_k)

zk+1P(zk+1hk+1)z_{k+1} \sim P(z_{k+1} | h_{k+1}) // prior state

rkR(rkhk,zk)r_k \sim R(r_k | h_k, z_k)

vk=V(hk,zk)v_k = V(h_k, z_k)

end for

Form λ\lambda-returns from imagined rewards and values

Update VV with the critic loss and π\pi with the actor loss

until Converged

The loss used to train the world model has the form

where the terms are different weights for each loss, and is the posterior over the latents, given by

The loss terms correspond to the observation reconstruction term, the reward prediction term, and the posterior-to-prior KL penalty

where is a divergence between the posterior and prior distributions over , not a pointwise loss between samples. In addition to the world model loss, we have the following actor-critic losses on imagined rollouts.

where is the truncated -return used for imagined trajectories

with terminal condition . Conceptually, Dreamer is a “learn a latent simulator, then improve the policy/value function entirely inside that simulator” algorithm.

World Models Trained to Predict Other Targets

These types of world models are not necessarily able to predict all the future observations. These are often still (conditional) generative models (in that they return a distribution over potentially high dimensional outputs), but they are lossy models because they do not capture all the details of the data.

The Objective Mismatch Problem

If we can learn a sufficiently accurate world model, then solving for the optimal policy in simulation will give a policy that is close to optimal in the real world. However, a simple agent may not be able to capture the full complexity of the true environment; this is called the small agent, big world problem.

For example, if the states are images, a dynamics model with limited representational capacity may choose to focus on predicting the background pixels rather than more control-relevant features, like small moving objects, since predicting the background pixels reliably reduces the MSE more. This is due to objective mismatch, which refers to the discrepancy between the way a model is usually trained (to predict the observations) and the way its representation is used for control.

The following discusses ways to tackle this problem for learning representations and models that don’t rely on predicting all the observations.

To keep the notation simple, let denote the full observable history, let be a learned latent representation, and let be a latent dynamics model. The main design choice is then: what target should and be trained to preserve or predict?

Observation Prediction

We can train the latent model to predict the next observation:

This is the standard reconstruction-style objective used by generative world models. Its advantage is that it uses a very rich supervision signal. Its drawback is that in high-dimensional problems, such as images, the model may spend much of its capacity predicting visually dominant but control-irrelevant details.

Reward Prediction

We can instead ask the latent state to preserve just enough information to predict the immediate reward:

Equivalently, we learn a predictor so that matches the reward that would have been predicted from the full history. Intuitively, this encourages to retain the parts of the history that matter for immediate task feedback, while discarding reward-irrelevant details.

Value Prediction

For control, the more important target is often long-term value rather than immediate reward. The goal is to learn a representation for which the optimal value function can be recovered from alone:

This is the idea behind value equivalence: if two histories map to the same latent state, then they should have the same control-relevant future. Intuitively, value prediction is more aligned with decision making than observation prediction, because it asks the representation to preserve what matters for return rather than what matters for pixel reconstruction.

Policy Prediction

The value and reward signals may still be too sparse to shape the representation efficiently. If we have access to a stronger policy target, for example from MCTS, we can train the latent state to preserve the information needed to imitate that target policy:

This gives denser supervision than reward alone. In MuZero, for example, the search policy produced by MCTS is used as a training target for the reactive policy network.

Self Prediction (Self Distillation)

In sparse-reward settings, even value or policy targets may not be enough to learn quickly. A common auxiliary objective is therefore to make the latent dynamics model predict the next latent state directly:

This is often called a self-prediction or latent consistency loss. Intuitively, it says that the representation should evolve in a way that is predictable from the current latent state and action, even if we never reconstruct the full observation.

Avoiding Self-Prediction Collapse using Frozen Targets

A trivial way to minimize the self-prediction loss is to learn an embedding that maps everything to a constant vector, say , in which case will be trivial for the dynamics model to predict. However, this is not a useful representation. This problem is called representation collapse.

Fortunately, we can often reduce or avoid collapse by using a frozen target network. That is, we use the auxiliary loss

where

is the exponential moving average (EMA) of the encoder weights . If we use a frozen old copy of the weights instead, this is usually called a target network.

This approach means that the goalposts (the target representation) evolve slowly and consistently over time, guided by the progress of the encoder and predictor. This adds stability to the training process, ensuring the target representations don’t change erratically from one step to the next, which would make the predictor’s job impossible.

Sources

  • Murphy, K. (2025). Reinforcement Learning: An Overview. Chapter 4.

Footnotes

  1. A Generalist Dynamics Model for Control

  2. Dream to Control: Learning Behaviors by Latent Imagination (Dreamer)

  3. Mastering Atari with Discrete World Models (Dreamer V2)

  4. Mastering Diverse Domains through World Models (Dreamer V3)

  5. Training Agents Inside of Scalable World Models (Dreamer V4)