embodichain.agents.rl.algo#

Overview#

Algorithm registry and algorithm-construction helpers for RL training.

Functions

build_algo(name, cfg_kwargs, policy, device, *)

get_registered_algo_names()

Classes:

BaseAlgorithm

Base class for RL algorithms.

GRPO

Group Relative Policy Optimization on top of TensorDict rollouts.

GRPOCfg

Configuration for GRPO.

PPO

PPO algorithm consuming TensorDict rollouts.

PPOCfg

Configuration for the PPO algorithm.

Functions:

build_algo(name, cfg_kwargs, policy, device, *)

compute_gae(rollout, gamma, gae_lambda)

Compute GAE over a rollout stored as [num_envs, time + 1].

get_registered_algo_names()

class embodichain.agents.rl.algo.BaseAlgorithm[source]#

Bases: object

Base class for RL algorithms.

Algorithms only implement policy updates over collected rollouts.

Attributes:

Methods:

update(rollout)

Update policy using collected data and return training losses.

device: device#
update(rollout)[source]#

Update policy using collected data and return training losses.

Return type:

Dict[str, float]

class embodichain.agents.rl.algo.GRPO[source]#

Bases: BaseAlgorithm

Group Relative Policy Optimization on top of TensorDict rollouts.

Methods:

__init__(cfg, policy)

update(rollout)

Update policy using collected data and return training losses.

__init__(cfg, policy)[source]#
update(rollout)[source]#

Update policy using collected data and return training losses.

Return type:

Dict[str, float]

class embodichain.agents.rl.algo.GRPOCfg[source]#

Bases: AlgorithmCfg

Configuration for GRPO.

Methods:

__init__([device, learning_rate, ...])

copy(**kwargs)

Return a new object replacing specified fields with new values.

replace(**kwargs)

Return a new object replacing specified fields with new values.

to_dict()

Convert an object into dictionary recursively.

validate([prefix])

Check the validity of configclass object.

Attributes:

__init__(device=<factory>, learning_rate=<factory>, batch_size=<factory>, gamma=<factory>, gae_lambda=<factory>, max_grad_norm=<factory>, n_epochs=<factory>, clip_coef=<factory>, ent_coef=<factory>, kl_coef=<factory>, group_size=<factory>, eps=<factory>, reset_every_rollout=<factory>, truncate_at_first_done=<factory>)#
batch_size: int#
clip_coef: float#
copy(**kwargs)#

Return a new object replacing specified fields with new values.

This is especially useful for frozen classes. Example usage:

@configclass(frozen=True)
class C:
    x: int
    y: int

c = C(1, 2)
c1 = c.replace(x=3)
assert c1.x == 3 and c1.y == 2
Parameters:
  • obj (object) – The object to replace.

  • **kwargs – The fields to replace and their new values.

Return type:

object

Returns:

The new object.

device: str#
ent_coef: float#
eps: float#
gae_lambda: float#
gamma: float#
group_size: int#
kl_coef: float#
learning_rate: float#
max_grad_norm: float#
n_epochs: int#
replace(**kwargs)#

Return a new object replacing specified fields with new values.

This is especially useful for frozen classes. Example usage:

@configclass(frozen=True)
class C:
    x: int
    y: int

c = C(1, 2)
c1 = c.replace(x=3)
assert c1.x == 3 and c1.y == 2
Parameters:
  • obj (object) – The object to replace.

  • **kwargs – The fields to replace and their new values.

Return type:

object

Returns:

The new object.

reset_every_rollout: bool#
to_dict()#

Convert an object into dictionary recursively.

Note

Ignores all names starting with “__” (i.e. built-in methods).

Parameters:

obj (object) – An instance of a class to convert.

Raises:

ValueError – When input argument is not an object.

Return type:

dict[str, Any]

Returns:

Converted dictionary mapping.

truncate_at_first_done: bool#
validate(prefix='')#

Check the validity of configclass object.

This function checks if the object is a valid configclass object. A valid configclass object contains no MISSING entries.

Parameters:
  • obj (object) – The object to check.

  • prefix (str) – The prefix to add to the missing fields. Defaults to ‘’.

Return type:

list[str]

Returns:

A list of missing fields.

Raises:

TypeError – When the object is not a valid configuration object.

class embodichain.agents.rl.algo.PPO[source]#

Bases: BaseAlgorithm

PPO algorithm consuming TensorDict rollouts.

Methods:

__init__(cfg, policy)

update(rollout)

Update the policy using a collected rollout.

__init__(cfg, policy)[source]#
update(rollout)[source]#

Update the policy using a collected rollout.

Return type:

Dict[str, float]

class embodichain.agents.rl.algo.PPOCfg[source]#

Bases: AlgorithmCfg

Configuration for the PPO algorithm.

Methods:

__init__([device, learning_rate, ...])

copy(**kwargs)

Return a new object replacing specified fields with new values.

replace(**kwargs)

Return a new object replacing specified fields with new values.

to_dict()

Convert an object into dictionary recursively.

validate([prefix])

Check the validity of configclass object.

Attributes:

__init__(device=<factory>, learning_rate=<factory>, batch_size=<factory>, gamma=<factory>, gae_lambda=<factory>, max_grad_norm=<factory>, n_epochs=<factory>, clip_coef=<factory>, ent_coef=<factory>, vf_coef=<factory>)#
batch_size: int#
clip_coef: float#
copy(**kwargs)#

Return a new object replacing specified fields with new values.

This is especially useful for frozen classes. Example usage:

@configclass(frozen=True)
class C:
    x: int
    y: int

c = C(1, 2)
c1 = c.replace(x=3)
assert c1.x == 3 and c1.y == 2
Parameters:
  • obj (object) – The object to replace.

  • **kwargs – The fields to replace and their new values.

Return type:

object

Returns:

The new object.

device: str#
ent_coef: float#
gae_lambda: float#
gamma: float#
learning_rate: float#
max_grad_norm: float#
n_epochs: int#
replace(**kwargs)#

Return a new object replacing specified fields with new values.

This is especially useful for frozen classes. Example usage:

@configclass(frozen=True)
class C:
    x: int
    y: int

c = C(1, 2)
c1 = c.replace(x=3)
assert c1.x == 3 and c1.y == 2
Parameters:
  • obj (object) – The object to replace.

  • **kwargs – The fields to replace and their new values.

Return type:

object

Returns:

The new object.

to_dict()#

Convert an object into dictionary recursively.

Note

Ignores all names starting with “__” (i.e. built-in methods).

Parameters:

obj (object) – An instance of a class to convert.

Raises:

ValueError – When input argument is not an object.

Return type:

dict[str, Any]

Returns:

Converted dictionary mapping.

validate(prefix='')#

Check the validity of configclass object.

This function checks if the object is a valid configclass object. A valid configclass object contains no MISSING entries.

Parameters:
  • obj (object) – The object to check.

  • prefix (str) – The prefix to add to the missing fields. Defaults to ‘’.

Return type:

list[str]

Returns:

A list of missing fields.

Raises:

TypeError – When the object is not a valid configuration object.

vf_coef: float#
embodichain.agents.rl.algo.build_algo(name, cfg_kwargs, policy, device, *, distributed=False)[source]#
embodichain.agents.rl.algo.compute_gae(rollout, gamma, gae_lambda)[source]#

Compute GAE over a rollout stored as [num_envs, time + 1].

Parameters:
  • rollout (TensorDict) – Rollout TensorDict where value[:, -1] stores the bootstrap value for the final observation and transition-only fields reserve their last slot as padding.

  • gamma (float) – Discount factor.

  • gae_lambda (float) – GAE lambda coefficient.

Return type:

tuple[Tensor, Tensor]

Returns:

Tuple of (advantages, returns), both shaped [num_envs, time].

embodichain.agents.rl.algo.get_registered_algo_names()[source]#
Return type:

list[str]