Source code for embodichain.agents.rl.algo.grpo

# ----------------------------------------------------------------------------
# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------

from __future__ import annotations

from copy import deepcopy
from typing import Dict

import torch
from tensordict import TensorDict

from embodichain.agents.rl.buffer import iterate_minibatches, transition_view
from embodichain.agents.rl.utils import AlgorithmCfg
from embodichain.utils import configclass
from .base import BaseAlgorithm


[docs] @configclass class GRPOCfg(AlgorithmCfg): """Configuration for GRPO.""" n_epochs: int = 10 clip_coef: float = 0.2 ent_coef: float = 0.0 kl_coef: float = 0.02 group_size: int = 4 eps: float = 1e-8 reset_every_rollout: bool = True truncate_at_first_done: bool = True
[docs] class GRPO(BaseAlgorithm): """Group Relative Policy Optimization on top of TensorDict rollouts."""
[docs] def __init__(self, cfg: GRPOCfg, policy): if cfg.group_size < 2: raise ValueError( f"GRPO requires group_size >= 2 for within-group normalization, got {cfg.group_size}." ) self.cfg = cfg self.policy = policy self.device = torch.device(cfg.device) self.optimizer = torch.optim.Adam(policy.parameters(), lr=cfg.learning_rate) if self.cfg.kl_coef > 0.0: raw_policy = getattr(policy, "module", policy) self.ref_policy = deepcopy(raw_policy).to(self.device).eval() for param in self.ref_policy.parameters(): param.requires_grad_(False) else: self.ref_policy = None
def _compute_step_returns_and_mask( self, rewards: torch.Tensor, dones: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: """Compute discounted returns and valid-step mask over `[N, T]` rollout.""" n_envs, t_steps = rewards.shape seq_mask = torch.ones( (n_envs, t_steps), dtype=torch.float32, device=self.device ) step_returns = torch.zeros( (n_envs, t_steps), dtype=torch.float32, device=self.device ) alive = torch.ones(n_envs, dtype=torch.float32, device=self.device) for t in range(t_steps): seq_mask[:, t] = alive if self.cfg.truncate_at_first_done: alive = alive * (~dones[:, t]).float() running_return = torch.zeros(n_envs, dtype=torch.float32, device=self.device) for t in reversed(range(t_steps)): running_return = ( rewards[:, t] + self.cfg.gamma * running_return * (~dones[:, t]).float() ) step_returns[:, t] = running_return return step_returns, seq_mask def _compute_step_group_advantages( self, step_returns: torch.Tensor, seq_mask: torch.Tensor ) -> torch.Tensor: """Normalize per-step returns within each environment group.""" n_envs, t_steps = step_returns.shape group_size = self.cfg.group_size returns_grouped = step_returns.view(n_envs // group_size, group_size, t_steps) mask_grouped = seq_mask.view(n_envs // group_size, group_size, t_steps) valid_count = mask_grouped.sum(dim=1, keepdim=True) valid_count_safe = torch.clamp(valid_count, min=1.0) group_mean = (returns_grouped * mask_grouped).sum( dim=1, keepdim=True ) / valid_count_safe diff_sq = ((returns_grouped - group_mean) ** 2) * mask_grouped group_var = diff_sq.sum(dim=1, keepdim=True) / valid_count_safe group_std = torch.sqrt(group_var) advantages = (returns_grouped - group_mean) / (group_std + self.cfg.eps) return advantages.view(n_envs, t_steps) * seq_mask
[docs] def update(self, rollout: TensorDict) -> Dict[str, float]: rollout = rollout.clone() num_envs = rollout.batch_size[0] if num_envs % self.cfg.group_size != 0: raise ValueError( f"GRPO requires num_envs divisible by group_size, got " f"num_envs={num_envs}, group_size={self.cfg.group_size}." ) rewards = rollout["reward"][:, :-1].float() dones = rollout["done"][:, :-1].bool() step_returns, seq_mask = self._compute_step_returns_and_mask(rewards, dones) rollout["advantage"] = torch.zeros_like(rollout["reward"], dtype=torch.float32) rollout["advantage"][:, :-1] = self._compute_step_group_advantages( step_returns, seq_mask ) rollout["seq_mask"] = torch.zeros_like(rollout["reward"], dtype=torch.float32) rollout["seq_mask"][:, :-1] = seq_mask rollout["seq_return"] = torch.zeros_like(rollout["reward"], dtype=torch.float32) rollout["seq_return"][:, :-1] = step_returns flat_rollout = transition_view(rollout, flatten=True) total_actor_loss = 0.0 total_entropy = 0.0 total_kl = 0.0 total_weight = 0.0 for _ in range(self.cfg.n_epochs): for batch in iterate_minibatches( flat_rollout, self.cfg.batch_size, self.device ): old_logprobs = batch["sample_log_prob"].clone() advantages = batch["advantage"].detach() seq_mask_batch = batch["seq_mask"].float() policy_module = getattr(self.policy, "module", self.policy) eval_batch = policy_module.evaluate_actions(batch) logprobs = eval_batch["sample_log_prob"] entropy = eval_batch["entropy"] ratio = (logprobs - old_logprobs).exp() surr1 = ratio * advantages surr2 = ( torch.clamp( ratio, 1.0 - self.cfg.clip_coef, 1.0 + self.cfg.clip_coef ) * advantages ) actor_num = -(torch.min(surr1, surr2) * seq_mask_batch).sum() denom = torch.clamp(seq_mask_batch.sum(), min=1.0) actor_loss = actor_num / denom entropy_loss = -(entropy * seq_mask_batch).sum() / denom if self.ref_policy is not None: with torch.no_grad(): ref_batch = self.ref_policy.evaluate_actions(batch) ref_logprobs = ref_batch["sample_log_prob"] log_ref_over_pi = ref_logprobs - logprobs kl_per = torch.exp(log_ref_over_pi) - log_ref_over_pi - 1.0 kl = (kl_per * seq_mask_batch).sum() / denom else: kl = torch.tensor(0.0, device=self.device) loss = ( actor_loss + self.cfg.kl_coef * kl + self.cfg.ent_coef * entropy_loss ) self.optimizer.zero_grad(set_to_none=True) loss.backward() torch.nn.utils.clip_grad_norm_( self.policy.parameters(), self.cfg.max_grad_norm ) self.optimizer.step() weight = float(denom.item()) total_actor_loss += actor_loss.item() * weight masked_entropy = (entropy * seq_mask_batch).sum() / denom total_entropy += masked_entropy.item() * weight total_kl += kl.item() * weight total_weight += weight return { "actor_loss": total_actor_loss / max(1.0, total_weight), "entropy": total_entropy / max(1.0, total_weight), "approx_ref_kl": total_kl / max(1.0, total_weight), }