Abstract
Apart from high accuracy, what interests many researchers and practitioners in real-life tabular learning problems (e.g., fraud detection and credit scoring) is uncovering hidden patterns in the data and/or providing meaningful justification of decisions made by machine learning models. In this concern, an important question arises: should one use inherently interpretable models or explain full-complexity models such as XGBoost, Random Forest with post hoc tools? Opting for the second choice is typically supported by the accuracy metric, but it is not always evident that the performance gap is sufficiently significant, especially considering the current trend of accurate and inherently interpretable models, as well as accounting for other real-life evaluation metrics such as faithfulness, stability, and computational cost of explanations. In this work, we show through benchmarking on 45 datasets that the relative accuracy loss is less than 4% in average when using intelligible models such as explainable boosting machine. Furthermore, we propose a simple use of model ensembling to improve the expressiveness of TabSRALinear, a novel attention-based inherently interpretable solution, and demonstrate both theoretically and empirically that it is a viable option for (1) generating stable or robust explanations and (2) incorporating human knowledge during the training phase. Source code is available at https://github.com/anselmeamekoe/TabSRA.
Similar content being viewed by others
Explore related subjects
Discover the latest articles, news and stories from top researchers in related subjects.Notes
The neural networks (NNs) described in Sect. 2.2.2 include an embedding layer designed to manage categorical features, as outlined in the work by Gorishniy et al. [6]. NNs (Sect. 2.2.2) are equipped with an embedding layer for handling categorical feature [6] and CatBoost also offers a native approach for handling categorical features through a combination of target encoding (referred to as “Borders") and Frequency Encoding (referred to as “Counter")
Gaussian quantile transformation was used in Grinsztajn et al. [16]; however, as this transformation is not bijective/linear, we only use scaling to preserve interpretability in the initial feature space.
For the sake of brevity, we used datasets containing only numerical features, as the notion of neighborhood and Lipschtitz estimate is not trivial for categorical and discrete features. For more details, see Alvarez-Melis and Jaakkola [33].
This dataset is highly imbalanced; therefore, AUCPR is a more appropriate than AUCROC [39].
One may want to process the NumOfProducts in binary feature, i.e., one product versus more than one product. This will help maintain the test AUCROC of XGBoost at approximately 0.828 for this dataset.
The obtained explanations using TreeSHAP are not actually monotonic decreasing as they are supposed to be Fig. 6b. That is, the effect of the NumOfProducts for some customers with three products is higher than the one of some customers with two products.
In general, these explanations do not represent causality, but rather a potential description of the association between the observed feature and the target variable.
References
Vaswani A, Shazeer N, Parmar N, Uszkoreit J, Jones L, Gomez AN, Kaiser Ł, Polosukhin I (2017) Attention is all you need. Advances in neural information processing systems 30
Chen T, Guestrin C (2016) Xgboost: A scalable tree boosting system. In: Proceedings of the 22nd Acm Sigkdd International Conference on Knowledge Discovery and Data Mining, pp 785–794
Somepalli G, Goldblum M, Schwarzschild A, Bruss CB, Goldstein T (2021) Saint: Improved neural networks for tabular data via row attention and contrastive pre-training. arXiv preprint arXiv:2106.01342
Kossen J, Band N, Lyle C, Gomez AN, Rainforth T, Gal Y (2021) Self-attention between datapoints: going beyond individual input-output pairs in deep learning. Adv Neural Inf Process Syst 34:28742–28756
Huang X, Khetan A, Cvitkovic M, Karnin Z (2020) Tabtransformer: Tabular data modeling using contextual embeddings. arXiv preprint arXiv:2012.06678
Gorishniy Y, Rubachev I, Khrulkov V, Babenko A (2021) Revisiting deep learning models for tabular data. Adv Neural Inf Process Syst 34:18932–18943
Lundberg SM, Lee S-I (2017) A unified approach to interpreting model predictions. Advances in neural information processing systems 30
Ribeiro MT, Singh S, Guestrin C (2016) " why should i trust you?" explaining the predictions of any classifier. In: Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, pp 1135–1144
Lundberg SM, Erion G, Chen H, DeGrave A, Prutkin JM, Nair B, Katz R, Himmelfarb J, Bansal N, Lee S-I (2020) From local explanations to global understanding with explainable AI for trees. Nat Mach Intell 2(1):56–67
Amoukou SI, Salaün T, Brunel N (2022) Accurate shapley values for explaining tree-based models. In: International Conference on Artificial Intelligence and Statistics. PMLR, pp 2448–2465
Kumar IE, Venkatasubramanian S, Scheidegger C, Friedler S (2020) Problems with shapley-value-based explanations as feature importance measures. In: International Conference on Machine Learning. PMLR, pp 5491–5500
Rudin C (2019) Stop explaining black box machine learning models for high stakes decisions and use interpretable models instead. Nat Mach Intell 1(5):206–215
Nori H, Jenkins S, Koch P, Caruana R (2019) Interpretml: A unified framework for machine learning interpretability. arXiv preprint arXiv:1909.09223
Agarwal R, Melnick L, Frosst N, Zhang X, Lengerich B, Caruana R, Hinton GE (2021) Neural additive models: interpretable machine learning with neural nets. Adv Neural Inf Process Syst 34:4699–4711
Chang C-H, Caruana R, Goldenberg A (2021) Node-gam: Neural generalized additive model for interpretable deep learning. arXiv preprint arXiv:2106.01613
Grinsztajn L, Oyallon E, Varoquaux G (2022) Why do tree-based models still outperform deep learning on typical tabular data? Adv Neural Inf Process Syst 35:507–520
Amekoe KM, Azzag H, Lebbah M, Dagdia ZC, Jaffre G (2023) A new class of intelligible models for tabular learning. In: The 5th International Workshop on eXplainable Knowledge Discovery in Data Mining (PKDD)-ECML-PKDD
Amekoe KM, Dilmi MD, Azzag H, Dagdia ZC, Lebbah M, Jaffre G (2023) Tabsra: An attention based self-explainable model for tabular learning. In: The 31th European Symposium on Artificial Neural Networks, Computational Intelligence and Machine Learning (ESANN)
Breiman L (2001) Random forests. Mach Learn 45(1):5–32
Ke G, Meng Q, Finley T, Wang T, Chen W, Ma W, Ye Q, Liu T-Y (2017) Lightgbm: a highly efficient gradient boosting decision tree. Advances in neural information processing systems 30
Prokhorenkova L, Gusev G, Vorobev A, Dorogush AV, Gulin A (2018) Catboost: unbiased boosting with categorical features. Advances in neural information processing systems 31
Chen K-Y, Chiang P-H, Chou H-R, Chen T-W, Chang T-H (2023) Trompt: Towards a better deep neural network for tabular data. arXiv preprint arXiv:2305.18446
Borisov V, Leemann T, Seßler K, Haug J, Pawelczyk M, Kasneci G (2021) Deep neural networks and tabular data: A survey. arXiv preprint arXiv:2110.01889
McElfresh D, Khandagale S, Valverde J, Prasad C V, Ramakrishnan G, Goldblum M, White C (2023) When do neural nets outperform boosted trees on tabular data? arXiv e-prints, 2305
Huang X, Marques-Silva J (2023) The inadequacy of shapley values for explainability. arXiv preprint arXiv:2302.08160
Ribeiro MT, Singh S, Guestrin C (2018) Anchors: High-precision model-agnostic explanations. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol 32
Marques-Silva J, Ignatiev A (2022) Delivering trustworthy ai through formal xai. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol 36. pp 12342–12350
Chen H, Covert IC, Lundberg SM, Lee S-I (2023) Algorithms to estimate shapley value feature attributions. Nat Mach Intell 5(6):590–601
Yang Z, Zhang A, Sudjianto A (2021) Gami-net: an explainable neural network based on generalized additive models with structured interactions. Pattern Recogn 120:108192
Lou Y, Caruana R, Gehrke J, Hooker G (2013) Accurate intelligible models with pairwise interactions. In: Proceedings of the 19th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining. pp 623–631
Popov S, Morozov S, Babenko A (2019) Neural oblivious decision ensembles for deep learning on tabular data. arXiv preprint arXiv:1909.06312
Chen Z, Tan S, Nori H, Inkpen K, Lou Y, Caruana R (2021) Using explainable boosting machines (ebms) to detect common flaws in data. In: Joint European Conference on Machine Learning and Knowledge Discovery in Databases. Springer, pp 534–551
Alvarez-Melis D, Jaakkola TS (2018) On the robustness of interpretability methods. arXiv preprint arXiv:1806.08049
Buitinck L, Louppe G, Blondel M, Pedregosa F, Mueller A, Grisel O, Niculae V, Prettenhofer P, Gramfort A, Grobler J et al (2013) Api design for machine learning software: experiences from the scikit-learn project. arXiv preprint arXiv:1309.0238
Ignatiev A, Izza Y, Stuckey PJ, Marques-Silva J (2022) Using maxsat for efficient explanations of tree ensembles. In: AAAI
Shrikumar A, Greenside P, Kundaje A (2017) Learning important features through propagating activation differences. In: International Conference on Machine Learning. PMLR, pp 3145–3153
Alvarez Melis D, Jaakkola T (2018) Towards robust interpretability with self-explaining neural networks. In: Bengio S, Wallach H, Larochelle H, Grauman K, Cesa-Bianchi N, Garnett R (eds) Advances in Neural Information Processing Systems, vol 31. Curran Associates, Inc
Agarwal C, Johnson N, Pawelczyk M, Krishna S, Saxena E, Zitnik M, Lakkaraju H (2022) Rethinking Stability for Attribution-based Explanations
Davis J, Goadrich M (2006) The relationship between precision-recall and roc curves. In: Proceedings of the 23rd International Conference on Machine Learning. pp 233–240
Wistuba M, Schilling N, Schmidt-Thieme L (2015) Learning hyperparameter optimization initializations. In: 2015 IEEE International Conference on Data Science and Advanced Analytics (DSAA). IEEE, pp 1–10
Gorishniy Y, Rubachev I, Babenko A (2022) On embeddings for numerical features in tabular deep learning. Adv Neural Inf Process Syst 35:24991–25004
Lengerich B, Tan S, Chang C-H, Hooker G, Caruana R (2020) Purifying interaction effects with the functional anova: An efficient algorithm for recovering identifiable additive models. In: International Conference on Artificial Intelligence and Statistics. PMLR, pp 2402–2412
Müller S, Toborek V, Beckh K, Jakobs M, Bauckhage C, Welke P (2023) An Empirical Evaluation of the Rashomon Effect in Explainable Machine Learning
Kim H, Papamakarios G, Mnih A (2021) The lipschitz constant of self-attention. In: International Conference on Machine Learning. PMLR, pp 5562–5571
Ultsch A (2005) Clustering wih som: U* c. Proc, Workshop on Self-Organizing Maps
Paszke A, Gross S, Massa F, Lerer A, Bradbury J, Chanan G, Killeen T, Lin Z, Gimelshein N, Antiga L, et al (2019) Pytorch: an imperative style, high-performance deep learning library. Advances in neural information processing systems 32
Biewald L (2020) Experiment Tracking with Weights and Biases. Software available from wandb.com. https://www.wandb.com/
Author information
Authors and Affiliations
Corresponding author
Ethics declarations
Conflict of interest
The authors have no conflict of interest to declare that are relevant to the content of this article.
Supplementary information
All data supporting the findings of this study are available within the paper and its Supplementary Information.
Additional information
Publisher's Note
Springer Nature remains neutral with regard to jurisdictional claims in published maps and institutional affiliations.
Appendices
Appendix A: additional theoretical results
1.1 On the Lipschitz estimate of TabSRALinear ensemble
Theorem 2
The feature attributions produced by TabSRALinear ensemble (Eq. 9) are locally stable in the sense of Lipschitz, that is, for all \(\textbf{x} \in {\mathbbm {R}}^{p}\), there exists \(\delta >0\) and \(L_{\textbf{x}}\ge 0\) finite such that:
With \(L_{\textbf{x}} = \sum _{h=1}^{H}\Vert \varvec{\beta }^h \Vert _{\infty } \left[ \Vert \textbf{a}^h \Vert _{\infty } + L_{{\textbf{a}}}^h(\Vert \textbf{x} \Vert _{\infty }+\delta )\right]\) and \(L_{{\textbf{a}}}^h\ge 0\) the Lipschitz constant of the h-th SRA block.
Theorem 2 demonstrates that the Lipschitz estimate of the TabSRALinear ensemble is theoretically additive with respect to the number of learners in the ensemble, denoted H. Hence, a large value of H may lead to less stable explanations. Our experimental results on default benchmark and the ablation study (Table 6) demonstrate that \(H=2\) is typically sufficient to achieve strong predictive performance.
The proof of the Theorem 2 can be readily accomplished by using the fact that: (1) the feature attributions of each individual TabSRALinear in the ensemble are locally stable in the sense of Lipschitz in accordance with Theorem 1. (2) A finite sum of Lipschitz continuous functions is Lipschitz continuous, as stated in Lemma 3).
1.2 Proof theorem 1
Before providing a proof of the theorem, we consider the following lemmas:
Lemma 3
-
(1)
The sum of two Lipschitz continuous functions is Lipschitz continuous
-
(2)
The product of two Lipschitz continuous and bounded functions is Lipschitz continuous
Proof
We consider \(\theta\) and \(\psi\) two functions from \({\mathbbm {R}}^{p} \longrightarrow {\mathbbm {R}}^{p}\), \(L_\theta\) and \(L_\psi -\) Lipschitz. For all \(\textbf{x}, \textbf{x}' \in {\mathbbm {R}}^p\) we have: (1)
(2) Moreover let’s assume that \(\theta\) and \(\psi\) are bounded i.e. \(\Vert \theta \Vert _{\infty }<{\infty }\) and \(\Vert \psi \Vert _{\infty }<{\infty }\), then we have:
\(\square\)
Lemma 4
The attention vector outputted using the SRA block is stable is sense of Lipschitz i.e., for all \(\textbf{x}, \textbf{x}' \in {\mathbbm {R}}^p\), there exists a constant \(L_{{\textbf{a}}}\ge 0\) finite such that:
Moreover, \(\Vert a \Vert _{\infty }=1\).
Proof
Each component \(k_i^j\) of the keys matrix K (resp. \(q_i^j\) of Q) is stable in the sense of Lipschtiz as outputted by a fully connected layer [44] (linear transformations followed by common activation such as ReLU, Sigmoid) and is bounded in [0, 1] using Sigmoid activation. Hence, the for all \(\textbf{x}, \textbf{x}' \in {\mathbbm {R}}^p\) there exists \(\alpha _{k_i^j}>0\), \(\alpha _{q_i^j}>0\), finite such that \(|k_i^j(\textbf{x}) - k_i^j(\textbf{x}') |= \alpha _{k_i^j}\Vert \textbf{x} - \textbf{x}' \Vert _1\) and \(|q_i^j(\textbf{x}) - q_i^j(\textbf{x}') |= \alpha _{q_i^j}\Vert \textbf{x} - \textbf{x}' \Vert _1\).
Then, each component of the attention vector is also Lipschitz since:
with \(L_{a_i} = \frac{1}{d_k} \sum _{j=1}^{d_k} \alpha _{i}^j\).
Finally, we have:
with \(L_{{\textbf{a}}} = \sum _{i=1}^{p}{L_{a_i}}\).
Since every \(a_i \in [0,1]\) we have \(\Vert {\textbf{a}} \Vert _{\infty }=1\). \(\square\)
Proof
of Theorem 1
Using 17 and considering \(\psi (\textbf{x}) = a(\textbf{x})\) , \(\theta (\textbf{x}) = \textbf{x}\) we have \(\Vert \psi \Vert _{\infty } = \Vert a \Vert _{\infty } =1\) and \(\Vert \theta \Vert _{\infty } = \Vert \mathcal {D} \Vert _{\infty } =\max _{\textbf{x} \in \mathcal {D}}\Vert \textbf{x} \Vert _{\infty }\) which is the overall maximal observable feature value.
Therefore:
In the formulation of the TabSRALinear, we are not necessarily interested in global stability or a uniform Lipschitz constant, as in Eq. 22 but rather in regional or local stability around a given target or anchor data point.
With this consideration, given the target data point \(\textbf{x}\), we can restrict \(\mathcal {D}\) to its neighborhood i.e.,
therefore \(\Vert \mathcal {D} \Vert _{\infty } \le \Vert \textbf{x} \Vert _{\infty } +\delta\).
Using the Eq. 22, it results that for every \(\textbf{x} \in {\mathbbm {R}}^{p}\), there exists a constant \(\delta > 0\) such that \(\Vert \textbf{x} - \textbf{x}' \Vert _1 < \delta\) implies:
\(\square\)
Appendix B: additional empirical information
1.1 Datasets
1.1.1 Middle-scale benchmark
This benchmark of 45 datasets (59 binary classification and regression and tasks) is introduced in a paper titled: Why do tree-based models still outperform deep learning on typical tabular data? [16]. The main goal of this benchmark was to identify certain meta-features or inductive biases that explain the superior predictive performance of tree-based models over NNs in tabular learning. We take a step forward by incorporating inherently interpretable models in the assessment the predictive performance. We provide essential details about datasets and refer the interested reader to the original paper [16] for further information. The main criteria for selecting datasets are:
-
The datasets contain heterogeneous features (this excludes images and signal datasets).
-
The datasets are not high dimensional. That is, the ratio a p/n is below 1/10, \(p<500\) with p the number of features and n the number of observations.
-
The data are I.I.D. Stream-like datasets or time series are removed.
-
They are real-world data.
-
The datasets are not too easy and not too small, i.e., \(p\ge 4\) and \(n \ge 3 000\).
-
They are not deterministic datasets. Datasets where the target is a deterministic function of the predictors of the predictors.
Furthermore, in order to keep the learning as homogeneous as possible, some subproblems that deserve their own analysis are excluded. That is:
-
The size of the datasets. In the middle-scale regime, the training set is truncated to 10,000 and the test set to 50,000.
-
There is no missing data. All data points with missing values are removed.
-
Balanced classes. For classification, the target is binarized if there are several classes by taking the two most numerous classes, and we keep half of the samples in each class.
-
Categorical features with more than 20 items are removed.
-
Numerical features with less than 10 unique values are removed.
Finally, every algorithm and hyperparameters combination is evaluated on the same random seed Train/Validation/Test split or fold. More precisely, 70% of samples for the train set (or the percentage which corresponds to a maximum of 10,000 samples if 70% is too high). Of the remaining 30%, 30% are used for the validation set (truncated to a maximum of 50,000 samples), and 70% for the test set (also truncated to a maximum of 50,000 samples). Depending on the size of the test set, several random seed splits/folds are used (cross-validation). That is:
-
If the test set is more than 6000 samples, we evaluate our algorithms on onefold.
-
If the test set is between 3000 and 6000 samples, we evaluate our algorithms in twofolds.
-
If the test set is 1000 and 3000 samples, we evaluate our algorithms on threefolds.
-
If the test set is less than 1000 testing samples, we evaluate our algorithms on fivefolds.
1.1.2 Default benchmark
-
Bank Churn: This dataset contains details of a bank’s customers, and the target variable is a binary variable reflecting the fact whether the customer left the bank (closed his account) or continues to be a customer. We drop the CustomerId and Surname columns. We also drop the HasCrCard column, which is non-really informative https://www.kaggle.com/datasets/shrutimechlearn/churn-modelling.
-
Credit Default: The goal is to predict the default probability (for the next month) of credit card clients in Taiwan using historical data. https://archive.ics.uci.edu/ml/datasets/default+of+credit+card+clients. We drop the SEX column as generally, the gender information is not used in credit scoring.
-
Credit Card Fraud: For this dataset, the aim is to predict whether credit card transactions are fraudulent or genuine. The original dataset is PCA transformed for privacy purpose. We drop the time information in our study. https://www.kaggle.com/mlg-ulb/creditcardfraud
-
Heloc Fico: For this dataset, the target variable to predict is a binary variable called RiskPerformance. The value “Bad” indicates that a consumer was 90 days past due or worse at least once over a period of 24 months from when the credit account was opened. The value “Good” indicates that they have made their payments without ever being more than 90 days overdue. https://community.fico.com/s/explainable-machine-learning-challenge.
1.2 Additional results for TabSRAs: vizualization
How raw data are reinforced using the SRA block
In addition to Sect. 3.1, we included two more examples, the 2D chainLink [45], the Noisy two moon as depicted in Figs. 10 and 11. By applying SRA coefficients to this dataset, we acquired a new data representation that enables the easy separation of classes, as shown in Fig. 10b. Even without knowledge of the true data-generating process, it is apparent that all observations have been moved strategically so that a simple rule can effectively isolate nearly all yellow observations of interest.
1.3 Additional results on the applicative case studies
1.3.1 Bank churn modeling
As shown in Fig. 12, the effect of the age feature on the churn risk is bell-shaped according to TabSRALinear, XGBoost+TreeSHAP, and EBMs. Moreover there is a strong interaction between the Age and IsActiveMember feature highlighted par the important influence on the churn score of the age (around 60) for non active member.
To handle this interaction, the GAMs based inherently interpretable solution EBM needed to break it down into main effect and interaction effect, which is not necessary with TabSRALinear. Overall, these findings explain the poor predictive performance of the logistic regression (LR) as highlighted in Sect. 4.4.1.
1.3.2 Credit card default
In Sect. 4.4.2, we demonstrate using the credit card default dataset that TabSRALinear can generate more concise explanations compare to XGBoost+TreeSHAP. As depicted in Fig. 13, EBMs also spread contribution among correlated features. As a result, their feature attributions are less sparse compared to those of TabSRALinear, particularly with EBMs having pairwise interactions, which may assign nonzero feature attribution to interaction terms as well as main effects.
1.4 Additional results for the robustness study
The output of piecewise constant approximators (for instance EBMs and XGBoost) is generally more sensitive to input perturbation compared to linear models (LR) and TabSRALinear, as depicted in Fig. 14. We argue that this is due to their flexibility in producing discontinuities or learning irregular functions [16, 24].
1.5 Implementation details for the predictive performance evaluation
1.5.1 Search space for hyperparameters
For full-complexity models, the hyperparameter spaces are derived from the previous study [16]. For tree-based models, the number of estimators (trees) is not tuned but rather set to a high value: 250 for Random Forest (RF), 1000 for XGBoost, and CatBoost. For the latter, early stopping is employed with patience 20. Default parameters for tree-based models are Scikit-Learn/XGBoost/CatBoost’s defaults.
For Neural Nets (NNs), the maximal number of epochs is set to 300 with early stopping and checkpoint (the best model on the validation set is kept). The early stopping round is 40 for MLP, ResNet, TabSRALinear, Linear, 10 for SAINT, and 50 for EBMs. Note that for the model using early stopping, 20% of the training dataset is used as validation set (different from the validation which is in the test part).
For NNs, we used the Pytorch library [46], the Weights, and Biases platform [47] for hyperparameters optimization and experiment tracking (Tables 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18 and 19).
1.6 Additional results about the predictive performance
1.6.1 Results of hyperpameters optimization using all iterations
In this part, we present the results of hyperparameter tuning using all iterations instead of the use bootstrapping as decribed in Sect. 4.1.3.
As indicated in Table 20, utilizing all iterations once instead of employing bootstrapping does not change the predictive ranking of algorithms and our conclusions (Sect. “Results of hyperpameters optimization using all iterations”) (Tables 21, 22, 23, 24, 25, 26, and 27).
1.6.2 Performance per dataset
In this section, we provide the detailed individual results for each task in the in middle-scale benchmark.
\(R^2\) is reported as predictive performance metric for regression tasks and the accuracy is reported for classification tasks.
Rights and permissions
Springer Nature or its licensor (e.g. a society or other partner) holds exclusive rights to this article under a publishing agreement with the author(s) or other rightsholder(s); author self-archiving of the accepted manuscript version of this article is solely governed by the terms of such publishing agreement and applicable law.
About this article
Cite this article
Amekoe, K.M., Azzag, H., Dagdia, Z.C. et al. Exploring accuracy and interpretability trade-off in tabular learning with novel attention-based models. Neural Comput & Applic 36, 18583–18611 (2024). https://doi.org/10.1007/s00521-024-10163-9
Received:
Accepted:
Published:
Issue Date:
DOI: https://doi.org/10.1007/s00521-024-10163-9