Pranjal Awasthi \Email[email protected]
\addrGoogle Research
and \NameNishanth Dikkala \Email[email protected]
\addrGoogle Research
and \NamePritish Kamath \Email[email protected]
\addrGoogle Research
and \NameRaghu Meka \Email[email protected]
\addrUniversity of California, Los Angeles
Learning Neural Networks with Sparse Activations
Abstract
A core component present in many successful neural network architectures, is an MLP block of two fully connected layers with a non-linear activation in between. An intriguing phenomenon observed empirically, including in transformer architectures, is that, after training, the activations in the hidden layer of this MLP block tend to be extremely sparse on any given input. Unlike traditional forms of sparsity, where there are neurons/weights which can be deleted from the network, this form of dynamic activation sparsity appears to be harder to exploit to get more efficient networks.
Motivated by this we initiate a formal study of PAC learnability of MLP layers that exhibit activation sparsity. We present a variety of results showing that such classes of functions do lead to provable computational and statistical advantages over their non-sparse counterparts. Our hope is that a better theoretical understanding of sparsely activated networks would lead to methods that can exploit activation sparsity in practice.
keywords:
Multilayer Perceptrons, PAC Learning, Activation Sparsity, Rademacher Complexity1 Introduction
In recent years, transformer based deep neural networks (Vaswani et al., 2017) and the subsequent development of large language models have marked a paradigm shift in the fields of natural language processing and computer vision (Brown et al., 2020; Chowdhery et al., 2022; Chen et al., 2022b; Dosovitskiy et al., 2020). These models have significantly improved performance across various tasks, setting new benchmarks and enabling previously unattainable breakthroughs. However, the computational cost of training and deploying these models, especially the largest variants, presents a significant challenge. A notable portion of these models’ computational and parameter overhead is attributed to the Multi-Layer Perceptron (MLP) layers. These layers are integral to the transformer architecture, playing a crucial role in its ability to solve many different tasks.
Despite their efficacy, the resource-intensive nature of these models has spurred a wave of research focused on enhancing their efficiency (Banner et al., 2019; Frankle and Carbin, 2018; Gholami et al., 2022; Hinton et al., 2015; Anil et al., 2018; Harutyunyan et al., 2023). Among the various strategies explored for improving the inference efficiency of large transformers, attempting to sparsify the transformer is a promising approach.
A motivation for exploiting sparsity is rooted in an intriguing empirical observation made in recent works (Li et al., 2023) regarding the behavior of MLP layers within large transformer models. Post-training, these layers tend to exhibit a high degree of sparsity in their activations; often each input activates as low as 3% of the neurons in the MLP layers, suggesting a natural emergence of sparsity in activations. This leads to these MLP layers behaving like key-value lookups (Geva et al., 2020). The extremely low sparsity (3%) suggests that there might be significant room to sparsify the MLP layers leading to both training and inference efficiency. In addition, such sparsity also helps with interpretability of transformers by disentangling neurons corresponding to distinct concepts (Elhage et al., 2022). Moreover, through extensive ablation studies Li et al. (2023) observe that this phenomenon is highly prevalent. It occurs in convolutional networks (CNNs), as well as in vanilla fully connected feedforward networks.
Despite the potential benefits, effectively harnessing dynamic sparsity has proven challenging. Although, there have been many recent efforts (Li et al., 2023; Grimaldi et al., 2023; Liu et al., 2023; Dong et al., 2023; Csordás et al., 2023; Mirzadeh et al., 2023), they have led to limited success. None of the approaches achieve speedups (either in training or in inference) anywhere close to the the potential factor of 33x that is suggested by 3% sparsity. Moreover, by explicitly enforcing sparsity via methods such as choosing only the top- activations, the quality of the model degrades in some cases.
A key reason for the hardness in exploiting activation sparsity is that this form of sparsity is dynamic in nature and is input-dependent (i.e., not a fixed pattern). While each input example activates a small number of neurons, the overall sparsity pattern cannot be localized to a small subset of the model weights. For instance, the dynamic nature precludes the use of typical weight quantization or pruning based methods to exploit sparsity empirically. On the other hand, having a non-localized sparsity pattern is crucial in ensuring the model has rich expressiveness.
The above observations suggest that post-training, large transformer networks belong to an intriguing function class that is highly expressive yet exhibits high sparsity. Given the challenges in exploiting this behavior in practical settings, in this work, we initiate a theoretical study of the statistical and computational properties of such functions in the probably approximately correct (PAC) learning framework (Valiant, 1984).
We introduce the class of sparsely activated MLPs. We focus on the case of depth- MLPs with input units and hidden units with the standard ReLU activations. We define the class as the class of depth- ReLU networks in -dimensions with the promise that on each input in the support of the data distribution, at most of the hidden units are active:
Definition 1.1 (Sparsely Activated Networks).
Let denote the activation, namely . The class consists of hypotheses of the form with the property that for all in the support of the distribution, it holds that .
Note that this sparsity differs from dead sparsity, where some neurons are never active on any of the inputs, and consequently, can be deleted from the network without impacting its functionality. The form of dynamic sparsity we study can be crucial for the networks to be more expressive. We provide a couple of examples of useful functions represented using sparsely activated networks here:
-
•
Junta functions: The class of functions on variables which depend on only a -sized subset () of the variables is known as -junta functions. Sparse parities are a canonical example of junta functions. We show in Theorem 4.2 that we can represent -juntas using .
-
•
Indexing function: Consider the function , where is the -th bit of ( mapped to ), where is the integer represented by the first bits of in binary representation, and is the remaining bits vector. This can be represented as a -sparse activation network of size (i.e., in ): where the first coordinates of are and the -th coordinate among the last coordinates is . On input , only the neuron corresponding to is activated, and the output is precisely .
In both the examples presented above, removing any of the neurons will change the functionality of the network. However, each weight vector is quite sparse. In Appendix A, we present an example of a sparsely activated network where even the weight vectors are not sparse. Hence, in general, it is not clear if sparsely activated networks can be represented with fewer neurons or sparse weight vectors.
In order to provide learning guarantees, we have to assume an upper bound on the scale of , ’s and ’s. We will use the following natural scaling for the paper:
Definition 1.2.
Let consisting of given as , satisfying and .
We then consider the problem of learning sparsely activated networks efficiently. We consider the domain to be the Boolean hypercube as a natural first-step and as a domain where sparsely activated networks can compute non-trivial functions. The Boolean hypercube provides a setting where the function can be sparse everywhere in the domain while maintaining expressiveness; this appears harder in the continuous setting. For instance, if the inputs are Gaussian over , one likely needs the biases in the ReLU units to be very large to enforce -sparsity. This suggests that, in the continuous domain, more non-standard distributions are likely necessary to obtain a rich class of functions which are sparse everywhere in the domain. Hence for theoretical simplicity we focus on functions on the Boolean hypercube.
Even with the sparsity assumption, the class is likely hard to learn in polynomial time (or even quasi-polynomial time) under an arbitrary distribution on the hypercube. In particular, we show that parities on the hypercube on variables can be computed by , with coefficient vectors of norm at most . Thus, need queries in the powerful Statistical Queries (SQ) model (see \Crefsec:lb-uniform for details). We also show cryptographic hardness results for learning under generic distributions on the hypercube.
Theorem 1.3 (Informal; see \Crefsec:lb-uniform).
Any SQ algorithm for learning under arbitrary distributions over the hypercube either requires tolerance or queries.
Assuming the hardness of learning with rounding problem with polynomial modulus, there is no run-time algorithm to -PAC learn .
Learning under uniform distribution.
Given the above hardness results, it is natural to consider distributional assumptions as is often done for related classes in learning theory (e.g., Klivans et al. (2004); Kane (2014) etc.). Our main result is that when the input distribution is uniform over the -dimensional hypercube, , the class can be learned in time :
Theorem 1.4 (Informal; see \Crefthm:generalk-uniform-ub).
There exists an -PAC learning algorithm for with respect to the uniform distribution over that has sample complexity and run-time (suppressing dependence on ).
As our learning algorithm works by performing linear regression over low-degree monomial basis (a.k.a. the low-degree algorithm), the guarantees work even in the agnostic or non-realizable setting by standard arguments (e.g., Klivans et al. (2004)). For simplicity, we focus on the realizable setting as the algorithm and analysis do not change for the agnostic case.
For sparsity , the above run-time is . As we showed above, can simulate juntas of size over variables. Thus, a quasi-polynomial run-time is the best we can do under a widely believed conjecture on the hardness of learning juntas.
The guarantee above is in stark contrast to what is achievable for general one-layer size ReLU networks under the uniform distribution over the hypercube. One-layer size- networks can simulate parities on variables. They thus cannot be learned even under the uniform distribution on the hypercube by SQ algorithms with less than queries. Further, even for non-SQ algorithms, as shown in (Chen et al., 2022a), quasi-polynomial run-time with respect to the uniform distribution on the hypercube is impossible under widely studied cryptographic assumptions.
The proof of \Crefthm:k-uniform-ub is via Fourier analysis and the low-degree algorithm. The main ingredient is to show that the average-sensitivity of functions in is at most . We then use this bound the noise-sensitivity of functions in . The latter implies the existence of a low-degree approximation by exploiting Klivans et al. (2004) which is enough to obtain the theorem. See \Crefsec:ub-uniform for details.
Learning under general distributions.
We also show that can be learnt under general distributions with smaller sample complexity than would be required without the sparsity condition, in the case when . In particular, we show the following.
Theorem 1.5 (Informal; see \Crefthm:general-dist-upper-bound).
There exists an -PAC learning algorithm for over that has sample complexity (suppressing dependence on ).
By contrast, the class (that is, size- networks without activation sparsity) requires a sample complexity of .linecolor=myGold,backgroundcolor=myGold!25,bordercolor=myGold]Pritish: Is this right? To prove the above, we provide a bound on the Rademacher complexity of the class that has an improved dependence on .
Taken together, our results demonstrate that leveraging dynamic activation sparsity is theoretically possible for both computational and statistical benefits. We hope that further theoretical study of the class of sparsely activated networks could pave the way for more efficient training and inference methods for deep architectures, including transformer-based models where these sparsely activated networks have been observed to arise in practice.
1.1 Related Work
Our work is motivated by recent empirical observations on the extreme sparsity observed in the MLP layers of trained transformer models (Li et al., 2023; Shen et al., 2023). The works of Li et al. (2023); Peng et al. (2023) propose theoretical explanations of why this phenomenon occurs. However, ours is the first work to formally study sparsely activated networks in the PAC learning setup and quantify their computational and statistical advantages. Motivated by the observation on sparsity, recent work has also studied the connections between the MLP layers and key-value memory lookups (Sukhbaatar et al., 2019; Lample et al., 2019; Geva et al., 2020).
There have also been recent works on designing networks with explicitly enforced sparsity structure. One such line of work concerns mixture of experts models (Shazeer et al., 2017; Fedus et al., 2022) where each input is independently routed to one or two MLP blocks among a set of experts. An alternate way to enforce sparsity is to introduce a top- operation after each MLP layer that zeros out most of the activations (Csordás et al., 2023; Li et al., 2023). In particular, Li et al. (2023) propose a top- transformer along these lines. However, due to the top- operation being relatively slow on accelerator hardware, this technique does not yield wall-clock speedup for either training or inference.
In another recent work Liu et al. (2023) propose to train a small predictor network to predict the activated indices at each MLP layer. There has also been work to explore enforcing block sparsity constraints and weight tying in the model weights themselves (Dong et al., 2023), as well as efforts to enforce static sparsity that is not input dependent (Frantar and Alistarh, 2023). However such methods haven’t been effective for language modeling via transformer models and have been much more successful in classification domains that have a small number of output labels.
2 Preliminaries
We consider the problem of learning real-valued functions over the input space , to small expected -squared error, namely for the underlying distribution over , our goal is the minimize the population loss of a predictor given as where . For any dataset , we denote the empirical loss as .
For any hypothesis class , we say that is -realizable, if there exists such that holds with probability for . Following the standard definition of probably approximately correct (PAC) learning (Valiant, 1984), we say that a learning algorithm -PAC learns with sample complexity if for all -realizable distributions over , and for , it holds with probability at least that . We say that a learning algorithm -PAC learns under distribution (over ) if the learning guarantee holds for all -realizable with the marginal over being . In particular, we use to denote the uniform distribution over .
2.1 Fourier Analysis and the Low-Degree Algorithm
Any function , has a unique Fourier representation given as where . The degree of , denoted , is the largest such that for some with . The norm of under the uniform distribution is defined as (O’Donnell, 2014).
We define the sensitivity of at as , where is with the -th bit flipped; the scaling factor of means that for , sensitivity can be interpreted as . The average sensitivity is defined as . For any , let denote the distribution obtained by flipping each coordinate of with probability . The -noise sensitivity of is .
A connection between noise sensitivity and Fourier concentration was first observed in Klivans et al. (2004). We state this connection below, along with other basic facts about Fourier coefficients.
Claim 1.
[See Klivans et al. (2004)] The following properties hold for all :
-
•
, and
-
•
, and hence .
We also need a bound on the average sensitivity of a single halfspace which is known to be . We require a more fine-grained version from Kane (2014) which quantifies the dependence on the bias of the halfspace.
Lemma 2.1 (Kane (2014)).
Let be a halfspace: and . Then, .
Proof 2.2.
Without loss of generality, we can assume that the coefficients of are positive. This makes a monotone function which is non-decreasing in each coordinate. Now, for , and ,
where the second equality is due to the non-decreasing nature of and that takes values in . Therefore,
the claim now follows from Lemma 6 of Kane (2014).
Low-degree algorithm.
We recall the standard low-degree algorithm and its guarantees for learning hypothesis classes that exhibit low-degree Fourier concentration (see e.g., Klivans et al. (2004) for details). For any hypothesis class , let .
Lemma 2.3.
For hypothesis class such that for all , there exists an -PAC learning algorithm for with sample and time complexity.
The algorithm operates by performing polynomial regression, that is, linear regression in the basis of monomials of degree at most . The algorithm achieves the desired error because is such that , and hence there exists a good solution to the polynomial regression problem.linecolor=myGold,backgroundcolor=myGold!25,bordercolor=myGold]Pritish: Cite Hsu-Kakade-Zhang?
3 Learning over Uniform Distribution
In this section we provide a learning algorithm for -sparsely activated networks under the uniform distribution.
Theorem 3.1.
There exists an -PAC learning algorithm for with respect to the uniform distribution over that has sample complexity and run-time for
At a high level, we show that all hypotheses in exhibit low-degree Fourier concentration and hence can be learned over the uniform distribution using the low-degree algorithm (\Creflem:low-degree-alg). To show Fourier concentration, we bound the noise sensitivity of sparse-activated networks by first showing a bound on the average sensitivity and then converting this to a bound on noise sensitivity.
Lemma 3.2.
For all , it holds that .
Proof 3.3.
Consider given as . For any , let for and . Since and , it follows that and . For any , let be defined as . Since is -sparse, we have that and hence and . It is easy to see that for it holds that for all .
The average sensitivity of is given as
(U) | ||||
(V) |
We bound term (U) as,
(U) | |||
We bound term (V) as follows using the inequality ,
(V) | |||
For , we have that
Note that (by -sparsity), and hence for , we have that . From \Creflm:ashalfspace, we have that . Thus,
where we use concavity of for . For each with , we have by Hoeffding bound that for some sufficiently large and ,
Hence, in particular we have that
And for all , we also have that holds with probability . Thus, we can upper bound (V) as,
(V) | |||
Next, we can use the bound on average sensitivity to bound the noise sensitivity of functions in . To do so we use an argument attributed to Peres for converting bounds on average sensitivity to bounds on noise sensitivity, allowing us to get better low-degree approximations.
Lemma 3.4.
For any ,
The proof of \Creflem:as-to-ns-generalk is provided in \appendixrefapx:as-to-ns.
[\theoremrefthm:generalk-uniform-ub] We combine \Creffact:fourier, \Creflem:low-degree-alg and \Creflem:as-to-ns-generalk. Fix an error parameter . Then, by \Creflem:as-to-ns-generalk, there is a constant , such that for
any , satisfies
Thus, we can choose a suitable , such that by \Creffact:fourier,
Finally, note that ; since at most neurons are active on any input, and each neuron can at most contribute . Thus, the theorem now follows from combining the above with \Creflem:low-degree-alg. The run-time and sample complexity will be where is as above.
Remark 3.5.
thm:generalk-uniform-ub can be extended to hold in case of the hypothesis class where -sparsity need not hold for all inputs , but holds with probability at least over the input distribution, that is, . This is by decomposing into (U), (V) and a third term handling for which the -sparsity is violated.
4 Lower Bounds for Learning
Note that the previous section implies a quasi-polynomial time learning algorithm for the class of -sparsely activated networks. We next show that a quasi-polynomial run-time is likely necessary for learning under the uniform distribution and stronger lower bounds under arbitrary distributions.
Sparse Activations Can Simulate Juntas
We first show that our proposed learning algorithms for the case of the uniform distribution have near-optimal runtime under a widely believed conjecture on the hardness of learning juntas. Let denote the set of Boolean functions that only depend on at most variables.
Conjecture 4.1 (Hardness of learning Juntas).
(see e.g. Mossel et al. (2003); Feldman et al. (2011)) There is no -PAC learning algorithm for learning under the uniform distribution on the hypercube that runs in time .linecolor=myGold,backgroundcolor=myGold!25,bordercolor=myGold]Pritish: We cited Mossel et al. (2003) and Feldman et al. (2011) in rebuttal. Cite those properly and note what they exactly say.linecolor=Gblue,backgroundcolor=Gblue!25,bordercolor=Gblue]Nishanth: problem is these citations also don’t formally state this as a conjecture IIRC.
The conjecture implies that there is no learning algorithm for that runs in time.
Theorem 4.2.
Assuming \Crefconj:junta-hardness, there is no -PAC learning algorithm for for and over that runs in time.
Proof 4.3.
We show that for all , that is, for any -junta can be expressed as where and . Suppose w.l.o.g. that depends on . Let be distinct vectors that take all possible values in the first coordinates, and are on other coordinates. Let for any such that for all and . Let and for all . It is now easy to verify that for all ,
Thus, the theorem follows under the assumption of \Crefconj:junta-hardness.
Hardness Under Arbitrary Distributions
We next show that one-sparse activation networks over can simulate parities of size . Fix an integer , and for , let be defined by if and only if is even. Now, we can use the following simple identity (similar identities were used for similar purposes for example in Klivans and Sherstov (2006))
Note that for any , at most one ReLU node is active. This is not quite enough to capture as the function inside the ReLUs are not linear. To fix this, we linearize the quadratic function by increasing the dimension. For , let be defined as follows:
Let and identify with in the natural way. Observe that for any , , there exists a vector such that
In particular, we can take , and if and . Note that and .
In summary, there exists a distribution on such that learning parities over under the uniform distribution is implied by learning under the distribution . The first part of \Crefth:lb-general now follows from standard lower bounds for learning parities.
SQ Hardness
Consider a class of functions, denoted by , that maps to , and let be a distribution over .
In the Statistical Query (SQ) model, as described by Kearns (1998), the learner interacts with the data through an SQ oracle. For a bounded query function and a tolerance , the oracle can return any value such that the absolute difference . The goal in SQ learning is to learn an approximation to the unknown concept only using few queries as above with reasonable tolerance. We will use the following classical theorem:
Theorem 4.4 ((Blum et al., 1994)).
Any SQ algorithm for learning the class of parities over within error under the uniform distribution over the hypercube with tolerance requires queries.
The first part of \Crefth:lb-general follows immediately from the above and the fact that parities on variables can be computed in as described.
Cryptographic Hardness
We sketch the argument here. Following Chen et al. (2022a), our starting point will be the Learning with Rounding (LWR) problem (Banerjee et al., 2012):
Definition 4.5.
For moduli , , let by
In the problem the secret is drawn uniformly at random and we are given samples of the form where is uniform over . The goal is to output a hypothesis that achieves a small error in predicting the label . It is conjectured that there is no algorithm for .
Conjecture 4.6 (See Banerjee et al. (2012)).
There is no run-time algorithm to solve the with probability at least (over the random choice of and the samples).
We show that an efficient algorithm for functions under arbitrary distributions on the hypercube will contradict this assumption.
Consider an instance of the problem. First, map to for by considering the binary representation of the integers in . Next, let be such that . Note that for every , we can find a vector such that . Then,
Now, observe that we can write
Note that in the conversion and . Further, for any input , only one of the ReLUs will be active. However, the above is not quite in as we have a quadratic function inside the ReLU. Just as we did for parities, we can fix this issue by linearizing the quadratic form. Let , and define by setting if and . Then, just as in our argument for parities, there exists a lifted weight vector and such that
In addition, it is easy to check that . In particular, we get that for every , there exists a function in such that for every ,
where is the embedding as defined above and in showing SQ hardness. The second part of \Crefth:lb-general now follows from the conjectured hardness of ; we omit the minor details.
5 Learning under General Distributions
We now show the statistical advantage associated with sparsely activated neural networks over general distributions. In particular, we show that
Theorem 5.1.
There exists a -PAC learning algorithm for any with sample complexity .
This result even holds in a more general setting where the input space and for all . To begin with we will again consider the class of -sparsely activated networks, i.e., . We will discuss extensions to towards the end of the section.
We use Rademacher complexity to establish the bound in Theorem 5.1. Given a set of examples the empirical Rademacher complexity (Shalev-Shwartz and Ben-David, 2014) is defined as , where are valued Rademacher random variables. For , let .
Lemma 5.2 (see Shalev-Shwartz and Ben-David (2014)).
For any class mapping to , there exists an -PAC learning algorithm for with sample complexity equal to the smallest such that for a large enough constant , it holds thatlinecolor=myGold,backgroundcolor=myGold!25,bordercolor=myGold]Pritish: I am combining Lemmas 26.5 and 26.9 from Shalev-Shwartz and Ben-David (2014), and using that square loss is -Lipschitz.
thm:general-dist-upper-bound will follow from bounding the Rademacher complexity . Recall that in the absence of any sparsity assumption, existing results (Anthony et al., 1999) on the Rademacher complexity of -hidden layer ReLU networks with input dimensionality and hidden units lead to a bound of .111Better bounds are possible under stronger assumptions on the network weights (Wei et al., 2019).linecolor=myGold,backgroundcolor=myGold!25,bordercolor=myGold]Pritish: Double check that these bounds apply under our updated scaling. We will show that the main statistical advantage that comes from sparsity is that the dependence on the number of hidden units can be made sub-linear, albeit at the expense of an explicit dependence on the input dimensionality . In particular we will prove the following theorem.
Theorem 5.3.
It holds that
(1) |
Proof 5.4.
For a given hypothesis and for any , let be the subset of the examples that activate neuron , i.e., . Since each is determined by a halfspace in dimensions, by the Sauer-Shelah lemma (Shalev-Shwartz and Ben-David, 2014) there can be at most such subsets.
Next, we have
(2) | ||||
(3) | ||||
(4) |
We will bound the above two terms separately via standard concentration inequalities. For the second term note that for any fixed , the random variable is sub-Gaussian with norm . Hence we for any fixed the following holds (Vershynin, 2018)
(5) |
where is an absolute constant. Via the union bound we get that with probability at least , all sets simultaneously satisfy the above inequality.
Hence we get the following bound on the second term.
(6) |
From the fact that the activations are -sparse we get that . This implies that . Furthermore, using the fact that we get
(7) |
Setting we get that the second term is bounded by
(8) |
Similarly, we next bound the first term. Note that for any fixed , and any coordinate , sub-Gaussian concentration (Vershynin, 2018) implies that
(9) |
Via a union bound over all the coordinates and all possible subsets we get that with probability at least , all sets simultaneously satisfy
(10) |
Using the above we can bound the first term as
(11) | ||||
(12) | ||||
(13) |
Recall from above that . Furthermore, setting we get that the first term is bounded by
(14) |
Combining the bounds for the first and the second terms, we get the desired claim.
Generalization to -sparsely activated networks.
The above analysis extends in a straightforward manner to the class , i.e., the class of networks where each input activates at most hidden units.
To extend the bound in Theorem 5.3 we note that using the fact that -sparsity implies that we get that
(15) |
Note that in contrast to the classical bounds on Rademacher complexity of general norm bounded -layer neural networks the bound in Theorem 5.3 above has a sub-linear dependence on . However we incur an explicit dependency on the input dimensionality.
We suspect that this is a limitation of our proof technique and conjecture that the right dependence should not have any explicit dependence on the input dimension .
Conjecture 5.5.
The class of -sparsely activated neural networks satisfies
(16) |
6 Discussion & Future Directions
Motivated by the empirical phenomenon of activation sparsity in MLP layers of large transformer models, in this work we proposed and studied the problem of PAC learning the class of sparsely activated neural networks. This is a novel concept class with many interesting properties. The form of input-dependent sparsity present in this class of functions makes it distinct from the typical sparse function classes studied in literature. The main conceptual insight from our work is that despite the empirical challenges in leveraging sparsity, activation sparsity can provably provide both computational and statistical benefits.
Several open questions come out of our work. While we provide algorithms with near optimal running time for the case of the uniform distribution, it would be interesting to design learning algorithms under arbitrary distributions that are provably better than the -time algorithms that exist for general -layer ReLU networks (Goel et al., 2020). As mentioned in \Crefsec:rademacher we strongly suspect that the dependence on the input dimension in the Rademacher complexity bound of \Crefthm:rademacher-1-sparse-bound-2 is suboptimal. While we primarily considered networks that are sparsely activated for all inputs, it might be interesting to also consider sparsely activated with high probability over input distributions, as we briefly alluded to in \Crefrem:sparse-with-high-prob although in that case, the probability of not being sparsely activated was very small. Finally, it would be interesting to explore practical algorithms for leveraging sparsity based on our theoretical insights.
We thank anonymous reviewers for their comments that helped improve the presentation.
References
- Anil et al. (2018) Rohan Anil, Gabriel Pereyra, Alexandre Passos, Robert Ormandi, George E Dahl, and Geoffrey E Hinton. Large scale distributed neural network training through online distillation. arXiv preprint arXiv:1804.03235, 2018.
- Anthony et al. (1999) Martin Anthony, Peter L Bartlett, Peter L Bartlett, et al. Neural network learning: Theoretical foundations, volume 9. cambridge university press Cambridge, 1999.
- Banerjee et al. (2012) Abhishek Banerjee, Chris Peikert, and Alon Rosen. Pseudorandom functions and lattices. In Annual International Conference on the Theory and Applications of Cryptographic Techniques, pages 719–737. Springer, 2012.
- Banner et al. (2019) Ron Banner, Yury Nahshan, and Daniel Soudry. Post training 4-bit quantization of convolutional networks for rapid-deployment. Advances in Neural Information Processing Systems, 32, 2019.
- Blum et al. (1994) Avrim Blum, Merrick Furst, Jeffrey Jackson, Michael Kearns, Yishay Mansour, and Steven Rudich. Weakly learning dnf and characterizing statistical query learning using fourier analysis. In Proceedings of the twenty-sixth annual ACM symposium on Theory of computing, pages 253–262, 1994.
- Brown et al. (2020) Tom Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared D Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, et al. Language models are few-shot learners. Advances in neural information processing systems, 33:1877–1901, 2020.
- Chen et al. (2022a) Sitan Chen, Aravind Gollakota, Adam R. Klivans, and Raghu Meka. Hardness of noise-free learning for two-hidden-layer neural networks. In Neural Information Processing Systems (NeurIPS), 2022a. URL http://papers.nips.cc/paper_files/paper/2022/hash/45a7ca247462d9e465ee88c8a302ca70-Abstract-Conference.html.
- Chen et al. (2022b) Xi Chen, Xiao Wang, Soravit Changpinyo, AJ Piergiovanni, Piotr Padlewski, Daniel Salz, Sebastian Goodman, Adam Grycner, Basil Mustafa, Lucas Beyer, et al. Pali: A jointly-scaled multilingual language-image model. arXiv preprint arXiv:2209.06794, 2022b.
- Choromanski et al. (2020) Krzysztof Choromanski, Valerii Likhosherstov, David Dohan, Xingyou Song, Andreea Gane, Tamas Sarlos, Peter Hawkins, Jared Davis, Afroz Mohiuddin, Lukasz Kaiser, et al. Rethinking attention with performers. arXiv preprint arXiv:2009.14794, 2020.
- Chowdhery et al. (2022) Aakanksha Chowdhery, Sharan Narang, Jacob Devlin, Maarten Bosma, Gaurav Mishra, Adam Roberts, Paul Barham, Hyung Won Chung, Charles Sutton, Sebastian Gehrmann, et al. Palm: Scaling language modeling with pathways. arXiv preprint arXiv:2204.02311, 2022.
- Csordás et al. (2023) Róbert Csordás, Kazuki Irie, and Jürgen Schmidhuber. Approximating two-layer feedforward networks for efficient transformers. arXiv preprint arXiv:2310.10837, 2023.
- Dong et al. (2023) Harry Dong, Beidi Chen, and Yuejie Chi. Towards structured sparsity in transformers for efficient inference. In Workshop on Efficient Systems for Foundation Models@ ICML2023, 2023.
- Dosovitskiy et al. (2020) Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, et al. An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929, 2020.
- Elhage et al. (2022) Nelson Elhage, Tristan Hume, Catherine Olsson, Neel Nanda, Tom Henighan, Scott Johnston, Sheer ElShowk, Nicholas Joseph, Nova DasSarma, Ben Mann, Danny Hernandez, Amanda Askell, Kamal Ndousse, Andy Jones, Dawn Drain, Anna Chen, Yuntao Bai, Deep Ganguli, Liane Lovitt, Zac Hatfield-Dodds, Jackson Kernion, Tom Conerly, Shauna Kravec, Stanislav Fort, Saurav Kadavath, Josh Jacobson, Eli Tran-Johnson, Jared Kaplan, Jack Clark, Tom Brown, Sam McCandlish, Dario Amodei, and Christopher Olah. Softmax linear units. Transformer Circuits Thread, 2022. https://transformer-circuits.pub/2022/solu/index.html.
- Fedus et al. (2022) William Fedus, Barret Zoph, and Noam Shazeer. Switch transformers: Scaling to trillion parameter models with simple and efficient sparsity. The Journal of Machine Learning Research, 23(1):5232–5270, 2022.
- Feldman et al. (2011) Vitaly Feldman, Homin K. Lee, and Rocco A. Servedio. Lower bounds and hardness amplification for learning shallow monotone formulas. In Conference on Learning Theory (COLT), volume 19 of JMLR Proceedings, pages 273–292. JMLR.org, 2011. URL http://proceedings.mlr.press/v19/feldman11a/feldman11a.pdf.
- Frankle and Carbin (2018) Jonathan Frankle and Michael Carbin. The lottery ticket hypothesis: Finding sparse, trainable neural networks. arXiv preprint arXiv:1803.03635, 2018.
- Frantar and Alistarh (2023) Elias Frantar and Dan Alistarh. Sparsegpt: Massive language models can be accurately pruned in one-shot. In International Conference on Machine Learning, pages 10323–10337. PMLR, 2023.
- Geva et al. (2020) Mor Geva, Roei Schuster, Jonathan Berant, and Omer Levy. Transformer feed-forward layers are key-value memories. arXiv preprint arXiv:2012.14913, 2020.
- Gholami et al. (2022) Amir Gholami, Sehoon Kim, Zhen Dong, Zhewei Yao, Michael W Mahoney, and Kurt Keutzer. A survey of quantization methods for efficient neural network inference. In Low-Power Computer Vision, pages 291–326. Chapman and Hall/CRC, 2022.
- Goel et al. (2020) Surbhi Goel, Aravind Gollakota, and Adam Klivans. Statistical-query lower bounds via functional gradients. Advances in Neural Information Processing Systems, 33:2147–2158, 2020.
- Grimaldi et al. (2023) Matteo Grimaldi, Darshan C Ganji, Ivan Lazarevich, and Sudhakar Sah. Accelerating deep neural networks via semi-structured activation sparsity. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pages 1179–1188, 2023.
- Gu and Dao (2023) Albert Gu and Tri Dao. Mamba: Linear-time sequence modeling with selective state spaces. arXiv preprint arXiv:2312.00752, 2023.
- Harutyunyan et al. (2023) Hrayr Harutyunyan, Ankit Singh Rawat, Aditya Krishna Menon, Seungyeon Kim, and Sanjiv Kumar. Supervision complexity and its role in knowledge distillation. arXiv preprint arXiv:2301.12245, 2023.
- Hinton et al. (2015) Geoffrey Hinton, Oriol Vinyals, and Jeff Dean. Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531, 2015.
- Kane (2014) Daniel M. Kane. The average sensitivity of an intersection of half spaces. In Symposium on Theory of Computing (STOC), pages 437–440. ACM, 2014. 10.1145/2591796.2591798. URL https://doi.org/10.1145/2591796.2591798.
- Kearns (1998) Michael J. Kearns. Efficient noise-tolerant learning from statistical queries. J. ACM, 45(6):983–1006, 1998. 10.1145/293347.293351. URL https://doi.org/10.1145/293347.293351.
- Klivans and Sherstov (2006) Adam R. Klivans and Alexander A. Sherstov. Cryptographic hardness for learning intersections of halfspaces. In 47th Annual IEEE Symposium on Foundations of Computer Science (FOCS 2006), 21-24 October 2006, Berkeley, California, USA, Proceedings, pages 553–562. IEEE Computer Society, 2006. 10.1109/FOCS.2006.24. URL https://doi.org/10.1109/FOCS.2006.24.
- Klivans et al. (2004) Adam R. Klivans, Ryan O’Donnell, and Rocco A. Servedio. Learning intersections and thresholds of halfspaces. J. Comput. Syst. Sci., 68(4):808–840, 2004. 10.1016/J.JCSS.2003.11.002. URL https://doi.org/10.1016/j.jcss.2003.11.002.
- Lample et al. (2019) Guillaume Lample, Alexandre Sablayrolles, Marc’Aurelio Ranzato, Ludovic Denoyer, and Hervé Jégou. Large memory layers with product keys. Advances in Neural Information Processing Systems, 32, 2019.
- Li et al. (2023) Zonglin Li, Chong You, Srinadh Bhojanapalli, Daliang Li, Ankit Singh Rawat, Sashank J. Reddi, Ke Ye, Felix Chern, Felix X. Yu, Ruiqi Guo, and Sanjiv Kumar. The lazy neuron phenomenon: On emergence of activation sparsity in transformers. In The Eleventh International Conference on Learning Representations, ICLR 2023, Kigali, Rwanda, May 1-5, 2023. OpenReview.net, 2023. URL https://openreview.net/pdf?id=TJ2nxciYCk-.
- Liu et al. (2023) Zichang Liu, Jue Wang, Tri Dao, Tianyi Zhou, Binhang Yuan, Zhao Song, Anshumali Shrivastava, Ce Zhang, Yuandong Tian, Christopher Re, and Beidi Chen. Deja vu: Contextual sparsity for efficient LLMs at inference time. In International Conference on Machine Learning (ICML), volume 202 of Proceedings of Machine Learning Research, pages 22137–22176. PMLR, 23–29 Jul 2023. URL https://proceedings.mlr.press/v202/liu23am.html.
- Mirzadeh et al. (2023) Iman Mirzadeh, Keivan Alizadeh, Sachin Mehta, Carlo C Del Mundo, Oncel Tuzel, Golnoosh Samei, Mohammad Rastegari, and Mehrdad Farajtabar. Relu strikes back: Exploiting activation sparsity in large language models. arXiv preprint arXiv:2310.04564, 2023.
- Mossel et al. (2003) Elchanan Mossel, Ryan O’Donnell, and Rocco A. Servedio. Learning juntas. In Symposium on Theory of Computing (STOC), pages 206–212. ACM, 2003. 10.1145/780542.780574. URL https://doi.org/10.1145/780542.780574.
- O’Donnell (2014) Ryan O’Donnell. Analysis of boolean functions. Cambridge University Press, 2014.
- Peng et al. (2023) Ze Peng, Lei Qi, Yinghuan Shi, and Yang Gao. Theoretical explanation of activation sparsity through flat minima and adversarial robustness. arXiv preprint arXiv:2309.03004, 2023.
- Shalev-Shwartz and Ben-David (2014) Shai Shalev-Shwartz and Shai Ben-David. Understanding machine learning: From theory to algorithms. Cambridge university press, 2014.
- Shazeer et al. (2017) Noam Shazeer, Azalia Mirhoseini, Krzysztof Maziarz, Andy Davis, Quoc Le, Geoffrey Hinton, and Jeff Dean. Outrageously large neural networks: The sparsely-gated mixture-of-experts layer. arXiv preprint arXiv:1701.06538, 2017.
- Shen et al. (2023) Kai Shen, Junliang Guo, Xu Tan, Siliang Tang, Rui Wang, and Jiang Bian. A study on relu and softmax in transformer. arXiv preprint arXiv:2302.06461, 2023.
- Sukhbaatar et al. (2019) Sainbayar Sukhbaatar, Edouard Grave, Guillaume Lample, Herve Jegou, and Armand Joulin. Augmenting self-attention with persistent memory. arXiv preprint arXiv:1907.01470, 2019.
- Valiant (1984) Leslie G Valiant. A theory of the learnable. Communications of the ACM, 27(11):1134–1142, 1984.
- Vaswani et al. (2017) Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, ukasz Kaiser, and Illia Polosukhin. Attention is all you need. Advances in neural information processing systems, 30, 2017.
- Vershynin (2018) Roman Vershynin. High-dimensional probability: An introduction with applications in data science, volume 47. Cambridge university press, 2018.
- Wang et al. (2020) Sinong Wang, Belinda Z Li, Madian Khabsa, Han Fang, and Hao Ma. Linformer: Self-attention with linear complexity. arXiv preprint arXiv:2006.04768, 2020.
- Wei et al. (2019) Colin Wei, Jason Lee, Qiang Liu, and Tengyu Ma. On the margin theory of feedforward neural networks, 2019. URL https://openreview.net/forum?id=HJGtFoC5Fm.
- Zaheer et al. (2020) Manzil Zaheer, Guru Guruganesh, Kumar Avinava Dubey, Joshua Ainslie, Chris Alberti, Santiago Ontanon, Philip Pham, Anirudh Ravula, Qifan Wang, Li Yang, et al. Big bird: Transformers for longer sequences. Advances in neural information processing systems, 33:17283–17297, 2020.
Appendix A Example of a Sparsely Activated Network without Weight Sparsity
There are interesting functions (beyond juntas/parities) that are sparsely activated but do not have weight sparsity. E.g.: suppose . Consider , , and look at , of the form , where the input is . When , this network is -sparsely activated for all inputs, and when , the function is -sparse with probability under the uniform distribution on . Remark 3.5 shows that our results continue to hold in such a setting. Intuitively, such functions are similar to Indexing; they return the function for all (or most) of the input space, where can depend arbitrarily on the part of the input.
Appendix B Proof of \Creflem:as-to-ns-generalk
[\Creflem:as-to-ns-generalk] Given a , let . We describe an alternate way to sample . First sample uniformly at random and partition the coordinates of into the buckets at random (each coordinate is included in exactly one of these buckets uniformly and independently). For each , sample uniformly at random. Multiply the coordinates of by and concatenate all the buckets to get . Choose one bucket at random and flip to get . Multiply the coordinates of by to get . Observe that are distributed exactly the same as . Now, given , define
where . Clearly . Hence,
(17) |
From \Creflem:avg-sens-generalk,
To bound we need to bound
For any , we have from measure concentration
Now we use that .
Combining with the fact that is always at most , we get that
Combining the above with (B), we get
The claim now follows.