C3PS: Context-aware Conditional Cross Pseudo Supervision for Semi-supervised Medical Image Segmentation
Abstract
Semi-supervised learning (SSL) methods, which can leverage a large amount of unlabeled data for improved performance, has attracted increasing attention recently. In this paper, we introduce a novel Context-aware Conditional Cross Pseudo Supervision method (referred as C3PS) for semi-supervised medical image segmentation. Unlike previously published Cross Pseudo Supervision (CPS) works, this paper introduces a novel Conditional Cross Pseudo Supervision (CCPS) mechanism where the cross pseudo supervision is conditioned on a given class label. Context-awareness is further introduced in the CCPS to improve the quality of pseudo-labels for cross pseudo supervision. The proposed method has the additional advantage that in the later training stage, it can focus on the learning of hard organs. Validated on two typical yet challenging medical image segmentation tasks, our method demonstrates superior performance over the state-of-the-art methods.
Keywords:
Semi-supervised learning Medical image segmentation Cross Pseudo Supervision Context-aware.1 Introduction
Medical Image Segmentation (MIS) is an essential step in many clinical applications. Past years witnessed remarkable progress of application of supervised deep learning (DL)-based methods to medical image segmentation. However, supervised DL-based methods typically require a large amount of expert-level accurate, densely-annotated data for training, which is laborious and costly to collect. To this end, various annotation-efficient methods have been introduced [1, 2, 3]. Among them, Semi-Supervised MIS (SSMIS) methods have attracted increasing attention as they only require a limited number of labeled data while leveraging a large amount of unlabeled data for improved performance. As a SSL-based image segmentation method, Cross Pseudo Supervision (CPS) [4] has set the new state-of-the-art (SOTA) in the task of semi-supervised semantic segmentation of natural images [5]. CPS works by imposing the consistency on two segmentation networks perturbed with different initialization for the same input image. The pseudo one-hot label map, generated from one perturbed segmentation network, is used to supervise the training of the other segmentation network, and vice versa. The idea behind the CPS consistency is that it can encourage high similarity between the predictions of two perturbed networks for the same input image while expanding training data by using the unlabeled data with pseudo labels.
Inspired by the original idea introduced by Chen et al. [4], various CPS-based approaches have been proposed in the literature for SSMIS [6, 7, 8]. For example, Luo et al.[6] introduced cross teaching between a Convolutional Neural Network (CNN) and a Transformer for SSMIS. Lin et al. [7] proposed a framework based on CPS which used class-aware weighted loss, probability-aware random cropping, and dual uncertainty-aware sampling supervision to conduct semi-supervised segmentation of knee joint MR images. Liu et al.[8] proposed a framework based on CPS that generated anatomically plausible predictions using shape awareness and local context constraints. Despite these progress, however, there are still space for further improvement. Specifically, how to design better cross pseudo supervision strategy is still an open problem.
In this paper, we propose a Context-aware Conditional Cross Pseudo Supervision method, referred as C3PS, for SSMIS. Unlike previous works [4, 6, 7, 8], we propose a novel Conditional Cross Pseudo Supervision (CCPS) mechanism where the cross pseudo supervision is conditioned on a given class label. Inspired by [9], context-awareness is further introduced in the CCPS to improve the quality of pseudo-labels for cross pseudo supervision. The proposed method has the additional advantage that in the later training stage, it can focus on the learning of hard organs. Our contributions can be summarized as follows:
-
•
We propose a context-aware conditional cross pseudo supervision method, referred as C3PS, for semi-supervised medical image segmentation. C3PS is based on cross teaching between a regular CNN (RNet) and a conditional CNN (CNet) where RNet is a multi-class segmentation network while CNet is a binary segmentation network conditioned on a given class label.
-
•
Based on the network design, we further introduce a novel CCPS mechanism where the cross pseudo supervision is conditioned on a given class label. CCPS mechanism has the additional advantage that in the later training stage, it can focus on the learning of hard organs. Context-awareness is further introduced in the CCPS to improve the quality of pseudo-label generation.
-
•
We validate C3PS on two typical yet challenging SSMIS tasks.
2 Method
As illustrated in Fig. 1-(a), C3PS consists of two networks: RNet and CNet. CNet is a binary segmentation network, which generates segmentation for a given conditional class label. We leverage unlabeled data by implementing CCPS between these two networks. Context-awareness is further introduced in the CCPS to improve the quality of pseudo-label generation. Formally, we denote RNet as with weights , CNet as with weights . Denote the labeled dataset, the unlabeled dataset. We train our model with and .

2.1 Supervised learning
We sample a labeled patch from . Given a conditional label (details on generating will be described below). We can get predictions from and as follows:
(1) |
The supervised loss for RNet and for CNet can be defined respectively as follows:
(2) |
(3) |
where and denote respectively the cross entropy loss and the Dice loss. aims to generate a binary mask depending on whether the label of a voxel is equal to or not. Finally, the overall supervised loss can be defined as follows:
(4) |
2.2 Conditional Cross Pseudo Supervision
2.2.1 CCPS loss between two networks.
We sample two unlabeled patches and from , where . Given a conditional label (details on generating will be described below), we can get predictions for from and as follows:
(5) |
We further obtain the pseudo labels generated by these two networks as follows:
(6) |
where is a binary segmentation mask.
We then conduct cross pseudo supervision between the RNet and the CNet. Concretely, we first use as pseudo labels to supervise the learning of the RNet:
(7) |
where means the predicted probability from the CNet for the conditional class at the -th voxel of image ; is used to ensure the quality of the pseudo label at the -th voxel of image ; is a confidence threshold; means the predicted probability from the RNet for the conditional class label at the -th voxel of image . Please note that when we use the output from the CNet to supervise the learning of the RNet, we only compute loss for foreground class since the predicted background class of the CNet may contain other organs.
Then we use as pseudo labels to supervise the learning of the CNet:
(8) |
where means the predicted probability from the RNet for the conditional class at the -th voxel of image ; and are used to ensure the quality of the pseudo label at the -th voxel of image ; is a confidence threshold; and denote the predicted probabilities from the CNet for the conditional class and other classes at the -th voxel of image , respectively.
Then our CCPS loss for the unlabeled data is defined as:
(9) |
2.2.2 Strategy for generating conditional labels.
Since we sample patches for model training, each sampled patch may only contain a subset of overall label set . In order to make our training process more efficient, we design a strategy to generate conditional labels for a labeled patch and an unlabeled patch . For the labeled patch , we first get unique label set from : , where returns the unique values in the label set. Then we randomly select from . For the unlabeled patch , we first get unique label set from : . Then we randomly select from .
2.2.3 Hard Organ Learning (HOL) with CCPS.
Since the CNet can learn with a given conditional label, in the later training stage when both the RNet and the CNet can generate stable segmentation results, we can feed a label subset to the CNet and let the CNet focus on the classes which are hard to segmentation (e.g., in abdominal organ segmentation, pancreas and kidneys are deemed to be more difficult due to their relatively small sizes). Concretely, we first define the label subset for hard organs. Then, for labeled data , we randomly select a conditional label from label set . Similarly, for unlabeled data , we randomly select a conditional label from label set . We set as the full label set in the first iterations. After that, we set as the hard organ label subset.
2.3 Context-awareness for pseudo-label generation
When patch-based strategy is used to train a network, predictions of a voxel may be different when the context is different (e.g., in different patches). Liu et al. [10] sampled two overlapped patches in each training iteration to make the model more robust to the context change. Followed by this work, we incorporate context-awareness to increase the quality of pseudo label generation. Specifically, for each iteration, we require that two sampled patches and have an overlapping region (shown in of Fig. 1-(a)). We further require that the CCPS loss defined above is only computed in the region . Then we feed and to and and get predictions using Eq. (5):
The predictions of the overlapped region in and from the RNet are and , respectively. Similarly, the predictions of the overlapped region in and from the CNet are and , respectively. We generate pseudo labels of voxels in by combining two predictions with different context as follows:
(10) |
(11) |
Finally, the loss for the unlabeled data after incorporating context-awareness is computed as follows:
(12) |
2.4 Overall loss function
The overall training objective of our proposed approach is:
(13) |
where is a weight factor defined as: , where denotes the current iteration and is the total iteration number.
Method | BCV | MMWHS | ||||||
L | U | DSC (%) | ASD (voxel) | L | U | DSC (%) | ASD (voxel) | |
Baseline | 4 | 0 | 63.7 | 11.19 | 2 | 0 | 75.6 | 28.66 |
Upper Bound | 24 | 0 | 86.9 | 10.31 | 14 | 0 | 91.2 | 4.13 |
DAN [11] | 4 | 20 | 64.2 | 11.92 | 2 | 12 | 75.4 | 37.50 |
MT [12] | 4 | 20 | 65.2 | 11.15 | 2 | 12 | 78.1 | 28.26 |
UAMT [13] | 4 | 20 | 67.6 | 13.59 | 2 | 12 | 77.7 | 29.01 |
URPC [14] | 4 | 20 | 65.7 | 15.08 | 2 | 12 | 72.9 | 36.66 |
McNet [15] | 4 | 20 | 60.4 | 10.49 | 2 | 12 | 67.7 | 21.38 |
CPS [4] | 4 | 20 | 67.2 | 11.45 | 2 | 12 | 78.9 | 27.23 |
Ours | 4 | 20 | 72.4 | 9.18 | 2 | 12 | 81.7 | 11.06 |
CPS | CCPS | Context-awareness | HOL | DSC (%) | ASD (voxel) |
---|---|---|---|---|---|
✓ | 67.2 | 11.45 | |||
✓ | ✓ | 67.8 | 9.75 | ||
✓ | ✓ | 68.5 | 9.80 | ||
✓ | ✓ | ✓ | 71.6 | 10.91 | |
✓ | ✓ | ✓ | ✓ | 72.4 | 9.18 |
3 Experiments
3.1 Dataset and Implementation details
Datasets. We evaluated our method on two public datasets: Beyond the Cranial Vault (BCV) dataset[16] and The MICCAI’17 Multi-Modality Whole Heart Segmentation challenge (MMWHS) dataset[17]. The BCV dataset contains 30 CT images with annotation for 13 organs. We chose to segment five abdominal organs including liver, spleen, pancreas, left kidney and right kidney. We used 24 samples (4 labeled and 20 unlabeled) for training and the remaining 6 samples for testing. The MMWHS consists of 20 cardiac CT samples with annotations for seven structures: left ventricle (LV), right ventricle (RV), left atrium (LA), right atrium (RA), pulmonary artery (PA), my-ocardium (MYO) and ascending aorta (AA). We took 14 samples (2 labeled and 12 unlabeled) for training and used the remaining 6 samples for testing.
Implementation details. We chose 3D U-Net[18] as the RNet and the conditional 3D U-Net (see Fig. 1-(b) for an illustration) as the CNet [19]. Please note, for a fair comparison with other SOTA SSMIS methods, we replaced the nnUNet used in [19] by 3D U-Net. We used a patch size of 160 160 96 for training. In total we trained our network 20000 iterations. We used SGD optimizer and set the initial learning rate to 0.01. Then at each iteration the initial learning rate was multiplied by , where is current iteration. We implemented our method based on PyTorch framework and conducted the evaluation on a NVIDIA Tesla V100 GPU. Starting from ()th iteration, we conducted hard organ learning. On the BCV dataset, we chose left kidney, right kidney and pancreas as the hard organs while on the MMWHS dataset, we chose MYO, RA, AA and RV as the hard organs. We empirically set the thresholds and to 0.95 and 0.9, respectively. We use Dice score coefficient (DSC; %) and Average Surface Distance (ASD; Voxel) as the evaluation metrics. Paired t-test is used to check whether a difference is statistically significant or not. We set the significant level as 0.05.

3.2 Results
We compared our methods with six SOTA SSMIS methods [11, 12, 13, 14, 4, 15]. Baseline and Upper Bound are obtained by training the RNet in a supervised manner.
BCV dataset. Table 1 presents the results on the BCV dataset. It can be found that the proposed C3PS method achieves better DSC performance than other 6 SOTA methods with a large margin when 4 labeled data were used. Additionally, from this table, one can also find that our method outperforms CPS with a large margin in terms of DSC (on average an increase of 5.2%). Paired t-test showed that the difference between our method and the CPS method [4] was statistically significant (=0.011).
MMWHS dataset. Table 1 also presents the results on the MMWHS dataset. The proposed method achieved a superior performance over other SOTA methods with an average DSC of 81.7%. Paired t-test showed that the difference between our method and the CPS method [4] was statistically significant (=1e-4).
Fig. 2 shows a visual comparison of the top-5 methods as well as our own method when applied to these two datasets. From this figure, one can see that our method predicts more consistent segmentation results with the ground truth than other methods.
Ablation Study. We conducted ablation studies on the BCV dataset to show the effectiveness of each individual components. As shown in Table 2, each component contributes to the superior performance of the proposed approach. We further investigate the effectiveness of HOL. The results are presented in Table 3. With HOL, we observed an improved performance for left and right kidney as well as pancreas.
Method | Spleen | Right Kidney | Left Kidney | Liver | Pancreas | All |
---|---|---|---|---|---|---|
Ours (w/o HOL) | 79.9 | 78.2 | 73.4 | 82.3 | 44.3 | 71.6 |
Ours (w/ HOL) | 78.8 | 78.6 | 74.9 | 83.5 | 46.3 | 72.4 |
4 Conclusion
In this paper, we proposed a context-aware conditional cross pseudo supervision method, referred as C3PS, for semi-supervised medical image segmentation. We further introduced a novel CCPS mechanism where the cross pseudo supervision was conditioned on a given class label. CCPS mechanism had the additional advantage that in the later training stage, it could focus on the learning of hard organs. Context-awareness was further introduced in the CCPS to improve the quality of pseudo-label generation. Results from experiments conducted on two public datasets demonstrate the efficacy of the proposed approach.
References
- [1] Yao, Q., Xiao, L., Liu, P., Zhou, S.K.: Label-free segmentation of covid-19 lesions in lung ct. IEEE transactions on medical imaging 40(10), 2808–2819 (2021)
- [2] Zhang, L., et al.: Generalizing deep learning for medical image segmentation to unseen domains via deep stacked transformation. IEEE transactions on medical imaging 39(7), 2531–2540 (2020)
- [3] Zhang, Y., et al.: Exploiting shared knowledge from non-covid lesions for annotation-efficient covid-19 ct lung infection segmentation. IEEE journal of biomedical and health informatics 25(11), 4152–4162 (2021)
- [4] Chen, X., Yuan, Y., Zeng, G., Wang, J.: Semi-supervised semantic segmentation with cross pseudo supervision. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. pp. 2613–2622 (2021)
- [5] Xu, M., et al.: Learning morphological feature perturbations for calibrated semi-supervised segmentation. In: Medical Imaging with Deep Learning (2021)
- [6] Luo, X., et al.: Semi-supervised medical image segmentation via cross teaching between cnn and transformer. arXiv preprint arXiv:2112.04894 (2021)
- [7] Lin, Y., et al.: Calibrating label distribution for class-imbalanced barely-supervised knee segmentation. In: MICCAI (8). Lecture Notes in Computer Science, vol. 13438, pp. 109–118. Springer (2022)
- [8] Liu, J., Desrosiers, C., Zhou, Y.: Semi-supervised medical image segmentation using cross-model pseudo-supervision with shape awareness and local context constraints. In: Medical Image Computing and Computer Assisted Intervention–MICCAI 2022: 25th International Conference, Singapore, September 18–22, 2022, Proceedings, Part VIII. pp. 140–150. Springer (2022)
- [9] Lai, X., Tian, Z., Jiang, L., Liu, S., Zhao, H., Wang, L., Jia, J.: Semi-supervised semantic segmentation with directional context-aware consistency. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. pp. 1205–1214 (2021)
- [10] Liu, P., Zheng, G.: Context-aware voxel-wise contrastive learning for label efficient multi-organ segmentation. In: Medical Image Computing and Computer Assisted Intervention–MICCAI 2022: 25th International Conference, Singapore, September 18–22, 2022, Proceedings, Part IV. pp. 653–662. Springer (2022)
- [11] Zhang, Y., et al.: Deep adversarial networks for biomedical image segmentation utilizing unannotated images. In: International conference on medical image computing and computer-assisted intervention. pp. 408–416. Springer (2017)
- [12] Tarvainen, A., Valpola, H.: Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results. Advances in neural information processing systems 30 (2017)
- [13] Yu, L., et al.: Uncertainty-aware self-ensembling model for semi-supervised 3d left atrium segmentation. In: International Conference on Medical Image Computing and Computer-Assisted Intervention. pp. 605–613. Springer (2019)
- [14] Luo, X., et al.: Efficient semi-supervised gross target volume of nasopharyngeal carcinoma segmentation via uncertainty rectified pyramid consistency. In: International Conference on Medical Image Computing and Computer-Assisted Intervention. pp. 318–329. Springer (2021)
- [15] Wu, Y., et al.: Mutual consistency learning for semi-supervised medical image segmentation. Medical Image Analysis 81, 102530 (2022)
- [16] Gibson, E., et al.: Automatic multi-organ segmentation on abdominal ct with dense v-networks. IEEE transactions on medical imaging 37(8), 1822–1834 (2018)
- [17] Zhuang, X., Shen, J.: Multi-scale patch and multi-modality atlases for whole heart segmentation of mri. Medical image analysis 31, 77–87 (2016)
- [18] Ronneberger, O., Fischer, P., Brox, T.: U-net: Convolutional networks for biomedical image segmentation. In: International Conference on Medical image computing and computer-assisted intervention. pp. 234–241. Springer (2015)
- [19] Zhang, G., Yang, Z., Huo, B., Chai, S., Jiang, S.: Multiorgan segmentation from partially labeled datasets with conditional nnu-net. Computers in Biology and Medicine 136, 104658 (2021)