Skip to content

Search Best Architecture and Hyperparameter

Sometimes (or often) we do not know exactly which architecture is the best for our data. In artificial intelligence, it is common for an architecture to be the best for one dataset and not so good for another dataset. To try to help to find the best solution, this Notebook will use two main function in PyTorch Tabular. One of them is Sweep to run all architecture available in PyTorch Tabular with default hyperparameters to search for the possible best architecture for our data. Afterward, we will use Tuner to search for the best hyperparameter of the best architecture that we found in Sweep.

import warnings
warnings.filterwarnings("ignore")

from sklearn.model_selection import train_test_split

from pytorch_tabular.utils import make_mixed_dataset
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig

Data

First of all, let's create a synthetic data which is a mix of numerical and categorical features and have multiple targets for classification. It means that there are multiple columns which we need to predict with the same set of features.

data, cat_col_names, num_col_names = make_mixed_dataset(
    task="classification", n_samples=3000, n_features=7, n_categories=4
)

train, test = train_test_split(data, random_state=42)
train, valid = train_test_split(train, random_state=42)

Common Configs

data_config = DataConfig(
    target=[
        "target"
    ],
    continuous_cols=num_col_names,
    categorical_cols=cat_col_names,
)
trainer_config = TrainerConfig(
    batch_size=32,
    max_epochs=50,
    early_stopping="valid_accuracy",
    early_stopping_mode="max",
    early_stopping_patience=3,
    checkpoints="valid_accuracy",
    load_best=True,
    progress_bar="none"
)
optimizer_config = OptimizerConfig()

Model Sweep

https://pytorch-tabular.readthedocs.io/en/latest/apidocs_coreclasses/#pytorch_tabular.model_sweep

Let's train all available models ("high_memory"). If some of them return as "OOM" it means that you do not have enough memory to run in the current batch_size. You can ignore that model or reduce the batch_size in TrainerConfig.

from pytorch_tabular import model_sweep
sweep_df, best_model = model_sweep(
                            task="classification",
                            train=train,
                            test=valid,
                            data_config=data_config,
                            optimizer_config=optimizer_config,
                            trainer_config=trainer_config,
                            model_list="high_memory",
                            verbose=False # Make True if you want to log metrics and params each trial
                        )
Output()
2024-07-20 12:47:01,862 - {pytorch_tabular.models.node.node_model:73} - INFO - Data Aware Initialization of NODE   
using a forward pass with 2000 batch size....                                                                      


best_model.evaluate(test)
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.8053333163261414     │
│         test_loss             0.44678735733032227    │
└───────────────────────────┴───────────────────────────┘
[{'test_loss': 0.44678735733032227, 'test_accuracy': 0.8053333163261414}]

In the following table, we can see the best models (with default hyperparameters) for our dataset. But we are not satisfied, so in this case we will take the top two models and use Tuner to find better hyperparameters and have a better result.

PS: Each time that run the Notebook the result may change a little, so you might see different top model that we will use in the next section.

sweep_df.drop(columns=["params", "time_taken", "epochs"]).sort_values("test_accuracy", ascending=False).style.background_gradient(
    subset=["test_accuracy"], cmap="RdYlGn"
).background_gradient(subset=["time_taken_per_epoch", "test_loss"], cmap="RdYlGn_r")
  model # Params test_loss test_accuracy time_taken_per_epoch
1 CategoryEmbeddingModel 12 T 0.458506 0.797513 0.190966
3 FTTransformerModel 272 T 0.486184 0.770870 0.529126
4 GANDALFModel 8 T 0.562945 0.705151 0.341467
8 TabTransformerModel 272 T 0.547346 0.696270 0.470920
0 AutoIntModel 14 T 0.580009 0.689165 0.360073
5 GatedAdditiveTreeEnsembleModel 79 T 0.673274 0.660746 3.624957
2 DANetModel 431 T 0.692986 0.644760 2.104359
6 NODEModel 864 T 0.676671 0.626998 1.497243
7 TabNetModel 6 T 0.708919 0.538188 0.484836

Model Tuner

https://pytorch-tabular.readthedocs.io/en/latest/apidocs_coreclasses/#pytorch_tabular.TabularModelTuner

Perfect!! Now that we know the best models, let take the top two and play with their hyperparameters to try find better results.

from pytorch_tabular.models import (
    CategoryEmbeddingModelConfig,
    FTTransformerConfig
)   

We can use two main strategies: - grid_search: to search for all hyperparameters that were defined, but remember that each new fields that you add will considerably increase the total training time. If you configure 4 optimizers, 4 layes, 2 activations and 2 dropout, that means 64 (4 * 4 * 2 * 3) trainings. - random_search: will get randomly get "n_trials" hyperparameters settings from each model that has been defined. It is useful for faster training, but remember that will not test all hyperparameters.

For all hyperparameters options: https://pytorch-tabular.readthedocs.io/en/latest/apidocs_model/

More informations about how the hyperparameter spaces work: https://pytorch-tabular.readthedocs.io/en/latest/tutorials/10-Hyperparameter%20Tuning/#define-the-hyperparameter-space

Let's define some hyperparameters.

PS: This Notebook is to exemplify the functions and does not mean that are the best hyperparameters to try.

search_space_category_embedding = {
    "optimizer_config__optimizer": ["Adam", "SGD"],
    "model_config__layers": ["128-64-32", "1024-512-256", "32-64-128", "256-512-1024"],
    "model_config__activation": ["ReLU", "LeakyReLU"],
    "model_config__embedding_dropout": [0.0, 0.2],
}
model_config_category_embedding = CategoryEmbeddingModelConfig(task="classification")
search_space_ft_transformer = {
    "optimizer_config__optimizer": ["Adam", "SGD"],
    "model_config__input_embed_dim": [32, 64],
    "model_config__num_attn_blocks": [3, 6, 8],
    "model_config__ff_hidden_multiplier": [4, 8],
    "model_config__transformer_activation": ["GEGLU", "LeakyReLU"],
    "model_config__embedding_dropout": [0.0, 0.2],
}
model_config_ft_transformer = FTTransformerConfig(task="classification")

Let's add all search spaces and model configs in list.

Important They must be in the same order and same length

search_spaces = [search_space_category_embedding, search_space_ft_transformer]
model_configs = [model_config_category_embedding, model_config_ft_transformer]
from pytorch_tabular.tabular_model_tuner import TabularModelTuner
tuner = TabularModelTuner(
    data_config=data_config,
    model_config=model_configs,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config
)
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    tuner_df = tuner.tune(
        train=train,
        validation=valid,
        search_space=search_spaces,
        strategy="grid_search",  # random_search
        # n_trials=5,
        metric="accuracy",
        mode="max",
        progress_bar=True,
        verbose=False # Make True if you want to log metrics and params each trial
    )
Output()


Nice!!! We now know the best architecture and possible hyperparameters for our dataset. Maybe the result is not good enough, but at least will reduce the options. With these results, we will know better which are the best hyperparameters that can be better explored and others that do not make sense to continue using.

It is even a good idea to explore the architecture paper so that, who knows, it can guide you further towards the best hyperparameters.

tuner_df.trials_df.sort_values("accuracy", ascending=False).style.background_gradient(
    subset=["accuracy"], cmap="RdYlGn"
).background_gradient(subset=["loss"], cmap="RdYlGn_r")
  trial_id model model_config__activation model_config__embedding_dropout model_config__layers optimizer_config__optimizer loss accuracy model_config__ff_hidden_multiplier model_config__input_embed_dim model_config__num_attn_blocks model_config__transformer_activation
22 22 0-CategoryEmbeddingModelConfig LeakyReLU 0.000000 256-512-1024 Adam 0.339012 0.857904 nan nan nan nan
26 26 0-CategoryEmbeddingModelConfig LeakyReLU 0.200000 1024-512-256 Adam 0.375515 0.817052 nan nan nan nan
20 20 0-CategoryEmbeddingModelConfig LeakyReLU 0.000000 32-64-128 Adam 0.368664 0.815275 nan nan nan nan
2 2 0-CategoryEmbeddingModelConfig ReLU 0.000000 1024-512-256 Adam 0.407023 0.813499 nan nan nan nan
6 6 0-CategoryEmbeddingModelConfig ReLU 0.000000 256-512-1024 Adam 0.445294 0.811723 nan nan nan nan
10 10 0-CategoryEmbeddingModelConfig ReLU 0.200000 1024-512-256 Adam 0.446737 0.811723 nan nan nan nan
18 18 0-CategoryEmbeddingModelConfig LeakyReLU 0.000000 1024-512-256 Adam 0.444420 0.808170 nan nan nan nan
30 30 0-CategoryEmbeddingModelConfig LeakyReLU 0.200000 256-512-1024 Adam 0.398530 0.797513 nan nan nan nan
14 14 0-CategoryEmbeddingModelConfig ReLU 0.200000 256-512-1024 Adam 0.455243 0.781528 nan nan nan nan
72 40 1-FTTransformerConfig nan 0.000000 nan Adam 0.445089 0.779751 8.000000 64.000000 6.000000 GEGLU
8 8 0-CategoryEmbeddingModelConfig ReLU 0.200000 128-64-32 Adam 0.486341 0.776199 nan nan nan nan
16 16 0-CategoryEmbeddingModelConfig LeakyReLU 0.000000 128-64-32 Adam 0.458817 0.776199 nan nan nan nan
116 84 1-FTTransformerConfig nan 0.200000 nan Adam 0.471312 0.776199 8.000000 64.000000 3.000000 GEGLU
62 30 1-FTTransformerConfig nan 0.000000 nan Adam 0.475959 0.774423 8.000000 32.000000 6.000000 LeakyReLU
36 4 1-FTTransformerConfig nan 0.000000 nan Adam 0.506062 0.772647 4.000000 32.000000 6.000000 GEGLU
28 28 0-CategoryEmbeddingModelConfig LeakyReLU 0.200000 32-64-128 Adam 0.503373 0.769094 nan nan nan nan
0 0 0-CategoryEmbeddingModelConfig ReLU 0.000000 128-64-32 Adam 0.482425 0.769094 nan nan nan nan
60 28 1-FTTransformerConfig nan 0.000000 nan Adam 0.495479 0.767318 8.000000 32.000000 6.000000 GEGLU
56 24 1-FTTransformerConfig nan 0.000000 nan Adam 0.519672 0.767318 8.000000 32.000000 3.000000 GEGLU
80 48 1-FTTransformerConfig nan 0.200000 nan Adam 0.518865 0.765542 4.000000 32.000000 3.000000 GEGLU
74 42 1-FTTransformerConfig nan 0.000000 nan Adam 0.483879 0.763766 8.000000 64.000000 6.000000 LeakyReLU
64 32 1-FTTransformerConfig nan 0.000000 nan Adam 0.575869 0.763766 8.000000 32.000000 8.000000 GEGLU
94 62 1-FTTransformerConfig nan 0.200000 nan Adam 0.484891 0.761989 4.000000 64.000000 3.000000 LeakyReLU
66 34 1-FTTransformerConfig nan 0.000000 nan Adam 0.506116 0.761989 8.000000 32.000000 8.000000 LeakyReLU
52 20 1-FTTransformerConfig nan 0.000000 nan Adam 0.511868 0.761989 4.000000 64.000000 8.000000 GEGLU
96 64 1-FTTransformerConfig nan 0.200000 nan Adam 0.482814 0.760213 4.000000 64.000000 6.000000 GEGLU
110 78 1-FTTransformerConfig nan 0.200000 nan Adam 0.479574 0.758437 8.000000 32.000000 6.000000 LeakyReLU
19 19 0-CategoryEmbeddingModelConfig LeakyReLU 0.000000 1024-512-256 SGD 0.532006 0.756661 nan nan nan nan
124 92 1-FTTransformerConfig nan 0.200000 nan Adam 0.532167 0.756661 8.000000 64.000000 8.000000 GEGLU
86 54 1-FTTransformerConfig nan 0.200000 nan Adam 0.462083 0.754885 4.000000 32.000000 6.000000 LeakyReLU
50 18 1-FTTransformerConfig nan 0.000000 nan Adam 0.503736 0.753108 4.000000 64.000000 6.000000 LeakyReLU
42 10 1-FTTransformerConfig nan 0.000000 nan Adam 0.470982 0.753108 4.000000 32.000000 8.000000 LeakyReLU
34 2 1-FTTransformerConfig nan 0.000000 nan Adam 0.503541 0.751332 4.000000 32.000000 3.000000 LeakyReLU
106 74 1-FTTransformerConfig nan 0.200000 nan Adam 0.504346 0.747780 8.000000 32.000000 3.000000 LeakyReLU
46 14 1-FTTransformerConfig nan 0.000000 nan Adam 0.488356 0.747780 4.000000 64.000000 3.000000 LeakyReLU
54 22 1-FTTransformerConfig nan 0.000000 nan Adam 0.561371 0.740675 4.000000 64.000000 8.000000 LeakyReLU
58 26 1-FTTransformerConfig nan 0.000000 nan Adam 0.494664 0.740675 8.000000 32.000000 3.000000 LeakyReLU
88 56 1-FTTransformerConfig nan 0.200000 nan Adam 0.527474 0.738899 4.000000 32.000000 8.000000 GEGLU
84 52 1-FTTransformerConfig nan 0.200000 nan Adam 0.508179 0.731794 4.000000 32.000000 6.000000 GEGLU
118 86 1-FTTransformerConfig nan 0.200000 nan Adam 0.511033 0.731794 8.000000 64.000000 3.000000 LeakyReLU
120 88 1-FTTransformerConfig nan 0.200000 nan Adam 0.473721 0.731794 8.000000 64.000000 6.000000 GEGLU
98 66 1-FTTransformerConfig nan 0.200000 nan Adam 0.518997 0.731794 4.000000 64.000000 6.000000 LeakyReLU
31 31 0-CategoryEmbeddingModelConfig LeakyReLU 0.200000 256-512-1024 SGD 0.538754 0.731794 nan nan nan nan
40 8 1-FTTransformerConfig nan 0.000000 nan Adam 0.546107 0.731794 4.000000 32.000000 8.000000 GEGLU
4 4 0-CategoryEmbeddingModelConfig ReLU 0.000000 32-64-128 Adam 0.533960 0.728242 nan nan nan nan
70 38 1-FTTransformerConfig nan 0.000000 nan Adam 0.579302 0.726465 8.000000 64.000000 3.000000 LeakyReLU
12 12 0-CategoryEmbeddingModelConfig ReLU 0.200000 32-64-128 Adam 0.508314 0.724689 nan nan nan nan
38 6 1-FTTransformerConfig nan 0.000000 nan Adam 0.538916 0.721137 4.000000 32.000000 6.000000 LeakyReLU
82 50 1-FTTransformerConfig nan 0.200000 nan Adam 0.537538 0.721137 4.000000 32.000000 3.000000 LeakyReLU
122 90 1-FTTransformerConfig nan 0.200000 nan Adam 0.522755 0.719361 8.000000 64.000000 6.000000 LeakyReLU
48 16 1-FTTransformerConfig nan 0.000000 nan Adam 0.471181 0.715808 4.000000 64.000000 6.000000 GEGLU
32 0 1-FTTransformerConfig nan 0.000000 nan Adam 0.550226 0.714032 4.000000 32.000000 3.000000 GEGLU
108 76 1-FTTransformerConfig nan 0.200000 nan Adam 0.523274 0.714032 8.000000 32.000000 6.000000 GEGLU
63 31 1-FTTransformerConfig nan 0.000000 nan SGD 0.591639 0.712256 8.000000 32.000000 6.000000 LeakyReLU
104 72 1-FTTransformerConfig nan 0.200000 nan Adam 0.508801 0.710480 8.000000 32.000000 3.000000 GEGLU
24 24 0-CategoryEmbeddingModelConfig LeakyReLU 0.200000 128-64-32 Adam 0.519161 0.710480 nan nan nan nan
68 36 1-FTTransformerConfig nan 0.000000 nan Adam 0.572089 0.706927 8.000000 64.000000 3.000000 GEGLU
92 60 1-FTTransformerConfig nan 0.200000 nan Adam 0.575852 0.706927 4.000000 64.000000 3.000000 GEGLU
126 94 1-FTTransformerConfig nan 0.200000 nan Adam 0.570989 0.706927 8.000000 64.000000 8.000000 LeakyReLU
44 12 1-FTTransformerConfig nan 0.000000 nan Adam 0.577062 0.705151 4.000000 64.000000 3.000000 GEGLU
79 47 1-FTTransformerConfig nan 0.000000 nan SGD 0.557485 0.703375 8.000000 64.000000 8.000000 LeakyReLU
51 19 1-FTTransformerConfig nan 0.000000 nan SGD 0.550771 0.703375 4.000000 64.000000 6.000000 LeakyReLU
11 11 0-CategoryEmbeddingModelConfig ReLU 0.200000 1024-512-256 SGD 0.555238 0.703375 nan nan nan nan
114 82 1-FTTransformerConfig nan 0.200000 nan Adam 0.487832 0.701599 8.000000 32.000000 8.000000 LeakyReLU
90 58 1-FTTransformerConfig nan 0.200000 nan Adam 0.579668 0.699822 4.000000 32.000000 8.000000 LeakyReLU
3 3 0-CategoryEmbeddingModelConfig ReLU 0.000000 1024-512-256 SGD 0.572410 0.696270 nan nan nan nan
112 80 1-FTTransformerConfig nan 0.200000 nan Adam 0.553881 0.692718 8.000000 32.000000 8.000000 GEGLU
15 15 0-CategoryEmbeddingModelConfig ReLU 0.200000 256-512-1024 SGD 0.562511 0.685613 nan nan nan nan
35 3 1-FTTransformerConfig nan 0.000000 nan SGD 0.581403 0.685613 4.000000 32.000000 3.000000 LeakyReLU
45 13 1-FTTransformerConfig nan 0.000000 nan SGD 0.597738 0.685613 4.000000 64.000000 3.000000 GEGLU
100 68 1-FTTransformerConfig nan 0.200000 nan Adam 0.584579 0.683837 4.000000 64.000000 8.000000 GEGLU
83 51 1-FTTransformerConfig nan 0.200000 nan SGD 0.662541 0.680284 4.000000 32.000000 3.000000 LeakyReLU
127 95 1-FTTransformerConfig nan 0.200000 nan SGD 0.614641 0.676732 8.000000 64.000000 8.000000 LeakyReLU
101 69 1-FTTransformerConfig nan 0.200000 nan SGD 0.579955 0.674956 4.000000 64.000000 8.000000 GEGLU
102 70 1-FTTransformerConfig nan 0.200000 nan Adam 0.585392 0.671403 4.000000 64.000000 8.000000 LeakyReLU
27 27 0-CategoryEmbeddingModelConfig LeakyReLU 0.200000 1024-512-256 SGD 0.594700 0.667851 nan nan nan nan
7 7 0-CategoryEmbeddingModelConfig ReLU 0.000000 256-512-1024 SGD 0.598617 0.666075 nan nan nan nan
121 89 1-FTTransformerConfig nan 0.200000 nan SGD 0.632152 0.666075 8.000000 64.000000 6.000000 GEGLU
76 44 1-FTTransformerConfig nan 0.000000 nan Adam 0.641684 0.666075 8.000000 64.000000 8.000000 GEGLU
103 71 1-FTTransformerConfig nan 0.200000 nan SGD 0.616750 0.666075 4.000000 64.000000 8.000000 LeakyReLU
91 59 1-FTTransformerConfig nan 0.200000 nan SGD 0.634522 0.664298 4.000000 32.000000 8.000000 LeakyReLU
59 27 1-FTTransformerConfig nan 0.000000 nan SGD 0.624750 0.664298 8.000000 32.000000 3.000000 LeakyReLU
107 75 1-FTTransformerConfig nan 0.200000 nan SGD 0.637458 0.657194 8.000000 32.000000 3.000000 LeakyReLU
69 37 1-FTTransformerConfig nan 0.000000 nan SGD 0.636728 0.657194 8.000000 64.000000 3.000000 GEGLU
55 23 1-FTTransformerConfig nan 0.000000 nan SGD 0.613378 0.657194 4.000000 64.000000 8.000000 LeakyReLU
5 5 0-CategoryEmbeddingModelConfig ReLU 0.000000 32-64-128 SGD 0.670955 0.655417 nan nan nan nan
117 85 1-FTTransformerConfig nan 0.200000 nan SGD 0.629454 0.655417 8.000000 64.000000 3.000000 GEGLU
97 65 1-FTTransformerConfig nan 0.200000 nan SGD 0.645757 0.655417 4.000000 64.000000 6.000000 GEGLU
87 55 1-FTTransformerConfig nan 0.200000 nan SGD 0.646177 0.651865 4.000000 32.000000 6.000000 LeakyReLU
23 23 0-CategoryEmbeddingModelConfig LeakyReLU 0.000000 256-512-1024 SGD 0.639443 0.650089 nan nan nan nan
39 7 1-FTTransformerConfig nan 0.000000 nan SGD 0.651099 0.646536 4.000000 32.000000 6.000000 LeakyReLU
125 93 1-FTTransformerConfig nan 0.200000 nan SGD 0.624359 0.646536 8.000000 64.000000 8.000000 GEGLU
65 33 1-FTTransformerConfig nan 0.000000 nan SGD 0.597288 0.644760 8.000000 32.000000 8.000000 GEGLU
47 15 1-FTTransformerConfig nan 0.000000 nan SGD 0.666151 0.644760 4.000000 64.000000 3.000000 LeakyReLU
49 17 1-FTTransformerConfig nan 0.000000 nan SGD 0.639839 0.642984 4.000000 64.000000 6.000000 GEGLU
123 91 1-FTTransformerConfig nan 0.200000 nan SGD 0.628552 0.641208 8.000000 64.000000 6.000000 LeakyReLU
75 43 1-FTTransformerConfig nan 0.000000 nan SGD 0.619922 0.641208 8.000000 64.000000 6.000000 LeakyReLU
85 53 1-FTTransformerConfig nan 0.200000 nan SGD 0.655388 0.641208 4.000000 32.000000 6.000000 GEGLU
89 57 1-FTTransformerConfig nan 0.200000 nan SGD 0.635567 0.637655 4.000000 32.000000 8.000000 GEGLU
93 61 1-FTTransformerConfig nan 0.200000 nan SGD 0.658716 0.635879 4.000000 64.000000 3.000000 GEGLU
71 39 1-FTTransformerConfig nan 0.000000 nan SGD 0.646253 0.634103 8.000000 64.000000 3.000000 LeakyReLU
67 35 1-FTTransformerConfig nan 0.000000 nan SGD 0.667418 0.632327 8.000000 32.000000 8.000000 LeakyReLU
81 49 1-FTTransformerConfig nan 0.200000 nan SGD 0.664191 0.628774 4.000000 32.000000 3.000000 GEGLU
119 87 1-FTTransformerConfig nan 0.200000 nan SGD 0.665687 0.628774 8.000000 64.000000 3.000000 LeakyReLU
113 81 1-FTTransformerConfig nan 0.200000 nan SGD 0.651145 0.628774 8.000000 32.000000 8.000000 GEGLU
53 21 1-FTTransformerConfig nan 0.000000 nan SGD 0.672665 0.628774 4.000000 64.000000 8.000000 GEGLU
25 25 0-CategoryEmbeddingModelConfig LeakyReLU 0.200000 128-64-32 SGD 0.662720 0.625222 nan nan nan nan
111 79 1-FTTransformerConfig nan 0.200000 nan SGD 0.633410 0.623446 8.000000 32.000000 6.000000 LeakyReLU
43 11 1-FTTransformerConfig nan 0.000000 nan SGD 0.635329 0.621670 4.000000 32.000000 8.000000 LeakyReLU
99 67 1-FTTransformerConfig nan 0.200000 nan SGD 0.628636 0.616341 4.000000 64.000000 6.000000 LeakyReLU
41 9 1-FTTransformerConfig nan 0.000000 nan SGD 0.639925 0.616341 4.000000 32.000000 8.000000 GEGLU
109 77 1-FTTransformerConfig nan 0.200000 nan SGD 0.651506 0.614565 8.000000 32.000000 6.000000 GEGLU
1 1 0-CategoryEmbeddingModelConfig ReLU 0.000000 128-64-32 SGD 0.665929 0.612789 nan nan nan nan
105 73 1-FTTransformerConfig nan 0.200000 nan SGD 0.658312 0.605684 8.000000 32.000000 3.000000 GEGLU
37 5 1-FTTransformerConfig nan 0.000000 nan SGD 0.652759 0.605684 4.000000 32.000000 6.000000 GEGLU
73 41 1-FTTransformerConfig nan 0.000000 nan SGD 0.659291 0.598579 8.000000 64.000000 6.000000 GEGLU
33 1 1-FTTransformerConfig nan 0.000000 nan SGD 0.648887 0.596803 4.000000 32.000000 3.000000 GEGLU
17 17 0-CategoryEmbeddingModelConfig LeakyReLU 0.000000 128-64-32 SGD 0.710187 0.589698 nan nan nan nan
77 45 1-FTTransformerConfig nan 0.000000 nan SGD 0.648749 0.589698 8.000000 64.000000 8.000000 GEGLU
29 29 0-CategoryEmbeddingModelConfig LeakyReLU 0.200000 32-64-128 SGD 0.719664 0.582593 nan nan nan nan
13 13 0-CategoryEmbeddingModelConfig ReLU 0.200000 32-64-128 SGD 0.778426 0.555950 nan nan nan nan
61 29 1-FTTransformerConfig nan 0.000000 nan SGD 0.689890 0.552398 8.000000 32.000000 6.000000 GEGLU
57 25 1-FTTransformerConfig nan 0.000000 nan SGD 0.690501 0.539964 8.000000 32.000000 3.000000 GEGLU
21 21 0-CategoryEmbeddingModelConfig LeakyReLU 0.000000 32-64-128 SGD 0.726256 0.539964 nan nan nan nan
95 63 1-FTTransformerConfig nan 0.200000 nan SGD 0.701837 0.502664 4.000000 64.000000 3.000000 LeakyReLU
115 83 1-FTTransformerConfig nan 0.200000 nan SGD 0.680208 0.502664 8.000000 32.000000 8.000000 LeakyReLU
78 46 1-FTTransformerConfig nan 0.000000 nan Adam 0.693390 0.493783 8.000000 64.000000 8.000000 LeakyReLU
9 9 0-CategoryEmbeddingModelConfig ReLU 0.200000 128-64-32 SGD 0.781076 0.433393 nan nan nan nan
tuner_df.best_model.evaluate(test)
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.8173333406448364     │
│         test_loss             0.38250666856765747    │
└───────────────────────────┴───────────────────────────┘
[{'test_loss': 0.38250666856765747, 'test_accuracy': 0.8173333406448364}]

After training, the best model will be saved in output variable as "best_model". So if you liked the result and wish to use the model in the future, you can save it calling "save_model".

tuner_df.best_model.save_model("best_model", inference_only=True)
2024-07-20 12:58:01,015 - {pytorch_tabular.tabular_model:1572} - WARNING - Directory is not empty. Overwriting the 
contents.                                                                                                          
# Load saved model
#from pytorch_tabular import TabularModel
#loaded_model = TabularModel.load_model("best_model")
#loaded_model.evaluate(test)