This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

When Foresight Pruning Meets Zeroth-Order Optimization: Efficient Federated Learning for Low-Memory Devices

Pengyu Zhang, Yingjie Liu, Yingbo Zhou, Xiao Du, Xian Wei, Ting Wang, Mingsong Chen
East China Normal University
Abstract.

Although Federated Learning (FL) enables collaborative learning in Artificial Intelligence of Things (AIoT) design, it fails to work on low-memory AIoT devices due to its heavy memory usage. To address this problem, various federated pruning methods are proposed to reduce memory usage during inference. However, few of them can substantially mitigate the memory burdens during pruning and training. As an alternative, zeroth-order or backpropagation-free (BP-Free) methods can partially alleviate the memory consumption, but they suffer from scaling up and large computation overheads, since the gradient estimation error and floating point operations (FLOPs) increase as the dimensionality of the model parameters grows. In this paper, we propose a federated foresight pruning method based on Neural Tangent Kernel (NTK), which can seamlessly integrate with federated BP-Free training frameworks. We present an approximation to the computation of federated NTK by using the local NTK matrices. Moreover, we demonstrate that the data-free property of our method can substantially reduce the approximation error in extreme data heterogeneity scenarios. Since our approach improves the performance of the vanilla BP-Free method with fewer FLOPs and truly alleviates memory pressure during training and inference, it makes FL more friendly to low-memory devices. Comprehensive experimental results obtained from simulation- and real test-bed-based platforms show that our federated foresight-pruning method not only preserves the ability of the dense model with a memory reduction up to 9×9\times but also boosts the performance of the vanilla BP-Free method with dramatically fewer FLOPs.

Federated learning, memory efficiency, model pruning, zeroth-order optimization.
ccs: Computing methodologies Distributed artificial intelligenceccs: Computing methodologies Artificial intelligence

1. Introduction

As a promising collaborative learning paradigm in Artificial Intelligence of Things (AIoT) (Zhang and Tao, 2021; Shukla et al., 2023; Chandrasekaran et al., 2022) design, Federated Learning (FL) (Khan et al., 2021) enables knowledge sharing among AIoT devices without compromising their data privacy. Specifically, FL enables each device to perform local training and transmit gradients instead of private data. The server then aggregates these local gradients to update the global model (Yang and Sun, 2023) and redistributes it to a selected subset of devices for the next round of training. Unlike centralized training deployed on powerful hardware platforms, FL is restricted by stringent local training resource requirements (Khan et al., 2021; Cui et al., 2023) reflected by being short of adequate computation and memory resources. The expensive communication expenses during the training additionally limit the practical deployment of FL (Yang and Sun, 2022; Cao et al., 2021). Moreover, the increasing computation overhead of Deep Neural Networks (DNNs) presents a major barrier to running them on edge devices (Sarkar et al., 2023). Prior work has primarily concentrated on model sparsity methods (Liang et al., 2021; Wu et al., 2024) to address these challenges in a one-shot manner. By introducing neural network pruning or sparse training approaches (Bibikar et al., 2022), the local computation overhead and communication costs are considerably relieved. Yet, since local devices privately own the data, the standard neural network pruning methods cannot be applied to FL without any modification. The notorious data heterogeneity problem further prohibits local devices from learning a common sparse model (Xia et al., 2022).

Table 1. Comparison of peak memory required by FL devices.
Method
Pruning
BP-Free
Peak Device
Memory Cost
FedAvg (McMahan et al., 2017a)
1x1\textbf{x}
ZeroFL (Qiu et al., 2022)
1x\geq 1\textbf{x}
PruneFL (Jiang et al., 2023)
1x\geq 1\textbf{x}
FedDST (Bibikar et al., 2022)
1x\geq 1\textbf{x}
FedTiny (Huang et al., 2023)
1x\geq 1\textbf{x}
BAFFLE (Feng et al., 2023)
<<1x<<1\textbf{x}
Ours
<<1x<<1\textbf{x}

Existing works perform pruning operations under various federated settings (Li et al., 2021; Jiang et al., 2023; Zhu et al., 2022; Liao et al., 2023) to bridge the abovementioned gap. Although these methods are efficient in inference costs, their training costs regarding memory consumption are seldom investigated. Since the pruning-during-training methods need to maintain the backpropagation structure (Paszke et al., 2019; Abadi et al., 2016), the memory consumption is at least the same as the vanilla training. Their heavy dependence on the backpropagation framework prohibits them from truly supporting a lightweight training paradigm for low-memory devices. As a result, existing federated pruning methods are not efficient in terms of memory usage; To eliminate the memory-inefficient auto-differential framework in FL, recent work (Feng et al., 2023; Li and Chen, 2021; Yi et al., 2022) proposed backpropagation-free (BP-Free) FL methods, aiming to estimate the actual gradients using zeroth-order optimization. Integrating federated pruning methods and BP-Free FL methods has enormous potential to reduce memory consumption, but it has compatibility gaps and harms performance. Existing federated pruning steps entangle with gradient computations, making them severely unstable to the precision of the estimated gradient when using a BP-Free strategy.

To mitigate the problems mentioned above, we seek better solutions from the perspective of foresight-federated pruning methods, which disentangle the pruning and training steps. Inspired by (Feng et al., 2023) and (Wang et al., 2023), we propose a federated and BP-Free foresight pruning method to allow lightweight federated pruning. Additionally, we leverage the sparse structure of the pruned model to greatly reduce the local computational overhead and boost the performance of Stein’s Identity, a classic zeroth-order optimization method, in FL settings. We compared the device peak memory cost of existing outstanding federated learning methods in Table 1. Note that our approach explores the synergy between federated pruning and BP-Free training, aiming to allow FL on AIoT devices with extremely low memories. In summary, this paper makes the following three major contributions:

  • We propose a novel memory-efficient foresight pruning method for FL, which is resilient to various data heterogeneities.

  • We propose an approximation to the federated NTK matrix and show that the data-free property of our method can effectively reduce the approximation error.

  • We implement our proposed foresight pruning method in conjunction with BP-Free training and conduct comprehensive experiments on both simulation- and real test-bed-based platforms to demonstrate the effectiveness of our approach.

2. Related Work

Federated Neural Network Pruning. A two-step pruning method was proposed in PruneFL (Jiang et al., 2023) to reduce the computation cost on the device side. It first completes coarse pruning on a selected device. Further, finer pruning is conducted on the server side to reduce computation expenses. ZeroFL (Qiu et al., 2022) separates model parameters into active and non-active groups for the purpose of the efficiency of both forward and backward passes. Although the aforementioned methods leave the finer pruning on the server side, they still require a large local memory footprint to record the updated importance scores of all parameters for the guidance of pruning. FedDST (Bibikar et al., 2022), instead, adopts sparse training, which consists of mask adjustment and regrow on the device side. Once the local sparse training is finished, the server conducts sparse aggregation and magnitude pruning procedures to obtain a global model. Even with the sparse training framework, the high requirements on memory are not relieved. Worse still, the local mask adjustment and regrow processes inevitably need additional training epochs, which contradicts the principle of minimizing local training steps in FL (Mendieta et al., 2022). To reduce the cost of the memory- and computation-intensive pruning process for extremely recourse-constrained devices, FedTiny (Huang et al., 2023) adaptively searches finer-pruned specialized models by progressive pruning based on a coarsely pruned model. Effective as it is under the low-density region for inference, the training step is hindered by the significant communication and memory burdens, making it less compatible with devices with limited memory capacity.

BP-Free Federated Learning. Gradient-based optimization techniques are commonly used for training DNNs in FL settings. Regardless, the backpropagation (BP) operations heavily depend on the auto-differential framework (e.g., Pytorch (Paszke et al., 2019)). The additional static and dynamic memory requirements are unaffordable to low-memory devices. Recently, advances in zeroth-order optimization methods (Li and Chen, 2021; Yi et al., 2022; Shu et al., 2023) have demonstrated potential for federated training, particularly in cases where backward processes are expensive from the perspective of computation and memory space. (Feng et al., 2023) proposes to utilize Stein’s Identity (Stein, 1981) as the gradient estimation method to develop a BP-Free federated learning framework. Yet, the Floating Point Operations (FLOPs) per communication round are significantly increased to an extremely high level during the training due to the property of BP-Free gradient estimation: numerous forward passes are required for one-time gradient estimation to decrease the estimation error (Adamczak et al., 2011; Feng et al., 2023). The degradation of learning performance is another issue since the BP-Free training comes at the cost of adding noises (Li et al., 2020; HaoChen et al., 2021) to the gradients. Therefore, we propose using sparsity techniques to reduce the FLOPs and boost learning performance in a one-shot manner.

3. Preliminaries

Let 𝒟={𝒳,𝒴}\mathcal{D}=\{\mathcal{X},\mathcal{Y}\} denote the entire training dataset. Let NN represent the total number of data points. Hence, we have 𝒳={𝐱1,,𝐱N}\mathcal{X}=\{\mathbf{x}_{1},\cdots,\mathbf{x}_{N}\} and 𝒴={𝐲1,,𝐲N}\mathcal{Y}=\{\mathbf{y}_{1},\cdots,\mathbf{y}_{N}\} denote inputs and labels, respectively.

Federated Learning. FL aims to collaboratively learn a global model parameterized by 𝑾n\boldsymbol{W}\in\mathbb{R}^{n} while keeping local data private. Given that mm devices are involved in each round of local training, where the local dataset on ii-th device is denoted as 𝒟i=(𝐱j,𝐲j)j=1Nj\mathcal{D}_{i}={(\mathbf{x}_{j},\mathbf{y}_{j})}_{j=1}^{N_{j}} with NjN_{j} representing the number of data points, the objective is defined by

(1) min𝑾f(𝑾)=i=1m1Nij=1Nj(𝑾;(𝐱j,𝐲j)),\displaystyle\mathop{\min}\limits_{\boldsymbol{W}}f(\boldsymbol{W})=\sum\limits_{i=1}^{m}\frac{1}{N_{i}}\sum\limits_{j=1}^{N_{j}}\mathcal{L}(\boldsymbol{W};(\mathbf{x}_{j},\mathbf{y}_{j})),

where (𝑾;𝒟i)\mathcal{L}(\boldsymbol{W};\mathcal{D}_{i}) is the specified loss function on local dataset 𝒟i\mathcal{D}_{i}.

Neural Tangent Kernel. Neural Tangent Kernel (NTK) analyzes the training dynamics of neural networks (Jacot et al., 2018). Given an arbitrary DNN ff initialized by 𝑾0\boldsymbol{W}_{0}, the NTK at initial state is defined as

𝜽0=𝑾0f(𝒳;𝑾0),𝑾0f(𝒳;𝑾0).\displaystyle\boldsymbol{\theta}_{0}=\langle\nabla_{\boldsymbol{W}_{0}}f(\mathcal{X};\boldsymbol{W}_{0}),\nabla_{\boldsymbol{W}_{0}}f(\mathcal{X};\boldsymbol{W}_{0})\rangle.

It has been proven that the NTK stays asymptotically constant during the training if the DNN is sufficiently large in terms of parameters. Therefore, the NTK at initialization (i.e., 𝜽0\boldsymbol{\theta}_{0}) can characterize the training dynamics.

Foresight Pruning Based on NTK. The general objective function of foresight pruning is formulated as

(2) min𝐦(𝒜(𝑾0,𝐦);𝒟)s.t.𝐦{0,1}p,𝐦0/nd,\displaystyle\small\min_{\mathbf{m}}\mathcal{L}(\mathcal{A}(\boldsymbol{W}_{0},\mathbf{m});\mathcal{D})\;\text{s.t.}\;\mathbf{m}\in\{0,1\}^{p},\;\|\mathbf{m}\|_{0}/n\leq d,

where 𝒜\mathcal{A} denotes the model architecture dominated by the binary mask 𝐦\mathbf{m}, 𝑾0\boldsymbol{W}_{0} represents the model parameter at initialization. We aim to find the mask 𝐦\mathbf{m} in which each element follows the binary distribution that minimizes the loss function and is simultaneously constrained by the target density dd. To make Eq. 2 tractable, existing foresight pruning methods propose the saliency measurement function, defined as

(3) S(𝐦j)=S(𝑾0j)=𝐦j=𝑾0j𝑾0j,\displaystyle S(\mathbf{m}^{j})=S(\boldsymbol{W}_{0}^{j})=\frac{\partial\mathcal{I}}{\partial\mathbf{m}^{j}}=\frac{\partial\mathcal{I}}{\partial\boldsymbol{W}_{0}^{j}}\cdot\boldsymbol{W}_{0}^{j},

where \mathcal{I} represents a function of model parameters and mask 𝐦\mathbf{m}. 𝐦j\mathbf{m}^{j} denotes the jj-th element of the mask 𝐦\mathbf{m}, which is a scalar. S(𝐦j)S(\mathbf{m}^{j}) represents the saliency score function that measures the impact of deactivating the 𝐦j\mathbf{m}^{j} (i.e., set 𝐦j\mathbf{m}^{j} to 0). After the saliency score is computed, we keep the top-dd elements in the mask. In detail, we set the element in the mask 11 (𝐦j=1\mathbf{m}^{j}=1) if it belongs to the set of top-dd elements and 0 (𝐦j=0\mathbf{m}^{j}=0) otherwise. To analyze the property of 𝜽0\boldsymbol{\theta}_{0}, we follow the work proposed by (Wang et al., 2023) to compute the spectrum of it. The spectrum is formulated as

(4) 𝜽0=𝜽0𝐭𝐫=𝑾0f(𝒳;𝑾0)F2,\displaystyle||\boldsymbol{\theta}_{0}||_{*}=||\boldsymbol{\theta}_{0}||_{\mathbf{tr}}=||\nabla_{\boldsymbol{W}_{0}}f(\mathcal{X};\boldsymbol{W}_{0})||_{F}^{2},

where ||||||\cdot||_{*} is the nuclear norm operation. The trace norm 𝜽0𝐭𝐫||\boldsymbol{\theta}_{0}||_{\mathbf{tr}} is equivalent to the nuclear norm since the NTK matrix is symmetric. To further reduce the cost of computing the NTK matrix, we rewrite the trace norm of the NTK matrix as the Frobenius norm of gradients 𝑾0f(𝒳;𝑾0)F2||\nabla_{\boldsymbol{W}_{0}}f(\mathcal{X};\boldsymbol{W}_{0})||_{F}^{2}. We aim to find a sparse model with the same training dynamics as the dense model. Therefore, the objective function of NTK-based foresight pruning is formulated as

(5) S(𝐦j)=|𝑾0f(𝒳;𝑾0𝐦)F2𝐦j|.\displaystyle S(\mathbf{m}^{j})=\left|\frac{\partial||\nabla_{\boldsymbol{W}_{0}}f(\mathcal{X};\boldsymbol{W}_{0}\odot\mathbf{m})||_{F}^{2}}{\partial\mathbf{m}^{j}}\right|.

Note that to achieve a BP-Free method, the Fobenius norm of gradients can be approximated by

(6) f(𝒳;𝑾0𝐦)f(𝒳;(𝑾0+Δ𝑾)𝐦)22,\displaystyle\left\|f(\mathcal{X};\boldsymbol{W}_{0}\odot\mathbf{m})-f(\mathcal{X};(\boldsymbol{W}_{0}+\Delta\boldsymbol{W})\odot\mathbf{m})\right\|_{2}^{2},

which saves the expensive computation of Eq. 5.

Finite Difference and Stein’s Identity. Derived from the definition of derivatives, the Finite Difference (FD) method can be extended to multivariate cases by using Taylor’s expansion. Given the loss function (𝑾;𝒟)\mathcal{L}(\boldsymbol{W};\mathcal{D}) and a small perturbation 𝜹n\boldsymbol{\delta}\in\mathbb{R}^{n}, the FD method is define as

(7) (𝑾+𝜹;𝒟)(𝑾;𝒟)=𝜹T𝑾(𝑾)+o(𝑾22).\displaystyle\mathcal{L}(\boldsymbol{W}+\boldsymbol{\delta};\mathcal{D})-\mathcal{L}(\boldsymbol{W};\mathcal{D})=\boldsymbol{\delta}^{T}\nabla_{\boldsymbol{W}}\mathcal{L}(\boldsymbol{W})+o(||\boldsymbol{W}||_{2}^{2}).

Assuming the loss function is continuously differentiable w.r.t. model parameter 𝑾\boldsymbol{W}. The precise gradient of 𝑾\boldsymbol{W} estimated by Stein’s Identity is formulated as

(8) 𝑾(𝑾)=𝔼𝜹𝒩(0,σ2𝐈)[𝜹σ2Δ(𝑾;𝒟)],\nabla_{\boldsymbol{W}}\mathcal{L}(\boldsymbol{W})=\mathbb{E}_{\boldsymbol{\delta}\sim\mathcal{N}(0,\sigma^{2}\mathbf{I})}[\frac{\boldsymbol{\delta}}{\sigma^{2}}\Delta\mathcal{L}(\boldsymbol{W};\mathcal{D})],

where Δ(𝑾;𝒟)=Δ(𝑾+𝜹;𝒟)Δ(𝑾;𝒟)\Delta\mathcal{L}(\boldsymbol{W};\mathcal{D})=\Delta\mathcal{L}(\boldsymbol{W}+\boldsymbol{\delta};\mathcal{D})-\Delta\mathcal{L}(\boldsymbol{W};\mathcal{D}), 𝜹\boldsymbol{\delta} follows a Gaussian distribution with zero mean and covariance σ2𝐈\sigma^{2}\mathbf{I}. We utilize the Monte Carlo method to sample KK number of 𝜹\boldsymbol{\delta}, thereby obtaining a stochastic version of the estimation as

(9) 𝑾^(𝑾)=1Kk=1K[𝜹kσ2Δ(𝑾;𝒟)].\displaystyle\widehat{\nabla_{\boldsymbol{W}}}\mathcal{L}(\boldsymbol{W})=\frac{1}{K}\sum_{k=1}^{K}[\frac{\boldsymbol{\delta}_{k}}{\sigma^{2}}\Delta\mathcal{L}(\boldsymbol{W};\mathcal{D})].

4. Methodology

4.1. Estimation Error of Stein’s Identity

Though the sampling method in Eq. 9 is friendly to low-memory devices, the computation overhead is extremely heavy since KK should be large for accurate estimation. Following what was proven in (Feng et al., 2023), the estimation error is formally defined by the following theorem:

Theorem 1.

(Estimation error (Feng et al., 2023)) Let 𝛅^=1Kk=1K𝛅k\hat{\boldsymbol{\delta}}=\frac{1}{K}\sum_{k=1}^{K}\boldsymbol{\delta}_{k}, covariance matrix 𝚺^=1Kσ2k=1K𝛅k𝛅kT\widehat{\boldsymbol{\Sigma}}=\frac{1}{K\sigma^{2}}\sum_{k=1}^{K}\boldsymbol{\delta}_{k}\boldsymbol{\delta}_{k}^{T}. Let nn be the dimension of trainable parameters 𝐖\mathbf{W} of the DNN. The discrepancy between the true gradient 𝐖(𝐖;𝐃)\nabla_{\boldsymbol{W}}\mathcal{L}(\mathbf{W};\mathbf{D}) and the estimated gradient 𝐖^(𝐖;𝐃)\widehat{\nabla_{\boldsymbol{W}}}\mathcal{L}(\mathbf{W};\mathbf{D}) is formulated as

𝑾^(𝐖;𝐃)=𝚺^𝑾(𝐖;𝐃)+o(𝜹^)\displaystyle\widehat{\nabla_{\boldsymbol{W}}}\mathcal{L}(\mathbf{W};\mathbf{D})=\widehat{\boldsymbol{\Sigma}}\nabla_{\boldsymbol{W}}\mathcal{L}(\mathbf{W};\mathbf{D})+o(\widehat{\boldsymbol{\delta}}) ;
s.t. 𝔼[𝚺^]=𝐈,𝔼[𝜹^]=𝟎\displaystyle\text{ s.t. }\mathbb{E}[\widehat{\boldsymbol{\Sigma}}]=\mathbf{I},\mathbb{E}[\widehat{\boldsymbol{\delta}}]=\mathbf{0} .

Neglecting the trivial term o(𝜹^)o(\widehat{\boldsymbol{\delta}}), the estimation error measured by Mean Square Error (MSE) is fully controlled by the term 𝚺^𝐈22||\hat{\boldsymbol{\Sigma}}-\mathbf{I}||_{2}^{2}. Based on the proof in (Adamczak et al., 2011), we have 𝚺^𝐈2nK||\hat{\boldsymbol{\Sigma}}-\mathbf{I}||_{2}\leq\sqrt{\frac{n}{K}}. In conclusion, we can expect a more accurate estimation if we either increase the number of Monte Carlo steps or decrease the dimensionality of the model parameter 𝑾\boldsymbol{W}, or both. In practice, KK is related to the total FLOPs consumed to estimate the true gradient, and nn reflects the memory requirements for a device to afford the model. Developing more advanced zeroth-order algorithms to lower the value of KK required to achieve the same performance is a promising way to obtain more accurate estimations. Another way is to put effort into decreasing the value of nn. To achieve the latter, we introduce model pruning to reduce computation overhead and increase the estimation precision as it aims to find the spare structure of the model, and can still maintain the performance.

Algorithm 1 Pruning and BP-Free Training

Input: 1) 𝑾0\boldsymbol{W}_{0}, a randomly initialized global model; 2) CC, a pool of all participants; 3) TpT_{p}, # of pruning rounds; 4) TtT_{t}, # of training rounds; 5) GpG_{p}, # of participants for pruning; 6) GtG_{t}, # of participants for training; 7) KK, # of perturbations to estimate gradients; 8) dd, target density; 9) 𝐦\mathbf{m}, mask for parameters; 10) η\eta, learning rate.

NTK Foresight Pruning:

1:  for t=1,,Tpt=1,\dots,T_{p} do
2:     CtC_{t} \leftarrow Random Sample(C,Gp)\text{Random\ Sample}(C,G_{p});
3:     GpG_{p} \leftarrow Number of elements in CtC_{t};
4:     if use real data then
5:        for i=1,,Gpi=1,\dots,G_{p} in parallel do
6:           Get FiF_{i} based on Eq. 13 at devices;
7:           Send FiF_{i} to the server;
8:        end for
9:     else if use random data then
10:        Sample 𝐱i\mathbf{x}_{i} from the standard Gaussian Distribution;
11:        Get FiF_{i} based on Eq. 13 at the server;
12:     end if
13:     Get \mathcal{I} based on Eq. 15;
14:     Compute 𝒮(𝐦j)\mathcal{S}(\mathbf{m}^{j}) based on Eq. 14 ;
15:     Get threshold τ\tau as (1dtTp)(1-d^{\frac{t}{T_{p}}}) percentile of 𝒮(𝐦j)\mathcal{S}(\mathbf{m}^{j});
16:     𝐦\mathbf{m} as 𝐦𝐦𝒮(𝐦j)<τ\mathbf{m}\leftarrow\mathbf{m}\odot\mathcal{S}(\mathbf{m}^{j})<\tau;
17:  end for
18:  Return A pruned model parameterized by 𝑾0𝐦\boldsymbol{W}_{0}\odot\mathbf{m}

BP-Free Training:

1:  𝑾𝑾0𝐦\boldsymbol{W}\leftarrow\boldsymbol{W}_{0}\odot\mathbf{m};
2:  for t=1,,Ttt=1,\dots,T_{t} do
3:     CtC_{t} \leftarrow Random Sample(C,Gt)\text{Random\ Sample}(C,G_{t});
4:     GtG_{t} \leftarrow Number of elements in CtC_{t};
5:     for i=1,,Gti=1,\dots,G_{t} in parallel do
6:        Set 𝑾i𝑾\boldsymbol{W}_{i}\leftarrow\boldsymbol{W};
7:        Get 𝑾i^(𝑾i)\widehat{\nabla_{\boldsymbol{W}_{i}}}\mathcal{L}(\boldsymbol{W}_{i}) based on Eq. 9;
8:     end for
9:     Devices send Δ(𝑾i;𝒟i)\Delta\mathcal{L}(\boldsymbol{W}_{i};\mathcal{D}_{i}) and corresponding random seed to the server;
10:     The server produces 𝜹\boldsymbol{\delta} after receiving random seed;
11:     Get 𝑾agg^(𝑾)\widehat{\nabla_{\boldsymbol{W}_{agg}}}\mathcal{L}(\boldsymbol{W}) based on Eq. 17;
12:     𝑾𝑾η𝑾agg^(𝑾)\boldsymbol{W}\leftarrow\boldsymbol{W}-\eta\widehat{\nabla_{\boldsymbol{W}_{agg}}}\mathcal{L}(\boldsymbol{W});
13:  end for
14:  Return A global model parameterized by 𝑾\boldsymbol{W}

Current federated pruning methods are highly integrated with the training process. Consequently, the noisy gradients negatively affect the pruning step, decreasing training performance. To mitigate the impact of noisy gradients on the pruning step, we propose decoupling the pruning processes from the training processes, which is one of the fundamental motivations of foresight pruning. Further, we dedicate efforts to developing a federated foresight-pruning method based on the NTK spectrum, ensuring full compatibility with the BP-Free training framework.

4.2. NTK-based Foresight Pruning for FL

Since the symmetric property of the 𝜽0\boldsymbol{\theta}_{0} is only valid for centralized training, we cannot directly apply it for federated training. Motivated by (Huang et al., 2021), we prove that the NTK in federated learning can be approximated by local NTK matrices.

Definition 0.

(Asymmetric FL-NTK) Let 𝛉0iNi×Ni\boldsymbol{\theta}_{0}^{i}\in\mathbb{R}^{N_{i}\times N_{i}} represent the local NTK of the ithi^{th} device based on its dataset. Let NmaxN_{max} represent max{N1,N2,,Nm}\max\{N_{1},N_{2},...,N_{m}\} and NsumN_{sum} represent i=1mNi\sum_{i=1}^{m}N_{i}. The asymmetric FL-NTK is defined as 𝛉0flNmax×Nsum\boldsymbol{\theta}_{0}^{fl}\in\mathbb{R}^{N_{max}\times N_{sum}}. The 𝛉0fl\boldsymbol{\theta}_{0}^{fl} is formulated by combining the NiN_{i} columns of 𝛉0i\boldsymbol{\theta}_{0}^{i} for all imi\in m and padding the reset elements with zero.

The asymmetric property of FL-NTK prevents us from computing the nuclear norm by the Frobenius norm of gradients based on each local dataset. However, we prove that we can utilize the summation of local symmetric NTK matrices to approximate the FL-NTK and efficiently conduct the Frobenius norm computation.

Proposition 0.

Since 𝛉0fl\boldsymbol{\theta}_{0}^{fl} is the horizontal concatenation of 𝛉0i\boldsymbol{\theta}_{0}^{i} for imi\in m, we decompose it as the summation of mm sparse NTK matrices. Let 𝛉0,spiNmax×Nsum\boldsymbol{\theta}_{0,sp}^{i}\in\mathbb{R}^{N_{max}\times N_{sum}} represent the sparse NTK matrix for the uthu^{th} device. We can reformulate the FL-NTK matrix as

(10) 𝜽0fl=i=1m𝜽0,spi,\displaystyle\boldsymbol{\theta}_{0}^{fl}=\sum_{i=1}^{m}\boldsymbol{\theta}_{0,sp}^{i},
(11) 𝜽0,spi[0:Ni,j=1i\displaystyle\boldsymbol{\theta}_{0,sp}^{i}[0:N_{i},\sum_{j=1}^{i} Nj:j=1iNj+Ni]=𝜽0i,\displaystyle N_{j}:\sum_{j=1}^{i}N_{j}+N_{i}]=\boldsymbol{\theta}_{0}^{i},

where the non-filling elements are all zeros in 𝛉0,spi\boldsymbol{\theta}_{0,sp}^{i}.

Since the nuclear norm satisfies the triangle inequality, we further formulate the computation of 𝜽0fl\|\boldsymbol{\theta}_{0}^{fl}\|_{*} as

(12) 𝜽0fl=i=1m𝜽0,spii=1m𝜽0,spi=i=1m𝜽0i\displaystyle\centering\|\boldsymbol{\theta}_{0}^{fl}\|_{*}=\|\sum_{i=1}^{m}\boldsymbol{\theta}_{0,sp}^{i}\|_{*}\leq\sum_{i=1}^{m}\|\boldsymbol{\theta}_{0,sp}^{i}\|_{*}=\sum_{i=1}^{m}\|\boldsymbol{\theta}_{0}^{i}\|_{*}\@add@centering

Therefore, the complex FL-NTK is upper-bounded by the summation of individual local NTK matrices. Since the sparse model preserves the same training dynamic as the dense model and estimated gradients based on Stein’s Identity are unbiased, the overall convergence rate based on the training round TT is O(1/T)O(1/\sqrt{T}), the same as the standard FedAvg (Li and Lyu, 2023) algorithm under non-convex settings.

Unlike centralized pruning, the entire dataset 𝒟\mathcal{D} is distributed to multiple devices in FL. According to the form of Eq. 3, the computation of saliency score demands at least one operation of gradient computing, i.e., /𝑾0j\partial\mathcal{I}/\partial\boldsymbol{W}_{0}^{j} if not consider the function \mathcal{I}. We can wisely let the local devices compute the function \mathcal{I} and leave the derivative and multiplication operations in Eq. 15 to the server to complete. In this way, local datasets equally contribute to the generation of the binary mask, yielding an unbiased global sparse model, and the memory required to support the auto-differential framework is shifted to the server. Things might be different when considering the specific form of the function \mathcal{I}. As mentioned above, the function \mathcal{I} defined in Eq. 4 can be approximated by either Eq. 6 or the loss gradient w.r.t. weights 𝑾022\left\|\nabla_{\boldsymbol{W}_{0}}\mathcal{L}\right\|_{2}^{2}. The latter introduces extra gradient computation, which inevitably relies on the auto-differential framework. In contrast, Eq. 6 benefits from being independent of the auto-differential framework. Regardless, the computation of the saliency score in Eq. 3 still intertwines with it (i.e., /𝑾0j\partial\mathcal{I}/\partial\boldsymbol{W}_{0}^{j}). We might introduce another zeroth-order optimization method to estimate the derivatives, but it introduces more estimation errors. As a result, we choose Eq. 6 as the function \mathcal{I} to make the local devices free of the memory-inefficient backpropagation framework. The federated NTK foresight pruning is organized as

(13) Fi=f(𝐱i;𝑾0,i𝐦\displaystyle F_{i}=\|f(\mathbf{x}_{i};\boldsymbol{W}_{0,i}\odot\mathbf{m} )f(𝐱i;(𝑾0,i+Δ𝑾i)𝐦)22,\displaystyle)-f(\mathbf{x}_{i};(\boldsymbol{W}_{0,i}+\Delta\boldsymbol{W}_{i})\odot\mathbf{m})\|_{2}^{2},
(14) S(𝐦j)\displaystyle\footnotesize S(\mathbf{m}^{j}) =|𝐦j|=|𝑾0j𝑾0j|,\displaystyle=\left|\frac{\partial\mathcal{I}}{\partial\mathbf{m}^{j}}\right|=\left|\frac{\partial\mathcal{I}}{\partial\boldsymbol{W}_{0}^{j}}\cdot\boldsymbol{W}_{0}^{j}\right|,
(15) =1N\displaystyle\mathcal{I}=\frac{1}{N} i=1N𝔼Δ𝑾𝒩(𝟎,ϵ𝐈)[Fi],\displaystyle\sum\nolimits_{i=1}^{N}\mathbb{E}_{\Delta\boldsymbol{W}\sim\mathcal{N}(\mathbf{0},\epsilon\mathbf{I})}[F_{i}],

where ii represents the ii-th dataset owned by the corresponding device. 𝑾0j=(𝑾0,ij)i=1m\boldsymbol{W}_{0}^{j}=(\boldsymbol{W}_{0,i}^{j})_{i=1}^{m} denotes the jj-th element of the initial global model. Δ𝑾i\Delta\boldsymbol{W}_{i} is sampled from a Gaussian distribution with zero mean and ϵ𝐈\epsilon\mathbf{I} variance. The 𝔼\mathbb{E} symbol denotes the expectation over Δ𝑾i\Delta\boldsymbol{W}_{i}, which can be approximated by the Monte Carlo sampling method. To avoid layer-collapse, we use an exponential decay schedule τ=(1dt/Tp)\tau=(1-d^{t/T_{p}}) to compute the pruning threshold τ\tau, where tt denotes the current pruning round, dd is the target density and TpT_{p} is the maximum pruning round. The parameters whose saliency scores are under the threshold τ\tau will be pruned. It may be noted that the proposed NTK pruning method introduces TpT_{p} extra computations to the devices. In addition, the multiple-round pruning schedule requires communication between the server and all devices. To resolve the two problems in one shot, we fully explore the intrinsic property of NTK pruning. Motivated by (Shu et al., 2022), we show that under data heterogeneity scenarios, the difference between any twp local NTK matrices is |𝜽0i(Pi)𝜽0j(Pj)|n01Z\Big{|}\left\|\boldsymbol{\theta}_{0}^{i}(P^{i})\right\|_{*}-\left\|\boldsymbol{\theta}_{0}^{j}(P^{j})\right\|_{*}\Big{|}\leq n_{0}^{-1}Z, where PiP^{i} and PjP^{j} represent the local data distributions, ZZ represent Pi(𝐗)Pj(𝐗)d𝐗\int\left\|P^{i}(\mathbf{X})-P^{j}(\mathbf{X})\right\|\mathrm{d}\mathbf{X}. n0n_{0} is the input dimension. In our cases, n0n_{0} is 10241024 and 0<Z<20<Z<2. Therefore, the approximation error defined by the summation operation in Eq. 12 can be represented by the difference and increases as the degree of heterogeneity grows. To alleviate the approximation error, we constraint the difference between PiP^{i} and PjP^{j} by sampling the data 𝐱i\mathbf{x}_{i} in Eq. 13 from the same standard Gaussian distribution, which diminishes the difference between two NTK matrices and makes our method resilient to data heterogeneity. Moreover, the computation of Eq. 13 can be fully conducted by the server, and our proposed federated NTK pruning method degrades to the centralized version, entirely saving the communication overhead.

4.3. Backpropgation-free Training for FL

We further integrate the proposed NTK foresight pruning with Stein’s Identity method to collaboratively build a memory-friendly federated training framework. To leverage Stein’s Identity in the context of FL, we rewrite the estimation of the aggregated gradient in the form of FedAvg (McMahan et al., 2017b) as

(16) 𝑾𝒂𝒈𝒈^(𝑾)=i=1mNiKNk=1K[𝜹σ2Δ(𝑾i;𝒟i)].\displaystyle\widehat{\nabla_{\boldsymbol{W_{agg}}}}\mathcal{L}(\boldsymbol{W})=\sum_{i=1}^{m}\frac{N_{i}}{KN}\sum_{k=1}^{K}[\frac{\boldsymbol{\delta}}{\sigma^{2}}\Delta\mathcal{L}(\boldsymbol{W}_{i};\mathcal{D}_{i})].

Note that 𝜹n\boldsymbol{\delta}\in\mathbb{R}^{n} and nn is the dimensionality of gradients we want to estimate. To further decrease the communication burden, we use the Random Seed Trick technique. The Random Seed Trick smartly leverages the inherent property of the developing environment: i) 𝜹\boldsymbol{\delta} has the same dimensionality as the gradients we want to estimate and is generated by sampling from a Gaussian distribution; ii) the sampling from Gaussian distribution is driven by the Random Seed in software level. Therefore, the server can produce the same sampling results as the local devices by maintaining the same Random Seed and running environments between them. In detail, Eq. 16 is performed by the server and devices as:

(17) 𝑾agg^(𝑾)=i=1mNiKNk=1K[𝜹server producesσ2Δ(𝑾i;𝒟i)computed by devices].\displaystyle\footnotesize\widehat{\nabla_{\boldsymbol{W}_{agg}}}\mathcal{L}(\boldsymbol{W})=\sum_{i=1}^{m}\frac{N_{i}}{KN}\sum_{k=1}^{K}[\frac{\overbrace{\boldsymbol{\delta}}^{\text{server produces}}}{\sigma^{2}}\underbrace{\Delta\mathcal{L}(\boldsymbol{W}_{i};\mathcal{D}_{i})}_{\text{computed by devices}}].

In practice, we let the devices compute Δ(𝑾i;𝒟i)\Delta\mathcal{L}(\boldsymbol{W}_{i};\mathcal{D}_{i}) and leave the 𝜹\boldsymbol{\delta} to the server. Consequently, the communication cost is significantly reduced from large dimensional vectors to scalars level since we do not need to transmit the vector 𝜹\boldsymbol{\delta}. The overall training processes are shown in Algorithm 1.

5. Experiments

In this section, we conducted federated pruning and federated zeroth-order experiments to answer the following two pivotal Research Questions: RQ1: What are the advantages of our NTK-based federated foresight pruning? RQ2: What improvements do the NTK-based foresight federated pruning bring to the standard federated zeroth-order method?

5.1. Experimental Settings

Dataset. We evaluated the performances on two image classification tasks, i.e., CIFAR-10, CIFAR-100 (Krizhevsky, 2009). Following the heterogeneity configurations in (Li et al., 2022), we considered Dirichlet distribution throughout this paper, which is parameterized by a coefficient β\beta, denoted as Dir(β)Dir(\beta). β\beta determines the degree of data heterogeneity. The smaller the β\beta is, the more heterogeneous the data will be. We only considered one non-IID scenario by setting β\beta as 0.10.1.

Models. To show the scalability of proposed foresight pruning to various types of models, we utilized the LeNet-5 (LeCun et al., 1998) and ResNet-20 (He et al., 2016) with batch normalization. For federated pruning evaluation experiments, we utilized LeNet-5 for CIFAR-10 and ResNet-20 for CIFAR-100. For the experiments on BP-Free federated training, we used LeNet-5 and ResNet-20 for CIFAR-10 and CIFAR100 datasets.

Baselines. For the pruning evaluation, we included FedDST and PruneFL to compare with the proposed NTK pruning. FedDST utilizes a sparse training method and conducts neuron growth and pruning at local devices. The PruneFL first conducts local coarse pruning at a randomly selected device, then performs server-side pruning based on collected gradients from local devices. We excluded the federated pruning methods that are inapplicable for memory-constrained FL (i.e., ZeroFL, LotteryFL). To show the improvements brought by our NTK pruning to BP-Free federated learning, we compared it with the state-of-the-art method BAFFLE (Feng et al., 2023), which conducted the vanilla Stein’s Identity method. Note that the result of FedAvg is conducted by backpropagation-based training.

Hyperparameter Settings. The pruning rounds TpT_{p} for LeNet and ResNet-20 are set to 5050 and 2020, respectively. The corresponding learning rate and batch size are set to 0.0010.001 and 256256. Following the source code in (Bibikar et al., 2022), we set the number of local training epochs to 55 if not specified. We randomly selected 1010 out of 100100 devices to perform local training at each round, and the training batch size is set to 3232. For experiments on CIFAR-10, we set the learning rate, momentum, and weight decay as 0.010.01, 0.90.9, and 1e31e-3, respectively. For the CIFAR-100, the momentum and weight decay are set to 0, 1e31e-3 for both models. Specifically, we set the learning rate to 0.010.01 for LeNet and 0.10.1 for ResNet. The learning rate decays to 99.8%99.8\% after each round for all experiments. For all experiments, we set KK, the number of Monte Carlo steps to estimate the true gradient, to 5050, 100100, and 200200, respectively. For a fair comparison, we used a sparsity level of 8080% and 55 epoch training based on the original FedDST paper.

5.2. Experimental Results for Federated Pruning

Accuracy and FLOPs Analysis. Table 2 shows the maximum accuracy and FLOPs for all federated pruning baselines under an extreme non-IID scenario where β=0.1\beta=0.1. The training curves for CIFAR-10 are shown in Figure 1. Note that “NTK-Ori” represents the result on real data and “NTK-Rand” represents that on randomly generated data. We denote epoch as “ep” in tables. Since the time coefficient is required to implement PruneFL, we omit PruneFL for the ResNet-20 model as the source code of FedDST or PruneFL does not support the corresponding time coefficient. The sparsity of all experiments is initialized to 80%80\%. We use the FLOPs to evaluate whether local devices suffer from intensive computation overhead, thus showing the efficiency of the federated learning algorithm.

For the Fedavg method, the accuracy after convergence is 53.66%53.66\%. Since it is conducted on a dense model, the FLOPs are not reduced, making it inefficient for resource-constrained devices. For the FedDST method, the FLOPs are significantly reduced to 20%20\% of the dense model. However, the accuracy drop is not negligible. In addition, the convergence speed is slower than the FedAvg method due to the consistent neuron growth and pruning steps. For the PruneFL approach, the accuracy drop after convergence is less than that of FedDST. Note that PruneFL in our experiment ends up with an all-one mask, which makes the model dense again when convergences. Therefore, the FLOPs of PruneFL remain the same compared to FedAvg. Since our NTK method collects pruning information from all local devices, the accuracy drop is largely reduced to 0.42%0.42\%. Moreover, the NTK method avoids model sparsity readjustment during the training. As a result, the training stability shown in Figure 1 is consolidated compared with FedDST and PruneFL. Though the FLOPs for a single forward pass of NTK are not as small as FedDST, the sparsity readjustment in FedDST inevitably brings more computation expenses during the training. For the CIFAR-100 dataset, we can observe that the FedDST outperforms other methods in terms of accuracy. Yet, FedDST implements multiple local training epochs (5 epochs in the original paper) to realize the sparsity readjustment, resulting in a heavy computation burden. We further display the experimental results of the NTK pruning given 11 local training epoch. The accuracy outperforms 55 epoch FedDST by 0.02%0.02\%, and the computation expense in FLOPs is 5×5\times cheaper than FedDST. Since the advantages in accuracy brought by the 11 epoch align with the observations that fewer local steps enhance the global consistency (Liu et al., 2023), we conclude that our proposed pruning method enjoys both computation efficiency and better performance.

Refer to caption
Figure 1. Test accuracy comparison on CIFAR-10.
Refer to caption
Figure 2. Test accuracy comparison on LeNet for CIFAR-10. KK is set to 200. The communication round is set to 3000.

Communication Cost Analysis. Assume that the target density is dd and the number of parameters of the model is nn. PruneFL requires devices to send full gradients to the server every ΔR\Delta R round, resulting in an average upload cost of (32d+32ΔR)n(32d+\frac{32}{\Delta R})n bits per device per round, where ΔR\Delta R is the interval of sparsity readjustment. Additionally, the maximum upload cost amounts to (32d+32)n(32d+32)n parameters. On the contrary, FedDST has an average communication cost of (32d+1ΔR)n(32d+\frac{1}{\Delta R})n bits per device per round before the completion of sparsity readjustment. Once the readjustment is finished, the communication cost decreases to 32dn32dn. In our method, the worst case of communication expense is 32Tpn+32dn32T_{p}n+32dn bits per device per round.

Data-free Foresight Federated-Pruning Analysis. Figure 1 shows that the result based on randomly generated data is superior to that on the real data. In addition, the data and label-agnostic properties provide shortcuts to avoid redundant upload and download costs during the pruning step. Using randomly generated data, the communication expense during the foresight pruning is reduced to 0, leading to a communication expense of 32dn32dn throughout the training process. As a result, the proposed NTK pruning demonstrates significant advantages in reducing transmission burdens.

Refer to caption
(a) K=50K=50
Refer to caption
(b) K=100K=100
Refer to caption
(c) K=200K=200
Refer to caption
(d) K=50K=50
Refer to caption
(e) K=100K=100
Refer to caption
(f) K=200K=200
Figure 3. Test accuracy for CIFAR-10 with different KK.

5.3. Experimental Results for BAFFLE

Table 2. Classification accuracy (%), consumed FLOPS and peak memory for CIFAR-10/100.
Model Method Max Acc. FLOPs
Peak
Memory
LeNet-5 FedAvg 54.0754.07 1x1\textbf{x} 1x1\textbf{x}
FedDST 46.4346.43 0.20x0.20\textbf{x} 1x1\textbf{x}
PruneFL 52.8652.86 1x1\textbf{x} 1x1\textbf{x}
NTK-Rand 53.6553.65 0.72x0.72\textbf{x} 0.10x0.10\textbf{x}
ResNet-20 FedAvg 31.6931.69 1x1\textbf{x} 1x1\textbf{x}
FedDst 34.8834.88 0.43x0.43\textbf{x} 1x1\textbf{x}
NTK-Rand 33.5733.57 0.43x0.43\textbf{x} 1x1\textbf{x}
FedAvg(1-ep) 34.1934.19 1x1\textbf{x} 1x1\textbf{x}
NTK-Rand(1-ep) 34.5034.50 0.43x0.43\textbf{x} 0.20x0.20\textbf{x}

Classification Accuracy Improvements. Since the performance in Table 2 and Figure 1 show that 11 local epoch and “NTK-Rand” both maximize the global performance, we set the local training epoch as 11 and utilize “NTK-Rand” for the following experiments. We denote the vanilla BP-Free method in (Feng et al., 2023) as Vanilla-BAFFLE and our method as NTK-BAFFLE. It is worth mentioning that the BP-Free training for FL is currently not competitive with the backpropagation-based FL as the estimated gradients are not as accurate as the true gradients. Therefore, we focus on how many improvements our method can achieve. Table 3 displays the maximum accuracy comparison between NTK-BAFFLE and Vanilla-BAFFLE on the CIFAR-10 dataset with a sparsity of 90%90\%. Note that the number of local training epochs is set to 1. For experiments on the two datasets, the NTK-BAFFLE consistently outperforms the vanilla BAFFLE regarding classification accuracy. For instance, the maximum accuracy is boosted by 4.62%4.62\%, 6.35%6.35\%, and 3.8%3.8\% on LeNet when setting KK to 50, 100, and 200, respectively. Figure 2 shows the learning curve against communication rounds given different values of KK. It’s clear that NTK-BAFFLE performs better than the Vanilla-BAFFLE method on various settings of KK. Figure 3 presents the learning curves from the perspective of accuracy versus FLOPs. The first row shows the results on LeNet, and the second row shows that on ResNet-20.

When utilizing ResNet-20, the improvements in the maximum accuracy are 1.00%1.00\%, 1.93%1.93\%, and 1.46%1.46\%, respectively. These phenomena are mainly due to the reason that the estimation over a more complex is not as stable as that over a relatively simple model. Although the maximum accuracy improvements are not as prominent as that when using ResNet-20, the learning curves against FLOPs from Figure 3 convincingly demonstrate that NTK-BAFFLE achieves higher accuracy when consuming the same FLOPs.

Table 3. Maximum accuracy (%) comparison on CIFAR-10.
Settings Value of KK
50 100 200
LeNet-5 NTK-BAFFLE 39.68\mathbf{39.68} 42.70\mathbf{42.70} 46.39\mathbf{46.39}
Vanilla-BAFFLE 35.0635.06 36.3536.35 42.5942.59
FedAvg 53.20
ResNet-20 NTK-BAFFLE 35.74\mathbf{35.74} 38.58\mathbf{38.58} 45.27\mathbf{45.27}
Vanilla-BAFFLE 34.7434.74 36.6536.65 43.8143.81
FedAvg 48.26
Table 4. Maximum accuracy (%) comparison on CIFAR-100.
Settings Value of K
50 100 200
LeNet-5 NTK-BAFFLE 5.59\mathbf{5.59} 6.80\mathbf{6.80} 8.35\mathbf{8.35}
Vanilla-BAFFLE 4.404.40 4.944.94 5.955.95
FedAvg 23.52
ResNet-20 NTK-BAFFLE 12.04\mathbf{12.04} 14.84\mathbf{14.84} 17.73\mathbf{17.73}
Vanilla-BAFFLE 11.5411.54 14.4114.41 15.9615.96
FedAvg 34.19

Table 4 presents the maximum accuracy comparison between NTK-BAFFLE and Vanilla-BAFFLE on the CIFAR-100 dataset. Unlike the 90%90\% sparsity setting for the CIFAR-10 dataset, we uniformly set the 80%80\% sparsity level for the CIFAR-100 dataset to ensure the Figure 5 exhibits the learning curves regarding accuracy versus FLOPs. Same as Figure 3, the first row in Figure 5 presents the results on LeNet, and the second row shows that on ResNet-20. The maximum accuracy improvements based on the NTK-BAFFLE against Vanilla-BAFFLE are consistently observed through all kinds of settings. For experiments conducted on LeNet, the maximum accuracy is increased by 1.19%1.19\%, 1.86%1.86\%, and 2.4%2.4\%, respectively. For experiments on ResNet-20, the accuracy is improved by 0.50%0.50\%, 0.43%0.43\%, and 1.77%1.77\% when setting KK to 50, 100, and 200, respectively. From Figure 3 and Figure 5, we can notice that our proposed pruning method consistently enhances the performance of the original BAFFLE method by achieving higher accuracy while consuming lower FLOPs.

Table 5. Platform configuration.
Type Device Configuration Number
Device Jetson Nano 128-core Maxwell GPU 7
Jetson Xavier AGX 512-core NVIDIA GPU 3
Server Workstation NVIDIA RTX 3080 GPU 1

Real Test-bed Results. We conducted real test-bed experiments on CIFAR-10 under the β=0.1\beta=0.1 heterogeneity. The model we utilized is the LeNet-5. The training hyperparameters stay the same and the number of Monte Carlo steps KK is set to 5050. The platform configuration is listed in Table 5. It is composed of 1010 device devices and one cloud server. 22 devices are randomly selected in each training round. Figure 4 shows the device devices and the comparison of learning curves. We can find that our method consistently beats the Vanilla-BAFFLE in classification accuracy.

Refer to caption
(a) Devices of our real test-bed
Refer to caption
(b) Learning curves
Figure 4. Learning performance on the real test-bed.
Refer to caption
(a) K=50K=50
Refer to caption
(b) K=100K=100
Refer to caption
(c) K=200K=200
Refer to caption
(d) K=50K=50
Refer to caption
(e) K=100K=100
Refer to caption
(f) K=200K=200
Figure 5. Test accuracy for CIFAR-100 with different KK.

Memory Analysis. The Vanilla-BAFFLE and NTK-BAFFLE offer an effective solution for minimizing memory usage on edge devices, combining the benefits of static and dynamic memory efficiency. When it comes to running backpropagation on deep networks using an auto-differential framework, additional static memory is needed. This places a significant strain on memory-constrained devices. Moreover, backpropagation relies on storing intermediate activations, leading to substantial requirements for dynamic memory. In contrast, the additional memory to store the perturbation in each forward pass and estimate the local gradient for the BAFFLE-based method is O(n)O(n). To maximize the benefits of the BP-Free method, we perform layer-by-layer computations, effectively partitioning the forward computation graph into smaller segments. Table 6 shows the peak GPU memory usage (MB) during the forward pass. For LeNet, the peak memory usage of Vanilla-BAFFLE and NTK-BAFFLE is 10.26%10.26\% of the BP operation. For ResNet-20, the peak memory of BAFLLE-related methods is 19.74%19.74\% of the BP process. Since NTK-BAFFLE implements sparse training, additional memory is needed to store the binary mask. However, modern frameworks (e.g., (Elsen et al., 2020)) are capable of efficiently storing and processing sparse matrices. In addition, the binary mask can be efficiently stored with a 11-bit datatype. As a result, the extra memory usage associated with NTK-BAFFLE is negligible.

Table 6. The peak GPU memory cost (in MB, including feature maps and parameters) of vanilla backpropagation (BP), Vanilla-BAFFLE, and NTK-BAFFLE.
Model BP Vanilla-BAFFLE NTK-BAFFLE
LeNet 16171617 166166 𝟏𝟔𝟔\mathbf{166}
ResNet-20 16711671 330330 𝟑𝟑𝟎\mathbf{330}

The additional memory required for storing the perturbation in each forward pass and estimating the local gradient for the BAFFLE-based method is O(n)O(n). It is worth noting that modern frameworks (e.g., (Elsen et al., 2020)) can efficiently store and process sparse matrices, making the memory needed to store the binary mask associated with NTK-BAFFLE negligible. We perform layer-by-layer computations, partitioning the forward computation graph into smaller segments and calculating the peak GPU memory usage during the forward pass when the input is 3232 RGB images with a size of 32×3232\times 32 from the CIFAR-10 dataset. In LeNet and ResNet-20, the peak memory usage of Vanilla-BAFFLE is significantly smaller compared to backpropagation-based methods, with values of 166166 MB vs. 16171617 MB and 330330 MB vs. 16711671 MB, respectively. Additionally, the proposed NTK-BAFFLE has fewer parameters. When considering an 80%80\% pruning rate, the peak memory usage of NTK-BAFFLE due to parameters is only 7.44%7.44\% (LeNet) and 3.88%3.88\% (ResNet-20) of the usage in Vanilla-BAFFLE, respectively.

Communication Overhead Analysis. In the BP-based FL system, the communication cost composed by uploading and downloading is twice the number of model parameters. In Vanilla-BAFFLE and NTK-BAFFLE, each device only uploads a KK-dimensional vector to the server for aggregation. Since KK is significantly less than the parameter amounts nn, both the Vanilla-BAFFLE and NTK-BAFFLE greatly reduce the cost during the upload compared to the conventional backpropagation-based FL. For download, NTK-BAFFLE requires less data transmission due to the sparsity property compared to Vanilla-BAFFLE. Overall, NTK-BAFFLE has the lowest communication costs.

6. Conclusion

Although Federated Learning (FL) is becoming increasingly popular in AIoT design, it greatly suffers from the problem of low inference performance when dealing with memory-constrained AIoT devices. To address this issue, in this paper we proposed a memory-efficient federated foresight pruning method based on the NTK theory. Since our proposed foresight-pruning method can seamlessly integrate into BP-Free training frameworks, it can significantly reduce the memory footprint during FL training. Meanwhile, by combining the proposed foresight pruning with the finite difference method used in BP-Free training, the performance of the vanilla BP-Free method is optimized, thus significantly reducing the overall FLOPs of FL. Note that, due to the introduction of a data-free approach to diminishing the error of using local NTK matrices to approximate the federated NTK matrix, our approach is resilient to various extreme data heterogeneity scenarios. Comprehensive experimental results obtained from simulation- and real test-bed-based platforms with various DNN models demonstrate the effectiveness of our method from the perspectives of memory usage and computation burden.

References

  • (1)
  • Abadi et al. (2016) Martín Abadi, Paul Barham, Jianmin Chen, Zhifeng Chen, Andy Davis, Jeffrey Dean, Matthieu Devin, Sanjay Ghemawat, Geoffrey Irving, Michael Isard, Manjunath Kudlur, Josh Levenberg, Rajat Monga, Sherry Moore, Derek G. Murray, Benoit Steiner, Paul Tucker, Vijay Vasudevan, Pete Warden, Martin Wicke, Yuan Yu, and Xiaoqiang Zheng. 2016. TensorFlow: A system for large-scale machine learning. In Proceedings of USENIX Symposium on Operating Systems Design and Implementation (OSDI). 265–283.
  • Adamczak et al. (2011) Radosław Adamczak, Alexander E Litvak, Alain Pajor, and Nicole Tomczak-Jaegermann. 2011. Sharp bounds on the rate of convergence of the empirical covariance matrix. Comptes Rendus. Mathématique 349, 3-4 (2011), 195–200.
  • Bibikar et al. (2022) Sameer Bibikar, Haris Vikalo, Zhangyang Wang, and Xiaohan Chen. 2022. Federated dynamic sparse training: Computing less, communicating less, yet learning better. In Proceedings of the AAAI Conference on Artificial Intelligence (AAAI). 6080–6088.
  • Cao et al. (2021) Jing Cao, Zirui Lian, Weihong Liu, Zongwei Zhu, and Cheng Ji. 2021. HADFL: Heterogeneity-aware Decentralized Federated Learning Framework. In Proceedings of the Design Automation Conference (DAC). 1–6.
  • Chandrasekaran et al. (2022) Rishikanth Chandrasekaran, Kazim Ergun, Jihyun Lee, Dhanush Nanjunda, Jaeyoung Kang, and Tajana Rosing. 2022. FHDnn: communication efficient and robust federated learning for AIoT networks. In Proceedings of the Design Automation Conference (DAC). 37–42.
  • Cui et al. (2023) Yangguang Cui, Kun Cao, Junlong Zhou, and Tongquan Wei. 2023. Optimizing Training Efficiency and Cost of Hierarchical Federated Learning in Heterogeneous Mobile-Edge Cloud Computing. IEEE Transactions on Computer-Aided Design of Integrated Circuits and Systems (TCAD) 42, 5 (2023), 1518–1531.
  • Elsen et al. (2020) Erich Elsen, Marat Dukhan, Trevor Gale, and Karen Simonyan. 2020. Fast sparse convnets. In Proceedings of the Conference on Computer Vision and Pattern Recognition (CVPR). 14629–14638.
  • Feng et al. (2023) Haozhe Feng, Tianyu Pang, Chao Du, Wei Chen, Shuicheng Yan, and Min Lin. 2023. Does Federated Learning Really Need Backpropagation? arXiv preprint arXiv:2301.12195 (2023).
  • HaoChen et al. (2021) Jeff Z. HaoChen, Colin Wei, Jason D. Lee, and Tengyu Ma. 2021. Shape Matters: Understanding the Implicit Bias of the Noise Covariance. In Proceedings of Annual Conference Computational Learning Theory (COLT). 2315–2357.
  • He et al. (2016) Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. 2016. Deep residual learning for image recognition. In Proceedings of the Conference on Computer Vision and Pattern Recognition (CVPR). 770–778.
  • Huang et al. (2021) Baihe Huang, Xiaoxiao Li, Zhao Song, and Xin Yang. 2021. Fl-ntk: A neural tangent kernel-based framework for federated learning analysis. In Proceedings of International Conference on Machine Learning (ICML). 4423–4434.
  • Huang et al. (2023) Hong Huang, Lan Zhang, Chaoyue Sun, Ruogu. Fang, Xiaoyong Yuan, and Dapeng Wu. 2023. Distributed Pruning Towards Tiny Neural Networks in Federated Learning. In Proceedings of the International Conference on Distributed Computing Systems (ICDCS). 190–201.
  • Jacot et al. (2018) Arthur Jacot, Franck Gabriel, and Clément Hongler. 2018. Neural Tangent Kernel: Convergence and Generalization in Neural Networks. In Proceedings of Advances in Neural Information Processing Systems (NeurIPS). 8580–8589.
  • Jiang et al. (2023) Yuang Jiang, Shiqiang Wang, Víctor Valls, Bong Jun Ko, Wei-Han Lee, Kin K. Leung, and Leandros Tassiulas. 2023. Model Pruning Enables Efficient Federated Learning on Edge Devices. IEEE Transactions on Neural Networks and Learning Systems (TNNLS) 34, 12 (2023), 10374–10386.
  • Khan et al. (2021) Latif U. Khan, Walid Saad, Zhu Han, Ekram Hossain, and Choong Seon Hong. 2021. Federated Learning for Internet of Things: Recent Advances, Taxonomy, and Open Challenges. IEEE Communications Surveys and Tutorials 23, 3 (2021), 1759–1799.
  • Krizhevsky (2009) A Krizhevsky. 2009. Learning Multiple Layers of Features from Tiny Images. Master’s thesis, University of Tront (2009).
  • LeCun et al. (1998) Yann LeCun, Léon Bottou, Yoshua Bengio, and Patrick Haffner. 1998. Gradient-based learning applied to document recognition. Proc. IEEE 86, 11 (1998), 2278–2324.
  • Li et al. (2021) Ang Li, Jingwei Sun, Binghui Wang, Lin Duan, Sicheng Li, Yiran Chen, and Hai Li. 2021. LotteryFL: Empower Edge Intelligence with Personalized and Communication-Efficient Federated Learning. In Proceedings of IEEE/ACM Symposium on Edge Computing (SEC). 68–79.
  • Li et al. (2020) Jian Li, Xuanyuan Luo, and Mingda Qiao. 2020. On generalization error bounds of noisy gradient methods for non-convex learning. In Proceedings of International Conference on Learning Representations (ICLR).
  • Li et al. (2022) Qinbin Li, Yiqun Diao, Quan Chen, and Bingsheng He. 2022. Federated learning on non-iid data silos: An experimental study. In Proceedings of the International Conference on Data Engineering (ICDE). 965–978.
  • Li and Lyu (2023) Yipeng Li and Xinchen Lyu. 2023. Convergence Analysis of Sequential Federated Learning on Heterogeneous Data. In Proceedings of Annual Conference on Neural Information Processing Systems (NeurIPS).
  • Li and Chen (2021) Zan Li and Li Chen. 2021. Communication-efficient decentralized zeroth-order method on heterogeneous data. In Proceedings of the International Conference on Wireless Communications and Signal Processing (WCSP). 1–6.
  • Liang et al. (2021) Tailin Liang, John Glossner, Lei Wang, Shaobo Shi, and Xiaotong Zhang. 2021. Pruning and quantization for deep neural network acceleration: A survey. Neurocomputing 461 (2021), 370–403.
  • Liao et al. (2023) Dongping Liao, Xitong Gao, Yiren Zhao, and Cheng-Zhong Xu. 2023. Adaptive Channel Sparsity for Federated Learning Under System Heterogeneity. In Proceedings of the Conference on Computer Vision and Pattern Recognition (CVPR). 20432–20441.
  • Liu et al. (2023) Yixing Liu, Yan Sun, Zhengtao Ding, Li Shen, Bo Liu, and Dacheng Tao. 2023. Enhance local consistency in federated learning: A multi-step inertial momentum approach. arXiv preprint arXiv:2302.05726 (2023).
  • McMahan et al. (2017a) Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, and Blaise Aguera y Arcas. 2017a. Communication-efficient learning of deep networks from decentralized data. In Proceedings of Artificial intelligence and statistics (AISTATS). 1273–1282.
  • McMahan et al. (2017b) Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, and Blaise Aguera y Arcas. 2017b. Communication-efficient learning of deep networks from decentralized data. In Proceedings of Artificial Intelligence and Statistics (AISTATS). 1273–1282.
  • Mendieta et al. (2022) Matias Mendieta, Taojiannan Yang, Pu Wang, Minwoo Lee, Zhengming Ding, and Chen Chen. 2022. Local learning matters: Rethinking data heterogeneity in federated learning. In Proceedings of the Conference on Computer Vision and Pattern Recognition (CVPR). 8397–8406.
  • Paszke et al. (2019) Adam Paszke, Sam Gross, Francisco Massa, Adam Lerer, James Bradbury, Gregory Chanan, Trevor Killeen, Zeming Lin, Natalia Gimelshein, Luca Antiga, Alban Desmaison, Andreas Köpf, Edward Yang, Zach DeVito, Martin Raison, Alykhan Tejani, Sasank Chilamkurthy, Benoit Steiner, Lu Fang, Junjie Bai, and Soumith Chintala. 2019. PyTorch: An Imperative Style, High-Performance Deep Learning Library. In Proceedings of Annual Conference on Neural Information Processing Systems (NeurIPS). 8024–8035.
  • Qiu et al. (2022) Xinchi Qiu, Javier Fernandez-Marques, Pedro PB Gusmao, Yan Gao, Titouan Parcollet, and Nicholas Donald Lane. 2022. ZeroFL: Efficient On-Device Training for Federated Learning with Local Sparsity. In Proceedings of the International Conference on Learning Representations (ICLR).
  • Sarkar et al. (2023) Rishov Sarkar, Hanxue Liang, Zhiwen Fan, Zhangyang Wang, and Cong Hao. 2023. Edge-moe: Memory-efficient multi-task vision transformer architecture with task-level sparsity via mixture-of-experts. In Proceedings of the International Conference on Computer-Aided Design (ICCAD). 1–9.
  • Shu et al. (2022) Yao Shu, Shaofeng Cai, Zhongxiang Dai, Beng Chin Ooi, and Bryan Kian Hsiang Low. 2022. NASI: Label- and Data-agnostic Neural Architecture Search at Initialization. In Proceedings of the International Conference on Learning Representations (ICLR).
  • Shu et al. (2023) Yao Shu, Zhongxiang Dai, Weicong Sng, Arun Verma, Patrick Jaillet, and Bryan Kian Hsiang Low. 2023. Zeroth-Order Optimization with Trajectory-Informed Derivative Estimation. In Proceedings of the International Conference on Learning Representations (ICLR).
  • Shukla et al. (2023) Sanket Shukla, Setareh Rafatirad, Houman Homayoun, and Sai Manoj Pudukottai Dinakarrao. 2023. Federated Learning with Heterogeneous Models for On-device Malware Detection in IoT Networks. In Proceedings of Design, Automation & Test in Europe Conference & Exhibition (DATE). 1–6.
  • Stein (1981) Charles M. Stein. 1981. Estimation of the Mean of a Multivariate Normal Distribution. The Annals of Statistics 9, 6 (1981), 1135–1151.
  • Wang et al. (2023) Yite Wang, Dawei Li, and Ruoyu Sun. 2023. NTK-SAP: Improving neural network pruning by aligning training dynamics. In Proceedings of the International Conference on Learning Representations (ICLR).
  • Wu et al. (2024) Donglei Wu, Weihao Yang, Haoyu Jin, Xiangyu Zou, Wen Xia, and Binxing Fang. 2024. FedComp: A Federated Learning Compression Framework for Resource-Constrained Edge Computing Devices. IEEE Transactions on Computer-Aided Design of Integrated Circuits and Systems (TCAD) 43, 1 (2024), 230–243.
  • Xia et al. (2022) Jun Xia, Tian Liu, Zhiwei Ling, Ting Wang, Xin Fu, and Mingsong Chen. 2022. PervasiveFL: Pervasive Federated Learning for Heterogeneous IoT Systems. IEEE Transactions on Computer-Aided Design of Integrated Circuits and Systems (TCAD) 41, 11 (2022), 4100–4111.
  • Yang and Sun (2022) Zhao Yang and Qingshuang Sun. 2022. Personalized Heterogeneity-Aware Federated Search Towards Better Accuracy and Energy Efficiency. In Proceedings of the International Conference on Computer-Aided Design (ICCAD). 59:1–59:9.
  • Yang and Sun (2023) Zhao Yang and Qingshuang Sun. 2023. Mitigating Heterogeneities in Federated Edge Learning with Resource-independence Aggregation. In Proceedings of the Design, Automation & Test in Europe Conference & Exhibition (DATE). 1–2.
  • Yi et al. (2022) Xinlei Yi, Shengjun Zhang, Tao Yang, and Karl H. Johansson. 2022. Zeroth-Order Algorithms for Stochastic Distributed Nonconvex Optimization. Automatica 142 (2022), 110353.
  • Zhang and Tao (2021) Jing Zhang and Dacheng Tao. 2021. Empowering Things with Intelligence: A Survey of the Progress, Challenges, and Opportunities in Artificial Intelligence of Things. IEEE Internet of Things Journal 8, 10 (2021), 7789–7817.
  • Zhu et al. (2022) Zhuangdi Zhu, Junyuan Hong, Steve Drew, and Jiayu Zhou. 2022. Resilient and Communication Efficient Learning for Heterogeneous Federated Systems. In Proceedings of the International Conference on Machine Learning (ICML). 27504–27526.