Reinforcement Learning Training#
This tutorial shows you how to train reinforcement learning agents using EmbodiChain’s RL framework. You’ll learn how to configure training via JSON, set up environments, policies, and algorithms, and launch training sessions.
Overview#
The RL framework provides a modular, extensible stack for robotics tasks:
Trainer: Orchestrates the training loop (calls algorithm for data collection and updates, handles logging/eval/save)
Algorithm: Controls data collection process (interacts with environment, fills buffer, computes advantages/returns) and updates the policy (e.g., PPO)
Policy: Neural network models implementing a unified interface (get_action/get_value/evaluate_actions)
Buffer: On-policy rollout storage and minibatch iterator (managed by algorithm)
Env Factory: Build environments from a JSON config via registry
Architecture#
The framework follows a clean separation of concerns:
Trainer: Orchestrates the training loop (calls algorithm for data collection and updates, handles logging/eval/save)
Algorithm: Controls data collection process (interacts with environment, fills buffer, computes advantages/returns) and updates the policy (e.g., PPO)
Policy: Neural network models implementing a unified interface
Buffer: On-policy rollout storage and minibatch iterator (managed by algorithm)
Env Factory: Build environments from a JSON config via registry
The core components and their relationships:
Trainer → Policy, Env, Algorithm (via callbacks for statistics)
Algorithm → Policy, RolloutBuffer (algorithm manages its own buffer)
Configuration via JSON#
Training is configured via a JSON file that defines runtime settings, environment, policy, and algorithm parameters.
Example Configuration#
The configuration file (e.g., train_config.json) is located in configs/agents/rl/push_cube:
Example: train_config.json
1{
2 "trainer": {
3 "exp_name": "push_cube_ppo",
4 "gym_config": "configs/agents/rl/push_cube/gym_config.json",
5 "seed": 42,
6 "device": "cuda:0",
7 "headless": true,
8 "iterations": 1000,
9 "rollout_steps": 1024,
10 "eval_freq": 2,
11 "save_freq": 200,
12 "use_wandb": false,
13 "wandb_project_name": "embodychain-push_cube",
14 "events": {
15 "eval": {
16 "record_camera": {
17 "func": "record_camera_data_async",
18 "mode": "interval",
19 "interval_step": 1,
20 "params": {
21 "name": "main_cam",
22 "resolution": [640, 480],
23 "eye": [-1.4, 1.4, 2.0],
24 "target": [0, 0, 0],
25 "up": [0, 0, 1],
26 "intrinsics": [600, 600, 320, 240],
27 "save_path": "./outputs/videos/eval"
28 }
29 }
30 }
31 }
32 },
33 "policy": {
34 "name": "actor_critic",
35 "actor": {
36 "type": "mlp",
37 "network_cfg": {
38 "hidden_sizes": [256, 256],
39 "activation": "relu"
40 }
41 },
42 "critic": {
43 "type": "mlp",
44 "network_cfg": {
45 "hidden_sizes": [256, 256],
46 "activation": "relu"
47 }
48 }
49 },
50 "algorithm": {
51 "name": "ppo",
52 "cfg": {
53 "learning_rate": 0.0001,
54 "n_epochs": 10,
55 "batch_size": 8192,
56 "gamma": 0.99,
57 "gae_lambda": 0.95,
58 "clip_coef": 0.2,
59 "ent_coef": 0.01,
60 "vf_coef": 0.5,
61 "max_grad_norm": 0.5
62 }
63 }
64}
Configuration Sections#
Runtime Settings#
The runtime section controls experiment setup:
exp_name: Experiment name (used for output directories)
seed: Random seed for reproducibility
cuda: Whether to use GPU (default: true)
headless: Whether to run simulation in headless mode
iterations: Number of training iterations
rollout_steps: Steps per rollout (e.g., 1024)
eval_freq: Frequency of evaluation (in steps)
save_freq: Frequency of checkpoint saving (in steps)
use_wandb: Whether to enable Weights & Biases logging (set in JSON config)
wandb_project_name: Weights & Biases project name
Environment Configuration#
The env section defines the task environment:
id: Environment registry ID (e.g., “PushCubeRL”)
cfg: Environment-specific configuration parameters
Example:
"env": {
"id": "PushCubeRL",
"cfg": {
"num_envs": 4,
"obs_mode": "state",
"episode_length": 100,
"action_scale": 0.1,
"success_threshold": 0.1
}
}
Policy Configuration#
The policy section defines the neural network policy:
name: Policy name (e.g., “actor_critic”, “vla”)
cfg: Policy-specific hyperparameters (empty for actor_critic)
actor: Actor network configuration (required for actor_critic)
critic: Critic network configuration (required for actor_critic)
Example:
"policy": {
"name": "actor_critic",
"cfg": {},
"actor": {
"type": "mlp",
"hidden_sizes": [256, 256],
"activation": "relu"
},
"critic": {
"type": "mlp",
"hidden_sizes": [256, 256],
"activation": "relu"
}
}
Algorithm Configuration#
The algorithm section defines the RL algorithm:
name: Algorithm name (e.g., “ppo”)
cfg: Algorithm-specific hyperparameters
Example:
"algorithm": {
"name": "ppo",
"cfg": {
"learning_rate": 0.0001,
"n_epochs": 10,
"batch_size": 64,
"gamma": 0.99,
"gae_lambda": 0.95,
"clip_coef": 0.2,
"ent_coef": 0.01,
"vf_coef": 0.5,
"max_grad_norm": 0.5
}
}
Training Script#
The training script (train.py) is located in embodichain/agents/rl/:
Code for train.py
1# ----------------------------------------------------------------------------
2# Copyright (c) 2021-2025 DexForce Technology Co., Ltd.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ----------------------------------------------------------------------------
16
17import argparse
18import os
19import time
20from pathlib import Path
21
22import numpy as np
23import torch
24import wandb
25import json
26from torch.utils.tensorboard import SummaryWriter
27from copy import deepcopy
28
29from embodichain.agents.rl.models import build_policy, get_registered_policy_names
30from embodichain.agents.rl.models import build_mlp_from_cfg
31from embodichain.agents.rl.algo import build_algo, get_registered_algo_names
32from embodichain.agents.rl.utils.trainer import Trainer
33from embodichain.utils import logger
34from embodichain.lab.gym.envs.tasks.rl import build_env
35from embodichain.lab.gym.utils.gym_utils import config_to_rl_cfg
36from embodichain.utils.utility import load_json
37from embodichain.utils.module_utils import find_function_from_modules
38from embodichain.lab.sim import SimulationManagerCfg
39from embodichain.lab.gym.envs.managers.cfg import EventCfg
40
41
42def main():
43 parser = argparse.ArgumentParser()
44 parser.add_argument("--config", type=str, required=True, help="Path to JSON config")
45 args = parser.parse_args()
46
47 with open(args.config, "r") as f:
48 cfg_json = json.load(f)
49
50 trainer_cfg = cfg_json["trainer"]
51 policy_block = cfg_json["policy"]
52 algo_block = cfg_json["algorithm"]
53
54 # Runtime
55 exp_name = trainer_cfg.get("exp_name", "generic_exp")
56 seed = int(trainer_cfg.get("seed", 1))
57 device_str = trainer_cfg.get("device", "cpu")
58 iterations = int(trainer_cfg.get("iterations", 250))
59 rollout_steps = int(trainer_cfg.get("rollout_steps", 2048))
60 eval_freq = int(trainer_cfg.get("eval_freq", 10000))
61 save_freq = int(trainer_cfg.get("save_freq", 50000))
62 headless = bool(trainer_cfg.get("headless", True))
63 wandb_project_name = trainer_cfg.get("wandb_project_name", "embodychain-generic")
64
65 # Device
66 if not isinstance(device_str, str):
67 raise ValueError(
68 f"runtime.device must be a string such as 'cpu' or 'cuda:0'. Got: {device_str!r}"
69 )
70 try:
71 device = torch.device(device_str)
72 except RuntimeError as exc:
73 raise ValueError(
74 f"Failed to parse runtime.device='{device_str}': {exc}"
75 ) from exc
76
77 if device.type == "cuda":
78 if not torch.cuda.is_available():
79 raise ValueError(
80 "CUDA device requested but torch.cuda.is_available() is False."
81 )
82 index = (
83 device.index if device.index is not None else torch.cuda.current_device()
84 )
85 device_count = torch.cuda.device_count()
86 if index < 0 or index >= device_count:
87 raise ValueError(
88 f"CUDA device index {index} is out of range (available devices: {device_count})."
89 )
90 torch.cuda.set_device(index)
91 device = torch.device(f"cuda:{index}")
92 elif device.type != "cpu":
93 raise ValueError(f"Unsupported device type: {device}")
94 logger.log_info(f"Device: {device}")
95
96 # Seeds
97 np.random.seed(seed)
98 torch.manual_seed(seed)
99 torch.backends.cudnn.deterministic = True
100 if device.type == "cuda":
101 torch.cuda.manual_seed_all(seed)
102
103 # Outputs
104 run_stamp = time.strftime("%Y%m%d_%H%M%S")
105 run_base = os.path.join("outputs", f"{exp_name}_{run_stamp}")
106 log_dir = os.path.join(run_base, "logs")
107 checkpoint_dir = os.path.join(run_base, "checkpoints")
108 os.makedirs(log_dir, exist_ok=True)
109 os.makedirs(checkpoint_dir, exist_ok=True)
110 writer = SummaryWriter(f"{log_dir}/{exp_name}")
111
112 # Initialize Weights & Biases (optional)
113 use_wandb = trainer_cfg.get("use_wandb", False)
114
115 # Initialize Weights & Biases (optional)
116 if use_wandb:
117 wandb.init(project=wandb_project_name, name=exp_name, config=cfg_json)
118
119 gym_config_path = Path(trainer_cfg["gym_config"])
120 logger.log_info(f"Current working directory: {Path.cwd()}")
121
122 gym_config_data = load_json(str(gym_config_path))
123 gym_env_cfg = config_to_rl_cfg(gym_config_data)
124
125 # Ensure sim configuration mirrors runtime overrides
126 if gym_env_cfg.sim_cfg is None:
127 gym_env_cfg.sim_cfg = SimulationManagerCfg()
128 if device.type == "cuda":
129 gpu_index = device.index
130 if gpu_index is None:
131 gpu_index = torch.cuda.current_device()
132 gym_env_cfg.sim_cfg.sim_device = torch.device(f"cuda:{gpu_index}")
133 if hasattr(gym_env_cfg.sim_cfg, "gpu_id"):
134 gym_env_cfg.sim_cfg.gpu_id = gpu_index
135 else:
136 gym_env_cfg.sim_cfg.sim_device = torch.device("cpu")
137 gym_env_cfg.sim_cfg.headless = headless
138
139 logger.log_info(
140 f"Loaded gym_config from {gym_config_path} (env_id={gym_env_cfg.env_id}, headless={gym_env_cfg.sim_cfg.headless}, sim_device={gym_env_cfg.sim_cfg.sim_device})"
141 )
142
143 env = build_env(gym_env_cfg.env_id, base_env_cfg=gym_env_cfg)
144
145 eval_gym_env_cfg = deepcopy(gym_env_cfg)
146 eval_gym_env_cfg.num_envs = 4
147 eval_gym_env_cfg.sim_cfg.headless = True
148
149 eval_env = build_env(eval_gym_env_cfg.env_id, base_env_cfg=eval_gym_env_cfg)
150
151 # Build Policy via registry
152 policy_name = policy_block["name"]
153 # Build Policy via registry (actor/critic must be explicitly defined in JSON when using actor_critic)
154 if policy_name.lower() == "actor_critic":
155 obs_dim = env.observation_space.shape[-1]
156 action_dim = env.action_space.shape[-1]
157
158 actor_cfg = policy_block.get("actor")
159 critic_cfg = policy_block.get("critic")
160 if actor_cfg is None or critic_cfg is None:
161 raise ValueError(
162 "ActorCritic requires 'actor' and 'critic' definitions in JSON (policy.actor / policy.critic)."
163 )
164
165 actor = build_mlp_from_cfg(actor_cfg, obs_dim, action_dim)
166 critic = build_mlp_from_cfg(critic_cfg, obs_dim, 1)
167
168 policy = build_policy(
169 policy_block,
170 env.observation_space,
171 env.action_space,
172 device,
173 actor=actor,
174 critic=critic,
175 )
176 else:
177 policy = build_policy(
178 policy_block, env.observation_space, env.action_space, device
179 )
180
181 # Build Algorithm via factory
182 algo_name = algo_block["name"].lower()
183 algo_cfg = algo_block["cfg"]
184 algo = build_algo(algo_name, algo_cfg, policy, device)
185
186 # Build Trainer
187 event_modules = [
188 "embodichain.lab.gym.envs.managers.randomization",
189 "embodichain.lab.gym.envs.managers.record",
190 "embodichain.lab.gym.envs.managers.events",
191 ]
192 events_dict = trainer_cfg.get("events", {})
193 train_event_cfg = {}
194 eval_event_cfg = {}
195 # Parse train events
196 for event_name, event_info in events_dict.get("train", {}).items():
197 event_func_str = event_info.get("func")
198 mode = event_info.get("mode", "interval")
199 params = event_info.get("params", {})
200 interval_step = event_info.get("interval_step", 1)
201 event_func = find_function_from_modules(
202 event_func_str, event_modules, raise_if_not_found=True
203 )
204 train_event_cfg[event_name] = EventCfg(
205 func=event_func,
206 mode=mode,
207 params=params,
208 interval_step=interval_step,
209 )
210 # Parse eval events
211 for event_name, event_info in events_dict.get("eval", {}).items():
212 event_func_str = event_info.get("func")
213 mode = event_info.get("mode", "interval")
214 params = event_info.get("params", {})
215 interval_step = event_info.get("interval_step", 1)
216 event_func = find_function_from_modules(
217 event_func_str, event_modules, raise_if_not_found=True
218 )
219 eval_event_cfg[event_name] = EventCfg(
220 func=event_func,
221 mode=mode,
222 params=params,
223 interval_step=interval_step,
224 )
225 trainer = Trainer(
226 policy=policy,
227 env=env,
228 algorithm=algo,
229 num_steps=rollout_steps,
230 batch_size=algo_cfg["batch_size"],
231 writer=writer,
232 eval_freq=eval_freq,
233 save_freq=save_freq,
234 checkpoint_dir=checkpoint_dir,
235 exp_name=exp_name,
236 use_wandb=use_wandb,
237 eval_env=eval_env,
238 event_cfg=train_event_cfg,
239 eval_event_cfg=eval_event_cfg,
240 )
241
242 logger.log_info("Generic training initialized")
243 logger.log_info(f"Task: {type(env).__name__}")
244 logger.log_info(
245 f"Policy: {policy_name} (available: {get_registered_policy_names()})"
246 )
247 logger.log_info(
248 f"Algorithm: {algo_name} (available: {get_registered_algo_names()})"
249 )
250
251 total_steps = int(iterations * rollout_steps * env.num_envs)
252 logger.log_info(f"Total steps: {total_steps} (iterations≈{iterations})")
253
254 try:
255 trainer.train(total_steps)
256 except KeyboardInterrupt:
257 logger.log_info("Training interrupted by user")
258 finally:
259 trainer.save_checkpoint()
260 writer.close()
261 if use_wandb:
262 try:
263 wandb.finish()
264 except Exception:
265 pass
266 logger.log_info("Training finished")
267
268
269if __name__ == "__main__":
270 main()
The Script Explained#
The training script performs the following steps:
Parse Configuration: Loads JSON config and extracts runtime/env/policy/algorithm blocks
Setup: Initializes device, seeds, output directories, TensorBoard, and Weights & Biases
Build Components: - Environment via
build_env()factory - Policy viabuild_policy()registry - Algorithm viabuild_algo()factoryCreate Trainer: Instantiates the
Trainerwith all componentsTrain: Runs the training loop until completion
Launching Training#
To start training, run:
python embodichain/agents/rl/train.py --config configs/agents/rl/push_cube/train_config.json
Outputs#
All outputs are written to ./outputs/<exp_name>_<timestamp>/:
logs/: TensorBoard logs
checkpoints/: Model checkpoints
Training Process#
The training process follows this sequence:
Rollout Phase: Algorithm collects trajectories by interacting with the environment (via
collect_rollout). During this phase, the trainer performs dense per-step logging of rewards and metrics from environment info.GAE Computation: Algorithm computes advantages and returns using Generalized Advantage Estimation (internal to algorithm, stored in buffer extras)
Update Phase: Algorithm updates the policy using collected data (e.g., PPO)
Logging: Trainer logs training losses and aggregated metrics to TensorBoard and Weights & Biases
Evaluation (periodic): Trainer evaluates the current policy
Checkpointing (periodic): Trainer saves model checkpoints
Policy Interface#
All policies must inherit from the Policy abstract base class:
from abc import ABC, abstractmethod
import torch.nn as nn
class Policy(nn.Module, ABC):
device: torch.device
@abstractmethod
def get_action(
self, obs: torch.Tensor, deterministic: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Returns (action, log_prob, value)"""
raise NotImplementedError
@abstractmethod
def get_value(self, obs: torch.Tensor) -> torch.Tensor:
"""Returns value estimate"""
raise NotImplementedError
@abstractmethod
def evaluate_actions(
self, obs: torch.Tensor, actions: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Returns (log_prob, entropy, value)"""
raise NotImplementedError
Available Policies#
ActorCritic: MLP-based Gaussian policy with learnable log_std. Requires external
actorandcriticmodules to be provided (defined in JSON config).VLAPlaceholderPolicy: Placeholder for Vision-Language-Action policies
Algorithms#
Available Algorithms#
PPO: Proximal Policy Optimization with GAE
Adding a New Algorithm#
To add a new algorithm:
Create a new algorithm class in
embodichain/agents/rl/algo/Implement
initialize_buffer(),collect_rollout(), andupdate()methodsRegister in
algo/__init__.py:
from embodichain.agents.rl.algo import BaseAlgorithm, register_algo
from embodichain.agents.rl.buffer import RolloutBuffer
@register_algo("my_algo")
class MyAlgorithm(BaseAlgorithm):
def __init__(self, cfg, policy):
self.cfg = cfg
self.policy = policy
self.device = torch.device(cfg.device)
self.buffer = None
def initialize_buffer(self, num_steps, num_envs, obs_dim, action_dim):
"""Initialize the algorithm's buffer."""
self.buffer = RolloutBuffer(num_steps, num_envs, obs_dim, action_dim, self.device)
def collect_rollout(self, env, policy, obs, num_steps, on_step_callback=None):
"""Control data collection process (interact with env, fill buffer, compute advantages/returns)."""
# Collect trajectories
# Compute advantages/returns (e.g., GAE for on-policy algorithms)
# Attach extras to buffer: self.buffer.set_extras({"advantages": adv, "returns": ret})
# Return empty dict (dense logging handled in trainer)
return {}
def update(self):
"""Update the policy using collected data."""
# Access extras from buffer: self.buffer._extras.get("advantages")
# Use self.buffer to update policy
return {"loss": 0.0}
Adding a New Policy#
To add a new policy:
Create a new policy class inheriting from the
Policyabstract base classRegister in
models/__init__.py:
from embodichain.agents.rl.models import register_policy, Policy
@register_policy("my_policy")
class MyPolicy(Policy):
def __init__(self, obs_space, action_space, device, config):
super().__init__()
self.device = device
# Initialize your networks here
def get_action(self, obs, deterministic=False):
...
def get_value(self, obs):
...
def evaluate_actions(self, obs, actions):
...
Adding a New Environment#
To add a new RL environment:
Create an environment class inheriting from
EmbodiedEnvRegister it with the Gymnasium registry:
from embodichain.lab.gym.utils.registration import register_env
@register_env("MyTaskRL", max_episode_steps=100, override=True)
class MyTaskEnv(EmbodiedEnv):
cfg: MyTaskEnvCfg
...
Use the environment ID in your JSON config:
"env": {
"id": "MyTaskRL",
"cfg": {
...
}
}
Best Practices#
Device Management: Device is single-sourced from
runtime.cuda. All components (trainer/algorithm/policy/env) share the same device.Action Scaling: Keep action scaling in the environment, not in the policy.
Observation Format: Environments should provide consistent observation shape/types (torch.float32) and a single
done = terminated | truncated.Algorithm Interface: Algorithms must implement
initialize_buffer(),collect_rollout(), andupdate()methods. The algorithm completely controls data collection and buffer management.Reward Components: Organize reward components in
info["rewards"]dictionary and metrics ininfo["metrics"]dictionary. The trainer performs dense per-step logging directly from environment info.Configuration: Use JSON for all hyperparameters. This makes experiments reproducible and easy to track.
Logging: Metrics are automatically logged to TensorBoard and Weights & Biases. Check
outputs/<exp_name>/logs/for TensorBoard logs.Checkpoints: Regular checkpoints are saved to
outputs/<exp_name>/checkpoints/. Use these to resume training or evaluate policies.