Reinforcement Learning Training#
This tutorial shows you how to train reinforcement learning agents using EmbodiChain’s RL framework. You’ll learn how to configure training via JSON, set up environments, policies, and algorithms, and launch training sessions.
Overview#
The RL framework provides a modular, extensible stack for robotics tasks:
Trainer: Orchestrates the training loop (calls algorithm for data collection and updates, handles logging/eval/save)
Algorithm: Controls data collection process (interacts with environment, fills buffer, computes advantages/returns) and updates the policy (e.g., PPO)
Policy: Neural network models implementing a unified interface (get_action/get_value/evaluate_actions)
Buffer: On-policy rollout storage and minibatch iterator (managed by algorithm)
Env Factory: Build environments from a JSON config via registry
Architecture#
The framework follows a clean separation of concerns:
Trainer: Orchestrates the training loop (calls algorithm for data collection and updates, handles logging/eval/save)
Algorithm: Controls data collection process (interacts with environment, fills buffer, computes advantages/returns) and updates the policy (e.g., PPO)
Policy: Neural network models implementing a unified interface
Buffer: On-policy rollout storage and minibatch iterator (managed by algorithm)
Env Factory: Build environments from a JSON config via registry
The core components and their relationships:
Trainer → Policy, Env, Algorithm (via callbacks for statistics)
Algorithm → Policy, RolloutBuffer (algorithm manages its own buffer)
Configuration via JSON#
Training is configured via a JSON file that defines runtime settings, environment, policy, and algorithm parameters.
Example Configuration#
The configuration file (e.g., train_config.json) is located in configs/agents/rl/push_cube:
Example: train_config.json
1{
2 "trainer": {
3 "exp_name": "push_cube_ppo",
4 "gym_config": "configs/agents/rl/push_cube/gym_config.json",
5 "seed": 42,
6 "device": "cuda:0",
7 "headless": false,
8 "enable_rt": false,
9 "gpu_id": 0,
10 "num_envs": 8,
11 "iterations": 1000,
12 "rollout_steps": 1024,
13 "eval_freq": 200,
14 "save_freq": 200,
15 "use_wandb": true,
16 "wandb_project_name": "embodychain-push_cube",
17 "events": {
18 "eval": {
19 "record_camera": {
20 "func": "record_camera_data_async",
21 "mode": "interval",
22 "interval_step": 1,
23 "params": {
24 "name": "main_cam",
25 "resolution": [640, 480],
26 "eye": [-1.4, 1.4, 2.0],
27 "target": [0, 0, 0],
28 "up": [0, 0, 1],
29 "intrinsics": [600, 600, 320, 240],
30 "save_path": "./outputs/videos/eval"
31 }
32 }
33 }
34 }
35 },
36 "policy": {
37 "name": "actor_critic",
38 "actor": {
39 "type": "mlp",
40 "network_cfg": {
41 "hidden_sizes": [256, 256],
42 "activation": "relu"
43 }
44 },
45 "critic": {
46 "type": "mlp",
47 "network_cfg": {
48 "hidden_sizes": [256, 256],
49 "activation": "relu"
50 }
51 }
52 },
53 "algorithm": {
54 "name": "ppo",
55 "cfg": {
56 "learning_rate": 0.0001,
57 "n_epochs": 10,
58 "batch_size": 8192,
59 "gamma": 0.99,
60 "gae_lambda": 0.95,
61 "clip_coef": 0.2,
62 "ent_coef": 0.01,
63 "vf_coef": 0.5,
64 "max_grad_norm": 0.5
65 }
66 }
67}
Configuration Sections#
Runtime Settings#
The runtime section controls experiment setup:
exp_name: Experiment name (used for output directories)
seed: Random seed for reproducibility
cuda: Whether to use GPU (default: true)
headless: Whether to run simulation in headless mode
iterations: Number of training iterations
rollout_steps: Steps per rollout (e.g., 1024)
eval_freq: Frequency of evaluation (in steps)
save_freq: Frequency of checkpoint saving (in steps)
use_wandb: Whether to enable Weights & Biases logging (set in JSON config)
wandb_project_name: Weights & Biases project name
Environment Configuration#
The env section defines the task environment:
id: Environment registry ID (e.g., “PushCubeRL”)
cfg: Environment-specific configuration parameters
Example:
"env": {
"id": "PushCubeRL",
"cfg": {
"num_envs": 4,
"obs_mode": "state",
"episode_length": 100,
"action_scale": 0.1,
"success_threshold": 0.1
}
}
Policy Configuration#
The policy section defines the neural network policy:
name: Policy name (e.g., “actor_critic”, “vla”)
cfg: Policy-specific hyperparameters (empty for actor_critic)
actor: Actor network configuration (required for actor_critic)
critic: Critic network configuration (required for actor_critic)
Example:
"policy": {
"name": "actor_critic",
"cfg": {},
"actor": {
"type": "mlp",
"hidden_sizes": [256, 256],
"activation": "relu"
},
"critic": {
"type": "mlp",
"hidden_sizes": [256, 256],
"activation": "relu"
}
}
Algorithm Configuration#
The algorithm section defines the RL algorithm:
name: Algorithm name (e.g., “ppo”)
cfg: Algorithm-specific hyperparameters
Example:
"algorithm": {
"name": "ppo",
"cfg": {
"learning_rate": 0.0001,
"n_epochs": 10,
"batch_size": 64,
"gamma": 0.99,
"gae_lambda": 0.95,
"clip_coef": 0.2,
"ent_coef": 0.01,
"vf_coef": 0.5,
"max_grad_norm": 0.5
}
}
Training Script#
The training script (train.py) is located in embodichain/agents/rl/:
Code for train.py
1# ----------------------------------------------------------------------------
2# Copyright (c) 2021-2025 DexForce Technology Co., Ltd.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ----------------------------------------------------------------------------
16
17import argparse
18import os
19import time
20from pathlib import Path
21
22import numpy as np
23import torch
24import wandb
25import json
26from torch.utils.tensorboard import SummaryWriter
27from copy import deepcopy
28
29from embodichain.agents.rl.models import build_policy, get_registered_policy_names
30from embodichain.agents.rl.models import build_mlp_from_cfg
31from embodichain.agents.rl.algo import build_algo, get_registered_algo_names
32from embodichain.agents.rl.utils.trainer import Trainer
33from embodichain.utils import logger
34from embodichain.lab.gym.envs.tasks.rl import build_env
35from embodichain.lab.gym.utils.gym_utils import config_to_cfg
36from embodichain.utils.utility import load_json
37from embodichain.utils.module_utils import find_function_from_modules
38from embodichain.lab.sim import SimulationManagerCfg
39from embodichain.lab.gym.envs.managers.cfg import EventCfg
40
41
42def main():
43 parser = argparse.ArgumentParser()
44 parser.add_argument("--config", type=str, required=True, help="Path to JSON config")
45 args = parser.parse_args()
46
47 with open(args.config, "r") as f:
48 cfg_json = json.load(f)
49
50 trainer_cfg = cfg_json["trainer"]
51 policy_block = cfg_json["policy"]
52 algo_block = cfg_json["algorithm"]
53
54 # Runtime
55 exp_name = trainer_cfg.get("exp_name", "generic_exp")
56 seed = int(trainer_cfg.get("seed", 1))
57 device_str = trainer_cfg.get("device", "cpu")
58 iterations = int(trainer_cfg.get("iterations", 250))
59 rollout_steps = int(trainer_cfg.get("rollout_steps", 2048))
60 eval_freq = int(trainer_cfg.get("eval_freq", 10000))
61 save_freq = int(trainer_cfg.get("save_freq", 50000))
62 headless = bool(trainer_cfg.get("headless", True))
63 enable_rt = bool(trainer_cfg.get("enable_rt", False))
64 gpu_id = int(trainer_cfg.get("gpu_id", 0))
65 num_envs = trainer_cfg.get("num_envs", None)
66 wandb_project_name = trainer_cfg.get("wandb_project_name", "embodychain-generic")
67
68 # Device
69 if not isinstance(device_str, str):
70 raise ValueError(
71 f"runtime.device must be a string such as 'cpu' or 'cuda:0'. Got: {device_str!r}"
72 )
73 try:
74 device = torch.device(device_str)
75 except RuntimeError as exc:
76 raise ValueError(
77 f"Failed to parse runtime.device='{device_str}': {exc}"
78 ) from exc
79
80 if device.type == "cuda":
81 if not torch.cuda.is_available():
82 raise ValueError(
83 "CUDA device requested but torch.cuda.is_available() is False."
84 )
85 index = (
86 device.index if device.index is not None else torch.cuda.current_device()
87 )
88 device_count = torch.cuda.device_count()
89 if index < 0 or index >= device_count:
90 raise ValueError(
91 f"CUDA device index {index} is out of range (available devices: {device_count})."
92 )
93 torch.cuda.set_device(index)
94 device = torch.device(f"cuda:{index}")
95 elif device.type != "cpu":
96 raise ValueError(f"Unsupported device type: {device}")
97 logger.log_info(f"Device: {device}")
98
99 # Seeds
100 np.random.seed(seed)
101 torch.manual_seed(seed)
102 torch.backends.cudnn.deterministic = True
103 if device.type == "cuda":
104 torch.cuda.manual_seed_all(seed)
105
106 # Outputs
107 run_stamp = time.strftime("%Y%m%d_%H%M%S")
108 run_base = os.path.join("outputs", f"{exp_name}_{run_stamp}")
109 log_dir = os.path.join(run_base, "logs")
110 checkpoint_dir = os.path.join(run_base, "checkpoints")
111 os.makedirs(log_dir, exist_ok=True)
112 os.makedirs(checkpoint_dir, exist_ok=True)
113 writer = SummaryWriter(f"{log_dir}/{exp_name}")
114
115 # Initialize Weights & Biases (optional)
116 use_wandb = trainer_cfg.get("use_wandb", False)
117
118 # Initialize Weights & Biases (optional)
119 if use_wandb:
120 wandb.init(project=wandb_project_name, name=exp_name, config=cfg_json)
121
122 gym_config_path = Path(trainer_cfg["gym_config"])
123 logger.log_info(f"Current working directory: {Path.cwd()}")
124
125 gym_config_data = load_json(str(gym_config_path))
126 gym_env_cfg = config_to_cfg(gym_config_data)
127
128 # Override num_envs from train config if provided
129 if num_envs is not None:
130 gym_env_cfg.num_envs = num_envs
131
132 # Ensure sim configuration mirrors runtime overrides
133 if gym_env_cfg.sim_cfg is None:
134 gym_env_cfg.sim_cfg = SimulationManagerCfg()
135 if device.type == "cuda":
136 gpu_index = device.index
137 if gpu_index is None:
138 gpu_index = torch.cuda.current_device()
139 gym_env_cfg.sim_cfg.sim_device = torch.device(f"cuda:{gpu_index}")
140 if hasattr(gym_env_cfg.sim_cfg, "gpu_id"):
141 gym_env_cfg.sim_cfg.gpu_id = gpu_index
142 else:
143 gym_env_cfg.sim_cfg.sim_device = torch.device("cpu")
144 gym_env_cfg.sim_cfg.headless = headless
145 gym_env_cfg.sim_cfg.enable_rt = enable_rt
146 gym_env_cfg.sim_cfg.gpu_id = gpu_id
147
148 logger.log_info(
149 f"Loaded gym_config from {gym_config_path} (env_id={gym_config_data['id']}, num_envs={gym_env_cfg.num_envs}, headless={gym_env_cfg.sim_cfg.headless}, enable_rt={gym_env_cfg.sim_cfg.enable_rt}, sim_device={gym_env_cfg.sim_cfg.sim_device})"
150 )
151
152 env = build_env(gym_config_data["id"], base_env_cfg=gym_env_cfg)
153
154 eval_gym_env_cfg = deepcopy(gym_env_cfg)
155 eval_gym_env_cfg.num_envs = 4
156 eval_gym_env_cfg.sim_cfg.headless = True
157
158 eval_env = build_env(gym_config_data["id"], base_env_cfg=eval_gym_env_cfg)
159
160 # Build Policy via registry
161 policy_name = policy_block["name"]
162 # Build Policy via registry (actor/critic must be explicitly defined in JSON when using actor_critic)
163 if policy_name.lower() == "actor_critic":
164 # Get observation dimension from flattened observation space
165 # flattened_observation_space returns Box space for RL training
166 obs_dim = env.flattened_observation_space.shape[-1]
167 action_dim = env.action_space.shape[-1]
168
169 actor_cfg = policy_block.get("actor")
170 critic_cfg = policy_block.get("critic")
171 if actor_cfg is None or critic_cfg is None:
172 raise ValueError(
173 "ActorCritic requires 'actor' and 'critic' definitions in JSON (policy.actor / policy.critic)."
174 )
175
176 actor = build_mlp_from_cfg(actor_cfg, obs_dim, action_dim)
177 critic = build_mlp_from_cfg(critic_cfg, obs_dim, 1)
178
179 policy = build_policy(
180 policy_block,
181 env.flattened_observation_space,
182 env.action_space,
183 device,
184 actor=actor,
185 critic=critic,
186 )
187 else:
188 policy = build_policy(
189 policy_block, env.flattened_observation_space, env.action_space, device
190 )
191
192 # Build Algorithm via factory
193 algo_name = algo_block["name"].lower()
194 algo_cfg = algo_block["cfg"]
195 algo = build_algo(algo_name, algo_cfg, policy, device)
196
197 # Build Trainer
198 event_modules = [
199 "embodichain.lab.gym.envs.managers.randomization",
200 "embodichain.lab.gym.envs.managers.record",
201 "embodichain.lab.gym.envs.managers.events",
202 ]
203 events_dict = trainer_cfg.get("events", {})
204 train_event_cfg = {}
205 eval_event_cfg = {}
206 # Parse train events
207 for event_name, event_info in events_dict.get("train", {}).items():
208 event_func_str = event_info.get("func")
209 mode = event_info.get("mode", "interval")
210 params = event_info.get("params", {})
211 interval_step = event_info.get("interval_step", 1)
212 event_func = find_function_from_modules(
213 event_func_str, event_modules, raise_if_not_found=True
214 )
215 train_event_cfg[event_name] = EventCfg(
216 func=event_func,
217 mode=mode,
218 params=params,
219 interval_step=interval_step,
220 )
221 # Parse eval events
222 for event_name, event_info in events_dict.get("eval", {}).items():
223 event_func_str = event_info.get("func")
224 mode = event_info.get("mode", "interval")
225 params = event_info.get("params", {})
226 interval_step = event_info.get("interval_step", 1)
227 event_func = find_function_from_modules(
228 event_func_str, event_modules, raise_if_not_found=True
229 )
230 eval_event_cfg[event_name] = EventCfg(
231 func=event_func,
232 mode=mode,
233 params=params,
234 interval_step=interval_step,
235 )
236 trainer = Trainer(
237 policy=policy,
238 env=env,
239 algorithm=algo,
240 num_steps=rollout_steps,
241 batch_size=algo_cfg["batch_size"],
242 writer=writer,
243 eval_freq=eval_freq,
244 save_freq=save_freq,
245 checkpoint_dir=checkpoint_dir,
246 exp_name=exp_name,
247 use_wandb=use_wandb,
248 eval_env=eval_env,
249 event_cfg=train_event_cfg,
250 eval_event_cfg=eval_event_cfg,
251 )
252
253 logger.log_info("Generic training initialized")
254 logger.log_info(f"Task: {type(env).__name__}")
255 logger.log_info(
256 f"Policy: {policy_name} (available: {get_registered_policy_names()})"
257 )
258 logger.log_info(
259 f"Algorithm: {algo_name} (available: {get_registered_algo_names()})"
260 )
261
262 total_steps = int(iterations * rollout_steps * env.num_envs)
263 logger.log_info(f"Total steps: {total_steps} (iterations≈{iterations})")
264
265 try:
266 trainer.train(total_steps)
267 except KeyboardInterrupt:
268 logger.log_info("Training interrupted by user")
269 finally:
270 trainer.save_checkpoint()
271 writer.close()
272 if use_wandb:
273 try:
274 wandb.finish()
275 except Exception:
276 pass
277 logger.log_info("Training finished")
278
279
280if __name__ == "__main__":
281 main()
The Script Explained#
The training script performs the following steps:
Parse Configuration: Loads JSON config and extracts runtime/env/policy/algorithm blocks
Setup: Initializes device, seeds, output directories, TensorBoard, and Weights & Biases
Build Components: - Environment via
build_env()factory - Policy viabuild_policy()registry - Algorithm viabuild_algo()factoryCreate Trainer: Instantiates the
Trainerwith all componentsTrain: Runs the training loop until completion
Launching Training#
To start training, run:
python -m embodichain.agents.rl.train --config configs/agents/rl/push_cube/train_config.json
Outputs#
All outputs are written to ./outputs/<exp_name>_<timestamp>/:
logs/: TensorBoard logs
checkpoints/: Model checkpoints
Training Process#
The training process follows this sequence:
Rollout Phase: Algorithm collects trajectories by interacting with the environment (via
collect_rollout). During this phase, the trainer performs dense per-step logging of rewards and metrics from environment info.GAE Computation: Algorithm computes advantages and returns using Generalized Advantage Estimation (internal to algorithm, stored in buffer extras)
Update Phase: Algorithm updates the policy using collected data (e.g., PPO)
Logging: Trainer logs training losses and aggregated metrics to TensorBoard and Weights & Biases
Evaluation (periodic): Trainer evaluates the current policy
Checkpointing (periodic): Trainer saves model checkpoints
Policy Interface#
All policies must inherit from the Policy abstract base class:
from abc import ABC, abstractmethod
import torch.nn as nn
class Policy(nn.Module, ABC):
device: torch.device
@abstractmethod
def get_action(
self, obs: torch.Tensor, deterministic: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Returns (action, log_prob, value)"""
raise NotImplementedError
@abstractmethod
def get_value(self, obs: torch.Tensor) -> torch.Tensor:
"""Returns value estimate"""
raise NotImplementedError
@abstractmethod
def evaluate_actions(
self, obs: torch.Tensor, actions: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Returns (log_prob, entropy, value)"""
raise NotImplementedError
Available Policies#
ActorCritic: MLP-based Gaussian policy with learnable log_std. Requires external
actorandcriticmodules to be provided (defined in JSON config).VLAPlaceholderPolicy: Placeholder for Vision-Language-Action policies
Algorithms#
Available Algorithms#
PPO: Proximal Policy Optimization with GAE
Adding a New Algorithm#
To add a new algorithm:
Create a new algorithm class in
embodichain/agents/rl/algo/Implement
initialize_buffer(),collect_rollout(), andupdate()methodsRegister in
algo/__init__.py:
from embodichain.agents.rl.algo import BaseAlgorithm, register_algo
from embodichain.agents.rl.buffer import RolloutBuffer
@register_algo("my_algo")
class MyAlgorithm(BaseAlgorithm):
def __init__(self, cfg, policy):
self.cfg = cfg
self.policy = policy
self.device = torch.device(cfg.device)
self.buffer = None
def initialize_buffer(self, num_steps, num_envs, obs_dim, action_dim):
"""Initialize the algorithm's buffer."""
self.buffer = RolloutBuffer(num_steps, num_envs, obs_dim, action_dim, self.device)
def collect_rollout(self, env, policy, obs, num_steps, on_step_callback=None):
"""Control data collection process (interact with env, fill buffer, compute advantages/returns)."""
# Collect trajectories
# Compute advantages/returns (e.g., GAE for on-policy algorithms)
# Attach extras to buffer: self.buffer.set_extras({"advantages": adv, "returns": ret})
# Return empty dict (dense logging handled in trainer)
return {}
def update(self):
"""Update the policy using collected data."""
# Access extras from buffer: self.buffer._extras.get("advantages")
# Use self.buffer to update policy
return {"loss": 0.0}
Adding a New Policy#
To add a new policy:
Create a new policy class inheriting from the
Policyabstract base classRegister in
models/__init__.py:
from embodichain.agents.rl.models import register_policy, Policy
@register_policy("my_policy")
class MyPolicy(Policy):
def __init__(self, obs_space, action_space, device, config):
super().__init__()
self.device = device
# Initialize your networks here
def get_action(self, obs, deterministic=False):
...
def get_value(self, obs):
...
def evaluate_actions(self, obs, actions):
...
Adding a New Environment#
To add a new RL environment:
Create an environment class inheriting from
EmbodiedEnvRegister it with the Gymnasium registry:
from embodichain.lab.gym.utils.registration import register_env
@register_env("MyTaskRL", max_episode_steps=100, override=True)
class MyTaskEnv(EmbodiedEnv):
cfg: MyTaskEnvCfg
...
Use the environment ID in your JSON config:
"env": {
"id": "MyTaskRL",
"cfg": {
...
}
}
Best Practices#
Device Management: Device is single-sourced from
runtime.cuda. All components (trainer/algorithm/policy/env) share the same device.Action Scaling: Keep action scaling in the environment, not in the policy.
Observation Format: Environments should provide consistent observation shape/types (torch.float32) and a single
done = terminated | truncated.Algorithm Interface: Algorithms must implement
initialize_buffer(),collect_rollout(), andupdate()methods. The algorithm completely controls data collection and buffer management.Reward Components: Organize reward components in
info["rewards"]dictionary and metrics ininfo["metrics"]dictionary. The trainer performs dense per-step logging directly from environment info.Configuration: Use JSON for all hyperparameters. This makes experiments reproducible and easy to track.
Logging: Metrics are automatically logged to TensorBoard and Weights & Biases. Check
outputs/<exp_name>/logs/for TensorBoard logs.Checkpoints: Regular checkpoints are saved to
outputs/<exp_name>/checkpoints/. Use these to resume training or evaluate policies.