Skip to content
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
import random
import numpy as np
import pandas as pd
import os
import wandb
import logging

# configure logging at the root level of Lightning
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
os.environ['PT_LOGLEVEL'] = "CRITICAL" # Setting Log Level for PyTorch Tabular. Need to do it before importing the modules
# %load_ext autoreload
# %autoreload 2

Utility Functions

def make_mixed_classification(n_samples, n_features, n_categories):
    X, y = make_classification(n_samples=n_samples, n_features=n_features, random_state=42, n_informative=5)
    cat_cols = random.choices(list(range(X.shape[-1])), k=n_categories)
    num_cols = [i for i in range(X.shape[-1]) if i not in cat_cols]
    for col in cat_cols:
        X[:, col] = pd.qcut(X[:, col], q=4).codes.astype(int)
    col_names = []
    num_col_names = []
    cat_col_names = []
    for i in range(X.shape[-1]):
        if i in cat_cols:
            col_names.append(f"cat_col_{i}")
            cat_col_names.append(f"cat_col_{i}")
        if i in num_cols:
            col_names.append(f"num_col_{i}")
            num_col_names.append(f"num_col_{i}")
    X = pd.DataFrame(X, columns=col_names)
    y = pd.Series(y, name="target")
    data = X.join(y)
    return data, cat_col_names, num_col_names


def print_metrics(y_true, y_pred, tag):
    if isinstance(y_true, pd.DataFrame) or isinstance(y_true, pd.Series):
        y_true = y_true.values
    if isinstance(y_pred, pd.DataFrame) or isinstance(y_pred, pd.Series):
        y_pred = y_pred.values
    if y_true.ndim > 1:
        y_true = y_true.ravel()
    if y_pred.ndim > 1:
        y_pred = y_pred.ravel()
    val_acc = accuracy_score(y_true, y_pred)
    val_f1 = f1_score(y_true, y_pred)
    print(f"{tag} Acc: {val_acc} | {tag} F1: {val_f1}")

Generate Synthetic Data

First of all, let's create a synthetic data which is a mix of numerical and categorical features

data, cat_col_names, num_col_names = make_mixed_classification(n_samples=10000, n_features=20, n_categories=4)

Importing the Library

from pytorch_tabular import TabularModel
from pytorch_tabular.models import (
    CategoryEmbeddingModelConfig,
    NodeConfig,
    TabNetModelConfig,
    GatedAdditiveTreeEnsembleConfig,
)
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig, ExperimentConfig
from pytorch_tabular.models.common.heads import LinearHeadConfig

Cross Validation

train, test = train_test_split(data, random_state=42)
data_config = DataConfig(
    target=[
        "target"
    ],  # target should always be a list. Multi-targets are only supported for regression. Multi-Task Classification is not implemented
    continuous_cols=num_col_names,
    categorical_cols=cat_col_names,
)
trainer_config = TrainerConfig(
    batch_size=1024,
    max_epochs=100,
    early_stopping="valid_loss",  # Monitor valid_loss for early stopping
    early_stopping_mode="min",  # Set the mode as min because for val_loss, lower is better
    early_stopping_patience=5,  # No. of epochs of degradation training will wait before terminating
    checkpoints="valid_loss",  # Save best checkpoint monitoring val_loss
    load_best=True,  # After training, load the best checkpoint
    progress_bar="none",  # Turning off Progress bar
    trainer_kwargs=dict(enable_model_summary=False),  # Turning off model summary
)
optimizer_config = OptimizerConfig()

head_config = LinearHeadConfig(
    layers="", dropout=0.1, initialization="kaiming"  # No additional layer in head, just a mapping layer to output_dim
).__dict__  # Convert to dict to pass to the model config (OmegaConf doesn't accept objects)

model_config = CategoryEmbeddingModelConfig(
    task="classification",
    layers="1024-512-512",  # Number of nodes in each layer
    activation="LeakyReLU",  # Activation between each layers
    learning_rate=1e-3,
    head="LinearHead",  # Linear Head
    head_config=head_config,  # Linear Head Config
)

tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
)

Using High-Level API

# cross validation loop usnig sklearn
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score, f1_score

kf = KFold(n_splits=5, shuffle=True, random_state=42)
metrics = []
for fold, (train_idx, val_idx) in enumerate(kf.split(train)):
    print(f"Fold: {fold}")
    train_fold = train.iloc[train_idx]
    val_fold = train.iloc[val_idx]
    # Initialize the tabular model
    tabular_model = TabularModel(
        data_config=data_config,
        model_config=model_config,
        optimizer_config=optimizer_config,
        trainer_config=trainer_config,
    )
    # Fit the model
    tabular_model.fit(
        train=train_fold,
        validation=val_fold,
    )
    result = tabular_model.evaluate(val_fold, verbose=False)
    metrics.append(result[0]["test_loss"])
Global seed set to 42
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

Fold: 0

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
Global seed set to 42
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(

Fold: 1

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
Global seed set to 42
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

Fold: 2

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
Global seed set to 42
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

Fold: 3

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
Global seed set to 42
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(

Fold: 4

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(

print(f"KFold Mean: {np.mean(metrics)} | KFold SD: {np.std(metrics)}")
KFold Mean: 0.24372529685497285 | KFold SD: 0.01793774135386692

Using Low-Level API

Sometimes, the fitting the datamodule is an expensive operation. If the dataset is sufficiently large, we can take an approximation and prepare the TabularDatamodule once and then reuse the same for the other folds.

P.S - The loop can easily be modified to do bagging (predict on test data using model from each fold and average it)

kf = KFold(n_splits=5, shuffle=True, random_state=42)
# Initialize the tabular model onece
tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
)
metrics = []
datamodule = None
model = None
for fold, (train_idx, val_idx) in enumerate(kf.split(train)):
    print(f"Fold: {fold}")
    train_fold = train.iloc[train_idx]
    val_fold = train.iloc[val_idx]
    if datamodule is None:
        # Initialize datamodule and model in the first fold
        # uses train data from this fold to fit all transformers
        datamodule = tabular_model.prepare_dataloader(train=train_fold, validation=val_fold, seed=42)
        model = tabular_model.prepare_model(datamodule)
    else:
        # Preprocess the current fold data using the fitted transformers and save in datamodule
        datamodule.train, _ = datamodule.preprocess_data(train_fold, stage="inference")
        datamodule.validation, _ = datamodule.preprocess_data(val_fold, stage="inference")
    # Train the model
    tabular_model.train(model, datamodule)
    result = tabular_model.evaluate(val_fold, verbose=False)
    metrics.append(result[0]["test_loss"])
    # Reset the trained weights before next fold
    tabular_model.model.reset_weights()
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

Fold: 0

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(

Fold: 1

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(

Fold: 2

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(

Fold: 3

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(

Fold: 4

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(

print(f"KFold Mean: {np.mean(metrics)} | KFold SD: {np.std(metrics)}")
KFold Mean: 0.23143928349018097 | KFold SD: 0.008452922690941587

Evaluating Multiple models without re-fitting DataModules

Using the Low-level API, we can also train and evaluate multiple models without re-fitting a datamodule.

train, test = train_test_split(data, random_state=42)
train, val = train_test_split(train, random_state=42)
results = []
data_config = DataConfig(
    target=[
        "target"
    ],  # target should always be a list. Multi-targets are only supported for regression. Multi-Task Classification is not implemented
    continuous_cols=num_col_names,
    categorical_cols=cat_col_names,
)
trainer_config = TrainerConfig(
    batch_size=1024,
    max_epochs=100,
    early_stopping="valid_loss",  # Monitor valid_loss for early stopping
    early_stopping_mode="min",  # Set the mode as min because for val_loss, lower is better
    early_stopping_patience=5,  # No. of epochs of degradation training will wait before terminating
    checkpoints="valid_loss",  # Save best checkpoint monitoring val_loss
    load_best=True,  # After training, load the best checkpoint
    #     progress_bar="none", # Turning off Progress bar
    #     trainer_kwargs=dict(
    #         enable_model_summary=False # Turning off model summary
    #     )
)
optimizer_config = OptimizerConfig()

head_config = LinearHeadConfig(
    layers="", dropout=0.1, initialization="kaiming"  # No additional layer in head, just a mapping layer to output_dim
).__dict__  # Convert to dict to pass to the model config (OmegaConf doesn't accept objects)

model1_config = CategoryEmbeddingModelConfig(
    task="classification",
    layers="1024-512-512",  # Number of nodes in each layer
    activation="LeakyReLU",  # Activation between each layers
    learning_rate=1e-3,
    head="LinearHead",  # Linear Head
    head_config=head_config,  # Linear Head Config
)

model_config = CategoryEmbeddingModelConfig(
    task="classification",
    layers="1024-512-512",  # Number of nodes in each layer
    activation="LeakyReLU",  # Activation between each layers
    learning_rate=1e-3,
    head="LinearHead",  # Linear Head
    head_config=head_config,  # Linear Head Config
)

alt_model_config = GatedAdditiveTreeEnsembleConfig(
    task="classification",
    learning_rate=1e-3,
    head="LinearHead",  # Linear Head
    head_config=head_config,  # Linear Head Config
)

tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
)
datamodule = tabular_model.prepare_dataloader(train=train, validation=val, seed=42)
model = tabular_model.prepare_model(datamodule)
tabular_model.train(model, datamodule)
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓
┃    Name              Type                       Params ┃
┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩
│ 0 │ _backbone        │ CategoryEmbeddingBackbone │  817 K │
│ 1 │ _embedding_layer │ Embedding1dLayer          │     92 │
│ 2 │ head             │ LinearHead                │  1.0 K │
│ 3 │ loss             │ CrossEntropyLoss          │      0 │
└───┴──────────────────┴───────────────────────────┴────────┘
Trainable params: 818 K                                                                                            
Non-trainable params: 0                                                                                            
Total params: 818 K                                                                                                
Total estimated model params size (MB): 3                                                                          

Output()
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(

<pytorch_lightning.trainer.trainer.Trainer at 0x7fd4491ea350>
result = tabular_model.evaluate(test)

result = result[0]
result["Model"] = "CategoryEmbedding"
results.append(result)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(


Output()
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.9028000235557556     │
│         test_loss             0.25517141819000244    │
└───────────────────────────┴───────────────────────────┘
Testing ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3/3 0:00:00 • 0:00:00 53.16it/s  

alt_tabular_model = TabularModel(
    data_config=data_config,
    model_config=alt_model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
)
alt_model = alt_tabular_model.prepare_model(datamodule)
alt_tabular_model.train(alt_model, datamodule)
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓
┃    Name              Type                        Params ┃
┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩
│ 0 │ _backbone        │ GatedAdditiveTreesBackbone │  1.1 M │
│ 1 │ _embedding_layer │ Embedding1dLayer           │     92 │
│ 2 │ _head            │ CustomHead                 │     86 │
│ 3 │ loss             │ CrossEntropyLoss           │      0 │
└───┴──────────────────┴────────────────────────────┴────────┘
Trainable params: 1.1 M                                                                                            
Non-trainable params: 0                                                                                            
Total params: 1.1 M                                                                                                
Total estimated model params size (MB): 4                                                                          

Output()
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(

<pytorch_lightning.trainer.trainer.Trainer at 0x7fd41a0eb160>
result = alt_tabular_model.evaluate(test)

result = result[0]
result["Model"] = "GATE"
results.append(result)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(


Output()
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.9120000004768372     │
│         test_loss             0.24496853351593018    │
└───────────────────────────┴───────────────────────────┘
Testing ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3/3 0:00:17 • 0:00:00 0.21it/s  

pd.DataFrame(results)
test_loss test_accuracy Model
0 0.255171 0.9028 CategoryEmbedding
1 0.244969 0.9120 GATE

Hyperparameter Tuning

Using the Low-level API, we can also implement hyperparameter tuning

results = []
train, test = train_test_split(data, random_state=42)
train, val = train_test_split(train, random_state=42)
data_config = DataConfig(
    target=[
        "target"
    ],  # target should always be a list. Multi-targets are only supported for regression. Multi-Task Classification is not implemented
    continuous_cols=num_col_names,
    categorical_cols=cat_col_names,
)
trainer_config = TrainerConfig(
    batch_size=1024,
    max_epochs=100,
    early_stopping="valid_loss",  # Monitor valid_loss for early stopping
    early_stopping_mode="min",  # Set the mode as min because for val_loss, lower is better
    early_stopping_patience=5,  # No. of epochs of degradation training will wait before terminating
    checkpoints="valid_loss",  # Save best checkpoint monitoring val_loss
    load_best=True,  # After training, load the best checkpoint
    progress_bar="none",  # Turning off Progress bar
    trainer_kwargs=dict(enable_model_summary=False),  # Turning off model summary
)
optimizer_config = OptimizerConfig()

head_config = LinearHeadConfig(
    layers="", dropout=0.1, initialization="kaiming"  # No additional layer in head, just a mapping layer to output_dim
).__dict__  # Convert to dict to pass to the model config (OmegaConf doesn't accept objects)

model_config = CategoryEmbeddingModelConfig(
    task="classification",
    layers="1024-512-512",  # Number of nodes in each layer
    activation="LeakyReLU",  # Activation between each layers
    learning_rate=1e-3,
    head="LinearHead",  # Linear Head
    head_config=head_config,  # Linear Head Config
)

tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
)
datamodule = tabular_model.prepare_dataloader(train=train, validation=val, seed=42)
model = tabular_model.prepare_model(datamodule)
tabular_model.train(model, datamodule)
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(

<pytorch_lightning.trainer.trainer.Trainer at 0x7f87f18f8460>
result = tabular_model.evaluate(test)

result = result[0]
result["Type"] = "UnTuned"
results.append(result)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(

┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.9300000071525574     │
│         test_loss             0.19813215732574463    │
└───────────────────────────┴───────────────────────────┘

Note: For demonstration we are using the test split for tuning, but in real problems, please use a separate validation set for tuning purposes. Otherwise, you will be overfitting to the test set and have falsely high performance estimates.

import warnings
from copy import deepcopy
from sklearn.model_selection import ParameterGrid
from rich.progress import Progress
# Define the Grid
param_grid = {
    "model_config__layers": ["1024-512-512", "512-128-64-32-16"],
    "model_config__activation": ["LeakyReLU", "ReLU"],
    "model_config__learning_rate": [1e-3, 1e-4],
    "model_config~head_config__dropout": [0.1, 0.2],
    "optimizer_config__optimizer": ["Adam", "SGD"],
}

trials = []
with Progress() as progress:
    task = progress.add_task("[green]GridSearch...", total=sum(1 for _ in ParameterGrid(param_grid)))
    for params in ParameterGrid(param_grid):
        # Copying the configs as a base
        # Make sure all default parameters that you want to be set for all trials are in the original configs
        # Tuning data_config can't be done with this setup. That would require us to re-fit the datamodule everytime
        trainer_config_t = deepcopy(trainer_config)
        optimizer_config_t = deepcopy(optimizer_config)
        head_config_t = deepcopy(head_config)
        model_config_t = deepcopy(model_config)

        # Now set the parameters to the right config
        for name, param in params.items():
            root, p = name.split("__")
            if root == "model_config":
                setattr(model_config_t, p, param)
            elif root == "trainer_config":
                setattr(trainer_config_t, p, param)
            elif root == "optimizer_config":
                setattr(optimizer_config_t, p, param)
            elif root == "model_config~head_config":
                model_config_t.head_config[p] = param
            else:
                warnings.warn(f"Unknown parameter defined. Ignoring {name}")
        # Initialize Tabular model using the new config
        tabular_model_t = TabularModel(
            data_config=data_config,
            model_config=model_config_t,
            optimizer_config=optimizer_config_t,
            trainer_config=trainer_config_t,
        )
        # Prepare the new model using old datamodule
        model_t = tabular_model_t.prepare_model(datamodule)
        # Train using new model and old datamodule
        tabular_model_t.train(model_t, datamodule)
        # Using one of the metrics/loss already defined
        # but we can also calculate any metric here and log that
        result_t = tabular_model_t.evaluate(test)[0]
        params.update(result_t)
        trials.append(params)
        progress.update(task, advance=1)

Output()
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_c
heckpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models 
exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   0% -:--:--
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   0% -:--:--
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   0% -:--:--
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   0% -:--:--
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy            0.920799970626831     │
│         test_loss             0.2067311555147171     │
└───────────────────────────┴───────────────────────────┘
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   0% -:--:--
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_c
heckpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models 
exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GridSearch... ╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   3% -:--:--
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   3% -:--:--
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   3% -:--:--
`Trainer.fit` stopped: `max_epochs=100` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   3% -:--:--
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.8172000050544739     │
│         test_loss             0.4020683467388153     │
└───────────────────────────┴───────────────────────────┘
GridSearch... ╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   3% -:--:--
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_c
heckpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models 
exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GridSearch... ━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   6% 0:06:56
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   6% 0:06:56
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   6% 0:06:56
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   6% 0:06:56
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.9272000193595886     │
│         test_loss             0.19424475729465485    │
└───────────────────────────┴───────────────────────────┘
GridSearch... ━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   6% 0:06:56
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_c
heckpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models 
exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GridSearch... ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   9% 0:04:18
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   9% 0:04:18
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   9% 0:04:18
`Trainer.fit` stopped: `max_epochs=100` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   9% 0:04:18
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.8148000240325928     │
│         test_loss             0.4120400846004486     │
└───────────────────────────┴───────────────────────────┘
GridSearch... ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   9% 0:04:18
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_c
heckpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models 
exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GridSearch... ━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  12% 0:04:22
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  12% 0:04:22
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  12% 0:04:22
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  12% 0:04:22
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.9088000059127808     │
│         test_loss             0.23200805485248566    │
└───────────────────────────┴───────────────────────────┘
GridSearch... ━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  12% 0:04:22
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_c
heckpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models 
exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GridSearch... ━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  16% 0:03:36
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  16% 0:03:36
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  16% 0:03:36
`Trainer.fit` stopped: `max_epochs=100` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  16% 0:03:36
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy            0.623199999332428     │
│         test_loss             0.6852937340736389     │
└───────────────────────────┴───────────────────────────┘
GridSearch... ━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  16% 0:03:36
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_c
heckpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models 
exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GridSearch... ━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  19% 0:04:15
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  19% 0:04:15
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  19% 0:04:15
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  19% 0:04:15
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.9083999991416931     │
│         test_loss             0.2381352037191391     │
└───────────────────────────┴───────────────────────────┘
GridSearch... ━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  19% 0:04:15
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_c
heckpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models 
exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GridSearch... ━━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  22% 0:03:43
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  22% 0:03:43
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  22% 0:03:43
`Trainer.fit` stopped: `max_epochs=100` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  22% 0:03:43
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.6624000072479248     │
│         test_loss             0.6230977773666382     │
└───────────────────────────┴───────────────────────────┘
GridSearch... ━━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  22% 0:03:43
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_c
heckpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models 
exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GridSearch... ━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  25% 0:04:16
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  25% 0:04:16
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  25% 0:04:16
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  25% 0:04:16
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.9211999773979187     │
│         test_loss             0.20694784820079803    │
└───────────────────────────┴───────────────────────────┘
GridSearch... ━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  25% 0:04:16
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_c
heckpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models 
exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GridSearch... ━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━  28% 0:03:02
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━  28% 0:03:02
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━  28% 0:03:02
`Trainer.fit` stopped: `max_epochs=100` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━  28% 0:03:02
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.7839999794960022     │
│         test_loss             0.47276991605758667    │
└───────────────────────────┴───────────────────────────┘
GridSearch... ━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━  28% 0:03:02
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_c
heckpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models 
exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GridSearch... ━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━  31% 0:02:55
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━  31% 0:02:55
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━  31% 0:02:55
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━  31% 0:02:55
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.9192000031471252     │
│         test_loss             0.21727892756462097    │
└───────────────────────────┴───────────────────────────┘
GridSearch... ━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━  31% 0:02:55
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_c
heckpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models 
exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GridSearch... ━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━  34% 0:02:15
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━  34% 0:02:15
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━  34% 0:02:15
`Trainer.fit` stopped: `max_epochs=100` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━  34% 0:02:15
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.7411999702453613     │
│         test_loss             0.5252864956855774     │
└───────────────────────────┴───────────────────────────┘
GridSearch... ━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━  34% 0:02:15
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_c
heckpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models 
exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GridSearch... ━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━  38% 0:02:49
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━  38% 0:02:49
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━  38% 0:02:49
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━  38% 0:02:49
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.9115999937057495     │
│         test_loss             0.22225087881088257    │
└───────────────────────────┴───────────────────────────┘
GridSearch... ━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━  38% 0:02:49
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_c
heckpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models 
exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GridSearch... ━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━  41% 0:02:44
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━  41% 0:02:44
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━  41% 0:02:44
`Trainer.fit` stopped: `max_epochs=100` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━  41% 0:02:44
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.5004000067710876     │
│         test_loss             0.7720578908920288     │
└───────────────────────────┴───────────────────────────┘
GridSearch... ━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━  41% 0:02:44
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_c
heckpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models 
exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GridSearch... ━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━  44% 0:03:19
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━  44% 0:03:19
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━  44% 0:03:19
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━  44% 0:03:19
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.9103999733924866     │
│         test_loss             0.2252570241689682     │
└───────────────────────────┴───────────────────────────┘
GridSearch... ━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━  44% 0:03:19
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_c
heckpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models 
exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GridSearch... ━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━  47% 0:03:10
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━  47% 0:03:10
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━  47% 0:03:10
`Trainer.fit` stopped: `max_epochs=100` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━  47% 0:03:10
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.5587999820709229     │
│         test_loss             0.7378883957862854     │
└───────────────────────────┴───────────────────────────┘
GridSearch... ━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━  47% 0:03:10
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_c
heckpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models 
exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GridSearch... ━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━  50% 0:03:07
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━  50% 0:03:07
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━  50% 0:03:07
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━  50% 0:03:07
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.9259999990463257     │
│         test_loss             0.1996007263660431     │
└───────────────────────────┴───────────────────────────┘
GridSearch... ━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━  50% 0:03:07
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_c
heckpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models 
exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GridSearch... ━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━  53% 0:02:18
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━  53% 0:02:18
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━  53% 0:02:18
`Trainer.fit` stopped: `max_epochs=100` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━  53% 0:02:18
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.8223999738693237     │
│         test_loss             0.4129759669303894     │
└───────────────────────────┴───────────────────────────┘
GridSearch... ━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━  53% 0:02:18
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_c
heckpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models 
exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━  56% 0:02:16
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━  56% 0:02:16
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━  56% 0:02:16
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━  56% 0:02:16
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.9236000180244446     │
│         test_loss             0.19273507595062256    │
└───────────────────────────┴───────────────────────────┘
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━  56% 0:02:16
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_c
heckpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models 
exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━  59% 0:01:40
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━  59% 0:01:40
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━  59% 0:01:40
`Trainer.fit` stopped: `max_epochs=100` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━  59% 0:01:40
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.8044000267982483     │
│         test_loss             0.4212886691093445     │
└───────────────────────────┴───────────────────────────┘
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━  59% 0:01:40
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_c
heckpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models 
exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━  62% 0:01:52
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━  62% 0:01:52
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━  62% 0:01:52
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━  62% 0:01:52
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy            0.909600019454956     │
│         test_loss             0.23710620403289795    │
└───────────────────────────┴───────────────────────────┘
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━  62% 0:01:52
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_c
heckpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models 
exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━  66% 0:01:31
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━  66% 0:01:31
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━  66% 0:01:31
`Trainer.fit` stopped: `max_epochs=100` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━  66% 0:01:31
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.6639999747276306     │
│         test_loss             0.6387022137641907     │
└───────────────────────────┴───────────────────────────┘
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━  66% 0:01:31
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_c
heckpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models 
exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━  69% 0:01:46
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━  69% 0:01:46
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━  69% 0:01:46
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━  69% 0:01:46
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.9064000248908997     │
│         test_loss             0.2324029803276062     │
└───────────────────────────┴───────────────────────────┘
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━  69% 0:01:46
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_c
heckpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models 
exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━  72% 0:01:22
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━  72% 0:01:22
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━  72% 0:01:22
`Trainer.fit` stopped: `max_epochs=100` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━  72% 0:01:22
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.6772000193595886     │
│         test_loss             0.5996285080909729     │
└───────────────────────────┴───────────────────────────┘
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━  72% 0:01:22
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_c
heckpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models 
exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━  75% 0:01:22
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━  75% 0:01:22
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━  75% 0:01:22
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━  75% 0:01:22
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.9168000221252441     │
│         test_loss             0.21080686151981354    │
└───────────────────────────┴───────────────────────────┘
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━  75% 0:01:22
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_c
heckpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models 
exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━  78% 0:00:52
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━  78% 0:00:52
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━  78% 0:00:52
`Trainer.fit` stopped: `max_epochs=100` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━  78% 0:00:52
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.7232000231742859     │
│         test_loss             0.5564718842506409     │
└───────────────────────────┴───────────────────────────┘
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━  78% 0:00:52
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_c
heckpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models 
exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━  78% 0:00:52
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━  81% 0:00:58
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━  81% 0:00:58
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━  81% 0:00:58
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.9279999732971191     │
│         test_loss             0.1926078200340271     │
└───────────────────────────┴───────────────────────────┘
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━  81% 0:00:58
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_c
heckpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models 
exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━  84% 0:00:28
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━  84% 0:00:28
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━  84% 0:00:28
`Trainer.fit` stopped: `max_epochs=100` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━  84% 0:00:28
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.7432000041007996     │
│         test_loss             0.5229927897453308     │
└───────────────────────────┴───────────────────────────┘
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━  84% 0:00:28
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_c
heckpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models 
exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━  88% 0:00:39
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━  88% 0:00:39
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━  88% 0:00:39
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━  88% 0:00:39
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.9092000126838684     │
│         test_loss             0.22894182801246643    │
└───────────────────────────┴───────────────────────────┘
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━  88% 0:00:39
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_c
heckpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models 
exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━  91% 0:00:28
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━  91% 0:00:28
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━  91% 0:00:28
`Trainer.fit` stopped: `max_epochs=100` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━  91% 0:00:28
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.5636000037193298     │
│         test_loss             0.7135388255119324     │
└───────────────────────────┴───────────────────────────┘
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━  91% 0:00:28
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_c
heckpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models 
exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━  94% 0:00:26
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━  94% 0:00:26
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━  94% 0:00:26
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━  94% 0:00:26
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.9204000234603882     │
│         test_loss             0.20439837872982025    │
└───────────────────────────┴───────────────────────────┘
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━  94% 0:00:26
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_c
heckpoint.py:604: UserWarning: Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/saved_models 
exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸  97% 0:00:14
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸  97% 0:00:14
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸  97% 0:00:14
`Trainer.fit` stopped: `max_epochs=100` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connector
s/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which 
may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of 
cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸  97% 0:00:14
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.5171999931335449     │
│         test_loss             0.7117891311645508     │
└───────────────────────────┴───────────────────────────┘
GridSearch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸  97% 0:00:14

trials_df = pd.DataFrame(trials)
trials_df.head()
model_config__activation model_config__layers model_config__learning_rate model_config~head_config__dropout optimizer_config__optimizer test_loss test_accuracy
0 LeakyReLU 1024-512-512 0.0010 0.1 Adam 0.206731 0.9208
1 LeakyReLU 1024-512-512 0.0010 0.1 SGD 0.402068 0.8172
2 LeakyReLU 1024-512-512 0.0010 0.2 Adam 0.194245 0.9272
3 LeakyReLU 1024-512-512 0.0010 0.2 SGD 0.412040 0.8148
4 LeakyReLU 1024-512-512 0.0001 0.1 Adam 0.232008 0.9088
# Params with lowest loss
trials_df.loc[trials_df.test_loss.idxmin()]
model_config__activation                         ReLU
model_config__layers                 512-128-64-32-16
model_config__learning_rate                     0.001
model_config~head_config__dropout                 0.2
optimizer_config__optimizer                      Adam
test_loss                                    0.192608
test_accuracy                                   0.928
Name: 26, dtype: object
# Params with highest accuracy
trials_df.loc[trials_df.test_accuracy.idxmax()]
model_config__activation                         ReLU
model_config__layers                 512-128-64-32-16
model_config__learning_rate                     0.001
model_config~head_config__dropout                 0.2
optimizer_config__optimizer                      Adam
test_loss                                    0.192608
test_accuracy                                   0.928
Name: 26, dtype: object