Rollout Buffer#
This module implements the data buffer for RL training, responsible for storing trajectory data from agent-environment interactions.
Main Classes and Structure#
RolloutBuffer#
Used for on-policy algorithms (such as PPO, GRPO), storing a shared rollout
TensorDictfor collector and algorithm stages.Supports multi-environment parallelism with rollout batch shape
[N, T + 1], all data allocated on GPU.Structure fields:
obs: Flattened observation tensor, float32, shape[N, T + 1, obs_dim]action: Action tensor, float32, shape[N, T + 1, action_dim]sample_log_prob: Action log probabilities, float32, shape[N, T + 1]value: Value estimates, float32, shape[N, T + 1]reward: Reward tensor, float32, shape[N, T + 1]done: Done flags, bool, shape[N, T + 1]terminated: Termination flags, bool, shape[N, T + 1]truncated: Truncation flags, bool, shape[N, T + 1]Algorithm-added fields such as
advantage,return,seq_mask, andseq_return
The final time index is valid for obs and value, where it stores the last
observation and bootstrap value. For transition-only fields (action, reward,
done, etc.), the final slot is padding so all rollout fields can share the
same [N, T + 1] batch shape.
Main Methods#
start_rollout(): Returns the shared preallocated rolloutTensorDictfor collector write-in.add(rollout): Marks the shared rollout as ready for consumption.get(flatten=True): Returns the stored rollout after converting it to a transition view over the valid firstTsteps.transition_view(rollout, flatten=False): Builds a transition-aligned view that drops the padded final slot from transition-only fields.iterate_minibatches(rollout, batch_size, device): Shared batching utility inbuffer/utils.py.
Usage Example#
buffer = RolloutBuffer(num_envs, rollout_len, obs_dim, action_dim, device)
rollout = collector.collect(num_steps=rollout_len, rollout=buffer.start_rollout())
buffer.add(rollout)
rollout = buffer.get(flatten=False)
flat_rollout = transition_view(rollout, flatten=True)
for batch in iterate_minibatches(flat_rollout, batch_size, device):
# batch["obs"], batch["action"], batch["advantage"] ...
pass
Design and Extension#
Supports multi-environment parallel collection, compatible with Gymnasium-style vectorized environments.
All tensors are preallocated on device to avoid frequent CPU-GPU copying.
Algorithm-specific fields are attached directly onto the shared rollout
TensorDictduring optimization.The shared minibatch iterator automatically shuffles flattened transition entries for PPO/GRPO style updates.
Code Example#
class RolloutBuffer:
def __init__(self, num_envs, rollout_len, obs_dim, action_dim, device):
# Preallocate rollout TensorDict
...
def start_rollout(self):
# Return shared rollout storage
...
def add(self, rollout):
# Mark rollout as full
...
def get(self, flatten=True):
# Consume rollout
...
Practical Tips#
The rollout buffer stores flattened RL observations; structured observations should be flattened or encoded before entering this buffer.
value[:, -1]stores the bootstrap value of the final observation. The final slot of transition-only fields is padding and should be ignored during optimization.Use
transition_view()plusiterate_minibatches()instead of duplicating rollout slicing logic in each algorithm.