Learning to Reweight for Generalizable Graph Neural Network

Authors

  • Zhengyu Chen Institute of Artificial Intelligence, Zhejiang University Shanghai Institute for Advanced Study, Zhejiang University
  • Teng Xiao The Pennsylvania State University
  • Kun Kuang Institute of Artificial Intelligence, Zhejiang University Shanghai Institute for Advanced Study, Zhejiang University
  • Zheqi Lv Institute of Artificial Intelligence, Zhejiang University Shanghai Institute for Advanced Study, Zhejiang University
  • Min Zhang Institute of Artificial Intelligence, Zhejiang University Shanghai Institute for Advanced Study, Zhejiang University
  • Jinluan Yang Institute of Artificial Intelligence, Zhejiang University Shanghai Institute for Advanced Study, Zhejiang University
  • Chengqiang Lu DAMA Academy, Alibaba Group
  • Hongxia Yang DAMA Academy, Alibaba Group
  • Fei Wu Institute of Artificial Intelligence, Zhejiang University Shanghai Institute for Advanced Study, Zhejiang University

DOI:

https://doi.org/10.1609/aaai.v38i8.28673

Keywords:

DMKM: Graph Mining, Social Network Analysis & Community

Abstract

Graph Neural Networks (GNNs) show promising results for graph tasks. However, existing GNNs' generalization ability will degrade when there exist distribution shifts between testing and training graph data. The fundamental reason for the severe degeneration is that most GNNs are designed based on the I.I.D hypothesis. In such a setting, GNNs tend to exploit subtle statistical correlations existing in the training set for predictions, even though it is a spurious correlation. In this paper, we study the problem of the generalization ability of GNNs on Out-Of-Distribution (OOD) settings. To solve this problem, we propose the Learning to Reweight for Generalizable Graph Neural Network (L2R-GNN) to enhance the generalization ability for achieving satisfactory performance on unseen testing graphs that have different distributions with training graphs. We propose a novel nonlinear graph decorrelation method, which can substantially improve the out-of-distribution generalization ability and compares favorably to previous methods in restraining the over-reduced sample size. The variables of graph representation are clustered based on the stability of their correlations, and graph decorrelation method learns weights to remove correlations between the variables of different clusters rather than any two variables. Besides, we introduce an effective stochastic algorithm based on bi-level optimization for the L2R-GNN framework, which enables simultaneously learning the optimal weights and GNN parameters, and avoids the over-fitting issue. Experiments show that L2R-GNN greatly outperforms baselines on various graph prediction benchmarks under distribution shifts.

Published

2024-03-24

How to Cite

Chen, Z., Xiao, T. ., Kuang, K., Lv, Z., Zhang, M., Yang, J., Lu, C., Yang, H., & Wu, F. (2024). Learning to Reweight for Generalizable Graph Neural Network. Proceedings of the AAAI Conference on Artificial Intelligence, 38(8), 8320-8328. https://doi.org/10.1609/aaai.v38i8.28673

Issue

Section

AAAI Technical Track on Data Mining & Knowledge Management