This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

11institutetext: CS Department, Missouri University of Science and Technology, Rolla, MO 65409, USA 11email: {yam64,tluo}@mst.edu

Unmasking Dementia Detection by Masking Input Gradients:
A JSM Approach to Model Interpretability and Precision

Yasmine Mustafa 0009-0002-3512-9659    Tie Luo Corresponding author. 0000-0003-2947-3111
Abstract

The evolution of deep learning and artificial intelligence has significantly reshaped technological landscapes. However, their effective application in crucial sectors such as medicine demands more than just superior performance, but trustworthiness as well. While interpretability plays a pivotal role, existing explainable AI (XAI) approaches often do not reveal Clever Hans behavior where a model makes (ungeneralizable) correct predictions using spurious correlations or biases in data. Likewise, current post-hoc XAI methods are susceptible to generating unjustified counterfactual examples. In this paper, we approach XAI with an innovative model debugging methodology realized through Jacobian Saliency Map (JSM). To cast the problem into a concrete context, we employ Alzheimer’s disease (AD) diagnosis as the use case, motivated by its significant impact on human lives and the formidable challenge in its early detection, stemming from the intricate nature of its progression. We introduce an interpretable, multimodal model for AD classification over its multi-stage progression, incorporating JSM as a modality-agnostic tool that provides insights into volumetric changes indicative of brain abnormalities. Our extensive evaluation including ablation study manifests the efficacy of using JSM for model debugging and interpretation, while significantly enhancing model accuracy as well.

Keywords:
Trustworthy AI Interpretability Explainability Reliability Jacobian saliency map Alzheimer’s disease

1 Introduction

Despite the remarkable successes of deep learning and artificial intelligence across various technological domains, achieving a high-performing system does not automatically guarantee its practical deployment and use, particularly in the field of medicine. Given the profound implications of medical decisions on human lives, doctors and patients often approach AI diagnoses with skepticism, notwithstanding claims of high precision, due to concerns surrounding trustworthiness.

In this study, we delve into the concept of trustworthy medical AI, focusing on two essential aspects. Firstly, explainability is crucial; without a clear explanation of the rationale behind a diagnosis, patients would be much less receptive to AI-driven decisions. Secondly, reliability is paramount, ensuring that AI models make predictions based on pertaining patterns rather than exhibiting what is known as “Clever-Hans behavior.” This phenomenon occurs when a machine learning model seemingly performs well but makes decisions based on irrelevant factors such as biases or coincidental correlations in data.

These two aspects are interrelated: a thorough model explanation not only provides the foundation for decision-making but also reveals whether predictions are influenced by Clever-Hans behavior. By addressing both explainability and reliability, we aim to enhance trust in medical AI systems and pave the way for their responsible and effective integration into healthcare practices.

A wide variety of explainable AI (XAI) tools have arisen to explain the predictions of trained black-box models, referred to as post-hoc methods. Although these methods have shown some promising prospects, they are vulnerable to the risk of generating unjustified counterfactual examples [12], hence may not be reliable. It is worth noting the nuances between explainability and interpretability in this context, although they are often loosely used interchangeably. Explainability refers to the ability to explain model decisions after training (post-hoc), while the original model is not interpretable by itself. On the other hand, interpretability is an inherent property of a model and means how easily and intuitively one can understand and make sense of the model’s decision-making process. Examples of highly interpretable models include decision trees and linear regression, which provide easily traceable logic for what roles the features played in decision-making. However, such models are shallow models and typically under-perform deep neural networks (DNNs). As post-hoc explanation is often inadequate to unravel the full complexity of model behavior due to its after-training nature, we focus on designing interpretable models in this paper, taking a during-modeling approach.

In this paper, we introduce a novel approach that guides the decision-making process of neural networks during the training phase by not only directing the model toward correct predictions but also penalizing any (including correct) predictions based on wrong cues. Cues refer to patterns, relationships, or features within the input data that are deemed relevant to the task at hand by the model to make predictions or decisions. Misinterpreted or misidentified cues erode the trustworthiness and reliability of a model even if its predictions are correct since the performance would not generalize to future unseen data.

For a concrete problem context, we take Alzheimer’s disease (AD) as the specific medical condition in this study, but highlight that our approach can be extended to similar problems without change of principle. AD is the predominant form of dementia and a major contributor to mortality. It impacts brain areas that are responsible for thought, memory, and language and is hard to cure. Although changes in the brain can manifest long before AD symptoms appear, it is challenging for medical professionals to detect manually until AD reaches late and severe stages, which however are no longer reversible. Therefore, AI-based diagnosis of AD has been researched actively. However, to date, newly developed models as such have rarely been adopted in clinical decision support systems (CDSS) for primary care, because such new models are almost exclusively based on DNNs which lack explainability.

Another crucial aspect of AD is that its clinical representation often involves multiple modalities such as computed tomography (CT) and magnetic resonance imaging (MRI) images. While combining such data modalities could lead to better precision [28], it poses a further challenge in designing interpretable models to interpret decisions involving all modalities.

In this paper, we propose an interpretable, multimodal model for classifying AD across its multi-stage progression, including early detection. We introduce a Jacobian-Augmented Loss function (JAL) that incorporates Jacobian Saliency Maps (JSM) as a model self-debugger. This approach is modality-agnostic as it is not limited to any specific imaging modality but can work seamlessly with various integrated modalities, making it versatile and adaptable to different data sources. In this study, we focus on images, given that other modalities such as Mini-Mental State Examination (MMSE) can be diagnosed effectively by medical professionals, and image-based early detection of AD proves challenging. Our methodology aligns with the broader concept of model debugging that involves troubleshooting unwanted behaviors and examining models’ predictions. Our goal is to ensure that a model not only learns to make accurate predictions but also avoids wrong cues as in the Clever Hans phenomenon, thus enhancing model reliability. We achieve this using JSM, which is computed during image preprocessing. Besides debugging, JSM also allows us to enhance model precision by highlighting deformations in body issues (e.g., brain as in the context of AD).

In summary, this paper makes the following contributions:

  • We design a novel loss function JAL which incorporates Jacobian saliency maps (JSM) to enable machine learning models to self-debug its decision-making process automatically. This approach ensures the model predictions to be based on genuine patterns and cues, and renders the model decision-making process to be more interpretable through a during-modeling methodology.

  • We include multimodal data fusion into the process through two distinct fusion techniques, not only shedding light on their differences but also showcasing the adaptability of JAL to different modalities and fusion levels.

  • Our approach enables models to provide fine-grained classification in terms of 4 classes including cognitively normal (CN) and three main AD stages: mild cognitive impairment (MCI), mild AD, and moderate to severe AD. On the contrary, existing approaches only provide binary classification or combine two or more stages into one class. Furthermore, although coarse-grained approaches are easier to attain higher accuracy due to less classes, our fine-grained diagnosis achieves higher accuracy than them.

  • Our comprehensive evaluation including ablation study proves the efficacy of using JSM as a model self-debugger for producing both reliable predictions and trustworthy interpretations.

2 Related Work

El-sappagh et al. [6]’s primary goal in explainability is to provide post-hoc explanations for the decisions made by Random Forest (RF) on classifying AD based on multi-modal data. Hence, the method does not try to explain the internal workings of RF but aims to provide explanations to decisions for a better understanding by physicians. The framework consists of two RF models. The first RF performs multi-classification to categorize individuals as normal, MCI, or AD. The second RF model only comes into play in the case of MCI, to classify whether the MCI is stable (sMCI) or progressive (pMCI). The authors used FreeSurfer software to automatically label areas of the structural MRI scans in order to provide natural language explanations using the Fuzzy Unordered Rule Induction Algorithm (FURIA) proposed by [8]. As a result, it produces a compact set of If-THEN statements that are understandable by physicians. SHapley Additive exPlanations (SHAP) [14] was employed to show local/global feature contributions to final decisions of the RF models. Finally, they used a visualization tool called RF explainer on the tree decisions to provide explanation about individual modalities.

Khare et al. [10] used electroencephalogram (EEG) which is lower-cost and less prone to radiation, as the only modality to perform binary classification of AD. After channel and feature analysis to extract the most important channels and features, the authors used the explainable boosting machine (XBM) model with three model-agnostic explainers: SHAP [14], Local Interpretable Model-agnostic Explanations (LIME) [21], and Morris Sensitivity (MS) [17]. The study presented topographic maps to elucidate feature importance in terms of EEG channels. Although XBM has shown promising performance, the used dataset is small insofar as the results are not affirmative enough to support EEG signals as a standalone diagnostic tool for AD. In fact, analyzing EEG is a challenging task [19], and it may not provide comprehensive information on a large scale. Nevertheless, we note that integrating EEG into a multimodal setup could be a valuable complementary aid.

Zhang et al. [30] proposed a 3D explainable residual attention network (3D ResAttNet) which is a deep convolutional neural network (CNN) with the addition of self-attention residual blocks and Gradient-weighted Class Activation Mapping (Grad-CAM) [25]. The residual mechanism alleviates vanishing gradients in deep networks, while self-attention learns long-range dependencies. Grad-CAM is used to pinpoint relevant areas (regions associated with disease presence) in each brain scan, by calculating the gradient of the probabilities of those areas with respect to the activation of a particular unit located at a certain position in the last convolutional layer of the network. This gradient represents how sensitive the predicted probabilities are to changes in the activation of that unit, thus highlighting the contribution of each brain area to the model’s decision. The authors used structural MRI (sMRI) as the single modality for two separate binary classifications: 1) AD vs CN and 2) pMCI vs sMCI.

Similarly, Yu et al. [29] used only sMRI too to perform binary diagnosis of AD (CN vs AD). Their objective was to create higher-resolution brain heatmaps to capture fine-grained details. To that end, they developed a network named MAXNet that consists of a Dual Attention Module (DAM) and a Multi-resolution Fusion Module (MFM), which learn representations that contain information at the voxel level. Additionally, they introduced High-resolution Activation Mapping (HAM) as a visualization method to enhance the quality of the heatmap. Although the algorithm can identify precise small regions in terms of voxels, validation cannot be provided as to whether the algorithm’s predictions are actually correct for the correct cues.

Mulyadi et al. [18] tackled this issue by developing a method called eXplainable AD Likelihood Map Estimation (XADLiME), based on clinically-guided prototype learning. They measured the similarity between those prototypes and the latent features of clinical information, a clinical label, MMSE score, and age, and thereby created a pseudo likelihood map representing the likelihood of AD across different stages. The AD likelihood map was estimated from sMRI using a feature extractor network and a reference map for comparison was obtained by a neural network with a sigmoid activation function. The estimated likelihood map can be viewed from both clinical and morphological perspectives to interpret predictions as a diagnostic tool, to help understand the likelihood of AD progression based on sMRI imaging.

While most studies, including [10, 30, 29, 18], utilize a single modality, our approach leverages multiple modalities for interpretable AD diagnosis and achieves enhanced performance. A perhaps more important differentiator of our work lies in our interpretation approach, which is rooted in our JSM framework, leading to more trustworthy medical diagnoses.

3 Methods

3.1 Jacobian Saliency Map (JSM)

Jacobian Saliency Maps (JSM) emerged recently [1] as a highly effective tool for deciphering the decision-making mechanisms of a deep learning model. It accomplishes this goal by defining specific zones within an input image and measuring their volumetric changes, thereby providing a precise understanding of feature attribution. Feature attribution, which ascribes significance and influence to individual features, enables us to identify the particular aspects of an input that exert significant influence on the model’s output. Thus, JSM transforms data in a way that aligns with human intuition and enhances the model’s interpretability before the actual deep learning pipeline, making it a promising choice for a diagnosis model debugger.

Our approach provides interpretation by computing the gradients of input with respect to weighted elements of the input image and optimizing them toward matching the patterns of deformations highlighted by the JSM. On manipulating the input gradients, we can explore two directions: enhance their significance in relevant brain areas or reduce their significance in irrelevant brain areas. However, because the appropriate magnitudes of gradients in relevant areas are typically unknown a priori, we choose to dampen the gradients in irrelevant areas as the normal regions are a reliable reference.

We are inspired by the work done by Ross et al. [23] who introduced a method to regularize a model’s gradients with respect to input features based on a binary mask annotation matrix. However, the annotation term added to the model in [23] is a simple binary mask, which fails to capture correlations between different regions of a brain scan. This poses a significant limitation to the effectiveness because such correlations carry crucial diagnostic information. Additionally, the datasets used in their validation were small and synthetic, thus leaving considerable doubt on whether the method would perform well in real-world medical applications.

Moreover, we do not use the annotation matrix but a special weighted saliency map instead. The rationale is that we prefer a matrix to not only represent the varying degrees of sensitivity of each feature to the diagnosis accurately but also highlight brain regions that exhibit deformations; in the meantime, the normal regions have to be preserved for contrast purposes. These properties cannot be achieved by the binary annotation matrix. In addition, by using the special saliency map to characterize the deformations in the brain, it also allows us to align the constraints on gradients (a penalty we formulate later in (2), (4)) with the domain-specific knowledge related to the medical problem. We perform both registration and Jacobian using Advanced Normalization Tools (ANTs) [3].

Intuition behind JSM. Ideally, the presence of a disease can be identified by comparing an individual’s many scans over a long period of time. Since this is usually not feasible in practice, we can compare an individual’s scan to a standardized healthy brain template to assess its local changes using a deformation map. JSM is derived from non-linear image registration, which is a process designed to both minimize variations among individual subjects and align all different images with a standardized template. JSM then examines the deformations that transform the anatomical structures of individuals onto a common standard space, through which we can deduce the relative volume differences between each individual and the template. This analysis aids in pinpointing statistically significant anatomical variations across diverse populations, such as distinguishing between AD patients and healthy elderly individuals. In our investigation, we utilized the Montreal Neurological Institute’s 152 brain template (MNI152) as the reference. This process of image registration and JSM computation can be attributed to the realm of computational anatomy, specifically recognized as tensor-based morphometry [22], and has not been adequately explored in the context of dementia.

Formulation of JSM. In medical image processing, computing deformation involves first aligning a source image MM with a target image FF. This is achieved by using a transformation ϕ\phi that maps points in MM to their corresponding points in FF. Then, the displacement between these corresponding points is represented as a Deformation Vector Field:

v(x,y,z)=ϕ(x,y,z)(x,y,z).\vec{v}(x,y,z)=\phi(x,y,z)-(x,y,z). (1)

To transform a point (x,y,z)(x,y,z) to ϕ(x,y,z)\phi(x,y,z), it is necessary to impose a regularization constraint in order to ensure that the deformation is seamless, one-to-one, and differentiable. This is framed as an optimization problem that minimizes the following cost function:

L(ϕ,M,F)=Lsim(ϕ(M),F)+αLReg(ϕ)L(\phi,M,F)=-L_{sim}(\phi(M),F)+\alpha L_{Reg}(\phi) (2)

where LsimL_{sim} is a similarity measure between two images, the transformed image ϕ(M)\phi(M) and FF, and LRegL_{Reg} is a regularization term that enforces the desired properties on the deformation. We use Mattes Mutual Information (MI) [16] as our similarity measure, i.e.,

Lsim=MI(M,F)=mfP(m,f)log(P(m,f)Q1(m)Q2(f))L_{sim}=MI(M^{\prime},F)=\sum_{m^{\prime}}\sum_{f}P(m^{\prime},f)\log\left(\frac{P(m^{\prime},f)}{Q_{1}(m^{\prime})Q_{2}(f)}\right) (3)

where MM^{\prime} denotes ϕ(M)\phi(M) for simplicity. P(m,f)P(m^{\prime},f) is the joint probability distribution of the intensity of voxel mm^{\prime} in image MM^{\prime} and that of voxel ff in image FF, and Q1(m)Q_{1}(m^{\prime}) and Q2(f)Q_{2}(f) represent the marginal probability distributions of mm^{\prime} and ff, respectively. For the regularizer, we use the B-spline regularization from [27]:

LReg=LB-spline=|2ϕ(x,y,z)|2𝑑VL_{Reg}=L_{\text{B-spline}}=\int\left|\nabla^{2}\phi(x,y,z)\right|^{2}dV (4)

where 2\nabla^{2} denotes the second-order derivative, dVdV represent seach voxel (x,y,z)(x,y,z), and we integrate over the entire spatial domain.

After solving for the transformation ϕ\phi, we compute a Jacobian matrix JJ from the deformation vector field v\vec{v}, by calculating the first derivative of v\vec{v} at each voxel to encode local deformations including stretching, shearing, and rotation. That is,

J(v)=[vxxvxyvxzvyxvyyvyzvzxvzyvzz]J(v)=\begin{bmatrix}\frac{\partial v_{x}}{\partial x}&\frac{\partial v_{x}}{\partial y}&\frac{\partial v_{x}}{\partial z}\\ \frac{\partial v_{y}}{\partial x}&\frac{\partial v_{y}}{\partial y}&\frac{\partial v_{y}}{\partial z}\\ \frac{\partial v_{z}}{\partial x}&\frac{\partial v_{z}}{\partial y}&\frac{\partial v_{z}}{\partial z}\end{bmatrix} (5)

Then, denoting the Jacobian determinant by Det(J)Det(J), we calculate it for every voxel v(x,y,z)v(x,y,z) as

(6)

which forms what we call a Jacobian Saliency Map JSMJSM of the source image MM:

JSM(M)=[Det(J(v(x,y,z))]x=1W(width)y=1H(height)z=1D(depth)At each voxel: {volume expansionif Det(J)>1no changeif Det(J)=1volume compressionif Det(J)<1\begin{array}[]{l}JSM(M)=\begin{bmatrix}&\vdots&\\ \dots&Det(J(v(x,y,z))&\dots\\ &\vdots&\\ \end{bmatrix}_{\begin{matrix}x=1...W{\rm\ (width)}\\ y=1...H{\rm\ (height)}\\ z=1...D{\rm\ \ (depth)}\end{matrix}}\\ \\ \text{At each voxel: }\begin{cases}\text{{volume expansion}}&\text{{if }}Det(J)>1\\ \text{{no change}}&\text{{if }}Det(J)=1\\ \text{{volume compression}}&\text{{if }}Det(J)<1\end{cases}\end{array} (7)

Volumetric changes refer to the alteration in volume at the level of individual voxels within a medical image. This JSM helps us to identify the volumetric ratio of the brain image at the voxel level before and after transformation ϕ\phi, which indicates the brain’s volume change.

3.2 Jacobian-Augmented Loss Function (JAL)

By breaking down an input image into distinct regions and measuring how they are transformed, JSM provides precise insight into the complexities of feature attribution and thus model explainability. To this end, we take an innovative approach to explore the possibility of leveraging JSM as a powerful model debugger, for which we incorporate the JSM formulated above into the loss function \mathcal{L} of a medical diagnostic model:

(𝒙,𝒚,JSM)\displaystyle\mathcal{L}(\boldsymbol{x},\boldsymbol{y},JSM) =k=1Kyklog(y^k)+λd=1Dp=1P(wdpJSMdpxdpk=1Klog(y^nk))2\displaystyle=-\sum_{k=1}^{K}y_{k}\log(\hat{y}_{k})+\lambda\sum_{d=1}^{D}\sum_{p=1}^{P}\left(w_{dp}JSM_{dp}\frac{\partial}{\partial x_{dp}}\sum_{k=1}^{K}\log(\hat{y}_{nk})\right)^{2} (8)

Given input data 𝒙\boldsymbol{x} and data label 𝒚\boldsymbol{y}, the loss function (λ,𝒙,𝒚,JSM)\mathcal{L}(\lambda,\boldsymbol{x},\boldsymbol{y},JSM) consists of the training loss (first term) and a novel, JSM-based regularization term (second term). It introduces an emphasis on the importance of anatomical changes captured by the Jacobian map values. Specifically, JSMdpJSM_{dp} represents the JSM values associated with the dthd^{th} feature and pthp^{th} spatial dimension (width, height, and depth in 3D), log(y^k)log(\hat{y}_{k}) represents the natural logarithm of the predicted probability of xx belonging to class kk, and xdp\frac{\partial}{\partial x_{dp}} reflects the partial derivative with respect to the dthd^{th} feature of xx in the pthp^{th} spatial dimension. Thus, this regularization term aims to not only enhance interpretability but also rectify predictions by mitigating the influence of irrelevant cues.

With reference to Equation (7), we add a weight matrix WW to give more importance to areas in the JSM that have volumetric changes (expansion or compression) and discourage the input gradients from being significant in areas with no volumetric changes (marked by 1). Hence each element of WW is defined as follows:

wdp={feature_weight,if JSMdp1debug_weight,otherwisew_{dp}=\begin{cases}\text{{feature\_weight}},&\text{{if }}JSM_{dp}\neq 1\\ \text{{debug\_weight}},&\text{{otherwise}}\end{cases} (9)

Both feature_weight and debug_weight are hyperparameters that indicate the level of importance in every region. We designate a debug weight of 0.2. This choice is deliberate, allowing us to down-weight potentially misleading features while still retaining them as a reference for contrasting relevant features. On the other hand, since regions of volumetric changes are of higher importance, we assign them a feature weight of 0.8. Hence, we debug the model through weighting features according to its relevance. Note that we don’t eliminate irrelevant areas by giving a weight of zero to preserve correlations in the brain volume. More rigorous hyperparameter tuning of such weights can be incorporated in future work.

Refer to caption
Figure 1: Preprocessing pipelines for MRI and CT scans, involving bias field correction for MRI, contrast stretching for CT to enhance diagnostic values, BET for brain extraction, and registering CT and MRI to MNI152 brain template.

4 Experiments

4.1 Dataset

In our experiments, we utilized the recently published OASIS-3 dataset. This dataset includes multi-modality data such as MRI, PET, and CT scan images from a diverse group of 1377 participants. Among these participants, 755 were cognitively normal (CN) adults, while the remaining 622 individuals showed different levels of cognitive decline. The age range of the participants was extensive, spanning from 42 to 95 years. CT imaging was used to detect whether certain areas of the brain were shrinking, which can be an indication of AD. On the other hand, MRI provided detailed images of the body and a clear view of progressive cerebral atrophy, which is most visible through T1-weighted volumetric sequences. Consideration of PET data was deferred due to its temporal nature, making it more suitable for future spatiotemporal analyses.

During clinical assessments and diagnoses, the clinical dementia rating (CDR) scores of the participants were utilized. The scores range from 0 to 3, with 0 indicating no AD dementia and 3 indicating severe AD dementia. The very mild stage (rating 0.5) of dementia is similar to the Mild Cognitive Impairment (MCI) stage of AD. As mentioned before, we combined moderate and severe AD due to the very low number of subjects with severe AD. Ultimately, we created four classes based on CDR scores: normal, MCI, mild AD, and severe AD. Having the same subjects in multiple sets (train and test sets) can result in the model overfitting to those individuals, potentially leading to a subpar performance on new, unseen subjects. Hence, for patients who underwent multiple sessions, we preserved the initial MRI session and selected the CT session that was closest in date to that MRI.

4.2 Preprocessing

Our preprocessing pipelines for MRI and CT scans are shown in Fig. 1. To minimize any spatially varying intensity bias that may result from factors such as magnetic field inhomogeneities and acquisition artifacts, we use the FMRIB’s Linear Image Registration Tool (FLIRT) [9] for bias field correction on MRI images. Meanwhile, we utilize a technique called contrast stretching on CT images to enhance their diagnostic value and visual perception. This process involves adjusting the pixel intensities to fully utilize the display’s dynamic range. For CT images, we followed the framework established by Kuijf et al. [11]. Also, we apply the Brain Extraction Tool (BET) [26] to eliminate non-brain portions from both MRI and CT images. Finally, we register both CT and MRI to the MNI152 brain template. Brain templates are typically created for MRI images, making it difficult to register a CT image to an MRI template. We overcame this by adhering to a method in [11] which involves identifying corresponding landmarks in the CT image and MRI template and then using these landmarks to align the images.

4.3 Multimodal Classification

Medical images, especially those related to the brain, are typically 3D. This introduces a computational burden into complex deep neural networks. Our approach seeks to harness the guidance provided by the JSM, and by incorporating such insights from JSM, we aim to develop lighter convolutional neural networks (CNNs) that alleviate computational burdens without compromising model performance. In view of the multimodal nature of our study, we incorporated two data fusion techniques: 1) Late Fusion and 2) Early Fusion. As shown in Fig. 2, late fusion adopts a dual-branch structure, treating each modality independently and subsequently aggregating their predictions through an averaging mechanism. This approach is particularly advantageous for JAL, where debugging is performed separately for each modality through its own JSM. On the other hand, early fusion involves concatenating the input images as well as their corresponding JSM maps, which allows the model to glean correlations between the two modalities and concurrently debug predictions for the input holistically.

Table 1: CNN Architecture for both branches
Layers Parameters Output Size
Input Batch Size 10 10×1×182×256×51210\times 1\times 182\times 256\times 512
Conv1 Stride 1 10×4×182×256×51210\times 4\times 182\times 256\times 512
Padding 1
Kernel Size 3×3×33\times 3\times 3
BatchNorm1 Momentum=0.9 10×4×182×256×51210\times 4\times 182\times 256\times 512
Dropout1 Dropout rate 0.5 10×4×182×256×51210\times 4\times 182\times 256\times 512
MaxPool1 Stride 2 10×4×91×128×25610\times 4\times 91\times 128\times 256
Kernel Size 2×2×22\times 2\times 2
Conv2 Stride 1 10×8×91×128×25610\times 8\times 91\times 128\times 256
Padding 1
Kernel Size 3×3×33\times 3\times 3
BatchNorm2 Momentum=0.9 10×8×91×128×25610\times 8\times 91\times 128\times 256
Dropout2 Dropout rate 0.2 10×8×91×128×25610\times 8\times 91\times 128\times 256
MaxPool2 Stride 2 10×8×45×64×12810\times 8\times 45\times 64\times 128
Kernel Size 2×2×22\times 2\times 2
Flatten 10×294912010\times 2949120
Full-Conn. 10×410\times 4

Our lightweight CNN model [20] contains two convolutional layers coupled with batch normalization, ReLU activation, dropout, and max-pooling operations, which capture spatial hierarchies and correlation patterns in the input data. The convolutional layers use a kernel size of 3x3x3, stride of 1, and padding, with a dropout of rate 0.2 for regularization. The max-pooling operation reduces spatial dimensions to 2x2x2. The 3D tensor is reshaped into a 1D tensor by a flattening layer, preparing it for the subsequent fully connected subnetwork, which performs the overall feature integration and classification. All the specifications are presented in Table 1. Finally, a softmax layer converts the predicted scores into probabilities. In the late fusion setup, the final prediction is computed by averaging the probabilities from both branches.

Refer to caption
Figure 2: The complete pipeline. Model debugging using JSM is integrated into training and takes effect during backpropagation, for each modality (late fusion) or both (early fusion). Final predictions are interpreted by plotting elevated gradients overlaid on input images.

4.4 Performance Evaluation

We trained our model on a 40GB A100 GPU with batch size 10. Our model was able to quickly converge within only 20 epochs. To address the class imbalance problem in AD, we utilized the Adaptive Synthetic (ADASYN) [7] oversampling algorithm to generate synthetic samples for minority classes during training. Another challenge is that OASIS-3 included identical subjects from multiple sessions across training and test sets may contribute to overfitting, as the model can become excessively attuned to the characteristics of those particular subjects and sessions [2]. To address this, we take meticulous care to ensure that only subjects from distinct sessions are grouped within either the training set or the test set (but not both).

Table 2 summarizes the performance comparison in terms of accuracy, sensitivity, and specificity between our approach and the state-of-the-art. Please note that while precision and recall are more general terms used in machine learning, sensitivity and specificity are often preferred in medical and diagnostic fields due to their direct interpretation in the context of disease detection and diagnosis. Table 2 shows that our testing accuracy across four classes surpasses all the baselines that employ the same dataset. Massalimova et al. [15] achieved marginally higher sensitivity and specificity, but it is crucial to note that our model handles a larger number of classes, making it more challenging to achieve higher accuracy. In addition, [15] uses ResNet18 while our model is significantly lighter. In fact, our model is lighter than nearly all the baselines. Basheer et al. [4] used features like CDR, MMSE, age, gender, etc along with MRI images, and found that age and gender had substantial positive impact on performance. In our case, we achieved superior performance solely with images as we aimed to test our model using spatial features.

Ablation Study. To provide an in-depth assessment of the impact of JAL on our model’s performance, we conducted an ablation study on models with and without JAL. This was achieved by setting the JSM term in (8) to zero. The results are presented in Fig. 3, which provides histograms of model performance distribution across multiple mini-batches in the test set. The overlap represents the intersection between the with JAL and without JAL conditions, indicating the extent to which the model’s performance remains consistent regardless of the presence or absence of JAL. The histogram demonstrates that the model performance significantly improves in terms of all the metrics (accuracy, sensitivity, and specificity) when JAL is incorporated, as evidenced by the rightward shift of the distributions in the blue bars in comparison with the yellow bars. This compellingly demonstrates the impact of incorporating JAL on the model’s decisions. For a more comprehensive evaluation, Table 3 dissects the four classes and provides more detailed results. It shows substantial improvements over all the AD stages (CN, MCI, MLD, SEV), further affirming the efficacy of JAL. Table 3 also allows us to see that performance for CN and SEV (severe) are relatively higher than MCI and MLD, which is because MCI and MLD have more subtle differences in dementia patterns, making them more intricate to discern. Nevertheless, our model with JAL exhibits evenly promising outcomes for all stages. Scores in Table 2 are the macro average of the scores in Table 3.

Table 2: Comparison with reported state-of-the-art using OASIS-3 dataset for AD classification
Model Modalities Classes Sensitivity (%) Specificity (%) Accuracy (%)
Salami et al. [24] MRI AD, CN 86.01 85.04 87.75
Massalimova et al. [15] MRI CN, MCI, AD 96 96 96
Lazli et al. [13] MRI, PET AD, CN 92.00 91.78 91.46
Basheer et al. [4] MRI, features AD, CN 82.3 *NP 92.3
Castellano et al. [5] PET AD, CN NP NP 80
Our work MRI, CT (Early) CN, MCI, MOD, SEV 92.72 95 91.31
MRI, CT (Late) CN, MCI, MOD, SEV 93.5 93.5 95.37
*NP: Not Provided
Refer to caption
(a) Accuracy
Refer to caption
(b) Sensitivity
Refer to caption
(c) Specificity
Refer to caption
(d) Accuracy
Refer to caption
(e) Sensitivity
Refer to caption
(f) Specificity
Figure 3: Ablation study on JAL in terms of performance histograms. (a-c): Early fusion, (d-f): Late fusion.
Table 3: Ablation study comparing model performance with and without JAL in Late and Early Fusion setups.
Fusion Loss Sensitivity (%) Specificity (%) Accuracy (%)
Function CN MCI MLD SEV CN MCI MLD SEV CN MCI MLD SEV
Early w/o JAL 80.34 78.96 74.3 80.96 84.34 89.5 84.34 89 80 84.3 80.21 89
w/ JAL 90.3 94.6 94 92 99.12 92.01 99 90 98.8 87.33 86.59 92.5
Late w/o JAL 87.7 87.9 85.5 87.6 87.8 86.6 86.6 85.2 88.8 86.1 88.4 87.3
w/ JAL 95.3 93.3 93.3 91.5 99.6 92 92 92.2 99.8 92 93 96.7
Refer to caption
Figure 4: Visualization of larger gradients in JSM-indicated deformation areas for MRI and CT modalities.

Interpretability. We examined the similarity between the volumetric changes characterized by JSM and the decision-making process of the neural network. By plotting gradients overlaid on their corresponding input images, we observed that they closely aligned with the patterns highlighted by the JSM. This can be seen from Fig. 4 which provides samples from the dataset showing the axial, sagittal, and coronal views of each MRI and CT images juxtaposed with the corresponding JSM views that highlight brain deformations. This desirable alignment is attributed to our JAL loss function which incorporates the JSM during the model debugging process, contributing to the model trustworthiness by promoting transparency.

Furthermore, incorporating JSM in JAL during model debugging also enables the model to learn and adapt to the highlighted deformations, and hence serves as a powerful tool for refining the model performance as well, fostering a more accurate and informed decision-making process.

5 Conclusion

This paper introduces a new approach to trustworthy medical diagnoses, by addressing two key challenges: model explainability and reliability. On explainability, we leverage Jacobian saliency maps (JSM) to provide informative and interpretable guide for feature learning, as well as capture subtle morphological changes associated with the disease. On reliability, we incorporate JSM into the loss function as a self-debugger to direct the model to critical (disease-relevant) regions during training, avoiding the Clever-Hans behavior. Our approach not only helps rectify erroneous predictions but also identifies regions in a post-hoc manner with elevated gradients for interpretability enhancement. (Note that post-hoc is for visualization only; our main approach of JAL/debugging is a during-modeling approach.) Our extensive evaluation underscores the success of JSM via a Jacobian-augmented loss function (JAL), leading to substantial accuracy improvement (by up to 10%) and greater model interpretability in identifying significant brain areas that lead to diagnostic predictions. Our XAI approach also works seamlessly with our multimodal data fusion methods and provides explanation in both early and late fusion setups.

References

  • [1] Abbas, S.Q., et al.: Transformed domain convolutional neural network for alzheimer’s disease diagnosis using structural mri. Pattern Recognition 133, 109031 (2023)
  • [2] Altay, F., et al.: Preclinical stage alzheimer’s disease detection using magnetic resonance image scans. In: Proceedings of the AAAI Conference on Artificial Intelligence. vol. 35, pp. 15088–15097 (2021)
  • [3] Avants, B.B., et al.: Advanced normalization tools (ants). Insight j 2(365), 1–35 (2009)
  • [4] Basheer, S., et al.: Computational modeling of dementia prediction using deep neural network: analysis on oasis dataset. IEEE access 9, 42449–42462 (2021)
  • [5] Castellano, G., et al.: Detection of dementia through 3d convolutional neural networks based on amyloid pet. In: 2021 IEEE Symposium Series on Computational Intelligence (SSCI). pp. 1–6. IEEE (2021)
  • [6] El-Sappagh, S., et al.: A multilayer multimodal detection and prediction model based on explainable artificial intelligence for alzheimer’s disease. Scientific reports 11(1),  2660 (2021)
  • [7] He, H., et al.: Adasyn: Adaptive synthetic sampling approach for imbalanced learning. In: 2008 IEEE international joint conference on neural networks (IEEE world congress on computational intelligence). pp. 1322–1328. IEEE (2008)
  • [8] Hühn, J., Hüllermeier, E.: Furia: an algorithm for unordered fuzzy rule induction. Data Mining and Knowledge Discovery 19, 293–319 (2009)
  • [9] Jenkinson, M., et al.: Improved optimization for the robust and accurate linear registration and motion correction of brain images. Neuroimage 17(2), 825–841 (2002)
  • [10] Khare, S.K., et al.: Adazd-net: Automated adaptive and explainable alzheimer’s disease detection system using eeg signals. Knowledge-Based Systems 278, 110858 (2023)
  • [11] Kuijf, H.J., et al.: Registration of brain ct images to an mri template for the purpose of lesion-symptom mapping. In: Multimodal Brain Image Analysis: Third International Workshop, MBIA 2013, Held in Conjunction with MICCAI 2013, Japan, Proceedings 3. pp. 119–128. Springer (2013)
  • [12] Laugel, T., et al.: The dangers of post-hoc interpretability: Unjustified counterfactual explanations. arXiv preprint arXiv:1907.09294 (2019)
  • [13] Lazli, L., et al.: Computer-aided diagnosis system of alzheimer’s disease based on multimodal fusion: tissue quantification based on the hybrid fuzzy-genetic-possibilistic model and discriminative classification based on the svdd model. Brain Sciences 9(10),  289 (2019)
  • [14] Lundberg, S.M., et al.: A unified approach to interpreting model predictions. Advances in neural information processing systems 30 (2017)
  • [15] Massalimova, A., et al.: Input agnostic deep learning for alzheimer’s disease classification using multimodal mri images. In: 2021 43rd Annual International Conference of the IEEE Engineering in Medicine & Biology Society (EMBC). pp. 2875–2878. IEEE (2021)
  • [16] Mattes, D., et al.: Pet-ct image registration in the chest using free-form deformations 22(1), 120–128 (2003)
  • [17] Morris, M.D.: Factorial sampling plans for preliminary computational experiments. Technometrics 33(2), 161–174 (1991)
  • [18] Mulyadi, A.W., et al.: Estimating explainable alzheimer’s disease likelihood map via clinically-guided prototype learning. NeuroImage 273, 120073 (2023)
  • [19] Mustafa, Y., Elmahallawy, M., Luo, T., Eldawlatly, S.: A brain-computer interface augmented reality framework with auto-adaptive ssvep recognition. In: 2023 IEEE International Conference on Metrology for eXtended Reality, Artificial Intelligence and Neural Engineering (MetroXRAINE). pp. 799–804. IEEE (2023)
  • [20] Mustafa, Y., Luo, T.: Diagnosing Alzheimer’s disease using early-late multimodal data fusion with Jacobian maps. In: IEEE International Conference on E-health Networking, Application & Services (Healthcom) (2023)
  • [21] Ribeiro, M.T., et al.: " why should i trust you?" explaining the predictions of any classifier. In: Proceedings of the 22nd ACM SIGKDD international conference on knowledge discovery and data mining. pp. 1135–1144 (2016)
  • [22] Riyahi, S., et al.: Quantifying local tumor morphological changes with jacobian map for prediction of pathologic tumor response to chemo-radiotherapy in locally advanced esophageal cancer. Physics in Medicine & Biology 63(14), 145020 (2018)
  • [23] Ross, A.S., et al.: Right for the right reasons: training differentiable models by constraining their explanations. In: Proceedings of the 26th International Joint Conference on Artificial Intelligence (IJCAI). pp. 2662–2670 (2017)
  • [24] Salami, F., et al.: Designing a clinical decision support system for alzheimer’s diagnosis on oasis-3 data set. Biomedical Signal Processing and Control 74, 103527 (2022)
  • [25] Selvaraju, R.R., et al.: Grad-cam: Visual explanations from deep networks via gradient-based localization. In: Proceedings of the IEEE international conference on computer vision. pp. 618–626 (2017)
  • [26] Smith, S.M.: Fast robust automated brain extraction. Human brain mapping 17(3), 143–155 (2002)
  • [27] Tustison, N.J., et al.: Explicit b-spline regularization in diffeomorphic image registration. Frontiers in neuroinformatics 7,  39 (2013)
  • [28] Venugopalan, J., et al.: Multimodal deep learning models for early detection of alzheimer’s disease stage. Scientific reports 11(1),  3254 (2021)
  • [29] Yu, L., Xiang, et al.: A novel explainable neural network for alzheimer’s disease diagnosis. Pattern Recognition 131, 108876 (2022)
  • [30] Zhang, X., Han, et al.: An explainable 3d residual self-attention deep neural network for joint atrophy localization and alzheimer’s disease diagnosis using structural mri. IEEE journal of biomedical and health informatics 26(11), 5289–5297 (2021)