Two-Stage Multi-task Self-Supervised Learning for Medical Image Segmentation
Abstract
Medical image segmentation has been significantly advanced by deep learning (DL) techniques, though the data scarcity inherent in medical applications poses a great challenge to DL-based segmentation methods. Self-supervised learning offers a solution by creating auxiliary learning tasks from the available dataset and then leveraging the knowledge acquired from solving auxiliary tasks to help better solve the target segmentation task. Different auxiliary tasks may have different properties and thus can help the target task to different extents. It is desired to leverage their complementary advantages to enhance the overall assistance to the target task. To achieve this, existing methods often adopt a joint training paradigm, which co-solves segmentation and auxiliary tasks by integrating their losses or intermediate gradients. However, direct coupling of losses or intermediate gradients risks undesirable interference because the knowledge acquired from solving each auxiliary task at every training step may not always benefit the target task. To address this issue, we propose a two-stage training approach. In the first stage, the target segmentation task will be independently co-solved with each auxiliary task in both joint training and pre-training modes, with the better model selected via validation performance. In the second stage, the models obtained with respect to each auxiliary task are converted into a single model using an ensemble knowledge distillation method. Our approach allows for making best use of each auxiliary task to create multiple elite segmentation models and then combine them into an even more powerful model. We employed five auxiliary tasks of different proprieties in our approach and applied it to train the U-Net model on an X-ray pneumothorax segmentation dataset. Experimental results demonstrate the superiority of our approach over several existing methods.
Index Terms:
medical image segmentation, deep learning, self-supervised learning, multi-task learning, knowledge distillationI Introduction
Medical image segmentation (MIS) aims to pixel-wise delineate the targets of interest in medical images, serving as the foundational step in numerous medical image analysis pipelines. Automatic MIS has received much attention, with deep learning (DL) showcasing remarkable success on a broad spectrum of tasks [1]. The success of DL typically thrives on abundant annotated data. However, because of high annotation costs in the medical domain, data scarcity poses a significant challenge that limits model performance. An MIS dataset often contains as few as hundreds or tens of data samples. Supervised by such a small dataset, the trained model risks learning limited knowledge and overfitting training data, leading to poor generalisation performance when applied to real-world applications.
Self-supervised learning (SSL) provides a promising way to tackle data scarcity by creating auxiliary tasks to facilitate solving the segmentation task. Through solving auxiliary tasks, the extra knowledge acquired can be transferred to the segmentation model to enrich its knowledge representation and thus boost generalisation performance. Without additional annotation efforts, auxiliary tasks can be created based on the segmentation dataset in various ways. First, they can be created solely based on the images, which can be realised by many unsupervised tasks, e.g., auto-encoding and contrastive learning. Second, auxiliary tasks can be created based on the segmentation annotations, yielding tasks such as surface distance map prediction [2] and contour prediction [3], which sometimes can be regarded as alternative formulations of the segmentation task. Different auxiliary tasks with different mechanisms can transfer different knowledge to boost performance from distinct aspects. It is thus desired to combine their complementary advantages by involving multiple auxiliary tasks to further assist the target segmentation task.
To leverage multiple auxiliary tasks, existing approaches [4, 5, 6] commonly adopt a multi-task learning formulation and a joint training paradigm, where the segmentation and auxiliary tasks are solved simultaneously via a certain integration of their respective losses or intermediate gradients at each iteration step. While such a paradigm often requires homogeneous tasks with highly related learning objectives to enable effective knowledge transfer, auxiliary tasks can be quite heterogeneous in their task formulations and thus the learned knowledge. Consequently, knowledge acquired from solving different auxiliary tasks at every training step may interfere with each other and not always benefit the target task. Specifically, the flaws are twofold. First, potential negative transfer from certain auxiliary tasks can hamper solving the segmentation task, which becomes severe with more auxiliary tasks involved. Second, for certain auxiliary tasks, interference from other tasks may impede learning valuable task-specific knowledge, making joint training infeasible to leverage these tasks. As complementary to joint training, a more suitable way to use these auxiliary tasks is to pre-train the model with an auxiliary task and fine-tune it on the segmentation task. As will be demonstrated in Section IV-C, different auxiliary tasks have distinct choices in the appropriate training mode to enable effective knowledge transfer and otherwise yield suboptimal performance. In a nutshell, facing heterogeneous auxiliary tasks, the interference among tasks should be alleviated and both joint training and pre-training modes need to be supported to take full advantage of auxiliary tasks.
To this end, we propose to harness multiple auxiliary tasks with a two-stage training method. In the first stage, the target segmentation task will be independently co-solved with each auxiliary task in both joint training and pre-training modes, and the better model is selected via validation performance. This results in multiple elite segmentation models, each containing distinct knowledge transferred from an auxiliary task with the appropriate training mode. In the second stage, we employ an ensemble knowledge distillation method [7] to convert the models obtained with respect to each auxiliary task into a single model. As such, the diverse knowledge obtained from various auxiliary tasks is combined into an even more powerful model for inference.
We employ five representative auxiliary tasks in our method, including three unsupervised tasks, i.e., Rubik’s Cube [8], MoCo-v3 [9], and VICReg [10] and two surface distance prediction tasks created based on segmentation annotations. The implemented method is used to train U-Net [11] models on an X-ray pneumothorax segmentation dataset (SIIM-ACR Pneumothorax Segmentation Challenge) [12]. Experimental results demonstrate the superiority of the proposed method in leveraging multiple auxiliary tasks to boost segmentation performance.
II Related Work
II-A Medical Image Segmentation
Medical image segmentation (MIS) aims at accurately delineating the target of interest (e.g., organs and lesions) in medical images, e.g., X-ray, CT, and MRI. So far, MIS has been significantly advanced by DL, demonstrated by state-of-the-art performance in a broad spectrum of tasks. Due to their over-parameterised nature, models in DL are data-hungry and typically require abundant annotated data to ensure good performance. Because the annotating process is expert-demanding and labour-intensive, MIS annotations are expensive and rare. As a result, DL learning with small annotated datasets has been in active research. So far, many endeavours have been focusing on improving model architecture [13, 14, 15] to more efficiently extract generalisable features. Other works have explored learning in certain settings with no or cheaper annotations. For instance, semi-supervised learning leverages abundant unlabelled imaging data to facilitate learning. Self-supervised learning [16] can train models solely with images in an unsupervised manner. Weakly supervised learning enables learning from weak but cheap annotations. Our work focuses on the fully supervised setting where only an annotated dataset is available, and there are no additional data or annotations.
II-B Self-Supervised Learning
Self-supervised learning (SSL) generally aims at feature representation learning from unlabelled data. In a common SSL paradigm, a model is pre-trained on large unlabeled data with carefully designed pretext tasks to serve as a general initialisation. It is expected that the knowledge learnt for solving the pretext tasks can be transferred to help solve various downstream tasks. Recent advances in SSL have achieved tremendous success in training DL-based models. So far, SSL has not only been widely demonstrated as an effective tool to combat data scarcity and boost performance but has also become a foundation technique for training large models [9]. The art of SSL is to design pretext tasks considering both the data domain and the downstream task property so that rich and useful knowledge can be extracted from the unlabeled upstream data to benefit downstream task performance. With abundant unlabelled data existing in the medical field, SSL for MIS has also been actively explored, most of which adapts the natural image counterparts to the medical image domain [8, 17]. The prosperity of SSL provides diverse unsupervised tasks that can serve as auxiliary tasks in our supervised setting.
II-C Multi-Task Learning
Multi-task learning (MTL) is a machine learning paradigm that learns multiple related tasks simultaneously. MTL can enrich overall feature representation and boost task performance by sharing complementary knowledge obtained from solving different tasks or letting tasks act as regularisers of one another [18]. MTL techniques have been extensively explored by improving model architecture or optimisation process, aiming to better capture shared information, avoid conflicts among tasks, or balance the performance of different tasks. While MTL typically relies on a dataset with multiple annotations, model training with a single-task dataset can still take advantage of MTL by involving auxiliary tasks. This setting, sometimes known as auxiliary learning, is a special case of MTL, which focuses on only the performance of the target task. Existing auxiliary learning approaches commonly adopt the MTL formulation. Based on the joint training paradigm, different ways to adaptively adjust auxiliary task contributions have been explored. Most methods focus on task-level adaptiveness via selecting auxiliary tasks to learn from [4, 19] or reweighting the contribution of different tasks in loss [6, 20] or gradient aggregation [4]. To achieve this, some methods rely on heuristic rules to assess task relationships [4, 19], and others are driven by feedback from validation set [6, 20].
II-D Knowledge Distillation
Knowledge distillation (KD) [7] is a popular model training technique in DL to enable knowledge transfer among models. KD was initially proposed for model compression [7, 21], where a large-sized pre-trained “teacher” model trains a small-sized “student” model by supervising the student with the teacher’s predictions. Through soft probabilistic labels, rich information can be transferred to the student model to facilitate its learning process [22, 23]. KD techniques have received active exploration due to the flexibility to transfer knowledge solely based on model outputs, posing no restriction on the architecture or the number of teacher models. So far, KD operations have been further developed to enable distilling knowledge with different representations, such as intermittent feature maps and inter-class relationships. The utility of KD has also been extended to transfer knowledge from multiple teacher models [24], a teacher trained together with the student [25], or even the student itself [26].
III Method
III-A Preliminary
For conciseness, we use single-class segmentation on 2D single-channel images to derive the notations and formulations to be used in the following sections. Note that our method is naturally compatible with 3D, multi-channel, multi-modalities, and multi-class segmentation settings. We use to denote an image, where and represent its height and width, respectively. The binary segmentation mask annotation for each image is denoted as . A model for image segmentation can be represented as , where denotes its parameters.
In a typical training process, a model is trained by optimising its parameters to solve the segmentation task represented by a training set , and the training process can be formulated as:
(1) |
where denotes the optimised parameters, and represents the loss function, which is often the Dice loss in the MIS domain. After training, the generalisation performance of is evaluated on a test set which does not overlap with .
Generally, we can denote the segmentation task as . An auxiliary task aims to train the model to enrich its knowledge representation in by bringing in additional knowledge that cannot be learnt solely from and thus improve model’s generalisation performance. The dataset used by auxiliary task is created based on , using either the images and labels in or solely the images. Different tasks have different formulations and can be heterogeneous in terms of their input and output spaces. For example, in a segmentation task using whole images, and , wheras in a contrastive learning task using image patches, and . To differentiate the inputs and outputs among tasks, we use and to denote the image and label for segmentation and use and to denote the image and label for auxiliary task . Our work involves multiple auxiliary tasks to improve model training, where denotes the total number of auxiliary tasks. To enable knowledge transfer, every auxiliary task shares a certain part of the model with the segmentation task, parameterised by , and every task, including the segmentation task, has its task-specific head. We use to denote the segmentation-specific parameters, and to denote the parameters specific to auxiliary task .
III-B Framework

Given a target MIS dataset and multiple auxiliary tasks, the proposed method aims to obtain a model with good performance on the MIS dataset by leveraging the auxiliary tasks. As shown in Figure 1, the method is composed of two training stages. First, the target segmentation task will be independently co-solved with each auxiliary task in both joint training and pre-training modes, and the better model obtained by the two modes is selected via validation performance. As a result, multiple elite segmentation models are obtained, each containing useful and distinct knowledge transferred from an auxiliary task. Second, these models are treated as teacher models and are compressed into a single model via ensemble knowledge distillation [7]. In this way, the diverse knowledge obtained from multiple auxiliary tasks is aggregated to obtain an even more powerful student model, which is output for inference. The following sections describe the two training stages in detail.
III-C Task-Specific Teacher Training
This stage aims to obtain a segmentation model independently facilitated by each auxiliary task. In this way, the interference among auxiliary tasks can be avoided, and the most appropriate training mode for each auxiliary task can be independently applied to make best use of each auxiliary task. For auxiliary task , we train two models to co-solve and segmentation in two modes: joint training and pre-training, respectively.
In the joint training mode, the model is trained to concurrently solve the segmentation and auxiliary tasks in a multi-task fashion. At each iteration step, we sample a batch of samples for the segmentation task and the auxiliary task, respectively, i.e., and . Then the inputs are fed to the model, and the training loss is calculated as
(2) |
where denotes a scalar coefficient to weight the loss contribution of task . A larger encourages the trained model to learn more from the auxiliary task and vice versa. All the model parameters , , and are jointly updated to minimise .
The pre-training mode involves two training steps. First, the model is pre-trained to solve the auxiliary task via the task-specific loss and data:
(3) |
to train and . Afterwards, the shared parameters are transferred to initialise the segmentation model in the fine-tuning step, where the model is trained to solve the segmentation task via
(4) |
In the fine-tuning step, the pre-trained are continuously trained along with randomly initialised .
After the two models are trained, we select the better one based on the validation performance evaluated by the validation set . By selecting the most appropriate training mode for leveraging each auxiliary task, the chance of negative transfer is reduced. Also, because each auxiliary task takes effect independently, the conflicts between auxiliary tasks are avoided. Another advantage of independent training is that this step can be naturally distributed to different machines to accelerate training.
III-D Multi-Teacher Ensemble Knowledge Distillation
This stage aims to compress those models obtained by the previous stage, denoted as to a single model for inference. In this way, the diverse knowledge obtained from multiple auxiliary tasks is aggregated to obtain an even more powerful model. To achieve this, we employ an ensemble knowledge distillation [7]. The previously obtained models are treated as teachers to distil their knowledge into the student model . For each input image, the average of all the teacher model’s probabilistic prediction maps, i.e., their ensemble prediction , is used as the pseudo annotation to train the student model. Considering that the pseudo annotation inevitably carries erroneous predictions from the teacher models, only using to supervise the student is prone to misleading. To address this issue, following common practice, we define the loss for student training as a linear combination of the loss w.r.t. the ground-truth annotation and the KD loss w.r.t. the pseudo annotation. Putting together, the total loss for training the student model is formulated as
(5) |
where is a coefficient balancing the impact of ground-truth annotations and pseudo annotations. We set intuitively to treat every pseudo annotation from a teacher model as equally important as the ground-truth annotation.
IV Experiments
We first describe the datasets used in this study in Section IV-A, then introduce the experimental setup in Section IV-B. Then, in Section IV-C, we show the crucial impact of the training mode for leveraging an auxiliary task. Based on the results in this section, we determine the training mode for each auxiliary task in the implementation of our method to be used in the subsequent experiments. In Section IV-D, we compare our proposed method with several existing methods for leveraging auxiliary tasks to demonstrate the superiority of our method. Then in Section IV-E, we measure the impact of our method on saving training data compared to the conventional training method. Finally, we provide further analysis on the auxiliary task contributions Section IV-F and the impact of the KD loss coefficient Section IV-G.
IV-A Dataset
PNE (SIIM-ACR Pneumothorax Segmentation Challenge) [12] is the largest public dataset for pneumothorax segmentation. The images take a subset of the ChestX-ray14 dataset [27] released by the National Institutes of Health (NIH), which consists of both positive and negative pneumothorax cases. In the positive cases, pixel-wise delineations of pneumothorax are provided. In this study, we take all 2669 samples with positive pneumothorax for experiments. In cases where multiple annotations are available, we merge the annotation masks by a union operation to include all the positive regions. All images are resized to . We randomly split the dataset into 1601/267/801 samples for training/validation/testing.
IV-B Experimental Setup
Common setup. To simulate a scenario suffering data scarcity, we randomly select 200 samples for training, which is a common number of data accessible in the application. We train 2D U-Net [11] models with the backbone of ResNet-18 [28], which is commonly adopted for MIS tasks. To construct the U-Net architecture, we adopt group normalisation layers [29] and ReLU activation layers. For data augmentation, we use random translation, zooming, rotation, Gaussian noise, Gaussian blur, brightness jittering, contrast jittering, and gamma jittering following [1]. We use a batch size of 16 and the Dice loss function to train every segmentation model. We adopt the Rectified Adam (RAdam) optimiser [30] for model training, where the learning rate is initialised to 0.001 and decayed by the poly annealing scheduler with a power rate of 0.9 throughout the training process.
Auxiliary Tasks. We employ the following five representative auxiliary tasks, covering both segmentation annotation-based tasks (SDM-in and SDM-out) and unsupervised tasks (RKB, MoCo, and VICReg), to implement our method:
-
•
SDM-in and SDM-out: Surface distance map (SDM) is an alternative representation of segmentation mask, with a rigorous mapping between them through Euclidean distance transform. SDM prediction can help produce smooth segmentation boundaries and reduce spatially isolated errors [2]. In this study, we consider predicting two kinds of SDMs as auxiliary tasks. SDM-in is obtained by replacing the intensity of each pixel in the foreground with the distance to its closest background pixel and setting the background regions to 0. Conversely, SDM-out is obtained by replacing the intensity of each pixel in the background with the distance to its closest foreground pixel and setting the foreground regions to 0.
-
•
RKB (Rubik’s Cube) [8] is a predictive SSL task that trains a model to predict the correct order of a set of shuffled image patches cropped from the same image.
-
•
MoCo (Momentum Contrast-v3) [9] is a contrastive SSL task that trains a model to differentiate the image patches cropped from different images. The task is designed to contrast the latent feature representations of different image patches so that the features of patches cropped from the same image (positive pair) are pulled together, and the features of patches cropped from different images (negative pair) are pushed apart.
-
•
VICReg (Variance-Invariance-Covariance Regularization) [10] is another contrastive SSL task. It has a similar goal as MoCo but employs a different contrastive learning formulation via a variance loss term and an invariance loss term. In addition, it promotes feature diversity by a covariance loss term.
IV-C Impact of Training Mode to Leverage Auxiliary Tasks
Aux. Task | Training Mode | Dice Score (%) ↑ | |
---|---|---|---|
Joint Training | Pre-Training | ||
None | / | / | 38.16±0.60 |
SDM-in | ✓ | 38.65±0.37 | |
✓ | 37.56±0.61 | ||
SDM-out | ✓ | 38.36±0.23 | |
✓ | 37.17±0.74 | ||
RKB | ✓ | 38.45±0.17 | |
✓ | 39.23±0.27 | ||
MoCo | ✓ | 38.13±0.56 | |
✓ | 39.23±0.42 | ||
VICReg | ✓ | 38.47±0.33 | |
✓ | 38.67±0.67 |
Training Method | Aux. Tasks | Dice Score (%) ↑ | |||||
---|---|---|---|---|---|---|---|
SDM-in | SDM-out | MoCo | RKB | VICReg | |||
Conventional | 38.16±0.60 | ||||||
Single Aux. Task | Joint Training / Pre-Training | ✓ | 38.65±0.37 / 37.56±0.61 | ||||
✓ | 38.36±0.23 / 37.17±0.74 | ||||||
✓ | 38.13±0.56 / 39.23±0.42 | ||||||
✓ | 38.45±0.17 / 39.23±0.27 | ||||||
✓ | 38.47±0.33 / 38.67±0.67 | ||||||
Multi Aux. Tasks | Joint Training | ✓ | ✓ | ✓ | ✓ | ✓ | 38.47±0.36 |
Multi-Task Pre-Training | ✓ | ✓ | ✓ | ✓ | ✓ | 38.35±0.37 | |
Ensemble | ✓ | ✓ | ✓ | ✓ | ✓ | 40.33±0.34 | |
GCS [4] | ✓ | ✓ | ✓ | ✓ | ✓ | 38.55±0.42 | |
PCGrad [31] | ✓ | ✓ | ✓ | ✓ | ✓ | 38.90±0.23 | |
OL-AUX [5] | ✓ | ✓ | ✓ | ✓ | ✓ | 38.10±0.33 | |
AMAL[32] | ✓ | ✓ | ✓ | ✓ | ✓ | 38.60±0.24 | |
ours | ✓ | ✓ | ✓ | ✓ | ✓ | 42.05±0.20 |
# Training Data | Dice Score (%) ↑ |
---|---|
200 | 38.16±0.60 |
400 | 41.27±0.14 |
500 | 42.42±0.38 |
600 | 43.46±0.61 |
800 | 45.14±0.10 |
1600 | 47.71±0.23 |
For each auxiliary task, we independently use it to conduct two experiments, each using the auxiliary task to facilitate training a segmentation model with either joint training or pre-training following the respective procedures described in Section III-C. The experimental results are reported in Table I. It can be observed that different auxiliary tasks work well with different training modes. SDM-in and SDM-out work well only via joint training while drastically degrading the segmentation performance with pre-training. By contrast, RKB, MoCo, and VICReg work better with pre-training while offering no or limited improvements when using joint training. These results suggest the appropriate training mode differs for each auxiliary task, and no single training mode works well for all the auxiliary tasks. Furthermore, to take full advantage of different auxiliary tasks, both joint training and pre-training modes need to be supported, which motivates the design of our proposed method.
IV-D Overall Comparison
This section compares the proposed method with several existing methods that leverage multiple auxiliary tasks to boost model performance. We compare the following methods, including intuitive methods for multi-task learning or pre-training, and existing methods proposed for leveraging multiple auxiliary tasks to boost model performance under scarce training data:
-
•
Joint Training is the most commonly used way to leverage auxiliary tasks. The segmentation task and the auxiliary tasks are used together to train the model via a weighted combination of their respective losses. Finally, the segmentation model is returned.
-
•
Multi-Task Pre-Training is a common way to perform model unsupervised pre-training with multi-task SSL, which is adapted to our supervised setting. It first pre-trains a backbone model on all the auxiliary tasks using Joint Training. Then, the pre-trained backbone is transferred to initialise the target segmentation model, which is fine-tuned on the segmentation dataset.
-
•
Ensemble is an intuitive way to aggregate the knowledge learnt from multiple auxiliary tasks. Similar to the first stage of the proposed method, it first uses each auxiliary task independently to facilitate the training of a segmentation model. Then, it directly uses the ensemble of all the obtained models for inference.
-
•
GCS (Gradient Cosine Similarity) [4] is an existing method that utilizes the similarity between the gradients from each auxiliary task and those from the primary task to determine whether to learn from each auxiliary task at every iteration step.
-
•
PCGrad (Projecting Conflicting Gradients) [31] is an MTL method adapted to our setting. At every iteration step, for each auxiliary task, the gradients that conflict with the segmentation task are removed.
-
•
OL-AUX (Online Learning for Auxiliary losses) [5] adaptively changes the loss weight based on the gradient inner product w.r.t the main task to decrease the long-term value of the main task loss.
-
•
AMAL (Adaptive mixing of Auxiliary Losses) [32] adaptively changes the loss weight based on gradient feedback from the validation set.
The results are reported in Table II. It can be seen that existing methods based on the joint training paradigm (all except for Multi-Task Pre-Training, Ensemble, and ours) can hardly integrate the advantages of multiple auxiliary tasks, evidenced by their equal or marginally better performance compared to the baselines of conventional training and joint training with a single auxiliary task. Among these methods, only PCGrad, which removes the gradients hampering solving the segmentation task can surpass all these baselines. This implies a severe interference among the learning heterogeneous tasks, which cannot be well tackled by certain designs trying to filter out negative transfer (GCS and PCGrad) or adaptively integrate the contributions of different auxiliary tasks (AMAL) within the joint training paradigm. Also, these compared methods cannot match the performance obtained by pre-training with a single auxiliary task of MoCo or RKB, showing the shortcoming of joint training to leverage these tasks. By contrast, those methods that determine the training mode for each auxiliary task (Ensemble and ours) significantly improve the performance over all the baselines and compared methods. This demonstrates the necessity to adaptively determine the training mode for different auxiliary tasks and the superiority of using each auxiliary task independently to avoid interference. In addition, our method outperforms Ensemble with a single model, indicating the effectiveness of using ensemble knowledge distillation to aggregate the knowledge from various auxiliary tasks.
IV-E Analysis on Saving Training Data
We probe the number of training data that can be saved by our method. We train segmentation models using conventional training with different numbers of data, ranging from 200 to 1600 and report the corresponding segmentation performance in Table III. It can be seen that the performance of our method with only 200 training samples matches that obtained with 400 - 500 training samples with conventional training, indicating that our method can effectively boost the segmentation performance when the training data is limited.
IV-F Analysis on Auxiliary Task Contribution
Aux. Tasks Involved | Dice Score (%) ↑ | |||||
# Tasks | MoCo | RKB | SDM-in | SDM-out | VICReg | |
5 | ✓ | ✓ | ✓ | ✓ | ✓ | 42.05 |
4 | ✓ | ✓ | ✓ | ✓ | 41.78 | |
✓ | ✓ | ✓ | ✓ | 41.57 | ||
✓ | ✓ | ✓ | ✓ | 42.00 | ||
✓ | ✓ | ✓ | ✓ | 41.90 | ||
✓ | ✓ | ✓ | ✓ | 42.12 | ||
3 | ✓ | ✓ | ✓ | 41.81 | ||
✓ | ✓ | ✓ | 41.60 | |||
✓ | ✓ | ✓ | 42.10 | |||
✓ | ✓ | ✓ | 42.10 | |||
2 | ✓ | ✓ | 41.96 | |||
✓ | ✓ | 41.77 | ||||
✓ | ✓ | 42.04 | ||||
1 | ✓ | 41.16 | ||||
✓ | 41.68 | |||||
✓ | 40.18 | |||||
✓ | 40.27 | |||||
✓ | 41.06 |
As shown in Table IV, we start by using all 5 auxiliary tasks as a baseline and ablate each auxiliary task one by one to obtain the results with 4 auxiliary tasks. Then the best performer with 4 auxiliary tasks is set as the new baseline to further ablate an auxiliary task to conduct the experiments involving 3 auxiliary tasks. This process is repeated until only 2 auxiliary tasks are left. Finally, the results obtained using a single auxiliary task in our method are listed for reference. It can be observed from the results that the performance obtained using all 5 tasks is among the best performers, whereas using less than 5 auxiliary tasks cannot guarantee optimal performance. Besides, no single auxiliary task alone can obtain good performance. These results indicate the effectiveness of using all the auxiliary tasks to integrate their complementary advantages and ensure good performance. It is also shown that certain auxiliary tasks (RKB and MoCo in our case) play more important roles than the others, evidenced by those inferior results obtained without these two tasks.
IV-G Analysis on KD loss coefficient
Dice Score (%) ↑ | |
0 | 36.61±0.76 |
0.1 | 37.88±0.47 |
0.2 | 38.33±0.28 |
0.3 | 39.37±0.44 |
0.4 | 40.45±0.13 |
0.5 | 41.50±0.26 |
0.6 | 41.95±0.26 |
0.7 | 42.07±0.10 |
0.83 (ours) | 42.05±0.20 |
0.9 | 41.95±0.26 |
1 | 41.64±0.12 |
We analyse the impact of the coefficient in the KD loss by setting it to certain values within . corresponds to only using the segmentation loss for training, and corresponds to only using the distillation loss. Recall that our method is implemented by setting . As shown in Table V, the best result is obtained at , and our setting achieves the second best. The inferior performance obtained by a of 0 or 1 indicates the importance of incorporating the supervision from both the ground-truth annotations and pseudo annotations for good performance. The results also show that the KD method in Stage 2 is not sensitive to the choice of as values within yield similar results.
V Conclusions and Future Work
We proposed a two-stage multi-task self-supervised learning approach to address the issue of data scarcity in medical image segmentation. Specifically, given the target segmentation task and the available training set for solving the task, multiple different auxiliary tasks are created in a self-supervised mode. After that, the target task is first independently co-solved with each auxiliary task in both joint training and pre-training modes, with the better model (based on validation performance) generated by the two training modes selected. Then, the models obtained with respect to each auxiliary task are used to produce a single model using an ensemble knowledge distillation method. This two-stage learning approach allows various auxiliary tasks to be more effectively leveraged to help solve the target segmentation task. Experiments validated the performance superiority of our approach compared to several existing methods that utilize multiple auxiliary tasks.
In the future, we plan to investigate ways to reduce the computation cost incurred by the model training with every auxiliary task in two modes. We will also study how to enable more effective and efficient knowledge transfer from auxiliary tasks, where evolutionary multi-task optimization-based approaches [33, 34] can be a potential solution. Further, we will explore the relationships between the models generated in auxiliary tasks, e.g., via graph matching [35] and learning vector quantization [36] techniques, to increase the diversity of auxiliary tasks for augmenting their helpfulness to the target task.
References
- [1] F. Isensee, P. F. Jaeger, S. A. Kohl, J. Petersen, and K. H. Maier-Hein, “nnu-net: a self-configuring method for deep learning-based biomedical image segmentation,” Nature methods, vol. 18, no. 2, pp. 203–211, 2021.
- [2] C. Tan, L. Zhao, Z. Yan, K. Li, D. Metaxas, and Y. Zhan, “Deep multi-task and task-specific feature learning network for robust shape preserved organ segmentation,” in 2018 IEEE 15th International Symposium on Biomedical Imaging (ISBI 2018), pp. 1221–1224, IEEE, 2018.
- [3] H. Chen, X. Qi, L. Yu, and P.-A. Heng, “Dcan: deep contour-aware networks for accurate gland segmentation,” in Proceedings of the IEEE conference on Computer Vision and Pattern Recognition, pp. 2487–2496, 2016.
- [4] Y. Du, W. M. Czarnecki, S. M. Jayakumar, M. Farajtabar, R. Pascanu, and B. Lakshminarayanan, “Adapting auxiliary losses using gradient similarity,” arXiv preprint arXiv:1812.02224, 2018.
- [5] X. Lin, H. Baweja, G. Kantor, and D. Held, “Adaptive auxiliary task weighting for reinforcement learning,” Advances in neural information processing systems, vol. 32, 2019.
- [6] B. Shi, J. Hoffman, K. Saenko, T. Darrell, and H. Xu, “Auxiliary task reweighting for minimum-data learning,” Advances in Neural Information Processing Systems, vol. 33, pp. 7148–7160, 2020.
- [7] G. Hinton, O. Vinyals, and J. Dean, “Distilling the knowledge in a neural network,” arXiv preprint arXiv:1503.02531, 2015.
- [8] X. Zhuang, Y. Li, Y. Hu, K. Ma, Y. Yang, and Y. Zheng, “Self-supervised feature learning for 3d medical images by playing a rubik’s cube,” in Medical Image Computing and Computer Assisted Intervention–MICCAI 2019: 22nd International Conference, Shenzhen, China, October 13–17, 2019, Proceedings, Part IV 22, pp. 420–428, Springer, 2019.
- [9] X. Chen, S. Xie, and K. He, “An empirical study of training self-supervised vision transformers,” 2021.
- [10] A. Bardes, J. Ponce, and Y. LeCun, “Vicreg: Variance-invariance-covariance regularization for self-supervised learning,” arXiv preprint arXiv:2105.04906, 2021.
- [11] O. Ronneberger, P. Fischer, and T. Brox, “U-net: Convolutional networks for biomedical image segmentation,” in Medical Image Computing and Computer-Assisted Intervention–MICCAI 2015: 18th International Conference, Munich, Germany, October 5-9, 2015, Proceedings, Part III 18, pp. 234–241, Springer, 2015.
- [12] Z. Anna, W. Carol, S. George, E. Julia, F. Mikhail, H. Mohannad, ParasLakhani, C. Phil, and B. Shunxing, “Siim-acr pneumothorax segmentation,” 2019.
- [13] O. Oktay, J. Schlemper, L. L. Folgoc, M. Lee, M. Heinrich, K. Misawa, K. Mori, S. McDonagh, N. Y. Hammerla, B. Kainz, et al., “Attention u-net: Learning where to look for the pancreas,” arXiv preprint arXiv:1804.03999, 2018.
- [14] Z. Zhou, M. M. R. Siddiquee, N. Tajbakhsh, and J. Liang, “Unet++: Redesigning skip connections to exploit multiscale features in image segmentation,” IEEE transactions on medical imaging, vol. 39, no. 6, pp. 1856–1867, 2019.
- [15] H. Cao, Y. Wang, J. Chen, D. Jiang, X. Zhang, Q. Tian, and M. Wang, “Swin-unet: Unet-like pure transformer for medical image segmentation,” in European conference on computer vision, pp. 205–218, Springer, 2022.
- [16] Z. Zhou, V. Sodha, J. Pang, M. B. Gotway, and J. Liang, “Models genesis,” Medical image analysis, vol. 67, p. 101840, 2021.
- [17] Z. Zhou, V. Sodha, M. M. Rahman Siddiquee, R. Feng, N. Tajbakhsh, M. B. Gotway, and J. Liang, “Models genesis: Generic autodidactic models for 3d medical image analysis,” in Medical Image Computing and Computer Assisted Intervention–MICCAI 2019: 22nd International Conference, Shenzhen, China, October 13–17, 2019, Proceedings, Part IV 22, pp. 384–393, Springer, 2019.
- [18] S. Vandenhende, S. Georgoulis, W. Van Gansbeke, M. Proesmans, D. Dai, and L. Van Gool, “Multi-task learning for dense prediction tasks: A survey,” IEEE transactions on pattern analysis and machine intelligence, vol. 44, no. 7, pp. 3614–3633, 2021.
- [19] P.-N. Kung, S.-S. Yin, Y.-C. Chen, T.-H. Yang, and Y.-N. Chen, “Efficient multi-task auxiliary learning: selecting auxiliary data by feature similarity,” in Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing, pp. 416–428, 2021.
- [20] A. Shamsian, A. Navon, N. Glazer, K. Kawaguchi, G. Chechik, and E. Fetaya, “Auxiliary learning as an asymmetric bargaining game,” arXiv preprint arXiv:2301.13501, 2023.
- [21] B. B. Sau and V. N. Balasubramanian, “Deep model compression: Distilling knowledge from noisy teachers,” arXiv preprint arXiv:1610.09650, 2016.
- [22] H. Bagherinezhad, M. Horton, M. Rastegari, and A. Farhadi, “Label refinery: Improving imagenet classification through label progression,” arXiv preprint arXiv:1805.02641, 2018.
- [23] X. Cheng, Z. Rao, Y. Chen, and Q. Zhang, “Explaining knowledge distillation by quantifying the knowledge,” in Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pp. 12925–12935, 2020.
- [24] Z. Shen, Z. He, and X. Xue, “Meal: Multi-model ensemble via adversarial learning,” in Proceedings of the AAAI Conference on Artificial Intelligence, vol. 33, pp. 4886–4893, 2019.
- [25] Y. Zhang, T. Xiang, T. M. Hospedales, and H. Lu, “Deep mutual learning,” in Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 4320–4328, 2018.
- [26] L. Zhang, J. Song, A. Gao, J. Chen, C. Bao, and K. Ma, “Be your own teacher: Improve the performance of convolutional neural networks via self distillation,” in Proceedings of the IEEE/CVF international conference on computer vision, pp. 3713–3722, 2019.
- [27] X. Wang, Y. Peng, L. Lu, Z. Lu, M. Bagheri, and R. M. Summers, “Chestx-ray8: Hospital-scale chest x-ray database and benchmarks on weakly-supervised classification and localization of common thorax diseases,” in Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 2097–2106, 2017.
- [28] K. He, X. Zhang, S. Ren, and J. Sun, “Deep residual learning for image recognition,” in Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770–778, 2016.
- [29] Y. Wu and K. He, “Group normalization,” in Proceedings of the European conference on computer vision (ECCV), pp. 3–19, 2018.
- [30] L. Liu, H. Jiang, P. He, W. Chen, X. Liu, J. Gao, and J. Han, “On the variance of the adaptive learning rate and beyond,” arXiv preprint arXiv:1908.03265, 2019.
- [31] T. Yu, S. Kumar, A. Gupta, S. Levine, K. Hausman, and C. Finn, “Gradient surgery for multi-task learning,” Advances in Neural Information Processing Systems, vol. 33, pp. 5824–5836, 2020.
- [32] D. Sivasubramanian, A. Maheshwari, A. Prathosh, P. Shenoy, and G. Ramakrishnan, “Adaptive mixing of auxiliary losses in supervised learning,” in Proceedings of the AAAI Conference on Artificial Intelligence, vol. 37, pp. 9855–9863, 2023.
- [33] H. Song, A. K. Qin, P.-W. Tsai, and J. J. Liang, “Multitasking multi-swarm optimization,” in 2019 IEEE Congress on Evolutionary Computation (CEC), pp. 1937–1944, IEEE, 2019.
- [34] Y. Wu, H. Ding, M. Gong, A. K. Qin, W. Ma, Q. Miao, and K. C. Tan, “Evolutionary multiform optimization with two-stage bidirectional knowledge transfer strategy for point cloud registration,” IEEE Transactions on Evolutionary Computation, vol. 28, no. 1, pp. 62–76, 2024.
- [35] M. Gong, Y. Wu, Q. Cai, W. Ma, A. K. Qin, Z. Wang, and L. Jiao, “Discrete particle swarm optimization for high-order graph matching,” Information Sciences, vol. 328, pp. 158–171, 2016.
- [36] A. K. Qin and P. N. Suganthan, “Initialization insensitive LVQ algorithm based on cost-function adaptation,” Pattern Recognition, vol. 38, no. 5, pp. 773–776, 2005.