embodichain.agents.rl.utils#
Overview#
The utils package contains helper utilities for RL configuration,
data conversion, and training orchestration.
Submodules
Configuration Helpers#
Classes:
Minimal algorithm configuration shared across RL algorithms. |
- class embodichain.agents.rl.utils.config.AlgorithmCfg[source]#
Bases:
objectMinimal algorithm configuration shared across RL algorithms.
Methods:
__init__([device, learning_rate, ...])copy(**kwargs)Return a new object replacing specified fields with new values.
replace(**kwargs)Return a new object replacing specified fields with new values.
to_dict()Convert an object into dictionary recursively.
validate([prefix])Check the validity of configclass object.
Attributes:
- __init__(device=<factory>, learning_rate=<factory>, batch_size=<factory>, gamma=<factory>, gae_lambda=<factory>, max_grad_norm=<factory>)#
-
batch_size:
int#
- copy(**kwargs)#
Return a new object replacing specified fields with new values.
This is especially useful for frozen classes. Example usage:
@configclass(frozen=True) class C: x: int y: int c = C(1, 2) c1 = c.replace(x=3) assert c1.x == 3 and c1.y == 2
- Parameters:
obj (
object) – The object to replace.**kwargs – The fields to replace and their new values.
- Return type:
object- Returns:
The new object.
-
device:
str#
-
gae_lambda:
float#
-
gamma:
float#
-
learning_rate:
float#
-
max_grad_norm:
float#
- replace(**kwargs)#
Return a new object replacing specified fields with new values.
This is especially useful for frozen classes. Example usage:
@configclass(frozen=True) class C: x: int y: int c = C(1, 2) c1 = c.replace(x=3) assert c1.x == 3 and c1.y == 2
- Parameters:
obj (
object) – The object to replace.**kwargs – The fields to replace and their new values.
- Return type:
object- Returns:
The new object.
- to_dict()#
Convert an object into dictionary recursively.
Note
Ignores all names starting with “__” (i.e. built-in methods).
- Parameters:
obj (
object) – An instance of a class to convert.- Raises:
ValueError – When input argument is not an object.
- Return type:
dict[str,Any]- Returns:
Converted dictionary mapping.
- validate(prefix='')#
Check the validity of configclass object.
This function checks if the object is a valid configclass object. A valid configclass object contains no MISSING entries.
- Parameters:
obj (
object) – The object to check.prefix (
str) – The prefix to add to the missing fields. Defaults to ‘’.
- Return type:
list[str]- Returns:
A list of missing fields.
- Raises:
TypeError – When the object is not a valid configuration object.
General Helpers#
Functions:
|
Convert an environment observation mapping into a TensorDict. |
Flatten a hierarchical observation TensorDict into a 2D tensor. |
- embodichain.agents.rl.utils.helper.dict_to_tensordict(obs_dict, device)[source]#
Convert an environment observation mapping into a TensorDict.
- Parameters:
obs_dict (
TensorDict|Mapping[str,Any]) – Environment observation returned by reset() or step().device (
device|str) – Target device for the resulting TensorDict.
- Return type:
TensorDict- Returns:
Observation TensorDict moved onto the target device.
- embodichain.agents.rl.utils.helper.flatten_dict_observation(obs)[source]#
Flatten a hierarchical observation TensorDict into a 2D tensor.
- Parameters:
obs (
TensorDict) – Observation TensorDict with batch dimension [num_envs].- Return type:
Tensor- Returns:
Flattened observation tensor of shape [num_envs, obs_dim].
Trainer Utilities#
Classes:
Algorithm-agnostic trainer that coordinates training loop, logging, and evaluation. |
- class embodichain.agents.rl.utils.trainer.Trainer[source]#
Bases:
objectAlgorithm-agnostic trainer that coordinates training loop, logging, and evaluation.
Methods: