This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Fine-grained Attention I/O Complexity: Comprehensive Analysis for Backward Passes

Xiaoyu Li [email protected]. Stevens Institute of Technology.    Yingyu Liang [email protected]. The University of Hong Kong. [email protected]. University of Wisconsin-Madison.    Zhenmei Shi [email protected]. University of Wisconsin-Madison.    Zhao Song [email protected]. The Simons Institute for the Theory of Computing at the University of California, Berkeley.    Yufa Zhou [email protected]. University of Pennsylvania.

Large Language Models (LLMs) have demonstrated remarkable capabilities in processing long-context information. However, the quadratic complexity of attention computation with respect to sequence length poses significant computational challenges, and I/O aware algorithms have been proposed. This paper presents a comprehensive analysis of the I/O complexity for attention mechanisms, focusing on backward passes by categorizing into small and large cache scenarios. Using the red-blue pebble game framework, we establish tight bounds on I/O complexity across all cache sizes. We confirm that the de facto standard I/O aware algorithm FlashAttention is optimal for both forward and backward passes for the large cache size scenario. For small cache sizes, we provide an algorithm that improves over existing methods and achieves the tight bounds. Additionally, we extend our analysis to sparse attention, a mainstream speeding-up approach, deriving fine-grained lower bounds for both forward and backward passes and both small and large caches. Our findings complete the theoretical foundation for I/O complexity in attention mechanisms, offering insights for designing efficient algorithms of LLM training and inference.

1 Introduction

Large Language Models (LLMs), such as GPT-4 [2], Claude [7], Llama [73], and more recently o1 [83] from OpenAI, have demonstrated immense potential to enhance various aspects of our daily lives, including conversational AI [62], AI agents [114, 24], search AI [83], AI assistants [60, 37], and many others. One of the most emergent abilities of LLMs is dealing with long-context information, which is crucial for processing materials such as academic papers, official reports, and legal documents. LLMs have proven adept at tackling long-context tasks, such as zero-shot summarization [17, 125] and maintaining very long-term conversations [115, 75]. OpenAI’s o1 model [83] serves as a significant advancement in this area. It leverages Chain-of-Thought (CoT) reasoning [113, 59] and employs Retrieval Augmented Generation (RAG) [67, 44] to exhibit PhD-level abilities, where both techniques require long context inputs for generation. This proficiency underscores the necessity for developing long-context modeling capabilities within LLMs.

LLMs are primarily based on the Transformer architecture [108], whose core component is the self-attention mechanism. However, the quadratic complexity of attention computation with respect to sequence length dominates the computational FLOPs during long-context training and inference. To address this issue, FlashAttention [26, 25, 91] accelerates attention computation and has become the de facto standard in the industry of LLM training and inference deployment. The success of FlashAttention lies in its I/O awareness [11], accounting for reads and writes to different levels of fast cache (e.g., GPU on-chip SRAM) and slow memory (e.g., GPU high-bandwidth memory) within the hardware hierarchy. Leveraging modern hardware design in GPUs, e.g., NVIDIA A100 and H100, efficiently allows FlashAttention to be integrated as a go-to method for LLM training and inference.

For the I/O complexity of exact attention111In this work, we only consider exact attention computation without any approximation. forward computation, the theoretical analysis of FlashAttention in DFE+ [26] only provides upper and lower bounds when the cache size M[d,nd]M\in[d,nd]. Their bounds are only tight in the range of M=Θ(nd)M=\Theta(nd), where nn is the input sequence length and dd is the hidden dimension. By fine-grained analysis, a recent work [102] provides matching upper and lower I/O complexity bounds of the attention forward passes for any cache size MM. For the I/O complexity of attention backward passes, existing work only provides an upper bound for FlashAttention for the cache size M[d,nd]M\in[d,nd] [26], without known lower bounds. Thus, the tight bounds for the I/O complexity of attention backward passes are lacking. This raises a natural question:

What is the optimal I/O complexity of attention backward computations for any cache size?

In this paper, we address this question and provide matching upper and lower I/O complexity bounds for backward passes of exact attention computation for all cache sizes, completing the picture of I/O complexity for the attention mechanism.

1.1 Our Contributions

In this work, we analyze the I/O complexity in the same setting as the existing work of FlashAttention [26] and SY [102]. We consider a two-level memory hierarchy consisting of a small but fast layer called the cache and a large but slower layer referred to as memory. The I/O complexity quantifies the data transfer between these two layers, which can be formally defined as a red-blue pebble game [49] as in Definition 3.4. We study the exact attention computation using standard matrix multiplication as the existing work222Note that there are many fast matrix multiplication methods. We do not study them, as they are hard to be parallelized. Standard matrix multiplication is still the most popular implementation on GPU, e.g., PyTorch. We refer readers to Section 3 for more details. and focus on backward gradient computation. We establish matching I/O complexity upper and lower bounds for attention backward computation (formalized in Theorem 1.1 and illustrated in Fig. 1). Combined with the attention forward results from SY [102], this completes the theory of I/O complexity in the attention mechanism.

Our main result is stated as follows:

Theorem 1.1 (Main result).

Let nn be the sequence length, dd the head dimension, and MM the cache size. The I/O complexity of attention backward computation under standard matrix multiplication is

Θ(min{n2d2+nd3M,n2d+nd2M}).\displaystyle\Theta\left(\min\left\{\frac{n^{2}d^{2}+nd^{3}}{M},\frac{n^{2}d+nd^{2}}{\sqrt{M}}\right\}\right).

To interpret our main result, we categorize the cache size MM into two cases: the small cache case where M=o(d2)M=o(d^{2}) and the large cache case where M=Ω(d2)M=\Omega(d^{2}) (see Fig. 1 for illustration).

In the small cache scenario, M=o(d2)M=o(d^{2}), by computation graph Fig. 2 and Algorithm 6, we show that the upper bound of the I/O complexity is O(n2d+nd2M)O(\frac{n^{2}d+nd^{2}}{\sqrt{M}}). In detail, Algorithm 6 explicitly read/write the n×nn\times n attention matrix and other n×dn\times d intermediate matrices from/to memory. Note that, when M=o(d2)M=o(d^{2}), our Algorithm 6 has a better upper bound than FlashAttention, whose upper bound is O(n2d2+nd3M)O(\frac{n^{2}d^{2}+nd^{3}}{M}). Furthermore, to establish a lower bound on the I/O complexity, we show that the I/O complexity of attention backward computation is equivalent to the I/O complexity of matrix multiplication when M=o(d2)M=o(d^{2}), which matches the upper bound of Algorithm 6.

Refer to caption
Figure 1: Attention backward I/O complexity comparison. The xx-axis is the cache size, and the yy-axis is the I/O complexity. The red line represents our tight upper/lower bound (Theorem 1.1), and the blue dash denotes the upper bound for FlashAttention [26]. The cross point is M=Θ(d2)M=\Theta(d^{2}), the dividing point of large cache and small cache settings. The results show that FlashAttention is optimal when M=Ω(d2)M=\Omega(d^{2}).

In the more practical large cache case, M=Ω(d2)M=\Omega(d^{2}), we prove an upper bound O(n2d2+nd3M)O(\frac{n^{2}d^{2}+nd^{3}}{M}) on the I/O complexity for the attention backward algorithms (Algorithm 9), which matches that of FlashAttention [26, 25, 91]. We prove that this upper bound is tight by providing a matching lower bound for the I/O complexity of attention backward using the red-blue pebble game analysis framework from HK [49].

Therefore, we provide the optimal bounds and algorithms for backward passes for all cache sizes. This fully characterizes the I/O complexity of attention forward/backward when combined with existing results on forward passes [102]. Notably, we confirm that FlashAttention is optimal for both the forward and backward passes when the cache size is large enough M=Ω(d2)M=\Omega(d^{2}).

Moreover, in recent years, sparse attention has become another mainstream method for speeding up the training process of transformer-based models [18, 122, 16]. These approaches mainly focus on techniques for sparsifying the attention matrix, thereby reducing the quadratic bottleneck in running time. However, it remains unknown whether this method can be integrated with I/O-aware algorithms like FlashAttention. Consequently, we further analyze the I/O complexity of sparse attention to provide theoretical guarantees, offering fine-grained lower bounds.

Theorem 1.2 (Lower bound for sparse attention forward and backward, informal version of Theorem 4.5).

Let ZinputZ_{\mathrm{input}} and ZQKZ_{\mathrm{QK}} be the number of nonzero entries of the input matrix and the key-query matrix, respectively. Then any algorithm for both attention forward and backward computation using sparse semi-ring matrix multiplication has I/O complexity

Ω(min{Zinput2M,ZinputZQKM}).\displaystyle\Omega\left(\min\left\{\frac{Z_{\mathrm{input}}^{2}}{M},\frac{Z_{\mathrm{input}}\sqrt{Z_{\mathrm{QK}}}}{\sqrt{M}}\right\}\right).

Our I/O complexity lower bound for sparse attention recovers the lower bound for both attention forward and backward passes when matrices involved in attention computation are dense, i.e., Zinput=Ω(nd),ZQK=Ω(n2)Z_{\mathrm{input}}=\Omega(nd),Z_{\mathrm{QK}}=\Omega(n^{2}). In such case, our lower bound reads as Ω(min{n2d2M,n2dM})\Omega(\min\{\frac{n^{2}d^{2}}{M},\frac{n^{2}d}{\sqrt{M}}\}), matching Theorem 1.1.

Table 1: Summary of our contributions. We categorize the cache size MM into two cases: (1) Large cache M=Ω(d2)M=\Omega(d^{2}); (2) Small cache M=o(d2)M=o(d^{2}). Assume ndn\geq d. We list our contributions for general and sparse attention below. ZinputZ_{\mathrm{input}} and ZQKZ_{\mathrm{QK}} denote the number of nonzero entries of the input matrix and the key-query matrix, respectively.
Attention Algorithm Large Cache Reference Small Cache Reference
General Forward Upper O(n2d2/M)O(n^{2}d^{2}/M) DFE+ [26] O(n2d/M)O(n^{2}d/\sqrt{M}) SY [102]
Forward Lower Ω(n2d2/M)\Omega(n^{2}d^{2}/M) SY [102] Ω(n2d/M)\Omega(n^{2}d/\sqrt{M}) SY [102]
Backward Upper O(n2d2/M)O(n^{2}d^{2}/M) DFE+ [26] O(n2d/M)O(n^{2}d/\sqrt{M}) Theorem 4.3
Backward Lower Ω(n2d2/M)\Omega(n^{2}d^{2}/M) Theorem 4.2 Ω(n2d/M)\Omega(n^{2}d/\sqrt{M}) Theorem 4.4
Sparse Forward Lower Ω(Zinput2/M)\Omega(Z_{\mathrm{input}}^{2}/M) Theorem 4.5 Ω(ZinputZQK/M)\Omega(Z_{\mathrm{input}}\sqrt{Z_{\mathrm{QK}}}/\sqrt{M}) Theorem 4.5
Backward Lower Ω(Zinput2/M)\Omega(Z_{\mathrm{input}}^{2}/M) Theorem 4.5 Ω(ZinputZQK/M)\Omega(Z_{\mathrm{input}}\sqrt{Z_{\mathrm{QK}}}/\sqrt{M}) Theorem 4.5

We summarize our contributions in Table 1 and also conclude as follows:

  • For small cache sizes M=o(d2)M=o(d^{2}) in the backward pass, we present optimal upper and lower bounds and propose an algorithm achieving the optimal (Algorithm 6). Notably, FlashAttention is not optimal in this setting, and our algorithm outperforms it.

  • For large cache sizes M=Ω(d2)M=\Omega(d^{2}) in the backward pass, we establish an optimal lower bound that matches the existing upper bound. We also prove the optimal upper bound and introduce an optimal algorithm (Algorithm 9), matching the existing results for FlashAttention but providing a different analysis.

  • For sparse attention, we offer fine-grained lower bounds for both forward and backward passes and across all cache sizes (Theorem 4.5).

Roadmap.

In Section 2, we review related literature. In Section 3, we introduce the definitions and background necessary for our study. We present our main results in Section 4 and discuss the techniques we employed in Section 5. Section 6 concludes our paper.

2 Related Work

Attention Computation Acceleration.

The quadratic time complexity of attention computation with respect to the length of the input sequence [108] poses significant computational challenges, especially for long sequences. Consequently, accelerating attention computation has become a crucial research area. From a theoretical standpoint, numerous works focus on approximating the attention matrix to accelerate computation [48, 8, 9, 68, 10, 71]. Experimental approaches involve modifying model architectures and optimizing implementations to accelerate inference. Methods such as Mamba [39, 27], Linearizing Transformers [120, 80], Hopfield Models [55, 110, 51, 116, 109, 46, 47] and PolySketchFormer [123, 61] aim to improve model performance and inference speed. System-level optimizations, such as FlashAttention [26, 25, 91] and block-wise parallel decoding [96], address bottlenecks in attention mechanisms and enhance inference speed through efficient implementation strategies. Collectively, these advancements contribute to making attention mechanisms more scalable and efficient, facilitating the deployment of large-scale language models. [95] accelerates inference by compressing the input text.

Learning with Bounded Memory and I/O Complexity.

A common memory model in computational systems is the two-level memory hierarchy. In this model, there are two layers of memory: a small but fast layer called the cache, and a large but slower layer called the memory. The I/O (input/output) complexity of an algorithm measures its efficiency based on the number of data transfer operations it performs between the cache and the memory. The early work of HK [49] formulated the I/O complexity mathematically using the language of graph theory. Learning with bounded memory has been studied in various fields in machine learning such as online learning [101, 84, 86], convex optimization [78, 20], active learning [50], attention computation [5], and continual learning [21, 36].

Sparse Attention.

Over the past few years, there has been extensive research on sparse Transformer/Attention models with weights pruning and inputs pruning, aimed at accelerating computation and training [119, 94, 16, 104, 43, 98, 92, 66, 34, 19]. In practice, the attention matrix is sparse, significantly reducing computational costs. Theoretical studies, such as [118], have demonstrated that sparse transformers are expressive enough and can achieve universal approximation properties.

Refer to caption
Figure 2: The computational graph for attention forward and backward. The blue boxes are input matrices, the gray boxes are intermediate matrices, the green box is the forward output, and the orange box is the final gradient matrix. Here, A1,A2,A3A_{1},A_{2},A_{3} denote the previous inputs, dO\mathrm{d}O denotes the upstream gradient, and X,YX,Y denote the attention weights. More detailed definitions of each variables can be found in Section 3 and B.

3 Preliminary

In this work, we consider using a standard algorithm for matrix multiplication, which means that for any two matrices An1×d,Bd×n2A\in\mathbb{R}^{n_{1}\times d},B\in\mathbb{R}^{d\times n_{2}}, each entry of ABAB is computed by (AB)i,j=k=1dAi,kBk,j(AB)_{i,j}=\sum_{k=1}^{d}A_{i,k}B_{k,j} for i[n1],j[n2]i\in[n_{1}],j\in[n_{2}]. Note that this setting is also used in FlashAttetnion [26] and SY [102]. Then, we introduce some key concepts needed for this paper.

3.1 Key Concept of Attention

Before formally stating our results, we begin by precisely defining the problems we study. We define the following computation of the general Softmax attention forward layer.

Definition 3.1 (Attention forward computation).

Let nn be the input length and dd be the head dimension. Let A1,A2,A3n×dA_{1},A_{2},A_{3}\in\mathbb{R}^{n\times d} be the inputs of previous layer. Given query, key and value weights matrix WQ,WK,WVd×dW_{Q},W_{K},W_{V}\in\mathbb{R}^{d\times d}, we have the Softmax attention forward computation being

𝖠𝗍𝗍𝗇(A1,A2,A3):=D1exp(A1WQWKA2)A3WV,\displaystyle\mathsf{Attn}(A_{1},A_{2},A_{3}):=D^{-1}\exp(A_{1}W_{Q}W_{K}^{\top}A_{2}^{\top})A_{3}W_{V},

where (1) D:=diag(exp(A1WQWKA2)𝟏n)D:=\operatorname{diag}(\exp(A_{1}W_{Q}W_{K}^{\top}A_{2}^{\top})\cdot{\bf 1}_{n}), (2) exp\exp denotes the exponential function and is applied entry-wisely, (3) diag()\operatorname{diag}() operation takes a vector and outputs a diagonal matrix with the entries of that vector, and (4) 𝟏n{\bf 1}_{n} denotes the length-nn all ones vector.

To simplify and focus more clearly on the core computational aspects of the problem, we set X=WQWKd×dX=W_{Q}W_{K}^{\top}\in\mathbb{R}^{d\times d} and Y=WVd×dY=W_{V}\in\mathbb{R}^{d\times d}.

Note that, we have 𝖲𝗈𝖿𝗍𝗆𝖺𝗑(A1XA2)=D1exp(A1XA2)n×n\mathsf{Softmax}(A_{1}XA_{2}^{\top})=D^{-1}\exp(A_{1}XA_{2}^{\top})\in\mathbb{R}^{n\times n}, and usually we call it the attention matrix. The above definition is general and encompasses both self-attention and cross-attention mechanisms in Transformer architectures. Specifically, self-attention occurs when A1=A2=A3A_{1}=A_{2}=A_{3}, meaning that the queries, keys, and values are all derived from the same source. In contrast, cross-attention happens when A2=A3A_{2}=A_{3}, indicating that the keys and values come from one source while the queries come from the other.

Notably, FlashAttention [26, 25, 91] and SY [102] consider Q,K,Vn×dQ,K,V\in\mathbb{R}^{n\times d} after applying the linear layer to the previous inputs, while we consider a more detailed structure as Q=A1WQ,K=A2WK,V=A3WVQ=A_{1}W_{Q},K=A_{2}W_{K},V=A_{3}W_{V} (Definition 3.1) explicitly calculating module-wise gradients on attention weights. This explains why our I/O complexity bound Θ(min{n2d2+nd3M,n2d+nd2M})\Theta(\min\{\frac{n^{2}d^{2}+nd^{3}}{M},\frac{n^{2}d+nd^{2}}{\sqrt{M}}\}) in Theorem 1.1 has an additional term nd2nd^{2} in the small cache case and nd3nd^{3} in the large cache case. When ndn\geq d, the additional term will disappear.

Mathematically, optimizing the attention computation involves adjusting the attention weight matrices XX, and YY. Using the previous results on attention gradients from AS24a [9] and LSS+ [68], we have the following definition of attention gradient:

Definition 3.2 (Attention backward gradient).

Let A1,A2n×dA_{1},A_{2}\in\mathbb{R}^{n\times d}. Let p(X)n×np(X)\in\mathbb{R}^{n\times n} be defined in Definition B.9 (see Fig. 2 for an illustration). Let L(X)L(X) be some loss function. The attention backward gradient for Xd×dX\in\mathbb{R}^{d\times d} is:

dL(X)dX=A1p(X)A2.\displaystyle\frac{\mathrm{d}L(X)}{\mathrm{d}X}=A_{1}^{\top}p(X)A_{2}.
Remark 3.3.

Since the attention module depends only linearly on YY, it is straightforward to incorporate it into an algorithm, and it is not a complexity bottleneck. Thus, we focus on the case where XX is variable and YY is a fixed input.

Refer to caption
Figure 3: This diagram shows a summation tree with d=2d=2 in the computational graph for the backward passes of attention using standard matrix multiplication. The orange and green nodes represent the input nodes of the level-11 summation tree. The brown nodes, along with the blue nodes (output from the level-11 summation tree), serve as inputs for the level-22 summation tree. The purple nodes represent the target output. When dd gets larger, the summation tree will expand with additional layers, where each new layer introduces intermediate nodes that represent the sums of pairs of nodes from the previous layer, i.e., there will be total 1+log2d1+\log_{2}d layer in total.

3.2 Summation Tree

In this subsection, we need to introduce the computational graph of the attention backward gradient, which is the key concept in our I/O complexity analysis.

In the computational graph shown in Fig. 2, we can first compute A1XA_{1}X and then compute (A1X)A2(A_{1}X)A_{2}^{\top}, or first compute XA2XA_{2}^{\top} and then compute A1(XA2)A_{1}(XA_{2}^{\top}). In either case, we perform two matrix multiplications: one between an n×dn\times d matrix and a d×dd\times d matrix, and the other between an n×dn\times d matrix and a d×nd\times n matrix. Without loss of generality for illustration, we consider the first case. To compute A1XA_{1}X, we need to calculate the products {(A1)i,kXk,j}\{(A_{1})_{i,k}X_{k,j}\} for all i[n],k[d],j[d]i\in[n],\,k\in[d],\,j\in[d]. Each entry (A1X)i,j(A_{1}X)_{i,j} is then obtained by summing these products over kk: (A1X)i,j=k=1d(A1)i,kXk,j.(A_{1}X)_{i,j}=\sum_{k=1}^{d}(A_{1})_{i,k}X_{k,j}. In the computational graph, this summation is represented by a summation tree that connects the product nodes (A1)i,kXk,j(A_{1})_{i,k}X_{k,j} to the sum node (A1X)i,j(A_{1}X)_{i,j}. We define the product nodes (A1)i,kXk,j(A_{1})_{i,k}X_{k,j}, the nodes corresponding to the sums (A1X)i,j(A_{1}X)_{i,j}, and all intermediate nodes in the summation trees as level-1 nodes. Similarly, we define level-2 nodes as these nodes in the summation trees involved in computing (A1X)A2(A_{1}X)A_{2}^{\top}. We give an example of the summation tree with d=2d=2 in Fig. 3.

3.3 I/O Complexity

There are various ways to define the two-level memory hierarchy and the I/O complexity. We state the definition in HK [49], which formulates the two-level memory hierarchy as a red-blue pebble game played on a computational graph. Very recently, SY [102] proved that the I/O complexity of forward computation of FlashAttention is optimal by analyzing the red-blue pebble game on an attention forward computational graph.

Definition 3.4 (Red-blue pebble game [49]).

Consider a game played on a directed acyclic graph that has a limited number of red pebbles and an unlimited number of blue pebbles. Initially, each input node (a node with no parents) is marked with a blue pebble, while all other nodes have no pebbles. The player is allowed to perform the following operations:

  • Input: Replace a blue pebble on a node with a red pebble.

  • Output: Replace a red pebble on a node with a blue pebble.

  • Compute: Place a red pebble on a node if all its parent nodes have red pebbles.

  • Delete: Remove a pebble from a node.

The objective of the game is to place blue pebbles on all output nodes (i.e., nodes with no children) while minimizing the total number of input and output operations used throughout the process.

In the red-blue pebble game, each node represents a computational task. A red pebble denotes a unit in the small but fast layer known as cache, while a blue pebble represents a unit in the large but slower layer called memory. A task can only be computed once all its dependent tasks are completed. All computations are assumed to occur within the cache. Hence, efficient use of cache plays a critical role in reducing the I/O operations of an algorithm to minimize the cost associated with data transfer between memory and cache. We can define the I/O complexity by using the red-blue pebble game.

Definition 3.5 (I/O complexity [49]).

Consider the red-blue pebble game played on a directed acyclic graph GG. Let MM be a positive integer. The I/O complexity, denoted as Q(G,M)Q(G,M), is the minimum number of input and output operations to complete the objective of the game with the restriction that no more than MM red pebbles are present on the graph at any time. We omit GG when it is clear in the context.

The red-blue pebble game provides insight into cache management by modeling the limited cache size through the number of red pebbles. The maximum number of red pebbles corresponds to the size of the cache, which means that there can be at most MM items in the cache at any given time.

4 Main Results

In Theorem 1.1, we provide matching upper and lower bounds for the I/O complexity of attention gradient computation in the backward passes. In detail, Theorem 1.1 states that the I/O complexity of the attention gradient computation is Θ(min{n2d2+nd3M,n2d+nd2M})\Theta(\min\{\frac{n^{2}d^{2}+nd^{3}}{M},\frac{n^{2}d+nd^{2}}{\sqrt{M}}\}), which splits the cache size into two cases: (1) small cache M=o(d2)M=o(d^{2}); (2) large cache M=Ω(d2)M=\Omega(d^{2}). At the cross point M=d2M=d^{2}, we have n2d2+nd3M=n2d+nd2M=n2+nd\frac{n^{2}d^{2}+nd^{3}}{M}=\frac{n^{2}d+nd^{2}}{\sqrt{M}}=n^{2}+nd. An intuitive figure of the asymptotic I/O complexity is shown in Fig. 1.

Here we discuss two implications of Theorem 1.1. First, through the fine-grained analysis, our result identifies a critical point at M=d2M=d^{2}, where the I/O complexity changes its behavior. For M=o(d2)M=o(d^{2}), we establish better upper and lower bounds compared to existing results, demonstrating that FlashAttention is not optimal in this regime. Second, when M=Ω(d2)M=\Omega(d^{2}), Theorem 1.1 provides a tighter lower bound than existing work using red-blue pebble game (Definition 3.4), offering insights of algorithm design.

Second, by combining the results of SY [102] with our findings, we provide a more general and tighter I/O complexity characterization of FlashAttention 1/2 [26, 25]. In the large cache scenario where M=Ω(d2)M=\Omega(d^{2}), the attention forward I/O complexity is Θ(n2d2M)\Theta(\frac{n^{2}d^{2}}{M}), as discussed in Theorem 5.1 of SY [102]. Combining this result with our attention backward I/O complexity Θ(n2d2+nd3M)\Theta(\frac{n^{2}d^{2}+nd^{3}}{M}) (Theorem 1.1), we conclude that the overall complexity is Θ(n2d2+nd3M)\Theta(\frac{n^{2}d^{2}+nd^{3}}{M}). Thus, given the cache size is sufficiently large, i.e. M=Ω(d2)M=\Omega(d^{2}), the I/O complexity of the forward and backward computation for FlashAttention 1/2 is optimal.

Our main result Theorem 1.1 is a summary of our results for different cache sizes (Theorem 4.1, 4.2, 4.3, and 4.4), which will be discussed in the later subsections.

4.1 Large Cache

The large cache scenario is more interesting and practical. We now prove an upper bound below.

Theorem 4.1 (Large cache upper bound, informal version of Theorem D.5).

Suppose nn is the input length, dd is the head dimension, and M=Ω(d2)M=\Omega(d^{2}) is the cache size. There is an algorithm (see Algorithm 9) outputs a d×dd\times d matrix g=dL(X)dXg=\frac{\mathrm{d}L(X)}{\mathrm{d}X} (Definition 3.2) with I/O complexity O(n2d2+nd3M)O(\frac{n^{2}d^{2}+nd^{3}}{M}).

We then demonstrate that this upper bound is tight by providing a matching lower bound for the I/O complexity of the attention backward passes. To achieve this, we employ the framework developed in HK [49], which shows that executing an algorithm on a machine with a two-level memory hierarchy can be modeled by a red-blue pebble game (Definition 3.4) on a directed acyclic graph. We present the large cache lower bound below, which shows as long as the cache size M=Ω(d2)M=\Omega(d^{2}), the I/O complexity is at least Ω(n2d2+nd3M)\Omega(\frac{n^{2}d^{2}+nd^{3}}{M}).

Theorem 4.2 (Large cache lower bound, informal version of Theorem E.9).

Suppose nn is the input length and dd is the head dimension. Suppose the cache size M=Ω(d2)M=\Omega(d^{2}). Then the I/O complexity of attention gradient computation using standard matrix multiplication is always Ω(n2d2+nd3M)\Omega(\frac{n^{2}d^{2}+nd^{3}}{M}).

4.2 Small Cache

In the small cache case, we provide an upper bound below. Notice that this is better than the I/O complexity of FlashAttention which is O(n2d2+nd3M)>O(n2d+nd2M)O(\frac{n^{2}d^{2}+nd^{3}}{M})>O(\frac{n^{2}d+nd^{2}}{\sqrt{M}}) when M=o(d2)M=o(d^{2}).

Theorem 4.3 (Small cache upper bound, informal version of Theorem C.12).

Suppose nn is the input length, dd is the head dimension, and M=o(d2)M=o(d^{2}) is the cache size. There is an algorithm (see Algorithm 6) outputs a d×dd\times d matrix g=dL(X)dXg=\frac{\mathrm{d}L(X)}{\mathrm{d}X} (Definition 3.2) with I/O complexity O(n2d+nd2M)O(\frac{n^{2}d+nd^{2}}{\sqrt{M}}), time complexity O(n2d+nd2)O(n^{2}d+nd^{2}), and space complexity O(n2+d2)O(n^{2}+d^{2}).

Furthermore, we show that attention gradient computation can be reduced to matrix multiplication, establishing a matching lower bound.

Theorem 4.4 (Small cache lower bound, informal version of Theorem E.10).

Suppose nn is the input length and dd is the head dimension. Suppose the cache size M=o(d2)M=o(d^{2}). Then the I/O complexity of attention gradient computation using standard matrix multiplication is always Ω(n2d+nd2M)\Omega(\frac{n^{2}d+nd^{2}}{\sqrt{M}}).

4.3 Lower Bound of Sparse Attention Forward and Backward Passes

Sparse attention is a generalization of standard attention and has been popular in practical applications. We refer readers to Section 2 for more discussion. To state our results, we first introduce some notations. For any matrix AA, we use nnz(A)\operatorname{nnz}(A) to denote the number of non-zero entries in the matrix AA. We assume that sparse matrices are stored by listing only their non-zero entries along with their coordinates. We assume sparse semi-ring matrix multiplication, which restricts operations to addition and multiplication of these entries. Each output entry (AB)i,j(AB)_{i,j} can only be computed as the sum of products given by kAi,kBk,j\sum_{k}A_{i,k}B_{k,j}.

Theorem 4.5 (Lower bound for sparse attention forward and backward, formal version of Theorem 1.2).

Suppose nn is the input length, dd is the head dimension, and MM is the cache size. Let ZA:=min{nnz(A1),nnz(A2)},ZX:=nnz(X),ZAX=min{nnz(A1X),nnz(XA2)},ZAXA:=nnz(A1XA2)Z_{A}:=\min\{\operatorname{nnz}(A_{1}),\operatorname{nnz}(A_{2})\},Z_{X}:=\operatorname{nnz}(X),Z_{AX}=\min\{\operatorname{nnz}(A_{1}X),\operatorname{nnz}(XA_{2}^{\top})\},Z_{AXA}:=\operatorname{nnz}(A_{1}XA_{2}^{\top}). Then any algorithm for both attention forward and backward computation using sparse semi-ring matrix multiplication has I/O complexity

Ω(min{ZA2+ZAZXM,ZAZAXA+ZAZXZAXAM}).\displaystyle\Omega\left(\min\left\{\frac{Z_{A}^{2}+Z_{A}Z_{X}}{M},\frac{Z_{A}\sqrt{Z_{AXA}}+\sqrt{Z_{A}Z_{X}Z_{AXA}}}{\sqrt{M}}\right\}\right).
Remark 4.6.

When matrices involved in attention computation are dense, i.e., ZA=Ω(nd),ZX=Ω(d2),ZAX=Ω(nd)Z_{A}=\Omega(nd),Z_{X}=\Omega(d^{2}),Z_{AX}=\Omega(nd), and ZAXA=Ω(n2)Z_{AXA}=\Omega(n^{2}). In such case, our lower bound reads as Ω(min{n2d2+nd3M,\Omega(\min\{\frac{n^{2}d^{2}+nd^{3}}{M}, n2d+nd2M})\frac{n^{2}d+nd^{2}}{\sqrt{M}}\}). Hence, it matches the result of lower bounds in the dense case.

5 Technical Overview

Upper Bound of Small Cache. In Section C, we present algorithms for the backward passes of attention in the small cache case, where M=o(d2)M=o(d^{2}). We observe that when M=o(d2)M=o(d^{2}), we have n2d2+nd3M>n2d+nd2M>n2+nd\frac{n^{2}d^{2}+nd^{3}}{M}>\frac{n^{2}d+nd^{2}}{\sqrt{M}}>n^{2}+nd. Then we can exploit this to design a better algorithm with I/O complexity better than n2d2+nd3M\frac{n^{2}d^{2}+nd^{3}}{M}, by reading/writing the n×nn\times n attention matrix and other n×dn\times d intermediate matrices from/to memory. In detail, our small cache algorithm (Algorithm 6) follows the computational graph in Figure 2 and is divided into four phases. In Phase 1 (Algorithm 2), we compute the attention matrix ff (Definition B.5) and write it to memory. In Phase 2 (Algorithm 3), we compute qq (Definition B.8), incorporating the information from the upstream gradient dO\mathrm{d}O. Phase 3 (Algorithm 4) computes the gradient component matrix pp (Definition B.9). Finally, in Phase 4 (Algorithm 5), we compute the final gradient g=A1pA2g=A_{1}^{\top}pA_{2} (Definition 3.2). At a high level, our algorithm splits the input and output matrices into blocks of size M×M\sqrt{M}\times\sqrt{M}. On the other hand, FlashAttention divides the n×dn\times d input matrices into multiple k×dk\times d matrices, where k<nk<n. Compared to our upper bound, we can see that FlashAttention is not optimal in this case. Following the computational graph in Figure 2, we perform the backward passes of attention using each M×M\sqrt{M}\times\sqrt{M} block as basic elements in standard matrix multiplication. Compared to forward passes, the computational graph of backward passes is more complicated and requires more fine-grained analysis, e.g., the four phases mentioned above. Through a detailed analysis of Algorithm 6, we establish Theorem 4.3.

Upper Bound of Large Cache. In Section D, we present algorithms for attention backward in the large cache case, where M=Ω(d2)M=\Omega(d^{2}). Similar to FlashAttention, the n×nn\times n attention matrix ff (Definition B.5) cannot be directly loaded into cache, even though it has been computed and can be stored in memory. The overall algorithm (Algorithm 9) consists of two phases. In Phase 1 (Algorithm 7), we compute S=A1XS=A_{1}X and h=A3Yh=A_{3}Y, and these two matrices are then passed to Phase 2. In Phase 2 (Algorithm 8), the inputs are matrices A1,A2,S,h,O,dOn×dA_{1},A_{2},S,h,O,\mathrm{d}O\in\mathbb{R}^{n\times d} (Definitions 3.1, B.6, B.7, and B.8), and vector lnl\in\mathbb{R}^{n} (Definition B.4). We vertically divide the inputs into row block matrices of size Br×dB_{r}\times d or Bc×dB_{c}\times d, where Br=min{M/4d,d}B_{r}=\min\{\lceil M/4d\rceil,d\} and Bc=M/4dB_{c}=\lceil M/4d\rceil. Using these row block matrices as computation units, we follow the computational graph (Fig. 2) and FlashAttention’s procedure. After accounting for the reads and writes of the overall algorithm (Algorithm 9), we prove Theorem 4.1. Furthermore, when the cache size is as large as Θ(nd)\Theta(nd), the I/O complexity can be reduced to O(nd+d2)O(nd+d^{2}), which corresponds to the size of the input and output of the algorithm.

Lower Bound of Large Cache and Small Cache. In Section E, we establish the lower bounds for the I/O complexity of attention gradient computation in both large and small cache cases. Following Definitions 3.4 and 3.5, we analyze the red-blue pebble game on the computational graph of any attention backward algorithm using standard matrix multiplication. More specifically, the key concept is the MM-partition, which decomposes the graph into subgraphs, ensuring that each subgraph satisfies conditions related to dominator and minimum sets (Definitions E.1, E.2, E.3, E.4, and E.5). Our proofs for the lower bound of backward passes builds upon the lemmas (Lemmas E.7 and E.8), which provide the foundation for relating the number of subgraphs to the I/O operations required. For the large cache scenario, M=Ω(d2)M=\Omega(d^{2}), we demonstrate that the I/O complexity scales with the need to compute matrix products efficiently. In the small cache case, M=o(d2)M=o(d^{2}), we show that higher I/O complexity is unavoidable due to the data transfers between cache and memory by reducing to the standard matrix multiplication. These analyses are formally established in the proofs of Theorems E.9 and E.10. In particular, our Theorems E.10, the small cache lower bound case, requires a new analysis deviation.

Remark 5.1.

The Softmax in Definition 3.1 can be changed to other non-linear activation functions and our lower bound still holds. It is because we must compute matrix multiplication of size n×dn\times d and d×nd\times n in non-linear attention. However, for linear attention, i.e., A1XA2A3YA_{1}XA_{2}^{\top}A_{3}Y, our lower bound is loose, since we can compute A2d×nA3n×d\underbrace{A_{2}^{\top}}_{d\times n}\underbrace{A_{3}}_{n\times d} first, and then we have A1n×dXd×dA2A3d×dYd×d\underbrace{A_{1}}_{n\times d}\underbrace{X}_{d\times d}\underbrace{A_{2}^{\top}A_{3}}_{d\times d}\underbrace{Y}_{d\times d}.

Lower Bound of Sparse Attention Forward and Backward Passes. In Section F, we establish lower bounds on the I/O complexity of sparse attention computation for both forward and backward passes. Sparse matrix multiplication is considered, where only non-zero entries are stored and used in computations. We derive I/O complexity bounds based on the non-zero counts of input matrices and the I/O operations required for sparse matrix multiplication (Lemma F.1). We further extend these bounds to the matrix products involved in the attention mechanism (Lemma F.2), which requires multiple sparse matrix multiplication analysis. We analyze scenarios where matrices are stored in cache or require intermediate I/Os during computation to obtain the I/O complexity bounds for both forward and backward passes (Theorems F.3 and Theorem F.4), and Theorem 4.5 directly holds as a consequence.

6 Conclusion

This work provided a comprehensive analysis of the I/O complexity for attention mechanisms, focusing on backward passes. We established tight bounds on I/O complexity for both small and large caches. Our results confirm that FlashAttention is optimal for both forward and backward on large cache sizes. For small cache sizes, we provided improved upper and lower bounds compared to existing methods. Additionally, we derived lower bounds for sparse attention for both forward and backward and across cache sizes. Our findings complete the theoretical foundation for I/O complexity in attention mechanisms, offering insights for efficient LLM training and inference. We leave exploring practical implementations leveraging these theoretical insights and investigating I/O complexity for other emerging attention variants as our future work.

Acknowledgement

Research is partially supported by the National Science Foundation (NSF) Grants 2023239-DMS, CCF-2046710, and Air Force Grant FA9550-18-1-0166.

References

  • AA [22] Amol Aggarwal and Josh Alman. Optimal-degree polynomial approximations for exponentials and gaussian kernel density estimation. In Proceedings of the 37th Computational Complexity Conference, pages 1–23, 2022.
  • AAA+ [23] Josh Achiam, Steven Adler, Sandhini Agarwal, Lama Ahmad, Ilge Akkaya, Florencia Leoni Aleman, Diogo Almeida, Janko Altenschmidt, Sam Altman, Shyamal Anadkat, et al. Gpt-4 technical report. arXiv preprint arXiv:2303.08774, 2023.
  • ABIS [19] Jayadev Acharya, Sourbh Bhadane, Piotr Indyk, and Ziteng Sun. Estimating entropy of distributions in constant space. Advances in Neural Information Processing Systems, 32, 2019.
  • AJA+ [24] Marah Abdin, Sam Ade Jacobs, Ammar Ahmad Awan, Jyoti Aneja, Ahmed Awadallah, Hany Awadalla, Nguyen Bach, Amit Bahree, Arash Bakhtiari, Harkirat Behl, et al. Phi-3 technical report: A highly capable language model locally on your phone. arXiv preprint arXiv:2404.14219, 2024.
  • ALSY [23] Raghav Addanki, Chenyang Li, Zhao Song, and Chiwun Yang. One pass streaming algorithm for super long token attention approximation in sublinear space. arXiv preprint arXiv:2311.14652, 2023.
  • AMNW [22] Maryam Aliakbarpour, Andrew McGregor, Jelani Nelson, and Erik Waingarten. Estimation of entropy in constant space with improved sample complexity. Advances in Neural Information Processing Systems, 35:32474–32486, 2022.
  • Ant [24] Anthropic. The claude 3 model family: Opus, sonnet, haiku, 2024. https://www-cdn.anthropic.com/de8ba9b01c9ab7cbabf5c33b80b7bbc618857627/Model_Card_Claude_3.pdf.
  • AS [23] Josh Alman and Zhao Song. Fast attention requires bounded entries. Advances in Neural Information Processing Systems, 36, 2023.
  • [9] Josh Alman and Zhao Song. The fine-grained complexity of gradient computation for training large language models. arXiv preprint arXiv:2402.04497, 2024.
  • [10] Josh Alman and Zhao Song. How to capture higher-order correlations? generalizing matrix softmax attention to kronecker computation. In The Twelfth International Conference on Learning Representations, 2024.
  • AV [88] Alok Aggarwal and S Vitter, Jeffrey. The input/output complexity of sorting and related problems. Communications of the ACM, 31(9):1116–1127, 1988.
  • BBS [22] Gavin Brown, Mark Bun, and Adam Smith. Strong memory lower bounds for learning natural models. In Conference on Learning Theory, pages 4989–5029. PMLR, 2022.
  • BCC+ [16] Michael A Bender, Rezaul Chowdhury, Alexander Conway, Martin Farach-Colton, Pramod Ganapathi, Rob Johnson, Samuel McCauley, Bertrand Simon, and Shikha Singh. The i/o complexity of computing prime tables. In LATIN 2016: Theoretical Informatics: 12th Latin American Symposium, Ensenada, Mexico, April 11-15, 2016, Proceedings 12, pages 192–206. Springer, 2016.
  • BCE+ [23] Sébastien Bubeck, Varun Chandrasekaran, Ronen Eldan, Johannes Gehrke, Eric Horvitz, Ece Kamar, Peter Lee, Yin Tat Lee, Yuanzhi Li, Scott Lundberg, Harsha Nori, Hamid Palangi, Marco Tulio Ribeiro, and Yi Zhang. Sparks of artificial general intelligence: Early experiments with gpt-4. arXiv preprint arXiv:2303.12712, 2023.
  • BDS [19] Gianfranco Bilardi and Lorenzo De Stefani. The i/o complexity of toom-cook integer multiplication. In Proceedings of the Thirtieth Annual ACM-SIAM Symposium on Discrete Algorithms, pages 2034–2052. SIAM, 2019.
  • BPC [20] Iz Beltagy, Matthew E Peters, and Arman Cohan. Longformer: The long-document transformer. arXiv preprint arXiv:2004.05150, 2020.
  • CAM [24] Anshuman Chhabra, Hadi Askari, and Prasant Mohapatra. Revisiting zero-shot abstractive summarization in the era of large language models from the perspective of position bias. arXiv preprint arXiv:2401.01989, 2024.
  • CGRS [19] Rewon Child, Scott Gray, Alec Radford, and Ilya Sutskever. Generating long sequences with sparse transformers. arXiv preprint arXiv:1904.10509, 2019.
  • CLS+ [24] Bo Chen, Yingyu Liang, Zhizhou Sha, Zhenmei Shi, and Zhao Song. Hsr-enhanced sparse attention acceleration, 2024.
  • CP [23] Xi Chen and Binghui Peng. Memory-query tradeoffs for randomized convex optimization. In 2023 IEEE 64th Annual Symposium on Foundations of Computer Science (FOCS), pages 1400–1413. IEEE, 2023.
  • CPP [22] Xi Chen, Christos Papadimitriou, and Binghui Peng. Memory bounds for continual learning. In 2022 IEEE 63rd Annual Symposium on Foundations of Computer Science (FOCS), pages 519–530. IEEE, 2022.
  • CWW+ [24] Yupeng Chang, Xu Wang, Jindong Wang, Yuan Wu, Linyi Yang, Kaijie Zhu, Hao Chen, Xiaoyuan Yi, Cunxiang Wang, Yidong Wang, et al. A survey on evaluation of large language models. ACM Transactions on Intelligent Systems and Technology, 15(3):1–45, 2024.
  • CXCL [20] Yi Cui, Di Xiao, Daren BH Cline, and Dmitri Loguinov. Improving i/o complexity of triangle enumeration. IEEE Transactions on Knowledge and Data Engineering, 34(4):1815–1828, 2020.
  • CYL+ [24] Weize Chen, Ziming You, Ran Li, Yitong Guan, Chen Qian, Chenyang Zhao, Cheng Yang, Ruobing Xie, Zhiyuan Liu, and Maosong Sun. Internet of agents: Weaving a web of heterogeneous agents for collaborative intelligence. arXiv preprint arXiv:2407.07061, 2024.
  • Dao [23] Tri Dao. Flashattention-2: Faster attention with better parallelism and work partitioning. arXiv preprint arXiv:2307.08691, 2023.
  • DFE+ [22] Tri Dao, Dan Fu, Stefano Ermon, Atri Rudra, and Christopher Ré. Flashattention: Fast and memory-efficient exact attention with io-awareness. Advances in Neural Information Processing Systems, 35:16344–16359, 2022.
  • DG [24] Tri Dao and Albert Gu. Transformers are ssms: Generalized models and efficient algorithms through structured state space duality. arXiv preprint arXiv:2405.21060, 2024.
  • DL [18] Erik D Demaine and Quanquan C Liu. Red-blue pebble game: Complexity of computing the trade-off between cache size and memory transfers. In Proceedings of the 30th on Symposium on Parallelism in Algorithms and Architectures, pages 195–204, 2018.
  • DLL+ [17] Erik D Demaine, Andrea Lincoln, Quanquan C Liu, Jayson Lynch, and Virginia Vassilevska Williams. Fine-grained i/o complexity via reductions: New lower bounds, faster algorithms, and a time hierarchy. arXiv preprint arXiv:1711.07960, 2017.
  • DLS+ [24] Nouha Dziri, Ximing Lu, Melanie Sclar, Xiang Lorraine Li, Liwei Jiang, Bill Yuchen Lin, Sean Welleck, Peter West, Chandra Bhagavatula, Ronan Le Bras, et al. Faith and fate: Limits of transformers on compositionality. Advances in Neural Information Processing Systems, 36, 2024.
  • [31] Lorenzo De Stefani. The i/o complexity of hybrid algorithms for square matrix multiplication. arXiv preprint arXiv:1904.12804, 2019.
  • [32] Lorenzo De Stefani. On the i/o complexity of hybrid algorithms for integer multiplication. arXiv preprint arXiv:1912.08045, 2019.
  • DSWZ [23] Yichuan Deng, Zhao Song, Zifan Wang, and Han Zhang. Streaming kernel pca algorithm with small space. arXiv preprint arXiv:2303.04555, 2023.
  • DSY [24] Yichuan Deng, Zhao Song, and Chiwun Yang. Attention is naturally sparse with gaussian distributed input. arXiv preprint arXiv:2404.02690, 2024.
  • DT [24] Shiyuan Deng and Yufei Tao. Subgraph enumeration in optimal i/o complexity. In 27th International Conference on Database Theory (ICDT 2024). Schloss Dagstuhl–Leibniz-Zentrum für Informatik, 2024.
  • EZW+ [22] Beyza Ermis, Giovanni Zappella, Martin Wistuba, Aditya Rawal, and Cedric Archambeau. Memory efficient continual learning with transformers. Advances in Neural Information Processing Systems, 35:10629–10642, 2022.
  • FJL+ [24] Tao Feng, Chuanyang Jin, Jingyu Liu, Kunlun Zhu, Haoqin Tu, Zirui Cheng, Guanyu Lin, and Jiaxuan You. How far are we from agi. arXiv preprint arXiv:2405.10313, 2024.
  • FTH+ [24] Leo Feng, Frederick Tung, Hossein Hajimirsadeghi, Mohamed Osama Ahmed, Yoshua Bengio, and Greg Mori. Attention as an rnn. arXiv preprint arXiv:2405.13956, 2024.
  • GD [23] Albert Gu and Tri Dao. Mamba: Linear-time sequence modeling with selective state spaces. arXiv preprint arXiv:2312.00752, 2023.
  • GHTL [14] William Gropp, Torsten Hoefler, Rajeev Thakur, and Ewing Lusk. Using advanced MPI: Modern features of the message-passing interface. MIT Press, 2014.
  • GLM [20] Alon Gonen, Shachar Lovett, and Michal Moshkovitz. Towards a combinatorial characterization of bounded-memory learning. Advances in Neural Information Processing Systems, 33:9028–9038, 2020.
  • GRT [18] Sumegha Garg, Ran Raz, and Avishay Tal. Extractor-based time-space lower bounds for learning. In Proceedings of the 50th Annual ACM SIGACT Symposium on Theory of Computing, pages 990–1002, 2018.
  • GXD+ [23] Daya Guo, Canwen Xu, Nan Duan, Jian Yin, and Julian McAuley. Longcoder: A long-range pre-trained language model for code completion. In International Conference on Machine Learning, pages 12098–12107. PMLR, 2023.
  • GXG+ [23] Yunfan Gao, Yun Xiong, Xinyu Gao, Kangxiang Jia, Jinliu Pan, Yuxi Bi, Yi Dai, Jiawei Sun, and Haofen Wang. Retrieval-augmented generation for large language models: A survey. arXiv preprint arXiv:2312.10997, 2023.
  • GZL+ [23] Suyu Ge, Yunan Zhang, Liyuan Liu, Minjia Zhang, Jiawei Han, and Jianfeng Gao. Model tells you what to discard: Adaptive kv cache compression for llms. arXiv preprint arXiv:2310.01801, 2023.
  • HCL+ [24] Jerry Yao-Chieh Hu, Pei-Hsuan Chang, Haozheng Luo, Hong-Yu Chen, Weijian Li, Wei-Po Wang, and Han Liu. Outlier-efficient hopfield layers for large transformer-based models. In Forty-first International Conference on Machine Learning (ICML), 2024.
  • HCW+ [24] Jerry Yao-Chieh Hu, Bo-Yu Chen, Dennis Wu, Feng Ruan, and Han Liu. Nonparametric modern hopfield models. arXiv preprint arXiv:2404.03900, 2024.
  • HJK+ [24] Insu Han, Rajesh Jayaram, Amin Karbasi, Vahab Mirrokni, David Woodruff, and Amir Zandieh. Hyperattention: Long-context attention in near-linear time. In The Twelfth International Conference on Learning Representations, 2024.
  • HK [81] Jia-Wei Hong and Hsiang-Tsung Kung. I/o complexity: The red-blue pebble game. In Proceedings of the thirteenth annual ACM symposium on Theory of computing, pages 326–333, 1981.
  • HKLM [21] Max Hopkins, Daniel Kane, Shachar Lovett, and Michal Moshkovitz. Bounded memory active learning through enriched queries. In Conference on Learning Theory, pages 2358–2387. PMLR, 2021.
  • HLSL [24] Jerry Yao-Chieh Hu, Thomas Lin, Zhao Song, and Han Liu. On computational limits of modern hopfield models: A fine-grained complexity analysis. In Forty-first International Conference on Machine Learning (ICML), 2024.
  • HSK+ [24] Jerry Yao-Chieh Hu, Maojiang Su, En-Jui Kuo, Zhao Song, and Han Liu. Computational limits of low-rank adaptation (lora) for transformer-based models. arXiv preprint arXiv:2406.03136, 2024.
  • HSW+ [22] Edward J Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, and Weizhu Chen. LoRA: Low-rank adaptation of large language models. In International Conference on Learning Representations, 2022.
  • HWL [21] Weihua He, Yongyun Wu, and Xiaohua Li. Attention mechanism for neural machine translation: a survey. In 2021 IEEE 5th Information Technology, Networking, Electronic and Automation Control Conference (ITNEC), volume 5, pages 1485–1489. IEEE, 2021.
  • HYW+ [23] Jerry Yao-Chieh Hu, Donglin Yang, Dennis Wu, Chenwei Xu, Bo-Yu Chen, and Han Liu. On sparse modern hopfield model. In Thirty-seventh Conference on Neural Information Processing Systems (NeurIPS), 2023.
  • JHC [21] Yuli Jiang, Xin Huang, and Hong Cheng. I/o efficient k-truss community search in massive graphs. The VLDB Journal, 30(5):713–738, 2021.
  • JSM+ [23] Albert Q Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, Florian Bressand, Gianna Lengyel, Guillaume Lample, Lucile Saulnier, et al. Mistral 7b. arXiv preprint arXiv:2310.06825, 2023.
  • JZ [20] Saachi Jain and Matei Zaharia. Spectral lower bounds on the i/o complexity of computation graphs. In Proceedings of the 32nd ACM Symposium on Parallelism in Algorithms and Architectures, pages 329–338, 2020.
  • KGR+ [22] Takeshi Kojima, Shixiang Shane Gu, Machel Reid, Yutaka Matsuo, and Yusuke Iwasawa. Large language models are zero-shot reasoners. Advances in neural information processing systems, 35:22199–22213, 2022.
  • KHC+ [24] Tzu-Sheng Kuo, Aaron Lee Halfaker, Zirui Cheng, Jiwoo Kim, Meng-Hsin Wu, Tongshuang Wu, Kenneth Holstein, and Haiyi Zhu. Wikibench: Community-driven data curation for ai evaluation on wikipedia. In Proceedings of the CHI Conference on Human Factors in Computing Systems, pages 1–24, 2024.
  • KMZ [23] Praneeth Kacham, Vahab Mirrokni, and Peilin Zhong. Polysketchformer: Fast transformers via sketches for polynomial kernels. arXiv preprint arXiv:2310.01655, 2023.
  • LCT+ [24] Na Liu, Liangyu Chen, Xiaoyu Tian, Wei Zou, Kaijiang Chen, and Ming Cui. From llm to conversational agent: A memory enhanced architecture with fine-tuning of large language models. arXiv preprint arXiv:2401.02777, 2024.
  • LL [21] Xiang Lisa Li and Percy Liang. Prefix-tuning: Optimizing continuous prompts for generation. arXiv preprint arXiv:2101.00190, 2021.
  • [64] Yingyu Liang, Heshan Liu, Zhenmei Shi, Zhao Song, Zhuoyan Xu, and Junze Yin. Conv-basis: A new paradigm for efficient attention inference and gradient computation in transformers. arXiv preprint arXiv:2405.05219, 2024.
  • [65] Yingyu Liang, Jiangxuan Long, Zhenmei Shi, Zhao Song, and Yufa Zhou. Beyond linear approximations: A novel pruning approach for attention matrix, 2024.
  • LLSS [24] Xiaoyu Li, Yingyu Liang, Zhenmei Shi, and Zhao Song. A tighter complexity analysis of sparsegpt. arXiv preprint arXiv:2408.12151, 2024.
  • LPP+ [20] Patrick Lewis, Ethan Perez, Aleksandra Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, et al. Retrieval-augmented generation for knowledge-intensive nlp tasks. Advances in Neural Information Processing Systems, 33:9459–9474, 2020.
  • LSS+ [24] Yingyu Liang, Zhizhou Sha, Zhenmei Shi, Zhao Song, and Yufa Zhou. Multi-layer transformers gradient can be approximated in almost linear time. arXiv preprint arXiv:2408.13233, 2024.
  • LSSY [24] Yingyu Liang, Zhenmei Shi, Zhao Song, and Chiwun Yang. Toward infinite-long prefix in transformer. arXiv preprint arXiv:2406.14036, 2024.
  • [70] Yingyu Liang, Zhenmei Shi, Zhao Song, and Yufa Zhou. Differential privacy of cross-attention with provable guarantee. arXiv preprint arXiv:2407.14717, 2024.
  • [71] Yingyu Liang, Zhenmei Shi, Zhao Song, and Yufa Zhou. Tensor attention training: Provably efficient learning of higher-order transformers. arXiv preprint arXiv:2405.16411, 2024.
  • LSZ+ [20] S Cliff Liu, Zhao Song, Hengjie Zhang, Lichen Zhang, and Tianyi Zhou. Space-efficient interior point method, with applications to linear programming and maximum weight bipartite matching. arXiv preprint arXiv:2009.06106, 2020.
  • LT [24] AI @ Meta Llama Team. The llama 3 herd of models. arXiv preprint arXiv:2407.21783, 2024.
  • MLH+ [22] Sewon Min, Xinxi Lyu, Ari Holtzman, Mikel Artetxe, Mike Lewis, Hannaneh Hajishirzi, and Luke Zettlemoyer. Rethinking the role of demonstrations: What makes in-context learning work? arXiv preprint arXiv:2202.12837, 2022.
  • MLT+ [24] Adyasha Maharana, Dong-Ho Lee, Sergey Tulyakov, Mohit Bansal, Francesco Barbieri, and Yuwei Fang. Evaluating very long-term conversational memory of llm agents. arXiv preprint arXiv:2402.17753, 2024.
  • MMS+ [19] Louis Martin, Benjamin Muller, Pedro Javier Ortiz Suárez, Yoann Dupont, Laurent Romary, Éric Villemonte de La Clergerie, Djamé Seddah, and Benoit Sagot. Camembert: a tasty french language model. arXiv preprint arXiv:1911.03894, 2019.
  • MPK [21] Arnab Maiti, Vishakha Patil, and Arindam Khan. Multi-armed bandits with bounded arm-memory: Near-optimal guarantees for best-arm identification and regret minimization. Advances in Neural Information Processing Systems, 34:19553–19565, 2021.
  • MSSV [22] Annie Marsden, Vatsal Sharan, Aaron Sidford, and Gregory Valiant. Efficient convex optimization requires superlinear memory. In Conference on Learning Theory, pages 2390–2430. PMLR, 2022.
  • MT [17] Michal Moshkovitz and Naftali Tishby. A general memory-bounded learning algorithm. arXiv preprint arXiv:1712.03524, 2017.
  • MVK+ [24] Jean Mercat, Igor Vasiljevic, Sedrick Keh, Kushal Arora, Achal Dave, Adrien Gaidon, and Thomas Kollar. Linearizing large language models. arXiv preprint arXiv:2405.06640, 2024.
  • NS [19] Roy Nissim and Oded Schwartz. Revisiting the i/o-complexity of fast matrix multiplication with recomputations. In 2019 IEEE International Parallel and Distributed Processing Symposium (IPDPS), pages 482–490. IEEE, 2019.
  • OEN+ [22] Catherine Olsson, Nelson Elhage, Neel Nanda, Nicholas Joseph, Nova DasSarma, Tom Henighan, Ben Mann, Amanda Askell, Yuntao Bai, Anna Chen, et al. In-context learning and induction heads. arXiv preprint arXiv:2209.11895, 2022.
  • Ope [24] OpenAI. Introducing openai o1-preview. https://openai.com/index/introducing-openai-o1-preview/, 2024. Accessed: September 12.
  • PR [23] Binghui Peng and Aviad Rubinstein. Near optimal memory-regret tradeoff for online learning. In 2023 IEEE 64th Annual Symposium on Foundations of Computer Science (FOCS), pages 1171–1194. IEEE, 2023.
  • PS [14] Rasmus Pagh and Morten Stöckel. The input/output complexity of sparse matrix multiplication. In European Symposium on Algorithms, pages 750–761. Springer, 2014.
  • PZ [23] Binghui Peng and Fred Zhang. Online prediction in sub-linear space. In Proceedings of the 2023 Annual ACM-SIAM Symposium on Discrete Algorithms (SODA), pages 1611–1634. SIAM, 2023.
  • Raz [17] Ran Raz. A time-space lower bound for a large class of learning problems. In 2017 IEEE 58th Annual Symposium on Foundations of Computer Science (FOCS), pages 732–742. IEEE, 2017.
  • Raz [18] Ran Raz. Fast learning requires good memory: A time-space lower bound for parity learning. Journal of the ACM (JACM), 66(1):1–18, 2018.
  • RST+ [24] Machel Reid, Nikolay Savinov, Denis Teplyashin, Dmitry Lepikhin, Timothy Lillicrap, Jean-baptiste Alayrac, Radu Soricut, Angeliki Lazaridou, Orhan Firat, Julian Schrittwieser, et al. Gemini 1.5: Unlocking multimodal understanding across millions of tokens of context. arXiv preprint arXiv:2403.05530, 2024.
  • SAMB [24] Tanmay Singh, Harshvardhan Aditya, Vijay K Madisetti, and Arshdeep Bahga. Whispered tuning: Data privacy preservation in fine-tuning llms through differential privacy. Journal of Software Engineering and Applications, 17(1):1–22, 2024.
  • SBZ+ [24] Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, and Tri Dao. Flashattention-3: Fast and accurate attention with asynchrony and low-precision. arXiv preprint arXiv:2407.08608, 2024.
  • SCY+ [24] Hanshi Sun, Zhuoming Chen, Xinyu Yang, Yuandong Tian, and Beidi Chen. Triforce: Lossless acceleration of long sequence generation with hierarchical speculative decoding. arXiv preprint arXiv:2404.11912, 2024.
  • SD [15] Jacob Steinhardt and John Duchi. Minimax rates for memory-bounded sparse linear regression. In Conference on Learning Theory, pages 1564–1587. PMLR, 2015.
  • SGBJ [19] Sainbayar Sukhbaatar, Edouard Grave, Piotr Bojanowski, and Armand Joulin. Adaptive attention span in transformers. arXiv preprint arXiv:1905.07799, 2019.
  • SMN+ [24] Zhenmei Shi, Yifei Ming, Xuan-Phi Nguyen, Yingyu Liang, and Shafiq Joty. Discovering the gems in early layers: Accelerating long-context llms with 1000x input token reduction. arXiv preprint arXiv:2409.17422, 2024.
  • SSU [18] Mitchell Stern, Noam Shazeer, and Jakob Uszkoreit. Blockwise parallel decoding for deep autoregressive models. Advances in Neural Information Processing Systems, 31, 2018.
  • SSV [19] Vatsal Sharan, Aaron Sidford, and Gregory Valiant. Memory-sample tradeoffs for linear regression with small error. In Proceedings of the 51st Annual ACM SIGACT Symposium on Theory of Computing, pages 890–901, 2019.
  • SVV+ [23] Hamed Shirzad, Ameya Velingker, Balaji Venkatachalam, Danica J Sutherland, and Ali Kemal Sinop. Exphormer: Sparse transformers for graphs. In International Conference on Machine Learning, pages 31613–31632. PMLR, 2023.
  • SVW [16] Jacob Steinhardt, Gregory Valiant, and Stefan Wager. Memory, communication, and statistical queries. In Conference on Learning Theory, pages 1490–1516. PMLR, 2016.
  • SWXL [24] Zhenmei Shi, Junyi Wei, Zhuoyan Xu, and Yingyu Liang. Why larger language models do in-context learning differently? In Forty-first International Conference on Machine Learning, 2024.
  • SWXZ [22] Vaidehi Srinivas, David P Woodruff, Ziyu Xu, and Samson Zhou. Memory bounds for the experts problem. In Proceedings of the 54th Annual ACM SIGACT Symposium on Theory of Computing, pages 1158–1171, 2022.
  • SY [24] Barna Saha and Christopher Ye. I/o complexity of attention, or how optimal is flashattention? In Forty-first International Conference on Machine Learning, 2024.
  • SYZ [23] Zhao Song, Mingquan Ye, and Lichen Zhang. Streaming semidefinite programs: o(n)o(\sqrt{n}) passes, small space and fast runtime. arXiv preprint arXiv:2309.05135, 2023.
  • TBY+ [20] Yi Tay, Dara Bahri, Liu Yang, Donald Metzler, and Da-Cheng Juan. Sparse sinkhorn attention. In International Conference on Machine Learning, pages 9438–9447. PMLR, 2020.
  • TKRR [16] Yael Tauman Kalai, Ran Raz, and Oded Regev. On the space complexity of linear programming with preprocessing. In Proceedings of the 2016 ACM Conference on Innovations in Theoretical Computer Science, pages 293–300, 2016.
  • UAS+ [20] Mohd Usama, Belal Ahmad, Enmin Song, M Shamim Hossain, Mubarak Alrashoud, and Ghulam Muhammad. Attention-based sentiment analysis using convolutional and recurrent neural network. Future Generation Computer Systems, 113:571–578, 2020.
  • Vit [01] Jeffrey Scott Vitter. External memory algorithms and data structures: Dealing with massive data. ACM Computing surveys (CsUR), 33(2):209–271, 2001.
  • VSP+ [17] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. Advances in neural information processing systems, 30, 2017.
  • WHHL [24] Dennis Wu, Jerry Yao-Chieh Hu, Teng-Yun Hsiao, and Han Liu. Uniform memory retrieval with larger capacity for modern hopfield models. In Forty-first International Conference on Machine Learning (ICML), 2024.
  • WHL+ [24] Dennis Wu, Jerry Yao-Chieh Hu, Weijian Li, Bo-Yu Chen, and Han Liu. STanhop: Sparse tandem hopfield model for memory-enhanced time series prediction. In The Twelfth International Conference on Learning Representations (ICLR), 2024.
  • WMS+ [24] Jiayu Wang, Yifei Ming, Zhenmei Shi, Vibhav Vineet, Xin Wang, and Neel Joshi. Is a picture worth a thousand words? delving into spatial reasoning for vision language models. arXiv preprint arXiv:2406.14852, 2024.
  • WS [19] Blake Woodworth and Nathan Srebro. Open problem: The oracle complexity of convex optimization with limited memory. In Conference on Learning Theory, pages 3202–3210. PMLR, 2019.
  • WWS+ [22] Jason Wei, Xuezhi Wang, Dale Schuurmans, Maarten Bosma, Fei Xia, Ed Chi, Quoc V Le, and Denny Zhou. Chain-of-thought prompting elicits reasoning in large language models. Advances in neural information processing systems, 35:24824–24837, 2022.
  • XCG+ [23] Zhiheng Xi, Wenxiang Chen, Xin Guo, Wei He, Yiwen Ding, Boyang Hong, Ming Zhang, Junzhe Wang, Senjie Jin, Enyu Zhou, et al. The rise and potential of large language model based agents: A survey. arXiv preprint arXiv:2309.07864, 2023.
  • XGW+ [22] Xinchao Xu, Zhibin Gou, Wenquan Wu, Zheng-Yu Niu, Hua Wu, Haifeng Wang, and Shihang Wang. Long time no see! open-domain conversation with long-term persona memory. arXiv preprint arXiv:2203.05797, 2022.
  • XHH+ [24] Chenwei Xu, Yu-Chao Huang, Jerry Yao-Chieh Hu, Weijian Li, Ammar Gilani, Hsi-Sheng Goan, and Han Liu. Bishop: Bi-directional cellular learning for tabular data with generalized sparse modern hopfield model. In Forty-first International Conference on Machine Learning (ICML), 2024.
  • XSL [24] Zhuoyan Xu, Zhenmei Shi, and Yingyu Liang. Do large language models have compositional ability? an investigation into limitations and scalability. In First Conference on Language Modeling, 2024.
  • YCB+ [20] Chulhee Yun, Yin-Wen Chang, Srinadh Bhojanapalli, Ankit Singh Rawat, Sashank Reddi, and Sanjiv Kumar. O (n) connections are expressive enough: Universal approximability of sparse transformers. Advances in Neural Information Processing Systems, 33:13783–13794, 2020.
  • YGG+ [19] Zihao Ye, Qipeng Guo, Quan Gan, Xipeng Qiu, and Zheng Zhang. Bp-transformer: Modelling long-range context via binary partitioning. arXiv preprint arXiv:1911.04070, 2019.
  • ZBKR [24] Michael Zhang, Kush Bhatia, Hermann Kumbong, and Christopher Ré. The hedgehog & the porcupine: Expressive linear attentions with softmax mimicry. arXiv preprint arXiv:2402.04347, 2024.
  • ZCO+ [15] Hao Zhang, Gang Chen, Beng Chin Ooi, Kian-Lee Tan, and Meihui Zhang. In-memory big data management and processing: A survey. IEEE Transactions on Knowledge and Data Engineering, 27(7):1920–1948, 2015.
  • ZGD+ [20] Manzil Zaheer, Guru Guruganesh, Kumar Avinava Dubey, Joshua Ainslie, Chris Alberti, Santiago Ontanon, Philip Pham, Anirudh Ravula, Qifan Wang, Li Yang, et al. Big bird: Transformers for longer sequences. Advances in neural information processing systems, 33:17283–17297, 2020.
  • ZHDK [23] Amir Zandieh, Insu Han, Majid Daliri, and Amin Karbasi. Kdeformer: Accelerating transformers via kernel density estimation. In ICML. arXiv preprint arXiv:2302.02451, 2023.
  • ZHJL [24] Jingyi Zhang, Jiaxing Huang, Sheng Jin, and Shijian Lu. Vision-language models for vision tasks: A survey. IEEE Transactions on Pattern Analysis and Machine Intelligence, 2024.
  • ZJV+ [24] Chenyang Zhao, Xueying Jia, Vijay Viswanathan, Tongshuang Wu, and Graham Neubig. Self-guide: Better task-specific instruction following via self-synthetic finetuning. arXiv preprint arXiv:2407.12874, 2024.
  • ZL [24] Yuchen Zeng and Kangwook Lee. The expressive power of low-rank adaptation. In The Twelfth International Conference on Learning Representations, 2024.

Appendix

Roadmap.

In Section A, we present a more comprehensive overview of related work pertinent to our study. In Section B, we introduce additional preliminaries, including notations and definitions of intermediate variables. Section C provides algorithms and establishes an upper bound theorem for the attention backward pass in small cache case M=o(d2)M=o(d^{2}). In Section D, we offer algorithms and an upper bound theorem for the attention backward pass in large cache case M=Ω(d2)M=\Omega(d^{2}). In Section E, we provide proofs for our attention backward I/O complexity lower bound results. In Section F, we prove the I/O complexity lower bounds for sparse attention.

Appendix A More Related Work

Large Language Models.

The exceptional success of generative large language models (LLMs), such as GPT-4 [2], Claude 3 [7], Gemini 1.5 [89], Llama 3.1 [73], Mistral Nemo [57], Phi 3.5 [4], is fundamentally attributed to the transformer architecture introduced by VSP+ [108] and all support at least 128k input token length. The transformer architecture and its self-attention mechanism have become indispensable in leading natural language processing (NLP) models [22], demonstrating remarkable capabilities across a diverse array of applications, including language translation [54], sentiment analysis [106], language modeling [76], the integration of differential privacy [90, 70], and multi-modal tasks [124, 71, 111]. Transformers’ emergent compositional abilities [30, 117] and proficiency in in-context learning [82, 74, 100] have led some to consider them as early indicators of Artificial General Intelligence (AGI) [14]. As such, the transformer architecture continues to play a pivotal role in advancing the field of AI.

More about Attention Computation Acceleration.

The quadratic time complexity of attention computation with respect to the length of the input sequence [108] poses significant computational challenges, especially for long sequences. Consequently, accelerating attention computation has become a crucial research area, with approaches broadly divided into two categories: (1) theoretical optimization of computational complexity [8, 9], and (2) experimental improvements to model performance [26, 25, 91, 45, 38].

From a theoretical standpoint, numerous works focus on approximating the attention matrix to accelerate computation. For example, AS [8], AS24a [9] utilize polynomial kernel approximation techniques [1] to speed up both training and inference of a single attention layer, achieving almost linear time complexity, and extend this approach to multi-layer transformer [68] and tensor attention [10, 71]. Other theoretical contributions include the conv-basis method introduced by LLS+24a [64] and a near-linear time algorithm proposed by HJK+ [48] under the assumptions of uniform softmax column norms and sparsity.

Experimental approaches involve modifying model architectures and optimizing implementations to accelerate inference. Methods such as Mamba [39, 27], Linearizing Transformers [120, 80], PolySketchFormer [123, 61], and various implementations of the Hopfield Model [47, 46, 109, 116, 51, 110, 55] aim to improve model performance and inference speed. Additionally, specific techniques like weight pruning [65, 66] have been developed to accelerate LLM generation. Some other techniques are introduced for efficient adaptation, such as LoRA [53, 126, 52] and prefix turning [63, 69]. System-level optimizations, such as Flash Attention [26, 25, 91] and block-wise parallel decoding [96], address bottlenecks in attention mechanisms and enhance inference speed through efficient implementation strategies. Collectively, these advancements contribute to making attention mechanisms more scalable and efficient, facilitating the deployment of large-scale language models.

More about Learning with Bounded Memory and I/O Complexity.

Learning with bounded memory has been studied in various fields in machine learning such as online learning [77, 101, 84, 86], parity learning [99, 87, 88, 42], convex optimization [112, 78, 20], active learning [50], learning linear classifiers [12], attention computation [5], linear regression [93, 97, 12], linear programming [105, 72], semi-definite programming [103], principal component analysis [33], continual learning [21, 36], entropy estimation [3, 6] and others [79, 41].

A common memory model in computational systems is the two-level memory hierarchy. In this model, there are two layers of memory: a small but fast layer called the cache, and a large but slower layer called the memory. The I/O (input/output) complexity of an algorithm measures its efficiency based on the number of data transfer operations it performs between the cache and the memory. In domains such as big data analytics and database management, these data transfers can become significant performance bottlenecks because massive datasets cannot be entirely accommodated in the cache, and thus optimizing I/O is essential for fast data retrieval and storage, directly impacting query performance and system scalability [40, 121]. The early work of HK [49] formulated the I/O complexity mathematically using the language of graph theory. Vit [107] provides a comprehensive survey of the I/O complexity of various batched and online problems. There exists a substantial body of work on the I/O complexity of numerous problems, including sorting [11], graph algorithms [23, 58, 56, 35], fine-grained I/O complexity [29], computational trade-off in data transfers [28], computing prime tables [13], attention computation [102], integer multiplication [15, 32], and matrix multiplication [31, 81].

Appendix B Preliminary

In Section B.1, we define some basic notation we will use. In Section B.2, we introduce the memory hierarchy we consider. In Section B.3, we state important facts related to fast matrix multiplication. In Section B.4, we define several intermediate functions which will arise in our algorithms.

B.1 Notations

For any positive integer nn, we define [n]:={1,2,,n}[n]:=\{1,2,\dots,n\}. For two same length vector xx and yy, we use x,y\langle x,y\rangle to denote the inner product between xx and yy, i.e., x,y=i=1nxiyi\langle x,y\rangle=\sum_{i=1}^{n}x_{i}y_{i}. We use \circ to denote the Hadamard product i.e. the (i,j)(i,j)-entry of ABA\circ B is Ai,jBi,jA_{i,j}B_{i,j}. We use xyx\circ y to denote vector that ii-th entry is xiyix_{i}y_{i}. Let 𝟏n{\bf 1}_{n} denote the length-nn all ones vector. It is not hard to see that xy,𝟏n=x,y\langle x\circ y,{\bf 1}_{n}\rangle=\langle x,y\rangle. For a vector xx, we use xx^{\top} to denote the transpose of xx. For a matrix AA, we use AA^{\top} to denote the transpose of matrix AA. For a matrix AA, we use exp(A)\exp(A) to denote the matrix that (i,j)(i,j)-th coordinate is exp(Ai,j)\exp(A_{i,j}).

Given a matrix An×mA\in\mathbb{R}^{n\times m}, we index an individual entry as A[i,j]A[i,j]. The ii-th row is denoted A[i]A[i] while the jj-th column is denoted A[,j]A[*,j]. A[i1:i2,j1:j2]A[i_{1}:i_{2},j_{1}:j_{2}] denotes a block of AA consisting of entries (i,j)(i,j) where i[i1,i2]i\in[i_{1},i_{2}] and j[j1,j2]j\in[j_{1},j_{2}]. Given a block size BB, the block A[(i1)B+1:iB,(j1)B+1:jB]A[(i-1)\cdot B+1:i\cdot B,(j-1)\cdot B+1:j\cdot B] is denoted A(B)[i,j]A^{(B)}[i,j].

For a vector vnv\in\mathbb{R}^{n}, we similarly denote entries v[i]v[i], a contiguous block of entries as v[i1:i2]v[i_{1}:i_{2}], and the ii-th block of size BB as v(B)[i]v^{(B)}[i]. Let diag(v)\operatorname{diag}(v) denote the matrix Dn×nD\in\mathbb{R}^{n\times n} with D[i,i]=v[i]D[i,i]=v[i].

B.2 Memory Hierarchy

In this study, we consider a two-level memory hierarchy composed of a small but fast layer called the cache and a large, slower layer referred to as the memory. We assume that the memory has unlimited capacity, while the cache is constrained by a finite size MM. Moreover, all computations are performed exclusively within the cache.

B.3 Matrix Multiplication

We define matrix multiplication notation and state some well-known facts here.

Definition B.1.

Let n1,n2,n3n_{1},n_{2},n_{3}, denote any three positive integers. We use 𝒯mat(n1,n2,n3){\cal T}_{\mathrm{mat}}(n_{1},n_{2},n_{3}) to denote the time of multiplying an n1×n2n_{1}\times n_{2} matrix with another n2×n3n_{2}\times n_{3}.

Then, we introduce a well-known fact.

Fact B.2.

Let n1,n2,n3n_{1},n_{2},n_{3}, denote any three positive integers. 𝒯mat(n1,n2,n3)=O(𝒯mat(n1,n3,n2))=O(𝒯mat(n2,n1,n3))=O(𝒯mat(n2,n3,n1))=O(𝒯mat(n3,n1,n2))=O(𝒯mat(n3,n2,n1)){\cal T}_{\mathrm{mat}}(n_{1},n_{2},n_{3})=O({\cal T}_{\mathrm{mat}}(n_{1},n_{3},n_{2}))=O({\cal T}_{\mathrm{mat}}(n_{2},n_{1},n_{3}))=O({\cal T}_{\mathrm{mat}}(n_{2},n_{3},n_{1}))=O({\cal T}_{\mathrm{mat}}(n_{3},n_{1},n_{2}))=O({\cal T}_{\mathrm{mat}}(n_{3},n_{2},n_{1})).

B.4 Definitions of Intermediate Variables

We start by some definitions about Xd×dX\in\mathbb{R}^{d\times d}.

Definition B.3 (Definition 3.4 in AS24a [9]).

Let A1,A2n×dA_{1},A_{2}\in\mathbb{R}^{n\times d} be two matrices. Let Xd×dX\in\mathbb{R}^{d\times d}.

Let us define function A(X)A(X) to be:

A(X):=exp(A1XA2)n×n.\displaystyle A(X):=\underbrace{\exp(A_{1}XA_{2}^{\top})}_{n\times n}.
Definition B.4 (Definition 3.5 in AS24a [9]).

For A(X)n×nA(X)\in\mathbb{R}^{n\times n} defined in Definition B.3, we define the softmax normalizing vector l(X)nl(X)\in\mathbb{R}^{n} to be

l(X):=A(X)n×n𝟏nn×1.\displaystyle l(X):=\underbrace{A(X)}_{n\times n}\cdot\underbrace{{\bf 1}_{n}}_{n\times 1}.
Definition B.5 (Definition 3.6 in AS24a [9]).

Suppose that l(X)nl(X)\in\mathbb{R}^{n} is defined as in Definition B.4. Let A(X)n×nA(X)\in\mathbb{R}^{n\times n} be defined as in Definition B.3. For a fixed j0[n]j_{0}\in[n], let us consider f(X)j0f(X)_{j_{0}}

f(X)j0:=l(X)j01scalarA(X)j0n×1.\displaystyle f(X)_{j_{0}}:=\underbrace{l(X)_{j_{0}}^{-1}}_{\mathrm{scalar}}\underbrace{A(X)_{j_{0}}}_{n\times 1}.

Let f(X)n×nf(X)\in\mathbb{R}^{n\times n} denote the matrix where j0j_{0}-th row is (f(X)j0)(f(X)_{j_{0}})^{\top}.

Furthermore, the matrix form of f(X)f(X) is

f(X)=diag(l(X))A(X)\displaystyle f(X)=\operatorname{diag}(l(X))A(X)

We then define h(Y)h(Y) related to Yd×dY\in\mathbb{R}^{d\times d}.

Definition B.6 (Definition 3.7 in AS24a [9]).

For A3n×dA_{3}\in\mathbb{R}^{n\times d} and Yd×dY\in\mathbb{R}^{d\times d}, we define h(Y)n×dh(Y)\in\mathbb{R}^{n\times d} as:

h(Y):=A3n×dYd×d.\displaystyle h(Y):=\underbrace{A_{3}}_{n\times d}\underbrace{Y}_{d\times d}.

Let us define the forward output matrix OO.

Definition B.7.

Let f(X),h(Y)f(X),h(Y) be defined in Definition B.5 and B.6. We define the output of attention as:

O:=f(X)n×nh(Y)n×d\displaystyle O:=\underbrace{f(X)}_{n\times n}\underbrace{h(Y)}_{n\times d}

where On×dO\in\mathbb{R}^{n\times d} is the output matrix of attention forward computation.

Now, we define qq, which incorporates the information from upstream gradient.

Definition B.8 (Definition C.10 in LSS+ [68]).

Let dOn×d\mathrm{d}O\in\mathbb{R}^{n\times d} be the upstream gradient, the matrix resulting from the application of the chain rule. Define h(Y)n×dh(Y)\in\mathbb{R}^{n\times d} as in Definition B.6.

We define q(Y)n×nq(Y)\in\mathbb{R}^{n\times n} as

q(Y):=dOn×dh(Y)d×n\displaystyle q(Y):=\underbrace{\mathrm{d}O}_{n\times d}\underbrace{h(Y)^{\top}}_{d\times n}

Then we use q(Y)j0q(Y)_{j_{0}}^{\top} to denote the j0j_{0}-th row of q(Y)n×nq(Y)\in\mathbb{R}^{n\times n}.

Finally, we define the gradient component matrix pp.

Definition B.9 (Definition C.5 in AS24a [9]).

For every index j0[n]j_{0}\in[n], we define p(X)j0np(X)_{j_{0}}\in\mathbb{R}^{n} as

p(X)j0:=(diag(f(X)j0)f(X)j0f(X)j0)q(Y)j0.\displaystyle p(X)_{j_{0}}:=(\operatorname{diag}(f(X)_{j_{0}})-f(X)_{j_{0}}f(X)_{j_{0}}^{\top})q(Y)_{j_{0}}.

We define p(X)n×np(X)\in\mathbb{R}^{n\times n} in the sense that p(X)j0p(X)_{j_{0}}^{\top} is the j0j_{0}-th row of p(X)p(X). Additionally, p(X)p(X) has matrix form as

p(X)=\displaystyle p(X)= f(X)q(Y)diag((f(X)q(Y))𝟏n)f(X)\displaystyle~{}f(X)\circ q(Y)-\operatorname{diag}((f(X)\circ q(Y))\cdot{\bf 1}_{n})f(X)
=\displaystyle= f(X)q(Y)diag((OdO)𝟏n)f(X)\displaystyle~{}f(X)\circ q(Y)-\operatorname{diag}((O\circ\mathrm{d}O)\cdot{\bf 1}_{n})f(X)

where f(X),Of(X),O are defined in Definition B.5 and B.7, and q(Y),dOq(Y),\mathrm{d}O are defined in Definition B.8.

Appendix C I/O Complexity Upper Bound for Small Cache

In this section, we prove the I/O complexity upper bound (Theorem C.12) for small cache case M=o(d2)M=o(d^{2}). Specifically, in Section C.1, we introduce an algorithm of attention gradient computation without cache to guide our algorithm design. Section C.2 presents algorithms and analyses for attention gradient computation in the small cache setting. Finally, Section C.3 provides the upper bound theorem for the small cache case.

C.1 Algorithm for Attention Backward Without Cache

Using results from AS24a [9], we can compute the gradient in 𝒯mat(n,d,n)+𝒯mat(n,d,d){\cal T}_{\mathrm{mat}}(n,d,n)+{\cal T}_{\mathrm{mat}}(n,d,d) time.

Lemma C.1 (Attention gradient computation, Lemma C.8 in AS24a [9]).

If it holds that

  • Define A1,A2,A3,dOn×dA_{1},A_{2},A_{3},\mathrm{d}O\in\mathbb{R}^{n\times d}. Define X,Yd×dX,Y\in\mathbb{R}^{d\times d} to be several input fixed matrices.

  • Let X,Yd×dX,Y\in\mathbb{R}^{d\times d} denote matrix variables (we will compute gradient with respect to XX).

  • Let g=dL(X)dXd×dg=\frac{\mathrm{d}L(X)}{\mathrm{d}X}\in\mathbb{R}^{d\times d} (Definition 3.2).

Then, gradient gd×dg\in\mathbb{R}^{d\times d} can be computed in 𝒯mat(n,d,n)+𝒯mat(n,d,d){\cal T}_{\mathrm{mat}}(n,d,n)+{\cal T}_{\mathrm{mat}}(n,d,d) time.

We first give a naive algorithm that have not utilized cache to compute the gradient (Algorithm 1).

Algorithm 1 Attention gradient computation without cache. See more details in Section B and C of AS24a [9] and Section F of LSS+ [68].
1:procedure AttentionGradientNoCache(A1,A2,A3,dOn×dA_{1},A_{2},A_{3},\mathrm{d}O\in\mathbb{R}^{n\times d}, X,Yd×dX,Y\in\mathbb{R}^{d\times d}) \triangleright Lemma C.2, Lemma C.3
2:     Read A1,A2,XA_{1},A_{2},X, initialize A0n×nA\leftarrow 0^{n\times n}, compute AA+A1XA2A\leftarrow A+A_{1}XA_{2}^{\top}, and delete XX
3:     Compute Aexp(A)A\leftarrow\exp(A), initialize l0nl\leftarrow 0^{n}, and compute ll+A𝟏l\leftarrow l+A\cdot{\bf 1}
4:     Initialize f0n×nf\leftarrow 0^{n\times n}, compute ff+diag(l)1Af\leftarrow f+\operatorname{diag}(l)^{-1}A, and delete A,dA,d
5:     Read A3,YA_{3},Y, initialize h0n×dh\leftarrow 0^{n\times d}, compute hh+A3Yh\leftarrow h+A_{3}Y, and delete A3,YA_{3},Y
6:     Read dO\mathrm{d}O, initialize q0n×nq\leftarrow 0^{n\times n}, compute qq+dOhq\leftarrow q+\mathrm{d}Oh^{\top}, and delete dO,h\mathrm{d}O,h
7:     Initialize p0n×np\leftarrow 0^{n\times n}, compute pp+fqdiag((fq)𝟏)fp\leftarrow p+f\circ q-\operatorname{diag}((f\circ q)\cdot{\bf 1})f, and delete f,qf,q
8:     Initialize g0n×ng\leftarrow 0^{n\times n}, compute gg+A1pA2g\leftarrow g+A_{1}^{\top}pA_{2}, and delete A1,A2,pA_{1},A_{2},p
9:     return gg\triangleright g=dL(X)dXd×dg=\frac{\mathrm{d}L(X)}{\mathrm{d}X}\in\mathbb{R}^{d\times d}, see Definition 3.2
10:end procedure
Lemma C.2 (Correctness).

The AttentionGradientNoCache (Algorithm 1) outputs a d×dd\times d matrix dL(X)dX\frac{\mathrm{d}L(X)}{\mathrm{d}X} defined in Definition 3.2.

Proof.

From Lemma C.1, we know this holds. ∎

Lemma C.3 (Time/space complexity).

There exists an algorithm (see Algorithm 1) that can compute the exact gradient in Definition 3.2 in 𝒯mat(n,d,n)+𝒯mat(n,d,d){\cal T}_{\mathrm{mat}}(n,d,n)+{\cal T}_{\mathrm{mat}}(n,d,d) time and O(n2+d2)O(n^{2}+d^{2}) space.

Proof.

From Lemma C.1, we can prove the time complexity. Since the stored matrices have three sizes, namely n×dn\times d, n×nn\times n, d×dd\times d, the space complexity is O(n2+nd+d2)=O(n2+d2)O(n^{2}+nd+d^{2})=O(n^{2}+d^{2}). ∎

C.2 Algorithms for Attention Backward in Small Cache

We now give algorithms to compute the upper bound of small cache case M=o(d2)M=o(d^{2}) in attention backward computation.

First, we give the algorithm and analysis for Phase 1 (see Algorithm 2) to compute ff defined in Definition B.5.

Lemma C.4 (Correctness of Phase 1).

The AttentionGradientCachePhase1 (Algorithm 2) outputs a n×nn\times n matrix ff defined in Definition B.5.

Proof.

The algorithm first computes S=A1XS=A_{1}X. Then it computes A=SA2A=SA_{2}^{\top}, A=exp(A)A=\exp(A), and l=A𝟏l=A\cdot{\bf 1}. Finally, it outputs f=diag(l)1Af=\operatorname{diag}(l)^{-1}A which is ff defined in Definition B.5. ∎

Lemma C.5 (I/O complexity of Phase 1).

The I/O complexity of AttentionGradientCachePhase1 (Algorithm 2) is O(n2d+nd2M)O(\frac{n^{2}d+nd^{2}}{\sqrt{M}}).

Proof.

In Phase 1 (Algorithm 2) the number of items in cache is at most 3B2+B4B2M3B^{2}+B\leq 4B^{2}\leq M. For each iteration in computing S=A1XS=A_{1}X and A=SA2A=SA_{2}^{\top}, the algorithm reads O(B2)O(B^{2}) from memory into cache. This is the dominating factor of the I/O complexity of the algorithm. Thus, the I/O complexity of Phase 1 is O(n2dB3B2)+O(nd2B3B2)=O(n2d+nd2B)=O(n2d+nd2M)O(\frac{n^{2}d}{B^{3}}B^{2})+O(\frac{nd^{2}}{B^{3}}B^{2})=O(\frac{n^{2}d+nd^{2}}{B})=O(\frac{n^{2}d+nd^{2}}{\sqrt{M}}). ∎

Algorithm 2 Attention gradient computation with cache phase 1. Compute ff.
1:procedure AttentionGradientCachePhase1(A1,A2n×dA_{1},A_{2}\in\mathbb{R}^{n\times d}, Xd×dX\in\mathbb{R}^{d\times d}, M+M\in\mathbb{N}_{+}) \triangleright Lemma C.4, Lemma C.5
2:     BM/4B\leftarrow\lfloor\sqrt{M/4}\rfloor
3:     /*Phase 1: Compute ff*/
4:     for 1in/B1\leq i\leq\lceil n/B\rceil do
5:         for 1jd/B1\leq j\leq\lceil d/B\rceil do
6:              Initialize S(B)[i,j]0B×BS^{(B)}[i,j]\leftarrow 0^{B\times B} in cache
7:              for 1kd/B1\leq k\leq\lceil d/B\rceil do
8:                  Read A1(B)[i,k]A_{1}^{(B)}[i,k] and X(B)[k,j]X^{(B)}[k,j] into cache
9:                  Compute S(B)[i,j]S(B)[i,j]+A1(B)[i,k]X(B)[k,j]S^{(B)}[i,j]\leftarrow S^{(B)}[i,j]+A_{1}^{(B)}[i,k]X^{(B)}[k,j] in cache \triangleright S=A1XS=A_{1}X
10:                  Delete A1(B)[i,k]A_{1}^{(B)}[i,k] and X(B)[k,j]X^{(B)}[k,j] from cache
11:              end for
12:              Write S(B)[i,j]S^{(B)}[i,j] in to memory, and delete S(B)[i,j]S^{(B)}[i,j] from cache
13:         end for
14:     end for
15:     for 1in/B1\leq i\leq\lceil n/B\rceil do
16:         Initialize l(B)[i]0Bl^{(B)}[i]\leftarrow 0^{B} in cache
17:         for 1jn/B1\leq j\leq\lceil n/B\rceil do
18:              Initialize A(B)[i,j]0B×BA^{(B)}[i,j]\leftarrow 0^{B\times B} in cache
19:              for 1kd/B1\leq k\leq\lceil d/B\rceil do
20:                  Read S(B)[i,k]S^{(B)}[i,k] and (A2)(B)[k,j](A_{2}^{\top})^{(B)}[k,j] into cache
21:                  Compute A(B)[i,j]A(B)[i,j]+S(B)[i,k](A2)(B)[k,j]A^{(B)}[i,j]\leftarrow A^{(B)}[i,j]+S^{(B)}[i,k](A_{2}^{\top})^{(B)}[k,j] in cache \triangleright A=SA2A=SA_{2}^{\top}
22:                  Delete S(B)[i,k]S^{(B)}[i,k] and (A2)(B)[k,j](A_{2}^{\top})^{(B)}[k,j] from cache
23:              end for
24:              Compute A(B)[i,j]exp(A(B)[i,j])A^{(B)}[i,j]\leftarrow\exp(A^{(B)}[i,j]) in cache, and write A(B)[i,j]A^{(B)}[i,j] into memory
25:              Compute l(B)[i]l(B)[i]+A(B)[i,j]𝟏l^{(B)}[i]\leftarrow l^{(B)}[i]+A^{(B)}[i,j]\cdot\mathbf{1} in cache \triangleright l=A𝟏l=A\cdot{\bf 1}
26:              Delete A(B)[i,j]A^{(B)}[i,j] from cache
27:         end for
28:         for 1jn/B1\leq j\leq\lceil n/B\rceil do
29:              Initialize f(B)[i,j]0B×Bf^{(B)}[i,j]\leftarrow 0^{B\times B} in cache
30:              Read A(B)[i,j]A^{(B)}[i,j] into cache
31:              Compute f(B)[i,j]f(B)[i,j]+diag(l(B)[i])1A(B)[i,j]f^{(B)}[i,j]\leftarrow f^{(B)}[i,j]+\operatorname{diag}(l^{(B)}[i])^{-1}A^{(B)}[i,j]
32:              Write f(B)[i,j]f^{(B)}[i,j] into memory, and delete A(B)[i,j]A^{(B)}[i,j] and f(B)[i,j]f^{(B)}[i,j] from cache
33:         end for
34:         Delete l(B)[i]l^{(B)}[i] from cache
35:     end for
36:     return ff \triangleright fn×nf\in\mathbb{R}^{n\times n}, where ff is defined in Definition B.5
37:end procedure

Second, we give the algorithm and analysis for Phase 2 (see Algorithm 3) to compute qq defined in Definition B.8.

Lemma C.6 (Correctness of Phase 2).

The AttentionGradientCachePhase2 (Algorithm 3) outputs a n×nn\times n matrix qq defined in Definition B.8.

Proof.

The algorithm first computes h=A3Yh=A_{3}Y. Then, it outputs q=dOhq=\mathrm{d}Oh^{\top} which is exactly the same as qq defined in Definition B.8. ∎

Lemma C.7 (I/O complexity of Phase 2).

The I/O complexity of AttentionGradientCachePhase2 (Algorithm 3) is O(n2d+nd2M)O(\frac{n^{2}d+nd^{2}}{\sqrt{M}}).

Proof.

In Phase 2 (Algorithm 3) the number of items in cache is at most 3B24B2M3B^{2}\leq 4B^{2}\leq M. For each iteration in computing h=A3Yh=A_{3}Y and q=dOhq=\mathrm{d}Oh^{\top}, the algorithm reads O(B2)O(B^{2}) from memory into cache. This is the dominating factor of the I/O complexity of the algorithm. Thus, the I/O complexity of Phase 2 is O(n2dB3B2)+O(nd2B3B2)=O(n2d+nd2B)=O(n2d+nd2M)O(\frac{n^{2}d}{B^{3}}B^{2})+O(\frac{nd^{2}}{B^{3}}B^{2})=O(\frac{n^{2}d+nd^{2}}{B})=O(\frac{n^{2}d+nd^{2}}{\sqrt{M}}). ∎

Algorithm 3 Attention gradient computation with cache phase 2. Compute qq.
1:procedure AttentionGradientCachePhase2(A3,dOn×dA_{3},\mathrm{d}O\in\mathbb{R}^{n\times d}, fn×nf\in\mathbb{R}^{n\times n} Yd×dY\in\mathbb{R}^{d\times d}, M+M\in\mathbb{N}_{+}) \triangleright Lemma C.6, Lemma C.7
2:     BM/4B\leftarrow\lfloor\sqrt{M/4}\rfloor
3:     /* Phase 2: Compute qq */
4:     for 1in/B1\leq i\leq\lceil n/B\rceil do
5:         for 1jd/B1\leq j\leq\lceil d/B\rceil do
6:              Initialize h(B)[i,j]0B×Bh^{(B)}[i,j]\leftarrow 0^{B\times B} in cache
7:              for 1kd/B1\leq k\leq\lceil d/B\rceil do
8:                  Read A3(B)[i,k]A_{3}^{(B)}[i,k] and Y(B)[k,j]Y^{(B)}[k,j] into cache
9:                  Compute h(B)[i,j]h(B)[i,j]+A3(B)[i,k]Y(B)[k,j]h^{(B)}[i,j]\leftarrow h^{(B)}[i,j]+A_{3}^{(B)}[i,k]Y^{(B)}[k,j] in cache
10:                  Delete A3(B)[i,k]A_{3}^{(B)}[i,k] and Y(B)[k,j]Y^{(B)}[k,j] from cache
11:              end for
12:              Write h(B)[i,j]h^{(B)}[i,j] in to memory, and delete h(B)[i,j]h^{(B)}[i,j] from cache
13:         end for
14:     end for
15:     for 1in/B1\leq i\leq\lceil n/B\rceil do
16:         for 1jn/B1\leq j\leq\lceil n/B\rceil do
17:              Initialize q(B)[i,j]0B×Bq^{(B)}[i,j]\leftarrow 0^{B\times B} in cache
18:              for 1kd/B1\leq k\leq\lceil d/B\rceil do
19:                  Read dO(B)[i,k]\mathrm{d}O^{(B)}[i,k] and (h)(B)[k,j](h^{\top})^{(B)}[k,j] into cache
20:                  Compute q(B)[i,j]q(B)[i,j]+dO(B)[i,k](h)(B)[k,j]q^{(B)}[i,j]\leftarrow q^{(B)}[i,j]+\mathrm{d}O^{(B)}[i,k](h^{\top})^{(B)}[k,j] in cache
21:                  Delete dO(B)[i,k]\mathrm{d}O^{(B)}[i,k] and (h)(B)[k,j](h^{\top})^{(B)}[k,j] from cache
22:              end for
23:              Write q(B)[i,j]q^{(B)}[i,j] in to memory, and delete q(B)[i,j]q^{(B)}[i,j] from cache
24:         end for
25:     end for
26:     return qq \triangleright qn×nq\in\mathbb{R}^{n\times n}, where qq is defined in Definiton B.8
27:end procedure

Then, we give the algorithm and analysis for Phase 3 (see Algorithm 4) to compute pp defined in Definition B.9.

Lemma C.8 (Correctness of Phase 3).

The AttentionGradientCachePhase3 (Algorithm 4) outputs a n×nn\times n matrix pp defined in Definition B.9.

Proof.

The algorithm first computes v=(fq)𝟏v=(f\circ q)\cdot{\bf 1}. Then it outputs p=fqdiag(v)fp=f\circ q-\operatorname{diag}(v)f. ∎

Lemma C.9 (I/O complexity of Phase 3).

The I/O complexity of AttentionGradientCachePhase3 (Algorithm 4) is O(n2M)O(\frac{n^{2}}{\sqrt{M}}).

Proof.

In Phase 3 (Algorithm 4) the number of items in cache is at most 3B2+B4B2M3B^{2}+B\leq 4B^{2}\leq M. For each iteration in computing v=(fq)𝟏v=(f\circ q)\cdot{\bf 1} and p=fqdiag(v)fp=f\circ q-\operatorname{diag}(v)f. The algorithm reads O(B2)O(B^{2}) from memory into cache. This is the dominating factor of the I/O complexity of the algorithm. Thus, the I/O complexity of Phase 2 is O(n2B3B2)=O(n2B)=O(n2M)O(\frac{n^{2}}{B^{3}}B^{2})=O(\frac{n^{2}}{B})=O(\frac{n^{2}}{\sqrt{M}}). ∎

Algorithm 4 Attention gradient computation with cache phase 3. Compute pp.
1:procedure AttentionGradientCachePhase3(qn×nq\in\mathbb{R}^{n\times n}, fn×nf\in\mathbb{R}^{n\times n}, M+M\in\mathbb{N}_{+}) \triangleright Lemma C.8, Lemma C.9
2:     BM/4B\leftarrow\lfloor\sqrt{M/4}\rfloor
3:     /* Phase 3: Compute pp */
4:     for 1in/B1\leq i\leq\lceil n/B\rceil do
5:         Initialize v(B)[i]0Bv^{(B)}[i]\leftarrow 0^{B} in cache
6:         for 1jn/B1\leq j\leq\lceil n/B\rceil do
7:              Read f(B)[i,j]f^{(B)}[i,j] and q(B)[i,j]q^{(B)}[i,j] into cache
8:              Compute v(B)[i]v(B)[i]+(f(B)[i,j]q(B)[i,j])𝟏v^{(B)}[i]\leftarrow v^{(B)}[i]+(f^{(B)}[i,j]\circ q^{(B)}[i,j])\cdot{\bf 1} \triangleright v=(fq)𝟏v=(f\circ q)\cdot{\bf 1}
9:              Delete f(B)[i,j]f^{(B)}[i,j] and q(B)[i,j]q^{(B)}[i,j] from cache
10:         end for
11:         for 1jn/B1\leq j\leq\lceil n/B\rceil do
12:              Initialize p(B)[i,j]0B×Bp^{(B)}[i,j]\leftarrow 0^{B\times B} in cache
13:              Read f(B)[i,j]f^{(B)}[i,j] and q(B)[i,j]q^{(B)}[i,j] into cache
14:              Compute p(B)[i,j]p(B)[i,j]+f(B)[i,j]q(B)[i,j]diag(v(B)[i])f(B)[i,j]p^{(B)}[i,j]\leftarrow p^{(B)}[i,j]+f^{(B)}[i,j]\circ q^{(B)}[i,j]-\operatorname{diag}(v^{(B)}[i])f^{(B)}[i,j]
15:              Delete f(B)[i,j]f^{(B)}[i,j] and q(B)[i,j]q^{(B)}[i,j] from cache
16:              Write p(B)[i,j]p^{(B)}[i,j] in to memory, and delete p(B)[i,j]p^{(B)}[i,j] from cache
17:         end for
18:         Delete v(B)[i]v^{(B)}[i] from cache
19:     end for
20:     return pp \triangleright pn×np\in\mathbb{R}^{n\times n}, where pp is defined in Definiton B.9
21:end procedure

Lastly, we give the algorithm and analysis for Phase 4 (see Algorithm 5) to compute dL(X)dX\frac{\mathrm{d}L(X)}{\mathrm{d}X}.

Lemma C.10 (Correctness of Phase 4).

The AttentionGradientCachePhase4 (Algorithm 5) outputs a d×dd\times d matrix g=dL(X)dXg=\frac{\mathrm{d}L(X)}{\mathrm{d}X} (Definition 3.2).

Proof.

The algorithm first computes T=A1pT=A_{1}^{\top}p. Then it outputs g=TA2g=TA_{2}. ∎

Lemma C.11 (I/O complexity of Phase 4).

The I/O complexity of AttentionGradientCachePhase4 (Algorithm 5) is O(n2d+nd2M)O(\frac{n^{2}d+nd^{2}}{\sqrt{M}}).

Proof.

In Phase 4 (Algorithm 5) the number of items in cache is at most 3B24B2M3B^{2}\leq 4B^{2}\leq M. For each iteration in computing T=A1pT=A_{1}^{\top}p and g=TA2g=TA_{2}. The algorithm reads O(B2)O(B^{2}) from memory into cache. This is the dominating factor of the I/O complexity of the algorithm. Thus, the I/O complexity of Phase 2 is O(n2dB3B2)+O(nd2B3B2)=O(n2d+nd2B)=O(n2d+nd2M)O(\frac{n^{2}d}{B^{3}}B^{2})+O(\frac{nd^{2}}{B^{3}}B^{2})=O(\frac{n^{2}d+nd^{2}}{B})=O(\frac{n^{2}d+nd^{2}}{\sqrt{M}}). ∎

Algorithm 5 Attention gradient computation with cache phase 4. Compute dL(X)dX\frac{\mathrm{d}L(X)}{\mathrm{d}X}.
1:procedure AttentionGradientCachePhase4(A1,A2n×dA_{1},A_{2}\in\mathbb{R}^{n\times d}, pn×np\in\mathbb{R}^{n\times n}, M+M\in\mathbb{N}_{+}) \triangleright Lemma C.10, Lemma C.11
2:     BM/4B\leftarrow\lfloor\sqrt{M/4}\rfloor
3:     /* Phase 4: Compute dL(X)dX\frac{\mathrm{d}L(X)}{\mathrm{d}X} */
4:     for 1id/B1\leq i\leq\lceil d/B\rceil do
5:         for 1jn/B1\leq j\leq\lceil n/B\rceil do
6:              Initialize T(B)[i,j]0B×BT^{(B)}[i,j]\leftarrow 0^{B\times B} in cache
7:              for 1kn/B1\leq k\leq\lceil n/B\rceil do
8:                  Read (A1)(B)[i,k](A_{1}^{\top})^{(B)}[i,k] and p(B)[k,j]p^{(B)}[k,j] into cache
9:                  Compute T(B)[i,j]T(B)[i,j]+(A1)(B)[i,k]p(B)[k,j]T^{(B)}[i,j]\leftarrow T^{(B)}[i,j]+(A_{1}^{\top})^{(B)}[i,k]p^{(B)}[k,j] in cache \triangleright T=A1pT=A_{1}^{\top}p
10:                  Delete (A1)(B)[i,k](A_{1}^{\top})^{(B)}[i,k] and p(B)[k,j]p^{(B)}[k,j] from cache
11:              end for
12:              Write T(B)[i,j]T^{(B)}[i,j] in to memory, and delete T(B)[i,j]T^{(B)}[i,j] from cache
13:         end for
14:     end for
15:     for 1id/B1\leq i\leq\lceil d/B\rceil do
16:         for 1jd/B1\leq j\leq\lceil d/B\rceil do
17:              Initialize g(B)[i,j]0B×Bg^{(B)}[i,j]\leftarrow 0^{B\times B} in cache
18:              for 1kn/B1\leq k\leq\lceil n/B\rceil do
19:                  Read T(B)[i,k]T^{(B)}[i,k] and A2(B)[k,j]A_{2}^{(B)}[k,j] into cache
20:                  Compute g(B)[i,j]g(B)[i,j]+T(B)[i,k]A2(B)[k,j]g^{(B)}[i,j]\leftarrow g^{(B)}[i,j]+T^{(B)}[i,k]A_{2}^{(B)}[k,j] in cache \triangleright g=TA2g=TA_{2}
21:                  Delete T(B)[i,k]T^{(B)}[i,k] and A2(B)[k,j]A_{2}^{(B)}[k,j] from cache
22:              end for
23:              Write g(B)[i,j]g^{(B)}[i,j] in to memory, and delete g(B)[i,j]g^{(B)}[i,j] from cache
24:         end for
25:     end for
26:     return gg \triangleright g=dL(X)dXd×dg=\frac{\mathrm{d}L(X)}{\mathrm{d}X}\in\mathbb{R}^{d\times d}, see Definition 3.2
27:end procedure

C.3 Upper Bound for Attention Backward in Small Cache M=o(d2)M=o(d^{2})

When cache size is not so big, i.e. M=o(d2)M=o(d^{2}), the attention backward is equivalent to matrix multiplication, thus having O(n2d+nd2M)O(\frac{n^{2}d+nd^{2}}{\sqrt{M}}) bound on the I/O complexity.

We show the upper bound theorem below for the overall algorithm (see Algorithm 6) to solve the attention backward in small cache case.

Theorem C.12 (Small cache upper bound, formal version of Theorem 4.3).

Suppose nn is the input length, dd is the head dimension, and MM is the cache size. There is an algorithm (see Algorithm 6) outputs a d×dd\times d matrix g=dL(X)dXg=\frac{\mathrm{d}L(X)}{\mathrm{d}X} (Definition 3.2) with I/O complexity O(n2d+nd2M)O(\frac{n^{2}d+nd^{2}}{\sqrt{M}}), time complexity 𝒯mat(n,d,n)+𝒯mat(n,d,d){\cal T}_{\mathrm{mat}}(n,d,n)+{\cal T}_{\mathrm{mat}}(n,d,d), and space complexity O(n2+d2)O(n^{2}+d^{2}).

Proof.

Time/space complexity.

First, we notice that Algorithm 6 calculates the same gradients as the Algorithm 1 except that the former utilize cache to speed up the computation and specify the standard matrix multiplication computations in cache. Thus, the overall time complexity 𝒯mat(n,d,n)+𝒯mat(n,d,d){\cal T}_{\mathrm{mat}}(n,d,n)+{\cal T}_{\mathrm{mat}}(n,d,d), and space complexity O(n2+d2)O(n^{2}+d^{2}) should be the same as Lemma C.3.

I/O complexity.

From Lemma C.5, C.7, C.9, and C.11, we know the overall I/O complexity is O(n2d+nd2M)+O(n2M)=O(n2d+nd2M)O(\frac{n^{2}d+nd^{2}}{\sqrt{M}})+O(\frac{n^{2}}{\sqrt{M}})=O(\frac{n^{2}d+nd^{2}}{\sqrt{M}}).

Correctness.

From Lemma C.4, C.6, C.8, and C.10, the algorithm computes the correct dL(X)dX\frac{\mathrm{d}L(X)}{\mathrm{d}X}. ∎

Algorithm 6 Attention gradient computation with small cache.
1:procedure AttentionGradientCache(A1,A2,A3,dOn×dA_{1},A_{2},A_{3},\mathrm{d}O\in\mathbb{R}^{n\times d}, X,Yd×dX,Y\in\mathbb{R}^{d\times d}, M+M\in\mathbb{N}_{+}) \triangleright Theorem C.12
2:     fAttentionGradientCachePhase1(A1,A2,X,M)f\leftarrow\textsc{AttentionGradientCachePhase1}(A_{1},A_{2},X,M) \triangleright see Algorithm 2
3:     qAttentionGradientCachePhase2(A3,dO,f,Y,M)q\leftarrow\textsc{AttentionGradientCachePhase2}(A_{3},\mathrm{d}O,f,Y,M) \triangleright see Algorithm 3
4:     pAttentionGradientCachePhase3(q,f,M)p\leftarrow\textsc{AttentionGradientCachePhase3}(q,f,M) \triangleright see Algorithm 4
5:     gAttentionGradientCachePhase4(A1,A2,p,M)g\leftarrow\textsc{AttentionGradientCachePhase4}(A_{1},A_{2},p,M) \triangleright see Algorithm 5
6:     return gg \triangleright g=dL(X)dXd×dg=\frac{\mathrm{d}L(X)}{\mathrm{d}X}\in\mathbb{R}^{d\times d}, see Definition 3.2
7:end procedure

Appendix D I/O Complexity Upper Bound for Large Cache

In this section, we establish the upper bound (Theorem D.5) for the I/O complexity in the case where the cache size is large, specifically when M=Ω(d2)M=\Omega(d^{2}). Section D.1 presents algorithms and analyses for attention gradient computation in the large cache setting. Section D.2 provides the upper bound theorem for the large cache case.

Since our goal is to compute the backward pass of the attention mechanism, and the forward pass has already been performed, it is natural to assume that we have access to the softmax normalizing vector l:=A𝟏nl:=A\cdot{\bf 1}\in\mathbb{R}^{n} (Definition B.4) and the final attention forward output O=diag(l)1AVn×dO=\operatorname{diag}(l)^{-1}AV\in\mathbb{R}^{n\times d} (Definition B.7) where A=exp(A1XA2)A=\exp(A_{1}XA_{2}^{\top}) (Definition B.3).

By utilizing these precomputed quantities from the forward pass, we can efficiently proceed with the backward computation while optimizing the I/O operations required.

D.1 Algorithms for Attention Backward in Large Cache

We first give Algorithm 7 and its analysis in large cache case for computing intermediate variables S,hS,h.

Algorithm 7 Attention gradient computation large cache phase 1. Compute S,hS,h.
1:procedure AttentionGradientLargeCachePhase1(A1,A3n×dA_{1},A_{3}\in\mathbb{R}^{n\times d}, X,Yd×dX,Y\in\mathbb{R}^{d\times d}, M+M\in\mathbb{N}_{+}) \triangleright Lemma D.1, Lemma D.2
2:     Brmin{M4d,d}B_{r}\leftarrow\min\{\lceil\frac{M}{4d}\rceil,d\} and BcM4dB_{c}\leftarrow\lceil\frac{M}{4d}\rceil
3:     Vertically divide A1A_{1} into Tr=nBrT_{r}=\lceil\frac{n}{B_{r}}\rceil blocks A1,1,,A1,TrA_{1,1},\dots,A_{1,T_{r}} of size Br×dB_{r}\times d each, and horizontally divide XX into Tc=dBcT_{c}=\lceil\frac{d}{B_{c}}\rceil blocks X,1,,X,TcX_{*,1},\dots,X_{*,T_{c}} of size d×Bcd\times B_{c} each
4:     Vertically divide A3A_{3} into Tr=nBrT_{r}=\lceil\frac{n}{B_{r}}\rceil blocks A3,1,,A3,TrA_{3,1},\dots,A_{3,T_{r}} of size Br×dB_{r}\times d each, and horizontally divide YY into Tc=dBcT_{c}=\lceil\frac{d}{B_{c}}\rceil blocks Y,1,,Y,TcY_{*,1},\dots,Y_{*,T_{c}} of size d×Bcd\times B_{c} each
5:     \triangleright Here A1,i,A3,iBr×dA_{1,i},A_{3,i}\in\mathbb{R}^{B_{r}\times d} means the ii-th row block of A1,A3A_{1},A_{3} for i[Tr]i\in[T_{r}], and X,j,Y,jd×BcX_{*,j},Y_{*,j}\in\mathbb{R}^{d\times B_{c}} means jj-th column block of X,YX,Y for j[Tc]j\in[T_{c}]
6:     for 1iTr1\leq i\leq T_{r} do
7:         Read A1,i,A3,iBr×dA_{1,i},A_{3,i}\in\mathbb{R}^{B_{r}\times d} into cache
8:         for 1jTc1\leq j\leq T_{c} do
9:              Read X,jd×BcX_{*,j}\in\mathbb{R}^{d\times B_{c}} into cache, and initialize Si,j0Br×BcS_{i,j}\leftarrow 0^{B_{r}\times B_{c}} in cache
10:              Compute Si,jSi,j+A1,iX,jS_{i,j}\leftarrow S_{i,j}+A_{1,i}X_{*,j} in cache \triangleright S=A1XS=A_{1}X
11:              Write Si,jS_{i,j} to memory, and delete Si,j,X,jS_{i,j},X_{*,j} from cache
12:              Read Y,jd×BcY_{*,j}\in\mathbb{R}^{d\times B_{c}} into cache, and initialize hi,j0Br×Bch_{i,j}\leftarrow 0^{B_{r}\times B_{c}} in cache
13:              Compute hi,jhi,j+A3,iY,jh_{i,j}\leftarrow h_{i,j}+A_{3,i}Y_{*,j} in cache \triangleright h=A3Yh=A_{3}Y
14:              Write hi,jh_{i,j} to memory, and delete hi,j,Y,jh_{i,j},Y_{*,j} from cache
15:         end for
16:         Delete A1,i,A3,iA_{1,i},A_{3,i} from cache
17:     end for
18:     return S,hS,h \triangleright S,hn×dS,h\in\mathbb{R}^{n\times d}
19:end procedure
Lemma D.1 (Correctness of Phase 1).

The AttentionGradientLargeCachePhase1 (Algorithm 7) outputs two n×dn\times d matrices S=A1XS=A_{1}X (Definition 3.1) and h=A3Yh=A_{3}Y (Definition B.6).

Proof.

The algorithm first divide A1,A3,X,YA_{1},A_{3},X,Y into row/column blocks of size Br×dB_{r}\times d or d×Bcd\times B_{c}. Then it reads the row/column block matrices to compute the corresponding small blocks of S,hS,h by standard matrix multiplication. Thus, it computes the exact value for S,hS,h. ∎

Lemma D.2 (I/O complexity of Phase 1).

Suppose the cache size satisfy ndMdnd\geq M\geq d. The I/O complexity of AttentionGradientLargeCachePhase1 (Algorithm 7) is O(n2d2M+nd3M)O(\frac{n^{2}d^{2}}{M}+\frac{nd^{3}}{M}).

Proof.

Why such conditions for Br,BcB_{r},B_{c}.

The cache size has three constraints, because we need matrices A1,i,A3,iBr×dA_{1,i},A_{3,i}\in\mathbb{R}^{B_{r}\times d}, X,j,Y,jd×BcX_{*,j},Y_{*,j}\in\mathbb{R}^{d\times B_{c}}, and Si,j,hi,jBr×BcS_{i,j},h_{i,j}\in\mathbb{R}^{B_{r}\times B_{c}} to fit into cache. Thus, we have

Brd=\displaystyle B_{r}d= O(M)\displaystyle~{}O(M)
Bcd=\displaystyle B_{c}d= O(M)\displaystyle~{}O(M)
BrBc=\displaystyle B_{r}B_{c}= O(M)\displaystyle~{}O(M)

Then, we need

Br=\displaystyle B_{r}= O(M/d)\displaystyle~{}O(M/d)
Bc=\displaystyle B_{c}= O(M/d)\displaystyle~{}O(M/d)

By setting Bc=Θ(M/d)B_{c}=\Theta(M/d), we have

Br=\displaystyle B_{r}= Θ(min{M/d,M/Bc})\displaystyle~{}\Theta(\min\{M/d,M/B_{c}\})
=\displaystyle= Θ(min{M/d,d})\displaystyle~{}\Theta(\min\{M/d,d\})

I/O complexity. We know Brmin{M4d,d}B_{r}\leftarrow\min\{\lceil\frac{M}{4d}\rceil,d\} and BcM4dB_{c}\leftarrow\lceil\frac{M}{4d}\rceil, also Tr=nBrT_{r}=\lceil\frac{n}{B_{r}}\rceil and Tc=dBrT_{c}=\lceil\frac{d}{B_{r}}\rceil. Substituting BrB_{r} into TrT_{r}, we get Tr=O(ndM)T_{r}=O(\frac{nd}{M}). Observe that TrBr=O(n)T_{r}B_{r}=O(n) and TcBc=O(d)T_{c}B_{c}=O(d).

The I/O complexity can be computed by:

Tr(Brd+Tc(dBc))=\displaystyle T_{r}(B_{r}d+T_{c}(dB_{c}))= O(nd)+Trd2\displaystyle~{}O(nd)+T_{r}d^{2}
=\displaystyle= O(nd)+O(ndMd2)\displaystyle~{}O(nd)+O(\frac{nd}{M}d^{2})
=\displaystyle= O(nd+nd3M)\displaystyle~{}O(nd+\frac{nd^{3}}{M})

where the first step follows from TrBr=O(n)T_{r}B_{r}=O(n) and TcBc=O(d)T_{c}B_{c}=O(d), the second step follows from Tr=O(ndM)T_{r}=O(\frac{nd}{M}), and the last step follows from simple algebra.

Because MndM\leq nd, we have

O(nd+nd3M)=\displaystyle O(nd+\frac{nd^{3}}{M})= O(ndMM+nd3M)\displaystyle~{}O(\frac{ndM}{M}+\frac{nd^{3}}{M})
=\displaystyle= O(n2d2M+nd3M)\displaystyle~{}O(\frac{n^{2}d^{2}}{M}+\frac{nd^{3}}{M})

Thus, the total I/O complexity is O(n2d2M+nd3M)O(\frac{n^{2}d^{2}}{M}+\frac{nd^{3}}{M})

Algorithm 8 Attention gradient computation large cache phase 2. Compute gg.
1:procedure AttentionGradientLargeCachePhase2(A1,A2,S,h,O,dOn×dA_{1},A_{2},S,h,O,\mathrm{d}O\in\mathbb{R}^{n\times d}, lnl\in\mathbb{R}^{n}, M+M\in\mathbb{N}_{+}) \triangleright Lemma D.3, Lemma D.4
2:     Brmin{M4d,d}B_{r}\leftarrow\min\{\lceil\frac{M}{4d}\rceil,d\} and BcM4dB_{c}\leftarrow\lceil\frac{M}{4d}\rceil
3:     Vertically divide SS into Tr=nBrT_{r}=\lceil\frac{n}{B_{r}}\rceil blocks S1,,STrS_{1},\dots,S_{T_{r}} of size Br×dB_{r}\times d each, vertically divide A2A_{2} into Tc=nBcT_{c}=\lceil\frac{n}{B_{c}}\rceil blocks A2,1,,A2,TcA_{2,1},\dots,A_{2,T_{c}} of size Bc×dB_{c}\times d each, and vertically divide ll into Tr=nBrT_{r}=\lceil\frac{n}{B_{r}}\rceil blocks l1,,lTrl_{1},\dots,l_{T_{r}} of size BrB_{r} each
4:     Vertically divide OO into Tr=nBrT_{r}=\lceil\frac{n}{B_{r}}\rceil blocks O1,,OTrO_{1},\dots,O_{T_{r}} of size Br×dB_{r}\times d each, vertically divide dO\mathrm{d}O into Tr=nBrT_{r}=\lceil\frac{n}{B_{r}}\rceil blocks dO1,,dOTr\mathrm{d}O_{1},\dots,\mathrm{d}O_{T_{r}} of size Br×dB_{r}\times d each, vertically divide hh into Tc=nBcT_{c}=\lceil\frac{n}{B_{c}}\rceil blocks h1,,hTch_{1},\dots,h_{T_{c}} of size Bc×dB_{c}\times d each, and vertically divide A1A_{1} into Tr=nBrT_{r}=\lceil\frac{n}{B_{r}}\rceil blocks A1,1,,A1,TrA_{1,1},\dots,A_{1,T_{r}} of size Br×dB_{r}\times d each
5:     Initialize g0d×dg\leftarrow 0^{d\times d} in cache
6:     for 1iTr1\leq i\leq T_{r} do
7:         Read Si,Oi,dOi,A1,iBr×dS_{i},O_{i},\mathrm{d}O_{i},A_{1,i}\in\mathbb{R}^{B_{r}\times d} and liBrl_{i}\in\mathbb{R}^{B_{r}} into cache
8:         Initialize vi0Brv_{i}\leftarrow 0^{B_{r}} and compute vivi+(dOiOi)𝟏v_{i}\leftarrow v_{i}+(\mathrm{d}O_{i}\circ O_{i})\cdot{\bf 1} in cache \triangleright v=(dOO)𝟏v=(\mathrm{d}O\circ O)\cdot{\bf 1}
9:         Delete OiO_{i} from cache
10:         for 1jTc1\leq j\leq T_{c} do
11:              Read hjBc×dh_{j}\in\mathbb{R}^{B_{c}\times d} and initialize qi,j0Br×Bcq_{i,j}\leftarrow 0^{B_{r}\times B_{c}} in cache
12:              Compute qi,jdOihjq_{i,j}\leftarrow\mathrm{d}O_{i}h_{j}^{\top} in cache \triangleright q=dOhq=\mathrm{d}Oh^{\top}
13:              Read A2,jBc×dA_{2,j}\in\mathbb{R}^{B_{c}\times d} into cache, and initialize Ai,j0Br×BcA_{i,j}\leftarrow 0^{B_{r}\times B_{c}} in cache
14:              Compute Ai,jAi,j+SiA2,jA_{i,j}\leftarrow A_{i,j}+S_{i}A_{2,j}^{\top} in cache \triangleright A=SA2A=SA_{2}^{\top}
15:              Compute Ai,jexp(Ai,j)A_{i,j}\leftarrow\exp(A_{i,j}) in cache, and initialize fi,j0Br×Bcf_{i,j}\leftarrow 0^{B_{r}\times B_{c}} in cache
16:              Compute fi,jfi,j+diag(li)1Ai,jf_{i,j}\leftarrow f_{i,j}+\operatorname{diag}(l_{i})^{-1}A_{i,j} in cache \triangleright f=diag(l)Af=\operatorname{diag}(l)A
17:              Delete Ai,jA_{i,j} from cache, and initialize pi,j0Br×Bcp_{i,j}\leftarrow 0^{B_{r}\times B_{c}} in cache
18:              Compute pi,jpi,j+fi,jqi,jdiag(vi)fi,jp_{i,j}\leftarrow p_{i,j}+f_{i,j}\circ q_{i,j}-\operatorname{diag}(v_{i})f_{i,j} in cache \triangleright p=fqdiag(v)fp=f\circ q-\operatorname{diag}(v)f
19:              Delete fi,j,qi,jf_{i,j},q_{i,j} in cache, and initialize T,j0d×BcT_{*,j}\leftarrow 0^{d\times B_{c}} in cache
20:              Compute T,jT,j+A1,ipi,jT_{*,j}\leftarrow T_{*,j}+A_{1,i}^{\top}p_{i,j} in cache \triangleright T=A1pT=A_{1}^{\top}p
21:              Compute gg+T,jA2,jg\leftarrow g+T_{*,j}A_{2,j} \triangleright g=TA2g=TA_{2}
22:              Delete T,j,A2,jT_{*,j},A_{2,j} from cache
23:         end for
24:         Delete Si,A1,i,dOi,li,viS_{i},A_{1,i},\mathrm{d}O_{i},l_{i},v_{i} from cache
25:     end for
26:     Write gg into memory
27:     return gg \triangleright g=dL(X)dXd×dg=\frac{\mathrm{d}L(X)}{\mathrm{d}X}\in\mathbb{R}^{d\times d}, see Definition 3.2
28:end procedure

We then give Algorithm 8 along with its analysis for computing the gradient gg.

Lemma D.3 (Correctness of Phase 2).

The AttentionGradientLargeCachePhase2 (Algorithm 8) outputs a d×dd\times d matrix gg (Definition 3.2).

Proof.

The algorithm first vertically divides the matrices SS, A2A_{2}, ll, OO, dO\mathrm{d}O, hh, and A1A_{1} into row blocks of size Br×dB_{r}\times d or Bc×dB_{c}\times d. Following the computational graph (Fig. 2) and the no-cache algorithm (Algorithm 1), we compute the gradient gg exactly. It is important to note that, in algorithm design, we need to avoid reading the attention matrix fn×nf\in\mathbb{R}^{n\times n} directly—even though it has been computed during the forward pass—or any matrices of size Br×nB_{r}\times n or Bc×nB_{c}\times n. Doing so would result in an O(n2)O(n^{2}) I/O complexity, which cannot be improved through caching. ∎

Lemma D.4 (I/O complexity of Phase 2).

Suppose the cache size satisfy ndMd2nd\geq M\geq d^{2}. The I/O complexity of AttentionGradientLargeCachePhase2 (Algorithm 8) is O(n2d2M+nd3M)O(\frac{n^{2}d^{2}}{M}+\frac{nd^{3}}{M}).

Proof.

The reason for conditions of Br,BcB_{r},B_{c} is the same as the proof of Lemma D.2. However, it is important to note that updating the gradient gg in the cache requires assuming a cache size of Md2M\geq d^{2}. This is necessary because we fuse the key and query weight matrices into a single matrix Xd×dX\in\mathbb{R}^{d\times d}. The update to the corresponding gradient gg in the cache is driven by the outer product representation of the matrix, as shown in Line 21 of Algorithm 8.

Next we show the I/O complexity. Since Brmin{M4d,d}B_{r}\leftarrow\min\{\lceil\frac{M}{4d}\rceil,d\} and BcM4dB_{c}\leftarrow\lceil\frac{M}{4d}\rceil, also Tr=nBrT_{r}=\lceil\frac{n}{B_{r}}\rceil and Tc=nBrT_{c}=\lceil\frac{n}{B_{r}}\rceil, we get Tr=O(ndM)T_{r}=O(\frac{nd}{M}). Also, we observe that TrBr=O(n)T_{r}B_{r}=O(n) and TcBc=O(n)T_{c}B_{c}=O(n).

The I/O complexity can be computed by:

Tr(Brd+TcBcd)+d2=\displaystyle T_{r}(B_{r}d+T_{c}B_{c}d)+d^{2}= O(nd)+Trnd+d2\displaystyle~{}O(nd)+T_{r}nd+d^{2}
=\displaystyle= O(Trnd)+d2\displaystyle~{}O(T_{r}nd)+d^{2}
=\displaystyle= O(n2d2M)+d2\displaystyle~{}O(\frac{n^{2}d^{2}}{M})+d^{2}

where the first step follows from TrBr=O(n)T_{r}B_{r}=O(n) and TcBc=O(n)T_{c}B_{c}=O(n), the second step follows from Tr1T_{r}\geq 1, and the last step follows from Tr=O(ndM)T_{r}=O(\frac{nd}{M}).

Then, because MndM\leq nd, we can show

O(d2+n2d2M)=\displaystyle O(d^{2}+\frac{n^{2}d^{2}}{M})= O(d2MM+n2d2M)\displaystyle~{}O(\frac{d^{2}M}{M}+\frac{n^{2}d^{2}}{M})
=\displaystyle= O(nd3M+n2d2M)\displaystyle~{}O(\frac{nd^{3}}{M}+\frac{n^{2}d^{2}}{M})

Thus, the total I/O complexity is O(n2d2M+nd3M)O(\frac{n^{2}d^{2}}{M}+\frac{nd^{3}}{M})

D.2 Upper Bound for Attention Backward in Large Cache M=Ω(d2)M=\Omega(d^{2})

In the large cache scenario, while it is feasible to precompute and store the n×nn\times n attention matrix, reading it will result in an unavoidable O(n2)O(n^{2}) I/O complexity. Inspired by FlashAttention [26, 25, 91], we present the following theorem, which provides an upper bound O(n2d2+nd3M)O(\frac{n^{2}d^{2}+nd^{3}}{M}) on the I/O complexity of the attention gradient algorithm in the large cache (Algorithm 9).

Theorem D.5 (Large cache upper bound, formal version of Theorem 4.1).

Suppose nn is the input length, dd is the head dimension, and ndMd2nd\geq M\geq d^{2} is the cache size. There is an algorithm (see Algorithm 9) outputs a d×dd\times d matrix g=dL(X)dXg=\frac{\mathrm{d}L(X)}{\mathrm{d}X} (Definition 3.2) with I/O complexity O(n2d2+nd3M)O(\frac{n^{2}d^{2}+nd^{3}}{M}).

Proof.

Correctness. Combining Lemma D.1 and D.3, we finish the proof.

I/O complexity. Combining Lemma D.2 and D.4, we finish the proof. ∎

Algorithm 9 Attention gradient computation with large cache.
1:procedure AttentionGradientLargeCache(A1,A2,A3,O,dOn×dA_{1},A_{2},A_{3},O,\mathrm{d}O\in\mathbb{R}^{n\times d}, X,Yd×dX,Y\in\mathbb{R}^{d\times d}, lnl\in\mathbb{R}^{n}, M+M\in\mathbb{N}_{+}) \triangleright Theorem D.5
2:     S,hAttentionGradientLargeCachePhase1(A1,A3,X,Y,M)S,h\leftarrow\textsc{AttentionGradientLargeCachePhase1}(A_{1},A_{3},X,Y,M) \triangleright see Algorithm 7
3:     gAttentionGradientLargeCachePhase4(A1,A2,h,S,O,dO,l,M)g\leftarrow\textsc{AttentionGradientLargeCachePhase4}(A_{1},A_{2},h,S,O,\mathrm{d}O,l,M) \triangleright see Algorithm 8
4:     return gg \triangleright g=dL(X)dXd×dg=\frac{\mathrm{d}L(X)}{\mathrm{d}X}\in\mathbb{R}^{d\times d}, see Definition 3.2
5:end procedure

Appendix E Lower Bound for Attention Backward Computation

In this section, we prove the lower bound of the attention gradient computation. In Section E.1, we state some definition in graph theory that will be used to establish the framework of [49] that will be used to analyze the I/O complexity. In Section E.2, we state some tools from previous works from I/O compleixty of standard matrix multiplication and attention forward computation. In Section E.3, we will establish our lower bounds of I/O complexity for attention backward passes in both large cache case and small cache case.

E.1 Basic Definition in Graph Theory

HK [49] introduces a method for analyzing I/O complexity using the concept of an MM-partition on a graph. Before we define it, we first provide some definitions from graph theory.

Definition E.1 (Dominator set).

Let G=(V,E)G=(V,E) be a directed acyclic graph and SVS\subseteq V. We define a set DVD\subseteq V as a dominator set of SS if, for every path in GG from a input node to any node in SS, there exists at least one node in DD on that path.

Definition E.2 (Minimum set).

Let G=(V,E)G=(V,E) be a directed acyclic graph and SVS\subseteq V. We say that a set MSM\subseteq S is a minimum set of SS if MM contains all nodes in SS that have no children in SS.

Definition E.3 (Vertex subset dependence).

Let G=(V,E)G=(V,E) be a directed acyclic graph. Let V1,V2VV_{1},V_{2}\subseteq V be two disjoint subsets. We say that V2V_{2} depends on V1V_{1} if there is a directed edge from a node in V1V_{1} to a node in V2V_{2}.

Definition E.4 (Cyclic dependence).

Let G=(V,E)G=(V,E) be a directed acyclic graph. Let V1,,VhVV_{1},\ldots,V_{h}\subseteq V be hh disjoint subsets of VV. We say that there is a cyclic dependence among {V1,,Vh}\{V_{1},\ldots,V_{h}\} if there exists a permutation (i1,,ih)(i_{1},\ldots,i_{h}) of [h][h] such that Vi1V_{i_{1}} depends on VihV_{i_{h}}, and for every j{2,,h}j\in\{2,\ldots,h\}, VijV_{i_{j}} depends on Vij1V_{i_{j-1}}.

Now, we are ready to define MM-partitons. In fact, the minimum number of sets in any MM-partition provides a lower bound on the I/O complexity.

Definition E.5 (MM-partition [49]).

Let G=(V,E)G=(V,E) be a directed acyclic graph. Let V1,,VhVV_{1},\ldots,V_{h}\subseteq V be hh disjoint subsets of VV. We say that {V1,,Vh}\{V_{1},\ldots,V_{h}\} is a MM-partition of GG if the following conditions are satisfied

  • {V1,,Vh}\{V_{1},\ldots,V_{h}\} is a partition of VV, i.e., V1,,VhV_{1},\ldots,V_{h} are disjoint and V=i=1hViV=\bigcup_{i=1}^{h}V_{i}.

  • For each ViV_{i}, there exists a dominator set DiD_{i} of ViV_{i} such that DiD_{i} has at most MM nodes.

  • For each ViV_{i}, there exists a minimum set MiM_{i} of ViV_{i} such that MiM_{i} has at most MM nodes.

  • There is no cyclic dependence among {V1,,Vh}\{V_{1},\ldots,V_{h}\}.

We use P(G,M)P(G,M) to denote the minimum number of sets in any MM-partition of GG.

E.2 Previous Tools for I/O Complexity

Now, we are ready to introduce some tools for I/O Complexity from HK [49] by using an MM-partition on a graph.

Lemma E.6 (Lemma 3.1 of HK [49]).

For any directed acyclic graph GG and any positive integer MM, we have

Q(G,M)M(P(G,2M)1).\displaystyle Q(G,M)\geq M\cdot(P(G,2M)-1).

We omit GG when it is clear in the context.

We state two useful lemmas from previous works as follows.

Lemma E.7 (Lemma 3.3 of SY [102]).

Suppose that M=Ω(d2)M=\Omega(d^{2}) and An1×d,Bd×n2A\in\mathbb{R}^{n_{1}\times d},B\in\mathbb{R}^{d\times n_{2}}. Let 𝒫\mathcal{P} be an MM-partition of the computational graph of any algorithm that computes ABAB using standard matrix multiplication. Then for each V𝒫V^{\prime}\in\mathcal{P}, VV^{\prime} contains at most O(M2d)O(\frac{M^{2}}{d}) product nodes Ai,kBk,jA_{i,k}B_{k,j}, sum nodes (AB)i,j(AB)_{i,j}, and all intermediate nodes in the summation trees.

In SY [102], the matrices AA and BB in the above lemma are of sizes n×dn\times d and d×nd\times n, respectively. We note that with slight modifications to the proofs, the result also holds when AA and BB have different sizes, specifically n1×dn_{1}\times d and d×n2d\times n_{2}.

The next lemma gives the lower bound of I/O compleixty of standard matrix multiplication.

Lemma E.8 (Corollary 6.2 of HK [49]).

Let An1×d,Bd×n2A\in\mathbb{R}^{n_{1}\times d},B\in\mathbb{R}^{d\times n_{2}}. The standard matrix multiplication algorithm computing ABAB has I/O complexity Q(M)=Ω(n1dn2M)Q(M)=\Omega(\frac{n_{1}dn_{2}}{\sqrt{M}}).

E.3 Proof of Our Lower Bound

We establish the lower bounds of I/O complexity of attention gradient computation in both large cache case and small cache case. We first give the lower bound in the large cache case, i.e., the cache size M=Ω(d2)M=\Omega(d^{2}).

Theorem E.9 (Large cache lower bound, formal version of Theorem 4.2).

Suppose nn is the input length and dd is the head dimension. Suppose the cache size M=Ω(d2)M=\Omega(d^{2}). Then the I/O complexity of attention gradient computation using standard matrix multiplication is Ω(n2d2+nd3M)\Omega(\frac{n^{2}d^{2}+nd^{3}}{M}).

Proof.

Any algorithm that computes the attention gradient needs to compute the matrix product A1XA2A_{1}XA_{2}^{\top} using standard matrix multiplication. Note that we compute A1XA2A_{1}XA_{2}^{\top} using standard matrix multiplication, so we either first compute A1XA_{1}X and then compute (A1X)A2(A_{1}X)A_{2}^{\top}, or first compute XA2XA_{2}^{\top} and then compute A1(XA2)A_{1}(XA_{2}^{\top}). In either case, we perform two matrix multiplications: one between an n×dn\times d matrix and a d×dd\times d matrix, and another between an n×dn\times d matrix and a d×nd\times n matrix. Without loss of generality, we assume the first case where we first compute A1XA_{1}X.

Recall that the level-1 nodes are the product nodes (A1)i,kXk,j(A_{1})_{i,k}X_{k,j}, the sum nodes (A1X)i,j(A_{1}X)_{i,j}, and all intermediate nodes in the summation trees. For every VV^{\prime} in an MM-partition 𝒫\mathcal{P}, by Lemma E.7, there are at most O(M2d)O(\frac{M^{2}}{d}) level-1 nodes in VV^{\prime}. Since the number of sum nodes (A1X)i,j(A_{1}X)_{i,j} is nd2nd^{2}, the number of parts in the MM-partition 𝒫\mathcal{P} is at least Ω(nd3M2)\Omega(\frac{nd^{3}}{M^{2}}). By Lemma E.6, the I/O complexity for computing A1XA_{1}X is Ω(n2dM)\Omega(\frac{n^{2}d}{M}).

Similarly, we recall that level-2 nodes are the product nodes (A1X)i,k(A2)k,j(A_{1}X)_{i,k}(A_{2}^{\top})_{k,j}, the sum nodes ((A1X)A2)i,j((A_{1}X)A_{2}^{\top})_{i,j}, and all intermediate nodes in the summation trees. For every VV^{\prime} in an MM-partition 𝒫\mathcal{P}, by Lemma E.7, there are at most O(M2d)O(\frac{M^{2}}{d}) level-2 nodes in VV^{\prime}. Since the number of sum nodes ((A1X)A2)i,j((A_{1}X)A_{2}^{\top})_{i,j} is n2dn^{2}d, the number of parts in the MM-partition 𝒫\mathcal{P} is at least Ω(n2d2M2)\Omega(\frac{n^{2}d^{2}}{M^{2}}). By Lemma E.6, the I/O complexity for computing (A1X)A2(A_{1}X)A_{2}^{\top} is Ω(n2d2M)\Omega(\frac{n^{2}d^{2}}{M}).

Therefore, the I/O complexity of attention gradient computation is at least Ω(nd3+n2d2M)\Omega(\frac{nd^{3}+n^{2}d^{2}}{M}). ∎

Next, we give the lower bound in the small cache case, i.e., the cache size M=o(d2)M=o(d^{2}).

Theorem E.10 (Small cache lower bound, formal version of Theorem 4.4).

Suppose nn is the input length and dd is the head dimension. Suppose the cache size M=o(d2)M=o(d^{2}). Then the I/O complexity of attention gradient computation using standard matrix multiplication is Ω(n2d+nd2M)\Omega(\frac{n^{2}d+nd^{2}}{\sqrt{M}}).

Proof.

We show that when M=o(d2)M=o(d^{2}), the attention gradient computation can be reduced to computing the matrix product A1XA2A_{1}XA_{2}^{\top}. Note that we compute A1XA2A_{1}XA_{2}^{\top} using standard matrix multiplication, so we either compute A1XA_{1}X first and then compute (A1X)A2(A_{1}X)A_{2}^{\top}, or we first compute XA2XA_{2}^{\top} and then A1(XA2)A_{1}(XA_{2}^{\top}). However, both cases require performing one matrix multiplication between an n×dn\times d matrix and a d×dd\times d matrix, and one matrix multiplication between an n×dn\times d matrix and a d×nd\times n matrix. Hence, without loss of generality, we assume that A1XA_{1}X is computed first. By Lemma E.8, the I/O complexity of computing A1XA_{1}X is Ω(nd2M)\Omega(\frac{nd^{2}}{\sqrt{M}}), and the I/O complexity of computing (A1X)A2(A_{1}X)A_{2}^{\top} is Ω(n2dM)\Omega(\frac{n^{2}d}{\sqrt{M}}). Hence, the total I/O complexity of computing A1XA2A_{1}XA_{2}^{\top} is Ω(n2d+nd2M)\Omega(\frac{n^{2}d+nd^{2}}{\sqrt{M}}).

Suppose that there is an algorithm 𝒜\mathcal{A} for attention gradient computation which has I/O complexity o(n2d+nd2M)o(\frac{n^{2}d+nd^{2}}{\sqrt{M}}). We construct an algorithm \mathcal{B} that computes the matrix product A1XA2A_{1}XA_{2}^{\top} with I/O complexity o(n2d+nd2M)o(\frac{n^{2}d+nd^{2}}{\sqrt{M}}). Since M<o(d2)M<o(d^{2}), we have n2d+nd2M>ω(n2+nd)>ω(n2)\frac{n^{2}d+nd^{2}}{\sqrt{M}}>\omega(n^{2}+nd)>\omega(n^{2}), so algorithm 𝒜\mathcal{A} is able to transfer the all entries of matrix product (A1X)A2(A_{1}X)A_{2}^{\top} from cache to memory. In the language of the red-blue pebble game, algorithm \mathcal{B} works as follows: whenever algorithm 𝒜\mathcal{A} delete a blue pebble from a node in (A1X)A2(A_{1}X)A_{2}^{\top}, do not delete it; whenever algorithm 𝒜\mathcal{A} place a red pebble on a node in (A1X)A2(A_{1}X)A_{2}^{\top}, also place a blue pebble on it. Since the I/O complexity of algorithm 𝒜\mathcal{A} is o(n2d+nd2M)o(\frac{n^{2}d+nd^{2}}{\sqrt{M}}) and we need an additional n2n^{2} I/O operations to transfer the entries of the matrix product (A1X)A2(A_{1}X)A_{2}^{\top} from cache to memory. Since n2<o(n2dM)n^{2}<o(\frac{n^{2}d}{\sqrt{M}}), the overall I/O complexity of \mathcal{B} is still o(n2d+nd2M)o(\frac{n^{2}d+nd^{2}}{\sqrt{M}}). However, this contradicts the fact that the I/O complexity of computing A1XA2A_{1}XA_{2}^{\top} is Ω(n2d+nd2M)\Omega(\frac{n^{2}d+nd^{2}}{\sqrt{M}}). Therefore, the I/O complexity of attention gradient computation using standard matrix multiplication is Ω(n2d+nd2M)\Omega(\frac{n^{2}d+nd^{2}}{\sqrt{M}}). ∎

Appendix F Sparse Attention Computation

In this section, we provide the lower bounds of sparse attention computation for both forward and backward passes. In Section F.1, we state previous tools of sparse matrix multiplication. In Section F.2, we provide the proofs of the lower bounds of sparse attention.

F.1 Previous Tools For I/O complexity of Sparse Matrix Multiplication

We assume that sparse matrices are stored by listing only their non-zero entries along with their coordinates. Sparse semi-ring matrix multiplication restricts operations to addition and multiplication of these entries, which means that each output entry (AB)i,j(AB)_{i,j} can only be computed as the sum of products given by kAi,kBk,j\sum_{k}A_{i,k}B_{k,j}.

Lemma F.1 (Theorem 2 of [85]).

Let An1×dA\in\mathbb{R}^{n_{1}\times d} and Bd×n2B\in\mathbb{R}^{d\times n_{2}} be two matrices such that R1:=nnz(A)+nnz(B)R_{1}:=\operatorname{nnz}(A)+\operatorname{nnz}(B) and R2:=nnz(AB)R_{2}:=\operatorname{nnz}(AB). The sparse semi-ring matrix multiplication that computes ABAB has I/O complexity Ω(min{R12M,R1R2M})\Omega(\min\{\frac{R_{1}^{2}}{M},\frac{R_{1}\sqrt{R_{2}}}{\sqrt{M}}\}).

Note that in this statement, the I/O complexity also separates into the large cache case and the small cache case, but the dividing point may not be d2d^{2}. It depends on whether all the necessary values for computing each output entry can be stored in the cache during the computation.

F.2 Our Lower Bounds for Sparse Attention Computation

We first prove a useful lemma which state the lower bound of I/O complexity of computing the attention matrix.

Lemma F.2.

Let A1n×d,Xd×d,A2d×nA_{1}\in\mathbb{R}^{n\times d},X\in\mathbb{R}^{d\times d},A_{2}\in\mathbb{R}^{d\times n} be three matrices. Let ZA:=min{nnz(A1),nnz(A2)},ZX:=nnz(X),ZAX=min{nnz(A1X),nnz(XA2)},ZAXA:=nnz(A1XA2)Z_{A}:=\min\{\operatorname{nnz}(A_{1}),\operatorname{nnz}(A_{2})\},Z_{X}:=\operatorname{nnz}(X),Z_{AX}=\min\{\operatorname{nnz}(A_{1}X),\operatorname{nnz}(XA_{2}^{\top})\},Z_{AXA}:=\operatorname{nnz}(A_{1}XA_{2}^{\top}). Then the sparse semi-ring matrix multiplication that computes A1XA2A_{1}XA_{2}^{\top} has I/O complexity Ω(min{ZA2+ZAZXM,ZAZAXA+ZAZXZAXM})\Omega(\min\{\frac{Z_{A}^{2}+Z_{A}Z_{X}}{M},\frac{Z_{A}\sqrt{Z_{AXA}}+\sqrt{Z_{A}Z_{X}Z_{AX}}}{\sqrt{M}}\}).

Proof.

We first consider the case where all the necessary values for computing each output entry can be stored in the cache during the computation. Suppose that A1XA_{1}X is computed first, by Lemma F.1, computing A1XA_{1}X has I/O compleixty

Ω((nnz(A1)+nnz(X))2M)=\displaystyle\Omega(\frac{(\operatorname{nnz}(A_{1})+\operatorname{nnz}(X))^{2}}{M})= Ω(nnz(A1)2+2nnz(A1)nnz(X)+nnz(X)2M)\displaystyle~{}\Omega(\frac{\operatorname{nnz}(A_{1})^{2}+2\operatorname{nnz}(A_{1})\operatorname{nnz}(X)+\operatorname{nnz}(X)^{2}}{M})
\displaystyle\geq Ω(ZA2+2ZAZX+ZX2M)\displaystyle~{}\Omega(\frac{Z_{A}^{2}+2Z_{A}Z_{X}+Z_{X}^{2}}{M})
\displaystyle\geq Ω(ZA2+2ZAZXM)\displaystyle~{}\Omega(\frac{Z_{A}^{2}+2Z_{A}Z_{X}}{M})

where the first step follows by the basic algebra, the second step uses the definition of ZA,ZXZ_{A},Z_{X}, and the last step follows from the basic algebra. Then we compute the product (A1X)A2(A_{1}X)A_{2}^{\top}, by Lemma F.1, computing A1XA_{1}X has I/O compleixty

Ω((nnz(A1X)+nnz(A2))2M)=\displaystyle\Omega(\frac{(\operatorname{nnz}(A_{1}X)+\operatorname{nnz}(A_{2}))^{2}}{M})= Ω(nnz(A1X)2+2nnz(A1X)nnz(A2)+nnz(A2)2M)\displaystyle~{}\Omega(\frac{\operatorname{nnz}(A_{1}X)^{2}+2\operatorname{nnz}(A_{1}X)\operatorname{nnz}(A_{2})+\operatorname{nnz}(A_{2})^{2}}{M})
\displaystyle\geq Ω(nnz(A2)2M)\displaystyle~{}\Omega(\frac{\operatorname{nnz}(A_{2})^{2}}{M})
=\displaystyle= Ω(ZA2M)\displaystyle~{}\Omega(\frac{Z_{A}^{2}}{M})

where the first and second steps follow by the basic algebra, and the last step uses the definition of ZAZ_{A}. Therefore, computing A1XA2A_{1}XA_{2}^{\top} in this way has I/O complexity Ω(2Z12+2Z1Z2M)=Ω(Z12+Z1Z2M).\Omega(\frac{2Z_{1}^{2}+2Z_{1}Z_{2}}{M})=\Omega(\frac{Z_{1}^{2}+Z_{1}Z_{2}}{M}). Similary, suppose that XA2XA_{2}^{\top} is computed first. Then we can also get the I/O complexity Ω(Z12+Z1Z2M)\Omega(\frac{Z_{1}^{2}+Z_{1}Z_{2}}{M}).

Next, we consider the case where some elementary products of matrix multiplication needs to be written in the memory during the computation. Suppose that A1XA_{1}X is computed first, and then (A1X)A2(A_{1}X)A_{2}^{\top} is computed. By Lemma F.1, computing (A1X)(A_{1}X) has I/O compleixty

Ω((nnz(A1)+nnz(X))nnz(A1X))M)\displaystyle\Omega(\frac{(\operatorname{nnz}(A_{1})+\operatorname{nnz}(X))\sqrt{\operatorname{nnz}(A_{1}X)})}{\sqrt{M}})\geq Ω(2nnz(A1)nnz(X)nnz(A1X)M)\displaystyle~{}\Omega(\frac{2\sqrt{\operatorname{nnz}(A_{1})\operatorname{nnz}(X)}\sqrt{\operatorname{nnz}(A_{1}X)}}{\sqrt{M}})
\displaystyle\geq Ω(2ZAZXZAXM)\displaystyle~{}\Omega(\frac{2\sqrt{Z_{A}Z_{X}Z_{AX}}}{\sqrt{M}})

where the first step uses Cauchy-Schwarz inequality, the second step uses the definition of ZAZ_{A}, ZXZ_{X} and ZAXAZ_{AXA}.

By Lemma F.1, computing (A1X)A2(A_{1}X)A_{2}^{\top} has I/O compleixty

Ω((nnz(A1X)+nnz(A2))nnz(A1XA2)M)\displaystyle\Omega(\frac{(\operatorname{nnz}(A_{1}X)+\operatorname{nnz}(A_{2}))\sqrt{\operatorname{nnz}(A_{1}XA_{2}^{\top})}}{\sqrt{M}})\geq Ω(nnz(A2)nnz(A1XA2)M)\displaystyle~{}\Omega(\frac{\operatorname{nnz}(A_{2})\sqrt{\operatorname{nnz}(A_{1}XA_{2}^{\top})}}{\sqrt{M}})
\displaystyle\geq Ω(ZAZAXAM).\displaystyle~{}\Omega(\frac{Z_{A}\sqrt{Z_{AXA}}}{\sqrt{M}}).

where the first step follows by the basic algebra, the second step uses the definition of ZAZ_{A} and ZAXAZ_{AXA}. Therefore, computing A1XA2A_{1}XA_{2}^{\top} in this way has I/O complexity Ω(ZAZAXA+ZAZXZAXM)\Omega(\frac{Z_{A}\sqrt{Z_{AXA}}+\sqrt{Z_{A}Z_{X}Z_{AX}}}{\sqrt{M}}). Similary, suppose that XA2XA_{2}^{\top} is computed first. Then we can also get the I/O complexity Ω(ZAZAXA+ZAZXZAXM)\Omega(\frac{Z_{A}\sqrt{Z_{AXA}}+\sqrt{Z_{A}Z_{X}Z_{AX}}}{\sqrt{M}}).

Therefore, the sparse semi-ring matrix multiplication that computes A1XA2A_{1}XA_{2}^{\top} has I/O complexity Ω(min{ZA2+ZAZXM,ZAZAXA+ZAZXZAXM})\Omega(\min\{\frac{Z_{A}^{2}+Z_{A}Z_{X}}{\sqrt{M}},\frac{Z_{A}\sqrt{Z_{AXA}}+\sqrt{Z_{A}Z_{X}Z_{AX}}}{\sqrt{\sqrt{M}}}\}). ∎

Next, we can apply Lemma F.2 to get the lower bound of sparse attention forward and backward passes.

Theorem F.3 (Lower bound for sparse attention forward).

Suppose nn is the input length, dd is the head dimension, and MM is the cache size. Let ZA:=min{nnz(A1),nnz(A2)},ZX:=nnz(X),ZAX=min{nnz(A1X),nnz(XA2)},ZAXA:=nnz(A1XA2)Z_{A}:=\min\{\operatorname{nnz}(A_{1}),\operatorname{nnz}(A_{2})\},Z_{X}:=\operatorname{nnz}(X),Z_{AX}=\min\{\operatorname{nnz}(A_{1}X),\operatorname{nnz}(XA_{2}^{\top})\},Z_{AXA}:=\operatorname{nnz}(A_{1}XA_{2}^{\top}). Then any algorithm for attention forward computation using sparse semi-ring matrix multiplication has I/O complexity Ω(min{ZA2+ZAZXM,ZAZAXA+ZAZXZAXM})\Omega(\min\{\frac{Z_{A}^{2}+Z_{A}Z_{X}}{M},\frac{Z_{A}\sqrt{Z_{AXA}}+\sqrt{Z_{A}Z_{X}Z_{AX}}}{\sqrt{M}}\}).

Proof.

Any algorithm for attention forward computation needs to compute the matrix product A1XA2A_{1}XA_{2}^{\top} to obtain the attention matrix. Thus by applying Lemma F.2, we complete the proof. ∎

Theorem F.4 (Lower bound for sparse attention backward).

Suppose nn is the input length, dd is the head dimension, and MM is the cache size. Let ZA:=min{nnz(A1),nnz(A2)},ZX:=nnz(X),ZAX=min{nnz(A1X),nnz(XA2)},ZAXA:=nnz(A1XA2)Z_{A}:=\min\{\operatorname{nnz}(A_{1}),\operatorname{nnz}(A_{2})\},Z_{X}:=\operatorname{nnz}(X),Z_{AX}=\min\{\operatorname{nnz}(A_{1}X),\operatorname{nnz}(XA_{2}^{\top})\},Z_{AXA}:=\operatorname{nnz}(A_{1}XA_{2}^{\top}). Then any algorithm for attention backward computation using sparse semi-ring matrix multiplication has I/O complexity Ω(min{ZA2+ZAZXM,ZAZAXA+ZAZXZAXM})\Omega(\min\{\frac{Z_{A}^{2}+Z_{A}Z_{X}}{M},\frac{Z_{A}\sqrt{Z_{AXA}}+\sqrt{Z_{A}Z_{X}Z_{AX}}}{\sqrt{M}}\}).

Proof.

Any algorithm for attention backward computation needs to compute the matrix product A1XA2A_{1}XA_{2}^{\top} to obtain the attention matrix. Thus by applying Lemma F.2, we complete the proof. ∎