Source code for embodichain.agents.rl.utils.trainer

# ----------------------------------------------------------------------------
# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------

from __future__ import annotations

from typing import Any, Dict
import time
import numpy as np
import torch
import wandb
from embodichain.utils import logger
from torch.utils.tensorboard import SummaryWriter
from collections import deque
from tensordict import TensorDict
from typing import Dict

from embodichain.lab.gym.envs.managers.action_manager import ActionManager
from embodichain.agents.rl.buffer import RolloutBuffer
from embodichain.agents.rl.collector import SyncCollector
from embodichain.lab.gym.envs.managers.event_manager import EventManager
from .helper import flatten_dict_observation


[docs] class Trainer: """Algorithm-agnostic trainer that coordinates training loop, logging, and evaluation."""
[docs] def __init__( self, policy, env, algorithm, buffer_size: int, batch_size: int, writer: SummaryWriter | None, eval_freq: int, save_freq: int, checkpoint_dir: str, exp_name: str, use_wandb: bool = True, eval_env=None, event_cfg=None, eval_event_cfg=None, num_eval_episodes: int = 5, distributed: bool = False, rank: int = 0, world_size: int = 1, ): self.policy = policy self.distributed = distributed self.rank = rank self.world_size = world_size self.env = env self.eval_env = eval_env self.algorithm = algorithm self.buffer_size = buffer_size self.batch_size = batch_size self.writer = writer self.eval_freq = eval_freq self.save_freq = save_freq self.checkpoint_dir = checkpoint_dir self.exp_name = exp_name self.use_wandb = use_wandb self.num_eval_episodes = num_eval_episodes if event_cfg is not None: self.event_manager = EventManager(event_cfg, env=self.env) if eval_event_cfg is not None: self.eval_event_manager = EventManager(eval_event_cfg, env=self.eval_env) # Get device from algorithm self.device = self.algorithm.device self.global_step = 0 self.start_time = time.time() self.ret_window = deque(maxlen=100) self.len_window = deque(maxlen=100) self.train_history: list[dict[str, float]] = [] self.eval_history: list[dict[str, float]] = [] self.last_eval_metrics: dict[str, float] = {} self.last_train_metrics: dict[str, float] = {} self.latest_checkpoint_path: str | None = None num_envs = getattr(self.env, "num_envs", None) if num_envs is None: raise RuntimeError("Env must expose num_envs for trainer statistics.") obs_dim = getattr(self.policy, "obs_dim", None) action_dim = getattr(self.policy, "action_dim", None) if obs_dim is None or action_dim is None: raise RuntimeError("Policy must expose obs_dim and action_dim.") self.buffer = RolloutBuffer( num_envs=num_envs, rollout_len=self.buffer_size, obs_dim=obs_dim, action_dim=action_dim, device=self.device, ) self.collector = SyncCollector( env=self.env, policy=self.policy, device=self.device, reset_every_rollout=bool( getattr( getattr(self.algorithm, "cfg", None), "reset_every_rollout", False ) ), ) # episode stats tracked on device to avoid repeated CPU round-trips self.curr_ret = torch.zeros(num_envs, dtype=torch.float32, device=self.device) self.curr_len = torch.zeros(num_envs, dtype=torch.int32, device=self.device)
# ---- lightweight helpers for dense logging ---- @staticmethod def _mean_scalar(x) -> float: if hasattr(x, "detach"): x = x.detach().cpu().numpy() else: x = np.asarray(x) return float(np.mean(x)) def _log_scalar_dict(self, prefix: str, data: dict): if not self.writer or not isinstance(data, dict): return for k, v in data.items(): try: self.writer.add_scalar( f"{prefix}/{k}", self._mean_scalar(v), self.global_step ) except Exception: continue def _pack_log_dict(self, prefix: str, data: dict) -> dict: if not isinstance(data, dict): return {} out = {} for k, v in data.items(): try: out[f"{prefix}/{k}"] = self._mean_scalar(v) except Exception: continue return out
[docs] def train(self, total_timesteps: int) -> Dict[str, Any]: if self.rank == 0: print(f"Start training, total steps: {total_timesteps}") while self.global_step < total_timesteps: self._collect_rollout() losses = self.algorithm.update(self.buffer.get(flatten=False)) self._log_train(losses) if ( self.eval_freq > 0 and self.eval_env is not None and self.global_step % self.eval_freq == 0 ): self._eval_once(num_episodes=self.num_eval_episodes) if self.global_step % self.save_freq == 0: self.save_checkpoint() return self.get_summary()
@torch.no_grad() def _collect_rollout(self): """Collect a rollout with the synchronous collector.""" # Callback function for statistics and logging def on_step(tensordict: TensorDict, info: dict): """Callback called at each step during rollout collection.""" reward = tensordict["reward"] done = tensordict["done"] # Episode stats self.curr_ret += reward self.curr_len += 1 done_idx = torch.nonzero(done, as_tuple=False).squeeze(-1) if done_idx.numel() > 0: finished_ret = self.curr_ret[done_idx].detach().cpu().tolist() finished_len = self.curr_len[done_idx].detach().cpu().tolist() self.ret_window.extend(finished_ret) self.len_window.extend(finished_len) self.curr_ret[done_idx] = 0 self.curr_len[done_idx] = 0 if not self.distributed: self.global_step += tensordict.batch_size[0] if self.rank == 0 and isinstance(info, dict): rewards_dict = info.get("rewards") metrics_dict = info.get("metrics") self._log_scalar_dict("rewards", rewards_dict) self._log_scalar_dict("metrics", metrics_dict) log_dict = {} log_dict.update(self._pack_log_dict("rewards", rewards_dict)) log_dict.update(self._pack_log_dict("metrics", metrics_dict)) if log_dict and self.use_wandb: wandb.log(log_dict, step=self.global_step) rollout = self.buffer.start_rollout() rollout = self.collector.collect( num_steps=self.buffer_size, rollout=rollout, on_step_callback=on_step, ) self.buffer.add(rollout) # Sync global_step and episode stats across ranks in distributed mode if self.distributed: if not torch.distributed.is_available(): raise RuntimeError( "Distributed training was requested (distributed=True), " "but torch.distributed is not available. " "Please ensure PyTorch was built with distributed support or " "set distributed=False." ) if not torch.distributed.is_initialized(): raise RuntimeError( "Distributed training was requested (distributed=True), " "but the torch.distributed process group is not initialized. " "Call torch.distributed.init_process_group(...) before creating " "or using Trainer(distributed=True, ...)." ) local_delta = self.env.num_envs * self.buffer_size delta_tensor = torch.tensor( [local_delta], dtype=torch.int64, device=self.device ) torch.distributed.all_reduce( delta_tensor, op=torch.distributed.ReduceOp.SUM ) self.global_step += int(delta_tensor.item()) self._sync_episode_stats() def _sync_episode_stats(self) -> None: """Sync ret_window and len_window across ranks; rank 0 gets merged stats.""" if not self.distributed or not torch.distributed.is_initialized(): return maxlen = 100 ret_list = list(self.ret_window) len_list = list(self.len_window) n = min(len(ret_list), maxlen) if n > 0: ret_list = ret_list[-n:] len_list = len_list[-n:] ret_tensor = torch.zeros(maxlen, dtype=torch.float32, device=self.device) len_tensor = torch.zeros(maxlen, dtype=torch.float32, device=self.device) if n > 0: ret_tensor[:n] = torch.tensor( ret_list, dtype=torch.float32, device=self.device ) len_tensor[:n] = torch.tensor( len_list, dtype=torch.float32, device=self.device ) count_tensor = torch.tensor([n], dtype=torch.int64, device=self.device) ret_list_all = [torch.zeros_like(ret_tensor) for _ in range(self.world_size)] len_list_all = [torch.zeros_like(len_tensor) for _ in range(self.world_size)] count_list_all = [ torch.zeros_like(count_tensor) for _ in range(self.world_size) ] torch.distributed.all_gather(ret_list_all, ret_tensor) torch.distributed.all_gather(len_list_all, len_tensor) torch.distributed.all_gather(count_list_all, count_tensor) if self.rank == 0: all_ret = [] all_len = [] for r in range(self.world_size): c = int(count_list_all[r].item()) if c > 0: all_ret.extend(ret_list_all[r][:c].cpu().tolist()) all_len.extend(len_list_all[r][:c].cpu().tolist()) self.ret_window.clear() self.len_window.clear() n_total = len(all_ret) start = max(0, n_total - maxlen) self.ret_window.extend(all_ret[start:]) self.len_window.extend(all_len[start:]) def _log_train(self, losses: Dict[str, float]): elapsed = max(1e-6, time.time() - self.start_time) sps = self.global_step / elapsed avgR = np.mean(self.ret_window) if len(self.ret_window) > 0 else float("nan") avgL = np.mean(self.len_window) if len(self.len_window) > 0 else float("nan") history_entry = { "global_step": float(self.global_step), "charts/SPS": float(sps), "charts/episode_reward_avg_100": float(avgR), "charts/episode_length_avg_100": float(avgL), } history_entry.update({f"train/{k}": float(v) for k, v in losses.items()}) self.train_history.append(history_entry) self.last_train_metrics = history_entry if self.writer: for k, v in losses.items(): self.writer.add_scalar(f"train/{k}", v, self.global_step) self.writer.add_scalar("charts/SPS", sps, self.global_step) if len(self.ret_window) > 0: self.writer.add_scalar( "charts/episode_reward_avg_100", float(np.mean(self.ret_window)), self.global_step, ) if len(self.len_window) > 0: self.writer.add_scalar( "charts/episode_length_avg_100", float(np.mean(self.len_window)), self.global_step, ) # console and external logging are rank-0 only in distributed mode. if self.rank == 0: print( f"[train] step={self.global_step} sps={sps:.0f} avgReward(100)={avgR:.3f} avgLength(100)={avgL:.1f}" ) # wandb (mirror TB logs) if self.use_wandb: log_dict = {f"train/{k}": v for k, v in losses.items()} log_dict["charts/SPS"] = sps if not np.isnan(avgR): log_dict["charts/episode_reward_avg_100"] = float(avgR) if not np.isnan(avgL): log_dict["charts/episode_length_avg_100"] = float(avgL) wandb.log(log_dict, step=self.global_step) @torch.no_grad() def _eval_once(self, num_episodes: int = 5) -> Dict[str, float]: """Run evaluation for specified number of episodes. Each episode runs all parallel environments until completion, allowing environments to finish at different times. Args: num_episodes: Number of episodes to evaluate """ self.policy.eval() episode_returns = [] episode_lengths = [] episode_successes = [] metric_values: dict[str, list[float]] = {} # Evaluation does not consume the training rollout buffer; binding it here can # overflow the shared RL buffer when eval episodes are longer than buffer_size. for _ in range(num_episodes): # Reset and initialize episode tracking obs, _ = self.eval_env.reset() obs = flatten_dict_observation(obs) num_envs = obs.shape[0] if obs.ndim == 2 else 1 done_mask = torch.zeros(num_envs, dtype=torch.bool, device=self.device) cumulative_reward = torch.zeros( num_envs, dtype=torch.float32, device=self.device ) step_count = torch.zeros(num_envs, dtype=torch.int32, device=self.device) # Run episode until all environments complete while not done_mask.all(): # Get deterministic actions from policy action_td = TensorDict( {"obs": obs}, batch_size=[num_envs], device=self.device, ) action_td = self.policy.get_action(action_td, deterministic=True) actions = action_td["action"] am: ActionManager = getattr(self.eval_env, "action_manager", None) if am is None: action_in = actions else: action_in = am.convert_policy_action_to_env_action(actions) # Environment step obs, reward, terminated, truncated, info = self.eval_env.step(action_in) obs = ( flatten_dict_observation(obs) if isinstance(obs, TensorDict) else obs ) # Update statistics only for still-running environments done = terminated | truncated still_running = ~done_mask cumulative_reward[still_running] += reward[still_running].float() step_count[still_running] += 1 newly_done = done & (~done_mask) if newly_done.any(): if isinstance(info, dict) and "success" in info: successes = info["success"][newly_done].detach().cpu().tolist() episode_successes.extend([float(v) for v in successes]) if isinstance(info, dict) and "metrics" in info: for key, value in info["metrics"].items(): values = value[newly_done].detach().cpu().tolist() metric_values.setdefault(key, []).extend( [float(v) for v in values] ) done_mask |= done # Trigger evaluation events (e.g., video recording) if hasattr(self, "eval_event_manager"): if "interval" in self.eval_event_manager.available_modes: self.eval_event_manager.apply(mode="interval") # Collect episode statistics episode_returns.extend(cumulative_reward.cpu().tolist()) episode_lengths.extend(step_count.cpu().tolist()) # Finalize evaluation functors (e.g., video recording) if hasattr(self, "eval_event_manager"): for functor_cfg in self.eval_event_manager._mode_functor_cfgs.get( "interval", [] ): functor = functor_cfg.func save_path = functor_cfg.params.get("save_path", "./outputs/videos/eval") if hasattr(functor, "flush"): functor.flush(save_path) if hasattr(functor, "finalize"): functor.finalize(save_path) # Log evaluation metrics if self.writer and episode_returns: self.writer.add_scalar( "eval/avg_reward", float(np.mean(episode_returns)), self.global_step ) self.writer.add_scalar( "eval/avg_length", float(np.mean(episode_lengths)), self.global_step ) if episode_successes: self.writer.add_scalar( "eval/success_rate", float(np.mean(episode_successes)), self.global_step, ) summary = { "global_step": float(self.global_step), "eval/avg_reward": ( float(np.mean(episode_returns)) if episode_returns else float("nan") ), "eval/avg_length": ( float(np.mean(episode_lengths)) if episode_lengths else float("nan") ), "eval/success_rate": ( float(np.mean(episode_successes)) if episode_successes else float("nan") ), } for key, values in metric_values.items(): if values: summary[f"eval/metrics/{key}"] = float(np.mean(values)) self.eval_history.append(summary) self.last_eval_metrics = summary if self.rank == 0 and self.use_wandb: log_dict = { key: value for key, value in summary.items() if key != "global_step" and not np.isnan(value) } if log_dict: wandb.log(log_dict, step=self.global_step) return summary
[docs] def save_checkpoint(self) -> str | None: # minimal model-only checkpoint; trainer/algorithm states can be added if self.rank != 0: return None path = f"{self.checkpoint_dir}/{self.exp_name}_step_{self.global_step}.pt" policy_state = ( self.policy.module.state_dict() if hasattr(self.policy, "module") else self.policy.state_dict() ) torch.save( { "global_step": self.global_step, "policy": policy_state, }, path, ) self.latest_checkpoint_path = path print(f"Checkpoint saved: {path}") return path
[docs] def get_summary(self) -> Dict[str, Any]: elapsed = max(1e-6, time.time() - self.start_time) return { "global_step": int(self.global_step), "elapsed_time_sec": float(elapsed), "training_fps": float(self.global_step / elapsed), "last_train_metrics": dict(self.last_train_metrics), "last_eval_metrics": dict(self.last_eval_metrics), "train_history": list(self.train_history), "eval_history": list(self.eval_history), "latest_checkpoint_path": self.latest_checkpoint_path, }