This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Smoothing the Landscape Boosts the Signal for SGD
Optimal Sample Complexity for Learning Single Index Models

Alex Damian
Princeton University
[email protected]
   Eshaan Nichani
Princeton University
[email protected]
   Rong Ge
Duke University
[email protected]
   Jason D. Lee
Princeton University
[email protected]
Abstract

We focus on the task of learning a single index model σ(wx)\sigma(w^{\star}\cdot x) with respect to the isotropic Gaussian distribution in dd dimensions. Prior work has shown that the sample complexity of learning ww^{\star} is governed by the information exponent kk^{\star} of the link function σ\sigma, which is defined as the index of the first nonzero Hermite coefficient of σ\sigma. Ben Arous et al. [1] showed that ndk1n\gtrsim d^{k^{\star}-1} samples suffice for learning ww^{\star} and that this is tight for online SGD. However, the CSQ lower bound for gradient based methods only shows that ndk/2n\gtrsim d^{{k^{\star}}/2} samples are necessary. In this work, we close the gap between the upper and lower bounds by showing that online SGD on a smoothed loss learns ww^{\star} with ndk/2n\gtrsim d^{{k^{\star}}/2} samples. We also draw connections to statistical analyses of tensor PCA and to the implicit regularization effects of minibatch SGD on empirical losses.

1 Introduction

Gradient descent-based algorithms are popular for deriving computational and statistical guarantees for a number of high-dimensional statistical learning problems [2, 3, 1, 4, 5, 6]. Despite the fact that the empirical loss is nonconvex and in the worst case computationally intractible to optimize, for a number of statistical learning tasks gradient-based methods still converge to good solutions with polynomial runtime and sample complexity. Analyses in these settings typically study properties of the empirical loss landscape [7], and in particular the number of samples needed for the signal of the gradient arising from the population loss to overpower the noise in some uniform sense. The sample complexity for learning with gradient descent is determined by the landscape of the empirical loss.

One setting in which the empirical loss landscape showcases rich behavior is that of learning a single-index model. Single index models are target functions of the form f(x)=σ(wx)f^{*}(x)=\sigma(w^{\star}\cdot x), where wSd1w^{\star}\in S^{d-1} is the unknown relevant direction and σ\sigma is the known link function. When the covariates are drawn from the standard dd-dimensional Gaussian distribution, the shape of the loss landscape is governed by the information exponent k{k^{\star}} of the link function σ\sigma, which characterizes the curvature of the loss landscape around the origin. Ben Arous et al. [1] show that online stochastic gradient descent on the empirical loss can recover ww^{*} with ndk1n\gtrsim d^{{k^{\star}}-1} samples; furthermore, they present a lower bound showing that for a class of online SGD algorithms, dk1d^{{k^{\star}}-1} samples are indeed necessary.

However, gradient descent can be suboptimal for various statistical learning problems, as it only relies on local information in the loss landscape and is thus prone to getting stuck in local minima. For learning a single index model, the Correlational Statistical Query (CSQ) lower bound only requires dk/2d^{{k^{\star}}/2} samples to recover ww^{\star} [6, 4], which is far fewer than the number of samples required by online SGD. This gap between gradient-based methods and the CSQ lower bound is also present in the Tensor PCA problem [8]; for recovering a rank 1 kk-tensor in dd dimensions, both gradient descent and the power method require dk1d^{k-1} samples, whereas more sophisticated spectral algorithms can match the computational lower bound of dk/2d^{k/2} samples.

In light of the lower bound from [1], it seems hopeless for a gradient-based algorithm to match the CSQ lower bound for learning single-index models. [1] considers the regime in which SGD is simply a discretization of gradient flow, in which case the poor properties of the loss landscape with insufficient samples imply a lower bound. However, recent work has shown that SGD is not just a discretization to gradient flow, but rather that it has an additional implicit regularization effect. Specifically, [9, 10, 11] show that over short periods of time, SGD converges to a quasi-stationary distribution N(θ,λS)N(\theta,\lambda S) where θ\theta is an initial reference point, SS is a matrix depending on the Hessian and the noise covariance and λ=ηB\lambda=\frac{\eta}{B} measures the strength of the noise where η\eta is the learning rate and BB is the batch size. The resulting long term dynamics therefore follow the smoothed gradient ~L(θ)=𝔼zN(0,S)[L(θ+λz)]\widetilde{\nabla}L(\theta)=\operatorname{\mathbb{E}}_{z\sim N(0,S)}[\nabla L(\theta+\lambda z)] which has the effect of regularizing the trace of the Hessian.

This implicit regularization effect of minibatch SGD has been shown to drastically improve generalization and reduce the number of samples necessary for supervised learning tasks [12, 13, 14]. However, the connection between the smoothed landscape and the resulting sample complexity is poorly understood. Towards closing this gap, we consider directly smoothing the loss landscape in order to efficiently learn single index models. Our main result, Theorem 1, shows that online SGD on the smoothed loss learns ww^{\star} in ndk/2n\gtrsim d^{{k^{\star}}/2} samples, which matches the correlation statistical query (CSQ) lower bound. This improves over the ndk1n\gtrsim d^{{k^{\star}}-1} lower bound for online SGD on the unsmoothed loss from Ben Arous et al. [1]. Key to our analysis is the observation that smoothing the loss landscape boosts the signal-to-noise ratio in a region around the initialization, which allows the iterates to avoid the poor local minima for the unsmoothed empirical loss. Our analysis is inspired by the implicit regularization effect of minibatch SGD, along with the partial trace algorithm for Tensor PCA which achieves the optimal dk/2d^{k/2} sample complexity for computationally efficient algorithms.

The outline of our paper is as follows. In Section 3 we formalize the specific statistical learning setup, define the information exponent k{k^{\star}}, and describe our algorithm. Section 4 contains our main theorem, and Section 5 presents a heuristic derivation for how smoothing the loss landscape increases the signal-to-noise ratio. We present empirical verification in Section 6, and in Section 7 we detail connections to tensor PCA nad minibatch SGD.

2 Related Work

There is a rich literature on learning single index models. Kakade et al. [15] showed that gradient descent can learn single index models when the link function is Lipschitz and monotonic and designed an alternative algorithm to handle the case when the link function is unknown. Soltanolkotabi [16] focused on learning single index models where the link function is ReLU(x):=max(0,x)\operatorname{ReLU}(x):=\max(0,x) which has information exponent k=1{k^{\star}}=1. The phase-retrieval problem is a special case of the single index model in which the link function is σ(x)=x2\sigma(x)=x^{2} or σ(x)=|x|\sigma(x)=\absolutevalue{x}; this corresponds to k=2{k^{\star}}=2, and solving phase retrieval via gradient descent has been well studied [17, 18, 19]. Dudeja and Hsu [20] constructed an algorithm which explicitly uses the harmonic structure of Hermite polynomials to identify the information exponent. Ben Arous et al. [1] provided matching upper and lower bounds that show that ndk1n\gtrsim d^{{k^{\star}}-1} samples are necessary and sufficient for online SGD to recover ww^{\star}.

Going beyond gradient-based algorithms, Chen and Meka [21] provide an algorithm that can learn polynomials of few relevant dimensions with ndn\gtrsim d samples, including single index models with polynomial link functions. Their estimator is based on the structure of the filtered PCA matrix 𝔼x,y[𝟏|y|τxxT]\operatorname{\mathbb{E}}_{x,y}[\mathbf{1}_{|y|\geq\tau}xx^{T}], which relies on the heavy tails of polynomials. In particular, this upper bound does not apply to bounded link functions. Furthermore, while their result achieves the information-theoretically optimal dd dependence it is not a CSQ algorithm, whereas our Algorithm 1 achieves the optimal sample complexity over the class of CSQ algorithms (which contains gradient descent).

Recent work has also studied the ability of neural networks to learn single or multi-index models [5, 6, 22, 23, 4]. Bietti et al. [5] showed that two layer neural networks are able to adapt to unknown link functions with ndkn\gtrsim d^{{k^{\star}}} samples. Damian et al. [6] consider multi-index models with polynomial link function, and under a nondegeneracy assumption which corresponds to the k=2{k^{\star}}=2 case, show that SGD on a two-layer neural network requires nd2+rpn\gtrsim d^{2}+r^{p} samples. Abbe et al. [23, 4] provide a generalization of the information exponent called the leap. They prove that in some settings, SGD can learn low dimensional target functions with ndLeap1n\gtrsim d^{\text{Leap}-1} samples. However, they conjecture that the optimal rate is ndLeap/2n\gtrsim d^{\text{Leap}/2} and that this can be achieved by ERM rather than online SGD.

The problem of learning single index models with information exponent kk is strongly related to the order kk Tensor PCA problem (see Section 7.1), which was introduced by Richard and Montanari [8]. They conjectured the existence of a computational-statistical gap for Tensor PCA as the information-theoretic threshold for the problem is ndn\gtrsim d, but all known computationally efficient algorithms require ndk/2n\gtrsim d^{k/2}. Furthermore, simple iterative estimators including tensor power method, gradient descent, and AMP are suboptimal and require ndk1n\gtrsim d^{k-1} samples. Hopkins et al. [24] introduced the partial trace estimator which succeeds with ndk/2n\gtrsim d^{\lceil k/2\rceil} samples. Anandkumar et al. [25] extended this result to show that gradient descent on a smoothed landscape could achieve dk/2d^{k/2} sample complexity when k=3k=3 and Biroli et al. [26] heuristically extended this result to larger kk. The success of smoothing the landscape for Tensor PCA is one of the inspirations for Algorithm 1.

3 Setting

3.1 Data distribution and target function

Our goal is to efficiently learn single index models of the form f(x)=σ(wx)f^{\star}(x)=\sigma(w^{\star}\cdot x) where wSd1w^{\star}\in S^{d-1}, the dd-dimensional unit sphere. We assume that σ\sigma is normalized so that 𝔼xN(0,1)[σ(x)2]=1\operatorname{\mathbb{E}}_{x\sim N(0,1)}[\sigma(x)^{2}]=1. We will also assume that σ\sigma is differentiable and that σ\sigma^{\prime} has polynomial tails:

Assumption 1.

There exist constants C1,C2C_{1},C_{2} such that |σ(x)|C1(1+x2)C2\absolutevalue{\sigma^{\prime}(x)}\leq C_{1}(1+x^{2})^{C_{2}}.

Our goal is to recover ww^{\star} given nn samples (x1,y1),,(xn,yn)(x_{1},y_{1}),\ldots,(x_{n},y_{n}) sampled i.i.d from

xiN(0,Id),yi=f(xi)+zi where ziN(0,ς2).\displaystyle x_{i}\sim N(0,I_{d}),\quad y_{i}=f^{\star}(x_{i})+z_{i}\mbox{\quad where\quad}z_{i}\sim N(0,\varsigma^{2}).

For simplicity of exposition, we assume that σ\sigma is known and we take our model class to be

f(w,x):=σ(wx) where wSd1.\displaystyle f(w,x):=\sigma\quantity(w\cdot x)\mbox{\quad where\quad}w\in S^{d-1}.

3.2 Algorithm: online SGD on a smoothed landscape

As wSd1w\in S^{d-1} we will let w\nabla_{w} denote the spherical gradient with respect to ww. That is, for a function g:dg:\mathbb{R}^{d}\rightarrow\mathbb{R}, let wg(w)=(IwwT)g(z)|z=w\nabla_{w}g(w)=(I-ww^{T})\nabla g(z)\evaluated{}_{z=w} where \nabla is the standard Euclidean gradient.

To compute the loss on a sample (x,y)(x,y), we use the correlation loss:

L(w;x;y):=1f(w,x)y.\displaystyle L(w;x;y):=1-f(w,x)y.

Furthermore, when the sample is omitted we refer to the population loss:

L(w):=𝔼x,y[L(w;x;y)]\displaystyle L(w):=\operatorname{\mathbb{E}}_{x,y}[L(w;x;y)]

Our primary contribution is that SGD on a smoothed loss achieves the optimal sample complexity for this problem. First, we define the smoothing operator λ\mathcal{L}_{\lambda}:

Definition 1.

Let g:Sd1g:S^{d-1}\to\mathbb{R}. We define the smoothing operator λ\mathcal{L}_{\lambda} by

(λg)(w):=𝔼zμw[g(w+λzw+λz)]\displaystyle(\mathcal{L}_{\lambda}g)(w):=\operatorname{\mathbb{E}}_{z\sim\mu_{w}}\quantity[g\quantity(\frac{w+\lambda z}{\norm{w+\lambda z}})]

where μw\mu_{w} is the uniform distribution over Sd1S^{d-1} conditioned on being perpendicular to ww.

This choice of smoothing is natural for spherical gradient descent and can be directly related111 This is equivalent to the intrinsic definition (λg)(w):=𝔼zUTw(Sd1)[expw(θz)](\mathcal{L}_{\lambda}g)(w):=\operatorname{\mathbb{E}}_{z\sim UT_{w}(S^{d-1})}[\exp_{w}(\theta z)] where θ=arctan(λ)\theta=\arctan(\lambda), UTw(Sd1)UT_{w}(S^{d-1}) is the unit sphere in Tw(Sd1)T_{w}(S^{d-1}), and exp\exp is the Riemannian exponential map. to the Riemannian exponential map on Sd1S^{d-1}. We will often abuse notation and write λ(g(w))\mathcal{L}_{\lambda}\quantity(g(w)) rather than (λg)(w)(\mathcal{L}_{\lambda}g)(w). The smoothed empirical loss Lλ(w;x;y)L_{\lambda}(w;x;y) and the population loss Lλ(w)L_{\lambda}(w) are defined by:

Lλ(w;x;y):=λ(L(w;x;y)) and Lλ(w):=λ(L(w)).\displaystyle L_{\lambda}(w;x;y):=\mathcal{L}_{\lambda}\quantity(L(w;x;y))\mbox{\quad and\quad}L_{\lambda}(w):=\mathcal{L}_{\lambda}\quantity(L(w)).

Our algorithm is online SGD on the smoothed loss LλL_{\lambda}:

Input: learning rate schedule {ηt}\quantity{\eta_{t}}, smoothing schedule {λt}\quantity{\lambda_{t}}, steps TT
Sample w0Unif(Sd1)w_{0}\sim\operatorname{Unif}(S^{d-1})
for t=0t=0 to T1T-1 do
       Sample a fresh sample (xt,yt)(x_{t},y_{t})
       w^t+1wtηtwLλt(wt;xt;yt)\hat{w}_{t+1}\leftarrow w_{t}-\eta_{t}\nabla_{w}L_{\lambda_{t}}(w_{t};x_{t};y_{t})
       wt+1w^t+1/w^t+1w_{t+1}\leftarrow\hat{w}_{t+1}/\norm{\hat{w}_{t+1}}
end for
Algorithm 1 Smoothed Online SGD

3.3 Hermite polynomials and information exponent

The sample complexity of Algorithm 1 depends on the Hermite coefficients of σ\sigma:

Definition 2 (Hermite Polynomials).

The kkth Hermite polynomial Hek:He_{k}:\mathbb{R}\rightarrow\mathbb{R} is the degree kk, monic polynomial defined by

Hek(x)=(1)kkμ(x)μ(x),He_{k}(x)=(-1)^{k}\frac{\nabla^{k}\mu(x)}{\mu(x)},

where μ(x):=ex222π\mu(x):=\frac{e^{-\frac{x^{2}}{2}}}{\sqrt{2\pi}} is the PDF of a standard Gaussian.

The first few Hermite polynomials are He0(x)=0,He1(x)=x,He2(x)=x21,He3(x)=x33xHe_{0}(x)=0,He_{1}(x)=x,He_{2}(x)=x^{2}-1,He_{3}(x)=x^{3}-3x. For further discussion on the Hermite polynomials and their properties, refer to Section A.2. The Hermite polynomials form an orthogonal basis of L2(μ)L^{2}(\mu) so any function in L2(μ)L^{2}(\mu) admits a Hermite expansion. We let {ck}k0\quantity{c_{k}}_{k\geq 0} denote the Hermite coefficients of the link function σ\sigma:

Definition 3 (Hermite Expansion of σ\sigma).

Let {ck}k0\quantity{c_{k}}_{k\geq 0} be the Hermite coefficients of σ\sigma, i.e.

σ(x)=k0ckk!Hek(x) where ck=𝔼xN(0,1)[σ(x)Hek(x)].\displaystyle\sigma(x)=\sum_{k\geq 0}\frac{c_{k}}{k!}He_{k}(x)\mbox{\quad where\quad}c_{k}=\operatorname{\mathbb{E}}_{x\sim N(0,1)}[\sigma(x)He_{k}(x)].

The critical quantity of interest is the information exponent of σ\sigma:

Definition 4 (Information Exponent).

k=k(σ){k^{\star}}={k^{\star}}(\sigma) is the first index k1k\geq 1 such that ck0c_{k}\neq 0.

Example 1.

Below are some example link functions and their information exponents:

  • σ(x)=x\sigma(x)=x and σ(x)=ReLU(x):=max(0,x)\sigma(x)=\operatorname{ReLU}(x):=\max(0,x) have information exponents k=1{k^{\star}}=1.

  • σ(x)=x2\sigma(x)=x^{2} and σ(x)=|x|\sigma(x)=\absolutevalue{x} have information exponents k=2{k^{\star}}=2.

  • σ(x)=x33x\sigma(x)=x^{3}-3x has information exponent k=3{k^{\star}}=3. More generally, σ(x)=Hek(x)\sigma(x)=He_{k}(x) has information exponent k=k{k^{\star}}=k.

Throughout our main results we focus on the case k3{k^{\star}}\geq 3 as when k=1,2{k^{\star}}=1,2, online SGD without smoothing already achieves the optimal sample complexity of ndn\asymp d samples (up to log factors) [1].

4 Main Results

Our main result is a sample complexity guarantee for Algorithm 1:

Theorem 1.

Assume w0wd1/2w_{0}\cdot w^{\star}\gtrsim d^{-1/2} and λ[1,d1/4]\lambda\in[1,d^{1/4}]. Let T1=O~(dk1λ2k+4)T_{1}=\tilde{O}\quantity(d^{{k^{\star}}-1}\lambda^{-2{k^{\star}}+4}). For tT1t\leq T_{1} set λt=λ\lambda_{t}=\lambda and ηt=O~(dk/2λ2k2)\eta_{t}=\tilde{O}(d^{-{k^{\star}}/2}\lambda^{2{k^{\star}}-2}). For t>T1t>T_{1} set λt=0\lambda_{t}=0 and ηt=O((d+tT1)1)\eta_{t}=O\quantity((d+t-T_{1})^{-1}). Then if T=T1+T2T=T_{1}+T_{2}, with high probability the final iterate wTw_{T} of Algorithm 1 satisfies L(wT)O(dd+T2).L(w_{T})\leq O(\frac{d}{d+T_{2}}).

Theorem 1 uses large smoothing (up to λ=d1/4\lambda=d^{1/4}) to rapidly escape the regime in which wwd1/2w\cdot w^{\star}\asymp d^{-1/2}. This first stage continues until ww=1od(1)w\cdot w^{\star}=1-o_{d}(1) which takes T1=O~(dk/2)T_{1}=\tilde{O}(d^{{k^{\star}}/2}) steps when λ=d1/4\lambda=d^{1/4}. The second stage, in which λ=0\lambda=0 and the learning rate decays linearly, lasts for an additional T2=d/ϵT_{2}=d/\epsilon steps where ϵ\epsilon is the target accuracy. Because Algorithm 1 uses each sample exactly once, this gives the sample complexity

ndk1λ2k+4+d/ϵ\displaystyle n\gtrsim d^{{k^{\star}}-1}\lambda^{-2{k^{\star}}+4}+d/\epsilon

to reach population loss L(wT)ϵL(w_{T})\leq\epsilon. Setting λ=O(1)\lambda=O(1) is equivalent to zero smoothing and gives a sample complexity of ndk1+d/ϵn\gtrsim d^{{k^{\star}}-1}+d/\epsilon, which matches the results of Ben Arous et al. [1]. On the other hand, setting λ\lambda to the maximal allowable value of d1/4d^{1/4} gives:

ndk2CSQlower bound+d/ϵinformationlower bound\displaystyle n\gtrsim\underbrace{d^{\frac{{k^{\star}}}{2}}\vphantom{d/\epsilon}}_{\begin{subarray}{c}\text{CSQ}\\ \text{lower bound}\end{subarray}}+\underbrace{d/\epsilon}_{\begin{subarray}{c}\text{information}\\ \text{lower bound}\end{subarray}}

which matches the sum of the CSQ lower bound, which is dk2d^{\frac{{k^{\star}}}{2}}, and the information-theoretic lower bound, which is d/ϵd/\epsilon, up to poly-logarithmic factors.

To complement Theorem 1, we replicate the CSQ lower bound in [6] for the specific function class σ(wx)\sigma(w\cdot x) where wSd1w\in S^{d-1}. Statistical query learners are a family of learners that can query values q(x,y)q(x,y) and receive outputs q^\hat{q} with |q^𝔼x,y[q(x,y)]|τ\absolutevalue{\hat{q}-\operatorname{\mathbb{E}}_{x,y}[q(x,y)]}\leq\tau where τ\tau denotes the query tolerance [27, 28]. An important class of statistical query learners is that of correlational/inner product statistical queries (CSQ) of the form q(x,y)=yh(x)q(x,y)=yh(x). This includes a wide class of algorithms including gradient descent with square loss and correlation loss.

Theorem 2 (CSQ Lower Bound).

Consider the function class σ:={σ(wx):w𝒮d1}\mathcal{F}_{\sigma}:=\{\sigma(w\cdot x):w\in\mathcal{S}^{d-1}\}. Any CSQ algorithm using qq queries requires a tolerance τ\tau of at most

τ(log(qd)d)k/4\displaystyle\tau\lesssim\quantity(\frac{\log(qd)}{d})^{{k^{\star}}/4}

to output an fσf\in\mathcal{F}_{\sigma} with population loss less than 1/21/2.

Using the standard τn1/2\tau\approx n^{-1/2} heuristic which comes from concentration, this implies that ndk2n\gtrsim d^{\frac{{k^{\star}}}{2}} samples are necessary to learn σ(wx)\sigma(w\cdot x) unless the algorithm makes exponentially many queries. In the context of gradient descent, this is equivalent to either requiring exponentially many parameters or exponentially many steps of gradient descent.

5 Proof Sketch

In this section we highlight the key ideas of the proof of Theorem 1. The full proof is deferred to Appendix B. The proof sketch is broken into three parts. First, we conduct a general analysis on online SGD to show how the signal-to-noise ratio (SNR) affects the sample complexity. Next, we compute the SNR for the unsmoothed objective (λ=0\lambda=0) to heuristically rederive the dk1d^{{k^{\star}}-1} sample complexity in Ben Arous et al. [1]. Finally, we show how smoothing boosts the SNR and leads to an improved sample complexity of dk/2d^{{k^{\star}}/2} when λ=d1/4\lambda=d^{1/4}.

5.1 Online SGD Analysis

To begin, we will analyze a single step of online SGD. We define αt:=wtw\alpha_{t}:=w_{t}\cdot w^{\star} so that αt[1,1]\alpha_{t}\in[-1,1] measures our current progress. Furthermore, let vt:=Lλt(wt;xt;yt)v_{t}:=-\nabla L_{\lambda_{t}}(w_{t};x_{t};y_{t}). Recall that the online SGD update is:

wt+1=wt+ηtvtwt+ηtvtαt+1=αt+ηt(vtw)wt+ηtvt.\displaystyle w_{t+1}=\frac{w_{t}+\eta_{t}v_{t}}{\norm{w_{t}+\eta_{t}v_{t}}}\implies\alpha_{t+1}=\frac{\alpha_{t}+\eta_{t}(v_{t}\cdot w^{\star})}{\norm{w_{t}+\eta_{t}v_{t}}}.

Using the fact that vwv\perp w and 11+x21x22\frac{1}{\sqrt{1+x^{2}}}\approx 1-\frac{x^{2}}{2} we can Taylor expand the update for αt+1\alpha_{t+1}:

αt+1=αt+ηt(vtw)1+ηt2vt2αt+ηt(vtw)ηt2vt2αt2+O(ηt3).\displaystyle\alpha_{t+1}=\frac{\alpha_{t}+\eta_{t}(v_{t}\cdot w^{\star})}{\sqrt{1+\eta_{t}^{2}\norm{v_{t}}^{2}}}\approx\alpha_{t}+\eta_{t}(v_{t}\cdot w^{\star})-\frac{\eta_{t}^{2}\norm{v_{t}}^{2}\alpha_{t}}{2}+O(\eta_{t}^{3}).

As in Ben Arous et al. [1], we decompose this update into a drift term and a martingale term. Let t=σ{(x0,y0),,(xt1,yt1)}\mathcal{F}_{t}=\sigma\quantity{(x_{0},y_{0}),\ldots,(x_{t-1},y_{t-1})} be the natural filtration. We focus on the drift term as the martingale term can be handled with standard concentration arguments. Taking expectations with respect to the fresh batch (xt,yt)(x_{t},y_{t}) gives:

𝔼[αt+1|t]αt+ηt𝔼[vtw|t]ηt2𝔼[vt2|t]αt/2\displaystyle\operatorname{\mathbb{E}}[\alpha_{t+1}|\mathcal{F}_{t}]\approx\alpha_{t}+\eta_{t}\operatorname{\mathbb{E}}[v_{t}\cdot w^{\star}|\mathcal{F}_{t}]-\eta_{t}^{2}\operatorname{\mathbb{E}}[\norm{v_{t}}^{2}|\mathcal{F}_{t}]\alpha_{t}/2

so to guarantee a positive drift, we need to set ηt2𝔼[vtw|t]𝔼[vt2|t]αt\eta_{t}\leq\frac{2\operatorname{\mathbb{E}}[v_{t}\cdot w^{\star}|\mathcal{F}_{t}]}{\operatorname{\mathbb{E}}[\norm{v_{t}}^{2}|\mathcal{F}_{t}]\alpha_{t}} which gives us the value of ηt\eta_{t} used in Theorem 1 for tT1t\leq T_{1}. However, to simplify the proof sketch we can assume knowledge of 𝔼[vtw|t]\operatorname{\mathbb{E}}[v_{t}\cdot w^{\star}|\mathcal{F}_{t}] and 𝔼[vt2|t]\operatorname{\mathbb{E}}[\norm{v_{t}}^{2}|\mathcal{F}_{t}] and optimize over ηt\eta_{t} to get a maximum drift of

𝔼[αt+1|wt]αt+12αt𝔼[vtw|t]2𝔼[vt2|t]SNR.\displaystyle\operatorname{\mathbb{E}}[\alpha_{t+1}|w_{t}]\approx\alpha_{t}+\frac{1}{2\alpha_{t}}\cdot\underbrace{\frac{\operatorname{\mathbb{E}}[v_{t}\cdot w^{\star}|\mathcal{F}_{t}]^{2}}{\operatorname{\mathbb{E}}[\norm{v_{t}}^{2}|\mathcal{F}_{t}]}}_{\text{SNR}}.

The numerator measures the correlation of the population gradient with ww^{\star} while the denominator measures the norm of the noisy gradient. Their ratio thus has a natural interpretation as the signal-to-noise ratio (SNR). Note that the SNR is a local property, i.e. the SNR can vary for different wtw_{t}. When the SNR can be written as a function of αt=wtw\alpha_{t}=w_{t}\cdot w^{\star}, the SNR directly dictates the rate of optimization through the ODE approximation: αSNR/α\alpha^{\prime}\approx\text{SNR}/\alpha. As online SGD uses each sample exactly once, the sample complexity for online SGD can be approximated by the time it takes this ODE to reach α1\alpha\approx 1 from α0d1/2\alpha_{0}\approx d^{-1/2}. The remainder of the proof sketch will therefore focus on analyzing the SNR of the minibatch gradient Lλ(w;x;y)\nabla L_{\lambda}(w;x;y).

5.2 Computing the Rate with Zero Smoothing

When λ=0\lambda=0, the signal and noise terms can easily be calculated. The key property we need is:

Property 1 (Orthogonality Property).

Let w,wSd1w,w^{\star}\in S^{d-1} and let α=ww\alpha=w\cdot w^{\star}. Then:

𝔼xN(0,Id)[Hej(wx)Hek(wx)]=δjkk!αk.\displaystyle\operatorname{\mathbb{E}}_{x\sim N(0,I_{d})}[He_{j}(w\cdot x)He_{k}(w^{\star}\cdot x)]=\delta_{jk}k!\alpha^{k}.

Using 1 and the Hermite expansion of σ\sigma (Definition 3) we can directly compute the population loss and gradient. Letting Pw:=IwwTP_{w}^{\perp}:=I-ww^{T} denote the projection onto the subspace orthogonal to ww we have:

L(w)=k0ck2k![1αk] and L(w)=(Pww)k0ck2(k1)!αk1.\displaystyle L(w)=\sum_{k\geq 0}\frac{c_{k}^{2}}{k!}[1-\alpha^{k}]\mbox{\quad and\quad}\nabla L(w)=-(P_{w}^{\perp}w^{\star})\sum_{k\geq 0}\frac{c_{k}^{2}}{(k-1)!}\alpha^{k-1}.

As α1\alpha\ll 1 throughout most of the trajectory, the gradient is dominated by the first nonzero Hermite coefficient so up to constants: 𝔼[vw]=L(w)wαk1\operatorname{\mathbb{E}}[v\cdot w^{\star}]=-\nabla L(w)\cdot w^{\star}\approx\alpha^{{k^{\star}}-1}. Similarly, a standard concentration argument shows that because vv is a random vector in dd dimensions where each coordinate is O(1)O(1), 𝔼[v]2d\operatorname{\mathbb{E}}[\norm{v}]^{2}\approx d. Therefore the SNR is equal to α2(k1)/d\alpha^{2({k^{\star}}-1)}/d so with an optimal learning rate schedule,

𝔼[αt+1|t]αt+αt2k3/d.\displaystyle\operatorname{\mathbb{E}}[\alpha_{t+1}|\mathcal{F}_{t}]\approx\alpha_{t}+\alpha_{t}^{2{k^{\star}}-3}/d.

This can be approximated by the ODE α=α2k3/d\alpha^{\prime}=\alpha^{2{k^{\star}}-3}/d. Solving this ODE with the initial α0d1/2\alpha_{0}\asymp d^{-1/2} gives that the escape time is proportional to dα02(k1)=dk1d\alpha_{0}^{-2({k^{\star}}-1)}=d^{{k^{\star}}-1} which heuristically re-derives the result of Ben Arous et al. [1].

d1/2d^{-1/2}d1/4d^{-1/4}11dkd^{-k^{\star}}dk2+1d^{-\frac{k}{2}+1}dk2+12d^{-\frac{k}{2}+\frac{1}{2}}11α\alphaSNRSNR with λ=d1/4\lambda=d^{1/4}kk^{\star} oddkk^{\star} evenNo Smoothing

Figure 1: When λ=d1/4\lambda=d^{1/4}, smoothing increases the SNR until α=λd1/2=d1/4\alpha=\lambda d^{-1/2}=d^{-1/4}.

5.3 How Smoothing boosts the SNR

Smoothing improves the sample complexity of online SGD by boosting the SNR of the stochastic gradient L(w;x;y)\nabla L(w;x;y). Recall that the population loss was approximately equal to 1ck2k!αk1-\frac{c_{k^{\star}}^{2}}{{k^{\star}}!}\alpha^{k^{\star}} where k{k^{\star}} is the first nonzero Hermite coefficient of σ\sigma. Isolating the dominant αk\alpha^{k^{\star}} term and applying the smoothing operator λ\mathcal{L}_{\lambda}, we get:

λ(αk)=𝔼zμw[(w+λzw+λzw)k].\displaystyle\mathcal{L}_{\lambda}(\alpha^{k^{\star}})=\operatorname{\mathbb{E}}_{z\sim\mu_{w}}\quantity[\quantity(\frac{w+\lambda z}{\norm{w+\lambda z}}\cdot w^{\star})^{k^{\star}}].

Because vwv\perp w and z=1\norm{z}=1 we have that w+λz=1+λ2λ\norm{w+\lambda z}=\sqrt{1+\lambda^{2}}\approx\lambda. Therefore,

λ(αk)λk𝔼zμw[(α+λ(zw))k]=λkj=0k(kj)αkjλj𝔼zμw[(zw)j].\displaystyle\mathcal{L}_{\lambda}(\alpha^{{k^{\star}}})\approx\lambda^{-{k^{\star}}}\operatorname{\mathbb{E}}_{z\sim\mu_{w}}\quantity[(\alpha+\lambda(z\cdot w^{\star}))^{k^{\star}}]=\lambda^{-{k^{\star}}}\sum_{j=0}^{k}\binom{k}{j}\alpha^{k-j}\lambda^{j}\operatorname{\mathbb{E}}_{z\sim\mu_{w}}[(z\cdot w^{\star})^{j}].

Now because z=dzz\stackrel{{\scriptstyle d}}{{=}}-z, the terms where jj is odd disappear. Furthermore, for a random zz, |zw|=Θ(d1/2)\absolutevalue{z\cdot w^{\star}}=\Theta(d^{-1/2}). Therefore reindexing and ignoring all constants we have that

Lλ(w)1λ(αk)\displaystyle L_{\lambda}(w)\approx 1-\mathcal{L}_{\lambda}(\alpha^{{k^{\star}}}) 1λkj=0k2αk2j(λ2/d)j.\displaystyle\approx 1-\lambda^{-{k^{\star}}}\sum_{j=0}^{\lfloor\frac{{k^{\star}}}{2}\rfloor}\alpha^{k-2j}\quantity(\lambda^{2}/d)^{j}.

Differentiating gives that

𝔼[vw]wwλ(αk)\displaystyle\operatorname{\mathbb{E}}[v\cdot w^{\star}]\approx-w^{\star}\cdot\nabla_{w}\mathcal{L}_{\lambda}(\alpha^{{k^{\star}}}) λ1j=0k12(αλ)k1(λ2α2d)j.\displaystyle\approx\lambda^{-1}\sum_{j=0}^{\lfloor\frac{{k^{\star}}-1}{2}\rfloor}\quantity(\frac{\alpha}{\lambda})^{{k^{\star}}-1}\quantity(\frac{\lambda^{2}}{\alpha^{2}d})^{j}.

As this is a geometric series, it is either dominated by the first or the last term depending on whether αλd1/2\alpha\geq\lambda d^{-1/2} or αλd1/2\alpha\leq\lambda d^{-1/2}. Furthermore, the last term is either dk12d^{-\frac{{k^{\star}}-1}{2}} if k{k^{\star}} is odd or αλdk22\frac{\alpha}{\lambda}d^{-\frac{{k^{\star}}-2}{2}} if k{k^{\star}} is even. Therefore the signal term is:

𝔼[vw]λ1{(αλ)k1αλd1/2dk12αλd1/2 and k is oddαλdk22αλd1/2 and k is even.\displaystyle\operatorname{\mathbb{E}}[v\cdot w^{\star}]\approx\lambda^{-1}\begin{cases}(\frac{\alpha}{\lambda})^{{k^{\star}}-1}&\alpha\geq\lambda d^{-1/2}\\ d^{-\frac{{k^{\star}}-1}{2}}&\alpha\leq\lambda d^{-1/2}\text{ and ${k^{\star}}$ is odd}\\ \frac{\alpha}{\lambda}d^{-\frac{{k^{\star}}-2}{2}}&\alpha\leq\lambda d^{-1/2}\text{ and ${k^{\star}}$ is even}\end{cases}.

In addition, we can show that when λd1/4\lambda\leq d^{1/4}, the noise term satisfies 𝔼[v2]dλ2k\operatorname{\mathbb{E}}[\norm{v}^{2}]\leq d\lambda^{-2{k^{\star}}}. Note that in the high signal regime (αλd1/2\alpha\geq\lambda d^{-1/2}), both the signal and the noise are smaller by factors of λk\lambda^{{k^{\star}}} which cancel when computing the SNR. However, when αλd1/2\alpha\leq\lambda d^{-1/2} the smoothing shrinks the noise faster than it shrinks the signal, resulting in an overall larger SNR. Explicitly,

SNR:=𝔼[vw]2𝔼[v2]1d{α2(k1)αλd1/2(λ2/d)k1αλd1/2 and k is oddα2(λ2/d)k2αλd1/2 and k is even.\displaystyle\text{SNR}:=\frac{\operatorname{\mathbb{E}}[v\cdot w^{\star}]^{2}}{\operatorname{\mathbb{E}}[\norm{v}^{2}]}\approx\frac{1}{d}\begin{cases}\alpha^{2({k^{\star}}-1)}&\alpha\geq\lambda d^{-1/2}\\ (\lambda^{2}/d)^{{k^{\star}}-1}&\alpha\leq\lambda d^{-1/2}\text{ and ${k^{\star}}$ is odd}\\ \alpha^{2}(\lambda^{2}/d)^{{k^{\star}}-2}&\alpha\leq\lambda d^{-1/2}\text{ and ${k^{\star}}$ is even}\end{cases}.

For αλd1/2\alpha\geq\lambda d^{-1/2}, smoothing does not affect the SNR. However, when αλd1/2\alpha\leq\lambda d^{-1/2}, smoothing greatly increases the SNR (see Figure 1).

Solving the ODE: α=SNR/α\alpha^{\prime}=\text{SNR}/\alpha gives that it takes Tdk1λ2k+4T\approx d^{{k^{\star}}-1}\lambda^{-2{k^{\star}}+4} steps to converge to α1\alpha\approx 1 from αd1/2\alpha\approx d^{-1/2}. Once α1\alpha\approx 1, the problem is locally strongly convex, so we can decay the learning rate and use classical analysis of strongly-convex functions to show that α1ϵ\alpha\geq 1-\epsilon with an additional d/ϵd/\epsilon steps, from which Theorem 1 follows.

6 Experiments

For k=3,4,5{k^{\star}}=3,4,5 and d=28,,213d=2^{8},\ldots,2^{13} we ran a minibatch variant of Algorithm 1 with batch size BB when σ(x)=Hek(x)k!\sigma(x)=\frac{He_{{k^{\star}}}(x)}{\sqrt{k!}}, the normalized k{k^{\star}}th Hermite polynomial. We set:

λ=d1/4,η=Bdk/2(1+λ2)k11000k!,B=min(0.1dk/2(1+λ2)2k+4,8192).\displaystyle\lambda=d^{1/4},\quad\eta=\frac{Bd^{-{k^{\star}}/2}(1+\lambda^{2})^{{k^{\star}}-1}}{{1000{k^{\star}}!}},\quad B=\min\quantity(0.1d^{{k^{\star}}/2}(1+\lambda^{2})^{-2{k^{\star}}+4},8192).

We computed the number of samples required for Algorithm 1 to reach α2=0.5\alpha^{2}=0.5 from α=d1/2\alpha=d^{-1/2} and we report the min, mean, and max over 1010 random seeds. For each kk we fit a power law of the form n=c1dc2n=c_{1}d^{c_{2}} in order to measure how the sample complexity scales with dd. For all values of k{k^{\star}}, we find that c2k/2c_{2}\approx{k^{\star}}/2 which matches Theorem 1. The results can be found in Figure 2 and additional experimental details can be found in Appendix E.

Refer to caption
Figure 2: For k=3,4,5{k^{\star}}=3,4,5, Algorithm 1 finds ww^{\star} with ndk/2n\propto d^{{k^{\star}}/2} samples. The solid lines and the shaded areas represent the mean and min/max values over 1010 random seeds. For each curve, we also fit a power law n=c1dc2n=c_{1}d^{c_{2}} represented by the dashed lines. The value of c2c_{2} is reported in the legend.

7 Discussion

7.1 Tensor PCA

We next outline connections to the Tensor PCA problem. Introduced in [8], the goal of Tensor PCA is to recover the hidden direction w𝒮d1w^{*}\in\mathcal{S}^{d-1} from the noisy kk-tensor Tn(d)kT_{n}\in(\mathbb{R}^{d})^{\otimes k} given by222This normalization is equivalent to the original 1βdZ\frac{1}{\beta\sqrt{d}}Z normalization by setting n=β2dn=\beta^{2}d.

Tn=(w)k+1nZ,\displaystyle T_{n}=(w^{*})^{\otimes k}+\frac{1}{\sqrt{n}}Z,

where Z(d)kZ\in(\mathbb{R}^{d})^{\otimes k} is a Gaussian noise tensor with each entry drawn i.i.d from 𝒩(0,1)\mathcal{N}(0,1).

The Tensor PCA problem has garnered significant interest as it exhibits a statistical-computational gap. ww^{*} is information theoretically recoverable when ndn\gtrsim d. However, the best polynomial-time algorithms require ndk/2n\gtrsim d^{k/2}; this lower bound has been shown to be tight for various notions of hardness such as CSQ or SoS lower bounds [29, 24, 30, 31, 32, 33, 34]. Tensor PCA also exhibits a gap between spectral methods and iterative algorithms. Algorithms that work in the ndk/2n\asymp d^{k/2} regime rely on unfolding or contracting the tensor XX, or on semidefinite programming relaxations [29, 24]. On the other hand, iterative algorithms including gradient descent, power method, and AMP require a much larger sample complexity of ndk1n\gtrsim d^{k-1} [35]. The suboptimality of iterative algorithms is believed to be due to bad properties of the landscape of the Tensor PCA objective in the region around the initialization. Specifically [36, 37] argue that there are exponentially many local minima near the equator in the ndk1n\ll d^{k-1} regime. To overcome this, prior works have considered “smoothed" versions of gradient descent, and show that smoothing recovers the computationally optimal SNR in the k=3k=3 case [25] and heuristically for larger kk [26].

7.1.1 The Partial Trace Algorithm

The smoothing algorithms above are inspired by the following partial trace algorithm for Tensor PCA [24], which can be viewed as Algorithm 1 in the limit as λ\lambda\to\infty [25]. Let Tn=(w)k+1nZT_{n}=(w^{\star})^{\otimes k}+\frac{1}{\sqrt{n}}Z. Then we will consider iteratively contracting indices of TT until all that remains is a vector (if kk is odd) or a matrix (if kk is even). Explicitly, we define the partial trace tensor by

Mn:=Tn(Idk22){d×dk is evendk is odd.\displaystyle M_{n}:=T_{n}\quantity(I_{d}^{\otimes\lceil\frac{k-2}{2}\rceil})\in\begin{cases}\mathbb{R}^{d\times d}&\text{$k$ is even}\\ \mathbb{R}^{d}&\text{$k$ is odd}.\end{cases}

When k{k^{\star}} is odd, we can directly return MnM_{n} as our estimate for ww^{\star} and when k{k^{\star}} is even we return the top eigenvector of MnM_{n}. A standard concentration argument shows that this succeeds when ndk/2n\gtrsim d^{\lceil k/2\rceil}. Furthermore, this can be strengthened to dk/2d^{k/2} by using the partial trace vector as a warm start for gradient descent or tensor power method when kk is odd [25, 26].

7.1.2 The Connection Between Single Index Models and Tensor PCA

For both tensor PCA and learning single index models, gradient descent succeeds when the sample complexity is n=dk1n=d^{k-1} [35, 1]. On the other hand, the smoothing algorithms for Tensor PCA [26, 25] succeed with the computationally optimal sample complexity of n=dk/2n=d^{k/2}. Our Theorem 1 shows that this smoothing analysis can indeed be transferred to the single-index model setting.

In fact, one can make a direct connection between learning single-index models with Gaussian covariates and Tensor PCA. Consider learning a single-index model when σ(x)=Hek(x)k!\sigma(x)=\frac{He_{k}(x)}{\sqrt{k!}}, the normalized kkth Hermite polynomial. Then minimizing the correlation loss is equivalent to maximizing the loss function:

Ln(w)=1ni=1nyiHek(wxi)k!=wk,Tn where Tn:=1ni=1nyi𝐇𝐞k(xi)k!.\displaystyle L_{n}(w)=\frac{1}{n}\sum_{i=1}^{n}y_{i}\frac{He_{k}(w\cdot x_{i})}{\sqrt{k!}}=\expectationvalue{w^{\otimes k},T_{n}}\mbox{\quad where\quad}T_{n}:=\frac{1}{n}\sum_{i=1}^{n}y_{i}\frac{\mathbf{He}_{k}(x_{i})}{\sqrt{k!}}.

Here 𝐇𝐞k(xi)(d)k\mathbf{He}_{k}(x_{i})\in(\mathbb{R}^{d})^{\otimes k} denotes the kkth Hermite tensor (see Section A.2 for background on Hermite polynomials and Hermite tensors). In addition, by the orthogonality of the Hermite tensors, 𝔼x,y[Tn]=(w)k\operatorname{\mathbb{E}}_{x,y}[T_{n}]=(w^{\star})^{\otimes k} so we can decompose Tn=(w)k+ZnT_{n}=(w^{\star})^{\otimes k}+Z_{n} where by standard concentration, each entry of ZnZ_{n} is order n1/2n^{-1/2}. We can therefore directly apply algorithms for Tensor PCA to this problem. We remark that this connection is a heuristic, as the structure of the noise in Tensor PCA and our single index model setting are different.

7.2 Empirical Risk Minimization on the Smoothed Landscape

Our main sample complexity guarantee, Theorem 1, is based on a tight analysis of online SGD (Algorithm 1) in which each sample is used exactly once. One might expect that if the algorithm were allowed to reuse samples, as is standard practice in deep learning, that the algorithm could succeed with fewer samples. In particular, Abbe et al. [4] conjectured that gradient descent on the empirical loss Ln(w):=1ni=1nL(w;xi;yi)L_{n}(w):=\frac{1}{n}\sum_{i=1}^{n}L(w;x_{i};y_{i}) would succeed with ndk/2n\gtrsim d^{{k^{\star}}/2} samples.

Our smoothing algorithm Algorithm 1 can be directly translated to the ERM setting to learn ww^{\star} with ndk/2n\gtrsim d^{{k^{\star}}/2} samples. We can then Taylor expand the smoothed loss in the large λ\lambda limit:

λ(Ln(w))𝔼zSd1[Ln(z)+λ1wLn(z)+λ22wT2Ln(z)w]+O(λ3).\displaystyle\mathcal{L}_{\lambda}\quantity(L_{n}(w))\approx\operatorname{\mathbb{E}}_{z\sim S^{d-1}}\quantity[L_{n}(z)+\lambda^{-1}w\cdot\nabla L_{n}(z)+\frac{\lambda^{-2}}{2}\cdot w^{T}\nabla^{2}L_{n}(z)w]+O(\lambda^{-3}).

As λ\lambda\to\infty, gradient descent on this smoothed loss will converge to 𝔼zSd1[Ln(z)]\operatorname{\mathbb{E}}_{z\sim S^{d-1}}[\nabla L_{n}(z)] which is equivalent to the partial trace estimator for k{k^{\star}} odd (see Section 7.1). If k{k^{\star}} is even, this first term is zero in expectation and gradient descent will converge to the top eigenvector of 𝔼zSd1[2Ln(z)]\operatorname{\mathbb{E}}_{z\sim S^{d-1}}[\nabla^{2}L_{n}(z)], which corresponds to the partial trace estimator for k{k^{\star}} even. Mirroring the calculation for the partial trace estimator, this succeeds with ndk/2n\gtrsim d^{\lceil{k^{\star}}/2\rceil} samples. When k{k^{\star}} is odd, this can be further improved to dk/2d^{{k^{\star}}/2} by using this estimator as a warm start from which to run gradient descent with λ=0\lambda=0 as in Anandkumar et al. [25], Biroli et al. [26].

7.3 Connection to Minibatch SGD

A recent line of works has studied the implicit regularization effect of stochastic gradient descent [9, 11, 10]. The key idea is that over short timescales, the iterates converge to a quasi-stationary distribution N(θ,λS)N(\theta,\lambda S) where SIS\approx I depends on the Hessian and the noise covariance at θ\theta and λ\lambda is proportional to the ratio of the learning rate and batch size. As a result, over longer periods of time SGD follows the smoothed gradient of the empirical loss:

L~n(w)=𝔼zN(0,S)[Ln(w+λz)].\displaystyle\widetilde{L}_{n}(w)=\operatorname{\mathbb{E}}_{z\sim N(0,S)}[L_{n}(w+\lambda z)].

We therefore conjecture that minibatch SGD is also able to achieve the optimal ndk/2n\gtrsim d^{{k^{\star}}/2} sample complexity without explicit smoothing if the learning rate and batch size are properly tuned.

8 Acknowledgements

AD acknowledges support from a NSF Graduate Research Fellowship. EN acknowledges support from a National Defense Science & Engineering Graduate Fellowship. RG is supported by NSF Award DMS-2031849, CCF-1845171 (CAREER), CCF-1934964 (Tripods) and a Sloan Research Fellowship. AD, EN, and JDL acknowledge support of the ARO under MURI Award W911NF-11-1-0304, the Sloan Research Fellowship, NSF CCF 2002272, NSF IIS 2107304, NSF CIF 2212262, ONR Young Investigator Award, and NSF CAREER Award 2144994.

References

  • Ben Arous et al. [2021] Gerard Ben Arous, Reza Gheissari, and Aukosh Jagannath. Online stochastic gradient descent on non-convex losses from high-dimensional inference. The Journal of Machine Learning Research, 22(1):4788–4838, 2021.
  • Ge et al. [2016] Rong Ge, Jason D Lee, and Tengyu Ma. Matrix completion has no spurious local minimum. Advances in neural information processing systems, 29, 2016.
  • Ma [2020] Tengyu Ma. Why do local methods solve nonconvex problems?, 2020.
  • Abbe et al. [2023] Emmanuel Abbe, Enric Boix-Adserà, and Theodor Misiakiewicz. Sgd learning on neural networks: leap complexity and saddle-to-saddle dynamics. arXiv, 2023. URL https://arxiv.org/abs/2302.11055.
  • Bietti et al. [2022] Alberto Bietti, Joan Bruna, Clayton Sanford, and Min Jae Song. Learning single-index models with shallow neural networks. In Advances in Neural Information Processing Systems (NeurIPS), 2022.
  • Damian et al. [2022] Alexandru Damian, Jason Lee, and Mahdi Soltanolkotabi. Neural networks can learn representations with gradient descent. In Conference on Learning Theory, pages 5413–5452. PMLR, 2022.
  • Mei et al. [2018] Song Mei, Yu Bai, and Andrea Montanari. The landscape of empirical risk for nonconvex losses. The Annals of Statistics, 46:2747–2774, 2018.
  • Richard and Montanari [2014] Emile Richard and Andrea Montanari. A statistical model for tensor pca. In Advances in Neural Information Processing Systems, pages 2897 – 2905, 2014.
  • Blanc et al. [2020] Guy Blanc, Neha Gupta, Gregory Valiant, and Paul Valiant. Implicit regularization for deep neural networks driven by an ornstein-uhlenbeck like process. In Conference on Learning Theory, pages 483–513, 2020.
  • Damian et al. [2021] Alex Damian, Tengyu Ma, and Jason D. Lee. Label noise SGD provably prefers flat global minimizers. In A. Beygelzimer, Y. Dauphin, P. Liang, and J. Wortman Vaughan, editors, Advances in Neural Information Processing Systems, 2021.
  • Li et al. [2022] Zhiyuan Li, Tianhao Wang, and Sanjeev Arora. What happens after SGD reaches zero loss? –a mathematical framework. In International Conference on Learning Representations, 2022.
  • Shallue et al. [2018] Christopher J Shallue, Jaehoon Lee, Joseph Antognini, Jascha Sohl-Dickstein, Roy Frostig, and George E Dahl. Measuring the effects of data parallelism on neural network training. arXiv preprint arXiv:1811.03600, 2018.
  • Szegedy et al. [2016] Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jon Shlens, and Zbigniew Wojna. Rethinking the inception architecture for computer vision. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 2818–2826, 2016.
  • Wen et al. [2019] Yeming Wen, Kevin Luk, Maxime Gazeau, Guodong Zhang, Harris Chan, and Jimmy Ba. Interplay between optimization and generalization of stochastic gradient descent with covariance noise. arXiv preprint arXiv:1902.08234, 2019.
  • Kakade et al. [2011] Sham M Kakade, Varun Kanade, Ohad Shamir, and Adam Kalai. Efficient learning of generalized linear and single index models with isotonic regression. Advances in Neural Information Processing Systems, 24, 2011.
  • Soltanolkotabi [2017] Mahdi Soltanolkotabi. Learning relus via gradient descent. In Advances in Neural Information Processing Systems (NeurIPS), 2017.
  • Candès et al. [2015] Emmanuel J. Candès, Xiaodong Li, and Mahdi Soltanolkotabi. Phase retrieval via wirtinger flow: Theory and algorithms. IEEE Transactions on Information Theory, 61(4):1985–2007, 2015. doi: 10.1109/TIT.2015.2399924.
  • Chen et al. [2019] Yuxin Chen, Yuejie Chi, Jianqing Fan, and Cong Ma. Gradient descent with random initialization: fast global convergence for nonconvex phase retrieval. Mathematical Programming, 176(1):5–37, 2019.
  • Sun et al. [2018] Ju Sun, Qing Qu, and John Wright. A geometric analysis of phase retrieval. Foundations of Computational Mathematics, 18(5):1131001198, 2018.
  • Dudeja and Hsu [2018] Rishabh Dudeja and Daniel Hsu. Learning single-index models in gaussian space. In Sébastien Bubeck, Vianney Perchet, and Philippe Rigollet, editors, Proceedings of the 31st Conference On Learning Theory, volume 75 of Proceedings of Machine Learning Research, pages 1887–1930. PMLR, 06–09 Jul 2018. URL https://proceedings.mlr.press/v75/dudeja18a.html.
  • Chen and Meka [2020] Sitan Chen and Raghu Meka. Learning polynomials in few relevant dimensions. In Jacob Abernethy and Shivani Agarwal, editors, Proceedings of Thirty Third Conference on Learning Theory, volume 125 of Proceedings of Machine Learning Research, pages 1161–1227. PMLR, 09–12 Jul 2020. URL https://proceedings.mlr.press/v125/chen20a.html.
  • Ba et al. [2022] Jimmy Ba, Murat A. Erdogdu, Taiji Suzuki, Zhichao Wang, Denny Wu, and Greg Yang. High-dimensional asymptotics of feature learning: How one gradient step improves the representation. In Advances in Neural Information Processing Systems (NeurIPS), 2022.
  • Abbe et al. [2022] Emmanuel Abbe, Enric Boix Adsera, and Theodor Misiakiewicz. The merged-staircase property: a necessary and nearly sufficient condition for sgd learning of sparse functions on two-layer neural networks. In Conference on Learning Theory, pages 4782–4887. PMLR, 2022.
  • Hopkins et al. [2016] Samuel B. Hopkins, Tselil Schramm, Jonathan Shi, and David Steurer. Fast spectral algorithms from sum-of-squares proofs: Tensor decomposition and planted sparse vectors. In Proceedings of the Forty-Eighth Annual ACM Symposium on Theory of Computing, STOC ’16, page 178–191, New York, NY, USA, 2016. Association for Computing Machinery. ISBN 9781450341325. doi: 10.1145/2897518.2897529. URL https://doi.org/10.1145/2897518.2897529.
  • Anandkumar et al. [2017] Anima Anandkumar, Yuan Deng, Rong Ge, and Hossein Mobahi. Homotopy analysis for tensor pca. In Satyen Kale and Ohad Shamir, editors, Proceedings of the 2017 Conference on Learning Theory, volume 65 of Proceedings of Machine Learning Research, pages 79–104. PMLR, 07–10 Jul 2017. URL https://proceedings.mlr.press/v65/anandkumar17a.html.
  • Biroli et al. [2020] Giulio Biroli, Chiara Cammarota, and Federico Ricci-Tersenghi. How to iron out rough landscapes and get optimal performances: averaged gradient descent and its application to tensor pca. Journal of Physics A: Mathematical and Theoretical, 53(17):174003, apr 2020. doi: 10.1088/1751-8121/ab7b1f. URL https://dx.doi.org/10.1088/1751-8121/ab7b1f.
  • Goel et al. [2020] Surbhi Goel, Aravind Gollakota, Zhihan Jin, Sushrut Karmalkar, and Adam Klivans. Superpolynomial lower bounds for learning one-layer neural networks using gradient descent. arXiv preprint arXiv:2006.12011, 2020.
  • Diakonikolas et al. [2020] Ilias Diakonikolas, Daniel M Kane, Vasilis Kontonis, and Nikos Zarifis. Algorithms and sq lower bounds for pac learning one-hidden-layer relu networks. In Conference on Learning Theory, pages 1514–1539, 2020.
  • Hopkins et al. [2015] Samuel B. Hopkins, Jonathan Shi, and David Steurer. Tensor principal component analysis via sum-of-square proofs. In Peter Grünwald, Elad Hazan, and Satyen Kale, editors, Proceedings of The 28th Conference on Learning Theory, volume 40 of Proceedings of Machine Learning Research, pages 956–1006, Paris, France, 03–06 Jul 2015. PMLR. URL https://proceedings.mlr.press/v40/Hopkins15.html.
  • Kunisky et al. [2019] Dmitriy Kunisky, Alexander S. Wein, and Afonso S. Bandeira. Notes on computational hardness of hypothesis testing: Predictions using the low-degree likelihood ratio, 2019.
  • Bandeira et al. [2022] Afonso S Bandeira, Ahmed El Alaoui, Samuel Hopkins, Tselil Schramm, Alexander S Wein, and Ilias Zadik. The franz-parisi criterion and computational trade-offs in high dimensional statistics. In S. Koyejo, S. Mohamed, A. Agarwal, D. Belgrave, K. Cho, and A. Oh, editors, Advances in Neural Information Processing Systems, volume 35, pages 33831–33844. Curran Associates, Inc., 2022.
  • Brennan et al. [2021] Matthew S Brennan, Guy Bresler, Sam Hopkins, Jerry Li, and Tselil Schramm. Statistical query algorithms and low degree tests are almost equivalent. In Mikhail Belkin and Samory Kpotufe, editors, Proceedings of Thirty Fourth Conference on Learning Theory, volume 134 of Proceedings of Machine Learning Research, pages 774–774. PMLR, 15–19 Aug 2021. URL https://proceedings.mlr.press/v134/brennan21a.html.
  • Dudeja and Hsu [2021] Rishabh Dudeja and Daniel Hsu. Statistical query lower bounds for tensor pca. Journal of Machine Learning Research, 22(83):1–51, 2021. URL http://jmlr.org/papers/v22/20-837.html.
  • Dudeja and Hsu [2022] Rishabh Dudeja and Daniel Hsu. Statistical-computational trade-offs in tensor pca and related problems via communication complexity, 2022.
  • Ben Arous et al. [2020] Gérard Ben Arous, Reza Gheissari, and Aukosh Jagannath. Algorithmic thresholds for tensor PCA. The Annals of Probability, 48(4):2052 – 2087, 2020. doi: 10.1214/19-AOP1415. URL https://doi.org/10.1214/19-AOP1415.
  • Ros et al. [2019] Valentina Ros, Gerard Ben Arous, Giulio Biroli, and Chiara Cammarota. Complex energy landscapes in spiked-tensor and simple glassy models: Ruggedness, arrangements of local minima, and phase transitions. Phys. Rev. X, 9:011003, Jan 2019. doi: 10.1103/PhysRevX.9.011003. URL https://link.aps.org/doi/10.1103/PhysRevX.9.011003.
  • Ben Arous et al. [2019] Gérard Ben Arous, Song Mei, Andrea Montanari, and Mihai Nica. The landscape of the spiked tensor model. Communications on Pure and Applied Mathematics, 72(11):2282–2330, 2019. doi: https://doi.org/10.1002/cpa.21861. URL https://onlinelibrary.wiley.com/doi/abs/10.1002/cpa.21861.
  • Szörényi [2009] Balázs Szörényi. Characterizing statistical query learning: Simplified notions and proofs. In ALT, 2009.
  • Pinelis [1994] Iosif Pinelis. Optimum Bounds for the Distributions of Martingales in Banach Spaces. The Annals of Probability, 22(4):1679 – 1706, 1994. doi: 10.1214/aop/1176988477. URL https://doi.org/10.1214/aop/1176988477.
  • Bradbury et al. [2018] James Bradbury, Roy Frostig, Peter Hawkins, Matthew James Johnson, Chris Leary, Dougal Maclaurin, George Necula, Adam Paszke, Jake VanderPlas, Skye Wanderman-Milne, and Qiao Zhang. JAX: composable transformations of Python+NumPy programs, 2018. URL http://github.com/google/jax.
  • Biewald [2020] Lukas Biewald. Experiment tracking with weights and biases, 2020. URL https://www.wandb.com/. Software available from wandb.com.

Appendix A Background and Notation

A.1 Tensor Notation

Throughout this section let T(d)kT\in(\mathbb{R}^{d})^{\otimes}k be a kk-tensor.

Definition 5 (Tensor Action).

For a jj tensor A(d)jA\in(\mathbb{R}^{d})^{\otimes j} with jkj\leq k, we define the action T[A]T[A] of TT on AA by

(T[A])i1,,ikj:=Ti1,,ikAikj+1,,ik(d)(kj).\displaystyle(T[A])_{i_{1},\ldots,i_{k-j}}:=T_{i_{1},\ldots,i_{k}}A^{i_{k-j+1},\ldots,i_{k}}\in(\mathbb{R}^{d})^{\otimes(k-j)}.

We will also use T,A\expectationvalue{T,A} to denote T[A]=A[T]T[A]=A[T] when A,TA,T are both kk tensors. Note that this corresponds to the standard dot product after flattening A,TA,T.

Definition 6 (Permutation/Transposition).

Given a kk-tensor TT and a permutation πSk\pi\in S_{k}, we use π(T)\pi(T) to denote the result of permuting the axes of TT by the permutation π\pi, i.e.

π(T)i1,,ik:=Tiπ(1),,iπ(k).\displaystyle\pi(T)_{i_{1},\ldots,i_{k}}:=T_{i_{\pi(1)},\ldots,i_{\pi(k)}}.
Definition 7 (Symmetrization).

We define Symk(d)2k\operatorname{Sym}_{k}\in(\mathbb{R}^{d})^{\otimes 2k} by

(Symk)i1,,ik,j1,,jk=1k!πSkδiπ(1),j1δiπ(k),jk\displaystyle(\operatorname{Sym}_{k})_{i_{1},\ldots,i_{k},j_{1},\ldots,j_{k}}=\frac{1}{k!}\sum_{\pi\in S_{k}}\delta_{i_{\pi(1)},j_{1}}\cdots\delta_{i_{\pi(k)},j_{k}}

where SkS_{k} is the symmetric group on 1,,k1,\ldots,k. Note that Symk\operatorname{Sym}_{k} acts on kk tensors TT by

(Symk[T])i1,,ik=1k!πSkπ(T).\displaystyle(\operatorname{Sym}_{k}[T])_{i_{1},\ldots,i_{k}}=\frac{1}{k!}\sum_{\pi\in S_{k}}\pi(T).

i.e. Symk[T]\operatorname{Sym}_{k}[T] is the symmetrized version of TT.

We will also overload notation and use Sym\operatorname{Sym} to denote the symmetrization operator, i.e. if TT is a kk-tensor, Sym(T):=Symk[T]\operatorname{Sym}(T):=\operatorname{Sym}_{k}[T].

Definition 8 (Symmetric Tensor Product).

For a kk tensor TT and a jj tensor AA we define the symmetric tensor product of TT and AA by

T~A:=Sym(TA).\displaystyle T\operatorname{\widetilde{\otimes}}A:=\operatorname{Sym}(T\otimes A).
Lemma 1.

For any tensor TT,

Sym(T)FTF.\displaystyle\norm{\operatorname{Sym}(T)}_{F}\leq\norm{T}_{F}.
Proof.
Sym(T)F=1k!πSkπ(T)F1k!πSkπ(T)F=TF\displaystyle\norm{\operatorname{Sym}(T)}_{F}=\norm{\frac{1}{k!}\sum_{\pi\in S_{k}}\pi(T)}_{F}\leq\frac{1}{k!}\sum_{\pi\in S_{k}}\norm{\pi(T)}_{F}=\norm{T}_{F}

because permuting the indices of TT does not change the Frobenius norm. ∎

We will use the following two lemmas for tensor moments of the Gaussian distribution and the uniform distribution over the sphere:

Definition 9.

For integers k,d>0k,d>0, define the quantity νk(d)\nu_{k}^{(d)} as

νk(d):=(2k1)!!j=0k11d+2j=Θ(dk).\displaystyle\nu_{k}^{(d)}:=(2k-1)!!\prod_{j=0}^{k-1}\frac{1}{d+2j}=\Theta(d^{-k}).

Note that νk(d)=𝔼zSd1[z12k]\nu_{k}^{(d)}=\operatorname{\mathbb{E}}_{z\sim S^{d-1}}[z_{1}^{2k}].

Lemma 2 (Tensorized Moments).
𝔼xN(0,Id)[x2k]=(2k1)!!I~k and 𝔼zSd1[z2k]=νk(d)I~k.\displaystyle\operatorname{\mathbb{E}}_{x\sim N(0,I_{d})}[x^{\otimes 2k}]=(2k-1)!!I^{\operatorname{\widetilde{\otimes}}k}\mbox{\quad and\quad}\operatorname{\mathbb{E}}_{z\sim S^{d-1}}[z^{\otimes 2k}]=\nu_{k}^{(d)}I^{\operatorname{\widetilde{\otimes}}k}.
Proof.

For the Gaussian moment, see [6]. The spherical moment follows from the decomposition x=dzrx\stackrel{{\scriptstyle d}}{{=}}zr where rχ(d)r\sim\chi(d). ∎

Lemma 3.
I~jF2=(νj(d))1\displaystyle\|I^{\operatorname{\widetilde{\otimes}}j}\|_{F}^{2}=(\nu_{j}^{(d)})^{-1}
Proof.

By Lemma 2 we have

1=𝔼zSd1[z2j]=𝔼zSd1[z2j],I~j=νj(d)I~jF2.\displaystyle 1=\operatorname{\mathbb{E}}_{z\sim S^{d-1}}[\|z\|^{2j}]=\langle\operatorname{\mathbb{E}}_{z\sim S^{d-1}}[z^{2j}],I^{\operatorname{\widetilde{\otimes}}j}\rangle=\nu_{j}^{(d)}\|I^{\operatorname{\widetilde{\otimes}}j}\|_{F}^{2}.

A.2 Hermite Polynomials and Hermite Tensors

We provide a brief review of the properties of Hermite polynomials and Hermite tensors.

Definition 10.

We define the kkth Hermite polynomial Hek(x)He_{k}(x) by

Hek(x):=(1)kkμ(x)μ(x)\displaystyle He_{k}(x):=(-1)^{k}\frac{\nabla^{k}\mu(x)}{\mu(x)}

where μ(x):=ex22(2π)d/2\mu(x):=\frac{e^{-\frac{\norm{x}^{2}}{2}}}{(2\pi)^{d/2}} is the PDF of a standard Gaussian in dd dimensions. Note that when d=1d=1, this definition reduces to the standard univariate Hermite polynomials.

We begin with the classical properties of the scalar Hermite polynomials:

Lemma 4 (Properties of Hermite Polynomials).

When d=1d=1,

  • Orthogonality:

    𝔼xN(0,1)[Hej(x)Hek(x)]=k!δjk\displaystyle\operatorname{\mathbb{E}}_{x\sim N(0,1)}[He_{j}(x)He_{k}(x)]=k!\delta_{jk}
  • Derivatives:

    ddxHek(x)=kHek1(x)\displaystyle\frac{d}{dx}He_{k}(x)=kHe_{k-1}(x)
  • Correlations: If x,yN(0,1)x,y\sim N(0,1) are correlated Gaussians with correlation α:=𝔼[xy]\alpha:=\operatorname{\mathbb{E}}[xy],

    𝔼x,y[Hej(x)Hek(y)]=k!δjkαk.\displaystyle\operatorname{\mathbb{E}}_{x,y}[He_{j}(x)He_{k}(y)]=k!\delta_{jk}\alpha^{k}.
  • Hermite Expansion: If fL2(μ)f\in L^{2}(\mu) where μ\mu is the PDF of a standard Gaussian,

    f(x)=L2(μ)k0ckk!Hek(x) where ck=𝔼xN(0,1)[f(x)Hek(x)].\displaystyle f(x)\stackrel{{\scriptstyle L^{2}(\mu)}}{{=}}\sum_{k\geq 0}\frac{c_{k}}{k!}He_{k}(x)\mbox{\quad where\quad}c_{k}=\operatorname{\mathbb{E}}_{x\sim N(0,1)}[f(x)He_{k}(x)].

These properties also have tensor analogues:

Lemma 5 (Hermite Polynomials in Higher Dimensions).
  • Relationship to Univariante Hermite Polynomials: If w=1\norm{w}=1,

    Hek(wx)=Hek(x),wk\displaystyle He_{k}(w\cdot x)=\expectationvalue{He_{k}(x),w^{\otimes k}}
  • Orthogonality:

    𝔼xN(0,Id)[Hej(x)Hek(x)]=δjkk!Symk\displaystyle\operatorname{\mathbb{E}}_{x\sim N(0,I_{d})}\quantity[He_{j}(x)\otimes He_{k}(x)]=\delta_{jk}k!\operatorname{Sym}_{k}

    or equivalently, for any jj tensor AA and kk tensor BB:

    𝔼xN(0,Id)[Hej(x),AHek(x),B]=δjkk!Sym(A),Sym(B).\displaystyle\operatorname{\mathbb{E}}_{x\sim N(0,I_{d})}\quantity[\expectationvalue{He_{j}(x),A}\expectationvalue{He_{k}(x),B}]=\delta_{jk}k!\expectationvalue{\operatorname{Sym}(A),\operatorname{Sym}(B)}.
  • Hermite Expansions: If f:df:\mathbb{R}^{d}\to\mathbb{R} satisfies 𝔼xN(0,Id)[f(x)2]<\operatorname{\mathbb{E}}_{x\sim N(0,I_{d})}[f(x)^{2}]<\infty,

    f=k01k!Hek(x),Ck where Ck=𝔼xN(0,Id)[f(x)Hek(x)].\displaystyle f=\sum_{k\geq 0}\frac{1}{k!}\expectationvalue{He_{k}(x),C_{k}}\mbox{\quad where\quad}C_{k}=\operatorname{\mathbb{E}}_{x\sim N(0,I_{d})}[f(x)He_{k}(x)].

Appendix B Proof of Theorem 1

The proof of Theorem 1 is divided into four parts. First, Section B.1 introduces some notation that will be used throughout the proof. Next, Section B.2 computes matching upper and lower bounds for the gradient of the smoothed population loss. Similarly, Section B.3 concentrates the empirical gradient of the smoothed loss. Finally, Section B.4 combines the bounds in Section B.2 and Section B.3 with a standard online SGD analysis to arrive at the final rate.

B.1 Additional Notation

Throughout the proof we will assume that wSd1w\in S^{d-1} so that w\nabla_{w} denotes the spherical gradient of ww. In particular, wg(w)w\nabla_{w}g(w)\perp w for any g:Sd1g:S^{d-1}\to\mathbb{R}. We will also use α\alpha to denote www\cdot w^{\star} so that we can write expressions such as:

wα=Pww=wαw.\displaystyle\nabla_{w}\alpha=P_{w}^{\perp}w^{\star}=w^{\star}-\alpha w.

We will use the following assumption on λ\lambda without reference throughout the proof:

Assumption 2.

λ2d/C\lambda^{2}\leq d/C for a sufficiently large constant CC.

We note that this is satisfied for the optimal choice of λ=d1/4\lambda=d^{1/4}.

We will use O~()\tilde{O}(\cdot) to hide polylog(d)\operatorname{polylog}(d) dependencies. Explicitly, X=O~(1)X=\tilde{O}(1) if there exists C1,C2>0C_{1},C_{2}>0 such that |X|C1log(d)C2\absolutevalue{X}\leq C_{1}\log(d)^{C_{2}}. We will also use the following shorthand for denoting high probability events:

Definition 11.

We say an event EE happens with high probability if for every k0k\geq 0 there exists d(k)d(k) such that for all dd(k)d\geq d(k), [E]1dk.\operatorname{\mathbb{P}}[E]\geq 1-d^{-k}.

Note that high probability events are closed under polynomially sized union bounds. As an example, if XN(0,1)X\sim N(0,1) then Xlog(d)X\leq\log(d) with high probability because

[x>log(d)]exp(log(d)2/2)dk\displaystyle\operatorname{\mathbb{P}}[x>\log(d)]\leq\exp(-\log(d)^{2}/2)\ll d^{-k}

for sufficiently large dd. In general, Lemma 24 shows that if XX is mean zero and has polynomial tails, i.e. there exists CC such that E[Xp]1/ppCE[X^{p}]^{1/p}\leq p^{C}, then X=O~(1)X=\tilde{O}(1) with high probability.

B.2 Computing the Smoothed Population Gradient

Recall that

σ(x)=L2(μ)k0ckk!Hek(x) where ck:=𝔼xN(0,1)[σ(x)Hek(x)].\displaystyle\sigma(x)\stackrel{{\scriptstyle L^{2}(\mu)}}{{=}}\sum_{k\geq 0}\frac{c_{k}}{k!}He_{k}(x)\mbox{\quad where\quad}c_{k}:=\operatorname{\mathbb{E}}_{x\sim N(0,1)}[\sigma(x)He_{k}(x)].

In addition, because we assumed that 𝔼xN(0,1)[σ(x)2]=1\operatorname{\mathbb{E}}_{x\sim N(0,1)}[\sigma(x)^{2}]=1 we have Parseval’s identity:

1=𝔼xN(0,1)[σ(x)2]=k0ck2k!.\displaystyle 1=\operatorname{\mathbb{E}}_{x\sim N(0,1)}[\sigma(x)^{2}]=\sum_{k\geq 0}\frac{c_{k}^{2}}{k!}.

This Hermite decomposition immmediately implies a closed form for the population loss:

Lemma 6 (Population Loss).

Let α=ww\alpha=w\cdot w^{\star}. Then,

L(w)=k0ck2k![1αk].\displaystyle L(w)=\sum_{k\geq 0}\frac{c_{k}^{2}}{k!}[1-\alpha^{k}].

Lemma 6 implies that to understand the smoothed population Lλ(w)=(λL)(w)L_{\lambda}(w)=(\mathcal{L}_{\lambda}L)(w), it suffices to understand λ(αk)\mathcal{L}_{\lambda}(\alpha^{k}) for k0k\geq 0. First, we will show that the set of single index models is closed under smoothing operator λ\mathcal{L}_{\lambda}:

Lemma 7.

Let g:[1,1]g:[-1,1]\to\mathbb{R} and let uSd1u\in S^{d-1}. Then

λ(g(wu))=gλ(wu)\displaystyle\mathcal{L}_{\lambda}\quantity(g(w\cdot u))=g_{\lambda}(w\cdot u)

where

gλ(α):=𝔼zSd2[g(α+λz11α21+λ2)].\displaystyle g_{\lambda}(\alpha):=\operatorname{\mathbb{E}}_{z\sim S^{d-2}}\quantity[g\quantity(\frac{\alpha+\lambda z_{1}\sqrt{1-\alpha^{2}}}{\sqrt{1+\lambda^{2}}})].
Proof.

Expanding the definition of λ\mathcal{L}_{\lambda} gives:

λ(g(wu))\displaystyle\mathcal{L}_{\lambda}\quantity(g(w\cdot u)) =𝔼zμw[g(w+λzw+λzu)]\displaystyle=\operatorname{\mathbb{E}}_{z\sim\mu_{w}}\quantity[g\quantity(\frac{w+\lambda z}{\norm{w+\lambda z}}\cdot u)]
=𝔼zμw[g(wu+λzu1+λ2)].\displaystyle=\operatorname{\mathbb{E}}_{z\sim\mu_{w}}\quantity[g\quantity(\frac{w\cdot u+\lambda z\cdot u}{\sqrt{1+\lambda^{2}}})].

Now I claim that when zμwz\sim\mu_{w}, zu=dz11(wu)2z\cdot u\stackrel{{\scriptstyle d}}{{=}}z_{1}\sqrt{1-(w\cdot u)^{2}} where zSd2z\sim S^{d-2} which would complete the proof. To see this, note that we can decompose RdR^{d} into span{w}span{w}\operatorname{span}\{w\}\oplus\operatorname{span}\{w\}^{\perp}. Under this decomposition we have the polyspherical decomposition z=d(0,z)z\stackrel{{\scriptstyle d}}{{=}}(0,z^{\prime}) where zSd2z^{\prime}\sim S^{d-2}. Then

zu=zPwu=dz1Pwu=z11(wu)2.\displaystyle z\cdot u=z^{\prime}\cdot P_{w}^{\perp}u\stackrel{{\scriptstyle d}}{{=}}z_{1}\|P_{w}^{\perp}u\|=z_{1}\sqrt{1-(w\cdot u)^{2}}.

Of central interest are the quantities λ(αk)\mathcal{L}_{\lambda}(\alpha^{k}) as these terms show up when smoothing the population loss (see Lemma 6). We begin by defining the quantity sk(α;λ)s_{k}(\alpha;\lambda) which will provide matching upper and lower bounds on λ(αk)\mathcal{L}_{\lambda}(\alpha^{k}) when α0\alpha\geq 0:

Definition 12.

We define sk(α;λ)s_{k}(\alpha;\lambda) by

sk(α;λ):=1(1+λ2)k/2{αkα2λ2d(λ2d)k2α2λ2d and k is evenα(λ2d)k12α2λ2d and k is odd.\displaystyle s_{k}(\alpha;\lambda):=\frac{1}{(1+\lambda^{2})^{k/2}}\begin{cases}\alpha^{k}&\alpha^{2}\geq\frac{\lambda^{2}}{d}\\ \quantity(\frac{\lambda^{2}}{d})^{\frac{k}{2}}&\alpha^{2}\leq\frac{\lambda^{2}}{d}\text{ and $k$ is even}\\ \alpha\quantity(\frac{\lambda^{2}}{d})^{\frac{k-1}{2}}&\alpha^{2}\leq\frac{\lambda^{2}}{d}\text{ and $k$ is odd}\end{cases}.
Lemma 8.

For all k0k\geq 0 and α0\alpha\geq 0, there exist constants c(k),C(k)c(k),C(k) such that

c(k)sk(α;λ)λ(αk)C(k)sk(α;λ).\displaystyle c(k)s_{k}(\alpha;\lambda)\leq\mathcal{L}_{\lambda}(\alpha^{k})\leq C(k)s_{k}(\alpha;\lambda).
Proof.

Using Lemma 7 we have that

λ(αk)\displaystyle\mathcal{L}_{\lambda}(\alpha^{k}) =𝔼zSd2[(α+λz11α21+λ2)k]\displaystyle=\operatorname{\mathbb{E}}_{z\sim S^{d-2}}\quantity[\quantity(\frac{\alpha+\lambda z_{1}\sqrt{1-\alpha^{2}}}{\sqrt{1+\lambda^{2}}})^{k}]
=(1+λ2)k/2j=0k(kj)αkjλj(1α2)j/2𝔼zSd2[z1j].\displaystyle=(1+\lambda^{2})^{-k/2}\sum_{j=0}^{k}\binom{k}{j}\alpha^{k-j}\lambda^{j}(1-\alpha^{2})^{j/2}\operatorname{\mathbb{E}}_{z\sim S^{d-2}}[z_{1}^{j}].

Now note that when jj is odd, 𝔼zSd2[z1j]=0\operatorname{\mathbb{E}}_{z\sim S^{d-2}}[z_{1}^{j}]=0 so we can re-index this sum to get

λ(αk)\displaystyle\mathcal{L}_{\lambda}(\alpha^{k}) =(1+λ2)k/2j=0k2(k2j)αk2jλ2j(1α2)j𝔼zSd2[z12j]\displaystyle=(1+\lambda^{2})^{-k/2}\sum_{j=0}^{\lfloor\frac{k}{2}\rfloor}\binom{k}{2j}\alpha^{k-2j}\lambda^{2j}(1-\alpha^{2})^{j}\operatorname{\mathbb{E}}_{z\sim S^{d-2}}[z_{1}^{2j}]
=(1+λ2)k/2j=0k2(k2j)αk2jλ2j(1α2)jνj(d1).\displaystyle=(1+\lambda^{2})^{-k/2}\sum_{j=0}^{\lfloor\frac{k}{2}\rfloor}\binom{k}{2j}\alpha^{k-2j}\lambda^{2j}(1-\alpha^{2})^{j}\nu_{j}^{(d-1)}.

Note that every term in this sum is non-negative. Now we can ignore constants depending on kk and use that νj(d1)dj\nu_{j}^{(d-1)}\asymp d^{-j} to get

λ(αk)(α1+λ2)kj=0k2(λ2(1α2)α2d)j.\displaystyle\mathcal{L}_{\lambda}(\alpha^{k})\asymp\quantity(\frac{\alpha}{\sqrt{1+\lambda^{2}}})^{k}\sum_{j=0}^{\lfloor\frac{k}{2}\rfloor}\quantity(\frac{\lambda^{2}(1-\alpha^{2})}{\alpha^{2}d})^{j}.

Now when α2λ2d\alpha^{2}\geq\frac{\lambda^{2}}{d}, this is a decreasing geometric series which is dominated by the first term so λ(αk)(αλ)k\mathcal{L}_{\lambda}(\alpha^{k})\asymp\quantity(\frac{\alpha}{\lambda})^{k}. Next, when α2λ2d\alpha^{2}\leq\frac{\lambda^{2}}{d} we have by 2 that α1C\alpha\leq\frac{1}{C} so 1α21-\alpha^{2} is bounded away from 0. Therefore the geometric series is dominated by the last term which is

1(1+λ2)k/2{(λ2d)k2k is evenα(λ2d)k12k is odd\displaystyle\frac{1}{(1+\lambda^{2})^{-k/2}}\begin{cases}\quantity(\frac{\lambda^{2}}{d})^{\frac{k}{2}}&\text{$k$ is even}\\ \alpha\quantity(\frac{\lambda^{2}}{d})^{\frac{k-1}{2}}&\text{$k$ is odd}\end{cases}

which completes the proof. ∎

Next, in order to understand the population gradient, we need to understand how the smoothing operator λ\mathcal{L}_{\lambda} commutes with differentiation. We note that these do not directly commute because the smoothing distribution μw\mu_{w} depends on ww so this term must be differentiated as well. However, smoothing and differentiation almost commute, which is described in the following lemma:

Lemma 9.

Define the dimension-dependent univariate smoothing operator by:

λ(d)g(α):=𝔼zSd2[g(α+λz11α21+λ2)].\displaystyle\mathcal{L}_{\lambda}^{(d)}g(\alpha):=\operatorname{\mathbb{E}}_{z\sim S^{d-2}}\quantity[g\quantity(\frac{\alpha+\lambda z_{1}\sqrt{1-\alpha^{2}}}{\sqrt{1+\lambda^{2}}})].

Then,

ddαλ(d)(g(α))=λ(d)(g(α))1+λ2λ2α(1+λ2)(d1)λ(d+2)(g′′(α)).\displaystyle\frac{d}{d\alpha}\mathcal{L}_{\lambda}^{(d)}(g(\alpha))=\frac{\mathcal{L}_{\lambda}^{(d)}(g^{\prime}(\alpha))}{\sqrt{1+\lambda^{2}}}-\frac{\lambda^{2}\alpha}{(1+\lambda^{2})(d-1)}\mathcal{L}_{\lambda}^{(d+2)}(g^{\prime\prime}(\alpha)).
Proof.

Directly differentiating the definition for λ(d)\mathcal{L}_{\lambda}^{(d)} gives

ddαλ(d)\displaystyle\frac{d}{d\alpha}\mathcal{L}_{\lambda}^{(d)} =λ(d)(g(α))1+λ2𝔼zSd2[αλz1(1+λ2)(1α2)g(α+λz11α21+λ2)].\displaystyle=\frac{\mathcal{L}_{\lambda}^{(d)}(g^{\prime}(\alpha))}{\sqrt{1+\lambda^{2}}}-\operatorname{\mathbb{E}}_{z\sim S^{d-2}}\quantity[\frac{\alpha\lambda z_{1}}{\sqrt{(1+\lambda^{2})(1-\alpha^{2})}}g^{\prime}\quantity(\frac{\alpha+\lambda z_{1}\sqrt{1-\alpha^{2}}}{\sqrt{1+\lambda^{2}}})].

By Lemma 25, this is equal to

ddαλ(d)\displaystyle\frac{d}{d\alpha}\mathcal{L}_{\lambda}^{(d)} =λ(d)(g(α))1+λ2λ2α(1+λ2)(d1)𝔼zSd[g′′(α+λz11α21+λ2)]\displaystyle=\frac{\mathcal{L}_{\lambda}^{(d)}(g^{\prime}(\alpha))}{\sqrt{1+\lambda^{2}}}-\frac{\lambda^{2}\alpha}{(1+\lambda^{2})(d-1)}\operatorname{\mathbb{E}}_{z\sim S^{d}}\quantity[g^{\prime\prime}\quantity(\frac{\alpha+\lambda z_{1}\sqrt{1-\alpha^{2}}}{\sqrt{1+\lambda^{2}}})]
=λ(d)(g(α))1+λ2λ2α(1+λ2)(d1)λ(d+2)(g′′(α)).\displaystyle=\frac{\mathcal{L}_{\lambda}^{(d)}(g^{\prime}(\alpha))}{\sqrt{1+\lambda^{2}}}-\frac{\lambda^{2}\alpha}{(1+\lambda^{2})(d-1)}\mathcal{L}_{\lambda}^{(d+2)}(g^{\prime\prime}(\alpha)).

Now we are ready to analyze the population gradient:

Lemma 10.
wLλ(w)=(wαw)cλ(α)\displaystyle\nabla_{w}L_{\lambda}(w)=-(w^{\star}-\alpha w)c_{\lambda}(\alpha)

where for αC1/4d1/2\alpha\geq C^{-1/4}d^{-1/2},

cλ(α)sk1(α;λ)1+λ2.\displaystyle c_{\lambda}(\alpha)\asymp\frac{s_{{k^{\star}}-1}(\alpha;\lambda)}{\sqrt{1+\lambda^{2}}}.
Proof.

Recall that

L(w)=1k0ck2k!αk.\displaystyle L(w)=1-\sum_{k\geq 0}\frac{c_{k}^{2}}{k!}\alpha^{k}.

Because k{k^{\star}} is the index of the first nonzero Hermite coefficient, we can start this sum at k=kk={k^{\star}}. Smoothing and differentiating gives:

wL(w)=(wαw)cλ(α) where cλ(α):=kkck2k!ddαλ(αk).\displaystyle\nabla_{w}L(w)=-(w^{\star}-\alpha w)c_{\lambda}(\alpha)\mbox{\quad where\quad}c_{\lambda}(\alpha):=\sum_{k\geq{k^{\star}}}\frac{c_{k}^{2}}{k!}\frac{d}{d\alpha}\mathcal{L}_{\lambda}(\alpha^{k}).

We will break this into the k=kk={k^{\star}} term and the k>kk>{k^{\star}} tail. First when k=kk={k^{\star}} we can use Lemma 9 and Lemma 8 to get:

ck2(k)!ddαλ(αk)=ck2(k)![kλ(αk1)1+λ2k(k1)λ2αλ(d+2)(αk2)(1+λ2)(d1)].\displaystyle\frac{c_{{k^{\star}}}^{2}}{({k^{\star}})!}\frac{d}{d\alpha}\mathcal{L}_{\lambda}(\alpha^{k^{\star}})=\frac{c_{{k^{\star}}}^{2}}{({k^{\star}})!}\quantity[\frac{{k^{\star}}\mathcal{L}_{\lambda}(\alpha^{{k^{\star}}-1})}{\sqrt{1+\lambda^{2}}}-\frac{{k^{\star}}({k^{\star}}-1)\lambda^{2}\alpha\mathcal{L}_{\lambda}^{(d+2)}(\alpha^{{k^{\star}}-2})}{(1+\lambda^{2})(d-1)}].

The first term is equal up to constants to sk1(α;λ)1+λ2\frac{s_{{k^{\star}}-1}(\alpha;\lambda)}{\sqrt{1+\lambda^{2}}} while the second term is equal up to constants to λ2α(1+λ2)dsk2(α;λ)\frac{\lambda^{2}\alpha}{(1+\lambda^{2})d}s_{{k^{\star}}-2}(\alpha;\lambda). However, we have that

λ2α(1+λ2)dsk2(α;λ)sk1(α;λ)1+λ2={λ2dα2λ2dλ2dα2λ2d and k is evenα2α2λ2d and k is oddλ2d1C.\displaystyle\frac{\frac{\lambda^{2}\alpha}{(1+\lambda^{2})d}s_{{k^{\star}}-2}(\alpha;\lambda)}{\frac{s_{{k^{\star}}-1}(\alpha;\lambda)}{\sqrt{1+\lambda^{2}}}}=\begin{cases}\frac{\lambda^{2}}{d}&\alpha^{2}\geq\frac{\lambda^{2}}{d}\\ \frac{\lambda^{2}}{d}&\alpha^{2}\leq\frac{\lambda^{2}}{d}\text{ and ${k^{\star}}$ is even}\\ \alpha^{2}&\alpha^{2}\leq\frac{\lambda^{2}}{d}\text{ and ${k^{\star}}$ is odd}\end{cases}\leq\frac{\lambda^{2}}{d}\leq\frac{1}{C}.

Therefore the k=kk={k^{\star}} term in cλ(α)c_{\lambda}(\alpha) is equal up to constants to sk1(α;λ)1+λ2\frac{s_{{k^{\star}}-1}(\alpha;\lambda)}{\sqrt{1+\lambda^{2}}}.

Next, we handle the k>kk>{k^{\star}} tail. By Lemma 9 this is equal to

k>kck2k![kλ(αk1)1+λ2k(k1)λ2αλ(d+2)(αk2)(1+λ2)(d1)].\displaystyle\sum_{k>{k^{\star}}}\frac{c_{k}^{2}}{k!}\quantity[\frac{k\mathcal{L}_{\lambda}(\alpha^{k-1})}{\sqrt{1+\lambda^{2}}}-\frac{k(k-1)\lambda^{2}\alpha\mathcal{L}_{\lambda}^{(d+2)}(\alpha^{k-2})}{(1+\lambda^{2})(d-1)}].

Now recall that from Lemma 8, λ(αk)\mathcal{L}_{\lambda}(\alpha^{k}) is always non-negative so we can use ck2k!c_{k}^{2}\leq k! to bound this tail in absolute value by

k>kkλ(αk1)1+λ2+k(k1)λ2αλ(d+2)(αk2)(1+λ2)(d1)\displaystyle\sum_{k>{k^{\star}}}\frac{k\mathcal{L}_{\lambda}(\alpha^{k-1})}{\sqrt{1+\lambda^{2}}}+\frac{k(k-1)\lambda^{2}\alpha\mathcal{L}_{\lambda}^{(d+2)}(\alpha^{k-2})}{(1+\lambda^{2})(d-1)}
11+λ2λ(k>kkαk1)+λ2α(1+λ2)dλ(d+2)(k>kk(k1)αk2)\displaystyle\lesssim\frac{1}{\sqrt{1+\lambda^{2}}}\mathcal{L}_{\lambda}\quantity(\sum_{k>{k^{\star}}}k\alpha^{k-1})+\frac{\lambda^{2}\alpha}{(1+\lambda^{2})d}\mathcal{L}_{\lambda}^{(d+2)}\quantity(\sum_{k>{k^{\star}}}k(k-1)\alpha^{k-2})
11+λ2λ(αk(1α)2)+λ2α(1+λ2)dλ(d+2)(αk1(1α)3).\displaystyle\lesssim\frac{1}{\sqrt{1+\lambda^{2}}}\mathcal{L}_{\lambda}\quantity(\frac{\alpha^{{k^{\star}}}}{(1-\alpha)^{2}})+\frac{\lambda^{2}\alpha}{(1+\lambda^{2})d}\mathcal{L}_{\lambda}^{(d+2)}\quantity(\frac{\alpha^{{k^{\star}}-1}}{(1-\alpha)^{3}}).

Now by Corollary 3, this is bounded for d5d\geq 5 by

sk(α;λ)1+λ2+λ2α(1+λ2)dsk1(α;λ).\displaystyle\frac{s_{{k^{\star}}}(\alpha;\lambda)}{\sqrt{1+\lambda^{2}}}+\frac{\lambda^{2}\alpha}{(1+\lambda^{2})d}s_{{k^{\star}}-1}(\alpha;\lambda).

For the first term, we have

sk(α;λ)sk1(α;λ)={αλα2λ2dλdαα2λ2d and k is evenαλα2λ2d and k is oddC1/4.\displaystyle\frac{s_{{k^{\star}}}(\alpha;\lambda)}{s_{{k^{\star}}-1}(\alpha;\lambda)}=\begin{cases}\frac{\alpha}{\lambda}&\alpha^{2}\geq\frac{\lambda^{2}}{d}\\ \frac{\lambda}{d\alpha}&\alpha^{2}\leq\frac{\lambda^{2}}{d}\text{ and ${k^{\star}}$ is even}\\ \frac{\alpha}{\lambda}&\alpha^{2}\leq\frac{\lambda^{2}}{d}\text{ and ${k^{\star}}$ is odd}\end{cases}\leq C^{-1/4}.

The second term is trivially bounded by

λ2α(1+λ2)dλ2(1+λ2)d1C(1+λ2)1C1+λ2\displaystyle\frac{\lambda^{2}\alpha}{(1+\lambda^{2})d}\leq\frac{\lambda^{2}}{(1+\lambda^{2})d}\leq\frac{1}{C(1+\lambda^{2})}\leq\frac{1}{C\sqrt{1+\lambda^{2}}}

which completes the proof. ∎

B.3 Concentrating the Empirical Gradient

We cannot directly apply Lemma 7 to σ(wx)\sigma(w\cdot x) as x1\norm{x}\neq 1. Instead, we will use the properties of the Hermite tensors to directly smooth σ(wx)\sigma(w\cdot x).

Lemma 11.
λ(Hek(wx))=Hek(x),Tk(w)\displaystyle\mathcal{L}_{\lambda}(He_{k}(w\cdot x))=\expectationvalue{He_{k}(x),T_{k}(w)}

where

Tk(w)=(1+λ2)k2j=0k2(k2j)w(k2j)~(Pw)~jλ2jνj(d1).\displaystyle T_{k}(w)=(1+\lambda^{2})^{-\frac{k}{2}}\sum_{j=0}^{\lfloor\frac{k}{2}\rfloor}\binom{k}{2j}w^{\otimes(k-2j)}\operatorname{\widetilde{\otimes}}(P_{w}^{\perp})^{\operatorname{\widetilde{\otimes}}j}\lambda^{2j}\nu_{j}^{(d-1)}.
Proof.

Using Lemma 5, we can write

λ(Hek(wx))=Hek(x),λ(wk).\displaystyle\mathcal{L}_{\lambda}\quantity(He_{k}(w\cdot x))=\expectationvalue{He_{k}(x),\mathcal{L}_{\lambda}\quantity(w^{\otimes k})}.

Now

Tk(w)\displaystyle T_{k}(w) =λ(wk)\displaystyle=\mathcal{L}_{\lambda}(w^{\otimes k})
=𝔼zμw[(w+λz1+λ2)k]\displaystyle=\operatorname{\mathbb{E}}_{z\sim\mu_{w}}\quantity[\quantity(\frac{w+\lambda z}{\sqrt{1+\lambda^{2}}})^{\otimes k}]
=(1+λ2)k2j=0k(kj)w(kj)~𝔼zμw[zj]λj.\displaystyle=(1+\lambda^{2})^{-\frac{k}{2}}\sum_{j=0}^{k}\binom{k}{j}w^{\otimes(k-j)}\operatorname{\widetilde{\otimes}}\operatorname{\mathbb{E}}_{z\sim\mu_{w}}[z^{\otimes j}]\lambda^{j}.

Now by Lemma 2, this is equal to

Tk(w)=(1+λ2)k2j=0k2(k2j)w(k2j)~(Pw)~jλ2jνj(d1)\displaystyle T_{k}(w)=(1+\lambda^{2})^{-\frac{k}{2}}\sum_{j=0}^{\lfloor\frac{k}{2}\rfloor}\binom{k}{2j}w^{\otimes(k-2j)}\operatorname{\widetilde{\otimes}}(P_{w}^{\perp})^{\operatorname{\widetilde{\otimes}}j}\lambda^{2j}\nu_{j}^{(d-1)}

which completes the proof. ∎

Lemma 12.

For any uSd1u\in S^{d-1} with uwu\perp w,

𝔼xN(0,Id)[(uwλ(Hek(wx)))2]\displaystyle\operatorname{\mathbb{E}}_{x\sim N(0,I_{d})}\quantity[\quantity(u\cdot\nabla_{w}\mathcal{L}_{\lambda}(He_{k}(w\cdot x)))^{2}]
k!(k21+λ2λ(αk1)+λ4k4(1+λ2)2d2λ(d+2)(αk2))|α=11+λ2.\displaystyle\qquad\lesssim k!\quantity(\frac{k^{2}}{1+\lambda^{2}}\mathcal{L}_{\lambda}\quantity(\alpha^{k-1})+\frac{\lambda^{4}k^{4}}{(1+\lambda^{2})^{2}d^{2}}\mathcal{L}_{\lambda}^{(d+2)}\quantity(\alpha^{k-2}))\evaluated{}_{\alpha=\frac{1}{\sqrt{1+\lambda^{2}}}}.
Proof.

Recall that by Lemma 11 we have

λ(Hek(wx))=Hek(x),Tk(w)\displaystyle\mathcal{L}_{\lambda}(He_{k}(w\cdot x))=\expectationvalue{He_{k}(x),T_{k}(w)}

where

Tk(w):=(1+λ2)k2j=0k2(k2j)w(k2j)~(Pw)~jλ2jνj(d1).\displaystyle T_{k}(w):=(1+\lambda^{2})^{-\frac{k}{2}}\sum_{j=0}^{\lfloor\frac{k}{2}\rfloor}\binom{k}{2j}w^{\otimes(k-2j)}\operatorname{\widetilde{\otimes}}(P_{w}^{\perp})^{\operatorname{\widetilde{\otimes}}j}\lambda^{2j}\nu_{j}^{(d-1)}.

Differentiating this with respect to ww gives

uwλ(Hek(wx))=Hek(x),wTk(w)[u].\displaystyle u\cdot\nabla_{w}\mathcal{L}_{\lambda}(He_{k}(w\cdot x))=\expectationvalue{He_{k}(x),\nabla_{w}T_{k}(w)[u]}.

Now note that by Lemma 5:

𝔼xN(0,Id)[(uwλ(Hek(wx)))2]\displaystyle\operatorname{\mathbb{E}}_{x\sim N(0,I_{d})}\quantity[\quantity(u\cdot\nabla_{w}\mathcal{L}_{\lambda}(He_{k}(w\cdot x)))^{2}] =𝔼xN(0,Id)[Hek(x),wTk(w)[u]2]\displaystyle=\operatorname{\mathbb{E}}_{x\sim N(0,I_{d})}\quantity[\expectationvalue{He_{k}(x),\nabla_{w}T_{k}(w)[u]}^{2}]
=k!wTk(w)[u]F2.\displaystyle=k!\norm{\nabla_{w}T_{k}(w)[u]}_{F}^{2}.

Therefore it suffices to compute the Frobenius norm of wTk(w)[u]\nabla_{w}T_{k}(w)[u]. We first explicitly differentiate Tk(w)T_{k}(w):

wTk(w)[u]\displaystyle\nabla_{w}T_{k}(w)[u] =(1+λ2)k2j=0k2(k2j)(k2j)u~wk2j1~(Pw)~jλ2jνj(d1)\displaystyle=(1+\lambda^{2})^{-\frac{k}{2}}\sum_{j=0}^{\lfloor\frac{k}{2}\rfloor}\binom{k}{2j}(k-2j)u\operatorname{\widetilde{\otimes}}w^{\otimes k-2j-1}\operatorname{\widetilde{\otimes}}(P_{w}^{\perp})^{\operatorname{\widetilde{\otimes}}j}\lambda^{2j}\nu_{j}^{(d-1)}
(1+λ2)k2j=0k2(k2j)(2j)u~wk2j+1~(Pw)~(j1)λ2jνj(d1)\displaystyle\qquad-(1+\lambda^{2})^{-\frac{k}{2}}\sum_{j=0}^{\lfloor\frac{k}{2}\rfloor}\binom{k}{2j}(2j)u\operatorname{\widetilde{\otimes}}w^{\otimes k-2j+1}\operatorname{\widetilde{\otimes}}(P_{w}^{\perp})^{\operatorname{\widetilde{\otimes}}(j-1)}\lambda^{2j}\nu_{j}^{(d-1)}
=k(1+λ2)k2j=0k12(k12j)u~wk12j~(Pw)~jλ2jνj(d1)\displaystyle=\frac{k}{(1+\lambda^{2})^{\frac{k}{2}}}\sum_{j=0}^{\lfloor\frac{k-1}{2}\rfloor}\binom{k-1}{2j}u\operatorname{\widetilde{\otimes}}w^{\otimes k-1-2j}\operatorname{\widetilde{\otimes}}(P_{w}^{\perp})^{\operatorname{\widetilde{\otimes}}j}\lambda^{2j}\nu_{j}^{(d-1)}
λ2k(k1)(d1)(1+λ2)k2j=0k22(k22j)u~wk12j~(Pw)~jλ2jνj(d+1).\displaystyle\qquad-\frac{\lambda^{2}k(k-1)}{(d-1)(1+\lambda^{2})^{\frac{k}{2}}}\sum_{j=0}^{\lfloor\frac{k-2}{2}\rfloor}\binom{k-2}{2j}u\operatorname{\widetilde{\otimes}}w^{\otimes k-1-2j}\operatorname{\widetilde{\otimes}}(P_{w}^{\perp})^{\operatorname{\widetilde{\otimes}}j}\lambda^{2j}\nu_{j}^{(d+1)}.

Taking Frobenius norms gives

wTk(w)[u]F2\displaystyle\norm{\nabla_{w}T_{k}(w)[u]}_{F}^{2}
k2(1+λ2)kj=0k12(k12j)u~wk12j~(Pw)~jλ2jνj(d1)F2\displaystyle\lesssim\frac{k^{2}}{(1+\lambda^{2})^{k}}\norm{\sum_{j=0}^{\lfloor\frac{k-1}{2}\rfloor}\binom{k-1}{2j}u\operatorname{\widetilde{\otimes}}w^{\otimes k-1-2j}\operatorname{\widetilde{\otimes}}(P_{w}^{\perp})^{\operatorname{\widetilde{\otimes}}j}\lambda^{2j}\nu_{j}^{(d-1)}}_{F}^{2}
+λ4k4(d1)2(1+λ2)kj=0k22(k22j)u~wk12j~(Pw)~jλ2jνj(d+1)F2.\displaystyle\qquad+\frac{\lambda^{4}k^{4}}{(d-1)^{2}(1+\lambda^{2})^{k}}\norm{\sum_{j=0}^{\lfloor\frac{k-2}{2}\rfloor}\binom{k-2}{2j}u\operatorname{\widetilde{\otimes}}w^{\otimes k-1-2j}\operatorname{\widetilde{\otimes}}(P_{w}^{\perp})^{\operatorname{\widetilde{\otimes}}j}\lambda^{2j}\nu_{j}^{(d+1)}}_{F}^{2}.

Now we can use Lemma 1 to pull out uu and get:

wTk(w)[u]F2\displaystyle\norm{\nabla_{w}T_{k}(w)[u]}_{F}^{2}
k2(1+λ2)kj=0k12(k12j)wk12j~(Pw)~jλ2jνj(d1)F2\displaystyle\lesssim\frac{k^{2}}{(1+\lambda^{2})^{k}}\norm{\sum_{j=0}^{\lfloor\frac{k-1}{2}\rfloor}\binom{k-1}{2j}w^{\otimes k-1-2j}\operatorname{\widetilde{\otimes}}(P_{w}^{\perp})^{\operatorname{\widetilde{\otimes}}j}\lambda^{2j}\nu_{j}^{(d-1)}}_{F}^{2}
+λ4k4d2(1+λ2)kj=0k22(k22j)wk12j~(Pw)~jλ2jνj(d+1)F2.\displaystyle\qquad+\frac{\lambda^{4}k^{4}}{d^{2}(1+\lambda^{2})^{k}}\norm{\sum_{j=0}^{\lfloor\frac{k-2}{2}\rfloor}\binom{k-2}{2j}w^{\otimes k-1-2j}\operatorname{\widetilde{\otimes}}(P_{w}^{\perp})^{\operatorname{\widetilde{\otimes}}j}\lambda^{2j}\nu_{j}^{(d+1)}}_{F}^{2}.

Now note that the terms in each sum are orthogonal as at least one ww will need to be contracted with a PwP_{w}^{\perp}. Therefore this is equivalent to:

wTk(w)[u]F2\displaystyle\norm{\nabla_{w}T_{k}(w)[u]}_{F}^{2}
k2(1+λ2)kj=0k12(k12j)2λ4j(νj(d1))2wk12j~(Pw)~jF2\displaystyle\lesssim\frac{k^{2}}{(1+\lambda^{2})^{k}}\sum_{j=0}^{\lfloor\frac{k-1}{2}\rfloor}\binom{k-1}{2j}^{2}\lambda^{4j}(\nu_{j}^{(d-1)})^{2}\norm{w^{\otimes k-1-2j}\operatorname{\widetilde{\otimes}}(P_{w}^{\perp})^{\operatorname{\widetilde{\otimes}}j}}_{F}^{2}
+λ4k4d2(1+λ2)kj=0k22(k22j)2λ4j(νj(d+1))2wk12j~(Pw)~jF2.\displaystyle\qquad+\frac{\lambda^{4}k^{4}}{d^{2}(1+\lambda^{2})^{k}}\sum_{j=0}^{\lfloor\frac{k-2}{2}\rfloor}\binom{k-2}{2j}^{2}\lambda^{4j}(\nu_{j}^{(d+1)})^{2}\norm{w^{\otimes k-1-2j}\operatorname{\widetilde{\otimes}}(P_{w}^{\perp})^{\operatorname{\widetilde{\otimes}}j}}_{F}^{2}.

Next, note that for any kk tensor AA, Sym(A)F2=1k!πSkA,π(A)\norm{\operatorname{Sym}(A)}_{F}^{2}=\frac{1}{k!}\sum_{\pi\in S_{k}}\expectationvalue{A,\pi(A)}. When A=wk2j~(Pw)~jA=w^{\otimes k-2j}\operatorname{\widetilde{\otimes}}(P_{w}^{\perp})^{\operatorname{\widetilde{\otimes}}j}, the only permutations that don’t give 0 are the ones which pair up all of the wws of which there are (2j)!(k2j)!(2j)!(k-2j)!. Therefore, by Lemma 3,

wk2j~(Pw)~jF2=1(k2j)(Pw)~jF2=1νj(d1)(k2j).\displaystyle\norm{w^{\otimes k-2j}\operatorname{\widetilde{\otimes}}(P_{w}^{\perp})^{\operatorname{\widetilde{\otimes}}j}}_{F}^{2}=\frac{1}{\binom{k}{2j}}\norm{(P_{w}^{\perp})^{\operatorname{\widetilde{\otimes}}j}}_{F}^{2}=\frac{1}{\nu_{j}^{(d-1)}\binom{k}{2j}}.

Plugging this in gives:

wTk(w)[u]F2\displaystyle\norm{\nabla_{w}T_{k}(w)[u]}_{F}^{2}
k2(1+λ2)kj=0k12(k12j)λ4jνj(d1)\displaystyle\lesssim\frac{k^{2}}{(1+\lambda^{2})^{k}}\sum_{j=0}^{\lfloor\frac{k-1}{2}\rfloor}\binom{k-1}{2j}\lambda^{4j}\nu_{j}^{(d-1)}
+λ4k4d2(1+λ2)kj=0k22(k22j)λ4jνj(d+1).\displaystyle\qquad+\frac{\lambda^{4}k^{4}}{d^{2}(1+\lambda^{2})^{k}}\sum_{j=0}^{\lfloor\frac{k-2}{2}\rfloor}\binom{k-2}{2j}\lambda^{4j}\nu_{j}^{(d+1)}.

Now note that

λ(αk)|α=11+λ2=1(1+λ2)kk=0k2(k2j)λ4jνj(d1)\displaystyle\mathcal{L}_{\lambda}(\alpha^{k})\evaluated{}_{\alpha=\frac{1}{\sqrt{1+\lambda^{2}}}}=\frac{1}{(1+\lambda^{2})^{k}}\sum_{k=0}^{\lfloor\frac{k}{2}\rfloor}\binom{k}{2j}\lambda^{4j}\nu_{j}^{(d-1)}

which completes the proof. ∎

Corollary 1.

For any uSd1u\in S^{d-1} with uwu\perp w,

𝔼xN(0,Id)[(uwλ(σ(wx)))2]min(1+λ2,d)(k1)1+λ2.\displaystyle\operatorname{\mathbb{E}}_{x\sim N(0,I_{d})}\quantity[\quantity(u\cdot\nabla_{w}\mathcal{L}_{\lambda}\quantity(\sigma(w\cdot x)))^{2}]\lesssim\frac{\min\quantity(1+\lambda^{2},\sqrt{d})^{-({k^{\star}}-1)}}{1+\lambda^{2}}.
Proof.

Note that

λ(σ(wx))=kckk!Hek(x),Tk(w)\displaystyle\mathcal{L}_{\lambda}\quantity(\sigma(w\cdot x))=\sum_{k}\frac{c_{k}}{k!}\expectationvalue{He_{k}(x),T_{k}(w)}

so

uwλ(σ(wx))=kckk!Hek(x),wTk(w)[u].\displaystyle u\cdot\nabla_{w}\mathcal{L}_{\lambda}\quantity(\sigma(w\cdot x))=\sum_{k}\frac{c_{k}}{k!}\expectationvalue{He_{k}(x),\nabla_{w}T_{k}(w)[u]}.

By Lemma 4, these terms are orthogonal so by Lemma 12,

𝔼xN(0,Id)[(uwλ(σ(wx)))2]\displaystyle\operatorname{\mathbb{E}}_{x\sim N(0,I_{d})}\quantity[\quantity(u\cdot\nabla_{w}\mathcal{L}_{\lambda}\quantity(\sigma(w\cdot x)))^{2}]
kck2k!(k21+λ2λ(αk1)+λ4k4(1+λ2)2d2λ(d+2)(αk2))|α=11+λ2\displaystyle\lesssim\sum_{k}\frac{c_{k}^{2}}{k!}\quantity(\frac{k^{2}}{1+\lambda^{2}}\mathcal{L}_{\lambda}\quantity(\alpha^{k-1})+\frac{\lambda^{4}k^{4}}{(1+\lambda^{2})^{2}d^{2}}\mathcal{L}_{\lambda}^{(d+2)}\quantity(\alpha^{k-2}))\evaluated{}_{\alpha=\frac{1}{\sqrt{1+\lambda^{2}}}}
11+λ2λ(αk1(1α)3)+λ4(1+λ2)2d2λ(d+2)(αk2(1α)5)\displaystyle\lesssim\frac{1}{1+\lambda^{2}}\mathcal{L}_{\lambda}\quantity(\frac{\alpha^{{k^{\star}}-1}}{(1-\alpha)^{3}})+\frac{\lambda^{4}}{(1+\lambda^{2})^{2}d^{2}}\mathcal{L}_{\lambda}^{(d+2)}\quantity(\frac{\alpha^{{k^{\star}}-2}}{(1-\alpha)^{5}})
11+λ2sk1(11+λ2;λ)+λ4(1+λ2)2dsk2(11+λ2;λ)\displaystyle\lesssim\frac{1}{1+\lambda^{2}}s_{{k^{\star}}-1}\quantity(\frac{1}{\sqrt{1+\lambda^{2}}};\lambda)+\frac{\lambda^{4}}{(1+\lambda^{2})^{2}d}s_{{k^{\star}}-2}\quantity(\frac{1}{\sqrt{1+\lambda^{2}}};\lambda)
11+λ2sk1(λ1;λ)+λ4(1+λ2)2dsk2(λ1;λ).\displaystyle\lesssim\frac{1}{1+\lambda^{2}}s_{{k^{\star}}-1}\quantity(\lambda^{-1};\lambda)+\frac{\lambda^{4}}{(1+\lambda^{2})^{2}d}s_{{k^{\star}}-2}\quantity(\lambda^{-1};\lambda).

Now plugging in the formula for sk(α;λ)s_{k}(\alpha;\lambda) gives:

𝔼xN(0,Id)[(uwλ(σ(wx)))2]\displaystyle\operatorname{\mathbb{E}}_{x\sim N(0,I_{d})}\quantity[\quantity(u\cdot\nabla_{w}\mathcal{L}_{\lambda}\quantity(\sigma(w\cdot x)))^{2}] (1+λ2)k+12{(1+λ2)k121+λ2d(1+λ2d)k121+λ2d\displaystyle\lesssim(1+\lambda^{2})^{-\frac{{k^{\star}}+1}{2}}\begin{cases}(1+\lambda^{2})^{-\frac{{k^{\star}}-1}{2}}&1+\lambda^{2}\leq\sqrt{d}\\ \quantity(\frac{1+\lambda^{2}}{d})^{\frac{{k^{\star}}-1}{2}}&1+\lambda^{2}\geq\sqrt{d}\end{cases}
min(1+λ2,d)(k1)1+λ2.\displaystyle\lesssim\frac{\min\quantity(1+\lambda^{2},\sqrt{d})^{-({k^{\star}}-1)}}{1+\lambda^{2}}.

The following lemma shows that wλ(σ(wx))\nabla_{w}\mathcal{L}_{\lambda}\quantity(\sigma(w\cdot x)) inherits polynomial tails from σ\sigma^{\prime}:

Lemma 13.

There exists an absolute constant CC such that for any uSd1u\in S^{d-1} with uwu\perp w and any p[0,d/C]p\in[0,d/C],

𝔼xN(0,Id)[(uwλ(σ(wx)))p]1/ppC1+λ2.\displaystyle\operatorname{\mathbb{E}}_{x\sim N(0,I_{d})}\quantity[\quantity(u\cdot\nabla_{w}\mathcal{L}_{\lambda}\quantity(\sigma(w\cdot x)))^{p}]^{1/p}\lesssim\frac{p^{C}}{\sqrt{1+\lambda^{2}}}.
Proof.

Following the proof of Lemma 9 we have

uwλ(σ(wx))\displaystyle u\cdot\nabla_{w}\mathcal{L}_{\lambda}\quantity(\sigma(w\cdot x))
=(ux)𝔼z1Sd2[σ(wx+λz1Pwx1+λ2)(11+λ2λz1(wx)Pwx1+λ2)].\displaystyle=(u\cdot x)\operatorname{\mathbb{E}}_{z_{1}\sim S^{d-2}}\quantity[\sigma^{\prime}\quantity(\frac{w\cdot x+\lambda z_{1}\norm{P_{w}^{\perp}x}}{\sqrt{1+\lambda^{2}}})\quantity(\frac{1}{\sqrt{1+\lambda^{2}}}-\frac{\lambda z_{1}(w\cdot x)}{\norm{P_{w}^{\perp}x}\sqrt{1+\lambda^{2}}})].

First, we consider the first term. Its pp norm is bounded by

11+λ2𝔼x[(ux)2p]𝔼x(𝔼z1Sd2[σ(wx+λz1Pwx1+λ2)]2p).\displaystyle\frac{1}{\sqrt{1+\lambda^{2}}}\operatorname{\mathbb{E}}_{x}\quantity[(u\cdot x)^{2p}]\operatorname{\mathbb{E}}_{x}\quantity(\operatorname{\mathbb{E}}_{z_{1}\sim S^{d-2}}\quantity[\sigma^{\prime}\quantity(\frac{w\cdot x+\lambda z_{1}\norm{P_{w}^{\perp}x}}{\sqrt{1+\lambda^{2}}})]^{2p}).

By Jensen we can pull out the expectation over z1z_{1} and use 1 to get

11+λ2𝔼x[(ux)2p]𝔼xN(0,Id),z1Sd2[σ(wx+λz1Pwx1+λ2)2p]\displaystyle\frac{1}{\sqrt{1+\lambda^{2}}}\operatorname{\mathbb{E}}_{x}[(u\cdot x)^{2p}]\operatorname{\mathbb{E}}_{x\sim N(0,I_{d}),z_{1}\sim S^{d-2}}\quantity[\sigma^{\prime}\quantity(\frac{w\cdot x+\lambda z_{1}\norm{P_{w}^{\perp}x}}{\sqrt{1+\lambda^{2}}})^{2p}]
=11+λ2𝔼xN(0,1)[x2p]𝔼xN(0,1)[σ(x)p]\displaystyle=\frac{1}{\sqrt{1+\lambda^{2}}}\operatorname{\mathbb{E}}_{x\sim N(0,1)}[x^{2p}]\operatorname{\mathbb{E}}_{x\sim N(0,1)}\quantity[\sigma^{\prime}(x)^{p}]
poly(p)1+λ2.\displaystyle\lesssim\frac{\operatorname{poly}(p)}{\sqrt{1+\lambda^{2}}}.

Similarly, the pp norm of the second term is bounded by

λ1+λ2poly(p)𝔼x,z1[(z1(xw)Pw)2p]12pλd1+λ2poly(p)poly(p)1+λ2.\displaystyle\frac{\lambda}{\sqrt{1+\lambda^{2}}}\cdot\operatorname{poly}(p)\cdot\operatorname{\mathbb{E}}_{x,z_{1}}\quantity[\quantity(\frac{z_{1}(x\cdot w)}{\norm{P_{w}^{\perp}}})^{2p}]^{\frac{1}{2p}}\lesssim\frac{\lambda}{d\sqrt{1+\lambda^{2}}}\operatorname{poly}(p)\ll\frac{\operatorname{poly}(p)}{\sqrt{1+\lambda^{2}}}.

Finally, we can use Corollary 1 and Lemma 13 to bound the pp norms of the gradient:

Lemma 14.

Let (x,y)(x,y) be a fresh sample and let v=Lλ(w;x;y)v=-\nabla L_{\lambda}(w;x;y). Then there exists a constant CC such that for any uSd1u\in S^{d-1} with uwu\perp w, any λd1/4\lambda\leq d^{1/4} and all 2pd/C2\leq p\leq d/C,

𝔼[(uv)p]1/ppoly(p)O~((1+λ2)12k1p).\displaystyle\operatorname{\mathbb{E}}_{\mathcal{B}}\quantity[(u\cdot v)^{p}]^{1/p}\lesssim\operatorname{poly}(p)\cdot\tilde{O}\quantity((1+\lambda^{2})^{-\frac{1}{2}-\frac{{k^{\star}}-1}{p}}).
Proof.

First,

𝔼[(uv)p]1/p\displaystyle\operatorname{\mathbb{E}}_{\mathcal{B}}[(u\cdot v)^{p}]^{1/p} =𝔼[(uLλ(w;))p]1/p\displaystyle=\operatorname{\mathbb{E}}_{\mathcal{B}}[(u\cdot\nabla L_{\lambda}(w;\mathcal{B}))^{p}]^{1/p}
=𝔼x,y[yp(uwλ(σ(wx))p]1/p.\displaystyle=\operatorname{\mathbb{E}}_{x,y}[y^{p}(u\cdot\nabla_{w}\mathcal{L}_{\lambda}\quantity(\sigma(w\cdot x))^{p}]^{1/p}.

Applying Lemma 23 with X=(uwλ(σ(wx)))2X=(u\cdot\nabla_{w}\mathcal{L}_{\lambda}(\sigma(w\cdot x)))^{2} and Y=yp(uwλ(σ(wx)))p2Y=y^{p}(u\cdot\nabla_{w}\mathcal{L}_{\lambda}(\sigma(w\cdot x)))^{p-2} gives:

𝔼[(uv)p]1/p\displaystyle\operatorname{\mathbb{E}}_{\mathcal{B}}[(u\cdot v)^{p}]^{1/p} poly(p)O~(min(1+λ2,d)k1p1+λ2)\displaystyle\lesssim\operatorname{poly}(p)\tilde{O}\quantity(\frac{\min\quantity(1+\lambda^{2},\sqrt{d})^{-\frac{{k^{\star}}-1}{p}}}{\sqrt{1+\lambda^{2}}})
poly(p)O~((1+λ2)12k1p)\displaystyle\lesssim\operatorname{poly}(p)\cdot\tilde{O}\quantity((1+\lambda^{2})^{-\frac{1}{2}-\frac{{k^{\star}}-1}{p}})

which completes the proof. ∎

Corollary 2.

Let v,ϵv,\epsilon be as in Lemma 14. Then for all 2pd/C2\leq p\leq d/C,

𝔼[v2p]1/ppoly(p)dO~((1+λ2)1k1p).\displaystyle\operatorname{\mathbb{E}}_{\mathcal{B}}[\norm{v}^{2p}]^{1/p}\lesssim\operatorname{poly}(p)\cdot d\cdot\tilde{O}\quantity((1+\lambda^{2})^{-1-\frac{{k^{\star}}-1}{p}}).
Proof.

By Jensen’s inequality,

v2p=𝔼[(i=1d(vei)2)p]dp1𝔼[i=1d(vei)2p]dpmaxi𝔼[(zei)2p].\displaystyle\norm{v}^{2p}=\operatorname{\mathbb{E}}\quantity[\quantity(\sum_{i=1}^{d}(v\cdot e_{i})^{2})^{p}]\lesssim d^{p-1}\operatorname{\mathbb{E}}\quantity[\sum_{i=1}^{d}(v\cdot e_{i})^{2p}]\lesssim d^{p}\max_{i}\operatorname{\mathbb{E}}[(z\cdot e_{i})^{2p}].

Taking ppth roots and using Lemma 14 finishes the proof. ∎

B.4 Analyzing the Dynamics

Throughout this section we will assume 1λd1/41\leq\lambda\leq d^{1/4}. The proof of the dynamics is split into three stages.

In the first stage, we analyze the regime α[α0,λd1/2]\alpha\in[\alpha_{0},\lambda d^{-1/2}]. In this regime, the signal is dominated by the smoothing.

In the second stage, we analyze the regime α[λd1/2,1od(1)]\alpha\in[\lambda d^{-1/2},1-o_{d}(1)]. This analysis is similar to the analysis in Ben Arous et al. [1] and could be equivalently carried out with λ=0\lambda=0.

Finally in the third stage, we decay the learning rate linearly to achieve the optimal rate

ndk2+dϵ.\displaystyle n\gtrsim d^{\frac{{k^{\star}}}{2}}+\frac{d}{\epsilon}.

All three stages will use the following progress lemma:

Lemma 15.

Let wSd1w\in S^{d-1} and let α:=ww\alpha:=w\cdot w^{\star}. Let (x,y)(x,y) be a fresh batch and define

v:=wLλ(w;x;y),z:=v𝔼x,y[v],w=w+λvw+λv and α:=ww.\displaystyle v:=-\nabla_{w}L_{\lambda}(w;x;y),\quad z:=v-\operatorname{\mathbb{E}}_{x,y}[v],\quad w^{\prime}=\frac{w+\lambda v}{\norm{w+\lambda v}}\mbox{\quad and\quad}\alpha^{\prime}:=w^{\prime}\cdot w^{\star}.

Then if ηα1+λ2\eta\lesssim\alpha\sqrt{1+\lambda^{2}},

α=α+η(1α2)cλ(α)+Z+O~(η2dα(1+λ2)k).\displaystyle\alpha^{\prime}=\alpha+\eta(1-\alpha^{2})c_{\lambda}(\alpha)+Z+\tilde{O}\quantity(\frac{\eta^{2}d\alpha}{(1+\lambda^{2})^{k^{\star}}}).

where 𝔼x,y[Z]=0\operatorname{\mathbb{E}}_{x,y}[Z]=0 and for all 2pd/C2\leq p\leq d/C,

𝔼x,y[Zp]1/pO~(poly(p))[η(1+λ2)12(k1)p][1α2+ηdα1+λ2].\displaystyle\operatorname{\mathbb{E}}_{x,y}[Z^{p}]^{1/p}\leq\tilde{O}(\operatorname{poly}(p))\quantity[\eta(1+\lambda^{2})^{-\frac{1}{2}-\frac{({k^{\star}}-1)}{p}}]\quantity[\sqrt{1-\alpha^{2}}+\frac{\eta d\alpha}{\sqrt{1+\lambda^{2}}}].

Furthermore, if λ=O(1)\lambda=O(1) the O~()\tilde{O}(\cdot) can be replaced with O()O(\cdot).

Proof.

Because vwv\perp w and 111+x21x221\geq\frac{1}{\sqrt{1+x^{2}}}\geq 1-\frac{x^{2}}{2},

α=α+η(vw)1+η2v2=α+η(vw)+r\displaystyle\alpha^{\prime}=\frac{\alpha+\eta(v\cdot w^{\star})}{\sqrt{1+\eta^{2}\norm{v}^{2}}}=\alpha+\eta(v\cdot w^{\star})+r

where |r|η22v2[α+η|vw|]\absolutevalue{r}\leq\frac{\eta^{2}}{2}\norm{v}^{2}\quantity[\alpha+\eta\absolutevalue{v\cdot w^{\star}}]. Note that by Lemma 14, η(vw)\eta(v\cdot w^{\star}) has moments bounded by ηλpoly(p)αpoly(p)\frac{\eta}{\lambda}\operatorname{poly}(p)\lesssim\alpha\operatorname{poly}(p). Therefore by Lemma 23 with X=v2X=\norm{v}^{2} and Y=α+η|vw|Y=\alpha+\eta\absolutevalue{v\cdot w^{\star}},

𝔼x,y[r]O~(η2𝔼[v2]α).\displaystyle\operatorname{\mathbb{E}}_{x,y}[r]\leq\tilde{O}\quantity(\eta^{2}\operatorname{\mathbb{E}}[\norm{v}^{2}]\alpha).

Plugging in the bound on 𝔼[v2]\operatorname{\mathbb{E}}[\norm{v}^{2}] from Corollary 2 gives

𝔼x,y[α]=α+η(1α2)cλ(α)+O~(η2dα(1+λ2)k).\displaystyle\operatorname{\mathbb{E}}_{x,y}[\alpha^{\prime}]=\alpha+\eta(1-\alpha^{2})c_{\lambda}(\alpha)+\tilde{O}\quantity(\eta^{2}d\alpha(1+\lambda^{2})^{-{k^{\star}}}).

In addition, by Lemma 14,

𝔼x,y[|η(vw)𝔼x,y[η(vw)]|p]1/p\displaystyle\operatorname{\mathbb{E}}_{x,y}\quantity[\absolutevalue{\eta(v\cdot w^{\star})-\operatorname{\mathbb{E}}_{x,y}[\eta(v\cdot w^{\star})]}^{p}]^{1/p} poly(p)ηPwwO~((1+λ2)12k1p)\displaystyle\lesssim\operatorname{poly}(p)\cdot\eta\cdot\norm{P_{w}^{\perp}w^{\star}}\cdot\tilde{O}\quantity((1+\lambda^{2})^{-\frac{1}{2}-\frac{{k^{\star}}-1}{p}})
=poly(p)η1α2O~((1+λ2)12k1p).\displaystyle=\operatorname{poly}(p)\cdot\eta\cdot\sqrt{1-\alpha^{2}}\cdot\tilde{O}\quantity((1+\lambda^{2})^{-\frac{1}{2}-\frac{{k^{\star}}-1}{p}}).

Similarly, by Lemma 23 with X=v2X=\norm{v}^{2} and Y=v2(p1)[α+η|vw|]pY=\norm{v}^{2(p-1)}[\alpha+\eta\absolutevalue{v\cdot w}]^{p}, Lemma 14, and Corollary 2,

𝔼x,y[|rt𝔼x,y[rt]|p]1/p\displaystyle\operatorname{\mathbb{E}}_{x,y}\quantity[\absolutevalue{r_{t}-\operatorname{\mathbb{E}}_{x,y}[r_{t}]}^{p}]^{1/p} 𝔼x,y[|rt|p]1/p\displaystyle\lesssim\operatorname{\mathbb{E}}_{x,y}\quantity[\absolutevalue{r_{t}}^{p}]^{1/p}
η2αpoly(p)O~((d(1+λ2)k)1/p(d1+λ2)p1p)\displaystyle\lesssim\eta^{2}\alpha\operatorname{poly}(p)\tilde{O}\quantity(\quantity(\frac{d}{(1+\lambda^{2})^{k^{\star}}})^{1/p}\quantity(\frac{d}{1+\lambda^{2}})^{\frac{p-1}{p}})
=poly(p)η2dαO~((1+λ2)1k1p).\displaystyle=\operatorname{poly}(p)\cdot\eta^{2}d\alpha\cdot\tilde{O}\quantity((1+\lambda^{2})^{-1-\frac{{k^{\star}}-1}{p}}).

We can now analyze the first stage in which α[d1/2,λd1/2]\alpha\in[d^{-1/2},\lambda\cdot d^{-1/2}]. This stage is dominated by the signal from the smoothing.

Lemma 16 (Stage 1).

Assume that λ1\lambda\geq 1 and α01Cd1/2\alpha_{0}\geq\frac{1}{C}d^{-1/2}. Set

η=dk2(1+λ2)k1log(d)C and T1=C(1+λ2)dk22log(d)η=O~(dk1λ2k+4)\displaystyle\eta=\frac{d^{-\frac{{k^{\star}}}{2}}(1+\lambda^{2})^{{k^{\star}}-1}}{\log(d)^{C}}\mbox{\quad and\quad}T_{1}=\frac{C(1+\lambda^{2})d^{\frac{{k^{\star}}-2}{2}}\log(d)}{\eta}=\tilde{O}\quantity(d^{{k^{\star}}-1}\lambda^{-2{k^{\star}}+4})

for a sufficiently large constant CC. Then with high probability, there exists tT1t\leq T_{1} such that αtλd1/2\alpha_{t}\geq\lambda d^{-1/2}.

Proof.

Let τ\tau be the hitting time for ατλd1/2\alpha_{\tau}\geq\lambda d^{-1/2}. For tT1t\leq T_{1}, let EtE_{t} be the event that

αt12[α0+ηj=0t1cλ(αj)].\displaystyle\alpha_{t}\geq\frac{1}{2}\quantity[\alpha_{0}+\eta\sum_{j=0}^{t-1}c_{\lambda}(\alpha_{j})].

We will prove by induction that for any tT1t\leq T_{1}, the event: {Et or tτ}\quantity{E_{t}\text{ or }t\geq\tau} happens with high probability. The base case of t=0t=0 is trivial so let t0t\geq 0 and assume the result for all s<ts<t. Note that η/λd1/2Cαj\eta/\lambda\ll\frac{d^{-1/2}}{C}\leq\alpha_{j} so by Lemma 15 and the fact that λ1\lambda\geq 1,

αt=α0+j=0t1[η(1αj2)cλ(αj)+Zj+O~(η2dαjλ2k)].\displaystyle\alpha_{t}=\alpha_{0}+\sum_{j=0}^{t-1}\quantity[\eta(1-\alpha_{j}^{2})c_{\lambda}(\alpha_{j})+Z_{j}+\tilde{O}\quantity(\eta^{2}d\alpha_{j}\lambda^{-2{k^{\star}}})].

Now note that [Et or tτ]=1[!Et and t<τ]\operatorname{\mathbb{P}}[E_{t}\text{ or }t\geq\tau]=1-\operatorname{\mathbb{P}}\quantity[!E_{t}\text{ and }t<\tau] so let us condition on the event t<τt<\tau. Then by the induction hypothesis, with high probability we have αs[α02,λd1/2]\alpha_{s}\in[\frac{\alpha_{0}}{2},\lambda d^{-1/2}] for all s<ts<t. Plugging in the value of η\eta gives:

η(1αj2)cλ(αj)+O~(η2dαjλ2k)\displaystyle\eta(1-\alpha_{j}^{2})c_{\lambda}(\alpha_{j})+\tilde{O}\quantity(\eta^{2}d\alpha_{j}\lambda^{-2{k^{\star}}})
η(1αj2)cλ(αj)ηdk22λ2αjC\displaystyle\geq\eta(1-\alpha_{j}^{2})c_{\lambda}(\alpha_{j})-\frac{\eta d^{-\frac{{k^{\star}}-2}{2}}\lambda^{-2}\alpha_{j}}{C}
ηcλ(αj)2.\displaystyle\geq\frac{\eta c_{\lambda}(\alpha_{j})}{2}.

Similarly, because j=0t1Zj\sum_{j=0}^{t-1}Z_{j} is a martingale we have by Lemma 22 and Lemma 24 that with high probability,

j=0t1Zj\displaystyle\sum_{j=0}^{t-1}Z_{j} O~([T1ηλk+ηλ1][1+maxj<tηdαjλ])\displaystyle\lesssim\tilde{O}\quantity(\quantity[\sqrt{T_{1}}\cdot\eta\lambda^{-{k^{\star}}}+\eta\lambda^{-1}]\quantity[1+\max_{j<t}\frac{\eta d\alpha_{j}}{\lambda}])
O~(T1ηλk+ηλ1)\displaystyle\lesssim\tilde{O}\quantity(\sqrt{T_{1}}\cdot\eta\lambda^{-{k^{\star}}}+\eta\lambda^{-1})
d1/2C.\displaystyle\leq\frac{d^{-1/2}}{C}.

where we used that ηdαj/ληd1\eta d\alpha_{j}/\lambda\leq\eta\sqrt{d}\ll 1. Therefore conditioned on tτt\leq\tau we have with high probability that for all sts\leq t:

αt12[α0+ηj=0t1cλ(αj)].\displaystyle\alpha_{t}\geq\frac{1}{2}\quantity[\alpha_{0}+\eta\sum_{j=0}^{t-1}c_{\lambda}(\alpha_{j})].

Now we split into two cases depending on the parity of k{k^{\star}}. First, if k{k^{\star}} is odd we have that with high probability, for all tT1t\leq T_{1}:

αtα0+ηtλ1dk12 or tτ.\displaystyle\alpha_{t}\gtrsim\alpha_{0}+\eta t\lambda^{-1}d^{-\frac{{k^{\star}}-1}{2}}\mbox{\quad or\quad}t\geq\tau.

Now let t=T1t=T_{1}. Then we have that with high probability,

αtλd1/2 or τT1\displaystyle\alpha_{t}\geq\lambda d^{-1/2}\mbox{\quad or\quad}\tau\leq T_{1}

which implies that τT1\tau\leq T_{1} with high probability. Next, if k{k^{\star}} is even we have that with high probability

αtα0+ηdk22λ2s=0t1αs or tτ.\displaystyle\alpha_{t}\gtrsim\alpha_{0}+\frac{\eta\cdot d^{-\frac{{k^{\star}}-2}{2}}}{\lambda^{2}}\sum_{s=0}^{t-1}\alpha_{s}\mbox{\quad or\quad}t\geq\tau.

As above, by Lemma 27 the first event implies that αT1λd1/2\alpha_{T_{1}}\geq\lambda d^{-1/2} so we must have τT1\tau\leq T_{1} with high probability. ∎

Next, we consider what happens when αλd1/2\alpha\geq\lambda d^{-1/2}. The analysis in this stage is similar to the online SGD analysis in [1].

Lemma 17 (Stage 2).

Assume that α0λd1/2\alpha_{0}\geq\lambda d^{-1/2}. Set η,T1\eta,T_{1} as in Lemma 16. Then with high probability, αT11d1/4\alpha_{T_{1}}\geq 1-d^{-1/4}.

Proof.

The proof is almost identical to Lemma 16. We again have from Lemma 15

αtα0+j=0t1[η(1αj2)cλ(αj)+ZjO~(η2dαjλ2k)].\displaystyle\alpha_{t}\geq\alpha_{0}+\sum_{j=0}^{t-1}\quantity[\eta(1-\alpha_{j}^{2})c_{\lambda}(\alpha_{j})+Z_{j}-\tilde{O}\quantity(\eta^{2}d\alpha_{j}\lambda^{-2{k^{\star}}})].

First, from martingale concentration we have that

j=0t1Zj\displaystyle\sum_{j=0}^{t-1}Z_{j} O~([T1ηλk+ηλ1][1+ηdλ])\displaystyle\lesssim\tilde{O}\quantity(\quantity[\sqrt{T_{1}}\cdot\eta\lambda^{-{k^{\star}}}+\eta\lambda^{-1}]\quantity[1+\frac{\eta d}{\lambda}])
O~([T1ηλk+ηλ1]λ)\displaystyle\lesssim\tilde{O}\quantity(\quantity[\sqrt{T_{1}}\cdot\eta\lambda^{-{k^{\star}}}+\eta\lambda^{-1}]\cdot\lambda)
λd1/2C\displaystyle\lesssim\frac{\lambda d^{-1/2}}{C}

where we used that ηλ2d\eta\ll\frac{\lambda^{2}}{d}. Therefore with high probability,

αt\displaystyle\alpha_{t} α02+j=0t1[η(1αj2)cλ(αj)O~(η2dαjλ2k)]\displaystyle\geq\frac{\alpha_{0}}{2}+\sum_{j=0}^{t-1}\quantity[\eta(1-\alpha_{j}^{2})c_{\lambda}(\alpha_{j})-\tilde{O}\quantity(\eta^{2}d\alpha_{j}\lambda^{-2{k^{\star}}})]
α02+ηj=0t1[(1αj2)cλ(αj)αjdk22Cλ2].\displaystyle\geq\frac{\alpha_{0}}{2}+\eta\sum_{j=0}^{t-1}\quantity[(1-\alpha_{j}^{2})c_{\lambda}(\alpha_{j})-\frac{\alpha_{j}d^{-\frac{{k^{\star}}-2}{2}}}{C\lambda^{2}}].

Therefore while αt11k\alpha_{t}\leq 1-\frac{1}{{k^{\star}}}, for sufficiently large CC we have

αtα02+ηC1/2λkj=0t1αjk1.\displaystyle\alpha_{t}\geq\frac{\alpha_{0}}{2}+\frac{\eta}{C^{1/2}\lambda^{{k^{\star}}}}\sum_{j=0}^{t-1}\alpha_{j}^{{k^{\star}}-1}.

Therefore by Lemma 27, we have that there exists tT1/2t\leq T_{1}/2 such that αt11k\alpha_{t}\geq 1-\frac{1}{{k^{\star}}}. Next, let pt=1αtp_{t}=1-\alpha_{t}. Then applying Lemma 15 to ptp_{t} and using (11k)k1/e(1-\frac{1}{{k^{\star}}})^{k^{\star}}\gtrsim 1/e gives that if

r:=dk22Cλ2 and c:=1C1/2λk\displaystyle r:=\frac{d^{-\frac{{k^{\star}}-2}{2}}}{C\lambda^{2}}\mbox{\quad and\quad}c:=\frac{1}{C^{1/2}\lambda^{{k^{\star}}}}

then

pt+1\displaystyle p_{t+1} ptηcpt+ηr+Zt\displaystyle\leq p_{t}-\eta cp_{t}+\eta r+Z_{t}
=(1ηc)pt+ηr+Zt.\displaystyle=(1-\eta c)p_{t}+\eta r+Z_{t}.

Therefore,

pt+s(1ηc)spt+r/c+i=0s1(1ηc)iZt+s1i.\displaystyle p_{t+s}\leq(1-\eta c)^{s}p_{t}+r/c+\sum_{i=0}^{s-1}(1-\eta c)^{i}Z_{t+s-1-i}.

With high probability, the martingale term is bounded by λd1/2/C\lambda d^{-1/2}/C as before as long as sT1s\leq T_{1}, so for s[Clog(d)ηc,T1]s\in[\frac{C\log(d)}{\eta c},T_{1}] we have that pt+sC1/2((λ2d)k22+λd1/2)C1/2d1/4p_{t+s}\lesssim C^{-1/2}\quantity(\quantity(\frac{\lambda^{2}}{d})^{\frac{{k^{\star}}-2}{2}}+\lambda d^{-1/2})\lesssim C^{-1/2}d^{-1/4}. Setting s=T1ts=T_{1}-t and choosing CC appropriately yields pT1d1/4p_{T_{1}}\leq d^{-1/4}, which completes the proof. ∎

Finally, the third stage guarantees not only a hitting time but a last iterate guarantee. It also achieves the optimal sample complexity in terms of the target accuracy ϵ\epsilon:

Lemma 18 (Stage 3).

Assume that α01d1/4\alpha_{0}\geq 1-d^{-1/4}. Set λ=0\lambda=0 and

ηt=CC4d+t.\displaystyle\eta_{t}=\frac{C}{C^{4}d+t}.

for a sufficiently large constant CC. Then for any texp(d1/C)t\leq\exp(d^{1/C}), we have that with high probability,

αt1O(dd+t).\displaystyle\alpha_{t}\geq 1-O\quantity(\frac{d}{d+t}).
Proof.

Let pt=1αtp_{t}=1-\alpha_{t}. By Lemma 15, while pt1/kp_{t}\leq 1/{k^{\star}}:

pt+1\displaystyle p_{t+1} ptηtpt2C+Cηt2d+ηtptWt+ηt2dZt\displaystyle\leq p_{t}-\frac{\eta_{t}p_{t}}{2C}+C\eta_{t}^{2}d+\eta_{t}\sqrt{p_{t}}\cdot W_{t}+\eta_{t}^{2}d\cdot Z_{t}
=(C4d+t2C4d+t)pt+Cηt2d+ηtptWt+ηt2dZt.\displaystyle=\quantity(\frac{C^{4}d+t-2}{C^{4}d+t})p_{t}+C\eta_{t}^{2}d+\eta_{t}\sqrt{p_{t}}\cdot W_{t}+\eta_{t}^{2}d\cdot Z_{t}.

where the moments of Wt,ZtW_{t},Z_{t} are each bounded by poly(p)\operatorname{poly}(p). We will prove by induction that with probability at least 1texp(Cd1/C/e)1-t\exp(-Cd^{1/C}/e), we have for all sts\leq t:

ps2C3dC4d+s2C1k.\displaystyle p_{s}\leq\frac{2C^{3}d}{C^{4}d+s}\leq\frac{2}{C}\leq\frac{1}{{k^{\star}}}.

The base case is clear so assume the result for all sts\leq t. Then from the recurrence above,

pt+1p0C8d2(C4d+t)2+1(C4d+t)2j=0t(C4d+j)2[Cηj2d+ηtptWt+ηt2dZt].\displaystyle p_{t+1}\leq p_{0}\frac{C^{8}d^{2}}{(C^{4}d+t)^{2}}+\frac{1}{(C^{4}d+t)^{2}}\sum_{j=0}^{t}(C^{4}d+j)^{2}\quantity[C\eta_{j}^{2}d+\eta_{t}\sqrt{p_{t}}\cdot W_{t}+\eta_{t}^{2}d\cdot Z_{t}].

First, because p0d1/4p_{0}\leq d^{-1/4},

C8p0d2(C4d+t)2C4p0dC4d+tdC4d+t.\displaystyle\frac{C^{8}p_{0}d^{2}}{(C^{4}d+t)^{2}}\leq\frac{C^{4}p_{0}d}{C^{4}d+t}\ll\frac{d}{C^{4}d+t}.

Next,

1(C4d+t)2j=0t(C4d+j)2Cηj2d\displaystyle\frac{1}{(C^{4}d+t)^{2}}\sum_{j=0}^{t}(C^{4}d+j)^{2}C\eta_{j}^{2}d =1(C4d+t)2j=0tC3d\displaystyle=\frac{1}{(C^{4}d+t)^{2}}\sum_{j=0}^{t}C^{3}d
=C3dt(C4d+t)2\displaystyle=\frac{C^{3}dt}{(C^{4}d+t)^{2}}
C3dC4d+t.\displaystyle\leq\frac{C^{3}d}{C^{4}d+t}.

The next error term is:

1(C4d+t)2j=0t(C4d+j)2ηtptWt.\displaystyle\frac{1}{(C^{4}d+t)^{2}}\sum_{j=0}^{t}(C^{4}d+j)^{2}\eta_{t}\sqrt{p_{t}}\cdot W_{t}.

Fix p=d1/Cep=\frac{d^{1/C}}{e}. Then we will bound the ppth moment of ptp_{t}:

𝔼[ptp]\displaystyle\operatorname{\mathbb{E}}[p_{t}^{p}] 2C3dC4d+t+𝔼[ptp𝟏pt2C3dC4d+s]\displaystyle\leq\frac{2C^{3}d}{C^{4}d+t}+\operatorname{\mathbb{E}}\quantity[p_{t}^{p}\mathbf{1}_{p_{t}\geq\frac{2C^{3}d}{C^{4}d+s}}]
(2C3dC4d+t)p+2p[pt2C3dC4d+t]\displaystyle\leq\quantity(\frac{2C^{3}d}{C^{4}d+t})^{p}+2^{p}\operatorname{\mathbb{P}}\quantity[p_{t}\geq\frac{2C^{3}d}{C^{4}d+t}]
(2C3dC4d+t)p+2ptexp(Cd1/C).\displaystyle\leq\quantity(\frac{2C^{3}d}{C^{4}d+t})^{p}+2^{p}t\exp(-Cd^{1/C}).

Now note that because texp(d1/C)t\leq\exp(d^{1/C}),

log(texp(Cd1/C))=log(t)Cd1/Clog(t)(C1)d1/Cplog(t).\displaystyle\log(t\exp(-Cd^{1/C}))=\log(t)-Cd^{1/C}\leq-\log(t)(C-1)d^{1/C}\leq-p\log(t).

Therefore 𝔼[ptp]1/p4C3dC4d+t.\operatorname{\mathbb{E}}[p_{t}^{p}]^{1/p}\leq\frac{4C^{3}d}{C^{4}d+t}. Therefore the pp norm of the predictable quadratic variation of the next error term is bounded by:

poly(p)j=0t(C4d+j)4ηt2𝔼[ptp]1/p\displaystyle\operatorname{poly}(p)\sum_{j=0}^{t}(C^{4}d+j)^{4}\eta_{t}^{2}\operatorname{\mathbb{E}}[p_{t}^{p}]^{1/p} poly(p)j=0tC5d(C4d+j)\displaystyle\leq\operatorname{poly}(p)\sum_{j=0}^{t}C^{5}d(C^{4}d+j)
poly(p)C5dt(C4d+t).\displaystyle\lesssim\operatorname{poly}(p)C^{5}dt(C^{4}d+t).

In addition, the pp norm of the largest term in this sum is bounded by

poly(p)C5d(C4d+t).\displaystyle\operatorname{poly}(p)\sqrt{C^{5}d(C^{4}d+t)}.

Therefore by Lemma 22 and Lemma 24, we have with probability at least 1exp(Cd1/C/e)1-\exp(Cd^{-1/C}/e), this term is bounded by

C5dt(C4d+t)3/2d1/2C3dC4d+t.\displaystyle\frac{\sqrt{C^{5}dt}}{(C^{4}d+t)^{3/2}}\cdot d^{1/2}\leq\frac{C^{3}d}{C^{4}d+t}.

Finally, the last term is similarly bounded with probability at least 1exp(Cd1/C/e)1-\exp(-Cd^{-1/C}/e) by

C2dt(C4d+t)2d1/2C3dC4d+t\displaystyle\frac{C^{2}d\sqrt{t}}{(C^{4}d+t)^{2}}\cdot d^{1/2}\ll\frac{C^{3}d}{C^{4}d+t}

which completes the induction. ∎

We can now combine the above lemmas to prove Theorem 1:

Proof of Theorem 1.

By Lemmas 16, 17 and 18, if T=T1+T2T=T_{1}+T_{2} we have with high probability for all T2exp(d1/C)T_{2}\leq\exp(d^{1/C}):

αT1O(dd+T2).\displaystyle\alpha_{T}\geq 1-O\quantity(\frac{d}{d+T_{2}}).

Next, note that by Bernoulli’s inequality ((1+x)n1+nx(1+x)^{n}\geq 1+nx), we have that 1αkk(1α)1-\alpha^{k}\leq k(1-\alpha). Therefore,

L(wT)\displaystyle L(w_{T}) =k0ck2k![1αTk]\displaystyle=\sum_{k\geq 0}\frac{c_{k}^{2}}{k!}[1-\alpha_{T}^{k}]
(1αT)k0ck2(k1)!\displaystyle\leq(1-\alpha_{T})\sum_{k\geq 0}\frac{c_{k}^{2}}{(k-1)!}
=(1αT)𝔼xN(0,1)[σ(x)2]\displaystyle=(1-\alpha_{T})\operatorname{\mathbb{E}}_{x\sim N(0,1)}[\sigma^{\prime}(x)^{2}]
dd+T2\displaystyle\lesssim\frac{d}{d+T_{2}}

which completes the proof of Theorem 1. ∎

B.5 Proof of Theorem 2

We directly follow the proof of Theorem 2 in Damian et al. [6] which is reproduced here for completeness. We begin with the following general CSQ lemma which can be found in Szörényi [38], Damian et al. [6]:

Lemma 19.

Let \mathcal{F} be a class of functions and 𝒟\mathcal{D} be a data distribution such that

𝔼x𝒟[f(x)2]=1 and |𝔼x𝒟[f(x)g(x)]|ϵfg.\displaystyle\operatorname{\mathbb{E}}_{x\sim\mathcal{D}}[f(x)^{2}]=1\mbox{\quad and\quad}\absolutevalue{\operatorname{\mathbb{E}}_{x\sim\mathcal{D}}[f(x)g(x)]}\leq\epsilon\qquad\forall f\neq g\in\mathcal{F}.

Then any correlational statistical query learner requires at least ||(τ2ϵ)2\frac{\absolutevalue{\mathcal{F}}(\tau^{2}-\epsilon)}{2} queries of tolerance τ\tau to output a function in \mathcal{F} with L2(𝒟)L^{2}(\mathcal{D}) loss at most 22ϵ2-2\epsilon.

First, we will construct a function class from a subset of :={σ(wx):wSd1}\mathcal{F}:=\quantity{\sigma(w\cdot x)~{}:~{}w\in S^{d-1}}. By [6, Lemma 3], for any ϵ\epsilon there exist 12ecϵ2d\frac{1}{2}e^{c\epsilon^{2}d} unit vectors w1,,wsw_{1},\ldots,w_{s} such that their pairwise inner products are all bounded by ϵ\epsilon. Let ^:={σ(wix):i[s]}\widehat{\mathcal{F}}:=\quantity{\sigma(w_{i}\cdot x)~{}:~{}i\in[s]}. Then for iji\neq j,

|𝔼xN(0,Id)[σ(wix)σ(wjx)]|=|k0ck2k!(wiwj)k||wiwj|kϵk.\displaystyle\absolutevalue{\operatorname{\mathbb{E}}_{x\sim N(0,I_{d})}[\sigma(w_{i}\cdot x)\sigma(w_{j}\cdot x)]}=\absolutevalue{\sum_{k\geq 0}\frac{c_{k}^{2}}{k!}(w_{i}\cdot w_{j})^{k}}\leq\absolutevalue{w_{i}\cdot w_{j}}^{{k^{\star}}}\leq\epsilon^{{k^{\star}}}.

Therefore by Lemma 19,

4mecϵ2d(τ2ϵk).\displaystyle 4m\geq e^{c\epsilon^{2}d}(\tau^{2}-\epsilon^{{k^{\star}}}).

Now set

ϵ=log(4m(cd)k/2)cd\displaystyle\epsilon=\sqrt{\frac{\log\quantity(4m(cd)^{{k^{\star}}/2})}{cd}}

which gives

τ21+logk/2(4m(cd)k/2)(cd)k/2logk/2(md)dk/2.\displaystyle\tau^{2}\leq\frac{1+\log^{k/2}(4m(cd)^{{k^{\star}}/2})}{(cd)^{{k^{\star}}/2}}\lesssim\frac{\log^{{k^{\star}}/2}(md)}{d^{{k^{\star}}/2}}.

Appendix C Concentration Inequalities

Lemma 20 (Rosenthal-Burkholder-Pinelis Inequality [39]).

Let {Yi}i=0n\{Y_{i}\}_{i=0}^{n} be a martingale with martingale difference sequence {Xi}i=1n\{X_{i}\}_{i=1}^{n} where Xi=YiYi1X_{i}=Y_{i}-Y_{i-1}. Let

Y=i=1n𝔼[Xi2|i1]\displaystyle\expectationvalue{Y}=\sum_{i=1}^{n}\operatorname{\mathbb{E}}[\norm{X_{i}}^{2}|\mathcal{F}_{i-1}]

denote the predictable quadratic variation. Then there exists an absolute constant CC such that for all pp,

YnpC[pYp/2+pmaxiXip].\displaystyle\norm{Y_{n}}_{p}\leq C\quantity[\sqrt{p\norm{\expectationvalue{Y}}_{p/2}}+p~{}\norm{\max_{i}\norm{X_{i}}}_{p}].

The above inequality is found in Pinelis [39, Theorem 4.1]. It is often combined with the following simple lemma:

Lemma 21.

For any random variables X1,,XnX_{1},\ldots,X_{n},

maxiXip(i=1nXipp)1/p.\displaystyle\norm{\max_{i}\norm{X_{i}}}_{p}\leq\quantity(\sum_{i=1}^{n}\norm{X_{i}}_{p}^{p})^{1/p}.

This has the immediate corollary:

Lemma 22.

Let {Yi}i=0n\{Y_{i}\}_{i=0}^{n} be a martingale with martingale difference sequence {Xi}i=1n\{X_{i}\}_{i=1}^{n} where Xi=YiYi1X_{i}=Y_{i}-Y_{i-1}. Let Y=i=1n𝔼[Xi2|i1]\expectationvalue{Y}=\sum_{i=1}^{n}\operatorname{\mathbb{E}}[\norm{X_{i}}^{2}|\mathcal{F}_{i-1}] denote the predictable quadratic variation. Then there exists an absolute constant CC such that for all pp,

YnpC[pYp/2+pn1/pmaxiXip].\displaystyle\norm{Y_{n}}_{p}\leq C\quantity[\sqrt{p\norm{\expectationvalue{Y}}_{p/2}}+pn^{1/p}~{}\max_{i}\norm{X_{i}}_{p}].

We will often use the following corollary of Holder’s inequality to bound the operator norm of a product of two random variables when one has polynomial tails:

Lemma 23.

Let X,YX,Y be random variables with YpσYpC\norm{Y}_{p}\leq\sigma_{Y}p^{C}. Then,

𝔼[XY]X1σY(2e)Cmax(1,1Clog(X2X1))C.\displaystyle\operatorname{\mathbb{E}}[XY]\leq\norm{X}_{1}\cdot\sigma_{Y}\cdot(2e)^{C}\cdot\max\quantity(1,\frac{1}{C}\log\quantity(\frac{\norm{X}_{2}}{\norm{X}_{1}}))^{C}.
Proof.

Fix ϵ[0,1]\epsilon\in[0,1]. Then using Holder’s inequality with 1=1ϵ+ϵ2+ϵ21=1-\epsilon+\frac{\epsilon}{2}+\frac{\epsilon}{2} gives:

𝔼[XY]=𝔼[X1ϵXϵY]X11ϵX2ϵY2/ϵ.\displaystyle\operatorname{\mathbb{E}}[XY]=\operatorname{\mathbb{E}}[X^{1-\epsilon}X^{\epsilon}Y]\leq\norm{X}_{1}^{1-\epsilon}\norm{X}_{2}^{\epsilon}\norm{Y}_{2/\epsilon}.

Using the fact that X,YX,Y have polynomial tails we can bound this by

𝔼[XY]=𝔼[X1ϵXϵY]X11ϵX2ϵσY(2/ϵ)C.\displaystyle\operatorname{\mathbb{E}}[XY]=\operatorname{\mathbb{E}}[X^{1-\epsilon}X^{\epsilon}Y]\leq\norm{X}_{1}^{1-\epsilon}\norm{X}_{2}^{\epsilon}\sigma_{Y}(2/\epsilon)^{C}.

First, if X2eCX1\norm{X}_{2}\geq e^{C}\norm{X}_{1}, we can set ϵ=Clog(X2X1)\epsilon=\frac{C}{\log\quantity(\frac{\norm{X}_{2}}{\norm{X}_{1}})} which gives

𝔼[XY]X1σY(2eClog(X2X1))C.\displaystyle\operatorname{\mathbb{E}}[XY]\leq\norm{X}_{1}\cdot\sigma_{Y}\cdot\quantity(\frac{2e}{C}\log\quantity(\frac{\norm{X}_{2}}{\norm{X}_{1}}))^{C}.

Next, if X2eCX1\norm{X}_{2}\leq e^{C}\norm{X}_{1} we can set ϵ=1\epsilon=1 which gives

𝔼[XY]X2Y2X1σY(2e)C\displaystyle\operatorname{\mathbb{E}}[XY]\leq\norm{X}_{2}\norm{Y}_{2}\leq\norm{X}_{1}\sigma_{Y}(2e)^{C}

which completes the proof. ∎

Finally, the following basic lemma will allow is to easily convert between pp-norm bounds and concentration inequalities:

Lemma 24.

Let δ0\delta\geq 0 and let XX be a mean zero random variable satisfying

𝔼[|X|p]1/pσXpC for p=log(1/δ)C\displaystyle\operatorname{\mathbb{E}}[\absolutevalue{X}^{p}]^{1/p}\leq\sigma_{X}p^{C}\mbox{\quad for\quad}p=\frac{\log(1/\delta)}{C}

for some CC. Then with probability at least 1δ1-\delta, |X|σX(ep)C\absolutevalue{X}\leq\sigma_{X}(ep)^{C}.

Proof.

Let ϵ=σX(ep)C\epsilon=\sigma_{X}(ep)^{C}. Then,

[|X|ϵ]\displaystyle\operatorname{\mathbb{P}}[\absolutevalue{X}\geq\epsilon] =[|X|pϵp]\displaystyle=\operatorname{\mathbb{P}}[\absolutevalue{X}^{p}\geq\epsilon^{p}]
𝔼[|X|p]ϵp\displaystyle\leq\frac{\operatorname{\mathbb{E}}[\absolutevalue{X}^{p}]}{\epsilon^{p}}
(σX)pppCϵp\displaystyle\leq\frac{(\sigma_{X})^{p}p^{pC}}{\epsilon^{p}}
=eCp\displaystyle=e^{-Cp}
=δ.\displaystyle=\delta.

Appendix D Additional Technical Lemmas

The following lemma extends Steins’s lemma (𝔼xN(0,1)[xg(x)]=𝔼xN(0,1)[g(x)]\operatorname{\mathbb{E}}_{x\sim N(0,1)}[xg(x)]=\operatorname{\mathbb{E}}_{x\sim N(0,1)}[g^{\prime}(x)]) to the ultraspherical distribution μ(d)\mu^{(d)} where μ(d)\mu^{(d)} is the distribution of z1z_{1} when zSd1z\sim S^{d-1}:

Lemma 25 (Spherical Stein’s Lemma).

For any gL2(μ(d))g\in L^{2}(\mu^{(d)}),

𝔼zSd1[z1g(z1)]=𝔼zSd+1[g(z1)]d.\displaystyle\operatorname{\mathbb{E}}_{z\sim S^{d-1}}[z_{1}g(z_{1})]=\frac{\operatorname{\mathbb{E}}_{z\sim S^{d+1}}[g^{\prime}(z_{1})]}{d}.
Proof.

Recall that the density of z1z_{1} is equal to

(1x2)d32C(d) where C(d):=πΓ(d12)Γ(d2).\displaystyle\frac{(1-x^{2})^{\frac{d-3}{2}}}{C(d)}\mbox{\quad where\quad}C(d):=\frac{\sqrt{\pi}\cdot\Gamma(\frac{d-1}{2})}{\Gamma(\frac{d}{2})}.

Therefore,

𝔼zSd1[z1g(z1)]=1C(d)11z1g(z1)(1z2)d32𝑑z1.\displaystyle\operatorname{\mathbb{E}}_{z\sim S^{d-1}}[z_{1}g(z_{1})]=\frac{1}{C(d)}\int_{-1}^{1}z_{1}g(z_{1})(1-z^{2})^{\frac{d-3}{2}}dz_{1}.

Now we can integrate by parts to get

𝔼zSd1[z1g(z1)]\displaystyle\operatorname{\mathbb{E}}_{z\sim S^{d-1}}[z_{1}g(z_{1})] =1C(d)11g(z1)(1z2)d12d1𝑑z1\displaystyle=\frac{1}{C(d)}\int_{-1}^{1}\frac{g^{\prime}(z_{1})(1-z^{2})^{\frac{d-1}{2}}}{d-1}dz_{1}
=C(d+2)C(d)(d1)𝔼zSd+1[g(z1)]\displaystyle=\frac{C(d+2)}{C(d)(d-1)}\operatorname{\mathbb{E}}_{z\sim S^{d+1}}[g^{\prime}(z_{1})]
=1d𝔼zSd+1[g(z1)].\displaystyle=\frac{1}{d}\operatorname{\mathbb{E}}_{z\sim S^{d+1}}[g^{\prime}(z_{1})].

Lemma 26.

For jd/4j\leq d/4,

𝔼zSd1(z1k(1z12)j)𝔼zSd2j1(z1k).\displaystyle\operatorname{\mathbb{E}}_{z\sim S^{d-1}}\quantity(\frac{z_{1}^{k}}{(1-z_{1}^{2})^{j}})\lesssim\operatorname{\mathbb{E}}_{z\sim S^{d-2j-1}}\quantity(z_{1}^{k}).
Proof.

Recall that the PDF of μ(d)\mu^{(d)} is

(1x2)d32C(d) where C(d):=πΓ(d12)Γ(d2).\displaystyle\frac{(1-x^{2})^{\frac{d-3}{2}}}{C(d)}\mbox{\quad where\quad}C(d):=\frac{\sqrt{\pi}\cdot\Gamma(\frac{d-1}{2})}{\Gamma(\frac{d}{2})}.

Using this we have that:

𝔼zSd1[z1k(1z12)j]\displaystyle\operatorname{\mathbb{E}}_{z\sim S^{d-1}}\quantity[\frac{z_{1}^{k}}{(1-z_{1}^{2})^{j}}] =1C(d)11xk(1x2)j(1x2)d32𝑑x\displaystyle=\frac{1}{C(d)}\int_{-1}^{1}\frac{x^{k}}{(1-x^{2})^{j}}(1-x^{2})^{\frac{d-3}{2}}dx
=1C(d)11xk(1x2)d2j32𝑑x\displaystyle=\frac{1}{C(d)}\int_{-1}^{1}x^{k}(1-x^{2})^{\frac{d-2j-3}{2}}dx
=C(d2j)C(d)𝔼z1μ(d2j)[z1k]\displaystyle=\frac{C(d-2j)}{C(d)}\operatorname{\mathbb{E}}_{z_{1}\sim\mu^{(d-2j)}}[z_{1}^{k}]
=Γ(d2)Γ(d2j12)Γ(d12)Γ(d2j2)𝔼zSd2j1[z1k]\displaystyle=\frac{\Gamma(\frac{d}{2})\Gamma(\frac{d-2j-1}{2})}{\Gamma(\frac{d-1}{2})\Gamma(\frac{d-2j}{2})}\operatorname{\mathbb{E}}_{z\sim S^{d-2j-1}}[z_{1}^{k}]
𝔼zSd2j1[z1k].\displaystyle\lesssim\operatorname{\mathbb{E}}_{z\sim S^{d-2j-1}}[z_{1}^{k}].

We have the following generalization of Lemma 8:

Corollary 3.

For any k,j0k,j\geq 0 with d2j+1d\geq 2j+1 and αC1/4d1/2\alpha\geq C^{-1/4}d^{1/2}, there exist c(j,k),C(j,k)c(j,k),C(j,k) such that

λ(αk(1α)j)C(j,k)sk(α;λ).\displaystyle\mathcal{L}_{\lambda}\quantity(\frac{\alpha^{k}}{(1-\alpha)^{j}})\leq C(j,k)s_{k}(\alpha;\lambda).
Proof.

Expanding the definition of λ\mathcal{L}_{\lambda} gives:

λ(αk(1α)j)\displaystyle\mathcal{L}_{\lambda}\quantity(\frac{\alpha^{k}}{(1-\alpha)^{j}}) =𝔼zSd2[(α+λz11α21+λ2)k(1α+λz11α21+λ2)j].\displaystyle=\operatorname{\mathbb{E}}_{z\sim S^{d-2}}\quantity[\frac{\quantity(\frac{\alpha+\lambda z_{1}\sqrt{1-\alpha^{2}}}{\sqrt{1+\lambda^{2}}})^{k}}{\quantity(1-\frac{\alpha+\lambda z_{1}\sqrt{1-\alpha^{2}}}{\sqrt{1+\lambda^{2}}})^{j}}].

Now let X=λ1α21+λ2(1α1+λ2)1X=\frac{\lambda\sqrt{1-\alpha^{2}}}{\sqrt{1+\lambda^{2}}}\cdot\quantity(1-\frac{\alpha}{\sqrt{1+\lambda^{2}}})^{-1} and note that by Cauchy-Schwarz, X1X\leq 1. Then,

λ(αk(1α)j)\displaystyle\mathcal{L}_{\lambda}\quantity(\frac{\alpha^{k}}{(1-\alpha)^{j}}) =1(1α1+λ2)j𝔼zSd2[(α+λz1α21+λ2)k(1Xz1)j]\displaystyle=\frac{1}{\quantity(1-\frac{\alpha}{\sqrt{1+\lambda^{2}}})^{j}}\operatorname{\mathbb{E}}_{z\sim S^{d-2}}\quantity[\frac{\quantity(\frac{\alpha+\lambda z\sqrt{1-\alpha^{2}}}{\sqrt{1+\lambda^{2}}})^{k}}{(1-Xz_{1})^{j}}]
𝔼zSd2[(1+Xz1)j(α+λz1α21+λ2)k(1X2z12)j].\displaystyle\asymp\operatorname{\mathbb{E}}_{z\sim S^{d-2}}\quantity[\frac{(1+Xz_{1})^{j}\quantity(\frac{\alpha+\lambda z\sqrt{1-\alpha^{2}}}{\sqrt{1+\lambda^{2}}})^{k}}{(1-X^{2}z_{1}^{2})^{j}}].

Now we can use the binomial theorem to expand this. Ignoring constants only depending on j,kj,k:

λ(αk(1α)j)\displaystyle\mathcal{L}_{\lambda}\quantity(\frac{\alpha^{k}}{(1-\alpha)^{j}}) =1(1α1+λ2)j𝔼zSd2[(α+λz11α21+λ2)k(1Xz1)j]\displaystyle=\frac{1}{\quantity(1-\frac{\alpha}{\sqrt{1+\lambda^{2}}})^{j}}\operatorname{\mathbb{E}}_{z\sim S^{d-2}}\quantity[\frac{\quantity(\frac{\alpha+\lambda z_{1}\sqrt{1-\alpha^{2}}}{\sqrt{1+\lambda^{2}}})^{k}}{(1-Xz_{1})^{j}}]
λki=0kαkiλi(1α2)i/2𝔼zSd2[(1+Xz1)jz1i(1X2z12)j]\displaystyle\asymp\lambda^{-k}\sum_{i=0}^{k}\alpha^{k-i}\lambda^{i}(1-\alpha^{2})^{i/2}\operatorname{\mathbb{E}}_{z\sim S^{d-2}}\quantity[\frac{(1+Xz_{1})^{j}z_{1}^{i}}{(1-X^{2}z_{1}^{2})^{j}}]
λki=0kαkiλi(1α2)i/2𝔼zSd2[(1+Xz1)jz1i(1z12)j].\displaystyle\leq\lambda^{-k}\sum_{i=0}^{k}\alpha^{k-i}\lambda^{i}(1-\alpha^{2})^{i/2}\operatorname{\mathbb{E}}_{z\sim S^{d-2}}\quantity[\frac{(1+Xz_{1})^{j}z_{1}^{i}}{(1-z_{1}^{2})^{j}}].

By Lemma 26, the z1z_{1} term is bounded by di2d^{-\frac{i}{2}} when ii is even and Xdi+12Xd^{-\frac{i+1}{2}} when ii is odd. Therefore this expression is bounded by

(αλ)ki=0k2(λ2(1α2)α2d)i+(αλ)k1i=0k12α2iλ2i(1α2)id(i+1)\displaystyle\quantity(\frac{\alpha}{\lambda})^{k}\sum_{i=0}^{\lfloor\frac{k}{2}\rfloor}\quantity(\frac{\lambda^{2}(1-\alpha^{2})}{\alpha^{2}d})^{i}+\quantity(\frac{\alpha}{\lambda})^{k-1}\sum_{i=0}^{\lfloor\frac{k-1}{2}\rfloor}\alpha^{-2i}\lambda^{2i}(1-\alpha^{2})^{i}d^{-(i+1)}
sk(α;λ)+1dsk1(α;λ).\displaystyle\asymp s_{k}(\alpha;\lambda)+\frac{1}{d}s_{k-1}(\alpha;\lambda).

Now note that

1dsk1(α;λ)sk(α;λ)={λdαα2λ2dαλα2λ2d and k is evenλdαα2λ2d and k is oddC1/4.\displaystyle\frac{\frac{1}{d}s_{k-1}(\alpha;\lambda)}{s_{k}(\alpha;\lambda)}=\begin{cases}\frac{\lambda}{d\alpha}&\alpha^{2}\geq\frac{\lambda^{2}}{d}\\ \frac{\alpha}{\lambda}&\alpha^{2}\leq\frac{\lambda^{2}}{d}\text{ and ${k^{\star}}$ is even}\\ \frac{\lambda}{d\alpha}&\alpha^{2}\leq\frac{\lambda^{2}}{d}\text{ and ${k^{\star}}$ is odd}\end{cases}\leq C^{-1/4}.

Therefore, sk(α;λ)s_{k}(\alpha;\lambda) is the dominant term which completes the proof. ∎

Lemma 27 (Adapted from Abbe et al. [4]).

Let η,a00\eta,a_{0}\geq 0 be positive constants, and let utu_{t} be a sequence satisfying

uta0+ηs=0t1usk.\displaystyle u_{t}\geq a_{0}+\eta\sum_{s=0}^{t-1}u_{s}^{k}.

Then, if max0st1ηusk1log2k\max_{0\leq s\leq t-1}\eta u_{s}^{k-1}\leq\frac{\log 2}{k}, we have the lower bound

ut(a0(k1)12η(k1)t)1k1.\displaystyle u_{t}\geq\quantity(a_{0}^{-(k-1)}-\frac{1}{2}\eta(k-1)t)^{-\frac{1}{k-1}}.
Proof.

Consider the auxiliary sequence wt=a0+ηs=0t1wskw_{t}=a_{0}+\eta\sum_{s=0}^{t-1}w_{s}^{k}. By induction, utwtu_{t}\geq w_{t}. To lower bound wtw_{t}, we have that

η=wtwt1wt1k\displaystyle\eta=\frac{w_{t}-w_{t-1}}{w_{t-1}^{k}} =wtwt1wtkwtkwt1k\displaystyle=\frac{w_{t}-w_{t-1}}{w_{t}^{k}}\cdot\frac{w_{t}^{k}}{w_{t-1}^{k}}
wtkwt1kwt1wt1xk𝑑x\displaystyle\leq\frac{w_{t}^{k}}{w_{t-1}^{k}}\int_{w_{t-1}}^{w_{t}}\frac{1}{x^{k}}dx
=wtkwt1k(k1)(wt1(k1)wt(k1))\displaystyle=\frac{w_{t}^{k}}{w_{t-1}^{k}(k-1)}\quantity(w_{t-1}^{-(k-1)}-w_{t}^{-(k-1)})
(1+ηwt1k1)k(k1)(wt1(k1)wt(k1))\displaystyle\leq\frac{(1+\eta w_{t-1}^{k-1})^{k}}{(k-1)}\quantity(w_{t-1}^{-(k-1)}-w_{t}^{-(k-1)})
(1+log2k)k(k1)(wt1(k1)wt(k1))\displaystyle\leq\frac{(1+\frac{\log 2}{k})^{k}}{(k-1)}\quantity(w_{t-1}^{-(k-1)}-w_{t}^{-(k-1)})
2k1(wt1(k1)wt(k1)).\displaystyle\leq\frac{2}{k-1}\quantity(w_{t-1}^{-(k-1)}-w_{t}^{-(k-1)}).

Therefore

wt(k1)wt1(k1)12η(k1).\displaystyle w_{t}^{-(k-1)}\leq w_{t-1}^{-(k-1)}-\frac{1}{2}\eta(k-1).

Altogether, we get

wt(k1)a0(k1)12η(k1)t,\displaystyle w_{t}^{-(k-1)}\leq a_{0}^{-(k-1)}-\frac{1}{2}\eta(k-1)t,

or

utwt(a0(k1)12η(k1)t)1k1,\displaystyle u_{t}\geq w_{t}\geq\quantity(a_{0}^{-(k-1)}-\frac{1}{2}\eta(k-1)t)^{-\frac{1}{k-1}},

as desired. ∎

Appendix E Additional Experimental Details

To compute the smoothed loss Lλ(w;x;y)L_{\lambda}(w;x;y) we used the closed form for λ(Hek(wx))\mathcal{L}_{\lambda}\quantity(He_{k}(w\cdot x)) (see Section B.3). Experiments were run on 8 NVIDIA A6000 GPUs. Our code is written in JAX [40] and we used Weights and Biases [41] for experiment tracking.