Source code for embodichain.agents.rl.train

# ----------------------------------------------------------------------------
# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------

import argparse
import os
import time
from pathlib import Path

import numpy as np
import torch
import wandb
from torch.utils.tensorboard import SummaryWriter
from copy import deepcopy

from embodichain.agents.rl.models import build_policy, get_registered_policy_names
from embodichain.agents.rl.models import build_mlp_from_cfg
from embodichain.agents.rl.algo import build_algo, get_registered_algo_names
from embodichain.agents.rl.utils import dict_to_tensordict, flatten_dict_observation
from embodichain.agents.rl.utils.trainer import Trainer
from embodichain.utils import logger
from embodichain.lab.gym.envs.tasks.rl import build_env
from embodichain.lab.gym.utils.gym_utils import config_to_cfg, DEFAULT_MANAGER_MODULES
from embodichain.utils.utility import load_config
from embodichain.utils.module_utils import find_function_from_modules
from embodichain.lab.sim import SimulationManagerCfg
from embodichain.lab.sim.cfg import RenderCfg
from embodichain.lab.gym.envs.managers.cfg import EventCfg


[docs] def parse_args(): """Parse command line arguments.""" parser = argparse.ArgumentParser() parser.add_argument( "--config", type=str, required=True, help="Path to training config file (.json, .yaml, or .yml).", ) parser.add_argument( "--distributed", action=argparse.BooleanOptionalAction, default=None, help="Enable or disable multi-GPU distributed training", ) return parser.parse_args()
[docs] def train_from_config(config_path: str, distributed: bool | None = None): """Run training from a config file path. Args: config_path: Path to the training config file (.json, .yaml, or .yml). distributed: If True, run multi-GPU distributed training. If None, use trainer.distributed from config. """ cfg_data = load_config(config_path) trainer_cfg = cfg_data["trainer"] policy_block = cfg_data["policy"] algo_block = cfg_data["algorithm"] # Resolve distributed flag if distributed is None: distributed = bool(trainer_cfg.get("distributed", False)) # Distributed setup rank = 0 world_size = 1 local_rank = 0 if distributed: if not torch.distributed.is_available(): raise RuntimeError( "Distributed training requested but torch.distributed is not available." ) if not torch.cuda.is_available(): raise RuntimeError( "Distributed training with NCCL backend requires CUDA, " "but torch.cuda.is_available() is False." ) local_rank = int(os.environ.get("LOCAL_RANK", 0)) if local_rank < 0 or local_rank >= torch.cuda.device_count(): raise ValueError( f"LOCAL_RANK {local_rank} is out of range " f"(available GPUs: {torch.cuda.device_count()})." ) torch.cuda.set_device(local_rank) if not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend="nccl") rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() # Runtime exp_name = trainer_cfg.get("exp_name", "generic_exp") seed = int(trainer_cfg.get("seed", 1)) device_str = trainer_cfg.get("device", "cpu") if distributed: device_str = f"cuda:{local_rank}" iterations = int(trainer_cfg.get("iterations", 250)) buffer_size = int( trainer_cfg.get("buffer_size", trainer_cfg.get("rollout_steps", 2048)) ) enable_eval = bool(trainer_cfg.get("enable_eval", False)) eval_freq = int(trainer_cfg.get("eval_freq", 10000)) save_freq = int(trainer_cfg.get("save_freq", 50000)) num_eval_episodes = int(trainer_cfg.get("num_eval_episodes", 5)) headless = bool(trainer_cfg.get("headless", True)) renderer = trainer_cfg.get("renderer", "hybrid") gpu_id = int(trainer_cfg.get("gpu_id", 0)) num_envs = trainer_cfg.get("num_envs", None) wandb_project_name = trainer_cfg.get("wandb_project_name", "embodichain-generic") # Device if not isinstance(device_str, str): raise ValueError( f"runtime.device must be a string such as 'cpu' or 'cuda:0'. Got: {device_str!r}" ) try: device = torch.device(device_str) except RuntimeError as exc: raise ValueError( f"Failed to parse runtime.device='{device_str}': {exc}" ) from exc if device.type == "cuda": if not torch.cuda.is_available(): raise ValueError( "CUDA device requested but torch.cuda.is_available() is False." ) index = ( device.index if device.index is not None else torch.cuda.current_device() ) device_count = torch.cuda.device_count() if index < 0 or index >= device_count: raise ValueError( f"CUDA device index {index} is out of range (available devices: {device_count})." ) torch.cuda.set_device(index) device = torch.device(f"cuda:{index}") elif device.type != "cpu": raise ValueError(f"Unsupported device type: {device}") if rank == 0: logger.log_info(f"Device: {device}") if distributed and rank == 0: logger.log_info(f"Distributed training: world_size={world_size}") # Seeds effective_seed = seed + rank np.random.seed(effective_seed) torch.manual_seed(effective_seed) torch.backends.cudnn.deterministic = True if device.type == "cuda": torch.cuda.manual_seed_all(effective_seed) # Outputs if distributed: run_stamp = time.strftime("%Y%m%d_%H%M%S") if rank == 0 else None run_stamp_list = [run_stamp] torch.distributed.broadcast_object_list(run_stamp_list, src=0) run_stamp = run_stamp_list[0] else: run_stamp = time.strftime("%Y%m%d_%H%M%S") run_base = os.path.join("outputs", f"{exp_name}_{run_stamp}") log_dir = os.path.join(run_base, "logs") checkpoint_dir = os.path.join(run_base, "checkpoints") if rank == 0: os.makedirs(log_dir, exist_ok=True) os.makedirs(checkpoint_dir, exist_ok=True) writer = SummaryWriter(f"{log_dir}/{exp_name}") if rank == 0 else None # Initialize Weights & Biases (optional) use_wandb = trainer_cfg.get("use_wandb", False) if use_wandb and rank == 0: wandb.init(project=wandb_project_name, name=exp_name, config=cfg_data) gym_config_path = Path(trainer_cfg["gym_config"]) if rank == 0: logger.log_info(f"Current working directory: {Path.cwd()}") gym_config_data = load_config(str(gym_config_path)) gym_env_cfg = config_to_cfg( gym_config_data, manager_modules=DEFAULT_MANAGER_MODULES ) if num_envs is not None: gym_env_cfg.num_envs = int(num_envs) # Ensure sim configuration mirrors runtime overrides if gym_env_cfg.sim_cfg is None: gym_env_cfg.sim_cfg = SimulationManagerCfg() if device.type == "cuda": gpu_index = device.index if gpu_index is None: gpu_index = torch.cuda.current_device() gym_env_cfg.sim_cfg.sim_device = torch.device(f"cuda:{gpu_index}") if hasattr(gym_env_cfg.sim_cfg, "gpu_id"): gym_env_cfg.sim_cfg.gpu_id = gpu_index else: gym_env_cfg.sim_cfg.sim_device = torch.device("cpu") gym_env_cfg.sim_cfg.headless = headless gym_env_cfg.sim_cfg.render_cfg = RenderCfg(renderer=renderer) gym_env_cfg.sim_cfg.gpu_id = gpu_id logger.log_info( f"Loaded gym_config from {gym_config_path} (env_id={gym_config_data['id']}, num_envs={gym_env_cfg.num_envs}, headless={gym_env_cfg.sim_cfg.headless}, renderer={gym_env_cfg.sim_cfg.render_cfg.renderer}, sim_device={gym_env_cfg.sim_cfg.sim_device})" ) env = build_env(gym_config_data["id"], base_env_cfg=gym_env_cfg) sample_obs, _ = env.reset() sample_obs_td = dict_to_tensordict(sample_obs, device) obs_dim = flatten_dict_observation(sample_obs_td).shape[-1] flat_obs_space = env.flattened_observation_space # Create evaluation environment only if enabled eval_env = None num_eval_envs = trainer_cfg.get("num_eval_envs", 4) if enable_eval and rank == 0: eval_gym_env_cfg = deepcopy(gym_env_cfg) eval_gym_env_cfg.num_envs = num_eval_envs eval_gym_env_cfg.sim_cfg.headless = True eval_env = build_env(gym_config_data["id"], base_env_cfg=eval_gym_env_cfg) logger.log_info( f"Evaluation environment created (num_envs={num_eval_envs}, headless=True)" ) # Build Policy via registry policy_name = policy_block["name"] env_action_dim = ( env.get_wrapper_attr("action_manager").total_action_dim if env.get_wrapper_attr("action_manager") is not None else len(env.get_wrapper_attr("active_joint_ids")) ) action_dim = policy_block.get("action_dim", env_action_dim) action_dim = int(action_dim) if action_dim != env_action_dim: raise ValueError( f"Configured policy.action_dim={action_dim} does not match env action dim {env_action_dim}." ) # Build Policy via registry (actor/critic must be explicitly defined in JSON when using actor_critic/actor_only) if policy_name.lower() == "actor_critic": actor_cfg = policy_block.get("actor") critic_cfg = policy_block.get("critic") if actor_cfg is None or critic_cfg is None: raise ValueError( "ActorCritic requires 'actor' and 'critic' definitions in JSON (policy.actor / policy.critic)." ) actor = build_mlp_from_cfg(actor_cfg, obs_dim, action_dim) critic = build_mlp_from_cfg(critic_cfg, obs_dim, 1) policy = build_policy( policy_block, flat_obs_space, env.action_space, device, actor=actor, critic=critic, ) elif policy_name.lower() == "actor_only": actor_cfg = policy_block.get("actor") if actor_cfg is None: raise ValueError( "ActorOnly requires 'actor' definition in JSON (policy.actor)." ) actor = build_mlp_from_cfg(actor_cfg, obs_dim, action_dim) policy = build_policy( policy_block, flat_obs_space, env.action_space, device, actor=actor, ) else: policy = build_policy( policy_block, env.observation_space, env.action_space, device ) # Build Algorithm via factory algo_name = algo_block["name"].lower() algo_cfg = algo_block["cfg"] algo = build_algo( algo_name, algo_cfg, policy, device, distributed=distributed, ) # Build Trainer event_modules = [ "embodichain.lab.gym.envs.managers.randomization", "embodichain.lab.gym.envs.managers.record", "embodichain.lab.gym.envs.managers.events", ] events_dict = trainer_cfg.get("events", {}) train_event_cfg = {} eval_event_cfg = {} # Parse train events for event_name, event_info in events_dict.get("train", {}).items(): event_func_str = event_info.get("func") mode = event_info.get("mode", "interval") params = event_info.get("params", {}) interval_step = event_info.get("interval_step", 1) event_func = find_function_from_modules( event_func_str, event_modules, raise_if_not_found=True ) train_event_cfg[event_name] = EventCfg( func=event_func, mode=mode, params=params, interval_step=interval_step, ) # Parse eval events (only if evaluation is enabled) if enable_eval: for event_name, event_info in events_dict.get("eval", {}).items(): event_func_str = event_info.get("func") mode = event_info.get("mode", "interval") params = event_info.get("params", {}) interval_step = event_info.get("interval_step", 1) event_func = find_function_from_modules( event_func_str, event_modules, raise_if_not_found=True ) eval_event_cfg[event_name] = EventCfg( func=event_func, mode=mode, params=params, interval_step=interval_step, ) trainer = Trainer( policy=policy, env=env, algorithm=algo, buffer_size=buffer_size, batch_size=algo_cfg["batch_size"], writer=writer, eval_freq=eval_freq if enable_eval else 0, # Disable eval if not enabled save_freq=save_freq, checkpoint_dir=checkpoint_dir, exp_name=exp_name, use_wandb=use_wandb, eval_env=eval_env, # None if enable_eval=False event_cfg=train_event_cfg, eval_event_cfg=eval_event_cfg if (enable_eval and rank == 0) else {}, num_eval_episodes=num_eval_episodes, distributed=distributed, rank=rank, world_size=world_size, ) if rank == 0: logger.log_info("Generic training initialized") logger.log_info(f"Task: {type(env).__name__}") logger.log_info( f"Policy: {policy_name} (available: {get_registered_policy_names()})" ) logger.log_info( f"Algorithm: {algo_name} (available: {get_registered_algo_names()})" ) total_steps = int(iterations * buffer_size * env.num_envs * world_size) if rank == 0: logger.log_info( f"Total steps: {total_steps} (iterations≈{iterations}, world_size={world_size})" ) try: trainer.train(total_steps) except KeyboardInterrupt: if rank == 0: logger.log_info("Training interrupted by user") finally: trainer.save_checkpoint() if writer is not None: writer.close() if use_wandb and rank == 0: try: wandb.finish() except Exception: pass # Clean up environments to prevent resource leaks try: if env is not None: env.close() except Exception as e: if rank == 0: logger.log_warning(f"Failed to close training environment: {e}") try: if eval_env is not None: eval_env.close() except Exception as e: if rank == 0: logger.log_warning(f"Failed to close evaluation environment: {e}") if distributed and torch.distributed.is_initialized(): torch.distributed.destroy_process_group() if rank == 0: logger.log_info("Training finished")
[docs] def cli() -> None: """Command-line interface for RL training. Parses CLI arguments and launches training from a config file. """ args = parse_args() train_from_config(args.config, distributed=args.distributed)
if __name__ == "__main__": cli()