Reinforcement Learning Training#
This tutorial shows you how to train reinforcement learning agents using EmbodiChain’s RL framework. You’ll learn how to configure training via JSON or YAML, set up environments, policies, and algorithms, and launch training sessions.
Overview#
The RL framework provides a modular, extensible stack for robotics tasks:
Trainer: Orchestrates the training loop (calls algorithm for data collection and updates, handles logging/eval/save)
Algorithm: Controls data collection process (interacts with environment, fills buffer, computes advantages/returns) and updates the policy (e.g., PPO)
Policy: Neural network models implementing a unified interface (get_action/get_value/evaluate_actions)
Buffer: On-policy rollout storage and minibatch iterator (managed by algorithm)
Env Factory: Build environments from a JSON or YAML config via registry
Architecture#
The framework follows a clean separation of concerns:
Trainer: Orchestrates the training loop (calls algorithm for data collection and updates, handles logging/eval/save)
Algorithm: Controls data collection process (interacts with environment, fills buffer, computes advantages/returns) and updates the policy (e.g., PPO)
Policy: Neural network models implementing a unified interface
Buffer: On-policy rollout storage and minibatch iterator (managed by algorithm)
Env Factory: Build environments from a JSON or YAML config via registry
The core components and their relationships:
Trainer → Policy, Env, Algorithm (via callbacks for statistics)
Algorithm → Policy, RolloutBuffer (algorithm manages its own buffer)
Configuration via JSON or YAML#
Training is configured via a JSON or YAML file that defines runtime settings, environment, policy, and algorithm parameters. EmbodiChain loads either format with load_config(); the nested trainer.gym_config path supports the same extensions.
Example Configuration#
The configuration file (e.g., train_config.json or train_config.yaml) is located in configs/agents/rl/push_cube or configs/agents/rl/basic/cart_pole:
Example: train_config.json
1{
2 "trainer": {
3 "exp_name": "push_cube_ppo",
4 "gym_config": "configs/agents/rl/push_cube/gym_config.json",
5 "seed": 42,
6 "device": "cuda:0",
7 "headless": true,
8 "gpu_id": 0,
9 "num_envs": 64,
10 "iterations": 1000,
11 "buffer_size": 1024,
12 "enable_eval": true,
13 "num_eval_envs": 16,
14 "num_eval_episodes": 3,
15 "eval_freq": 100,
16 "save_freq": 100,
17 "use_wandb": true,
18 "wandb_project_name": "embodichain-push_cube",
19 "events": {
20 "eval": {
21 "record_camera": {
22 "func": "record_camera_data_async",
23 "mode": "interval",
24 "interval_step": 1,
25 "params": {
26 "name": "main_cam",
27 "resolution": [640, 480],
28 "eye": [-1.4, 1.4, 2.0],
29 "target": [0, 0, 0],
30 "up": [0, 0, 1],
31 "intrinsics": [600, 600, 320, 240],
32 "save_path": "./outputs/videos_ppo1/eval"
33 }
34 }
35 }
36 },
37 "renderer": "hybrid"
38 },
39 "policy": {
40 "name": "actor_critic",
41 "actor": {
42 "type": "mlp",
43 "network_cfg": {
44 "hidden_sizes": [
45 256,
46 256
47 ],
48 "activation": "relu"
49 }
50 },
51 "critic": {
52 "type": "mlp",
53 "network_cfg": {
54 "hidden_sizes": [
55 256,
56 256
57 ],
58 "activation": "relu"
59 }
60 }
61 },
62 "algorithm": {
63 "name": "ppo",
64 "cfg": {
65 "learning_rate": 0.0001,
66 "n_epochs": 10,
67 "batch_size": 8192,
68 "gamma": 0.99,
69 "gae_lambda": 0.95,
70 "clip_coef": 0.2,
71 "ent_coef": 0.01,
72 "vf_coef": 0.5,
73 "max_grad_norm": 0.5
74 }
75 }
76}
Example: train_config.yaml (CartPole)
1trainer:
2 exp_name: cart_pole_ppo
3 gym_config: configs/agents/rl/basic/cart_pole/gym_config.yaml
4 seed: 42
5 device: cuda:0
6 headless: true
7 num_envs: 64
8 iterations: 1000
9 buffer_size: 1024
10 eval_freq: 200
11 save_freq: 200
12 use_wandb: false
13 wandb_project_name: embodichain-cart_pole
14 events:
15 eval:
16 record_camera:
17 func: record_camera_data_async
18 mode: interval
19 interval_step: 1
20 params:
21 name: main_cam
22 resolution:
23 - 640
24 - 480
25 eye:
26 - -1.4
27 - 1.4
28 - 2.5
29 target:
30 - 0
31 - 0
32 - 0.7
33 up:
34 - 0
35 - 0
36 - 1
37 intrinsics:
38 - 600
39 - 600
40 - 320
41 - 240
42 save_path: ./outputs/videos/eval
43 renderer: fast-rt
44policy:
45 name: actor_critic
46 actor:
47 type: mlp
48 network_cfg:
49 hidden_sizes:
50 - 256
51 - 256
52 activation: relu
53 critic:
54 type: mlp
55 network_cfg:
56 hidden_sizes:
57 - 256
58 - 256
59 activation: relu
60algorithm:
61 name: ppo
62 cfg:
63 learning_rate: 0.0001
64 n_epochs: 10
65 batch_size: 8192
66 gamma: 0.99
67 gae_lambda: 0.95
68 clip_coef: 0.2
69 ent_coef: 0.01
70 vf_coef: 0.5
71 max_grad_norm: 0.5
Configuration Sections#
Runtime Settings#
The trainer section controls experiment setup:
exp_name: Experiment name (used for output directories)
seed: Random seed for reproducibility
device: Runtime device string, e.g.
"cpu"or"cuda:0"headless: Whether to run simulation in headless mode
iterations: Number of training iterations
buffer_size: Steps collected per rollout (e.g., 1024)
eval_freq: Frequency of evaluation (in steps)
save_freq: Frequency of checkpoint saving (in steps)
use_wandb: Whether to enable Weights & Biases logging (set in the config file)
wandb_project_name: Weights & Biases project name
Environment Configuration#
The env section defines the task environment:
id: Environment registry ID (e.g., “PushCubeRL”)
cfg: Environment-specific configuration parameters
For RL environments, use the actions field for action preprocessing and extensions for task-specific parameters:
actions: Action Manager config (e.g., DeltaQposTerm with scale)
extensions: Task-specific parameters (e.g., success_threshold)
Example:
"env": {
"id": "PushCubeRL",
"cfg": {
"num_envs": 4,
"actions": {
"delta_qpos": {
"func": "DeltaQposTerm",
"params": { "scale": 0.1 }
}
},
"extensions": {
"success_threshold": 0.1
}
}
}
Policy Configuration#
The policy section defines the neural network policy:
name: Policy name (e.g., “actor_critic”, “vla”)
action_dim: Optional policy output action dimension. If omitted, it is inferred from
env.action_space.actor: Actor network configuration (required for actor_critic)
critic: Critic network configuration (required for actor_critic)
Example:
"policy": {
"name": "actor_critic",
"actor": {
"type": "mlp",
"network_cfg": {
"hidden_sizes": [256, 256],
"activation": "relu"
}
},
"critic": {
"type": "mlp",
"network_cfg": {
"hidden_sizes": [256, 256],
"activation": "relu"
}
}
}
Algorithm Configuration#
The algorithm section defines the RL algorithm:
name: Algorithm name (e.g., “ppo”, “grpo”)
cfg: Algorithm-specific hyperparameters
PPO example:
"algorithm": {
"name": "ppo",
"cfg": {
"learning_rate": 0.0001,
"n_epochs": 10,
"batch_size": 64,
"gamma": 0.99,
"gae_lambda": 0.95,
"clip_coef": 0.2,
"ent_coef": 0.01,
"vf_coef": 0.5,
"max_grad_norm": 0.5
}
}
GRPO example (for Embodied AI / from-scratch training, e.g. CartPole):
"algorithm": {
"name": "grpo",
"cfg": {
"learning_rate": 0.0001,
"n_epochs": 10,
"batch_size": 8192,
"gamma": 0.99,
"clip_coef": 0.2,
"ent_coef": 0.001,
"kl_coef": 0,
"group_size": 4,
"eps": 1e-8,
"reset_every_rollout": true,
"max_grad_norm": 0.5,
"truncate_at_first_done": true
}
}
For GRPO: use actor_only policy. Set kl_coef=0 for from-scratch training; kl_coef=0.02 for VLA/LLM fine-tuning.
Training Script#
The training script (train.py) is located in embodichain/agents/rl/:
Code for train.py
1# ----------------------------------------------------------------------------
2# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ----------------------------------------------------------------------------
16
17import argparse
18import os
19import time
20from pathlib import Path
21
22import numpy as np
23import torch
24import wandb
25from torch.utils.tensorboard import SummaryWriter
26from copy import deepcopy
27
28from embodichain.agents.rl.models import build_policy, get_registered_policy_names
29from embodichain.agents.rl.models import build_mlp_from_cfg
30from embodichain.agents.rl.algo import build_algo, get_registered_algo_names
31from embodichain.agents.rl.utils import dict_to_tensordict, flatten_dict_observation
32from embodichain.agents.rl.utils.trainer import Trainer
33from embodichain.utils import logger
34from embodichain.lab.gym.envs.tasks.rl import build_env
35from embodichain.lab.gym.utils.gym_utils import config_to_cfg, DEFAULT_MANAGER_MODULES
36from embodichain.utils.utility import load_config
37from embodichain.utils.module_utils import find_function_from_modules
38from embodichain.lab.sim import SimulationManagerCfg
39from embodichain.lab.sim.cfg import RenderCfg
40from embodichain.lab.gym.envs.managers.cfg import EventCfg
41
42
43def parse_args():
44 """Parse command line arguments."""
45 parser = argparse.ArgumentParser()
46 parser.add_argument(
47 "--config",
48 type=str,
49 required=True,
50 help="Path to training config file (.json, .yaml, or .yml).",
51 )
52 parser.add_argument(
53 "--distributed",
54 action=argparse.BooleanOptionalAction,
55 default=None,
56 help="Enable or disable multi-GPU distributed training",
57 )
58 return parser.parse_args()
59
60
61def train_from_config(config_path: str, distributed: bool | None = None):
62 """Run training from a config file path.
63
64 Args:
65 config_path: Path to the training config file (.json, .yaml, or .yml).
66 distributed: If True, run multi-GPU distributed training.
67 If None, use trainer.distributed from config.
68 """
69 cfg_data = load_config(config_path)
70
71 trainer_cfg = cfg_data["trainer"]
72 policy_block = cfg_data["policy"]
73 algo_block = cfg_data["algorithm"]
74
75 # Resolve distributed flag
76 if distributed is None:
77 distributed = bool(trainer_cfg.get("distributed", False))
78
79 # Distributed setup
80 rank = 0
81 world_size = 1
82 local_rank = 0
83 if distributed:
84 if not torch.distributed.is_available():
85 raise RuntimeError(
86 "Distributed training requested but torch.distributed is not available."
87 )
88 if not torch.cuda.is_available():
89 raise RuntimeError(
90 "Distributed training with NCCL backend requires CUDA, "
91 "but torch.cuda.is_available() is False."
92 )
93 local_rank = int(os.environ.get("LOCAL_RANK", 0))
94 if local_rank < 0 or local_rank >= torch.cuda.device_count():
95 raise ValueError(
96 f"LOCAL_RANK {local_rank} is out of range "
97 f"(available GPUs: {torch.cuda.device_count()})."
98 )
99 torch.cuda.set_device(local_rank)
100 if not torch.distributed.is_initialized():
101 torch.distributed.init_process_group(backend="nccl")
102 rank = torch.distributed.get_rank()
103 world_size = torch.distributed.get_world_size()
104
105 # Runtime
106 exp_name = trainer_cfg.get("exp_name", "generic_exp")
107 seed = int(trainer_cfg.get("seed", 1))
108 device_str = trainer_cfg.get("device", "cpu")
109 if distributed:
110 device_str = f"cuda:{local_rank}"
111 iterations = int(trainer_cfg.get("iterations", 250))
112 buffer_size = int(
113 trainer_cfg.get("buffer_size", trainer_cfg.get("rollout_steps", 2048))
114 )
115 enable_eval = bool(trainer_cfg.get("enable_eval", False))
116 eval_freq = int(trainer_cfg.get("eval_freq", 10000))
117 save_freq = int(trainer_cfg.get("save_freq", 50000))
118 num_eval_episodes = int(trainer_cfg.get("num_eval_episodes", 5))
119 headless = bool(trainer_cfg.get("headless", True))
120 renderer = trainer_cfg.get("renderer", "hybrid")
121 gpu_id = int(trainer_cfg.get("gpu_id", 0))
122 num_envs = trainer_cfg.get("num_envs", None)
123 wandb_project_name = trainer_cfg.get("wandb_project_name", "embodichain-generic")
124
125 # Device
126 if not isinstance(device_str, str):
127 raise ValueError(
128 f"runtime.device must be a string such as 'cpu' or 'cuda:0'. Got: {device_str!r}"
129 )
130 try:
131 device = torch.device(device_str)
132 except RuntimeError as exc:
133 raise ValueError(
134 f"Failed to parse runtime.device='{device_str}': {exc}"
135 ) from exc
136
137 if device.type == "cuda":
138 if not torch.cuda.is_available():
139 raise ValueError(
140 "CUDA device requested but torch.cuda.is_available() is False."
141 )
142 index = (
143 device.index if device.index is not None else torch.cuda.current_device()
144 )
145 device_count = torch.cuda.device_count()
146 if index < 0 or index >= device_count:
147 raise ValueError(
148 f"CUDA device index {index} is out of range (available devices: {device_count})."
149 )
150 torch.cuda.set_device(index)
151 device = torch.device(f"cuda:{index}")
152 elif device.type != "cpu":
153 raise ValueError(f"Unsupported device type: {device}")
154 if rank == 0:
155 logger.log_info(f"Device: {device}")
156 if distributed and rank == 0:
157 logger.log_info(f"Distributed training: world_size={world_size}")
158
159 # Seeds
160 effective_seed = seed + rank
161 np.random.seed(effective_seed)
162 torch.manual_seed(effective_seed)
163 torch.backends.cudnn.deterministic = True
164 if device.type == "cuda":
165 torch.cuda.manual_seed_all(effective_seed)
166
167 # Outputs
168 if distributed:
169 run_stamp = time.strftime("%Y%m%d_%H%M%S") if rank == 0 else None
170 run_stamp_list = [run_stamp]
171 torch.distributed.broadcast_object_list(run_stamp_list, src=0)
172 run_stamp = run_stamp_list[0]
173 else:
174 run_stamp = time.strftime("%Y%m%d_%H%M%S")
175 run_base = os.path.join("outputs", f"{exp_name}_{run_stamp}")
176 log_dir = os.path.join(run_base, "logs")
177 checkpoint_dir = os.path.join(run_base, "checkpoints")
178 if rank == 0:
179 os.makedirs(log_dir, exist_ok=True)
180 os.makedirs(checkpoint_dir, exist_ok=True)
181 writer = SummaryWriter(f"{log_dir}/{exp_name}") if rank == 0 else None
182
183 # Initialize Weights & Biases (optional)
184 use_wandb = trainer_cfg.get("use_wandb", False)
185 if use_wandb and rank == 0:
186 wandb.init(project=wandb_project_name, name=exp_name, config=cfg_data)
187
188 gym_config_path = Path(trainer_cfg["gym_config"])
189 if rank == 0:
190 logger.log_info(f"Current working directory: {Path.cwd()}")
191
192 gym_config_data = load_config(str(gym_config_path))
193 gym_env_cfg = config_to_cfg(
194 gym_config_data, manager_modules=DEFAULT_MANAGER_MODULES
195 )
196 if num_envs is not None:
197 gym_env_cfg.num_envs = int(num_envs)
198
199 # Ensure sim configuration mirrors runtime overrides
200 if gym_env_cfg.sim_cfg is None:
201 gym_env_cfg.sim_cfg = SimulationManagerCfg()
202 if device.type == "cuda":
203 gpu_index = device.index
204 if gpu_index is None:
205 gpu_index = torch.cuda.current_device()
206 gym_env_cfg.sim_cfg.sim_device = torch.device(f"cuda:{gpu_index}")
207 if hasattr(gym_env_cfg.sim_cfg, "gpu_id"):
208 gym_env_cfg.sim_cfg.gpu_id = gpu_index
209 else:
210 gym_env_cfg.sim_cfg.sim_device = torch.device("cpu")
211 gym_env_cfg.sim_cfg.headless = headless
212 gym_env_cfg.sim_cfg.render_cfg = RenderCfg(renderer=renderer)
213 gym_env_cfg.sim_cfg.gpu_id = gpu_id
214 logger.log_info(
215 f"Loaded gym_config from {gym_config_path} (env_id={gym_config_data['id']}, num_envs={gym_env_cfg.num_envs}, headless={gym_env_cfg.sim_cfg.headless}, renderer={gym_env_cfg.sim_cfg.render_cfg.renderer}, sim_device={gym_env_cfg.sim_cfg.sim_device})"
216 )
217
218 env = build_env(gym_config_data["id"], base_env_cfg=gym_env_cfg)
219 sample_obs, _ = env.reset()
220 sample_obs_td = dict_to_tensordict(sample_obs, device)
221 obs_dim = flatten_dict_observation(sample_obs_td).shape[-1]
222 flat_obs_space = env.flattened_observation_space
223
224 # Create evaluation environment only if enabled
225 eval_env = None
226 num_eval_envs = trainer_cfg.get("num_eval_envs", 4)
227 if enable_eval and rank == 0:
228 eval_gym_env_cfg = deepcopy(gym_env_cfg)
229 eval_gym_env_cfg.num_envs = num_eval_envs
230 eval_gym_env_cfg.sim_cfg.headless = True
231 eval_env = build_env(gym_config_data["id"], base_env_cfg=eval_gym_env_cfg)
232 logger.log_info(
233 f"Evaluation environment created (num_envs={num_eval_envs}, headless=True)"
234 )
235
236 # Build Policy via registry
237 policy_name = policy_block["name"]
238 env_action_dim = (
239 env.get_wrapper_attr("action_manager").total_action_dim
240 if env.get_wrapper_attr("action_manager") is not None
241 else len(env.get_wrapper_attr("active_joint_ids"))
242 )
243 action_dim = policy_block.get("action_dim", env_action_dim)
244 action_dim = int(action_dim)
245 if action_dim != env_action_dim:
246 raise ValueError(
247 f"Configured policy.action_dim={action_dim} does not match env action dim {env_action_dim}."
248 )
249 # Build Policy via registry (actor/critic must be explicitly defined in JSON when using actor_critic/actor_only)
250 if policy_name.lower() == "actor_critic":
251 actor_cfg = policy_block.get("actor")
252 critic_cfg = policy_block.get("critic")
253 if actor_cfg is None or critic_cfg is None:
254 raise ValueError(
255 "ActorCritic requires 'actor' and 'critic' definitions in JSON (policy.actor / policy.critic)."
256 )
257
258 actor = build_mlp_from_cfg(actor_cfg, obs_dim, action_dim)
259 critic = build_mlp_from_cfg(critic_cfg, obs_dim, 1)
260
261 policy = build_policy(
262 policy_block,
263 flat_obs_space,
264 env.action_space,
265 device,
266 actor=actor,
267 critic=critic,
268 )
269 elif policy_name.lower() == "actor_only":
270 actor_cfg = policy_block.get("actor")
271 if actor_cfg is None:
272 raise ValueError(
273 "ActorOnly requires 'actor' definition in JSON (policy.actor)."
274 )
275
276 actor = build_mlp_from_cfg(actor_cfg, obs_dim, action_dim)
277
278 policy = build_policy(
279 policy_block,
280 flat_obs_space,
281 env.action_space,
282 device,
283 actor=actor,
284 )
285 else:
286 policy = build_policy(
287 policy_block, env.observation_space, env.action_space, device
288 )
289
290 # Build Algorithm via factory
291 algo_name = algo_block["name"].lower()
292 algo_cfg = algo_block["cfg"]
293 algo = build_algo(
294 algo_name,
295 algo_cfg,
296 policy,
297 device,
298 distributed=distributed,
299 )
300
301 # Build Trainer
302 event_modules = [
303 "embodichain.lab.gym.envs.managers.randomization",
304 "embodichain.lab.gym.envs.managers.record",
305 "embodichain.lab.gym.envs.managers.events",
306 ]
307 events_dict = trainer_cfg.get("events", {})
308 train_event_cfg = {}
309 eval_event_cfg = {}
310 # Parse train events
311 for event_name, event_info in events_dict.get("train", {}).items():
312 event_func_str = event_info.get("func")
313 mode = event_info.get("mode", "interval")
314 params = event_info.get("params", {})
315 interval_step = event_info.get("interval_step", 1)
316 event_func = find_function_from_modules(
317 event_func_str, event_modules, raise_if_not_found=True
318 )
319 train_event_cfg[event_name] = EventCfg(
320 func=event_func,
321 mode=mode,
322 params=params,
323 interval_step=interval_step,
324 )
325 # Parse eval events (only if evaluation is enabled)
326 if enable_eval:
327 for event_name, event_info in events_dict.get("eval", {}).items():
328 event_func_str = event_info.get("func")
329 mode = event_info.get("mode", "interval")
330 params = event_info.get("params", {})
331 interval_step = event_info.get("interval_step", 1)
332 event_func = find_function_from_modules(
333 event_func_str, event_modules, raise_if_not_found=True
334 )
335 eval_event_cfg[event_name] = EventCfg(
336 func=event_func,
337 mode=mode,
338 params=params,
339 interval_step=interval_step,
340 )
341 trainer = Trainer(
342 policy=policy,
343 env=env,
344 algorithm=algo,
345 buffer_size=buffer_size,
346 batch_size=algo_cfg["batch_size"],
347 writer=writer,
348 eval_freq=eval_freq if enable_eval else 0, # Disable eval if not enabled
349 save_freq=save_freq,
350 checkpoint_dir=checkpoint_dir,
351 exp_name=exp_name,
352 use_wandb=use_wandb,
353 eval_env=eval_env, # None if enable_eval=False
354 event_cfg=train_event_cfg,
355 eval_event_cfg=eval_event_cfg if (enable_eval and rank == 0) else {},
356 num_eval_episodes=num_eval_episodes,
357 distributed=distributed,
358 rank=rank,
359 world_size=world_size,
360 )
361
362 if rank == 0:
363 logger.log_info("Generic training initialized")
364 logger.log_info(f"Task: {type(env).__name__}")
365 logger.log_info(
366 f"Policy: {policy_name} (available: {get_registered_policy_names()})"
367 )
368 logger.log_info(
369 f"Algorithm: {algo_name} (available: {get_registered_algo_names()})"
370 )
371
372 total_steps = int(iterations * buffer_size * env.num_envs * world_size)
373 if rank == 0:
374 logger.log_info(
375 f"Total steps: {total_steps} (iterations≈{iterations}, world_size={world_size})"
376 )
377
378 try:
379 trainer.train(total_steps)
380 except KeyboardInterrupt:
381 if rank == 0:
382 logger.log_info("Training interrupted by user")
383 finally:
384 trainer.save_checkpoint()
385 if writer is not None:
386 writer.close()
387 if use_wandb and rank == 0:
388 try:
389 wandb.finish()
390 except Exception:
391 pass
392
393 # Clean up environments to prevent resource leaks
394 try:
395 if env is not None:
396 env.close()
397 except Exception as e:
398 if rank == 0:
399 logger.log_warning(f"Failed to close training environment: {e}")
400
401 try:
402 if eval_env is not None:
403 eval_env.close()
404 except Exception as e:
405 if rank == 0:
406 logger.log_warning(f"Failed to close evaluation environment: {e}")
407
408 if distributed and torch.distributed.is_initialized():
409 torch.distributed.destroy_process_group()
410
411 if rank == 0:
412 logger.log_info("Training finished")
413
414
415def cli() -> None:
416 """Command-line interface for RL training.
417
418 Parses CLI arguments and launches training from a config file.
419 """
420 args = parse_args()
421 train_from_config(args.config, distributed=args.distributed)
422
423
424if __name__ == "__main__":
425 cli()
The Script Explained#
The training script performs the following steps:
Parse Configuration: Loads the config file (
.json,.yaml, or.yml) and extracts runtime/env/policy/algorithm blocksSetup: Initializes device, seeds, output directories, TensorBoard, and Weights & Biases
Build Components: - Environment via
build_env()factory - Policy viabuild_policy()registry - Algorithm viabuild_algo()factoryCreate Trainer: Instantiates the
Trainerwith all componentsTrain: Runs the training loop until completion
Launching Training#
To start training, run:
python -m embodichain train-rl --config configs/agents/rl/basic/cart_pole/train_config.yaml
JSON configs are also supported:
python -m embodichain train-rl --config configs/agents/rl/push_cube/train_config.json
Outputs#
All outputs are written to ./outputs/<exp_name>_<timestamp>/:
logs/: TensorBoard logs
checkpoints/: Model checkpoints
Training Process#
The training process follows this sequence:
Rollout Phase:
SyncCollectorinteracts with the environment and writes policy-side fields into a shared rolloutTensorDictwith uniform[N, T + 1]layout.EmbodiedEnvwrites environment-side step fields such asreward,done,terminated, andtruncatedinto the same rollout viaset_rollout_buffer(). The final slot of transition-only fields is reserved as padding, whileobs[:, -1]andvalue[:, -1]remain valid bootstrap data.Advantage/Return Computation: Algorithm computes advantages and returns from the collected rollout (e.g. GAE for PPO, step-wise group normalization for GRPO) and converts it to a transition-aligned view over the valid first
Tsteps before minibatch optimization.Update Phase: Algorithm updates the policy with
update(rollout)Logging: Trainer logs training losses and aggregated metrics to TensorBoard and Weights & Biases
Evaluation (periodic): Trainer evaluates the current policy
Checkpointing (periodic): Trainer saves model checkpoints
Policy Interface#
All policies must inherit from the Policy abstract base class:
from abc import ABC, abstractmethod
import torch.nn as nn
class Policy(nn.Module, ABC):
device: torch.device
def get_action(self, tensordict, deterministic: bool = False):
"""Samples action, sample_log_prob, and value into the TensorDict."""
...
@abstractmethod
def forward(self, tensordict, deterministic: bool = False):
"""Writes action, sample_log_prob, and value into the TensorDict."""
raise NotImplementedError
@abstractmethod
def get_value(self, tensordict):
"""Writes value estimate into the TensorDict."""
raise NotImplementedError
@abstractmethod
def evaluate_actions(self, tensordict):
"""Returns a new TensorDict with log_prob, entropy, and value."""
raise NotImplementedError
Available Policies#
ActorCritic: MLP-based Gaussian policy with learnable log_std. Requires external
actorandcriticmodules to be provided (defined in the training config file). Used with PPO.ActorOnly: Actor-only policy without Critic. Used with GRPO (group-relative advantage estimation).
VLAPlaceholderPolicy: Placeholder for Vision-Language-Action policies
Algorithms#
Available Algorithms#
PPO: Proximal Policy Optimization with GAE
GRPO: Group Relative Policy Optimization (no Critic, step-wise returns, masked group normalization). Use
actor_onlypolicy. Setkl_coef=0for from-scratch training (CartPole, dense reward);kl_coef=0.02for VLA/LLM fine-tuning.
Adding a New Algorithm#
To add a new algorithm:
Create a new algorithm class in
embodichain/agents/rl/algo/Implement
update(rollout)and consume the shared rolloutTensorDictRegister in
algo/__init__.py:
from tensordict import TensorDict
from embodichain.agents.rl.algo import BaseAlgorithm, register_algo
@register_algo("my_algo")
class MyAlgorithm(BaseAlgorithm):
def __init__(self, cfg, policy):
self.cfg = cfg
self.policy = policy
self.device = torch.device(cfg.device)
def update(self, rollout: TensorDict):
"""Update the policy using a collected rollout."""
# compute advantages / returns from rollout
# optimize policy parameters
return {"loss": 0.0}
Adding a New Policy#
To add a new policy:
Create a new policy class inheriting from the
Policyabstract base classRegister in
models/__init__.py:
from embodichain.agents.rl.models import register_policy, Policy
@register_policy("my_policy")
class MyPolicy(Policy):
def __init__(self, obs_dim, action_dim, device, config):
super().__init__()
self.device = device
# Initialize your networks here
def get_action(self, tensordict, deterministic=False):
...
def forward(self, tensordict, deterministic=False):
...
def get_value(self, tensordict):
...
def evaluate_actions(self, tensordict):
...
Current built-in MLP policies use flattened observations in the training path. If your policy requires structured or multi-modal inputs, keep the richer obs_space interface and define a matching rollout/collector schema.
Adding a New Environment#
To add a new RL environment:
Create an environment class inheriting from
EmbodiedEnv(with Action Manager configured for action preprocessing and standardized info structure):
from embodichain.lab.gym.envs import EmbodiedEnv, EmbodiedEnvCfg
from embodichain.lab.gym.utils.registration import register_env
import torch
@register_env("MyTaskRL", override=True)
class MyTaskEnv(EmbodiedEnv):
def __init__(self, cfg: EmbodiedEnvCfg = None, **kwargs):
super().__init__(cfg, **kwargs)
def compute_task_state(self, **kwargs):
"""Compute success/failure conditions and metrics."""
is_success = ... # Define success condition
is_fail = torch.zeros_like(is_success)
metrics = {"distance": ..., "error": ...}
return is_success, is_fail, metrics
Configure the environment in your config file with
actionsandextensions:
"env": {
"id": "MyTaskRL",
"cfg": {
"num_envs": 4,
"actions": {
"delta_qpos": {
"func": "DeltaQposTerm",
"params": { "scale": 0.1 }
}
},
"extensions": {
"success_threshold": 0.05
}
}
}
The EmbodiedEnv with Action Manager provides:
Action Preprocessing: Configurable via
actions(DeltaQposTerm, QposTerm, EefPoseTerm, etc.)Standardized Info: Implements
get_info()usingcompute_task_state()template method
Best Practices#
Use EmbodiedEnv with Action Manager for RL Tasks: Inherit from
EmbodiedEnvand configureactionsin your config. The Action Manager handles action preprocessing (delta_qpos, qpos, qvel, qf, eef_pose) in a modular way.Action Configuration: Use the
actionsfield in your config file. Example:"delta_qpos": {"func": "DeltaQposTerm", "params": {"scale": 0.1}}.Device Management: Device is single-sourced from
runtime.cuda. All components (trainer/algorithm/policy/env) share the same device.Observation Format: Environments should provide consistent observation shape/types (torch.float32) and a single
done = terminated | truncated.Algorithm Interface: Algorithms implement
update(rollout)and consume a shared rolloutTensorDict. Collection is handled bySyncCollectorplus environment-side rollout writes inEmbodiedEnv.Reward Configuration: Use the
RewardManagerin your environment config to define reward components. Organize reward components ininfo["rewards"]dictionary and metrics ininfo["metrics"]dictionary. The trainer performs dense per-step logging directly from environment info.Template Methods: Override
compute_task_state()to define success/failure conditions and metrics. Overridecheck_truncated()for custom truncation logic.Configuration: Use JSON for all hyperparameters. This makes experiments reproducible and easy to track.
Logging: Metrics are automatically logged to TensorBoard and Weights & Biases. Check
outputs/<exp_name>/logs/for TensorBoard logs.Checkpoints: Regular checkpoints are saved to
outputs/<exp_name>/checkpoints/. Use these to resume training or evaluate policies.
See Also#
Reinforcement Learning — RL module architecture and component reference
Embodied Environments — EmbodiedEnv configuration and Action Manager
Creating a Basic Environment — Creating basic Gymnasium environments
Creating a Modular Environment — Advanced modular environments with managers
Supported Tasks — List of available RL task environments