RL Algorithms#
This module contains the core implementations of reinforcement learning algorithms, including PPO (Proximal Policy Optimization) and GRPO (Group Relative Policy Optimization).
Main Classes and Functions#
BaseAlgorithm#
Abstract base class for RL algorithms, defining common interfaces such as buffer initialization, data collection, and update.
Key methods:
initialize_buffer(num_steps, num_envs, obs_dim, action_dim): Initialize the trajectory buffer.collect_rollout(env, policy, obs, num_steps, on_step_callback): Collect interaction data.update(): Update the policy based on collected data.
Designed to be algorithm-agnostic; Trainer only depends on this interface to support various RL algorithms.
Supports multi-environment parallel collection, compatible with Gymnasium/IsaacGym environments.
PPO#
Mainstream on-policy algorithm, supports Generalized Advantage Estimation (GAE), policy update, and hyperparameter configuration.
Key methods:
_compute_gae(rewards, values, dones): Generalized Advantage Estimation.collect_rollout: Collect trajectories and compute advantages/returns.update: Multi-epoch minibatch optimization, including entropy, value, and policy loss, with gradient clipping.
Supports custom callbacks, detailed logging, and GPU acceleration.
Typical training flow: collect rollout → compute advantage/return → multi-epoch minibatch optimization.
Supports advantage normalization, entropy regularization, value loss weighting, etc.
GRPO#
Group Relative Policy Optimization: uses group-level return comparison instead of a Critic network, saving memory.
Step-wise returns: Computes per-step discounted returns (R_t = r_t + \gamma R_{t+1}) (reverse accumulation), avoiding causal issues and discount bias for dense-reward Embodied AI tasks.
Masked group normalization: For variable-length sequences (e.g.
truncate_at_first_done), group mean/std uses only alive peers at each step, avoiding dead envs’ zeros dragging down the mean.Optional reference policy: When
kl_coef > 0, creates a frozen reference policy for KL regularization (e.g. VLA fine-tuning). Whenkl_coef = 0, no ref policy is created (recommended for from-scratch training like CartPole).Key methods:
_compute_step_returns_and_mask(rewards, dones): Step-wise discounted returns and valid-step mask._compute_step_group_advantages(step_returns, seq_mask): Per-step group normalization with masked mean/std.collect_rollout: Collect trajectories and compute step-wise advantages.update: Multi-epoch minibatch optimization with optional KL penalty.
Supports both Embodied AI (dense reward, from-scratch training) and VLA (sparse reward, fine-tuning) modes via
kl_coefconfiguration.
Config Classes#
AlgorithmCfg,PPOCfg,GRPOCfg: Centralized management of learning rate, batch size, clip_coef, ent_coef, vf_coef, and other parameters.Supports automatic loading from JSON config files for batch experiments and parameter tuning.
Can be extended via inheritance for multiple algorithms and tasks.
Code Example#
class BaseAlgorithm:
def initialize_buffer(self, num_steps, num_envs, obs_dim, action_dim):
...
def collect_rollout(self, env, policy, obs, num_steps, on_step_callback=None):
...
def update(self):
...
class PPO(BaseAlgorithm):
def _compute_gae(self, rewards, values, dones):
...
def collect_rollout(self, ...):
...
def update(self):
...
Usage Recommendations#
It is recommended to manage all algorithm parameters via config classes and JSON config files for reproducibility and tuning.
Supports multi-environment parallel collection to improve sampling efficiency.
Custom algorithm classes can be implemented to extend new RL methods.
GRPO: Use
actor_onlypolicy (no Critic). Setkl_coef=0for from-scratch training (CartPole, dense reward); setkl_coef=0.02for VLA/LLM fine-tuning.
Extension Notes#
Users can inherit from
BaseAlgorithmto implement custom algorithms and flexibly integrate them into the RL framework.Supports multi-environment parallelism and event-driven extension.
Typical usage:
algo = PPO(cfg, policy)
buffer = algo.initialize_buffer(...)
for _ in range(num_iterations):
algo.collect_rollout(...)
algo.update()