Source code for embodichain.agents.rl.utils.helper

# ----------------------------------------------------------------------------
# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------

import torch
from tensordict import TensorDict


[docs] def flatten_dict_observation(obs: TensorDict) -> torch.Tensor: """ Flatten hierarchical TensorDict observations from ObservationManager. Recursively traverse nested TensorDicts, collect all tensor values, flatten each to (num_envs, -1), and concatenate in sorted key order. Args: obs: Nested TensorDict structure, e.g. TensorDict(robot=TensorDict(qpos=..., qvel=...), ...) Returns: Concatenated flat tensor of shape (num_envs, total_dim) """ obs_list = [] def _collect_tensors(d, prefix=""): """Recursively collect tensors from nested TensorDicts in sorted order.""" for key in sorted(d.keys()): full_key = f"{prefix}/{key}" if prefix else key value = d[key] if isinstance(value, TensorDict): _collect_tensors(value, full_key) elif isinstance(value, torch.Tensor): # Flatten tensor to (num_envs, -1) shape obs_list.append(value.flatten(start_dim=1)) _collect_tensors(obs) if not obs_list: raise ValueError("No tensors found in observation TensorDict") result = torch.cat(obs_list, dim=-1) return result