Smoothing the Landscape Boosts the Signal for SGD
Optimal Sample Complexity for Learning Single Index Models
Abstract
We focus on the task of learning a single index model with respect to the isotropic Gaussian distribution in dimensions. Prior work has shown that the sample complexity of learning is governed by the information exponent of the link function , which is defined as the index of the first nonzero Hermite coefficient of . Ben Arous et al. [1] showed that samples suffice for learning and that this is tight for online SGD. However, the CSQ lower bound for gradient based methods only shows that samples are necessary. In this work, we close the gap between the upper and lower bounds by showing that online SGD on a smoothed loss learns with samples. We also draw connections to statistical analyses of tensor PCA and to the implicit regularization effects of minibatch SGD on empirical losses.
1 Introduction
Gradient descent-based algorithms are popular for deriving computational and statistical guarantees for a number of high-dimensional statistical learning problems [2, 3, 1, 4, 5, 6]. Despite the fact that the empirical loss is nonconvex and in the worst case computationally intractible to optimize, for a number of statistical learning tasks gradient-based methods still converge to good solutions with polynomial runtime and sample complexity. Analyses in these settings typically study properties of the empirical loss landscape [7], and in particular the number of samples needed for the signal of the gradient arising from the population loss to overpower the noise in some uniform sense. The sample complexity for learning with gradient descent is determined by the landscape of the empirical loss.
One setting in which the empirical loss landscape showcases rich behavior is that of learning a single-index model. Single index models are target functions of the form , where is the unknown relevant direction and is the known link function. When the covariates are drawn from the standard -dimensional Gaussian distribution, the shape of the loss landscape is governed by the information exponent of the link function , which characterizes the curvature of the loss landscape around the origin. Ben Arous et al. [1] show that online stochastic gradient descent on the empirical loss can recover with samples; furthermore, they present a lower bound showing that for a class of online SGD algorithms, samples are indeed necessary.
However, gradient descent can be suboptimal for various statistical learning problems, as it only relies on local information in the loss landscape and is thus prone to getting stuck in local minima. For learning a single index model, the Correlational Statistical Query (CSQ) lower bound only requires samples to recover [6, 4], which is far fewer than the number of samples required by online SGD. This gap between gradient-based methods and the CSQ lower bound is also present in the Tensor PCA problem [8]; for recovering a rank 1 -tensor in dimensions, both gradient descent and the power method require samples, whereas more sophisticated spectral algorithms can match the computational lower bound of samples.
In light of the lower bound from [1], it seems hopeless for a gradient-based algorithm to match the CSQ lower bound for learning single-index models. [1] considers the regime in which SGD is simply a discretization of gradient flow, in which case the poor properties of the loss landscape with insufficient samples imply a lower bound. However, recent work has shown that SGD is not just a discretization to gradient flow, but rather that it has an additional implicit regularization effect. Specifically, [9, 10, 11] show that over short periods of time, SGD converges to a quasi-stationary distribution where is an initial reference point, is a matrix depending on the Hessian and the noise covariance and measures the strength of the noise where is the learning rate and is the batch size. The resulting long term dynamics therefore follow the smoothed gradient which has the effect of regularizing the trace of the Hessian.
This implicit regularization effect of minibatch SGD has been shown to drastically improve generalization and reduce the number of samples necessary for supervised learning tasks [12, 13, 14]. However, the connection between the smoothed landscape and the resulting sample complexity is poorly understood. Towards closing this gap, we consider directly smoothing the loss landscape in order to efficiently learn single index models. Our main result, Theorem 1, shows that online SGD on the smoothed loss learns in samples, which matches the correlation statistical query (CSQ) lower bound. This improves over the lower bound for online SGD on the unsmoothed loss from Ben Arous et al. [1]. Key to our analysis is the observation that smoothing the loss landscape boosts the signal-to-noise ratio in a region around the initialization, which allows the iterates to avoid the poor local minima for the unsmoothed empirical loss. Our analysis is inspired by the implicit regularization effect of minibatch SGD, along with the partial trace algorithm for Tensor PCA which achieves the optimal sample complexity for computationally efficient algorithms.
The outline of our paper is as follows. In Section 3 we formalize the specific statistical learning setup, define the information exponent , and describe our algorithm. Section 4 contains our main theorem, and Section 5 presents a heuristic derivation for how smoothing the loss landscape increases the signal-to-noise ratio. We present empirical verification in Section 6, and in Section 7 we detail connections to tensor PCA nad minibatch SGD.
2 Related Work
There is a rich literature on learning single index models. Kakade et al. [15] showed that gradient descent can learn single index models when the link function is Lipschitz and monotonic and designed an alternative algorithm to handle the case when the link function is unknown. Soltanolkotabi [16] focused on learning single index models where the link function is which has information exponent . The phase-retrieval problem is a special case of the single index model in which the link function is or ; this corresponds to , and solving phase retrieval via gradient descent has been well studied [17, 18, 19]. Dudeja and Hsu [20] constructed an algorithm which explicitly uses the harmonic structure of Hermite polynomials to identify the information exponent. Ben Arous et al. [1] provided matching upper and lower bounds that show that samples are necessary and sufficient for online SGD to recover .
Going beyond gradient-based algorithms, Chen and Meka [21] provide an algorithm that can learn polynomials of few relevant dimensions with samples, including single index models with polynomial link functions. Their estimator is based on the structure of the filtered PCA matrix , which relies on the heavy tails of polynomials. In particular, this upper bound does not apply to bounded link functions. Furthermore, while their result achieves the information-theoretically optimal dependence it is not a CSQ algorithm, whereas our Algorithm 1 achieves the optimal sample complexity over the class of CSQ algorithms (which contains gradient descent).
Recent work has also studied the ability of neural networks to learn single or multi-index models [5, 6, 22, 23, 4]. Bietti et al. [5] showed that two layer neural networks are able to adapt to unknown link functions with samples. Damian et al. [6] consider multi-index models with polynomial link function, and under a nondegeneracy assumption which corresponds to the case, show that SGD on a two-layer neural network requires samples. Abbe et al. [23, 4] provide a generalization of the information exponent called the leap. They prove that in some settings, SGD can learn low dimensional target functions with samples. However, they conjecture that the optimal rate is and that this can be achieved by ERM rather than online SGD.
The problem of learning single index models with information exponent is strongly related to the order Tensor PCA problem (see Section 7.1), which was introduced by Richard and Montanari [8]. They conjectured the existence of a computational-statistical gap for Tensor PCA as the information-theoretic threshold for the problem is , but all known computationally efficient algorithms require . Furthermore, simple iterative estimators including tensor power method, gradient descent, and AMP are suboptimal and require samples. Hopkins et al. [24] introduced the partial trace estimator which succeeds with samples. Anandkumar et al. [25] extended this result to show that gradient descent on a smoothed landscape could achieve sample complexity when and Biroli et al. [26] heuristically extended this result to larger . The success of smoothing the landscape for Tensor PCA is one of the inspirations for Algorithm 1.
3 Setting
3.1 Data distribution and target function
Our goal is to efficiently learn single index models of the form where , the -dimensional unit sphere. We assume that is normalized so that . We will also assume that is differentiable and that has polynomial tails:
Assumption 1.
There exist constants such that .
Our goal is to recover given samples sampled i.i.d from
For simplicity of exposition, we assume that is known and we take our model class to be
3.2 Algorithm: online SGD on a smoothed landscape
As we will let denote the spherical gradient with respect to . That is, for a function , let where is the standard Euclidean gradient.
To compute the loss on a sample , we use the correlation loss:
Furthermore, when the sample is omitted we refer to the population loss:
Our primary contribution is that SGD on a smoothed loss achieves the optimal sample complexity for this problem. First, we define the smoothing operator :
Definition 1.
Let . We define the smoothing operator by
where is the uniform distribution over conditioned on being perpendicular to .
This choice of smoothing is natural for spherical gradient descent and can be directly related111 This is equivalent to the intrinsic definition where , is the unit sphere in , and is the Riemannian exponential map. to the Riemannian exponential map on . We will often abuse notation and write rather than . The smoothed empirical loss and the population loss are defined by:
Our algorithm is online SGD on the smoothed loss :
3.3 Hermite polynomials and information exponent
The sample complexity of Algorithm 1 depends on the Hermite coefficients of :
Definition 2 (Hermite Polynomials).
The th Hermite polynomial is the degree , monic polynomial defined by
where is the PDF of a standard Gaussian.
The first few Hermite polynomials are . For further discussion on the Hermite polynomials and their properties, refer to Section A.2. The Hermite polynomials form an orthogonal basis of so any function in admits a Hermite expansion. We let denote the Hermite coefficients of the link function :
Definition 3 (Hermite Expansion of ).
Let be the Hermite coefficients of , i.e.
The critical quantity of interest is the information exponent of :
Definition 4 (Information Exponent).
is the first index such that .
Example 1.
Below are some example link functions and their information exponents:
-
•
and have information exponents .
-
•
and have information exponents .
-
•
has information exponent . More generally, has information exponent .
Throughout our main results we focus on the case as when , online SGD without smoothing already achieves the optimal sample complexity of samples (up to log factors) [1].
4 Main Results
Our main result is a sample complexity guarantee for Algorithm 1:
Theorem 1.
Assume and . Let . For set and . For set and . Then if , with high probability the final iterate of Algorithm 1 satisfies
Theorem 1 uses large smoothing (up to ) to rapidly escape the regime in which . This first stage continues until which takes steps when . The second stage, in which and the learning rate decays linearly, lasts for an additional steps where is the target accuracy. Because Algorithm 1 uses each sample exactly once, this gives the sample complexity
to reach population loss . Setting is equivalent to zero smoothing and gives a sample complexity of , which matches the results of Ben Arous et al. [1]. On the other hand, setting to the maximal allowable value of gives:
which matches the sum of the CSQ lower bound, which is , and the information-theoretic lower bound, which is , up to poly-logarithmic factors.
To complement Theorem 1, we replicate the CSQ lower bound in [6] for the specific function class where . Statistical query learners are a family of learners that can query values and receive outputs with where denotes the query tolerance [27, 28]. An important class of statistical query learners is that of correlational/inner product statistical queries (CSQ) of the form . This includes a wide class of algorithms including gradient descent with square loss and correlation loss.
Theorem 2 (CSQ Lower Bound).
Consider the function class . Any CSQ algorithm using queries requires a tolerance of at most
to output an with population loss less than .
Using the standard heuristic which comes from concentration, this implies that samples are necessary to learn unless the algorithm makes exponentially many queries. In the context of gradient descent, this is equivalent to either requiring exponentially many parameters or exponentially many steps of gradient descent.
5 Proof Sketch
In this section we highlight the key ideas of the proof of Theorem 1. The full proof is deferred to Appendix B. The proof sketch is broken into three parts. First, we conduct a general analysis on online SGD to show how the signal-to-noise ratio (SNR) affects the sample complexity. Next, we compute the SNR for the unsmoothed objective () to heuristically rederive the sample complexity in Ben Arous et al. [1]. Finally, we show how smoothing boosts the SNR and leads to an improved sample complexity of when .
5.1 Online SGD Analysis
To begin, we will analyze a single step of online SGD. We define so that measures our current progress. Furthermore, let . Recall that the online SGD update is:
Using the fact that and we can Taylor expand the update for :
As in Ben Arous et al. [1], we decompose this update into a drift term and a martingale term. Let be the natural filtration. We focus on the drift term as the martingale term can be handled with standard concentration arguments. Taking expectations with respect to the fresh batch gives:
so to guarantee a positive drift, we need to set which gives us the value of used in Theorem 1 for . However, to simplify the proof sketch we can assume knowledge of and and optimize over to get a maximum drift of
The numerator measures the correlation of the population gradient with while the denominator measures the norm of the noisy gradient. Their ratio thus has a natural interpretation as the signal-to-noise ratio (SNR). Note that the SNR is a local property, i.e. the SNR can vary for different . When the SNR can be written as a function of , the SNR directly dictates the rate of optimization through the ODE approximation: . As online SGD uses each sample exactly once, the sample complexity for online SGD can be approximated by the time it takes this ODE to reach from . The remainder of the proof sketch will therefore focus on analyzing the SNR of the minibatch gradient .
5.2 Computing the Rate with Zero Smoothing
When , the signal and noise terms can easily be calculated. The key property we need is:
Property 1 (Orthogonality Property).
Let and let . Then:
Using 1 and the Hermite expansion of (Definition 3) we can directly compute the population loss and gradient. Letting denote the projection onto the subspace orthogonal to we have:
As throughout most of the trajectory, the gradient is dominated by the first nonzero Hermite coefficient so up to constants: . Similarly, a standard concentration argument shows that because is a random vector in dimensions where each coordinate is , . Therefore the SNR is equal to so with an optimal learning rate schedule,
This can be approximated by the ODE . Solving this ODE with the initial gives that the escape time is proportional to which heuristically re-derives the result of Ben Arous et al. [1].
5.3 How Smoothing boosts the SNR
Smoothing improves the sample complexity of online SGD by boosting the SNR of the stochastic gradient . Recall that the population loss was approximately equal to where is the first nonzero Hermite coefficient of . Isolating the dominant term and applying the smoothing operator , we get:
Because and we have that . Therefore,
Now because , the terms where is odd disappear. Furthermore, for a random , . Therefore reindexing and ignoring all constants we have that
Differentiating gives that
As this is a geometric series, it is either dominated by the first or the last term depending on whether or . Furthermore, the last term is either if is odd or if is even. Therefore the signal term is:
In addition, we can show that when , the noise term satisfies . Note that in the high signal regime (), both the signal and the noise are smaller by factors of which cancel when computing the SNR. However, when the smoothing shrinks the noise faster than it shrinks the signal, resulting in an overall larger SNR. Explicitly,
For , smoothing does not affect the SNR. However, when , smoothing greatly increases the SNR (see Figure 1).
Solving the ODE: gives that it takes steps to converge to from . Once , the problem is locally strongly convex, so we can decay the learning rate and use classical analysis of strongly-convex functions to show that with an additional steps, from which Theorem 1 follows.
6 Experiments
For and we ran a minibatch variant of Algorithm 1 with batch size when , the normalized th Hermite polynomial. We set:
We computed the number of samples required for Algorithm 1 to reach from and we report the min, mean, and max over random seeds. For each we fit a power law of the form in order to measure how the sample complexity scales with . For all values of , we find that which matches Theorem 1. The results can be found in Figure 2 and additional experimental details can be found in Appendix E.

7 Discussion
7.1 Tensor PCA
We next outline connections to the Tensor PCA problem. Introduced in [8], the goal of Tensor PCA is to recover the hidden direction from the noisy -tensor given by222This normalization is equivalent to the original normalization by setting .
where is a Gaussian noise tensor with each entry drawn i.i.d from .
The Tensor PCA problem has garnered significant interest as it exhibits a statistical-computational gap. is information theoretically recoverable when . However, the best polynomial-time algorithms require ; this lower bound has been shown to be tight for various notions of hardness such as CSQ or SoS lower bounds [29, 24, 30, 31, 32, 33, 34]. Tensor PCA also exhibits a gap between spectral methods and iterative algorithms. Algorithms that work in the regime rely on unfolding or contracting the tensor , or on semidefinite programming relaxations [29, 24]. On the other hand, iterative algorithms including gradient descent, power method, and AMP require a much larger sample complexity of [35]. The suboptimality of iterative algorithms is believed to be due to bad properties of the landscape of the Tensor PCA objective in the region around the initialization. Specifically [36, 37] argue that there are exponentially many local minima near the equator in the regime. To overcome this, prior works have considered “smoothed" versions of gradient descent, and show that smoothing recovers the computationally optimal SNR in the case [25] and heuristically for larger [26].
7.1.1 The Partial Trace Algorithm
The smoothing algorithms above are inspired by the following partial trace algorithm for Tensor PCA [24], which can be viewed as Algorithm 1 in the limit as [25]. Let . Then we will consider iteratively contracting indices of until all that remains is a vector (if is odd) or a matrix (if is even). Explicitly, we define the partial trace tensor by
When is odd, we can directly return as our estimate for and when is even we return the top eigenvector of . A standard concentration argument shows that this succeeds when . Furthermore, this can be strengthened to by using the partial trace vector as a warm start for gradient descent or tensor power method when is odd [25, 26].
7.1.2 The Connection Between Single Index Models and Tensor PCA
For both tensor PCA and learning single index models, gradient descent succeeds when the sample complexity is [35, 1]. On the other hand, the smoothing algorithms for Tensor PCA [26, 25] succeed with the computationally optimal sample complexity of . Our Theorem 1 shows that this smoothing analysis can indeed be transferred to the single-index model setting.
In fact, one can make a direct connection between learning single-index models with Gaussian covariates and Tensor PCA. Consider learning a single-index model when , the normalized th Hermite polynomial. Then minimizing the correlation loss is equivalent to maximizing the loss function:
Here denotes the th Hermite tensor (see Section A.2 for background on Hermite polynomials and Hermite tensors). In addition, by the orthogonality of the Hermite tensors, so we can decompose where by standard concentration, each entry of is order . We can therefore directly apply algorithms for Tensor PCA to this problem. We remark that this connection is a heuristic, as the structure of the noise in Tensor PCA and our single index model setting are different.
7.2 Empirical Risk Minimization on the Smoothed Landscape
Our main sample complexity guarantee, Theorem 1, is based on a tight analysis of online SGD (Algorithm 1) in which each sample is used exactly once. One might expect that if the algorithm were allowed to reuse samples, as is standard practice in deep learning, that the algorithm could succeed with fewer samples. In particular, Abbe et al. [4] conjectured that gradient descent on the empirical loss would succeed with samples.
Our smoothing algorithm Algorithm 1 can be directly translated to the ERM setting to learn with samples. We can then Taylor expand the smoothed loss in the large limit:
As , gradient descent on this smoothed loss will converge to which is equivalent to the partial trace estimator for odd (see Section 7.1). If is even, this first term is zero in expectation and gradient descent will converge to the top eigenvector of , which corresponds to the partial trace estimator for even. Mirroring the calculation for the partial trace estimator, this succeeds with samples. When is odd, this can be further improved to by using this estimator as a warm start from which to run gradient descent with as in Anandkumar et al. [25], Biroli et al. [26].
7.3 Connection to Minibatch SGD
A recent line of works has studied the implicit regularization effect of stochastic gradient descent [9, 11, 10]. The key idea is that over short timescales, the iterates converge to a quasi-stationary distribution where depends on the Hessian and the noise covariance at and is proportional to the ratio of the learning rate and batch size. As a result, over longer periods of time SGD follows the smoothed gradient of the empirical loss:
We therefore conjecture that minibatch SGD is also able to achieve the optimal sample complexity without explicit smoothing if the learning rate and batch size are properly tuned.
8 Acknowledgements
AD acknowledges support from a NSF Graduate Research Fellowship. EN acknowledges support from a National Defense Science & Engineering Graduate Fellowship. RG is supported by NSF Award DMS-2031849, CCF-1845171 (CAREER), CCF-1934964 (Tripods) and a Sloan Research Fellowship. AD, EN, and JDL acknowledge support of the ARO under MURI Award W911NF-11-1-0304, the Sloan Research Fellowship, NSF CCF 2002272, NSF IIS 2107304, NSF CIF 2212262, ONR Young Investigator Award, and NSF CAREER Award 2144994.
References
- Ben Arous et al. [2021] Gerard Ben Arous, Reza Gheissari, and Aukosh Jagannath. Online stochastic gradient descent on non-convex losses from high-dimensional inference. The Journal of Machine Learning Research, 22(1):4788–4838, 2021.
- Ge et al. [2016] Rong Ge, Jason D Lee, and Tengyu Ma. Matrix completion has no spurious local minimum. Advances in neural information processing systems, 29, 2016.
- Ma [2020] Tengyu Ma. Why do local methods solve nonconvex problems?, 2020.
- Abbe et al. [2023] Emmanuel Abbe, Enric Boix-Adserà, and Theodor Misiakiewicz. Sgd learning on neural networks: leap complexity and saddle-to-saddle dynamics. arXiv, 2023. URL https://arxiv.org/abs/2302.11055.
- Bietti et al. [2022] Alberto Bietti, Joan Bruna, Clayton Sanford, and Min Jae Song. Learning single-index models with shallow neural networks. In Advances in Neural Information Processing Systems (NeurIPS), 2022.
- Damian et al. [2022] Alexandru Damian, Jason Lee, and Mahdi Soltanolkotabi. Neural networks can learn representations with gradient descent. In Conference on Learning Theory, pages 5413–5452. PMLR, 2022.
- Mei et al. [2018] Song Mei, Yu Bai, and Andrea Montanari. The landscape of empirical risk for nonconvex losses. The Annals of Statistics, 46:2747–2774, 2018.
- Richard and Montanari [2014] Emile Richard and Andrea Montanari. A statistical model for tensor pca. In Advances in Neural Information Processing Systems, pages 2897 – 2905, 2014.
- Blanc et al. [2020] Guy Blanc, Neha Gupta, Gregory Valiant, and Paul Valiant. Implicit regularization for deep neural networks driven by an ornstein-uhlenbeck like process. In Conference on Learning Theory, pages 483–513, 2020.
- Damian et al. [2021] Alex Damian, Tengyu Ma, and Jason D. Lee. Label noise SGD provably prefers flat global minimizers. In A. Beygelzimer, Y. Dauphin, P. Liang, and J. Wortman Vaughan, editors, Advances in Neural Information Processing Systems, 2021.
- Li et al. [2022] Zhiyuan Li, Tianhao Wang, and Sanjeev Arora. What happens after SGD reaches zero loss? –a mathematical framework. In International Conference on Learning Representations, 2022.
- Shallue et al. [2018] Christopher J Shallue, Jaehoon Lee, Joseph Antognini, Jascha Sohl-Dickstein, Roy Frostig, and George E Dahl. Measuring the effects of data parallelism on neural network training. arXiv preprint arXiv:1811.03600, 2018.
- Szegedy et al. [2016] Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jon Shlens, and Zbigniew Wojna. Rethinking the inception architecture for computer vision. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 2818–2826, 2016.
- Wen et al. [2019] Yeming Wen, Kevin Luk, Maxime Gazeau, Guodong Zhang, Harris Chan, and Jimmy Ba. Interplay between optimization and generalization of stochastic gradient descent with covariance noise. arXiv preprint arXiv:1902.08234, 2019.
- Kakade et al. [2011] Sham M Kakade, Varun Kanade, Ohad Shamir, and Adam Kalai. Efficient learning of generalized linear and single index models with isotonic regression. Advances in Neural Information Processing Systems, 24, 2011.
- Soltanolkotabi [2017] Mahdi Soltanolkotabi. Learning relus via gradient descent. In Advances in Neural Information Processing Systems (NeurIPS), 2017.
- Candès et al. [2015] Emmanuel J. Candès, Xiaodong Li, and Mahdi Soltanolkotabi. Phase retrieval via wirtinger flow: Theory and algorithms. IEEE Transactions on Information Theory, 61(4):1985–2007, 2015. doi: 10.1109/TIT.2015.2399924.
- Chen et al. [2019] Yuxin Chen, Yuejie Chi, Jianqing Fan, and Cong Ma. Gradient descent with random initialization: fast global convergence for nonconvex phase retrieval. Mathematical Programming, 176(1):5–37, 2019.
- Sun et al. [2018] Ju Sun, Qing Qu, and John Wright. A geometric analysis of phase retrieval. Foundations of Computational Mathematics, 18(5):1131001198, 2018.
- Dudeja and Hsu [2018] Rishabh Dudeja and Daniel Hsu. Learning single-index models in gaussian space. In Sébastien Bubeck, Vianney Perchet, and Philippe Rigollet, editors, Proceedings of the 31st Conference On Learning Theory, volume 75 of Proceedings of Machine Learning Research, pages 1887–1930. PMLR, 06–09 Jul 2018. URL https://proceedings.mlr.press/v75/dudeja18a.html.
- Chen and Meka [2020] Sitan Chen and Raghu Meka. Learning polynomials in few relevant dimensions. In Jacob Abernethy and Shivani Agarwal, editors, Proceedings of Thirty Third Conference on Learning Theory, volume 125 of Proceedings of Machine Learning Research, pages 1161–1227. PMLR, 09–12 Jul 2020. URL https://proceedings.mlr.press/v125/chen20a.html.
- Ba et al. [2022] Jimmy Ba, Murat A. Erdogdu, Taiji Suzuki, Zhichao Wang, Denny Wu, and Greg Yang. High-dimensional asymptotics of feature learning: How one gradient step improves the representation. In Advances in Neural Information Processing Systems (NeurIPS), 2022.
- Abbe et al. [2022] Emmanuel Abbe, Enric Boix Adsera, and Theodor Misiakiewicz. The merged-staircase property: a necessary and nearly sufficient condition for sgd learning of sparse functions on two-layer neural networks. In Conference on Learning Theory, pages 4782–4887. PMLR, 2022.
- Hopkins et al. [2016] Samuel B. Hopkins, Tselil Schramm, Jonathan Shi, and David Steurer. Fast spectral algorithms from sum-of-squares proofs: Tensor decomposition and planted sparse vectors. In Proceedings of the Forty-Eighth Annual ACM Symposium on Theory of Computing, STOC ’16, page 178–191, New York, NY, USA, 2016. Association for Computing Machinery. ISBN 9781450341325. doi: 10.1145/2897518.2897529. URL https://doi.org/10.1145/2897518.2897529.
- Anandkumar et al. [2017] Anima Anandkumar, Yuan Deng, Rong Ge, and Hossein Mobahi. Homotopy analysis for tensor pca. In Satyen Kale and Ohad Shamir, editors, Proceedings of the 2017 Conference on Learning Theory, volume 65 of Proceedings of Machine Learning Research, pages 79–104. PMLR, 07–10 Jul 2017. URL https://proceedings.mlr.press/v65/anandkumar17a.html.
- Biroli et al. [2020] Giulio Biroli, Chiara Cammarota, and Federico Ricci-Tersenghi. How to iron out rough landscapes and get optimal performances: averaged gradient descent and its application to tensor pca. Journal of Physics A: Mathematical and Theoretical, 53(17):174003, apr 2020. doi: 10.1088/1751-8121/ab7b1f. URL https://dx.doi.org/10.1088/1751-8121/ab7b1f.
- Goel et al. [2020] Surbhi Goel, Aravind Gollakota, Zhihan Jin, Sushrut Karmalkar, and Adam Klivans. Superpolynomial lower bounds for learning one-layer neural networks using gradient descent. arXiv preprint arXiv:2006.12011, 2020.
- Diakonikolas et al. [2020] Ilias Diakonikolas, Daniel M Kane, Vasilis Kontonis, and Nikos Zarifis. Algorithms and sq lower bounds for pac learning one-hidden-layer relu networks. In Conference on Learning Theory, pages 1514–1539, 2020.
- Hopkins et al. [2015] Samuel B. Hopkins, Jonathan Shi, and David Steurer. Tensor principal component analysis via sum-of-square proofs. In Peter Grünwald, Elad Hazan, and Satyen Kale, editors, Proceedings of The 28th Conference on Learning Theory, volume 40 of Proceedings of Machine Learning Research, pages 956–1006, Paris, France, 03–06 Jul 2015. PMLR. URL https://proceedings.mlr.press/v40/Hopkins15.html.
- Kunisky et al. [2019] Dmitriy Kunisky, Alexander S. Wein, and Afonso S. Bandeira. Notes on computational hardness of hypothesis testing: Predictions using the low-degree likelihood ratio, 2019.
- Bandeira et al. [2022] Afonso S Bandeira, Ahmed El Alaoui, Samuel Hopkins, Tselil Schramm, Alexander S Wein, and Ilias Zadik. The franz-parisi criterion and computational trade-offs in high dimensional statistics. In S. Koyejo, S. Mohamed, A. Agarwal, D. Belgrave, K. Cho, and A. Oh, editors, Advances in Neural Information Processing Systems, volume 35, pages 33831–33844. Curran Associates, Inc., 2022.
- Brennan et al. [2021] Matthew S Brennan, Guy Bresler, Sam Hopkins, Jerry Li, and Tselil Schramm. Statistical query algorithms and low degree tests are almost equivalent. In Mikhail Belkin and Samory Kpotufe, editors, Proceedings of Thirty Fourth Conference on Learning Theory, volume 134 of Proceedings of Machine Learning Research, pages 774–774. PMLR, 15–19 Aug 2021. URL https://proceedings.mlr.press/v134/brennan21a.html.
- Dudeja and Hsu [2021] Rishabh Dudeja and Daniel Hsu. Statistical query lower bounds for tensor pca. Journal of Machine Learning Research, 22(83):1–51, 2021. URL http://jmlr.org/papers/v22/20-837.html.
- Dudeja and Hsu [2022] Rishabh Dudeja and Daniel Hsu. Statistical-computational trade-offs in tensor pca and related problems via communication complexity, 2022.
- Ben Arous et al. [2020] Gérard Ben Arous, Reza Gheissari, and Aukosh Jagannath. Algorithmic thresholds for tensor PCA. The Annals of Probability, 48(4):2052 – 2087, 2020. doi: 10.1214/19-AOP1415. URL https://doi.org/10.1214/19-AOP1415.
- Ros et al. [2019] Valentina Ros, Gerard Ben Arous, Giulio Biroli, and Chiara Cammarota. Complex energy landscapes in spiked-tensor and simple glassy models: Ruggedness, arrangements of local minima, and phase transitions. Phys. Rev. X, 9:011003, Jan 2019. doi: 10.1103/PhysRevX.9.011003. URL https://link.aps.org/doi/10.1103/PhysRevX.9.011003.
- Ben Arous et al. [2019] Gérard Ben Arous, Song Mei, Andrea Montanari, and Mihai Nica. The landscape of the spiked tensor model. Communications on Pure and Applied Mathematics, 72(11):2282–2330, 2019. doi: https://doi.org/10.1002/cpa.21861. URL https://onlinelibrary.wiley.com/doi/abs/10.1002/cpa.21861.
- Szörényi [2009] Balázs Szörényi. Characterizing statistical query learning: Simplified notions and proofs. In ALT, 2009.
- Pinelis [1994] Iosif Pinelis. Optimum Bounds for the Distributions of Martingales in Banach Spaces. The Annals of Probability, 22(4):1679 – 1706, 1994. doi: 10.1214/aop/1176988477. URL https://doi.org/10.1214/aop/1176988477.
- Bradbury et al. [2018] James Bradbury, Roy Frostig, Peter Hawkins, Matthew James Johnson, Chris Leary, Dougal Maclaurin, George Necula, Adam Paszke, Jake VanderPlas, Skye Wanderman-Milne, and Qiao Zhang. JAX: composable transformations of Python+NumPy programs, 2018. URL http://github.com/google/jax.
- Biewald [2020] Lukas Biewald. Experiment tracking with weights and biases, 2020. URL https://www.wandb.com/. Software available from wandb.com.
Appendix A Background and Notation
A.1 Tensor Notation
Throughout this section let be a -tensor.
Definition 5 (Tensor Action).
For a tensor with , we define the action of on by
We will also use to denote when are both tensors. Note that this corresponds to the standard dot product after flattening .
Definition 6 (Permutation/Transposition).
Given a -tensor and a permutation , we use to denote the result of permuting the axes of by the permutation , i.e.
Definition 7 (Symmetrization).
We define by
where is the symmetric group on . Note that acts on tensors by
i.e. is the symmetrized version of .
We will also overload notation and use to denote the symmetrization operator, i.e. if is a -tensor, .
Definition 8 (Symmetric Tensor Product).
For a tensor and a tensor we define the symmetric tensor product of and by
Lemma 1.
For any tensor ,
Proof.
because permuting the indices of does not change the Frobenius norm. ∎
We will use the following two lemmas for tensor moments of the Gaussian distribution and the uniform distribution over the sphere:
Definition 9.
For integers , define the quantity as
Note that .
Lemma 2 (Tensorized Moments).
Proof.
For the Gaussian moment, see [6]. The spherical moment follows from the decomposition where . ∎
Lemma 3.
Proof.
A.2 Hermite Polynomials and Hermite Tensors
We provide a brief review of the properties of Hermite polynomials and Hermite tensors.
Definition 10.
We define the th Hermite polynomial by
where is the PDF of a standard Gaussian in dimensions. Note that when , this definition reduces to the standard univariate Hermite polynomials.
We begin with the classical properties of the scalar Hermite polynomials:
Lemma 4 (Properties of Hermite Polynomials).
When ,
-
•
Orthogonality:
-
•
Derivatives:
-
•
Correlations: If are correlated Gaussians with correlation ,
-
•
Hermite Expansion: If where is the PDF of a standard Gaussian,
These properties also have tensor analogues:
Lemma 5 (Hermite Polynomials in Higher Dimensions).
-
•
Relationship to Univariante Hermite Polynomials: If ,
-
•
Orthogonality:
or equivalently, for any tensor and tensor :
-
•
Hermite Expansions: If satisfies ,
Appendix B Proof of Theorem 1
The proof of Theorem 1 is divided into four parts. First, Section B.1 introduces some notation that will be used throughout the proof. Next, Section B.2 computes matching upper and lower bounds for the gradient of the smoothed population loss. Similarly, Section B.3 concentrates the empirical gradient of the smoothed loss. Finally, Section B.4 combines the bounds in Section B.2 and Section B.3 with a standard online SGD analysis to arrive at the final rate.
B.1 Additional Notation
Throughout the proof we will assume that so that denotes the spherical gradient of . In particular, for any . We will also use to denote so that we can write expressions such as:
We will use the following assumption on without reference throughout the proof:
Assumption 2.
for a sufficiently large constant .
We note that this is satisfied for the optimal choice of .
We will use to hide dependencies. Explicitly, if there exists such that . We will also use the following shorthand for denoting high probability events:
Definition 11.
We say an event happens with high probability if for every there exists such that for all ,
Note that high probability events are closed under polynomially sized union bounds. As an example, if then with high probability because
for sufficiently large . In general, Lemma 24 shows that if is mean zero and has polynomial tails, i.e. there exists such that , then with high probability.
B.2 Computing the Smoothed Population Gradient
Recall that
In addition, because we assumed that we have Parseval’s identity:
This Hermite decomposition immmediately implies a closed form for the population loss:
Lemma 6 (Population Loss).
Let . Then,
Lemma 6 implies that to understand the smoothed population , it suffices to understand for . First, we will show that the set of single index models is closed under smoothing operator :
Lemma 7.
Let and let . Then
where
Proof.
Expanding the definition of gives:
Now I claim that when , where which would complete the proof. To see this, note that we can decompose into . Under this decomposition we have the polyspherical decomposition where . Then
∎
Of central interest are the quantities as these terms show up when smoothing the population loss (see Lemma 6). We begin by defining the quantity which will provide matching upper and lower bounds on when :
Definition 12.
We define by
Lemma 8.
For all and , there exist constants such that
Proof.
Using Lemma 7 we have that
Now note that when is odd, so we can re-index this sum to get
Note that every term in this sum is non-negative. Now we can ignore constants depending on and use that to get
Now when , this is a decreasing geometric series which is dominated by the first term so . Next, when we have by 2 that so is bounded away from . Therefore the geometric series is dominated by the last term which is
which completes the proof. ∎
Next, in order to understand the population gradient, we need to understand how the smoothing operator commutes with differentiation. We note that these do not directly commute because the smoothing distribution depends on so this term must be differentiated as well. However, smoothing and differentiation almost commute, which is described in the following lemma:
Lemma 9.
Define the dimension-dependent univariate smoothing operator by:
Then,
Proof.
Now we are ready to analyze the population gradient:
Lemma 10.
where for ,
Proof.
Recall that
Because is the index of the first nonzero Hermite coefficient, we can start this sum at . Smoothing and differentiating gives:
We will break this into the term and the tail. First when we can use Lemma 9 and Lemma 8 to get:
The first term is equal up to constants to while the second term is equal up to constants to . However, we have that
Therefore the term in is equal up to constants to .
Next, we handle the tail. By Lemma 9 this is equal to
Now recall that from Lemma 8, is always non-negative so we can use to bound this tail in absolute value by
Now by Corollary 3, this is bounded for by
For the first term, we have
The second term is trivially bounded by
which completes the proof. ∎
B.3 Concentrating the Empirical Gradient
We cannot directly apply Lemma 7 to as . Instead, we will use the properties of the Hermite tensors to directly smooth .
Lemma 11.
where
Proof.
Lemma 12.
For any with ,
Proof.
Recall that by Lemma 11 we have
where
Differentiating this with respect to gives
Now note that by Lemma 5:
Therefore it suffices to compute the Frobenius norm of . We first explicitly differentiate :
Taking Frobenius norms gives
Now we can use Lemma 1 to pull out and get:
Now note that the terms in each sum are orthogonal as at least one will need to be contracted with a . Therefore this is equivalent to:
Next, note that for any tensor , . When , the only permutations that don’t give are the ones which pair up all of the s of which there are . Therefore, by Lemma 3,
Plugging this in gives:
Now note that
which completes the proof. ∎
Corollary 1.
For any with ,
Proof.
The following lemma shows that inherits polynomial tails from :
Lemma 13.
There exists an absolute constant such that for any with and any ,
Proof.
Finally, we can use Corollary 1 and Lemma 13 to bound the norms of the gradient:
Lemma 14.
Let be a fresh sample and let . Then there exists a constant such that for any with , any and all ,
Proof.
Corollary 2.
Let be as in Lemma 14. Then for all ,
Proof.
B.4 Analyzing the Dynamics
Throughout this section we will assume . The proof of the dynamics is split into three stages.
In the first stage, we analyze the regime . In this regime, the signal is dominated by the smoothing.
In the second stage, we analyze the regime . This analysis is similar to the analysis in Ben Arous et al. [1] and could be equivalently carried out with .
Finally in the third stage, we decay the learning rate linearly to achieve the optimal rate
All three stages will use the following progress lemma:
Lemma 15.
Let and let . Let be a fresh batch and define
Then if ,
where and for all ,
Furthermore, if the can be replaced with .
Proof.
Because and ,
where . Note that by Lemma 14, has moments bounded by . Therefore by Lemma 23 with and ,
Plugging in the bound on from Corollary 2 gives
In addition, by Lemma 14,
Similarly, by Lemma 23 with and , Lemma 14, and Corollary 2,
∎
We can now analyze the first stage in which . This stage is dominated by the signal from the smoothing.
Lemma 16 (Stage 1).
Assume that and . Set
for a sufficiently large constant . Then with high probability, there exists such that .
Proof.
Let be the hitting time for . For , let be the event that
We will prove by induction that for any , the event: happens with high probability. The base case of is trivial so let and assume the result for all . Note that so by Lemma 15 and the fact that ,
Now note that so let us condition on the event . Then by the induction hypothesis, with high probability we have for all . Plugging in the value of gives:
Similarly, because is a martingale we have by Lemma 22 and Lemma 24 that with high probability,
where we used that . Therefore conditioned on we have with high probability that for all :
Now we split into two cases depending on the parity of . First, if is odd we have that with high probability, for all :
Now let . Then we have that with high probability,
which implies that with high probability. Next, if is even we have that with high probability
As above, by Lemma 27 the first event implies that so we must have with high probability. ∎
Next, we consider what happens when . The analysis in this stage is similar to the online SGD analysis in [1].
Lemma 17 (Stage 2).
Assume that . Set as in Lemma 16. Then with high probability, .
Proof.
The proof is almost identical to Lemma 16. We again have from Lemma 15
First, from martingale concentration we have that
where we used that . Therefore with high probability,
Therefore while , for sufficiently large we have
Therefore by Lemma 27, we have that there exists such that . Next, let . Then applying Lemma 15 to and using gives that if
then
Therefore,
With high probability, the martingale term is bounded by as before as long as , so for we have that . Setting and choosing appropriately yields , which completes the proof. ∎
Finally, the third stage guarantees not only a hitting time but a last iterate guarantee. It also achieves the optimal sample complexity in terms of the target accuracy :
Lemma 18 (Stage 3).
Assume that . Set and
for a sufficiently large constant . Then for any , we have that with high probability,
Proof.
Let . By Lemma 15, while :
where the moments of are each bounded by . We will prove by induction that with probability at least , we have for all :
The base case is clear so assume the result for all . Then from the recurrence above,
First, because ,
Next,
The next error term is:
Fix . Then we will bound the th moment of :
Now note that because ,
Therefore Therefore the norm of the predictable quadratic variation of the next error term is bounded by:
In addition, the norm of the largest term in this sum is bounded by
Therefore by Lemma 22 and Lemma 24, we have with probability at least , this term is bounded by
Finally, the last term is similarly bounded with probability at least by
which completes the induction. ∎
We can now combine the above lemmas to prove Theorem 1:
B.5 Proof of Theorem 2
We directly follow the proof of Theorem 2 in Damian et al. [6] which is reproduced here for completeness. We begin with the following general CSQ lemma which can be found in Szörényi [38], Damian et al. [6]:
Lemma 19.
Let be a class of functions and be a data distribution such that
Then any correlational statistical query learner requires at least queries of tolerance to output a function in with loss at most .
Appendix C Concentration Inequalities
Lemma 20 (Rosenthal-Burkholder-Pinelis Inequality [39]).
Let be a martingale with martingale difference sequence where . Let
denote the predictable quadratic variation. Then there exists an absolute constant such that for all ,
The above inequality is found in Pinelis [39, Theorem 4.1]. It is often combined with the following simple lemma:
Lemma 21.
For any random variables ,
This has the immediate corollary:
Lemma 22.
Let be a martingale with martingale difference sequence where . Let denote the predictable quadratic variation. Then there exists an absolute constant such that for all ,
We will often use the following corollary of Holder’s inequality to bound the operator norm of a product of two random variables when one has polynomial tails:
Lemma 23.
Let be random variables with . Then,
Proof.
Fix . Then using Holder’s inequality with gives:
Using the fact that have polynomial tails we can bound this by
First, if , we can set which gives
Next, if we can set which gives
which completes the proof. ∎
Finally, the following basic lemma will allow is to easily convert between -norm bounds and concentration inequalities:
Lemma 24.
Let and let be a mean zero random variable satisfying
for some . Then with probability at least , .
Proof.
Let . Then,
∎
Appendix D Additional Technical Lemmas
The following lemma extends Steins’s lemma () to the ultraspherical distribution where is the distribution of when :
Lemma 25 (Spherical Stein’s Lemma).
For any ,
Proof.
Recall that the density of is equal to
Therefore,
Now we can integrate by parts to get
∎
Lemma 26.
For ,
Proof.
Recall that the PDF of is
Using this we have that:
∎
We have the following generalization of Lemma 8:
Corollary 3.
For any with and , there exist such that
Proof.
Expanding the definition of gives:
Now let and note that by Cauchy-Schwarz, . Then,
Now we can use the binomial theorem to expand this. Ignoring constants only depending on :
By Lemma 26, the term is bounded by when is even and when is odd. Therefore this expression is bounded by
Now note that
Therefore, is the dominant term which completes the proof. ∎
Lemma 27 (Adapted from Abbe et al. [4]).
Let be positive constants, and let be a sequence satisfying
Then, if , we have the lower bound
Proof.
Consider the auxiliary sequence . By induction, . To lower bound , we have that
Therefore
Altogether, we get
or
as desired. ∎
Appendix E Additional Experimental Details
To compute the smoothed loss we used the closed form for (see Section B.3). Experiments were run on 8 NVIDIA A6000 GPUs. Our code is written in JAX [40] and we used Weights and Biases [41] for experiment tracking.