Skip to content
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
import random
import numpy as np
import pandas as pd
import os
import wandb
from pathlib import Path
np.random.seed(42)
random.seed(42)
# %load_ext autoreload
# %autoreload 2

Utility Functions

def print_metrics(y_true, y_pred, tag):
    if isinstance(y_true, pd.DataFrame) or isinstance(y_true, pd.Series):
        y_true = y_true.values
    if isinstance(y_pred, pd.DataFrame) or isinstance(y_pred, pd.Series):
        y_pred = y_pred.values
    if y_true.ndim>1:
        y_true=y_true.ravel()
    if y_pred.ndim>1:
        y_pred=y_pred.ravel()
    val_acc = accuracy_score(y_true, y_pred)
    val_f1 = f1_score(y_true, y_pred)
    print(f"{tag} Acc: {val_acc} | {tag} F1: {val_f1}")

Loading CoverType Dataset

Predicting forest cover type from cartographic variables only (no remotely sensed data). The actual forest cover type for a given observation (30 x 30 meter cell) was determined from US Forest Service (USFS) Region 2 Resource Information System (RIS) data. Independent variables were derived from data originally obtained from US Geological Survey (USGS) and USFS data. Data is in raw form (not scaled) and contains binary (0 or 1) columns of data for qualitative independent variables (wilderness areas and soil types).

This study area includes four wilderness areas located in the Roosevelt National Forest of northern Colorado. These areas represent forests with minimal human-caused disturbances, so that existing forest cover types are more a result of ecological processes rather than forest management practices.

It is from UCI ML Repository

BASE_DIR = Path.home().joinpath("data")
datafile = BASE_DIR.joinpath("covtype.data.gz")
datafile.parent.mkdir(parents=True, exist_ok=True)
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/covtype/covtype.data.gz"
if not datafile.exists():
    import wget

    wget.download(url, datafile.as_posix())

target_name = "Covertype"

cat_col_names = [
    "Wilderness_Area1",
    "Wilderness_Area2",
    "Wilderness_Area3",
    "Wilderness_Area4",
    "Soil_Type1",
    "Soil_Type2",
    "Soil_Type3",
    "Soil_Type4",
    "Soil_Type5",
    "Soil_Type6",
    "Soil_Type7",
    "Soil_Type8",
    "Soil_Type9",
    "Soil_Type10",
    "Soil_Type11",
    "Soil_Type12",
    "Soil_Type13",
    "Soil_Type14",
    "Soil_Type15",
    "Soil_Type16",
    "Soil_Type17",
    "Soil_Type18",
    "Soil_Type19",
    "Soil_Type20",
    "Soil_Type21",
    "Soil_Type22",
    "Soil_Type23",
    "Soil_Type24",
    "Soil_Type25",
    "Soil_Type26",
    "Soil_Type27",
    "Soil_Type28",
    "Soil_Type29",
    "Soil_Type30",
    "Soil_Type31",
    "Soil_Type32",
    "Soil_Type33",
    "Soil_Type34",
    "Soil_Type35",
    "Soil_Type36",
    "Soil_Type37",
    "Soil_Type38",
    "Soil_Type39",
    "Soil_Type40",
]

num_col_names = [
    "Elevation",
    "Aspect",
    "Slope",
    "Horizontal_Distance_To_Hydrology",
    "Vertical_Distance_To_Hydrology",
    "Horizontal_Distance_To_Roadways",
    "Hillshade_9am",
    "Hillshade_Noon",
    "Hillshade_3pm",
    "Horizontal_Distance_To_Fire_Points",
]

feature_columns = num_col_names + cat_col_names + [target_name]

data = pd.read_csv(datafile, header=None, names=feature_columns)
#Dropping NA rows
data.dropna(inplace=True)
# Splitting data into data for SSL and data for finetuning
ssl, finetune = train_test_split(data, random_state=42, test_size=0.01)
# Train and val splits
ssl_train, ssl_val = train_test_split(ssl, random_state=42)
finetune_train, finetune_test = train_test_split(finetune, random_state=42)
finetune_train, finetune_val = train_test_split(finetune_train, random_state=42)
print(f"Unlabelled Data: {ssl.shape[0]} rows | Labelled Data: {finetune.shape[0]}")
Unlabelled Data: 575201 rows | Labelled Data: 5811

Importing the Library

from pytorch_tabular import TabularModel
from pytorch_tabular.models import CategoryEmbeddingModelConfig
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig, ExperimentConfig
from pytorch_tabular.models.common.heads import LinearHeadConfig
from pytorch_tabular.ssl_models.dae import DenoisingAutoEncoderConfig

Self-Supervised Learning

An excerpt from the article by Yann LeCun and Ishan Mishra from Meta will serve as a good introduction here: > Supervised learning is a bottleneck for building more intelligent generalist models that can do multiple tasks and acquire new skills without massive amounts of labeled data. Practically speaking, it’s impossible to label everything in the world. There are also some tasks for which there’s simply not enough labeled data, such as training translation systems for low-resource languages.

> As babies, we learn how the world works largely by observation. We form generalized predictive models about objects in the world by learning concepts such as object permanence and gravity. Later in life, we observe the world, act on it, observe again, and build hypotheses to explain how our actions change our environment by trial and error.

>Common sense helps people learn new skills without requiring massive amounts of teaching for every single task. For example, if we show just a few drawings of cows to small children, they’ll eventually be able to recognize any cow they see. By contrast, AI systems trained with supervised learning require many examples of cow images and might still fail to classify cows in unusual situations, such as lying on a beach. How is it that humans can learn to drive a car in about 20 hours of practice with very little supervision, while fully autonomous driving still eludes our best AI systems trained with thousands of hours of data from human drivers? The short answer is that humans rely on their previously acquired background knowledge of how the world works.

> How do we get machines to do the same?

>We believe that self-supervised learning (SSL) is one of the most promising ways to build such background knowledge and approximate a form of common sense in AI systems.

Full Article is a very interesting read.

SSL has been very successfully used in NLP (All the Large Language Models which create magic is learnt through SSL), and with some success in Computer Vision. But can we do that with Tabular data? The answer is yes.

There are many research papers which talk about such models:

  1. TabNet talks about how it can be used for SSL
  2. VIME: Extending the Success of Self- and Semi-supervised Learning to Tabular Domain also proposes another SSL model for tabular
  3. And so does SubTab: Subsetting Features of Tabular Data for Self-Supervised Representation Learning

And before all these, there was Denoising AutoEncoder which was used for winning solutions in many Tabular Playground competitions.

PyTorch Tabular has provided the Denoising AutoEncoder implementation inspired from https://github.com/ryancheunggit/tabular_dae. Let's see how we can use that.

A Denoising AutoEncoder has the below architecture (Source):

DAE

We corrupt the input on the left and we ask the model to learn to predict the orginal, denoised input. By this process, the network is forced to learn a compressed bottleneck (labelled code) which captures most of the characteristics of the input data, i.e. a robust representation of the input data.

In the DenoisingAutoencoder implementation in PyTorchTabular, the noise is introduced in two ways: 1. swap - In this strategy, noise is introduced by replacing a value in a feature with another value of the same feature, randomly sampled from the rest of the rows.

  1. zero - In here, noise is introduced by just replacing the value with zero.

In addition to that, we also can set noise_probabilities, with which we can define the probability with which noise will be introduced to a feature. We can set this parameter as a dictionary of the form, {featurename: noise_probability}. Or we can also set a single probability for all the features easily by using default_noise_probability

Now once we have this robust representation, we can use this representation in other downstream tasks like regression or classification.

A typical SSL workflow would have a large dataset without labels, and a smaller dataset with labels for finetuning.

  1. We start with Pre-training using unlabelled data
  2. Then we use the pre-trained model (the codein the diagram) for a dowstream task like regression or classification
    i. We create a new model with the pretrained model as the backbone and a head for prediction
    ii. We train the new model (finetune) on small labelled data

This approach would typically work better than a purely supervised model using just the small labelled dataset.

Fully Supervised Model

results = []

Let's first train a fully supervised model using just the ~5k rows of labelled data. This can be used as a baseline.

data_config = DataConfig(
    target=[target_name],
    continuous_cols=num_col_names,
    categorical_cols=cat_col_names,
    normalize_continuous_features=True,
)

trainer_config = TrainerConfig(
    batch_size=2048,
    max_epochs=1000,
    early_stopping=None, # Turning off Early Stopping
    checkpoints="valid_loss", # Save best checkpoint monitoring val_loss
    load_best=True, # After training, load the best checkpoint
)

optimizer_config = OptimizerConfig()

model_config = CategoryEmbeddingModelConfig(
    task="classification",
    layers="2000-1000",
    activation="ReLU",
    dropout=0.1,
    initialization="kaiming",
    head="LinearHead",
    head_config={
        "layers": "",
        "activation": "ReLU",
    },
    learning_rate = 1e-3
)

tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
)
tabular_model.fit(
    train=finetune_train, 
    validation=finetune_val
)
Global seed set to 42
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓
┃    Name              Type                       Params ┃
┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩
│ 0 │ _backbone        │ CategoryEmbeddingBackbone │  2.2 M │
│ 1 │ _embedding_layer │ Embedding1dLayer          │    276 │
│ 2 │ head             │ LinearHead                │  7.0 K │
│ 3 │ loss             │ CrossEntropyLoss          │      0 │
└───┴──────────────────┴───────────────────────────┴────────┘
Trainable params: 2.2 M                                                                                            
Non-trainable params: 0                                                                                            
Total params: 2.2 M                                                                                                
Total estimated model params size (MB): 8                                                                          

Output()
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
`Trainer.fit` stopped: `max_epochs=1000` reached.


<pytorch_lightning.trainer.trainer.Trainer at 0x7fe8f0c4cd90>
result = tabular_model.evaluate(finetune_test)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(


Output()
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.7247074842453003     │
│         test_loss             0.6507382988929749     │
└───────────────────────────┴───────────────────────────┘
Testing ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1/1 0:00:00 • 0:00:00 0.00it/s  

result = result[0]
result['Type'] = "Supervised"
results.append(result)

Denoising AutoEncoder

Pre-training

Now, we use ~575k unlabelled data to do self-supervised learning

batch_size = 2048
steps_per_epoch = int(ssl_train.shape[0]/batch_size)
epochs = 100
ssl_data_config = DataConfig(
    target=None, #Setting target as None because we don't need the target for SSL
    continuous_cols=num_col_names,
    categorical_cols=cat_col_names,
    normalize_continuous_features=True,
    handle_missing_values=False, # For SSL tasks, missing values and unknwon categories will not be handled automatically
    handle_unknown_categories=False, # Not Setting these configs to False will throw and error when initializing TabularModel 
)

ssl_trainer_config = TrainerConfig(
    batch_size=batch_size,
    max_epochs=epochs,
    early_stopping=None, # Turning off Early Stopping
    checkpoints="valid_loss", # Save best checkpoint monitoring val_loss
    load_best=True, # After training, load the best checkpoint
)

# Setting OneCycleLR schedule
ssl_optimizer_config = OptimizerConfig(
    lr_scheduler="OneCycleLR",
    lr_scheduler_params={
        "max_lr":1e-2, 
        "epochs": epochs, 
        "steps_per_epoch":steps_per_epoch
    }
)

# Setting the encoder config
encoder_config = CategoryEmbeddingModelConfig(
    task="backbone",
    layers="4000-2000-1000-512",  # Number of nodes in each layer
    activation="ReLU",  # Activation between each layers
    head=None, # If not set to None, it will throw a warning
)

# Setting the decoder config.
# NOTE: the last dimension in encoder layers should be first dimension in decoder layers
# i.e. last encoder layer dim = 512, first decoder layer dim = 512
decoder_config = CategoryEmbeddingModelConfig(
    task="backbone",
    layers="512-2048-4096",  # Number of nodes in each layer
    activation="ReLU",  # Activation between each layers
    head=None, # If not set to None, it will throw a warning
)

# DAE Config. No need to set task because it is hardcoded to SSL
# Can't set any loss or metrics as well because for the SSL task
# (especially for DAE), the loss and metrics are fixed.
ssl_model_config = DenoisingAutoEncoderConfig(
    noise_strategy="swap", # Can be 'zero' as well. Defines if noise is swapping features or making features zero
    default_noise_probability = 0.7, # Probability of corruption by noise
    encoder_config=encoder_config, 
    decoder_config=decoder_config, 
    learning_rate=1e-3)

ssl_tabular_model = TabularModel(
    data_config=ssl_data_config,
    model_config=ssl_model_config,
    optimizer_config=ssl_optimizer_config,
    trainer_config=ssl_trainer_config,
)
# Pretraining
ssl_tabular_model.pretrain(train=ssl_train, validation=ssl_val)
Global seed set to 42
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

┏━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓
┃    Name                 Type                            Params ┃
┡━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩
│ 0 │ encoder             │ CategoryEmbeddingBackbone      │ 11.1 M │
│ 1 │ decoder             │ CategoryEmbeddingBackbone      │  9.7 M │
│ 2 │ _featurizer         │ DenoisingAutoEncoderFeaturizer │ 11.1 M │
│ 3 │ _embedding          │ MixedEmbedding1dLayer          │     20 │
│ 4 │ reconstruction      │ MultiTaskHead                  │  581 K │
│ 5 │ mask_reconstruction │ Linear                         │  581 K │
└───┴─────────────────────┴────────────────────────────────┴────────┘
Trainable params: 22.0 M                                                                                           
Non-trainable params: 0                                                                                            
Total params: 22.0 M                                                                                               
Total estimated model params size (MB): 87                                                                         

Output()
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
`Trainer.fit` stopped: `max_epochs=100` reached.


<pytorch_lightning.trainer.trainer.Trainer at 0x7fe8ce2f5720>

Fine-Tuning

Now we create a finetune model using the pretrained weights, and funetune the model for the classification task using the ~5k labelled data.

batch_size = 2048
steps_per_epoch = int(ssl_train.shape[0]/batch_size)
epochs = 1000
ft_trainer_config = TrainerConfig(
    batch_size=batch_size,
    max_epochs=epochs,
    early_stopping=None, # Turning off Early Stopping
    checkpoints="valid_loss", # Save best checkpoint monitoring val_loss
    load_best=True, # After training, load the best checkpoint
)
from torch_optimizer import QHAdam
ft_optimizer_config = OptimizerConfig(
    lr_scheduler="OneCycleLR",
    lr_scheduler_params={
        "max_lr":1e-3, 
        "epochs": epochs, 
        "steps_per_epoch":steps_per_epoch
    }
)

finetune_model = ssl_tabular_model.create_finetune_model(
    task="classification",
    target=[target_name], #Provide the column name of the target as a list
    head="LinearHead",
    head_config={
        "layers": "512-256-512-64",
        "activation": "ReLU",
    },
    trainer_config=ft_trainer_config,# Overriding previous trainer config
    optimizer_config=ft_optimizer_config,
    optimizer=QHAdam, # Using a custom optimizer
    optimizer_params={"nus": (0.7, 1.0), "betas": (0.95, 0.998)}
)
# Check if the new model has the pretrained weights
import torch
assert torch.equal(ssl_tabular_model.model.encoder.linear_layers[0].weight, finetune_model.model._backbone.encoder.linear_layers[0].weight)
finetune_model.finetune(
    train=finetune_train, 
    validation=finetune_val, 
    freeze_backbone=True)
Global seed set to 42
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

┏━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓
┃    Name         Type                       Params ┃
┡━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩
│ 0 │ custom_loss │ CrossEntropyLoss          │      0 │
│ 1 │ _backbone   │ DenoisingAutoEncoderModel │ 22.0 M │
│ 2 │ _head       │ LinearHead                │  558 K │
└───┴─────────────┴───────────────────────────┴────────┘
Trainable params: 558 K                                                                                            
Non-trainable params: 22.0 M                                                                                       
Total params: 22.5 M                                                                                               
Total estimated model params size (MB): 90                                                                         

Output()
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
`Trainer.fit` stopped: `max_epochs=1000` reached.


<pytorch_lightning.trainer.trainer.Trainer at 0x7fe8f07b7850>
result = finetune_model.evaluate(finetune_test)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(


Output()
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy            0.732278048992157     │
│         test_loss              0.644106924533844     │
└───────────────────────────┴───────────────────────────┘
Testing ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1/1 0:00:00 • 0:00:00 0.00it/s  

result = result[0]
result['Type'] = "Self-Supervised"
results.append(result)

We can see that the Self-Supervised appraoch has slightly better accuracy and lower loss. This is not a definite phenomenon and it depends on how well we were able to learn the representation during Pre-training.

pd.DataFrame(results)
test_loss test_accuracy Type
0 0.650738 0.724707 Supervised
1 0.644107 0.732278 Self-Supervised