11email: [email protected] 22institutetext: The Hong Kong University of Science and Technology (Guangzhou), China33institutetext: University of California, Santa Cruz, USA
Distribution Aligned Diffusion and Prototype-guided network for Unsupervised Domain Adaptive Segmentation
Abstract
The Diffusion Probabilistic Model (DPM) has emerged as a highly effective generative model in the field of computer vision. Its intermediate latent vectors offer rich semantic information, making it an attractive option for various downstream tasks such as segmentation and detection. In order to explore its potential further, we have taken a step forward and considered a more complex scenario in the medical image domain, specifically, under an unsupervised adaptation condition. To this end, we propose a Diffusion-based and Prototype-guided network (DP-Net) for unsupervised domain adaptive segmentation. Concretely, our DP-Net consists of two stages: 1) Distribution Aligned Diffusion (DADiff), which involves training a domain discriminator to minimize the difference between the intermediate features generated by the DPM, thereby aligning the inter-domain distribution; and 2) Prototype-guided Consistency Learning (PCL), which utilizes feature centroids as prototypes and applies a prototype-guided loss to ensure that the segmentor learns consistent content from both source and target domains. Our approach is evaluated on fundus datasets through a series of experiments, which demonstrate that the performance of the proposed method is reliable and outperforms state-of-the-art methods. Our work presents a promising direction for using DPM in complex medical image scenarios, opening up new possibilities for further research in medical imaging.111 Code is available at https://github.com/haipengzhou856/DPNet
Keywords:
Unsupervised domain adaptation fundus segmentation diffusion model1 Introduction
Automated segmentation of fundus images and privacy protection are important considerations in glaucoma screening and diagnosis. However, there are two significant challenges to address. Firstly, transplanting a well-trained model from one domain to another may lead to a decline in performance due to the domain-shift problem [16]. Secondly, obtaining annotations is often tricky, especially in medical image scenarios where privacy protection is a concern. To address these challenges, unsupervised domain adaptation (UDA) [28, 13, 15, 23, 22, 8] techniques have been developed. UDA approaches aim to align the distribution between the source and target domains, extracting domain-invariant features to alleviate the domain gap. Notably, UDA requires no access to annotations from the target domain, which reduces the manpower consumption needed for data labeling and makes it more practical for medical applications.
There are two main UDA approaches, namely adversarial learning and intra-domain category rectification. Adversarial learning methods involve training a discriminator to minimize the domain discrepancy between different domains, typically at the input- [18, 4, 3], feature- [14], or output-level [23], enforcing the upstream extractor to learn domain-invariant features. While this approach has shown promise, conventional CNNs are limited in their ability to extract and align features across domains. Advanced architectures such as Transformers [20, 26] have shown potential for improving domain generalization. Additionally, Diffusion probabilistic model (DPM) has emerged as a more effective and robust approach for generating images and downstream computer vision tasks such as segmentation [2] and detection [6]. However, their usage in domain adaptation remains unexplored. Another mainstream approach is intra-domain category rectification, which aims to rectify discrepancies. Pseudo-labeling methods [29, 25, 4] are commonly used for this purpose, but these methods heavily rely on the quality of generated pseudo-labels, which can lead to performance degradation when given uncertain labels. To address this issue, uncertainty estimation techniques are widely used to denoise unreliable masks. For instance, DPL [4] estimates uncertainty via Monte Carlo Dropout for Bayesian approximation and filters out the noise in the pseudo-label. Domain invariants, such as prototypes, can also be used to regularize the training process to rectify the distribution discrepancy between different domains. ProDA [27] suggests that prototypes can act as anchors for intra-domain category rectification, aiding in extracting consistent content from source and target domains.
In this paper, we investigate the domain generalization capability of DPM and propose a novel Diffusion-based and Prototype-guided network, called DP-Net, for unsupervised domain adaptive segmentation. To the best of our knowledge, this is the first attempt to explore DPM for domain adaptation purposes. Our proposed DP-Net comprises two stages: Distribution Aligned Diffusion (DADiff) and Prototype-guided Consistency Learning (PCL) module. To align the inter-domain distribution, the DADiff aims to train a domain discriminator to minimize the difference between the intermediate features of the source and target images generated by the DPM. The PCL stage, on the other hand, promotes consistent learning across different domains by utilizing class-wise feature centroids as domain-invariant prototypes, encouraging the segmentor to learn consistent content. By leveraging the latest diffusion models, DP-Net can extract domain-generalized features more effectively, resulting in superior performance for unsupervised domain adaptive segmentation on fundus image data.
2 Method
Our proposed DP-Net comprises two stages. In stage 1, we introduce the Distribution Aligned Diffusion (DADiff) model that combines a diffusion model with a domain discriminator to extract generalized features from both source and target images. DADiff mitigates the domain gap and yields latent representations, denoted as and (Fig. 1). In stage 2, we design a Prototype-guided Consistency Learning (PCL) module, where is used to train to predict the source domain’s segmentation map. We generate a pseudo-label with and compute a prototype-based consistency loss, incorporating both domain segmentation outcomes to guide to prioritize consistent objectives (Fig. 2).
2.1 Distribution Aligned Diffusion (DADiff) model

The exploration of latent feature representations from generative models has been widely investigated in many dense prediction tasks [21, 24]. However, GANs are prone to collapse and out-of-distribution issues. On the other hand, the diffusion model has attracted much research attention and has demonstrated superior performance in many vision tasks due to its noise robustness.
Instead of taking labels as input in the original DPM [1], we adapt the vanilla Diffusion [7], which is originally used to generate an image. This allows us to obtain latent feature representations from the intermediate activations of DPM, which contain rich semantic information [2]. In order to align the feature distributions between the source and target domains, we propose Distribution Aligned Diffusion (DADiff), as shown in Fig. 1. The generated images are discarded, and only the activations features of DPM are utilized. DADiff introduces a sequence of Gaussian noises to the input image in the forward pass and reconstructs the image in the reverse pass to predict the noise. Specifically, given a sample , the forward process is:
(1) |
where denotes the variance condition at the step . is the identity matrix. is a distribution. The noisy sample at the time step can be obtained as:
(2) |
where . . Meanwhile, the reverse diffusion process is parametered by and can be described as:
(3) |
Instead of directly obtaining the distribution of Eq. 3, our DADiff follows [12] to predict the noise via a UNet :
(4) |
where denotes the variance scheme learnt by . is often assign with a UNet. We follows [7] to compute the noise estimation loss to train the UNet .
We utilize the trained UNet to extract latent diffusion features from the source image and from the target image . Specifically, in the reverse diffusion process, our DADiff follows the Markov step to collect the latent activations and , respectively. To reduce the domain discrepancy between these two domains, we implement a domain discriminator with a gradient reversal layer (GRL) [11] to ensure inter-domain indistinguishability, as shown in Fig. 1. Specifically, we devise the following loss to learn the domain discrimintor:
(5) |
where and denote the source and target domain labels, respectively. The coordinate in the diffusion feature map is denoted as .
2.2 Prototype-guided Consistency Learning (PCL) Module

Given the latent diffusion features generated by DADiff in the first stage, we introduce a Prototype-guided Consistency Learning (PCL) module in the second stage to further narrow the gap between the source and target domains, where the detail is shown in Fig. 2. In the UDA setting, we have access to the source domain features and their corresponding segmentation annotations , as well as the target domain features without any label. To perform domain adaptation, we first pass the source domain features through a shared segmentation decoder to obtain segmentation results, on which we perform supervised learning to improve the segmentation performance on the source domain. Similarly, we generate pseudo labels for the target domain based on a probability threshold selection . With inputting and , one can obtain the output feature and from the last convolution layer of the segmentation decoder. And we compute the object prototype at the source domain and the object prototype at and target domain as:
(6) |
where denotes the index of pixel and is the number of pixels of object classes. However, the prototype of the target domain may be unreliable due to false predictions produced by . To address it, we follow [4] and estimate the uncertainty to denoise the coarse pseudo labels. Using Monte Carlo Dropout [10], we subject the decoder to dropout during stochastic forward passes on each sample, generating a series of coarse pseudo labels . We compute the uncertainty map by taking the standard deviation of the forward predictions. We then obtain a refined pseudo label by applying an uncertainty threshold to generate a stable prediction.
After refining the pseudo labels in the target domain, we propose a prototype-guided consistency learning loss to further narrow the gap between the source and target domains for domain adaptation. This loss aims to reduce the distance between the prototypes (i.e., and ) from the source and target domains. The definition of is given by:
(7) |
By adding with the supervised learning loss at the source domain, the total loss of the stage 2 is computed by:
(8) |
where the weight is empirically set as in our experiments.
3 Experiments
3.0.1 Dataset and evaluation metrics.
We adopt three publicly available fundus datasets for our experiments, and each one is regarded as an identifiable domain showing a domain gap. In detail, we use the training set of the REFUGE [17] as the source domain and test on the target domain RIM-ONE-r3 [9] and Drishti-GS [19]. The source domain includes 400 annotated training images. Following the data usage manner in [22], we split 99/60 and 50/51 for training/testing in RIMONE-r3 and Drishti-GS respectively. We crop a 512512 disc region of interest (RoI) as the network input. The source domain has data augmentation operations, including random rotation, random flipping, elastic transformation, contrast adjustment, adding Gaussian noise and random erasing, while we apply no operation for the target domain for UDA consideration. We introduce the Dice coefficient to compare different segmentation methods quantitatively.
3.0.2 Implementation details.
For Stage 1 of our method, the totally sampling timestep is with iterations. The optimizer adopts Adam with a learning rate of . And we utilize the representations from the middle blocks of U-Net decoder . The reverse diffusion process steps are . As for Stage 2 of our method, in order to make a fair comparison with existing methods, we adopt the decoder of deeplabv3+ [5] for the segmentor G. The dropout rate is for stochastic forward passes. The threshold and are set as and . The batch size of both is , where is between source and target domains. We implement our network using Pytorch framework with RTX Titan Xp GPUs.
3.1 Comparison with state-of-the-arts
Methods | RIM-ONE-r3[9] | Drishti-GS[19] | ||
Dice disc | Dice cup | Dice disc | Dice cup | |
Baseline(w/o DA) | 0.946 | 0.879 | 0.974 | 0.912 |
TD-GAN [28] | 0.853 | 0.728 | 0.924 | 0.747 |
Hoffman et al. [13] | 0.852 | 0.755 | 0.959 | 0.851 |
Javanmardi et al. [15] | 0.853 | 0.779 | 0.961 | 0.849 |
OSAL-pixel [23] | 0.854 | 0.778 | 0.962 | 0.851 |
pOSAL [23] | 0.865 | 0.787 | 0.965 | 0.858 |
BEAL [22] | 0.898 | 0.810 | 0.961 | 0.862 |
CLR [8] | 0.905 | 0.841 | 0.966 | 0.892 |
Ours | 0.913 | 0.852 | 0.966 | 0.884 |
Quantitative Comparison. We compare our proposed method against seven state-of-the-art UDA algorithms and report their Dice scores for optic disc and optic cup segmentation for Drishti-GS and RIM-ONE-r3 datasets in Table 1. Here, we first report a supervised baseline result (w/o DA) which takes vanilla DPM to extract features. Apparently, this result is the upper bounder of all UDA methods. For the RIM-ONE-r3 dataset, CLR [8] has the best Dice score for optic disc and optic cup segmentation among all seven compared methods, and they are 0.905 and 0.841. Contrarily, our method further outperforms CLR in terms of the optic disc and optic cup segmentation for the RIM-ONE-r3 dataset. It improves the Dice score from 0.905 to 0.913 for the optic disc segmentation, and the Dice score from 0.841 to 0.852 for the optic cup segmentation. Regarding the Drishti-GS dataset, both our method and CLR [8] have the Dice score of 0.966, which is the largest one for the optic disc segmentation. Although the Dice score (0.884) of our method for the optic cup segmentation takes the 2nd rank, it is only slightly smaller than the best one (0.892). We argue that such inferior performance (from 0.892 to 0.884) is caused by that BEAL [22] and CLR [8] adopt a pre-trained backbone from ImageNet, while our diffusion configuration only uses the fundus image data, which is pretty smaller than ImageNet.

Visual Comparison. Fig. 3 visually compares the segmentation results produced by our network and state-of-the-art methods for the input images from Drishti-GS and RIM-ONE-r3. Apparently, our method can more accurately segment the optic disc (green) and optic cup (blue) regions, and our segmentation results are most consistent with the ground truth; see the first column of Fig. 3.
3.2 Ablation study
Configuration | Dataset | |||||||
RIM-ONE-r3[9] | Drishti-GS[19] | |||||||
Name | Diff | DA | +DPL | +PCL | Dice disc | Dice cup | Dice disc | Dice cup |
M1 (basic) | ✓ | 0.846 | 0.793 | 0.934 | 0.809 | |||
M2 | ✓ | ✓ | 0.884 | 0.821 | 0.947 | 0.869 | ||
M3 | ✓ | ✓ | ✓ | 0.916 | 0.829 | 0.963 | 0.876 | |
M4 (Ours) | ✓ | ✓ | ✓ | 0.913 | 0.852 | 0.966 | 0.884 | |
We conduct ablation study experiments to investigate major components of our framework. We first construct a baseline (denoted as “M1”) by only extracting diffusion features to predict the segmentation results. It is equal to remove the domain discriminator at Eq. 5 and the prototype-guided consistency loss at Eq. 7 of the PCL module. Then, we add the domain discriminator into “M1” to construct another baseline network (“M2”). After that, we add a DPL [4] module into “M2” to build a baseline network (“M3”), and the prototype-guided consistency loss into “M2” to build “M4”. Apparently, “M4” is equal to the full setting of our network.
Table 2 reports the Dice score of the optic disc and cup segmentation of our method and baseline networks in terms of Drishti-GS and RIM-ONE-r3 datasets. Apparently, “M2” has larger Dice scores than “M1” on the optic disc and cup segmentation for both two datasets. It indicates that the domain discriminator of Eq. 5 on diffusion features can reduce the domain gap of the source and target domains, thereby improving the UDA performance. Then, “M3” and our method (i.e., “M4”) outperform “M2” for the optic disc and cup segmentation on the two datasets. It demonstrates that exploring the prototype information to compute consistency loss can further enhance the domain adaptation performance. More importantly, our method has a superior Dice score over “M3”. It indicates that our PCL can better reduce the domain gap than DPL [4]. Specifically, compared to “M3”, our method improves the Dice score of the optic disc segmentation from 0.963 to 0.966, and the Dice score of the optic cup segmentation from 0.876 to 0.844 on the Drishti-GS dataset. For the RIM-ONE-r3 dataset, our method improves the Dice score of the optic cup segmentation from 0.829 to 0.852. And the Dice scores of the optic disc segmentation for our method and “M3” are very close, and they are 0.916 and 0.913, respectively.
4 Conclusion
In this paper, we propose the Diffusion-based and Prototype-guided network (DP-Net) for Unsupervised Domain Adaptive Segmentation. Specifically, we investigate the effectiveness of the diffusion model for domain adaptation, and propose the Distribution Aligned Diffusion (DADiff) approach to extract generalized features. Additionally, we develop the Prototype-guided Consistency Learning (PCL) module that utilizes the foreground prototype to guide the segmentor in learning consistent objectives. Our experimental results on benchmark datasets demonstrate the superior performance of DP-Net over state-of-the-art methods, indicating its potential for other unsupervised domain adaptation tasks.
References
- [1] Amit, T., Nachmani, E., Shaharbany, T., Wolf, L.: Segdiff: Image segmentation with diffusion probabilistic models. arXiv preprint arXiv:2112.00390 (2021)
- [2] Baranchuk, D., Rubachev, I., Voynov, A., Khrulkov, V., Babenko, A.: Label-efficient semantic segmentation with diffusion models. arXiv preprint arXiv:2112.03126 (2021)
- [3] Bateson, M., Kervadec, H., Dolz, J., Lombaert, H., Ben Ayed, I.: Source-relaxed domain adaptation for image segmentation. In: Medical Image Computing and Computer Assisted Intervention–MICCAI 2020: 23rd International Conference, Lima, Peru, October 4–8, 2020, Proceedings, Part I 23. pp. 490–499. Springer (2020)
- [4] Chen, C., Liu, Q., Jin, Y., Dou, Q., Heng, P.A.: Source-free domain adaptive fundus image segmentation with denoised pseudo-labeling. In: Medical Image Computing and Computer Assisted Intervention–MICCAI 2021: 24th International Conference, Strasbourg, France, September 27–October 1, 2021, Proceedings, Part V 24. pp. 225–235. Springer (2021)
- [5] Chen, L.C., Zhu, Y., Papandreou, G., Schroff, F., Adam, H.: Encoder-decoder with atrous separable convolution for semantic image segmentation. In: Proceedings of the European conference on computer vision (ECCV). pp. 801–818 (2018)
- [6] Chen, S., Sun, P., Song, Y., Luo, P.: Diffusiondet: Diffusion model for object detection. arXiv preprint arXiv:2211.09788 (2022)
- [7] Dhariwal, P., Nichol, A.: Diffusion models beat gans on image synthesis. Advances in Neural Information Processing Systems 34, 8780–8794 (2021)
- [8] Feng, W., Wang, L., Ju, L., Zhao, X., Wang, X., Shi, X., Ge, Z.: Unsupervised domain adaptive fundus image segmentation with category-level regularization. In: Medical Image Computing and Computer Assisted Intervention–MICCAI 2022: 25th International Conference, Singapore, September 18–22, 2022, Proceedings, Part II. pp. 497–506. Springer (2022)
- [9] Fumero, F., Alayón, S., Sanchez, J.L., Sigut, J., Gonzalez-Hernandez, M.: Rim-one: An open retinal image database for optic nerve evaluation. In: international symposium on computer-based medical systems. pp. 1–6. IEEE (2011)
- [10] Gal, Y., Ghahramani, Z.: Dropout as a bayesian approximation: Representing model uncertainty in deep learning. In: international conference on machine learning. pp. 1050–1059. PMLR (2016)
- [11] Ganin, Y., Lempitsky, V.: Unsupervised domain adaptation by backpropagation. In: International conference on machine learning. pp. 1180–1189. PMLR (2015)
- [12] Ho, J., Jain, A., Abbeel, P.: Denoising diffusion probabilistic models. Advances in Neural Information Processing Systems 33, 6840–6851 (2020)
- [13] Hoffman, J., Wang, D., Yu, F., Darrell, T.: Fcns in the wild: Pixel-level adversarial and constraint-based adaptation. arXiv preprint arXiv:1612.02649 (2016)
- [14] Hsu, C.C., Tsai, Y.H., Lin, Y.Y., Yang, M.H.: Every pixel matters: Center-aware feature alignment for domain adaptive object detector. In: Computer Vision–ECCV 2020: 16th European Conference, Glasgow, UK, August 23–28, 2020, Proceedings, Part IX 16. pp. 733–748. Springer (2020)
- [15] Javanmardi, M., Tasdizen, T.: Domain adaptation for biomedical image segmentation using adversarial training. In: 2018 IEEE 15th International Symposium on Biomedical Imaging (ISBI 2018). pp. 554–558. IEEE (2018)
- [16] Moreno-Torres, J.G., Raeder, T., Alaiz-Rodríguez, R., Chawla, N.V., Herrera, F.: A unifying view on dataset shift in classification. Pattern recognition 45(1), 521–530 (2012)
- [17] Orlando, J.I., Fu, H., Breda, J.B., van Keer, K., Bathula, D.R., et al.: Refuge challenge: A unified framework for evaluating automated methods for glaucoma assessment from fundus photographs. Medical image analysis 59, 101570 (2020)
- [18] Shaban, M.T., Baur, C., Navab, N., Albarqouni, S.: Staingan: Stain style transfer for digital histological images. In: 2019 Ieee 16th international symposium on biomedical imaging (Isbi 2019). pp. 953–956. IEEE (2019)
- [19] Sivaswamy, J., Krishnadas, S., Chakravarty, A., Joshi, G., Tabish, A.S., et al.: A comprehensive retinal image dataset for the assessment of glaucoma from the optic nerve head analysis. JSM Biomedical Imaging Data Papers 2(1), 1004 (2015)
- [20] Sun, T., Lu, C., Zhang, T., Ling, H.: Safe self-refinement for transformer-based domain adaptation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. pp. 7191–7200 (2022)
- [21] Tritrong, N., Rewatbowornwong, P., Suwajanakorn, S.: Repurposing gans for one-shot semantic part segmentation. In: Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. pp. 4475–4485 (2021)
- [22] Wang, S., Yu, L., Li, K., Yang, X., Fu, C.W., Heng, P.A.: Boundary and entropy-driven adversarial learning for fundus image segmentation. In: Medical Image Computing and Computer Assisted Intervention–MICCAI 2019: 22nd International Conference, Shenzhen, China, October 13–17, 2019, Proceedings, Part I 22. pp. 102–110. Springer (2019)
- [23] Wang, S., Yu, L., Yang, X., Fu, C.W., Heng, P.A.: Patch-based output space adversarial learning for joint optic disc and cup segmentation. IEEE transactions on medical imaging 38(11), 2485–2495 (2019)
- [24] Xu, Y., Shen, Y., Zhu, J., Yang, C., Zhou, B.: Generative hierarchical features from synthesizing images. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. pp. 4432–4442 (2021)
- [25] Yan, P., Wu, Z., Liu, M., Zeng, K., Lin, L., Li, G.: Unsupervised domain adaptive salient object detection through uncertainty-aware pseudo-label learning. In: Proceedings of the AAAI Conference on Artificial Intelligence. vol. 36, pp. 3000–3008 (2022)
- [26] Yang, J., Liu, J., Xu, N., Huang, J.: Tvt: Transferable vision transformer for unsupervised domain adaptation. In: Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision. pp. 520–530 (2023)
- [27] Zhang, P., Zhang, B., Zhang, T., Chen, D., Wang, Y., Wen, F.: Prototypical pseudo label denoising and target structure learning for domain adaptive semantic segmentation. In: Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. pp. 12414–12424 (2021)
- [28] Zhang, Y., Miao, S., Mansi, T., Liao, R.: Task driven generative modeling for unsupervised domain adaptation: Application to x-ray image segmentation. In: Medical Image Computing and Computer Assisted Intervention–MICCAI 2018: 21st International Conference, Granada, Spain, September 16-20, 2018, Proceedings, Part II. pp. 599–607. Springer (2018)
- [29] Zheng, Z., Yang, Y.: Rectifying pseudo label learning via uncertainty estimation for domain adaptive semantic segmentation. International Journal of Computer Vision 129(4), 1106–1120 (2021)