Source code for embodichain.agents.rl.train

# ----------------------------------------------------------------------------
# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------

import argparse
import os
import time
from pathlib import Path

import numpy as np
import torch
import wandb
import json
from torch.utils.tensorboard import SummaryWriter
from copy import deepcopy

from embodichain.agents.rl.models import build_policy, get_registered_policy_names
from embodichain.agents.rl.models import build_mlp_from_cfg
from embodichain.agents.rl.algo import build_algo, get_registered_algo_names
from embodichain.agents.rl.utils import dict_to_tensordict, flatten_dict_observation
from embodichain.agents.rl.utils.trainer import Trainer
from embodichain.utils import logger
from embodichain.lab.gym.envs.tasks.rl import build_env
from embodichain.lab.gym.utils.gym_utils import config_to_cfg, DEFAULT_MANAGER_MODULES
from embodichain.utils.utility import load_json
from embodichain.utils.module_utils import find_function_from_modules
from embodichain.lab.sim import SimulationManagerCfg
from embodichain.lab.gym.envs.managers.cfg import EventCfg


[docs] def parse_args(): """Parse command line arguments.""" parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True, help="Path to JSON config") parser.add_argument( "--distributed", action=argparse.BooleanOptionalAction, default=None, help="Enable or disable multi-GPU distributed training", ) return parser.parse_args()
[docs] def train_from_config(config_path: str, distributed: bool | None = None): """Run training from a config file path. Args: config_path: Path to the JSON config file distributed: If True, run multi-GPU distributed training. If None, use trainer.distributed from config. """ with open(config_path, "r") as f: cfg_json = json.load(f) trainer_cfg = cfg_json["trainer"] policy_block = cfg_json["policy"] algo_block = cfg_json["algorithm"] # Resolve distributed flag if distributed is None: distributed = bool(trainer_cfg.get("distributed", False)) # Distributed setup rank = 0 world_size = 1 local_rank = 0 if distributed: if not torch.distributed.is_available(): raise RuntimeError( "Distributed training requested but torch.distributed is not available." ) if not torch.cuda.is_available(): raise RuntimeError( "Distributed training with NCCL backend requires CUDA, " "but torch.cuda.is_available() is False." ) local_rank = int(os.environ.get("LOCAL_RANK", 0)) if local_rank < 0 or local_rank >= torch.cuda.device_count(): raise ValueError( f"LOCAL_RANK {local_rank} is out of range " f"(available GPUs: {torch.cuda.device_count()})." ) torch.cuda.set_device(local_rank) if not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend="nccl") rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() # Runtime exp_name = trainer_cfg.get("exp_name", "generic_exp") seed = int(trainer_cfg.get("seed", 1)) device_str = trainer_cfg.get("device", "cpu") if distributed: device_str = f"cuda:{local_rank}" iterations = int(trainer_cfg.get("iterations", 250)) buffer_size = int( trainer_cfg.get("buffer_size", trainer_cfg.get("rollout_steps", 2048)) ) enable_eval = bool(trainer_cfg.get("enable_eval", False)) eval_freq = int(trainer_cfg.get("eval_freq", 10000)) save_freq = int(trainer_cfg.get("save_freq", 50000)) num_eval_episodes = int(trainer_cfg.get("num_eval_episodes", 5)) headless = bool(trainer_cfg.get("headless", True)) enable_rt = bool(trainer_cfg.get("enable_rt", False)) gpu_id = int(trainer_cfg.get("gpu_id", 0)) num_envs = trainer_cfg.get("num_envs", None) wandb_project_name = trainer_cfg.get("wandb_project_name", "embodichain-generic") # Device if not isinstance(device_str, str): raise ValueError( f"runtime.device must be a string such as 'cpu' or 'cuda:0'. Got: {device_str!r}" ) try: device = torch.device(device_str) except RuntimeError as exc: raise ValueError( f"Failed to parse runtime.device='{device_str}': {exc}" ) from exc if device.type == "cuda": if not torch.cuda.is_available(): raise ValueError( "CUDA device requested but torch.cuda.is_available() is False." ) index = ( device.index if device.index is not None else torch.cuda.current_device() ) device_count = torch.cuda.device_count() if index < 0 or index >= device_count: raise ValueError( f"CUDA device index {index} is out of range (available devices: {device_count})." ) torch.cuda.set_device(index) device = torch.device(f"cuda:{index}") elif device.type != "cpu": raise ValueError(f"Unsupported device type: {device}") if rank == 0: logger.log_info(f"Device: {device}") if distributed and rank == 0: logger.log_info(f"Distributed training: world_size={world_size}") # Seeds effective_seed = seed + rank np.random.seed(effective_seed) torch.manual_seed(effective_seed) torch.backends.cudnn.deterministic = True if device.type == "cuda": torch.cuda.manual_seed_all(effective_seed) # Outputs if distributed: run_stamp = time.strftime("%Y%m%d_%H%M%S") if rank == 0 else None run_stamp_list = [run_stamp] torch.distributed.broadcast_object_list(run_stamp_list, src=0) run_stamp = run_stamp_list[0] else: run_stamp = time.strftime("%Y%m%d_%H%M%S") run_base = os.path.join("outputs", f"{exp_name}_{run_stamp}") log_dir = os.path.join(run_base, "logs") checkpoint_dir = os.path.join(run_base, "checkpoints") if rank == 0: os.makedirs(log_dir, exist_ok=True) os.makedirs(checkpoint_dir, exist_ok=True) writer = SummaryWriter(f"{log_dir}/{exp_name}") if rank == 0 else None # Initialize Weights & Biases (optional) use_wandb = trainer_cfg.get("use_wandb", False) if use_wandb and rank == 0: wandb.init(project=wandb_project_name, name=exp_name, config=cfg_json) gym_config_path = Path(trainer_cfg["gym_config"]) if rank == 0: logger.log_info(f"Current working directory: {Path.cwd()}") gym_config_data = load_json(str(gym_config_path)) gym_env_cfg = config_to_cfg( gym_config_data, manager_modules=DEFAULT_MANAGER_MODULES ) if num_envs is not None: gym_env_cfg.num_envs = int(num_envs) # Ensure sim configuration mirrors runtime overrides if gym_env_cfg.sim_cfg is None: gym_env_cfg.sim_cfg = SimulationManagerCfg() if device.type == "cuda": gpu_index = device.index if gpu_index is None: gpu_index = torch.cuda.current_device() gym_env_cfg.sim_cfg.sim_device = torch.device(f"cuda:{gpu_index}") if hasattr(gym_env_cfg.sim_cfg, "gpu_id"): gym_env_cfg.sim_cfg.gpu_id = gpu_index else: gym_env_cfg.sim_cfg.sim_device = torch.device("cpu") gym_env_cfg.sim_cfg.headless = headless gym_env_cfg.sim_cfg.enable_rt = enable_rt gym_env_cfg.sim_cfg.gpu_id = local_rank if distributed else gpu_id if rank == 0: logger.log_info( f"Loaded gym_config from {gym_config_path} (env_id={gym_config_data['id']}, num_envs={gym_env_cfg.num_envs}, headless={gym_env_cfg.sim_cfg.headless}, enable_rt={gym_env_cfg.sim_cfg.enable_rt}, sim_device={gym_env_cfg.sim_cfg.sim_device})" ) env = build_env(gym_config_data["id"], base_env_cfg=gym_env_cfg) sample_obs, _ = env.reset() sample_obs_td = dict_to_tensordict(sample_obs, device) obs_dim = flatten_dict_observation(sample_obs_td).shape[-1] flat_obs_space = env.flattened_observation_space # Create evaluation environment only if enabled eval_env = None num_eval_envs = trainer_cfg.get("num_eval_envs", 4) if enable_eval and rank == 0: eval_gym_env_cfg = deepcopy(gym_env_cfg) eval_gym_env_cfg.num_envs = num_eval_envs eval_gym_env_cfg.sim_cfg.headless = True eval_env = build_env(gym_config_data["id"], base_env_cfg=eval_gym_env_cfg) logger.log_info( f"Evaluation environment created (num_envs={num_eval_envs}, headless=True)" ) # Build Policy via registry policy_name = policy_block["name"] env_action_dim = ( env.get_wrapper_attr("action_manager").total_action_dim if env.get_wrapper_attr("action_manager") is not None else len(env.get_wrapper_attr("active_joint_ids")) ) action_dim = policy_block.get("action_dim", env_action_dim) action_dim = int(action_dim) if action_dim != env_action_dim: raise ValueError( f"Configured policy.action_dim={action_dim} does not match env action dim {env_action_dim}." ) # Build Policy via registry (actor/critic must be explicitly defined in JSON when using actor_critic/actor_only) if policy_name.lower() == "actor_critic": actor_cfg = policy_block.get("actor") critic_cfg = policy_block.get("critic") if actor_cfg is None or critic_cfg is None: raise ValueError( "ActorCritic requires 'actor' and 'critic' definitions in JSON (policy.actor / policy.critic)." ) actor = build_mlp_from_cfg(actor_cfg, obs_dim, action_dim) critic = build_mlp_from_cfg(critic_cfg, obs_dim, 1) policy = build_policy( policy_block, flat_obs_space, env.action_space, device, actor=actor, critic=critic, ) elif policy_name.lower() == "actor_only": actor_cfg = policy_block.get("actor") if actor_cfg is None: raise ValueError( "ActorOnly requires 'actor' definition in JSON (policy.actor)." ) actor = build_mlp_from_cfg(actor_cfg, obs_dim, action_dim) policy = build_policy( policy_block, flat_obs_space, env.action_space, device, actor=actor, ) else: policy = build_policy( policy_block, env.observation_space, env.action_space, device ) # Build Algorithm via factory algo_name = algo_block["name"].lower() algo_cfg = algo_block["cfg"] algo = build_algo( algo_name, algo_cfg, policy, device, distributed=distributed, ) # Build Trainer event_modules = [ "embodichain.lab.gym.envs.managers.randomization", "embodichain.lab.gym.envs.managers.record", "embodichain.lab.gym.envs.managers.events", ] events_dict = trainer_cfg.get("events", {}) train_event_cfg = {} eval_event_cfg = {} # Parse train events for event_name, event_info in events_dict.get("train", {}).items(): event_func_str = event_info.get("func") mode = event_info.get("mode", "interval") params = event_info.get("params", {}) interval_step = event_info.get("interval_step", 1) event_func = find_function_from_modules( event_func_str, event_modules, raise_if_not_found=True ) train_event_cfg[event_name] = EventCfg( func=event_func, mode=mode, params=params, interval_step=interval_step, ) # Parse eval events (only if evaluation is enabled) if enable_eval: for event_name, event_info in events_dict.get("eval", {}).items(): event_func_str = event_info.get("func") mode = event_info.get("mode", "interval") params = event_info.get("params", {}) interval_step = event_info.get("interval_step", 1) event_func = find_function_from_modules( event_func_str, event_modules, raise_if_not_found=True ) eval_event_cfg[event_name] = EventCfg( func=event_func, mode=mode, params=params, interval_step=interval_step, ) trainer = Trainer( policy=policy, env=env, algorithm=algo, buffer_size=buffer_size, batch_size=algo_cfg["batch_size"], writer=writer, eval_freq=eval_freq if enable_eval else 0, # Disable eval if not enabled save_freq=save_freq, checkpoint_dir=checkpoint_dir, exp_name=exp_name, use_wandb=use_wandb, eval_env=eval_env, # None if enable_eval=False event_cfg=train_event_cfg, eval_event_cfg=eval_event_cfg if (enable_eval and rank == 0) else {}, num_eval_episodes=num_eval_episodes, distributed=distributed, rank=rank, world_size=world_size, ) if rank == 0: logger.log_info("Generic training initialized") logger.log_info(f"Task: {type(env).__name__}") logger.log_info( f"Policy: {policy_name} (available: {get_registered_policy_names()})" ) logger.log_info( f"Algorithm: {algo_name} (available: {get_registered_algo_names()})" ) total_steps = int(iterations * buffer_size * env.num_envs * world_size) if rank == 0: logger.log_info( f"Total steps: {total_steps} (iterations≈{iterations}, world_size={world_size})" ) try: trainer.train(total_steps) except KeyboardInterrupt: if rank == 0: logger.log_info("Training interrupted by user") finally: trainer.save_checkpoint() if writer is not None: writer.close() if use_wandb and rank == 0: try: wandb.finish() except Exception: pass # Clean up environments to prevent resource leaks try: if env is not None: env.close() except Exception as e: if rank == 0: logger.log_warning(f"Failed to close training environment: {e}") try: if eval_env is not None: eval_env.close() except Exception as e: if rank == 0: logger.log_warning(f"Failed to close evaluation environment: {e}") if distributed and torch.distributed.is_initialized(): torch.distributed.destroy_process_group() if rank == 0: logger.log_info("Training finished")
[docs] def main(): """Main entry point for command-line training.""" args = parse_args() train_from_config(args.config, distributed=args.distributed)
if __name__ == "__main__": main()