Nothing Special   »   [go: up one dir, main page]

Skip to main content

Advertisement

Log in

Exploring accuracy and interpretability trade-off in tabular learning with novel attention-based models

  • Original Article
  • Published:
Neural Computing and Applications Aims and scope Submit manuscript

Abstract

Apart from high accuracy, what interests many researchers and practitioners in real-life tabular learning problems (e.g., fraud detection and credit scoring) is uncovering hidden patterns in the data and/or providing meaningful justification of decisions made by machine learning models. In this concern, an important question arises: should one use inherently interpretable models or explain full-complexity models such as XGBoost, Random Forest with post hoc tools? Opting for the second choice is typically supported by the accuracy metric, but it is not always evident that the performance gap is sufficiently significant, especially considering the current trend of accurate and inherently interpretable models, as well as accounting for other real-life evaluation metrics such as faithfulness, stability, and computational cost of explanations. In this work, we show through benchmarking on 45 datasets that the relative accuracy loss is less than 4% in average when using intelligible models such as explainable boosting machine. Furthermore, we propose a simple use of model ensembling to improve the expressiveness of TabSRALinear, a novel attention-based inherently interpretable solution, and demonstrate both theoretically and empirically that it is a viable option for (1) generating stable or robust explanations and (2) incorporating human knowledge during the training phase. Source code is available at https://github.com/anselmeamekoe/TabSRA.

This is a preview of subscription content, log in via an institution to check access.

Access this article

Subscribe and save

Springer+ Basic
$34.99 /Month
  • Get 10 units per month
  • Download Article/Chapter or eBook
  • 1 Unit = 1 Article or 1 Chapter
  • Cancel anytime
Subscribe now

Buy Now

Price excludes VAT (USA)
Tax calculation will be finalised during checkout.

Instant access to the full article PDF.

Fig. 1
Fig. 2
Fig. 3
Fig. 4
Fig. 5
Fig. 6
Fig. 7
Fig. 8
Fig. 9

Similar content being viewed by others

Explore related subjects

Discover the latest articles, news and stories from top researchers in related subjects.

Notes

  1. The neural networks (NNs) described in Sect. 2.2.2 include an embedding layer designed to manage categorical features, as outlined in the work by Gorishniy et al. [6]. NNs (Sect. 2.2.2) are equipped with an embedding layer for handling categorical feature [6] and CatBoost also offers a native approach for handling categorical features through a combination of target encoding (referred to as “Borders") and Frequency Encoding (referred to as “Counter")

  2. Gaussian quantile transformation was used in Grinsztajn et al. [16]; however, as this transformation is not bijective/linear, we only use scaling to preserve interpretability in the initial feature space.

  3. For the sake of brevity, we used datasets containing only numerical features, as the notion of neighborhood and Lipschtitz estimate is not trivial for categorical and discrete features. For more details, see Alvarez-Melis and Jaakkola [33].

  4. This dataset is highly imbalanced; therefore, AUCPR is a more appropriate than AUCROC [39].

  5. One may want to process the NumOfProducts in binary feature, i.e., one product versus more than one product. This will help maintain the test AUCROC of XGBoost at approximately 0.828 for this dataset.

  6. The obtained explanations using TreeSHAP are not actually monotonic decreasing as they are supposed to be Fig. 6b. That is, the effect of the NumOfProducts for some customers with three products is higher than the one of some customers with two products.

  7. In general, these explanations do not represent causality, but rather a potential description of the association between the observed feature and the target variable.

  8. https://github.com/anselmeamekoe/TabSRA.

References

  1. Vaswani A, Shazeer N, Parmar N, Uszkoreit J, Jones L, Gomez AN, Kaiser Ł, Polosukhin I (2017) Attention is all you need. Advances in neural information processing systems 30

  2. Chen T, Guestrin C (2016) Xgboost: A scalable tree boosting system. In: Proceedings of the 22nd Acm Sigkdd International Conference on Knowledge Discovery and Data Mining, pp 785–794

  3. Somepalli G, Goldblum M, Schwarzschild A, Bruss CB, Goldstein T (2021) Saint: Improved neural networks for tabular data via row attention and contrastive pre-training. arXiv preprint arXiv:2106.01342

  4. Kossen J, Band N, Lyle C, Gomez AN, Rainforth T, Gal Y (2021) Self-attention between datapoints: going beyond individual input-output pairs in deep learning. Adv Neural Inf Process Syst 34:28742–28756

    Google Scholar 

  5. Huang X, Khetan A, Cvitkovic M, Karnin Z (2020) Tabtransformer: Tabular data modeling using contextual embeddings. arXiv preprint arXiv:2012.06678

  6. Gorishniy Y, Rubachev I, Khrulkov V, Babenko A (2021) Revisiting deep learning models for tabular data. Adv Neural Inf Process Syst 34:18932–18943

    Google Scholar 

  7. Lundberg SM, Lee S-I (2017) A unified approach to interpreting model predictions. Advances in neural information processing systems 30

  8. Ribeiro MT, Singh S, Guestrin C (2016) " why should i trust you?" explaining the predictions of any classifier. In: Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, pp 1135–1144

  9. Lundberg SM, Erion G, Chen H, DeGrave A, Prutkin JM, Nair B, Katz R, Himmelfarb J, Bansal N, Lee S-I (2020) From local explanations to global understanding with explainable AI for trees. Nat Mach Intell 2(1):56–67

    Article  Google Scholar 

  10. Amoukou SI, Salaün T, Brunel N (2022) Accurate shapley values for explaining tree-based models. In: International Conference on Artificial Intelligence and Statistics. PMLR, pp 2448–2465

  11. Kumar IE, Venkatasubramanian S, Scheidegger C, Friedler S (2020) Problems with shapley-value-based explanations as feature importance measures. In: International Conference on Machine Learning. PMLR, pp 5491–5500

  12. Rudin C (2019) Stop explaining black box machine learning models for high stakes decisions and use interpretable models instead. Nat Mach Intell 1(5):206–215

    Article  Google Scholar 

  13. Nori H, Jenkins S, Koch P, Caruana R (2019) Interpretml: A unified framework for machine learning interpretability. arXiv preprint arXiv:1909.09223

  14. Agarwal R, Melnick L, Frosst N, Zhang X, Lengerich B, Caruana R, Hinton GE (2021) Neural additive models: interpretable machine learning with neural nets. Adv Neural Inf Process Syst 34:4699–4711

    Google Scholar 

  15. Chang C-H, Caruana R, Goldenberg A (2021) Node-gam: Neural generalized additive model for interpretable deep learning. arXiv preprint arXiv:2106.01613

  16. Grinsztajn L, Oyallon E, Varoquaux G (2022) Why do tree-based models still outperform deep learning on typical tabular data? Adv Neural Inf Process Syst 35:507–520

    Google Scholar 

  17. Amekoe KM, Azzag H, Lebbah M, Dagdia ZC, Jaffre G (2023) A new class of intelligible models for tabular learning. In: The 5th International Workshop on eXplainable Knowledge Discovery in Data Mining (PKDD)-ECML-PKDD

  18. Amekoe KM, Dilmi MD, Azzag H, Dagdia ZC, Lebbah M, Jaffre G (2023) Tabsra: An attention based self-explainable model for tabular learning. In: The 31th European Symposium on Artificial Neural Networks, Computational Intelligence and Machine Learning (ESANN)

  19. Breiman L (2001) Random forests. Mach Learn 45(1):5–32

    Article  Google Scholar 

  20. Ke G, Meng Q, Finley T, Wang T, Chen W, Ma W, Ye Q, Liu T-Y (2017) Lightgbm: a highly efficient gradient boosting decision tree. Advances in neural information processing systems 30

  21. Prokhorenkova L, Gusev G, Vorobev A, Dorogush AV, Gulin A (2018) Catboost: unbiased boosting with categorical features. Advances in neural information processing systems 31

  22. Chen K-Y, Chiang P-H, Chou H-R, Chen T-W, Chang T-H (2023) Trompt: Towards a better deep neural network for tabular data. arXiv preprint arXiv:2305.18446

  23. Borisov V, Leemann T, Seßler K, Haug J, Pawelczyk M, Kasneci G (2021) Deep neural networks and tabular data: A survey. arXiv preprint arXiv:2110.01889

  24. McElfresh D, Khandagale S, Valverde J, Prasad C V, Ramakrishnan G, Goldblum M, White C (2023) When do neural nets outperform boosted trees on tabular data? arXiv e-prints, 2305

  25. Huang X, Marques-Silva J (2023) The inadequacy of shapley values for explainability. arXiv preprint arXiv:2302.08160

  26. Ribeiro MT, Singh S, Guestrin C (2018) Anchors: High-precision model-agnostic explanations. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol 32

  27. Marques-Silva J, Ignatiev A (2022) Delivering trustworthy ai through formal xai. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol 36. pp 12342–12350

  28. Chen H, Covert IC, Lundberg SM, Lee S-I (2023) Algorithms to estimate shapley value feature attributions. Nat Mach Intell 5(6):590–601

    Article  Google Scholar 

  29. Yang Z, Zhang A, Sudjianto A (2021) Gami-net: an explainable neural network based on generalized additive models with structured interactions. Pattern Recogn 120:108192

    Article  Google Scholar 

  30. Lou Y, Caruana R, Gehrke J, Hooker G (2013) Accurate intelligible models with pairwise interactions. In: Proceedings of the 19th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining. pp 623–631

  31. Popov S, Morozov S, Babenko A (2019) Neural oblivious decision ensembles for deep learning on tabular data. arXiv preprint arXiv:1909.06312

  32. Chen Z, Tan S, Nori H, Inkpen K, Lou Y, Caruana R (2021) Using explainable boosting machines (ebms) to detect common flaws in data. In: Joint European Conference on Machine Learning and Knowledge Discovery in Databases. Springer, pp 534–551

  33. Alvarez-Melis D, Jaakkola TS (2018) On the robustness of interpretability methods. arXiv preprint arXiv:1806.08049

  34. Buitinck L, Louppe G, Blondel M, Pedregosa F, Mueller A, Grisel O, Niculae V, Prettenhofer P, Gramfort A, Grobler J et al (2013) Api design for machine learning software: experiences from the scikit-learn project. arXiv preprint arXiv:1309.0238

  35. Ignatiev A, Izza Y, Stuckey PJ, Marques-Silva J (2022) Using maxsat for efficient explanations of tree ensembles. In: AAAI

  36. Shrikumar A, Greenside P, Kundaje A (2017) Learning important features through propagating activation differences. In: International Conference on Machine Learning. PMLR, pp 3145–3153

  37. Alvarez Melis D, Jaakkola T (2018) Towards robust interpretability with self-explaining neural networks. In: Bengio S, Wallach H, Larochelle H, Grauman K, Cesa-Bianchi N, Garnett R (eds) Advances in Neural Information Processing Systems, vol 31. Curran Associates, Inc

  38. Agarwal C, Johnson N, Pawelczyk M, Krishna S, Saxena E, Zitnik M, Lakkaraju H (2022) Rethinking Stability for Attribution-based Explanations

  39. Davis J, Goadrich M (2006) The relationship between precision-recall and roc curves. In: Proceedings of the 23rd International Conference on Machine Learning. pp 233–240

  40. Wistuba M, Schilling N, Schmidt-Thieme L (2015) Learning hyperparameter optimization initializations. In: 2015 IEEE International Conference on Data Science and Advanced Analytics (DSAA). IEEE, pp 1–10

  41. Gorishniy Y, Rubachev I, Babenko A (2022) On embeddings for numerical features in tabular deep learning. Adv Neural Inf Process Syst 35:24991–25004

    Google Scholar 

  42. Lengerich B, Tan S, Chang C-H, Hooker G, Caruana R (2020) Purifying interaction effects with the functional anova: An efficient algorithm for recovering identifiable additive models. In: International Conference on Artificial Intelligence and Statistics. PMLR, pp 2402–2412

  43. Müller S, Toborek V, Beckh K, Jakobs M, Bauckhage C, Welke P (2023) An Empirical Evaluation of the Rashomon Effect in Explainable Machine Learning

  44. Kim H, Papamakarios G, Mnih A (2021) The lipschitz constant of self-attention. In: International Conference on Machine Learning. PMLR, pp 5562–5571

  45. Ultsch A (2005) Clustering wih som: U* c. Proc, Workshop on Self-Organizing Maps

  46. Paszke A, Gross S, Massa F, Lerer A, Bradbury J, Chanan G, Killeen T, Lin Z, Gimelshein N, Antiga L, et al (2019) Pytorch: an imperative style, high-performance deep learning library. Advances in neural information processing systems 32

  47. Biewald L (2020) Experiment Tracking with Weights and Biases. Software available from wandb.com. https://www.wandb.com/

Download references

Author information

Authors and Affiliations

Authors

Corresponding author

Correspondence to Kodjo Mawuena Amekoe.

Ethics declarations

Conflict of interest

The authors have no conflict of interest to declare that are relevant to the content of this article.

Supplementary information

All data supporting the findings of this study are available within the paper and its Supplementary Information.

Additional information

Publisher's Note

Springer Nature remains neutral with regard to jurisdictional claims in published maps and institutional affiliations.

Appendices

Appendix A: additional theoretical results

1.1 On the Lipschitz estimate of TabSRALinear ensemble

Theorem 2

The feature attributions produced by TabSRALinear ensemble (Eq. 9) are locally stable in the sense of Lipschitz, that is, for all \(\textbf{x} \in {\mathbbm {R}}^{p}\), there exists \(\delta >0\) and \(L_{\textbf{x}}\ge 0\) finite such that:

$$\begin{aligned}\Vert \textbf{x} - \textbf{x}' \Vert _1 < \delta \,\implies \, & \Vert \sum _{h=1}^{H}\varvec{\beta }^h \odot a^h(\textbf{x}) \odot \textbf{x} \\& - \sum _{h=1}^{H}\varvec{\beta }^h \odot a^h(\textbf{x}') \odot \textbf{x}' \Vert _1&\le L_{\textbf{x}}\Vert \textbf{x} - \textbf{x}' \Vert _1 \end{aligned}$$
(15)

With \(L_{\textbf{x}} = \sum _{h=1}^{H}\Vert \varvec{\beta }^h \Vert _{\infty } \left[ \Vert \textbf{a}^h \Vert _{\infty } + L_{{\textbf{a}}}^h(\Vert \textbf{x} \Vert _{\infty }+\delta )\right]\) and \(L_{{\textbf{a}}}^h\ge 0\) the Lipschitz constant of the h-th SRA block.

Theorem 2 demonstrates that the Lipschitz estimate of the TabSRALinear ensemble is theoretically additive with respect to the number of learners in the ensemble, denoted H. Hence, a large value of H may lead to less stable explanations. Our experimental results on default benchmark and the ablation study (Table 6) demonstrate that \(H=2\) is typically sufficient to achieve strong predictive performance.

The proof of the Theorem 2 can be readily accomplished by using the fact that: (1) the feature attributions of each individual TabSRALinear in the ensemble are locally stable in the sense of Lipschitz in accordance with Theorem 1. (2) A finite sum of Lipschitz continuous functions is Lipschitz continuous, as stated in Lemma 3).

1.2 Proof theorem 1

Before providing a proof of the theorem, we consider the following lemmas:

Lemma 3

  1. (1)

    The sum of two Lipschitz continuous functions is Lipschitz continuous

  2. (2)

    The product of two Lipschitz continuous and bounded functions is Lipschitz continuous

Proof

We consider \(\theta\) and \(\psi\) two functions from \({\mathbbm {R}}^{p} \longrightarrow {\mathbbm {R}}^{p}\), \(L_\theta\) and \(L_\psi -\) Lipschitz. For all \(\textbf{x}, \textbf{x}' \in {\mathbbm {R}}^p\) we have: (1)

$$\begin{aligned} &\Vert (\theta +\psi )(\textbf{x})- (\theta +\psi )(\textbf{x}') \Vert _1 = \Vert (\theta (\textbf{x})-\theta (\textbf{x}'))\\&\quad + (\psi (\textbf{x})-\psi (\textbf{x}'))\Vert _1 \\&\quad\le \Vert \theta (\textbf{x})-\theta (\textbf{x}')\Vert _1 \\&\quad+\Vert \psi (\textbf{x})-\psi (\textbf{x}') \Vert _1 \quad \text {Minkowski's inequality}\\&\quad\le L_\theta \Vert \textbf{x} - \textbf{x}' \Vert _1+ L_\psi \Vert \textbf{x} - \textbf{x}' \Vert _1\\&\quad= ( L_\theta +L_\psi )\Vert \textbf{x} - \textbf{x}' \Vert _1\\ \end{aligned}$$
(16)

(2) Moreover let’s assume that \(\theta\) and \(\psi\) are bounded i.e. \(\Vert \theta \Vert _{\infty }<{\infty }\) and \(\Vert \psi \Vert _{\infty }<{\infty }\), then we have:

$$\begin{aligned}&\Vert (\theta \odot \psi )(\textbf{x})- (\theta \odot \psi )(\textbf{x}') \Vert _1 \\ &\quad= \Vert \theta (\textbf{x})\odot ( \psi (\textbf{x})-\psi (\textbf{x}'))\\ &\quad+ \psi (\textbf{x}')\odot ((\theta (\textbf{x})-\theta (\textbf{x}'))\Vert _1 \\&\quad \le \Vert \theta (\textbf{x})\odot ( \psi (\textbf{x})-\psi (\textbf{x}')) \Vert _1 + \Vert \psi (\textbf{x}')\\ &\quad\odot ((\theta (\textbf{x})-\theta (\textbf{x}')) \Vert _1 \quad \\&\text {using Hlder's inequality:}\\&\quad\le \Vert \theta (\textbf{x}) \Vert _{\infty }\Vert \psi (\textbf{x})-\psi (x') \Vert _1+\Vert \psi (\textbf{x}) \Vert _{\infty }\Vert \theta (\textbf{x})-\theta (\textbf{x}') \Vert _1\\&\quad\le L_\psi \Vert \theta (\textbf{x}) \Vert _{\infty }\Vert \textbf{x} - \textbf{x}' \Vert _1+ L_\theta \Vert \psi (\textbf{x}) \Vert _{\infty }\Vert \textbf{x} - \textbf{x}' \Vert _1\\&\quad= ( L_\theta \Vert \psi (\textbf{x}) \Vert _{\infty }+ L_\psi \Vert \theta (\textbf{x}) \Vert _{\infty })\Vert \textbf{x} - \textbf{x}' \Vert _1\\&\quad\le ( L_\theta \Vert \psi \Vert _{\infty }+ L_\psi \Vert \theta \Vert _{\infty })\Vert \textbf{x} - \textbf{x}' \Vert _1\\ \end{aligned}$$
(17)

\(\square\)

Lemma 4

The attention vector outputted using the SRA block is stable is sense of Lipschitz i.e., for all \(\textbf{x}, \textbf{x}' \in {\mathbbm {R}}^p\), there exists a constant \(L_{{\textbf{a}}}\ge 0\) finite such that:

$$\begin{aligned} \begin{aligned}&\Vert a(\textbf{x})- a(\textbf{x}') \Vert _1 \le L_{{\textbf{a}}}\Vert \textbf{x} - \textbf{x}' \Vert _1 \end{aligned} \end{aligned}$$
(18)

Moreover, \(\Vert a \Vert _{\infty }=1\).

Proof

Each component \(k_i^j\) of the keys matrix K (resp. \(q_i^j\) of Q) is stable in the sense of Lipschtiz as outputted by a fully connected layer [44] (linear transformations followed by common activation such as ReLU, Sigmoid) and is bounded in [0, 1] using Sigmoid activation. Hence, the for all \(\textbf{x}, \textbf{x}' \in {\mathbbm {R}}^p\) there exists \(\alpha _{k_i^j}>0\), \(\alpha _{q_i^j}>0\), finite such that \(|k_i^j(\textbf{x}) - k_i^j(\textbf{x}') |= \alpha _{k_i^j}\Vert \textbf{x} - \textbf{x}' \Vert _1\) and \(|q_i^j(\textbf{x}) - q_i^j(\textbf{x}') |= \alpha _{q_i^j}\Vert \textbf{x} - \textbf{x}' \Vert _1\).

Then, each component of the attention vector is also Lipschitz since:

$$\begin{aligned}&|a_i(\textbf{x})- a_i(\textbf{x}') |= \left|\frac{1}{d_k} \left(\sum _{j=1}^{d_k}{k_i^j(\textbf{x})q_i^j(\textbf{x})}-\sum _{j=1}^{d_k}{k_i^j(\textbf{x}')q_i^j(\textbf{x}')} \right)\right|\\&\;= |\frac{1}{d_k} \sum _{j=1}^{d_k}({k_i^j(\textbf{x})q_i^j(\textbf{x})-{k_i^j(\textbf{x}')q_i^j(\textbf{x}')}} )|\\&\;\le \frac{1}{d_k} \sum _{j=1}^{d_k} \alpha _{i}^j\Vert \textbf{x} - \textbf{x}' \Vert _1 \; \text {product and sum of Lipsctiz function} \\&\;= L_{a_i}\Vert \textbf{x} - \textbf{x}' \Vert _1 \end{aligned}$$
(19)

with \(L_{a_i} = \frac{1}{d_k} \sum _{j=1}^{d_k} \alpha _{i}^j\).

Finally, we have:

$$\begin{aligned} &\Vert a(\textbf{x})- a(\textbf{x}') \Vert _1 = \sum _{i=1}^{p}{|a_i(\textbf{x})- a_i(\textbf{x}') |}\\&\quad\le \sum _{i=1}^{p}{ L_{a_i}\Vert \textbf{x} - \textbf{x}' \Vert _1} \quad \text {using the Eq.~19} \\&\quad= L_{{\textbf{a}}}\Vert \textbf{x} - \textbf{x}' \Vert _1 \end{aligned}$$
(20)

with \(L_{{\textbf{a}}} = \sum _{i=1}^{p}{L_{a_i}}\).

Since every \(a_i \in [0,1]\) we have \(\Vert {\textbf{a}} \Vert _{\infty }=1\). \(\square\)

Proof

of Theorem 1

$$\begin{aligned} &|\varvec{\beta } \cdot (a(\textbf{x}) \odot \textbf{x}) - \varvec{\beta } \cdot (a(\textbf{x}') \\&\quad\odot \textbf{x}') |= |\varvec{\beta } \cdot (a(\textbf{x}) \odot \textbf{x} - a(\textbf{x}') \odot \textbf{x}') |\\&\quad\le \Vert \varvec{\beta } \Vert _{\infty } \Vert a(\textbf{x}) \odot \textbf{x} - a(\textbf{x}') \odot \textbf{x}' \Vert _1 \\ \end{aligned}$$
(21)

Using 17 and considering \(\psi (\textbf{x}) = a(\textbf{x})\) , \(\theta (\textbf{x}) = \textbf{x}\) we have \(\Vert \psi \Vert _{\infty } = \Vert a \Vert _{\infty } =1\) and \(\Vert \theta \Vert _{\infty } = \Vert \mathcal {D} \Vert _{\infty } =\max _{\textbf{x} \in \mathcal {D}}\Vert \textbf{x} \Vert _{\infty }\) which is the overall maximal observable feature value.

Therefore:

$$\begin{aligned}& |\varvec{\beta } \cdot (a(\textbf{x}) \odot \textbf{x}) - \varvec{\beta } \cdot (a(\textbf{x}') \odot \textbf{x}') | \\&\quad \le \Vert \varvec{\beta } \Vert _{\infty } ( \Vert {\textbf{a}} \Vert _{\infty } +L_{{\textbf{a}}}\Vert \mathcal {D} \Vert _{\infty } ) \Vert \textbf{x} - \textbf{x}' \Vert _1 \end{aligned}$$
(22)

In the formulation of the TabSRALinear, we are not necessarily interested in global stability or a uniform Lipschitz constant, as in Eq. 22 but rather in regional or local stability around a given target or anchor data point.

With this consideration, given the target data point \(\textbf{x}\), we can restrict \(\mathcal {D}\) to its neighborhood i.e.,

$$\begin{aligned} \mathcal {D} = \{\textbf{x}' \in {\mathbbm {R}}^{p}/\Vert \textbf{x} - \textbf{x}' \Vert _1 < \delta \} \end{aligned}$$
(23)

therefore \(\Vert \mathcal {D} \Vert _{\infty } \le \Vert \textbf{x} \Vert _{\infty } +\delta\).

Using the Eq. 22, it results that for every \(\textbf{x} \in {\mathbbm {R}}^{p}\), there exists a constant \(\delta > 0\) such that \(\Vert \textbf{x} - \textbf{x}' \Vert _1 < \delta\) implies:

$$\begin{aligned} \begin{aligned}&|\varvec{\beta } \cdot (a(\textbf{x}) \odot \textbf{x}) - \varvec{\beta } \cdot (a(\textbf{x}') \odot \textbf{x}') |\\& \quad \le \Vert \varvec{\beta } \Vert _{\infty } \left( 1 + L_{\textbf{a}}(\Vert \textbf{x} \Vert _{\infty } +\delta )\right) \Vert \textbf{x} - \textbf{x}' \Vert _1 \end{aligned} \end{aligned}$$
(24)

\(\square\)

Appendix B: additional empirical information

1.1 Datasets

1.1.1 Middle-scale benchmark

This benchmark of 45 datasets (59 binary classification and regression and tasks) is introduced in a paper titled: Why do tree-based models still outperform deep learning on typical tabular data? [16]. The main goal of this benchmark was to identify certain meta-features or inductive biases that explain the superior predictive performance of tree-based models over NNs in tabular learning. We take a step forward by incorporating inherently interpretable models in the assessment the predictive performance. We provide essential details about datasets and refer the interested reader to the original paper [16] for further information. The main criteria for selecting datasets are:

  • The datasets contain heterogeneous features (this excludes images and signal datasets).

  • The datasets are not high dimensional. That is, the ratio a p/n is below 1/10, \(p<500\) with p the number of features and n the number of observations.

  • The data are I.I.D. Stream-like datasets or time series are removed.

  • They are real-world data.

  • The datasets are not too easy and not too small, i.e., \(p\ge 4\) and \(n \ge 3 000\).

  • They are not deterministic datasets. Datasets where the target is a deterministic function of the predictors of the predictors.

Furthermore, in order to keep the learning as homogeneous as possible, some subproblems that deserve their own analysis are excluded. That is:

  • The size of the datasets. In the middle-scale regime, the training set is truncated to 10,000 and the test set to 50,000.

  • There is no missing data. All data points with missing values are removed.

  • Balanced classes. For classification, the target is binarized if there are several classes by taking the two most numerous classes, and we keep half of the samples in each class.

  • Categorical features with more than 20 items are removed.

  • Numerical features with less than 10 unique values are removed.

Finally, every algorithm and hyperparameters combination is evaluated on the same random seed Train/Validation/Test split or fold. More precisely, 70% of samples for the train set (or the percentage which corresponds to a maximum of 10,000 samples if 70% is too high). Of the remaining 30%, 30% are used for the validation set (truncated to a maximum of 50,000 samples), and 70% for the test set (also truncated to a maximum of 50,000 samples). Depending on the size of the test set, several random seed splits/folds are used (cross-validation). That is:

  • If the test set is more than 6000 samples, we evaluate our algorithms on onefold.

  • If the test set is between 3000 and 6000 samples, we evaluate our algorithms in twofolds.

  • If the test set is 1000 and 3000 samples, we evaluate our algorithms on threefolds.

  • If the test set is less than 1000 testing samples, we evaluate our algorithms on fivefolds.

1.1.2 Default benchmark

  • Bank Churn: This dataset contains details of a bank’s customers, and the target variable is a binary variable reflecting the fact whether the customer left the bank (closed his account) or continues to be a customer. We drop the CustomerId and Surname columns. We also drop the HasCrCard column, which is non-really informative https://www.kaggle.com/datasets/shrutimechlearn/churn-modelling.

  • Credit Default: The goal is to predict the default probability (for the next month) of credit card clients in Taiwan using historical data. https://archive.ics.uci.edu/ml/datasets/default+of+credit+card+clients. We drop the SEX column as generally, the gender information is not used in credit scoring.

  • Credit Card Fraud: For this dataset, the aim is to predict whether credit card transactions are fraudulent or genuine. The original dataset is PCA transformed for privacy purpose. We drop the time information in our study. https://www.kaggle.com/mlg-ulb/creditcardfraud

  • Heloc Fico: For this dataset, the target variable to predict is a binary variable called RiskPerformance. The value “Bad” indicates that a consumer was 90 days past due or worse at least once over a period of 24 months from when the credit account was opened. The value “Good” indicates that they have made their payments without ever being more than 90 days overdue. https://community.fico.com/s/explainable-machine-learning-challenge.

1.2 Additional results for TabSRAs: vizualization

How raw data are reinforced using the SRA block

In addition to Sect. 3.1, we included two more examples, the 2D chainLink [45], the Noisy two moon as depicted in Figs. 10 and 11. By applying SRA coefficients to this dataset, we acquired a new data representation that enables the easy separation of classes, as shown in Fig. 10b. Even without knowledge of the true data-generating process, it is apparent that all observations have been moved strategically so that a simple rule can effectively isolate nearly all yellow observations of interest.

Fig. 10
figure 10

ChainLink 2D: 1000 data points

Fig. 11
figure 11

Noisy two moons: 10,000 data points

1.3 Additional results on the applicative case studies

1.3.1 Bank churn modeling

Fig. 12
figure 12

Bank churn modeling: interaction between the Age and IsActiveMember feature. b XGBoost refers to the XGBoost+TreeSHAP feature attribution solution. c Indicates the main effect of the Age on the churn score for both EBM_S , EBM and d show the pairwise interaction for the latter


As shown in Fig. 12, the effect of the age feature on the churn risk is bell-shaped according to TabSRALinear, XGBoost+TreeSHAP, and EBMs. Moreover there is a strong interaction between the Age and IsActiveMember feature highlighted par the important influence on the churn score of the age (around 60) for non active member.

To handle this interaction, the GAMs based inherently interpretable solution EBM needed to break it down into main effect and interaction effect, which is not necessary with TabSRALinear. Overall, these findings explain the poor predictive performance of the logistic regression (LR) as highlighted in Sect. 4.4.1.

1.3.2 Credit card default


In Sect. 4.4.2, we demonstrate using the credit card default dataset that TabSRALinear can generate more concise explanations compare to XGBoost+TreeSHAP. As depicted in Fig. 13, EBMs also spread contribution among correlated features. As a result, their feature attributions are less sparse compared to those of TabSRALinear, particularly with EBMs having pairwise interactions, which may assign nonzero feature attribution to interaction terms as well as main effects.

Fig. 13
figure 13

Individual prediction understanding for the credit card default dataset

1.4 Additional results for the robustness study


The output of piecewise constant approximators (for instance EBMs and XGBoost) is generally more sensitive to input perturbation compared to linear models (LR) and TabSRALinear, as depicted in Fig. 14. We argue that this is due to their flexibility in producing discontinuities or learning irregular functions [16, 24].

Fig. 14
figure 14

Change in predictions using input perturbationes (Sect. 4.2.3). LR = Logistic Regression, SRA = TabSRALinear, XGB_SHAP = XGBoost+TreeSHAP

1.5 Implementation details for the predictive performance evaluation

1.5.1 Search space for hyperparameters

For full-complexity models, the hyperparameter spaces are derived from the previous study [16]. For tree-based models, the number of estimators (trees) is not tuned but rather set to a high value: 250 for Random Forest (RF), 1000 for XGBoost, and CatBoost. For the latter, early stopping is employed with patience 20. Default parameters for tree-based models are Scikit-Learn/XGBoost/CatBoost’s defaults.

For Neural Nets (NNs), the maximal number of epochs is set to 300 with early stopping and checkpoint (the best model on the validation set is kept). The early stopping round is 40 for MLP, ResNet, TabSRALinear, Linear, 10 for SAINT, and 50 for EBMs. Note that for the model using early stopping, 20% of the training dataset is used as validation set (different from the validation which is in the test part).

For NNs, we used the Pytorch library [46], the Weights, and Biases platform [47] for hyperparameters optimization and experiment tracking (Tables 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18 and 19).

Table 8 Decision Tree (DT)
Table 9 Linear Models (LR)
Table 10 TabSRALinear
Table 11 EBM_S
Table 12 EBM
Table 13 Random Forest (RF)
Table 14 XGBoost
Table 15 CatBoost
Table 16 MLP
Table 17 ResNet
Table 18 FT-Transformer
Table 19 SAINT

1.6 Additional results about the predictive performance

1.6.1 Results of hyperpameters optimization using all iterations

In this part, we present the results of hyperparameter tuning using all iterations instead of the use bootstrapping as decribed in Sect. 4.1.3.

As indicated in Table 20, utilizing all iterations once instead of employing bootstrapping does not change the predictive ranking of algorithms and our conclusions (Sect. “Results of hyperpameters optimization using all iterations”) (Tables 21, 22, 23, 24, 25, 26, and 27).

Table 20 Predictive performance of models across 59 tasks (45 datasets) using all random hyperparameters search iterations

1.6.2 Performance per dataset

In this section, we provide the detailed individual results for each task in the in middle-scale benchmark.

\(R^2\) is reported as predictive performance metric for regression tasks and the accuracy is reported for classification tasks.

Table 21 Regression tasks with numerical features only
Table 22 Regression tasks with numerical features only
Table 23 Regression tasks with heterogeneous features
Table 24 Regression tasks with heterogeneous features
Table 25 Classification tasks with numerical features only
Table 26 Classification tasks with numerical features only
Table 27 Classification tasks with heterogeneous features

Rights and permissions

Springer Nature or its licensor (e.g. a society or other partner) holds exclusive rights to this article under a publishing agreement with the author(s) or other rightsholder(s); author self-archiving of the accepted manuscript version of this article is solely governed by the terms of such publishing agreement and applicable law.

Reprints and permissions

About this article

Check for updates. Verify currency and authenticity via CrossMark

Cite this article

Amekoe, K.M., Azzag, H., Dagdia, Z.C. et al. Exploring accuracy and interpretability trade-off in tabular learning with novel attention-based models. Neural Comput & Applic 36, 18583–18611 (2024). https://doi.org/10.1007/s00521-024-10163-9

Download citation

  • Received:

  • Accepted:

  • Published:

  • Issue Date:

  • DOI: https://doi.org/10.1007/s00521-024-10163-9

Keywords

Navigation