Skip to content

MLP

Bases: Module

Multi-Layer Perceptron (MLP) module.

This module consists of two linear layers with an activation function in between. It supports various configurations such as the hidden size, activation function, initializing the output to zero, and recomputing the forward pass during backpropagation.

Parameters:

Name Type Description Default
size int

The input and output size of the MLP.

required
hidden_size int

The size of the hidden layer.

required
activation Union[Activation, str]

The activation function to use. Can be either an Activation enum or a string representing the activation name.

required
device device

The device to use for the linear layers.

required
dtype dtype

The data type to use for the linear layers.

required
initialize_output_to_zero bool

Whether to initialize the output layer weights to zero. Default is False.

False
recompute bool

Whether to recompute the forward pass during backpropagation. This can save memory but increase computation time. Default is False.

False

Attributes:

Name Type Description
linear1 Linear

The first linear layer.

linear2 Linear

The second linear layer.

activation Activation

The activation function to use.

Methods:

Name Description
forward

Performs the forward pass of the MLP. - x (torch.Tensor): The input tensor. - add_input (bool): Whether to add the input to the output. Default is False. - allow_inplace (bool): Indicates that 'x' is not used after the call and its buffer can be reused for the output. The operation is not guaranteed to be inplace. Default is False. - save_peak_mem_factor (Optional[int]): If provided, enables a memory-saving technique that reduces peak memory usage during the forward pass. This requires 'add_input' and 'allow_inplace' to be True. See the documentation of the decorator 'support_save_peak_mem_factor' for details. Default is None.

Example
>>> mlp = MLP(size=128, hidden_size=256, activation='gelu', device='cuda', dtype=torch.float32)
>>> x = torch.randn(32, 128, device='cuda', dtype=torch.float32)
>>> output = mlp(x)

__init__

__init__(
    size: int,
    hidden_size: int,
    activation: Union[Activation, str],
    device,
    dtype,
    initialize_output_to_zero: bool = False,
    recompute: bool = False,
)

forward

forward(
    x: Tensor,
    add_input: bool = False,
    allow_inplace: bool = False,
    save_peak_mem_factor: Optional[int] = None,
) -> Tensor