# ----------------------------------------------------------------------------
# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------
from __future__ import annotations
from typing import Callable
import torch
from tensordict import TensorDict
from embodichain.agents.rl.utils import dict_to_tensordict, flatten_dict_observation
from .base import BaseCollector
__all__ = ["SyncCollector"]
[docs]
class SyncCollector(BaseCollector):
"""Synchronously collect rollouts from a vectorized environment."""
[docs]
def __init__(
self,
env,
policy,
device: torch.device,
reset_every_rollout: bool = False,
) -> None:
self.env = env
self.policy = policy
self.device = device
self.reset_every_rollout = reset_every_rollout
self._supports_shared_rollout = hasattr(self.env, "set_rollout_buffer")
self.obs_td = self._reset_env()
[docs]
@torch.no_grad()
def collect(
self,
num_steps: int,
rollout: TensorDict | None = None,
on_step_callback: Callable[[TensorDict, dict], None] | None = None,
) -> TensorDict:
self.policy.train()
if self.reset_every_rollout:
self.obs_td = self._reset_env()
if rollout is None:
raise ValueError(
"SyncCollector.collect() requires a preallocated rollout TensorDict."
)
if tuple(rollout.batch_size) != (self.env.num_envs, num_steps + 1):
raise ValueError(
"Preallocated rollout batch size mismatch: "
f"expected ({self.env.num_envs}, {num_steps + 1}), got {tuple(rollout.batch_size)}."
)
self._validate_rollout(rollout, num_steps)
if self._supports_shared_rollout:
self.env.set_rollout_buffer(rollout)
initial_obs = flatten_dict_observation(self.obs_td)
rollout["obs"][:, 0] = initial_obs
for step_idx in range(num_steps):
step_td = TensorDict(
{"obs": rollout["obs"][:, step_idx]},
batch_size=[rollout.batch_size[0]],
device=self.device,
)
step_td = self.policy.get_action(step_td)
next_obs, reward, terminated, truncated, env_info = self.env.step(
self._to_action_dict(step_td["action"])
)
next_obs_td = dict_to_tensordict(next_obs, self.device)
self._write_step(
rollout=rollout,
step_idx=step_idx,
step_td=step_td,
)
if not self._supports_shared_rollout:
self._write_env_step(
rollout=rollout,
step_idx=step_idx,
reward=reward,
terminated=terminated,
truncated=truncated,
)
rollout["obs"][:, step_idx + 1] = flatten_dict_observation(next_obs_td)
if on_step_callback is not None:
on_step_callback(rollout[:, step_idx], env_info)
self.obs_td = next_obs_td
self._attach_final_value(rollout)
return rollout
def _attach_final_value(self, rollout: TensorDict) -> None:
"""Populate the bootstrap value for the final observed state."""
final_obs = rollout["obs"][:, -1]
last_next_td = TensorDict(
{"obs": final_obs},
batch_size=[rollout.batch_size[0]],
device=self.device,
)
self.policy.get_value(last_next_td)
rollout["value"][:, -1] = last_next_td["value"]
def _reset_env(self) -> TensorDict:
obs, _ = self.env.reset()
return dict_to_tensordict(obs, self.device)
def _to_action_dict(self, action: torch.Tensor) -> TensorDict | torch.Tensor:
am = getattr(self.env, "action_manager", None)
if am is None:
return action
else:
return am.convert_policy_action_to_env_action(action)
def _write_step(
self,
rollout: TensorDict,
step_idx: int,
step_td: TensorDict,
) -> None:
"""Write policy-side fields for one transition into the shared rollout TensorDict."""
rollout["action"][:, step_idx] = step_td["action"]
rollout["sample_log_prob"][:, step_idx] = step_td["sample_log_prob"]
rollout["value"][:, step_idx] = step_td["value"]
def _write_env_step(
self,
rollout: TensorDict,
step_idx: int,
reward: torch.Tensor,
terminated: torch.Tensor,
truncated: torch.Tensor,
) -> None:
"""Populate transition-side fields when the environment does not own the rollout."""
done = terminated | truncated
rollout["reward"][:, step_idx] = reward.to(self.device)
rollout["done"][:, step_idx] = done.to(self.device)
rollout["terminated"][:, step_idx] = terminated.to(self.device)
rollout["truncated"][:, step_idx] = truncated.to(self.device)
def _validate_rollout(self, rollout: TensorDict, num_steps: int) -> None:
"""Validate rollout layout expected by the collector."""
expected_shapes = {
"obs": (self.env.num_envs, num_steps + 1, self.policy.obs_dim),
"action": (self.env.num_envs, num_steps + 1, self.policy.action_dim),
"sample_log_prob": (self.env.num_envs, num_steps + 1),
"value": (self.env.num_envs, num_steps + 1),
"reward": (self.env.num_envs, num_steps + 1),
"done": (self.env.num_envs, num_steps + 1),
"terminated": (self.env.num_envs, num_steps + 1),
"truncated": (self.env.num_envs, num_steps + 1),
}
for key, expected_shape in expected_shapes.items():
actual_shape = tuple(rollout[key].shape)
if actual_shape != expected_shape:
raise ValueError(
f"Preallocated rollout field '{key}' shape mismatch: "
f"expected {expected_shape}, got {actual_shape}."
)