Source code for embodichain.lab.gym.envs.managers.observations

# ----------------------------------------------------------------------------
# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------

from __future__ import annotations

import torch
import os

from tensordict import TensorDict
from typing import TYPE_CHECKING, Literal, Union, List, Dict, Sequence

from embodichain.lab.sim.objects import RigidObject, Articulation, Robot
from embodichain.lab.sim.sensors import Camera, StereoCamera
from embodichain.lab.sim.types import EnvObs
from embodichain.lab.gym.envs.managers.cfg import SceneEntityCfg
from embodichain.lab.gym.envs.managers.events import resolve_dict
from embodichain.lab.gym.envs.managers import Functor, FunctorCfg
from embodichain.utils import logger
from embodichain.utils.math import quat_from_matrix

if TYPE_CHECKING:
    from embodichain.lab.gym.envs import EmbodiedEnv


[docs] def get_object_pose( env: EmbodiedEnv, obs: EnvObs, entity_cfg: SceneEntityCfg, to_matrix: bool = True, ) -> torch.Tensor: """Get the arena poses of the objects in the environment. If the object with the specified UID does not exist in the environment, a zero tensor will be returned. Args: env: The environment instance. obs: The observation dictionary. entity_cfg: The configuration of the scene entity. to_matrix: Whether to return the pose as a 4x4 transformation matrix. If False, returns as (position, quaternion). Returns: A tensor of shape (num_envs, 7) or (num_envs, 4, 4) representing the world poses of the objects. """ if entity_cfg.uid not in env.sim.asset_uids: if to_matrix: return torch.zeros( (env.num_envs, 4, 4), dtype=torch.float32, device=env.device ) else: return torch.zeros( (env.num_envs, 7), dtype=torch.float32, device=env.device ) obj = env.sim.get_asset(entity_cfg.uid) return obj.get_local_pose(to_matrix=to_matrix)
[docs] def get_rigid_object_pose( env: EmbodiedEnv, obs: EnvObs, entity_cfg: SceneEntityCfg, to_matrix: bool = True, ) -> torch.Tensor: """Get the world poses of the rigid objects in the environment. If the rigid object with the specified UID does not exist in the environment, a zero tensor will be returned. Note: This method will be deprecated in the future and replaced by `get_object_pose` as the distinction between rigid objects and general objects is being removed. Please use `get_object_pose` instead when possible. Args: env: The environment instance. obs: The observation dictionary. entity_cfg: The configuration of the scene entity. to_matrix: Whether to return the pose as a 4x4 transformation matrix. If False, returns as (position, quaternion). Returns: A tensor of shape (num_envs, 7) or (num_envs, 4, 4) representing the world poses of the rigid objects. """ if entity_cfg.uid not in env.sim.get_rigid_object_uid_list(): if to_matrix: return torch.zeros( (env.num_envs, 4, 4), dtype=torch.float32, device=env.device ) else: return torch.zeros( (env.num_envs, 7), dtype=torch.float32, device=env.device ) obj = env.sim.get_rigid_object(entity_cfg.uid) return obj.get_local_pose(to_matrix=to_matrix)
[docs] def get_object_body_scale( env: EmbodiedEnv, obs: EnvObs, entity_cfg: SceneEntityCfg, ) -> torch.Tensor: """Get the body scale of the objects in the environment. If the object with the specified UID does not exist in the environment, a zero tensor will be returned. Args: env: The environment instance. obs: The observation dictionary. entity_cfg: The configuration of the scene entity. Returns: A tensor of shape (num_envs, 3) representing the body scale of the objects. """ if entity_cfg.uid not in env.sim.asset_uids: return torch.zeros((env.num_envs, 3), dtype=torch.float32, device=env.device) obj = env.sim.get_asset(entity_cfg.uid) if isinstance(obj, RigidObject) is False: logger.log_error( f"Object with UID '{entity_cfg.uid}' is not a RigidObject. Currently only support getting body scale for RigidObject, please check again." ) return obj.get_body_scale()
[docs] def get_object_uid( env: EmbodiedEnv, obs: EnvObs, entity_cfg: SceneEntityCfg, ) -> torch.Tensor: """Get the user IDs of the objects in the environment. If the object with the specified UID does not exist in the environment, a zero tensor will be returned. Note: - If asset is RigidObject, the user IDs is shaped as (num_envs,) - If asset is Articulation or Robot, the user IDs is shaped as (num_envs, num_links) and ordered by link_names in the configuration. Args: env: The environment instance. obs: The observation dictionary. entity_cfg: The configuration of the scene entity. Returns: A tensor of shape (num_envs,) representing the user IDs of the objects. """ if entity_cfg.uid not in env.sim.asset_uids: return torch.zeros((env.num_envs,), dtype=torch.int32, device=env.device) obj = env.sim.get_asset(entity_cfg.uid) if isinstance(obj, (Articulation, Robot, RigidObject)) is False: logger.log_error( f"Object with UID '{entity_cfg.uid}' is not an Articulation, Robot or RigidObject. Currently only support getting user IDs for Articulation, Robot and RigidObject, please check again." ) return obj.get_user_ids()
[docs] def get_rigid_object_velocity( env: EmbodiedEnv, obs: EnvObs, entity_cfg: SceneEntityCfg, ) -> torch.Tensor: """Get the world velocities of the rigid objects in the environment. If the rigid object with the specified UID does not exist in the environment, a zero tensor will be returned. Args: env: The environment instance. obs: The observation dictionary. entity_cfg: The configuration of the scene entity. Returns: A tensor of shape (num_envs, 6) representing the linear and angular velocities of the rigid objects. """ if entity_cfg.uid not in env.sim.get_rigid_object_uid_list(): return torch.zeros((env.num_envs, 6), dtype=torch.float32, device=env.device) obj = env.sim.get_rigid_object(entity_cfg.uid) return obj.body_data.vel
[docs] def normalize_robot_joint_data( env: EmbodiedEnv, data: torch.Tensor, joint_ids: Sequence[int], limit: Literal["qpos_limits", "qvel_limits"] = "qpos_limits", range: Sequence[float] = [0.0, 1.0], ) -> torch.Tensor: """Normalize the robot joint positions to the range of [0, 1] based on the joint limits. Args: env: The environment instance. obs: The observation dictionary. joint_ids: The indices of the joints to be normalized. limit: The type of joint limits to be used for normalization. Options are: - `qpos_limits`: Use the joint position limits for normalization. - `qvel_limits`: Use the joint velocity limits for normalization. range: The range to normalize the joint data to. Defaults to [0.0, 1.0]. """ robot = env.robot joint_ids_set = torch.as_tensor(env.active_joint_ids)[joint_ids] # shape of target_limits: (num_envs, len(joint_ids), 2) target_limits = getattr(robot.body_data, limit)[:, joint_ids_set, :] low = target_limits[:, :, 0] high = target_limits[:, :, 1] # normalize the joint data to the specified range data[:, joint_ids] = (data[:, joint_ids] - low) / (high - low) * ( range[1] - range[0] ) + range[0] return data
[docs] def get_sensor_pose_in_robot_frame( env: EmbodiedEnv, obs: EnvObs, entity_cfg: SceneEntityCfg, robot_uid: str | None = None, is_right: bool = False, ) -> torch.Tensor: """Get the pose of a sensor in the robot's base coordinate frame. Args: env: The environment instance. obs: The observation dictionary. entity_cfg: The configuration of the sensor entity. robot_uid: The uid of the robot. If None, uses the default robot from env. is_right: Whether to return the right camera intrinsics for stereo cameras. Defaults to False (left camera). Ignored for monocular cameras. Returns: A tensor of shape (num_envs, 4, 4) representing the sensor pose in robot coordinates as a transformation matrix. """ # Get robot base pose in world frame robot = env.sim.get_robot(robot_uid) if robot_uid else env.robot robot_pose = robot.get_local_pose(to_matrix=True) robot_pose_inv = torch.linalg.inv(robot_pose) # Get sensor pose in world frame sensor: Union[Camera, StereoCamera] = env.sim.get_sensor(entity_cfg.uid) if sensor is None: logger.log_error( f"Sensor with UID '{entity_cfg.uid}' not found in the simulation." ) if isinstance(sensor, StereoCamera): cam_left_pose, cam_right_pose = sensor.get_left_right_arena_pose() if is_right: cam_in_robot = torch.matmul(robot_pose_inv, cam_right_pose) else: cam_in_robot = torch.matmul(robot_pose_inv, cam_left_pose) else: cam_pose = sensor.get_arena_pose(to_matrix=True) # Compute sensor pose in robot coordinate frame: T_robot_cam = inv(T_world_robot) @ T_world_cam cam_in_robot = torch.matmul(robot_pose_inv, cam_pose) return cam_in_robot
[docs] def get_sensor_intrinsics( env: EmbodiedEnv, obs: EnvObs, entity_cfg: SceneEntityCfg, is_right: bool = False, ) -> torch.Tensor: """Get the intrinsic matrix of a sensor (camera). Args: env: The environment instance. obs: The observation dictionary. entity_cfg: The configuration of the sensor entity. is_right: Whether to return the right camera intrinsics for stereo cameras. Defaults to False (left camera). Ignored for monocular cameras. Returns: A tensor of shape (num_envs, 3, 3) representing the camera intrinsics. """ sensor = env.sim.get_sensor(entity_cfg.uid) if sensor is None: logger.log_error( f"Sensor with UID '{entity_cfg.uid}' not found in the simulation." ) if isinstance(sensor, StereoCamera): left_intrinsics, right_intrinsics = sensor.get_intrinsics() return right_intrinsics if is_right else left_intrinsics elif isinstance(sensor, Camera): return sensor.get_intrinsics() # (num_envs, 3, 3) else: logger.log_error(f"Sensor '{entity_cfg.uid}' is not Camera or StereoCamera.") return torch.zeros((env.num_envs, 3, 3), dtype=torch.float32)
[docs] def compute_semantic_mask( env: EmbodiedEnv, obs: EnvObs, entity_cfg: SceneEntityCfg, foreground_uids: Sequence[str], is_right: bool = False, ) -> torch.Tensor: """Compute the semantic mask for the specified scene entity. Note: The semantic mask is defined as (B, H, W, len(SemanticMask)) where these channels represents: - background channel: the instance id of the background is set to 1 (0 if not background) - foreground channel: the instance id of the foreground objects is set to 1 (0 if not foreground) - robot left-side channel: the instance id of the robot left-side is set to 1 - robot right-side channel: the instance id of the robot right-side is set to 1 Args: env: The environment instance. obs: The observation dictionary. entity_cfg: The configuration of the scene entity. foreground_uids: The list of uids for the foreground objects. is_right: Whether to use the right camera for stereo cameras. Default is False. Only applicable if the sensor is a StereoCamera. Returns: A tensor of shape (num_envs, height, width) representing the semantic mask. """ from embodichain.data.enum import SemanticMask sensor: Union[Camera, StereoCamera] = env.sim.get_sensor(entity_cfg.uid) if sensor.cfg.enable_mask is False: logger.log_error( f"Sensor '{entity_cfg.uid}' does not have mask enabled. Please enable the mask in the sensor configuration." ) if isinstance(sensor, StereoCamera) and is_right: mask = obs["sensor"][entity_cfg.uid]["mask_right"] else: mask = obs["sensor"][entity_cfg.uid]["mask"] left_link_indices = [ i for i, link_name in enumerate(env.robot.link_names) if link_name.startswith("left_") ] left_robot_uids = env.robot.user_ids[:, left_link_indices] right_link_indices = [ i for i, link_name in enumerate(env.robot.link_names) if link_name.startswith("right_") ] right_robot_uids = env.robot.user_ids[:, right_link_indices] mask_exp = mask.unsqueeze(-1) left_robot_uids_exp = left_robot_uids.unsqueeze_(1).unsqueeze_(1) right_robot_uids_exp = right_robot_uids.unsqueeze_(1).unsqueeze_(1) left_robot_mask = (mask_exp == left_robot_uids_exp).any(-1).squeeze_(-1) right_robot_mask = (mask_exp == right_robot_uids_exp).any(-1).squeeze_(-1) asset_uids = env.sim.asset_uids foreground_assets = [ env.sim.get_asset(uid) for uid in foreground_uids if uid in asset_uids ] # cat assets uid (num_envs, n) into dim 1 foreground_uids = torch.cat( [ ( asset.get_user_ids().unsqueeze(1) if asset.get_user_ids().dim() == 1 else asset.get_user_ids() ) for asset in foreground_assets ], dim=1, ) foreground_uids_exp = foreground_uids.unsqueeze_(1).unsqueeze_(1) foreground_mask = (mask_exp == foreground_uids_exp).any(-1).squeeze_(-1) background_mask = ~(left_robot_mask | right_robot_mask | foreground_mask).squeeze_( -1 ) masks = [None, None, None, None] masks_ids = [member.value for member in SemanticMask] assert len(masks) == len( masks_ids ), "Different length of mask slots and SemanticMask Enum {}.".format(masks_ids) mask_id_to_label = { SemanticMask.BACKGROUND.value: background_mask, SemanticMask.FOREGROUND.value: foreground_mask, SemanticMask.ROBOT_LEFT.value: left_robot_mask, SemanticMask.ROBOT_RIGHT.value: right_robot_mask, } for mask_id in masks_ids: masks[mask_id] = mask_id_to_label[mask_id] return torch.stack(masks, dim=-1)
[docs] def get_robot_eef_pose( env: "EmbodiedEnv", obs: EnvObs, part_name: str | None = None, position_only: bool = False, ) -> torch.Tensor: """Get robot end-effector pose using forward kinematics. Args: env: The environment instance. obs: The observation dictionary. part_name: The name of the control part. If None, uses default part. position_only: If True, returns only position (3D). If False, returns full pose (4x4 matrix). Returns: A tensor of shape (num_envs, 3) if position_only=True, or (num_envs, 4, 4) otherwise. """ robot = env.robot if part_name is not None: joint_ids = robot.get_joint_ids(part_name) qpos = robot.body_data.qpos[:, joint_ids] ee_pose = robot.compute_fk(name=part_name, qpos=qpos, to_matrix=True) else: qpos = robot.get_qpos() ee_pose = robot.compute_fk(qpos=qpos, to_matrix=True) if position_only: return ee_pose[:, :3, 3] return ee_pose
[docs] def target_position( env: "EmbodiedEnv", obs: EnvObs, target_pose_key: str = "goal_pose", ) -> torch.Tensor: """Get virtual target position from env state. Reads target pose from env.{target_pose_key} (set by randomize_target_pose event). Returns zeros if not yet initialized (e.g., during env initialization before reset). Args: env: The environment instance obs: Observation dict (unused, for API compatibility) target_pose_key: Key for target pose in env (default: "goal_pose") Returns: Target position tensor of shape (num_envs, 3). Returns zeros if target_pose_key is not found (e.g., before first reset). """ if not hasattr(env, target_pose_key): # Return zeros during initialization (before reset event triggers) return torch.zeros(env.num_envs, 3, device=env.device) target_poses = getattr(env, target_pose_key) if target_poses.dim() == 2: # (num_envs, 3) return target_poses else: # (num_envs, 4, 4) return target_poses[:, :3, 3]
[docs] class compute_exteroception(Functor): """Compute the exteroception for the observation space. The exteroception is currently defined as a set of keypoints around a reference pose, which are prjected from 3D space to 2D image plane. The reference pose can derive from the following sources: - Pose from robot control part (e.g., end-effector, usually tcp pose) - Object affordance pose (e.g., handle pose of a mug or a pick pose of a cube) Therefore, the exteroception are defined in the camera-like sensor, for example. descriptor = { "cam_high": [ { "type": "affordance", "obj_uid": "obj1", "key": "grasp_pose", "is_arena_coord": True }, { "type": "affordance", "obj_uid": "obj1", "key": "place_pose", }, { "type": "robot", "control_part": "left_arm", }, { "type": "robot", "control_part": "right_arm", } ], ... } Explanation of the parameters: - The key of the dictionary is the sensor uid. - The value is another dictionary, where the key is the source type, and the value is a dictionary of parameters. - For `affordance` source type, the parameters are: - `obj_uid`: The uid of the object to get the affordance pose from. - `key`: The key of the affordance pose in the affordance data. - `is_arena_coord`: Whether the affordance pose is in the arena coordinate system. Default is False. - For `robot` source type, the parameters are: - `control_part`: The control part of the robot to get the pose from. """
[docs] def __init__( self, cfg: FunctorCfg, env: EmbodiedEnv, ): super().__init__(cfg, env) if self._env.num_envs != 1: logger.log_error( f"Exteroception functor only supported env with 'num_envs=1' but got 'num_envs={self._env.num_envs}'. Please check again." ) self._valid_source = ["robot", "affordance"]
[docs] @staticmethod def shift_pose(pose: torch.Tensor, axis: int, shift: float) -> torch.Tensor: """Shift the pose along the specified axis by the given amount. Args: pose: The original pose tensor of shape (B, 4, 4). axis: The axis along which to shift (0 for x, 1 for y, 2 for z). shift: The amount to shift along the specified axis. """ shift_pose = torch.linalg.inv(pose) shift_pose[:, axis, -1] += shift shift_pose = torch.linalg.inv(shift_pose) return shift_pose
[docs] @staticmethod def expand_pose( pose: torch.Tensor, x_interval: float, y_interval: float, kpnts_number: int, ref_pose: torch.Tensor = None, ) -> torch.Tensor: """Expand pose with keypoints along x and y axes. Args: pose: The original pose tensor of shape (B, 4, 4). x_interval: The interval for expanding along x-axis. y_interval: The interval for expanding along y-axis. kpnts_number: Number of keypoints to generate for each axis. ref_pose: Reference pose tensor of shape (B, 4, 4). If None, uses identity matrix. Returns: Expanded poses tensor of shape (B, 1 + 2*kpnts_number, 4, 4). """ batch_size = pose.shape[0] device = pose.device # Create default reference pose if not provided if ref_pose is None: ref_pose = ( torch.eye(4, device=device).unsqueeze_(0).repeat(batch_size, 1, 1) ) # Start with the original pose transformed by ref_pose ret = [ref_pose @ pose] # Generate x-axis offsets and expand poses # TODO: only support 1 env xoffset = torch.linspace(-x_interval, x_interval, kpnts_number, device=device) for x_shift in xoffset: shifted_pose = compute_exteroception.shift_pose(pose, 0, x_shift.item()) x_expanded = ref_pose @ shifted_pose ret.append(x_expanded) # Generate y-axis offsets and expand poses # TODO: only support 1 env yoffset = torch.linspace(-y_interval, y_interval, kpnts_number, device=device) for y_shift in yoffset: shifted_pose = compute_exteroception.shift_pose(pose, 1, y_shift.item()) y_expanded = ref_pose @ shifted_pose ret.append(y_expanded) # Stack all poses along a new dimension return torch.stack(ret, dim=1)
@staticmethod def _project_3d_to_2d( cam_pose: torch.Tensor, intrinsics: torch.Tensor, height: int, width: int, target_poses: torch.Tensor, normalize: bool = True, ) -> torch.Tensor: """Project 3D poses to 2D image plane. Args: cam_pose: Camera pose of in arena frame of shape (B, 4, 4). intrinsics: Camera intrinsic matrix of shape (B, 3, 3). height: Image height. width: Image width. target_poses: 3D poses of shape (B, N, 4, 4). normalize: Whether to normalize the projected points to [0, 1] range. Returns: Projected 2D points of shape (B, N, 2). """ batch_size, num_poses = target_poses.shape[:2] # Convert to opencv coordinate system cam_pose[:, :3, 1] = -cam_pose[:, :3, 1] cam_pose[:, :3, 2] = -cam_pose[:, :3, 2] # Expand cam_pose_inv and intrinsics to match target_poses batch dimension cam_pose_inv = torch.linalg.inv(cam_pose) # (B, 4, 4) cam_pose_inv_expanded = cam_pose_inv.unsqueeze(1).expand( -1, num_poses, -1, -1 ) # (B, N, 4, 4) cam_pose_inv_reshaped = cam_pose_inv_expanded.reshape(-1, 4, 4) # (B*N, 4, 4) intrinsics_expanded = intrinsics.unsqueeze(1).expand( -1, num_poses, -1, -1 ) # (B, N, 3, 3) intrinsics_reshaped = intrinsics_expanded.reshape(-1, 3, 3) # (B*N, 3, 3) # Reshape target_poses to (B*N, 4, 4) target_poses_reshaped = target_poses.reshape(-1, 4, 4) # (B*N, 4, 4) # Transform 3D points to camera coordinates in parallel # Extract translation part (position) from target poses: (B*N, 4, 1) target_positions = target_poses_reshaped[:, :, 3:4] # (B*N, 4, 1) # Transform to camera coordinates: (B*N, 4, 1) cam_positions = cam_pose_inv_reshaped.bmm(target_positions) # (B*N, 4, 1) cam_positions_3d = cam_positions[:, :3, 0] # (B*N, 3) # Project to 2D using intrinsics in parallel # Add small epsilon to avoid division by zero eps = 1e-8 z_safe = torch.clamp(cam_positions_3d[:, 2], min=eps) # (B*N,) # Normalize by depth normalized_points = cam_positions_3d[:, :2] / z_safe.unsqueeze(-1) # (B*N, 2) # Convert to homogeneous coordinates and apply intrinsics normalized_homogeneous = torch.cat( [normalized_points, torch.ones_like(normalized_points[:, :1])], dim=-1 ) # (B*N, 3) pixel_coords = intrinsics_reshaped.bmm( normalized_homogeneous.unsqueeze(-1) ).squeeze( -1 ) # (B*N, 3) # Extract 2D coordinates points_2d_flat = pixel_coords[:, :2] # (B*N, 2) # Reshape back to (B, N, 2) points_2d = points_2d_flat.reshape(batch_size, num_poses, 2) # clip to range [0, width] and [0, height] points_2d[..., 0] = torch.clamp(points_2d[..., 0], 0, width - 1) points_2d[..., 1] = torch.clamp(points_2d[..., 1], 0, height - 1) if normalize: # Normalize to [0, 1] range points_2d[..., 0] /= width points_2d[..., 1] /= height return points_2d def _get_gripper_ratio( self, control_part: str, gripper_qpos: torch.Tensor | None = None ): robot: Robot = self._env.robot gripper_max_limit = robot.body_data.qpos_limits[ :, robot.get_joint_ids(control_part) ][:, 0, 1] if gripper_qpos is None: gripper_qpos = robot.get_qpos()[:, robot.get_joint_ids(control_part)][:, 0] return gripper_qpos / gripper_max_limit def _get_robot_exteroception( self, control_part: str | None = None, x_interval: float = 0.02, y_interval: float = 0.02, kpnts_number: int = 12, offset: list | torch.Tensor | None = None, follow_eef: bool = False, ) -> torch.Tensor: """Get the robot exteroception poses. Args: control_part: The part of the robot to use as reference. If None, uses the base. x_interval: The interval for expanding along x-axis. y_interval: The interval for expanding along y-axis. kpnts_number: Number of keypoints to generate for each axis. offset: Intrinsic offset that need to be substracted. follow_eef: Whether to follow the gripper or not. Returns: A tensor of shape (num_envs, 1 + 2*kpnts_number, 4, 4) representing the exteroception poses. """ robot: Robot = self._env.robot if control_part is not None: current_qpos = robot.get_qpos()[:, robot.get_joint_ids(control_part)] robot_pose = robot.compute_fk( current_qpos, name=control_part, to_matrix=True ) if follow_eef: gripper_ratio = self._get_gripper_ratio( control_part.replace("_arm", "_eef") ) # TODO: "_eef" hardcode # TODO: only support 1 env y_interval = (y_interval * gripper_ratio)[0].item() else: logger.log_error("Not supported Robot without control part yet.") if offset is not None: offset = torch.as_tensor( offset, dtype=torch.float32, device=self._env.device ) if (offset.ndim > 2) or (offset.shape[-1] != 3): logger.log_error( f"Only (N, 3) shaped xyz-intrinsic offset supported, got shape {offset.shape}" ) elif offset.ndim == 1: offset = offset[None] # TODO: This operation may be slow when large scale Parallelization, but when small (num_envs=1) this operation is faster robot_pose[:, :3, 3] = robot_pose[:, :3, 3] - torch.einsum( "bij,bj->bi", robot_pose[:, :3, :3], offset ) return compute_exteroception.expand_pose( robot_pose, x_interval, y_interval, kpnts_number, ) def _get_object_exteroception( self, uid: str, affordance_key: str, x_interval: float = 0.02, y_interval: float = 0.02, kpnts_number: int = 12, is_arena_coord: bool = False, follow_eef: str | None = None, ) -> torch.Tensor: """Get the rigid object exteroception poses. Args: uid: The UID of the object. affordance_key: The key of the affordance to use for the object pose. x_interval: The interval for expanding along x-axis. y_interval: The interval for expanding along y-axis. kpnts_number: Number of keypoints to generate for each axis. is_arena_coord: Whether to use the arena coordinate system. Default is False. Returns: A tensor of shape (num_envs, 1 + 2*kpnts_number, 4, 4) representing the exteroception poses. """ obj: RigidObject = self._env.sim.get_rigid_object(uid) if obj is None: logger.log_error( f"Rigid object with UID '{uid}' not found in the simulation." ) if hasattr(self._env, "affordance_datas") is False: logger.log_error( "Affordance data is not available in the environment. We cannot compute object exteroception." ) if affordance_key not in self._env.affordance_datas: # TODO: should this default behavior be warned? # logger.log_warning( # f"Affordance key '{affordance_key}' not found in the affordance data, using identity pose.." # ) pass affordance_pose = torch.as_tensor( self._env.affordance_datas.get( affordance_key, torch.eye(4).repeat(self._env.num_envs, 1, 1) ), dtype=torch.float32, ) if affordance_pose.ndim < 3: affordance_pose = affordance_pose.repeat(self._env.num_envs, 1, 1) ref_pose = None if is_arena_coord else obj.get_local_pose(to_matrix=True) if follow_eef is not None: gripper_ratio = self._get_gripper_ratio(control_part=follow_eef) # TODO: only support 1 env y_interval = (y_interval * gripper_ratio)[0].item() return compute_exteroception.expand_pose( affordance_pose, x_interval, y_interval, kpnts_number, ref_pose=ref_pose, ) def _check_source_valid(self, source: str) -> bool: if source not in self._valid_source: logger.log_error( f"Invalid exteroception source '{source}'. Supported sources are {self._valid_source}." ) return True def __call__( self, env: EmbodiedEnv, obs: EnvObs, descriptor: Dict[str, Dict[str, str]], x_interval: float = 0.02, y_interval: float = 0.02, kpnts_number: int = 12, groups: int = 6, ) -> Dict[str, Dict[str, torch.Tensor]]: """Compute the exteroception poses based on the asset type. Args: descriptor: The observation dictionary. Returns: A dictionary containing the exteroception poses with key 'exteroception'. """ exteroception = {} descriptor = resolve_dict(self._env, descriptor) for sensor_uid, sources in descriptor.items(): sensor: Union[Camera, StereoCamera] = self._env.sim.get_sensor(sensor_uid) if sensor is None: logger.log_error( f"Sensor with UID '{sensor_uid}' not found in the simulation." ) if not isinstance(sensor, (Camera, StereoCamera)): logger.log_error( f"Sensor with UID '{sensor_uid}' is not a Camera or StereoCamera." ) height, width = sensor.cfg.height, sensor.cfg.width exteroception[sensor_uid] = {} taget_pose_list = [] for source in sources: source_type = source["type"] self._check_source_valid(source_type) if source_type == "robot": target_pose = self._get_robot_exteroception( control_part=source["control_part"], x_interval=x_interval, y_interval=y_interval, kpnts_number=kpnts_number, offset=source.get("offset", None), follow_eef=source.get("follow_eef", False), ) elif source_type == "affordance": target_pose = self._get_object_exteroception( uid=source["obj_uid"], affordance_key=source["key"], x_interval=x_interval, y_interval=y_interval, kpnts_number=kpnts_number, is_arena_coord=source["is_arena_coord"], follow_eef=source.get("follow_eef", None), ) else: logger.log_error( f"Unsupported exteroception source '{source_type}'. Supported sources are 'robot' and 'affordance." ) taget_pose_list.append(target_pose) target_poses = torch.cat(taget_pose_list, dim=1) if target_poses.shape[1] / (2 * kpnts_number + 1) != groups: logger.log_error( f"Exteroception groups number mismatch. Expected {groups}, but got {int(target_poses.shape[1] / (2 * kpnts_number + 1))}." ) if isinstance(sensor, StereoCamera): intrinsics, right_intrinsics = sensor.get_intrinsics() left_arena_pose, right_arena_pose = sensor.get_left_right_arena_pose() projected_kpnts = compute_exteroception._project_3d_to_2d( left_arena_pose, intrinsics, height, width, target_poses, ) exteroception[sensor_uid]["l"] = projected_kpnts projected_kpnts = compute_exteroception._project_3d_to_2d( right_arena_pose, right_intrinsics, height, width, target_poses, ) exteroception[sensor_uid]["r"] = projected_kpnts else: intrinsics = sensor.get_intrinsics() projected_kpnts = compute_exteroception._project_3d_to_2d( sensor.get_arena_pose(to_matrix=True), intrinsics, height, width, target_poses, ) exteroception[sensor_uid] = projected_kpnts return exteroception
[docs] class get_rigid_object_physics_attributes(Functor): """Get the physics attributes of the rigid object in the environment with caching. This functor retrieves and caches physics attributes (mass, friction, damping, inertia) for rigid objects. The cache is cleared when the environment resets, ensuring fresh values are fetched at the start of each episode. If the rigid object with the specified UID does not exist in the environment, a zero tensor will be returned for each attribute. The cached data is stored per entity UID. When called, if data is cached, it returns a clone of the cached tensor to prevent accidental modifications. .. note:: Physics attributes are typically constant during an episode, so caching improves performance by avoiding repeated queries to the physics engine. Args: cfg: The configuration object. env: The environment instance. """
[docs] def __init__(self, cfg: FunctorCfg, env: EmbodiedEnv): """Initialize the physics attributes functor. Args: cfg: The configuration object. env: The environment instance. """ super().__init__(cfg, env) self._cache: Dict[str, TensorDict] = {}
[docs] def reset(self, env_ids: Sequence[int] | None = None) -> None: """Clear the cached physics attributes. Args: env_ids: The environment ids. Defaults to None, which clears all cache. """ self._cache.clear()
def __call__( self, env: EmbodiedEnv, obs: EnvObs, entity_cfg: SceneEntityCfg, ) -> TensorDict: """Get the physics attributes of the rigid object. Args: env: The environment instance. obs: The observation dictionary. entity_cfg: The configuration of the scene entity. Returns: A TensorDict containing the physics attributes of the rigid object. If the object does not exist, zero tensors are returned for each attribute. """ uid = entity_cfg.uid # Return cached data if available if uid in self._cache: cached_dict = self._cache[uid] # Return clones to prevent accidental modifications return cached_dict.clone() # Fetch physics attributes from the rigid object if entity_cfg.uid not in env.sim.get_rigid_object_uid_list(): result = TensorDict( { "mass": torch.zeros( (env.num_envs, 1), dtype=torch.float32, device=env.device ), "friction": torch.zeros( (env.num_envs, 1), dtype=torch.float32, device=env.device ), "damping": torch.zeros( (env.num_envs, 1), dtype=torch.float32, device=env.device ), "inertia": torch.zeros( (env.num_envs, 3), dtype=torch.float32, device=env.device ), }, batch_size=[env.num_envs], device=env.device, ) else: obj = env.sim.get_rigid_object(entity_cfg.uid) result = TensorDict( { "mass": obj.get_mass(), "friction": obj.get_friction(), "damping": obj.get_damping(), "inertia": obj.get_inertia(), }, batch_size=[env.num_envs], device=env.device, ) # Cache the result (store clones to avoid modifying cached data) self._cache[uid] = result.clone() return result
[docs] class get_articulation_joint_drive(Functor): """Get the joint drive properties of the articulation in the environment with caching. This functor retrieves and caches joint drive properties (stiffness, damping, max_effort, max_velocity, friction) for articulations (including robots). The cache is cleared when the environment resets, ensuring fresh values are fetched at the start of each episode. If the articulation with the specified UID does not exist in the environment, a zero tensor will be returned for each attribute. The cached data is stored per entity UID. When called, if data is cached, it returns a clone of the cached tensor to prevent accidental modifications. .. note:: Joint drive properties are typically constant during an episode, so caching improves performance by avoiding repeated queries. Args: cfg: The configuration object. env: The environment instance. """
[docs] def __init__(self, cfg: FunctorCfg, env: EmbodiedEnv): """Initialize the joint drive functor. Args: cfg: The configuration object. env: The environment instance. """ super().__init__(cfg, env) self._cache: Dict[str, TensorDict] = {}
[docs] def reset(self, env_ids: Sequence[int] | None = None) -> None: """Clear the cached joint drive properties. Args: env_ids: The environment ids. Defaults to None, which clears all cache. """ self._cache.clear()
def __call__( self, env: EmbodiedEnv, obs: EnvObs, entity_cfg: SceneEntityCfg, ) -> TensorDict: """Get the joint drive properties of the articulation. Args: env: The environment instance. obs: The observation dictionary. entity_cfg: The configuration of the scene entity. Returns: A TensorDict containing the joint drive properties of the articulation. If the object does not exist, zero tensors are returned for each attribute. """ uid = entity_cfg.uid # Return cached data if available if uid in self._cache: cached_dict = self._cache[uid] # Return clones to prevent accidental modifications return cached_dict.clone() # Fetch joint drive properties from the articulation or robot if uid in env.sim.get_articulation_uid_list(): art = env.sim.get_articulation(uid) elif uid in env.sim.get_robot_uid_list(): art = env.sim.get_robot(uid) else: art = None if art is None: # We don't know the exact DOF of a non-existent articulation, # but usually it's 0 if we don't have it. We will just use 1 as fallback or return empty # Wait, Articulation's DOF might not be 1. But to support tensor shape consistency, # perhaps 1 is better than failing. We can use a 0-size dimension or 1. # get_rigid_object_physics_attributes uses shape (num_envs, 1) for mass, etc. # Here we default to 1 joint if not found. result = TensorDict( { "stiffness": torch.zeros( (env.num_envs, 1), dtype=torch.float32, device=env.device ), "damping": torch.zeros( (env.num_envs, 1), dtype=torch.float32, device=env.device ), "max_effort": torch.zeros( (env.num_envs, 1), dtype=torch.float32, device=env.device ), "max_velocity": torch.zeros( (env.num_envs, 1), dtype=torch.float32, device=env.device ), "friction": torch.zeros( (env.num_envs, 1), dtype=torch.float32, device=env.device ), }, batch_size=[env.num_envs], device=env.device, ) else: stiffness, damping, max_effort, max_velocity, friction = ( art.get_joint_drive() ) result = TensorDict( { "stiffness": stiffness, "damping": damping, "max_effort": max_effort, "max_velocity": max_velocity, "friction": friction, }, batch_size=[env.num_envs], device=env.device, ) # Cache the result (store clones to avoid modifying cached data) self._cache[uid] = result.clone() return result