Source code for embodichain.agents.rl.algo.ppo
# ----------------------------------------------------------------------------
# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------
from typing import Dict
import torch
from tensordict import TensorDict
from embodichain.agents.rl.buffer import iterate_minibatches, transition_view
from embodichain.agents.rl.utils import AlgorithmCfg
from embodichain.utils import configclass
from .common import compute_gae
from .base import BaseAlgorithm
[docs]
@configclass
class PPOCfg(AlgorithmCfg):
"""Configuration for the PPO algorithm."""
n_epochs: int = 10
clip_coef: float = 0.2
ent_coef: float = 0.01
vf_coef: float = 0.5
[docs]
class PPO(BaseAlgorithm):
"""PPO algorithm consuming TensorDict rollouts."""
[docs]
def __init__(self, cfg: PPOCfg, policy):
self.cfg = cfg
self.policy = policy
self.device = torch.device(cfg.device)
self.optimizer = torch.optim.Adam(policy.parameters(), lr=cfg.learning_rate)
# no per-rollout aggregation for dense logging
[docs]
def update(self, rollout: TensorDict) -> Dict[str, float]:
"""Update the policy using a collected rollout."""
rollout = rollout.clone()
compute_gae(rollout, gamma=self.cfg.gamma, gae_lambda=self.cfg.gae_lambda)
flat_rollout = transition_view(rollout, flatten=True)
advantages = flat_rollout["advantage"]
adv_mean = advantages.mean()
adv_std = advantages.std().clamp_min(1e-8)
total_actor_loss = 0.0
total_value_loss = 0.0
total_entropy = 0.0
total_steps = 0
for _ in range(self.cfg.n_epochs):
for batch in iterate_minibatches(
flat_rollout, self.cfg.batch_size, self.device
):
old_logprobs = batch["sample_log_prob"].clone()
returns = batch["return"].clone()
batch_advantages = ((batch["advantage"] - adv_mean) / adv_std).detach()
policy_module = getattr(self.policy, "module", self.policy)
eval_batch = policy_module.evaluate_actions(batch)
logprobs = eval_batch["sample_log_prob"]
entropy = eval_batch["entropy"]
values = eval_batch["value"]
ratio = (logprobs - old_logprobs).exp()
surr1 = ratio * batch_advantages
surr2 = (
torch.clamp(
ratio, 1.0 - self.cfg.clip_coef, 1.0 + self.cfg.clip_coef
)
* batch_advantages
)
actor_loss = -torch.min(surr1, surr2).mean()
value_loss = torch.nn.functional.mse_loss(values, returns)
entropy_loss = -entropy.mean()
loss = (
actor_loss
+ self.cfg.vf_coef * value_loss
+ self.cfg.ent_coef * entropy_loss
)
self.optimizer.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(
self.policy.parameters(), self.cfg.max_grad_norm
)
self.optimizer.step()
bs = batch.batch_size[0]
total_actor_loss += actor_loss.item() * bs
total_value_loss += value_loss.item() * bs
total_entropy += (-entropy_loss.item()) * bs
total_steps += bs
return {
"actor_loss": total_actor_loss / max(1, total_steps),
"value_loss": total_value_loss / max(1, total_steps),
"entropy": total_entropy / max(1, total_steps),
}