Skip to content

After defining all the configs, we need to put it all together and this is where TabularModel comes in. TabularModel is the core work horse, which orchestrates and sets everything up.

TabularModel parses the configs and:

  1. initializes the model
  2. sets up the experiment tracking framework
  3. initializes and sets up the TabularDatamodule which handles all the data transformations and preparation of the DataLoaders
  4. sets up the callbacks and the Pytorch Lightning Trainer
  5. enables you to train, save, load, and predict

Initializing Tabular Model

Basic Usage:

  • data_config: DataConfig: DataConfig object or path to the yaml file.
  • model_config: ModelConfig: A subclass of ModelConfig or path to the yaml file. Determines which model to run from the type of config.
  • optimizer_config: OptimizerConfig: OptimizerConfig object or path to the yaml file.
  • trainer_config: TrainerConfig: TrainerConfig object or path to the yaml file.
  • experiment_config: ExperimentConfig: ExperimentConfig object or path to the yaml file.

Usage Example

tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
    experiment_config=experiment_config,
)

Model Sweep

PyTorch Tabular also provides an easy way to check performance of different models and configurations on a given dataset. This is done through the model_sweep function. It takes in a list of model configs or one of the presets defined in pytorch_tabular.MODEL_PRESETS and trains them on the data. It then ranks the models based on the metric provided and returns the best model.

These are the major args: - task: The type of prediction task. Either 'classification' or 'regression' - train: The training data - test: The test data on which performance is evaluated - all the config objects can be passed as either the object or the path to the yaml file. - models: The list of models to compare. This can be one of the presets defined in pytorch_tabular.MODEL_SWEEP_PRESETS or a list of ModelConfig objects. - metrics: the list of metrics you need to track during training. The metrics should be one of the functional metrics implemented in torchmetrics. By default, it is accuracy if classification and mean_squared_error for regression - metrics_prob_input: Is a mandatory parameter for classification metrics defined in the config. This defines whether the input to the metric function is the probability or the class. Length should be same as the number of metrics. Defaults to None. - metrics_params: The parameters to be passed to the metrics function. - rank_metric: The metric to use for ranking the models. The first element of the tuple is the metric name and the second element is the direction. Defaults to ('loss', "lower_is_better"). - return_best_model: If True, will return the best model. Defaults to True.

Usage Example

sweep_df, best_model = model_sweep(
    task="classification",  # One of "classification", "regression"
    train=train,
    test=test,
    data_config=data_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
    model_list="lite",  # One of the presets defined in pytorch_tabular.MODEL_SWEEP_PRESETS
    common_model_args=dict(head="LinearHead", head_config=head_config),
    metrics=['accuracy', "f1_score"], # The metrics to track during training
    metrics_params=[{}, {"average": "weighted"}],
    metrics_prob_input=[False, True],
    rank_metric=("accuracy", "higher_is_better"), # The metric to use for ranking the models. 
    progress_bar=True, # If True, will show a progress bar
    verbose=False # If True, will print the results of each model
)

For more examples, check out the tutorial notebook - Model Sweep for example usage.

Advanced Usage

  • config: DictConfig: Another way of initializing TabularModel is with an Dictconfig from omegaconf. Although not recommended, you can create a normal dictionary with all the parameters dumped into it and create a DictConfig from omegaconf and pass it here. The downside is that you'll be skipping all the validation(both type validation and logical validations). This is primarily used internally to load a saved model from a checkpoint.
  • model_callable: Optional[Callable]: Usually, the model callable and parameters are inferred from the ModelConfig. But in special cases, like when working with a custom model, you can pass the class(not the initialized object) to this parameter and override the config based initialization.

Training API (Supervised Learning)

There are two APIs for training or 'fit'-ing a model.

  1. High-level API
  2. Low-level API

The low-level API is more flexible and allows you to customize the training loop. The high-level API is easier to use and is recommended for most use cases.

High-Level API

pytorch_tabular.TabularModel.fit(train, validation=None, loss=None, metrics=None, metrics_prob_inputs=None, optimizer=None, optimizer_params=None, train_sampler=None, target_transform=None, max_epochs=None, min_epochs=None, seed=42, callbacks=None, datamodule=None, cache_data='memory', handle_oom=True)

The fit method which takes in the data and triggers the training.

Parameters:

Name Type Description Default
train DataFrame

Training Dataframe

required
validation Optional[DataFrame]

If provided, will use this dataframe as the validation while training. Used in Early Stopping and Logging. If left empty, will use 20% of Train data as validation. Defaults to None.

None
loss Optional[Module]

Custom Loss functions which are not in standard pytorch library

None
metrics Optional[List[Callable]]

Custom metric functions(Callable) which has the signature metric_fn(y_hat, y) and works on torch tensor inputs. y_hat is expected to be of shape (batch_size, num_classes) for classification and (batch_size, 1) for regression and y is expected to be of shape (batch_size, 1)

None
metrics_prob_inputs Optional[List[bool]]

This is a mandatory parameter for classification metrics. If the metric function requires probabilities as inputs, set this to True. The length of the list should be equal to the number of metrics. Defaults to None.

None
optimizer Optional[Optimizer]

Custom optimizers which are a drop in replacements for standard PyTorch optimizers. This should be the Class and not the initialized object

None
optimizer_params Optional[Dict]

The parameters to initialize the custom optimizer.

None
train_sampler Optional[Sampler]

Custom PyTorch batch samplers which will be passed to the DataLoaders. Useful for dealing with imbalanced data and other custom batching strategies

None
target_transform Optional[Union[TransformerMixin, Tuple(Callable)]]

If provided, applies the transform to the target before modelling and inverse the transform during prediction. The parameter can either be a sklearn Transformer which has an inverse_transform method, or a tuple of callables (transform_func, inverse_transform_func)

None
max_epochs Optional[int]

Overwrite maximum number of epochs to be run. Defaults to None.

None
min_epochs Optional[int]

Overwrite minimum number of epochs to be run. Defaults to None.

None
seed Optional[int]

(int): Random seed for reproducibility. Defaults to 42.

42
callbacks Optional[List[Callback]]

List of callbacks to be used during training. Defaults to None.

None
datamodule Optional[TabularDatamodule]

The datamodule. If provided, will ignore the rest of the parameters like train, test etc and use the datamodule. Defaults to None.

None
cache_data str

Decides how to cache the data in the dataloader. If set to "memory", will cache in memory. If set to a valid path, will cache in that path. Defaults to "memory".

'memory'
handle_oom bool

If True, will try to handle OOM errors elegantly. Defaults to True.

True

Returns:

Type Description
Trainer

pl.Trainer: The PyTorch Lightning Trainer instance

Source code in src/pytorch_tabular/tabular_model.py
def fit(
    self,
    train: Optional[DataFrame],
    validation: Optional[DataFrame] = None,
    loss: Optional[torch.nn.Module] = None,
    metrics: Optional[List[Callable]] = None,
    metrics_prob_inputs: Optional[List[bool]] = None,
    optimizer: Optional[torch.optim.Optimizer] = None,
    optimizer_params: Dict = None,
    train_sampler: Optional[torch.utils.data.Sampler] = None,
    target_transform: Optional[Union[TransformerMixin, Tuple]] = None,
    max_epochs: Optional[int] = None,
    min_epochs: Optional[int] = None,
    seed: Optional[int] = 42,
    callbacks: Optional[List[pl.Callback]] = None,
    datamodule: Optional[TabularDatamodule] = None,
    cache_data: str = "memory",
    handle_oom: bool = True,
) -> pl.Trainer:
    """The fit method which takes in the data and triggers the training.

    Args:
        train (DataFrame): Training Dataframe

        validation (Optional[DataFrame], optional):
            If provided, will use this dataframe as the validation while training.
            Used in Early Stopping and Logging. If left empty, will use 20% of Train data as validation.
            Defaults to None.

        loss (Optional[torch.nn.Module], optional): Custom Loss functions which are not in standard pytorch library

        metrics (Optional[List[Callable]], optional): Custom metric functions(Callable) which has the
            signature metric_fn(y_hat, y) and works on torch tensor inputs. y_hat is expected to be of shape
            (batch_size, num_classes) for classification and (batch_size, 1) for regression and y is expected to be
            of shape (batch_size, 1)

        metrics_prob_inputs (Optional[List[bool]], optional): This is a mandatory parameter for
            classification metrics. If the metric function requires probabilities as inputs, set this to True.
            The length of the list should be equal to the number of metrics. Defaults to None.

        optimizer (Optional[torch.optim.Optimizer], optional):
            Custom optimizers which are a drop in replacements for
            standard PyTorch optimizers. This should be the Class and not the initialized object

        optimizer_params (Optional[Dict], optional): The parameters to initialize the custom optimizer.

        train_sampler (Optional[torch.utils.data.Sampler], optional):
            Custom PyTorch batch samplers which will be passed
            to the DataLoaders. Useful for dealing with imbalanced data and other custom batching strategies

        target_transform (Optional[Union[TransformerMixin, Tuple(Callable)]], optional):
            If provided, applies the transform to the target before modelling and inverse the transform during
            prediction. The parameter can either be a sklearn Transformer
            which has an inverse_transform method, or a tuple of callables (transform_func, inverse_transform_func)

        max_epochs (Optional[int]): Overwrite maximum number of epochs to be run. Defaults to None.

        min_epochs (Optional[int]): Overwrite minimum number of epochs to be run. Defaults to None.

        seed: (int): Random seed for reproducibility. Defaults to 42.

        callbacks (Optional[List[pl.Callback]], optional):
            List of callbacks to be used during training. Defaults to None.

        datamodule (Optional[TabularDatamodule], optional): The datamodule.
            If provided, will ignore the rest of the parameters like train, test etc and use the datamodule.
            Defaults to None.

        cache_data (str): Decides how to cache the data in the dataloader. If set to
            "memory", will cache in memory. If set to a valid path, will cache in that path. Defaults to "memory".

        handle_oom (bool): If True, will try to handle OOM errors elegantly. Defaults to True.

    Returns:
        pl.Trainer: The PyTorch Lightning Trainer instance

    """
    assert self.config.task != "ssl", (
        "`fit` is not valid for SSL task. Please use `pretrain` for" " semi-supervised learning"
    )
    if metrics is not None:
        assert len(metrics) == len(
            metrics_prob_inputs or []
        ), "The length of `metrics` and `metrics_prob_inputs` should be equal"
    seed = seed or self.config.seed
    if seed:
        seed_everything(seed)
    if datamodule is None:
        datamodule = self.prepare_dataloader(
            train,
            validation,
            train_sampler,
            target_transform,
            seed,
            cache_data,
        )
    else:
        if train is not None:
            warnings.warn(
                "train data and datamodule is provided."
                " Ignoring the train data and using the datamodule."
                " Set either one of them to None to avoid this warning."
            )
    model = self.prepare_model(
        datamodule,
        loss,
        metrics,
        metrics_prob_inputs,
        optimizer,
        optimizer_params or {},
    )

    return self.train(model, datamodule, callbacks, max_epochs, min_epochs, handle_oom)

pytorch_tabular.TabularModel.cross_validate(cv, train, metric=None, return_oof=False, groups=None, verbose=True, reset_datamodule=True, handle_oom=True, **kwargs)

Cross validate the model.

Parameters:

Name Type Description Default
cv Optional[Union[int, Iterable, BaseCrossValidator]]

Determines the cross-validation splitting strategy. Possible inputs for cv are:

  • None, to use the default 5-fold cross validation (KFold for Regression and StratifiedKFold for Classification),
  • integer, to specify the number of folds in a (Stratified)KFold,
  • An iterable yielding (train, test) splits as arrays of indices.
  • A scikit-learn CV splitter.
required
train DataFrame

The training data with labels

required
metric Optional[Union[str, Callable]]

The metrics to be used for evaluation. If None, will use the first metric in the config. If str is provided, will use that metric from the defined ones. If callable is provided, will use that function as the metric. We expect callable to be of the form metric(y_true, y_pred). For classification problems, The y_pred is a dataframe with the probabilities for each class (_probability) and a final prediction(prediction). And for Regression, it is a dataframe with a final prediction (_prediction). Defaults to None.

None
return_oof bool

If True, will return the out-of-fold predictions along with the cross validation results. Defaults to False.

False
groups Optional[Union[str, ndarray]]

Group labels for the samples used while splitting. If provided, will be used as the groups argument for the split method of the cross validator. If input is str, will use the column in the input dataframe with that name as the group labels. If input is array-like, will use that as the group. The only constraint is that the group labels should have the same size as the number of rows in the input dataframe. Defaults to None.

None
verbose bool

If True, will log the results. Defaults to True.

True
reset_datamodule bool

If True, will reset the datamodule for each iteration. It will be slower because we will be fitting the transformations for each fold. If False, we take an approximation that once the transformations are fit on the first fold, they will be valid for all the other folds. Defaults to True.

True
handle_oom bool

If True, will handle out of memory errors elegantly

True
**kwargs

Additional keyword arguments to be passed to the fit method of the model.

{}

Returns:

Name Type Description
DataFrame

The dataframe with the cross validation results

Source code in src/pytorch_tabular/tabular_model.py
def cross_validate(
    self,
    cv: Optional[Union[int, Iterable, BaseCrossValidator]],
    train: DataFrame,
    metric: Optional[Union[str, Callable]] = None,
    return_oof: bool = False,
    groups: Optional[Union[str, np.ndarray]] = None,
    verbose: bool = True,
    reset_datamodule: bool = True,
    handle_oom: bool = True,
    **kwargs,
):
    """Cross validate the model.

    Args:
        cv (Optional[Union[int, Iterable, BaseCrossValidator]]): Determines the cross-validation splitting strategy.
            Possible inputs for cv are:

            - None, to use the default 5-fold cross validation (KFold for
            Regression and StratifiedKFold for Classification),
            - integer, to specify the number of folds in a (Stratified)KFold,
            - An iterable yielding (train, test) splits as arrays of indices.
            - A scikit-learn CV splitter.

        train (DataFrame): The training data with labels

        metric (Optional[Union[str, Callable]], optional): The metrics to be used for evaluation.
            If None, will use the first metric in the config. If str is provided, will use that
            metric from the defined ones. If callable is provided, will use that function as the
            metric. We expect callable to be of the form `metric(y_true, y_pred)`. For classification
            problems, The `y_pred` is a dataframe with the probabilities for each class
            (<class>_probability) and a final prediction(prediction). And for Regression, it is a
            dataframe with a final prediction (<target>_prediction).
            Defaults to None.

        return_oof (bool, optional): If True, will return the out-of-fold predictions
            along with the cross validation results. Defaults to False.

        groups (Optional[Union[str, np.ndarray]], optional): Group labels for
            the samples used while splitting. If provided, will be used as the
            `groups` argument for the `split` method of the cross validator.
            If input is str, will use the column in the input dataframe with that
            name as the group labels. If input is array-like, will use that as the
            group. The only constraint is that the group labels should have the
            same size as the number of rows in the input dataframe. Defaults to None.

        verbose (bool, optional): If True, will log the results. Defaults to True.

        reset_datamodule (bool, optional): If True, will reset the datamodule for each iteration.
            It will be slower because we will be fitting the transformations for each fold.
            If False, we take an approximation that once the transformations are fit on the first
            fold, they will be valid for all the other folds. Defaults to True.

        handle_oom (bool, optional): If True, will handle out of memory errors elegantly
        **kwargs: Additional keyword arguments to be passed to the `fit` method of the model.

    Returns:
        DataFrame: The dataframe with the cross validation results

    """
    cv = self._check_cv(cv)
    prep_dl_kwargs, prep_model_kwargs, train_kwargs = self._split_kwargs(kwargs)
    is_callable_metric = False
    if metric is None:
        metric = "test_" + self.config.metrics[0]
    elif isinstance(metric, str):
        metric = metric if metric.startswith("test_") else "test_" + metric
    elif callable(metric):
        is_callable_metric = True

    if isinstance(cv, BaseCrossValidator):
        it = enumerate(cv.split(train, y=train[self.config.target], groups=groups))
    else:
        # when iterable is directly passed
        it = enumerate(cv)
    cv_metrics = []
    datamodule = None
    model = None
    oof_preds = []
    for fold, (train_idx, val_idx) in it:
        if verbose:
            logger.info(f"Running Fold {fold+1}/{cv.get_n_splits()}")
        # train_fold = train.iloc[train_idx]
        # val_fold = train.iloc[val_idx]
        if reset_datamodule:
            datamodule = None
        if datamodule is None:
            # Initialize datamodule and model in the first fold
            # uses train data from this fold to fit all transformers
            datamodule = self.prepare_dataloader(
                train=train.iloc[train_idx], validation=train.iloc[val_idx], seed=42, **prep_dl_kwargs
            )
            model = self.prepare_model(datamodule, **prep_model_kwargs)
        else:
            # Preprocess the current fold data using the fitted transformers and save in datamodule
            datamodule.train, _ = datamodule.preprocess_data(train.iloc[train_idx], stage="inference")
            datamodule.validation, _ = datamodule.preprocess_data(train.iloc[val_idx], stage="inference")

        # Train the model
        handle_oom = train_kwargs.pop("handle_oom", handle_oom)
        self.train(model, datamodule, handle_oom=handle_oom, **train_kwargs)
        if return_oof or is_callable_metric:
            preds = self.predict(train.iloc[val_idx], include_input_features=False)
            oof_preds.append(preds)
        if is_callable_metric:
            cv_metrics.append(metric(train.iloc[val_idx][self.config.target], preds))
        else:
            result = self.evaluate(train.iloc[val_idx], verbose=False)
            cv_metrics.append(result[0][metric])
        if verbose:
            logger.info(f"Fold {fold+1}/{cv.get_n_splits()} score: {cv_metrics[-1]}")
        self.model.reset_weights()
    return cv_metrics, oof_preds

Low-Level API

The low-level API is more flexible and allows you to write more complicated logic like cross validation, ensembling, etc. The low-level API is more verbose and requires you to write more code, but it comes with more control to the user.

The fit method is split into three sub-methods:

  1. prepare_dataloader

  2. prepare_model

  3. train

prepare_dataloader

This method is responsible for setting up the TabularDataModule and returns the object. You can save this object using save_dataloader and load it later using load_datamodule to skip the data preparation step. This is useful when you are doing cross validation or ensembling.

pytorch_tabular.TabularModel.prepare_dataloader(train, validation=None, train_sampler=None, target_transform=None, seed=42, cache_data='memory')

Prepares the dataloaders for training and validation.

Parameters:

Name Type Description Default
train DataFrame

Training Dataframe

required
validation Optional[DataFrame]

If provided, will use this dataframe as the validation while training. Used in Early Stopping and Logging. If left empty, will use 20% of Train data as validation. Defaults to None.

None
train_sampler Optional[Sampler]

Custom PyTorch batch samplers which will be passed to the DataLoaders. Useful for dealing with imbalanced data and other custom batching strategies

None
target_transform Optional[Union[TransformerMixin, Tuple(Callable)]]

If provided, applies the transform to the target before modelling and inverse the transform during prediction. The parameter can either be a sklearn Transformer which has an inverse_transform method, or a tuple of callables (transform_func, inverse_transform_func)

None
seed Optional[int]

Random seed for reproducibility. Defaults to 42.

42
cache_data str

Decides how to cache the data in the dataloader. If set to "memory", will cache in memory. If set to a valid path, will cache in that path. Defaults to "memory".

'memory'

Returns: TabularDatamodule: The prepared datamodule

Source code in src/pytorch_tabular/tabular_model.py
def prepare_dataloader(
    self,
    train: DataFrame,
    validation: Optional[DataFrame] = None,
    train_sampler: Optional[torch.utils.data.Sampler] = None,
    target_transform: Optional[Union[TransformerMixin, Tuple]] = None,
    seed: Optional[int] = 42,
    cache_data: str = "memory",
) -> TabularDatamodule:
    """Prepares the dataloaders for training and validation.

    Args:
        train (DataFrame): Training Dataframe

        validation (Optional[DataFrame], optional):
            If provided, will use this dataframe as the validation while training.
            Used in Early Stopping and Logging. If left empty, will use 20% of Train data as validation.
            Defaults to None.

        train_sampler (Optional[torch.utils.data.Sampler], optional):
            Custom PyTorch batch samplers which will be passed to the DataLoaders.
            Useful for dealing with imbalanced data and other custom batching strategies

        target_transform (Optional[Union[TransformerMixin, Tuple(Callable)]], optional):
            If provided, applies the transform to the target before modelling and inverse the transform during
            prediction. The parameter can either be a sklearn Transformer which has an inverse_transform method, or
            a tuple of callables (transform_func, inverse_transform_func)

        seed (Optional[int], optional): Random seed for reproducibility. Defaults to 42.

        cache_data (str): Decides how to cache the data in the dataloader. If set to
            "memory", will cache in memory. If set to a valid path, will cache in that path. Defaults to "memory".
    Returns:
        TabularDatamodule: The prepared datamodule

    """
    if self.verbose:
        logger.info("Preparing the DataLoaders")
    target_transform = self._check_and_set_target_transform(target_transform)

    datamodule = TabularDatamodule(
        train=train,
        validation=validation,
        config=self.config,
        target_transform=target_transform,
        train_sampler=train_sampler,
        seed=seed,
        cache_data=cache_data,
        verbose=self.verbose,
    )
    datamodule.prepare_data()
    datamodule.setup("fit")
    return datamodule

prepare_model

This method is responsible for setting up and initializing the model and takes in the prepared datamodule as an input. It returns the model instance.

pytorch_tabular.TabularModel.prepare_model(datamodule, loss=None, metrics=None, metrics_prob_inputs=None, optimizer=None, optimizer_params=None)

Prepares the model for training.

Parameters:

Name Type Description Default
datamodule TabularDatamodule

The datamodule

required
loss Optional[Module]

Custom Loss functions which are not in standard pytorch library

None
metrics Optional[List[Callable]]

Custom metric functions(Callable) which has the signature metric_fn(y_hat, y) and works on torch tensor inputs

None
metrics_prob_inputs Optional[List[bool]]

This is a mandatory parameter for classification metrics. If the metric function requires probabilities as inputs, set this to True. The length of the list should be equal to the number of metrics. Defaults to None.

None
optimizer Optional[Optimizer]

Custom optimizers which are a drop in replacements for standard PyTorch optimizers. This should be the Class and not the initialized object

None
optimizer_params Optional[Dict]

The parameters to initialize the custom optimizer.

None

Returns:

Name Type Description
BaseModel BaseModel

The prepared model

Source code in src/pytorch_tabular/tabular_model.py
def prepare_model(
    self,
    datamodule: TabularDatamodule,
    loss: Optional[torch.nn.Module] = None,
    metrics: Optional[List[Callable]] = None,
    metrics_prob_inputs: Optional[List[bool]] = None,
    optimizer: Optional[torch.optim.Optimizer] = None,
    optimizer_params: Dict = None,
) -> BaseModel:
    """Prepares the model for training.

    Args:
        datamodule (TabularDatamodule): The datamodule

        loss (Optional[torch.nn.Module], optional): Custom Loss functions which are not in standard pytorch library

        metrics (Optional[List[Callable]], optional): Custom metric functions(Callable) which has the
            signature metric_fn(y_hat, y) and works on torch tensor inputs

        metrics_prob_inputs (Optional[List[bool]], optional): This is a mandatory parameter for
            classification metrics. If the metric function requires probabilities as inputs, set this to True.
            The length of the list should be equal to the number of metrics. Defaults to None.

        optimizer (Optional[torch.optim.Optimizer], optional):
            Custom optimizers which are a drop in replacements for standard PyTorch optimizers.
            This should be the Class and not the initialized object

        optimizer_params (Optional[Dict], optional): The parameters to initialize the custom optimizer.

    Returns:
        BaseModel: The prepared model

    """
    if self.verbose:
        logger.info(f"Preparing the Model: {self.config._model_name}")
    # Fetching the config as some data specific configs have been added in the datamodule
    self.inferred_config = self._read_parse_config(datamodule.update_config(self.config), InferredConfig)
    model = self.model_callable(
        self.config,
        custom_loss=loss,  # Unused in SSL tasks
        custom_metrics=metrics,  # Unused in SSL tasks
        custom_metrics_prob_inputs=metrics_prob_inputs,  # Unused in SSL tasks
        custom_optimizer=optimizer,
        custom_optimizer_params=optimizer_params or {},
        inferred_config=self.inferred_config,
    )
    # Data Aware Initialization(for the models that need it)
    model.data_aware_initialization(datamodule)
    if self.model_state_dict_path is not None:
        self._load_weights(model, self.model_state_dict_path)
    if self.track_experiment and self.config.log_target == "wandb":
        self.logger.watch(model, log=self.config.exp_watch, log_freq=self.config.exp_log_freq)
    return model

train

This method is responsible for training the model and takes in the prepared datamodule and model as an input. It returns the trained model instance.

pytorch_tabular.TabularModel.train(model, datamodule, callbacks=None, max_epochs=None, min_epochs=None, handle_oom=True)

Trains the model.

Parameters:

Name Type Description Default
model LightningModule

The PyTorch Lightning model to be trained.

required
datamodule TabularDatamodule

The datamodule

required
callbacks Optional[List[Callback]]

List of callbacks to be used during training. Defaults to None.

None
max_epochs Optional[int]

Overwrite maximum number of epochs to be run. Defaults to None.

None
min_epochs Optional[int]

Overwrite minimum number of epochs to be run. Defaults to None.

None
handle_oom bool

If True, will try to handle OOM errors elegantly. Defaults to True.

True

Returns:

Type Description
Trainer

pl.Trainer: The PyTorch Lightning Trainer instance

Source code in src/pytorch_tabular/tabular_model.py
def train(
    self,
    model: pl.LightningModule,
    datamodule: TabularDatamodule,
    callbacks: Optional[List[pl.Callback]] = None,
    max_epochs: int = None,
    min_epochs: int = None,
    handle_oom: bool = True,
) -> pl.Trainer:
    """Trains the model.

    Args:
        model (pl.LightningModule): The PyTorch Lightning model to be trained.

        datamodule (TabularDatamodule): The datamodule

        callbacks (Optional[List[pl.Callback]], optional):
            List of callbacks to be used during training. Defaults to None.

        max_epochs (Optional[int]): Overwrite maximum number of epochs to be run. Defaults to None.

        min_epochs (Optional[int]): Overwrite minimum number of epochs to be run. Defaults to None.

        handle_oom (bool): If True, will try to handle OOM errors elegantly. Defaults to True.

    Returns:
        pl.Trainer: The PyTorch Lightning Trainer instance

    """
    self._prepare_for_training(model, datamodule, callbacks, max_epochs, min_epochs)
    train_loader, val_loader = (
        self.datamodule.train_dataloader(),
        self.datamodule.val_dataloader(),
    )
    self.model.train()
    if self.config.auto_lr_find and (not self.config.fast_dev_run):
        if self.verbose:
            logger.info("Auto LR Find Started")
        with OutOfMemoryHandler(handle_oom=handle_oom) as oom_handler:
            result = Tuner(self.trainer).lr_find(
                self.model,
                train_dataloaders=train_loader,
                val_dataloaders=val_loader,
            )
        if oom_handler.oom_triggered:
            raise OOMException(
                "OOM detected during LR Find. Try reducing your batch_size or the"
                " model parameters." + "/n" + "Original Error: " + oom_handler.oom_msg
            )
        if self.verbose:
            logger.info(
                f"Suggested LR: {result.suggestion()}. For plot and detailed"
                " analysis, use `find_learning_rate` method."
            )
        self.model.reset_weights()
        # Parameters in models needs to be initialized again after LR find
        self.model.data_aware_initialization(self.datamodule)
    self.model.train()
    if self.verbose:
        logger.info("Training Started")
    with OutOfMemoryHandler(handle_oom=handle_oom) as oom_handler:
        self.trainer.fit(self.model, train_loader, val_loader)
    if oom_handler.oom_triggered:
        raise OOMException(
            "OOM detected during Training. Try reducing your batch_size or the"
            " model parameters."
            "/n" + "Original Error: " + oom_handler.oom_msg
        )
    self._is_fitted = True
    if self.verbose:
        logger.info("Training the model completed")
    if self.config.load_best:
        self.load_best_model()
    return self.trainer

Training API (Self-Supervised Learning)

For self-supervised learning, there is a different API because the process is different.

  1. pytorch_tabular.TabularModel.pretrain: This method is responsible for pretraining the model. It takes in the the input dataframes, and other parameters to pre-train on the provided data.
  2. pytorch_tabular.TabularModel.create_finetune_model: If we want to use the pretrained model for finetuning, we need to create a new model with the pretrained weights. This method is responsible for creating a finetune model. It takes in the pre-trained model and returns a finetune model. The returned object is a separate instance of TabularModel and can be used to finetune the model.
  3. pytorch_tabular.TabularModel.finetune: This method is responsible for finetuning the model and can only be used with a model which is created through create_finetune_model. It takes in the the input dataframes, and other parameters to finetune on the provided data.

Note

The dataframes passed to pretrain need not have the target column. But even if you defined the target column in DataConfig, it will be ignored. But the dataframes passed to finetune must have the target column.

pytorch_tabular.TabularModel.pretrain(train, validation=None, optimizer=None, optimizer_params=None, max_epochs=None, min_epochs=None, seed=42, callbacks=None, datamodule=None, cache_data='memory')

The pretrained method which takes in the data and triggers the training.

Parameters:

Name Type Description Default
train DataFrame

Training Dataframe

required
validation Optional[DataFrame]

If provided, will use this dataframe as the validation while training. Used in Early Stopping and Logging. If left empty, will use 20% of Train data as validation. Defaults to None.

None
optimizer Optional[Optimizer]

Custom optimizers which are a drop in replacements for standard PyTorch optimizers. This should be the Class and not the initialized object

None
optimizer_params Optional[Dict]

The parameters to initialize the custom optimizer.

None
max_epochs Optional[int]

Overwrite maximum number of epochs to be run. Defaults to None.

None
min_epochs Optional[int]

Overwrite minimum number of epochs to be run. Defaults to None.

None
seed Optional[int]

(int): Random seed for reproducibility. Defaults to 42.

42
callbacks Optional[List[Callback]]

List of callbacks to be used during training. Defaults to None.

None
datamodule Optional[TabularDatamodule]

The datamodule. If provided, will ignore the rest of the parameters like train, test etc. and use the datamodule. Defaults to None.

None
cache_data str

Decides how to cache the data in the dataloader. If set to "memory", will cache in memory. If set to a valid path, will cache in that path. Defaults to "memory".

'memory'

Returns: pl.Trainer: The PyTorch Lightning Trainer instance

Source code in src/pytorch_tabular/tabular_model.py
def pretrain(
    self,
    train: Optional[DataFrame],
    validation: Optional[DataFrame] = None,
    optimizer: Optional[torch.optim.Optimizer] = None,
    optimizer_params: Dict = None,
    # train_sampler: Optional[torch.utils.data.Sampler] = None,
    max_epochs: Optional[int] = None,
    min_epochs: Optional[int] = None,
    seed: Optional[int] = 42,
    callbacks: Optional[List[pl.Callback]] = None,
    datamodule: Optional[TabularDatamodule] = None,
    cache_data: str = "memory",
) -> pl.Trainer:
    """The pretrained method which takes in the data and triggers the training.

    Args:
        train (DataFrame): Training Dataframe

        validation (Optional[DataFrame], optional): If provided, will use this dataframe as the validation while
            training. Used in Early Stopping and Logging. If left empty, will use 20% of Train data as validation.
            Defaults to None.

        optimizer (Optional[torch.optim.Optimizer], optional): Custom optimizers which are a drop in replacements
            for standard PyTorch optimizers. This should be the Class and not the initialized object

        optimizer_params (Optional[Dict], optional): The parameters to initialize the custom optimizer.

        max_epochs (Optional[int]): Overwrite maximum number of epochs to be run. Defaults to None.

        min_epochs (Optional[int]): Overwrite minimum number of epochs to be run. Defaults to None.

        seed: (int): Random seed for reproducibility. Defaults to 42.

        callbacks (Optional[List[pl.Callback]], optional): List of callbacks to be used during training.
            Defaults to None.

        datamodule (Optional[TabularDatamodule], optional): The datamodule. If provided, will ignore the rest of the
            parameters like train, test etc. and use the datamodule. Defaults to None.

        cache_data (str): Decides how to cache the data in the dataloader. If set to
            "memory", will cache in memory. If set to a valid path, will cache in that path. Defaults to "memory".
    Returns:
        pl.Trainer: The PyTorch Lightning Trainer instance

    """
    assert self.config.task == "ssl", (
        f"`pretrain` is not valid for {self.config.task} task. Please use `fit`" " instead."
    )
    seed = seed or self.config.seed
    if seed:
        seed_everything(seed)
    if datamodule is None:
        datamodule = self.prepare_dataloader(
            train,
            validation,
            train_sampler=None,
            target_transform=None,
            seed=seed,
            cache_data=cache_data,
        )
    else:
        if train is not None:
            warnings.warn(
                "train data and datamodule is provided."
                " Ignoring the train data and using the datamodule."
                " Set either one of them to None to avoid this warning."
            )
    model = self.prepare_model(
        datamodule,
        optimizer,
        optimizer_params or {},
    )

    return self.train(model, datamodule, callbacks, max_epochs, min_epochs)

pytorch_tabular.TabularModel.create_finetune_model(task, head, head_config, train, validation=None, train_sampler=None, target_transform=None, target=None, optimizer_config=None, trainer_config=None, experiment_config=None, loss=None, metrics=None, metrics_prob_input=None, metrics_params=None, optimizer=None, optimizer_params=None, learning_rate=None, target_range=None, seed=42)

Creates a new TabularModel model using the pretrained weights and the new task and head.

Parameters:

Name Type Description Default
task str

The task to be performed. One of "regression", "classification"

required
head str

The head to be used for the model. Should be one of the heads defined in pytorch_tabular.models.common.heads. Defaults to LinearHead. Choices are: [None,LinearHead,MixtureDensityHead].

required
head_config Dict

The config as a dict which defines the head. If left empty, will be initialized as default linear head.

required
train DataFrame

The training data with labels

required
validation Optional[DataFrame]

The validation data with labels. Defaults to None.

None
train_sampler Optional[Sampler]

If provided, will be used as a batch sampler for training. Defaults to None.

None
target_transform Optional[Union[TransformerMixin, Tuple]]

If provided, will be used to transform the target before training and inverse transform the predictions.

None
target Optional[str]

The target column name if not provided in the initial pretraining stage. Defaults to None.

None
optimizer_config Optional[OptimizerConfig]

If provided, will redefine the optimizer for fine-tuning stage. Defaults to None.

None
trainer_config Optional[TrainerConfig]

If provided, will redefine the trainer for fine-tuning stage. Defaults to None.

None
experiment_config Optional[ExperimentConfig]

If provided, will redefine the experiment for fine-tuning stage. Defaults to None.

None
loss Optional[Module]

If provided, will be used as the loss function for the fine-tuning. By default, it is MSELoss for regression and CrossEntropyLoss for classification.

None
metrics Optional[List[Callable]]

List of metrics (either callables or str) to be used for the fine-tuning stage. If str, it should be one of the functional metrics implemented in torchmetrics.functional. Defaults to None.

None
metrics_prob_input Optional[List[bool]]

Is a mandatory parameter for classification metrics This defines whether the input to the metric function is the probability or the class. Length should be same as the number of metrics. Defaults to None.

None
metrics_params Optional[Dict]

The parameters for the metrics in the same order as metrics. For eg. f1_score for multi-class needs a parameter average to fully define the metric. Defaults to None.

None
optimizer Optional[Optimizer]

Custom optimizers which are a drop in replacements for standard PyTorch optimizers. If provided, the OptimizerConfig is ignored in favor of this. Defaults to None.

None
optimizer_params Dict

The parameters for the optimizer. Defaults to {}.

None
learning_rate Optional[float]

The learning rate to be used. Defaults to 1e-3.

None
target_range Optional[Tuple[float, float]]

The target range for the regression task. Is ignored for classification. Defaults to None.

None
seed Optional[int]

Random seed for reproducibility. Defaults to 42.

42

Returns: TabularModel (TabularModel): The new TabularModel model for fine-tuning

Source code in src/pytorch_tabular/tabular_model.py
def create_finetune_model(
    self,
    task: str,
    head: str,
    head_config: Dict,
    train: DataFrame,
    validation: Optional[DataFrame] = None,
    train_sampler: Optional[torch.utils.data.Sampler] = None,
    target_transform: Optional[Union[TransformerMixin, Tuple]] = None,
    target: Optional[str] = None,
    optimizer_config: Optional[OptimizerConfig] = None,
    trainer_config: Optional[TrainerConfig] = None,
    experiment_config: Optional[ExperimentConfig] = None,
    loss: Optional[torch.nn.Module] = None,
    metrics: Optional[List[Union[Callable, str]]] = None,
    metrics_prob_input: Optional[List[bool]] = None,
    metrics_params: Optional[Dict] = None,
    optimizer: Optional[torch.optim.Optimizer] = None,
    optimizer_params: Dict = None,
    learning_rate: Optional[float] = None,
    target_range: Optional[Tuple[float, float]] = None,
    seed: Optional[int] = 42,
):
    """Creates a new TabularModel model using the pretrained weights and the new task and head.

    Args:
        task (str): The task to be performed. One of "regression", "classification"

        head (str): The head to be used for the model. Should be one of the heads defined
            in `pytorch_tabular.models.common.heads`. Defaults to  LinearHead. Choices are:
            [`None`,`LinearHead`,`MixtureDensityHead`].

        head_config (Dict): The config as a dict which defines the head. If left empty,
            will be initialized as default linear head.

        train (DataFrame): The training data with labels

        validation (Optional[DataFrame], optional): The validation data with labels. Defaults to None.

        train_sampler (Optional[torch.utils.data.Sampler], optional): If provided, will be used as a batch sampler
            for training. Defaults to None.

        target_transform (Optional[Union[TransformerMixin, Tuple]], optional): If provided, will be used
            to transform the target before training and inverse transform the predictions.

        target (Optional[str], optional): The target column name if not provided in the initial pretraining stage.
            Defaults to None.

        optimizer_config (Optional[OptimizerConfig], optional):
            If provided, will redefine the optimizer for fine-tuning stage. Defaults to None.

        trainer_config (Optional[TrainerConfig], optional):
            If provided, will redefine the trainer for fine-tuning stage. Defaults to None.

        experiment_config (Optional[ExperimentConfig], optional):
            If provided, will redefine the experiment for fine-tuning stage. Defaults to None.

        loss (Optional[torch.nn.Module], optional):
            If provided, will be used as the loss function for the fine-tuning.
            By default, it is MSELoss for regression and CrossEntropyLoss for classification.

        metrics (Optional[List[Callable]], optional): List of metrics (either callables or str) to be used for the
            fine-tuning stage. If str, it should be one of the functional metrics implemented in
            ``torchmetrics.functional``. Defaults to None.

        metrics_prob_input (Optional[List[bool]], optional): Is a mandatory parameter for classification metrics
            This defines whether the input to the metric function is the probability or the class.
            Length should be same as the number of metrics. Defaults to None.

        metrics_params (Optional[Dict], optional): The parameters for the metrics in the same order as metrics.
            For eg. f1_score for multi-class needs a parameter `average` to fully define the metric.
            Defaults to None.

        optimizer (Optional[torch.optim.Optimizer], optional):
            Custom optimizers which are a drop in replacements for standard PyTorch optimizers. If provided,
            the OptimizerConfig is ignored in favor of this. Defaults to None.

        optimizer_params (Dict, optional): The parameters for the optimizer. Defaults to {}.

        learning_rate (Optional[float], optional): The learning rate to be used. Defaults to 1e-3.

        target_range (Optional[Tuple[float, float]], optional): The target range for the regression task.
            Is ignored for classification. Defaults to None.

        seed (Optional[int], optional): Random seed for reproducibility. Defaults to 42.
    Returns:
        TabularModel (TabularModel): The new TabularModel model for fine-tuning

    """
    config = self.config
    optimizer_params = optimizer_params or {}
    if target is None:
        assert (
            hasattr(config, "target") and config.target is not None
        ), "`target` cannot be None if it was not set in the initial `DataConfig`"
    else:
        assert isinstance(target, list), "`target` should be a list of strings"
        config.target = target
    config.task = task
    # Add code to update configs with newly provided ones
    if optimizer_config is not None:
        for key, value in optimizer_config.__dict__.items():
            config[key] = value
        if len(optimizer_params) > 0:
            config.optimizer_params = optimizer_params
        else:
            config.optimizer_params = {}
    if trainer_config is not None:
        for key, value in trainer_config.__dict__.items():
            config[key] = value
    if experiment_config is not None:
        for key, value in experiment_config.__dict__.items():
            config[key] = value
    else:
        if self.track_experiment:
            # Renaming the experiment run so that a different log is created for finetuning
            if self.verbose:
                logger.info("Renaming the experiment run for finetuning as" f" {config['run_name'] + '_finetuned'}")
            config["run_name"] = config["run_name"] + "_finetuned"

    datamodule = self.datamodule.copy(
        train=train,
        validation=validation,
        target_transform=target_transform,
        train_sampler=train_sampler,
        seed=seed,
        config_override={"target": target} if target is not None else {},
    )
    model_callable = _GenericModel
    inferred_config = OmegaConf.structured(datamodule._inferred_config)
    # Adding dummy attributes for compatibility. Not used because custom metrics are provided
    if not hasattr(config, "metrics"):
        config.metrics = "dummy"
    if not hasattr(config, "metrics_params"):
        config.metrics_params = {}
    if not hasattr(config, "metrics_prob_input"):
        config.metrics_prob_input = metrics_prob_input or [False]
    if metrics is not None:
        assert len(metrics) == len(metrics_params), "Number of metrics and metrics_params should be same"
        assert len(metrics) == len(metrics_prob_input), "Number of metrics and metrics_prob_input should be same"
        metrics = [getattr(torchmetrics.functional, m) if isinstance(m, str) else m for m in metrics]
    if task == "regression":
        loss = loss or torch.nn.MSELoss()
        if metrics is None:
            metrics = [torchmetrics.functional.mean_squared_error]
            metrics_params = [{}]
    elif task == "classification":
        loss = loss or torch.nn.CrossEntropyLoss()
        if metrics is None:
            metrics = [torchmetrics.functional.accuracy]
            metrics_params = [
                {
                    "task": "multiclass",
                    "num_classes": inferred_config.output_dim,
                    "top_k": 1,
                }
            ]
            metrics_prob_input = [False]
        else:
            for i, mp in enumerate(metrics_params):
                # For classification task, output_dim == number of classses
                metrics_params[i]["task"] = mp.get("task", "multiclass")
                metrics_params[i]["num_classes"] = mp.get("num_classes", inferred_config.output_dim)
                metrics_params[i]["top_k"] = mp.get("top_k", 1)
    else:
        raise ValueError(f"Task {task} not supported")
    # Forming partial callables using metrics and metric params
    metrics = [partial(m, **mp) for m, mp in zip(metrics, metrics_params)]
    self.model.mode = "finetune"
    if learning_rate is not None:
        config.learning_rate = learning_rate
    config.target_range = target_range
    model_args = {
        "backbone": self.model,
        "head": head,
        "head_config": head_config,
        "config": config,
        "inferred_config": inferred_config,
        "custom_loss": loss,
        "custom_metrics": metrics,
        "custom_metrics_prob_inputs": metrics_prob_input,
        "custom_optimizer": optimizer,
        "custom_optimizer_params": optimizer_params,
    }
    # Initializing with default metrics, losses, and optimizers. Will revert once initialized
    model = model_callable(
        **model_args,
    )
    tabular_model = TabularModel(config=config, verbose=self.verbose)
    tabular_model.model = model
    tabular_model.datamodule = datamodule
    # Setting a flag to identify this as a fine-tune model
    tabular_model._is_finetune_model = True
    return tabular_model

pytorch_tabular.TabularModel.finetune(max_epochs=None, min_epochs=None, callbacks=None, freeze_backbone=False)

Finetunes the model on the provided data.

Parameters:

Name Type Description Default
max_epochs Optional[int]

The maximum number of epochs to train for. Defaults to None.

None
min_epochs Optional[int]

The minimum number of epochs to train for. Defaults to None.

None
callbacks Optional[List[Callback]]

If provided, will be added to the callbacks for Trainer. Defaults to None.

None
freeze_backbone bool

If True, will freeze the backbone by tirning off gradients. Defaults to False, which means the pretrained weights are also further tuned during fine-tuning.

False

Returns:

Type Description
Trainer

pl.Trainer: The trainer object

Source code in src/pytorch_tabular/tabular_model.py
def finetune(
    self,
    max_epochs: Optional[int] = None,
    min_epochs: Optional[int] = None,
    callbacks: Optional[List[pl.Callback]] = None,
    freeze_backbone: bool = False,
) -> pl.Trainer:
    """Finetunes the model on the provided data.

    Args:
        max_epochs (Optional[int], optional): The maximum number of epochs to train for. Defaults to None.

        min_epochs (Optional[int], optional): The minimum number of epochs to train for. Defaults to None.

        callbacks (Optional[List[pl.Callback]], optional): If provided, will be added to the callbacks for Trainer.
            Defaults to None.

        freeze_backbone (bool, optional): If True, will freeze the backbone by tirning off gradients.
            Defaults to False, which means the pretrained weights are also further tuned during fine-tuning.

    Returns:
        pl.Trainer: The trainer object

    """
    assert self._is_finetune_model, (
        "finetune() can only be called on a finetune model created using" " `TabularModel.create_finetune_model()`"
    )
    seed_everything(self.config.seed)
    if freeze_backbone:
        for param in self.model.backbone.parameters():
            param.requires_grad = False
    return self.train(
        self.model,
        self.datamodule,
        callbacks=callbacks,
        max_epochs=max_epochs,
        min_epochs=min_epochs,
    )

Model Evaluation

pytorch_tabular.TabularModel.predict(test, quantiles=[0.25, 0.5, 0.75], n_samples=100, ret_logits=False, include_input_features=False, device=None, progress_bar=None, test_time_augmentation=False, num_tta=5, alpha_tta=0.1, aggregate_tta='mean', tta_seed=42)

Uses the trained model to predict on new data and return as a dataframe.

Parameters:

Name Type Description Default
test DataFrame

The new dataframe with the features defined during training

required
quantiles Optional[List]

For probabilistic models like Mixture Density Networks, this specifies the different quantiles to be extracted apart from the central_tendency and added to the dataframe. For other models it is ignored. Defaults to [0.25, 0.5, 0.75]

[0.25, 0.5, 0.75]
n_samples Optional[int]

Number of samples to draw from the posterior to estimate the quantiles. Ignored for non-probabilistic models. Defaults to 100

100
ret_logits bool

Flag to return raw model outputs/logits except the backbone features along with the dataframe. Defaults to False

False
include_input_features bool

DEPRECATED: Flag to include the input features in the returned dataframe. Defaults to True

False
progress_bar Optional[str]

choose progress bar for tracking the progress. "rich" or "tqdm" will set the respective progress bars. If None, no progress bar will be shown.

None
test_time_augmentation bool

If True, will use test time augmentation to generate predictions. The approach is very similar to what is described here But, we add noise to the embedded inputs to handle categorical features as well. (x_{aug} = x_{orig} + lpha * \epsilon) where (\epsilon \sim \mathcal{N}(0, 1)) Defaults to False

False
num_tta float

The number of augumentations to run TTA for. Defaults to 0.0

5
alpha_tta float

The standard deviation of the gaussian noise to be added to the input features

0.1
aggregate_tta Union[str, Callable]

The function to be used to aggregate the predictions from each augumentation. If str, should be one of "mean", "median", "min", or "max" for regression. For classification, the previous options are applied to the confidence scores (soft voting) and then converted to final prediction. An additional option "hard_voting" is available for classification. If callable, should be a function that takes in a list of 3D arrays (num_samples, num_cv, num_targets) and returns a 2D array of final probabilities (num_samples, num_targets). Defaults to "mean".'

'mean'
tta_seed int

The random seed to be used for the noise added in TTA. Defaults to 42.

42

Returns:

Name Type Description
DataFrame DataFrame

Returns a dataframe with predictions and features (if include_input_features=True). If classification, it returns probabilities and final prediction

Source code in src/pytorch_tabular/tabular_model.py
def predict(
    self,
    test: DataFrame,
    quantiles: Optional[List] = [0.25, 0.5, 0.75],
    n_samples: Optional[int] = 100,
    ret_logits=False,
    include_input_features: bool = False,
    device: Optional[torch.device] = None,
    progress_bar: Optional[str] = None,
    test_time_augmentation: Optional[bool] = False,
    num_tta: Optional[float] = 5,
    alpha_tta: Optional[float] = 0.1,
    aggregate_tta: Optional[str] = "mean",
    tta_seed: Optional[int] = 42,
) -> DataFrame:
    """Uses the trained model to predict on new data and return as a dataframe.

    Args:
        test (DataFrame): The new dataframe with the features defined during training

        quantiles (Optional[List]): For probabilistic models like Mixture Density Networks, this specifies
            the different quantiles to be extracted apart from the `central_tendency` and added to the dataframe.
            For other models it is ignored. Defaults to [0.25, 0.5, 0.75]

        n_samples (Optional[int]): Number of samples to draw from the posterior to estimate the quantiles.
            Ignored for non-probabilistic models. Defaults to 100

        ret_logits (bool): Flag to return raw model outputs/logits except the backbone features along
            with the dataframe. Defaults to False

        include_input_features (bool): DEPRECATED: Flag to include the input features in the returned dataframe.
            Defaults to True

        progress_bar: choose progress bar for tracking the progress. "rich" or "tqdm" will set the respective
            progress bars. If None, no progress bar will be shown.

        test_time_augmentation (bool): If True, will use test time augmentation to generate predictions.
            The approach is very similar to what is described [here](https://kozodoi.me/blog/20210908/tta-tabular)
            But, we add noise to the embedded inputs to handle categorical features as well.\
            \\(x_{aug} = x_{orig} + \alpha * \\epsilon\\) where \\(\\epsilon \\sim \\mathcal{N}(0, 1)\\)
            Defaults to False
        num_tta (float): The number of augumentations to run TTA for. Defaults to 0.0

        alpha_tta (float): The standard deviation of the gaussian noise to be added to the input features

        aggregate_tta (Union[str, Callable], optional): The function to be used to aggregate the
            predictions from each augumentation. If str, should be one of "mean", "median", "min", or "max"
            for regression. For classification, the previous options are applied to the confidence
            scores (soft voting) and then converted to final prediction. An additional option
            "hard_voting" is available for classification.
            If callable, should be a function that takes in a list of 3D arrays (num_samples, num_cv, num_targets)
            and returns a 2D array of final probabilities (num_samples, num_targets). Defaults to "mean".'

        tta_seed (int): The random seed to be used for the noise added in TTA. Defaults to 42.

    Returns:
        DataFrame: Returns a dataframe with predictions and features (if `include_input_features=True`).
            If classification, it returns probabilities and final prediction

    """
    warnings.warn(
        "`include_input_features` will be deprecated in the next release."
        " Please add index columns to the test dataframe if you want to"
        " retain some features like the key or id",
        DeprecationWarning,
    )
    if test_time_augmentation:
        assert num_tta > 0, "num_tta should be greater than 0"
        assert alpha_tta > 0, "alpha_tta should be greater than 0"
        assert include_input_features is False, "include_input_features cannot be True for TTA."
        if not callable(aggregate_tta):
            assert aggregate_tta in [
                "mean",
                "median",
                "min",
                "max",
                "hard_voting",
            ], "aggregate should be one of 'mean', 'median', 'min', 'max', or" " 'hard_voting'"
        if self.config.task == "regression":
            assert aggregate_tta != "hard_voting", "hard_voting is only available for classification"

        torch.manual_seed(tta_seed)

        def add_noise(module, input, output):
            return output + alpha_tta * torch.randn_like(output, memory_format=torch.contiguous_format)

        # Register the hook to the embedding_layer
        handle = self.model.embedding_layer.register_forward_hook(add_noise)
        pred_prob_l = []
        for _ in range(num_tta):
            pred_df = self._predict(
                test,
                quantiles,
                n_samples,
                ret_logits,
                include_input_features=False,
                device=device,
                progress_bar=progress_bar or "None",
            )
            pred_idx = pred_df.index
            if self.config.task == "classification":
                pred_prob_l.append(pred_df.values[:, : -len(self.config.target)])
            elif self.config.task == "regression":
                pred_prob_l.append(pred_df.values)
        pred_df = self._combine_predictions(pred_prob_l, pred_idx, aggregate_tta, None)
        # Remove the hook
        handle.remove()
    else:
        pred_df = self._predict(
            test,
            quantiles,
            n_samples,
            ret_logits,
            include_input_features,
            device,
            progress_bar,
        )
    return pred_df

pytorch_tabular.TabularModel.evaluate(test=None, test_loader=None, ckpt_path=None, verbose=True)

Evaluates the dataframe using the loss and metrics already set in config.

Parameters:

Name Type Description Default
test Optional[DataFrame]

The dataframe to be evaluated. If not provided, will try to use the test provided during fit. If that was also not provided will return an empty dictionary

None
test_loader Optional[DataLoader]

The dataloader to be used for evaluation. If provided, will use the dataloader instead of the test dataframe or the test data provided during fit. Defaults to None.

None
ckpt_path Optional[Union[str, Path]]

The path to the checkpoint to be loaded. If not provided, will try to use the best checkpoint during training.

None
verbose bool

If true, will print the results. Defaults to True.

True

Returns: The final test result dictionary.

Source code in src/pytorch_tabular/tabular_model.py
def evaluate(
    self,
    test: Optional[DataFrame] = None,
    test_loader: Optional[torch.utils.data.DataLoader] = None,
    ckpt_path: Optional[Union[str, Path]] = None,
    verbose: bool = True,
) -> Union[dict, list]:
    """Evaluates the dataframe using the loss and metrics already set in config.

    Args:
        test (Optional[DataFrame]): The dataframe to be evaluated. If not provided, will try to use the
            test provided during fit. If that was also not provided will return an empty dictionary

        test_loader (Optional[torch.utils.data.DataLoader], optional): The dataloader to be used for evaluation.
            If provided, will use the dataloader instead of the test dataframe or the test data provided during fit.
            Defaults to None.

        ckpt_path (Optional[Union[str, Path]], optional): The path to the checkpoint to be loaded. If not provided,
            will try to use the best checkpoint during training.

        verbose (bool, optional): If true, will print the results. Defaults to True.
    Returns:
        The final test result dictionary.

    """
    assert not (test_loader is None and test is None), (
        "Either `test_loader` or `test` should be provided."
        " If `test_loader` is not provided, `test` should be provided."
    )
    if test_loader is None:
        test_loader = self.datamodule.prepare_inference_dataloader(test)
    result = self.trainer.test(
        model=self.model,
        dataloaders=test_loader,
        ckpt_path=ckpt_path,
        verbose=verbose,
    )
    return result

pytorch_tabular.TabularModel.cross_validate(cv, train, metric=None, return_oof=False, groups=None, verbose=True, reset_datamodule=True, handle_oom=True, **kwargs)

Cross validate the model.

Parameters:

Name Type Description Default
cv Optional[Union[int, Iterable, BaseCrossValidator]]

Determines the cross-validation splitting strategy. Possible inputs for cv are:

  • None, to use the default 5-fold cross validation (KFold for Regression and StratifiedKFold for Classification),
  • integer, to specify the number of folds in a (Stratified)KFold,
  • An iterable yielding (train, test) splits as arrays of indices.
  • A scikit-learn CV splitter.
required
train DataFrame

The training data with labels

required
metric Optional[Union[str, Callable]]

The metrics to be used for evaluation. If None, will use the first metric in the config. If str is provided, will use that metric from the defined ones. If callable is provided, will use that function as the metric. We expect callable to be of the form metric(y_true, y_pred). For classification problems, The y_pred is a dataframe with the probabilities for each class (_probability) and a final prediction(prediction). And for Regression, it is a dataframe with a final prediction (_prediction). Defaults to None.

None
return_oof bool

If True, will return the out-of-fold predictions along with the cross validation results. Defaults to False.

False
groups Optional[Union[str, ndarray]]

Group labels for the samples used while splitting. If provided, will be used as the groups argument for the split method of the cross validator. If input is str, will use the column in the input dataframe with that name as the group labels. If input is array-like, will use that as the group. The only constraint is that the group labels should have the same size as the number of rows in the input dataframe. Defaults to None.

None
verbose bool

If True, will log the results. Defaults to True.

True
reset_datamodule bool

If True, will reset the datamodule for each iteration. It will be slower because we will be fitting the transformations for each fold. If False, we take an approximation that once the transformations are fit on the first fold, they will be valid for all the other folds. Defaults to True.

True
handle_oom bool

If True, will handle out of memory errors elegantly

True
**kwargs

Additional keyword arguments to be passed to the fit method of the model.

{}

Returns:

Name Type Description
DataFrame

The dataframe with the cross validation results

Source code in src/pytorch_tabular/tabular_model.py
def cross_validate(
    self,
    cv: Optional[Union[int, Iterable, BaseCrossValidator]],
    train: DataFrame,
    metric: Optional[Union[str, Callable]] = None,
    return_oof: bool = False,
    groups: Optional[Union[str, np.ndarray]] = None,
    verbose: bool = True,
    reset_datamodule: bool = True,
    handle_oom: bool = True,
    **kwargs,
):
    """Cross validate the model.

    Args:
        cv (Optional[Union[int, Iterable, BaseCrossValidator]]): Determines the cross-validation splitting strategy.
            Possible inputs for cv are:

            - None, to use the default 5-fold cross validation (KFold for
            Regression and StratifiedKFold for Classification),
            - integer, to specify the number of folds in a (Stratified)KFold,
            - An iterable yielding (train, test) splits as arrays of indices.
            - A scikit-learn CV splitter.

        train (DataFrame): The training data with labels

        metric (Optional[Union[str, Callable]], optional): The metrics to be used for evaluation.
            If None, will use the first metric in the config. If str is provided, will use that
            metric from the defined ones. If callable is provided, will use that function as the
            metric. We expect callable to be of the form `metric(y_true, y_pred)`. For classification
            problems, The `y_pred` is a dataframe with the probabilities for each class
            (<class>_probability) and a final prediction(prediction). And for Regression, it is a
            dataframe with a final prediction (<target>_prediction).
            Defaults to None.

        return_oof (bool, optional): If True, will return the out-of-fold predictions
            along with the cross validation results. Defaults to False.

        groups (Optional[Union[str, np.ndarray]], optional): Group labels for
            the samples used while splitting. If provided, will be used as the
            `groups` argument for the `split` method of the cross validator.
            If input is str, will use the column in the input dataframe with that
            name as the group labels. If input is array-like, will use that as the
            group. The only constraint is that the group labels should have the
            same size as the number of rows in the input dataframe. Defaults to None.

        verbose (bool, optional): If True, will log the results. Defaults to True.

        reset_datamodule (bool, optional): If True, will reset the datamodule for each iteration.
            It will be slower because we will be fitting the transformations for each fold.
            If False, we take an approximation that once the transformations are fit on the first
            fold, they will be valid for all the other folds. Defaults to True.

        handle_oom (bool, optional): If True, will handle out of memory errors elegantly
        **kwargs: Additional keyword arguments to be passed to the `fit` method of the model.

    Returns:
        DataFrame: The dataframe with the cross validation results

    """
    cv = self._check_cv(cv)
    prep_dl_kwargs, prep_model_kwargs, train_kwargs = self._split_kwargs(kwargs)
    is_callable_metric = False
    if metric is None:
        metric = "test_" + self.config.metrics[0]
    elif isinstance(metric, str):
        metric = metric if metric.startswith("test_") else "test_" + metric
    elif callable(metric):
        is_callable_metric = True

    if isinstance(cv, BaseCrossValidator):
        it = enumerate(cv.split(train, y=train[self.config.target], groups=groups))
    else:
        # when iterable is directly passed
        it = enumerate(cv)
    cv_metrics = []
    datamodule = None
    model = None
    oof_preds = []
    for fold, (train_idx, val_idx) in it:
        if verbose:
            logger.info(f"Running Fold {fold+1}/{cv.get_n_splits()}")
        # train_fold = train.iloc[train_idx]
        # val_fold = train.iloc[val_idx]
        if reset_datamodule:
            datamodule = None
        if datamodule is None:
            # Initialize datamodule and model in the first fold
            # uses train data from this fold to fit all transformers
            datamodule = self.prepare_dataloader(
                train=train.iloc[train_idx], validation=train.iloc[val_idx], seed=42, **prep_dl_kwargs
            )
            model = self.prepare_model(datamodule, **prep_model_kwargs)
        else:
            # Preprocess the current fold data using the fitted transformers and save in datamodule
            datamodule.train, _ = datamodule.preprocess_data(train.iloc[train_idx], stage="inference")
            datamodule.validation, _ = datamodule.preprocess_data(train.iloc[val_idx], stage="inference")

        # Train the model
        handle_oom = train_kwargs.pop("handle_oom", handle_oom)
        self.train(model, datamodule, handle_oom=handle_oom, **train_kwargs)
        if return_oof or is_callable_metric:
            preds = self.predict(train.iloc[val_idx], include_input_features=False)
            oof_preds.append(preds)
        if is_callable_metric:
            cv_metrics.append(metric(train.iloc[val_idx][self.config.target], preds))
        else:
            result = self.evaluate(train.iloc[val_idx], verbose=False)
            cv_metrics.append(result[0][metric])
        if verbose:
            logger.info(f"Fold {fold+1}/{cv.get_n_splits()} score: {cv_metrics[-1]}")
        self.model.reset_weights()
    return cv_metrics, oof_preds

pytorch_tabular.TabularModel.bagging_predict(cv, train, test, groups=None, verbose=True, reset_datamodule=True, return_raw_predictions=False, aggregate='mean', weights=None, handle_oom=True, **kwargs)

Bagging predict on the test data.

Parameters:

Name Type Description Default
cv Optional[Union[int, Iterable, BaseCrossValidator]]

Determines the cross-validation splitting strategy. Possible inputs for cv are:

  • None, to use the default 5-fold cross validation (KFold for Regression and StratifiedKFold for Classification),
  • integer, to specify the number of folds in a (Stratified)KFold,
  • An iterable yielding (train, test) splits as arrays of indices.
  • A scikit-learn CV splitter.
required
train DataFrame

The training data with labels

required
test DataFrame

The test data to be predicted

required
groups Optional[Union[str, ndarray]]

Group labels for the samples used while splitting. If provided, will be used as the groups argument for the split method of the cross validator. If input is str, will use the column in the input dataframe with that name as the group labels. If input is array-like, will use that as the group. The only constraint is that the group labels should have the same size as the number of rows in the input dataframe. Defaults to None.

None
verbose bool

If True, will log the results. Defaults to True.

True
reset_datamodule bool

If True, will reset the datamodule for each iteration. It will be slower because we will be fitting the transformations for each fold. If False, we take an approximation that once the transformations are fit on the first fold, they will be valid for all the other folds. Defaults to True.

True
return_raw_predictions bool

If True, will return the raw predictions from each fold. Defaults to False.

False
aggregate Union[str, Callable]

The function to be used to aggregate the predictions from each fold. If str, should be one of "mean", "median", "min", or "max" for regression. For classification, the previous options are applied to the confidence scores (soft voting) and then converted to final prediction. An additional option "hard_voting" is available for classification. If callable, should be a function that takes in a list of 3D arrays (num_samples, num_cv, num_targets) and returns a 2D array of final probabilities (num_samples, num_targets). Defaults to "mean".

'mean'
weights Optional[List[float]]

The weights to be used for aggregating the predictions from each fold. If None, will use equal weights. This is only used when aggregate is "mean". Defaults to None.

None
handle_oom bool

If True, will handle out of memory errors elegantly

True
**kwargs

Additional keyword arguments to be passed to the fit method of the model.

{}

Returns:

Name Type Description
DataFrame

The dataframe with the bagged predictions.

Source code in src/pytorch_tabular/tabular_model.py
def bagging_predict(
    self,
    cv: Optional[Union[int, Iterable, BaseCrossValidator]],
    train: DataFrame,
    test: DataFrame,
    groups: Optional[Union[str, np.ndarray]] = None,
    verbose: bool = True,
    reset_datamodule: bool = True,
    return_raw_predictions: bool = False,
    aggregate: Union[str, Callable] = "mean",
    weights: Optional[List[float]] = None,
    handle_oom: bool = True,
    **kwargs,
):
    """Bagging predict on the test data.

    Args:
        cv (Optional[Union[int, Iterable, BaseCrossValidator]]): Determines the cross-validation splitting strategy.
            Possible inputs for cv are:

            - None, to use the default 5-fold cross validation (KFold for
            Regression and StratifiedKFold for Classification),
            - integer, to specify the number of folds in a (Stratified)KFold,
            - An iterable yielding (train, test) splits as arrays of indices.
            - A scikit-learn CV splitter.

        train (DataFrame): The training data with labels

        test (DataFrame): The test data to be predicted

        groups (Optional[Union[str, np.ndarray]], optional): Group labels for
            the samples used while splitting. If provided, will be used as the
            `groups` argument for the `split` method of the cross validator.
            If input is str, will use the column in the input dataframe with that
            name as the group labels. If input is array-like, will use that as the
            group. The only constraint is that the group labels should have the
            same size as the number of rows in the input dataframe. Defaults to None.

        verbose (bool, optional): If True, will log the results. Defaults to True.

        reset_datamodule (bool, optional): If True, will reset the datamodule for each iteration.
            It will be slower because we will be fitting the transformations for each fold.
            If False, we take an approximation that once the transformations are fit on the first
            fold, they will be valid for all the other folds. Defaults to True.

        return_raw_predictions (bool, optional): If True, will return the raw predictions
            from each fold. Defaults to False.

        aggregate (Union[str, Callable], optional): The function to be used to aggregate the
            predictions from each fold. If str, should be one of "mean", "median", "min", or "max"
            for regression. For classification, the previous options are applied to the confidence
            scores (soft voting) and then converted to final prediction. An additional option
            "hard_voting" is available for classification.
            If callable, should be a function that takes in a list of 3D arrays (num_samples, num_cv, num_targets)
            and returns a 2D array of final probabilities (num_samples, num_targets). Defaults to "mean".

        weights (Optional[List[float]], optional): The weights to be used for aggregating the predictions
            from each fold. If None, will use equal weights. This is only used when `aggregate` is "mean".
            Defaults to None.

        handle_oom (bool, optional): If True, will handle out of memory errors elegantly

        **kwargs: Additional keyword arguments to be passed to the `fit` method of the model.

    Returns:
        DataFrame: The dataframe with the bagged predictions.

    """
    if weights is not None:
        assert len(weights) == cv.n_splits, "Number of weights should be equal to the number of folds"
    assert self.config.task in [
        "classification",
        "regression",
    ], "Bagging is only available for classification and regression"
    if not callable(aggregate):
        assert aggregate in ["mean", "median", "min", "max", "hard_voting"], (
            "aggregate should be one of 'mean', 'median', 'min', 'max', or" " 'hard_voting'"
        )
    if self.config.task == "regression":
        assert aggregate != "hard_voting", "hard_voting is only available for classification"
    cv = self._check_cv(cv)
    prep_dl_kwargs, prep_model_kwargs, train_kwargs = self._split_kwargs(kwargs)
    pred_prob_l = []
    datamodule = None
    model = None
    for fold, (train_idx, val_idx) in enumerate(cv.split(train, y=train[self.config.target], groups=groups)):
        if verbose:
            logger.info(f"Running Fold {fold+1}/{cv.get_n_splits()}")
        train_fold = train.iloc[train_idx]
        val_fold = train.iloc[val_idx]
        if reset_datamodule:
            datamodule = None
        if datamodule is None:
            # Initialize datamodule and model in the first fold
            # uses train data from this fold to fit all transformers
            datamodule = self.prepare_dataloader(train=train_fold, validation=val_fold, seed=42, **prep_dl_kwargs)
            model = self.prepare_model(datamodule, **prep_model_kwargs)
        else:
            # Preprocess the current fold data using the fitted transformers and save in datamodule
            datamodule.train, _ = datamodule.preprocess_data(train_fold, stage="inference")
            datamodule.validation, _ = datamodule.preprocess_data(val_fold, stage="inference")

        # Train the model
        handle_oom = train_kwargs.pop("handle_oom", handle_oom)
        self.train(model, datamodule, handle_oom=handle_oom, **train_kwargs)
        fold_preds = self.predict(test, include_input_features=False)
        pred_idx = fold_preds.index
        if self.config.task == "classification":
            pred_prob_l.append(fold_preds.values[:, : -len(self.config.target)])
        elif self.config.task == "regression":
            pred_prob_l.append(fold_preds.values)
        if verbose:
            logger.info(f"Fold {fold+1}/{cv.get_n_splits()} prediction done")
        self.model.reset_weights()
    pred_df = self._combine_predictions(pred_prob_l, pred_idx, aggregate, weights)
    if return_raw_predictions:
        return pred_df, pred_prob_l
    else:
        return pred_df

Artifact Saving and Loading

Saving the Model, Datamodule, and Configs

pytorch_tabular.TabularModel.save_config(dir)

Saves the config in the specified directory.

Source code in src/pytorch_tabular/tabular_model.py
def save_config(self, dir: str) -> None:
    """Saves the config in the specified directory."""
    with open(os.path.join(dir, "config.yml"), "w") as fp:
        OmegaConf.save(self.config, fp, resolve=True)

pytorch_tabular.TabularModel.save_datamodule(dir, inference_only=False)

Saves the datamodule in the specified directory.

Parameters:

Name Type Description Default
dir str

The path to the directory to save the datamodule

required
inference_only bool

If True, will only save the inference datamodule without data. This cannot be used for further training, but can be used for inference. Defaults to False.

False
Source code in src/pytorch_tabular/tabular_model.py
def save_datamodule(self, dir: str, inference_only: bool = False) -> None:
    """Saves the datamodule in the specified directory.

    Args:
        dir (str): The path to the directory to save the datamodule
        inference_only (bool): If True, will only save the inference datamodule
            without data. This cannot be used for further training, but can be
            used for inference. Defaults to False.

    """
    if inference_only:
        dm = self.datamodule.inference_only_copy()
    else:
        dm = self.datamodule

    joblib.dump(dm, os.path.join(dir, "datamodule.sav"))

pytorch_tabular.TabularModel.save_model(dir, inference_only=False)

Saves the model and checkpoints in the specified directory.

Parameters:

Name Type Description Default
dir str

The path to the directory to save the model

required
inference_only bool

If True, will only save the inference only version of the datamodule

False
Source code in src/pytorch_tabular/tabular_model.py
def save_model(self, dir: str, inference_only: bool = False) -> None:
    """Saves the model and checkpoints in the specified directory.

    Args:
        dir (str): The path to the directory to save the model
        inference_only (bool): If True, will only save the inference
            only version of the datamodule

    """
    if os.path.exists(dir) and (os.listdir(dir)):
        logger.warning("Directory is not empty. Overwriting the contents.")
        for f in os.listdir(dir):
            os.remove(os.path.join(dir, f))
    os.makedirs(dir, exist_ok=True)
    self.save_config(dir)
    self.save_datamodule(dir, inference_only=inference_only)
    if hasattr(self.config, "log_target") and self.config.log_target is not None:
        joblib.dump(self.logger, os.path.join(dir, "exp_logger.sav"))
    if hasattr(self, "callbacks"):
        joblib.dump(self.callbacks, os.path.join(dir, "callbacks.sav"))
    self.trainer.save_checkpoint(os.path.join(dir, "model.ckpt"))
    custom_params = {}
    custom_params["custom_loss"] = getattr(self.model, "custom_loss", None)
    custom_params["custom_metrics"] = getattr(self.model, "custom_metrics", None)
    custom_params["custom_metrics_prob_inputs"] = getattr(self.model, "custom_metrics_prob_inputs", None)
    custom_params["custom_optimizer"] = getattr(self.model, "custom_optimizer", None)
    custom_params["custom_optimizer_params"] = getattr(self.model, "custom_optimizer_params", None)
    joblib.dump(custom_params, os.path.join(dir, "custom_params.sav"))
    if self.custom_model:
        joblib.dump(self.model_callable, os.path.join(dir, "custom_model_callable.sav"))

pytorch_tabular.TabularModel.save_model_for_inference(path, kind='pytorch', onnx_export_params={'opset_version': 12})

Saves the model for inference.

Parameters:

Name Type Description Default
path Union[str, Path]

path to save the model

required
kind str

"pytorch" or "onnx" (Experimental)

'pytorch'
onnx_export_params Dict

parameters for onnx export to be passed to torch.onnx.export

{'opset_version': 12}

Returns:

Name Type Description
bool bool

True if the model was saved successfully

Source code in src/pytorch_tabular/tabular_model.py
def save_model_for_inference(
    self,
    path: Union[str, Path],
    kind: str = "pytorch",
    onnx_export_params: Dict = {"opset_version": 12},
) -> bool:
    """Saves the model for inference.

    Args:
        path (Union[str, Path]): path to save the model
        kind (str): "pytorch" or "onnx" (Experimental)
        onnx_export_params (Dict): parameters for onnx export to be
            passed to torch.onnx.export

    Returns:
        bool: True if the model was saved successfully

    """
    if kind == "pytorch":
        torch.save(self.model, str(path))
        return True
    elif kind == "onnx":
        # Export the model
        onnx_export_params["input_names"] = ["categorical", "continuous"]
        onnx_export_params["output_names"] = onnx_export_params.get("output_names", ["output"])
        onnx_export_params["dynamic_axes"] = {
            onnx_export_params["input_names"][0]: {0: "batch_size"},
            onnx_export_params["output_names"][0]: {0: "batch_size"},
        }
        cat = torch.zeros(
            self.config.batch_size,
            len(self.config.categorical_cols),
            dtype=torch.int,
        )
        cont = torch.randn(
            self.config.batch_size,
            len(self.config.continuous_cols),
            requires_grad=True,
        )
        x = {"continuous": cont, "categorical": cat}
        torch.onnx.export(self.model, x, str(path), **onnx_export_params)
        return True
    else:
        raise ValueError("`kind` must be either pytorch or onnx")

pytorch_tabular.TabularModel.save_weights(path)

Saves the model weights in the specified directory.

Parameters:

Name Type Description Default
path str

The path to the file to save the model

required
Source code in src/pytorch_tabular/tabular_model.py
def save_weights(self, path: Union[str, Path]) -> None:
    """Saves the model weights in the specified directory.

    Args:
        path (str): The path to the file to save the model

    """
    torch.save(self.model.state_dict(), path)

Loading the Model and Datamodule

pytorch_tabular.TabularModel.load_best_model()

Loads the best model after training is done.

Source code in src/pytorch_tabular/tabular_model.py
def load_best_model(self) -> None:
    """Loads the best model after training is done."""
    if self.trainer.checkpoint_callback is not None:
        if self.verbose:
            logger.info("Loading the best model")
        ckpt_path = self.trainer.checkpoint_callback.best_model_path
        if ckpt_path != "":
            if self.verbose:
                logger.debug(f"Model Checkpoint: {ckpt_path}")
            ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
            self.model.load_state_dict(ckpt["state_dict"])
        else:
            logger.warning("No best model available to load. Did you run it more than 1" " epoch?...")
    else:
        logger.warning(
            "No best model available to load. Checkpoint Callback needs to be" " enabled for this to work"
        )

pytorch_tabular.TabularModel.load_model(dir, map_location=None, strict=True) classmethod

Loads a saved model from the directory.

Parameters:

Name Type Description Default
dir str

The directory where the model wa saved, along with the checkpoints

required
map_location Union[Dict[str, str], str, device, int, Callable, None])

If your checkpoint saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map to the new setup. The behaviour is the same as in torch.load()

None
strict bool)

Whether to strictly enforce that the keys in checkpoint_path match the keys returned by this module's state dict. Default: True.

True

Returns:

Name Type Description
TabularModel TabularModel

The saved TabularModel

Source code in src/pytorch_tabular/tabular_model.py
@classmethod
def load_model(cls, dir: str, map_location=None, strict=True):
    """Loads a saved model from the directory.

    Args:
        dir (str): The directory where the model wa saved, along with the checkpoints
        map_location (Union[Dict[str, str], str, device, int, Callable, None]) : If your checkpoint
            saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map
            to the new setup. The behaviour is the same as in torch.load()
        strict (bool) : Whether to strictly enforce that the keys in checkpoint_path match the keys
            returned by this module's state dict. Default: True.

    Returns:
        TabularModel (TabularModel): The saved TabularModel

    """
    config = OmegaConf.load(os.path.join(dir, "config.yml"))
    datamodule = joblib.load(os.path.join(dir, "datamodule.sav"))
    if (
        hasattr(config, "log_target")
        and (config.log_target is not None)
        and os.path.exists(os.path.join(dir, "exp_logger.sav"))
    ):
        logger = joblib.load(os.path.join(dir, "exp_logger.sav"))
    else:
        logger = None
    if os.path.exists(os.path.join(dir, "callbacks.sav")):
        callbacks = joblib.load(os.path.join(dir, "callbacks.sav"))
        # Excluding Gradient Accumulation Scheduler Callback as we are creating
        # a new one in trainer
        callbacks = [c for c in callbacks if not isinstance(c, GradientAccumulationScheduler)]
    else:
        callbacks = []
    if os.path.exists(os.path.join(dir, "custom_model_callable.sav")):
        model_callable = joblib.load(os.path.join(dir, "custom_model_callable.sav"))
        custom_model = True
    else:
        model_callable = getattr_nested(config._module_src, config._model_name)
        # model_callable = getattr(
        #     getattr(models, config._module_src), config._model_name
        # )
        custom_model = False
    inferred_config = datamodule.update_config(config)
    inferred_config = OmegaConf.structured(inferred_config)
    model_args = {
        "config": config,
        "inferred_config": inferred_config,
    }
    custom_params = joblib.load(os.path.join(dir, "custom_params.sav"))
    if custom_params.get("custom_loss") is not None:
        model_args["loss"] = "MSELoss"  # For compatibility. Not Used
    if custom_params.get("custom_metrics") is not None:
        model_args["metrics"] = ["mean_squared_error"]  # For compatibility. Not Used
        model_args["metrics_params"] = [{}]  # For compatibility. Not Used
        model_args["metrics_prob_inputs"] = [False]  # For compatibility. Not Used
    if custom_params.get("custom_optimizer") is not None:
        model_args["optimizer"] = "Adam"  # For compatibility. Not Used
    if custom_params.get("custom_optimizer_params") is not None:
        model_args["optimizer_params"] = {}  # For compatibility. Not Used

    # Initializing with default metrics, losses, and optimizers. Will revert once initialized
    try:
        model = model_callable.load_from_checkpoint(
            checkpoint_path=os.path.join(dir, "model.ckpt"),
            map_location=map_location,
            strict=strict,
            **model_args,
        )
    except RuntimeError as e:
        if (
            "Unexpected key(s) in state_dict" in str(e)
            and "loss.weight" in str(e)
            and "custom_loss.weight" in str(e)
        ):
            # Custom loss will be loaded after the model is initialized
            # continuing with strict=False
            model = model_callable.load_from_checkpoint(
                checkpoint_path=os.path.join(dir, "model.ckpt"),
                map_location=map_location,
                strict=False,
                **model_args,
            )
        else:
            raise e
    if custom_params.get("custom_optimizer") is not None:
        model.custom_optimizer = custom_params["custom_optimizer"]
    if custom_params.get("custom_optimizer_params") is not None:
        model.custom_optimizer_params = custom_params["custom_optimizer_params"]
    if custom_params.get("custom_loss") is not None:
        model.loss = custom_params["custom_loss"]
    if custom_params.get("custom_metrics") is not None:
        model.custom_metrics = custom_params.get("custom_metrics")
        model.hparams.metrics = [m.__name__ for m in custom_params.get("custom_metrics")]
        model.hparams.metrics_params = [{}]
        model.hparams.metrics_prob_input = custom_params.get("custom_metrics_prob_inputs")
    model._setup_loss()
    model._setup_metrics()
    tabular_model = cls(config=config, model_callable=model_callable)
    tabular_model.model = model
    tabular_model.custom_model = custom_model
    tabular_model.datamodule = datamodule
    tabular_model.callbacks = callbacks
    tabular_model.trainer = tabular_model._prepare_trainer(callbacks=callbacks)
    # tabular_model.trainer.model = model
    tabular_model.logger = logger
    return tabular_model

pytorch_tabular.TabularModel.load_weights(path)

Loads the model weights in the specified directory.

Parameters:

Name Type Description Default
path str

The path to the file to load the model from

required
Source code in src/pytorch_tabular/tabular_model.py
def load_weights(self, path: Union[str, Path]) -> None:
    """Loads the model weights in the specified directory.

    Args:
        path (str): The path to the file to load the model from

    """
    self._load_weights(self.model, path)

Other Functions

pytorch_tabular.TabularModel.find_learning_rate(model, datamodule, min_lr=1e-08, max_lr=1, num_training=100, mode='exponential', early_stop_threshold=4.0, plot=True, callbacks=None)

Enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking a good starting learning rate.

Parameters:

Name Type Description Default
model LightningModule

The PyTorch Lightning model to be trained.

required
datamodule TabularDatamodule

The datamodule

required
min_lr Optional[float]

minimum learning rate to investigate

1e-08
max_lr Optional[float]

maximum learning rate to investigate

1
num_training Optional[int]

number of learning rates to test

100
mode Optional[str]

search strategy, either 'linear' or 'exponential'. If set to 'linear' the learning rate will be searched by linearly increasing after each batch. If set to 'exponential', will increase learning rate exponentially.

'exponential'
early_stop_threshold Optional[float]

threshold for stopping the search. If the loss at any point is larger than early_stop_threshold*best_loss then the search is stopped. To disable, set to None.

4.0
plot bool

If true, will plot using matplotlib

True
callbacks Optional[List]

If provided, will be added to the callbacks for Trainer.

None

Returns:

Type Description
Tuple[float, DataFrame]

The suggested learning rate and the learning rate finder results

Source code in src/pytorch_tabular/tabular_model.py
def find_learning_rate(
    self,
    model: pl.LightningModule,
    datamodule: TabularDatamodule,
    min_lr: float = 1e-8,
    max_lr: float = 1,
    num_training: int = 100,
    mode: str = "exponential",
    early_stop_threshold: Optional[float] = 4.0,
    plot: bool = True,
    callbacks: Optional[List] = None,
) -> Tuple[float, DataFrame]:
    """Enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in
    picking a good starting learning rate.

    Args:
        model (pl.LightningModule): The PyTorch Lightning model to be trained.

        datamodule (TabularDatamodule): The datamodule

        min_lr (Optional[float], optional): minimum learning rate to investigate

        max_lr (Optional[float], optional): maximum learning rate to investigate

        num_training (Optional[int], optional): number of learning rates to test

        mode (Optional[str], optional): search strategy, either 'linear' or 'exponential'. If set to
            'linear' the learning rate will be searched by linearly increasing
            after each batch. If set to 'exponential', will increase learning
            rate exponentially.

        early_stop_threshold (Optional[float], optional): threshold for stopping the search. If the
            loss at any point is larger than early_stop_threshold*best_loss
            then the search is stopped. To disable, set to None.

        plot (bool, optional): If true, will plot using matplotlib

        callbacks (Optional[List], optional): If provided, will be added to the callbacks for Trainer.

    Returns:
        The suggested learning rate and the learning rate finder results

    """
    self._prepare_for_training(model, datamodule, callbacks, max_epochs=None, min_epochs=None)
    train_loader, _ = datamodule.train_dataloader(), datamodule.val_dataloader()
    lr_finder = Tuner(self.trainer).lr_find(
        model=self.model,
        train_dataloaders=train_loader,
        val_dataloaders=None,
        min_lr=min_lr,
        max_lr=max_lr,
        num_training=num_training,
        mode=mode,
        early_stop_threshold=early_stop_threshold,
    )
    if plot:
        fig = lr_finder.plot(suggest=True)
        fig.show()
    new_lr = lr_finder.suggestion()
    # cancelling the model and trainer that was loaded
    self.model = None
    self.trainer = None
    self.datamodule = None
    self.callbacks = None
    return new_lr, DataFrame(lr_finder.results)

pytorch_tabular.TabularModel.summary(model=None, max_depth=-1)

Prints a summary of the model.

Parameters:

Name Type Description Default
max_depth int

The maximum depth to traverse the modules and displayed in the summary. Defaults to -1, which means will display all the modules.

-1
Source code in src/pytorch_tabular/tabular_model.py
def summary(self, model=None, max_depth: int = -1) -> None:
    """Prints a summary of the model.

    Args:
        max_depth (int): The maximum depth to traverse the modules and displayed in the summary.
            Defaults to -1, which means will display all the modules.

    """
    if model is not None:
        print(summarize(model, max_depth=max_depth))
    elif self.has_model:
        print(summarize(self.model, max_depth=max_depth))
    else:
        rich_print(f"[bold green]{self.__class__.__name__}[/bold green]")
        rich_print("-" * 100)
        rich_print("[bold yellow]Config[/bold yellow]")
        rich_print("-" * 100)
        pprint(self.config.__dict__["_content"])
        rich_print(
            ":triangular_flag:[bold red]Full Model Summary once model has "
            "been initialized or passed in as an argument[/bold red]"
        )

pytorch_tabular.TabularModel.feature_importance()

Returns the feature importance of the model as a pandas DataFrame.

Source code in src/pytorch_tabular/tabular_model.py
def feature_importance(self) -> DataFrame:
    """Returns the feature importance of the model as a pandas DataFrame."""
    return self.model.feature_importance()