# ----------------------------------------------------------------------------
# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------
from __future__ import annotations
import time
import torch
import multiprocessing as mp
from multiprocessing.sharedctypes import Synchronized, SynchronizedArray
from multiprocessing.synchronize import Event as MpEvent
from tensordict import TensorDict
from tqdm import tqdm
from embodichain.utils.logger import log_info, log_error
from embodichain.utils import configclass
[docs]
@configclass
class OnlineDataEngineCfg:
buffer_size: int = 16
"""Number of episodes (environment trajectories) that can be stored in the shared buffer at once.
Must be ≥ num_envs and ideally a multiple of num_envs."""
max_episode_steps: int = 300
"""Maximum number of timesteps per episode. Must be ≥ chunk_size used by OnlineDataset."""
# TODO: This param maybe changed to more general format.
state_dim: int = 14
"""Dimensionality of the state space."""
buffer_device: str = "cpu"
"""Device on which the shared buffer is allocated."""
# TODO: We may support multiple envs in the future.
gym_config: dict = dict()
"""Gym environment configuration dictionary (already loaded, not a file path).
The contents depend on the specific environment being used. Default is None."""
action_config: dict = dict()
"""Action configuration dictionary. The contents depend on the specific environment and robot being used."""
refill_threshold: int = 50
"""Total number of samples (refill_threshold * buffer_size) drawn from the shared buffer before a refill is triggered.
Accumulates across all calls to :meth:`OnlineDataEngine.sample_batch`. When this threshold
is exceeded the engine signals the simulation subprocess to regenerate the entire buffer,
amortising the cost of environment simulation over many training steps.
"""
# ---------------------------------------------------------------------------
# Subprocess entry point (module-level so it can be pickled by multiprocessing)
# ---------------------------------------------------------------------------
def _sim_worker_fn(
cfg: OnlineDataEngineCfg,
shared_buffer: TensorDict,
lock_index: SynchronizedArray,
fill_signal: MpEvent,
init_signal: MpEvent,
close_signal: MpEvent,
) -> None:
"""Simulation subprocess entry point.
Builds the gym environment, then waits on *fill_signal*. Each time the
signal is raised the subprocess runs enough rollouts to overwrite every
slot in *shared_buffer* with fresh demonstration data, and advances *lock_index*
so the main process can avoid sampling from the slot currently being written.
After the **first** fill completes *init_signal* is set exactly once so the
main process knows the buffer contains valid data.
Args:
cfg: Engine configuration (picklable dataclass).
shared_buffer: Shared-memory TensorDict of shape
``[buffer_size, max_episode_steps, ...]``.
lock_index: Two-element shared integer array ``[write_start, write_end)``
indicating which buffer rows are currently being overwritten.
fill_signal: Event set by the main process to request a refill.
init_signal: Event set by this worker after the first fill completes.
Remains set permanently thereafter.
close_signal: Event set by the main process to request a graceful shutdown.
"""
import gymnasium as gym
from embodichain.lab.gym.utils.gym_utils import (
config_to_cfg,
DEFAULT_MANAGER_MODULES,
)
from embodichain.lab.sim import SimulationManagerCfg
from embodichain.utils.logger import log_info, log_warning, log_error
gym_config: dict = cfg.gym_config
action_config: dict = cfg.action_config
# Build env config from the gym configuration dictionary.
env_cfg = config_to_cfg(gym_config, manager_modules=DEFAULT_MANAGER_MODULES)
env_cfg.filter_dataset_saving = True
env_cfg.init_rollout_buffer = False
env_cfg.sim_cfg = SimulationManagerCfg(
headless=gym_config.get("headless", True),
sim_device=gym_config.get("device", "cpu"),
enable_rt=gym_config.get("enable_rt", True),
gpu_id=gym_config.get("gpu_id", 0),
)
num_envs: int = env_cfg.num_envs
buffer_size: int = shared_buffer.batch_size[0]
if buffer_size % num_envs != 0:
log_warning(
f"[Simulation Process] buffer_size ({buffer_size}) is not evenly divisible by "
f"num_envs ({num_envs}). This may lead to inefficient buffer usage and should ideally be fixed by adjusting "
"the OnlineDataEngineCfg.",
)
num_rollouts_per_fill: int = buffer_size // num_envs
if buffer_size % num_envs != 0:
num_rollouts_per_fill += (
1 # Ensure we fill the entire buffer, even if the last slice is smaller.
)
# --- Build the environment and attach the initial tmp_buffer slice ------
env = gym.make(id=gym_config["id"], cfg=env_cfg, **action_config)
log_info("[Simulation Process] Environment created.", color="cyan")
# --- Main loop: wait for fill signal, then fill the entire buffer -------
try:
while True:
fill_signal.wait()
fill_signal.clear()
if close_signal.is_set():
log_info(
"[Simulation Process] Close signal received. Shutting down.",
color="cyan",
)
break
log_info(
"[Simulation Process] Fill signal received. Starting full buffer fill.",
color="cyan",
)
# Reset write cursor to the beginning of the buffer.
lock_index[0] = 0
lock_index[1] = num_envs
rollout_idx = 0
while rollout_idx < num_rollouts_per_fill:
if close_signal.is_set():
return
tmp_buffer = shared_buffer[lock_index[0] : lock_index[1], :]
env.get_wrapper_attr("set_rollout_buffer")(tmp_buffer)
_, _ = env.reset()
action_list = env.get_wrapper_attr("create_demo_action_list")()
if action_list is None or len(action_list) == 0:
log_warning(
f"[Simulation Process] Rollout {rollout_idx + 1}/{num_rollouts_per_fill}: "
"action list is empty, skipping episode."
)
continue
for action in tqdm(
action_list,
desc=f"[Sim] rollout {rollout_idx + 1}/{num_rollouts_per_fill}",
unit="step",
leave=False,
):
if close_signal.is_set():
return
env.step(action)
rollout_idx += 1
log_info(
f"[Simulation Process] Rollout {rollout_idx}/{num_rollouts_per_fill} done. "
f"lock_index=[{lock_index[0]}, {lock_index[1]}], ",
color="cyan",
)
# Advance lock_index to the next write slice.
next_start = lock_index[0] + num_envs
next_end = lock_index[1] + num_envs
if next_start >= buffer_size:
# Wrap around to the start of the buffer.
next_start = 0
next_end = num_envs
elif next_end > buffer_size:
next_end = buffer_size
next_start = buffer_size - num_envs
lock_index[0] = next_start
lock_index[1] = next_end
# # Signal that the buffer contains valid data for the first time.
# # is_set() is checked so subsequent refills do not redundantly set it.
if not init_signal.is_set():
init_signal.set()
log_info(
"[Simulation Process] Initial buffer fill complete. Engine is ready.",
color="cyan",
)
# # At this point the entire buffer has been filled with fresh data, and
# # all the data in the buffer is valid and safe to sample from.
lock_index[0] = -1
lock_index[1] = -1
except KeyboardInterrupt:
log_warning("[Simulation Process] Stopping (KeyboardInterrupt).")
except Exception as e:
log_error(f"[Simulation Process] Unhandled error: {e}")
finally:
env.close()
# ---------------------------------------------------------------------------
# OnlineDataEngine
# ---------------------------------------------------------------------------
[docs]
class OnlineDataEngine:
"""Engine for managing Online Data Streaming (ODS) and environment rollouts.
Creates a shared rollout buffer in CPU shared memory, spawns a dedicated
simulation subprocess that fills the buffer with demonstration trajectories,
and exposes a :meth:`sample_batch` method for the training process to draw
batches of trajectory chunks.
**Subprocess lifecycle**
The simulation subprocess is started in :meth:`start` and immediately
receives a fill signal so the buffer is populated before the first call to
:meth:`sample_batch`. The subprocess loops indefinitely: it waits for
*fill_signal*, runs ``buffer_size // num_envs`` rollouts to overwrite every
buffer slot, then goes back to waiting.
**Concurrency and lock protection**
:attr:`_lock_index` ``[write_start, write_end)`` is updated by the
subprocess after each rollout so that :meth:`sample_batch` can skip the
slot currently being written to, preventing partial reads.
**Refill criterion**
:meth:`sample_batch` accumulates the total number of individual trajectory
samples drawn into :attr:`_sample_count`. When this counter exceeds
:attr:`~OnlineDataEngineCfg.refill_threshold` the fill signal is raised
and the counter resets to zero. This amortises the cost of GPU-accelerated
simulation across many training iterations.
**Initialisation barrier**
The :attr:`is_init` property returns ``False`` until the subprocess
completes the very first full buffer fill, after which it becomes
permanently ``True``. Training code should wait on this flag before
calling :meth:`sample_batch` to avoid drawing all-zero data.
Args:
cfg: Engine configuration.
Attributes:
shared_buffer: Shared-memory TensorDict of shape
``[buffer_size, max_episode_steps, ...]``.
buffer_size: Total number of trajectory slots in the shared buffer.
device: Device of the shared buffer.
is_init: ``True`` once the buffer has been populated at least once.
"""
[docs]
def __init__(self, cfg: OnlineDataEngineCfg) -> None:
self.cfg = cfg
# Allocate the shared buffer (shape: [buffer_size, max_episode_steps, ...]).
self.shared_buffer: TensorDict = self._create_buffer()
self.buffer_size: int = self.shared_buffer.batch_size[0]
self.device = self.shared_buffer.device
num_envs: int = cfg.gym_config.get("num_envs", 1)
if num_envs > self.buffer_size:
log_error(
f"num_envs ({num_envs}) exceeds buffer_size ({self.buffer_size}). "
"Increase buffer_size in OnlineDataEngineCfg.",
error_type=ValueError,
)
# -------------------------------------------------------------------
# Shared interprocess state
# -------------------------------------------------------------------
# Use a spawn context to avoid forking unsafe runtime state.
self._mp_ctx = mp.get_context("forkserver")
# Current write window: subprocess updates these after each rollout.
# Shape: [write_start, write_end) (exclusive upper bound).
self._lock_index: SynchronizedArray = self._mp_ctx.Array("i", [0, num_envs])
# Raised by the main process to request a full buffer refill.
self._fill_signal: MpEvent = self._mp_ctx.Event()
# Set by the subprocess once the first complete buffer fill finishes.
# Used by the :attr:`is_init` property to let callers wait for readiness.
self._init_signal: MpEvent = self._mp_ctx.Event()
# Set by the main process to request the simulation subprocess to stop.
self._close_signal: MpEvent = self._mp_ctx.Event()
# Accumulated sample count used by the refill criterion.
self._sample_count: Synchronized = self._mp_ctx.Value("i", 0)
# Handle to the simulation subprocess, set in start() and used in stop().
self._sim_process: mp.Process | None = None
[docs]
def start(self) -> None:
self._sim_process: mp.Process = self._mp_ctx.Process(
target=_sim_worker_fn,
args=(
self.cfg,
self.shared_buffer,
self._lock_index,
self._fill_signal,
self._init_signal,
self._close_signal,
),
daemon=True,
)
self._sim_process.start()
log_info(
f"[OnlineDataEngine] Simulation subprocess started (PID={self._sim_process.pid}).",
color="green",
)
# Trigger the initial fill so data is ready before the first sample.
self._fill_signal.set()
while not self.is_init:
time.sleep(0.5)
# -----------------------------------------------------------------------
# Buffer initialisation
# -----------------------------------------------------------------------
def _create_buffer(self) -> TensorDict:
"""Allocate the shared rollout buffer.
The buffer has shape ``[buffer_size, max_episode_steps, ...]`` and is
placed in CPU shared memory so it can be safely accessed from both the
main process and the simulation subprocess.
Returns:
TensorDict in shared memory.
"""
from embodichain.lab.gym.utils.gym_utils import init_rollout_buffer_from_config
gym_config: dict = self.cfg.gym_config
max_episode_steps: int = gym_config.get(
"max_episode_steps", self.cfg.max_episode_steps
)
shared_td = init_rollout_buffer_from_config(
gym_config,
device=self.cfg.buffer_device,
batch_size=self.cfg.buffer_size,
max_episode_steps=max_episode_steps,
state_dim=self.cfg.state_dim,
)
if shared_td.device.type == "cpu":
shared_td.share_memory_()
return shared_td
# -----------------------------------------------------------------------
# Status
# -----------------------------------------------------------------------
@property
def is_init(self) -> bool:
"""Whether the shared buffer has been fully populated at least once.
Returns ``True`` after the simulation subprocess completes its first
full buffer fill, ``False`` while that initial fill is still in
progress. Callers that must not sample stale (all-zero) data can
poll or block on this property before entering their training loop::
while not engine.is_init:
time.sleep(0.5)
Returns:
``True`` once the buffer contains valid trajectory data.
"""
return self._init_signal.is_set()
# -----------------------------------------------------------------------
# Sampling
# -----------------------------------------------------------------------
[docs]
def sample_batch(self, batch_size: int, chunk_size: int) -> TensorDict:
"""Sample a batch of trajectory chunks from the shared rollout buffer.
Randomly draws *batch_size* environment trajectories from the portion
of the buffer that has been written at least once, skipping any rows
currently being overwritten by the simulation subprocess. For each
selected trajectory a contiguous window of *chunk_size* timesteps is
chosen at a uniformly random offset.
After sampling the internal :attr:`_sample_count` is incremented by
*batch_size*; if the count exceeds
:attr:`~OnlineDataEngineCfg.refill_threshold` a buffer refill is
triggered automatically.
Args:
batch_size: Number of trajectory chunks to include in the batch.
chunk_size: Number of consecutive timesteps in each chunk.
Returns:
TensorDict with batch size ``[batch_size, chunk_size]``.
Raises:
ValueError: If ``chunk_size`` exceeds ``max_episode_steps``.
"""
max_steps: int = self.shared_buffer.batch_size[1]
if chunk_size > max_steps:
log_error(
f"chunk_size ({chunk_size}) exceeds max_episode_steps ({max_steps}).",
error_type=ValueError,
)
# Build the set of rows that are safe to sample from: all valid rows
# minus the slice currently being written by the subprocess.
lock_start: int = self._lock_index[0]
lock_end: int = self._lock_index[1]
all_valid = torch.arange(self.buffer_size)
is_locked = (all_valid >= lock_start) & (all_valid < lock_end)
available = all_valid[~is_locked]
if len(available) == 0:
# Edge case: the entire valid region is locked. Sampling a batch
# is not possible in this state and will result in a hard failure.
log_error(
"[OnlineDataEngine] All valid buffer rows are currently locked. "
"Cannot sample a batch at this time; sampling fails because no "
"unlocked rows are available.",
error_type=RuntimeError,
)
# Sample row indices and chunk start offsets.
row_sample_idx = torch.randint(0, len(available), (batch_size,))
row_indices = available[row_sample_idx]
max_start = max_steps - chunk_size
start_indices = torch.randint(0, max_start + 1, (batch_size,))
time_offsets = torch.arange(chunk_size)
time_indices = start_indices[:, None] + time_offsets[None, :]
result = self.shared_buffer[row_indices[:, None], time_indices]
# Update sample count and conditionally trigger a refill.
self._trigger_refill_if_needed(batch_size)
return result
# -----------------------------------------------------------------------
# Refill criterion
# -----------------------------------------------------------------------
def _trigger_refill_if_needed(self, count: int = 1) -> None:
"""Accumulate sample count and trigger a buffer refill when the threshold is reached.
This method is called by :meth:`sample_batch` after every batch. The
refill is only requested when the fill signal is not already pending
(i.e. the subprocess has finished the previous refill).
Args:
count: Number of individual trajectory samples drawn in the latest
call to :meth:`sample_batch` (typically equal to *batch_size*).
"""
with self._sample_count.get_lock():
self._sample_count.value += count
should_refill = (
self._sample_count.value >= self.cfg.refill_threshold * self.buffer_size
and not self._fill_signal.is_set()
)
if should_refill:
self._sample_count.value = 0
if should_refill:
self._fill_signal.set()
log_info(
f"[OnlineDataEngine] Sample count reached refill threshold (refill_threshold * buffer_size) "
f"({self.cfg.refill_threshold * self.buffer_size}). Signalling subprocess to refill the buffer.",
color="cyan",
)
# -----------------------------------------------------------------------
# Lifecycle
# -----------------------------------------------------------------------
[docs]
def stop(self) -> None:
"""Terminate the simulation subprocess and release resources.
Sets the close signal and waits briefly for the subprocess to exit
gracefully (it checks the signal between rollout steps). If the
subprocess is still alive after the grace period it is force-terminated.
Safe to call multiple times — subsequent calls are no-ops if the
subprocess has already been terminated.
"""
if self._sim_process is None or not self._sim_process.is_alive():
return
# Ask the subprocess to stop and unblock it if it is waiting on fill_signal.
self._close_signal.set()
self._fill_signal.set()
# Allow time for a graceful exit (close_signal is checked between steps).
self._sim_process.join(timeout=5.0)
if self._sim_process.is_alive():
self._sim_process.terminate()
self._sim_process.join(timeout=3.0)
log_info("[OnlineDataEngine] Simulation subprocess terminated.", color="green")
def __del__(self) -> None:
self.stop()