Skip to content

Approaching any Tabular Problem using PyTorch Tabular

Pre-requisites: Basic knowledge of Machine Learning and Tabular Problems like Regression and Classification
Level: Beginner

In this tutorial, we will look at how to tackle any tabular machine learning problem (classification or regression) using PyTorch Tabular. We will use the Covertype dataset from the UCI repository. The dataset contains 581012 rows and 54 columns. The dataset is a multi-class classification problem. The goal is to predict the forest cover type from cartographic variables only (no remotely sensed data).

In a typical machine learning workflow, we would do the following steps: 1. Load the dataset
2. Analyze the dataset
3. Split the dataset into train and test
4. Preprocess the dataset
5. Define the model
6. Train the model
7. Make predictions on new data
8. Evaluate the model

Let's see how we do the same using PyTorch Tabular

Step 1: Load the Data

Cover Type Dataset

Predicting forest cover type from cartographic variables only (no remotely sensed data). The actual forest cover type for a given observation (30 x 30 meter cell) was determined from US Forest Service (USFS) Region 2 Resource Information System (RIS) data. Independent variables were derived from data originally obtained from US Geological Survey (USGS) and USFS data. Data is in raw form (not scaled) and contains binary (0 or 1) columns of data for qualitative independent variables (wilderness areas and soil types).

This study area includes four wilderness areas located in the Roosevelt National Forest of northern Colorado. These areas represent forests with minimal human-caused disturbances, so that existing forest cover types are more a result of ecological processes rather than forest management practices.

There is a simple utility method in PyTorch Tabular to load this particular dataset. It downloads the data from UCI ML Repository. The original dataset has two categorical information - Soil Type and Wilderness Area - but one-hot encoded. The utility method converts them to categorical columns to make it more closer to real-life datasets in the wild.

from pytorch_tabular.utils import load_covertype_dataset
data, _, _, _ = load_covertype_dataset()

Step 2: Analyze the dataset

In this step, we will explore the data to understand the data better. Exploratory Data Analysis (EDA) can be many things and it depends on the data and the problem we are trying to solve. And this can help us understand the data better and make some decisions on how to proceed with the data. But here, we will restrict ourselves to the most basic data analysis; just enough to understand which are the continuous and categorical columns, and if there are any missing values.

from rich import print
# One of the easiest ways to identify categorical features is using the pandas select_dtypes function.
categorical_features = data.select_dtypes(include=['object'])
print(categorical_features.columns)
Index(['Wilderness_Area', 'Soil_Type'], dtype='object')

But this may not be always reliable. For example, if we have a column called month and it has values from 1 to 12, then it is a categorical column. But select_dtypes will treat it as a continuous column. So, we need to be careful and use our judgement.

# Another way is to use the unique values in each column.
for col in data.columns:
    print(col, len(data[col].unique()))
Elevation 1978
Aspect 361
Slope 67
Horizontal_Distance_To_Hydrology 551
Vertical_Distance_To_Hydrology 700
Horizontal_Distance_To_Roadways 5785
Hillshade_9am 207
Hillshade_Noon 185
Hillshade_3pm 255
Horizontal_Distance_To_Fire_Points 5827
Cover_Type 7
Wilderness_Area 4
Soil_Type 40

But this is not reliable either. For example, we have a column called Soil_Type and it has 40 unique values. How do we decide if it is a categorical column or a continuous column? We need to use our judgement here as well.

And reading the data description, understanding the domain, and using our judgement is the best way to decide if a column is categorical or continuous.

Here we will consider Wilderness_Area and Soil_Type as categorical features. We know Cover_Type is the target column and that makes the rest of the columns continuous features.

# This separation have already been done for you while loading this particular dataset from `PyTorch Tabular`. Let's load the dataset in the right way.
data, cat_col_names, num_col_names, target_col = load_covertype_dataset()
# Let's also print out a few details
print(f"Data Shape: {data.shape} | # of cat cols: {len(cat_col_names)} | # of num cols: {len(num_col_names)}")
print(f"[bold dodger_blue2] Features: {num_col_names + cat_col_names}[/bold dodger_blue2]")
print(f"[bold purple4]Target: {target_col}[/bold purple4]")
Data Shape: (581012, 13) | # of cat cols: 2 | # of num cols: 10
 Features: ['Elevation', 'Aspect', 'Slope', 'Horizontal_Distance_To_Hydrology', 'Vertical_Distance_To_Hydrology', 
'Horizontal_Distance_To_Roadways', 'Hillshade_9am', 'Hillshade_Noon', 'Hillshade_3pm', 
'Horizontal_Distance_To_Fire_Points', 'Wilderness_Area', 'Soil_Type']
Target: Cover_Type
Note Supervised Learning reduces to finding a function that maps inputs to outputs. The inputs are called features and the outputs are called targets. The features can be continuous or categorical. The targets can be continuous or categorical. In classification, the targets are categorical. In regression, the targets are continuous.
# Let's also check the data for missing values
print(data.isna().sum())
Elevation                             0
Aspect                                0
Slope                                 0
Horizontal_Distance_To_Hydrology      0
Vertical_Distance_To_Hydrology        0
Horizontal_Distance_To_Roadways       0
Hillshade_9am                         0
Hillshade_Noon                        0
Hillshade_3pm                         0
Horizontal_Distance_To_Fire_Points    0
Cover_Type                            0
Wilderness_Area                       0
Soil_Type                             0
dtype: int64

Great news! There is no missing values in the dataset. If there were any missing values, we need to handle them. Kaggle has a good tutorial on how to handle missing values. You can find it here.

Note PyTorch Tabular can deal with mising values in categorical features natively, but missing values in numerical features need to be handled separately.

Step 3 - Split the dataset into train and test

Now, in all tabular problems, when we apply machine learning we need to have a training set, validation set and a test set. We will use the training set to train the model, validation set to make modelling decisions(like the hyperparameters, or kind of model to use etc.) and the test set to evaluate the final model. Since the dataset doesn't come with a test set, we will split the training set into training, validation and test set.

from sklearn.model_selection import train_test_split
train, test = train_test_split(data, random_state=42, test_size=0.2)
train, val = train_test_split(train, random_state=42, test_size=0.2)
print(f"Train Shape: {train.shape} | Val Shape: {val.shape} | Test Shape: {test.shape}")
Train Shape: (371847, 13) | Val Shape: (92962, 13) | Test Shape: (116203, 13)

Step 4: Preprocess the dataset

In a typical machine learning project, this is the most time consuming step where we create new features, clean the data, handle missing values, handle outliers, scale the data, encode categorical features and so on.

In a scikit-learn based project, a pseudo code for this step would look like this:

data = create_new_features(data)
data = clean_data(data)
data = handle_missing_values(data)
data = handle_outliers(data)
data, cat_encoder = encode_categorical_features(data)
data, scaler = scale_data(data)
X, y = split_features_target(data)

But one of the allures of deep learning is that we don't need to spend time on feature engineering. We can just use the raw data and let the model figure out the best features to use. But we still need to do some data preparation. And for that, PyTorch Tabular takes care of some of these needs:

  • Missing values in categorical features are handled natively
  • Categorical features are encoded automatically using embeddings
  • Continuous features are scaled automatically using StandardScaler
  • Date features like month, day, year are extracted automatically
  • Target transformation like log, power, quantile, box-cox can be enabled with a parameter. This will also handle the inverse tranformation automatically.
  • Continuous features can be transformed using box-cox, quantile normal etc. with a parameter

While we have all these features, we can also choose to do any of these manually. For example, we can choose to encode categorical features using one hot encoding or target encoding and consider them as continuous features. We can also choose to scale the continuous features using MinMaxScaler or RobustScaler and turn off the automatic scaling.

So, here, we won't be doing any of these. We will just use the data as is and let PyTorch Tabular handle the rest.

Step 5: Define the Model

Now, we will define the model. In a scikit-learn workflow, we would have done the following steps:

model = SomeModel(**parameters)

This is where PyTorch Tabular is different because we need to define a few configs before we define the model. One of the reasons is that PyTorch Tabular model handles a lot of things under the hood. So, we need to tell the model what kind of data we are dealing with. We also need to define the training dynamics, along with the model parameters. The configs we need to define are:

  1. DataConfig - This is where we define the data related configs like the target column, categorical columns, continuous columns, date columns, categorical embedding dimensions, etc. But the good news is that most of these are optional. If we don't define them, PyTorch Tabular will try to infer them from the data or have thumb rules to handle them. The bare minimum we need to define is the target column name, continuous columns and categorical columns. Categorical columns are embedded by default, numerical columns scaled by default and date columns are extracted by default.

  2. TrainerConfig - This is where we define the training related configs like the batch size, number of epochs, early stopping, etc. Again, all of these are optional. If we don't define them, PyTorch Tabular will use some default values. By default, PyTorch Tabular runs with a batch size of 64, with early stopping with a patience of 3 epochs and checkpointing enabled. This means that the model will be saved at the end of every epoch and the best model will be saved. The model will stop training if the validation loss doesn't improve for 3 epochs. Although all of TrainerConfig is optional, it is infinitely customizable. And with the entire PyTorch Lightning Trainer exposed, either through explicit parameters in TrainerConfig or through a catch-all trainer_kwargs parameter in TrainerConfig.

  3. OptimizerConfig - This is where we define the optimizer related configs like the the kind of optimizer, weight decay, learning rate schedulers, etc. Again, all of these are optional. If we don't define them, PyTorch Tabular will use some default values. By default, PyTorch Tabular uses the Adam optimizer. It doesn't use any learning rate decay by default. Although all of OptimizerConfig is optional, it is also customizable.

  4. ExperimentConfig - This is where we define how to track the experiment for logging and reproducibility. By default, PyTorch Tabular uses tensorboard for logging. But we can also use wandb. We can also choose to not log anything (although not recommended) by not defining an ExperimentConfig.

  5. <modelspecificconfig> - This is where we define which model to use and the corresponding hyperparameters. In PyTorch Tabular, each of the implemented model has their own config class. For example, if we want to use TabNet, we need to define TabNetConfig. If we want to use GANDALF, we need to define GANDALFConfig, and so on. Each of these config classes have their own set of model specific hyperparameters, as well as some common parameters like the loss function, metrics, learning rate, etc. Again, all of these are optional. If we don't define them, PyTorch Tabular will use some default values. Learning RAte is set to 1e-3 by default. The loss function is set to CrossEntropyLoss for classification and MSELoss for regression. The metrics are set to Accuracy for classification and MSE for regression. And all the model specific hyperparameters are set to suggested default values in their respective papers, or some default values that work well in practice.

Here, let's use GANDALF Model. We will define the configs as follows:

from pytorch_tabular.models import GANDALFConfig
from pytorch_tabular.config import (
    DataConfig,
    OptimizerConfig,
    TrainerConfig,
)

data_config = DataConfig(
    target=[
        target_col
    ],  # target should always be a list. Multi-targets are only supported for regression. Multi-Task Classification is not implemented
    continuous_cols=num_col_names,
    categorical_cols=cat_col_names,
)
trainer_config = TrainerConfig(
    batch_size=1024,
    max_epochs=100,
)
optimizer_config = OptimizerConfig()
model_config = GANDALFConfig(
    task="classification",
    gflu_stages=6,
    gflu_feature_init_sparsity=0.3,
    gflu_dropout=0.0,
    learning_rate=1e-3,
)

Now that we have defined all the configs, we can define the TabularModel. Apart from the configs, there are some additional parameters we can pass to the model to control the verbosity of the model.

from pytorch_tabular import TabularModel

tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
    verbose=True
)
2024-01-07 04:39:30,992 - {pytorch_tabular.tabular_model:140} - INFO - Experiment Tracking is turned off           

We can see that since we passed verbose=True, it has already logged that the Experiment Tracking is disabled.

Step 6 : Train the model

Now, we can train the model. In the scikit-learn workflow, we would have done the following:

model.fit(X_train, y_train)

In PyTorch Tabular, there are two ways we can do this: - High-Level API - A fit method which is very similar to the scikit-learn API, but the fit method has a lot more parameters to control the training dynamics. This is the recommended way to train the model. - Low-Level API - A collection of methods - prepare_dataloader, prepare_model, and train. This is for advanced users who want to have more control over the training process.

Let's stick to the high-level API in this introductory tutorial. We will use the fit method to train the model. There is only one compulsory parameter for the fit method - train data. We can also pass the validation data explicitly. If not provided, it'll use 20% of training data as validation data. In addition to this there are many other parameters like custom loss functions, metrics, custom optimizers, etc. which can be used to make the training process more customizable.

tabular_model.fit(train=train, validation=val)
Seed set to 42

2024-01-07 04:39:31,059 - {pytorch_tabular.tabular_model:524} - INFO - Preparing the DataLoaders                   
2024-01-07 04:39:31,500 - {pytorch_tabular.tabular_datamodule:499} - INFO - Setting up the datamodule for          
classification task                                                                                                
2024-01-07 04:39:32,358 - {pytorch_tabular.tabular_model:574} - INFO - Preparing the Model: GANDALFModel           
2024-01-07 04:39:32,591 - {pytorch_tabular.tabular_model:340} - INFO - Preparing the Trainer                       
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

2024-01-07 04:39:32,839 - {pytorch_tabular.tabular_model:630} - INFO - Auto LR Find Started                        
You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:639: Checkpoint directory saved_models exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.

Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]
`Trainer.fit` stopped: `max_steps=100` reached.
Learning rate set to 0.02089296130854041
Restoring states from the checkpoint path at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_6d1c6109-882a-4b7f-939c-2d42ecd8ff06.ckpt
Restored all states from the checkpoint at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_6d1c6109-882a-4b7f-939c-2d42ecd8ff06.ckpt

2024-01-07 04:39:37,498 - {pytorch_tabular.tabular_model:643} - INFO - Suggested LR: 0.02089296130854041. For plot 
and detailed analysis, use `find_learning_rate` method.                                                            
2024-01-07 04:39:37,500 - {pytorch_tabular.tabular_model:652} - INFO - Training Started                            
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓
┃    Name              Type              Params ┃
┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩
│ 0 │ _backbone        │ GANDALFBackbone  │ 42.4 K │
│ 1 │ _embedding_layer │ Embedding1dLayer │    896 │
│ 2 │ _head            │ Sequential       │    252 │
│ 3 │ loss             │ CrossEntropyLoss │      0 │
└───┴──────────────────┴──────────────────┴────────┘
Trainable params: 43.6 K                                                                                           
Non-trainable params: 0                                                                                            
Total params: 43.6 K                                                                                               
Total estimated model params size (MB): 0                                                                          
Output()


2024-01-07 04:41:25,635 - {pytorch_tabular.tabular_model:663} - INFO - Training the model completed                
2024-01-07 04:41:25,636 - {pytorch_tabular.tabular_model:1487} - INFO - Loading the best model                     
<pytorch_lightning.trainer.trainer.Trainer at 0x7f1dbab82c10>

Step 7: Making predictions on new data

Now that we have trained the model, we can make predictions on new data. In a scikit-learn workflow, we would have done the following:

y_pred = model.predict(X_test)
y_pred_proba = model.predict_proba(X_test)

In PyTorch Tabular, we can do something very similar. We can use the predict method to make predictions on new data. This method returns the predictions as a pandas dataframe predictions. For classification problems, it returns the class probabilities, and final prediction class based on 0.5 threshold. All we have to do is pass in a dataframe with atleast all the features that was used for training.

pred_df = tabular_model.predict(test)
pred_df.head()
1_probability 2_probability 3_probability 4_probability 5_probability 6_probability 7_probability prediction
250728 0.901409 0.001267 1.025811e-08 9.070018e-08 0.000040 3.519856e-08 9.728358e-02 1
246788 0.156802 0.843021 8.172029e-07 2.142834e-09 0.000171 2.749175e-07 4.734764e-06 2
407714 0.001035 0.969636 4.896594e-03 4.262948e-06 0.019907 4.521028e-03 8.038983e-07 2
25713 0.289917 0.709881 1.039616e-05 5.966012e-08 0.000152 3.714674e-05 1.749813e-06 2
21820 0.000729 0.870874 2.740137e-05 3.132881e-06 0.128357 9.515504e-06 9.939656e-09 2

Step 8: Evaluating the Model

Now, we can evaluate the model. In the scikit-learn workflow, we would have done the following:

pred_df = model.predict(X_test)
accuracy = accuracy_score(y_test, pred_df)

In PyTorch Tabular, there are two ways we can do this: - Get the predictions on the test set and calculate the metrics manually - Use the evaluate method which will return the metrics (the same ones we have defined during training)

We will see the second way here. We can use the evaluate method to evaluate the model on the test set. This method returns a dictionary of metrics

result = tabular_model.evaluate(test)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

Output()
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.

┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy            0.878411054611206     │
│         test_loss             0.2998563051223755     │
└───────────────────────────┴───────────────────────────┘


result
[{'test_loss': 0.2998563051223755, 'test_accuracy': 0.878411054611206}]

Step 9: Saving and Loading the Model

After the model is trained, we can save the model and load it later to make predictions on new data. In a scikit-learn workflow, we would have done the following:

joblib.dump(model, "model.joblib")
model = joblib.load("model.joblib")

In PyTorch Tabular, we can do something very similar. We can use the save_model method to save the model. This method saves everything required to make predictions on new data. By default it also saves the datamodule, which contains the training data, validation data, and test data as well. But we can choose to not save the datamodule by setting inference_only=True.

tabular_model.save_model("examples/basic")
2024-01-07 04:43:51,268 - {pytorch_tabular.tabular_model:1531} - WARNING - Directory is not empty. Overwriting the 
contents.                                                                                                          

Now we can load the saved model using the load_model method. This method returns the model and the datamodule. We can use the model to make predictions on new data.

loaded_model = TabularModel.load_model("examples/basic")
2024-01-07 04:43:51,948 - {pytorch_tabular.tabular_model:165} - INFO - Experiment Tracking is turned off           
2024-01-07 04:43:51,953 - {pytorch_tabular.tabular_model:340} - INFO - Preparing the Trainer                       
Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

# Let's check if we get the same result on test data using the loaded model
result = loaded_model.evaluate(test)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

Output()
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.

┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy            0.878411054611206     │
│         test_loss             0.2998563051223755     │
└───────────────────────────┴───────────────────────────┘


Congrats!: You have trained a SOTA deep learning model data. Things you can try:
  1. Check out the PyTorch Tabular Documentation to learn more about the library
  2. Use alternate models like TabNet, CategoryEmbedding, etc.
  3. Use different datasets and try out the workflow.
  4. Check out other tutorials and how-to guides in the documentation.
Now try to use these features in your own projects and Kaggle competitions. If you have any questions, please feel free to ask them in the GitHub Discussions