Skip to content
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
import random
import numpy as np
import pandas as pd

from pytorch_tabular.utils import make_mixed_dataset, print_metrics
import os
%load_ext autoreload
%autoreload 2
data, cat_col_names, num_col_names = make_mixed_dataset(task="classification", n_samples=10000, n_features=8, n_categories=4, weights=[0.8], random_state=42)
train, test = train_test_split(data, random_state=42)
train, val = train_test_split(train, random_state=42)
data.target.value_counts(normalize=True)
target
class_0    0.7968
class_1    0.2032
Name: proportion, dtype: float64

Importing the Library

from pytorch_tabular import TabularModel
from pytorch_tabular.models import CategoryEmbeddingModelConfig
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig, ExperimentConfig
from pytorch_tabular.models.common.heads import LinearHeadConfig
results = []

Define the Configs

data_config = DataConfig(
    target=['target'], #target should always be a list. Multi-targets are only supported for regression. Multi-Task Classification is not implemented
    continuous_cols=num_col_names,
    categorical_cols=cat_col_names,
)
trainer_config = TrainerConfig(
    auto_lr_find=True, # Runs the LRFinder to automatically derive a learning rate
    batch_size=1024,
    max_epochs=100,
    early_stopping="valid_loss", # Monitor valid_loss for early stopping
    early_stopping_mode = "min", # Set the mode as min because for val_loss, lower is better
    early_stopping_patience=5, # No. of epochs of degradation training will wait before terminating
    checkpoints="valid_loss", # Save best checkpoint monitoring val_loss
    load_best=True, # After training, load the best checkpoint
#     accelerator="cpu"
)
optimizer_config = OptimizerConfig()

head_config = LinearHeadConfig(
    layers="", # No additional layer in head, just a mapping layer to output_dim
    dropout=0.1,
    initialization="kaiming"
).__dict__ # Convert to dict to pass to the model config (OmegaConf doesn't accept objects)

model_config = CategoryEmbeddingModelConfig(
    task="classification",
    layers="1024-512-512",  # Number of nodes in each layer
    activation="LeakyReLU", # Activation between each layers
    head = "LinearHead", #Linear Head
    head_config = head_config, # Linear Head Config
    learning_rate = 1e-3,
    metrics=["f1_score","accuracy"], 
    metrics_params=[{"num_classes":2},{}], # f1_score needs num_classes
    metrics_prob_input=[True, False] # f1_score needs probability scores, while accuracy doesn't
)

Training the Model

tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
)
2023-12-27 15:10:25,372 - {pytorch_tabular.tabular_model:134} - INFO - Experiment Tracking is turned off           
tabular_model.fit(train=train, validation=val)
Seed set to 42

2023-12-27 15:10:26,887 - {pytorch_tabular.tabular_model:506} - INFO - Preparing the DataLoaders                   
2023-12-27 15:10:26,892 - {pytorch_tabular.tabular_datamodule:431} - INFO - Setting up the datamodule for          
classification task                                                                                                
2023-12-27 15:10:26,918 - {pytorch_tabular.tabular_model:556} - INFO - Preparing the Model: CategoryEmbeddingModel 
2023-12-27 15:10:26,951 - {pytorch_tabular.tabular_model:322} - INFO - Preparing the Trainer                       
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

2023-12-27 15:10:27,101 - {pytorch_tabular.tabular_model:612} - INFO - Auto LR Find Started                        
You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:639: Checkpoint directory saved_models exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:293: The number of training batches (6) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.

Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]
LR finder stopped early after 93 steps due to diverging loss.
Learning rate set to 0.0013182567385564075
Restoring states from the checkpoint path at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_9ed58e8d-3195-4f96-ac24-2a0f0e44afcc.ckpt
Restored all states from the checkpoint at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_9ed58e8d-3195-4f96-ac24-2a0f0e44afcc.ckpt

2023-12-27 15:10:29,919 - {pytorch_tabular.tabular_model:625} - INFO - Suggested LR: 0.0013182567385564075. For    
plot and detailed analysis, use `find_learning_rate` method.                                                       
2023-12-27 15:10:29,925 - {pytorch_tabular.tabular_model:634} - INFO - Training Started                            
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓
┃    Name              Type                       Params ┃
┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩
│ 0 │ _backbone        │ CategoryEmbeddingBackbone │  802 K │
│ 1 │ _embedding_layer │ Embedding1dLayer          │     55 │
│ 2 │ head             │ LinearHead                │  1.0 K │
│ 3 │ loss             │ CrossEntropyLoss          │      0 │
└───┴──────────────────┴───────────────────────────┴────────┘
Trainable params: 803 K                                                                                            
Non-trainable params: 0                                                                                            
Total params: 803 K                                                                                                
Total estimated model params size (MB): 3                                                                          
Output()


2023-12-27 15:10:33,645 - {pytorch_tabular.tabular_model:645} - INFO - Training the model completed                
2023-12-27 15:10:33,646 - {pytorch_tabular.tabular_model:1491} - INFO - Loading the best model                     
<pytorch_lightning.trainer.trainer.Trainer at 0x7fa60824df10>
result = tabular_model.evaluate(test)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

Output()
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.9332000017166138     │
│       test_f1_score           0.9332000017166138     │
│         test_loss             0.1947803646326065     │
└───────────────────────────┴───────────────────────────┘
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.



result = {k: float(v) for k,v in result[0].items()}
result["mode"] = "Normal"

results.append(result)

Custom Sampler

PyTorch Tabular also allows custom batching strategy through Custom Samplers which comes in handy when working with imbalanced data.

Although you can use any sampler, Pytorch Tabular has a few handy utility functions which takes in the target array and implements WeightedRandomSampler using inverse frequency sampling to combat imbalance. This is analogous to preprocessing techniques like Under or OverSampling in traditional ML systems.

from pytorch_tabular.utils import get_balanced_sampler, get_class_weighted_cross_entropy
tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
    verbose=False
)
sampler = get_balanced_sampler(train['target'].values.ravel())

tabular_model.fit(train=train, validation=val, train_sampler=sampler)
Seed set to 42

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:639: Checkpoint directory saved_models exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:293: The number of training batches (6) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.

Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]
LR finder stopped early after 88 steps due to diverging loss.
Learning rate set to 0.00017378008287493763
Restoring states from the checkpoint path at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_423bc877-8996-4138-8885-7dad6f5f2f36.ckpt
Restored all states from the checkpoint at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_423bc877-8996-4138-8885-7dad6f5f2f36.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓
┃    Name              Type                       Params ┃
┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩
│ 0 │ _backbone        │ CategoryEmbeddingBackbone │  802 K │
│ 1 │ _embedding_layer │ Embedding1dLayer          │     55 │
│ 2 │ head             │ LinearHead                │  1.0 K │
│ 3 │ loss             │ CrossEntropyLoss          │      0 │
└───┴──────────────────┴───────────────────────────┴────────┘
Trainable params: 803 K                                                                                            
Non-trainable params: 0                                                                                            
Total params: 803 K                                                                                                
Total estimated model params size (MB): 3                                                                          
Output()


<pytorch_lightning.trainer.trainer.Trainer at 0x7fa60066a8d0>
result = tabular_model.evaluate(test)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

Output()
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy            0.921999990940094     │
│       test_f1_score            0.921999990940094     │
│         test_loss             0.2218063920736313     │
└───────────────────────────┴───────────────────────────┘
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.



result = {k: float(v) for k,v in result[0].items()}
result["mode"] = "Balanced Sampler"

results.append(result)

Custom Weighted Loss

If Samplers were like Over/Under Sampling, Custom Weighted Loss is similar to class_weights. Depending on the problem, one of these might help you with imbalance. You can easily make calculate the class_weights and provide them to the CrossEntropyLoss using the parameter weight. To make this easier, PyTorch Tabular has a handy utility method which calculates smoothed class weights and initializes a weighted loss. Once you have that loss, it's just a matter of passing it to the 1fit1 method using the loss parameter.

tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
    verbose=False
)
weighted_loss = get_class_weighted_cross_entropy(train["target"].values.ravel(), mu=0.1)

tabular_model.fit(train=train, validation=val, loss=weighted_loss)
Seed set to 42
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:639: Checkpoint directory saved_models exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:293: The number of training batches (6) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.

Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]
LR finder stopped early after 93 steps due to diverging loss.
Learning rate set to 0.0013182567385564075
Restoring states from the checkpoint path at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_c0a72040-20ef-497a-a00b-650c9f9d4dd2.ckpt
Restored all states from the checkpoint at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_c0a72040-20ef-497a-a00b-650c9f9d4dd2.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓
┃    Name              Type                       Params ┃
┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩
│ 0 │ custom_loss      │ CrossEntropyLoss          │      0 │
│ 1 │ _backbone        │ CategoryEmbeddingBackbone │  802 K │
│ 2 │ _embedding_layer │ Embedding1dLayer          │     55 │
│ 3 │ head             │ LinearHead                │  1.0 K │
└───┴──────────────────┴───────────────────────────┴────────┘
Trainable params: 803 K                                                                                            
Non-trainable params: 0                                                                                            
Total params: 803 K                                                                                                
Total estimated model params size (MB): 3                                                                          
Output()


<pytorch_lightning.trainer.trainer.Trainer at 0x7fa60002ef90>
result = tabular_model.evaluate(test)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

Output()
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.9332000017166138     │
│       test_f1_score           0.9332000017166138     │
│         test_loss             0.1947803646326065     │
└───────────────────────────┴───────────────────────────┘
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.



result = {k: float(v) for k,v in result[0].items()}
result["mode"] = "Class Weights"

results.append(result)
res_df = pd.DataFrame(results).T
res_df.columns = res_df.iloc[-1]
res_df = res_df.iloc[:-1].astype(float)
res_df.style.highlight_min(color="lightgreen",axis=1)
mode Normal Balanced Sampler Class Weights
test_loss 0.194780 0.221806 0.194780
test_f1_score 0.933200 0.922000 0.933200
test_accuracy 0.933200 0.922000 0.933200