Skip to content
import os
import random
import numpy as np
import pandas as pd
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import plotly.express as px
import plotly.graph_objs as go
from plotly.offline import init_notebook_mode
import requests
# %load_ext autoreload
# %autoreload 2
from IPython.display import Math
from pytorch_tabular import TabularModel
from pytorch_tabular.models import (
    CategoryEmbeddingModelConfig,
    GatedAdditiveTreeEnsembleConfig,
    MDNConfig
)
from pytorch_tabular.config import (
    DataConfig,
    OptimizerConfig,
    TrainerConfig,
    ExperimentConfig,
)
# from pytorch_tabular.categorical_encoders import CategoricalEmbeddingTransformer
from pytorch_tabular.models.common.heads import LinearHeadConfig, MixtureDensityHeadConfig
np.random.seed(42)

Utility Functions

def generate_linear_example(samples=int(1e5)):
    x_data = np.random.sample(samples)[:, np.newaxis].astype(np.float32)
    y_data = np.add(5*x_data, np.multiply((x_data)**2, np.random.standard_normal(x_data.shape)))

    x_train, x_valid, y_train, y_valid = train_test_split(x_data, y_data, test_size=0.5, random_state=42)
    x_test = np.linspace(0.,1.,int(1e3))[:, np.newaxis].astype(np.float32)
    df_train = pd.DataFrame({"col1": x_train.ravel(), "target": y_train.ravel()})
    df_valid = pd.DataFrame({"col1": x_valid.ravel(), "target": y_valid.ravel()})
    # test = sorted(df_valid.col1.round(3).unique())
    # df_test = pd.DataFrame({"col1": test})
    df_test = pd.DataFrame({"col1": x_test.ravel()})
    return (df_train, df_valid, df_test, ["target"])

def generate_non_linear_example(samples=int(1e5)):
    x_data = np.float32(np.random.uniform(-10, 10, (1, samples)))
    r_data = np.array([np.random.normal(scale=np.abs(i)) for i in x_data])
    y_data = np.float32(np.square(x_data)+r_data*2.0)

    x_data2 = np.float32(np.random.uniform(-10, 10, (1, samples)))
    r_data2 = np.array([np.random.normal(scale=np.abs(i)) for i in x_data2])
    y_data2 = np.float32(-np.square(x_data2)+r_data2*2.0)

    x_data = np.concatenate((x_data,x_data2),axis=1).T
    y_data = np.concatenate((y_data,y_data2),axis=1).T

    min_max_scaler = MinMaxScaler()
    y_data = min_max_scaler.fit_transform(y_data)

    x_train, x_valid, y_train, y_valid = train_test_split(x_data, y_data, test_size=0.5, random_state=42, shuffle=True)
    x_test = np.linspace(-10,10,int(1e3))[:, np.newaxis].astype(np.float32)
    df_train = pd.DataFrame({"col1": x_train.ravel(), "target": y_train.ravel()})
    df_valid = pd.DataFrame({"col1": x_valid.ravel(), "target": y_valid.ravel()})
    # test = sorted(df_valid.col1.round(3).unique())
    # df_test = pd.DataFrame({"col1": test})
    df_test = pd.DataFrame({"col1": x_test.ravel()})
    return (df_train, df_valid, df_test, ["target"])

def generate_step_linear_example(samples=int(1e5)):
    x_data = np.random.sample(samples)[:, np.newaxis].astype(np.float32)
    y_data = np.zeros(x_data.shape)
    mask = x_data<0.5
    y_data[mask] = np.add(5*x_data[mask], np.multiply((x_data[mask])**2, np.random.standard_normal(x_data[mask].shape)))
    y_data[~mask] = np.add(100*x_data[~mask]+x_data[~mask]**2 , np.multiply((x_data[~mask])**2, np.random.standard_normal(x_data[~mask].shape)))
    min_max_scaler = MinMaxScaler()
    y_data = min_max_scaler.fit_transform(y_data)

    x_train, x_valid, y_train, y_valid = train_test_split(x_data, y_data, test_size=0.5, random_state=42, shuffle=True)
    x_test = np.linspace(0.,1.,int(1e3))[:, np.newaxis].astype(np.float32)
    df_train = pd.DataFrame({"col1": x_train.ravel(), "target": y_train.ravel()})
    df_valid = pd.DataFrame({"col1": x_valid.ravel(), "target": y_valid.ravel()})
    # test = sorted(df_valid.col1.round(3).unique())
    # df_test = pd.DataFrame({"col1": test})
    df_test = pd.DataFrame({"col1": x_test.ravel()})
    return (df_train, df_valid, df_test, ["target"])

def generate_gaussian_mixture(samples=int(1e5)):
    x_data = np.random.sample(samples)[:, np.newaxis].astype(np.float32)
    pi = np.sin(x_data)+3*x_data*np.cos(x_data)
    pi = pi/pi.max()
    # g1 = np.random.sample(samples)*4*x_data.squeeze()
    # g2 = np.random.sample(samples)*15*x_data.squeeze()
    g1 = 2*x_data.squeeze() + 0.5*np.random.sample(samples)
    g2 = 8*x_data.squeeze() + 0.5*np.random.sample(samples)

    y_data = pi.round().squeeze()*g1 + (1-pi.round().squeeze())*g2
    y_data = y_data.reshape(-1,1)
    x_train, x_valid, y_train, y_valid = train_test_split(x_data, y_data, test_size=0.5, random_state=42)
    x_test = np.linspace(0.,1.,int(1e3))[:, np.newaxis].astype(np.float32)
    df_train = pd.DataFrame({"col1": x_train.ravel(), "target": y_train.ravel()})
    df_valid = pd.DataFrame({"col1": x_valid.ravel(), "target": y_valid.ravel()})
    # test = sorted(df_valid.col1.round(3).unique())
    # df_test = pd.DataFrame({"col1": test})
    df_test = pd.DataFrame({"col1": x_test.ravel()})
    return (df_train, df_valid, df_test, ["target"])

def latex_to_png( formula, file):
    tfile = file
    r = requests.get( 'http://latex.codecogs.com/png.latex?\dpi{300} \huge %s' % formula )
    with open( tfile, 'wb' ) as f:
#     f = open( tfile, 'wb' )
        f.write( r.content )
#     f.close()

Linear Example

df_train, df_valid, df_test, target_col = generate_linear_example()

Plot

# display(Math(r"$y = 5x + (x^2 * \epsilon)$"+"\n"+r"$\epsilon \backsim \mathcal{N}(0,1)$"))
fig = px.scatter(df_train, x="col1", y="target", title=r"$y = 5x + (x^2 * \epsilon)$"+"\n"+r"$\epsilon \backsim \mathcal{N}(0,1)$")
fig.update_layout(
    title={
        'y':0.9,
        'x':0.5,
        'xanchor': 'center',
        'yanchor': 'top'})
fig.write_image("imgs/prob_reg_fig_1.png")

fig = px.histogram(df_train, x="target", title="Histogram")
fig.write_image("imgs/prob_reg_hist_1.png")

Training the MDN

Define the Configs

epochs = 15
batch_size = 128
steps_per_epoch = int((len(df_train)//batch_size)*0.9)
data_config = DataConfig(
    target=['target'],
    continuous_cols=['col1'],
    categorical_cols=[],
#         continuous_feature_transform="quantile_uniform"
)
trainer_config = TrainerConfig(
    auto_lr_find=True, # Runs the LRFinder to automatically derive a learning rate
    batch_size=batch_size,
    max_epochs=epochs,
    early_stopping="valid_loss",
    early_stopping_patience=5,
    checkpoints="valid_loss"
)

optimizer_config = OptimizerConfig(lr_scheduler="ReduceLROnPlateau", lr_scheduler_params={"patience":3})

mdn_head_config = MixtureDensityHeadConfig(num_gaussian=1).__dict__

backbone_config_class = "CategoryEmbeddingModelConfig"
backbone_config = dict(
    task="backbone",
    layers="128-64",  # Number of nodes in each layer
    activation="ReLU",  # Activation between each layers
    initialization="kaiming"
)

model_config = MDNConfig(
    task="regression",
    backbone_config_class=backbone_config_class,
    backbone_config_params=backbone_config,
    head_config=mdn_head_config,
    learning_rate=1e-3,
)
tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config
)

Training the Model

tabular_model.fit(train=df_train, validation=df_valid)
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/utilities/seed.py:48: LightningDeprecationWarning:

`pytorch_lightning.utilities.seed.seed_everything` has been deprecated in v1.8.0 and will be removed in v1.10.0. Please use `lightning_lite.utilities.seed.seed_everything` instead.

Global seed set to 42
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_datamodule.py:198: FutureWarning:

In a future version, `df.iloc[:, i] = newvals` will attempt to set the values inplace instead of always setting a new array. To retain the old behavior, use either `df[df.columns[i]] = newvals` or, if columns are non-unique, `df.isetitem(i, newvals)`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_datamodule.py:202: FutureWarning:

In a future version, `df.iloc[:, i] = newvals` will attempt to set the values inplace instead of always setting a new array. To retain the old behavior, use either `df[df.columns[i]] = newvals` or, if columns are non-unique, `df.isetitem(i, newvals)`

`head` is not a valid parameter for backbone task. Making `head=None`
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/torch/cuda/__init__.py:497: UserWarning:

Can't initialize NVML

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:604: UserWarning:

Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/checkpoints exists and is not empty.

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning:

The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning:

The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.


Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]
`Trainer.fit` stopped: `max_steps=98` reached.
LR finder stopped early after 98 steps due to diverging loss.
Learning rate set to 0.0009120108393559097
Restoring states from the checkpoint path at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_fe815ad0-1801-45e6-9700-6b960f3736da.ckpt
Restored all states from the checkpoint file at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_fe815ad0-1801-45e6-9700-6b960f3736da.ckpt

  | Name             | Type                      | Params
---------------------------------------------------------------
0 | _backbone        | CategoryEmbeddingBackbone | 8.5 K 
1 | _embedding_layer | Embedding1dLayer          | 2     
2 | _head            | MixtureDensityHead        | 194   
3 | loss             | MSELoss                   | 0     
---------------------------------------------------------------
8.7 K     Trainable params
0         Non-trainable params
8.7 K     Total params
0.035     Total estimated model params size (MB)

Sanity Checking: 0it [00:00, ?it/s]
Training: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
`Trainer.fit` stopped: `max_epochs=15` reached.
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/utilities/cloud_io.py:41: LightningDeprecationWarning:

`pytorch_lightning.utilities.cloud_io.load` has been deprecated in v1.8.0 and will be removed in v1.10.0. This function is internal but you can copy over its implementation.


<pytorch_lightning.trainer.trainer.Trainer at 0x7f00c9e47f10>

Predictions and Visualization

pred_df = tabular_model.predict(df_test, quantiles=[0.25,0.5,0.75], n_samples=100)
pred_df.head()
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_datamodule.py:202: FutureWarning:

In a future version, `df.iloc[:, i] = newvals` will attempt to set the values inplace instead of always setting a new array. To retain the old behavior, use either `df[df.columns[i]] = newvals` or, if columns are non-unique, `df.isetitem(i, newvals)`


Generating Predictions...:   0%|          | 0/8 [00:00<?, ?it/s]
col1 target_prediction target_q25 target_q50 target_q75
0 0.000000 0.426740 0.149168 0.443037 0.678223
1 0.001001 0.403517 0.169063 0.382718 0.627589
2 0.002002 0.500595 0.189833 0.531668 0.795271
3 0.003003 0.433998 0.172423 0.431257 0.682992
4 0.004004 0.463966 0.220691 0.469996 0.699911
fig = go.Figure([
    go.Scatter(
        name='Mean',
        x=pred_df['col1'],
        y=pred_df['target_prediction'],
        mode='lines',
        line=dict(color='rgba(28,53,94,1)'),
    ),
    go.Scatter(
        name='Upper Bound',
        x=pred_df['col1'],
        y=pred_df['target_q75'],
        mode='lines',
        marker=dict(color='rgba(0,147,201,0.3)'),
        line=dict(width=0),
        showlegend=False
    ),
    go.Scatter(
        name='Lower Bound',
        x=pred_df['col1'],
        y=pred_df['target_q25'],
        marker=dict(color="#444"),
        line=dict(width=0),
        mode='lines',
        fillcolor='rgba(0,147,201,0.3)',
        fill='tonexty',
        showlegend=False
    )
])
fig.update_layout(
    yaxis_title='y',
    title='Mixture Density Network Prediction',
    hovermode="x"
)
# fig.show()
fig.write_image("imgs/prob_reg_mdn_1.png")

Non-Linear Example

df_train, df_valid, df_test, target_col = generate_non_linear_example()

Plot

fig = px.scatter(df_train, x="col1", y="target", title=r"$y = \pm x^2 + \epsilon$"+"\n"+r"$\epsilon\backsim\mathcal{N}(0,|x|)$")
fig.update_layout(
    title={
        'y':0.9,
        'x':0.5,
        'xanchor': 'center',
        'yanchor': 'top'})
fig.write_image("imgs/prob_reg_fig_2.png")

fig = px.histogram(df_train, x="target", title="Histogram")
fig.write_image("imgs/prob_reg_hist_2.png")

Training a FeedForward

Define the Configs

epochs = 200
batch_size = 2048
steps_per_epoch = int((len(df_train)//batch_size)*0.9)
data_config = DataConfig(
    target=['target'],
    continuous_cols=['col1'],
    categorical_cols=[],
#         continuous_feature_transform="quantile_uniform"
)
trainer_config = TrainerConfig(
    auto_lr_find=True, # Runs the LRFinder to automatically derive a learning rate
    batch_size=batch_size,
    max_epochs=epochs,
    early_stopping="valid_loss",
    early_stopping_patience=5,
    checkpoints="valid_loss"
)
optimizer_config = OptimizerConfig(lr_scheduler="ReduceLROnPlateau", lr_scheduler_params={"patience":3})
model_config = CategoryEmbeddingModelConfig(
    task="regression",
    layers="16-8",  # Number of nodes in each layer
    activation="ReLU",  # Activation between each layers
    head="LinearHead",
    learning_rate=1e-3,
)
tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config
)

Training the Model

tabular_model.fit(train=df_train, validation=df_valid)
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/utilities/seed.py:48: LightningDeprecationWarning:

`pytorch_lightning.utilities.seed.seed_everything` has been deprecated in v1.8.0 and will be removed in v1.10.0. Please use `lightning_lite.utilities.seed.seed_everything` instead.

Global seed set to 42
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:604: UserWarning:

Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/checkpoints exists and is not empty.

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning:

The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning:

The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.


Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]
`Trainer.fit` stopped: `max_steps=100` reached.
Learning rate set to 0.012022644346174132
Restoring states from the checkpoint path at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_50589bdd-1fbd-4a3c-810e-22c399ce6d75.ckpt
Restored all states from the checkpoint file at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_50589bdd-1fbd-4a3c-810e-22c399ce6d75.ckpt

  | Name             | Type                      | Params
---------------------------------------------------------------
0 | _backbone        | CategoryEmbeddingBackbone | 168   
1 | _embedding_layer | Embedding1dLayer          | 2     
2 | head             | LinearHead                | 9     
3 | loss             | MSELoss                   | 0     
---------------------------------------------------------------
179       Trainable params
0         Non-trainable params
179       Total params
0.001     Total estimated model params size (MB)

Sanity Checking: 0it [00:00, ?it/s]
Training: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:48: UserWarning:

Detected KeyboardInterrupt, attempting graceful shutdown...

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/utilities/cloud_io.py:41: LightningDeprecationWarning:

`pytorch_lightning.utilities.cloud_io.load` has been deprecated in v1.8.0 and will be removed in v1.10.0. This function is internal but you can copy over its implementation.


<pytorch_lightning.trainer.trainer.Trainer at 0x7f01afe8a590>

Predictions and Visualization

pred_df = tabular_model.predict(df_valid.sample(1000).sort_values("col1"))
pred_df.head()
Generating Predictions...:   0%|          | 0/1 [00:00<?, ?it/s]
col1 target target_prediction
52696 -9.971139 0.234386 0.505665
84854 -9.920759 0.198878 0.505669
58540 -9.911808 0.726360 0.505670
9330 -9.907802 0.805257 0.505671
54741 -9.893517 0.807486 0.505672
fig = go.Figure([
    go.Scatter(
        name='Prediction',
        x=pred_df['col1'],
        y=pred_df['target_prediction'],
        mode='lines',
        line=dict(color='rgba(28,53,94,1)'),
    ),
    go.Scatter(
        name='Actual',
        x=pred_df['col1'],
        y=pred_df['target'],
        mode='markers',
        line=dict(color='rgba(60,180,229,1)'),
    ),
])
fig.update_layout(
    yaxis_title='y',
    title='Category Embedding Prediction',
    hovermode="x"
)
# fig.show()
fig.write_image("imgs/prob_reg_non_mdn_2.png")

Training the MDN

Define the Configs

epochs = 200
batch_size = 2048
steps_per_epoch = int((len(df_train)//batch_size)*0.9)
data_config = DataConfig(
    target=['target'],
    continuous_cols=['col1'],
    categorical_cols=[],
#         continuous_feature_transform="quantile_uniform"
)
trainer_config = TrainerConfig(
    auto_lr_find=False, # Runs the LRFinder to automatically derive a learning rate
    batch_size=batch_size,
    max_epochs=epochs,
    early_stopping="valid_loss",
    early_stopping_patience=5,
    checkpoints="valid_loss"
)
optimizer_config = OptimizerConfig(lr_scheduler="ReduceLROnPlateau", lr_scheduler_params={"patience":3})

mdn_head_config = MixtureDensityHeadConfig(num_gaussian=2, weight_regularization=2, lambda_mu=10, lambda_pi=5).__dict__

backbone_config_class = "CategoryEmbeddingModelConfig"
backbone_config = dict(
    task="backbone",
    layers="128-64",  # Number of nodes in each layer
    activation="ReLU",  # Activation between each layers
    head=None,
)

model_config = MDNConfig(
    task="regression",
    backbone_config_class=backbone_config_class,
    backbone_config_params=backbone_config,
    head_config=mdn_head_config,
    learning_rate=1e-3,
)

tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config
)

Training the Model

tabular_model.fit(train=df_train, validation=df_valid)
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/utilities/seed.py:48: LightningDeprecationWarning:

`pytorch_lightning.utilities.seed.seed_everything` has been deprecated in v1.8.0 and will be removed in v1.10.0. Please use `lightning_lite.utilities.seed.seed_everything` instead.

Global seed set to 42
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:604: UserWarning:

Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/checkpoints exists and is not empty.

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type                      | Params
---------------------------------------------------------------
0 | _backbone        | CategoryEmbeddingBackbone | 8.5 K 
1 | _embedding_layer | Embedding1dLayer          | 2     
2 | _head            | MixtureDensityHead        | 388   
3 | loss             | MSELoss                   | 0     
---------------------------------------------------------------
8.9 K     Trainable params
0         Non-trainable params
8.9 K     Total params
0.036     Total estimated model params size (MB)

Sanity Checking: 0it [00:00, ?it/s]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning:

The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning:

The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.


Training: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:48: UserWarning:

Detected KeyboardInterrupt, attempting graceful shutdown...

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/utilities/cloud_io.py:41: LightningDeprecationWarning:

`pytorch_lightning.utilities.cloud_io.load` has been deprecated in v1.8.0 and will be removed in v1.10.0. This function is internal but you can copy over its implementation.


<pytorch_lightning.trainer.trainer.Trainer at 0x7f7fc50ef8e0>

Predictions and Visualization

pred_df = tabular_model.predict(df_test, quantiles=[0.25,0.5,0.75], n_samples=100, ret_logits=True)
pred_df.head()
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_datamodule.py:202: FutureWarning:

In a future version, `df.iloc[:, i] = newvals` will attempt to set the values inplace instead of always setting a new array. To retain the old behavior, use either `df[df.columns[i]] = newvals` or, if columns are non-unique, `df.isetitem(i, newvals)`


Generating Predictions...:   0%|          | 0/1 [00:00<?, ?it/s]
col1 target_prediction target_q25 target_q50 target_q75 pi_0 pi_1 sigma_0 sigma_1 mu_0 ... backbone_features_54 backbone_features_55 backbone_features_56 backbone_features_57 backbone_features_58 backbone_features_59 backbone_features_60 backbone_features_61 backbone_features_62 backbone_features_63
0 -10.000000 0.329563 0.227722 0.277360 0.316893 0.000037 0.000020 0.055194 0.054383 0.255731 ... 0.0 0.0 4.998535 1.226815 0.0 3.516335 0.0 1.714176 2.625350 4.461626
1 -9.979980 0.652835 0.648949 0.704733 0.751973 0.000037 0.000020 0.055109 0.054314 0.256374 ... 0.0 0.0 4.999834 1.226926 0.0 3.506327 0.0 1.722667 2.624132 4.448001
2 -9.959960 0.671895 0.659679 0.706704 0.753247 0.000037 0.000019 0.055025 0.054245 0.257018 ... 0.0 0.0 5.001133 1.227038 0.0 3.496319 0.0 1.731157 2.622915 4.434377
3 -9.939939 0.488492 0.255757 0.619482 0.704288 0.000037 0.000019 0.054941 0.054176 0.257662 ... 0.0 0.0 5.002433 1.227150 0.0 3.486310 0.0 1.739647 2.621696 4.420753
4 -9.919920 0.696155 0.677998 0.725608 0.762295 0.000037 0.000019 0.054858 0.054107 0.258305 ... 0.0 0.0 5.003733 1.227261 0.0 3.476303 0.0 1.748137 2.620478 4.407130

5 rows × 75 columns

df = df_valid.sample(10000)
fig = go.Figure([
    go.Scatter(
        name='Ground Truth',
        x=df['col1'],
        y=df['target'],
        mode='markers',
        line=dict(color='rgba(153, 115, 142, 0.2)'),
    ),
    go.Scatter(
        name='Component 1',
        x=pred_df['col1'],
        y=pred_df['mu_0'],
        mode='lines',
        line=dict(color='rgba(36, 37, 130, 1)'),
    ),
    go.Scatter(
        name='Component 2',
        x=pred_df['col1'],
        y=pred_df['mu_1'],
        mode='lines',
        line=dict(color='rgba(246, 76, 114, 1)'),
    ),
    go.Scatter(
        name='Upper Bound 1',
        x=pred_df['col1'],
        y=pred_df['mu_0']+pred_df['sigma_0'],
        mode='lines',
        marker=dict(color='rgba(47, 47, 162, 0.5)'),
#         line=dict(width=0),
        showlegend=False
    ),
    go.Scatter(
        name='Lower Bound 1',
        x=pred_df['col1'],
        y=pred_df['mu_0']-pred_df['sigma_0'],
        marker=dict(color="#444"),
        line=dict(width=0),
        mode='lines',
        fillcolor='rgba(47, 47, 162, 0.5)',
        fill='tonexty',
        showlegend=False
    ),
    go.Scatter(
        name='Upper Bound 2',
        x=pred_df['col1'],
        y=pred_df['mu_1']+pred_df['sigma_1'],
        mode='lines',
        marker=dict(color='rgba(250, 152, 174, 0.5)'),
#         line=dict(width=0),
        showlegend=False
    ),
    go.Scatter(
        name='Lower Bound 2',
        x=pred_df['col1'],
        y=pred_df['mu_1']-pred_df['sigma_1'],
        marker=dict(color="#444"),
        line=dict(width=0),
        mode='lines',
        fillcolor='rgba(250, 152, 174, 0.5)',
        fill='tonexty',
        showlegend=False
    ),
])
fig.update_layout(
    yaxis_title='y',
    title='Mixture Density Network Prediction',
    hovermode="x"
)
# fig.show()
fig.write_image("imgs/prob_reg_mdn_2.png")

Gaussian Mixture

df_train, df_valid, df_test, target_col = generate_gaussian_mixture()

Plot

from IPython.display import display, Math, Latex
eqn = r'$\pi = \frac{sin(x) + 3xcos(x)}{max \left (sin(x) + 3xcos(x) \right )} \\ \\ g1 = 2x + 0.5 \epsilon \rightarrow  \epsilon \backsim \mathcal{N}(0,1) \\ g2 = 8x + 0.5 \epsilon \rightarrow  \epsilon \backsim \mathcal{N}(0,1) \\ p = Bernoulli(pi) \rightarrow \text{Samples one of two outcomes based on the value of } \pi \\ y = p \times g1 + (1-p) \times g2$'
display(Math(eqn))
fig = px.scatter(df_train, x="col1", y="target")
fig.update_layout(
    title={
        'y':0.9,
        'x':0.5,
        'xanchor': 'center',
        'yanchor': 'top'})
fig.write_image("imgs/prob_reg_fig_3.png")
$\displaystyle \pi = \frac{sin(x) + 3xcos(x)}{max \left (sin(x) + 3xcos(x) \right )} \\ \\ g1 = 2x + 0.5 \epsilon \rightarrow \epsilon \backsim \mathcal{N}(0,1) \\ g2 = 8x + 0.5 \epsilon \rightarrow \epsilon \backsim \mathcal{N}(0,1) \\ p = Bernoulli(pi) \rightarrow \text{Samples one of two outcomes based on the value of } \pi \\ y = p \times g1 + (1-p) \times g2$

fig = px.histogram(df_train, x="target", title="Histogram")
fig.write_image("imgs/prob_reg_hist_3.png")

Training a FeedForward

Define the Configs

epochs = 200
batch_size = 2048
steps_per_epoch = int((len(df_train)//batch_size)*0.9)
data_config = DataConfig(
    target=['target'],
    continuous_cols=['col1'],
    categorical_cols=[],
#         continuous_feature_transform="quantile_uniform"
)
trainer_config = TrainerConfig(
    auto_lr_find=True, # Runs the LRFinder to automatically derive a learning rate
    batch_size=batch_size,
    max_epochs=epochs,
    early_stopping="valid_loss",
    early_stopping_patience=5,
    checkpoints="valid_loss"
)

optimizer_config = OptimizerConfig(lr_scheduler="ReduceLROnPlateau", lr_scheduler_params={"patience":3})

model_config = CategoryEmbeddingModelConfig(
    task="regression",
    layers="16-8",  # Number of nodes in each layer
    activation="ReLU",  # Activation between each layers
    head="LinearHead",
    learning_rate=1e-3,
)

tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config
)

Training the Model

tabular_model.fit(train=df_train, validation=df_valid)
Global seed set to 42
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_datamodule.py:198: FutureWarning:

In a future version, `df.iloc[:, i] = newvals` will attempt to set the values inplace instead of always setting a new array. To retain the old behavior, use either `df[df.columns[i]] = newvals` or, if columns are non-unique, `df.isetitem(i, newvals)`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_datamodule.py:202: FutureWarning:

In a future version, `df.iloc[:, i] = newvals` will attempt to set the values inplace instead of always setting a new array. To retain the old behavior, use either `df[df.columns[i]] = newvals` or, if columns are non-unique, `df.isetitem(i, newvals)`

Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:604: UserWarning:

Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/checkpoints exists and is not empty.

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning:

The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning:

The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.


Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]
`Trainer.fit` stopped: `max_steps=100` reached.
Learning rate set to 0.01
Restoring states from the checkpoint path at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_37f44922-c147-4423-999a-61166be4c5a9.ckpt
Restored all states from the checkpoint file at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_37f44922-c147-4423-999a-61166be4c5a9.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type                      | Params
---------------------------------------------------------------
0 | _backbone        | CategoryEmbeddingBackbone | 168   
1 | _embedding_layer | Embedding1dLayer          | 2     
2 | head             | LinearHead                | 9     
3 | loss             | MSELoss                   | 0     
---------------------------------------------------------------
179       Trainable params
0         Non-trainable params
179       Total params
0.001     Total estimated model params size (MB)

Sanity Checking: 0it [00:00, ?it/s]
Training: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
`Trainer.fit` stopped: `max_epochs=200` reached.
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/utilities/cloud_io.py:41: LightningDeprecationWarning:

`pytorch_lightning.utilities.cloud_io.load` has been deprecated in v1.8.0 and will be removed in v1.10.0. This function is internal but you can copy over its implementation.


<pytorch_lightning.trainer.trainer.Trainer at 0x7fd8160a98d0>

Predictions and Visualization

pred_df = tabular_model.predict(df_valid.sample(1000).sort_values("col1"))
pred_df.head()
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_datamodule.py:202: FutureWarning:

In a future version, `df.iloc[:, i] = newvals` will attempt to set the values inplace instead of always setting a new array. To retain the old behavior, use either `df[df.columns[i]] = newvals` or, if columns are non-unique, `df.isetitem(i, newvals)`


Generating Predictions...:   0%|          | 0/1 [00:00<?, ?it/s]
col1 target target_prediction
46976 0.001606 0.310145 0.587298
13609 0.001873 0.502399 0.589231
672 0.004710 0.052382 0.609763
37622 0.005097 0.538130 0.612562
8230 0.007761 0.390704 0.631844
fig = go.Figure([
    go.Scatter(
        name='Prediction',
        x=pred_df['col1'],
        y=pred_df['target_prediction'],
        mode='lines',
        line=dict(color='rgba(28,53,94,1)'),
    ),
    go.Scatter(
        name='Actual',
        x=pred_df['col1'],
        y=pred_df['target'],
        mode='markers',
        line=dict(color='rgba(60,180,229,1)'),
    ),
])
fig.update_layout(
    yaxis_title='y',
    title='Category Embedding Network Prediction',
    hovermode="x"
)
fig.write_image("imgs/prob_reg_non_mdn_3.png")

Training the MDN

Define the Configs

epochs = 200
batch_size = 2048
steps_per_epoch = int((len(df_train)//batch_size)*0.9)
data_config = DataConfig(
    target=['target'],
    continuous_cols=['col1'],
    categorical_cols=[],
#         continuous_feature_transform="quantile_uniform"
)
trainer_config = TrainerConfig(
    auto_lr_find=True, # Runs the LRFinder to automatically derive a learning rate
    batch_size=batch_size,
    max_epochs=epochs,
    early_stopping_patience = 5,
    early_stopping="valid_loss",
    checkpoints="valid_loss",
    load_best=True
)

optimizer_config = OptimizerConfig(lr_scheduler="ReduceLROnPlateau", lr_scheduler_params={"patience":3})

mdn_head_config = MixtureDensityHeadConfig(num_gaussian=2, weight_regularization=2).__dict__

backbone_config_class = "CategoryEmbeddingModelConfig"
backbone_config = dict(
    task="backbone",
    layers="128-64",  # Number of nodes in each layer
    activation="ReLU",  # Activation between each layers
    head=None,
)

model_config = MDNConfig(
    task="regression",
    backbone_config_class=backbone_config_class,
    backbone_config_params=backbone_config,
    head_config=mdn_head_config,
    learning_rate=1e-3,
)

tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config
)

Training the Model

tabular_model.fit(train=df_train, validation=df_valid)
Global seed set to 42
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_datamodule.py:198: FutureWarning:

In a future version, `df.iloc[:, i] = newvals` will attempt to set the values inplace instead of always setting a new array. To retain the old behavior, use either `df[df.columns[i]] = newvals` or, if columns are non-unique, `df.isetitem(i, newvals)`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_datamodule.py:202: FutureWarning:

In a future version, `df.iloc[:, i] = newvals` will attempt to set the values inplace instead of always setting a new array. To retain the old behavior, use either `df[df.columns[i]] = newvals` or, if columns are non-unique, `df.isetitem(i, newvals)`

Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:604: UserWarning:

Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/checkpoints exists and is not empty.

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning:

The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning:

The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.


Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]
`Trainer.fit` stopped: `max_steps=100` reached.
Learning rate set to 0.0007585775750291836
Restoring states from the checkpoint path at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_653e468e-c0c6-4fa7-922a-59c28c09d328.ckpt
Restored all states from the checkpoint file at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_653e468e-c0c6-4fa7-922a-59c28c09d328.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type                      | Params
---------------------------------------------------------------
0 | _backbone        | CategoryEmbeddingBackbone | 8.5 K 
1 | _embedding_layer | Embedding1dLayer          | 2     
2 | _head            | MixtureDensityHead        | 388   
3 | loss             | MSELoss                   | 0     
---------------------------------------------------------------
8.9 K     Trainable params
0         Non-trainable params
8.9 K     Total params
0.036     Total estimated model params size (MB)

Sanity Checking: 0it [00:00, ?it/s]
Training: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:48: UserWarning:

Detected KeyboardInterrupt, attempting graceful shutdown...

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/utilities/cloud_io.py:41: LightningDeprecationWarning:

`pytorch_lightning.utilities.cloud_io.load` has been deprecated in v1.8.0 and will be removed in v1.10.0. This function is internal but you can copy over its implementation.


<pytorch_lightning.trainer.trainer.Trainer at 0x7fd6ad1e9f90>

Predictions and Visualization

pred_df = tabular_model.predict(df_test, quantiles=[0.25,0.5,0.75], n_samples=100, ret_logits=True)
pred_df.head()
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_datamodule.py:202: FutureWarning:

In a future version, `df.iloc[:, i] = newvals` will attempt to set the values inplace instead of always setting a new array. To retain the old behavior, use either `df[df.columns[i]] = newvals` or, if columns are non-unique, `df.isetitem(i, newvals)`


Generating Predictions...:   0%|          | 0/1 [00:00<?, ?it/s]
col1 target_prediction target_q25 target_q50 target_q75 pi_0 pi_1 sigma_0 sigma_1 mu_0 ... backbone_features_54 backbone_features_55 backbone_features_56 backbone_features_57 backbone_features_58 backbone_features_59 backbone_features_60 backbone_features_61 backbone_features_62 backbone_features_63
0 0.000000 0.643981 0.512064 0.660999 0.767942 15.547802 0.401295 0.199566 0.020728 0.651822 ... 1.046951 0.0 1.101677 1.591893 0.0 1.616410 0.0 0.0 2.758485 2.189235
1 0.001001 0.695453 0.584647 0.673015 0.819583 15.507115 0.406832 0.199811 0.020848 0.659085 ... 1.045177 0.0 1.099991 1.589368 0.0 1.613189 0.0 0.0 2.754719 2.183303
2 0.002002 0.681246 0.554822 0.703955 0.802527 15.466434 0.412371 0.200056 0.020968 0.666350 ... 1.043402 0.0 1.098305 1.586841 0.0 1.609968 0.0 0.0 2.750953 2.177371
3 0.003003 0.667499 0.537470 0.662630 0.798900 15.425751 0.417906 0.200301 0.021090 0.673611 ... 1.041627 0.0 1.096619 1.584315 0.0 1.606747 0.0 0.0 2.747187 2.171439
4 0.004004 0.658052 0.521760 0.633820 0.838818 15.385064 0.423444 0.200547 0.021212 0.680872 ... 1.039852 0.0 1.094933 1.581789 0.0 1.603526 0.0 0.0 2.743420 2.165506

5 rows × 75 columns

df = df_valid.sample(10000)
fig = go.Figure([
    go.Scatter(
        name='Ground Truth',
        x=df['col1'],
        y=df['target'],
        mode='markers',
        line=dict(color='rgba(153, 115, 142, 0.2)'),
    ),
    go.Scatter(
        name='Component 1',
        x=pred_df['col1'],
        y=pred_df['mu_0'],
        mode='lines',
        line=dict(color='rgba(90, 92, 237, 1)'),
    ),
    go.Scatter(
        name='Component 2',
        x=pred_df['col1'],
        y=pred_df['mu_1'],
        mode='lines',
        line=dict(color='rgba(246, 76, 114, 1)'),
    ),
    go.Scatter(
        name='Upper Bound 1',
        x=pred_df['col1'],
        y=pred_df['mu_0']+pred_df['sigma_0'],
        mode='lines',
        marker=dict(color='rgba(47, 47, 162, 0.5)'),
#         line=dict(width=0),
        showlegend=False
    ),
    go.Scatter(
        name='Lower Bound 1',
        x=pred_df['col1'],
        y=pred_df['mu_0']-pred_df['sigma_0'],
        marker=dict(color="#444"),
        line=dict(width=0),
        mode='lines',
        fillcolor='rgba(47, 47, 162, 0.5)',
        fill='tonexty',
        showlegend=False
    ),
    go.Scatter(
        name='Upper Bound 2',
        x=pred_df['col1'],
        y=pred_df['mu_1']+pred_df['sigma_1'],
        mode='lines',
        marker=dict(color='rgba(250, 152, 174, 0.5)'),
#         line=dict(width=0),
        showlegend=False
    ),
    go.Scatter(
        name='Lower Bound 2',
        x=pred_df['col1'],
        y=pred_df['mu_1']-pred_df['sigma_1'],
        marker=dict(color="#444"),
        line=dict(width=0),
        mode='lines',
        fillcolor='rgba(250, 152, 174, 0.5)',
        fill='tonexty',
        showlegend=False
    ),
])
fig.update_layout(
    yaxis_title='y',
#     yaxis_range=[0,1],
    title='Mixture Density Network Prediction',
    hovermode="x",
    yaxis_range=[df['target'].min()*0.85, df['target'].max()*1.15]
)
fig.write_image("imgs/prob_reg_mdn_3.png")

fig = go.Figure([
    go.Scatter(
        name='Ground Truth',
        x=df['col1'],
        y=df['target'],
        mode='markers',
        line=dict(color='rgba(153, 115, 142, 0.2)'),
    ),
    go.Scatter(
        name='Component 1',
        x=pred_df['col1'],
        y=pred_df['mu_0'],
        mode='lines',
        line=dict(color='rgba(90, 92, 237, 1)'),
    ),
    go.Scatter(
        name='Mixing Coefficient 1',
        x=pred_df['col1'],
        y=pred_df['pi_1'],
        mode='lines',
        line=dict(color='rgba(255, 216, 117, 1)'),
    ),

    go.Scatter(
        name='Upper Bound 1',
        x=pred_df['col1'],
        y=pred_df['mu_0']+pred_df['sigma_0'],
        mode='lines',
        marker=dict(color='rgba(47, 47, 162, 0.5)'),
#         line=dict(width=0),
        showlegend=False
    ),
    go.Scatter(
        name='Lower Bound 1',
        x=pred_df['col1'],
        y=pred_df['mu_0']-pred_df['sigma_0'],
        marker=dict(color="#444"),
        line=dict(width=0),
        mode='lines',
        fillcolor='rgba(47, 47, 162, 0.5)',
        fill='tonexty',
        showlegend=False
    ),

])
fig.update_layout(
    yaxis_title='y',
#     yaxis_range=[-0.2,1],
    title='Mixture Density Network Prediction',
    hovermode="x",
    yaxis_range=[df['target'].min()*0.85, df['target'].max()*1.15]
)
fig.write_image("imgs/prob_reg_mixing1_3.png")

fig = go.Figure([
    go.Scatter(
        name='Ground Truth',
        x=df['col1'],
        y=df['target'],
        mode='markers',
        line=dict(color='rgba(153, 115, 142, 0.2)'),
    ),

    go.Scatter(
        name='Component 2',
        x=pred_df['col1'],
        y=pred_df['mu_1'],
        mode='lines',
        line=dict(color='rgba(246, 76, 114, 1)'),
    ),

    go.Scatter(
        name='Mixing Coefficient 2',
        x=pred_df['col1'],
        y=pred_df['pi_1'],
        mode='lines',
        line=dict(color='rgba(255, 216, 117, 1)'),
    ),

    go.Scatter(
        name='Upper Bound 2',
        x=pred_df['col1'],
        y=pred_df['mu_1']+pred_df['sigma_1'],
        mode='lines',
        marker=dict(color='rgba(250, 152, 174, 0.5)'),
#         line=dict(width=0),
        showlegend=False
    ),
    go.Scatter(
        name='Lower Bound 2',
        x=pred_df['col1'],
        y=pred_df['mu_1']-pred_df['sigma_1'],
        marker=dict(color="#444"),
        line=dict(width=0),
        mode='lines',
        fillcolor='rgba(250, 152, 174, 0.5)',
        fill='tonexty',
        showlegend=False
    ),
])
fig.update_layout(
    yaxis_title='y',
#     yaxis_range=[-0.2,1],
    title='Mixture Density Network Prediction',
    hovermode="x",
    yaxis_range=[df['target'].min()*0.85, df['target'].max()*1.15]
)
fig.write_image("imgs/prob_reg_mixing2_3.png")

from scipy.special import softmax
pred_df[['pi_0','pi_1']] = softmax(pred_df[['pi_0','pi_1']].values, axis=-1)
fig = px.line(pred_df, x='col1', y=['pi_0','pi_1'])
fig.write_image("imgs/prob_reg_mixing12_3.png")

Boston Housing Dataset

from sklearn.datasets import fetch_california_housing
target_col = "target"
data = fetch_california_housing(return_X_y=False)
X = pd.DataFrame(data['data'], columns=data['feature_names'])
cont_cols = X.columns.tolist()
cat_cols = []
y = data['target']
X[target_col] = y
df_train, df_test = train_test_split(X, test_size=0.2, random_state=42)
df_train, df_valid = train_test_split(df_train, test_size=0.2, random_state=42)

Plot

fig = px.histogram(df_train, x="target", title="Histogram")
fig.write_image("imgs/prob_reg_hist_4.png")

Training the MDN

Define the Configs

Let's use a nifty util function in the package to figure out the centers of the possible gaussian components. It internally runs a Kmeans and returns the cluster centroids and lets set that as the bias initialization

from pytorch_tabular.utils import get_gaussian_centers

mu_init = get_gaussian_centers(df_train[target_col], n_components=4)
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/sklearn/cluster/_kmeans.py:870: FutureWarning:

The default value of `n_init` will change from 10 to 'auto' in 1.4. Set the value of `n_init` explicitly to suppress the warning


epochs = 1000
batch_size = 2048
steps_per_epoch = int((len(df_train)//batch_size)*0.9)
data_config = DataConfig(
    target=['target'],
    continuous_cols=cont_cols,
    categorical_cols=cat_cols,
#         continuous_feature_transform="quantile_uniform"
)
trainer_config = TrainerConfig(
    auto_lr_find=True, # Runs the LRFinder to automatically derive a learning rate
    batch_size=batch_size,
    max_epochs=epochs,
    early_stopping="valid_loss",
    early_stopping_patience=5,
    checkpoints="valid_loss",
    load_best=True
)

optimizer_config = OptimizerConfig(lr_scheduler="ReduceLROnPlateau", lr_scheduler_params={"patience":3})

mdn_head_config = MixtureDensityHeadConfig(
    num_gaussian=4, 
    weight_regularization=2,
    mu_bias_init=mu_init
).__dict__

#lambda_pi=10, 
#lambda_sigma=1, 

backbone_config_class = "CategoryEmbeddingModelConfig"
backbone_config = dict(
    task="backbone",
    layers="200-100",  # Number of nodes in each layer
    activation="ReLU",  # Activation between each layers
    head=None,
)

model_config = MDNConfig(
    task="regression",
    backbone_config_class=backbone_config_class,
    backbone_config_params=backbone_config,
    head_config=mdn_head_config,
    learning_rate=1e-3,
)


tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config
)

Training the Model

tabular_model.fit(train=df_train, validation=df_valid)
Global seed set to 42
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:604: UserWarning:

Checkpoint directory /home/manujosephv/pytorch_tabular/docs/tutorials/checkpoints exists and is not empty.

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning:

The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning:

The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.


Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]
`Trainer.fit` stopped: `max_steps=100` reached.
Learning rate set to 0.0009120108393559097
Restoring states from the checkpoint path at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_c8d4ff90-1aad-4dcb-986b-2548e08389cc.ckpt
Restored all states from the checkpoint file at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_c8d4ff90-1aad-4dcb-986b-2548e08389cc.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type                      | Params
---------------------------------------------------------------
0 | _backbone        | CategoryEmbeddingBackbone | 21.9 K
1 | _embedding_layer | Embedding1dLayer          | 16    
2 | _head            | MixtureDensityHead        | 1.2 K 
3 | loss             | MSELoss                   | 0     
---------------------------------------------------------------
23.1 K    Trainable params
0         Non-trainable params
23.1 K    Total params
0.092     Total estimated model params size (MB)

Sanity Checking: 0it [00:00, ?it/s]
Training: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:48: UserWarning:

Detected KeyboardInterrupt, attempting graceful shutdown...

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_lightning/utilities/cloud_io.py:41: LightningDeprecationWarning:

`pytorch_lightning.utilities.cloud_io.load` has been deprecated in v1.8.0 and will be removed in v1.10.0. This function is internal but you can copy over its implementation.


<pytorch_lightning.trainer.trainer.Trainer at 0x7fd6a27173d0>

Predictions and Visualization

pred_df = tabular_model.predict(df_test, quantiles=[0.25,0.5,0.75], n_samples=100, ret_logits=True)
pred_df.head()
Generating Predictions...:   0%|          | 0/3 [00:00<?, ?it/s]
/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_model.py:1126: PerformanceWarning:

DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_model.py:1126: PerformanceWarning:

DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_model.py:1126: PerformanceWarning:

DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_model.py:1126: PerformanceWarning:

DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_model.py:1126: PerformanceWarning:

DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_model.py:1126: PerformanceWarning:

DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_model.py:1126: PerformanceWarning:

DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_model.py:1126: PerformanceWarning:

DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_model.py:1126: PerformanceWarning:

DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_model.py:1126: PerformanceWarning:

DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_model.py:1126: PerformanceWarning:

DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_model.py:1126: PerformanceWarning:

DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_model.py:1126: PerformanceWarning:

DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_model.py:1126: PerformanceWarning:

DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_model.py:1126: PerformanceWarning:

DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_model.py:1126: PerformanceWarning:

DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`

/home/manujosephv/pytorch_tabular/.env/tabular_env/lib/python3.10/site-packages/pytorch_tabular/tabular_model.py:1126: PerformanceWarning:

DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`


MedInc HouseAge AveRooms AveBedrms Population AveOccup Latitude Longitude target target_prediction ... backbone_features_90 backbone_features_91 backbone_features_92 backbone_features_93 backbone_features_94 backbone_features_95 backbone_features_96 backbone_features_97 backbone_features_98 backbone_features_99
20046 1.6812 25.0 4.192201 1.022284 1392.0 3.877437 36.06 -119.01 0.47700 0.921853 ... 0.104538 0.0 0.516786 0.749274 0.0 0.0 0.402392 0.0 0.0 0.000000
3024 2.5313 30.0 5.039384 1.193493 1565.0 2.679795 35.14 -119.46 0.45800 1.190664 ... 0.391450 0.0 0.000000 0.395355 0.0 0.0 0.000000 0.0 0.0 0.039423
15663 3.4801 52.0 3.977155 1.185877 1310.0 1.360332 37.80 -122.44 5.00001 3.117064 ... 0.000000 0.0 0.000000 0.045079 0.0 0.0 0.000000 0.0 0.0 0.204494
20484 5.7376 17.0 6.163636 1.020202 1705.0 3.444444 34.28 -118.72 2.18600 2.724108 ... 0.000000 0.0 0.000000 0.069432 0.0 0.0 0.000000 0.0 0.0 2.369836
9814 3.7250 34.0 5.492991 1.028037 1063.0 2.483645 36.62 -121.93 2.78000 2.599437 ... 0.096121 0.0 0.156296 0.269434 0.0 0.0 0.113502 0.0 0.0 0.205405

5 rows × 125 columns

import scipy.stats as ss

def plot_normal(x_range, mu=0, sigma=1, cdf=False, **kwargs):
    '''
    Plots the normal distribution function for a given x range
    If mu and sigma are not provided, standard normal is plotted
    If cdf=True cumulative distribution is plotted
    Passes any keyword arguments to matplotlib plot function
    '''
    x = x_range
    if cdf:
        y = ss.norm.cdf(x, mu, sigma)
    else:
        y = ss.norm.pdf(x, mu, sigma)
    return x,y
import torch
from torch import nn

from torch.autograd import Variable
from torch.distributions import Categorical
def get_pdf(idx):
    row = pred_df.iloc[idx]
    pi = torch.from_numpy(row[['pi_0','pi_1','pi_2','pi_3']].values).unsqueeze(0)
    mu = torch.from_numpy(row[['mu_0','mu_1','mu_2','mu_3']].values).unsqueeze(0)
    sigma = torch.from_numpy(row[['sigma_0','sigma_1','sigma_2','sigma_3']].values).unsqueeze(0)
    softmax_pi = nn.functional.gumbel_softmax(pi, tau=1, dim=-1)
    categorical = Categorical(softmax_pi)
    pis = categorical.sample().unsqueeze(1)
    sigma = sigma.gather(1, pis).item()
    mu = mu.gather(1, pis).item()
    x = np.linspace(row['target_prediction'].item()*0.1, row['target_prediction'].item()*1.9, 5000)
    return plot_normal(x, mu=mu, sigma=sigma)
# idxs = pred_df[mask].sample(5).index

idxs = [2, 23, 564, 365]
traces = []
for idx in idxs:
    x,y = get_pdf(idx)
    trace = go.Scatter(
            name=f'House_{idx}',
            x=x,
            y=y,
            mode='lines',
    #         line=dict(color='rgba(246, 76, 114, 1)'),
        )
    traces.append(trace)

fig = go.Figure(traces)
fig.update_layout(
    yaxis_title='P(MEDV)',
    xaxis_title='MEDV',
#     yaxis_range=[-0.2,1],
    title='PDFs of different Houses',
    hovermode="x"
)
fig.write_image("imgs/prob_reg_pdfs_4.png")