Source code for embodichain.lab.sim.atomic_actions.core

# ----------------------------------------------------------------------------
# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------

from __future__ import annotations

import torch
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING

from embodichain.lab.sim.planners import PlanResult, PlanState, MoveType
from embodichain.utils import configclass

from embodichain.toolkits.graspkit.pg_grasp import (
    GraspGenerator,
    GraspGeneratorCfg,
)
from embodichain.toolkits.graspkit.pg_grasp.gripper_collision_checker import (
    GripperCollisionCfg,
)
from embodichain.lab.sim.common import BatchEntity
from embodichain.utils import logger

if TYPE_CHECKING:
    from embodichain.lab.sim.planners import MotionGenerator, MotionGenOptions
    from embodichain.lab.sim.objects import Robot


# =============================================================================
# Affordance Classes
# =============================================================================


[docs] @dataclass class Affordance: """Base class for affordance data. Affordance represents interaction possibilities for an object. This is the base class for specific affordance types. """ object_label: str = "" """Label of the object this affordance belongs to.""" geometry: Dict[str, Any] = field(default_factory=dict) """Geometry dictionary shared with ObjectSemantics. The mesh payload is expected to be stored in: - ``mesh_vertices``: torch.Tensor with shape [N, 3] - ``mesh_triangles``: torch.Tensor with shape [M, 3] """ custom_config: Dict[str, Any] = field(default_factory=dict) """User-defined configuration payload for affordance creation and usage.""" @property def mesh_vertices(self) -> torch.Tensor | None: """Get mesh vertices from geometry. Returns: Mesh vertices tensor [N, 3], or None if unavailable. Raises: TypeError: If ``mesh_vertices`` exists but is not a torch tensor. """ vertices = self.geometry.get("mesh_vertices") if vertices is None: return None if not isinstance(vertices, torch.Tensor): raise TypeError("geometry['mesh_vertices'] must be a torch.Tensor") return vertices @property def mesh_triangles(self) -> torch.Tensor | None: """Get mesh triangles from geometry. Returns: Mesh triangle index tensor [M, 3], or None if unavailable. Raises: TypeError: If ``mesh_triangles`` exists but is not a torch tensor. """ triangles = self.geometry.get("mesh_triangles") if triangles is None: return None if not isinstance(triangles, torch.Tensor): raise TypeError("geometry['mesh_triangles'] must be a torch.Tensor") return triangles
[docs] def set_custom_config(self, key: str, value: Any) -> None: """Set a custom affordance configuration value.""" self.custom_config[key] = value
[docs] def get_custom_config(self, key: str, default: Any = None) -> Any: """Get a custom affordance configuration value.""" return self.custom_config.get(key, default)
[docs] def get_batch_size(self) -> int: """Return the batch size of this affordance data.""" return 1
@dataclass class AntipodalAffordance(Affordance): generator: GraspGenerator | None = None """Grasp generator instance, initialized lazily when needed.""" force_reannotate: bool = False """Whether to force re-annotation of grasp generator on each access.""" is_draw_grasp_xpos: bool = False """Whether to visualize grasp poses in the simulator.""" def _init_generator(self): if ( self.geometry.get("mesh_vertices", None) is None or self.geometry.get("mesh_triangles", None) is None ): logger.log_error( "Mesh vertices and triangles must be provided in geometry to initialize AntipodalAffordance." ) self.generator = GraspGenerator( vertices=self.geometry.get("mesh_vertices"), triangles=self.geometry.get("mesh_triangles"), cfg=self.custom_config.get("generator_cfg", None), gripper_collision_cfg=self.custom_config.get("gripper_collision_cfg", None), ) if self.force_reannotate: self.generator.annotate() else: if self.generator._hit_point_pairs is None: self.generator.annotate() def get_best_grasp_poses( self, obj_poses: torch.Tensor, approach_direction: torch.Tensor = torch.tensor( [0, 0, -1], dtype=torch.float32 ), ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if self.generator is None: self._init_generator() grasp_xpos_list = [] is_success_list = [] open_length_list = [] for i, obj_pose in enumerate(obj_poses): is_success, grasp_xpos, open_length = self.generator.get_grasp_poses( obj_pose, approach_direction ) if is_success: grasp_xpos_list.append(grasp_xpos.unsqueeze(0)) else: logger.log_warning(f"No valid grasp pose found for {i}-th object.") grasp_xpos_list.append( torch.eye( 4, dtype=torch.float32, device=self.generator.device ).unsqueeze(0) ) # Default to identity pose if no grasp found is_success_list.append(is_success) open_length_list.append(open_length) is_success = torch.tensor( is_success_list, dtype=torch.bool, device=self.generator.device ) grasp_xpos = torch.concatenate(grasp_xpos_list, dim=0) # [B, 4, 4] open_length = torch.tensor( open_length_list, dtype=torch.float32, device=self.generator.device ) if self.is_draw_grasp_xpos: self._draw_grasp_xpos(grasp_xpos, open_length) return is_success, grasp_xpos, open_length def _draw_grasp_xpos(self, grasp_xpos: torch.Tensor, open_length: torch.Tensor): sim = SimulationManager.get_instance() axis_xpos = [] for i in range(grasp_xpos.shape[0]): axis_xpos.append(grasp_xpos[i].to("cpu").numpy()) sim.draw_marker( cfg=MarkerCfg( name="grasp_xpos", axis_xpos=axis_xpos, axis_len=0.05, ) )
[docs] @dataclass class InteractionPoints(Affordance): """Interaction points affordance containing a batch of 3D positions. Interaction points define specific locations on an object surface that can be used for contact-based interactions (pushing, poking, touching) rather than full grasping. """ points: torch.Tensor = field(default_factory=lambda: torch.zeros(1, 3)) """Batch of 3D interaction points with shape [B, 3]. Each point is a 3D coordinate in the object's local coordinate frame. """ normals: torch.Tensor | None = None """Optional surface normals at each interaction point with shape [B, 3]. Normals indicate the surface orientation at each point, useful for determining approach directions. """ point_types: List[str] = field(default_factory=list) """Optional labels for each point's interaction type. Examples: "push", "poke", "touch", "pinch" """
[docs] def get_points_by_type(self, point_type: str) -> torch.Tensor | None: """Get points by their interaction type. Args: point_type: Type of interaction (e.g., "push", "poke") Returns: Tensor of points if found, None otherwise """ if point_type in self.point_types: indices = [i for i, t in enumerate(self.point_types) if t == point_type] return self.points[indices] return None
[docs] def get_batch_size(self) -> int: """Return the number of interaction points in this affordance.""" return self.points.shape[0]
[docs] def get_approach_direction(self, point_idx: int) -> torch.Tensor: """Get recommended approach direction for a given point. Args: point_idx: Index of the point Returns: 3D approach direction vector (normalized) """ if self.normals is not None: # Approach from the opposite direction of the surface normal return -self.normals[point_idx] # Default: approach from positive z return torch.tensor( [0, 0, 1], dtype=self.points.dtype, device=self.points.device )
# ============================================================================= # ObjectSemantics # =============================================================================
[docs] @dataclass class ObjectSemantics: """Semantic information about interaction target. This class encapsulates all semantic and geometric information about an object needed for intelligent interaction planning. """ affordance: Affordance """Affordance data (GraspPose, InteractionPoints, etc.).""" geometry: Dict[str, Any] """Geometric information including bounding box, mesh data.""" properties: Dict[str, Any] = field(default_factory=dict) """Physical properties: mass, friction, etc.""" label: str = "none" """Object category label (e.g., 'apple', 'bottle').""" entity: BatchEntity | None = None """Optional reference to the underlying simulation entity representing this object.""" def __post_init__(self) -> None: """Bind affordance metadata to this semantic object. The affordance shares the same geometry dict instance as ``ObjectSemantics.geometry`` so mesh tensors are authored in one place. """ self.affordance.object_label = self.label self.affordance.geometry = self.geometry
# ============================================================================= # ActionCfg and AtomicAction # =============================================================================
[docs] @configclass class ActionCfg: """Configuration for atomic actions.""" name: str = "default" """Name of the action, used for identification and logging.""" control_part: str = "arm" """Control part name for the action.""" interpolation_type: str = "linear" """Interpolation type: 'linear', 'cubic'.""" velocity_limit: Optional[float] = None """Optional velocity limit for the motion.""" acceleration_limit: Optional[float] = None """Optional acceleration limit for the motion."""
[docs] class AtomicAction(ABC): """Abstract base class for atomic actions. All atomic actions use PlanResult from embodichain.lab.sim.planners as the return type for execute() method, ensuring consistency with the existing motion planning infrastructure. """
[docs] def __init__( self, motion_generator: MotionGenerator, cfg: ActionCfg = ActionCfg(), ): """ Initialize the atomic action. Args: motion_generator: The motion generator instance to use for planning. cfg: Configuration for the action. """ self.motion_generator = motion_generator self.cfg = cfg self.robot = motion_generator.robot self.control_part = cfg.control_part self.device = self.robot.device
[docs] @abstractmethod def execute( self, target: Union[torch.Tensor, ObjectSemantics], start_qpos: Optional[torch.Tensor] = None, **kwargs, ) -> tuple[bool, torch.Tensor, list[float]]: """execute pick up action Args: target (ObjectSemantics): object semantics containing grasp affordance and entity information start_qpos (Optional[torch.Tensor], optional): Planning start qpos. Defaults to None. Returns: tuple[bool, torch.Tensor, list[float]]: is_success, trajectory of shape (n_envs, n_waypoints, dof), joint_ids corresponding to trajectory """
[docs] @abstractmethod def validate( self, target: Union[torch.Tensor, ObjectSemantics], start_qpos: Optional[torch.Tensor] = None, **kwargs, ) -> bool: """Validate if the action is feasible without executing. This method performs a quick feasibility check (e.g., IK solvability) without generating a full trajectory. Returns: True if action appears feasible, False otherwise """ pass
def _ik_solve( self, target_pose: torch.Tensor, qpos_seed: Optional[torch.Tensor] = None ) -> torch.Tensor: """Solve IK for target pose. Args: target_pose: Target pose [4, 4] qpos_seed: Seed configuration [DOF] Returns: Joint configuration [DOF] Raises: RuntimeError: If IK fails to find a solution """ if qpos_seed is None: qpos_seed = self.robot.get_qpos() success, qpos = self.robot.compute_ik( pose=target_pose.unsqueeze(0), qpos_seed=qpos_seed.unsqueeze(0), name=self.control_part, ) if not success.all(): raise RuntimeError(f"IK failed for target pose: {target_pose}") return qpos.squeeze(0) def _fk_compute(self, qpos: torch.Tensor) -> torch.Tensor: """Compute forward kinematics. Args: qpos: Joint configuration [DOF] or [B, DOF] Returns: End-effector pose [4, 4] or [B, 4, 4] """ if qpos.dim() == 1: qpos = qpos.unsqueeze(0) xpos = self.robot.compute_fk( qpos=qpos, name=self.control_part, to_matrix=True, ) return xpos.squeeze(0) if xpos.shape[0] == 1 else xpos def _apply_offset(self, pose: torch.Tensor, offset: torch.Tensor) -> torch.Tensor: """Apply offset to pose in local frame. Args: pose: Base pose [N, 4, 4] offset: Offset in local frame [N, 3] or [3] Returns: Pose with offset applied [N, 4, 4] """ if not len(pose.shape) == 3 or pose.shape[1:] != (4, 4): logger.log_error("pose must have shape [N, 4, 4]") if len(offset.shape) == 1: offset = offset.unsqueeze(0) if not len(offset.shape) == 2 or offset.shape[1] != 3: logger.log_error("offset must have shape [N, 3] or [3]") result = pose.clone() result[:, :3, 3] += offset return result
[docs] def plan_trajectory( self, target_states: List[PlanState], options: Optional["MotionGenOptions"] = None, ) -> "PlanResult": """Plan trajectory using motion generator.""" from embodichain.lab.sim.planners import MotionGenOptions if options is None: options = MotionGenOptions(control_part=self.control_part) return self.motion_generator.generate(target_states, options)