This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Patch-level Routing in Mixture-of-Experts is Provably Sample-efficient for Convolutional Neural Networks

Mohammed Nowaz Rabbani Chowdhury    Shuai Zhang    Meng Wang    Sijia Liu    Pin-Yu Chen
Abstract

In deep learning, mixture-of-experts (MoE) activates one or few experts (sub-networks) on a per-sample or per-token basis, resulting in significant computation reduction. The recently proposed patch-level routing in MoE (pMoE) divides each input into nn patches (or tokens) and sends ll patches (lnl\ll n) to each expert through prioritized routing. pMoE has demonstrated great empirical success in reducing training and inference costs while maintaining test accuracy. However, the theoretical explanation of pMoE and the general MoE remains elusive. Focusing on a supervised classification task using a mixture of two-layer convolutional neural networks (CNNs), we show for the first time that pMoE provably reduces the required number of training samples to achieve desirable generalization (referred to as the sample complexity) by a factor in the polynomial order of n/ln/l, and outperforms its single-expert counterpart of the same or even larger capacity. The advantage results from the discriminative routing property, which is justified in both theory and practice that pMoE routers can filter label-irrelevant patches and route similar class-discriminative patches to the same expert. Our experimental results on MNIST, CIFAR-10, and CelebA support our theoretical findings on pMoE’s generalization and show that pMoE can avoid learning spurious correlations.

Machine Learning, ICML

1 Introduction

Deep learning has demonstrated exceptional empirical success in many applications at the cost of high computational and data requirements. To address this issue, mixture-of-experts (MoE) only activates partial regions of a neural network for each data point and significantly reduces the computational complexity of deep learning without hurting the performance in applications such as machine translation and natural image classification (Shazeer et al., 2017; Yang et al., 2019).

Refer to caption
Figure 1: An illustration of pMoE. The image is divided into 2020 patches while the router selects 44 of them for each expert.

A conventional MoE model contains multiple experts (subnetworks of the backbone architecture) and one learnable router that routes each input sample to a few but not all the experts (Ramachandran & Le, 2018). Position-wise MoE has been introduced in language models (Shazeer et al., 2017; Lepikhin et al., 2020; Fedus et al., 2022), where the routing decisions are made on embeddings of different positions of the input separately rather than routing the entire text-input. Riquelme et al. (2021) extended it to vision models where the routing decisions are made on image patches. Zhou et al. (2022) further extended where the MoE layer has one router for each expert such that the router selects partial patches for the corresponding expert and discards the remaining patches. We termed this routing mode as patch-level routing and the MoE layer as patch-level MoE (pMoE) layer (see Figure 1 for an illustration of a pMoE). Notably, pMoE achieves the same test accuracy in vision tasks with 20% less training compute, and 50% less inference compute compared to its single-expert (i.e., one expert which is receiving all the patches of an input) counterpart of the same capacity (Riquelme et al., 2021).

Despite the empirical success of MoE, it remains elusive in theory, why can MoE maintain test accuracy while significantly reducing the amount of computation? To the best of our knowledge, only one recent work by Chen et al. (2022) shows theoretically that a conventional sample-wise MoE achieves higher test accuracy than convolutional neural networks (CNN) in a special setup of a binary classification task on data from linearly separable clusters. However, the sample-wise analyses by Chen et al. (2022) do not extend to patch-level MoE, which employ different routing strategies than conventional MoE, and their data model might not characterize some practical datasets. This paper addresses the following question theoretically:

How much computational resource does pMoE save from the single-expert counterpart while maintaining the same generalization guarantee?

In this paper, we consider a supervised binary classification task where each input sample consists of nn equal-sized patches including class-discriminative patterns that determine the labels and class-irrelevant patterns that do not affect the labels. The neural network contains a pMoE layer111In practice, pMoEs are usually placed in the last layers of deep models. Our analysis can be extended to this case as long as the input to the pMoE layer satisfies our data model (see Section 4.2). and multiple experts, each of which is a two-layer CNN222We consider CNN as expert due to its wide applications, especially in vision tasks. Moreover, the pMoE in (Riquelme et al., 2021; Zhou et al., 2022) uses two-layer Multi-Layer Perceptrons (MLPs) as experts in vision transformer (ViT), which operates on image patches. Hence, the MLPs in (Riquelme et al., 2021; Zhou et al., 2022) are effectively non-overlapping CNNs. of the same architecture. The router sends ll (lnl\ll n) patches to each expert. Although we consider a simplified neural network model to facilitate the formal analysis of pMoE, the insights are applicable to more general setups. Our major results include:

1. To the best of our knowledge, this paper provides the first theoretical generalization analysis of pMoE. Our analysis reveals that pMoE with two-layer CNNs as experts can achieve the same generalization performance as conventional CNN while reducing the sample complexity (the required number of training samples to learn a proper model) and model complexity. Specifically, we prove that as long as ll is larger than a certain threshold, pMoE reduces the sample complexity and model complexity by a factor polynomial in n/ln/l, indicating an improved generalization with a smaller ll.

2. Characterization of the desired property of the pMoE router. We show that a desired pMoE router can dispatch the same class-discriminative patterns to the same expert and discard some class-irrelevant patterns. This discriminative property allows the experts to learn the class-discriminative patterns with reduced interference from irrelevant patterns, which in turn reduces the sample complexity and model complexity. We also prove theoretically that a separately trained pMoE router has the desired property and empirically verify this property on practical pMoE routers.

3. Experimental demonstration of reduced sample complexity by pMoE in deep CNN models. In addition to verifying our theoretical findings on synthetic data prepared from the MNIST dataset (LeCun et al., 2010), we demonstrate the sample efficiency of pMoE in learning some benchmark vision datasets (e.g., CIFAR-10 (Krizhevsky, 2009) and CelebA (Liu et al., 2015)) by replacing the last convolutional layer of a ten-layer wide residual network (WRN) (Zagoruyko & Komodakis, 2016) with a pMoE layer. These experiments not only verify our theoretical findings but also demonstrate the applicability of pMoE in reducing sample complexity in deep-CNN-based vision models, complementing the existing empirical success of pMoE with vision transformers.

2 Related Works

Mixture-of-Experts. MoE was first introduced in the 1990s with dense sample-wise routing, i.e. each input sample is routed to all the experts (Jacobs et al., 1991; Jordan & Jacobs, 1994; Chen et al., 1999; Tresp, 2000; Rasmussen & Ghahramani, 2001). Sparse sample-wise routing was later introduced (Bengio et al., 2013; Eigen et al., 2013), where each input sample activates few of the experts in an MoE layer both for joint training (Ramachandran & Le, 2018; Yang et al., 2019) and separate training of the router and experts (Collobert et al., 2001, 2003; Ahmed et al., 2016; Gross et al., 2017). Position/patch-wise MoE (i.e., pMoE) recently demonstrated success in large language and vision models (Shazeer et al., 2017; Lepikhin et al., 2020; Riquelme et al., 2021; Fedus et al., 2022). To solve the issue of load imbalance (Lewis et al., 2021), Zhou et al. (2022) introduces the expert-choice routing in pMoE, where each expert uses one router to select a fixed number of patches from the input. This paper analyzes the sparse patch-level MoE with expert-choice routing under both joint-training and separate-training setups.

Optimization and generalization analyses of neural networks (NN). Due to the significant nonconvexity of deep learning problem, the existing generalization analyses are limited to linearized or shallow neural networks. The Neural-Tangent-Kernel (NTK) approach (Jacot et al., 2018; Lee et al., 2019; Du et al., 2019; Allen-Zhu et al., 2019b; Zou et al., 2020; Chizat et al., 2019; Ghorbani et al., 2021) considers strong over-parameterization and approximates the neural network by the first-order Taylor expansion. The NTK results are independent of the input data, and performance gaps in the representation power and generalization ability exist between the practical NN and the NTK results (Yehudai & Shamir, 2019; Ghorbani et al., 2019, 2020; Li et al., 2020; Malach et al., 2021). Nonlinear neural networks are analyzed recently through higher-order Taylor expansions (Allen-Zhu et al., 2019a; Bai & Lee, 2019; Arora et al., 2019; Ji & Telgarsky, 2019) or employing a model estimation approach from Gaussian input data (Zhong et al., 2017b, a; Zhang et al., 2020b, a; Fu et al., 2020; Li et al., 2022b), but these results are limited to two-layer networks with few papers on three-layer networks (Allen-Zhu et al., 2019a; Allen-Zhu & Li, 2019, 2020a; Li et al., 2022a).

The above works consider arbitrary input data or Gaussian input. To better characterize the practical generalization performance, some recent works analyze structured data models using approaches such as feature mapping (Li & Liang, 2018), where some of the initial model weights are close to data features, and feature learning (Daniely & Malach, 2020; Shalev-Shwartz et al., 2020; Shi et al., 2021; Allen-Zhu & Li, 2022; Li et al., 2023), where some weights gradually learn features during training. Among them, Allen-Zhu & Li (2020b); Brutzkus & Globerson (2021); Karp et al. (2021) analyze CNN on learning structured data composed of class-discriminative patterns that determine the labels and other label-irrelevant patterns. This paper extends the data models in Allen-Zhu & Li (2020b); Brutzkus & Globerson (2021); Karp et al. (2021) to a more general setup, and our analytical approach is a combination of feature learning in routers and feature mapping in experts for pMoE.

3 Problem Formulation

This paper considers the supervised binary classification333Our results can be extended to multiclass classification problems. See Section M in the Appendix for details. problem where given NN i.i.d. training samples {(xi,yi)}i=1N\{(x_{i},y_{i})\}_{i=1}^{N} generated by an unknown distribution 𝒟\mathcal{D}, the objective is to learn a neural network model that maps xx to yy for any (x,y)(x,y) sampled from 𝒟\mathcal{D}. Here, the input xndx\in\mathbb{R}^{nd} has nn disjoint patches, i.e., x=[x(1),x(2),,x(n)]x^{\intercal}=[x^{(1)\intercal},x^{(2)\intercal},...,x^{(n)\intercal}], where x(j)dx^{(j)}\in\mathbb{R}^{d} denotes the jj-th patch of xx. y{+1,1}y\in\{+1,-1\} denotes the corresponding label.

3.1 Neural Network Models

We consider a pMoE architecture that includes kk experts and the corresponding kk routers. Each router selects ll out of nn (l<nl<n) patches for each expert separately. Specifically, the router for each expert ss (s[k]s\in[k]) contains a trainable gating kernel wsdw_{s}\in\mathbb{R}^{d}. Given a sample xx, the router computes a routing value gj,s(x)=ws,x(j)g_{j,s}(x)=\langle w_{s},x^{(j)}\rangle for each patch jj. Let Js(x)J_{s}(x) denote the index set of top-ll values of gj,sg_{j,s} among all the patches j[n]j\in[n]. Only patches with indices in Js(x)J_{s}(x) are routed to the expert ss, multiplied by a gating value Gj,s(x)G_{j,s}(x), which are selected differently in different pMoE models.

Each expert is a two-layer CNN with the same architecture. Let mm denote the total number of neurons in all the experts. Then each expert contains (m/k)(m/k) neurons. Let wr,sdw_{r,s}\in\mathbb{R}^{d} and ar,sa_{r,s}\in\mathbb{R} denote the hidden layer and output layer weights for neuron rr (r[m/k])r\in[m/k]) in expert ss (s[k]s\in[k]), respectively. The activation function is the rectified linear unit (ReLU), where ReLU(z)=max(0,z)\textbf{ReLU}(z)=\text{max}(0,z).

Let θ={ar,s,wr,s,ws,s[k],r[m/k]}\theta=\{a_{r,s},w_{r,s},w_{s},\forall s\in[k],\forall r\in[m/k]\} include all the trainable weights. The pMoE model denoted as fMf_{M}, is defined as follows:

fM(θ,x)=s=1𝑘r=1mkar,sljJs(ws,x)ReLU(wr,s,x(j))Gj,s(ws,x)f_{M}(\theta,x)=\overset{k}{\underset{s=1}{\sum}}\overset{\frac{m}{k}}{\underset{r=1}{\sum}}\cfrac{a_{r,s}}{l}\underset{j\in J_{s}(w_{s},x)}{\sum}\textbf{ReLU}(\langle w_{r,s},x^{(j)}\rangle)G_{j,s}(w_{s},x)

(1)

An illustration of (1) is given in Figure 2.

Refer to caption
Figure 2: An illustration of the pMoE model in (1) with k=3,m=6,n=6k=3,m=6,n=6, and l=2l=2.

The learning problem solves the following empirical risk minimization problem with the logistic loss function,

min𝜃:L(θ)=1Ni=1𝑁log(1+eyifM(θ,xi))\displaystyle\underset{\theta}{\text{min}}:\hskip 11.38092ptL(\theta)=\cfrac{1}{N}\overset{N}{\underset{i=1}{\sum}}\log{(1+e^{-y_{i}f_{M}(\theta,x_{i})})} (2)

We consider two different training modes of pMoE, Separate-training and Joint-training of the routers and the experts. We also consider the conventional CNN architecture for comparison.

(I) Separate-training pMoE: Under the setup of the so-called hard mixtures of experts (Collobert et al., 2003; Ahmed et al., 2016; Gross et al., 2017), the router weights wsw_{s} are trained first and then fixed when training the weights of the experts. In this case, the gating values are set as

Gj,s(ws,x)1,j,s,xG_{j,s}(w_{s},x)\equiv 1,\ \forall j,s,x (3)

We select k=2k=2 in this case to simplify the analysis.

(II) Joint-training pMoE: The routers and the experts are learned jointly, see, e.g., (Lepikhin et al., 2020; Riquelme et al., 2021; Fedus et al., 2022). Here, the gating values are softmax functions with

Gj,s(ws,x)=egj,s(x)/(iJs(x)gi,s(x))G_{j,s}(w_{s},x)=e^{g_{j,s}(x)}/(\sum_{i\in J_{s}(x)}g_{i,s}(x)) (4)

(III) CNN single-expert counterpart: The conventional two-layer CNN with mm neurons, denoted as fCf_{C}, satisfies,

fC(θ,x)=r=1𝑚ar(1nj=1𝑛ReLU(wr,x(j)))\displaystyle f_{C}(\theta,x)=\overset{m}{\underset{r=1}{\sum}}a_{r}\left(\cfrac{1}{n}\overset{n}{\underset{j=1}{\sum}}\textbf{ReLU}(\langle w_{r},x^{(j)}\rangle)\right) (5)

Eq. (5) can be viewed as a special case of (1) when there is only one expert (k=1k=1), and all the patches are sent to the expert (l=nl=n) with gating values Gj,s1G_{j,s}\equiv 1.

Let θ~\tilde{\theta} denote the parameters of the learned model by solving (1). The predicted label for a test sample xx by the learned model is sign(fM(θ~,x))\textrm{sign}(f_{M}(\tilde{\theta},x)). The generalization accuracy, i.e., the fraction of correct predictions of all test samples equals (x,y)𝒟[yfM(θ,x)>0]\underset{(x,y)\sim\mathcal{D}}{\mathbb{P}}\left[yf_{M}(\theta,x)>0\right]. This paper studies both separate and joint training of pMoE and compares their performance with CNN, from the perspective of sample complexity to achieve a desirable generalization accuracy.

3.2 Training Algorithms

In the following algorithms, we fix the output layer weights ar,sa_{r,s} and ara_{r} at their initial values randomly sampled from the standard Gaussian distribution 𝒩(0,1)\mathcal{N}(0,1) and do not update them during the training. This is a typical simplification when analyzing NN, as used in (Li & Liang, 2018; Brutzkus et al., 2018; Allen-Zhu et al., 2019a; Arora et al., 2019).

(I) Separate-training pMoE: The routers are separately trained using NrN_{r} training samples (Nr<NN_{r}<N), denoted by {(xi,yi)}i=1Nr\{(x_{i},y_{i})\}_{i=1}^{N_{r}} without loss of generality. The gating kernels w1w_{1} and w2w_{2} are obtained by solving the following minimization problem:

minw1,w2:lr(w1,w2)=1Nri=1Nryiw1w2,j=1nxi(j)\displaystyle\underset{w_{1},w_{2}}{\text{min}}:\hskip 2.84544ptl_{r}(w_{1},w_{2})=-\cfrac{1}{N_{r}}\hskip 2.84544pt\overset{N_{r}}{\underset{i=1}{\sum}}y_{i}\langle w_{1}-w_{2},\sum_{j=1}^{n}x_{i}^{(j)}\rangle (6)

To solve (6), we implement the mini-batch SGD with batch size BrB_{r} for Tr=Nr/BrT_{r}=N_{r}/B_{r} iterations, starting from the random initialization as follows:

ws(0)𝒩(0,σr2𝕀d×d),s[2]w_{s}^{(0)}\sim\mathcal{N}(0,\sigma_{r}^{2}\mathbb{I}_{d\times d}),\forall s\in[2] (7)

where, σr=Θ(1/(n2log(poly(n))d))\sigma_{r}=\Theta\big{(}1\big{/}(n^{2}\log{(\textrm{poly}(n))}\sqrt{d})\big{)}.

After learning the routers, we train the hidden-layer weights wr,sw_{r,s} by solving (2) while fixing w1w_{1} and w2w_{2}. We implement mini-batch SGD of batch size BB for T=N/BT=N/B iterations starting from the initialization

wr,s(0)𝒩(0,1m𝕀d×d),s[2],r[m/2]w_{r,s}^{(0)}\sim\mathcal{N}(0,\frac{1}{m}\mathbb{I}_{d\times d}),\forall s\in[2],\forall r\in[m/2] (8)

(II) Joint-training pMoE: wsw_{s} and wr,sw_{r,s} in (1) are updated simultaneously by mini-batch SGD of batch size BB for T=N/BT=N/B iterations starting from the initialization in (7) and (8).

(III) CNN: wrw_{r} in (5) are updated by mini-batch SGD of batch size BB for T=N/BT=N/B iterations starting from the initialization in (8).

4 Theoretical Results

4.1 Key Findings At-a-glance

Before defining the data model assumptions and rationale in Section 4.2 and presenting the formal results in 4.3, we first summarize our key findings. We assume that the data patches are sampled from either class-discriminative patterns that determine the labels or a possibly infinite number of class-irrelevant patterns that have no impact on the label. The parameter δ\delta (defined in (9)) is inversely related to the separation among patterns, i.e., δ\delta decreases when (i) the separation among class-discriminative patterns increases, and/or (ii) the separation between class-discriminative and class-irrelevant patterns increases. The key findings are as follows.

(I). A properly trained patch-level router sends class-discriminative patches of one class to the same expert while dropping some class-irrelevant patches. We prove that separate-training pMoE routes class-discriminative patches of the class with label y=+1y=+1 (or the class with label y=1y=-1) to the expert 1 (or the expert 2) respectively, and the class-irrelevant patterns that are sufficiently away from class-discriminative patterns are not routed to any expert (Lemma 4.1). This discriminative routing property is also verified empirically for joint-training pMoE (see section 5.1). Therefore, pMoE effectively reduces the interference by irrelevant patches when each expert learns the class-discriminative patterns. Moreover, we show empirically that pMoE can remove class-irrelevant patches that are spuriously correlated with class labels and thus can avoid learning from spuriously correlated features of the data.

(II). Both the sample complexity and the required number of hidden nodes of pMoE reduce by a polynomial factor of n/ln/l over CNN. We prove that as long as ll, the number of patches per expert, is greater than a threshold (that decreases as the separation between class-discriminative and class-irrelevant patterns increases), the sample complexity and the required number of neurons of learning pMoE are Ω(l8)\Omega(l^{8}) and Ω(l10)\Omega(l^{10}) respectively. In contrast, the sample and model complexities of the CNN are Ω(n8)\Omega(n^{8}) and Ω(n10)\Omega(n^{10}) respectively, indicating improved generalization by pMoE.

(III). Larger separation among class-discriminative and class-irrelevant patterns reduces the sample complexity and model complexity of pMoE. Both the sample complexity and the required number of neurons of pMoE is polynomial in δ\delta, which decreases when the separation among patterns increases.

4.2 Data Model Assumptions and Rationale

The input xx is comprised of one class-discriminative pattern and n1n-1 class-irrelevant patterns, and the label yy is determined by the class-discriminative pattern only.

Distributions of class-discriminative patterns: The unit vectors o1o_{1} and o2do_{2}\in\mathbb{R}^{d} denote the class-discriminative patterns that determine the labels. The separation between o1o_{1} and o2o_{2} is measured as δd:=o1,o2(1,1)\delta_{d}:=\langle o_{1},o_{2}\rangle\in(-1,1). o1o_{1} and o2o_{2} are equally distributed in the samples, and each sample has exactly one of them. If xx contains o1o_{1} (or o2o_{2}), then yy is +1+1 (or 1-1).

Distributions of class-irrelevant patterns. Class-irrelevant patterns are unit vectors in d\mathbb{R}^{d} belonging to pp disjoint pattern sets S1,S2,.,SpS_{1},S_{2},....,S_{p}, and these patterns distribute equally for both classes. δr\delta_{r} measures the separation between class-discriminative patterns and class-irrelevant patterns, where |oi,q|δr|\langle o_{i},q\rangle|\leq\delta_{r}, i[2]\forall i\in[2], qSj\forall q\in S_{j}, j=1,,pj=1,...,p. Each SjS_{j} belongs to a ball with a diameter of Θ((1δr2)/dp2)\Theta(\sqrt{(1-\delta_{r}^{2})/dp^{2})}. Note that NO separation among class-irrelevant patterns themselves is required.

The rationale of our data model. The data distribution 𝒟\mathcal{D} captures the locality of the label-defining features in image data. It is motivated by and extended from the data distributions in recent theoretical frameworks (Yu et al., 2019; Brutzkus & Globerson, 2021; Karp et al., 2021; Chen et al., 2022). Specifically, Yu et al. (2019) and Brutzkus & Globerson (2021) require orthogonal patterns, i.e., δr\delta_{r} and δd\delta_{d} are both 0, and there are only a fixed number of non-discriminative patterns. Karp et al. (2021) and Chen et al. (2022) assume that δd=1\delta_{d}=-1 and a possibly infinite number of patterns drawn from zero-mean Gaussian distribution. In our model, δd\delta_{d} takes any value in (1,1)(-1,1), and the class-irrelevant patterns can be drawn from pp pattern sets that contain an infinite number of patterns that are not necessarily Gaussian or orthogonal.

Define

δ=1/(1max(δd2,δr2))\delta=1/(1-\max(\delta_{d}^{2},\delta_{r}^{2})) (9)

δ\delta decreases if (1) o1o_{1} and o2o_{2} are more separated from each other, and (2) Both o1o_{1} and o2o_{2} are more separated from any set SiS_{i}, i[p]i\in[p]. We also define an integer ll^{*} (lnl^{*}\leq n) that measures the maximum number of class-irrelevant patterns per sample that are sufficiently closer to o1o_{1} than o2o_{2}, and vice versa. Specifically, a class-irrelevant pattern qq is called δ\delta^{\prime}-closer (δ>0\delta^{\prime}>0) to o1o_{1} than o2o_{2}, if o1o2,q>δ\langle o_{1}-o_{2},q\rangle>\delta^{\prime} holds. Similarly, qq is δ\delta^{\prime}-closer to o2o_{2} than o1o_{1} if o2o1,q>δ\langle o_{2}-o_{1},q\rangle>\delta^{\prime}. Then, let l1l^{*}-1 be the maximum number of class-irrelevant patches that are either δ\delta^{\prime}-closer to o1o_{1} than o2o_{2} or vice versa with δ=Θ(1δd)\delta^{\prime}=\Theta(1-\delta_{d}) in any xx sampled from 𝒟\mathcal{D}. ll^{*} depends on 𝒟\mathcal{D} and δd\delta_{d}. When 𝒟\mathcal{D} is fixed, a smaller δd\delta_{d} corresponds to a larger separation between o1o_{1} and o2o_{2} and leads to a small ll^{*}. In contrast to linearly separable data in (Yu et al., 2019; Brutzkus et al., 2018; Chen et al., 2022), our data model is NOT linearly separable as long as l=Ω(1)l^{*}=\Omega(1) (see section K in Appendix for the proof).

4.3 Main Theoretical Results

4.3.1 Generalization Guarantee of Separate-training pMoE

Lemma 4.1 shows that as long as the number of patches per expert, ll, is greater than ll^{*}, then the separately learned routers by solving (6) always send o1o_{1} to expert 1 and o2o_{2} to expert 2. Based on this discriminative property of the learned routers, Theorem 4.2 then quantifies the sample complexity and network size of separate-training pMoE to achieve a desired generalization error ϵ\epsilon. Theorem 4.3 quantifies the sample and model complexities of CNN for comparison.

Lemma 4.1 (Discriminative Property of Separately Trained Routers).

For every lll\geq l^{*}, w.h.p. over the random initialization defined in (7), after doing mini-batch SGD with batch-size Br=Ω(n2/(1δd)2)B_{r}=\Omega\left(n^{2}/(1-\delta_{d})^{2}\right) and learning rate ηr=Θ(1/n)\eta_{r}=\Theta(1/n), for Tr=Ω(1/(1δd))T_{r}=\Omega\left(1/(1-\delta_{d})\right) iterations, the returned w1w_{1} and w2w_{2} satisfy

argj[n](x(j)=o1)J1(w1,x),(x,y=+1)𝒟\underset{j\in[n]}{\text{arg}}(x^{(j)}=o_{1})\in J_{1}(w_{1},x),\quad\forall(x,y=+1)\sim\mathcal{D}
argj[n](x(j)=o2)J2(w2,x),(x,y=1)𝒟\underset{j\in[n]}{\text{arg}}(x^{(j)}=o_{2})\in J_{2}(w_{2},x),\quad\forall(x,y=-1)\sim\mathcal{D}

i.e., the learned routers always send o1o_{1} to expert 1 and o2o_{2} to expert 2.

The main idea in proving Lemma 4.1 is to show that the gradient in each iteration has a large component along the directions of o1o_{1} and o2o_{2}. Then after enough iterations, the inner product of w1w_{1} and o1o_{1} (similarly, w2w_{2} and o2o_{2}) is sufficiently large. The intuition of requiring lll\geq l^{*} is that because there are at most l1l^{*}-1 class-irrelevant patches sufficiently closer to o1o_{1} than o2o_{2} (or vice versa), then sending lll\geq l^{*} patches to one expert will ensure that one of them is o1o_{1} (or o2o_{2}). Note that the batch size BrB_{r} and the number of iterations TrT_{r} depend on δd\delta_{d}, the separation between o1o_{1} and o2o_{2}, but are independent of the separation between class-discriminative and class-irrelevant patterns.

We then show that the separate-training pMoE reduces both the sample complexity and the required model size (Theorem 4.2) compared to the CNN (Theorem 4.3).

Theorem 4.2 (Generalization guarantee of separate-training pMoE).

For every ϵ>0\epsilon>0 and lll\geq l^{*}, for every mMS=Ω(l10p12δ6/ϵ16)m\geq M_{S}=\Omega\left(l^{10}p^{12}\delta^{6}\big{/}\epsilon^{16}\right) with at least NS=Ω(l8p12δ6/ϵ16)N_{S}=\Omega(l^{8}p^{12}\delta^{6}/\epsilon^{16}) training samples, after performing minibatch SGD with the batch size B=Ω(l4p6δ3/ϵ8)B=\Omega\left(l^{4}p^{6}\delta^{3}\big{/}\epsilon^{8}\right) and the learning rate η=O(1/(mpoly(l,p,δ,1/ϵ,logm)))\eta=O\big{(}1\big{/}(m\textrm{poly}(l,p,\delta,1/\epsilon,\log m))\big{)} for T=O(l4p6δ3/ϵ8)T=O\left(l^{4}p^{6}\delta^{3}\big{/}\epsilon^{8}\right) iterations, it holds w.h.p. that

(x,y)𝒟[yfM(θ(T),x)>0]1ϵ\underset{(x,y)\sim\mathcal{D}}{\mathbb{P}}\left[yf_{M}(\theta^{(T)},x)>0\right]\geq 1-\epsilon

Theorem 4.2 implies that to achieve generalization error ϵ\epsilon by a separate-training pMoE, we need NS=Ω(l8p12δ6/ϵ16)N_{S}=\Omega(l^{8}p^{12}\delta^{6}/\epsilon^{16}) training samples and MS=Ω(l10p12δ6/ϵ16)M_{S}=\Omega\left(l^{10}p^{12}\delta^{6}\big{/}\epsilon^{16}\right) hidden nodes. Therefore, both NSN_{S} and MSM_{S} increase polynomially with the number of patches ll sent to each expert. Moreover, both NSN_{S} and MSM_{S} are polynomial in δ\delta defined in (9), indicating an improved generalization performance with stronger separation among patterns.

The proof of Theorem 4.2 is inspired by Li & Liang (2018), which analyzes the generalization performance of fully-connected neural networks (FCN) on structured data, but we have new technical contributions in analyzing pMoE models. In addition to analyzing the pMoE routers (Lemma 4.1), which do not appear in the FCN analysis, our analyses also significantly relax the separation requirement on the data, compared with that by Li & Liang (2018). For example, Li & Liang (2018) requires the separation between the two classes, measured by the smallest 2\ell_{2}-norm distance of two points in different classes, being Ω(n)\Omega(n) to obtain a sample complexity bound of poly(nn) for the binary classification task. In contrast, the separation between the two classes in our data model is min{2(1δd),21δr}\min\{\sqrt{2(1-\delta_{d})},2\sqrt{1-\delta_{r}}\}, much less than Ω(n)\Omega(n) required by Li & Liang (2018).

Theorem 4.3 (Generalization guarantee of CNN).

For every ϵ>0\epsilon>0, for every mMC=Ω(n10p12δ6/ϵ16)m\geq M_{C}=\Omega\left(n^{10}p^{12}\delta^{6}\big{/}\epsilon^{16}\right) with at least NC=Ω(n8p12δ6/ϵ16)N_{C}=\Omega(n^{8}p^{12}\delta^{6}/\epsilon^{16}) training samples, after performing minibatch SGD with the batch size B=Ω(n4p6δ3/ϵ8)B=\Omega\left(n^{4}p^{6}\delta^{3}\big{/}\epsilon^{8}\right) and the learning rate η=O(1/(mpoly(n,p,δ,1/ϵ,logm)))\eta=O\big{(}1\big{/}(m\textrm{poly}(n,p,\delta,1/\epsilon,\log m))\big{)} for T=O(n4p6δ3/ϵ8)T=O\left(n^{4}p^{6}\delta^{3}\big{/}\epsilon^{8}\right) iterations, it holds w.h.p. that

(x,y)𝒟[yfC(θ(T),x)>0]1ϵ\underset{(x,y)\sim\mathcal{D}}{\mathbb{P}}\left[yf_{C}(\theta^{(T)},x)>0\right]\geq 1-\epsilon

Theorem 4.3 implies that to achieve a generalization error ϵ\epsilon using CNN in (5), we need NC=Ω(n8p12δ6/ϵ16)N_{C}=\Omega(n^{8}p^{12}\delta^{6}/\epsilon^{16}) training samples and MC=Ω(n10p12δ6/ϵ16)M_{C}=\Omega\left(n^{10}p^{12}\delta^{6}\big{/}\epsilon^{16}\right) neurons.

Sample-complexity gap between single CNN and mixture of CNNs. From Theorem 4.2 and Theorem 4.3, the sample-complexity ratio of the CNN to the separate-training pMoE is NC/NS=Θ((n/l)8)N_{C}/N_{S}=\Theta\big{(}(n/l)^{8}\big{)}. Similarly, the required number of neurons is reduced by a factor of MC/MS=Θ((n/l)10)M_{C}/M_{S}=\Theta\big{(}(n/l)^{10}\big{)} in separate-training pMoE444The bounds for the sample complexity and model size in Theorem 4.2 and Theorem 4.3 are sufficient but not necessary. Thus, rigorously speaking, one can not compare sufficient conditions only. In our analysis, however, the bounds for MoE and CNN are derived with exactly the same technique with the only difference to handle the routers. Therefore, it is fair to compare these two bounds to show the advantage of pMoE..

4.3.2 Generalization Guarantee of Joint-training pMoE with Proper Routers

Theorem 4.5 characterizes the generalization performance of joint-training pMoE assuming the routers are properly trained in the sense that after some SGD iterations, for each class at least one of the kk experts receives all class-discriminative patches of that class with the largest gating-value (see Assumption 4.4).

Assumption 4.4.

There exists an integer T<TT^{\prime}<T such that for all tTt\geq T^{\prime}, it holds that:

There exists an expert s[k] s.t. (x,y=+1)𝒟,\displaystyle\text{There exists an expert }s\in[k]\text{ s.t. }\forall(x,y=+1)\sim\mathcal{D},
jo1Js(ws(t),x), and Gjo1,s(t)(x)Gj,s(t)(x)\displaystyle\hskip 28.45274ptj_{o_{1}}\in J_{s}(w_{s}^{(t)},x),\text{ and }G_{j_{o_{1}},s}^{(t)}(x)\geq G_{j,s}^{(t)}(x)
and an expert s[k] s.t. (x,y=1)𝒟,\displaystyle\text{and an expert }s\in[k]\text{ s.t. }\forall(x,y=-1)\sim\mathcal{D},
jo2Js(ws(t),x), and Gjo2,s(t)(x)Gj,s(t)(x)\displaystyle\hskip 28.45274ptj_{o_{2}}\in J_{s}(w_{s}^{(t)},x),\text{ and }G_{j_{o_{2}},s}^{(t)}(x)\geq G_{j,s}^{(t)}(x)

where jo1j_{o_{1}} (jo2j_{o_{2}}) denotes the index of the class-discriminative pattern o1o_{1} (o2o_{2}), Gj,s(t)(x)G_{j,s}^{(t)}(x) is the gating output of patch jJs(ws(t),x)j\in J_{s}(w_{s}^{(t)},x) of sample xx for expert ss at the iteration tt, and ws(t)w_{s}^{(t)} is the gating kernel for expert ss at iteration tt.

Assumption 4.4 is required in proving Theorem 4.5 because of the difficulty of tracking the dynamics of the routers in joint-training pMoE. Assumption 4.4 is verified on empirical experiments in Section 5.1, while its theoretical proof is left for future work.

Table 1: Computational complexity of pMoE and CNN.
Complexity to achieve ϵ\epsilon error (Complx/Iter ×\times T) pMoE CNN
Separate-training Joint-training
O(Bml5d/ϵ8)O(Bml^{5}d/\epsilon^{8}) O(Bmk2l3d/ϵ8)O(Bmk^{2}l^{3}d/\epsilon^{8}) O(Bmn5d/ϵ8)O(Bmn^{5}d/\epsilon^{8})
Complexity per Iteration (Complx/Iter) O(Bmld)O(Bmld) Router Expert O(Bmnd)O(Bmnd)
O(Bknd)O(Bknd) (Forward pass) O(Bmld)O(Bmld)
O(Bkl2d)O(Bkl^{2}d) (Backward pass)
Iteration required to converge with ϵ\epsilon error (T) O(l4/ϵ8)O(l^{4}/\epsilon^{8}) O(k2l2/ϵ8)O(k^{2}l^{2}/\epsilon^{8}) O(n4/ϵ8)O(n^{4}/\epsilon^{8})
Theorem 4.5 (Generalization guarantee of joint-training pMoE).

Suppose Assumption 4.4 hold. Then for every ϵ>0\epsilon>0, for every mMJ=Ω(k3n2l6p12δ6/ϵ16)m\geq M_{J}=\Omega\left(k^{3}n^{2}l^{6}p^{12}\delta^{6}\big{/}\epsilon^{16}\right) with at least NJ=Ω(k4l6p12δ6/ϵ16)N_{J}=\Omega(k^{4}l^{6}p^{12}\delta^{6}/\epsilon^{16}) training samples, after performing minibatch SGD with the batch size B=Ω(k2l4p6δ3/ϵ8)B=\Omega\left(k^{2}l^{4}p^{6}\delta^{3}\big{/}\epsilon^{8}\right) and the learning rate η=O(1/(mpoly(l,p,δ,1/ϵ,logm)))\eta=O\big{(}1\big{/}(m\text{poly}(l,p,\delta,1/\epsilon,\log m))\big{)} for T=O(k2l2p6δ3/ϵ8)T=O\left(k^{2}l^{2}p^{6}\delta^{3}\big{/}\epsilon^{8}\right) iterations, it holds w.h.p. that

(x,y)𝒟[yfM(θ(T),x)>0]1ϵ\underset{(x,y)\sim\mathcal{D}}{\mathbb{P}}\left[yf_{M}(\theta^{(T)},x)>0\right]\geq 1-\epsilon

Theorem 4.5 indicates that, with proper routers, joint-training pMoE needs NJ=Ω(k4l6p12δ6/ϵ16)N_{J}=\Omega(k^{4}l^{6}p^{12}\delta^{6}/\epsilon^{16}) training samples and MJ=Ω(k3n2l6p12δ6/ϵ16)M_{J}=\Omega\left(k^{3}n^{2}l^{6}p^{12}\delta^{6}\big{/}\epsilon^{16}\right) neurons to achieve ϵ\epsilon generalization error. Compared with CNN in Theorem 4.3, joint-training pMoE reduces the sample complexity and model size by a factor of Θ(n8/k4l6)\Theta(n^{8}/k^{4}l^{6}) and Θ(n10/k3l6)\Theta(n^{10}/k^{3}l^{6}), respectively. With more experts (a larger kk), it is easier to satisfy Assumption 4.4 to learn proper routers but requires larger sample and model complexities. When the number of samples is fixed, the expression of NJN_{J} also indicates that ϵ\epsilon sales as k1/4l3/8k^{1/4}l^{3/8}, corresponding to an improved generalization when kk and ll decrease.

We provide the end-to-end computational complexity comparison between the analyzed pMoE models and general CNN model in Table 1 (see section N in Appendix for details). The results in Table 1 indicates that the computational complexity in joint-training pMoE is reduced by a factor of O(n5/k2l3)O(n^{5}/k^{2}l^{3}) compared with CNN. Similarly, the reduction of computational complexity of separate-training pMoE is O(n5/l5)O(n^{5}/l^{5}).

5 Experimental Results

5.1 pMoE of Two-layer CNN

Dataset: We verify our theoretical findings about the model in (1) on synthetic data prepared from MNIST (LeCun et al., 2010) data set. Each sample contains n=16n=16 patches with patch size d=28×28d=28\times 28. Each patch is drawn from the MNIST dataset. See Figure 4 as an example. We treat the digits “1” and “0” as the class-discriminative patterns o1o_{1} and o2o_{2}, respectively. Each of the digits from “2” to “9” represents a class-irrelevant pattern set.

Refer to caption
Figure 3: Sample image of the synthetic data from MNIST. Class label is “1”.
Refer to caption
Figure 4: Generalization performance of pMoE and CNN with a similar model size
Refer to caption
Figure 5: Phase transition of sample complexity with ll in separate-training pMoE
Refer to caption
Figure 6: Change of test accuracy in joint-training pMoE with kk for fixed sample sizes
Refer to caption
Figure 7: Change of test accuracy in joint-training pMoE with ll for fixed sample sizes

Setup: We compare separate-training pMoE, joint-training pMoE, and CNN with similar model sizes. The separate-training pMoE contains two experts with 2020 hidden nodes in each expert. The joint-training pMoE has eight experts with five hidden nodes per expert. The CNN has 4040 hidden nodes. All are trained using SGD with η=0.2\eta=0.2 until zero training error. pMoE converges much faster than CNN, which takes 150150 epochs. Before training the experts in the separate-training pMoE, we train the router for 100100 epochs. The models are evaluated on 10001000 test samples.

Generalization performance: Figure 4 compares the test accuracy of the three models, where l=2l=2 and l=6l=6 for separate-training and joint-training pMoE, respectively. The error bars show the mean plus/minus one standard deviation of five independent experiments. pMoE outperforms CNN with the same number of training samples. pMoE only requires 60% of the training samples needed by CNN to achieve 95%95\% test accuracy.

Figure 5 shows the sample complexity of separate-training pMoE with respect to ll. Each block represents 20 independent trials. A white block indicates all success, and a black block indicates all failure. The sample complexity is polynomial in ll, verifying Theorem 4.2. Figure 7 and 6 show the test accuracy of joint-training pMoE with a fixed sample size when ll and kk change, respectively. When ll is greater than ll^{*}, which is 66 in Figure 7, the test accuracy matches our predicted order. Similarly, the dependence on kk also matches our prediction, when kk is large enough to make Assumption 4.4 hold.

Router performance: Figure 8 verifies the discriminative property of separately trained routers (Lemma 4.1) by showing the percentage of testing data that have class-discriminative patterns (o1o_{1} and o2o_{2}) in top ll patches of the separately trained router. With very few training samples (such as 300300), one can already learn a proper router that has discriminative patterns in top-44 patches for 95% of data. Figure 9 verifies the discriminative property of jointly trained routers (Assumption 4.4). With only 300300 training samples, the jointly trained router dispatches o1o_{1} with the largest gating value to a particular expert for 95% of class-1 data and similarly for o2o_{2} in 92% of class-2 data.

Refer to caption
Figure 8: Percentage of properly routed discriminative patterns by a separately trained router.
Refer to caption
Figure 9: Percentage of properly routed discriminative patterns by a jointly trained router. l=6l=6.

5.2 pMoE of Wide Residual Networks (WRNs)

Neural network model: We employ the 10-layer WRN (Zagoruyko & Komodakis, 2016) with a widening factor of 10 as the expert. We construct a patch-level MoE counterpart of WRN, referred to as WRN-pMoE, by replacing the last convolutional layer of WRN with an pMoE layer of an equal number of trainable parameters (see Figure 18 in Appendix for an illustration). WRN-pMoE is trained with the joint-training method555Code is available at https://github.com/nowazrabbani/pMoE_CNN. All the results are averaged over five independent experiments.

Datasets: We consider both CelebA (Liu et al., 2015) and CIFAR-10 datasets. The experiments on CIFAR-10 are deferred to the Appendix (see section A). We down-sample the images of CelebA to 64×6464\times 64. The last convolutional layer of WRN receives a (16×16×64016\times 16\times 640) dimensional feature map. The feature map is divided into 1616 patches with size 4×4×6404\times 4\times 640 in WRN-pMoE. k=8k=8 and l=2l=2 for the pMoE layer.

Refer to caption
Figure 10: Classification accuracy of WRN-pMoE and WRN on “smiling” in CelebA
Refer to caption
Figure 11: Classification accuracy of WRN-pMoE and WRN on “smiling” when spuriously correlated with “black hair” in CelebA
Refer to caption
Figure 12: Classification accuracy of WRN-pMoE and WRN on multiclass classification in CelebA
Table 2: Comparison of training compute of WRN and WRN-pMoE.
No. of training samples Convergence time (sec) Training FLOPs (×1015\times 10^{15})
WRN WRN-pMoE WRN WRN-pMoE
40004000 260260 𝟏𝟓𝟔\mathbf{156} 66 3.5\mathbf{3.5}
80008000 324324 𝟏𝟗𝟐\mathbf{192} 7.57.5 4.4\mathbf{4.4}
1200012000 468468 𝟐𝟖𝟎\mathbf{280} 1111 6.4\mathbf{6.4}
1600016000 630630 𝟑𝟔𝟖\mathbf{368} 1515 8.5\mathbf{8.5}

Performance Comparison: Figure 12 shows the test accuracy of the binary classification problem on the attribute “smiling.” WRN-pMoE requires less than one-fifth of the training samples needed by WRN to achieve 86% accuracy. Figure 12 shows the performance when the training data contain spurious correlations with the hair color as a spurious attribute. Specifically, 95% of the training images with the attribute “smiling” also have the attribute “black hair,” while 95% of the training images with the attribute “not-smiling” have the attribute “blond hair.” The models may learn the hair-color attribute rather than “smiling” due to spurious correlation and, thus, the test accuracies are lower in Figure 12 than those in Figure 12. Nevertheless, WRN-pMoE outperforms WRN and reduces the sample complexity to achieve the same accuracy.

Figure 12 shows the test accuracy of multiclass classification (four classes with class attributes: “Not smiling, Eyeglass,” “Smiling, Eyeglass,” “Smiling, No eyeglass,” and “Not smiling, No eyeglass”) in CelebA. The results are consistent with the binary classification results. Furthermore, Table 2 empirically verifies the computational efficiency of WRN-pMoE over WRN on multiclass classification in CelebA666An NVIDIA RTX 4500 GPU was used to run the experiments, training FLOPs are calculated as Training FLOPs=Training time (second)×Number of GPUs×peak FLOP/second×GPU utilization rate\text{Training FLOPs}=\text{Training time (second)}\times\text{Number of GPUs}\times\text{peak FLOP/second}\times\text{GPU utilization rate}. Even with same number of training samples, WRN-pMoE is still more computationally efficient than WRN, because WRN-pMoE requires fewer iterations to converge and has a lower per-iteration cost.

6 Conclusion

MoE reduces computational costs significantly without hurting the generalization performance in various empirical studies, but the theoretical explanation is mostly elusive. This paper provides the first theoretical analysis of patch-level MoE and proves its savings in sample complexity and model size quantitatively compared with the single-expert counterpart. Although centered on a classification task using a mixture of two-layer CNNs, our theoretical insights are verified empirically on deep architectures and multiple datasets. Future works include analyzing other MoE architectures such as MoE in Vision Transformer (ViT) and connecting MoE with other sparsification methods to further reduce the computation.

Acknowledgements

This work was supported by AFOSR FA9550-20-1-0122, NSF 1932196 and the Rensselaer-IBM AI Research Collaboration (http://airc.rpi.edu), part of the IBM AI Horizons Network (http://ibm.biz/AIHorizons). We thank Yihua Zhang at Michigan State University for the help in experiments with CelebA dataset. We thank all anonymous reviewers.

References

  • Ahmed et al. (2016) Ahmed, K., Baig, M. H., and Torresani, L. Network of experts for large-scale image categorization. In European Conference on Computer Vision, pp.  516–532. Springer, 2016.
  • Allen-Zhu & Li (2019) Allen-Zhu, Z. and Li, Y. What can resnet learn efficiently, going beyond kernels? Advances in Neural Information Processing Systems, 32, 2019.
  • Allen-Zhu & Li (2020a) Allen-Zhu, Z. and Li, Y. Backward feature correction: How deep learning performs deep learning. arXiv preprint arXiv:2001.04413, 2020a.
  • Allen-Zhu & Li (2020b) Allen-Zhu, Z. and Li, Y. Towards understanding ensemble, knowledge distillation and self-distillation in deep learning. arXiv preprint arXiv:2012.09816, 2020b.
  • Allen-Zhu & Li (2022) Allen-Zhu, Z. and Li, Y. Feature purification: How adversarial training performs robust deep learning. In 2021 IEEE 62nd Annual Symposium on Foundations of Computer Science (FOCS), pp.  977–988. IEEE, 2022.
  • Allen-Zhu et al. (2019a) Allen-Zhu, Z., Li, Y., and Liang, Y. Learning and generalization in overparameterized neural networks, going beyond two layers. Advances in neural information processing systems, 32, 2019a.
  • Allen-Zhu et al. (2019b) Allen-Zhu, Z., Li, Y., and Song, Z. A convergence theory for deep learning via over-parameterization. In International Conference on Machine Learning, pp. 242–252. PMLR, 2019b.
  • Arora et al. (2019) Arora, S., Du, S., Hu, W., Li, Z., and Wang, R. Fine-grained analysis of optimization and generalization for overparameterized two-layer neural networks. In International Conference on Machine Learning, pp. 322–332. PMLR, 2019.
  • Bai & Lee (2019) Bai, Y. and Lee, J. D. Beyond linearization: On quadratic and higher-order approximation of wide neural networks. In International Conference on Learning Representations, 2019.
  • Bengio et al. (2013) Bengio, Y., Léonard, N., and Courville, A. Estimating or propagating gradients through stochastic neurons for conditional computation. arXiv preprint arXiv:1308.3432, 2013.
  • Brutzkus & Globerson (2021) Brutzkus, A. and Globerson, A. An optimization and generalization analysis for max-pooling networks. In Uncertainty in Artificial Intelligence, pp.  1650–1660. PMLR, 2021.
  • Brutzkus et al. (2018) Brutzkus, A., Globerson, A., Malach, E., and Shalev-Shwartz, S. SGD learns over-parameterized networks that provably generalize on linearly separable data. In International Conference on Learning Representations, 2018.
  • Chen et al. (1999) Chen, K., Xu, L., and Chi, H. Improved learning algorithms for mixture of experts in multiclass classification. Neural networks, 12(9):1229–1252, 1999.
  • Chen et al. (2022) Chen, Z., Deng, Y., Wu, Y., Gu, Q., and Li, Y. Towards understanding mixture of experts in deep learning. arXiv preprint arXiv:2208.02813, 2022.
  • Chizat et al. (2019) Chizat, L., Oyallon, E., and Bach, F. On lazy training in differentiable programming. Advances in Neural Information Processing Systems, 32, 2019.
  • Collobert et al. (2001) Collobert, R., Bengio, S., and Bengio, Y. A parallel mixture of SVMs for very large scale problems. Advances in Neural Information Processing Systems, 14, 2001.
  • Collobert et al. (2003) Collobert, R., Bengio, Y., and Bengio, S. Scaling large learning problems with hard parallel mixtures. International Journal of pattern recognition and artificial intelligence, 17(03):349–365, 2003.
  • Daniely & Malach (2020) Daniely, A. and Malach, E. Learning parities with neural networks. Advances in Neural Information Processing Systems, 33:20356–20365, 2020.
  • Du et al. (2019) Du, S., Lee, J., Li, H., Wang, L., and Zhai, X. Gradient descent finds global minima of deep neural networks. In International conference on machine learning, pp. 1675–1685. PMLR, 2019.
  • Eigen et al. (2013) Eigen, D., Ranzato, M., and Sutskever, I. Learning factored representations in a deep mixture of experts. arXiv preprint arXiv:1312.4314, 2013.
  • Fedus et al. (2022) Fedus, W., Zoph, B., and Shazeer, N. Switch transformers: Scaling to trillion parameter models with simple and efficient sparsity. Journal of Machine Learning Research, 23(120):1–39, 2022.
  • Fu et al. (2020) Fu, H., Chi, Y., and Liang, Y. Guaranteed recovery of one-hidden-layer neural networks via cross entropy. IEEE transactions on signal processing, 68:3225–3235, 2020.
  • Ghorbani et al. (2019) Ghorbani, B., Mei, S., Misiakiewicz, T., and Montanari, A. Limitations of lazy training of two-layers neural network. Advances in Neural Information Processing Systems, 32, 2019.
  • Ghorbani et al. (2020) Ghorbani, B., Mei, S., Misiakiewicz, T., and Montanari, A. When do neural networks outperform kernel methods? Advances in Neural Information Processing Systems, 33:14820–14830, 2020.
  • Ghorbani et al. (2021) Ghorbani, B., Mei, S., Misiakiewicz, T., and Montanari, A. Linearized two-layers neural networks in high dimension. The Annals of Statistics, 49(2):1029–1054, 2021.
  • Gross et al. (2017) Gross, S., Ranzato, M., and Szlam, A. Hard mixtures of experts for large scale weakly supervised vision. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp.  6865–6873, 2017.
  • Jacobs et al. (1991) Jacobs, R. A., Jordan, M. I., Nowlan, S. J., and Hinton, G. E. Adaptive mixtures of local experts. Neural computation, 3(1):79–87, 1991.
  • Jacot et al. (2018) Jacot, A., Gabriel, F., and Hongler, C. Neural tangent kernel: Convergence and generalization in neural networks. Advances in neural information processing systems, 31, 2018.
  • Ji & Telgarsky (2019) Ji, Z. and Telgarsky, M. Polylogarithmic width suffices for gradient descent to achieve arbitrarily small test error with shallow relu networks. In International Conference on Learning Representations, 2019.
  • Jordan & Jacobs (1994) Jordan, M. I. and Jacobs, R. A. Hierarchical mixtures of experts and the em algorithm. Neural computation, 6(2):181–214, 1994.
  • Karp et al. (2021) Karp, S., Winston, E., Li, Y., and Singh, A. Local signal adaptivity: Provable feature learning in neural networks beyond kernels. Advances in Neural Information Processing Systems, 34:24883–24897, 2021.
  • Krizhevsky (2009) Krizhevsky, A. Learning multiple layers of features from tiny images. Technical report, Canadian Institute For Advanced Research, 2009.
  • LeCun et al. (2010) LeCun, Y., Cortes, C., and Burges, C. MNIST handwritten digit database. AT&T labs [online]. available http. yann. lecun. com/exdb/mnist, 2010.
  • Lee et al. (2019) Lee, J., Xiao, L., Schoenholz, S., Bahri, Y., Novak, R., Sohl-Dickstein, J., and Pennington, J. Wide neural networks of any depth evolve as linear models under gradient descent. Advances in neural information processing systems, 32, 2019.
  • Lepikhin et al. (2020) Lepikhin, D., Lee, H., Xu, Y., Chen, D., Firat, O., Huang, Y., Krikun, M., Shazeer, N., and Chen, Z. Gshard: Scaling giant models with conditional computation and automatic sharding. In International Conference on Learning Representations, 2020.
  • Lewis et al. (2021) Lewis, M., Bhosale, S., Dettmers, T., Goyal, N., and Zettlemoyer, L. Base layers: Simplifying training of large, sparse models. In International Conference on Machine Learning, pp. 6265–6274. PMLR, 2021.
  • Li et al. (2022a) Li, H., Wang, M., Liu, S., Chen, P.-Y., and Xiong, J. Generalization guarantee of training graph convolutional networks with graph topology sampling. In International Conference on Machine Learning, pp. 13014–13051. PMLR, 2022a.
  • Li et al. (2022b) Li, H., Zhang, S., and Wang, M. Learning and generalization of one-hidden-layer neural networks, going beyond standard gaussian data. In 2022 56th Annual Conference on Information Sciences and Systems (CISS), pp.  37–42. IEEE, 2022b.
  • Li et al. (2023) Li, H., Wang, M., Liu, S., and Chen, P.-Y. A theoretical understanding of shallow vision transformers: Learning, generalization, and sample complexity. In The Eleventh International Conference on Learning Representations, 2023. URL https://openreview.net/forum?id=jClGv3Qjhb.
  • Li & Liang (2018) Li, Y. and Liang, Y. Learning overparameterized neural networks via stochastic gradient descent on structured data. Advances in neural information processing systems, 31, 2018.
  • Li et al. (2020) Li, Y., Ma, T., and Zhang, H. R. Learning over-parametrized two-layer neural networks beyond NTK. In Conference on learning theory, pp.  2613–2682. PMLR, 2020.
  • Liu et al. (2015) Liu, Z., Luo, P., Wang, X., and Tang, X. Deep learning face attributes in the wild. In Proceedings of the IEEE international conference on computer vision, pp.  3730–3738, 2015.
  • Malach et al. (2021) Malach, E., Kamath, P., Abbe, E., and Srebro, N. Quantifying the benefit of using differentiable learning over tangent kernels. In International Conference on Machine Learning, pp. 7379–7389. PMLR, 2021.
  • Ramachandran & Le (2018) Ramachandran, P. and Le, Q. V. Diversity and depth in per-example routing models. In International Conference on Learning Representations, 2018.
  • Rasmussen & Ghahramani (2001) Rasmussen, C. and Ghahramani, Z. Infinite mixtures of gaussian process experts. Advances in neural information processing systems, 14, 2001.
  • Riquelme et al. (2021) Riquelme, C., Puigcerver, J., Mustafa, B., Neumann, M., Jenatton, R., Susano Pinto, A., Keysers, D., and Houlsby, N. Scaling vision with sparse mixture of experts. Advances in Neural Information Processing Systems, 34:8583–8595, 2021.
  • Shalev-Shwartz et al. (2020) Shalev-Shwartz, S. et al. Computational separation between convolutional and fully-connected networks. In International Conference on Learning Representations, 2020.
  • Shazeer et al. (2017) Shazeer, N., Mirhoseini, A., Maziarz, K., Davis, A., Le, Q. V., Hinton, G. E., and Dean, J. Outrageously large neural networks: The sparsely-gated mixture-of-experts layer. In International Conference on Learning Representations, 2017.
  • Shi et al. (2021) Shi, Z., Wei, J., and Liang, Y. A theoretical analysis on feature learning in neural networks: Emergence from inputs and advantage over fixed features. In International Conference on Learning Representations, 2021.
  • Tresp (2000) Tresp, V. Mixtures of gaussian processes. In Leen, T., Dietterich, T., and Tresp, V. (eds.), Advances in Neural Information Processing Systems, volume 13. MIT Press, 2000. URL https://proceedings.neurips.cc/paper/2000/file/9fdb62f932adf55af2c0e09e55861964-Paper.pdf.
  • Vaswani et al. (2017) Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, Ł., and Polosukhin, I. Attention is all you need. Advances in neural information processing systems, 30, 2017.
  • Yang et al. (2019) Yang, B., Bender, G., Le, Q. V., and Ngiam, J. Condconv: Conditionally parameterized convolutions for efficient inference. Advances in Neural Information Processing Systems, 32, 2019.
  • Yehudai & Shamir (2019) Yehudai, G. and Shamir, O. On the power and limitations of random features for understanding neural networks. Advances in Neural Information Processing Systems, 32, 2019.
  • Yu et al. (2019) Yu, B., Zhang, J., and Zhu, Z. On the learning dynamics of two-layer nonlinear convolutional neural networks. arXiv preprint arXiv:1905.10157, 2019.
  • Zagoruyko & Komodakis (2016) Zagoruyko, S. and Komodakis, N. Wide residual networks. arXiv preprint arXiv:1605.07146, 2016.
  • Zhang et al. (2020a) Zhang, S., Wang, M., Liu, S., Chen, P.-Y., and Xiong, J. Fast learning of graph neural networks with guaranteed generalizability: one-hidden-layer case. In International Conference on Machine Learning, pp. 11268–11277. PMLR, 2020a.
  • Zhang et al. (2020b) Zhang, S., Wang, M., Xiong, J., Liu, S., and Chen, P.-Y. Improved linear convergence of training CNNs with generalizability guarantees: A one-hidden-layer case. IEEE Transactions on Neural Networks and Learning Systems, 32(6):2622–2635, 2020b.
  • Zhong et al. (2017a) Zhong, K., Song, Z., and Dhillon, I. S. Learning non-overlapping convolutional neural networks with multiple kernels. arXiv preprint arXiv:1711.03440, 2017a.
  • Zhong et al. (2017b) Zhong, K., Song, Z., Jain, P., Bartlett, P. L., and Dhillon, I. S. Recovery guarantees for one-hidden-layer neural networks. In International conference on machine learning, pp. 4140–4149. PMLR, 2017b.
  • Zhou et al. (2022) Zhou, Y., Lei, T., Liu, H., Du, N., Huang, Y., Zhao, V. Y., Dai, A. M., Chen, Z., Le, Q. V., and Laudon, J. Mixture-of-experts with expert choice routing. In Oh, A. H., Agarwal, A., Belgrave, D., and Cho, K. (eds.), Advances in Neural Information Processing Systems, 2022. URL https://openreview.net/forum?id=jdJo1HIVinI.
  • Zou et al. (2020) Zou, D., Cao, Y., Zhou, D., and Gu, Q. Gradient descent optimizes over-parameterized deep relu networks. Machine learning, 109(3):467–492, 2020.

Appendix A Experiments on CIFAR-10 Datasets

We also compare WRN and WRN-pMoE on CIFAR-10-based datasets. To better reflect local features, in addition to the original CIFAR-10, we adopt techniques of Karp et al. (2021) to generate two datasets based on CIFAR-10:

1. CIFAR-10 with ImageNet noise. Each CIFAR-10 image is down-sampled to size 16×1616\times 16 and placed at a random location of a background image chosen from ImageNet Plants synset. Figure 14(c) shows an example image of this dataset.

2. CIFAR-Vehicles. Each vehicle image of CIFAR-10 is down-sampled to size 16×1616\times 16 and placed in one quadrant of an image randomly where the other quadrants are randomly filled with down-sampled animal images in CIFAR-10. See Figure 14(b) for a sample image.

The last convolutional layer of WRN receives a (8×8×640)(8\times 8\times 640) dimensional feature map. In WRN-pMoE we divide this feature map into 6464 patches with size (1×1×640)(1\times 1\times 640). The MoE layer of WRN-pMoE contains k=4k=4 experts with each expert receiving l=16l=16 patches.

Refer to caption
Figure 13: Example images from (a) CIFAR-10, (b) CIFAR-Vehicles, and (c) CIFAR-10, ImageNet noise datasets
Refer to caption
Figure 14: Ten-classification accuracy of WRN and WRN-pMoE on CIFAR-10
Refer to caption
Figure 15: Ten-classification accuracy of WRN and WRN-pMoE on CIFAR-10, ImageNet noise
Refer to caption
Figure 16: Four-classification accuracy of WRN and WRN-pMoE on CIFAR-Vehicles

Figures 14, 16, and 16 compare the test accuracy of WRN and WRN-pMoE for the ten-classification problem on CIFAR10 and CIFAR-10 with ImageNet noise, and the four-classification problem in CIFAR-Vehicles, respectively. WRN-pMoE outperforms WRN in all these datasets, indicating reduced sample complexity using the pMoE layer. The performance gap is more significant in the other two datasets than the original CIFAR-10 dataset. That is because these constructed datasets contain local features, and the pMoE layer has a clear advantage in learning local features effectively.

Appendix B Preliminaries

The loss function for SGD at iteration tt with minibatch t\mathcal{B}_{t}:

(θ(t)):=1B(x,y)tlog(1+eyfM(θ(t),x))\mathcal{L}(\theta^{(t)}):=\cfrac{1}{B}\sum_{(x,y)\in\mathcal{B}_{t}}\log{(1+e^{-yf_{M}(\theta^{(t)},x)})} (10)

For the router-training in separate-training pMoE, the loss function of SGD at iteration tt with minibatch tr\mathcal{B}_{t}^{r}:

r(w1(t),w2(t)):=1Br(x,y)tryw1(t)w2(t),j=1nx(j)\ell_{r}(w_{1}^{(t)},w_{2}^{(t)}):=-\cfrac{1}{B_{r}}\hskip 2.84544pt\sum_{(x,y)\in\mathcal{B}_{t}^{r}}y\langle w^{(t)}_{1}-w_{2}^{(t)},\sum_{j=1}^{n}x^{(j)}\rangle (11)

Notations:

  1. 1.

    Generally O~(.)\tilde{O}(.) and Ω~(.)\tilde{\Omega}(.) hides factor log(poly(m,n,p,δ,1ϵ))\log(\text{poly}(m,n,p,\delta,\frac{1}{\epsilon})). At Lemma E.3 and D.4, Ω~(.)\tilde{\Omega}(.) hides factor log(poly(n))\log(\text{poly}(n)).

  2. 2.

    Generally with high probability (abbreviated as w.h.p.) implies with probability 11poly(m,n,p,δ,1ϵ)1-\cfrac{1}{\text{poly}(m,n,p,\delta,\frac{1}{\epsilon})}, where poly(.)\text{poly}(.) implies a sufficiently large polynomial. At Lemma E.2, E.3 and D.4 “w.h.p.” implies 11poly(n)1-\cfrac{1}{\text{poly}(n)}.

  3. 3.

    We denote, σ=1m\sigma=\frac{1}{\sqrt{m}} such that the expert initialization, wr,s(0)𝒩(0,σ2𝕀d×d),s[k],r[m/k]w_{r,s}^{(0)}\sim\mathcal{N}(0,\sigma^{2}\mathbb{I}_{d\times d}),\forall s\in[k],\forall r\in[m/k].

The training algorithms for separate-training and joint-training pMoE are given in Algorithm 1 and Algorithm 2, respectively:

Algorithm 1 Two-phase SGD for separate-training pMoE

Input : Training data {(xi,yi)}i=1N\{(x_{i},y_{i})\}_{i=1}^{N}, learning rates ηr\eta_{r} and η\eta, number of iterations TrT_{r} and TT, batch-
           sizes BrB_{r} and BB
Step-1: Initialize ws(0),wr,s(0),ar,s,s{1,2},r[m/k]w_{s}^{(0)},w_{r,s}^{(0)},a_{r,s},\forall s\in\{1,2\},r\in[m/k] according to (7) and (8)
Step-2: for t=0,1,,Tr1t=0,1,...,T_{r}-1 do:

ws(t+1)=ws(t)ηrr(w1(t),w2(t))ws(t),s{1,2}w_{s}^{(t+1)}=w_{s}^{(t)}-\eta_{r}\cfrac{\partial\ell_{r}(w_{1}^{(t)},w_{2}^{(t)})}{\partial w_{s}^{(t)}},\forall s\in\{1,2\}

Step-3: for t=0,1,,T1t=0,1,...,T-1 do:

wr,s(t+1)=wr,s(t)η(θ(t))wr,s(t),r[m/k],s{1,2}w_{r,s}^{(t+1)}=w_{r,s}^{(t)}-\eta\cfrac{\partial\mathcal{L}(\theta^{(t)})}{\partial w_{r,s}^{(t)}},\hskip 5.69046pt\forall r\in[m/k],s\in\{1,2\}

Algorithm 2 SGD for joint-training pMoE

Input : Training data {(xi,yi)}i=1N\{(x_{i},y_{i})\}_{i=1}^{N}, learning rate η\eta, number of iteration TT, batch-size BB
Step-1: Initialize ws(0),wr,s(0),ar,s,s[k],r[m/k]w_{s}^{(0)},w_{r,s}^{(0)},a_{r,s},\forall s\in[k],r\in[m/k] according to (7) and (8)
Step-2: for t=0,1,,T1t=0,1,...,T-1 do:

ws(t+1)=ws(t)η(θ(t))ws(t),s[k]w_{s}^{(t+1)}=w_{s}^{(t)}-\eta\cfrac{\partial\mathcal{L}(\theta^{(t)})}{\partial w_{s}^{(t)}},\forall s\in[k]

wr,s(t+1)=wr,s(t)η(θ(t))wr,s(t),r[m/k],s[k]w_{r,s}^{(t+1)}=w_{r,s}^{(t)}-\eta\cfrac{\partial\mathcal{L}(\theta^{(t)})}{\partial w_{r,s}^{(t)}},\hskip 5.69046pt\forall r\in[m/k],s\in[k]

Appendix C Proof Sketch

The proof of generalization guarantee for pMoE (i.e., Theorem 4.2 and 4.5) can be outlined as follows (the proof for single CNN follows a simpler version of the outline provided below):

Step 1. (Feature learning in the router) For separate-training pMoE, we first show that the batch-gradient of the router loss (i.e., r(w1(t),w2(t))\ell_{r}(w_{1}^{(t)},w_{2}^{(t)})) w.r.t. the gating kernels (i.e., w1(t)w_{1}^{(t)} and w2(t)w_{2}^{(t)}) has large component (of size 1δd2Ω(nBr)\frac{1-\delta_{d}}{2}-\Omega\left(\frac{n}{\sqrt{B_{r}}}\right)) along the class-discriminative pattern o1o_{1} and o2o_{2} respectively. Then, by selecting Br=Ω(n2(1δd)2)B_{r}=\Omega\left(\frac{n^{2}}{(1-\delta_{d})^{2}}\right) (which provides us Ω(1)\Omega(1) loss reduction per step) and training for Ω(11δd)\Omega\left(\frac{1}{1-\delta_{d}}\right) iterations, we can show that w1w_{1} and w2w_{2} is sufficiently aligned with o1o_{1} and o2o_{2} respectively to guarantee the selection of these class-discriminative patterns in TOP-ll patches when lll\geq l^{*} (see Lemma D.4 for exact statement).

Step 2. (Coupling the experts to pseudo experts) When the experts of pMoE are sufficiently overparameterized, w.h.p. the experts can be coupled to a smooth pseudo network777The pseudo network is defined as the network which activation pattern does not change from the initialization i.e., the sign of the pre-activation output of hidden nodes does not change from the sign at initialization; see (Li & Liang, 2018) for details. of experts as for every sample drawn from the distribution 𝒟\mathcal{D} and every τ>0\tau>0, the activation pattern for 1Ω(τlσ)1-\Omega\left(\frac{\tau l}{\sigma}\right) (for separate-training pMoE) or 1Ω(τnσ)1-\Omega\left(\frac{\tau n}{\sigma}\right) (for joint-training pMoE) fraction of hidden nodes in each expert does not change from the initialization for O(τη)O(\frac{\tau}{\eta}) iterations (see Lemma G.1 or H.1 for exact statement). This indicates that with τ=O(σl)\tau=O\left(\frac{\sigma}{l}\right) (for separate-training pMoE) or τ=O(σn)\tau=O\left(\frac{\sigma}{n}\right) (for joint-training pMoE), η=Ω(1ml)\eta=\Omega\left(\frac{1}{ml}\right) (for separate-training pMoE) or η=Ω(1mn)\eta=\Omega\left(\frac{1}{mn}\right) (for joint-training pMoE) and σ=O(1m)\sigma=O\left(\frac{1}{\sqrt{m}}\right) we can couple Ω(1)\Omega(1) fraction of hidden nodes of each expert to the corresponding pseudo experts for O(m)O(\sqrt{m}) iterations.

Step 3.(Large error implies large gradient) We can now analyze the pseudo network of experts corresponding to the separate-training pMoE to show that, at any iteration tt, the magnitude of the expected gradient for any expert s{1,2}s\in\{1,2\} of the pseudo network is Ω(vs(t)l)\Omega\left(\frac{v_{s}^{(t)}}{l}\right) where vs(t)v_{s}^{(t)} characterizes the class-conditional expected error over samples with y=+1y=+1 and y=1y=-1 for s=1s=1 and s=2s=2, respectively (see Lemma G.3 for exact statement). Similarly, for joint-training pMoE we show that the magnitude of the expected gradient is Ω(vs(t)l)\Omega\left(\frac{v_{s}^{(t)}}{l}\right), but this time vs(t)v_{s}^{(t)} characterizes the maximum of the class-conditional expected-errors over the samples for which the expert “ss” receiving class-discriminative patterns from the router (see Lemma H.3 for exact statement).

Step 4. (Convergence) Now let us define v(t)=s[k]vs2(t)v^{(t)}=\sqrt{\sum_{s\in[k]}v_{s}^{2}(t)}. For separate-training pMoE, by selecting the batch size Bt=Ω(l4(v(t))4)B_{t}=\Omega(\frac{l^{4}}{(v^{(t)})^{4}}) at iteration tt, η=Ω((v(t))2ml2)\eta=\Omega(\frac{(v^{(t)})^{2}}{ml^{2}}) and τ=O(σ(v(t))2l3)\tau=O(\frac{\sigma(v^{(t)})^{2}}{l^{3}}), we can couple the empirical batch gradient of each expert of the true network for that batch to the expected gradient of the corresponding expert of the pseudo network. Because the pseudo network is smooth, we can show that SGD minimizes the expected loss of the true network by Ω(ηm(v(t))2l2)\Omega(\frac{\eta m(v^{(t)})^{2}}{l^{2}}) at each iteration for t=O(σ(v(t))2ηl3)t=O(\frac{\sigma(v^{(t)})^{2}}{\eta l^{3}}) iterations (see Lemma G.4 for the exact statement). Similarly, for joint-training pMoE, by selecting Bt=Ω(k2(v(t))4)B_{t}=\Omega(\frac{k^{2}}{(v^{(t)})^{4}}) and η=Ω((v(t))2l3mk2)\eta=\Omega(\frac{(v^{(t)})^{2}l^{3}}{mk^{2}}) we can show that SGD minimizes the expected loss of the true network by Ω(ηm(v(t))2l2)\Omega(\frac{\eta m(v^{(t)})^{2}}{l^{2}}) for t=O(σ(v(t))2l2ηnk)t=O(\frac{\sigma(v^{(t)})^{2}l^{2}}{\eta nk}) (see Lemma H.4 for exact statement). As the loss of the true network is O(1)O(1) at initialization, eventually the network will converge.

Step 5. (Generalization) We show that to ensure at most ϵ\epsilon generalization error after any iteration tt, we need max{v1(t),v2(t)}<ϵ2\max\{v_{1}^{(t)},v_{2}^{(t)}\}<\epsilon^{2} where v1(t)v_{1}^{(t)} and v2(t)v_{2}^{(t)} correspond to the class-conditional expected error of the class with y=+1y=+1 and y=1y=-1, respectively. Now as we show that the router in the separate-training pMoE dispatch class-discriminative patches of all the samples labeled as y=+1y=+1 to the expert indexed by s=1s=1 and class-discriminative patches of all the samples labeled as y=1y=-1 to the expert indexed by s=2s=2 from the beginning of expert-training, v(t)<ϵ2v^{(t)}<\epsilon^{2} ensures max{v1(t),v2(t)}<ϵ2\max\{v_{1}^{(t)},v_{2}^{(t)}\}<\epsilon^{2}. On the other hand, for the joint-training pMoE, as we assume that the router ensures the dispatchment of all the class-discriminative patches of a class to a particular expert before the convergence of the model and the gating value of the patch is the largest among all the patches sent to that particular expert, v(t)<ϵ2lv^{(t)}<\frac{\epsilon^{2}}{l} implies max{v1(t),v2(t)}<ϵ2\max\{v_{1}^{(t)},v_{2}^{(t)}\}<\epsilon^{2}. Hence for separate-training pMoE, by setting v(t)ϵ2v^{(t)}\geq\epsilon^{2} we show that with B=Ω(l4/ϵ8)B=\Omega(l^{4}/\epsilon^{8}) and η=Ω(1/mpoly(l,1/ϵ))\eta=\Omega(1/m\text{poly}(l,1/\epsilon)) for T=O(l4/ϵ8)T=O(l^{4}/\epsilon^{8}) iterations, we can guarantee that the generalization error is less than ϵ\epsilon (see Theorem F.3 for exact statement). Similarly, for joint-training pMoE, by setting v(t)ϵ2lv^{(t)}\geq\frac{\epsilon^{2}}{l} and setting B=Ω(k2l4/ϵ8)B=\Omega(k^{2}l^{4}/\epsilon^{8}) and η=Ω(1/(mpoly(l,1/ϵ)\eta=\Omega(1/(m\text{poly}(l,1/\epsilon) for T=O(k2l2/ϵ8)T=O(k^{2}l^{2}/\epsilon^{8}) iterations, we can guarantee that the generalization error is less than ϵ\epsilon (see Theorem F.5 for exact statement).

Appendix D Proof of the Lemma 4.1

Definition D.1.

(δ\delta^{\prime}-closer class-irrelevant patterns) For any δ>0\delta^{\prime}>0, a class-irrelevant pattern qq is δ\delta^{\prime}-closer to o1o_{1} than o2o_{2}, if o1,qo2,q>δ\langle o_{1},q\rangle-\langle o_{2},q\rangle>\delta^{\prime} for any δ>0\delta^{\prime}>0. Similarly, a class-irrelevant pattern qq is δ\delta^{\prime}-closer to o2o_{2} than o1o_{1} if o2,qo1,q>δ\langle o_{2},q\rangle-\langle o_{1},q\rangle>\delta^{\prime}.

Definition D.2.

(Set of δ\delta^{\prime}-closer class-irrelevant patterns, 𝒮c(δ)\mathcal{S}_{c}(\delta^{\prime})) For any δ>0\delta^{\prime}>0, define the set of δ\delta^{\prime}-closer class-irrelevant patterns, denoted as 𝒮c(δ)i=1pSj\mathcal{S}_{c}(\delta^{\prime})\subset\bigcup_{i=1}^{p}S_{j} such that: q𝒮c(δ),|o1o2,q|>δ\forall q\in\mathcal{S}_{c}(\delta^{\prime}),|\langle o_{1}-o_{2},q\rangle|>\delta^{\prime}.

Definition D.3.

(Threshold, ll^{*}) Define the threshold ll^{*} such that:
                     (x,y)𝒟,|{j[n]:x(j)o1 and x(j)Sc(1δd2)}|l1\forall(x,y)\sim\mathcal{D},\left|\{j\in[n]:x^{(j)}\neq o_{1}\text{ and }x^{(j)}\in S_{c}\left(\frac{1-\delta_{d}}{2}\right)\}\right|\leq l^{*}-1

Lemma D.4.

(Full version of Lemma 4.1) For every lll\geq l^{*}, w.h.p. over the random initialization defined in (7), after completing the Step-2 of Algorithm-1 with batch-size Br=Ω~(n2(1δd)2)B_{r}=\tilde{\Omega}\left(\cfrac{n^{2}}{(1-\delta_{d})^{2}}\right) and learning rate ηr=Θ(1n)\eta_{r}=\Theta\left(\frac{1}{n}\right) for Tr=Ω(11δd)T_{r}=\Omega\left(\cfrac{1}{1-\delta_{d}}\right) iterations, the returned w1(Tr)w_{1}^{(T_{r})} and w2(Tr)w_{2}^{(T_{r})} satisfy

argj[n](x(j)=o1)J1(w1(Tr),x),(x,y=+1)𝒟\underset{j\in[n]}{\text{arg}}(x^{(j)}=o_{1})\in J_{1}(w_{1}^{(T_{r})},x),\quad\forall(x,y=+1)\sim\mathcal{D}
argj[n](x(j)=o2)J2(w2(Tr),x),(x,y=1)𝒟\underset{j\in[n]}{\text{arg}}(x^{(j)}=o_{2})\in J_{2}(w_{2}^{(T_{r})},x),\quad\forall(x,y=-1)\sim\mathcal{D}
Proof.

The proof follows directly from the Definition D.3 and the Lemma E.3. ∎

Appendix E Lemmas Used to Prove the Lemma 4.1

We denote,
ws(t)𝔼[r(w1,w2)]:=𝔼𝒟[r(w1(t),w2(t))ws(t)]\nabla_{w_{s}^{(t)}}\mathbb{E}[\ell_{r}(w_{1},w_{2})]:=\mathbb{E}_{\mathcal{D}}\left[\cfrac{\partial\ell_{r}(w^{(t)}_{1},w_{2}^{(t)})}{\partial w_{s}^{(t)}}\right] where ws(t){w1(t),w2(t)}w_{s}^{(t)}\in\{w_{1}^{(t)},w_{2}^{(t)}\} for all t[Tr]t\in[T_{r}].

Lemma E.1.

At any iteration tTrt\leq T_{r} of the Step-2 of Algorithm 1,

w1(t)𝔼[l(f(x),y)]=12(o1o2)\nabla_{w_{1}^{(t)}}\mathbb{E}[l(f(x),y)]=-\cfrac{1}{2}\left(o_{1}-o_{2}\right), and w2(t)𝔼[l(f(x),y)]=12(o2o1)\nabla_{w_{2}^{(t)}}\mathbb{E}[l(f(x),y)]=-\cfrac{1}{2}\left(o_{2}-o_{1}\right)

Proof.

As, r(w1(t),w2(t))=1Br(x,y)tryw1(t)w2(t),j=1nx(j)\ell_{r}(w_{1}^{(t)},w_{2}^{(t)})=-\cfrac{1}{B_{r}}\hskip 2.84544pt\sum_{(x,y)\in\mathcal{B}_{t}^{r}}y\langle w^{(t)}_{1}-w_{2}^{(t)},\sum_{j=1}^{n}x^{(j)}\rangle,

w1(t)𝔼[lr(w1,w2)]=𝔼𝒟[yj=1nx(j)]\nabla_{w_{1}^{(t)}}\mathbb{E}[l_{r}(w_{1},w_{2})]=-\mathbb{E}_{\mathcal{D}}\left[y\sum_{j=1}^{n}x^{(j)}\right] and w2(t)𝔼[lr(w1,w2)]=𝔼𝒟[yj=1nx(j)]\nabla_{w_{2}^{(t)}}\mathbb{E}[l_{r}(w_{1},w_{2})]=\mathbb{E}_{\mathcal{D}}\left[y\sum_{j=1}^{n}x^{(j)}\right]

Therefore,

w1(t)𝔼[lr(w1,w2)]=12𝔼𝒟|y=+1[j=1nx(j)|y=+1]+12𝔼𝒟|y=1[j=1nx(j)|y=1]\displaystyle\nabla_{w_{1}^{(t)}}\mathbb{E}[l_{r}(w_{1},w_{2})]=-\frac{1}{2}\mathbb{E}_{\mathcal{D}|y=+1}\left[\sum_{j=1}^{n}x^{(j)}|y=+1\right]+\frac{1}{2}\mathbb{E}_{\mathcal{D}|y=-1}\left[\sum_{j=1}^{n}x^{(j)}|y=-1\right]
=12𝔼𝒟|y=+1[j[n]/arg 𝑗x(j)=o1x(j)|y=+1]+12𝔼𝒟|y=1[j[n]/arg 𝑗x(j)=o2x(j)|y=1]\displaystyle=-\frac{1}{2}\mathbb{E}_{\mathcal{D}|y=+1}\left[\sum_{j\in[n]/\underset{j}{\text{arg }}x^{(j)}=o_{1}}x^{(j)}|y=+1\right]+\frac{1}{2}\mathbb{E}_{\mathcal{D}|y=-1}\left[\sum_{j\in[n]/\underset{j}{\text{arg }}x^{(j)}=o_{2}}x^{(j)}|y=-1\right]
12(o1o2)\displaystyle\hskip 9.95863pt-\frac{1}{2}\left(o_{1}-o_{2}\right)
=12(o1o2)\displaystyle=-\frac{1}{2}\left(o_{1}-o_{2}\right)

where the last equality comes from the fact that class-irrelevant patterns are distributed identically in both classes. Using similar line of arguments we can show that, w2(t)𝔼[l(f(x),y)]=12(o2o1)\nabla_{w_{2}^{(t)}}\mathbb{E}[l(f(x),y)]=-\cfrac{1}{2}\left(o_{2}-o_{1}\right). ∎

Lemma E.2.

With probability 11poly(n)1-\frac{1}{poly(n)} (i.e., w.h.p.) over the random initialization of the gating kernels defined in (7), ws(0)1n2\left|\left|w_{s}^{(0)}\right|\right|\leq\frac{1}{n^{2}}; s{1,2}\forall s\in\{1,2\}

Proof.

Let us denote the ii-th element of the vector ws(0)w_{s}^{(0)} as wsi(0)w_{s_{i}}^{(0)} where i[d]i\in[d].
Then according to the random initialization of ws(0)w_{s}^{(0)} and using a Gaussian tail-bound (i.e., for X𝒩(0,σ2):Pr[|X|t]2et2/2σ2X\sim\mathcal{N}(0,\sigma^{2}):Pr[|X|\geq t]\leq 2e^{-t^{2}/2\sigma^{2}}): [|wsi(0)|1n2d]1poly(n)\mathbb{P}\left[\left|w_{s_{i}}^{(0)}\right|\geq\frac{1}{n^{2}\sqrt{d}}\right]\leq\frac{1}{poly(n)}.
Let us denote the event :i[d],|wsi(0)|1n2d\mathcal{E}:\forall i\in[d],\left|w_{s_{i}}^{(0)}\right|\leq\frac{1}{n^{2}\sqrt{d}}. Therefore, []11poly(n)\mathbb{P}\left[\mathcal{E}\right]\geq 1-\frac{1}{poly(n)}.
Now, conditioned on the event ,ws(0)1n2\mathcal{E},\left|\left|w_{s}^{(0)}\right|\right|\leq\frac{1}{n^{2}}.
Therefore, [ws(0)1n2][ws(0)1n2|][]=11poly(n)\mathbb{P}\left[\left|\left|w_{s}^{(0)}\right|\right|\leq\frac{1}{n^{2}}\right]\leq\mathbb{P}\left[\left|\left|w_{s}^{(0)}\right|\right|\geq\frac{1}{n^{2}}|\mathcal{E}\right]\mathbb{P}\left[\mathcal{E}\right]=1-\frac{1}{poly(n)}

Lemma E.3.

W.h.p. over the random initialization of the gating-kernels defined in (7) and randomly selected batch of batch-size Br=Ω~(n2(1δd)2)B_{r}=\tilde{\Omega}\left(\cfrac{n^{2}}{(1-\delta_{d})^{2}}\right) at each iteration, after Tr=Ω(11δd)T_{r}=\Omega\left(\cfrac{1}{1-\delta_{d}}\right) iterations of Step-2 of Algorithm 1 with learning rate ηr=Θ(1n)\eta_{r}=\Theta\left(\frac{1}{n}\right), (x,y)𝒟,j[n]:x(j)𝒮c(1δd2)\forall(x,y)\sim\mathcal{D},j\in[n]:x^{(j)}\not\in\mathcal{S}_{c}(\frac{1-\delta_{d}}{2}), w1(Tr),o1>w1(Tr),x(j)\langle w_{1}^{(T_{r})},o_{1}\rangle>\langle w_{1}^{(T_{r})},x^{(j)}\rangle and w2(Tr),o2>w2(Tr),x(j)\langle w_{2}^{(T_{r})},o_{2}\rangle>\langle w_{2}^{(T_{r})},x^{(j)}\rangle.

Proof.

Let, at tt-th iteration of Step-2 of Algorithm 1, ~ws(t)=r(w1(t),w2(t))ws(t)\tilde{\nabla}_{w_{s}}^{(t)}=\cfrac{\partial\ell_{r}(w_{1}^{(t)},w_{2}^{(t)})}{\partial w_{s}^{(t)}} for all s{1,2}s\in\{1,2\}

Also let us denote, ws(t)𝔼[r(w1(t),w2(t))]=ws(t)\nabla_{w_{s}^{(t)}}\mathbb{E}\left[\ell_{r}(w_{1}^{(t)},w_{2}^{(t)})\right]=\nabla_{w_{s}}^{(t)} for all s{1,2}s\in\{1,2\}

Therefore, after TrT_{r}-th iteration of SGD and using Lemma E.1,

w1(Tr)\displaystyle w_{1}^{(T_{r})} =w1(0)ηrt=0Tr1~w1(t)\displaystyle=w_{1}^{(0)}-\eta_{r}\overset{T_{r}-1}{\underset{t=0}{\sum}}\tilde{\nabla}_{w_{1}}^{(t)}
=w1(0)+ηrTr2(o1o2)ηrt=0Tr1(~w1(t)w1(t))\displaystyle=w_{1}^{(0)}+\cfrac{\eta_{r}T_{r}}{2}\left(o_{1}-o_{2}\right)-\eta_{r}\overset{T_{r}-1}{\underset{t=0}{\sum}}\left(\tilde{\nabla}_{w_{1}}^{(t)}-\nabla_{w_{1}}^{(t)}\right)

Similarly, w2(Tr)=w2(0)+ηrTr2(o2o1)ηrt=0Tr1(~w2(t)w2(t))w_{2}^{(T_{r})}=w_{2}^{(0)}+\cfrac{\eta_{r}T_{r}}{2}\left(o_{2}-o_{1}\right)-\eta_{r}\overset{T_{r}-1}{\underset{t=0}{\sum}}\left(\tilde{\nabla}_{w_{2}}^{(t)}-\nabla_{w_{2}}^{(t)}\right).

Now, ~ws(t)=O(n)||\tilde{\nabla}_{w_{s}}^{(t)}||=O(n). Hence, w.h.p. over a randomly sampled batch of size BrB_{r}, using Hoeffding’s concentration,

~ws(t)ws(t)=O~(nBr);s{1,2}||\tilde{\nabla}_{w_{s}}^{(t)}-\nabla_{w_{s}}^{(t)}||=\tilde{O}\left(\cfrac{n}{\sqrt{B_{r}}}\right);\forall s\in\{1,2\}.

Now,

w1(Tr),o1\displaystyle\langle w_{1}^{(T_{r})},o_{1}\rangle =w1(0),o1+ηrTr2(1o1,o2)ηrt=0Tr1~w1(t)w1(t),o1\displaystyle=\langle w_{1}^{(0)},o_{1}\rangle+\cfrac{\eta_{r}T_{r}}{2}\left(1-\langle o_{1},o_{2}\rangle\right)-\eta_{r}\overset{T_{r}-1}{\underset{t=0}{\sum}}\langle\tilde{\nabla}_{w_{1}}^{(t)}-\nabla_{w_{1}}^{(t)},o_{1}\rangle
ηrTr2(1δd)ηrTrO~(nBr)w1(0)\displaystyle\geq\cfrac{\eta_{r}T_{r}}{2}\left(1-\delta_{d}\right)-\eta_{r}T_{r}\tilde{O}\left(\cfrac{n}{\sqrt{B_{r}}}\right)-\left|\left|w_{1}^{(0)}\right|\right|

On the other hand, (x,y)𝒟,j[n]:x(j)𝒮c(1δd2)\forall(x,y)\sim\mathcal{D},\forall j\in[n]:x^{(j)}\not\in\mathcal{S}_{c}\left(\frac{1-\delta_{d}}{2}\right),

w1(Tr),x(j)\displaystyle\langle w_{1}^{(T_{r})},x^{(j)}\rangle =w1(0),x(j)+ηrTr2(o1,x(j)o2,x(j))ηrt=0Tr1~w1(t)w1,x(j)\displaystyle=\langle w_{1}^{(0)},x^{(j)}\rangle+\cfrac{\eta_{r}T_{r}}{2}\left(\langle o_{1},x^{(j)}\rangle-\langle o_{2},x^{(j)}\rangle\right)-\eta_{r}\overset{T_{r}-1}{\underset{t=0}{\sum}}\langle\tilde{\nabla}_{w_{1}}^{(t)}-\nabla_{w_{1}},x^{(j)}\rangle
ηrTr4(1δd)+ηrTrO~(nBr)+w1(0)\displaystyle\leq\cfrac{\eta_{r}T_{r}}{4}(1-\delta_{d})+\eta_{r}T_{r}\tilde{O}\left(\cfrac{n}{\sqrt{B_{r}}}\right)+\left|\left|w_{1}^{(0)}\right|\right|

From Lemma E.2, w.h.p. over the random initialization: w1(0)1n2\left|\left|w_{1}^{(0)}\right|\right|\leq\cfrac{1}{n^{2}}.

Therefore, selecting Br=Ω~(n2(1δd)2)B_{r}=\tilde{\Omega}\left(\cfrac{n^{2}}{(1-\delta_{d})^{2}}\right) and ηr=Θ(1n)\eta_{r}=\Theta\left(\frac{1}{n}\right), we need Tr=Ω(11δd)T_{r}=\Omega\left(\cfrac{1}{1-\delta_{d}}\right) iterations to achieve w1(Tr),o1>w1(Tr),x(j)\langle w_{1}^{(T_{r})},o_{1}\rangle>\langle w_{1}^{(T_{r})},x^{(j)}\rangle, j[n]:x(j)𝒮c(1δd2)\forall j\in[n]:x^{(j)}\in\mathcal{S}_{c}\left(\frac{1-\delta_{d}}{2}\right)

Similar line of arguments can be made to show with batch size Br=Ω~(n2(1δd)2)B_{r}=\tilde{\Omega}\left(\cfrac{n^{2}}{(1-\delta_{d})^{2}}\right) and learning rate ηr=Θ(1n)\eta_{r}=\Theta\left(\frac{1}{n}\right), after Tr=Ω(11δd)T_{r}=\Omega\left(\cfrac{1}{1-\delta_{d}}\right) iterations, w2(Tr),o2w2(Tr),x(j)\langle w_{2}^{(T_{r})},o_{2}\rangle\geq\langle w_{2}^{(T_{r})},x^{(j)}\rangle, j[n]:x(j)𝒮c(1δd2)\forall j\in[n]:x^{(j)}\in\mathcal{S}_{c}\left(\frac{1-\delta_{d}}{2}\right).

Appendix F Proofs of the Theorem 4.2, 4.3 and 4.5

Definition F.1.

At any iteration tt of the minibatch SGD,

  1. 1.

    Define the value function, v(t)(θ(t),x,y):=11+eyfM(θ(t),x)\displaystyle v^{(t)}(\theta^{(t)},x,y):=\frac{1}{1+e^{yf_{M}(\theta^{(t)},x)}}. It is easy to show that for any (x,y)𝒟(x,y)\sim\mathcal{D}, 0v(t)(θ(t),x,y)10\leq v^{(t)}(\theta^{(t)},x,y)\leq 1. The function captures the prediction error, i.e., a larger v(t)v^{(t)} indicates a larger prediction error.

  2. 2.

    Define, the class-conditional expected value function, v1(t):=𝔼𝒟|y=+1[v(t)(θ(t),x,y)|y=+1]v_{1}^{(t)}:=\mathbb{E}_{\mathcal{D}|y=+1}[v^{(t)}(\theta^{(t)},x,y)|y=+1] and v2(t):=𝔼𝒟|y=1[v(t)(θ(t),x,y)|y=1]v_{2}^{(t)}:=\mathbb{E}_{\mathcal{D}|y=-1}[v^{(t)}(\theta^{(t)},x,y)|y=-1]. Here, v1(t)v_{1}^{(t)} captures the expected error for the class with label y=+1y=+1 and v2(t)v_{2}^{(t)} captures the expected error for the class with label y=1y=-1.

Definition F.2.

At any iteration tt of the minibatch SGD,

  1. 1.

    For any sample (x,y)𝒟(x,y)\sim\mathcal{D}, we define the reduction of loss at the tt-th iteration of SGD as,

    ΔL(θ(t),θ(t+1),x,y):=(θ(t),x,y)(θ(t+1),x,y)\displaystyle\Delta L(\theta^{(t)},\theta^{(t+1)},x,y):=\mathcal{L}(\theta^{(t)},x,y)-\mathcal{L}(\theta^{(t+1)},x,y)

    where, (θ(t),x,y):=log(1+eyfM(θ(t),x))\mathcal{L}(\theta^{(t)},x,y):=\log(1+e^{-yf_{M}(\theta^{(t)},x)}) is the single-sample loss function.

  2. 2.

    Define the expected reduction of loss at the tt-th iteration of SGD as,

    ΔL(θ(t),θ(t+1)):=𝔼𝒟[(θ(t),x,y)(θ(t+1),x,y)]\displaystyle\Delta L(\theta^{(t)},\theta^{(t+1)}):=\mathbb{E}_{\mathcal{D}}\left[\mathcal{L}(\theta^{(t)},x,y)-\mathcal{L}(\theta^{(t+1)},x,y)\right]
Theorem F.3.

(Full version of Theorem 4.2) For every ϵ>0\epsilon>0 and lll\geq l^{*}, for every mMS=Ω~(l10p12δ6/ϵ16)m\geq M_{S}=\tilde{\Omega}\left(l^{10}p^{12}\delta^{6}\big{/}\epsilon^{16}\right) with at least NS=Ω~(l8p12δ6/ϵ16)N_{S}=\tilde{\Omega}(l^{8}p^{12}\delta^{6}/\epsilon^{16}) training samples, after performing minibatch SGD with the batch size B=Ω~(l4p6δ3/ϵ8)B=\tilde{\Omega}\left(l^{4}p^{6}\delta^{3}\big{/}\epsilon^{8}\right) and the learning rate η=O~(1/mpoly(l,p,δ,1/ϵ,logm))\eta=\tilde{O}\big{(}1\big{/}m\textrm{poly}(l,p,\delta,1/\epsilon,\log m)\big{)} for T=O~(l4p6δ3/ϵ8)T=\tilde{O}\left(l^{4}p^{6}\delta^{3}\big{/}\epsilon^{8}\right) iterations, it holds w.h.p. that

(x,y)𝒟[yfM(θ(T),x)>0]1ϵ\underset{(x,y)\sim\mathcal{D}}{\mathbb{P}}\left[yf_{M}(\theta^{(T)},x)>0\right]\geq 1-\epsilon

Proof.

First we will show that for any ϵ<12\epsilon<\frac{1}{2}, if (x,y)𝒟[yf(θ(t),x)>0]1ϵ\mathbb{P}_{(x,y)\sim\mathcal{D}}\left[yf(\theta^{(t)},x)>0\right]\leq 1-\epsilon, then max{v1(t),v2(t)}ϵ2\max\{v_{1}^{(t)},v_{2}^{(t)}\}\geq\epsilon^{2}.

Now for any (x,y)𝒟(x,y)\sim\mathcal{D} and ϵ<12\epsilon<\frac{1}{2}, if v(t)(θ(t),x,y)ϵv^{(t)}(\theta^{(t)},x,y)\leq\epsilon, yfM(θ(t),x,y)>0yf_{M}(\theta^{(t)},x,y)>0 i.e., the prediction is correct.

Now if v1(t)=𝔼𝒟|y=+1[v(t)(θ(t),x,y)|y=+1]ϵ2v_{1}^{(t)}=\mathbb{E}_{\mathcal{D}|y=+1}\left[v^{(t)}(\theta^{(t)},x,y)\big{|}y=+1\right]\leq\epsilon^{2}, then using Markov’s inequality 𝒟|y=+1[v(t)(θ(t),x,y)ϵ]1ϵ\mathbb{P}_{\mathcal{D}|y=+1}\left[v^{(t)}(\theta^{(t)},x,y)\leq\epsilon\right]\geq 1-\epsilon which implies for any ϵ<12\epsilon<\frac{1}{2}, 𝒟|y=+1[yf(θ(t),x)>0]1ϵ\mathbb{P}_{\mathcal{D}|y=+1}\left[yf(\theta^{(t)},x)>0\right]\geq 1-\epsilon.

Similarly, if v2(t)=𝔼𝒟|y=1[v(t)(θ(t),x,y)|y=1]ϵ2v_{2}^{(t)}=\mathbb{E}_{\mathcal{D}|y=-1}\left[v^{(t)}(\theta^{(t)},x,y)\big{|}y=-1\right]\leq\epsilon^{2}, for any ϵ<12\epsilon<\frac{1}{2}, 𝒟|y=1[yf(θ(t),x)>0]1ϵ\mathbb{P}_{\mathcal{D}|y=-1}\left[yf(\theta^{(t)},x)>0\right]\geq 1-\epsilon.

Therefore, for any ϵ<12\epsilon<\frac{1}{2}, if (x,y)𝒟[yf(θ(t),x)>0]1ϵ\mathbb{P}_{(x,y)\sim\mathcal{D}}\left[yf(\theta^{(t)},x)>0\right]\leq 1-\epsilon, then max{v1(t),v2(t)}ϵ2\max\{v_{1}^{(t)},v_{2}^{(t)}\}\geq\epsilon^{2}.

Now, if v(t):=s{1,2}(vs(t))2ϵ2v^{(t)}:=\sqrt{\underset{s\in\{1,2\}}{\sum}(v_{s}^{(t)})^{2}}\leq\epsilon^{2} then max{v1(t),v2(t)}ϵ2\max\{v_{1}^{(t)},v_{2}^{(t)}\}\leq\epsilon^{2}, which implies after a proper number of iterations if v(t)ϵ2v^{(t)}\leq\epsilon^{2} then (x,y)𝒟[yfM(θ(T),x)>0]1ϵ\underset{(x,y)\sim\mathcal{D}}{\mathbb{P}}\left[yf_{M}(\theta^{(T)},x)>0\right]\geq 1-\epsilon.

Let, v(t)ϵ2v^{(t)}\geq\epsilon^{2}. Then by using Lemma G.4 for every lll\geq l^{*}, with η=O~(ϵ4ml2p3δ3/2)\eta=\tilde{O}\left(\cfrac{\epsilon^{4}}{ml^{2}p^{3}\delta^{3/2}}\right) and B=Ω~(l4p6δ3ϵ8)B=\tilde{\Omega}\left(\cfrac{l^{4}p^{6}\delta^{3}}{\epsilon^{8}}\right), at least for t=O~(σϵ4ηl3p3δ3/2)t=\tilde{O}\left(\cfrac{\sigma\epsilon^{4}}{\eta l^{3}p^{3}\delta^{3/2}}\right) we have,

ΔL(θ(t),θ(t+1))=Ω~(ηmϵ4l2p3δ3/2)\Delta L(\theta^{(t)},\theta^{(t+1)})=\tilde{\Omega}\left(\cfrac{\eta m\epsilon^{4}}{l^{2}p^{3}\delta^{3/2}}\right) (12)

Now, as wr,s(0)𝒩(0,σ2)w_{r,s}^{(0)}\sim\mathcal{N}(0,\sigma^{2}) with σ=1m\sigma=\frac{1}{\sqrt{m}}, wr,s(0),x(j)𝒩(0,σ2)\left\langle w_{r,s}^{(0)},x^{(j)}\right\rangle\sim\mathcal{N}(0,\sigma^{2}) jJs(ws(0),x)\forall j\in J_{s}(w_{s}^{(0)},x) and (x,y)𝒟\forall(x,y)\sim\mathcal{D}. Therefore, w.h.p. |fM(θ(0),x)|=O~(1)\left|f_{M}(\theta^{(0)},x)\right|=\tilde{O}(1) which implies (θ(0),x,y)=O~(1)\mathcal{L}(\theta^{(0)},x,y)=\tilde{O}(1). Now as (θ(t),x,y)>0\mathcal{L}(\theta^{(t)},x,y)>0, (12) can happen at most O~(l2p3δ3/2ηmϵ4)\tilde{O}\left(\frac{l^{2}p^{3}\delta^{3/2}}{\eta m\epsilon^{4}}\right) iterations. Now as ηm=O~(ϵ4l2p3δ3/2)\eta m=\tilde{O}\left(\cfrac{\epsilon^{4}}{l^{2}p^{3}\delta^{3/2}}\right), we need T=O~(l4p6δ3ϵ8)T=\tilde{O}\left(\cfrac{l^{4}p^{6}\delta^{3}}{\epsilon^{8}}\right) iterations to ensure that v(t)ϵ2v^{(t)}\leq\epsilon^{2}.

On the other hand, to ensure (12) hold for TT iterations, we need,

σϵ4ηl3p3δ3/2=Ω~(l2p3δ3/2ηmϵ4)\displaystyle\cfrac{\sigma\epsilon^{4}}{\eta l^{3}p^{3}\delta^{3/2}}=\tilde{\Omega}\left(\cfrac{l^{2}p^{3}\delta^{3/2}}{\eta m\epsilon^{4}}\right)

which implies we need m=Ω~(l10p12δ6ϵ16)m=\tilde{\Omega}\left(\cfrac{l^{10}p^{12}\delta^{6}}{\epsilon^{16}}\right). ∎

Now, for any (x,y=+1)𝒟(x,y=+1)\sim\mathcal{D} and (x,y=1)𝒟(x,y=-1)\sim\mathcal{D}, let us denote the index of the class-discriminative patterns i.e., o1o_{1} and o2o_{2} as jo1j_{o_{1}} and jo2j_{o_{2}}, respectively.

Definition F.4.

At any iteration tt of minibatch SGD of the joint-training pMoE (i.e., Step-2 of Algorithm 2),

  1. 1.

    For any (x,y=+1)𝒟(x,y=+1)\sim\mathcal{D} and the expert s[k]s\in[k], define the event that o1o_{1} in Top-ll as, 1,s(t):jo1Js(ws(t),x)\mathcal{E}_{1,s}^{(t)}:j_{o_{1}}\in J_{s}(w_{s}^{(t)},x). Similarly, for any (x,y=1)𝒟(x,y=-1)\sim\mathcal{D} define the event that o2o_{2} in Top-ll as, 2,s(t):jo2Js(ws(t),x)\mathcal{E}_{2,s}^{(t)}:j_{o_{2}}\in J_{s}(w_{s}^{(t)},x).

  2. 2.

    For any expert s[k]s\in[k], define the probability of the event that o1o_{1} in Top-ll as, p1,s(t):=𝒟|y=+1[1,s(t)|y=+1]p_{1,s}^{(t)}:=\mathbb{P}_{\mathcal{D}|y=+1}\left[\mathcal{E}_{1,s}^{(t)}\big{|}y=+1\right] and the probability of the event that o2o_{2} in Top-ll as, p2,s(t):=𝒟|y=1[2,s(t)|y=1]p_{2,s}^{(t)}:=\mathbb{P}_{\mathcal{D}|y=-1}\left[\mathcal{E}_{2,s}^{(t)}\big{|}y=-1\right]

  3. 3.

    For any expert s[k]s\in[k] define, v1,s(t):=𝔼𝒟|y=+1,1,s(t)[p1,s(t)Gjo1,s(t)(x)v(t)(θ(t),x,y)|y=+1,1,s(t)]v_{1,s}^{(t)}:=\mathbb{E}_{\mathcal{D}|y=+1,\mathcal{E}_{1,s}^{(t)}}\left[p_{1,s}^{(t)}G_{j_{o_{1}},s}^{(t)}(x)v^{(t)}(\theta^{(t)},x,y)\big{|}y=+1,\mathcal{E}_{1,s}^{(t)}\right] and v2,s(t):=𝔼𝒟|y=1,2,s(t)[p2,s(t)Gjo2,s(t)(x)v(t)(θ(t),x,y)|y=1,2,s(t)]v_{2,s}^{(t)}:=\mathbb{E}_{\mathcal{D}|y=-1,\mathcal{E}_{2,s}^{(t)}}\left[p_{2,s}^{(t)}G_{j_{o_{2}},s}^{(t)}(x)v^{(t)}(\theta^{(t)},x,y)\big{|}y=-1,\mathcal{E}_{2,s}^{(t)}\right] where Gjo1,s(t)(x)G_{j_{o_{1}},s}^{(t)}(x) and Gjo2,s(t)(x)G_{j_{o_{2}},s}^{(t)}(x) denote the gating value for the class-discriminative patterns o1o_{1} and o2o_{2} conditioned on 1,s(t)\mathcal{E}_{1,s}^{(t)} and 2,s(t)\mathcal{E}_{2,s}^{(t)}, respectively.

Theorem F.5.

(Full version of the Theorem 4.5) Suppose Assumption 4.4 hold. Then for every ϵ>0\epsilon>0, for every mMJ=Ω~(k3n2l6p12δ6/ϵ16)m\geq M_{J}=\tilde{\Omega}\left(k^{3}n^{2}l^{6}p^{12}\delta^{6}\big{/}\epsilon^{16}\right) with at least NJ=Ω~(k4l6p12δ6/ϵ16)N_{J}=\tilde{\Omega}(k^{4}l^{6}p^{12}\delta^{6}/\epsilon^{16}) training samples, after performing minibatch SGD with the batch size B=Ω~(k2l4p6δ3/ϵ8)B=\tilde{\Omega}\left(k^{2}l^{4}p^{6}\delta^{3}\big{/}\epsilon^{8}\right) and the learning rate η=O~(1/mpoly(l,p,δ,1/ϵ,logm))\eta=\tilde{O}\big{(}1\big{/}mpoly(l,p,\delta,1/\epsilon,\log m)\big{)} for T=O~(k2l2p6δ3/ϵ8)T=\tilde{O}\left(k^{2}l^{2}p^{6}\delta^{3}\big{/}\epsilon^{8}\right) iterations, it holds w.h.p. that

(x,y)𝒟[yfM(θ(T),x)>0]1ϵ\underset{(x,y)\sim\mathcal{D}}{\mathbb{P}}\left[yf_{M}(\theta^{(T)},x)>0\right]\geq 1-\epsilon

Proof.

From the argument of the proof of Theorem F.3, we know that for any ϵ<12\epsilon<\frac{1}{2}, if (x,y)𝒟[yf(θ(t),x)>0]1ϵ\mathbb{P}_{(x,y)\sim\mathcal{D}}\left[yf(\theta^{(t)},x)>0\right]\leq 1-\epsilon, then max{v1(t),v2(t)}ϵ2\max\{v_{1}^{(t)},v_{2}^{(t)}\}\geq\epsilon^{2} where v1(t):=𝔼𝒟|y=+1[v(t)(θ(t),x,y)|y=+1]v_{1}^{(t)}:=\mathbb{E}_{\mathcal{D}|y=+1}[v^{(t)}(\theta^{(t)},x,y)|y=+1] and v2(t):=𝔼𝒟|y=1[v(t)(θ(t),x,y)|y=1]v_{2}^{(t)}:=\mathbb{E}_{\mathcal{D}|y=-1}[v^{(t)}(\theta^{(t)},x,y)|y=-1]

Now, we will consider the case when tTt\geq T^{\prime} where TT^{\prime} is defined in Assumption 4.4.

Now, if the expert s1[k]s_{1}\in[k] satisfies Assumption 4.4 for y=+1y=+1, then p1,s1(t)=1p_{1,s_{1}}^{(t)}=1 and Gjo1,s1(t)(x)1lG_{j_{o_{1}},s_{1}}^{(t)}(x)\geq\cfrac{1}{l} for any (x,y=+1)𝒟(x,y=+1)\sim\mathcal{D}. Therefore, v1,s1(t)v1(t)lv_{1,s_{1}}^{(t)}\geq\cfrac{v_{1}^{(t)}}{l}.

Similarly, if the expert s2[k]s_{2}\in[k] satisfies Assumption 4.4 for y=1y=-1, then v2,s2(t)v2(t)lv_{2,s_{2}}^{(t)}\geq\cfrac{v_{2}^{(t)}}{l}.

Now for any expert s[k]s\in[k], let us define vs(t):=max{v1,s(t),v2,s(t)}v_{s}^{(t)}:=\max\{v_{1,s}^{(t)},v_{2,s}^{(t)}\}

Now, if v(t):=s[k](vs(t))2ϵ2lv^{(t)}:=\sqrt{\underset{s\in[k]}{\sum}(v_{s}^{(t)})^{2}}\leq\cfrac{\epsilon^{2}}{l}, then vs1(t)ϵ2lv_{s_{1}}^{(t)}\leq\cfrac{\epsilon^{2}}{l} and vs2(t)ϵ2lv_{s_{2}}^{(t)}\leq\cfrac{\epsilon^{2}}{l}.

This implies, max{v1,s1(t),v2,s1(t)}ϵ2l\max\{v_{1,s_{1}}^{(t)},v_{2,s_{1}}^{(t)}\}\leq\cfrac{\epsilon^{2}}{l} and max{v1,s2(t),v2,s2(t)}ϵ2l\max\{v_{1,s_{2}}^{(t)},v_{2,s_{2}}^{(t)}\}\leq\cfrac{\epsilon^{2}}{l}.

Therefore, v1,s1(t)ϵ2lv_{1,s_{1}}^{(t)}\leq\cfrac{\epsilon^{2}}{l} and v2,s2(t)ϵ2lv_{2,s_{2}}^{(t)}\leq\cfrac{\epsilon^{2}}{l} which implies v1(t)ϵ2v_{1}^{(t)}\leq\epsilon^{2} and v2(t)ϵ2v_{2}^{(t)}\leq\epsilon^{2}.

In that case, max{v1(t),v2(t)}ϵ2\max\{v_{1}^{(t)},v_{2}^{(t)}\}\leq\epsilon^{2}.

Therefore, by taking v(t)ϵ2lv^{(t)}\geq\cfrac{\epsilon^{2}}{l}, using the results of Lemma H.4 and following same procedure as in Theorem F.3 we can complete the proof. ∎

Theorem F.6.

(Full version of the Theorem 4.3) For every ϵ>0\epsilon>0, for every mMC=Ω~(n10p12δ6/ϵ16)m\geq M_{C}=\tilde{\Omega}\left(n^{10}p^{12}\delta^{6}\big{/}\epsilon^{16}\right) with at least NC=Ω~(n8p12δ6/ϵ16)N_{C}=\tilde{\Omega}(n^{8}p^{12}\delta^{6}/\epsilon^{16}) training samples, after performing minibatch SGD with the batch size B=Ω~(n4p6δ3/ϵ8)B=\tilde{\Omega}\left(n^{4}p^{6}\delta^{3}\big{/}\epsilon^{8}\right) and the learning rate η=O~(1/mpoly(n,p,δ,1/ϵ,logm))\eta=\tilde{O}\big{(}1\big{/}m\textrm{poly}(n,p,\delta,1/\epsilon,\log m)\big{)} for T=O~(n4p6δ3/ϵ8)T=\tilde{O}\left(n^{4}p^{6}\delta^{3}\big{/}\epsilon^{8}\right) iterations, it holds w.h.p. that

(x,y)𝒟[yfC(θ(T),x)>0]1ϵ\underset{(x,y)\sim\mathcal{D}}{\mathbb{P}}\left[yf_{C}(\theta^{(T)},x)>0\right]\geq 1-\epsilon

Proof.

From the argument of the proof of Theorem F.3, we know that for any ϵ<12\epsilon<\frac{1}{2}, if (x,y)𝒟[yf(θ(t),x)>0]1ϵ\mathbb{P}_{(x,y)\sim\mathcal{D}}\left[yf(\theta^{(t)},x)>0\right]\leq 1-\epsilon, then v(t):=max{v1(t),v2(t)}ϵ2v^{(t)}:=\max\{v_{1}^{(t)},v_{2}^{(t)}\}\geq\epsilon^{2} where v1(t):=𝔼𝒟|y=+1[v(t)(θ(t),x,y)|y=+1]v_{1}^{(t)}:=\mathbb{E}_{\mathcal{D}|y=+1}[v^{(t)}(\theta^{(t)},x,y)|y=+1] and v2(t):=𝔼𝒟|y=1[v(t)(θ(t),x,y)|y=1]v_{2}^{(t)}:=\mathbb{E}_{\mathcal{D}|y=-1}[v^{(t)}(\theta^{(t)},x,y)|y=-1].

Therefore, taking v(t)ϵ2v^{(t)}\geq\epsilon^{2}, using the results of Lemma I.3 and following similar procedure as in Theorem F.3 we can complete the proof. ∎

Appendix G Lemmas Used to Prove the Theorem 4.2

For any iteration tt of the Step-3 of Algorithm 1, recall the loss function for a single-sample generated by the distribution 𝒟\mathcal{D}, (θ(t),x,y):=log(1+eyfM(θ(t),x))\mathcal{L}(\theta^{(t)},x,y):=\log(1+e^{-yf_{M}(\theta^{(t)},x)}). The gradient of the loss for a single sample with respect to the hidden nodes of the experts:

(θ(t),x,y)wr,s(t)=yar,sv(t)(θ(t),x,y)(1ljJs(ws(t),x)x(j)1wr,s(t),x(j)0)\frac{\partial\mathcal{L}(\theta^{(t)},x,y)}{\partial w_{r,s}^{(t)}}=-ya_{r,s}v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{l}\underset{j\in J_{s}(w_{s}^{(t)},x)}{\sum}x^{(j)}1_{\langle w_{r,s}^{(t)},x^{(j)}\rangle\geq 0}\right) (13)

We define the corresponding pseudo-gradient as:

(θ(t),x,y)wr,s(t)=yar,sv(t)(θ(t),x,y)(1ljJs(ws(t),x)x(j)1wr,s(0),x(j)0)\frac{\overset{\sim}{\partial}\mathcal{L}(\theta^{(t)},x,y)}{\partial w_{r,s}^{(t)}}=-ya_{r,s}v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{l}\underset{j\in J_{s}(w_{s}^{(t)},x)}{\sum}x^{(j)}1_{\langle w_{r,s}^{(0)},x^{(j)}\rangle\geq 0}\right) (14)

Therefore, the expected pseudo-gradient:

^(θ(t))wr,s(t)\displaystyle\frac{\overset{\sim}{\partial}\hat{\mathcal{L}}(\theta^{(t)})}{\partial w_{r,s}^{(t)}} =𝔼𝒟[(θ(t),x,y)wr,s(t)]\displaystyle=\mathbb{E}_{\mathcal{D}}\left[\frac{\overset{\sim}{\partial}\mathcal{L}(\theta^{(t)},x,y)}{\partial w_{r,s}^{(t)}}\right]
=ar,s2(𝔼𝒟|y=+1[v(t)(θ(t),x,y)(1ljJs(ws(t),x)x(j)1wr,s(0),x(j)0)|y=+1]\displaystyle=-\cfrac{a_{r,s}}{2}\left(\mathbb{E}_{\mathcal{D}|y=+1}\left[v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{l}\underset{j\in J_{s}(w_{s}^{(t)},x)}{\sum}x^{(j)}1_{\langle w_{r,s}^{(0)},x^{(j)}\rangle\geq 0}\right)\Big{|}y=+1\right]\right.
𝔼𝒟|y=1[v(t)(θ(t),x,y)(1ljJs(ws(t),x)x(j)1wr,s(0),Pjx0)|y=1])\displaystyle\left.\hskip 42.67912pt-\mathbb{E}_{\mathcal{D}|y=-1}\left[v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{l}\underset{j\in J_{s}(w_{s}^{(t)},x)}{\sum}x^{(j)}1_{\langle w_{r,s}^{(0)},P_{j}x\rangle\geq 0}\right)\Big{|}y=-1\right]\right)
=ar,s2Pr,s(t)\displaystyle=-\cfrac{a_{r,s}}{2}P_{r,s}^{(t)}

Here,

Pr,s(t)\displaystyle P_{r,s}^{(t)} =𝔼𝒟|y=+1[v(t)(θ(t),x,y)(1ljJs(ws(t),x)x(j)1wr,s(0),x(j)0)|y=+1]\displaystyle=\mathbb{E}_{\mathcal{D}|y=+1}\left[v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{l}\underset{j\in J_{s}(w_{s}^{(t)},x)}{\sum}x^{(j)}1_{\langle w_{r,s}^{(0)},x^{(j)}\rangle\geq 0}\right)\Big{|}y=+1\right]
𝔼𝒟|y=1[v(t)(θ(t),x,y)(1ljJs(ws(t),x)x(j)1wr,s(0),x(j)0)|y=1]\displaystyle\hskip 42.67912pt-\mathbb{E}_{\mathcal{D}|y=-1}\left[v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{l}\underset{j\in J_{s}(w_{s}^{(t)},x)}{\sum}x^{(j)}1_{\langle w_{r,s}^{(0)},x^{(j)}\rangle\geq 0}\right)\Big{|}y=-1\right]
Lemma G.1.

W.h.p. over the random initialization of the hidden nodes of the experts defined in 8, for every (x,y)𝒟(x,y)\sim\mathcal{D} and for every τ>0\tau>0, for every t=O~(τη)t=\tilde{O}\left(\cfrac{\tau}{\eta}\right) of the Step-3 of Algorithm 1, we have that for at least (12eτlσ)\left(1-\cfrac{2e\tau l}{\sigma}\right) fraction of r[m/2]r\in[m/2] of the expert s{1,2}s\in\{1,2\}:

(θ(t),x,y)wr,s(t)=(θ(t),x,y)wr,s(t)\cfrac{\partial\mathcal{L}(\theta^{(t)},x,y)}{\partial w_{r,s}^{(t)}}=\cfrac{\overset{\sim}{\partial}\mathcal{L}(\theta^{(t)},x,y)}{\partial w_{r,s}^{(t)}} and |wr,s(t),x(j)|τ,jJs(ws(t),x)|\langle w_{r,s}^{(t)},x^{(j)}\rangle|\geq\tau,\forall j\in J_{s}(w_{s}^{(t)},x)



Proof.

Recall the gradient of the loss for single-sample (x,y)𝒟(x,y)\sim\mathcal{D} w.r.t. the hidden node r[m/2]r\in[m/2] of the expert s{1,2}s\in\{1,2\}:

(θ(t),x,y)wr,s(t)=yar,sv(t)(θ(t),x,y)(1ljJs(ws(t),x)x(j)1wr,s(t),x(j)0)\displaystyle\frac{\partial\mathcal{L}(\theta^{(t)},x,y)}{\partial w_{r,s}^{(t)}}=-ya_{r,s}v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{l}\underset{j\in J_{s}(w_{s}^{(t)},x)}{\sum}x^{(j)}1_{\langle w_{r,s}^{(t)},x^{(j)}\rangle\geq 0}\right)

and the corresponding pseudo-gradient:

(θ(t),x,y)wr,s(t)=yar,sv(t)(θ(t),x,y)(1ljJs(ws(t),x)x(j)1wr,s(0),x(j)0)\displaystyle\frac{\overset{\sim}{\partial}\mathcal{L}(\theta^{(t)},x,y)}{\partial w_{r,s}^{(t)}}=-ya_{r,s}v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{l}\underset{j\in J_{s}(w_{s}^{(t)},x)}{\sum}x^{(j)}1_{\langle w_{r,s}^{(0)},x^{(j)}\rangle\geq 0}\right)

Now, ar,s𝒩(0,1)a_{r,s}\sim\mathcal{N}(0,1). Hence, using the concentration bound of Gaussian random variable (i.e., for X𝒩(0,σ2):Pr[|X|t]12et2/2σ2X\sim\mathcal{N}(0,\sigma^{2}):Pr[|X|\leq t]\geq 1-2e^{-t^{2}/2\sigma^{2}}) and as O~(.)\tilde{O}(.) hides factor log(poly(m,n,p,δ,1ϵ))\log\left(poly(m,n,p,\delta,\frac{1}{\epsilon})\right) we get:

[|ar,s|=O~(1)]11poly(m,n,p,δ,1ϵ) (i.e., w.h.p.)\displaystyle\mathbb{P}\left[|a_{r,s}|=\tilde{O}(1)\right]\geq 1-\frac{1}{poly(m,n,p,\delta,\frac{1}{\epsilon})}\text{ (i.e., w.h.p.)}

Now as v(t)(θ(t),x,y)1v^{(t)}(\theta^{(t)},x,y)\leq 1 and x(j)=1||x^{(j)}||=1, w.h.p. (θ(t),x,y)wr,s(t)=O~(1)\left|\left|\frac{\partial\mathcal{L}(\theta^{(t)},x,y)}{\partial w_{r,s}^{(t)}}\right|\right|=\tilde{O}(1) so as the mini-batch gradient, (θ(t))wr,s(t)=O~(1)\left|\left|\frac{\partial\mathcal{L}(\theta^{(t)})}{\partial w_{r,s}^{(t)}}\right|\right|=\tilde{O}(1).
Now, from the update rule of the Step-3 of Algorithm 1, wr,s(t)wr,s(t+1)=η(θ(t))wr,s(t)w_{r,s}^{(t)}-w_{r,s}^{(t+1)}=\eta\frac{\partial\mathcal{L}(\theta^{(t)})}{\partial w_{r,s}^{(t)}}

Therefore, using the property of Telescoping series, wr,s(0)wr,s(t)=ηi=1𝑡(θ(i1))wr,s(t)\displaystyle w_{r,s}^{(0)}-w_{r,s}^{(t)}=\eta\overset{t}{\underset{i=1}{\sum}}\frac{\partial\mathcal{L}(\theta^{(i-1)})}{\partial w_{r,s}^{(t)}}
Therefore, wr,s(t)wr,s(0)=Υηt\left|\left|w_{r,s}^{(t)}-w_{r,s}^{(0)}\right|\right|=\Upsilon\eta t where we denote O~(1)\tilde{O}(1) by Υ\Upsilon

Now, for every τ>0,\tau>0, consider the set s:={r[m/2]:jJs(ws(t),x),|wr,s(0),x(j)|2τ}\mathcal{H}_{s}:=\left\{r\in[m/2]:\forall j\in J_{s}(w_{s}^{(t)},x),|\langle w_{r,s}^{(0)},x^{(j)}\rangle|\geq 2\tau\right\}

Now, for every tτΥηt\leq\cfrac{\tau}{\Upsilon\eta}, |wr,s(t)wr,s(0),x(j)|τ|\langle w_{r,s}^{(t)}-w_{r,s}^{(0)},x^{(j)}\rangle|\leq\tau
Which implies for every rsr\in\mathcal{H}_{s}, tτΥηt\leq\cfrac{\tau}{\Upsilon\eta} and jJs(ws(t),x)j\in J_{s}(w_{s}^{(t)},x), |wr,s(t),x(j)|τ|\langle w_{r,s}^{(t)},x^{(j)}\rangle|\geq\tau

Therefore, for every rsr\in\mathcal{H}_{s}, t=O~(τη)t=\tilde{O}(\frac{\tau}{\eta}) and jJs(ws(t),x)j\in J_{s}(w_{s}^{(t)},x), 1wr,s(t),x(j)0=1wr,s(0),x(j)01_{\langle w_{r,s}^{(t)},x^{(j)}\rangle\geq 0}=1_{\langle w_{r,s}^{(0)},x^{(j)}\rangle\geq 0} and hence, (θ(t),x,y)wr,s(t)=(θ(t),x,y)wr,s(t)\cfrac{\partial\mathcal{L}(\theta^{(t)},x,y)}{\partial w_{r,s}^{(t)}}=\cfrac{\overset{\sim}{\partial}\mathcal{L}(\theta^{(t)},x,y)}{\partial w_{r,s}^{(t)}}

Now, we will find the lower bound of |s|:|\mathcal{H}_{s}|:

As, wr,s(0)𝒩(0,σ2𝕀d×d),jJs(ws(t),x),wr,s(0),x(j)𝒩(0,σ2)w_{r,s}^{(0)}\sim\mathcal{N}(0,\sigma^{2}\mathbb{I}_{d\times d}),\forall j\in J_{s}(w_{s}^{(t)},x),\langle w_{r,s}^{(0)},x^{(j)}\rangle\sim\mathcal{N}(0,\sigma^{2})

Hence, [|wr,s(0),x(j)|2τ]2eτσ\mathbb{P}[|\langle w_{r,s}^{(0)},x^{(j)}\rangle|\leq 2\tau]\leq\cfrac{2e\tau}{\sigma}
Now as |Js(ws(t),x)|=l|J_{s}(w_{s}^{(t)},x)|=l, [jJs(ws(t),x),|wr,s(0),x(j)|2τ]12eτlσ\mathbb{P}[\forall j\in J_{s}(w_{s}^{(t)},x),|\langle w_{r,s}^{(0)},x^{(j)}\rangle|\geq 2\tau]\geq 1-\cfrac{2e\tau l}{\sigma}

Therefore, |s|(12eτlσ)m2|\mathcal{H}_{s}|\geq\left(1-\cfrac{2e\tau l}{\sigma}\right)\frac{m}{2}

Using the following two lemmas we show that when v1(t)v_{1}^{(t)} is large, the expected pseudo-gradient of the loss function w.r.t. the hidden nodes of the expert 1 is large. Similar thing happens for expert 2 when v2(t)v_{2}^{(t)} is large. We prove the first of these two lemmas for a fixed set {v(t)(θ(t),x,y):(x,y)𝒟}\{v^{(t)}(\theta^{(t)},x,y):(x,y)\sim\mathcal{D}\} which does not depend on the random initialization of the hidden nodes of the experts (i.e., on {wr,s(0)}\{w_{r,s}^{(0)}\}). In the second of these two lemmas we remove the dependency on fixed set by means of a sampling trick introduced in (Li & Liang, 2018) to take a union bound over an epsilon-net on the set {v(t)(θ(t),x,y):(x,y)𝒟}\{v^{(t)}(\theta^{(t)},x,y):(x,y)\sim\mathcal{D}\}.

Lemma G.2.

For any possible fixed set {v(t)(θ(t),x,y):(x,y)𝒟}\{v^{(t)}(\theta^{(t)},x,y):(x,y)\sim\mathcal{D}\} (that does not depend on wr,s(0)w_{r,s}^{(0)}) such that vs(t)=v1(t)v_{s}^{(t)}=v_{1}^{(t)} for s=1s=1 and vs(t)=v2(t)v_{s}^{(t)}=v_{2}^{(t)} for s=2s=2 we have for every lll\geq l^{*}:

[Pr,s(t)=Ω(vs(t)lpδ)]=Ω(1pδ)\mathbb{P}\left[||P_{r,s}^{(t)}||=\overset{\sim}{\Omega}\left(\cfrac{v_{s}^{(t)}}{lp\sqrt{\delta}}\right)\right]=\Omega\left(\cfrac{1}{p\sqrt{\delta}}\right)

Proof.

WLOG, let’s assume s=1s=1. Now,

Pr,1(t)\displaystyle P_{r,1}^{(t)} =𝔼𝒟|y=+1[v(t)(θ(t),x,y)(1ljJ1(w1(t),x)x(j)1wr,1(0),x(j)0)|y=+1]\displaystyle=\mathbb{E}_{\mathcal{D}|y=+1}\left[v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{l}\underset{j\in J_{1}(w_{1}^{(t)},x)}{\sum}x^{(j)}1_{\langle w_{r,1}^{(0)},x^{(j)}\rangle\geq 0}\right)\Big{|}y=+1\right]
𝔼𝒟|y=1[v(t)(θ(t),x,y)(1ljJ1(w1(t),x)x(j)1wr,1(0),x(j)0)|y=1]\displaystyle\hskip 42.67912pt-\mathbb{E}_{\mathcal{D}|y=-1}\left[v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{l}\underset{j\in J_{1}(w_{1}^{(t)},x)}{\sum}x^{(j)}1_{\langle w_{r,1}^{(0)},x^{(j)}\rangle\geq 0}\right)\Big{|}y=-1\right]

Then,

h(wr,1(0))\displaystyle h(w_{r,1}^{(0)}) :=Pr,1,wr,1(0)\displaystyle:=\langle P_{r,1},w_{r,1}^{(0)}\rangle
=𝔼𝒟|y=+1[v(t)(θ(t),x,y)(1ljJ1(w1(t),x)ReLU(wr,1(0),x(j)))|y=+1]\displaystyle=\mathbb{E}_{\mathcal{D}|y=+1}\left[v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{l}\underset{j\in J_{1}(w_{1}^{(t)},x)}{\sum}\textbf{ReLU}\left(\langle w_{r,1}^{(0)},x^{(j)}\rangle\right)\right)\Big{|}y=+1\right]
𝔼𝒟|y=1[v(t)(θ(t),x,y)(1ljJ1(w1(t),x)ReLU(wr,1(0),x(j)))|y=1]\displaystyle\hskip 34.14322pt-\mathbb{E}_{\mathcal{D}|y=-1}\left[v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{l}\underset{j\in J_{1}(w_{1}^{(t)},x)}{\sum}\textbf{ReLU}\left(\langle w_{r,1}^{(0)},x^{(j)}\rangle\right)\right)\Big{|}y=-1\right]

Now, let us decompose wr,1(0)=αo1+βw_{r,1}^{(0)}=\alpha o_{1}+\beta, where βo1\beta\perp o_{1}

Then,

h(wr,1(0))\displaystyle h(w_{r,1}^{(0)}) =𝔼𝒟|y=+1[v(t)(θ(t),x,y)(1ljJ1(w1(t),x)ReLU(αo1,x(j)+β,x(j)))|y=+1]\displaystyle=\mathbb{E}_{\mathcal{D}|y=+1}\left[v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{l}\underset{j\in J_{1}(w_{1}^{(t)},x)}{\sum}\textbf{ReLU}\left(\alpha\langle o_{1},x^{(j)}\rangle+\langle\beta,x^{(j)}\rangle\right)\right)\Big{|}y=+1\right]
𝔼𝒟|y=1[v(t)(θ(t),x,y)(1ljJ1(w1(t),x)ReLU(αo1,x(j)+β,x(j)))|y=1]\displaystyle\hskip 5.69046pt-\mathbb{E}_{\mathcal{D}|y=-1}\left[v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{l}\underset{j\in J_{1}(w_{1}^{(t)},x)}{\sum}\textbf{ReLU}\left(\alpha\langle o_{1},x^{(j)}\rangle+\langle\beta,x^{(j)}\rangle\right)\right)\Big{|}y=-1\right]
=ϕ(α)l(α)\displaystyle=\phi(\alpha)-l(\alpha)

Where,

ϕ(α):=\displaystyle\phi(\alpha):= 𝔼𝒟|y=+1[v(t)(θ(t),x,y)(1ljJ1(w1(t),x)ReLU(αo1,x(j)+β,x(j)))|y=+1]\displaystyle\mathbb{E}_{\mathcal{D}|y=+1}\left[v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{l}\underset{j\in J_{1}(w_{1}^{(t)},x)}{\sum}\textbf{ReLU}\left(\alpha\langle o_{1},x^{(j)}\rangle+\langle\beta,x^{(j)}\rangle\right)\right)\Big{|}y=+1\right]

and,

l(α):=𝔼𝒟|y=1[v(t)(θ(t),x,y)(1ljJ1(w1(t),x)ReLU(αo1,x(j)+β,x(j)))|y=1]\displaystyle l(\alpha):=\mathbb{E}_{\mathcal{D}|y=-1}\left[v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{l}\underset{j\in J_{1}(w_{1}^{(t)},x)}{\sum}\textbf{ReLU}\left(\alpha\langle o_{1},x^{(j)}\rangle+\langle\beta,x^{(j)}\rangle\right)\right)\Big{|}y=-1\right]

Note that, ϕ(α)\phi(\alpha) and l(α)l(\alpha) both are convex functions.

Now for lll\geq l^{*}, using Lemma D.4, we can express ϕ(α)\phi(\alpha) as follows:

ϕ(α)=v1(t)lReLU(α)\displaystyle\phi(\alpha)=\cfrac{v_{1}^{(t)}}{l}\textbf{ReLU}\left(\alpha\right)
+𝔼𝒟|y=+1[v(t)(θ(t),x,y)(1ljJ1(x)/argjJ1(x)x(j)=o1ReLU(αo1,x(j)+β,x(j)))|y=+1]\displaystyle\hskip 5.69046pt+\mathbb{E}_{\mathcal{D}|y=+1}\left[v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{l}\underset{j\in J_{1}(x)/\underset{j\in J_{1}(x)}{\text{arg}}x^{(j)}=o_{1}}{\sum}\textbf{ReLU}\left(\alpha\langle o_{1},x^{(j)}\rangle+\langle\beta,x^{(j)}\rangle\right)\right)\Big{|}y=+1\right]

Now, for any class-irrelevant pattern set SiS_{i} where i[p]i\in[p], let us define qiSiq_{i}^{*}\in S_{i} such that qi=𝔼Si[q]𝔼Si[q]q_{i}^{*}=\cfrac{\mathbb{E}_{S_{i}}[q]}{||\mathbb{E}_{S_{i}}[q]||}. Also, let us define the set, :={qi:i[p]}{o2}\mathcal{H}:=\{q_{i}^{*}:i\in[p]\}\cup\{o_{2}\}

Now let us define the event τ:(i)|α|τ;(ii)q:|β,q|4τ\mathcal{E}_{\tau}:(i)\hskip 2.84544pt|\alpha|\leq\tau;\hskip 2.84544pt(ii)\hskip 2.84544pt\forall q^{\prime}\in\mathcal{H}:|\langle\beta,q^{\prime}\rangle|\geq 4\tau

Now, as α𝒩(0,σ2)\alpha\sim\mathcal{N}(0,\sigma^{2}), for every q,β,q𝒩(0,(1o1,q2)σ2)q^{\prime}\in\mathcal{H},\langle\beta,q^{\prime}\rangle\sim\mathcal{N}\left(0,\left(1-\langle o_{1},q^{\prime}\rangle^{2}\right)\sigma^{2}\right)

Now, 1o1,q21δ1-\langle o_{1},q^{\prime}\rangle^{2}\geq\frac{1}{\delta}. Hence, [q:|β,q|4τ]4eτpδσ\mathbb{P}[\exists q^{\prime}\in\mathcal{H}:|\langle\beta,q^{\prime}\rangle|\leq 4\tau]\leq\cfrac{4e\tau p\sqrt{\delta}}{\sigma}

Therefore, [q:|β,q|4τ]14eτpδσ\mathbb{P}[\forall q^{\prime}\in\mathcal{H}:|\langle\beta,q^{\prime}\rangle|\geq 4\tau]\geq 1-\cfrac{4e\tau p\sqrt{\delta}}{\sigma}.

Picking, τσ8epδ\tau\leq\cfrac{\sigma}{8ep\sqrt{\delta}} gives, [q:|β,q|4τ]12\mathbb{P}[\forall q^{\prime}\in\mathcal{H}:|\langle\beta,q^{\prime}\rangle|\geq 4\tau]\geq\cfrac{1}{2}.

On the other hand, [|α|τ]=Ω(τσ)\mathbb{P}[|\alpha|\leq\tau]=\Omega\left(\cfrac{\tau}{\sigma}\right). Therefore, [τ]=Ω(τσ)\mathbb{P}[\mathcal{E}_{\tau}]=\Omega\left(\cfrac{\tau}{\sigma}\right)

Now, i[p]\forall i\in[p] s.t. qSi,𝔼[|wr,1(0),qqi|]𝔼𝒩(0,σ2𝕀d×d)[wr,1(0)]𝔼Si[qqi]τq\in S_{i},\mathbb{E}[|\langle w_{r,1}^{(0)},q-q_{i}^{*}\rangle|]\leq\mathbb{E}_{\mathcal{N}(0,\sigma^{2}\mathbb{I}_{d\times d})}[||w_{r,1}^{(0)}||]\mathbb{E}_{S_{i}}[||q-q_{i}^{*}||]\leq\tau, where the last inequality comes from the bound of the diameter of the pattern sets and the fact that for any X𝒩(0,σ2𝕀d×d),𝔼[X]4σdX\sim\mathcal{N}(0,\sigma^{2}\mathbb{I}_{d\times d}),\mathbb{E}[||X||]\leq 4\sigma\sqrt{d}.

Therefore, using Markov’s inequality i[p]\forall i\in[p] s.t. qSi,[|wr,1(0),qqi|2τ]12q\in S_{i},\mathbb{P}[|\langle w_{r,1}^{(0)},q-q_{i}^{*}\rangle|\leq 2\tau]\geq\frac{1}{2}

Now,
i[p],\forall i\in[p], s.t. qSi,ReLU(αo1,q+β,q)=ReLU(αo1,qi+β,qi+wr,1(0),qqi)q\in S_{i},\textbf{ReLU}\left(\alpha\langle o_{1},q\rangle+\langle\beta,q\rangle\right)=\textbf{ReLU}\left(\alpha\langle o_{1},q_{i}^{*}\rangle+\langle\beta,q_{i}^{*}\rangle+\langle w_{r,1}^{(0)},q-q_{i}^{*}\rangle\right)

Now, conditioned on the event τ\mathcal{E}_{\tau}, for a fixed β\beta and α\alpha is the only random variable,
i[p]\forall i\in[p] s.t. qSi,ReLU(αo1,q+β,q)=(αo1,q+β,q)1β,qi0q\in S_{i},\textbf{ReLU}\left(\alpha\langle o_{1},q\rangle+\langle\beta,q\rangle\right)=\left(\alpha\langle o_{1},q\rangle+\langle\beta,q\rangle\right)1_{\langle\beta,q_{i}^{*}\rangle\geq 0} which is a linear function of α[τ,τ]\alpha\in[-\tau,\tau] with probability at least 12\frac{1}{2} and, ReLU(αo1,o2+β,o2)=(αo1,o2+β,o2)1β,o20\textbf{ReLU}\left(\alpha\langle o_{1},o_{2}\rangle+\langle\beta,o_{2}\rangle\right)=\left(\alpha\langle o_{1},o_{2}\rangle+\langle\beta,o_{2}\rangle\right)1_{\langle\beta,o_{2}\rangle\geq 0} which is a linear function of α[τ,τ]\alpha\in[-\tau,\tau] with probability 11.

Now, let us define {l(α)}\{\partial l(\alpha)\} and {ϕ(α)}\{\partial\phi(\alpha)\} as the set of sub-gradient at the point α\alpha for l(α)l(\alpha) and ϕ(α)\phi(\alpha) respectively such that maxl(α)=max{l(α)}\partial_{\text{max}}l(\alpha)=\text{max}\{\partial l(\alpha)\}, maxϕ(α)=max{ϕ(α)}\partial_{\text{max}}\phi(\alpha)=\text{max}\{\partial\phi(\alpha)\}, minl(α)=min{l(α)}\partial_{\text{min}}l(\alpha)=\text{min}\{\partial l(\alpha)\} and minϕ(α)=min{ϕ(α)}\partial_{\text{min}}\phi(\alpha)=\text{min}\{\partial\phi(\alpha)\}.

Then, using the above argument, conditioned on the event τ\mathcal{E}_{\tau}, maxl(τ)minl(τ)=0\partial_{\text{max}}l(\tau)-\partial_{\text{min}}l(-\tau)=0.
On the other hand, maxϕ(τ/2)minϕ(τ/2)=v1(t)l\partial_{\text{max}}\phi(\tau/2)-\partial_{\text{min}}\phi(-\tau/2)=\cfrac{v_{1}^{(t)}}{l}.

Now using Lemma J.1, conditioned on the event τ\mathcal{E}_{\tau}, αU(τ,τ)[|ϕ(α)l(α)|v1(t)τ512l]164\underset{\alpha\sim U(-\tau,\tau)}{\mathbb{P}}\left[|\phi(\alpha)-l(\alpha)|\geq\cfrac{v_{1}^{(t)}\tau}{512l}\right]\geq\cfrac{1}{64}.

Now, for τσ8epδ\tau\leq\cfrac{\sigma}{8ep\sqrt{\delta}}, conditioned on τ\mathcal{E}_{\tau}, the density p(α)[1eτ,eτ]p(\alpha)\in\left[\cfrac{1}{e\tau},\cfrac{e}{\tau}\right], which implies that,

[h(wr,1(0))v1(t)τ128l][h(wr,1(0))v1(t)τ128l|τ][τ]=Ω(τσ)\mathbb{P}\left[h(w_{r,1}^{(0)})\geq\cfrac{v_{1}^{(t)}\tau}{128l}\right]\geq\mathbb{P}\left[h(w_{r,1}^{(0)})\geq\cfrac{v_{1}^{(t)}\tau}{128l}\big{|}\mathcal{E}_{\tau}\right]\mathbb{P}\left[\mathcal{E}_{\tau}\right]=\Omega\left(\cfrac{\tau}{\sigma}\right) (15)

Now, as v1(t)v_{1}^{(t)} does not depends on wr,1(0)w_{r,1}^{(0)}, Pr,1(t),wr,1(0)𝒩(0,σ2Pr,1(t)2)\langle P_{r,1}^{(t)},w_{r,1}^{(0)}\rangle\sim\mathcal{N}(0,\sigma^{2}||P_{r,1}^{(t)}||^{2}).

Now, using a concentration bound of Gaussian RV (i.e., [Xσx]ex2/2\mathbb{P}[X\geq\sigma x]\leq e^{-x^{2}/2}),

[Pr,1(t),wr,1(0)(σPr,1(t))10c]e50c2; here c>10.\mathbb{P}[\langle P_{r,1}^{(t)},w_{r,1}^{(0)}\rangle\geq(\sigma||P_{r,1}^{(t)}||)10c]\leq e^{-50c^{2}};\text{ here }c>10. (16)

Now, taking c=100logpδσc=100\sqrt{\log{\cfrac{p\sqrt{\delta}}{\sigma}}} in (16) we get,

[Pr,1(t),wr,1(0)=Ω~(σPr,1(t))]=o(1)\mathbb{P}[\langle P_{r,1}^{(t)},w_{r,1}^{(0)}\rangle=\tilde{\Omega}(\sigma||P_{r,1}^{(t)}||)]=o(1) (17)

On the other hand, picking τ=Θ(σpδ)\tau=\Theta(\frac{\sigma}{p\sqrt{\delta}}) and plugging in at (15) gives,

[Pr,1(t),wr,1(0)=Ω(σv1(t)lpδ)]=Ω(1pδ)\mathbb{P}\left[\langle P_{r,1}^{(t)},w_{r,1}^{(0)}\rangle=\Omega\left(\sigma\cfrac{v_{1}^{(t)}}{lp\sqrt{\delta}}\right)\right]=\Omega\left(\frac{1}{p\sqrt{\delta}}\right) (18)

Comparing (17) and (18) we get, [Pr,1(t)=Ω~(v1(t)lpδ)]=Ω(1pδ)\mathbb{P}\left[||P_{r,1}^{(t)}||=\tilde{\Omega}\left(\cfrac{v_{1}^{(t)}}{lp\sqrt{\delta}}\right)\right]=\Omega\left(\cfrac{1}{p\sqrt{\delta}}\right)

Lemma G.3.

Let vs(t)=v1(t)v_{s}^{(t)}=v_{1}^{(t)} for s=1s=1 and vs(t)=v2(t)v_{s}^{(t)}=v_{2}^{(t)} for s=2s=2. Then, for every vs(t)>0v_{s}^{(t)}>0, for m=Ω~(l2p3δ3/2(vs(t))2)m=\tilde{\Omega}\left(\cfrac{l^{2}p^{3}\delta^{3/2}}{(v_{s}^{(t)})^{2}}\right), for every possible set {v(t)(θ(t),x,y):(x,y)𝒟}\{v^{(t)}(\theta^{(t)},x,y):(x,y)\sim\mathcal{D}\} (that depends on wr,s(0)w_{r,s}^{(0)}), there exist at least Ω(1pδ)\Omega\left(\frac{1}{p\sqrt{\delta}}\right) fraction of r[m/2]r\in[m/2] of the expert s{1,2}s\in\{1,2\} such that for every lll\geq l^{*},

^(θ(t))wr,s(t)=Ω~(vs(t)lpδ)\left|\left|\cfrac{\overset{\sim}{\partial}\hat{\mathcal{L}}(\theta^{(t)})}{\partial w_{r,s}^{(t)}}\right|\right|=\tilde{\Omega}\left(\cfrac{v_{s}^{(t)}}{lp\sqrt{\delta}}\right)



Proof.

Let us pick SS samples to form S={(xi,yi)}i=1S\textbf{S}=\{(x_{i},y_{i})\}_{i=1}^{S} with S/2S/2 many samples from y=+1y=+1 and S/2S/2 many samples from y=1y=-1. Let us denote the subset of samples with y=+1y=+1 as S1\textbf{S}_{1} and the subset of samples with y=1y=-1 as S2\textbf{S}_{2}. Therefore, |S1|=|S2|=S/2|\textbf{S}_{1}|=|\textbf{S}_{2}|=S/2. Let us denote the corresponding value function of ii-th sample of S as v(t)(θ(t),xi,yi)v^{(t)}(\theta^{(t)},x_{i},y_{i}). Since, each v(t)(θ(t),xi,yi)[0,1]v^{(t)}(\theta^{(t)},x_{i},y_{i})\in[0,1]\, using Hoeffding’s inequality we know that w.h.p. :

|vs(t)1S/2(xi,yi)Ssv(t)(θ(t),xi,yi)|=O~(1S)\displaystyle\left|v_{s}^{(t)}-\cfrac{1}{S/2}\underset{(x_{i},y_{i})\in\textbf{S}_{s}}{\sum}v^{(t)}(\theta^{(t)},x_{i},y_{i})\right|=\tilde{O}\left(\cfrac{1}{\sqrt{S}}\right)

This implies that, as long as S=Ω~(1(vs(t))2)S=\tilde{\Omega}\left(\cfrac{1}{(v_{s}^{(t)})^{2}}\right), we will have that,

1S/2(xi,yi)Ssv(t)(θ(t),xi,yi)[12vs(t),32vs(t)]\displaystyle\cfrac{1}{S/2}\underset{(x_{i},y_{i})\in\textbf{S}_{s}}{\sum}v^{(t)}(\theta^{(t)},x_{i},y_{i})\in\left[\cfrac{1}{2}v_{s}^{(t)},\cfrac{3}{2}v_{s}^{(t)}\right]

Now, the average pseudo-gradient over the set S,

1S(xi,yi)S(θ(t),xi,yi)wr,s(t)\displaystyle\displaystyle\cfrac{1}{S}\underset{(x_{i},y_{i})\in\textbf{S}}{\sum}\cfrac{\overset{\sim}{\partial}\mathcal{L}(\theta^{(t)},x_{i},y_{i})}{\partial w_{r,s}^{(t)}} =1S(xi,yi)Syar,sv(t)(θ(t),xi,yi)(1ljJs(ws(t),xi)xi(j)1wr,s(0),xi(j)0)\displaystyle=\cfrac{1}{S}\underset{(x_{i},y_{i})\in\textbf{S}}{\sum}-ya_{r,s}v^{(t)}(\theta^{(t)},x_{i},y_{i})\left(\cfrac{1}{l}\underset{j\in J_{s}(w_{s}^{(t)},x_{i})}{\sum}x_{i}^{(j)}1_{\langle w_{r,s}^{(0)},x_{i}^{(j)}\rangle\geq 0}\right)
=ar,s2Pr,s(t)(S)\displaystyle=-\cfrac{a_{r,s}}{2}P_{r,s}^{(t)}(\textbf{S})

where,

Pr,s(t)(S)\displaystyle P_{r,s}^{(t)}(\textbf{S}) =1S/2(xi,yi)S1v(t)(θ(t),xi,yi)(1ljJs(ws(t),xi)xi(j)1wr,s(0),xi(j)0)\displaystyle=\cfrac{1}{S/2}\underset{(x_{i},y_{i})\in\textbf{S}_{1}}{\sum}v^{(t)}(\theta^{(t)},x_{i},y_{i})\left(\cfrac{1}{l}\underset{j\in J_{s}(w_{s}^{(t)},x_{i})}{\sum}x_{i}^{(j)}1_{\langle w_{r,s}^{(0)},x_{i}^{(j)}\rangle\geq 0}\right)
1S/2(xi,yi)S2v(t)(θ(t),xi,yi)(1ljJs(ws(t),xi)xi(j)1wr,s(0),xi(j)0)\displaystyle\hskip 28.45274pt-\cfrac{1}{S/2}\underset{(x_{i},y_{i})\in\textbf{S}_{2}}{\sum}v^{(t)}(\theta^{(t)},x_{i},y_{i})\left(\cfrac{1}{l}\underset{j\in J_{s}(w_{s}^{(t)},x_{i})}{\sum}x_{i}^{(j)}1_{\langle w_{r,s}^{(0)},x_{i}^{(j)}\rangle\geq 0}\right)

Now as ar,s𝒩(0,1)a_{r,s}\sim\mathcal{N}(0,1), [1S(xi,yi)S(θ(t),xi,yi)wr,s(t)=ar,s2Pr,s(t)(S)12Pr,s(t)(S)]1e\mathbb{P}\left[\left|\left|\cfrac{1}{S}\underset{(x_{i},y_{i})\in\textbf{S}}{\sum}\cfrac{\overset{\sim}{\partial}\mathcal{L}(\theta^{(t)},x_{i},y_{i})}{\partial w_{r,s}^{(t)}}\right|\right|=\left|\left|\cfrac{a_{r,s}}{2}P_{r,s}^{(t)}(\textbf{S})\right|\right|\geq\cfrac{1}{2}\left|\left|P_{r,s}^{(t)}(\textbf{S})\right|\right|\right]\geq\cfrac{1}{e}

Now for a fixed set {v(t)(θ(t),xi,yi):(xi,yi)S}\{v^{(t)}(\theta^{(t)},x_{i},y_{i}):(x_{i},y_{i})\in\textbf{S}\} as long as S=Ω~(1vs2(t))S=\tilde{\Omega}\left(\cfrac{1}{v_{s}^{2}{(t)}}\right), for every lll\geq l^{*} using Lemma G.2,

[Pr,s(t)(S)=Ω~(vs(t)lpδ)]=Ω(1pδ)\displaystyle\mathbb{P}\left[||P_{r,s}^{(t)}(\textbf{S})||=\tilde{\Omega}\left(\cfrac{v_{s}^{(t)}}{lp\sqrt{\delta}}\right)\right]=\Omega\left(\cfrac{1}{p\sqrt{\delta}}\right)

Hence, for a fixed set {v(t)(θ(t),xi,yi):(xi,yi)S}\{v^{(t)}(\theta^{(t)},x_{i},y_{i}):(x_{i},y_{i})\in\textbf{S}\}, the probability that there are less than O(1pδ)O\left(\frac{1}{p\sqrt{\delta}}\right) fraction of r[m/2]r\in[m/2] such that 1S(xi,yi)S(θ(t),xi,yi)wr,s(t)\left|\left|\cfrac{1}{S}\underset{(x_{i},y_{i})\in\textbf{S}}{\sum}\cfrac{\overset{\sim}{\partial}\mathcal{L}(\theta^{(t)},x_{i},y_{i})}{\partial w_{r,s}^{(t)}}\right|\right| is Ω~(vs(t)lpδ)\tilde{\Omega}\left(\cfrac{v_{s}^{(t)}}{lp\sqrt{\delta}}\right) is no more than pfixp_{\text{fix}} where, pfixexp(Ω(mpδ))p_{\text{fix}}\leq\exp{\left(-\Omega\left(\frac{m}{p\sqrt{\delta}}\right)\right)}.

Moreover, for every ε¯>0\bar{\varepsilon}>0, for two different {v(t)(θ(t),xi,yi):(xi,yi)S}\{v^{(t)}(\theta^{(t)},x_{i},y_{i}):(x_{i},y_{i})\in\textbf{S}\}, {v(t)(θ(t),xi,yi):(xi,yi)S}\{v^{\prime(t)}(\theta^{(t)},x_{i},y_{i}):(x_{i},y_{i})\in\textbf{S}\} such that (xi,yi)S\forall(x_{i},y_{i})\in\textbf{S}, |v(t)(θ(t),xi,yi)v(t)(θ(t),xi,yi)|ε¯|v^{(t)}(\theta^{(t)},x_{i},y_{i})-v^{\prime(t)}(\theta^{(t)},x_{i},y_{i})|\leq\bar{\varepsilon}, since w.h.p. |ar,s|=O~(1)|a_{r,s}|=\tilde{O}(1),

1S(xi,yi)Syar,s(v(t)(θ(t),xi,yi)v(t)(θ(t),xi,yi))(1ljJs(ws(t),xi)xi(j)1wr,s(0),xi(j)0)\displaystyle\left|\left|\cfrac{1}{S}\underset{(x_{i},y_{i})\in\textbf{S}}{\sum}-ya_{r,s}(v^{(t)}(\theta^{(t)},x_{i},y_{i})-v^{\prime(t)}(\theta^{(t)},x_{i},y_{i}))\left(\cfrac{1}{l}\underset{j\in J_{s}(w_{s}^{(t)},x_{i})}{\sum}x_{i}^{(j)}1_{\langle w_{r,s}^{(0)},x_{i}^{(j)}\rangle\geq 0}\right)\right|\right|
=O~(ε¯)\displaystyle=\tilde{O}(\bar{\varepsilon})

which implies that we can take ε¯\bar{\varepsilon}-net with ε¯=Θ~(vs(t)lpδ)\bar{\varepsilon}=\tilde{\Theta}\left(\cfrac{v_{s}^{(t)}}{lp\sqrt{\delta}}\right).

Thus, the probability that there exists {v(t)(θ(t),xi,yi):(xi,yi)S}\{v^{(t)}(\theta^{(t)},x_{i},y_{i}):(x_{i},y_{i})\in\textbf{S}\} such that there are no more than O(1pδ)O\left(\frac{1}{p\sqrt{\delta}}\right) fraction of r[m/2]r\in[m/2] with 1S(xi,yi)S(θ(t),xi,yi)wr,s(t)=Ω~(vs(t)lpδ)\left|\left|\cfrac{1}{S}\underset{(x_{i},y_{i})\in\textbf{S}}{\sum}\cfrac{\overset{\sim}{\partial}\mathcal{L}(\theta^{(t)},x_{i},y_{i})}{\partial w_{r,s}^{(t)}}\right|\right|=\tilde{\Omega}\left(\cfrac{v_{s}^{(t)}}{lp\sqrt{\delta}}\right) is no more than, ppfix(vs(t)ε¯)Sexp(Ω(mpδ)+Slog(vs(t)ε¯))p\leq p_{\text{fix}}\left(\cfrac{v_{s}^{(t)}}{\bar{\varepsilon}}\right)^{S}\leq\exp{\left(-\Omega\left(\frac{m}{p\sqrt{\delta}}\right)+S\log{\left(\cfrac{v_{s}^{(t)}}{\bar{\varepsilon}}\right)}\right)}.

Hence, for m=Ω(Spδ)m=\overset{\sim}{\Omega}\left(Sp\sqrt{\delta}\right) with S=Ω~(1vs2(t))S=\tilde{\Omega}\left(\cfrac{1}{v_{s}^{2}{(t)}}\right), w.h.p. for every possible choice of {v(t)(θ(t),xi,yi):(xi,yi)S}\{v^{(t)}(\theta^{(t)},x_{i},y_{i}):(x_{i},y_{i})\in\textbf{S}\}, there are at least Ω(1pδ)\Omega\left(\frac{1}{p\sqrt{\delta}}\right) fraction of r[m/2]r\in[m/2] such that,

1S(xi,yi)S(θ(t),xi,yi)wr,s(t)=Ω~(vs(t)lpδ)\displaystyle\left|\left|\cfrac{1}{S}\underset{(x_{i},y_{i})\in\textbf{S}}{\sum}\cfrac{\overset{\sim}{\partial}\mathcal{L}(\theta^{(t)},x_{i},y_{i})}{\partial w_{r,s}^{(t)}}\right|\right|=\tilde{\Omega}\left(\cfrac{v_{s}^{(t)}}{lp\sqrt{\delta}}\right)

Now, we consider the difference between the sample gradient and the expected gradient. Since, (θ(t),xi,yi)wr,s(t)=O(1)\left|\left|\cfrac{\overset{\sim}{\partial}\mathcal{L}(\theta^{(t)},x_{i},y_{i})}{\partial w_{r,s}^{(t)}}\right|\right|=\overset{\sim}{O}(1), by using the Hoeffding’s inequality, we know that for every r[m/2]r\in[m/2]:

1S(xi,yi)S(θ(t),xi,yi)wr,s(t)^(θ(t))wr,s(t)=O~(1S)\displaystyle\left|\left|\cfrac{1}{S}\underset{(x_{i},y_{i})\in S}{\sum}\cfrac{\overset{\sim}{\partial}\mathcal{L}(\theta^{(t)},x_{i},y_{i})}{\partial w_{r,s}^{(t)}}-\cfrac{\overset{\sim}{\partial}\hat{\mathcal{L}}(\theta^{(t)})}{\partial w_{r,s}^{(t)}}\right|\right|=\tilde{O}\left(\cfrac{1}{\sqrt{S}}\right)

This implies that as long as S=Ω~((lpδvs(t))2)S=\tilde{\Omega}\left(\left(\cfrac{lp\sqrt{\delta}}{v_{s}^{(t)}}\right)^{2}\right) and hence for m=Ω~(l2p3δ3/2(vs(t))2)m=\tilde{\Omega}\left(\cfrac{l^{2}p^{3}\delta^{3/2}}{(v_{s}^{(t)})^{2}}\right), such r[m/2]r\in[m/2] also have:

^(θ(t))wr,s(t)=Ω~(vs(t)lpδ)\displaystyle\left|\left|\cfrac{\overset{\sim}{\partial}\hat{\mathcal{L}}(\theta^{(t)})}{\partial w_{r,s}^{(t)}}\right|\right|=\tilde{\Omega}\left(\cfrac{v_{s}^{(t)}}{lp\sqrt{\delta}}\right)

Lemma G.4.

Let us define v(t):=s{1,2}(vs(t))2v^{(t)}:=\sqrt{\underset{s\in\{1,2\}}{\sum}(v_{s}^{(t)})^{2}} where vs(t)=v1(t)v_{s}^{(t)}=v_{1}^{(t)} for s=1s=1 and vs(t)=v2(t)v_{s}^{(t)}=v_{2}^{(t)} for s=2s=2; γ:=Ω(1pδ)\gamma:=\Omega\left(\frac{1}{p\sqrt{\delta}}\right). Then, by selecting learning rate η=O~(γ3(v(t))2ml2)\eta=\tilde{O}\left(\cfrac{\gamma^{3}(v^{(t)})^{2}}{ml^{2}}\right) and batch size B=Ω~(l4γ6(v(t))4)B=\tilde{\Omega}\left(\cfrac{l^{4}}{\gamma^{6}(v^{(t)})^{4}}\right), at each iteration tt of the Step-3 of Algorithm 1 such that t=O~(σγ3(v(t))2ηl3)t=\tilde{O}\left(\cfrac{\sigma\gamma^{3}(v^{(t)})^{2}}{\eta l^{3}}\right), w.h.p. we can ensure that for every lll\geq l^{*},

ΔL(θ(t),θ(t+1))ηmγ3l2Ω~((v(t))2)\displaystyle\Delta L(\theta^{(t)},\theta^{(t+1)})\geq\cfrac{\eta m\gamma^{3}}{l^{2}}\tilde{\Omega}\left((v^{(t)})^{2}\right)
Proof.

For every lll\geq l^{*}, from Lemma G.3, for at least γ\gamma fraction of r[m/2]r\in[m/2] of expert ss:

^(θ(t))wr,s(t)=Ω~(vs(t)lpδ)\displaystyle\left|\left|\cfrac{\overset{\sim}{\partial}\hat{\mathcal{L}}(\theta^{(t)})}{\partial w_{r,s}^{(t)}}\right|\right|=\tilde{\Omega}\left(\cfrac{v_{s}^{(t)}}{lp\sqrt{\delta}}\right)

Now w.h.p., ~(θ(t),x,y)wr,s(t)=O(1)\left|\left|\cfrac{\tilde{\partial}\mathcal{L}(\theta^{(t)},x,y)}{\partial w_{r,s}^{(t)}}\right|\right|=\overset{\sim}{O}(1). Therefore, w.h.p. over a randomly sampled batch from 𝒟\mathcal{D} at iteration tt denoted as t\mathcal{B}_{t} of size BB:

1B(x,y)t(θ(t),x,y)wr,s(t)^(θ(t))wr,s(t)=O~(1B)\displaystyle\left|\left|\cfrac{1}{B}\underset{(x,y)\in\mathcal{B}_{t}}{\sum}\cfrac{\overset{\sim}{\partial}\mathcal{L}(\theta^{(t)},x,y)}{\partial w_{r,s}^{(t)}}-\cfrac{\overset{\sim}{\partial}\hat{\mathcal{L}}(\theta^{(t)})}{\partial w_{r,s}^{(t)}}\right|\right|=\tilde{O}\left(\cfrac{1}{\sqrt{B}}\right)

This implies, by selecting batch-size of B=Ω(l2p2δ(vs(t))2)B=\Omega\left(\cfrac{l^{2}p^{2}\delta}{(v_{s}^{(t)})^{2}}\right), for these γ\gamma fraction of r[m/2]r\in[m/2] of expert ss we can ensure that:

1B(xi,yi)t(θ(t),x,y)wr,s(t)=Ω~(vs(t)lpδ)\displaystyle\left|\left|\cfrac{1}{B}\underset{(x_{i},y_{i})\in\mathcal{B}_{t}}{\sum}\cfrac{\overset{\sim}{\partial}\mathcal{L}(\theta^{(t)},x,y)}{\partial w_{r,s}^{(t)}}\right|\right|=\tilde{\Omega}\left(\cfrac{v_{s}^{(t)}}{lp\sqrt{\delta}}\right)

Now using Lemma G.1, for a fixed (x,y)t(x,y)\in\mathcal{B}_{t}, by selecting τ=σγ4elB\tau=\cfrac{\sigma\gamma}{4elB} we have (1γ2B)\left(1-\cfrac{\gamma}{2B}\right) fraction of r[m/2]r\in[m/2] of the expert ss:

(θ(t),x,y)wr,s(t)=(θ(t),x,y)wr,s(t)\displaystyle\cfrac{\partial\mathcal{L}(\theta^{(t)},x,y)}{\partial w_{r,s}^{(t)}}=\cfrac{\overset{\sim}{\partial}\mathcal{L}(\theta^{(t)},x,y)}{\partial w_{r,s}^{(t)}}

Therefore, at least (1γ/2)(1-\gamma/2) fraction of r[m/2]r\in[m/2] of the expert ss:

(θ(t),x,y)wr,s(t)=(θ(t),x,y)wr,s(t)(x,y)t\displaystyle\cfrac{\partial\mathcal{L}(\theta^{(t)},x,y)}{\partial w_{r,s}^{(t)}}=\cfrac{\overset{\sim}{\partial}\mathcal{L}(\theta^{(t)},x,y)}{\partial w_{r,s}^{(t)}}\hskip 28.45274pt\forall(x,y)\in\mathcal{B}_{t}

Recall our definition of loss-function for SGD at iteration tt with mini-batch t\mathcal{B}_{t}, (θ(t))=1B(x,y)tlog(1+eyfM(θ(t),x))=1B(x,y)t(θ(t),x,y)\mathcal{L}(\theta^{(t)})=\cfrac{1}{B}\sum_{(x,y)\in\mathcal{B}_{t}}\log{(1+e^{-yf_{M}(\theta^{(t)},x)})}=\cfrac{1}{B}\sum_{(x,y)\in\mathcal{B}_{t}}\mathcal{L}(\theta^{(t)},x,y) and the corresponding batch-gradient at iteration tt, (θ(t))wr,s(t)=1B(x,y)t(θ(t),x,y)wr,s(t)\cfrac{\partial\mathcal{L}(\theta^{(t)})}{\partial w_{r,s}^{(t)}}=\cfrac{1}{B}\sum_{(x,y)\in\mathcal{B}_{t}}\cfrac{\partial\mathcal{L}(\theta^{(t)},x,y)}{\partial w_{r,s}^{(t)}}. Therefore, there are at least γ/2\gamma/2 fraction of r[m/2]r\in[m/2] of the expert ss:

(θ(t))wr,s(t)=Ω~(vs(t)lpδ)\displaystyle\left|\left|\cfrac{\partial\mathcal{L}(\theta^{(t)})}{\partial w_{r,s}^{(t)}}\right|\right|=\tilde{\Omega}\left(\cfrac{v_{s}^{(t)}}{lp\sqrt{\delta}}\right)

Now for any (x,y)𝒟(x^{\prime},y^{\prime})\sim\mathcal{D}, according to Lemma G.1, w.h.p. there are at least 12eτlσ1-\cfrac{2e\tau l}{\sigma} fraction of r[m/2]r\in[m/2] of the expert ss such that jJs(ws(t),x),|wr,s(t),x(j)|τ\forall j\in J_{s}(w_{s}^{(t)},x^{\prime}),|\langle w_{r,s}^{(t)},x^{\prime(j)}\rangle|\geq\tau. Let us denote the set of these rr’s of ss as 𝒮r,s\mathcal{S}_{r,s}. Therefore, on the set s{1,2}𝒮r,s\underset{s\in\{1,2\}}{\bigcup}\mathcal{S}_{r,s}, the loss function (θ(t),x,y)\mathcal{L}(\theta^{(t)},x^{\prime},y^{\prime}) is O(1)\overset{\sim}{O}(1) -smooth and O(1)\overset{\sim}{O}(1) -Lipschitz smooth.

On the other hand, the update rule of SGD at the iteration tt is, θ(t+1)=θ(t)η(θ(t))wr,s(t)\theta^{(t+1)}=\theta^{(t)}-\eta\cfrac{\partial\mathcal{L}(\theta^{(t)})}{\partial w_{r,s}^{(t)}}

Therefore, using Lemma J.2,

ΔL(θ(t),θ(t+1),x,y):=(θ(t),x,y)(θ(t+1),x,y)\displaystyle\Delta L(\theta^{(t)},\theta^{(t+1)},x^{\prime},y^{\prime}):=\mathcal{L}(\theta^{(t)},x^{\prime},y^{\prime})-\mathcal{L}({\theta^{(t+1)},x^{\prime},y^{\prime}})
ηrs[2]𝒮r,s(θ(t))wr,s(t),(θ(t),x,y)wr,s(t)s[2],r[m/2]\s[2]𝒮r,sO(η)O(η2m2)\displaystyle\geq\eta\underset{r\in\underset{s\in[2]}{\bigcup}\mathcal{S}_{r,s}}{\sum}\left\langle\cfrac{\partial\mathcal{L}(\theta^{(t)})}{\partial w_{r,s}^{(t)}},\cfrac{\partial\mathcal{L}(\theta^{(t)},x^{\prime},y^{\prime})}{\partial w_{r,s}^{(t)}}\right\rangle-\underset{s\in[2],r\in[m/2]\backslash\underset{s\in[2]}{\cup}\mathcal{S}_{r,s}}{\sum}\overset{\sim}{O}(\eta)-\overset{\sim}{O}\left(\eta^{2}m^{2}\right)
ηr[m/2],s[2](θ(t))wr,s(t),(θ(t),x,y)wr,s(t)O(ητlmσ)O(η2m2)\displaystyle\geq\eta\underset{r\in[m/2],s\in[2]}{\sum}\left\langle\cfrac{\partial\mathcal{L}(\theta^{(t)})}{\partial w_{r,s}^{(t)}},\cfrac{\partial\mathcal{L}(\theta^{(t)},x^{\prime},y^{\prime})}{\partial w_{r,s}^{(t)}}\right\rangle-\overset{\sim}{O}\left(\cfrac{\eta\tau lm}{\sigma}\right)-\overset{\sim}{O}\left(\eta^{2}m^{2}\right)

Let us denote the event,

0:\displaystyle\mathcal{E}_{0}:
ΔL(θ(t),θ(t+1),x,y)\displaystyle\Delta L(\theta^{(t)},\theta^{(t+1)},x^{\prime},y^{\prime})
ηr[m/2],s[2](θ(t))wr,s(t),(θ(t),x,y)wr,s(t)O(ητlmσ)O(η2m2)\displaystyle\geq\eta\underset{r\in[m/2],s\in[2]}{\sum}\left\langle\cfrac{\partial\mathcal{L}(\theta^{(t)})}{\partial w_{r,s}^{(t)}},\cfrac{\partial\mathcal{L}(\theta^{(t)},x^{\prime},y^{\prime})}{\partial w_{r,s}^{(t)}}\right\rangle-\overset{\sim}{O}\left(\cfrac{\eta\tau lm}{\sigma}\right)-\overset{\sim}{O}\left(\eta^{2}m^{2}\right)

Then, [0]11poly(m,n,p,δ,1ϵ)\mathbb{P}\left[\mathcal{E}_{0}\right]\geq 1-\cfrac{1}{poly(m,n,p,\delta,\frac{1}{\epsilon})} (i.e., w.h.p.) and hence [¬0]1poly(m,n,p,δ,1ϵ)\mathbb{P}\left[\neg\mathcal{E}_{0}\right]\leq\cfrac{1}{poly(m,n,p,\delta,\frac{1}{\epsilon})}

Also, let us define the event,

1:\displaystyle\mathcal{E}_{1}:
|(θ(t),x,y)|=O~(m),(θ(t),x,y)wr,s(t)=O~(1) and (θ(t))wr,s(t)=O~(1)\displaystyle\left|\mathcal{L}(\theta^{(t)},x^{\prime},y^{\prime})\right|=\tilde{O}(m),\left|\left|\cfrac{\partial\mathcal{L}(\theta^{(t)},x^{\prime},y^{\prime})}{\partial w_{r,s}^{(t)}}\right|\right|=\tilde{O}(1)\text{ and }\left|\left|\cfrac{\partial\mathcal{L}(\theta^{(t)})}{\partial w_{r,s}^{(t)}}\right|\right|=\tilde{O}(1)

Then, [1]11poly(m,n,p,δ,1ϵ)\mathbb{P}\left[\mathcal{E}_{1}\right]\geq 1-\cfrac{1}{poly(m,n,p,\delta,\frac{1}{\epsilon})} and hence [¬1]1poly(m,n,p,δ,1ϵ)\mathbb{P}\left[\neg\mathcal{E}_{1}\right]\leq\cfrac{1}{poly(m,n,p,\delta,\frac{1}{\epsilon})}

Now, the expected gradient at iteration tt, ^(θ(t))wr,s(t):=𝔼𝒟[(θ(t),x,y)wr,s(t)]\cfrac{{\partial}\hat{\mathcal{L}}(\theta^{(t)})}{\partial w_{r,s}^{(t)}}:=\mathbb{E}_{\mathcal{D}}\left[\cfrac{\partial\mathcal{L}(\theta^{(t)},x^{\prime},y^{\prime})}{\partial w_{r,s}^{(t)}}\right]

Therefore condition on 1\mathcal{E}_{1},

^(θ(t))wr,s(t)\displaystyle\cfrac{{\partial}\hat{\mathcal{L}}(\theta^{(t)})}{\partial w_{r,s}^{(t)}} =𝔼𝒟[(θ(t),x,y)wr,s(t)|1]\displaystyle=\mathbb{E}_{\mathcal{D}}\left[\cfrac{\partial\mathcal{L}(\theta^{(t)},x^{\prime},y^{\prime})}{\partial w_{r,s}^{(t)}}\big{|}\mathcal{E}_{1}\right]
=𝔼𝒟[(θ(t),x,y)wr,s(t)|0,1][0|1]+𝔼𝒟[(θ(t),x,y)wr,s(t)|¬0,1][¬0|1]\displaystyle=\mathbb{E}_{\mathcal{D}}\left[\cfrac{\partial\mathcal{L}(\theta^{(t)},x^{\prime},y^{\prime})}{\partial w_{r,s}^{(t)}}\big{|}\mathcal{E}_{0},\mathcal{E}_{1}\right]\mathbb{P}\left[\mathcal{E}_{0}\big{|}\mathcal{E}_{1}\right]+\mathbb{E}_{\mathcal{D}}\left[\cfrac{\partial\mathcal{L}(\theta^{(t)},x^{\prime},y^{\prime})}{\partial w_{r,s}^{(t)}}\big{|}\neg\mathcal{E}_{0},\mathcal{E}_{1}\right]\mathbb{P}\left[\neg\mathcal{E}_{0}\big{|}\mathcal{E}_{1}\right]

Which implies,

||^(θ(t))wr,s(t)𝔼𝒟[(θ(t),x,y)wr,s(t)|0,1]||O~(1)poly(m,n,p,δ,1ϵ)\displaystyle\left|\left|\cfrac{{\partial}\hat{\mathcal{L}}(\theta^{(t)})}{\partial w_{r,s}^{(t)}}-\mathbb{E}_{\mathcal{D}}\left[\cfrac{\partial\mathcal{L}(\theta^{(t)},x^{\prime},y^{\prime})}{\partial w_{r,s}^{(t)}}\big{|}\mathcal{E}_{0},\mathcal{E}_{1}\right]\right|\right|\leq\cfrac{\tilde{O}(1)}{poly(m,n,p,\delta,\frac{1}{\epsilon})}

Again, condition on 1\mathcal{E}_{1},

ΔL(θ(t),θ(t+1)):=𝔼𝒟[(θ(t),x,y)(θ(t+1),x,y)|1]\displaystyle\Delta L(\theta^{(t)},\theta^{(t+1)}):=\mathbb{E}_{\mathcal{D}}\left[\mathcal{L}(\theta^{(t)},x^{\prime},y^{\prime})-\mathcal{L}(\theta^{(t+1)},x^{\prime},y^{\prime})\big{|}\mathcal{E}_{1}\right]
=𝔼𝒟[(θ(t),x,y)(θ(t+1),x,y)|0,1][0|1]\displaystyle=\mathbb{E}_{\mathcal{D}}\left[\mathcal{L}(\theta^{(t)},x^{\prime},y^{\prime})-\mathcal{L}(\theta^{(t+1)},x^{\prime},y^{\prime})\big{|}\mathcal{E}_{0},\mathcal{E}_{1}\right]\mathbb{P}\left[\mathcal{E}_{0}\big{|}\mathcal{E}_{1}\right]
+𝔼𝒟[(θ(t),x,y)(θ(t+1),x,y)|¬0,1][¬0|1]\displaystyle\hskip 28.45274pt+\mathbb{E}_{\mathcal{D}}\left[\mathcal{L}(\theta^{(t)},x^{\prime},y^{\prime})-\mathcal{L}(\theta^{(t+1)},x^{\prime},y^{\prime})\big{|}\neg\mathcal{E}_{0},\mathcal{E}_{1}\right]\mathbb{P}\left[\neg\mathcal{E}_{0}\big{|}\mathcal{E}_{1}\right]
ηr[m/2],s[2](θ(t))wr,s(t),𝔼𝒟[(θ(t),x,y)wr,s(t)|0,1]O(ητlmσ)O(η2m2)\displaystyle\geq\eta\underset{r\in\left[m/2\right],s\in[2]}{\sum}\left\langle\cfrac{\partial\mathcal{L}(\theta^{(t)})}{\partial w_{r,s}^{(t)}},\mathbb{E}_{\mathcal{D}}\left[\cfrac{\partial\mathcal{L}(\theta^{(t)},x^{\prime},y^{\prime})}{\partial w_{r,s}^{(t)}}\big{|}\mathcal{E}_{0},\mathcal{E}_{1}\right]\right\rangle-\overset{\sim}{O}\left(\cfrac{\eta\tau lm}{\sigma}\right)-\overset{\sim}{O}\left(\eta^{2}m^{2}\right)
O~(m)poly(m,n,p,δ,1ϵ)\displaystyle\hskip 28.45274pt-\cfrac{\tilde{O}(m)}{poly(m,n,p,\delta,\frac{1}{\epsilon})}
ηr[m/2],s[2](θ(t))wr,s(t),^(θ(t))wr,s(t)O(ητlmσ)O(η2m2)O~(m)poly(m,n,p,δ,1ϵ)\displaystyle\geq\eta\underset{r\in\left[m/2\right],s\in[2]}{\sum}\left\langle\cfrac{\partial\mathcal{L}(\theta^{(t)})}{\partial w_{r,s}^{(t)}},\cfrac{{\partial}\hat{\mathcal{L}}(\theta^{(t)})}{\partial w_{r,s}^{(t)}}\right\rangle-\overset{\sim}{O}\left(\cfrac{\eta\tau lm}{\sigma}\right)-\overset{\sim}{O}\left(\eta^{2}m^{2}\right)-\cfrac{\tilde{O}(m)}{poly(m,n,p,\delta,\frac{1}{\epsilon})}
O~(ηm)poly(m,n,p,δ,1ϵ)\displaystyle\hskip 28.45274pt-\cfrac{\tilde{O}(\eta m)}{poly(m,n,p,\delta,\frac{1}{\epsilon})}
ηr[m/2],s[2](θ(t))wr,s(t),^(θ(t))wr,s(t)O(ητlmσ)O(η2m2)\displaystyle\geq\eta\underset{r\in\left[m/2\right],s\in[2]}{\sum}\left\langle\cfrac{\partial\mathcal{L}(\theta^{(t)})}{\partial w_{r,s}^{(t)}},\cfrac{{\partial}\hat{\mathcal{L}}(\theta^{(t)})}{\partial w_{r,s}^{(t)}}\right\rangle-\overset{\sim}{O}\left(\cfrac{\eta\tau lm}{\sigma}\right)-\overset{\sim}{O}\left(\eta^{2}m^{2}\right)

Now, w.h.p.

(θ(t))wr,s(t)^(θ(t))wr,s(t)=O~(1B)\displaystyle\left|\left|\cfrac{\partial\mathcal{L}(\theta^{(t)})}{\partial w_{r,s}^{(t)}}-\cfrac{{\partial}\hat{\mathcal{L}}(\theta^{(t)})}{\partial w_{r,s}^{(t)}}\right|\right|=\tilde{O}\left(\cfrac{1}{\sqrt{B}}\right)

Therefore,

(θ(t))wr,s(t),^(θ(t))wr,s(t)(θ(t))wr,s(t)2O~(1B)\displaystyle\left\langle\cfrac{\partial\mathcal{L}(\theta^{(t)})}{\partial w_{r,s}^{(t)}},\cfrac{{\partial}\hat{\mathcal{L}}(\theta^{(t)})}{\partial w_{r,s}^{(t)}}\right\rangle\geq\left|\left|\cfrac{\partial\mathcal{L}(\theta^{(t)})}{\partial w_{r,s}^{(t)}}\right|\right|^{2}-\tilde{O}\left(\cfrac{1}{\sqrt{B}}\right)

Therefore,

ΔL(θ(t),θ(t+1))\displaystyle\Delta L(\theta^{(t)},\theta^{(t+1)}) ηr[m],s[2](θ(t))wr,s(t)2O~(ητlmσ)O~(η2m2)ηO~(mB)\displaystyle\geq\eta\underset{r\in\left[m\right],s\in[2]}{\sum}\left|\left|\cfrac{\partial\mathcal{L}(\theta^{(t)})}{\partial w_{r,s}^{(t)}}\right|\right|^{2}-\tilde{O}\left(\cfrac{\eta\tau lm}{\sigma}\right)-\tilde{O}\left(\eta^{2}m^{2}\right)-\eta\tilde{O}\left(\cfrac{m}{\sqrt{B}}\right)
ηmγ3l2Ω~(s[2](vs(t))2)O~(ητlmσ)O~(η2m2)ηO~(mB)\displaystyle\geq\cfrac{\eta m\gamma^{3}}{l^{2}}\tilde{\Omega}\left(\underset{s\in[2]}{\sum}(v_{s}^{(t)})^{2}\right)-\tilde{O}\left(\cfrac{\eta\tau lm}{\sigma}\right)-\tilde{O}\left(\eta^{2}m^{2}\right)-\eta\tilde{O}\left(\cfrac{m}{\sqrt{B}}\right)
ηmγ3l2Ω((v(t))2)O~(ητlmσ)O~(η2m2)ηO~(mB)\displaystyle\geq\cfrac{\eta m\gamma^{3}}{l^{2}}\overset{\sim}{\Omega}\left((v^{(t)})^{2}\right)-\tilde{O}\left(\cfrac{\eta\tau lm}{\sigma}\right)-\tilde{O}\left(\eta^{2}m^{2}\right)-\eta\tilde{O}\left(\cfrac{m}{\sqrt{B}}\right)

Now selecting, η=O~(γ3(v(t))2ml2)\eta=\tilde{O}\left(\cfrac{\gamma^{3}(v^{(t)})^{2}}{ml^{2}}\right), B=Ω~(l4γ6(v(t))4)B=\tilde{\Omega}\left(\cfrac{l^{4}}{\gamma^{6}(v^{(t)})^{4}}\right), τ=O~(σγ3(v(t))2l3)\tau=\tilde{O}\left(\cfrac{\sigma\gamma^{3}(v^{(t)})^{2}}{l^{3}}\right) and hence for
t=O~(σγ3(v(t))2ηl3)t=\tilde{O}\left(\cfrac{\sigma\gamma^{3}(v^{(t)})^{2}}{\eta l^{3}}\right), we get,

ΔL(θ(t),θ(t+1))ηmγ3l2Ω~((v(t))2)\displaystyle\Delta L(\theta^{(t)},\theta^{(t+1)})\geq\cfrac{\eta m\gamma^{3}}{l^{2}}\tilde{\Omega}\left((v^{(t)})^{2}\right)

Appendix H Lemmas Used to Prove the Theorem 4.5

In joint-training pMoE i.e., for any iteration tt of the Step-2 of Algorithm 2, the gradient of the loss for single-sample with respect to the hidden nodes of the experts:

(θ(t),x,y)wr,s(t)=yar,sv(t)(θ(t),x,y)(1ljJs(ws(t),x)Gj,s(ws(t),x)x(j)1wr,s(t),x(j)0)\frac{\partial\mathcal{L}(\theta^{(t)},x,y)}{\partial w_{r,s}^{(t)}}=-ya_{r,s}v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{l}\underset{j\in J_{s}(w_{s}^{(t)},x)}{\sum}G_{j,s}(w_{s}^{(t)},x)x^{(j)}1_{\langle w_{r,s}^{(t)},x^{(j)}\rangle\geq 0}\right) (19)

and the corresponding pseudo-gradient:

(θ(t),x,y)wr,s(t)=yar,sv(t)(θ(t),x,y)(1ljJs(ws(t),x)Gj,s(ws(t),x)x(j)1wr,s(0),x(j)0)\frac{\overset{\sim}{\partial}\mathcal{L}(\theta^{(t)},x,y)}{\partial w_{r,s}^{(t)}}=-ya_{r,s}v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{l}\underset{j\in J_{s}(w_{s}^{(t)},x)}{\sum}G_{j,s}(w_{s}^{(t)},x)x^{(j)}1_{\langle w_{r,s}^{(0)},x^{(j)}\rangle\geq 0}\right) (20)

and the expected pseudo-gradient:

^(θ(t))wr,s(t)=𝔼𝒟[(θ(t),x,y)wr,s(t)]\displaystyle\frac{\overset{\sim}{\partial}\hat{\mathcal{L}}(\theta^{(t)})}{\partial w_{r,s}^{(t)}}=\mathbb{E}_{\mathcal{D}}\left[\frac{\overset{\sim}{\partial}\mathcal{L}(\theta^{(t)},x,y)}{\partial w_{r,s}^{(t)}}\right]
=ar,s2(𝔼𝒟|y=+1[v(t)(θ(t),x,y)(1ljJs(ws(t),x)Gj,s(ws(t),x)x(j)1wr,s(0),x(j)0)|y=+1]\displaystyle=-\cfrac{a_{r,s}}{2}\left(\mathbb{E}_{\mathcal{D}|y=+1}\left[v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{l}\underset{j\in J_{s}(w_{s}^{(t)},x)}{\sum}G_{j,s}(w_{s}^{(t)},x)x^{(j)}1_{\langle w_{r,s}^{(0)},x^{(j)}\rangle\geq 0}\right)\Big{|}y=+1\right]\right.
𝔼𝒟|y=1[v(t)(θ(t),x,y)(1ljJs(ws(t),x)Gj,s(ws(t),x)x(j)1wr,s(0),Pjx0)|y=1])\displaystyle\left.\hskip 28.45274pt-\mathbb{E}_{\mathcal{D}|y=-1}\left[v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{l}\underset{j\in J_{s}(w_{s}^{(t)},x)}{\sum}G_{j,s}(w_{s}^{(t)},x)x^{(j)}1_{\langle w_{r,s}^{(0)},P_{j}x\rangle\geq 0}\right)\Big{|}y=-1\right]\right)
=ar,s2Pr,s(t)\displaystyle=-\cfrac{a_{r,s}}{2}P_{r,s}^{(t)}

with,

Pr,s(t)\displaystyle P_{r,s}^{(t)} =𝔼𝒟|y=+1[v(t)(θ(t),x,y)(1ljJs(ws(t),x)Gj,s(ws(t),x)x(j)1wr,s(0),x(j)0)|y=+1]\displaystyle=\mathbb{E}_{\mathcal{D}|y=+1}\left[v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{l}\underset{j\in J_{s}(w_{s}^{(t)},x)}{\sum}G_{j,s}(w_{s}^{(t)},x)x^{(j)}1_{\langle w_{r,s}^{(0)},x^{(j)}\rangle\geq 0}\right)\Big{|}y=+1\right]
𝔼𝒟|y=1[v(t)(θ(t),x,y)(1ljJs(ws(t),x)Gj,s(ws(t),x)x(j)1wr,s(0),x(j)0)|y=1]\displaystyle\hskip 28.45274pt-\mathbb{E}_{\mathcal{D}|y=-1}\left[v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{l}\underset{j\in J_{s}(w_{s}^{(t)},x)}{\sum}G_{j,s}(w_{s}^{(t)},x)x^{(j)}1_{\langle w_{r,s}^{(0)},x^{(j)}\rangle\geq 0}\right)\Big{|}y=-1\right]
Lemma H.1.

W.h.p. over the random initialization of the hidden nodes of the experts defined in (8), for every (x,y)𝒟(x,y)\sim\mathcal{D} and for every τ>0\tau>0, for every t=O~(τlη)t=\tilde{O}\left(\cfrac{\tau l}{\eta}\right) of the Step-2 of Algorithm 2, we have that for at least (12eτnσ)\left(1-\cfrac{2e\tau n}{\sigma}\right) fraction of r[m/k]r\in[m/k] of the expert s[k]s\in[k]:

(θ(t),x,y)wr,s(t)=(θ(t),x,y)wr,s(t)\cfrac{\partial\mathcal{L}(\theta^{(t)},x,y)}{\partial w_{r,s}^{(t)}}=\cfrac{\overset{\sim}{\partial}\mathcal{L}(\theta^{(t)},x,y)}{\partial w_{r,s}^{(t)}} and |wr,s(t),x(j)|τ,j[n]|\langle w_{r,s}^{(t)},x^{(j)}\rangle|\geq\tau,\forall j\in[n]

Proof.

Using similar argument as in Lemma G.1 and as jJs(ws(t),x)Gj,s(ws(t),x)=1\sum_{j\in J_{s}(w_{s}^{(t)},x)}G_{j,s}(w_{s}^{(t)},x)=1 w.h.p. (θ(t),x,y)wr,s(t)=O~(1l)\left|\left|\frac{\partial\mathcal{L}(\theta^{(t)},x,y)}{\partial w_{r,s}^{(t)}}\right|\right|=\tilde{O}\left(\frac{1}{l}\right) so as the mini-batch gradient, (θ(t))wr,s(t)=O~(1l)\left|\left|\frac{\partial\mathcal{L}(\theta^{(t)})}{\partial w_{r,s}^{(t)}}\right|\right|=\tilde{O}\left(\frac{1}{l}\right).

Therefore, wr,s(t)wr,s(0)=O~(ηtl)\left|\left|w_{r,s}^{(t)}-w_{r,s}^{(0)}\right|\right|=\tilde{O}\left(\frac{\eta t}{l}\right).

Now, for every τ>0,\tau>0, considering the set s:={r[m/k]:j[n],|wr,s(0),x(j)|2τ}\mathcal{H}_{s}:=\left\{r\in[m/k]:\forall j\in[n],|\langle w_{r,s}^{(0)},x^{(j)}\rangle|\geq 2\tau\right\} and following the same procedure as in Lemma G.1 we can complete the proof.

Lemma H.2.

For the expert s[k]s\in[k] and any possible fixed set {v(t)(θ(t),x,y)Gj,s(ws(t),x):(x,y)𝒟,jJs(ws(t),x)}\{v^{(t)}(\theta^{(t)},x,y)G_{j,s}(w_{s}^{(t)},x):(x,y)\sim\mathcal{D},j\in J_{s}(w_{s}^{(t)},x)\} (that does not depend on wr,s(0)w_{r,s}^{(0)}) such that vs(t)=v1,s(t)=max{v1,s(t),v2,s(t)}v_{s}^{(t)}=v_{1,s}^{(t)}=\max\{v_{1,s}^{(t)},v_{2,s}^{(t)}\}, we have:

[Pr,s(t)=Ω(vs(t)lpδ)]=Ω(1pδ)\mathbb{P}\left[||P_{r,s}^{(t)}||=\overset{\sim}{\Omega}\left(\cfrac{v_{s}^{(t)}}{lp\sqrt{\delta}}\right)\right]=\Omega\left(\cfrac{1}{p\sqrt{\delta}}\right)

Proof.

We know that,

Pr,s(t)\displaystyle P_{r,s}^{(t)} =𝔼𝒟|y=+1[v(t)(θ(t),x,y)(1ljJs(ws(t),x)Gj,s(ws(t),x)x(j)1wr,s(0),x(j)0)|y=+1]\displaystyle=\mathbb{E}_{\mathcal{D}|y=+1}\left[v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{l}\underset{j\in J_{s}(w_{s}^{(t)},x)}{\sum}G_{j,s}(w_{s}^{(t)},x)x^{(j)}1_{\langle w_{r,s}^{(0)},x^{(j)}\rangle\geq 0}\right)\Big{|}y=+1\right]
𝔼𝒟|y=1[v(t)(θ(t),x,y)(1ljJs(ws(t),x)Gj,s(ws(t),x)x(j)1wr,s(0),x(j)0)|y=1]\displaystyle\hskip 28.45274pt-\mathbb{E}_{\mathcal{D}|y=-1}\left[v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{l}\underset{j\in J_{s}(w_{s}^{(t)},x)}{\sum}G_{j,s}(w_{s}^{(t)},x)x^{(j)}1_{\langle w_{r,s}^{(0)},x^{(j)}\rangle\geq 0}\right)\Big{|}y=-1\right]

Therefore,

h(wr,s(0)):=Pr,s,wr,s(0)\displaystyle h(w_{r,s}^{(0)}):=\langle P_{r,s},w_{r,s}^{(0)}\rangle
=𝔼𝒟|y=+1[v(t)(θ(t),x,y)(1ljJs(ws(t),x)Gj,s(ws(t),x)ReLU(wr,s(0),x(j)))|y=+1]\displaystyle=\mathbb{E}_{\mathcal{D}|y=+1}\left[v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{l}\underset{j\in J_{s}(w_{s}^{(t)},x)}{\sum}G_{j,s}(w_{s}^{(t)},x)\textbf{ReLU}\left(\langle w_{r,s}^{(0)},x^{(j)}\rangle\right)\right)\Big{|}y=+1\right]
𝔼𝒟|y=1[v(t)(θ(t),x,y)(1ljJs(ws(t),x)Gj,s(ws(t),x)ReLU(wr,s(0),x(j)))|y=1]\displaystyle\hskip 28.45274pt-\mathbb{E}_{\mathcal{D}|y=-1}\left[v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{l}\underset{j\in J_{s}(w_{s}^{(t)},x)}{\sum}G_{j,s}(w_{s}^{(t)},x)\textbf{ReLU}\left(\langle w_{r,s}^{(0)},x^{(j)}\rangle\right)\right)\Big{|}y=-1\right]

Now, decomposing wr,s(0)=αo1+βw_{r,s}^{(0)}=\alpha o_{1}+\beta with βo1\beta\perp o_{1} we get,

h(wr,s(0))=v1,s(t)lReLU(α)\displaystyle h(w_{r,s}^{(0)})=\cfrac{v_{1,s}^{(t)}}{l}\textbf{ReLU}\left(\alpha\right)
+𝔼𝒟|y=+1,1,s(t)[p1,s(t)v(t)(θ(t),x,y)(1ljJs(x)/jo1Gj,sReLU(αo1,x(j)+β,x(j)))]\displaystyle+\mathbb{E}_{\mathcal{D}|y=+1,\mathcal{E}_{1,s}^{(t)}}\left[p_{1,s}^{(t)}v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{l}\underset{j\in J_{s}(x)/j_{o_{1}}}{\sum}G_{j,s}\textbf{ReLU}\left(\alpha\langle o_{1},x^{(j)}\rangle+\langle\beta,x^{(j)}\rangle\right)\right)\right]
+𝔼𝒟|y=+1,¬1,s(t)[(1p1,s(t))v(t)(θ(t),x,y)(1ljJs(x)Gj,sReLU(αo1,x(j)+β,x(j)))]\displaystyle+\mathbb{E}_{\mathcal{D}|y=+1,\neg\mathcal{E}_{1,s}^{(t)}}\left[(1-p_{1,s}^{(t)})v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{l}\underset{j\in J_{s}(x)}{\sum}G_{j,s}\textbf{ReLU}\left(\alpha\langle o_{1},x^{(j)}\rangle+\langle\beta,x^{(j)}\rangle\right)\right)\right]
𝔼𝒟|y=1[v(t)(θ(t),x,y)(1ljJs(x)Gj,sReLU(αo1,x(j)+β,x(j)))]\displaystyle-\mathbb{E}_{\mathcal{D}|y=-1}\left[v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{l}\underset{j\in J_{s}(x)}{\sum}G_{j,s}\textbf{ReLU}\left(\alpha\langle o_{1},x^{(j)}\rangle+\langle\beta,x^{(j)}\rangle\right)\right)\right]
=ϕ(α)l(α)\displaystyle=\phi(\alpha)-l(\alpha)

where,

ϕ(α):=v1,s(t)lReLU(α)\displaystyle\phi(\alpha):=\cfrac{v_{1,s}^{(t)}}{l}\textbf{ReLU}\left(\alpha\right)
+𝔼𝒟|y=+1,1,s(t)[p1,s(t)v(t)(θ(t),x,y)(1ljJs(x)/jo1Gj,sReLU(αo1,x(j)+β,x(j)))]\displaystyle+\mathbb{E}_{\mathcal{D}|y=+1,\mathcal{E}_{1,s}^{(t)}}\left[p_{1,s}^{(t)}v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{l}\underset{j\in J_{s}(x)/j_{o_{1}}}{\sum}G_{j,s}\textbf{ReLU}\left(\alpha\langle o_{1},x^{(j)}\rangle+\langle\beta,x^{(j)}\rangle\right)\right)\right]
+𝔼𝒟|y=+1,¬1,s(t)[(1p1,s(t))v(t)(θ(t),x,y)(1ljJs(x)Gj,sReLU(αo1,x(j)+β,x(j)))]\displaystyle+\mathbb{E}_{\mathcal{D}|y=+1,\neg\mathcal{E}_{1,s}^{(t)}}\left[(1-p_{1,s}^{(t)})v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{l}\underset{j\in J_{s}(x)}{\sum}G_{j,s}\textbf{ReLU}\left(\alpha\langle o_{1},x^{(j)}\rangle+\langle\beta,x^{(j)}\rangle\right)\right)\right]

and

l(α):=𝔼𝒟|y=1[v(t)(θ(t),x,y)(1ljJs(x)Gj,sReLU(αo1,x(j)+β,x(j)))]\displaystyle l(\alpha):=\mathbb{E}_{\mathcal{D}|y=-1}\left[v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{l}\underset{j\in J_{s}(x)}{\sum}G_{j,s}\textbf{ReLU}\left(\alpha\langle o_{1},x^{(j)}\rangle+\langle\beta,x^{(j)}\rangle\right)\right)\right]

Now as ϕ(α)\phi(\alpha) and l(α)l(\alpha) both are convex functions, using the same procedure as in Lemma G.1 we can complete the proof. ∎

Lemma H.3.

Let vs(t)=max{v1,s(t),v2,s(t)}v_{s}^{(t)}=\max\{v_{1,s}^{(t)},v_{2,s}^{(t)}\}. Then, for every vs(t)>0v_{s}^{(t)}>0, for m=Ω~(klp3δ3/2(vs(t))2)m=\tilde{\Omega}\left(\cfrac{klp^{3}\delta^{3/2}}{(v_{s}^{(t)})^{2}}\right), for every possible set {v(t)(θ(t),x,y)Gj,s(ws(t),x):(x,y)𝒟,jJs(ws(t),x)}\{v^{(t)}(\theta^{(t)},x,y)G_{j,s}(w_{s}^{(t)},x):(x,y)\sim\mathcal{D},j\in J_{s}(w_{s}^{(t)},x)\} (that depends on wr,s(0)w_{r,s}^{(0)}), there exist at least Ω(1pδ)\Omega\left(\frac{1}{p\sqrt{\delta}}\right) fraction of r[m/k]r\in[m/k] of the expert s[k]s\in[k] such that,

^(θ(t))wr,s(t)=Ω~(vs(t)lpδ)\left|\left|\cfrac{\overset{\sim}{\partial}\hat{\mathcal{L}}(\theta^{(t)})}{\partial w_{r,s}^{(t)}}\right|\right|=\tilde{\Omega}\left(\cfrac{v_{s}^{(t)}}{lp\sqrt{\delta}}\right)

Proof.

Let us pick SS samples to form S={(xi,yi)}i=1S\textbf{S}=\{(x_{i},y_{i})\}_{i=1}^{S} with S/2S/2 many samples from y=+1y=+1 such that 12p1,s(t)S\frac{1}{2}p_{1,s}^{(t)}S many samples of them satisfy the event 1,s(t)\mathcal{E}_{1,s}^{(t)} and S/2S/2 many samples from y=1y=-1 such that 12p2,s(t)S\frac{1}{2}p_{2,s}^{(t)}S many samples of them satisfy the event 2,s(t)\mathcal{E}_{2,s}^{(t)}. We denote the subset of S satisfying the event 1,s(t)\mathcal{E}_{1,s}^{(t)} by S1\textbf{S}_{1} and the subset of S satisfying the event 2,s(t)\mathcal{E}_{2,s}^{(t)} by S2\textbf{S}_{2}. Therefore, |S1|=12p1,s(t)S\left|\textbf{S}_{1}\right|=\frac{1}{2}p_{1,s}^{(t)}S and |S2|=12p2,s(t)S\left|\textbf{S}_{2}\right|=\frac{1}{2}p_{2,s}^{(t)}S. Now, w.h.p. :

|v1,s(t)2p1,s(t)S(xi,yi)S1p1,s(t)Gjo1,s(t)(xi)v(t)(θ(t),xi,yi)|=O~(1p1,s(t)S) and\displaystyle\left|v_{1,s}^{(t)}-\cfrac{2}{p_{1,s}^{(t)}S}\underset{(x_{i},y_{i})\in\textbf{S}_{1}}{\sum}p_{1,s}^{(t)}G_{j_{o_{1}},s}^{(t)}(x_{i})v^{(t)}(\theta^{(t)},x_{i},y_{i})\right|=\tilde{O}\left(\cfrac{1}{\sqrt{p_{1,s}^{(t)}S}}\right)\text{ and }
|v2,s(t)2p2,s(t)S(xi,yi)S2p2,s(t)Gjo2,s(t)(xi)v(t)(θ(t),xi,yi)|=O~(1p2,s(t)S)\displaystyle\left|v_{2,s}^{(t)}-\cfrac{2}{p_{2,s}^{(t)}S}\underset{(x_{i},y_{i})\in\textbf{S}_{2}}{\sum}p_{2,s}^{(t)}G_{j_{o_{2}},s}^{(t)}(x_{i})v^{(t)}(\theta^{(t)},x_{i},y_{i})\right|=\tilde{O}\left(\cfrac{1}{\sqrt{p_{2,s}^{(t)}S}}\right)

This implies that, as long as S=Ω~(1(vs(t))2)S=\tilde{\Omega}\left(\cfrac{1}{(v_{s}^{(t)})^{2}}\right), we will have that,

max{2p1,s(t)S(xi,yi)S1p1,s(t)Gjo1,s(t)(xi)v(t)(θ(t),xi,yi),\displaystyle\max\left\{\cfrac{2}{p_{1,s}^{(t)}S}\underset{(x_{i},y_{i})\in\textbf{S}_{1}}{\sum}p_{1,s}^{(t)}G_{j_{o_{1}},s}^{(t)}(x_{i})v^{(t)}(\theta^{(t)},x_{i},y_{i}),\right.
2p2,s(t)S(xi,yi)S2p2,s(t)Gjo2,s(t)(xi)v(t)(θ(t),xi,yi)}[12vs(t),32vs(t)]\displaystyle\left.\hskip 56.9055pt\cfrac{2}{p_{2,s}^{(t)}S}\underset{(x_{i},y_{i})\in\textbf{S}_{2}}{\sum}p_{2,s}^{(t)}G_{j_{o_{2}},s}^{(t)}(x_{i})v^{(t)}(\theta^{(t)},x_{i},y_{i})\right\}\in\left[\cfrac{1}{2}v_{s}^{(t)},\cfrac{3}{2}v_{s}^{(t)}\right]

Now using the same procedure as in Lemma G.3 and using Lemma H.2 we can show that, for a fixed set {v(t)(θ(t),xi,yi)Gj,s(ws(t),xi):(xi,yi)S,jJs(ws(t),xi)}\{v^{(t)}(\theta^{(t)},x_{i},y_{i})G_{j,s}(w_{s}^{(t)},x_{i}):(x_{i},y_{i})\in\textbf{S},j\in J_{s}(w_{s}^{(t)},x_{i})\} as long as S=Ω~(1(vs(t))2)S=\tilde{\Omega}\left(\cfrac{1}{(v_{s}^{(t)})^{2}}\right), the probability that there are less than O(1pδ)O\left(\frac{1}{p\sqrt{\delta}}\right) fraction of r[m/k]r\in[m/k] such that 1S(xi,yi)S(θ(t),xi,yi)wr,s(t)\left|\left|\cfrac{1}{S}\underset{(x_{i},y_{i})\in\textbf{S}}{\sum}\cfrac{\overset{\sim}{\partial}\mathcal{L}(\theta^{(t)},x_{i},y_{i})}{\partial w_{r,s}^{(t)}}\right|\right| is Ω~(vs(t)lpδ)\tilde{\Omega}\left(\cfrac{v_{s}^{(t)}}{lp\sqrt{\delta}}\right) is no more than pfixp_{\text{fix}} where, pfixexp(Ω(mkpδ))p_{\text{fix}}\leq\exp{\left(-\Omega\left(\frac{m}{kp\sqrt{\delta}}\right)\right)}.

Now, for every ε¯>0\bar{\varepsilon}>0, for two different {v(t)(θ(t),xi,yi)Gj,s(ws(t),xi):(xi,yi)S,jJs(ws(t),xi)}\{v^{(t)}(\theta^{(t)},x_{i},y_{i})G_{j,s}(w_{s}^{(t)},x_{i}):(x_{i},y_{i})\in\textbf{S},j\in J_{s}(w_{s}^{(t)},x_{i})\}, {v(t)(θ(t),xi,yi)Gj,s(ws(t),xi):(xi,yi)S,jJs(ws(t),xi)}\{v^{\prime(t)}(\theta^{(t)},x_{i},y_{i})G_{j,s}^{\prime}(w_{s}^{(t)},x_{i}):(x_{i},y_{i})\in\textbf{S},j\in J_{s}(w_{s}^{(t)},x_{i})\} such that (xi,yi)S,jJs(ws(t),xi)\forall(x_{i},y_{i})\in\textbf{S},j\in J_{s}(w_{s}^{(t)},x_{i}), |v(t)(θ(t),xi,yi)Gj,s(ws(t),xi)v(t)(θ(t),xi,yi)Gj,s(ws(t),xi)|ε¯|v^{(t)}(\theta^{(t)},x_{i},y_{i})G_{j,s}(w_{s}^{(t)},x_{i})-v^{\prime(t)}(\theta^{(t)},x_{i},y_{i})G_{j,s}^{\prime}(w_{s}^{(t)},x_{i})|\leq\bar{\varepsilon}, w.h.p.,

||1S(xi,yi)Syar,sljJs(ws(t),xi)(v(t)(θ(t),xi,yi)Gj,s(ws(t),xi)\displaystyle\left|\left|\cfrac{1}{S}\underset{(x_{i},y_{i})\in\textbf{S}}{\sum}\cfrac{-ya_{r,s}}{l}\underset{j\in J_{s}(w_{s}^{(t)},x_{i})}{\sum}\left(v^{(t)}(\theta^{(t)},x_{i},y_{i})G_{j,s}(w_{s}^{(t)},x_{i})\right.\right.\right.
v(t)(θ(t),xi,yi)Gj,s(ws(t),xi))xi(j)1wr,s(0),xi(j)0||=O~(ε¯)\displaystyle\left.\left.\left.\hskip 142.26378pt-v^{\prime(t)}(\theta^{(t)},x_{i},y_{i})G_{j,s}^{\prime}(w_{s}^{(t)},x_{i})\right)x_{i}^{(j)}1_{\langle w_{r,s}^{(0)},x_{i}^{(j)}\rangle\geq 0}\right|\right|=\tilde{O}(\bar{\varepsilon})

Therefore taking ε¯\bar{\varepsilon}-net with ε¯=Θ~(vs(t)lpδ)\bar{\varepsilon}=\tilde{\Theta}\left(\cfrac{v_{s}^{(t)}}{lp\sqrt{\delta}}\right) we can show that the probability that there exists {v(t)(θ(t),xi,yi)Gj,s(ws(t),xi):(xi,yi)S,jJs(ws(t),xi)}\{v^{(t)}(\theta^{(t)},x_{i},y_{i})G_{j,s}(w_{s}^{(t)},x_{i}):(x_{i},y_{i})\in\textbf{S},j\in J_{s}(w_{s}^{(t)},x_{i})\} such that there are no more than O(1pδ)O\left(\frac{1}{p\sqrt{\delta}}\right) fraction of r[m/k]r\in[m/k] with 1S(xi,yi)S(θ(t),xi,yi)wr,s(t)=Ω~(vs(t)lpδ)\left|\left|\cfrac{1}{S}\underset{(x_{i},y_{i})\in\textbf{S}}{\sum}\cfrac{\overset{\sim}{\partial}\mathcal{L}(\theta^{(t)},x_{i},y_{i})}{\partial w_{r,s}^{(t)}}\right|\right|=\tilde{\Omega}\left(\cfrac{v_{s}^{(t)}}{lp\sqrt{\delta}}\right) is no more than, ppfix(vs(t)ε¯)Slexp(Ω(mkpδ)+Sllog(vs(t)ε¯))p\leq p_{\text{fix}}\left(\cfrac{v_{s}^{(t)}}{\bar{\varepsilon}}\right)^{Sl}\leq\exp{\left(-\Omega\left(\frac{m}{kp\sqrt{\delta}}\right)+Sl\log{\left(\cfrac{v_{s}^{(t)}}{\bar{\varepsilon}}\right)}\right)}.

Hence, for m=Ω(kSlpδ)m=\overset{\sim}{\Omega}\left(kSlp\sqrt{\delta}\right) with S=Ω~(1(vs(t))2)S=\tilde{\Omega}\left(\cfrac{1}{(v_{s}^{(t)})^{2}}\right), w.h.p. for every possible choice of {v(t)(θ(t),xi,yi)Gj,s(ws(t),xi):(xi,yi)S,jJs(ws(t),xi)}\{v^{(t)}(\theta^{(t)},x_{i},y_{i})G_{j,s}(w_{s}^{(t)},x_{i}):(x_{i},y_{i})\in\textbf{S},j\in J_{s}(w_{s}^{(t)},x_{i})\}, there are at least Ω(1pδ)\Omega\left(\frac{1}{p\sqrt{\delta}}\right) fraction of r[m/k]r\in[m/k] such that,

1S(xi,yi)S(θ(t),xi,yi)wr,s(t)=Ω~(vs(t)lpδ)\displaystyle\left|\left|\cfrac{1}{S}\underset{(x_{i},y_{i})\in\textbf{S}}{\sum}\cfrac{\overset{\sim}{\partial}\mathcal{L}(\theta^{(t)},x_{i},y_{i})}{\partial w_{r,s}^{(t)}}\right|\right|=\tilde{\Omega}\left(\cfrac{v_{s}^{(t)}}{lp\sqrt{\delta}}\right)

Now as (θ(t),xi,yi)wr,s(t)=O(1/l)\left|\left|\cfrac{\overset{\sim}{\partial}\mathcal{L}(\theta^{(t)},x_{i},y_{i})}{\partial w_{r,s}^{(t)}}\right|\right|=\overset{\sim}{O}(1/l), using the same procedure as in Lemma G.3 we can complete the proof which gives us m=Ω~(klp3δ3/2(vs(t))2)m=\tilde{\Omega}\left(\cfrac{klp^{3}\delta^{3/2}}{(v_{s}^{(t)})^{2}}\right). ∎

Lemma H.4.

Let us define v(t):=s[k](vs(t))2v^{(t)}:=\sqrt{\underset{s\in[k]}{\sum}(v_{s}^{(t)})^{2}} where vs(t)=max{v1,s(t),v2,s(t)}v_{s}^{(t)}=\max\{v_{1,s}^{(t)},v_{2,s}^{(t)}\} for all s[k]s\in[k]; γ:=Ω(1pδ)\gamma:=\Omega\left(\frac{1}{p\sqrt{\delta}}\right). Then, by selecting learning rate η=O~(γ3(v(t))2l3mk2)\eta=\tilde{O}\left(\cfrac{\gamma^{3}(v^{(t)})^{2}l^{3}}{mk^{2}}\right) and batch size B=Ω~(k2γ6(v(t))4)B=\tilde{\Omega}\left(\cfrac{k^{2}}{\gamma^{6}(v^{(t)})^{4}}\right), at each iteration tt of the Step-2 of Algorithm 2 such that t=O~(σγ3(v(t))2l2ηnk)t=\tilde{O}\left(\cfrac{\sigma\gamma^{3}(v^{(t)})^{2}l^{2}}{\eta nk}\right), w.h.p. we can ensure that,

ΔL(θ(t),θ(t+1))ηmγ3l2Ω~((v(t))2)\displaystyle\Delta L(\theta^{(t)},\theta^{(t+1)})\geq\cfrac{\eta m\gamma^{3}}{l^{2}}\tilde{\Omega}\left((v^{(t)})^{2}\right)
Proof.

As w.h.p. ~(θ(t),x,y)wr,s(t)=O(1/l)\left|\left|\cfrac{\tilde{\partial}\mathcal{L}(\theta^{(t)},x,y)}{\partial w_{r,s}^{(t)}}\right|\right|=\overset{\sim}{O}(1/l), for a randomly sampled batch t\mathcal{B}_{t} of size BB, by selecting τ=σγ4enB\tau=\cfrac{\sigma\gamma}{4enB} in Lemma H.1 and using the same procedure as in Lemma G.4, we can show that for at least γ/2\gamma/2 fraction of r[m/k]r\in[m/k] of expert s[k]s\in[k]:

(θ(t))wr,s(t)=Ω~(vs(t)lpδ)\displaystyle\left|\left|\cfrac{\partial\mathcal{L}(\theta^{(t)})}{\partial w_{r,s}^{(t)}}\right|\right|=\tilde{\Omega}\left(\cfrac{v_{s}^{(t)}}{lp\sqrt{\delta}}\right)

Now, for any (x,y)𝒟(x^{\prime},y^{\prime})\sim\mathcal{D}, from Lemma H.1 we know that for at least 12eτnσ1-\cfrac{2e\tau n}{\sigma} fraction of r[m/k]r\in[m/k] of any expert s[k]s\in[k], the loss function is O~(1/l)\tilde{O}(1/l)-Lipschitz smooth and also O~(1/l)\tilde{O}(1/l)-smooth.

Therefore, using same procedure as in Lemma G.4 we can complete the proof. ∎

Appendix I Lemmas Used to Prove the Theorem 4.3

For the single CNN model, as all the patches of an input (x,y)𝒟(x,y)\sim\mathcal{D} are sent to the model (i.e., there is no router), the gradient of the single sample loss function w.r.t. hidden node r[m]r\in[m],

(θ(t),x,y)wr(t)=yarv(t)(θ(t),x,y)(1nj[n]x(j)1wr(t),x(j)0)\displaystyle\frac{\partial\mathcal{L}(\theta^{(t)},x,y)}{\partial w_{r}^{(t)}}=-ya_{r}v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{n}\underset{j\in[n]}{\sum}x^{(j)}1_{\langle w_{r}^{(t)},x^{(j)}\rangle\geq 0}\right) (21)

the corresponding pseudo-gradient,

(θ(t),x,y)wr(t)=yarv(t)(θ(t),x,y)(1nj[n]x(j)1wr(0),x(j)0)\displaystyle\frac{\partial\mathcal{L}(\theta^{(t)},x,y)}{\partial w_{r}^{(t)}}=-ya_{r}v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{n}\underset{j\in[n]}{\sum}x^{(j)}1_{\langle w_{r}^{(0)},x^{(j)}\rangle\geq 0}\right)

and the expected pseudo-gradient,

^(θ(t))wr(t)\displaystyle\frac{\overset{\sim}{\partial}\hat{\mathcal{L}}(\theta^{(t)})}{\partial w_{r}^{(t)}} =𝔼𝒟[(θ(t),x,y)wr(t)]\displaystyle=\mathbb{E}_{\mathcal{D}}\left[\frac{\overset{\sim}{\partial}\mathcal{L}(\theta^{(t)},x,y)}{\partial w_{r}^{(t)}}\right]
=ar2(𝔼𝒟|y=+1[v(t)(θ(t),x,y)(1nj[n]x(j)1wr(0),x(j)0)|y=+1]\displaystyle=-\cfrac{a_{r}}{2}\left(\mathbb{E}_{\mathcal{D}|y=+1}\left[v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{n}\underset{j\in[n]}{\sum}x^{(j)}1_{\langle w_{r}^{(0)},x^{(j)}\rangle\geq 0}\right)\Big{|}y=+1\right]\right.
𝔼𝒟|y=1[v(t)(θ(t),x,y)(1nj[n]x(j)1wr(0),Pjx0)|y=1])\displaystyle\left.\hskip 42.67912pt-\mathbb{E}_{\mathcal{D}|y=-1}\left[v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{n}\underset{j\in[n]}{\sum}x^{(j)}1_{\langle w_{r}^{(0)},P_{j}x\rangle\geq 0}\right)\Big{|}y=-1\right]\right)
=ar2Pr(t)\displaystyle=-\cfrac{a_{r}}{2}P_{r}^{(t)}

where,

Pr(t)\displaystyle P_{r}^{(t)} =𝔼𝒟|y=+1[v(t)(θ(t),x,y)(1nj[n])x(j)1wr(0),x(j)0)|y=+1]\displaystyle=\mathbb{E}_{\mathcal{D}|y=+1}\left[v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{n}\underset{j\in[n])}{\sum}x^{(j)}1_{\langle w_{r}^{(0)},x^{(j)}\rangle\geq 0}\right)\Big{|}y=+1\right]
𝔼𝒟|y=1[v(t)(θ(t),x,y)(1nj[n]x(j)1wr(0),x(j)0)|y=1]\displaystyle\hskip 42.67912pt-\mathbb{E}_{\mathcal{D}|y=-1}\left[v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{n}\underset{j\in[n]}{\sum}x^{(j)}1_{\langle w_{r}^{(0)},x^{(j)}\rangle\geq 0}\right)\Big{|}y=-1\right]
Lemma I.1.

W.h.p. over the random initialization, for every (x,y)𝒟(x,y)\sim\mathcal{D} and for every τ>0\tau>0, for every iteration t=O~(τη)t=\tilde{O}\left(\cfrac{\tau}{\eta}\right) of the minibatch SGD, we have that for at least (12eτnσ)\left(1-\cfrac{2e\tau n}{\sigma}\right) fraction of r[m]r\in[m]:

(θ(t),x,y)wr(t)=(θ(t),x,y)wr(t)\cfrac{\partial\mathcal{L}(\theta^{(t)},x,y)}{\partial w_{r}^{(t)}}=\cfrac{\overset{\sim}{\partial}\mathcal{L}(\theta^{(t)},x,y)}{\partial w_{r}^{(t)}} and |wr(t),x(j)|τ,j[n]|\langle w_{r}^{(t)},x^{(j)}\rangle|\geq\tau,\forall j\in[n]

Proof.

Using similar argument as in Lemma G.1 we can show that w.h.p., (θ(t),x,y)wr(t)=O~(1)\left|\left|\frac{\partial\mathcal{L}(\theta^{(t)},x,y)}{\partial w_{r}^{(t)}}\right|\right|=\tilde{O}\left(1\right) so as the mini-batch gradient, (θ(t))wr(t)=O~(1)\left|\left|\frac{\partial\mathcal{L}(\theta^{(t)})}{\partial w_{r}^{(t)}}\right|\right|=\tilde{O}\left(1\right).

Therefore, wr(t)wr(0)=O~(1)\left|\left|w_{r}^{(t)}-w_{r}^{(0)}\right|\right|=\tilde{O}\left(1\right).

Now, for every τ>0,\tau>0, considering the set :={r[m]:j[n],|wr(0),x(j)|2τ}\mathcal{H}:=\left\{r\in[m]:\forall j\in[n],|\langle w_{r}^{(0)},x^{(j)}\rangle|\geq 2\tau\right\} and following the same procedure as in Lemma G.1 we can complete the proof.

Recall, v1(t):=𝔼𝒟|y=+1[v(t)(θ(t),x,y)|y=+1]v_{1}^{(t)}:=\mathbb{E}_{\mathcal{D}|y=+1}[v^{(t)}(\theta^{(t)},x,y)|y=+1] and v2(t):=𝔼𝒟|y=1[v(t)(θ(t),x,y)|y=1]v_{2}^{(t)}:=\mathbb{E}_{\mathcal{D}|y=-1}[v^{(t)}(\theta^{(t)},x,y)|y=-1].

Lemma I.2.

For any possible fixed set {v(t)(θ(t),x,y):(x,y)𝒟}\{v^{(t)}(\theta^{(t)},x,y):(x,y)\sim\mathcal{D}\} (that does not depend on wr(0)w_{r}^{(0)}) such that v(t)=v1(t)=max{v1(t),v2(t)}v^{(t)}=v_{1}^{(t)}=\max\{v_{1}^{(t)},v_{2}^{(t)}\}, we have:

[Pr(t)=Ω(v(t)npδ)]=Ω(1pδ)\mathbb{P}\left[||P_{r}^{(t)}||=\overset{\sim}{\Omega}\left(\cfrac{v^{(t)}}{np\sqrt{\delta}}\right)\right]=\Omega\left(\cfrac{1}{p\sqrt{\delta}}\right)

Proof.

We know that,

Pr(t)\displaystyle P_{r}^{(t)} =𝔼𝒟|y=+1[v(t)(θ(t),x,y)(1nj[n]x(j)1wr(0),x(j)0)|y=+1]\displaystyle=\mathbb{E}_{\mathcal{D}|y=+1}\left[v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{n}\underset{j\in[n]}{\sum}x^{(j)}1_{\langle w_{r}^{(0)},x^{(j)}\rangle\geq 0}\right)\Big{|}y=+1\right]
𝔼𝒟|y=1[v(t)(θ(t),x,y)(1nj[n]x(j)1wr(0),x(j)0)|y=1]\displaystyle\hskip 28.45274pt-\mathbb{E}_{\mathcal{D}|y=-1}\left[v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{n}\underset{j\in[n]}{\sum}x^{(j)}1_{\langle w_{r}^{(0)},x^{(j)}\rangle\geq 0}\right)\Big{|}y=-1\right]

Therefore,

h(wr(0)):=Pr,wr(0)\displaystyle h(w_{r}^{(0)}):=\langle P_{r},w_{r}^{(0)}\rangle
=𝔼𝒟|y=+1[v(t)(θ(t),x,y)(1nj[n]ReLU(wr(0),x(j)))|y=+1]\displaystyle=\mathbb{E}_{\mathcal{D}|y=+1}\left[v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{n}\underset{j\in[n]}{\sum}\textbf{ReLU}\left(\langle w_{r}^{(0)},x^{(j)}\rangle\right)\right)\Big{|}y=+1\right]
𝔼𝒟|y=1[v(t)(θ(t),x,y)(1nj[n]ReLU(wr(0),x(j)))|y=1]\displaystyle\hskip 28.45274pt-\mathbb{E}_{\mathcal{D}|y=-1}\left[v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{n}\underset{j\in[n]}{\sum}\textbf{ReLU}\left(\langle w_{r}^{(0)},x^{(j)}\rangle\right)\right)\Big{|}y=-1\right]

Now, decomposing wr(0)=αo1+βw_{r}^{(0)}=\alpha o_{1}+\beta with βo1\beta\perp o_{1} we get,

h(wr(0))=v(t)nReLU(α)\displaystyle h(w_{r}^{(0)})=\cfrac{v^{(t)}}{n}\textbf{ReLU}\left(\alpha\right)
+𝔼𝒟|y=+1[v(t)(θ(t),x,y)(1nj[n]/jo1ReLU(αo1,x(j)+β,x(j)))]\displaystyle+\mathbb{E}_{\mathcal{D}|y=+1}\left[v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{n}\underset{j\in[n]/j_{o_{1}}}{\sum}\textbf{ReLU}\left(\alpha\langle o_{1},x^{(j)}\rangle+\langle\beta,x^{(j)}\rangle\right)\right)\right]
𝔼𝒟|y=1[v(t)(θ(t),x,y)(1nj[n]ReLU(αo1,x(j)+β,x(j)))]\displaystyle-\mathbb{E}_{\mathcal{D}|y=-1}\left[v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{n}\underset{j\in[n]}{\sum}\textbf{ReLU}\left(\alpha\langle o_{1},x^{(j)}\rangle+\langle\beta,x^{(j)}\rangle\right)\right)\right]
=ϕ(α)l(α)\displaystyle=\phi(\alpha)-l(\alpha)

where,

ϕ(α):=v(t)nReLU(α)\displaystyle\phi(\alpha):=\cfrac{v^{(t)}}{n}\textbf{ReLU}\left(\alpha\right)
+𝔼𝒟|y=+1[v(t)(θ(t),x,y)(1nj[n]/jo1ReLU(αo1,x(j)+β,x(j)))]\displaystyle\hskip 34.14322pt+\mathbb{E}_{\mathcal{D}|y=+1}\left[v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{n}\underset{j\in[n]/j_{o_{1}}}{\sum}\textbf{ReLU}\left(\alpha\langle o_{1},x^{(j)}\rangle+\langle\beta,x^{(j)}\rangle\right)\right)\right]

and

l(α):=𝔼𝒟|y=1[v(t)(θ(t),x,y)(1nj[n]ReLU(αo1,x(j)+β,x(j)))]\displaystyle l(\alpha):=\mathbb{E}_{\mathcal{D}|y=-1}\left[v^{(t)}(\theta^{(t)},x,y)\left(\cfrac{1}{n}\underset{j\in[n]}{\sum}\textbf{ReLU}\left(\alpha\langle o_{1},x^{(j)}\rangle+\langle\beta,x^{(j)}\rangle\right)\right)\right]

Now as ϕ(α)\phi(\alpha) and l(α)l(\alpha) both are convex functions, using the same procedure as in Lemma G.1 we can complete the proof. ∎

Lemma I.3.

Let v(t)=max{v1(t),v2(t)}v^{(t)}=\max\{v_{1}^{(t)},v_{2}^{(t)}\}. Then, for every v(t)>0v^{(t)}>0, for m=Ω~(n2p3δ3/2(v(t))2)m=\tilde{\Omega}\left(\cfrac{n^{2}p^{3}\delta^{3/2}}{(v^{(t)})^{2}}\right), for every possible set {v(t)(θ(t),x,y)(ws(t),x):(x,y)𝒟}\{v^{(t)}(\theta^{(t)},x,y)(w_{s}^{(t)},x):(x,y)\sim\mathcal{D}\} (that depends on wr(0)w_{r}^{(0)}), there exist at least Ω(1pδ)\Omega\left(\frac{1}{p\sqrt{\delta}}\right) fraction of r[m]r\in[m] such that,

^(θ(t))wr(t)=Ω~(v(t)npδ)\left|\left|\cfrac{\overset{\sim}{\partial}\hat{\mathcal{L}}(\theta^{(t)})}{\partial w_{r}^{(t)}}\right|\right|=\tilde{\Omega}\left(\cfrac{v^{(t)}}{np\sqrt{\delta}}\right)



Proof.

Similar as in the proof of Lemma G.3, by picking SS samples from the distribution 𝒟\mathcal{D} to form the set S={(xi,yi)}i=1S\textbf{S}=\{(x_{i},y_{i})\}_{i=1}^{S} such that S/2S/2 many samples from y=+1y=+1 (denoting the sub-set by S+1\textbf{S}_{+1}) and S/2S/2 many samples from y=1y=-1 (denoting the sub-set by S1\textbf{S}_{-1}), we can show that w.h.p.,

|v1(t)1S/2(xi,yi)S+1v(t)(θ(t),xi,yi)|=O~(1S) and\displaystyle\left|v_{1}^{(t)}-\cfrac{1}{S/2}\underset{(x_{i},y_{i})\in\textbf{S}_{+1}}{\sum}v^{(t)}(\theta^{(t)},x_{i},y_{i})\right|=\tilde{O}\left(\cfrac{1}{\sqrt{S}}\right)\text{ and }
|v2(t)1S/2(xi,yi)S1v(t)(θ(t),xi,yi)|=O~(1S)\displaystyle\left|v_{2}^{(t)}-\cfrac{1}{S/2}\underset{(x_{i},y_{i})\in\textbf{S}_{-1}}{\sum}v^{(t)}(\theta^{(t)},x_{i},y_{i})\right|=\tilde{O}\left(\cfrac{1}{\sqrt{S}}\right)

This implies that, as long as S=Ω~(1(v(t))2)S=\tilde{\Omega}\left(\cfrac{1}{(v^{(t)})^{2}}\right) we have,

max{1S/2(xi,yi)S+1v(t)(θ(t),xi,yi),1S/2(xi,yi)S1v(t)(θ(t),xi,yi)}[12v(t),32v(t)]\displaystyle\max\left\{\cfrac{1}{S/2}\underset{(x_{i},y_{i})\in\textbf{S}_{+1}}{\sum}v^{(t)}(\theta^{(t)},x_{i},y_{i}),\cfrac{1}{S/2}\underset{(x_{i},y_{i})\in\textbf{S}_{-1}}{\sum}v^{(t)}(\theta^{(t)},x_{i},y_{i})\right\}\in\left[\cfrac{1}{2}v^{(t)},\cfrac{3}{2}v^{(t)}\right]

Now using Lemma I.2 and following similar procedure as in Lemma G.3 we can complete the proof. ∎

Lemma I.4.

With v(t)=max{v1(t),v2(t)}v^{(t)}=\max\{v_{1}^{(t)},v_{2}^{(t)}\} and γ=Ω(1pδ)\gamma=\Omega\left(\frac{1}{p\sqrt{\delta}}\right), by selecting learning rate η=O~(γ3(v(t))2mn2)\eta=\tilde{O}\left(\cfrac{\gamma^{3}(v^{(t)})^{2}}{mn^{2}}\right) and batch-size B=Ω~(n4γ6(v(t))4)B=\tilde{\Omega}\left(\cfrac{n^{4}}{\gamma^{6}(v^{(t)})^{4}}\right), for t=O~(σγ3(v(t))2ηn3)t=\tilde{O}\left(\cfrac{\sigma\gamma^{3}(v^{(t)})^{2}}{\eta n^{3}}\right) iterations of SGD, w.h.p. we can ensure that,

ΔL(θ(t),θ(t+1))ηmγ3n2Ω~((v(t))2)\displaystyle\Delta L(\theta^{(t)},\theta^{(t+1)})\geq\cfrac{\eta m\gamma^{3}}{n^{2}}\tilde{\Omega}\left((v^{(t)})^{2}\right)
Proof.

Using Lemma I.1 and I.3 and following similar technique as in Lemma G.4, the proof can be completed. ∎

Appendix J Auxiliary Lemmas

Lemma J.1.

(Li & Liang, 2018) Let ψ:\psi:\mathbb{R}\rightarrow\mathbb{R} and ζ:\zeta:\mathbb{R}\rightarrow\mathbb{R} are convex functions. Let {ψ(x)}\{\partial\psi(x)\} and {ζ(x)}\{\partial\zeta(x)\} are the sets of sub-gradient of ψ\psi and ζ\zeta at xx respectively such that maxψ(x)=max{ψ(x)}\partial_{\text{max}}\psi(x)=\text{max}\{\partial\psi(x)\} , maxζ(x)=max{ζ(x)}\partial_{\text{max}}\zeta(x)=\text{max}\{\partial\zeta(x)\}, minψ(x)=min{ψ(x)}\partial_{\text{min}}\psi(x)=\text{min}\{\partial\psi(x)\} and minζ(x)=min{ζ(x)}\partial_{\text{min}}\zeta(x)=\text{min}\{\partial\zeta(x)\}. Then for any τ0\tau\geq 0 such that γ=(maxψ(τ/2)minψ(τ/2)(maxζ(τ/2)minζ(τ/2)\gamma=(\partial_{\text{max}}\psi(\tau/2)-\partial_{\text{min}}\psi(-\tau/2)-(\partial_{\text{max}}\zeta(\tau/2)-\partial_{\text{min}}\zeta(-\tau/2),

αU(τ,τ)[|ψ(α)ζ(α)|τγ512]164\mathbb{P}_{\alpha\sim U(-\tau,\tau)}\left[|\psi(\alpha)-\zeta(\alpha)|\geq\frac{\tau\gamma}{512}\right]\geq\frac{1}{64}

Lemma J.2.

(Li & Liang, 2018) Let for any i[m]i\in[m], the function hi:dh_{i}:\mathbb{R}^{d}\rightarrow\mathbb{R} is LL-Lipschitz smooth and there exists r[m]r\in[m] such that for all i[mr]i\in[m-r] the function hih_{i} is also LL-smooth. Furthermore, let us assume that the function g:g:\mathbb{R}\rightarrow\mathbb{R} is both LL-Lipschitz smooth and LL-smooth. Let define f(w):=g(i[m]hi(wi))f(w):=g\left(\sum_{i\in[m]}h_{i}(w_{i})\right) where wdmw\in\mathbb{R}^{dm} such that widw_{i}\in\mathbb{R}^{d}. Then for every ξdm\xi\in\mathbb{R}^{dm} such that ξid\xi_{i}\in\mathbb{R}^{d} with ξiρ\left|\left|\xi_{i}\right|\right|\leq\rho, we have:

g(i[m]hi(wi+ξi))g(i[m]hi(wi))i[mr]f(w)wi,ξi+L3m2ρ2+L2rρ\displaystyle g\left(\sum_{i\in[m]}h_{i}(w_{i}+\xi_{i})\right)-g\left(\sum_{i\in[m]}h_{i}(w_{i})\right)\leq\sum_{i\in[m-r]}\left\langle\cfrac{\partial f(w)}{\partial w_{i}},\xi_{i}\right\rangle+L^{3}m^{2}\rho^{2}+L^{2}r\rho

Appendix K Proof of the Non-linear Separability of the Data-model

Lemma K.1.

As long as l=Ω(1)l^{*}=\Omega(1), the distribution 𝒟\mathcal{D} is NOT linearly separable.

Proof.

We will prove the Lemma by contradiction.

Now, if the distribution, 𝒟\mathcal{D} is linearly separable, then there exists a hyperplane h=[h(1)T,h(2)T,,h(n)T]h=\left[h^{(1)T},h^{(2)T},...,h^{(n)T}\right] with h=1||h||=1 (here, h(j)h^{(j)} represents the jj-th patch of the hyperplane for j[n]j\in[n]) such that,

(x1,y=+1)𝒟 and (x2,y=1)𝒟,x1Thx2Th0\forall(x_{1},y=+1)\sim\mathcal{D}\text{ and }(x_{2},y=-1)\sim\mathcal{D},x_{1}^{T}h-x_{2}^{T}h\geq 0 (22)

Now, as the class-discriminative patterns o1o_{1} and o2o_{2} can occur at any position j[n]j\in[n], h(j)2=Θ(1n);j[n]||h^{(j)}||^{2}=\Theta(\frac{1}{n});\hskip 2.84544pt\forall j\in[n].

Now, j[n]\forall j\in[n], we can decompose h(j)h^{(j)} as h(j)=ajo1+bjo2h^{(j)}=a_{j}o_{1}+b_{j}o_{2}.

Then, |aj|=|bj|=Θ(1n(1δd))|a_{j}|=|b_{j}|=\Theta\left(\cfrac{1}{\sqrt{n(1-\delta_{d})}}\right), j[n]\forall j\in[n] as o1=o2=1||o_{1}||=||o_{2}||=1.

Now,

x1Thx2Th=o1,h(jo1)o2,h(jo2)+j[n]/jo1x1(j),h(j)j[n]/jo2x2(j),h(j)\displaystyle x_{1}^{T}h-x_{2}^{T}h=\langle o_{1},h^{(j_{o_{1}})}\rangle-\langle o_{2},h^{(j_{o_{2}})}\rangle+\sum_{j\in[n]/j_{o_{1}}}\langle x_{1}^{(j)},h^{(j)}\rangle-\sum_{j\in[n]/j_{o_{2}}}\langle x_{2}^{(j)},h^{(j)}\rangle

Now,

o1,h(jo1)o2,h(jo2)=(ajo1bjo2)(ajo2bjo1)δd\displaystyle\langle o_{1},h^{(j_{o_{1}})}\rangle-\langle o_{2},h^{(j_{o_{2}})}\rangle=(a_{j_{o_{1}}}-b_{j_{o_{2}}})-(a_{j_{o_{2}}}-b_{j_{o_{1}}})\delta_{d}
|ajo1bjo2||ajo1bjo2|δd[WLOG, let assume δd<0]\displaystyle\leq|a_{j_{o_{1}}}-b_{j_{o_{2}}}|-|a_{j_{o_{1}}}-b_{j_{o_{2}}}|\delta_{d}\hskip 28.45274pt\text{[WLOG, let assume $\delta_{d}<0$]}
=O(1δdn)\displaystyle=O\left(\sqrt{\frac{1-\delta_{d}}{n}}\right)

Now,

j[n]/jo2x2(j),h(j)j[n]/jo1x1(j),h(j)=j[n]/jo2x2(j),ajo1+bjo2j[n]/jo1x1(j),ajo1+bjo2\displaystyle\sum_{j\in[n]/j_{o_{2}}}\langle x_{2}^{(j)},h^{(j)}\rangle-\sum_{j\in[n]/j_{o_{1}}}\langle x_{1}^{(j)},h^{(j)}\rangle=\sum_{j\in[n]/j_{o_{2}}}\langle x_{2}^{(j)},a_{j}o_{1}+b_{j}o_{2}\rangle-\sum_{j\in[n]/j_{o_{1}}}\langle x_{1}^{(j)},a_{j}o_{1}+b_{j}o_{2}\rangle
=O(l(1δd)n)\displaystyle=O\left(l^{*}\sqrt{\frac{(1-\delta_{d})}{n}}\right)

Therefore, for l=Ω(1)l^{*}=\Omega(1) there is contradiction with (22). ∎

Appendix L WRN and WRN-pMoE Architectures Implemented in the Experiments

Refer to caption
Figure 17: The WRN architecture implemented to learn CelebA dataset
Refer to caption
Figure 18: The WRN-pMoE architecture implemented to learn CelebA dataset

Appendix M Extension to Multi-class Classification

Let us consider cc-class classification problem where c>2c>2. Then, we have (x,y)𝒟c(x,y)\sim\mathcal{D}_{c} where y{1,2,,c}y\in\{1,2,...,c\} for the multi-class distribution 𝒟c\mathcal{D}_{c}.

The multi-class data model:
Now, according to the data model presented in section 4.2, we have {o1,o2,,oc}\{o_{1},o_{2},...,o_{c}\} as class-discriminative pattern set. j,j[c]\forall j,j^{\prime}\in[c] such that jjj\neq j^{\prime}, we define δdj,j:=oj,oj\delta_{d_{j,j^{\prime}}}:=\langle o_{j},o_{j^{\prime}}\rangle. We further define δd:=max{δdj,j}\delta_{d}:=\max\{\delta_{d_{j,j^{\prime}}}\}. Then,

δ=1(1max{δdj,j2,δr2}j,j[c],jj)\displaystyle\delta=\cfrac{1}{(1-\max\{\delta^{2}_{d_{j,j^{\prime}}},\delta^{2}_{r}\}_{j,j^{\prime}\in[c],j\neq j^{\prime}})}

The multi-class pMoE model:
The pMoE model for multi-class case is given by,

i[c],fMi(θ,x)=s=1𝑘r=1mkar,s,iljJs(ws,x)ReLU(wr,s,x(j))Gj,s(ws,x)\forall i\in[c],f_{M_{i}}(\theta,x)=\overset{k}{\underset{s=1}{\sum}}\overset{\frac{m}{k}}{\underset{r=1}{\sum}}\cfrac{a_{r,s,i}}{l}\underset{j\in J_{s}(w_{s},x)}{\sum}\textbf{ReLU}(\langle w_{r,s},x^{(j)}\rangle)G_{j,s}(w_{s},x) (23)

An illustration of (23) is given in Figure 19.

Refer to caption
Figure 19: An illustration of the pMoE model in (23) with c=4,k=4,m=8,n=6c=4,k=4,m=8,n=6 and l=2l=2.

For mult-class case, we replace the logistic loss function by the softmax loss function (also known as cross-entropy loss). For the training dataset {xj,yj}j=1N\{x_{j},y_{j}\}_{j=1}^{N}, we minimize the following empirical risk minimization problem:

min𝜃:L(θ)=1Nj=1𝑁logi=1cexp(fMi(θ,xj))exp(fMyj(θ,xj))\displaystyle\underset{\theta}{\text{min}}:\hskip 11.38092ptL(\theta)=\cfrac{1}{N}\overset{N}{\underset{j=1}{\sum}}\log{\cfrac{\sum_{i=1}^{c}\exp{(f_{M_{i}}(\theta,x_{j}))}}{\exp{(f_{M_{y_{j}}}(\theta,x_{j}))}}} (24)

M.1 The Multi-class Separate-training pMoE

Number of experts: For the multi-class separate-training pMoE, we take k=ck=c, i.e. number of experts is equal to the number of classes.

Training algorithm:
Input
: Training data {(xi,yi)}i=1N\{(x_{i},y_{i})\}_{i=1}^{N}, learning rates ηr\eta_{r} and η\eta, number of iterations TrT_{r} and TT, batch-
           sizes BrB_{r} and BB
Step-1: Initialize ws(0),wr,s(0),ar,s,s{1,2},r[m/k]w_{s}^{(0)},w_{r,s}^{(0)},a_{r,s},\forall s\in\{1,2\},r\in[m/k] according to (7) and (8)
Step-2: (Pair-wise router training) We train the router, i.e. the gating-kernels w1,w2,,wcw_{1},w_{2},...,w_{c} using pair-wise training describe below:

  1. 1.

    At first, we separate the training set of NrN_{r} samples into cc disjoint subsets {Nr,1,Nr,2,,Nr,c}\{N_{r,1},N_{r,2},...,N_{r,c}\} according to the class-labels.

  2. 2.

    Now, we prepare c/2c/2 pairs of training sets {(Nr,1,Nr,2),(Nr,3,Nr,4),,(Nr,c1,Nr,c)}\{(N_{r,1},N_{r,2}),(N_{r,3},N_{r,4}),...,(N_{r,c-1},N_{r,c})\} (here WLOG we assume that cc is even).

  3. 3.

    Under each pair (Nr,i,Nr,i+1)(N_{r,i},N_{r,i+1}), we re-define the label as y=+1y=+1 and y=1y=-1 for the class ii and i+1i+1 respectively and train the gating-kernels wiw_{i} and wi+1w_{i+1} by minimizing (6) for TrT_{r} iterations

  4. 4.

    After the end of pair-wise training for all the pairs {(Nr,1,Nr,2),(Nr,3,Nr,4),,(Nr,c1,Nr,c)}\{(N_{r,1},N_{r,2}),(N_{r,3},N_{r,4}),...,(N_{r,c-1},N_{r,c})\}, we receive w1(Tr),w2(Tr),,wc(Tr)w_{1}^{(T_{r})},w_{2}^{(T_{r})},...,w_{c}^{(T_{r})} as the learned gating-kernels.

Step-3:(Expert training)
Using the learned gating-kernels w1(Tr),w2(Tr),,wc(Tr)w_{1}^{(T_{r})},w_{2}^{(T_{r})},...,w_{c}^{(T_{r})} in Step-2 and using the same procedure as in Step-3 of Algorithm 1 we train the experts.

The multi-class counterpart of the Lemma 4.1:
Now, using the same proof techniques as for Lemma 4.1 (i.e. following same procedures as in section D and E) we can show that, we need Nr=Ω(c2n2(1δd)2)N_{r}=\Omega(\cfrac{c^{2}n^{2}}{(1-\delta_{d})^{2}}) training samples to ensure,

argj[n](x(j)=oi)Ji(wi(Tr),x)(x,y=i)𝒟cand i[c]\displaystyle\text{arg}_{j\in[n]}(x^{(j)}=o_{i})\in J_{i}(w_{i}^{(T_{r})},x)\hskip 28.45274pt\forall(x,y=i)\sim\mathcal{D}_{c}\hskip 2.84544pt\text{and }\forall i\in[c]

The multi-class counterpart of the Theorem 4.2:
We redefine the value-function for each class i[c]i\in[c] as,

vi,a(t)(θ(t),x,y=a):={jaefMj(θ(t),x)j=1𝑐efMj(θ(t),x);if, i=aefMi(θ(t),x)j=1𝑐efMj(θ(t),x);otherwise\displaystyle v_{i,a}^{(t)}(\theta^{(t)},x,y=a):=\begin{cases}\cfrac{\underset{j\neq a}{\sum}e^{f_{M_{j}}(\theta^{(t)},x)}}{\overset{c}{\underset{j=1}{\sum}}e^{f_{M_{j}}(\theta^{(t)},x)}};\hskip 5.69046pt\text{if, }i=a\\ \\ -\cfrac{e^{f_{M_{i}}(\theta^{(t)},x)}}{\overset{c}{\underset{j=1}{\sum}}e^{f_{M_{j}}(\theta^{(t)},x)}};\hskip 5.69046pt\text{otherwise}\end{cases} (25)

Now using similar techniques as in the proof of Theorem 4.2 (i.e. following same procedure as in the proof of Theorem F.3 and section G) we can show that for every ϵ>0\epsilon>0, we need number of hidden nodes mMS=Ω(l10p12δ6c11/ϵ16)m\geq M_{S}=\Omega\left(l^{10}p^{12}\delta^{6}c^{11}\big{/}\epsilon^{16}\right), batch-size B=Ω(l4p6δ3c6/ϵ8)B=\Omega\left(l^{4}p^{6}\delta^{3}c^{6}\big{/}\epsilon^{8}\right) for T=O(l4p6δ3c6/ϵ8)T=O\left(l^{4}p^{6}\delta^{3}c^{6}\big{/}\epsilon^{8}\right) iterations (i.e. NS=Ω(l8p12δ6c12/ϵ16)N_{S}=\Omega(l^{8}p^{12}\delta^{6}c^{12}/\epsilon^{16})) to ensure,

(x,y)𝒟c[j[c],jy,fMy(θ(T),x)>fMj(θ(T),x)]1ϵ\displaystyle\underset{(x,y)\sim\mathcal{D}_{c}}{\mathbb{P}}\left[\forall j\in[c],j\neq y,f_{M_{y}}(\theta^{(T)},x)>f_{M_{j}}(\theta^{(T)},x)\right]\geq 1-\epsilon

M.2 The Multi-class Joint-training pMoE

Training algorithm: Same as the Algorithm 2 except that for multi-class case the loss function is softmax instead of logistic loss.
The multi-class counterpart of the Theorem 4.5:
Using the value-function define in (25) and as long as the Assumption 4.4 satisfied for all the classes i[c]i\in[c], following the similar techniques as in the proof of Theorem 4.5 (i.e. following same procedure as in the proof of Theorem F.5 and section H), we can show that for every ϵ>0\epsilon>0, we need number of hidden nodes mMJ=Ω(k3n2l6p12δ6c8/ϵ16)m\geq M_{J}=\Omega\left(k^{3}n^{2}l^{6}p^{12}\delta^{6}c^{8}\big{/}\epsilon^{16}\right), batch-size B=Ω(k2l4p6δ3c4/ϵ8)B=\Omega\left(k^{2}l^{4}p^{6}\delta^{3}c^{4}\big{/}\epsilon^{8}\right) for T=O(k2l2p6δ3c4/ϵ8)T=O\left(k^{2}l^{2}p^{6}\delta^{3}c^{4}\big{/}\epsilon^{8}\right) iterations (i.e. NJ=Ω(k4l6p12δ6c8/ϵ16)N_{J}=\Omega(k^{4}l^{6}p^{12}\delta^{6}c^{8}/\epsilon^{16})) to ensure,

(x,y)𝒟c[j[c],jy,fMy(θ(T),x)>fMj(θ(T),x)]1ϵ\displaystyle\underset{(x,y)\sim\mathcal{D}_{c}}{\mathbb{P}}\left[\forall j\in[c],j\neq y,f_{M_{y}}(\theta^{(T)},x)>f_{M_{j}}(\theta^{(T)},x)\right]\geq 1-\epsilon

Appendix N Details of the Results in Table 1

Complexity in forward pass. The computational complexity of a non-overlapping convolution operation by a filter of dimension dd on an input sample of nn patches (of same dimension as the filter) is O(nd)O(nd) (Vaswani et al., 2017). Therefore, the complexity of forward pass of a batch of size BB through a convolution layer of mm neurons is O(Bmnd)O(Bmnd). Similarly, the forward pass complexity of a the batch through the experts (of same total number of neurons as in the convolution layer) of a pMoE layer is O(Bmld)O(Bmld). The operations in a pMoE router includes convolution (with complexity O(nd)O(nd)), softmax operation (with complexity O(1)O(1)) and TOP-ll operation (with complexity O(nl)O(nl) when lnl\ll n). Therefore, the overall forward pass complexity of a pMoE router with kk expert is O(Bknd)O(Bknd).

Complexity in backward pass. The gradient of neurons in convolution layer for an input sample is given in (21), which implies that the complexity of the gradient calculation is O(nd)O(nd) (addition of nn vectors of dimension dd) and hence the backward pass complexity of CNN is O(Bmnd)O(Bmnd). Similarly, the backward pass complexity of pMoE experts is O(Bmld)O(Bmld). Now the gradient of gating kernels in pMoE router is given in (26), which implies that the complexity of the gradient calculation is O(l2d)O(l^{2}d) (addition of l2l^{2} vectors of dimension dd) and hence the backward pass complexity of pMoE router is O(Bkl2d)O(Bkl^{2}d).

(θ(t),x,y)ws(t)=yv(t)(θ(t),x,y)(r[m]ar,s(1ljJs(x)ReLU(wr,s(t),x(j))Gj,s(iJs(x)/j(x(j)x(i))Gi,s)))\displaystyle\frac{\partial\mathcal{L}(\theta^{(t)},x,y)}{\partial w_{s}^{(t)}}=-yv^{(t)}(\theta^{(t)},x,y)\left(\sum_{r\in[m]}a_{r,s}\left(\cfrac{1}{l}\sum_{j\in J_{s}(x)}\textbf{ReLU}(\langle w_{r,s}^{(t)},x^{(j)}\rangle)G_{j,s}(\sum_{i\in J_{s}(x)/j}(x^{(j)}-x^{(i)})G_{i,s})\right)\right) (26)

Complexity to achieve ϵ\epsilon generalization error. From Theorem 4.5, to achieve ϵ\epsilon generalization error we need O(k2l2/ϵ8)O(k^{2}l^{2}/\epsilon^{8}) iterations of training in pMoE, which implies that the computational complexity to achieve ϵ\epsilon error in pMoE is O(Bmk2l3d/ϵ8)O(Bmk^{2}l^{3}d/\epsilon^{8}). Similarly, using the results from Theorem 4.3, the corresponding complexity in CNN is O(Bmn5d/ϵ8)O(Bmn^{5}d/\epsilon^{8}).