Source code for embodichain.agents.rl.utils.helper
# ----------------------------------------------------------------------------
# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------
from __future__ import annotations
from collections.abc import Mapping
from typing import Any
import torch
from tensordict import TensorDict
__all__ = [
"dict_to_tensordict",
"flatten_dict_observation",
]
[docs]
def flatten_dict_observation(obs: TensorDict) -> torch.Tensor:
"""Flatten a hierarchical observation TensorDict into a 2D tensor.
Args:
obs: Observation TensorDict with batch dimension `[num_envs]`.
Returns:
Flattened observation tensor of shape `[num_envs, obs_dim]`.
"""
obs_list: list[torch.Tensor] = []
def _collect_tensors(data: TensorDict) -> None:
for key in sorted(data.keys()):
value = data[key]
if isinstance(value, TensorDict):
_collect_tensors(value)
elif isinstance(value, torch.Tensor):
obs_list.append(value.flatten(start_dim=1))
_collect_tensors(obs)
if not obs_list:
raise ValueError("No tensors found in observation TensorDict.")
return torch.cat(obs_list, dim=-1)
[docs]
def dict_to_tensordict(
obs_dict: TensorDict | Mapping[str, Any], device: torch.device | str
) -> TensorDict:
"""Convert an environment observation mapping into a TensorDict.
Args:
obs_dict: Environment observation returned by `reset()` or `step()`.
device: Target device for the resulting TensorDict.
Returns:
Observation TensorDict moved onto the target device.
"""
if isinstance(obs_dict, TensorDict):
return obs_dict.to(device)
if not isinstance(obs_dict, Mapping):
raise TypeError(
f"Expected observation mapping or TensorDict, got {type(obs_dict)!r}."
)
return TensorDict.from_dict(dict(obs_dict), device=device)