Reinforcement Learning Training#

This tutorial shows you how to train reinforcement learning agents using EmbodiChain’s RL framework. You’ll learn how to configure training via JSON, set up environments, policies, and algorithms, and launch training sessions.

Overview#

The RL framework provides a modular, extensible stack for robotics tasks:

  • Trainer: Orchestrates the training loop (calls algorithm for data collection and updates, handles logging/eval/save)

  • Algorithm: Controls data collection process (interacts with environment, fills buffer, computes advantages/returns) and updates the policy (e.g., PPO)

  • Policy: Neural network models implementing a unified interface (get_action/get_value/evaluate_actions)

  • Buffer: On-policy rollout storage and minibatch iterator (managed by algorithm)

  • Env Factory: Build environments from a JSON config via registry

Architecture#

The framework follows a clean separation of concerns:

  • Trainer: Orchestrates the training loop (calls algorithm for data collection and updates, handles logging/eval/save)

  • Algorithm: Controls data collection process (interacts with environment, fills buffer, computes advantages/returns) and updates the policy (e.g., PPO)

  • Policy: Neural network models implementing a unified interface

  • Buffer: On-policy rollout storage and minibatch iterator (managed by algorithm)

  • Env Factory: Build environments from a JSON config via registry

The core components and their relationships:

  • Trainer → Policy, Env, Algorithm (via callbacks for statistics)

  • Algorithm → Policy, RolloutBuffer (algorithm manages its own buffer)

Configuration via JSON#

Training is configured via a JSON file that defines runtime settings, environment, policy, and algorithm parameters.

Example Configuration#

The configuration file (e.g., train_config.json) is located in configs/agents/rl/push_cube:

Example: train_config.json
 1{ 
 2    "trainer": {
 3        "exp_name": "push_cube_ppo",
 4        "gym_config": "configs/agents/rl/push_cube/gym_config.json",
 5        "seed": 42,
 6        "device": "cuda:0",
 7        "headless": true,
 8        "enable_rt": false,
 9        "gpu_id": 0,
10        "num_envs": 64,
11        "iterations": 1000,
12        "rollout_steps": 1024,
13        "enable_eval": true,
14        "num_eval_envs": 16,
15        "num_eval_episodes": 3,
16        "eval_freq": 2,
17        "save_freq": 200,
18        "use_wandb": false,
19        "wandb_project_name": "embodichain-push_cube",
20        "events": {
21            "eval": {
22                "record_camera": {
23                    "func": "record_camera_data_async",
24                    "mode": "interval",
25                    "interval_step": 1,
26                    "params": {
27                        "name": "main_cam",
28                        "resolution": [640, 480],
29                        "eye": [-1.4, 1.4, 2.0],
30                        "target": [0, 0, 0],
31                        "up": [0, 0, 1],
32                        "intrinsics": [600, 600, 320, 240],
33                        "save_path": "./outputs/videos/eval"
34                    }
35                }
36            }
37        }
38    },
39    "policy": {
40        "name": "actor_critic",
41        "actor": {
42            "type": "mlp",
43            "network_cfg": {
44                "hidden_sizes": [256, 256],
45                "activation": "relu"
46            }
47        },
48        "critic": {
49            "type": "mlp",
50            "network_cfg": {
51                "hidden_sizes": [256, 256],
52                "activation": "relu"
53            }
54        }
55    },
56    "algorithm": {
57        "name": "ppo",
58        "cfg": {
59            "learning_rate": 0.0001,
60            "n_epochs": 10,
61            "batch_size": 8192,
62            "gamma": 0.99,
63            "gae_lambda": 0.95,
64            "clip_coef": 0.2,
65            "ent_coef": 0.01,
66            "vf_coef": 0.5,
67            "max_grad_norm": 0.5
68        }
69    }
70}

Configuration Sections#

Runtime Settings#

The runtime section controls experiment setup:

  • exp_name: Experiment name (used for output directories)

  • seed: Random seed for reproducibility

  • cuda: Whether to use GPU (default: true)

  • headless: Whether to run simulation in headless mode

  • iterations: Number of training iterations

  • rollout_steps: Steps per rollout (e.g., 1024)

  • eval_freq: Frequency of evaluation (in steps)

  • save_freq: Frequency of checkpoint saving (in steps)

  • use_wandb: Whether to enable Weights & Biases logging (set in JSON config)

  • wandb_project_name: Weights & Biases project name

Environment Configuration#

The env section defines the task environment:

  • id: Environment registry ID (e.g., “PushCubeRL”)

  • cfg: Environment-specific configuration parameters

For RL environments, use the actions field for action preprocessing and extensions for task-specific parameters:

  • actions: Action Manager config (e.g., DeltaQposTerm with scale)

  • extensions: Task-specific parameters (e.g., success_threshold)

Example:

"env": {
  "id": "PushCubeRL",
  "cfg": {
    "num_envs": 4,
    "actions": {
      "delta_qpos": {
        "func": "DeltaQposTerm",
        "params": { "scale": 0.1 }
      }
    },
    "extensions": {
      "success_threshold": 0.1
    }
  }
}

Policy Configuration#

The policy section defines the neural network policy:

  • name: Policy name (e.g., “actor_critic”, “vla”)

  • cfg: Policy-specific hyperparameters (empty for actor_critic)

  • actor: Actor network configuration (required for actor_critic)

  • critic: Critic network configuration (required for actor_critic)

Example:

"policy": {
  "name": "actor_critic",
  "cfg": {},
  "actor": {
    "type": "mlp",
    "hidden_sizes": [256, 256],
    "activation": "relu"
  },
  "critic": {
    "type": "mlp",
    "hidden_sizes": [256, 256],
    "activation": "relu"
  }
}

Algorithm Configuration#

The algorithm section defines the RL algorithm:

  • name: Algorithm name (e.g., “ppo”, “grpo”)

  • cfg: Algorithm-specific hyperparameters

PPO example:

"algorithm": {
  "name": "ppo",
  "cfg": {
    "learning_rate": 0.0001,
    "n_epochs": 10,
    "batch_size": 64,
    "gamma": 0.99,
    "gae_lambda": 0.95,
    "clip_coef": 0.2,
    "ent_coef": 0.01,
    "vf_coef": 0.5,
    "max_grad_norm": 0.5
  }
}

GRPO example (for Embodied AI / from-scratch training, e.g. CartPole):

"algorithm": {
  "name": "grpo",
  "cfg": {
    "learning_rate": 0.0001,
    "n_epochs": 10,
    "batch_size": 8192,
    "gamma": 0.99,
    "clip_coef": 0.2,
    "ent_coef": 0.001,
    "kl_coef": 0,
    "group_size": 4,
    "eps": 1e-8,
    "reset_every_rollout": true,
    "max_grad_norm": 0.5,
    "truncate_at_first_done": true
  }
}

For GRPO: use actor_only policy. Set kl_coef=0 for from-scratch training; kl_coef=0.02 for VLA/LLM fine-tuning.

Training Script#

The training script (train.py) is located in embodichain/agents/rl/:

Code for train.py
  1# ----------------------------------------------------------------------------
  2# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8#     http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15# ----------------------------------------------------------------------------
 16
 17import argparse
 18import os
 19import time
 20from pathlib import Path
 21
 22import numpy as np
 23import torch
 24import wandb
 25import json
 26from torch.utils.tensorboard import SummaryWriter
 27from copy import deepcopy
 28
 29from embodichain.agents.rl.models import build_policy, get_registered_policy_names
 30from embodichain.agents.rl.models import build_mlp_from_cfg
 31from embodichain.agents.rl.algo import build_algo, get_registered_algo_names
 32from embodichain.agents.rl.utils.trainer import Trainer
 33from embodichain.utils import logger
 34from embodichain.lab.gym.envs.tasks.rl import build_env
 35from embodichain.lab.gym.utils.gym_utils import config_to_cfg, DEFAULT_MANAGER_MODULES
 36from embodichain.utils.utility import load_json
 37from embodichain.utils.module_utils import find_function_from_modules
 38from embodichain.lab.sim import SimulationManagerCfg
 39from embodichain.lab.gym.envs.managers.cfg import EventCfg
 40
 41
 42def parse_args():
 43    """Parse command line arguments."""
 44    parser = argparse.ArgumentParser()
 45    parser.add_argument("--config", type=str, required=True, help="Path to JSON config")
 46    return parser.parse_args()
 47
 48
 49def train_from_config(config_path: str):
 50    """Run training from a config file path.
 51
 52    Args:
 53        config_path: Path to the JSON config file
 54    """
 55    with open(config_path, "r") as f:
 56        cfg_json = json.load(f)
 57
 58    trainer_cfg = cfg_json["trainer"]
 59    policy_block = cfg_json["policy"]
 60    algo_block = cfg_json["algorithm"]
 61
 62    # Runtime
 63    exp_name = trainer_cfg.get("exp_name", "generic_exp")
 64    seed = int(trainer_cfg.get("seed", 1))
 65    device_str = trainer_cfg.get("device", "cpu")
 66    iterations = int(trainer_cfg.get("iterations", 250))
 67    rollout_steps = int(trainer_cfg.get("rollout_steps", 2048))
 68    enable_eval = bool(trainer_cfg.get("enable_eval", False))
 69    eval_freq = int(trainer_cfg.get("eval_freq", 10000))
 70    save_freq = int(trainer_cfg.get("save_freq", 50000))
 71    num_eval_episodes = int(trainer_cfg.get("num_eval_episodes", 5))
 72    headless = bool(trainer_cfg.get("headless", True))
 73    enable_rt = bool(trainer_cfg.get("enable_rt", False))
 74    gpu_id = int(trainer_cfg.get("gpu_id", 0))
 75    num_envs = trainer_cfg.get("num_envs", None)
 76    wandb_project_name = trainer_cfg.get("wandb_project_name", "embodichain-generic")
 77
 78    # Device
 79    if not isinstance(device_str, str):
 80        raise ValueError(
 81            f"runtime.device must be a string such as 'cpu' or 'cuda:0'. Got: {device_str!r}"
 82        )
 83    try:
 84        device = torch.device(device_str)
 85    except RuntimeError as exc:
 86        raise ValueError(
 87            f"Failed to parse runtime.device='{device_str}': {exc}"
 88        ) from exc
 89
 90    if device.type == "cuda":
 91        if not torch.cuda.is_available():
 92            raise ValueError(
 93                "CUDA device requested but torch.cuda.is_available() is False."
 94            )
 95        index = (
 96            device.index if device.index is not None else torch.cuda.current_device()
 97        )
 98        device_count = torch.cuda.device_count()
 99        if index < 0 or index >= device_count:
100            raise ValueError(
101                f"CUDA device index {index} is out of range (available devices: {device_count})."
102            )
103        torch.cuda.set_device(index)
104        device = torch.device(f"cuda:{index}")
105    elif device.type != "cpu":
106        raise ValueError(f"Unsupported device type: {device}")
107    logger.log_info(f"Device: {device}")
108
109    # Seeds
110    np.random.seed(seed)
111    torch.manual_seed(seed)
112    torch.backends.cudnn.deterministic = True
113    if device.type == "cuda":
114        torch.cuda.manual_seed_all(seed)
115
116    # Outputs
117    run_stamp = time.strftime("%Y%m%d_%H%M%S")
118    run_base = os.path.join("outputs", f"{exp_name}_{run_stamp}")
119    log_dir = os.path.join(run_base, "logs")
120    checkpoint_dir = os.path.join(run_base, "checkpoints")
121    os.makedirs(log_dir, exist_ok=True)
122    os.makedirs(checkpoint_dir, exist_ok=True)
123    writer = SummaryWriter(f"{log_dir}/{exp_name}")
124
125    # Initialize Weights & Biases (optional)
126    use_wandb = trainer_cfg.get("use_wandb", False)
127
128    # Initialize Weights & Biases (optional)
129    if use_wandb:
130        wandb.init(project=wandb_project_name, name=exp_name, config=cfg_json)
131
132    gym_config_path = Path(trainer_cfg["gym_config"])
133    logger.log_info(f"Current working directory: {Path.cwd()}")
134
135    gym_config_data = load_json(str(gym_config_path))
136    gym_env_cfg = config_to_cfg(
137        gym_config_data, manager_modules=DEFAULT_MANAGER_MODULES
138    )
139    if num_envs is not None:
140        gym_env_cfg.num_envs = int(num_envs)
141
142    if num_envs is not None:
143        gym_env_cfg.num_envs = num_envs
144
145    # Ensure sim configuration mirrors runtime overrides
146    if gym_env_cfg.sim_cfg is None:
147        gym_env_cfg.sim_cfg = SimulationManagerCfg()
148    if device.type == "cuda":
149        gpu_index = device.index
150        if gpu_index is None:
151            gpu_index = torch.cuda.current_device()
152        gym_env_cfg.sim_cfg.sim_device = torch.device(f"cuda:{gpu_index}")
153        if hasattr(gym_env_cfg.sim_cfg, "gpu_id"):
154            gym_env_cfg.sim_cfg.gpu_id = gpu_index
155    else:
156        gym_env_cfg.sim_cfg.sim_device = torch.device("cpu")
157    gym_env_cfg.sim_cfg.headless = headless
158    gym_env_cfg.sim_cfg.enable_rt = enable_rt
159    gym_env_cfg.sim_cfg.gpu_id = gpu_id
160
161    logger.log_info(
162        f"Loaded gym_config from {gym_config_path} (env_id={gym_config_data['id']}, num_envs={gym_env_cfg.num_envs}, headless={gym_env_cfg.sim_cfg.headless}, enable_rt={gym_env_cfg.sim_cfg.enable_rt}, sim_device={gym_env_cfg.sim_cfg.sim_device})"
163    )
164
165    env = build_env(gym_config_data["id"], base_env_cfg=gym_env_cfg)
166
167    # Create evaluation environment only if enabled
168    eval_env = None
169    num_eval_envs = trainer_cfg.get("num_eval_envs", 4)
170    if enable_eval:
171        eval_gym_env_cfg = deepcopy(gym_env_cfg)
172        eval_gym_env_cfg.num_envs = num_eval_envs
173        eval_gym_env_cfg.sim_cfg.headless = True
174        eval_env = build_env(gym_config_data["id"], base_env_cfg=eval_gym_env_cfg)
175        logger.log_info(
176            f"Evaluation environment created (num_envs={num_eval_envs}, headless=True)"
177        )
178
179    # Build Policy via registry
180    policy_name = policy_block["name"]
181    # Build Policy via registry (actor/critic must be explicitly defined in JSON when using actor_critic/actor_only)
182    if policy_name.lower() == "actor_critic":
183        # Get observation dimension from flattened observation space
184        # flattened_observation_space returns Box space for RL training
185        obs_dim = env.flattened_observation_space.shape[-1]
186        action_dim = env.action_space.shape[-1]
187
188        actor_cfg = policy_block.get("actor")
189        critic_cfg = policy_block.get("critic")
190        if actor_cfg is None or critic_cfg is None:
191            raise ValueError(
192                "ActorCritic requires 'actor' and 'critic' definitions in JSON (policy.actor / policy.critic)."
193            )
194
195        actor = build_mlp_from_cfg(actor_cfg, obs_dim, action_dim)
196        critic = build_mlp_from_cfg(critic_cfg, obs_dim, 1)
197
198        policy = build_policy(
199            policy_block,
200            env.flattened_observation_space,
201            env.action_space,
202            device,
203            actor=actor,
204            critic=critic,
205        )
206    elif policy_name.lower() == "actor_only":
207        obs_dim = env.flattened_observation_space.shape[-1]
208        action_dim = env.action_space.shape[-1]
209
210        actor_cfg = policy_block.get("actor")
211        if actor_cfg is None:
212            raise ValueError(
213                "ActorOnly requires 'actor' definition in JSON (policy.actor)."
214            )
215
216        actor = build_mlp_from_cfg(actor_cfg, obs_dim, action_dim)
217
218        policy = build_policy(
219            policy_block,
220            env.flattened_observation_space,
221            env.action_space,
222            device,
223            actor=actor,
224        )
225    else:
226        policy = build_policy(
227            policy_block, env.flattened_observation_space, env.action_space, device
228        )
229
230    # Build Algorithm via factory
231    algo_name = algo_block["name"].lower()
232    algo_cfg = algo_block["cfg"]
233    algo = build_algo(algo_name, algo_cfg, policy, device)
234
235    # Build Trainer
236    event_modules = [
237        "embodichain.lab.gym.envs.managers.randomization",
238        "embodichain.lab.gym.envs.managers.record",
239        "embodichain.lab.gym.envs.managers.events",
240    ]
241    events_dict = trainer_cfg.get("events", {})
242    train_event_cfg = {}
243    eval_event_cfg = {}
244    # Parse train events
245    for event_name, event_info in events_dict.get("train", {}).items():
246        event_func_str = event_info.get("func")
247        mode = event_info.get("mode", "interval")
248        params = event_info.get("params", {})
249        interval_step = event_info.get("interval_step", 1)
250        event_func = find_function_from_modules(
251            event_func_str, event_modules, raise_if_not_found=True
252        )
253        train_event_cfg[event_name] = EventCfg(
254            func=event_func,
255            mode=mode,
256            params=params,
257            interval_step=interval_step,
258        )
259    # Parse eval events (only if evaluation is enabled)
260    if enable_eval:
261        for event_name, event_info in events_dict.get("eval", {}).items():
262            event_func_str = event_info.get("func")
263            mode = event_info.get("mode", "interval")
264            params = event_info.get("params", {})
265            interval_step = event_info.get("interval_step", 1)
266            event_func = find_function_from_modules(
267                event_func_str, event_modules, raise_if_not_found=True
268            )
269            eval_event_cfg[event_name] = EventCfg(
270                func=event_func,
271                mode=mode,
272                params=params,
273                interval_step=interval_step,
274            )
275    trainer = Trainer(
276        policy=policy,
277        env=env,
278        algorithm=algo,
279        num_steps=rollout_steps,
280        batch_size=algo_cfg["batch_size"],
281        writer=writer,
282        eval_freq=eval_freq if enable_eval else 0,  # Disable eval if not enabled
283        save_freq=save_freq,
284        checkpoint_dir=checkpoint_dir,
285        exp_name=exp_name,
286        use_wandb=use_wandb,
287        eval_env=eval_env,  # None if enable_eval=False
288        event_cfg=train_event_cfg,
289        eval_event_cfg=eval_event_cfg if enable_eval else {},
290        num_eval_episodes=num_eval_episodes,
291    )
292
293    logger.log_info("Generic training initialized")
294    logger.log_info(f"Task: {type(env).__name__}")
295    logger.log_info(
296        f"Policy: {policy_name} (available: {get_registered_policy_names()})"
297    )
298    logger.log_info(
299        f"Algorithm: {algo_name} (available: {get_registered_algo_names()})"
300    )
301
302    total_steps = int(iterations * rollout_steps * env.num_envs)
303    logger.log_info(f"Total steps: {total_steps} (iterations≈{iterations})")
304
305    try:
306        trainer.train(total_steps)
307    except KeyboardInterrupt:
308        logger.log_info("Training interrupted by user")
309    finally:
310        trainer.save_checkpoint()
311        writer.close()
312        if use_wandb:
313            try:
314                wandb.finish()
315            except Exception:
316                pass
317
318        # Clean up environments to prevent resource leaks
319        try:
320            if env is not None:
321                env.close()
322        except Exception as e:
323            logger.log_warning(f"Failed to close training environment: {e}")
324
325        try:
326            if eval_env is not None:
327                eval_env.close()
328        except Exception as e:
329            logger.log_warning(f"Failed to close evaluation environment: {e}")
330
331        logger.log_info("Training finished")
332
333
334def main():
335    """Main entry point for command-line training."""
336    args = parse_args()
337    train_from_config(args.config)
338
339
340if __name__ == "__main__":
341    main()

The Script Explained#

The training script performs the following steps:

  1. Parse Configuration: Loads JSON config and extracts runtime/env/policy/algorithm blocks

  2. Setup: Initializes device, seeds, output directories, TensorBoard, and Weights & Biases

  3. Build Components: - Environment via build_env() factory - Policy via build_policy() registry - Algorithm via build_algo() factory

  4. Create Trainer: Instantiates the Trainer with all components

  5. Train: Runs the training loop until completion

Launching Training#

To start training, run:

python -m embodichain.agents.rl.train --config configs/agents/rl/push_cube/train_config.json

Outputs#

All outputs are written to ./outputs/<exp_name>_<timestamp>/:

  • logs/: TensorBoard logs

  • checkpoints/: Model checkpoints

Training Process#

The training process follows this sequence:

  1. Rollout Phase: Algorithm collects trajectories by interacting with the environment (via collect_rollout). During this phase, the trainer performs dense per-step logging of rewards and metrics from environment info.

  2. Advantage/Return Computation: Algorithm computes advantages and returns (e.g. GAE for PPO, step-wise group normalization for GRPO; stored in buffer extras)

  3. Update Phase: Algorithm updates the policy using collected data (e.g., PPO)

  4. Logging: Trainer logs training losses and aggregated metrics to TensorBoard and Weights & Biases

  5. Evaluation (periodic): Trainer evaluates the current policy

  6. Checkpointing (periodic): Trainer saves model checkpoints

Policy Interface#

All policies must inherit from the Policy abstract base class:

from abc import ABC, abstractmethod
import torch.nn as nn

class Policy(nn.Module, ABC):
    device: torch.device

    @abstractmethod
    def get_action(
        self, obs: torch.Tensor, deterministic: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Returns (action, log_prob, value)"""
        raise NotImplementedError

    @abstractmethod
    def get_value(self, obs: torch.Tensor) -> torch.Tensor:
        """Returns value estimate"""
        raise NotImplementedError

    @abstractmethod
    def evaluate_actions(
        self, obs: torch.Tensor, actions: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Returns (log_prob, entropy, value)"""
        raise NotImplementedError

Available Policies#

  • ActorCritic: MLP-based Gaussian policy with learnable log_std. Requires external actor and critic modules to be provided (defined in JSON config). Used with PPO.

  • ActorOnly: Actor-only policy without Critic. Used with GRPO (group-relative advantage estimation).

  • VLAPlaceholderPolicy: Placeholder for Vision-Language-Action policies

Algorithms#

Available Algorithms#

  • PPO: Proximal Policy Optimization with GAE

  • GRPO: Group Relative Policy Optimization (no Critic, step-wise returns, masked group normalization). Use actor_only policy. Set kl_coef=0 for from-scratch training (CartPole, dense reward); kl_coef=0.02 for VLA/LLM fine-tuning.

Adding a New Algorithm#

To add a new algorithm:

  1. Create a new algorithm class in embodichain/agents/rl/algo/

  2. Implement initialize_buffer(), collect_rollout(), and update() methods

  3. Register in algo/__init__.py:

from embodichain.agents.rl.algo import BaseAlgorithm, register_algo
from embodichain.agents.rl.buffer import RolloutBuffer

@register_algo("my_algo")
class MyAlgorithm(BaseAlgorithm):
    def __init__(self, cfg, policy):
        self.cfg = cfg
        self.policy = policy
        self.device = torch.device(cfg.device)
        self.buffer = None

    def initialize_buffer(self, num_steps, num_envs, obs_dim, action_dim):
        """Initialize the algorithm's buffer."""
        self.buffer = RolloutBuffer(num_steps, num_envs, obs_dim, action_dim, self.device)

    def collect_rollout(self, env, policy, obs, num_steps, on_step_callback=None):
        """Control data collection process (interact with env, fill buffer, compute advantages/returns)."""
        # Collect trajectories
        # Compute advantages/returns (e.g., GAE for on-policy algorithms)
        # Attach extras to buffer: self.buffer.set_extras({"advantages": adv, "returns": ret})
        # Return empty dict (dense logging handled in trainer)
        return {}

    def update(self):
        """Update the policy using collected data."""
        # Access extras from buffer: self.buffer._extras.get("advantages")
        # Use self.buffer to update policy
        return {"loss": 0.0}

Adding a New Policy#

To add a new policy:

  1. Create a new policy class inheriting from the Policy abstract base class

  2. Register in models/__init__.py:

from embodichain.agents.rl.models import register_policy, Policy

@register_policy("my_policy")
class MyPolicy(Policy):
    def __init__(self, obs_space, action_space, device, config):
        super().__init__()
        self.device = device
        # Initialize your networks here

    def get_action(self, obs, deterministic=False):
        ...
    def get_value(self, obs):
        ...
    def evaluate_actions(self, obs, actions):
        ...

Adding a New Environment#

To add a new RL environment:

  1. Create an environment class inheriting from EmbodiedEnv (with Action Manager configured for action preprocessing and standardized info structure):

from embodichain.lab.gym.envs import EmbodiedEnv, EmbodiedEnvCfg
from embodichain.lab.gym.utils.registration import register_env
import torch

@register_env("MyTaskRL", override=True)
class MyTaskEnv(EmbodiedEnv):
    def __init__(self, cfg: EmbodiedEnvCfg = None, **kwargs):
        super().__init__(cfg, **kwargs)

    def compute_task_state(self, **kwargs):
        """Compute success/failure conditions and metrics."""
        is_success = ...  # Define success condition
        is_fail = torch.zeros_like(is_success)
        metrics = {"distance": ..., "error": ...}
        return is_success, is_fail, metrics
  1. Configure the environment in your JSON config with actions and extensions:

"env": {
  "id": "MyTaskRL",
  "cfg": {
    "num_envs": 4,
    "actions": {
      "delta_qpos": {
        "func": "DeltaQposTerm",
        "params": { "scale": 0.1 }
      }
    },
    "extensions": {
      "success_threshold": 0.05
    }
  }
}

The EmbodiedEnv with Action Manager provides:

  • Action Preprocessing: Configurable via actions (DeltaQposTerm, QposTerm, EefPoseTerm, etc.)

  • Standardized Info: Implements get_info() using compute_task_state() template method

Best Practices#

  • Use EmbodiedEnv with Action Manager for RL Tasks: Inherit from EmbodiedEnv and configure actions in your config. The Action Manager handles action preprocessing (delta_qpos, qpos, qvel, qf, eef_pose) in a modular way.

  • Action Configuration: Use the actions field in your JSON config. Example: "delta_qpos": {"func": "DeltaQposTerm", "params": {"scale": 0.1}}.

  • Device Management: Device is single-sourced from runtime.cuda. All components (trainer/algorithm/policy/env) share the same device.

  • Observation Format: Environments should provide consistent observation shape/types (torch.float32) and a single done = terminated | truncated.

  • Algorithm Interface: Algorithms must implement initialize_buffer(), collect_rollout(), and update() methods. The algorithm completely controls data collection and buffer management.

  • Reward Configuration: Use the RewardManager in your environment config to define reward components. Organize reward components in info["rewards"] dictionary and metrics in info["metrics"] dictionary. The trainer performs dense per-step logging directly from environment info.

  • Template Methods: Override compute_task_state() to define success/failure conditions and metrics. Override check_truncated() for custom truncation logic.

  • Configuration: Use JSON for all hyperparameters. This makes experiments reproducible and easy to track.

  • Logging: Metrics are automatically logged to TensorBoard and Weights & Biases. Check outputs/<exp_name>/logs/ for TensorBoard logs.

  • Checkpoints: Regular checkpoints are saved to outputs/<exp_name>/checkpoints/. Use these to resume training or evaluate policies.