# ----------------------------------------------------------------------------
# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------
from __future__ import annotations
from collections.abc import Iterator
import torch
from tensordict import TensorDict
__all__ = ["iterate_minibatches", "transition_view"]
[docs]
def transition_view(rollout: TensorDict, flatten: bool = False) -> TensorDict:
"""Build a transition-aligned TensorDict from a rollout.
The shared rollout uses a uniform `[num_envs, time + 1]` layout. For
transition-only fields such as `action`, `reward`, and `done`, the final
slot is reserved as padding so that all rollout fields share the same batch
shape. This helper drops that padded slot and exposes the valid transition
slices as a TensorDict with batch shape `[num_envs, time]`.
Args:
rollout: Rollout TensorDict with root batch shape `[num_envs, time + 1]`.
flatten: If True, return a flattened `[num_envs * time]` view.
Returns:
TensorDict containing transition-aligned fields.
"""
action = rollout["action"][:, :-1]
num_envs, time_dim = action.shape[:2]
td = TensorDict(
{
"obs": rollout["obs"][:, :-1],
"action": action,
"sample_log_prob": rollout["sample_log_prob"][:, :-1],
"value": rollout["value"][:, :-1],
"next_value": rollout["value"][:, 1:],
"reward": rollout["reward"][:, :-1],
"done": rollout["done"][:, :-1],
"terminated": rollout["terminated"][:, :-1],
"truncated": rollout["truncated"][:, :-1],
},
batch_size=[num_envs, time_dim],
device=rollout.device,
)
for key in ("advantage", "return", "seq_mask", "seq_return", "entropy"):
if key in rollout.keys():
td[key] = rollout[key][:, :-1]
if flatten:
return td.reshape(num_envs * time_dim)
return td
[docs]
def iterate_minibatches(
rollout: TensorDict, batch_size: int, device: torch.device
) -> Iterator[TensorDict]:
"""Yield shuffled minibatches from a flattened rollout."""
total = rollout.batch_size[0]
indices = torch.randperm(total, device=device)
for start in range(0, total, batch_size):
yield rollout[indices[start : start + batch_size]]