Explainability/Interpretablity
The explainability features in PyTorch Tabular allow users to interpret and understand the predictions made by a tabular deep learning model. These features provide insights into the model's decision-making process and help identify the most influential features. Some of the explainability features are inbuilt from the models, and a lot of others are based on the Captum library.
Native Feature Importance¶
One of the features of the GBDT models which everybody loves is the feature importance. It helps us understand which features are the most important for the model. PyTorch Tabular provides a similar feature for some of the models - GANDALF, GATE, and FTTransformers - where the models natively support the extraction of feature importance.
Local Feature Attributions/Explanations¶
Local feature attributions/explanations help us understand the contribution of each feature towards the prediction for a particular sample. PyTorch Tabular provides this feature for all the models except TabTransformer, Tabnet, and Mixed Density Networks. It is based on the Captum library. The library provides a lot of algorithms for computing feature attributions. PyTorch Tabular provides a wrapper around the library to make it easy to use. The following algorithms are supported:
- GradientShap: https://captum.ai/api/gradient_shap.html
- IntegratedGradients: https://captum.ai/api/integrated_gradients.html
- DeepLift: https://captum.ai/api/deep_lift.html
- DeepLiftShap: https://captum.ai/api/deep_lift_shap.html
- InputXGradient: https://captum.ai/api/input_x_gradient.html
- FeaturePermutation: https://captum.ai/api/feature_permutation.html
- FeatureAblation: https://captum.ai/api/feature_ablation.html
- KernelShap: https://captum.ai/api/kernel_shap.html
PyTorch Tabular
also supports explaining single instances as well as batches of instances. But, larger datasets will take longer to explain. An exception is the FeaturePermutation
and FeatureAblation
methods, which is only meaningful for large batches of instances.
Most of these explainability methods require a baseline. This is used to compare the attributions of the input with the attributions of the baseline. The baseline can be a scalar value, a tensor of the same shape as the input, or a special string like "b|10000" which means 10000 samples from the training data. If the baseline is not provided, the default baseline (zero) is used.
# tabular_model is the trained model of a supported model
# Explain a single instance using the GradientShap method and baseline as 10000 samples from the training data
tabular_model.explain(test.head(1), method="GradientShap", baselines="b|10000")
# Explain a batch of instances using the IntegratedGradients method and baseline as 0
tabular_model.explain(test.head(10), method="IntegratedGradients", baselines=0)
Checkout the Captum documentation for more details on the algorithms and the Explainability Tutorial for example usage.
API Reference¶
pytorch_tabular.TabularModel.explain(data, method='GradientShap', method_args={}, baselines=None, **kwargs)
¶
Returns the feature attributions/explanations of the model as a pandas DataFrame. The shape of the returned dataframe is (num_samples, num_features)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data |
DataFrame
|
The dataframe to be explained |
required |
method |
str
|
The method to be used for explaining the model. It should be one of the Defaults to "GradientShap". For more details, refer to https://captum.ai/api/attribution.html |
'GradientShap'
|
method_args |
Optional[Dict]
|
The arguments to be passed to the initialization of the Captum method. |
{}
|
baselines |
Union[float, tensor, str]
|
The baselines to be used for the explanation.
If a scalar is provided, will use that value as the baseline for all the features.
If a tensor is provided, will use that tensor as the baseline for all the features.
If a string like |
None
|
**kwargs |
Additional keyword arguments to be passed to the Captum method |
{}
|
Returns:
Name | Type | Description |
---|---|---|
DataFrame |
DataFrame
|
The dataframe with the feature importance |
Source code in src/pytorch_tabular/tabular_model.py
1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 |
|