Reinforcement Learning Training#
This tutorial shows you how to train reinforcement learning agents using EmbodiChain’s RL framework. You’ll learn how to configure training via JSON, set up environments, policies, and algorithms, and launch training sessions.
Overview#
The RL framework provides a modular, extensible stack for robotics tasks:
Trainer: Orchestrates the training loop (calls algorithm for data collection and updates, handles logging/eval/save)
Algorithm: Controls data collection process (interacts with environment, fills buffer, computes advantages/returns) and updates the policy (e.g., PPO)
Policy: Neural network models implementing a unified interface (get_action/get_value/evaluate_actions)
Buffer: On-policy rollout storage and minibatch iterator (managed by algorithm)
Env Factory: Build environments from a JSON config via registry
Architecture#
The framework follows a clean separation of concerns:
Trainer: Orchestrates the training loop (calls algorithm for data collection and updates, handles logging/eval/save)
Algorithm: Controls data collection process (interacts with environment, fills buffer, computes advantages/returns) and updates the policy (e.g., PPO)
Policy: Neural network models implementing a unified interface
Buffer: On-policy rollout storage and minibatch iterator (managed by algorithm)
Env Factory: Build environments from a JSON config via registry
The core components and their relationships:
Trainer → Policy, Env, Algorithm (via callbacks for statistics)
Algorithm → Policy, RolloutBuffer (algorithm manages its own buffer)
Configuration via JSON#
Training is configured via a JSON file that defines runtime settings, environment, policy, and algorithm parameters.
Example Configuration#
The configuration file (e.g., train_config.json) is located in configs/agents/rl/push_cube:
Example: train_config.json
1{
2 "trainer": {
3 "exp_name": "push_cube_ppo",
4 "gym_config": "configs/agents/rl/push_cube/gym_config.json",
5 "seed": 42,
6 "device": "cuda:0",
7 "headless": true,
8 "enable_rt": false,
9 "gpu_id": 0,
10 "num_envs": 64,
11 "iterations": 1000,
12 "rollout_steps": 1024,
13 "enable_eval": true,
14 "num_eval_envs": 16,
15 "num_eval_episodes": 3,
16 "eval_freq": 2,
17 "save_freq": 200,
18 "use_wandb": false,
19 "wandb_project_name": "embodichain-push_cube",
20 "events": {
21 "eval": {
22 "record_camera": {
23 "func": "record_camera_data_async",
24 "mode": "interval",
25 "interval_step": 1,
26 "params": {
27 "name": "main_cam",
28 "resolution": [640, 480],
29 "eye": [-1.4, 1.4, 2.0],
30 "target": [0, 0, 0],
31 "up": [0, 0, 1],
32 "intrinsics": [600, 600, 320, 240],
33 "save_path": "./outputs/videos/eval"
34 }
35 }
36 }
37 }
38 },
39 "policy": {
40 "name": "actor_critic",
41 "actor": {
42 "type": "mlp",
43 "network_cfg": {
44 "hidden_sizes": [256, 256],
45 "activation": "relu"
46 }
47 },
48 "critic": {
49 "type": "mlp",
50 "network_cfg": {
51 "hidden_sizes": [256, 256],
52 "activation": "relu"
53 }
54 }
55 },
56 "algorithm": {
57 "name": "ppo",
58 "cfg": {
59 "learning_rate": 0.0001,
60 "n_epochs": 10,
61 "batch_size": 8192,
62 "gamma": 0.99,
63 "gae_lambda": 0.95,
64 "clip_coef": 0.2,
65 "ent_coef": 0.01,
66 "vf_coef": 0.5,
67 "max_grad_norm": 0.5
68 }
69 }
70}
Configuration Sections#
Runtime Settings#
The runtime section controls experiment setup:
exp_name: Experiment name (used for output directories)
seed: Random seed for reproducibility
cuda: Whether to use GPU (default: true)
headless: Whether to run simulation in headless mode
iterations: Number of training iterations
rollout_steps: Steps per rollout (e.g., 1024)
eval_freq: Frequency of evaluation (in steps)
save_freq: Frequency of checkpoint saving (in steps)
use_wandb: Whether to enable Weights & Biases logging (set in JSON config)
wandb_project_name: Weights & Biases project name
Environment Configuration#
The env section defines the task environment:
id: Environment registry ID (e.g., “PushCubeRL”)
cfg: Environment-specific configuration parameters
For RL environments, use the actions field for action preprocessing and extensions for task-specific parameters:
actions: Action Manager config (e.g., DeltaQposTerm with scale)
extensions: Task-specific parameters (e.g., success_threshold)
Example:
"env": {
"id": "PushCubeRL",
"cfg": {
"num_envs": 4,
"actions": {
"delta_qpos": {
"func": "DeltaQposTerm",
"params": { "scale": 0.1 }
}
},
"extensions": {
"success_threshold": 0.1
}
}
}
Policy Configuration#
The policy section defines the neural network policy:
name: Policy name (e.g., “actor_critic”, “vla”)
cfg: Policy-specific hyperparameters (empty for actor_critic)
actor: Actor network configuration (required for actor_critic)
critic: Critic network configuration (required for actor_critic)
Example:
"policy": {
"name": "actor_critic",
"cfg": {},
"actor": {
"type": "mlp",
"hidden_sizes": [256, 256],
"activation": "relu"
},
"critic": {
"type": "mlp",
"hidden_sizes": [256, 256],
"activation": "relu"
}
}
Algorithm Configuration#
The algorithm section defines the RL algorithm:
name: Algorithm name (e.g., “ppo”, “grpo”)
cfg: Algorithm-specific hyperparameters
PPO example:
"algorithm": {
"name": "ppo",
"cfg": {
"learning_rate": 0.0001,
"n_epochs": 10,
"batch_size": 64,
"gamma": 0.99,
"gae_lambda": 0.95,
"clip_coef": 0.2,
"ent_coef": 0.01,
"vf_coef": 0.5,
"max_grad_norm": 0.5
}
}
GRPO example (for Embodied AI / from-scratch training, e.g. CartPole):
"algorithm": {
"name": "grpo",
"cfg": {
"learning_rate": 0.0001,
"n_epochs": 10,
"batch_size": 8192,
"gamma": 0.99,
"clip_coef": 0.2,
"ent_coef": 0.001,
"kl_coef": 0,
"group_size": 4,
"eps": 1e-8,
"reset_every_rollout": true,
"max_grad_norm": 0.5,
"truncate_at_first_done": true
}
}
For GRPO: use actor_only policy. Set kl_coef=0 for from-scratch training; kl_coef=0.02 for VLA/LLM fine-tuning.
Training Script#
The training script (train.py) is located in embodichain/agents/rl/:
Code for train.py
1# ----------------------------------------------------------------------------
2# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ----------------------------------------------------------------------------
16
17import argparse
18import os
19import time
20from pathlib import Path
21
22import numpy as np
23import torch
24import wandb
25import json
26from torch.utils.tensorboard import SummaryWriter
27from copy import deepcopy
28
29from embodichain.agents.rl.models import build_policy, get_registered_policy_names
30from embodichain.agents.rl.models import build_mlp_from_cfg
31from embodichain.agents.rl.algo import build_algo, get_registered_algo_names
32from embodichain.agents.rl.utils.trainer import Trainer
33from embodichain.utils import logger
34from embodichain.lab.gym.envs.tasks.rl import build_env
35from embodichain.lab.gym.utils.gym_utils import config_to_cfg, DEFAULT_MANAGER_MODULES
36from embodichain.utils.utility import load_json
37from embodichain.utils.module_utils import find_function_from_modules
38from embodichain.lab.sim import SimulationManagerCfg
39from embodichain.lab.gym.envs.managers.cfg import EventCfg
40
41
42def parse_args():
43 """Parse command line arguments."""
44 parser = argparse.ArgumentParser()
45 parser.add_argument("--config", type=str, required=True, help="Path to JSON config")
46 return parser.parse_args()
47
48
49def train_from_config(config_path: str):
50 """Run training from a config file path.
51
52 Args:
53 config_path: Path to the JSON config file
54 """
55 with open(config_path, "r") as f:
56 cfg_json = json.load(f)
57
58 trainer_cfg = cfg_json["trainer"]
59 policy_block = cfg_json["policy"]
60 algo_block = cfg_json["algorithm"]
61
62 # Runtime
63 exp_name = trainer_cfg.get("exp_name", "generic_exp")
64 seed = int(trainer_cfg.get("seed", 1))
65 device_str = trainer_cfg.get("device", "cpu")
66 iterations = int(trainer_cfg.get("iterations", 250))
67 rollout_steps = int(trainer_cfg.get("rollout_steps", 2048))
68 enable_eval = bool(trainer_cfg.get("enable_eval", False))
69 eval_freq = int(trainer_cfg.get("eval_freq", 10000))
70 save_freq = int(trainer_cfg.get("save_freq", 50000))
71 num_eval_episodes = int(trainer_cfg.get("num_eval_episodes", 5))
72 headless = bool(trainer_cfg.get("headless", True))
73 enable_rt = bool(trainer_cfg.get("enable_rt", False))
74 gpu_id = int(trainer_cfg.get("gpu_id", 0))
75 num_envs = trainer_cfg.get("num_envs", None)
76 wandb_project_name = trainer_cfg.get("wandb_project_name", "embodichain-generic")
77
78 # Device
79 if not isinstance(device_str, str):
80 raise ValueError(
81 f"runtime.device must be a string such as 'cpu' or 'cuda:0'. Got: {device_str!r}"
82 )
83 try:
84 device = torch.device(device_str)
85 except RuntimeError as exc:
86 raise ValueError(
87 f"Failed to parse runtime.device='{device_str}': {exc}"
88 ) from exc
89
90 if device.type == "cuda":
91 if not torch.cuda.is_available():
92 raise ValueError(
93 "CUDA device requested but torch.cuda.is_available() is False."
94 )
95 index = (
96 device.index if device.index is not None else torch.cuda.current_device()
97 )
98 device_count = torch.cuda.device_count()
99 if index < 0 or index >= device_count:
100 raise ValueError(
101 f"CUDA device index {index} is out of range (available devices: {device_count})."
102 )
103 torch.cuda.set_device(index)
104 device = torch.device(f"cuda:{index}")
105 elif device.type != "cpu":
106 raise ValueError(f"Unsupported device type: {device}")
107 logger.log_info(f"Device: {device}")
108
109 # Seeds
110 np.random.seed(seed)
111 torch.manual_seed(seed)
112 torch.backends.cudnn.deterministic = True
113 if device.type == "cuda":
114 torch.cuda.manual_seed_all(seed)
115
116 # Outputs
117 run_stamp = time.strftime("%Y%m%d_%H%M%S")
118 run_base = os.path.join("outputs", f"{exp_name}_{run_stamp}")
119 log_dir = os.path.join(run_base, "logs")
120 checkpoint_dir = os.path.join(run_base, "checkpoints")
121 os.makedirs(log_dir, exist_ok=True)
122 os.makedirs(checkpoint_dir, exist_ok=True)
123 writer = SummaryWriter(f"{log_dir}/{exp_name}")
124
125 # Initialize Weights & Biases (optional)
126 use_wandb = trainer_cfg.get("use_wandb", False)
127
128 # Initialize Weights & Biases (optional)
129 if use_wandb:
130 wandb.init(project=wandb_project_name, name=exp_name, config=cfg_json)
131
132 gym_config_path = Path(trainer_cfg["gym_config"])
133 logger.log_info(f"Current working directory: {Path.cwd()}")
134
135 gym_config_data = load_json(str(gym_config_path))
136 gym_env_cfg = config_to_cfg(
137 gym_config_data, manager_modules=DEFAULT_MANAGER_MODULES
138 )
139 if num_envs is not None:
140 gym_env_cfg.num_envs = int(num_envs)
141
142 if num_envs is not None:
143 gym_env_cfg.num_envs = num_envs
144
145 # Ensure sim configuration mirrors runtime overrides
146 if gym_env_cfg.sim_cfg is None:
147 gym_env_cfg.sim_cfg = SimulationManagerCfg()
148 if device.type == "cuda":
149 gpu_index = device.index
150 if gpu_index is None:
151 gpu_index = torch.cuda.current_device()
152 gym_env_cfg.sim_cfg.sim_device = torch.device(f"cuda:{gpu_index}")
153 if hasattr(gym_env_cfg.sim_cfg, "gpu_id"):
154 gym_env_cfg.sim_cfg.gpu_id = gpu_index
155 else:
156 gym_env_cfg.sim_cfg.sim_device = torch.device("cpu")
157 gym_env_cfg.sim_cfg.headless = headless
158 gym_env_cfg.sim_cfg.enable_rt = enable_rt
159 gym_env_cfg.sim_cfg.gpu_id = gpu_id
160
161 logger.log_info(
162 f"Loaded gym_config from {gym_config_path} (env_id={gym_config_data['id']}, num_envs={gym_env_cfg.num_envs}, headless={gym_env_cfg.sim_cfg.headless}, enable_rt={gym_env_cfg.sim_cfg.enable_rt}, sim_device={gym_env_cfg.sim_cfg.sim_device})"
163 )
164
165 env = build_env(gym_config_data["id"], base_env_cfg=gym_env_cfg)
166
167 # Create evaluation environment only if enabled
168 eval_env = None
169 num_eval_envs = trainer_cfg.get("num_eval_envs", 4)
170 if enable_eval:
171 eval_gym_env_cfg = deepcopy(gym_env_cfg)
172 eval_gym_env_cfg.num_envs = num_eval_envs
173 eval_gym_env_cfg.sim_cfg.headless = True
174 eval_env = build_env(gym_config_data["id"], base_env_cfg=eval_gym_env_cfg)
175 logger.log_info(
176 f"Evaluation environment created (num_envs={num_eval_envs}, headless=True)"
177 )
178
179 # Build Policy via registry
180 policy_name = policy_block["name"]
181 # Build Policy via registry (actor/critic must be explicitly defined in JSON when using actor_critic/actor_only)
182 if policy_name.lower() == "actor_critic":
183 # Get observation dimension from flattened observation space
184 # flattened_observation_space returns Box space for RL training
185 obs_dim = env.flattened_observation_space.shape[-1]
186 action_dim = env.action_space.shape[-1]
187
188 actor_cfg = policy_block.get("actor")
189 critic_cfg = policy_block.get("critic")
190 if actor_cfg is None or critic_cfg is None:
191 raise ValueError(
192 "ActorCritic requires 'actor' and 'critic' definitions in JSON (policy.actor / policy.critic)."
193 )
194
195 actor = build_mlp_from_cfg(actor_cfg, obs_dim, action_dim)
196 critic = build_mlp_from_cfg(critic_cfg, obs_dim, 1)
197
198 policy = build_policy(
199 policy_block,
200 env.flattened_observation_space,
201 env.action_space,
202 device,
203 actor=actor,
204 critic=critic,
205 )
206 elif policy_name.lower() == "actor_only":
207 obs_dim = env.flattened_observation_space.shape[-1]
208 action_dim = env.action_space.shape[-1]
209
210 actor_cfg = policy_block.get("actor")
211 if actor_cfg is None:
212 raise ValueError(
213 "ActorOnly requires 'actor' definition in JSON (policy.actor)."
214 )
215
216 actor = build_mlp_from_cfg(actor_cfg, obs_dim, action_dim)
217
218 policy = build_policy(
219 policy_block,
220 env.flattened_observation_space,
221 env.action_space,
222 device,
223 actor=actor,
224 )
225 else:
226 policy = build_policy(
227 policy_block, env.flattened_observation_space, env.action_space, device
228 )
229
230 # Build Algorithm via factory
231 algo_name = algo_block["name"].lower()
232 algo_cfg = algo_block["cfg"]
233 algo = build_algo(algo_name, algo_cfg, policy, device)
234
235 # Build Trainer
236 event_modules = [
237 "embodichain.lab.gym.envs.managers.randomization",
238 "embodichain.lab.gym.envs.managers.record",
239 "embodichain.lab.gym.envs.managers.events",
240 ]
241 events_dict = trainer_cfg.get("events", {})
242 train_event_cfg = {}
243 eval_event_cfg = {}
244 # Parse train events
245 for event_name, event_info in events_dict.get("train", {}).items():
246 event_func_str = event_info.get("func")
247 mode = event_info.get("mode", "interval")
248 params = event_info.get("params", {})
249 interval_step = event_info.get("interval_step", 1)
250 event_func = find_function_from_modules(
251 event_func_str, event_modules, raise_if_not_found=True
252 )
253 train_event_cfg[event_name] = EventCfg(
254 func=event_func,
255 mode=mode,
256 params=params,
257 interval_step=interval_step,
258 )
259 # Parse eval events (only if evaluation is enabled)
260 if enable_eval:
261 for event_name, event_info in events_dict.get("eval", {}).items():
262 event_func_str = event_info.get("func")
263 mode = event_info.get("mode", "interval")
264 params = event_info.get("params", {})
265 interval_step = event_info.get("interval_step", 1)
266 event_func = find_function_from_modules(
267 event_func_str, event_modules, raise_if_not_found=True
268 )
269 eval_event_cfg[event_name] = EventCfg(
270 func=event_func,
271 mode=mode,
272 params=params,
273 interval_step=interval_step,
274 )
275 trainer = Trainer(
276 policy=policy,
277 env=env,
278 algorithm=algo,
279 num_steps=rollout_steps,
280 batch_size=algo_cfg["batch_size"],
281 writer=writer,
282 eval_freq=eval_freq if enable_eval else 0, # Disable eval if not enabled
283 save_freq=save_freq,
284 checkpoint_dir=checkpoint_dir,
285 exp_name=exp_name,
286 use_wandb=use_wandb,
287 eval_env=eval_env, # None if enable_eval=False
288 event_cfg=train_event_cfg,
289 eval_event_cfg=eval_event_cfg if enable_eval else {},
290 num_eval_episodes=num_eval_episodes,
291 )
292
293 logger.log_info("Generic training initialized")
294 logger.log_info(f"Task: {type(env).__name__}")
295 logger.log_info(
296 f"Policy: {policy_name} (available: {get_registered_policy_names()})"
297 )
298 logger.log_info(
299 f"Algorithm: {algo_name} (available: {get_registered_algo_names()})"
300 )
301
302 total_steps = int(iterations * rollout_steps * env.num_envs)
303 logger.log_info(f"Total steps: {total_steps} (iterations≈{iterations})")
304
305 try:
306 trainer.train(total_steps)
307 except KeyboardInterrupt:
308 logger.log_info("Training interrupted by user")
309 finally:
310 trainer.save_checkpoint()
311 writer.close()
312 if use_wandb:
313 try:
314 wandb.finish()
315 except Exception:
316 pass
317
318 # Clean up environments to prevent resource leaks
319 try:
320 if env is not None:
321 env.close()
322 except Exception as e:
323 logger.log_warning(f"Failed to close training environment: {e}")
324
325 try:
326 if eval_env is not None:
327 eval_env.close()
328 except Exception as e:
329 logger.log_warning(f"Failed to close evaluation environment: {e}")
330
331 logger.log_info("Training finished")
332
333
334def main():
335 """Main entry point for command-line training."""
336 args = parse_args()
337 train_from_config(args.config)
338
339
340if __name__ == "__main__":
341 main()
The Script Explained#
The training script performs the following steps:
Parse Configuration: Loads JSON config and extracts runtime/env/policy/algorithm blocks
Setup: Initializes device, seeds, output directories, TensorBoard, and Weights & Biases
Build Components: - Environment via
build_env()factory - Policy viabuild_policy()registry - Algorithm viabuild_algo()factoryCreate Trainer: Instantiates the
Trainerwith all componentsTrain: Runs the training loop until completion
Launching Training#
To start training, run:
python -m embodichain.agents.rl.train --config configs/agents/rl/push_cube/train_config.json
Outputs#
All outputs are written to ./outputs/<exp_name>_<timestamp>/:
logs/: TensorBoard logs
checkpoints/: Model checkpoints
Training Process#
The training process follows this sequence:
Rollout Phase: Algorithm collects trajectories by interacting with the environment (via
collect_rollout). During this phase, the trainer performs dense per-step logging of rewards and metrics from environment info.Advantage/Return Computation: Algorithm computes advantages and returns (e.g. GAE for PPO, step-wise group normalization for GRPO; stored in buffer extras)
Update Phase: Algorithm updates the policy using collected data (e.g., PPO)
Logging: Trainer logs training losses and aggregated metrics to TensorBoard and Weights & Biases
Evaluation (periodic): Trainer evaluates the current policy
Checkpointing (periodic): Trainer saves model checkpoints
Policy Interface#
All policies must inherit from the Policy abstract base class:
from abc import ABC, abstractmethod
import torch.nn as nn
class Policy(nn.Module, ABC):
device: torch.device
@abstractmethod
def get_action(
self, obs: torch.Tensor, deterministic: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Returns (action, log_prob, value)"""
raise NotImplementedError
@abstractmethod
def get_value(self, obs: torch.Tensor) -> torch.Tensor:
"""Returns value estimate"""
raise NotImplementedError
@abstractmethod
def evaluate_actions(
self, obs: torch.Tensor, actions: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Returns (log_prob, entropy, value)"""
raise NotImplementedError
Available Policies#
ActorCritic: MLP-based Gaussian policy with learnable log_std. Requires external
actorandcriticmodules to be provided (defined in JSON config). Used with PPO.ActorOnly: Actor-only policy without Critic. Used with GRPO (group-relative advantage estimation).
VLAPlaceholderPolicy: Placeholder for Vision-Language-Action policies
Algorithms#
Available Algorithms#
PPO: Proximal Policy Optimization with GAE
GRPO: Group Relative Policy Optimization (no Critic, step-wise returns, masked group normalization). Use
actor_onlypolicy. Setkl_coef=0for from-scratch training (CartPole, dense reward);kl_coef=0.02for VLA/LLM fine-tuning.
Adding a New Algorithm#
To add a new algorithm:
Create a new algorithm class in
embodichain/agents/rl/algo/Implement
initialize_buffer(),collect_rollout(), andupdate()methodsRegister in
algo/__init__.py:
from embodichain.agents.rl.algo import BaseAlgorithm, register_algo
from embodichain.agents.rl.buffer import RolloutBuffer
@register_algo("my_algo")
class MyAlgorithm(BaseAlgorithm):
def __init__(self, cfg, policy):
self.cfg = cfg
self.policy = policy
self.device = torch.device(cfg.device)
self.buffer = None
def initialize_buffer(self, num_steps, num_envs, obs_dim, action_dim):
"""Initialize the algorithm's buffer."""
self.buffer = RolloutBuffer(num_steps, num_envs, obs_dim, action_dim, self.device)
def collect_rollout(self, env, policy, obs, num_steps, on_step_callback=None):
"""Control data collection process (interact with env, fill buffer, compute advantages/returns)."""
# Collect trajectories
# Compute advantages/returns (e.g., GAE for on-policy algorithms)
# Attach extras to buffer: self.buffer.set_extras({"advantages": adv, "returns": ret})
# Return empty dict (dense logging handled in trainer)
return {}
def update(self):
"""Update the policy using collected data."""
# Access extras from buffer: self.buffer._extras.get("advantages")
# Use self.buffer to update policy
return {"loss": 0.0}
Adding a New Policy#
To add a new policy:
Create a new policy class inheriting from the
Policyabstract base classRegister in
models/__init__.py:
from embodichain.agents.rl.models import register_policy, Policy
@register_policy("my_policy")
class MyPolicy(Policy):
def __init__(self, obs_space, action_space, device, config):
super().__init__()
self.device = device
# Initialize your networks here
def get_action(self, obs, deterministic=False):
...
def get_value(self, obs):
...
def evaluate_actions(self, obs, actions):
...
Adding a New Environment#
To add a new RL environment:
Create an environment class inheriting from
EmbodiedEnv(with Action Manager configured for action preprocessing and standardized info structure):
from embodichain.lab.gym.envs import EmbodiedEnv, EmbodiedEnvCfg
from embodichain.lab.gym.utils.registration import register_env
import torch
@register_env("MyTaskRL", override=True)
class MyTaskEnv(EmbodiedEnv):
def __init__(self, cfg: EmbodiedEnvCfg = None, **kwargs):
super().__init__(cfg, **kwargs)
def compute_task_state(self, **kwargs):
"""Compute success/failure conditions and metrics."""
is_success = ... # Define success condition
is_fail = torch.zeros_like(is_success)
metrics = {"distance": ..., "error": ...}
return is_success, is_fail, metrics
Configure the environment in your JSON config with
actionsandextensions:
"env": {
"id": "MyTaskRL",
"cfg": {
"num_envs": 4,
"actions": {
"delta_qpos": {
"func": "DeltaQposTerm",
"params": { "scale": 0.1 }
}
},
"extensions": {
"success_threshold": 0.05
}
}
}
The EmbodiedEnv with Action Manager provides:
Action Preprocessing: Configurable via
actions(DeltaQposTerm, QposTerm, EefPoseTerm, etc.)Standardized Info: Implements
get_info()usingcompute_task_state()template method
Best Practices#
Use EmbodiedEnv with Action Manager for RL Tasks: Inherit from
EmbodiedEnvand configureactionsin your config. The Action Manager handles action preprocessing (delta_qpos, qpos, qvel, qf, eef_pose) in a modular way.Action Configuration: Use the
actionsfield in your JSON config. Example:"delta_qpos": {"func": "DeltaQposTerm", "params": {"scale": 0.1}}.Device Management: Device is single-sourced from
runtime.cuda. All components (trainer/algorithm/policy/env) share the same device.Observation Format: Environments should provide consistent observation shape/types (torch.float32) and a single
done = terminated | truncated.Algorithm Interface: Algorithms must implement
initialize_buffer(),collect_rollout(), andupdate()methods. The algorithm completely controls data collection and buffer management.Reward Configuration: Use the
RewardManagerin your environment config to define reward components. Organize reward components ininfo["rewards"]dictionary and metrics ininfo["metrics"]dictionary. The trainer performs dense per-step logging directly from environment info.Template Methods: Override
compute_task_state()to define success/failure conditions and metrics. Overridecheck_truncated()for custom truncation logic.Configuration: Use JSON for all hyperparameters. This makes experiments reproducible and easy to track.
Logging: Metrics are automatically logged to TensorBoard and Weights & Biases. Check
outputs/<exp_name>/logs/for TensorBoard logs.Checkpoints: Regular checkpoints are saved to
outputs/<exp_name>/checkpoints/. Use these to resume training or evaluate policies.