Reinforcement Learning Training#
This tutorial shows you how to train reinforcement learning agents using EmbodiChain’s RL framework. You’ll learn how to configure training via JSON, set up environments, policies, and algorithms, and launch training sessions.
Overview#
The RL framework provides a modular, extensible stack for robotics tasks:
Trainer: Orchestrates the training loop (calls algorithm for data collection and updates, handles logging/eval/save)
Algorithm: Controls data collection process (interacts with environment, fills buffer, computes advantages/returns) and updates the policy (e.g., PPO)
Policy: Neural network models implementing a unified interface (get_action/get_value/evaluate_actions)
Buffer: On-policy rollout storage and minibatch iterator (managed by algorithm)
Env Factory: Build environments from a JSON config via registry
Architecture#
The framework follows a clean separation of concerns:
Trainer: Orchestrates the training loop (calls algorithm for data collection and updates, handles logging/eval/save)
Algorithm: Controls data collection process (interacts with environment, fills buffer, computes advantages/returns) and updates the policy (e.g., PPO)
Policy: Neural network models implementing a unified interface
Buffer: On-policy rollout storage and minibatch iterator (managed by algorithm)
Env Factory: Build environments from a JSON config via registry
The core components and their relationships:
Trainer → Policy, Env, Algorithm (via callbacks for statistics)
Algorithm → Policy, RolloutBuffer (algorithm manages its own buffer)
Configuration via JSON#
Training is configured via a JSON file that defines runtime settings, environment, policy, and algorithm parameters.
Example Configuration#
The configuration file (e.g., train_config.json) is located in configs/agents/rl/push_cube:
Example: train_config.json
1{
2 "trainer": {
3 "exp_name": "push_cube_ppo",
4 "gym_config": "configs/agents/rl/push_cube/gym_config.json",
5 "seed": 42,
6 "device": "cuda:0",
7 "headless": true,
8 "enable_rt": false,
9 "gpu_id": 0,
10 "num_envs": 64,
11 "iterations": 1000,
12 "buffer_size": 1024,
13 "enable_eval": true,
14 "num_eval_envs": 16,
15 "num_eval_episodes": 3,
16 "eval_freq": 100,
17 "save_freq": 100,
18 "use_wandb": true,
19 "wandb_project_name": "embodichain-push_cube",
20 "events": {
21 "eval": {
22 "record_camera": {
23 "func": "record_camera_data_async",
24 "mode": "interval",
25 "interval_step": 1,
26 "params": {
27 "name": "main_cam",
28 "resolution": [640, 480],
29 "eye": [-1.4, 1.4, 2.0],
30 "target": [0, 0, 0],
31 "up": [0, 0, 1],
32 "intrinsics": [600, 600, 320, 240],
33 "save_path": "./outputs/videos_ppo1/eval"
34 }
35 }
36 }
37 }
38 },
39 "policy": {
40 "name": "actor_critic",
41 "actor": {
42 "type": "mlp",
43 "network_cfg": {
44 "hidden_sizes": [256, 256],
45 "activation": "relu"
46 }
47 },
48 "critic": {
49 "type": "mlp",
50 "network_cfg": {
51 "hidden_sizes": [256, 256],
52 "activation": "relu"
53 }
54 }
55 },
56 "algorithm": {
57 "name": "ppo",
58 "cfg": {
59 "learning_rate": 0.0001,
60 "n_epochs": 10,
61 "batch_size": 8192,
62 "gamma": 0.99,
63 "gae_lambda": 0.95,
64 "clip_coef": 0.2,
65 "ent_coef": 0.01,
66 "vf_coef": 0.5,
67 "max_grad_norm": 0.5
68 }
69 }
70}
Configuration Sections#
Runtime Settings#
The trainer section controls experiment setup:
exp_name: Experiment name (used for output directories)
seed: Random seed for reproducibility
device: Runtime device string, e.g.
"cpu"or"cuda:0"headless: Whether to run simulation in headless mode
iterations: Number of training iterations
buffer_size: Steps collected per rollout (e.g., 1024)
eval_freq: Frequency of evaluation (in steps)
save_freq: Frequency of checkpoint saving (in steps)
use_wandb: Whether to enable Weights & Biases logging (set in JSON config)
wandb_project_name: Weights & Biases project name
Environment Configuration#
The env section defines the task environment:
id: Environment registry ID (e.g., “PushCubeRL”)
cfg: Environment-specific configuration parameters
For RL environments, use the actions field for action preprocessing and extensions for task-specific parameters:
actions: Action Manager config (e.g., DeltaQposTerm with scale)
extensions: Task-specific parameters (e.g., success_threshold)
Example:
"env": {
"id": "PushCubeRL",
"cfg": {
"num_envs": 4,
"actions": {
"delta_qpos": {
"func": "DeltaQposTerm",
"params": { "scale": 0.1 }
}
},
"extensions": {
"success_threshold": 0.1
}
}
}
Policy Configuration#
The policy section defines the neural network policy:
name: Policy name (e.g., “actor_critic”, “vla”)
action_dim: Optional policy output action dimension. If omitted, it is inferred from
env.action_space.actor: Actor network configuration (required for actor_critic)
critic: Critic network configuration (required for actor_critic)
Example:
"policy": {
"name": "actor_critic",
"actor": {
"type": "mlp",
"network_cfg": {
"hidden_sizes": [256, 256],
"activation": "relu"
}
},
"critic": {
"type": "mlp",
"network_cfg": {
"hidden_sizes": [256, 256],
"activation": "relu"
}
}
}
Algorithm Configuration#
The algorithm section defines the RL algorithm:
name: Algorithm name (e.g., “ppo”, “grpo”)
cfg: Algorithm-specific hyperparameters
PPO example:
"algorithm": {
"name": "ppo",
"cfg": {
"learning_rate": 0.0001,
"n_epochs": 10,
"batch_size": 64,
"gamma": 0.99,
"gae_lambda": 0.95,
"clip_coef": 0.2,
"ent_coef": 0.01,
"vf_coef": 0.5,
"max_grad_norm": 0.5
}
}
GRPO example (for Embodied AI / from-scratch training, e.g. CartPole):
"algorithm": {
"name": "grpo",
"cfg": {
"learning_rate": 0.0001,
"n_epochs": 10,
"batch_size": 8192,
"gamma": 0.99,
"clip_coef": 0.2,
"ent_coef": 0.001,
"kl_coef": 0,
"group_size": 4,
"eps": 1e-8,
"reset_every_rollout": true,
"max_grad_norm": 0.5,
"truncate_at_first_done": true
}
}
For GRPO: use actor_only policy. Set kl_coef=0 for from-scratch training; kl_coef=0.02 for VLA/LLM fine-tuning.
Training Script#
The training script (train.py) is located in embodichain/agents/rl/:
Code for train.py
1# ----------------------------------------------------------------------------
2# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ----------------------------------------------------------------------------
16
17import argparse
18import os
19import time
20from pathlib import Path
21
22import numpy as np
23import torch
24import wandb
25import json
26from torch.utils.tensorboard import SummaryWriter
27from copy import deepcopy
28
29from embodichain.agents.rl.models import build_policy, get_registered_policy_names
30from embodichain.agents.rl.models import build_mlp_from_cfg
31from embodichain.agents.rl.algo import build_algo, get_registered_algo_names
32from embodichain.agents.rl.utils import dict_to_tensordict, flatten_dict_observation
33from embodichain.agents.rl.utils.trainer import Trainer
34from embodichain.utils import logger
35from embodichain.lab.gym.envs.tasks.rl import build_env
36from embodichain.lab.gym.utils.gym_utils import config_to_cfg, DEFAULT_MANAGER_MODULES
37from embodichain.utils.utility import load_json
38from embodichain.utils.module_utils import find_function_from_modules
39from embodichain.lab.sim import SimulationManagerCfg
40from embodichain.lab.gym.envs.managers.cfg import EventCfg
41
42
43def parse_args():
44 """Parse command line arguments."""
45 parser = argparse.ArgumentParser()
46 parser.add_argument("--config", type=str, required=True, help="Path to JSON config")
47 parser.add_argument(
48 "--distributed",
49 action=argparse.BooleanOptionalAction,
50 default=None,
51 help="Enable or disable multi-GPU distributed training",
52 )
53 return parser.parse_args()
54
55
56def train_from_config(config_path: str, distributed: bool | None = None):
57 """Run training from a config file path.
58
59 Args:
60 config_path: Path to the JSON config file
61 distributed: If True, run multi-GPU distributed training.
62 If None, use trainer.distributed from config.
63 """
64 with open(config_path, "r") as f:
65 cfg_json = json.load(f)
66
67 trainer_cfg = cfg_json["trainer"]
68 policy_block = cfg_json["policy"]
69 algo_block = cfg_json["algorithm"]
70
71 # Resolve distributed flag
72 if distributed is None:
73 distributed = bool(trainer_cfg.get("distributed", False))
74
75 # Distributed setup
76 rank = 0
77 world_size = 1
78 local_rank = 0
79 if distributed:
80 if not torch.distributed.is_available():
81 raise RuntimeError(
82 "Distributed training requested but torch.distributed is not available."
83 )
84 if not torch.cuda.is_available():
85 raise RuntimeError(
86 "Distributed training with NCCL backend requires CUDA, "
87 "but torch.cuda.is_available() is False."
88 )
89 local_rank = int(os.environ.get("LOCAL_RANK", 0))
90 if local_rank < 0 or local_rank >= torch.cuda.device_count():
91 raise ValueError(
92 f"LOCAL_RANK {local_rank} is out of range "
93 f"(available GPUs: {torch.cuda.device_count()})."
94 )
95 torch.cuda.set_device(local_rank)
96 if not torch.distributed.is_initialized():
97 torch.distributed.init_process_group(backend="nccl")
98 rank = torch.distributed.get_rank()
99 world_size = torch.distributed.get_world_size()
100
101 # Runtime
102 exp_name = trainer_cfg.get("exp_name", "generic_exp")
103 seed = int(trainer_cfg.get("seed", 1))
104 device_str = trainer_cfg.get("device", "cpu")
105 if distributed:
106 device_str = f"cuda:{local_rank}"
107 iterations = int(trainer_cfg.get("iterations", 250))
108 buffer_size = int(
109 trainer_cfg.get("buffer_size", trainer_cfg.get("rollout_steps", 2048))
110 )
111 enable_eval = bool(trainer_cfg.get("enable_eval", False))
112 eval_freq = int(trainer_cfg.get("eval_freq", 10000))
113 save_freq = int(trainer_cfg.get("save_freq", 50000))
114 num_eval_episodes = int(trainer_cfg.get("num_eval_episodes", 5))
115 headless = bool(trainer_cfg.get("headless", True))
116 enable_rt = bool(trainer_cfg.get("enable_rt", False))
117 gpu_id = int(trainer_cfg.get("gpu_id", 0))
118 num_envs = trainer_cfg.get("num_envs", None)
119 wandb_project_name = trainer_cfg.get("wandb_project_name", "embodichain-generic")
120
121 # Device
122 if not isinstance(device_str, str):
123 raise ValueError(
124 f"runtime.device must be a string such as 'cpu' or 'cuda:0'. Got: {device_str!r}"
125 )
126 try:
127 device = torch.device(device_str)
128 except RuntimeError as exc:
129 raise ValueError(
130 f"Failed to parse runtime.device='{device_str}': {exc}"
131 ) from exc
132
133 if device.type == "cuda":
134 if not torch.cuda.is_available():
135 raise ValueError(
136 "CUDA device requested but torch.cuda.is_available() is False."
137 )
138 index = (
139 device.index if device.index is not None else torch.cuda.current_device()
140 )
141 device_count = torch.cuda.device_count()
142 if index < 0 or index >= device_count:
143 raise ValueError(
144 f"CUDA device index {index} is out of range (available devices: {device_count})."
145 )
146 torch.cuda.set_device(index)
147 device = torch.device(f"cuda:{index}")
148 elif device.type != "cpu":
149 raise ValueError(f"Unsupported device type: {device}")
150 if rank == 0:
151 logger.log_info(f"Device: {device}")
152 if distributed and rank == 0:
153 logger.log_info(f"Distributed training: world_size={world_size}")
154
155 # Seeds
156 effective_seed = seed + rank
157 np.random.seed(effective_seed)
158 torch.manual_seed(effective_seed)
159 torch.backends.cudnn.deterministic = True
160 if device.type == "cuda":
161 torch.cuda.manual_seed_all(effective_seed)
162
163 # Outputs
164 if distributed:
165 run_stamp = time.strftime("%Y%m%d_%H%M%S") if rank == 0 else None
166 run_stamp_list = [run_stamp]
167 torch.distributed.broadcast_object_list(run_stamp_list, src=0)
168 run_stamp = run_stamp_list[0]
169 else:
170 run_stamp = time.strftime("%Y%m%d_%H%M%S")
171 run_base = os.path.join("outputs", f"{exp_name}_{run_stamp}")
172 log_dir = os.path.join(run_base, "logs")
173 checkpoint_dir = os.path.join(run_base, "checkpoints")
174 if rank == 0:
175 os.makedirs(log_dir, exist_ok=True)
176 os.makedirs(checkpoint_dir, exist_ok=True)
177 writer = SummaryWriter(f"{log_dir}/{exp_name}") if rank == 0 else None
178
179 # Initialize Weights & Biases (optional)
180 use_wandb = trainer_cfg.get("use_wandb", False)
181 if use_wandb and rank == 0:
182 wandb.init(project=wandb_project_name, name=exp_name, config=cfg_json)
183
184 gym_config_path = Path(trainer_cfg["gym_config"])
185 if rank == 0:
186 logger.log_info(f"Current working directory: {Path.cwd()}")
187
188 gym_config_data = load_json(str(gym_config_path))
189 gym_env_cfg = config_to_cfg(
190 gym_config_data, manager_modules=DEFAULT_MANAGER_MODULES
191 )
192 if num_envs is not None:
193 gym_env_cfg.num_envs = int(num_envs)
194
195 # Ensure sim configuration mirrors runtime overrides
196 if gym_env_cfg.sim_cfg is None:
197 gym_env_cfg.sim_cfg = SimulationManagerCfg()
198 if device.type == "cuda":
199 gpu_index = device.index
200 if gpu_index is None:
201 gpu_index = torch.cuda.current_device()
202 gym_env_cfg.sim_cfg.sim_device = torch.device(f"cuda:{gpu_index}")
203 if hasattr(gym_env_cfg.sim_cfg, "gpu_id"):
204 gym_env_cfg.sim_cfg.gpu_id = gpu_index
205 else:
206 gym_env_cfg.sim_cfg.sim_device = torch.device("cpu")
207 gym_env_cfg.sim_cfg.headless = headless
208 gym_env_cfg.sim_cfg.enable_rt = enable_rt
209 gym_env_cfg.sim_cfg.gpu_id = local_rank if distributed else gpu_id
210
211 if rank == 0:
212 logger.log_info(
213 f"Loaded gym_config from {gym_config_path} (env_id={gym_config_data['id']}, num_envs={gym_env_cfg.num_envs}, headless={gym_env_cfg.sim_cfg.headless}, enable_rt={gym_env_cfg.sim_cfg.enable_rt}, sim_device={gym_env_cfg.sim_cfg.sim_device})"
214 )
215
216 env = build_env(gym_config_data["id"], base_env_cfg=gym_env_cfg)
217 sample_obs, _ = env.reset()
218 sample_obs_td = dict_to_tensordict(sample_obs, device)
219 obs_dim = flatten_dict_observation(sample_obs_td).shape[-1]
220 flat_obs_space = env.flattened_observation_space
221
222 # Create evaluation environment only if enabled
223 eval_env = None
224 num_eval_envs = trainer_cfg.get("num_eval_envs", 4)
225 if enable_eval and rank == 0:
226 eval_gym_env_cfg = deepcopy(gym_env_cfg)
227 eval_gym_env_cfg.num_envs = num_eval_envs
228 eval_gym_env_cfg.sim_cfg.headless = True
229 eval_env = build_env(gym_config_data["id"], base_env_cfg=eval_gym_env_cfg)
230 logger.log_info(
231 f"Evaluation environment created (num_envs={num_eval_envs}, headless=True)"
232 )
233
234 # Build Policy via registry
235 policy_name = policy_block["name"]
236 env_action_dim = (
237 env.get_wrapper_attr("action_manager").total_action_dim
238 if env.get_wrapper_attr("action_manager") is not None
239 else len(env.get_wrapper_attr("active_joint_ids"))
240 )
241 action_dim = policy_block.get("action_dim", env_action_dim)
242 action_dim = int(action_dim)
243 if action_dim != env_action_dim:
244 raise ValueError(
245 f"Configured policy.action_dim={action_dim} does not match env action dim {env_action_dim}."
246 )
247 # Build Policy via registry (actor/critic must be explicitly defined in JSON when using actor_critic/actor_only)
248 if policy_name.lower() == "actor_critic":
249 actor_cfg = policy_block.get("actor")
250 critic_cfg = policy_block.get("critic")
251 if actor_cfg is None or critic_cfg is None:
252 raise ValueError(
253 "ActorCritic requires 'actor' and 'critic' definitions in JSON (policy.actor / policy.critic)."
254 )
255
256 actor = build_mlp_from_cfg(actor_cfg, obs_dim, action_dim)
257 critic = build_mlp_from_cfg(critic_cfg, obs_dim, 1)
258
259 policy = build_policy(
260 policy_block,
261 flat_obs_space,
262 env.action_space,
263 device,
264 actor=actor,
265 critic=critic,
266 )
267 elif policy_name.lower() == "actor_only":
268 actor_cfg = policy_block.get("actor")
269 if actor_cfg is None:
270 raise ValueError(
271 "ActorOnly requires 'actor' definition in JSON (policy.actor)."
272 )
273
274 actor = build_mlp_from_cfg(actor_cfg, obs_dim, action_dim)
275
276 policy = build_policy(
277 policy_block,
278 flat_obs_space,
279 env.action_space,
280 device,
281 actor=actor,
282 )
283 else:
284 policy = build_policy(
285 policy_block, env.observation_space, env.action_space, device
286 )
287
288 # Build Algorithm via factory
289 algo_name = algo_block["name"].lower()
290 algo_cfg = algo_block["cfg"]
291 algo = build_algo(
292 algo_name,
293 algo_cfg,
294 policy,
295 device,
296 distributed=distributed,
297 )
298
299 # Build Trainer
300 event_modules = [
301 "embodichain.lab.gym.envs.managers.randomization",
302 "embodichain.lab.gym.envs.managers.record",
303 "embodichain.lab.gym.envs.managers.events",
304 ]
305 events_dict = trainer_cfg.get("events", {})
306 train_event_cfg = {}
307 eval_event_cfg = {}
308 # Parse train events
309 for event_name, event_info in events_dict.get("train", {}).items():
310 event_func_str = event_info.get("func")
311 mode = event_info.get("mode", "interval")
312 params = event_info.get("params", {})
313 interval_step = event_info.get("interval_step", 1)
314 event_func = find_function_from_modules(
315 event_func_str, event_modules, raise_if_not_found=True
316 )
317 train_event_cfg[event_name] = EventCfg(
318 func=event_func,
319 mode=mode,
320 params=params,
321 interval_step=interval_step,
322 )
323 # Parse eval events (only if evaluation is enabled)
324 if enable_eval:
325 for event_name, event_info in events_dict.get("eval", {}).items():
326 event_func_str = event_info.get("func")
327 mode = event_info.get("mode", "interval")
328 params = event_info.get("params", {})
329 interval_step = event_info.get("interval_step", 1)
330 event_func = find_function_from_modules(
331 event_func_str, event_modules, raise_if_not_found=True
332 )
333 eval_event_cfg[event_name] = EventCfg(
334 func=event_func,
335 mode=mode,
336 params=params,
337 interval_step=interval_step,
338 )
339 trainer = Trainer(
340 policy=policy,
341 env=env,
342 algorithm=algo,
343 buffer_size=buffer_size,
344 batch_size=algo_cfg["batch_size"],
345 writer=writer,
346 eval_freq=eval_freq if enable_eval else 0, # Disable eval if not enabled
347 save_freq=save_freq,
348 checkpoint_dir=checkpoint_dir,
349 exp_name=exp_name,
350 use_wandb=use_wandb,
351 eval_env=eval_env, # None if enable_eval=False
352 event_cfg=train_event_cfg,
353 eval_event_cfg=eval_event_cfg if (enable_eval and rank == 0) else {},
354 num_eval_episodes=num_eval_episodes,
355 distributed=distributed,
356 rank=rank,
357 world_size=world_size,
358 )
359
360 if rank == 0:
361 logger.log_info("Generic training initialized")
362 logger.log_info(f"Task: {type(env).__name__}")
363 logger.log_info(
364 f"Policy: {policy_name} (available: {get_registered_policy_names()})"
365 )
366 logger.log_info(
367 f"Algorithm: {algo_name} (available: {get_registered_algo_names()})"
368 )
369
370 total_steps = int(iterations * buffer_size * env.num_envs * world_size)
371 if rank == 0:
372 logger.log_info(
373 f"Total steps: {total_steps} (iterations≈{iterations}, world_size={world_size})"
374 )
375
376 try:
377 trainer.train(total_steps)
378 except KeyboardInterrupt:
379 if rank == 0:
380 logger.log_info("Training interrupted by user")
381 finally:
382 trainer.save_checkpoint()
383 if writer is not None:
384 writer.close()
385 if use_wandb and rank == 0:
386 try:
387 wandb.finish()
388 except Exception:
389 pass
390
391 # Clean up environments to prevent resource leaks
392 try:
393 if env is not None:
394 env.close()
395 except Exception as e:
396 if rank == 0:
397 logger.log_warning(f"Failed to close training environment: {e}")
398
399 try:
400 if eval_env is not None:
401 eval_env.close()
402 except Exception as e:
403 if rank == 0:
404 logger.log_warning(f"Failed to close evaluation environment: {e}")
405
406 if distributed and torch.distributed.is_initialized():
407 torch.distributed.destroy_process_group()
408
409 if rank == 0:
410 logger.log_info("Training finished")
411
412
413def main():
414 """Main entry point for command-line training."""
415 args = parse_args()
416 train_from_config(args.config, distributed=args.distributed)
417
418
419if __name__ == "__main__":
420 main()
The Script Explained#
The training script performs the following steps:
Parse Configuration: Loads JSON config and extracts runtime/env/policy/algorithm blocks
Setup: Initializes device, seeds, output directories, TensorBoard, and Weights & Biases
Build Components: - Environment via
build_env()factory - Policy viabuild_policy()registry - Algorithm viabuild_algo()factoryCreate Trainer: Instantiates the
Trainerwith all componentsTrain: Runs the training loop until completion
Launching Training#
To start training, run:
python -m embodichain.agents.rl.train --config configs/agents/rl/push_cube/train_config.json
Outputs#
All outputs are written to ./outputs/<exp_name>_<timestamp>/:
logs/: TensorBoard logs
checkpoints/: Model checkpoints
Training Process#
The training process follows this sequence:
Rollout Phase:
SyncCollectorinteracts with the environment and writes policy-side fields into a shared rolloutTensorDictwith uniform[N, T + 1]layout.EmbodiedEnvwrites environment-side step fields such asreward,done,terminated, andtruncatedinto the same rollout viaset_rollout_buffer(). The final slot of transition-only fields is reserved as padding, whileobs[:, -1]andvalue[:, -1]remain valid bootstrap data.Advantage/Return Computation: Algorithm computes advantages and returns from the collected rollout (e.g. GAE for PPO, step-wise group normalization for GRPO) and converts it to a transition-aligned view over the valid first
Tsteps before minibatch optimization.Update Phase: Algorithm updates the policy with
update(rollout)Logging: Trainer logs training losses and aggregated metrics to TensorBoard and Weights & Biases
Evaluation (periodic): Trainer evaluates the current policy
Checkpointing (periodic): Trainer saves model checkpoints
Policy Interface#
All policies must inherit from the Policy abstract base class:
from abc import ABC, abstractmethod
import torch.nn as nn
class Policy(nn.Module, ABC):
device: torch.device
def get_action(self, tensordict, deterministic: bool = False):
"""Samples action, sample_log_prob, and value into the TensorDict."""
...
@abstractmethod
def forward(self, tensordict, deterministic: bool = False):
"""Writes action, sample_log_prob, and value into the TensorDict."""
raise NotImplementedError
@abstractmethod
def get_value(self, tensordict):
"""Writes value estimate into the TensorDict."""
raise NotImplementedError
@abstractmethod
def evaluate_actions(self, tensordict):
"""Returns a new TensorDict with log_prob, entropy, and value."""
raise NotImplementedError
Available Policies#
ActorCritic: MLP-based Gaussian policy with learnable log_std. Requires external
actorandcriticmodules to be provided (defined in JSON config). Used with PPO.ActorOnly: Actor-only policy without Critic. Used with GRPO (group-relative advantage estimation).
VLAPlaceholderPolicy: Placeholder for Vision-Language-Action policies
Algorithms#
Available Algorithms#
PPO: Proximal Policy Optimization with GAE
GRPO: Group Relative Policy Optimization (no Critic, step-wise returns, masked group normalization). Use
actor_onlypolicy. Setkl_coef=0for from-scratch training (CartPole, dense reward);kl_coef=0.02for VLA/LLM fine-tuning.
Adding a New Algorithm#
To add a new algorithm:
Create a new algorithm class in
embodichain/agents/rl/algo/Implement
update(rollout)and consume the shared rolloutTensorDictRegister in
algo/__init__.py:
from tensordict import TensorDict
from embodichain.agents.rl.algo import BaseAlgorithm, register_algo
@register_algo("my_algo")
class MyAlgorithm(BaseAlgorithm):
def __init__(self, cfg, policy):
self.cfg = cfg
self.policy = policy
self.device = torch.device(cfg.device)
def update(self, rollout: TensorDict):
"""Update the policy using a collected rollout."""
# compute advantages / returns from rollout
# optimize policy parameters
return {"loss": 0.0}
Adding a New Policy#
To add a new policy:
Create a new policy class inheriting from the
Policyabstract base classRegister in
models/__init__.py:
from embodichain.agents.rl.models import register_policy, Policy
@register_policy("my_policy")
class MyPolicy(Policy):
def __init__(self, obs_dim, action_dim, device, config):
super().__init__()
self.device = device
# Initialize your networks here
def get_action(self, tensordict, deterministic=False):
...
def forward(self, tensordict, deterministic=False):
...
def get_value(self, tensordict):
...
def evaluate_actions(self, tensordict):
...
Current built-in MLP policies use flattened observations in the training path. If your policy requires structured or multi-modal inputs, keep the richer obs_space interface and define a matching rollout/collector schema.
Adding a New Environment#
To add a new RL environment:
Create an environment class inheriting from
EmbodiedEnv(with Action Manager configured for action preprocessing and standardized info structure):
from embodichain.lab.gym.envs import EmbodiedEnv, EmbodiedEnvCfg
from embodichain.lab.gym.utils.registration import register_env
import torch
@register_env("MyTaskRL", override=True)
class MyTaskEnv(EmbodiedEnv):
def __init__(self, cfg: EmbodiedEnvCfg = None, **kwargs):
super().__init__(cfg, **kwargs)
def compute_task_state(self, **kwargs):
"""Compute success/failure conditions and metrics."""
is_success = ... # Define success condition
is_fail = torch.zeros_like(is_success)
metrics = {"distance": ..., "error": ...}
return is_success, is_fail, metrics
Configure the environment in your JSON config with
actionsandextensions:
"env": {
"id": "MyTaskRL",
"cfg": {
"num_envs": 4,
"actions": {
"delta_qpos": {
"func": "DeltaQposTerm",
"params": { "scale": 0.1 }
}
},
"extensions": {
"success_threshold": 0.05
}
}
}
The EmbodiedEnv with Action Manager provides:
Action Preprocessing: Configurable via
actions(DeltaQposTerm, QposTerm, EefPoseTerm, etc.)Standardized Info: Implements
get_info()usingcompute_task_state()template method
Best Practices#
Use EmbodiedEnv with Action Manager for RL Tasks: Inherit from
EmbodiedEnvand configureactionsin your config. The Action Manager handles action preprocessing (delta_qpos, qpos, qvel, qf, eef_pose) in a modular way.Action Configuration: Use the
actionsfield in your JSON config. Example:"delta_qpos": {"func": "DeltaQposTerm", "params": {"scale": 0.1}}.Device Management: Device is single-sourced from
runtime.cuda. All components (trainer/algorithm/policy/env) share the same device.Observation Format: Environments should provide consistent observation shape/types (torch.float32) and a single
done = terminated | truncated.Algorithm Interface: Algorithms implement
update(rollout)and consume a shared rolloutTensorDict. Collection is handled bySyncCollectorplus environment-side rollout writes inEmbodiedEnv.Reward Configuration: Use the
RewardManagerin your environment config to define reward components. Organize reward components ininfo["rewards"]dictionary and metrics ininfo["metrics"]dictionary. The trainer performs dense per-step logging directly from environment info.Template Methods: Override
compute_task_state()to define success/failure conditions and metrics. Overridecheck_truncated()for custom truncation logic.Configuration: Use JSON for all hyperparameters. This makes experiments reproducible and easy to track.
Logging: Metrics are automatically logged to TensorBoard and Weights & Biases. Check
outputs/<exp_name>/logs/for TensorBoard logs.Checkpoints: Regular checkpoints are saved to
outputs/<exp_name>/checkpoints/. Use these to resume training or evaluate policies.