Skip to content

Configurations

Core Configuration

Data configuration.

Parameters:

Name Type Description Default
target Optional[List[str]]

A list of strings with the names of the target column(s). It is mandatory for all except SSL tasks.

None
continuous_cols List

Column names of the numeric fields. Defaults to []

list()
categorical_cols List

Column names of the categorical fields to treat differently. Defaults to []

list()
date_columns List

(Column name, Freq, Format) tuples of the date fields. For eg. a field named introduction_date and with a monthly frequency like "2023-12" should have an entry ('intro_date','M','%Y-%m')

list()
encode_date_columns bool

Whether to encode the derived variables from date

True
validation_split Optional[float]

Percentage of Training rows to keep aside as validation. Used only if Validation Data is not given separately

0.2
continuous_feature_transform Optional[str]

Whether to transform the features before modelling. By default, it is turned off. Choices are: [None,yeo-johnson,box-cox, quantile_normal,quantile_uniform].

None
normalize_continuous_features bool

Flag to normalize the input features(continuous)

True
quantile_noise int

NOT IMPLEMENTED. If specified fits QuantileTransformer on data with added gaussian noise with std = :quantile_noise: * data.std ; this will cause discrete values to be more separable. Please note that this transformation does NOT apply gaussian noise to the resulting data, the noise is only applied for QuantileTransformer

0
num_workers Optional[int]

The number of workers used for data loading. For windows always set to 0

0
pin_memory bool

Whether to pin memory for data loading.

True
handle_unknown_categories bool

Whether to handle unknown or new values in categorical columns as unknown

True
handle_missing_values bool

Whether to handle missing values in categorical columns as unknown

True
dataloader_kwargs Dict[str, Any]

Additional kwargs to be passed to PyTorch DataLoader. See https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader

dict()
Source code in src/pytorch_tabular/config/config.py
@dataclass
class DataConfig:
    """Data configuration.

    Args:
        target (Optional[List[str]]): A list of strings with the names of the target column(s). It is
                mandatory for all except SSL tasks.

        continuous_cols (List): Column names of the numeric fields. Defaults to []

        categorical_cols (List): Column names of the categorical fields to treat differently. Defaults to
                []

        date_columns (List): (Column name, Freq, Format) tuples of the date fields. For eg. a field named
                introduction_date and with a monthly frequency like "2023-12" should have
                an entry ('intro_date','M','%Y-%m')

        encode_date_columns (bool): Whether to encode the derived variables from date

        validation_split (Optional[float]): Percentage of Training rows to keep aside as validation. Used
                only if Validation Data is not given separately

        continuous_feature_transform (Optional[str]): Whether to transform the features before
                modelling. By default, it is turned off. Choices are: [`None`,`yeo-johnson`,`box-cox`,
                `quantile_normal`,`quantile_uniform`].

        normalize_continuous_features (bool): Flag to normalize the input features(continuous)

        quantile_noise (int): NOT IMPLEMENTED. If specified fits QuantileTransformer on data with added
                gaussian noise with std = :quantile_noise: * data.std ; this will cause discrete values to be more
                separable. Please note that this transformation does NOT apply gaussian noise to the resulting
                data, the noise is only applied for QuantileTransformer

        num_workers (Optional[int]): The number of workers used for data loading. For windows always set to
                0

        pin_memory (bool): Whether to pin memory for data loading.

        handle_unknown_categories (bool): Whether to handle unknown or new values in categorical
                columns as unknown

        handle_missing_values (bool): Whether to handle missing values in categorical columns as
                unknown

        dataloader_kwargs (Dict[str, Any]): Additional kwargs to be passed to PyTorch DataLoader. See
                https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader

    """

    target: Optional[List[str]] = field(
        default=None,
        metadata={
            "help": "A list of strings with the names of the target column(s)."
            " It is mandatory for all except SSL tasks."
        },
    )
    continuous_cols: List = field(
        default_factory=list,
        metadata={"help": "Column names of the numeric fields. Defaults to []"},
    )
    categorical_cols: List = field(
        default_factory=list,
        metadata={"help": "Column names of the categorical fields to treat differently. Defaults to []"},
    )
    date_columns: List = field(
        default_factory=list,
        metadata={
            "help": "(Column names, Freq) tuples of the date fields. For eg. a field named"
            " introduction_date and with a monthly frequency like '2023-12' should have"
            " an entry ('intro_date','M','%Y-%m')"
        },
    )

    encode_date_columns: bool = field(
        default=True,
        metadata={"help": "Whether or not to encode the derived variables from date"},
    )
    validation_split: Optional[float] = field(
        default=0.2,
        metadata={
            "help": "Percentage of Training rows to keep aside as validation."
            " Used only if Validation Data is not given separately"
        },
    )
    continuous_feature_transform: Optional[str] = field(
        default=None,
        metadata={
            "help": "Whether or not to transform the features before modelling. By default it is turned off.",
            "choices": [
                None,
                "yeo-johnson",
                "box-cox",
                "quantile_normal",
                "quantile_uniform",
            ],
        },
    )
    normalize_continuous_features: bool = field(
        default=True,
        metadata={"help": "Flag to normalize the input features (continuous)"},
    )
    quantile_noise: int = field(
        default=0,
        metadata={
            "help": "NOT IMPLEMENTED. If specified fits QuantileTransformer on data with added gaussian noise"
            " with std = :quantile_noise: * data.std ; this will cause discrete values to be more separable."
            " Please not that this transformation does NOT apply gaussian noise to the resulting data,"
            " the noise is only applied for QuantileTransformer"
        },
    )
    num_workers: Optional[int] = field(
        default=0,
        metadata={"help": "The number of workers used for data loading. For windows always set to 0"},
    )
    pin_memory: bool = field(
        default=True,
        metadata={"help": "Whether or not to pin memory for data loading."},
    )
    handle_unknown_categories: bool = field(
        default=True,
        metadata={"help": "Whether or not to handle unknown or new values in categorical columns as unknown"},
    )
    handle_missing_values: bool = field(
        default=True,
        metadata={"help": "Whether or not to handle missing values in categorical columns as unknown"},
    )

    dataloader_kwargs: Dict[str, Any] = field(
        default_factory=dict,
        metadata={"help": "Additional kwargs to be passed to PyTorch DataLoader."},
    )

    def __post_init__(self):
        assert (
            len(self.categorical_cols) + len(self.continuous_cols) + len(self.date_columns) > 0
        ), "There should be at-least one feature defined in categorical, continuous, or date columns"
        _validate_choices(self)
        if os.name == "nt" and self.num_workers != 0:
            print("Windows does not support num_workers > 0. Setting num_workers to 0")
            self.num_workers = 0

Base Model configuration.

Parameters:

Name Type Description Default
task str

Specify whether the problem is regression or classification. backbone is a task which considers the model as a backbone to generate features. Mostly used internally for SSL and related tasks.. Choices are: [regression,classification,backbone].

required
head Optional[str]

The head to be used for the model. Should be one of the heads defined in pytorch_tabular.models.common.heads. Defaults to LinearHead. Choices are: [None,LinearHead,MixtureDensityHead].

'LinearHead'
head_config Optional[Dict]

The config as a dict which defines the head. If left empty, will be initialized as default linear head.

lambda: {'layers': ''}()
embedding_dims Optional[List]

The dimensions of the embedding for each categorical column as a list of tuples (cardinality, embedding_dim). If left empty, will infer using the cardinality of the categorical column using the rule min(50, (x + 1) // 2)

None
embedding_dropout float

Dropout to be applied to the Categorical Embedding. Defaults to 0.0

0.0
batch_norm_continuous_input bool

If True, we will normalize the continuous layer by passing it through a BatchNorm layer.

True
virtual_batch_size Optional[int]

If not None, all BatchNorms will be converted to GhostBatchNorm's with the specified virtual batch size. Defaults to None

None
learning_rate float

The learning rate of the model. Defaults to 1e-3.

0.001
loss Optional[str]

The loss function to be applied. By Default, it is MSELoss for regression and CrossEntropyLoss for classification. Unless you are sure what you are doing, leave it at MSELoss or L1Loss for regression and CrossEntropyLoss for classification

None
metrics Optional[List[str]]

the list of metrics you need to track during training. The metrics should be one of the functional metrics implemented in torchmetrics. By default, it is accuracy if classification and mean_squared_error for regression

None
metrics_prob_input Optional[bool]

Is a mandatory parameter for classification metrics defined in the config. This defines whether the input to the metric function is the probability or the class. Length should be same as the number of metrics. Defaults to None.

None
metrics_params Optional[List]

The parameters to be passed to the metrics function. task is forced to be multiclass because the multiclass version can handle binary as well and for simplicity we are only using multiclass.

None
target_range Optional[List]

The range in which we should limit the output variable. Currently ignored for multi-target regression. Typically used for Regression problems. If left empty, will not apply any restrictions

None
seed int

The seed for reproducibility. Defaults to 42

42
Source code in src/pytorch_tabular/config/config.py
@dataclass
class ModelConfig:
    """Base Model configuration.

    Args:
        task (str): Specify whether the problem is regression or classification. `backbone` is a task which
                considers the model as a backbone to generate features. Mostly used internally for SSL and related
                tasks.. Choices are: [`regression`,`classification`,`backbone`].

        head (Optional[str]): The head to be used for the model. Should be one of the heads defined in
                `pytorch_tabular.models.common.heads`. Defaults to  LinearHead. Choices are:
                [`None`,`LinearHead`,`MixtureDensityHead`].

        head_config (Optional[Dict]): The config as a dict which defines the head. If left empty, will be
                initialized as default linear head.

        embedding_dims (Optional[List]): The dimensions of the embedding for each categorical column as a
                list of tuples (cardinality, embedding_dim). If left empty, will infer using the cardinality of
                the categorical column using the rule min(50, (x + 1) // 2)

        embedding_dropout (float): Dropout to be applied to the Categorical Embedding. Defaults to 0.0

        batch_norm_continuous_input (bool): If True, we will normalize the continuous layer by passing it
                through a BatchNorm layer.

        virtual_batch_size (Optional[int]): If not None, all BatchNorms will be converted to GhostBatchNorm's
                with the specified virtual batch size. Defaults to None

        learning_rate (float): The learning rate of the model. Defaults to 1e-3.

        loss (Optional[str]): The loss function to be applied. By Default, it is MSELoss for regression and
                CrossEntropyLoss for classification. Unless you are sure what you are doing, leave it at MSELoss
                or L1Loss for regression and CrossEntropyLoss for classification

        metrics (Optional[List[str]]): the list of metrics you need to track during training. The metrics
                should be one of the functional metrics implemented in ``torchmetrics``. By default, it is
                accuracy if classification and mean_squared_error for regression

        metrics_prob_input (Optional[bool]): Is a mandatory parameter for classification metrics defined in
                the config. This defines whether the input to the metric function is the probability or the class.
                Length should be same as the number of metrics. Defaults to None.

        metrics_params (Optional[List]): The parameters to be passed to the metrics function. `task` is forced to
                be `multiclass` because the multiclass version can handle binary as well and for simplicity we are
                only using `multiclass`.

        target_range (Optional[List]): The range in which we should limit the output variable. Currently
                ignored for multi-target regression. Typically used for Regression problems. If left empty, will
                not apply any restrictions

        seed (int): The seed for reproducibility. Defaults to 42

    """

    task: str = field(
        metadata={
            "help": "Specify whether the problem is regression or classification."
            " `backbone` is a task which considers the model as a backbone to generate features."
            " Mostly used internally for SSL and related tasks.",
            "choices": ["regression", "classification", "backbone"],
        }
    )

    head: Optional[str] = field(
        default="LinearHead",
        metadata={
            "help": "The head to be used for the model. Should be one of the heads defined"
            " in `pytorch_tabular.models.common.heads`. Defaults to  LinearHead",
            "choices": [None, "LinearHead", "MixtureDensityHead"],
        },
    )

    head_config: Optional[Dict] = field(
        default_factory=lambda: {"layers": ""},
        metadata={
            "help": "The config as a dict which defines the head."
            " If left empty, will be initialized as default linear head."
        },
    )
    embedding_dims: Optional[List] = field(
        default=None,
        metadata={
            "help": "The dimensions of the embedding for each categorical column as a list of tuples "
            "(cardinality, embedding_dim). If left empty, will infer using the cardinality of the "
            "categorical column using the rule min(50, (x + 1) // 2)"
        },
    )
    embedding_dropout: float = field(
        default=0.0,
        metadata={"help": "Dropout to be applied to the Categorical Embedding. Defaults to 0.0"},
    )
    batch_norm_continuous_input: bool = field(
        default=True,
        metadata={"help": "If True, we will normalize the continuous layer by passing it through a BatchNorm layer."},
    )

    learning_rate: float = field(
        default=1e-3,
        metadata={"help": "The learning rate of the model. Defaults to 1e-3."},
    )
    loss: Optional[str] = field(
        default=None,
        metadata={
            "help": "The loss function to be applied. By Default it is MSELoss for regression "
            "and CrossEntropyLoss for classification. Unless you are sure what you are doing, "
            "leave it at MSELoss or L1Loss for regression and CrossEntropyLoss for classification"
        },
    )
    metrics: Optional[List[str]] = field(
        default=None,
        metadata={
            "help": "the list of metrics you need to track during training. The metrics should be one "
            "of the functional metrics implemented in ``torchmetrics``. To use your own metric, please "
            "use the `metric` param in the `fit` method By default, it is accuracy if classification "
            "and mean_squared_error for regression"
        },
    )
    metrics_prob_input: Optional[List[bool]] = field(
        default=None,
        metadata={
            "help": "Is a mandatory parameter for classification metrics defined in the config. This defines "
            "whether the input to the metric function is the probability or the class. Length should be same "
            "as the number of metrics. Defaults to None."
        },
    )
    metrics_params: Optional[List] = field(
        default=None,
        metadata={
            "help": "The parameters to be passed to the metrics function. `task` is forced to be `multiclass`` "
            "because the multiclass version can handle binary as well and for simplicity we are only using "
            "`multiclass`."
        },
    )
    target_range: Optional[List] = field(
        default=None,
        metadata={
            "help": "The range in which we should limit the output variable. "
            "Currently ignored for multi-target regression. Typically used for Regression problems. "
            "If left empty, will not apply any restrictions"
        },
    )

    virtual_batch_size: Optional[int] = field(
        default=None,
        metadata={
            "help": "If not None, all BatchNorms will be converted to GhostBatchNorm's "
            " with this virtual batch size. Defaults to None"
        },
    )

    seed: int = field(
        default=42,
        metadata={"help": "The seed for reproducibility. Defaults to 42"},
    )

    _module_src: str = field(default="models")
    _model_name: str = field(default="Model")
    _backbone_name: str = field(default="Backbone")
    _config_name: str = field(default="Config")

    def __post_init__(self):
        if self.task == "regression":
            self.loss = self.loss or "MSELoss"
            self.metrics = self.metrics or ["mean_squared_error"]
            self.metrics_params = [{} for _ in self.metrics] if self.metrics_params is None else self.metrics_params
            self.metrics_prob_input = [False for _ in self.metrics]  # not used in Regression. just for compatibility
        elif self.task == "classification":
            self.loss = self.loss or "CrossEntropyLoss"
            self.metrics = self.metrics or ["accuracy"]
            self.metrics_params = [{} for _ in self.metrics] if self.metrics_params is None else self.metrics_params
            self.metrics_prob_input = (
                [False for _ in self.metrics] if self.metrics_prob_input is None else self.metrics_prob_input
            )
        elif self.task == "backbone":
            self.loss = None
            self.metrics = None
            self.metrics_params = None
            if self.head is not None:
                logger.warning("`head` is not a valid parameter for backbone task. Making `head=None`")
                self.head = None
                self.head_config = None
        else:
            raise NotImplementedError(
                f"{self.task} is not a valid task. Should be one of "
                f"{self.__dataclass_fields__['task'].metadata['choices']}"
            )
        if self.metrics is not None:
            assert len(self.metrics) == len(self.metrics_params), "metrics and metric_params should have same length"

        if self.task != "backbone":
            assert self.head in dir(heads.blocks), f"{self.head} is not a valid head"
            if hasattr(self, "_config_name") and self._config_name != "MDNConfig":
                assert self.head != "MixtureDensityHead", "MixtureDensityHead is not supported as a head for regular "
                "models. Use `MDNConfig` instead. Please see Probabilistic Regression with MDN How-to-Guide in "
                "documentation for the right usage."
            _head_callable = getattr(heads.blocks, self.head)
            ideal_head_config = _head_callable._config_template
            invalid_keys = set(self.head_config.keys()) - set(ideal_head_config.__dict__.keys())
            assert len(invalid_keys) == 0, f"`head_config` has some invalid keys: {invalid_keys}"

        # For Custom models, setting these values for compatibility
        if not hasattr(self, "_config_name"):
            self._config_name = type(self).__name__
        if not hasattr(self, "_model_name"):
            self._model_name = re.sub("[Cc]onfig", "Model", self._config_name)
        if not hasattr(self, "_backbone_name"):
            self._backbone_name = re.sub("[Cc]onfig", "Backbone", self._config_name)
        _validate_choices(self)

Base SSLModel Configuration.

Parameters:

Name Type Description Default
encoder_config Optional[ModelConfig]

The config of the encoder to be used for the model. Should be one of the model configs defined in PyTorch Tabular

None
decoder_config Optional[ModelConfig]

The config of decoder to be used for the model. Should be one of the model configs defined in PyTorch Tabular. Defaults to nn.Identity

None
embedding_dims Optional[List]

The dimensions of the embedding for each categorical column as a list of tuples (cardinality, embedding_dim). If left empty, will infer using the cardinality of the categorical column using the rule min(50, (x + 1) // 2)

None
embedding_dropout float

Dropout to be applied to the Categorical Embedding. Defaults to 0.1

0.1
batch_norm_continuous_input bool

If True, we will normalize the continuous layer by passing it through a BatchNorm layer.

True
virtual_batch_size Optional[int]

If not None, all BatchNorms will be converted to GhostBatchNorm's with the specified virtual batch size. Defaults to None

None
learning_rate float

The learning rate of the model. Defaults to 1e-3

0.001
seed int

The seed for reproducibility. Defaults to 42

42
Source code in src/pytorch_tabular/config/config.py
@dataclass
class SSLModelConfig:
    """Base SSLModel Configuration.

    Args:
        encoder_config (Optional[ModelConfig]): The config of the encoder to be used for the
                model. Should be one of the model configs defined in PyTorch Tabular

        decoder_config (Optional[ModelConfig]): The config of decoder to be used for the model.
                Should be one of the model configs defined in PyTorch Tabular. Defaults to nn.Identity

        embedding_dims (Optional[List]): The dimensions of the embedding for each categorical column as a
                list of tuples (cardinality, embedding_dim). If left empty, will infer using the cardinality of
                the categorical column using the rule min(50, (x + 1) // 2)

        embedding_dropout (float): Dropout to be applied to the Categorical Embedding. Defaults to 0.1

        batch_norm_continuous_input (bool): If True, we will normalize the continuous layer by passing it
                through a BatchNorm layer.

        virtual_batch_size (Optional[int]): If not None, all BatchNorms will be converted to GhostBatchNorm's
                with the specified virtual batch size. Defaults to None

        learning_rate (float): The learning rate of the model. Defaults to 1e-3

        seed (int): The seed for reproducibility. Defaults to 42

    """

    task: str = field(init=False, default="ssl")

    encoder_config: Optional[ModelConfig] = field(
        default=None,
        metadata={
            "help": "The config of the encoder to be used for the model."
            " Should be one of the model configs defined in PyTorch Tabular",
        },
    )

    decoder_config: Optional[ModelConfig] = field(
        default=None,
        metadata={
            "help": "The config of decoder to be used for the model."
            " Should be one of the model configs defined in PyTorch Tabular. Defaults to nn.Identity",
        },
    )

    embedding_dims: Optional[List] = field(
        default=None,
        metadata={
            "help": "The dimensions of the embedding for each categorical column as a list of tuples "
            "(cardinality, embedding_dim). If left empty, will infer using the cardinality of the "
            "categorical column using the rule min(50, (x + 1) // 2)"
        },
    )
    embedding_dropout: float = field(
        default=0.1,
        metadata={"help": "Dropout to be applied to the Categorical Embedding. Defaults to 0.1"},
    )
    batch_norm_continuous_input: bool = field(
        default=True,
        metadata={"help": "If True, we will normalize the continuous layer by passing it through a BatchNorm layer."},
    )
    virtual_batch_size: Optional[int] = field(
        default=None,
        metadata={
            "help": "If not None, all BatchNorms will be converted to GhostBatchNorm's "
            " with this virtual batch size. Defaults to None"
        },
    )
    learning_rate: float = field(
        default=1e-3,
        metadata={"help": "The learning rate of the model. Defaults to 1e-3"},
    )
    seed: int = field(
        default=42,
        metadata={"help": "The seed for reproducibility. Defaults to 42"},
    )

    _module_src: str = field(default="models")
    _model_name: str = field(default="Model")
    _config_name: str = field(default="Config")

    def __post_init__(self):
        assert self.task == "ssl", f"task should be ssl, got {self.task}"
        # For Custom models, setting these values for compatibility
        if not hasattr(self, "_config_name"):
            self._config_name = type(self).__name__
        if not hasattr(self, "_model_name"):
            self._model_name = re.sub("[Cc]onfig", "Model", self._config_name)
        _validate_choices(self)

Trainer configuration.

Parameters:

Name Type Description Default
batch_size int

Number of samples in each batch of training

64
data_aware_init_batch_size int

Number of samples in each batch of training for the data-aware initialization, when applicable. Defaults to 2000

2000
fast_dev_run bool

runs n if set to n (int) else 1 if set to True batch(es) of train, val and test to find any bugs (ie: a sort of unit test).

False
max_epochs int

Maximum number of epochs to be run

10
min_epochs Optional[int]

Force training for at least these many epochs. 1 by default

1
max_time Optional[int]

Stop training after this amount of time has passed. Disabled by default (None)

None
accelerator Optional[str]

The accelerator to use for training. Can be one of 'cpu','gpu','tpu','ipu', 'mps', 'auto'. Defaults to 'auto'. Choices are: [cpu,gpu,tpu,ipu,'mps',auto].

'auto'
devices Optional[int]

Number of devices to train on (int). -1 uses all available devices. By default, uses all available devices (-1)

-1
devices_list Optional[List[int]]

List of devices to train on (list). If specified, takes precedence over devices argument. Defaults to None

None
accumulate_grad_batches int

Accumulates grads every k batches or as set up in the dict. Trainer also calls optimizer.step() for the last indivisible step number.

1
auto_lr_find bool

Runs a learning rate finder algorithm when calling trainer.tune(), to find optimal initial learning rate.

False
auto_select_gpus bool

If enabled and devices is an integer, pick available gpus automatically. This is especially useful when GPUs are configured to be in 'exclusive mode', such that only one process at a time can access them.

True
check_val_every_n_epoch int

Check val every n train epochs.

1
gradient_clip_val float

Gradient clipping value

0.0
overfit_batches float

Uses this much data of the training set. If nonzero, will use the same training set for validation and testing. If the training dataloaders have shuffle=True, Lightning will automatically disable it. Useful for quickly debugging or trying to overfit on purpose.

0.0
deterministic bool

If true enables cudnn.deterministic. Might make your system slower, but ensures reproducibility.

False
profiler Optional[str]

To profile individual steps during training and assist in identifying bottlenecks. None, simple or advanced, pytorch. Choices are: [None,simple,advanced,pytorch].

None
early_stopping Optional[str]

The loss/metric that needed to be monitored for early stopping. If None, there will be no early stopping

'valid_loss'
early_stopping_min_delta float

The minimum delta in the loss/metric which qualifies as an improvement in early stopping

0.001
early_stopping_mode str

The direction in which the loss/metric should be optimized. Choices are: [max,min].

'min'
early_stopping_patience int

The number of epochs to wait until there is no further improvements in loss/metric

3
early_stopping_kwargs Optional[Dict]

Additional keyword arguments for the early stopping callback. See the documentation for the PyTorch Lightning EarlyStopping callback for more details.

lambda: {}()
checkpoints Optional[str]

The loss/metric that needed to be monitored for checkpoints. If None, there will be no checkpoints

'valid_loss'
checkpoints_path str

The path where the saved models will be

'saved_models'
checkpoints_every_n_epochs int

Number of training steps between checkpoints

1
checkpoints_name Optional[str]

The name under which the models will be saved. If left blank, first it will look for run_name in experiment_config and if that is also None then it will use a generic name like task_version.

None
checkpoints_mode str

The direction in which the loss/metric should be optimized

'min'
checkpoints_save_top_k int

The number of best models to save

1
checkpoints_kwargs Optional[Dict]

Additional keyword arguments for the checkpoints callback. See the documentation for the PyTorch Lightning ModelCheckpoint callback for more details.

lambda: {}()
load_best bool

Flag to load the best model saved during training

True
track_grad_norm int

Track and Log Gradient Norms in the logger. -1 by default means no tracking. 1 for the L1 norm, 2 for L2 norm, etc.

-1
progress_bar str

Progress bar type. Can be one of: none, simple, rich. Defaults to rich.

'rich'
precision int

Precision of the model. Can be one of: 32, 16, 64. Defaults to 32.. Choices are: [32,16,64].

32
seed int

Seed for random number generators. Defaults to 42

42
trainer_kwargs Dict[str, Any]

Additional kwargs to be passed to PyTorch Lightning Trainer. See https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.trainer.html#pytorch_lightning.trainer.Trainer

dict()
Source code in src/pytorch_tabular/config/config.py
@dataclass
class TrainerConfig:
    """Trainer configuration.

    Args:
        batch_size (int): Number of samples in each batch of training

        data_aware_init_batch_size (int): Number of samples in each batch of training for the data-aware initialization,
            when applicable. Defaults to 2000

        fast_dev_run (bool): runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es) of train, val
                and test to find any bugs (ie: a sort of unit test).

        max_epochs (int): Maximum number of epochs to be run

        min_epochs (Optional[int]): Force training for at least these many epochs. 1 by default

        max_time (Optional[int]): Stop training after this amount of time has passed. Disabled by default
                (None)

        accelerator (Optional[str]): The accelerator to use for training. Can be one of
                'cpu','gpu','tpu','ipu', 'mps', 'auto'. Defaults to 'auto'.
                Choices are: [`cpu`,`gpu`,`tpu`,`ipu`,'mps',`auto`].

        devices (Optional[int]): Number of devices to train on (int). -1 uses all available devices. By
                default, uses all available devices (-1)

        devices_list (Optional[List[int]]): List of devices to train on (list). If specified, takes
                precedence over `devices` argument. Defaults to None

        accumulate_grad_batches (int): Accumulates grads every k batches or as set up in the dict. Trainer
                also calls optimizer.step() for the last indivisible step number.

        auto_lr_find (bool): Runs a learning rate finder algorithm when calling
                trainer.tune(), to find optimal initial learning rate.

        auto_select_gpus (bool): If enabled and `devices` is an integer, pick available gpus automatically.
                This is especially useful when GPUs are configured to be in 'exclusive mode', such that only one
                process at a time can access them.

        check_val_every_n_epoch (int): Check val every n train epochs.

        gradient_clip_val (float): Gradient clipping value

        overfit_batches (float): Uses this much data of the training set. If nonzero, will use the same
                training set for validation and testing. If the training dataloaders have shuffle=True, Lightning
                will automatically disable it. Useful for quickly debugging or trying to overfit on purpose.

        deterministic (bool): If true enables cudnn.deterministic. Might make your system slower, but
                ensures reproducibility.

        profiler (Optional[str]): To profile individual steps during training and assist in identifying
                bottlenecks. None, simple or advanced, pytorch. Choices are:
                [`None`,`simple`,`advanced`,`pytorch`].

        early_stopping (Optional[str]): The loss/metric that needed to be monitored for early stopping. If
                None, there will be no early stopping

        early_stopping_min_delta (float): The minimum delta in the loss/metric which qualifies as an
                improvement in early stopping

        early_stopping_mode (str): The direction in which the loss/metric should be optimized. Choices are:
                [`max`,`min`].

        early_stopping_patience (int): The number of epochs to wait until there is no further improvements
                in loss/metric

        early_stopping_kwargs (Optional[Dict]): Additional keyword arguments for the early stopping callback.
                See the documentation for the PyTorch Lightning EarlyStopping callback for more details.

        checkpoints (Optional[str]): The loss/metric that needed to be monitored for checkpoints. If None,
                there will be no checkpoints

        checkpoints_path (str): The path where the saved models will be

        checkpoints_every_n_epochs (int): Number of training steps between checkpoints

        checkpoints_name (Optional[str]): The name under which the models will be saved. If left blank,
                first it will look for `run_name` in experiment_config and if that is also None then it will use a
                generic name like task_version.

        checkpoints_mode (str): The direction in which the loss/metric should be optimized

        checkpoints_save_top_k (int): The number of best models to save

        checkpoints_kwargs (Optional[Dict]): Additional keyword arguments for the checkpoints callback.
                See the documentation for the PyTorch Lightning ModelCheckpoint callback for more details.

        load_best (bool): Flag to load the best model saved during training

        track_grad_norm (int): Track and Log Gradient Norms in the logger. -1 by default means no tracking.
                1 for the L1 norm, 2 for L2 norm, etc.

        progress_bar (str): Progress bar type. Can be one of: `none`, `simple`, `rich`. Defaults to `rich`.

        precision (int): Precision of the model. Can be one of: `32`, `16`, `64`. Defaults to `32`..
                Choices are: [`32`,`16`,`64`].

        seed (int): Seed for random number generators. Defaults to 42

        trainer_kwargs (Dict[str, Any]): Additional kwargs to be passed to PyTorch Lightning Trainer. See
                https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.trainer.html#pytorch_lightning.trainer.Trainer

    """

    batch_size: int = field(default=64, metadata={"help": "Number of samples in each batch of training"})
    data_aware_init_batch_size: int = field(
        default=2000,
        metadata={
            "help": "Number of samples in each batch of training for the data-aware initialization,"
            " when applicable. Defaults to 2000"
        },
    )
    fast_dev_run: bool = field(
        default=False,
        metadata={
            "help": "runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es) of train,"
            " val and test to find any bugs (ie: a sort of unit test)."
        },
    )
    max_epochs: int = field(default=10, metadata={"help": "Maximum number of epochs to be run"})
    min_epochs: Optional[int] = field(
        default=1,
        metadata={"help": "Force training for at least these many epochs. 1 by default"},
    )
    max_time: Optional[int] = field(
        default=None,
        metadata={"help": "Stop training after this amount of time has passed. Disabled by default (None)"},
    )
    accelerator: Optional[str] = field(
        default="auto",
        metadata={
            "help": "The accelerator to use for training. Can be one of 'cpu','gpu','tpu','ipu','auto'."
            " Defaults to 'auto'",
            "choices": ["cpu", "gpu", "tpu", "ipu", "mps", "auto"],
        },
    )
    devices: Optional[int] = field(
        default=-1,
        metadata={
            "help": "Number of devices to train on. -1 uses all available devices."
            " By default uses all available devices (-1)",
        },
    )
    devices_list: Optional[List[int]] = field(
        default=None,
        metadata={
            "help": "List of devices to train on (list). If specified, takes precedence over `devices` argument."
            " Defaults to None",
        },
    )

    accumulate_grad_batches: int = field(
        default=1,
        metadata={
            "help": "Accumulates grads every k batches or as set up in the dict."
            " Trainer also calls optimizer.step() for the last indivisible step number."
        },
    )
    auto_lr_find: bool = field(
        default=False,
        metadata={
            "help": "Runs a learning rate finder algorithm (see this paper) when calling trainer.tune(),"
            " to find optimal initial learning rate."
        },
    )
    auto_select_gpus: bool = field(
        default=True,
        metadata={
            "help": "If enabled and `devices` is an integer, pick available gpus automatically."
            " This is especially useful when GPUs are configured to be in 'exclusive mode',"
            " such that only one process at a time can access them."
        },
    )
    check_val_every_n_epoch: int = field(default=1, metadata={"help": "Check val every n train epochs."})
    gradient_clip_val: float = field(default=0.0, metadata={"help": "Gradient clipping value"})
    overfit_batches: float = field(
        default=0.0,
        metadata={
            "help": "Uses this much data of the training set. If nonzero, will use the same training set"
            " for validation and testing. If the training dataloaders have shuffle=True,"
            " Lightning will automatically disable it."
            " Useful for quickly debugging or trying to overfit on purpose."
        },
    )
    deterministic: bool = field(
        default=False,
        metadata={
            "help": "If true enables cudnn.deterministic. Might make your system slower, but ensures reproducibility."
        },
    )
    profiler: Optional[str] = field(
        default=None,
        metadata={
            "help": "To profile individual steps during training and assist in identifying bottlenecks."
            " None, simple or advanced, pytorch",
            "choices": [None, "simple", "advanced", "pytorch"],
        },
    )
    early_stopping: Optional[str] = field(
        default="valid_loss",
        metadata={
            "help": "The loss/metric that needed to be monitored for early stopping."
            " If None, there will be no early stopping"
        },
    )
    early_stopping_min_delta: float = field(
        default=0.001,
        metadata={"help": "The minimum delta in the loss/metric which qualifies as an improvement in early stopping"},
    )
    early_stopping_mode: str = field(
        default="min",
        metadata={
            "help": "The direction in which the loss/metric should be optimized",
            "choices": ["max", "min"],
        },
    )
    early_stopping_patience: int = field(
        default=3,
        metadata={"help": "The number of epochs to wait until there is no further improvements in loss/metric"},
    )
    early_stopping_kwargs: Optional[Dict[str, Any]] = field(
        default_factory=lambda: {},
        metadata={
            "help": "Additional keyword arguments for the early stopping callback."
            " See the documentation for the PyTorch Lightning EarlyStopping callback for more details."
        },
    )
    checkpoints: Optional[str] = field(
        default="valid_loss",
        metadata={
            "help": "The loss/metric that needed to be monitored for checkpoints. If None, there will be no checkpoints"
        },
    )
    checkpoints_path: str = field(
        default="saved_models",
        metadata={"help": "The path where the saved models will be"},
    )
    checkpoints_every_n_epochs: int = field(
        default=1,
        metadata={"help": "Number of training steps between checkpoints"},
    )
    checkpoints_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "The name under which the models will be saved. If left blank,"
            " first it will look for `run_name` in experiment_config and if that is also None"
            " then it will use a generic name like task_version."
        },
    )
    checkpoints_mode: str = field(
        default="min",
        metadata={"help": "The direction in which the loss/metric should be optimized"},
    )
    checkpoints_save_top_k: int = field(
        default=1,
        metadata={"help": "The number of best models to save"},
    )
    checkpoints_kwargs: Optional[Dict[str, Any]] = field(
        default_factory=lambda: {},
        metadata={
            "help": "Additional keyword arguments for the checkpoints callback. See the documentation"
            " for the PyTorch Lightning ModelCheckpoint callback for more details."
        },
    )
    load_best: bool = field(
        default=True,
        metadata={"help": "Flag to load the best model saved during training"},
    )
    track_grad_norm: int = field(
        default=-1,
        metadata={
            "help": "Track and Log Gradient Norms in the logger. -1 by default means no tracking. "
            "1 for the L1 norm, 2 for L2 norm, etc."
        },
    )
    progress_bar: str = field(
        default="rich",
        metadata={"help": "Progress bar type. Can be one of: `none`, `simple`, `rich`. Defaults to `rich`."},
    )
    precision: int = field(
        default=32,
        metadata={
            "help": "Precision of the model. Can be one of: `32`, `16`, `64`. Defaults to `32`.",
            "choices": [32, 16, 64],
        },
    )
    seed: int = field(
        default=42,
        metadata={"help": "Seed for random number generators. Defaults to 42"},
    )
    trainer_kwargs: Dict[str, Any] = field(
        default_factory=dict,
        metadata={"help": "Additional kwargs to be passed to PyTorch Lightning Trainer."},
    )

    def __post_init__(self):
        _validate_choices(self)
        if self.accelerator is None:
            self.accelerator = "cpu"
        if self.devices_list is not None:
            self.devices = self.devices_list
        delattr(self, "devices_list")
        for key in self.early_stopping_kwargs.keys():
            if key in ["min_delta", "mode", "patience"]:
                raise ValueError(
                    f"Cannot override {key} in early_stopping_kwargs."
                    f" Please use the appropriate argument in `TrainerConfig`"
                )
        for key in self.checkpoints_kwargs.keys():
            if key in ["dirpath", "filename", "monitor", "save_top_k", "mode", "every_n_epochs"]:
                raise ValueError(
                    f"Cannot override {key} in checkpoints_kwargs."
                    f" Please use the appropriate argument in `TrainerConfig`"
                )

Experiment configuration. Experiment Tracking with WandB and Tensorboard.

Parameters:

Name Type Description Default
project_name str

The name of the project under which all runs will be logged. For Tensorboard this defines the folder under which the logs will be saved and for W&B it defines the project name

MISSING
run_name Optional[str]

The name of the run; a specific identifier to recognize the run. If left blank, will be assigned an auto-generated name

None
exp_watch Optional[str]

The level of logging required. Can be gradients, parameters, all or None. Defaults to None. Choices are: [gradients,parameters,all,None].

None
log_target str

Determines where logging happens - Tensorboard or W&B. Choices are: [wandb,tensorboard].

'tensorboard'
log_logits bool

Turn this on to log the logits as a histogram in W&B

False
exp_log_freq int

step count between logging of gradients and parameters.

100
Source code in src/pytorch_tabular/config/config.py
@dataclass
class ExperimentConfig:
    """Experiment configuration. Experiment Tracking with WandB and Tensorboard.

    Args:
        project_name (str): The name of the project under which all runs will be logged. For Tensorboard
                this defines the folder under which the logs will be saved and for W&B it defines the project name

        run_name (Optional[str]): The name of the run; a specific identifier to recognize the run. If left
                blank, will be assigned an auto-generated name

        exp_watch (Optional[str]): The level of logging required.  Can be `gradients`, `parameters`, `all`
                or `None`. Defaults to None. Choices are: [`gradients`,`parameters`,`all`,`None`].

        log_target (str): Determines where logging happens - Tensorboard or W&B. Choices are:
                [`wandb`,`tensorboard`].

        log_logits (bool): Turn this on to log the logits as a histogram in W&B

        exp_log_freq (int): step count between logging of gradients and parameters.

    """

    project_name: str = field(
        default=MISSING,
        metadata={
            "help": "The name of the project under which all runs will be logged."
            " For Tensorboard this defines the folder under which the logs will be saved"
            " and for W&B it defines the project name"
        },
    )

    run_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "The name of the run; a specific identifier to recognize the run."
            " If left blank, will be assigned a auto-generated name"
        },
    )
    exp_watch: Optional[str] = field(
        default=None,
        metadata={
            "help": "The level of logging required.  Can be `gradients`, `parameters`, `all` or `None`."
            " Defaults to None",
            "choices": ["gradients", "parameters", "all", None],
        },
    )

    log_target: str = field(
        default="tensorboard",
        metadata={
            "help": "Determines where logging happens - Tensorboard or W&B",
            "choices": ["wandb", "tensorboard"],
        },
    )
    log_logits: bool = field(
        default=False,
        metadata={"help": "Turn this on to log the logits as a histogram in W&B"},
    )

    exp_log_freq: int = field(
        default=100,
        metadata={"help": "step count between logging of gradients and parameters."},
    )

    def __post_init__(self):
        _validate_choices(self)
        if self.log_target == "wandb":
            try:
                import wandb  # noqa: F401
            except ImportError:
                raise ImportError(
                    "No W&B installation detected. `pip install wandb` to install W&B if you set log_target as `wandb`"
                )

Optimizer and Learning Rate Scheduler configuration.

Parameters:

Name Type Description Default
optimizer str

Any of the standard optimizers from torch.optim or provide full python path, for example "torch_optimizer.RAdam".

'Adam'
optimizer_params Dict

The parameters for the optimizer. If left blank, will use default parameters.

lambda: {}()
lr_scheduler Optional[str]

The name of the LearningRateScheduler to use, if any, from torch.optim.lr_scheduler. If None, will not use any scheduler. Defaults to None

None
lr_scheduler_params Optional[Dict]

The parameters for the LearningRateScheduler. If left blank, will use default parameters.

lambda: {}()
lr_scheduler_monitor_metric Optional[str]

Used with ReduceLROnPlateau, where the plateau is decided based on this metric

'valid_loss'
Source code in src/pytorch_tabular/config/config.py
@dataclass
class OptimizerConfig:
    """Optimizer and Learning Rate Scheduler configuration.

    Args:
        optimizer (str): Any of the standard optimizers from
                [torch.optim](https://pytorch.org/docs/stable/optim.html#algorithms) or provide full python path,
                for example "torch_optimizer.RAdam".

        optimizer_params (Dict): The parameters for the optimizer. If left blank, will use default
                parameters.

        lr_scheduler (Optional[str]): The name of the LearningRateScheduler to use, if any, from
                [torch.optim.lr_scheduler](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-
                rate). If None, will not use any scheduler. Defaults to `None`

        lr_scheduler_params (Optional[Dict]): The parameters for the LearningRateScheduler. If left blank,
                will use default parameters.

        lr_scheduler_monitor_metric (Optional[str]): Used with ReduceLROnPlateau, where the plateau is
                decided based on this metric

    """

    optimizer: str = field(
        default="Adam",
        metadata={
            "help": "Any of the standard optimizers from"
            " [torch.optim](https://pytorch.org/docs/stable/optim.html#algorithms) or provide full python path,"
            " for example 'torch_optimizer.RAdam'."
        },
    )
    optimizer_params: Dict = field(
        default_factory=lambda: {},
        metadata={"help": "The parameters for the optimizer. If left blank, will use default parameters."},
    )
    lr_scheduler: Optional[str] = field(
        default=None,
        metadata={
            "help": "The name of the LearningRateScheduler to use, if any, from"
            " https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate."
            " If None, will not use any scheduler. Defaults to `None`",
        },
    )
    lr_scheduler_params: Optional[Dict] = field(
        default_factory=lambda: {},
        metadata={"help": "The parameters for the LearningRateScheduler. If left blank, will use default parameters."},
    )

    lr_scheduler_monitor_metric: Optional[str] = field(
        default="valid_loss",
        metadata={"help": "Used with ReduceLROnPlateau, where the plateau is decided based on this metric"},
    )

    @staticmethod
    def read_from_yaml(filename: str = "config/optimizer_config.yml"):
        config = _read_yaml(filename)
        if config["lr_scheduler_params"] is None:
            config["lr_scheduler_params"] = {}
        return OptimizerConfig(**config)
Source code in src/pytorch_tabular/config/config.py
class ExperimentRunManager:
    def __init__(
        self,
        exp_version_manager: str = ".pt_tmp/exp_version_manager.yml",
    ) -> None:
        """The manages the versions of the experiments based on the name. It is a simple dictionary(yaml) based lookup.
        Primary purpose is to avoid overwriting of saved models while running the training without changing the
        experiment name.

        Args:
            exp_version_manager (str, optional): The path of the yml file which acts as version control.
                Defaults to ".pt_tmp/exp_version_manager.yml".

        """
        super().__init__()
        self._exp_version_manager = exp_version_manager
        if os.path.exists(exp_version_manager):
            self.exp_version_manager = OmegaConf.load(exp_version_manager)
        else:
            self.exp_version_manager = OmegaConf.create({})
            os.makedirs(os.path.split(exp_version_manager)[0], exist_ok=True)
            with open(self._exp_version_manager, "w") as file:
                OmegaConf.save(config=self.exp_version_manager, f=file)

    def update_versions(self, name):
        if name in self.exp_version_manager.keys():
            uid = self.exp_version_manager[name] + 1
        else:
            uid = 1
        self.exp_version_manager[name] = uid
        with open(self._exp_version_manager, "w") as file:
            OmegaConf.save(config=self.exp_version_manager, f=file)
        return uid

__init__(exp_version_manager='.pt_tmp/exp_version_manager.yml')

The manages the versions of the experiments based on the name. It is a simple dictionary(yaml) based lookup. Primary purpose is to avoid overwriting of saved models while running the training without changing the experiment name.

Parameters:

Name Type Description Default
exp_version_manager str

The path of the yml file which acts as version control. Defaults to ".pt_tmp/exp_version_manager.yml".

'.pt_tmp/exp_version_manager.yml'
Source code in src/pytorch_tabular/config/config.py
def __init__(
    self,
    exp_version_manager: str = ".pt_tmp/exp_version_manager.yml",
) -> None:
    """The manages the versions of the experiments based on the name. It is a simple dictionary(yaml) based lookup.
    Primary purpose is to avoid overwriting of saved models while running the training without changing the
    experiment name.

    Args:
        exp_version_manager (str, optional): The path of the yml file which acts as version control.
            Defaults to ".pt_tmp/exp_version_manager.yml".

    """
    super().__init__()
    self._exp_version_manager = exp_version_manager
    if os.path.exists(exp_version_manager):
        self.exp_version_manager = OmegaConf.load(exp_version_manager)
    else:
        self.exp_version_manager = OmegaConf.create({})
        os.makedirs(os.path.split(exp_version_manager)[0], exist_ok=True)
        with open(self._exp_version_manager, "w") as file:
            OmegaConf.save(config=self.exp_version_manager, f=file)

Head Configuration

In addition to these core classes, we also have config classes for heads

A model class for Linear Head configuration; serves as a template and documentation. The models take a dictionary as input, but if there are keys which are not present in this model class, it'll throw an exception.

Parameters:

Name Type Description Default
layers str

Hyphen-separated number of layers and units in the classification/regression head. E.g. 32-64-32. Default is just a mapping from intput dimension to output dimension

''
activation str

The activation type in the classification head. The default activation in PyTorch like ReLU, TanH, LeakyReLU, etc. https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity

'ReLU'
dropout float

probability of a classification element to be zeroed.

0.0
use_batch_norm bool

Flag to include a BatchNorm layer after each Linear Layer+DropOut

False
initialization str

Initialization scheme for the linear layers. Defaults to kaiming. Choices are: [kaiming,xavier,random].

'kaiming'
Source code in src/pytorch_tabular/models/common/heads/config.py
@dataclass
class LinearHeadConfig:
    """A model class for Linear Head configuration; serves as a template and documentation. The models take a
    dictionary as input, but if there are keys which are not present in this model class, it'll throw an exception.

    Args:
        layers (str): Hyphen-separated number of layers and units in the classification/regression head.
                E.g. 32-64-32. Default is just a mapping from intput dimension to output dimension

        activation (str): The activation type in the classification head. The default activation in PyTorch
                like ReLU, TanH, LeakyReLU, etc. https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity

        dropout (float): probability of a classification element to be zeroed.

        use_batch_norm (bool): Flag to include a BatchNorm layer after each Linear Layer+DropOut

        initialization (str): Initialization scheme for the linear layers. Defaults to `kaiming`. Choices
                are: [`kaiming`,`xavier`,`random`].

    """

    layers: str = field(
        default="",
        metadata={
            "help": "Hyphen-separated number of layers and units in the classification/regression head. eg. 32-64-32."
            " Default is just a mapping from intput dimension to output dimension"
        },
    )
    activation: str = field(
        default="ReLU",
        metadata={
            "help": "The activation type in the classification head. The default activation in PyTorch"
            " like ReLU, TanH, LeakyReLU, etc."
            " https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity"
        },
    )
    dropout: float = field(
        default=0.0,
        metadata={"help": "probability of an classification element to be zeroed."},
    )
    use_batch_norm: bool = field(
        default=False,
        metadata={"help": "Flag to include a BatchNorm layer after each Linear Layer+DropOut"},
    )
    initialization: str = field(
        default="kaiming",
        metadata={
            "help": "Initialization scheme for the linear layers. Defaults to `kaiming`",
            "choices": ["kaiming", "xavier", "random"],
        },
    )

MixtureDensityHead configuration.

Parameters:

Name Type Description Default
num_gaussian int

Number of Gaussian Distributions in the mixture model. Defaults to 1

1
sigma_bias_flag bool

Whether to have a bias term in the sigma layer. Defaults to False

False
mu_bias_init Optional[List]

To initialize the bias parameter of the mu layer to predefined cluster centers. Should be a list with the same length as number of gaussians in the mixture model. It is highly recommended to set the parameter to combat mode collapse. Defaults to None

None
weight_regularization Optional[int]

Whether to apply L1 or L2 Norm to the MDN layers. Defaults to L2. Choices are: [1,2].

2
lambda_sigma Optional[float]

The regularization constant for weight regularization of sigma layer. Defaults to 0.1

0.1
lambda_pi Optional[float]

The regularization constant for weight regularization of pi layer. Defaults to 0.1

0.1
lambda_mu Optional[float]

The regularization constant for weight regularization of mu layer. Defaults to 0

0
softmax_temperature Optional[float]

The temperature to be used in the gumbel softmax of the mixing coefficients. Values less than one leads to sharper transition between the multiple components. Defaults to 1

1
n_samples int

Number of samples to draw from the posterior to get prediction. Defaults to 100

100
central_tendency str

Which measure to use to get the point prediction. Defaults to mean. Choices are: [mean,median].

'mean'
speedup_training bool

Turning on this parameter does away with sampling during training which speeds up training, but also doesn't give you visibility on train metrics. Defaults to False

False
log_debug_plot bool

Turning on this parameter plots histograms of the mu, sigma, and pi layers in addition to the logits(if log_logits is turned on in experment config). Defaults to False

False
input_dim int

The input dimensions to the head. This will be automatically filled in while initializing from the backbone.output_dim

None
Source code in src/pytorch_tabular/models/common/heads/config.py
@dataclass
class MixtureDensityHeadConfig:
    """MixtureDensityHead configuration.

    Args:
        num_gaussian (int): Number of Gaussian Distributions in the mixture model. Defaults to 1

        sigma_bias_flag (bool): Whether to have a bias term in the sigma layer. Defaults to False

        mu_bias_init (Optional[List]): To initialize the bias parameter of the mu layer to predefined
                cluster centers. Should be a list with the same length as number of gaussians in the mixture
                model. It is highly recommended to set the parameter to combat mode collapse. Defaults to None

        weight_regularization (Optional[int]): Whether to apply L1 or L2 Norm to the MDN layers. Defaults
                to L2. Choices are: [`1`,`2`].

        lambda_sigma (Optional[float]): The regularization constant for weight regularization of sigma
                layer. Defaults to 0.1

        lambda_pi (Optional[float]): The regularization constant for weight regularization of pi layer.
                Defaults to 0.1

        lambda_mu (Optional[float]): The regularization constant for weight regularization of mu layer.
                Defaults to 0

        softmax_temperature (Optional[float]): The temperature to be used in the gumbel softmax of the
                mixing coefficients. Values less than one leads to sharper transition between the multiple
                components. Defaults to 1

        n_samples (int): Number of samples to draw from the posterior to get prediction. Defaults to 100

        central_tendency (str): Which measure to use to get the point prediction. Defaults to mean. Choices
                are: [`mean`,`median`].

        speedup_training (bool): Turning on this parameter does away with sampling during training which
                speeds up training, but also doesn't give you visibility on train metrics. Defaults to False

        log_debug_plot (bool): Turning on this parameter plots histograms of the mu, sigma, and pi layers
                in addition to the logits(if log_logits is turned on in experment config). Defaults to False

        input_dim (int): The input dimensions to the head. This will be automatically filled in while
                initializing from the `backbone.output_dim`

    """

    num_gaussian: int = field(
        default=1,
        metadata={
            "help": "Number of Gaussian Distributions in the mixture model. Defaults to 1",
        },
    )
    sigma_bias_flag: bool = field(
        default=False,
        metadata={
            "help": "Whether to have a bias term in the sigma layer. Defaults to False",
        },
    )
    mu_bias_init: Optional[List] = field(
        default=None,
        metadata={
            "help": "To initialize the bias parameter of the mu layer to predefined cluster centers."
            " Should be a list with the same length as number of gaussians in the mixture model."
            " It is highly recommended to set the parameter to combat mode collapse. Defaults to None",
        },
    )

    weight_regularization: Optional[int] = field(
        default=2,
        metadata={
            "help": "Whether to apply L1 or L2 Norm to the MDN layers. Defaults to L2",
            "choices": [1, 2],
        },
    )

    lambda_sigma: Optional[float] = field(
        default=0.1,
        metadata={
            "help": "The regularization constant for weight regularization of sigma layer. Defaults to 0.1",
        },
    )
    lambda_pi: Optional[float] = field(
        default=0.1,
        metadata={
            "help": "The regularization constant for weight regularization of pi layer. Defaults to 0.1",
        },
    )
    lambda_mu: Optional[float] = field(
        default=0,
        metadata={
            "help": "The regularization constant for weight regularization of mu layer. Defaults to 0",
        },
    )
    softmax_temperature: Optional[float] = field(
        default=1,
        metadata={
            "help": "The temperature to be used in the gumbel softmax of the mixing coefficients."
            " Values less than one leads to sharper transition between the multiple components. Defaults to 1",
        },
    )
    n_samples: int = field(
        default=100,
        metadata={
            "help": "Number of samples to draw from the posterior to get prediction. Defaults to 100",
        },
    )
    central_tendency: str = field(
        default="mean",
        metadata={
            "help": "Which measure to use to get the point prediction. Defaults to mean",
            "choices": ["mean", "median"],
        },
    )
    speedup_training: bool = field(
        default=False,
        metadata={
            "help": "Turning on this parameter does away with sampling during training which speeds up training,"
            " but also doesn't give you visibility on train metrics. Defaults to False",
        },
    )
    log_debug_plot: bool = field(
        default=False,
        metadata={
            "help": "Turning on this parameter plots histograms of the mu, sigma, and pi layers in addition"
            " to the logits(if log_logits is turned on in experment config). Defaults to False",
        },
    )
    input_dim: int = field(
        default=None,
        metadata={
            "help": "The input dimensions to the head. This will be automatically filled in while initializing"
            " from the `backbone.output_dim`",
        },
    )
    _probabilistic: bool = field(default=True)