This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

\newfloatcommand

capbtabboxtable[][\FBwidth]

One-shot Collaborative Data Distillation

William Holland, Chandra Thapa, Sarah Ali Siddiqui, Wei Shao, and Seyit Camtepe
CSIRO Data61, Sydney, Australia
Corresponding Author. Email: [email protected].
Abstract

Large machine-learning training datasets can be distilled into small collections of informative synthetic data samples. These synthetic sets support efficient model learning and reduce the communication cost of data sharing. Thus, high-fidelity distilled data can support the efficient deployment of machine learning applications in distributed network environments. A naive way to construct a synthetic set in a distributed environment is to allow each client to perform local data distillation and to merge local distillations at a central server. However, the quality of the resulting set is impaired by heterogeneity in the distributions of the local data held by clients. To overcome this challenge, we introduce the first collaborative data distillation technique, called CollabDM, which captures the global distribution of the data and requires only a single round of communication between client and server. Our method outperforms the state-of-the-art one-shot learning method on skewed data in distributed learning environments. We also show the promising practical benefits of our method when applied to attack detection in 5G networks.

1 Introduction

Machine learning models trained on massive datasets are susceptible to high training times, slow research iteration, and poor eco-sustainability. To overcome these problems, and to increase the scalability of machine learning applications, large datasets can be distilled into small collections of informative synthetic data samples [37]. If the distilled data effectively captures the original dataset, machine learning models can be trained efficiently on the synthetic data with accuracy comparable to models trained on the original data.

In addition to computational efficiency, data distillation has the benefit of both reducing the communication cost of data sharing and, as only synthetic samples are shared, providing privacy to data owners [3]. These benefits are notable in applications such as 5G mobile networks, where massive volumes of data are generated from diverse sets of sources. In this setting, distilled data can be shared, in a safe and communication efficient manner, across heterogeneous domains and utilized for robust model training.

However, distributed learning is impaired by heterogeneity between the local distributions of the data held by clients [16, 17]. Further, sharing locally distilled datasets for global model training can amplify the perverse impacts of data heterogeneity [10]. This motivates the creation of data distillation techniques that synthesize a global synthetic dataset through the collaboration of clients. Collaboration allows diverse data sources to participate in a global distillation process without sharing local data. The resulting global synthetic set can be shared across parties and utilized for applications such as neural network architecture search [42], (global) model training, and continual learning [26].

Standard data distillation techniques [2, 37, 44] operate in a centralized and static model in which the whole dataset is available in a single location. Adapting these methods (efficiently) in a distributed learning environment is non-trivial and remains an open challenge. For example, Pi et al. propose a federated learning framework that performs global data distillation [25]. The synthetic data is optimized such that the parameter trajectory of the model trained on it matches the parameter trajectory of the model trained through standard federated learning. However, this procedure operates on a single model initialization and distillation algorithms that match training trajectories typically generate training trajectories across a large number of random model initializations [2]. Thus, the authors only implement a fraction of the full algorithm.

The challenge with adapting data distillation techniques to collaborative settings is that most involve many iterations of model training, which would, thus, involve many rounds of communication and parameter sharing. Consequently, they incur large communication overheads in a distributed setting and negate many of the benefits they promise. To overcome this limitation, we present a collaborative data distillation algorithm based on distribution matching [36, 43]. In distribution matching no model training is required. Instead, the synthetic data are optimized to match the distribution of real data in a family of random lower-dimensional embedding spaces. As the embedding spaces are randomly initialized, they can be distributed to clients with random seeds, mitigating the communication burden of transmitting and training model parameters.

Further, in distribution matching, the mean of the embeddings on real data is required to compute the loss functions for synthetic data optimization. Thus, with the random seeds for these embeddings initialized a priori, the clients can compute all means (one for each iteration of synthetic data training) in a single batch and transmit them to the server in a single-round of communication. Consequently, synthetic data is distilled collaboratively with a small communication overhead.

Prior work has utilized data distillation for communication reduction in federated learning scenarios [5, 9, 18, 31, 39, 47]. Here, clients distill their data independently of each other and upload the distilled data to the server. The global model is then updated with the information distilled in the synthetic data. For large models, the distilled data is often more compact than the model parameters. Therefore, synthetic data can offer lower per-round communication and faster model convergence than standard approaches that aggregate local model parameters into a global model. These methods provide a heuristic for reducing the communication cost of federated learning. However, they do not optimize the synthetic data over the global data distribution. The significance of a global synthetic dataset is that it provides distributed settings with efficient methods for additional applications, such as neural architecture search and continual learning.

Motivating Application

To motivate the field of collaborative data distillation, we provide a target application for the technique: 5G mobile networks. Next-generation mobile networks are built on edge networks, where network resources are placed close to end-users and are often spread across multiple tenants and domains. This creates a landscape where large volumes of data are generated at a growing number of locations, often within specified trust boundaries. The generated data can be used for a burgeoning range of 5G machine learning applications [11]. The data can be heterogeneous, motivating the need for globally trained models that generalize well. However, the generated data can be both large in size and private, preventing its transmission to a central point that orchestrates the machine learning applications. This challenge can be overcome with a compact global synthetic set, which can be easily shared among edge networks to support the relevant machine learning applications.

In our studies, along with standard benchmark datasets, we have considered attack detection in network traffic. In this setting, traffic at different points in the network can be monitored by a device with a general CPU or GPU. The device maintains a neural network to classify incoming traffic as benign or anomalous. If multiple points in the network contribute to training a global synthetic set, robust model training can be performed to capture the global dynamics of data generated across the network.

Contributions
  • We provide the first distributed data distillation algorithm, CollabDM, that captures the dynamics of the global data distribution in a single-round of communication.

  • The algorithm is tested against benchmark datasets. Results indicate that our technique outperforms the state-of-the-art one-shot learning method DENSE [41] on heterogeneous data partitions. The global synthetic sets generated by CollabDM are remarkably robust to the underlying data distribution, with only very small reductions in performance when the level of skew in the data distribution increases.

  • The algorithm is tested in a target distributed learning environment: 5G networks. This represents a new application for data distillation techniques. Results demonstrate that data distillation provides a promising direction for supporting machine learning applications in 5G networks.

2 Related Work

Data Distillation

Data distillation methods aim to synthesize small and high-fidelity data summaries, which distill the most important information from a target dataset [27]. The summaries can serve as effective drop-in replacements for the original dataset in machine-learning applications. Data distillation methods can be categorized into three types: meta-learning, parameter matching, and distribution matching.

Meta-learning techniques [24, 37] aim to minimize the expected loss incurred on real data for models trained on the synthetic set. This involves a bi-level optimization, where an inner loop trains a model with respect to the synthetic dataset, and the outer loop updates the synthetic set (considered as a hyperparameter) based on the loss observed on the model by real data. Parameter matching techniques allow the synthetic data to imitate the influence of the target dataset on model training. For example, synthetic data can be distilled to match the training gradients [44] or parameter trajectories [2] observed during training on real data.

In distribution matching [36, 43], the synthetic data are optimized to match the distribution of real data in a family of lower-dimensional embedding spaces. In contrast to prior approaches, distribution matching involves a single-level optimization. It is, therefore, considered less computationally intensive and more scalable.

Virtual Learning

Federated learning involves building a local surrogate function to approximate the local training objective. By sending local surrogate functions to the server, the server can build a global surrogate around the current solution. The aim is to build local surrogates that are informative and succinct. Local synthetic data can be constructed to capture information about the local update at the client and build local surrogate functions [39]. For example, locally distilled data can be used to approximate gradient updates [5, 18], minimize the difference between models trained on real and synthetic data [9] or communicate local approximations in the loss landscape [35].

Huang et al. propose an iterative method that utilizes local and global distillation [10]. They iteratively refine local and global synthetic data. The global virtual data is used as an anchor on the server side for model training. Similarly, Liu et al. attempt to distill synthetic data with global dynamics [20]. The distilled data is optimized to mimic the parameter trajectories of the global model under the standard FedAvg [21] algorithm. The authors observe that the updated dynamics of the global model contain knowledge about the global data distribution. This knowledge is transferred to a synthetic dataset at the server. However, data distillation with trajectory matching typically requires training on lots of randomly initialized models. Therefore, the actual data distillation algorithm is only partially implemented.

Note that the above methods are all multi-shot; that is, they require multiple rounds of communication.

One Shot Federated Learning

One-shot federated learning involves completing a federated learning objective in a single round of communication. Single-round communication is in high demand for practical applications [32] and has advantages such as reducing the risk of being attacked [41]. Most one-shot federated learning methods are based on knowledge distillation [8] or data distillation [37].

Methods based on knowledge distillation utilize the local models as teachers for the global model [6, 41]. Guha et al. propose a method where each client trains a model to completion and ensemble methods are used to train a global model [6]. This approach involves a public dataset. Zhang et al. propose a two-stage method that trains a global model through a data generation stage and a model distillation stage [41]. The first stage uses ensemble models obtained from clients to train a global data generator. The knowledge from ensemble models is distilled in the data generator and used to train a global model.

For methods based on data distillation, clients distill synthetic data locally (and independently of each other) and send the summaries to the server [31, 47], constituting a single round of communication. The server then trains the model on aggregated synthetic data. Our method adopts this template. However, our approach differs in that clients send additional computations, allowing the server to refine the synthetic data according to a global loss function. Thus, our approach is able to better combat data heterogeneity observed across the clients.

3 Preliminaries

The first part of this section covers key notation and the problem definition. The second part introduces the main data distillation frameworks [15]: meta-learning, parameter matching, and distribution matching. The meta-learning and parameter-matching frameworks will help demonstrate the challenges of distributed data distillation. Our approach is based on the distribution matching framework, which supports a collaborative algorithm that overcomes these challenges.

3.1 Notation

Let 𝒟=Δ{(xi,yi)}i=1|𝒟|\mathcal{D}\overset{\Delta}{=}\{(x_{i},y_{i})\}_{i=1}^{|\mathcal{D}|} be the data set that needs to be distilled, where xi𝒳x_{i}\in\mathcal{X} denotes the input features and yi𝒴y_{i}\in\mathcal{Y} is the label for xix_{i}. Throughout, the notation d𝒟d\sim\mathcal{D} refers to a data point dd selected uniformly at random from the set 𝒟\mathcal{D}. Given a data budget nZ+n\in\mathrm{Z}^{+}, a data distillation technique aims to synthesize a high-fidelity summary 𝒮=Δ{(xi~,yi~)}i=1n\mathcal{S}\overset{\Delta}{=}\{(\tilde{x_{i}},\tilde{y_{i}})\}_{i=1}^{n} such that n|𝒟|n\ll|\mathcal{D}|. The small distilled dataset should achieve a comparable generalization performance to the large original dataset.

For a given learning algorithm Φθ:𝒳𝒴\Phi_{\theta}:\mathcal{X}\rightarrow\mathcal{Y}, with parameterization θ\theta, the empirical risk \mathcal{R} on parameterization θ\theta and input data 𝒟\mathcal{D} is defined as

(𝒟;θ)\displaystyle\mathcal{R}(\mathcal{D};\theta) =i=1|𝒟|l(Φθ(xi),yi),\displaystyle=\sum_{i=1}^{|\mathcal{D}|}l(\Phi_{\theta}(x_{i}),y_{i}),

where ll is a loss function. A training algorithm for Φ\Phi attempts to find θ\theta that minimizes \mathcal{R}.

3.2 Problem Definition

Definition 1 (Data Distillation ([27]))

Given a learning algorithm Φ\Phi, let θ𝒟\theta^{\mathcal{D}}, θ𝒮\theta^{\mathcal{S}} represent the optimal set of parameters for Φ\Phi on, respectively, 𝒟\mathcal{D} and 𝒮\mathcal{S}. Data distillation is defined as the optimization of the following:

argmin𝒮(sup{|l(Φθ𝒟(x),y)l(Φθ𝒮(x),y)|}x𝒳y𝒴).\displaystyle\arg\min_{\mathcal{S}}\left(\sup\{\hskip 5.69054pt|l(\Phi_{\theta^{\mathcal{D}}}(x),y)-l(\Phi_{\theta^{\mathcal{S}}}(x),y)|\hskip 5.69054pt\}_{\begin{subarray}{c}x\sim\mathcal{X}\\ y\sim\mathcal{Y}\end{subarray}}\right). (1)

Thus, the objective is to extract the knowledge from 𝒟\mathcal{D} and transfer it to the synthetic set 𝒮\mathcal{S}, such that the model trained on 𝒮\mathcal{S} should achieve comparable generalization to the model trained on 𝒟\mathcal{D}.

Problem 1 (Collaborative Data Distillation)

The dataset 𝒟\mathcal{D} is split over KK disjoint clients that can communicate with a central server. Let 𝒟i\mathcal{D}_{i} be the data stored at client ii. For 𝒟=i=1K𝒟i\mathcal{D}=\cup_{i=1}^{K}\mathcal{D}_{i}, collaborative data distillation aims to solve the objective of Equation (1) under the conditions that

  1. 1.

    The server cannot observe 𝒟\mathcal{D}.

  2. 2.

    Client ii cannot observe 𝒟𝒟i\mathcal{D}\setminus\mathcal{D}_{i}.

With compact synthetic sets, collaborative data distillation aims to reduce the communication overhead in distributed machine-learning applications at a minimal cost in terms of fidelity.

3.3 Data Distillation with Meta-Learning

Meta-learning based methods [1, 33, 37, 46] treat 𝒮\mathcal{S} as a hyperparameter, which is updated by a meta (outer) algorithm and a base (inner) algorithm solves a conventional learning problem with respect to the synthetic dataset. The objective can thus be formulated as a bi-level optimization:

𝒮=argmin𝒮(𝒟;θ𝒮)\displaystyle\mathcal{S}^{*}=\arg\min_{\mathcal{S}}\mathcal{R}(\mathcal{D};\theta^{\mathcal{S}}) (2)

subject to

θ𝒮=argminθ(𝒮;θ).\displaystyle\theta^{\mathcal{S}}=\arg\min_{\theta}\mathcal{R}(\mathcal{S};\theta).

The inner loop, which optimizes parameters on the synthetic data, can be realized through gradient descent or kernel regression. The objective function can be defined as the meta-loss (𝒮)=(𝒟,θ𝒮)\mathcal{L}(\mathcal{S})=\mathcal{R}(\mathcal{D},\theta^{\mathcal{S}}). Consequently, the synthetic data can be updated as 𝒮=𝒮α𝒮(𝒮)\mathcal{S}=\mathcal{S}-\alpha\nabla_{\mathcal{S}}\mathcal{L}(\mathcal{S}) for learning rate α\alpha.

3.4 Data Distillation with Parameter Matching

Data matching [2, 44] aims to align the byproducts of model training on real and synthetic data. The synthetic data learns to mimic the influence of real data on model training. The objective function of data matching can be summarized as follows:

(𝒮)=k=0TQ(ϕ(𝒟,θ(k)),ϕ(𝒮,θ(k)))\displaystyle\mathcal{L}(\mathcal{S})=\sum_{k=0}^{T}Q(\phi(\mathcal{D},\theta^{(k)}),\phi(\mathcal{S},\theta^{(k)}))

subject to

θ(k+1)=θ(k)ηθ(k)(𝒮;θ(k)),\displaystyle\theta^{(k+1)}=\theta^{(k)}-\eta\nabla_{\theta^{(k)}}\mathcal{R}(\mathcal{S};\theta^{(k)}),

where QQ is a distance function, η\eta is the learning rate, and ϕ\phi maps a dataset to informative spaces such as gradient, parameter, and feature spaces. For example, the map ϕ(𝒮,θ)=θ(𝒮;θ)\phi(\mathcal{S},\theta)=\nabla_{\theta}\mathcal{R}(\mathcal{S};\theta) [44] equates to matching the gradients, with respect to θ\theta, of the observed empirical risk, and 𝒮\mathcal{S} is optimized to mimic the gradients observed on 𝒟\mathcal{D} during model training. In practice, the full dataset might be replaced with a batch to save memory and facilitate faster convergence.

3.5 Data Distillation with Distribution Matching

The underlying bi-level optimization of prior approaches is often expensive in terms of computation and memory. To mitigate these costs, distribution matching [36, 43] aims to solve a correlated proxy task that restricts optimization to a single level and improves scalability. Instead of matching the quality of the models generated by 𝒟\mathcal{D} and 𝒮\mathcal{S}, distribution matching attempts to match the underlying distributions of 𝒟\mathcal{D} and 𝒮\mathcal{S}. The assumption here is that datasets with the same distribution will lead to similarly trained models.

Distribution matching uses a collection of random parametric encoders to embed data into low-dimensional latent spaces. Distance metrics can then be used to compute the distribution mismatch between 𝒟\mathcal{D} and 𝒮\mathcal{S}. Formally, given a set of encoders ={ψi:𝒳𝒳i}\mathcal{E}=\{\psi_{i}:\mathcal{X}\rightarrow\mathcal{X}_{i}\}, the optimization objective, under maximum mean discrepancy, is:

argmin𝒮Eψy𝒴[Ex𝒟y[ψ(x)]Ex𝒮y[ψ(x)]2],\displaystyle\arg\min_{\mathcal{S}}\mathrm{E}_{\begin{subarray}{c}\psi\sim\mathcal{E}\\ y\sim\mathcal{Y}\end{subarray}}\left[\lVert\mathrm{E}_{\begin{subarray}{c}x\sim\mathcal{D}^{y}\end{subarray}}[\psi(x)]-\mathrm{E}_{\begin{subarray}{c}x\sim\mathcal{S}^{y}\end{subarray}}[\psi(x)]\rVert^{2}\right],

where 𝒟y={x(x,y)𝒟}\mathcal{D}^{y}=\{x\mid(x,y)\in\mathcal{D}\}. This objective, for a given ψ\psi\in\mathcal{E}, can be approximated with the following empirical loss:

=y𝒴1|Dy|xDyψθ(x)1|Sy|xSyψθ(x)2,\displaystyle\mathcal{L}=\sum_{y\in\mathcal{Y}}\left\lVert\frac{1}{|D^{y}|}\sum_{x\in D^{y}}\psi_{\theta}(x)-\frac{1}{|S^{y}|}\sum_{x\in S^{y}}\psi_{\theta}(x)\right\rVert^{2}, (3)

where Dy𝒟yD^{y}\subset\mathcal{D}^{y} denotes a batch of real data and Sy𝒮yS^{y}\subset\mathcal{S}^{y} denotes a batch of synthetic data. Typically \mathcal{E} is generated randomly and each ψ\psi\in\mathcal{E} has the same network architecture.

4 Collaborative Data Distillation

In the collaborative setting, the dataset 𝒟=i=1K𝒟i\mathcal{D}=\cup_{i=1}^{K}\mathcal{D}_{i} is split over KK disjoint clients that can communicate with a central server. The goal of collaborative data distillation is to produce a synthetic dataset 𝒮\mathcal{S} at the server that achieves similar generalization performance to i=1K𝒟i\cup_{i=1}^{K}\mathcal{D}_{i}. As a starting point, a straightforward solution would be to get each client ii to independently distill their own synthetic dataset 𝒮i\mathcal{S}_{i} (using any data distillation method) and set 𝒮=i=1K𝒮i\mathcal{S}=\cup_{i=1}^{K}\mathcal{S}_{i}. However, under the influence of data heterogeneity, the locally distilled data could be biased and, consequently, produce a distillation that does not capture the global data distribution.

Alternatively, the global dynamics of the data could be captured by adapting a full data distillation algorithm to a federated learning setting. In the following subsection, we provide a framework for the adaptation of the meta-learning and data-matching algorithms. This framework will act as a strawman solution to our collaborative distillation algorithm based on distribution matching.

4.1 Strawman Collaborative Distillation

The collaborative distillation process begins with the server initializing a set of synthetic data 𝒮\mathcal{S}. This can be achieved with a random initialization or by asking clients to transmit local data distillations: 𝒮=i=1K𝒮i\mathcal{S}=\cup_{i=1}^{K}\mathcal{S}_{i}. With the synthetic data initialized, it is then updated iteratively. For the meta-learning and data-matching frameworks, each iteration has the following steps:

  1. 1.

    The server retrieves a network θ\theta. The network parameters could be generated randomly or retrieved from a cache.

  2. 2.

    The network θ\theta is updated. The update could be based on either (𝒟;θ)\mathcal{R}(\mathcal{D};\theta) or (𝒮;θ)\mathcal{R}(\mathcal{S};\theta). For the former, clients conduct federated learning. The latter objective can be executed on the server.

  3. 3.

    θ\theta is broadcast to clients.

  4. 4.

    The loss function (𝒮)\mathcal{L}(\mathcal{S}) is computed. The clients collaborate to compute Φθ(D)\Phi_{\theta}(D), for batch D𝒟D\subset\mathcal{D}.

  5. 5.

    Synthetic data is updated based on the gradient of the loss function.

At each iteration, model parameters are broadcast to clients, and clients send losses incurred on real data to the server. For large models, this can create a communication overhead that compromises the benefits of collaborative distillation. In other words, similar to federated learning, the framework involves multiple rounds of communication that broadcast model parameters to clients.

Refer to caption
Figure 1: Overview of CollabDM. In a single round of communication, the server sends seeds to initialize learning models. The client then distills local data and computes embeddings on the seeded models. Locally distilled data and computed embeddings are then sent to the server. The server uses the embeddings to refine the distilled data to reflect the global data distribution.

This limitation can be overcome with distribution matching. As distribution matching does not require model training, Step 2 from the framework can be removed. Further, as embeddings are initialized randomly, they can be broadcast with a random seed. Thus, distribution matching can be implemented in a collaborative learning environment without explicitly sharing network parameters.

4.2 Collaborative Distribution Matching

Input: Number of clients KK, number of global iterations TT, proportion ε\varepsilon of clients participating in each iteration, learning rate η\eta and batch size BB for real data
Output: Global distilled dataset 𝒮\mathcal{S}
1 define ServerDM()\textsf{ServerDM}()
2       for t{1,,T}t\in\{1,\ldots,T\} do
3             αt\alpha_{t}\leftarrow random seed
4             Generate random subset of clients Zt[K]Z_{t}\subset[K] with |Zt|=εK|Z_{t}|=\varepsilon K
5            
6      for k{1,,K}k\in\{1,\ldots,K\} do
7             Ak{αt|kZt}A_{k}\leftarrow\{\alpha_{t}|k\in Z_{t}\}
8             𝒮k,LkClientDM(Ak)\mathcal{S}_{k},L_{k}\leftarrow\textsf{ClientDM}(A_{k})
      // Begin server data distillation
9       𝒮k=1K𝒮k\mathcal{S}\leftarrow\cup_{k=1}^{K}\mathcal{S}_{k}
10       for t{1,,T}t\in\{1,\ldots,T\} do
11             for y𝒴y\in\mathcal{Y} do
12                   Generate random batch Sty𝒮yS_{t}^{y}\subset\mathcal{S}^{y} of synthetic data
13                   y1|εK|kZtLt,ky1|Sty|xStyψαt(x)2\mathcal{L}^{y}\leftarrow\left\lVert\frac{1}{|\varepsilon K|}\sum\limits_{k\in Z_{t}}L_{t,k}^{y}-\frac{1}{|S_{t}^{y}|}\sum\limits_{x\in S_{t}^{y}}\psi_{\alpha_{t}}(x)\right\rVert^{2}
14                  
            // Compute loss according to Eq. (1)
15             y𝒴y\mathcal{L}\leftarrow\sum_{y\in\mathcal{Y}}\mathcal{L}^{y}
16             𝒮𝒮η𝒮\mathcal{S}\leftarrow\mathcal{S}-\eta\nabla_{\mathcal{S}}\mathcal{L}
17            
18      return 𝒮\mathcal{S}
19      
20define ClientDM(Ak)\textsf{ClientDM}(A_{k})
21       𝒮k\mathcal{S}_{k}\leftarrow Compute data distillation on 𝒟k\mathcal{D}_{k}
22       LkL_{k}\leftarrow\varnothing
23       for αtAk\alpha_{t}\in A_{k} do
24             for y𝒴y\in\mathcal{Y} do
25                   Generate random batch Dt,ky𝒟kyD_{t,k}^{y}\subset\mathcal{D}_{k}^{y} of real data with |Dt,ky|=B|D_{t,k}^{y}|=B
26                   Lt,ky1BxDt,kyψαt(x)L_{t,k}^{y}\leftarrow\frac{1}{B}\sum_{x\in D_{t,k}^{y}}\psi_{\alpha_{t}}(x)
27                  
28            LkLk{Lt,ky}y𝒴L_{k}\leftarrow L_{k}\cup\{L_{t,k}^{y}\}_{y\in\mathcal{Y}}
29            
30      return 𝒮k,Lk\mathcal{S}_{k},L_{k}
31      
Algorithm 1 Collaborative Distribution Matching (CollabDM)

The goal of Collaborative Distribution Matching (CollabDM) is to compute the loss function in Equation (3) for each embedding ψ\psi\in\mathcal{E}. The gradient of the loss is used to update the synthetic dataset stored on the server. As the loss function is calculated over the global dataset 𝒟=i=1K𝒟i\mathcal{D}=\cup_{i=1}^{K}\mathcal{D}_{i}, the updates are able to capture the global dynamics of the data. Equation  (3) can be split into two components: the embeddings on real data, which are computed at clients, and the embeddings on synthetic data, which are computed at the server. As \mathcal{E} is fixed, it can be broadcast to clients prior to the distillation process. This allows each client to compute the mean embeddings on real data, one for each iteration of server training, in a single batch and complete their share of the collaboration in a single round of communication. A high-level overview of the procedure is presented in Figure 1. We now outline the full algorithm (provided in Algorithm 1) in more detail.

Let ||=T|\mathcal{E}|=T denote the number of rounds required to distill synthetic data through distribution matching. In CollabDM, for each future training round t{1,,T}t\in\{1,\ldots,T\}, the server begins by selecting a random seed αt\alpha_{t} to encode a lower-dimension embedding and selecting a subset of clients Zt{1,,K}Z_{t}\subset\{1,\ldots,K\} to participate in the round. A batch of seeds Ak={αtkZt}A_{k}=\{\alpha_{t}\mid k\in Z_{t}\} is then broadcast to each client kk. Once the clients receive embedding seeds from the server, client training begins. Each client has two roles. First, the client performs a local data distillation to produce 𝒮k\mathcal{S}_{k}. Any data distillation technique could be used. Local distillations will be used to initialize the synthetic data at the server. Second, the client computes their contribution to each objective function. For each embedding αtAk\alpha_{t}\in A_{k} and label y𝒴y\in\mathcal{Y}, the client selects a batch of real data Dt,ky𝒟kyD_{t,k}^{y}\subset\mathcal{D}_{k}^{y}, of size B=|Dt,ky|B=|D_{t,k}^{y}|, and computes the mean of the embeddings on the batch:

Lt,ky=B1xDt,kyψαt(x).L_{t,k}^{y}=B^{-1}\sum_{x\in D_{t,k}^{y}}\psi_{\alpha_{t}}(x).

The collection of sums

Lk=t:αtZky𝒴Lt,kyL_{k}=\bigcup_{t:\alpha_{t}\in Z_{k}}\bigcup_{y\in\mathcal{Y}}L_{t,k}^{y}

is then sent to the server, along with 𝒮k\mathcal{S}_{k}. This concludes the client’s role in CollabDM. Thus, in a single round of communication, the client receives AkA_{k} and, subsequently, transmits (Lk,𝒮k)(L_{k},\mathcal{S}_{k}).

The server can now complete data distillation through distribution matching. The synthetic data is initialized through the local distillations 𝒮=k=1K𝒮k\mathcal{S}=\cup_{k=1}^{K}\mathcal{S}_{k}. The server then iterates through the embeddings in \mathcal{E}. For each embedding αt\alpha_{t}, using the client computations on real data kZtLt,k\cup_{k\in Z_{t}}L_{t,k}, the loss function \mathcal{L} of Equation (1) is evaluated. The synthetic data is then updated with the gradient of the loss with respect to 𝒮\mathcal{S}:

𝒮=𝒮η𝒮.\mathcal{S}=\mathcal{S}-\eta\nabla_{\mathcal{S}}\mathcal{L}.

As the embeddings on real data are constant with respect to 𝒮\mathcal{S}, the gradients can be computed at the server. Therefore, a global data distillation is achieved without further communication with clients.

4.3 Parameter Optimization

There are a number of optimizations that can be applied to the distribution matching objective to improve the utility of the synthetic data [19, 29, 36, 45]. These optimizations can be adapted to CollabDM. Notably, the synthetic data variables can be parameterized in a more efficient manner to allow for the distillation of more instances (for the same memory budget) and an enhanced representation of the real data. To achieve this, we adopt a technique called partition and expand [43]. For partition parameter ll, each image s𝒮s\in\mathcal{S} is partitioned into l×ll\times l mini-samples, and each mini-sample is then expanded to the input data dimensions using differentiable augmentation:

spartition[s1,1s1,lsl,1sl,l]up-samples1,s2,,sl\displaystyle s\xrightarrow{\text{partition}}\begin{bmatrix}s_{1,1}&\dots&s_{1,l}\\ \vdots&\ddots&\vdots\\ s_{l,1}&\dots&s_{l,l}\end{bmatrix}\xrightarrow{\text{up-sample}}s^{\prime}_{1},s^{\prime}_{2},\ldots,s^{\prime}_{l}

Thus, the number of features extracted from 𝒮\mathcal{S} is increased without changing the storage budget.

5 Experiments

We now evaluate the classification performance of deep networks that are trained on the synthetic data generated by our method. A key parameter in CollabDM is the number of global iterations TT. As TT increases, we would expect a higher fidelity synthetic set, as we are able to expose the loss function to a greater number of random models. However, increasing TT also increases the bandwidth overhead of the algorithm, as clients are required to send more random embeddings. Therefore, we are interested in the trade-off between classification accuracy and the amount of data transferred.

Experiments are split across two settings. First, we evaluate our approach against standard benchmark image classification datasets. This allows for robust comparison with existing art. Second, we provide an evaluation for a target application: attack detection in 5G mobile networks. This target application extends the use of data distillation techniques to network traffic data and provides further motivation for collaborative data distillation (Problem 1). Programs111Code is available here: https://github.com/rayneholland/CollabDM/tree/main were executed on a Dell laptop with Intel Core i5-8350U CPU, 8GB RAM, x64-based processor, and NVIDIA Quadro P5000M, 16 GB.

5.1 Training and Evaluation Setup

Flow Type Benign HTTPFlood SlowrateDoS UDPFlood Anomalous
# of images 5081 1497 777 4865 706
Table 1: Summary of 5G network traffic images.
Refer to caption
(a) IPC = 10, # of clients = 5
Refer to caption
(b) IPC = 10, # of clients = 20
Refer to caption
(c) IPC = 50, # of clients = 5
Refer to caption
(d) IPC = 50, # of clients = 20
Figure 2: Testing accuracy vs. data transmitted per client across different parameter settings. The dashed red line corresponds to the classification accuracy of partition-and-expand distribution matching in the central model.
Datasets

For benchmark testing, we conduct experiments on four standard image classification datasets: MNIST [14], FMNIST [38], CIFAR10 [12] and SVHN [23]. MNIST consists of 60,000 binary images of handwritten digits. There are 10 classes in total. FMNIST is a dataset of 70,000 binary images of fashion products with 10 classes. CIFAR10 is a selection of 50,000 small color images separated into 10 classes. SVHN is a dataset of 600,000 color images of street sign numbers with 10 classes.

For attack detection on 5G mobile networks, we adopt the 5G-NIDD dataset [28], a comprehensive benchmark dataset for 5G attack detection. The dataset is labeled and constructed through the creation of benign and malicious traffic profiles on a functional 5G test network. The data was collected in an environment comprised of 2 base stations connected to an attacker node and a set of benign traffic-generating devices. Malicious traffic is generated through either DoS or port scan attacks. The network captures are processed into CSV files containing the network flows and their associated features. Each flow is classified as either benign or as belonging to one of 8 different attacks. The classes of attack include:

  • DoS: HTTP Flood, ICMP Flood, SYN Flood, Slowrate DoS and UDP Flood.

  • Port Scan: UDP Scan, Syn Scan, and TCP connect Scan.

Each flow is marked with 44 features. Rows from the CSV file are combined into batches and transformed into 64x64 black/white images. In total, 12,295 images were created. As the classes ICMP Flood, SYN Flood, SYN Scan, TCP Connect Scan, and UDP Scan contained an insufficient number of images, they were combined into an umbrella class of ‘anomalous’ traffic. The resulting dataset is summarized in Table 1

Dataset MNIST FMNIST CIFAR10 SVHN
Skew β=0.1\beta=0.1 β=0.3\beta=0.3 β=0.5\beta=0.5 β=0.1\beta=0.1 β=0.3\beta=0.3 β=0.5\beta=0.5 β=0.1\beta=0.1 β=0.3\beta=0.3 β=0.5\beta=0.5 β=0.1\beta=0.1 β=0.3\beta=0.3 β=0.5\beta=0.5
FedD3 [31] 86.70 88.91 88.36 64.35 74.13 75.15 38.54 39.37 40.53 65.11 66.00 65.62
DOSFL [47] 77.97 83.24 91.73 64.25 72.01 80.83 43.92 47.08 56.62 67.94 69.05 69.91
DENSE [41] 66.61 76.48 95.82 50.29 83.96 85.94 50.26 59.76 62.19 55.34 79.59 80.03
LocalDM 96.10 96.93 97.17 84.18 84.5 84.24 52.93 52.17 54.22 70.04 70.69 71.49
CollabDM 97.72 97.82 97.83 85.43 86.51 86.71 57.97 59.36 60.21 74.35 74.66 75.57
CollabDM-pae 97.78 97.80 98.07 86.19 86.31 86.91 63.91 64.67 64.50 85.83 86.44 86.53
Table 2: Accuracy of different methods across β={0.1,0.3,0.5}\beta=\{0.1,0.3,0.5\} on different datasets. IPC = 50 for the distillation methods.
Data Partition

To simulate real-world applications and distributed learning environments, we use the Dirichlet distribution to generate a non-IID data partition among clients [40, 41]. In particular, we sample pkDir(β)p_{k}~{}\sim\textsf{Dir}(\beta) and allocate a pkip_{k}^{i} proportion of the data of class kk to the client ii. We can change the degree of imbalance by varying the parameter β\beta. A small β\beta generates a highly skewed partition.

Model Architecture

Unless otherwise specified, all synthetic data are distilled using embeddings from the convolutional neural network (ConvNet) architecture used by Zhao et al. [43]. For classification accuracy, the learned synthetic sets are used to train randomly initialized ConvNets, which are subsequently used to perform classification tasks on real data. The default ConvNet includes three repeated convolutional blocks, and each block involves a 128-kernel convolution layer, instance normalization layer [34], ReLu activation function [22], and average pooling. In each experiment, we learn one synthetic set and use it to test 20 randomly initialized networks. We repeat each experiment 5 times and report the mean testing accuracy of the 100 trained networks. In addition, to test the transferability of the synthetic data, do cross-architecture experiments in Section 5.3. Following Zhao et al. [43], we evaluate our method on four different architectures, including ConvNet, AlexNet [13], VGG11 [30] and ResNet18 [7]. In this setting, we learn the synthetic set on one network architecture and use the resulting set to train networks with different architectures. The ability to train different network architectures on the same synthetic set is an advantage collaborative data distillation has over traditional federated learning.

Comparison Methods

We evaluate two versions of CollabDM. These include the standard version, outlined in Algorithm 1, and the optimized version, outlined in Section 4.3, that utilizes the partition and expand technique [45]. For the optimized version, denoted CollabDM-pae, clients also employ partition-and-expand for the local distillation step.

For benchmarking, we compare CollabDM against four baselines. The first, named LocalDM, is based on naive virtual learning, where clients distill their data independently of each other, and the server uses this data for model training without refining it over a global objective. By using distribution matching as the local objective, we can assess the gain in classification accuracy achieved through the additional steps in CollabDM to refine global synthetic data. We also against compare two prior works (FedD3 [31] and DOSFL [47]) that use naive virtual learning with the meta-learning distillation objective. DOSFL is implemented with an optimization from Feng et al. [4] to update the method.

In addition, we evaluate Collab against DENSE [41], a state-of-the-art technique for one-shot federated learning. DENSE creates a global data generator and uses the data generator to train models on the server. For a fair comparison, we only include data synthesizing techniques. Thus, one-shot averaging methods, such as [32], are not included. These methods can provide a good classification model for training a single network. However, unlike data synthesizing techniques, they cannot be used for additional applications such as data sharing, neural architecture search, and continual learning.

Training parameters

Following Zhao et al. [43], the learning rate for local distillation is set to 1.0. The number of training iterations is set to 1000, compared to 20,000 in prior work. This is to reduce resource consumption on the client side, with each client taking 2-3 minutes to complete their portion of the algorithm. We also use a larger batch size of 512 for embeddings on real data. This will support faster convergence at the server and reduce the amount of data transmitted over the network. In addition, the learning rate for synthetic data at the server is set to 10. Again, this is designed to encourage faster convergence. During the experiments, we measure classification accuracy at every 50 iterations. This allows the trade-off between data transmission and classification accuracy to be observed.

5.2 Benchmark Image Data

Parameters

We begin by looking at the impact of key parameters on classification performance. The parameters under consideration are images-per-class, the number of clients, and the number of iterations performed for global distillation. This experiment will also provide a proof-of-concept. That is, it will demonstrate that the global distillation steps improve classification accuracy within a distributed setting. For this section, the dataset is CIFAR10, and data is distributed IID across clients.

The results of the experiment are displayed in Figure 2. The number of iterations is expressed as the amount of data transmitted across the channel. The results demonstrate that increasing the amount of information transmitted increases the classification accuracy of the synthetic set. However, diminishing returns are experienced, and the largest increases in accuracy occur during the early iterations. As expected, testing accuracy is inversely proportional to the number of clients. For example, as observed in Figures 2(a) and 2(b), the classification accuracy for CollabDM-pae drops from 59%59\% to 53%53\% as the number of clients increases from 5 to 20. Increasing the number of images-per-class only has a small impact on the amount of data transmitted, while significantly increase testing accuracy. For example, for CollabDM-pae, as observed in Figures 2(a) and 2(c), at 10 images-per-class 4.3MB of data achieves 57%57\% accuracy and at 50 images-per-class 3.27MB of data achieves 65%65\% accuracy. Notably, across all four settings, the partition-and-expand technique provides a significant increase in classification accuracy.

Refer to caption
Figure 3: The impact of images-per-class on testing accuracy for 5G network traffic data.

Heterogeneous One-Shot Learning

The evaluation of CollabDM in heterogeneous one-shot learning is listed in Table 2. The number of iterations at the server is set to 200. This is equivalent to at most 9MB of data being sent per client. The level of heterogeneity is controlled by the parameter β\beta. Most notably, all three distribution matching techniques demonstrate remarkable robustness to data heterogeneity, with all three outperforming the state-of-the-art method for β=0.1\beta=0.1. The surprising performance of LocalDM provides evidence that distribution matching techniques are well suited to collaborative data distillation and distributed learning settings.

CollabDM significantly outpeforms both FedD3 and DOSFL, the two techniques that perform local meta-learning distillation. Song et al. originally test FedD3 in a setting with a small dataset and a large number of clients (50000 data points spread across 500 clients) [31]. In their experiments, they rely on 500 clients distilling 10 images each, which leads to a synthetic set at the client of 5000 images. This is a brute approach that is effective at reducing communication when the scope is limited to just federated learning. However, under our objective, Problem 1, the method performs poorly when asked to distill a compact synthetic set of 50 images per class. For completeness, as CollabDM sends additional data, we also compared FedD3 and CollabDM when the volume of communication is equal. With each client limited to 9MB (which equates to 600600 images per client for FedD3), CollabDM outperforms FedD3 on all tests, including an improvement of 5 percentage points on CIFAR10.

DOSFL uses soft resets, where each training model is sampled from the parameters of the server’s model, to overcome data heterogeneity [47]. While soft resets outperform traditional random initializations on non-IID data, they are not robust against varying degrees of skew. This demonstrates the need for a global distillation objective to regularize learning.

CollabDM-pae outperforms the state-of-the-art DENSE in all experiments, with notable improvements for highly skewed data partitions (β=0.1\beta=0.1). For example, on the SVHN dataset, for β=0.1\beta=0.1, CollabDM-pae improves over DENSE by 30 percentage points.

5.3 5G Attack Detection

Attack detection on 5G network data is a motivating application for our technique. 5G networks are decentralized by design and cater to a range of verticals. It is a setting in which data generation is innately distributed and heterogeneous. In our test setting, data collection is split across two base stations, which act as the clients in CollabDM. For all experiments, the amount of data transmitted is limited to 99MB per client.

We first look at the impact of the number of images per class on attack classification. The results are presented in Figure 3. Remarkably, at just 1 image-per-class, CollabDM-pae achieves 89%89\% testing accuracy. This represents the distillation of 12,995 images of network traffic into 5 informative images, which, with the partition-and-expand technique, contain enough information to allow the network to not only distinguish between benign and malicious flows but also classify concrete attacks. This suggests that different attacks have highly distinct profiles. In addition, at just 10 images-per-class, CollabDM-pae achieves peak testing accuracy at 99%99\%.

To verify the generalizability of the global synthetic sets, we conduct cross-architecture experiments. The results are presented in Table 3. Synthetic data is learned on one architecture and evaluated on a separate architecture. Each synthetic set contains 10 images per class. Results indicate that the synthetic sets generalize well, with, at worst, only a small drop in accuracy when moving to new architectures. These results promote the use of data distillation for data sharing in 5G networks, with very small global synthetic sets available for machine learning applications at different locations in the network.

Train Test ConvNet AlexNet VG11 ResNet
ConvNet 98.84 96.49 98.18 97.60
AlexNet 96.55 94.21 95.87 94.19
VG11 89.95 86.31 89.33 91.30
ResNet 95.71 93.22 95.05 94.16
Table 3: Cross-architecture evaluation for 5G network traffic data.

6 Conclusion

We have presented a novel algorithm for data distillation in distributed settings. The algorithm supports the distillation of a synthetic set that matches the global data distribution and requires only a single round of communication between clients and the central server. Experiments demonstrate that learned synthetic sets are robust to heterogeneous data partitions and comfortably outperform the state-of-the-art approach. In addition, our work is motivated by a new application for data distillation: attack detection in 5G mobile networks. Experiments exhibit that distillation techniques effectively capture the information in both benign and malicious traffic profiles.

7 Acknowledgments

This research work is partially conducted as part of the 6G Security Research and Development Project, as led by the Commonwealth Scientific and Industrial Research Organisation (CSIRO) through funding appropriated by the Australian Government’s Department of Home Affairs. This paper does not reflect any Australian Government policy position. For more information regarding this Project, please refer to https://research.csiro.au/6gsecurity/.

References

  • [1] Bohdal, O., Yang, Y., and Hospedales, T. Flexible dataset distillation: Learn labels instead of images. arXiv preprint arXiv:2006.08572 (2020).
  • [2] Cazenavette, G., Wang, T., Torralba, A., Efros, A. A., and Zhu, J.-Y. Dataset distillation by matching training trajectories. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (2022), pp. 4750–4759.
  • [3] Chen, D., Kerkouche, R., and Fritz, M. Private set generation with discriminative information. Advances in Neural Information Processing Systems 35 (2022), 14678–14690.
  • [4] Feng, Y., Vedantam, S. R., and Kempe, J. Embarrassingly simple dataset distillation. In The Twelfth International Conference on Learning Representations (2023).
  • [5] Goetz, J., and Tewari, A. Federated learning via synthetic data. arXiv preprint arXiv:2008.04489 (2020).
  • [6] Guha, N., Talwalkar, A., and Smith, V. One-shot federated learning. arXiv preprint arXiv:1902.11175 (2019).
  • [7] He, K., Zhang, X., Ren, S., and Sun, J. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition (2016), pp. 770–778.
  • [8] Hinton, G., Vinyals, O., and Dean, J. Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531 (2015).
  • [9] Hu, S., Goetz, J., Malik, K., Zhan, H., Liu, Z., and Liu, Y. Fedsynth: Gradient compression via synthetic data in federated learning. arXiv preprint arXiv:2204.01273 (2022).
  • [10] Huang, C.-Y., Jin, R., Zhao, C., Xu, D., and Li, X. Federated virtual learning on heterogeneous data with local-global distillation. arXiv preprint arXiv:2303.02278 (2023).
  • [11] Kaur, J., Khan, M. A., Iftikhar, M., Imran, M., and Haq, Q. E. U. Machine learning techniques for 5g and beyond. IEEE Access 9 (2021), 23472–23488.
  • [12] Krizhevsky, A., Hinton, G., et al. Learning multiple layers of features from tiny images.
  • [13] Krizhevsky, A., Sutskever, I., and Hinton, G. E. Imagenet classification with deep convolutional neural networks. Advances in neural information processing systems 25 (2012).
  • [14] LeCun, Y., Bottou, L., Bengio, Y., and Haffner, P. Gradient-based learning applied to document recognition. Proceedings of the IEEE 86, 11 (1998), 2278–2324.
  • [15] Lei, S., and Tao, D. A comprehensive survey to dataset distillation. arXiv preprint arXiv:2301.05603 (2023).
  • [16] Li, T., Sahu, A. K., Zaheer, M., Sanjabi, M., Talwalkar, A., and Smith, V. Federated optimization in heterogeneous networks. Proceedings of Machine learning and systems 2 (2020), 429–450.
  • [17] Li, X., Jiang, M., Zhang, X., Kamp, M., and Dou, Q. Fedbn: Federated learning on non-iid features via local batch normalization. arXiv preprint arXiv:2102.07623 (2021).
  • [18] Liu, P., Yu, X., and Zhou, J. T. Meta knowledge condensation for federated learning. arXiv preprint arXiv:2209.14851 (2022).
  • [19] Liu, S., Wang, K., Yang, X., Ye, J., and Wang, X. Dataset distillation via factorization. Advances in Neural Information Processing Systems 35 (2022), 1100–1113.
  • [20] Liu, S., and Wang, X. Few-shot dataset distillation via translative pre-training. In Proceedings of the IEEE/CVF International Conference on Computer Vision (2023), pp. 18654–18664.
  • [21] McMahan, B., Moore, E., Ramage, D., Hampson, S., and y Arcas, B. A. Communication-efficient learning of deep networks from decentralized data. In Artificial intelligence and statistics (2017), PMLR, pp. 1273–1282.
  • [22] Nair, V., and Hinton, G. E. Rectified linear units improve restricted boltzmann machines. In Proceedings of the 27th international conference on machine learning (ICML-10) (2010), pp. 807–814.
  • [23] Netzer, Y., Wang, T., Coates, A., Bissacco, A., Wu, B., and Ng, A. Y. Reading digits in natural images with unsupervised feature learning.
  • [24] Nguyen, T., Chen, Z., and Lee, J. Dataset meta-learning from kernel ridge-regression. arXiv preprint arXiv:2011.00050 (2020).
  • [25] Pi, R., Zhang, W., Xie, Y., Gao, J., Wang, X., Kim, S., and Chen, Q. Dynafed: Tackling client data heterogeneity with global dynamics. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (2023), pp. 12177–12186.
  • [26] Rosasco, A., Carta, A., Cossu, A., Lomonaco, V., and Bacciu, D. Distilled replay: Overcoming forgetting through synthetic samples. In International Workshop on Continual Semi-Supervised Learning (2021), Springer, pp. 104–117.
  • [27] Sachdeva, N., and McAuley, J. Data distillation: A survey. arXiv preprint arXiv:2301.04272 (2023).
  • [28] Samarakoon, S., Siriwardhana, Y., Porambage, P., Liyanage, M., Chang, S.-Y., Kim, J., Kim, J., and Ylianttila, M. 5g-nidd: A comprehensive network intrusion detection dataset generated over 5g wireless network. arXiv preprint arXiv:2212.01298 (2022).
  • [29] Shin, D., Shin, S., and Moon, I.-C. Frequency domain-based dataset distillation. arXiv preprint arXiv:2311.08819 (2023).
  • [30] Simonyan, K., and Zisserman, A. Very deep convolutional networks for large-scale image recognition. arXiv preprint arXiv:1409.1556 (2014).
  • [31] Song, R., Liu, D., Chen, D. Z., Festag, A., Trinitis, C., Schulz, M., and Knoll, A. Federated learning via decentralized dataset distillation in resource-constrained edge environments. In 2023 International Joint Conference on Neural Networks (IJCNN) (2023), IEEE, pp. 1–10.
  • [32] Su, S., Li, B., and Xue, X. One-shot federated learning without server-side training. Neural Networks 164 (2023), 203–215.
  • [33] Sucholutsky, I., and Schonlau, M. Soft-label dataset distillation and text dataset distillation. In 2021 International Joint Conference on Neural Networks (IJCNN) (2021), IEEE, pp. 1–8.
  • [34] Ulyanov, D., Vedaldi, A., and Lempitsky, V. Instance normalization: The missing ingredient for fast stylization. arXiv preprint arXiv:1607.08022 (2016).
  • [35] Wang, H., Yurochkin, M., Sun, Y., Papailiopoulos, D., and Khazaeni, Y. Federated learning with matched averaging. arXiv preprint arXiv:2002.06440 (2020).
  • [36] Wang, K., Zhao, B., Peng, X., Zhu, Z., Yang, S., Wang, S., Huang, G., Bilen, H., Wang, X., and You, Y. Cafe: Learning to condense dataset by aligning features. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (2022), pp. 12196–12205.
  • [37] Wang, T., Zhu, J.-Y., Torralba, A., and Efros, A. A. Dataset distillation. arXiv preprint arXiv:1811.10959 (2018).
  • [38] Xiao, H., Rasul, K., and Vollgraf, R. Fashion-mnist: a novel image dataset for benchmarking machine learning algorithms. arXiv preprint arXiv:1708.07747 (2017).
  • [39] Xiong, Y., Wang, R., Cheng, M., Yu, F., and Hsieh, C.-J. Feddm: Iterative distribution matching for communication-efficient federated learning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (2023), pp. 16323–16332.
  • [40] Yurochkin, M., Agarwal, M., Ghosh, S., Greenewald, K., Hoang, N., and Khazaeni, Y. Bayesian nonparametric federated learning of neural networks. In International conference on machine learning (2019), PMLR, pp. 7252–7261.
  • [41] Zhang, J., Chen, C., Li, B., Lyu, L., Wu, S., Ding, S., Shen, C., and Wu, C. Dense: Data-free one-shot federated learning. Advances in Neural Information Processing Systems 35 (2022), 21414–21428.
  • [42] Zhao, B., and Bilen, H. Dataset condensation with differentiable siamese augmentation. In International Conference on Machine Learning (2021), PMLR, pp. 12674–12685.
  • [43] Zhao, B., and Bilen, H. Dataset condensation with distribution matching. In Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (2023), pp. 6514–6523.
  • [44] Zhao, B., Mopuri, K. R., and Bilen, H. Dataset condensation with gradient matching. arXiv preprint arXiv:2006.05929 (2020).
  • [45] Zhao, G., Li, G., Qin, Y., and Yu, Y. Improved distribution matching for dataset condensation. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (2023), pp. 7856–7865.
  • [46] Zhou, Y., Nezhadarya, E., and Ba, J. Dataset distillation using neural feature regression. Advances in Neural Information Processing Systems 35 (2022), 9813–9827.
  • [47] Zhou, Y., Pu, G., Ma, X., Li, X., and Wu, D. Distilled one-shot federated learning. arXiv preprint arXiv:2009.07999 (2020).