Skip to content

Using Model Sweep as an initial Model Selection Tool

Pre-requisites: Basic knowledge of Deep Learning and Tabular Problems like Regression and Classification. Also go through the Approaching Any Tabular Problem with PyTorch Tabular tutorial.
Level: Intermediate

In this tutorial, we will look at an easy way to assess the performance different Deep Learning models in PyTorch Tabular on a dataset. Sort of a pycaret style sweep of models. In PyTorch Tabular, we call this Model Sweep.

from rich import print
from rich.pretty import pprint

Data

We will use the Covertype dataset from UCI ML Repository and split it into train and test. We can split into val as well, but even if we don't PyTorch Tabular will automatically do it for us out of the train set.

from pytorch_tabular.utils import load_covertype_dataset
from sklearn.model_selection import train_test_split

data, cat_col_names, num_col_names, target_col = load_covertype_dataset()
train, test = train_test_split(data, random_state=42, test_size=0.2)
print(f"Train Shape: {train.shape} | Test Shape: {test.shape}")
Train Shape: (464809, 13) | Test Shape: (116203, 13)

Defining the Config

As you saw in the basic tutorial, we need to define a set of configs. Even for model sweep, we need to define all configs except the ModelConfig. We will keep most of it defaults, but set some congis to control the training process: - Automatic Learning Rate Finding - Batch Size - Max Epochs - Turning off Progress Bar and Model Summary so taht it won't clutter the output.

from pytorch_tabular.config import (
    DataConfig,
    OptimizerConfig,
    TrainerConfig,
)
from pytorch_tabular.models.common.heads import LinearHeadConfig

data_config = DataConfig(
    target=[target_col],
    continuous_cols=num_col_names,
    categorical_cols=cat_col_names,
)
trainer_config = TrainerConfig(
    batch_size=1024,
    max_epochs=25,
    auto_lr_find=True,
    early_stopping=None,  # Monitor valid_loss for early stopping
    # early_stopping_mode="min",  # Set the mode as min because for val_loss, lower is better
    # early_stopping_patience=5,  # No. of epochs of degradation training will wait before terminating
    checkpoints="valid_loss",  # Save best checkpoint monitoring val_loss
    load_best=True,  # After training, load the best checkpoint
    progress_bar="none",  # Turning off Progress bar
    trainer_kwargs=dict(enable_model_summary=False),  # Turning off model summary
    accelerator="cpu",
)
optimizer_config = OptimizerConfig()

head_config = LinearHeadConfig(
    layers="",
    dropout=0.1,
    initialization=(  # No additional layer in head, just a mapping layer to output_dim
        "kaiming"
    ),
).__dict__  # Convert to dict to pass to the model config (OmegaConf doesn't accept objects)

Model Sweep

The model sweep enables you to quickly sweep thorugh different models and configurations. It takes in a list of model configs or one of the presets defined in pytorch_tabular.MODEL_PRESETS and trains them on the data. It then ranks the models based on the metric provided and returns the best model.

These are the major arguments to the model_sweep function: - task: The type of prediction task. Either 'classification' or 'regression' - train: The training data - test: The test data on which performance is evaluated - Configs: All the config objects can be passed as either the object or the path to the yaml file. - model_list: The list of models to compare. This can be one of the presets defined in pytorch_tabular.MODEL_SWEEP_PRESETS or a list of ModelConfig objects.

There are three presets defined in pytorch_tabular.MODEL_SWEEP_PRESETS:

from pytorch_tabular import MODEL_SWEEP_PRESETS

print(list(MODEL_SWEEP_PRESETS.keys()))
['lite', 'standard', 'full', 'high_memory']
  1. lite : This is a set of models that are fast to train. This is the default value for model_list. The models and its hyperparameters parameters are carefully chosen such that they have comparable # of parameters, trains relatively faster, and gives good results. The models included are:
pprint(MODEL_SWEEP_PRESETS["lite"])
(
('CategoryEmbeddingModelConfig', {'layers': '256-128-64'}),
('GANDALFConfig', {'gflu_stages': 6}),
('TabNetModelConfig', {'n_d': 32, 'n_a': 32, 'n_steps': 3, 'gamma': 1.5, 'n_independent': 1, 'n_shared': 2})
)
  1. standard : This is a set of models that have less than or around a 100 thousand learnable parameters so that it's still not high memory requirement. All the models from the lite presets are also included. The models and its hyperparameters parameters are carefully chosen such that they have comparable # of parameters, and gives good results. The models included are:
pprint(MODEL_SWEEP_PRESETS["standard"])
(
('CategoryEmbeddingModelConfig', {'layers': '256-128-64'}),
('CategoryEmbeddingModelConfig', {'layers': '512-128-64'}),
('GANDALFConfig', {'gflu_stages': 6}),
('GANDALFConfig', {'gflu_stages': 15}),
('TabNetModelConfig', {'n_d': 32, 'n_a': 32, 'n_steps': 3, 'gamma': 1.5, 'n_independent': 1, 'n_shared': 2}),
('TabNetModelConfig', {'n_d': 32, 'n_a': 32, 'n_steps': 5, 'gamma': 1.5, 'n_independent': 2, 'n_shared': 3}),
('FTTransformerConfig', {'num_heads': 4, 'num_attn_blocks': 4})
)
  1. full: This is a full sweep of the models, with default hyperparameters, implemented in PyTorch Tabular, except for Mixed Density Networks (which is a specialized model for probabilistic regression) and NODE (which is a model which require high compute and memory). The models included are:
pprint(list(MODEL_SWEEP_PRESETS["full"]))
[
'AutoIntConfig',
'CategoryEmbeddingModelConfig',
'DANetConfig',
'FTTransformerConfig',
'GANDALFConfig',
'GatedAdditiveTreeEnsembleConfig',
'TabNetModelConfig',
'TabTransformerConfig'
]
  1. high_memory: This is a full sweep of the models, with default hyperparameters, implemented in PyTorch Tabular, except for Mixed Density Networks (which is a specialized model for probabilistic regression). This option is only recommended if you have ample memory to hold the model and data in your CPU/GPU. The models included are:
pprint(list(MODEL_SWEEP_PRESETS["high_memory"]))
[
'AutoIntConfig',
'CategoryEmbeddingModelConfig',
'DANetConfig',
'FTTransformerConfig',
'GANDALFConfig',
'GatedAdditiveTreeEnsembleConfig',
'NodeConfig',
'TabNetModelConfig',
'TabTransformerConfig'
]
  • metrics, metrics_params, metrics_prob_input: The metrics to use for evaluation. These parameters hold the same meaning as in the ModelConfig.
  • rank_metric: This is the metric to use for ranking the models. This is a Tuple with the first element as the metric name and the second element is the direction (if it is lower_the_better or hgher_the_better). Defaults to ('loss', "lower_is_better").
  • return_best_model: If True, will return the best model. Defaults to True.

Now let's try and run the sweep on the Covertype dataset, using the lite preset.

%%time
from pytorch_tabular import model_sweep
import warnings

# Filtering out the warnings
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    sweep_df, best_model = model_sweep(
        task="classification",  # One of "classification", "regression"
        train=train,
        test=test,
        data_config=data_config,
        optimizer_config=optimizer_config,
        trainer_config=trainer_config,
        model_list="lite",
        common_model_args=dict(head="LinearHead", head_config=head_config),
        metrics=["accuracy", "f1_score"],
        metrics_params=[{}, {"average": "macro"}],
        metrics_prob_input=[False, True],
        rank_metric=("accuracy", "higher_is_better"),
        progress_bar=True,
        verbose=False,
        suppress_lightning_logger=True,
    )
Output()
Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]
Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]
Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]


CPU times: user 2h 29min 42s, sys: 15.8 s, total: 2h 29min 58s
Wall time: 16min 37s

The output, sweep_df is a pandas dataframe with the following columns: - model : The name of the model - # Params : The number of trainable parameters in the model - test_loss : The loss on the test set - test_<metric> : The metric value on the test set - time_taken : The time taken to train the model - epochs : The number of epochs trained - time_taken_per_epoch : The time taken per epoch - params : The config used to train the model

Let's check which model performed the best.

sweep_df.drop(columns=["params", "time_taken", "epochs"]).style.background_gradient(
    subset=["test_accuracy", "test_f1_score"], cmap="RdYlGn"
).background_gradient(subset=["time_taken_per_epoch", "test_loss"], cmap="RdYlGn_r")
  model # Params test_loss test_accuracy test_f1_score time_taken_per_epoch
1 GANDALFModel 43 T 0.189933 0.924494 0.924418 10.985013
2 TabNetModel 50 T 0.259448 0.895175 0.894817 19.809555
0 CategoryEmbeddingModel 51 T 0.302084 0.878024 0.876729 7.634541

We have trained three fast models on the dataset in ~15 mins on CPU. That is pretty fast. We can see that the GANDALF model performed the best in terms of accuracy, loss and f1 score. We can also see that the training time is comparable to regular MLP. A natural next step would be to tune the model a bit more and find the best parameters.

Or, if you have more time, access to a decent size GPU, and want to try out more models, you can try the standard preset. Even on a CPU, it may run for a couple of hours only. But it will give you a good idea of the performance of different models.

Let' try and run the standard preset.

%%time
# Filtering out the warnings
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    sweep_df, best_model = model_sweep(
        task="classification",  # One of "classification", "regression"
        train=train,
        test=test,
        data_config=data_config,
        optimizer_config=optimizer_config,
        trainer_config=trainer_config,
        model_list="standard",
        common_model_args=dict(head="LinearHead", head_config=head_config),
        metrics=["accuracy", "f1_score"],
        metrics_params=[{}, {"average": "macro"}],
        metrics_prob_input=[False, True],
        rank_metric=("accuracy", "higher_is_better"),
        progress_bar=True,
        verbose=False,
        suppress_lightning_logger=True,
    )
Output()
Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]
Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]
Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]
Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]
Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]
Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]
Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]


CPU times: user 10h 11min 4s, sys: 2min 16s, total: 10h 13min 20s
Wall time: 1h 6min 18s

sweep_df.drop(columns=["params", "time_taken", "epochs"]).style.background_gradient(
    subset=["test_accuracy", "test_f1_score"], cmap="RdYlGn"
).background_gradient(subset=["time_taken_per_epoch", "test_loss"], cmap="RdYlGn_r")
  model # Params test_loss test_accuracy test_f1_score time_taken_per_epoch
3 GANDALFModel 107 T 0.163602 0.935071 0.935061 15.870558
1 CategoryEmbeddingModel 93 T 0.233573 0.906560 0.905311 9.128509
6 FTTransformerModel 117 T 0.243499 0.900330 0.900065 63.771070
2 GANDALFModel 43 T 0.257583 0.898075 0.897640 10.899241
4 TabNetModel 50 T 0.260693 0.894461 0.894012 18.629878
0 CategoryEmbeddingModel 51 T 0.263826 0.893875 0.894207 7.868230
5 TabNetModel 129 T 0.534261 0.766813 0.760403 32.926586

The larger GANDALF model performed the best in terms of accuracy, loss and f1 score. Although the training time is slightly higher than the comparable MLP, it is still pretty fast.

Now, apart from using the presets, you can also pass a list of ModelConfig objects. Let's try and run a sweep with a list of ModelConfig objects.

from pytorch_tabular.models import CategoryEmbeddingModelConfig, GANDALFConfig
common_params = {
    "task": "classification",
    "head":"LinearHead", "head_config":head_config
}
model_list = [
    CategoryEmbeddingModelConfig(layers="1024-512-256", **common_params),
    GANDALFConfig(gflu_stages=2, **common_params),
    GANDALFConfig(gflu_stages=6, learnable_sparsity=False, **common_params),
]

# Filtering out the warnings
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    sweep_df, best_model = model_sweep(
        task="classification",  # One of "classification", "regression"
        train=train,
        test=test,
        data_config=data_config,
        optimizer_config=optimizer_config,
        trainer_config=trainer_config,
        model_list=model_list,
        metrics=["accuracy", "f1_score"],
        metrics_params=[{}, {"average": "macro"}],
        metrics_prob_input=[False, True],
        rank_metric=("accuracy", "higher_is_better"),
        progress_bar=True,
        verbose=False,
        suppress_lightning_logger=True,
    )
Output()
Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]
Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]


sweep_df.drop(columns=["params", "time_taken", "epochs"]).style.background_gradient(
    subset=["test_accuracy", "test_f1_score"], cmap="RdYlGn"
).background_gradient(subset=["time_taken_per_epoch", "test_loss"], cmap="RdYlGn_r")
  model # Params test_loss test_accuracy test_f1_score time_taken_per_epoch
0 CategoryEmbeddingModel 694 T 0.276405 0.888075 0.795560 14.553613
1 GANDALFModel 15 T 0.284878 0.885967 0.797202 8.369561
2 GANDALFModel 43 T 0.287677 0.884142 0.793678 10.864214

Although we chose some random hyperparameters, we can see that the GANDALF model performed very close to the MLP, at a fraction of the Parameters and lower training time.

Congrats!: You have learned how to use Model Sweep in PyTorch Tabular to check multiple models on a single dataset. This would be a very useful first step when deciding which models to use for your problem.
Now try to use this in your own dataset. You can also try to use the `full` preset and see how it performs.