This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Re-Think and Re-Design Graph Neural Networks in Spaces of Continuous Graph Diffusion Functionals

Tingting Dan1, Jiaqi Ding2, Ziquan Wei1, Shahar Z Kovalsky3, Minjeong Kim4
Won Hwa Kim5, Guorong Wu1,2,∗
1
Department of Psychiatry, University of North Carolina at Chapel Hill
2 Department of Computer Science, University of North Carolina at Chapel Hill
3 Department of Mathematics, University of North Carolina at Chapel Hill
4 Department of Computer Science, University of North Carolina at Greensboro
5 Computer Science and Engineering / Graduate School of AI
[email protected]
Abstract

Graphs are ubiquitous in various domains, such as social networks and biological systems. Despite the great successes of graph neural networks (GNNs) in modeling and analyzing complex graph data, the inductive bias of locality assumption, which involves exchanging information only within neighboring connected nodes, restricts GNNs in capturing long-range dependencies and global patterns in graphs. Inspired by the classic Brachistochrone problem, we seek how to devise a new inductive bias for cutting-edge graph application and present a general framework through the lens of variational analysis. The backbone of our framework is a two-way mapping between the discrete GNN model and continuous diffusion functional, which allows us to design application-specific objective function in the continuous domain and engineer discrete deep model with mathematical guarantees. First, we address over-smoothing in current GNNs. Specifically, our inference reveals that the existing layer-by-layer models of graph embedding learning are equivalent to a 2{\ell_{2}}-norm integral functional of graph gradients, which is the underlying cause of the over-smoothing problem. Similar to edge-preserving filters in image denoising, we introduce the total variation (TV) to promote alignment of the graph diffusion pattern with the global information present in community topologies. On top of this, we devise a new selective mechanism for inductive bias that can be easily integrated into existing GNNs and effectively address the trade-off between model depth and over-smoothing. Second, we devise a novel generative adversarial network (GAN) to predict the spreading flows in the graph through a neural transport equation. To avoid the potential issue of vanishing flows, we tailor the objective function to minimize the transportation within each community while maximizing the inter-community flows. Our new GNN models achieve state-of-the-art (SOTA) performance on graph learning benchmarks such as Cora, Citeseer, and Pubmed.

1 Introduction

Graph is a fundamental data structure that arises in various domains, including social network analysis [48], natural language processing [42], computer vision [34], recommender systems [43], and knowledge graphs [25] among others. Tremendous efforts have been made to operate machine learning on graph data (called graph neural networks, or GNNs) at the node [27], link [47], and graph level [31]. The common inductive bias used in GNNs is the homophily assumption that nodes that are connected in a graph are more likely to have similar features or labels. In this context, most GNN models deploy a collection of fully-connected layers to progressively learn graph embeddings by aggregating the nodal feature representations from its topologically-connected neighbors throughout the graph [21].

Refer to caption
Figure 1: Demonstration of the root cause of over-smoothing in GNNs. Nodes #1 and #2 are located along the boundary of two communities. The locality assumption in GNNs steers the learning of the graph representations by constraining the information exchange via node-to-node connections. However, such link-wise inductive bias opts to neutralize the contrast of node embeddings between nodes #1 and #2, which might undermine the node classification accuracy. Our research framework yields the solution for the over-smoothing issue by enabling heat-kernel diffusion within each community while penalizing the excessive community-to-community information exchanges (highlighted in red).

Under the hood of GNNs, the graph representation learning process is achieved by various learnable operations, such as message passing [5] or graph convolution [27]. Due to the nature of exchanging information in a local graph neighborhood, however, it is challenging to capture global graph representations, which go beyond node-to-node relationship, by leveraging the deep architecture in GNNs while being free of overly smoothing the feature representations for the closely-connected nodes. Fig. 1 demonstrates the root cause of over-smoothing issue in current GNNs, where node color denotes the group label (no color means unlabeled) and edge thickness indicates connection strength. It is clear that nodes #1 and #2 are located at the boundary of two communities. The inductive bias of GNNs (i.e., locality assumption) enforces the node embedding vectors on node #1 and #2 becoming similar due to being strongly connected (highlighted in red), even though the insight of global topology suggests that their node embeddings should be distinct. As additional layers are added to GNNs, the node embeddings become capable of capturing global feature representations that underlie the entire graph topology. However, this comes at the cost of over-smoothing node embeddings across graph nodes due to (1) an increased number of node-to-node information exchanges, and (2) a greater degree of common topology within larger graph neighborhoods. In this regard, current GNNs only deploy a few layers (typically two or three) [30], which might be insufficient to characterize the complex feature representations on the graph.

It is evident that mitigating the over-smoothing problem in GNNs will enable training deeper models. From a network architecture perspective, skip connections [18; 45], residual connections [29; 20], and graph attention mechanisms [39; 37] have been proposed to alleviate the information loss in GNNs, by either preserving the local feature representation or making information exchange adaptive to the importance of nodes in the graph. Although these techniques are effective to patch the over-smoothing issue in some applications, the lack of an in-depth understanding of the root cause of the problem poses the challenge of finding a generalized solution that can be scaled up to current graph learning applications.

Inspired by the success of neural ordinary differential equations in computer vision [10], research focus has recently shifted to link the discrete model in GNNs with partial differential equation (PDE) based numerical recipes [44; 33; 6; 15]. For example, Graph Neural Diffusion (GRAND) formulates GNNs as a continuous diffusion process [6]. In their framework, the layer structure of GNNs corresponds to a specific discretization choice of temporal operators. Since PDE-based model does not revolutionize the underlying inductive bias in current GNNs, it is still unable to prevent the excessive information change between adjacent nodes as in nodes #1 and #2 in Fig. 1. In this regard, using more advanced PDE solvers only can provide marginal improvements in terms of numerical stability over the corresponding discrete GNN models, while the additional computational cost, even in the feed-forward scenario, could limit the practical applicability of PDE-based methods for large-scale graph learning tasks.

In this regard, pioneering work on continuous approaches has prompted to re-think GNN as a graph diffusion process governed by the Euler-Lagrange (E-L) equation of the heat kernel. This formulation is reminiscent of the Brachistochrone problem 111The Brachistochrone problem is a classic physics problem that involves finding the curve down which a bead sliding under the influence of gravity will travel in the least amount of time between two points., which emerged over 400 years ago and established the mathematical framework of classical mechanics. The powerful calculus of variations allows us to generate solutions for various mechanics questions (e.g., the slope that yields the fastest ball sliding down the curve is given by a cycloid) through the lens of E-L equation, as shown in Fig. 2 (top).

In a similar vein, the question that arises in the context of community detection is: What graph diffusion pattern is best suited for preserving community organizations? The question for graph classification would be: What graph diffusion pattern works best for capturing the system-level characteristics of graph topology? Following the spirit of Brachistochrone problem, we present a general research framework to customize application-specific GNNs in a continuous space of graph diffusion functionals. As shown in Fig. 2 (bottom), we have established a fundamental structure for our framework that involves a two-way mapping between a discrete GNN model and a continuous graph diffusion functional. This allows us to develop application-specific objective functions (with an explainable regularization term) in the continuous domain and construct a discrete deep model with mathematical guarantee. We demonstrate two novel GNN models, one for addressing over-smoothing and one for predicting the flows from longitudinal nodal features, both achieving state-of-the-art performance (Cora: 85.6%, Citeseer: 73.9%, Pubmed: 80.10%, even in 128 network layers).

We have made four major contributions. (1) We establish a connection between the discrete model of GNNs and the continuous functional of inductive bias in graph learning by using the E-L equation as a stepping stone to bridge the discrete and continuous domains. (2) We introduce a general framework to re-think and re-design new GNNs that is less “black-box”. (3) We devise a novel selective mechanism upon inductive bias to address the over-smoothing issue in current GNNs and achieve state-of-the-art performance on graph learning benchmarks. (4) We construct a novel GNN in the form of a generative adversarial network (GAN) to predict the flow dynamics in the graph by a neural transport equation.

Refer to caption
Figure 2: Top: The Brachistochrone problem played a pivotal role in the development of classical mechanics and the powerful mathematical tool known as the calculus of variations. Bottom: We introduce a general framework to answer "Brachistochrone" problems regarding diffusion patterns on graphs that allows us to re-think and re-design application-specific deep model of GNNs with enhanced mathematical interpretability.

2 Methods

In the following, we first elucidate the relationship between GNN, PDE, and calculus of variations (COV), which sets the stage for the GNN-PDE-COV framework for new GNN models in Section 2.2.

2.1 Re-think GNNs: Connecting dots across graph neural networks, graph diffusion process, Euler-Lagrange equation, and Lagrangian mechanics

Graph diffusion process. Given graph data 𝒢=(V,W)\mathcal{G}=(V,W) with NN nodes V={vi|i=1,,N}V=\{v_{i}|i=1,…,N\}, the adjacency matrix W=[wij]i,j=1NN×NW=\left[w_{ij}\right]_{i,j=1}^{N}\in\mathbb{R}^{N\times N} describes connectivity strength between any two nodes. For each node viv_{i}, we have a graph embedding vector ximx_{i}\in\mathcal{R}^{m}. In the context of graph topology, the graph gradient (𝒢x)ij=wij(xixj)\left(\nabla_{\mathcal{G}}x\right)_{ij}=w_{ij}\left(x_{i}-x_{j}\right) indicates the feature difference between viv_{i} and vjv_{j} weighted by the connectivity strength wijw_{ij}, where 𝒢\nabla_{\mathcal{G}} is a NN×N\mathbb{R}^{N}\rightarrow\mathbb{R}^{N\times N} operator. Thus, the graph diffusion process can be formulated as x(t)t=div(𝒢x(t))\frac{\partial x(t)}{\partial t}=div\left(\nabla_{\mathcal{G}}x(t)\right), where the evolution of embedding vectors x=[xi]i=1Nx=\left[x_{i}\right]_{i=1}^{N} is steered by the graph divergence operator.

Connecting GNN to graph diffusion. In the regime of GNN, the regularization in the loss function often measures the smoothness of embeddings xx over the graph by xTΔxx^{T}\Delta x, where Δ=div(𝒢)\Delta=div(\nabla_{\mathcal{G}}) is the graph Laplacian operator [27]. To that end, the graph smoothness penalty encourages two connected nodes to have similar embeddings by information exchange in each GNN layer. Specifically, the new graph embedding xlx^{l} in the lthl^{th} layer is essentially a linear combination of the graph embedding xl1x^{l-1} in the previous layer, i.e., xl=AW,Θxl1x^{l}=A_{W,\Theta}x^{l-1}, where the matrix AA depends on graph adjacency matrix WW and trainable GNN parameter Θ\Theta. After rewriting xl=Axl1x^{l}=Ax^{l-1} into xlxl1=(AI)xl1x^{l}-x^{l-1}=(A-I)x^{l-1}, updating graph embeddings xx in GNN falls into a discrete graph diffusion process, where the time parameter tt acts as a continuous analog of the layers in the spirit of Neural ODEs [10]. It has been shown in [6] that running the graph diffusion process for multiple iterations is equivalent to applying a GNN layer multiple times.

GNN is a discrete model of Lagrangian mechanics via E-L equation. The diffusion process x(t)t=div(𝒢x(t))\frac{\partial x(t)}{\partial t}=div\left(\nabla_{\mathcal{G}}x(t)\right) has been heavily studied in image processing in decades ago, which is the E-L equation of the functional minxΩ|x|2𝑑x\mathop{\min}\limits_{x}\int_{\Omega}|\nabla x{|^{2}}dx. By replacing the 1D gradient operator defined in the Euclidean space Ω\Omega with the graph gradient (𝒢x)ij(\nabla_{\mathcal{G}}x)_{ij}, it is straightforward to find that the equation governing the graph diffusion process x(t)t=div(𝒢x(t))\frac{\partial x(t)}{\partial t}=div\left(\nabla_{\mathcal{G}}x(t)\right) is the E-L equation of the functional minx𝒢|𝒢x|2𝑑x\mathop{\min}\limits_{x}\int_{\mathcal{G}}|\nabla_{\mathcal{G}}x{|^{2}}dx over the graph topology. Since the heat kernel diffusion is essentially the mathematical description of the inductive bias in current GNNs, we have established a mapping between the mechanics of GNN models and the functional of graph diffusion patterns in a continuous domain.

Tracing the smoking gun of over-smoothing in GNNs. In Fig. 1, we observed that the inductive bias of link-wise propagation is the major reason for excessive information exchange, which attributes to the over-smoothing problem in GNNs. An intuitive approach is to align the diffusion process with high-level properties associated with graph topology, such as network communities. After connecting the GNN inductive bias to the functional of graph diffusion process, we postulate that the root cause of over-smoothing is the isotropic regularization mechanism encoded by the 2{\ell_{2}}-norm. More importantly, connecting GNN to the calculus of variations offers a more principled way to design new deep models with mathematics guarantees and model mechanistic explainability.

2.2 Re-design GNNs: Revolutionize inductive bias, derive new E-L equation, and construct deeper GNN

The general roadmap for re-designing GNNs typically involves three major steps: (1) formulating inductive bias into the functional of graph diffusion patterns; (2) deriving the corresponding E-L equation; and then (3) developing a new deep model of GNN based on the finite difference solution of E-L equation. Since the graph diffusion functional is application-specific, we demonstrate the construction of new GNN models in the following two graph learning applications.

2.2.1 Develop VERY deep GNNs with a selective mechanism for link-adaptive inductive bias

Problem formulation. Taking the feature learning component (learnable parameters Θ\Theta) out of GNNs, the graph embeddings xLx^{L} (output of an LL-layer GNN) can be regarded as the output of an iterative smoothing process (LL times) underlying the graph topology 𝒢\mathcal{G}, constrained by the data fidelity xLx022\left\|{x^{L}-x^{0}}\right\|_{2}^{2} (w.r.t. the initial graph embeddings x0x^{0}) and graph smoothness term 𝒢|𝒢x|2dx\mathop{\int}_{\mathcal{G}}|\nabla_{\mathcal{G}}x{|^{2}}dx. Inspired by the great success of total variation (TV) for preserving edges in image denoising [35], reconstruction [41] and restoration [8], we proposed to address the over-smoothing issue in current GNN by replacing the quadratic Laplacian regularizer with TV on graph gradients, i.e., 𝒥TV(x)=|𝒢x|dx\mathcal{J}_{TV}(x)=\mathop{\int}|\nabla_{\mathcal{G}}x|dx. Thus, the TV-based objective function for graph diffusion becomes: minx(xx022+𝒥TV(x))\mathop{\min}\limits_{x}(\left\|{x-x^{0}}\right\|_{2}^{2}+\mathcal{J}_{TV}(x)).

However, the 1{\ell_{1}}-norm function, denoted by |||\cdot| in the definition of the total variation functional 𝒥TV\mathcal{J}_{TV}, is not differentiable at zero. Following the dual-optimization schema [4; 7], we introduce the latent auxiliary matrix zN×Nz\in\mathbb{R}^{N\times N} and reformulate the TV-based functional as 𝒥TV(x,z)=maxzminx(z𝒢x)𝑑x\mathcal{J}_{TV}(x,z)=\mathop{\max}\limits_{z}\mathop{\min}\limits_{x}\int(z\otimes\nabla_{\mathcal{G}}x)dx, subject to |z|𝟏N×N|z|\leq\mathbf{1}^{N\times N}, where \otimes is Hadamard operation between two matrices. Furthermore, we use an engineering trick of element-wise operation zij(𝒢x)ijz_{ij}(\nabla_{\mathcal{G}}x)_{ij} to keep the degree always non-negative (same as we take the absolute value), which makes the problem solvable. In the end, we reformulate the minimization of 𝒥TV(x)\mathcal{J}_{TV}(x) into a dual min-max functional as 𝒥TV(x,z)\mathcal{J}_{TV}(x,z), where we maximize zz (z𝟏N×Nz\rightarrow\mathbf{1}^{N\times N}) such that 𝒥TV(x,z)\mathcal{J}_{TV}(x,z) is close enough to 𝒥TV(x)\mathcal{J}_{TV}(x). Therefore, the new objective function is reformulated as:

𝒥(x,z)=maxzminxxx022+λ(z𝒢x)𝑑x,{\cal J}(x,z)=\mathop{\max}\limits_{z}\mathop{\min}\limits_{x}\left\|{x-{x^{0}}}\right\|_{2}^{2}+\lambda\int{(z{\nabla_{\cal\mathcal{G}}}x})dx, (1)

which λ\lambda is a scalar balancing the data fidelity term and regularization term. Essentially, Eq. 1 is the dual formulation with min-max property for the TV distillation problem [50].

Constructing E-L equations. To solve Eq. 1, we present the following two-step alternating optimization schema. First, the inner minimization problem (solving for xix_{i}) in Eq. 1 can be solved by letting xi𝒥(xi,zi)=0\frac{\partial}{{\partial x_{i}}}\mathcal{J}(x_{i},z_{i})=0:

xi𝒥(xi,zi)=2(xixi0)+λzi𝒢xi=0x^i=xi0λ2zi𝒢xi\begin{array}[]{*{20}{c}}{\frac{\partial}{{\partial x_{i}}}{\cal J}(x_{i},z_{i})=2(x_{i}-{x_{i}^{0}})+\lambda z_{i}{\nabla_{\cal G}}x_{i}=0}&\Rightarrow&{\hat{x}_{i}={x_{i}^{0}}-\frac{\lambda}{2}z_{i}\nabla_{\cal G}x_{i}}\end{array} (2)

Replacing (𝒢x)ij\left(\nabla_{\mathcal{G}}x\right)_{ij} with wij(xixj)w_{ij}\left(x_{i}-x_{j}\right), the intuition of Eq. 2 is that each element in x^i\hat{x}_{i} is essentially the combination between the corresponding initial value in xi0x^{0}_{i} and the overall graph gradients zi𝒢xi=j𝒩iwij(xixj)ziz_{i}\nabla_{\cal G}x_{i}=\sum_{j\in\mathcal{N}_{i}}w_{ij}(x_{i}-x_{j})z_{i} within its graph neighborhood 𝒩i\mathcal{N}_{i}. In this regard, Eq. 2 characterizes the dynamic information exchange on the graph, which is not only steered by graph topology but also moderated by the attenuation factor ziz_{i} at each node.

Second, by substituting Eq. 2 back into Eq. 1, the objective function of ziz_{i} becomes 𝒥(zi)=max|zi|𝟏λ2zi𝒢xi22+λzi𝒢(xi0λ2zi𝒢xi)\mathcal{J}(z_{i})=\mathop{\max}\limits_{|z_{i}|\leq{\mathbf{1}}}\left\|{\frac{\lambda}{2}{z_{i}\nabla_{\mathcal{G}}}x_{i}}\right\|_{2}^{2}+\lambda z_{i}{\nabla_{\mathcal{G}}}({x_{i}^{0}}-\frac{\lambda}{2}{z_{i}\nabla_{\mathcal{G}}}x_{i}). With simplification (in Eq. 8 to Eq. 10 of Supplementary), the optimization of each ziz_{i} is achieved by argmin|zi|𝟏zi𝒢xizi𝒢xi4λzi𝒢xi0\mathop{\arg\min}\limits_{|z_{i}|\leq{\mathbf{1}}}z_{i}{\nabla_{\mathcal{G}}}x_{i}{z_{i}}{\nabla_{\mathcal{G}}}x_{i}-\frac{4}{\lambda}z_{i}{\nabla_{\mathcal{G}}}{x_{i}^{0}}. Specifically, we employ the majorization-minimization (MM) method [16] to optimize ziz_{i} by solving this constrained minimization problem (the detailed derivation process is given in Eq. 8 to Eq. 19 of Section 5.1 of Supplementary), where ziz_{i} can be iteratively refined by:

zil=clip(zil1+2βλ𝒢xib,1)={b11|b|1b>1b<1\vspace{-0.3em}{z_{i}^{l}}=clip(\underbrace{{z_{i}^{l-1}}+\frac{2}{{\beta\lambda}}{\nabla_{\mathcal{G}}}{x_{i}}}_{b},1)=\left\{{\begin{array}[]{*{20}{c}}b\\ 1\\ {-1}\end{array}}\right.\begin{array}[]{*{20}{c}}{|b|\leq 1}\\ {b>1}\\ {b<-1}\end{array}\vspace{-0.1em} (3)

β\beta is a hyper-parameter that is required to be no less than the largest eigenvalue of (𝒢xi)(𝒢xi)(\nabla_{\mathcal{G}}x_{i})(\nabla_{\mathcal{G}}x_{i})^{\intercal}.

Refer to caption
Figure 3: Our new deep model integrates a novel diffusion-clip (DC) layer (for selective graph diffusion) after the conventional fully-connected (FC)layer (for graph representation learning).

Develop new GNN network architecture with a selective inductive bias. The building block in vanilla GNN [27] is a FC (fully-connected) layer where the input is the embedding vectors after isotropic graph diffusion (in 2{\ell_{2}}-norm). Since the estimation of graph embeddings xx in Eq. 2 depends on the latest estimation of z(l)z^{(l)}, such recursive min-max solution for Eq. 1 allows us to devise a new network architecture that disentangles the building block in vanilla GNN into the feature representation learning and graph diffusion underling TV. As shown in Fig. 3, we first deploy a FC layer to update the graph embeddings x(l)x^{(l)}. After that, we concatenate a diffusion-clip (DC) layer for selective graph diffusion, which sequentially applies (1) node-adaptive graph diffusion (blue arrow in Fig. 3) on x(l)x^{(l)} by Eq. 2 222Since the optimization schema has been switched to the layer-by-layer manner, the initialization x0x_{0} becomes x(l1)x^{(l-1)} from the previous layer., and (2) clip operation (purple arrow in Fig. 3) to each xi(l)x^{(l)}_{i} by Eq. 3.

Remarks. Eq. 3 indicates that larger connective degree results in larger value of zz. Thus, the DC layer shifts the diffusion patterns by penalizing the inter-community information exchange (due to strong connections) while remaining the heat-kernel diffusion within the community. The preference of such link-adaptive diffusion can be adjusted by the hyper-parameter λ\lambda 333λ\lambda can be either pre-defined or learned from the data. in Eq. 1. Recall our intuitive solution for over-smoothing problem in Fig. 1, the DC layer offers the exact global insight of graph topology to keep the node embeddings distinct between nodes #1 and #2. We demonstrate the effect of DC layer on the real-world graph data in Fig. 8 of Supplementary document.

Refer to caption
Figure 4: (a) The illustration of the computational challenge for estimating the spreading flow. (b) The GAN architecture for min-max optimization in the network.

2.2.2 Predict flow dynamics through graph neural transport equation

Problem formulation. We live in a world of complex systems, where everything is intricately connected in multiple ways. A holistic insight of how the system’s components interact with each other and how changes in one part of the system can affect the behavior of the whole sheds new light on the dynamic behaviors of these complex systems over time. However, oftentimes it is an ill-posed problem. Taking the toy system in Fig. 4(a) as an example, while it is simple to calculate the future focal patterns based on the focal patterns at the current time point and the node-to-node flow information, determining flow dynamics based on longitudinal nodal observations is computationally hard since the solution is not unique.

The naïve solution to predict the spreading flow is to (1) train a GNN to learn the intrinsic node embeddings and (2) predict the flow based on the difference of learned embeddings. However, this two-step approach might suffer from vanishing flow due to over-smoothing in GNNs. Following the spirit of Brachistochrone problem, we ask the question "What flow field f(t)=[fij(t)]i,j=1Nf(t)=[f_{ij}(t)]_{i,j=1}^{N} underlines the system mechanics to the extent that it is able to predict the behaviors in the future?"

In this paper, we focus on the conservative system of energy transportation [2]. The system mechanics is formulated as:

dxdt+div(q)=0\frac{{dx}}{{dt}}+div(q)=0 (4)

where q=[qij]i,j=1Nq=[q_{ij}]_{i,j=1}^{N} is the flux field which propagates the potential energy u(t)=[ui(t)]i=1Nu(t)=[u_{i}(t)]_{i=1}^{N} (conserved quantity) over time. Similar to a gravity field driving water flow, the intuition of Eq. 4 is that the change of energy density uu (we assume there is a non-linear mapping ϕ\phi from external force xx to uu, i.e., ui=ϕ(xi)u_{i}=\phi(x_{i})) leads to energy transport throughout the entire graph. As flux is closely related to the difference of energy 𝒢u\nabla_{\mathcal{G}}u underlying the graph topology, we assume the energy flux qq is regulated by the potential energy field 𝒢u\nabla_{\mathcal{G}}u, i.e., q=α𝒢uq=\alpha\otimes\nabla_{\mathcal{G}}u, where α=[αij]i,j=1N\alpha=[\alpha_{ij}]_{i,j=1}^{N} is a learnable matrix characterizing the link-wise contribution of each energy potential 𝒢uij\nabla_{\mathcal{G}}u_{ij} to the potential energy flux qijq_{ij}.

By plugging q=α𝒢uq=\alpha\otimes\nabla_{\mathcal{G}}u into Eq. 4, the energy transport process can be reformulated as:

ut=ϕ1(αdiv(𝒢u))=ϕ1(αΔu),\frac{{\partial u}}{{\partial t}}=-\phi^{-1}(\alpha\otimes div(\nabla_{\mathcal{G}}u))=-\phi^{-1}(\alpha\otimes\Delta u), (5)

where Δ=div(𝒢)\Delta=div(\nabla_{\mathcal{G}}) is the graph Laplacian operator. Since the PDE in Eq. 5 is equivalent to the E-L equation of the quadratic functional 𝒥(u)=minu𝒢α|𝒢u|2𝑑u\mathcal{J}(u)=\mathop{\min}\limits_{u}\int_{\mathcal{G}}\alpha\otimes|\nabla_{\mathcal{G}}u{|^{2}}du (after taking ϕ\phi away), a major issue is the over-smoothness in uu that might result in vanishing flows. In this context, we propose to replace the 2{\ell_{2}}-norm integral functional 𝒥(u)\mathcal{J}(u) with TV-based counterpart 𝒥TV(u)=minu𝒢α|𝒢u|𝑑u\mathcal{J}_{TV}(u)=\mathop{\min}\limits_{u}\int_{\mathcal{G}}\alpha\otimes|\nabla_{\mathcal{G}}u{|}du.

Renovate new E-L equation – graph neural transport equations. Following the min-max optimization schema in Eq. 1-3, we introduce an auxiliary matrix ff to lift the undifferentialable barrier. After that, the minimization of 𝒥TV(u)\mathcal{J}_{TV}(u) boils down into a dual min-max functional 𝒥TV(u,f)=minumaxf𝒢αf(𝒢u)𝑑u\mathcal{J}_{TV}(u,f)=\mathop{\min}\limits_{u}\mathop{\max}\limits_{f}\int_{\mathcal{G}}\alpha\otimes f(\nabla_{\mathcal{G}}u)du. Since u(t)u(t) is a time series, it is difficult to derive the deterministic solutions (as Eq. 2-3) by letting u𝒥TV=0\frac{\partial}{{\partial u}}\mathcal{J}_{TV}=0 and f𝒥TV=0\frac{\partial}{{\partial f}}\mathcal{J}_{TV}=0. Instead, we use Ga^teauxG\hat{a}teaux variations to optimize 𝒥TV(u,f)\mathcal{J}_{TV}(u,f) via the following two coupled time-dependent PDEs (please see Section 5.2, Eq. 21 to Eq. 26, in Supplementary for details):

{maxfdfdt=α𝒢uminududt=αdiv(f)\left\{{\begin{array}[]{*{20}{c}}{\mathop{\max}\limits_{f}\frac{{df}}{{dt}}={\alpha\otimes\nabla_{\mathcal{G}}}u}\\ {\mathop{\min}\limits_{u}\frac{{du}}{{dt}}=\alpha\otimes div(f)}\end{array}}\right. (6)

Remarks. The solution to Eq. 6 is known as continuous max-flow and constitutes a continuous version of a graph-cut [1]. Since α\alpha is a latent variable and potential energy uu is given, the maximization of ff opts towards maximizing the spreading flow through the lens of α\alpha. In this regard, the mechanistic role of auxiliary matrix ff is essentially the latent (maximized) spreading flows that satisfy u(t+1)i=u(t)i+j=1Nfij(t)u(t+1)_{i}=u(t)_{i}+\sum\nolimits_{j=1}^{N}{{f_{ij}}(t)}. The potential energy u^\hat{u} can be solved via a wave equation (utt=div(ft)=α2Δuu_{tt}=div(f_{t})=\alpha^{2}\otimes\Delta u), where the system dynamics is predominated by the adjusted Lagrangian mechanics α2Δu\alpha^{2}\otimes\Delta u. By determining α\alpha at a granularity of graph links, we devise a novel GAN model to predict the spreading flows ff which not only offers explainability underlying the min-max optimization mechanism in Eq. 6 but also sets the stage to understand system dynamics through machine learning.

Develop a GAN model of flow prediction with TV-based Lagrangian Mechanics. The overall network architecture is shown in Fig. 4 (b), which consists of a generator (red solid box) and a discriminator module (blue solid box). Specifically, the generator (GG) consists of (1) a GCN component [15] to optimize u^\hat{u} through the wave equation utt=α2Δuu_{tt}=\alpha^{2}\otimes\Delta u and (2) a FC layer to characterize the non-linear mapping function x^(t+1)=ϕ1(u^(t))\hat{x}(t+1)=\phi^{-1}(\hat{u}(t)). In contrast, the discriminator (DD) is designed to (1) synthesize α\alpha and (2) construct the future u~t+1\tilde{u}_{t+1} based on the current utu_{t} and current estimation of spreading flow f=α𝒢uf=\alpha\otimes\nabla_{\mathcal{G}}u (orange dash box). To make the network architecture consistent between generator and discriminator modules, we include another GCN to map the synthesized u~(t+1)\tilde{u}(t+1) to the external force x~(t+1)\tilde{x}(t+1). Note, since the working mechanism of this adversarial model underlines the min-max optimization in the energy transport equation, the nature of predicted spreading flows is carved by the characteristics of max-flow.

The driving force of our network is to minimize (1) the MSE (mean square error) between the output of the generator x^t+1\hat{x}_{t+1} and the observed regional features, (2) the distance between the synthesized regional features x~t+1{\tilde{x}_{t+1}} (from the discriminator) and the output of generator x^t+1\hat{x}_{t+1} (from the generator). In the spirit of probabilistic GAN [49], we use one loss function D\mathcal{L}_{D} to train the discriminator (DD) and another one G\mathcal{L}_{G} to train the generator (GG):

{D=D(xt+1)+[ξD(G(xt))]+G=D(G(xt))\left\{\begin{array}[]{c}\mathcal{L}_{D}=D\left(x_{t+1}\right)+\left[\xi-D\left(G\left(x_{t}\right)\right)\right]^{+}\\ \mathcal{L}_{G}=D\left(G\left(x_{t}\right)\right)\end{array}\right. (7)

where ξ\xi denotes the positive margin and the operator []+=max(0,){[\cdot]^{+}}=\max(0,\cdot). Minimizing G\mathcal{L}_{G} is similar to maximizing the second term of D\mathcal{L}_{D} except for the non-zero gradient when D(G(xt))ξD(G(x_{t}))\geq\xi.

3 Experiments

In this section, we evaluate the performance of the proposed GNN-PDE-COV framework with comparison to six graph learning benchmark methods on a wide variety of open graph datasets [36], as well as a proof-of-concept application of uncovering the propagation pathway of pathological events in Alzheimer’s disease (AD) from the longitudinal neuroimages.

3.1 Datasets and experimental setup

Dataset and benchmark methods. We evaluate the new GNN models derived from our proposed GNN framework in two different applications. First, we use three standard citation networks, namely Cora, Citeseer, and Pubmed [36] for node classification (the detailed data statistic is shown in Table 3 of Supplementary). We adopt the public fixed split [46] to separate these datasets into training, validation, and test sets. We follow the experimental setup of [9] for a fair comparison with six benchmark GNN models (vanilla GCN [27], GAT [39], GCNII [9], ResGCN [29], DenseGCN [29], GRAND [6]). Since our DC-layer can be seamlessly integrated into existing GNNs as a plug-in. The corresponding new GNN models (with DC-layer) are denoted GCN+, GAT+, GCNII+, ResGCN+, DenseGCN+, and GRAND+, respectively.

Second, we apply the GAN model in Section 2.2.2 to predict the concentration level of AD-related pathological burdens and their spreading pathways from longitudinal neuroimages. Currently, there is no in-vivo imaging technique that can directly measure the flow of information across brain regions. Here, our computational approach holds great clinical value to understand the pathophysiological mechanism involved in disease progression [26]. Specifically, we parcellate each brain into 148 cortical surface regions and 12 sub-cortical regions using Destrieux atlas [13]. The wiring topology of these 160 brain regions is measured by diffusion-weighted imaging [3] and tractography techniques [19]. The regional concentration levels AD pathology including amyloid, tau, and fluorodeoxyglucose (FDG) and cortical thickness (CoTh) are measured from PET (positron emission tomography) and MRI (magnetic resonance imaging) scans [23]. We use a total of M=1,291M=1,291 subjects from ADNI [32], each having longitudinal imaging data (2-5 time points). The details of image statistics and pre-processing are shown in Sec. 6.1.2. Since we apply the flow prediction model to each modality separately, we differentiate them with X-FlowNet (XX stands for amyloid, tau, FGD, and CoTh).

Experimental setup. In the node classification task, we verify the effectiveness and generality of DC-layer in various number of layers (L=2,4,8,16,32,64,128L={2,4,8,16,32,64,128}). All baselines use their own default parameter settings. Evaluation metrics include accuracy, precision and F1-score. To validate the performance of X-FlowNet, we examine (1) prediction accuracy (MAE) of follow-up concentration level, (2) prediction of the risk of developing AD using the baseline scan, and (3) understand the propagation mechanism in AD by revealing the node-to-node spreading flows of neuropathologies.

The main results of graph node classification and flow prediction are demonstrated in Section 3.2 and 3.3, respectively. Other supporting results such as ablation study and hyper-parameter setting are shown in Section 6 of the Supplementary document.

Table 1: Test accuracies (%) on citation networks. We show the mean value, the quota of increase (\uparrow)/decrease(\downarrow) after adding DC layer. Statistical significance is determined from 50 resampling tests. ‘*’ means statistically significance with p0.05p\leq 0.05, ‘**’ denotes p0.01p\leq 0.01. The missing results are due to the huge consumption of GPU memory for large graphs (DenseGCN) or gradient explosions (GAT) or non-convergence (GRAND). The best performance of baselines is denoted in blue, while the best performance after adding the DC layer is denoted in red.
Dataset Model L=2L=2 L=4L=4 L=8L=8 L=16L=16 L=32L=32 L=64L=64 L=128L=128
Cora GCN 81.30 79.90 75.70 25.20 20.00 31.80 20.80
GCN+ 82.701.40\textbf{{\color[rgb]{1,0,0}82.70}}_{1.40\uparrow}^{**} 82.702.80\textbf{{\color[rgb]{1,0,0}82.70}}_{2.80\uparrow}^{**} 82.306.60\textbf{82.30}_{6.60\uparrow}^{**} 70.6045.4\textbf{70.60}_{45.4\uparrow}^{**} 67.8047.8\textbf{67.80}_{47.8\uparrow}^{**} 66.6034.8\textbf{66.60}_{34.8\uparrow}^{**} 59.9039.1\textbf{59.90}_{39.1\uparrow}^{**}
GAT 82.40 80.30 57,90 31.90
GAT+ 82.600.20\textbf{{\color[rgb]{1,0,0}82.60}}_{0.20\uparrow} 80.500.20\textbf{80.50}_{0.20\uparrow} 69.7011.8\textbf{69.70}_{11.8\uparrow}^{**} 66.0034.1\textbf{66.00}_{34.1\uparrow}^{**} 63.6063.6\textbf{63.60}_{63.6\uparrow}^{**} 54.6054.6\textbf{54.60}_{54.6\uparrow}^{**} 45.7045.7\textbf{45.70}_{45.7\uparrow}^{**}
GRAND 80.00 82.64 82.74 83.45 81.83 80.81 79.19
GRAND+ 81.931.93\textbf{{81.93}}_{1.93\uparrow}^{**} 83.450.81\textbf{{83.45}}_{0.81\uparrow}^{**} 82.950.20\textbf{{82.95}}_{0.20\uparrow} 84.271.32\textbf{{\color[rgb]{1,0,0}84.27}}_{1.32\uparrow}^{**} 83.150.71\textbf{83.15}_{0.71\uparrow}^{*} 81.520.71\textbf{81.52}_{0.71\uparrow}^{**} 80.100.91\textbf{80.10}_{0.91\uparrow}^{**}
ResGCN 76.30 77.30 76.20 77.60 73.30 31.90 31.90
ResGCN+ 77.801.50\textbf{77.80}_{1.50\uparrow}^{**} 78.701.40\textbf{78.70}_{1.40\uparrow}^{**} 78.802.60\textbf{{\color[rgb]{1,0,0}78.80}}_{2.60\uparrow}^{**} 78.601.00\textbf{78.60}_{1.00\uparrow}^{**} 76.903.60\textbf{76.90}_{3.60\uparrow}^{**} 76.8044.9\textbf{76.80}_{44.9\uparrow}^{**} 33.601.70\textbf{33.60}_{1.70\uparrow}^{**}
DenseGCN 76.60 78.50 76.00
DenseGCN+ 78.001.40\textbf{78.00}_{1.40\uparrow}^{**} 78.700.20\textbf{{\color[rgb]{1,0,0}78.70}}_{0.20\uparrow} 76.901.40\textbf{76.90}_{1.40\uparrow}^{**}
GCNII 76.40 81.90 81.50 84.80 84.60 85.50 85.30
GCNII+ 84.708.30\textbf{84.70}_{8.30\uparrow}^{**} 84.802.90\textbf{84.80}_{2.90\uparrow}^{**} 84.703.20\textbf{84.70}_{3.20\uparrow}^{**} 85.200.40\textbf{85.20}_{0.40\uparrow}^{**} 85.400.80\textbf{85.40}_{0.80\uparrow}^{**} 86.300.80\textbf{{\color[rgb]{1,0,0}86.30}}_{0.80\uparrow}^{*} 85.600.30\textbf{85.60}_{0.30\uparrow}
Citeseer GCN 70.20 62.50 62.90 21.00 17.90 22.90 19.80
GCN+ 72.902.70\textbf{{\color[rgb]{1,0,0}72.90}}_{2.70\uparrow}^{**} 67.304.80\textbf{67.30}_{4.80\uparrow}^{**} 72.009.10\textbf{72.00}_{9.10\uparrow}^{**} 54.7033.7\textbf{54.70}_{33.7\uparrow}^{**} 50.3032.4\textbf{50.30}_{32.4\uparrow}^{**} 48.4025.5\textbf{48.40}_{25.5\uparrow}^{**} 46.6026.8\textbf{46.60}_{26.8\uparrow}^{**}
GAT 71.70 58.60 26.60 18.10
GAT+ 73.001.30\textbf{{\color[rgb]{1,0,0}73.00}}_{1.30\uparrow}^{**} 69.5010.9\textbf{69.50}_{10.9\uparrow}^{**} 47.6021.0\textbf{47.60}_{21.0\uparrow}^{**} 31.8013.7\textbf{31.80}_{13.7\uparrow}^{**} 31.3031.3\textbf{31.30}_{31.3\uparrow}^{**} 30.6030.6\textbf{30.60}_{30.6\uparrow}^{**} 29.3029.3\textbf{29.30}_{29.3\uparrow}^{**}
GRAND 71.94 72.58 73.87 75.00 75.16 72.90 69.52
GRAND+ 72.260.32\textbf{72.26}_{0.32\uparrow}^{*} 73.550.97\textbf{73.55}_{0.97\uparrow}^{**} 75.161.29\textbf{75.16}_{1.29\uparrow}^{**} 75.650.65\textbf{{\color[rgb]{1,0,0}75.65}}_{0.65\uparrow}^{*} 75.520.36\textbf{75.52}_{0.36\uparrow}^{*} 74.521.62\textbf{74.52}_{1.62\uparrow}^{*} 72.262.74\textbf{72.26}_{2.74\uparrow}^{**}
ResGCN 67.10 66.00 63.60 65.50 62.3 18.80 18.10
ResGCN+ 68.000.90\textbf{{\color[rgb]{1,0,0}68.00}}_{0.90\uparrow}^{**} 67.601.60\textbf{67.60}_{1.60\uparrow}^{**} 66.002.40\textbf{66.00}_{2.40\uparrow}^{**} 66.000.50\textbf{66.00}_{0.50\uparrow}^{*} 65.803.50\textbf{65.80}_{3.50\uparrow}^{**} 24.005.20\textbf{24.00}_{5.20\uparrow}^{**} 24.306.20\textbf{24.30}_{6.20\uparrow}^{**}
DenseGCN 67.40 64.00 62.20
DenseGCN+ 67.800.40\textbf{{\color[rgb]{1,0,0}67.80}}_{0.40\uparrow}^{*} 66.602.60\textbf{66.60}_{2.60\uparrow}^{**} 64.702.50\textbf{64.70}_{2.50\uparrow}^{**}
GCNII 66.50 67.80 69.30 71.60 73.10 71.40 70.20
GCNII+ 72.405.90\textbf{72.40}_{5.90\uparrow}^{**} 73.305.5\textbf{73.30}_{5.5\uparrow}^{**} 73.804.50\textbf{73.80}_{4.50\uparrow}^{**} 73.401.80\textbf{73.40}_{1.80\uparrow}^{**} 73.800.70\textbf{73.80}_{0.70\uparrow}^{**} 74.603.20\textbf{{\color[rgb]{1,0,0}74.60}}_{3.20\uparrow}^{**} 73.903.70\textbf{73.90}_{3.70\uparrow}^{**}
Pubmed GCN 78.50 76.50 77.30 40.90 38.20 38.10 38.70
GCN+ 79.801.30\textbf{{\color[rgb]{1,0,0}79.80}}_{1.30\uparrow}^{**} 79.102.60\textbf{79.10}_{2.60\uparrow}^{**} 78.200.90\textbf{78.20}_{0.90\uparrow}^{**} 77.4036.5\textbf{77.40}_{36.5\uparrow}^{**} 76.2038.0\textbf{76.20}_{38.0\uparrow}^{**} 75.1037.0\textbf{75.10}_{37.0\uparrow}^{**} 73.0034.3\textbf{73.00}_{34.3\uparrow}^{**}
GAT 77.40 72.20 77.80 40.70
GAT+ 77.900.50\textbf{77.90}_{0.50\uparrow}^{*} 77.305.10\textbf{77.30}_{5.10\uparrow}^{**} 78.500.70\textbf{{\color[rgb]{1,0,0}78.50}}_{0.70\uparrow}^{*} 73.5032.8\textbf{73.50}_{32.8\uparrow}^{**} 68.2068.2\textbf{68.20}_{68.2\uparrow}^{**} 66.8066.8\textbf{66.80}_{66.8\uparrow}^{**} 63.5063.5\textbf{63.50}_{63.5\uparrow}^{**}
GRAND 77.07 77.94 78.29 79.93 79.12
GRAND+ 78.030.96\textbf{78.03}_{0.96\uparrow}^{**} 78.340.40\textbf{78.34}_{0.40\uparrow}^{*} 80.211.92\textbf{{\color[rgb]{1,0,0}80.21}}_{1.92\uparrow}^{**} 80.080.15\textbf{80.08}_{0.15\uparrow}^{**} 79.320.20\textbf{79.32}_{0.20\uparrow}
ResGCN 76.30 77.30 76.20 77.60 73.30 31.90 31.90
ResGCN+ 77.801.50\textbf{77.80}_{1.50\uparrow}^{**} 78.701.40\textbf{78.70}_{1.40\uparrow}^{**} 78.802.60\textbf{{\color[rgb]{1,0,0}78.80}}_{2.60\uparrow}^{*} 78.601.00\textbf{78.60}_{1.00\uparrow}^{**} 76.903.60\textbf{76.90}_{3.60\uparrow}^{**} 76.8044.90\textbf{76.80}_{44.90\uparrow}^{**} 32.000.10\textbf{32.00}_{0.10\uparrow}
DenseGCN 75.80 76.10 75.80
DenseGCN+ 76.100.30\textbf{76.10}_{0.30\uparrow} 76.700.60\textbf{76.70}_{0.60\uparrow}^{*} 77.501.70\textbf{{\color[rgb]{1,0,0}77.50}}_{1.70\uparrow}^{**}
GCNII 77.30 78.80 79.50 79.70 79.90 0.7980 79.70
GCNII+ 78.401.10\textbf{78.40}_{1.10\uparrow}^{**} 80.101.30\textbf{{\color[rgb]{1,0,0}80.10}}_{1.30\uparrow}^{**} 80.000.60\textbf{80.00}_{0.60\uparrow}^{*} 80.100.30\textbf{{\color[rgb]{1,0,0}80.10}}_{0.30\uparrow} 80.000.20\textbf{80.00}_{0.20\uparrow} 80.000.20\textbf{80.00}_{0.20\uparrow} 80.100.40\textbf{{\color[rgb]{1,0,0}80.10}}_{0.40\uparrow}^{*}

3.2 Experimental results on graph node classification

We postulate that by mitigating the over-smoothing issue, we can leverage the depth of GNN models to effectively capture complex feature representations in graph data. As shown in Table 1, we investigate the graph node classification accuracy as we increase the number of GNN layers by six benchmark GNN models and their corresponding plug-in models (indicated by ’+’ at the end of each GNN model name) with the DC-layer. The results demonstrate that: (1) the new GNN models generated from the GNN-PDE-COV framework have achieved SOTA in Cora (86.30% by GCNII+), Citeseer (75.65% by GRAND+), and Pubmed (80.10 % by GCNII+); (2) all of new GNN models outperforms their original counterparts with significant improvement in accuracy; (3) the new GNN models exhibit less sensitivity to the increase of model depth compared to current GNN models; (4) the new GNN models are also effective in resolving the gradient explosion problem [30] (e.g, the gradient explosion occurs when training GAT on all involved datasets with deeper than 16 layers, while our GAT+ can maintain reasonable learning performance even reaching 128 layers.)

It is important to note that due to the nature of the graph diffusion process, graph embeddings from all GNN models (including ours) will eventually become identical after a sufficiently large number of layers [11]. However, the selective diffusion mechanism (i.e., penalizing excessive diffusion across communities) provided by our GNN-PDE-COV framework allows us to control the diffusion patterns and optimize them for specific graph learning applications.

3.3 Application for uncovering the propagation mechanism of pathological events in AD

First, we evaluate the prediction accuracy between the ground truth and the estimated concentration values by our X-FlowNet and six benchmark GNN methods. The statistics of MAE (mean absolute error) by X-FlowNet, GCN, GAT, GRAND, ResGCN, DenseGCN and GCNII, at different noise levels on the observed concentration levels, are shown in Fig. 5 (a). It is clear that our X-FlowNet consistently outperforms the other GCN-based models in all imaging modalities.

Second, we have evaluated the potential of disease risk prediction and presented the results in Table 4 in Supplementary document, where our GNN-PDE-COV model not only achieved the highest diagnostic accuracy but also demonstrated a significant improvement (paired t-test p<0.001p<0.001) in disease risk prediction compared to other methods. These results suggest that our approach holds great clinical value for disease early diagnosis.

Third, we examine the spreading flows of tau aggregates in CN (cognitively normal) and AD groups. As the inward and outward flows shown in Fig. 5(b), it is evident that there are significantly larger amount of tau spreading between sub-cortical regions and entorhinal cortex in CN (early sign of AD onset) while the volume of subcortical-entorhinal tau spreading is greatly reduced in the late stage of AD. This is consistent with current clinical findings that tau pathology starts from sub-cortical regions and then switches to cortical-cortical propagation as disease progresses [28]. However, our Tau-FlowNet offers a fine-granularity brain mapping of region-to-region spreading flows over time, which provides a new window to understand the tau propagation mechanism in AD etiology [14].

Refer to caption
Figure 5: (a) Prediction accuracy by X-FlowNet and six benchmark GNN models w.r.t. various noise levels. (b) The subcortical\shortrightarrowcortical tau flows are profound in CN. But in AD, there is a diminished extent of such flows.

4 Conclusion

In this work, we present the GNN-PDE-COV framework to re-think and re-design GNN models with great mathematical insight. On top of this, we devise the selective inductive bias to address the over-smoothing problem in GNN and develop new GNN model to predict the pathology flows in-vivo via longitudinal neuroimages. Future work may involve exploring innovative graph regularization techniques and conducting further validation on a broader range of graph-based learning tasks.

Supplementary

5 Solving variational problems: From objective functional to E-L equations

5.1 Step-by-step derivation of min-max optimization in Section 2.2.1

By substituting Eq. 2 into Eq. 1 in the main manuscript, we can obtain the objective function of subscript zz (we temporarily drop ii for clarity):

𝒥(z)\displaystyle{\cal J}({z}) =max|z|𝟏λ2z𝒢x22+λz𝒢(x0λ2z𝒢x)\displaystyle=\mathop{\max}\limits_{|{z}|\leq{\bf{1}}}\left\|{\frac{\lambda}{2}{z}{\nabla_{\cal G}}{x}}\right\|_{2}^{2}+\lambda{z}{\nabla_{\cal G}}(x^{0}-\frac{\lambda}{2}{z}{\nabla{\cal G}}{x}) (8)
=max|z|𝟏λ24z𝒢xz𝒢x+λz𝒢x0\displaystyle=\begin{aligned} &\mathop{\max}\limits_{|{z}|\leq{\bf{1}}}-\frac{{{\lambda^{2}}}}{4}{z}{\nabla_{\cal G}}{x}{z}{\nabla_{\cal G}}{x}+\lambda{z}{\nabla_{\cal G}}x^{0}\end{aligned} (9)

Next, we convert Eq. LABEL:function2 into a minimization problem as follows:

z=argmin|z|𝟏z𝒢xz𝒢x4λz𝒢x0{z}=\begin{array}[]{*{20}{c}}{\mathop{\arg\min}\limits_{|{z}|\leq{\bf{1}}}}&{{z}{\nabla_{\cal G}}{x}{z}{\nabla_{\cal G}}{x}-\frac{4}{\lambda}{z}{\nabla_{\cal G}}x^{0}}\end{array} (10)

By letting the derivative with respect to ziz_{i} to zero, we have the following equation

𝒢xz𝒢x=4λ𝒢x0{\nabla_{\cal G}}{x}{z}{\nabla_{\cal G}}{x}=\frac{4}{\lambda}{\nabla_{\cal G}}x^{0} (11)

Since zz might be in high dimensional space, solving such a large system of linear equations under the constraint |z|1|z|\leq 1 is oftentimes computationally challenging. In order to find a practical solution for zz that satisfies the constrained minimization problem in Eq. 10, we resort to the majorization-minimization (MM) method [16]. First, we define:

M(z)=z𝒢xz𝒢x4λz𝒢x0M({z})=z{\nabla_{\cal G}}{x}{z}{\nabla_{\cal G}}{x}-\frac{4}{\lambda}{z}{\nabla_{\cal G}}x^{0} (12)

By setting zlz^{l} as point of coincidence, we can find a separable majorizer of M(z)M(z) by adding the non-negative function

(zzl)(βI𝒢x𝒢x)(zzl)(z-{z^{l}})^{\intercal}(\beta I-\nabla_{\cal G}x\nabla_{\cal G}{x^{\intercal}})(z-{z^{l}}) (13)

to M(z)M(z), where β\beta is greater than or equal to the maximum eigenvalue of 𝒢x𝒢x\nabla_{\cal G}x\nabla_{\cal G}{x^{\intercal}}. Note, to unify the format, we use the matrix transpose property in Eq. 13. Therefore, a majorizer of M(z)M(z) is given by:

M(z)+(zzl)(βI𝒢x𝒢x)(zzl)M(z)+{(z-{z^{l}})^{\intercal}}(\beta I-\nabla_{\cal G}x\nabla_{\cal G}{x^{\intercal}})(z-{z^{l}}) (14)

And, using the MM approach, we can obtain the update equation for zz as follows:

zl+1\displaystyle\mathop{{z^{l+1}}} =argmin|z|1(M(z)+(zzl)(βI𝒢x𝒢x)(zzl))\displaystyle={\mathop{\arg\min}\limits_{\left|z\right|\leq 1}}(M(z)+{{(z-{z^{l}})}^{\intercal}}(\beta I-{\nabla_{\cal G}}x{\nabla{\cal G}}{x^{\intercal}})(z-{z^{l}})) (15)
=argmin|z|1(βzz2(𝒢(2λx0𝒢xzl)+βzl)z)\displaystyle={\mathop{\arg\min}\limits_{\left|z\right|\leq 1}}(\beta{z^{\intercal}}z-2{{({\nabla_{\cal G}}(\frac{2}{\lambda}{x^{0}}-{\nabla{\cal G}}x{z^{l}})+\beta{z^{l}})}^{\intercal}}z)
=argmin|z|1(zz2(1β𝒢(2λx0𝒢xzl)+zl)z)\displaystyle={\mathop{\arg\min}\limits_{\left|z\right|\leq 1}}({z^{\intercal}}z-2{{(\frac{1}{\beta}{\nabla_{\cal G}}(\frac{2}{\lambda}{x^{0}}-{\nabla{\cal G}}x{z^{l}})+{z^{l}})}^{\intercal}}z)
=argmin|z|1(zz2bz)\displaystyle={\mathop{\arg\min}\limits_{\left|z\right|\leq 1}}({z^{\intercal}}z-2{b^{\intercal}}z)

where b=zl+1β𝒢(2λx0𝒢xzl)b={z^{l}}+\frac{1}{\beta}{\nabla_{\cal G}}(\frac{2}{\lambda}{x^{0}}-{\nabla_{\cal G}}x{z^{l}}).

Then, the next step is to find zNz\in{\cal R}^{N} that minimizes zz2bzz^{\intercal}z-2bz subject to the constraint |z|1|z|\leq 1. Let’s first consider the simplest case where zz is a scalar:

argmin|z|1z22bz\begin{array}[]{*{20}{c}}{\mathop{\arg\min}\limits_{\left|z\right|\leq 1}}&{{z^{2}}-2bz}\end{array} (16)

The minimum of z22bz{{z^{2}}-2bz} is at z=bz=b. If b1b\leq 1, then the solution is z=bz=b. If |b|1|b|\geq 1, then the solution is z=sign(b)z=sign(b). We can define the clipping function as:

clip(b,1):={b|b|1sign(b)|b|1clip(b,1):=\left\{{\begin{array}[]{*{20}{c}}{\begin{array}[]{*{20}{c}}b&{\left|b\right|\leq 1}\end{array}}\\ {\begin{array}[]{*{20}{c}}{sign(b)}&{\left|b\right|\geq 1}\end{array}}\end{array}}\right. (17)

as illustrated in the middle of Fig. 3 in the main text, then we can write the solution to Eq. 16 as z=clip(b,1)z=clip(b,1).

Note that the vector case Eq. 15 is separable - the elements of zz are uncoupled so the constrained minimization can be performed element-wise. Therefore, an update equation for zz is given by:

zl+1=clip(zl+1β𝒢(2λx0𝒢xzl),1){z^{l+1}}=clip({z^{l}}+\frac{1}{\beta}{\nabla_{\cal G}}(\frac{2}{\lambda}{x^{0}}-{\nabla_{\cal G}}x{z^{l}}),1) (18)

where ll denotes the index of the network layer, the representation of (l+1)th(l+1)^{th} is given by Eq. (1) in the main manuscript. Because the optimization problem is convex, the iteration will converge from any initialization. We may choose, say z0=0z^{0}=0. We call this the iterative diffusion-clip (DC) algorithm.

This algorithm can also be written as

xl+1=x0λ2𝒢zl\displaystyle x^{l+1}=x^{0}-\frac{\lambda}{2}{\nabla_{\cal G}}^{\intercal}z^{l} (19)
zl+1=clip(zl+2βλ𝒢xl+1,1).\displaystyle z^{l+1}=clip\left(z^{l}+\frac{2}{\beta\lambda}{\nabla_{\cal G}}x^{l+1},1\right).

By scaling zz with a factor of λ/2\lambda/2, we have the following equivalent formulations:

xl+1=x0𝒢zl\displaystyle x^{l+1}=x^{0}-{\nabla_{\cal G}}^{\intercal}z^{l} (20)
zl+1=clip(z(i)+1β𝒢xl+1,λ2)\displaystyle z^{l+1}=clip\left(z^{(i)}+\frac{1}{\beta}{\nabla_{\cal G}}x^{l+1},\frac{\lambda}{2}\right)

We summarize the process of the diffusion-clip (DC) layer in Algorithm 1 (it is similar to the iterative shrinkage threshold algorithm [12]):

Algorithm 1 DC layer process

The objective function:

𝒥(x)=minx(xx(0)22+𝒥TV(x)){\cal J}(x)=\mathop{\min}\limits_{x}(\left\|{x-x^{(0)}}\right\|_{2}^{2}+\mathcal{J}_{TV}(x))

can be minimized by alternating the following two steps:

xl=x0𝒢xzl1x^{l}=x^{0}-{\nabla_{\cal G}}x^{\intercal}z^{l-1}

zl=clip(zl1+1β𝒢xl,λ2)=clip(zl1+2βλ𝒢xl,1)z^{l}=clip\left(z^{l-1}+\frac{1}{\beta}{\nabla_{\cal G}}x^{l},\frac{\lambda}{2}\right)=clip\left(z^{l-1}+\frac{2}{\beta\lambda}{\nabla_{\cal G}}x^{l},1\right)

for l1l\geq 1 with z0=0z^{0}=0 and βmaxeig(𝒢x𝒢x)\beta\geq maxeig({\nabla_{\cal G}}x^{\intercal}{\nabla_{\cal G}}x)

5.2 The step-by-step derivation of min-max optimization schema in Section 2.2.2

According to the introduction of Secction 2.2.2 (Eq. 4 and Eq. 5) in the main manuscript, we summarize the following equations,

{dxdt+div(q)=0ui=ϕ(xi)q=αuΔu=div(u)derive{dxdt=div(q)dudt=ϕ1div(q)dudt=ϕ1div(αq)dudt=ϕ1(αΔu)\left\{{\begin{array}[]{*{20}{c}}{\frac{{dx}}{{dt}}+div(q)=0}\\ {{u_{i}}=\phi({x_{i}})}\\ {q=\alpha\otimes\nabla u}\\ {\Delta u=div(\nabla u)}\end{array}}\right.\xrightarrow{\hskip 14.22636pt\text{derive}\hskip 14.22636pt}\left\{{\begin{array}[]{*{20}{c}}{\frac{{dx}}{{dt}}=-div(q)}\\ {\frac{{du}}{{dt}}=-{\phi^{-1}}div(q)}\\ {\frac{{du}}{{dt}}=-{\phi^{-1}}div(\alpha\otimes q)}\\ {\frac{{du}}{{dt}}=-{\phi^{-1}}(\alpha\otimes\Delta u)}\end{array}}\right. (21)

Since the PDE in Eq. 5 in the main manuscript is equivalent to the E-L equation of the quadratic functional 𝒥(u)=minu𝒢α|𝒢u|2𝑑u\mathcal{J}(u)=\mathop{\min}\limits_{u}\int_{\mathcal{G}}\alpha\otimes|\nabla_{\mathcal{G}}u{|^{2}}du (after taking ϕ\phi away), we propose to replace the 2{\ell_{2}}-norm integral functional 𝒥(u)\mathcal{J}(u) with TV-based counterpart

𝒥TV(u)=minu𝒢α|𝒢u|𝑑u\mathcal{J}_{TV}(u)=\mathop{\min}\limits_{u}\int_{\mathcal{G}}\alpha\otimes|\nabla_{\mathcal{G}}u{|}du (22)

We then introduce an auxiliary matrix ff to lift the undifferentiable barrier, and reformulate the TV-based functional as a dual min-max functional

𝒥TV(u,f)=minumaxf𝒢αf(𝒢u)𝑑u\mathcal{J}_{TV}(u,f)=\mathop{\min}\limits_{u}\mathop{\max}\limits_{f}\int_{\mathcal{G}}\alpha\otimes f(\nabla_{\mathcal{G}}u)du (23)

where we maximize ff such that 𝒥TV(u,f)\mathcal{J}_{TV}(u,f) is close enough to 𝒥TV(u)\mathcal{J}_{TV}(u). Using Gâteaux variations, we assume uu+εau\to u+\varepsilon a, ff+εbf\to f+\varepsilon b, and the directional derivatives in the directions aa and bb defined as d𝒥dε(u+εa)|ε0{\left.{\frac{{d{\cal J}}}{{d\varepsilon}}(u+\varepsilon a)}\right|_{\varepsilon\to 0}} and d𝒥dε(f+εb)|ε0{\left.{\frac{{d{\cal J}}}{{d\varepsilon}}(f+\varepsilon b)}\right|_{\varepsilon\to 0}}. Given a functional 𝒥TV(u,f)\mathcal{J}_{TV}(u,f), its Gâteaux variations is formulated by:

𝒥TV(u+εa,f+εb)=α[(f+εb)(u+εa)]𝑑u\displaystyle\mathcal{J}_{TV}(u+\varepsilon a,f+\varepsilon b)=\int\alpha\otimes[(f+\varepsilon b)\cdot(\nabla u+\varepsilon\nabla a)]du (24)
𝒥ε|ε0=α[(fa)+(ub)]𝑑u\displaystyle\left.\left.\Rightarrow\frac{\partial\mathcal{J}}{\partial\varepsilon}\right|_{\varepsilon\rightarrow 0}=\int\alpha\otimes[(f\cdot\nabla a)+(\nabla ub)]\right.du
𝒥ε|ε0=αfaα(af)𝑑u+α(bu)𝑑u\displaystyle\left.\Rightarrow\frac{\partial\mathcal{J}}{\partial\varepsilon}\right|_{\varepsilon\rightarrow 0}=\alpha\otimes f\cdot a-\int{\alpha\otimes(a\cdot\nabla f)du}+\int{\alpha\otimes(b\nabla u)du}

Since we assume either uu is given at the boundary (Dirichlet boundary condition), the boundary term αfa\alpha\otimes f\cdot a can be dropped. After that, the derivative of 𝒥TV(u,f)\mathcal{J}_{TV}(u,f) becomes:

𝒥ε|ε0=α(fa+ub)\left.\frac{\partial\mathcal{J}}{\partial\varepsilon}\right|_{\varepsilon\rightarrow 0}=-\int\alpha\otimes(\nabla f\cdot a+\nabla u\cdot b) (25)

Since the dummy functional aa and bb are related to uu and ff respectively, the E-L equation from the Gâteaux variations in Eq. 25 leads to two coupled PDEs:

{maxfdfdt=α𝒢uminududt=αdiv(f)\left\{{\begin{array}[]{*{20}{c}}{\mathop{\max}\limits_{f}\frac{{df}}{{dt}}={\alpha\otimes\nabla_{\mathcal{G}}}u}\\ {\mathop{\min}\limits_{u}\frac{{du}}{{dt}}=\alpha\otimes div(f)}\end{array}}\right. (26)

Note, we use the adjoint operator div(f)=fdiv(f)=-\nabla f to approximate the discretization of f\nabla f [22], which allows us to link the minimization of uu to the classic graph diffusion process.

6 Experimental details

6.1 Implementation details

6.1.1 Hyperparameters & training details

Table 2 lists the detailed parameter setting for several GNN-based models, including X-FlowNet, PDENet, GCN, GAT, ResGCN, DenseGCN and GCNII.

In the node classification experiments, we set the output dimension to be the number of classes. We adopt the public fixed split [46] to separate these datasets into training, validation, and test sets. We use the accuracy, precision and F1-score of node classification as the evaluation metrics.

For the ADNI dataset prediction experiment, we set the input and output dimensions to be the same as the number of brain nodes cannot be altered. We use 5-fold cross-validation to evaluate the performance of different methods and measure their prediction accuracy using mean absolute error (MAE). We also conduct an ablation study using a two-step approach. First, we train a model (MLP+GNN) shown in the left panel of Fig. 4 (b) in the main manuscript to predict the potential energy filed (PEF) based on the transport equation, then compute the flows using Eq. 6, followed by a GCN-based model to predict the further concentration level of AD-related pathological burdens. Since the deep model in this two-step approach is also formalized from the PDE, we refer to this degraded version as PDENet.

In addition, we conduct a prediction of the risk of developing AD using the baseline scan, which can be regarded as a graph classification experiment. This experiment only uses 2 GCN layers with a hidden dimension as 6464 for all methods, while the remaining parameters follow the node classification experiment (Table 2 top).

In this work, all experiments are conducted on a server: Intel(R) Xeon(R) Gold 5220R CPU @ 2.20GHz, NVIDIA RTX A5000. The source code is open on anonymous GitHub (https://anonymous.4open.science/r/GNN-PDE-COV-FBBD/) for the sake of reproducibility.

6.1.2 Data pre-processing on ADNI dataset.

In total, 1,291 subjects are selected from ADNI [32] dataset, each having diffusion-weighted imaging (DWI) scans and longitudinal amyloid, FDG, cortical thickness(CoTh) and tau PET scans (2-5 time points). The neuroimage processing consists of the following major steps:

  • We segment the T1-weighted image into white matter, gray matter, and cerebral spinal fluid using FSL software [24]. On top of the tissue probability map, we parcellate the cortical surface into 148 cortical regions (frontal lobe, insula lobe, temporal lobe, occipital lobe, parietal lobe, and limbic lobe) and 12 sub-cortical regions (left and right hippocampus, caudate, thalamus, amygdala, globus pallidum, and putamen), using the Destrieux atlas [13] (yellow arrows in Fig. 6). Second, we convert each DWI scan to diffusion tensor images (DTI) [38].

  • We apply surface seed-based probabilistic fiber tractography [17] using the DTI data, thus producing a 160×160160\times 160 anatomical connectivity matrix (white arrows in Fig. 6). Note, the weight of the anatomical connectivity is defined by the number of fibers linking two brain regions normalized by the total number of fibers in the whole brain (Δ\Delta for graph diffusion in X-FlowNet).

  • Following the region parcellations, we calculate the regional concentration level (the Cerebellum as the reference) of the amyloid, FDG, CoTh and tau pathologies for each brain region (red arrows in Fig. 6), yielding the input x160x\in{\mathcal{R}^{160}} for training X-FlowNet, respectively.

Following the clinical outcomes, we partition the subjects into the cognitive normal (CN), early-stage mild cognitive impairment (EMCI), late-stage mild cognitive impairment (LMCI), and AD groups. To facilitate population counts, we regard CN and EMCI as "CN-like" group, while LMCI and AD as "AD-like" groups. Table 3 summarizes the statistics of the two datasets.

Table 2: Parameters setting on Citation network (top) and ADNI data (bottom). MM denotes the feature dimension and CC denotes the number of classes. For Cora dataset, we set i=4i=4 when network layer L=2L=2, i=8i=8 if L=4L=4, i=10i=10 if L=8,16,32,64,128L=8,16,32,64,128. For Citeseer dataset, we set i=4i=4 when network layer L=2L=2, i=8i=8 if L=4L=4, i=11i=11 if L=8,16,32,64,128L=8,16,32,64,128. For Pubmed dataset, we set i=4i=4 when network layer L=2L=2, i=8i=8 if L=4,8,16,32,64,128L=4,8,16,32,64,128. The hidden dimension of lthl^{th} is twice that of layer (l1)th(l-1)^{th}. Take Cora as an example (8 layers), the dimension of the hidden layer is: 1433 \to 1024 \to512 \to 256\to128\to64\to32\to16\to7. After exceeding 8 layers, the number of hidden layers is doubled according to the total network layer.
Algorithm Optimizer Learning rate Weight decay Hidden layer Dropout Epoch
GCN Adam 0.01 5×1045\times 10^{-4} M2i24CM\to 2^{i}\to...\to 2^{4}\to C 0.5 1500
GAT Adam 0.001 5×1045\times 10^{-4} head=8, M2iCM\to 2^{i}...\to C 0.6 2000
RGCN Adam 0.005 5×1045\times 10^{-4} hidden dimension=64 0.1 2500
DGCN Adam 0.001 5×1045\times 10^{-4} hidden dimension=64 0.1 2500
GRAND Adam 0.01 5×1045\times 10^{-4} hidden dimension=16 0.5 200
GCNII Adam 0.005 5×1045\times 10^{-4} hidden dimension=128 0.6 2000
GCN Adam 0.001 5×1045\times 10^{-4} hidden dimension=16 0.2 500
GAT Adam 0.001 5×1045\times 10^{-4} head=8, hidden dimension=4 0.5 800
RGCN Adam 0.001 5×1045\times 10^{-4} hidden dimension=16 0.1 500
DGCN Adam 0.01 5×1045\times 10^{-4} hidden dimension=8 0.1 500
GCNII Adam 0.005 5×1045\times 10^{-4} hidden dimension=16 0.6 1500
GRAND Adam 0.01 5×1045\times 10^{-4} hidden dimension=16 0.5 500
X-FlowNet Adam 1e-4/3e-3 1×1051\times 10^{-5} hidden dimension=16 0.5 500
PDENet Adam 0.01 1×1051\times 10^{-5} hidden dimension=16 0.5 500

Refer to caption

Figure 6: General workflows for processing T1-weighted image (yellow arrows), diffusion-weighted image (white arrows), and PET images (red arrows). The output is shown at the bottom right, including the brain network, and regional concentration level of amyloid, FDG, CoTh and tau aggregates.
Table 3: Dataset statistics.
Node classification (Citation) Application on flow prediction (ADNI)
Dataset Description Features # of subjects (CN/AD)
Classes Nodes Edges Features Amyloid (160) 304/83
Cora 7 2708 5429 1433 Tau (160) 124/37
Citeseer 6 3327 4732 3703 FDG (160) 211/63
Pubmed 3 19717 44338 500 Cortical thickness (160) 359/110

6.2 Experiments on node classification

Fig 7 presents the performance of different evaluation criteria (accuracy, precision, and F1-score) across different network layers for node classification by benchmark GNN model (patterned in dash lines) and the counterpart novel GNN model from our GNN-PDE-COV framework (patterned by solid lines), where each row is associated with a specific instance of GNN model. It is evident that our proposed GNN-PDE-COV consistently outperforms other methods across different layers, with significantly enhanced degrees in accuracy, precision, and F1-score. Moreover, the GNN model yielded from our GNN-PDE-COV framework consistently achieves the highest accuracy on all three datasets. Overall, these results demonstrate the state-of-the-art performance by our GNN-PDE-COV framework in graph node classification.

Refer to caption
Figure 7: The performance of node classification with respect to various GNN layers (horizontal axis) on six models. Note: dotted line: baseline, solid line: GNN-PDE-COV, blue: Cora, purple: Citeseer, red: Pubmed.

The effect of anti-smoothing by clip operation is shown in Fig. 8. To set up the stage, we put the spotlight on the links that connect two nodes with different categorical labels. In this context, 2,006 links from Cora, 2,408 links from Citeseer, and 17,518 links from Pubmed datasets are selected, called inter-class links. For each inter-class link, we calculate node-to-node similarity in terms of Pearson’s correlation between two associated graph embedding vectors 444the learned feature representations for node classification by benchmark methods (in red) and the counterpart GNN models derived from GNN-PDE-COV framework (in green). We find that (1) more than 70% nodes are actually associated with inter-class links which confirms the hypothesis of over-smoothing in Fig. 1 of our manuscript; (2) Our novel GNN models have the ability to learn feature representations that better preserve the discriminative power for node classification (as indicated by the distribution of node-to-node similarity shifting towards the sign of anti-correlation).

Refer to caption
Figure 8: The distribution of the node-to-node similarities (measured by Pearson’s correlation between embedding vectors) by Benchmark methods (in red) and our GNN-PDE-COV (in green) in Cora (left), Citeseer (middle), and Pubmed (right).

6.3 Application on uncovering the propagation mechanism of pathological events in AD

Firstly, we examine the prediction accuracy for each modality of concentration (tau, amyloid, FDG, CoTh) level at different noise levels. Specifically, to evaluate the robustness of our X-FlowNet model to noise, we conducted an experiment by adding uncorrelated additive Gaussian noise levels with standard deviation ranging from 0.02 to 1 to the observed concentration levels of tau, amyloid, FDG, and CoTh. We then evaluated the prediction accuracy (MAE) using 5-fold cross-validation. The prediction results, as shown in Fig. 9, indicate that our X-FlowNet model is less sensitive to noise added to the imaging features than all other counterpart GNN methods.

Refer to caption
Figure 9: The performance of 5-fold cross-validation for predicting the future concentration (top to bottom: Tau, Amyloid, FDG and CoTh) level by (1) GCN, (2) GAT, (3) GCNII, (4) ResGCN, (5) DenseGCN, (6) GRAND, (7) PDENet (used in ablation study), and (8) our X-FlowNet.

Secondly, we conduct an ablation study to compare our X-FlowNet model with PDENet (marked as #7 in Fig. 9). Our model, which is in a GAN architecture and incorporates a TV constraint to avoid over-smoothing, integrates the two steps of estimating the PEF and uncovering the spreading flows into a unified neural network, resulting in significantly improved prediction accuracy compared to PDENet.

Thirdly, we perform a disease risk prediction experiment, which can be regarded as a graph classification problem. We assume that we have baseline amyloid, tau, FDG, and CoTh scans, and evaluate the prediction accuracy, precision and F1-score of various models in forecasting the risk of developing AD. We consider two dichotomous cases: one included only AD vs. CN groups and the other involved AD/LMCI vs. CN/EMCI. The results of the mean of 5-fold cross-validation are shown in Table 4. Our GNN-PDE-COV outperforms all other methods in terms of accuracy, precision and F1-score indicated by an asterisk (‘*’) at the significance level of 0.001.

Fourthly, we examine the propagation pattern of tau spreading flows on an individual basis (Fig. 10). First, we visualize the top flows (ranked in terms of flow volume) uncovered in a CN subject (Fig. 10(a)). It is apparent that subcortex-cortex flows are the predominant patterns, where most of the tau aggregates spread from subcortical regions (globus pallidus, hippocampus, and putamen) to the temporal lobe, limbic lobe, parietal lobe, and insula lobe. Note, we find inferior temporal gyrus (t6t_{6}) and entorhinal cortex (t8t_{8}) are actively involved in the subcortex-cortex flows, which are the footprints of early stage tau propagation frequently reported in many pathology studies [28; 40]. Second, we visualize the top flows uncovered in an AD subject (Fig. 10(b)). It is apparent that the propagation of tau is restricted on the brain cortex, mainly spreading from temporal lobe regions to other regions (such as frontal lobe, limbic lobe and occipital lobe), which is aligned with current clinical and pathology findings that predominant amount of tau aggregates propagate throughout brain cortex in the late stage of AD.

Table 4: The performance of disease risk prediction. Note: RGCN denotes ResGCN, DGCN denotes DenseGCN. ‘*’ denotes the significant improvement (paired t-test: p<0.001p<0.001). Blue: Tau, red: amyloid, orange: FDG, green: CoTh.
Tau Unit (%) GCN GCN+ GAT GAT+ GCNII GCNII+ RGCN RGCN+ DGCN DGCN+ GRAND GRAND+
Precision 80.15 90.03(*) 69.91 86.18(*) 83.93 90.03(*) 84.64 89.46(*) 84.03 91.58(*) 87.95 88.22(*)
Accuracy 82.30 88.74(*) 81.05 87.50(*) 83.79 88.75(*) 86.03 90.00(*) 85.54 91.25(*) 88.75 90.12(*)
AD/LMCI vs. CN/EMCI F1-score 75.55 84.49(*) 72.87 84.72(*) 78.82 84.45(*) 83.15 88.54(*) 82.45 91.39(*) 88.14 89.44(*)
Precision 89.29 91.92(*) 87.26 90.13(*) 83.65 88.52(*) 92.61 95.72(*) 92.61 95.91(*) 91.77 95.76(*)
Accuracy 86.64 90.91(*) 84.86 88.41(*) 76.84 86.36(*) 91.07 95.45(*) 91.07 95.65(*) 90.91 95.45(*)
AD vs. CN F1-score 85.64 90.26(*) 83.99 87.16(*) 71.51 84.68(*) 90.45 95.32(*) 90.45 95.55(*) 88.86 95.38(*)
Amyloid Unit (%) GCN GCN+ GAT GAT+ GCNII GCNII+ RGCN RGCN+ DGCN DGCN+ GRAND GRAND+
Precision 76.36 83.78(*) 67.73 71.79(*) 60.87 60.01() 72.53 83.21(*) 74.92 60.17(*) 79.00 79.93(*)
Accuracy 76.40 79.44(*) 75.43 77.57(*) 74.31 76.64(*) 75.99 78.50(*) 76.92 77.57(*) 80.37 81.31(*)
AD/LMCI vs. CN/EMCI F1-score 70.33 72.58(*) 67.66 69.39(*) 63.57 67.31(*) 70.66 70.67(*) 72.68 67.77(*) 79.25 79.63(*)
Precision 81.58 88.37(*) 81.54 87.98(*) 70.59 79.98(*) 85.75 93.09(*) 83.87 90.56(*) 65.53 89.62(*)
Accuracy 80.77 88.10(*) 80.78 88.10(*) 75.02 81.24(*) 85.56 92.86(*) 85.29 90.48(*) 80.95 87.80(*)
AD vs. CN F1-score 78.14 87.68(*) 78.07 87.98(*) 65.87 77.42(*) 85.34 92.92(*) 82.30 90.27(*) 72.43 88.22(*)
FDG Unit (%) GCN GCN+ GAT GAT+ GCNII GCNII+ RGCN RGCN+ DGCN DGCN+ GRAND GRAND+
Precision 68.43 67.20(*) 55.86 59.29(*) 60.08 70.94(*) 50.45 55.14(*) 50.45 55.14(*) 51.38 56.25(*)
Accuracy 73.17 76.00(*) 72.17 77.00(*) 71.78 74.54(*) 70.98 74.26(*) 70.98 74.26(*) 71.10 75.00(*)
AD/LMCI vs. CN/EMCI F1-score 63.94 68.15(*) 62.15 66.99(*) 61.02 69.07(*) 58.96 63.29(*) 58.96 63.29(*) 59.42 64.29(*)
Precision 81.11 87.25(*) 61.90 62.33(*) 74.31 81.06(*) 59.77 80.57(*) 66.84 81.77(*) 70.91 72.24(*)
Accuracy 82.17 84.62(*) 72.82 78.95(*) 79.55 82.05(*) 73.35 79.58(*) 73.87 80.11(*) 84.21 86.32(*)
AD vs. CN F1-score 79.40 82.04(*) 64.23 69.66(*) 73.88 80.99(*) 62.77 75.98(*) 63.92 76.58(*) 76.99 78.06(*)
CoTh Unit (%) GCN GCN+ GAT GAT+ GCNII GCNII+ RGCN RGCN+ DGCN DGCN+ GRAND GRAND+
Precision 74.85 74.71(*) 62.63 67.15(*) 62.63 74.71(*) 62.63 68.77(*) 62.63 64.59(*) 63.81 68.77(*)
Accuracy 80.68 82.32(*) 79.10 79.34(*) 79.10 80.37(*) 79.10 82.93(*) 79.10 80.37(*) 79.88 82.93(*)
AD/LMCI vs. CN/EMCI F1-score 73.55 75.93(*) 69.89 70.44(*) 69.89 72.72(*) 69.89 75.19(*) 69.89 71.62(*) 70.94 75.19(*)
Precision 83.45 85.77(*) 71.24 72.80(*) 76.14 79.84(*) 65.04 80.62(*) 65.04 81.52(*) 71.24 74.37(*)
Accuracy 84.79 87.16(*) 81.50 85.32(*) 81.50 83.49(*) 80.59 83.52(*) 80.59 82.42(*) 84.40 86.24(*)
AD vs. CN F1-score 82.02 83.69(*) 75.06 78.56(*) 74.17 81.07(*) 71.95 78.48(*) 71.95 76.13(*) 77.27 79.87(*)
Refer to caption
Figure 10: Visualization of tau spreading flows in an individual cognitive normal subject (a) and an individual Alzheimer’s disease subject (b). For clarity, we only display the top-ranked flows sorted by the flow volume.

6.4 Discussion and limitations

Discussion. In our experiments, we found adding DC layer right after every FC layer usually does not yield best performance. Instead, we empirically set to add DC layer from the first several FC layers. For example, we add DC layer after the 3rd3^{rd} FC layer in an 8-layer GNN model, after the 5th5^{th} FC layer in a 16-layer GNN model, and after 8th8^{th} FC layer in a GNN model with more than 16 layers. One possible explanation is that the clip operation in DC layer depends on a good estimation of cap bb in Eq. 3 (in the main manuscript). Given that the estimation of bb may lack stability during the initial stages of graph learning, it can be advantageous to postpone the clip operation from an engineering perspective. However, delaying the addition of the DC layer too much can result in missed opportunities to address the problem of over-smoothing.

Regarding the computational time, we record the additional computational time of training our DC layer on different datasets. Specifically, the extra training time is 2.2 ms/epoch in Cora, 9.8 ms/epoch in Citeseer, 7.8 ms/epoch in Pubmed, and 0.3 ms/epoch in ADNI, respectively, where the data descriptions are listed in Table 3. It is apparent that the TV-based constraint effectively addresses the over-smoothing issue in GNN without imposing a significant computational burden.

Limitations. Our current graph learning experiments are limited to citation networks. In the future, we will evaluate our GNN-PDE-COV framework on other graph datasets such as drug medicine and protein networks.

Societal impact. Our major contribution to the machine learning field is a novel research framework which allows us to develop new GNN models with a system-level understanding. We have provided a new approach to address the common issue of over-smoothing in GNN with a mathematical guarantee. From the application perspective, the new deep model for uncovering the in-vivo propagation flows has great potential to establish new underpinning of disease progression and disentangle the heterogeneity of diverse neurodegeneration trajectories.

References

  • [1] Appleton, B., Talbot, H.: Globally minimal surfaces by continuous maximal flows. IEEE transactions on pattern analysis and machine intelligence 28(1), 106–118 (2005)
  • [2] Arnol’d, V.I.: Mathematical methods of classical mechanics, vol. 60. Springer Science & Business Media (2013)
  • [3] Bammer, R.: Basic principles of diffusion-weighted imaging. European journal of radiology 45(3), 169–184 (2003)
  • [4] Beck, A., Teboulle, M.: A fast iterative shrinkage-thresholding algorithm for linear inverse problems. SIAM Journal on Imaging Sciences 2(1), 183–202 (2009)
  • [5] Bronstein, M.: Beyond message passing: a physics-inspired paradigm for graph neural networks. The Gradient (2022)
  • [6] Chamberlain, B., Rowbottom, J., Gorinova, M.I., Bronstein, M., Webb, S., Rossi, E.: Grand: Graph neural diffusion. In: International Conference on Machine Learning. pp. 1407–1418. PMLR (2021)
  • [7] Chambolle, A., Pock, T.: A first-order primal-dual algorithm for convex problems with applications to imaging. Journal of Mathematical Imaging and Vision 40(1), 120–145 (2011)
  • [8] Chan, T., Esedoglu, S., Park, F., Yip, A.: Total variation image restoration: Overview and recent developments. Handbook of mathematical models in computer vision pp. 17–31 (2006)
  • [9] Chen, M., Wei, Z., Huang, Z., Ding, B., Li, Y.: Simple and deep graph convolutional networks. In: International conference on machine learning. pp. 1725–1735. PMLR (2020)
  • [10] Chen, R.T., Rubanova, Y., Bettencourt, J., Duvenaud, D.K.: Neural ordinary differential equations. Advances in neural information processing systems 31 (2018)
  • [11] Coifman, R.R., Lafon, S.: Diffusion maps. Applied and computational harmonic analysis 21(1), 5–30 (2006)
  • [12] Daubechies, I., Defrise, M., De Mol, C.: An iterative thresholding algorithm for linear inverse problems with a sparsity constraint. Communications on Pure and Applied Mathematics: A Journal Issued by the Courant Institute of Mathematical Sciences 57(11), 1413–1457 (2004)
  • [13] Destrieux, C., Fischl, B., Dale, A., Halgren, E.: Automatic parcellation of human cortical gyri and sulci using standard anatomical nomenclature. Neuroimage 53(1), 1–15 (2010)
  • [14] Dujardin, S., Fernandes, A., Bannon, R., Commins, C., De Los Santos, M., Kamath, T.V., Hayashi, M., Hyman, B.T.: Tau propagation is dependent on the genetic background of mouse strains. Brain Communications 4(2), fcac048 (2022)
  • [15] Eliasof, M., Haber, E., Treister, E.: Pde-gcn: novel architectures for graph neural networks motivated by partial differential equations. Advances in neural information processing systems 34, 3836–3849 (2021)
  • [16] Figueiredo, M.A., Bioucas-Dias, J.M., Nowak, R.D.: Majorization–minimization algorithms for wavelet-based image restoration. IEEE Transactions on Image processing 16(12), 2980–2991 (2007)
  • [17] Fillard, P., Descoteaux, M., Goh, A., Gouttard, S., Jeurissen, B., Malcolm, J., Ramirez-Manzanares, A., Reisert, M., Sakaie, K., Tensaouti, F., et al.: Quantitative evaluation of 10 tractography algorithms on a realistic diffusion mr phantom. Neuroimage 56(1), 220–234 (2011)
  • [18] Gao, H., Ji, S.: Graph u-nets. In: international conference on machine learning. pp. 2083–2092. PMLR (2019)
  • [19] Glasser, M.F., Rilling, J.K.: Dti tractography of the human brain’s language pathways. Cerebral cortex 18(11), 2471–2482 (2008)
  • [20] Guo, Z., Zhang, Y., Teng, Z., Lu, W.: Densely connected graph convolutional networks for graph-to-sequence learning. Transactions of the Association for Computational Linguistics 7, 297–312 (2019)
  • [21] Hamilton, W., Ying, Z., Leskovec, J.: Inductive representation learning on large graphs. Advances in neural information processing systems 30 (2017)
  • [22] Hyman, J.M., Shashkov, M.: Adjoint operators for the natural discretizations of the divergence, gradient and curl on logically rectangular grids. Applied Numerical Mathematics 25(4), 413–442 (1997)
  • [23] Jack Jr, C.R., Bennett, D.A., Blennow, K., Carrillo, M.C., Dunn, B., Haeberlein, S.B., Holtzman, D.M., Jagust, W., Jessen, F., Karlawish, J., et al.: Nia-aa research framework: toward a biological definition of alzheimer’s disease. Alzheimer’s & Dementia 14(4), 535–562 (2018)
  • [24] Jenkinson, M., Beckmann, C.F., Behrens, T.E., Woolrich, M.W., Smith, S.M.: Fsl. Neuroimage 62(2), 782–790 (2012)
  • [25] Ji, S., Pan, S., Cambria, E., Marttinen, P., Philip, S.Y.: A survey on knowledge graphs: Representation, acquisition, and applications. IEEE transactions on neural networks and learning systems 33(2), 494–514 (2021)
  • [26] Jones, D.T., Graff-Radford, J., Lowe, V.J., Wiste, H.J., Gunter, J.L., Senjem, M.L., Botha, H., Kantarci, K., Boeve, B.F., Knopman, D.S., et al.: Tau, amyloid, and cascading network failure across the alzheimer’s disease spectrum. Cortex 97, 143–159 (2017)
  • [27] Kipf, T.N., Welling, M.: Semi-supervised classification with graph convolutional networks. arXiv preprint arXiv:1609.02907 (2016)
  • [28] Lee, W.J., Brown, J.A., Kim, H.R., La Joie, R., Cho, H., Lyoo, C.H., Rabinovici, G.D., Seong, J.K., Seeley, W.W.: Regional aβ\beta-tau interactions promote onset and acceleration of alzheimer’s disease tau spreading. Neuron 110(12), 1932–1943 (2022)
  • [29] Li, G., Muller, M., Thabet, A., Ghanem, B.: Deepgcns: Can gcns go as deep as cnns? In: Proceedings of the IEEE/CVF international conference on computer vision. pp. 9267–9276 (2019)
  • [30] Li, Q., Han, Z., Wu, X.M.: Deeper insights into graph convolutional networks for semi-supervised learning. In: Proceedings of the AAAI conference on artificial intelligence. vol. 32 (2018)
  • [31] Morris, C.: Graph neural networks: Graph classification. In: Wu, L., Cui, P., Pei, J., Zhao, L. (eds.) Graph Neural Networks: Foundations, Frontiers, and Applications, pp. 179–193. Springer Singapore, Singapore (2022)
  • [32] Petersen, R.C., Aisen, P.S., Beckett, L.A., Donohue, M.C., Gamst, A.C., Harvey, D.J., Jack, C.R., Jagust, W.J., Shaw, L.M., Toga, A.W., Trojanowski, J.Q., Weiner, M.W.: Alzheimer’s disease neuroimaging initiative (adni). Neurology 74(3), 201–209 (2010)
  • [33] Poli, M., Massaroli, S., Park, J., Yamashita, A., Asama, H., Park, J.: Graph neural ordinary differential equations. arXiv preprint arXiv:1911.07532 (2019)
  • [34] Pradhyumna, P., Shreya, G., et al.: Graph neural network (gnn) in image and video understanding using deep learning for computer vision applications. In: 2021 Second International Conference on Electronics and Sustainable Communication Systems (ICESC). pp. 1183–1189. IEEE (2021)
  • [35] Rudin, L.I., Osher, S., Fatemi, E.: Nonlinear total variation based noise removal algorithms. Physica D: nonlinear phenomena 60(1-4), 259–268 (1992)
  • [36] Sen, P., Namata, G., Bilgic, M., Getoor, L., Galligher, B., Eliassi-Rad, T.: Collective classification in network data. AI magazine 29(3), 93–93 (2008)
  • [37] Thekumparampil, K.K., Wang, C., Oh, S., Li, L.J.: Attention-based graph neural network for semi-supervised learning. arXiv preprint arXiv:1803.03735 (2018)
  • [38] Tristán-Vega, A., Aja-Fernández, S.: Dwi filtering using joint information for dti and hardi. Medical image analysis 14(2), 205–218 (2010)
  • [39] Veličković, P., Cucurull, G., Casanova, A., Romero, A., Lio, P., Bengio, Y.: Graph attention networks. arXiv preprint arXiv:1710.10903 (2017)
  • [40] Vogel, J.W., Iturria-Medina, Y., Strandberg, O.T., Smith, R., Levitis, E., Evans, A.C., Hansson, O.: Spread of pathological tau proteins through communicating neurons in human alzheimer’s disease. Nature communications 11(1),  2612 (2020)
  • [41] Wang, Y., Yang, J., Yin, W., Zhang, Y.: A new alternating minimization algorithm for total variation image reconstruction. SIAM Journal on Imaging Sciences 1(3), 248–272 (2008)
  • [42] Wu, L., Chen, Y., Shen, K., Guo, X., Gao, H., Li, S., Pei, J., Long, B., et al.: Graph neural networks for natural language processing: A survey. Foundations and Trends® in Machine Learning 16(2), 119–328 (2023)
  • [43] Wu, S., Sun, F., Zhang, W., Xie, X., Cui, B.: Graph neural networks in recommender systems: a survey. ACM Computing Surveys 55(5), 1–37 (2022)
  • [44] Xhonneux, L.P., Qu, M., Tang, J.: Continuous graph neural networks. In: International Conference on Machine Learning. pp. 10432–10441. PMLR (2020)
  • [45] Xu, K., Li, C., Tian, Y., Sonobe, T., Kawarabayashi, K.i., Jegelka, S.: Representation learning on graphs with jumping knowledge networks. In: International conference on machine learning. pp. 5453–5462. PMLR (2018)
  • [46] Yang, Z., Cohen, W., Salakhudinov, R.: Revisiting semi-supervised learning with graph embeddings. In: International conference on machine learning. pp. 40–48. PMLR (2016)
  • [47] Zhang, M., Chen, Y.: Link prediction based on graph neural networks. Advances in neural information processing systems 31 (2018)
  • [48] Zhang, Y., Gao, S., Pei, J., Huang, H.: Improving social network embedding via new second-order continuous graph neural networks. In: Proceedings of the 28th ACM SIGKDD Conference on Knowledge Discovery and Data Mining. pp. 2515–2523 (2022)
  • [49] Zhao, J., Mathieu, M., LeCun, Y.: Energy-based generative adversarial network. arXiv preprint arXiv:1609.03126 (2016)
  • [50] Zhu, M., Wright, S.J., Chan, T.F., et al.: Duality-based algorithms for total-variation-regularized image restoration. Computational Optimization and Applications 47(3),  377 (2010)