embodichain.agents.rl.collector#

Overview#

Collectors are responsible for interacting with vectorized environments and assembling rollout data into a preallocated TensorDict layout.

Classes

BaseCollector

Base class for rollout collectors.

SyncCollector

Synchronously collect rollouts from a vectorized environment.

BaseCollector#

class embodichain.agents.rl.collector.BaseCollector[source]#

Bases: ABC

Base class for rollout collectors.

Methods:

collect(num_steps[, rollout, on_step_callback])

Collect a rollout and return it as a TensorDict.

abstract collect(num_steps, rollout=None, on_step_callback=None)[source]#

Collect a rollout and return it as a TensorDict.

Return type:

TensorDict

SyncCollector#

class embodichain.agents.rl.collector.SyncCollector[source]#

Bases: BaseCollector

Synchronously collect rollouts from a vectorized environment.

Methods:

__init__(env, policy, device[, ...])

collect(num_steps[, rollout, on_step_callback])

Collect a rollout and return it as a TensorDict.

__init__(env, policy, device, reset_every_rollout=False)[source]#
collect(num_steps, rollout=None, on_step_callback=None)[source]#

Collect a rollout and return it as a TensorDict.

Return type:

TensorDict