Skip to content

Self-Supervised Models

Configuration Classes

pytorch_tabular.ssl_models.DenoisingAutoEncoderConfig dataclass

Bases: SSLModelConfig

DeNoising AutoEncoder configuration.

PARAMETER DESCRIPTION
noise_strategy

Defines what kind of noise we are introducing to samples. swap - Swap noise is when we replace values of a feature with random permutations of the same feature. zero - Zero noise is when we replace values of a feature with zeros. Defaults to swap. Choices are: [swap,zero].

TYPE: str DEFAULT: field(default='swap', metadata={'help': 'Defines what kind of noise we are introducing to samples. `swap` - Swap noise is when we replace values of a feature with random permutations of the same feature. `zero` - Zero noise is when we replace values of a feature with zeros. Defaults to swap', 'choices': ['swap', 'zero']})

noise_probabilities

Dict of individual probabilities to corrupt the input features with swap/zero noise. Key should be the feature name and if any feature is missing, the default_noise_probability is used. Default is an empty dict()

TYPE: Dict[str, float] DEFAULT: field(default_factory=lambda : {}, metadata={'help': 'Dict of individual probabilities to corrupt the input features with swap/zero noise. Key should be the feature name and if any feature is missing, the default_noise_probability is used. Default is an empty dict()'})

default_noise_probability

Default probability to corrupt the input features with swap/zero noise. For features for which noise_probabilities does not define a probability. Default is 0.8

TYPE: float DEFAULT: field(default=0.8, metadata={'help': 'Default probability to corrupt the input features with swap/zero noise. For features for which noise_probabilities does not define a probability. Default is 0.8'})

loss_type_weights

Weights to be used for the loss function in the order [binary, categorical, numerical]. If None, will use the default weights using a formula. eg. for binary, default weight will be n_binary/n_features. Defaults to None

TYPE: Optional[List[float]] DEFAULT: field(default=None, metadata={'help': 'Weights to be used for the loss function in the order [binary, categorical, numerical]. If None, will use the default weights using a formula. eg. for binary, default weight will be n_binary/n_features. Defaults to None'})

mask_loss_weight

Weight to be used for the loss function for the masked features. Defaults to 1.0

TYPE: float DEFAULT: field(default=2.0, metadata={'help': 'Weight to be used for the loss function for the masked features. Defaults to 1.0'})

max_onehot_cardinality

Maximum cardinality of one-hot encoded categorical features. Any categorical feature with cardinality>max_onehot_cardinality will be embedded in a learned embedding space and others will be converted to a one hot representation. If set to 0, will use the embedding strategy for all categorical feature. Default is 4

TYPE: int DEFAULT: field(default=4, metadata={'help': 'Maximum cardinality of one-hot encoded categorical features. Any categorical feature with cardinality>max_onehot_cardinality will be embedded in a learned embedding space and others will be converted to a one hot representation. If set to 0, will use the embedding strategy for all categorical feature. Default is 4'})

encoder_config

The config of the encoder to be used for the model. Should be one of the model configs defined in PyTorch Tabular

TYPE: Optional[pytorch_tabular.config.config.ModelConfig] DEFAULT: field(default=None, metadata={'help': 'The config of the encoder to be used for the model. Should be one of the model configs defined in PyTorch Tabular'})

decoder_config

The config of decoder to be used for the model. Should be one of the model configs defined in PyTorch Tabular. Defaults to nn.Identity

TYPE: Optional[pytorch_tabular.config.config.ModelConfig] DEFAULT: field(default=None, metadata={'help': 'The config of decoder to be used for the model. Should be one of the model configs defined in PyTorch Tabular. Defaults to nn.Identity'})

embedding_dims

The dimensions of the embedding for each categorical column as a list of tuples (cardinality, embedding_dim). If left empty, will infer using the cardinality of the categorical column using the rule min(50, (x + 1) // 2)

TYPE: Optional[List] DEFAULT: field(default=None, metadata={'help': 'The dimensions of the embedding for each categorical column as a list of tuples (cardinality, embedding_dim). If left empty, will infer using the cardinality of the categorical column using the rule min(50, (x + 1) // 2)'})

embedding_dropout

Dropout to be applied to the Categorical Embedding. Defaults to 0.1

TYPE: float DEFAULT: field(default=0.1, metadata={'help': 'Dropout to be applied to the Categorical Embedding. Defaults to 0.1'})

batch_norm_continuous_input

If True, we will normalize the continuous layer by passing it through a BatchNorm layer. DEPRECATED - Use head and head_config instead

TYPE: bool DEFAULT: field(default=True, metadata={'help': 'If True, we will normalize the continuous layer by passing it through a BatchNorm layer. DEPRECATED - Use head and head_config instead'})

learning_rate

The learning rate of the model. Defaults to 1e-3

TYPE: float DEFAULT: field(default=0.001, metadata={'help': 'The learning rate of the model. Defaults to 1e-3'})

seed

The seed for reproducibility. Defaults to 42

TYPE: int DEFAULT: field(default=42, metadata={'help': 'The seed for reproducibility. Defaults to 42'})

Model Classes

pytorch_tabular.ssl_models.DenoisingAutoEncoderModel(config, **kwargs)

Bases: SSLBaseModel

Source code in src/pytorch_tabular/ssl_models/dae/dae.py
def __init__(self, config: DictConfig, **kwargs):
    encoded_cat_dims = 0
    inferred_config = kwargs.get("inferred_config")
    for card, embd_dim in inferred_config.embedding_dims:
        if card == 2:
            encoded_cat_dims += 1
        elif card <= config.max_onehot_cardinality:
            encoded_cat_dims += card
        else:
            encoded_cat_dims += embd_dim
    config.encoder_config._backbone_input_dim = encoded_cat_dims + len(config.continuous_cols)
    assert (
        config.encoder_config._config_name in self.ALLOWED_MODELS
    ), "Encoder must be one of the following: " + ", ".join(self.ALLOWED_MODELS)
    if config.decoder_config is not None:
        assert (
            config.decoder_config._config_name in self.ALLOWED_MODELS
        ), "Decoder must be one of the following: " + ", ".join(self.ALLOWED_MODELS)
        if "-" in config.encoder_config.layers:
            config.decoder_config._backbone_input_dim = int(config.encoder_config.layers.split("-")[-1])
        else:
            config.decoder_config._backbone_input_dim = int(config.encoder_config.layers)
    super().__init__(config, **kwargs)

Base Model Class

pytorch_tabular.ssl_models.SSLBaseModel(config, mode='pretrain', encoder=None, decoder=None, custom_optimizer=None, custom_optimizer_params={}, **kwargs)

Bases: pl.LightningModule

Source code in src/pytorch_tabular/ssl_models/base_model.py
def __init__(
    self,
    config: DictConfig,
    mode: str = "pretrain",
    encoder: Optional[nn.Module] = None,
    decoder: Optional[nn.Module] = None,
    custom_optimizer: Optional[torch.optim.Optimizer] = None,
    custom_optimizer_params: Dict = {},
    **kwargs,
):
    super().__init__()
    assert "inferred_config" in kwargs, "inferred_config not found in initialization arguments"
    inferred_config = kwargs["inferred_config"]
    # Merging the config and inferred config
    config = safe_merge_config(config, inferred_config)

    self._setup_encoder_decoder(
        encoder,
        config.encoder_config,
        decoder,
        config.decoder_config,
        inferred_config,
    )
    self.custom_optimizer = custom_optimizer
    self.custom_optimizer_params = custom_optimizer_params
    # Updating config with custom parameters for experiment tracking
    if self.custom_optimizer is not None:
        config.optimizer = str(self.custom_optimizer.__class__.__name__)
    if len(self.custom_optimizer_params) > 0:
        config.optimizer_params = self.custom_optimizer_params
    self.mode = mode
    self._check_and_verify()
    self.save_hyperparameters(config)
    self._build_network()
    self._setup_loss()
    self._setup_metrics()