Source code for embodichain.agents.rl.algo.ppo

# ----------------------------------------------------------------------------
# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------

import torch
from typing import Dict, Any, Tuple, Callable

from tensordict import TensorDict

from embodichain.agents.rl.utils import AlgorithmCfg, flatten_dict_observation
from embodichain.agents.rl.buffer import RolloutBuffer
from embodichain.utils import configclass
from .base import BaseAlgorithm


[docs] @configclass class PPOCfg(AlgorithmCfg): """Configuration for the PPO algorithm.""" n_epochs: int = 10 clip_coef: float = 0.2 ent_coef: float = 0.01 vf_coef: float = 0.5
[docs] class PPO(BaseAlgorithm): """PPO algorithm operating via Policy and RolloutBuffer (algo-agnostic design)."""
[docs] def __init__(self, cfg: PPOCfg, policy): self.cfg = cfg self.policy = policy self.device = torch.device(cfg.device) self.optimizer = torch.optim.Adam(policy.parameters(), lr=cfg.learning_rate) self.buffer: RolloutBuffer | None = None
# no per-rollout aggregation for dense logging def _compute_gae( self, rewards: torch.Tensor, values: torch.Tensor, dones: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Internal method to compute GAE. Only called by collect_rollout.""" T, N = rewards.shape advantages = torch.zeros_like(rewards, device=self.device) last_adv = torch.zeros(N, device=self.device) for t in reversed(range(T)): next_value = values[t + 1] if t < T - 1 else torch.zeros_like(values[0]) not_done = (~dones[t]).float() delta = rewards[t] + self.cfg.gamma * next_value * not_done - values[t] last_adv = ( delta + self.cfg.gamma * self.cfg.gae_lambda * not_done * last_adv ) advantages[t] = last_adv returns = advantages + values return advantages, returns
[docs] def initialize_buffer( self, num_steps: int, num_envs: int, obs_dim: int, action_dim: int ): """Initialize the rollout buffer. Called by trainer before first rollout.""" self.buffer = RolloutBuffer( num_steps, num_envs, obs_dim, action_dim, self.device )
[docs] def collect_rollout( self, env, policy, obs: torch.Tensor, num_steps: int, on_step_callback: Callable | None = None, ) -> Dict[str, Any]: """Collect a rollout. Algorithm controls the data collection process.""" if self.buffer is None: raise RuntimeError( "Buffer not initialized. Call initialize_buffer() first." ) policy.train() self.buffer.step = 0 current_obs = obs for t in range(num_steps): # Get action from policy actions, log_prob, value = policy.get_action( current_obs, deterministic=False ) # Wrap action as dict for env processing am = getattr(env, "action_manager", None) action_type = ( am.action_type if am else getattr(env, "action_type", "delta_qpos") ) action_dict = {action_type: actions} # Step environment result = env.step(action_dict) next_obs, reward, terminated, truncated, env_info = result done = terminated | truncated # Light dtype normalization reward = reward.float() done = done.bool() # Flatten TensorDict observation from ObservationManager if needed if isinstance(next_obs, TensorDict): next_obs = flatten_dict_observation(next_obs) # Add to buffer self.buffer.add(current_obs, actions, reward, done, value, log_prob) # Dense logging is handled in Trainer.on_step via info; no aggregation here # Call callback for statistics and logging if on_step_callback is not None: on_step_callback(current_obs, actions, reward, done, env_info, next_obs) current_obs = next_obs # Compute advantages/returns and attach to buffer extras adv, ret = self._compute_gae( self.buffer.rewards, self.buffer.values, self.buffer.dones ) self.buffer.set_extras({"advantages": adv, "returns": ret}) # No aggregated logging results; Trainer performs dense per-step logging return {}
[docs] def update(self) -> dict: """Update the policy using the collected rollout buffer.""" if self.buffer is None: raise RuntimeError("Buffer not initialized. Call collect_rollout() first.") # Normalize advantages (optional, common default) adv = self.buffer._extras.get("advantages") adv = (adv - adv.mean()) / (adv.std() + 1e-8) total_actor_loss = 0.0 total_value_loss = 0.0 total_entropy = 0.0 total_steps = 0 for _ in range(self.cfg.n_epochs): for batch in self.buffer.iterate_minibatches(self.cfg.batch_size): obs = batch["obs"] actions = batch["actions"] old_logprobs = batch["logprobs"] returns = batch["returns"] advantages = ( (batch["advantages"] - adv.mean()) / (adv.std() + 1e-8) ).detach() logprobs, entropy, values = self.policy.evaluate_actions(obs, actions) ratio = (logprobs - old_logprobs).exp() surr1 = ratio * advantages surr2 = ( torch.clamp( ratio, 1.0 - self.cfg.clip_coef, 1.0 + self.cfg.clip_coef ) * advantages ) actor_loss = -torch.min(surr1, surr2).mean() value_loss = torch.nn.functional.mse_loss(values, returns) entropy_loss = -entropy.mean() loss = ( actor_loss + self.cfg.vf_coef * value_loss + self.cfg.ent_coef * entropy_loss ) self.optimizer.zero_grad(set_to_none=True) loss.backward() torch.nn.utils.clip_grad_norm_( self.policy.parameters(), self.cfg.max_grad_norm ) self.optimizer.step() bs = obs.shape[0] total_actor_loss += actor_loss.item() * bs total_value_loss += value_loss.item() * bs total_entropy += (-entropy_loss.item()) * bs total_steps += bs return { "actor_loss": total_actor_loss / max(1, total_steps), "value_loss": total_value_loss / max(1, total_steps), "entropy": total_entropy / max(1, total_steps), }