Source code for embodichain.agents.datasets.online_data

# ----------------------------------------------------------------------------
# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------

from __future__ import annotations

from typing import Callable, Iterator, List, Optional

from tensordict import TensorDict
from torch.utils.data import IterableDataset

from embodichain.agents.engine.data import OnlineDataEngine
from embodichain.agents.datasets.sampler import ChunkSizeSampler


__all__ = [
    "OnlineDataset",
]


[docs] class OnlineDataset(IterableDataset): """Infinite IterableDataset backed by a live OnlineDataEngine shared buffer. Two sampling modes are supported depending on the ``batch_size`` argument: **Item mode** (``batch_size=None``, default) ``__iter__`` yields one ``TensorDict`` of shape ``[chunk_size]`` per step. Use with a standard ``DataLoader(dataset, batch_size=B)`` so the DataLoader handles collation and worker sharding. **Batch mode** (``batch_size=N``) ``__iter__`` yields one pre-batched ``TensorDict`` of shape ``[N, chunk_size]`` per step by calling ``engine.sample_batch(N, chunk_size)`` directly. Use with ``DataLoader(dataset, batch_size=None)`` to skip DataLoader collation and leverage the engine's bulk-sampling efficiency. **Dynamic chunk sizes** Pass a :class:`ChunkSizeSampler` as ``chunk_size`` to draw a fresh chunk length on every iteration step. In batch mode the size is sampled once per step and applied uniformly to all trajectories in the batch, ensuring a consistent ``[batch_size, chunk_size]`` shape. Two built-in samplers are provided: - :class:`UniformChunkSampler` — uniform discrete distribution over ``[low, high]``. - :class:`GMMChunkSampler` — Gaussian Mixture Model, useful for multi-modal chunk-length curricula. .. note:: ``__len__`` is intentionally absent — ``IterableDataset`` does not require it and the stream is infinite. .. note:: Multi-worker DataLoader: each worker gets its own iterator; since sampling is independent random draws from shared memory, this is safe. Args: engine: A started OnlineDataEngine whose shared buffer is used for sampling. chunk_size: Fixed number of consecutive timesteps per chunk (``int``), or a :class:`ChunkSizeSampler` that returns a fresh size on every iteration step. batch_size: If ``None``, yield single chunks of shape ``[chunk_size]`` (item mode). If an int, yield pre-batched TensorDicts of shape ``[batch_size, chunk_size]`` (batch mode). transform: Optional ``(TensorDict) -> TensorDict`` applied to each yielded item/batch before returning. Example — fixed chunk size, item mode:: dataset = OnlineDataset(engine, chunk_size=64) loader = DataLoader(dataset, batch_size=32, num_workers=4, collate_fn=OnlineDataset.collate_fn) for batch in loader: # batch has shape [32, 64, ...] train_step(batch) Example — fixed chunk size, batch mode:: dataset = OnlineDataset(engine, chunk_size=64, batch_size=32) loader = DataLoader(dataset, batch_size=None, collate_fn=OnlineDataset.passthrough_collate_fn) for batch in loader: # batch has shape [32, 64, ...] train_step(batch) Example — dynamic chunk size with uniform sampler:: sampler = UniformChunkSampler(low=16, high=64) dataset = OnlineDataset(engine, chunk_size=sampler) loader = DataLoader(dataset, batch_size=32) for batch in loader: # chunk dimension varies each batch train_step(batch) Example — dynamic chunk size with GMM sampler:: sampler = GMMChunkSampler( means=[16.0, 64.0], stds=[4.0, 8.0], weights=[0.6, 0.4], low=8, high=96, ) dataset = OnlineDataset(engine, chunk_size=sampler, batch_size=32) loader = DataLoader(dataset, batch_size=None) for batch in loader: train_step(batch) """
[docs] def __init__( self, engine: OnlineDataEngine, chunk_size: int | ChunkSizeSampler, batch_size: Optional[int] = None, transform: Optional[Callable[[TensorDict], TensorDict]] = None, ) -> None: if isinstance(chunk_size, int): if chunk_size < 1: raise ValueError(f"chunk_size must be ≥ 1, got {chunk_size}.") elif not isinstance(chunk_size, ChunkSizeSampler): raise TypeError( f"chunk_size must be an int or a ChunkSizeSampler, got {type(chunk_size).__name__}." ) self._engine = engine self._chunk_size = chunk_size self._batch_size = batch_size self._transform = transform
# ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _next_chunk_size(self) -> int: """Return the chunk size for the current iteration step. For fixed ``int`` chunk sizes this is a no-op attribute read. For :class:`ChunkSizeSampler` instances the sampler is called to draw a fresh value. Returns: Positive integer chunk size. """ if isinstance(self._chunk_size, int): return self._chunk_size return self._chunk_size() # ------------------------------------------------------------------ # IterableDataset interface # ------------------------------------------------------------------ def __iter__(self) -> Iterator[TensorDict]: """Yield trajectory chunks indefinitely from the shared buffer. In item mode each call to ``next()`` draws one chunk of shape ``[chunk_size]``. In batch mode each call draws a full batch of shape ``[batch_size, chunk_size]``. When a :class:`ChunkSizeSampler` is used, ``chunk_size`` is re-sampled once per yielded item/batch. Yields: TensorDict sampled from the engine's shared buffer, optionally post-processed by ``transform``. """ if self._batch_size is None: # In item mode, keep chunk_size fixed per iterator to preserve # consistent shapes for DataLoader collation. chunk_size = self._next_chunk_size() while True: # Item mode: draw one trajectory and remove the outer batch dim. raw = self._engine.sample_batch(batch_size=1, chunk_size=chunk_size) sample: TensorDict = raw[0] if self._transform is not None: sample = self._transform(sample) yield sample while True: chunk_size = self._next_chunk_size() # Batch mode: draw a full pre-batched TensorDict. sample = self._engine.sample_batch( batch_size=self._batch_size, chunk_size=chunk_size ) if self._transform is not None: sample = self._transform(sample) yield sample
[docs] @staticmethod def collate_fn(batch: List[TensorDict]) -> TensorDict: """Collate a list of TensorDicts into a single batched TensorDict. Pass this as ``collate_fn`` to ``DataLoader`` when using item mode (``batch_size`` not None on the DataLoader side) to avoid the default collation failure with TensorDict objects. Args: batch: List of TensorDicts, each of shape ``[chunk_size, ...]``. Returns: Stacked TensorDict of shape ``[len(batch), chunk_size, ...]``. """ import torch return torch.stack(batch)
[docs] @staticmethod def passthrough_collate_fn(batch: TensorDict) -> TensorDict: """Collate function for batch-mode DataLoaders. When the dataset is in batch mode it already yields pre-batched TensorDicts. With ``batch_size=None``, PyTorch's DataLoader skips auto-batching and passes each item directly to ``collate_fn`` as-is (not wrapped in a list). This function returns the TensorDict unchanged. Pass this as ``collate_fn`` to ``DataLoader`` when using batch mode (``batch_size=None`` on the DataLoader side) to avoid the default collation failure with TensorDict objects. Args: batch: A pre-batched TensorDict of shape ``[batch_size, chunk_size, ...]`` passed directly by the DataLoader. Returns: The pre-batched TensorDict unchanged. """ return batch