embodichain.agents.rl.algo#
Overview#
Algorithm registry and algorithm-construction helpers for RL training.
Functions
build_algo(name, cfg_kwargs, policy, device, *)
Classes:
Base class for RL algorithms. |
|
Group Relative Policy Optimization on top of TensorDict rollouts. |
|
Configuration for GRPO. |
|
PPO algorithm consuming TensorDict rollouts. |
|
Configuration for the PPO algorithm. |
Functions:
|
|
|
Compute GAE over a rollout stored as [num_envs, time + 1]. |
- class embodichain.agents.rl.algo.BaseAlgorithm[source]#
Bases:
objectBase class for RL algorithms.
Algorithms only implement policy updates over collected rollouts.
Attributes:
Methods:
update(rollout)Update policy using collected data and return training losses.
-
device:
device#
-
device:
- class embodichain.agents.rl.algo.GRPO[source]#
Bases:
BaseAlgorithmGroup Relative Policy Optimization on top of TensorDict rollouts.
Methods:
- class embodichain.agents.rl.algo.GRPOCfg[source]#
Bases:
AlgorithmCfgConfiguration for GRPO.
Methods:
__init__([device, learning_rate, ...])copy(**kwargs)Return a new object replacing specified fields with new values.
replace(**kwargs)Return a new object replacing specified fields with new values.
to_dict()Convert an object into dictionary recursively.
validate([prefix])Check the validity of configclass object.
Attributes:
- __init__(device=<factory>, learning_rate=<factory>, batch_size=<factory>, gamma=<factory>, gae_lambda=<factory>, max_grad_norm=<factory>, n_epochs=<factory>, clip_coef=<factory>, ent_coef=<factory>, kl_coef=<factory>, group_size=<factory>, eps=<factory>, reset_every_rollout=<factory>, truncate_at_first_done=<factory>)#
-
batch_size:
int#
-
clip_coef:
float#
- copy(**kwargs)#
Return a new object replacing specified fields with new values.
This is especially useful for frozen classes. Example usage:
@configclass(frozen=True) class C: x: int y: int c = C(1, 2) c1 = c.replace(x=3) assert c1.x == 3 and c1.y == 2
- Parameters:
obj (
object) – The object to replace.**kwargs – The fields to replace and their new values.
- Return type:
object- Returns:
The new object.
-
device:
str#
-
ent_coef:
float#
-
eps:
float#
-
gae_lambda:
float#
-
gamma:
float#
-
group_size:
int#
-
kl_coef:
float#
-
learning_rate:
float#
-
max_grad_norm:
float#
-
n_epochs:
int#
- replace(**kwargs)#
Return a new object replacing specified fields with new values.
This is especially useful for frozen classes. Example usage:
@configclass(frozen=True) class C: x: int y: int c = C(1, 2) c1 = c.replace(x=3) assert c1.x == 3 and c1.y == 2
- Parameters:
obj (
object) – The object to replace.**kwargs – The fields to replace and their new values.
- Return type:
object- Returns:
The new object.
-
reset_every_rollout:
bool#
- to_dict()#
Convert an object into dictionary recursively.
Note
Ignores all names starting with “__” (i.e. built-in methods).
- Parameters:
obj (
object) – An instance of a class to convert.- Raises:
ValueError – When input argument is not an object.
- Return type:
dict[str,Any]- Returns:
Converted dictionary mapping.
-
truncate_at_first_done:
bool#
- validate(prefix='')#
Check the validity of configclass object.
This function checks if the object is a valid configclass object. A valid configclass object contains no MISSING entries.
- Parameters:
obj (
object) – The object to check.prefix (
str) – The prefix to add to the missing fields. Defaults to ‘’.
- Return type:
list[str]- Returns:
A list of missing fields.
- Raises:
TypeError – When the object is not a valid configuration object.
- class embodichain.agents.rl.algo.PPO[source]#
Bases:
BaseAlgorithmPPO algorithm consuming TensorDict rollouts.
Methods:
- class embodichain.agents.rl.algo.PPOCfg[source]#
Bases:
AlgorithmCfgConfiguration for the PPO algorithm.
Methods:
__init__([device, learning_rate, ...])copy(**kwargs)Return a new object replacing specified fields with new values.
replace(**kwargs)Return a new object replacing specified fields with new values.
to_dict()Convert an object into dictionary recursively.
validate([prefix])Check the validity of configclass object.
Attributes:
- __init__(device=<factory>, learning_rate=<factory>, batch_size=<factory>, gamma=<factory>, gae_lambda=<factory>, max_grad_norm=<factory>, n_epochs=<factory>, clip_coef=<factory>, ent_coef=<factory>, vf_coef=<factory>)#
-
batch_size:
int#
-
clip_coef:
float#
- copy(**kwargs)#
Return a new object replacing specified fields with new values.
This is especially useful for frozen classes. Example usage:
@configclass(frozen=True) class C: x: int y: int c = C(1, 2) c1 = c.replace(x=3) assert c1.x == 3 and c1.y == 2
- Parameters:
obj (
object) – The object to replace.**kwargs – The fields to replace and their new values.
- Return type:
object- Returns:
The new object.
-
device:
str#
-
ent_coef:
float#
-
gae_lambda:
float#
-
gamma:
float#
-
learning_rate:
float#
-
max_grad_norm:
float#
-
n_epochs:
int#
- replace(**kwargs)#
Return a new object replacing specified fields with new values.
This is especially useful for frozen classes. Example usage:
@configclass(frozen=True) class C: x: int y: int c = C(1, 2) c1 = c.replace(x=3) assert c1.x == 3 and c1.y == 2
- Parameters:
obj (
object) – The object to replace.**kwargs – The fields to replace and their new values.
- Return type:
object- Returns:
The new object.
- to_dict()#
Convert an object into dictionary recursively.
Note
Ignores all names starting with “__” (i.e. built-in methods).
- Parameters:
obj (
object) – An instance of a class to convert.- Raises:
ValueError – When input argument is not an object.
- Return type:
dict[str,Any]- Returns:
Converted dictionary mapping.
- validate(prefix='')#
Check the validity of configclass object.
This function checks if the object is a valid configclass object. A valid configclass object contains no MISSING entries.
- Parameters:
obj (
object) – The object to check.prefix (
str) – The prefix to add to the missing fields. Defaults to ‘’.
- Return type:
list[str]- Returns:
A list of missing fields.
- Raises:
TypeError – When the object is not a valid configuration object.
-
vf_coef:
float#
- embodichain.agents.rl.algo.build_algo(name, cfg_kwargs, policy, device, *, distributed=False)[source]#
- embodichain.agents.rl.algo.compute_gae(rollout, gamma, gae_lambda)[source]#
Compute GAE over a rollout stored as [num_envs, time + 1].
- Parameters:
rollout (
TensorDict) – Rollout TensorDict where value[:, -1] stores the bootstrap value for the final observation and transition-only fields reserve their last slot as padding.gamma (
float) – Discount factor.gae_lambda (
float) – GAE lambda coefficient.
- Return type:
tuple[Tensor,Tensor]- Returns:
Tuple of (advantages, returns), both shaped [num_envs, time].