embodichain.agents.rl.utils#

Overview#

The utils package contains helper utilities for RL configuration, data conversion, and training orchestration.

Submodules

Configuration Helpers#

Classes:

AlgorithmCfg

Minimal algorithm configuration shared across RL algorithms.

class embodichain.agents.rl.utils.config.AlgorithmCfg[source]#

Bases: object

Minimal algorithm configuration shared across RL algorithms.

Methods:

__init__([device, learning_rate, ...])

copy(**kwargs)

Return a new object replacing specified fields with new values.

replace(**kwargs)

Return a new object replacing specified fields with new values.

to_dict()

Convert an object into dictionary recursively.

validate([prefix])

Check the validity of configclass object.

Attributes:

__init__(device=<factory>, learning_rate=<factory>, batch_size=<factory>, gamma=<factory>, gae_lambda=<factory>, max_grad_norm=<factory>)#
batch_size: int#
copy(**kwargs)#

Return a new object replacing specified fields with new values.

This is especially useful for frozen classes. Example usage:

@configclass(frozen=True)
class C:
    x: int
    y: int

c = C(1, 2)
c1 = c.replace(x=3)
assert c1.x == 3 and c1.y == 2
Parameters:
  • obj (object) – The object to replace.

  • **kwargs – The fields to replace and their new values.

Return type:

object

Returns:

The new object.

device: str#
gae_lambda: float#
gamma: float#
learning_rate: float#
max_grad_norm: float#
replace(**kwargs)#

Return a new object replacing specified fields with new values.

This is especially useful for frozen classes. Example usage:

@configclass(frozen=True)
class C:
    x: int
    y: int

c = C(1, 2)
c1 = c.replace(x=3)
assert c1.x == 3 and c1.y == 2
Parameters:
  • obj (object) – The object to replace.

  • **kwargs – The fields to replace and their new values.

Return type:

object

Returns:

The new object.

to_dict()#

Convert an object into dictionary recursively.

Note

Ignores all names starting with “__” (i.e. built-in methods).

Parameters:

obj (object) – An instance of a class to convert.

Raises:

ValueError – When input argument is not an object.

Return type:

dict[str, Any]

Returns:

Converted dictionary mapping.

validate(prefix='')#

Check the validity of configclass object.

This function checks if the object is a valid configclass object. A valid configclass object contains no MISSING entries.

Parameters:
  • obj (object) – The object to check.

  • prefix (str) – The prefix to add to the missing fields. Defaults to ‘’.

Return type:

list[str]

Returns:

A list of missing fields.

Raises:

TypeError – When the object is not a valid configuration object.

General Helpers#

Functions:

dict_to_tensordict(obs_dict, device)

Convert an environment observation mapping into a TensorDict.

flatten_dict_observation(obs)

Flatten a hierarchical observation TensorDict into a 2D tensor.

embodichain.agents.rl.utils.helper.dict_to_tensordict(obs_dict, device)[source]#

Convert an environment observation mapping into a TensorDict.

Parameters:
  • obs_dict (TensorDict | Mapping[str, Any]) – Environment observation returned by reset() or step().

  • device (device | str) – Target device for the resulting TensorDict.

Return type:

TensorDict

Returns:

Observation TensorDict moved onto the target device.

embodichain.agents.rl.utils.helper.flatten_dict_observation(obs)[source]#

Flatten a hierarchical observation TensorDict into a 2D tensor.

Parameters:

obs (TensorDict) – Observation TensorDict with batch dimension [num_envs].

Return type:

Tensor

Returns:

Flattened observation tensor of shape [num_envs, obs_dim].

Trainer Utilities#

Classes:

Trainer

Algorithm-agnostic trainer that coordinates training loop, logging, and evaluation.

class embodichain.agents.rl.utils.trainer.Trainer[source]#

Bases: object

Algorithm-agnostic trainer that coordinates training loop, logging, and evaluation.

Methods:

__init__(policy, env, algorithm, ...[, ...])

get_summary()

save_checkpoint()

train(total_timesteps)

__init__(policy, env, algorithm, buffer_size, batch_size, writer, eval_freq, save_freq, checkpoint_dir, exp_name, use_wandb=True, eval_env=None, event_cfg=None, eval_event_cfg=None, num_eval_episodes=5, distributed=False, rank=0, world_size=1)[source]#
get_summary()[source]#
Return type:

Dict[str, Any]

save_checkpoint()[source]#
Return type:

str | None

train(total_timesteps)[source]#
Return type:

Dict[str, Any]