This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Federated Learning with Privacy-Preserving Ensemble Attention Distillation

Xuan Gong    Liangchen Song    Rishi Vedula    Abhishek Sharma    Meng Zheng    Benjamin Planche   
Arun Innanje
   Terrence Chen \IEEEmembershipSenior Member, IEEE    Junsong Yuan \IEEEmembershipFellow, IEEE   
David Doermann \IEEEmembershipFellow, IEEE
   Ziyan Wu \IEEEmembershipSenior Member, IEEE X. Gong, L. Song, R. Vedula, J. Yuan and D. Doermann are with University at Buffalo, Buffalo NY, USA ({xuangong, lsong8, rishisat, jsyuan, doermann}@buffalo.edu). A. Sharma, M. Zheng, B. Planche, A Innanje, T. Chen and Z. Wu are with United Imaging Intelligence, Cambridge MA, USA ({first.last}@uii-ai.com). This paper is primarily based on the work done during X. Gong and L. Song’s internships with United Imaging Intelligence. Corresponding author: Z. Wu.
Abstract

Federated Learning (FL) is a machine learning paradigm where many local nodes collaboratively train a central model while keeping the training data decentralized. This is particularly relevant for clinical applications since patient data are usually not allowed to be transferred out of medical facilities, leading to the need for FL. Existing FL methods typically share model parameters or employ co-distillation to address the issue of unbalanced data distribution. However, they also require numerous rounds of synchronized communication and, more importantly, suffer from a privacy leakage risk. We propose a privacy-preserving FL framework leveraging unlabeled public data for one-way offline knowledge distillation in this work. The central model is learned from local knowledge via ensemble attention distillation. Our technique uses decentralized and heterogeneous local data like existing FL approaches, but more importantly, it significantly reduces the risk of privacy leakage. We demonstrate that our method achieves very competitive performance with more robust privacy preservation based on extensive experiments on image classification, segmentation, and reconstruction tasks.

{IEEEkeywords}

Privacy, Federated Learning, Distillation

1 Introduction

With increasing interest in topics such as edge computing [1], a new machine learning paradigm called federated learning (FL) is emerging. In FL, one does not necessarily need all data samples to reside in one specific “local” or edge-compute node. Instead, it relies on model fusion/distillation techniques to train a single centralized model in a distributed, decentralized fashion.

Several vital challenges make FL markedly different from typical distributed learning. First, privacy is always the key concern, w.r.t., protecting local data. Second, communication is a critical bottleneck, as the central training can easily get disrupted by network communication issues. Third, individual local data distributions can differ substantially as distributed data centers tend to collect data in different settings. This inherent heterogeneity can manifest itself in various ways: various data size or domain distributions, different local model architectures, or simply the diversity in knowledge across all local models.

Refer to caption
Figure 1: Unlike traditional FL approaches (a) which exchange gradients or weights directly with local nodes, the proposed privacy-preserving FL framework (b) relies only on the exchange of products of inference with non-proprietary public data.

Mainstreamed federated learning methods teach the central model using a recursive exchange of parameters/gradients between the central and local nodes [2, 3, 4, 5, 6, 7, 8]. Recent applications of FL in medical imaging also share similar features. Sheller et al. [9] share local model gradients for multi-institution collaboration, and Yang et al. [10] employ a similar approach for the task of semi-supervised COVID segmentation. Typically, such methods involve each local model sharing its gradients with a central server after each round of local training on its local data. The central server then aggregates the local model parameters [7, 11, 12]. Each local node then updates its local model with the latest global aggregation update, and this process repeats till convergence. These parameter-based communication methods have many known security weaknesses and are limited to models with homogeneous architectures. While some methods have shown hope of protecting against data leakage in medical imaging [13, 14], some recent works [15, 16] demonstrated that local private data could be fully recovered from publicly-shared gradients, further highlighting the associated privacy-related risks in general and in medical applications in particular.

Another approach to fusing local models into a single central model is to employ distillation [17, 18], where the central model is learned as a student with multiple local models as teachers. Distillation-based methods train the central model with aggregated locally-computed logits [19, 18, 20], therefore eliminating the requirement of identical architectures. Despite the known bottleneck of communication in FL, existing distillation-based FL methods optimize the central and local models jointly by synchronizing local inferred predictions, requiring a high degree of synchronization and numerous rounds of communication. Moreover, these techniques either assume that both the public and private data are sampled from the same underlying distribution [17], or iteratively exchange local parameters beyond the products on public data [20], both invariably exposing private data to attack. In addition to the issues of network communication and data security issues discussed above, these existing methods mainly rely on distilling the final predictions, e.g., logits, and completely ignore the structure knowledge that leads to these final predictions.

In contrast to previous FL methods that incrementally train the local models, update them synchronously, and exclusively share parameters or distill logits online, this paper ensembles stale local information with both logits and feature-level knowledge using a privacy-preserving, offline, federated distillation method under the heterogeneous FL setting, as illustrated in Figure 1. To protect the privacy of local data, we distill the unlabeled, non-sensitive, cross-domain public data output without exchanging local model parameters or gradients. The proposed distillation is one-way, i.e., from local nodes to a central server, with the local training remaining asynchronous and independent. Our key insight is that training the local models to completion allows us to mine and ensemble local models with structural knowledge to capture the internal expertise. To coordinate local knowledge in the FL heterogeneity framework, we propose federated attention distillation (FedAD) to fully exploit the consensus and diversity of attention maps across local models. We demonstrate the effectiveness and efficiency of our method via experiments with chest x-ray and brain tumor MRI datasets on image classification, segmentation, and reconstruction tasks. Our contributions are summarized as follows:

  • We propose an offline federated distillation framework to explicitly preserve local data privacy by only distilling model outputs on unlabeled, non-sensitive public data.

  • To deal with the heterogeneity in federated learning scenarios, we distill structure knowledge via novel attention-bound constraints for local ensembles with a trade-off between local consensus and diversity.

  • Our federated distillation pipeline is model-agnostic and highly flexible without any requirement w.r.t. online synchronization during communication.

  • We empirically show that our method can be extended to typical medical tasks such as image reconstruction. We simultaneously achieve competitive performance with a more robust privacy-preserving guarantee and superior communication efficiency.

This paper is an extended version of our previous work [21]. In particular, (a) we generalize our bound constraint from top-down attention generated by Grad-CAM [22] to self-attention [23]. (b) we extend the applications beyond classification or segmentation-related tasks to more general medical tasks, i.e., image reconstruction. (c) we theoretically analyze the privacy guarantee of our method (performance bound of the central model), where private data across local nodes, and public data used in knowledge distillation, are from different domains. (d) we provide comprehensive empirical studies on chest-x-ray image classification, brain tumor image segmentation and reconstruction tasks on cross-domain federated distillation. The experiment settings simulate the real in-the-wild FL scenarios, including local data across different institutions, local data with various sizes, public data in different domains with local data, and public data with modalities different from local ones.

2 Related Work

2.1 Federated Learning

2.1.1 Parameter-based FL

In parameter-based FL methods, each local model shares its parameters/gradients with the central server after each round of local training on its local data. The central server aggregates them by averaging [2]. Then the aggregated results are shared with the local nodes, updating their corresponding local model before proceeding with the next training round. This process is repeated until the stopping criterion is met. A variety of extensions of FedAvg [2, 7, 11, 12] employ improved aggregation schemes, such as adding momentum [4] and local weighting [11, 12]. Local weighting schemes have also been investigated based on client loss [11] and client data size [12]. FedMA [7] aggregates local parameters layer-wise by matching and averaging hidden elements. FedProx [5] incorporates a proximal term to restrict local updates close to the global model. SCAFFOLD [8] introduces control variations to correct the local updates.

Refer to caption
Figure 2: Overall pipeline of the proposed FedAD framework.

2.1.2 Distillation-based FL

Federated distillation methods exchange model output rather than model parameters. Solutions that produce central models by distilling knowledge from private data incur concerns w.r.t. the leaking of private local data [24, 25]. In contrast, some works [19, 17, 18] distill the output of public data. FedMD [17] divides local training and joint distillation into two phases by adding labeled public data. Each local model is fully trained with public and private data in the first phase. In the collaboration phase, the central and local models are jointly optimized to reach a consensus via the distillation of predictions on the public data. Cronous [18] implements a similar two-phase system, but the local initialization only utilizes private data since the public data is unlabeled. Although model-agnostic, these methods select public data based on prior knowledge w.r.t. private data. The recently proposed FedDF [20] makes it more robust to distillation data selection, but it still has privacy issues due to the iterative exchange of models over hundreds of rounds. The above-mentioned exclusively require many rounds of back-and-forth communication, leading to bandwidth bottlenecks and other inefficiencies. For communication efficiency, Guha et al. [26] proposes a preliminary investigation into offline federated learning where all local models can be trained independently without inter-institutional communication. It then distills through the averaging of predictions on unlabeled data. Another similar work is PATE [27], where the central model is trained with hard pseudo-labels voted by local models rather than the soft prediction we employed.

2.2 Knowledge Transfer

2.2.1 Knowledge Ensemble

Model ensemble [28] is a popular regularization technique that is more robust and generalizable than individual models. Classical ensemble approaches such as bagging [29], boosting [30], and Bayesian [31] are exploited to assemble partial knowledge for better prediction as a mixture of experts [32]. With the success of knowledge transfer [33], recent advancements in ensemble networks are dominated by the student-teacher learning paradigm [34]. Ensemble learning aggregates the knowledge of multiple teachers before distilling the knowledge into the student network. Supervised ensemble learning is dominated by gate learning to design the weight for aggregation [34, 35, 36]. In semi-supervised and self-supervised scenarios, [37] and [38] exploit the relative similarity between samples. Furthermore, co-distillation extends one-way transfer to bidirectional collaborative learning [39, 40, 41].

2.2.2 Structure Knowledge

Beyond the use of softened labels for distillation [33], many improvements have been developed to transfer structure knowledge, such as intermediate representations [42]. Attention transfer relaxes feature-level knowledge to attention maps either from bottom-up activation [43] or top-down gradients [44]. Yim et al. [45] utilizes the Gram matrix as a flow of solution procedure for distillation. Other works match features such as maximum mean discrepancy [46] or mutual information [47]. More recent attempts explore structured knowledge ensemble for knowledge distillation with labeled distillation data. FEED [48] can be seen as an extension of self-distillation [49] that accumulates and assembles feature-level knowledge to train itself recursively. Knowledge flow [50] enhances the student by adding transformed and scaled intermediate representations from teacher models. In contrast, our feature-level ensemble method is label-free, model agnostic, and is used for heterogeneous federated distillation scenarios.

3 Methodology

3.1 Problem Definition

In FL scenarios, we consider KK local nodes, each hosting a private, local dataset 𝒟k={(xki,yki)|i=1,,|𝒟k|}\mathcal{D}_{k}=\{(x_{k}^{i},y_{k}^{i})|i=1,\ldots,|\mathcal{D}_{k}|\}, where xkix_{k}^{i} and ykiy_{k}^{i} are the ii-th paired data sample from the kk-th local node. The public dataset 𝒟0={x0i|i=1,,|𝒟0|}\mathcal{D}_{0}=\{x_{0}^{i}|i=1,\ldots,|\mathcal{D}_{0}|\} can be either labeled or unlabeled and is accessible by all local nodes.

As illustrated in Figure 2, in the first stage, the model at each local node kk is initialized by training over its own local, private data 𝒟k\mathcal{D}_{k}. Let θk\theta_{k} denote the model parameters after this initial training. Please note that the proposed distillation-based approach is agnostic to neural network architecture. Hence, each local node can specify its unique architecture suited to the particular distribution of its local data. In the second stage, we ignore all private datasets and freeze the local models; hence reducing the risk of any data leakage through the exposure of the local models or data. Instead, the public dataset 𝒟0\mathcal{D}_{0} hosted on the server and deployed at each local node is used for a one-way knowledge distillation procedure, from locals to the server. Each fully-trained local model θk\theta_{k} and the server-based central model θs\theta_{\mathrm{s}} constitute a teacher-student knowledge transfer setup. Overall, we consider an ensemble of multiple teachers, one at each local node, which only communicate products inferred on public data 𝒟0\mathcal{D}_{0} to the server-based student.

3.2 Ensemble and Distillation

3.2.1 Conventional Ensemble Distillation

Let zkc=f(x0,θk,c)z^{c}_{k}=f(x_{0},\theta_{k},c) be the logits of a public data sample x0x_{0} corresponding to class c𝒞kc\in\mathcal{C}_{k}, produced by the model at local node kk, and the output of the central model be z~c=f(x0,θs,c)\tilde{z}^{c}=f(x_{0},\theta_{\mathrm{s}},c), where c{1,,C}c\in\{1,\ldots,C\}. The conventional ensemble z^c=1Kk=1Kzkc\hat{z}^{c}=\frac{1}{K}\sum_{k=1}^{K}z^{c}_{k} takes an average of all teachers’ logits and then employ activation σ(,)\sigma(\cdot,\cdot) using softmax to represent the probabilities that the sample belongs to class cc for:

σ(zc,τ)=exp(zc/τ)cexp(zc/τ)\sigma(z^{c},\tau)=\frac{exp(z^{c}/\tau)}{\sum_{c}{exp(z^{c}/\tau)}} (1)

where τ\tau is a temperature parameter. Taking the teachers’ ensembled soft labels as σ(z^c,τ)\sigma(\hat{z}^{c},\tau) and the student soft label as σ(z~c,τ)\sigma(\tilde{z}^{c},\tau) , conventional knowledge distillation employs the Kullback-Leibler divergence to update the student model:

=cσ(z^c,τ)logσ(z^c,τ)σ(z~c,τ).\mathcal{L}=\sum_{c}\sigma(\hat{z}^{c},\tau)\mathrm{log}{\frac{\sigma(\hat{z}^{c},\tau)}{\sigma(\tilde{z}^{c},\tau)}}. (2)

Hinton et al. [33] has shown that minimizing Eq. 2 with a high τ\tau is equivalent to minimizing the 2\ell_{2} error between the logits of teacher and student, thereby relating cross-entropy minimization to matching logits.

3.2.2 Importance of Weighted Ensemble Distillation

Let the global data distribution p0(x,y)p_{0}(x,y) of image xx and label yy be the target, while the local private data distribution can be indicated as pk(x,y)p_{k}(x,y). Due to the imbalance of data distribution among locals, we come to the bias ratio of local prediction :

p^k(x,y)=pk(x,y)p0(x,y)=pk(y)pk(x|y)p0(y)p0(x|y)pk(y)p0(y),\hat{p}_{k}(x,y)=\frac{p_{k}(x,y)}{p_{0}(x,y)}=\frac{p_{k}(y)p_{k}(x|y)}{p_{0}(y)p_{0}(x|y)}\thickapprox\frac{p_{k}(y)}{p_{0}(y)}, (3)

where we assume pk(x|y)p0(x|y)p_{k}(x|y)\thickapprox p_{0}(x|y) as the local difference on conditional probability distribution pk(x|y)p_{k}(x|y) is minor.

To consider this aspect, during the ensemble, we introduce an importance weight ω\omega for each local node to reflect the distribution of local private data that its model was initially trained with:

z^c=kωkczkc,ωkc=NkckNkc,\hat{z}^{c}=\sum_{k}{\omega_{k}^{c}z^{c}_{k}},~{}\omega_{k}^{c}=\frac{N_{k}^{c}}{\sum_{k}{N_{k}^{c}}}, (4)

where the importance weight ωc\omega^{c} is class-specific, which means the number of samples labeled as class cc for training the model of local node kk: Nkc=i=1|𝒟k|(yki(c)=1)N_{k}^{c}=\sum_{i=1}^{|\mathcal{D}_{k}|}({y}_{k}^{i}(c)=1). It reflects the distribution of local private data corresponding to each particular class cc. Without loss of generality, we denote the local model output as zk=f(x0,θk)z_{k}=f(x_{0},\theta_{k}) and central model output as z~=f(x0,θs)\tilde{z}=f(x_{0},\theta_{\mathrm{s}}). Specifically, for the classification task, we have zk=[zk1,,zkC]z_{k}=[z_{k}^{1},...,z_{k}^{C}] and z~=[z~1,,z~C]\tilde{z}=[\tilde{z}^{1},...,\tilde{z}^{C}] for local model and central model respectively. Following the aforementioned 2\ell_{2} observation from [33], we consider the case of τ\tau\rightarrow\infty, hence expressing the logit loss as:

w(z~,z^)=z~z^.\mathcal{L}_{\text{w}}({\tilde{z}},{\hat{z}})=\|{\tilde{z}}-{\hat{z}}\|. (5)

3.2.3 Attention Bounded Ensemble Distillation

The above-mentioned distillation essentially captures the divergence between the final output from the teacher and student models. However, it provides little insight into the underlying structure knowledge or reasoning of the teacher models, which can be complementary and important to the final output, especially in the FL scenario with its high degree of heterogeneity in local data sources. Although intuitive, it challenges transferring structural and more comprehensive knowledge. Structural knowledge, such as intermediate feature representations, suffers from a high bandwidth burden and, in most cases, relies on an identical network architecture among central (student) and all the local (teacher) models. Therefore we turn to more concise attention interpretations to transfer knowledge in a more efficient (as opposed to full feature tensors) and effective (as opposed to output vectors only) way without risking privacy leakage or posting additional communication/architecture requirements.

Specifically, we propose a bounded constraint for attention ensemble distillation to achieve consensus while maintaining the local node’s’ inherent diversity. Let 𝑨kHW\bm{A}_{k}\in\mathbb{R}^{HW} be the attention map produced by the kk-th local model, where HH and WW represent the 2D size of the attention map. Given the set of local attention maps 𝒜={𝑨k|k=1,,K}\mathcal{A}=\{\bm{A}_{k}|k=1,...,K\}, we take the spatial-wise minimum among them as the attention consensus 𝑰\bm{I}, and take the spatial-wise maximum to represent the attention diversity 𝑼\bm{U} among 𝒜\mathcal{A}:

Ihw=minkAkhw,Uhw=maxkAkhw,I^{hw}=\min\limits_{k}A_{k}^{hw},U^{hw}=\max\limits_{k}A_{k}^{hw}, (6)

where h=1,,Hh=1,...,H and w=1,,Ww=1,...,W. 𝑰\bm{I} denotes a consensus on the high-response region, among all the local attention maps, that have a high probability to be the real attention. While 𝑼\bm{U} considers all the high-response regions among the local attention maps, it also preserves diversity of “expertise” among the local models.

For simplicity, we denote 𝑨~\tilde{\bm{A}} as the attention map generated by the central model. Considering the attention consensus 𝑰\bm{I} as a lower bound, we enforce the response in 𝑨~\tilde{\bm{A}} to explicitly activate at the region of consensus achieved by all locals:

low(𝑨~,𝑰)=1HWh,wIhwT(A~hw)h,wIhw.\mathcal{L}_{\text{low}}(\tilde{\bm{A}},\bm{I})=-\frac{1}{HW}\frac{\sum_{h,w}I^{hw}\cdot T(\tilde{A}^{hw})}{\sum_{h,w}I^{hw}}. (7)

T()T(\cdot) is a soft-masking operation based on sigmoid [51]:

T(A)=11+exp(ρ(Ab)).T({A})=\frac{1}{1+exp(-\rho({A}-b))}. (8)

Considering all the high-response regions among locals 𝑼\bm{U} as the upper bound, we enforce the response of 𝑨~\tilde{\bm{A}} to be explicitly lower than that of 𝑼\bm{U}:

up(𝑨~,𝑼)=1HWh,wA~hwT(Uhw)h,wA~hw,\mathcal{L}_{\text{up}}(\tilde{\bm{A}},\bm{U})=-\frac{1}{HW}\frac{\sum_{h,w}\tilde{A}^{hw}\cdot T(U^{hw})}{\sum_{h,w}\tilde{A}^{hw}}, (9)

The intuition here is that we seek each high-response pixel in 𝑨~\tilde{\bm{A}} to have support from at least one local model to consider model diversity.

Compared to the naive distillation that enforces precisely the same attention strength as one aggregated attention map of diverse local nodes, our designed attention bound constraint is a relaxed version, i.e., tolerating incorrect/biased local attention maps. Its high robustness to outliers enables it to handle the inherent heterogeneity among locals more efficiently.

Refer to caption
Figure 3: Illustration of how ensemble attention can effectively guide the central model to focus on the correct region. In ensemble attention visualization, we threshold the activation at 0.5 and color the activated area red, green, and blue for the three locals, respectively. We contour the boundary of attention consensus/diversity region with yellow/black.

3.3 Application to different tasks

3.3.1 Image Classification / Segmentation

For the classification task, the heterogeneity of FL mainly comes from its inability to cope with the more general scenario of local nodes not sharing the same target classes. We note that Eq. 4 can deal with this issue efficiently.

We employ top-down attention generated with Grad-CAM [22] to provide location cues for class activation reasoning. Feeding an image to the model obtains a raw score zcz^{c} (before the activation layer) for each class cc. The gradient of score zcz^{c} is computed with respect to the feature maps in a convolutional layer [𝑭1,𝑭2,,𝑭J][\bm{F}_{1},\bm{F}_{2},...,\bm{F}_{J}], where 𝑭jH×W\bm{F}_{j}\in\mathbb{R}^{H\times W}, JJ is the channel size, HH and WW indicates size of a 2D feature. These gradients can be globally averaged to obtain the neuron importance βjc\beta_{j}^{c} corresponding to 𝑭j\bm{F}_{j}:

βjc=1HWhwzcFjhw.\beta_{j}^{c}=\frac{1}{HW}\sum_{h}\sum_{w}\frac{\partial z^{c}}{\partial F_{j}^{hw}}. (10)

All 𝑭j\bm{F}_{j} weighted by βjc\beta_{j}^{c} are combined and activated with ReLU to get the a class-specific attention map 𝑨cH×W\bm{A}^{c}\in\mathbb{R}^{H\times W}:

𝑨c=ReLU(jβjc𝑭j).\bm{A}^{c}=\text{ReLU}(\sum_{j}{\beta_{j}^{c}\cdot\bm{F}_{j}}). (11)

We then normalize the attention maps to have all values lie between 0 and 1: Ahwc=AhwcmaxhwAhwcA_{hw}^{c}=\frac{A_{hw}^{c}}{\max_{hw}A_{hw}^{c}}. When employing the attention-bound loss, Eq.7 and Eq.9, the class-specific attention maps are taken independently, and thus the overall loss function for classification can be written as:

cls=1Ccw(z~c,z^c)+low(𝑨~c,𝑰c)+up(𝑨~c,𝑼c).\mathcal{L}_{\text{cls}}=\frac{1}{C}\sum_{c}\mathcal{L}_{\text{w}}(\tilde{z}^{c},\hat{z}^{c})+\mathcal{L}_{\text{low}}(\tilde{\bm{A}}^{c},\bm{I}^{c})+\mathcal{L}_{\text{up}}(\tilde{\bm{A}}^{c},\bm{U}^{c}). (12)

Segmentation can be seen as pixelwise/voxelwise classification task while differing from the above mainly in two aspects: 1) the model’s prediction 𝒛cHW\bm{z}^{c}\in\mathbb{R}^{HW} is the same shape as its input (denoted as 2D here for simplicity); 2) we directly employ 𝒛c\bm{z}^{c} for the attention 𝑨c=σ(𝒛c,τ)\bm{A}^{c}=\sigma(\bm{z}^{c},\tau) relating to the activation bound constraint.

3.3.2 Image Reconstruction

For image reconstruction tasks, we rewrite the weighted ensemble as:

𝒛^=kωk𝒛k,ωk=|𝒟k|k|𝒟k|,\bm{\hat{z}}=\sum_{k}{\omega_{k}\cdot\bm{z}_{k}},~{}\omega_{k}=\frac{|\mathcal{D}_{k}|}{\sum_{k}|\mathcal{D}_{k}|}, (13)

where |𝒟k||\mathcal{D}_{k}| denotes the number of samples used to train the model at local node kk, and model output 𝒛{\bm{z}} is with 2D image size. We employ a non-local self-attention module [23] to capture spatial-wise dependencies of 2D features. For one batch, given the feature maps 𝑭\bm{F} with size J×H×WJ\times H\times W, we reshape 𝑭\bm{F} to 𝑭¯\bar{\bm{F}} (with size J×HWJ\times HW) and then calculate the spatial-wise similarity 𝑺HW×HW\bm{S}\in\mathbb{R}^{HW\times HW} via dot product (matrix multiplication): 𝑺=𝑭¯T𝑭¯\bm{S}=\bar{\bm{F}}^{\text{T}}\cdot\bar{\bm{F}}. Then 𝑺\bm{S} is normalized into spatial-wise attention 𝑨\bm{A} using softmax along the first dimension:

Ahw=exp(Shw)h=1HWexp(Shw).A^{hw}=\frac{exp(S^{hw})}{\sum_{h=1}^{HW}exp(S^{hw})}. (14)

The normalized similarity is used to enhance the features 𝑭\bm{F}:

𝑭=𝑭+Reshape((𝑨𝑭¯T)T),\bm{F}=\bm{F}+\text{Reshape}(({\bm{A}}\cdot\bar{\bm{F}}^{\text{T}})^{\text{T}}), (15)

where Reshape()\text{Reshape}(\cdot) is to reshape the size of J×HWJ\times HW to J×H×WJ\times H\times W. The overall loss function for image reconstruction can be written as:

recon=w(𝒛~,𝒛^)+low(𝑨~,𝑰)+up(𝑨~,𝑼).\mathcal{L}_{\text{recon}}=\mathcal{L}_{\text{w}}(\tilde{\bm{z}},\hat{\bm{z}})+\mathcal{L}_{\text{low}}(\tilde{\bm{A}},\bm{I})+\mathcal{L}_{\text{up}}(\tilde{\bm{A}},\bm{U}). (16)

The overall process is explained in Algorithm 1.

Algorithm 1 FedAD on classification/ reconstruction
  Input: Labeled private dataset {𝒟k}\{\mathcal{D}_{k}\}, unlabeled public data 𝒟0\mathcal{D}_{0}, central model θs\theta_{\mathrm{s}}, local models {θk}\{\theta_{k}\}, TT distillation steps.
  Local Training: Train each local model θk\theta_{k} with 𝒟k\mathcal{D}_{k}
  for each distillation step t=1,,Tt=1,...,T do
     𝒙0\bm{x}_{0} \leftarrow a batch of public data from 𝒟0\mathcal{D}_{0}
     for k=1,,Kk=1,...,K  do
        𝒛k\bm{z}_{k}, 𝑨k\bm{A}_{k} f(𝒙0,θk)\leftarrow f(\bm{x}_{0},\theta_{k})                    \triangleright Eq. 11/ Eq. 14
     end for
     𝒛^\hat{\bm{z}}\leftarrow ensemble {𝒛k}\{\bm{z}_{k}\}                       \triangleright Eq. 4/ Eq. 13
     𝑰,𝑼{\bm{I}},{\bm{U}}\leftarrow ensemble {𝑨k}\{\bm{A}_{k}\}                  \triangleright Eq. 6
     𝒛~,𝑨~f(𝒙0,θs)\tilde{\bm{z}},\tilde{\bm{A}}\leftarrow f(\bm{x}_{0},\theta_{\mathrm{s}})                           \triangleright Eq. 11/ Eq. 14
     Update: θsθsθs\theta_{\mathrm{s}}\leftarrow{\theta_{\mathrm{s}}}-\nabla_{\theta_{\mathrm{s}}}\mathcal{L}                  \triangleright Eq. 12/ Eq. 16
  end for
Table 1: Results on CXR14 and CheXpert with in/cross-domain local nodes (KdK_{d}=3, α=1\alpha=1). Models are tested on NIH CXR14 (12 classes) and CheXpert (8 classes). “Centralized” denotes the result of centralized training with all samples from one dataset. “Standalone” denotes the average performance of all distributed local models. “E.Cardio” abbreviates “Enlarged Cardiomediastinum”. When training and testing samples are from different datasets, only the mean AUC on the six overlapping classes are listed (mAUCoverlap{}^{\textit{overlap}}). ‘mAUCtotal{}^{\textit{total}}’ denotes the mean AUC on all classes of the two test sets, which are 12 and 8 for CXR14 and CheXpert respectively.
NIH CXR14 CheXpert Cross-domain
Centralized Standalone FedAD Centralized Standalone FedAD FedAD
Pathology testCXR{}^{\text{{CXR}}} testCheXpert{}^{\text{{CheXpert}}} testCXR{}^{\text{{CXR}}} testCheXpert{}^{\text{{CheXpert}}} testCXR{}^{\text{{CXR}}} testCheXpert{}^{\text{{CheXpert}}} testCXR{}^{\text{{CXR}}} testCheXpert{}^{\text{{CheXpert}}} testCXR{}^{\text{{CXR}}} testCheXpert{}^{\text{{CheXpert}}} testCXR{}^{\text{{CXR}}} testCheXpert{}^{\text{{CheXpert}}} testCXR{}^{\text{{CXR}}} testCheXpert{}^{\text{{CheXpert}}}
Cardiomegaly 86.18 82.72 81.34 64.73 78.4 52.22 84.18 77.94 71.81 62 67.91 72.14 82.35 68.91
Emphysema 89.21 - 84.29 - 81.75 - - - - - - - 82.13 -
Hernia 85.39 - 83.63 - 84.57 - - - - - - - 83.62 -
Infiltration 70.97 - 63.17 - 69.61 - - - - - - - 69.49 -
Mass 78.86 - 73.47 - 74.5 - - - - - - - 74.91 -
Nodule 77.38 - 69.23 - 70.56 - - - - - - - 69.99 -
Atelectasis 76.57 88.17 72.27 71.85 72.6 74.42 72.16 76.64 66.24 76.12 57.72 84.03 71.93 84.41
Pneumothorax 85.56 78.86 81.18 73.64 79.47 78.04 84.9 84.95 74.35 64.7 61.78 64.78 80.82 72.76
Pneumonia 70.77 75.57 68.72 72.07 68.26 79.83 70.54 84.91 65.06 75.9 64.98 90.32 68.4 93.27
Fibrosis 80.35 - 74.12 - 72.31 - - - - - - - 72.39 -
Edema 82.96 85.84 80.6 74.63 80.4 83.37 79.41 88.9 77.33 84.17 75.32 83.82 80.39 80.95
Consolidation 72.6 85.29 68.96 75.68 69.57 89.03 70.93 92.56 63.8 83.85 67.01 92.56 70.08 94.33
E.Cardio. - - - - - - - 61.21 - 57.72 - 83.66 - 79.64
Lung Opacity - - - - - - - 93.5 - 84.89 - 89.65 - 86.95
# class 12 6 12 6 12 6 6 8 6 8 6 8 12 8
mAUCoverlap{}^{\textit{overlap}} 79.11 82.74 74.51 72.1 74.78 76.15 77.02 84.32 69.76 74.46 65.79 81.27 75.66 82.44
mAUCtotal{}^{\textit{total}} 79.73 - 75.08 - 75.17 - - 82.58 - 73.67 - 82.62 75.54 82.65
Table 2: Ablation study on KK/α\alpha with single domain local nodes for chest x-ray classification. “Centralized” denotes the result of centralized training with all samples from one dataset.
Attention Dropout Centralized α=1\alpha=1 α=0.1\alpha=0.1
Bound Rate K=3K=3 K=5K=5 K=3K=3 K=5K=5
NIH 0 79.73 73.94 73.39 64.26 66.89
0 75.17 75.12 67.08 69.30
CXR14 1K\frac{1}{K} 73.72 74.01 64.81 67.95
2K\frac{2}{K} 69.87 73.64 60.93 65.88
CheXpert 0 82.58 81.31 80.76 75.09 76.86
0 82.62 82.71 77.23 79.50
1K\frac{1}{K} 80.78 81.52 74.42 76.17
2K\frac{2}{K} 78.25 78.65 72.10 74.61

3.4 Cross-domain Analysis

Our method maintains generalizability while distilling knowledge from multiple locals with cross-domain public data. Built upon the theories from domain adaptation [52, 53, 54], this section gives a theoretical analysis with a performance bound for the aggregated central model.

We suppose the input space is denoted by 𝒳\mathcal{X}, and the source and target domain are represented as 𝒟S\mathcal{D}^{S} and 𝒟T\mathcal{D}^{T}, respectively. Given hh as the hypothesis function and gg as the ground-truth labeling function, we can infer the error as ϵ𝒟S(h,g)=𝔼x𝒟S[|h(x)g(x)|]\epsilon_{\mathcal{D}^{S}}(h,g)=\mathbb{E}_{x\sim\mathcal{D}^{S}}[|h(x)-g(x)|], where ϵ𝒟S\epsilon_{\mathcal{D}^{S}} and ϵ𝒟T\epsilon_{\mathcal{D}^{T}} represent the risk of hh on 𝒟S\mathcal{D}^{S} and 𝒟T\mathcal{D}^{T}. To evaluate the distance between two domain distributions 𝒰\mathcal{U}, 𝒰\mathcal{U^{\prime}} on the hypothesis space \mathcal{H}, [52] introduces \mathcal{H}-divergence d(𝒰,𝒰)=2supA𝒜|Pr𝒟(A)Pr𝒟(A)|d_{\mathcal{H}}(\mathcal{U},\mathcal{U}^{\prime})=2\operatorname{sup}_{A\in\mathcal{A}_{\mathcal{H}}}|\operatorname{Pr}_{\mathcal{D}}(A)-\operatorname{Pr}_{\mathcal{D}^{\prime}}(A)|, where 𝒜\mathcal{A}_{\mathcal{H}} denotes a collection of subsets of 𝒳\mathcal{X} which support the hypothesis in \mathcal{H}. The symmetric different space is defined as Δ={h(x)h(x)|h,h}\mathcal{H}\Delta\mathcal{H}=\{h(x)\bigoplus h^{\prime}(x)|h,h^{\prime}\in\mathcal{H}\} (\bigoplus represents the XOR operation). Then we have the following theorem for the generalizability between two domains [53]:
Theorem 1. Generalization bounds. Let \mathcal{H} be a hypothesis space with VC dimension dd, 𝒰S\mathcal{U}^{S} and 𝒰T\mathcal{U}^{T} each be unlabeled samples of size NN, drawn from 𝒟S\mathcal{D}^{S} and 𝒟T\mathcal{D}^{T} respectively. For any hh\in\mathcal{H} and δ(0,1)\delta\in(0,1), the following holds with probability at least 1δ1-\delta (over the choice of the samples):

ϵ𝒟T(h)\displaystyle\epsilon_{\mathcal{D}^{T}}(h)\leq ϵ𝒟S(h)+12dΔ(𝒰S,𝒰T)\displaystyle\epsilon_{\mathcal{D}^{S}}(h)+\frac{1}{2}{d}_{\mathcal{H}\Delta\mathcal{H}}(\mathcal{U}^{S},\mathcal{U}^{T}) (17)
+42dlog(2N)+log(2δ)N+λ,\displaystyle+4\sqrt{\frac{2d\operatorname{log}(2N)+\operatorname{log}(\frac{2}{\delta})}{N}}+\lambda,

where λ=ϵ𝒟S(h)+ϵ𝒟T(h)\lambda=\epsilon_{\mathcal{D}^{S}}(h^{*})+\epsilon_{\mathcal{D}^{T}}(h^{*}) and hh^{*} is the ideal joint hypothesis minimizing the combined error: h=argminhϵ𝒟S(h)+ϵ𝒟T(h)h^{*}=\operatorname{argmin}_{h\in\mathcal{H}}\epsilon_{\mathcal{D}^{S}}(h^{*})+\epsilon_{\mathcal{D}^{T}}(h^{*}).

In our case, 𝒟S\mathcal{D}^{S} is the domain of private data across KK local nodes 𝒟S={𝒟k}\mathcal{D}^{S}=\{\mathcal{D}^{k}\}, and 𝒟T\mathcal{D}^{T} = 𝒟0\mathcal{D}^{0} is the domain of public data, where we assume |𝒟0|=N|\mathcal{D}^{0}|=N and k|𝒟k|=N\sum_{k}|\mathcal{D}^{k}|=N. Given a local model h𝒟kh_{\mathcal{D}^{k}} trained on data 𝒟k\mathcal{D}^{k}, we learn a central model h𝒟0h_{\mathcal{D}^{0}} with unlabeled public data 𝒟0{\mathcal{D}^{0}} through weighted aggregation: h𝒟0=kωkh𝒟kh_{\mathcal{D}^{0}}=\sum_{k}\omega_{k}h_{\mathcal{D}^{k}}, where kωk=1\sum_{k}\omega_{k}=1. As proved in [55], the overall private data is 𝒰S=kωk𝒰k\mathcal{U}^{S}=\sum_{k}\omega_{k}\mathcal{U}^{k}, and dΔ(kωk𝒰k,𝒰0)kωk(12dΔ(𝒰k,𝒰0)){d}_{\mathcal{H}\Delta\mathcal{H}}(\sum_{k}\omega_{k}\mathcal{U}^{k},\mathcal{U}^{0})\leq\sum_{k}\omega_{k}(\frac{1}{2}{d}_{\mathcal{H}\Delta\mathcal{H}}(\mathcal{U}^{k},\mathcal{U}^{0})). We then rewrite Eq. 17 and have the weighted generalization bound as Eq. 18, where we note that the test error of central model ϵ𝒟0\epsilon_{\mathcal{D}^{0}} is bounded by that of local model ϵ𝒟k\epsilon_{\mathcal{D}^{k}}, the domain gap between local data and public data dΔ(𝒰k,𝒰0){d}_{\mathcal{H}\Delta\mathcal{H}}(\mathcal{U}^{k},\mathcal{U}^{0}), the function VC dimension dd, and the data size NN.

ϵ𝒟0(h𝒟0)\displaystyle\epsilon_{\mathcal{D}^{0}}(h_{\mathcal{D}^{0}}) ϵ𝒟k(kωkh𝒟k)+12dΔ(kωk𝒰k,𝒰0)\displaystyle\leq\epsilon_{\mathcal{D}^{k}}(\sum_{k}{\omega_{k}h_{\mathcal{D}^{k}}})+\frac{1}{2}{d}_{\mathcal{H}\Delta\mathcal{H}}(\sum_{k}\omega_{k}\mathcal{U}^{k},\mathcal{U}^{0}) (18)
+42dlog(2N)+log(2δ)N+λω\displaystyle\quad+4\sqrt{\frac{2d\operatorname{log}(2N)+\operatorname{log}(\frac{2}{\delta})}{N}}+\lambda_{\omega}
ϵ𝒟k(kωkh𝒟k)+kωk(12dΔ(𝒰k,𝒰0))\displaystyle\leq\epsilon_{\mathcal{D}^{k}}(\sum_{k}{\omega_{k}h_{\mathcal{D}^{k}})}+\sum_{k}\omega_{k}\left(\frac{1}{2}{d}_{\mathcal{H}\Delta\mathcal{H}}(\mathcal{U}^{k},\mathcal{U}^{0})\right)
+42dlog(2N)+log(2δ)N+λω.\displaystyle\quad+4\sqrt{\frac{2d\operatorname{log}(2N)+\operatorname{log}(\frac{2}{\delta})}{N}}+\lambda_{\omega}.

4 Experiments

We employ one-shot distillation (each local model transfers its prediction over public data only once, and these local products are used for numerous steps of central training) for bandwidth-sensitive tasks like segmentation and reconstruction for communication efficiency.

4.1 Chest X-Ray Image Classification

4.1.1 Datasets

We evaluate our method on a multi-label classification task with standard chest-x-ray datasets: NIH chestX-ray14 (NIH CXR14) [56] and CheXpert [57] as locally held private data. NIH CXR14 consists of 112,120 frontal-view x-ray images scanned from 32,717 patients labeled with 14 diseases. CheXpert contains 224,316 chest radiographs from 65,240 patients labeled with 14 common chest radiographs, including both frontal-view and lateral-view. For public data, we use 26,684 x-ray images in the RSNA Pneumonia Detection Challenge public data [58] without using their labels.

4.1.2 Implementation

We use NIH CXR14 and CheXpert as domains where private data comes from. For samples with multiple positive labels, we randomly choose one and split the dataset across locals using the Dirichlet distribution as in most FL works [4], in which the value of α\alpha controls the degree of non-IID-ness. A smaller α\alpha indicates higher non-IID-ness. For the total of KK local nodes, each dataset is distributed to Kd=K/2K_{d}=K/2 local nodes under the cross-domain scenario. Following the validation strategy in [59], for both datasets we randomly sample a fraction (10%) of the training data to form the validation set. For training, we use ResNet-34 with a batch size of 32 and the same data augmentation methods as in [59]. Each local model is trained individually with SGD and CosineAnnealing [60] and a decreasing learning rate from 1e-3 to 1e-6 across 15 epochs. We use SGD and CosineAnnealing for distillation and a decreasing learning rate from 1e-2 to 1e-3 across 20 epochs.

Table 3: Ablation study on |𝒟0||\mathcal{D}_{0}| (K=3K=3, α=1\alpha=1) using CheXpert as local datasets.
# samples in 𝒟0\mathcal{D}_{0}
Centralized 1000 5000 10000 15000 20000 26684
82.58 81.08 81.23 81.36 81.57 82.37 82.62
Table 4: Comparison with parameter based and distillation based FL methods on chest-x-ray image classification task using CheXpert as local datasets (K=3K=3, α=1\alpha=1). “Centralized” denotes the result of centralized training with all samples from one dataset.
Centralized FedAvg [2] FedDF [20] FedMD [17] Ours
Distillation - N Y Y Y
Param. Trans. - Y Y N N
Privacy -
Asynchronous -
mAUC(%) 82.58 79.03 82.94 77.66 82.62

4.1.3 Results

We first study the cases when local samples are from one dataset. Table 2 shows results w.r.t. when local data are within the domain with a varying number of locals KK and non-IID-ness α\alpha. One can note that our method outperforms centralized training with all local data on CheXpert when KK=3 and α\alpha=1. Table  3 shows training results with varying public dataset sizes. The results suggest that, although unlabeled, a more extensive public dataset improves performance. To compare with the SOTA FL methods, Table 4 shows whether the method in comparison transfers parameters or employs distillation and analyzes each privacy guarantee and synchronization requirements. We can see that our method outperforms the counterparts with the best utility-privacy trade-off.

With both datasets as private data, we conduct cross-domain, cross-site evaluations with FedAD. We distribute the training datasets to KK=6 local nodes (KdK_{d}=3 for each dataset). Table 1 shows the results of FedAD on cross-domain datasets with α\alpha=1. On CheXpert, both single domain and cross-domain distillation achieve better performance than centralized training: cross-domain learning with FedAD obtains the best mAUC of 82.65%, slightly better than FedAD training with only data from CheXpert (82.62%). On NIH CXR14, while training with all data centrally yields the best result (79.73%), cross-domain trained FedAD still obtains better performance (75.54%) compared to FedAD trained with single domain data (75.17%) and compared to the average AUC of local nodes (75.08%). Table 1 also reports results on the six overlapping classes in each test set. It can be seen that FedAD trained with cross-domain data obtains superior performance compared to the two FedAD models trained with single domain data on both test datasets. We note that the FedAD model trained with cross-domain data can classify the 14 classes in total (12 classes from CheXpert and the 8 classes from NIH CXR14), whereas other models can only classify 8 or 12 classes.

Table 5: Comparisons on BraTS in terms of average Dice score over voxel-level annotations of “whole tumour”, “tumour core”, “enhancing tumour”, and communication efficiency attributes.
Method Average Communication Efficiency Transmit Privacy
Dice(%)\uparrow Bandwidth(GB)\downarrow Asynchronous Cost/Risk\downarrow
Li et al. [13] 84.33 64.37 Parameter \infty
81.28 ϵ1\epsilon_{1}=1, ϵ3\epsilon_{3}=1
80.01 ϵ1\epsilon_{1}=1, ϵ3\epsilon_{3}=0.01
FedMD [17] 75.71 2154.84 Distillation 0
Ours 77.85 13.36 Distillation 0
Standalone 73.38±\pm 3.44 - - - -
Table 6: Ablation study on output ensemble scheme, attention lower/upper bound, and the modality of public data. T1-weighted images from B, F, I training set are used as local data, and T1-weighted images from B, F, I testing set are used as evaluation.
Ensemble scheme [20, 12] Eq.13 Eq.13 Eq.13 Eq.13 Eq.13
Non-local module (Eq.15)
Attention lower bound (Eq.7)
Attention upper bound (Eq.9)
Unlabeled public data 𝒟0\mathcal{D}_{0} T1w T1w T1w T1w T1w T2w
SSIM \uparrow 0.8892 0.9097 0.9108 0.9147 0.9161 0.9112
PSNR \uparrow 32.91 33.20 33.24 33.30 33.38 33.05

4.2 3D Brain Tumor Segmentation

4.2.1 Dataset

The BraTS 2018 dataset [61] contains multi-parametric pre-operative MRI scans of 285 subjects with brain tumors. Each subject was scanned under the T1-weighted, T1-weighted with contrast enhancement, T2-weighted, and T2 fluid-attenuated inversion recovery (T2-FLAIR) modalities.

4.2.2 Implementation

Following the same protocol as in [13], we use 242 subjects for the training set and 43 subjects for held-out test set. According to the originating institution, the training set is stratified into three federated local clients. We use the unlabeled validation set of the BraTS 2020 dataset [62] as the public data comprising 125 subjects, independent of either private dataset. We use the same model structure as [13, 63] but with half its channel numbers. We only use its segmentation branch (no reconstruction branch). The training strategy is the same as [63]. The local training takes 20,000 iterations, and local-to-central distillation takes 5,000 iterations.

4.2.3 Results

Table 5 compares the segmentation performance on BraTS with the SOTA parameter based method [13] and the SOTA distillation based method [17]. Note we report both naive (non-private) and noisy (less-private) version of [13]. The comparison shows that our method achieves the best utility-privacy trade-off. We can also observe that when compared with [17], our method achieves better results with much more efficiency, i.e., lower communication bandwidth and no local synchronization requirement at the same time.

Table 7: Results on in-domain testing sets for MRI image reconstruction. “Standalone” denotes the locally trained model with individual private data. Under the FL setting, we compare our method with FL-MRCM [64] quantitatively in terms of SSIM, PSNR, and communication Bandwidth.
Privacy Flexibility Data T1-weighted T2-weighted Bandwidth
Train Test SSIM\uparrow PSNR\uparrow SSIM\uparrow PSNR\uparrow (GB)\downarrow
Standalone - - B B 0.9743 38.81 0.9694 36.53 -
F 0.7787 29.16 0.8028 27.35
I 0.8948 31.02 0.7692 27.47
F B 0.9125 33.82 0.9250 33.96
F 0.9360 33.83 0.9522 33.84
I 0.9146 31.26 0.9003 30.67
I B 0.9421 34.98 0.9111 31.76
F 0.8919 31.80 0.9092 30.61
I 0.9615 34.91 0.9336 32.14
Central -ize - - B,F,I B 0.9557 37.19 0.9398 34.08 -
F 0.9335 34.58 0.9002 30.47
I 0.9451 33.55 0.8873 30.95
[64] Online B,F,I B 0.9577 36.88 0.9308 34.28 868.8
F 0.9023 33.63 0.8974 31.24
I 0.9362 33.29 0.8778 30.44
Ours Offline B,F,I B 0.9111 34.55 0.9199 34.06 866.0
F 0.9182 33.37 0.9374 32.76
I 0.9173 31.72 0.9058 30.93

4.3 Brain Magnetic Resonance Image Reconstruction

4.3.1 Datasets

Following the prior art [64], we use fastMRI [65], IXI [66], BraTS[67] as private data distributed across local nodes and evaluate the corresponding test sets (same data split as [64]). Guo et al. [64] reports results with four brain MRI datasets, of which HPKS [68] is privately held and not available for use at the moment; so we experiment with the remaining publicly available sets as local data: fastMRI, IXI, and BraTS. Besides, we use OASIS-3 [69] as a public dataset.

FastMRI [65] (abbreviated as F): for fastMRI we use T1-weighted images from 3,443 subjects. Of the 3,443, data from 2,583 subjects are used for training, and 860 are used for testing. Besides, we use T2-weighted images from 3,832 subjects, of which 2,874 subjects are for training, and the rest 958 subjects are used for testing. Each subject consists of approximately 15 axial cross-sectional images of brain tissues.

BraTS [67] (abbreviated as B): BraTS2020 is composed of 494 subjects available for both T1 and T2-weighted modalities. Of these, 369 subjects are used for training and 125 subjects for testing. Each subject includes approximately 120 axial cross-sectional images of brain tissues for both modalities.

IXI [66] (abbreviated as I): IXI dataset has 581 subjects available for the T1-weighted modality, among which 436, 55, and 90 subjects are used for training, validation, and testing respectively. For the T2-weighted modality, there are 578 subjects, of which data from 434 subjects are for training, 55 for validation, and 89 for testing. Approximately 150 and 130 axial cross-sectional images of brain tissues for T1 and T2-weighted, respectively, are provided for each subject.

OASIS-3 [69]: Open Access Series of Imaging Studies (OASIS-3) is a multi-modal dataset. It contains 3,388 subjects for T1w and 3,598 subjects for T2w. All sessions are collected with a 16-channel head coil on 1.5T scanners.

Refer to caption
Figure 4: Qualitative results of different methods when training with T1/T2-weighted B, F, I as local data and testing on T1 B, T2 B, T1 F, T2 F, T1 I, T2 I test set respectively. The second column of each sub-figure is the error map (absolute difference) between the reconstructed images and the ground truth.
Table 8: Results on cross-domain testing sets for MRI image reconstruction.
Method Transmit Privacy Data T1-weighted T2-weighted
Train Test SSIM\uparrow PSNR\uparrow # test subjects wAverage SSIM\uparrow PSNR\uparrow # test subjects wAverage
SSIM\uparrow PSNR\uparrow SSIM\uparrow PSNR\uparrow
FedAvg[2] Parameter B,F I 0.9086 31.46 90 0.8533 31.49 0.8532 29.42 89 0.8608 30.36
F,I B 0.9268 34.69 125 0.9062 33.11 128
B,I F 0.8369 31.03 860 0.8556 30.09 958
FL-MRCM[64] Parameter B,F I 0.9157 31.74 90 0.8553 32.39 0.8354 29.28 89 0.8632 30.00
F,I B 0.9486 35.78 125 0.9041 33.15 128
B,I F 0.8354 31.96 860 0.8604 29.66 958
FedDF[20] Parameter B,F I 0.9209 32.11 90 0.8665 32.75 0.8781 30.20 89 0.8819 30.77
++ F,I B 0.9561 35.92 125 0.9154 33.96 128
Distillation B,I F 0.8479 32.36 860 0.8775 30.37 958
Ours Distillation B,F I 0.9141 31.26 90 0.8521 31.54 0.8883 29.92 89 0.8827 30.59
F,I B 0.9052 33.25 125 0.8964 33.24 128
B,I F 0.8533 31.62 860 0.8805 30.31 958
Centralized - - B,F I 0.9015 31.03 90 0.8827 32.76 0.8763 29.22 89 0.8846 30.56
F,I B 0.9246 34.75 125 0.9045 33.07 128
B,I F 0.8747 32.65 860 0.8827 30.33 958

4.3.2 Implementation

Following [64], we subsample the given k-space by multiplying with a mask, where the acceleration factor (AF) is set as 4. The 2D MRI images are preprocessed with zero padding and then cropped to the size of 256×256256\times 256. We utilize the same U-Net [70] style encoder-decoder architecture for the reconstruction networks as the one provided in FL-MRCM [64]. A minor difference in the architecture with [64] comes from the additional residual non-local block 15 deployed on the bottom features of U-Net before forwarding into a sequence of the up-sampling layer. The size of these features J×H×WJ\times H\times W are 512×16×16512\times 16\times 16. For local training, the network is trained with an Adam optimizer using a constant learning rate of 1e41e^{-4} with 20 epochs. For distillation, the central model is trained with an RMSprop optimizer using a constant learning rate of 1e41e^{-4} with five epochs in one communication round.

4.3.3 Results

We first perform ablation studies on the impact of the output ensemble scheme, the proposed attention distillation bound, and the modality of unlabeled public data. Here we use the T1-weighted images from BraTS (B), fastMRI (F), and IXI (I) as private datasets and unlabeled T1-weighted images from OASIS-3 as public data; and we report results on the aggregated T1-weighted test images from B, F, and I. From Table 6 we can see the superiority of the importance-weighted ensemble beyond the average ensemble typically used in previous FL works [20, 12]. Our proposed attention to upper/lower bound constraint further improves the reconstruction performance with higher SSIM and PSNR. In addition, Table 6 shows the comparison of using T1-weighted and T2-weighted images of OASIS-3 as public data. The results demonstrate the robustness of our method to public data from a different domain (locally held data and data used for distillation are all from other datasets) and even different modalities (all three local nodes have T1-weighted images for training while public data is T2-weighted).

Second, we compare the performance with the prior arts [64] when taking T1/T2-weighted images from B, F, and I as local data and unlabeled T1/T2-weighted images of OASIS-3 as public data. The results are reported on the corresponding test sets of local data, respectively. From Table 7 we can see that, when compared with the prior art [64], our method achieves very competitive reconstruction performance in terms of SSIM and PSNR with higher communication efficiency while at the same time maintaining the local data privacy by not sharing local model parameters/gradients or any product inferred from local private local data. The counterpart [64] not only iteratively shares local model parameters but also shares the features inferred on each local private data. Besides the superior guarantees w.r.t. data privacy, our method demonstrates higher communication efficiency through lower bandwidth and higher flexibility (offline communication without any synchronization requirements on the local model). Notably, when testing on T2-weighted F and I, we achieved better performance than centralized training (collecting local data together for training), e.g., on T2-weighted F, we achieved 0.9374/32.76 SSIM/PSNR over the 0.9002/30.47 SSIM/PSNR of centralized training. The reason is that we utilize additional unlabeled, non-sensitive, cross-domain public data, which, we assume, are easily acquired in real-world clinical scenarios. Comparisons of qualitative results are shown in Figure 4.

We leverage cross-domain test sets (different domains with the local data) to evaluate the generalizability. Table 8 compares our method with the SOTA FL methods. It shows that our privacy-preserving method owns comparable generalizability with the prior arts [2, 64, 20], which share iterative local model parameters and therefore risk privacy leakage.

5 Conclusions

In this work, we propose a novel distillation-based federated learning framework (FedAD) that can, in principle, preserve local data privacy by using only unlabeled and domain-robust public data. To address the communication bottleneck comprehensively, we employ a one-way (offline) knowledge distillation process with an importance-weighted ensemble and attention-bound constraints. We demonstrate that our proposed attention ensemble scheme can balance the consensus and diversity across locals to handle the inherent heterogeneity in FL scenarios. Extensive experiments on various medical image analysis and imaging tasks including classification, segmentation, as well as MR reconstruction using cross-domain and heterogeneous data distributions highlight the efficacy of FedAD and its preservation of local data privacy. With privacy being a critical topic for real-world medical applications, we believe our proposed FL framework is able to facilitate privacy-abiding learning across various hospital sites and extend to other medical image applications such as object detection and instance segmentation. Future work includes further generalizing the FedAD framework so that it is more task-agnostic, and relaxing or eliminating the requirement of real data used in the distillation.

References

  • [1] H. Li, K. Ota, and M. Dong, “Learning iot in edge: Deep learning for the internet of things with edge computing,” IEEE Network, vol. 32, no. 1, pp. 96–101, 2018.
  • [2] B. McMahan, E. Moore, D. Ramage, S. Hampson, and B. A. y Arcas, “Communication-efficient learning of deep networks from decentralized data,” in Artificial Intelligence and Statistics.   PMLR, 2017, pp. 1273–1282.
  • [3] V. Smith, C.-K. Chiang, M. Sanjabi, and A. S. Talwalkar, “Federated multi-task learning,” in NeurIPS, 2017, pp. 4424–4434.
  • [4] T.-M. H. Hsu, H. Qi, and M. Brown, “Measuring the effects of non-identical data distribution for federated visual classification,” arXiv:1909.06335, 2019.
  • [5] T. Li, A. K. Sahu, M. Zaheer, M. Sanjabi, A. Talwalkar, and V. Smith, “Federated optimization in heterogeneous networks,” arXiv:1812.06127, 2018.
  • [6] Y. Zhao, M. Li, L. Lai, N. Suda, D. Civin, and V. Chandra, “Federated learning with non-iid data,” arXiv:1806.00582, 2018.
  • [7] H. Wang, M. Yurochkin, Y. Sun, D. Papailiopoulos, and Y. Khazaeni, “Federated learning with matched averaging,” ICLR, 2020.
  • [8] S. P. Karimireddy, S. Kale, M. Mohri, S. J. Reddi, S. U. Stich, and A. T. Suresh, “Scaffold: Stochastic controlled averaging for on-device federated learning,” ICML, 2020.
  • [9] M. J. Sheller, G. A. Reina, B. Edwards, J. Martin, and S. Bakas, “Multi-institutional deep learning modeling without sharing patient data: A feasibility study on brain tumor segmentation,” in MICCAI Brainlesion Workshop, 2018, pp. 92–104.
  • [10] D. Yang, Z. Xu, W. Li, A. Myronenko, H. R. Roth, S. Harmon, S. Xu, B. Turkbey, E. Turkbey, X. Wang et al., “Federated semi-supervised learning for covid region segmentation in chest ct using multi-national data from china, italy, japan,” Medical Image Analysis, p. 101992, 2021.
  • [11] T. Li, M. Sanjabi, A. Beirami, and V. Smith, “Fair resource allocation in federated learning,” ICLR, 2020.
  • [12] T.-M. H. Hsu, H. Qi, and M. Brown, “Federated visual classification with real-world data distribution,” ECCV, 2020.
  • [13] W. Li, F. Milletarì, D. Xu, N. Rieke, J. Hancox, W. Zhu, M. Baust, Y. Cheng, S. Ourselin, M. J. Cardoso et al., “Privacy-preserving federated brain tumour segmentation,” in International Workshop on Machine Learning in Medical Imaging, 2019, pp. 133–141.
  • [14] X. Li, Y. Gu, N. Dvornek, L. H. Staib, P. Ventola, and J. S. Duncan, “Multi-site fmri analysis using privacy-preserving federated learning and domain adaptation: Abide results,” Medical Image Analysis, vol. 65, p. 101765, 2020.
  • [15] L. Zhu, Z. Liu, and S. Han, “Deep leakage from gradients,” in NeurIPS, 2019.
  • [16] J. Geiping, H. Bauermeister, H. Dröge, and M. Moeller, “Inverting gradients–how easy is it to break privacy in federated learning?”
  • [17] D. Li and J. Wang, “Fedmd: Heterogenous federated learning via model distillation,” arXiv:1910.03581, 2019.
  • [18] H. Chang, V. Shejwalkar, R. Shokri, and A. Houmansadr, “Cronus: Robust and heterogeneous collaborative learning with black-box knowledge transfer,” arXiv:1912.11279, 2019.
  • [19] E. Jeong, S. Oh, H. Kim, J. Park, M. Bennis, and S.-L. Kim, “Communication-efficient on-device machine learning: Federated distillation and augmentation under non-iid private data,” arXiv:1811.11479, 2018.
  • [20] T. Lin, L. Kong, S. U. Stich, and M. Jaggi, “Ensemble distillation for robust model fusion in federated learning,” NeurIPS, 2020.
  • [21] X. Gong, A. Sharma, S. Karanam, Z. Wu, T. Chen, D. Doermann, and A. Innanje, “Ensemble attention distillation for privacy-preserving federated learning,” in ICCV, 2021.
  • [22] R. R. Selvaraju, M. Cogswell, A. Das, R. Vedantam, D. Parikh, and D. Batra, “Grad-cam: Visual explanations from deep networks via gradient-based localization,” in ICCV, 2017, pp. 618–626.
  • [23] X. Wang, R. Girshick, A. Gupta, and K. He, “Non-local neural networks,” in CVPR, 2018, pp. 7794–7803.
  • [24] Y. Zhou, G. Pu, X. Ma, X. Li, and D. Wu, “Distilled one-shot federated learning,” arXiv:2009.07999, 2020.
  • [25] M. Shin, C. Hwang, J. Kim, J. Park, M. Bennis, and S.-L. Kim, “Xor mixup: Privacy-preserving data augmentation for one-shot federated learning,” arXiv:2006.05148, 2020.
  • [26] N. Guha, A. Talwalkar, and V. Smith, “One-shot federated learning,” arXiv:1902.11175, 2019.
  • [27] N. Papernot, M. Abadi, U. Erlingsson, I. Goodfellow, and K. Talwar, “Semi-supervised knowledge transfer for deep learning from private training data,” arXiv:1610.05755, 2016.
  • [28] Z.-H. Zhou, J. Wu, and W. Tang, “Ensembling neural networks: many could be better than all,” Artificial intelligence, no. 1-2, pp. 239–263, 2002.
  • [29] R. Caruana, A. Niculescu-Mizil, G. Crew, and A. Ksikes, “Ensemble selection from libraries of models,” in Proceedings of the twenty-first international conference on Machine learning, 2004, p. 18.
  • [30] R. Avnimelech and N. Intrator, “Boosted mixture of experts: an ensemble learning scheme,” Neural computation, vol. 11, no. 2, pp. 483–497, 1999.
  • [31] C. M. Bishop and M. Svensén, “Bayesian hierarchical mixtures of experts,” arXiv:1212.2447, 2012.
  • [32] R. A. Jacobs, M. I. Jordan, S. J. Nowlan, and G. E. Hinton, “Adaptive mixtures of local experts,” Neural Computation, vol. 3, no. 1, pp. 79–87, 1991.
  • [33] G. Hinton, O. Vinyals, and J. Dean, “Distilling the knowledge in a neural network,” arXiv:1503.02531, 2015.
  • [34] N. Shazeer, A. Mirhoseini, K. Maziarz, A. Davis, Q. Le, G. Hinton, and J. Dean, “Outrageously large neural networks: The sparsely-gated mixture-of-experts layer,” ICLR, 2017.
  • [35] U. Asif, J. Tang, and S. Harrer, “Ensemble knowledge distillation for learning improved and efficient networks,” arXiv:1909.08097, 2019.
  • [36] L. Xiang, G. Ding, and J. Han, “Learning from multiple experts: Self-paced knowledge distillation for long-tailed classification,” in ECCV, 2020.
  • [37] A. Wu, W.-S. Zheng, X. Guo, and J.-H. Lai, “Distilled person re-identification: Towards a more scalable system,” in CVPR, 2019.
  • [38] S. You, C. Xu, C. Xu, and D. Tao, “Learning from multiple teacher networks,” in KDD, 2017.
  • [39] G. Song and W. Chai, “Collaborative learning for deep neural networks,” in NeurIPS, 2018, pp. 1832–1841.
  • [40] X. Zhu, S. Gong et al., “Knowledge distillation by on-the-fly native ensemble,” in NeurIPS, 2018, pp. 7517–7527.
  • [41] Q. Guo, X. Wang, Y. Wu, Z. Yu, D. Liang, X. Hu, and P. Luo, “Online knowledge distillation via collaborative learning,” in CVPR, 2020.
  • [42] A. Romero, N. Ballas, S. E. Kahou, A. Chassang, C. Gatta, and Y. Bengio, “Fitnets: Hints for thin deep nets,” ICLR, 2015.
  • [43] S. Zagoruyko and N. Komodakis, “Paying more attention to attention: Improving the performance of convolutional neural networks via attention transfer,” ICLR, 2017.
  • [44] P. Dhar, R. V. Singh, K.-C. Peng, Z. Wu, and R. Chellappa, “Learning without memorizing,” in CVPR, 2019, pp. 5138–5146.
  • [45] J. Yim, D. Joo, J. Bae, and J. Kim, “A gift from knowledge distillation: Fast optimization, network minimization and transfer learning,” in CVPR, 2017, pp. 4133–4141.
  • [46] Z. Huang and N. Wang, “Like what you like: Knowledge distill via neuron selectivity transfer,” arXiv:1707.01219, 2017.
  • [47] N. Passalis, M. Tzelepi, and A. Tefas, “Heterogeneous knowledge distillation using information flow modeling,” in CVPR, 2020.
  • [48] S. Park and N. Kwak, “Feed: Feature-level ensemble for knowledge distillation,” AAAI, 2020.
  • [49] T. Furlanello, Z. C. Lipton, M. Tschannen, L. Itti, and A. Anandkumar, “Born again neural networks,” PMLR, 2018.
  • [50] I.-J. Liu, J. Peng, and A. G. Schwing, “Knowledge flow: Improve upon your teachers,” ICLR, 2019.
  • [51] K. Li, Z. Wu, K.-C. Peng, J. Ernst, and Y. Fu, “Tell me where to look: Guided attention inference network,” in CVPR, 2018.
  • [52] S. Ben-David, J. Blitzer, K. Crammer, A. Kulesza, F. Pereira, and J. W. Vaughan, “A theory of learning from different domains,” Machine learning, vol. 79, no. 1, pp. 151–175, 2010.
  • [53] J. Blitzer, K. Crammer, A. Kulesza, F. Pereira, and J. Wortman, “Learning bounds for domain adaptation,” 2008.
  • [54] B. Planche, “Bridging the realism gap for cad-based visual recognition,” Ph.D. dissertation, Universität Passau, 2020.
  • [55] X. Peng, Z. Huang, Y. Zhu, and K. Saenko, “Federated adversarial domain adaptation,” in ICLR, 2020.
  • [56] X. Wang, Y. Peng, L. Lu, Z. Lu, M. Bagheri, and R. M. Summers, “Chestx-ray8: Hospital-scale chest x-ray database and benchmarks on weakly-supervised classification and localization of common thorax diseases,” in CVPR, 2017, pp. 3462–3471.
  • [57] J. Irvin, P. Rajpurkar, M. Ko, Y. Yu, S. Ciurea-Ilcus, C. Chute, H. Marklund, B. Haghgoo, R. Ball, K. Shpanskaya et al., “Chexpert: A large chest radiograph dataset with uncertainty labels and expert comparison,” in AAAI, vol. 33, 2019, pp. 590–597.
  • [58] R. S. of North America, “Rsna pneumonia detection challenge dataset,” 2018. [Online]. Available: https://www.kaggle.com/c/rsna-pneumonia-detection-challenge
  • [59] W. Ye, J. Yao, H. Xue, and Y. Li, “Weakly supervised lesion localization with probabilistic-cam pooling,” 2020.
  • [60] I. Loshchilov and F. Hutter, “Sgdr: Stochastic gradient descent with warm restarts,” arXiv:1608.03983, 2016.
  • [61] S. Bakas, M. Reyes, A. Jakab, S. Bauer, M. Rempfler, A. Crimi, R. T. Shinohara, C. Berger, S. M. Ha, M. Rozycki et al., “Identifying the best machine learning algorithms for brain tumor segmentation, progression assessment, and overall survival prediction in the brats challenge,” arXiv:1811.02629, 2018.
  • [62] B. H. Menze, A. Jakab, S. Bauer, J. Kalpathy-Cramer, K. Farahani, J. Kirby, Y. Burren, N. Porz, J. Slotboom, R. Wiest et al., “The multimodal brain tumor image segmentation benchmark (brats),” IEEE Transactions on Medical Imaging, vol. 34, no. 10, pp. 1993–2024, 2014.
  • [63] A. Myronenko, “3d mri brain tumor segmentation using autoencoder regularization,” in MICCAI Brainlesion Workshop, 2018, pp. 311–320.
  • [64] P. Guo, P. Wang, J. Zhou, S. Jiang, and V. M. Patel, “Multi-institutional collaborations for improving deep learning-based magnetic resonance image reconstruction using federated learning,” in CVPR, 2021.
  • [65] J. Zbontar, F. Knoll, A. Sriram, T. Murrell, Z. Huang, M. J. Muckley, A. Defazio, R. Stern, P. Johnson, M. Bruno, M. Parente, K. J. Geras, J. Katsnelson, H. Chandarana, Z. Zhang, M. Drozdzal, A. Romero, M. Rabbat, P. Vincent, N. Yakubova, J. Pinkerton, D. Wang, E. Owens, C. L. Zitnick, M. P. Recht, D. K. Sodickson, and Y. W. Lui, “fastMRI: An open dataset and benchmarks for accelerated MRI,” 2018.
  • [66] I. C. London, “Ixi dataset.” [Online]. Available: https://brain-development.org/
  • [67] S. S. Bakas, “Brats miccai brain tumor dataset,” 2020. [Online]. Available: https://dx.doi.org/10.21227/hdtd-5j88
  • [68] S. Jiang, C. G. Eberhart, M. Lim, H.-Y. Heo, Y. Zhang, L. Blair, Z. Wen, M. Holdhoff, D. Lin, P. Huang et al., “Identifying recurrent malignant glioma after treatment using amide proton transfer-weighted mr imaging: a validation study with image-guided stereotactic biopsy,” Clinical Cancer Research, vol. 25, no. 2, pp. 552–561, 2019.
  • [69] P. J. LaMontagne, T. L. Benzinger, J. C. Morris, S. Keefe, R. Hornbeck, C. Xiong, E. Grant, J. Hassenstab, K. Moulder, A. G. Vlassenko et al., “Oasis-3: longitudinal neuroimaging, clinical, and cognitive dataset for normal aging and alzheimer disease,” MedRxiv, 2019.
  • [70] O. Ronneberger, P. Fischer, and T. Brox, “U-net: Convolutional networks for biomedical image segmentation,” in MICCAI, 2015.