Source code for embodichain.lab.sim.solvers.pytorch_solver

# ----------------------------------------------------------------------------
# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------

import torch

from typing import Union, Tuple, List, TYPE_CHECKING
from dataclasses import MISSING
from copy import deepcopy

from embodichain.utils import configclass, logger
from embodichain.lab.sim.solvers import SolverCfg, BaseSolver
from embodichain.lab.sim.solvers.qpos_seed_sampler import QposSeedSampler
from embodichain.lab.sim.utility.solver_utils import validate_iteration_params

if TYPE_CHECKING:
    from typing import Self

from embodichain.lab.sim.utility.import_utils import (
    lazy_import_pytorch_kinematics,
)


[docs] @configclass class PytorchSolverCfg(SolverCfg): """Configuration for the pytorch kinematics solver used in the robot simulation. This configuration includes properties related to the solver setup, such as the URDF path, the end link name, and the root link name, along with the Tool Center Point (TCP). """ class_type: str = "PytorchSolver" # Solver iteration parameters pos_eps: float = 5e-4 """Tolerance for convergence for position""" rot_eps: float = 5e-4 """Tolerance for convergence for rotation""" max_iterations: int = 500 """Maximum number of iterations for the solver""" dt: float = 0.1 """Time step for numerical integration""" damp: float = 1e-6 """Damping factor to prevent numerical instability""" is_only_position_constraint: bool = False """Flag to indicate whether the solver should only consider position constraints.""" num_samples: int = 5 """Number of samples to generate different joint seeds for IK iterations. A higher number of samples increases the chances of finding a valid solution """ ik_nearest_weight: list[float] | None = None """Weights for the inverse kinematics nearest calculation. The weights influence how the solver prioritizes closeness to the seed position when multiple solutions are available. """
[docs] def init_solver( self, device: torch.device = torch.device("cpu"), **kwargs ) -> "PytorchSolver": """Initialize the solver with the configuration. Args: device (torch.device): The device to use for the solver. Defaults to CPU. **kwargs: Additional keyword arguments that may be used for solver initialization. Returns: PytorchSolver: An initialized solver instance. """ solver = PytorchSolver(cfg=self, device=device, **kwargs) # Set the Tool Center Point (TCP) for the solver solver.set_tcp(self._get_tcp_as_numpy()) return solver
def ensure_pose_shape(func): """ Decorator to ensure the input target_pose is of shape (n, 4, 4). If input is (4, 4), it will be converted to (1, 4, 4). Raises ValueError if shape is invalid. """ def wrapper(self, target_xpos, *args, **kwargs): target_xpos = torch.as_tensor( target_xpos, device=self.device, dtype=torch.float32 ) if target_xpos.dim() == 2: if target_xpos.shape != (4, 4): raise ValueError("target_xpos must be of shape (4, 4) or (n, 4, 4).") target_xpos = target_xpos.unsqueeze(0) elif target_xpos.dim() == 3: if target_xpos.shape[1:] != (4, 4): raise ValueError("target_xpos must be of shape (4, 4) or (n, 4, 4).") else: raise ValueError( "target_xpos must be a tensor of shape (4, 4) or (n, 4, 4)." ) return func(self, target_xpos, *args, **kwargs) return wrapper
[docs] class PytorchSolver(BaseSolver):
[docs] def __init__( self, cfg: PytorchSolverCfg, device: str = None, **kwargs, ): r"""Initializes the PyTorch kinematics solver. This constructor sets up the kinematics solver using PyTorch, allowing for efficient computation of robot kinematics based on the specified URDF model. Args: cfg: The configuration for the solver. device (str, optional): The device to use for the solver (e.g., "cpu" or "cuda"). **kwargs: Additional keyword arguments passed to the base solver. """ super().__init__(cfg=cfg, device=device, **kwargs) self.pk = lazy_import_pytorch_kinematics() # Initialize solver parameters from configuration self._pos_eps = cfg.pos_eps self._rot_eps = cfg.rot_eps self._max_iterations = cfg.max_iterations self._dt = cfg.dt self._damp = cfg.damp self._is_only_position_constraint = cfg.is_only_position_constraint self._num_samples = cfg.num_samples # Get agent joint limits. self.lim = torch.tensor( self.pk_serial_chain.get_joint_limits(), device=self.device ) # Inverse kinematics is available via damped least squares (iterative steps with Jacobian pseudo-inverse damped to avoid oscillation near singularlities). self.pik = self.pk.PseudoInverseIK( self.pk_serial_chain, pos_tolerance=self._pos_eps, rot_tolerance=self._rot_eps, joint_limits=self.lim.T, early_stopping_any_converged=True, max_iterations=self._max_iterations, lr=self._dt, num_retries=1, use_compile=True, ) self.dof = self.pk_serial_chain.n_joints
[docs] def get_iteration_params(self) -> dict: r"""Returns the current iteration parameters. Returns: dict: A dictionary containing the current values of: - pos_eps (float): Pos convergence threshold - rot_eps (float): Rot convergence threshold - max_iterations (int): Maximum number of iterations. - dt (float): Time step size. - damp (float): Damping factor. - num_samples (int): Number of samples. - is_only_position_constraint (bool): Flag to indicate whether the solver should only consider position constraints. """ return { "pos_eps": self._pos_eps, "rot_eps": self._rot_eps, "max_iterations": self._max_iterations, "dt": self._dt, "damp": self._damp, "num_samples": self._num_samples, }
[docs] def set_iteration_params( self, pos_eps: float = 5e-4, rot_eps: float = 5e-4, max_iterations: int = 1000, dt: float = 0.1, damp: float = 1e-6, num_samples: int = 30, is_only_position_constraint: bool = False, ) -> bool: r"""Sets the iteration parameters for the kinematics solver. Args: pos_eps (float): Pos convergence threshold, must be positive. rot_eps (float): Rot convergence threshold, must be positive. max_iterations (int): Maximum number of iterations, must be positive. dt (float): Time step size, must be positive. damp (float): Damping factor, must be non-negative. num_samples (int): Number of samples, must be positive. is_only_position_constraint (bool): Flag to indicate whether the solver should only consider position constraints. Returns: bool: True if all parameters are valid and set, False otherwise. """ # Validate parameters if not validate_iteration_params( pos_eps, rot_eps, max_iterations, dt, damp, num_samples ): return False # Set parameters if all are valid self._pos_eps = pos_eps self._rot_eps = rot_eps self._max_iterations = max_iterations self._dt = dt self._damp = damp self._num_samples = num_samples self._is_only_position_constraint = is_only_position_constraint self.pik = self.pk.PseudoInverseIK( self.pk_serial_chain, pos_tolerance=self._pos_eps, rot_tolerance=self._rot_eps, joint_limits=self.lim.T, early_stopping_any_converged=True, max_iterations=self._max_iterations, lr=self._dt, num_retries=1, use_compile=True, ) return True
def _compute_inverse_kinematics( self, target_pose: torch.Tensor, joint_seed: torch.Tensor ) -> Tuple[Union[bool, torch.Tensor], torch.Tensor]: r"""Computes the inverse kinematics solutions for the given target poses and joint seeds. Args: target_pose (torch.Tensor): The target poses represented as a (batch_size, 4, 4) tensor. joint_seed (torch.Tensor): The initial joint positions used as a seed. It can be either a 1D tensor of shape (dof,) or a 2D tensor of shape (batch_size, dof). Returns: Tuple[Union[bool, torch.Tensor], torch.Tensor]: - First element: - If solutions exist: torch.BoolTensor of shape (batch_size,) indicating convergence per pose - If no solutions: Python False - Second element: - If solutions exist: torch.Tensor of shape (batch_size, dof) containing joint solutions - If no solutions: Empty torch.Tensor """ target_pose = target_pose.to(self.device).float() joint_seed = joint_seed.to(self.device).float() # Extract translation and rotation parts pos = target_pose[:, :3, 3] rot = target_pose[:, :3, :3] tf = self.pk.Transform3d( pos=pos, rot=rot, device=self.device, ) self.pik.initial_config = joint_seed result = self.pik.solve(tf) return result.converged_any, result.solutions[:, 0, :].squeeze(0) def _qpos_map_to_limits( self, qpos: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: r"""Maps a batch of joint positions to fit within joint limits and computes the distance to the seed position. Args: qpos (torch.Tensor): Batch of candidate joint positions, shape (N, dof). Returns: tuple[torch.Tensor, torch.Tensor]: A tuple containing: - torch.Tensor: whether qpos exactly within joint limit, shape (N). - torch.Tensor: qpos that roughly mapped into joint limit, shape (N, dof). """ two_pi = 2.0 * torch.pi k = torch.ceil((self.lower_qpos_limits - qpos) / two_pi) qpos_mapped = qpos + k * two_pi is_within_limits = (qpos_mapped >= self.lower_qpos_limits) & ( qpos_mapped <= self.upper_qpos_limits ) # if qpos_mapped is valid near zero, use it k_zero = torch.ceil( (-torch.pi - qpos) / two_pi ) # [-pi, pi] is the valid range near zero qpos_mapped_near_zero = qpos + k_zero * two_pi is_within_limits_near_zero = ( qpos_mapped_near_zero >= self.lower_qpos_limits ) & (qpos_mapped_near_zero <= self.upper_qpos_limits) qpos_mapped[is_within_limits_near_zero] = qpos_mapped_near_zero[ is_within_limits_near_zero ] return is_within_limits.all(dim=1), qpos_mapped
[docs] @ensure_pose_shape def get_ik( self, target_xpos: torch.Tensor, qpos_seed: torch.Tensor | None = None, num_samples: int | None = None, return_all_solutions: bool = False, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: r"""Computes the inverse kinematics for given target poses. This function generates random joint configurations within the specified limits, including the provided joint_seed, and attempts to find valid inverse kinematics solutions. It then identifies the joint positions that are closest to the joint_seed. Args: target_xpos (torch.Tensor): A tensor representing the target positions. It can be of shape (batch_size, 3) for multiple positions or (3,) for a single position. qpos_seed (torch.Tensor | None): Initial joint positions used as seed for IK solving. Can be: - 1D tensor of shape (dof,): Single seed for all target positions - 2D tensor of shape (batch_size, dof): Individual seed per position If None, defaults to zero configuration. Defaults to None. num_samples (int | None): The number of random samples to generate. Must be positive. Defaults to None. return_all_solutions (bool, optional): If True, returns all valid solutions found. **kwargs: Additional arguments for future extensions. Returns: tuple[list[bool], torch.Tensor]: A tuple containing: - A tensor of booleans indicating whether valid solutions were found for each target pose. (Shape: (batch_size,)) - A tensor of shape (batch_size, 1, dof) containing joint positions for each target pose, or an empty tensor if no valid solutions were found. """ # Convert target_pose to tensor and ensure correct device and dtype target_xpos = torch.as_tensor( target_xpos, device=self.device, dtype=torch.float32 ) if num_samples is not None: self._num_samples = num_samples # Prepare qpos_seed if qpos_seed is None: qpos_seed = torch.zeros(self.dof, device=self.device) else: qpos_seed = torch.as_tensor(qpos_seed, device=self.device) # Check qpos_seed dimensions n_batch = target_xpos.shape[0] if qpos_seed.shape == (n_batch, self.dof): qpos_seed = qpos_seed elif qpos_seed.shape == (self.dof,): qpos_seed = qpos_seed.unsqueeze(0).repeat(n_batch, 1) else: logger.log_error( f"Invalid qpos_seed shape {qpos_seed.shape} for batch_size {n_batch} and dof {self.dof}", ValueError, ) # output qpos_seed shape: (batch_size, dof) # Transform target_xpos by TCP tcp_xpos = torch.as_tensor( self.tcp_xpos, device=self.device, dtype=torch.float32 ) tcp_xpos_inv = tcp_xpos.clone() tcp_xpos_inv[:3, :3] = tcp_xpos_inv[:3, :3].T tcp_xpos_inv[:3, 3] = -tcp_xpos_inv[:3, :3] @ tcp_xpos_inv[:3, 3] target_xpos = target_xpos @ tcp_xpos_inv # Get joint limits and ensure shape matches dof batch_size = target_xpos.shape[0] sampler = QposSeedSampler( num_samples=self._num_samples, dof=self.dof, device=self.device ) random_qpos_seeds = sampler.sample( qpos_seed, self.lower_qpos_limits, self.upper_qpos_limits, batch_size, ) target_xpos_repeated = sampler.repeat_target_xpos( target_xpos, self._num_samples ) # Compute IK solutions for all samples is_ik_success, ik_qpos = self._compute_inverse_kinematics( target_xpos_repeated, random_qpos_seeds ) if is_ik_success.any().item() is False: logger.log_warning("No IK solutions found for any of the target poses.") failed_state = is_ik_success.reshape(batch_size, self._num_samples)[:, 0] failed_qpos = ik_qpos.reshape(batch_size, self._num_samples, self.dof)[ :, 0, : ] return failed_state, failed_qpos # map ik_qpos to within limits and check validity is_mask_valid, ik_qpos_mapped = self._qpos_map_to_limits(ik_qpos) is_success = torch.logical_and(is_ik_success, is_mask_valid) all_is_success = is_success.reshape(batch_size, self._num_samples) all_results = ik_qpos_mapped.reshape(batch_size, self._num_samples, self.dof) if return_all_solutions: return all_is_success.any(dim=1), all_results qpos_seed_repeat = qpos_seed.unsqueeze(1).repeat(1, self._num_samples, 1) weighed_diff = self.ik_nearest_weight * (all_results - qpos_seed_repeat) qpos_seed_dis = torch.norm(weighed_diff, dim=2) # Tricky: mask out invalid solutions by setting distance to inf, so they won't be selected as closest qpos_seed_dis[~all_is_success] = float("inf") closest_indices = torch.argmin(qpos_seed_dis, dim=1) closest_qpos = all_results[torch.arange(batch_size), closest_indices] return all_is_success.any(dim=1), closest_qpos[:, None, :]
[docs] def get_all_fk(self, qpos: torch.tensor) -> torch.tensor: r"""Get the forward kinematics for all links from root to end link. Args: qpos (torch.Tensor): The joint positions. Returns: list: A list of 4x4 homogeneous transformation matrices representing the poses of all links from root to end link. """ qpos = torch.as_tensor(qpos) qpos = qpos.to(self.device) ret = self.pk_serial_chain.forward_kinematics(qpos, end_only=False) link_names = list(ret.keys()) if self.root_link_name is not None: try: start_index = link_names.index(self.root_link_name) except ValueError: raise KeyError( f"Root link name '{self.root_link_name}' not found in the kinematic chain" ) else: start_index = 0 if self.end_link_name is not None: try: end_index = link_names.index(self.end_link_name) + 1 except ValueError: raise KeyError( f"End link name '{self.end_link_name}' not found in the kinematic chain" ) else: end_index = len(link_names) poses = [] for link_name in link_names[start_index:end_index]: xpos = ret[link_name] if not hasattr(xpos, "get_matrix"): raise AttributeError( f"The result for link '{link_name}' must have 'get_matrix' attributes." ) xpos_t = torch.eye(4, device=xpos.get_matrix().device) m = xpos.get_matrix() xpos_t[:3, 3] = m[:, :3, 3] xpos_t[:3, :3] = m[:, :3, :3] poses.append(xpos_t) return poses