Reinforcement Learning Training#

This tutorial shows you how to train reinforcement learning agents using EmbodiChain’s RL framework. You’ll learn how to configure training via JSON, set up environments, policies, and algorithms, and launch training sessions.

Overview#

The RL framework provides a modular, extensible stack for robotics tasks:

  • Trainer: Orchestrates the training loop (calls algorithm for data collection and updates, handles logging/eval/save)

  • Algorithm: Controls data collection process (interacts with environment, fills buffer, computes advantages/returns) and updates the policy (e.g., PPO)

  • Policy: Neural network models implementing a unified interface (get_action/get_value/evaluate_actions)

  • Buffer: On-policy rollout storage and minibatch iterator (managed by algorithm)

  • Env Factory: Build environments from a JSON config via registry

Architecture#

The framework follows a clean separation of concerns:

  • Trainer: Orchestrates the training loop (calls algorithm for data collection and updates, handles logging/eval/save)

  • Algorithm: Controls data collection process (interacts with environment, fills buffer, computes advantages/returns) and updates the policy (e.g., PPO)

  • Policy: Neural network models implementing a unified interface

  • Buffer: On-policy rollout storage and minibatch iterator (managed by algorithm)

  • Env Factory: Build environments from a JSON config via registry

The core components and their relationships:

  • Trainer → Policy, Env, Algorithm (via callbacks for statistics)

  • Algorithm → Policy, RolloutBuffer (algorithm manages its own buffer)

Configuration via JSON#

Training is configured via a JSON file that defines runtime settings, environment, policy, and algorithm parameters.

Example Configuration#

The configuration file (e.g., train_config.json) is located in configs/agents/rl/push_cube:

Example: train_config.json
 1{
 2  "trainer": {
 3    "exp_name": "push_cube_ppo",
 4    "gym_config": "configs/agents/rl/push_cube/gym_config.json",
 5    "seed": 42,
 6    "device": "cuda:0",
 7    "headless": true,
 8    "iterations": 1000,
 9    "rollout_steps": 1024,
10    "eval_freq": 2,
11    "save_freq": 200,
12    "use_wandb": false,
13    "wandb_project_name": "embodychain-push_cube",
14    "events": {
15      "eval": {
16        "record_camera": {
17          "func": "record_camera_data_async",
18          "mode": "interval",
19          "interval_step": 1,
20          "params": {
21            "name": "main_cam",
22            "resolution": [640, 480],
23            "eye": [-1.4, 1.4, 2.0],
24            "target": [0, 0, 0],
25            "up": [0, 0, 1],
26            "intrinsics": [600, 600, 320, 240],
27            "save_path": "./outputs/videos/eval"
28          }
29        }
30      }
31    }
32  },
33  "policy": {
34    "name": "actor_critic",
35    "actor": {
36      "type": "mlp",
37      "network_cfg": {
38        "hidden_sizes": [256, 256],
39        "activation": "relu"
40      }
41    },
42    "critic": {
43      "type": "mlp",
44      "network_cfg": {
45        "hidden_sizes": [256, 256],
46        "activation": "relu"
47      }
48    }
49  },
50  "algorithm": {
51    "name": "ppo",
52    "cfg": {
53      "learning_rate": 0.0001,
54      "n_epochs": 10,
55      "batch_size": 8192,
56      "gamma": 0.99,
57      "gae_lambda": 0.95,
58      "clip_coef": 0.2,
59      "ent_coef": 0.01,
60      "vf_coef": 0.5,
61      "max_grad_norm": 0.5
62    }
63  }
64}

Configuration Sections#

Runtime Settings#

The runtime section controls experiment setup:

  • exp_name: Experiment name (used for output directories)

  • seed: Random seed for reproducibility

  • cuda: Whether to use GPU (default: true)

  • headless: Whether to run simulation in headless mode

  • iterations: Number of training iterations

  • rollout_steps: Steps per rollout (e.g., 1024)

  • eval_freq: Frequency of evaluation (in steps)

  • save_freq: Frequency of checkpoint saving (in steps)

  • use_wandb: Whether to enable Weights & Biases logging (set in JSON config)

  • wandb_project_name: Weights & Biases project name

Environment Configuration#

The env section defines the task environment:

  • id: Environment registry ID (e.g., “PushCubeRL”)

  • cfg: Environment-specific configuration parameters

Example:

"env": {
  "id": "PushCubeRL",
  "cfg": {
    "num_envs": 4,
    "obs_mode": "state",
    "episode_length": 100,
    "action_scale": 0.1,
    "success_threshold": 0.1
  }
}

Policy Configuration#

The policy section defines the neural network policy:

  • name: Policy name (e.g., “actor_critic”, “vla”)

  • cfg: Policy-specific hyperparameters (empty for actor_critic)

  • actor: Actor network configuration (required for actor_critic)

  • critic: Critic network configuration (required for actor_critic)

Example:

"policy": {
  "name": "actor_critic",
  "cfg": {},
  "actor": {
    "type": "mlp",
    "hidden_sizes": [256, 256],
    "activation": "relu"
  },
  "critic": {
    "type": "mlp",
    "hidden_sizes": [256, 256],
    "activation": "relu"
  }
}

Algorithm Configuration#

The algorithm section defines the RL algorithm:

  • name: Algorithm name (e.g., “ppo”)

  • cfg: Algorithm-specific hyperparameters

Example:

"algorithm": {
  "name": "ppo",
  "cfg": {
    "learning_rate": 0.0001,
    "n_epochs": 10,
    "batch_size": 64,
    "gamma": 0.99,
    "gae_lambda": 0.95,
    "clip_coef": 0.2,
    "ent_coef": 0.01,
    "vf_coef": 0.5,
    "max_grad_norm": 0.5
  }
}

Training Script#

The training script (train.py) is located in embodichain/agents/rl/:

Code for train.py
  1# ----------------------------------------------------------------------------
  2# Copyright (c) 2021-2025 DexForce Technology Co., Ltd.
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8#     http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15# ----------------------------------------------------------------------------
 16
 17import argparse
 18import os
 19import time
 20from pathlib import Path
 21
 22import numpy as np
 23import torch
 24import wandb
 25import json
 26from torch.utils.tensorboard import SummaryWriter
 27from copy import deepcopy
 28
 29from embodichain.agents.rl.models import build_policy, get_registered_policy_names
 30from embodichain.agents.rl.models import build_mlp_from_cfg
 31from embodichain.agents.rl.algo import build_algo, get_registered_algo_names
 32from embodichain.agents.rl.utils.trainer import Trainer
 33from embodichain.utils import logger
 34from embodichain.lab.gym.envs.tasks.rl import build_env
 35from embodichain.lab.gym.utils.gym_utils import config_to_rl_cfg
 36from embodichain.utils.utility import load_json
 37from embodichain.utils.module_utils import find_function_from_modules
 38from embodichain.lab.sim import SimulationManagerCfg
 39from embodichain.lab.gym.envs.managers.cfg import EventCfg
 40
 41
 42def main():
 43    parser = argparse.ArgumentParser()
 44    parser.add_argument("--config", type=str, required=True, help="Path to JSON config")
 45    args = parser.parse_args()
 46
 47    with open(args.config, "r") as f:
 48        cfg_json = json.load(f)
 49
 50    trainer_cfg = cfg_json["trainer"]
 51    policy_block = cfg_json["policy"]
 52    algo_block = cfg_json["algorithm"]
 53
 54    # Runtime
 55    exp_name = trainer_cfg.get("exp_name", "generic_exp")
 56    seed = int(trainer_cfg.get("seed", 1))
 57    device_str = trainer_cfg.get("device", "cpu")
 58    iterations = int(trainer_cfg.get("iterations", 250))
 59    rollout_steps = int(trainer_cfg.get("rollout_steps", 2048))
 60    eval_freq = int(trainer_cfg.get("eval_freq", 10000))
 61    save_freq = int(trainer_cfg.get("save_freq", 50000))
 62    headless = bool(trainer_cfg.get("headless", True))
 63    wandb_project_name = trainer_cfg.get("wandb_project_name", "embodychain-generic")
 64
 65    # Device
 66    if not isinstance(device_str, str):
 67        raise ValueError(
 68            f"runtime.device must be a string such as 'cpu' or 'cuda:0'. Got: {device_str!r}"
 69        )
 70    try:
 71        device = torch.device(device_str)
 72    except RuntimeError as exc:
 73        raise ValueError(
 74            f"Failed to parse runtime.device='{device_str}': {exc}"
 75        ) from exc
 76
 77    if device.type == "cuda":
 78        if not torch.cuda.is_available():
 79            raise ValueError(
 80                "CUDA device requested but torch.cuda.is_available() is False."
 81            )
 82        index = (
 83            device.index if device.index is not None else torch.cuda.current_device()
 84        )
 85        device_count = torch.cuda.device_count()
 86        if index < 0 or index >= device_count:
 87            raise ValueError(
 88                f"CUDA device index {index} is out of range (available devices: {device_count})."
 89            )
 90        torch.cuda.set_device(index)
 91        device = torch.device(f"cuda:{index}")
 92    elif device.type != "cpu":
 93        raise ValueError(f"Unsupported device type: {device}")
 94    logger.log_info(f"Device: {device}")
 95
 96    # Seeds
 97    np.random.seed(seed)
 98    torch.manual_seed(seed)
 99    torch.backends.cudnn.deterministic = True
100    if device.type == "cuda":
101        torch.cuda.manual_seed_all(seed)
102
103    # Outputs
104    run_stamp = time.strftime("%Y%m%d_%H%M%S")
105    run_base = os.path.join("outputs", f"{exp_name}_{run_stamp}")
106    log_dir = os.path.join(run_base, "logs")
107    checkpoint_dir = os.path.join(run_base, "checkpoints")
108    os.makedirs(log_dir, exist_ok=True)
109    os.makedirs(checkpoint_dir, exist_ok=True)
110    writer = SummaryWriter(f"{log_dir}/{exp_name}")
111
112    # Initialize Weights & Biases (optional)
113    use_wandb = trainer_cfg.get("use_wandb", False)
114
115    # Initialize Weights & Biases (optional)
116    if use_wandb:
117        wandb.init(project=wandb_project_name, name=exp_name, config=cfg_json)
118
119    gym_config_path = Path(trainer_cfg["gym_config"])
120    logger.log_info(f"Current working directory: {Path.cwd()}")
121
122    gym_config_data = load_json(str(gym_config_path))
123    gym_env_cfg = config_to_rl_cfg(gym_config_data)
124
125    # Ensure sim configuration mirrors runtime overrides
126    if gym_env_cfg.sim_cfg is None:
127        gym_env_cfg.sim_cfg = SimulationManagerCfg()
128    if device.type == "cuda":
129        gpu_index = device.index
130        if gpu_index is None:
131            gpu_index = torch.cuda.current_device()
132        gym_env_cfg.sim_cfg.sim_device = torch.device(f"cuda:{gpu_index}")
133        if hasattr(gym_env_cfg.sim_cfg, "gpu_id"):
134            gym_env_cfg.sim_cfg.gpu_id = gpu_index
135    else:
136        gym_env_cfg.sim_cfg.sim_device = torch.device("cpu")
137    gym_env_cfg.sim_cfg.headless = headless
138
139    logger.log_info(
140        f"Loaded gym_config from {gym_config_path} (env_id={gym_env_cfg.env_id}, headless={gym_env_cfg.sim_cfg.headless}, sim_device={gym_env_cfg.sim_cfg.sim_device})"
141    )
142
143    env = build_env(gym_env_cfg.env_id, base_env_cfg=gym_env_cfg)
144
145    eval_gym_env_cfg = deepcopy(gym_env_cfg)
146    eval_gym_env_cfg.num_envs = 4
147    eval_gym_env_cfg.sim_cfg.headless = True
148
149    eval_env = build_env(eval_gym_env_cfg.env_id, base_env_cfg=eval_gym_env_cfg)
150
151    # Build Policy via registry
152    policy_name = policy_block["name"]
153    # Build Policy via registry (actor/critic must be explicitly defined in JSON when using actor_critic)
154    if policy_name.lower() == "actor_critic":
155        obs_dim = env.observation_space.shape[-1]
156        action_dim = env.action_space.shape[-1]
157
158        actor_cfg = policy_block.get("actor")
159        critic_cfg = policy_block.get("critic")
160        if actor_cfg is None or critic_cfg is None:
161            raise ValueError(
162                "ActorCritic requires 'actor' and 'critic' definitions in JSON (policy.actor / policy.critic)."
163            )
164
165        actor = build_mlp_from_cfg(actor_cfg, obs_dim, action_dim)
166        critic = build_mlp_from_cfg(critic_cfg, obs_dim, 1)
167
168        policy = build_policy(
169            policy_block,
170            env.observation_space,
171            env.action_space,
172            device,
173            actor=actor,
174            critic=critic,
175        )
176    else:
177        policy = build_policy(
178            policy_block, env.observation_space, env.action_space, device
179        )
180
181    # Build Algorithm via factory
182    algo_name = algo_block["name"].lower()
183    algo_cfg = algo_block["cfg"]
184    algo = build_algo(algo_name, algo_cfg, policy, device)
185
186    # Build Trainer
187    event_modules = [
188        "embodichain.lab.gym.envs.managers.randomization",
189        "embodichain.lab.gym.envs.managers.record",
190        "embodichain.lab.gym.envs.managers.events",
191    ]
192    events_dict = trainer_cfg.get("events", {})
193    train_event_cfg = {}
194    eval_event_cfg = {}
195    # Parse train events
196    for event_name, event_info in events_dict.get("train", {}).items():
197        event_func_str = event_info.get("func")
198        mode = event_info.get("mode", "interval")
199        params = event_info.get("params", {})
200        interval_step = event_info.get("interval_step", 1)
201        event_func = find_function_from_modules(
202            event_func_str, event_modules, raise_if_not_found=True
203        )
204        train_event_cfg[event_name] = EventCfg(
205            func=event_func,
206            mode=mode,
207            params=params,
208            interval_step=interval_step,
209        )
210    # Parse eval events
211    for event_name, event_info in events_dict.get("eval", {}).items():
212        event_func_str = event_info.get("func")
213        mode = event_info.get("mode", "interval")
214        params = event_info.get("params", {})
215        interval_step = event_info.get("interval_step", 1)
216        event_func = find_function_from_modules(
217            event_func_str, event_modules, raise_if_not_found=True
218        )
219        eval_event_cfg[event_name] = EventCfg(
220            func=event_func,
221            mode=mode,
222            params=params,
223            interval_step=interval_step,
224        )
225    trainer = Trainer(
226        policy=policy,
227        env=env,
228        algorithm=algo,
229        num_steps=rollout_steps,
230        batch_size=algo_cfg["batch_size"],
231        writer=writer,
232        eval_freq=eval_freq,
233        save_freq=save_freq,
234        checkpoint_dir=checkpoint_dir,
235        exp_name=exp_name,
236        use_wandb=use_wandb,
237        eval_env=eval_env,
238        event_cfg=train_event_cfg,
239        eval_event_cfg=eval_event_cfg,
240    )
241
242    logger.log_info("Generic training initialized")
243    logger.log_info(f"Task: {type(env).__name__}")
244    logger.log_info(
245        f"Policy: {policy_name} (available: {get_registered_policy_names()})"
246    )
247    logger.log_info(
248        f"Algorithm: {algo_name} (available: {get_registered_algo_names()})"
249    )
250
251    total_steps = int(iterations * rollout_steps * env.num_envs)
252    logger.log_info(f"Total steps: {total_steps} (iterations≈{iterations})")
253
254    try:
255        trainer.train(total_steps)
256    except KeyboardInterrupt:
257        logger.log_info("Training interrupted by user")
258    finally:
259        trainer.save_checkpoint()
260        writer.close()
261        if use_wandb:
262            try:
263                wandb.finish()
264            except Exception:
265                pass
266        logger.log_info("Training finished")
267
268
269if __name__ == "__main__":
270    main()

The Script Explained#

The training script performs the following steps:

  1. Parse Configuration: Loads JSON config and extracts runtime/env/policy/algorithm blocks

  2. Setup: Initializes device, seeds, output directories, TensorBoard, and Weights & Biases

  3. Build Components: - Environment via build_env() factory - Policy via build_policy() registry - Algorithm via build_algo() factory

  4. Create Trainer: Instantiates the Trainer with all components

  5. Train: Runs the training loop until completion

Launching Training#

To start training, run:

python embodichain/agents/rl/train.py --config configs/agents/rl/push_cube/train_config.json

Outputs#

All outputs are written to ./outputs/<exp_name>_<timestamp>/:

  • logs/: TensorBoard logs

  • checkpoints/: Model checkpoints

Training Process#

The training process follows this sequence:

  1. Rollout Phase: Algorithm collects trajectories by interacting with the environment (via collect_rollout). During this phase, the trainer performs dense per-step logging of rewards and metrics from environment info.

  2. GAE Computation: Algorithm computes advantages and returns using Generalized Advantage Estimation (internal to algorithm, stored in buffer extras)

  3. Update Phase: Algorithm updates the policy using collected data (e.g., PPO)

  4. Logging: Trainer logs training losses and aggregated metrics to TensorBoard and Weights & Biases

  5. Evaluation (periodic): Trainer evaluates the current policy

  6. Checkpointing (periodic): Trainer saves model checkpoints

Policy Interface#

All policies must inherit from the Policy abstract base class:

from abc import ABC, abstractmethod
import torch.nn as nn

class Policy(nn.Module, ABC):
    device: torch.device

    @abstractmethod
    def get_action(
        self, obs: torch.Tensor, deterministic: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Returns (action, log_prob, value)"""
        raise NotImplementedError

    @abstractmethod
    def get_value(self, obs: torch.Tensor) -> torch.Tensor:
        """Returns value estimate"""
        raise NotImplementedError

    @abstractmethod
    def evaluate_actions(
        self, obs: torch.Tensor, actions: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Returns (log_prob, entropy, value)"""
        raise NotImplementedError

Available Policies#

  • ActorCritic: MLP-based Gaussian policy with learnable log_std. Requires external actor and critic modules to be provided (defined in JSON config).

  • VLAPlaceholderPolicy: Placeholder for Vision-Language-Action policies

Algorithms#

Available Algorithms#

  • PPO: Proximal Policy Optimization with GAE

Adding a New Algorithm#

To add a new algorithm:

  1. Create a new algorithm class in embodichain/agents/rl/algo/

  2. Implement initialize_buffer(), collect_rollout(), and update() methods

  3. Register in algo/__init__.py:

from embodichain.agents.rl.algo import BaseAlgorithm, register_algo
from embodichain.agents.rl.buffer import RolloutBuffer

@register_algo("my_algo")
class MyAlgorithm(BaseAlgorithm):
    def __init__(self, cfg, policy):
        self.cfg = cfg
        self.policy = policy
        self.device = torch.device(cfg.device)
        self.buffer = None

    def initialize_buffer(self, num_steps, num_envs, obs_dim, action_dim):
        """Initialize the algorithm's buffer."""
        self.buffer = RolloutBuffer(num_steps, num_envs, obs_dim, action_dim, self.device)

    def collect_rollout(self, env, policy, obs, num_steps, on_step_callback=None):
        """Control data collection process (interact with env, fill buffer, compute advantages/returns)."""
        # Collect trajectories
        # Compute advantages/returns (e.g., GAE for on-policy algorithms)
        # Attach extras to buffer: self.buffer.set_extras({"advantages": adv, "returns": ret})
        # Return empty dict (dense logging handled in trainer)
        return {}

    def update(self):
        """Update the policy using collected data."""
        # Access extras from buffer: self.buffer._extras.get("advantages")
        # Use self.buffer to update policy
        return {"loss": 0.0}

Adding a New Policy#

To add a new policy:

  1. Create a new policy class inheriting from the Policy abstract base class

  2. Register in models/__init__.py:

from embodichain.agents.rl.models import register_policy, Policy

@register_policy("my_policy")
class MyPolicy(Policy):
    def __init__(self, obs_space, action_space, device, config):
        super().__init__()
        self.device = device
        # Initialize your networks here

    def get_action(self, obs, deterministic=False):
        ...
    def get_value(self, obs):
        ...
    def evaluate_actions(self, obs, actions):
        ...

Adding a New Environment#

To add a new RL environment:

  1. Create an environment class inheriting from EmbodiedEnv

  2. Register it with the Gymnasium registry:

from embodichain.lab.gym.utils.registration import register_env

@register_env("MyTaskRL", max_episode_steps=100, override=True)
class MyTaskEnv(EmbodiedEnv):
    cfg: MyTaskEnvCfg
    ...
  1. Use the environment ID in your JSON config:

"env": {
  "id": "MyTaskRL",
  "cfg": {
    ...
  }
}

Best Practices#

  • Device Management: Device is single-sourced from runtime.cuda. All components (trainer/algorithm/policy/env) share the same device.

  • Action Scaling: Keep action scaling in the environment, not in the policy.

  • Observation Format: Environments should provide consistent observation shape/types (torch.float32) and a single done = terminated | truncated.

  • Algorithm Interface: Algorithms must implement initialize_buffer(), collect_rollout(), and update() methods. The algorithm completely controls data collection and buffer management.

  • Reward Components: Organize reward components in info["rewards"] dictionary and metrics in info["metrics"] dictionary. The trainer performs dense per-step logging directly from environment info.

  • Configuration: Use JSON for all hyperparameters. This makes experiments reproducible and easy to track.

  • Logging: Metrics are automatically logged to TensorBoard and Weights & Biases. Check outputs/<exp_name>/logs/ for TensorBoard logs.

  • Checkpoints: Regular checkpoints are saved to outputs/<exp_name>/checkpoints/. Use these to resume training or evaluate policies.