This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Efficient SGD Neural Network Training via Sublinear Activated Neuron Identification

Lianke Qin [email protected]. UCSB.    Zhao Song [email protected]. Adobe Research.    Yuanyuan Yang [email protected]. The University of Washington.

Deep learning has been widely used in many fields, but the model training process usually consumes massive computational resources and time. Therefore, designing an efficient neural network training method with a provable convergence guarantee is a fundamental and important research question. In this paper, we present a static half-space report data structure that consists of a fully connected two-layer neural network for shifted ReLU activation to enable activated neuron identification in sublinear time via geometric search. We also prove that our algorithm can converge in O(M2/ϵ2)O(M^{2}/\epsilon^{2}) time with network size quadratic in the coefficient norm upper bound MM and error term ϵ\epsilon.

1 Introduction

Deep learning is widely used in computer vision [50, 51, 69, 48, 43], natural language processing [19, 49], game playing [68, 71] and beyond. It’s often the case that the training of deep learning algorithms takes an enormous amount of computational resources. A fundamental challenge in this line of research is, therefore, designing an efficient neural network training method that provably converges. Existing work that provably converges suffers a over-parameterized network structure [29, 52, 2, 11, 9, 24, 17, 79, 90, 93, 61, 94, 13, 38].

The preceding study by [21] established that SGD is capable of learning polynomials with restricted weights and particular kernel spaces utilizing a neural network of depth two of near optimal size. This is particularly the case for the set of even polynomials of limited degree and with a coefficient vector norm that does not exceed MM., for input distribution on a unit sphere, O~(M2/ϵ2)\widetilde{O}(M^{2}/\epsilon^{2}) neurons and O(M2/ϵ2)O(M^{2}/\epsilon^{2}) iterations suffice to output a predictor with error ϵ\leq\epsilon via depth two neural networks. However, their algorithm still suffers a cost per iteration as O(mbd)O(mbd), where bb is the batch size, mm is the width of the neural tangent kernel and dd is the dimension of input data point. This linear dependency on mm comes from computing inner products between the gradient of the loss function with respect to weights and the gradient of the weights with respect to the points being queried. This seems to be a natural barrier. One natural question to ask is,

Is there some algorithm that only requires o(mbd)o(mbd) cost per iteration?

We provide a positive response to the preceding question, summarizing our contributions in the following manner:

  • We proposed a static half-space report data structure that consists of a fully connected two-layer neural network with neural tangent kernel for shifted ReLU activation. In specific, we build a half-space report data structure of weights with batched SGD update. At every iteration, our algorithm identifies the weights that are fired by current data points and propagates them efficiently. Additionally, we can show that at any given iteration, the number of activated neurons for each input data point is upper bounded by o(m)o(m). Thus, via geometric search, our algorithm identifies those activated neurons in time sublinear in mm.

  • We show that our algorithm can converge in O(M2/ϵ2)O(M^{2}/\epsilon^{2}) time with network size quadratic in the coefficient norm upper bound MM and error term ϵ\epsilon.

1.1 Our Results

To formally introduce our main results, we first present the definitions regarding the input data distribution. In this paper, we will assume that the input data distribution is on a unit sphere 𝕊d1\mathbb{S}^{d-1}. On top of that, we present the definition of RR-bounded distribution.

Definition 1.1 (RR-bounded distribution).

A distribution 𝒟{\cal D} on 𝕊d1\mathbb{S}^{d-1} is said to be RR-bounded if, for every u𝕊d1u\in\mathbb{S}^{d-1}, the expectation 𝔼x𝒟[u,x2]\operatorname*{{\mathbb{E}}}_{x\sim{\cal D}}[\langle u,x\rangle^{2}] is less than or equal to R2d\frac{R^{2}}{d}.

In practice, any given distribution 𝒟{\cal D} is bounded by d\sqrt{d}. Additionally, many commonly used distributions are bounded by O(1)O(1), or even (1+o(1))(1+o(1)). Examples of such distributions include uniform distribution on a unit sphere 𝕊d1\mathbb{S}^{d-1}, on a discrete cube {±1d}d\{\pm\frac{1}{\sqrt{d}}\}^{d}, and on Ω(d)\Omega(d) randomly selected points.

Then, we present the formal definition of the neural network structure that we will consider in this paper. Our study will focus on fully connected, depth-2 neural networks with 2m2m hidden neurons, implementing a shifted ReLU activation function σb0:\sigma_{b_{0}}:\mathbb{R}\rightarrow\mathbb{R} with 2\ell_{2}-loss function. More formally, we present the definition of prediction function and loss functions as follows:

Definition 1.2 (Prediction and Loss function).

Given b0,xd,Wd×2mb_{0}\in\mathbb{R},x\in\mathbb{R}^{d},W\in\mathbb{R}^{d\times 2m} and a2ma\in\mathbb{R}^{2m}, we say a prediction function ff is 2𝖭𝖭(2m,b0)2\mathsf{NN}(2m,b_{0}) if:

f(W,x,a):=12mr=12marσb0(wr,x)\displaystyle f(W,x,a):=\frac{1}{\sqrt{2m}}\sum_{r=1}^{2m}a_{r}\sigma_{b_{0}}(\langle w_{r},x\rangle)
l(W):=12i=1n(f(W,xi,a)yi)2\displaystyle l(W):=\frac{1}{2}\sum_{i=1}^{n}(f(W,x_{i},a)-y_{i})^{2}

where σb0\sigma_{b_{0}} denote the shifted ReLU function σb0(x)=max{0,xb0}\sigma_{b_{0}}(x)=\max\{0,x-b_{0}\}.

The shifted ReLU function is frequently employed in literature, as well as theoretical investigations as indicated in [94, 82]. Regarding the neural network weights, we leverage the Xavier initialization with zero outputs method as described in [32]. This involves organizing the neurons into pairs, with each pair composed of two neurons initialized identically, differing only by a factor of ±\pm. We denote 𝒩d,mσ:={hW(x)=u,σ(W0x)}{\cal N}_{d,m}^{\sigma}:=\{h_{W}(x)=\langle u,\sigma(W_{0}x)\rangle\}, and W:=(W0,u)W:=(W_{0},u) as the aggregation of all weights. More formally, we present our weight initialization as follows:

Definition 1.3 (Weights at Initialization).

Given input dimension dd, number of neurons 2m2m, constant B>0B>0, we say the weight W=(W0,u)W=(W_{0},u), W02m×nW_{0}\in\mathbb{R}^{2m\times n} is initialized according to distribution (d,m,B){\cal I}(d,m,B) if:

  • For each r[m]r\in[m], we sample wr(0)𝒩(0,Id)w_{r}(0)\sim{\cal N}(0,I_{d}), and wm+r(0)=wr(0)w_{m+r}(0)=w_{r}(0).

  • For each r[m]r\in[m], we sample ara_{r} from {+B,B}\{+B,-B\} uniformly at random, and am+r=ara_{m+r}=a_{r}.

Note that if W(d,m,B)W\sim{\cal I}(d,m,B), then with probability 11, hW(x)=0,xh_{W}(x)=0,~{}\forall x.

Next, we present a benchmark of our algorithm:

Definition 1.4 (Even Polynomial with Bounded Coefficient Norm).

We denote the class of even polynomials with coefficient norm bound MM as 𝒫cM{\cal P}_{c}^{M}. More formally,

𝒫cM:={\displaystyle{\cal P}_{c}^{M}:=\Big{\{} p(x)=|α|is even and caαxα:|α|is even and caα2M2}\displaystyle~{}p(x)=\sum_{|\alpha|~{}\text{is even and }\leq c}a_{\alpha}x^{\alpha}:\sum_{|\alpha|~{}\text{is even and }\leq c}a_{\alpha}^{2}\leq M^{2}\Big{\}}

We are now ready to state our main theorem (a combination of Theorem 5.1 and Lemma 6.9) that, Algorithm 1 is capable of learning even polynomials of bounded norm, denoted 𝒫cM{\cal P}_{c}^{M}, exhibiting near-optimal characteristics in terms of sample complexity and network size. Furthermore, it showcases a per iteration time that is sublinear in mm:

Theorem 1.5 (Main theorem).

Given the following:

  • a constant c>0c>0, accuracy parameter ϵ\epsilon, along with positive constants B>0B>0 and η>0\eta>0,

  • a selection of parameters m=O~(d1ϵ2M2R2),T=O(ϵ2M2)m=\widetilde{O}({d^{-1}\epsilon^{-2}M^{2}R^{2}}),T=O({\epsilon^{-2}M^{2}}),

  • sample access to RR-bounded distribution 𝒟d{\cal D}\in\mathbb{R}^{d}(Definition 1.1),

  • input dimension dd and coefficient norm bound MM(Definition 1.4),

  • 𝒟(h):=𝔼(x,y)𝒟l(h(x),y){\cal L}_{\cal D}(h):=\operatorname*{{\mathbb{E}}}_{(x,y)\sim{\cal D}}l(h(x),y) is the expected loss of predictor hh on input distribution 𝒟{\cal D}.

there exists an algorithm (Algorithm 1) which gives that when running the Stochastic Gradient Descent

  • with a batch size of bb,

  • on 2𝖭𝖭(m,b0=0.4log(2m))2\mathsf{NN}(m,b_{0}=\sqrt{0.4\log(2m)})(Definition 1.2),

  • returns a function hh that satisfies:

    𝔼[𝒟(h)]𝒟(𝒫cM)+ϵ\displaystyle\operatorname*{{\mathbb{E}}}[{\cal L}_{\cal D}(h)]\leq{\cal L}_{\cal D}({\cal P}^{M}_{c})+\epsilon
  • with expected per-iteration running time in

    O~(m1Θ(1/d)bd).\displaystyle\widetilde{O}(m^{1-\Theta(1/d)}bd).

Roadmap.

We first present a technique overview of our paper in Section 2. We then introduce some notations and preliminaries in Section 3. We present our main algorithm and give the proof of correctness of our algorithm in Section 5. We give the proof of the running time of our algorithm in Section 6. We conclude the contribution of this paper in Section 7.

2 Technical Overview

Our work consists of two results: The first result focuses on proving the learnability of shifted ReLU activation on two-layered neural network with SGD update. The second result focuses on designing an efficient half-space reporting data structure based on the weight sparsity induced by shifted ReLU activation function, which gives us time per iteration sublinear in mm.

We prove the first result via reduction-based techniques. In order to prove the learnability of shifted ReLU activation on two-layer neural network (Theorem 5.1), we actually prove a more general statement in Theorem 5.2, where for general activation function σ\sigma, the expected loss of the two-layer neural network is optimal up to an additive ϵ\epsilon error with m=O~(M2/ϵ2)m=\widetilde{O}(M^{2}/\epsilon^{2}) neurons and T=O(M2/ϵ2)T=O(M^{2}/\epsilon^{2}) updates. Next, we prove that the general activation on two-layer network is equivalent to neural tangent kernel training (Lemma 5.3) when weight vectors of the network on neurons are initialized with large enough BB (Definition 1.3). In specific, when BB is large enough, SGD with a general activation function on two-layer neural network follows a lazy update, where the weight vector on the input only moves around a small ball. In this case, the first-order approximation of the network function on the initial weight is approximately equivalent to the original network function. In this light, neural tangent kernel learning suffices to approximate the general two-layer neural network training with SGD update. Then, we analyze the neural tangent kernel training in the language of (vector) random feature scheme, which completes the whole proof for the first result.

The second result is based on the observation that, for shifted ReLU activation function on two-layer neural network with SGD update, the number of fired neuron for each data point is sublinear in the network size 2m2m. In this light, we adapted half-space reporting data structure hsr to boost the per-iteration running time by preprocessing the network weights. More precisely, the algorithm initializes the half-space reporting data structure hsr with initialized weight vectors, then at every batched SGD iteration, the algorithm queries the weight vectors that fire for the points in the batch by hsr and update the fired neuron set in O~(bm1Θ(1/d)d)\widetilde{O}(bm^{1-\Theta(1/d)}d) time.

3 Preliminaries

Notations

We use σ\sigma to denote the descent activation function. We use 𝒟{\cal D} to denote the input distribution. In a matrix WW, the ii-th row is represented as wiw_{i}. We denote the ii-th row in a matrix WW by wiw_{i}. We use xp=(i=1d|xi|p)1p\|x\|_{p}=(\sum_{i=1}^{d}|x_{i}|^{p})^{\frac{1}{p}} to represent the pp-norm of xdx\in\mathbb{R}^{d}. Given a matrix WW, we use |W||W| to denote its spectral norm, defined as |W|=max{Wxx=1}|W|=\max\{\|Wx\||\|x\|=1\}. We adhere to the standard convention where |x|=|x|2|x|=|x|_{2}. Regarding a distribution 𝒟\cal D on a space 𝒳{\cal X}, p1p\geq 1 and f:𝒳f:{\cal X}\rightarrow\mathbb{R}, we use fp,𝒟=(𝔼x𝒟|f(x)|p)1p\|f\|_{p,{\cal D}}=(\operatorname*{{\mathbb{E}}}_{x\sim{\cal D}}|f(x)|^{p})^{\frac{1}{p}}. For any function ff, we use O~(f)\widetilde{O}(f) to represent O(fpolylog(f))O(f\cdot\operatorname{poly}\log(f)). For an integer nn, we use [n][n] to denote the set {1,,n}\{1,\ldots,n\}. For a given function σ\sigma, we leverage σ\sigma^{\prime} and σ′′\sigma^{\prime\prime} to represent its first-order and second-order derivative, respectively.

3.1 Definitions

In this section, we present some definitions of properties on a function ll. To begin with, we present the definition of a convex function.

Definition 3.1 (Convexity).

We say a function ll is convex if for any x1,x2x_{1},x_{2}, we have:

f(x1)f(x2)+f(x2)(x1x2)\displaystyle f(x_{1})\geq f(x_{2})+f^{\prime}(x_{2})(x_{1}-x_{2})

Next, we present the definition of Lipschitzness of a function.

Definition 3.2 (LL-Lipschitz).

We say a function ll is LL-Lipshitz with respect to norm \|\cdot\| if for any x1,x2x_{1},x_{2}, we have:

|l(x1)l(x2)|Lx1x2\displaystyle|l(x_{1})-l(x_{2})|\leq L\|x_{1}-x_{2}\|

3.2 Neural Network Training

Then, we introduce some basic definitions regarding neural network training of supervised learning and some related notations:

Definition 3.3 (Supervised Learning).

The objective of supervised learning is to learn a mapping from an input space, denoted as 𝒳{\cal X}, to an output space, denoted as 𝒴\cal Y, using a sample set S=(xi,yi)i[n]S={(x_{i},y_{i})}_{i\in[n]}. These samples are independently and identically drawn from a distribution 𝒟{\cal D}, which spans across 𝒳×𝒴{\cal X}\times{\cal Y}.

A special case of the supervised learning is binary classification, where the prediction is a binary label.

Definition 3.4 (Binary Classification).

The binary classification problem is characterized by the label 𝒴=±1{\cal Y}={\pm{1}}. Specifically, given a loss function l:×𝒴[0,)l:\mathbb{R}\times{\cal Y}\rightarrow[0,\infty), the aim is to identify a predictor h:𝒳h:{\cal X}\rightarrow\mathbb{R} with a loss 𝒟(h):=𝔼(x,y)𝒟l(h(x),y){\cal L}_{\cal D}(h):=\operatorname*{{\mathbb{E}}}_{(x,y)\sim{\cal D}}l(h(x),y) is small.

Moreover, when a function hh is defined by a parameter vector ww, we denote 𝒟(w):=𝒟(h){\cal L}_{\cal D}(w):={\cal L}_{\cal D}(h), and l(x,y)(w):=l(h(x),y)l_{(x,y)}(w):=l(h(x),y). For a class {\cal H} of predictors from 𝒳\cal X to \mathbb{R}, we denote

𝒟():=infh𝒟𝒟(h).\displaystyle{\cal L}_{\cal D}({\cal H}):=\inf_{h\in{\cal D}}{\cal L}_{\cal D}(h).

For classification problems, the properties of their loss function are defined as follows:

Definition 3.5 (Properties of Loss Function).

A loss function ll exhibits LL-Lipschitz characteristics if

  • for all y𝒴y\in{\cal Y}, the function ly(y^):=l(y^,y)l_{y}(\widehat{y}):=l(\widehat{y},y) adheres to LL-Lipschitz properties (Definition 3.2).

Similarly, it is considered convex if lyl_{y} is convex (Definition 3.1) for each y𝒴y\in{\cal Y}. The function is considered to have LL-descent properties if lyl_{y} is convex, LL-Lipschitz, and twice differentiable except at a finite number of points for every y𝒴y\in{\cal Y}.

Note that shifted ReLU activation is a descent activation function. Furthermore, we present the definition of empirical loss on mm samples:

Definition 3.6 (Empirical Loss).

The empirical loss for a set of mm points is:

S(h):=1mi=1ml(h(xi),yi)\displaystyle{\cal L}_{S}(h):=\frac{1}{m}\sum_{i=1}^{m}l(h(x_{i}),y_{i})

Furthermore, when function hh is defined by a vector of parameters ww, we denote S(w):=S(h){\cal L}_{S}(w):={\cal L}_{S}(h). For a class {\cal H} of predictors 𝒳\cal X\rightarrow\mathbb{R}, we denote S()=infhS(h){\cal L}_{S}({\cal H})=\inf_{h\in{\cal H}}{\cal L}_{S}(h).

In the remainder of our paper, we denote NeuralNetworkTraining(σ,d,m,l,η,b,T,B)(\sigma,d,m,l,\eta,b,T,B) as the neural network training with activation σ\sigma, input dimension dd, weight dimension mm, loss ll, learning rate η\eta, SGD batch size bb, initialization parameter BB and the number of iteration TT. Additionally, this optimization process initialize weight vector as W(d,m,B)W\sim{\cal I}(d,m,B)(Definition 1.3).

3.3 Kernel Spaces

In this section, we provide some definitions regarding kernel and kernel spaces.

Definition 3.7 (Kernel).

Let 𝒳{\cal X} be a given set. A kernel is defined as a function 𝖪:𝒳×𝒳\mathsf{K}:{\cal X}\times{\cal X}\rightarrow\mathbb{R} that guarantees, for all x1,,xn𝒳x_{1},\ldots,x_{n}\in{\cal X}, the resulting matrix {𝖪(xi,xj)}i,j\{\mathsf{K}(x_{i},x_{j})\}_{i,j} is positive semi-definite. A kernel space pertains to a Hilbert space {\cal H} in which the mapping f(x){\cal H}\mapsto f(x) is bounded. The following theorem delineates a bijective correlation between kernels and kernel spaces.

The succeeding theorem details a one-to-one relationship between kernels and kernel spaces.

Theorem 3.8 (Kernel versus Kernel Spaces).

For each kernel 𝖪\mathsf{K}, a unique kernel space 𝖪{\cal H}\mathsf{K} exists such that for all x1,x2𝒳,𝖪(x1,x2)=𝖪(,x1),𝖪(,x2)𝖪x_{1},x_{2}\in{\cal X},\mathsf{K}(x_{1},x_{2})=\langle\mathsf{K}(\cdot,x_{1}),\mathsf{K}(\cdot,x_{2})\rangle_{{\cal H}_{\mathsf{K}}}. Similarly, for every kernel space \cal H, a kernel 𝖪\mathsf{K} can be found such that =𝖪{\cal H}={\cal H}_{\mathsf{K}}.

Within the context of 𝖪{\cal H}_{\mathsf{K}}, the norm, and inner product are denoted by 𝖪\|\cdot\|_{\mathsf{K}} and ,𝖪\langle\cdot,\cdot\rangle_{\mathsf{K}} respectively. The ensuing theorem elucidates the robust correlation between kernels and the embeddings of XX into Hilbert spaces.

Theorem 3.9 (Kernel versus Embedding).

A function 𝖪:𝒳×𝒳\mathsf{K}:{\cal X}\times{\cal X}\rightarrow\mathbb{R} is recognized as a kernel if and only if a mapping Ψ:𝒳\Psi:{\cal X}\rightarrow{\cal H} exists to some Hilbert space where

𝖪(x1,x2)=Ψ(x1),Ψ(x2).\displaystyle\mathsf{K}(x_{1},x_{2})=\langle\Psi(x_{1}),\Psi(x_{2})\rangle_{\cal H}.

In this situation, we have:

𝖪={fΨ,v|v}, with fΨ,v(x)=v,Ψ(x).\displaystyle{\cal H}_{\mathsf{K}}=\{f{\Psi,v}|v\in{\cal H}\}\text{, with }f_{\Psi,v}(x)=\langle v,\Psi(x)\rangle_{\cal H}.

Furthermore, we denote f𝖪:=min{|v|:fΨ,v}\|f\|_{\mathsf{K}}:=\min\{|v|{\cal H}:f_{\Psi,v}\}, and the minimizer is unique.

We will leverage a certain kind of kernels which are known as inner product kernels. These are kernels 𝖪:𝕊d1×𝕊d1\mathsf{K}:\mathbb{S}^{d-1}\times\mathbb{S}^{d-1}\rightarrow\mathbb{R} given by 𝖪(x,y)=n=0bnx,yn\mathsf{K}(x,y)=\sum_{n=0}^{\infty}b_{n}\langle x,y\rangle^{n} where bn>0b_{n}>0 are scalars satisfying n=0bn<\sum_{n=0}^{\infty}b_{n}<\infty. It is well known that for any such series, 𝖪\mathsf{K} acts as a kernel. The upcoming lemma outlines a few properties of inner product kernels.

Lemma 3.10 (Characteristics of Inner Product Kernel [21]).

Let 𝖪\mathsf{K} be the inner product kernel 𝖪(x,y)=n=1bnx,yn\mathsf{K}(x,y)=\sum_{n=1}^{\infty}b_{n}\langle x,y\rangle^{n}. Assuming that bn>0b_{n}>0,

  • If p(x)=|α|=naαxαp(x)=\sum_{|\alpha|=n}a_{\alpha}x^{\alpha}, then p{}𝖪p\in\{\cal H\}_{\mathsf{K}}, and p𝖪21bn|α|=naα2\|p\|_{\mathsf{K}}^{2}\leq\frac{1}{b_{n}}\sum_{|\alpha|=n}a_{\alpha}^{2}.

  • For every u𝕊d1u\in\mathbb{S}^{d-1}, the function f(x)=u,xnf(x)=\langle u,x\rangle^{n} resides in 𝖪{\cal H}_{\mathsf{K}} and f𝖪2=1bn\|f\|_{\mathsf{K}}^{2}=\frac{1}{b_{n}}.

For a kernel 𝖪\mathsf{K} and M>0M>0, we represent 𝖪M:={h𝖪:h𝖪M}{\cal H}_{\mathsf{K}}^{M}:=\{h\in{\cal H}_{\mathsf{K}}:\|h\|_{\mathsf{K}}\leq M\}. Additionally, the inner product kernel space 𝖪M{\cal H}_{\mathsf{K}}^{M} is a natural benchmark for learning algorithms. Then, we present the definition of Hermite polynomials and dual activation functions.

Definition 3.11 (Hermite Polynomials and the Dual Activation).

Hermite polynomials h0,h1,h2,h_{0},h_{1},h_{2},\ldots correspond to the sequence of orthonormal polynomials linked to the standard Gaussian measure on \mathbb{R}. Establish an activation σ:\sigma:\mathbb{R}\rightarrow\mathbb{R}, we define the dual activation of σ\sigma as follows:

  • σ^(ρ):=𝔼X,Y𝒟ρ[σ(X)σ(Y)]\widehat{\sigma}(\rho):=\operatorname*{{\mathbb{E}}}_{X,Y\sim{\cal D}{\rho}}[\sigma(X)\sigma(Y)] where 𝒟ρ{\cal D}{\rho} represents ρ\rho-correlated standard Gaussian.

  • Additionally, it stands that if σ=n=0anhn\sigma=\sum_{n=0}^{\infty}a_{n}h_{n}, then σ^(ρ)=n=0an2ρn.\widehat{\sigma}(\rho)=\sum_{n=0}^{\infty}a_{n}^{2}\rho^{n}.

Specifically, 𝖪σ(x,y):=σ^(x,y)\mathsf{K}_{\sigma}(x,y):=\widehat{\sigma}(\langle x,y\rangle) forms an inner product kernel.

4 Related Work

Sketching

Sketching is a well-known technique to improve performance or memory complexity [18]. It has wide applications in linear algebra, such as linear regression and low-rank approximation[18, 60, 58, 67, 75, 39, 6, 76, 77, 23], training over-parameterized neural network [82, 83, 91], empirical risk minimization [56, 65], linear programming [56, 47, 80], distributed problems [85, 15], clustering [31], generative adversarial networks [88], kernel density estimation [63], tensor decomposition [78], trace estimation [46], projected gradient descent [40, 87], matrix sensing [27, 64], John Ellipsoid computation [16, 81], semi-definite programming [35], kernel methods [1, 4, 20, 74], adversarial training [34], cutting plane method [45], discrepany [92], federated learning [66], reinforcement learning [5, 86, 72], relational database [62].

Over-parameterization in Training Neural Networks.

The investigation of the geometry and convergence patterns of various optimization methods on over-parameterized neural networks has become a significant focus within the deep learning sphere [52, 44, 30, 10, 12, 25, 79, 89, 61, 53, 13, 82, 38, 83, 42, 55, 7, 59, 92]. The ground-breaking research by [44] introduced the concept of neural tangent kernel (NTK), a critical analytical tool in deep learning theory. By expanding the neural network’s size to the extent that the network width becomes relatively large (mΩ(n2))(m\geq\Omega(n^{2})), it can be demonstrated that the training dynamic of a neural network closely mirrors that of an NTK.

5 Proof of Correctness

At first, we present our algorithm (Algorithm 1) for shifted ReLU activation over two layer neural network via SGD update.

Algorithm 1 Neural Network Training Via Building a Data Structure of Weights
1:procedure NeuralNetworkTrainingViaPreprocessingWeights(d,m,B,η,b,Td,m,B,\eta,b,T)
2:     Network parameters dd and mm
3:     Initialization parameter B>0B>0,
4:     Learning rate η>0\eta>0,
5:     Batch size bb,
6:     Number of steps T>0T>0
7:     Access to samples from a distribution 𝒟{\cal D}
8:     Sample W(d,m,B)W\sim{\cal I}(d,m,B) \triangleright Definition 1.3
9:     b00.4log2mb_{0}\leftarrow\sqrt{0.4\log 2m}
10:     HalfSpaceReport hsr \triangleright Algorithm 2
11:     hsr.Init({wr(0)}r[2m],2m,d)(\{w_{r}(0)\}_{r\in[2m]},2m,d) \triangleright This step takes 𝒯𝗂𝗇𝗂𝗍(2m,d){\cal T}_{\mathsf{init}}(2m,d) time.
12:     
13:     for t=1Tt=1\to T do
14:         Obtain a mini-batch St={(xit,yit)}i=1b𝒟bS_{t}=\{(x_{i}^{t},y_{i}^{t})\}_{i=1}^{b}\sim{\cal D}^{b}
15:         for i=1bi=1\to b do
16:              Si,fireS_{i,\text{fire}}\leftarrowhsr.Query(xit,b0)(x_{i}^{t},b_{0}) \triangleright This step takes 𝒯𝗊𝗎𝖾𝗋𝗒(2m,d,ki,t){\cal T}_{\mathsf{query}}(2m,d,k_{i,t}) time.
17:              u(t)i12mrSi,firearσb0(wr(t)xi)u(t)_{i}\leftarrow\frac{1}{\sqrt{2m}}\sum_{r\in S_{i,\text{fire}}}a_{r}\cdot\sigma_{b_{0}}(w_{r}(t)^{\top}x_{i})\triangleright This step takes O(dki,t)O(d\cdot k_{i,t}) time
18:         end for
19:         P0b×2mP\leftarrow 0^{b\times 2m}
20:         for i=1bi=1\to b do
21:              for rSi,firer\in S_{i,\text{fire}} do
22:                  Pi,r12marσb0(wr(t)xit)P_{i,r}\leftarrow\frac{1}{\sqrt{2m}}a_{r}\cdot\sigma_{b_{0}}^{\prime}(w_{r}(t)^{\top}x_{i}^{t})
23:              end for
24:         end for
25:         MXdiag(yu(t))M\leftarrow X\text{diag}(y-u(t)) \triangleright Md×bM\in\mathbb{R}^{d\times b}, takes O(bd)O(bd) time.
26:         ΔWMP\Delta W\leftarrow MP \triangleright This step takes O(dnnz(P))O(d\cdot\operatorname{nnz}(P)) time, where nnz(P)=O(bm4/5)\text{nnz}(P)=O(bm^{4/5})
27:         W(t+1)W(t)ηΔWW(t+1)\leftarrow W(t)-\eta\cdot\Delta W. \triangleright Backward computation.
28:         Let Q[2m]Q\subset[2m], such that for each rQ,ΔW,rr\in Q,\Delta W_{*,r} is not all zeros \triangleright |Q|O(bm4/5)|Q|\leq O(bm^{4/5})
29:         for rQr\in Q do
30:              hsr.Delete(wr(t))(w_{r}(t))
31:              hsr.Insert(wr(t+1))(w_{r}(t+1)) \triangleright Update the network weight.
32:         end for
33:     end for
34:     Choose t[T]t\in[T] uniformly at random and return W(t)W(t).
35:end procedure

Next, we deliver the proof of correctness showing that Algorithm 1 is capable of learning even polynomials of bounded norm 𝒫cM{\cal P}_{c}^{M}(Definition 1.4) with nearly optimal sample complexity and network size.

Theorem 5.1 (Neural Network Learning with Shifted ReLU Activation).

Given the following conditions:

  • a fixed constant c>0c>0 and b0b_{0},

  • d,M>0,R>0d,M>0,R>0, ϵ>0\epsilon>0,

  • the network function 2𝖭𝖭(2m,b0)2\mathsf{NN}(2m,b_{0}) with initialization as indicated in Definition 1.3,

there is a choice of

m=O~(d1ϵ2M2R2),T=O(ϵ2M2),\displaystyle m=\widetilde{O}({d^{-1}\epsilon^{-2}M^{2}R^{2}}),T=O({\epsilon^{-2}M^{2}}),

and positive values of BB and η\eta. Such a selection makes sure that for every RR-bounded distribution 𝒟{\cal D} (as defined in Definition 1.1) and a batch size bb, the function hh obtained by Algorithm 1 satisfies the condition that

𝔼[𝒟(h)]𝒟(𝒫cM)+ϵ.\displaystyle\operatorname*{{\mathbb{E}}}[{\cal L}_{\cal D}(h)]\leq{\cal L}_{\cal D}({\cal P}^{M}_{c})+\epsilon.

From [22], for shifted ReLU activation σb0\sigma_{b_{0}}, it holds that for every constant cc, 𝒫cMtkσhO(M){\cal P}_{c}^{M}\subset{\cal H}_{\mathrm{tk}^{h}_{\sigma}}^{O(M)}. As a result, the following theorem implies the above theorem.

Theorem 5.2 (Neural Network Learning).

Given d,M>0,R>0d,M>0,R>0 and ϵ>0\epsilon>0, there exists a choice of

m=O~(M2R2dϵ2),T=O(M2ϵ2),\displaystyle m=\widetilde{O}(\frac{M^{2}R^{2}}{d\epsilon^{2}}),T=O(\frac{M^{2}}{\epsilon^{2}}),

along with B>0B>0 and η>0\eta>0, such that for any batch size bb and RR-bounded distribution 𝒟{\cal D}(Definition 1.1), the function hh obtained by NeuralNetworkTraining(σ,d,m,l,η,b,T,B)(\sigma,d,m,l,\eta,b,T,B) gives us:

𝔼[𝒟(h)]𝒟(Mtkσh)+ϵ.\displaystyle\operatorname*{{\mathbb{E}}}[{\cal L}_{\cal D}(h)]\leq{\cal L}_{\cal D}({\cal H}^{M}{\mathrm{tk}{\sigma}^{h}})+\epsilon.

We can prove this theorem by a reduction to neural tangent kernel space on the initialized weight. At first, we use ψW(x)\psi_{W}(x) to denote the gradient of the function WhW(x)W\mapsto h_{W}(x) with respect to the hidden weights, i.e.,

ψW(x):=(u1σ(w1,x)x,,u2mσ(w2m,x)x)2m×d,\displaystyle\psi_{W}(x):=(u_{1}\sigma^{\prime}(\langle w_{1},x\rangle)x,\ldots,u_{2m}\sigma^{\prime}(\langle w_{2m},x\rangle)x)\in\mathbb{R}^{2m\times d},

where we use σ(x)\sigma^{\prime}(x) denote the first order derivative of activation σ\sigma. Moreover, we denote fψW,V(x):=V,ψw(x)f_{\psi_{W},V}(x):=\langle V,\psi_{w}(x)\rangle.

Next, we show that NeuralNetworkTraining(σ,d,m,l,η,b,T,B)(\sigma,d,m,l,\eta,b,T,B) is equivalent to NeuralTangentKernelTraining(σ,d,m,l,η,b,T)(\sigma,d,m,l,\eta,b,T), with large enough initialization of the weights on neurons. We defer the proof of this lemma to Section C.1.

Lemma 5.3 (Equivalence for NNT and NTKT).

If the following conditions hold

  • Fix a descent activation σ\sigma as well as a convex descent loss ll(Definition 3.1).

There is a choice B=poly(d,m,1/η,T,1/ϵ)B=\operatorname{poly}(d,m,1/\eta,T,1/\epsilon), such that for every input distribution the following holds: Let h1,h2h_{1},h_{2} be the functions returned by NeuralNetworkTraining(σ,d,m,l,η,b,T,B)(\sigma,d,m,l,\eta,b,T,B) with parameters d,m,ηB2,b,B,Td,m,\frac{\eta}{B^{2}},b,B,T and NeuralTangentKernelTraining(σ,d,m,l,η,b,T)(\sigma,d,m,l,\eta,b,T).

Then, we have

  • |𝔼[𝒟(h1)]𝔼[𝒟(h2)]|<ϵ|\operatorname*{{\mathbb{E}}}[{\cal L}_{\cal D}(h_{1})]-\operatorname*{{\mathbb{E}}}[{\cal L}_{\cal D}(h_{2})]|<\epsilon.

By the above lemma, it’s enough to analyze NeuralTangentKernelTraining in order to prove Theorem 5.2. To be specific, Theorem 5.2 follows from the following theorem:

Theorem 5.4.

Given d,M>0,R>0d,M>0,R>0 and ϵ>0\epsilon>0, there is a choice of m=O~(d1ϵ2M2R2)m=\widetilde{O}({d^{-1}\epsilon^{-2}M^{2}R^{2}}), T=O(ϵ2M2)T=O({\epsilon^{-2}M^{2}}), and η>0\eta>0 which enable that for every RR-bounded distribution 𝒟\cal D(Definition 1.1) and batch size bb, the function hh obtained by NeuralTangentKernelTraining(σ,d,m,l,η,b,T)(\sigma,d,m,l,\eta,b,T) satisfies

𝔼[𝒟(h)]𝒟(tkσhM)+ϵ.\displaystyle\operatorname*{{\mathbb{E}}}[{\cal L}_{\cal D}(h)]\leq{\cal L}_{\cal D}({\cal H}^{M}_{\mathrm{tk}^{h}_{\sigma}})+\epsilon.

In order to prove the above theorem, we prove an equivalent theorem (Theorem 5.6) in the next section, where we rephrase everything in the language of the vector random feature scheme.

5.1 Vector Random Feature Schemes

We note that NeuralTangentKernelTraining(σ,d,m,l,η,b,T)(\sigma,d,m,l,\eta,b,T) is SGD on top of the random embedding ψ(W)\psi(W) that consists of mm i.i.d. random mappings:

ψW(x)=(σ(W,x)x,σ(W,x)x),\displaystyle\psi_{W}(x)=(\sigma^{\prime}(\langle W,x\rangle)x,-\sigma^{\prime}(\langle W,x\rangle)x),

where WdW\in\mathbb{R}^{d} follows a standard Gaussian distribution. For simplification, we adjust the training process to SGD on independent and identically distributed random mappings ψW(x)=σ(W,x)x\psi_{W}(x)=\sigma^{\prime}(\langle W,x\rangle)x. After the application of this mapping, inner products between different examples remain unaffected up to multiplication. As the SGD update solely depends on these inner products, analyzing the learning process within the corresponding random feature scheme framework is sufficient. To begin with, we use random mapping ΨW(x)\Psi_{W}(x) to denote the random mm-embedding generated from ψ\psi:

ΨW(x):=1m(ψ(w1,x),,ψ(wm,x))\displaystyle\Psi_{W}(x):=\frac{1}{\sqrt{m}}\cdot(\psi(w_{1},x),\ldots,\psi(w_{m},x))

where w1,,wmw_{1},\ldots,w_{m} are i.i.d. Next, we consider SGDRFS(ψ,m,l,η,b,T)(\psi,m,l,\eta,b,T) for learning the class 𝖪{\cal H}_{\mathsf{K}}, by running SGD algorithm on these random features. For the remainder of this section, we establish a CC-bounded Random Feature Space (RFS) ψ\psi for a kernel 𝖪\mathsf{K} and a randomly selected mm embedding ψw\psi_{w}. We adjust the notation to denote the RFS as ψ\psi. The Neural Tangent Kernel (NTK) RFS is presented by the mapping ψ:d×𝕊d1d\psi:\mathbb{R}^{d}\times\mathbb{S}^{d-1}\rightarrow\mathbb{R}^{d} defined by

ψ(w,x):=σ(w,x)x.\displaystyle\psi(w,x):=\sigma^{\prime}(\langle w,x\rangle)x.
Definition 5.5 (mm-Kernel and mm-Kernel Space).

For a Random Feature Space (RFS) ψ\psi of a kernel 𝖪\mathsf{K}, and a randomly chosen mm embedding ψw\psi_{w}, the random mm-kernel with regard to ψw\psi_{w} is

𝖪w(x1,x2)=ψw(x1),ψw(x2)\displaystyle\mathsf{K}_{w}(x_{1},x_{2})=\langle\psi_{w}(x_{1}),\psi_{w}(x_{2})\rangle

Similarly, the random mm-kernel space corresponding to ψw\psi_{w} is 𝖪w{\cal H}_{\mathsf{K}_{w}}. For every x1,x2𝒳x_{1},x_{2}\in{\cal X}, we define

𝖪w(x1,x2)=1mi=1mψ(wi,x1),ψ(wi,x2)\displaystyle\mathsf{K}_{w}(x_{1},x_{2})=\frac{1}{m}\sum_{i=1}^{m}\langle\psi(w_{i},x_{1}),\psi(w_{i},x_{2})\rangle

as the average of mm independent random variables whose expectation is 𝖪(x1,x2)\mathsf{K}(x_{1},x_{2})

With all the definitions, now we are ready to state the proof of correctness on the training of random feature scheme by the SGD update. Moreover, we defer the proof of Theorem 5.6 to Section C.2.

Theorem 5.6.

Assume that

  • ψ\psi is a factorized (Definition B.3),

  • CC-bounded Random Feature Space (RFS) (Definition B.1) for 𝖪\mathsf{K},

  • ll is convex (Definition 3.1) and LL-Lipschitz (Definition 3.2),

  • 𝒟{\cal D} has RR-bounded marginal (Definition 1.1).

Let ff be the function returned by SGDRFS(ψ,m,l,η,b,T)(\psi,m,l,\eta,b,T). Fix a function fkf^{*}\in{\cal H}_{k}. Then:

𝔼[𝒟(f)]𝒟(f)+LRCf𝖪md+f𝖪22ηT+ηL2C22\displaystyle\operatorname*{{\mathbb{E}}}[{\cal L}_{\cal D}(f)]\leq{\cal L}_{\cal D}(f^{*})+\frac{LRC\|f^{*}\|_{\mathsf{K}}}{\sqrt{md}}+\frac{\|f^{*}\|_{\mathsf{K}}^{2}}{2\eta T}+\frac{\eta L^{2}C^{2}}{2}

In particular, if |f𝖪|M|f^{|}_{\mathsf{K}}\leq M and η=MTLC\eta=\frac{M}{\sqrt{T}LC}, we have:

𝔼[𝒟(f)]L𝒟(f)+LRCMmd+LCMT\displaystyle\operatorname*{{\mathbb{E}}}[{\cal L}_{\cal D}(f)]\leq L_{\cal D}(f^{*})+\frac{LRCM}{\sqrt{md}}+\frac{LCM}{\sqrt{T}}

6 Proof of Running Time

In this section, we first present some definitions and properties regarding the activated neuron at each iteration. Then, we present the problem definition, algorithm, and runtime guarantee of half-space reporting. Finally, we present the runtime analysis that our algorithm has a cost per iteration sublinear in number of neurons 2m2m.

6.1 Sparsity Characterization

In this section, we show that with high probability, at every time t[T]t\in[T], the number of neurons activated by shifted ReLU of each data point is sublinear in mm. To begin with, we first present a definition which shows the set of neurons that fires at time tt.

Definition 6.1 (Fire Set).

For each index ii within the range [n][n], and for every timestep tt within the range [T][T], we define Si,firing(t)S_{i,\mathrm{firing}}(t) as a subset of [m][m]. This set corresponds to the neurons that ”activate” or ”fire” at time tt. Formally, it is defined as follows:

Si,fire(t):={r[m]:wr(t),xi>b0}\displaystyle S_{i,\mathrm{fire}}(t):=\{r\in[m]:\langle w_{r}(t),x_{i}\rangle>b_{0}\}

We also define ki,tk_{i,t} to be the size of the aforementioned set, i.e., ki,t:=|Si,firing(t)|k_{i,t}:=|S_{i,\mathrm{firing}}(t)| for all t[T]t\in[T].

We subsequently introduce a novel ”sparsity lemma” which demonstrates that the activation function σb0\sigma_{b_{0}} results in the required sparsity.

Lemma 6.2 (Sparsity After Initialization, [82]).

Given parameter b0,mb_{0},m, and network structure as 𝟤𝖭𝖭(2m,b0)\mathsf{2NN}(2m,b_{0})(Definition 1.2), then after weight initialization, with probability

1nexp(Ω(mexp(b02/2))),\displaystyle 1-n\cdot\exp(-\Omega(m\cdot\exp(-{b_{0}}^{2}/2))),

for every input xix_{i}, the number of fired neurons ki,0k_{i,0} is upper bounded by: ki,0=O(mexp(b02/2))k_{i,0}=O(m\cdot\exp(-b_{0}^{2}/2)).

Next, we present a choice of threshold b0b_{0} that gives us a fired neurons set sublinear in network size 2m2m:

Remark 6.3.

Let b0=0.4logmb_{0}=\sqrt{0.4\log m}, then k0=m4/5k_{0}=m^{4/5}. For t=m4/5t=m^{4/5}, Lemma 6.2 implies that:

Pr[|𝒮i,fire(0)|>2m4/5]exp(min{mR,O(m4/5)}).\displaystyle\Pr[|{\cal S}_{i,\mathrm{fire}}(0)|>2m^{4/5}]\leq\exp(-\min\{mR,O(m^{4/5})\}).

In the forthcoming discussions, our aim is to establish that at any time instance t[T]t\in[T], the count of activated neurons is sublinear with respect to mm. Initially, we define the set of flip neurons at time tt:

Definition 6.4 (Flip Set).

For each index ii within the range [n][n], and for each time instance t[T]t\in[T], we define 𝒮i,flip(t){\cal S}{i,\mathrm{flip}}(t) as a subset of [m][m]. This set corresponds to the neurons that switch their state at time tt. More formally, it is defined as follows:

𝒮i,flip(t):=\displaystyle{\cal S}_{i,\mathrm{flip}}(t):= {r[m]:sgn(wr(t),xib)sign(wr(t1),xib)}\displaystyle~{}\{r\in[m]:\mathrm{sgn}(\langle w_{r}(t),x_{i}\rangle-b)\neq\mathrm{sign}(\langle w_{r}(t-1),x_{i}\rangle-b)\}

In contrast, there exist neurons that remain in the same state throughout the entire training process:

Definition 6.5 (Nonflip Set).

For each index ii in the range [n][n], we define SiS_{i} as a subset of [m][m]. This set includes the neurons that never switch their state during the whole training process. Specifically, it is defined as follows:

Si\displaystyle S_{i} :={r[m]:t[T],sgn(wr(t),xib)\displaystyle~{}:=\{r\in[m]:\forall t\in[T],\mathrm{sgn}(\langle w_{r}(t),x_{i}\rangle-b)
=sgn(wr(0),xib)}.\displaystyle~{}=\mathrm{sgn}(\langle w_{r}(0),x_{i}\rangle-b)\}.

Then, we introduce a lemma showing that the number of fired neurons ki,tk_{i,t} is small for all i[n],t[T]i\in[n],t\in[T] with high probability.

Lemma 6.6 (Upper Bound of Fired Neuron per Iteration, [82]).

Let

  • b00b_{0}\geq 0 be a parameter,

  • σb0(x)=max{x,b0}\sigma_{b_{0}}(x)=\max\{x,b_{0}\} be the activation function.

For each i[n],t[T]i\in[n],t\in[T], ki,tk_{i,t} is the number of activated neurons at the tt-th iteration. For 0<t<T0<t<T, with probability at least

1nexp(Ω(m)min{R,exp(b02/2)}),\displaystyle 1-n\cdot\exp(-\Omega(m)\cdot\min\{R,-\exp(-b_{0}^{2}/2)\}),

ki,tk_{i,t} is at most O(mexp(b02/2))O(m\exp(-b_{0}^{2}/2)) for all i[n]i\in[n].

6.2 Data Structure for Half-Space Reporting

In this section, we introduce the problem formulation, the data structure, and the time efficiency guarantees for the half-space reporting data structure. The primary objective of half-space reporting is to construct a data structure for a set SS in such a way that, for any given half-space HH, the data structure can quickly identify and output the points that fall within this half-space:

Definition 6.7 (Half-space Range Reporting).

Given a set SS of nn points in d\mathbb{R}^{d}, we define a half-space range reporting data structure that supports two fundamental operations:

  • Query(H)(H): Provided a half-space HdH\subset\mathbb{R}^{d}, this operation returns all points within SS that are also contained within HH. In other words, it outputs the intersection of SS and HH.

  • Update: This operation pertains to modifying the set SS by either inserting a new point into it, or removing an existing point from it:

    • Insert(q)(q): This operation inserts a point qq into the set SS.

    • Delete(q)(q): This operation deletes the point qq from the set SS.

Moreover, we denote 𝒯𝗂𝗇𝗂𝗍{\cal T}_{\mathsf{init}}, 𝒯𝗊𝗎𝖾𝗋𝗒{\cal T}_{\mathsf{query}}, and 𝒯𝗎𝗉𝖽𝖺𝗍𝖾{\cal T}_{\mathsf{update}} as the pre-processing time, per round query time, and per round update time for the data structure.

Next, we present the formal data structure for half-space reporting.

Algorithm 2 Half Space Report Data Structure
1:data structure: HalfSpaceReport
2:    procedures:
3:       Init(S,n,d)(S,n,d) \triangleright Initialize the data structure with a set SS of nn points in d\mathbb{R}^{d}
4:       Query(a,b)(a,b) \triangleright a,bda,b\in\mathbb{R}^{d}. Output the set {xS:sgn(a,xb)0}\{x\in S:\mathrm{sgn}(\langle a,x\rangle-b)\geq 0\}
5:       Add(x)(x) \triangleright Add point xdx\in\mathbb{R}^{d} to SS
6:       Delete(x)(x) \triangleright Delete point xdx\in\mathbb{R}^{d} from SS
7:end data structure

From  [3], this problem can be solved with sublinear time complexity:

Corollary 6.8 ([3]).

Given a set of nn points in d\mathbb{R}^{d}, the half-space reporting problem can be solved with:

  • A query time denoted by 𝒯𝗊𝗎𝖾𝗋𝗒(n,d,k)=Od(n11/d/2+k){\cal T}{\mathsf{query}}(n,d,k)=O_{d}(n^{1-1/{\lfloor d/2\rfloor}}+k), where k=|SH|k=|S\cap H| is the number of points in the intersection of the set SS and half-space HH.

  • An amortized update time denoted by 𝒯𝗎𝗉𝖽𝖺𝗍𝖾=Od(log2(n)){\cal T}\mathsf{update}=O_{d}(\log^{2}(n)).

6.3 Cost per iteration

In this section, we analyze the time complexity per iteration of Algorithm 1.

Lemma 6.9 (Running time of Algorithm 1).

Given the following:

  • Sample access to distribution 𝒟d{\cal D}\in\mathbb{R}^{d},

  • Running stochastic gradient descent algorithm (Algorithm 1) on 2𝖭𝖭(2m,b0=0.4log(2m))2\mathsf{NN}(2m,b_{0}=\sqrt{0.4\log(2m)}) (Definition 1.2) with batch size bb,

then the expected cost per-iteration of this algorithm is

O~(m1Θ(1/d)bd)\displaystyle\widetilde{O}(m^{1-\Theta(1/d)}bd)

We delay the proof of Lemma 6.9 to Section C.3.

7 Conclusion

Deep learning is widely employed in many domains, but its model training procedure often takes an unnecessarily large amount of computational resources and time. In this paper, we design an efficient neural network training method with SGD update that has a provable convergence guarantee. By leveraging the static half-space report data structure into the optimization process of a fully connected two-layer neural network with neural tangent kernel for shifted ReLU activation, our algorithm supports sublinear time activate neuron identification via geometric search. In addition, we prove that our algorithm can converge in O(M2/ϵ2)O(M^{2}/\epsilon^{2}) time with network size quadratic in the coefficient norm upper bound MM and error term ϵ\epsilon. As far as we are aware, our work does not have negative societal impacts. One limitation of our work is that we can study other activation functions beyond shifted ReLU in the future.

References

  • ACW [17] Haim Avron, Kenneth L Clarkson, and David P Woodruff. Faster kernel ridge regression using sketching and preconditioning. SIAM Journal on Matrix Analysis and Applications, 38(4):1116–1138, 2017.
  • ADH+ [19] Sanjeev Arora, Simon Du, Wei Hu, Zhiyuan Li, and Ruosong Wang. Fine-grained analysis of optimization and generalization for overparameterized two-layer neural networks. In International Conference on Machine Learning, pages 322–332. PMLR, 2019.
  • AEM [92] Pankaj K Agarwal, David Eppstein, and Jiri Matousek. Dynamic half-space reporting, geometric optimization, and minimum spanning trees. In Annual Symposium on Foundations of Computer Science, volume 33, pages 80–80. IEEE Computer Society Press, 1992.
  • AKK+ [20] Thomas D Ahle, Michael Kapralov, Jakob BT Knudsen, Rasmus Pagh, Ameya Velingker, David P Woodruff, and Amir Zandieh. Oblivious sketching of high-degree polynomial kernels. In Proceedings of the Fourteenth Annual ACM-SIAM Symposium on Discrete Algorithms, pages 141–160. SIAM, 2020.
  • AKL [17] Jacob Andreas, Dan Klein, and Sergey Levine. Modular multitask reinforcement learning with policy sketches. In International Conference on Machine Learning, pages 166–175. PMLR, 2017.
  • ALS+ [18] Alexandr Andoni, Chengyu Lin, Ying Sheng, Peilin Zhong, and Ruiqi Zhong. Subspace embedding and linear regression with orlicz norm. In International Conference on Machine Learning (ICML), pages 224–233. PMLR, 2018.
  • ALS+ [22] Josh Alman, Jiehao Liang, Zhao Song, Ruizhe Zhang, and Danyang Zhuo. Bypass exponential time preprocessing: Fast neural network training via weight-data correlation preprocessing. arXiv preprint arXiv:2211.14227, 2022.
  • AS [23] Josh Alman and Zhao Song. Fast attention requires bounded entries. arXiv preprint arXiv:2302.13214, 2023.
  • [9] Zeyuan Allen-Zhu, Yuanzhi Li, and Zhao Song. A convergence theory for deep learning via over-parameterization. In International Conference on Machine Learning, pages 242–252. PMLR, 2019.
  • [10] Zeyuan Allen-Zhu, Yuanzhi Li, and Zhao Song. A convergence theory for deep learning via over-parameterization. In ICML, 2019.
  • [11] Zeyuan Allen-Zhu, Yuanzhi Li, and Zhao Song. On the convergence rate of training recurrent neural networks. Advances in neural information processing systems, 32, 2019.
  • [12] Zeyuan Allen-Zhu, Yuanzhi Li, and Zhao Song. On the convergence rate of training recurrent neural networks. In NeurIPS, 2019.
  • BPSW [21] Jan van den Brand, Binghui Peng, Zhao Song, and Omri Weinstein. Training (overparametrized) neural networks in near-linear time. In 12th Innovations in Theoretical Computer Science Conference, ITCS 2021, January 6-8, 2021, Virtual Conference, volume 185 of LIPIcs, pages 63:1–63:15. Schloss Dagstuhl - Leibniz-Zentrum für Informatik, 2021.
  • BSZ [23] Jan van den Brand, Zhao Song, and Tianyi Zhou. Algorithm and hardness for dynamic attention maintenance in large language models. arXiv e-prints, pages arXiv–2304, 2023.
  • BWZ [16] Christos Boutsidis, David P Woodruff, and Peilin Zhong. Optimal principal component analysis in distributed and streaming models. In Proceedings of the forty-eighth annual ACM symposium on Theory of Computing (STOC), pages 236–249, 2016.
  • CCLY [19] Michael B Cohen, Ben Cousins, Yin Tat Lee, and Xin Yang. A near-optimal algorithm for approximating the john ellipsoid. In Conference on Learning Theory, pages 849–873. PMLR, 2019.
  • CG [19] Yuan Cao and Quanquan Gu. Generalization bounds of stochastic gradient descent for wide and deep neural networks. Advances in neural information processing systems, 32, 2019.
  • CW [13] Kenneth L. Clarkson and David P. Woodruff. Low rank approximation and regression in input sparsity time. In Proceedings of the 45th Annual ACM Symposium on Theory of Computing (STOC), 2013.
  • CWB+ [11] Ronan Collobert, Jason Weston, Léon Bottou, Michael Karlen, Koray Kavukcuoglu, and Pavel Kuksa. Natural language processing (almost) from scratch. Journal of machine learning research, 12(ARTICLE):2493–2537, 2011.
  • CY [21] Yifan Chen and Yun Yang. Accumulations of projections—a unified framework for random sketches in kernel ridge regression. In International Conference on Artificial Intelligence and Statistics, pages 2953–2961. PMLR, 2021.
  • Dan [20] Amit Daniely. Neural networks learning and memorization with (almost) no over-parameterization. Advances in Neural Information Processing Systems, 33:9007–9016, 2020.
  • DFS [16] Amit Daniely, Roy Frostig, and Yoram Singer. Toward deeper understanding of neural networks: The power of initialization and a dual view on expressivity. Advances In Neural Information Processing Systems, 29, 2016.
  • DJS+ [19] Huaian Diao, Rajesh Jayaram, Zhao Song, Wen Sun, and David Woodruff. Optimal sketching for kronecker product regression and low rank approximation. Advances in neural information processing systems, 32, 2019.
  • [24] Simon Du, Jason Lee, Haochuan Li, Liwei Wang, and Xiyu Zhai. Gradient descent finds global minima of deep neural networks. In International conference on machine learning, pages 1675–1685. PMLR, 2019.
  • [25] Simon S Du, Jason D Lee, Haochuan Li, Liwei Wang, and Xiyu Zhai. Gradient descent finds global minima of deep neural networks. In International Conference on Machine Learning (ICML), 2019.
  • [26] Yichuan Deng, Zhihang Li, and Zhao Song. Attention scheme inspired softmax regression. arXiv preprint arXiv:2304.10411, 2023.
  • [27] Yichuan Deng, Zhihang Li, and Zhao Song. An improved sample complexity for rank-1 matrix sensing. arXiv preprint arXiv:2303.06895, 2023.
  • DMS [23] Yichuan Deng, Sridhar Mahadevan, and Zhao Song. Randomized and deterministic attention sparsification algorithms for over-parameterized feature dimension. arXiv preprint arXiv:2304.04397, 2023.
  • DZPS [18] Simon S Du, Xiyu Zhai, Barnabas Poczos, and Aarti Singh. Gradient descent provably optimizes over-parameterized neural networks. In International Conference on Learning Representations, 2018.
  • DZPS [19] Simon S Du, Xiyu Zhai, Barnabas Poczos, and Aarti Singh. Gradient descent provably optimizes over-parameterized neural networks. In ICLR. arXiv preprint arXiv:1810.02054, 2019.
  • EMZ [21] Hossein Esfandiari, Vahab Mirrokni, and Peilin Zhong. Almost linear time density level set estimation via dbscan. In AAAI, 2021.
  • GB [10] Xavier Glorot and Yoshua Bengio. Understanding the difficulty of training deep feedforward neural networks. In Proceedings of the thirteenth international conference on artificial intelligence and statistics, pages 249–256. JMLR Workshop and Conference Proceedings, 2010.
  • GMS [23] Yeqi Gao, Sridhar Mahadevan, and Zhao Song. An over-parameterized exponential regression. arXiv preprint arXiv:2303.16504, 2023.
  • GQSW [22] Yeqi Gao, Lianke Qin, Zhao Song, and Yitan Wang. A sublinear adversarial training algorithm. arXiv preprint arXiv:2208.05395, 2022.
  • GS [22] Yuzhou Gu and Zhao Song. A faster small treewidth sdp solver. arXiv preprint arXiv:2211.06033, 2022.
  • GSX [23] Yeqi Gao, Zhao Song, and Shenghao Xie. In-context learning for attention scheme: from single softmax regression to multiple softmax regression via a tensor trick. arXiv preprint arXiv:2307.02419, 2023.
  • GSY [23] Yeqi Gao, Zhao Song, and Junze Yin. An iterative algorithm for rescaled hyperbolic functions regression. arXiv preprint arXiv:2305.00660, 2023.
  • HLSY [21] Baihe Huang, Xiaoxiao Li, Zhao Song, and Xin Yang. Fl-ntk: A neural tangent kernel-based framework for federated learning analysis. In International Conference on Machine Learning, pages 4423–4434. PMLR, 2021.
  • HLW [17] Jarvis Haupt, Xingguo Li, and David P Woodruff. Near optimal sketching of low-rank tensor regression. In ICML, 2017.
  • HMR [18] Filip Hanzely, Konstantin Mishchenko, and Peter Richtárik. Sega: Variance reduction via gradient sketching. Advances in Neural Information Processing Systems, 31, 2018.
  • Hoe [63] Wassily Hoeffding. Probability inequalities for sums of bounded random variables. Journal of the American Statistical Association, 58(301):13–30, 1963.
  • HSWZ [22] Hang Hu, Zhao Song, Omri Weinstein, and Danyang Zhuo. Training overparametrized neural networks in sublinear time. arXiv preprint arXiv:2208.04508, 2022.
  • HZRS [16] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 770–778, 2016.
  • JGH [18] Arthur Jacot, Franck Gabriel, and Clément Hongler. Neural tangent kernel: Convergence and generalization in neural networks. Advances in neural information processing systems, 31, 2018.
  • JLSW [20] Haotian Jiang, Yin Tat Lee, Zhao Song, and Sam Chiu-wai Wong. An improved cutting plane method for convex optimization, convex-concave games and its applications. In STOC, 2020.
  • JPWZ [21] Shuli Jiang, Hai Pham, David Woodruff, and Richard Zhang. Optimal sketching for trace estimation. Advances in Neural Information Processing Systems, 34, 2021.
  • JSWZ [21] Shunhua Jiang, Zhao Song, Omri Weinstein, and Hengjie Zhang. Faster dynamic matrix inverse for faster lps. In STOC. arXiv preprint arXiv:2004.07470, 2021.
  • KSH [12] Alex Krizhevsky, Ilya Sutskever, and Geoffrey E Hinton. Imagenet classification with deep convolutional neural networks. Advances in neural information processing systems, 25, 2012.
  • KT [19] Jacob Devlin Ming-Wei Chang Kenton and Lee Kristina Toutanova. Bert: Pre-training of deep bidirectional transformers for language understanding. In Proceedings of NAACL-HLT, pages 4171–4186, 2019.
  • LBBH [98] Yann LeCun, Léon Bottou, Yoshua Bengio, and Patrick Haffner. Gradient-based learning applied to document recognition. Proceedings of the IEEE, 86(11):2278–2324, 1998.
  • LHBB [99] Yann LeCun, Patrick Haffner, Léon Bottou, and Yoshua Bengio. Object recognition with gradient-based learning. In Shape, contour and grouping in computer vision, pages 319–345. Springer, 1999.
  • LL [18] Yuanzhi Li and Yingyu Liang. Learning overparameterized neural networks via stochastic gradient descent on structured data. Advances in neural information processing systems, 31, 2018.
  • LSS+ [20] Jason D Lee, Ruoqi Shen, Zhao Song, Mengdi Wang, and Zheng Yu. Generalized leverage score sampling for neural networks. In NeurIPS, 2020.
  • LSX+ [23] Shuai Li, Zhao Song, Yu Xia, Tong Yu, and Tianyi Zhou. The closeness of in-context learning and weight shifting for softmax regression. arXiv preprint arXiv:2304.13276, 2023.
  • LSY [23] Xiaoxiao Li, Zhao Song, and Jiaming Yang. Federated adversarial learning: A framework with convergence analysis. In International Conference on Machine Learning, pages 19932–19959. PMLR, 2023.
  • LSZ [19] Yin Tat Lee, Zhao Song, and Qiuyi Zhang. Solving empirical risk minimization in the current matrix multiplication time. In COLT, 2019.
  • LSZ [23] Zhihang Li, Zhao Song, and Tianyi Zhou. Solving regularized exp, cosh and sinh regression problems. arXiv preprint arXiv:2303.15725, 2023.
  • MM [13] Xiangrui Meng and Michael W Mahoney. Low-distortion subspace embeddings in input-sparsity time and applications to robust linear regression. In Proceedings of the forty-fifth annual ACM symposium on Theory of computing (STOC), pages 91–100, 2013.
  • MOSW [22] Alexander Munteanu, Simon Omlor, Zhao Song, and David Woodruff. Bounding the width of neural networks via coupled initialization a worst case analysis. In International Conference on Machine Learning, pages 16083–16122. PMLR, 2022.
  • NN [13] Jelani Nelson and Huy L Nguyên. Osnap: Faster numerical linear algebra algorithms via sparser subspace embeddings. In Proceedings of the 54th Annual IEEE Symposium on Foundations of Computer Science (FOCS), 2013.
  • OS [20] Samet Oymak and Mahdi Soltanolkotabi. Toward moderate overparameterization: Global convergence guarantees for training shallow neural networks. IEEE Journal on Selected Areas in Information Theory, 1(1):84–105, 2020.
  • QJS+ [22] Lianke Qin, Rajesh Jayaram, Elaine Shi, Zhao Song, Danyang Zhuo, and Shumo Chu. Adore: Differentially oblivious relational database operators. In VLDB, 2022.
  • QRS+ [22] Lianke Qin, Aravind Reddy, Zhao Song, Zhaozhuo Xu, and Danyang Zhuo. Adaptive and dynamic multi-resolution hashing for pairwise summations. In BigData, 2022.
  • QSZ [23] Lianke Qin, Zhao Song, and Ruizhe Zhang. A general algorithm for solving rank-one matrix sensing. arXiv preprint arXiv:2303.12298, 2023.
  • QSZZ [23] Lianke Qin, Zhao Song, Lichen Zhang, and Danyang Zhuo. An online and unified algorithm for projection matrix vector multiplication with application to empirical risk minimization. In International Conference on Artificial Intelligence and Statistics, pages 101–156. PMLR, 2023.
  • RPU+ [20] Daniel Rothchild, Ashwinee Panda, Enayat Ullah, Nikita Ivkin, Ion Stoica, Vladimir Braverman, Joseph Gonzalez, and Raman Arora. Fetchsgd: Communication-efficient federated learning with sketching. In International Conference on Machine Learning, pages 8253–8265. PMLR, 2020.
  • RSW [16] Ilya Razenshteyn, Zhao Song, and David P. Woodruff. Weighted low rank approximations with provable guarantees. In Proceedings of the Forty-Eighth Annual ACM Symposium on Theory of Computing, STOC ’16, page 250–263, 2016.
  • SHM+ [16] David Silver, Aja Huang, Chris J Maddison, Arthur Guez, Laurent Sifre, George Van Den Driessche, Julian Schrittwieser, Ioannis Antonoglou, Veda Panneershelvam, Marc Lanctot, et al. Mastering the game of go with deep neural networks and tree search. nature, 529(7587):484–489, 2016.
  • SLJ+ [15] Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, and Andrew Rabinovich. Going deeper with convolutions. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 1–9, 2015.
  • SSBD [14] Shai Shalev-Shwartz and Shai Ben-David. Understanding machine learning: From theory to algorithms. Cambridge university press, 2014.
  • SSS+ [17] David Silver, Julian Schrittwieser, Karen Simonyan, Ioannis Antonoglou, Aja Huang, Arthur Guez, Thomas Hubert, Lucas Baker, Matthew Lai, Adrian Bolton, et al. Mastering the game of go without human knowledge. nature, 550(7676):354–359, 2017.
  • SSX [23] Anshumali Shrivastava, Zhao Song, and Zhaozhuo Xu. A tale of two efficient value iteration algorithms for solving linear mdps with large action space. In AISTATS, 2023.
  • SSZ [23] Ritwik Sinha, Zhao Song, and Tianyi Zhou. A mathematical abstraction for balancing the trade-off between creativity and reality in large language models. arXiv preprint arXiv:2306.02295, 2023.
  • SWYZ [21] Zhao Song, David Woodruff, Zheng Yu, and Lichen Zhang. Fast sketching of polynomial kernels of polynomial degree. In International Conference on Machine Learning, pages 9812–9823. PMLR, 2021.
  • SWZ [17] Zhao Song, David P Woodruff, and Peilin Zhong. Low rank approximation with entrywise 1\ell_{1}-norm error. In Proceedings of the 49th Annual Symposium on the Theory of Computing (STOC), 2017.
  • [76] Zhao Song, David Woodruff, and Peilin Zhong. Average case column subset selection for entrywise 1\ell_{1}-norm loss. Advances in Neural Information Processing Systems (NeurIPS), 32:10111–10121, 2019.
  • [77] Zhao Song, David Woodruff, and Peilin Zhong. Towards a zero-one law for column subset selection. Advances in Neural Information Processing Systems, 32:6123–6134, 2019.
  • [78] Zhao Song, David P Woodruff, and Peilin Zhong. Relative error tensor low rank approximation. In SODA. arXiv preprint arXiv:1704.08246, 2019.
  • SY [19] Zhao Song and Xin Yang. Quadratic suffices for over-parametrization via matrix chernoff bound. arXiv preprint arXiv:1906.03593, 2019.
  • SY [21] Zhao Song and Zheng Yu. Oblivious sketching-based central path method for solving linear programming problems. In 38th International Conference on Machine Learning (ICML), 2021.
  • SYYZ [22] Zhao Song, Xin Yang, Yuanyuan Yang, and Tianyi Zhou. Faster algorithm for structured john ellipsoid computation. arXiv preprint arXiv:2211.14407, 2022.
  • SYZ [21] Zhao Song, Shuo Yang, and Ruizhe Zhang. Does preprocessing help training over-parameterized neural networks? Advances in Neural Information Processing Systems (NeurIPS), 34, 2021.
  • SZZ [21] Zhao Song, Lichen Zhang, and Ruizhe Zhang. Training multi-layer over-parametrized neural network in subquadratic time. arXiv preprint arXiv:2112.07628, 2021.
  • WYW+ [23] Junda Wu, Tong Yu, Rui Wang, Zhao Song, Ruiyi Zhang, Handong Zhao, Chaochao Lu, Shuai Li, and Ricardo Henao. Infoprompt: Information-theoretic soft prompt tuning for natural language understanding. arXiv preprint arXiv:2306.04933, 2023.
  • WZ [16] David P Woodruff and Peilin Zhong. Distributed low rank approximation of implicit functions of a matrix. In 2016 IEEE 32nd International Conference on Data Engineering (ICDE), pages 847–858. IEEE, 2016.
  • WZD+ [20] Ruosong Wang, Peilin Zhong, Simon S Du, Russ R Salakhutdinov, and Lin F Yang. Planning with general objective functions: Going beyond total rewards. In Annual Conference on Neural Information Processing Systems (NeurIPS), 2020.
  • XSS [21] Zhaozhuo Xu, Zhao Song, and Anshumali Shrivastava. Breaking the linear iteration cost barrier for some well-known conditional gradient methods using maxip data-structures. Advances in Neural Information Processing Systems (NeurIPS), 34, 2021.
  • XZZ [18] Chang Xiao, Peilin Zhong, and Changxi Zheng. Bourgan: generative networks with metric embeddings. In Proceedings of the 32nd International Conference on Neural Information Processing Systems (NeurIPS), pages 2275–2286, 2018.
  • ZCZG [20] Difan Zou, Yuan Cao, Dongruo Zhou, and Quanquan Gu. Gradient descent optimizes over-parameterized deep relu networks. In Mach. Learn., 2020.
  • ZG [19] Difan Zou and Quanquan Gu. An improved analysis of training over-parameterized deep neural networks. Advances in neural information processing systems, 32, 2019.
  • ZHA+ [21] Amir Zandieh, Insu Han, Haim Avron, Neta Shoham, Chaewon Kim, and Jinwoo Shin. Scaling neural tangent kernels via sketching and random features. Advances in Neural Information Processing Systems, 34, 2021.
  • Zha [22] Lichen Zhang. Speeding up optimizations via data structures: Faster search, sample and maintenance. Master’s thesis, Carnegie Mellon University, 2022.
  • ZMG [19] Guodong Zhang, James Martens, and Roger B Grosse. Fast convergence of natural gradient descent for over-parameterized neural networks. Advances in Neural Information Processing Systems, 32, 2019.
  • ZPD+ [20] Yi Zhang, Orestis Plevrakis, Simon S Du, Xingguo Li, Zhao Song, and Sanjeev Arora. Over-parameterized adversarial training: An analysis overcoming the curse of dimensionality. Advances in Neural Information Processing Systems, 33:679–688, 2020.
  • ZSZ+ [23] Zhenyu Zhang, Ying Sheng, Tianyi Zhou, Tianlong Chen, Lianmin Zheng, Ruisi Cai, Zhao Song, Yuandong Tian, Christopher Ré, Clark Barrett, et al. H _2\_2 o: Heavy-hitter oracle for efficient generative inference of large language models. arXiv preprint arXiv:2306.14048, 2023.

Appendix

Roadmap.

We present some probabilistic tools in Section A. We present more preliminaries in Section B. We present the missing proofs in Section C.

Appendix A Probabilistic Inequalities

Here we present the Hoeffding bound that characterize the probability that the sum of independent random bounded variables deviates from its true mean by a certain amount.

Lemma A.1 (Hoeffding bound ([41])).

Let X1,,XnX_{1},\cdots,X_{n} denote nn independent bounded variables in [ai,bi][a_{i},b_{i}]. Let X=i=1nXiX=\sum_{i=1}^{n}X_{i}, then we have

Pr[|X𝔼[X]|t]2exp(2t2i=1n(biai)2).\displaystyle\Pr[|X-\operatorname*{{\mathbb{E}}}[X]|\geq t]\leq 2\exp\left(-\frac{2t^{2}}{\sum_{i=1}^{n}(b_{i}-a_{i})^{2}}\right).

We present Jensen’s Inequality that relates the function value of a convex function on a convex combination of inputs and the value of the same convex combination of the function values on these inputs.

Lemma A.2 (Jensen’s Inequality).

Let ff be a convex function, and αn\alpha\in\mathbb{R}^{n}, such that i=1nαi=1,αi[0,1]\sum_{i=1}^{n}\alpha_{i}=1,\alpha_{i}\in[0,1]. Then for all x1,,xnx_{1},\ldots,x_{n}, we have:

f(i=1nαixi)i=1nαif(xi)\displaystyle f(\sum_{i=1}^{n}\alpha_{i}x_{i})\leq\sum_{i=1}^{n}\alpha_{i}f(x_{i})

Appendix B More preliminaries

In Section B.1, we introduce the necessary preliminaries related to neural tangent kernel training. In Section B.2, we discuss some fundamental aspects of the random feature scheme.

B.1 The Neural Tangent Kernel Training

Given fixed network parameters σ,d,m\sigma,d,m, and BB, the neural tangent kernel (as defined in [44]) for weights W2m×dW\in\mathbb{R}^{2m\times d} is defined as follows:

tkW(x,y):=12mB2WhW(x),WhW(y)\displaystyle\mathrm{tk}_{W}(x,y):=\frac{1}{2mB^{2}}\cdot\langle\nabla_{W}h_{W}(x),\nabla_{W}h_{W}(y)\rangle

The Neural Tangent Kernel space, denoted as tkW{\cal H}_{\mathrm{tk}_{W}}, is a linear approximation of the updates in the function value hWh_{W} as a result of small alterations in the network weights WW. More formally, htkWh\in{\cal H}_{\mathrm{tk}_{W}} iff there is an UU such that:

h(x)=limϵ01ϵ(hW+ϵU(x)hW(x)),x𝕊d1\displaystyle h(x)=\lim_{\epsilon\rightarrow 0}\frac{1}{\epsilon}\cdot(h_{W+\epsilon U}(x)-h_{W}(x)),~{}~{}~{}~{}\forall x\in\mathbb{S}^{d-1} (1)

Note that mBhtkW\sqrt{m}B\cdot\|h\|_{\mathrm{tk}_{W}} is the minimal Euclidean norm of UU that satisfies Eq. (1). For simplicity, we use tkσ,B(x,y)\mathrm{tk}_{\sigma,B}(x,y) to denote tkσ,d,m,B(x,y)\mathrm{tk}_{\sigma,d,m,B}(x,y), which is the expected initial neural tangent kernel:

tkσ,d,m,B(x,y):=𝔼W(d,m,B)[tkW(x,y)]\displaystyle\mathrm{tk}_{\sigma,d,m,B}(x,y):=\operatorname*{{\mathbb{E}}}_{W\sim{\cal I}(d,m,B)}[\mathrm{tk}_{W}(x,y)]

In the rest of the paper, we will abbreviate NeuralTangentKernelTraining(σ,d,m,l,η,b,T)(\sigma,d,m,l,\eta,b,T) as neural tangent kernel training with parameters: activation function σ\sigma, input dimension dd, weight dimension mm, loss function ll, learning rate η\eta, SGD batch size bb, and iteration number TT. Moreover, in this optimization process, we will initialize the weight vector following W(d,m,1)W\sim{\cal I}(d,m,1) and set the initial kernel weight as V1=02m×dV^{1}=0\in\mathbb{R}^{2m\times d}.

B.2 Random Feature Scheme

In this section, we present some essential backgrounds on random feature scheme. We begin with the definition of random feature scheme with respect to kernel.

Definition B.1 (Random Feature Scheme).

Let 𝒳\cal X be a measurable space and let 𝖪:𝒳×𝒳\mathsf{K}:{\cal X}\times{\cal X}\rightarrow\mathbb{R} be a kernel. A random features scheme(RFS) for 𝖪\mathsf{K} is a pair (ψ,μ)(\psi,\mu) where μ\mu is a probability measure on a measurable space Ω\Omega, and ψ:Ω×𝒳d\psi:\Omega\times{\cal X}\rightarrow\mathbb{R}^{d} is a measurable function, such that:

𝖪(x1,x2)=𝔼wμ[ψ(w,x1),ψ(w,x2)],x1,x2𝒳.\displaystyle\mathsf{K}(x_{1},x_{2})=\operatorname*{{\mathbb{E}}}_{w\sim\mu}[\langle\psi(w,x_{1}),\psi(w,x_{2})\rangle],~{}\forall x_{1},x_{2}\in{\cal X}. (2)

where μ\mu is the standard Gaussian measure on d\mathbb{R}^{d}, which is an RFS for the kernel tkσh\mathrm{tk}_{\sigma}^{h}.

Next, we present the definition of CC-bounded RFS. For activation function σ\sigma, the NTK RFS is CC-bounded for C=σC=\|\sigma^{\prime}\|_{\infty}.

Definition B.2 (CC-bounded RFS).

We say that ψ\psi is CC-bounded if ψC\|\psi\|\leq C.

Furthermore, we present the definition of factorized RFS. Additionally, the NTK RFS can be factorized.

Definition B.3 (Factorized RFS).

We say that an RFS ψ:Ω×𝕊d1d\psi:\Omega\times\mathbb{S}^{d-1}\rightarrow\mathbb{R}^{d} is factorized if there exists a function ψ1:Ω×𝕊d1\psi_{1}:\Omega\times\mathbb{S}^{d-1}\rightarrow\mathbb{R} such that ψ(w,x)=ψ1(w,x)x\psi(w,x)=\psi_{1}(w,x)x.

For the remainder of this paper, we use SGDRFS(ψ,m,l,η,b,T)(\psi,m,l,\eta,b,T) as a shorthand to denote the Stochastic Gradient Descent in the Random Feature Space. This encapsulates the following parameters: the Random Feature Space (RFS) ψ\psi, input dimension dd, the count of random features mm, the loss function ll, learning rate η\eta, SGD batch size bb, and the number of iterations TT. Additionally, the optimization process will initialize the weight vector as v1=0m×dv^{1}=0\in\mathbb{R}^{m\times d}.

B.3 Several Instances of ψ\psi

Suppose we consider the following neural network function f:df:\mathbb{R}^{d}\rightarrow\mathbb{R}

f(x)=r=1marσ(wr,x).\displaystyle f(x)=\sum_{r=1}^{m}a_{r}\sigma(\langle w_{r},x\rangle).

For the ReLU activation function σ\sigma (see [30, 79]), we have

σ(z):=max{z,0}\displaystyle\sigma(z):=\max\{z,0\}

and

ψ(w,x)=\displaystyle\psi(w,x)= xσ(z)|z=w,x\displaystyle~{}x\cdot\sigma(z)^{\prime}|_{z=\langle w,x\rangle}
=\displaystyle= x𝟏w,x0\displaystyle~{}x\cdot{\bf 1}_{\langle w,x\rangle\geq 0}

Due to recent trending of Large language models, there are a number of work study the exponential or softmax based objective function [8, 14, 57, 26, 54, 73, 84, 37, 95, 28, 36]. Thus, we can also consider exponential activation function σ\sigma (see [33])

σ(z):=exp(z),\displaystyle\sigma(z):=\exp(z),

and

ψ(w,x)=\displaystyle\psi(w,x)= xσ(z)|z=w,x\displaystyle~{}x\cdot\sigma(z)^{\prime}|_{z=\langle w,x\rangle}
=\displaystyle= xexp(w,x)\displaystyle~{}x\cdot\exp(\langle w,x\rangle)

Appendix C Missing Proofs

In Section C.1 we present the proof of the Equivalence for NNT and NTKT. In Section C.2, we present the proof of Theorem 5.6. In Section C.3, we present the proof of running time for our algorithm.

C.1 Proof of Equivalence for NNT and NTKT

Lemma C.1 (Equivalence for NNT and NTKT, restatement of Lemma 5.3).

If the following conditions hold

  • Fix a descent activation σ\sigma as well as a convex descent loss ll(Definition 3.1).

There is a choice B=poly(d,m,1/η,T,1/ϵ)B=\operatorname{poly}(d,m,1/\eta,T,1/\epsilon), such that for every input distribution the following holds: Let h1,h2h_{1},h_{2} be the functions returned by NeuralNetworkTraining(σ,d,m,l,η,b,T,B)(\sigma,d,m,l,\eta,b,T,B) with parameters d,m,ηB2,b,B,Td,m,\frac{\eta}{B^{2}},b,B,T and NeuralTangentKernelTraining(σ,d,m,l,η,b,T)(\sigma,d,m,l,\eta,b,T).

Then, we have

  • |𝔼[𝒟(h1)]𝔼[𝒟(h2)]|<ϵ|\operatorname*{{\mathbb{E}}}[{\cal L}_{\cal D}(h_{1})]-\operatorname*{{\mathbb{E}}}[{\cal L}_{\cal D}(h_{2})]|<\epsilon.

Proof.

For simplicity, we first prove under the assumption that the activation function σ\sigma is twice differentiable and satisfied σ,σ′′<M\|\sigma^{\prime}\|_{\infty},\|\sigma^{\prime\prime}\|_{\infty}<M. Then, at the end of this proof, we will show how this implies the proof for the case where the activation function is MM-descent.

We analyze two different implementations of the NeuralNetworkTraining()(\cdot) algorithm: The first implementation initiates with weights W1=(W,u)W_{1}=(W,u) sampled from distribution (d,m,1){\cal I}(d,m,1) and adopts learning rate η1\eta_{1}. The second implementation utilizes the exact same mini-batches and hyperparameters as the first one, with the exception that the output weights are scaled by a factor BB and the learning rate is divided by B2B^{2}, i.e., W2=(W,Bu)W_{2}=(W,Bu) and η2=η1/B2\eta_{2}=\eta_{1}/B^{2}. Essentially, this transforms the network function from hW(x)h_{W}(x) to h~W(x):=BhW(x)\widetilde{h}_{W}(x):=Bh_{W}(x).

Next, we show that the second implementation of NeuralNetworkTraining approximates NeuralTangentKernelTraining(σ,d,m,l,η,b,T)(\sigma,d,m,l,\eta,b,T). Compared to the first implementation, the gradient of the hidden layer becomes BB times larger, while the gradient of the output layer remains unchanged, i.e.,

WhW2(x)=BWhW1(x)\displaystyle\nabla_{W}h_{W_{2}}(x)=B\cdot\nabla_{W}h_{W_{1}}(x)
uhW2(x)=uhW1(x)\displaystyle\nabla_{u}h_{W_{2}}(x)=\nabla_{u}h_{W_{1}}(x)

Consequently, the overall shift is scaled down by a factor of 1/B1/B. Thus, the optimization process operates within a sphere of radius R/BR/B around WW, where RR is a polynomial in M,d,m,1/η,T,1/ϵM,d,m,1/\eta,T,1/\epsilon.

Next, we examine the first-order approximation of h~W\widetilde{h}_{W} around the initial weight, specifically,

|h~W+V(x)Bhw(x)BWhW(x),V|\displaystyle|\widetilde{h}_{W+V}(x)-Bh_{w}(x)-B\langle\nabla_{W}h_{W}(x),V\rangle|\leq H2V2\displaystyle~{}\frac{H}{2}\|V\|^{2}
|h~W+V(x)BWhW(x),V|\displaystyle|\widetilde{h}_{W+V}(x)-B\langle\nabla_{W}h_{W}(x),V\rangle|\leq H2V2\displaystyle~{}\frac{H}{2}\|V\|^{2}

The first step is derived from hW(x)=0h_{W}(x)=0 for the initial weight WW and HH signifies a uniform bound on the Hessian of hw(x)h_{w}(x), which is obtained from the fact that σ,σ′′<M\|\sigma^{\prime}\|_{\infty},\|\sigma^{\prime\prime}\|_{\infty}<M. Since RR doesn’t depend on BB, for a sufficiently large BB, the quadratic part V20\|V\|^{2}\leftarrow 0. Thus, we only need to consider the scenario where the optimization is conducted over the linear function BWhW(x),VB\langle\nabla_{W}h_{W}(x),V\rangle with a learning rate of η/B2\eta/B^{2} and starting at 0. This is equivalent to NeuralTangentKernelTraining(σ,d,m,l,η,b,T)(\sigma,d,m,l,\eta,b,T) that optimizes over the linear function Wh(W,x),V\langle\nabla_{W}h(W,x),V\rangle with a learning rate of η\eta and starting at 0.

It’s important to note that any MM-descent activation function locally ensures σ,σ′′<M\|\sigma^{\prime}\|_{\infty},\|\sigma^{\prime\prime}\|_{\infty}<M. Additionally, if BB is sufficiently large, the output of the hidden layer before the activation remains largely stable throughout the entire optimization process. Given this, we don’t transition into different regions that comply with σ,σ′′<M\|\sigma^{\prime}\|_{\infty},\|\sigma^{\prime\prime}\|_{\infty}<M for every sample in the mini-batches.

C.2 Proof of Theorem 5.6

In this section, we first present the correctness theorem for SGDRFS and its general proof by Lemma C.5. Then, we present some definitions and lemmas to prove Lemma C.5. Finally, we present the proof of Lemma C.5.

Theorem C.2 (Restatement of Theorem 5.6).

Assume that

  • ψ\psi is a factorized (Definition B.3),

  • CC-bounded Random Feature Space (RFS) (Definition B.1) for 𝖪\mathsf{K},

  • ll is convex (Definition 3.1) and LL-Lipschitz (Definition 3.2),

  • 𝒟{\cal D} has RR-bounded marginal (Definition 1.1).

Let ff be the function returned by SGDRFS(ψ,m,l,η,b,T)(\psi,m,l,\eta,b,T). Fix a function fkf^{*}\in{\cal H}_{k}. Then:

𝔼[𝒟(f)]𝒟(f)+LRCf𝖪md+f𝖪22ηT+ηL2C22\displaystyle\operatorname*{{\mathbb{E}}}[{\cal L}_{\cal D}(f)]\leq{\cal L}_{\cal D}(f^{*})+\frac{LRC\|f^{*}\|_{\mathsf{K}}}{\sqrt{md}}+\frac{\|f^{*}\|_{\mathsf{K}}^{2}}{2\eta T}+\frac{\eta L^{2}C^{2}}{2}

In particular, if |f𝖪|M|f^{|}_{\mathsf{K}}\leq M and η=MTLC\eta=\frac{M}{\sqrt{T}LC}, we have:

𝔼[𝒟(f)]L𝒟(f)+LRCMmd+LCMT\displaystyle\operatorname*{{\mathbb{E}}}[{\cal L}_{\cal D}(f)]\leq L_{\cal D}(f^{*})+\frac{LRCM}{\sqrt{md}}+\frac{LCM}{\sqrt{T}}
Proof.

At first, by Hoeffding’s bound (Lemma A.1), we have: For any m2C4ϵ2log(2/δ)m\geq 2C^{4}\epsilon^{-2}\log({2/\delta}), for every x1,x2𝒳x_{1},x_{2}\in{\cal X}, we have:

Pr[|𝖪w(x1,x2)𝖪(x1,x2)|ϵ]δ\displaystyle\Pr[|\mathsf{K}_{w}(x_{1},x_{2})-\mathsf{K}(x_{1},x_{2})|\geq\epsilon]\leq\delta (3)

Next, we will explore how to approximate functions in 𝖪{\cal H}_{\mathsf{K}} using functions from 𝖪w{\cal H}_{\mathsf{K}_{w}}. For this purpose, we consider the following embedding:

xΨx|Ψx:=ψ(,x)L2(Ω,d)\displaystyle x\mapsto\Psi^{x}|\Psi^{x}:=\psi(\cdot,x)\in L^{2}(\Omega,\mathbb{R}^{d}) (4)

From Equation (2), we have that for any x1,x2𝒳x_{1},x_{2}\in{\cal X}, 𝖪(x1,x2)=Ψx1,Ψx2L2(Ω)\mathsf{K}(x_{1},x_{2})=\langle\Psi^{x_{1}},\Psi^{x_{2}}\rangle_{L^{2}(\Omega)}. In particular, according to Theorem 3.9, for every f𝖪f\in{\cal H}_{\mathsf{K}}, there exists a unique function f~L2(Ω,d)\widetilde{f}\in L^{2}(\Omega,\mathbb{R}^{d}) such that:

f~L2(Ω)=f𝖪\displaystyle\|\widetilde{f}\|_{L^{2}(\Omega)}=\|f\|_{\mathsf{K}} (5)

and for every x𝒳x\in{\cal X},

f(x)=f~,ΨxL2(Ω,d)=𝔼wμ[f~(w),ψ(w,x)].\displaystyle f(x)=\langle\widetilde{f},\Psi^{x}\rangle_{L^{2}(\Omega,\mathbb{R}^{d})}=\operatorname*{{\mathbb{E}}}_{w\sim\mu}[\langle\widetilde{f}(w),\psi(w,x)\rangle]. (6)

Then, we denote v:=1m(f~(w1),,f~(wm))dmv^{*}:=\frac{1}{\sqrt{m}}(\widetilde{f}^{*}(w_{1}),\ldots,\widetilde{f}^{*}(w_{m}))\in\mathbb{R}^{dm}. Then, by standard results on SGD (e.g. [70]), we have that given ww,

𝒟(f)𝒟(fw)+12ηTv2+ηL2C22\displaystyle{\cal L}_{\cal D}(f)\leq{\cal L}_{\cal D}(f^{*}_{w})+\frac{1}{2\eta T}\|v^{*}\|^{2}+\frac{\eta L^{2}C^{2}}{2}

Applying the expectation over the selection of ww, and employing Lemma C.5 along with Eq. (5), we obtain:

𝒟(f)𝒟(f)+LRCf𝖪md+f𝖪22ηT+ηL2C22\displaystyle{\cal L}_{\cal D}(f)\leq{\cal L}_{\cal D}(f^{*})+\frac{LRC\|f^{*}\|_{\mathsf{K}}}{\sqrt{md}}+\frac{\|f^{*}\|_{\mathsf{K}}^{2}}{2\eta T}+\frac{\eta L^{2}C^{2}}{2}

For the ease of proof, we introduce the definition of fw(x)f_{w}(x):

Definition C.3 (fw(x)f_{w}(x)).

Given m,xm,x and function ff, we denote fw(x)f_{w}(x) as follows:

fw(x):=1mi=1mf~(wi),ψ(wi,x).\displaystyle f_{w}(x):=\frac{1}{m}\sum_{i=1}^{m}\langle\widetilde{f}(w_{i}),\psi(w_{i},x)\rangle.
Corollary C.4 (Function Approximation).

For the following conditions:

  • for all x𝒳x\in{\cal X}, 𝔼w[|f(x)fw(x)|2]C2f𝖪2m\operatorname*{{\mathbb{E}}}_{w}[|f(x)-f_{w}(x)|^{2}]\leq\frac{C^{2}\|f\|_{\mathsf{K}}^{2}}{m},

  • if 𝒟{\cal D} represents a distribution on 𝒳\cal X,

we establish that:

𝔼w[ffw2,𝒟]Cf𝖪m\displaystyle\operatorname*{{\mathbb{E}}}_{w}[\|f-f_{w}\|_{2,{\cal D}}]\leq\frac{C\|f\|_{\mathsf{K}}}{\sqrt{m}}
Proof.

From Equation (6), we find that 𝔼w[fw(x)]=f(x)\operatorname*{{\mathbb{E}}}_{w}[f_{w}(x)]=f(x). Additionally, for every xx, the variance of fw(x)f_{w}(x) can be computed as follows:

1m𝔼wμ[|f~(w),ψ(w,x)|2]\displaystyle\frac{1}{m}\operatorname*{{\mathbb{E}}}_{w\sim\mu}[|\langle\widetilde{f}(w),\psi(w,x)\rangle|^{2}]\leq C2m𝔼wμ[|f~(w)|2]\displaystyle~{}\frac{C^{2}}{m}\operatorname*{{\mathbb{E}}}_{w\sim\mu}[|\widetilde{f}(w)|^{2}]
=\displaystyle= C2mf𝖪2\displaystyle~{}\frac{C^{2}}{m}\|f\|^{2}_{\mathsf{K}}

The initial step is derived from the fact that ψ\psi is CC-bounded (see Definition B.2), while the concluding step is drawn from Eq. (5). Consequently, it directly leads us to:

𝔼w[|f(x)fw(x)|2]C2mf𝖪2.\displaystyle\operatorname*{{\mathbb{E}}}_{w}[|f(x)-f_{w}(x)|^{2}]\leq\frac{C^{2}}{m}\cdot\|f\|_{\mathsf{K}}^{2}. (7)

Additionally, when 𝒟{\cal D} represents a distribution on the set 𝒳\cal X, we can derive that:

𝔼w[ffw2,𝒟]\displaystyle\operatorname*{{\mathbb{E}}}_{w}[\|f-f_{w}\|_{2,{\cal D}}]\leq 𝔼w[ffw2,𝒟2]\displaystyle~{}\sqrt{\operatorname*{{\mathbb{E}}}_{w}[\|f-f_{w}\|_{2,{\cal D}}^{2}]}
=\displaystyle= 𝔼w[𝔼x𝒟[|f(x)fw(x)|2]]\displaystyle~{}\sqrt{\operatorname*{{\mathbb{E}}}_{w}[\operatorname*{{\mathbb{E}}}_{x\sim{\cal D}}[|f(x)-f_{w}(x)|^{2}]]}
=\displaystyle= 𝔼x[𝔼w[|f(x)fw(x)|2]]\displaystyle~{}\sqrt{\operatorname*{{\mathbb{E}}}_{x}[\operatorname*{{\mathbb{E}}}_{w}[|f(x)-f_{w}(x)|^{2}]]}
\displaystyle\leq Cf𝖪m\displaystyle~{}\frac{C\|f\|_{\mathsf{K}}}{\sqrt{m}}

where the first step follows from Jensen’s inequality(Lemma A.2), the second step follows from plugging in the definition of 2,𝒟\|\cdot\|_{2,{\cal D}} , the third step follows from exchanging the order of expectation, and the last step follows from using Eq. (7). ∎

Thus, O(f𝖪2ϵ2)O(\frac{\|f\|_{\mathsf{K}}^{2}}{\epsilon^{2}}) random features are adequate to ensure an expected L2L^{2} distance of no more than ϵ\epsilon. Next, we present a situation where a dd-dimensional random feature is as effective as dd one-dimensional random features. Specifically, O(f𝖪2dϵ2)O(\frac{\|f\|_{\mathsf{K}}^{2}}{d\epsilon^{2}}) random features are sufficient to guarantee an expected L2L^{2} distance of at most ϵ\epsilon.

Lemma C.5 (Closeness of ff and fwf_{w}).

Assume that

  • ψ:Ω×𝕊d1\psi:\Omega\times\mathbb{S}^{d-1} is factorized (Definition B.3),

  • 𝒟{\cal D} is RR-bounded distribution (Definition 1.1).

Then,

𝔼w[ffw2,𝒟]𝔼w[ffw2,𝒟2]RCmdf𝖪.\displaystyle\operatorname*{{\mathbb{E}}}_{w}[\|f-f_{w}\|_{2,{\cal D}}]\leq\sqrt{\operatorname*{{\mathbb{E}}}_{w}[\|f-f_{w}\|_{2,{\cal D}}^{2}]}\leq\frac{RC}{\sqrt{md}}\cdot\|f\|_{\mathsf{K}}.

Additionally, if l:𝕊d1×Y[0,)l:\mathbb{S}^{d-1}\times Y\rightarrow[0,\infty) is an LL-Lipschitz loss (as per Definition 3.2), and if 𝒟1{\cal D}_{1} is a distribution over 𝕊d1×Y\mathbb{S}^{d-1}\times Y with an RR-bounded marginal (according to Definition 1.1) then:

𝔼w[𝒟1(fw)]𝒟1(f)+LRCmdf𝖪\displaystyle\operatorname*{{\mathbb{E}}}_{w}[{\cal L}_{{\cal D}_{1}}(f_{w})]\leq{\cal L}_{{\cal D}_{1}}(f)+\frac{LRC}{\sqrt{md}}\cdot\|f\|_{\mathsf{K}}
Proof.

Let us have x𝒟x\sim{\cal D} and wμw\sim\mu. We have:

𝔼w[ffw2,𝒟]2\displaystyle\operatorname*{{\mathbb{E}}}_{w}[\|f-f_{w}\|_{2,{\cal D}}]^{2}\leq 𝔼w[ffw2,𝒟2]\displaystyle~{}\operatorname*{{\mathbb{E}}}_{w}[\|f-f_{w}\|_{2,{\cal D}}^{2}]
=\displaystyle= 𝔼w[𝔼x[|f(x)fw(x)|2]]\displaystyle~{}\operatorname*{{\mathbb{E}}}_{w}[\operatorname*{{\mathbb{E}}}_{x}[|f(x)-f_{w}(x)|^{2}]]
=\displaystyle= 𝔼x[𝔼w[|f(x)fw(x)|2]]\displaystyle~{}\operatorname*{{\mathbb{E}}}_{x}[\operatorname*{{\mathbb{E}}}_{w}[|f(x)-f_{w}(x)|^{2}]]
=\displaystyle= 1m𝔼x[𝔼wμ[|f~(w),ψ(w,x)f(x)|2]]\displaystyle~{}\frac{1}{m}\cdot\operatorname*{{\mathbb{E}}}_{x}[\operatorname*{{\mathbb{E}}}_{w\sim\mu}[|\langle\widetilde{f}(w),\psi(w,x)\rangle-f(x)|^{2}]]
\displaystyle\leq 1m𝔼x[𝔼wμ[|f~(w),ψ(w,x)|2]]\displaystyle~{}\frac{1}{m}\cdot\operatorname*{{\mathbb{E}}}_{x}[\operatorname*{{\mathbb{E}}}_{w\sim\mu}[|\langle\widetilde{f}(w),\psi(w,x)\rangle|^{2}]]
=\displaystyle= 1m𝔼wμ[𝔼x[|f~(w),ψ1(w,x)x|2]]\displaystyle~{}\frac{1}{m}\cdot\operatorname*{{\mathbb{E}}}_{w\sim\mu}[\operatorname*{{\mathbb{E}}}_{x}[|\widetilde{f}(w),\psi_{1}(w,x)x|^{2}]]
\displaystyle\leq C2m𝔼wμ[𝔼x[|f~(w),x|2]]\displaystyle~{}\frac{C^{2}}{m}\cdot\operatorname*{{\mathbb{E}}}_{w\sim\mu}[\operatorname*{{\mathbb{E}}}_{x}[|\langle\widetilde{f}(w),x\rangle|^{2}]]
\displaystyle\leq C2R2md𝔼wμ[f~(w)2]\displaystyle~{}\frac{C^{2}R^{2}}{md}\operatorname*{{\mathbb{E}}}_{w\sim\mu}[\|\widetilde{f}(w)\|^{2}]
=\displaystyle= C2R2mdf𝖪2\displaystyle~{}\frac{C^{2}R^{2}}{md}\cdot\|f\|_{\mathsf{K}}^{2}

where the first step comes from Jensen’s inequality (Lemma A.2 ), the second step is because of plugging in the definition of 2,𝒟\|\cdot\|_{2,{\cal D}} , the third step derives from changing the order of expectations, the fourth step follows from Theorem 3.9 and Eq. (6) , the fifth step is due to the fact that the variance is bounded by squared L2L^{2}-norm, the sixth step follows from that the considered RFS is factorized(Definition B.3), the seventh step is because that ψ\psi and ψ1\psi_{1} are CC-bounded (Definition B.2 ), the eighth step comes from 𝒟{\cal D} is RR-bounded (Definition 1.1 ), and the final step is due to Eq. (5). Taking the square root of both sides yields to:

𝔼w[ffw2,𝒟]CRf𝖪md.\displaystyle\operatorname*{{\mathbb{E}}}_{w}[\|f-f_{w}\|_{2,{\cal D}}]\leq\frac{CR\|f\|_{\mathsf{K}}}{\sqrt{md}}.

Finally, for LL-Lipschitz ll(Definition 3.2), and (x,y)𝒟1(x,y)\sim{\cal D}_{1}, then:

𝔼w[𝒟1(fw)]=\displaystyle\operatorname*{{\mathbb{E}}}_{w}[{\cal L}_{{\cal D}_{1}}(f_{w})]= 𝔼w[𝔼x,y[l(fw),y]]\displaystyle~{}\operatorname*{{\mathbb{E}}}_{w}[\operatorname*{{\mathbb{E}}}_{x,y}[l(f_{w}),y]]
\displaystyle\leq 𝔼w[𝔼x,y[l(f(x),y)]]+L𝔼w[𝔼x[|f(x)fw(x)|]]\displaystyle~{}\operatorname*{{\mathbb{E}}}_{w}[\operatorname*{{\mathbb{E}}}_{x,y}[l(f(x),y)]]+L\operatorname*{{\mathbb{E}}}_{w}[\operatorname*{{\mathbb{E}}}_{x}[|f(x)-f_{w}(x)|]]
=\displaystyle= 𝔼x,y[l(f(x),y)]+L𝔼w[𝔼x[|f(x)fw(x)|]]\displaystyle~{}\operatorname*{{\mathbb{E}}}_{x,y}[l(f(x),y)]+L\operatorname*{{\mathbb{E}}}_{w}[\operatorname*{{\mathbb{E}}}_{x}[|f(x)-f_{w}(x)|]]
=\displaystyle= 𝒟1(f)+L𝔼w[𝔼x[|f(x)fw(x)|]]\displaystyle~{}{\cal L}_{{\cal D}_{1}}(f)+L\operatorname*{{\mathbb{E}}}_{w}[\operatorname*{{\mathbb{E}}}_{x}[|f(x)-f_{w}(x)|]]
\displaystyle\leq 𝒟1(f)+L𝔼w[𝔼x[|f(x)fw(x)|2]]\displaystyle~{}{\cal L}_{{\cal D}_{1}}(f)+L\operatorname*{{\mathbb{E}}}_{w}[\sqrt{\operatorname*{{\mathbb{E}}}_{x}[|f(x)-f_{w}(x)|^{2}]}]
\displaystyle\leq 𝒟1(f)+LCRf𝖪md\displaystyle~{}{\cal L}_{{\cal D}_{1}}(f)+\frac{LCR\|f\|_{\mathsf{K}}}{\sqrt{md}}

where the first step is due to definition of \cal L(Definition 3.4) , the second step comes from the LL-Lipschitzness of function ll (Definition 3.5)

, the third step is because that ll is no longer a function of ww, the fourth step follows from the definition of 𝒟1{\cal L}_{{\cal D}_{1}}(Definition 3.4) , the fifth step follows from the fact that 1\ell_{1} distance is upper bounded by 2\ell_{2} distance of f(x)f(x) and fw(x)f_{w}(x), the sixth step follows from Corollary C.4. ∎

C.3 Proof of Running Time

Lemma C.6 (Running time of Algorithm 1, restatement of Lemma 6.9).

Given the following:

  • Sample access to distribution 𝒟d{\cal D}\in\mathbb{R}^{d},

  • Running stochastic gradient descent algorithm (Algorithm 1) on 2𝖭𝖭(2m,b0=0.4log(2m))2\mathsf{NN}(2m,b_{0}=\sqrt{0.4\log(2m)}) (Definition 1.2) with batch size bb,

then the expected cost per-iteration of this algorithm is

O~(m1Θ(1/d)bd)\displaystyle\widetilde{O}(m^{1-\Theta(1/d)}bd)
Proof.

The per-time complexity can be decomposed as follows:

  • Querying the active neuron set for xiStx_{i}\in S_{t} takes O~(m1Θ(1/d)bd)\widetilde{O}(m^{1-\Theta(1/d)}bd) time.

    i=1b𝒯𝗊𝗎𝖾𝗋𝗒(2m,d,ki,t(St))=\displaystyle\sum_{i=1}^{b}{\cal T}_{\mathsf{query}}(2m,d,k_{i,t}(S_{t}))= bO~(m1Θ(1/d)d)\displaystyle~{}b\widetilde{O}(m^{1-\Theta(1/d)}d)
    =\displaystyle= O~(m1Θ(1/d)bd)\displaystyle~{}\widetilde{O}(m^{1-\Theta(1/d)}bd)

    where the first step follows from Corollary 6.8, and the final step follows from calculation.

  • Forward computation takes O(bdm4/5)O(bdm^{4/5}) time.

    i[b]O(dki,t)=\displaystyle\sum_{i\in[b]}O(d\cdot k_{i,t})= O(bdm4/5)\displaystyle~{}O(bdm^{4/5})

    where the last step is due to Lemma 6.6

  • Backward computation takes m4/5bdm^{4/5}bd time.

    • Computing MM takes O(bd)O(bd) time .

    • Computing gradient ΔW\Delta W and updating W(t+1)W(t+1) takes O(m4/5bd)O(m^{4/5}bd) time.

      O(dnnz(P))=O(dbm4/5)\displaystyle O(d\cdot\mathrm{nnz}(P))=O(d\cdot bm^{4/5})

      where the last step is because of Lemma 6.6.

  • Updating the weight vectors in HalfSpaceReport data structure takes O(bm4/5log2(2m))O(bm^{4/5}\log^{2}(2m)) time.

    𝒯𝗎𝗉𝖽𝖺𝗍𝖾(|Si,fire(t)|+|Si,fire(t+1)|)=\displaystyle{\cal T}_{\mathsf{update}}\cdot(|S_{i,\mathrm{fire}}(t)|+|S_{i,\mathrm{fire}}(t+1)|)= O(log2(2m))(i[b]ki,t+ki,t+1)\displaystyle~{}O(\log^{2}(2m))\cdot(\sum_{i\in[b]}k_{i,t}+k_{i,t+1})
    =\displaystyle= O(log2(2m))O(bm4/5)\displaystyle~{}O(\log^{2}(2m))\cdot O(bm^{4/5})
    =\displaystyle= O(bm4/5log2(2m))\displaystyle~{}O(bm^{4/5}\log^{2}(2m))

Summing over all the above terms gives us the per iteration running time O~(m1Θ(1/d)bd)\widetilde{O}(m^{1-\Theta(1/d)}bd). ∎