embodichain.agents#
Submodules
Datasets#
|
|
|
Classes:
Abstract base class for chunk-size samplers. |
|
Gaussian Mixture Model chunk-size sampler. |
|
Infinite IterableDataset backed by a live OnlineDataEngine shared buffer. |
|
Discrete-uniform chunk-size sampler over |
- class embodichain.agents.datasets.ChunkSizeSampler[source]#
Bases:
ABCAbstract base class for chunk-size samplers.
Subclasses implement
__call__()to return an integer chunk size on demand. A sampler is called once perOnlineDataset.__iter__()step, so consecutive samples / batches may have different time dimensions.When used in batch mode the same chunk size is drawn once and applied to every trajectory in the batch so that the resulting TensorDict has a consistent shape
[batch_size, chunk_size].
- class embodichain.agents.datasets.GMMChunkSampler[source]#
Bases:
ChunkSizeSamplerGaussian Mixture Model chunk-size sampler.
Selects a mixture component according to
weights, samples a value from the correspondingNormal(mean, std)distribution, rounds to the nearest integer, and optionally clamps the result to[low, high].- Parameters:
means (
List[float]) – Mean of each Gaussian component (number of elements = K).stds (
List[float]) – Standard deviation of each component (must be > 0, same length asmeans).weights (
Optional[List[float]]) – Unnormalised mixture weights (same length asmeans). Defaults to a uniform distribution over all components.low (
Optional[int]) – Optional lower bound for clamping the sampled value (inclusive, must be ≥ 1 if provided).high (
Optional[int]) – Optional upper bound for clamping the sampled value (inclusive, must be ≥lowif both are provided).
- Raises:
ValueError – If
means,stds, orweightshave mismatched lengths, if anystd ≤ 0, or if the bounds are inconsistent.
Example — two-component mixture favouring short and long chunks:
sampler = GMMChunkSampler( means=[16.0, 64.0], stds=[4.0, 8.0], weights=[0.6, 0.4], low=8, high=96, ) chunk_size = sampler() # e.g. 18
Methods:
__init__(means, stds[, weights, low, high])
- class embodichain.agents.datasets.OnlineDataset[source]#
Bases:
IterableDatasetInfinite IterableDataset backed by a live OnlineDataEngine shared buffer.
Two sampling modes are supported depending on the
batch_sizeargument:- Item mode (
batch_size=None, default) __iter__yields oneTensorDictof shape[chunk_size]per step. Use with a standardDataLoader(dataset, batch_size=B)so the DataLoader handles collation and worker sharding.- Batch mode (
batch_size=N) __iter__yields one pre-batchedTensorDictof shape[N, chunk_size]per step by callingengine.sample_batch(N, chunk_size)directly. Use withDataLoader(dataset, batch_size=None)to skip DataLoader collation and leverage the engine’s bulk-sampling efficiency.- Dynamic chunk sizes
Pass a
ChunkSizeSampleraschunk_sizeto draw a fresh chunk length on every iteration step. In batch mode the size is sampled once per step and applied uniformly to all trajectories in the batch, ensuring a consistent[batch_size, chunk_size]shape. Two built-in samplers are provided:UniformChunkSampler— uniform discrete distribution over[low, high].GMMChunkSampler— Gaussian Mixture Model, useful for multi-modal chunk-length curricula.
Note
__len__is intentionally absent —IterableDatasetdoes not require it and the stream is infinite.Note
Multi-worker DataLoader: each worker gets its own iterator; since sampling is independent random draws from shared memory, this is safe.
- Parameters:
engine (
OnlineDataEngine) – A started OnlineDataEngine whose shared buffer is used for sampling.chunk_size (
int|ChunkSizeSampler) – Fixed number of consecutive timesteps per chunk (int), or aChunkSizeSamplerthat returns a fresh size on every iteration step.batch_size (
Optional[int]) – IfNone, yield single chunks of shape[chunk_size](item mode). If an int, yield pre-batched TensorDicts of shape[batch_size, chunk_size](batch mode).transform (
Optional[Callable[[TensorDict],TensorDict]]) – Optional(TensorDict) -> TensorDictapplied to each yielded item/batch before returning.
Example — fixed chunk size, item mode:
dataset = OnlineDataset(engine, chunk_size=64) loader = DataLoader(dataset, batch_size=32, num_workers=4, collate_fn=OnlineDataset.collate_fn) for batch in loader: # batch has shape [32, 64, ...] train_step(batch)
Example — fixed chunk size, batch mode:
dataset = OnlineDataset(engine, chunk_size=64, batch_size=32) loader = DataLoader(dataset, batch_size=None, collate_fn=OnlineDataset.passthrough_collate_fn) for batch in loader: # batch has shape [32, 64, ...] train_step(batch)
Example — dynamic chunk size with uniform sampler:
sampler = UniformChunkSampler(low=16, high=64) dataset = OnlineDataset(engine, chunk_size=sampler) loader = DataLoader(dataset, batch_size=32) for batch in loader: # chunk dimension varies each batch train_step(batch)
Example — dynamic chunk size with GMM sampler:
sampler = GMMChunkSampler( means=[16.0, 64.0], stds=[4.0, 8.0], weights=[0.6, 0.4], low=8, high=96, ) dataset = OnlineDataset(engine, chunk_size=sampler, batch_size=32) loader = DataLoader(dataset, batch_size=None) for batch in loader: train_step(batch)
Methods:
__init__(engine, chunk_size[, batch_size, ...])collate_fn(batch)Collate a list of TensorDicts into a single batched TensorDict.
passthrough_collate_fn(batch)Collate function for batch-mode DataLoaders.
- static collate_fn(batch)[source]#
Collate a list of TensorDicts into a single batched TensorDict.
Pass this as
collate_fntoDataLoaderwhen using item mode (batch_sizenot None on the DataLoader side) to avoid the default collation failure with TensorDict objects.- Parameters:
batch (
List[TensorDict]) – List of TensorDicts, each of shape[chunk_size, ...].- Return type:
TensorDict- Returns:
Stacked TensorDict of shape
[len(batch), chunk_size, ...].
- static passthrough_collate_fn(batch)[source]#
Collate function for batch-mode DataLoaders.
When the dataset is in batch mode it already yields pre-batched TensorDicts. With
batch_size=None, PyTorch’s DataLoader skips auto-batching and passes each item directly tocollate_fnas-is (not wrapped in a list). This function returns the TensorDict unchanged.Pass this as
collate_fntoDataLoaderwhen using batch mode (batch_size=Noneon the DataLoader side) to avoid the default collation failure with TensorDict objects.- Parameters:
batch (
TensorDict) – A pre-batched TensorDict of shape[batch_size, chunk_size, ...]passed directly by the DataLoader.- Return type:
TensorDict- Returns:
The pre-batched TensorDict unchanged.
- Item mode (
- class embodichain.agents.datasets.UniformChunkSampler[source]#
Bases:
ChunkSizeSamplerDiscrete-uniform chunk-size sampler over
[low, high].Draws an integer uniformly at random from the closed interval
[low, high]on every call.- Parameters:
low (
int) – Minimum chunk size (inclusive, must be ≥ 1).high (
int) – Maximum chunk size (inclusive, must be ≥low).
- Raises:
ValueError – If
low < 1orhigh < low.
Example:
sampler = UniformChunkSampler(low=16, high=64) chunk_size = sampler() # e.g. 37
Methods:
__init__(low, high)
Online Data Engine#
|
Classes:
Engine for managing Online Data Streaming (ODS) and environment rollouts. |
|
OnlineDataEngineCfg(buffer_size: 'int' = <factory>, max_episode_steps: 'int' = <factory>, state_dim: 'int' = <factory>, buffer_device: 'str' = <factory>, gym_config: 'dict' = <factory>, action_config: 'dict' = <factory>, refill_threshold: 'int' = <factory>) |
- class embodichain.agents.engine.OnlineDataEngine[source]#
Bases:
objectEngine for managing Online Data Streaming (ODS) and environment rollouts.
Creates a shared rollout buffer in CPU shared memory, spawns a dedicated simulation subprocess that fills the buffer with demonstration trajectories, and exposes a
sample_batch()method for the training process to draw batches of trajectory chunks.Subprocess lifecycle
The simulation subprocess is started in
start()and immediately receives a fill signal so the buffer is populated before the first call tosample_batch(). The subprocess loops indefinitely: it waits for fill_signal, runsbuffer_size // num_envsrollouts to overwrite every buffer slot, then goes back to waiting.Concurrency and lock protection
_lock_index[write_start, write_end)is updated by the subprocess after each rollout so thatsample_batch()can skip the slot currently being written to, preventing partial reads.Refill criterion
sample_batch()accumulates the total number of individual trajectory samples drawn into_sample_count. When this counter exceedsrefill_thresholdthe fill signal is raised and the counter resets to zero. This amortises the cost of GPU-accelerated simulation across many training iterations.Initialisation barrier
The
is_initproperty returnsFalseuntil the subprocess completes the very first full buffer fill, after which it becomes permanentlyTrue. Training code should wait on this flag before callingsample_batch()to avoid drawing all-zero data.- Parameters:
cfg (
OnlineDataEngineCfg) – Engine configuration.
Shared-memory TensorDict of shape
[buffer_size, max_episode_steps, ...].
- buffer_size#
Total number of trajectory slots in the shared buffer.
- device#
Device of the shared buffer.
- is_init#
Trueonce the buffer has been populated at least once.
Methods:
__init__(cfg)sample_batch(batch_size, chunk_size)Sample a batch of trajectory chunks from the shared rollout buffer.
start()stop()Terminate the simulation subprocess and release resources.
Attributes:
Whether the shared buffer has been fully populated at least once.
- property is_init: bool#
Whether the shared buffer has been fully populated at least once.
Returns
Trueafter the simulation subprocess completes its first full buffer fill,Falsewhile that initial fill is still in progress. Callers that must not sample stale (all-zero) data can poll or block on this property before entering their training loop:while not engine.is_init: time.sleep(0.5)
- Returns:
Trueonce the buffer contains valid trajectory data.
- sample_batch(batch_size, chunk_size)[source]#
Sample a batch of trajectory chunks from the shared rollout buffer.
Randomly draws batch_size environment trajectories from the portion of the buffer that has been written at least once, skipping any rows currently being overwritten by the simulation subprocess. For each selected trajectory a contiguous window of chunk_size timesteps is chosen at a uniformly random offset.
After sampling the internal
_sample_countis incremented by batch_size; if the count exceedsrefill_thresholda buffer refill is triggered automatically.- Parameters:
batch_size (
int) – Number of trajectory chunks to include in the batch.chunk_size (
int) – Number of consecutive timesteps in each chunk.
- Return type:
TensorDict- Returns:
TensorDict with batch size
[batch_size, chunk_size].- Raises:
ValueError – If
chunk_sizeexceedsmax_episode_steps.
- stop()[source]#
Terminate the simulation subprocess and release resources.
Sets the close signal and waits briefly for the subprocess to exit gracefully (it checks the signal between rollout steps). If the subprocess is still alive after the grace period it is force-terminated.
Safe to call multiple times — subsequent calls are no-ops if the subprocess has already been terminated.
- Return type:
None
- class embodichain.agents.engine.OnlineDataEngineCfg[source]#
Bases:
objectOnlineDataEngineCfg(buffer_size: ‘int’ = <factory>, max_episode_steps: ‘int’ = <factory>, state_dim: ‘int’ = <factory>, buffer_device: ‘str’ = <factory>, gym_config: ‘dict’ = <factory>, action_config: ‘dict’ = <factory>, refill_threshold: ‘int’ = <factory>)
Methods:
__init__([buffer_size, max_episode_steps, ...])copy(**kwargs)Return a new object replacing specified fields with new values.
replace(**kwargs)Return a new object replacing specified fields with new values.
to_dict()Convert an object into dictionary recursively.
validate([prefix])Check the validity of configclass object.
Attributes:
Action configuration dictionary.
Device on which the shared buffer is allocated.
Number of episodes (environment trajectories) that can be stored in the shared buffer at once.
Gym environment configuration dictionary (already loaded, not a file path).
Maximum number of timesteps per episode.
Total number of samples (refill_threshold * buffer_size) drawn from the shared buffer before a refill is triggered.
Dimensionality of the state space.
- __init__(buffer_size=<factory>, max_episode_steps=<factory>, state_dim=<factory>, buffer_device=<factory>, gym_config=<factory>, action_config=<factory>, refill_threshold=<factory>)#
-
action_config:
dict# Action configuration dictionary. The contents depend on the specific environment and robot being used.
-
buffer_device:
str# Device on which the shared buffer is allocated.
-
buffer_size:
int# Number of episodes (environment trajectories) that can be stored in the shared buffer at once. Must be ≥ num_envs and ideally a multiple of num_envs.
- copy(**kwargs)#
Return a new object replacing specified fields with new values.
This is especially useful for frozen classes. Example usage:
@configclass(frozen=True) class C: x: int y: int c = C(1, 2) c1 = c.replace(x=3) assert c1.x == 3 and c1.y == 2
- Parameters:
obj (
object) – The object to replace.**kwargs – The fields to replace and their new values.
- Return type:
object- Returns:
The new object.
-
gym_config:
dict# Gym environment configuration dictionary (already loaded, not a file path). The contents depend on the specific environment being used. Default is None.
-
max_episode_steps:
int# Maximum number of timesteps per episode. Must be ≥ chunk_size used by OnlineDataset.
-
refill_threshold:
int# Total number of samples (refill_threshold * buffer_size) drawn from the shared buffer before a refill is triggered. Accumulates across all calls to
OnlineDataEngine.sample_batch(). When this threshold is exceeded the engine signals the simulation subprocess to regenerate the entire buffer, amortising the cost of environment simulation over many training steps.
- replace(**kwargs)#
Return a new object replacing specified fields with new values.
This is especially useful for frozen classes. Example usage:
@configclass(frozen=True) class C: x: int y: int c = C(1, 2) c1 = c.replace(x=3) assert c1.x == 3 and c1.y == 2
- Parameters:
obj (
object) – The object to replace.**kwargs – The fields to replace and their new values.
- Return type:
object- Returns:
The new object.
-
state_dim:
int# Dimensionality of the state space.
- to_dict()#
Convert an object into dictionary recursively.
Note
Ignores all names starting with “__” (i.e. built-in methods).
- Parameters:
obj (
object) – An instance of a class to convert.- Raises:
ValueError – When input argument is not an object.
- Return type:
dict[str,Any]- Returns:
Converted dictionary mapping.
- validate(prefix='')#
Check the validity of configclass object.
This function checks if the object is a valid configclass object. A valid configclass object contains no MISSING entries.
- Parameters:
obj (
object) – The object to check.prefix (
str) – The prefix to add to the missing fields. Defaults to ‘’.
- Return type:
list[str]- Returns:
A list of missing fields.
- Raises:
TypeError – When the object is not a valid configuration object.