Patch-level Routing in Mixture-of-Experts is Provably Sample-efficient for Convolutional Neural Networks
Abstract
In deep learning, mixture-of-experts (MoE) activates one or few experts (sub-networks) on a per-sample or per-token basis, resulting in significant computation reduction. The recently proposed patch-level routing in MoE (pMoE) divides each input into patches (or tokens) and sends patches () to each expert through prioritized routing. pMoE has demonstrated great empirical success in reducing training and inference costs while maintaining test accuracy. However, the theoretical explanation of pMoE and the general MoE remains elusive. Focusing on a supervised classification task using a mixture of two-layer convolutional neural networks (CNNs), we show for the first time that pMoE provably reduces the required number of training samples to achieve desirable generalization (referred to as the sample complexity) by a factor in the polynomial order of , and outperforms its single-expert counterpart of the same or even larger capacity. The advantage results from the discriminative routing property, which is justified in both theory and practice that pMoE routers can filter label-irrelevant patches and route similar class-discriminative patches to the same expert. Our experimental results on MNIST, CIFAR-10, and CelebA support our theoretical findings on pMoE’s generalization and show that pMoE can avoid learning spurious correlations.
1 Introduction
Deep learning has demonstrated exceptional empirical success in many applications at the cost of high computational and data requirements. To address this issue, mixture-of-experts (MoE) only activates partial regions of a neural network for each data point and significantly reduces the computational complexity of deep learning without hurting the performance in applications such as machine translation and natural image classification (Shazeer et al., 2017; Yang et al., 2019).

A conventional MoE model contains multiple experts (subnetworks of the backbone architecture) and one learnable router that routes each input sample to a few but not all the experts (Ramachandran & Le, 2018). Position-wise MoE has been introduced in language models (Shazeer et al., 2017; Lepikhin et al., 2020; Fedus et al., 2022), where the routing decisions are made on embeddings of different positions of the input separately rather than routing the entire text-input. Riquelme et al. (2021) extended it to vision models where the routing decisions are made on image patches. Zhou et al. (2022) further extended where the MoE layer has one router for each expert such that the router selects partial patches for the corresponding expert and discards the remaining patches. We termed this routing mode as patch-level routing and the MoE layer as patch-level MoE (pMoE) layer (see Figure 1 for an illustration of a pMoE). Notably, pMoE achieves the same test accuracy in vision tasks with 20% less training compute, and 50% less inference compute compared to its single-expert (i.e., one expert which is receiving all the patches of an input) counterpart of the same capacity (Riquelme et al., 2021).
Despite the empirical success of MoE, it remains elusive in theory, why can MoE maintain test accuracy while significantly reducing the amount of computation? To the best of our knowledge, only one recent work by Chen et al. (2022) shows theoretically that a conventional sample-wise MoE achieves higher test accuracy than convolutional neural networks (CNN) in a special setup of a binary classification task on data from linearly separable clusters. However, the sample-wise analyses by Chen et al. (2022) do not extend to patch-level MoE, which employ different routing strategies than conventional MoE, and their data model might not characterize some practical datasets. This paper addresses the following question theoretically:
How much computational resource does pMoE save from the single-expert counterpart while maintaining the same generalization guarantee?
In this paper, we consider a supervised binary classification task where each input sample consists of equal-sized patches including class-discriminative patterns that determine the labels and class-irrelevant patterns that do not affect the labels. The neural network contains a pMoE layer111In practice, pMoEs are usually placed in the last layers of deep models. Our analysis can be extended to this case as long as the input to the pMoE layer satisfies our data model (see Section 4.2). and multiple experts, each of which is a two-layer CNN222We consider CNN as expert due to its wide applications, especially in vision tasks. Moreover, the pMoE in (Riquelme et al., 2021; Zhou et al., 2022) uses two-layer Multi-Layer Perceptrons (MLPs) as experts in vision transformer (ViT), which operates on image patches. Hence, the MLPs in (Riquelme et al., 2021; Zhou et al., 2022) are effectively non-overlapping CNNs. of the same architecture. The router sends () patches to each expert. Although we consider a simplified neural network model to facilitate the formal analysis of pMoE, the insights are applicable to more general setups. Our major results include:
1. To the best of our knowledge, this paper provides the first theoretical generalization analysis of pMoE. Our analysis reveals that pMoE with two-layer CNNs as experts can achieve the same generalization performance as conventional CNN while reducing the sample complexity (the required number of training samples to learn a proper model) and model complexity. Specifically, we prove that as long as is larger than a certain threshold, pMoE reduces the sample complexity and model complexity by a factor polynomial in , indicating an improved generalization with a smaller .
2. Characterization of the desired property of the pMoE router. We show that a desired pMoE router can dispatch the same class-discriminative patterns to the same expert and discard some class-irrelevant patterns. This discriminative property allows the experts to learn the class-discriminative patterns with reduced interference from irrelevant patterns, which in turn reduces the sample complexity and model complexity. We also prove theoretically that a separately trained pMoE router has the desired property and empirically verify this property on practical pMoE routers.
3. Experimental demonstration of reduced sample complexity by pMoE in deep CNN models. In addition to verifying our theoretical findings on synthetic data prepared from the MNIST dataset (LeCun et al., 2010), we demonstrate the sample efficiency of pMoE in learning some benchmark vision datasets (e.g., CIFAR-10 (Krizhevsky, 2009) and CelebA (Liu et al., 2015)) by replacing the last convolutional layer of a ten-layer wide residual network (WRN) (Zagoruyko & Komodakis, 2016) with a pMoE layer. These experiments not only verify our theoretical findings but also demonstrate the applicability of pMoE in reducing sample complexity in deep-CNN-based vision models, complementing the existing empirical success of pMoE with vision transformers.
2 Related Works
Mixture-of-Experts. MoE was first introduced in the 1990s with dense sample-wise routing, i.e. each input sample is routed to all the experts (Jacobs et al., 1991; Jordan & Jacobs, 1994; Chen et al., 1999; Tresp, 2000; Rasmussen & Ghahramani, 2001). Sparse sample-wise routing was later introduced (Bengio et al., 2013; Eigen et al., 2013), where each input sample activates few of the experts in an MoE layer both for joint training (Ramachandran & Le, 2018; Yang et al., 2019) and separate training of the router and experts (Collobert et al., 2001, 2003; Ahmed et al., 2016; Gross et al., 2017). Position/patch-wise MoE (i.e., pMoE) recently demonstrated success in large language and vision models (Shazeer et al., 2017; Lepikhin et al., 2020; Riquelme et al., 2021; Fedus et al., 2022). To solve the issue of load imbalance (Lewis et al., 2021), Zhou et al. (2022) introduces the expert-choice routing in pMoE, where each expert uses one router to select a fixed number of patches from the input. This paper analyzes the sparse patch-level MoE with expert-choice routing under both joint-training and separate-training setups.
Optimization and generalization analyses of neural networks (NN). Due to the significant nonconvexity of deep learning problem, the existing generalization analyses are limited to linearized or shallow neural networks. The Neural-Tangent-Kernel (NTK) approach (Jacot et al., 2018; Lee et al., 2019; Du et al., 2019; Allen-Zhu et al., 2019b; Zou et al., 2020; Chizat et al., 2019; Ghorbani et al., 2021) considers strong over-parameterization and approximates the neural network by the first-order Taylor expansion. The NTK results are independent of the input data, and performance gaps in the representation power and generalization ability exist between the practical NN and the NTK results (Yehudai & Shamir, 2019; Ghorbani et al., 2019, 2020; Li et al., 2020; Malach et al., 2021). Nonlinear neural networks are analyzed recently through higher-order Taylor expansions (Allen-Zhu et al., 2019a; Bai & Lee, 2019; Arora et al., 2019; Ji & Telgarsky, 2019) or employing a model estimation approach from Gaussian input data (Zhong et al., 2017b, a; Zhang et al., 2020b, a; Fu et al., 2020; Li et al., 2022b), but these results are limited to two-layer networks with few papers on three-layer networks (Allen-Zhu et al., 2019a; Allen-Zhu & Li, 2019, 2020a; Li et al., 2022a).
The above works consider arbitrary input data or Gaussian input. To better characterize the practical generalization performance, some recent works analyze structured data models using approaches such as feature mapping (Li & Liang, 2018), where some of the initial model weights are close to data features, and feature learning (Daniely & Malach, 2020; Shalev-Shwartz et al., 2020; Shi et al., 2021; Allen-Zhu & Li, 2022; Li et al., 2023), where some weights gradually learn features during training. Among them, Allen-Zhu & Li (2020b); Brutzkus & Globerson (2021); Karp et al. (2021) analyze CNN on learning structured data composed of class-discriminative patterns that determine the labels and other label-irrelevant patterns. This paper extends the data models in Allen-Zhu & Li (2020b); Brutzkus & Globerson (2021); Karp et al. (2021) to a more general setup, and our analytical approach is a combination of feature learning in routers and feature mapping in experts for pMoE.
3 Problem Formulation
This paper considers the supervised binary classification333Our results can be extended to multiclass classification problems. See Section M in the Appendix for details. problem where given i.i.d. training samples generated by an unknown distribution , the objective is to learn a neural network model that maps to for any sampled from . Here, the input has disjoint patches, i.e., , where denotes the -th patch of . denotes the corresponding label.
3.1 Neural Network Models
We consider a pMoE architecture that includes experts and the corresponding routers. Each router selects out of () patches for each expert separately. Specifically, the router for each expert () contains a trainable gating kernel . Given a sample , the router computes a routing value for each patch . Let denote the index set of top- values of among all the patches . Only patches with indices in are routed to the expert , multiplied by a gating value , which are selected differently in different pMoE models.
Each expert is a two-layer CNN with the same architecture. Let denote the total number of neurons in all the experts. Then each expert contains neurons. Let and denote the hidden layer and output layer weights for neuron ( in expert (), respectively. The activation function is the rectified linear unit (ReLU), where .
Let include all the trainable weights. The pMoE model denoted as , is defined as follows:
|
(1) |

The learning problem solves the following empirical risk minimization problem with the logistic loss function,
(2) |
We consider two different training modes of pMoE, Separate-training and Joint-training of the routers and the experts. We also consider the conventional CNN architecture for comparison.
(I) Separate-training pMoE: Under the setup of the so-called hard mixtures of experts (Collobert et al., 2003; Ahmed et al., 2016; Gross et al., 2017), the router weights are trained first and then fixed when training the weights of the experts. In this case, the gating values are set as
(3) |
We select in this case to simplify the analysis.
(II) Joint-training pMoE: The routers and the experts are learned jointly, see, e.g., (Lepikhin et al., 2020; Riquelme et al., 2021; Fedus et al., 2022). Here, the gating values are softmax functions with
(4) |
(III) CNN single-expert counterpart: The conventional two-layer CNN with neurons, denoted as , satisfies,
(5) |
Eq. (5) can be viewed as a special case of (1) when there is only one expert (), and all the patches are sent to the expert () with gating values .
Let denote the parameters of the learned model by solving (1). The predicted label for a test sample by the learned model is . The generalization accuracy, i.e., the fraction of correct predictions of all test samples equals . This paper studies both separate and joint training of pMoE and compares their performance with CNN, from the perspective of sample complexity to achieve a desirable generalization accuracy.
3.2 Training Algorithms
In the following algorithms, we fix the output layer weights and at their initial values randomly sampled from the standard Gaussian distribution and do not update them during the training. This is a typical simplification when analyzing NN, as used in (Li & Liang, 2018; Brutzkus et al., 2018; Allen-Zhu et al., 2019a; Arora et al., 2019).
(I) Separate-training pMoE: The routers are separately trained using training samples (), denoted by without loss of generality. The gating kernels and are obtained by solving the following minimization problem:
(6) |
To solve (6), we implement the mini-batch SGD with batch size for iterations, starting from the random initialization as follows:
(7) |
where, .
After learning the routers, we train the hidden-layer weights by solving (2) while fixing and . We implement mini-batch SGD of batch size for iterations starting from the initialization
(8) |
4 Theoretical Results
4.1 Key Findings At-a-glance
Before defining the data model assumptions and rationale in Section 4.2 and presenting the formal results in 4.3, we first summarize our key findings. We assume that the data patches are sampled from either class-discriminative patterns that determine the labels or a possibly infinite number of class-irrelevant patterns that have no impact on the label. The parameter (defined in (9)) is inversely related to the separation among patterns, i.e., decreases when (i) the separation among class-discriminative patterns increases, and/or (ii) the separation between class-discriminative and class-irrelevant patterns increases. The key findings are as follows.
(I). A properly trained patch-level router sends class-discriminative patches of one class to the same expert while dropping some class-irrelevant patches. We prove that separate-training pMoE routes class-discriminative patches of the class with label (or the class with label ) to the expert 1 (or the expert 2) respectively, and the class-irrelevant patterns that are sufficiently away from class-discriminative patterns are not routed to any expert (Lemma 4.1). This discriminative routing property is also verified empirically for joint-training pMoE (see section 5.1). Therefore, pMoE effectively reduces the interference by irrelevant patches when each expert learns the class-discriminative patterns. Moreover, we show empirically that pMoE can remove class-irrelevant patches that are spuriously correlated with class labels and thus can avoid learning from spuriously correlated features of the data.
(II). Both the sample complexity and the required number of hidden nodes of pMoE reduce by a polynomial factor of over CNN. We prove that as long as , the number of patches per expert, is greater than a threshold (that decreases as the separation between class-discriminative and class-irrelevant patterns increases), the sample complexity and the required number of neurons of learning pMoE are and respectively. In contrast, the sample and model complexities of the CNN are and respectively, indicating improved generalization by pMoE.
(III). Larger separation among class-discriminative and class-irrelevant patterns reduces the sample complexity and model complexity of pMoE. Both the sample complexity and the required number of neurons of pMoE is polynomial in , which decreases when the separation among patterns increases.
4.2 Data Model Assumptions and Rationale
The input is comprised of one class-discriminative pattern and class-irrelevant patterns, and the label is determined by the class-discriminative pattern only.
Distributions of class-discriminative patterns: The unit vectors and denote the class-discriminative patterns that determine the labels. The separation between and is measured as . and are equally distributed in the samples, and each sample has exactly one of them. If contains (or ), then is (or ).
Distributions of class-irrelevant patterns. Class-irrelevant patterns are unit vectors in belonging to disjoint pattern sets , and these patterns distribute equally for both classes. measures the separation between class-discriminative patterns and class-irrelevant patterns, where , , , . Each belongs to a ball with a diameter of . Note that NO separation among class-irrelevant patterns themselves is required.
The rationale of our data model. The data distribution captures the locality of the label-defining features in image data. It is motivated by and extended from the data distributions in recent theoretical frameworks (Yu et al., 2019; Brutzkus & Globerson, 2021; Karp et al., 2021; Chen et al., 2022). Specifically, Yu et al. (2019) and Brutzkus & Globerson (2021) require orthogonal patterns, i.e., and are both , and there are only a fixed number of non-discriminative patterns. Karp et al. (2021) and Chen et al. (2022) assume that and a possibly infinite number of patterns drawn from zero-mean Gaussian distribution. In our model, takes any value in , and the class-irrelevant patterns can be drawn from pattern sets that contain an infinite number of patterns that are not necessarily Gaussian or orthogonal.
Define
(9) |
decreases if (1) and are more separated from each other, and (2) Both and are more separated from any set , . We also define an integer () that measures the maximum number of class-irrelevant patterns per sample that are sufficiently closer to than , and vice versa. Specifically, a class-irrelevant pattern is called -closer () to than , if holds. Similarly, is -closer to than if . Then, let be the maximum number of class-irrelevant patches that are either -closer to than or vice versa with in any sampled from . depends on and . When is fixed, a smaller corresponds to a larger separation between and and leads to a small . In contrast to linearly separable data in (Yu et al., 2019; Brutzkus et al., 2018; Chen et al., 2022), our data model is NOT linearly separable as long as (see section K in Appendix for the proof).
4.3 Main Theoretical Results
4.3.1 Generalization Guarantee of Separate-training pMoE
Lemma 4.1 shows that as long as the number of patches per expert, , is greater than , then the separately learned routers by solving (6) always send to expert 1 and to expert 2. Based on this discriminative property of the learned routers, Theorem 4.2 then quantifies the sample complexity and network size of separate-training pMoE to achieve a desired generalization error . Theorem 4.3 quantifies the sample and model complexities of CNN for comparison.
Lemma 4.1 (Discriminative Property of Separately Trained Routers).
For every , w.h.p. over the random initialization defined in (7), after doing mini-batch SGD with batch-size and learning rate , for iterations, the returned and satisfy
i.e., the learned routers always send to expert 1 and to expert 2.
The main idea in proving Lemma 4.1 is to show that the gradient in each iteration has a large component along the directions of and . Then after enough iterations, the inner product of and (similarly, and ) is sufficiently large. The intuition of requiring is that because there are at most class-irrelevant patches sufficiently closer to than (or vice versa), then sending patches to one expert will ensure that one of them is (or ). Note that the batch size and the number of iterations depend on , the separation between and , but are independent of the separation between class-discriminative and class-irrelevant patterns.
We then show that the separate-training pMoE reduces both the sample complexity and the required model size (Theorem 4.2) compared to the CNN (Theorem 4.3).
Theorem 4.2 (Generalization guarantee of separate-training pMoE).
For every and , for every with at least training samples, after performing minibatch SGD with the batch size and the learning rate for iterations, it holds w.h.p. that
Theorem 4.2 implies that to achieve generalization error by a separate-training pMoE, we need training samples and hidden nodes. Therefore, both and increase polynomially with the number of patches sent to each expert. Moreover, both and are polynomial in defined in (9), indicating an improved generalization performance with stronger separation among patterns.
The proof of Theorem 4.2 is inspired by Li & Liang (2018), which analyzes the generalization performance of fully-connected neural networks (FCN) on structured data, but we have new technical contributions in analyzing pMoE models. In addition to analyzing the pMoE routers (Lemma 4.1), which do not appear in the FCN analysis, our analyses also significantly relax the separation requirement on the data, compared with that by Li & Liang (2018). For example, Li & Liang (2018) requires the separation between the two classes, measured by the smallest -norm distance of two points in different classes, being to obtain a sample complexity bound of poly() for the binary classification task. In contrast, the separation between the two classes in our data model is , much less than required by Li & Liang (2018).
Theorem 4.3 (Generalization guarantee of CNN).
For every , for every with at least training samples, after performing minibatch SGD with the batch size and the learning rate for iterations, it holds w.h.p. that
Theorem 4.3 implies that to achieve a generalization error using CNN in (5), we need training samples and neurons.
Sample-complexity gap between single CNN and mixture of CNNs. From Theorem 4.2 and Theorem 4.3, the sample-complexity ratio of the CNN to the separate-training pMoE is . Similarly, the required number of neurons is reduced by a factor of in separate-training pMoE444The bounds for the sample complexity and model size in Theorem 4.2 and Theorem 4.3 are sufficient but not necessary. Thus, rigorously speaking, one can not compare sufficient conditions only. In our analysis, however, the bounds for MoE and CNN are derived with exactly the same technique with the only difference to handle the routers. Therefore, it is fair to compare these two bounds to show the advantage of pMoE..
4.3.2 Generalization Guarantee of Joint-training pMoE with Proper Routers
Theorem 4.5 characterizes the generalization performance of joint-training pMoE assuming the routers are properly trained in the sense that after some SGD iterations, for each class at least one of the experts receives all class-discriminative patches of that class with the largest gating-value (see Assumption 4.4).
Assumption 4.4.
There exists an integer such that for all , it holds that:
where () denotes the index of the class-discriminative pattern (), is the gating output of patch of sample for expert at the iteration , and is the gating kernel for expert at iteration .
Assumption 4.4 is required in proving Theorem 4.5 because of the difficulty of tracking the dynamics of the routers in joint-training pMoE. Assumption 4.4 is verified on empirical experiments in Section 5.1, while its theoretical proof is left for future work.
Complexity to achieve error (Complx/Iter T) | pMoE | CNN | ||
Separate-training | Joint-training | |||
Complexity per Iteration (Complx/Iter) | Router | Expert | ||
(Forward pass) | ||||
(Backward pass) | ||||
Iteration required to converge with error (T) |
Theorem 4.5 (Generalization guarantee of joint-training pMoE).
Suppose Assumption 4.4 hold. Then for every , for every with at least training samples, after performing minibatch SGD with the batch size and the learning rate for iterations, it holds w.h.p. that
Theorem 4.5 indicates that, with proper routers, joint-training pMoE needs training samples and neurons to achieve generalization error. Compared with CNN in Theorem 4.3, joint-training pMoE reduces the sample complexity and model size by a factor of and , respectively. With more experts (a larger ), it is easier to satisfy Assumption 4.4 to learn proper routers but requires larger sample and model complexities. When the number of samples is fixed, the expression of also indicates that sales as , corresponding to an improved generalization when and decrease.
We provide the end-to-end computational complexity comparison between the analyzed pMoE models and general CNN model in Table 1 (see section N in Appendix for details). The results in Table 1 indicates that the computational complexity in joint-training pMoE is reduced by a factor of compared with CNN. Similarly, the reduction of computational complexity of separate-training pMoE is .
5 Experimental Results
5.1 pMoE of Two-layer CNN
Dataset: We verify our theoretical findings about the model in (1) on synthetic data prepared from MNIST (LeCun et al., 2010) data set. Each sample contains patches with patch size . Each patch is drawn from the MNIST dataset. See Figure 4 as an example. We treat the digits “1” and “0” as the class-discriminative patterns and , respectively. Each of the digits from “2” to “9” represents a class-irrelevant pattern set.





Setup: We compare separate-training pMoE, joint-training pMoE, and CNN with similar model sizes. The separate-training pMoE contains two experts with hidden nodes in each expert. The joint-training pMoE has eight experts with five hidden nodes per expert. The CNN has hidden nodes. All are trained using SGD with until zero training error. pMoE converges much faster than CNN, which takes epochs. Before training the experts in the separate-training pMoE, we train the router for epochs. The models are evaluated on test samples.
Generalization performance: Figure 4 compares the test accuracy of the three models, where and for separate-training and joint-training pMoE, respectively. The error bars show the mean plus/minus one standard deviation of five independent experiments. pMoE outperforms CNN with the same number of training samples. pMoE only requires 60% of the training samples needed by CNN to achieve test accuracy.
Figure 5 shows the sample complexity of separate-training pMoE with respect to . Each block represents 20 independent trials. A white block indicates all success, and a black block indicates all failure. The sample complexity is polynomial in , verifying Theorem 4.2. Figure 7 and 6 show the test accuracy of joint-training pMoE with a fixed sample size when and change, respectively. When is greater than , which is in Figure 7, the test accuracy matches our predicted order. Similarly, the dependence on also matches our prediction, when is large enough to make Assumption 4.4 hold.
Router performance: Figure 8 verifies the discriminative property of separately trained routers (Lemma 4.1) by showing the percentage of testing data that have class-discriminative patterns ( and ) in top patches of the separately trained router. With very few training samples (such as ), one can already learn a proper router that has discriminative patterns in top- patches for 95% of data. Figure 9 verifies the discriminative property of jointly trained routers (Assumption 4.4). With only training samples, the jointly trained router dispatches with the largest gating value to a particular expert for 95% of class-1 data and similarly for in 92% of class-2 data.


5.2 pMoE of Wide Residual Networks (WRNs)
Neural network model: We employ the 10-layer WRN (Zagoruyko & Komodakis, 2016) with a widening factor of 10 as the expert. We construct a patch-level MoE counterpart of WRN, referred to as WRN-pMoE, by replacing the last convolutional layer of WRN with an pMoE layer of an equal number of trainable parameters (see Figure 18 in Appendix for an illustration). WRN-pMoE is trained with the joint-training method555Code is available at https://github.com/nowazrabbani/pMoE_CNN. All the results are averaged over five independent experiments.
Datasets: We consider both CelebA (Liu et al., 2015) and CIFAR-10 datasets. The experiments on CIFAR-10 are deferred to the Appendix (see section A). We down-sample the images of CelebA to . The last convolutional layer of WRN receives a () dimensional feature map. The feature map is divided into patches with size in WRN-pMoE. and for the pMoE layer.



No. of training samples | Convergence time (sec) | Training FLOPs () | ||
WRN | WRN-pMoE | WRN | WRN-pMoE | |
Performance Comparison: Figure 12 shows the test accuracy of the binary classification problem on the attribute “smiling.” WRN-pMoE requires less than one-fifth of the training samples needed by WRN to achieve 86% accuracy. Figure 12 shows the performance when the training data contain spurious correlations with the hair color as a spurious attribute. Specifically, 95% of the training images with the attribute “smiling” also have the attribute “black hair,” while 95% of the training images with the attribute “not-smiling” have the attribute “blond hair.” The models may learn the hair-color attribute rather than “smiling” due to spurious correlation and, thus, the test accuracies are lower in Figure 12 than those in Figure 12. Nevertheless, WRN-pMoE outperforms WRN and reduces the sample complexity to achieve the same accuracy.
Figure 12 shows the test accuracy of multiclass classification (four classes with class attributes: “Not smiling, Eyeglass,” “Smiling, Eyeglass,” “Smiling, No eyeglass,” and “Not smiling, No eyeglass”) in CelebA. The results are consistent with the binary classification results. Furthermore, Table 2 empirically verifies the computational efficiency of WRN-pMoE over WRN on multiclass classification in CelebA666An NVIDIA RTX 4500 GPU was used to run the experiments, training FLOPs are calculated as . Even with same number of training samples, WRN-pMoE is still more computationally efficient than WRN, because WRN-pMoE requires fewer iterations to converge and has a lower per-iteration cost.
6 Conclusion
MoE reduces computational costs significantly without hurting the generalization performance in various empirical studies, but the theoretical explanation is mostly elusive. This paper provides the first theoretical analysis of patch-level MoE and proves its savings in sample complexity and model size quantitatively compared with the single-expert counterpart. Although centered on a classification task using a mixture of two-layer CNNs, our theoretical insights are verified empirically on deep architectures and multiple datasets. Future works include analyzing other MoE architectures such as MoE in Vision Transformer (ViT) and connecting MoE with other sparsification methods to further reduce the computation.
Acknowledgements
This work was supported by AFOSR FA9550-20-1-0122, NSF 1932196 and the Rensselaer-IBM AI Research Collaboration (http://airc.rpi.edu), part of the IBM AI Horizons Network (http://ibm.biz/AIHorizons). We thank Yihua Zhang at Michigan State University for the help in experiments with CelebA dataset. We thank all anonymous reviewers.
References
- Ahmed et al. (2016) Ahmed, K., Baig, M. H., and Torresani, L. Network of experts for large-scale image categorization. In European Conference on Computer Vision, pp. 516–532. Springer, 2016.
- Allen-Zhu & Li (2019) Allen-Zhu, Z. and Li, Y. What can resnet learn efficiently, going beyond kernels? Advances in Neural Information Processing Systems, 32, 2019.
- Allen-Zhu & Li (2020a) Allen-Zhu, Z. and Li, Y. Backward feature correction: How deep learning performs deep learning. arXiv preprint arXiv:2001.04413, 2020a.
- Allen-Zhu & Li (2020b) Allen-Zhu, Z. and Li, Y. Towards understanding ensemble, knowledge distillation and self-distillation in deep learning. arXiv preprint arXiv:2012.09816, 2020b.
- Allen-Zhu & Li (2022) Allen-Zhu, Z. and Li, Y. Feature purification: How adversarial training performs robust deep learning. In 2021 IEEE 62nd Annual Symposium on Foundations of Computer Science (FOCS), pp. 977–988. IEEE, 2022.
- Allen-Zhu et al. (2019a) Allen-Zhu, Z., Li, Y., and Liang, Y. Learning and generalization in overparameterized neural networks, going beyond two layers. Advances in neural information processing systems, 32, 2019a.
- Allen-Zhu et al. (2019b) Allen-Zhu, Z., Li, Y., and Song, Z. A convergence theory for deep learning via over-parameterization. In International Conference on Machine Learning, pp. 242–252. PMLR, 2019b.
- Arora et al. (2019) Arora, S., Du, S., Hu, W., Li, Z., and Wang, R. Fine-grained analysis of optimization and generalization for overparameterized two-layer neural networks. In International Conference on Machine Learning, pp. 322–332. PMLR, 2019.
- Bai & Lee (2019) Bai, Y. and Lee, J. D. Beyond linearization: On quadratic and higher-order approximation of wide neural networks. In International Conference on Learning Representations, 2019.
- Bengio et al. (2013) Bengio, Y., Léonard, N., and Courville, A. Estimating or propagating gradients through stochastic neurons for conditional computation. arXiv preprint arXiv:1308.3432, 2013.
- Brutzkus & Globerson (2021) Brutzkus, A. and Globerson, A. An optimization and generalization analysis for max-pooling networks. In Uncertainty in Artificial Intelligence, pp. 1650–1660. PMLR, 2021.
- Brutzkus et al. (2018) Brutzkus, A., Globerson, A., Malach, E., and Shalev-Shwartz, S. SGD learns over-parameterized networks that provably generalize on linearly separable data. In International Conference on Learning Representations, 2018.
- Chen et al. (1999) Chen, K., Xu, L., and Chi, H. Improved learning algorithms for mixture of experts in multiclass classification. Neural networks, 12(9):1229–1252, 1999.
- Chen et al. (2022) Chen, Z., Deng, Y., Wu, Y., Gu, Q., and Li, Y. Towards understanding mixture of experts in deep learning. arXiv preprint arXiv:2208.02813, 2022.
- Chizat et al. (2019) Chizat, L., Oyallon, E., and Bach, F. On lazy training in differentiable programming. Advances in Neural Information Processing Systems, 32, 2019.
- Collobert et al. (2001) Collobert, R., Bengio, S., and Bengio, Y. A parallel mixture of SVMs for very large scale problems. Advances in Neural Information Processing Systems, 14, 2001.
- Collobert et al. (2003) Collobert, R., Bengio, Y., and Bengio, S. Scaling large learning problems with hard parallel mixtures. International Journal of pattern recognition and artificial intelligence, 17(03):349–365, 2003.
- Daniely & Malach (2020) Daniely, A. and Malach, E. Learning parities with neural networks. Advances in Neural Information Processing Systems, 33:20356–20365, 2020.
- Du et al. (2019) Du, S., Lee, J., Li, H., Wang, L., and Zhai, X. Gradient descent finds global minima of deep neural networks. In International conference on machine learning, pp. 1675–1685. PMLR, 2019.
- Eigen et al. (2013) Eigen, D., Ranzato, M., and Sutskever, I. Learning factored representations in a deep mixture of experts. arXiv preprint arXiv:1312.4314, 2013.
- Fedus et al. (2022) Fedus, W., Zoph, B., and Shazeer, N. Switch transformers: Scaling to trillion parameter models with simple and efficient sparsity. Journal of Machine Learning Research, 23(120):1–39, 2022.
- Fu et al. (2020) Fu, H., Chi, Y., and Liang, Y. Guaranteed recovery of one-hidden-layer neural networks via cross entropy. IEEE transactions on signal processing, 68:3225–3235, 2020.
- Ghorbani et al. (2019) Ghorbani, B., Mei, S., Misiakiewicz, T., and Montanari, A. Limitations of lazy training of two-layers neural network. Advances in Neural Information Processing Systems, 32, 2019.
- Ghorbani et al. (2020) Ghorbani, B., Mei, S., Misiakiewicz, T., and Montanari, A. When do neural networks outperform kernel methods? Advances in Neural Information Processing Systems, 33:14820–14830, 2020.
- Ghorbani et al. (2021) Ghorbani, B., Mei, S., Misiakiewicz, T., and Montanari, A. Linearized two-layers neural networks in high dimension. The Annals of Statistics, 49(2):1029–1054, 2021.
- Gross et al. (2017) Gross, S., Ranzato, M., and Szlam, A. Hard mixtures of experts for large scale weakly supervised vision. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 6865–6873, 2017.
- Jacobs et al. (1991) Jacobs, R. A., Jordan, M. I., Nowlan, S. J., and Hinton, G. E. Adaptive mixtures of local experts. Neural computation, 3(1):79–87, 1991.
- Jacot et al. (2018) Jacot, A., Gabriel, F., and Hongler, C. Neural tangent kernel: Convergence and generalization in neural networks. Advances in neural information processing systems, 31, 2018.
- Ji & Telgarsky (2019) Ji, Z. and Telgarsky, M. Polylogarithmic width suffices for gradient descent to achieve arbitrarily small test error with shallow relu networks. In International Conference on Learning Representations, 2019.
- Jordan & Jacobs (1994) Jordan, M. I. and Jacobs, R. A. Hierarchical mixtures of experts and the em algorithm. Neural computation, 6(2):181–214, 1994.
- Karp et al. (2021) Karp, S., Winston, E., Li, Y., and Singh, A. Local signal adaptivity: Provable feature learning in neural networks beyond kernels. Advances in Neural Information Processing Systems, 34:24883–24897, 2021.
- Krizhevsky (2009) Krizhevsky, A. Learning multiple layers of features from tiny images. Technical report, Canadian Institute For Advanced Research, 2009.
- LeCun et al. (2010) LeCun, Y., Cortes, C., and Burges, C. MNIST handwritten digit database. AT&T labs [online]. available http. yann. lecun. com/exdb/mnist, 2010.
- Lee et al. (2019) Lee, J., Xiao, L., Schoenholz, S., Bahri, Y., Novak, R., Sohl-Dickstein, J., and Pennington, J. Wide neural networks of any depth evolve as linear models under gradient descent. Advances in neural information processing systems, 32, 2019.
- Lepikhin et al. (2020) Lepikhin, D., Lee, H., Xu, Y., Chen, D., Firat, O., Huang, Y., Krikun, M., Shazeer, N., and Chen, Z. Gshard: Scaling giant models with conditional computation and automatic sharding. In International Conference on Learning Representations, 2020.
- Lewis et al. (2021) Lewis, M., Bhosale, S., Dettmers, T., Goyal, N., and Zettlemoyer, L. Base layers: Simplifying training of large, sparse models. In International Conference on Machine Learning, pp. 6265–6274. PMLR, 2021.
- Li et al. (2022a) Li, H., Wang, M., Liu, S., Chen, P.-Y., and Xiong, J. Generalization guarantee of training graph convolutional networks with graph topology sampling. In International Conference on Machine Learning, pp. 13014–13051. PMLR, 2022a.
- Li et al. (2022b) Li, H., Zhang, S., and Wang, M. Learning and generalization of one-hidden-layer neural networks, going beyond standard gaussian data. In 2022 56th Annual Conference on Information Sciences and Systems (CISS), pp. 37–42. IEEE, 2022b.
- Li et al. (2023) Li, H., Wang, M., Liu, S., and Chen, P.-Y. A theoretical understanding of shallow vision transformers: Learning, generalization, and sample complexity. In The Eleventh International Conference on Learning Representations, 2023. URL https://openreview.net/forum?id=jClGv3Qjhb.
- Li & Liang (2018) Li, Y. and Liang, Y. Learning overparameterized neural networks via stochastic gradient descent on structured data. Advances in neural information processing systems, 31, 2018.
- Li et al. (2020) Li, Y., Ma, T., and Zhang, H. R. Learning over-parametrized two-layer neural networks beyond NTK. In Conference on learning theory, pp. 2613–2682. PMLR, 2020.
- Liu et al. (2015) Liu, Z., Luo, P., Wang, X., and Tang, X. Deep learning face attributes in the wild. In Proceedings of the IEEE international conference on computer vision, pp. 3730–3738, 2015.
- Malach et al. (2021) Malach, E., Kamath, P., Abbe, E., and Srebro, N. Quantifying the benefit of using differentiable learning over tangent kernels. In International Conference on Machine Learning, pp. 7379–7389. PMLR, 2021.
- Ramachandran & Le (2018) Ramachandran, P. and Le, Q. V. Diversity and depth in per-example routing models. In International Conference on Learning Representations, 2018.
- Rasmussen & Ghahramani (2001) Rasmussen, C. and Ghahramani, Z. Infinite mixtures of gaussian process experts. Advances in neural information processing systems, 14, 2001.
- Riquelme et al. (2021) Riquelme, C., Puigcerver, J., Mustafa, B., Neumann, M., Jenatton, R., Susano Pinto, A., Keysers, D., and Houlsby, N. Scaling vision with sparse mixture of experts. Advances in Neural Information Processing Systems, 34:8583–8595, 2021.
- Shalev-Shwartz et al. (2020) Shalev-Shwartz, S. et al. Computational separation between convolutional and fully-connected networks. In International Conference on Learning Representations, 2020.
- Shazeer et al. (2017) Shazeer, N., Mirhoseini, A., Maziarz, K., Davis, A., Le, Q. V., Hinton, G. E., and Dean, J. Outrageously large neural networks: The sparsely-gated mixture-of-experts layer. In International Conference on Learning Representations, 2017.
- Shi et al. (2021) Shi, Z., Wei, J., and Liang, Y. A theoretical analysis on feature learning in neural networks: Emergence from inputs and advantage over fixed features. In International Conference on Learning Representations, 2021.
- Tresp (2000) Tresp, V. Mixtures of gaussian processes. In Leen, T., Dietterich, T., and Tresp, V. (eds.), Advances in Neural Information Processing Systems, volume 13. MIT Press, 2000. URL https://proceedings.neurips.cc/paper/2000/file/9fdb62f932adf55af2c0e09e55861964-Paper.pdf.
- Vaswani et al. (2017) Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, Ł., and Polosukhin, I. Attention is all you need. Advances in neural information processing systems, 30, 2017.
- Yang et al. (2019) Yang, B., Bender, G., Le, Q. V., and Ngiam, J. Condconv: Conditionally parameterized convolutions for efficient inference. Advances in Neural Information Processing Systems, 32, 2019.
- Yehudai & Shamir (2019) Yehudai, G. and Shamir, O. On the power and limitations of random features for understanding neural networks. Advances in Neural Information Processing Systems, 32, 2019.
- Yu et al. (2019) Yu, B., Zhang, J., and Zhu, Z. On the learning dynamics of two-layer nonlinear convolutional neural networks. arXiv preprint arXiv:1905.10157, 2019.
- Zagoruyko & Komodakis (2016) Zagoruyko, S. and Komodakis, N. Wide residual networks. arXiv preprint arXiv:1605.07146, 2016.
- Zhang et al. (2020a) Zhang, S., Wang, M., Liu, S., Chen, P.-Y., and Xiong, J. Fast learning of graph neural networks with guaranteed generalizability: one-hidden-layer case. In International Conference on Machine Learning, pp. 11268–11277. PMLR, 2020a.
- Zhang et al. (2020b) Zhang, S., Wang, M., Xiong, J., Liu, S., and Chen, P.-Y. Improved linear convergence of training CNNs with generalizability guarantees: A one-hidden-layer case. IEEE Transactions on Neural Networks and Learning Systems, 32(6):2622–2635, 2020b.
- Zhong et al. (2017a) Zhong, K., Song, Z., and Dhillon, I. S. Learning non-overlapping convolutional neural networks with multiple kernels. arXiv preprint arXiv:1711.03440, 2017a.
- Zhong et al. (2017b) Zhong, K., Song, Z., Jain, P., Bartlett, P. L., and Dhillon, I. S. Recovery guarantees for one-hidden-layer neural networks. In International conference on machine learning, pp. 4140–4149. PMLR, 2017b.
- Zhou et al. (2022) Zhou, Y., Lei, T., Liu, H., Du, N., Huang, Y., Zhao, V. Y., Dai, A. M., Chen, Z., Le, Q. V., and Laudon, J. Mixture-of-experts with expert choice routing. In Oh, A. H., Agarwal, A., Belgrave, D., and Cho, K. (eds.), Advances in Neural Information Processing Systems, 2022. URL https://openreview.net/forum?id=jdJo1HIVinI.
- Zou et al. (2020) Zou, D., Cao, Y., Zhou, D., and Gu, Q. Gradient descent optimizes over-parameterized deep relu networks. Machine learning, 109(3):467–492, 2020.
Appendix A Experiments on CIFAR-10 Datasets
We also compare WRN and WRN-pMoE on CIFAR-10-based datasets. To better reflect local features, in addition to the original CIFAR-10, we adopt techniques of Karp et al. (2021) to generate two datasets based on CIFAR-10:
1. CIFAR-10 with ImageNet noise. Each CIFAR-10 image is down-sampled to size and placed at a random location of a background image chosen from ImageNet Plants synset. Figure 14(c) shows an example image of this dataset.
2. CIFAR-Vehicles. Each vehicle image of CIFAR-10 is down-sampled to size and placed in one quadrant of an image randomly where the other quadrants are randomly filled with down-sampled animal images in CIFAR-10. See Figure 14(b) for a sample image.
The last convolutional layer of WRN receives a dimensional feature map. In WRN-pMoE we divide this feature map into patches with size . The MoE layer of WRN-pMoE contains experts with each expert receiving patches.




Figures 14, 16, and 16 compare the test accuracy of WRN and WRN-pMoE for the ten-classification problem on CIFAR10 and CIFAR-10 with ImageNet noise, and the four-classification problem in CIFAR-Vehicles, respectively. WRN-pMoE outperforms WRN in all these datasets, indicating reduced sample complexity using the pMoE layer. The performance gap is more significant in the other two datasets than the original CIFAR-10 dataset. That is because these constructed datasets contain local features, and the pMoE layer has a clear advantage in learning local features effectively.
Appendix B Preliminaries
The loss function for SGD at iteration with minibatch :
(10) |
For the router-training in separate-training pMoE, the loss function of SGD at iteration with minibatch :
(11) |
Notations:
- 1.
- 2.
-
3.
We denote, such that the expert initialization, .
The training algorithms for separate-training and joint-training pMoE are given in Algorithm 1 and Algorithm 2, respectively:
Appendix C Proof Sketch
The proof of generalization guarantee for pMoE (i.e., Theorem 4.2 and 4.5) can be outlined as follows (the proof for single CNN follows a simpler version of the outline provided below):
Step 1. (Feature learning in the router) For separate-training pMoE, we first show that the batch-gradient of the router loss (i.e., ) w.r.t. the gating kernels (i.e., and ) has large component (of size ) along the class-discriminative pattern and respectively. Then, by selecting (which provides us loss reduction per step) and training for iterations, we can show that and is sufficiently aligned with and respectively to guarantee the selection of these class-discriminative patterns in TOP- patches when (see Lemma D.4 for exact statement).
Step 2. (Coupling the experts to pseudo experts) When the experts of pMoE are sufficiently overparameterized, w.h.p. the experts can be coupled to a smooth pseudo network777The pseudo network is defined as the network which activation pattern does not change from the initialization i.e., the sign of the pre-activation output of hidden nodes does not change from the sign at initialization; see (Li & Liang, 2018) for details. of experts as for every sample drawn from the distribution and every , the activation pattern for (for separate-training pMoE) or (for joint-training pMoE) fraction of hidden nodes in each expert does not change from the initialization for iterations (see Lemma G.1 or H.1 for exact statement). This indicates that with (for separate-training pMoE) or (for joint-training pMoE), (for separate-training pMoE) or (for joint-training pMoE) and we can couple fraction of hidden nodes of each expert to the corresponding pseudo experts for iterations.
Step 3.(Large error implies large gradient) We can now analyze the pseudo network of experts corresponding to the separate-training pMoE to show that, at any iteration , the magnitude of the expected gradient for any expert of the pseudo network is where characterizes the class-conditional expected error over samples with and for and , respectively (see Lemma G.3 for exact statement). Similarly, for joint-training pMoE we show that the magnitude of the expected gradient is , but this time characterizes the maximum of the class-conditional expected-errors over the samples for which the expert “” receiving class-discriminative patterns from the router (see Lemma H.3 for exact statement).
Step 4. (Convergence) Now let us define . For separate-training pMoE, by selecting the batch size at iteration , and , we can couple the empirical batch gradient of each expert of the true network for that batch to the expected gradient of the corresponding expert of the pseudo network. Because the pseudo network is smooth, we can show that SGD minimizes the expected loss of the true network by at each iteration for iterations (see Lemma G.4 for the exact statement). Similarly, for joint-training pMoE, by selecting and we can show that SGD minimizes the expected loss of the true network by for (see Lemma H.4 for exact statement). As the loss of the true network is at initialization, eventually the network will converge.
Step 5. (Generalization) We show that to ensure at most generalization error after any iteration , we need where and correspond to the class-conditional expected error of the class with and , respectively. Now as we show that the router in the separate-training pMoE dispatch class-discriminative patches of all the samples labeled as to the expert indexed by and class-discriminative patches of all the samples labeled as to the expert indexed by from the beginning of expert-training, ensures . On the other hand, for the joint-training pMoE, as we assume that the router ensures the dispatchment of all the class-discriminative patches of a class to a particular expert before the convergence of the model and the gating value of the patch is the largest among all the patches sent to that particular expert, implies . Hence for separate-training pMoE, by setting we show that with and for iterations, we can guarantee that the generalization error is less than (see Theorem F.3 for exact statement). Similarly, for joint-training pMoE, by setting and setting and for iterations, we can guarantee that the generalization error is less than (see Theorem F.5 for exact statement).
Appendix D Proof of the Lemma 4.1
Definition D.1.
(-closer class-irrelevant patterns) For any , a class-irrelevant pattern is -closer to than , if for any . Similarly, a class-irrelevant pattern is -closer to than if .
Definition D.2.
(Set of -closer class-irrelevant patterns, ) For any , define the set of -closer class-irrelevant patterns, denoted as such that: .
Definition D.3.
(Threshold, )
Define the threshold such that:
Lemma D.4.
Appendix E Lemmas Used to Prove the Lemma 4.1
We denote,
where for all .
Lemma E.1.
At any iteration of the Step-2 of Algorithm 1,
, and
Proof.
As, ,
and
Therefore,
where the last equality comes from the fact that class-irrelevant patterns are distributed identically in both classes. Using similar line of arguments we can show that, . ∎
Lemma E.2.
With probability (i.e., w.h.p.) over the random initialization of the gating kernels defined in (7), ;
Proof.
Let us denote the -th element of the vector as where .
Then according to the random initialization of and using a Gaussian tail-bound (i.e., for ): .
Let us denote the event .
Therefore, .
Now, conditioned on the event .
Therefore,
∎
Lemma E.3.
Proof.
Let, at -th iteration of Step-2 of Algorithm 1, for all
Also let us denote, for all
Therefore, after -th iteration of SGD and using Lemma E.1,
Similarly, .
Now, . Hence, w.h.p. over a randomly sampled batch of size , using Hoeffding’s concentration,
.
Now,
On the other hand, ,
From Lemma E.2, w.h.p. over the random initialization: .
Therefore, selecting and , we need iterations to achieve ,
Similar line of arguments can be made to show with batch size and learning rate , after iterations, , .
∎
Appendix F Proofs of the Theorem 4.2, 4.3 and 4.5
Definition F.1.
At any iteration of the minibatch SGD,
-
1.
Define the value function, . It is easy to show that for any , . The function captures the prediction error, i.e., a larger indicates a larger prediction error.
-
2.
Define, the class-conditional expected value function, and . Here, captures the expected error for the class with label and captures the expected error for the class with label .
Definition F.2.
At any iteration of the minibatch SGD,
-
1.
For any sample , we define the reduction of loss at the -th iteration of SGD as,
where, is the single-sample loss function.
-
2.
Define the expected reduction of loss at the -th iteration of SGD as,
Theorem F.3.
(Full version of Theorem 4.2) For every and , for every with at least training samples, after performing minibatch SGD with the batch size and the learning rate for iterations, it holds w.h.p. that
Proof.
First we will show that for any , if , then .
Now for any and , if , i.e., the prediction is correct.
Now if , then using Markov’s inequality which implies for any , .
Similarly, if , for any , .
Therefore, for any , if , then .
Now, if then , which implies after a proper number of iterations if then .
Let, . Then by using Lemma G.4 for every , with and , at least for we have,
(12) |
Now, as with , and . Therefore, w.h.p. which implies . Now as , (12) can happen at most iterations. Now as , we need iterations to ensure that .
On the other hand, to ensure (12) hold for iterations, we need,
which implies we need . ∎
Now, for any and , let us denote the index of the class-discriminative patterns i.e., and as and , respectively.
Definition F.4.
At any iteration of minibatch SGD of the joint-training pMoE (i.e., Step-2 of Algorithm 2),
-
1.
For any and the expert , define the event that in Top- as, . Similarly, for any define the event that in Top- as, .
-
2.
For any expert , define the probability of the event that in Top- as, and the probability of the event that in Top- as,
-
3.
For any expert define, and where and denote the gating value for the class-discriminative patterns and conditioned on and , respectively.
Theorem F.5.
Proof.
From the argument of the proof of Theorem F.3, we know that for any , if , then where and
Now, we will consider the case when where is defined in Assumption 4.4.
Now, if the expert satisfies Assumption 4.4 for , then and for any . Therefore, .
Similarly, if the expert satisfies Assumption 4.4 for , then .
Now for any expert , let us define
Now, if , then and .
This implies, and .
Therefore, and which implies and .
In that case, .
Therefore, by taking , using the results of Lemma H.4 and following same procedure as in Theorem F.3 we can complete the proof.
∎
Theorem F.6.
(Full version of the Theorem 4.3) For every , for every with at least training samples, after performing minibatch SGD with the batch size and the learning rate for iterations, it holds w.h.p. that
Appendix G Lemmas Used to Prove the Theorem 4.2
For any iteration of the Step-3 of Algorithm 1, recall the loss function for a single-sample generated by the distribution , . The gradient of the loss for a single sample with respect to the hidden nodes of the experts:
(13) |
We define the corresponding pseudo-gradient as:
(14) |
Therefore, the expected pseudo-gradient:
Here,
Lemma G.1.
Proof.
Recall the gradient of the loss for single-sample w.r.t. the hidden node of the expert :
and the corresponding pseudo-gradient:
Now, . Hence, using the concentration bound of Gaussian random variable (i.e., for ) and as hides factor we get:
Now as and , w.h.p. so as the mini-batch gradient, .
Now, from the update rule of the Step-3 of Algorithm 1,
Therefore, using the property of Telescoping series,
Therefore, where we denote by
Now, for every consider the set
Now, for every ,
Which implies for every , and ,
Therefore, for every , and ,
and hence,
Now, we will find the lower bound of
As,
Hence,
Now as ,
Therefore,
∎
Using the following two lemmas we show that when is large, the expected pseudo-gradient of the loss function w.r.t. the hidden nodes of the expert 1 is large. Similar thing happens for expert 2 when is large.
We prove the first of these two lemmas for a fixed set which does not depend on the random initialization of the hidden nodes of the experts (i.e., on ). In the second of these two lemmas we remove the dependency on fixed set by means of a sampling trick introduced in (Li & Liang, 2018) to take a union bound over an epsilon-net on the set .
Lemma G.2.
For any possible fixed set (that does not depend on ) such that for and for we have for every :
Proof.
WLOG, let’s assume . Now,
Then,
Now, let us decompose , where
Then,
Where,
and,
Note that, and both are convex functions.
Now for , using Lemma D.4, we can express as follows:
Now, for any class-irrelevant pattern set where , let us define such that . Also, let us define the set,
Now let us define the event
Now, as , for every
Now, . Hence,
Therefore, .
Picking, gives, .
On the other hand, . Therefore,
Now, s.t. ,
where the last inequality comes from the bound of the diameter of the pattern sets and the fact that for any .
Therefore, using Markov’s inequality s.t.
Now,
s.t.
Now, conditioned on the event , for a fixed and is the only random variable,
s.t.
which is a linear function of with probability at least and, which is a linear function of with probability .
Now, let us define and as the set of sub-gradient at the point for and respectively such that , , and .
Then, using the above argument, conditioned on the event , .
On the other hand, .
Now using Lemma J.1, conditioned on the event , .
Now, for , conditioned on , the density , which implies that,
(15) |
Now, as does not depends on , .
Now, using a concentration bound of Gaussian RV (i.e., ),
(16) |
Now, taking in (16) we get,
(17) |
On the other hand, picking and plugging in at (15) gives,
(18) |
Lemma G.3.
Let for and for . Then, for every , for , for every possible set (that depends on ), there exist at least fraction of of the expert such that for every ,
Proof.
Let us pick samples to form with many samples from and many samples from . Let us denote the subset of samples with as and the subset of samples with as . Therefore, . Let us denote the corresponding value function of -th sample of S as . Since, each using Hoeffding’s inequality we know that w.h.p. :
This implies that, as long as , we will have that,
Now, the average pseudo-gradient over the set S,
where,
Now as ,
Now for a fixed set as long as , for every using Lemma G.2,
Hence, for a fixed set , the probability that there are less than fraction of such that is is no more than where, .
Moreover, for every , for two different , such that , , since w.h.p. ,
which implies that we can take -net with .
Thus, the probability that there exists such that there are no more than fraction of with is no more than,
.
Hence, for with , w.h.p. for every possible choice of , there are at least fraction of such that,
Now, we consider the difference between the sample gradient and the expected gradient. Since, , by using the Hoeffding’s inequality, we know that for every :
This implies that as long as and hence for , such also have:
∎
Lemma G.4.
Let us define where for and for ; . Then, by selecting learning rate and batch size , at each iteration of the Step-3 of Algorithm 1 such that , w.h.p. we can ensure that for every ,
Proof.
For every , from Lemma G.3, for at least fraction of of expert :
Now w.h.p., . Therefore, w.h.p. over a randomly sampled batch from at iteration denoted as of size :
This implies, by selecting batch-size of , for these fraction of of expert we can ensure that:
Now using Lemma G.1, for a fixed , by selecting we have fraction of of the expert :
Therefore, at least fraction of of the expert :
Recall our definition of loss-function for SGD at iteration with mini-batch , and the corresponding batch-gradient at iteration , . Therefore, there are at least fraction of of the expert :
Now for any , according to Lemma G.1, w.h.p. there are at least fraction of of the expert such that . Let us denote the set of these ’s of as . Therefore, on the set , the loss function is -smooth and -Lipschitz smooth.
On the other hand, the update rule of SGD at the iteration is,
Therefore, using Lemma J.2,
Let us denote the event,
Then, (i.e., w.h.p.) and hence
Also, let us define the event,
Then, and hence
Now, the expected gradient at iteration ,
Therefore condition on ,
Which implies,
Again, condition on ,
Now, w.h.p.
Therefore,
Therefore,
Now selecting, , , and hence for
, we get,
∎
Appendix H Lemmas Used to Prove the Theorem 4.5
In joint-training pMoE i.e., for any iteration of the Step-2 of Algorithm 2, the gradient of the loss for single-sample with respect to the hidden nodes of the experts:
(19) |
and the corresponding pseudo-gradient:
(20) |
and the expected pseudo-gradient:
with,
Lemma H.1.
Proof.
Lemma H.2.
For the expert and any possible fixed set (that does not depend on ) such that , we have:
Proof.
We know that,
Therefore,
Now, decomposing with we get,
where,
and
Now as and both are convex functions, using the same procedure as in Lemma G.1 we can complete the proof. ∎
Lemma H.3.
Let . Then, for every , for , for every possible set (that depends on ), there exist at least fraction of of the expert such that,
Proof.
Let us pick samples to form with many samples from such that many samples of them satisfy the event and many samples from such that many samples of them satisfy the event . We denote the subset of S satisfying the event by and the subset of S satisfying the event by . Therefore, and . Now, w.h.p. :
This implies that, as long as , we will have that,
Now using the same procedure as in Lemma G.3 and using Lemma H.2 we can show that, for a fixed set as long as , the probability that there are less than fraction of such that is is no more than where, .
Now, for every , for two different , such that , , w.h.p.,
Therefore taking -net with we can show that the probability that there exists such that there are no more than fraction of with is no more than,
.
Hence, for with , w.h.p. for every possible choice of , there are at least fraction of such that,
Now as , using the same procedure as in Lemma G.3 we can complete the proof which gives us . ∎
Lemma H.4.
Let us define where for all ; . Then, by selecting learning rate and batch size , at each iteration of the Step-2 of Algorithm 2 such that , w.h.p. we can ensure that,
Proof.
As w.h.p. , for a randomly sampled batch of size , by selecting in Lemma H.1 and using the same procedure as in Lemma G.4, we can show that for at least fraction of of expert :
Now, for any , from Lemma H.1 we know that for at least fraction of of any expert , the loss function is -Lipschitz smooth and also -smooth.
Therefore, using same procedure as in Lemma G.4 we can complete the proof.
∎
Appendix I Lemmas Used to Prove the Theorem 4.3
For the single CNN model, as all the patches of an input are sent to the model (i.e., there is no router), the gradient of the single sample loss function w.r.t. hidden node ,
(21) |
the corresponding pseudo-gradient,
and the expected pseudo-gradient,
where,
Lemma I.1.
W.h.p. over the random initialization, for every and for every , for every iteration of the minibatch SGD, we have that for at least fraction of :
and
Proof.
Recall, and .
Lemma I.2.
For any possible fixed set (that does not depend on ) such that , we have:
Proof.
We know that,
Therefore,
Now, decomposing with we get,
where,
and
Now as and both are convex functions, using the same procedure as in Lemma G.1 we can complete the proof. ∎
Lemma I.3.
Let . Then, for every , for , for every possible set (that depends on ), there exist at least fraction of such that,
Proof.
Similar as in the proof of Lemma G.3, by picking samples from the distribution to form the set such that many samples from (denoting the sub-set by ) and many samples from (denoting the sub-set by ), we can show that w.h.p.,
This implies that, as long as we have,
Now using Lemma I.2 and following similar procedure as in Lemma G.3 we can complete the proof. ∎
Lemma I.4.
With and , by selecting learning rate and batch-size , for iterations of SGD, w.h.p. we can ensure that,
Appendix J Auxiliary Lemmas
Lemma J.1.
(Li & Liang, 2018) Let and are convex functions. Let and are the sets of sub-gradient of and at respectively such that , , and . Then for any such that ,
Lemma J.2.
(Li & Liang, 2018) Let for any , the function is -Lipschitz smooth and there exists such that for all the function is also -smooth. Furthermore, let us assume that the function is both -Lipschitz smooth and -smooth. Let define where such that . Then for every such that with , we have:
Appendix K Proof of the Non-linear Separability of the Data-model
Lemma K.1.
As long as , the distribution is NOT linearly separable.
Proof.
We will prove the Lemma by contradiction.
Now, if the distribution, is linearly separable, then there exists a hyperplane with (here, represents the -th patch of the hyperplane for ) such that,
(22) |
Now, as the class-discriminative patterns and can occur at any position , .
Now, , we can decompose as .
Then, , as .
Now,
Now,
Now,
Therefore, for there is contradiction with (22). ∎
Appendix L WRN and WRN-pMoE Architectures Implemented in the Experiments


Appendix M Extension to Multi-class Classification
Let us consider -class classification problem where . Then, we have where for the multi-class distribution .
The multi-class data model:
Now, according to the data model presented in section 4.2, we have as class-discriminative pattern set. such that , we define . We further define . Then,
The multi-class pMoE model:
The pMoE model for multi-class case is given by,
(23) |

For mult-class case, we replace the logistic loss function by the softmax loss function (also known as cross-entropy loss). For the training dataset , we minimize the following empirical risk minimization problem:
(24) |
M.1 The Multi-class Separate-training pMoE
Number of experts: For the multi-class separate-training pMoE, we take , i.e. number of experts is equal to the number of classes.
Training algorithm:
Input : Training data , learning rates and , number of iterations and , batch-
sizes and
Step-1: Initialize according to (7) and (8)
Step-2: (Pair-wise router training) We train the router, i.e. the gating-kernels using pair-wise training describe below:
-
1.
At first, we separate the training set of samples into disjoint subsets according to the class-labels.
-
2.
Now, we prepare pairs of training sets (here WLOG we assume that is even).
-
3.
Under each pair , we re-define the label as and for the class and respectively and train the gating-kernels and by minimizing (6) for iterations
-
4.
After the end of pair-wise training for all the pairs , we receive as the learned gating-kernels.
Step-3:(Expert training)
Using the learned gating-kernels in Step-2 and using the same procedure as in Step-3 of Algorithm 1 we train the experts.
The multi-class counterpart of the Lemma 4.1:
Now, using the same proof techniques as for Lemma 4.1 (i.e. following same procedures as in section D and E) we can show that, we need training samples to ensure,
The multi-class counterpart of the Theorem
4.2:
We redefine the value-function for each class as,
(25) |
Now using similar techniques as in the proof of Theorem 4.2 (i.e. following same procedure as in the proof of Theorem F.3 and section G) we can show that for every , we need number of hidden nodes , batch-size for iterations (i.e. ) to ensure,
M.2 The Multi-class Joint-training pMoE
Training algorithm: Same as the Algorithm 2 except that for multi-class case the loss function is softmax instead of logistic loss.
The multi-class counterpart of the Theorem
4.5:
Using the value-function define in (25) and as long as the Assumption 4.4 satisfied for all the classes , following the similar techniques as in the proof of Theorem 4.5 (i.e. following same procedure as in the proof of Theorem F.5 and section H), we can show that for every , we need number of hidden nodes , batch-size for iterations (i.e. ) to ensure,
Appendix N Details of the Results in Table 1
Complexity in forward pass. The computational complexity of a non-overlapping convolution operation by a filter of dimension on an input sample of patches (of same dimension as the filter) is (Vaswani et al., 2017). Therefore, the complexity of forward pass of a batch of size through a convolution layer of neurons is . Similarly, the forward pass complexity of a the batch through the experts (of same total number of neurons as in the convolution layer) of a pMoE layer is . The operations in a pMoE router includes convolution (with complexity ), softmax operation (with complexity ) and TOP- operation (with complexity when ). Therefore, the overall forward pass complexity of a pMoE router with expert is .
Complexity in backward pass. The gradient of neurons in convolution layer for an input sample is given in (21), which implies that the complexity of the gradient calculation is (addition of vectors of dimension ) and hence the backward pass complexity of CNN is . Similarly, the backward pass complexity of pMoE experts is . Now the gradient of gating kernels in pMoE router is given in (26), which implies that the complexity of the gradient calculation is (addition of vectors of dimension ) and hence the backward pass complexity of pMoE router is .
(26) |
Complexity to achieve generalization error. From Theorem 4.5, to achieve generalization error we need iterations of training in pMoE, which implies that the computational complexity to achieve error in pMoE is . Similarly, using the results from Theorem 4.3, the corresponding complexity in CNN is .