Explainability/Interpretablity
The explainability features in PyTorch Tabular allow users to interpret and understand the predictions made by a tabular deep learning model. These features provide insights into the model's decision-making process and help identify the most influential features. Some of the explainability features are inbuilt from the models, and a lot of others are based on the Captum library.
Native Feature Importance¶
One of the features of the GBDT models which everybody loves is the feature importance. It helps us understand which features are the most important for the model. PyTorch Tabular provides a similar feature for some of the models - GANDALF, GATE, and FTTransformers - where the models natively support the extraction of feature importance.
Local Feature Attributions/Explanations¶
Local feature attributions/explanations help us understand the contribution of each feature towards the prediction for a particular sample. PyTorch Tabular provides this feature for all the models except TabTransformer, Tabnet, and Mixed Density Networks. It is based on the Captum library. The library provides a lot of algorithms for computing feature attributions. PyTorch Tabular provides a wrapper around the library to make it easy to use. The following algorithms are supported:
- GradientShap: https://captum.ai/api/gradient_shap.html
- IntegratedGradients: https://captum.ai/api/integrated_gradients.html
- DeepLift: https://captum.ai/api/deep_lift.html
- DeepLiftShap: https://captum.ai/api/deep_lift_shap.html
- InputXGradient: https://captum.ai/api/input_x_gradient.html
- FeaturePermutation: https://captum.ai/api/feature_permutation.html
- FeatureAblation: https://captum.ai/api/feature_ablation.html
- KernelShap: https://captum.ai/api/kernel_shap.html
PyTorch Tabular
also supports explaining single instances as well as batches of instances. But, larger datasets will take longer to explain. An exception is the FeaturePermutation
and FeatureAblation
methods, which is only meaningful for large batches of instances.
Most of these explainability methods require a baseline. This is used to compare the attributions of the input with the attributions of the baseline. The baseline can be a scalar value, a tensor of the same shape as the input, or a special string like "b|10000" which means 10000 samples from the training data. If the baseline is not provided, the default baseline (zero) is used.
# tabular_model is the trained model of a supported model
# Explain a single instance using the GradientShap method and baseline as 10000 samples from the training data
tabular_model.explain(test.head(1), method="GradientShap", baselines="b|10000")
# Explain a batch of instances using the IntegratedGradients method and baseline as 0
tabular_model.explain(test.head(10), method="IntegratedGradients", baselines=0)
Checkout the Captum documentation for more details on the algorithms and the Explainability Tutorial for example usage.
API Reference¶
pytorch_tabular.TabularModel.explain(data, method='GradientShap', method_args={}, baselines=None, **kwargs)
¶
Returns the feature attributions/explanations of the model as a pandas DataFrame. The shape of the returned dataframe is (num_samples, num_features)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data |
DataFrame
|
The dataframe to be explained |
required |
method |
str
|
The method to be used for explaining the model. It should be one of the Defaults to "GradientShap". For more details, refer to https://captum.ai/api/attribution.html |
'GradientShap'
|
method_args |
Optional[Dict]
|
The arguments to be passed to the initialization of the Captum method. |
{}
|
baselines |
Union[float, tensor, str]
|
The baselines to be used for the explanation.
If a scalar is provided, will use that value as the baseline for all the features.
If a tensor is provided, will use that tensor as the baseline for all the features.
If a string like |
None
|
**kwargs |
Additional keyword arguments to be passed to the Captum method |
{}
|
Returns:
Name | Type | Description |
---|---|---|
DataFrame |
DataFrame
|
The dataframe with the feature importance |
Source code in src/pytorch_tabular/tabular_model.py
1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049 2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075 2076 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2095 2096 2097 2098 2099 2100 2101 2102 2103 |
|