Source code for embodichain.lab.sim.atomic_actions.actions

# ----------------------------------------------------------------------------
# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------

from __future__ import annotations

import torch
from typing import Optional, Union, TYPE_CHECKING, Any

from embodichain.lab.sim.planners import PlanResult, PlanState, MoveType
from embodichain.lab.sim.planners.motion_generator import MotionGenOptions
from embodichain.lab.sim.planners.toppra_planner import ToppraPlanOptions
from .core import AtomicAction, ObjectSemantics, AntipodalAffordance, ActionCfg
from embodichain.utils import logger
from embodichain.utils import configclass
from embodichain.lab.sim.utility.action_utils import interpolate_with_distance
import numpy as np

if TYPE_CHECKING:
    from embodichain.lab.sim.planners import MotionGenerator
    from embodichain.lab.sim.objects import Robot


[docs] @configclass class MoveActionCfg(ActionCfg): name: str = "move" """Name of the action, used for identification and logging.""" sample_interval: int = 50 """Number of waypoints to sample for the motion trajectory. Should be large enough to ensure smooth motion, but not too large to cause unnecessary computation overhead."""
@configclass class GraspActionCfg(MoveActionCfg): """Shared configuration for actions that involve gripper open/close motions.""" hand_open_qpos: torch.Tensor | None = None """[hand_dof,] of float. Joint positions for open hand state.""" hand_close_qpos: torch.Tensor | None = None """[hand_dof,] of float. Joint positions for closed hand state.""" hand_control_part: str = "hand" """Name of the robot part that controls the hand joints.""" lift_height: float = 0.1 """Height (m) to lift the end-effector after the gripper phase.""" sample_interval: int = 80 """Number of waypoints for the full trajectory (approach + hand + lift/back).""" hand_interp_steps: int = 5 """Number of waypoints for the gripper open/close interpolation phase."""
[docs] class MoveAction(AtomicAction):
[docs] def __init__( self, motion_generator: MotionGenerator, cfg: MoveActionCfg | None = None, ): """ Initialize the atomic action. Args: motion_generator: The motion generator instance to use for planning. cfg: Configuration for the action. """ super().__init__( motion_generator, cfg=cfg if cfg is not None else MoveActionCfg() ) self.n_envs = self.robot.get_qpos().shape[0] self.arm_joint_ids = self.robot.get_joint_ids(name=self.cfg.control_part) self.dof = len(self.arm_joint_ids)
def _resolve_pose_target( self, target: Union[ObjectSemantics, torch.Tensor], *, action_name: str, ) -> tuple[bool, torch.Tensor]: """Resolve a pose target into a batched homogeneous transform tensor.""" if isinstance(target, ObjectSemantics): logger.log_error( f"{action_name} currently does not support ObjectSemantics target. " f"Please provide target pose as torch.Tensor of shape (4, 4) or " f"(n_envs, 4, 4)", NotImplementedError, ) if not isinstance(target, torch.Tensor): logger.log_error( "Target must be either ObjectSemantics or torch.Tensor of shape " f"(4, 4) or ({self.n_envs}, 4, 4)", TypeError, ) if target.shape == (4, 4): target = target.unsqueeze(0).repeat(self.n_envs, 1, 1) if target.shape != (self.n_envs, 4, 4): logger.log_error( f"Target tensor must have shape (4, 4) or ({self.n_envs}, 4, 4), but got {target.shape}", ValueError, ) return True, target def _resolve_start_qpos( self, start_qpos: Optional[torch.Tensor], arm_dof: Optional[int] = None, ) -> torch.Tensor: """Resolve planning start joint positions into batched arm joint positions.""" arm_dof = self.dof if arm_dof is None else arm_dof if start_qpos is None: start_qpos = self.robot.get_qpos(name=self.cfg.control_part) if start_qpos.shape == (arm_dof,): start_qpos = start_qpos.unsqueeze(0).repeat(self.n_envs, 1) if start_qpos.shape != (self.n_envs, arm_dof): logger.log_error( f"start_qpos must have shape ({self.n_envs}, {arm_dof}), but got {start_qpos.shape}", ValueError, ) return start_qpos def _compute_three_phase_waypoints( self, hand_interp_steps: int, *, first_phase_name: str, third_phase_name: str, first_phase_ratio: float = 0.6, ) -> tuple[int, int, int]: """Split total sample interval into motion, hand interpolation, and motion phases.""" first_phase_waypoint = int( np.round(self.cfg.sample_interval - hand_interp_steps) * first_phase_ratio ) if first_phase_waypoint < 2: logger.log_error( f"Not enough waypoints for {first_phase_name} trajectory. " "Please increase sample_interval or decrease hand_interp_steps.", ValueError, ) second_phase_waypoint = hand_interp_steps third_phase_waypoint = ( self.cfg.sample_interval - first_phase_waypoint - second_phase_waypoint ) if third_phase_waypoint < 2: logger.log_error( f"Not enough waypoints for {third_phase_name} trajectory. " "Please increase sample_interval or decrease hand_interp_steps.", ValueError, ) return first_phase_waypoint, second_phase_waypoint, third_phase_waypoint def _build_motion_gen_options( self, start_qpos: torch.Tensor, sample_interval: int, ) -> MotionGenOptions: """Build default motion generation options for an atomic action.""" return MotionGenOptions( start_qpos=start_qpos[0], control_part=self.cfg.control_part, is_interpolate=True, is_linear=False, interpolate_position_step=0.001, plan_opts=ToppraPlanOptions( sample_interval=sample_interval, ), ) def _plan_arm_trajectory( self, target_states_list: list[list[PlanState]], start_qpos: torch.Tensor, n_waypoints: int, arm_dof: Optional[int] = None, ) -> tuple[bool, torch.Tensor]: """Plan batched arm trajectories for all environments.""" arm_dof = self.dof if arm_dof is None else arm_dof n_state = len(target_states_list[0]) xpos_traj = torch.zeros( size=(self.n_envs, n_state, 4, 4), dtype=torch.float32, device=self.device ) for i, target_states in enumerate(target_states_list): for j, target_state in enumerate(target_states): # [env_i, state_j, 4, 4] xpos_traj[i, j] = target_state.xpos trajectory = torch.zeros( size=(self.n_envs, n_state, arm_dof), dtype=torch.float32, device=self.device, ) qpos_seed = start_qpos for j in range(n_state): is_success, qpos = self.robot.compute_ik( pose=xpos_traj[:, j], name=self.cfg.control_part, joint_seed=qpos_seed ) if not is_success: logger.log_warning( f"Failed to compute IK for target state {j} in some environments. " "The resulting trajectory may be invalid." ) return False, trajectory else: trajectory[:, j] = qpos qpos_seed = qpos trajectory = torch.concatenate([start_qpos.unsqueeze(1), trajectory], dim=1) interp_traj = interpolate_with_distance( trajectory=trajectory, interp_num=n_waypoints, device=self.device ) return True, interp_traj def _interpolate_hand_qpos( self, start_hand_qpos: torch.Tensor, end_hand_qpos: torch.Tensor, n_waypoints: int, ) -> torch.Tensor: """Interpolate hand joint positions between two gripper states.""" weights = torch.linspace(0, 1, steps=n_waypoints, device=self.device) hand_qpos_list = [ torch.lerp(start_hand_qpos, end_hand_qpos, weight) for weight in weights ] return torch.stack(hand_qpos_list, dim=0)
[docs] def execute( self, target: Union[ObjectSemantics, torch.Tensor], start_qpos: Optional[torch.Tensor] = None, **kwargs, ) -> tuple[bool, torch.Tensor, list[float]]: """execute pick up action Args: target (ObjectSemantics): object semantics containing grasp affordance and entity information start_qpos (Optional[torch.Tensor], optional): Planning start qpos. Defaults to None. Returns: tuple[bool, torch.Tensor, list[float]]: is_success, trajectory of shape (n_envs, n_waypoints, dof), joint_ids corresponding to trajectory """ is_success, move_xpos = self._resolve_pose_target( target, action_name=self.__class__.__name__ ) start_qpos = self._resolve_start_qpos(start_qpos) # TODO: warning and fallback if no valid grasp pose found if not is_success: logger.log_warning( "Failed to resolve grasp pose, using default approach pose" ) return False, torch.empty(0), self.arm_joint_ids target_states_list = [ [ PlanState(xpos=move_xpos[i], move_type=MoveType.EEF_MOVE), ] for i in range(self.n_envs) ] is_plan_success, trajectory = self._plan_arm_trajectory( target_states_list, start_qpos, self.cfg.sample_interval ) return is_plan_success, trajectory, self.arm_joint_ids
[docs] def validate(self, target, start_qpos=None, **kwargs): # TODO: implement proper validation logic for pick up action return True
[docs] @configclass class PickUpActionCfg(GraspActionCfg): name: str = "pick_up" """Name of the action, used for identification and logging.""" pre_grasp_distance: float = 0.15 """Distance to offset back from the grasp pose along the approach direction to get the pre-grasp pose. Should be large enough to avoid collision during approach.""" approach_direction: torch.Tensor = torch.tensor([0, 0, -1], dtype=torch.float32) """Direction from which the gripper approaches the object for grasping, expressed in the object local frame. Default [0, 0, -1] means approaching from above."""
[docs] class PickUpAction(MoveAction):
[docs] def __init__( self, motion_generator: MotionGenerator, cfg: PickUpActionCfg | None = None, ): """ Initialize the atomic action. Args: motion_generator: The motion generator instance to use for planning. cfg: Configuration for the action. """ super().__init__( motion_generator, cfg=cfg if cfg is not None else PickUpActionCfg() ) self.cfg = cfg self.approach_direction = self.cfg.approach_direction.to(self.device) if self.cfg.hand_open_qpos is None: logger.log_error("hand_open_qpos must be specified in PickUpActionCfg") if self.cfg.hand_close_qpos is None: logger.log_error("hand_close_qpos must be specified in PickUpActionCfg") self.hand_open_qpos = self.cfg.hand_open_qpos.to(self.device) self.hand_close_qpos = self.cfg.hand_close_qpos.to(self.device) self.hand_joint_ids = self.robot.get_joint_ids(name=self.cfg.hand_control_part) self.joint_ids = self.arm_joint_ids + self.hand_joint_ids self.arm_dof = len(self.arm_joint_ids) self.dof = len(self.joint_ids)
[docs] def execute( self, target: Union[ObjectSemantics, torch.Tensor], start_qpos: Optional[torch.Tensor] = None, **kwargs, ) -> tuple[bool, torch.Tensor, list[float]]: """execute pick up action Args: target (Union[ObjectSemantics, torch.Tensor]): target object semantics or target pose for grasping start_qpos (Optional[torch.Tensor], optional): Planning start qpos. Defaults to None. Returns: tuple[bool, torch.Tensor, list[float]]: is_success, trajectory of shape (n_envs, n_waypoints, dof), joint_ids corresponding to trajectory """ # Resolve grasp pose if isinstance(target, ObjectSemantics): is_success, grasp_xpos, open_length = self._resolve_grasp_pose(target) else: is_success, grasp_xpos = self._resolve_pose_target( target, action_name=self.__class__.__name__ ) # TODO: warning and fallback if no valid grasp pose found if not is_success: logger.log_warning( "Failed to resolve grasp pose, using default approach pose" ) return False, torch.empty(0), self.joint_ids # Compute pre-grasp pose # TODO: only for parallel gripper, approach in negative grasp z direction grasp_z = grasp_xpos[:, :3, 2] pre_grasp_xpos = self._apply_offset( pose=grasp_xpos, offset=-grasp_z * self.cfg.pre_grasp_distance, ) # Compute lift pose start_qpos = self._resolve_start_qpos(start_qpos, self.arm_dof) # compute waypoint number for each phase n_approach_waypoint, n_close_waypoint, n_lift_waypoint = ( self._compute_three_phase_waypoints( self.cfg.hand_interp_steps, first_phase_name="approach", third_phase_name="lift", ) ) # get pick trajectory target_states_list = [ [ PlanState(xpos=pre_grasp_xpos[i], move_type=MoveType.EEF_MOVE), PlanState(xpos=grasp_xpos[i], move_type=MoveType.EEF_MOVE), ] for i in range(self.n_envs) ] pick_trajectory = torch.zeros( size=(self.n_envs, n_approach_waypoint, self.dof), dtype=torch.float32, device=self.device, ) is_success, plan_traj = self._plan_arm_trajectory( target_states_list, start_qpos, n_approach_waypoint, self.arm_dof, ) if not is_success: logger.log_warning("Failed to plan approach trajectory.") return False, pick_trajectory, self.joint_ids pick_trajectory[:, :, : self.arm_dof] = plan_traj # Padding hand open qpos to pick trajectory pick_trajectory[:, :, self.arm_dof :] = self.hand_open_qpos # get hand closing trajectory grasp_qpos = pick_trajectory[ :, -1, : self.arm_dof ] # Assuming the last point of pick trajectory is the grasp pose hand_close_path = self._interpolate_hand_qpos( self.hand_open_qpos, self.hand_close_qpos, n_close_waypoint, ) hand_close_trajectory = torch.zeros( size=(self.n_envs, n_close_waypoint, self.dof), device=self.device, ) hand_close_trajectory[:, :, : self.arm_dof] = grasp_qpos hand_close_trajectory[:, :, self.arm_dof :] = hand_close_path # get lift trajectory lift_trajectory = torch.zeros( size=(self.n_envs, n_lift_waypoint, self.dof), dtype=torch.float32, device=self.device, ) # lift_xpos = self._compute_lift_xpos(grasp_xpos) lift_xpos = self._apply_offset( pose=grasp_xpos, offset=torch.tensor([0, 0, 1], device=self.device) * self.cfg.lift_height, ) target_states_list = [ [ PlanState(xpos=lift_xpos[i], move_type=MoveType.EEF_MOVE), ] for i in range(self.n_envs) ] is_success, plan_traj = self._plan_arm_trajectory( target_states_list, grasp_qpos, n_lift_waypoint, self.arm_dof, ) if not is_success: logger.log_warning("Failed to plan lift trajectory.") return False, lift_trajectory, self.joint_ids lift_trajectory[:, :, : self.arm_dof] = plan_traj # padding hand close qpos to lift trajectory lift_trajectory[:, :, self.arm_dof :] = self.hand_close_qpos # concatenate trajectories trajectory = torch.cat( [pick_trajectory, hand_close_trajectory, lift_trajectory], dim=1 ) return True, trajectory, self.joint_ids
def _resolve_grasp_pose( self, semantics: ObjectSemantics ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if not isinstance(semantics.affordance, AntipodalAffordance): logger.log_error( "Grasp pose affordance must be of type AntipodalAffordance" ) if semantics.entity is None: logger.log_error( "ObjectSemantics must be associated with an entity to get object pose" ) obj_poses = semantics.entity.get_local_pose(to_matrix=True) is_success, grasp_xpos, open_length = semantics.affordance.get_best_grasp_poses( obj_poses=obj_poses, approach_direction=self.approach_direction ) return is_success, grasp_xpos, open_length
[docs] def validate(self, target, start_qpos=None, **kwargs): # TODO: implement proper validation logic for pick up action return True
[docs] @configclass class PlaceActionCfg(GraspActionCfg): name: str = "place" """Name of the action, used for identification and logging."""
[docs] class PlaceAction(MoveAction):
[docs] def __init__( self, motion_generator: MotionGenerator, cfg: PlaceActionCfg | None = None, ): """ Initialize the atomic action. Args: motion_generator: The motion generator instance to use for planning. cfg: Configuration for the action. """ super().__init__( motion_generator, cfg=cfg if cfg is not None else PlaceActionCfg() ) self.cfg = cfg if self.cfg.hand_open_qpos is None: logger.log_error("hand_open_qpos must be specified in PlaceActionCfg") if self.cfg.hand_close_qpos is None: logger.log_error("hand_close_qpos must be specified in PlaceActionCfg") self.hand_open_qpos = self.cfg.hand_open_qpos.to(self.device) self.hand_close_qpos = self.cfg.hand_close_qpos.to(self.device) self.hand_joint_ids = self.robot.get_joint_ids(name=self.cfg.hand_control_part) self.joint_ids = self.arm_joint_ids + self.hand_joint_ids self.arm_dof = len(self.arm_joint_ids) self.dof = len(self.joint_ids)
[docs] def execute( self, target: Union[ObjectSemantics, torch.Tensor], start_qpos: Optional[torch.Tensor] = None, **kwargs, ) -> tuple[bool, torch.Tensor, list[float]]: """execute pick up action Args: target (ObjectSemantics): object semantics containing grasp affordance and entity information start_qpos (Optional[torch.Tensor], optional): Planning start qpos. Defaults to None. Returns: tuple[bool, torch.Tensor, list[float]]: is_success, trajectory of shape (n_envs, n_waypoints, dof), joint_ids corresponding to trajectory """ is_success, place_xpos = self._resolve_pose_target( target, action_name=self.__class__.__name__ ) start_qpos = self._resolve_start_qpos(start_qpos, self.arm_dof) # TODO: warning and fallback if no valid grasp pose found if not is_success: logger.log_warning( "Failed to resolve grasp pose, using default approach pose" ) return False, torch.empty(0), self.joint_ids # compute waypoint number for each phase n_down_waypoint, n_open_waypoint, n_lift_waypoint = ( self._compute_three_phase_waypoints( self.cfg.hand_interp_steps, first_phase_name="approach", third_phase_name="lift", ) ) down_trajectory = torch.zeros( size=(self.n_envs, n_down_waypoint, self.dof), dtype=torch.float32, device=self.device, ) lift_xpos = self._apply_offset( pose=place_xpos, offset=torch.tensor([0, 0, 1], device=self.device) * self.cfg.lift_height, ) target_states_list = [ [ PlanState(xpos=lift_xpos[i], move_type=MoveType.EEF_MOVE), PlanState(xpos=place_xpos[i], move_type=MoveType.EEF_MOVE), ] for i in range(self.n_envs) ] is_success, plan_traj = self._plan_arm_trajectory( target_states_list, start_qpos, n_down_waypoint, self.arm_dof, ) if not is_success: logger.log_warning("Failed to plan down trajectory.") return False, down_trajectory, self.joint_ids down_trajectory[:, :, : self.arm_dof] = plan_traj # Padding hand open qpos to pick trajectory down_trajectory[:, :, self.arm_dof :] = self.hand_close_qpos # get hand closing trajectory reach_qpos = down_trajectory[ :, -1, : self.arm_dof ] # Assuming the last point of pick trajectory is the grasp pose hand_open_path = self._interpolate_hand_qpos( self.hand_close_qpos, self.hand_open_qpos, n_open_waypoint, ) hand_open_trajectory = torch.zeros( size=(self.n_envs, n_open_waypoint, self.dof), device=self.device, ) hand_open_trajectory[:, :, : self.arm_dof] = reach_qpos hand_open_trajectory[:, :, self.arm_dof :] = hand_open_path # get lift trajectory back_trajectory = torch.zeros( size=(self.n_envs, n_lift_waypoint, self.dof), dtype=torch.float32, device=self.device, ) target_states_list = [ [ PlanState(xpos=lift_xpos[i], move_type=MoveType.EEF_MOVE), ] for i in range(self.n_envs) ] is_success, plan_traj = self._plan_arm_trajectory( target_states_list, reach_qpos, n_lift_waypoint, self.arm_dof, ) if not is_success: logger.log_warning("Failed to plan back trajectory.") return False, back_trajectory, self.joint_ids back_trajectory[:, :, : self.arm_dof] = plan_traj # padding hand open qpos to back trajectory back_trajectory[:, :, self.arm_dof :] = self.hand_open_qpos # concatenate trajectories trajectory = torch.cat( [down_trajectory, hand_open_trajectory, back_trajectory], dim=1 ) return True, trajectory, self.joint_ids
[docs] def validate(self, target, start_qpos=None, **kwargs): # TODO: implement proper validation logic for pick up action return True