NeuralIKSolver#
Experimental
NeuralIKSolver is an experimental feature. The API, checkpoint format,
and default parameters may change without a deprecation cycle. It is currently
only validated on the Franka Panda robot.
NeuralIKSolver is a learning-based inverse kinematics (IK) solver that uses a
trained neural network policy to iteratively solve IK queries. It requires a
pre-trained checkpoint and supports batch processing.
Key Features#
Iterative neural policy inference for IK solving
Batch processing for multiple target poses simultaneously
Multi-seed sampling: generate several random initial guesses and return the best solution
Joint limit enforcement at every iteration
PyTorch-based — supports both CPU and CUDA devices
Configuration#
The solver is configured using the NeuralIKSolverCfg class. Pre-trained
checkpoints are hosted on HuggingFace and can be downloaded with
download_neural_ik_checkpoint() (requires HF_TOKEN environment variable).
from embodichain.data.assets.solver_assets import download_neural_ik_checkpoint
from embodichain.lab.sim.solvers.neural_ik_solver import NeuralIKSolverCfg
checkpoint_path = download_neural_ik_checkpoint()
cfg = NeuralIKSolverCfg(
checkpoint_path=checkpoint_path,
num_arm_joints=7,
max_steps=30,
action_scale=0.2,
hidden_dims=[256, 256],
pos_eps=0.01,
rot_eps=0.1,
num_samples=1,
)
Main Methods#
get_ik(self, target_xpos, qpos_seed=None, num_samples=None, **kwargs)Solve IK for the given target end-effector pose(s).Parameters:
target_xpos(torch.Tensor): Target pose(s) as 4x4 matrix, shape(4, 4)or(B, 4, 4).qpos_seed(torch.Tensor, optional): Initial joint positions, shape(dof,)or(B, dof).num_samples(int, optional): Overridecfg.num_samplesfor this call.return_all_solutions(bool): IfTrue, return all sampled solutions with shape(B, num_samples, dof).
Returns:
Tuple[torch.Tensor, torch.Tensor]:Success flags, shape
(B,).Joint positions, shape
(B, 1, dof)or(B, num_samples, dof).
Example:
import torch
success, ik_qpos = solver.get_ik(target_xpos=target_pose, qpos_seed=qpos_seed)
print("Success:", success)
print("IK solution:", ik_qpos)