Importing the Library¶
Define the Configs¶
data_config = DataConfig(
target=[
"target"
], # target should always be a list. Multi-targets are only supported for regression. Multi-Task Classification is not implemented
continuous_cols=num_col_names,
categorical_cols=cat_col_names,
)
trainer_config = TrainerConfig(
auto_lr_find=True, # Runs the LRFinder to automatically derive a learning rate
batch_size=1024,
max_epochs=100,
accelerator="auto", # can be 'cpu','gpu', 'tpu', or 'ipu'
)
optimizer_config = OptimizerConfig()
head_config = LinearHeadConfig(
layers="", # No additional layer in head, just a mapping layer to output_dim
dropout=0.1,
initialization="kaiming",
).__dict__ # Convert to dict to pass to the model config (OmegaConf doesn't accept objects)
model_config = GANDALFConfig(
task="classification",
gflu_stages=3, # Number of stages in the GFLU block
gflu_dropout=0.0, # Dropout in each of the GFLU block
gflu_feature_init_sparsity=0.1, # Sparsity of the initial feature selection
head="LinearHead", # Linear Head
head_config=head_config, # Linear Head Config
learning_rate=1e-3,
)
tabular_model = TabularModel(
data_config=data_config,
model_config=model_config,
optimizer_config=optimizer_config,
trainer_config=trainer_config,
verbose=False,
)
Training the Model¶
Evaluating the Model¶
Native Global Feature Importance¶
Some models like GANDALF, GATE, and FTTransformer have native feature importance, similar to the feature importance you get with GDBTs.
Local Feature Attributions¶
We can also use techniques like SHAP to get local feature attributions. This is a very powerful technique to explain the predictions of the model. We can use the explain
method to get the local feature attributions for a given input.
PyTorch Tabular
supports these methods from captum
for all models except Tabnet, TabTransformer, and MDN:
- GradientShap: https://captum.ai/api/gradient_shap.html
- IntegratedGradients: https://captum.ai/api/integrated_gradients.html
- DeepLift: https://captum.ai/api/deep_lift.html
- DeepLiftShap: https://captum.ai/api/deep_lift_shap.html
- InputXGradient: https://captum.ai/api/input_x_gradient.html
- FeaturePermutation: https://captum.ai/api/feature_permutation.html
- FeatureAblation: https://captum.ai/api/feature_ablation.html
- KernelShap: https://captum.ai/api/kernel_shap.html
PyTorch Tabular
also supports explaining single instances as well as batches of instances. But, larger datasets will take longer to explain. An exception is the FeaturePermutation
and FeatureAblation
methods, which is only meaningful for large batches of instances.
Most of these explainability methods require a baseline. This is used to compare the attributions of the input with the attributions of the baseline. The baseline can be a scalar value, a tensor of the same shape as the input, or a special string like "b|100" which means 100 samples from the training data. If the baseline is not provided, the default baseline (zero) is used.
Single Instance¶
# Draw plot
plt.figure(figsize=(14, 10), dpi=80)
# Plotting the horizontal lines
plt.hlines(
y=exp.index,
linewidth=5,
xmin=0,
xmax=exp.GradientSHAP,
colors=exp.colors.values,
alpha=0.5,
)
# Decorations
# Setting the labels of x-axis and y-axis
plt.gca().set(ylabel="Features", xlabel="GradientSHAP")
# Setting Date to y-axis
plt.yticks(exp.index, exp.Features, fontsize=12)
# Title of Bar Chart
plt.title("GradientSHAP Local Explanation", fontdict={"size": 20})
# Optional grid layout
plt.grid(linestyle="--", alpha=0.5)
# Displaying the Diverging Bar Chart
plt.show()
Multiple Instances¶
# Draw plot
plt.figure(figsize=(14, 10), dpi=80)
# Plotting the horizontal lines
plt.hlines(
y=exp_agg.index,
linewidth=5,
xmin=0,
xmax=exp_agg.GradientSHAP,
colors=exp_agg.colors.values,
alpha=0.5,
)
# Decorations
# Setting the labels of x-axis and y-axis
plt.gca().set(ylabel="Features", xlabel="GradientSHAP")
# Setting Date to y-axis
plt.yticks(exp_agg.index, exp_agg.Features, fontsize=12)
# Title of Bar Chart
plt.title("GradientSHAP Global Explanation", fontdict={"size": 20})
# Optional grid layout
plt.grid(linestyle="--", alpha=0.5)
# Displaying the Diverging Bar Chart
plt.show()