Skip to content
import os
import random

import numpy as np
import pandas as pd
from sklearn.metrics import mean_absolute_error, mean_squared_error
from sklearn.model_selection import train_test_split

from pytorch_tabular.utils import make_mixed_dataset, print_metrics

# %load_ext autoreload
# %autoreload 2
data, cat_col_names, num_col_names = make_mixed_dataset(
    task="regression", n_samples=10000, n_features=20, n_categories=4
)
train, test = train_test_split(data, random_state=42)
train, val = train_test_split(train, random_state=42)

Importing the Library

from pytorch_tabular import TabularModel
from pytorch_tabular.models import CategoryEmbeddingModelConfig
from pytorch_tabular.config import (
    DataConfig,
    OptimizerConfig,
    TrainerConfig,
    ExperimentConfig,
    ModelConfig,
)
from pytorch_tabular.models import BaseModel
from pytorch_tabular.models.common.layers import Embedding1dLayer
from pytorch_tabular.models.common.heads import LinearHeadConfig

Defining a Custom Model

import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import DictConfig
from typing import Dict
from dataclasses import dataclass, field

PyTorch Tabular is very easy to extend and infinitely customizable. All the models that have been implemented in PyTorch Tabular inherits an Abstract Class BaseModel which is in fact a PyTorchLightning Model.

It handles all the major functions like decoding the config params and setting up the loss and metrics. It also calculates the Loss and metrics and feeds it back to the PyTorch Lightning Trainer which does the back-propagation.

If we look at the anatomy of a PyTorch Tabular model, there are three main components:

  1. Embedding Layer
  2. Backbone
  3. Head

Embedding Layer takes the input from the dataloader, which is a dictionary with categorical and continuous tensors under those keys. The Embedding Layer converts this dictionary to a single tensor, handling the categorical tensors the right way. There are two Embedding Layers already implemented, EmbeddingLayer1d, EmbeddingLayer2d.

Backbone is the main model architecture.

Head is the linear layers which takes the output of the backbone and converts it into the output we desire.

To encapsulate and enforce this structure, BaseModel requires us to define three property methods: 1. def embedding_layer(self) 2. def backbone(self) 3. def head(self)

There is another method def _build_network(self) which also needs to be defined mandatorily. We can use this method to define the embedding, backbone, head, and whatever other layers or components we need to use in the model.

For standard Feed Forward layers as the head, we can also use a handly method in BaseModel called _get_head_from_config which will use the head and head_config you have set in the ModelConfig to initialize the right head automatically for you.

An example of using it:

self._head = self._get_head_from_config()

For standard flows, the forward method that is already defined would be enough.

def forward(self, x: Dict):
    x = self.embed_input(x) # Embeds the input dictionary and returns a tensor
    x = self.compute_backbone(x) # Takes the tensor input and does representation learning
    return self.compute_head(x) # transforms the backbone output to desired output, applies target range if necessary, and packs the output in the desired format.
What this allows us to do is to define any standard PyTorch model and use it as the backbone, and then use the rest of the PyTorch Tabular machinery to train, evaluate and log the model.

While this is the bare minimum, you can redefine or use any of the Pytorch Lightning standard methods to tweak your model and training to your liking. The real beauty of this setup is that how much customization you need to do is really upto you. For standard workflows, you can change the minimum. And for highly unsual models, you can re-define any of the method in BaseModel (as long as the interfaces remain unchanged).

In addition to the model, you will also need to define a config. Configs are python dataclasses and should inherit ModelConfig and will have all the parameters of the ModelConfig. by default. Any additional parameter should be defined in the dataclass.

Key things to note:

  1. All the different parameters in the different configs(like TrainerConfig, OptimizerConfig, etc) are all available in config before calling super() and in self.hparams after.
  2. the input batch at the forward method is a dictionary with keys continuous and categorical
  3. In the _build_network method, save every component that you want access in the forward to self
# Define a Config class to hold the hyperparameters of the model
# We need to inherit ModelConfig so that default parameters like learning rate, head, head_config, etc.
# are present in the config
@dataclass
class MyAwesomeModelConfig(ModelConfig):
    use_batch_norm: bool = True


# Define the core model as a pure PyTorch model
# The forward method takes in a tensor and outputs a tensor


class MyAwesomeRegressionModel(nn.Module):
    def __init__(self, hparams: DictConfig, **kwargs):
        super().__init__()
        self.hparams = (
            hparams  # Save the config to be accessed in the model elsewhere
        )
        self._build_network()

    def _build_network(self):
        # Continuous and Categorical Dimensions are precalculated and stored in the config (InferredConfig)
        inp_dim = self.hparams.embedded_cat_dim + self.hparams.continuous_dim
        self.linear_layer_1 = nn.Linear(inp_dim, 200)
        self.linear_layer_2 = nn.Linear(inp_dim + 200, 70)
        self.linear_layer_3 = nn.Linear(inp_dim + 70, 70)
        self.input_batch_norm = nn.BatchNorm1d(self.hparams.continuous_dim)
        if self.hparams.use_batch_norm:
            self.batch_norm_2 = nn.BatchNorm1d(inp_dim + 200)
            self.batch_norm_3 = nn.BatchNorm1d(inp_dim + 70)
        self.dropout = nn.Dropout(0.3)
        self.output_dim = 70

    def forward(self, x: torch.Tensor):
        inp = x
        x = F.relu(self.linear_layer_1(x))
        x = self.dropout(x)
        x = torch.cat([x, inp], 1)
        if self.hparams.use_batch_norm:
            x = self.batch_norm_1(x)
        x = F.relu(self.linear_layer_2(x))
        x = self.dropout(x)
        x = torch.cat([x, inp], 1)
        if self.hparams.use_batch_norm:
            x = self.batch_norm_3(x)
        x = self.linear_layer_3(x)
        return x

    # It is also good practice to bundle the embedding scheme along with the core model
    # This is so that the model encapsulates it's requirement of how to embed the input
    # dictionary of categorical and continuous tensors and how to combine them into a single tensor.
    # We can either use one of the predefined embedding layers in PyTorch Tabular
    # Or define a custom embedding layer as well. Check Embedding1dLayer implementation
    # on how to implement an embedding layer
    def _build_embedding_layer(self):
        return Embedding1dLayer(
            continuous_dim=self.hparams.continuous_dim,
            categorical_embedding_dims=self.hparams.embedding_dims,
            embedding_dropout=self.hparams.embedding_dropout,
            batch_norm_continuous_input=self.hparams.batch_norm_continuous_input,
        )


# Define the PyTorch Tabular model by inheriting BaseModel


class MyAwesomeRegressionPTModel(BaseModel):
    def __init__(self, config: DictConfig, **kwargs):
        # Save any attribute that you need in _build_network before calling super()
        # After the super() call, the config will be saved as `hparams`
        super().__init__(config, **kwargs)

    def _build_network(self):
        # Backbone - Initialize the PyTorch model we defined earlier
        self._backbone = MyAwesomeRegressionModel(self.hparams)
        # Initializing Embedding Layer using the method we defined in the backbone
        self._embedding_layer = self._backbone._build_embedding_layer()
        # Head - we can use the helper method to initialize standard linear head
        self.head = self._get_head_from_config()

    # Now define the property methods which the BaseModel requires you to override
    @property
    def backbone(self):
        return self._backbone

    @property
    def embedding_layer(self):
        return self._embedding_layer

    @property
    def head(self):
        return self._head

    # For more customizations, we can override forward, compute_backbone, compute_head etc.

Define the Configs

There is one deviation from the normal when we create a TabularModel object with the configs. Earlier the model was inferred from the config and initialized autmatically. But here, we have to use the model_callable parameter of the TabularModel and pass in the model class(not the initialized object)

data_config = DataConfig(
    target=[
        "target"
    ],  # target should always be a list. Multi-targets are only supported for regression. Multi-Task Classification is not implemented
    continuous_cols=num_col_names,
    categorical_cols=cat_col_names,
)
trainer_config = TrainerConfig(
    auto_lr_find=True,  # Runs the LRFinder to automatically derive a learning rate
    batch_size=1024,
    max_epochs=100,
    accelerator="auto",  # can be 'cpu','gpu', 'tpu', or 'ipu'
    devices=-1,  # -1 means use all available
)
optimizer_config = OptimizerConfig()

head_config = LinearHeadConfig(
    layers="32",  # Number of nodes in each layer
    activation="ReLU",  # Activation between each layers
    dropout=0.1,
    initialization="kaiming",
).__dict__  # Convert to dict to pass to the model config (OmegaConf doesn't accept objects)

model_config = MyAwesomeModelConfig(
    task="regression",
    use_batch_norm=False,
    head="LinearHead",  # Linear Head
    head_config=head_config,  # Linear Head Config
    learning_rate=1e-3,
)

tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
    model_callable=MyAwesomeRegressionPTModel,  # When using custom model, we need to use the model_callable parameter
)
2023-12-26 16:43:07,659 - {pytorch_tabular.tabular_model:134} - INFO - Experiment Tracking is turned off           

Training the Model

The rest of the process is business-as-usual

tabular_model.fit(train=train, validation=val)
Seed set to 42

2023-12-26 16:43:10,447 - {pytorch_tabular.tabular_model:506} - INFO - Preparing the DataLoaders                   
2023-12-26 16:43:10,455 - {pytorch_tabular.tabular_datamodule:431} - INFO - Setting up the datamodule for          
regression task                                                                                                    
2023-12-26 16:43:10,490 - {pytorch_tabular.tabular_model:556} - INFO - Preparing the Model: Model                  
2023-12-26 16:43:10,541 - {pytorch_tabular.tabular_model:322} - INFO - Preparing the Trainer                       
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

2023-12-26 16:43:10,722 - {pytorch_tabular.tabular_model:612} - INFO - Auto LR Find Started                        
You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:639: Checkpoint directory saved_models exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:293: The number of training batches (6) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.

Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]
LR finder stopped early after 99 steps due to diverging loss.
Learning rate set to 0.008317637711026709
Restoring states from the checkpoint path at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_c0a3e124-98c0-4ff0-80c4-0287c2f67b80.ckpt
Restored all states from the checkpoint at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_c0a3e124-98c0-4ff0-80c4-0287c2f67b80.ckpt

2023-12-26 16:43:14,083 - {pytorch_tabular.tabular_model:625} - INFO - Suggested LR: 0.008317637711026709. For plot
and detailed analysis, use `find_learning_rate` method.                                                            
2023-12-26 16:43:14,085 - {pytorch_tabular.tabular_model:634} - INFO - Training Started                            
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓
┃    Name              Type                      Params ┃
┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩
│ 0 │ _backbone        │ MyAwesomeRegressionModel │ 28.8 K │
│ 1 │ _embedding_layer │ Embedding1dLayer         │     92 │
│ 2 │ head             │ LinearHead               │  2.3 K │
│ 3 │ loss             │ MSELoss                  │      0 │
└───┴──────────────────┴──────────────────────────┴────────┘
Trainable params: 31.2 K                                                                                           
Non-trainable params: 0                                                                                            
Total params: 31.2 K                                                                                               
Total estimated model params size (MB): 0                                                                          
Output()


2023-12-26 16:43:16,563 - {pytorch_tabular.tabular_model:645} - INFO - Training the model completed                
2023-12-26 16:43:16,564 - {pytorch_tabular.tabular_model:1491} - INFO - Loading the best model                     
<pytorch_lightning.trainer.trainer.Trainer at 0x7f08ebab7e90>
result = tabular_model.evaluate(test)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

Output()
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│         test_loss             509.21295166015625     │
│  test_mean_squared_error      509.21295166015625     │
└───────────────────────────┴───────────────────────────┘
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.



pred_df = tabular_model.predict(test)
pred_df.head()
target_prediction
6252 -111.886528
4684 -169.181137
1731 272.100281
4742 12.167077
4521 56.998451
print_metrics(
    [(mean_squared_error, "MSE", {}), (mean_absolute_error, "MAE", {})],
    test["target"],
    pred_df["target_prediction"],
    tag="Holdout",
)
Holdout MSE: 509.21296869752325 | Holdout MAE: 17.057807657741225