Skip to content

Other Features

Apart from training and using Deep Networks for tabular data, PyTorch Tabular also has some cool features which can help your classical ML/ sci-kit learn pipelines

Categorical Embeddings

The CategoryEmbedding Model can also be used as a way to encode your categorical columns. instead of using a One-hot encoder or a variant of TargetMean Encoding, you can use a learned embedding to encode your categorical features. And all this can be done using a scikit-learn style Transformer.

Usage Example

# passing the trained model as an argument
transformer = CategoricalEmbeddingTransformer(tabular_model)
# passing the train dataframe to extract the embeddings and replace categorical features
# defined in the trained tabular_model
train_transformed = transformer.fit_transform(train)
# using the extracted embeddings on new dataframe
val_transformed = transformer.transform(val)

pytorch_tabular.categorical_encoders.CategoricalEmbeddingTransformer

Bases: BaseEstimator, TransformerMixin

Source code in src/pytorch_tabular/categorical_encoders.py
class CategoricalEmbeddingTransformer(BaseEstimator, TransformerMixin):
    NAN_CATEGORY = 0

    def __init__(self, tabular_model):
        """Initializes the Transformer and extracts the neural embeddings.

        Args:
            tabular_model (TabularModel): The trained TabularModel object

        """
        self._categorical_encoder = tabular_model.datamodule.categorical_encoder
        self.cols = tabular_model.model.hparams.categorical_cols
        # dict {str: np.ndarray} column name --> mapping from category (index of df) to value (column of df)
        self._mapping = {}

        self._extract_embedding(tabular_model.model)

    def _extract_embedding(self, model):
        try:
            embedding_layer = model.extract_embedding()
        except ValueError as e:
            logger.error(
                f"Extracting embedding layer from model received this error: {e}."
                f" Some models do not support this feature."
            )
            embedding_layer = None
        if embedding_layer is not None:
            for i, col in enumerate(self.cols):
                self._mapping[col] = {}
                embedding = embedding_layer[i]
                self._mapping[col][self.NAN_CATEGORY] = embedding.weight[0, :].detach().cpu().numpy().ravel()
                for key in self._categorical_encoder._mapping[col].index:
                    self._mapping[col][key] = (
                        embedding.weight[self._categorical_encoder._mapping[col].loc[key], :]
                        .detach()
                        .cpu()
                        .numpy()
                        .ravel()
                    )
        else:
            raise ValueError("Passed model doesn't support this feature.")

    def fit(self, X, y=None):
        """Just for compatibility.

        Does not do anything

        """
        return self

    def transform(self, X: DataFrame, y=None) -> DataFrame:
        """Transforms the categorical columns specified to the trained neural embedding from the model.

        Args:
            X (DataFrame): DataFrame of features, shape (n_samples, n_features). Must contain columns to encode.
            y ([type], optional): Only for compatibility. Not used. Defaults to None.

        Raises:
            ValueError: [description]

        Returns:
            DataFrame: The encoded dataframe

        """
        if not self._mapping:
            raise ValueError(
                "Passed model should either have an attribute `embeddng_layers`"
                " or a method `extract_embedding` defined for `transform`."
            )
        assert all(c in X.columns for c in self.cols)

        X_encoded = X.copy(deep=True)
        for col, mapping in track(
            self._mapping.items(),
            description="Encoding the data...",
            total=len(self._mapping.values()),
        ):
            for dim in range(mapping[self.NAN_CATEGORY].shape[0]):
                X_encoded.loc[:, f"{col}_embed_dim_{dim}"] = (
                    X_encoded[col].fillna(self.NAN_CATEGORY).map({k: v[dim] for k, v in mapping.items()})
                )
                # Filling unseen categories also with NAN Embedding
                X_encoded[f"{col}_embed_dim_{dim}"].fillna(mapping[self.NAN_CATEGORY][dim], inplace=True)
        X_encoded.drop(columns=self.cols, inplace=True)
        return X_encoded

    def fit_transform(self, X: DataFrame, y=None) -> DataFrame:
        """Encode given columns of X based on the learned embedding.

        Args:
            X (DataFrame): DataFrame of features, shape (n_samples, n_features). Must contain columns to encode.
            y ([type], optional): Only for compatibility. Not used. Defaults to None.

        Returns:
            DataFrame: The encoded dataframe

        """
        self.fit(X, y)
        return self.transform(X)

    def save_as_object_file(self, path):
        if not self._mapping:
            raise ValueError("`fit` method must be called before `save_as_object_file`.")
        pickle.dump(self.__dict__, open(path, "wb"))

    def load_from_object_file(self, path):
        for k, v in pickle.load(open(path, "rb")).items():
            setattr(self, k, v)

__init__(tabular_model)

Initializes the Transformer and extracts the neural embeddings.

Parameters:

Name Type Description Default
tabular_model TabularModel

The trained TabularModel object

required
Source code in src/pytorch_tabular/categorical_encoders.py
def __init__(self, tabular_model):
    """Initializes the Transformer and extracts the neural embeddings.

    Args:
        tabular_model (TabularModel): The trained TabularModel object

    """
    self._categorical_encoder = tabular_model.datamodule.categorical_encoder
    self.cols = tabular_model.model.hparams.categorical_cols
    # dict {str: np.ndarray} column name --> mapping from category (index of df) to value (column of df)
    self._mapping = {}

    self._extract_embedding(tabular_model.model)

fit(X, y=None)

Just for compatibility.

Does not do anything

Source code in src/pytorch_tabular/categorical_encoders.py
def fit(self, X, y=None):
    """Just for compatibility.

    Does not do anything

    """
    return self

fit_transform(X, y=None)

Encode given columns of X based on the learned embedding.

Parameters:

Name Type Description Default
X DataFrame

DataFrame of features, shape (n_samples, n_features). Must contain columns to encode.

required
y [type]

Only for compatibility. Not used. Defaults to None.

None

Returns:

Name Type Description
DataFrame DataFrame

The encoded dataframe

Source code in src/pytorch_tabular/categorical_encoders.py
def fit_transform(self, X: DataFrame, y=None) -> DataFrame:
    """Encode given columns of X based on the learned embedding.

    Args:
        X (DataFrame): DataFrame of features, shape (n_samples, n_features). Must contain columns to encode.
        y ([type], optional): Only for compatibility. Not used. Defaults to None.

    Returns:
        DataFrame: The encoded dataframe

    """
    self.fit(X, y)
    return self.transform(X)

transform(X, y=None)

Transforms the categorical columns specified to the trained neural embedding from the model.

Parameters:

Name Type Description Default
X DataFrame

DataFrame of features, shape (n_samples, n_features). Must contain columns to encode.

required
y [type]

Only for compatibility. Not used. Defaults to None.

None

Raises:

Type Description
ValueError

[description]

Returns:

Name Type Description
DataFrame DataFrame

The encoded dataframe

Source code in src/pytorch_tabular/categorical_encoders.py
def transform(self, X: DataFrame, y=None) -> DataFrame:
    """Transforms the categorical columns specified to the trained neural embedding from the model.

    Args:
        X (DataFrame): DataFrame of features, shape (n_samples, n_features). Must contain columns to encode.
        y ([type], optional): Only for compatibility. Not used. Defaults to None.

    Raises:
        ValueError: [description]

    Returns:
        DataFrame: The encoded dataframe

    """
    if not self._mapping:
        raise ValueError(
            "Passed model should either have an attribute `embeddng_layers`"
            " or a method `extract_embedding` defined for `transform`."
        )
    assert all(c in X.columns for c in self.cols)

    X_encoded = X.copy(deep=True)
    for col, mapping in track(
        self._mapping.items(),
        description="Encoding the data...",
        total=len(self._mapping.values()),
    ):
        for dim in range(mapping[self.NAN_CATEGORY].shape[0]):
            X_encoded.loc[:, f"{col}_embed_dim_{dim}"] = (
                X_encoded[col].fillna(self.NAN_CATEGORY).map({k: v[dim] for k, v in mapping.items()})
            )
            # Filling unseen categories also with NAN Embedding
            X_encoded[f"{col}_embed_dim_{dim}"].fillna(mapping[self.NAN_CATEGORY][dim], inplace=True)
    X_encoded.drop(columns=self.cols, inplace=True)
    return X_encoded

Feature Extractor

What if you want to use the features learnt by the Neural Network in your ML model? Pytorch Tabular let's you do that as well, and with ease. Again, a scikit-learn style Transformer does the job for you.

Usage Example

# passing the trained model as an argument
dt = DeepFeatureExtractor(tabular_model)
# passing the train dataframe to extract the last layer features
# here `fit` is there only for compatibility and does not do anything
enc_df = dt.fit_transform(train)
# using the extracted embeddings on new dataframe
val_transformed = transformer.transform(val)

pytorch_tabular.feature_extractor.DeepFeatureExtractor

Bases: BaseEstimator, TransformerMixin

Source code in src/pytorch_tabular/feature_extractor.py
class DeepFeatureExtractor(BaseEstimator, TransformerMixin):
    def __init__(self, tabular_model, extract_keys=["backbone_features"], drop_original=True):
        """Initializes the Transformer and extracts the neural features.

        Args:
            tabular_model (TabularModel): The trained TabularModel object
            extract_keys (list, optional): The keys of the features to extract. Defaults to ["backbone_features"].
            drop_original (bool, optional): Whether to drop the original columns. Defaults to True.

        """
        assert not (
            isinstance(tabular_model.model, NODEModel)
            or isinstance(tabular_model.model, TabNetModel)
            or isinstance(tabular_model.model, MDNModel)
        ), "FeatureExtractor doesn't work for Mixture Density Networks, NODE Model, & Tabnet Model"
        self.tabular_model = tabular_model
        self.extract_keys = extract_keys
        self.drop_original = drop_original

    def fit(self, X, y=None):
        """Just for compatibility.

        Does not do anything

        """
        return self

    def transform(self, X: pd.DataFrame, y=None) -> pd.DataFrame:
        """Transforms the categorical columns specified to the trained neural features from the model.

        Args:
            X (pd.DataFrame): DataFrame of features, shape (n_samples, n_features). Must contain columns to encode.
            y ([type], optional): Only for compatibility. Not used. Defaults to None.

        Raises:
            ValueError: [description]

        Returns:
            pd.DataFrame: The encoded dataframe

        """

        X_encoded = X.copy(deep=True)
        orig_features = X_encoded.columns
        self.tabular_model.model.eval()
        inference_dataloader = self.tabular_model.datamodule.prepare_inference_dataloader(X_encoded)
        logits_predictions = defaultdict(list)
        for batch in track(inference_dataloader, description="Generating Features..."):
            for k, v in batch.items():
                if isinstance(v, list) and (len(v) == 0):
                    # Skipping empty list
                    continue
                batch[k] = v.to(self.tabular_model.model.device)
            if self.tabular_model.config.task == "ssl":
                ret_value = {"backbone_features": self.tabular_model.model.predict(batch, ret_model_output=True)}
            else:
                _, ret_value = self.tabular_model.model.predict(batch, ret_model_output=True)
            for k in self.extract_keys:
                if k in ret_value.keys():
                    logits_predictions[k].append(ret_value[k].detach().cpu())

        for k, v in logits_predictions.items():
            v = torch.cat(v, dim=0).numpy()
            if v.ndim == 1:
                v = v.reshape(-1, 1)
            for i in range(v.shape[-1]):
                if v.shape[-1] > 1:
                    X_encoded[f"{k}_{i}"] = v[:, i]
                else:
                    X_encoded[f"{k}"] = v[:, i]

        if self.drop_original:
            X_encoded.drop(columns=orig_features, inplace=True)
        return X_encoded

    def fit_transform(self, X: pd.DataFrame, y=None) -> pd.DataFrame:
        """Encode given columns of X based on the learned features.

        Args:
            X (pd.DataFrame): DataFrame of features, shape (n_samples, n_features). Must contain columns to encode.
            y ([type], optional): Only for compatibility. Not used. Defaults to None.

        Returns:
            pd.DataFrame: The encoded dataframe

        """
        self.fit(X, y)
        return self.transform(X)

    def save_as_object_file(self, path):
        """Saves the feature extractor as a pickle file.

        Args:
            path (str): The path to save the file

        """
        if not self._mapping:
            raise ValueError("`fit` method must be called before `save_as_object_file`.")
        pickle.dump(self.__dict__, open(path, "wb"))

    def load_from_object_file(self, path):
        """Loads the feature extractor from a pickle file.

        Args:
            path (str): The path to load the file from

        """
        for k, v in pickle.load(open(path, "rb")).items():
            setattr(self, k, v)

__init__(tabular_model, extract_keys=['backbone_features'], drop_original=True)

Initializes the Transformer and extracts the neural features.

Parameters:

Name Type Description Default
tabular_model TabularModel

The trained TabularModel object

required
extract_keys list

The keys of the features to extract. Defaults to ["backbone_features"].

['backbone_features']
drop_original bool

Whether to drop the original columns. Defaults to True.

True
Source code in src/pytorch_tabular/feature_extractor.py
def __init__(self, tabular_model, extract_keys=["backbone_features"], drop_original=True):
    """Initializes the Transformer and extracts the neural features.

    Args:
        tabular_model (TabularModel): The trained TabularModel object
        extract_keys (list, optional): The keys of the features to extract. Defaults to ["backbone_features"].
        drop_original (bool, optional): Whether to drop the original columns. Defaults to True.

    """
    assert not (
        isinstance(tabular_model.model, NODEModel)
        or isinstance(tabular_model.model, TabNetModel)
        or isinstance(tabular_model.model, MDNModel)
    ), "FeatureExtractor doesn't work for Mixture Density Networks, NODE Model, & Tabnet Model"
    self.tabular_model = tabular_model
    self.extract_keys = extract_keys
    self.drop_original = drop_original

fit(X, y=None)

Just for compatibility.

Does not do anything

Source code in src/pytorch_tabular/feature_extractor.py
def fit(self, X, y=None):
    """Just for compatibility.

    Does not do anything

    """
    return self

fit_transform(X, y=None)

Encode given columns of X based on the learned features.

Parameters:

Name Type Description Default
X DataFrame

DataFrame of features, shape (n_samples, n_features). Must contain columns to encode.

required
y [type]

Only for compatibility. Not used. Defaults to None.

None

Returns:

Type Description
DataFrame

pd.DataFrame: The encoded dataframe

Source code in src/pytorch_tabular/feature_extractor.py
def fit_transform(self, X: pd.DataFrame, y=None) -> pd.DataFrame:
    """Encode given columns of X based on the learned features.

    Args:
        X (pd.DataFrame): DataFrame of features, shape (n_samples, n_features). Must contain columns to encode.
        y ([type], optional): Only for compatibility. Not used. Defaults to None.

    Returns:
        pd.DataFrame: The encoded dataframe

    """
    self.fit(X, y)
    return self.transform(X)

load_from_object_file(path)

Loads the feature extractor from a pickle file.

Parameters:

Name Type Description Default
path str

The path to load the file from

required
Source code in src/pytorch_tabular/feature_extractor.py
def load_from_object_file(self, path):
    """Loads the feature extractor from a pickle file.

    Args:
        path (str): The path to load the file from

    """
    for k, v in pickle.load(open(path, "rb")).items():
        setattr(self, k, v)

save_as_object_file(path)

Saves the feature extractor as a pickle file.

Parameters:

Name Type Description Default
path str

The path to save the file

required
Source code in src/pytorch_tabular/feature_extractor.py
def save_as_object_file(self, path):
    """Saves the feature extractor as a pickle file.

    Args:
        path (str): The path to save the file

    """
    if not self._mapping:
        raise ValueError("`fit` method must be called before `save_as_object_file`.")
    pickle.dump(self.__dict__, open(path, "wb"))

transform(X, y=None)

Transforms the categorical columns specified to the trained neural features from the model.

Parameters:

Name Type Description Default
X DataFrame

DataFrame of features, shape (n_samples, n_features). Must contain columns to encode.

required
y [type]

Only for compatibility. Not used. Defaults to None.

None

Raises:

Type Description
ValueError

[description]

Returns:

Type Description
DataFrame

pd.DataFrame: The encoded dataframe

Source code in src/pytorch_tabular/feature_extractor.py
def transform(self, X: pd.DataFrame, y=None) -> pd.DataFrame:
    """Transforms the categorical columns specified to the trained neural features from the model.

    Args:
        X (pd.DataFrame): DataFrame of features, shape (n_samples, n_features). Must contain columns to encode.
        y ([type], optional): Only for compatibility. Not used. Defaults to None.

    Raises:
        ValueError: [description]

    Returns:
        pd.DataFrame: The encoded dataframe

    """

    X_encoded = X.copy(deep=True)
    orig_features = X_encoded.columns
    self.tabular_model.model.eval()
    inference_dataloader = self.tabular_model.datamodule.prepare_inference_dataloader(X_encoded)
    logits_predictions = defaultdict(list)
    for batch in track(inference_dataloader, description="Generating Features..."):
        for k, v in batch.items():
            if isinstance(v, list) and (len(v) == 0):
                # Skipping empty list
                continue
            batch[k] = v.to(self.tabular_model.model.device)
        if self.tabular_model.config.task == "ssl":
            ret_value = {"backbone_features": self.tabular_model.model.predict(batch, ret_model_output=True)}
        else:
            _, ret_value = self.tabular_model.model.predict(batch, ret_model_output=True)
        for k in self.extract_keys:
            if k in ret_value.keys():
                logits_predictions[k].append(ret_value[k].detach().cpu())

    for k, v in logits_predictions.items():
        v = torch.cat(v, dim=0).numpy()
        if v.ndim == 1:
            v = v.reshape(-1, 1)
        for i in range(v.shape[-1]):
            if v.shape[-1] > 1:
                X_encoded[f"{k}_{i}"] = v[:, i]
            else:
                X_encoded[f"{k}"] = v[:, i]

    if self.drop_original:
        X_encoded.drop(columns=orig_features, inplace=True)
    return X_encoded