Source code for embodichain.lab.sim.atomic_actions.engine

# ----------------------------------------------------------------------------
# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------

from __future__ import annotations

import torch
from typing import Any, Dict, List, Optional, Type, Union, TYPE_CHECKING

from embodichain.lab.sim.planners import PlanResult
from embodichain.utils import logger
from .core import AtomicAction, ObjectSemantics, ActionCfg

if TYPE_CHECKING:
    from embodichain.lab.sim.planners import MotionGenerator
    from embodichain.lab.sim.objects import Robot


# =============================================================================
# Global Action Registry
# =============================================================================

_global_action_registry: Dict[str, Type[AtomicAction]] = {}
_global_action_configs: Dict[str, Type[ActionCfg]] = {}


[docs] def register_action( name: str, action_class: Type[AtomicAction], config_class: Optional[Type[ActionCfg]] = None, ) -> None: """Register a custom atomic action class globally. This function allows registration of custom action types that can then be instantiated by the AtomicActionEngine. Args: name: Unique identifier for the action type action_class: The AtomicAction subclass to register config_class: Optional configuration class for the action Example: >>> class MyCustomAction(AtomicAction): ... def execute(self, target, **kwargs): ... # Implementation ... pass ... def validate(self, target, **kwargs): ... return True >>> register_action("my_custom", MyCustomAction) """ _global_action_registry[name] = action_class if config_class is not None: _global_action_configs[name] = config_class
[docs] def unregister_action(name: str) -> None: """Unregister an action type. Args: name: The action type identifier to remove """ _global_action_registry.pop(name, None) _global_action_configs.pop(name, None)
[docs] def get_registered_actions() -> Dict[str, Type[AtomicAction]]: """Get all registered action types. Returns: Dictionary mapping action names to their classes """ return _global_action_registry.copy()
# ============================================================================= # Semantic Analyzer # ============================================================================= class SemanticAnalyzer: """Analyzes objects and provides ObjectSemantics for atomic actions.""" def __init__(self): self._object_cache: Dict[str, ObjectSemantics] = {} def analyze( self, label: str, geometry: Optional[Dict[str, Any]] = None, custom_config: Optional[Dict[str, Any]] = None, use_cache: bool = True, ) -> ObjectSemantics: """Analyze object by label and return ObjectSemantics. This is a placeholder implementation that should be extended with actual object detection and affordance computation. Args: label: Object category label (e.g., "apple", "bottle") geometry: Optional geometry payload. Can include mesh tensors: ``mesh_vertices`` [N, 3] and ``mesh_triangles`` [M, 3]. custom_config: Optional user-defined affordance configuration. use_cache: Whether to use cached semantics when available. Returns: ObjectSemantics containing affordance data """ # Only use cache for default analyze path if ( use_cache and geometry is None and custom_config is None and label in self._object_cache ): return self._object_cache[label] # Create default semantics (placeholder implementation) from .core import AntipodalAffordance # Generate default grasp poses based on object type default_poses = torch.eye(4).unsqueeze(0) default_poses[0, 2, 3] = 0.1 # Default offset default_geometry: Dict[str, Any] = {"bounding_box": [0.1, 0.1, 0.1]} if geometry is not None: default_geometry.update(geometry) grasp_affordance = AntipodalAffordance( object_label=label, custom_config=custom_config or {}, ) semantics = ObjectSemantics( label=label, affordance=grasp_affordance, geometry=default_geometry, properties={"mass": 1.0, "friction": 0.5}, ) # Cache only default path if use_cache and geometry is None and custom_config is None: self._object_cache[label] = semantics return semantics def clear_cache(self) -> None: """Clear the object semantics cache.""" self._object_cache.clear() # ============================================================================= # Atomic Action Engine # =============================================================================
[docs] class AtomicActionEngine: """Central engine for managing and executing atomic actions."""
[docs] def __init__( self, motion_generator: "MotionGenerator", actions_cfg_list: Optional[List[ActionCfg]] = None, ): self.motion_generator = motion_generator self.robot = self.motion_generator.robot self.device = self.motion_generator.device # Semantic analyzer for object understanding self._semantic_analyzer = SemanticAnalyzer() # Initialize default actions self._actions: Dict[str, AtomicAction] = self._init_actions(actions_cfg_list)
def _init_actions( self, actions_cfg_list: Optional[List[ActionCfg]] = None ) -> Dict[str, "AtomicAction"]: actions: Dict[str, AtomicAction] = {} from .actions import MoveAction, PickUpAction, PlaceAction builtin_action_map: Dict[str, Type[AtomicAction]] = { "move": MoveAction, "pick_up": PickUpAction, "place": PlaceAction, } if actions_cfg_list is not None: for cfg in actions_cfg_list: action_class = builtin_action_map.get( cfg.name ) or _global_action_registry.get(cfg.name) if action_class is None: logger.log_error(f"Unknown action name in config: {cfg.name}") continue instance = action_class(motion_generator=self.motion_generator, cfg=cfg) actions[cfg.name] = instance return actions
[docs] def execute_static( self, target_list: List[Union[torch.Tensor, str, ObjectSemantics, Dict[str, Any]]], ) -> tuple[bool, torch.Tensor]: """Execute a sequence of actions to target poses. Each element in ``target_list`` corresponds to an action in the order they were registered via ``actions_cfg_list``. """ action_names = list(self._actions.keys()) if len(target_list) != len(action_names): logger.log_error( f"Length of target_list ({len(target_list)}) must match number of actions ({len(action_names)})." ) start_qpos = self.motion_generator.robot.get_qpos() n_envs = start_qpos.shape[0] all_dof = self.motion_generator.robot.dof all_trajectory = torch.empty( size=(n_envs, 0, all_dof), dtype=torch.float32, device=self.device ) for action_name, target in zip(action_names, target_list): atom_action = self._actions[action_name] target = self._resolve_target(target) control_part = atom_action.control_part arm_joint_ids = self.motion_generator.robot.get_joint_ids(name=control_part) start_qpos_part = start_qpos[:, arm_joint_ids] is_success, traj, joint_ids = atom_action.execute( target=target, start_qpos=start_qpos_part ) if not is_success: return False, all_trajectory n_waypoints = traj.shape[1] traj_full = torch.zeros( size=(n_envs, n_waypoints, all_dof), dtype=torch.float32, device=self.device, ) traj_full[:, :] = start_qpos traj_full[:, :, joint_ids] = traj all_trajectory = torch.cat((all_trajectory, traj_full), dim=1) # update start qpos for the next action start_qpos[:, joint_ids] = traj[:, -1, :] return True, all_trajectory
[docs] def validate( self, action_name: str, target: Union[torch.Tensor, str, ObjectSemantics, Dict[str, Any]], **kwargs, ) -> bool: """Validate if a named action is feasible without executing.""" if action_name not in self._actions: logger.log_warning(f"Action '{action_name}' is not registered.") return False action = self._actions[action_name] target = self._resolve_target(target) return action.validate(target, **kwargs)
def _resolve_target( self, target: Union[torch.Tensor, str, ObjectSemantics, Dict[str, Any]], ) -> Union[torch.Tensor, ObjectSemantics]: """Resolve user target input into tensor pose or ObjectSemantics. Supports the convenience dict format in ``execute`` and ``validate``. """ if isinstance(target, torch.Tensor): return target if isinstance(target, ObjectSemantics): return target if isinstance(target, str): return self._semantic_analyzer.analyze(target) if isinstance(target, dict): if "pose" in target: pose = target["pose"] if not isinstance(pose, torch.Tensor): raise TypeError("target['pose'] must be a torch.Tensor") return pose if "semantics" in target: semantics = target["semantics"] if not isinstance(semantics, ObjectSemantics): raise TypeError( "target['semantics'] must be an ObjectSemantics instance" ) return semantics label = target.get("label") if label is None: raise ValueError( "Dict target must provide 'label', or use 'pose'/'semantics'." ) if not isinstance(label, str): raise TypeError("target['label'] must be a string") geometry = target.get("geometry") custom_config = target.get("custom_config") use_cache = target.get("use_cache", True) semantics = self._semantic_analyzer.analyze( label=label, geometry=geometry, custom_config=custom_config, use_cache=use_cache, ) properties = target.get("properties") if properties is not None: semantics.properties.update(properties) uid = target.get("uid") if uid is not None: semantics.uid = uid return semantics raise TypeError( "target must be torch.Tensor, str, ObjectSemantics, or Dict[str, Any]" )
[docs] def get_semantic_analyzer(self) -> SemanticAnalyzer: """Get the semantic analyzer for object understanding.""" return self._semantic_analyzer
[docs] def set_semantic_analyzer(self, analyzer: SemanticAnalyzer) -> None: """Set a custom semantic analyzer.""" self._semantic_analyzer = analyzer