Skip to content

Self-Supervised Models

Configuration Classes

Bases: SSLModelConfig

DeNoising AutoEncoder configuration.

Parameters:

Name Type Description Default
noise_strategy str

Defines what kind of noise we are introducing to samples. swap - Swap noise is when we replace values of a feature with random permutations of the same feature. zero - Zero noise is when we replace values of a feature with zeros. Defaults to swap. Choices are: [swap,zero].

'swap'
noise_probabilities Dict[str, float]

Dict of individual probabilities to corrupt the input features with swap/zero noise. Key should be the feature name and if any feature is missing, the default_noise_probability is used. Default is an empty dict()

lambda: {}()
default_noise_probability float

Default probability to corrupt the input features with swap/zero noise. For features for which noise_probabilities does not define a probability. Default is 0.8

0.8
loss_type_weights Optional[List[float]]

Weights to be used for the loss function in the order [binary, categorical, numerical]. If None, will use the default weights using a formula. eg. for binary, default weight will be n_binary/n_features. Defaults to None

None
mask_loss_weight float

Weight to be used for the loss function for the masked features. Defaults to 1.0

2.0
max_onehot_cardinality int

Maximum cardinality of one-hot encoded categorical features. Any categorical feature with cardinality>max_onehot_cardinality will be embedded in a learned embedding space and others will be converted to a one hot representation. If set to 0, will use the embedding strategy for all categorical feature. Default is 4

4
include_input_features_inference bool

If True, will include the input features along with the learned features while fine tuning. Defaults to False

False
encoder_config Optional[ModelConfig]

The config of the encoder to be used for the model. Should be one of the model configs defined in PyTorch Tabular

None
decoder_config Optional[ModelConfig]

The config of decoder to be used for the model. Should be one of the model configs defined in PyTorch Tabular. Defaults to nn.Identity

None
embedding_dims Optional[List]

The dimensions of the embedding for each categorical column as a list of tuples (cardinality, embedding_dim). If left empty, will infer using the cardinality of the categorical column using the rule min(50, (x + 1) // 2)

None
embedding_dropout float

Dropout to be applied to the Categorical Embedding. Defaults to 0.1

0.1
batch_norm_continuous_input bool

If True, we will normalize the continuous layer by passing it through a BatchNorm layer.

True
learning_rate float

The learning rate of the model. Defaults to 1e-3

0.001
seed int

The seed for reproducibility. Defaults to 42

42
Source code in src/pytorch_tabular/ssl_models/dae/config.py
@dataclass
class DenoisingAutoEncoderConfig(SSLModelConfig):
    """DeNoising AutoEncoder configuration.

    Args:
        noise_strategy (str): Defines what kind of noise we are introducing to samples. `swap` - Swap noise
                is when we replace values of a feature with random permutations of the same feature. `zero` - Zero
                noise is when we replace values of a feature with zeros. Defaults to swap. Choices are:
                [`swap`,`zero`].

        noise_probabilities (Dict[str, float]): Dict of individual probabilities to corrupt the input
                features with swap/zero noise. Key should be the feature name and if any feature is missing, the
                default_noise_probability is used. Default is an empty dict()

        default_noise_probability (float): Default probability to corrupt the input features with swap/zero
                noise. For features for which noise_probabilities does not define a probability. Default is 0.8

        loss_type_weights (Optional[List[float]]): Weights to be used for the loss function in the order
                [binary, categorical, numerical]. If None, will use the default weights using a formula. eg. for
                binary, default weight will be n_binary/n_features. Defaults to None

        mask_loss_weight (float): Weight to be used for the loss function for the masked features. Defaults
                to 1.0

        max_onehot_cardinality (int): Maximum cardinality of one-hot encoded categorical features. Any
                categorical feature with cardinality>max_onehot_cardinality will be embedded in a learned
                embedding space and others will be converted to a one hot representation. If set to 0, will use
                the embedding strategy for all categorical feature. Default is 4

        include_input_features_inference (bool): If True, will include the input features along with the
                learned features while fine tuning. Defaults to False

        encoder_config (Optional[pytorch_tabular.config.config.ModelConfig]): The config of the encoder to
                be used for the model. Should be one of the model configs defined in PyTorch Tabular

        decoder_config (Optional[pytorch_tabular.config.config.ModelConfig]): The config of decoder to be
                used for the model. Should be one of the model configs defined in PyTorch Tabular. Defaults to
                nn.Identity

        embedding_dims (Optional[List]): The dimensions of the embedding for each categorical column as a
                list of tuples (cardinality, embedding_dim). If left empty, will infer using the cardinality of
                the categorical column using the rule min(50, (x + 1) // 2)

        embedding_dropout (float): Dropout to be applied to the Categorical Embedding. Defaults to 0.1

        batch_norm_continuous_input (bool): If True, we will normalize the continuous layer by passing it
                through a BatchNorm layer.

        learning_rate (float): The learning rate of the model. Defaults to 1e-3

        seed (int): The seed for reproducibility. Defaults to 42

    """

    noise_strategy: str = field(
        default="swap",
        metadata={
            "help": "Defines what kind of noise we are introducing to samples."
            " `swap` - Swap noise is when we replace values of a feature with random permutations"
            " of the same feature. `zero` - Zero noise is when we replace values of a feature with zeros."
            " Defaults to swap",
            "choices": ["swap", "zero"],
        },
    )
    # Union not supported by omegaconf. Currently Union[float, Dict[str, float]]
    noise_probabilities: Dict[str, float] = field(
        default_factory=lambda: {},
        metadata={
            "help": "Dict of individual probabilities to corrupt the input features with swap/zero noise."
            " Key should be the feature name and if any feature is missing,"
            " the default_noise_probability is used. Default is an empty dict()"
        },
    )
    default_noise_probability: float = field(
        default=0.8,
        metadata={
            "help": "Default probability to corrupt the input features with swap/zero noise."
            " For features for which noise_probabilities does not define a probability. Default is 0.8"
        },
    )
    loss_type_weights: Optional[List[float]] = field(
        default=None,
        metadata={
            "help": "Weights to be used for the loss function in the order [binary, categorical, numerical]."
            " If None, will use the default weights using a formula. eg. for binary,"
            " default weight will be n_binary/n_features. Defaults to None"
        },
    )
    mask_loss_weight: float = field(
        default=2.0,
        metadata={"help": "Weight to be used for the loss function for the masked features. Defaults to 1.0"},
    )
    max_onehot_cardinality: int = field(
        default=4,
        metadata={
            "help": "Maximum cardinality of one-hot encoded categorical features."
            " Any categorical feature with cardinality>max_onehot_cardinality will be embedded"
            " in a learned embedding space and others will be converted to a one hot representation."
            " If set to 0, will use the embedding strategy for all categorical feature. Default is 4"
        },
    )
    include_input_features_inference: bool = field(
        default=False,
        metadata={
            "help": "If True, will include the input features along with the learned features"
            " while fine tuning. Defaults to False"
        },
    )

    _module_src: str = field(default="ssl_models.dae")
    _model_name: str = field(default="DenoisingAutoEncoderModel")
    _config_name: str = field(default="DenoisingAutoEncoderConfig")

    def __post_init__(self):
        assert hasattr(self.encoder_config, "_backbone_name"), "encoder_config should have a _backbone_name attribute"
        if self.decoder_config is not None:
            assert hasattr(
                self.decoder_config, "_backbone_name"
            ), "decoder_config should have a _backbone_name attribute"
        super().__post_init__()

Model Classes

Bases: SSLBaseModel

Source code in src/pytorch_tabular/ssl_models/dae/dae.py
class DenoisingAutoEncoderModel(SSLBaseModel):
    output_tuple = namedtuple("output_tuple", ["original", "reconstructed"])
    loss_weight_tuple = namedtuple("loss_weight_tuple", ["binary", "categorical", "continuous", "mask"])
    # fix for pickling
    # https://codefying.com/2019/05/04/dont-get-in-a-pickle-with-a-namedtuple/
    output_tuple.__qualname__ = "DenoisingAutoEncoderModel.output_tuple"
    loss_weight_tuple.__qualname__ = "DenoisingAutoEncoderModel.loss_weight_tuple"
    ALLOWED_MODELS = ["CategoryEmbeddingModelConfig"]

    def __init__(self, config: DictConfig, **kwargs):
        encoded_cat_dims = 0
        inferred_config = kwargs.get("inferred_config")
        for card, embd_dim in inferred_config.embedding_dims:
            if card == 2:
                encoded_cat_dims += 1
            elif card <= config.max_onehot_cardinality:
                encoded_cat_dims += card
            else:
                encoded_cat_dims += embd_dim
        config.encoder_config._backbone_input_dim = encoded_cat_dims + len(config.continuous_cols)
        assert config.encoder_config._config_name in self.ALLOWED_MODELS, (
            "Encoder must be one of the following: " + ", ".join(self.ALLOWED_MODELS)
        )
        if config.decoder_config is not None:
            assert config.decoder_config._config_name in self.ALLOWED_MODELS, (
                "Decoder must be one of the following: " + ", ".join(self.ALLOWED_MODELS)
            )
            if "-" in config.encoder_config.layers:
                config.decoder_config._backbone_input_dim = int(config.encoder_config.layers.split("-")[-1])
            else:
                config.decoder_config._backbone_input_dim = int(config.encoder_config.layers)
        super().__init__(config, **kwargs)

    def _get_noise_probability(self, name):
        return self.hparams.noise_probabilities.get(name, self.hparams.default_noise_probability)

    @property
    def embedding_layer(self):
        return self._embedding

    @property
    def featurizer(self):
        return self._featurizer

    def _build_network(self):
        self._featurizer = DenoisingAutoEncoderFeaturizer(self.encoder, self.hparams)
        self._embedding = self._featurizer._build_embedding_layer()
        self.reconstruction = MultiTaskHead(
            self.decoder.output_dim,
            n_binary=len(self._embedding._binary_feat_idx),
            n_categorical=len(self._embedding._onehot_feat_idx),
            n_numerical=self._embedding.embedded_cat_dim + len(self.hparams.continuous_cols),
            cardinality=[self._embedding.categorical_embedding_dims[i][0] for i in self._embedding._onehot_feat_idx],
        )
        self.mask_reconstruction = nn.Linear(self.decoder.output_dim, len(self._featurizer.swap_noise.probas))

    def _setup_loss(self):
        self.losses = {
            "binary": nn.BCEWithLogitsLoss(),
            "categorical": nn.CrossEntropyLoss(),
            "continuous": nn.MSELoss(),
            "mask": nn.BCEWithLogitsLoss(),
        }
        if self.hparams.loss_type_weights is None:
            self.loss_weights = self.loss_weight_tuple(*self._init_loss_weights())
        else:
            self.loss_weights = self.loss_weight_tuple(*self.hparams.loss_type_weights, self.hparams.mask_loss_weight)

    def _init_loss_weights(self):
        n_features = self.hparams.continuous_dim + len(self.hparams.embedding_dims)
        return [
            len(self.embedding_layer._binary_feat_idx) / n_features,
            len(self.embedding_layer._onehot_feat_idx) / n_features,
            self.hparams.continuous_dim + len(self.embedding_layer._embedding_feat_idx) / n_features,
            self.hparams.mask_loss_weight,
        ]

    def _setup_metrics(self):
        return None

    def forward(self, x: Dict):
        if self.mode == "pretrain":
            x = self.embedding_layer(x)
            # (B, N, E)
            features = self.featurizer(x, perturb=True)
            z, mask = features.features, features.mask
            # decoder
            z_hat = self.decoder(z)
            # reconstruction
            reconstructed_in = self.reconstruction(z_hat)
            # mask reconstruction
            reconstructed_mask = self.mask_reconstruction(z_hat)
            output_dict = {"mask": self.output_tuple(mask, reconstructed_mask)}
            if "continuous" in reconstructed_in.keys():
                output_dict["continuous"] = self.output_tuple(
                    torch.cat(
                        [
                            i
                            for i in [
                                x.get("continuous", None),
                                x.get("embedding", None),
                            ]
                            if i is not None
                        ],
                        1,
                    ),
                    reconstructed_in["continuous"],
                )
            if "categorical" in reconstructed_in.keys():
                output_dict["categorical"] = self.output_tuple(x["_categorical_orig"], reconstructed_in["categorical"])
            if "binary" in reconstructed_in.keys():
                output_dict["binary"] = self.output_tuple(x["binary"], reconstructed_in["binary"])
            return output_dict
        else:  # self.mode == "finetune"
            z, x = self.featurizer(x, perturb=False, return_input=True)
            if self.hparams.include_input_features_inference:
                return torch.cat([z.features, x], 1)
            else:
                return z.features

    def calculate_loss(self, output, tag):
        total_loss = 0
        for type_, out in output.items():
            if type_ == "categorical":
                loss = 0
                for i in range(out.original.size(-1)):
                    loss += self.losses[type_](out.reconstructed[i], out.original[:, i])
            elif type_ == "binary":
                # Casting output to float for BCEWithLogitsLoss
                loss = self.losses[type_](out.reconstructed, out.original.float())
            else:
                loss = self.losses[type_](out.reconstructed, out.original)
            loss *= getattr(self.loss_weights, type_)
            self.log(
                f"{tag}_{type_}_loss",
                loss.item(),
                on_epoch=True,
                on_step=False,
                logger=True,
                prog_bar=False,
            )
            total_loss += loss
        self.log(
            f"{tag}_loss",
            total_loss,
            on_epoch=(tag == "valid") or (tag == "test"),
            on_step=(tag == "train"),
            # on_step=False,
            logger=True,
            prog_bar=True,
        )
        return total_loss

    def calculate_metrics(self, output, tag):
        pass

    def featurize(self, x: Dict):
        x = self.embedding_layer(x)
        return self.featurizer(x, perturb=False).features

    @property
    def output_dim(self):
        if self.mode == "finetune" and self.hparams.include_input_features_inference:
            return self._featurizer.encoder.output_dim + self.hparams.encoder_config._backbone_input_dim
        else:
            return self._featurizer.encoder.output_dim

Base Model Class

Bases: LightningModule

Source code in src/pytorch_tabular/ssl_models/base_model.py
class SSLBaseModel(pl.LightningModule, metaclass=ABCMeta):
    def __init__(
        self,
        config: DictConfig,
        mode: str = "pretrain",
        encoder: Optional[nn.Module] = None,
        decoder: Optional[nn.Module] = None,
        custom_optimizer: Optional[torch.optim.Optimizer] = None,
        custom_optimizer_params: Dict = {},
        **kwargs,
    ):
        """Base Model for all SSL Models.

        Args:
            config (DictConfig): Configuration defined by the user
            mode (str, optional): Mode of the model. Defaults to "pretrain".
            encoder (Optional[nn.Module], optional): Encoder of the model. Defaults to None.
            decoder (Optional[nn.Module], optional): Decoder of the model. Defaults to None.
            custom_optimizer (Optional[torch.optim.Optimizer], optional): Custom optimizer to use. Defaults to None.
            custom_optimizer_params (Dict, optional): Custom optimizer parameters to use. Defaults to {}.

        """
        super().__init__()
        assert "inferred_config" in kwargs, "inferred_config not found in initialization arguments"
        inferred_config = kwargs["inferred_config"]
        # Merging the config and inferred config
        config = safe_merge_config(config, inferred_config)

        self._setup_encoder_decoder(
            encoder,
            config.encoder_config,
            decoder,
            config.decoder_config,
            inferred_config,
        )
        self.custom_optimizer = custom_optimizer
        self.custom_optimizer_params = custom_optimizer_params
        # Updating config with custom parameters for experiment tracking
        if self.custom_optimizer is not None:
            config.optimizer = str(self.custom_optimizer.__class__.__name__)
        if len(self.custom_optimizer_params) > 0:
            config.optimizer_params = self.custom_optimizer_params
        self.mode = mode
        self._check_and_verify()
        self.save_hyperparameters(config)
        self._build_network()
        self._setup_loss()
        self._setup_metrics()

    def _setup_encoder_decoder(self, encoder, encoder_config, decoder, decoder_config, inferred_config):
        assert (encoder is not None) or (
            encoder_config is not None
        ), "Either encoder or encoder_config must be provided"
        # assert (decoder is not None) or (decoder_config is not None),
        # "Either decoder or decoder_config must be provided"
        if encoder is not None:
            self.encoder = encoder
            self._custom_decoder = True
        else:
            # Since encoder is not provided, we will use the encoder_config
            model_callable = getattr_nested(encoder_config._module_src, encoder_config._backbone_name)
            self.encoder = model_callable(
                safe_merge_config(encoder_config, inferred_config),
                # inferred_config=inferred_config,
            )
        if decoder is not None:
            self.decoder = decoder
            self._custom_encoder = True
        elif decoder_config is not None:
            # Since decoder is not provided, we will use the decoder_config
            model_callable = getattr_nested(decoder_config._module_src, decoder_config._backbone_name)
            self.decoder = model_callable(
                safe_merge_config(decoder_config, inferred_config),
                # inferred_config=inferred_config,
            )
        else:
            self.decoder = nn.Identity()

    def _check_and_verify(self):
        assert hasattr(self.encoder, "output_dim"), "An encoder backbone must have an output_dim attribute"
        if isinstance(self.decoder, nn.Identity):
            self.decoder.output_dim = self.encoder.output_dim
        assert hasattr(self.decoder, "output_dim"), "A decoder must have an output_dim attribute"

    @property
    def embedding_layer(self):
        raise NotImplementedError("`embedding_layer` property needs to be implemented by inheriting classes")

    @property
    def featurizer(self):
        raise NotImplementedError("`featurizer` property needs to be implemented by inheriting classes")

    @abstractmethod
    def _setup_loss(self):
        pass

    @abstractmethod
    def _setup_metrics(self):
        pass

    @abstractmethod
    def calculate_loss(self, output, tag):
        pass

    @abstractmethod
    def calculate_metrics(self, output, tag):
        pass

    @abstractmethod
    def forward(self, x: Dict):
        pass

    @abstractmethod
    def featurize(self, x: Dict):
        pass

    def predict(self, x: Dict, ret_model_output: bool = True):  # ret_model_output only for compatibility
        assert ret_model_output, "ret_model_output must be True in case of SSL predict"
        return self.featurize(x)

    def data_aware_initialization(self, datamodule):
        pass

    def training_step(self, batch, batch_idx):
        output = self.forward(batch)
        loss = self.calculate_loss(output, tag="train")
        self.calculate_metrics(output, tag="train")
        return loss

    def validation_step(self, batch, batch_idx):
        with torch.no_grad():
            output = self.forward(batch)
            self.calculate_loss(output, tag="valid")
            self.calculate_metrics(output, tag="valid")
        return output

    def test_step(self, batch, batch_idx):
        with torch.no_grad():
            output = self.forward(batch)
            self.calculate_loss(output, tag="test")
            self.calculate_metrics(output, tag="test")
        return output

    def on_validation_epoch_end(self) -> None:
        if hasattr(self.hparams, "log_logits") and self.hparams.log_logits:
            warnings.warn(
                "Logging Logits is disabled for SSL tasks. Set `log_logits` to False" " to turn off this warning"
            )
        super().on_validation_epoch_end()

    def configure_optimizers(self):
        if self.custom_optimizer is None:
            # Loading from the config
            try:
                self._optimizer = _create_optimizer(self.hparams.optimizer)
                opt = self._optimizer(
                    self.parameters(),
                    lr=self.hparams.learning_rate,
                    **self.hparams.optimizer_params,
                )
            except AttributeError as e:
                logger.error(f"{self.hparams.optimizer} is not a valid optimizer defined in the torch.optim module")
                raise e
        else:
            # Loading from custom fit arguments
            self._optimizer = self.custom_optimizer

            opt = self._optimizer(self.parameters(), lr=self.hparams.learning_rate, **self.custom_optimizer_params)
        if self.hparams.lr_scheduler is not None:
            try:
                self._lr_scheduler = getattr(torch.optim.lr_scheduler, self.hparams.lr_scheduler)
            except AttributeError as e:
                logger.error(
                    f"{self.hparams.lr_scheduler} is not a valid learning rate sheduler defined"
                    f" in the torch.optim.lr_scheduler module"
                )
                raise e
            if isinstance(self._lr_scheduler, torch.optim.lr_scheduler._LRScheduler):
                return {
                    "optimizer": opt,
                    "lr_scheduler": self._lr_scheduler(opt, **self.hparams.lr_scheduler_params),
                }
            return {
                "optimizer": opt,
                "lr_scheduler": self._lr_scheduler(opt, **self.hparams.lr_scheduler_params),
                "monitor": self.hparams.lr_scheduler_monitor_metric,
            }
        else:
            return opt

    def reset_weights(self):
        reset_all_weights(self.featurizer)
        reset_all_weights(self.embedding_layer)

__init__(config, mode='pretrain', encoder=None, decoder=None, custom_optimizer=None, custom_optimizer_params={}, **kwargs)

Base Model for all SSL Models.

Parameters:

Name Type Description Default
config DictConfig

Configuration defined by the user

required
mode str

Mode of the model. Defaults to "pretrain".

'pretrain'
encoder Optional[Module]

Encoder of the model. Defaults to None.

None
decoder Optional[Module]

Decoder of the model. Defaults to None.

None
custom_optimizer Optional[Optimizer]

Custom optimizer to use. Defaults to None.

None
custom_optimizer_params Dict

Custom optimizer parameters to use. Defaults to {}.

{}
Source code in src/pytorch_tabular/ssl_models/base_model.py
def __init__(
    self,
    config: DictConfig,
    mode: str = "pretrain",
    encoder: Optional[nn.Module] = None,
    decoder: Optional[nn.Module] = None,
    custom_optimizer: Optional[torch.optim.Optimizer] = None,
    custom_optimizer_params: Dict = {},
    **kwargs,
):
    """Base Model for all SSL Models.

    Args:
        config (DictConfig): Configuration defined by the user
        mode (str, optional): Mode of the model. Defaults to "pretrain".
        encoder (Optional[nn.Module], optional): Encoder of the model. Defaults to None.
        decoder (Optional[nn.Module], optional): Decoder of the model. Defaults to None.
        custom_optimizer (Optional[torch.optim.Optimizer], optional): Custom optimizer to use. Defaults to None.
        custom_optimizer_params (Dict, optional): Custom optimizer parameters to use. Defaults to {}.

    """
    super().__init__()
    assert "inferred_config" in kwargs, "inferred_config not found in initialization arguments"
    inferred_config = kwargs["inferred_config"]
    # Merging the config and inferred config
    config = safe_merge_config(config, inferred_config)

    self._setup_encoder_decoder(
        encoder,
        config.encoder_config,
        decoder,
        config.decoder_config,
        inferred_config,
    )
    self.custom_optimizer = custom_optimizer
    self.custom_optimizer_params = custom_optimizer_params
    # Updating config with custom parameters for experiment tracking
    if self.custom_optimizer is not None:
        config.optimizer = str(self.custom_optimizer.__class__.__name__)
    if len(self.custom_optimizer_params) > 0:
        config.optimizer_params = self.custom_optimizer_params
    self.mode = mode
    self._check_and_verify()
    self.save_hyperparameters(config)
    self._build_network()
    self._setup_loss()
    self._setup_metrics()