22institutetext: The Hong Kong University of Science and Technology, Hong Kong, China 33institutetext: Institute of High Performance Computing, Agency for Science, Technology and Research, Singapore 44institutetext: University of Cambridge, UK
DiffMIC: Dual-Guidance Diffusion Network
for Medical Image Classification
Abstract
Diffusion Probabilistic Models have recently shown remarkable performance in generative image modeling, attracting significant attention in the computer vision community. However, while a substantial amount of diffusion-based research has focused on generative tasks, few studies have applied diffusion models to general medical image classification. In this paper, we propose the first diffusion-based model (named DiffMIC) to address general medical image classification by eliminating unexpected noise and perturbations in medical images and robustly capturing semantic representation. To achieve this goal, we devise a dual conditional guidance strategy that conditions each diffusion step with multiple granularities to improve step-wise regional attention. Furthermore, we propose learning the mutual information in each granularity by enforcing Maximum-Mean Discrepancy regularization during the diffusion forward process. We evaluate the effectiveness of our DiffMIC on three medical classification tasks with different image modalities, including placental maturity grading on ultrasound images, skin lesion classification using dermatoscopic images, and diabetic retinopathy grading using fundus images. Our experimental results demonstrate that DiffMIC outperforms state-of-the-art methods by a significant margin, indicating the universality and effectiveness of the proposed model. Our code is publicly available at https://github.com/scott-yjyang/DiffMIC.
Keywords:
diffusion probabilistic model medical image classification placental maturity skin lesion diabetic retinopathy.1 Introduction
Medical image analysis plays an indispensable role in clinical therapy because of the implications of digital medical imaging in modern healthcare [5]. Medical image classification, a fundamental step in the analysis of medical images, strives to distinguish medical images from different modalities based on certain criteria. An automatic and reliable classification system can help doctors interpret medical images quickly and accurately. Massive solutions for medical image classification have been developed over the past decades in the literature, most of which adopt deep learning ranging from popular CNNs to vision transformers [8, 9, 22, 23]. These methods have the potential to reduce the time and effort required for manual classification and improve the accuracy and consistency of results. However, medical images with diverse modalities still challenge existing methods due to the presence of various ambiguous lesions and fine-grained tissues, such as ultrasound (US), dermatoscopic, and fundus images. Moreover, generating medical images under hardware limitations can cause noisy and blurry effects, which can degrade image quality and thus demand a more effective feature representation modeling for robust classifications.
Recently, Denoising Diffusion Probabilistic Models (DDPM) [14] have achieved excellent results in image generation and synthesis tasks [21, 2, 6, 26] by iteratively improving the quality of a given image. Specifically, DDPM is a generative model based on a Markov chain, which models the data distribution by simulating a diffusion process that evolves the input data towards a target distribution. Although a few pioneer works tried to adopt the diffusion model for image segmentation and object detection tasks [1, 29, 4, 12], their potential for high-level vision has yet to be fully explored.
Motivated by the achievements of diffusion probabilistic models in generative image modeling, 1) we present a novel Denoising Diffusion-based model named DiffMIC for accurate classification of diverse medical image modalities. As far as we know, we are the first to propose a Diffusion-based model for general medical image classification. Our method can appropriately eliminate undesirable noise in medical images as the diffusion process is stochastic in nature for each sampling step. 2) In particular, we introduce a Dual-granularity Conditional Guidance (DCG) strategy to guide the denoising procedure, conditioning each step with both global and local priors in the diffusion process. By conducting the diffusion process on smaller patches, our method can distinguish critical tissues with fine-grained capability. 3) Moreover, we introduce Condition-specific Maximum-Mean Discrepancy (MMD) regularization to learn the mutual information in the latent space for each granularity, enabling the network to model a robust feature representation shared by the whole image and patches. 4) We evaluate the effectiveness of DiffMIC on three 2D medical image classification tasks including placental maturity grading, skin lesion classification, and diabetic retinopathy grading. The experimental results demonstrate that our diffusion-based classification method consistently and significantly surpasses state-of-the-art methods for all three tasks.
2 Method

Figure 1 shows the schematic illustration of our network for medical image classification. Given an input medical image , we pass it to an image encoder to obtain the image feature embedding , and a dual-granularity conditional guidance (DCG) model to produce the global prior and local prior . At the training stage, we apply the diffusion process on ground truth and different priors to generate three noisy variables , , and (the global prior for , the local prior for , and dual priors for ). Then, we combine the three noisy variables , , and and their respective priors and project them into a latent space, respectively. We further integrate three projected embeddings with the image feature embedding in the denoising U-Net, respectively, and predict the noise distribution sampled for , , and . We devise condition-specific maximum-mean discrepancy (MMD) regularization loss on the predicted noise of and , and employ the noise estimation loss by mean squared error (MSE) on the predicted noise of to collaboratively train our DiffMIC network.
Diffusion Model. Following DDPM [14], our diffusion model also has two stages: a forward diffusion stage (training) and a reverse diffusion stage (inference). In the forward process, the ground truth response variable is added Gaussian noise through the diffusion process conditioned by time step sampled from a uniform distribution of , and such noisy variables are denoted as . As suggested by the standard implementation of DDPM, we adopt a UNet as the denoising network to parameterize the reverse diffusion process and learn the noise distribution in the forward process. In the reverse diffusion process, the trained UNet generates the final prediction by transforming the noisy variable distribution to the ground truth distribution :
(1) |
where is parameters of the denoising UNet, denotes the Gaussian distribution, and is the identity matrix.
2.1 Dual-granularity Conditional Guidance (DCG) Strategy
DCG Model. In most conditional DDPM, the conditional prior will be a unique given information. However, medical image classification is particularly challenging due to the ambiguity of objects. It is difficult to differentiate lesions and tissues from the background, especially in low-contrast image modalities, such as ultrasound images. Moreover, unexpected noise or blurry effects may exist in regions of interest (ROIs), thereby hindering the understanding of high-level semantics. Taking only a raw image as the condition in each diffusion step will be insufficient to robustly learn the fine-grained information, resulting in classification performance degradation.
To alleviate this issue, we design a Dual-granularity Conditional Guidance (DCG) for encoding each diffusion step. Specifically, we introduce a DCG model to compute the global and local conditional priors for the diffusion process. Similar to the diagnostic process of a radiologist, we can obtain a holistic understanding from the global prior and also concentrate on areas corresponding to lesions from the local prior when removing the negative noise effects. As shown in Figure 1 (c), for the global stream, the raw image data is fed into the global encoder and then a convolutional layer to generate a saliency map of the whole image. The global prior is then predicted from the whole saliency map by averaging the responses. For the local stream, we further crop the ROIs whose responses are significant in the saliency map of the whole image. Each ROI is fed into the local encoder to obtain a feature vector. We then leverage the gated attention mechanism[15] to fuse all feature vectors from ROIs to obtain a weighted vector, which is then utilized for computing the local prior by one linear layer.
Denoising Model. The noisy variable is sampled in the diffusion process based on the global and local priors computed by the DCG model following:
(2) |
where , with a linear noise schedule . After that, we feed the concatenated vector of the noisy variable and dual priors into our denoising model UNet to estimate the noise distribution, which is formulated as:
(3) |
where denotes the projection layer to the latent space. is the concatenation operation. and are the encoder and decoder of UNet. Note that the image feature embedding is further integrated with the projected noisy embedding in the UNet to make the model focus on high-level semantics and thus obtain more robust feature representations. In the forward process, we seek to minimize the noise estimation loss :
(4) |
Our method improves the vanilla diffusion model by conditioning each step estimation function on priors that combine information derived from the raw image and ROIs.
2.2 Condition-specific MMD Regularization
Maximum-Mean Discrepancy (MMD) is to measure the similarity between two distributions by comparing all of their moments [11, 17], which can be efficiently achieved by a kernel function. Inspired by InfoVAE [31], we introduce an additional pair of condition-specific MMD regularization loss to learn mutual information between the sampled noise distribution and the Gaussian distribution. To be specific, we sample the noisy variable from the diffusion process at time step conditioned only by the global prior and then compute an MMD-regularization loss as:
(5) | ||||
with |
where is a positive definite kernel to reproduce distributions in the Hilbert space. The condition-specific MMD regularization is also applied on the local prior, as shown in Figure 1 (a). While the general noise estimation loss captures the complementary information from both priors, the condition-specific MMD regularization maintains the mutual information between each prior and target distribution. This also helps the network better model the robust feature representation shared by dual priors and converge faster in a stable way.
2.3 Training and Inference Scheme
Total loss. By adding the noise estimation loss and the MMD-regularization loss, we compute the total loss of our denoising network as follows:
(6) |
where is a balancing hyper-parameter, and it is empirically set as =0.5.
Training details. The diffusion model in this study leverages a standard DDPM training process, where the diffusion time step is selected from a uniform distribution of , and the noise is linearly scheduled with and . We adopt ResNet18 as the image encoder . Following [12], we concatenate ,,, and apply a linear layer with an output dimension of 6144 to obtain the fused vector in the latent space. To condition the response embedding on the timestep, we perform a Hadamard product between the fused vector and a timestep embedding. We then integrate the image feature embedding and response embedding by performing another Hadamard product between them. The output vector is sent through two consecutive fully-connected layers, each followed by a Hadamard product with a timestep embedding. Finally, we use a fully-connected layer to predict the noise with an output dimension of classes. It is worth noting that all fully-connected layers are accompanied by a batch normalization layer and a Softplus non-linearity, with the exception of the output layer. For the DCG model , the backbone of its global and local stream is ResNet. We adopt the standard cross-entropy loss as the objective of the DCG model. We jointly train the denoising diffusion model and DCG model after pretraining the DCG model 10 epochs for warm-up, thereby resulting in an end-to-end DiffMIC for medical image classification.
Inference stage. As displayed in Figure 1 (b), given an input image , we first feed it into the DCG model to obtain dual priors . Then, following the pipeline of DDPM, the final prediction is iteratively denoised from the random prediction using the trained UNet conditioned by dual priors and the image feature embedding .
3 Experiments
Datasets and Evaluation: We evaluate the effectiveness of our network on an in-home dataset and two public datasets, e.g., PMG2000, HAM10000 [27], and APTOS2019 [16]. (a) PMG2000. We collect and annotate a benchmark dataset (denoted as PMG2000) for placental maturity grading (PMG) with four categories111Our data collection is approved by the Institutional Review Board (IRB).. PMG2000 is composed of 2,098 ultrasound images, and we randomly divide the entire dataset into a training part and a testing part at an 8:2 ratio. (b) HAM10000. HAM10000 [27] is from the Skin Lesion Analysis Toward Melanoma Detection 2018 challenge, and it contains 10,015 skin lesion images with predefined 7 categories. (c) APTOS2019. In APTOS2019 [16], A total of 3,662 fundus images have been labeled to classify diabetic retinopathy into five different categories. Following the same protocol in [10], we split HAM10000 and APTOS2019 into a train part and a test part at a 7:3 ratio. These three datasets are with different medical image modalities. PMG2000 is gray-scale and class-balanced ultrasound images; HAM10000 is colorful but class-imbalanced dermatoscopic images; and APTOS2019 is another class-imbalanced dataset with colorful Fundus images. Moreover, we introduce two widely-used metrics Accuracy and F1-score to quantitatively compare our framework against existing SOTA methods.
Implementation Details: Our framework is implemented with the PyTorch on one NVIDIA RTX 3090 GPU. We center-crop the image and then resize the spatial resolution of the cropped image to . Random flipping and rotation for data augmentation are implemented during the training processing. In all experiments, we extract six ROI patches from each image. We trained our network end-to-end using the batch size of 32 and the Adam optimizer. The initial learning rate for the denoising model U-Net is set as 1×10-3, while for the DCG model (see Section 2.1) it is set to 2×10-4 when training the entire network. Following [20], the number of training epochs is set as 1,000 for all three datasets. In inference, we empirically set the total diffusion time step as 100 for PMG2000, 250 for HAM10000, and 60 for APTOS2019, which is much smaller than most of the existing works [14, 12]. The average running time of our DiffMIC is about 0.056 seconds for classifying an image with a spatial resolution of .
Methods | LDAM [3] | OHEM [25] | MTL [18] | DANIL [10] | CL [20] | ProCo [30] | Our DiffMIC | |
HAM10000 | Accuracy | 0.857 | 0.818 | 0.811 | 0.825 | 0.865 | 0.887 | 0.906 |
F1-score | 0.734 | 0.660 | 0.667 | 0.674 | 0.739 | 0.763 | 0.816 | |
APTOS2019 | Accuracy | 0.813 | 0.813 | 0.813 | 0.825 | 0.825 | 0.837 | 0.858 |
F1-score | 0.620 | 0.631 | 0.632 | 0.660 | 0.652 | 0.674 | 0.716 |
Comparison with State-of-the-art Methods: In Table 2(a), we compare our DiffMIC against many state-of-the-art CNNs and transformer-based networks, including ResNet, Vision Transformer (ViT), Swin Transformer (Swin), Pyramid Transformer (PVT), and a medical image classification method (i.e., GMIC) on PMG2000. Apparently, PVT has the largest Accuracy of 0.907, and the largest F1-score of 0.902 among these methods. More importantly, our method further outperforms PVT. It improves the Accuracy from 0.907 to 0.931, and the F1-score from 0.902 to 0.926.
Note that both HAM10000 and APTOS2019 have a class imbalance issue. Hence, we compare our DiffMIC against state-of-the-art long-tailed medical image classification methods, and report the comparison results in Table 2(b). For HAM10000, our method produces a promising improvement over the second-best method ProCo of 0.019 and 0.053 in terms of Accuracy and F1-score, respectively. For APTOS2019, our method obtains a considerable improvement over ProCo of 0.021 and 0.042 in Accuracy and F1-score respectively.
Diffusion | DCG | MMD-reg | Accuracy | F1-score | |
basic | - | - | - | 0.879 | 0.881 |
C1 | ✓ | - | - | 0.906 | 0.899 |
C2 | ✓ | ✓ | - | 0.920 | 0.914 |
Our method | ✓ | ✓ | ✓ | 0.931 | 0.926 |
PMG2000

HAM10000

APTOS2019

Ablation Study: Extensive experiments are conducted to evaluate the effectiveness of major modules of our network. To do so, we build three baseline networks from our method. The first baseline (denoted as “basic”) is to remove all diffusion operations and the MMD regularization loss from our network. It means that “basic” is equal to the classical ResNet18. Then, we apply the vanilla diffusion process onto “basic” to construct another baseline network (denoted as “C1”), and further add our dual-granularity conditional guidance into the diffusion process to build a baseline network, which is denoted as “C2”. Hence, “C2” is equal to removing the MMD regularization loss from our network for image classification. Table 2 reports the Accuracy and F1-score results of our method and three baseline networks on our PMG2000 dataset. Apparently, compared to “basic”, “C1” has an Accuracy improvement of 0.027 and an F1-score improvement of 0.018, which indicates that the diffusion mechanism can learn more discriminate features for medical image classification, thereby improving the PMG performance. Moreover, the better Accuracy and F1-score results of “C2” over “C1” demonstrates that introducing our dual-granularity conditional guidance into the vanilla diffusion process can benefit the PMG performance. Furthermore, our method outperforms “C2” in terms of Accuracy and F1-score, which indicates that exploring the MMD regularization loss in the diffusion process can further help to enhance the PMG results.
Visualization of our Diffusion Procedure: To illustrate the diffusion reverse process guided by our dual-granularity conditional encoding, we used the t-SNE tool to visualize the denoised feature embeddings at consecutive time steps. Figure 2 presents the results of this process on all three datasets. As the time step encoding progresses, the denoise diffusion model gradually removes noise from the feature representation, resulting in a clearer distribution of classes from the Gaussian distribution. The total number of time steps required for inference depends on the complexity of the dataset.
4 Conclusion
This work presents a diffusion-based network (DiffMIC) to boost medical image classification. The main idea of our DiffMIC is to introduce dual-granularity conditional guidance over vanilla DDPM, and enforce condition-specific MMD regularization to improve classification performance. Experimental results on three medical image classification datasets with diverse image modalities show the superior performance of our network over state-of-the-art methods. As the first diffusion-based model for general medical image classification, our DiffMIC has the potential to serve as an essential baseline for future research in this area.
Acknowledgments: This research is supported by Guangzhou Municipal Science and Technology Project (Grant No. 2023A03J0671), the National Research Foundation, Singapore under its AI Singapore Programme (AISG Award No: AISG2-TC-2021-003), A*STAR AME Programmatic Funding Scheme Under Project A20H4b0141, and A*STAR Central Research Fund.
References
- [1] Amit, T., Nachmani, E., Shaharbany, T., Wolf, L.: Segdiff: Image segmentation with diffusion probabilistic models. arXiv preprint arXiv:2112.00390 (2021)
- [2] Batzolis, G., Stanczuk, J., Schönlieb, C.B., Etmann, C.: Conditional image generation with score-based diffusion models. arXiv preprint arXiv:2111.13606 (2021)
- [3] Cao, K., Wei, C., Gaidon, A., Arechiga, N., Ma, T.: Learning imbalanced datasets with label-distribution-aware margin loss. Advances in neural information processing systems 32 (2019)
- [4] Chen, S., Sun, P., Song, Y., Luo, P.: Diffusiondet: Diffusion model for object detection. arXiv preprint arXiv:2211.09788 (2022)
- [5] De Bruijne, M.: Machine learning approaches in medical image analysis: From detection to diagnosis (2016)
- [6] Dhariwal, P., Nichol, A.: Diffusion models beat gans on image synthesis. Advances in Neural Information Processing Systems 34, 8780–8794 (2021)
- [7] Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S., et al.: An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929 (2020)
- [8] Esteva, A., Kuprel, B., Novoa, R.A., Ko, J., Swetter, S.M., Blau, H.M., Thrun, S.: Dermatologist-level classification of skin cancer with deep neural networks. Nature 542(7639), 115–118 (feb 2017)
- [9] Esteva, A., Robicquet, A., Ramsundar, B., Kuleshov, V., DePristo, M., Chou, K., Cui, C., Corrado, G., Thrun, S., Dean, J.: A guide to deep learning in healthcare. Nature Medicine 25(1), 24–29 (jan 2019)
- [10] Gong, L., Ma, K., Zheng, Y.: Distractor-aware neuron intrinsic learning for generic 2d medical image classifications. In: International Conference on Medical Image Computing and Computer-Assisted Intervention. pp. 591–601. Springer (2020)
- [11] Gretton, A., Borgwardt, K., Rasch, M., Schölkopf, B., Smola, A.: A kernel method for the two-sample-problem. Advances in neural information processing systems 19 (2006)
- [12] Han, X., Zheng, H., Zhou, M.: Card: Classification and regression diffusion models. arXiv preprint arXiv:2206.07275 (2022)
- [13] He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: Proceedings of the IEEE conference on computer vision and pattern recognition. pp. 770–778 (2016)
- [14] Ho, J., Jain, A., Abbeel, P.: Denoising diffusion probabilistic models. Advances in Neural Information Processing Systems 33, 6840–6851 (2020)
- [15] Ilse, M., Tomczak, J., Welling, M.: Attention-based deep multiple instance learning. In: International conference on machine learning. pp. 2127–2136. PMLR (2018)
- [16] Karthik, Maggie, S.D.: Aptos 2019 blindness detection (2019), https://kaggle.com/competitions/aptos2019-blindness-detection
- [17] Li, Y., Swersky, K., Zemel, R.: Generative moment matching networks. In: International conference on machine learning. pp. 1718–1727. PMLR (2015)
- [18] Liao, H., Luo, J.: A deep multi-task learning approach to skin lesion classification. arXiv preprint arXiv:1812.03527 (2018)
- [19] Liu, Z., Lin, Y., Cao, Y., Hu, H., Wei, Y., Zhang, Z., Lin, S., Guo, B.: Swin transformer: Hierarchical vision transformer using shifted windows. In: Proceedings of the IEEE/CVF international conference on computer vision. pp. 10012–10022 (2021)
- [20] Marrakchi, Y., Makansi, O., Brox, T.: Fighting class imbalance with contrastive learning. In: Medical Image Computing and Computer Assisted Intervention–MICCAI 2021: 24th International Conference, Strasbourg, France, September 27–October 1, 2021, Proceedings, Part III 24. pp. 466–476. Springer (2021)
- [21] Nichol, A.Q., Dhariwal, P.: Improved denoising diffusion probabilistic models. In: Meila, M., Zhang, T. (eds.) Proceedings of the 38th International Conference on Machine Learning. Proceedings of Machine Learning Research, vol. 139, pp. 8162–8171. PMLR (18–24 Jul 2021)
- [22] Rajpurkar, P., Chen, E., Banerjee, O., Topol, E.J.: AI in health and medicine. Nature Medicine 28(1), 31–38 (jan 2022)
- [23] Shamshad, F., Khan, S., Zamir, S.W., Khan, M.H., Hayat, M., Khan, F.S., Fu, H.: Transformers in Medical Imaging: A Survey. arXiv (jan 2022)
- [24] Shen, Y., Wu, N., Phang, J., Park, J., Liu, K., Tyagi, S., Heacock, L., Kim, S.G., Moy, L., Cho, K., et al.: An interpretable classifier for high-resolution breast cancer screening images utilizing weakly supervised localization. Medical image analysis 68, 101908 (2021)
- [25] Shrivastava, A., Gupta, A., Girshick, R.: Training region-based object detectors with online hard example mining. In: Proceedings of the IEEE conference on computer vision and pattern recognition. pp. 761–769 (2016)
- [26] Singh, J., Gould, S., Zheng, L.: High-fidelity guided image synthesis with latent diffusion models. arXiv preprint arXiv:2211.17084 (2022)
- [27] Tschandl, P., Rosendahl, C., Kittler, H.: The ham10000 dataset, a large collection of multi-source dermatoscopic images of common pigmented skin lesions. Scientific data 5(1), 1–9 (2018)
- [28] Wang, W., Xie, E., Li, X., Fan, D.P., Song, K., Liang, D., Lu, T., Luo, P., Shao, L.: Pyramid vision transformer: A versatile backbone for dense prediction without convolutions. In: Proceedings of the IEEE/CVF international conference on computer vision. pp. 568–578 (2021)
- [29] Wolleb, J., Sandkühler, R., Bieder, F., Valmaggia, P., Cattin, P.C.: Diffusion models for implicit image segmentation ensembles. In: International Conference on Medical Imaging with Deep Learning. pp. 1336–1348. PMLR (2022)
- [30] Yang, Z., Pan, J., Yang, Y., Shi, X., Zhou, H.Y., Zhang, Z., Bian, C.: Proco: Prototype-aware contrastive learning for long-tailed medical image classification. In: Medical Image Computing and Computer Assisted Intervention–MICCAI 2022: 25th International Conference, Singapore, September 18–22, 2022, Proceedings, Part VIII. pp. 173–182. Springer (2022)
- [31] Zhao, S., Song, J., Ermon, S.: Infovae: Information maximizing variational autoencoders. arXiv preprint arXiv:1706.02262 (2017)