Source code for embodichain.agents.rl.buffer.utils

# ----------------------------------------------------------------------------
# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------

from __future__ import annotations

from collections.abc import Iterator

import torch
from tensordict import TensorDict

__all__ = ["iterate_minibatches", "transition_view"]


[docs] def transition_view(rollout: TensorDict, flatten: bool = False) -> TensorDict: """Build a transition-aligned TensorDict from a rollout. The shared rollout uses a uniform `[num_envs, time + 1]` layout. For transition-only fields such as `action`, `reward`, and `done`, the final slot is reserved as padding so that all rollout fields share the same batch shape. This helper drops that padded slot and exposes the valid transition slices as a TensorDict with batch shape `[num_envs, time]`. Args: rollout: Rollout TensorDict with root batch shape `[num_envs, time + 1]`. flatten: If True, return a flattened `[num_envs * time]` view. Returns: TensorDict containing transition-aligned fields. """ action = rollout["action"][:, :-1] num_envs, time_dim = action.shape[:2] td = TensorDict( { "obs": rollout["obs"][:, :-1], "action": action, "sample_log_prob": rollout["sample_log_prob"][:, :-1], "value": rollout["value"][:, :-1], "next_value": rollout["value"][:, 1:], "reward": rollout["reward"][:, :-1], "done": rollout["done"][:, :-1], "terminated": rollout["terminated"][:, :-1], "truncated": rollout["truncated"][:, :-1], }, batch_size=[num_envs, time_dim], device=rollout.device, ) for key in ("advantage", "return", "seq_mask", "seq_return", "entropy"): if key in rollout.keys(): td[key] = rollout[key][:, :-1] if flatten: return td.reshape(num_envs * time_dim) return td
[docs] def iterate_minibatches( rollout: TensorDict, batch_size: int, device: torch.device ) -> Iterator[TensorDict]: """Yield shuffled minibatches from a flattened rollout.""" total = rollout.batch_size[0] indices = torch.randperm(total, device=device) for start in range(0, total, batch_size): yield rollout[indices[start : start + batch_size]]