Source code for embodichain.agents.rl.models

# ----------------------------------------------------------------------------
# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------

from __future__ import annotations

import inspect
from typing import Dict, Type

from gymnasium import spaces
import torch

from .actor_critic import ActorCritic
from .actor_only import ActorOnly
from .policy import Policy
from .mlp import MLP

# In-module policy registry
_POLICY_REGISTRY: Dict[str, Type[Policy]] = {}


[docs] def register_policy(name: str, policy_cls: Type[Policy]) -> None: if name in _POLICY_REGISTRY: raise ValueError(f"Policy '{name}' is already registered") _POLICY_REGISTRY[name] = policy_cls
[docs] def get_registered_policy_names() -> list[str]: return list(_POLICY_REGISTRY.keys())
[docs] def get_policy_class(name: str) -> Type[Policy] | None: return _POLICY_REGISTRY.get(name)
def _resolve_space_dim(space_or_dim: spaces.Space | int, name: str) -> int: """Resolve a flattened feature dimension from an integer or simple Box space.""" if isinstance(space_or_dim, int): return space_or_dim if isinstance(space_or_dim, spaces.Box) and len(space_or_dim.shape) > 0: return int(space_or_dim.shape[-1]) raise TypeError( f"{name} must be an int or a flat Box space for MLP-based policies, got {type(space_or_dim)!r}." )
[docs] def build_policy( policy_block: dict, obs_space: spaces.Space | int, action_space: spaces.Space | int, device: torch.device, actor: torch.nn.Module | None = None, critic: torch.nn.Module | None = None, ) -> Policy: """Build a policy from config using spaces for extensibility. Built-in MLP policies still resolve flattened `obs_dim` / `action_dim`, while custom policies may accept richer `obs_space` / `action_space` inputs. """ name = policy_block["name"].lower() if name not in _POLICY_REGISTRY: available = ", ".join(get_registered_policy_names()) raise ValueError( f"Policy '{name}' is not registered. Available policies: {available}" ) policy_cls = _POLICY_REGISTRY[name] if name == "actor_critic": if actor is None or critic is None: raise ValueError( "ActorCritic policy requires external 'actor' and 'critic' modules." ) obs_dim = _resolve_space_dim(obs_space, "obs_space") action_dim = _resolve_space_dim(action_space, "action_space") return policy_cls( obs_dim=obs_dim, action_dim=action_dim, device=device, actor=actor, critic=critic, ) elif name == "actor_only": if actor is None: raise ValueError("ActorOnly policy requires external 'actor' module.") obs_dim = _resolve_space_dim(obs_space, "obs_space") action_dim = _resolve_space_dim(action_space, "action_space") return policy_cls( obs_dim=obs_dim, action_dim=action_dim, device=device, actor=actor, ) init_params = inspect.signature(policy_cls.__init__).parameters build_kwargs: dict[str, object] = {"device": device} if "obs_space" in init_params: build_kwargs["obs_space"] = obs_space elif "obs_dim" in init_params: build_kwargs["obs_dim"] = _resolve_space_dim(obs_space, "obs_space") if "action_space" in init_params: build_kwargs["action_space"] = action_space elif "action_dim" in init_params: build_kwargs["action_dim"] = _resolve_space_dim(action_space, "action_space") if "actor" in init_params and actor is not None: build_kwargs["actor"] = actor if "critic" in init_params and critic is not None: build_kwargs["critic"] = critic return policy_cls(**build_kwargs)
[docs] def build_mlp_from_cfg(module_cfg: Dict, in_dim: int, out_dim: int) -> MLP: """Construct an MLP module from a minimal json-like config. Expected schema: module_cfg = { "type": "mlp", "hidden_sizes": [256, 256], "activation": "relu", } """ if module_cfg.get("type", "").lower() != "mlp": raise ValueError("Only 'mlp' type is supported for actor/critic in this setup.") hidden_sizes = module_cfg["network_cfg"]["hidden_sizes"] activation = module_cfg["network_cfg"]["activation"] return MLP(in_dim, out_dim, hidden_sizes, activation)
# default registrations register_policy("actor_critic", ActorCritic) register_policy("actor_only", ActorOnly) __all__ = [ "ActorCritic", "ActorOnly", "register_policy", "get_registered_policy_names", "build_policy", "build_mlp_from_cfg", "get_policy_class", "Policy", "MLP", ]