Reinforcement Learning Training#

This tutorial shows you how to train reinforcement learning agents using EmbodiChain’s RL framework. You’ll learn how to configure training via JSON, set up environments, policies, and algorithms, and launch training sessions.

Overview#

The RL framework provides a modular, extensible stack for robotics tasks:

  • Trainer: Orchestrates the training loop (calls algorithm for data collection and updates, handles logging/eval/save)

  • Algorithm: Controls data collection process (interacts with environment, fills buffer, computes advantages/returns) and updates the policy (e.g., PPO)

  • Policy: Neural network models implementing a unified interface (get_action/get_value/evaluate_actions)

  • Buffer: On-policy rollout storage and minibatch iterator (managed by algorithm)

  • Env Factory: Build environments from a JSON config via registry

Architecture#

The framework follows a clean separation of concerns:

  • Trainer: Orchestrates the training loop (calls algorithm for data collection and updates, handles logging/eval/save)

  • Algorithm: Controls data collection process (interacts with environment, fills buffer, computes advantages/returns) and updates the policy (e.g., PPO)

  • Policy: Neural network models implementing a unified interface

  • Buffer: On-policy rollout storage and minibatch iterator (managed by algorithm)

  • Env Factory: Build environments from a JSON config via registry

The core components and their relationships:

  • Trainer → Policy, Env, Algorithm (via callbacks for statistics)

  • Algorithm → Policy, RolloutBuffer (algorithm manages its own buffer)

Configuration via JSON#

Training is configured via a JSON file that defines runtime settings, environment, policy, and algorithm parameters.

Example Configuration#

The configuration file (e.g., train_config.json) is located in configs/agents/rl/push_cube:

Example: train_config.json
 1{ 
 2    "trainer": {
 3        "exp_name": "push_cube_ppo",
 4        "gym_config": "configs/agents/rl/push_cube/gym_config.json",
 5        "seed": 42,
 6        "device": "cuda:0",
 7        "headless": false,
 8        "enable_rt": false,
 9        "gpu_id": 0,
10        "num_envs": 8,
11        "iterations": 1000,
12        "rollout_steps": 1024,
13        "eval_freq": 200,
14        "save_freq": 200,
15        "use_wandb": true,
16        "wandb_project_name": "embodychain-push_cube",
17        "events": {
18            "eval": {
19                "record_camera": {
20                    "func": "record_camera_data_async",
21                    "mode": "interval",
22                    "interval_step": 1,
23                    "params": {
24                        "name": "main_cam",
25                        "resolution": [640, 480],
26                        "eye": [-1.4, 1.4, 2.0],
27                        "target": [0, 0, 0],
28                        "up": [0, 0, 1],
29                        "intrinsics": [600, 600, 320, 240],
30                        "save_path": "./outputs/videos/eval"
31                    }
32                }
33            }
34        }
35    },
36    "policy": {
37        "name": "actor_critic",
38        "actor": {
39            "type": "mlp",
40            "network_cfg": {
41                "hidden_sizes": [256, 256],
42                "activation": "relu"
43            }
44        },
45        "critic": {
46            "type": "mlp",
47            "network_cfg": {
48                "hidden_sizes": [256, 256],
49                "activation": "relu"
50            }
51        }
52    },
53    "algorithm": {
54        "name": "ppo",
55        "cfg": {
56            "learning_rate": 0.0001,
57            "n_epochs": 10,
58            "batch_size": 8192,
59            "gamma": 0.99,
60            "gae_lambda": 0.95,
61            "clip_coef": 0.2,
62            "ent_coef": 0.01,
63            "vf_coef": 0.5,
64            "max_grad_norm": 0.5
65        }
66    }
67}

Configuration Sections#

Runtime Settings#

The runtime section controls experiment setup:

  • exp_name: Experiment name (used for output directories)

  • seed: Random seed for reproducibility

  • cuda: Whether to use GPU (default: true)

  • headless: Whether to run simulation in headless mode

  • iterations: Number of training iterations

  • rollout_steps: Steps per rollout (e.g., 1024)

  • eval_freq: Frequency of evaluation (in steps)

  • save_freq: Frequency of checkpoint saving (in steps)

  • use_wandb: Whether to enable Weights & Biases logging (set in JSON config)

  • wandb_project_name: Weights & Biases project name

Environment Configuration#

The env section defines the task environment:

  • id: Environment registry ID (e.g., “PushCubeRL”)

  • cfg: Environment-specific configuration parameters

Example:

"env": {
  "id": "PushCubeRL",
  "cfg": {
    "num_envs": 4,
    "obs_mode": "state",
    "episode_length": 100,
    "action_scale": 0.1,
    "success_threshold": 0.1
  }
}

Policy Configuration#

The policy section defines the neural network policy:

  • name: Policy name (e.g., “actor_critic”, “vla”)

  • cfg: Policy-specific hyperparameters (empty for actor_critic)

  • actor: Actor network configuration (required for actor_critic)

  • critic: Critic network configuration (required for actor_critic)

Example:

"policy": {
  "name": "actor_critic",
  "cfg": {},
  "actor": {
    "type": "mlp",
    "hidden_sizes": [256, 256],
    "activation": "relu"
  },
  "critic": {
    "type": "mlp",
    "hidden_sizes": [256, 256],
    "activation": "relu"
  }
}

Algorithm Configuration#

The algorithm section defines the RL algorithm:

  • name: Algorithm name (e.g., “ppo”)

  • cfg: Algorithm-specific hyperparameters

Example:

"algorithm": {
  "name": "ppo",
  "cfg": {
    "learning_rate": 0.0001,
    "n_epochs": 10,
    "batch_size": 64,
    "gamma": 0.99,
    "gae_lambda": 0.95,
    "clip_coef": 0.2,
    "ent_coef": 0.01,
    "vf_coef": 0.5,
    "max_grad_norm": 0.5
  }
}

Training Script#

The training script (train.py) is located in embodichain/agents/rl/:

Code for train.py
  1# ----------------------------------------------------------------------------
  2# Copyright (c) 2021-2025 DexForce Technology Co., Ltd.
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8#     http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15# ----------------------------------------------------------------------------
 16
 17import argparse
 18import os
 19import time
 20from pathlib import Path
 21
 22import numpy as np
 23import torch
 24import wandb
 25import json
 26from torch.utils.tensorboard import SummaryWriter
 27from copy import deepcopy
 28
 29from embodichain.agents.rl.models import build_policy, get_registered_policy_names
 30from embodichain.agents.rl.models import build_mlp_from_cfg
 31from embodichain.agents.rl.algo import build_algo, get_registered_algo_names
 32from embodichain.agents.rl.utils.trainer import Trainer
 33from embodichain.utils import logger
 34from embodichain.lab.gym.envs.tasks.rl import build_env
 35from embodichain.lab.gym.utils.gym_utils import config_to_cfg
 36from embodichain.utils.utility import load_json
 37from embodichain.utils.module_utils import find_function_from_modules
 38from embodichain.lab.sim import SimulationManagerCfg
 39from embodichain.lab.gym.envs.managers.cfg import EventCfg
 40
 41
 42def main():
 43    parser = argparse.ArgumentParser()
 44    parser.add_argument("--config", type=str, required=True, help="Path to JSON config")
 45    args = parser.parse_args()
 46
 47    with open(args.config, "r") as f:
 48        cfg_json = json.load(f)
 49
 50    trainer_cfg = cfg_json["trainer"]
 51    policy_block = cfg_json["policy"]
 52    algo_block = cfg_json["algorithm"]
 53
 54    # Runtime
 55    exp_name = trainer_cfg.get("exp_name", "generic_exp")
 56    seed = int(trainer_cfg.get("seed", 1))
 57    device_str = trainer_cfg.get("device", "cpu")
 58    iterations = int(trainer_cfg.get("iterations", 250))
 59    rollout_steps = int(trainer_cfg.get("rollout_steps", 2048))
 60    eval_freq = int(trainer_cfg.get("eval_freq", 10000))
 61    save_freq = int(trainer_cfg.get("save_freq", 50000))
 62    headless = bool(trainer_cfg.get("headless", True))
 63    enable_rt = bool(trainer_cfg.get("enable_rt", False))
 64    gpu_id = int(trainer_cfg.get("gpu_id", 0))
 65    num_envs = trainer_cfg.get("num_envs", None)
 66    wandb_project_name = trainer_cfg.get("wandb_project_name", "embodychain-generic")
 67
 68    # Device
 69    if not isinstance(device_str, str):
 70        raise ValueError(
 71            f"runtime.device must be a string such as 'cpu' or 'cuda:0'. Got: {device_str!r}"
 72        )
 73    try:
 74        device = torch.device(device_str)
 75    except RuntimeError as exc:
 76        raise ValueError(
 77            f"Failed to parse runtime.device='{device_str}': {exc}"
 78        ) from exc
 79
 80    if device.type == "cuda":
 81        if not torch.cuda.is_available():
 82            raise ValueError(
 83                "CUDA device requested but torch.cuda.is_available() is False."
 84            )
 85        index = (
 86            device.index if device.index is not None else torch.cuda.current_device()
 87        )
 88        device_count = torch.cuda.device_count()
 89        if index < 0 or index >= device_count:
 90            raise ValueError(
 91                f"CUDA device index {index} is out of range (available devices: {device_count})."
 92            )
 93        torch.cuda.set_device(index)
 94        device = torch.device(f"cuda:{index}")
 95    elif device.type != "cpu":
 96        raise ValueError(f"Unsupported device type: {device}")
 97    logger.log_info(f"Device: {device}")
 98
 99    # Seeds
100    np.random.seed(seed)
101    torch.manual_seed(seed)
102    torch.backends.cudnn.deterministic = True
103    if device.type == "cuda":
104        torch.cuda.manual_seed_all(seed)
105
106    # Outputs
107    run_stamp = time.strftime("%Y%m%d_%H%M%S")
108    run_base = os.path.join("outputs", f"{exp_name}_{run_stamp}")
109    log_dir = os.path.join(run_base, "logs")
110    checkpoint_dir = os.path.join(run_base, "checkpoints")
111    os.makedirs(log_dir, exist_ok=True)
112    os.makedirs(checkpoint_dir, exist_ok=True)
113    writer = SummaryWriter(f"{log_dir}/{exp_name}")
114
115    # Initialize Weights & Biases (optional)
116    use_wandb = trainer_cfg.get("use_wandb", False)
117
118    # Initialize Weights & Biases (optional)
119    if use_wandb:
120        wandb.init(project=wandb_project_name, name=exp_name, config=cfg_json)
121
122    gym_config_path = Path(trainer_cfg["gym_config"])
123    logger.log_info(f"Current working directory: {Path.cwd()}")
124
125    gym_config_data = load_json(str(gym_config_path))
126    gym_env_cfg = config_to_cfg(gym_config_data)
127
128    # Override num_envs from train config if provided
129    if num_envs is not None:
130        gym_env_cfg.num_envs = num_envs
131
132    # Ensure sim configuration mirrors runtime overrides
133    if gym_env_cfg.sim_cfg is None:
134        gym_env_cfg.sim_cfg = SimulationManagerCfg()
135    if device.type == "cuda":
136        gpu_index = device.index
137        if gpu_index is None:
138            gpu_index = torch.cuda.current_device()
139        gym_env_cfg.sim_cfg.sim_device = torch.device(f"cuda:{gpu_index}")
140        if hasattr(gym_env_cfg.sim_cfg, "gpu_id"):
141            gym_env_cfg.sim_cfg.gpu_id = gpu_index
142    else:
143        gym_env_cfg.sim_cfg.sim_device = torch.device("cpu")
144    gym_env_cfg.sim_cfg.headless = headless
145    gym_env_cfg.sim_cfg.enable_rt = enable_rt
146    gym_env_cfg.sim_cfg.gpu_id = gpu_id
147
148    logger.log_info(
149        f"Loaded gym_config from {gym_config_path} (env_id={gym_config_data['id']}, num_envs={gym_env_cfg.num_envs}, headless={gym_env_cfg.sim_cfg.headless}, enable_rt={gym_env_cfg.sim_cfg.enable_rt}, sim_device={gym_env_cfg.sim_cfg.sim_device})"
150    )
151
152    env = build_env(gym_config_data["id"], base_env_cfg=gym_env_cfg)
153
154    eval_gym_env_cfg = deepcopy(gym_env_cfg)
155    eval_gym_env_cfg.num_envs = 4
156    eval_gym_env_cfg.sim_cfg.headless = True
157
158    eval_env = build_env(gym_config_data["id"], base_env_cfg=eval_gym_env_cfg)
159
160    # Build Policy via registry
161    policy_name = policy_block["name"]
162    # Build Policy via registry (actor/critic must be explicitly defined in JSON when using actor_critic)
163    if policy_name.lower() == "actor_critic":
164        # Get observation dimension from flattened observation space
165        # flattened_observation_space returns Box space for RL training
166        obs_dim = env.flattened_observation_space.shape[-1]
167        action_dim = env.action_space.shape[-1]
168
169        actor_cfg = policy_block.get("actor")
170        critic_cfg = policy_block.get("critic")
171        if actor_cfg is None or critic_cfg is None:
172            raise ValueError(
173                "ActorCritic requires 'actor' and 'critic' definitions in JSON (policy.actor / policy.critic)."
174            )
175
176        actor = build_mlp_from_cfg(actor_cfg, obs_dim, action_dim)
177        critic = build_mlp_from_cfg(critic_cfg, obs_dim, 1)
178
179        policy = build_policy(
180            policy_block,
181            env.flattened_observation_space,
182            env.action_space,
183            device,
184            actor=actor,
185            critic=critic,
186        )
187    else:
188        policy = build_policy(
189            policy_block, env.flattened_observation_space, env.action_space, device
190        )
191
192    # Build Algorithm via factory
193    algo_name = algo_block["name"].lower()
194    algo_cfg = algo_block["cfg"]
195    algo = build_algo(algo_name, algo_cfg, policy, device)
196
197    # Build Trainer
198    event_modules = [
199        "embodichain.lab.gym.envs.managers.randomization",
200        "embodichain.lab.gym.envs.managers.record",
201        "embodichain.lab.gym.envs.managers.events",
202    ]
203    events_dict = trainer_cfg.get("events", {})
204    train_event_cfg = {}
205    eval_event_cfg = {}
206    # Parse train events
207    for event_name, event_info in events_dict.get("train", {}).items():
208        event_func_str = event_info.get("func")
209        mode = event_info.get("mode", "interval")
210        params = event_info.get("params", {})
211        interval_step = event_info.get("interval_step", 1)
212        event_func = find_function_from_modules(
213            event_func_str, event_modules, raise_if_not_found=True
214        )
215        train_event_cfg[event_name] = EventCfg(
216            func=event_func,
217            mode=mode,
218            params=params,
219            interval_step=interval_step,
220        )
221    # Parse eval events
222    for event_name, event_info in events_dict.get("eval", {}).items():
223        event_func_str = event_info.get("func")
224        mode = event_info.get("mode", "interval")
225        params = event_info.get("params", {})
226        interval_step = event_info.get("interval_step", 1)
227        event_func = find_function_from_modules(
228            event_func_str, event_modules, raise_if_not_found=True
229        )
230        eval_event_cfg[event_name] = EventCfg(
231            func=event_func,
232            mode=mode,
233            params=params,
234            interval_step=interval_step,
235        )
236    trainer = Trainer(
237        policy=policy,
238        env=env,
239        algorithm=algo,
240        num_steps=rollout_steps,
241        batch_size=algo_cfg["batch_size"],
242        writer=writer,
243        eval_freq=eval_freq,
244        save_freq=save_freq,
245        checkpoint_dir=checkpoint_dir,
246        exp_name=exp_name,
247        use_wandb=use_wandb,
248        eval_env=eval_env,
249        event_cfg=train_event_cfg,
250        eval_event_cfg=eval_event_cfg,
251    )
252
253    logger.log_info("Generic training initialized")
254    logger.log_info(f"Task: {type(env).__name__}")
255    logger.log_info(
256        f"Policy: {policy_name} (available: {get_registered_policy_names()})"
257    )
258    logger.log_info(
259        f"Algorithm: {algo_name} (available: {get_registered_algo_names()})"
260    )
261
262    total_steps = int(iterations * rollout_steps * env.num_envs)
263    logger.log_info(f"Total steps: {total_steps} (iterations≈{iterations})")
264
265    try:
266        trainer.train(total_steps)
267    except KeyboardInterrupt:
268        logger.log_info("Training interrupted by user")
269    finally:
270        trainer.save_checkpoint()
271        writer.close()
272        if use_wandb:
273            try:
274                wandb.finish()
275            except Exception:
276                pass
277        logger.log_info("Training finished")
278
279
280if __name__ == "__main__":
281    main()

The Script Explained#

The training script performs the following steps:

  1. Parse Configuration: Loads JSON config and extracts runtime/env/policy/algorithm blocks

  2. Setup: Initializes device, seeds, output directories, TensorBoard, and Weights & Biases

  3. Build Components: - Environment via build_env() factory - Policy via build_policy() registry - Algorithm via build_algo() factory

  4. Create Trainer: Instantiates the Trainer with all components

  5. Train: Runs the training loop until completion

Launching Training#

To start training, run:

python -m embodichain.agents.rl.train --config configs/agents/rl/push_cube/train_config.json

Outputs#

All outputs are written to ./outputs/<exp_name>_<timestamp>/:

  • logs/: TensorBoard logs

  • checkpoints/: Model checkpoints

Training Process#

The training process follows this sequence:

  1. Rollout Phase: Algorithm collects trajectories by interacting with the environment (via collect_rollout). During this phase, the trainer performs dense per-step logging of rewards and metrics from environment info.

  2. GAE Computation: Algorithm computes advantages and returns using Generalized Advantage Estimation (internal to algorithm, stored in buffer extras)

  3. Update Phase: Algorithm updates the policy using collected data (e.g., PPO)

  4. Logging: Trainer logs training losses and aggregated metrics to TensorBoard and Weights & Biases

  5. Evaluation (periodic): Trainer evaluates the current policy

  6. Checkpointing (periodic): Trainer saves model checkpoints

Policy Interface#

All policies must inherit from the Policy abstract base class:

from abc import ABC, abstractmethod
import torch.nn as nn

class Policy(nn.Module, ABC):
    device: torch.device

    @abstractmethod
    def get_action(
        self, obs: torch.Tensor, deterministic: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Returns (action, log_prob, value)"""
        raise NotImplementedError

    @abstractmethod
    def get_value(self, obs: torch.Tensor) -> torch.Tensor:
        """Returns value estimate"""
        raise NotImplementedError

    @abstractmethod
    def evaluate_actions(
        self, obs: torch.Tensor, actions: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Returns (log_prob, entropy, value)"""
        raise NotImplementedError

Available Policies#

  • ActorCritic: MLP-based Gaussian policy with learnable log_std. Requires external actor and critic modules to be provided (defined in JSON config).

  • VLAPlaceholderPolicy: Placeholder for Vision-Language-Action policies

Algorithms#

Available Algorithms#

  • PPO: Proximal Policy Optimization with GAE

Adding a New Algorithm#

To add a new algorithm:

  1. Create a new algorithm class in embodichain/agents/rl/algo/

  2. Implement initialize_buffer(), collect_rollout(), and update() methods

  3. Register in algo/__init__.py:

from embodichain.agents.rl.algo import BaseAlgorithm, register_algo
from embodichain.agents.rl.buffer import RolloutBuffer

@register_algo("my_algo")
class MyAlgorithm(BaseAlgorithm):
    def __init__(self, cfg, policy):
        self.cfg = cfg
        self.policy = policy
        self.device = torch.device(cfg.device)
        self.buffer = None

    def initialize_buffer(self, num_steps, num_envs, obs_dim, action_dim):
        """Initialize the algorithm's buffer."""
        self.buffer = RolloutBuffer(num_steps, num_envs, obs_dim, action_dim, self.device)

    def collect_rollout(self, env, policy, obs, num_steps, on_step_callback=None):
        """Control data collection process (interact with env, fill buffer, compute advantages/returns)."""
        # Collect trajectories
        # Compute advantages/returns (e.g., GAE for on-policy algorithms)
        # Attach extras to buffer: self.buffer.set_extras({"advantages": adv, "returns": ret})
        # Return empty dict (dense logging handled in trainer)
        return {}

    def update(self):
        """Update the policy using collected data."""
        # Access extras from buffer: self.buffer._extras.get("advantages")
        # Use self.buffer to update policy
        return {"loss": 0.0}

Adding a New Policy#

To add a new policy:

  1. Create a new policy class inheriting from the Policy abstract base class

  2. Register in models/__init__.py:

from embodichain.agents.rl.models import register_policy, Policy

@register_policy("my_policy")
class MyPolicy(Policy):
    def __init__(self, obs_space, action_space, device, config):
        super().__init__()
        self.device = device
        # Initialize your networks here

    def get_action(self, obs, deterministic=False):
        ...
    def get_value(self, obs):
        ...
    def evaluate_actions(self, obs, actions):
        ...

Adding a New Environment#

To add a new RL environment:

  1. Create an environment class inheriting from EmbodiedEnv

  2. Register it with the Gymnasium registry:

from embodichain.lab.gym.utils.registration import register_env

@register_env("MyTaskRL", max_episode_steps=100, override=True)
class MyTaskEnv(EmbodiedEnv):
    cfg: MyTaskEnvCfg
    ...
  1. Use the environment ID in your JSON config:

"env": {
  "id": "MyTaskRL",
  "cfg": {
    ...
  }
}

Best Practices#

  • Device Management: Device is single-sourced from runtime.cuda. All components (trainer/algorithm/policy/env) share the same device.

  • Action Scaling: Keep action scaling in the environment, not in the policy.

  • Observation Format: Environments should provide consistent observation shape/types (torch.float32) and a single done = terminated | truncated.

  • Algorithm Interface: Algorithms must implement initialize_buffer(), collect_rollout(), and update() methods. The algorithm completely controls data collection and buffer management.

  • Reward Components: Organize reward components in info["rewards"] dictionary and metrics in info["metrics"] dictionary. The trainer performs dense per-step logging directly from environment info.

  • Configuration: Use JSON for all hyperparameters. This makes experiments reproducible and easy to track.

  • Logging: Metrics are automatically logged to TensorBoard and Weights & Biases. Check outputs/<exp_name>/logs/ for TensorBoard logs.

  • Checkpoints: Regular checkpoints are saved to outputs/<exp_name>/checkpoints/. Use these to resume training or evaluate policies.