33institutetext: School of Information Science and Engineering, Hebei North University, Zhangjiakou, China
33email: {zs1997,zhfzhu,zhzliu,zhyguo,19112005,yzhao}@bjtu.edu.cn
Multi-modal Graph Learning for Disease Prediction
Abstract
Benefiting from the powerful expressive capability of graphs, graph-based approaches have achieved impressive performance in various biomedical applications. Most existing methods tend to define the adjacency matrix among samples manually based on meta-features, and then obtain the node embeddings for downstream tasks by Graph Representation Learning (GRL). However, it is not easy for these approaches to generalize to unseen samples. Meanwhile, the complex correlation between modalities is also ignored. As a result, these factors inevitably yield the inadequacy of providing valid information about the patient’s condition for a reliable diagnosis. In this paper, we propose an end-to-end Multi-modal Graph Learning framework (MMGL) for disease prediction. To effectively exploit the rich information across multi-modality associated to diseases, a modal-attentional multi-modal fusion is proposed to integrate the features of each modality by leveraging the correlation and complementarity between the modalities. Furthermore, instead of defining the adjacency matrix manually as existing methods, the latent graph structure can be captured through a novel way of adaptive graph learning. It could be jointly optimized with the prediction model, thus revealing the intrinsic connections among samples. Unlike the previous transductive methods, our model is also applicable to the scenario of inductive learning for those unseen data. An extensive group of experiments on two disease prediction problems is then carefully designed and presented, demonstrating that MMGL obtains more favorable performances. In addition, we also visualize and analyze the learned graph structure to provide more reliable decision support for doctor in real medical applications and inspiration for disease research.
Keywords:
Multi-modal Disease prediction Graph learning Feature fusion1 Introduction
A large amount of relational data exists in the biomedical field. As a general description of relational data, graphs are facilitated to model a variety of biomedical scenarios [21]. Recently, inspired by the excellent performance of graph-based methods in machine learning, particularly graph convolutional networks (GCNs) [17, 24, 9], they have also been applied to handle relational data in various Computer Aided Diagnosis (CADx) tasks, such as Alzheimer prediction [18], Autism prediction [5], and cancer prognosis prediction [20].

Most existing methods try to construct the patient relationship graphs from existing features through pre-defined similarity measures, then apply GCNs to aggregate patient features over local neighborhoods to give the prediction results. These methods can be broadly classified into two categories: single-graph-based methods and multi-graph based methods. Parisot et al. [19] proposed to compute patient similarities from a set of meta-features such as age and sex, thus constructing the adjacency matrix to apply GCNs. [13, 27] employed the same graph construction rules and initially explored the effect of graph structure on performance of the disease prediction through setting different neighborhood sizes. These methods simply combine the imaging and non-imaging modalities through graph construction and GCNs. However, they fail to effectively mine the intrinsic information of each modality. Thus, several recent works have proposed to construct multiple graphs in parallel, in which each graph is built from different modality, then execute the integration of the embeddings learned from different graphs for the prediction. As shown in Fig.1(a), [8] concatenated the node embeddings directly, and both [12] and [22] adopted attention-based fusion mechanism like Fig.1(b).
Although the above methods have achieved remarkable performance, three key issues remain to be further considered with respect to the application of GCNs in disease prediction tasks, and even in some other biomedical tasks:
(1) Insufficient inter-modal relationship mining. Each modality provides different information for the diagnosis of a disease, which is both complementary and potentially redundant. However, concatenation [4, 8, 10] or intra-modal attention mechanism [12, 22] adopted in previous studies are hard to capture the latent inter-modal correlation, which may lead to the learned features are biased towards a single modality. To solve this issue, we propose a modal-attentional multi-modal fusion to mine the intrinsic relevance between modalities while preserving the individuality of each modality.
(2) Defining the graph adjacency matrix manually. Existing single-graph-based [19, 13, 25] and multi-graph-based methods [14, 8, 22] both construct the graph through hand-designed similarity measures, which requires a careful tuning and is thus difficult to generalize to downstream tasks. A better approach is to learn a graph by the means of end-to-end, but less focus has been put on the graph structure learning[26, 11], especially in the field of medicine[4]. Meanwhile, sigmoid-based graph learning mechanism in [4] is prone to causing gradient disappearance, which makes model training unstable. Thus, we propose a learning-based adaptive approach for graph learning to learn the graph structure dynamically. In fact, it provides a more feasible way for downstream tasks and reveals the latent connections among samples.
(3) Hard applicable to inductive learning. For the approaches based on spectral graph convolution like [10, 13, 19], it’s hard for them to generalize to unseen samples. Besides, to accommodate inductive learning, it is also essential but cumbersome for multi-graph-based methods [14, 8, 22] to measure relationship of unseen samples on each graph. Unlike these approaches, our MMGL can be flexibly extended to the scenario of inductive learning.
In this paper, we propose a novel Multi-modal Graph Learning framework (MMGL) for disease prediction, and the main contributions can be highlighted in the following aspects:
-
•
As a flexible modular inductive learning framework, the proposed MMGL provides some substantial improvements and inspirations for the application of GCN in disease prediction tasks.
-
•
Considering the correlation and complementarity between modalities, we propose a modal-attentional feature fusion approach (MaFF) that exploits the inter-modal relevance to integrate the multi-modal features.
-
•
To reveal the intrinsic connections among samples, a novel adaptive graph learning mechanism (AGL) is proposed to achieves learnable graph construction, thus obtaining the latent robust graph for downstream tasks.
-
•
The comparable even significant improvement compared to the state-of-the-art methods indicates the advantages of our MMGL in terms of disease prediction tasks.
2 Methodology
2.1 Problem Formulation
Let denote the -dimensional multi-modal features of patients and is the corresponding labels. Each patient is represented by modalities, so the multi-modal features also can denote as and , where represents the -dimensional features of the -th modality of and . Given the multi-modal features , the issue we consider in this paper for multi-modal graph learning is to achieve an optimal latent graph structure inference based on multi-modal feature fusion considering inter-modal correlation, thus providing reliable graph support to GCNs for disease prediction or the other biomedical tasks.
As illustrated in Fig. 2, the overall framework of MMGL consists of three modules. In feature fusion phase, the multi-modal features is integrated into a fused single feature matrix by modal-attentional feature fusion module (MaFF). Then, in graph learning phase, the adjacency matrix characterizing the latent graph structure is captured through a novel adaptive graph learning mechanism (AGL) based on the multi-modal modal-fused features. Finally, according to the learned adjacency matrix and fused feature matrix , the disease prediction results can be obtained by GCN.

2.2 Modal-attentional Feature Fusion
In a real diagnostic scenario, medical experts always need to analyze various multi-modal data of the patient to make a reliable decision, since the single-modal data lacks of providing enough information for an accurate diagnosis. Similarly, reliable Computer Aided Diagnosis (CADx) also requires to be capable of leveraging the complementarity of multi-modal data [15]. To capture inter-modal relevance and complementarity for efficient information integration, we propose a transformer-style [23] modal-attentional feature fusion module named MaFF, which fuses the features of each modality while preserving rich information.
Given the query vectors of current modality and key-value pairs of the others (i.e., -), the inter-modal attention score can be calculated. In practice, considering the problems of space-efficiency and parallelization, the scaled dot-product is chosen as attention function [23]. Since each modal feature has different dimension, as shown in Fig. 1(c), we first use the projection matrices (i.e., , , ) with fixed dimension to translate to , , and , which facilitates the computation of attention scores in the same dimensional space. For a specific patient , we can obtain the corresponding query matrix , key matrix , and value matrix . Then, the inter-modal attention score map for is computed as:
(1) |
where denotes the attention score between modality and . is the scaling factor to control the hardness of attention, which is set to like in [23]. After that, we perform cross-modal aggregation of the value vectors of each modality based on the calculated inter-modal correlations , thus the integrated feature of patient could be computed as:
(2) |
where and are projection layers. Besides, the implementation of residual connection between and the initial value vector can effectively avoid the gradient vanishing problem during training process.
Compared to the fusion module in [12, 4, 8], the modal-attentional fusion puts concerns on inter-modal correlation through a multi-modal attention mechanism. As a consequence, it tend to optimally combine the complementary information from different modalities. It’s worth noting that the multi-modal attention map is patient-specific, which is also applicable to inductive learning. And we could obtain global-level inter-modal correlation map by averaging over all patients. Besides,the proposed multi-modal attention mechanism can also be scalable to multi-head version easily.
2.3 Adaptive Graph Structure Learning
Based on existing graph structures, GCNs in [17, 24, 6] learn node representations for downstream tasks through neighborhood aggregation or spectral convolution. However, it’s no trivial to obtain the available graph for some specific tasks in biomedical field. Therefore, the graph learning problem often needs to be considered for applying GCNs in biomedical tasks.
For graph learning, it is usually modeled as two kinds of forms: (i) learning a joint discrete probability distribution on the edges of the graph [7], (ii) learning a similarity metric of nodes. Since the former is non-differentiable and hard applicable to inductive learning, we take the graph learning problem in consideration from the perspective of similarity metric learning of nodes. Some previous methods have adopted radial basis function (RBF) kernel [22, 27], cosine similarity [10], or threshold-based metric(for discrete feature) [19, 13] as the similarity metric. However, these approaches still require careful manual tuning to construct a meaningful graph structure for downstream GCNs. Therefore, as illustrated in Fig 2, we propose a simple but effective learnable metric function, which could be jointly optimized with the downstream GCNs:
(3) |
where is a learnable weight matrix and is computed as weighted cosine similarity between patient and . Since there are few uni-directional effects between patients except for epidemics, thus the learned adjacency matrix is symmetric that is also in accordance with the expectations of realistic population graph of patients.
Commonly, a realistic adjacency matrix is usually non-negative and sparse. However, since is a fully connected graph that is computationally expensive and the element is ranging in [-1,1], we capture a non-negative sparse graph from by applying the ReLU function, i.e., setting negative value elements in to zero. Finally, the latent graph is obtained for downstream tasks.
2.4 Model Optimization
Graph Regularization Due to the sensitivity of GCNs to the graph structure, graph learning has a significant impact on the performance of GCN in downstream tasks. Furthermore, the constraint on the sparsity, connectivity, and smoothness of the learned graph is also important for adaptive graph learning [3]. Here, the Dirichlet energy is used to measure the smoothness of a set of graph signals :
(4) |
It can be seen from Eq. 4 that the smaller the distance between and , the larger will be. Hence, the smooth loss is intended to making connections between similar nodes, which means to enforce smoothness of the graph signals on the learned graph .Essentially, simultaneously serves to control the sparsity of A [11]. However, only utilizing may lead to the trivial solution (i.e., ). To avoid it, two additional regularization terms following [11] are imposed on :
(5) |
where the first term uses logarithmic barrier to control the connectivity of , and the second term is a regularization term to avoid the excessive sparseness caused by . The total graph regularization loss is defined as .
2.4.1 Loss Function
Based on the integrated feature matrix and learned sparse graph , we use GNNs to give the predicted results of the patients. Unlike [4, 10] which optimize the graph structure directly based on task-aware prediction loss, we use a joint loss function to guide the optimization of all three modules of MMGL simultaneously:
(6) |
where and denote the task-aware loss and the graph regularization loss respectively, , , and are hyper-parameters to balance the three loss terms. For the disease prediction task treated as classification problems, is set to cross-entropy loss.
3 Experimental Results and Analysis
In this section, we evaluate the performance of MMGL on two biomedical datasets. We first detail our experimental protocol, and then present the comparison results of MMGL with the state-of-the-art methods in disease prediction tasks.
3.1 Datasets and Preprocessing
TADPOLE [18]: As a subset of the Alzheimer’s Disease Neuroimaging Initiative (ADNI) database (adni.loni.usc.edu), TADPOLE dataset contains features extracted from multi-modalites, which include MR, PET, cognitive tests, cerebro-spinal fluid (CSF) biomarkers, risk factors, clinical examinations and demographic information. For Alzheimer’s Disease prediction, we select 685 patients with 366-dimensional multi-modal features from TADPOLE, divided into 245 normal, 360 Mild Cognitive Impairment (MCI) and 80 Alzheimer’s Disease (AD) patients, respectively. Then, we divide the features according to the corresponding modalities and use the means of features for missing value filling.
ABIDE [1]: The Autism Brain Imaging Data Exchange (ABIDE) collected 1000 resting-state functional magnetic resonance imaging (R-fMRI) data with corresponding phenotypic data from 20 different sites. For Autism disease prediction, we select 871 patients, which are divided into 468 normal and 403 Autism Spectrum Disorder (ASD) patients. Then, for a fair comparison, we follow the preprocessing step as in [19].
3.1.1 Implementation details
Without loss of generality, the standard GCN [17] is adopted as the prediction module, which can also be replaced by other GCNs. To stabilize the training process, we use a modular iterative training strategy, where a complete training epoch is performed by jointly training MaFF and AGL once, and then jointly training AGL and GCN once. The model uses Adam [16] as the optimizer and is implemented on the PyTorch platform. For hyper-parameters tuning, we set 4 attention heads of MaFF, and the other hyper-parameters are tuned through hyperopt [2].
3.2 Performance Comparisons
3.2.1 Baselines.
To evaluate the performance of MMGL, we choose to compare with several baselines, especially those that have achieved the state-of-the-art results in disease prediction tasks recently. Both PopGCN [19] and InceptionGCN [13] are single-graph-based methods and two of the earliest works to use GCNs for disease prediction tasks. Multi-GCN [12] is a multi-graph-based method. In addition, we also compared with LGL [4] and EV-GCN [10] which are the most related state-of-the-art works in disease prediction tasks.
Methods | TADPOLE | ABIDE | ||
---|---|---|---|---|
ACC | AUC | ACC | AUC | |
PopGCN [19] | ||||
InceptionGCN [13] | ||||
Multi-GCN111Due to the source code is not publicly available, the reported results here are our re-implementation of the original algorithms. [12] | ||||
EV-GCN [10] | ||||
LGL111Due to the source code is not publicly available, the reported results here are our re-implementation of the original algorithms. [4] | 93.921.61 | |||
MMGL | 93.731.70 | 86.953.88 | 86.743.77 |
3.2.2 Quantitative Results.
We evaluate MMGL and other baselines on both the TADPOLE and ABIDE datasets using 10-fold stratified cross validation strategy, and the mean scores and standard errors of Area Under Curve (AUC) and accuracy are reported. As shown in Table. 1, it can be concluded that, (i) compared to single-graph-based methods which are simply using meta-features and imaging features, the further use of multi-modal features can effectively improve the performance of the model; (ii) our MMGL outperforms Multi-GCN [12] by about 10.0% and 17.5% on TADPOLE and ABIDE datasets respectively, which just verifies that the effectiveness of MaFF to integration of multi-modal features; (iii) the approximate 2.2% improvement of our MMGL on ABIDE dataset over LGL that is the state-of-the-art method is achieved, which demonstrates that our AGL may more effective compared to the graph learning in LGL.
![]() |
![]() |
![]() |
![]() |
(a) | (b) | (c) | (d) |
3.2.3 Qualitative Results.
We visualize the kNN graphs and graph structures learned by MMGL, where the kNN graphs are constructed by the fused features. As illustrated in Fig. 3, the kNN graph and the learned graph have overall similarity and form subgraphs corresponding to the patient classes, reflecting the differences between classes, which demonstrates the effectiveness of MaFF. Furthermore, compared to kNN graph, graph structure learned by AGL is more sparse between different classes of patients, indicating that AGL is better able to learn intra-class similarity while capturing inter-class differences. Thus, although the groundtruth graph does not exist, we can still see the superiority of AGL compared to kNN graph construction.
3.3 Ablation study
To validate the effectiveness of modal-attentional feature fusion (MaFF) and adaptive graph learning mechanism (AGL), we replace MaFF with MLP and concatenation operation respectively, and replace AGL with kNN graph construction based on RBF kernel and construction method of popGCN respectively. Table. 2 shows the ablation study results on different modules in our models. Specifically, the performance of the constructed graph of popGCN is the worst, especially on the ABIDE dataset, indicating that hand-constructed graphs are indeed not a desirable choice. Besides, AGL achieves favorable performance despite the absence of MaFF, which again validates the effectiveness of adaptive graph learning. More importantly, it can be observed that the combination of MaFF and AGL achieves a performance that far exceeds other combinations.
TADPOLE | ABIDE | |||
ACC | AUC | ACC | AUC | |
MMGL | 93.731.70 | 93.811.80 | 86.953.88 | 86.743.77 |
MLP+AGL | 89.422.22 | 90.302.34 | 85.074.68 | 83.364.64 |
Concat+AGL | 85.972.06 | 88.172.48 | 83.702.64 | 83.702.68 |
MaFF+popGCN | 85.393.37 | 87.912.78 | 69.274.45 | 69.304.33 |
MaFF+kNN | 87.973.69 | 89.672.65 | 84.193.83 | 83.914.07 |
4 Conclusion
In this paper, we propose a novel multi-modal graph learning framework named MMGL for disease prediction. For better integration of complementary information of multi-modalities, we propose a modal-attentional feature fusion module (MaFF) to achieve feature fusion considering inter-modal correlations. Furthermore, based on fused features, a lightweight adaptive graph learning mechanism is proposed to reveal the intrinsic connections among samples, constructing the optimal graph structure for downstream tasks. We have carried out extensive experiments on two disease prediction problems, and the results demonstrate the obvious superiority of our MMGL over currently available alternatives. More importantly, as a high modular inductive framework, MMGL provides a baseline that different variants of MMGL can be easily implemented to perform scenario-specific multimodal adaptive graph learning. Our ongoing research work will extend our MMGL to adaptive unified graph learning for more biomedical tasks.
References
- [1] Abraham, A., Milham, M.P., Di Martino, A., Craddock, R.C., Samaras, D., Thirion, B., Varoquaux, G.: Deriving reproducible biomarkers from multi-site resting-state data: An autism-based example. NeuroImage 147, 736–745 (2017)
- [2] Bergstra, J., Yamins, D., Cox, D.: Making a science of model search: Hyperparameter optimization in hundreds of dimensions for vision architectures. In: International Conference on Machine Learning (ICML). pp. 115–123 (2013)
- [3] Chen, Y., Wu, L., Zaki, M.: Iterative deep graph learning for graph neural networks: Better and robust node embeddings. Advances in Neural Information Processing Systems (NeruIPS) 33 (2020)
- [4] Cosmo, L., Kazi, A., Ahmadi, S.A., Navab, N., Bronstein, M.: Latent-graph learning for disease prediction. In: International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI). pp. 643–653 (2020)
- [5] Craddock, C., Benhajali, Y., Chu, C., Chouinard, F., Evans, A., Jakab, A., Khundrakpam, B.S., Lewis, J.D., Li, Q., Milham, M., et al.: The neuro bureau preprocessing initiative: open sharing of preprocessed neuroimaging data and derivatives. Frontiers in Neuroinformatics 7 (2013)
- [6] Defferrard, M., Bresson, X., Vandergheynst, P.: Convolutional neural networks on graphs with fast localized spectral filtering. In: Advances in Neural Information Processing Systems (NeurIPS). pp. 3844–3852 (2016)
- [7] Franceschi, L., Niepert, M., Pontil, M., He, X.: Learning discrete structures for graph neural networks. In: International Conference on Machine Learning (ICML). pp. 1972–1982 (2019)
- [8] Gao, J., Lyu, T., Xiong, F., Wang, J., Ke, W., Li, Z.: Mgnn: A multimodal graph neural network for predicting the survival of cancer patients. In: International ACM SIGIR Conference on Research and Development in Information Retrieval (SIGIR). pp. 1697–1700 (2020)
- [9] Hamilton, W., Ying, Z., Leskovec, J.: Inductive representation learning on large graphs. In: Advances in Neural Information Processing Systems (NeurIPS). pp. 1024–1034 (2017)
- [10] Huang, Y., Chung, A.C.: Edge-variational graph convolutional networks for uncertainty-aware disease prediction. In: International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI). pp. 562–572 (2020)
- [11] Kalofolias, V.: How to learn a graph from smooth signals. In: Artificial Intelligence and Statistics. pp. 920–929 (2016)
- [12] Kazi, A., Shekarforoush, S., Kortuem, K., Albarqouni, S., Navab, N., et al.: Self-attention equipped graph convolutions for disease prediction. In: 2019 IEEE 16th International Symposium on Biomedical Imaging (ISBI). pp. 1896–1899 (2019)
- [13] Kazi, A., Shekarforoush, S., Krishna, S.A., Burwinkel, H., Vivar, G., Kortüm, K., Ahmadi, S.A., Albarqouni, S., Navab, N.: Inceptiongcn: receptive field aware graph convolutional network for disease prediction. In: International Conference on Information Processing in Medical Imaging (IPMI). pp. 73–85 (2019)
- [14] Kazi, A., Shekarforoush, S., Krishna, S.A., Burwinkel, H., Vivar, G., Wiestler, B., Kortüm, K., Ahmadi, S.A., Albarqouni, S., Navab, N.: Graph convolution based attention model for personalized disease prediction. In: International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI). pp. 122–130 (2019)
- [15] Khosravan, N., Celik, H., Turkbey, B., Jones, E.C., Wood, B., Bagci, U.: A collaborative computer aided diagnosis (c-cad) system with eye-tracking, sparse attentional model, and deep learning. Medical Image Analysis 51, 101–115 (2019)
- [16] Kingma, D.P., Ba, J.: Adam: A method for stochastic optimization. Preprint arXiv:1412.6980 (2014)
- [17] Kipf, T.N., Welling, M.: Semi-supervised classification with graph convolutional networks. Preprint arXiv:1609.02907 (2016)
- [18] Marinescu, R.V., Oxtoby, N.P., Young, A.L., Bron, E.E., Toga, A.W., Weiner, M.W., Barkhof, F., Fox, N.C., Klein, S., Alexander, D.C., et al.: Tadpole challenge: Prediction of longitudinal evolution in alzheimer’s disease. Preprint arXiv:1805.03909 (2018)
- [19] Parisot, S., Ktena, S.I., Ferrante, E., Lee, M., Moreno, R.G., Glocker, B., Rueckert, D.: Spectral graph convolutions for population-based disease prediction. In: International conference on medical image computing and computer-assisted intervention (MICCAI). pp. 177–185 (2017)
- [20] Rhee, S., Seo, S., Kim, S.: Hybrid approach of relation network and localized graph convolutional filtering for breast cancer subtype classification. Preprint arXiv:1711.05859 (2017)
- [21] Su, C., Tong, J., Zhu, Y., Cui, P., Wang, F.: Network embedding in biomedical data science. Briefings in Bioinformatics 21(1), 182–197 (2020)
- [22] Valenchon, J., Coates, M.: Multiple-graph recurrent graph convolutional neural network architectures for predicting disease outcomes. In: IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP). pp. 3157–3161 (2019)
- [23] Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A.N., Kaiser, L., Polosukhin, I.: Attention is all you need. Preprint arXiv:1706.03762 (2017)
- [24] Veličković, P., Cucurull, G., Casanova, A., Romero, A., Liò, P., Bengio, Y.: Graph attention networks. In: International Conference on Learning Representations (ICLR) (2018)
- [25] Yang, H., Li, X., Wu, Y., Li, S., Lu, S., Duncan, J.S., Gee, J.C., Gu, S.: Interpretable multimodality embedding of cerebral cortex using attention graph network for identifying bipolar disorder. In: International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI). pp. 799–807 (2019)
- [26] Zhan, K., Chang, X., Guan, J., Chen, L., Ma, Z., Yang, Y.: Adaptive structure discovery for multimedia analysis using multiple features. IEEE Transactions on Cybernetics 49(5), 1826–1834 (2018)
- [27] Zhu, Q., Du, B., Yan, P.: Multi-hop convolutions on weighted graphs. Preprint arXiv:1911.04978 (2019)