Reinforcement Learning Training#

This tutorial shows you how to train reinforcement learning agents using EmbodiChain’s RL framework. You’ll learn how to configure training via JSON or YAML, set up environments, policies, and algorithms, and launch training sessions.

Overview#

The RL framework provides a modular, extensible stack for robotics tasks:

  • Trainer: Orchestrates the training loop (calls algorithm for data collection and updates, handles logging/eval/save)

  • Algorithm: Controls data collection process (interacts with environment, fills buffer, computes advantages/returns) and updates the policy (e.g., PPO)

  • Policy: Neural network models implementing a unified interface (get_action/get_value/evaluate_actions)

  • Buffer: On-policy rollout storage and minibatch iterator (managed by algorithm)

  • Env Factory: Build environments from a JSON or YAML config via registry

Architecture#

The framework follows a clean separation of concerns:

  • Trainer: Orchestrates the training loop (calls algorithm for data collection and updates, handles logging/eval/save)

  • Algorithm: Controls data collection process (interacts with environment, fills buffer, computes advantages/returns) and updates the policy (e.g., PPO)

  • Policy: Neural network models implementing a unified interface

  • Buffer: On-policy rollout storage and minibatch iterator (managed by algorithm)

  • Env Factory: Build environments from a JSON or YAML config via registry

The core components and their relationships:

  • Trainer → Policy, Env, Algorithm (via callbacks for statistics)

  • Algorithm → Policy, RolloutBuffer (algorithm manages its own buffer)

Configuration via JSON or YAML#

Training is configured via a JSON or YAML file that defines runtime settings, environment, policy, and algorithm parameters. EmbodiChain loads either format with load_config(); the nested trainer.gym_config path supports the same extensions.

Example Configuration#

The configuration file (e.g., train_config.json or train_config.yaml) is located in configs/agents/rl/push_cube or configs/agents/rl/basic/cart_pole:

Example: train_config.json
 1{
 2    "trainer": {
 3        "exp_name": "push_cube_ppo",
 4        "gym_config": "configs/agents/rl/push_cube/gym_config.json",
 5        "seed": 42,
 6        "device": "cuda:0",
 7        "headless": true,
 8        "gpu_id": 0,
 9        "num_envs": 64,
10        "iterations": 1000,
11        "buffer_size": 1024,
12        "enable_eval": true,
13        "num_eval_envs": 16,
14        "num_eval_episodes": 3,
15        "eval_freq": 100,
16        "save_freq": 100,
17        "use_wandb": true,
18        "wandb_project_name": "embodichain-push_cube",
19        "events": {
20            "eval": {
21                "record_camera": {
22                    "func": "record_camera_data_async",
23                    "mode": "interval",
24                    "interval_step": 1,
25                    "params": {
26                        "name": "main_cam",
27                        "resolution": [640, 480],
28                        "eye": [-1.4, 1.4, 2.0],
29                        "target": [0, 0, 0],
30                        "up": [0, 0, 1],
31                        "intrinsics": [600, 600, 320, 240],
32                        "save_path": "./outputs/videos_ppo1/eval"
33                    }
34                }
35            }
36        },
37        "renderer": "hybrid"
38    },
39    "policy": {
40        "name": "actor_critic",
41        "actor": {
42            "type": "mlp",
43            "network_cfg": {
44                "hidden_sizes": [
45                    256,
46                    256
47                ],
48                "activation": "relu"
49            }
50        },
51        "critic": {
52            "type": "mlp",
53            "network_cfg": {
54                "hidden_sizes": [
55                    256,
56                    256
57                ],
58                "activation": "relu"
59            }
60        }
61    },
62    "algorithm": {
63        "name": "ppo",
64        "cfg": {
65            "learning_rate": 0.0001,
66            "n_epochs": 10,
67            "batch_size": 8192,
68            "gamma": 0.99,
69            "gae_lambda": 0.95,
70            "clip_coef": 0.2,
71            "ent_coef": 0.01,
72            "vf_coef": 0.5,
73            "max_grad_norm": 0.5
74        }
75    }
76}
Example: train_config.yaml (CartPole)
 1trainer:
 2  exp_name: cart_pole_ppo
 3  gym_config: configs/agents/rl/basic/cart_pole/gym_config.yaml
 4  seed: 42
 5  device: cuda:0
 6  headless: true
 7  num_envs: 64
 8  iterations: 1000
 9  buffer_size: 1024
10  eval_freq: 200
11  save_freq: 200
12  use_wandb: false
13  wandb_project_name: embodichain-cart_pole
14  events:
15    eval:
16      record_camera:
17        func: record_camera_data_async
18        mode: interval
19        interval_step: 1
20        params:
21          name: main_cam
22          resolution:
23          - 640
24          - 480
25          eye:
26          - -1.4
27          - 1.4
28          - 2.5
29          target:
30          - 0
31          - 0
32          - 0.7
33          up:
34          - 0
35          - 0
36          - 1
37          intrinsics:
38          - 600
39          - 600
40          - 320
41          - 240
42          save_path: ./outputs/videos/eval
43  renderer: fast-rt
44policy:
45  name: actor_critic
46  actor:
47    type: mlp
48    network_cfg:
49      hidden_sizes:
50      - 256
51      - 256
52      activation: relu
53  critic:
54    type: mlp
55    network_cfg:
56      hidden_sizes:
57      - 256
58      - 256
59      activation: relu
60algorithm:
61  name: ppo
62  cfg:
63    learning_rate: 0.0001
64    n_epochs: 10
65    batch_size: 8192
66    gamma: 0.99
67    gae_lambda: 0.95
68    clip_coef: 0.2
69    ent_coef: 0.01
70    vf_coef: 0.5
71    max_grad_norm: 0.5

Configuration Sections#

Runtime Settings#

The trainer section controls experiment setup:

  • exp_name: Experiment name (used for output directories)

  • seed: Random seed for reproducibility

  • device: Runtime device string, e.g. "cpu" or "cuda:0"

  • headless: Whether to run simulation in headless mode

  • iterations: Number of training iterations

  • buffer_size: Steps collected per rollout (e.g., 1024)

  • eval_freq: Frequency of evaluation (in steps)

  • save_freq: Frequency of checkpoint saving (in steps)

  • use_wandb: Whether to enable Weights & Biases logging (set in the config file)

  • wandb_project_name: Weights & Biases project name

Environment Configuration#

The env section defines the task environment:

  • id: Environment registry ID (e.g., “PushCubeRL”)

  • cfg: Environment-specific configuration parameters

For RL environments, use the actions field for action preprocessing and extensions for task-specific parameters:

  • actions: Action Manager config (e.g., DeltaQposTerm with scale)

  • extensions: Task-specific parameters (e.g., success_threshold)

Example:

"env": {
  "id": "PushCubeRL",
  "cfg": {
    "num_envs": 4,
    "actions": {
      "delta_qpos": {
        "func": "DeltaQposTerm",
        "params": { "scale": 0.1 }
      }
    },
    "extensions": {
      "success_threshold": 0.1
    }
  }
}

Policy Configuration#

The policy section defines the neural network policy:

  • name: Policy name (e.g., “actor_critic”, “vla”)

  • action_dim: Optional policy output action dimension. If omitted, it is inferred from env.action_space.

  • actor: Actor network configuration (required for actor_critic)

  • critic: Critic network configuration (required for actor_critic)

Example:

"policy": {
  "name": "actor_critic",
  "actor": {
    "type": "mlp",
    "network_cfg": {
      "hidden_sizes": [256, 256],
      "activation": "relu"
    }
  },
  "critic": {
    "type": "mlp",
    "network_cfg": {
      "hidden_sizes": [256, 256],
      "activation": "relu"
    }
  }
}

Algorithm Configuration#

The algorithm section defines the RL algorithm:

  • name: Algorithm name (e.g., “ppo”, “grpo”)

  • cfg: Algorithm-specific hyperparameters

PPO example:

"algorithm": {
  "name": "ppo",
  "cfg": {
    "learning_rate": 0.0001,
    "n_epochs": 10,
    "batch_size": 64,
    "gamma": 0.99,
    "gae_lambda": 0.95,
    "clip_coef": 0.2,
    "ent_coef": 0.01,
    "vf_coef": 0.5,
    "max_grad_norm": 0.5
  }
}

GRPO example (for Embodied AI / from-scratch training, e.g. CartPole):

"algorithm": {
  "name": "grpo",
  "cfg": {
    "learning_rate": 0.0001,
    "n_epochs": 10,
    "batch_size": 8192,
    "gamma": 0.99,
    "clip_coef": 0.2,
    "ent_coef": 0.001,
    "kl_coef": 0,
    "group_size": 4,
    "eps": 1e-8,
    "reset_every_rollout": true,
    "max_grad_norm": 0.5,
    "truncate_at_first_done": true
  }
}

For GRPO: use actor_only policy. Set kl_coef=0 for from-scratch training; kl_coef=0.02 for VLA/LLM fine-tuning.

Training Script#

The training script (train.py) is located in embodichain/agents/rl/:

Code for train.py
  1# ----------------------------------------------------------------------------
  2# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8#     http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15# ----------------------------------------------------------------------------
 16
 17import argparse
 18import os
 19import time
 20from pathlib import Path
 21
 22import numpy as np
 23import torch
 24import wandb
 25from torch.utils.tensorboard import SummaryWriter
 26from copy import deepcopy
 27
 28from embodichain.agents.rl.models import build_policy, get_registered_policy_names
 29from embodichain.agents.rl.models import build_mlp_from_cfg
 30from embodichain.agents.rl.algo import build_algo, get_registered_algo_names
 31from embodichain.agents.rl.utils import dict_to_tensordict, flatten_dict_observation
 32from embodichain.agents.rl.utils.trainer import Trainer
 33from embodichain.utils import logger
 34from embodichain.lab.gym.envs.tasks.rl import build_env
 35from embodichain.lab.gym.utils.gym_utils import config_to_cfg, DEFAULT_MANAGER_MODULES
 36from embodichain.utils.utility import load_config
 37from embodichain.utils.module_utils import find_function_from_modules
 38from embodichain.lab.sim import SimulationManagerCfg
 39from embodichain.lab.sim.cfg import RenderCfg
 40from embodichain.lab.gym.envs.managers.cfg import EventCfg
 41
 42
 43def parse_args():
 44    """Parse command line arguments."""
 45    parser = argparse.ArgumentParser()
 46    parser.add_argument(
 47        "--config",
 48        type=str,
 49        required=True,
 50        help="Path to training config file (.json, .yaml, or .yml).",
 51    )
 52    parser.add_argument(
 53        "--distributed",
 54        action=argparse.BooleanOptionalAction,
 55        default=None,
 56        help="Enable or disable multi-GPU distributed training",
 57    )
 58    return parser.parse_args()
 59
 60
 61def train_from_config(config_path: str, distributed: bool | None = None):
 62    """Run training from a config file path.
 63
 64    Args:
 65        config_path: Path to the training config file (.json, .yaml, or .yml).
 66        distributed: If True, run multi-GPU distributed training.
 67            If None, use trainer.distributed from config.
 68    """
 69    cfg_data = load_config(config_path)
 70
 71    trainer_cfg = cfg_data["trainer"]
 72    policy_block = cfg_data["policy"]
 73    algo_block = cfg_data["algorithm"]
 74
 75    # Resolve distributed flag
 76    if distributed is None:
 77        distributed = bool(trainer_cfg.get("distributed", False))
 78
 79    # Distributed setup
 80    rank = 0
 81    world_size = 1
 82    local_rank = 0
 83    if distributed:
 84        if not torch.distributed.is_available():
 85            raise RuntimeError(
 86                "Distributed training requested but torch.distributed is not available."
 87            )
 88        if not torch.cuda.is_available():
 89            raise RuntimeError(
 90                "Distributed training with NCCL backend requires CUDA, "
 91                "but torch.cuda.is_available() is False."
 92            )
 93        local_rank = int(os.environ.get("LOCAL_RANK", 0))
 94        if local_rank < 0 or local_rank >= torch.cuda.device_count():
 95            raise ValueError(
 96                f"LOCAL_RANK {local_rank} is out of range "
 97                f"(available GPUs: {torch.cuda.device_count()})."
 98            )
 99        torch.cuda.set_device(local_rank)
100        if not torch.distributed.is_initialized():
101            torch.distributed.init_process_group(backend="nccl")
102        rank = torch.distributed.get_rank()
103        world_size = torch.distributed.get_world_size()
104
105    # Runtime
106    exp_name = trainer_cfg.get("exp_name", "generic_exp")
107    seed = int(trainer_cfg.get("seed", 1))
108    device_str = trainer_cfg.get("device", "cpu")
109    if distributed:
110        device_str = f"cuda:{local_rank}"
111    iterations = int(trainer_cfg.get("iterations", 250))
112    buffer_size = int(
113        trainer_cfg.get("buffer_size", trainer_cfg.get("rollout_steps", 2048))
114    )
115    enable_eval = bool(trainer_cfg.get("enable_eval", False))
116    eval_freq = int(trainer_cfg.get("eval_freq", 10000))
117    save_freq = int(trainer_cfg.get("save_freq", 50000))
118    num_eval_episodes = int(trainer_cfg.get("num_eval_episodes", 5))
119    headless = bool(trainer_cfg.get("headless", True))
120    renderer = trainer_cfg.get("renderer", "hybrid")
121    gpu_id = int(trainer_cfg.get("gpu_id", 0))
122    num_envs = trainer_cfg.get("num_envs", None)
123    wandb_project_name = trainer_cfg.get("wandb_project_name", "embodichain-generic")
124
125    # Device
126    if not isinstance(device_str, str):
127        raise ValueError(
128            f"runtime.device must be a string such as 'cpu' or 'cuda:0'. Got: {device_str!r}"
129        )
130    try:
131        device = torch.device(device_str)
132    except RuntimeError as exc:
133        raise ValueError(
134            f"Failed to parse runtime.device='{device_str}': {exc}"
135        ) from exc
136
137    if device.type == "cuda":
138        if not torch.cuda.is_available():
139            raise ValueError(
140                "CUDA device requested but torch.cuda.is_available() is False."
141            )
142        index = (
143            device.index if device.index is not None else torch.cuda.current_device()
144        )
145        device_count = torch.cuda.device_count()
146        if index < 0 or index >= device_count:
147            raise ValueError(
148                f"CUDA device index {index} is out of range (available devices: {device_count})."
149            )
150        torch.cuda.set_device(index)
151        device = torch.device(f"cuda:{index}")
152    elif device.type != "cpu":
153        raise ValueError(f"Unsupported device type: {device}")
154    if rank == 0:
155        logger.log_info(f"Device: {device}")
156    if distributed and rank == 0:
157        logger.log_info(f"Distributed training: world_size={world_size}")
158
159    # Seeds
160    effective_seed = seed + rank
161    np.random.seed(effective_seed)
162    torch.manual_seed(effective_seed)
163    torch.backends.cudnn.deterministic = True
164    if device.type == "cuda":
165        torch.cuda.manual_seed_all(effective_seed)
166
167    # Outputs
168    if distributed:
169        run_stamp = time.strftime("%Y%m%d_%H%M%S") if rank == 0 else None
170        run_stamp_list = [run_stamp]
171        torch.distributed.broadcast_object_list(run_stamp_list, src=0)
172        run_stamp = run_stamp_list[0]
173    else:
174        run_stamp = time.strftime("%Y%m%d_%H%M%S")
175    run_base = os.path.join("outputs", f"{exp_name}_{run_stamp}")
176    log_dir = os.path.join(run_base, "logs")
177    checkpoint_dir = os.path.join(run_base, "checkpoints")
178    if rank == 0:
179        os.makedirs(log_dir, exist_ok=True)
180        os.makedirs(checkpoint_dir, exist_ok=True)
181    writer = SummaryWriter(f"{log_dir}/{exp_name}") if rank == 0 else None
182
183    # Initialize Weights & Biases (optional)
184    use_wandb = trainer_cfg.get("use_wandb", False)
185    if use_wandb and rank == 0:
186        wandb.init(project=wandb_project_name, name=exp_name, config=cfg_data)
187
188    gym_config_path = Path(trainer_cfg["gym_config"])
189    if rank == 0:
190        logger.log_info(f"Current working directory: {Path.cwd()}")
191
192    gym_config_data = load_config(str(gym_config_path))
193    gym_env_cfg = config_to_cfg(
194        gym_config_data, manager_modules=DEFAULT_MANAGER_MODULES
195    )
196    if num_envs is not None:
197        gym_env_cfg.num_envs = int(num_envs)
198
199    # Ensure sim configuration mirrors runtime overrides
200    if gym_env_cfg.sim_cfg is None:
201        gym_env_cfg.sim_cfg = SimulationManagerCfg()
202    if device.type == "cuda":
203        gpu_index = device.index
204        if gpu_index is None:
205            gpu_index = torch.cuda.current_device()
206        gym_env_cfg.sim_cfg.sim_device = torch.device(f"cuda:{gpu_index}")
207        if hasattr(gym_env_cfg.sim_cfg, "gpu_id"):
208            gym_env_cfg.sim_cfg.gpu_id = gpu_index
209    else:
210        gym_env_cfg.sim_cfg.sim_device = torch.device("cpu")
211    gym_env_cfg.sim_cfg.headless = headless
212    gym_env_cfg.sim_cfg.render_cfg = RenderCfg(renderer=renderer)
213    gym_env_cfg.sim_cfg.gpu_id = gpu_id
214    logger.log_info(
215        f"Loaded gym_config from {gym_config_path} (env_id={gym_config_data['id']}, num_envs={gym_env_cfg.num_envs}, headless={gym_env_cfg.sim_cfg.headless}, renderer={gym_env_cfg.sim_cfg.render_cfg.renderer}, sim_device={gym_env_cfg.sim_cfg.sim_device})"
216    )
217
218    env = build_env(gym_config_data["id"], base_env_cfg=gym_env_cfg)
219    sample_obs, _ = env.reset()
220    sample_obs_td = dict_to_tensordict(sample_obs, device)
221    obs_dim = flatten_dict_observation(sample_obs_td).shape[-1]
222    flat_obs_space = env.flattened_observation_space
223
224    # Create evaluation environment only if enabled
225    eval_env = None
226    num_eval_envs = trainer_cfg.get("num_eval_envs", 4)
227    if enable_eval and rank == 0:
228        eval_gym_env_cfg = deepcopy(gym_env_cfg)
229        eval_gym_env_cfg.num_envs = num_eval_envs
230        eval_gym_env_cfg.sim_cfg.headless = True
231        eval_env = build_env(gym_config_data["id"], base_env_cfg=eval_gym_env_cfg)
232        logger.log_info(
233            f"Evaluation environment created (num_envs={num_eval_envs}, headless=True)"
234        )
235
236    # Build Policy via registry
237    policy_name = policy_block["name"]
238    env_action_dim = (
239        env.get_wrapper_attr("action_manager").total_action_dim
240        if env.get_wrapper_attr("action_manager") is not None
241        else len(env.get_wrapper_attr("active_joint_ids"))
242    )
243    action_dim = policy_block.get("action_dim", env_action_dim)
244    action_dim = int(action_dim)
245    if action_dim != env_action_dim:
246        raise ValueError(
247            f"Configured policy.action_dim={action_dim} does not match env action dim {env_action_dim}."
248        )
249    # Build Policy via registry (actor/critic must be explicitly defined in JSON when using actor_critic/actor_only)
250    if policy_name.lower() == "actor_critic":
251        actor_cfg = policy_block.get("actor")
252        critic_cfg = policy_block.get("critic")
253        if actor_cfg is None or critic_cfg is None:
254            raise ValueError(
255                "ActorCritic requires 'actor' and 'critic' definitions in JSON (policy.actor / policy.critic)."
256            )
257
258        actor = build_mlp_from_cfg(actor_cfg, obs_dim, action_dim)
259        critic = build_mlp_from_cfg(critic_cfg, obs_dim, 1)
260
261        policy = build_policy(
262            policy_block,
263            flat_obs_space,
264            env.action_space,
265            device,
266            actor=actor,
267            critic=critic,
268        )
269    elif policy_name.lower() == "actor_only":
270        actor_cfg = policy_block.get("actor")
271        if actor_cfg is None:
272            raise ValueError(
273                "ActorOnly requires 'actor' definition in JSON (policy.actor)."
274            )
275
276        actor = build_mlp_from_cfg(actor_cfg, obs_dim, action_dim)
277
278        policy = build_policy(
279            policy_block,
280            flat_obs_space,
281            env.action_space,
282            device,
283            actor=actor,
284        )
285    else:
286        policy = build_policy(
287            policy_block, env.observation_space, env.action_space, device
288        )
289
290    # Build Algorithm via factory
291    algo_name = algo_block["name"].lower()
292    algo_cfg = algo_block["cfg"]
293    algo = build_algo(
294        algo_name,
295        algo_cfg,
296        policy,
297        device,
298        distributed=distributed,
299    )
300
301    # Build Trainer
302    event_modules = [
303        "embodichain.lab.gym.envs.managers.randomization",
304        "embodichain.lab.gym.envs.managers.record",
305        "embodichain.lab.gym.envs.managers.events",
306    ]
307    events_dict = trainer_cfg.get("events", {})
308    train_event_cfg = {}
309    eval_event_cfg = {}
310    # Parse train events
311    for event_name, event_info in events_dict.get("train", {}).items():
312        event_func_str = event_info.get("func")
313        mode = event_info.get("mode", "interval")
314        params = event_info.get("params", {})
315        interval_step = event_info.get("interval_step", 1)
316        event_func = find_function_from_modules(
317            event_func_str, event_modules, raise_if_not_found=True
318        )
319        train_event_cfg[event_name] = EventCfg(
320            func=event_func,
321            mode=mode,
322            params=params,
323            interval_step=interval_step,
324        )
325    # Parse eval events (only if evaluation is enabled)
326    if enable_eval:
327        for event_name, event_info in events_dict.get("eval", {}).items():
328            event_func_str = event_info.get("func")
329            mode = event_info.get("mode", "interval")
330            params = event_info.get("params", {})
331            interval_step = event_info.get("interval_step", 1)
332            event_func = find_function_from_modules(
333                event_func_str, event_modules, raise_if_not_found=True
334            )
335            eval_event_cfg[event_name] = EventCfg(
336                func=event_func,
337                mode=mode,
338                params=params,
339                interval_step=interval_step,
340            )
341    trainer = Trainer(
342        policy=policy,
343        env=env,
344        algorithm=algo,
345        buffer_size=buffer_size,
346        batch_size=algo_cfg["batch_size"],
347        writer=writer,
348        eval_freq=eval_freq if enable_eval else 0,  # Disable eval if not enabled
349        save_freq=save_freq,
350        checkpoint_dir=checkpoint_dir,
351        exp_name=exp_name,
352        use_wandb=use_wandb,
353        eval_env=eval_env,  # None if enable_eval=False
354        event_cfg=train_event_cfg,
355        eval_event_cfg=eval_event_cfg if (enable_eval and rank == 0) else {},
356        num_eval_episodes=num_eval_episodes,
357        distributed=distributed,
358        rank=rank,
359        world_size=world_size,
360    )
361
362    if rank == 0:
363        logger.log_info("Generic training initialized")
364        logger.log_info(f"Task: {type(env).__name__}")
365        logger.log_info(
366            f"Policy: {policy_name} (available: {get_registered_policy_names()})"
367        )
368        logger.log_info(
369            f"Algorithm: {algo_name} (available: {get_registered_algo_names()})"
370        )
371
372    total_steps = int(iterations * buffer_size * env.num_envs * world_size)
373    if rank == 0:
374        logger.log_info(
375            f"Total steps: {total_steps} (iterations≈{iterations}, world_size={world_size})"
376        )
377
378    try:
379        trainer.train(total_steps)
380    except KeyboardInterrupt:
381        if rank == 0:
382            logger.log_info("Training interrupted by user")
383    finally:
384        trainer.save_checkpoint()
385        if writer is not None:
386            writer.close()
387        if use_wandb and rank == 0:
388            try:
389                wandb.finish()
390            except Exception:
391                pass
392
393        # Clean up environments to prevent resource leaks
394        try:
395            if env is not None:
396                env.close()
397        except Exception as e:
398            if rank == 0:
399                logger.log_warning(f"Failed to close training environment: {e}")
400
401        try:
402            if eval_env is not None:
403                eval_env.close()
404        except Exception as e:
405            if rank == 0:
406                logger.log_warning(f"Failed to close evaluation environment: {e}")
407
408        if distributed and torch.distributed.is_initialized():
409            torch.distributed.destroy_process_group()
410
411        if rank == 0:
412            logger.log_info("Training finished")
413
414
415def cli() -> None:
416    """Command-line interface for RL training.
417
418    Parses CLI arguments and launches training from a config file.
419    """
420    args = parse_args()
421    train_from_config(args.config, distributed=args.distributed)
422
423
424if __name__ == "__main__":
425    cli()

The Script Explained#

The training script performs the following steps:

  1. Parse Configuration: Loads the config file (.json, .yaml, or .yml) and extracts runtime/env/policy/algorithm blocks

  2. Setup: Initializes device, seeds, output directories, TensorBoard, and Weights & Biases

  3. Build Components: - Environment via build_env() factory - Policy via build_policy() registry - Algorithm via build_algo() factory

  4. Create Trainer: Instantiates the Trainer with all components

  5. Train: Runs the training loop until completion

Launching Training#

To start training, run:

python -m embodichain train-rl --config configs/agents/rl/basic/cart_pole/train_config.yaml

JSON configs are also supported:

python -m embodichain train-rl --config configs/agents/rl/push_cube/train_config.json

Outputs#

All outputs are written to ./outputs/<exp_name>_<timestamp>/:

  • logs/: TensorBoard logs

  • checkpoints/: Model checkpoints

Training Process#

The training process follows this sequence:

  1. Rollout Phase: SyncCollector interacts with the environment and writes policy-side fields into a shared rollout TensorDict with uniform [N, T + 1] layout. EmbodiedEnv writes environment-side step fields such as reward, done, terminated, and truncated into the same rollout via set_rollout_buffer(). The final slot of transition-only fields is reserved as padding, while obs[:, -1] and value[:, -1] remain valid bootstrap data.

  2. Advantage/Return Computation: Algorithm computes advantages and returns from the collected rollout (e.g. GAE for PPO, step-wise group normalization for GRPO) and converts it to a transition-aligned view over the valid first T steps before minibatch optimization.

  3. Update Phase: Algorithm updates the policy with update(rollout)

  4. Logging: Trainer logs training losses and aggregated metrics to TensorBoard and Weights & Biases

  5. Evaluation (periodic): Trainer evaluates the current policy

  6. Checkpointing (periodic): Trainer saves model checkpoints

Policy Interface#

All policies must inherit from the Policy abstract base class:

from abc import ABC, abstractmethod
import torch.nn as nn

class Policy(nn.Module, ABC):
    device: torch.device

    def get_action(self, tensordict, deterministic: bool = False):
        """Samples action, sample_log_prob, and value into the TensorDict."""
        ...

    @abstractmethod
    def forward(self, tensordict, deterministic: bool = False):
        """Writes action, sample_log_prob, and value into the TensorDict."""
        raise NotImplementedError

    @abstractmethod
    def get_value(self, tensordict):
        """Writes value estimate into the TensorDict."""
        raise NotImplementedError

    @abstractmethod
    def evaluate_actions(self, tensordict):
        """Returns a new TensorDict with log_prob, entropy, and value."""
        raise NotImplementedError

Available Policies#

  • ActorCritic: MLP-based Gaussian policy with learnable log_std. Requires external actor and critic modules to be provided (defined in the training config file). Used with PPO.

  • ActorOnly: Actor-only policy without Critic. Used with GRPO (group-relative advantage estimation).

  • VLAPlaceholderPolicy: Placeholder for Vision-Language-Action policies

Algorithms#

Available Algorithms#

  • PPO: Proximal Policy Optimization with GAE

  • GRPO: Group Relative Policy Optimization (no Critic, step-wise returns, masked group normalization). Use actor_only policy. Set kl_coef=0 for from-scratch training (CartPole, dense reward); kl_coef=0.02 for VLA/LLM fine-tuning.

Adding a New Algorithm#

To add a new algorithm:

  1. Create a new algorithm class in embodichain/agents/rl/algo/

  2. Implement update(rollout) and consume the shared rollout TensorDict

  3. Register in algo/__init__.py:

from tensordict import TensorDict
from embodichain.agents.rl.algo import BaseAlgorithm, register_algo

@register_algo("my_algo")
class MyAlgorithm(BaseAlgorithm):
    def __init__(self, cfg, policy):
        self.cfg = cfg
        self.policy = policy
        self.device = torch.device(cfg.device)

    def update(self, rollout: TensorDict):
        """Update the policy using a collected rollout."""
        # compute advantages / returns from rollout
        # optimize policy parameters
        return {"loss": 0.0}

Adding a New Policy#

To add a new policy:

  1. Create a new policy class inheriting from the Policy abstract base class

  2. Register in models/__init__.py:

from embodichain.agents.rl.models import register_policy, Policy

@register_policy("my_policy")
class MyPolicy(Policy):
    def __init__(self, obs_dim, action_dim, device, config):
        super().__init__()
        self.device = device
        # Initialize your networks here

    def get_action(self, tensordict, deterministic=False):
        ...
    def forward(self, tensordict, deterministic=False):
        ...
    def get_value(self, tensordict):
        ...
    def evaluate_actions(self, tensordict):
        ...

Current built-in MLP policies use flattened observations in the training path. If your policy requires structured or multi-modal inputs, keep the richer obs_space interface and define a matching rollout/collector schema.

Adding a New Environment#

To add a new RL environment:

  1. Create an environment class inheriting from EmbodiedEnv (with Action Manager configured for action preprocessing and standardized info structure):

from embodichain.lab.gym.envs import EmbodiedEnv, EmbodiedEnvCfg
from embodichain.lab.gym.utils.registration import register_env
import torch

@register_env("MyTaskRL", override=True)
class MyTaskEnv(EmbodiedEnv):
    def __init__(self, cfg: EmbodiedEnvCfg = None, **kwargs):
        super().__init__(cfg, **kwargs)

    def compute_task_state(self, **kwargs):
        """Compute success/failure conditions and metrics."""
        is_success = ...  # Define success condition
        is_fail = torch.zeros_like(is_success)
        metrics = {"distance": ..., "error": ...}
        return is_success, is_fail, metrics
  1. Configure the environment in your config file with actions and extensions:

"env": {
  "id": "MyTaskRL",
  "cfg": {
    "num_envs": 4,
    "actions": {
      "delta_qpos": {
        "func": "DeltaQposTerm",
        "params": { "scale": 0.1 }
      }
    },
    "extensions": {
      "success_threshold": 0.05
    }
  }
}

The EmbodiedEnv with Action Manager provides:

  • Action Preprocessing: Configurable via actions (DeltaQposTerm, QposTerm, EefPoseTerm, etc.)

  • Standardized Info: Implements get_info() using compute_task_state() template method

Best Practices#

  • Use EmbodiedEnv with Action Manager for RL Tasks: Inherit from EmbodiedEnv and configure actions in your config. The Action Manager handles action preprocessing (delta_qpos, qpos, qvel, qf, eef_pose) in a modular way.

  • Action Configuration: Use the actions field in your config file. Example: "delta_qpos": {"func": "DeltaQposTerm", "params": {"scale": 0.1}}.

  • Device Management: Device is single-sourced from runtime.cuda. All components (trainer/algorithm/policy/env) share the same device.

  • Observation Format: Environments should provide consistent observation shape/types (torch.float32) and a single done = terminated | truncated.

  • Algorithm Interface: Algorithms implement update(rollout) and consume a shared rollout TensorDict. Collection is handled by SyncCollector plus environment-side rollout writes in EmbodiedEnv.

  • Reward Configuration: Use the RewardManager in your environment config to define reward components. Organize reward components in info["rewards"] dictionary and metrics in info["metrics"] dictionary. The trainer performs dense per-step logging directly from environment info.

  • Template Methods: Override compute_task_state() to define success/failure conditions and metrics. Override check_truncated() for custom truncation logic.

  • Configuration: Use JSON for all hyperparameters. This makes experiments reproducible and easy to track.

  • Logging: Metrics are automatically logged to TensorBoard and Weights & Biases. Check outputs/<exp_name>/logs/ for TensorBoard logs.

  • Checkpoints: Regular checkpoints are saved to outputs/<exp_name>/checkpoints/. Use these to resume training or evaluate policies.

See Also#