Source code for embodichain.lab.sim.objects.robot

# ----------------------------------------------------------------------------
# Copyright (c) 2021-2025 DexForce Technology Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------

import torch
import numpy as np

from typing import List, Dict, Optional, Tuple, Union, Sequence
from dataclasses import dataclass, field

from dexsim.engine import Articulation as _Articulation
from embodichain.lab.sim.cfg import RobotCfg
from embodichain.lab.sim.solvers import SolverCfg, BaseSolver
from embodichain.lab.sim.objects import Articulation
from embodichain.lab.sim.utility.tensor import to_tensor
from embodichain.utils.math import quat_from_matrix
from embodichain.utils.string import (
    is_regular_expression,
    resolve_matching_names_values,
)
from embodichain.utils import logger


@dataclass
class ControlGroup:
    r"""Represents a group of controllable joints in a robot.

    Attributes:
        joint_names (List[str]): Names of the joints in this control group.
        joint_ids (List[int]): IDs corresponding to the joints in this control group.
        link_names (List[str]): Names of child links associated with the joints.
    """

    joint_names: List[str] = field(default_factory=list)
    joint_ids: List[int] = field(default_factory=list)
    link_names: List[str] = field(default_factory=list)

    def __post_init__(self):
        pass


[docs] class Robot(Articulation): """A class representing a batch of robots in the simulation environment. Robot is a specific type of articulation that can have additional properties or methods. - `control_parts`: Specify the parts that can be controlled in a different manner. Different part may have different joint ids, drive properties, pyhsical attributes, kinematic solvers or motion planners. - `solvers`: Specify the kinematic solvers for the robot. - `planners`: Specify the motion planner for the robot. """
[docs] def __init__( self, cfg: RobotCfg, entities: List[_Articulation], device: torch.device = torch.device("cpu"), ) -> None: super().__init__(cfg, entities, device) self._solvers = {} # Initialize joint ids for control parts. self._joint_ids: Dict[str, List[int]] = {} self._control_groups: Dict[str, ControlGroup] = {} if self.cfg.control_parts: self._init_control_parts(self.cfg.control_parts) if self.cfg.solver_cfg: self.init_solver(self.cfg.solver_cfg)
def __str__(self) -> str: parent_str = super().__str__() return ( parent_str + f" | control_parts: {self.control_parts}, solvers: {self._solvers}" ) @property def control_parts(self) -> Union[Dict[str, List[str]], None]: """Get the control parts of the robot.""" return self.cfg.control_parts
[docs] def get_joint_ids( self, name: Optional[str] = None, remove_mimic: bool = False ) -> List[int]: """Get the joint ids of the robot for a specific control part. Args: name (str, optional): The name of the control part to get the joint ids for. If None, the default part is used. remove_mimic (bool, optional): If True, mimic joints will be excluded from the returned joint ids. Defaults to False. Returns: List[int]: The joint ids of the robot for the specified control part. """ if not self.control_parts or name is None: return ( torch.arange(self.dof, dtype=torch.int32).tolist() if not remove_mimic else [i for i in range(self.dof) if i not in self.mimic_ids] ) if name not in self.control_parts: logger.log_error( f"The control part '{name}' does not exist in the robot's control parts." ) return ( self._joint_ids[name] if not remove_mimic else [i for i in self._joint_ids[name] if i not in self.mimic_ids] )
[docs] def get_proprioception(self) -> Dict[str, torch.Tensor]: """Gets robot proprioception information, primarily for agent state representation in robot learning scenarios. The default proprioception information includes: - qpos: Joint positions. - qvel: Joint velocities. - qf: Joint efforts. Returns: Dict[str, torch.Tensor]: A dictionary containing the robot's proprioception information """ return dict( qpos=self.body_data.qpos, qvel=self.body_data.qvel, qf=self.body_data.qf )
[docs] def compute_fk( self, qpos: Optional[Union[torch.tensor, np.ndarray]], name: Optional[str] = None, link_names: Optional[List[str]] = None, end_link_name: Optional[str] = None, root_link_name: Optional[str] = None, env_ids: Optional[Sequence[int]] = None, to_matrix: bool = False, ) -> torch.Tensor: """Compute the forward kinematics of the robot given joint positions and optionally a specific part name. The output pose will be in the local arena frame. Args: qpos (Optional[Union[torch.tensor, np.ndarray]]): Joint positions of the robot, (n_envs, num_joints). name (str, optional): The name of the control part to compute the FK for. If None, the default part is used. link_names (List[str], optional): The names of the links to compute the FK for. If None, all links are used. end_link_name (str, optional): The name of the end link to compute the FK for. If None, the default end link is used. root_link_name (str, optional): The name of the root link to compute the FK for. If None, the default root link is used. env_ids (Sequence[int], optional): The environment ids to compute the FK for. If None, all environments are used. to_matrix (bool, optional): If True, returns the transformation in the form of a 4x4 matrix. Returns: torch.Tensor: The forward kinematics result with shape (n_envs, 7) or (n_envs, 4, 4) if `to_matrix` is True. """ local_env_ids = self._all_indices if env_ids is None else env_ids if name is None and hasattr(super(), "compute_fk"): return super().compute_fk( qpos=qpos, link_names=link_names, end_link_name=end_link_name, root_link_name=root_link_name, ) if not self._solvers: logger.log_error( "No solvers are defined for the robot. Please ensure that the robot has solvers configured." ) solver = self._solvers.get(name if name is not None else "default", None) if solver is None: logger.log_error( f"The control part '{name}' does not have an associated solver. Please ensure that a valid control part with an available solver is provided." ) return None if qpos.dim() == 1: qpos = qpos.unsqueeze(0) if qpos.shape[0] != len(local_env_ids): logger.log_error( f"Joint positions batch size mismatch. Expected {len(local_env_ids)} but got {qpos.shape[0]}." ) if qpos.shape[1] != solver.dof: logger.log_error( f"Joint positions shape mismatch. Expected {solver.dof} joints, got {qpos.shape[1]}." ) result_matrix = solver.get_fk(qpos=qpos) base_pose = self.get_link_pose( link_name=solver.root_link_name, env_ids=local_env_ids, to_matrix=True ) result_matrix = torch.bmm(base_pose, result_matrix) if to_matrix: return result_matrix else: pos = result_matrix[:, :3, 3] quat = quat_from_matrix(result_matrix[:, :3, :3]) return torch.cat((pos, quat), dim=-1)
[docs] def compute_ik( self, pose: Union[torch.Tensor, np.ndarray], joint_seed: Optional[Union[torch.Tensor, np.ndarray]] = None, name: Optional[str] = None, env_ids: Optional[Sequence[int]] = None, return_all_solutions: bool = False, ) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: """Compute the inverse kinematics of the robot given joint positions and optionally a specific part name. The input pose should be in the local arena frame. Args: pose (torch.Tensor): The end effector pose of the robot, (n_envs, 7) or (n_envs, 4, 4). joint_seed (torch.Tensor, optional): The joint positions to use as a seed for the IK computation, (n_envs, dof). If None, the zero joint positions will be used as the seed. name (str, optional): The name of the control part to compute the IK for. If None, the default part is used. env_ids (Optional[Sequence[int]]): Environment indices to apply the positions. Defaults to all environments. return_all_solutions (bool, optional): Whether to return all IK solutions or just the best one. Defaults to False. Returns: Tuple[torch.Tensor, torch.Tensor]: The success Tensor with shape (n_envs, ) and qpos Tensor with shape (n_envs, max_results, dof). """ local_env_ids = self._all_indices if env_ids is None else env_ids solver = self._solvers.get(name if name is not None else "default", None) if solver is None: logger.log_error( f"The control part '{name}' does not have an associated solver. Please ensure that a valid control part with an available solver is provided." ) return None pose = to_tensor(pose, device=self.device) if (pose.dim() == 1 and pose.shape[1] == 7) or ( pose.dim() == 2 and pose.shape[1] == 4 ): pose = pose.unsqueeze(0) if pose.shape[0] != len(local_env_ids): logger.log_error( f"Pose batch size mismatch. Expected {len(local_env_ids)} but got {pose.shape[0]}." ) if joint_seed is not None: joint_seed = to_tensor(joint_seed, device=self.device) if joint_seed.dim() == 1: joint_seed = joint_seed.unsqueeze(0) if joint_seed.shape[0] != len(local_env_ids): logger.log_error( f"Joint seed batch size mismatch. Expected {len(local_env_ids)} but got {joint_seed.shape[0]}." ) if pose.shape[-1] == 7 and pose.dim() == 2: # Convert pose from (batch, 7) to (batch, 4, 4) pose = torch.cat( ( pose[:, :3].unsqueeze(-1), # Position quat_from_matrix(pose[:, 3:]).unsqueeze(-1), # Quaternion ), dim=-1, ) pose = torch.cat( ( pose, torch.tensor([[0, 0, 0, 1]], device=pose.device).expand( pose.shape[0], -1, -1 ), ), dim=1, ) base_pose = self.get_link_pose( link_name=solver.root_link_name, env_ids=local_env_ids, to_matrix=True ) pose = torch.bmm(torch.inverse(base_pose), pose) ret, qpos = solver.get_ik( target_xpos=pose, qpos_seed=joint_seed, return_all_solutions=return_all_solutions, ) dof = qpos.shape[-1] if not return_all_solutions: qpos = qpos.reshape(-1, dof) return ret.to(self.device), qpos.to(self.device)
[docs] def compute_batch_fk( self, qpos: torch.tensor, name: str, env_ids: Optional[Sequence[int]] = None, to_matrix: bool = False, ): """Compute the forward kinematics of the robot given joint positions and optionally a specific part name. The output pose will be in the local arena frame. Args: qpos (Optional[Union[torch.tensor, np.ndarray]]): Joint positions of the robot, (n_envs, n_batch, num_joints). name (str, optional): The name of the control part to compute the FK for. If None, the default part is used. env_ids (Sequence[int], optional): The environment ids to compute the FK for. If None, all environments are used. to_matrix (bool, optional): If True, returns the transformation in the form of a 4x4 matrix. Returns: torch.Tensor: The forward kinematics result with shape (n_envs, batch, 7) or (n_envs, batch, 4, 4) if `to_matrix` is True. """ local_env_ids = self._all_indices if env_ids is None else env_ids if not self._solvers: logger.log_error( "No solvers are defined for the robot. Please ensure that the robot has solvers configured." ) solver = self._solvers.get(name if name is not None else "default", None) if solver is None: logger.log_error( f"The control part '{name}' does not have an associated solver. Please ensure that a valid control part with an available solver is provided." ) return None if qpos.shape[0] != len(local_env_ids): logger.log_error( f"Joint positions batch size mismatch. Expected {len(local_env_ids)} but got {qpos.shape[0]}." ) if qpos.shape[2] != solver.dof: logger.log_error( f"Joint positions shape mismatch. Expected {solver.dof} joints, got {qpos.shape[1]}." ) n_batch = qpos.shape[1] qpos_batch = qpos.reshape(-1, solver.dof) xpos_batch = solver.get_fk(qpos=qpos_batch) # get xpos from link root base_xpos_n_envs = self.get_link_pose( link_name=solver.root_link_name, env_ids=local_env_ids, to_matrix=True ) base_xpos_batch = ( base_xpos_n_envs[:, None, :, :].repeat(1, n_batch, 1, 1).reshape(-1, 4, 4) ) result_matrix = torch.bmm(base_xpos_batch, xpos_batch) if to_matrix: result_matrix = result_matrix.reshape(len(local_env_ids), n_batch, 4, 4) return result_matrix else: pos = result_matrix[:, :3, 3] quat = quat_from_matrix(result_matrix[:, :3, :3]) result = torch.cat((pos, quat), dim=-1) result = result.reshape(len(local_env_ids), n_batch, 7) return result
[docs] def compute_batch_ik( self, pose: Union[torch.Tensor, np.ndarray], joint_seed: Optional[Union[torch.Tensor, np.ndarray]], name: str, env_ids: Optional[Sequence[int]] = None, ): """Compute the inverse kinematics of the robot given joint positions and optionally a specific part name. The input pose should be in the local arena frame. Args: pose (torch.Tensor): The end effector pose of the robot, (n_envs, n_batch, 7) or (n_envs, n_batch, 4, 4). joint_seed (torch.Tensor, optional): The joint positions to use as a seed for the IK computation, (n_envs, n_batch, dof). If None, the zero joint positions will be used as the seed. name (str): The name of the control part to compute the IK for. If None, the default part is used. env_ids (Optional[Sequence[int]]): Environment indices to apply the positions. Defaults to all environments. Returns: Tuple[torch.Tensor, torch.Tensor]: Success Tensor with shape (n_envs, n_batch) Qpos Tensor with shape (n_envs, n_batch, dof). """ local_env_ids = self._all_indices if env_ids is None else env_ids solver = self._solvers.get(name if name is not None else "default", None) if solver is None: logger.log_error( f"The control part '{name}' does not have an associated solver. Please ensure that a valid control part with an available solver is provided." ) return None pose = to_tensor(pose, device=self.device) if pose.shape[0] != len(local_env_ids): logger.log_error( f"Pose batch size mismatch. Expected {len(local_env_ids)} but got {pose.shape[0]}." ) n_batch = pose.shape[1] n_dof = solver.dof if joint_seed is None: joint_seed = torch.zeros( (len(local_env_ids), n_batch, n_dof), dtype=torch.float32, device=self.device, ) if joint_seed.shape[0] != len(local_env_ids): logger.log_error( f"Joint seed env size mismatch. Expected {len(local_env_ids)} but got {joint_seed.shape[0]}." ) if joint_seed.shape[1] != n_batch: logger.log_error( f"Joint seed batch size mismatch. Expected {n_batch} but got {joint_seed.shape[1]}." ) if joint_seed.shape[-1] != n_dof: logger.log_error( f"Joint seed dof size mismatch. Expected {n_batch} but got {joint_seed.shape[-1]}." ) if pose.shape[-1] == 7 and pose.dim() == 3: # Convert pose from (n_envs, n_batch, 7) to (n_envs * n_batch, 4, 4) pose_batch = torch.reshape(-1, 7) pose_batch = torch.cat( ( pose_batch[:, :3].unsqueeze(-1), # Position quat_from_matrix(pose_batch[:, 3:]).unsqueeze(-1), # Quaternion ), dim=-1, ) pose_batch = torch.cat( ( pose_batch, torch.tensor([[0, 0, 0, 1]], device=pose_batch.device).expand( pose_batch.shape[0], -1, -1 ), ), dim=1, ) else: # Convert pose from (n_envs, n_batch, 4, 4) to (n_envs * n_batch, 4, 4) pose_batch = pose.reshape(-1, 4, 4) # get xpos from link root base_xpos_n_envs = self.get_link_pose( link_name=solver.root_link_name, env_ids=local_env_ids, to_matrix=True ) base_inv_xpos_n_envs = torch.inverse(base_xpos_n_envs) base_inv_xpos_batch = ( base_inv_xpos_n_envs[:, None, :, :] .repeat(1, n_batch, 1, 1) .reshape(-1, 4, 4) ) pose_batch = torch.bmm(base_inv_xpos_batch, pose_batch) joint_seed_batch = joint_seed.reshape(-1, n_dof) ret, qpos_batch = solver.get_ik( target_xpos=pose_batch, qpos_seed=joint_seed_batch, return_all_solutions=False, ) ret = ret.reshape(len(local_env_ids), n_batch) qpos = qpos_batch.reshape(len(local_env_ids), n_batch, n_dof) return ret, qpos
def _init_control_parts(self, control_parts: Dict[str, List[str]]) -> None: """Initialize the control parts of the robot. Args: control_parts (Dict[str, List[str]]): A dictionary where keys are control part names and values are lists of joint names or regular expressions that match joint names. """ joint_name_to_ids = {name: i for i, name in enumerate(self.joint_names)} for name, joint_names in control_parts.items(): # convert joint_names which is a regular expression to a list of joint names joint_names_expanded = [] for jn in joint_names: if is_regular_expression(jn): _, names, _ = resolve_matching_names_values( {jn: None}, self.joint_names ) joint_names_expanded.extend(names) else: joint_names_expanded.append(jn) self._joint_ids[name] = [ joint_name_to_ids[joint_name] for joint_name in joint_names_expanded if joint_name in joint_name_to_ids ] if len(self._joint_ids[name]) != len(joint_names_expanded): logger.log_error( f"joint names in control part '{name}' do not match the robot's joint names. The full joint names are: {self.joint_names}." ) self.cfg.control_parts[name] = joint_names_expanded # Initialize control groups self._control_groups = self._extract_control_groups()
[docs] def init_solver(self, cfg: Union[SolverCfg, Dict[str, SolverCfg]]) -> None: """Initialize the kinematic solver for the robot. Args: cfg (Union[SolverCfg, Dict[str, SolverCfg]]): The configuration for the kinematic solver. """ self.cfg: RobotCfg if isinstance(cfg, SolverCfg): if self.control_parts: logger.log_error( "Control parts are defined in the robot configuration, solver_cfg must be a dictionary." ) if cfg.urdf_path is None: cfg.urdf_path = self.cfg.fpath self._solvers["default"] = cfg.init_solver(device=self.device) elif isinstance(cfg, Dict): if isinstance(self.cfg.control_parts, Dict) is False: logger.log_error( "When `solver_cfg` is a dictionary, `control_parts` must also be a dictionary." ) # If solver_cfg is a dictionary, iterate through it to create solvers for name, solver_cfg in cfg.items(): if solver_cfg.urdf_path is None: solver_cfg.urdf_path = self.cfg.fpath _, part_names, value = resolve_matching_names_values( {name: solver_cfg}, self.cfg.control_parts.keys() ) for part_name in part_names: if ( not hasattr(solver_cfg, "joint_names") or solver_cfg.joint_names is None ): solver_cfg.joint_names = self.cfg.control_parts[part_name] self._solvers[name] = solver_cfg.init_solver(device=self.device)
[docs] def get_solver(self, name: Optional[str] = None) -> Optional[BaseSolver]: """Get the kinematic solver for a specific control part. Args: name (str, optional): The name of the control part to get the solver for. If None, the default part is used. Returns: Optional[BaseSolver]: The kinematic solver for the specified control part, or None if not found. """ if not self._solvers: logger.log_error( "No solvers are defined for the robot. Please ensure that the robot has solvers configured." ) return None return self._solvers.get(name if name is not None else "default", None)
[docs] def get_control_part_base_pose( self, name: Optional[str] = None, env_ids: Optional[Sequence[int]] = None, to_matrix: bool = False, ) -> torch.Tensor: """Retrieves the base pose of the control part for a specified robot. Args: name (Optional[str]): The name of the control part the solver adhere to. If None, the default solver is used. env_ids (Optional[Sequence[int]]): A sequence of environment IDs to specify the environments. If None, all indices are used. to_matrix (bool): If True, returns the pose in the form of a 4x4 matrix. Returns: The pose of the specified link in the form of a matrix. """ local_env_ids = self._all_indices if env_ids is None else env_ids root_link_name = None if name in self._control_groups: root_link_name = self._control_groups[name].link_names[0] return self.get_link_pose( link_name=root_link_name, env_ids=local_env_ids, to_matrix=to_matrix )
def _extract_control_groups(self) -> Dict[str, ControlGroup]: r"""Extract control groups from the active joint names. This method creates a dictionary of control groups where each control group is associated with its corresponding joint names. It utilizes the `_extract_control_group` method to populate the control groups. Returns: Dict[str, ControlGroup]: A dictionary mapping control group names to their corresponding ControlGroup instances. """ if not self.control_parts: return {} control_groups = { control_group_name: self._extract_control_group(joint_names) for control_group_name, joint_names in self.control_parts.items() } return control_groups def _extract_control_group(self, joint_names: List[str]) -> ControlGroup: r"""Extract a control group from the given list of joint names. Args: joint_names (List[str]): A list of joint names to be included in the control group. Returns: ControlGroup: An instance of ControlGroup containing the specified joints and their associated links. """ control_group = ControlGroup() joint_id_list = [] for joint_name in joint_names: if joint_name in self.joint_names: joint_index = self.joint_names.index(joint_name) joint_id_list.append(joint_index) control_group.joint_names.append(joint_name) # Set root link for first joint if len(control_group.link_names) == 0: parent_names = self._entities[0].get_ancestral_link_names( joint_index ) control_group.link_names.extend(parent_names) child_name = self._entities[0].get_child_link_name(joint_index) control_group.link_names.append(child_name) control_group.joint_ids = joint_id_list return control_group
[docs] def build_pk_serial_chain(self) -> None: """Build the kinematic serial chain for the robot. This method is mainly used for robot learning scenarios, for example: - Imitation learning dataset generation. """ self.pk_serial_chain = self.cfg.build_pk_serial_chain(device=self.device)
[docs] def destroy(self) -> None: return super().destroy()