embodichain.agents#

Submodules

Datasets#

online_data

sampler

Classes:

ChunkSizeSampler

Abstract base class for chunk-size samplers.

GMMChunkSampler

Gaussian Mixture Model chunk-size sampler.

OnlineDataset

Infinite IterableDataset backed by a live OnlineDataEngine shared buffer.

UniformChunkSampler

Discrete-uniform chunk-size sampler over [low, high].

class embodichain.agents.datasets.ChunkSizeSampler[source]#

Bases: ABC

Abstract base class for chunk-size samplers.

Subclasses implement __call__() to return an integer chunk size on demand. A sampler is called once per OnlineDataset.__iter__() step, so consecutive samples / batches may have different time dimensions.

When used in batch mode the same chunk size is drawn once and applied to every trajectory in the batch so that the resulting TensorDict has a consistent shape [batch_size, chunk_size].

class embodichain.agents.datasets.GMMChunkSampler[source]#

Bases: ChunkSizeSampler

Gaussian Mixture Model chunk-size sampler.

Selects a mixture component according to weights, samples a value from the corresponding Normal(mean, std) distribution, rounds to the nearest integer, and optionally clamps the result to [low, high].

Parameters:
  • means (List[float]) – Mean of each Gaussian component (number of elements = K).

  • stds (List[float]) – Standard deviation of each component (must be > 0, same length as means).

  • weights (Optional[List[float]]) – Unnormalised mixture weights (same length as means). Defaults to a uniform distribution over all components.

  • low (Optional[int]) – Optional lower bound for clamping the sampled value (inclusive, must be ≥ 1 if provided).

  • high (Optional[int]) – Optional upper bound for clamping the sampled value (inclusive, must be ≥ low if both are provided).

Raises:

ValueError – If means, stds, or weights have mismatched lengths, if any std 0, or if the bounds are inconsistent.

Example — two-component mixture favouring short and long chunks:

sampler = GMMChunkSampler(
    means=[16.0, 64.0],
    stds=[4.0, 8.0],
    weights=[0.6, 0.4],
    low=8,
    high=96,
)
chunk_size = sampler()  # e.g. 18

Methods:

__init__(means, stds[, weights, low, high])

__init__(means, stds, weights=None, low=None, high=None)[source]#
class embodichain.agents.datasets.OnlineDataset[source]#

Bases: IterableDataset

Infinite IterableDataset backed by a live OnlineDataEngine shared buffer.

Two sampling modes are supported depending on the batch_size argument:

Item mode (batch_size=None, default)

__iter__ yields one TensorDict of shape [chunk_size] per step. Use with a standard DataLoader(dataset, batch_size=B) so the DataLoader handles collation and worker sharding.

Batch mode (batch_size=N)

__iter__ yields one pre-batched TensorDict of shape [N, chunk_size] per step by calling engine.sample_batch(N, chunk_size) directly. Use with DataLoader(dataset, batch_size=None) to skip DataLoader collation and leverage the engine’s bulk-sampling efficiency.

Dynamic chunk sizes

Pass a ChunkSizeSampler as chunk_size to draw a fresh chunk length on every iteration step. In batch mode the size is sampled once per step and applied uniformly to all trajectories in the batch, ensuring a consistent [batch_size, chunk_size] shape. Two built-in samplers are provided:

Note

__len__ is intentionally absent — IterableDataset does not require it and the stream is infinite.

Note

Multi-worker DataLoader: each worker gets its own iterator; since sampling is independent random draws from shared memory, this is safe.

Parameters:
  • engine (OnlineDataEngine) – A started OnlineDataEngine whose shared buffer is used for sampling.

  • chunk_size (int | ChunkSizeSampler) – Fixed number of consecutive timesteps per chunk (int), or a ChunkSizeSampler that returns a fresh size on every iteration step.

  • batch_size (Optional[int]) – If None, yield single chunks of shape [chunk_size] (item mode). If an int, yield pre-batched TensorDicts of shape [batch_size, chunk_size] (batch mode).

  • transform (Optional[Callable[[TensorDict], TensorDict]]) – Optional (TensorDict) -> TensorDict applied to each yielded item/batch before returning.

Example — fixed chunk size, item mode:

dataset = OnlineDataset(engine, chunk_size=64)
loader  = DataLoader(dataset, batch_size=32, num_workers=4,
                     collate_fn=OnlineDataset.collate_fn)
for batch in loader:
    # batch has shape [32, 64, ...]
    train_step(batch)

Example — fixed chunk size, batch mode:

dataset = OnlineDataset(engine, chunk_size=64, batch_size=32)
loader  = DataLoader(dataset, batch_size=None,
                     collate_fn=OnlineDataset.passthrough_collate_fn)
for batch in loader:
    # batch has shape [32, 64, ...]
    train_step(batch)

Example — dynamic chunk size with uniform sampler:

sampler = UniformChunkSampler(low=16, high=64)
dataset = OnlineDataset(engine, chunk_size=sampler)
loader  = DataLoader(dataset, batch_size=32)
for batch in loader:
    # chunk dimension varies each batch
    train_step(batch)

Example — dynamic chunk size with GMM sampler:

sampler = GMMChunkSampler(
    means=[16.0, 64.0], stds=[4.0, 8.0], weights=[0.6, 0.4],
    low=8, high=96,
)
dataset = OnlineDataset(engine, chunk_size=sampler, batch_size=32)
loader  = DataLoader(dataset, batch_size=None)
for batch in loader:
    train_step(batch)

Methods:

__init__(engine, chunk_size[, batch_size, ...])

collate_fn(batch)

Collate a list of TensorDicts into a single batched TensorDict.

passthrough_collate_fn(batch)

Collate function for batch-mode DataLoaders.

__init__(engine, chunk_size, batch_size=None, transform=None)[source]#
static collate_fn(batch)[source]#

Collate a list of TensorDicts into a single batched TensorDict.

Pass this as collate_fn to DataLoader when using item mode (batch_size not None on the DataLoader side) to avoid the default collation failure with TensorDict objects.

Parameters:

batch (List[TensorDict]) – List of TensorDicts, each of shape [chunk_size, ...].

Return type:

TensorDict

Returns:

Stacked TensorDict of shape [len(batch), chunk_size, ...].

static passthrough_collate_fn(batch)[source]#

Collate function for batch-mode DataLoaders.

When the dataset is in batch mode it already yields pre-batched TensorDicts. With batch_size=None, PyTorch’s DataLoader skips auto-batching and passes each item directly to collate_fn as-is (not wrapped in a list). This function returns the TensorDict unchanged.

Pass this as collate_fn to DataLoader when using batch mode (batch_size=None on the DataLoader side) to avoid the default collation failure with TensorDict objects.

Parameters:

batch (TensorDict) – A pre-batched TensorDict of shape [batch_size, chunk_size, ...] passed directly by the DataLoader.

Return type:

TensorDict

Returns:

The pre-batched TensorDict unchanged.

class embodichain.agents.datasets.UniformChunkSampler[source]#

Bases: ChunkSizeSampler

Discrete-uniform chunk-size sampler over [low, high].

Draws an integer uniformly at random from the closed interval [low, high] on every call.

Parameters:
  • low (int) – Minimum chunk size (inclusive, must be ≥ 1).

  • high (int) – Maximum chunk size (inclusive, must be ≥ low).

Raises:

ValueError – If low < 1 or high < low.

Example:

sampler = UniformChunkSampler(low=16, high=64)
chunk_size = sampler()  # e.g. 37

Methods:

__init__(low, high)

__init__(low, high)[source]#

Online Data Engine#

data

Classes:

OnlineDataEngine

Engine for managing Online Data Streaming (ODS) and environment rollouts.

OnlineDataEngineCfg

OnlineDataEngineCfg(buffer_size: 'int' = <factory>, max_episode_steps: 'int' = <factory>, state_dim: 'int' = <factory>, buffer_device: 'str' = <factory>, gym_config: 'dict' = <factory>, action_config: 'dict' = <factory>, refill_threshold: 'int' = <factory>)

class embodichain.agents.engine.OnlineDataEngine[source]#

Bases: object

Engine for managing Online Data Streaming (ODS) and environment rollouts.

Creates a shared rollout buffer in CPU shared memory, spawns a dedicated simulation subprocess that fills the buffer with demonstration trajectories, and exposes a sample_batch() method for the training process to draw batches of trajectory chunks.

Subprocess lifecycle

The simulation subprocess is started in start() and immediately receives a fill signal so the buffer is populated before the first call to sample_batch(). The subprocess loops indefinitely: it waits for fill_signal, runs buffer_size // num_envs rollouts to overwrite every buffer slot, then goes back to waiting.

Concurrency and lock protection

_lock_index [write_start, write_end) is updated by the subprocess after each rollout so that sample_batch() can skip the slot currently being written to, preventing partial reads.

Refill criterion

sample_batch() accumulates the total number of individual trajectory samples drawn into _sample_count. When this counter exceeds refill_threshold the fill signal is raised and the counter resets to zero. This amortises the cost of GPU-accelerated simulation across many training iterations.

Initialisation barrier

The is_init property returns False until the subprocess completes the very first full buffer fill, after which it becomes permanently True. Training code should wait on this flag before calling sample_batch() to avoid drawing all-zero data.

Parameters:

cfg (OnlineDataEngineCfg) – Engine configuration.

shared_buffer#

Shared-memory TensorDict of shape [buffer_size, max_episode_steps, ...].

buffer_size#

Total number of trajectory slots in the shared buffer.

device#

Device of the shared buffer.

is_init#

True once the buffer has been populated at least once.

Methods:

__init__(cfg)

sample_batch(batch_size, chunk_size)

Sample a batch of trajectory chunks from the shared rollout buffer.

start()

stop()

Terminate the simulation subprocess and release resources.

Attributes:

is_init

Whether the shared buffer has been fully populated at least once.

__init__(cfg)[source]#
property is_init: bool#

Whether the shared buffer has been fully populated at least once.

Returns True after the simulation subprocess completes its first full buffer fill, False while that initial fill is still in progress. Callers that must not sample stale (all-zero) data can poll or block on this property before entering their training loop:

while not engine.is_init:
    time.sleep(0.5)
Returns:

True once the buffer contains valid trajectory data.

sample_batch(batch_size, chunk_size)[source]#

Sample a batch of trajectory chunks from the shared rollout buffer.

Randomly draws batch_size environment trajectories from the portion of the buffer that has been written at least once, skipping any rows currently being overwritten by the simulation subprocess. For each selected trajectory a contiguous window of chunk_size timesteps is chosen at a uniformly random offset.

After sampling the internal _sample_count is incremented by batch_size; if the count exceeds refill_threshold a buffer refill is triggered automatically.

Parameters:
  • batch_size (int) – Number of trajectory chunks to include in the batch.

  • chunk_size (int) – Number of consecutive timesteps in each chunk.

Return type:

TensorDict

Returns:

TensorDict with batch size [batch_size, chunk_size].

Raises:

ValueError – If chunk_size exceeds max_episode_steps.

start()[source]#
Return type:

None

stop()[source]#

Terminate the simulation subprocess and release resources.

Sets the close signal and waits briefly for the subprocess to exit gracefully (it checks the signal between rollout steps). If the subprocess is still alive after the grace period it is force-terminated.

Safe to call multiple times — subsequent calls are no-ops if the subprocess has already been terminated.

Return type:

None

class embodichain.agents.engine.OnlineDataEngineCfg[source]#

Bases: object

OnlineDataEngineCfg(buffer_size: ‘int’ = <factory>, max_episode_steps: ‘int’ = <factory>, state_dim: ‘int’ = <factory>, buffer_device: ‘str’ = <factory>, gym_config: ‘dict’ = <factory>, action_config: ‘dict’ = <factory>, refill_threshold: ‘int’ = <factory>)

Methods:

__init__([buffer_size, max_episode_steps, ...])

copy(**kwargs)

Return a new object replacing specified fields with new values.

replace(**kwargs)

Return a new object replacing specified fields with new values.

to_dict()

Convert an object into dictionary recursively.

validate([prefix])

Check the validity of configclass object.

Attributes:

action_config

Action configuration dictionary.

buffer_device

Device on which the shared buffer is allocated.

buffer_size

Number of episodes (environment trajectories) that can be stored in the shared buffer at once.

gym_config

Gym environment configuration dictionary (already loaded, not a file path).

max_episode_steps

Maximum number of timesteps per episode.

refill_threshold

Total number of samples (refill_threshold * buffer_size) drawn from the shared buffer before a refill is triggered.

state_dim

Dimensionality of the state space.

__init__(buffer_size=<factory>, max_episode_steps=<factory>, state_dim=<factory>, buffer_device=<factory>, gym_config=<factory>, action_config=<factory>, refill_threshold=<factory>)#
action_config: dict#

Action configuration dictionary. The contents depend on the specific environment and robot being used.

buffer_device: str#

Device on which the shared buffer is allocated.

buffer_size: int#

Number of episodes (environment trajectories) that can be stored in the shared buffer at once. Must be ≥ num_envs and ideally a multiple of num_envs.

copy(**kwargs)#

Return a new object replacing specified fields with new values.

This is especially useful for frozen classes. Example usage:

@configclass(frozen=True)
class C:
    x: int
    y: int

c = C(1, 2)
c1 = c.replace(x=3)
assert c1.x == 3 and c1.y == 2
Parameters:
  • obj (object) – The object to replace.

  • **kwargs – The fields to replace and their new values.

Return type:

object

Returns:

The new object.

gym_config: dict#

Gym environment configuration dictionary (already loaded, not a file path). The contents depend on the specific environment being used. Default is None.

max_episode_steps: int#

Maximum number of timesteps per episode. Must be ≥ chunk_size used by OnlineDataset.

refill_threshold: int#

Total number of samples (refill_threshold * buffer_size) drawn from the shared buffer before a refill is triggered. Accumulates across all calls to OnlineDataEngine.sample_batch(). When this threshold is exceeded the engine signals the simulation subprocess to regenerate the entire buffer, amortising the cost of environment simulation over many training steps.

replace(**kwargs)#

Return a new object replacing specified fields with new values.

This is especially useful for frozen classes. Example usage:

@configclass(frozen=True)
class C:
    x: int
    y: int

c = C(1, 2)
c1 = c.replace(x=3)
assert c1.x == 3 and c1.y == 2
Parameters:
  • obj (object) – The object to replace.

  • **kwargs – The fields to replace and their new values.

Return type:

object

Returns:

The new object.

state_dim: int#

Dimensionality of the state space.

to_dict()#

Convert an object into dictionary recursively.

Note

Ignores all names starting with “__” (i.e. built-in methods).

Parameters:

obj (object) – An instance of a class to convert.

Raises:

ValueError – When input argument is not an object.

Return type:

dict[str, Any]

Returns:

Converted dictionary mapping.

validate(prefix='')#

Check the validity of configclass object.

This function checks if the object is a valid configclass object. A valid configclass object contains no MISSING entries.

Parameters:
  • obj (object) – The object to check.

  • prefix (str) – The prefix to add to the missing fields. Defaults to ‘’.

Return type:

list[str]

Returns:

A list of missing fields.

Raises:

TypeError – When the object is not a valid configuration object.

Reinforcement Learning#