When Foresight Pruning Meets Zeroth-Order Optimization: Efficient Federated Learning for Low-Memory Devices
Abstract.
Although Federated Learning (FL) enables collaborative learning in Artificial Intelligence of Things (AIoT) design, it fails to work on low-memory AIoT devices due to its heavy memory usage. To address this problem, various federated pruning methods are proposed to reduce memory usage during inference. However, few of them can substantially mitigate the memory burdens during pruning and training. As an alternative, zeroth-order or backpropagation-free (BP-Free) methods can partially alleviate the memory consumption, but they suffer from scaling up and large computation overheads, since the gradient estimation error and floating point operations (FLOPs) increase as the dimensionality of the model parameters grows. In this paper, we propose a federated foresight pruning method based on Neural Tangent Kernel (NTK), which can seamlessly integrate with federated BP-Free training frameworks. We present an approximation to the computation of federated NTK by using the local NTK matrices. Moreover, we demonstrate that the data-free property of our method can substantially reduce the approximation error in extreme data heterogeneity scenarios. Since our approach improves the performance of the vanilla BP-Free method with fewer FLOPs and truly alleviates memory pressure during training and inference, it makes FL more friendly to low-memory devices. Comprehensive experimental results obtained from simulation- and real test-bed-based platforms show that our federated foresight-pruning method not only preserves the ability of the dense model with a memory reduction up to but also boosts the performance of the vanilla BP-Free method with dramatically fewer FLOPs.
1. Introduction
As a promising collaborative learning paradigm in Artificial Intelligence of Things (AIoT) (Zhang and Tao, 2021; Shukla et al., 2023; Chandrasekaran et al., 2022) design, Federated Learning (FL) (Khan et al., 2021) enables knowledge sharing among AIoT devices without compromising their data privacy. Specifically, FL enables each device to perform local training and transmit gradients instead of private data. The server then aggregates these local gradients to update the global model (Yang and Sun, 2023) and redistributes it to a selected subset of devices for the next round of training. Unlike centralized training deployed on powerful hardware platforms, FL is restricted by stringent local training resource requirements (Khan et al., 2021; Cui et al., 2023) reflected by being short of adequate computation and memory resources. The expensive communication expenses during the training additionally limit the practical deployment of FL (Yang and Sun, 2022; Cao et al., 2021). Moreover, the increasing computation overhead of Deep Neural Networks (DNNs) presents a major barrier to running them on edge devices (Sarkar et al., 2023). Prior work has primarily concentrated on model sparsity methods (Liang et al., 2021; Wu et al., 2024) to address these challenges in a one-shot manner. By introducing neural network pruning or sparse training approaches (Bibikar et al., 2022), the local computation overhead and communication costs are considerably relieved. Yet, since local devices privately own the data, the standard neural network pruning methods cannot be applied to FL without any modification. The notorious data heterogeneity problem further prohibits local devices from learning a common sparse model (Xia et al., 2022).
Method |
|
BP-Free |
|
|||
---|---|---|---|---|---|---|
|
✗ | ✗ | ||||
|
✓ | ✗ | ||||
|
✓ | ✗ | ||||
|
✓ | ✗ | ||||
|
✓ | ✗ | ||||
|
✗ | ✓ | ||||
|
✓ | ✓ |
Existing works perform pruning operations under various federated settings (Li et al., 2021; Jiang et al., 2023; Zhu et al., 2022; Liao et al., 2023) to bridge the abovementioned gap. Although these methods are efficient in inference costs, their training costs regarding memory consumption are seldom investigated. Since the pruning-during-training methods need to maintain the backpropagation structure (Paszke et al., 2019; Abadi et al., 2016), the memory consumption is at least the same as the vanilla training. Their heavy dependence on the backpropagation framework prohibits them from truly supporting a lightweight training paradigm for low-memory devices. As a result, existing federated pruning methods are not efficient in terms of memory usage; To eliminate the memory-inefficient auto-differential framework in FL, recent work (Feng et al., 2023; Li and Chen, 2021; Yi et al., 2022) proposed backpropagation-free (BP-Free) FL methods, aiming to estimate the actual gradients using zeroth-order optimization. Integrating federated pruning methods and BP-Free FL methods has enormous potential to reduce memory consumption, but it has compatibility gaps and harms performance. Existing federated pruning steps entangle with gradient computations, making them severely unstable to the precision of the estimated gradient when using a BP-Free strategy.
To mitigate the problems mentioned above, we seek better solutions from the perspective of foresight-federated pruning methods, which disentangle the pruning and training steps. Inspired by (Feng et al., 2023) and (Wang et al., 2023), we propose a federated and BP-Free foresight pruning method to allow lightweight federated pruning. Additionally, we leverage the sparse structure of the pruned model to greatly reduce the local computational overhead and boost the performance of Stein’s Identity, a classic zeroth-order optimization method, in FL settings. We compared the device peak memory cost of existing outstanding federated learning methods in Table 1. Note that our approach explores the synergy between federated pruning and BP-Free training, aiming to allow FL on AIoT devices with extremely low memories. In summary, this paper makes the following three major contributions:
-
•
We propose a novel memory-efficient foresight pruning method for FL, which is resilient to various data heterogeneities.
-
•
We propose an approximation to the federated NTK matrix and show that the data-free property of our method can effectively reduce the approximation error.
-
•
We implement our proposed foresight pruning method in conjunction with BP-Free training and conduct comprehensive experiments on both simulation- and real test-bed-based platforms to demonstrate the effectiveness of our approach.
2. Related Work
Federated Neural Network Pruning. A two-step pruning method was proposed in PruneFL (Jiang et al., 2023) to reduce the computation cost on the device side. It first completes coarse pruning on a selected device. Further, finer pruning is conducted on the server side to reduce computation expenses. ZeroFL (Qiu et al., 2022) separates model parameters into active and non-active groups for the purpose of the efficiency of both forward and backward passes. Although the aforementioned methods leave the finer pruning on the server side, they still require a large local memory footprint to record the updated importance scores of all parameters for the guidance of pruning. FedDST (Bibikar et al., 2022), instead, adopts sparse training, which consists of mask adjustment and regrow on the device side. Once the local sparse training is finished, the server conducts sparse aggregation and magnitude pruning procedures to obtain a global model. Even with the sparse training framework, the high requirements on memory are not relieved. Worse still, the local mask adjustment and regrow processes inevitably need additional training epochs, which contradicts the principle of minimizing local training steps in FL (Mendieta et al., 2022). To reduce the cost of the memory- and computation-intensive pruning process for extremely recourse-constrained devices, FedTiny (Huang et al., 2023) adaptively searches finer-pruned specialized models by progressive pruning based on a coarsely pruned model. Effective as it is under the low-density region for inference, the training step is hindered by the significant communication and memory burdens, making it less compatible with devices with limited memory capacity.
BP-Free Federated Learning. Gradient-based optimization techniques are commonly used for training DNNs in FL settings. Regardless, the backpropagation (BP) operations heavily depend on the auto-differential framework (e.g., Pytorch (Paszke et al., 2019)). The additional static and dynamic memory requirements are unaffordable to low-memory devices. Recently, advances in zeroth-order optimization methods (Li and Chen, 2021; Yi et al., 2022; Shu et al., 2023) have demonstrated potential for federated training, particularly in cases where backward processes are expensive from the perspective of computation and memory space. (Feng et al., 2023) proposes to utilize Stein’s Identity (Stein, 1981) as the gradient estimation method to develop a BP-Free federated learning framework. Yet, the Floating Point Operations (FLOPs) per communication round are significantly increased to an extremely high level during the training due to the property of BP-Free gradient estimation: numerous forward passes are required for one-time gradient estimation to decrease the estimation error (Adamczak et al., 2011; Feng et al., 2023). The degradation of learning performance is another issue since the BP-Free training comes at the cost of adding noises (Li et al., 2020; HaoChen et al., 2021) to the gradients. Therefore, we propose using sparsity techniques to reduce the FLOPs and boost learning performance in a one-shot manner.
3. Preliminaries
Let denote the entire training dataset. Let represent the total number of data points. Hence, we have and denote inputs and labels, respectively.
Federated Learning. FL aims to collaboratively learn a global model parameterized by while keeping local data private. Given that devices are involved in each round of local training, where the local dataset on -th device is denoted as with representing the number of data points, the objective is defined by
(1) |
where is the specified loss function on local dataset .
Neural Tangent Kernel. Neural Tangent Kernel (NTK) analyzes the training dynamics of neural networks (Jacot et al., 2018). Given an arbitrary DNN initialized by , the NTK at initial state is defined as
It has been proven that the NTK stays asymptotically constant during the training if the DNN is sufficiently large in terms of parameters. Therefore, the NTK at initialization (i.e., ) can characterize the training dynamics.
Foresight Pruning Based on NTK. The general objective function of foresight pruning is formulated as
(2) |
where denotes the model architecture dominated by the binary mask , represents the model parameter at initialization. We aim to find the mask in which each element follows the binary distribution that minimizes the loss function and is simultaneously constrained by the target density . To make Eq. 2 tractable, existing foresight pruning methods propose the saliency measurement function, defined as
(3) |
where represents a function of model parameters and mask . denotes the -th element of the mask , which is a scalar. represents the saliency score function that measures the impact of deactivating the (i.e., set to ). After the saliency score is computed, we keep the top- elements in the mask. In detail, we set the element in the mask () if it belongs to the set of top- elements and () otherwise. To analyze the property of , we follow the work proposed by (Wang et al., 2023) to compute the spectrum of it. The spectrum is formulated as
(4) |
where is the nuclear norm operation. The trace norm is equivalent to the nuclear norm since the NTK matrix is symmetric. To further reduce the cost of computing the NTK matrix, we rewrite the trace norm of the NTK matrix as the Frobenius norm of gradients . We aim to find a sparse model with the same training dynamics as the dense model. Therefore, the objective function of NTK-based foresight pruning is formulated as
(5) |
Note that to achieve a BP-Free method, the Fobenius norm of gradients can be approximated by
(6) |
which saves the expensive computation of Eq. 5.
Finite Difference and Stein’s Identity. Derived from the definition of derivatives, the Finite Difference (FD) method can be extended to multivariate cases by using Taylor’s expansion. Given the loss function and a small perturbation , the FD method is define as
(7) |
Assuming the loss function is continuously differentiable w.r.t. model parameter . The precise gradient of estimated by Stein’s Identity is formulated as
(8) |
where , follows a Gaussian distribution with zero mean and covariance . We utilize the Monte Carlo method to sample number of , thereby obtaining a stochastic version of the estimation as
(9) |
4. Methodology
4.1. Estimation Error of Stein’s Identity
Though the sampling method in Eq. 9 is friendly to low-memory devices, the computation overhead is extremely heavy since should be large for accurate estimation. Following what was proven in (Feng et al., 2023), the estimation error is formally defined by the following theorem:
Theorem 1.
(Estimation error (Feng et al., 2023)) Let , covariance matrix . Let be the dimension of trainable parameters of the DNN. The discrepancy between the true gradient and the estimated gradient is formulated as
; | |||
. |
Neglecting the trivial term , the estimation error measured by Mean Square Error (MSE) is fully controlled by the term . Based on the proof in (Adamczak et al., 2011), we have . In conclusion, we can expect a more accurate estimation if we either increase the number of Monte Carlo steps or decrease the dimensionality of the model parameter , or both. In practice, is related to the total FLOPs consumed to estimate the true gradient, and reflects the memory requirements for a device to afford the model. Developing more advanced zeroth-order algorithms to lower the value of required to achieve the same performance is a promising way to obtain more accurate estimations. Another way is to put effort into decreasing the value of . To achieve the latter, we introduce model pruning to reduce computation overhead and increase the estimation precision as it aims to find the spare structure of the model, and can still maintain the performance.
Input: 1) , a randomly initialized global model; 2) , a pool of all participants; 3) , # of pruning rounds; 4) , # of training rounds; 5) , # of participants for pruning; 6) , # of participants for training; 7) , # of perturbations to estimate gradients; 8) , target density; 9) , mask for parameters; 10) , learning rate.
NTK Foresight Pruning:
BP-Free Training:
Current federated pruning methods are highly integrated with the training process. Consequently, the noisy gradients negatively affect the pruning step, decreasing training performance. To mitigate the impact of noisy gradients on the pruning step, we propose decoupling the pruning processes from the training processes, which is one of the fundamental motivations of foresight pruning. Further, we dedicate efforts to developing a federated foresight-pruning method based on the NTK spectrum, ensuring full compatibility with the BP-Free training framework.
4.2. NTK-based Foresight Pruning for FL
Since the symmetric property of the is only valid for centralized training, we cannot directly apply it for federated training. Motivated by (Huang et al., 2021), we prove that the NTK in federated learning can be approximated by local NTK matrices.
Definition 0.
(Asymmetric FL-NTK) Let represent the local NTK of the device based on its dataset. Let represent and represent . The asymmetric FL-NTK is defined as . The is formulated by combining the columns of for all and padding the reset elements with zero.
The asymmetric property of FL-NTK prevents us from computing the nuclear norm by the Frobenius norm of gradients based on each local dataset. However, we prove that we can utilize the summation of local symmetric NTK matrices to approximate the FL-NTK and efficiently conduct the Frobenius norm computation.
Proposition 0.
Since is the horizontal concatenation of for , we decompose it as the summation of sparse NTK matrices. Let represent the sparse NTK matrix for the device. We can reformulate the FL-NTK matrix as
(10) | ||||
(11) |
where the non-filling elements are all zeros in .
Since the nuclear norm satisfies the triangle inequality, we further formulate the computation of as
(12) |
Therefore, the complex FL-NTK is upper-bounded by the summation of individual local NTK matrices. Since the sparse model preserves the same training dynamic as the dense model and estimated gradients based on Stein’s Identity are unbiased, the overall convergence rate based on the training round is , the same as the standard FedAvg (Li and Lyu, 2023) algorithm under non-convex settings.
Unlike centralized pruning, the entire dataset is distributed to multiple devices in FL. According to the form of Eq. 3, the computation of saliency score demands at least one operation of gradient computing, i.e., if not consider the function . We can wisely let the local devices compute the function and leave the derivative and multiplication operations in Eq. 15 to the server to complete. In this way, local datasets equally contribute to the generation of the binary mask, yielding an unbiased global sparse model, and the memory required to support the auto-differential framework is shifted to the server. Things might be different when considering the specific form of the function . As mentioned above, the function defined in Eq. 4 can be approximated by either Eq. 6 or the loss gradient w.r.t. weights . The latter introduces extra gradient computation, which inevitably relies on the auto-differential framework. In contrast, Eq. 6 benefits from being independent of the auto-differential framework. Regardless, the computation of the saliency score in Eq. 3 still intertwines with it (i.e., ). We might introduce another zeroth-order optimization method to estimate the derivatives, but it introduces more estimation errors. As a result, we choose Eq. 6 as the function to make the local devices free of the memory-inefficient backpropagation framework. The federated NTK foresight pruning is organized as
(13) | ||||
(14) | ||||
(15) |
where represents the -th dataset owned by the corresponding device. denotes the -th element of the initial global model. is sampled from a Gaussian distribution with zero mean and variance. The symbol denotes the expectation over , which can be approximated by the Monte Carlo sampling method. To avoid layer-collapse, we use an exponential decay schedule to compute the pruning threshold , where denotes the current pruning round, is the target density and is the maximum pruning round. The parameters whose saliency scores are under the threshold will be pruned. It may be noted that the proposed NTK pruning method introduces extra computations to the devices. In addition, the multiple-round pruning schedule requires communication between the server and all devices. To resolve the two problems in one shot, we fully explore the intrinsic property of NTK pruning. Motivated by (Shu et al., 2022), we show that under data heterogeneity scenarios, the difference between any twp local NTK matrices is , where and represent the local data distributions, represent . is the input dimension. In our cases, is and . Therefore, the approximation error defined by the summation operation in Eq. 12 can be represented by the difference and increases as the degree of heterogeneity grows. To alleviate the approximation error, we constraint the difference between and by sampling the data in Eq. 13 from the same standard Gaussian distribution, which diminishes the difference between two NTK matrices and makes our method resilient to data heterogeneity. Moreover, the computation of Eq. 13 can be fully conducted by the server, and our proposed federated NTK pruning method degrades to the centralized version, entirely saving the communication overhead.
4.3. Backpropgation-free Training for FL
We further integrate the proposed NTK foresight pruning with Stein’s Identity method to collaboratively build a memory-friendly federated training framework. To leverage Stein’s Identity in the context of FL, we rewrite the estimation of the aggregated gradient in the form of FedAvg (McMahan et al., 2017b) as
(16) |
Note that and is the dimensionality of gradients we want to estimate. To further decrease the communication burden, we use the Random Seed Trick technique. The Random Seed Trick smartly leverages the inherent property of the developing environment: i) has the same dimensionality as the gradients we want to estimate and is generated by sampling from a Gaussian distribution; ii) the sampling from Gaussian distribution is driven by the Random Seed in software level. Therefore, the server can produce the same sampling results as the local devices by maintaining the same Random Seed and running environments between them. In detail, Eq. 16 is performed by the server and devices as:
(17) |
In practice, we let the devices compute and leave the to the server. Consequently, the communication cost is significantly reduced from large dimensional vectors to scalars level since we do not need to transmit the vector . The overall training processes are shown in Algorithm 1.
5. Experiments
In this section, we conducted federated pruning and federated zeroth-order experiments to answer the following two pivotal Research Questions: RQ1: What are the advantages of our NTK-based federated foresight pruning? RQ2: What improvements do the NTK-based foresight federated pruning bring to the standard federated zeroth-order method?
5.1. Experimental Settings
Dataset. We evaluated the performances on two image classification tasks, i.e., CIFAR-10, CIFAR-100 (Krizhevsky, 2009). Following the heterogeneity configurations in (Li et al., 2022), we considered Dirichlet distribution throughout this paper, which is parameterized by a coefficient , denoted as . determines the degree of data heterogeneity. The smaller the is, the more heterogeneous the data will be. We only considered one non-IID scenario by setting as .
Models. To show the scalability of proposed foresight pruning to various types of models, we utilized the LeNet-5 (LeCun et al., 1998) and ResNet-20 (He et al., 2016) with batch normalization. For federated pruning evaluation experiments, we utilized LeNet-5 for CIFAR-10 and ResNet-20 for CIFAR-100. For the experiments on BP-Free federated training, we used LeNet-5 and ResNet-20 for CIFAR-10 and CIFAR100 datasets.
Baselines. For the pruning evaluation, we included FedDST and PruneFL to compare with the proposed NTK pruning. FedDST utilizes a sparse training method and conducts neuron growth and pruning at local devices. The PruneFL first conducts local coarse pruning at a randomly selected device, then performs server-side pruning based on collected gradients from local devices. We excluded the federated pruning methods that are inapplicable for memory-constrained FL (i.e., ZeroFL, LotteryFL). To show the improvements brought by our NTK pruning to BP-Free federated learning, we compared it with the state-of-the-art method BAFFLE (Feng et al., 2023), which conducted the vanilla Stein’s Identity method. Note that the result of FedAvg is conducted by backpropagation-based training.
Hyperparameter Settings. The pruning rounds for LeNet and ResNet-20 are set to and , respectively. The corresponding learning rate and batch size are set to and . Following the source code in (Bibikar et al., 2022), we set the number of local training epochs to if not specified. We randomly selected out of devices to perform local training at each round, and the training batch size is set to . For experiments on CIFAR-10, we set the learning rate, momentum, and weight decay as , , and , respectively. For the CIFAR-100, the momentum and weight decay are set to , for both models. Specifically, we set the learning rate to for LeNet and for ResNet. The learning rate decays to after each round for all experiments. For all experiments, we set , the number of Monte Carlo steps to estimate the true gradient, to , , and , respectively. For a fair comparison, we used a sparsity level of % and epoch training based on the original FedDST paper.
5.2. Experimental Results for Federated Pruning
Accuracy and FLOPs Analysis. Table 2 shows the maximum accuracy and FLOPs for all federated pruning baselines under an extreme non-IID scenario where . The training curves for CIFAR-10 are shown in Figure 1. Note that “NTK-Ori” represents the result on real data and “NTK-Rand” represents that on randomly generated data. We denote epoch as “ep” in tables. Since the time coefficient is required to implement PruneFL, we omit PruneFL for the ResNet-20 model as the source code of FedDST or PruneFL does not support the corresponding time coefficient. The sparsity of all experiments is initialized to . We use the FLOPs to evaluate whether local devices suffer from intensive computation overhead, thus showing the efficiency of the federated learning algorithm.
For the Fedavg method, the accuracy after convergence is . Since it is conducted on a dense model, the FLOPs are not reduced, making it inefficient for resource-constrained devices. For the FedDST method, the FLOPs are significantly reduced to of the dense model. However, the accuracy drop is not negligible. In addition, the convergence speed is slower than the FedAvg method due to the consistent neuron growth and pruning steps. For the PruneFL approach, the accuracy drop after convergence is less than that of FedDST. Note that PruneFL in our experiment ends up with an all-one mask, which makes the model dense again when convergences. Therefore, the FLOPs of PruneFL remain the same compared to FedAvg. Since our NTK method collects pruning information from all local devices, the accuracy drop is largely reduced to . Moreover, the NTK method avoids model sparsity readjustment during the training. As a result, the training stability shown in Figure 1 is consolidated compared with FedDST and PruneFL. Though the FLOPs for a single forward pass of NTK are not as small as FedDST, the sparsity readjustment in FedDST inevitably brings more computation expenses during the training. For the CIFAR-100 dataset, we can observe that the FedDST outperforms other methods in terms of accuracy. Yet, FedDST implements multiple local training epochs (5 epochs in the original paper) to realize the sparsity readjustment, resulting in a heavy computation burden. We further display the experimental results of the NTK pruning given local training epoch. The accuracy outperforms epoch FedDST by , and the computation expense in FLOPs is cheaper than FedDST. Since the advantages in accuracy brought by the epoch align with the observations that fewer local steps enhance the global consistency (Liu et al., 2023), we conclude that our proposed pruning method enjoys both computation efficiency and better performance.


Communication Cost Analysis. Assume that the target density is and the number of parameters of the model is . PruneFL requires devices to send full gradients to the server every round, resulting in an average upload cost of bits per device per round, where is the interval of sparsity readjustment. Additionally, the maximum upload cost amounts to parameters. On the contrary, FedDST has an average communication cost of bits per device per round before the completion of sparsity readjustment. Once the readjustment is finished, the communication cost decreases to . In our method, the worst case of communication expense is bits per device per round.
Data-free Foresight Federated-Pruning Analysis. Figure 1 shows that the result based on randomly generated data is superior to that on the real data. In addition, the data and label-agnostic properties provide shortcuts to avoid redundant upload and download costs during the pruning step. Using randomly generated data, the communication expense during the foresight pruning is reduced to , leading to a communication expense of throughout the training process. As a result, the proposed NTK pruning demonstrates significant advantages in reducing transmission burdens.






5.3. Experimental Results for BAFFLE
Model | Method | Max Acc. | FLOPs |
|
||
---|---|---|---|---|---|---|
LeNet-5 | FedAvg | |||||
FedDST | ||||||
PruneFL | ||||||
NTK-Rand | ||||||
ResNet-20 | FedAvg | |||||
FedDst | ||||||
NTK-Rand | ||||||
FedAvg(1-ep) | ||||||
NTK-Rand(1-ep) |
Classification Accuracy Improvements. Since the performance in Table 2 and Figure 1 show that local epoch and “NTK-Rand” both maximize the global performance, we set the local training epoch as and utilize “NTK-Rand” for the following experiments. We denote the vanilla BP-Free method in (Feng et al., 2023) as Vanilla-BAFFLE and our method as NTK-BAFFLE. It is worth mentioning that the BP-Free training for FL is currently not competitive with the backpropagation-based FL as the estimated gradients are not as accurate as the true gradients. Therefore, we focus on how many improvements our method can achieve. Table 3 displays the maximum accuracy comparison between NTK-BAFFLE and Vanilla-BAFFLE on the CIFAR-10 dataset with a sparsity of . Note that the number of local training epochs is set to 1. For experiments on the two datasets, the NTK-BAFFLE consistently outperforms the vanilla BAFFLE regarding classification accuracy. For instance, the maximum accuracy is boosted by , , and on LeNet when setting to 50, 100, and 200, respectively. Figure 2 shows the learning curve against communication rounds given different values of . It’s clear that NTK-BAFFLE performs better than the Vanilla-BAFFLE method on various settings of . Figure 3 presents the learning curves from the perspective of accuracy versus FLOPs. The first row shows the results on LeNet, and the second row shows that on ResNet-20.
When utilizing ResNet-20, the improvements in the maximum accuracy are , , and , respectively. These phenomena are mainly due to the reason that the estimation over a more complex is not as stable as that over a relatively simple model. Although the maximum accuracy improvements are not as prominent as that when using ResNet-20, the learning curves against FLOPs from Figure 3 convincingly demonstrate that NTK-BAFFLE achieves higher accuracy when consuming the same FLOPs.
Settings | Value of | |||
---|---|---|---|---|
50 | 100 | 200 | ||
LeNet-5 | NTK-BAFFLE | |||
Vanilla-BAFFLE | ||||
FedAvg | 53.20 | |||
ResNet-20 | NTK-BAFFLE | |||
Vanilla-BAFFLE | ||||
FedAvg | 48.26 |
Settings | Value of K | |||
---|---|---|---|---|
50 | 100 | 200 | ||
LeNet-5 | NTK-BAFFLE | |||
Vanilla-BAFFLE | ||||
FedAvg | 23.52 | |||
ResNet-20 | NTK-BAFFLE | |||
Vanilla-BAFFLE | ||||
FedAvg | 34.19 |
Table 4 presents the maximum accuracy comparison between NTK-BAFFLE and Vanilla-BAFFLE on the CIFAR-100 dataset. Unlike the sparsity setting for the CIFAR-10 dataset, we uniformly set the sparsity level for the CIFAR-100 dataset to ensure the Figure 5 exhibits the learning curves regarding accuracy versus FLOPs. Same as Figure 3, the first row in Figure 5 presents the results on LeNet, and the second row shows that on ResNet-20. The maximum accuracy improvements based on the NTK-BAFFLE against Vanilla-BAFFLE are consistently observed through all kinds of settings. For experiments conducted on LeNet, the maximum accuracy is increased by , , and , respectively. For experiments on ResNet-20, the accuracy is improved by , , and when setting to 50, 100, and 200, respectively. From Figure 3 and Figure 5, we can notice that our proposed pruning method consistently enhances the performance of the original BAFFLE method by achieving higher accuracy while consuming lower FLOPs.
Type | Device | Configuration | Number |
---|---|---|---|
Device | Jetson Nano | 128-core Maxwell GPU | 7 |
Jetson Xavier AGX | 512-core NVIDIA GPU | 3 | |
Server | Workstation | NVIDIA RTX 3080 GPU | 1 |
Real Test-bed Results. We conducted real test-bed experiments on CIFAR-10 under the heterogeneity. The model we utilized is the LeNet-5. The training hyperparameters stay the same and the number of Monte Carlo steps is set to . The platform configuration is listed in Table 5. It is composed of device devices and one cloud server. devices are randomly selected in each training round. Figure 4 shows the device devices and the comparison of learning curves. We can find that our method consistently beats the Vanilla-BAFFLE in classification accuracy.








Memory Analysis. The Vanilla-BAFFLE and NTK-BAFFLE offer an effective solution for minimizing memory usage on edge devices, combining the benefits of static and dynamic memory efficiency. When it comes to running backpropagation on deep networks using an auto-differential framework, additional static memory is needed. This places a significant strain on memory-constrained devices. Moreover, backpropagation relies on storing intermediate activations, leading to substantial requirements for dynamic memory. In contrast, the additional memory to store the perturbation in each forward pass and estimate the local gradient for the BAFFLE-based method is . To maximize the benefits of the BP-Free method, we perform layer-by-layer computations, effectively partitioning the forward computation graph into smaller segments. Table 6 shows the peak GPU memory usage (MB) during the forward pass. For LeNet, the peak memory usage of Vanilla-BAFFLE and NTK-BAFFLE is of the BP operation. For ResNet-20, the peak memory of BAFLLE-related methods is of the BP process. Since NTK-BAFFLE implements sparse training, additional memory is needed to store the binary mask. However, modern frameworks (e.g., (Elsen et al., 2020)) are capable of efficiently storing and processing sparse matrices. In addition, the binary mask can be efficiently stored with a -bit datatype. As a result, the extra memory usage associated with NTK-BAFFLE is negligible.
Model | BP | Vanilla-BAFFLE | NTK-BAFFLE |
---|---|---|---|
LeNet | |||
ResNet-20 |
The additional memory required for storing the perturbation in each forward pass and estimating the local gradient for the BAFFLE-based method is . It is worth noting that modern frameworks (e.g., (Elsen et al., 2020)) can efficiently store and process sparse matrices, making the memory needed to store the binary mask associated with NTK-BAFFLE negligible. We perform layer-by-layer computations, partitioning the forward computation graph into smaller segments and calculating the peak GPU memory usage during the forward pass when the input is RGB images with a size of from the CIFAR-10 dataset. In LeNet and ResNet-20, the peak memory usage of Vanilla-BAFFLE is significantly smaller compared to backpropagation-based methods, with values of MB vs. MB and MB vs. MB, respectively. Additionally, the proposed NTK-BAFFLE has fewer parameters. When considering an pruning rate, the peak memory usage of NTK-BAFFLE due to parameters is only (LeNet) and (ResNet-20) of the usage in Vanilla-BAFFLE, respectively.
Communication Overhead Analysis. In the BP-based FL system, the communication cost composed by uploading and downloading is twice the number of model parameters. In Vanilla-BAFFLE and NTK-BAFFLE, each device only uploads a -dimensional vector to the server for aggregation. Since is significantly less than the parameter amounts , both the Vanilla-BAFFLE and NTK-BAFFLE greatly reduce the cost during the upload compared to the conventional backpropagation-based FL. For download, NTK-BAFFLE requires less data transmission due to the sparsity property compared to Vanilla-BAFFLE. Overall, NTK-BAFFLE has the lowest communication costs.
6. Conclusion
Although Federated Learning (FL) is becoming increasingly popular in AIoT design, it greatly suffers from the problem of low inference performance when dealing with memory-constrained AIoT devices. To address this issue, in this paper we proposed a memory-efficient federated foresight pruning method based on the NTK theory. Since our proposed foresight-pruning method can seamlessly integrate into BP-Free training frameworks, it can significantly reduce the memory footprint during FL training. Meanwhile, by combining the proposed foresight pruning with the finite difference method used in BP-Free training, the performance of the vanilla BP-Free method is optimized, thus significantly reducing the overall FLOPs of FL. Note that, due to the introduction of a data-free approach to diminishing the error of using local NTK matrices to approximate the federated NTK matrix, our approach is resilient to various extreme data heterogeneity scenarios. Comprehensive experimental results obtained from simulation- and real test-bed-based platforms with various DNN models demonstrate the effectiveness of our method from the perspectives of memory usage and computation burden.
References
- (1)
- Abadi et al. (2016) Martín Abadi, Paul Barham, Jianmin Chen, Zhifeng Chen, Andy Davis, Jeffrey Dean, Matthieu Devin, Sanjay Ghemawat, Geoffrey Irving, Michael Isard, Manjunath Kudlur, Josh Levenberg, Rajat Monga, Sherry Moore, Derek G. Murray, Benoit Steiner, Paul Tucker, Vijay Vasudevan, Pete Warden, Martin Wicke, Yuan Yu, and Xiaoqiang Zheng. 2016. TensorFlow: A system for large-scale machine learning. In Proceedings of USENIX Symposium on Operating Systems Design and Implementation (OSDI). 265–283.
- Adamczak et al. (2011) Radosław Adamczak, Alexander E Litvak, Alain Pajor, and Nicole Tomczak-Jaegermann. 2011. Sharp bounds on the rate of convergence of the empirical covariance matrix. Comptes Rendus. Mathématique 349, 3-4 (2011), 195–200.
- Bibikar et al. (2022) Sameer Bibikar, Haris Vikalo, Zhangyang Wang, and Xiaohan Chen. 2022. Federated dynamic sparse training: Computing less, communicating less, yet learning better. In Proceedings of the AAAI Conference on Artificial Intelligence (AAAI). 6080–6088.
- Cao et al. (2021) Jing Cao, Zirui Lian, Weihong Liu, Zongwei Zhu, and Cheng Ji. 2021. HADFL: Heterogeneity-aware Decentralized Federated Learning Framework. In Proceedings of the Design Automation Conference (DAC). 1–6.
- Chandrasekaran et al. (2022) Rishikanth Chandrasekaran, Kazim Ergun, Jihyun Lee, Dhanush Nanjunda, Jaeyoung Kang, and Tajana Rosing. 2022. FHDnn: communication efficient and robust federated learning for AIoT networks. In Proceedings of the Design Automation Conference (DAC). 37–42.
- Cui et al. (2023) Yangguang Cui, Kun Cao, Junlong Zhou, and Tongquan Wei. 2023. Optimizing Training Efficiency and Cost of Hierarchical Federated Learning in Heterogeneous Mobile-Edge Cloud Computing. IEEE Transactions on Computer-Aided Design of Integrated Circuits and Systems (TCAD) 42, 5 (2023), 1518–1531.
- Elsen et al. (2020) Erich Elsen, Marat Dukhan, Trevor Gale, and Karen Simonyan. 2020. Fast sparse convnets. In Proceedings of the Conference on Computer Vision and Pattern Recognition (CVPR). 14629–14638.
- Feng et al. (2023) Haozhe Feng, Tianyu Pang, Chao Du, Wei Chen, Shuicheng Yan, and Min Lin. 2023. Does Federated Learning Really Need Backpropagation? arXiv preprint arXiv:2301.12195 (2023).
- HaoChen et al. (2021) Jeff Z. HaoChen, Colin Wei, Jason D. Lee, and Tengyu Ma. 2021. Shape Matters: Understanding the Implicit Bias of the Noise Covariance. In Proceedings of Annual Conference Computational Learning Theory (COLT). 2315–2357.
- He et al. (2016) Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. 2016. Deep residual learning for image recognition. In Proceedings of the Conference on Computer Vision and Pattern Recognition (CVPR). 770–778.
- Huang et al. (2021) Baihe Huang, Xiaoxiao Li, Zhao Song, and Xin Yang. 2021. Fl-ntk: A neural tangent kernel-based framework for federated learning analysis. In Proceedings of International Conference on Machine Learning (ICML). 4423–4434.
- Huang et al. (2023) Hong Huang, Lan Zhang, Chaoyue Sun, Ruogu. Fang, Xiaoyong Yuan, and Dapeng Wu. 2023. Distributed Pruning Towards Tiny Neural Networks in Federated Learning. In Proceedings of the International Conference on Distributed Computing Systems (ICDCS). 190–201.
- Jacot et al. (2018) Arthur Jacot, Franck Gabriel, and Clément Hongler. 2018. Neural Tangent Kernel: Convergence and Generalization in Neural Networks. In Proceedings of Advances in Neural Information Processing Systems (NeurIPS). 8580–8589.
- Jiang et al. (2023) Yuang Jiang, Shiqiang Wang, Víctor Valls, Bong Jun Ko, Wei-Han Lee, Kin K. Leung, and Leandros Tassiulas. 2023. Model Pruning Enables Efficient Federated Learning on Edge Devices. IEEE Transactions on Neural Networks and Learning Systems (TNNLS) 34, 12 (2023), 10374–10386.
- Khan et al. (2021) Latif U. Khan, Walid Saad, Zhu Han, Ekram Hossain, and Choong Seon Hong. 2021. Federated Learning for Internet of Things: Recent Advances, Taxonomy, and Open Challenges. IEEE Communications Surveys and Tutorials 23, 3 (2021), 1759–1799.
- Krizhevsky (2009) A Krizhevsky. 2009. Learning Multiple Layers of Features from Tiny Images. Master’s thesis, University of Tront (2009).
- LeCun et al. (1998) Yann LeCun, Léon Bottou, Yoshua Bengio, and Patrick Haffner. 1998. Gradient-based learning applied to document recognition. Proc. IEEE 86, 11 (1998), 2278–2324.
- Li et al. (2021) Ang Li, Jingwei Sun, Binghui Wang, Lin Duan, Sicheng Li, Yiran Chen, and Hai Li. 2021. LotteryFL: Empower Edge Intelligence with Personalized and Communication-Efficient Federated Learning. In Proceedings of IEEE/ACM Symposium on Edge Computing (SEC). 68–79.
- Li et al. (2020) Jian Li, Xuanyuan Luo, and Mingda Qiao. 2020. On generalization error bounds of noisy gradient methods for non-convex learning. In Proceedings of International Conference on Learning Representations (ICLR).
- Li et al. (2022) Qinbin Li, Yiqun Diao, Quan Chen, and Bingsheng He. 2022. Federated learning on non-iid data silos: An experimental study. In Proceedings of the International Conference on Data Engineering (ICDE). 965–978.
- Li and Lyu (2023) Yipeng Li and Xinchen Lyu. 2023. Convergence Analysis of Sequential Federated Learning on Heterogeneous Data. In Proceedings of Annual Conference on Neural Information Processing Systems (NeurIPS).
- Li and Chen (2021) Zan Li and Li Chen. 2021. Communication-efficient decentralized zeroth-order method on heterogeneous data. In Proceedings of the International Conference on Wireless Communications and Signal Processing (WCSP). 1–6.
- Liang et al. (2021) Tailin Liang, John Glossner, Lei Wang, Shaobo Shi, and Xiaotong Zhang. 2021. Pruning and quantization for deep neural network acceleration: A survey. Neurocomputing 461 (2021), 370–403.
- Liao et al. (2023) Dongping Liao, Xitong Gao, Yiren Zhao, and Cheng-Zhong Xu. 2023. Adaptive Channel Sparsity for Federated Learning Under System Heterogeneity. In Proceedings of the Conference on Computer Vision and Pattern Recognition (CVPR). 20432–20441.
- Liu et al. (2023) Yixing Liu, Yan Sun, Zhengtao Ding, Li Shen, Bo Liu, and Dacheng Tao. 2023. Enhance local consistency in federated learning: A multi-step inertial momentum approach. arXiv preprint arXiv:2302.05726 (2023).
- McMahan et al. (2017a) Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, and Blaise Aguera y Arcas. 2017a. Communication-efficient learning of deep networks from decentralized data. In Proceedings of Artificial intelligence and statistics (AISTATS). 1273–1282.
- McMahan et al. (2017b) Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, and Blaise Aguera y Arcas. 2017b. Communication-efficient learning of deep networks from decentralized data. In Proceedings of Artificial Intelligence and Statistics (AISTATS). 1273–1282.
- Mendieta et al. (2022) Matias Mendieta, Taojiannan Yang, Pu Wang, Minwoo Lee, Zhengming Ding, and Chen Chen. 2022. Local learning matters: Rethinking data heterogeneity in federated learning. In Proceedings of the Conference on Computer Vision and Pattern Recognition (CVPR). 8397–8406.
- Paszke et al. (2019) Adam Paszke, Sam Gross, Francisco Massa, Adam Lerer, James Bradbury, Gregory Chanan, Trevor Killeen, Zeming Lin, Natalia Gimelshein, Luca Antiga, Alban Desmaison, Andreas Köpf, Edward Yang, Zach DeVito, Martin Raison, Alykhan Tejani, Sasank Chilamkurthy, Benoit Steiner, Lu Fang, Junjie Bai, and Soumith Chintala. 2019. PyTorch: An Imperative Style, High-Performance Deep Learning Library. In Proceedings of Annual Conference on Neural Information Processing Systems (NeurIPS). 8024–8035.
- Qiu et al. (2022) Xinchi Qiu, Javier Fernandez-Marques, Pedro PB Gusmao, Yan Gao, Titouan Parcollet, and Nicholas Donald Lane. 2022. ZeroFL: Efficient On-Device Training for Federated Learning with Local Sparsity. In Proceedings of the International Conference on Learning Representations (ICLR).
- Sarkar et al. (2023) Rishov Sarkar, Hanxue Liang, Zhiwen Fan, Zhangyang Wang, and Cong Hao. 2023. Edge-moe: Memory-efficient multi-task vision transformer architecture with task-level sparsity via mixture-of-experts. In Proceedings of the International Conference on Computer-Aided Design (ICCAD). 1–9.
- Shu et al. (2022) Yao Shu, Shaofeng Cai, Zhongxiang Dai, Beng Chin Ooi, and Bryan Kian Hsiang Low. 2022. NASI: Label- and Data-agnostic Neural Architecture Search at Initialization. In Proceedings of the International Conference on Learning Representations (ICLR).
- Shu et al. (2023) Yao Shu, Zhongxiang Dai, Weicong Sng, Arun Verma, Patrick Jaillet, and Bryan Kian Hsiang Low. 2023. Zeroth-Order Optimization with Trajectory-Informed Derivative Estimation. In Proceedings of the International Conference on Learning Representations (ICLR).
- Shukla et al. (2023) Sanket Shukla, Setareh Rafatirad, Houman Homayoun, and Sai Manoj Pudukottai Dinakarrao. 2023. Federated Learning with Heterogeneous Models for On-device Malware Detection in IoT Networks. In Proceedings of Design, Automation & Test in Europe Conference & Exhibition (DATE). 1–6.
- Stein (1981) Charles M. Stein. 1981. Estimation of the Mean of a Multivariate Normal Distribution. The Annals of Statistics 9, 6 (1981), 1135–1151.
- Wang et al. (2023) Yite Wang, Dawei Li, and Ruoyu Sun. 2023. NTK-SAP: Improving neural network pruning by aligning training dynamics. In Proceedings of the International Conference on Learning Representations (ICLR).
- Wu et al. (2024) Donglei Wu, Weihao Yang, Haoyu Jin, Xiangyu Zou, Wen Xia, and Binxing Fang. 2024. FedComp: A Federated Learning Compression Framework for Resource-Constrained Edge Computing Devices. IEEE Transactions on Computer-Aided Design of Integrated Circuits and Systems (TCAD) 43, 1 (2024), 230–243.
- Xia et al. (2022) Jun Xia, Tian Liu, Zhiwei Ling, Ting Wang, Xin Fu, and Mingsong Chen. 2022. PervasiveFL: Pervasive Federated Learning for Heterogeneous IoT Systems. IEEE Transactions on Computer-Aided Design of Integrated Circuits and Systems (TCAD) 41, 11 (2022), 4100–4111.
- Yang and Sun (2022) Zhao Yang and Qingshuang Sun. 2022. Personalized Heterogeneity-Aware Federated Search Towards Better Accuracy and Energy Efficiency. In Proceedings of the International Conference on Computer-Aided Design (ICCAD). 59:1–59:9.
- Yang and Sun (2023) Zhao Yang and Qingshuang Sun. 2023. Mitigating Heterogeneities in Federated Edge Learning with Resource-independence Aggregation. In Proceedings of the Design, Automation & Test in Europe Conference & Exhibition (DATE). 1–2.
- Yi et al. (2022) Xinlei Yi, Shengjun Zhang, Tao Yang, and Karl H. Johansson. 2022. Zeroth-Order Algorithms for Stochastic Distributed Nonconvex Optimization. Automatica 142 (2022), 110353.
- Zhang and Tao (2021) Jing Zhang and Dacheng Tao. 2021. Empowering Things with Intelligence: A Survey of the Progress, Challenges, and Opportunities in Artificial Intelligence of Things. IEEE Internet of Things Journal 8, 10 (2021), 7789–7817.
- Zhu et al. (2022) Zhuangdi Zhu, Junyuan Hong, Steve Drew, and Jiayu Zhou. 2022. Resilient and Communication Efficient Learning for Heterogeneous Federated Systems. In Proceedings of the International Conference on Machine Learning (ICML). 27504–27526.