Skip to content

Model Stacking in PyTorch Tabular

This page demonstrates how to use model stacking functionality in PyTorch Tabular to combine multiple models for better predictions.

Setup and Imports

import warnings
warnings.filterwarnings("ignore")
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from pytorch_tabular import TabularModel
from pytorch_tabular.models import (
CategoryEmbeddingModelConfig,
FTTransformerConfig,
TabNetModelConfig
)
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
from pytorch_tabular.models.stacking import StackingModelConfig
from pytorch_tabular.utils import make_mixed_dataset

Create synthetic classification dataset & split into train, validation and test sets

data, cat_col_names, num_col_names = make_mixed_dataset(
    task="classification", n_samples=3000, n_features=7, n_categories=4
)

train, test = train_test_split(data, random_state=42)
train, valid = train_test_split(train, random_state=42)

Common configurations

data_config = DataConfig(
    target=["target"],
    continuous_cols=num_col_names,
    categorical_cols=cat_col_names,
)
trainer_config = TrainerConfig(
    batch_size=1024,
    max_epochs=20,
    early_stopping="valid_accuracy",
    early_stopping_mode="max",
    early_stopping_patience=3,
    checkpoints="valid_accuracy",
    load_best=True,
)
optimizer_config = OptimizerConfig()

Configure individual models

model_config_1 = CategoryEmbeddingModelConfig(
    task="classification",
    layers="128-64-32",
    activation="ReLU",
    learning_rate=1e-3
)
model_config_2 = FTTransformerConfig(
    task="classification",
    input_embed_dim=32,
    num_attn_blocks=2,
    num_heads=4,
    learning_rate=1e-3
)
model_config_3 = TabNetModelConfig(
    task="classification",
    n_d=8,
    n_a=8,
    n_steps=3,
    learning_rate=1e-3
)

Configure Stacking Model

Now let's set up the stacking configuration that will combine these models:

stacking_config = StackingModelConfig(
    task="classification",
    model_configs=[
        model_config_1,
        model_config_2,
        model_config_3
    ],
    head="LinearHead",
    head_config={
        "layers": "64",
        "activation": "ReLU",
        "dropout": 0.1
    },
    learning_rate=1e-3
)

Train Stacking Model

stacking_model = TabularModel(
    data_config=data_config,
    model_config=stacking_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
)
stacking_model.fit(
    train=train,
    validation=valid
)
2024-12-12 00:02:35,338 - {pytorch_tabular.tabular_model:147} - INFO - Experiment Tracking is turned off           
Seed set to 42

2024-12-12 00:02:35,388 - {pytorch_tabular.tabular_model:549} - INFO - Preparing the DataLoaders                   
2024-12-12 00:02:35,394 - {pytorch_tabular.tabular_datamodule:527} - INFO - Setting up the datamodule for          
classification task                                                                                                
2024-12-12 00:02:35,462 - {pytorch_tabular.tabular_model:600} - INFO - Preparing the Model: StackingModel          
2024-12-12 00:02:35,516 - {pytorch_tabular.tabular_model:343} - INFO - Preparing the Trainer                       
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

2024-12-12 00:02:35,813 - {pytorch_tabular.tabular_model:679} - INFO - Training Started                            
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━┓
┃    Name              Type                    Params  Mode  ┃
┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━┩
│ 0 │ _backbone        │ StackingBackbone       │ 77.2 K │ train │
│ 1 │ _embedding_layer │ StackingEmbeddingLayer │    917 │ train │
│ 2 │ _head            │ LinearHead             │ 12.5 K │ train │
│ 3 │ loss             │ CrossEntropyLoss       │      0 │ train │
└───┴──────────────────┴────────────────────────┴────────┴───────┘
Trainable params: 90.6 K                                                                                           
Non-trainable params: 0                                                                                            
Total params: 90.6 K                                                                                               
Total estimated model params size (MB): 0                                                                          
Modules in train mode: 188                                                                                         
Modules in eval mode: 0                                                                                            
Output()

2024-12-12 00:02:39,304 - {pytorch_tabular.tabular_model:692} - INFO - Training the model completed                
2024-12-12 00:02:39,307 - {pytorch_tabular.tabular_model:1533} - INFO - Loading the best model                     
<pytorch_lightning.trainer.trainer.Trainer at 0x7fb1a508d420>

Evaluate Results

predictions = stacking_model.predict(test)
stacking_metrics = stacking_model.evaluate(test)[0]
stacking_acc = stacking_metrics["test_accuracy"]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

Output()
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.5960000157356262     │
│         test_loss             0.7419928312301636     │
│        test_loss_0            0.7419928312301636     │
└───────────────────────────┴───────────────────────────┘

Compare with individual models

def train_and_evaluate_model(model_config, name):
    model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
    )
    model.fit(train=train, validation=valid)
    metrics = model.evaluate(test)
    print(f"\n{name} Metrics:")
    print(metrics)
    return metrics
ce_metrics = train_and_evaluate_model(model_config_1, "Category Embedding")[0]
ft_metrics = train_and_evaluate_model(model_config_2, "FT Transformer")[0]
tab_metrics = train_and_evaluate_model(model_config_3, "TabNet")[0]
ce_acc = ce_metrics["test_accuracy"]
ft_acc = ft_metrics["test_accuracy"]
tab_acc = tab_metrics["test_accuracy"]
2024-12-12 00:09:01,257 - {pytorch_tabular.tabular_model:147} - INFO - Experiment Tracking is turned off           
Seed set to 42

2024-12-12 00:09:01,320 - {pytorch_tabular.tabular_model:549} - INFO - Preparing the DataLoaders                   
2024-12-12 00:09:01,340 - {pytorch_tabular.tabular_datamodule:527} - INFO - Setting up the datamodule for          
classification task                                                                                                
2024-12-12 00:09:01,376 - {pytorch_tabular.tabular_model:600} - INFO - Preparing the Model: CategoryEmbeddingModel 
2024-12-12 00:09:01,411 - {pytorch_tabular.tabular_model:343} - INFO - Preparing the Trainer                       
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

2024-12-12 00:09:01,638 - {pytorch_tabular.tabular_model:679} - INFO - Training Started                            
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━┓
┃    Name              Type                       Params  Mode  ┃
┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━┩
│ 0 │ _backbone        │ CategoryEmbeddingBackbone │ 12.1 K │ train │
│ 1 │ _embedding_layer │ Embedding1dLayer          │     53 │ train │
│ 2 │ head             │ LinearHead                │     66 │ train │
│ 3 │ loss             │ CrossEntropyLoss          │      0 │ train │
└───┴──────────────────┴───────────────────────────┴────────┴───────┘
Trainable params: 12.2 K                                                                                           
Non-trainable params: 0                                                                                            
Total params: 12.2 K                                                                                               
Total estimated model params size (MB): 0                                                                          
Modules in train mode: 19                                                                                          
Modules in eval mode: 0                                                                                            
Output()
`Trainer.fit` stopped: `max_epochs=20` reached.


2024-12-12 00:09:04,935 - {pytorch_tabular.tabular_model:692} - INFO - Training the model completed                
2024-12-12 00:09:04,938 - {pytorch_tabular.tabular_model:1533} - INFO - Loading the best model                     
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

Output()
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.4586666524410248     │
│         test_loss             0.8828091025352478     │
│        test_loss_0            0.8828091025352478     │
└───────────────────────────┴───────────────────────────┘


Category Embedding Metrics:
[{'test_loss_0': 0.8828091025352478, 'test_loss': 0.8828091025352478, 'test_accuracy': 0.4586666524410248}]

2024-12-12 00:09:05,183 - {pytorch_tabular.tabular_model:147} - INFO - Experiment Tracking is turned off           
Seed set to 42

2024-12-12 00:09:05,263 - {pytorch_tabular.tabular_model:549} - INFO - Preparing the DataLoaders                   
2024-12-12 00:09:05,272 - {pytorch_tabular.tabular_datamodule:527} - INFO - Setting up the datamodule for          
classification task                                                                                                
2024-12-12 00:09:05,294 - {pytorch_tabular.tabular_model:600} - INFO - Preparing the Model: FTTransformerModel     
2024-12-12 00:09:05,323 - {pytorch_tabular.tabular_model:343} - INFO - Preparing the Trainer                       
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

2024-12-12 00:09:05,623 - {pytorch_tabular.tabular_model:679} - INFO - Training Started                            
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━┓
┃    Name              Type                   Params  Mode  ┃
┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━┩
│ 0 │ _backbone        │ FTTransformerBackbone │ 57.7 K │ train │
│ 1 │ _embedding_layer │ Embedding2dLayer      │    864 │ train │
│ 2 │ _head            │ LinearHead            │     66 │ train │
│ 3 │ loss             │ CrossEntropyLoss      │      0 │ train │
└───┴──────────────────┴───────────────────────┴────────┴───────┘
Trainable params: 58.6 K                                                                                           
Non-trainable params: 0                                                                                            
Total params: 58.6 K                                                                                               
Total estimated model params size (MB): 0                                                                          
Modules in train mode: 56                                                                                          
Modules in eval mode: 0                                                                                            
Output()

2024-12-12 00:09:07,482 - {pytorch_tabular.tabular_model:692} - INFO - Training the model completed                
2024-12-12 00:09:07,488 - {pytorch_tabular.tabular_model:1533} - INFO - Loading the best model                     
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

Output()
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.5546666383743286     │
│         test_loss             0.6846821904182434     │
│        test_loss_0            0.6846821904182434     │
└───────────────────────────┴───────────────────────────┘


FT Transformer Metrics:
[{'test_loss_0': 0.6846821904182434, 'test_loss': 0.6846821904182434, 'test_accuracy': 0.5546666383743286}]

2024-12-12 00:09:07,824 - {pytorch_tabular.tabular_model:147} - INFO - Experiment Tracking is turned off           
Seed set to 42

2024-12-12 00:09:07,863 - {pytorch_tabular.tabular_model:549} - INFO - Preparing the DataLoaders                   
2024-12-12 00:09:07,870 - {pytorch_tabular.tabular_datamodule:527} - INFO - Setting up the datamodule for          
classification task                                                                                                
2024-12-12 00:09:07,900 - {pytorch_tabular.tabular_model:600} - INFO - Preparing the Model: TabNetModel            
2024-12-12 00:09:07,965 - {pytorch_tabular.tabular_model:343} - INFO - Preparing the Trainer                       
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

2024-12-12 00:09:08,200 - {pytorch_tabular.tabular_model:679} - INFO - Training Started                            
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━┓
┃    Name              Type              Params  Mode  ┃
┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━┩
│ 0 │ _embedding_layer │ Identity         │      0 │ train │
│ 1 │ _backbone        │ TabNetBackbone   │  6.4 K │ train │
│ 2 │ _head            │ Identity         │      0 │ train │
│ 3 │ loss             │ CrossEntropyLoss │      0 │ train │
└───┴──────────────────┴──────────────────┴────────┴───────┘
Trainable params: 6.4 K                                                                                            
Non-trainable params: 0                                                                                            
Total params: 6.4 K                                                                                                
Total estimated model params size (MB): 0                                                                          
Modules in train mode: 111                                                                                         
Modules in eval mode: 0                                                                                            
Output()

2024-12-12 00:09:09,766 - {pytorch_tabular.tabular_model:692} - INFO - Training the model completed                
2024-12-12 00:09:09,767 - {pytorch_tabular.tabular_model:1533} - INFO - Loading the best model                     
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

Output()
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.4346666634082794     │
│         test_loss             1.1570961475372314     │
│        test_loss_0            1.1570961475372314     │
└───────────────────────────┴───────────────────────────┘


TabNet Metrics:
[{'test_loss_0': 1.1570961475372314, 'test_loss': 1.1570961475372314, 'test_accuracy': 0.4346666634082794}]

print("Stacking Model Test Accuracy: {}".format(stacking_acc))
print("Category Embedding Model Test Accucacy: {}".format(ce_acc))
print("FT Transformer Model Test Accuracy: {}".format(ft_acc))
print("TabNet Model Test Accuracy: {}".format(tab_acc))
Stacking Model Test Accuracy: 0.5960000157356262
Category Embedding Model Test Accucacy: 0.4586666524410248
FT Transformer Model Test Accuracy: 0.5546666383743286
TabNet Model Test Accuracy: 0.4346666634082794

Save the stacking model & load it

stacking_model.save_model("stacking_model")
2024-12-12 00:00:31,524 - {pytorch_tabular.tabular_model:1579} - WARNING - Directory is not empty. Overwriting the 
contents.                                                                                                          
loaded_model = TabularModel.load_model("stacking_model")
2024-12-12 00:00:32,437 - {pytorch_tabular.tabular_model:172} - INFO - Experiment Tracking is turned off           
2024-12-12 00:00:32,452 - {pytorch_tabular.tabular_model:343} - INFO - Preparing the Trainer                       
Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

Key Points About Stacking

  1. The stacking model combines predictions from multiple base models into a final prediction
  2. Each base model can have its own architecture and hyperparameters
  3. The head layer combines the outputs from all base models
  4. Base models are trained simultaneously
  5. The stacking model can often achieve better performance than individual models

Tips for Better Stacking Results

  1. Use diverse base models that capture different aspects of the data
  2. Experiment with different head architectures
  3. Consider using cross-validation for more robust stacking
  4. Balance model complexity with training time
  5. Monitor individual model performances to ensure they contribute meaningfully

This example demonstrates basic stacking functionality. For production use cases, you may want to: - Use cross-validation - Implement more sophisticated ensemble techniques - Add custom metrics - Tune hyperparameters for both base models and stacking head