Skip to content
import warnings

import numpy as np
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.model_selection import train_test_split

from pytorch_tabular.utils import make_mixed_dataset
data, cat_col_names, num_col_names = make_mixed_dataset(
    task="classification", n_samples=10000, n_features=20, n_categories=4
)

Importing the Library

from pytorch_tabular import TabularModel
from pytorch_tabular.models import (
    CategoryEmbeddingModelConfig,
)
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
from pytorch_tabular.models.common.heads import LinearHeadConfig
train, test = train_test_split(data, random_state=42)

Cross Validation

data_config = DataConfig(
    target=[
        "target"
    ],  # target should always be a list. Multi-targets are only supported for regression. Multi-Task Classification is not implemented
    continuous_cols=num_col_names,
    categorical_cols=cat_col_names,
)
trainer_config = TrainerConfig(
    batch_size=1024,
    max_epochs=100,
    early_stopping="valid_loss",  # Monitor valid_loss for early stopping
    early_stopping_mode="min",  # Set the mode as min because for val_loss, lower is better
    early_stopping_patience=5,  # No. of epochs of degradation training will wait before terminating
    checkpoints="valid_loss",  # Save best checkpoint monitoring val_loss
    load_best=True,  # After training, load the best checkpoint
    progress_bar="none",  # Turning off Progress bar
    trainer_kwargs=dict(enable_model_summary=False),  # Turning off model summary
)
optimizer_config = OptimizerConfig()

head_config = LinearHeadConfig(
    layers="",
    dropout=0.1,
    initialization=(  # No additional layer in head, just a mapping layer to output_dim
        "kaiming"
    ),
).__dict__  # Convert to dict to pass to the model config (OmegaConf doesn't accept objects)

model_config = CategoryEmbeddingModelConfig(
    task="classification",
    layers="1024-512-512",  # Number of nodes in each layer
    activation="LeakyReLU",  # Activation between each layers
    learning_rate=1e-3,
    head="LinearHead",  # Linear Head
    head_config=head_config,  # Linear Head Config
)

tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
    verbose=False,
)

Using High-Level API

We can use the high level method cross_validate in TabularModel'

The arguments are as follows: - cv can either be an integer or a KFold object. If it is an integer, it will be treated as the number of folds in a KFold. For classification problems, it will be a StratifiedKFold. if its a KFold object, it will be used as is. - metric is the metric to be used for evaluation. It can either be a string (name of the metric) or a callable. If it is a callable, it should take in two arguments, the predictions and the targets. The predictions should be the dataframe output from the model.predict and the target can be a series or an array. - train is the training dataset. - return_oof is a boolean. If set to True, it will return the out-of-fold predictions for the training dataset. This is useful for stacking models. - reset_datamodule is a boolean. If set to True, it will reset the datamodule after each fold, and is the right way of doing cross validation. If set to False, it will not reset the datamodule and will be faster, but will have a small amount of data leakage. This is useful when working with huge datasets and you want to save time.

# cross validation loop usnig sklearn
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score, f1_score

kf = KFold(n_splits=5, shuffle=True, random_state=42)


def _accuracy(y_true, y_pred):
    return accuracy_score(y_true, y_pred["prediction"].values)


with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    cv_scores, oof_predictions = tabular_model.cross_validate(
        cv=2, train=train, metric=_accuracy, return_oof=True, reset_datamodule=False
    )
2023-12-31 13:08:18,468 - {pytorch_tabular.tabular_model:1925} - INFO - Running Fold 1/2                           
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

2023-12-31 13:08:22,376 - {pytorch_tabular.tabular_model:1952} - INFO - Fold 1/2 score: 0.908                      
2023-12-31 13:08:22,383 - {pytorch_tabular.tabular_model:1925} - INFO - Running Fold 2/2                           
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

2023-12-31 13:08:24,704 - {pytorch_tabular.tabular_model:1952} - INFO - Fold 2/2 score: 0.9517333333333333         
print(f"KFold Mean: {np.mean(cv_scores)} | KFold SD: {np.std(cv_scores)}")
KFold Mean: 0.9298666666666666 | KFold SD: 0.021866666666666645

Using Low-Level API

Sometimes, we might want to do something more than just a plain, vanilla, cross validation. For a example, we might want to do a cross validation with mutiple metrics, or we might want to do a cross validation with a custom metric which relies on something other than the target and predictions. In such cases, we can use the low level API.

from rich import print
def _accuracy(y_true, y_pred):
    return accuracy_score(y_true, y_pred["prediction"].values)


def _roc_auc_score(y_true, y_pred):
    return roc_auc_score(y_true, y_pred["class_1_probability"].values)


kf = KFold(n_splits=5, shuffle=True, random_state=42)
# Initialize the tabular model onece
tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
    verbose=False,
)
acc_metrics = []
roc_metrics = []
preds = []
datamodule = None
model = None
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    for fold, (train_idx, val_idx) in enumerate(kf.split(train)):
        train_fold = train.iloc[train_idx]
        val_fold = train.iloc[val_idx]
        if datamodule is None:
            # Initialize datamodule and model in the first fold
            # uses train data from this fold to fit all transformers
            datamodule = tabular_model.prepare_dataloader(
                train=train_fold, validation=val_fold, seed=42
            )
            model = tabular_model.prepare_model(datamodule)
        else:
            # Creates a copy of the datamodule with same transformers but different train and validation data
            datamodule = datamodule.copy(train=train_fold, validation=val_fold)
        # Train the model
        tabular_model.train(model, datamodule)
        pred_df = tabular_model.predict(val_fold)
        acc_metrics.append(_accuracy(val_fold["target"], pred_df))
        roc_metrics.append(_roc_auc_score(val_fold["target"], pred_df))
        print(
            f"[bold red]Fold:[/bold red] {fold} | [bold green]Accuracy:[/bold green]"
            f" {acc_metrics[-1]} | [bold green]AUC:[/bold green] {roc_metrics[-1]}"
        )
        # Reset the trained weights before next fold
        tabular_model.model.reset_weights()
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

Fold: 0 | Accuracy: 0.9293333333333333 | AUC: 0.9807391279599271
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

Fold: 1 | Accuracy: 0.9146666666666666 | AUC: 0.9736274684219891
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

Fold: 2 | Accuracy: 0.924 | AUC: 0.9730588808512757
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

Fold: 3 | Accuracy: 0.922 | AUC: 0.9757440627005844
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

Fold: 4 | Accuracy: 0.9166666666666666 | AUC: 0.9743804540010267
print(
    f"KFold Accuracy Mean: {np.mean(acc_metrics)} | KFold Accuracy SD:"
    f" {np.std(acc_metrics)}"
)
print(f"KFold AUC Mean: {np.mean(roc_metrics)} | KFold AUC SD: {np.std(roc_metrics)}")
KFold Accuracy Mean: 0.9213333333333333 | KFold Accuracy SD: 0.005249338582674566
KFold AUC Mean: 0.9755099987869607 | KFold AUC SD: 0.002765008099674828