Over-the-Air Federated Learning Over MIMO Channels: A Sparse-Coded Multiplexing Approach
Abstract
The communication bottleneck of over-the-air federated learning (OA-FL) lies in uploading the gradients of local learning models. In this paper, we study the reduction of the communication overhead in the gradients uploading by using the multiple-input multiple-output (MIMO) technique. We propose a novel sparse-coded multiplexing (SCoM) approach that employs sparse-coding compression and MIMO multiplexing to balance the communication overhead and the learning performance of the FL model. We derive an upper bound on the learning performance loss of the SCoM-based MIMO OA-FL scheme by quantitatively characterizing the gradient aggregation error. Based on the analysis results, we show that the optimal number of multiplexed data streams to minimize the upper bound on the FL learning performance loss is given by the minimum of the numbers of transmit and receive antennas. We then formulate an optimization problem for the design of precoding and post-processing matrices to minimize the gradient aggregation error. To solve this problem, we develop a low-complexity algorithm based on alternating optimization (AO) and alternating direction method of multipliers (ADMM), which effectively mitigates the impact of the gradient aggregation error. Numerical results demonstrate the superb performance of the proposed SCoM approach.
Index Terms:
Over-the-air federated learning, multiple-input multiple-output access channel, multiplexing, turbo compressed sensing.I Introduction
Sixth-generation (6G) wireless communications, as expected to support the connection density up to millions of wireless devices per square kilometer, will provide a solid foundation to fulfil the vision of ubiquitous intelligence [zhou2020service]. To develop a powerful intelligence model, it is necessary to exploit the diversity of data distributed over a large number of edge devices. A straightforward paradigm is to require edge devices to send local data to a central parameter server (PS) for training the model centrally. However, sending raw data to the PS requires a huge communication overhead and may expose user privacy. To overcome these drawbacks, federated learning (FL) is a promising substitute that allows edge devices to collaborate on training a machine learning (ML) model without sharing their local data with others [mcmahan2017communication]. Instead of uploading raw data, in an FL training round, each edge device sends its local gradient to the PS, and the PS aggregates the local gradients, updates and sends back the global model to the devices.
Gradient uploading is a critical bottleneck of deploying FL on a wireless network, since it is difficult to support the communication demands of massive edge devices with limited communication resources (e.g. time, bandwidth, and space). For example, the dimension of the recent ML models is extremely large, e.g., the ResNet152 has 60 million parameters [he2016deep], while the GPT-3 has 175 billion parameters [brown2020language]. Yet, the available channel bandwidth is typically small due to the bandwidth and latency limitations, e.g., 1 LTE frame of 5MHz bandwidth and 10ms duration can carry only complex symbols. Fortunately, in FL, the PS does not need to know the local gradient of each device but the aggregated gradient, usually the mean of all the local gradients. Based on this property, over-the-air FL (OA-FL) is proposed in [nazer2007computation, zhu2020broadband, amiri2020federated], where edge devices share the same wireless resources to upload their local gradients. Thanks to the analog superposition of electromagnetic waves, the local gradients are aggregated over-the-air in the process of uploading. Compared with the traditional orthogonal multiple access (OMA) approaches [yang2021energy, vu2020cell, vu2022joint], OA-FL does not require more communication resources as the number of devices increases [amiri2020federated], which greatly alleviates the communication bottleneck of gradient uploading. Pioneering studies demonstrate that OA-FL also exhibits strong noise tolerance and significant latency improvement [zhu2020broadband, liu2020reconfigurable].
Due to the appealing features of OA-FL, much research effort has been devoted to the design of efficient OA-FL systems. For example, ref. [lin2017deep] pointed out that the local gradients can be sparsified, compressed, and quantized to reduce the communication overhead without causing substantial losses in accuracy. Ref. [amiri2020federated] proposed an efficient scheme, where the local gradients are sparsified and linear coding compressed before uploaded, and the aggregated gradient is recovered at the PS via compressed sensing methods. The authors in [ma2022over] used partial-orthogonal compressing matrices and turbo compressed sensing (Turbo-CS) [ma2014turbo], achieving a lower complexity scheme of sparse coding. It has been shown in [ma2022over] that sparse coding enables the OA-FL system to achieve a lower communication overhead and a faster convergence rate.
The above schemes are all based on single-input single-output (SISO) systems. Multiple-input multiple-output (MIMO) with array signal processing has been widely recognized as a powerful technique to enhance the system capacity. MIMO multiplexing significantly reduces the number of channel uses by transmitting multiple spatial data streams in parallel through antenna arrays[spencer2004zero]. However, MIMO multiplexing causes the inter-stream interference which corrupts the aggregated gradient and the test accuracy for OA-FL. There have been some preliminary attempts to alleviate the impact of inter-stream interference by designing the precoding matrices at the devices and the post-processing matrix at the PS. For instance, ref. [zhu2018mimo] set the precoding matrix to the pseudo-inverse of channel matrix, and derived a closed-form equalizer as the post-processing matrix using differential geometry. The scheme proposed in [chen2018over] also uses the pseudo-inverse of channel matrix as the precoding matrix, and computes the post-processing matrix based on the receive antenna selection. However, both methods are based on the channel inversion, which may significantly amplify noise and hence exacerbate the gradients aggregation error, especially when some devices suffer from deep channel fading [zhu2020broadband, liu2020reconfigurable, zhong2022over].
In this paper, we consider an over-the-air federated learning (OA-FL) network where the local gradients are uploaded over a MIMO multiple access (MAC) channel. The MIMO MAC channel comprises a central PS with multiple antennas and several edge devices with multiple antennas. We propose a novel Sparse-Coded Multiplexing (SCoM) approach that integrates sparse coding with MIMO multiplexing in gradients uploading. Benefiting from two techniques, the SCoM achieves a strikingly better balance between the communication overhead and the learning performance. On one hand, sparse-coding utilizes the sparsity of the gradient to compress the gradient, reducing the communication overhead. On the other hand, MIMO multiplexing reduces the number of channel uses by transmitting multiple streams in parallel, and suppresses the gradients aggregation error through precoding and post-processing matrices. The main contributions are summarized as follows.
-
•
We propose a novel SCoM approach for gradients uploading in MIMO OA-FL. We derive an upper bound on the learning performance loss of the SCoM-based MIMO OA-FL scheme by quantitatively characterizing the gradient aggregation error.
-
•
Based on the analytical result, we formulate a joint precoding and post-processing matrices optimization problem for suppressing the gradient aggregation error. We design a low-complexity algorithm that employs alternating optimization (AO) and alternating direction method of multipliers (ADMM) to jointly optimize the precoding and post-processing matrices.
-
•
We derive the optimal number of multiplexed data streams for SCoM to balance the communication overhead and the gradient aggregation performance. More specifically, the optimal number of multiplexed data streams to minimize the upper bound of the learning performance loss is given by , where denotes the number of transmit antennas, and denotes the number of receive antennas.
Numerical results demonstrate that our proposed SCoM approach achieves the same test accuracy with much lower communication overhead than other existing approaches, which indicates the superior performance of the SCoM approach.
The rest of this paper is structured as follows. Section II introduces the FL model and the MIMO MAC channel. Section III presents the proposed SCoM approach. The analysis of the learning performance of the SCoM approach is presented in Section IV. In Section V, we present the optimization problem to minimize the gradient aggregation error, and propose a low-complexity algorithm to jointly optimize precoding and post-processing matrices. Section VI presents numerical results to evaluate the SCoM approach and Section VII concludes the paper.
Notation: and denote the sets of real and complex numbers, respectively. , , , , and are used to denote the trace, the rank, the conjugate, the transpose, and the conjugate transpose of the matrix, respectively. denotes the set . denotes a sub-vector of that contains entries from index to index . The expectation operator is denoted by . We use , and to denote the identity matrix of size and the zero matrix of size , respectively. We use , and to denote the -norm and the Frobenius norm, respectively. denotes the circularly-symmetric complex Gaussian (CSCG) distribution that has a mean of and a covariance of .
II System Model
II-A Federated Learning
We start with the description of the FL task deployed on a wireless communication system, where the system consists of one central PS and edge devices. We assume that the training data of the FL task are all distributedly stored on the edge devices. Let denote the local dataset of the -th device, and denote the cardinality of . is the total number of training data samples for the FL task. is the model parameter vector with being the total length of the model parameter. The target of the FL task is to minimize an empirical loss function based on the local datasets , given by
(1) |
where is the local loss function of device , and is the sample-wise loss function for the -th training sample in .
To minimize the empirical loss function in \eqrefeq:FLTarget, the FL training involves communication rounds between the edge devices and the PS for to reach convergence. Specifically, each communication round consists of four steps:
-
•
Global model download: The PS sends the global model to each edge device.
-
•
Local gradients computation: The local gradient is computed by device based on their own data and the global model, given by
(2) -
•
Local gradients upload: The edge devices send the local gradients to the PS through wireless channels.
-
•
Global model aggregation: The local gradients are aggregated as
(3) where denotes the aggregated gradient, and . The global model is updated by
(4) where denotes the learning rate.
II-B MIMO Channel Model
[width=0.6]Images/SystemModel.pdf
We now introduce the wireless multi-user multiple-input multiple-output (MIMO) channel for the above FL system. As depicted in Fig. 1, the considered MIMO OA-FL system consists of a PS with antennas, and edge devices with each equipped with antennas. As in previous studies in OA-FL [amiri2020federated, sery2021over, cao2022transmission], we make two assumptions: that the download of the global model is through error-free links111 In practice, the channel noise causes communication errors in the model download. This issue can be addressed by the schemes proposed by [vu2020cell, vu2022joint]. This is beyond the scope of this paper and hence omitted here. and that the devices upload the local gradients to the PS synchronously222 The existing techniques in 4G Long Term Evolution, e.g., the timing advance (TA) mechanism, can achieve the synchronization of the gradient symbols among the edge devices [3gpp.38.213]. . We now focus on the process of local gradients uploading. We consider a block-fading channel, where the channel state information (CSI) remains constant during the gradients uploading. Let denote the CSI matrix between the -th device and the PS at the -th round. We assume that the PS has perfect knowledge of the CSI of the wireless channels between the devices and the PS333 The approaches of CSI estimation over MIMO MAC channels can be referred to [nguyen2013compressive, wen2014channel, vu2020cell, vu2022joint].. Thus, at the PS, the receive signal matrix from the above MIMO multiple access (MAC) channel is given by
(5) |
where denotes the number of channel uses at the -th round; denotes the transmit data matrix for the -th device; and is an additive white Gaussian noise (AWGN) matrix, with the entries independently drawn from . Let denote the -th column of . Here, we consider the following transmit power constraint:
(6) |
where is the transmit power budget.
What remains are to map the local gradients to the transmit matrices at the edge devices, and to recover the aggregated gradient from the receive signal matrix . These issues are discussed in detail in the next section.
III Proposed Sparse-Coded Multiplexing Approach
With the development of deep learning, the size of model is increasing. To upload the large number of FL local gradients over the aforementioned MIMO channel, the key challenge is the heavy communication burden. Although transmitting the data streams in parallel with MIMO multiplexing efficiently reduces the communication overhead, MIMO multiplexing also causes the interference between the data streams, resulting in a loss of FL learning performance.
To address these challenges, we propose a novel transmission scheme, i.e., the Sparse-Coded Multiplexing (SCoM) approach, for the above MIMO OA-FL system, as shown in Fig. 2. The SCoM approach employs two techniques: sparse-coding and MIMO multiplexing. On one hand, sparse-coding utilizes the sparsity of the gradient to compress the gradient, reducing the communication overhead. Meanwhile, sparse-coding leverages the compression matrix to encode the data streams, thereby reducing the correlation between data streams and suppressing inter-stream interference. On the other hand, MIMO multiplexing reduces the number of channel uses by transmitting multi-stream data through antenna arrays, and suppresses inter-stream interference through precoding and post-processing matrices. The details of the SCoM approach are given as follows.
[width=]Images/FlowGraph.pdf
III-A Processing on Devices
To support the local gradients uploading in the MIMO OA-FL system, the pre-processing operations are first conducted on the edge devices, including the gradient sparsification [amiri2020federated] and the gradient compression[ma2014turbo].
To be specific, for the -th edge device, the local gradient at the -th round is first complexified to fully utilize the spectral efficiency of complex channels. The complexified gradient is denoted by
(7) |
where . Then, the accumulated gradient is obtained by
(8) |
where denotes the error accumulation vector of the -th device at the -th round with initialized as . Having calculated the accumulated gradient, in the sparsification, the -th device obtains the sparsified gradient via
(9) |
where denotes the sparsity ratio. The operator retains the entries of with the largest absolute value magnitude, and sets the remaining entries to . The error accumulation vector is updated via
(10) |
Then is normalized to by
(11) |
where is a random flipping vector, with each entry of being independent and identically distributed (i.i.d.) drawn from ; and denotes the variance of , with being the -th entry of . From \eqrefeq:g_normalize, the entries of have zero-mean and unit variance.
Each device then compresses into a low-dimensional vector via a common compressing matrix. Specifically, the compressed gradient is given by
(12) |
where denotes the common compressing matrix444 We assume that both the compression matrix and the flipping vector keep invariant throughout the FL training process, and are shared among the devices prior to the FL training., denotes the length of the compressed gradient, and denotes the compression ratio. Inspired by [ma2014turbo, ma2022over], we employ a partial DFT matrix as the compressing matrix, given by . is a selection matrix consisting of randomly selected and reordered rows of the identity matrix ; and is a unitary DFT matrix, where the -th entry of is given by . We note that the partial DFT matrix has lower computational complexity and better performance, compared with other types of compressed matrix such as i.i.d. Gaussian matrix [ma2015performance].
To transmit the data with multiple streams, device then reshapes the compressed gradient into as
(13) |
where denotes the number of data streams, . Naturally, the number of channel uses satisfies the following equation:
(14) |
We are now ready to describe the design of the transmit matrix . The transmit matrix is given by , where denotes the precoding matrix for the -th device. Let denote the -th column of . We have . From the transmit power constraint in \eqrefeq:trans_power_constr1, we have
(15) |
where the step (a) is due to the normalization in \eqrefeq:g_normalize.
III-B Processing on the PS
We now describe the processing operations on the PS. In the following, we first transfer the recovery of the aggregated gradient to a compressed sensing problem, and then adopt the Turbo compressed sensing (Turbo-CS) algorithm [ma2014turbo] to solve this problem.
At the PS, the receive signal matrix in \eqrefeq:receive_data is first processed through the post processing matrix , and the post-processed matrix is given by
(16) |
is an approximation to the compressed gradient aggregation matrix , and the residual error is given by
(17) |
Based on \eqrefeq:error_receive_data, the PS converts the post-processed matrix into an equivalent vector form:
(18) |
where step (a) is from \eqrefeq:g_compress and \eqrefeq:G_def, together with the definition of and . We assume that the entries of are independently drawn from a Bernoulli Gaussian distribution:
(19) |
where is the -th entry of , is the sparsity ratio of , which can be estimated by the Expectation-Maximization algorithm [vila2013expectation], and is the variance of the nonzero entries in . Moreover, we assume that the entries of are i.i.d. drawn from , where represents the mean square error (MSE) of the compressed gradient aggregation, given by
(20) |
The recovery of from in \eqrefeq:compressed_problem is a compressed sensing problem, where the local gradients are compressed by the partial DFT matrix on the edge devices. From [ma2014turbo, ma2022over], we see that the Turbo-CS algorithm is the state-of-the-art to solve the compressed sensing problem with partial-orthogonal sensing matrices. Thus, we employ the Turbo-CS algorithm to recover :
(21) |
The details of the Turbo-CS algorithm are presented in the next subsection. Meanwhile, the performance of the Turbo-CS algorithm to recover the aggregated gradient is related to the variance of the noise , i.e., . A smaller leads to a better recovery performance and a less loss of learning accuracy, which is discussed in Section IV.
III-C Turbo-CS Algorithm
As shown in Fig. 3, Turbo-CS conducts the iteration between modules A and B until convergence. Module A is a linear minimum mean-squared error (LMMSE) estimator handling the linear constraint in \eqrefeq:compressed_problem, and module B is a minimum mean-squared error (MMSE) denoiser exploiting the sparsity in \eqrefeq:g_agg_sparsity. We next give the operations of Turbo-CS in detail.
[width = 0.7]Images/TurboCS.pdf
The iterative process begins with module A. The inputs of the LMMSE estimator in module A are the a priori mean , the a priori covariance , and the observed vector .
With given , , and , the a posteriori mean and the a posteriori covariance of the LMMSE estimator are given by [kay1993fundamentals]
{subequations}
{align}
g_A^post(t)
& = g_A^pri(t)
+ vApri(t)vApri(t)+ σw(t) A^H
(^r^(t) - A g_A^pri(t)),
v_A^post(t)
= v_A^pri(t) - κ⋅( vApri(t))2vApri(t)+ σw(t),
where is the compression ratio defined below \eqrefeq:g_compress.
Then the extrinsic mean and variance of the LMMSE estimator are given by
{subequations}
{align}
g_A^ext(t) &= v_A^ext(t) ( gApost(t)vApost(t)-gApri(t)vApri(t) ),
v_A^ext(t) = (1/v_A^post(t)-1/v_A^pri(t) )^-1.
The extrinsic messages are used to update the a priori mean and the a priori variance as , and .
Both are the inputs of the MMSE denoiser in module B.
In module B, we model the a priori mean as an observation of corrupted by additive noise: , where denotes the -th entry of . The a posteriori mean and the variance of the MMSE denoiser are given by
{subequations}
{align}
g_B^post(t) &= \mathbbE[g^no(t)—g_B^pri(t)],
v_B^post(t) = 1D/2∑_d=1^D/2 \operatornamevar[g^no(t)_d—g_B,d^pri(t)],
where , and is the -th entry of .
Then the extrinsic messages of the MMSE denoiser are given by
{subequations}
{align}
&g_B^ext(t) = v_B^ext(t)
(gBpost(t)vBpost(t) - gBpri(t)vBpri(t) ),
v_B^ext(t) = ( 1/v_B^post(t) - 1/v_B^pri(t) )^-1.
The extrinsic messages are used to update the a priori mean and the a priori variance as , and .
Both are the inputs of the LMMSE estimator in module A.
At the end of the iterative process, the final estimate is based on the a posteriori output of the module B, i.e., .
Then, based on the output of Turbo-CS, the aggregated gradient is given by
(22) |
where . The PS then updates the model by
(23) |
To sum up, at each round , the receive matrix is first reshaped into a vector form in \eqrefeq:compressed_problem. Given the observed vector and the initial values , Turbo-CS iteratively calculates \eqrefeq:g_v_post_A-\eqrefeq:g_v_ext_B until a certain termination criterion is met. Finally, the output is scaled as for the model update in \eqrefeq:SGD_error.
III-D Overall Scheme
The proposed SCoM approach to the local gradients uploading is summarized below, where lines 3-5, 11-17 are executed at the PS, and lines 6-10 are executed at the devices.
[htb] \floatnamealgorithm Proposed SCoM-Based MIMO OA-FL Scheme {algorithmic}[1] \REQUIRE, , , , , , , and . \STATEInitialization: , the global model . \FOR \STATEPS does: \STATEEstimate the CSI matrices , and optimize and ; \STATESend and to the edge devices through error free links; \STATEEach device does in parallel: \STATECompute based on the and via \eqrefeq:GradDef; \STATECompute via \eqrefeq:Complexification-\eqrefeq:G_def; \STATEUpdate via \eqrefeq:Delta_update; \STATESend to the PS with the precoding matrix over the MIMO channel in \eqrefeq:receive_data; \STATEPS does: \STATEReshape into via \eqrefeq:compressed_problem;