Source code for embodichain.agents.rl.models

# ----------------------------------------------------------------------------
# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------

from __future__ import annotations

from typing import Dict, Type
import torch
from gymnasium import spaces

from .actor_critic import ActorCritic
from .actor_only import ActorOnly
from .policy import Policy
from .mlp import MLP

# In-module policy registry
_POLICY_REGISTRY: Dict[str, Type[Policy]] = {}


[docs] def register_policy(name: str, policy_cls: Type[Policy]) -> None: if name in _POLICY_REGISTRY: raise ValueError(f"Policy '{name}' is already registered") _POLICY_REGISTRY[name] = policy_cls
[docs] def get_registered_policy_names() -> list[str]: return list(_POLICY_REGISTRY.keys())
[docs] def get_policy_class(name: str) -> Type[Policy] | None: return _POLICY_REGISTRY.get(name)
[docs] def build_policy( policy_block: dict, obs_space: spaces.Space, action_space: spaces.Space, device: torch.device, actor: torch.nn.Module | None = None, critic: torch.nn.Module | None = None, ) -> Policy: """Build policy strictly from json-like block: { name: ..., cfg: {...} }""" name = policy_block["name"].lower() if name not in _POLICY_REGISTRY: available = ", ".join(get_registered_policy_names()) raise ValueError( f"Policy '{name}' is not registered. Available policies: {available}" ) policy_cls = _POLICY_REGISTRY[name] if name == "actor_critic": if actor is None or critic is None: raise ValueError( "ActorCritic policy requires external 'actor' and 'critic' modules." ) return policy_cls(obs_space, action_space, device, actor=actor, critic=critic) elif name == "actor_only": if actor is None: raise ValueError("ActorOnly policy requires external 'actor' module.") return policy_cls(obs_space, action_space, device, actor=actor) else: return policy_cls(obs_space, action_space, device)
[docs] def build_mlp_from_cfg(module_cfg: Dict, in_dim: int, out_dim: int) -> MLP: """Construct an MLP module from a minimal json-like config. Expected schema: module_cfg = { "type": "mlp", "hidden_sizes": [256, 256], "activation": "relu", } """ if module_cfg.get("type", "").lower() != "mlp": raise ValueError("Only 'mlp' type is supported for actor/critic in this setup.") hidden_sizes = module_cfg["network_cfg"]["hidden_sizes"] activation = module_cfg["network_cfg"]["activation"] return MLP(in_dim, out_dim, hidden_sizes, activation)
# default registrations register_policy("actor_critic", ActorCritic) register_policy("actor_only", ActorOnly) __all__ = [ "ActorCritic", "ActorOnly", "register_policy", "get_registered_policy_names", "build_policy", "build_mlp_from_cfg", "get_policy_class", "Policy", "MLP", ]