This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Dealing with heterogeneous 3D MR Knee images: A Federated Few-Shot Learning method with dual knowledge distillation

Abstract

Federated Learning has gained popularity among medical institutions since it enables collaborative training between clients (e.g., hospitals) without aggregating data. However, due to the high cost associated with creating annotations, especially for large 3D image datasets, clinical institutions do not have enough supervised data for training locally. Thus, the performance of the collaborative model is subpar under limited supervision. On the other hand, large institutions have the resources to compile data repositories with high-resolution images and labels. Therefore, individual clients can utilize the knowledge acquired in the public data repositories to mitigate the shortage of private annotated images. In this paper, we propose a federated few-shot learning method with dual knowledge distillation. This method allows joint training with limited annotations across clients without jeopardizing privacy. The supervised learning of the proposed method extracts features from limited labeled data in each client, while the unsupervised data is used to distill both feature and response-based knowledge from a national data repository to further improve the accuracy of the collaborative model and reduce the communication cost. Extensive evaluations are conducted on 3D magnetic resonance knee images from a private clinical dataset. Our proposed method shows superior performance and less training time than other semi-supervised federated learning methods. Codes and additional visualization results are available at https://github.com/hexiaoxiao-cs/fedml-knee.

Index Terms—  Federated Learning, Few-shot Learning, Knowledge Distillation

1 Introduction

Refer to caption
Fig. 1: Overview of the proposed federated few-shot learning framework. We adopted the FL workflow from [1]. The dual knowledge acquired from Osteoarthritis Initiative (OAI) training repository (top) is distributed to each client during initialization. Each client trains with local low-resolution support images (bottom) and returns the trained student model to the server. Both OAI and local knee MR data contain red, green, and blue annotations for femoral cartilage (FC), tibial cartilage (TC), and patellar cartilage (PC).

Federated Learning (FL) [1] was introduced as a privacy-aware framework that utilizes isolated datasets without aggregating data into one location. It uses a client-server architecture where each client performs training on the local dataset, and the server aggregates models submitted by clients as indicated in Fig. 1. This framework is suitable for medical analysis applications [2, 3], since it enables the cooperation between medical institutions to train the same model while preserving privacy. Although FL is beneficial for ensuring privacy during collaboration, each client model still needs to be accurate for the joint model to become useful in clinical applications [4, 5, 6]. However, medical institutions usually lack labeled data due to insufficient annotation resources, especially for 3D images. For example, our private dataset only contains 20 labeled images per client. Previously, Zhang et al. [7] tackled the problem of lacking annotated data by performing semi-supervised training on all clients by maximizing the expectation of pseudo labels among unlabeled data. Yang et al. [8] employed the FixMatch consistency regularization among unlabeled data for better training. However, these semi-supervised methods struggle to produce an accurate segmentation network due to limited annotations. They also require more communication between clients and the server, which is challenging for institutions in remote areas. Therefore, more supervision during training is needed to reduce data traffic and improve accuracy.

Few-shot learning (FSL) utilizes prior knowledge acquired from a training set with sufficient annotations to guide training on the support set with limited labels for improving the accuracy of the client model. In our federated FSL approach, the support set located on each client utilizes our private dataset with limited supervision. The training set on the server uses the OAI repository that is heterogeneous in resolution and imaging parameters, while containing significantly more images with annotations compared to the support set. Therefore, high-quality OAI data can be used to improve the accuracy of clients on their local datasets. Instead of distributing the data repository to each client for training, a pre-trained model based on OAI data is created at the server and sent to clients to reduce data traffic. As shown as the 3D images in Fig. 1, the knee cartilages are thin tissues, thus posing challenges to local training with limited supervision. Also, the image resolution of our private dataset (bottom) is significantly lower than the data repository (top), because coarse scanned images are more commonly used in clinical applications. Such heterogeneity in resolution between the clinical dataset and the repository prohibits applying the model trained on OAI repository to the local dataset.

Instead of directly applying the pre-trained model, we can distill the knowledge of knee cartilages from the OAI repository to accelerate collaborative training. As illustrated in Fig. 2, our few-shot learning method contains a teacher-student architecture [9] to distill the dual knowledge that consists of the response and feature-based [10] knowledge of the target from the pre-trained teacher model to the local student models. The response-based knowledge refers to the soft label created by the teacher network. The representation of knee cartilages produced by the encoder from the teacher network is used as feature-based knowledge. Then offline distillation is used to transfer the dual knowledge from the pre-trained model to the client-side model through unlabeled local data. The distillation process helps each client to extract a more general feature that is not bounded to the quality of data [11, 12] and reduces the time and data transfer between clients and server through dual knowledge. In parallel, supervised learning adapts the collaborative model to the local dataset through the labeled data.

In this paper, we propose a FL-based few-shot learning framework with dual knowledge distillation for improved segmentation of knee cartilages from 3D MR data. Our contributions are: (i) we identified the problem of limited local annotations among medical institutions and propose a few-shot learning method that utilize prior knowledge from well-annotated open data repository to train a collaborative deep learning model with few local annotations, (ii) we address the data heterogeneity problem of using Non-IID [13] sources and a large disparity in imaging parameters between repository and local clinical data, and (iii) we identify the problem of massive data transfer associated with utilizing data repository in FL settings and solved it through prior feature extraction of the data repository. Two state-of-the-art methods have been selected and compared to our method. Our method has shown superior performance.

2 Methods

Refer to caption
Fig. 2: Our few-shot learning method with dual knowledge distillation in each client. The unsupervised images are used to gain response and feature-based knowledge from the teacher network and distill it to the student network. The student network is also optimized on the supervised images.

Few-shot Learning with Dual Knowledge Distillation. Fig. 2 shows our client-side few-shot learning framework with supervised and dual knowledge distillation loss terms. The supervised loss consists of the cross-entropy loss and the dice loss between the segmentation of the student network 𝐒\mathbf{S} and the ground truth to provide feedback, defined as LS(𝐒)L_{S}(\mathbf{S}).

Since the private dataset is heavily unlabeled, our method guides the training on unlabeled images in the support set by exploiting the dual knowledge acquired from the training set. This helps alleviate the label shortage during training of the student network with the local images. We define the dual knowledge distillation loss LDKDL_{DKD} using the response-based LRL_{R} and feature-based knowledge distillation loss LFL_{F}.

The response-based knowledge is extracted using the soft label produced by the teacher model 𝐓\mathbf{T}. The soft label contains the probability distribution of each voxel. Thus, it contains more information compared to the binary-encoded hard label. However, the teacher may produce incorrect labels because the teacher model is trained on the OAI repository, which is different from the private data. Such incorrect segmentation needs to be identified and excluded. Inspired by [9], we estimate the uncertainty of the teacher-produced soft label by utilizing multiple scholastic passes via random dropout and adding noise to the unsupervised data to get multiple soft labels. Then predictive entropy on the soft labels is utilized to assess the uncertainty of the teacher network. The score is then used to filter unreliable predictions and select confident labels for the student to learn. This process constitutes the response-based distillation loss LR(𝐓,𝐒)L_{R}(\mathbf{T},\mathbf{S}).

To further accelerate training and improve the accuracy of the collaborative model, feature-based knowledge of the teacher model is used. The idea is to capture the high-level representation of the target from the teacher network trained on the data repository. Since the teacher network is a pre-trained model, the feature maps of the teacher model are better than the randomly initialized student model. Therefore, the goal is to let the student network produce a similar set of feature vectors with the same cartilage as the teacher network to expedite the training process. With this in mind, we distill the feature-based knowledge by utilizing the KL divergence on the latent code produced by the encoder network in both the teacher and student networks. Let E𝐓\text{E}_{\mathbf{T}} and E𝐒\text{E}_{\mathbf{S}} be the encoder of teacher and student network, then LF(𝐓,𝐒)L_{F}(\mathbf{T},\mathbf{S}), which is the feature-based distillation loss with input xx, is:

LF(𝐓,𝐒)=jxE𝐓(j)logE𝐓(j)E𝐒(j)L_{F}(\mathbf{T},\mathbf{S})=\sum_{j\in x}\text{E}_{\mathbf{T}}(j)\log{\frac{\text{E}_{\mathbf{T}}(j)}{\text{E}_{\mathbf{S}}(j)}}\vspace{-0.2cm} (1)

Therefore, the loss function LL of our few-shot learning is

L(𝐓,𝐒)\displaystyle L(\mathbf{T},\mathbf{S}) =LS(𝐒)+λLDKD(T,S)\displaystyle=L_{S}(\mathbf{S})+\lambda L_{DKD}(\textbf{T},\textbf{S)} (2)
=LS(𝐒)+λ(LR(𝐓,𝐒)+LF(𝐓,𝐒))\displaystyle=L_{S}(\mathbf{S})+\lambda(L_{R}(\mathbf{T},\mathbf{S})+L_{F}(\mathbf{T},\mathbf{S}))\vspace{-0.3cm} (3)

where LS,LDKD,LR,LFL_{S},L_{DKD},L_{R},L_{F} stands for supervised loss, dual knowledge distillation loss, response-based distillation loss, and feature-based distillation loss, respectively. λ\lambda regularizes the supervised and dual knowledge distillation losses.

Although the teacher network can provide valuable insights in the first few rounds of training, the effectiveness of the knowledge distillation diminishes with the student model performing better on the private dataset. Eventually, the teacher model will hold back the performance of the student model. However, during the first few communication round, the student model barely contains any knowledge regarding the morphology of the cartilages. Updating the teacher network from the student network will undermine the accuracy of the teacher network. Thus, a delayed exponential moving average (EMA) update from student to teacher is applied.

Inter-institutional Federated Learning. To share the knowledge gained from each client, we integrate our proposed few-shot learning method into the federated learning framework. No patient data will be transferred in any part of the training process. In our paper, we facilitate the federated learning framework similarly to FedAVG [1] as indicated in Fig. 1. The federated learning process is outlined in Alg. 1: Let 𝐒tc\mathbf{S}^{c}_{t} be the student model weights from cCc\in C in the synchronization round tt:

Algorithm 1 In the cluster, there are N=|C|N=|C| clients in total, each with a learning rate of α\alpha. The set containing all clients is denoted as CC. The communication interval is denoted as EE.
1:
2:Initialize student model with random weights 𝐒0\mathbf{S}_{0}
3:Load and distribute the pre-trained teacher model 𝐓\mathbf{T}
4:for each communication rounds t1,,roundst\in{1,...,rounds} do
5:     for all each client cCc\in C do in parallel
6:         𝐒tc\mathbf{S}_{t}^{c}\leftarrow TrainLocally(c,𝐒t)(c,\mathbf{S}_{t}) # Collect models
7:     end for
8:     𝐒t+1c=0Npc𝐒tc\mathbf{S}_{t+1}\leftarrow\sum_{c=0}^{N}p_{c}\mathbf{S}^{c}_{t} # Aggregate client models
9:end for
10:
11:for each client iteration e1,,Ee\in{1,...,E} do
12:     𝐒e𝐒e1ηL(𝐓,𝐒e1)\mathbf{S}_{e}\leftarrow\mathbf{S}_{e-1}-\eta\nabla L(\mathbf{T},\mathbf{S}_{e-1}) # Perform local training
13:end for
14:if t6t\geq 6 then
15:     𝐓\mathbf{T}\leftarrow UpdateEMA(𝐓,𝐒𝐄\mathbf{T},\mathbf{S_{E}}) # Delayed EMA update
16:end if
17:return 𝐒E\mathbf{S}_{E}

3 Experiments

All Cartilages Femoral Cartilage Tibial Cartilage Patellar Cartilage
DSC VOE ASSD DSC VOE ASSD DSC VOE ASSD DSC VOE ASSD
Local 0.713 43.981 1.228 0.727 42.193 1.297 0.710 44.647 0.822 0.566 58.029 2.646
Semi 0.744 40.199 1.182 0.767 37.397 1.287 0.757 38.803 0.794 0.622 52.560 2.367
SSFL 0.746 39.973 1.100 0.740 40.344 1.608 0.754 39.236 0.631 0.654 50.152 1.464
Fed-Semi 0.762 38.310 0.902 0.763 38.125 1.022 0.756 39.063 0.629 0.673 48.158 1.500
Ours 0.789 34.529 0.643 0.796 33.386 0.632 0.777 36.309 0.523 0.690 45.464 1.441
Table 1: Quantitative comparison of methods on our private dataset. The best results have been highlighted in the chart.
Refer to caption
Fig. 3: Visual results of subject 1. (a) and (f) shows the GT labels; (b) and (g) are from local training; (c) and (h) are from SSFL; (d) and (i) are Fed-Semi; (e) and (j) are from our proposed method in 3D and sagittal views, respectively.

Experiment settings. We evaluate our method on a private dataset, which contains 2020 labeled and 10001000 unlabeled 3D MR knee images in each of the 4 clients. The voxel size (mm) of the images ranges from (0.303,0.303,3.5)(0.303,0.303,3.5) to (0.3125,0.3125,4.5)(0.3125,0.3125,4.5). The dual knowledge is extracted from the OAI repository with voxel size of (0.365,0.365,0.7)(0.365,0.365,0.7). All images are resized to 352×288×16352\times 288\times 16, and their pixel intensity has been normalized to [0,1][0,1]. All datasets have been split into 6:2:26:2:2 for training, validation, and testing. We utilize U-net [14] as the segmentation network.

Two state-of-the-art federated semi-supervised segmentation approaches, SSFL [7] and Fed-Semi [8], are compared. We also evaluated the performance without federated learning (Local) and without knowledge distillation (Semi). To measure accuracy and spatial correctness, dice similarity coefficient (DSC), volumetric overlap error (VOE) (mm3), and average symmetric surface distance (ASSD) (mm) between the GT labels and segmentation are reported.

Refer to caption
Fig. 4: Visual result for subject 2. (a) SSFL and (b) Ours.

Experiment results. As observed in Table 1, our method outperforms both SSFL and Fed-Semi, with a DSC increase of 5.7%5.7\% and 3.6%3.6\%, respectively. This indicates that the additional knowledge improved the training in the FL scenario. Furthermore, our 3D surface is more accurate, since the ASSD is reduced by 41%41\%. By comparing our method to local training, federated learning improves DSC by 10.6%10.6\% for all cartilages. Prior knowledge is advantageous in small target segmentation since it provides additional information about it. For example, PC is a small cartilage, and segmenting PC is challenging since there are fewer voxels representing that cartilage. Our method has shown an increase of 11%11\% in DSC score on PC compared to those without prior knowledge. For efficiency, our method utilizes 19 communication rounds compared to 28 rounds needed by other methods, which amounts to a 32%32\% decrease in data transfer and training time.

Fig. 3 shows three cartilages of one subject with the above-mentioned methods. Comparing the white circled region, local training (b) under segments PC compared to our method (e), which confirms that the dual knowledge helps the small cartilage segmentation. Meanwhile, as indicated in the yellow circled area of Fig. 3, our method produced the most accurate result compared to other methods, which all failed to discover parts of FC. In particular, about one-third of FC produced by local training (g) is missing. This provides evidence that clients with limited labels cannot train a usable network, and collaboration between medical institutions is needed. In addition, both SSFL (h) and Fed-Semi (i) methods show discontinuous labels of FC indicating insufficient knowledge of cartilage shape compared to the proposed method. To show the stability of our method, Fig. 4 demonstrates a hard case in our dataset. The label produced by SSFL missed a significant portion of FC compared to our method, which is not acceptable for medical applications. However, our method maintains high accuracy throughout the test cases because of the additional knowledge distilled from the OAI repository.

4 Conclusion

In this work, we proposed a few-shot FL framework with dual knowledge distillation. The dual knowledge, including the response and feature-based knowledge extracted from the data repository on the server side, is used to accelerate and guide the training of the student model locally using the private dataset. Our few-shot learning reduces annotation requirements in each client, and knowledge distillation mitigates the challenge of dissimilarity in imaging resolution and parameters of the training and support set. We carried out a comprehensive analysis of our method and obtained superior results.

5 Compliance with Ethical Standards

This research study was conducted retrospectively using human subject data including open access dataset by National Institution of Health through the Osteoarthritis Initiative and private dataset from Shanghai Sixth People’s Hospital. The studies involving human participants were reviewed and approved by Ethics Committee of Shanghai Sixth People’s Hospital. Written informed consent to participate in this study was provided by the participants’ legal guardian/next of kin.

References

  • [1] Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, and Blaise Aguera y Arcas, “Communication-efficient learning of deep networks from decentralized data,” in Artificial intelligence and statistics. PMLR, 2017, pp. 1273–1282.
  • [2] Ittai Dayan and et al., “Federated learning for predicting clinical outcomes in patients with covid-19,” Nature Medicine, vol. 27, no. 10, pp. 1735–1743, 2021.
  • [3] Qi Dou and et al., “Federated deep learning for detecting covid-19 lung abnormalities in ct: a privacy-preserving multinational validation study,” npj Digital Medicine, vol. 4, no. 1, pp. 60, 2021.
  • [4] Qi Chang, Zhennan Yan, Mu Zhou, Di Liu, Khalid Sawalha, Meng Ye, Qilong Zhangli, Mikael Kanski, Subhi Al’Aref, Leon Axel, et al., “Deeprecon: Joint 2d cardiac segmentation and 3d volume reconstruction via a structure-specific generative method,” in International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, 2022, pp. 567–577.
  • [5] Di Liu, Yunhe Gao, Qilong Zhangli, Ligong Han, Xiaoxiao He, Zhaoyang Xia, Song Wen, Qi Chang, Zhennan Yan, Mu Zhou, et al., “Transfusion: multi-view divergent fusion for medical image segmentation with transformers,” in International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, 2022, pp. 485–495.
  • [6] Qilong Zhangli, Jingru Yi, Di Liu, Xiaoxiao He, Zhaoyang Xia, Qi Chang, Ligong Han, Yunhe Gao, Song Wen, Haiming Tang, et al., “Region proposal rectification towards robust instance segmentation of biological images,” in International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, 2022, pp. 129–139.
  • [7] Zhengming Zhang, Yaoqing Yang, Zhewei Yao, Yujun Yan, Joseph E Gonzalez, Kannan Ramchandran, and Michael W Mahoney, “Improving semi-supervised federated learning by reducing the gradient diversity of models,” IEEE International Conference on Big Data (Big Data), 2021.
  • [8] Dong Yang and et al., “Federated semi-supervised learning for covid region segmentation in chest ct using multi-national data from china, italy, japan,” Medical Image Analysis, vol. 70, pp. 101992, 2021.
  • [9] Lequan Yu, Shujun Wang, Xiaomeng Li, Chi-Wing Fu, and Pheng-Ann Heng, “Uncertainty-aware self-ensembling model for semi-supervised 3d left atrium segmentation,” in International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, 2019, pp. 605–613.
  • [10] Jianping Gou, Baosheng Yu, Stephen J Maybank, and Dacheng Tao, “Knowledge distillation: A survey,” International Journal of Computer Vision, vol. 129, no. 6, pp. 1789–1819, 2021.
  • [11] T. V. Nguyen, M. A. Dakka, S. M. Diakiw, M. D. VerMilyea, M. Perugini, J. M. M. Hall, and D. Perugini, “A novel decentralized federated learning approach to train on globally distributed, poor quality, and protected private medical data,” Scientific Reports, vol. 12, no. 1, pp. 8888, 2022.
  • [12] Siwei Mai, Qian Li, Qi Zhao, and Mingchen Gao, “Few-shot transfer learning for hereditary retinal diseases recognition,” in International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, 2021, pp. 97–107.
  • [13] Hangyu Zhu, Jinjin Xu, Shiqing Liu, and Yaochu Jin, “Federated learning on non-iid data: A survey,” Neurocomputing, vol. 465, pp. 371–390, 2021.
  • [14] Xiaoxiao He, Chaowei Tan, Yuting Qiao, Virak Tan, Dimitris Metaxas, and Kang Li, “Effective 3D humerus and scapula extraction using low-contrast and high-shape-variability MR data,” in Medical Imaging 2019: Biomedical Applications in Molecular, Structural, and Functional Imaging, Barjor Gimi and Andrzej Krol, Eds. International Society for Optics and Photonics, 2019, vol. 10953, pp. 118 – 124, SPIE.