Skip to content
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error, mean_squared_error
import random
import numpy as np
import pandas as pd
import os
# %load_ext autoreload
# %autoreload 2

Utility Functions

def make_mixed_regression(n_samples, n_features, n_categories):
    X,y = make_regression(n_samples=n_samples, n_features=n_features, random_state=42, n_informative=5, n_targets=1)
    cat_cols = random.choices(list(range(X.shape[-1])),k=n_categories)
    num_cols = [i for i in range(X.shape[-1]) if i not in cat_cols]
    for col in cat_cols:
        X[:,col] = pd.qcut(X[:,col], q=4).codes.astype(int)
    col_names = [] 
    num_col_names=[]
    cat_col_names=[]
    for i in range(X.shape[-1]):
        if i in cat_cols:
            col_names.append(f"cat_col_{i}")
            cat_col_names.append(f"cat_col_{i}")
        if i in num_cols:
            col_names.append(f"num_col_{i}")
            num_col_names.append(f"num_col_{i}")
    X = pd.DataFrame(X, columns=col_names)
    y = pd.DataFrame(y, columns=["target"])
    data = X.join(y)
    return data, cat_col_names, num_col_names

def print_metrics(y_true, y_pred, tag):
    if isinstance(y_true, pd.DataFrame) or isinstance(y_true, pd.Series):
        y_true = y_true.values
    if isinstance(y_pred, pd.DataFrame) or isinstance(y_pred, pd.Series):
        y_pred = y_pred.values
    if y_true.ndim>1:
        y_true=y_true.ravel()
    if y_pred.ndim>1:
        y_pred=y_pred.ravel()
    val_acc = mean_squared_error(y_true, y_pred)
    val_f1 = mean_absolute_error(y_true, y_pred)
    print(f"{tag} MSE: {val_acc} | {tag} MAE: {val_f1}")

Generate Synthetic Data

First of all, let's create a synthetic data which is a mix of numerical and categorical features

data, cat_col_names, num_col_names = make_mixed_regression(n_samples=10000, n_features=20, n_categories=4)
train, test = train_test_split(data, random_state=42)
train, val = train_test_split(train, random_state=42)

Importing the Library

from pytorch_tabular import TabularModel
from pytorch_tabular.models import CategoryEmbeddingModelConfig
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig, ExperimentConfig, ModelConfig
from pytorch_tabular.models import BaseModel
from pytorch_tabular.models.common.layers import Embedding1dLayer
from pytorch_tabular.models.common.heads import LinearHeadConfig

Defining a Custom Model

import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import DictConfig
from typing import Dict
from dataclasses import dataclass, field

PyTorch Tabular is very easy to extend and infinitely customizable. All the models that have been implemented in PyTorch Tabular inherits an Abstract Class BaseModel which is in fact a PyTorchLightning Model.

It handles all the major functions like decoding the config params and setting up the loss and metrics. It also calculates the Loss and metrics and feeds it back to the PyTorch Lightning Trainer which does the back-propagation.

If we look at the anatomy of a PyTorch Tabular model, there are three main components:

  1. Embedding Layer
  2. Backbone
  3. Head

Embedding Layer takes the input from the dataloader, which is a dictionary with categorical and continuous tensors under those keys. The Embedding Layer converts this dictionary to a single tensor, handling the categorical tensors the right way. There are two Embedding Layers already implemented, EmbeddingLayer1d, EmbeddingLayer2d.

Backbone is the main model architecture.

Head is the linear layers which takes the output of the backbone and converts it into the output we desire.

To encapsulate and enforce this structure, BaseModel requires us to define three property methods: 1. def embedding_layer(self) 2. def backbone(self) 3. def head(self)

There is another method def _build_network(self) which also needs to be defined mandatorily. We can use this method to define the embedding, backbone, head, and whatever other layers or components we need to use in the model.

For standard Feed Forward layers as the head, we can also use a handly method in BaseModel called _get_head_from_config which will use the head and head_config you have set in the ModelConfig to initialize the right head automatically for you.

An example of using it:

self._head = self._get_head_from_config()

For standard flows, the forward method that is already defined would be enough.

def forward(self, x: Dict):
    x = self.embed_input(x) # Embeds the input dictionary and returns a tensor
    x = self.compute_backbone(x) # Takes the tensor input and does representation learning
    return self.compute_head(x) # transforms the backbone output to desired output, applies target range if necessary, and packs the output in the desired format.
What this allows us to do is to define any standard PyTorch model and use it as the backbone, and then use the rest of the PyTorch Tabular machinery to train, evaluate and log the model.

While this is the bare minimum, you can redefine or use any of the Pytorch Lightning standard methods to tweak your model and training to your liking. The real beauty of this setup is that how much customization you need to do is really upto you. For standard workflows, you can change the minimum. And for highly unsual models, you can re-define any of the method in BaseModel (as long as the interfaces remain unchanged).

In addition to the model, you will also need to define a config. Configs are python dataclasses and should inherit ModelConfig and will have all the parameters of the ModelConfig. by default. Any additional parameter should be defined in the dataclass.

Key things to note:

  1. All the different parameters in the different configs(like TrainerConfig, OptimizerConfig, etc) are all available in config before calling super() and in self.hparams after.
  2. the input batch at the forward method is a dictionary with keys continuous and categorical
  3. In the \_build_network method, save every component that you want access in the forward to self
# Define a Config class to hold the hyperparameters of the model
# We need to inherit ModelConfig so that default parameters like learning rate, head, head_config, etc. 
# are present in the config
@dataclass
class MyAwesomeModelConfig(ModelConfig):
    use_batch_norm: bool = True

# Define the core model as a pure PyTorch model
# The forward method takes in a tensor and outputs a tensor

class MyAwesomeRegressionModel(nn.Module):
    def __init__(
        self,
        hparams: DictConfig,
        **kwargs
    ):
        super().__init__()
        self.hparams = hparams # Save the config to be accessed in the model elsewhere
        self._build_network()

    def _build_network(self):
        #Continuous and Categorical Dimensions are precalculated and stored in the config (InferredConfig)
        inp_dim = self.hparams.embedded_cat_dim + self.hparams.continuous_dim
        self.linear_layer_1 = nn.Linear(inp_dim, 200)
        self.linear_layer_2 = nn.Linear(inp_dim+200, 70)
        self.linear_layer_3 = nn.Linear(inp_dim+70, 70)
        self.input_batch_norm = nn.BatchNorm1d(self.hparams.continuous_dim)
        if self.hparams.use_batch_norm:
            self.batch_norm_2 = nn.BatchNorm1d(inp_dim+200)
            self.batch_norm_3 = nn.BatchNorm1d(inp_dim+70)
        self.dropout = nn.Dropout(0.3)
        self.output_dim = 70

    def forward(self, x: torch.Tensor):
        inp = x
        x = F.relu(self.linear_layer_1(x))
        x = self.dropout(x)
        x = torch.cat([x,inp], 1)
        if self.hparams.use_batch_norm:
            x = self.batch_norm_1(x)
        x = F.relu(self.linear_layer_2(x))
        x = self.dropout(x)
        x = torch.cat([x,inp], 1)
        if self.hparams.use_batch_norm:
            x = self.batch_norm_3(x)
        x = self.linear_layer_3(x)
        return x

    # It is also good practice to bundle the embedding scheme along with the core model
    # This is so that the model encapsulates it's requirement of how to embed the input
    # dictionary of categorical and continuous tensors and how to combine them into a single tensor.
    # We can either use one of the predefined embedding layers in PyTorch Tabular
    # Or define a custom embedding layer as well. Check Embedding1dLayer implementation 
    # on how to implement an embedding layer
    def _build_embedding_layer(self):
        return Embedding1dLayer(
            continuous_dim=self.hparams.continuous_dim,
            categorical_embedding_dims=self.hparams.embedding_dims,
            embedding_dropout=self.hparams.embedding_dropout,
            batch_norm_continuous_input=self.hparams.batch_norm_continuous_input,
        )

# Define the PyTorch Tabular model by inheriting BaseModel

class MyAwesomeRegressionPTModel(BaseModel):
    def __init__(
        self,
        config: DictConfig,
        **kwargs
    ):
        # Save any attribute that you need in _build_network before calling super()
        # After the super() call, the config will be saved as `hparams`
        super().__init__(config, **kwargs)

    def _build_network(self):
        # Backbone - Initialize the PyTorch model we defined earlier
        self._backbone = MyAwesomeRegressionModel(self.hparams)
        # Initializing Embedding Layer using the method we defined in the backbone
        self._embedding_layer = self._backbone._build_embedding_layer()
        # Head - we can use the helper method to initialize standard linear head
        self.head = self._get_head_from_config()

    # Now define the property methods which the BaseModel requires you to override
    @property
    def backbone(self):
        return self._backbone

    @property
    def embedding_layer(self):
        return self._embedding_layer

    @property
    def head(self):
        return self._head

    # For more customizations, we can override forward, compute_backbone, compute_head etc.

Define the Configs

There is one deviation from the normal when we create a TabularModel object with the configs. Earlier the model was inferred from the config and initialized autmatically. But here, we have to use the model_callable parameter of the TabularModel and pass in the model class(not the initialized object)

data_config = DataConfig(
    target=['target'], #target should always be a list. Multi-targets are only supported for regression. Multi-Task Classification is not implemented
    continuous_cols=num_col_names,
    categorical_cols=cat_col_names,
)
trainer_config = TrainerConfig(
    auto_lr_find=True, # Runs the LRFinder to automatically derive a learning rate
    batch_size=1024,
    max_epochs=100,
    accelerator="auto", # can be 'cpu','gpu', 'tpu', or 'ipu' 
    devices=-1 # -1 means use all available
)
optimizer_config = OptimizerConfig()

head_config = LinearHeadConfig(
    layers="32", # Number of nodes in each layer
    activation="ReLU", # Activation between each layers
    dropout=0.1,
    initialization="kaiming"
).__dict__ # Convert to dict to pass to the model config (OmegaConf doesn't accept objects)

model_config = MyAwesomeModelConfig(
    task="regression",
    use_batch_norm =False,
    head = "LinearHead", #Linear Head
    head_config = head_config, # Linear Head Config
    learning_rate = 1e-3
)

tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
    model_callable = MyAwesomeRegressionPTModel # When using custom model, we need to use the model_callable parameter
)

Training the Model

The rest of the process is business-as-usual

tabular_model.fit(train=train, validation=val)
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/utilities/seed.py:48: LightningDeprecationWarning: `pytorch_lightning.utilities.seed.seed_everything` has been deprecated in v1.8.0 and will be removed in v1.10.0. Please use `lightning_lite.utilities.seed.seed_everything` instead.
  rank_zero_deprecation(
Global seed set to 42
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/categorical_encoders.py:48: FutureWarning: In a future version, `df.iloc[:, i] = newvals` will attempt to set the values inplace instead of always setting a new array. To retain the old behavior, use either `df[df.columns[i]] = newvals` or, if columns are non-unique, `df.isetitem(i, newvals)`
  X_encoded.loc[:, col] = X_encoded[col].fillna(NAN_CATEGORY).map(mapping["value"])
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/categorical_encoders.py:48: FutureWarning: In a future version, `df.iloc[:, i] = newvals` will attempt to set the values inplace instead of always setting a new array. To retain the old behavior, use either `df[df.columns[i]] = newvals` or, if columns are non-unique, `df.isetitem(i, newvals)`
  X_encoded.loc[:, col] = X_encoded[col].fillna(NAN_CATEGORY).map(mapping["value"])
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/categorical_encoders.py:48: FutureWarning: In a future version, `df.iloc[:, i] = newvals` will attempt to set the values inplace instead of always setting a new array. To retain the old behavior, use either `df[df.columns[i]] = newvals` or, if columns are non-unique, `df.isetitem(i, newvals)`
  X_encoded.loc[:, col] = X_encoded[col].fillna(NAN_CATEGORY).map(mapping["value"])
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/categorical_encoders.py:48: FutureWarning: In a future version, `df.iloc[:, i] = newvals` will attempt to set the values inplace instead of always setting a new array. To retain the old behavior, use either `df[df.columns[i]] = newvals` or, if columns are non-unique, `df.isetitem(i, newvals)`
  X_encoded.loc[:, col] = X_encoded[col].fillna(NAN_CATEGORY).map(mapping["value"])
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/categorical_encoders.py:48: FutureWarning: In a future version, `df.iloc[:, i] = newvals` will attempt to set the values inplace instead of always setting a new array. To retain the old behavior, use either `df[df.columns[i]] = newvals` or, if columns are non-unique, `df.isetitem(i, newvals)`
  X_encoded.loc[:, col] = X_encoded[col].fillna(NAN_CATEGORY).map(mapping["value"])
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/categorical_encoders.py:48: FutureWarning: In a future version, `df.iloc[:, i] = newvals` will attempt to set the values inplace instead of always setting a new array. To retain the old behavior, use either `df[df.columns[i]] = newvals` or, if columns are non-unique, `df.isetitem(i, newvals)`
  X_encoded.loc[:, col] = X_encoded[col].fillna(NAN_CATEGORY).map(mapping["value"])
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/categorical_encoders.py:48: FutureWarning: In a future version, `df.iloc[:, i] = newvals` will attempt to set the values inplace instead of always setting a new array. To retain the old behavior, use either `df[df.columns[i]] = newvals` or, if columns are non-unique, `df.isetitem(i, newvals)`
  X_encoded.loc[:, col] = X_encoded[col].fillna(NAN_CATEGORY).map(mapping["value"])
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/categorical_encoders.py:48: FutureWarning: In a future version, `df.iloc[:, i] = newvals` will attempt to set the values inplace instead of always setting a new array. To retain the old behavior, use either `df[df.columns[i]] = newvals` or, if columns are non-unique, `df.isetitem(i, newvals)`
  X_encoded.loc[:, col] = X_encoded[col].fillna(NAN_CATEGORY).map(mapping["value"])
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/checkpoints exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(

Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]
`Trainer.fit` stopped: `max_steps=98` reached.
LR finder stopped early after 98 steps due to diverging loss.
Learning rate set to 0.017378008287493765
Restoring states from the checkpoint path at /home/manujosephv/pytorch_tabular/docs/.lr_find_e7471621-02d6-41e1-b21f-a52d621d2e17.ckpt
Restored all states from the checkpoint file at /home/manujosephv/pytorch_tabular/docs/.lr_find_e7471621-02d6-41e1-b21f-a52d621d2e17.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type                     | Params
--------------------------------------------------------------
0 | _backbone        | MyAwesomeRegressionModel | 28.8 K
1 | _embedding_layer | Embedding1dLayer         | 92    
2 | head             | LinearHead               | 2.3 K 
3 | loss             | MSELoss                  | 0     
--------------------------------------------------------------
31.2 K    Trainable params
0         Non-trainable params
31.2 K    Total params
0.125     Total estimated model params size (MB)

Sanity Checking: 0it [00:00, ?it/s]
Training: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
`Trainer.fit` stopped: `max_epochs=100` reached.
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/utilities/cloud_io.py:41: LightningDeprecationWarning: `pytorch_lightning.utilities.cloud_io.load` has been deprecated in v1.8.0 and will be removed in v1.10.0. This function is internal but you can copy over its implementation.
  rank_zero_deprecation(

<pytorch_lightning.trainer.trainer.Trainer at 0x7fd65349e2c0>
result = tabular_model.evaluate(test)
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/categorical_encoders.py:48: FutureWarning: In a future version, `df.iloc[:, i] = newvals` will attempt to set the values inplace instead of always setting a new array. To retain the old behavior, use either `df[df.columns[i]] = newvals` or, if columns are non-unique, `df.isetitem(i, newvals)`
  X_encoded.loc[:, col] = X_encoded[col].fillna(NAN_CATEGORY).map(mapping["value"])
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/categorical_encoders.py:48: FutureWarning: In a future version, `df.iloc[:, i] = newvals` will attempt to set the values inplace instead of always setting a new array. To retain the old behavior, use either `df[df.columns[i]] = newvals` or, if columns are non-unique, `df.isetitem(i, newvals)`
  X_encoded.loc[:, col] = X_encoded[col].fillna(NAN_CATEGORY).map(mapping["value"])
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/categorical_encoders.py:48: FutureWarning: In a future version, `df.iloc[:, i] = newvals` will attempt to set the values inplace instead of always setting a new array. To retain the old behavior, use either `df[df.columns[i]] = newvals` or, if columns are non-unique, `df.isetitem(i, newvals)`
  X_encoded.loc[:, col] = X_encoded[col].fillna(NAN_CATEGORY).map(mapping["value"])
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/categorical_encoders.py:48: FutureWarning: In a future version, `df.iloc[:, i] = newvals` will attempt to set the values inplace instead of always setting a new array. To retain the old behavior, use either `df[df.columns[i]] = newvals` or, if columns are non-unique, `df.isetitem(i, newvals)`
  X_encoded.loc[:, col] = X_encoded[col].fillna(NAN_CATEGORY).map(mapping["value"])
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(

Testing: 0it [00:00, ?it/s]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│         test_loss              478.0108642578125     │
│  test_mean_squared_error       478.0108642578125     │
└───────────────────────────┴───────────────────────────┘
pred_df = tabular_model.predict(test)
pred_df.head()
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/categorical_encoders.py:48: FutureWarning: In a future version, `df.iloc[:, i] = newvals` will attempt to set the values inplace instead of always setting a new array. To retain the old behavior, use either `df[df.columns[i]] = newvals` or, if columns are non-unique, `df.isetitem(i, newvals)`
  X_encoded.loc[:, col] = X_encoded[col].fillna(NAN_CATEGORY).map(mapping["value"])
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/categorical_encoders.py:48: FutureWarning: In a future version, `df.iloc[:, i] = newvals` will attempt to set the values inplace instead of always setting a new array. To retain the old behavior, use either `df[df.columns[i]] = newvals` or, if columns are non-unique, `df.isetitem(i, newvals)`
  X_encoded.loc[:, col] = X_encoded[col].fillna(NAN_CATEGORY).map(mapping["value"])
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/categorical_encoders.py:48: FutureWarning: In a future version, `df.iloc[:, i] = newvals` will attempt to set the values inplace instead of always setting a new array. To retain the old behavior, use either `df[df.columns[i]] = newvals` or, if columns are non-unique, `df.isetitem(i, newvals)`
  X_encoded.loc[:, col] = X_encoded[col].fillna(NAN_CATEGORY).map(mapping["value"])
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/categorical_encoders.py:48: FutureWarning: In a future version, `df.iloc[:, i] = newvals` will attempt to set the values inplace instead of always setting a new array. To retain the old behavior, use either `df[df.columns[i]] = newvals` or, if columns are non-unique, `df.isetitem(i, newvals)`
  X_encoded.loc[:, col] = X_encoded[col].fillna(NAN_CATEGORY).map(mapping["value"])

Generating Predictions...:   0%|          | 0/3 [00:00<?, ?it/s]
num_col_0 cat_col_1 num_col_2 cat_col_3 num_col_4 num_col_5 num_col_6 cat_col_7 num_col_8 num_col_9 ... num_col_12 cat_col_13 num_col_14 num_col_15 num_col_16 num_col_17 num_col_18 num_col_19 target target_prediction
6252 0.321476 0.0 -0.836426 1.0 -1.372801 0.148776 1.607678 0.0 0.099704 2.494107 ... 0.543315 3.0 -0.719307 -0.416442 -0.843505 0.571309 0.150040 -0.636704 -119.618988 -102.397652
4684 0.291679 1.0 -0.213108 3.0 1.209858 -0.684209 0.065715 0.0 -2.164594 -1.212303 ... -0.053648 3.0 -1.304521 -1.007395 0.304803 0.262136 -0.638452 0.672491 -207.596232 -171.570618
1731 -1.547951 1.0 1.517188 1.0 2.356890 0.826815 -0.570187 1.0 0.787585 0.027579 ... 0.154850 1.0 0.221052 1.993319 0.028488 -0.347405 1.121574 -0.146075 272.098656 236.952728
4742 0.911628 3.0 0.089328 1.0 0.984190 -1.114405 0.594178 0.0 -0.994555 -0.379163 ... 0.426304 1.0 -0.775785 -1.001061 -0.725295 -1.177682 -0.511682 -0.721897 21.896867 21.157124
4521 0.087945 2.0 -0.320962 1.0 0.423397 -0.512270 -0.314670 1.0 -0.386701 0.966912 ... 1.021263 3.0 0.851912 1.296487 1.079245 1.077255 0.327339 -0.365532 46.346326 43.655518

5 rows × 22 columns

print_metrics(test['target'], pred_df["target_prediction"], tag="Holdout")
Holdout MSE: 478.0109113972704 | Holdout MAE: 17.182815680793645