This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

The Fine-Grained Complexity of Gradient Computation for Training Large Language Models

Josh Alman [email protected]. Columbia University.    Zhao Song [email protected]. Adobe Research.

Large language models (LLMs) have made fundamental contributions over the last a few years. To train an LLM, one needs to alternatingly run ‘forward’ computations and ‘backward’ computations. The forward computation can be viewed as attention function evaluation, and the backward computation can be viewed as a gradient computation. In previous work by [Alman and Song, NeurIPS 2023], it was proved that the forward step can be performed in almost-linear time in certain parameter regimes, but that there is no truly sub-quadratic time algorithm in the remaining parameter regimes unless the popular hypothesis 𝖲𝖤𝖳𝖧\operatorname{\mathsf{SETH}} is false. In this work, we show nearly identical results for the harder-seeming problem of computing the gradient of loss function of one layer attention network, and thus for the entire process of LLM training. This completely characterizes the fine-grained complexity of every step of LLM training.

1 Introduction

Large language models (LLMs) have emerged as popular technologies, driving breakthroughs across many applications in natural language processing, computer vision, translation, and many other areas [47, 15, 35, 51, 9, 54, 14, 45, 46, 30, 36, 44, 50, 49]. The training of these models is a computationally intensive process, characterized by alternating between two primary operations: forward computation and backward computation. Forward computation, or function evaluation, involves the propagation of input data through the network to generate predictions. Conversely, backward computation, or gradient computation, is the process of calculating the gradient of the loss function with respect to the model’s parameters, facilitating the optimization of these parameters during training.

The efficiency of these computations directly impacts the feasibility and scalability of training LLMs, particularly as models grow in size and complexity. Recent work by [4, 5] has carefully studied the forward computation step. They demonstrated a sharp computational boundary, showing that how quickly the forward steps can be performed depends critically on how large the entries are of the matrices which define the model parameters. They showed a near-linear time algorithm when these entries are small, and also proved that when the entries are large, there is no algorithm much faster than the trivial algorithm, contingent upon the Strong Exponential Time Hypothesis (𝖲𝖤𝖳𝖧\operatorname{\mathsf{SETH}}) [31] holding true. This finding underscores a fundamental limitation in accelerating the training of LLMs, raising pivotal questions about the inherent computational complexity of these models.

The Strong Exponential Time Hypothesis (𝖲𝖤𝖳𝖧\operatorname{\mathsf{SETH}}) was introduced by Impagliazzo and Paturi [31] over 20 years ago. It is a strengthening of the 𝖯𝖭𝖯\mathsf{P}\neq\mathsf{NP} conjecture, and asserts that our current best 𝖲𝖠𝖳\mathsf{SAT} algorithms are roughly optimal (for detailed statement, see Hypothesis 3.3 below). 𝖲𝖤𝖳𝖧\operatorname{\mathsf{SETH}} is a popular conjecture from fine-grained complexity theory which has been used to prove lower bounds for a wide variety of algorithmic problems. See, for instance, the survey [48].

In other words, in some parameter regimes, the algorithm of [4] performs the forward steps about as quickly as one could hope for, whereas in other regimes, assuming 𝖲𝖤𝖳𝖧\operatorname{\mathsf{SETH}}, it is impossible to design a nontrivially fast algorithm. However, this leaves open many important questions about LLM training. In the case when forward computation can be done quickly, can the same be said for backward computation? If not, then the entire training process would still be slow. Relatedly, in parameter regimes where forward computation is known to be hard, is backward computation also hard? If not, perhaps heuristic tricks could be used, or other details of the model could be modified, to speed up the overall training. As we will see shortly, the backward step is defined in a much more complicated way than the forward step, and it is not evident that algorithms or lower bounds for one extend to the other.

Our study aims to resolve these questions and determine the fine-grained complexity of the backward computation phase. Our main result (which we state more foramlly shortly) shows that the same computational threshold from forward computation also arises for the backward problem, and that the problems are easy (opr hard) in the exact same parameter regimes. Thus, the forward algorithm of [4] can be combined with our novel backward algorithm to perform each training step for LLMs in near-linear time when the parameter matrix entries are small enough, whereas when the entries are not small enough, neither step can be performed quickly.

In addition to characterizing the fine-grained complexity of LLM training, our result for gradient computation is novel for a few reasons.

  • Previous work on computational lower bounds, only focuses on forward computation, see [4, 34, 5]. To our knowledge, ours is the first work to prove hardness of a backward computation step for training an LLM or similar model.

  • There has been previous work on the algorithms for backward/gradient computation [10, 42, 17, 3, 23, 43]. That said, most of these works focus on backwards computation in other settings. The only previous work we’re aware of that studies the optimization of attention layers (for LLMs) is [24], which uses Newton method that rely on Hessian computation. However, Hessian computation is substantially more expensive than gradient computation; our results apply to the gradient computation and get around the Hessian “barrier”, allowing for faster algorithms in some parameter regimes, and more powerful lower bounds in others.

1.1 Problem Definition

Before formally stating our results, we begin by precisely defining the problems we study. We begin with the following problem of the computation of general Attention forward layer.

Definition 1.1 (\ell-th layer forward computation).

Given weights Q,K,Vd×dQ,K,V\in\mathbb{R}^{d\times d}, and letting En×dE_{\ell}\in\mathbb{R}^{n\times d} denote the \ell-th layer input, then E+1n×dE_{\ell+1}\in\mathbb{R}^{n\times d} is defined recursively as

E+1D1exp(EQKE/d)EV\displaystyle E_{\ell+1}\leftarrow D^{-1}\exp(E_{\ell}QK^{\top}E_{\ell}^{\top}/d)E_{\ell}V

where

  • D:=diag(exp(EQKE/d)𝟏n)D:=\mathrm{diag}(\exp(E_{\ell}QK^{\top}E_{\ell}^{\top}/d){\bf 1}_{n}).

  • exp\exp denotes the exponential function which is entry-wise, i.e., exp(A)i,j=exp(Ai,j)\exp(A)_{i,j}=\exp(A_{i,j}) for all matrices AA.

  • diag()\mathrm{diag}() operation takes a vector as input and generates a diagonal matrix with the entries of that vector.

  • 𝟏n{\bf 1}_{n} denotes the length-nn all ones vector.

In mathematical terms, optimization in the context of attention computation is described as (by renaming the QKd×dQK^{\top}\in\mathbb{R}^{d\times d} to be Xd×dX\in\mathbb{R}^{d\times d} and Vd×dV\in\mathbb{R}^{d\times d} to be Yd×dY\in\mathbb{R}^{d\times d}):

Definition 1.2 (Attention optimization).

Given four n×dn\times d size matrices A1,A2,A3A_{1},A_{2},A_{3} and En×dE\in\mathbb{R}^{n\times d}. Suppose that a d×dd\times d size square matrix YY\in\mathbb{R} is also given. The attention optimization problem is formulated as:

minXd×dL(X):=0.5D(X)1exp(A1XA2/d)A3YEF2.\displaystyle\min_{X\in\mathbb{R}^{d\times d}}L(X):=0.5\|D(X)^{-1}\exp(A_{1}XA_{2}^{\top}/d)A_{3}Y-E\|_{F}^{2}.

Here D(X)n×nD(X)\in\mathbb{R}^{n\times n} is

D(X):=diag(exp(A1XA2/d)𝟏n).\displaystyle D(X):=\mathrm{diag}(\exp(A_{1}XA_{2}^{\top}/d){\bf 1}_{n}).

and F2\|\cdot\|_{F}^{2} denotes the squared Frobenius norm, i.e., AF2:=i,jAi,j2\|A\|_{F}^{2}:=\sum_{i,j}A_{i,j}^{2}.

Remark 1.3.

In principle, the loss function above, and resulting gradients below, should depend on both XX and YY. However, since the final matrix computed in the norm in LL depends only linearly on YY, it is straightforward to incorporate it into either an algorithm or lower bound. Thus, in this work, we focus on the case where XX is variable and YY is a fixed input to simplify some arguments.

We thus define Approximate Attention Loss function Gradient Computation problem as follows:

Definition 1.4 (Approximate Attention Loss Gradient Computation (𝖠𝖠𝗍𝗍𝖫𝖦𝖢(n,d,ϵ)\mathsf{AAttLGC}(n,d,\epsilon))).

Given four n×dn\times d size matrices A1n×d,A2n×d,A3n×d,A_{1}\in\mathbb{R}^{n\times d},A_{2}\in\mathbb{R}^{n\times d},A_{3}\in\mathbb{R}^{n\times d},, En×dE\in\mathbb{R}^{n\times d} and a square matrix Yd×dY\in\mathbb{R}^{d\times d} to be fixed matrices. Assume that A1XB\|A_{1}X\|_{\infty}\leq B, A2B\|A_{2}\|_{\infty}\leq B. Assume all numbers (in matrices) are also in log(n)\log(n) bits model. Let L(X)L(X) be defined as Definition 1.2. Let dL(X)dX\frac{\mathrm{d}L(X)}{\mathrm{d}X} denote the gradient of loss function L(x)L(x).

The goal is to output a vector g~\widetilde{g} such that

g~dL(X)dXϵ.\displaystyle\|\widetilde{g}-\frac{\mathrm{d}L(X)}{\mathrm{d}X}\|_{\infty}\leq\epsilon.

Here for matrix AA, A:=maxi,j|Ai,j|\|A\|_{\infty}:=\max_{i,j}|A_{i,j}|.

1.2 Main Results

Our main results show that there is a threshold in the computational complexity of 𝖠𝖠𝗍𝗍𝖫𝖦𝖢(n,d=O(logn))\mathsf{AAttLGC}(n,d=O(\log n)) depending on the bound BB. When B=o(logn)B=o(\sqrt{\log n}) we give a new near-linear-time algorithm, and when B=ω(logn)B=\omega(\sqrt{\log n}), we show that such an algorithm is impossible assuming SETH. This matches the results of [4], where a nearly identical threshold at BB around logn\sqrt{\log n} was also observed. Our results therefore imply that the entire LLM training process has this computational threshold.

Theorem 1.5 (Main result, Lower bound, informal version of Theorem 5.5).

Assuming 𝖲𝖤𝖳𝖧\mathsf{SETH}, there is no algorithm running in time O(n2q)O(n^{2-q}) for any q>0q>0 for the 𝖠𝖠𝗍𝗍𝖫𝖦𝖢(n,d=O(logn),B=ω(logn))\mathsf{AAttLGC}(n,d=O(\log n),B=\omega(\sqrt{\log n})) (see Definition 1.4).

Theorem 1.6 (Main result, Upper bound, informal version of Theorem D.6).

Assuming entries are bounded, there is a n1+o(1)n^{1+o(1)} time algorithm to solve 𝖠𝖠𝗍𝗍𝖫𝖦𝖢(n,d=O(logn),B=o(logn))\mathsf{AAttLGC}(n,d=O(\log n),B=o(\sqrt{\log n})) (see Definition 1.4) up to 1/poly(n)1/\operatorname{poly}(n) accuracy.

Our new algorithm (Theorem 1.6) builds on a low-rank approximation for the attention matrix from prior work [1, 4]. Incorporating these approximation into the gradient computation is not straightforward; in the forward problem, one simply multiplies the attention matrix by an input value matrix, but in the backward problem, it is combined with other matrices in an intricate (non-linear) way. We ultimately use tools from tensor algebra to get a handle on the entry-wise products and high-rank sparse matrices which arise in the gradient computation but do not typically preserve the needed low-rank structure.

Our new lower bound (Theorem 1.5) comes from a careful reduction from a special case the forward problem (where hardness is known from prior work) to the backward problem. Reducing from computing a function to computing its gradient in general is quite challenging or impossible without control over how quickly the gradient may be growing or changing, and in general, the gradient of the forward (attention) computation can behave quite erratically (which is likely necessary for the expressive power of attention units). Nonetheless, in the special case of the inputs for which attention computation is known to be hard from prior work, we are able to reasonably control the growth of these gradients and successfully perform our reduction.

Roadmap. We discuss other related works in Section 2. In Section 3, we provide the basic notation, definitions, backgrounds, and facts which we will use. In Section 4, we provide the proof sketch of our algorithm and defer the details to the Appendix. In Section 5, we provide our main lower bound result. In Section 6, we briefly conclude our paper.

2 Related Work

Fine-grained Complexity

Numerous algorithmic techniques have been used in theory and in practice for attention computations. The first algorithm with provable guarantees, by Zandieh, Han, Daliri, and Karbasi [53], used locality sensitive hashing (LSH) techniques [12], while later work by Alman and Song [4] used polynomial approxmation methods [2, 1]. We particularly focus here on the latter technique, which is the only algorithm we’re aware of which achieves near-linear running time.

Keles, Wijewardena, and Hedge [34] established the first lower bound on attention computation under the assumption of 𝖲𝖤𝖳𝖧\mathsf{SETH}. Their findings demonstrated that when d=ω(logn)d=\omega(\log n), it is not possible to execute forward computations in subquadratic time. The later lower bound of [4] further incorporated the magnitudes of the input entries into the lower bound to tightly match the aforementioned algorithms. Both use the high-level technique of [7] from kernel density estimation, and build on methods derived from fine-grained complexity associated with approximate nearest neighbor search [40] and the polynomial method [1].

Fast Attention Computation

Optimizing the computation of attention mechanisms in pre-trained LLMs, given their extensive parameter sets, has been a focal point of recent research. Various studies have explored the application of locality sensitive hashing (LSH) techniques to approximate attention mechanisms. [32] introduced two methods to enhance computational efficiency, including the use of LSH to replace dot product attention and a reversible residual layer to substitute the standard residual layer. [13] refined this approximation, noting that LSH’s efficiency does not require constant parameter updates. [53] proposed an innovative estimator based on Kernel Density Estimation (KDE) to speed up the softmax function and matrix multiplication computations. Some recent works [29, 33] have specifically used sketching techniques to avoid large entries in the attention matrix. [38] developed techniques utilizing a transformer within a transformer (TinT) model to simulate the transformer’s forward and backward passes, significantly increasing parameter efficiency. [37] tackled the challenge of fine-tuning LLMs with high memory demands by improving the classical ZO-SCD optimizer, creating a memory-efficient gradient estimator that requires only a forward pass. [11] provided insights into dynamic attention problems, they provide algorithm and hardness for the dynamic setting of attention problem. [28] introduces a quantum algorithm for attention computation, opening new avenues for efficiency improvements. [26] provides a result for computing the attention matrix differentially privately. [20] introduces a randomized and deterministic attention sparsification algorithms for over-parameterized feature dimension. [19] provides a zero-th order method to accelarate the computation of attention.

Transformer Training

Transformer architectures (the backbone of LLMs) have been trained with alternating steps of forward and backward computations since their introduction [47, 15, 35, 51, 9, 54]. In Appendix B below, we perform computations to verify that our stated problems are the same as the forward and backward steps from the literature.

3 Preliminary

In Section 3.1, we define some basic notation we will use. In Section 3.2, we state important facts related to fast matrix multiplication. In Section 3.3, provide the formal definition of the Strong Exponential Time Hypothesis. In Section 3.4, we define several intermediate functions related to softmax and exponential which will arise in our algorithms. In Section 3.5, we define the loss function. In Section 3.6, we provide standard tensor tricks which we will use. In Section 3.7, we show how to reformulate the loss function for our purposes.

3.1 Notation

For any positive integer nn, we define [n]:={1,2,,n}[n]:=\{1,2,\dots,n\}. For two same length vector xx and yy, we use x,y\langle x,y\rangle to denote the inner product between xx and yy, i.e., x,y=i=1nxiyi\langle x,y\rangle=\sum_{i=1}^{n}x_{i}y_{i}. We use xyx\circ y to denote vector that ii-th entry is xiyix_{i}y_{i}. Let 𝟏n{\bf 1}_{n} denote the length-nn all ones vector. It is not hard to see that xy,𝟏n=x,y\langle x\circ y,{\bf 1}_{n}\rangle=\langle x,y\rangle. For a vector xx, we use xx^{\top} to denote the transpose of xx. For a matrix MM, we use MM^{\top} to denote the transpose of matrix MM. For a vector xx, we use exp(z)\exp(z) to denote the vector that ii-th coordinate is exp(zi)\exp(z_{i}). For a matrix MM, we use exp(M)\exp(M) to denote the matrix that (i,j)(i,j)-th coordinate is exp(Mi,j)\exp(M_{i,j}). For a function ff, we use O~(f)\widetilde{O}(f) to denote fpoly(logf)f\cdot\operatorname{poly}(\log f). Let n0,n1,m0,m1n_{0},n_{1},m_{0},m_{1} be positive integers. Let Xn0×m0X\in\mathbb{R}^{n_{0}\times m_{0}} and Yn1×m1Y\in\mathbb{R}^{n_{1}\times m_{1}}. We define the Kronecker product between matrices XX and YY, denoted XYn0n1×m0m1X\otimes Y\in\mathbb{R}^{n_{0}n_{1}\times m_{0}m_{1}}, as (XY)(j01)n1+j1,(i01)m2+i1(X\otimes Y)_{(j_{0}-1)n_{1}+j_{1},(i_{0}-1)m_{2}+i_{1}} is equal to Xj0,i0Yj1,i1X_{j_{0},i_{0}}Y_{j_{1},i_{1}}, where j0[n0],i0[m0],j1[n1],i1[m1]j_{0}\in[n_{0}],i_{0}\in[m_{0}],j_{1}\in[n_{1}],i_{1}\in[m_{1}].

3.2 Matrix Multiplication

We define matrix multiplication notation and state some well-know facts here.

Definition 3.1.

Let n1,n2,n3n_{1},n_{2},n_{3}, denote any three positive integers. We use 𝒯mat(n1,n2,n3){\cal T}_{\mathrm{mat}}(n_{1},n_{2},n_{3}) to denote the time of multiplying an n1×n2n_{1}\times n_{2} matrix with another n2×n3n_{2}\times n_{3}.

It is well-known that

Fact 3.2 ([6, 8]).

Let n1,n2,n3n_{1},n_{2},n_{3}, denote any three positive integers. 𝒯mat(n1,n2,n3)=O(𝒯mat(n1,n3,n2))=O(𝒯mat(n2,n1,n3))=O(𝒯mat(n2,n3,n1))=O(𝒯mat(n3,n1,n2))=O(𝒯mat(n3,n2,n1)){\cal T}_{\mathrm{mat}}(n_{1},n_{2},n_{3})=O({\cal T}_{\mathrm{mat}}(n_{1},n_{3},n_{2}))=O({\cal T}_{\mathrm{mat}}(n_{2},n_{1},n_{3}))=O({\cal T}_{\mathrm{mat}}(n_{2},n_{3},n_{1}))=O({\cal T}_{\mathrm{mat}}(n_{3},n_{1},n_{2}))=O({\cal T}_{\mathrm{mat}}(n_{3},n_{2},n_{1})).

3.3 Backgrounds on Complexity

Over 20 years ago, Impagliazzo and Paturi [31] introduced the Strong Exponential Time Hypothesis (𝖲𝖤𝖳𝖧\operatorname{\mathsf{SETH}}), an enhancement of the 𝖯𝖭𝖯\mathsf{P}\neq\mathsf{NP} conjecture. It posits that the existing algorithms for solving 𝖲𝖠𝖳\mathsf{SAT} problems are essentially as efficient as possible:

Hypothesis 3.3 (Strong Exponential Time Hypothesis (𝖲𝖤𝖳𝖧\operatorname{\mathsf{SETH}})).

For any ϵ>0\epsilon>0, there exists a positive integer k3k\geq 3 for which solving kk-𝖲𝖠𝖳\mathsf{SAT} problems with nn variables in O(2(1ϵ)n)O(2^{(1-\epsilon)n}) time is impossible, including with the use of randomized algorithms.

SETH, a widely recognized conjecture, has been instrumental in establishing fine-grained lower bounds across a broad spectrum of algorithmic challenges, as highlighted in the survey [48].

3.4 Definitions related with Softmax

Now, we start by some definitions about Xd×dX\in\mathbb{R}^{d\times d} which will be helpful. Let xx denote the vectorization of XX.

Definition 3.4.

Let A1,A2n×dA_{1},A_{2}\in\mathbb{R}^{n\times d} be two matrices. Suppose that 𝖠=A1A2n2×d2\operatorname{\mathsf{A}}=A_{1}\otimes A_{2}\in\mathbb{R}^{n^{2}\times d^{2}}. We define 𝖠j0n×d2\operatorname{\mathsf{A}}_{j_{0}}\in\mathbb{R}^{n\times d^{2}} be a n×d2n\times d^{2} size sub-block from 𝖠\operatorname{\mathsf{A}}. Note that there nn such sub-blocks.

For every j0[n]j_{0}\in[n], let us define function u(x)j0:d2nu(x)_{j_{0}}:\mathbb{R}^{d^{2}}\rightarrow\mathbb{R}^{n} to be:

u(x)j0:=exp(𝖠j0x)n×1.\displaystyle u(x)_{j_{0}}:=\underbrace{\exp(\operatorname{\mathsf{A}}_{j_{0}}x)}_{n\times 1}.
Definition 3.5.

Suppose that there are two n×dn\times d size matrices A1,A2n×dA_{1},A_{2}\in\mathbb{R}^{n\times d}. We define 𝖠j0n×d2\operatorname{\mathsf{A}}_{j_{0}}\in\mathbb{R}^{n\times d^{2}} be a n×d2n\times d^{2} size sub-block from 𝖠\operatorname{\mathsf{A}}. (Recall that 𝖠=A1A2n2×d2\operatorname{\mathsf{A}}=A_{1}\otimes A_{2}\in\mathbb{R}^{n^{2}\times d^{2}}.)

For every index j0[n]j_{0}\in[n], we consider a function, α(x)j0:d2\alpha(x)_{j_{0}}:\mathbb{R}^{d^{2}}\rightarrow\mathbb{R} as:

α(x)j0:=exp(𝖠j0x)n×1,𝟏nn×1.\displaystyle\alpha(x)_{j_{0}}:=\langle\underbrace{\exp(\operatorname{\mathsf{A}}_{j_{0}}x)}_{n\times 1},\underbrace{{\bf 1}_{n}}_{n\times 1}\rangle.
Definition 3.6.

Suppose that α(x)j0\alpha(x)_{j_{0}}\in\mathbb{R} is defined as in Definition 3.5.

Recall u(x)j0nu(x)_{j_{0}}\in\mathbb{R}^{n} is defined as in Definition 3.4.

For a fixed j0[n]j_{0}\in[n], let us consider function f(x)j0:d2nf(x)_{j_{0}}:\mathbb{R}^{d^{2}}\rightarrow\mathbb{R}^{n}

f(x)j0:=α(x)j01scalaru(x)j0n×1.\displaystyle f(x)_{j_{0}}:=\underbrace{\alpha(x)_{j_{0}}^{-1}}_{\mathrm{scalar}}\underbrace{u(x)_{j_{0}}}_{n\times 1}.

Let f(x)n×nf(x)\in\mathbb{R}^{n\times n} denote the matrix where j0j_{0}-th row is (f(x)j0)(f(x)_{j_{0}})^{\top}.

Definition 3.7.

For every i0[d]i_{0}\in[d], we define h()i0:d2nh()_{i_{0}}:\mathbb{R}^{d^{2}}\rightarrow\mathbb{R}^{n} as:

h(y)i0:=A3n×dY,i0d×1.\displaystyle h(y)_{i_{0}}:=\underbrace{A_{3}}_{n\times d}\underbrace{Y_{*,i_{0}}}_{d\times 1}.

Here let Yd×dY\in\mathbb{R}^{d\times d} denote the matrix representation of yd2y\in\mathbb{R}^{d^{2}}. Let h(y)n×dh(y)\in\mathbb{R}^{n\times d} matrix where i0i_{0} column is h(y)i0h(y)_{i_{0}}.

3.5 Loss Functions

In this section, we introduce some helpful definitions related to both xd2x\in\mathbb{R}^{d^{2}}.

Definition 3.8.

For every j0[n]j_{0}\in[n], we use f(x)j0nf(x)_{j_{0}}\in\mathbb{R}^{n} to denote the normalized vector defined by Definition 3.6. For every i0[d]i_{0}\in[d], we let h(y)i0h(y)_{i_{0}} to be defined in Definition 3.7.

Consider every j0[n]j_{0}\in[n], every i0[d]i_{0}\in[d]. Let us consider c(x)j0,i0:d2×d2c(x)_{j_{0},i_{0}}:\mathbb{R}^{d^{2}}\times\mathbb{R}^{d^{2}}\rightarrow\mathbb{R} as follows:

c(x)j0,i0:=f(x)j0,h(y)i0Ej0,i0.\displaystyle c(x)_{j_{0},i_{0}}:=\langle f(x)_{j_{0}},h(y)_{i_{0}}\rangle-E_{j_{0},i_{0}}.

Here Ej0,i0E_{j_{0},i_{0}} is the (j0,i0)(j_{0},i_{0})-th coordinate/location of En×dE\in\mathbb{R}^{n\times d} for j0[n],i0[d]j_{0}\in[n],i_{0}\in[d]. This is equivalent to c(x)n×d=f(x)n×nh(y)n×dEn×d\underbrace{c(x)}_{n\times d}=\underbrace{f(x)}_{n\times n}\underbrace{h(y)}_{n\times d}-\underbrace{E}_{n\times d}.

Definition 3.9.

For every j0[n]j_{0}\in[n], for every i0[d]i_{0}\in[d]. Let us define L(x)j0,i0L(x)_{j_{0},i_{0}} to be :=0.5c(x)j0,i02:=0.5c(x)_{j_{0},i_{0}}^{2}.

3.6 Tensor Trick

We state the well-known tensor-trick. It has been widely used in literature of linear algebra related to tensor computations [41, 21, 18, 5, 25, 52, 39, 27, 22, 16].

Fact 3.10 (Tensor trick).

For two matrices A1A_{1} and A2n×dA_{2}\in\mathbb{R}^{n\times d}, define 𝖠=A1A2\operatorname{\mathsf{A}}=A_{1}\otimes A_{2}. Let Xd×dX\in\mathbb{R}^{d\times d}. Let xd2x\in\mathbb{R}^{d^{2}} denote the vector representation of XX. Then we have vec(A1XA2)=𝖠x\operatorname{vec}(A_{1}XA_{2}^{\top})=\operatorname{\mathsf{A}}x.

Using the above tensor-trick, it is easy to observe that

Fact 3.11.

For two matrices A1A_{1} and A2n×dA_{2}\in\mathbb{R}^{n\times d}, denote 𝖠=A1A2\operatorname{\mathsf{A}}=A_{1}\otimes A_{2}. Let Xd×dX\in\mathbb{R}^{d\times d}. Let 𝖠j0n×d2\operatorname{\mathsf{A}}_{j_{0}}\in\mathbb{R}^{n\times d^{2}} a submatrix of 𝖠\operatorname{\mathsf{A}} (by properly selecting nn rows of 𝖠\operatorname{\mathsf{A}}). Let xd2x\in\mathbb{R}^{d^{2}} denote the vector representation of XX. Then, we have

  • vec(exp(A1XA2))=exp(𝖠x)\operatorname{vec}(\exp(A_{1}XA_{2}^{\top}))=\exp(\operatorname{\mathsf{A}}x)

  • (exp(A1XA2)j0,)=exp(𝖠j0x)(\exp(A_{1}XA_{2}^{\top})_{j_{0},*})^{\top}=\exp(\operatorname{\mathsf{A}}_{j_{0}}x),

Here exp(A1XA2)j0,\exp(A_{1}XA_{2}^{\top})_{j_{0},*} is the j0j_{0}-th row of n×nn\times n matrix exp(A1XA2)\exp(A_{1}XA_{2}^{\top}).

Proof.

We can use the definition in Lemma and Definition 3.10, to prove it. ∎

3.7 Reshape the Loss function via Tensor Trick

Lemma 3.12.

Given the below requirements

  • Here are three matrices A1n×dA_{1}\in\mathbb{R}^{n\times d}, A2n×dA_{2}\in\mathbb{R}^{n\times d}, and A3n×dA_{3}\in\mathbb{R}^{n\times d}

  • Let 𝖠=A1A2n2×d2\mathsf{A}=A_{1}\otimes A_{2}\in\mathbb{R}^{n^{2}\times d^{2}} to be the Kronecker product of the two matrices A1A_{1} and A2A_{2}

    • For every j0[n]j_{0}\in[n], define 𝖠j0n×d2\operatorname{\mathsf{A}}_{j_{0}}\in\mathbb{R}^{n\times d^{2}} to be a n×d2n\times d^{2} sized block in the matrix 𝖠n2×d2\operatorname{\mathsf{A}}\in\mathbb{R}^{n^{2}\times d^{2}}

  • En×dE\in\mathbb{R}^{n\times d} be a matrix. Define Ej0,i0E_{j_{0},i_{0}} as the (j0,i0)(j_{0},i_{0})-th coordinate/location of En×dE\in\mathbb{R}^{n\times d} for every pair of j0[n]j_{0}\in[n] and i0[d]i_{0}\in[d]

  • Here are two square matrices Xd×dX\in\mathbb{R}^{d\times d}, let Yd×dY\in\mathbb{R}^{d\times d}

  • Let L(X)L(X) be defined as Definition 1.2

  • For every pair of j0[n]j_{0}\in[n], i0[d]i_{0}\in[d], recall that definition of L(x)j0,i0L(x)_{j_{0},i_{0}} can be found in in Definition 3.9

Then, we have

L(X)=j0[n]i0[d]L(x)j0,i0.\displaystyle L(X)=\sum_{j_{0}\in[n]}\sum_{i_{0}\in[d]}L(x)_{j_{0},i_{0}}.
Proof.

We can show that

L(X)\displaystyle~{}L(X)
=\displaystyle= 0.5D(X)1n×nexp(A1XA2)n×nA3n×dYd×dEn×dF2\displaystyle~{}0.5\cdot\|\underbrace{D(X)^{-1}}_{n\times n}\underbrace{\exp(A_{1}XA_{2}^{\top})}_{n\times n}\underbrace{A_{3}}_{n\times d}\underbrace{Y}_{d\times d}-\underbrace{E}_{n\times d}\|_{F}^{2}
=\displaystyle= j0=1ni0=1d0.5\displaystyle~{}\sum_{j_{0}=1}^{n}\sum_{i_{0}=1}^{d}0.5\cdot
(exp(𝖠j0x),𝟏n1exp(𝖠j0x),A3Y,i0Ej0,i0)2\displaystyle~{}(\langle\langle\exp(\operatorname{\mathsf{A}}_{j_{0}}x),{\bf 1}_{n}\rangle^{-1}\cdot\exp(\operatorname{\mathsf{A}}_{j_{0}}x),A_{3}Y_{*,i_{0}}\rangle-E_{j_{0},i_{0}})^{2}
=\displaystyle= j0=1ni0=1d0.5(f(x)j0,h(y)i0Ej0,i0)2\displaystyle~{}\sum_{j_{0}=1}^{n}\sum_{i_{0}=1}^{d}0.5(\langle f(x)_{j_{0}},h(y)_{i_{0}}\rangle-E_{j_{0},i_{0}})^{2}
=\displaystyle= j0=1ni0=1dL(x)j0,i0\displaystyle~{}\sum_{j_{0}=1}^{n}\sum_{i_{0}=1}^{d}L(x)_{j_{0},i_{0}}

where the first step follows from definition, the second step follows from writing down the summation, the third step follows from definition of f(x)j0f(x)_{j_{0}} (recall the Definition 3.6) and h(y)i0h(y)_{i_{0}} (recall the Definition 3.7), and the last step follows from L(x)j0,i0L(x)_{j_{0},i_{0}} (see Definition 3.9). ∎

4 Proof Sketch for General Upper Bound

The most straightforward way to compute the gradient would take O(n2d2)O(n^{2}d^{2}) time in order to explicitly write down the matrix 𝖠\operatorname{\mathsf{A}}. By using fast matrix multiplication and regroup the entries, we can obtain our first intermediate algorithm, which runs in quadratic time to compute the gradient.

Lemma 4.1 (Attention gradient computation, informal version of Lemma C.8).

If the following conditions hold

  • Define four n×dn\times d size matrices E,A1,A2,A3E,A_{1},A_{2},A_{3} and two d×dd\times d square matrices X,YX,Y to be input fixed matrices.

  • Let Xd×dX\in\mathbb{R}^{d\times d} and Yd×dY\in\mathbb{R}^{d\times d} denote matrix variables (we will compute gradient with respect to XX )

    • For easy of writing, we also use vector variables xd2×1x\in\mathbb{R}^{d^{2}\times 1} and yd2×1y\in\mathbb{R}^{d^{2}\times 1}

  • Let g=dL(X)dxd2g=\frac{\mathrm{d}L(X)}{\mathrm{d}x}\in\mathbb{R}^{d^{2}} (We abuse notation L(x)L(x) and L(X)L(X) are the same thin)

Then we can show that gradient gd2g\in\mathbb{R}^{d^{2}} can be calculated in O(𝒯mat(n,d,n)+𝒯mat(n,d,d))O({\cal T}_{\mathrm{mat}}(n,d,n)+{\cal T}_{\mathrm{mat}}(n,d,d)) time.

Next, we will show how to improve the running time of computing gradient from quadratic time (n2\geq n^{2}) to almost linear time n1+o(1)n^{1+o(1)}.

Note that by linearity of derivative, we can show that

dL(x)dx=j0=1ni0=1ddL(x)j0,i0dx\displaystyle\frac{\mathrm{d}L(x)}{\mathrm{d}x}=\sum_{j_{0}=1}^{n}\sum_{i_{0}=1}^{d}\frac{\mathrm{d}L(x)_{j_{0},i_{0}}}{\mathrm{d}x}

Based on calculations we perform in Section B, Section C, and several linear algebra facts, we can show that

dL(x)j0,i0dx\displaystyle~{}\frac{\mathrm{d}L(x)_{j_{0},i_{0}}}{\mathrm{d}x}
=\displaystyle= c(x)j0,i0scalar𝖠j0d2×n(diag(f(x)j0)f(x)j0f(x)j0)n×nh(y)i0n×1\displaystyle~{}\underbrace{c(x)_{j_{0},i_{0}}}_{\mathrm{scalar}}\cdot\underbrace{\operatorname{\mathsf{A}}_{j_{0}}^{\top}}_{d^{2}\times n}\underbrace{(\mathrm{diag}(f(x)_{j_{0}})-f(x)_{j_{0}}f(x)_{j_{0}}^{\top})}_{n\times n}\underbrace{h(y)_{i_{0}}}_{n\times 1}

For any fixed j0[n]j_{0}\in[n], consider this quantity. Since this expression involves an n×nn\times n matrix, the most straightforward way to calculate it would take Θ(n2)\Theta(n^{2}) time, and so summing over all j0[n]j_{0}\in[n] would lead to a cubic-time algorithm. It is not too difficult to improve this: the n×nn\times n matrix

(diag(f(x)j0)adiagonalmatrixf(x)j0f(x)j0arank1matrix)\displaystyle(\underbrace{\mathrm{diag}(f(x)_{j_{0}})}_{\mathrm{a~{}diagonal~{}matrix}}-\underbrace{f(x)_{j_{0}}f(x)_{j_{0}}^{\top}}_{\mathrm{a~{}rank~{}1~{}matrix}})

is easily decomposed into a low-rank part (f(x)j0f(x)j0f(x)_{j_{0}}f(x)_{j_{0}}^{\top} which has size n×nn\times n) and a sparse part (diag(f(x)j0)\mathrm{diag}(f(x)_{j_{0}}) which also has size n×nn\times n), which reduces the calculation of each part to only O~(n)\widetilde{O}(n) time, and the total running time to O~(n2)\widetilde{O}(n^{2}) time.

However, we are aiming for a almost-linear time algorithm, and it is not possible to achieve this by treating the different j0j_{0} separately, since a given j0j_{0} must take Ω(n)\Omega(n) time to process. Instead, we use tensor techniques related to low-rank approximations to simultanouesly compute all j0j_{0} together and sum them in almost-linear time.

In order to do that, we create several extra artificial or intermediate matrices q(x)n×nq(x)\in\mathbb{R}^{n\times n}(see Section C), p(x)n×np(x)\in\mathbb{R}^{n\times n} (see Section C). We will show the gradient can be finally constructed using a simple chaining technique (see Section D for more details), from f,c,qf,c,q, p1p_{1} (handling diag(f(x)j0)\mathrm{diag}(f(x)_{j_{0}}) similarly), p2p_{2} (handling f(x)j0f(x)j0f(x)_{j_{0}}f(x)_{j_{0}}^{\top} similarly), pp (p=p1p2p=p_{1}-p_{2}) to dLdx\frac{\mathrm{d}L}{\mathrm{d}x}. Intuitively, the chaining shows that a low rank representation for ff yields one for cc, and these in turn yield one for qq, and so on.

In particular, using q(x)q(x), we obtain that dL(x)dx\frac{\mathrm{d}L(x)}{\mathrm{d}x} can be written as

j0=1n𝖠j0(adiagonalmatrixdiag(f(x)j0)arank1matrixf(x)j0f(x)j0)acolumnvectorq(x)j0\displaystyle\sum_{j_{0}=1}^{n}\operatorname{\mathsf{A}}_{j_{0}}^{\top}(\underbrace{\mathrm{~{}a~{}diagonal~{}matrix}}_{\mathrm{diag}(f(x)_{j_{0}})}-\underbrace{\mathrm{~{}a~{}rank~{}1~{}matrix}}_{f(x)_{j_{0}}f(x)_{j_{0}}^{\top}})\underbrace{\mathrm{a~{}column~{}vector}}_{q(x)_{j_{0}}}

which in fact notably removes the summation step of i0=1i_{0}=1 to dd. Using the notation of p(x)p(x), we finally yield that we need to compute A1p(x)A2A_{1}^{\top}p(x)A_{2}. Thus as long as p(x)p(x) has a low-rank representation, then we can solve the in n1+o(1)n^{1+o(1)} time (see Section D for more details). In particular, we will find that p(x)p(x) is the entry-wise product of two matrices with low-rank representations from prior work, which we can combine using a column-wise Kronecker product to approximate p(x)p(x) itself.

5 General Lower Bound

We will critically make use of the known hardness result for attention computation itself, which we state now.

Definition 5.1 (Attention Computation).

Given as input matrices Q,K,Vn×dQ,K,V\in\mathbb{R}^{n\times d} and a parameter ε>0\varepsilon>0, compute a matrix Tn×dT\in\mathbb{R}^{n\times d} satisfying

TD1AVε,\|T-D^{-1}AV\|_{\infty}\leq\varepsilon,

where A=exp(QK)A=\exp(QK^{\top}) and D=diag(A𝟏n)D=\mathrm{diag}(A{\bf 1}_{n}).

Lemma 5.2 (Lemma 4.7 in [4]).

Assuming 𝖲𝖤𝖳𝖧\operatorname{\mathsf{SETH}}, there is no algorithm running in time O(n2δ)O(n^{2-\delta}) for any constant δ>0\delta>0 that solves Attention Computation (Definition 5.1), even when the inputs satisfy the following constraints, for any parameter κ0\kappa\geq 0:

  • d=O(logn)d=O(\log n),

  • V{0,1}n×dV\in\{0,1\}^{n\times d},

  • There is a value BO(log2n(1+κ))B\leq O(\log^{2}n\cdot(1+\kappa)) such that every entry of QKQK^{\top} is in the interval [0,B][0,B] and at least half the entries in each row of QKQK^{\top} are equal to BB,

  • moreover Q,KO(logn(1+κ))\|Q\|_{\infty},\|K\|_{\infty}\leq O(\sqrt{\log n(1+\kappa)}), and

  • ε<nκO(1)\varepsilon<n^{\kappa-O(1)}.

Next, we show that the attention optimization problem behaves particularly well when given matrices constrained as in Lemma 5.2:

Lemma 5.3.

Let AA be a fixed n×nn\times n matrix whose entries are real numbers in the interval [0,B][0,B], and such that in each row of AA, at least half the entries are equal to BB. Let VV be any n×dn\times d matrix whose entries are all in {0,1}\{0,1\}. For λ\lambda\in\mathbb{R}, define the n×nn\times n matrix Mλ:=exp(λA)M_{\lambda}:=\exp(\lambda A), where exp\exp is applied entry-wise. Define the function f:f:\mathbb{R}\to\mathbb{R} by

f(λ):=diag(Mλ𝟏n)1MλVF2,\displaystyle f(\lambda):=\|\mathrm{diag}(M_{\lambda}{\bf 1}_{n})^{-1}M_{\lambda}V\|_{F}^{2},

Then, for all λ\lambda\in\mathbb{R} we have

  • |f(λ)|O(Bn)|f^{\prime}(\lambda)|\leq O(Bn),

  • |f′′(λ)|O(B2n)|f^{\prime\prime}(\lambda)|\leq O(B^{2}n).

Proof.

Let CC denote the n×nn\times n matrix C=diag(Mλ𝟏n)1MλC=\mathrm{diag}(M_{\lambda}{\bf 1}_{n})^{-1}M_{\lambda}. For i,j[n]i,j\in[n], we calculate that Mλ[i,j]=eλA[i,j]M_{\lambda}[i,j]=e^{\lambda A[i,j]} and so

C[i,j]=eλA[i,j]k=1neλA[i,k].\displaystyle C[i,j]=\frac{e^{\lambda A[i,j]}}{\sum_{k=1}^{n}e^{\lambda A[i,k]}}.

For [d]\ell\in[d], let S[n]S_{\ell}\subseteq[n] be the set of 11s in column \ell of VV, i.e., S={j[n]V[j,]=1}S_{\ell}=\{j\in[n]\mid V[j,\ell]=1\}. Hence, for i[n]i\in[n] and [d]\ell\in[d], the entry (i,)(i,\ell) of the matrix diag(Mλ𝟏n)1MλV\mathrm{diag}(M_{\lambda}{\bf 1}_{n})^{-1}M_{\lambda}V is given by

diag(Mλ𝟏n)1MλV[i,]\displaystyle\mathrm{diag}(M_{\lambda}{\bf 1}_{n})^{-1}M_{\lambda}V[i,\ell] =CV[i,]\displaystyle=CV[i,\ell]
=j=1nC[i,j]V[j,]\displaystyle=\sum_{j=1}^{n}C[i,j]V[j,\ell]
=jSC[i,j]\displaystyle=\sum_{j\in S_{\ell}}C[i,j]
=jSeλA[i,j]k=1neλA[i,k].\displaystyle=\frac{\sum_{j\in S_{\ell}}e^{\lambda A[i,j]}}{\sum_{k=1}^{n}e^{\lambda A[i,k]}}.

where the first step follows from definition, the second step follows from simple algebra.

We thus get an explicit expression for f(λ)f(\lambda):

f(λ)\displaystyle f(\lambda) =i=1n=1d(jSeλA[i,j])2(k=1neλA[i,k])2\displaystyle=\sum_{i=1}^{n}\frac{\sum_{\ell=1}^{d}\left(\sum_{j\in S_{\ell}}e^{\lambda A[i,j]}\right)^{2}}{\left(\sum_{k=1}^{n}e^{\lambda A[i,k]}\right)^{2}}
=i=1n=1dj1Snj2Sneλ(A[i,j1]+A[i,j2])k1=1nk2=1neλ(A[i,k1]+A[i,k2]).\displaystyle=\sum_{i=1}^{n}\frac{\sum_{\ell=1}^{d}\sum_{j_{1}\in S_{\ell}}^{n}\sum_{j_{2}\in S_{\ell}}^{n}e^{\lambda(A[i,j_{1}]+A[i,j_{2}])}}{\sum_{k_{1}=1}^{n}\sum_{k_{2}=1}^{n}e^{\lambda(A[i,k_{1}]+A[i,k_{2}])}}.

We define

a(λ,i):==1dj1Snj2Sneλ(A[i,j1]+A[i,j2])\displaystyle a(\lambda,i):=\sum_{\ell=1}^{d}\sum_{j_{1}\in S_{\ell}}^{n}\sum_{j_{2}\in S_{\ell}}^{n}e^{\lambda(A[i,j_{1}]+A[i,j_{2}])}

and then we define

b(λ,i):=k1=1nk2=1neλ(A[i,k1]+A[i,k2])\displaystyle b(\lambda,i):=\sum_{k_{1}=1}^{n}\sum_{k_{2}=1}^{n}e^{\lambda(A[i,k_{1}]+A[i,k_{2}])}

Combining the above three equations, we can obtain

f(λ)=i=1na(λ,i)/b(λ,i).\displaystyle f(\lambda)=\sum_{i=1}^{n}a(\lambda,i)/b(\lambda,i).

Since, for each row of AA, at least half the entries equal BB, and all the entries are in the interval [1,B][1,B], we can bound

(n2)2e2Bλb(λ,i)(n)2e2Bλ.\displaystyle\left(\frac{n}{2}\right)^{2}\cdot e^{2B\lambda}\leq b(\lambda,i)\leq\left(n\right)^{2}\cdot e^{2B\lambda}. (1)

Furthermore, since the derivative of eλ(A[i,k1]+A[i,k2])e^{\lambda(A[i,k_{1}]+A[i,k_{2}])} with respect to λ\lambda is (A[i,k1]+A[i,k2])eλ(A[i,k1]+A[i,k2])(A[i,k_{1}]+A[i,k_{2}])\cdot e^{\lambda(A[i,k_{1}]+A[i,k_{2}])}, we can bound

2b(λ,i)db(λ,i)dλ2Bb(λ,i).\displaystyle 2\cdot b(\lambda,i)\leq\frac{\mathrm{d}b(\lambda,i)}{\mathrm{d}\lambda}\leq 2B\cdot b(\lambda,i). (2)

We may similarly bound

0a(λ,i)n2e2Bλ,\displaystyle 0\leq a(\lambda,i)\leq n^{2}\cdot e^{2B\lambda}, (3)

and

2a(λ,i)da(λ,i)dλ2Ba(λ,i).\displaystyle 2\cdot a(\lambda,i)\leq\frac{\mathrm{d}a(\lambda,i)}{\mathrm{d}\lambda}\leq 2B\cdot a(\lambda,i). (4)

We can thus bound the derivative of ff (where here, all the notation means derivative with respect to λ\lambda):

f(λ)\displaystyle f^{\prime}(\lambda) =i=1na(λ,i)b(λ,i)a(λ,i)b(λ,i)(b(λ,i))2\displaystyle=\sum_{i=1}^{n}\frac{a^{\prime}(\lambda,i)\cdot b(\lambda,i)-a(\lambda,i)\cdot b^{\prime}(\lambda,i)}{(b(\lambda,i))^{2}}
i=1na(λ,i)b(λ,i)(b(λ,i))2\displaystyle\leq\sum_{i=1}^{n}\frac{a^{\prime}(\lambda,i)\cdot b(\lambda,i)}{(b(\lambda,i))^{2}}
=i=1na(λ,i)b(λ,i)\displaystyle=\sum_{i=1}^{n}\frac{a^{\prime}(\lambda,i)}{b(\lambda,i)}
i=1n2Bn2e2Bλ(n/2)2e2Bλ\displaystyle\leq\sum_{i=1}^{n}\frac{2B\cdot n^{2}e^{2B\lambda}}{(n/2)^{2}\cdot e^{2B\lambda}}
=i=1n8B\displaystyle=\sum_{i=1}^{n}8B
=8Bn.\displaystyle=8B\cdot n.

where the 1st step follows from definition, the 2nd step follows from simple algebra, the 3rd step follows from cancelling b(λ,i)b(\lambda,i), the 4th step is using Eq. (1) (for b(λ,i)b(\lambda,i)) and Eq. (4) (for a(λ,i)a^{\prime}(\lambda,i)), the 5th step follows from simple algebra, and the last step follows from simple algebra.

Similarly, we can provide a lower bound f(λ)f^{\prime}(\lambda),

f(λ)\displaystyle f^{\prime}(\lambda) =i=1na(λ,i)b(λ,i)a(λ,i)b(λ,i)(b(λ,i))2\displaystyle=\sum_{i=1}^{n}\frac{a^{\prime}(\lambda,i)\cdot b(\lambda,i)-a(\lambda,i)\cdot b^{\prime}(\lambda,i)}{(b(\lambda,i))^{2}}
i=1na(λ,i)b(λ,i)(b(λ,i))2\displaystyle\geq-\sum_{i=1}^{n}\frac{a(\lambda,i)\cdot b^{\prime}(\lambda,i)}{(b(\lambda,i))^{2}}
i=1n(n2e2Bλ)(2Bb(λ,i))((n/2)2e2Bλ)(b(λ,i))\displaystyle\geq-\sum_{i=1}^{n}\frac{(n^{2}\cdot e^{2B\lambda})\cdot(2B\cdot b(\lambda,i))}{((n/2)^{2}\cdot e^{2B\lambda})\cdot(b(\lambda,i))}
=i=1n8B\displaystyle=-\sum_{i=1}^{n}8B
=8Bn.\displaystyle=-8B\cdot n.

where the 1st step follows from definition, the 2nd step follows form simple algebra, the 3rd step follows Eq. (2) (for b(λ,i)b^{\prime}(\lambda,i)) and Eq. (3) (for a(λ,i)a(\lambda,i)), the 4th step follows from simple algebra, and the last step follows from simple algbera.

Finally, letting f(λ,i):=a(λ,i)/b(λ,i)f(\lambda,i):=a(\lambda,i)/b(\lambda,i), we have again by the quotient rule that f′′(λ)f^{\prime\prime}(\lambda) is equal to

i=1na′′(λ,i)b′′(λ,i)f(λ,i)2b(λ,i)f(λ,i)b(λ,i)\sum_{i=1}^{n}\frac{a^{\prime\prime}(\lambda,i)-b^{\prime\prime}(\lambda,i)\cdot f(\lambda,i)-2\cdot b^{\prime}(\lambda,i)\cdot f^{\prime}(\lambda,i)}{b(\lambda,i)}

which we similarly bound in magnitude by O(B2n)O(B^{2}n). ∎

We recall a simple approximation from calculus:

Lemma 5.4.

Let f:[0,1]f:[0,1]\to\mathbb{R} be a twice-differentiable function such that |f′′(λ)|b|f^{\prime\prime}(\lambda)|\leq b for all λ[0,1]\lambda\in[0,1]. For any positive integer mm, define the sum

tm:=i=0m1f(i/m)m.t_{m}:=\sum_{i=0}^{m-1}\frac{f^{\prime}(i/m)}{m}.

Then,

|tm(f(1)f(0))|b/m.|t_{m}-(f(1)-f(0))|\leq b/m.
Proof.

If two λ0,λ1[0,1]\lambda_{0},\lambda_{1}\in[0,1] have |λ0λ1|1/m|\lambda_{0}-\lambda_{1}|\leq 1/m, then from our bound on f′′(λ)f^{\prime\prime}(\lambda), we know that |f(λ1)f(λ0)|b/m|f^{\prime}(\lambda_{1})-f^{\prime}(\lambda_{0})|\leq b/m. We can thus bound the difference

f(1)f(0)=01f(λ)𝑑λ\displaystyle f(1)-f(0)=\int_{0}^{1}f^{\prime}(\lambda)d\lambda

by

f(1)f(0)i=0m1f(i/m)+(b/m)m=tm+b/m\displaystyle f(1)-f(0)\leq\sum_{i=0}^{m-1}\frac{f^{\prime}(i/m)+(b/m)}{m}=t_{m}+b/m

and

f(1)f(0)i=0m1f(i/m)(b/m)m=tmb/m.\displaystyle f(1)-f(0)\geq\sum_{i=0}^{m-1}\frac{f^{\prime}(i/m)-(b/m)}{m}=t_{m}-b/m.

Thus, we complete the proof. ∎

Finally, we are ready for our main result:

Theorem 5.5 (Formal version of Theorem 1.5).

Let κ:𝒩𝒩\kappa:\mathcal{N}\to\mathcal{N} by any function with κ(n)=ω(1)\kappa(n)=\omega(1) and κ(n)=o(logn)\kappa(n)=o(\log n). Assuming 𝖲𝖤𝖳𝖧\operatorname{\mathsf{SETH}}, there is no algorithm running in time O(n2δ)O(n^{2-\delta}) for any constant δ>0\delta>0 for Approximate Attention Loss Gradient Computation (Definition 1.4), even in the case where d=O(logn)d=O(\log n) and the input matrices satisfy A1,A2,A3O(lognκ(n))\|A_{1}\|_{\infty},\|A_{2}\|_{\infty},\|A_{3}\|_{\infty}\leq O(\sqrt{\log n}\cdot\kappa(n)), B=0B=0, Y=IY=I, X=λIX=\lambda I for some scalar λ[0,1]\lambda\in[0,1], and ε=O(1/(logn)4)\varepsilon=O(1/(\log n)^{4}).

Proof.

Suppose there were such an algorithm. We call it O((logn)4)O((\log n)^{4}) times to refute Lemma 5.2 (with parameter κ=κ(n)\kappa=\kappa(n)). Let Q,K,VQ,K,V be the input matrices to Lemma 5.2, and set A1=QA_{1}=Q, A2=KA_{2}=K, A3=VA_{3}=V, Y=IY=I, and X=λIX=\lambda I for a parameter λ[0,1]\lambda\in[0,1]. Suppose the function f:[0,1]f:[0,1]\to\mathbb{R} is in Lemma 5.3 where AA is the matrix A1A2A_{1}A_{2}^{\top}, so that MλM_{\lambda} is the matrix exp(A1XA2)\exp(A_{1}XA_{2}^{\top}). It follows from Lemma 5.3 that

|f′′(λ)|O(nlog2n(κ(n))2).\displaystyle|f^{\prime\prime}(\lambda)|\leq O(n\log^{2}n\cdot(\kappa(n))^{2}).

We can compute f(0)f(0) in O~(n)\widetilde{O}(n) time since then MfM_{f} is the all-1s matrix, and our goal is to output f(1)f(1).

Thus, by Lemma 5.4, it suffices to compute f(λ)f^{\prime}(\lambda) on O(log2(n)(κ(n))2)=O(log4n)O(\log^{2}(n)(\kappa(n))^{2})=O(\log^{4}n) points up to O(1/(logn)4)O(1/(\log n)^{4}) error, and return their average. But, since we have picked X=λIX=\lambda I, we can calculate f(λ)f^{\prime}(\lambda) from the gradient dL(X)dX\frac{\mathrm{d}L(X)}{\mathrm{d}X} (from Definition 1.4), which is approximated by our assumed algorithm. ∎

6 Conclusion

Our results give a complete fine-grained analysis of the running time needed to train LLMs. We show that there is a threshold depending on the parameter BB, the magnitude of the parameter matrix entries. In settings where BB is small, a near-linear-time algorithm for LLM training is possible by using our novel algorithm for backward computation. In settings where BB is large, not only does our algorithm not apply, but we show it is impossible to design a nontrivially-fast algorithm (barring a breakthrough in satisfiability algorithms that would refute the popular 𝖲𝖤𝖳𝖧\operatorname{\mathsf{SETH}}).

These insights can guide LLM designers to more efficient algorithms. When BB can be made small, it would lead to substantial savings in the computational resources needed for training and expression. When BB must be large (perhaps to achieve a high expressiveness?), our lower bounds show that one may as well use straigthforward algorithms and focus on other aspects of algorithm speedup such as parallelization. The magnitude of BB needed has been studied more recently (e.g., [5]), and the need for fast training algorithms may further motivate this direction of research.

Appendix

Roadmap.

In Section A, we provide basic notation and facts. In Section B, we provide details about gradient computations. In Section C, we explain the computation time for the gradient of attention loss. In Section D, we show how to further improve the gradient computation from quadratic time to almost linear time.

Appendix A Preliminaries

In Section A.1, we define some basic notation. In Section A.2, we state several facts which we will use.

A.1 Notation

For any positive integer nn, we define [n]:={1,2,,n}[n]:=\{1,2,\dots,n\}.

For two same length vector xx and yy, we use x,y\langle x,y\rangle to denote the inner product between xx and yy, i.e., x,y=i=1nxiyi\langle x,y\rangle=\sum_{i=1}^{n}x_{i}y_{i}. We use xyx\circ y to denote vector that ii-th entry is xiyix_{i}y_{i}. Let 𝟏n{\bf 1}_{n} denote the length-nn all ones vector. It is not hard to see that xy,𝟏n=x,y\langle x\circ y,{\bf 1}_{n}\rangle=\langle x,y\rangle.

For a vector uu, we use uu^{\top} to denote the transpose of uu. For a matrix MM, we use MM^{\top} to denote the transpose of matrix MM.

For a vector uu, we use exp(u)\exp(u) to denote the vector that ii-th coordinate is exp(ui)\exp(u_{i}). For a matrix AA, we use exp(A)\exp(A) to denote the matrix that (i,j)(i,j)-th coordinate is exp(Ai,j)\exp(A_{i,j}).

We define the Kronecker product between matrices XX and YY, denoted XYn0n1×m0m1X\otimes Y\in\mathbb{R}^{n_{0}n_{1}\times m_{0}m_{1}}, as (XY)(j01)n1+j1,(i01)m2+i1(X\otimes Y)_{(j_{0}-1)n_{1}+j_{1},(i_{0}-1)m_{2}+i_{1}} is equal to Xj0,i0Yj1,i1X_{j_{0},i_{0}}Y_{j_{1},i_{1}}, where j0[n0],i0[m0],j1[n1],i1[m1]j_{0}\in[n_{0}],i_{0}\in[m_{0}],j_{1}\in[n_{1}],i_{1}\in[m_{1}].

For each positive integers m1,m2,m3m_{1},m_{2},m_{3}, we use 𝒯mat(m1,m2,m3){\cal T}_{\mathrm{mat}}(m_{1},m_{2},m_{3}) to denote the time of multiplying m1×m2m_{1}\times m_{2} matrix with another m2×m3m_{2}\times m_{3} matrix.

A.2 Basic Facts

Fact A.1.

Let x,y,znx,y,z\in\mathbb{R}^{n}. Then we have

  • xy,z=xdiag(y)z\langle x\circ y,z\rangle=x^{\top}\mathrm{diag}(y)z.

  • x,y=xy,𝟏n\langle x,y\rangle=\langle x\circ y,{\bf 1}_{n}\rangle.

Fact A.2 (Folklore).

Let U1,V1n×k1U_{1},V_{1}\in\mathbb{R}^{n\times k_{1}}. Let U2,V2n×k2U_{2},V_{2}\in\mathbb{R}^{n\times k_{2}}. Then we have

(U1V1)(U2V2)=(U1U2)(V1V2)\displaystyle(U_{1}V_{1}^{\top})\circ(U_{2}V_{2}^{\top})=(U_{1}\oslash U_{2})(V_{1}\oslash V_{2})^{\top}

Here, given U1n×k1U_{1}\in\mathbb{R}^{n\times k_{1}} and U2n×k2U_{2}\in\mathbb{R}^{n\times k_{2}}, the U1U2n×k1k2U_{1}\oslash U_{2}\in\mathbb{R}^{n\times k_{1}k_{2}} is the row-wise Kronecker product, i.e., (U1U2)i,l1+(l21)k1:=(U1)i,l1Ui,l2(U_{1}\oslash U_{2})_{i,l_{1}+(l_{2}-1)k_{1}}:=(U_{1})_{i,l_{1}}U_{i,l_{2}} for all i[n]i\in[n], l1[k1]l_{1}\in[k_{1}] and l2[k2]l_{2}\in[k_{2}]

Appendix B More Details about Gradient Computation

In this section, we provide details and calculations to assist with gradient and derivative computations. We remark that, in this section, for convenience of computing a closed form for the gradient, we ignore the 1/d1/d factor in function ff. Since it is only a rescaling factor, it won’t affect how we compute these matrices in general.

Lemma B.1 (The gradient computation for several different functions with respect to xix_{i}).

For every i[d2]i\in[d^{2}], define 𝖠j0,in\operatorname{\mathsf{A}}_{j_{0},i}\in\mathbb{R}^{n} to be the ii-th column for 𝖠j0n×d\operatorname{\mathsf{A}}_{j_{0}}\in\mathbb{R}^{n\times d}. u(x)j0nu(x)_{j_{0}}\in\mathbb{R}^{n}. The scalar function α(x)j0\alpha(x)_{j_{0}}\in\mathbb{R}, column function f(x)j0nf(x)_{j_{0}}\in\mathbb{R}^{n}, scalar function c(x)j0,i0c(x)_{j_{0},i_{0}}\in\mathbb{R} and scalar function L(x)j0,i0L(x)_{j_{0},i_{0}}\in\mathbb{R} are defined as in Definitions 3.4, 3.5, 3.6, 3.8 and 3.9 respectively.

Then, for each i[d2]i\in[d^{2}], we have

  • Part 1.

    dxdxi=ei\displaystyle\frac{\mathrm{d}x}{\mathrm{d}x_{i}}=e_{i}
  • Part 2. For each j0[n]j_{0}\in[n],

    d𝖠j0xdxi=(𝖠j0)i\displaystyle\frac{\mathrm{d}\operatorname{\mathsf{A}}_{j_{0}}x}{\mathrm{d}x_{i}}=(\operatorname{\mathsf{A}}_{j_{0}})_{i}
  • Part 3. For each j0[n]j_{0}\in[n]

    du(x)j0dxi=𝖠j0,iu(x)j0\displaystyle\frac{\mathrm{d}u(x)_{j_{0}}}{\mathrm{d}x_{i}}=\operatorname{\mathsf{A}}_{j_{0},i}\circ u(x)_{j_{0}}
  • Part 4. For each j0[n]j_{0}\in[n],

    dα(x)j0dxi=𝖠j0,i,u(x)j0\displaystyle\frac{\mathrm{d}\alpha(x)_{j_{0}}}{\mathrm{d}x_{i}}=\langle\operatorname{\mathsf{A}}_{j_{0},i},u(x)_{j_{0}}\rangle
  • Part 5. For each j0[n]j_{0}\in[n],

    df(x)j0dxi=𝖠j0,if(x)j0𝖠j0,i,f(x)j0f(x)j0\displaystyle\frac{\mathrm{d}f(x)_{j_{0}}}{\mathrm{d}x_{i}}=\operatorname{\mathsf{A}}_{j_{0},i}\circ f(x)_{j_{0}}-\langle\operatorname{\mathsf{A}}_{j_{0},i},f(x)_{j_{0}}\rangle\cdot f(x)_{j_{0}}
  • Part 6. For each j0[n]j_{0}\in[n], for each i0[d]i_{0}\in[d],

    df(x)j0,h(y)i0dxi=h(y)i0,𝖠j0,if(x)j0h(y)i0,f(x)j0𝖠j0,i,f(x)j0\displaystyle\frac{\mathrm{d}\langle f(x)_{j_{0}},h(y)_{i_{0}}\rangle}{\mathrm{d}x_{i}}=\langle h(y)_{i_{0}},\operatorname{\mathsf{A}}_{j_{0},i}\circ f(x)_{j_{0}}\rangle-\langle h(y)_{i_{0}},f(x)_{j_{0}}\rangle\cdot\langle\operatorname{\mathsf{A}}_{j_{0},i},f(x)_{j_{0}}\rangle
  • Part 7. For each j0[n]j_{0}\in[n], for every i0[d]i_{0}\in[d]

    dc(x)j0,i0dxi=𝖠j0,if(x)j0,h(y)i0f(x)j0,h(y)i0𝖠j0,i,f(x)j0\displaystyle\frac{\mathrm{d}c(x)_{j_{0},i_{0}}}{\mathrm{d}x_{i}}=\langle\operatorname{\mathsf{A}}_{j_{0},i}\circ f(x)_{j_{0}},h(y)_{i_{0}}\rangle-\langle f(x)_{j_{0}},h(y)_{i_{0}}\rangle\cdot\langle\operatorname{\mathsf{A}}_{j_{0},i},f(x)_{j_{0}}\rangle
  • Part 8. For each j0[n]j_{0}\in[n], for each i0[d]i_{0}\in[d]

    dL(x)j0,i0dxi=(h(y)i0,𝖠j0,if(x)j0f(x)j0,𝖠j0,ih(y)i0,f(x)j0)c(x)j0,i0\displaystyle\frac{\mathrm{d}L(x)_{j_{0},i_{0}}}{\mathrm{d}x_{i}}=(\langle h(y)_{i_{0}},\operatorname{\mathsf{A}}_{j_{0},i}\circ f(x)_{j_{0}}\rangle-\langle f(x)_{j_{0}},\operatorname{\mathsf{A}}_{j_{0},i}\rangle\cdot\langle h(y)_{i_{0}},f(x)_{j_{0}}\rangle)\cdot c(x)_{j_{0},i_{0}}
Proof.

Proof of Part 1. We have

dxdxi\displaystyle\frac{\mathrm{d}x}{\mathrm{d}x_{i}}

Proof of Part 2. We have

d𝖠j0xdxi=\displaystyle\frac{\mathrm{d}\operatorname{\mathsf{A}}_{j_{0}}x}{\mathrm{d}x_{i}}= 𝖠j0n×d2dxdxid2×1\displaystyle~{}\underbrace{\operatorname{\mathsf{A}}_{j_{0}}}_{n\times d^{2}}\underbrace{\frac{\mathrm{d}x}{\mathrm{d}x_{i}}}_{d^{2}\times 1}
=\displaystyle= 𝖠j0n×d2eid2×1\displaystyle~{}\underbrace{\operatorname{\mathsf{A}}_{j_{0}}}_{n\times d^{2}}\cdot\underbrace{e_{i}}_{d^{2}\times 1}
=\displaystyle= 𝖠j0,i\displaystyle~{}\operatorname{\mathsf{A}}_{j_{0},i}

Proof of Part 3.

We can show

du(x)j0dxi=\displaystyle\frac{\mathrm{d}u(x)_{j_{0}}}{\mathrm{d}x_{i}}= dexp(𝖠j0x)dxi\displaystyle~{}\frac{\mathrm{d}\exp(\operatorname{\mathsf{A}}_{j_{0}}x)}{\mathrm{d}x_{i}}
=\displaystyle= exp(𝖠j0x)d𝖠j0xdxi\displaystyle~{}\exp(\operatorname{\mathsf{A}}_{j_{0}}x)\circ\frac{\mathrm{d}\operatorname{\mathsf{A}}_{j_{0}}x}{\mathrm{d}x_{i}}
=\displaystyle= exp(𝖠j0x)𝖠j0,i\displaystyle~{}\exp(\operatorname{\mathsf{A}}_{j_{0}}x)\circ\operatorname{\mathsf{A}}_{j_{0},i}
=\displaystyle= u(x)j0𝖠j0,i\displaystyle~{}u(x)_{j_{0}}\circ\operatorname{\mathsf{A}}_{j_{0},i}

where the 3rd step follows from Part 2, the last step follows from definition of u(x)j0u(x)_{j_{0}}.

Proof of Part 4.

For simplicity of writing proofs, we use ()(\cdot) to denote (x)(x).

We can show

dα()j0dxi=\displaystyle\frac{\mathrm{d}\alpha(\cdot)_{j_{0}}}{\mathrm{d}x_{i}}= du()j0,𝟏ndxi\displaystyle~{}\frac{\mathrm{d}\langle u(\cdot)_{j_{0}},{\bf 1}_{n}\rangle}{\mathrm{d}x_{i}}
=\displaystyle= u()j0𝖠j0,i,𝟏n\displaystyle~{}\langle u(\cdot)_{j_{0}}\circ\operatorname{\mathsf{A}}_{j_{0},i},{\bf 1}_{n}\rangle
=\displaystyle= u()j0,𝖠j0,i\displaystyle~{}\langle u(\cdot)_{j_{0}},\operatorname{\mathsf{A}}_{j_{0},i}\rangle

where the 1st step follows from definition of α()\alpha(\cdot), the 2nd step follows from Part 3, the 3rd step follows from Fact A.1.

Proof of Part 5. For simplicity of writing proofs, we use ()(\cdot) to denote (x)(x).

We can show that

df()j0dxi=\displaystyle\frac{\mathrm{d}f(\cdot)_{j_{0}}}{\mathrm{d}x_{i}}= dα()j01u()j0dxi\displaystyle~{}\frac{\mathrm{d}\alpha(\cdot)_{j_{0}}^{-1}u(\cdot)_{j_{0}}}{\mathrm{d}x_{i}}
=\displaystyle= α()j01du()j0dxi+(dα()j01dxi)u()j0\displaystyle~{}\alpha(\cdot)_{j_{0}}^{-1}\frac{\mathrm{d}u(\cdot)_{j_{0}}}{\mathrm{d}x_{i}}+(\frac{\mathrm{d}\alpha(\cdot)_{j_{0}}^{-1}}{\mathrm{d}x_{i}})u(\cdot)_{j_{0}}

For the first term, we have

α()j01du()j0dxi=\displaystyle\alpha(\cdot)_{j_{0}}^{-1}\frac{\mathrm{d}u(\cdot)_{j_{0}}}{\mathrm{d}x_{i}}= α()j01u()j0𝖠j0,i\displaystyle~{}\alpha(\cdot)_{j_{0}}^{-1}u(\cdot)_{j_{0}}\circ\operatorname{\mathsf{A}}_{j_{0},i}
=\displaystyle= f()j0𝖠j0,i\displaystyle~{}f(\cdot)_{j_{0}}\circ\operatorname{\mathsf{A}}_{j_{0},i}

where the 1st step follows from Part 3, the 2nd step follows from definition of f()f(\cdot).

For the second term, we have

(dα()j01dxi)u()j0=\displaystyle(\frac{\mathrm{d}\alpha(\cdot)_{j_{0}}^{-1}}{\mathrm{d}x_{i}})u(\cdot)_{j_{0}}= α()j02dα()j0dxiu()j0\displaystyle~{}-\alpha(\cdot)_{j_{0}}^{-2}\frac{\mathrm{d}\alpha(\cdot)_{j_{0}}}{\mathrm{d}x_{i}}u(\cdot)_{j_{0}}
=\displaystyle= α()j02u()j0,𝖠j0,iu()j0\displaystyle~{}-\alpha(\cdot)_{j_{0}}^{-2}\cdot\langle u(\cdot)_{j_{0}},\operatorname{\mathsf{A}}_{j_{0},i}\rangle\cdot u(\cdot)_{j_{0}}
=\displaystyle= f()j0f()j0,𝖠j0,i\displaystyle~{}-f(\cdot)_{j_{0}}\cdot\langle f(\cdot)_{j_{0}},\operatorname{\mathsf{A}}_{j_{0},i}\rangle

where the 1st step follows from basic calculus, the 2nd step follows from Part 4, the 3rd step follows from definition of f()j0f(\cdot)_{j_{0}}.

Using all of the results above, it holds that

df()j0dxi=\displaystyle\frac{\mathrm{d}f(\cdot)_{j_{0}}}{\mathrm{d}x_{i}}= f()j0𝖠j0,if()j0f()j0,𝖠j0,i\displaystyle~{}f(\cdot)_{j_{0}}\circ\operatorname{\mathsf{A}}_{j_{0},i}-f(\cdot)_{j_{0}}\cdot\langle f(\cdot)_{j_{0}},\operatorname{\mathsf{A}}_{j_{0},i}\rangle

Proof of Part 6. It follows Part 5 directly.

Proof of Part 7. For simplicity of writing proofs, we use ()(\cdot) to denote (x)(x).

Following the definition of cc in Definition 3.8, it holds that

c()j0,i0:=f()j0,vEj0,i0\displaystyle c(\cdot)_{j_{0},i_{0}}:=\langle f(\cdot)_{j_{0}},v\rangle-E_{j_{0},i_{0}} (5)

Thus it holds that

dc()j0,i0dxi=\displaystyle\frac{\mathrm{d}c(\cdot)_{j_{0},i_{0}}}{\mathrm{d}x_{i}}= d(f()j0,h(y)i0Ej0,i0)dxi\displaystyle~{}\frac{\mathrm{d}(\langle f(\cdot)_{j_{0}},h(y)_{i_{0}}\rangle-E_{j_{0},i_{0}})}{\mathrm{d}x_{i}}
=\displaystyle= df()j0,h(y)i0dxi\displaystyle~{}\frac{\mathrm{d}\langle f(\cdot)_{j_{0}},h(y)_{i_{0}}\rangle}{\mathrm{d}x_{i}}
=\displaystyle= f()j0𝖠j0,i,h(y)i0f()j0,h(y)i0f()j0,𝖠j0,i,\displaystyle~{}\langle f(\cdot)_{j_{0}}\circ\operatorname{\mathsf{A}}_{j_{0},i},h(y)_{i_{0}}\rangle-\langle f(\cdot)_{j_{0}},h(y)_{i_{0}}\rangle\cdot\langle f(\cdot)_{j_{0}},\operatorname{\mathsf{A}}_{j_{0},i}\rangle,

where the 1st step is because of Eq. (5), the 2nd step is from dEj0,i0dxi=0\frac{\mathrm{d}E_{j_{0},i_{0}}}{\mathrm{d}x_{i}}=0, and the 3rd step is followed by Part 4.

Proof of Part 8. For simplicity of writing proofs, we use ()(\cdot) to denote (x)(x). Following the definition of L()L(\cdot) in Definition 3.9, it holds that

L()j0,i0=0.5c()j0,i02\displaystyle L(\cdot)_{j_{0},i_{0}}=0.5c(\cdot)_{j_{0},i_{0}}^{2} (6)

Thus, we have

dL()j0,i0dxi=\displaystyle\frac{\mathrm{d}L(\cdot)_{j_{0},i_{0}}}{\mathrm{d}x_{i}}= d(0.5c()j0,i02)dxi\displaystyle~{}\frac{\mathrm{d}(0.5c(\cdot)_{j_{0},i_{0}}^{2})}{\mathrm{d}x_{i}}
=\displaystyle= c()j0,i0dc()dxi\displaystyle~{}c(\cdot)_{j_{0},i_{0}}\frac{\mathrm{d}c(\cdot)}{\mathrm{d}x_{i}}
=\displaystyle= c()j0,i0(f()j0𝖠j0,i,h(y)i0f()j0,h(y)i0f()j0,𝖠j0,i),\displaystyle~{}c(\cdot)_{j_{0},i_{0}}\cdot(\langle f(\cdot)_{j_{0}}\circ\operatorname{\mathsf{A}}_{j_{0},i},h(y)_{i_{0}}\rangle-\langle f(\cdot)_{j_{0}},h(y)_{i_{0}}\rangle\cdot\langle f(\cdot)_{j_{0}},\operatorname{\mathsf{A}}_{j_{0},i}\rangle),

where the 1st step is followed by the Eq. (6), the 2nd step is due to the chain rule, the last step followed by Part 5.

Appendix C Time for Computation

In Section C.1, we show the calculation of ff (Similarly as Section B, we still ignore the 1/d1/d factor here) and hh. In Section C.2, we show the way we calculate cc in straightforward way. In Section C.3 and Section C.4, we define two artificial functions pp and qq, and show how to compute them. In Section C.5, we provide the way to re-write the gradient in an elegant way. In Section C.6, we finally put these all together and find the running time of our algorithm.

C.1 Compute ff and hh

Lemma C.1 (Computing ff and hh).

Suppose the following objects are given

  • Let f(x)f(x) be defined as Definition 3.6

  • Let h(y)h(y) be defined as Definition 3.7

Then, we have

  • f(x)f(x) can be calculated in time of 𝒯mat(n,d,n)+𝒯mat(n,d,d){\cal T}_{\mathrm{mat}}(n,d,n)+{\cal T}_{\mathrm{mat}}(n,d,d)

  • h(y)h(y) can be calculated in time of 𝒯mat(n,d,d){\cal T}_{\mathrm{mat}}(n,d,d)

Proof.

Note that

f(x)=D1exp(A1XA2)\displaystyle f(x)=D^{-1}\exp(A_{1}XA_{2}^{\top})

and

D=diag(exp(A1XA2)𝟏n)\displaystyle D=\mathrm{diag}(\exp(A_{1}XA_{2}^{\top}){\bf 1}_{n})

We firstly compute exp(A1XA2)\exp(A_{1}XA_{2}^{\top}), this takes time of 𝒯mat(n,d,d){\cal T}_{\mathrm{mat}}(n,d,d) and 𝒯mat(n,d,n){\cal T}_{\mathrm{mat}}(n,d,n).

Then we can compute DD, which takes O(n2)O(n^{2}) time.

Then we can compute D1exp(A1XA2)D^{-1}\exp(A_{1}XA_{2}^{\top}), this takes O(n2)O(n^{2}) time.

Thus, the overall time is

𝒯mat(n,d,d)+𝒯mat(n,d,n)+O(n2)\displaystyle~{}{\cal T}_{\mathrm{mat}}(n,d,d)+{\cal T}_{\mathrm{mat}}(n,d,n)+O(n^{2})
=\displaystyle= O(𝒯mat(n,d,d)+𝒯mat(n,d,n))\displaystyle~{}O({\cal T}_{\mathrm{mat}}(n,d,d)+{\cal T}_{\mathrm{mat}}(n,d,n))

Note that h(y)=A3Yh(y)=A_{3}Y which takes time of 𝒯mat(n,d,d){\cal T}_{\mathrm{mat}}(n,d,d).

Thus, the proof is completed. ∎

C.2 Compute cc

Lemma C.2 (Computing cc).

Suppose the following objects are given

  • Bn×dB\in\mathbb{R}^{n\times d}

  • f(x)n×nf(x)\in\mathbb{R}^{n\times n} is given

  • h(y)n×dh(y)\in\mathbb{R}^{n\times d} is given,

Then one can compute c(x)n×dc(x)\in\mathbb{R}^{n\times d} in 𝒯mat(n,n,d){\cal T}_{\mathrm{mat}}(n,n,d) time.

Proof.

Based on Definition of c(x)n×dc(x)\in\mathbb{R}^{n\times d} which is

c(x)=f(x)h(y)E\displaystyle c(x)=f(x)h(y)-E

Computing f(x)h(y)f(x)h(y) takes time of 𝒯mat(n,n,d){\cal T}_{\mathrm{mat}}(n,n,d), and calculating f(x)h(y)Ef(x)h(y)-E takes time of O(nd)O(nd).

Thus, finally, overall time is

𝒯mat(n,n,d)+O(nd).\displaystyle{\cal T}_{\mathrm{mat}}(n,n,d)+O(nd).

C.3 Computation for qq

We will define qq, and then explain how to calculate qq.

Definition C.3.

Define c(x)n×dc(x)\in\mathbb{R}^{n\times d} as in Definition 3.8. Define h(y)n×dh(y)\in\mathbb{R}^{n\times d} as in Definition 3.7.

We define q(x)n×nq(x)\in\mathbb{R}^{n\times n} as

q(x):=c(x)n×dh(y)d×n\displaystyle q(x):=\underbrace{c(x)}_{n\times d}\underbrace{h(y)^{\top}}_{d\times n}

Then we use q(x)j0q(x)_{j_{0}}^{\top} to denote the j0j_{0}-th row of q(x)n×nq(x)\in\mathbb{R}^{n\times n}.

Lemma C.4.

If it holds that

  • Suppose c(x)n×dc(x)\in\mathbb{R}^{n\times d} is given

  • Suppose h(y)n×dh(y)\in\mathbb{R}^{n\times d} is given

Then, we can compute q(x)q(x) in the time of O(𝒯mat(n,n,d))O({\cal T}_{\mathrm{mat}}(n,n,d)).

Proof.

Recall that q(x)=c(x)h(y)q(x)=c(x)h(y)^{\top}. Thus it takes time of 𝒯mat(n,d,n)=O(𝒯mat(n,n,d)){\cal T}_{\mathrm{mat}}(n,d,n)=O({\cal T}_{\mathrm{mat}}(n,n,d)). ∎

C.4 Computation for p(x)p(x)

Let us firstly define pp, and then we can show how to construct it.

Definition C.5.

For every index j0[n]j_{0}\in[n], we define p(x)j0np(x)_{j_{0}}\in\mathbb{R}^{n} as

p(x)j0:=(diag(f(x)j0)f(x)j0f(x)j0)q(x)j0.\displaystyle p(x)_{j_{0}}:=(\mathrm{diag}(f(x)_{j_{0}})-f(x)_{j_{0}}f(x)_{j_{0}}^{\top})q(x)_{j_{0}}.

We define p(x)n×np(x)\in\mathbb{R}^{n\times n} in the sense that p(x)j0p(x)_{j_{0}}^{\top} is the j0j_{0}-th row of p(x)p(x).

Lemma C.6.

If the below requirements are holding that

  • Suppose f(x)n×nf(x)\in\mathbb{R}^{n\times n} is given

  • Suppose q(x)n×nq(x)\in\mathbb{R}^{n\times n} is given

Then, we can compute q(x)q(x) in O(n2)O(n^{2}) time.

Proof.

Since diag(f(x)j0)\mathrm{diag}(f(x)_{j_{0}}) is a diagonal matrix and f(x)j0f(x)j0f(x)_{j_{0}}f(x)_{j_{0}}^{\top} is a rank-one matrix, we know that p(x)j0np(x)_{j_{0}}\in\mathbb{R}^{n} can be computed in O(n)O(n), for each j0[n]j_{0}\in[n]. Thus we can construct matrix p(x)n×np(x)\in\mathbb{R}^{n\times n} in n×O(n)=O(n2)n\times O(n)=O(n^{2}) time in total. ∎

C.5 Analyze the closed form of gradient

Lemma C.7 ( ).

Define the functions f(x)n×nf(x)\in\mathbb{R}^{n\times n}, c(x)n×dc(x)\in\mathbb{R}^{n\times d}, h(y)n×dh(y)\in\mathbb{R}^{n\times d}, q(x)n×nq(x)\in\mathbb{R}^{n\times n} and p(x)n×np(x)\in\mathbb{R}^{n\times n} as in Definitions 3.6, 3.8, 3.7, C.3 and C.5 respectively. A1,A2n×dA_{1},A_{2}\in\mathbb{R}^{n\times d} are two given matrices. We define𝖠=A1A2\operatorname{\mathsf{A}}=A_{1}\otimes A_{2}. Let L(x)L(x) be defined as Definition 1.2. Let L(x)j0,i0L(x)_{j_{0},i_{0}} be defined as Definition 3.9. Then, we can show that dL(x)dx=vec(A1p(x)A2)\frac{\mathrm{d}L(x)}{\mathrm{d}x}=\operatorname{vec}(A_{1}^{\top}p(x)A_{2}).

Proof.

From the Lemma statement, we have

dL(x,y)j0,i0dxi=c(x,y)j0,i0(f(x)j0𝖠j0,i,h(y)i0f(x)j0,h(y)i0f(x)j0,𝖠j0,i)\displaystyle\frac{\mathrm{d}L(x,y)_{j_{0},i_{0}}}{\mathrm{d}x_{i}}=c(x,y)_{j_{0},i_{0}}\cdot(\langle f(x)_{j_{0}}\circ\operatorname{\mathsf{A}}_{j_{0},i},h(y)_{i_{0}}\rangle-\langle f(x)_{j_{0}},h(y)_{i_{0}}\rangle\cdot\langle f(x)_{j_{0}},\operatorname{\mathsf{A}}_{j_{0},i}\rangle) (7)

Note that by Fact A.1, it holds that

f(x)j0𝖠j0,i,h(y)i0=𝖠j0,idiag(f(x)j0)h(y)i0\displaystyle\langle f(x)_{j_{0}}\circ\operatorname{\mathsf{A}}_{j_{0},i},h(y)_{i_{0}}\rangle=\operatorname{\mathsf{A}}_{j_{0},i}^{\top}\mathrm{diag}(f(x)_{j_{0}})h(y)_{i_{0}}

and

f(x)j0,vf(x)j0,𝖠j0,i=𝖠j0,if(x)j0f(x)j0h(y)i0\displaystyle\langle f(x)_{j_{0}},v\rangle\cdot\langle f(x)_{j_{0}},\operatorname{\mathsf{A}}_{j_{0},i}\rangle=\operatorname{\mathsf{A}}_{j_{0},i}^{\top}f(x)_{j_{0}}f(x)_{j_{0}}^{\top}h(y)_{i_{0}}

Therefore, Eq. (7) becomes

dL(x)j0,i0dxi=\displaystyle\frac{\mathrm{d}L(x)_{j_{0},i_{0}}}{\mathrm{d}x_{i}}= c(x,y)j0,i0(𝖠j0,idiag(f(x)j0)h(y)i0𝖠j0,if(x)j0f(x)j0h(y)i0)\displaystyle~{}c(x,y)_{j_{0},i_{0}}\cdot(\operatorname{\mathsf{A}}_{j_{0},i}^{\top}\mathrm{diag}(f(x)_{j_{0}})h(y)_{i_{0}}-\operatorname{\mathsf{A}}_{j_{0},i}^{\top}f(x)_{j_{0}}f(x)_{j_{0}}^{\top}h(y)_{i_{0}})
=\displaystyle= c(x,y)j0,i0𝖠j0,i(diag(f(x)j0)f(x)j0f(x)j0)h(y)i0,\displaystyle~{}c(x,y)_{j_{0},i_{0}}\cdot\operatorname{\mathsf{A}}_{j_{0},i}^{\top}(\mathrm{diag}(f(x)_{j_{0}})-f(x)_{j_{0}}f(x)_{j_{0}}^{\top})h(y)_{i_{0}}, (8)

where the 2nd step follows from simple algebra.

Recall the way we define q(x)j0q(x)_{j_{0}} (see Definition C.3).

q(x)j0:=i0=1dc(x)j0,i0h(y)i0.\displaystyle q(x)_{j_{0}}:=\sum_{i_{0}=1}^{d}c(x)_{j_{0},i_{0}}h(y)_{i_{0}}. (9)

Recall that p(x)j0np(x)_{j_{0}}\in\mathbb{R}^{n} is define as Definition C.5,

p(x)j0:=(diag(f(x)j0)f(x)j0f(x)j0)q(x)j0.\displaystyle p(x)_{j_{0}}:=(\mathrm{diag}(f(x)_{j_{0}})-f(x)_{j_{0}}f(x)_{j_{0}}^{\top})q(x)_{j_{0}}. (10)

It holds that

dL(x)dx\displaystyle~{}\frac{\mathrm{d}L(x)}{\mathrm{d}x}
=\displaystyle= j0=1ni0=1ddL(x)j0,i0dx\displaystyle~{}\sum_{j_{0}=1}^{n}\sum_{i_{0}=1}^{d}\frac{\mathrm{d}L(x)_{j_{0},i_{0}}}{\mathrm{d}x}
=\displaystyle= j0=1ni0=1dc(x)j0,i0scalar𝖠j0d2×n(diag(f(x)j0)f(x)j0f(x)j0)n×nh(y)i0n×1\displaystyle~{}\sum_{j_{0}=1}^{n}\sum_{i_{0}=1}^{d}\underbrace{c(x)_{j_{0},i_{0}}}_{\mathrm{scalar}}\cdot\underbrace{\operatorname{\mathsf{A}}_{j_{0}}^{\top}}_{d^{2}\times n}\underbrace{(\mathrm{diag}(f(x)_{j_{0}})-f(x)_{j_{0}}f(x)_{j_{0}}^{\top})}_{n\times n}\underbrace{h(y)_{i_{0}}}_{n\times 1}
=\displaystyle= j0=1n𝖠j0(diag(f(x)j0)f(x)j0f(x)j0)q(x)j0\displaystyle~{}\sum_{j_{0}=1}^{n}\operatorname{\mathsf{A}}_{j_{0}}^{\top}(\mathrm{diag}(f(x)_{j_{0}})-f(x)_{j_{0}}f(x)_{j_{0}}^{\top})q(x)_{j_{0}}
=\displaystyle= j0=1n𝖠j0p(x)j0\displaystyle~{}\sum_{j_{0}=1}^{n}\operatorname{\mathsf{A}}_{j_{0}}^{\top}p(x)_{j_{0}}
=\displaystyle= vec(A1p(x)A2)\displaystyle~{}\operatorname{vec}(A_{1}^{\top}p(x)A_{2})

where the 1st step is because of Definition 1.2, the 2nd step is based on Eq. (C.5), the 3rd step is followed by Eq. (9), the 4th step is due to Eq. (10), and the last step uses tensor-trick.

C.6 Putting it together

Lemma C.8 (Attention gradient computation, formal version of Lemma 4.1).

If it holds that

  • Define A1,A2,A3,En×dA_{1},A_{2},A_{3},E\in\mathbb{R}^{n\times d}. Define X,Yd×dX,Y\in\mathbb{R}^{d\times d} to be several input fixed matrices.

  • Let X,Yd×dX,Y\in\mathbb{R}^{d\times d} denote matrix variables (we will compute gradient with respect to XX )

    • For easy of writing, we also use vector variables xd2×1x\in\mathbb{R}^{d^{2}\times 1} and yd2×1y\in\mathbb{R}^{d^{2}\times 1}, i.e., vec(X)=x\operatorname{vec}(X)=x.

  • Let g=dL(X)dxd2g=\frac{\mathrm{d}L(X)}{\mathrm{d}x}\in\mathbb{R}^{d^{2}} (where L(X)L(X) is defined as Definition 1.2)

Then we can show that gradient gd2g\in\mathbb{R}^{d^{2}} can be computed in 𝒯mat(n,d,n)+𝒯mat(n,d,d){\cal T}_{\mathrm{mat}}(n,d,n)+{\cal T}_{\mathrm{mat}}(n,d,d) time.

Proof.

Step 1. we compute f(x)f(x), h(y)h(y). This takes O(𝒯mat(n,n,d)+𝒯mat(n,d,d))O({\cal T}_{\mathrm{mat}}(n,n,d)+{\cal T}_{\mathrm{mat}}(n,d,d)) time due to Lemma C.1.

Step 2. we compute c(x)c(x). This takes time of O(𝒯mat(n,n,d)+𝒯mat(n,d,d))O({\cal T}_{\mathrm{mat}}(n,n,d)+{\cal T}_{\mathrm{mat}}(n,d,d)) due to Lemma C.2.

Step 3. we compute q(x)q(x). This take time of O(𝒯mat(n,n,d))O({\cal T}_{\mathrm{mat}}(n,n,d)) due to Lemma C.4.

Step 4. we compute p(x)p(x). This take time of O(n2)O(n^{2}) due to Lemma C.6.

Step 5. using Lemma C.7, we know that gradient is equivalent to vec(A1p(x)A2)\operatorname{vec}(A_{1}^{\top}p(x)A_{2}). Suppose A1d×n,p(x)n×n,A2n×dA_{1}^{\top}\in\mathbb{R}^{d\times n},p(x)\in\mathbb{R}^{n\times n},A_{2}\in\mathbb{R}^{n\times d} are given, then it can be calculated in time of O(𝒯mat(n,n,d)+𝒯mat(n,d,d))O({\cal T}_{\mathrm{mat}}(n,n,d)+{\cal T}_{\mathrm{mat}}(n,d,d)).

Thus, overall running for computing gradient is

O(𝒯mat(n,d,d)+𝒯mat(n,d,n))\displaystyle O({\cal T}_{\mathrm{mat}}(n,d,d)+{\cal T}_{\mathrm{mat}}(n,d,n))

time. ∎

Appendix D Fast Running Time via Polynomial Method

Recall that in the previous section, for convenience of computing the derivative, we ignoreed the dd factor in ff. That factor dd doesn’t impact the running time of our algorithms since it is just a rescaling factor. To apply the tools from previous work [4], we will now reconsider the 1/d1/d factor in ff. In Section D.1, we will show how to efficiently and explicitly construct a low rank representation for ff. In Section D.2, we show how to create a low rank construction for c(x)c(x). In Section D.3, Section D.4 and Section D.5, we further give low rank presentations for q(x),p1(x),p2(x)q(x),p_{1}(x),p_{2}(x). In Section D.6, we prove our final algorithmic result by putting everything together.

D.1 Low rank representation to ff

Using [4]’s polynomial method result, we are able to obtain the following low-rank representation result,

Lemma D.1 (Section 3 of [4]).

For any B=o(logn)B=o(\sqrt{\log n}), there exists a k1=no(1)k_{1}=n^{o(1)} such that: Let A1,A2n×dA_{1},A_{2}\in\mathbb{R}^{n\times d} be two matrices and Xd×dX\in\mathbb{R}^{d\times d} be a square matrix. It holds that A1XB,A2B\|A_{1}^{\top}X\|_{\infty}\leq B,\|A_{2}\|_{\infty}\leq B, then there are two matrices U1,V1n×k1U_{1},V_{1}\in\mathbb{R}^{n\times k_{1}} such that U1V1f(x)ϵ/poly(n)\|U_{1}V_{1}^{\top}-f(x)\|_{\infty}\leq\epsilon/\operatorname{poly}(n). Here f(x)=D1exp(A1XA2/d)f(x)=D^{-1}\exp(A_{1}XA_{2}^{\top}/d) and we define D=diag(exp(A1XA2/d)𝟏n)D=\mathrm{diag}(\exp(A_{1}XA_{2}^{\top}/d){\bf 1}_{n}). Moreover, these matrices U1,V1U_{1},V_{1} can be explicitly constructed in n1+o(1)n^{1+o(1)} time.

D.2 Low rank representation to cc

Lemma D.2.

Let d=O(logn)d=O(\log n). Assume that each number in the n×dn\times d matrices EE and h(y)h(y) can be written using O(logn)O(\log n) bits. Let n×dn\times d matrix c(x)c(x) be defined as Definition 3.8. Then, there are two matrices U1,V1n×k1U_{1},V_{1}\in\mathbb{R}^{n\times k_{1}} we have U1V1h(y)Ec(x)ϵ/poly(n)\|U_{1}V_{1}^{\top}h(y)-E-c(x)\|_{\infty}\leq\epsilon/\operatorname{poly}(n).

Proof.

We can show that

U1V1h(y)Ec(x)=\displaystyle\|U_{1}V_{1}^{\top}h(y)-E-c(x)\|_{\infty}= U1V1h(y)Ef(x)h(y)+E\displaystyle~{}\|U_{1}V_{1}^{\top}h(y)-E-f(x)h(y)+E\|_{\infty}
=\displaystyle= (U1V1f(x))h(y)\displaystyle~{}\|(U_{1}V_{1}^{\top}-f(x))h(y)\|_{\infty}
\displaystyle\leq ϵ/poly(n)\displaystyle~{}\epsilon/\operatorname{poly}(n)

where the first step follows from c(x)=f(x)h(y)Ec(x)=f(x)h(y)-E.

D.3 Low rank representation to qq

Lemma D.3.

Let k2=no(1)k_{2}=n^{o(1)}. Define c(x)n×dc(x)\in\mathbb{R}^{n\times d} to be as in Definition 3.8. Define h(y)n×dh(y)\in\mathbb{R}^{n\times d} to be as in Definition 3.7. Assume that q(x):=h(y)c(x)n×nq(x):=h(y)c(x)^{\top}\in\mathbb{R}^{n\times n}. There are two matrices U2,V2n×k2U_{2},V_{2}\in\mathbb{R}^{n\times k_{2}} such that U2V2q(x)ϵ/poly(n)\|U_{2}V_{2}^{\top}-q(x)\|_{\infty}\leq\epsilon/\operatorname{poly}(n). The matrices U2,V2U_{2},V_{2} can be explicitly constructed in n1+o(1)n^{1+o(1)} time.

Proof.

We define q~(x)\widetilde{q}(x) to be the approximation of q(x)q(x).

From Lemma D.2, we know that U1V1h(y)EU_{1}V_{1}^{\top}h(y)-E is a good approximation to c(x)c(x).

Then we should pick in this way q~(x)=h(y)(U1V1h(y)E)\widetilde{q}(x)=h(y)(U_{1}V_{1}^{\top}h(y)-E)^{\top}.

Now, let us turn q~(x)\widetilde{q}(x) into some low-rank representation

q~(x)=h(y)n×dh(y)d×nV1n×k1U1k1×nh(y)n×dEd×n\displaystyle\widetilde{q}(x)=\underbrace{h(y)}_{n\times d}\underbrace{h(y)^{\top}}_{d\times n}\underbrace{V_{1}}_{n\times k_{1}}\underbrace{U_{1}^{\top}}_{k_{1}\times n}-\underbrace{h(y)}_{n\times d}\underbrace{E^{\top}}_{d\times n}

It is obvious that we should can first compute h(y)V1h(y)^{\top}V_{1} which only takes n1+o(1)n^{1+o(1)} time. Then since all the low rank matrices are known, then we can explicitly construct U2,V2n×k2U_{2},V_{2}\in\mathbb{R}^{n\times k_{2}} where k2=max{d,k}+d=no(1)k_{2}=\max\{d,k\}+d=n^{o(1)}.

For controlling the error, we can show

q~(x)q(x)=\displaystyle\|\widetilde{q}(x)-q(x)\|_{\infty}= h(y)(U1V1h(y))E)h(y)c(x)\displaystyle~{}\|h(y)(U_{1}V_{1}^{\top}h(y))-E)^{\top}-h(y)c(x)^{\top}\|_{\infty}
\displaystyle\leq dh(y)U1V1h(y))Ec(x)\displaystyle~{}d\cdot\|h(y)\|_{\infty}\cdot\|U_{1}V_{1}^{\top}h(y))-E-c(x)\|_{\infty}
\displaystyle\leq ϵ/poly(n)\displaystyle~{}\epsilon/\operatorname{poly}(n)

Thus, we complete the proof. ∎

D.4 Low rank representation to p1(x)p_{1}(x)

Lemma D.4.

Let k1=no(1)k_{1}=n^{o(1)}. Let k2=no(1)k_{2}=n^{o(1)}. Assume that p1(x):=f(x)q(x)p_{1}(x):=f(x)\circ q(x). Assume U1,V1n×k1U_{1},V_{1}\in\mathbb{R}^{n\times k_{1}} approximates the f(x)f(x) such that U1V1f(x)ϵ/poly(n)\|U_{1}V_{1}^{\top}-f(x)\|_{\infty}\leq\epsilon/\operatorname{poly}(n). Assume U2,V2n×k2U_{2},V_{2}\in\mathbb{R}^{n\times k_{2}} approximates the q(x)n×nq(x)\in\mathbb{R}^{n\times n} such that U2V2q(x)ϵ/poly(n)\|U_{2}V_{2}^{\top}-q(x)\|_{\infty}\leq\epsilon/\operatorname{poly}(n). Then there are matrices U3,V3n×k3U_{3},V_{3}\in\mathbb{R}^{n\times k_{3}} such that U3V3p1(x)ϵ/poly(n)\|U_{3}V_{3}^{\top}-p_{1}(x)\|_{\infty}\leq\epsilon/\operatorname{poly}(n). The matrices U3,V3U_{3},V_{3} can be explicitly constructed in n1+o(1)n^{1+o(1)} time.

Proof.

We choose U3=U1U2U_{3}=U_{1}\oslash U_{2} and V3=V1V2V_{3}=V_{1}\oslash V_{2}. This can be computed in n1+o(1)n^{1+o(1)} time.

For easy of writing proofs, we call f~(x)=U1V1\widetilde{f}(x)=U_{1}V_{1}^{\top} and q~(x)=U2V2\widetilde{q}(x)=U_{2}V_{2}^{\top}.

Using Fact A.2, we know that

U3V3p1(x)\displaystyle\|U_{3}V_{3}^{\top}-p_{1}(x)\|_{\infty}\leq U3V3f(x)q(x)\displaystyle~{}\|U_{3}V_{3}^{\top}-f(x)\circ q(x)\|_{\infty}
=\displaystyle= (U1U2)(V1V2)f(x)q(x)\displaystyle~{}\|(U_{1}\oslash U_{2})(V_{1}\oslash V_{2})^{\top}-f(x)\circ q(x)\|_{\infty}
=\displaystyle= (U1V1)(U2V2)f(x)q(x)\displaystyle~{}\|(U_{1}V_{1}^{\top})\circ(U_{2}V_{2}^{\top})-f(x)\circ q(x)\|_{\infty}
=\displaystyle= f~(x)q~(x)f(x)q(x)\displaystyle~{}\|\widetilde{f}(x)\circ\widetilde{q}(x)-f(x)\circ q(x)\|_{\infty}
=\displaystyle= f~(x)q~(x)f~(x)q(x)+f~(x)q(x)f(x)q(x)\displaystyle~{}\|\widetilde{f}(x)\circ\widetilde{q}(x)-\widetilde{f}(x)\circ q(x)+\widetilde{f}(x)\circ q(x)-f(x)\circ q(x)\|_{\infty}
\displaystyle\leq f~(x)q~(x)f~(x)q(x)+f~(x)q(x)f(x)q(x)\displaystyle~{}\|\widetilde{f}(x)\circ\widetilde{q}(x)-\widetilde{f}(x)\circ q(x)\|_{\infty}+\|\widetilde{f}(x)\circ q(x)-f(x)\circ q(x)\|_{\infty}
\displaystyle\leq ϵ/poly(n)\displaystyle~{}\epsilon/\operatorname{poly}(n)

where the 1st step follows from the way we define p1(x)p_{1}(x), the 2nd step follows from the way we define U3U_{3} and V3V_{3}, the 3rd step follows from Fact A.2, the 4th step follows from the way we define f~(x)\widetilde{f}(x) and q~(x)\widetilde{q}(x), the 5th step follows from simple algebra, the 6th step follows by triangle inequality, and the last step follows by that entries are bounded and f~(x)f(x)ϵ/poly(n)\|\widetilde{f}(x)-f(x)\|_{\infty}\leq\epsilon/\operatorname{poly}(n) (Lemma assumption) and q~(x)q(x)ϵ/poly(n)\|\widetilde{q}(x)-q(x)\|_{\infty}\leq\epsilon/\operatorname{poly}(n) (Lemma assumption)

D.5 Low rank representation p2(x)p_{2}(x)

Lemma D.5.

Let k1=no(1)k_{1}=n^{o(1)}. Let k2=no(1)k_{2}=n^{o(1)}. Let k4=no(1)k_{4}=n^{o(1)}. Assume that p2(x)p_{2}(x) is an n×nn\times n where j0j_{0}-th column p2(x)j0=f(x)j0f(x)j0q(x)j0p_{2}(x)_{j_{0}}=f(x)_{j_{0}}f(x)_{j_{0}}^{\top}q(x)_{j_{0}} for each j0[n]j_{0}\in[n]. Assume U1,V1n×k1U_{1},V_{1}\in\mathbb{R}^{n\times k_{1}} approximates the f(x)f(x) such that U1V1f(x)ϵ/poly(n)\|U_{1}V_{1}^{\top}-f(x)\|_{\infty}\leq\epsilon/\operatorname{poly}(n). Assume U2,V2n×k2U_{2},V_{2}\in\mathbb{R}^{n\times k_{2}} approximates the q(x)n×nq(x)\in\mathbb{R}^{n\times n} such that U2V2q(x)ϵ/poly(n)\|U_{2}V_{2}^{\top}-q(x)\|_{\infty}\leq\epsilon/\operatorname{poly}(n). Then there are matrices U4,V4n×k4U_{4},V_{4}\in\mathbb{R}^{n\times k_{4}} such that U4V4p2(x)ϵ/poly(n)\|U_{4}V_{4}^{\top}-p_{2}(x)\|_{\infty}\leq\epsilon/\operatorname{poly}(n). The matrices U4,V4U_{4},V_{4} can be explicitly constructed in n1+o(1)n^{1+o(1)} time.

Proof.

We define a local vector function r(x)nr(x)\in\mathbb{R}^{n} where r(x)j0r(x)_{j_{0}} is f(x)j0q(x)j0f(x)_{j_{0}}q(x)_{j_{0}}. Let r~(x)\widetilde{r}(x) denote the approximation of r(x)r(x).

Note that (U1V1)j0,(U_{1}V_{1})_{j_{0},*}^{\top} is a good approximation to f(x)j0f(x)_{j_{0}}.

Note that (U2V2)j0,(U_{2}V_{2})_{j_{0},*}^{\top} is a good approximation to q(x)j0q(x)_{j_{0}}.

Let r~(x)j0:=f~(x)j0,q~(x)j0=(U1V1)j0,(U2V2)j0,\widetilde{r}(x)_{j_{0}}:=\langle\widetilde{f}(x)_{j_{0}},\widetilde{q}(x)_{j_{0}}\rangle=(U_{1}V_{1})_{j_{0},*}\cdot(U_{2}V_{2})_{j_{0},*}^{\top}.

For the computation side, we firstly compute V1V2V_{1}V_{2}^{\top}. This takes n1+o(1)n^{1+o(1)} time.

Next, we we have

r~(x)j0=\displaystyle\widetilde{r}(x)_{j_{0}}= (U1V1)j0,(U2V2)j0,\displaystyle~{}(U_{1}V_{1})_{j_{0},*}\cdot(U_{2}V_{2})_{j_{0},*}^{\top}
=\displaystyle= (U1)j0,1×k1V1V2k1×k2(U2)j0,k2×1\displaystyle~{}\underbrace{(U_{1})_{j_{0},*}}_{1\times k_{1}}\underbrace{V_{1}V_{2}^{\top}}_{k_{1}\times k_{2}}\underbrace{(U_{2})_{j_{0},*}^{\top}}_{k_{2}\times 1}

Once the V1V2V_{1}V_{2}^{\top} are pre-computed, the above step only takes O(k1k2)O(k_{1}k_{2}) time. Since there nn coordinates, so the overall time is still O(nk1k2)=n1+o(1)O(nk_{1}k_{2})=n^{1+o(1)}.

Let f~(x)=U1V1\widetilde{f}(x)=U_{1}V_{1}^{\top} denote the approximation of f(x)f(x). Then we just use f~(x)\widetilde{f}(x) and r~(x)\widetilde{r}(x) to approximate p2(x)p_{2}(x) in the following sense, let p~2(x)=f~(x)diag(r~(x))\widetilde{p}_{2}(x)=\widetilde{f}(x)\mathrm{diag}(\widetilde{r}(x)). Since f~(x)\widetilde{f}(x) has low rank representation, and diag(r~(x))\mathrm{diag}(\widetilde{r}(x)) is a diagonal matrix, then it is obvious how to construct U4U_{4} and V4V_{4}. Basically U4=U1U_{4}=U_{1} and V4=diag(r~(x))V1V_{4}=\mathrm{diag}(\widetilde{r}(x))V_{1}.

Now, we need to control the error, we have

U4V4p2(x)=\displaystyle\|U_{4}V_{4}^{\top}-p_{2}(x)\|_{\infty}= p~2(x)p2(x)\displaystyle~{}\|\widetilde{p}_{2}(x)-p_{2}(x)\|_{\infty}
=\displaystyle= maxj0[n]f~(x)j0r~(x)j0f(x)j0r(x)j0\displaystyle~{}\max_{j_{0}\in[n]}\|\widetilde{f}(x)_{j_{0}}\widetilde{r}(x)_{j_{0}}-f(x)_{j_{0}}r(x)_{j_{0}}\|_{\infty}
=\displaystyle= maxj0[n]f~(x)j0r~(x)j0f~(x)j0r(x)j0+f~(x)j0r(x)j0f(x)j0r(x)j0\displaystyle~{}\max_{j_{0}\in[n]}\|\widetilde{f}(x)_{j_{0}}\widetilde{r}(x)_{j_{0}}-\widetilde{f}(x)_{j_{0}}r(x)_{j_{0}}+\widetilde{f}(x)_{j_{0}}r(x)_{j_{0}}-f(x)_{j_{0}}r(x)_{j_{0}}\|_{\infty}
\displaystyle\leq maxj0[n]f~(x)j0r~(x)j0f~(x)j0r(x)j0+f~(x)j0r(x)j0f(x)j0r(x)j0\displaystyle~{}\max_{j_{0}\in[n]}\|\widetilde{f}(x)_{j_{0}}\widetilde{r}(x)_{j_{0}}-\widetilde{f}(x)_{j_{0}}r(x)_{j_{0}}\|_{\infty}+\|\widetilde{f}(x)_{j_{0}}r(x)_{j_{0}}-f(x)_{j_{0}}r(x)_{j_{0}}\|_{\infty}

where the 2nd step follows follows from definition of p2(x)p_{2}(x) and p~2(x)\widetilde{p}_{2}(x).

For the first term, we have

maxj0[n]f~(x)j0r~(x)j0f~(x)j0r(x)j0\displaystyle\max_{j_{0}\in[n]}\|\widetilde{f}(x)_{j_{0}}\widetilde{r}(x)_{j_{0}}-\widetilde{f}(x)_{j_{0}}r(x)_{j_{0}}\|_{\infty}\leq maxj0[n]f~(x)j0|r~(x)j0r(x)j0|\displaystyle~{}\max_{j_{0}\in[n]}\|\widetilde{f}(x)_{j_{0}}\|_{\infty}\cdot|\widetilde{r}(x)_{j_{0}}-r(x)_{j_{0}}|
\displaystyle\leq ϵ/poly(n)\displaystyle~{}\epsilon/\operatorname{poly}(n)

For the second term, we have

maxj0[n]f~(x)j0r(x)j0f(x)j0r(x)j0\displaystyle\max_{j_{0}\in[n]}\|\widetilde{f}(x)_{j_{0}}r(x)_{j_{0}}-f(x)_{j_{0}}r(x)_{j_{0}}\|_{\infty}\leq maxj0[n]f~(x)j0f(x)j0|r(x)j0|\displaystyle~{}\max_{j_{0}\in[n]}\|\widetilde{f}(x)_{j_{0}}-f(x)_{j_{0}}\|_{\infty}\cdot|r(x)_{j_{0}}|
\displaystyle\leq ϵ/poly(n)\displaystyle~{}\epsilon/\operatorname{poly}(n)

Using the three equations we obtained above, the proof is completed. ∎

D.6 Fast Computation in Almost Linear Time

Theorem D.6 (Main result, formal version of Theorem 1.6).

Assuming the entries of A1,A2,X,A3,Y,EA_{1},A_{2},X,A_{3},Y,E are represented using O(logn)O(\log n) bits, there is a n1+o(1)n^{1+o(1)} time algorithm to solve 𝖠𝖠𝗍𝗍𝖫𝖦𝖢(n,d=O(logn),B=o(logn))\mathsf{AAttLGC}(n,d=O(\log n),B=o(\sqrt{\log n})) (see Definition 1.4) up to 1/poly(n)1/\operatorname{poly}(n) accuracy. In particular, our algorithm outputs a gradient vector g~d2\widetilde{g}\in\mathbb{R}^{d^{2}} such that dLdxg~1/poly(n)\|\frac{\mathrm{d}L}{\mathrm{d}x}-\widetilde{g}\|_{\infty}\leq 1/\operatorname{poly}(n).

Proof.

Recall definition of n×nn\times n matrices p(x)p(x) (Definition C.5), p1(x)p_{1}(x) (see Lemma D.5) and p2(x)p_{2}(x) (Lemma D.4), it is straightforward that

p(x)=p1(x)p2(x).\displaystyle p(x)=p_{1}(x)-p_{2}(x).

Using Lemma D.1, Lemma D.2, Lemma D.3, we know that assumptions in Lemma D.4 and Lemma D.5 are holding, so that we can use Lemma D.4 and Lemma D.5 to obtain that

  • p1(x)p_{1}(x) has approximate low rank representation U3,V3U_{3},V_{3}, let p~1(x)\widetilde{p}_{1}(x) denote U3V3U_{3}V_{3}^{\top}

  • p2(x)p_{2}(x) has approximate low rank representation U4,V4U_{4},V_{4}, let p~2(x)\widetilde{p}_{2}(x) denote U4V4U_{4}V_{4}^{\top}

All of the Lemmas D.1, D.2, D.3, D.4 and D.5 are taking n1+o(1)n^{1+o(1)} time.

According to the proof for the Lemma C.7, we have that

L(X)dx=vec(A1p(x)A2)\displaystyle\frac{L(X)}{\mathrm{d}x}=\operatorname{vec}(A_{1}^{\top}p(x)A_{2})

Thus, we firstly compute A1U3V3A2A_{1}^{\top}U_{3}V_{3}^{\top}A_{2},

  • We compute A1U3d×k3A_{1}^{\top}U_{3}\in\mathbb{R}^{d\times k_{3}}, this takes n1+o(1)n^{1+o(1)} time

  • We compute V3A2k3×dV_{3}^{\top}A_{2}\in\mathbb{R}^{k_{3}\times d}, this takes n1+o(1)n^{1+o(1)} time

  • Compute (A1U3)(V3A2)(A_{1}^{\top}U_{3})\cdot(V_{3}^{\top}A_{2}), this takes d2no(1)d^{2}n^{o(1)} time

Second, we can compute A1U4V4A2A_{1}^{\top}U_{4}V_{4}^{\top}A_{2},

  • We compute A1U4d×k4A_{1}^{\top}U_{4}\in\mathbb{R}^{d\times k_{4}}, this takes n1+o(1)n^{1+o(1)} time

  • We compute V4A2k4×dV_{4}^{\top}A_{2}\in\mathbb{R}^{k_{4}\times d}, this takes n1+o(1)n^{1+o(1)} time

  • Compute (A1U4)(V4A2)(A_{1}^{\top}U_{4})\cdot(V_{4}^{\top}A_{2}), this takes d2no(1)d^{2}n^{o(1)} time

So, overall running time is still n1+o(1)n^{1+o(1)}.

We have

dL(X)dxg~=\displaystyle\|\frac{\mathrm{d}L(X)}{\mathrm{d}x}-\widetilde{g}\|_{\infty}= vec(A1p(x)A2)vec(A1p~(x)A2)\displaystyle~{}\|\operatorname{vec}(A_{1}^{\top}p(x)A_{2})-\operatorname{vec}(A_{1}^{\top}\widetilde{p}(x)A_{2})\|_{\infty}
=\displaystyle= A1p(x)A2A1p~(x)A2\displaystyle~{}\|A_{1}^{\top}p(x)A_{2}-A_{1}^{\top}\widetilde{p}(x)A_{2}\|_{\infty}
=\displaystyle= A1(p1(x)p2(x))A2A1(p~1(x)p~2(x))A2\displaystyle~{}\|A_{1}^{\top}(p_{1}(x)-p_{2}(x))A_{2}-A_{1}^{\top}(\widetilde{p}_{1}(x)-\widetilde{p}_{2}(x))A_{2}\|_{\infty}
\displaystyle\leq A1(p1(x)p~1(x))A2+A1(p2(x)p~2(x))A2\displaystyle~{}\|A_{1}^{\top}(p_{1}(x)-\widetilde{p}_{1}(x))A_{2}\|_{\infty}+\|A_{1}^{\top}(p_{2}(x)-\widetilde{p}_{2}(x))A_{2}\|_{\infty}
\displaystyle\leq A1A2n2(p1(x)p~1(x)+p2(x)p~2(x))\displaystyle~{}\|A_{1}\|_{\infty}\|A_{2}\|_{\infty}\cdot n^{2}\cdot(\|p_{1}(x)-\widetilde{p}_{1}(x)\|_{\infty}+\|p_{2}(x)-\widetilde{p}_{2}(x)\|_{\infty})
\displaystyle\leq ϵ/poly(n)\displaystyle~{}\epsilon/\operatorname{poly}(n)

where the 4th step follows from triangle inequality, the last step follows from entries in A1,A2A_{1},A_{2} are bounded, and p1(x)p~1(x)ϵ/poly(n)\|p_{1}(x)-\widetilde{p}_{1}(x)\|_{\infty}\leq\epsilon/\operatorname{poly}(n), p2(x)p~2(x)ϵ/poly(n)\|p_{2}(x)-\widetilde{p}_{2}(x)\|_{\infty}\leq\epsilon/\operatorname{poly}(n) .

Picking ϵ=1/poly(n)\epsilon=1/\operatorname{poly}(n), we have the proof completed. ∎

Acknowledgments

The authors would like to thank Yichuan Deng for helpful discussions.

References

  • AA [22] Amol Aggarwal and Josh Alman. Optimal-degree polynomial approximations for exponentials and gaussian kernel density estimation. In 37th Computational Complexity Conference (CCC 2022). Schloss Dagstuhl-Leibniz-Zentrum für Informatik, 2022.
  • ACSS [20] Josh Alman, Timothy Chu, Aaron Schild, and Zhao Song. Algorithms and hardness for linear algebra on geometric graphs. In 2020 IEEE 61st Annual Symposium on Foundations of Computer Science (FOCS), pages 541–552. IEEE, 2020.
  • ALS+ [23] Josh Alman, Jiehao Liang, Zhao Song, Ruizhe Zhang, and Danyang Zhuo. Bypass exponential time preprocessing: Fast neural network training via weight-data correlation preprocessing. In NeurIPS. arXiv preprint arXiv:2211.14227, 2023.
  • AS [23] Josh Alman and Zhao Song. Fast attention requires bounded entries. In NeurIPS, 2023.
  • AS [24] Josh Alman and Zhao Song. How to capture higher-order correlations? generalizing matrix softmax attention to kronecker computation. In ICLR, 2024.
  • BCS [97] Peter Bürgisser, Michael Clausen, and Mohammad A Shokrollahi. Algebraic complexity theory, volume 315. Springer Science & Business Media, 1997.
  • BIS [17] Arturs Backurs, Piotr Indyk, and Ludwig Schmidt. On the fine-grained complexity of empirical risk minimization: Kernel methods and neural networks. Advances in Neural Information Processing Systems (NeurIPS), 30, 2017.
  • Blä [13] Markus Bläser. Fast matrix multiplication. Theory of Computing, pages 1–60, 2013.
  • BMR+ [20] Tom Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared D Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, et al. Language models are few-shot learners. Advances in neural information processing systems, 33:1877–1901, 2020.
  • BPSW [21] Jan van den Brand, Binghui Peng, Zhao Song, and Omri Weinstein. Training (over- parametrized) neural networks in near-linear time. 12th Innovations in Theoretical Computer Science Conference (ITCS), 2021.
  • BSZ [23] Jan van den Brand, Zhao Song, and Tianyi Zhou. Algorithm and hardness for dynamic attention maintenance in large language models. arXiv preprint arXiv:2304.02207, 2023.
  • CKNS [20] Moses Charikar, Michael Kapralov, Navid Nouri, and Paris Siminelakis. Kernel density estimation through density constrained near neighbor search. In 2020 IEEE 61st Annual Symposium on Foundations of Computer Science (FOCS), pages 172–183. IEEE, 2020.
  • CLP+ [21] Beidi Chen, Zichang Liu, Binghui Peng, Zhaozhuo Xu, Jonathan Lingjie Li, Tri Dao, Zhao Song, Anshumali Shrivastava, and Re.Mongoose Christopher. A learnable lsh framework for efficient neural network training. International Conference on Learning Representation, 2021.
  • CND+ [22] Aakanksha Chowdhery, Sharan Narang, Jacob Devlin, Maarten Bosma, Gaurav Mishra, Adam Roberts, Paul Barham, Hyung Won Chung, Charles Sutton, Sebastian Gehrmann, et al. Palm: Scaling language modeling with pathways. arXiv preprint arXiv:2204.02311, 2022.
  • DCLT [18] Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. Bert: Pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805, 2018.
  • DGS [23] Yichuan Deng, Yeqi Gao, and Zhao Song. Solving tensor low cycle rank approximation. arXiv preprint arXiv:2304.06594, 2023.
  • DHS+ [22] Yichuan Deng, Hang Hu, Zhao Song, Omri Weinstein, and Danyang Zhuo. Training overparametrized neural networks in sublinear time. arXiv preprint arXiv:2208.04508, 2022.
  • DJS+ [19] Huaian Diao, Rajesh Jayaram, Zhao Song, Wen Sun, and David Woodruff. Optimal sketching for kronecker product regression and low rank approximation. Advances in neural information processing systems, 32, 2019.
  • DLMS [23] Yichuan Deng, Zhihang Li, Sridhar Mahadevan, and Zhao Song. Zero-th order algorithm for softmax attention optimization. arXiv preprint arXiv:2307.08352, 2023.
  • DMS [23] Yichuan Deng, Sridhar Mahadevan, and Zhao Song. Randomized and deterministic attention sparsification algorithms for over-parameterized feature dimension. arXiv preprint arXiv:2304.04397, 2023.
  • DSSW [18] Huaian Diao, Zhao Song, Wen Sun, and David Woodruff. Sketching for kronecker product regression and p-splines. In International Conference on Artificial Intelligence and Statistics, pages 1299–1308. PMLR, 2018.
  • DSY [23] Yichuan Deng, Zhao Song, and Junze Yin. Faster robust tensor power method for arbitrary order. arXiv preprint arXiv:2306.00406, 2023.
  • GQSW [24] Yeqi Gao, Lianke Qin, Zhao Song, and Yitan Wang. A sublinear adversarial training algorithm. In ICLR. arXiv preprint arXiv:2208.05395, 2024.
  • GSWY [23] Yeqi Gao, Zhao Song, Weixin Wang, and Junze Yin. A fast optimization view: Reformulating single layer attention in llm based on tensor and svm trick, and solving it in matrix multiplication time. arXiv preprint arXiv:2309.07418, 2023.
  • GSX [23] Yeqi Gao, Zhao Song, and Shenghao Xie. In-context learning for attention scheme: from single softmax regression to multiple softmax regression via a tensor trick. arXiv preprint arXiv:2307.02419, 2023.
  • [26] Yeqi Gao, Zhao Song, and Xin Yang. Differentially private attention computation. arXiv preprint arXiv:2305.04701, 2023.
  • [27] Yeqi Gao, Zhao Song, and Junze Yin. Gradientcoin: A peer-to-peer decentralized large language models. arXiv preprint arXiv:2308.10502, 2023.
  • GSYZ [23] Yeqi Gao, Zhao Song, Xin Yang, and Ruizhe Zhang. Fast quantum algorithm for attention computation. arXiv preprint arXiv:2307.08045, 2023.
  • HJK+ [23] Insu Han, Rajesh Jarayam, Amin Karbasi, Vahab Mirrokni, David P. Woodruff, and Amir Zandieh. Hyperattention: Long-context attention in near-linear time. arXiv preprint arXiv:2310.05869, 2023.
  • Inc [23] Adobe Inc. Adobe firefly. In Adobe. https://www.adobe.com/sensei/generative-ai/firefly.html, 2023.
  • IP [01] Russell Impagliazzo and Ramamohan Paturi. On the complexity of k-sat. Journal of Computer and System Sciences, 62(2):367–375, 2001.
  • KKL [20] Nikita Kitaev, Łukasz Kaiser, and Anselm Levskaya. Reformer: The efficient transformer. arXiv preprint arXiv:2001.04451, 2020.
  • KMZ [23] Praneeth Kacham, Vahab Mirrokni, and Peilin Zhong. Polysketchformer: Fast transformers via sketches for polynomial kernels. arXiv preprint arXiv:2310.01655, 2023.
  • KWH [23] Feyza Duman Keles, Pruthuvi Mahesakya Wijewardena, and Chinmay Hegde. On the computational complexity of self-attention. In International Conference on Algorithmic Learning Theory, pages 597–619. PMLR, 2023.
  • LOG+ [19] Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, and Veselin Stoyanov. Roberta: A robustly optimized bert pretraining approach. arXiv preprint arXiv:1907.11692, 2019.
  • Man [23] James Manyika. An overview of bard: an early experiment with generative ai. Technical report, Tech. rep., Technical report, Google AI, 2023.
  • MGN+ [23] Sadhika Malladi, Tianyu Gao, Eshaan Nichani, Alex Damian, Jason D Lee, Danqi Chen, and Sanjeev Arora. Fine-tuning language models with just forward passes. arXiv preprint arXiv:2305.17333, 2023.
  • PMXA [23] Abhishek Panigrahi, Sadhika Malladi, Mengzhou Xia, and Sanjeev Arora. Trainable transformer in transformer. arXiv preprint arXiv:2307.01189, 2023.
  • RSZ [22] Aravind Reddy, Zhao Song, and Lichen Zhang. Dynamic tensor product regression. In NeurIPS, 2022.
  • Rub [18] Aviad Rubinstein. Hardness of approximate nearest neighbor search. In Proceedings of the 50th annual ACM SIGACT symposium on theory of computing (STOC), pages 1260–1268, 2018.
  • SWZ [19] Zhao Song, David P Woodruff, and Peilin Zhong. Relative error tensor low rank approximation. In SODA. arXiv preprint arXiv:1704.08246, 2019.
  • SYZ [21] Zhao Song, Shuo Yang, and Ruizhe Zhang. Does preprocessing help training over-parameterized neural networks? 35th Conference on Neural Information Processing Systems, 2021.
  • SZZ [24] Zhao Song, Lichen Zhang, and Ruizhe Zhang. Training multi-layer over-parametrized neural network in subquadratic time. In ITCS. arXiv preprint arXiv:2112.07628, 2024.
  • TDFH+ [22] Romal Thoppilan, Daniel De Freitas, Jamie Hall, Noam Shazeer, Apoorv Kulshreshtha, Heng-Tze Cheng, Alicia Jin, Taylor Bos, Leslie Baker, Yu Du, et al. Lamda: Language models for dialog applications. arXiv preprint arXiv:2201.08239, 2022.
  • TLI+ [23] Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin, Edouard Grave, and Guillaume Lample. Llama: Open and efficient foundation language models. arXiv preprint arXiv:2302.13971, 2023.
  • TMS+ [23] Hugo Touvron, Louis Martin, Kevin Stone, Peter Albert, Amjad Almahairi, Yasmine Babaei, Nikolay Bashlykov, Soumya Batra, Prajjwal Bhargava, Shruti Bhosale, et al. Llama 2: Open foundation and fine-tuned chat models. arXiv preprint arXiv:2307.09288, 2023.
  • VSP+ [17] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. Advances in neural information processing systems, 30, 2017.
  • Wil [18] Virginia Vassilevska Williams. On some fine-grained questions in algorithms and complexity. In Proceedings of the international congress of mathematicians: Rio de janeiro 2018, pages 3447–3487. World Scientific, 2018.
  • WTB+ [22] Jason Wei, Yi Tay, Rishi Bommasani, Colin Raffel, Barret Zoph, Sebastian Borgeaud, Dani Yogatama, Maarten Bosma, Denny Zhou, Donald Metzler, et al. Emergent abilities of large language models. arXiv preprint arXiv:2206.07682, 2022.
  • YCRI [22] Ann Yuan, Andy Coenen, Emily Reif, and Daphne Ippolito. Wordcraft: story writing with large language models. In 27th International Conference on Intelligent User Interfaces, pages 841–852, 2022.
  • YDY+ [19] Zhilin Yang, Zihang Dai, Yiming Yang, Jaime Carbonell, Russ R Salakhutdinov, and Quoc V Le. Xlnet: Generalized autoregressive pretraining for language understanding. Advances in neural information processing systems, 32, 2019.
  • Zha [22] Lichen Zhang. Speeding up optimizations via data structures: Faster search, sample and maintenance. Master’s thesis, Carnegie Mellon University, 2022.
  • ZHDK [23] Amir Zandieh, Insu Han, Majid Daliri, and Amin Karbasi. Kdeformer: Accelerating transformers via kernel density estimation. In ICML. arXiv preprint arXiv:2302.02451, 2023.
  • ZRG+ [22] Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen, Christopher Dewan, Mona Diab, Xian Li, Xi Victoria Lin, et al. Opt: Open pre-trained transformer language models. arXiv preprint arXiv:2205.01068, 2022.