Skip to content

Common Modules

Embeddings

pytorch_tabular.models.common.layers.Embedding1dLayer(continuous_dim, categorical_embedding_dims, embedding_dropout=0.0, batch_norm_continuous_input=False)

Bases: nn.Module

Enables different values in a categorical features to have different embeddings.

Source code in src/pytorch_tabular/models/common/layers/embeddings.py
def __init__(
    self,
    continuous_dim: int,
    categorical_embedding_dims: Tuple[int, int],
    embedding_dropout: float = 0.0,
    batch_norm_continuous_input: bool = False,
):
    super().__init__()
    self.continuous_dim = continuous_dim
    self.categorical_embedding_dims = categorical_embedding_dims
    self.batch_norm_continuous_input = batch_norm_continuous_input

    # Embedding layers
    self.cat_embedding_layers = nn.ModuleList([nn.Embedding(x, y) for x, y in categorical_embedding_dims])
    if embedding_dropout > 0:
        self.embd_dropout = nn.Dropout(embedding_dropout)
    else:
        self.embd_dropout = None
    # Continuous Layers
    if batch_norm_continuous_input:
        self.normalizing_batch_norm = nn.BatchNorm1d(continuous_dim)

pytorch_tabular.models.common.layers.Embedding2dLayer(continuous_dim, categorical_cardinality, embedding_dim, shared_embedding_strategy=None, frac_shared_embed=0.25, embedding_bias=False, batch_norm_continuous_input=False, embedding_dropout=0.0, initialization=None)

Bases: nn.Module

Embeds categorical and continuous features into a 2D tensor.

PARAMETER DESCRIPTION
continuous_dim

number of continuous features

TYPE: int

categorical_cardinality

list of cardinalities of categorical features

TYPE: List[int]

embedding_dim

embedding dimension

TYPE: int

shared_embedding_strategy

strategy to use for shared embeddings

TYPE: Optional[str] DEFAULT: None

frac_shared_embed

fraction of embeddings to share

TYPE: float DEFAULT: 0.25

embedding_bias

whether to use bias in embedding layers

TYPE: bool DEFAULT: False

batch_norm_continuous_input

whether to use batch norm on continuous features

TYPE: bool DEFAULT: False

embedding_dropout

dropout to apply to embeddings

TYPE: float DEFAULT: 0.0

initialization

initialization strategy to use for embedding layers

TYPE: Optional[str] DEFAULT: None

Source code in src/pytorch_tabular/models/common/layers/embeddings.py
def __init__(
    self,
    continuous_dim: int,
    categorical_cardinality: List[int],
    embedding_dim: int,
    shared_embedding_strategy: Optional[str] = None,
    frac_shared_embed: float = 0.25,
    embedding_bias: bool = False,
    batch_norm_continuous_input: bool = False,
    embedding_dropout: float = 0.0,
    initialization: Optional[str] = None,
):
    """
    Args:
        continuous_dim: number of continuous features
        categorical_cardinality: list of cardinalities of categorical features
        embedding_dim: embedding dimension
        shared_embedding_strategy: strategy to use for shared embeddings
        frac_shared_embed: fraction of embeddings to share
        embedding_bias: whether to use bias in embedding layers
        batch_norm_continuous_input: whether to use batch norm on continuous features
        embedding_dropout: dropout to apply to embeddings
        initialization: initialization strategy to use for embedding layers"""
    super().__init__()
    self.continuous_dim = continuous_dim
    self.categorical_cardinality = categorical_cardinality
    self.embedding_dim = embedding_dim
    self.batch_norm_continuous_input = batch_norm_continuous_input
    self.shared_embedding_strategy = shared_embedding_strategy
    self.frac_shared_embed = frac_shared_embed
    self.embedding_bias = embedding_bias
    self.initialization = initialization
    d_sqrt_inv = 1 / math.sqrt(embedding_dim)
    if initialization is not None:
        assert initialization in [
            "kaiming_uniform",
            "kaiming_normal",
        ], "initialization should be either of `kaiming` or `uniform`"
        self._do_kaiming_initialization = True
        self._initialize_kaiming = partial(
            _initialize_kaiming,
            initialization=initialization,
            d_sqrt_inv=d_sqrt_inv,
        )
    else:
        self._do_kaiming_initialization = False

    # cat Embedding layers
    if self.shared_embedding_strategy is not None:
        self.cat_embedding_layers = nn.ModuleList(
            [
                SharedEmbeddings(
                    c,
                    self.embedding_dim,
                    add_shared_embed=(self.shared_embedding_strategy == "add"),
                    frac_shared_embed=self.frac_shared_embed,
                )
                for c in categorical_cardinality
            ]
        )
        if self._do_kaiming_initialization:
            for embedding_layer in self.cat_embedding_layers:
                self._initialize_kaiming(embedding_layer.embed.weight)
                self._initialize_kaiming(embedding_layer.shared_embed)
    else:
        self.cat_embedding_layers = nn.ModuleList(
            [nn.Embedding(c, self.embedding_dim) for c in categorical_cardinality]
        )
        if self._do_kaiming_initialization:
            for embedding_layer in self.cat_embedding_layers:
                self._initialize_kaiming(embedding_layer.weight)
    if embedding_bias:
        self.cat_embedding_bias = nn.Parameter(torch.Tensor(len(self.categorical_cardinality), self.embedding_dim))
        if self._do_kaiming_initialization:
            self._initialize_kaiming(self.cat_embedding_bias)
    # Continuous Embedding Layer
    self.cont_embedding_layer = nn.Embedding(self.continuous_dim, self.embedding_dim)
    if self._do_kaiming_initialization:
        self._initialize_kaiming(self.cont_embedding_layer.weight)
    if embedding_bias:
        self.cont_embedding_bias = nn.Parameter(torch.Tensor(self.continuous_dim, self.embedding_dim))
        if self._do_kaiming_initialization:
            self._initialize_kaiming(self.cont_embedding_bias)
    if batch_norm_continuous_input:
        self.normalizing_batch_norm = nn.BatchNorm1d(continuous_dim)
    if embedding_dropout > 0:
        self.embd_dropout = nn.Dropout(embedding_dropout)
    else:
        self.embd_dropout = None

pytorch_tabular.models.common.layers.PreEncoded1dLayer(continuous_dim, categorical_dim, embedding_dropout=0.0, batch_norm_continuous_input=False)

Bases: nn.Module

Takes in pre-encoded categorical variables and just concatenates with continuous variables No learnable component.

Source code in src/pytorch_tabular/models/common/layers/embeddings.py
def __init__(
    self,
    continuous_dim: int,
    categorical_dim: Tuple[int, int],
    embedding_dropout: float = 0.0,
    batch_norm_continuous_input: bool = False,
):
    super().__init__()
    self.continuous_dim = continuous_dim
    self.categorical_dim = categorical_dim
    self.batch_norm_continuous_input = batch_norm_continuous_input

    if embedding_dropout > 0:
        self.embd_dropout = nn.Dropout(embedding_dropout)
    else:
        self.embd_dropout = None
    # Continuous Layers
    if batch_norm_continuous_input:
        self.normalizing_batch_norm = nn.BatchNorm1d(continuous_dim)

pytorch_tabular.models.common.layers.SharedEmbeddings(num_embed, embed_dim, add_shared_embed=False, frac_shared_embed=0.25)

Bases: nn.Module

Enables different values in a categorical feature to share some embeddings across.

Source code in src/pytorch_tabular/models/common/layers/embeddings.py
def __init__(
    self,
    num_embed: int,
    embed_dim: int,
    add_shared_embed: bool = False,
    frac_shared_embed: float = 0.25,
):
    super().__init__()
    assert frac_shared_embed < 1, "'frac_shared_embed' must be less than 1"

    self.add_shared_embed = add_shared_embed
    self.embed = nn.Embedding(num_embed, embed_dim, padding_idx=0)
    self.embed.weight.data.clamp_(-2, 2)
    if add_shared_embed:
        col_embed_dim = embed_dim
    else:
        col_embed_dim = int(embed_dim * frac_shared_embed)
    self.shared_embed = nn.Parameter(torch.empty(1, col_embed_dim).uniform_(-1, 1))

Gated Units

pytorch_tabular.models.common.layers.GatedFeatureLearningUnit(n_features_in, n_stages, feature_mask_function=entmax15, feature_sparsity=0.3, learnable_sparsity=True, dropout=0.0)

Bases: nn.Module

Source code in src/pytorch_tabular/models/common/layers/gated_units.py
def __init__(
    self,
    n_features_in: int,
    n_stages: int,
    feature_mask_function: Callable = entmax15,
    feature_sparsity: float = 0.3,
    learnable_sparsity: bool = True,
    dropout: float = 0.0,
):
    super().__init__()
    self.n_features_in = n_features_in
    self.n_features_out = n_features_in
    self.feature_mask_function = feature_mask_function
    self._dropout = dropout
    self.n_stages = n_stages
    self.feature_sparsity = feature_sparsity
    self.learnable_sparsity = learnable_sparsity
    self._build_network()

pytorch_tabular.models.common.layers.GEGLU(d_model, d_ff, dropout=0.1)

Bases: nn.Module

Gated Exponential Linear Unit (GEGLU)

PARAMETER DESCRIPTION
d_model

dimension of the model

TYPE: int

d_ff

dimension of the feedforward layer

TYPE: int

dropout

dropout probability

TYPE: float DEFAULT: 0.1

Source code in src/pytorch_tabular/models/common/layers/gated_units.py
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
    """
    Args:
        d_model: dimension of the model
        d_ff: dimension of the feedforward layer
        dropout: dropout probability
    """
    super().__init__()
    self.ffn = PositionWiseFeedForward(d_model, d_ff, dropout, nn.GELU(), True, False, False, False)

pytorch_tabular.models.common.layers.ReGLU(d_model, d_ff, dropout=0.1)

Bases: nn.Module

ReGLU.

PARAMETER DESCRIPTION
d_model

dimension of the model

TYPE: int

d_ff

dimension of the feedforward layer

TYPE: int

dropout

dropout probability

TYPE: float DEFAULT: 0.1

Source code in src/pytorch_tabular/models/common/layers/gated_units.py
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
    """
    Args:
        d_model: dimension of the model
        d_ff: dimension of the feedforward layer
        dropout: dropout probability
    """
    super().__init__()
    self.ffn = PositionWiseFeedForward(d_model, d_ff, dropout, nn.ReLU(), True, False, False, False)

pytorch_tabular.models.common.layers.SwiGLU(d_model, d_ff, dropout=0.1)

Bases: nn.Module

PARAMETER DESCRIPTION
d_model

dimension of the model

TYPE: int

d_ff

dimension of the feedforward layer

TYPE: int

dropout

dropout probability

TYPE: float DEFAULT: 0.1

Source code in src/pytorch_tabular/models/common/layers/gated_units.py
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
    """
    Args:
        d_model: dimension of the model
        d_ff: dimension of the feedforward layer
        dropout: dropout probability
    """
    super().__init__()
    self.ffn = PositionWiseFeedForward(d_model, d_ff, dropout, nn.SiLU(), True, False, False, False)

pytorch_tabular.models.common.layers.PositionWiseFeedForward(d_model, d_ff, dropout=0.1, activation=nn.ReLU(), is_gated=False, bias1=True, bias2=True, bias_gate=True)

Bases: nn.Module

title: Position-wise Feed-Forward Network (FFN) summary: Documented reusable implementation of the position wise feedforward network.

Position-wise Feed-Forward Network (FFN)

This is a PyTorch implementation of position-wise feedforward network used in transformer. FFN consists of two fully connected layers. Number of dimensions in the hidden layer $d_{ff}$, is generally set to around four times that of the token embedding $d_{model}$. So it is sometime also called the expand-and-contract network. There is an activation at the hidden layer, which is usually set to ReLU (Rectified Linear Unit) activation, $$\max(0, x)$$ That is, the FFN function is, $$FFN(x, W_1, W_2, b_1, b_2) = \max(0, x W_1 + b_1) W_2 + b_2$$ where $W_1$, $W_2$, $b_1$ and $b_2$ are learnable parameters. Sometimes the GELU (Gaussian Error Linear Unit) activation is also used instead of ReLU. $$x \Phi(x)$$ where $\Phi(x) = P(X \le x), X \sim \mathcal{N}(0,1)$

Gated Linear Units

This is a generic implementation that supports different variants including Gated Linear Units (GLU).

  • d_model is the number of features in a token embedding
  • d_ff is the number of features in the hidden layer of the FFN
  • dropout is dropout probability for the hidden layer
  • is_gated specifies whether the hidden layer is gated
  • bias1 specified whether the first fully connected layer should have a learnable bias
  • bias2 specified whether the second fully connected layer should have a learnable bias
  • bias_gate specified whether the fully connected layer for the gate should have a learnable bias
Source code in src/pytorch_tabular/models/common/layers/gated_units.py
def __init__(
    self,
    d_model: int,
    d_ff: int,
    dropout: float = 0.1,
    activation=nn.ReLU(),
    is_gated: bool = False,
    bias1: bool = True,
    bias2: bool = True,
    bias_gate: bool = True,
):
    """
    * `d_model` is the number of features in a token embedding
    * `d_ff` is the number of features in the hidden layer of the FFN
    * `dropout` is dropout probability for the hidden layer
    * `is_gated` specifies whether the hidden layer is gated
    * `bias1` specified whether the first fully connected layer should have a learnable bias
    * `bias2` specified whether the second fully connected layer should have a learnable bias
    * `bias_gate` specified whether the fully connected layer for the gate should have a learnable bias
    """
    super().__init__()
    # Layer one parameterized by weight $W_1$ and bias $b_1$
    self.layer1 = nn.Linear(d_model, d_ff, bias=bias1)
    # Layer one parameterized by weight $W_1$ and bias $b_1$
    self.layer2 = nn.Linear(d_ff, d_model, bias=bias2)
    # Hidden layer dropout
    self.dropout = nn.Dropout(dropout)
    # Activation function $f$
    self.activation = activation
    # Whether there is a gate
    self.is_gated = is_gated
    if is_gated:
        # If there is a gate the linear layer to transform inputs to
        # be multiplied by the gate, parameterized by weight $V$ and bias $c$
        self.linear_v = nn.Linear(d_model, d_ff, bias=bias_gate)

Soft Trees

pytorch_tabular.models.common.layers.NeuralDecisionTree(depth, n_features, dropout=0, binning_activation=entmax15, feature_mask_function=entmax15, feature_sparsity=0.8, learnable_sparsity=True)

Bases: nn.Module

Source code in src/pytorch_tabular/models/common/layers/soft_trees.py
def __init__(
    self,
    depth: int,
    n_features: int,
    dropout: float = 0,
    binning_activation: Callable = entmax15,
    feature_mask_function: Callable = entmax15,
    feature_sparsity: float = 0.8,
    learnable_sparsity: bool = True,
):
    super().__init__()
    self.depth = depth
    self._num_cutpoints = 1
    self.n_features = n_features
    self._dropout = dropout
    self.binning_activation = binning_activation
    self.feature_mask_function = feature_mask_function
    self.feature_sparsity = feature_sparsity
    self.learnable_sparsity = learnable_sparsity
    self._build_network()

pytorch_tabular.models.common.layers.ODST(in_features, num_trees, depth=6, tree_output_dim=1, flatten_output=True, choice_function=sparsemax, bin_function=sparsemoid, initialize_response_=nn.init.normal_, initialize_selection_logits_=nn.init.uniform_, threshold_init_beta=1.0, threshold_init_cutoff=1.0)

Bases: ModuleWithInit

Oblivious Differentiable Sparsemax Trees. http://tinyurl.com/odst-readmore One can drop (sic!) this module anywhere instead of nn.Linear.

:param in_features: number of features in the input tensor :param num_trees: number of trees in this layer :param tree_dim: number of response channels in the response of individual tree :param depth: number of splits in every tree :param flatten_output: if False, returns [..., num_trees, tree_dim], by default returns [..., num_trees * tree_dim] :param choice_function: f(tensor, dim) -> R_simplex computes feature weights s.t. f(tensor, dim).sum(dim) == 1 :param bin_function: f(tensor) -> R[0, 1], computes tree leaf weights

:param initialize_response_: in-place initializer for tree output tensor :param initialize_selection_logits_: in-place initializer for logits that select features for the tree both thresholds and scales are initialized with data-aware init (or .load_state_dict) :param threshold_init_beta: initializes threshold to a q-th quantile of data points where q ~ Beta(:threshold_init_beta:, :threshold_init_beta:) If this param is set to 1, initial thresholds will have the same distribution as data points If greater than 1 (e.g. 10), thresholds will be closer to median data value If less than 1 (e.g. 0.1), thresholds will approach min/max data values.

:param threshold_init_cutoff: threshold log-temperatures initializer, in (0, inf) By default(1.0), log-remperatures are initialized in such a way that all bin selectors end up in the linear region of sparse-sigmoid. The temperatures are then scaled by this parameter. Setting this value > 1.0 will result in some margin between data points and sparse-sigmoid cutoff value Setting this value < 1.0 will cause (1 - value) part of data points to end up in flat sparse-sigmoid region For instance, threshold_init_cutoff = 0.9 will set 10% points equal to 0.0 or 1.0 Setting this value > 1.0 will result in a margin between data points and sparse-sigmoid cutoff value All points will be between (0.5 - 0.5 / threshold_init_cutoff) and (0.5 + 0.5 / threshold_init_cutoff)

Source code in src/pytorch_tabular/models/common/layers/soft_trees.py
def __init__(
    self,
    in_features,
    num_trees,
    depth=6,
    tree_output_dim=1,
    flatten_output=True,
    choice_function=sparsemax,
    bin_function=sparsemoid,
    initialize_response_=nn.init.normal_,
    initialize_selection_logits_=nn.init.uniform_,
    threshold_init_beta=1.0,
    threshold_init_cutoff=1.0,
):
    """Oblivious Differentiable Sparsemax Trees. http://tinyurl.com/odst-readmore One can drop (sic!) this module
    anywhere instead of nn.Linear.

    :param in_features: number of features in the input tensor
    :param num_trees: number of trees in this layer
    :param tree_dim: number of response channels in the response of individual tree
    :param depth: number of splits in every tree
    :param flatten_output: if False, returns [..., num_trees, tree_dim],
        by default returns [..., num_trees * tree_dim]
    :param choice_function: f(tensor, dim) -> R_simplex computes feature weights s.t. f(tensor, dim).sum(dim) == 1
    :param bin_function: f(tensor) -> R[0, 1], computes tree leaf weights

    :param initialize_response_: in-place initializer for tree output tensor
    :param initialize_selection_logits_: in-place initializer for logits that select features for the tree
    both thresholds and scales are initialized with data-aware init (or .load_state_dict)
    :param threshold_init_beta: initializes threshold to a q-th quantile of data points
        where q ~ Beta(:threshold_init_beta:, :threshold_init_beta:)
        If this param is set to 1, initial thresholds will have the same distribution as data points
        If greater than 1 (e.g. 10), thresholds will be closer to median data value
        If less than 1 (e.g. 0.1), thresholds will approach min/max data values.

    :param threshold_init_cutoff: threshold log-temperatures initializer, in (0, inf)
        By default(1.0), log-remperatures are initialized in such a way that all bin selectors
        end up in the linear region of sparse-sigmoid. The temperatures are then scaled by this parameter.
        Setting this value > 1.0 will result in some margin between data points and sparse-sigmoid cutoff value
        Setting this value < 1.0 will cause (1 - value) part of data points to end up in flat sparse-sigmoid region
        For instance, threshold_init_cutoff = 0.9 will set 10% points equal to 0.0 or 1.0
        Setting this value > 1.0 will result in a margin between data points and sparse-sigmoid cutoff value
        All points will be between (0.5 - 0.5 / threshold_init_cutoff) and (0.5 + 0.5 / threshold_init_cutoff)
    """
    super().__init__()
    self.depth, self.num_trees, self.tree_dim, self.flatten_output = (
        depth,
        num_trees,
        tree_output_dim,
        flatten_output,
    )
    self.choice_function, self.bin_function = choice_function, bin_function
    self.threshold_init_beta, self.threshold_init_cutoff = (
        threshold_init_beta,
        threshold_init_cutoff,
    )

    self.response = nn.Parameter(torch.zeros([num_trees, tree_output_dim, 2**depth]), requires_grad=True)
    initialize_response_(self.response)

    self.feature_selection_logits = nn.Parameter(torch.zeros([in_features, num_trees, depth]), requires_grad=True)
    initialize_selection_logits_(self.feature_selection_logits)

    self.feature_thresholds = nn.Parameter(
        torch.full([num_trees, depth], float("nan"), dtype=torch.float32),
        requires_grad=True,
    )  # nan values will be initialized on first batch (data-aware init)
    self.log_temperatures = nn.Parameter(
        torch.full([num_trees, depth], float("nan"), dtype=torch.float32),
        requires_grad=True,
    )

    # binary codes for mapping between 1-hot vectors and bin indices
    with torch.no_grad():
        indices = torch.arange(2**self.depth)
        offsets = 2 ** torch.arange(self.depth)
        bin_codes = (indices.view(1, -1) // offsets.view(-1, 1) % 2).to(torch.float32)
        bin_codes_1hot = torch.stack([bin_codes, 1.0 - bin_codes], dim=-1)
        self.bin_codes_1hot = nn.Parameter(bin_codes_1hot, requires_grad=False)

Transformers

pytorch_tabular.models.common.layers.AddNorm(input_dim, dropout)

Bases: nn.Module

Applies LayerNorm, Dropout and adds to input.

Standard AddNorm operations in Transformers

Source code in src/pytorch_tabular/models/common/layers/transformers.py
def __init__(self, input_dim: int, dropout: float):
    super().__init__()
    self.dropout = nn.Dropout(dropout)
    self.ln = nn.LayerNorm(input_dim)

pytorch_tabular.models.common.layers.AppendCLSToken(d_token, initialization)

Bases: nn.Module

Appends the [CLS] token for BERT-like inference.

Initialize self.

Source code in src/pytorch_tabular/models/common/layers/transformers.py
def __init__(self, d_token: int, initialization: str) -> None:
    """Initialize self."""
    super().__init__()
    self.weight = nn.Parameter(torch.Tensor(d_token))
    d_sqrt_inv = 1 / math.sqrt(d_token)
    _initialize_kaiming(self.weight, initialization, d_sqrt_inv)

forward(x)

Perform the forward pass.

Source code in src/pytorch_tabular/models/common/layers/transformers.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Perform the forward pass."""
    assert x.ndim == 3
    return torch.cat([x, self.weight.view(1, 1, -1).repeat(len(x), 1, 1)], dim=1)

pytorch_tabular.models.common.layers.MultiHeadedAttention(input_dim, num_heads=8, head_dim=16, dropout=0.1, keep_attn=True)

Bases: nn.Module

Multi Headed Attention Block in Transformers.

Source code in src/pytorch_tabular/models/common/layers/transformers.py
def __init__(
    self,
    input_dim: int,
    num_heads: int = 8,
    head_dim: int = 16,
    dropout: int = 0.1,
    keep_attn: bool = True,
):
    super().__init__()
    assert input_dim % num_heads == 0, "'input_dim' must be multiples of 'num_heads'"
    inner_dim = head_dim * num_heads
    self.n_heads = num_heads
    self.scale = head_dim**-0.5
    self.keep_attn = keep_attn

    self.to_qkv = nn.Linear(input_dim, inner_dim * 3, bias=False)
    self.to_out = nn.Linear(inner_dim, input_dim)

    self.dropout = nn.Dropout(dropout)

pytorch_tabular.models.common.layers.TransformerEncoderBlock(input_embed_dim, num_heads=8, ff_hidden_multiplier=4, ff_activation='GEGLU', attn_dropout=0.1, keep_attn=True, ff_dropout=0.1, add_norm_dropout=0.1, transformer_head_dim=None)

Bases: nn.Module

A single Transformer Encoder Block.

PARAMETER DESCRIPTION
input_embed_dim

The input embedding dimension

TYPE: int

num_heads

The number of attention heads

TYPE: int DEFAULT: 8

ff_hidden_multiplier

The hidden dimension multiplier for the position-wise feed-forward layer

TYPE: int DEFAULT: 4

ff_activation

The activation function for the position-wise feed-forward layer

TYPE: str DEFAULT: 'GEGLU'

attn_dropout

The dropout probability for the attention layer

TYPE: float DEFAULT: 0.1

keep_attn

Whether to keep the attention weights

TYPE: bool DEFAULT: True

ff_dropout

The dropout probability for the position-wise feed-forward layer

TYPE: float DEFAULT: 0.1

add_norm_dropout

The dropout probability for the residual connections

TYPE: float DEFAULT: 0.1

transformer_head_dim

The dimension of the attention heads. If None, will default to input_embed_dim

TYPE: Optional[int] DEFAULT: None

Source code in src/pytorch_tabular/models/common/layers/transformers.py
def __init__(
    self,
    input_embed_dim: int,
    num_heads: int = 8,
    ff_hidden_multiplier: int = 4,
    ff_activation: str = "GEGLU",
    attn_dropout: float = 0.1,
    keep_attn: bool = True,
    ff_dropout: float = 0.1,
    add_norm_dropout: float = 0.1,
    transformer_head_dim: Optional[int] = None,
):
    """
    Args:
        input_embed_dim: The input embedding dimension
        num_heads: The number of attention heads
        ff_hidden_multiplier: The hidden dimension multiplier for the position-wise feed-forward layer
        ff_activation: The activation function for the position-wise feed-forward layer
        attn_dropout: The dropout probability for the attention layer
        keep_attn: Whether to keep the attention weights
        ff_dropout: The dropout probability for the position-wise feed-forward layer
        add_norm_dropout: The dropout probability for the residual connections
        transformer_head_dim: The dimension of the attention heads. If None, will default to input_embed_dim
    """
    super().__init__()
    self.mha = MultiHeadedAttention(
        input_embed_dim,
        num_heads,
        head_dim=input_embed_dim if transformer_head_dim is None else transformer_head_dim,
        dropout=attn_dropout,
        keep_attn=keep_attn,
    )

    try:
        self.pos_wise_ff = GATED_UNITS[ff_activation](
            d_model=input_embed_dim,
            d_ff=input_embed_dim * ff_hidden_multiplier,
            dropout=ff_dropout,
        )
    except AttributeError:
        self.pos_wise_ff = PositionWiseFeedForward(
            d_model=input_embed_dim,
            d_ff=input_embed_dim * ff_hidden_multiplier,
            dropout=ff_dropout,
            activation=getattr(nn, self.hparams.ff_activation),
        )
    self.attn_add_norm = AddNorm(input_embed_dim, add_norm_dropout)
    self.ff_add_norm = AddNorm(input_embed_dim, add_norm_dropout)

Miscellaneous

pytorch_tabular.models.common.layers.Lambda(func)

Bases: nn.Module

A wrapper for a lambda function as a pytorch module.

Initialize lambda module

PARAMETER DESCRIPTION
func

any function/callable

TYPE: Callable

Source code in src/pytorch_tabular/models/common/layers/misc.py
def __init__(self, func: Callable):
    """Initialize lambda module
    Args:
        func: any function/callable
    """
    super().__init__()
    self.func = func

pytorch_tabular.models.common.layers.ModuleWithInit()

Bases: nn.Module

Base class for pytorch module with data-aware initializer on first batch.

Source code in src/pytorch_tabular/models/common/layers/misc.py
def __init__(self):
    super().__init__()
    self._is_initialized_tensor = nn.Parameter(torch.tensor(0, dtype=torch.uint8), requires_grad=False)
    self._is_initialized_bool = None

initialize(*args, **kwargs)

Initialize module tensors using first batch of data.

Source code in src/pytorch_tabular/models/common/layers/misc.py
def initialize(self, *args, **kwargs):
    """Initialize module tensors using first batch of data."""
    raise NotImplementedError("Please implement ")

pytorch_tabular.models.common.layers.Residual(fn)

Bases: nn.Module

Source code in src/pytorch_tabular/models/common/layers/misc.py
def __init__(self, fn):
    super().__init__()
    self.fn = fn

Activations

pytorch_tabular.models.common.layers.activations.Entmoid15

Bases: Function

A highly optimized equivalent of labda x: Entmax15([x, 0])

pytorch_tabular.models.common.layers.activations.entmoid15 = Entmoid15.apply module-attribute

pytorch_tabular.models.common.layers.activations.entmax15 = entmax15 module-attribute

pytorch_tabular.models.common.layers.activations.sparsemax = sparsemax module-attribute

pytorch_tabular.models.common.layers.activations.sparsemoid(input)

Source code in src/pytorch_tabular/models/common/layers/activations.py
def sparsemoid(input):
    return (0.5 * input + 0.5).clamp_(0, 1)

pytorch_tabular.models.common.layers.activations.t_softmax(input, t=None, dim=-1)

Source code in src/pytorch_tabular/models/common/layers/activations.py
def t_softmax(input: Tensor, t: Tensor = None, dim: int = -1) -> Tensor:
    if t is None:
        t = torch.tensor(0.5, device=input.device)
    assert (t >= 0.0).all()
    maxes = torch.max(input, dim=dim, keepdim=True).values
    input_minus_maxes = input - maxes

    w = torch.relu(input_minus_maxes + t) + 1e-8
    return torch.softmax(input_minus_maxes + torch.log(w), dim=dim)

pytorch_tabular.models.common.layers.activations.TSoftmax(dim=-1)

Bases: torch.nn.Module

Source code in src/pytorch_tabular/models/common/layers/activations.py
def __init__(self, dim: int = -1):
    super().__init__()
    self.dim = dim

pytorch_tabular.models.common.layers.activations.RSoftmax(dim=-1, eps=1e-08)

Bases: torch.nn.Module

Source code in src/pytorch_tabular/models/common/layers/activations.py
def __init__(self, dim: int = -1, eps: float = 1e-8):
    super().__init__()
    self.dim = dim
    self.eps = eps
    self.tsoftmax = TSoftmax(dim=dim)