Reinforcement Learning Training#

This tutorial shows you how to train reinforcement learning agents using EmbodiChain’s RL framework. You’ll learn how to configure training via JSON, set up environments, policies, and algorithms, and launch training sessions.

Overview#

The RL framework provides a modular, extensible stack for robotics tasks:

  • Trainer: Orchestrates the training loop (calls algorithm for data collection and updates, handles logging/eval/save)

  • Algorithm: Controls data collection process (interacts with environment, fills buffer, computes advantages/returns) and updates the policy (e.g., PPO)

  • Policy: Neural network models implementing a unified interface (get_action/get_value/evaluate_actions)

  • Buffer: On-policy rollout storage and minibatch iterator (managed by algorithm)

  • Env Factory: Build environments from a JSON config via registry

Architecture#

The framework follows a clean separation of concerns:

  • Trainer: Orchestrates the training loop (calls algorithm for data collection and updates, handles logging/eval/save)

  • Algorithm: Controls data collection process (interacts with environment, fills buffer, computes advantages/returns) and updates the policy (e.g., PPO)

  • Policy: Neural network models implementing a unified interface

  • Buffer: On-policy rollout storage and minibatch iterator (managed by algorithm)

  • Env Factory: Build environments from a JSON config via registry

The core components and their relationships:

  • Trainer → Policy, Env, Algorithm (via callbacks for statistics)

  • Algorithm → Policy, RolloutBuffer (algorithm manages its own buffer)

Configuration via JSON#

Training is configured via a JSON file that defines runtime settings, environment, policy, and algorithm parameters.

Example Configuration#

The configuration file (e.g., train_config.json) is located in configs/agents/rl/push_cube:

Example: train_config.json
 1{ 
 2    "trainer": {
 3        "exp_name": "push_cube_ppo",
 4        "gym_config": "configs/agents/rl/push_cube/gym_config.json",
 5        "seed": 42,
 6        "device": "cuda:0",
 7        "headless": true,
 8        "enable_rt": false,
 9        "gpu_id": 0,
10        "num_envs": 64,
11        "iterations": 1000,
12        "buffer_size": 1024,
13        "enable_eval": true,
14        "num_eval_envs": 16,
15        "num_eval_episodes": 3,
16        "eval_freq": 100,
17        "save_freq": 100,
18        "use_wandb": true,
19        "wandb_project_name": "embodichain-push_cube",
20        "events": {
21            "eval": {
22                "record_camera": {
23                    "func": "record_camera_data_async",
24                    "mode": "interval",
25                    "interval_step": 1,
26                    "params": {
27                        "name": "main_cam",
28                        "resolution": [640, 480],
29                        "eye": [-1.4, 1.4, 2.0],
30                        "target": [0, 0, 0],
31                        "up": [0, 0, 1],
32                        "intrinsics": [600, 600, 320, 240],
33                        "save_path": "./outputs/videos_ppo1/eval"
34                    }
35                }
36            }
37        }
38    },
39    "policy": {
40        "name": "actor_critic",
41        "actor": {
42            "type": "mlp",
43            "network_cfg": {
44                "hidden_sizes": [256, 256],
45                "activation": "relu"
46            }
47        },
48        "critic": {
49            "type": "mlp",
50            "network_cfg": {
51                "hidden_sizes": [256, 256],
52                "activation": "relu"
53            }
54        }
55    },
56    "algorithm": {
57        "name": "ppo",
58        "cfg": {
59            "learning_rate": 0.0001,
60            "n_epochs": 10,
61            "batch_size": 8192,
62            "gamma": 0.99,
63            "gae_lambda": 0.95,
64            "clip_coef": 0.2,
65            "ent_coef": 0.01,
66            "vf_coef": 0.5,
67            "max_grad_norm": 0.5
68        }
69    }
70}

Configuration Sections#

Runtime Settings#

The trainer section controls experiment setup:

  • exp_name: Experiment name (used for output directories)

  • seed: Random seed for reproducibility

  • device: Runtime device string, e.g. "cpu" or "cuda:0"

  • headless: Whether to run simulation in headless mode

  • iterations: Number of training iterations

  • buffer_size: Steps collected per rollout (e.g., 1024)

  • eval_freq: Frequency of evaluation (in steps)

  • save_freq: Frequency of checkpoint saving (in steps)

  • use_wandb: Whether to enable Weights & Biases logging (set in JSON config)

  • wandb_project_name: Weights & Biases project name

Environment Configuration#

The env section defines the task environment:

  • id: Environment registry ID (e.g., “PushCubeRL”)

  • cfg: Environment-specific configuration parameters

For RL environments, use the actions field for action preprocessing and extensions for task-specific parameters:

  • actions: Action Manager config (e.g., DeltaQposTerm with scale)

  • extensions: Task-specific parameters (e.g., success_threshold)

Example:

"env": {
  "id": "PushCubeRL",
  "cfg": {
    "num_envs": 4,
    "actions": {
      "delta_qpos": {
        "func": "DeltaQposTerm",
        "params": { "scale": 0.1 }
      }
    },
    "extensions": {
      "success_threshold": 0.1
    }
  }
}

Policy Configuration#

The policy section defines the neural network policy:

  • name: Policy name (e.g., “actor_critic”, “vla”)

  • action_dim: Optional policy output action dimension. If omitted, it is inferred from env.action_space.

  • actor: Actor network configuration (required for actor_critic)

  • critic: Critic network configuration (required for actor_critic)

Example:

"policy": {
  "name": "actor_critic",
  "actor": {
    "type": "mlp",
    "network_cfg": {
      "hidden_sizes": [256, 256],
      "activation": "relu"
    }
  },
  "critic": {
    "type": "mlp",
    "network_cfg": {
      "hidden_sizes": [256, 256],
      "activation": "relu"
    }
  }
}

Algorithm Configuration#

The algorithm section defines the RL algorithm:

  • name: Algorithm name (e.g., “ppo”, “grpo”)

  • cfg: Algorithm-specific hyperparameters

PPO example:

"algorithm": {
  "name": "ppo",
  "cfg": {
    "learning_rate": 0.0001,
    "n_epochs": 10,
    "batch_size": 64,
    "gamma": 0.99,
    "gae_lambda": 0.95,
    "clip_coef": 0.2,
    "ent_coef": 0.01,
    "vf_coef": 0.5,
    "max_grad_norm": 0.5
  }
}

GRPO example (for Embodied AI / from-scratch training, e.g. CartPole):

"algorithm": {
  "name": "grpo",
  "cfg": {
    "learning_rate": 0.0001,
    "n_epochs": 10,
    "batch_size": 8192,
    "gamma": 0.99,
    "clip_coef": 0.2,
    "ent_coef": 0.001,
    "kl_coef": 0,
    "group_size": 4,
    "eps": 1e-8,
    "reset_every_rollout": true,
    "max_grad_norm": 0.5,
    "truncate_at_first_done": true
  }
}

For GRPO: use actor_only policy. Set kl_coef=0 for from-scratch training; kl_coef=0.02 for VLA/LLM fine-tuning.

Training Script#

The training script (train.py) is located in embodichain/agents/rl/:

Code for train.py
  1# ----------------------------------------------------------------------------
  2# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8#     http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15# ----------------------------------------------------------------------------
 16
 17import argparse
 18import os
 19import time
 20from pathlib import Path
 21
 22import numpy as np
 23import torch
 24import wandb
 25import json
 26from torch.utils.tensorboard import SummaryWriter
 27from copy import deepcopy
 28
 29from embodichain.agents.rl.models import build_policy, get_registered_policy_names
 30from embodichain.agents.rl.models import build_mlp_from_cfg
 31from embodichain.agents.rl.algo import build_algo, get_registered_algo_names
 32from embodichain.agents.rl.utils import dict_to_tensordict, flatten_dict_observation
 33from embodichain.agents.rl.utils.trainer import Trainer
 34from embodichain.utils import logger
 35from embodichain.lab.gym.envs.tasks.rl import build_env
 36from embodichain.lab.gym.utils.gym_utils import config_to_cfg, DEFAULT_MANAGER_MODULES
 37from embodichain.utils.utility import load_json
 38from embodichain.utils.module_utils import find_function_from_modules
 39from embodichain.lab.sim import SimulationManagerCfg
 40from embodichain.lab.gym.envs.managers.cfg import EventCfg
 41
 42
 43def parse_args():
 44    """Parse command line arguments."""
 45    parser = argparse.ArgumentParser()
 46    parser.add_argument("--config", type=str, required=True, help="Path to JSON config")
 47    parser.add_argument(
 48        "--distributed",
 49        action=argparse.BooleanOptionalAction,
 50        default=None,
 51        help="Enable or disable multi-GPU distributed training",
 52    )
 53    return parser.parse_args()
 54
 55
 56def train_from_config(config_path: str, distributed: bool | None = None):
 57    """Run training from a config file path.
 58
 59    Args:
 60        config_path: Path to the JSON config file
 61        distributed: If True, run multi-GPU distributed training.
 62            If None, use trainer.distributed from config.
 63    """
 64    with open(config_path, "r") as f:
 65        cfg_json = json.load(f)
 66
 67    trainer_cfg = cfg_json["trainer"]
 68    policy_block = cfg_json["policy"]
 69    algo_block = cfg_json["algorithm"]
 70
 71    # Resolve distributed flag
 72    if distributed is None:
 73        distributed = bool(trainer_cfg.get("distributed", False))
 74
 75    # Distributed setup
 76    rank = 0
 77    world_size = 1
 78    local_rank = 0
 79    if distributed:
 80        if not torch.distributed.is_available():
 81            raise RuntimeError(
 82                "Distributed training requested but torch.distributed is not available."
 83            )
 84        if not torch.cuda.is_available():
 85            raise RuntimeError(
 86                "Distributed training with NCCL backend requires CUDA, "
 87                "but torch.cuda.is_available() is False."
 88            )
 89        local_rank = int(os.environ.get("LOCAL_RANK", 0))
 90        if local_rank < 0 or local_rank >= torch.cuda.device_count():
 91            raise ValueError(
 92                f"LOCAL_RANK {local_rank} is out of range "
 93                f"(available GPUs: {torch.cuda.device_count()})."
 94            )
 95        torch.cuda.set_device(local_rank)
 96        if not torch.distributed.is_initialized():
 97            torch.distributed.init_process_group(backend="nccl")
 98        rank = torch.distributed.get_rank()
 99        world_size = torch.distributed.get_world_size()
100
101    # Runtime
102    exp_name = trainer_cfg.get("exp_name", "generic_exp")
103    seed = int(trainer_cfg.get("seed", 1))
104    device_str = trainer_cfg.get("device", "cpu")
105    if distributed:
106        device_str = f"cuda:{local_rank}"
107    iterations = int(trainer_cfg.get("iterations", 250))
108    buffer_size = int(
109        trainer_cfg.get("buffer_size", trainer_cfg.get("rollout_steps", 2048))
110    )
111    enable_eval = bool(trainer_cfg.get("enable_eval", False))
112    eval_freq = int(trainer_cfg.get("eval_freq", 10000))
113    save_freq = int(trainer_cfg.get("save_freq", 50000))
114    num_eval_episodes = int(trainer_cfg.get("num_eval_episodes", 5))
115    headless = bool(trainer_cfg.get("headless", True))
116    enable_rt = bool(trainer_cfg.get("enable_rt", False))
117    gpu_id = int(trainer_cfg.get("gpu_id", 0))
118    num_envs = trainer_cfg.get("num_envs", None)
119    wandb_project_name = trainer_cfg.get("wandb_project_name", "embodichain-generic")
120
121    # Device
122    if not isinstance(device_str, str):
123        raise ValueError(
124            f"runtime.device must be a string such as 'cpu' or 'cuda:0'. Got: {device_str!r}"
125        )
126    try:
127        device = torch.device(device_str)
128    except RuntimeError as exc:
129        raise ValueError(
130            f"Failed to parse runtime.device='{device_str}': {exc}"
131        ) from exc
132
133    if device.type == "cuda":
134        if not torch.cuda.is_available():
135            raise ValueError(
136                "CUDA device requested but torch.cuda.is_available() is False."
137            )
138        index = (
139            device.index if device.index is not None else torch.cuda.current_device()
140        )
141        device_count = torch.cuda.device_count()
142        if index < 0 or index >= device_count:
143            raise ValueError(
144                f"CUDA device index {index} is out of range (available devices: {device_count})."
145            )
146        torch.cuda.set_device(index)
147        device = torch.device(f"cuda:{index}")
148    elif device.type != "cpu":
149        raise ValueError(f"Unsupported device type: {device}")
150    if rank == 0:
151        logger.log_info(f"Device: {device}")
152    if distributed and rank == 0:
153        logger.log_info(f"Distributed training: world_size={world_size}")
154
155    # Seeds
156    effective_seed = seed + rank
157    np.random.seed(effective_seed)
158    torch.manual_seed(effective_seed)
159    torch.backends.cudnn.deterministic = True
160    if device.type == "cuda":
161        torch.cuda.manual_seed_all(effective_seed)
162
163    # Outputs
164    if distributed:
165        run_stamp = time.strftime("%Y%m%d_%H%M%S") if rank == 0 else None
166        run_stamp_list = [run_stamp]
167        torch.distributed.broadcast_object_list(run_stamp_list, src=0)
168        run_stamp = run_stamp_list[0]
169    else:
170        run_stamp = time.strftime("%Y%m%d_%H%M%S")
171    run_base = os.path.join("outputs", f"{exp_name}_{run_stamp}")
172    log_dir = os.path.join(run_base, "logs")
173    checkpoint_dir = os.path.join(run_base, "checkpoints")
174    if rank == 0:
175        os.makedirs(log_dir, exist_ok=True)
176        os.makedirs(checkpoint_dir, exist_ok=True)
177    writer = SummaryWriter(f"{log_dir}/{exp_name}") if rank == 0 else None
178
179    # Initialize Weights & Biases (optional)
180    use_wandb = trainer_cfg.get("use_wandb", False)
181    if use_wandb and rank == 0:
182        wandb.init(project=wandb_project_name, name=exp_name, config=cfg_json)
183
184    gym_config_path = Path(trainer_cfg["gym_config"])
185    if rank == 0:
186        logger.log_info(f"Current working directory: {Path.cwd()}")
187
188    gym_config_data = load_json(str(gym_config_path))
189    gym_env_cfg = config_to_cfg(
190        gym_config_data, manager_modules=DEFAULT_MANAGER_MODULES
191    )
192    if num_envs is not None:
193        gym_env_cfg.num_envs = int(num_envs)
194
195    # Ensure sim configuration mirrors runtime overrides
196    if gym_env_cfg.sim_cfg is None:
197        gym_env_cfg.sim_cfg = SimulationManagerCfg()
198    if device.type == "cuda":
199        gpu_index = device.index
200        if gpu_index is None:
201            gpu_index = torch.cuda.current_device()
202        gym_env_cfg.sim_cfg.sim_device = torch.device(f"cuda:{gpu_index}")
203        if hasattr(gym_env_cfg.sim_cfg, "gpu_id"):
204            gym_env_cfg.sim_cfg.gpu_id = gpu_index
205    else:
206        gym_env_cfg.sim_cfg.sim_device = torch.device("cpu")
207    gym_env_cfg.sim_cfg.headless = headless
208    gym_env_cfg.sim_cfg.enable_rt = enable_rt
209    gym_env_cfg.sim_cfg.gpu_id = local_rank if distributed else gpu_id
210
211    if rank == 0:
212        logger.log_info(
213            f"Loaded gym_config from {gym_config_path} (env_id={gym_config_data['id']}, num_envs={gym_env_cfg.num_envs}, headless={gym_env_cfg.sim_cfg.headless}, enable_rt={gym_env_cfg.sim_cfg.enable_rt}, sim_device={gym_env_cfg.sim_cfg.sim_device})"
214        )
215
216    env = build_env(gym_config_data["id"], base_env_cfg=gym_env_cfg)
217    sample_obs, _ = env.reset()
218    sample_obs_td = dict_to_tensordict(sample_obs, device)
219    obs_dim = flatten_dict_observation(sample_obs_td).shape[-1]
220    flat_obs_space = env.flattened_observation_space
221
222    # Create evaluation environment only if enabled
223    eval_env = None
224    num_eval_envs = trainer_cfg.get("num_eval_envs", 4)
225    if enable_eval and rank == 0:
226        eval_gym_env_cfg = deepcopy(gym_env_cfg)
227        eval_gym_env_cfg.num_envs = num_eval_envs
228        eval_gym_env_cfg.sim_cfg.headless = True
229        eval_env = build_env(gym_config_data["id"], base_env_cfg=eval_gym_env_cfg)
230        logger.log_info(
231            f"Evaluation environment created (num_envs={num_eval_envs}, headless=True)"
232        )
233
234    # Build Policy via registry
235    policy_name = policy_block["name"]
236    env_action_dim = (
237        env.get_wrapper_attr("action_manager").total_action_dim
238        if env.get_wrapper_attr("action_manager") is not None
239        else len(env.get_wrapper_attr("active_joint_ids"))
240    )
241    action_dim = policy_block.get("action_dim", env_action_dim)
242    action_dim = int(action_dim)
243    if action_dim != env_action_dim:
244        raise ValueError(
245            f"Configured policy.action_dim={action_dim} does not match env action dim {env_action_dim}."
246        )
247    # Build Policy via registry (actor/critic must be explicitly defined in JSON when using actor_critic/actor_only)
248    if policy_name.lower() == "actor_critic":
249        actor_cfg = policy_block.get("actor")
250        critic_cfg = policy_block.get("critic")
251        if actor_cfg is None or critic_cfg is None:
252            raise ValueError(
253                "ActorCritic requires 'actor' and 'critic' definitions in JSON (policy.actor / policy.critic)."
254            )
255
256        actor = build_mlp_from_cfg(actor_cfg, obs_dim, action_dim)
257        critic = build_mlp_from_cfg(critic_cfg, obs_dim, 1)
258
259        policy = build_policy(
260            policy_block,
261            flat_obs_space,
262            env.action_space,
263            device,
264            actor=actor,
265            critic=critic,
266        )
267    elif policy_name.lower() == "actor_only":
268        actor_cfg = policy_block.get("actor")
269        if actor_cfg is None:
270            raise ValueError(
271                "ActorOnly requires 'actor' definition in JSON (policy.actor)."
272            )
273
274        actor = build_mlp_from_cfg(actor_cfg, obs_dim, action_dim)
275
276        policy = build_policy(
277            policy_block,
278            flat_obs_space,
279            env.action_space,
280            device,
281            actor=actor,
282        )
283    else:
284        policy = build_policy(
285            policy_block, env.observation_space, env.action_space, device
286        )
287
288    # Build Algorithm via factory
289    algo_name = algo_block["name"].lower()
290    algo_cfg = algo_block["cfg"]
291    algo = build_algo(
292        algo_name,
293        algo_cfg,
294        policy,
295        device,
296        distributed=distributed,
297    )
298
299    # Build Trainer
300    event_modules = [
301        "embodichain.lab.gym.envs.managers.randomization",
302        "embodichain.lab.gym.envs.managers.record",
303        "embodichain.lab.gym.envs.managers.events",
304    ]
305    events_dict = trainer_cfg.get("events", {})
306    train_event_cfg = {}
307    eval_event_cfg = {}
308    # Parse train events
309    for event_name, event_info in events_dict.get("train", {}).items():
310        event_func_str = event_info.get("func")
311        mode = event_info.get("mode", "interval")
312        params = event_info.get("params", {})
313        interval_step = event_info.get("interval_step", 1)
314        event_func = find_function_from_modules(
315            event_func_str, event_modules, raise_if_not_found=True
316        )
317        train_event_cfg[event_name] = EventCfg(
318            func=event_func,
319            mode=mode,
320            params=params,
321            interval_step=interval_step,
322        )
323    # Parse eval events (only if evaluation is enabled)
324    if enable_eval:
325        for event_name, event_info in events_dict.get("eval", {}).items():
326            event_func_str = event_info.get("func")
327            mode = event_info.get("mode", "interval")
328            params = event_info.get("params", {})
329            interval_step = event_info.get("interval_step", 1)
330            event_func = find_function_from_modules(
331                event_func_str, event_modules, raise_if_not_found=True
332            )
333            eval_event_cfg[event_name] = EventCfg(
334                func=event_func,
335                mode=mode,
336                params=params,
337                interval_step=interval_step,
338            )
339    trainer = Trainer(
340        policy=policy,
341        env=env,
342        algorithm=algo,
343        buffer_size=buffer_size,
344        batch_size=algo_cfg["batch_size"],
345        writer=writer,
346        eval_freq=eval_freq if enable_eval else 0,  # Disable eval if not enabled
347        save_freq=save_freq,
348        checkpoint_dir=checkpoint_dir,
349        exp_name=exp_name,
350        use_wandb=use_wandb,
351        eval_env=eval_env,  # None if enable_eval=False
352        event_cfg=train_event_cfg,
353        eval_event_cfg=eval_event_cfg if (enable_eval and rank == 0) else {},
354        num_eval_episodes=num_eval_episodes,
355        distributed=distributed,
356        rank=rank,
357        world_size=world_size,
358    )
359
360    if rank == 0:
361        logger.log_info("Generic training initialized")
362        logger.log_info(f"Task: {type(env).__name__}")
363        logger.log_info(
364            f"Policy: {policy_name} (available: {get_registered_policy_names()})"
365        )
366        logger.log_info(
367            f"Algorithm: {algo_name} (available: {get_registered_algo_names()})"
368        )
369
370    total_steps = int(iterations * buffer_size * env.num_envs * world_size)
371    if rank == 0:
372        logger.log_info(
373            f"Total steps: {total_steps} (iterations≈{iterations}, world_size={world_size})"
374        )
375
376    try:
377        trainer.train(total_steps)
378    except KeyboardInterrupt:
379        if rank == 0:
380            logger.log_info("Training interrupted by user")
381    finally:
382        trainer.save_checkpoint()
383        if writer is not None:
384            writer.close()
385        if use_wandb and rank == 0:
386            try:
387                wandb.finish()
388            except Exception:
389                pass
390
391        # Clean up environments to prevent resource leaks
392        try:
393            if env is not None:
394                env.close()
395        except Exception as e:
396            if rank == 0:
397                logger.log_warning(f"Failed to close training environment: {e}")
398
399        try:
400            if eval_env is not None:
401                eval_env.close()
402        except Exception as e:
403            if rank == 0:
404                logger.log_warning(f"Failed to close evaluation environment: {e}")
405
406        if distributed and torch.distributed.is_initialized():
407            torch.distributed.destroy_process_group()
408
409        if rank == 0:
410            logger.log_info("Training finished")
411
412
413def main():
414    """Main entry point for command-line training."""
415    args = parse_args()
416    train_from_config(args.config, distributed=args.distributed)
417
418
419if __name__ == "__main__":
420    main()

The Script Explained#

The training script performs the following steps:

  1. Parse Configuration: Loads JSON config and extracts runtime/env/policy/algorithm blocks

  2. Setup: Initializes device, seeds, output directories, TensorBoard, and Weights & Biases

  3. Build Components: - Environment via build_env() factory - Policy via build_policy() registry - Algorithm via build_algo() factory

  4. Create Trainer: Instantiates the Trainer with all components

  5. Train: Runs the training loop until completion

Launching Training#

To start training, run:

python -m embodichain.agents.rl.train --config configs/agents/rl/push_cube/train_config.json

Outputs#

All outputs are written to ./outputs/<exp_name>_<timestamp>/:

  • logs/: TensorBoard logs

  • checkpoints/: Model checkpoints

Training Process#

The training process follows this sequence:

  1. Rollout Phase: SyncCollector interacts with the environment and writes policy-side fields into a shared rollout TensorDict with uniform [N, T + 1] layout. EmbodiedEnv writes environment-side step fields such as reward, done, terminated, and truncated into the same rollout via set_rollout_buffer(). The final slot of transition-only fields is reserved as padding, while obs[:, -1] and value[:, -1] remain valid bootstrap data.

  2. Advantage/Return Computation: Algorithm computes advantages and returns from the collected rollout (e.g. GAE for PPO, step-wise group normalization for GRPO) and converts it to a transition-aligned view over the valid first T steps before minibatch optimization.

  3. Update Phase: Algorithm updates the policy with update(rollout)

  4. Logging: Trainer logs training losses and aggregated metrics to TensorBoard and Weights & Biases

  5. Evaluation (periodic): Trainer evaluates the current policy

  6. Checkpointing (periodic): Trainer saves model checkpoints

Policy Interface#

All policies must inherit from the Policy abstract base class:

from abc import ABC, abstractmethod
import torch.nn as nn

class Policy(nn.Module, ABC):
    device: torch.device

    def get_action(self, tensordict, deterministic: bool = False):
        """Samples action, sample_log_prob, and value into the TensorDict."""
        ...

    @abstractmethod
    def forward(self, tensordict, deterministic: bool = False):
        """Writes action, sample_log_prob, and value into the TensorDict."""
        raise NotImplementedError

    @abstractmethod
    def get_value(self, tensordict):
        """Writes value estimate into the TensorDict."""
        raise NotImplementedError

    @abstractmethod
    def evaluate_actions(self, tensordict):
        """Returns a new TensorDict with log_prob, entropy, and value."""
        raise NotImplementedError

Available Policies#

  • ActorCritic: MLP-based Gaussian policy with learnable log_std. Requires external actor and critic modules to be provided (defined in JSON config). Used with PPO.

  • ActorOnly: Actor-only policy without Critic. Used with GRPO (group-relative advantage estimation).

  • VLAPlaceholderPolicy: Placeholder for Vision-Language-Action policies

Algorithms#

Available Algorithms#

  • PPO: Proximal Policy Optimization with GAE

  • GRPO: Group Relative Policy Optimization (no Critic, step-wise returns, masked group normalization). Use actor_only policy. Set kl_coef=0 for from-scratch training (CartPole, dense reward); kl_coef=0.02 for VLA/LLM fine-tuning.

Adding a New Algorithm#

To add a new algorithm:

  1. Create a new algorithm class in embodichain/agents/rl/algo/

  2. Implement update(rollout) and consume the shared rollout TensorDict

  3. Register in algo/__init__.py:

from tensordict import TensorDict
from embodichain.agents.rl.algo import BaseAlgorithm, register_algo

@register_algo("my_algo")
class MyAlgorithm(BaseAlgorithm):
    def __init__(self, cfg, policy):
        self.cfg = cfg
        self.policy = policy
        self.device = torch.device(cfg.device)

    def update(self, rollout: TensorDict):
        """Update the policy using a collected rollout."""
        # compute advantages / returns from rollout
        # optimize policy parameters
        return {"loss": 0.0}

Adding a New Policy#

To add a new policy:

  1. Create a new policy class inheriting from the Policy abstract base class

  2. Register in models/__init__.py:

from embodichain.agents.rl.models import register_policy, Policy

@register_policy("my_policy")
class MyPolicy(Policy):
    def __init__(self, obs_dim, action_dim, device, config):
        super().__init__()
        self.device = device
        # Initialize your networks here

    def get_action(self, tensordict, deterministic=False):
        ...
    def forward(self, tensordict, deterministic=False):
        ...
    def get_value(self, tensordict):
        ...
    def evaluate_actions(self, tensordict):
        ...

Current built-in MLP policies use flattened observations in the training path. If your policy requires structured or multi-modal inputs, keep the richer obs_space interface and define a matching rollout/collector schema.

Adding a New Environment#

To add a new RL environment:

  1. Create an environment class inheriting from EmbodiedEnv (with Action Manager configured for action preprocessing and standardized info structure):

from embodichain.lab.gym.envs import EmbodiedEnv, EmbodiedEnvCfg
from embodichain.lab.gym.utils.registration import register_env
import torch

@register_env("MyTaskRL", override=True)
class MyTaskEnv(EmbodiedEnv):
    def __init__(self, cfg: EmbodiedEnvCfg = None, **kwargs):
        super().__init__(cfg, **kwargs)

    def compute_task_state(self, **kwargs):
        """Compute success/failure conditions and metrics."""
        is_success = ...  # Define success condition
        is_fail = torch.zeros_like(is_success)
        metrics = {"distance": ..., "error": ...}
        return is_success, is_fail, metrics
  1. Configure the environment in your JSON config with actions and extensions:

"env": {
  "id": "MyTaskRL",
  "cfg": {
    "num_envs": 4,
    "actions": {
      "delta_qpos": {
        "func": "DeltaQposTerm",
        "params": { "scale": 0.1 }
      }
    },
    "extensions": {
      "success_threshold": 0.05
    }
  }
}

The EmbodiedEnv with Action Manager provides:

  • Action Preprocessing: Configurable via actions (DeltaQposTerm, QposTerm, EefPoseTerm, etc.)

  • Standardized Info: Implements get_info() using compute_task_state() template method

Best Practices#

  • Use EmbodiedEnv with Action Manager for RL Tasks: Inherit from EmbodiedEnv and configure actions in your config. The Action Manager handles action preprocessing (delta_qpos, qpos, qvel, qf, eef_pose) in a modular way.

  • Action Configuration: Use the actions field in your JSON config. Example: "delta_qpos": {"func": "DeltaQposTerm", "params": {"scale": 0.1}}.

  • Device Management: Device is single-sourced from runtime.cuda. All components (trainer/algorithm/policy/env) share the same device.

  • Observation Format: Environments should provide consistent observation shape/types (torch.float32) and a single done = terminated | truncated.

  • Algorithm Interface: Algorithms implement update(rollout) and consume a shared rollout TensorDict. Collection is handled by SyncCollector plus environment-side rollout writes in EmbodiedEnv.

  • Reward Configuration: Use the RewardManager in your environment config to define reward components. Organize reward components in info["rewards"] dictionary and metrics in info["metrics"] dictionary. The trainer performs dense per-step logging directly from environment info.

  • Template Methods: Override compute_task_state() to define success/failure conditions and metrics. Override check_truncated() for custom truncation logic.

  • Configuration: Use JSON for all hyperparameters. This makes experiments reproducible and easy to track.

  • Logging: Metrics are automatically logged to TensorBoard and Weights & Biases. Check outputs/<exp_name>/logs/ for TensorBoard logs.

  • Checkpoints: Regular checkpoints are saved to outputs/<exp_name>/checkpoints/. Use these to resume training or evaluate policies.