This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

CNN-Transformer Rectified Collaborative Learning for Medical Image Segmentation

Lanhu Wu {}^{~{}}, Miao Zhang {}^{~{}}, , Yongri Piao {}^{~{}}, , Zhenyan Yao {}^{~{}},
Weibing Sun , Feng Tian , and Huchuan Lu {}^{~{}}
This work was supported by the National Natural Science Foundation of China under Grant 62172070, Grant 62372080, and Grant 62376050. (Corresponding author: Yongri Piao).Lanhu Wu, Yongri Piao, Zhenyan Yao, and Huchuan Lu are with the School of Information and Communication Engineering, Dalian University of Technology, Dalian 116024, China (e-mail: [email protected]; [email protected]; [email protected]; [email protected]).Miao Zhang is with the Key Laboratory for Ubiquitous Network and Service Software of Liaoning Province, the DUT-RU International School of Information Science and Engineering, Dalian University of Technology, Dalian 116024, China (e-mail: [email protected]).Weibing Sun and Feng Tian are with the Affiliated Zhongshan Hospital of Dalian University, Dalian 116024, China (e-mail: [email protected]; [email protected]). 0009-0006-6420-5971 0000-0002-7972-7047 0000-0002-0860-252X 0009-0008-2833-8325 0000-0002-6668-9758
Abstract

Automatic and precise medical image segmentation (MIS) is of vital importance for clinical diagnosis and analysis. Current MIS methods mainly rely on the convolutional neural network (CNN) or self-attention mechanism (Transformer) for feature modeling. However, CNN-based methods suffer from the inaccurate localization owing to the limited global dependency while Transformer-based methods always present the coarse boundary for the lack of local emphasis. Although some CNN-Transformer hybrid methods are designed to synthesize the complementary local and global information for better performance, the combination of CNN and Transformer introduces numerous parameters and increases the computation cost. To this end, this paper proposes a CNN-Transformer rectified collaborative learning (CTRCL) framework to learn stronger CNN-based and Transformer-based models for MIS tasks via the bi-directional knowledge transfer between them. Specifically, we propose a rectified logit-wise collaborative learning (RLCL) strategy which introduces the ground truth to adaptively select and rectify the wrong regions in student soft labels for accurate knowledge transfer in the logit space. We also propose a class-aware feature-wise collaborative learning (CFCL) strategy to achieve effective knowledge transfer between CNN-based and Transformer-based models in the feature space by granting their intermediate features the similar capability of category perception. Extensive experiments on three popular MIS benchmarks demonstrate that our CTRCL outperforms most state-of-the-art collaborative learning methods under different evaluation metrics.

Index Terms:
Medical image segmentation, CNN, Transformer, collaborative learning.

I Introduction

Medical image segmentation (MIS) is one of the critical steps in pre-treatment diagnoses, treatment planning, and post-treatment assessments of various diseases. Traditional MIS methods mainly rely on template matching [1], edge detection [2], machine learning [3], etc. However, the challenge of selecting discriminating features and appropriate hyper-parameters limits the development of these methods.

Refer to caption
Figure 1: Visualizations of class activation maps generated by Grad-CAM [4] and segmentation results of ResNet-50 [5] (Left) and MiT-B2 [6] (Right). Current CNN-based models suffer from the inaccurate localization (e.g., missing stomach in (a), (c)) while Transformer-based models present the coarse boundary (e.g., incomplete liver in (e), (g)). Our CTRCL improves the performance of the CNN-based model with more accurate location ((b), (d)) and the Transformer-based model with more elaborate boundary ((f), (h)) via collaborative learning between CNN-based and Transformer-based models.

In recent years, deep learning technology has been widely applied to MIS and made breakthrough progress. Early deep learning based methods depend on the convolutional neural network (CNN) for hierarchical feature modeling. UNet [7] produces the high-resolution segmentation map by aggregating multi-stage features with skip connections and integrating them in a top-down manner. Due to the effective encoder-decoder architecture, a few variants of UNet [8], [9], [10], [11], [12] have demonstrated the impressive performance in MIS tasks. Despite the effectiveness of CNN-based methods, they exhibit limitations in extracting the global context information due to the confined receptive field of the convolution operation, leading to the inaccurate localization of various organs and lesions (as shown in Fig. 1 (a), (c)). To overcome this problem, a series of Transformer-based methods [13], [14], [15], [16] are proposed for MIS which utilize the self-attention mechanism to model long-range dependencies with dynamic weights and global receptive field. Nevertheless, since each element is calculated and compared with other elements equally in the self-attention process, the local information would be diluted within a global perspective, resulting in the coarse boundary for medical objects (as shown in Fig. 1 (e), (g)). Besides, Transformers are data-hungry models while abundant and quality-annotated medical data are difficult to acquire, which limits the potential of these methods. Thus, some methods [17], [18], [19] design CNN-Transformer hybrid networks to synthesize the complementary local and global information for better performance of MIS. Although these methods achieve significant improvements, the combination of CNN and Transformer introduces numerous parameters and increases the cost of computation. To this end, we ask a question: ‘why not enable CNN-based and Transformer-based models to learn the complementary global and local information from each other for enhancement of each model?’

Our inspiration is derived from the collaborative learning mechanism that interchanges the knowledge among multiple students synchronously during the end-to-end training process for improvement of each student. However, accurate and effective CNN-Transformer collaborative learning is a challenging task for two reasons. On the one hand, since the CNN-based student and Transformer-based student always suffer from the inaccurate localization and coarse boundary respectively for various MIS tasks, their soft labels (predicted probability) are imprecise in the corresponding areas. Direct collaborative learning in these areas between two students may cause them to converge in wrong directions, thereby degrading their performance. This impact is exacerbated in the initial phase of collaborative learning because both students are unpretrained on the target dataset, leading to severe inaccuracy of their soft labels. On the other hand, CNN and Transformer are heterogeneous networks with substantial intrinsic gaps between their intermediate features, which significantly increases the difficulty of knowledge transfer in the feature space. Forcibly aligning the mismatched features would disrupt their respective characteristics, hence leading to incorrect predictions.

To address these problems, this article proposes a CNN-Transformer rectified collaborative learning (CTRCL) framework to learn stronger CNN-based and Transformer-based models for MIS. Specifically, CTRCL framework incorporates a rectified logit-wise collaborative learning (RLCL) strategy and a class-aware feature-wise collaborative learning (CFCL) strategy to achieve bi-directional knowledge transfer between CNN-based and Transformer-based students in logit and feature spaces, respectively. For accurate logit-wise knowledge transfer, RLCL strategy introduces an adaptive rectification module (ARM) which employs the ground truth to select the wrong regions in student soft labels and to rectify them with dynamic weights. As such, our RLCL is endowed with the ability to adaptively generate high-quality student soft labels for logit-wise mutual learning. For effective feature-wise knowledge transfer, CFCL strategy transforms the student immediate features into class-aware representations via a category perception module (CPM) and aligns them under loss supervision. By granting the student features the similar capability of category perception, CFCL manages to transfer the feature-wise knowledge between heterogeneous networks. To verify the effectiveness of our CTRCL framework, we conduct experiments on three popular MIS benchmarks: Synapse Multi-organ [20], ACDC [21] and Kvasir-SEG [22]. The contributions of this work can be summarized as follows:

  • To our best knowledge, our CTRCL framework makes the first attempt to adopt the collaborative learning mechanism to learn stronger CNN-based and Transformer-based models for MIS tasks via the bi-directional knowledge transfer between them in both logit and feature spaces.

  • We propose an RLCL strategy for accurate logit-wise knowledge transfer, which introduces the ground truth to adaptively select and rectify the wrong regions in student soft labels with dynamic weights.

  • We propose a CFCL strategy for effective feature-wise knowledge transfer between heterogeneous networks by granting their intermediate features the similar capability of category perception.

  • Our CTRCL framework consistently achieves new state-of-the-art performance on three MIS benchmarks compared with other collaborative learning methods. Particularly, our CTRCL reduces the MAE metric by 42.93% and 31.23% for ResNet-50 and MiT-B2 respectively on the Kvasir-SEG [22] dataset.

Refer to caption
Figure 2: The whole pipeline of our CTRCL framework, containing three parts: CNN-based student, Transformer-based student, and collaborative learning strategies (CFCL and RLCL). (a) Class-aware feature-wise collaborative learning (CFCL) focuses on the effective feature-wise knowledge transfer by encouraging student features to possess similar class-aware representations. (b) Rectified logit-wise collaborative learning (RLCL) aims at the accurate logit-wise knowledge transfer with student soft labels rectified by the ground truth. Please refer to Section III for details.

II Related Work

II-A Medical Image Segmentation

Medical image segmentation (MIS) is a dense prediction task that classifies the pixels of organs or lesions in medical images (e.g., CT, MRI, US, endoscopy, etc.). Traditional MIS methods typically rely on hand-crafted features such as edge, shape, texture, etc, which suffer from a high risk of mis-segmentation for the inherent lack of high-level semantics.

With the development of deep learning technology, CNN-based MIS methods are first proposed and achieve impressive performance. Ronneberger et al. [7] proposed a symmetric U-shape network to aggregate multi-stage features with skip connections and progressively integrated them for high-resolution segmentation maps. Later, Zhou et al. [8] designed a series of nested, dense skip connections to reduce the semantic gaps between the feature maps and utilized deep supervision for multi-scale lesions. Jha et al. [23] introduced the residual connection [5], SE [24] and ASPP[25] to further improve the performance. However, limited by the receptive field in CNN, these methods are deficient in extracting global context information, leading to inaccurate localization. To this end, some Transformer-based MIS methods are proposed with self-attention mechanism applied for the long-range modeling. Karimi et al. [26] first presented a convolution-free segmentation model by forwarding flattened image representations to Transformer, whose outputs are then reorganized into 3D tensors to generate segmentation masks. Cao et al. [13] utilized the Swin Transformer to construct the encoder and decoder within the U-Net paradigm. Furthermore, Lin et al. [14] adopted a dual-scale Transformer encoder to attain coarse–fine-tuning features and employed the self-attention mechanism to establish global dependencies between them. Nonetheless, due to the local dilution in the self-attention process, methods based on pure transformer usually suffer from the coarse boundary in segmentation results. Besides, the excessive data dependency limits the potential of these methods for the insufficient medical data. Thus, some works couple both CNN and Transformer to design hybrid networks for MIS. Chen et al. [17] concatenated a Transformer-based encoder after the CNN-based encoder for global feature extraction. Zhang et al. [18] utilized shallow CNN-based encoder and Transformer-based segmentation network in parallel, and fused the features from two branches to jointly make predictions. Zhou et al. [27] constructed a CNN-Transformer hybrid backbone, nnFormer, with convolutional and Transformer blocks interleaved to take advantages of each other. Although these methods achieve better performance, the large model size and high computational cost prevent them from application. Different from the aforementioned methods, our CTRCL enables CNN-based model and Transformer-based model to learn from each other to improve the performance of each model.

II-B Collaborative Learning

Collaborative learning technology derives from the knowledge distillation [28] (KD) which transfers the knowledge from a cumbersome pre-trained teacher model to a compact student model for model compression. Differently, collaborative learning simultaneously exchanges the knowledge among a series of student models in an end-to-end training procedure to enhance the performance of each model.

Current collaborative learning methods can be primarily categorized into two types, mutual learning [29], [30], [31], [32] and ensemble learning [33], [34], [35], [36]. Mutual learning strategy is first time proposed by Zhang et al. [29] to make peer students learn from each other through the Kullback-Leibler (KL) divergence loss between each pair of student logits. Following this work, Chung et al. [37] introduced the mutual learning mechanism to feature maps using an adversarial training paradigm. More recently, Zhu et al. [38] proposed a bidirectional selective distillation strategy to transfer reliable knowledge between student models in both logit and feature spaces. Ensemble learning constructs a virtual teacher by assembling the outputs of all the students, which are then distilled back to foster each student. Lan et al. [33] constructed a multi-branch network and assembled the on-the-fly logit information from the branches to enhance the performance on the target network. Later, Guo et al. [39] extended the ensemble learning to the outputs of peer students with different architectures to generate high-quality labels for supervision. Furthermore, Kim et al. [40] introduced a feature fusion module which integrates the feature representations of sub-networks to construct the teacher classifier whose knowledge is delivered back to each sub-network. Unfortunately, current collaborative learning methods overlook the accuracy of the soft labels in the process of logit-wise knowledge transfer because neither a single student nor an ensemble teacher can completely reflect the distribution of the ground truth. Additionally, feature-wise collaborative learning between heterogeneous networks are not fully explored in previous studies. In contrast, we strive to ensure the accuracy of student soft labels with the help of the ground truth and to achieve effective feature-wise knowledge transfer between heterogeneous networks (CNN and Transformer) via the alignment of class-aware representations.

III Proposed Method

III-A Overview

An overview of our CTRCL framework is depicted in Fig. 2, which consists of three components: a CNN-based student f(θC)f\left(\theta^{C}\right), a Transformer-based student f(θT)f\left(\theta^{T}\right), and the proposed collaborative learning strategies (RLCL and CFCL). Given an input image set 𝒳\mathcal{X} and its label set 𝒴\mathcal{Y}, our objective is to enable f(X;θC)f\left(X;\theta^{C}\right) and f(X;θT)f\left(X;\theta^{T}\right) to learn collaboratively to assign the pixel-wise label l1,,Ll\in 1,...,L in image X𝒳(XH×W×D)X\in\mathcal{X}\left(X\in\mathbb{R}^{H\times W\times D}\right) more accurately than the student itself, where HH, WW and DD are the height, width and depth of XX, LL is the number of categories. To achieve this goal, given a specific input XX, we first attain the segmentation predictions (PCP^{C} and PTP^{T}) and feature representations (FCF^{C} and FTF^{T}) from the two students f(X;θC)f\left(X;\theta^{C}\right) and f(X;θT)f\left(X;\theta^{T}\right), formulated as:

(PC,FC)\displaystyle\left(P^{C},F^{C}\right) =f(X;θC),\displaystyle=f\left(X;\theta^{C}\right), (1)
(PT,FT)\displaystyle\left(P^{T},F^{T}\right) =f(X;θT).\displaystyle=f\left(X;\theta^{T}\right).

Then the pixel-wise segmentation loss (seg\mathcal{L}_{seg}) is based on the cross entropy CE()\mathrm{CE\left(\cdot\right)} with the ground truth YY:

segC\displaystyle{\mathcal{L}}_{seg}^{C} =1H×Wh=1Hw=1WCE(P(h,w)C,Y(h,w)),\displaystyle=\frac{1}{H\times W}\sum_{h=1}^{H}\sum_{w=1}^{W}\mathrm{CE}\left(P^{C}_{\left(h,w\right)},Y_{\left(h,w\right)}\right), (2)
segT\displaystyle{\mathcal{L}}_{seg}^{T} =1H×Wh=1Hw=1WCE(P(h,w)T,Y(h,w)),\displaystyle=\frac{1}{H\times W}\sum_{h=1}^{H}\sum_{w=1}^{W}\mathrm{CE}\left(P^{T}_{\left(h,w\right)},Y_{\left(h,w\right)}\right),

where (h,w){\left(h,w\right)} indexes the spatial location for each pixel. Meanwhile, we propose a rectified logit-wise collaborative learning (RLCL) strategy and a class-aware feature-wise collaborative learning (CFCL) strategy to achieve bi-directional knowledge transfer between CNN-based and Transformer-based students in the logit and feature spaces, respectively. The details of RLCL and CFCL strategies are described as follows.

III-B Rectified Logit-wise Collaborative Learning

It is a fact that neither a CNN-based nor a Transformer-based student can completely mimic the ground truth, inevitably leading to mis-categorized pixels in their soft labels, which is adverse to the logit-wise collaborative learning between them. Consequently, we propose a rectified logit-wise collaborative learning (RLCL) strategy which adopts the ground truth to adaptively rectify the wrong regions in student soft labels for accurate logit-wise knowledge transfer between CNN-based and Transformer-based students.

As shown in Fig. 2 (b), the RLCL strategy introduces an adaptive rectification module (ARM) to rectify the segmentation predictions of both students (PCP^{C} and PTP^{T}) and obtain rectified soft labels (PrCP^{C}_{r} and PrTP^{T}_{r}) under the guidance of the ground truth YY, which can be formulated as:

PrC\displaystyle{P}_{r}^{C} =ARM(PC,Y),\displaystyle=\mathrm{ARM}\left({P}^{C},Y\right), (3)
PrT\displaystyle{P}_{r}^{T} =ARM(PT,Y).\displaystyle=\mathrm{ARM}\left({P}^{T},Y\right).

The illustration of ARM is provided in Fig. 3. Given the segmentation prediction of a student as PP, we first attain its pseudo label via the Argmax function and compare it with the ground truth YY by the XOR operation to produce an error mask MM as follows:

M=XOR(Argmax(P),Y).M=\mathrm{XOR}\left(\mathrm{Argmax}\left(P\right),Y\right). (4)

Then we restore the one-hot label from the ground truth YY and utilize the error mask MM to select the mis-segmented region YmY^{m} and PmP^{m} from YY and PP, respectively. After that, PmP^{m} is rectified to PrmP_{r}^{m} by the combination of PmP^{m} and YmY^{m}. The above process can be written as:

Ym\displaystyle Y^{m} =OHE(Y)M,\displaystyle=\mathrm{OHE}\left(Y\right)\cdot M, (5)
Pm\displaystyle P^{m} =PM,\displaystyle=P\cdot M,
Prm\displaystyle P^{m}_{r} =λPm+(1λ)Ym,\displaystyle=\lambda\cdot P^{m}+\left(1-\lambda\right)\cdot Y^{m},

where OHE()\mathrm{OHE}\left(\cdot\right) denotes the one-hot encoding and λ\lambda is a matrix of dynamic weights for rectification demonstrated afterwards. Finally, we joint the rectified region PrmP_{r}^{m} and the correct region PcP_{c} to generate the rectified soft label PrP_{r}, formulated as:

Pc\displaystyle P_{c} =P(1M),\displaystyle=P\cdot(1-M), (6)
Pr\displaystyle P_{r} =Prm+Pc.\displaystyle=P_{r}^{m}+P_{c}.

Now, we describe the computation of λ\lambda and the process of rectification in detail. As a whole, the λ\lambda consists of an alignment factor λa{\lambda}^{a}, a similarity-based decay factor λs{\lambda}^{s} and a certainty-based decay factor λc{\lambda}^{c}, formulated as:

λ=λaλsλc.\lambda={\lambda}^{a}\cdot{\lambda}^{s}\cdot{\lambda}^{c}. (7)

III-B1 Alignment Factor (λa{\lambda}^{a})

We initially introduce the one-hot label in YmY^{m} to align the probabilities of mis-categorized class and truth class in the prediction PmP^{m}. Given the probabilities of mis-categorized class and truth class in PmP^{m} as pmisp^{mis} and ptrup^{tru}, the alignment process can be formulated as:

λapmis+(1λa)vmis=λaptru+(1λa)vtru{\lambda}^{a}\cdot p^{mis}+\left(1-{\lambda}^{a}\right)\cdot v^{mis}={\lambda}^{a}\cdot p^{tru}+\left(1-{\lambda}^{a}\right)\cdot v^{tru} (8)

where vmisv^{mis} and vtruv^{tru} are the values of mis-categorized class and truth class in YmY^{m}, namely 0 and 1 respectively. Thus, the alignment factor λa\lambda^{a} could be obtained as follows:

λa=(1+pmisptru)1.{\lambda}^{a}={\left(1+p^{mis}-p^{tru}\right)}^{-1}. (9)

Let λ:=λa\lambda:=\lambda^{a}, the mis-categorized class shares the same probability with the truth class in PrmP^{m}_{r}. Furthermore, ARM aims to modify the mis-category prediction, which means a larger probability on truth class than mis-category class. Namely, the equality relationship in Eq. (8) should be converted to ‘<<’. To achieve this, we additionally introduce two decay factors (λs\lambda^{s} and λc\lambda^{c}) on the basis of λa\lambda^{a}.

Refer to caption
Figure 3: Illustration of the adaptive rectification module (ARM).

III-B2 Similarity-based Decay Factor (λs{\lambda}^{s})

λs{\lambda}^{s} is a decay factor to decrease the proportion of prediction for rectification, which is positively correlated to the similarity between the prediction and the one-hot label. We adopt the cross entropy CE()\mathrm{CE\left(\cdot\right)} to measure the similarity and formulate the λs{\lambda}^{s} as follows:

λs=exp(CE(Pm,Ym)).{\lambda}^{s}=\mathrm{exp}\left(-\mathrm{CE}\left(P^{m},Y^{m}\right)\right). (10)

III-B3 Certainty-based Decay Factor (λc{\lambda}^{c})

Similarly, λc{\lambda}^{c} is also a decay factor for the rectification process, positively correlated to the certainty of the prediction. We adopt the entropy S()\mathrm{S}\left(\cdot\right) to evaluate the certainty of predictions and calculate the λc{\lambda}^{c} by the following formula:

λc=exp(S(Pm)).{\lambda}^{c}=\mathrm{exp}\left(-\mathrm{S}\left(P^{m}\right)\right). (11)

In this way, the ARM adaptively select and rectify the wrong regions in student soft labels under the guidance of ground truth with dynamic weights to generate accurate soft labels for logit-wise collaborative learning.

With the rectified soft labels (PrCP^{C}_{r} and PrTP^{T}_{r}) obtained, we follow the previous approaches to adopt the KL divergence KL()\mathrm{KL\left(\cdot\right)} for logit-wise collaborative learning between CNN-based and Transformer-based students:

rlclC\displaystyle{\mathcal{L}}_{rlcl}^{C} =1H×Wh=1Hw=1WKL(P(h,w)C||PrT(h,w)),\displaystyle=\frac{1}{H\times W}\sum_{h=1}^{H}\sum_{w=1}^{W}\mathrm{KL}\left(P^{C}_{\left(h,w\right)}||{P^{T}_{r}}_{\left(h,w\right)}\right), (12)
rlclT\displaystyle{\mathcal{L}}_{rlcl}^{T} =1H×Wh=1Hw=1WKL(P(h,w)T||PrC(h,w)).\displaystyle=\frac{1}{H\times W}\sum_{h=1}^{H}\sum_{w=1}^{W}\mathrm{KL}\left(P^{T}_{\left(h,w\right)}||{P^{C}_{r}}_{\left(h,w\right)}\right).

III-C Class-aware Feature-wise Collaborative Learning

The distinct modeling approaches of CNN and Transformer result in intrinsic feature gaps between them. To this end, we propose a class-aware feature-wise collaborative learning (CFCL) strategy to achieve effective feature-wise knowledge transfer between heterogeneous CNN-based and Transformer-based students by granting their immediate features the similar capability of category perception.

Refer to caption
Figure 4: Illustration of the category perception module (CPM).

As shown in Fig. 2 (a), the CFCL strategy introduces a category perception module (CPM) to extract the class-aware representations (RCR^{C} and RTR^{T}) for the categories in ground truth YY from immediate features (FCF^{C} and FTF^{T}), which can be formulated as:

RC\displaystyle R^{C} =CPM(FC,Y),\displaystyle=\mathrm{CPM}\left({F}^{C},Y\right), (13)
RT\displaystyle R^{T} =CPM(FT,Y).\displaystyle=\mathrm{CPM}\left({F}^{T},Y\right).

The structure of the CPM is illustrated in Fig. 4. Given a batch of input images {Xn}n=1N\left\{X^{n}\right\}^{N}_{n=1} where NN denotes the batch size, and their corresponding ground truths {Yn}n=1N\left\{Y^{n}\right\}^{N}_{n=1} with KK classes (including background) in total, let FnH^×W^×D^F^{n}\in\mathbb{R}^{\hat{H}\times\hat{W}\times\hat{D}} be the feature map of a student network, where H^\hat{H}, W^\hat{W} and D^\hat{D} are the height, width and depth of FnF^{n}. The prototype of class l𝒦l\in\mathcal{K} (with |𝒦|=K\left|\mathcal{K}\right|=K) is generated via the masked average pooling and batch mean:

pl=1Nlnh^,w^F(h^,w^)n𝟙[Y(h^,w^)n=l]h^,w^𝟙[Y(h^,w^)n=l]+ϵ,{p}_{l}=\frac{1}{N_{l}}\sum_{n}{\frac{{\textstyle\sum_{\hat{h},\hat{w}}}F^{n}_{\left(\hat{h},\hat{w}\right)}\mathbbm{1}\left[Y^{n}_{\left(\hat{h},\hat{w}\right)}=l\right]}{{\textstyle\sum_{\hat{h},\hat{w}}}\mathbbm{1}\left[Y^{n}_{\left(\hat{h},\hat{w}\right)}=l\right]+\epsilon}}, (14)

where the ground truth YnY^{n} is down-sampled to match the spatial size of the feature map FnF^{n} with the nearest interpolation; (h^,w^)(\hat{h},\hat{w}) indexes the spatial locations for each pixel; 𝟙[]\mathbbm{1}\left[\cdot\right] represents the indicator function that returns value 1 when the argument is true or 0 otherwise; ϵ\epsilon is an infinitesimal to cope with the absence of some classes in a batch; and NlN_{l} is the number of images that include the llth class.

Since the prototype actually denotes the channel-wise summary of a class, we can obtain the pixel-wise category perception by establishing the relationship between the feature map and the prototype. To achieve this objective, we adopt a non-parametric metric learning mechanism. Concretely, we first calculate the distances between the feature vector at each spatial location and all prototypes. Then we apply a softmax operation over the distances for the class-aware representation RR. Let 𝒫n={pll𝒦n,|𝒦n|=Kn}\mathcal{P}^{n}=\left\{p_{l}\mid l\in\mathcal{K}^{n},\,\left|\mathcal{K}^{n}\right|=K^{n}\right\} be the prototypes for the nnth sample where KnK^{n} is the number of classes in the nnth sample, for each pi𝒫p_{i}\in\mathcal{P}, we have:

Rin(h^,w^)=exp(αd(F(h^,w^)n,pi))pi𝒫nexp(αd(F(h^,w^)n,pi)),{R^{n}_{i}}_{\left(\hat{h},\hat{w}\right)}=\frac{\mathrm{exp}\left(-\alpha d\left(F^{n}_{\left(\hat{h},\hat{w}\right)},p_{i}\right)\right)}{{\textstyle\sum_{p_{i}\in\mathcal{P}^{n}}}\mathrm{exp}\left(-\alpha d\left(F^{n}_{\left(\hat{h},\hat{w}\right)},p_{i}\right)\right)}, (15)

where d()d\left(\cdot\right) adopts the cosine distance, and multiplier α\alpha is a scaling factor fixed at 20 recommended by [41].

To achieve the class-aware feature-wise collaborative learning between CNN-based and Transformer-based students, we propose to minimize the loss function as follows:

cfclC\displaystyle{\mathcal{L}}_{cfcl}^{C} =1H^×W^h^=1H^w^=1W^KL(R(h^,w^)C||R(h^,w^)T),\displaystyle=\frac{1}{\hat{H}\times\hat{W}}\sum_{\hat{h}=1}^{\hat{H}}\sum_{\hat{w}=1}^{\hat{W}}\mathrm{KL}\left({R}^{C}_{\left(\hat{h},\hat{w}\right)}||{R}^{T}_{\left(\hat{h},\hat{w}\right)}\right), (16)
cfclT\displaystyle{\mathcal{L}}_{cfcl}^{T} =1H^×W^h^=1H^w^=1W^KL(R(h^,w^)T||R(h^,w^)C).\displaystyle=\frac{1}{\hat{H}\times\hat{W}}\sum_{\hat{h}=1}^{\hat{H}}\sum_{\hat{w}=1}^{\hat{W}}\mathrm{KL}\left({R}^{T}_{\left(\hat{h},\hat{w}\right)}||{R}^{C}_{\left(\hat{h},\hat{w}\right)}\right).

Noteworthily, we perform the bi-directional CFCL on the output features of both encoder and decoder, i.e., (FEC,FET)\left(F^{C}_{E},F^{T}_{E}\right) and (FDC,FDT)\left(F^{C}_{D},F^{T}_{D}\right), to exchange the information of high-level semantics as well as the low-level details between CNN student and Transformer student, respectively.

III-D Optimization

The overall loss functions of the CTRCL framework for CNN-based and Transformer-based students are given as:

C\displaystyle{\mathcal{L}}^{C} =segC+βrlclC+γ1cfclE,C+γ2cfclD,C,\displaystyle={\mathcal{L}}_{seg}^{C}+\beta{\mathcal{L}}_{rlcl}^{C}+\gamma_{1}{\mathcal{L}}_{cfcl}^{E,C}+\gamma_{2}{\mathcal{L}}_{cfcl}^{D,C}, (17)
T\displaystyle{\mathcal{L}}^{T} =segT+βrlclT+γ1cfclE,T+γ2cfclD,T,\displaystyle={\mathcal{L}}_{seg}^{T}+\beta{\mathcal{L}}_{rlcl}^{T}+\gamma_{1}{\mathcal{L}}_{cfcl}^{E,T}+\gamma_{2}{\mathcal{L}}_{cfcl}^{D,T},

where seg\mathcal{L}_{seg} is the segmentation loss between the student predictions and the ground truth; rlcl{\mathcal{L}}_{rlcl} is the rectified logit-wise collaborative learning loss between the student predictions; cfclE{\mathcal{L}}_{cfcl}^{E} and cfclD{\mathcal{L}}_{cfcl}^{D} denote the class-aware feature-wise collaborative learning loss between the student output features of encoder and decoder, i.e., (FEC,FET)\left(F^{C}_{E},F^{T}_{E}\right) and (FDC,FDT)\left(F^{C}_{D},F^{T}_{D}\right) respectively; β\beta, γ1\gamma_{1} and γ2\gamma_{2} are hyperparameters to trade off these loss terms. The optimization details are summarized in Algorithm 1. In test, we deploy both CNN-based and Transformer-based students because our target is to learn stronger CNN-based and Transformer-based models for MIS via the collaborative learning mechanism.

Algorithm 1 Optimization of CTRCL
1:Image set 𝒳\mathcal{X}, label set 𝒴\mathcal{Y}, CNN-based student f(θC)f(\theta^{C}) and Transformer-based student f(θV)f(\theta^{V});
2:Two trained student models;
3:Initialization: Initialize θC\theta^{C} and θV\theta^{V};
4:repeat
5:     Input a batch of images {Xn,Yn}n=1N\{X^{n},Y^{n}\}_{n=1}^{N};
6:     Attain (PC,PT)\left(P^{C},P^{T}\right) and (FC,FT)\left(F^{C},F^{T}\right) with Eq. (1);
7:     Compute segC{\mathcal{L}}_{seg}^{C} and segT{\mathcal{L}}_{seg}^{T} with Eq. (2);
8:     Select PmP^{m} in PP with Eq. (4);
9:     Compute λa\lambda^{a}, λs\lambda^{s} and λc\lambda^{c} with Eq. (9) (10) (11) and obtain λ\lambda with Eq. (7);
10:     Attain PrP_{r} with Eq. (5) and (6);
11:     Compute rlclC{\mathcal{L}}_{rlcl}^{C} and rlclT{\mathcal{L}}_{rlcl}^{T} with Eq. (12);
12:     Compute the prototype pp with Eq. (14);
13:     Attain the class-aware representation RR with Eq. (15);
14:     Compute cfclC{\mathcal{L}}_{cfcl}^{C} and cfclT{\mathcal{L}}_{cfcl}^{T} with Eq. (16);
15:     Compute C{\mathcal{L}}^{C} and T{\mathcal{L}}^{T} with Eq. (17);
16:     Update θC\theta^{C} and θV\theta^{V} respectively;
17:until maximum iterations.
18:return θC\theta^{C} and θV\theta^{V}.
19:End.

IV Experiments

IV-A Datasets

IV-A1 Synapse Multi-organ

The Synapse [20] dataset contains 30 abdominal CT scans with 3779 axial contrast-enhanced abdominal CT images including 8 anatomical structures, i.e., aorta, gallbladder (GB), left kidney (KL), right kidney (KR), liver, pancreas (PC), spleen (SP), and stomach (SM). Each CT scan consists of 85-198 slices of 512 ×\times 512 pixels, with a voxel spatial resolution of [0:54-0:54] ×\times [0:98-0:98] ×\times [2:5-5:0] mm3\mathrm{mm}^{3}. Following [13], [17], we divide the dataset into 18 scans (2211 axial slices) for training, and 12 scans (1568 axial slices) for testing.

IV-A2 Automatic Cardiac Diagnosis Challenge

The ACDC [21] dataset contains 150 cardiac MRI scans which consist of three organs, right ventricle (RV), myocardium (Myo), and left ventricle (LV). Cine MR images were acquired in breath hold, and a series of short-axis slices cover the heart from the base to the apex of the left ventricle, with a slice thickness of 5 to 8 mm\mathrm{mm}. The short-axis in-plane spatial resolution goes from 0.83 to 1.75 mm2\mathrm{mm}^{2}/pixel. We follow the official division protocol for experiments, i.e., 100 cases (1902 axial slices) for training and 50 cases (1076 axial slices) for testing.

IV-A3 Kvasir-SEG

The Kvasir [22] dataset is the most commonly used endoscopic image dataset in the field of polyp segmentation, collected by Vestre Viken Health Trust in Norway from inside the gastrointestinal (GI) tract. It contains 1000 images with different resolutions from 720 ×\times 576 to 1920 ×\times 1072 pixels related to endoscopic polyp removal, which can be used for computer-aided gastrointestinal lesion segmentation. Consistent with [42], [43], we divide the dataset into 900 images for training and 100 images for testing.

TABLE I: Quantitative Comparisons of DSC (%\%), JAC (%\%) and HSD (mm\mathrm{mm}) Scores on Synapse Multi-organ Dataset and ACDC Dataset. \uparrow(\downarrow) Denotes Higher the Better (Lower the Better). The Best Results are Shown in Boldface. Vanilla Means no Collaborative Learning. Δ(%)\Delta(\%) Denotes the Improvement Rate Comparing Our Method with Vanilla.
Dataset Method MobileNetV2 MiT-B0 ResNet-50 MiT-B2
DSC\uparrow JAC\uparrow HSD\downarrow DSC\uparrow JAC\uparrow HSD\downarrow DSC\uparrow JAC\uparrow HSD\downarrow DSC\uparrow JAC\uparrow HSD\downarrow
Synapse Vanilla 73.35 62.69 25.31 75.68 65.67 19.39 75.62 65.67 30.11 77.85 67.64 26.52
DML [29] 73.97 63.40 19.84 75.94 66.05 21.82 76.33 66.21 27.45 78.09 68.43 25.13
AFD [37] 74.38 63.86 18.86 76.46 66.29 12.45 76.69 66.78 26.12 78.52 68.69 23.62
CTCL [38] 75.01 64.64 22.04 77.00 67.18 17.23 77.39 67.52 23.73 79.44 70.17 21.57
KDCL [39] 74.12 63.25 18.05 76.32 66.02 18.61 76.55 66.36 26.87 78.25 68.56 24.63
FFL [40] 74.58 64.40 20.21 76.79 66.87 15.95 76.92 66.97 24.34 78.86 69.31 21.56
AFID [44] 74.91 64.36 20.85 77.16 67.34 15.05 77.16 67.24 23.96 79.09 69.62 22.13
CTRCL (Ours) 76.09 65.77 17.38 78.03 68.33 15.80 78.64 69.50 21.88 80.81 71.91 19.71
Δ(%)\Delta(\%) 3.74 4.91 31.33 3.11 4.05 18.51 3.99 5.83 27.33 3.80 6.31 25.68
ACDC Vanilla 86.48 76.88 1.96 88.09 79.40 1.33 88.39 79.82 1.85 89.22 81.12 1.46
DML [29] 86.95 77.68 1.48 87.87 79.03 1.67 88.94 80.64 1.25 89.12 80.99 1.55
AFD [37] 87.42 78.42 1.35 88.37 79.87 1.28 89.30 81.22 1.21 89.44 81.42 1.27
CTCL [38] 88.21 79.60 1.37 88.98 80.80 1.22 89.92 82.18 1.19 90.33 82.85 1.24
KDCL [39] 87.21 78.06 1.58 88.26 79.68 1.52 89.25 81.17 1.43 89.56 81.65 1.26
FFL [40] 87.78 78.98 1.28 88.64 80.21 1.52 89.62 81.75 1.24 89.75 81.94 1.21
AFID [44] 87.92 79.09 1.33 88.85 80.57 1.73 89.76 81.89 1.18 89.98 82.31 1.23
CTRCL (Ours) 89.01 80.80 1.23 89.84 82.10 1.17 90.52 83.19 1.20 91.19 84.25 1.11
Δ(%)\Delta(\%) 2.93 5.10 37.24 1.99 3.40 12.03 2.41 4.22 35.14 2.21 3.86 23.97
TABLE II: Quantitative Comparisons of DSC (%\%), JAC (%\%) and MAE (%\%) Scores on Kvasir Dataset. \uparrow(\downarrow) Denotes Higher the Better (Lower the Better). The Best Results are Shown in Boldface. Vanilla Means no Collaborative Learning . Δ(%)\Delta(\%) Denotes the Improvement Rate Comparing Our Method with Vanilla.
Dataset Method MobileNetV2 MiT-B0 ResNet-50 MiT-B2
DSC\uparrow JAC\uparrow MAE\downarrow DSC\uparrow JAC\uparrow MAE\downarrow DSC\uparrow JAC\uparrow MAE\downarrow DSC\uparrow JAC\uparrow MAE\downarrow
Kvasir Vanilla 86.39 79.64 3.68 89.17 83.64 3.24 87.23 81.12 3.82 89.96 85.02 2.85
DML [29] 87.26 81.20 3.52 88.94 83.01 3.25 88.43 82.51 4.10 89.72 84.39 2.88
AFD [37] 87.77 81.51 3.15 89.64 83.87 2.88 89.22 83.13 3.14 90.51 85.24 2.83
CTCL [38] 88.58 82.82 2.94 90.78 85.37 2.64 90.17 84.86 2.38 91.87 86.97 2.36
KDCL [39] 87.52 81.73 3.38 89.40 83.50 3.03 88.91 83.04 3.48 90.37 85.09 2.66
FFL [40] 87.94 81.97 3.17 89.97 84.14 2.88 89.57 84.00 2.61 90.92 85.86 2.44
AFID [44] 88.18 82.47 3.13 90.39 84.93 2.82 89.95 84.51 2.58 91.30 86.28 2.53
CTRCL (Ours) 89.66 84.34 2.60 91.81 86.60 2.50 91.68 86.55 2.18 93.12 88.24 1.96
Δ(%)\Delta(\%) 3.79 5.90 29.35 2.96 3.54 22.84 5.10 6.69 42.93 3.51 3.79 31.23

IV-B Evaluation Metrics

We adopt four evaluation metrics to evaluate the performance of various methods, including dice similarity coefficient (DSC), Jaccard index (JAC), Hausdorff distance (HSD) and mean absolute error (MAE). Dice similarity coefficient is a statistic used to gauge the similarity between two sets of data. Jaccard index is a measure used to evaluate the performance of an algorithm in the field of detecting objects within an image. Hausdorff distance is the maximum distance of a set to the nearest point in the other set. Mean absolute error is a quantitative measure of the average difference between all elements in the predicted mask and the ground truth. The higher value is better for dice similarity coefficient and Jaccard index while the lower is better for Hausdorff distance and mean absolute error.

IV-C Implementation Details

We implement our proposed method on PyTorch framework and conduct all the experiments on a single NVIDIA RTX 4090 GPU with fixed random seeds. For CNN-based students, we choose the semantic FPN [45] as decoder with encoders of MobileNetV2 [46] and ResNet-50 [5]; for Transformer-based students, we utilize the efficient SegFormer [6] decoder with encoders of MiT-B0 and MiT-B2, which have comparable parameters with their CNN counterparts, respectively. Pre-trained weights on ImageNet-1k [47] are employed to initialize the backbone networks. We itemize the crop size (cs), batch size (bs), learning rate (lr), training epochs (ep), optimizer (opt) for each dataset: (1) Synapse Multi-organ dataset and ACDC dataset: cs=224×\times224; bs=8; lr=0.0003; ep=200; opt=AdamW. (2) Kvasir dataset: cs=352×\times352; bs=8; lr=0.00005; ep=300; opt=AdamW. We utilize the polynomial learning-rate strategy [48] to adjust the learning rate with the power of 0.9. We adopt the random flip and random rotation for data augmentation to avoid overfitting. For a more robust and fair comparison, we take the checkpoint from the last epoch and report the results on the testing sets.

Refer to caption
Figure 5: Visual comparisons on Synapse Multi-organ dataset. (a) MobileNetV2. (b) MiT-B0. (c) ResNet-50. (d) MiT-B2.
Refer to caption
Figure 6: Visual comparisons on ACDC dataset. Green, red and blue denote RV, LV and Myo respectively. (a) MobileNetV2. (b) MiT-B0. (c) ResNet-50. (d) MiT-B2.
Refer to caption
Figure 7: Visual comparisons on Kvasir dataset. The polyps are shown in white. (a) MobileNetV2. (b) MiT-B0. (c) ResNet-50. (d) MiT-B2.

IV-D Comparisons with State-of-the-arts

We compare our proposed method with other 6 state-of-the-art collaborative learning methods, including DML [29], AFD [37], CTCL [38], KDCL [39], FFL [40], and AFID [44]. Among them, DML, AFD and CTCL are mutual learning based methods while KDCL, FFL and AFID are ensemble learning based methods. For comparison, we implement all the methods on two pairs: MobileNetV2 and MiT-B0; ResNet-50 and MiT-B2. Experimental results and analyses are as follows.

IV-D1 Results on Synapse Multi-organ Dataset

We first evaluate the aforementioned methods on the synapse multi-organ dataset and report the quantitative results in Table I. In contrast to the students without KD (Vanilla), all the methods show varying improvements in terms of DSC, JAC and HSD while our method achieves consistently the best performance compared to other online KD methods. Specifically, our method produces a remarkable margin of 3.74%, 4.91%, 31.33% on MobileNetV2 and 3.11%, 4.05%, 18.51% on MiT-B0 in DSC, JAC and HSD, respectively, much outperforming the second best method CTCL which improves MobileNetV2 by 2.26%, 3.11%, 12.91% and MiT-B0 by 1.74%, 2.30%, 11.14% in terms of DSC, JAC and HSD. Moreover, our method achieves larger benefits for more sophisticated students, with an improvement of 3.99%, 5.83%, 27.33% on ResNet-50 and 3.80%, 6.31%, 25.68% on MiT-B2 in DSC, JAC and HSD, respectively, indicating that reasonable knowledge transferring contributes to stimulating the potential of complicated student models especially under the condition of limited training data.

IV-D2 Results on ACDC Dataset

To avoid randomness, we also conduct our experiments on the ACDC dataset and the quantitative results are shown in Table I. As expected, our method still achieves the best segmentation performance over other online KD methods. Specifically, our methods outperforms the previous best mutual learning based method CTCL with improvements of 0.91%, 1.51%, 10.22% on MobileNetV2 and 0.97%, 1.61%, 4.10% on MiT-B0 in terms of DSC, JAC and HSD, respectively. In addition, compared with the best ensemble learning based method AFID, our method improves the evaluation of DSC, JAC by 0.85%, 1.59% on ResNet-50 and 1.34%, 2.36% on MiT-B2.

IV-D3 Results on Kvasir Dataset

To verify the generalization ability of our proposed method, we conduct the experiments on Kvasir dataset and report the quantitative results in Table II. Different from the synapse multi-organ and ACDC datasets, there is only one class to be segmented in the Kvasir dataset, leading to limited category information. Besides, compared to the organ segmentation tasks with strong position and shape priors, there is a large variation of position and shape for polyps in endoscopic images, increasing the difficulty of accurate polyp segmentation. Surprisingly, nearly all the online KD methods achieve more prominent improvements on both CNN and Transformer students in contrast to their performances on synapse multi-organ and ACDC datasets due to the complementary intrinsic properties of Transformer and CNN. Meanwhile, our method still achieves the highest performance with improvements of 1.22%, 1.84%, 11.56% on MobileNetV2 and 1.13%, 1.44%, 5.3% on MiT-B0; 1.67%, 1.99%, 8.4% on ResNet-50 and 1.36%, 1.46%, 16.95% on MiT-B2 over the previous best model CTCL in terms of DSC, JAC and MAE, respectively. These quantitative results on three datasets substantiate the fine robustness of our proposed RCL.

IV-D4 Visual Comparison

We also show some visual comparisons of our proposed method and other state-of-the-art online KD methods on Synapse Multi-organ dataset, ACDC dataset and Kvasir dataset in Fig. 5, 6 and 7 respectively to exhibit the superiority of our approach from an intuitive view. As we can see, compared to the vanilla students (Column 2), all the online KD methods (Column 3-5) achieve qualitative improvements on the segmentation maps. Meanwhile, our method (Column 5) is capable of generating fine predictions that are more consistent with the ground truth (Column 6) in contrast to either mutual learning based method (Column 3) or ensemble learning based method (Column 4). Specifically, with our method implemented, there are more accurate locations of the segmented objects for CNN-based students MobileNetV2 (Row 1) and ResNet-50 (Row 3), and more detailed boundary descriptions of them for Transformer-based students MiT-B0 (Row 2) and MiT-B2 (Row 4), which indicates the effectiveness of the bidirectional knowledge transmission between CNN student and Transformer student achieved by our proposed framework.

IV-E Ablation Studies

The proposed framework mainly consists of two key modules, namely rectified logit-wise collaborative learning (RLCL) and class-aware feature-wise collaborative learning (CFCL). In this section, we carry out a series of experiments with different settings to verify the performance of different components and key modules. We conduct our ablation studies on the Synapse Multi-organ dataset using ResNet-50 and MiT-B2, and choose DSC and HSD for evaluation.

IV-E1 The Effectiveness of Different Components

Table III reports the results of different components. As we can see, both RLCL and CFCL can enhance the performance of CNN-based and Transformer-based students. Specifically, it represents an improvement of 2.76%, 22.45% on ResNet-50 and 2.53%, 19.08% on MiT-B2 in terms of DSC and HSD as we employ the RLCL on the vanilla students. Similarly, with the CFCL implemented, there is a gain of 1.73%, 14.51% on ResNet-50 and 1.76%, 13.05% on MiT-B2 in DSC and HSD, respectively. The best results are achieved by combining both RLCL and CFCL, which demonstrates a good cooperativity of the proposed two modules.

TABLE III: Ablation Analysis of Different Components. And the Best Results are Shown in Boldface.
No. Vanilla RLCL CFCL ResNet-50 MiT-B2
DSC\uparrow HSD\downarrow DSC\uparrow HSD\downarrow
1 75.62 30.11 77.85 26.52
2 77.71 23.35 79.82 21.46
3 76.93 25.74 79.22 23.06
4 78.64 21.88 80.81 19.71
Refer to caption
Figure 8: t-SNE visualization of the extracted features from CNN-based and Transformer-based students with our proposed RLCL and CFCL. (a) Vanilla. (b) Vanilla + RLCL. (c) Vanilla + CFCL. (d) Vanilla + RLCL + CFCL.

In Fig. 8, we visualize the feature representations extracted from two students with different components. As we can see, our proposed RLCL and CFCL encourage significant inter-class separation and intra-class compactness, indicating that our proposed methods effectively help the student models to learn discriminative features for each class, which is crucial for segmentation predictions.

TABLE IV: Comparison of Our RLCL with Other Logit-wise Collaborative Learning Strategies. The Best Results are Shown in Boldface.
No.     Methods ResNet-50 MiT-B2
DSC\uparrow HSD\downarrow DSC\uparrow HSD\downarrow
1    Vanilla 75.62 30.11 77.85 26.52
2    + ML[29] 76.33 27.45 78.09 25.13
3    + EL [39] 76.55 26.87 78.25 24.63
4    + BSD [38] 76.89 25.62 78.93 23.15
5    + RLCL (Ours) 77.71 23.35 79.82 21.46
TABLE V: Ablation Analysis of the Rectification Weight λ\lambda. The Best Results are Shown in Boldface.
No.     λ\lambda ResNet-50 MiT-B2
DSC\uparrow HSD\downarrow DSC\uparrow HSD\downarrow
1    λ=1\lambda=1 75.62 30.11 77.85 26.52
2    λ=λa\lambda=\lambda^{a} 76.71 26.24 78.82 23.67
3    λ=λaλs\lambda=\lambda^{a}\cdot\lambda^{s} 77.42 24.68 79.57 21.94
4    λ=λaλc\lambda=\lambda^{a}\cdot\lambda^{c} 77.14 25.16 79.28 22.65
5    λ=λaλsλc\lambda=\lambda^{a}\cdot\lambda^{s}\cdot\lambda^{c} 77.71 23.35 79.82 21.46
TABLE VI: Ablation Analysis of CFCL on Different Stages. cfclE{\mathcal{L}}_{cfcl}^{E} and cfclD{\mathcal{L}}_{cfcl}^{D} Denote the Class-aware feature-wise distillation on the Output Features of Encoder and Decoder, Respectively. The Best Results are Shown in Boldface.
No. Vanilla cfclE{\mathcal{L}}_{cfcl}^{E} cfclD{\mathcal{L}}_{cfcl}^{D} ResNet-50 MiT-B2
DSC\uparrow HSD\downarrow DSC\uparrow HSD\downarrow
1 75.62 30.11 77.85 26.52
2 76.54 26.67 78.24 25.14
3 76.18 28.32 78.71 24.21
4 76.93 25.74 79.22 23.06

IV-E2 The Effectiveness of RLCL

In Table IV, we compare the RLCL with three logit-wise collaborative learning strategies including mutual learning (ML) [29], ensemble learning (EL) [39] and bi-directional selective distillation (BSD) [38]. As we can see, ML achieves inferior improvements on both students for overlooking the accuracy of student soft labels. EL obtains a considerable increase due to the enhanced ensemble label. BSD selects the reliable regions in student soft labels for bidirectional distillation and attains a large improvement. Comparatively, our RLCL achieves the best performance on both two students, indicating its superiority that adaptively rectifies the wrong regions in student soft labels for accurate knowledge transfer in the logit space.

We also explore the influence of the rectification weight λ\lambda with different configurations and report the quantitative results in Table V. Compared to the intuitive logit-wise collaborative learning without rectification (λ\lambda = 1), an obvious improvement is achieved as we align the probabilities of mis-categorized class and truth class for the wrong regions of soft labels via λa\lambda^{a}. In addition, we observe that both similarity-based decay factor λs\lambda^{s} and certainty-based decay factor λc\lambda^{c} can further enhance the performance of two students due to the adaptive rectification under the guidance of ground truth. By combining λa\lambda^{a}, λs\lambda^{s} and λc\lambda^{c} as the final rectification weight λ\lambda, we obtain the best results on both two students through logit-wise collaborative learning with dynamically rectified soft labels.

IV-E3 The Effectiveness of CFCL

Table VI demonstrates the effectiveness of CFCL on the output features of encoder and decoder. As we can see, both the CFCL on the output feature of encoder and that of decoder contribute to a performance improvement on two students. Meanwhile, we observe that the CFCL on the output feature of encoder benefits more than that of the decoder for CNN-based student (DSC: +1.22% vs. +0.74%) while this circumstance is on the contrary for Transformer-based student (DSC: +0.5% vs. +1.10%). That is because the output feature of Transformer encoder incorporates more global semantic information that benefits the CNN-based student for accurate location while the output feature of CNN decoder is richer in local detail information, encouraging the Transformer-based student to predict more elaborate boundaries.

IV-E4 The Sensitivity of β\beta, γ1\gamma_{1} and γ2\gamma_{2}

Table VII reports the DSC(%) of the students with different ratios of β\beta, γ1\gamma_{1} and γ2\gamma_{2} on the Synapse Multi-organ dataset. As we can see, increasing the importance of β\beta significantly improves the performance of CNN-based student due to the accurate knowledge transfer from Transformer to CNN by the proposed RLCL. However, the performance of Transformer-based student degrades slightly when β\beta=5. As our goal is to facilitate the collaborative learning between two students, we choose β\beta=3 as it presents the best trade-off for both students. Similarly, we choose γ1\gamma_{1}=1 and γ2\gamma_{2}=2 for the proposed CFCL between CNN-based student and Transformer-based student on the output features of encoder and decoder, respectively. For the sake of convenience and robustness, we adopt the configuration of β\beta=3, γ1\gamma_{1}=1 and γ2\gamma_{2}=2 for all the experiments of our method, i.e., two pairs of CNN-Transformer students on three medical image segmentation datasets.

TABLE VII: Ablation Analysis of β\beta, γ1\gamma_{1} and γ2\gamma_{2}. Δs\Delta_{s} is the Sum of Improvement in DSC on Both Students Compared with the Vanilla.
β\beta 0 1 2 3 4 5
ResNet-50 75.62 77.15 77.56 77.71 77.90 77.34
MiT-B2 77.85 78.84 79.09 79.82 78.57 77.64
Δs\Delta_{s} 0 +2.52 +3.18 +4.06 +3.00 +1.51
γ1\gamma_{1} 0 0.1 0.5 1 2 5
ResNet-50 75.62 75.87 76.17 76.54 76.63 76.04
MiT-B2 77.85 77.96 78.38 78.24 77.94 77.52
Δs\Delta_{s} 0 +0.36 +1.08 +1.31 +1.10 +0.09
γ2\gamma_{2} 0 1 2 3 4 5
ResNet-50 75.62 75.95 76.18 76.21 75.85 75.47
MiT-B2 77.85 78.56 78.71 78.34 78.33 78.22
Δs\Delta_{s} 0 +1.04 +1.42 +1.08 +0.71 +0.32

IV-F Generalization Studies

In this section, we design a series of extensive experiments to demonstrate the generalization of our method. The results and analyses are as follows.

IV-F1 Generalization on Different Pairs of Students

Fig. 9 presents the results of our method on different pairs of students. Specifically, we implement our method on isomorphic students in Fig. 9 (a) and heterogeneous students with incomparable parameters in Fig. 9 (b). As we can see, our method still achieves considerable improvements in DSC for sub-students, indicating the outstanding generalization and robustness of our method. Moreover, CNN-Transformer collaborative learning benefits more than CNN-CNN or Transformer-Transformer collaborative learning due to the global-local complementarity between CNN and Transformer.

IV-F2 Generalization on Knowledge Distillation

We also migrate our method to knowledge distillation (KD), i.e., the CNN-based and Transformer-based students guiding each other as teachers respectively in an offline manner. The comparison results with collaborative learning (CL) are shown in Fig. 10. Obviously, our method can yet enhance the performance of students in offline KD paradigm, further verifying the generalization ability of our method. Meanwhile, we observe that simultaneous collaborative learning achieves larger improvements than two-stage KD for both CNN-based and Transformer-based students because the teacher is progressively optimized in the collaborative learning process.

Refer to caption
Figure 9: Results of our method on different pairs of students. (a) Isomorphic students (MobileNetV2 & ResNet-50; MiT-B0 & MiT-B2). (b) Heterogeneous students in different sizes (MobileNetV2 & MiT-B2; ResNet-50 & MiT-B0).
Refer to caption
Figure 10: Results of our method on knowledge distillation (KD) and comparison with collaborative learning (CL).
Refer to caption
Figure 11: Results of our method on existing CNN-based and Transformer-based MIS methods.

IV-F3 Generalization on Existing MIS Methods

In Fig. 11, we extend our proposed framework to existing CNN-based and Transformer-based MIS methods, including 4 pairs of CNN-Transformer students: U-Net [7] & SwinUNet [13]; Att-UNet [49] & MissFormer [15]; U-Net++ [8] & UCTransnet [50]; DARR [51] & MT-UNet [52]. Our method presents excellent generalization ability on various students and achieves consistent improvements on them. The results indicate that our framework can significantly improve the performance of existing networks via accurate and effective collaborative learning between them, which makes great sense to the development of MIS models in the future.

V Conclusion

In this paper, we propose a CNN-Transformer rectified collaborative learning (CTRCL) framework to learn stronger CNN-based and Transformer-based models for MIS tasks via bi-directional knowledge transfer between them in both logit and feature spaces. For accurate logit-wise knowledge transfer, we propose an RLCL strategy which introduces the ground truth to adaptively rectify the wrong regions in soft labels. For effective feature-wise knowledge transfer between heterogeneous CNN and Transformer, we propose a CFCL strategy to grant their immediate features the similar category perception ability. Experimental results on three popular benchmarks show superior performance of our method over 6 state-of-the-art collaborative learning methods. Furthermore, extensive ablation and generalization studies demonstrate the effectiveness of each component as well as the generalization ability of our method.

References

  • [1] W. Chen, R. Smith, S.-Y. Ji, K. R. Ward, and K. Najarian, “Automated ventricular systems segmentation in brain ct images by combining low-level segmentation and high-level template matching,” BMC Med. Inf. Decis. Making, vol. 9, no. 1, p. S4, Dec. 2009.
  • [2] Z. Yu-Qian, G. Wei-Hua, C. Zhen-Cheng, T. Jing-Tian, and L. Ling-Yun, “Medical images edge detection based on mathematical morphology,” in Proc. 27th Annu. Int. Conf. IEEE Eng. Med. Biol. Soc. (EMBC), Jan. 2006, pp. 6492–6495.
  • [3] S. Li, T. Fevens, and A. Krzyżak, “A SVM-based framework for autonomous volumetric medical image segmentation using hierarchical and coupled level sets,” in Int. Congr. Ser., vol. 1268, pp. 207–212, Jun. 2004.
  • [4] R. R. Selvaraju, M. Cogswell, A. Das, R. Vedantam, D. Parikh, and D. Batra, “Grad-CAM: Visual explanations from deep networks via gradient-based localization,” in Proc. IEEE Int. Conf. Comput. Vis. (ICCV), Oct. 2017, pp. 618–626.
  • [5] K. He, X. Zhang, S. Ren, and J. Sun, “Deep residual learning for image recognition,” in Proc. IEEE Conf. Comput. Vis. Pattern Recognit. (CVPR), Jun. 2016, pp. 770–778.
  • [6] E. Xie, W. Wang, Z. Yu, A. Anandkumar, J. M. Alvarez, and P. Luo, “SegFormer: Simple and efficient design for semantic segmentation with transformers,” Proc. Adv. Neural Inform. Process. Syst., 2021, pp. 12077–12090.
  • [7] O. Ronneberger, P. Fischer, and T. Brox, “U-Net: Convolutional networks for biomedical image segmentation,” in Proc. Int. Conf. Med. Image Comput. Comput.-Assist. Intervent., 2015, pp. 234–241.
  • [8] Z. Zhou, M. M. R. Siddiquee, N. Tajbakhsh, and J. Liang, “UNet++: Redesigning skip connections to exploit multiscale features in image segmentation,” IEEE Trans. Med. Imag., vol. 39, no. 6, pp. 1856–1867, Jun. 2020.
  • [9] Z. Gu et al., “CE-Net: Context encoder network for 2D medical image segmentation,” IEEE Trans. Med. Imag., vol. 38, no. 10, pp. 2281–2292, Oct. 2019.
  • [10] S. Feng et al., “CPFNet: Context pyramid fusion network for medical image segmentation,” IEEE Trans. Med. Imag., vol. 39, no. 10, pp. 3008–3018, Oct. 2020.
  • [11] Z. Wang, “Automatic localization and segmentation of the ventricles in magnetic resonance images,” IEEE Trans. Circuit Syst. Video Technol., vol. 31, no. 2, pp. 621–631, Feb. 2021.
  • [12] T. Zhou, Y. Zhou, G. Li, G. Chen, and J. Shen, “Uncertainty-aware hierarchical aggregation network for medical image segmentation,” IEEE Trans. Circuit Syst. Video Technol., early access, Feb. 26, 2024, doi: 10.1109/TCSVT.2024.3370685.
  • [13] H. Cao, et al., “Swin-UNet: Unet-like pure transformer for medical image segmentation,” in Proc. Eur. Conf. Comput. Vis. (ECCV), 2022, pp. 205–218.
  • [14] A. Lin, B. Chen, J. Xu, Z. Zhang, G. Lu, and D. Zhang, “DS-TransUNet: Dual swin transformer U-Net for medical image segmentation,” IEEE Trans. Instrum. Meas., vol. 71, pp. 1–15, 2022.
  • [15] X. Huang, Z. Deng, D. Li, X. Yuan, and Y. Fu, “MISSFormer: an effective transformer for 2D medical image segmentation,” IEEE Trans. Med. Imag., vol. 42, no. 5, pp. 1484-1494, May 2023.
  • [16] H. Li, D.-H. Zhai, and Y. Xia, “ERDUnet: An efficient residual double-coding unet for medical image segmentation,” IEEE Trans. Circuit Syst. Video Technol., vol. 34, no. 4, pp. 2083-2096, Apr. 2024.
  • [17] J. Chen et. al, “TransUNet: Transformers make strong encoders for medical image segmentation,” 2021, arXiv:2102.04306.
  • [18] Y. Zhang, H. Liu, and Q. Hu, “TransFuse: Fusing transformers and CNNs for medical image segmentation,” in Proc. Int. Conf. Med. Image Comput. Comput.-Assist. Intervent., 2021, pp. 14–24.
  • [19] F. Yuan, Z. Zhang, and Z. Fang, “An effective CNN and transformer complementary network for medical image segmentation,” Pattern Recognit., vol. 136, p. 109228, Apr. 2023.
  • [20] B. Landman, Z. Xu, J. Igelsias, M. Styner, T. Langerak, and A. Klein, “MICCAI multi-atlas labeling beyond the cranial vault–workshop and challenge,” in Proc. MICCAI Multi-Atlas Labeling Beyond Cranial Vault—Workshop Challenge, vol. 5, p. 12, 2015.
  • [21] O. Bernard et al., “Deep learning techniques for automatic MRI cardiac multi-structures segmentation and diagnosis: is the problem solved?,” IEEE Trans. Med. Imag., vol. 37, no. 11, pp. 2514–2525, Nov. 2018.
  • [22] D. Jha et al., “Kvasir-SEG: A segmented polyp dataset,” in Proc. ACM Int. Conf. MultiMedia, 2020, pp. 451–462.
  • [23] D. Jha, et al., “ResUNet++: An advanced architecture for medical image segmentation,” in Proc. IEEE Int. Symp. Multimedia (ISM), Dec. 2019, pp. 225–2255.
  • [24] J. Hu, L. Shen, and G. Sun, “Squeeze-and-excitation networks,” in Proc. IEEE/CVF Conf. Comput. Vis. Pattern Recognit. (CVPR), Jun. 2018, pp. 7132–7141.
  • [25] L.-C. Chen, G. Papandreou, I. Kokkinos, K. Murphy, and A. L. Yuille, “DeepLab: Semantic image segmentation with deep convolutional nets, atrous convolution, and fully connected CRFs,” IEEE Trans. Pattern Anal. Mach. Intell., vol. 40, no. 4, pp. 834–848, Apr. 2018.
  • [26] D. Karimi, S. D. Vasylechko, and A. Gholipour, “Convolution-free medical image segmentation using transformers,” in Proc. Int. Conf. Med. Image Comput. Comput.-Assist. Intervent., 2021, pp. 78–88.
  • [27] H.-Y. Zhou, et al. “nnFormer: Volumetric medical image segmentation via a 3D transformer,” IEEE Trans. Image Process., vol. 32, pp. 4036-4045, 2023.
  • [28] G. Hinton, O. Vinyals, and J. Dean, “Distilling the knowledge in a neural network,” 2015, arXiv:1503.02531.
  • [29] Y. Zhang, T. Xiang, T. M. Hospedales, and H. Lu, “Deep mutual learning,” in Proc. IEEE/CVF Conf. Comput. Vis. Pattern Recognit. (CVPR), Jun. 2018, pp. 4320–4328.
  • [30] K. Zhang, C. Zhang, S. Li, D. Zeng, and S. Ge, “Student network learning via evolutionary knowledge distillation,” IEEE Trans. Circuit Syst. Video Technol., vol. 32, no. 4, pp. 2251–2263, Apr. 2022.
  • [31] L. Kong and J. Yang, “MDFlow: Unsupervised optical flow learning by reliable mutual knowledge distillation,” IEEE Trans. Circuit Syst. Video Technol., vol. 33, no. 2, pp. 677–688, Feb. 2023.
  • [32] C. Yang, Z. An, H. Zhou, F. Zhuang, Y. Xu, and Q. Zhang, “Online knowledge distillation via mutual contrastive learning for visual recognition,” IEEE Trans. Pattern Anal. Mach. Intell., vol. 45, no. 8, pp. 10212-10227, Aug. 2023.
  • [33] X. Zhu et al., “Knowledge distillation by on-the-fly native ensemble,” in Proc. Adv. Neural Inform. Process. Syst., 2018, pp. 7517–7527.
  • [34] G. Wu and S. Gong, “Peer collaborative learning for online knowledge distillation,” in Proc. AAAI Conf. Artif. Intell., 2021, pp. 10302–10310.
  • [35] S. Zhao, T. Xu, X.-J. Wu, and J. Kittler, “Distillation, ensemble and selection for building a better and faster siamese based tracker,” IEEE Trans. Circuit Syst. Video Technol., vol. 34, no. 1, pp. 182-194, Jan. 2024.
  • [36] T. Su et al., “Deep cross-layer collaborative learning network for online knowledge distillation,” IEEE Trans. Circuit Syst. Video Technol., vol. 33, no. 5, pp. 2075-2087, May 2023.
  • [37] I. Chung, S. Park, J. Kim, and N. Kwak, “Feature-map-level online adversarial knowledge distillation,” in Int. Conf. Mach. Learn., 2020, pp. 2006–2015.
  • [38] J. Zhu, Y. Luo, X. Zheng, H. Wang, and L. Wang, “A good student is cooperative and reliable: CNN-transformer collaborative learning for semantic segmentation,” in Proc. IEEE/CVF Int. Conf. Comput. Vis. (ICCV), Oct. 2023, pp. 11720–11730.
  • [39] Q. Guo et al., “Online knowledge distillation via collaborative learning,” in Proc. IEEE/CVF Conf. Comput. Vis. Pattern Recognit. (CVPR), Jun. 2020, pp. 11020–11029.
  • [40] J. Kim, M. Hyun, I. Chung, and N. Kwak, “Feature fusion for online mutual knowledge distillation,” in Proc. Int. Conf. Pattern Recognit. (ICPR), Jan. 2021, pp. 4619–4625.
  • [41] K. Wang, J. H. Liew, Y. Zou, D. Zhou, and J. Feng, “PANet: Few-shot image semantic segmentation with prototype alignment,” in Proc. IEEE/CVF Int. Conf. Comput. Vis. (ICCV), Oct. 2019, pp. 9197–9206.
  • [42] D.-P. Fan et al., “PraNet: Parallel reverse attention network for polyp segmentation,” in Proc. Int. Conf. Med. Image Comput. Comput.-Assist. Intervent., 2020, pp. 263–273.
  • [43] J.-H. Shi, Q. Zhang, Y.-H. Tang, and Z.-Q. Zhang, “Polyp-mixer: An efficient context-aware MLP-based paradigm for polyp segmentation,” IEEE Trans. Circuit Syst. Video Technol., vol. 33, no. 1, pp. 30–42, Jan. 2023.
  • [44] T. Su, Q. Liang, J. Zhang, Z. Yu, G. Wang, and X. Liu, “Attention-based feature interaction for efficient online knowledge distillation,” in Proc. IEEE Int. Conf. Data Mining (ICDM), Dec. 2021, pp. 579–588.
  • [45] A. Kirillov, R. Girshick, K. He, and P. Dollár, “Panoptic feature pyramid networks,” in Proc. IEEE/CVF Conf. Comput. Vis. Pattern Recognit. (CVPR), Jun. 2019, pp. 6399–6408.
  • [46] M. Sandler, A. Howard, M. Zhu, A. Zhmoginov, and L.-C. Chen, “MobileNetV2: Inverted residuals and linear bottlenecks,” in Proc. IEEE Conf. Comput. Vis. Pattern Recognit. (CVPR), Jun. 2018, pp. 4510–4520.
  • [47] J. Deng, W. Dong, R. Socher, L.-J. Li, K. Li, and L. Fei-Fei, “ImageNet: A large-scale hierarchical image database,” in Proc. IEEE Conf. Comput. Vis. Pattern Recognit. (CVPR), Jun. 2009, pp. 248–255.
  • [48] W. Liu, A. Rabinovich, and A. C. Berg, “ParseNet: Looking wider to see better,” 2015, arXiv:1506.04579.
  • [49] J. Schlemper et al., “Attention gated networks: Learning to leverage salient regions in medical images,” Med. Image Anal., vol. 53, pp. 197–207, Apr. 2019.
  • [50] H. Wang, P. Cao, J. Wang, and O. R. Zaiane, “UCTransNet: Rethinking the skip connections in U-Net from a channel-wise perspective with transformer,” in Proc. AAAI Conf. Artif. Intell., 2022, pp. 2441–2449.
  • [51] S. Fu et al., “Domain adaptive relational reasoning for 3D multi-organ segmentation,” in Proc. Int. Conf. Med. Image Comput. Comput.-Assist. Intervent., 2020, pp. 656–666.
  • [52] H. Wang et al., “Mixed Transformer U-Net for medical image segmentation,” in Proc. IEEE Int. Conf. Acoust., Speech Signal Process. (ICASSP), May 2022, pp. 2390–2394.
[Uncaptioned image] Lanhu Wu received the B.S. degree in communication engineering from Dalian University of Technology, Dalian, China, in 2020. He is currently with IIAU-OIP Lab, Dalian University of Technology, Dalian, China. His current research interests include computer vision, medical image segmentation and model compression.
[Uncaptioned image] Miao Zhang (Member, IEEE) received the B.S. degree in computer science from the Memorial University of Newfoundland, St. John’s, NL, Canada, in 2005, and the Ph.D. degree in electronic engineering from Kwangwoon University, Seoul, South Korea, in 2012. From 2013 to 2015, she was an Assistant Professor with the Department of Game and Mobile Contents, Keimyung University, Daegu, South Korea, and an Adjunct Professor with the Department of Computer Science, DigiPen Institute of Technology, Redmond, WA, USA, respectively. She is currently an Associate Professor with the Key Laboratory for Ubiquitous Network and Service Software of Liaoning Province, DUT-RU International School of Information Science and Software Engineering, Dalian University of Technology, Dalian, China. Her research interests include computer vision, machine learning, and 3D imaging and visualization.
[Uncaptioned image] Yongri Piao ((Member, IEEE) received the M.Sc. and Ph.D. degrees in information and communication engineering from Pukyong National University, Busan, South Korea, in 2005 and 2008, respectively. Since 2012, he has been an Associate Professor of information and communication engineering with the Dalian University of Technology, Dalian, China. His research interests include 3D computer vision and sensing, object detection and target recognition, and 3D reconstruction and visualization.
[Uncaptioned image] Zhenyan Yao received the B.S. degree in communication engineering from North China Electric Power University, Baoding, China, in 2022. He is currently with IIAU-OIP Lab, Dalian University of Technology, Dalian, China. His current research interests include computer vision, medical image segmentation and semi-supervised learning.
[Uncaptioned image] Weibing Sun Affiliated Zhongshan Hospital of Dalian University, Urology, Chief physician, Professor.
[Uncaptioned image] Feng Tian Affiliated Zhongshan Hospital of Dalian University, Urology, Chief physician, Professor.
[Uncaptioned image] Huchuan Lu (Fellow, IEEE) received the M.S. degree in signal and information processing and the Ph.D. degree in system engineering from the Dalian University of Technology (DUT), Dalian, China, in 1998 and 2008, respectively. He joined the Faculty of the School of Information and Communication Engineering, DUT, in 1998, where he is currently a Full Professor. His research interests include computer vision and pattern recognition with a focus on visual tracking, saliency detection, and segmentation.