embodichain.agents.rl

Contents

embodichain.agents.rl#

Submodules

Algorithms#

Classes:

BaseAlgorithm

Base class for RL algorithms.

GRPO

Group Relative Policy Optimization on top of TensorDict rollouts.

GRPOCfg

Configuration for GRPO.

PPO

PPO algorithm consuming TensorDict rollouts.

PPOCfg

Configuration for the PPO algorithm.

Functions:

build_algo(name, cfg_kwargs, policy, device, *)

compute_gae(rollout, gamma, gae_lambda)

Compute GAE over a rollout stored as [num_envs, time + 1].

get_registered_algo_names()

class embodichain.agents.rl.algo.BaseAlgorithm[source]#

Bases: object

Base class for RL algorithms.

Algorithms only implement policy updates over collected rollouts.

Attributes:

Methods:

update(rollout)

Update policy using collected data and return training losses.

device: device#
update(rollout)[source]#

Update policy using collected data and return training losses.

Return type:

Dict[str, float]

class embodichain.agents.rl.algo.GRPO[source]#

Bases: BaseAlgorithm

Group Relative Policy Optimization on top of TensorDict rollouts.

Methods:

__init__(cfg, policy)

update(rollout)

Update policy using collected data and return training losses.

__init__(cfg, policy)[source]#
update(rollout)[source]#

Update policy using collected data and return training losses.

Return type:

Dict[str, float]

class embodichain.agents.rl.algo.GRPOCfg[source]#

Bases: AlgorithmCfg

Configuration for GRPO.

Methods:

__init__([device, learning_rate, ...])

copy(**kwargs)

Return a new object replacing specified fields with new values.

replace(**kwargs)

Return a new object replacing specified fields with new values.

to_dict()

Convert an object into dictionary recursively.

validate([prefix])

Check the validity of configclass object.

Attributes:

__init__(device=<factory>, learning_rate=<factory>, batch_size=<factory>, gamma=<factory>, gae_lambda=<factory>, max_grad_norm=<factory>, n_epochs=<factory>, clip_coef=<factory>, ent_coef=<factory>, kl_coef=<factory>, group_size=<factory>, eps=<factory>, reset_every_rollout=<factory>, truncate_at_first_done=<factory>)#
batch_size: int#
clip_coef: float#
copy(**kwargs)#

Return a new object replacing specified fields with new values.

This is especially useful for frozen classes. Example usage:

@configclass(frozen=True)
class C:
    x: int
    y: int

c = C(1, 2)
c1 = c.replace(x=3)
assert c1.x == 3 and c1.y == 2
Parameters:
  • obj (object) – The object to replace.

  • **kwargs – The fields to replace and their new values.

Return type:

object

Returns:

The new object.

device: str#
ent_coef: float#
eps: float#
gae_lambda: float#
gamma: float#
group_size: int#
kl_coef: float#
learning_rate: float#
max_grad_norm: float#
n_epochs: int#
replace(**kwargs)#

Return a new object replacing specified fields with new values.

This is especially useful for frozen classes. Example usage:

@configclass(frozen=True)
class C:
    x: int
    y: int

c = C(1, 2)
c1 = c.replace(x=3)
assert c1.x == 3 and c1.y == 2
Parameters:
  • obj (object) – The object to replace.

  • **kwargs – The fields to replace and their new values.

Return type:

object

Returns:

The new object.

reset_every_rollout: bool#
to_dict()#

Convert an object into dictionary recursively.

Note

Ignores all names starting with “__” (i.e. built-in methods).

Parameters:

obj (object) – An instance of a class to convert.

Raises:

ValueError – When input argument is not an object.

Return type:

dict[str, Any]

Returns:

Converted dictionary mapping.

truncate_at_first_done: bool#
validate(prefix='')#

Check the validity of configclass object.

This function checks if the object is a valid configclass object. A valid configclass object contains no MISSING entries.

Parameters:
  • obj (object) – The object to check.

  • prefix (str) – The prefix to add to the missing fields. Defaults to ‘’.

Return type:

list[str]

Returns:

A list of missing fields.

Raises:

TypeError – When the object is not a valid configuration object.

class embodichain.agents.rl.algo.PPO[source]#

Bases: BaseAlgorithm

PPO algorithm consuming TensorDict rollouts.

Methods:

__init__(cfg, policy)

update(rollout)

Update the policy using a collected rollout.

__init__(cfg, policy)[source]#
update(rollout)[source]#

Update the policy using a collected rollout.

Return type:

Dict[str, float]

class embodichain.agents.rl.algo.PPOCfg[source]#

Bases: AlgorithmCfg

Configuration for the PPO algorithm.

Methods:

__init__([device, learning_rate, ...])

copy(**kwargs)

Return a new object replacing specified fields with new values.

replace(**kwargs)

Return a new object replacing specified fields with new values.

to_dict()

Convert an object into dictionary recursively.

validate([prefix])

Check the validity of configclass object.

Attributes:

__init__(device=<factory>, learning_rate=<factory>, batch_size=<factory>, gamma=<factory>, gae_lambda=<factory>, max_grad_norm=<factory>, n_epochs=<factory>, clip_coef=<factory>, ent_coef=<factory>, vf_coef=<factory>)#
batch_size: int#
clip_coef: float#
copy(**kwargs)#

Return a new object replacing specified fields with new values.

This is especially useful for frozen classes. Example usage:

@configclass(frozen=True)
class C:
    x: int
    y: int

c = C(1, 2)
c1 = c.replace(x=3)
assert c1.x == 3 and c1.y == 2
Parameters:
  • obj (object) – The object to replace.

  • **kwargs – The fields to replace and their new values.

Return type:

object

Returns:

The new object.

device: str#
ent_coef: float#
gae_lambda: float#
gamma: float#
learning_rate: float#
max_grad_norm: float#
n_epochs: int#
replace(**kwargs)#

Return a new object replacing specified fields with new values.

This is especially useful for frozen classes. Example usage:

@configclass(frozen=True)
class C:
    x: int
    y: int

c = C(1, 2)
c1 = c.replace(x=3)
assert c1.x == 3 and c1.y == 2
Parameters:
  • obj (object) – The object to replace.

  • **kwargs – The fields to replace and their new values.

Return type:

object

Returns:

The new object.

to_dict()#

Convert an object into dictionary recursively.

Note

Ignores all names starting with “__” (i.e. built-in methods).

Parameters:

obj (object) – An instance of a class to convert.

Raises:

ValueError – When input argument is not an object.

Return type:

dict[str, Any]

Returns:

Converted dictionary mapping.

validate(prefix='')#

Check the validity of configclass object.

This function checks if the object is a valid configclass object. A valid configclass object contains no MISSING entries.

Parameters:
  • obj (object) – The object to check.

  • prefix (str) – The prefix to add to the missing fields. Defaults to ‘’.

Return type:

list[str]

Returns:

A list of missing fields.

Raises:

TypeError – When the object is not a valid configuration object.

vf_coef: float#
embodichain.agents.rl.algo.build_algo(name, cfg_kwargs, policy, device, *, distributed=False)[source]#
embodichain.agents.rl.algo.compute_gae(rollout, gamma, gae_lambda)[source]#

Compute GAE over a rollout stored as [num_envs, time + 1].

Parameters:
  • rollout (TensorDict) – Rollout TensorDict where value[:, -1] stores the bootstrap value for the final observation and transition-only fields reserve their last slot as padding.

  • gamma (float) – Discount factor.

  • gae_lambda (float) – GAE lambda coefficient.

Return type:

tuple[Tensor, Tensor]

Returns:

Tuple of (advantages, returns), both shaped [num_envs, time].

embodichain.agents.rl.algo.get_registered_algo_names()[source]#
Return type:

list[str]

Rollout Buffer#

Classes:

RolloutBuffer

Single-rollout buffer backed by a preallocated TensorDict.

Functions:

iterate_minibatches(rollout, batch_size, device)

Yield shuffled minibatches from a flattened rollout.

transition_view(rollout[, flatten])

Build a transition-aligned TensorDict from a rollout.

class embodichain.agents.rl.buffer.RolloutBuffer[source]#

Bases: object

Single-rollout buffer backed by a preallocated TensorDict.

The shared rollout uses a uniform [num_envs, time + 1] layout. For transition-only fields such as action, reward, and done, the final time index is reused as padding so the collector, environment, and algorithms can share a single TensorDict batch shape.

Methods:

__init__(num_envs, rollout_len, obs_dim, ...)

add(rollout)

Mark the shared rollout as ready for consumption.

get([flatten])

Return the stored rollout and clear the buffer.

is_full()

Return whether a rollout is waiting to be consumed.

start_rollout()

Return the shared rollout TensorDict for collector write-in.

Attributes:

__init__(num_envs, rollout_len, obs_dim, action_dim, device)[source]#
add(rollout)[source]#

Mark the shared rollout as ready for consumption.

Return type:

None

property buffer: TensorDict#
get(flatten=True)[source]#

Return the stored rollout and clear the buffer.

When flatten is True, the rollout is first converted to a transition view that drops the padded final slot from transition-only fields.

Return type:

TensorDict

is_full()[source]#

Return whether a rollout is waiting to be consumed.

Return type:

bool

start_rollout()[source]#

Return the shared rollout TensorDict for collector write-in.

Return type:

TensorDict

embodichain.agents.rl.buffer.iterate_minibatches(rollout, batch_size, device)[source]#

Yield shuffled minibatches from a flattened rollout.

Return type:

Iterator[TensorDict]

embodichain.agents.rl.buffer.transition_view(rollout, flatten=False)[source]#

Build a transition-aligned TensorDict from a rollout.

The shared rollout uses a uniform [num_envs, time + 1] layout. For transition-only fields such as action, reward, and done, the final slot is reserved as padding so that all rollout fields share the same batch shape. This helper drops that padded slot and exposes the valid transition slices as a TensorDict with batch shape [num_envs, time].

Parameters:
  • rollout (TensorDict) – Rollout TensorDict with root batch shape [num_envs, time + 1].

  • flatten (bool) – If True, return a flattened [num_envs * time] view.

Return type:

TensorDict

Returns:

TensorDict containing transition-aligned fields.

Policy Models#

Classes:

ActorCritic

Actor-Critic with learnable log_std for Gaussian policy.

ActorOnly

Actor-only policy for algorithms that do not use a value function (e.g., GRPO).

MLP

General MLP supporting custom last activation, orthogonal init, and output reshape.

Policy

Abstract base class that all RL policies must implement.

Functions:

build_mlp_from_cfg(module_cfg, in_dim, out_dim)

Construct an MLP module from a minimal json-like config.

build_policy(policy_block, obs_space, ...[, ...])

Build a policy from config using spaces for extensibility.

get_policy_class(name)

get_registered_policy_names()

register_policy(name, policy_cls)

class embodichain.agents.rl.models.ActorCritic[source]#

Bases: Policy

Actor-Critic with learnable log_std for Gaussian policy.

This is a placeholder implementation of the Policy interface that: - Encapsulates MLP networks (actor + critic) that need to be trained by RL algorithms - Handles internal computation: MLP output → mean + learnable log_std → Normal distribution - Provides a uniform interface for RL algorithms (PPO, SAC, etc.)

This allows seamless swapping with other policy implementations (e.g., VLAPolicy) without modifying RL algorithm code.

Implements TensorDict-native interfaces while preserving get_action() compatibility for evaluation and legacy call-sites.

Methods:

__init__(obs_dim, action_dim, device, actor, ...)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

evaluate_actions(tensordict)

Evaluate actions and return current policy outputs.

forward(tensordict[, deterministic])

Write sampled actions and value estimates into the TensorDict.

get_value(tensordict)

Write value estimate for the given observations into the TensorDict.

__init__(obs_dim, action_dim, device, actor, critic)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

evaluate_actions(tensordict)[source]#

Evaluate actions and return current policy outputs.

Parameters:

tensordict (TensorDict) – TensorDict containing obs and action.

Return type:

TensorDict

Returns:

A new TensorDict containing sample_log_prob, entropy, and value.

forward(tensordict, deterministic=False)[source]#

Write sampled actions and value estimates into the TensorDict.

Return type:

TensorDict

get_value(tensordict)[source]#

Write value estimate for the given observations into the TensorDict.

Parameters:

tensordict (TensorDict) – Input TensorDict containing obs.

Return type:

TensorDict

Returns:

TensorDict with value populated.

class embodichain.agents.rl.models.ActorOnly[source]#

Bases: Policy

Actor-only policy for algorithms that do not use a value function (e.g., GRPO).

Same interface as ActorCritic: get_action and evaluate_actions return (action, log_prob, value), but value is always zeros since no critic is used.

Methods:

__init__(obs_dim, action_dim, device, actor)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

evaluate_actions(tensordict)

Evaluate actions and return current policy outputs.

forward(tensordict[, deterministic])

Write sampled actions and value estimates into the TensorDict.

get_value(tensordict)

Write value estimate for the given observations into the TensorDict.

__init__(obs_dim, action_dim, device, actor)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

evaluate_actions(tensordict)[source]#

Evaluate actions and return current policy outputs.

Parameters:

tensordict (TensorDict) – TensorDict containing obs and action.

Return type:

TensorDict

Returns:

A new TensorDict containing sample_log_prob, entropy, and value.

forward(tensordict, deterministic=False)[source]#

Write sampled actions and value estimates into the TensorDict.

Return type:

TensorDict

get_value(tensordict)[source]#

Write value estimate for the given observations into the TensorDict.

Parameters:

tensordict (TensorDict) – Input TensorDict containing obs.

Return type:

TensorDict

Returns:

TensorDict with value populated.

class embodichain.agents.rl.models.MLP[source]#

Bases: Sequential

General MLP supporting custom last activation, orthogonal init, and output reshape.

Parameters:
  • input_dim (-) – input dimension

  • output_dim (-) – output dimension (int or shape tuple/list)

  • hidden_dims (-) – hidden layer sizes, e.g. [256, 256]

  • activation (-) – hidden layer activation name (relu/elu/tanh/gelu/silu)

  • last_activation (-) – last-layer activation name or None for linear

  • use_layernorm (-) – whether to add LayerNorm after each hidden linear layer

  • dropout_p (-) – dropout probability for hidden layers (0 disables)

Methods:

__init__(input_dim, output_dim, hidden_dims)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

init_orthogonal([scales])

Orthogonal-initialize linear layers and zero the bias.

__init__(input_dim, output_dim, hidden_dims, activation='elu', last_activation=None, use_layernorm=False, dropout_p=0.0)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

init_orthogonal(scales=1.0)[source]#

Orthogonal-initialize linear layers and zero the bias.

scales: single gain value or a sequence with length equal to the number of linear layers.

Return type:

None

class embodichain.agents.rl.models.Policy[source]#

Bases: Module, ABC

Abstract base class that all RL policies must implement.

A Policy: - Encapsulates neural networks that are trained by RL algorithms - Handles internal computations (e.g., network output → distribution) - Provides a uniform interface for algorithms (PPO, SAC, etc.)

Methods:

__init__()

Initialize internal Module state, shared by both nn.Module and ScriptModule.

evaluate_actions(tensordict)

Evaluate actions and return current policy outputs.

forward(tensordict[, deterministic])

Write sampled actions and value estimates into the TensorDict.

get_action(tensordict[, deterministic])

Sample actions into the provided TensorDict without gradients.

get_value(tensordict)

Write value estimate for the given observations into the TensorDict.

Attributes:

device

Device where the policy parameters are located.

__init__()[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

device: torch.device#

Device where the policy parameters are located.

abstract evaluate_actions(tensordict)[source]#

Evaluate actions and return current policy outputs.

Parameters:

tensordict (TensorDict) – TensorDict containing obs and action.

Return type:

TensorDict

Returns:

A new TensorDict containing sample_log_prob, entropy, and value.

abstract forward(tensordict, deterministic=False)[source]#

Write sampled actions and value estimates into the TensorDict.

Return type:

TensorDict

get_action(tensordict, deterministic=False)[source]#

Sample actions into the provided TensorDict without gradients.

Parameters:
  • tensordict (TensorDict) – Input TensorDict containing obs.

  • deterministic (bool) – If True, return the mean action; otherwise sample

Return type:

TensorDict

Returns:

TensorDict with action, sample_log_prob, and value populated.

abstract get_value(tensordict)[source]#

Write value estimate for the given observations into the TensorDict.

Parameters:

tensordict (TensorDict) – Input TensorDict containing obs.

Return type:

TensorDict

Returns:

TensorDict with value populated.

training: bool#
embodichain.agents.rl.models.build_mlp_from_cfg(module_cfg, in_dim, out_dim)[source]#

Construct an MLP module from a minimal json-like config.

Return type:

MLP

Expected schema:
module_cfg = {

“type”: “mlp”, “hidden_sizes”: [256, 256], “activation”: “relu”,

}

embodichain.agents.rl.models.build_policy(policy_block, obs_space, action_space, device, actor=None, critic=None)[source]#

Build a policy from config using spaces for extensibility.

Built-in MLP policies still resolve flattened obs_dim / action_dim, while custom policies may accept richer obs_space / action_space inputs.

Return type:

Policy

embodichain.agents.rl.models.get_policy_class(name)[source]#
Return type:

Optional[Type[Policy]]

embodichain.agents.rl.models.get_registered_policy_names()[source]#
Return type:

list[str]

embodichain.agents.rl.models.register_policy(name, policy_cls)[source]#
Return type:

None

Training#

Functions:

main()

Main entry point for command-line training.

parse_args()

Parse command line arguments.

train_from_config(config_path[, distributed])

Run training from a config file path.

embodichain.agents.rl.train.main()[source]#

Main entry point for command-line training.

embodichain.agents.rl.train.parse_args()[source]#

Parse command line arguments.

embodichain.agents.rl.train.train_from_config(config_path, distributed=None)[source]#

Run training from a config file path.

Parameters:
  • config_path (str) – Path to the JSON config file

  • distributed (Optional[bool]) – If True, run multi-GPU distributed training. If None, use trainer.distributed from config.

Utilities#

Classes:

AlgorithmCfg

Minimal algorithm configuration shared across RL algorithms.

Functions:

dict_to_tensordict(obs_dict, device)

Convert an environment observation mapping into a TensorDict.

flatten_dict_observation(obs)

Flatten a hierarchical observation TensorDict into a 2D tensor.

class embodichain.agents.rl.utils.AlgorithmCfg[source]#

Bases: object

Minimal algorithm configuration shared across RL algorithms.

Methods:

__init__([device, learning_rate, ...])

copy(**kwargs)

Return a new object replacing specified fields with new values.

replace(**kwargs)

Return a new object replacing specified fields with new values.

to_dict()

Convert an object into dictionary recursively.

validate([prefix])

Check the validity of configclass object.

Attributes:

__init__(device=<factory>, learning_rate=<factory>, batch_size=<factory>, gamma=<factory>, gae_lambda=<factory>, max_grad_norm=<factory>)#
batch_size: int#
copy(**kwargs)#

Return a new object replacing specified fields with new values.

This is especially useful for frozen classes. Example usage:

@configclass(frozen=True)
class C:
    x: int
    y: int

c = C(1, 2)
c1 = c.replace(x=3)
assert c1.x == 3 and c1.y == 2
Parameters:
  • obj (object) – The object to replace.

  • **kwargs – The fields to replace and their new values.

Return type:

object

Returns:

The new object.

device: str#
gae_lambda: float#
gamma: float#
learning_rate: float#
max_grad_norm: float#
replace(**kwargs)#

Return a new object replacing specified fields with new values.

This is especially useful for frozen classes. Example usage:

@configclass(frozen=True)
class C:
    x: int
    y: int

c = C(1, 2)
c1 = c.replace(x=3)
assert c1.x == 3 and c1.y == 2
Parameters:
  • obj (object) – The object to replace.

  • **kwargs – The fields to replace and their new values.

Return type:

object

Returns:

The new object.

to_dict()#

Convert an object into dictionary recursively.

Note

Ignores all names starting with “__” (i.e. built-in methods).

Parameters:

obj (object) – An instance of a class to convert.

Raises:

ValueError – When input argument is not an object.

Return type:

dict[str, Any]

Returns:

Converted dictionary mapping.

validate(prefix='')#

Check the validity of configclass object.

This function checks if the object is a valid configclass object. A valid configclass object contains no MISSING entries.

Parameters:
  • obj (object) – The object to check.

  • prefix (str) – The prefix to add to the missing fields. Defaults to ‘’.

Return type:

list[str]

Returns:

A list of missing fields.

Raises:

TypeError – When the object is not a valid configuration object.

embodichain.agents.rl.utils.dict_to_tensordict(obs_dict, device)[source]#

Convert an environment observation mapping into a TensorDict.

Parameters:
  • obs_dict (TensorDict | Mapping[str, Any]) – Environment observation returned by reset() or step().

  • device (device | str) – Target device for the resulting TensorDict.

Return type:

TensorDict

Returns:

Observation TensorDict moved onto the target device.

embodichain.agents.rl.utils.flatten_dict_observation(obs)[source]#

Flatten a hierarchical observation TensorDict into a 2D tensor.

Parameters:

obs (TensorDict) – Observation TensorDict with batch dimension [num_envs].

Return type:

Tensor

Returns:

Flattened observation tensor of shape [num_envs, obs_dim].