Source code for embodichain.agents.rl.models.mlp

# ----------------------------------------------------------------------------
# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------

from __future__ import annotations

from functools import reduce
from typing import Iterable, List, Sequence, Tuple, Union

import torch
import torch.nn as nn


ActivationName = Union[str, None]


def _resolve_activation(name: ActivationName) -> nn.Module:
    if name is None:
        return nn.Identity()
    name_l = str(name).lower()
    if name_l in ("relu",):
        return nn.ReLU()
    if name_l in ("elu",):
        return nn.ELU()
    if name_l in ("tanh",):
        return nn.Tanh()
    if name_l in ("gelu",):
        return nn.GELU()
    if name_l in ("silu", "swish"):
        return nn.SiLU()
    # fallback
    return nn.ReLU()


[docs] class MLP(nn.Sequential): """General MLP supporting custom last activation, orthogonal init, and output reshape. Args: - input_dim: input dimension - output_dim: output dimension (int or shape tuple/list) - hidden_dims: hidden layer sizes, e.g. [256, 256] - activation: hidden layer activation name (relu/elu/tanh/gelu/silu) - last_activation: last-layer activation name or None for linear - use_layernorm: whether to add LayerNorm after each hidden linear layer - dropout_p: dropout probability for hidden layers (0 disables) """
[docs] def __init__( self, input_dim: int, output_dim: Union[int, Sequence[int]], hidden_dims: Sequence[int], activation: ActivationName = "elu", last_activation: ActivationName = None, use_layernorm: bool = False, dropout_p: float = 0.0, ) -> None: super().__init__() act = lambda: _resolve_activation(activation) last_act = ( _resolve_activation(last_activation) if last_activation is not None else None ) layers: List[nn.Module] = [] dims = [input_dim] + list(hidden_dims) for in_d, out_d in zip(dims[:-1], dims[1:]): layers.append(nn.Linear(in_d, out_d)) if use_layernorm: layers.append(nn.LayerNorm(out_d)) layers.append(act()) if dropout_p and dropout_p > 0.0: layers.append(nn.Dropout(p=dropout_p)) # Output layer if isinstance(output_dim, int): layers.append(nn.Linear(dims[-1], output_dim)) else: total_out = int(reduce(lambda a, b: a * b, output_dim)) layers.append(nn.Linear(dims[-1], total_out)) layers.append(nn.Unflatten(dim=-1, unflattened_size=tuple(output_dim))) if last_act is not None: layers.append(last_act) for idx, layer in enumerate(layers): self.add_module(str(idx), layer)
[docs] def init_orthogonal(self, scales: Union[float, Sequence[float]] = 1.0) -> None: """Orthogonal-initialize linear layers and zero the bias. scales: single gain value or a sequence with length equal to the number of linear layers. """ def get_scale(i: int) -> float: if isinstance(scales, (list, tuple)): return float(scales[i]) return float(scales) lin_idx = 0 for m in self.modules(): if isinstance(m, nn.Linear): nn.init.orthogonal_(m.weight, gain=get_scale(lin_idx)) nn.init.zeros_(m.bias) lin_idx += 1