# ----------------------------------------------------------------------------
# Copyright (c) 2021-2025 DexForce Technology Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------
from __future__ import annotations
import torch
import os
import random
from typing import TYPE_CHECKING, Literal, Union, Optional, List, Dict, Sequence
from embodichain.lab.sim.objects import RigidObject, Articulation, Robot
from embodichain.lab.sim.sensors import Camera, StereoCamera
from embodichain.lab.sim.types import EnvObs
from embodichain.lab.gym.envs.managers.cfg import SceneEntityCfg
from embodichain.lab.gym.envs.managers.events import resolve_dict
from embodichain.lab.gym.envs.managers import Functor, FunctorCfg
from embodichain.utils import logger
if TYPE_CHECKING:
from embodichain.lab.gym.envs import EmbodiedEnv
[docs]
def get_rigid_object_pose(
env: EmbodiedEnv,
obs: EnvObs,
entity_cfg: SceneEntityCfg,
) -> torch.Tensor:
"""Get the world poses of the rigid objects in the environment.
Args:
env: The environment instance.
obs: The observation dictionary.
entity_cfg: The configuration of the scene entity.
Returns:
A tensor of shape (num_envs, 4, 4) representing the world poses of the rigid objects.
"""
obj = env.sim.get_rigid_object(entity_cfg.uid)
return obj.get_local_pose(to_matrix=True)
[docs]
def normalize_robot_joint_data(
env: EmbodiedEnv,
data: torch.Tensor,
joint_ids: Sequence[int],
limit: Literal["qpos_limits", "qvel_limits"] = "qpos_limits",
) -> torch.Tensor:
"""Normalize the robot joint positions to the range of [0, 1] based on the joint limits.
Args:
env: The environment instance.
obs: The observation dictionary.
joint_ids: The indices of the joints to be normalized.
limit: The type of joint limits to be used for normalization. Options are:
- `qpos_limits`: Use the joint position limits for normalization.
- `qvel_limits`: Use the joint velocity limits for normalization.
"""
robot = env.robot
# shape of target_limits: (num_envs, len(joint_ids), 2)
target_limits = getattr(robot.body_data, limit)[:, joint_ids, :]
# normalize the joint data to the range of [0, 1]
data[:, joint_ids] = (data[:, joint_ids] - target_limits[:, :, 0]) / (
target_limits[:, :, 1] - target_limits[:, :, 0]
)
return data
[docs]
def compute_semantic_mask(
env: EmbodiedEnv,
obs: EnvObs,
entity_cfg: SceneEntityCfg,
foreground_uids: Sequence[str],
is_right: bool = False,
) -> torch.Tensor:
"""Compute the semantic mask for the specified scene entity.
Note:
The semantic mask is defined as (B, H, W, 3) where the three channels represents:
- robot channel: the instance id of the robot is set to 1 (0 if not robot)
- background channel: the instance id of the background is set to 1 (0 if not background)
- foreground channel: the instance id of the foreground objects is set to 1 (0 if not foreground)
Args:
env: The environment instance.
obs: The observation dictionary.
entity_cfg: The configuration of the scene entity.
foreground_uids: The list of uids for the foreground objects.
is_right: Whether to use the right camera for stereo cameras. Default is False.
Only applicable if the sensor is a StereoCamera.
Returns:
A tensor of shape (num_envs, height, width) representing the semantic mask.
"""
sensor: Union[Camera, StereoCamera] = env.sim.get_sensor(entity_cfg.uid)
if sensor.cfg.enable_mask is False:
logger.log_error(
f"Sensor '{entity_cfg.uid}' does not have mask enabled. Please enable the mask in the sensor configuration."
)
if isinstance(sensor, StereoCamera) and is_right:
mask = obs["sensor"][entity_cfg.uid]["mask_right"]
else:
mask = obs["sensor"][entity_cfg.uid]["mask"]
robot_uids = env.robot.get_user_ids()
mask_exp = mask.unsqueeze(-1)
robot_uids_exp = robot_uids.unsqueeze_(1).unsqueeze_(1)
robot_mask = (mask_exp == robot_uids_exp).any(-1).squeeze_(-1)
foreground_assets = [env.sim.get_asset(uid) for uid in foreground_uids]
# cat assets uid (num_envs, n) into dim 1
foreground_uids = torch.cat(
[
(
asset.get_user_ids().unsqueeze(1)
if asset.get_user_ids().dim() == 1
else asset.get_user_ids()
)
for asset in foreground_assets
],
dim=1,
)
foreground_uids_exp = foreground_uids.unsqueeze_(1).unsqueeze_(1)
foreground_mask = (mask_exp == foreground_uids_exp).any(-1).squeeze_(-1)
background_mask = ~(robot_mask | foreground_mask).squeeze_(-1)
return torch.stack([robot_mask, background_mask, foreground_mask], dim=-1)
[docs]
class compute_exteroception(Functor):
"""Compute the exteroception for the observation space.
The exteroception is currently defined as a set of keypoints around a reference pose, which are prjected from 3D
space to 2D image plane.
The reference pose can derive from the following sources:
- Pose from robot control part (e.g., end-effector, usually tcp pose)
- Object affordance pose (e.g., handle pose of a mug or a pick pose of a cube)
Therefore, the exteroception are defined in the camera-like sensor, for example.
descriptor = {
"cam_high": [
{
"type": "affordance",
"obj_uid": "obj1",
"key": "grasp_pose",
"is_arena_coord": True
},
{
"type": "affordance",
"obj_uid": "obj1",
"key": "place_pose",
},
{
"type": "robot",
"control_part": "left_arm",
},
{
"type": "robot",
"control_part": "right_arm",
}
],
...
}
Explanation of the parameters:
- The key of the dictionary is the sensor uid.
- The value is another dictionary, where the key is the source type, and the value is a dictionary of parameters.
- For `affordance` source type, the parameters are:
- `obj_uid`: The uid of the object to get the affordance pose from.
- `key`: The key of the affordance pose in the affordance data.
- `is_arena_coord`: Whether the affordance pose is in the arena coordinate system. Default is False.
- For `robot` source type, the parameters are:
- `control_part`: The control part of the robot to get the pose from.
"""
[docs]
def __init__(
self,
cfg: FunctorCfg,
env: EmbodiedEnv,
):
super().__init__(cfg, env)
if self._env.num_envs != 1:
logger.log_error(
f"Exteroception functor only supported env with 'num_envs=1' but got 'num_envs={self._env.num_envs}'. Please check again."
)
self._valid_source = ["robot", "affordance"]
[docs]
@staticmethod
def shift_pose(pose: torch.Tensor, axis: int, shift: float) -> torch.Tensor:
"""Shift the pose along the specified axis by the given amount.
Args:
pose: The original pose tensor of shape (B, 4, 4).
axis: The axis along which to shift (0 for x, 1 for y, 2 for z).
shift: The amount to shift along the specified axis.
"""
shift_pose = torch.linalg.inv(pose)
shift_pose[:, axis, -1] += shift
shift_pose = torch.linalg.inv(shift_pose)
return shift_pose
[docs]
@staticmethod
def expand_pose(
pose: torch.Tensor,
x_interval: float,
y_interval: float,
kpnts_number: int,
ref_pose: torch.Tensor = None,
) -> torch.Tensor:
"""Expand pose with keypoints along x and y axes.
Args:
pose: The original pose tensor of shape (B, 4, 4).
x_interval: The interval for expanding along x-axis.
y_interval: The interval for expanding along y-axis.
kpnts_number: Number of keypoints to generate for each axis.
ref_pose: Reference pose tensor of shape (B, 4, 4). If None, uses identity matrix.
Returns:
Expanded poses tensor of shape (B, 1 + 2*kpnts_number, 4, 4).
"""
batch_size = pose.shape[0]
device = pose.device
# Create default reference pose if not provided
if ref_pose is None:
ref_pose = (
torch.eye(4, device=device).unsqueeze_(0).repeat(batch_size, 1, 1)
)
# Start with the original pose transformed by ref_pose
ret = [ref_pose @ pose]
# Generate x-axis offsets and expand poses
# TODO: only support 1 env
xoffset = torch.linspace(-x_interval, x_interval, kpnts_number, device=device)
for x_shift in xoffset:
shifted_pose = compute_exteroception.shift_pose(pose, 0, x_shift.item())
x_expanded = ref_pose @ shifted_pose
ret.append(x_expanded)
# Generate y-axis offsets and expand poses
# TODO: only support 1 env
yoffset = torch.linspace(-y_interval, y_interval, kpnts_number, device=device)
for y_shift in yoffset:
shifted_pose = compute_exteroception.shift_pose(pose, 1, y_shift.item())
y_expanded = ref_pose @ shifted_pose
ret.append(y_expanded)
# Stack all poses along a new dimension
return torch.stack(ret, dim=1)
@staticmethod
def _project_3d_to_2d(
cam_pose: torch.Tensor,
intrinsics: torch.Tensor,
height: int,
width: int,
target_poses: torch.Tensor,
normalize: bool = True,
) -> torch.Tensor:
"""Project 3D poses to 2D image plane.
Args:
cam_pose: Camera pose of in arena frame of shape (B, 4, 4).
intrinsics: Camera intrinsic matrix of shape (B, 3, 3).
height: Image height.
width: Image width.
target_poses: 3D poses of shape (B, N, 4, 4).
normalize: Whether to normalize the projected points to [0, 1] range.
Returns:
Projected 2D points of shape (B, N, 2).
"""
batch_size, num_poses = target_poses.shape[:2]
# Convert to opencv coordinate system
cam_pose[:, :3, 1] = -cam_pose[:, :3, 1]
cam_pose[:, :3, 2] = -cam_pose[:, :3, 2]
# Expand cam_pose_inv and intrinsics to match target_poses batch dimension
cam_pose_inv = torch.linalg.inv(cam_pose) # (B, 4, 4)
cam_pose_inv_expanded = cam_pose_inv.unsqueeze(1).expand(
-1, num_poses, -1, -1
) # (B, N, 4, 4)
cam_pose_inv_reshaped = cam_pose_inv_expanded.reshape(-1, 4, 4) # (B*N, 4, 4)
intrinsics_expanded = intrinsics.unsqueeze(1).expand(
-1, num_poses, -1, -1
) # (B, N, 3, 3)
intrinsics_reshaped = intrinsics_expanded.reshape(-1, 3, 3) # (B*N, 3, 3)
# Reshape target_poses to (B*N, 4, 4)
target_poses_reshaped = target_poses.reshape(-1, 4, 4) # (B*N, 4, 4)
# Transform 3D points to camera coordinates in parallel
# Extract translation part (position) from target poses: (B*N, 4, 1)
target_positions = target_poses_reshaped[:, :, 3:4] # (B*N, 4, 1)
# Transform to camera coordinates: (B*N, 4, 1)
cam_positions = cam_pose_inv_reshaped.bmm(target_positions) # (B*N, 4, 1)
cam_positions_3d = cam_positions[:, :3, 0] # (B*N, 3)
# Project to 2D using intrinsics in parallel
# Add small epsilon to avoid division by zero
eps = 1e-8
z_safe = torch.clamp(cam_positions_3d[:, 2], min=eps) # (B*N,)
# Normalize by depth
normalized_points = cam_positions_3d[:, :2] / z_safe.unsqueeze(-1) # (B*N, 2)
# Convert to homogeneous coordinates and apply intrinsics
normalized_homogeneous = torch.cat(
[normalized_points, torch.ones_like(normalized_points[:, :1])], dim=-1
) # (B*N, 3)
pixel_coords = intrinsics_reshaped.bmm(
normalized_homogeneous.unsqueeze(-1)
).squeeze(
-1
) # (B*N, 3)
# Extract 2D coordinates
points_2d_flat = pixel_coords[:, :2] # (B*N, 2)
# Reshape back to (B, N, 2)
points_2d = points_2d_flat.reshape(batch_size, num_poses, 2)
# clip to range [0, width] and [0, height]
points_2d[..., 0] = torch.clamp(points_2d[..., 0], 0, width - 1)
points_2d[..., 1] = torch.clamp(points_2d[..., 1], 0, height - 1)
if normalize:
# Normalize to [0, 1] range
points_2d[..., 0] /= width
points_2d[..., 1] /= height
return points_2d
def _get_gripper_ratio(
self, control_part: str, gripper_qpos: Optional[torch.Tensor] = None
):
robot: Robot = self._env.robot
gripper_max_limit = robot.body_data.qpos_limits[
:, robot.get_joint_ids(control_part)
][:, 0, 1]
if gripper_qpos is None:
gripper_qpos = robot.get_qpos()[:, robot.get_joint_ids(control_part)][:, 0]
return gripper_qpos / gripper_max_limit
def _get_robot_exteroception(
self,
control_part: Optional[str] = None,
x_interval: float = 0.02,
y_interval: float = 0.02,
kpnts_number: int = 12,
offset: Optional[Union[List, torch.Tensor]] = None,
follow_eef: bool = False,
) -> torch.Tensor:
"""Get the robot exteroception poses.
Args:
control_part: The part of the robot to use as reference. If None, uses the base.
x_interval: The interval for expanding along x-axis.
y_interval: The interval for expanding along y-axis.
kpnts_number: Number of keypoints to generate for each axis.
offset: Intrinsic offset that need to be substracted.
follow_eef: Whether to follow the gripper or not.
Returns:
A tensor of shape (num_envs, 1 + 2*kpnts_number, 4, 4) representing the exteroception poses.
"""
robot: Robot = self._env.robot
if control_part is not None:
current_qpos = robot.get_qpos()[:, robot.get_joint_ids(control_part)]
robot_pose = robot.compute_fk(
current_qpos, name=control_part, to_matrix=True
)
if follow_eef:
gripper_ratio = self._get_gripper_ratio(
control_part.replace("_arm", "_eef")
) # TODO: "_eef" hardcode
# TODO: only support 1 env
y_interval = (y_interval * gripper_ratio)[0].item()
else:
logger.log_error("Not supported Robot without control part yet.")
if offset is not None:
offset = torch.as_tensor(
offset, dtype=torch.float32, device=self._env.device
)
if (offset.ndim > 2) or (offset.shape[-1] != 3):
logger.log_error(
f"Only (N, 3) shaped xyz-intrinsic offset supported, got shape {offset.shape}"
)
elif offset.ndim == 1:
offset = offset[None]
# TODO: This operation may be slow when large scale Parallelization, but when small (num_envs=1) this operation is faster
robot_pose[:, :3, 3] = robot_pose[:, :3, 3] - torch.einsum(
"bij,bj->bi", robot_pose[:, :3, :3], offset
)
return compute_exteroception.expand_pose(
robot_pose,
x_interval,
y_interval,
kpnts_number,
)
def _get_object_exteroception(
self,
uid: str,
affordance_key: str,
x_interval: float = 0.02,
y_interval: float = 0.02,
kpnts_number: int = 12,
is_arena_coord: bool = False,
follow_eef: Optional[str] = None,
) -> torch.Tensor:
"""Get the rigid object exteroception poses.
Args:
uid: The UID of the object.
affordance_key: The key of the affordance to use for the object pose.
x_interval: The interval for expanding along x-axis.
y_interval: The interval for expanding along y-axis.
kpnts_number: Number of keypoints to generate for each axis.
is_arena_coord: Whether to use the arena coordinate system. Default is False.
Returns:
A tensor of shape (num_envs, 1 + 2*kpnts_number, 4, 4) representing the exteroception poses.
"""
obj: RigidObject = self._env.sim.get_rigid_object(uid)
if obj is None:
logger.log_error(
f"Rigid object with UID '{uid}' not found in the simulation."
)
if hasattr(self._env, "affordance_datas") is False:
logger.log_error(
"Affordance data is not available in the environment. We cannot compute object exteroception."
)
if affordance_key not in self._env.affordance_datas:
# TODO: should this default behavior be warned?
# logger.log_warning(
# f"Affordance key '{affordance_key}' not found in the affordance data, using identity pose.."
# )
pass
affordance_pose = torch.as_tensor(
self._env.affordance_datas.get(
affordance_key, torch.eye(4).repeat(self._env.num_envs, 1, 1)
),
dtype=torch.float32,
)
if affordance_pose.ndim < 3:
affordance_pose = affordance_pose.repeat(self._env.num_envs, 1, 1)
ref_pose = None if is_arena_coord else obj.get_local_pose(to_matrix=True)
if follow_eef is not None:
gripper_ratio = self._get_gripper_ratio(control_part=follow_eef)
# TODO: only support 1 env
y_interval = (y_interval * gripper_ratio)[0].item()
return compute_exteroception.expand_pose(
affordance_pose,
x_interval,
y_interval,
kpnts_number,
ref_pose=ref_pose,
)
def _check_source_valid(self, source: str) -> bool:
if source not in self._valid_source:
logger.log_error(
f"Invalid exteroception source '{source}'. Supported sources are {self._valid_source}."
)
return True
def __call__(
self,
env: EmbodiedEnv,
obs: EnvObs,
descriptor: Dict[str, Dict[str, str]],
x_interval: float = 0.02,
y_interval: float = 0.02,
kpnts_number: int = 12,
groups: int = 6,
) -> Dict[str, Dict[str, torch.Tensor]]:
"""Compute the exteroception poses based on the asset type.
Args:
descriptor: The observation dictionary.
Returns:
A dictionary containing the exteroception poses with key 'exteroception'.
"""
exteroception = {}
descriptor = resolve_dict(self._env, descriptor)
for sensor_uid, sources in descriptor.items():
sensor: Union[Camera, StereoCamera] = self._env.sim.get_sensor(sensor_uid)
if sensor is None:
logger.log_error(
f"Sensor with UID '{sensor_uid}' not found in the simulation."
)
if not isinstance(sensor, (Camera, StereoCamera)):
logger.log_error(
f"Sensor with UID '{sensor_uid}' is not a Camera or StereoCamera."
)
height, width = sensor.cfg.height, sensor.cfg.width
exteroception[sensor_uid] = {}
taget_pose_list = []
for source in sources:
source_type = source["type"]
self._check_source_valid(source_type)
if source_type == "robot":
target_pose = self._get_robot_exteroception(
control_part=source["control_part"],
x_interval=x_interval,
y_interval=y_interval,
kpnts_number=kpnts_number,
offset=source.get("offset", None),
follow_eef=source.get("follow_eef", False),
)
elif source_type == "affordance":
target_pose = self._get_object_exteroception(
uid=source["obj_uid"],
affordance_key=source["key"],
x_interval=x_interval,
y_interval=y_interval,
kpnts_number=kpnts_number,
is_arena_coord=source["is_arena_coord"],
follow_eef=source.get("follow_eef", None),
)
else:
logger.log_error(
f"Unsupported exteroception source '{source_type}'. Supported sources are 'robot' and 'affordance."
)
taget_pose_list.append(target_pose)
target_poses = torch.cat(taget_pose_list, dim=1)
if target_poses.shape[1] / (2 * kpnts_number + 1) != groups:
logger.log_error(
f"Exteroception groups number mismatch. Expected {groups}, but got {int(target_poses.shape[1] / (2 * kpnts_number + 1))}."
)
if isinstance(sensor, StereoCamera):
intrinsics, right_intrinsics = sensor.get_intrinsics()
left_arena_pose, right_arena_pose = sensor.get_left_right_arena_pose()
projected_kpnts = compute_exteroception._project_3d_to_2d(
left_arena_pose,
intrinsics,
height,
width,
target_poses,
)
exteroception[sensor_uid]["l"] = projected_kpnts
projected_kpnts = compute_exteroception._project_3d_to_2d(
right_arena_pose,
right_intrinsics,
height,
width,
target_poses,
)
exteroception[sensor_uid]["r"] = projected_kpnts
else:
intrinsics = sensor.get_intrinsics()
projected_kpnts = compute_exteroception._project_3d_to_2d(
sensor.get_arena_pose(to_matrix=True),
intrinsics,
height,
width,
target_poses,
)
exteroception[sensor_uid] = projected_kpnts
return exteroception