embodichain.agents.rl#
Submodules
Algorithms#
Classes:
Base class for RL algorithms. |
|
Group Relative Policy Optimization on top of RolloutBuffer. |
|
Configuration for GRPO. |
|
PPO algorithm operating via Policy and RolloutBuffer (algo-agnostic design). |
|
Configuration for the PPO algorithm. |
Functions:
|
|
- class embodichain.agents.rl.algo.BaseAlgorithm[source]#
Bases:
objectBase class for RL algorithms.
Algorithms must implement buffer initialization, rollout collection, and policy update. Trainer depends only on this interface to remain algorithm-agnostic.
Methods:
collect_rollout(env, policy, obs, num_steps)Collect trajectories and return logging info (e.g., reward components).
initialize_buffer(num_steps, num_envs, ...)Initialize internal buffer(s) required by the algorithm.
update()Update policy using collected data and return training losses.
Attributes:
- collect_rollout(env, policy, obs, num_steps, on_step_callback=None)[source]#
Collect trajectories and return logging info (e.g., reward components).
- Return type:
Dict[str,Any]
-
device:
device#
- class embodichain.agents.rl.algo.GRPO[source]#
Bases:
BaseAlgorithmGroup Relative Policy Optimization on top of RolloutBuffer.
Methods:
__init__(cfg, policy)collect_rollout(env, policy, obs, num_steps)Collect trajectories and return logging info (e.g., reward components).
initialize_buffer(num_steps, num_envs, ...)Initialize internal buffer(s) required by the algorithm.
update()Update policy using collected data and return training losses.
- collect_rollout(env, policy, obs, num_steps, on_step_callback=None)[source]#
Collect trajectories and return logging info (e.g., reward components).
- Return type:
Dict[str,Any]
- class embodichain.agents.rl.algo.GRPOCfg[source]#
Bases:
AlgorithmCfgConfiguration for GRPO.
Methods:
__init__([device, learning_rate, ...])copy(**kwargs)Return a new object replacing specified fields with new values.
replace(**kwargs)Return a new object replacing specified fields with new values.
to_dict()Convert an object into dictionary recursively.
validate([prefix])Check the validity of configclass object.
Attributes:
- __init__(device=<factory>, learning_rate=<factory>, batch_size=<factory>, gamma=<factory>, gae_lambda=<factory>, max_grad_norm=<factory>, n_epochs=<factory>, clip_coef=<factory>, ent_coef=<factory>, kl_coef=<factory>, group_size=<factory>, eps=<factory>, reset_every_rollout=<factory>, truncate_at_first_done=<factory>)#
-
batch_size:
int#
-
clip_coef:
float#
- copy(**kwargs)#
Return a new object replacing specified fields with new values.
This is especially useful for frozen classes. Example usage:
@configclass(frozen=True) class C: x: int y: int c = C(1, 2) c1 = c.replace(x=3) assert c1.x == 3 and c1.y == 2
- Parameters:
obj (
object) – The object to replace.**kwargs – The fields to replace and their new values.
- Return type:
object- Returns:
The new object.
-
device:
str#
-
ent_coef:
float#
-
eps:
float#
-
gae_lambda:
float#
-
gamma:
float#
-
group_size:
int#
-
kl_coef:
float#
-
learning_rate:
float#
-
max_grad_norm:
float#
-
n_epochs:
int#
- replace(**kwargs)#
Return a new object replacing specified fields with new values.
This is especially useful for frozen classes. Example usage:
@configclass(frozen=True) class C: x: int y: int c = C(1, 2) c1 = c.replace(x=3) assert c1.x == 3 and c1.y == 2
- Parameters:
obj (
object) – The object to replace.**kwargs – The fields to replace and their new values.
- Return type:
object- Returns:
The new object.
-
reset_every_rollout:
bool#
- to_dict()#
Convert an object into dictionary recursively.
Note
Ignores all names starting with “__” (i.e. built-in methods).
- Parameters:
obj (
object) – An instance of a class to convert.- Raises:
ValueError – When input argument is not an object.
- Return type:
dict[str,Any]- Returns:
Converted dictionary mapping.
-
truncate_at_first_done:
bool#
- validate(prefix='')#
Check the validity of configclass object.
This function checks if the object is a valid configclass object. A valid configclass object contains no MISSING entries.
- Parameters:
obj (
object) – The object to check.prefix (
str) – The prefix to add to the missing fields. Defaults to ‘’.
- Return type:
list[str]- Returns:
A list of missing fields.
- Raises:
TypeError – When the object is not a valid configuration object.
- class embodichain.agents.rl.algo.PPO[source]#
Bases:
BaseAlgorithmPPO algorithm operating via Policy and RolloutBuffer (algo-agnostic design).
Methods:
__init__(cfg, policy)collect_rollout(env, policy, obs, num_steps)Collect a rollout.
initialize_buffer(num_steps, num_envs, ...)Initialize the rollout buffer.
update()Update the policy using the collected rollout buffer.
- collect_rollout(env, policy, obs, num_steps, on_step_callback=None)[source]#
Collect a rollout. Algorithm controls the data collection process.
- Return type:
Dict[str,Any]
- class embodichain.agents.rl.algo.PPOCfg[source]#
Bases:
AlgorithmCfgConfiguration for the PPO algorithm.
Methods:
__init__([device, learning_rate, ...])copy(**kwargs)Return a new object replacing specified fields with new values.
replace(**kwargs)Return a new object replacing specified fields with new values.
to_dict()Convert an object into dictionary recursively.
validate([prefix])Check the validity of configclass object.
Attributes:
- __init__(device=<factory>, learning_rate=<factory>, batch_size=<factory>, gamma=<factory>, gae_lambda=<factory>, max_grad_norm=<factory>, n_epochs=<factory>, clip_coef=<factory>, ent_coef=<factory>, vf_coef=<factory>)#
-
batch_size:
int#
-
clip_coef:
float#
- copy(**kwargs)#
Return a new object replacing specified fields with new values.
This is especially useful for frozen classes. Example usage:
@configclass(frozen=True) class C: x: int y: int c = C(1, 2) c1 = c.replace(x=3) assert c1.x == 3 and c1.y == 2
- Parameters:
obj (
object) – The object to replace.**kwargs – The fields to replace and their new values.
- Return type:
object- Returns:
The new object.
-
device:
str#
-
ent_coef:
float#
-
gae_lambda:
float#
-
gamma:
float#
-
learning_rate:
float#
-
max_grad_norm:
float#
-
n_epochs:
int#
- replace(**kwargs)#
Return a new object replacing specified fields with new values.
This is especially useful for frozen classes. Example usage:
@configclass(frozen=True) class C: x: int y: int c = C(1, 2) c1 = c.replace(x=3) assert c1.x == 3 and c1.y == 2
- Parameters:
obj (
object) – The object to replace.**kwargs – The fields to replace and their new values.
- Return type:
object- Returns:
The new object.
- to_dict()#
Convert an object into dictionary recursively.
Note
Ignores all names starting with “__” (i.e. built-in methods).
- Parameters:
obj (
object) – An instance of a class to convert.- Raises:
ValueError – When input argument is not an object.
- Return type:
dict[str,Any]- Returns:
Converted dictionary mapping.
- validate(prefix='')#
Check the validity of configclass object.
This function checks if the object is a valid configclass object. A valid configclass object contains no MISSING entries.
- Parameters:
obj (
object) – The object to check.prefix (
str) – The prefix to add to the missing fields. Defaults to ‘’.
- Return type:
list[str]- Returns:
A list of missing fields.
- Raises:
TypeError – When the object is not a valid configuration object.
-
vf_coef:
float#
Rollout Buffer#
Classes:
On-device rollout buffer for on-policy algorithms. |
- class embodichain.agents.rl.buffer.RolloutBuffer[source]#
Bases:
objectOn-device rollout buffer for on-policy algorithms.
Stores (obs, actions, rewards, dones, values, logprobs) over time. After finalize(), exposes advantages/returns and minibatch iteration.
Methods:
__init__(num_steps, num_envs, obs_dim, ...)add(obs, action, reward, done, value, logprob)iterate_minibatches(batch_size)set_extras(extras)Attach algorithm-specific tensors (shape [T, N, ...]) for batching.
Policy Models#
Classes:
Actor-Critic with learnable log_std for Gaussian policy. |
|
Actor-only policy for algorithms that do not use a value function (e.g., GRPO). |
|
General MLP supporting custom last activation, orthogonal init, and output reshape. |
|
Abstract base class that all RL policies must implement. |
Functions:
|
Construct an MLP module from a minimal json-like config. |
|
Build policy strictly from json-like block: { name: ..., cfg: {...} } |
|
|
|
- class embodichain.agents.rl.models.ActorCritic[source]#
Bases:
PolicyActor-Critic with learnable log_std for Gaussian policy.
This is a placeholder implementation of the Policy interface that: - Encapsulates MLP networks (actor + critic) that need to be trained by RL algorithms - Handles internal computation: MLP output → mean + learnable log_std → Normal distribution - Provides a uniform interface for RL algorithms (PPO, SAC, etc.)
This allows seamless swapping with other policy implementations (e.g., VLAPolicy) without modifying RL algorithm code.
- Implements:
get_action(obs, deterministic=False) -> (action, log_prob, value)
get_value(obs)
evaluate_actions(obs, actions) -> (log_prob, entropy, value)
Methods:
__init__(obs_space, action_space, device, ...)Initialize internal Module state, shared by both nn.Module and ScriptModule.
evaluate_actions(obs, actions)Evaluate actions and compute log probabilities, entropy, and values.
get_action(obs[, deterministic])Sample an action from the policy.
get_value(obs)Get value estimate for given observations.
- __init__(obs_space, action_space, device, actor, critic)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- evaluate_actions(obs, actions)[source]#
Evaluate actions and compute log probabilities, entropy, and values.
- Parameters:
obs (
Tensor) – Observation tensor of shape (batch_size, obs_dim)actions (
Tensor) – Action tensor of shape (batch_size, action_dim)
- Returns:
log_prob: Log probability of actions, shape (batch_size,)
entropy: Entropy of the action distribution, shape (batch_size,)
value: Value estimate, shape (batch_size,)
- Return type:
Tuple of (log_prob, entropy, value)
- get_action(obs, deterministic=False)[source]#
Sample an action from the policy.
- Parameters:
obs (
Tensor) – Observation tensor of shape (batch_size, obs_dim)deterministic (
bool) – If True, return the mean action; otherwise sample
- Returns:
action: Sampled action tensor of shape (batch_size, action_dim)
log_prob: Log probability of the action, shape (batch_size,)
value: Value estimate, shape (batch_size,)
- Return type:
Tuple of (action, log_prob, value)
- class embodichain.agents.rl.models.ActorOnly[source]#
Bases:
PolicyActor-only policy for algorithms that do not use a value function (e.g., GRPO).
Same interface as ActorCritic: get_action and evaluate_actions return (action, log_prob, value), but value is always zeros since no critic is used.
Methods:
__init__(obs_space, action_space, device, actor)Initialize internal Module state, shared by both nn.Module and ScriptModule.
evaluate_actions(obs, actions)Evaluate actions and compute log probabilities, entropy, and values.
get_action(obs[, deterministic])Sample an action from the policy.
get_value(obs)Get value estimate for given observations.
- __init__(obs_space, action_space, device, actor)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- evaluate_actions(obs, actions)[source]#
Evaluate actions and compute log probabilities, entropy, and values.
- Parameters:
obs (
Tensor) – Observation tensor of shape (batch_size, obs_dim)actions (
Tensor) – Action tensor of shape (batch_size, action_dim)
- Returns:
log_prob: Log probability of actions, shape (batch_size,)
entropy: Entropy of the action distribution, shape (batch_size,)
value: Value estimate, shape (batch_size,)
- Return type:
Tuple of (log_prob, entropy, value)
- get_action(obs, deterministic=False)[source]#
Sample an action from the policy.
- Parameters:
obs (
Tensor) – Observation tensor of shape (batch_size, obs_dim)deterministic (
bool) – If True, return the mean action; otherwise sample
- Returns:
action: Sampled action tensor of shape (batch_size, action_dim)
log_prob: Log probability of the action, shape (batch_size,)
value: Value estimate, shape (batch_size,)
- Return type:
Tuple of (action, log_prob, value)
- class embodichain.agents.rl.models.MLP[source]#
Bases:
SequentialGeneral MLP supporting custom last activation, orthogonal init, and output reshape.
- Parameters:
input_dim (-) – input dimension
output_dim (-) – output dimension (int or shape tuple/list)
hidden_dims (-) – hidden layer sizes, e.g. [256, 256]
activation (-) – hidden layer activation name (relu/elu/tanh/gelu/silu)
last_activation (-) – last-layer activation name or None for linear
use_layernorm (-) – whether to add LayerNorm after each hidden linear layer
dropout_p (-) – dropout probability for hidden layers (0 disables)
Methods:
__init__(input_dim, output_dim, hidden_dims)Initialize internal Module state, shared by both nn.Module and ScriptModule.
init_orthogonal([scales])Orthogonal-initialize linear layers and zero the bias.
- class embodichain.agents.rl.models.Policy[source]#
Bases:
Module,ABCAbstract base class that all RL policies must implement.
A Policy: - Encapsulates neural networks that are trained by RL algorithms - Handles internal computations (e.g., network output → distribution) - Provides a uniform interface for algorithms (PPO, SAC, etc.)
Methods:
__init__()Initialize internal Module state, shared by both nn.Module and ScriptModule.
evaluate_actions(obs, actions)Evaluate actions and compute log probabilities, entropy, and values.
get_action(obs[, deterministic])Sample an action from the policy.
get_value(obs)Get value estimate for given observations.
Attributes:
Device where the policy parameters are located.
- device: torch.device#
Device where the policy parameters are located.
- abstract evaluate_actions(obs, actions)[source]#
Evaluate actions and compute log probabilities, entropy, and values.
- Parameters:
obs (
Tensor) – Observation tensor of shape (batch_size, obs_dim)actions (
Tensor) – Action tensor of shape (batch_size, action_dim)
- Returns:
log_prob: Log probability of actions, shape (batch_size,)
entropy: Entropy of the action distribution, shape (batch_size,)
value: Value estimate, shape (batch_size,)
- Return type:
Tuple of (log_prob, entropy, value)
- abstract get_action(obs, deterministic=False)[source]#
Sample an action from the policy.
- Parameters:
obs (
Tensor) – Observation tensor of shape (batch_size, obs_dim)deterministic (
bool) – If True, return the mean action; otherwise sample
- Returns:
action: Sampled action tensor of shape (batch_size, action_dim)
log_prob: Log probability of the action, shape (batch_size,)
value: Value estimate, shape (batch_size,)
- Return type:
Tuple of (action, log_prob, value)
- abstract get_value(obs)[source]#
Get value estimate for given observations.
- Parameters:
obs (
Tensor) – Observation tensor of shape (batch_size, obs_dim)- Return type:
Tensor- Returns:
Value estimate tensor of shape (batch_size,)
- training: bool#
- embodichain.agents.rl.models.build_mlp_from_cfg(module_cfg, in_dim, out_dim)[source]#
Construct an MLP module from a minimal json-like config.
- Return type:
- Expected schema:
- module_cfg = {
“type”: “mlp”, “hidden_sizes”: [256, 256], “activation”: “relu”,
}
Training#
Functions:
|
Main entry point for command-line training. |
Parse command line arguments. |
|
|
Run training from a config file path. |
Utilities#
Classes:
Minimal algorithm configuration shared across RL algorithms. |
Functions:
Flatten hierarchical TensorDict observations from ObservationManager. |
- class embodichain.agents.rl.utils.AlgorithmCfg[source]#
Bases:
objectMinimal algorithm configuration shared across RL algorithms.
Methods:
__init__([device, learning_rate, ...])copy(**kwargs)Return a new object replacing specified fields with new values.
replace(**kwargs)Return a new object replacing specified fields with new values.
to_dict()Convert an object into dictionary recursively.
validate([prefix])Check the validity of configclass object.
Attributes:
- __init__(device=<factory>, learning_rate=<factory>, batch_size=<factory>, gamma=<factory>, gae_lambda=<factory>, max_grad_norm=<factory>)#
-
batch_size:
int#
- copy(**kwargs)#
Return a new object replacing specified fields with new values.
This is especially useful for frozen classes. Example usage:
@configclass(frozen=True) class C: x: int y: int c = C(1, 2) c1 = c.replace(x=3) assert c1.x == 3 and c1.y == 2
- Parameters:
obj (
object) – The object to replace.**kwargs – The fields to replace and their new values.
- Return type:
object- Returns:
The new object.
-
device:
str#
-
gae_lambda:
float#
-
gamma:
float#
-
learning_rate:
float#
-
max_grad_norm:
float#
- replace(**kwargs)#
Return a new object replacing specified fields with new values.
This is especially useful for frozen classes. Example usage:
@configclass(frozen=True) class C: x: int y: int c = C(1, 2) c1 = c.replace(x=3) assert c1.x == 3 and c1.y == 2
- Parameters:
obj (
object) – The object to replace.**kwargs – The fields to replace and their new values.
- Return type:
object- Returns:
The new object.
- to_dict()#
Convert an object into dictionary recursively.
Note
Ignores all names starting with “__” (i.e. built-in methods).
- Parameters:
obj (
object) – An instance of a class to convert.- Raises:
ValueError – When input argument is not an object.
- Return type:
dict[str,Any]- Returns:
Converted dictionary mapping.
- validate(prefix='')#
Check the validity of configclass object.
This function checks if the object is a valid configclass object. A valid configclass object contains no MISSING entries.
- Parameters:
obj (
object) – The object to check.prefix (
str) – The prefix to add to the missing fields. Defaults to ‘’.
- Return type:
list[str]- Returns:
A list of missing fields.
- Raises:
TypeError – When the object is not a valid configuration object.
- embodichain.agents.rl.utils.flatten_dict_observation(obs)[source]#
Flatten hierarchical TensorDict observations from ObservationManager.
Recursively traverse nested TensorDicts, collect all tensor values, flatten each to (num_envs, -1), and concatenate in sorted key order.
- Parameters:
obs (
TensorDict) – Nested TensorDict structure, e.g. TensorDict(robot=TensorDict(qpos=…, qvel=…), …)- Return type:
Tensor- Returns:
Concatenated flat tensor of shape (num_envs, total_dim)