Skip to content

Common Modules

Embeddings

Bases: Module

Enables different values in a categorical features to have different embeddings.

Source code in src/pytorch_tabular/models/common/layers/embeddings.py
class Embedding1dLayer(nn.Module):
    """Enables different values in a categorical features to have different embeddings."""

    def __init__(
        self,
        continuous_dim: int,
        categorical_embedding_dims: Tuple[int, int],
        embedding_dropout: float = 0.0,
        batch_norm_continuous_input: bool = False,
        virtual_batch_size: Optional[int] = None,
    ):
        super().__init__()
        self.continuous_dim = continuous_dim
        self.categorical_embedding_dims = categorical_embedding_dims
        self.batch_norm_continuous_input = batch_norm_continuous_input

        # Embedding layers
        self.cat_embedding_layers = nn.ModuleList([nn.Embedding(x, y) for x, y in categorical_embedding_dims])
        if embedding_dropout > 0:
            self.embd_dropout = nn.Dropout(embedding_dropout)
        else:
            self.embd_dropout = None
        # Continuous Layers
        if batch_norm_continuous_input:
            self.normalizing_batch_norm = BatchNorm1d(continuous_dim, virtual_batch_size)

    def forward(self, x: Dict[str, Any]) -> torch.Tensor:
        assert "continuous" in x or "categorical" in x, "x must contain either continuous and categorical features"
        # (B, N)
        continuous_data, categorical_data = (
            x.get("continuous", torch.empty(0, 0)),
            x.get("categorical", torch.empty(0, 0)),
        )
        assert categorical_data.shape[1] == len(
            self.cat_embedding_layers
        ), "categorical_data must have same number of columns as categorical embedding layers"
        assert (
            continuous_data.shape[1] == self.continuous_dim
        ), "continuous_data must have same number of columns as continuous dim"
        embed = None
        if continuous_data.shape[1] > 0:
            if self.batch_norm_continuous_input:
                embed = self.normalizing_batch_norm(continuous_data)
            else:
                embed = continuous_data
            # (B, N, C)
        if categorical_data.shape[1] > 0:
            categorical_embed = torch.cat(
                [
                    embedding_layer(categorical_data[:, i])
                    for i, embedding_layer in enumerate(self.cat_embedding_layers)
                ],
                dim=1,
            )
            # (B, N, C + C)
            if embed is None:
                embed = categorical_embed
            else:
                embed = torch.cat([embed, categorical_embed], dim=1)
        if self.embd_dropout is not None:
            embed = self.embd_dropout(embed)
        return embed

Bases: Module

Embeds categorical and continuous features into a 2D tensor.

Source code in src/pytorch_tabular/models/common/layers/embeddings.py
class Embedding2dLayer(nn.Module):
    """Embeds categorical and continuous features into a 2D tensor."""

    def __init__(
        self,
        continuous_dim: int,
        categorical_cardinality: List[int],
        embedding_dim: int,
        shared_embedding_strategy: Optional[str] = None,
        frac_shared_embed: float = 0.25,
        embedding_bias: bool = False,
        batch_norm_continuous_input: bool = False,
        virtual_batch_size: Optional[int] = None,
        embedding_dropout: float = 0.0,
        initialization: Optional[str] = None,
    ):
        """
        Args:
            continuous_dim: number of continuous features
            categorical_cardinality: list of cardinalities of categorical features
            embedding_dim: embedding dimension
            shared_embedding_strategy: strategy to use for shared embeddings
            frac_shared_embed: fraction of embeddings to share
            embedding_bias: whether to use bias in embedding layers
            batch_norm_continuous_input: whether to use batch norm on continuous features
            embedding_dropout: dropout to apply to embeddings
            initialization: initialization strategy to use for embedding layers"""
        super().__init__()
        self.continuous_dim = continuous_dim
        self.categorical_cardinality = categorical_cardinality
        self.embedding_dim = embedding_dim
        self.batch_norm_continuous_input = batch_norm_continuous_input
        self.shared_embedding_strategy = shared_embedding_strategy
        self.frac_shared_embed = frac_shared_embed
        self.embedding_bias = embedding_bias
        self.initialization = initialization
        d_sqrt_inv = 1 / math.sqrt(embedding_dim)
        if initialization is not None:
            assert initialization in [
                "kaiming_uniform",
                "kaiming_normal",
            ], "initialization should be either of `kaiming` or `uniform`"
            self._do_kaiming_initialization = True
            self._initialize_kaiming = partial(
                _initialize_kaiming,
                initialization=initialization,
                d_sqrt_inv=d_sqrt_inv,
            )
        else:
            self._do_kaiming_initialization = False

        # cat Embedding layers
        if self.shared_embedding_strategy is not None:
            self.cat_embedding_layers = nn.ModuleList(
                [
                    SharedEmbeddings(
                        c,
                        self.embedding_dim,
                        add_shared_embed=(self.shared_embedding_strategy == "add"),
                        frac_shared_embed=self.frac_shared_embed,
                    )
                    for c in categorical_cardinality
                ]
            )
            if self._do_kaiming_initialization:
                for embedding_layer in self.cat_embedding_layers:
                    self._initialize_kaiming(embedding_layer.embed.weight)
                    self._initialize_kaiming(embedding_layer.shared_embed)
        else:
            self.cat_embedding_layers = nn.ModuleList(
                [nn.Embedding(c, self.embedding_dim) for c in categorical_cardinality]
            )
            if self._do_kaiming_initialization:
                for embedding_layer in self.cat_embedding_layers:
                    self._initialize_kaiming(embedding_layer.weight)
        if embedding_bias:
            self.cat_embedding_bias = nn.Parameter(torch.Tensor(len(self.categorical_cardinality), self.embedding_dim))
            if self._do_kaiming_initialization:
                self._initialize_kaiming(self.cat_embedding_bias)
        # Continuous Embedding Layer
        self.cont_embedding_layer = nn.Embedding(self.continuous_dim, self.embedding_dim)
        if self._do_kaiming_initialization:
            self._initialize_kaiming(self.cont_embedding_layer.weight)
        if embedding_bias:
            self.cont_embedding_bias = nn.Parameter(torch.Tensor(self.continuous_dim, self.embedding_dim))
            if self._do_kaiming_initialization:
                self._initialize_kaiming(self.cont_embedding_bias)
        if batch_norm_continuous_input:
            self.normalizing_batch_norm = BatchNorm1d(continuous_dim, virtual_batch_size)
        if embedding_dropout > 0:
            self.embd_dropout = nn.Dropout(embedding_dropout)
        else:
            self.embd_dropout = None

    def forward(self, x: Dict[str, Any]) -> torch.Tensor:
        assert "continuous" in x or "categorical" in x, "x must contain either continuous and categorical features"
        # (B, N)
        continuous_data, categorical_data = (
            x.get("continuous", torch.empty(0, 0)),
            x.get("categorical", torch.empty(0, 0)),
        )
        assert categorical_data.shape[1] == len(
            self.cat_embedding_layers
        ), "categorical_data must have same number of columns as categorical embedding layers"
        assert (
            continuous_data.shape[1] == self.continuous_dim
        ), "continuous_data must have same number of columns as continuous dim"
        embed = None
        if continuous_data.shape[1] > 0:
            cont_idx = torch.arange(self.continuous_dim, device=continuous_data.device).expand(
                continuous_data.size(0), -1
            )
            if self.batch_norm_continuous_input:
                continuous_data = self.normalizing_batch_norm(continuous_data)
            embed = torch.mul(
                continuous_data.unsqueeze(2),
                self.cont_embedding_layer(cont_idx),
            )
            if self.embedding_bias:
                embed += self.cont_embedding_bias
            # (B, N, C)
        if categorical_data.shape[1] > 0:
            categorical_embed = torch.cat(
                [
                    embedding_layer(categorical_data[:, i]).unsqueeze(1)
                    for i, embedding_layer in enumerate(self.cat_embedding_layers)
                ],
                dim=1,
            )
            if self.embedding_bias:
                categorical_embed += self.cat_embedding_bias
            # (B, N, C + C)
            if embed is None:
                embed = categorical_embed
            else:
                embed = torch.cat([embed, categorical_embed], dim=1)
        if self.embd_dropout is not None:
            embed = self.embd_dropout(embed)
        return embed

__init__(continuous_dim, categorical_cardinality, embedding_dim, shared_embedding_strategy=None, frac_shared_embed=0.25, embedding_bias=False, batch_norm_continuous_input=False, virtual_batch_size=None, embedding_dropout=0.0, initialization=None)

Parameters:

Name Type Description Default
continuous_dim int

number of continuous features

required
categorical_cardinality List[int]

list of cardinalities of categorical features

required
embedding_dim int

embedding dimension

required
shared_embedding_strategy Optional[str]

strategy to use for shared embeddings

None
frac_shared_embed float

fraction of embeddings to share

0.25
embedding_bias bool

whether to use bias in embedding layers

False
batch_norm_continuous_input bool

whether to use batch norm on continuous features

False
embedding_dropout float

dropout to apply to embeddings

0.0
initialization Optional[str]

initialization strategy to use for embedding layers

None
Source code in src/pytorch_tabular/models/common/layers/embeddings.py
def __init__(
    self,
    continuous_dim: int,
    categorical_cardinality: List[int],
    embedding_dim: int,
    shared_embedding_strategy: Optional[str] = None,
    frac_shared_embed: float = 0.25,
    embedding_bias: bool = False,
    batch_norm_continuous_input: bool = False,
    virtual_batch_size: Optional[int] = None,
    embedding_dropout: float = 0.0,
    initialization: Optional[str] = None,
):
    """
    Args:
        continuous_dim: number of continuous features
        categorical_cardinality: list of cardinalities of categorical features
        embedding_dim: embedding dimension
        shared_embedding_strategy: strategy to use for shared embeddings
        frac_shared_embed: fraction of embeddings to share
        embedding_bias: whether to use bias in embedding layers
        batch_norm_continuous_input: whether to use batch norm on continuous features
        embedding_dropout: dropout to apply to embeddings
        initialization: initialization strategy to use for embedding layers"""
    super().__init__()
    self.continuous_dim = continuous_dim
    self.categorical_cardinality = categorical_cardinality
    self.embedding_dim = embedding_dim
    self.batch_norm_continuous_input = batch_norm_continuous_input
    self.shared_embedding_strategy = shared_embedding_strategy
    self.frac_shared_embed = frac_shared_embed
    self.embedding_bias = embedding_bias
    self.initialization = initialization
    d_sqrt_inv = 1 / math.sqrt(embedding_dim)
    if initialization is not None:
        assert initialization in [
            "kaiming_uniform",
            "kaiming_normal",
        ], "initialization should be either of `kaiming` or `uniform`"
        self._do_kaiming_initialization = True
        self._initialize_kaiming = partial(
            _initialize_kaiming,
            initialization=initialization,
            d_sqrt_inv=d_sqrt_inv,
        )
    else:
        self._do_kaiming_initialization = False

    # cat Embedding layers
    if self.shared_embedding_strategy is not None:
        self.cat_embedding_layers = nn.ModuleList(
            [
                SharedEmbeddings(
                    c,
                    self.embedding_dim,
                    add_shared_embed=(self.shared_embedding_strategy == "add"),
                    frac_shared_embed=self.frac_shared_embed,
                )
                for c in categorical_cardinality
            ]
        )
        if self._do_kaiming_initialization:
            for embedding_layer in self.cat_embedding_layers:
                self._initialize_kaiming(embedding_layer.embed.weight)
                self._initialize_kaiming(embedding_layer.shared_embed)
    else:
        self.cat_embedding_layers = nn.ModuleList(
            [nn.Embedding(c, self.embedding_dim) for c in categorical_cardinality]
        )
        if self._do_kaiming_initialization:
            for embedding_layer in self.cat_embedding_layers:
                self._initialize_kaiming(embedding_layer.weight)
    if embedding_bias:
        self.cat_embedding_bias = nn.Parameter(torch.Tensor(len(self.categorical_cardinality), self.embedding_dim))
        if self._do_kaiming_initialization:
            self._initialize_kaiming(self.cat_embedding_bias)
    # Continuous Embedding Layer
    self.cont_embedding_layer = nn.Embedding(self.continuous_dim, self.embedding_dim)
    if self._do_kaiming_initialization:
        self._initialize_kaiming(self.cont_embedding_layer.weight)
    if embedding_bias:
        self.cont_embedding_bias = nn.Parameter(torch.Tensor(self.continuous_dim, self.embedding_dim))
        if self._do_kaiming_initialization:
            self._initialize_kaiming(self.cont_embedding_bias)
    if batch_norm_continuous_input:
        self.normalizing_batch_norm = BatchNorm1d(continuous_dim, virtual_batch_size)
    if embedding_dropout > 0:
        self.embd_dropout = nn.Dropout(embedding_dropout)
    else:
        self.embd_dropout = None

Bases: Module

Takes in pre-encoded categorical variables and just concatenates with continuous variables No learnable component.

Source code in src/pytorch_tabular/models/common/layers/embeddings.py
class PreEncoded1dLayer(nn.Module):
    """Takes in pre-encoded categorical variables and just concatenates with continuous variables No learnable
    component."""

    def __init__(
        self,
        continuous_dim: int,
        categorical_dim: Tuple[int, int],
        embedding_dropout: float = 0.0,
        batch_norm_continuous_input: bool = False,
        virtual_batch_size: Optional[int] = None,
    ):
        super().__init__()
        self.continuous_dim = continuous_dim
        self.categorical_dim = categorical_dim
        self.batch_norm_continuous_input = batch_norm_continuous_input

        if embedding_dropout > 0:
            self.embd_dropout = nn.Dropout(embedding_dropout)
        else:
            self.embd_dropout = None
        # Continuous Layers
        if batch_norm_continuous_input:
            self.normalizing_batch_norm = BatchNorm1d(continuous_dim, virtual_batch_size)

    def forward(self, x: Dict[str, Any]) -> torch.Tensor:
        assert "continuous" in x or "categorical" in x, "x must contain either continuous and categorical features"
        # (B, N)
        continuous_data, categorical_data = (
            x.get("continuous", torch.empty(0, 0)),
            x.get("categorical", torch.empty(0, 0)),
        )
        assert (
            categorical_data.shape[1] == self.categorical_dim
        ), "categorical_data must have same number of columns as categorical embedding layers"
        assert (
            continuous_data.shape[1] == self.continuous_dim
        ), "continuous_data must have same number of columns as continuous dim"
        embed = None
        if continuous_data.shape[1] > 0:
            if self.batch_norm_continuous_input:
                embed = self.normalizing_batch_norm(continuous_data)
            else:
                embed = continuous_data
            # (B, N, C)
        if categorical_data.shape[1] > 0:
            # (B, N, C)
            if embed is None:
                embed = categorical_data
            else:
                embed = torch.cat([embed, categorical_data], dim=1)
        if self.embd_dropout is not None:
            embed = self.embd_dropout(embed)
        return embed

Bases: Module

Enables different values in a categorical feature to share some embeddings across.

Source code in src/pytorch_tabular/models/common/layers/embeddings.py
class SharedEmbeddings(nn.Module):
    """Enables different values in a categorical feature to share some embeddings across."""

    def __init__(
        self,
        num_embed: int,
        embed_dim: int,
        add_shared_embed: bool = False,
        frac_shared_embed: float = 0.25,
    ):
        super().__init__()
        assert frac_shared_embed < 1, "'frac_shared_embed' must be less than 1"

        self.add_shared_embed = add_shared_embed
        self.embed = nn.Embedding(num_embed, embed_dim, padding_idx=0)
        self.embed.weight.data.clamp_(-2, 2)
        if add_shared_embed:
            col_embed_dim = embed_dim
        else:
            col_embed_dim = int(embed_dim * frac_shared_embed)
        self.shared_embed = nn.Parameter(torch.empty(1, col_embed_dim).uniform_(-1, 1))

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        out = self.embed(X)
        shared_embed = self.shared_embed.expand(out.shape[0], -1)
        if self.add_shared_embed:
            out += shared_embed
        else:
            out[:, : shared_embed.shape[1]] = shared_embed
        return out

    @property
    def weight(self):
        w = self.embed.weight.detach()
        if self.add_shared_embed:
            w += self.shared_embed
        else:
            w[:, : self.shared_embed.shape[1]] = self.shared_embed
        return w

Gated Units

Bases: Module

Source code in src/pytorch_tabular/models/common/layers/gated_units.py
class GatedFeatureLearningUnit(nn.Module):
    def __init__(
        self,
        n_features_in: int,
        n_stages: int,
        feature_mask_function: Callable = entmax15,
        feature_sparsity: float = 0.3,
        learnable_sparsity: bool = True,
        dropout: float = 0.0,
    ):
        super().__init__()
        self.n_features_in = n_features_in
        self.n_features_out = n_features_in
        self.feature_mask_function = feature_mask_function
        self._dropout = dropout
        self.n_stages = n_stages
        self.feature_sparsity = feature_sparsity
        self.learnable_sparsity = learnable_sparsity
        self._build_network()

    def _create_feature_mask(self):
        feature_masks = torch.cat(
            [
                torch.distributions.Beta(
                    torch.tensor([random.uniform(0.5, 10.0)]),
                    torch.tensor([random.uniform(0.5, 10.0)]),
                )
                .sample((self.n_features_in,))
                .squeeze(-1)
                for _ in range(self.n_stages)
            ]
        ).reshape(self.n_stages, self.n_features_in)
        return nn.Parameter(
            feature_masks,
            requires_grad=True,
        )

    def _build_network(self):
        self.W_in = nn.ModuleList(
            [nn.Linear(2 * self.n_features_in, 2 * self.n_features_in) for _ in range(self.n_stages)]
        )
        self.W_out = nn.ModuleList(
            [nn.Linear(2 * self.n_features_in, self.n_features_in) for _ in range(self.n_stages)]
        )

        self.feature_masks = self._create_feature_mask()
        if self.feature_mask_function.__name__ == "t_softmax":
            t = RSoftmax.calculate_t(self.feature_masks, r=torch.tensor([self.feature_sparsity]), dim=-1)
            self.t = nn.Parameter(t, requires_grad=self.learnable_sparsity)
        self.dropout = nn.Dropout(self._dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = x
        t = torch.relu(self.t) if self.feature_mask_function.__name__ == "t_softmax" else None
        for d in range(self.n_stages):
            if self.feature_mask_function.__name__ == "t_softmax":
                feature = self.feature_mask_function(self.feature_masks[d], t[d]) * x
            else:
                feature = self.feature_mask_function(self.feature_masks[d]) * x
            h_in = self.W_in[d](torch.cat([feature, h], dim=-1))
            z = torch.sigmoid(h_in[:, : self.n_features_in])
            r = torch.sigmoid(h_in[:, self.n_features_in :])
            h_out = torch.tanh(self.W_out[d](torch.cat([r * h, x], dim=-1)))
            h = self.dropout((1 - z) * h + z * h_out)
        return h

Bases: Module

Gated Exponential Linear Unit (GEGLU)

Source code in src/pytorch_tabular/models/common/layers/gated_units.py
class GEGLU(nn.Module):
    """Gated Exponential Linear Unit (GEGLU)"""

    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        """
        Args:
            d_model: dimension of the model
            d_ff: dimension of the feedforward layer
            dropout: dropout probability
        """
        super().__init__()
        self.ffn = PositionWiseFeedForward(d_model, d_ff, dropout, nn.GELU(), True, False, False, False)

    def forward(self, x: torch.Tensor):
        return self.ffn(x)

__init__(d_model, d_ff, dropout=0.1)

Parameters:

Name Type Description Default
d_model int

dimension of the model

required
d_ff int

dimension of the feedforward layer

required
dropout float

dropout probability

0.1
Source code in src/pytorch_tabular/models/common/layers/gated_units.py
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
    """
    Args:
        d_model: dimension of the model
        d_ff: dimension of the feedforward layer
        dropout: dropout probability
    """
    super().__init__()
    self.ffn = PositionWiseFeedForward(d_model, d_ff, dropout, nn.GELU(), True, False, False, False)

Bases: Module

ReGLU.

Source code in src/pytorch_tabular/models/common/layers/gated_units.py
class ReGLU(nn.Module):
    """ReGLU."""

    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        """
        Args:
            d_model: dimension of the model
            d_ff: dimension of the feedforward layer
            dropout: dropout probability
        """
        super().__init__()
        self.ffn = PositionWiseFeedForward(d_model, d_ff, dropout, nn.ReLU(), True, False, False, False)

    def forward(self, x: torch.Tensor):
        return self.ffn(x)

__init__(d_model, d_ff, dropout=0.1)

Parameters:

Name Type Description Default
d_model int

dimension of the model

required
d_ff int

dimension of the feedforward layer

required
dropout float

dropout probability

0.1
Source code in src/pytorch_tabular/models/common/layers/gated_units.py
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
    """
    Args:
        d_model: dimension of the model
        d_ff: dimension of the feedforward layer
        dropout: dropout probability
    """
    super().__init__()
    self.ffn = PositionWiseFeedForward(d_model, d_ff, dropout, nn.ReLU(), True, False, False, False)

Bases: Module

Source code in src/pytorch_tabular/models/common/layers/gated_units.py
class SwiGLU(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        """
        Args:
            d_model: dimension of the model
            d_ff: dimension of the feedforward layer
            dropout: dropout probability
        """
        super().__init__()
        self.ffn = PositionWiseFeedForward(d_model, d_ff, dropout, nn.SiLU(), True, False, False, False)

    def forward(self, x: torch.Tensor):
        return self.ffn(x)

__init__(d_model, d_ff, dropout=0.1)

Parameters:

Name Type Description Default
d_model int

dimension of the model

required
d_ff int

dimension of the feedforward layer

required
dropout float

dropout probability

0.1
Source code in src/pytorch_tabular/models/common/layers/gated_units.py
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
    """
    Args:
        d_model: dimension of the model
        d_ff: dimension of the feedforward layer
        dropout: dropout probability
    """
    super().__init__()
    self.ffn = PositionWiseFeedForward(d_model, d_ff, dropout, nn.SiLU(), True, False, False, False)

Bases: Module

title: Position-wise Feed-Forward Network (FFN) summary: Documented reusable implementation of the position wise feedforward network.

Position-wise Feed-Forward Network (FFN)

This is a PyTorch implementation of position-wise feedforward network used in transformer. FFN consists of two fully connected layers. Number of dimensions in the hidden layer $d_{ff}$, is generally set to around four times that of the token embedding $d_{model}$. So it is sometime also called the expand-and-contract network. There is an activation at the hidden layer, which is usually set to ReLU (Rectified Linear Unit) activation, $$\max(0, x)$$ That is, the FFN function is, $$FFN(x, W_1, W_2, b_1, b_2) = \max(0, x W_1 + b_1) W_2 + b_2$$ where $W_1$, $W_2$, $b_1$ and $b_2$ are learnable parameters. Sometimes the GELU (Gaussian Error Linear Unit) activation is also used instead of ReLU. $$x \Phi(x)$$ where $\Phi(x) = P(X \le x), X \sim \mathcal{N}(0,1)$

Gated Linear Units

This is a generic implementation that supports different variants including Gated Linear Units (GLU).

Source code in src/pytorch_tabular/models/common/layers/gated_units.py
class PositionWiseFeedForward(nn.Module):
    r"""
    title: Position-wise Feed-Forward Network (FFN)
    summary: Documented reusable implementation of the position wise feedforward network.

    # Position-wise Feed-Forward Network (FFN)
    This is a [PyTorch](https://pytorch.org)  implementation
    of position-wise feedforward network used in transformer.
    FFN consists of two fully connected layers.
    Number of dimensions in the hidden layer $d_{ff}$, is generally set to around
    four times that of the token embedding $d_{model}$.
    So it is sometime also called the expand-and-contract network.
    There is an activation at the hidden layer, which is
    usually set to ReLU (Rectified Linear Unit) activation, $$\\max(0, x)$$
    That is, the FFN function is,
    $$FFN(x, W_1, W_2, b_1, b_2) = \\max(0, x W_1 + b_1) W_2 + b_2$$
    where $W_1$, $W_2$, $b_1$ and $b_2$ are learnable parameters.
    Sometimes the
    GELU (Gaussian Error Linear Unit) activation is also used instead of ReLU.
    $$x \\Phi(x)$$ where $\\Phi(x) = P(X \\le x), X \\sim \\mathcal{N}(0,1)$
    ### Gated Linear Units
    This is a generic implementation that supports different variants including
    [Gated Linear Units](https://arxiv.org/abs/2002.05202) (GLU).
    """

    def __init__(
        self,
        d_model: int,
        d_ff: int,
        dropout: float = 0.1,
        activation=nn.ReLU(),
        is_gated: bool = False,
        bias1: bool = True,
        bias2: bool = True,
        bias_gate: bool = True,
    ):
        """
        * `d_model` is the number of features in a token embedding
        * `d_ff` is the number of features in the hidden layer of the FFN
        * `dropout` is dropout probability for the hidden layer
        * `is_gated` specifies whether the hidden layer is gated
        * `bias1` specified whether the first fully connected layer should have a learnable bias
        * `bias2` specified whether the second fully connected layer should have a learnable bias
        * `bias_gate` specified whether the fully connected layer for the gate should have a learnable bias
        """
        super().__init__()
        # Layer one parameterized by weight $W_1$ and bias $b_1$
        self.layer1 = nn.Linear(d_model, d_ff, bias=bias1)
        # Layer one parameterized by weight $W_1$ and bias $b_1$
        self.layer2 = nn.Linear(d_ff, d_model, bias=bias2)
        # Hidden layer dropout
        self.dropout = nn.Dropout(dropout)
        # Activation function $f$
        self.activation = activation
        # Whether there is a gate
        self.is_gated = is_gated
        if is_gated:
            # If there is a gate the linear layer to transform inputs to
            # be multiplied by the gate, parameterized by weight $V$ and bias $c$
            self.linear_v = nn.Linear(d_model, d_ff, bias=bias_gate)

    def forward(self, x: torch.Tensor):
        # $f(x W_1 + b_1)$
        g = self.activation(self.layer1(x))
        # If gated, $f(x W_1 + b_1) \otimes (x V + b) $
        if self.is_gated:
            x = g * self.linear_v(x)
        # Otherwise
        else:
            x = g
        # Apply dropout
        x = self.dropout(x)
        # $(f(x W_1 + b_1) \otimes (x V + b)) W_2 + b_2$ or $f(x W_1 + b_1) W_2 + b_2$
        # depending on whether it is gated
        return self.layer2(x)

__init__(d_model, d_ff, dropout=0.1, activation=nn.ReLU(), is_gated=False, bias1=True, bias2=True, bias_gate=True)

  • d_model is the number of features in a token embedding
  • d_ff is the number of features in the hidden layer of the FFN
  • dropout is dropout probability for the hidden layer
  • is_gated specifies whether the hidden layer is gated
  • bias1 specified whether the first fully connected layer should have a learnable bias
  • bias2 specified whether the second fully connected layer should have a learnable bias
  • bias_gate specified whether the fully connected layer for the gate should have a learnable bias
Source code in src/pytorch_tabular/models/common/layers/gated_units.py
def __init__(
    self,
    d_model: int,
    d_ff: int,
    dropout: float = 0.1,
    activation=nn.ReLU(),
    is_gated: bool = False,
    bias1: bool = True,
    bias2: bool = True,
    bias_gate: bool = True,
):
    """
    * `d_model` is the number of features in a token embedding
    * `d_ff` is the number of features in the hidden layer of the FFN
    * `dropout` is dropout probability for the hidden layer
    * `is_gated` specifies whether the hidden layer is gated
    * `bias1` specified whether the first fully connected layer should have a learnable bias
    * `bias2` specified whether the second fully connected layer should have a learnable bias
    * `bias_gate` specified whether the fully connected layer for the gate should have a learnable bias
    """
    super().__init__()
    # Layer one parameterized by weight $W_1$ and bias $b_1$
    self.layer1 = nn.Linear(d_model, d_ff, bias=bias1)
    # Layer one parameterized by weight $W_1$ and bias $b_1$
    self.layer2 = nn.Linear(d_ff, d_model, bias=bias2)
    # Hidden layer dropout
    self.dropout = nn.Dropout(dropout)
    # Activation function $f$
    self.activation = activation
    # Whether there is a gate
    self.is_gated = is_gated
    if is_gated:
        # If there is a gate the linear layer to transform inputs to
        # be multiplied by the gate, parameterized by weight $V$ and bias $c$
        self.linear_v = nn.Linear(d_model, d_ff, bias=bias_gate)

Soft Trees

Bases: Module

Source code in src/pytorch_tabular/models/common/layers/soft_trees.py
class NeuralDecisionTree(nn.Module):
    def __init__(
        self,
        depth: int,
        n_features: int,
        dropout: float = 0,
        binning_activation: Callable = entmax15,
        feature_mask_function: Callable = entmax15,
        feature_sparsity: float = 0.8,
        learnable_sparsity: bool = True,
    ):
        super().__init__()
        self.depth = depth
        self._num_cutpoints = 1
        self.n_features = n_features
        self._dropout = dropout
        self.binning_activation = binning_activation
        self.feature_mask_function = feature_mask_function
        self.feature_sparsity = feature_sparsity
        self.learnable_sparsity = learnable_sparsity
        self._build_network()

    def _build_network(self):
        for d in range(self.depth):
            for n in range(max(2 ** (d), 1)):
                self.add_module(
                    f"decision_stump_{d}_{n}",
                    NeuralDecisionStump(
                        self.n_features + (2 ** (d) if d > 0 else 0),
                        self.binning_activation,
                        self.feature_mask_function,
                        self.feature_sparsity,
                        self.learnable_sparsity,
                    ),
                )
        self.dropout = nn.Dropout(self._dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        tree_input = x
        feature_masks = []
        for d in range(self.depth):
            layer_nodes = []
            layer_feature_masks = []
            for n in range(max(2 ** (d), 1)):
                leaf_nodes, feature_mask = self._modules[f"decision_stump_{d}_{n}"](tree_input)
                layer_nodes.append(leaf_nodes)
                layer_feature_masks.append(feature_mask)
            layer_nodes = torch.cat(layer_nodes, dim=1)
            tree_input = torch.cat([x, layer_nodes], dim=1)
            feature_masks.append(layer_feature_masks)
        return self.dropout(layer_nodes), feature_masks

Bases: ModuleWithInit

Source code in src/pytorch_tabular/models/common/layers/soft_trees.py
class ODST(ModuleWithInit):
    def __init__(
        self,
        in_features,
        num_trees,
        depth=6,
        tree_output_dim=1,
        flatten_output=True,
        choice_function=sparsemax,
        bin_function=sparsemoid,
        initialize_response_=nn.init.normal_,
        initialize_selection_logits_=nn.init.uniform_,
        threshold_init_beta=1.0,
        threshold_init_cutoff=1.0,
    ):
        """Oblivious Differentiable Sparsemax Trees. http://tinyurl.com/odst-readmore One can drop (sic!) this module
        anywhere instead of nn.Linear.

        :param in_features: number of features in the input tensor
        :param num_trees: number of trees in this layer
        :param tree_dim: number of response channels in the response of individual tree
        :param depth: number of splits in every tree
        :param flatten_output: if False, returns [..., num_trees, tree_dim],
            by default returns [..., num_trees * tree_dim]
        :param choice_function: f(tensor, dim) -> R_simplex computes feature weights s.t. f(tensor, dim).sum(dim) == 1
        :param bin_function: f(tensor) -> R[0, 1], computes tree leaf weights

        :param initialize_response_: in-place initializer for tree output tensor
        :param initialize_selection_logits_: in-place initializer for logits that select features for the tree
        both thresholds and scales are initialized with data-aware init (or .load_state_dict)
        :param threshold_init_beta: initializes threshold to a q-th quantile of data points
            where q ~ Beta(:threshold_init_beta:, :threshold_init_beta:)
            If this param is set to 1, initial thresholds will have the same distribution as data points
            If greater than 1 (e.g. 10), thresholds will be closer to median data value
            If less than 1 (e.g. 0.1), thresholds will approach min/max data values.

        :param threshold_init_cutoff: threshold log-temperatures initializer, in (0, inf)
            By default(1.0), log-remperatures are initialized in such a way that all bin selectors
            end up in the linear region of sparse-sigmoid. The temperatures are then scaled by this parameter.
            Setting this value > 1.0 will result in some margin between data points and sparse-sigmoid cutoff value
            Setting this value < 1.0 will cause (1 - value) part of data points to end up in flat sparse-sigmoid region
            For instance, threshold_init_cutoff = 0.9 will set 10% points equal to 0.0 or 1.0
            Setting this value > 1.0 will result in a margin between data points and sparse-sigmoid cutoff value
            All points will be between (0.5 - 0.5 / threshold_init_cutoff) and (0.5 + 0.5 / threshold_init_cutoff)
        """
        super().__init__()
        self.depth, self.num_trees, self.tree_dim, self.flatten_output = (
            depth,
            num_trees,
            tree_output_dim,
            flatten_output,
        )
        self.choice_function, self.bin_function = choice_function, bin_function
        self.threshold_init_beta, self.threshold_init_cutoff = (
            threshold_init_beta,
            threshold_init_cutoff,
        )

        self.response = nn.Parameter(torch.zeros([num_trees, tree_output_dim, 2**depth]), requires_grad=True)
        initialize_response_(self.response)

        self.feature_selection_logits = nn.Parameter(torch.zeros([in_features, num_trees, depth]), requires_grad=True)
        initialize_selection_logits_(self.feature_selection_logits)

        self.feature_thresholds = nn.Parameter(
            torch.full([num_trees, depth], float("nan"), dtype=torch.float32),
            requires_grad=True,
        )  # nan values will be initialized on first batch (data-aware init)
        self.log_temperatures = nn.Parameter(
            torch.full([num_trees, depth], float("nan"), dtype=torch.float32),
            requires_grad=True,
        )

        # binary codes for mapping between 1-hot vectors and bin indices
        with torch.no_grad():
            indices = torch.arange(2**self.depth)
            offsets = 2 ** torch.arange(self.depth)
            bin_codes = (indices.view(1, -1) // offsets.view(-1, 1) % 2).to(torch.float32)
            bin_codes_1hot = torch.stack([bin_codes, 1.0 - bin_codes], dim=-1)
            self.bin_codes_1hot = nn.Parameter(bin_codes_1hot, requires_grad=False)
            # ^-- [depth, 2 ** depth, 2]

    def forward(self, input):
        assert len(input.shape) >= 2
        if len(input.shape) > 2:
            return self.forward(input.view(-1, input.shape[-1])).view(*input.shape[:-1], -1)
        # new input shape: [batch_size, in_features]

        feature_logits = self.feature_selection_logits
        feature_selectors = self.choice_function(feature_logits, dim=0)
        # ^--[in_features, num_trees, depth]

        feature_values = torch.einsum("bi,ind->bnd", input, feature_selectors)
        # ^--[batch_size, num_trees, depth]

        threshold_logits = (feature_values - self.feature_thresholds) * torch.exp(-self.log_temperatures)

        threshold_logits = torch.stack([-threshold_logits, threshold_logits], dim=-1)
        # ^--[batch_size, num_trees, depth, 2]

        bins = self.bin_function(threshold_logits)
        # ^--[batch_size, num_trees, depth, 2], approximately binary

        bin_matches = torch.einsum("btds,dcs->btdc", bins, self.bin_codes_1hot)
        # ^--[batch_size, num_trees, depth, 2 ** depth]

        response_weights = torch.prod(bin_matches, dim=-2)
        # ^-- [batch_size, num_trees, 2 ** depth]

        response = torch.einsum("bnd,ncd->bnc", response_weights, self.response)
        # ^-- [batch_size, num_trees, tree_dim]

        return response.flatten(1, 2) if self.flatten_output else response

    def initialize(self, input, eps=1e-6):
        # data-aware initializer
        assert len(input.shape) == 2
        if input.shape[0] < 1000:
            warn(
                "Data-aware initialization is performed on less than 1000 data points. This may cause instability."
                "To avoid potential problems, run this model on a data batch with at least 1000 data samples."
                "You can do so manually before training. Use with torch.no_grad() for memory efficiency."
            )
        with torch.no_grad():
            feature_selectors = self.choice_function(self.feature_selection_logits, dim=0)
            # ^--[in_features, num_trees, depth]

            feature_values = torch.einsum("bi,ind->bnd", input, feature_selectors)
            # ^--[batch_size, num_trees, depth]

            # initialize thresholds: sample random percentiles of data
            percentiles_q = 100 * np.random.beta(
                self.threshold_init_beta,
                self.threshold_init_beta,
                size=[self.num_trees, self.depth],
            )
            self.feature_thresholds.data[...] = torch.as_tensor(
                list(
                    map(
                        np.percentile,
                        check_numpy(feature_values.flatten(1, 2).t()),
                        percentiles_q.flatten(),
                    )
                ),
                dtype=feature_values.dtype,
                device=feature_values.device,
            ).view(self.num_trees, self.depth)

            # init temperatures: make sure enough data points are in the linear region of sparse-sigmoid
            temperatures = np.percentile(
                check_numpy(abs(feature_values - self.feature_thresholds)),
                q=100 * min(1.0, self.threshold_init_cutoff),
                axis=0,
            )

            # if threshold_init_cutoff > 1, scale everything down by it
            temperatures /= max(1.0, self.threshold_init_cutoff)
            self.log_temperatures.data[...] = torch.log(torch.as_tensor(temperatures) + eps)

    def __repr__(self):
        return "{}(in_features={}, num_trees={}, depth={}, tree_dim={}, flatten_output={})".format(
            self.__class__.__name__,
            self.feature_selection_logits.shape[0],
            self.num_trees,
            self.depth,
            self.tree_dim,
            self.flatten_output,
        )

__init__(in_features, num_trees, depth=6, tree_output_dim=1, flatten_output=True, choice_function=sparsemax, bin_function=sparsemoid, initialize_response_=nn.init.normal_, initialize_selection_logits_=nn.init.uniform_, threshold_init_beta=1.0, threshold_init_cutoff=1.0)

Oblivious Differentiable Sparsemax Trees. http://tinyurl.com/odst-readmore One can drop (sic!) this module anywhere instead of nn.Linear.

:param in_features: number of features in the input tensor :param num_trees: number of trees in this layer :param tree_dim: number of response channels in the response of individual tree :param depth: number of splits in every tree :param flatten_output: if False, returns [..., num_trees, tree_dim], by default returns [..., num_trees * tree_dim] :param choice_function: f(tensor, dim) -> R_simplex computes feature weights s.t. f(tensor, dim).sum(dim) == 1 :param bin_function: f(tensor) -> R[0, 1], computes tree leaf weights

:param initialize_response_: in-place initializer for tree output tensor :param initialize_selection_logits_: in-place initializer for logits that select features for the tree both thresholds and scales are initialized with data-aware init (or .load_state_dict) :param threshold_init_beta: initializes threshold to a q-th quantile of data points where q ~ Beta(:threshold_init_beta:, :threshold_init_beta:) If this param is set to 1, initial thresholds will have the same distribution as data points If greater than 1 (e.g. 10), thresholds will be closer to median data value If less than 1 (e.g. 0.1), thresholds will approach min/max data values.

:param threshold_init_cutoff: threshold log-temperatures initializer, in (0, inf) By default(1.0), log-remperatures are initialized in such a way that all bin selectors end up in the linear region of sparse-sigmoid. The temperatures are then scaled by this parameter. Setting this value > 1.0 will result in some margin between data points and sparse-sigmoid cutoff value Setting this value < 1.0 will cause (1 - value) part of data points to end up in flat sparse-sigmoid region For instance, threshold_init_cutoff = 0.9 will set 10% points equal to 0.0 or 1.0 Setting this value > 1.0 will result in a margin between data points and sparse-sigmoid cutoff value All points will be between (0.5 - 0.5 / threshold_init_cutoff) and (0.5 + 0.5 / threshold_init_cutoff)

Source code in src/pytorch_tabular/models/common/layers/soft_trees.py
def __init__(
    self,
    in_features,
    num_trees,
    depth=6,
    tree_output_dim=1,
    flatten_output=True,
    choice_function=sparsemax,
    bin_function=sparsemoid,
    initialize_response_=nn.init.normal_,
    initialize_selection_logits_=nn.init.uniform_,
    threshold_init_beta=1.0,
    threshold_init_cutoff=1.0,
):
    """Oblivious Differentiable Sparsemax Trees. http://tinyurl.com/odst-readmore One can drop (sic!) this module
    anywhere instead of nn.Linear.

    :param in_features: number of features in the input tensor
    :param num_trees: number of trees in this layer
    :param tree_dim: number of response channels in the response of individual tree
    :param depth: number of splits in every tree
    :param flatten_output: if False, returns [..., num_trees, tree_dim],
        by default returns [..., num_trees * tree_dim]
    :param choice_function: f(tensor, dim) -> R_simplex computes feature weights s.t. f(tensor, dim).sum(dim) == 1
    :param bin_function: f(tensor) -> R[0, 1], computes tree leaf weights

    :param initialize_response_: in-place initializer for tree output tensor
    :param initialize_selection_logits_: in-place initializer for logits that select features for the tree
    both thresholds and scales are initialized with data-aware init (or .load_state_dict)
    :param threshold_init_beta: initializes threshold to a q-th quantile of data points
        where q ~ Beta(:threshold_init_beta:, :threshold_init_beta:)
        If this param is set to 1, initial thresholds will have the same distribution as data points
        If greater than 1 (e.g. 10), thresholds will be closer to median data value
        If less than 1 (e.g. 0.1), thresholds will approach min/max data values.

    :param threshold_init_cutoff: threshold log-temperatures initializer, in (0, inf)
        By default(1.0), log-remperatures are initialized in such a way that all bin selectors
        end up in the linear region of sparse-sigmoid. The temperatures are then scaled by this parameter.
        Setting this value > 1.0 will result in some margin between data points and sparse-sigmoid cutoff value
        Setting this value < 1.0 will cause (1 - value) part of data points to end up in flat sparse-sigmoid region
        For instance, threshold_init_cutoff = 0.9 will set 10% points equal to 0.0 or 1.0
        Setting this value > 1.0 will result in a margin between data points and sparse-sigmoid cutoff value
        All points will be between (0.5 - 0.5 / threshold_init_cutoff) and (0.5 + 0.5 / threshold_init_cutoff)
    """
    super().__init__()
    self.depth, self.num_trees, self.tree_dim, self.flatten_output = (
        depth,
        num_trees,
        tree_output_dim,
        flatten_output,
    )
    self.choice_function, self.bin_function = choice_function, bin_function
    self.threshold_init_beta, self.threshold_init_cutoff = (
        threshold_init_beta,
        threshold_init_cutoff,
    )

    self.response = nn.Parameter(torch.zeros([num_trees, tree_output_dim, 2**depth]), requires_grad=True)
    initialize_response_(self.response)

    self.feature_selection_logits = nn.Parameter(torch.zeros([in_features, num_trees, depth]), requires_grad=True)
    initialize_selection_logits_(self.feature_selection_logits)

    self.feature_thresholds = nn.Parameter(
        torch.full([num_trees, depth], float("nan"), dtype=torch.float32),
        requires_grad=True,
    )  # nan values will be initialized on first batch (data-aware init)
    self.log_temperatures = nn.Parameter(
        torch.full([num_trees, depth], float("nan"), dtype=torch.float32),
        requires_grad=True,
    )

    # binary codes for mapping between 1-hot vectors and bin indices
    with torch.no_grad():
        indices = torch.arange(2**self.depth)
        offsets = 2 ** torch.arange(self.depth)
        bin_codes = (indices.view(1, -1) // offsets.view(-1, 1) % 2).to(torch.float32)
        bin_codes_1hot = torch.stack([bin_codes, 1.0 - bin_codes], dim=-1)
        self.bin_codes_1hot = nn.Parameter(bin_codes_1hot, requires_grad=False)

Transformers

Bases: Module

Applies LayerNorm, Dropout and adds to input.

Standard AddNorm operations in Transformers

Source code in src/pytorch_tabular/models/common/layers/transformers.py
class AddNorm(nn.Module):
    """Applies LayerNorm, Dropout and adds to input.

    Standard AddNorm operations in Transformers

    """

    def __init__(self, input_dim: int, dropout: float):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(input_dim)

    def forward(self, X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
        return self.ln(self.dropout(Y) + X)

Bases: Module

Appends the [CLS] token for BERT-like inference.

Source code in src/pytorch_tabular/models/common/layers/transformers.py
class AppendCLSToken(nn.Module):
    """Appends the [CLS] token for BERT-like inference."""

    def __init__(self, d_token: int, initialization: str) -> None:
        """Initialize self."""
        super().__init__()
        self.weight = nn.Parameter(torch.Tensor(d_token))
        d_sqrt_inv = 1 / math.sqrt(d_token)
        _initialize_kaiming(self.weight, initialization, d_sqrt_inv)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Perform the forward pass."""
        assert x.ndim == 3
        return torch.cat([x, self.weight.view(1, 1, -1).repeat(len(x), 1, 1)], dim=1)

__init__(d_token, initialization)

Initialize self.

Source code in src/pytorch_tabular/models/common/layers/transformers.py
def __init__(self, d_token: int, initialization: str) -> None:
    """Initialize self."""
    super().__init__()
    self.weight = nn.Parameter(torch.Tensor(d_token))
    d_sqrt_inv = 1 / math.sqrt(d_token)
    _initialize_kaiming(self.weight, initialization, d_sqrt_inv)

forward(x)

Perform the forward pass.

Source code in src/pytorch_tabular/models/common/layers/transformers.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Perform the forward pass."""
    assert x.ndim == 3
    return torch.cat([x, self.weight.view(1, 1, -1).repeat(len(x), 1, 1)], dim=1)

Bases: Module

Multi Headed Attention Block in Transformers.

Source code in src/pytorch_tabular/models/common/layers/transformers.py
class MultiHeadedAttention(nn.Module):
    """Multi Headed Attention Block in Transformers."""

    def __init__(
        self,
        input_dim: int,
        num_heads: int = 8,
        head_dim: int = 16,
        dropout: int = 0.1,
        keep_attn: bool = True,
    ):
        super().__init__()
        assert input_dim % num_heads == 0, "'input_dim' must be multiples of 'num_heads'"
        inner_dim = head_dim * num_heads
        self.n_heads = num_heads
        self.scale = head_dim**-0.5
        self.keep_attn = keep_attn

        self.to_qkv = nn.Linear(input_dim, inner_dim * 3, bias=False)
        self.to_out = nn.Linear(inner_dim, input_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        h = self.n_heads
        q, k, v = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = (rearrange(t, "b n (h d) -> b h n d", h=h) for t in (q, k, v))
        sim = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale

        attn = sim.softmax(dim=-1)
        attn = self.dropout(attn)
        if self.keep_attn:
            self.attn_weights = attn
        out = einsum("b h i j, b h j d -> b h i d", attn, v)
        out = rearrange(out, "b h n d -> b n (h d)", h=h)
        return self.to_out(out)

Bases: Module

A single Transformer Encoder Block.

Source code in src/pytorch_tabular/models/common/layers/transformers.py
class TransformerEncoderBlock(nn.Module):
    """A single Transformer Encoder Block."""

    def __init__(
        self,
        input_embed_dim: int,
        num_heads: int = 8,
        ff_hidden_multiplier: int = 4,
        ff_activation: str = "GEGLU",
        attn_dropout: float = 0.1,
        keep_attn: bool = True,
        ff_dropout: float = 0.1,
        add_norm_dropout: float = 0.1,
        transformer_head_dim: Optional[int] = None,
    ):
        """
        Args:
            input_embed_dim: The input embedding dimension
            num_heads: The number of attention heads
            ff_hidden_multiplier: The hidden dimension multiplier for the position-wise feed-forward layer
            ff_activation: The activation function for the position-wise feed-forward layer
            attn_dropout: The dropout probability for the attention layer
            keep_attn: Whether to keep the attention weights
            ff_dropout: The dropout probability for the position-wise feed-forward layer
            add_norm_dropout: The dropout probability for the residual connections
            transformer_head_dim: The dimension of the attention heads. If None, will default to input_embed_dim
        """
        super().__init__()
        self.mha = MultiHeadedAttention(
            input_embed_dim,
            num_heads,
            head_dim=input_embed_dim if transformer_head_dim is None else transformer_head_dim,
            dropout=attn_dropout,
            keep_attn=keep_attn,
        )

        try:
            self.pos_wise_ff = GATED_UNITS[ff_activation](
                d_model=input_embed_dim,
                d_ff=input_embed_dim * ff_hidden_multiplier,
                dropout=ff_dropout,
            )
        except (AttributeError, KeyError):
            self.pos_wise_ff = PositionWiseFeedForward(
                d_model=input_embed_dim,
                d_ff=input_embed_dim * ff_hidden_multiplier,
                dropout=ff_dropout,
                activation=getattr(nn, ff_activation)(),
            )
        self.attn_add_norm = AddNorm(input_embed_dim, add_norm_dropout)
        self.ff_add_norm = AddNorm(input_embed_dim, add_norm_dropout)

    def forward(self, x):
        y = self.mha(x)
        x = self.attn_add_norm(x, y)
        y = self.pos_wise_ff(y)
        return self.ff_add_norm(x, y)

__init__(input_embed_dim, num_heads=8, ff_hidden_multiplier=4, ff_activation='GEGLU', attn_dropout=0.1, keep_attn=True, ff_dropout=0.1, add_norm_dropout=0.1, transformer_head_dim=None)

Parameters:

Name Type Description Default
input_embed_dim int

The input embedding dimension

required
num_heads int

The number of attention heads

8
ff_hidden_multiplier int

The hidden dimension multiplier for the position-wise feed-forward layer

4
ff_activation str

The activation function for the position-wise feed-forward layer

'GEGLU'
attn_dropout float

The dropout probability for the attention layer

0.1
keep_attn bool

Whether to keep the attention weights

True
ff_dropout float

The dropout probability for the position-wise feed-forward layer

0.1
add_norm_dropout float

The dropout probability for the residual connections

0.1
transformer_head_dim Optional[int]

The dimension of the attention heads. If None, will default to input_embed_dim

None
Source code in src/pytorch_tabular/models/common/layers/transformers.py
def __init__(
    self,
    input_embed_dim: int,
    num_heads: int = 8,
    ff_hidden_multiplier: int = 4,
    ff_activation: str = "GEGLU",
    attn_dropout: float = 0.1,
    keep_attn: bool = True,
    ff_dropout: float = 0.1,
    add_norm_dropout: float = 0.1,
    transformer_head_dim: Optional[int] = None,
):
    """
    Args:
        input_embed_dim: The input embedding dimension
        num_heads: The number of attention heads
        ff_hidden_multiplier: The hidden dimension multiplier for the position-wise feed-forward layer
        ff_activation: The activation function for the position-wise feed-forward layer
        attn_dropout: The dropout probability for the attention layer
        keep_attn: Whether to keep the attention weights
        ff_dropout: The dropout probability for the position-wise feed-forward layer
        add_norm_dropout: The dropout probability for the residual connections
        transformer_head_dim: The dimension of the attention heads. If None, will default to input_embed_dim
    """
    super().__init__()
    self.mha = MultiHeadedAttention(
        input_embed_dim,
        num_heads,
        head_dim=input_embed_dim if transformer_head_dim is None else transformer_head_dim,
        dropout=attn_dropout,
        keep_attn=keep_attn,
    )

    try:
        self.pos_wise_ff = GATED_UNITS[ff_activation](
            d_model=input_embed_dim,
            d_ff=input_embed_dim * ff_hidden_multiplier,
            dropout=ff_dropout,
        )
    except (AttributeError, KeyError):
        self.pos_wise_ff = PositionWiseFeedForward(
            d_model=input_embed_dim,
            d_ff=input_embed_dim * ff_hidden_multiplier,
            dropout=ff_dropout,
            activation=getattr(nn, ff_activation)(),
        )
    self.attn_add_norm = AddNorm(input_embed_dim, add_norm_dropout)
    self.ff_add_norm = AddNorm(input_embed_dim, add_norm_dropout)

Miscellaneous

Bases: Module

A wrapper for a lambda function as a pytorch module.

Source code in src/pytorch_tabular/models/common/layers/misc.py
class Lambda(nn.Module):
    """A wrapper for a lambda function as a pytorch module."""

    def __init__(self, func: Callable):
        """Initialize lambda module
        Args:
            func: any function/callable
        """
        super().__init__()
        self.func = func

    def forward(self, *args, **kwargs):
        return self.func(*args, **kwargs)

__init__(func)

Initialize lambda module Args: func: any function/callable

Source code in src/pytorch_tabular/models/common/layers/misc.py
def __init__(self, func: Callable):
    """Initialize lambda module
    Args:
        func: any function/callable
    """
    super().__init__()
    self.func = func

Bases: Module

Base class for pytorch module with data-aware initializer on first batch.

Source code in src/pytorch_tabular/models/common/layers/misc.py
class ModuleWithInit(nn.Module):
    """Base class for pytorch module with data-aware initializer on first batch."""

    def __init__(self):
        super().__init__()
        self._is_initialized_tensor = nn.Parameter(torch.tensor(0, dtype=torch.uint8), requires_grad=False)
        self._is_initialized_bool = None
        # Note: this module uses a separate flag self._is_initialized so as to achieve both
        # * persistence: is_initialized is saved alongside model in state_dict
        # * speed: model doesn't need to cache
        # please DO NOT use these flags in child modules

    def initialize(self, *args, **kwargs):
        """Initialize module tensors using first batch of data."""
        raise NotImplementedError("Please implement ")

    def __call__(self, *args, **kwargs):
        if self._is_initialized_bool is None:
            self._is_initialized_bool = bool(self._is_initialized_tensor.item())
        if not self._is_initialized_bool:
            self.initialize(*args, **kwargs)
            self._is_initialized_tensor.data[...] = 1
            self._is_initialized_bool = True
        return super().__call__(*args, **kwargs)

initialize(*args, **kwargs)

Initialize module tensors using first batch of data.

Source code in src/pytorch_tabular/models/common/layers/misc.py
def initialize(self, *args, **kwargs):
    """Initialize module tensors using first batch of data."""
    raise NotImplementedError("Please implement ")

Bases: Module

Source code in src/pytorch_tabular/models/common/layers/misc.py
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

Activations

Bases: Function

A highly optimized equivalent of labda x: Entmax15([x, 0])

Source code in src/pytorch_tabular/models/common/layers/activations.py
class Entmoid15(Function):
    """A highly optimized equivalent of labda x: Entmax15([x, 0])"""

    @staticmethod
    def forward(ctx, input):
        output = Entmoid15._forward(input)
        ctx.save_for_backward(output)
        return output

    @staticmethod
    @script
    def _forward(input):
        input, is_pos = abs(input), input >= 0
        tau = (input + torch.sqrt(F.relu(8 - input**2))) / 2
        tau.masked_fill_(tau <= input, 2.0)
        y_neg = 0.25 * F.relu(tau - input, inplace=True) ** 2
        return torch.where(is_pos, 1 - y_neg, y_neg)

    @staticmethod
    def backward(ctx, grad_output):
        return Entmoid15._backward(ctx.saved_tensors[0], grad_output)

    @staticmethod
    @script
    def _backward(output, grad_output):
        gppr0, gppr1 = output.sqrt(), (1 - output).sqrt()
        grad_input = grad_output * gppr0
        q = grad_input / (gppr0 + gppr1)
        grad_input -= q * gppr0
        return grad_input

1.5-entmax: normalizing sparse transform (a la softmax).

Solves the optimization problem:

max_p <x, p> - H_1.5(p)    s.t.    p >= 0, sum(p) == 1.

where H_1.5(p) is the Tsallis alpha-entropy with alpha=1.5.

Parameters

X : torch.Tensor The input tensor.

int

The dimension along which to apply 1.5-entmax.

int or None

number of largest elements to partial-sort over. For optimal performance, should be slightly bigger than the expected number of nonzeros in the solution. If the solution is more than k-sparse, this function is recursively called with a 2*k schedule. If None, full sorting is performed from the beginning.

Returns

P : torch tensor, same shape as X The projection result, such that P.sum(dim=dim) == 1 elementwise.

sparsemax: normalizing sparse transform (a la softmax).

Solves the projection:

min_p ||x - p||_2   s.t.    p >= 0, sum(p) == 1.

Parameters

X : torch.Tensor The input tensor.

int

The dimension along which to apply sparsemax.

int or None

number of largest elements to partial-sort over. For optimal performance, should be slightly bigger than the expected number of nonzeros in the solution. If the solution is more than k-sparse, this function is recursively called with a 2*k schedule. If None, full sorting is performed from the beginning.

Returns

P : torch tensor, same shape as X The projection result, such that P.sum(dim=dim) == 1 elementwise.

Source code in src/pytorch_tabular/models/common/layers/activations.py
def sparsemoid(input):
    return (0.5 * input + 0.5).clamp_(0, 1)
Source code in src/pytorch_tabular/models/common/layers/activations.py
def t_softmax(input: Tensor, t: Tensor = None, dim: int = -1) -> Tensor:
    if t is None:
        t = torch.tensor(0.5, device=input.device)
    assert (t >= 0.0).all()
    maxes = torch.max(input, dim=dim, keepdim=True).values
    input_minus_maxes = input - maxes

    w = torch.relu(input_minus_maxes + t) + 1e-8
    return torch.softmax(input_minus_maxes + torch.log(w), dim=dim)

Bases: Module

Source code in src/pytorch_tabular/models/common/layers/activations.py
class TSoftmax(torch.nn.Module):
    def __init__(self, dim: int = -1):
        super().__init__()
        self.dim = dim

    def forward(self, input: Tensor, t: Tensor) -> Tensor:
        return t_softmax(input, t, self.dim)

Bases: Module

Source code in src/pytorch_tabular/models/common/layers/activations.py
class RSoftmax(torch.nn.Module):
    def __init__(self, dim: int = -1, eps: float = 1e-8):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.tsoftmax = TSoftmax(dim=dim)

    @classmethod
    def calculate_t(cls, input: Tensor, r: Tensor, dim: int = -1, eps: float = 1e-8):
        # r represents what is the fraction of zero values that we want to have
        assert ((0.0 <= r) & (r <= 1.0)).all()

        maxes = torch.max(input, dim=dim, keepdim=True).values
        input_minus_maxes = input - maxes

        zeros_mask = torch.exp(input_minus_maxes) == 0.0
        zeros_frac = zeros_mask.sum(dim=dim, keepdim=True).float() / input_minus_maxes.shape[dim]

        q = torch.clamp((r - zeros_frac) / (1 - zeros_frac), min=0.0, max=1.0)
        x_minus_maxes = input_minus_maxes * (~zeros_mask).float()
        if q.ndim > 1:
            t = -torch.quantile(x_minus_maxes, q.view(-1), dim=dim, keepdim=True).detach()
            t = t.squeeze(dim).diagonal(dim1=-2, dim2=-1).unsqueeze(-1) + eps
        else:
            t = -torch.quantile(x_minus_maxes, q, dim=dim).detach() + eps
        return t

    def forward(self, input: Tensor, r: Tensor):
        t = RSoftmax.calculate_t(input, r, self.dim, self.eps)
        return self.tsoftmax(input, t)