The Fine-Grained Complexity of Gradient Computation for Training Large Language Models
Large language models (LLMs) have made fundamental contributions over the last a few years. To train an LLM, one needs to alternatingly run ‘forward’ computations and ‘backward’ computations. The forward computation can be viewed as attention function evaluation, and the backward computation can be viewed as a gradient computation. In previous work by [Alman and Song, NeurIPS 2023], it was proved that the forward step can be performed in almost-linear time in certain parameter regimes, but that there is no truly sub-quadratic time algorithm in the remaining parameter regimes unless the popular hypothesis is false. In this work, we show nearly identical results for the harder-seeming problem of computing the gradient of loss function of one layer attention network, and thus for the entire process of LLM training. This completely characterizes the fine-grained complexity of every step of LLM training.
1 Introduction
Large language models (LLMs) have emerged as popular technologies, driving breakthroughs across many applications in natural language processing, computer vision, translation, and many other areas [47, 15, 35, 51, 9, 54, 14, 45, 46, 30, 36, 44, 50, 49]. The training of these models is a computationally intensive process, characterized by alternating between two primary operations: forward computation and backward computation. Forward computation, or function evaluation, involves the propagation of input data through the network to generate predictions. Conversely, backward computation, or gradient computation, is the process of calculating the gradient of the loss function with respect to the model’s parameters, facilitating the optimization of these parameters during training.
The efficiency of these computations directly impacts the feasibility and scalability of training LLMs, particularly as models grow in size and complexity. Recent work by [4, 5] has carefully studied the forward computation step. They demonstrated a sharp computational boundary, showing that how quickly the forward steps can be performed depends critically on how large the entries are of the matrices which define the model parameters. They showed a near-linear time algorithm when these entries are small, and also proved that when the entries are large, there is no algorithm much faster than the trivial algorithm, contingent upon the Strong Exponential Time Hypothesis () [31] holding true. This finding underscores a fundamental limitation in accelerating the training of LLMs, raising pivotal questions about the inherent computational complexity of these models.
The Strong Exponential Time Hypothesis () was introduced by Impagliazzo and Paturi [31] over 20 years ago. It is a strengthening of the conjecture, and asserts that our current best algorithms are roughly optimal (for detailed statement, see Hypothesis 3.3 below). is a popular conjecture from fine-grained complexity theory which has been used to prove lower bounds for a wide variety of algorithmic problems. See, for instance, the survey [48].
In other words, in some parameter regimes, the algorithm of [4] performs the forward steps about as quickly as one could hope for, whereas in other regimes, assuming , it is impossible to design a nontrivially fast algorithm. However, this leaves open many important questions about LLM training. In the case when forward computation can be done quickly, can the same be said for backward computation? If not, then the entire training process would still be slow. Relatedly, in parameter regimes where forward computation is known to be hard, is backward computation also hard? If not, perhaps heuristic tricks could be used, or other details of the model could be modified, to speed up the overall training. As we will see shortly, the backward step is defined in a much more complicated way than the forward step, and it is not evident that algorithms or lower bounds for one extend to the other.
Our study aims to resolve these questions and determine the fine-grained complexity of the backward computation phase. Our main result (which we state more foramlly shortly) shows that the same computational threshold from forward computation also arises for the backward problem, and that the problems are easy (opr hard) in the exact same parameter regimes. Thus, the forward algorithm of [4] can be combined with our novel backward algorithm to perform each training step for LLMs in near-linear time when the parameter matrix entries are small enough, whereas when the entries are not small enough, neither step can be performed quickly.
In addition to characterizing the fine-grained complexity of LLM training, our result for gradient computation is novel for a few reasons.
- •
-
•
There has been previous work on the algorithms for backward/gradient computation [10, 42, 17, 3, 23, 43]. That said, most of these works focus on backwards computation in other settings. The only previous work we’re aware of that studies the optimization of attention layers (for LLMs) is [24], which uses Newton method that rely on Hessian computation. However, Hessian computation is substantially more expensive than gradient computation; our results apply to the gradient computation and get around the Hessian “barrier”, allowing for faster algorithms in some parameter regimes, and more powerful lower bounds in others.
1.1 Problem Definition
Before formally stating our results, we begin by precisely defining the problems we study. We begin with the following problem of the computation of general Attention forward layer.
Definition 1.1 (-th layer forward computation).
Given weights , and letting denote the -th layer input, then is defined recursively as
where
-
•
.
-
•
denotes the exponential function which is entry-wise, i.e., for all matrices .
-
•
operation takes a vector as input and generates a diagonal matrix with the entries of that vector.
-
•
denotes the length- all ones vector.
In mathematical terms, optimization in the context of attention computation is described as (by renaming the to be and to be ):
Definition 1.2 (Attention optimization).
Given four size matrices and . Suppose that a size square matrix is also given. The attention optimization problem is formulated as:
Here is
and denotes the squared Frobenius norm, i.e., .
Remark 1.3.
In principle, the loss function above, and resulting gradients below, should depend on both and . However, since the final matrix computed in the norm in depends only linearly on , it is straightforward to incorporate it into either an algorithm or lower bound. Thus, in this work, we focus on the case where is variable and is a fixed input to simplify some arguments.
We thus define Approximate Attention Loss function Gradient Computation problem as follows:
Definition 1.4 (Approximate Attention Loss Gradient Computation ()).
Given four size matrices , and a square matrix to be fixed matrices. Assume that , . Assume all numbers (in matrices) are also in bits model. Let be defined as Definition 1.2. Let denote the gradient of loss function .
The goal is to output a vector such that
Here for matrix , .
1.2 Main Results
Our main results show that there is a threshold in the computational complexity of depending on the bound . When we give a new near-linear-time algorithm, and when , we show that such an algorithm is impossible assuming SETH. This matches the results of [4], where a nearly identical threshold at around was also observed. Our results therefore imply that the entire LLM training process has this computational threshold.
Theorem 1.5 (Main result, Lower bound, informal version of Theorem 5.5).
Assuming , there is no algorithm running in time for any for the (see Definition 1.4).
Theorem 1.6 (Main result, Upper bound, informal version of Theorem D.6).
Assuming entries are bounded, there is a time algorithm to solve (see Definition 1.4) up to accuracy.
Our new algorithm (Theorem 1.6) builds on a low-rank approximation for the attention matrix from prior work [1, 4]. Incorporating these approximation into the gradient computation is not straightforward; in the forward problem, one simply multiplies the attention matrix by an input value matrix, but in the backward problem, it is combined with other matrices in an intricate (non-linear) way. We ultimately use tools from tensor algebra to get a handle on the entry-wise products and high-rank sparse matrices which arise in the gradient computation but do not typically preserve the needed low-rank structure.
Our new lower bound (Theorem 1.5) comes from a careful reduction from a special case the forward problem (where hardness is known from prior work) to the backward problem. Reducing from computing a function to computing its gradient in general is quite challenging or impossible without control over how quickly the gradient may be growing or changing, and in general, the gradient of the forward (attention) computation can behave quite erratically (which is likely necessary for the expressive power of attention units). Nonetheless, in the special case of the inputs for which attention computation is known to be hard from prior work, we are able to reasonably control the growth of these gradients and successfully perform our reduction.
Roadmap. We discuss other related works in Section 2. In Section 3, we provide the basic notation, definitions, backgrounds, and facts which we will use. In Section 4, we provide the proof sketch of our algorithm and defer the details to the Appendix. In Section 5, we provide our main lower bound result. In Section 6, we briefly conclude our paper.
2 Related Work
Fine-grained Complexity
Numerous algorithmic techniques have been used in theory and in practice for attention computations. The first algorithm with provable guarantees, by Zandieh, Han, Daliri, and Karbasi [53], used locality sensitive hashing (LSH) techniques [12], while later work by Alman and Song [4] used polynomial approxmation methods [2, 1]. We particularly focus here on the latter technique, which is the only algorithm we’re aware of which achieves near-linear running time.
Keles, Wijewardena, and Hedge [34] established the first lower bound on attention computation under the assumption of . Their findings demonstrated that when , it is not possible to execute forward computations in subquadratic time. The later lower bound of [4] further incorporated the magnitudes of the input entries into the lower bound to tightly match the aforementioned algorithms. Both use the high-level technique of [7] from kernel density estimation, and build on methods derived from fine-grained complexity associated with approximate nearest neighbor search [40] and the polynomial method [1].
Fast Attention Computation
Optimizing the computation of attention mechanisms in pre-trained LLMs, given their extensive parameter sets, has been a focal point of recent research. Various studies have explored the application of locality sensitive hashing (LSH) techniques to approximate attention mechanisms. [32] introduced two methods to enhance computational efficiency, including the use of LSH to replace dot product attention and a reversible residual layer to substitute the standard residual layer. [13] refined this approximation, noting that LSH’s efficiency does not require constant parameter updates. [53] proposed an innovative estimator based on Kernel Density Estimation (KDE) to speed up the softmax function and matrix multiplication computations. Some recent works [29, 33] have specifically used sketching techniques to avoid large entries in the attention matrix. [38] developed techniques utilizing a transformer within a transformer (TinT) model to simulate the transformer’s forward and backward passes, significantly increasing parameter efficiency. [37] tackled the challenge of fine-tuning LLMs with high memory demands by improving the classical ZO-SCD optimizer, creating a memory-efficient gradient estimator that requires only a forward pass. [11] provided insights into dynamic attention problems, they provide algorithm and hardness for the dynamic setting of attention problem. [28] introduces a quantum algorithm for attention computation, opening new avenues for efficiency improvements. [26] provides a result for computing the attention matrix differentially privately. [20] introduces a randomized and deterministic attention sparsification algorithms for over-parameterized feature dimension. [19] provides a zero-th order method to accelarate the computation of attention.
Transformer Training
Transformer architectures (the backbone of LLMs) have been trained with alternating steps of forward and backward computations since their introduction [47, 15, 35, 51, 9, 54]. In Appendix B below, we perform computations to verify that our stated problems are the same as the forward and backward steps from the literature.
3 Preliminary
In Section 3.1, we define some basic notation we will use. In Section 3.2, we state important facts related to fast matrix multiplication. In Section 3.3, provide the formal definition of the Strong Exponential Time Hypothesis. In Section 3.4, we define several intermediate functions related to softmax and exponential which will arise in our algorithms. In Section 3.5, we define the loss function. In Section 3.6, we provide standard tensor tricks which we will use. In Section 3.7, we show how to reformulate the loss function for our purposes.
3.1 Notation
For any positive integer , we define . For two same length vector and , we use to denote the inner product between and , i.e., . We use to denote vector that -th entry is . Let denote the length- all ones vector. It is not hard to see that . For a vector , we use to denote the transpose of . For a matrix , we use to denote the transpose of matrix . For a vector , we use to denote the vector that -th coordinate is . For a matrix , we use to denote the matrix that -th coordinate is . For a function , we use to denote . Let be positive integers. Let and . We define the Kronecker product between matrices and , denoted , as is equal to , where .
3.2 Matrix Multiplication
We define matrix multiplication notation and state some well-know facts here.
Definition 3.1.
Let , denote any three positive integers. We use to denote the time of multiplying an matrix with another .
It is well-known that
3.3 Backgrounds on Complexity
Over 20 years ago, Impagliazzo and Paturi [31] introduced the Strong Exponential Time Hypothesis (), an enhancement of the conjecture. It posits that the existing algorithms for solving problems are essentially as efficient as possible:
Hypothesis 3.3 (Strong Exponential Time Hypothesis ()).
For any , there exists a positive integer for which solving - problems with variables in time is impossible, including with the use of randomized algorithms.
SETH, a widely recognized conjecture, has been instrumental in establishing fine-grained lower bounds across a broad spectrum of algorithmic challenges, as highlighted in the survey [48].
3.4 Definitions related with Softmax
Now, we start by some definitions about which will be helpful. Let denote the vectorization of .
Definition 3.4.
Let be two matrices. Suppose that . We define be a size sub-block from . Note that there such sub-blocks.
For every , let us define function to be:
Definition 3.5.
Suppose that there are two size matrices . We define be a size sub-block from . (Recall that .)
For every index , we consider a function, as:
Definition 3.6.
Suppose that is defined as in Definition 3.5.
Recall is defined as in Definition 3.4.
For a fixed , let us consider function
Let denote the matrix where -th row is .
Definition 3.7.
For every , we define as:
Here let denote the matrix representation of . Let matrix where column is .
3.5 Loss Functions
In this section, we introduce some helpful definitions related to both .
Definition 3.8.
For every , we use to denote the normalized vector defined by Definition 3.6. For every , we let to be defined in Definition 3.7.
Consider every , every . Let us consider as follows:
Here is the -th coordinate/location of for . This is equivalent to .
Definition 3.9.
For every , for every . Let us define to be .
3.6 Tensor Trick
We state the well-known tensor-trick. It has been widely used in literature of linear algebra related to tensor computations [41, 21, 18, 5, 25, 52, 39, 27, 22, 16].
Fact 3.10 (Tensor trick).
For two matrices and , define . Let . Let denote the vector representation of . Then we have .
Using the above tensor-trick, it is easy to observe that
Fact 3.11.
For two matrices and , denote . Let . Let a submatrix of (by properly selecting rows of ). Let denote the vector representation of . Then, we have
-
•
-
•
,
Here is the -th row of matrix .
Proof.
We can use the definition in Lemma and Definition 3.10, to prove it. ∎
3.7 Reshape the Loss function via Tensor Trick
Lemma 3.12.
Given the below requirements
-
•
Here are three matrices , , and
-
•
Let to be the Kronecker product of the two matrices and
-
–
For every , define to be a sized block in the matrix
-
–
-
•
be a matrix. Define as the -th coordinate/location of for every pair of and
-
•
Here are two square matrices , let
-
•
Let be defined as Definition 1.2
-
•
For every pair of , , recall that definition of can be found in in Definition 3.9
Then, we have
4 Proof Sketch for General Upper Bound
The most straightforward way to compute the gradient would take time in order to explicitly write down the matrix . By using fast matrix multiplication and regroup the entries, we can obtain our first intermediate algorithm, which runs in quadratic time to compute the gradient.
Lemma 4.1 (Attention gradient computation, informal version of Lemma C.8).
If the following conditions hold
-
•
Define four size matrices and two square matrices to be input fixed matrices.
-
•
Let and denote matrix variables (we will compute gradient with respect to )
-
–
For easy of writing, we also use vector variables and
-
–
-
•
Let (We abuse notation and are the same thin)
Then we can show that gradient can be calculated in time.
Next, we will show how to improve the running time of computing gradient from quadratic time () to almost linear time .
Note that by linearity of derivative, we can show that
Based on calculations we perform in Section B, Section C, and several linear algebra facts, we can show that
For any fixed , consider this quantity. Since this expression involves an matrix, the most straightforward way to calculate it would take time, and so summing over all would lead to a cubic-time algorithm. It is not too difficult to improve this: the matrix
is easily decomposed into a low-rank part ( which has size ) and a sparse part ( which also has size ), which reduces the calculation of each part to only time, and the total running time to time.
However, we are aiming for a almost-linear time algorithm, and it is not possible to achieve this by treating the different separately, since a given must take time to process. Instead, we use tensor techniques related to low-rank approximations to simultanouesly compute all together and sum them in almost-linear time.
In order to do that, we create several extra artificial or intermediate matrices (see Section C), (see Section C). We will show the gradient can be finally constructed using a simple chaining technique (see Section D for more details), from , (handling similarly), (handling similarly), () to . Intuitively, the chaining shows that a low rank representation for yields one for , and these in turn yield one for , and so on.
In particular, using , we obtain that can be written as
which in fact notably removes the summation step of to . Using the notation of , we finally yield that we need to compute . Thus as long as has a low-rank representation, then we can solve the in time (see Section D for more details). In particular, we will find that is the entry-wise product of two matrices with low-rank representations from prior work, which we can combine using a column-wise Kronecker product to approximate itself.
5 General Lower Bound
We will critically make use of the known hardness result for attention computation itself, which we state now.
Definition 5.1 (Attention Computation).
Given as input matrices and a parameter , compute a matrix satisfying
where and .
Lemma 5.2 (Lemma 4.7 in [4]).
Assuming , there is no algorithm running in time for any constant that solves Attention Computation (Definition 5.1), even when the inputs satisfy the following constraints, for any parameter :
-
•
,
-
•
,
-
•
There is a value such that every entry of is in the interval and at least half the entries in each row of are equal to ,
-
•
moreover , and
-
•
.
Next, we show that the attention optimization problem behaves particularly well when given matrices constrained as in Lemma 5.2:
Lemma 5.3.
Let be a fixed matrix whose entries are real numbers in the interval , and such that in each row of , at least half the entries are equal to . Let be any matrix whose entries are all in . For , define the matrix , where is applied entry-wise. Define the function by
Then, for all we have
-
•
,
-
•
.
Proof.
Let denote the matrix . For , we calculate that and so
For , let be the set of s in column of , i.e., . Hence, for and , the entry of the matrix is given by
where the first step follows from definition, the second step follows from simple algebra.
We thus get an explicit expression for :
We define
and then we define
Combining the above three equations, we can obtain
Since, for each row of , at least half the entries equal , and all the entries are in the interval , we can bound
(1) |
Furthermore, since the derivative of with respect to is , we can bound
(2) |
We may similarly bound
(3) |
and
(4) |
We can thus bound the derivative of (where here, all the ′ notation means derivative with respect to ):
where the 1st step follows from definition, the 2nd step follows from simple algebra, the 3rd step follows from cancelling , the 4th step is using Eq. (1) (for ) and Eq. (4) (for ), the 5th step follows from simple algebra, and the last step follows from simple algebra.
Similarly, we can provide a lower bound ,
where the 1st step follows from definition, the 2nd step follows form simple algebra, the 3rd step follows Eq. (2) (for ) and Eq. (3) (for ), the 4th step follows from simple algebra, and the last step follows from simple algbera.
Finally, letting , we have again by the quotient rule that is equal to
which we similarly bound in magnitude by . ∎
We recall a simple approximation from calculus:
Lemma 5.4.
Let be a twice-differentiable function such that for all . For any positive integer , define the sum
Then,
Proof.
If two have , then from our bound on , we know that . We can thus bound the difference
by
and
Thus, we complete the proof. ∎
Finally, we are ready for our main result:
Theorem 5.5 (Formal version of Theorem 1.5).
Let by any function with and . Assuming , there is no algorithm running in time for any constant for Approximate Attention Loss Gradient Computation (Definition 1.4), even in the case where and the input matrices satisfy , , , for some scalar , and .
Proof.
Suppose there were such an algorithm. We call it times to refute Lemma 5.2 (with parameter ). Let be the input matrices to Lemma 5.2, and set , , , , and for a parameter . Suppose the function is in Lemma 5.3 where is the matrix , so that is the matrix . It follows from Lemma 5.3 that
We can compute in time since then is the all-1s matrix, and our goal is to output .
6 Conclusion
Our results give a complete fine-grained analysis of the running time needed to train LLMs. We show that there is a threshold depending on the parameter , the magnitude of the parameter matrix entries. In settings where is small, a near-linear-time algorithm for LLM training is possible by using our novel algorithm for backward computation. In settings where is large, not only does our algorithm not apply, but we show it is impossible to design a nontrivially-fast algorithm (barring a breakthrough in satisfiability algorithms that would refute the popular ).
These insights can guide LLM designers to more efficient algorithms. When can be made small, it would lead to substantial savings in the computational resources needed for training and expression. When must be large (perhaps to achieve a high expressiveness?), our lower bounds show that one may as well use straigthforward algorithms and focus on other aspects of algorithm speedup such as parallelization. The magnitude of needed has been studied more recently (e.g., [5]), and the need for fast training algorithms may further motivate this direction of research.
Appendix
Roadmap.
In Section A, we provide basic notation and facts. In Section B, we provide details about gradient computations. In Section C, we explain the computation time for the gradient of attention loss. In Section D, we show how to further improve the gradient computation from quadratic time to almost linear time.
Appendix A Preliminaries
In Section A.1, we define some basic notation. In Section A.2, we state several facts which we will use.
A.1 Notation
For any positive integer , we define .
For two same length vector and , we use to denote the inner product between and , i.e., . We use to denote vector that -th entry is . Let denote the length- all ones vector. It is not hard to see that .
For a vector , we use to denote the transpose of . For a matrix , we use to denote the transpose of matrix .
For a vector , we use to denote the vector that -th coordinate is . For a matrix , we use to denote the matrix that -th coordinate is .
We define the Kronecker product between matrices and , denoted , as is equal to , where .
For each positive integers , we use to denote the time of multiplying matrix with another matrix.
A.2 Basic Facts
Fact A.1.
Let . Then we have
-
•
.
-
•
.
Fact A.2 (Folklore).
Let . Let . Then we have
Here, given and , the is the row-wise Kronecker product, i.e., for all , and
Appendix B More Details about Gradient Computation
In this section, we provide details and calculations to assist with gradient and derivative computations. We remark that, in this section, for convenience of computing a closed form for the gradient, we ignore the factor in function . Since it is only a rescaling factor, it won’t affect how we compute these matrices in general.
Lemma B.1 (The gradient computation for several different functions with respect to ).
For every , define to be the -th column for . . The scalar function , column function , scalar function and scalar function are defined as in Definitions 3.4, 3.5, 3.6, 3.8 and 3.9 respectively.
Then, for each , we have
-
•
Part 1.
-
•
Part 2. For each ,
-
•
Part 3. For each
-
•
Part 4. For each ,
-
•
Part 5. For each ,
-
•
Part 6. For each , for each ,
-
•
Part 7. For each , for every
-
•
Part 8. For each , for each
Proof.
Proof of Part 1. We have
Proof of Part 2. We have
Proof of Part 3.
We can show
where the 3rd step follows from Part 2, the last step follows from definition of .
Proof of Part 4.
For simplicity of writing proofs, we use to denote .
We can show
where the 1st step follows from definition of , the 2nd step follows from Part 3, the 3rd step follows from Fact A.1.
Proof of Part 5. For simplicity of writing proofs, we use to denote .
We can show that
For the first term, we have
where the 1st step follows from Part 3, the 2nd step follows from definition of .
For the second term, we have
where the 1st step follows from basic calculus, the 2nd step follows from Part 4, the 3rd step follows from definition of .
Using all of the results above, it holds that
Proof of Part 6. It follows Part 5 directly.
Proof of Part 7. For simplicity of writing proofs, we use to denote .
Following the definition of in Definition 3.8, it holds that
(5) |
Thus it holds that
where the 1st step is because of Eq. (5), the 2nd step is from , and the 3rd step is followed by Part 4.
Proof of Part 8. For simplicity of writing proofs, we use to denote . Following the definition of in Definition 3.9, it holds that
(6) |
Thus, we have
where the 1st step is followed by the Eq. (6), the 2nd step is due to the chain rule, the last step followed by Part 5.
∎
Appendix C Time for Computation
In Section C.1, we show the calculation of (Similarly as Section B, we still ignore the factor here) and . In Section C.2, we show the way we calculate in straightforward way. In Section C.3 and Section C.4, we define two artificial functions and , and show how to compute them. In Section C.5, we provide the way to re-write the gradient in an elegant way. In Section C.6, we finally put these all together and find the running time of our algorithm.
C.1 Compute and
Lemma C.1 (Computing and ).
Proof.
Note that
and
We firstly compute , this takes time of and .
Then we can compute , which takes time.
Then we can compute , this takes time.
Thus, the overall time is
Note that which takes time of .
Thus, the proof is completed. ∎
C.2 Compute
Lemma C.2 (Computing ).
Suppose the following objects are given
-
•
-
•
is given
-
•
is given,
Then one can compute in time.
Proof.
Based on Definition of which is
Computing takes time of , and calculating takes time of .
Thus, finally, overall time is
∎
C.3 Computation for
We will define , and then explain how to calculate .
Definition C.3.
We define as
Then we use to denote the -th row of .
Lemma C.4.
If it holds that
-
•
Suppose is given
-
•
Suppose is given
Then, we can compute in the time of .
Proof.
Recall that . Thus it takes time of . ∎
C.4 Computation for
Let us firstly define , and then we can show how to construct it.
Definition C.5.
For every index , we define as
We define in the sense that is the -th row of .
Lemma C.6.
If the below requirements are holding that
-
•
Suppose is given
-
•
Suppose is given
Then, we can compute in time.
Proof.
Since is a diagonal matrix and is a rank-one matrix, we know that can be computed in , for each . Thus we can construct matrix in time in total. ∎
C.5 Analyze the closed form of gradient
Lemma C.7 ( ).
Proof.
From the Lemma statement, we have
(7) |
Recall the way we define (see Definition C.3).
(9) |
Recall that is define as Definition C.5,
(10) |
It holds that
where the 1st step is because of Definition 1.2, the 2nd step is based on Eq. (C.5), the 3rd step is followed by Eq. (9), the 4th step is due to Eq. (10), and the last step uses tensor-trick.
∎
C.6 Putting it together
Lemma C.8 (Attention gradient computation, formal version of Lemma 4.1).
If it holds that
-
•
Define . Define to be several input fixed matrices.
-
•
Let denote matrix variables (we will compute gradient with respect to )
-
–
For easy of writing, we also use vector variables and , i.e., .
-
–
-
•
Let (where is defined as Definition 1.2)
Then we can show that gradient can be computed in time.
Proof.
Step 1. we compute , . This takes time due to Lemma C.1.
Step 2. we compute . This takes time of due to Lemma C.2.
Step 3. we compute . This take time of due to Lemma C.4.
Step 4. we compute . This take time of due to Lemma C.6.
Step 5. using Lemma C.7, we know that gradient is equivalent to . Suppose are given, then it can be calculated in time of .
Thus, overall running for computing gradient is
time. ∎
Appendix D Fast Running Time via Polynomial Method
Recall that in the previous section, for convenience of computing the derivative, we ignoreed the factor in . That factor doesn’t impact the running time of our algorithms since it is just a rescaling factor. To apply the tools from previous work [4], we will now reconsider the factor in . In Section D.1, we will show how to efficiently and explicitly construct a low rank representation for . In Section D.2, we show how to create a low rank construction for . In Section D.3, Section D.4 and Section D.5, we further give low rank presentations for . In Section D.6, we prove our final algorithmic result by putting everything together.
D.1 Low rank representation to
Using [4]’s polynomial method result, we are able to obtain the following low-rank representation result,
Lemma D.1 (Section 3 of [4]).
For any , there exists a such that: Let be two matrices and be a square matrix. It holds that , then there are two matrices such that . Here and we define . Moreover, these matrices can be explicitly constructed in time.
D.2 Low rank representation to
Lemma D.2.
Let . Assume that each number in the matrices and can be written using bits. Let matrix be defined as Definition 3.8. Then, there are two matrices we have .
Proof.
We can show that
where the first step follows from .
∎
D.3 Low rank representation to
Lemma D.3.
Proof.
We define to be the approximation of .
From Lemma D.2, we know that is a good approximation to .
Then we should pick in this way .
Now, let us turn into some low-rank representation
It is obvious that we should can first compute which only takes time. Then since all the low rank matrices are known, then we can explicitly construct where .
For controlling the error, we can show
Thus, we complete the proof. ∎
D.4 Low rank representation to
Lemma D.4.
Let . Let . Assume that . Assume approximates the such that . Assume approximates the such that . Then there are matrices such that . The matrices can be explicitly constructed in time.
Proof.
We choose and . This can be computed in time.
For easy of writing proofs, we call and .
Using Fact A.2, we know that
where the 1st step follows from the way we define , the 2nd step follows from the way we define and , the 3rd step follows from Fact A.2, the 4th step follows from the way we define and , the 5th step follows from simple algebra, the 6th step follows by triangle inequality, and the last step follows by that entries are bounded and (Lemma assumption) and (Lemma assumption)
∎
D.5 Low rank representation
Lemma D.5.
Let . Let . Let . Assume that is an where -th column for each . Assume approximates the such that . Assume approximates the such that . Then there are matrices such that . The matrices can be explicitly constructed in time.
Proof.
We define a local vector function where is . Let denote the approximation of .
Note that is a good approximation to .
Note that is a good approximation to .
Let .
For the computation side, we firstly compute . This takes time.
Next, we we have
Once the are pre-computed, the above step only takes time. Since there coordinates, so the overall time is still .
Let denote the approximation of . Then we just use and to approximate in the following sense, let . Since has low rank representation, and is a diagonal matrix, then it is obvious how to construct and . Basically and .
Now, we need to control the error, we have
where the 2nd step follows follows from definition of and .
For the first term, we have
For the second term, we have
Using the three equations we obtained above, the proof is completed. ∎
D.6 Fast Computation in Almost Linear Time
Theorem D.6 (Main result, formal version of Theorem 1.6).
Assuming the entries of are represented using bits, there is a time algorithm to solve (see Definition 1.4) up to accuracy. In particular, our algorithm outputs a gradient vector such that .
Proof.
Recall definition of matrices (Definition C.5), (see Lemma D.5) and (Lemma D.4), it is straightforward that
Using Lemma D.1, Lemma D.2, Lemma D.3, we know that assumptions in Lemma D.4 and Lemma D.5 are holding, so that we can use Lemma D.4 and Lemma D.5 to obtain that
-
•
has approximate low rank representation , let denote
-
•
has approximate low rank representation , let denote
All of the Lemmas D.1, D.2, D.3, D.4 and D.5 are taking time.
According to the proof for the Lemma C.7, we have that
Thus, we firstly compute ,
-
•
We compute , this takes time
-
•
We compute , this takes time
-
•
Compute , this takes time
Second, we can compute ,
-
•
We compute , this takes time
-
•
We compute , this takes time
-
•
Compute , this takes time
So, overall running time is still .
We have
where the 4th step follows from triangle inequality, the last step follows from entries in are bounded, and , .
Picking , we have the proof completed. ∎
Acknowledgments
The authors would like to thank Yichuan Deng for helpful discussions.
References
- AA [22] Amol Aggarwal and Josh Alman. Optimal-degree polynomial approximations for exponentials and gaussian kernel density estimation. In 37th Computational Complexity Conference (CCC 2022). Schloss Dagstuhl-Leibniz-Zentrum für Informatik, 2022.
- ACSS [20] Josh Alman, Timothy Chu, Aaron Schild, and Zhao Song. Algorithms and hardness for linear algebra on geometric graphs. In 2020 IEEE 61st Annual Symposium on Foundations of Computer Science (FOCS), pages 541–552. IEEE, 2020.
- ALS+ [23] Josh Alman, Jiehao Liang, Zhao Song, Ruizhe Zhang, and Danyang Zhuo. Bypass exponential time preprocessing: Fast neural network training via weight-data correlation preprocessing. In NeurIPS. arXiv preprint arXiv:2211.14227, 2023.
- AS [23] Josh Alman and Zhao Song. Fast attention requires bounded entries. In NeurIPS, 2023.
- AS [24] Josh Alman and Zhao Song. How to capture higher-order correlations? generalizing matrix softmax attention to kronecker computation. In ICLR, 2024.
- BCS [97] Peter Bürgisser, Michael Clausen, and Mohammad A Shokrollahi. Algebraic complexity theory, volume 315. Springer Science & Business Media, 1997.
- BIS [17] Arturs Backurs, Piotr Indyk, and Ludwig Schmidt. On the fine-grained complexity of empirical risk minimization: Kernel methods and neural networks. Advances in Neural Information Processing Systems (NeurIPS), 30, 2017.
- Blä [13] Markus Bläser. Fast matrix multiplication. Theory of Computing, pages 1–60, 2013.
- BMR+ [20] Tom Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared D Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, et al. Language models are few-shot learners. Advances in neural information processing systems, 33:1877–1901, 2020.
- BPSW [21] Jan van den Brand, Binghui Peng, Zhao Song, and Omri Weinstein. Training (over- parametrized) neural networks in near-linear time. 12th Innovations in Theoretical Computer Science Conference (ITCS), 2021.
- BSZ [23] Jan van den Brand, Zhao Song, and Tianyi Zhou. Algorithm and hardness for dynamic attention maintenance in large language models. arXiv preprint arXiv:2304.02207, 2023.
- CKNS [20] Moses Charikar, Michael Kapralov, Navid Nouri, and Paris Siminelakis. Kernel density estimation through density constrained near neighbor search. In 2020 IEEE 61st Annual Symposium on Foundations of Computer Science (FOCS), pages 172–183. IEEE, 2020.
- CLP+ [21] Beidi Chen, Zichang Liu, Binghui Peng, Zhaozhuo Xu, Jonathan Lingjie Li, Tri Dao, Zhao Song, Anshumali Shrivastava, and Re.Mongoose Christopher. A learnable lsh framework for efficient neural network training. International Conference on Learning Representation, 2021.
- CND+ [22] Aakanksha Chowdhery, Sharan Narang, Jacob Devlin, Maarten Bosma, Gaurav Mishra, Adam Roberts, Paul Barham, Hyung Won Chung, Charles Sutton, Sebastian Gehrmann, et al. Palm: Scaling language modeling with pathways. arXiv preprint arXiv:2204.02311, 2022.
- DCLT [18] Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. Bert: Pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805, 2018.
- DGS [23] Yichuan Deng, Yeqi Gao, and Zhao Song. Solving tensor low cycle rank approximation. arXiv preprint arXiv:2304.06594, 2023.
- DHS+ [22] Yichuan Deng, Hang Hu, Zhao Song, Omri Weinstein, and Danyang Zhuo. Training overparametrized neural networks in sublinear time. arXiv preprint arXiv:2208.04508, 2022.
- DJS+ [19] Huaian Diao, Rajesh Jayaram, Zhao Song, Wen Sun, and David Woodruff. Optimal sketching for kronecker product regression and low rank approximation. Advances in neural information processing systems, 32, 2019.
- DLMS [23] Yichuan Deng, Zhihang Li, Sridhar Mahadevan, and Zhao Song. Zero-th order algorithm for softmax attention optimization. arXiv preprint arXiv:2307.08352, 2023.
- DMS [23] Yichuan Deng, Sridhar Mahadevan, and Zhao Song. Randomized and deterministic attention sparsification algorithms for over-parameterized feature dimension. arXiv preprint arXiv:2304.04397, 2023.
- DSSW [18] Huaian Diao, Zhao Song, Wen Sun, and David Woodruff. Sketching for kronecker product regression and p-splines. In International Conference on Artificial Intelligence and Statistics, pages 1299–1308. PMLR, 2018.
- DSY [23] Yichuan Deng, Zhao Song, and Junze Yin. Faster robust tensor power method for arbitrary order. arXiv preprint arXiv:2306.00406, 2023.
- GQSW [24] Yeqi Gao, Lianke Qin, Zhao Song, and Yitan Wang. A sublinear adversarial training algorithm. In ICLR. arXiv preprint arXiv:2208.05395, 2024.
- GSWY [23] Yeqi Gao, Zhao Song, Weixin Wang, and Junze Yin. A fast optimization view: Reformulating single layer attention in llm based on tensor and svm trick, and solving it in matrix multiplication time. arXiv preprint arXiv:2309.07418, 2023.
- GSX [23] Yeqi Gao, Zhao Song, and Shenghao Xie. In-context learning for attention scheme: from single softmax regression to multiple softmax regression via a tensor trick. arXiv preprint arXiv:2307.02419, 2023.
- [26] Yeqi Gao, Zhao Song, and Xin Yang. Differentially private attention computation. arXiv preprint arXiv:2305.04701, 2023.
- [27] Yeqi Gao, Zhao Song, and Junze Yin. Gradientcoin: A peer-to-peer decentralized large language models. arXiv preprint arXiv:2308.10502, 2023.
- GSYZ [23] Yeqi Gao, Zhao Song, Xin Yang, and Ruizhe Zhang. Fast quantum algorithm for attention computation. arXiv preprint arXiv:2307.08045, 2023.
- HJK+ [23] Insu Han, Rajesh Jarayam, Amin Karbasi, Vahab Mirrokni, David P. Woodruff, and Amir Zandieh. Hyperattention: Long-context attention in near-linear time. arXiv preprint arXiv:2310.05869, 2023.
- Inc [23] Adobe Inc. Adobe firefly. In Adobe. https://www.adobe.com/sensei/generative-ai/firefly.html, 2023.
- IP [01] Russell Impagliazzo and Ramamohan Paturi. On the complexity of k-sat. Journal of Computer and System Sciences, 62(2):367–375, 2001.
- KKL [20] Nikita Kitaev, Łukasz Kaiser, and Anselm Levskaya. Reformer: The efficient transformer. arXiv preprint arXiv:2001.04451, 2020.
- KMZ [23] Praneeth Kacham, Vahab Mirrokni, and Peilin Zhong. Polysketchformer: Fast transformers via sketches for polynomial kernels. arXiv preprint arXiv:2310.01655, 2023.
- KWH [23] Feyza Duman Keles, Pruthuvi Mahesakya Wijewardena, and Chinmay Hegde. On the computational complexity of self-attention. In International Conference on Algorithmic Learning Theory, pages 597–619. PMLR, 2023.
- LOG+ [19] Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, and Veselin Stoyanov. Roberta: A robustly optimized bert pretraining approach. arXiv preprint arXiv:1907.11692, 2019.
- Man [23] James Manyika. An overview of bard: an early experiment with generative ai. Technical report, Tech. rep., Technical report, Google AI, 2023.
- MGN+ [23] Sadhika Malladi, Tianyu Gao, Eshaan Nichani, Alex Damian, Jason D Lee, Danqi Chen, and Sanjeev Arora. Fine-tuning language models with just forward passes. arXiv preprint arXiv:2305.17333, 2023.
- PMXA [23] Abhishek Panigrahi, Sadhika Malladi, Mengzhou Xia, and Sanjeev Arora. Trainable transformer in transformer. arXiv preprint arXiv:2307.01189, 2023.
- RSZ [22] Aravind Reddy, Zhao Song, and Lichen Zhang. Dynamic tensor product regression. In NeurIPS, 2022.
- Rub [18] Aviad Rubinstein. Hardness of approximate nearest neighbor search. In Proceedings of the 50th annual ACM SIGACT symposium on theory of computing (STOC), pages 1260–1268, 2018.
- SWZ [19] Zhao Song, David P Woodruff, and Peilin Zhong. Relative error tensor low rank approximation. In SODA. arXiv preprint arXiv:1704.08246, 2019.
- SYZ [21] Zhao Song, Shuo Yang, and Ruizhe Zhang. Does preprocessing help training over-parameterized neural networks? 35th Conference on Neural Information Processing Systems, 2021.
- SZZ [24] Zhao Song, Lichen Zhang, and Ruizhe Zhang. Training multi-layer over-parametrized neural network in subquadratic time. In ITCS. arXiv preprint arXiv:2112.07628, 2024.
- TDFH+ [22] Romal Thoppilan, Daniel De Freitas, Jamie Hall, Noam Shazeer, Apoorv Kulshreshtha, Heng-Tze Cheng, Alicia Jin, Taylor Bos, Leslie Baker, Yu Du, et al. Lamda: Language models for dialog applications. arXiv preprint arXiv:2201.08239, 2022.
- TLI+ [23] Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin, Edouard Grave, and Guillaume Lample. Llama: Open and efficient foundation language models. arXiv preprint arXiv:2302.13971, 2023.
- TMS+ [23] Hugo Touvron, Louis Martin, Kevin Stone, Peter Albert, Amjad Almahairi, Yasmine Babaei, Nikolay Bashlykov, Soumya Batra, Prajjwal Bhargava, Shruti Bhosale, et al. Llama 2: Open foundation and fine-tuned chat models. arXiv preprint arXiv:2307.09288, 2023.
- VSP+ [17] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. Advances in neural information processing systems, 30, 2017.
- Wil [18] Virginia Vassilevska Williams. On some fine-grained questions in algorithms and complexity. In Proceedings of the international congress of mathematicians: Rio de janeiro 2018, pages 3447–3487. World Scientific, 2018.
- WTB+ [22] Jason Wei, Yi Tay, Rishi Bommasani, Colin Raffel, Barret Zoph, Sebastian Borgeaud, Dani Yogatama, Maarten Bosma, Denny Zhou, Donald Metzler, et al. Emergent abilities of large language models. arXiv preprint arXiv:2206.07682, 2022.
- YCRI [22] Ann Yuan, Andy Coenen, Emily Reif, and Daphne Ippolito. Wordcraft: story writing with large language models. In 27th International Conference on Intelligent User Interfaces, pages 841–852, 2022.
- YDY+ [19] Zhilin Yang, Zihang Dai, Yiming Yang, Jaime Carbonell, Russ R Salakhutdinov, and Quoc V Le. Xlnet: Generalized autoregressive pretraining for language understanding. Advances in neural information processing systems, 32, 2019.
- Zha [22] Lichen Zhang. Speeding up optimizations via data structures: Faster search, sample and maintenance. Master’s thesis, Carnegie Mellon University, 2022.
- ZHDK [23] Amir Zandieh, Insu Han, Majid Daliri, and Amin Karbasi. Kdeformer: Accelerating transformers via kernel density estimation. In ICML. arXiv preprint arXiv:2302.02451, 2023.
- ZRG+ [22] Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen, Christopher Dewan, Mona Diab, Xian Li, Xi Victoria Lin, et al. Opt: Open pre-trained transformer language models. arXiv preprint arXiv:2205.01068, 2022.