This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

\coltauthor\Name

Pranjal Awasthi \Email[email protected]
\addrGoogle Research and \NameNishanth Dikkala \Email[email protected]
\addrGoogle Research and \NamePritish Kamath \Email[email protected]
\addrGoogle Research and \NameRaghu Meka \Email[email protected]
\addrUniversity of California, Los Angeles

Learning Neural Networks with Sparse Activations

Abstract

A core component present in many successful neural network architectures, is an MLP block of two fully connected layers with a non-linear activation in between. An intriguing phenomenon observed empirically, including in transformer architectures, is that, after training, the activations in the hidden layer of this MLP block tend to be extremely sparse on any given input. Unlike traditional forms of sparsity, where there are neurons/weights which can be deleted from the network, this form of dynamic activation sparsity appears to be harder to exploit to get more efficient networks.

Motivated by this we initiate a formal study of PAC learnability of MLP layers that exhibit activation sparsity. We present a variety of results showing that such classes of functions do lead to provable computational and statistical advantages over their non-sparse counterparts. Our hope is that a better theoretical understanding of sparsely activated networks would lead to methods that can exploit activation sparsity in practice.

keywords:
Multilayer Perceptrons, PAC Learning, Activation Sparsity, Rademacher Complexity

1 Introduction

In recent years, transformer based deep neural networks (Vaswani et al., 2017) and the subsequent development of large language models have marked a paradigm shift in the fields of natural language processing and computer vision (Brown et al., 2020; Chowdhery et al., 2022; Chen et al., 2022b; Dosovitskiy et al., 2020). These models have significantly improved performance across various tasks, setting new benchmarks and enabling previously unattainable breakthroughs. However, the computational cost of training and deploying these models, especially the largest variants, presents a significant challenge. A notable portion of these models’ computational and parameter overhead is attributed to the Multi-Layer Perceptron (MLP) layers. These layers are integral to the transformer architecture, playing a crucial role in its ability to solve many different tasks.

Despite their efficacy, the resource-intensive nature of these models has spurred a wave of research focused on enhancing their efficiency (Banner et al., 2019; Frankle and Carbin, 2018; Gholami et al., 2022; Hinton et al., 2015; Anil et al., 2018; Harutyunyan et al., 2023). Among the various strategies explored for improving the inference efficiency of large transformers, attempting to sparsify the transformer is a promising approach.

A motivation for exploiting sparsity is rooted in an intriguing empirical observation made in recent works (Li et al., 2023) regarding the behavior of MLP layers within large transformer models. Post-training, these layers tend to exhibit a high degree of sparsity in their activations; often each input activates as low as 3% of the neurons in the MLP layers, suggesting a natural emergence of sparsity in activations. This leads to these MLP layers behaving like key-value lookups (Geva et al., 2020). The extremely low sparsity (3%) suggests that there might be significant room to sparsify the MLP layers leading to both training and inference efficiency. In addition, such sparsity also helps with interpretability of transformers by disentangling neurons corresponding to distinct concepts (Elhage et al., 2022). Moreover, through extensive ablation studies Li et al. (2023) observe that this phenomenon is highly prevalent. It occurs in convolutional networks (CNNs), as well as in vanilla fully connected feedforward networks.

Despite the potential benefits, effectively harnessing dynamic sparsity has proven challenging. Although, there have been many recent efforts (Li et al., 2023; Grimaldi et al., 2023; Liu et al., 2023; Dong et al., 2023; Csordás et al., 2023; Mirzadeh et al., 2023), they have led to limited success. None of the approaches achieve speedups (either in training or in inference) anywhere close to the the potential factor of 33x that is suggested by 3% sparsity. Moreover, by explicitly enforcing sparsity via methods such as choosing only the top-kk activations, the quality of the model degrades in some cases.

A key reason for the hardness in exploiting activation sparsity is that this form of sparsity is dynamic in nature and is input-dependent (i.e., not a fixed pattern). While each input example activates a small number of neurons, the overall sparsity pattern cannot be localized to a small subset of the model weights. For instance, the dynamic nature precludes the use of typical weight quantization or pruning based methods to exploit sparsity empirically. On the other hand, having a non-localized sparsity pattern is crucial in ensuring the model has rich expressiveness.

The above observations suggest that post-training, large transformer networks belong to an intriguing function class that is highly expressive yet exhibits high sparsity. Given the challenges in exploiting this behavior in practical settings, in this work, we initiate a theoretical study of the statistical and computational properties of such functions in the probably approximately correct (PAC) learning framework (Valiant, 1984).

We introduce the class of sparsely activated MLPs. We focus on the case of depth-11 MLPs with nn input units and ss hidden units with the standard ReLU activations. We define the class n,s,k\mathcal{H}_{n,s,k} as the class of depth-11 ReLU networks in nn-dimensions with the promise that on each input in the support of the data distribution, at most kk of the ss hidden units are active:

Definition 1.1 (Sparsely Activated Networks).

Let σ()\sigma(\cdot) denote the 𝖱𝖾𝖫𝖴\mathsf{ReLU} activation, namely σ(z):=max{z,0}\sigma(z):=\max\{z,0\}. The class n,s,k\mathcal{H}_{n,s,k} consists of hypotheses of the form h(x)=j=1sujσ(wj,xbj)h(x)=\sum_{j=1}^{s}u_{j}\sigma(\left\langle w_{j},x\right\rangle-b_{j}) with the property that for all xx in the support of the distribution, it holds that |{j:wj,xbj>0}|k|\{j:\left\langle w_{j},x\right\rangle-b_{j}>0\}|\leq k.

Note that this sparsity differs from dead sparsity, where some neurons are never active on any of the inputs, and consequently, can be deleted from the network without impacting its functionality. The form of dynamic sparsity we study can be crucial for the networks to be more expressive. We provide a couple of examples of useful functions represented using sparsely activated networks here:

  • Junta functions: The class of functions on nn variables which depend on only a pp-sized subset (p<np<n) of the variables is known as pp-junta functions. Sparse parities are a canonical example of junta functions. We show in Theorem 4.2 that we can represent log(s)\log(s)-juntas using n,s,1\mathcal{H}_{n,s,1}.

  • Indexing function: Consider the function 𝖨𝗇𝖽𝖾𝗑b:{1,1}b+2b{0,1}\mathsf{Index}_{b}:\{-1,1\}^{b+2^{b}}\to\{0,1\}, where 𝖨𝗇𝖽𝖾𝗑b(z)\mathsf{Index}_{b}(z) is the xx-th bit of yy (1-1 mapped to 0), where xx is the integer represented by the first bb bits of zz in binary representation, and yy is the remaining 2b2^{b} bits vector. This can be represented as a 11-sparse activation network of size 2b2^{b} (i.e., in b+2b,2b,1\mathcal{H}_{b+2^{b},2^{b},1}): 𝖨𝗇𝖽𝖾𝗑b((x,y))=α{1,1}bσ(wα,zb+12)\mathsf{Index}_{b}((x,y))=\sum_{\alpha\in\{-1,1\}^{b}}\sigma(\langle{w_{\alpha},z}\rangle-b+\frac{1}{2}) where the first bb coordinates of wαw_{\alpha} are α\alpha and the α\alpha-th coordinate among the last 2b2^{b} coordinates is 12\frac{1}{2}. On input z=(x,y)z=(x,y), only the neuron corresponding to α=x\alpha=x is activated, and the output is precisely 12yx+12\frac{1}{2}y_{x}+\frac{1}{2}.

In both the examples presented above, removing any of the ss neurons will change the functionality of the network. However, each weight vector wiw_{i} is quite sparse. In Appendix A, we present an example of a sparsely activated network where even the weight vectors wiw_{i} are not sparse. Hence, in general, it is not clear if sparsely activated networks can be represented with fewer neurons or sparse weight vectors.

In order to provide learning guarantees, we have to assume an upper bound on the scale of uu, wjw_{j}’s and bjb_{j}’s. We will use the following natural scaling for the paper:

Definition 1.2.

Let n,s,kW,Bn,s,k\mathcal{H}_{n,s,k}^{W,B}\subseteq\mathcal{H}_{n,s,k} consisting of hh given as h(x)=j=1sujσ(wj,xbj)h(x)=\sum_{j=1}^{s}u_{j}\sigma(\left\langle w_{j},x\right\rangle-b_{j}), satisfying umaxj[s]wj2W\|u\|_{\infty}\cdot\max_{j\in[s]}\|w_{j}\|_{2}\leq W and umaxj[s]|bj|B\|u\|_{\infty}\cdot\max_{j\in[s]}|b_{j}|\leq B.

We then consider the problem of learning sparsely activated networks efficiently. We consider the domain to be the Boolean hypercube 𝒳={1,1}n\mathcal{X}=\{1,-1\}^{n} as a natural first-step and as a domain where sparsely activated networks can compute non-trivial functions. The Boolean hypercube provides a setting where the function can be sparse everywhere in the domain while maintaining expressiveness; this appears harder in the continuous setting. For instance, if the inputs are Gaussian over n\mathbb{R}^{n}, one likely needs the biases in the ReLU units to be very large to enforce 11-sparsity. This suggests that, in the continuous domain, more non-standard distributions are likely necessary to obtain a rich class of functions which are sparse everywhere in the domain. Hence for theoretical simplicity we focus on functions on the Boolean hypercube.

Even with the sparsity assumption, the class n,s,1\mathcal{H}_{n,s,1} is likely hard to learn in polynomial time (or even quasi-polynomial time) under an arbitrary distribution on the hypercube. In particular, we show that parities on the hypercube on kk variables can be computed by k2,2k,1\mathcal{H}_{k^{2},2k,1}, with coefficient vectors of norm at most O(k)O(k). Thus, n,O(n),1\mathcal{H}_{n,O(\sqrt{n}),1} need 2Ω(n)2^{\Omega(\sqrt{n})} queries in the powerful Statistical Queries (SQ) model (see \Crefsec:lb-uniform for details). We also show cryptographic hardness results for learning n,s,1\mathcal{H}_{n,s,1} under generic distributions on the hypercube.

Theorem 1.3 (Informal; see \Crefsec:lb-uniform).

Any SQ algorithm for learning n,O(n),1O(n0.75),O(n)\mathcal{H}_{n,O(\sqrt{n}),1}^{O(n^{0.75}),O(n)} under arbitrary distributions over the hypercube either requires 2Ω(n)2^{-\Omega(\sqrt{n})} tolerance or 2Ω(n)2^{\Omega(\sqrt{n})} queries.

Assuming the hardness of learning with rounding problem with polynomial modulus, there is no 𝗉𝗈𝗅𝗒(n,s,W,B,1/ε)\mathsf{poly}(n,s,W,B,1/\varepsilon) run-time algorithm to (ε,δ)(\varepsilon,\delta)-PAC learn n,s,1W,B\mathcal{H}_{n,s,1}^{W,B}.

Learning under uniform distribution.

Given the above hardness results, it is natural to consider distributional assumptions as is often done for related classes in learning theory (e.g., Klivans et al. (2004); Kane (2014) etc.). Our main result is that when the input distribution is uniform over the nn-dimensional hypercube, {1,1}n\{1,-1\}^{n}, the class n,s,kW,B\mathcal{H}^{W,B}_{n,s,k} can be learned in time n𝗉𝗈𝗅𝗒(klog(ns))n^{\mathsf{poly}(k\log(ns))}:

Theorem 1.4 (Informal; see \Crefthm:generalk-uniform-ub).

There exists an (ε,δ)(\varepsilon,\delta)-PAC learning algorithm for n,s,kW,B\mathcal{H}_{n,s,k}^{W,B} with respect to the uniform distribution over {1,1}n\{1,-1\}^{n} that has sample complexity and run-time n𝗉𝗈𝗅𝗒(klog(ns))/ε2log(1/δ)/εn^{\mathsf{poly}(k\log(ns))/\varepsilon^{2}}\log(1/\delta)/\varepsilon (suppressing dependence on W,BW,B).

As our learning algorithm works by performing linear regression over low-degree monomial basis (a.k.a. the low-degree algorithm), the guarantees work even in the agnostic or non-realizable setting by standard arguments (e.g., Klivans et al. (2004)). For simplicity, we focus on the realizable setting as the algorithm and analysis do not change for the agnostic case.

For sparsity k=1k=1, the above run-time is nO(𝗉𝗈𝗅𝗒(log(ns))/ε2)n^{O(\mathsf{poly}(\log(ns))/\varepsilon^{2})}. As we showed above, n,s,1\mathcal{H}_{n,s,1} can simulate juntas of size log2s\log_{2}s over nn variables. Thus, a quasi-polynomial run-time is the best we can do under a widely believed conjecture on the hardness of learning juntas.

The guarantee above is in stark contrast to what is achievable for general one-layer size ss ReLU networks under the uniform distribution over the hypercube. One-layer size-ss networks can simulate parities on min(n,s)\min(n,s) variables. They thus cannot be learned even under the uniform distribution on the hypercube by SQ algorithms with less than 2Ω(min(n,s))2^{\Omega(\min(n,s))} queries. Further, even for non-SQ algorithms, as shown in (Chen et al., 2022a), quasi-polynomial run-time with respect to the uniform distribution on the hypercube is impossible under widely studied cryptographic assumptions.

The proof of \Crefthm:k-uniform-ub is via Fourier analysis and the low-degree algorithm. The main ingredient is to show that the average-sensitivity of functions in n,s,k\mathcal{H}_{n,s,k} is at most O(k4(nlog(ns)))O(k^{4}(\sqrt{n}\log(ns))). We then use this bound the noise-sensitivity of functions in n,s,k\mathcal{H}_{n,s,k}. The latter implies the existence of a low-degree approximation by exploiting Klivans et al. (2004) which is enough to obtain the theorem. See \Crefsec:ub-uniform for details.

Learning under general distributions.

We also show that n,s,kW,B\mathcal{H}_{n,s,k}^{W,B} can be learnt under general distributions with smaller sample complexity than would be required without the sparsity condition, in the case when skns\gg kn. In particular, we show the following.

Theorem 1.5 (Informal; see \Crefthm:general-dist-upper-bound).

There exists an (ε,δ)(\varepsilon,\delta)-PAC learning algorithm for n,s,kW,B\mathcal{H}_{n,s,k}^{W,B} over {1,1}n\{1,-1\}^{n} that has sample complexity O~(ksn/ε2)\widetilde{O}\left(ksn/\varepsilon^{2}\right) (suppressing dependence on W,B,δW,B,\delta).

By contrast, the class n,s,sW,B\mathcal{H}_{n,s,s}^{W,B} (that is, size-ss networks without activation sparsity) requires a sample complexity of Ω(s2/ε2)\Omega(s^{2}/\varepsilon^{2}).linecolor=myGold,backgroundcolor=myGold!25,bordercolor=myGold]Pritish: Is this right? To prove the above, we provide a bound on the Rademacher complexity of the class n,s,kW,B\mathcal{H}_{n,s,k}^{W,B} that has an improved dependence on ss.

Taken together, our results demonstrate that leveraging dynamic activation sparsity is theoretically possible for both computational and statistical benefits. We hope that further theoretical study of the class of sparsely activated networks could pave the way for more efficient training and inference methods for deep architectures, including transformer-based models where these sparsely activated networks have been observed to arise in practice.

1.1 Related Work

Our work is motivated by recent empirical observations on the extreme sparsity observed in the MLP layers of trained transformer models (Li et al., 2023; Shen et al., 2023). The works of Li et al. (2023); Peng et al. (2023) propose theoretical explanations of why this phenomenon occurs. However, ours is the first work to formally study sparsely activated networks in the PAC learning setup and quantify their computational and statistical advantages. Motivated by the observation on sparsity, recent work has also studied the connections between the MLP layers and key-value memory lookups (Sukhbaatar et al., 2019; Lample et al., 2019; Geva et al., 2020).

There have also been recent works on designing networks with explicitly enforced sparsity structure. One such line of work concerns mixture of experts models (Shazeer et al., 2017; Fedus et al., 2022) where each input is independently routed to one or two MLP blocks among a set of experts. An alternate way to enforce sparsity is to introduce a top-kk operation after each MLP layer that zeros out most of the activations (Csordás et al., 2023; Li et al., 2023). In particular, Li et al. (2023) propose a top-kk transformer along these lines. However, due to the top-kk operation being relatively slow on accelerator hardware, this technique does not yield wall-clock speedup for either training or inference.

In another recent work Liu et al. (2023) propose to train a small predictor network to predict the activated indices at each MLP layer. There has also been work to explore enforcing block sparsity constraints and weight tying in the model weights themselves (Dong et al., 2023), as well as efforts to enforce static sparsity that is not input dependent (Frantar and Alistarh, 2023). However such methods haven’t been effective for language modeling via transformer models and have been much more successful in classification domains that have a small number of output labels.

Significantly more attention has been given to sparsifying attention layer computation (Zaheer et al., 2020; Choromanski et al., 2020; Wang et al., 2020; Gu and Dao, 2023). Instead, our focus in this work here is understanding the sparsity behavior of the MLP layer.

2 Preliminaries

We consider the problem of learning real-valued functions over the input space 𝒳={1,1}n\mathcal{X}=\{-1,1\}^{n}, to small expected 2\ell_{2}-squared error, namely for the underlying distribution 𝒟\mathcal{D} over (x,y)𝒳×(x,y)\in\mathcal{X}\times\mathbb{R}, our goal is the minimize the population loss of a predictor f:𝒳f:\mathcal{X}\to\mathbb{R} given as 𝒟(f):=𝔼(x,y)𝒟(f(x),y)\mathcal{L}_{\mathcal{D}}(f):=\operatorname*{\mathop{\mathbb{E}}}_{(x,y)\sim\mathcal{D}}\ell(f(x),y) where (y^,y):=12(y^y)2\ell(\hat{y},y):=\frac{1}{2}(\hat{y}-y)^{2}. For any dataset S(𝒳×)S\in(\mathcal{X}\times\mathbb{R})^{*}, we denote the empirical loss as S(f):=1|S|(x,y)S(f(x),y)\mathcal{L}_{S}(f):=\frac{1}{|S|}\sum_{(x,y)\in S}\ell(f(x),y).

For any hypothesis class 𝒳\mathcal{H}\subseteq\mathbb{R}^{\mathcal{X}}, we say that 𝒟\mathcal{D} is \mathcal{H}-realizable, if there exists hh^{\star}\in\mathcal{H} such that h(x)=yh^{\star}(x)=y holds with probability 11 for (x,y)𝒟(x,y)\sim\mathcal{D}. Following the standard definition of probably approximately correct (PAC) learning (Valiant, 1984), we say that a learning algorithm 𝒜\mathcal{A} (ε,δ)(\varepsilon,\delta)-PAC learns \mathcal{H} with sample complexity m(ε,δ)m(\varepsilon,\delta) if for all \mathcal{H}-realizable distributions 𝒟\mathcal{D} over 𝒳×\mathcal{X}\times\mathbb{R}, and for S𝒟m(ε,δ)S\sim\mathcal{D}^{m(\varepsilon,\delta)}, it holds with probability at least 1δ1-\delta that 𝒟(𝒜(S))ε\mathcal{L}_{\mathcal{D}}(\mathcal{A}(S))\leq\varepsilon. We say that a learning algorithm 𝒜\mathcal{A} (ε,δ)(\varepsilon,\delta)-PAC learns \mathcal{H} under distribution 𝒫\mathcal{P} (over 𝒳\mathcal{X}) if the learning guarantee holds for all \mathcal{H}-realizable 𝒟\mathcal{D} with the marginal over 𝒳\mathcal{X} being 𝒫\mathcal{P}. In particular, we use 𝒰\mathcal{U} to denote the uniform distribution over 𝒳\mathcal{X}.

2.1 Fourier Analysis and the Low-Degree Algorithm

Any function f:{1,1}nf:\{-1,1\}^{n}\to\mathbb{R}, has a unique Fourier representation given as T[n]f^(T)χT(x)\sum_{T\subseteq[n]}\hat{f}(T)\chi_{T}(x) where χT(x):=jTxi\chi_{T}(x):=\prod_{j\in T}x_{i}. The degree of ff, denoted deg(f)\deg(f), is the largest kk such that f^(T)0\hat{f}(T)\neq 0 for some TT with |T|=k|T|=k. The 2\ell_{2} norm of ff under the uniform distribution is defined as f22:=𝔼x𝒰f(x)2\|f\|_{2}^{2}:=\operatorname*{\mathop{\mathbb{E}}}_{x\sim\mathcal{U}}f(x)^{2} (O’Donnell, 2014).

We define the 2\ell_{2} sensitivity of ff at xx as 𝗌𝖾𝗇f(x):=14i[n](f(x)f(xi))2\mathsf{sen}_{f}(x):=\frac{1}{4}\sum_{i\in[n]}(f(x)-f(x^{\oplus i}))^{2}, where xix^{\oplus i} is xx with the ii-th bit flipped; the scaling factor of 1/41/4 means that for f:{1,1}n{1,1}f:\{-1,1\}^{n}\to\{-1,1\}, sensitivity can be interpreted as 𝗌𝖾𝗇f(x)=|{i:f(x)f(xi)}|\mathsf{sen}_{f}(x)=|\{i:f(x)\neq f(x^{\oplus i})\}|. The average 22\ell_{2}^{2} sensitivity 𝖠𝖲(f)\mathsf{AS}(f) is defined as 𝔼x𝒰[𝗌𝖾𝗇f(x)]\operatorname*{\mathop{\mathbb{E}}}_{x\sim\mathcal{U}}\left[\mathsf{sen}_{f}(x)\right]. For any xx, let Nρ(x)N_{\rho}(x) denote the distribution obtained by flipping each coordinate of xx with probability (1ρ)/2(1-\rho)/2. The ρ\rho-noise sensitivity of ff is 𝖭𝖲ρ(f):=𝔼x𝒰,yNρ(x)14(f(x)f(y))2\mathsf{NS}_{\rho}(f):=\operatorname*{\mathop{\mathbb{E}}}_{x\sim\mathcal{U},y\sim N_{\rho}(x)}\frac{1}{4}(f(x)-f(y))^{2}.

A connection between noise sensitivity and Fourier concentration was first observed in Klivans et al. (2004). We state this connection below, along with other basic facts about Fourier coefficients.

Claim 1.

[See Klivans et al. (2004)] The following properties hold for all f:{1,1}nf:\{-1,1\}^{n}\to\mathbb{R}:

  • f22=T[n]f^(T)2\|f\|_{2}^{2}=\sum_{T\subseteq[n]}\hat{f}(T)^{2}, and

  • 𝖭𝖲ρ(f)=T[n]12(1ρ|T|)f^(T)2\mathsf{NS}_{\rho}(f)=\sum_{T\subseteq[n]}\frac{1}{2}(1-\rho^{|T|})\hat{f}(T)^{2}, and hence T:|T|>df^(T)22𝖭𝖲ρ(f)/(1ρd)\sum_{T:|T|>d}\hat{f}(T)^{2}\leq 2\cdot\mathsf{NS}_{\rho}(f)/(1-\rho^{d}).

We also need a bound on the average sensitivity of a single halfspace which is known to be O(n)O(\sqrt{n}). We require a more fine-grained version from Kane (2014) which quantifies the dependence on the bias of the halfspace.

Lemma 2.1 (Kane (2014)).

Let g:𝒳{0,1}g:\mathcal{X}\rightarrow\{0,1\} be a halfspace: g(x)=𝟙{w,xb}g(x)=\mathds{1}\{\left\langle w,x\right\rangle\leq b\} and 𝔼[g]=p\operatorname*{\mathop{\mathbb{E}}}[g]=p. Then, 𝖠𝖲(g)=O(pnlog(1/p))\mathsf{AS}(g)=O(p\sqrt{n\log(1/p)}).

Proof 2.2.

Without loss of generality, we can assume that the coefficients of ww are positive. This makes gg a monotone function which is non-decreasing in each coordinate. Now, for i[n]i\in[n], and x𝒰x\sim\mathcal{U},

𝔼[xig(x)]=12x𝒳(xig(x)xig(xi))=𝔼[(g(x)g(xi))2],\operatorname*{\mathop{\mathbb{E}}}[x_{i}g(x)]=\frac{1}{2}\sum_{x\in\mathcal{X}}\left(x_{i}g(x)-x_{i}g(x^{\oplus i})\right)=\operatorname*{\mathop{\mathbb{E}}}[(g(x)-g(x^{\oplus i}))^{2}],

where the second equality is due to the non-decreasing nature of gg and that g(x)g(x) takes values in {0,1}\{0,1\}. Therefore,

𝖠𝖲(g)=14𝔼x[g(x)i=1nxi],\textstyle\mathsf{AS}(g)=\frac{1}{4}\operatorname*{\mathop{\mathbb{E}}}_{x}\left[g(x)\sum_{i=1}^{n}x_{i}\right],

the claim now follows from Lemma 6 of Kane (2014).

Low-degree algorithm.

We recall the standard low-degree algorithm and its guarantees for learning hypothesis classes that exhibit low-degree Fourier concentration (see e.g., Klivans et al. (2004) for details). For any hypothesis class (𝒳)\mathcal{H}\subseteq(\mathcal{X}\to\mathbb{R}), let C:=suph,x𝒳h(x)C_{\mathcal{H}}:=\sup_{h\in\mathcal{H},x\in\mathcal{X}}h(x).

Lemma 2.3.

For hypothesis class (𝒳)\mathcal{H}\subseteq(\mathcal{X}\to\mathbb{R}) such that T:|T|>dh^(T)2ε\sum_{T:|T|>d}\hat{h}(T)^{2}\leq\varepsilon for all hh\in\mathcal{H}, there exists an (O(ε),δ)(O(\varepsilon),\delta)-PAC learning algorithm for \mathcal{H} with O(ndC2log(1/δ)/ε)O(n^{d}C_{\mathcal{H}}^{2}\log(1/\delta)/\varepsilon) sample and time complexity.

The algorithm operates by performing polynomial regression, that is, linear regression in the basis of monomials of degree at most dd. The algorithm achieves the desired error because g(x):=T:|T|dh^(T)χT(x)g(x):=\sum_{T:|T|\leq d}\hat{h}(T)\chi_{T}(x) is such that gh22=T:|T|>dh^(T)2ε/2\|g-h\|_{2}^{2}=\sum_{T:|T|>d}\hat{h}(T)^{2}\leq\varepsilon/2, and hence there exists a good solution to the polynomial regression problem.linecolor=myGold,backgroundcolor=myGold!25,bordercolor=myGold]Pritish: Cite Hsu-Kakade-Zhang?

3 Learning over Uniform Distribution

In this section we provide a learning algorithm for kk-sparsely activated networks under the uniform distribution.

Theorem 3.1.

There exists an (ε,δ)(\varepsilon,\delta)-PAC learning algorithm for n,s,kW,B\mathcal{H}_{n,s,k}^{W,B} with respect to the uniform distribution over 𝒳\mathcal{X} that has sample complexity and run-time O(ndk2(Wn+B)2log(1/δ)/ε)O(n^{d}k^{2}(W\sqrt{n}+B)^{2}\log(1/\delta)/\varepsilon) for d=Θ((k8W4log(ns)4+k6B4logs)/ε2)d=\Theta((k^{8}W^{4}\log(ns)^{4}+k^{6}B^{4}\log s)/\varepsilon^{2})

At a high level, we show that all hypotheses in n,s,kW,B\mathcal{H}_{n,s,k}^{W,B} exhibit low-degree Fourier concentration and hence can be learned over the uniform distribution using the low-degree algorithm (\Creflem:low-degree-alg). To show Fourier concentration, we bound the noise sensitivity of sparse-activated networks by first showing a bound on the average sensitivity and then converting this to a bound on noise sensitivity.

Lemma 3.2.

For all hn,s,kW,Bh\in\mathcal{H}_{n,s,k}^{W,B}, it holds that 𝖠𝖲(h)O(k4W2nlog(ns)+k3B2logs)\mathsf{AS}(h)~{}\leq~{}O\left(k^{4}W^{2}\sqrt{n}\log(ns)+k^{3}B^{2}\sqrt{\log s}\right).

Proof 3.3.

Consider hn,s,kW,Bh\in\mathcal{H}_{n,s,k}^{W,B} given as h(x)=j=1sujσ(wj,xbj)h(x)=\sum_{j=1}^{s}u_{j}\sigma(\left\langle w_{j},x\right\rangle-b_{j}). For any R[s]R\subseteq[s], let R(x)=wR,xbR\ell_{R}(x)~{}=~{}\left\langle w^{R},x\right\rangle-b^{R} for wR:=jRujwjw^{R}:=\sum_{j\in R}u_{j}w_{j} and bR:=jRujbjb^{R}:=\sum_{j\in R}u_{j}b_{j}. Since maxj|uj|maxjwjW\max_{j}|u_{j}|\cdot\max_{j}\|w_{j}\|\leq W and maxj|uj|maxj|bj|B\max_{j}|u_{j}|\cdot\max_{j}|b_{j}|\leq B, it follows that wR|R|W\|w^{R}\|\leq|R|\cdot W and |bR||R|B|b^{R}|\leq|R|\cdot B. For any x𝒳x\in\mathcal{X}, let Rx[s]R_{x}\subseteq[s] be defined as Rx:={j[s]:wj,x>bj}R_{x}:=\{j\in[s]:\left\langle w_{j},x\right\rangle>b_{j}\}. Since hh is kk-sparse, we have that |Rx|k|R_{x}|\leq k and hence wRxkW\|w^{R_{x}}\|\leq kW and |bRx|kB|b^{R_{x}}|\leq kB. It is easy to see that for hn,s,kW,Bh\in\mathcal{H}_{n,s,k}^{W,B} it holds that h(x)=Rx(x)h(x)=\ell_{R_{x}}(x) for all x𝒳x\in\mathcal{X}.

The average sensitivity of hh is given as

𝖠𝖲(h)\displaystyle\mathsf{AS}(h) =𝔼x[i=1n14(h(x)h(xi)2]\displaystyle\textstyle~{}=~{}\operatorname*{\mathop{\mathbb{E}}}_{x}\left[\sum_{i=1}^{n}\frac{1}{4}\left(h(x)-h(x^{\oplus i}\right)^{2}\right]
=𝔼x[i=1n14(h(x)h(xi))2𝟙{Rx=Rxi}]\displaystyle\textstyle~{}=~{}\operatorname*{\mathop{\mathbb{E}}}_{x}\left[\sum_{i=1}^{n}\frac{1}{4}\left(h(x)-h(x^{\oplus i})\right)^{2}\cdot\mathds{1}\{R_{x}=R_{x^{\oplus i}}\}\right] (U)
+𝔼x[i=1n14(h(x)h(xi)2𝟙{RxRxi}]\displaystyle\textstyle~{}~{}~{}~{}+~{}\operatorname*{\mathop{\mathbb{E}}}_{x}\left[\sum_{i=1}^{n}\frac{1}{4}\left(h(x)-h(x^{\oplus i}\right)^{2}\cdot\mathds{1}\{R_{x}\neq R_{x^{\oplus i}}\}\right] (V)

We bound term (U) as,

(U) =𝔼x[i=1n14(Rx(x)Rx(xi))2𝟙{Rx=Rxi}]\displaystyle\textstyle~{}=~{}\operatorname*{\mathop{\mathbb{E}}}_{x}\left[\sum_{i=1}^{n}\frac{1}{4}\left(\ell_{R_{x}}(x)-\ell_{R_{x}}(x^{\oplus i})\right)^{2}\cdot\mathds{1}\{R_{x}=R_{x^{\oplus i}}\}\right]
𝔼x[i=1n14(wiRx)2]=𝔼x[14wRx2]k2W24.\displaystyle\textstyle~{}\leq~{}\operatorname*{\mathop{\mathbb{E}}}_{x}\left[\sum_{i=1}^{n}\frac{1}{4}\left(w^{R_{x}}_{i}\right)^{2}\right]~{}=~{}\operatorname*{\mathop{\mathbb{E}}}_{x}\left[\frac{1}{4}\|w^{R_{x}}\|^{2}\right]~{}\leq~{}\frac{k^{2}W^{2}}{4}\,.

We bound term (V) as follows using the inequality (ab)22(w2+b2)(a-b)^{2}\leq 2(w^{2}+b^{2}),

(V) =n𝔼x,i[14(h(x)h(xi)2𝟙{RxRxi}]\displaystyle\textstyle~{}=~{}n\operatorname*{\mathop{\mathbb{E}}}_{x,i}\left[\frac{1}{4}\left(h(x)-h(x^{\oplus i}\right)^{2}\cdot\mathds{1}\{R_{x}\neq R_{x^{\oplus i}}\}\right]
n𝔼x,i[12(h(x)2+h(xi)2)𝟙{RxRxi}]\displaystyle\textstyle~{}\leq~{}n\operatorname*{\mathop{\mathbb{E}}}_{x,i}\left[\frac{1}{2}\left(h(x)^{2}+h(x^{\oplus i})^{2}\right)\cdot\mathds{1}\{R_{x}\neq R_{x^{\oplus i}}\}\right]
=n𝔼x,i[h(x)2𝟙{RxRxi}](by symmetry)\displaystyle\textstyle~{}=~{}n\operatorname*{\mathop{\mathbb{E}}}_{x,i}\left[h(x)^{2}\cdot\mathds{1}\{R_{x}\neq R_{x^{\oplus i}}\}\right]\qquad\text{(by symmetry)}

For gj(x):=𝟙{wj,x>bj}g_{j}(x):=\mathds{1}\{\left\langle w_{j},x\right\rangle>b_{j}\}, we have that

Prx,i[RxRxi]\displaystyle\Pr_{x,i}[R_{x}\neq R_{x^{\oplus i}}] 1nj=1si=1nPrx[gj(x)gj(xi)]=1nj=1s𝖠𝖲(gj)\displaystyle\textstyle~{}\leq~{}\frac{1}{n}\sum_{j=1}^{s}\sum_{i=1}^{n}\Pr_{x}[g_{j}(x)\neq g_{j}(x^{\oplus i})]=\frac{1}{n}\sum_{j=1}^{s}\mathsf{AS}(g_{j})

Note that j=1sgj(x)k\sum_{j=1}^{s}g_{j}(x)\leq k (by kk-sparsity), and hence for pj=𝔼x[gj(x)]p_{j}=\operatorname*{\mathop{\mathbb{E}}}_{x}[g_{j}(x)], we have that j=1spjk\sum_{j=1}^{s}p_{j}\leq k. From \Creflm:ashalfspace, we have that 𝖠𝖲(gj)pjnlog(1/pj)\mathsf{AS}(g_{j})\leq p_{j}\sqrt{n\log(1/p_{j})}. Thus,

Prx,i[RxRxi]\displaystyle\Pr_{x,i}[R_{x}\neq R_{x^{\oplus i}}] 1nj=1spjnlog(1/pj)klog(s/k)n\displaystyle~{}\leq~{}\frac{1}{n}\sum_{j=1}^{s}p_{j}\sqrt{n\log(1/p_{j})}~{}\leq~{}\frac{k\sqrt{\log(s/k)}}{\sqrt{n}}

where we use concavity of plog(1/p)p\sqrt{\log(1/p)} for p(0,1)p\in(0,1). For each R[s]R\subseteq[s] with |S|k|S|\leq k, we have by Hoeffding bound that for some sufficiently large cc and t=ckWlog(nks)+kBt=ckW\sqrt{\log(n^{k}s)}+kB,

Prx𝒰[R[s]:|R|k and |wR,xbR|>t]\displaystyle\Pr_{x\sim\mathcal{U}}\left[\exists R\subseteq[s]:|R|\leq k\ \text{ and }\ \left|\left\langle w^{R},x\right\rangle-b^{R}\right|>t\right]
Prx𝒰[R[s]:|R|k and |wR,x|>t|bR|]\displaystyle~{}\leq~{}\Pr_{x\sim\mathcal{U}}\left[\exists R\subseteq[s]:|R|\leq k\ \text{ and }\ \left|\left\langle w^{R},x\right\rangle\right|>t-\left|b^{R}\right|\right]
2nkexp((t|bR|)22wR2)1(ns)4,\displaystyle~{}\leq~{}2n^{k}\exp\left(\frac{-(t-|b^{R}|)^{2}}{2\|w^{R}\|^{2}}\right)\leq\frac{1}{(ns)^{4}},

Hence, in particular we have that

Prx[|Rx(x)|ck1.5Wlog(ns)+kB]\displaystyle\Pr_{x}[|\ell_{R_{x}}(x)|\geq ck^{1.5}W\sqrt{\log(ns)}+kB] 1n4s4\displaystyle~{}\leq~{}\frac{1}{n^{4}s^{4}}

And for all xx, we also have that |Rx(x)|kWn+kB|\ell_{R_{x}}(x)|\leq kW\sqrt{n}+kB holds with probability 11. Thus, we can upper bound (V) as,

(V) n[(klog(s/k)n1(ns)4)(ck1.5Wlog(ns)+kB)2+1(ns)4(kWn+kB)2]\displaystyle\textstyle~{}\leq~{}n\cdot\left[\left(\frac{k\sqrt{\log(s/k)}}{\sqrt{n}}-\frac{1}{(ns)^{4}}\right)(ck^{1.5}W\sqrt{\log(ns)}+kB)^{2}+\frac{1}{(ns)^{4}}\cdot(kW\sqrt{n}+kB)^{2}\right]
O(k4W2nlog(ns)+k3B2logs)\displaystyle~{}\leq~{}O\left(k^{4}W^{2}\sqrt{n}\log(ns)+k^{3}B^{2}\sqrt{\log s}\right)

Combining the bounds on (U) and (V) completes the proof.

Next, we can use the bound on average sensitivity to bound the noise sensitivity of functions in n,s,kW,B\mathcal{H}_{n,s,k}^{W,B}. To do so we use an argument attributed to Peres for converting bounds on average sensitivity to bounds on noise sensitivity, allowing us to get better low-degree approximations.

Lemma 3.4.

For any hn,s,kBh\in\mathcal{H}_{n,s,k}^{B},

𝖭𝖲ρ(h)=(1ρ)O(k4W2log2(ns/(1ρ))+k3B2logs).\displaystyle\mathsf{NS}_{\rho}(h)=\sqrt{(1-\rho)}\cdot O(k^{4}W^{2}\log^{2}(ns/(1-\rho))+k^{3}B^{2}\sqrt{\log s}).

The proof of \Creflem:as-to-ns-generalk is provided in \appendixrefapx:as-to-ns.

{proofof}

[\theoremrefthm:generalk-uniform-ub] We combine \Creffact:fourier, \Creflem:low-degree-alg and \Creflem:as-to-ns-generalk. Fix an error parameter ε\varepsilon. Then, by \Creflem:as-to-ns-generalk, there is a constant c>0c>0, such that for

1ρ=cε2min{log(knsW2/ε)k8W4log4(ns),1k6B4logs}1-\rho=c\varepsilon^{2}\cdot\min\left\{\frac{\log(knsW^{2}/\varepsilon)}{k^{8}W^{4}\log^{4}(ns)},\frac{1}{k^{6}B^{4}\log s}\right\}

any hn,s,kW,Bh\in\mathcal{H}_{n,s,k}^{W,B}, satisfies

𝖭𝖲ρ(h)ε/3\mathsf{NS}_{\rho}(h)\leq\varepsilon/3

Thus, we can choose a suitable d=Θ((k8W4log(ns)4+k6B4logs)/ε2)d=\Theta((k^{8}W^{4}\log(ns)^{4}+k^{6}B^{4}\log s)/\varepsilon^{2}), such that by \Creffact:fourier,

T:|T|>df^(T)2ε3(1ρd)εd(1ρ)ε.\textstyle\sum_{T:|T|>d}\hat{f}(T)^{2}\leq\frac{\varepsilon}{3(1-\rho^{d})}\leq\frac{\varepsilon}{d(1-\rho)}\leq\varepsilon\,.

Finally, note that Cn,s,kW,B=k(Wn+B)C_{\mathcal{H}_{n,s,k}^{W,B}}=k(W\sqrt{n}+B); since at most kk neurons are active on any input, and each neuron can at most contribute Wn+BW\sqrt{n}+B. Thus, the theorem now follows from combining the above with \Creflem:low-degree-alg. The run-time and sample complexity will be O(ndlog(1/δ)/ε)O(n^{d}\log(1/\delta)/\varepsilon) where dd is as above.

Remark 3.5.
\Cref

thm:generalk-uniform-ub can be extended to hold in case of the hypothesis class where kk-sparsity need not hold for all inputs x𝒳x\in\mathcal{X}, but holds with probability at least 11/𝗉𝗈𝗅𝗒(n,s)1-1/\mathsf{poly}(n,s) over the input distribution, that is, Prx𝒰[#{i[s]:wi,x+bi>0}>k]1/𝗉𝗈𝗅𝗒(n,s)\Pr_{x\sim\mathcal{U}}[\#\{i\in[s]:\left\langle w_{i},x\right\rangle+b_{i}>0\}>k]\leq 1/\mathsf{poly}(n,s). This is by decomposing 𝖠𝖲(h)\mathsf{AS}(h) into (U), (V) and a third term handling xx for which the kk-sparsity is violated.

4 Lower Bounds for Learning 𝓗𝒏,𝒔,𝟏\mathcal{H}_{n,s,1}

Note that the previous section implies a quasi-polynomial time learning algorithm for the class n,s,1\mathcal{H}_{n,s,1} of 11-sparsely activated networks. We next show that a quasi-polynomial run-time is likely necessary for learning n,s,1\mathcal{H}_{n,s,1} under the uniform distribution and stronger lower bounds under arbitrary distributions.

Sparse Activations Can Simulate Juntas

We first show that our proposed learning algorithms for the case of the uniform distribution have near-optimal runtime under a widely believed conjecture on the hardness of learning juntas. Let 𝒥n,p\mathcal{J}_{n,p} denote the set of Boolean functions f:{1,1}n{1,1}f:\{1,-1\}^{n}\to\{-1,1\} that only depend on at most pp variables.

Conjecture 4.1 (Hardness of learning Juntas).

(see e.g. Mossel et al. (2003); Feldman et al. (2011)) There is no (ε,δ)(\varepsilon,\delta)-PAC learning algorithm for learning 𝒥n,p\mathcal{J}_{n,p} under the uniform distribution on the hypercube that runs in time no(p)n^{o(p)}.linecolor=myGold,backgroundcolor=myGold!25,bordercolor=myGold]Pritish: We cited Mossel et al. (2003) and Feldman et al. (2011) in rebuttal. Cite those properly and note what they exactly say.linecolor=Gblue,backgroundcolor=Gblue!25,bordercolor=Gblue]Nishanth: problem is these citations also don’t formally state this as a conjecture IIRC.

The conjecture implies that there is no learning algorithm for n,s,1\mathcal{H}_{n,s,1} that runs in no(logs)n^{o(\log s)} time.

Theorem 4.2.

Assuming \Crefconj:junta-hardness, there is no (ε,δ)(\varepsilon,\delta)-PAC learning algorithm for n,s,1W,B\mathcal{H}_{n,s,1}^{W,B} for W=log2sW=\sqrt{\log_{2}s} and B=log2sB=\log_{2}s over 𝒰\mathcal{U} that runs in no(logs)n^{o(\log s)} time.

Proof 4.3.

We show that n,s,1p,p𝒥n,p\mathcal{H}_{n,s,1}^{\sqrt{p},p}\supseteq\mathcal{J}_{n,p} for all plog2sp\leq\left\lfloor\log_{2}s\right\rfloor, that is, for plog2sp\leq\left\lfloor\log_{2}s\right\rfloor any pp-junta f𝒥n,pf\in\mathcal{J}_{n,p} can be expressed as j[s]ujσ(wj,x+bj)\sum_{j\in[s]}u_{j}\sigma(\left\langle w_{j},x\right\rangle+b_{j}) where u1\|u\|_{\infty}\leq 1 and wj2log2s\|w_{j}\|_{2}\leq\sqrt{\log_{2}s}. Suppose w.l.o.g. that ff depends on x1,,xpx_{1},\ldots,x_{p}. Let w1,w2pw_{1},\ldots w_{2^{p}} be distinct vectors that take all possible ±1\pm 1 values in the first pp coordinates, and are 0 on other coordinates. Let uj=f(x)u_{j}=f(x) for any xx such that xi=wjix_{i}=w_{ji} for all i[p]i\in[p] and j[2p]j\in[2^{p}]. Let wj=𝟎w_{j}=\mathbf{0} and uj=0u_{j}=0 for all j>2pj>2^{p}. It is now easy to verify that for all x𝒳x\in\mathcal{X},

f(x)=j[2p]ujσ(wj,xp+1),since σ(wj,xp+1)=𝟙{xi=wji for all i[p]}\textstyle f(x)=\sum_{j\in[2^{p}]}u_{j}\sigma(\left\langle w_{j},x\right\rangle-p+1),\quad\text{since }\sigma(\left\langle w_{j},x\right\rangle-p+1)=\mathds{1}\{x_{i}=w_{ji}\text{ for all }i\in[p]\}

Thus, the theorem follows under the assumption of \Crefconj:junta-hardness.

Hardness Under Arbitrary Distributions

We next show that one-sparse activation networks over {1,1}n\{1,-1\}^{n} can simulate parities of size Ω(n)\Omega(\sqrt{n}). Fix an integer mm, and for S[m]S\subseteq[m], let χS:{1,1}m{0,1}\chi_{S}:\{1,-1\}^{m}\rightarrow\{0,1\} be defined by χS(y)=1\chi_{S}(y)=1 if and only if iSyi\sum_{i\in S}y_{i} is even. Now, we can use the following simple identity (similar identities were used for similar purposes for example in Klivans and Sherstov (2006))

χS(y)=a{m,,m}:aeven2σ(12(iSyia)2).\textstyle\chi_{S}(y)=\sum_{a\in\{-m,\ldots,m\}:a\;even}2\sigma\left(\frac{1}{2}-\left(\sum_{i\in S}y_{i}-a\right)^{2}\right).

Note that for any y{1,1}my\in\{1,-1\}^{m}, at most one ReLU node is active. This is not quite enough to capture n,s,1\mathcal{H}_{n,s,1} as the function inside the ReLUs are not linear. To fix this, we linearize the quadratic function by increasing the dimension. For y{1,1}my\in\{1,-1\}^{m}, let x(y){1,1}m×mx(y)\in\{1,-1\}^{m\times m} be defined as follows:

x(y)ij={yiif i=jyiyjif ij.x(y)_{ij}=\begin{cases}y_{i}&\text{if $i=j$}\\ y_{i}y_{j}&\text{if $i\neq j$}\end{cases}.

Let n=m2n=m^{2} and identify {1,1}n\{1,-1\}^{n} with {1,1}m×m\{1,-1\}^{m\times m} in the natural way. Observe that for any S[m]S\subseteq[m], a[m,m]a\in[-m,m], there exists a vector wS,an,bS,aw_{S,a}\in\mathbb{R}^{n},b_{S,a}\in\mathbb{R} such that

12(iSyia)2=wS,a,x(y)bS,a.\textstyle\frac{1}{2}-\left(\sum_{i\in S}y_{i}-a\right)^{2}=\langle w_{S,a},x(y)\rangle-b_{S,a}.

In particular, we can take bS,a=|S|+a21/2b_{S,a}=|S|+a^{2}-1/2, and wS,a[i,j]=1w_{S,a}[i,j]=-1 if ij[m]i\neq j\in[m] and wS,a[i,i]=2aw_{S,a}[i,i]=2a. Note that wS,a2=O(m1.5)=O(n3/4)\|w_{S,a}\|_{2}=O(m^{1.5})=O(n^{3/4}) and |bS,a|=O(m2)=O(n)|b_{S,a}|=O(m^{2})=O(n).

In summary, there exists a distribution 𝒟\cal{D} on {1,1}m×m\{1,-1\}^{m\times m} such that learning parities over {1,1}m\{1,-1\}^{m} under the uniform distribution is implied by learning m2,2m,1O(m1.5),O(m2)\mathcal{H}_{m^{2},2m,1}^{O(m^{1.5}),O(m^{2})} under the distribution 𝒟\cal{D}. The first part of \Crefth:lb-general now follows from standard lower bounds for learning parities.

SQ Hardness

Consider a class of functions, denoted by CC, that maps n\mathbb{R}^{n} to \mathbb{R}, and let 𝒟\mathcal{D} be a distribution over n\mathbb{R}^{n}.

In the Statistical Query (SQ) model, as described by Kearns (1998), the learner interacts with the data through an SQ oracle. For a bounded query function ϕ:n×[1,1]\phi:\mathbb{R}^{n}\times\mathbb{R}\rightarrow[-1,1] and a tolerance τ>0\tau>0, the oracle can return any value vv such that the absolute difference |v𝔼xD[ϕ(x,f(x))]|τ\left|v-\mathbb{E}_{x\sim D}[\phi(x,f(x))]\right|\leq\tau. The goal in SQ learning is to learn an approximation to the unknown concept only using few queries as above with reasonable tolerance. We will use the following classical theorem:

Theorem 4.4 ((Blum et al., 1994)).

Any SQ algorithm for learning the class of parities over {1,1}m\{1,-1\}^{m} within error 1/31/3 under the uniform distribution over the hypercube with tolerance τ\tau requires Ω(2mτ2)\Omega(2^{m}\tau^{2}) queries.

The first part of \Crefth:lb-general follows immediately from the above and the fact that parities on mm variables can be computed in m2,O(m),1O(m1.5),O(m2)\mathcal{H}_{m^{2},O(m),1}^{O(m^{1.5}),O(m^{2})} as described.

Cryptographic Hardness

We sketch the argument here. Following Chen et al. (2022a), our starting point will be the Learning with Rounding (LWR) problem (Banerjee et al., 2012):

Definition 4.5.

For moduli p,qp,q\in\mathbb{N}, wqmw\in\mathbb{Z}_{q}^{m}, let fw:qmpf_{w}:\mathbb{Z}_{q}^{m}\rightarrow\mathbb{Z}_{p} by fw(y):=(w,y mod q)modp.f_{w}(y):=(\langle w,y\rangle\text{ mod }q)\mod p.

In the 𝖫𝖶𝖱p,q,m\mathsf{LWR}_{p,q,m} problem the secret wqmw\in\mathbb{Z}_{q}^{m} is drawn uniformly at random and we are given samples of the form (y,fw(y))(y,f_{w}(y)) where yy is uniform over ZqmZ_{q}^{m}. The goal is to output a hypothesis that achieves a small error in predicting the label fw()f_{w}(\cdot). It is conjectured that there is no 𝗉𝗈𝗅𝗒(m,p,q)\mathsf{poly}(m,p,q) algorithm for 𝖫𝖶𝖱p,q,m\mathsf{LWR}_{p,q,m}.

Conjecture 4.6 (See Banerjee et al. (2012)).

There is no 𝗉𝗈𝗅𝗒(p,q,m)\mathsf{poly}(p,q,m) run-time algorithm to solve the 𝖫𝖶𝖱p,q,m\mathsf{LWR}_{p,q,m} with probability at least 2/32/3 (over the random choice of ww and the samples).

We show that an efficient algorithm for n,s,1\mathcal{H}_{n,s,1} functions under arbitrary distributions on the hypercube will contradict this assumption.

Consider an instance of the 𝖫𝖶𝖱p,q,m\mathsf{LWR}_{p,q,m} problem. First, map yqmy\in\mathbb{Z}_{q}^{m} to z(y){1,1}rz(y)\in\{1,-1\}^{r} for r=O(mlogq)r=O(m\log q) by considering the binary representation of the integers in yy. Next, let λ:[q2m][p]\lambda:[q^{2}m]\rightarrow[p] be such that λ(i)=(imodq)modp\lambda(i)=(i\mod q)\mod p. Note that for every wqmw\in\mathbb{Z}_{q}^{m}, we can find a vector v(w)rv(w)\in\mathbb{R}^{r} such that v(w),z(y)=w,y\langle v(w),z(y)\rangle=\langle w,y\rangle. Then,

fw(y)=λ(v(w),z(y)).f_{w}(y)=\lambda(\langle v(w),z(y)\rangle).

Now, observe that we can write

λ(v(w),z(y))=a[q2m]2λ(a)σ(12(v(w),z(y)a))2).\textstyle\lambda(\langle v(w),z(y)\rangle)=\sum_{a\in[q^{2}m]}2\lambda(a)\sigma\left(\frac{1}{2}-\left(\langle v(w),z(y)\rangle-a)\right)^{2}\right).

Note that in the conversion z(y){1,1}rz(y)\in\{1,-1\}^{r} and v(w)rv(w)\in\mathbb{R}^{r}. Further, for any input yy, only one of the ReLUs will be active. However, the above is not quite in n,s,1\mathcal{H}_{n,s,1} as we have a quadratic function inside the ReLU. Just as we did for parities, we can fix this issue by linearizing the quadratic form. Let n=r2n=r^{2}, and define x(y){1,1}r×rx(y)\in\{1,-1\}^{r\times r} by setting x(y)ij=z(y)iz(y)jx(y)_{ij}=z(y)_{i}z(y)_{j} if iji\neq j and x(y)ii=z(y)ix(y)_{ii}=z(y)_{i}. Then, just as in our argument for parities, there exists a lifted weight vector Ww,a{1,1}nW_{w,a}\in\{1,-1\}^{n} and bw,ab_{w,a} such that

12(v(w),z(y)a))2=Ww,a,x(y)bw,a.\frac{1}{2}-\left(\langle v(w),z(y)\rangle-a)\right)^{2}=\langle W_{w,a},x(y)\rangle-b_{w,a}.

In addition, it is easy to check that Ww,a2,|bw,a|=𝗉𝗈𝗅𝗒(q,m)\|W_{w,a}\|_{2},|b_{w,a}|=\mathsf{poly}(q,m). In particular, we get that for every wqmw\in\mathbb{Z}_{q}^{m}, there exists a function FwF_{w} in r2,O(q2m),1\mathcal{H}_{r^{2},O(q^{2}m),1} such that for every yqmy\in\mathbb{Z}_{q}^{m},

fw(y)=Fw(x(y)),f_{w}(y)=F_{w}(x(y)),

where x(y){1,1}r2x(y)\in\{1,-1\}^{r^{2}} is the embedding as defined above and in showing SQ hardness. The second part of \Crefth:lb-general now follows from the conjectured hardness of 𝖫𝖶𝖱p,q,m\mathsf{LWR}_{p,q,m}; we omit the minor details.

5 Learning under General Distributions

We now show the statistical advantage associated with sparsely activated neural networks over general distributions. In particular, we show that

Theorem 5.1.

There exists a (ε,δ)(\varepsilon,\delta)-PAC learning algorithm for any n,s,kW,B\mathcal{H}_{n,s,k}^{W,B} with sample complexity m(ε,δ)=O((WR+B)2ksnlog(k(R+B)ε)+log(1δ)ε2)m(\varepsilon,\delta)=O\Big{(}\frac{(WR+B)^{2}ksn\log(\frac{k(R+B)}{\varepsilon})+\log(\frac{1}{\delta})}{\varepsilon^{2}}\Big{)}.

This result even holds in a more general setting where the input space 𝒳n\mathcal{X}\subset\mathbb{R}^{n} and xR\|x\|\leq R for all x𝒳x\in\mathcal{X}. To begin with we will again consider the class of 11-sparsely activated networks, i.e., n,s,1W,B\mathcal{H}^{W,B}_{n,s,1}. We will discuss extensions to n,s,kW,B\mathcal{H}^{W,B}_{n,s,k} towards the end of the section.

We use Rademacher complexity to establish the bound in Theorem 5.1. Given a set of examples S={x1,x2,,xm}S=\{x_{1},x_{2},\ldots,x_{m}\} the empirical Rademacher complexity (Shalev-Shwartz and Ben-David, 2014) is defined as (S):=𝔼ζ[maxh1mi=1mζih(xi)]\mathcal{R}_{\mathcal{H}}(S)\textstyle~{}:=~{}\mathbb{E}_{\zeta}\left[\max_{h\in\mathcal{H}}\frac{1}{m}\sum_{i=1}^{m}\zeta_{i}h(x_{i})\right], where ζ1,,ζm\zeta_{1},\ldots,\zeta_{m} are {1,+1}\{-1,+1\} valued Rademacher random variables. For \mathcal{H}, let C:=suph,x𝒳h(x)C_{\mathcal{H}}:=\sup_{h\in\mathcal{H},x\in\mathcal{X}}h(x).

Lemma 5.2 (see Shalev-Shwartz and Ben-David (2014)).

For any class \mathcal{H} mapping 𝒳\mathcal{X} to \mathbb{R}, there exists an (ε,δ)(\varepsilon,\delta)-PAC learning algorithm for \mathcal{H} with sample complexity m(ε,δ)m(\varepsilon,\delta) equal to the smallest mm such that for a large enough constant cc, it holds thatlinecolor=myGold,backgroundcolor=myGold!25,bordercolor=myGold]Pritish: I am combining Lemmas 26.5 and 26.9 from Shalev-Shwartz and Ben-David (2014), and using that square loss is O(C)O(C_{\mathcal{H}})-Lipschitz.

c(C𝔼S[(S)]+log(1/δ)m)ε.\textstyle c\cdot\left(C_{\mathcal{H}}\mathbb{E}_{S}[\mathcal{R}_{\mathcal{H}}(S)]+\sqrt{\frac{\log(1/\delta)}{m}}\right)\leq\varepsilon\,.
\Cref

thm:general-dist-upper-bound will follow from bounding the Rademacher complexity (S)\mathcal{R}_{\mathcal{H}}(S). Recall that in the absence of any sparsity assumption, existing results (Anthony et al., 1999) on the Rademacher complexity of 11-hidden layer ReLU networks with input dimensionality nn and ss hidden units lead to a bound of (WR+B)sm\frac{(WR+B){s}}{\sqrt{m}}.111Better bounds are possible under stronger assumptions on the network weights (Wei et al., 2019).linecolor=myGold,backgroundcolor=myGold!25,bordercolor=myGold]Pritish: Double check that these bounds apply under our updated scaling. We will show that the main statistical advantage that comes from sparsity is that the dependence on the number of hidden units ss can be made sub-linear, albeit at the expense of an explicit dependence on the input dimensionality nn. In particular we will prove the following theorem.

Theorem 5.3.

It holds that

n,s,1W,B(S)\displaystyle\mathcal{R}_{\mathcal{H}^{W,B}_{n,s,1}}(S) (WR+B)snlog(m(R+B))m.\displaystyle~{}\leq~{}\frac{(WR+B)\sqrt{sn\log(m(R+B))}}{\sqrt{m}}. (1)
Proof 5.4.

For a given hypothesis u,w1,,wsn,s,1W,Bu,w_{1},\ldots,w_{s}\in\mathcal{H}^{W,B}_{n,s,1} and for any j[s]j\in[s], let IjI_{j} be the subset of the mm examples that activate neuron jj, i.e., Ij={i[m]:wj,xibj0}I_{j}=\{i\in[m]:\left\langle w_{j},x_{i}\right\rangle-b_{j}\geq 0\}. Since each IjI_{j} is determined by a halfspace in nn dimensions, by the Sauer-Shelah lemma (Shalev-Shwartz and Ben-David, 2014) there can be at most O(mn)O(m^{n}) such subsets.

Next, we have

n,s,1W,B(S)\displaystyle\mathcal{R}_{\mathcal{H}^{W,B}_{n,s,1}}(S) :=𝔼ζ[maxu,w1,,wsn,s,1W,B1mi=1mζij=1sujσ(wj,xibj)]\displaystyle\textstyle~{}:=~{}\mathbb{E}_{\zeta}\left[\max_{u,w_{1},\ldots,w_{s}\in\mathcal{H}^{W,B}_{n,s,1}}\frac{1}{m}\sum_{i=1}^{m}\zeta_{i}\sum_{j=1}^{s}u_{j}\sigma(\left\langle w_{j},x_{i}\right\rangle-b_{j})\right] (2)
=𝔼ζ[maxu,w1,,wsn,s,1W,B1mj=1siIjζiuj(wj,xibj)]\displaystyle\textstyle~{}=~{}\mathbb{E}_{\zeta}\left[\max_{u,w_{1},\ldots,w_{s}\in\mathcal{H}^{W,B}_{n,s,1}}\frac{1}{m}\sum_{j=1}^{s}\sum_{i\in I_{j}}\zeta_{i}u_{j}(\left\langle w_{j},x_{i}\right\rangle-b_{j})\right] (3)
𝔼ζ[maxu,w1,,wsn,s,1W,B1mj=1siIjζiujwj,xi]\displaystyle\textstyle~{}\leq~{}\mathbb{E}_{\zeta}\left[\max_{u,w_{1},\ldots,w_{s}\in\mathcal{H}^{W,B}_{n,s,1}}\frac{1}{m}\sum_{j=1}^{s}\sum_{i\in I_{j}}\zeta_{i}u_{j}\left\langle w_{j},x_{i}\right\rangle\right]
+𝔼ζ[maxu,w1,,wsn,s,1W,B1mj=1sxiIjζiujbj]\displaystyle\textstyle~{}~{}~{}~{}+~{}\mathbb{E}_{\zeta}\left[\max_{u,w_{1},\ldots,w_{s}\in\mathcal{H}^{W,B}_{n,s,1}}\frac{1}{m}\sum_{j=1}^{s}\sum_{x_{i}\in I_{j}}\zeta_{i}u_{j}b_{j}\right] (4)

We will bound the above two terms separately via standard concentration inequalities. For the second term note that for any fixed IjI_{j}, the random variable iIjζi\sum_{i\in I_{j}}\zeta_{i} is sub-Gaussian with norm O(|Ij|)O(\sqrt{|I_{j}|}). Hence we for any fixed IjI_{j} the following holds (Vershynin, 2018)

[|iIjζi|>t|Ij|]2et2c,\displaystyle\textstyle\mathbb{P}\left[\left|\sum_{i\in I_{j}}\zeta_{i}\right|>t\sqrt{|I_{j}|}\right]\leq 2e^{-\frac{t^{2}}{c}}, (5)

where c>0c>0 is an absolute constant. Via the union bound we get that with probability at least 1O(mnet2/c)1-O(m^{n}e^{-t^{2}/c}), all sets IjI_{j} simultaneously satisfy the above inequality.

Hence we get the following bound on the second term.

𝔼ζ[maxu,w1,,wsn,s,1W,B1mj=1sxiIjζiujbj]\displaystyle\mathbb{E}_{\zeta}\left[\max_{u,w_{1},\ldots,w_{s}\in\mathcal{H}^{W,B}_{n,s,1}}\frac{1}{m}\sum_{j=1}^{s}\sum_{x_{i}\in I_{j}}\zeta_{i}u_{j}b_{j}\right] 1mj=1st|ujbj||Ij|+O(mnet2/c)1mj=1s|ujbj||Ij|.\displaystyle~{}\leq~{}\frac{1}{m}\sum_{j=1}^{s}t|u_{j}b_{j}|\sqrt{|I_{j}|}+O(m^{n}e^{-t^{2}/c})\frac{1}{m}\sum_{j=1}^{s}|u_{j}b_{j}||I_{j}|. (6)

From the fact that the activations are 11-sparse we get that j=1s|Ij|=m\sum_{j=1}^{s}|I_{j}|=m. This implies that j=1s|Ij|sm\sum_{j=1}^{s}\sqrt{|I_{j}|}\leq\sqrt{sm}. Furthermore, using the fact that maxj|ujbj|B\max_{j}|u_{j}b_{j}|\leq B we get

𝔼ζ[maxu,w1,,wsn,s,1W,B1mj=1sxiIjζiujbj]\displaystyle\mathbb{E}_{\zeta}\left[\max_{u,w_{1},\ldots,w_{s}\in\mathcal{H}^{W,B}_{n,s,1}}\frac{1}{{m}}\sum_{j=1}^{s}\sum_{x_{i}\in I_{j}}\zeta_{i}u_{j}b_{j}\right] 1mtBs+O(mnet2/c)B.\displaystyle~{}\leq~{}\frac{1}{\sqrt{m}}tB\sqrt{s}+O(m^{n}e^{-t^{2}/c})B. (7)

Setting t=2nclog(mB)t=2\sqrt{nc\log(mB)} we get that the second term is bounded by

𝔼ζ[maxu,w1,,wsn,s,1W,B1mj=1sxiIjζiujbj]\displaystyle\mathbb{E}_{\zeta}\left[\max_{u,w_{1},\ldots,w_{s}\in\mathcal{H}^{W,B}_{n,s,1}}\frac{1}{m}\sum_{j=1}^{s}\sum_{x_{i}\in I_{j}}\zeta_{i}u_{j}b_{j}\right] 4Bsnclog(mB)m.\displaystyle~{}\leq~{}\frac{4B\sqrt{snc\log(mB)}}{\sqrt{m}}. (8)

Similarly, we next bound the first term. Note that for any fixed IjI_{j}, and any coordinate p[n]p\in[n], sub-Gaussian concentration (Vershynin, 2018) implies that

(|iIjζixi,p|>tiIjxi,p2)2et2c.\displaystyle\mathbb{P}\left(\Big{|}\sum_{i\in I_{j}}\zeta_{i}x_{i,p}\Big{|}>t\sqrt{\sum_{i\in I_{j}}x^{2}_{i,p}}\right)\leq 2e^{-\frac{t^{2}}{c}}. (9)

Via a union bound over all the nn coordinates and all possible subsets IjI_{j} we get that with probability at least 12nmnet2c1-2nm^{n}e^{-\frac{t^{2}}{c}}, all sets IjI_{j} simultaneously satisfy

iIjζixitR|Ij|.\displaystyle\left\|\sum_{i\in I_{j}}\zeta_{i}x_{i}\right\|\leq tR\sqrt{|I_{j}|}. (10)

Using the above we can bound the first term as

𝔼ζ[maxu,w1,,wsn,s,1W,B1mj=1sxiIjζiujwj,xi]\displaystyle\mathbb{E}_{\zeta}\left[\max_{u,w_{1},\ldots,w_{s}\in\mathcal{H}^{W,B}_{n,s,1}}\frac{1}{m}\sum_{j=1}^{s}\sum_{x_{i}\in I_{j}}\zeta_{i}u_{j}\left\langle w_{j},x_{i}\right\rangle\right] 𝔼ζ[maxu,w1,,wsn,s,1W,B1mj=1sujwj,iIjζixi]\displaystyle~{}\leq~{}\mathbb{E}_{\zeta}\left[\max_{u,w_{1},\ldots,w_{s}\in\mathcal{H}^{W,B}_{n,s,1}}\frac{1}{m}\sum_{j=1}^{s}\left\langle u_{j}w_{j},\sum_{i\in I_{j}}\zeta_{i}x_{i}\right\rangle\right] (11)
1mj=1s|uj|wj𝔼ζ[iIjζixi]\displaystyle~{}\leq~{}\frac{1}{m}\sum_{j=1}^{s}|u_{j}|\|w_{j}\|\mathbb{E}_{\zeta}\left[\left\|\sum_{i\in I_{j}}\zeta_{i}x_{i}\right\|\right] (12)
Wmj=1s(tR|Ij|+2nmnet2cR|Ij|).\displaystyle~{}\leq~{}\frac{W}{m}\sum_{j=1}^{s}\big{(}tR\sqrt{|I_{j}|}+2nm^{n}e^{-\frac{t^{2}}{c}}R|I_{j}|\big{)}. (13)

Recall from above that j=1s|Ij|sm\sum_{j=1}^{s}\sqrt{|I_{j}|}\leq\sqrt{sm}. Furthermore, setting t=2nlog(mR)t=2\sqrt{n\log(mR)} we get that the first term is bounded by

𝔼ζ[maxu,w1,,wsn,s,1W,B1mj=1siIjζiujwj,xi]\displaystyle\mathbb{E}_{\zeta}\left[\max_{u,w_{1},\ldots,w_{s}\in\mathcal{H}^{W,B}_{n,s,1}}\frac{1}{m}\sum_{j=1}^{s}\sum_{i\in I_{j}}\zeta_{i}u_{j}\left\langle w_{j},x_{i}\right\rangle\right] 4WRnslog(mR)m.\displaystyle~{}\leq~{}4\frac{WR\sqrt{ns\log(mR)}}{\sqrt{m}}. (14)

Combining the bounds for the first and the second terms, we get the desired claim.

Generalization to 𝒌k-sparsely activated networks.

The above analysis extends in a straightforward manner to the class n,s,k\mathcal{H}_{n,s,k}, i.e., the class of networks where each input activates at most kk hidden units.

To extend the bound in Theorem 5.3 we note that using the fact that kk-sparsity implies that j[s]|Ij|km\sum_{j\in[s]}|I_{j}|\leq km we get that

n,s,kW,B(S)\displaystyle\mathcal{R}_{\mathcal{H}^{W,B}_{n,s,k}}(S) (WR+B)snklog(km(R+B))m.\displaystyle~{}\leq~{}\frac{(WR+B)\sqrt{snk\log(km(R+B))}}{\sqrt{m}}. (15)

Note that in contrast to the classical bounds on Rademacher complexity of general norm bounded 11-layer neural networks the bound in Theorem 5.3 above has a sub-linear dependence on ss. However we incur an explicit dependency on the input dimensionality.

We suspect that this is a limitation of our proof technique and conjecture that the right dependence should not have any explicit dependence on the input dimension nn.

Conjecture 5.5.

The class n,s,kW,B\mathcal{H}^{W,B}_{n,s,k} of kk-sparsely activated neural networks satisfies

n,s,kW,B(S)\displaystyle\mathcal{R}_{\mathcal{H}^{W,B}_{n,s,k}}(S) (WR+B)skm.\displaystyle\textstyle\leq\frac{(WR+B)\sqrt{sk}}{\sqrt{m}}. (16)

6 Discussion & Future Directions

Motivated by the empirical phenomenon of activation sparsity in MLP layers of large transformer models, in this work we proposed and studied the problem of PAC learning the class of sparsely activated neural networks. This is a novel concept class with many interesting properties. The form of input-dependent sparsity present in this class of functions makes it distinct from the typical sparse function classes studied in literature. The main conceptual insight from our work is that despite the empirical challenges in leveraging sparsity, activation sparsity can provably provide both computational and statistical benefits.

Several open questions come out of our work. While we provide algorithms with near optimal running time for the case of the uniform distribution, it would be interesting to design learning algorithms under arbitrary distributions that are provably better than the O((ns)n)O((ns)^{n})-time algorithms that exist for general 11-layer ReLU networks (Goel et al., 2020). As mentioned in \Crefsec:rademacher we strongly suspect that the dependence on the input dimension nn in the Rademacher complexity bound of \Crefthm:rademacher-1-sparse-bound-2 is suboptimal. While we primarily considered networks that are sparsely activated for all inputs, it might be interesting to also consider sparsely activated with high probability over input distributions, as we briefly alluded to in \Crefrem:sparse-with-high-prob although in that case, the probability of not being sparsely activated was very small. Finally, it would be interesting to explore practical algorithms for leveraging sparsity based on our theoretical insights.

\acks

We thank anonymous reviewers for their comments that helped improve the presentation.

References

  • Anil et al. (2018) Rohan Anil, Gabriel Pereyra, Alexandre Passos, Robert Ormandi, George E Dahl, and Geoffrey E Hinton. Large scale distributed neural network training through online distillation. arXiv preprint arXiv:1804.03235, 2018.
  • Anthony et al. (1999) Martin Anthony, Peter L Bartlett, Peter L Bartlett, et al. Neural network learning: Theoretical foundations, volume 9. cambridge university press Cambridge, 1999.
  • Banerjee et al. (2012) Abhishek Banerjee, Chris Peikert, and Alon Rosen. Pseudorandom functions and lattices. In Annual International Conference on the Theory and Applications of Cryptographic Techniques, pages 719–737. Springer, 2012.
  • Banner et al. (2019) Ron Banner, Yury Nahshan, and Daniel Soudry. Post training 4-bit quantization of convolutional networks for rapid-deployment. Advances in Neural Information Processing Systems, 32, 2019.
  • Blum et al. (1994) Avrim Blum, Merrick Furst, Jeffrey Jackson, Michael Kearns, Yishay Mansour, and Steven Rudich. Weakly learning dnf and characterizing statistical query learning using fourier analysis. In Proceedings of the twenty-sixth annual ACM symposium on Theory of computing, pages 253–262, 1994.
  • Brown et al. (2020) Tom Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared D Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, et al. Language models are few-shot learners. Advances in neural information processing systems, 33:1877–1901, 2020.
  • Chen et al. (2022a) Sitan Chen, Aravind Gollakota, Adam R. Klivans, and Raghu Meka. Hardness of noise-free learning for two-hidden-layer neural networks. In Neural Information Processing Systems (NeurIPS), 2022a. URL http://papers.nips.cc/paper_files/paper/2022/hash/45a7ca247462d9e465ee88c8a302ca70-Abstract-Conference.html.
  • Chen et al. (2022b) Xi Chen, Xiao Wang, Soravit Changpinyo, AJ Piergiovanni, Piotr Padlewski, Daniel Salz, Sebastian Goodman, Adam Grycner, Basil Mustafa, Lucas Beyer, et al. Pali: A jointly-scaled multilingual language-image model. arXiv preprint arXiv:2209.06794, 2022b.
  • Choromanski et al. (2020) Krzysztof Choromanski, Valerii Likhosherstov, David Dohan, Xingyou Song, Andreea Gane, Tamas Sarlos, Peter Hawkins, Jared Davis, Afroz Mohiuddin, Lukasz Kaiser, et al. Rethinking attention with performers. arXiv preprint arXiv:2009.14794, 2020.
  • Chowdhery et al. (2022) Aakanksha Chowdhery, Sharan Narang, Jacob Devlin, Maarten Bosma, Gaurav Mishra, Adam Roberts, Paul Barham, Hyung Won Chung, Charles Sutton, Sebastian Gehrmann, et al. Palm: Scaling language modeling with pathways. arXiv preprint arXiv:2204.02311, 2022.
  • Csordás et al. (2023) Róbert Csordás, Kazuki Irie, and Jürgen Schmidhuber. Approximating two-layer feedforward networks for efficient transformers. arXiv preprint arXiv:2310.10837, 2023.
  • Dong et al. (2023) Harry Dong, Beidi Chen, and Yuejie Chi. Towards structured sparsity in transformers for efficient inference. In Workshop on Efficient Systems for Foundation Models@ ICML2023, 2023.
  • Dosovitskiy et al. (2020) Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, et al. An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929, 2020.
  • Elhage et al. (2022) Nelson Elhage, Tristan Hume, Catherine Olsson, Neel Nanda, Tom Henighan, Scott Johnston, Sheer ElShowk, Nicholas Joseph, Nova DasSarma, Ben Mann, Danny Hernandez, Amanda Askell, Kamal Ndousse, Andy Jones, Dawn Drain, Anna Chen, Yuntao Bai, Deep Ganguli, Liane Lovitt, Zac Hatfield-Dodds, Jackson Kernion, Tom Conerly, Shauna Kravec, Stanislav Fort, Saurav Kadavath, Josh Jacobson, Eli Tran-Johnson, Jared Kaplan, Jack Clark, Tom Brown, Sam McCandlish, Dario Amodei, and Christopher Olah. Softmax linear units. Transformer Circuits Thread, 2022. https://transformer-circuits.pub/2022/solu/index.html.
  • Fedus et al. (2022) William Fedus, Barret Zoph, and Noam Shazeer. Switch transformers: Scaling to trillion parameter models with simple and efficient sparsity. The Journal of Machine Learning Research, 23(1):5232–5270, 2022.
  • Feldman et al. (2011) Vitaly Feldman, Homin K. Lee, and Rocco A. Servedio. Lower bounds and hardness amplification for learning shallow monotone formulas. In Conference on Learning Theory (COLT), volume 19 of JMLR Proceedings, pages 273–292. JMLR.org, 2011. URL http://proceedings.mlr.press/v19/feldman11a/feldman11a.pdf.
  • Frankle and Carbin (2018) Jonathan Frankle and Michael Carbin. The lottery ticket hypothesis: Finding sparse, trainable neural networks. arXiv preprint arXiv:1803.03635, 2018.
  • Frantar and Alistarh (2023) Elias Frantar and Dan Alistarh. Sparsegpt: Massive language models can be accurately pruned in one-shot. In International Conference on Machine Learning, pages 10323–10337. PMLR, 2023.
  • Geva et al. (2020) Mor Geva, Roei Schuster, Jonathan Berant, and Omer Levy. Transformer feed-forward layers are key-value memories. arXiv preprint arXiv:2012.14913, 2020.
  • Gholami et al. (2022) Amir Gholami, Sehoon Kim, Zhen Dong, Zhewei Yao, Michael W Mahoney, and Kurt Keutzer. A survey of quantization methods for efficient neural network inference. In Low-Power Computer Vision, pages 291–326. Chapman and Hall/CRC, 2022.
  • Goel et al. (2020) Surbhi Goel, Aravind Gollakota, and Adam Klivans. Statistical-query lower bounds via functional gradients. Advances in Neural Information Processing Systems, 33:2147–2158, 2020.
  • Grimaldi et al. (2023) Matteo Grimaldi, Darshan C Ganji, Ivan Lazarevich, and Sudhakar Sah. Accelerating deep neural networks via semi-structured activation sparsity. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pages 1179–1188, 2023.
  • Gu and Dao (2023) Albert Gu and Tri Dao. Mamba: Linear-time sequence modeling with selective state spaces. arXiv preprint arXiv:2312.00752, 2023.
  • Harutyunyan et al. (2023) Hrayr Harutyunyan, Ankit Singh Rawat, Aditya Krishna Menon, Seungyeon Kim, and Sanjiv Kumar. Supervision complexity and its role in knowledge distillation. arXiv preprint arXiv:2301.12245, 2023.
  • Hinton et al. (2015) Geoffrey Hinton, Oriol Vinyals, and Jeff Dean. Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531, 2015.
  • Kane (2014) Daniel M. Kane. The average sensitivity of an intersection of half spaces. In Symposium on Theory of Computing (STOC), pages 437–440. ACM, 2014. 10.1145/2591796.2591798. URL https://doi.org/10.1145/2591796.2591798.
  • Kearns (1998) Michael J. Kearns. Efficient noise-tolerant learning from statistical queries. J. ACM, 45(6):983–1006, 1998. 10.1145/293347.293351. URL https://doi.org/10.1145/293347.293351.
  • Klivans and Sherstov (2006) Adam R. Klivans and Alexander A. Sherstov. Cryptographic hardness for learning intersections of halfspaces. In 47th Annual IEEE Symposium on Foundations of Computer Science (FOCS 2006), 21-24 October 2006, Berkeley, California, USA, Proceedings, pages 553–562. IEEE Computer Society, 2006. 10.1109/FOCS.2006.24. URL https://doi.org/10.1109/FOCS.2006.24.
  • Klivans et al. (2004) Adam R. Klivans, Ryan O’Donnell, and Rocco A. Servedio. Learning intersections and thresholds of halfspaces. J. Comput. Syst. Sci., 68(4):808–840, 2004. 10.1016/J.JCSS.2003.11.002. URL https://doi.org/10.1016/j.jcss.2003.11.002.
  • Lample et al. (2019) Guillaume Lample, Alexandre Sablayrolles, Marc’Aurelio Ranzato, Ludovic Denoyer, and Hervé Jégou. Large memory layers with product keys. Advances in Neural Information Processing Systems, 32, 2019.
  • Li et al. (2023) Zonglin Li, Chong You, Srinadh Bhojanapalli, Daliang Li, Ankit Singh Rawat, Sashank J. Reddi, Ke Ye, Felix Chern, Felix X. Yu, Ruiqi Guo, and Sanjiv Kumar. The lazy neuron phenomenon: On emergence of activation sparsity in transformers. In The Eleventh International Conference on Learning Representations, ICLR 2023, Kigali, Rwanda, May 1-5, 2023. OpenReview.net, 2023. URL https://openreview.net/pdf?id=TJ2nxciYCk-.
  • Liu et al. (2023) Zichang Liu, Jue Wang, Tri Dao, Tianyi Zhou, Binhang Yuan, Zhao Song, Anshumali Shrivastava, Ce Zhang, Yuandong Tian, Christopher Re, and Beidi Chen. Deja vu: Contextual sparsity for efficient LLMs at inference time. In International Conference on Machine Learning (ICML), volume 202 of Proceedings of Machine Learning Research, pages 22137–22176. PMLR, 23–29 Jul 2023. URL https://proceedings.mlr.press/v202/liu23am.html.
  • Mirzadeh et al. (2023) Iman Mirzadeh, Keivan Alizadeh, Sachin Mehta, Carlo C Del Mundo, Oncel Tuzel, Golnoosh Samei, Mohammad Rastegari, and Mehrdad Farajtabar. Relu strikes back: Exploiting activation sparsity in large language models. arXiv preprint arXiv:2310.04564, 2023.
  • Mossel et al. (2003) Elchanan Mossel, Ryan O’Donnell, and Rocco A. Servedio. Learning juntas. In Symposium on Theory of Computing (STOC), pages 206–212. ACM, 2003. 10.1145/780542.780574. URL https://doi.org/10.1145/780542.780574.
  • O’Donnell (2014) Ryan O’Donnell. Analysis of boolean functions. Cambridge University Press, 2014.
  • Peng et al. (2023) Ze Peng, Lei Qi, Yinghuan Shi, and Yang Gao. Theoretical explanation of activation sparsity through flat minima and adversarial robustness. arXiv preprint arXiv:2309.03004, 2023.
  • Shalev-Shwartz and Ben-David (2014) Shai Shalev-Shwartz and Shai Ben-David. Understanding machine learning: From theory to algorithms. Cambridge university press, 2014.
  • Shazeer et al. (2017) Noam Shazeer, Azalia Mirhoseini, Krzysztof Maziarz, Andy Davis, Quoc Le, Geoffrey Hinton, and Jeff Dean. Outrageously large neural networks: The sparsely-gated mixture-of-experts layer. arXiv preprint arXiv:1701.06538, 2017.
  • Shen et al. (2023) Kai Shen, Junliang Guo, Xu Tan, Siliang Tang, Rui Wang, and Jiang Bian. A study on relu and softmax in transformer. arXiv preprint arXiv:2302.06461, 2023.
  • Sukhbaatar et al. (2019) Sainbayar Sukhbaatar, Edouard Grave, Guillaume Lample, Herve Jegou, and Armand Joulin. Augmenting self-attention with persistent memory. arXiv preprint arXiv:1907.01470, 2019.
  • Valiant (1984) Leslie G Valiant. A theory of the learnable. Communications of the ACM, 27(11):1134–1142, 1984.
  • Vaswani et al. (2017) Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, 𝕃\mathbb{L}ukasz Kaiser, and Illia Polosukhin. Attention is all you need. Advances in neural information processing systems, 30, 2017.
  • Vershynin (2018) Roman Vershynin. High-dimensional probability: An introduction with applications in data science, volume 47. Cambridge university press, 2018.
  • Wang et al. (2020) Sinong Wang, Belinda Z Li, Madian Khabsa, Han Fang, and Hao Ma. Linformer: Self-attention with linear complexity. arXiv preprint arXiv:2006.04768, 2020.
  • Wei et al. (2019) Colin Wei, Jason Lee, Qiang Liu, and Tengyu Ma. On the margin theory of feedforward neural networks, 2019. URL https://openreview.net/forum?id=HJGtFoC5Fm.
  • Zaheer et al. (2020) Manzil Zaheer, Guru Guruganesh, Kumar Avinava Dubey, Joshua Ainslie, Chris Alberti, Santiago Ontanon, Philip Pham, Anirudh Ravula, Qifan Wang, Li Yang, et al. Big bird: Transformers for longer sequences. Advances in neural information processing systems, 33:17283–17297, 2020.

Appendix A Example of a Sparsely Activated Network without Weight Sparsity

There are interesting functions (beyond juntas/parities) that are sparsely activated but do not have weight sparsity. E.g.: suppose log2s<n\log_{2}s<n. Consider b=log2sb=\log_{2}s, q=nbq=n-b, and look at F:{1,1}b×{1,1}qF:\{-1,1\}^{b}\times\{-1,1\}^{q}\rightarrow\mathbb{R}, of the form α{1,1}bσ(wα,y+Γ(x,αb))\sum_{\alpha\in\{1,-1\}^{b}}\sigma(\langle w_{\alpha},y\rangle+\Gamma\cdot(\langle x,\alpha\rangle-b)), where the input is (x,y)(x,y). When Γ=q\Gamma=\sqrt{q}, this network is 11-sparsely activated for all inputs, and when Γ=Θ(logs)\Gamma=\Theta(\sqrt{\log s}), the function is 11-sparse with probability 11/𝗉𝗈𝗅𝗒(s)1-1/\mathsf{poly}(s) under the uniform distribution on {1,1}b+q\{-1,1\}^{b+q}. Remark 3.5 shows that our results continue to hold in such a setting. Intuitively, such functions are similar to Indexing; they return the function σ(wx,y)\sigma(\langle w_{x},y\rangle) for all (or most) of the input space, where wxw_{x} can depend arbitrarily on the xx part of the input.

Appendix B Proof of \Creflem:as-to-ns-generalk

{proofof}

[\Creflem:as-to-ns-generalk] Given a ρ[1,1]\rho\in[-1,1], let r=2/(1ρ)r=\lfloor 2/(1-\rho)\rfloor. We describe an alternate way to sample (x,Nρ(x))(x,N_{\rho}(x)). First sample z{±1}nz\in\{\pm 1\}^{n} uniformly at random and partition the nn coordinates of zz into the rr buckets {Ae[n]}e=1r\{A_{e}\subseteq[n]\}_{e=1}^{r} at random (each coordinate is included in exactly one of these buckets uniformly and independently). For each AeA_{e}, sample ve{±1}v_{e}\in\{\pm 1\} uniformly at random. Multiply the coordinates of AeA_{e} by vev_{e} and concatenate all the buckets to get xx. Choose one bucket bb at random and flip vbv_{b} to get vbv^{\oplus b}. Multiply the coordinates of AeA_{e} by vbv^{\oplus b} to get yy. Observe that (x,y)(x,y) are distributed exactly the same as (x,Nρ(x))(x,N_{\rho}(x)). Now, given h(x)=j=1sujσ(wj,xbj)h(x)=\sum_{j=1}^{s}u_{j}\sigma(\left\langle w_{j},x\right\rangle-b_{j}), define

Hz(v)=j=1sujσ(wj,vbj),\displaystyle H_{z}(v)=\sum_{j=1}^{s}u_{j}\sigma(\left\langle w_{j}^{\prime},v\right\rangle-b_{j}),

where wje=lAewjlzlw_{je}^{\prime}=\sum_{l\in A_{e}}w_{jl}z_{l}. Clearly h(x)=Hz(v)h(x)=H_{z}(v). Hence,

𝖭𝖲ρ(h)\displaystyle\mathsf{NS}_{\rho}(h) =𝔼[(h(x)h(y))2]\displaystyle=\operatorname*{\mathop{\mathbb{E}}}[(h(x)-h(y))^{2}]
=1r𝔼z,{Ae}(b=1r𝔼v(Hz(v)Hz(vb))2)\displaystyle=\frac{1}{r}\operatorname*{\mathop{\mathbb{E}}}_{z,\{A_{e}\}}\left(\sum_{b=1}^{r}\operatorname*{\mathop{\mathbb{E}}}_{v}(H_{z}(v)-H_{z}(v^{\oplus b}))^{2}\right)
=1r𝔼z,{Ae}[𝖠𝖲(Hz)].\displaystyle=\frac{1}{r}\operatorname*{\mathop{\mathbb{E}}}_{z,\{A_{e}\}}[\mathsf{AS}(H_{z})]. (17)

From \Creflem:avg-sens-generalk,

𝖠𝖲(Hz)O(k4W2rlog(rs)+k3B2logs),\displaystyle\mathsf{AS}(H_{z})\leq O\left(k^{4}W^{\prime 2}\sqrt{r}\log(rs)+k^{3}B^{2}\sqrt{\log s}\right),
where W:=maxj[s]|uj|wj.\displaystyle\text{where }W^{\prime}:=\max_{j\in[s]}|u_{j}|\cdot\|w_{j}\|.

To bound WW^{\prime} we need to bound

maxj[s]wj22=maxj[s]e=1r(iAewjizi)2.\displaystyle\textstyle\max_{j\in[s]}\|w_{j}^{\prime}\|^{2}_{2}=\max_{j\in[s]}\sum_{e=1}^{r}\left(\sum_{i\in A_{e}}w_{ji}z_{i}\right)^{2}.

For any j[s]j\in[s], we have from measure concentration

Prz[|iAewjizi|>t]2exp(t24iAewji2)\displaystyle\textstyle\Pr_{z}\left[\left|\sum_{i\in A_{e}}w_{ji}z_{i}\right|>t\right]\leq 2\exp\left(-\frac{t^{2}}{4\sum_{i\in A_{e}}w_{ji}^{2}}\right)
\displaystyle\implies Prz[|iAewjizi|>22log(nsr)iAewji2]1(nsr)2\displaystyle\textstyle\Pr_{z}\left[\left|\sum_{i\in A_{e}}w_{ji}z_{i}\right|>2\sqrt{2\log(nsr)\sum_{i\in A_{e}}w_{ji}^{2}}\right]\leq\frac{1}{(nsr)^{2}}

Now we use that e=1riAewji2=wj22\sum_{e=1}^{r}\sum_{i\in A_{e}}w_{ji}^{2}=\|w_{j}\|_{2}^{2}.

\displaystyle\implies Prz[e=1r(iAewjizi)2>8log(nsr)wj22]1n2s2r\displaystyle\textstyle\Pr_{z}\left[\sum_{e=1}^{r}\left(\sum_{i\in A_{e}}w_{ji}z_{i}\right)^{2}>8\log(nsr)\|w_{j}\|_{2}^{2}\right]\leq\frac{1}{n^{2}s^{2}r}
\displaystyle\implies Prz[j[s],wj228log(nsr)wj22]11n2sr.\displaystyle\Pr_{z}\left[\forall j\in[s],\;\|w^{\prime}_{j}\|_{2}^{2}\leq 8\log(nsr)\|w_{j}\|_{2}^{2}\right]\geq 1-\frac{1}{n^{2}sr}.

Combining with the fact that wj\|w^{\prime}_{j}\| is always at most WnW\sqrt{n}, we get that

𝔼z[maxj[s]wj22]O(lognsr)wj22.\operatorname*{\mathop{\mathbb{E}}}_{z}\left[\max_{j\in[s]}\|w^{\prime}_{j}\|_{2}^{2}\right]\leq O(\log nsr)\|w_{j}\|_{2}^{2}.

Combining the above with (B), we get

𝖭𝖲ρ(h)\displaystyle\mathsf{NS}_{\rho}(h) =1r𝔼z,{Ae}[𝖠𝖲(Hz)]O(W2k4log2(nrs)2+k3B2logs)r\displaystyle=\frac{1}{r}\operatorname*{\mathop{\mathbb{E}}}_{z,\{A_{e}\}}[\mathsf{AS}(H_{z})]\leq\frac{O(W^{2}k^{4}\log^{2}(nrs)^{2}+k^{3}B^{2}\sqrt{\log s})}{\sqrt{r}}
=(1ρ)O(k4W2log2(ns/(1ρ))+k3B2logs).\displaystyle=\sqrt{(1-\rho)}O(k^{4}W^{2}\log^{2}(ns/(1-\rho))+k^{3}B^{2}\sqrt{\log s}).

The claim now follows.