Skip to content
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
import random
import numpy as np
import pandas as pd
import os
import wandb
import logging

# %load_ext autoreload
# %autoreload 2

Utility Functions

def make_mixed_classification(n_samples, n_features, n_categories):
    X, y = make_classification(n_samples=n_samples, n_features=n_features, random_state=42, n_informative=5)
    cat_cols = random.choices(list(range(X.shape[-1])), k=n_categories)
    num_cols = [i for i in range(X.shape[-1]) if i not in cat_cols]
    for col in cat_cols:
        X[:, col] = pd.qcut(X[:, col], q=4).codes.astype(int)
    col_names = []
    num_col_names = []
    cat_col_names = []
    for i in range(X.shape[-1]):
        if i in cat_cols:
            col_names.append(f"cat_col_{i}")
            cat_col_names.append(f"cat_col_{i}")
        if i in num_cols:
            col_names.append(f"num_col_{i}")
            num_col_names.append(f"num_col_{i}")
    X = pd.DataFrame(X, columns=col_names)
    y = pd.Series(y, name="target")
    data = X.join(y)
    return data, cat_col_names, num_col_names


def print_metrics(y_true, y_pred, tag):
    if isinstance(y_true, pd.DataFrame) or isinstance(y_true, pd.Series):
        y_true = y_true.values
    if isinstance(y_pred, pd.DataFrame) or isinstance(y_pred, pd.Series):
        y_pred = y_pred.values
    if y_true.ndim > 1:
        y_true = y_true.ravel()
    if y_pred.ndim > 1:
        y_pred = y_pred.ravel()
    val_acc = accuracy_score(y_true, y_pred)
    val_f1 = f1_score(y_true, y_pred)
    print(f"{tag} Acc: {val_acc} | {tag} F1: {val_f1}")

Generate Synthetic Data

First of all, let's create a synthetic data which is a mix of numerical and categorical features

data, cat_col_names, num_col_names = make_mixed_classification(n_samples=10000, n_features=20, n_categories=4)
train, test = train_test_split(data, random_state=42)
train, val = train_test_split(train, random_state=42)

Importing the Library

from pytorch_tabular import TabularModel
from pytorch_tabular.models import (
    CategoryEmbeddingModelConfig,
    NodeConfig,
    TabNetModelConfig,
    GatedAdditiveTreeEnsembleConfig,
)
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig, ExperimentConfig
from pytorch_tabular.models.common.heads import LinearHeadConfig

Test-Time Augumentation

> Test time augmentation (TTA) is a popular technique in computer vision. TTA aims at boosting the model accuracy by using data augmentation on the inference stage. The idea behind TTA is simple: for each test image, we create multiple versions that are a little different from the original (e.g., cropped or flipped). Next, we predict labels for the test images and created copies and average model predictions over multiple versions of each image. This usually helps to improve the accuracy irrespective of the underlying model. - from Test-Time Augmentation for Tabular Data

Using the Low-level API, we can also implement test-time augumentation for tabular data

results = []
data_config = DataConfig(
    target=[
        "target"
    ],  # target should always be a list. Multi-targets are only supported for regression. Multi-Task Classification is not implemented
    continuous_cols=num_col_names,
    categorical_cols=cat_col_names,
)
trainer_config = TrainerConfig(
    batch_size=1024,
    max_epochs=100,
    early_stopping="valid_loss",  # Monitor valid_loss for early stopping
    early_stopping_mode="min",  # Set the mode as min because for val_loss, lower is better
    early_stopping_patience=5,  # No. of epochs of degradation training will wait before terminating
    checkpoints="valid_loss",  # Save best checkpoint monitoring val_loss
    load_best=True,  # After training, load the best checkpoint
    #     progress_bar="none", # Turning off Progress bar
    #     trainer_kwargs=dict(
    #         enable_model_summary=False # Turning off model summary
    #     )
)
optimizer_config = OptimizerConfig()

head_config = LinearHeadConfig(
    layers="", dropout=0.1, initialization="kaiming"  # No additional layer in head, just a mapping layer to output_dim
).__dict__  # Convert to dict to pass to the model config (OmegaConf doesn't accept objects)

model1_config = CategoryEmbeddingModelConfig(
    task="classification",
    layers="1024-512-512",  # Number of nodes in each layer
    activation="LeakyReLU",  # Activation between each layers
    learning_rate=1e-3,
    head="LinearHead",  # Linear Head
    head_config=head_config,  # Linear Head Config
)

model_config = CategoryEmbeddingModelConfig(
    task="classification",
    layers="1024-512-512",  # Number of nodes in each layer
    activation="LeakyReLU",  # Activation between each layers
    learning_rate=1e-3,
    head="LinearHead",  # Linear Head
    head_config=head_config,  # Linear Head Config
)

tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
)
datamodule = tabular_model.prepare_dataloader(train=train, validation=val, seed=42)
model = tabular_model.prepare_model(datamodule)
tabular_model.train(model, datamodule)
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓
┃    Name              Type                       Params ┃
┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩
│ 0 │ _backbone        │ CategoryEmbeddingBackbone │  817 K │
│ 1 │ _embedding_layer │ Embedding1dLayer          │     92 │
│ 2 │ head             │ LinearHead                │  1.0 K │
│ 3 │ loss             │ CrossEntropyLoss          │      0 │
└───┴──────────────────┴───────────────────────────┴────────┘
Trainable params: 818 K                                                                                            
Non-trainable params: 0                                                                                            
Total params: 818 K                                                                                                
Total estimated model params size (MB): 3                                                                          

Output()
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(

<pytorch_lightning.trainer.trainer.Trainer at 0x7f0c8a1f6080>
# Taking original predictions
orig_pred_df = tabular_model.predict(test)
pred_df = orig_pred_df[["0_probability", "1_probability"]].copy()
pred_df.columns = ["0_probability_0", "1_probability_0"]
Generating Predictions...:   0%|          | 0/3 [00:00<?, ?it/s]
num_tta = 4  # number of augumentations
noise_scale = 0.05  # Scale of the noise. Higher means larger perturbation of features
categorical_swap_proba = 0.1
from pytorch_tabular.ssl_models.common.noise_generators import SwapNoiseCorrupter
from sklearn.metrics import accuracy_score
import torch

for i in range(num_tta):
    noised_test = test.copy()
    corrupter = SwapNoiseCorrupter(
        probas=[categorical_swap_proba for _ in range(len(datamodule.config.categorical_cols))]
    )
    for col in datamodule.config.continuous_cols:
        # Adding gaussian noise to continuous columns
        noised_test[col] = (
            noised_test[col] + noise_scale * np.random.normal(0, 1, size=len(noised_test)) * noised_test[col].std()
        )
    # For categorical we can swap values
    # for this we can reuse the SwapNoiseCorrupter in PyTorch Tabular
    cat_cols = torch.from_numpy(noised_test[datamodule.config.categorical_cols].values)
    cat_cols, _ = corrupter(cat_cols)
    noised_test.loc[:, datamodule.config.categorical_cols] = cat_cols.numpy()

    pred_df[[f"0_probability_{i+1}", f"1_probability_{i+1}"]] = 0
    pred_df[[f"0_probability_{i+1}", f"1_probability_{i+1}"]] = tabular_model.predict(noised_test)[
        ["0_probability", "1_probability"]
    ]
Generating Predictions...:   0%|          | 0/3 [00:00<?, ?it/s]
Generating Predictions...:   0%|          | 0/3 [00:00<?, ?it/s]
Generating Predictions...:   0%|          | 0/3 [00:00<?, ?it/s]
Generating Predictions...:   0%|          | 0/3 [00:00<?, ?it/s]
# Calculating mean probability across different augumentations
pred_df["0_prob"] = pred_df[[col for col in pred_df.columns if col.startswith("0_probability")]].mean(axis=1)
pred_df["1_prob"] = pred_df[[col for col in pred_df.columns if col.startswith("1_probability")]].mean(axis=1)
# Calculating final preds from probabilisites using 0.5 threshold
pred_df["original_pred"] = pred_df["1_probability_0"] &gt; 0.5
pred_df["tta_pred"] = pred_df["1_prob"] &gt; 0.5
# Calculating metrics
orig_acc = accuracy_score(test["target"].values, pred_df["original_pred"].values)
tta_acc = accuracy_score(test["target"].values, pred_df["tta_pred"].values)
print(f"Original Accuracy: {orig_acc} | TTA Accuracy: {tta_acc}")
Original Accuracy: 0.924 | TTA Accuracy: 0.928

noise_scale and categorical_swap_proba should be tuned to find the right amount of noise to be injected.