This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Unraveling the Gradient Descent Dynamics of Transformers

Bingqing Song
University of Minnesota, Twin Cities
[email protected]
&Boran Han
Amazon Web Services
[email protected]
Shuai Zhang
Amazon Web Services
[email protected]
&Jie Ding
University of Minnesota, Twin Cities
[email protected]
&Mingyi Hong
University of Minnesota, Twin Cities
[email protected]
The work of B. Song was partially done while interning at Amazon Web Services.
Abstract

While the Transformer architecture has achieved remarkable success across various domains, a thorough theoretical foundation explaining its optimization dynamics is yet to be fully developed. In this study, we aim to bridge this understanding gap by answering the following two core questions: (1) Which types of Transformer architectures allow Gradient Descent (GD) to achieve guaranteed convergence? and (2) Under what initial conditions and architectural specifics does the Transformer achieve rapid convergence during training? By analyzing the loss landscape of a single Transformer layer using Softmax and Gaussian attention kernels, our work provides concrete answers to these questions. Our findings demonstrate that, with appropriate weight initialization, GD can train a Transformer model (with either kernel type) to achieve a global optimal solution, especially when the input embedding dimension is large. Nonetheless, certain scenarios highlight potential pitfalls: training a Transformer using the Softmax attention kernel may sometimes lead to suboptimal local solutions. In contrast, the Gaussian attention kernel exhibits a much favorable behavior. Our empirical study further validate the theoretical findings.

1 Introduction

Transformer model architectures have become popular in machine learning, delivering remarkable performance across a wide array of tasks. From natural language processing (Vaswani et al., 2017; Beltagy et al., 2020) to computer vision (Dosovitskiy et al., 2020), these models have set new standards in performance and efficiency. Popular models include BERT (Devlin et al., 2018), RoBERTa (Liu et al., 2019), DeBERTa (He et al., 2020), GPT models (Radford et al., 2019; Brown et al., 2020) and ViT (Dosovitskiy et al., 2020). Despite their empirical success, a comprehensive understanding of their optimization process remains elusive. As highlighted in Liu et al. (2020), the training of large Transformers can sometimes result in deteriorated performance. It is therefore critical to develop theoretical insights for researchers and practitioners to better understand the practical performance of Transformers. However, the complexity of their architectures, coupled with the non-convex nature of the associated optimization problems, has made the theoretical analysis of these models very challenging.

The optimization landscape can be pivotal for understanding a certain type of neural network and providing the practical guidance (Liu et al., 2020). Existing literature offers numerous studies on achieving zero-loss solutions in networks with ReLU activation. These studies encompass various network structures, including fully-connected, convolutional, and residual networks, as explored in (Jain et al., 2017), (Jin et al., 2021), and (Danilova et al., 2022). They delve into the analysis of network optimization landscapes and provide assurances of rapid global convergence when using gradient descent (GD) or stochastic gradient descent (SGD) algorithms. For instance, in Du et al. (2019), the authors focus on fully-connected networks and ResNets with smooth activation functions, and they have demonstrated that global convergence can be achieved using GD with a network size proportional to 𝒪(poly(N))\mathcal{O}\big{(}\text{poly}(N)\big{)}, where NN is the sample size. Similarly, (Allen-Zhu et al., 2019) show that ReLU fully-connected networks with at least 𝒪(poly(N))\mathcal{O}\big{(}\text{poly}(N)\big{)} neurons can achieve global convergence using GD or SGD. From a statistical perspective, (Li et al., 2023) have shown that for two-layer ReLU neural networks (with input dimension pp) that admit a sparse subnetwork representation, a sample size of O(log4(p/δ))O(\log^{4}(p/\delta)) can guarantee the global convergence with probability at least δ\delta using GD. Despite this extensive body of work on traditional architectures, it is not clear what conditions we need (e.g. network size, optimizer, initialization) to ensure training Transformer models to find high-quality solutions.

Compared to traditional deep learning architectures, Transformers incorporate a unique level of intricacy through their attention kernel (Vaswani et al., 2017), which is designed to effectively handle sequence inputs. This mechanism incorporates Softmax activation to the inner products of query and key vectors, and this inherently non-convex operation poses considerable challenges to theoretical analysis. Consequently, existing frameworks for analyzing the convergence of classical deep learning models are not directly applicable to Transformers. Further, many recent works have pointed out that the performance of Transformers depends on a number of factors such as the choice of kernel function, initialization, choice of optimizers, and forms of token embeddings  (Huang et al., 2020; Pan and Li, 2023; Shazeer, 2020; Li et al., 2018; Tian et al., 2023). In deep learning, these factors have been studied in a line works. For example, Li et al. (2018) show that the good training performance is not universal ; skip connections have the effect of smoothing the training landscape, and the Adam algorithm tends to follow a more direct trajectory towards optimal solutions compared to SGD. Therefore, it is imperative to understand what kind of conditions, including initialization, network structure, data properties, and optimizer choices, will lead to high-performing Transformers.

In this work, we will delve into the intricacies of attention kernels, discussing both their advantages and limitations in the context of model optimization. The main contributions of this work are threefold.

  • We derive the conditions that will make the one-layer Softmax attention Transformer reach global optimality with vanilla gradient descent. The convergence guarantee is largely attributed to the linear layer (WVW^{V}) in the attention mechanism.

  • We investigate the attention kernel’s effectiveness, revealing Gaussian attention achieves zero training loss, while Softmax can lead to non-optimal stationary points.

  • Our experiments validate that Softmax attention Transformers converge slower and present more challenging training landscapes than Gaussian counterparts, potentially leading to more local optimal solutions.

2 Related Work

A number of research works have focused on the theoretical analysis and interpretation of Transformer models, revealing crucial insights into their practical performance.

Liu et al. (2020) showed that heavy reliance on the residual branch in multi-layer Transformer models can lead to training instability, which amplifies small parameter perturbations, causing significant disturbances in the model’s output. In Bhojanapalli et al. (2020), the authors illustrated the existence of a low-rank bottleneck in Transformer models with sufficiently large embedding and hidden size (D=dD=d). However, this work focuses on the representation ability of large size attention, while falling short of analyzing Transformer models from an optimization perspective. In Noci et al. (2022), the authors explored rank collapse issues in token representations and their impact on training. The authors discussed the origin of the phenomenon of rank collapse and proposed depth-dependent scaling of residual branches as a potential solution. They specifically investigated scenarios where token rank equals one, which can hinder Transformer training. Their findings demonstrate the occurrence of the vanishing gradient issue, however, this work does not comprehensively characterize the vanishing gradient problem throughout the entire training process.

A recent work Wu et al. (2024) analyzes the convergence behaviour of shallow Transformer, which shows the global convergence can be achieved with GD algorithm. However, the focus of our paper is different from Wu et al. (2024). We not only derive the global convergence analysis (Our Theorem 2), but also investigates the role of different variables in optimization.

Some other works focus on improving the optimization of Transformers empirically. (Huang et al., 2020) have proposed an initialization strategy such that no warm-up or layer normalization is needed to train Transformers efficiently; in Shazeer (2020), the GLU variant of token embedding has been showed to be better than plain embedding in the optimization of Transformer models with Softmax attention kernel. It is worth noting that the above works all primarily focus on empirical investigations into the training of Transformer models, lacking a comprehensive theoretical analysis of the underlying mechanisms.

Some recent research has focused on the convergence analysis of Transformer-based models within the in-context learning (ICL) framework. For instance, Huang et al. (2023); Zhang et al. (2023) explores the learning dynamics of a one-layer Transformer with Softmax attention trained via gradient descent to learn linear function classes in-context. However, this line of study primarily addresses the general convergence performance of Transformers within the ICL setting and does not delve into the role of individual variables.

3 Notations and Problem Description

In this section, we define the structure of the Transformer model and describe the training problem. We consider a one-layer attention Transformer model with multiple heads and a dataset with NN samples. Each data sample consists of nn discrete tokens, each with embedding dimension DD. We denote the dataset as {(Xi,yi)}i=1N\{(X_{i},y_{i})\}_{i=1}^{N}, where Xin×DX_{i}\in\mathbb{R}^{n\times D}, and yiny_{i}\in\mathbb{R}^{n} is the label of the dataset. The output from the Transformer model is the prediction of the label. The Transformer structure is formulated as follows:

Attention(WhQ,WhK,WhV;Xi):=S(WhQ,WhK;Xi)XiWhV\displaystyle\operatorname{Attention}(W^{Q}_{h},W^{K}_{h},W^{V}_{h};X_{i}):=S(W^{Q}_{h},W^{K}_{h};X_{i})X_{i}W_{h}^{V} (1)
𝖬𝖧(WQ,WK,WV;Xi):=Concat(head1,,headH)WO,\displaystyle\operatorname{\sf MH}(W^{Q},W^{K},W^{V};X_{i}):=\operatorname{Concat}\left(\text{head}_{1},\ldots,\text{head}_{\mathrm{H}}\right)\cdot W^{O},
where headh:=Attention(WhQ,WhK,WhV;Xi),h=1,,H.\displaystyle\text{where }\operatorname{head}_{\mathrm{h}}:=\operatorname{Attention}(W^{Q}_{h},W^{K}_{h},W^{V}_{h};X_{i}),h=1,\cdots,H. (2)

In the above notation, WhQ,WhKD×dW^{Q}_{h},W^{K}_{h}\in\mathbb{R}^{D\times d} is the query weight matrix and key weight matrix, respectively; WhVD×dW_{h}^{V}\in\mathbb{R}^{D\times d} is the value weight matrix; these matrices are the main optimization variables throughout the paper. Further WOHd×1W^{O}\in\mathbb{R}^{Hd\times 1} is a fixed matrix, representing the weight of the output layer; HH is the number of attention heads; S()S(\cdot) is a kernel function of variables WQ,WKW^{Q},W^{K} and input XiX_{i}. Attention()(\cdot) is the attention head function; MH()(\cdot) represents the multi-head attention function. For example, with the Softmax attention (Vaswani et al., 2017), S()S(\cdot) can be written as:

S(WhQ,WhK;Xi):=Softmax(XiWhQ(XiWhK)d)\displaystyle S\left(W_{h}^{Q},W_{h}^{K};X_{i}\right):=\operatorname{Softmax}\left(\frac{X_{i}W_{h}^{Q}\left(X_{i}W_{h}^{K}\right)^{\top}}{\sqrt{d}}\right) (3)

where for a given n×nn\times n matrix ZZ, Sofmax(Z):=[Softmax(Z1),,Softmax(Zn)]\operatorname{Sofmax}(Z):=[\operatorname{Softmax}(Z_{1}),\cdots,\operatorname{Softmax}(Z_{n})]. Throughout, let us denote S()kjS(\cdot)_{kj} as the element of kk-th row and jj-th column in matrix S()S(\cdot). Let XikDX_{ik\cdot}\in\mathbb{R}^{D} denote the embedding of the kk-th token in data XiX_{i}, which is the kk-th row of matrix XiX_{i}. The structure of Transformer model can be found in Fig 1, where we denote Sih:=S(WhQ,WhK;Xi)S_{ih}:=S\left(W_{h}^{Q},W_{h}^{K};X_{i}\right).

Based on the above Transformer model, we consider minimizing the following empirical 2\ell_{2} loss function for the entire data set {Xi,yi}i=1N\{X_{i},y_{i}\}_{i=1}^{N}:

minM12i=1N𝖬𝖧(M;Xi)yi2,\operatorname{min}\limits_{M}\frac{1}{2}\sum_{i=1}^{N}\|{\sf MH}(M;X_{i})-y_{i}\|^{2}, (4)

where M:=(WQ,WK,WV)M:=(W^{Q},W^{K},W^{V}) is the set of variables that can be optimized.

For notation simplicity, next we define the vector version of the Transformer model given in Equation (1), for the entire dataset {(Xi,yi)}i=1N\{(X_{i},y_{i})\}_{i=1}^{N}. Towards this end, let XNn×DX\in\mathbb{R}^{Nn\times D} denote the column-stacked matrix of each single data XiX_{i}. Similarly, define the stacked label yNny\in\mathbb{R}^{Nn}. Then we can define:

𝖬𝖧(M;X):=(S11X1S1HX1SN1XNSNHXN)diag(W1V,,WHV)WO,\displaystyle{\sf MH}(M;X):=\begin{pmatrix}S_{11}X_{1}&\cdots&S_{1H}X_{1}\\ \cdots&\cdots&\cdots\\ S_{N1}X_{N}&\cdots&S_{NH}X_{N}\end{pmatrix}\cdot\operatorname{diag}(W^{V}_{1},\cdots,W^{V}_{H})\cdot W^{O}, (5)

i=1,2,,N,h=1,2,,Hi=1,2,\cdots,N,h=1,2,\cdots,H for simplicity.

Refer to caption
Figure 1: One head in Transformer architecture with Softmax Attention.

Let B:=(S11X1S1HX1SN1XNSNHXN)B:=\begin{pmatrix}S_{11}X_{1}&\cdots&S_{1H}X_{1}\\ \cdots&\cdots&\cdots\\ S_{N1}X_{N}&\cdots&S_{NH}X_{N}\end{pmatrix}, and WV:=diag(W1V,,WHV)HD×HdW^{V}:=\operatorname{diag}(W^{V}_{1},\cdots,W^{V}_{H})\in\mathbb{R}^{HD\times Hd} denote the diagonalized weight matrices that include all value weight matrices for all attention heads. Using these definitions, We can simplify Equation (5) as

𝖬𝖧(M;X)=BWVWO\displaystyle{\sf MH}(M;X)=B\cdot W^{V}\cdot W^{O}

Thus the empirical loss function given in Equation (4) can be simplified as

minM12𝖬𝖧(M;X)y2.\displaystyle\operatorname{min}\limits_{M}\frac{1}{2}\|{\sf MH}(M;X)-y\|^{2}. (6)

For more notations in the following sections, we will use subscript tt to represent the variables in tt-th iteration, e.g, Mt:={WtQ,WtK,WtV}M_{t}:=\{W^{Q}_{t},W^{K}_{t},W^{V}_{t}\}. Similarly, we denote BtB_{t} as the matrix BB at tt-th iteration.

It is important to note that, in the above description and throughout the paper, we model the Transformer training problem by using a single-layer Transformer, with a regression loss. In practice Transformer models can exhibit greater complexity (different loss functions, multiple layers, etc). For example, the text classification task has an additional mean pooling layer followed by the output of the Transformer structure. Further, they usually contain downstream MLP modules. However, we choose to use the simplified version due to the following reasons:

First, the primary objective of this work is to understand how different attention kernels affect the training dynamics of the Transformers, so we do not include the layer normalization in our model. In fact, in the literature, many works that analyze popular network structures also do not consider layer normalization. For example, in (Huang et al., 2023; Zhang et al., 2023), both analyze the convergence performance of Transformers but normalization is not considered.

Second, we do not include the downstream MLP module in our work since we are interested in the role of self-attention layer in convergence analysis, and the single-attention model is also the standard model used in (Huang et al., 2023; Zhang et al., 2023). Further, the analysis of MLP is standard in literature (Allen-Zhu et al., 2019; Du et al., 2019; Nguyen and Mondelli, 2020). And it is worth noting that our choice to focus on a one-layer Transformer is consistent with other works that similarly aim to investigate the core training dynamics of Transformers, e.g, in (Tian et al., 2023), a single-layer Transformer is considered as a basic model.

4 Convergence Analysis

In this section, we present our theoretical analysis for solving problem (6). We focus on the behavior of the vanilla GD algorithm for optimizing the variable set MM, where M{WQ,WK,WV}M\subset\{W^{Q},W^{K},W^{V}\}. Below we summarize our results.

Common convergence conditions with Softmax Attention: When the activation function S()S(\cdot) is either the Softmax or Gaussian function, and the embedding dimension DD is at least 𝒪(Nn)\mathcal{O}(Nn), optimizing Equation (6) can achieve a global optimal solution when M={WV}M=\{W^{V}\} and M={WQ,WK,WV}M=\{W^{Q},W^{K},W^{V}\}.

Different behavior between Softmax and Gaussian Kernel Attention. When S()S(\cdot) is Gaussian and the embedding dimension DD is at least 𝒪(Nn)\mathcal{O}(Nn), convergence to global optimal is also ensured for M={WQ}M=\{W^{Q}\}. Interestingly, under the same conditions of large DD, convergence to global optimal is not guaranteed when S()S(\cdot) is Softmax.

In the subsequent sections, we will elaborate on these convergence results in detail, providing a deeper understanding of the nuances in Transformer behavior under varying configurations. To set up our analysis, we introduce λ¯V\underline{\lambda}^{V} as the smallest eigenvalue of W0VW^{V}_{0}, λ¯B\underline{\lambda}^{B} as the smallest eigenvalue of B0B_{0}, λ¯hQ,λ¯hK,λ¯V\bar{\lambda}_{h}^{Q},\bar{\lambda}_{h}^{K},\bar{\lambda}^{V} as the largest singular value of matrix Wh,0Q,Wh,0K,WVW_{h,0}^{Q},W_{h,0}^{K},W^{V}, respectively. We denote 2\|\cdot\|_{2} as 2\ell_{2} norm and F\|\cdot\|_{F} as Frobenius norm. Further, we denote σmax()\sigma_{\max}(\cdot) and σmin()\sigma_{\min}(\cdot) as the largest and smallest singular value of a matrix, respectively. For any vector vv, let min(|v|)\min(|v|) denote the smallest absolute value of vector vv.

4.1 Convergence to global optimal

First, we examine the role of WVW^{V} in the optimization of multi-head attention network structure. Our analysis demonstrates that with the hidden dimension HDNnHD\geq Nn and proper initialization, the global optimal solution of (6) can be found using a vanilla gradient descent algorithm. The initialization requires that the matrix B0B_{0} has full rank. Our first result shows that, overparameterized Transformer can be trained to global optimal solution.

Theorem 1.

Consider problem (4) with S()S(\cdot) being instantiated as the Softmax kernel given in (3). Consider the following update for the variable M={WV}M=\{W^{V}\}: Wt+1V=WtVηWVf(Mt;X)W^{V}_{t+1}=W^{V}_{t}-\eta\nabla_{W^{V}}f(M_{t};X), where η>0\eta>0 is the stepsize.

Suppose W0QW^{Q}_{0} and W0KW^{K}_{0} are initialized such that λ¯B>0\underline{\lambda}^{B}>0. Then we have:

f(Mt;X)(1ηα)tf(M0;X),f\left(M_{t};X\right)\leq\left(1-\eta\alpha\right)^{t}f\left(M_{0};X\right), (7)

where α:=WO2(λ¯B)2>0\alpha:=\|W^{O}\|^{2}(\underline{\lambda}^{B})^{2}>0; η>0\eta>0 is defined in Appendix 1.3, and chosen such as ηα<1\eta\alpha<1.

Remark 1.

The aforementioned theorem focuses on the convergence behavior when only WVW^{V} is being updated. We further elaborate on the initial conditions ensuring λ¯B>0\underline{\lambda}^{B}>0.

Note that λ¯B>0\underline{\lambda}^{B}>0 implies that the objective function ff exhibits a landscape that is nearly convex, which is crucial for optimization. By definition, this condition implies that B0B_{0} has full rank, which can be fulfilled by selecting appropriate W0QW^{Q}_{0} and W0KW^{K}_{0}, plus having large enough embedding size, satisfying DNn/HD\geq Nn/H. We refer the readers to Appendix 1.3{\rm 1.3} for the derivation of this condition, which can be guaranteed by random initialization with high probability.

Furthermore, it is important to note that our work aligns with existing literature on the subject of embedding size in Transformer models. For example, in (Bhojanapalli et al., 2020), the authors restrict their focus to the simplified case of N=1,H=1N=1,H=1. They establish the necessary condition for Softmax attention to overcome its low-rank bottleneck, which requires DnD\geq n . In our analysis, we derive a similar necessary condition on Transformer model size (Dn×(N/H)D\geq n\times(N/H)) to guarantee the global convergence when a Transformer model is trained with GD.

In Theorem 1, we have illustrated the case where only updating WVW^{V} already leads to global convergence. However, in practice, all parameters WV,WQ,WKW^{V},W^{Q},W^{K} are updated. This case is more challenging to analyze due to the non-linearity introduced by the Softmax function. Next, we show that a similar result in Theorem 1 still holds when all the parameters are updated simultaneously.

Theorem 2.

Consider problem (4), with S()S(\cdot) being instantiated as the Softmax kernel. Consider the GD update where M={WQ,WK,WV}M=\{W^{Q},W^{K},W^{V}\}: Suppose λ¯B>0\underline{\lambda}^{B}>0, and the initialization M0M_{0} satisfy

n2NHXF5h=1H((λ¯hQ)2+(λ¯hK)2)λ¯VWO2(λ¯B)2min(λ¯hQ,λ¯hK,λ¯B)×𝖬𝖧(M0;X)y2ν.\displaystyle\frac{n^{2}\sqrt{NH}\|X\|_{F}^{5}\sum\limits_{h=1}^{H}\left((\bar{\lambda}_{h}^{Q})^{2}+(\bar{\lambda}_{h}^{K})^{2}\right)\bar{\lambda}^{V}}{\|W^{O}\|_{2}\cdot(\underline{\lambda}^{B})^{2}\min{(\bar{\lambda}^{Q}_{h},{\bar{\lambda}}^{K}_{h}},\underline{\lambda}^{B})}\times\|{\sf MH}(M_{0};X)-y\|_{2}\leq\nu. (8)

Then there exists stepsize η>0\eta>0, such that

f(Mt;X)(1ηβ)tf(M0;X),f\left(M_{t};X\right)\leq\left(1-\eta\beta\right)^{t}f\left(M_{0};X\right), (9)

where β:=WO2(λ¯B)2>0\beta:=\|W^{O}\|^{2}(\underline{\lambda}^{B})^{2}>0, and the constants η,ν\eta,\nu are defined in Appendix 1.3.

Remark 2.

In the stated theorem, we simplify our analysis by excluding the downstream MLP module in the typical Transformer model, since it is easy to combine the model in Equation (2) with downstream MLP layers. Further, it can be directly showed that the Transformer with MLP will lead to the same convergence rate of the optimization problem as updating WQ,WK,WVW^{Q},W^{K},W^{V} only. To illustrate this, consider the following Transformer model:

G(WQ,WK,WV;Xi)=𝖬𝖧(WQ,WK,WV;Xi)W1W2WL,\displaystyle\;G\left(W^{Q},W^{K},W^{V};X_{i}\right)={\sf MH}(W^{Q},W^{K},W^{V};X_{i})\cdot W^{1}W^{2}\cdots W^{L}, (10)

where Wlnl1×nl,and n0=dO.\text{where }W^{l}\in\mathbb{R}^{n_{l-1}\times n_{l}},\text{and }n_{0}=d^{O}. Based on the Transformer model defined in Equation (10), we have the following corollary.

Corollary 1.

Consider problem minM12G(M;X)y2\min\limits_{M}\frac{1}{2}\|G(M;X)-y\|^{2}, with G()G(\cdot) being defined in Equation (10) and S()S(\cdot) being instantiated as the Softmax kernel. Suppose that the MLP module satisfies:

n1n2nL.n_{1}\geq n_{2}\cdots\geq n_{L}.

Consider the following GD update (where M={WQ,WK,WV,W1,,WL}M=\{W^{Q},W^{K},W^{V},W^{1},\cdots,W^{L}\}): Suppose λ¯B>0\underline{\lambda}^{B}>0. Then, there exists a step size η>0\eta>0 and initialization weight M0M_{0}, such that the loss function linearly converges to 0.

Remark 3.

The above theorem and corollary describe the global convergence guarantee when WQ,WKW^{Q},W^{K} and WVW^{V} are updated. This is in line with the insights gained from Theorem 1. However, the conditions for initialization are more stringent, and the optimization landscape becomes inherently more complex due to the involvement of the Softmax attention through WQW^{Q} and WKW^{K}.

To ensure the initial condition 8, we have two options: 1) Initializing M0M_{0} such that 𝖬𝖧(M0;X)yF\|{\sf MH}(M_{0};X)-y\|_{F} is small, which implies that the optimization starts in a region close to the global optimal solution and that the initial weight is close to the global optimal solution; 2) Balancing between WOW^{O} and WVW^{V}, in the sense that WO2\|W^{O}\|_{2} is large and λ¯V\bar{\lambda}^{V} is small. For a detailed account of these initialization strategies, please refer to Appendix 1.3.

Finally, we need to point out that for Transformers with Gaussian kernel attention, we can derive similar convergence results as long as the attention kernel maintains full rank and weights are initialized appropriately. Here we do not include the theoretical statement since it is similar to the result for Softmax attention.

4.2 Softmax vs Gaussian kernel: Softmax attention Transformers may exhibit slower convergence.

In the previous section, we explored the global convergence of training Transformer models. However, from Theorem 2, it was not clear what roles do matrices WQW^{Q} and WKW^{K} play in the entire convergence process, since Theorem 1 indicates that optimizing WVW^{V} alone already ensures the desired convergence. Nevertheless, it is the matrices WKW^{K} and WQW^{Q} that truly represent the power of a Transformer model, because they are used to extract token correlations.

To study how well a Transformer model can extract the token correlation, in this section, we will study the GD dynamics for Transformer models, where only WKW^{K} and WQW^{Q} are optimized (while fixing WVW^{V}). If optimizing these two parameters alone can still achieve zero training loss, then we claim that the input token correlation can be optimally extracted by the Transformer model.

4.2.1 Notations

To begin our study, let us define that Gaussian kernel to be an n×nn\times n matrix, where its kk-th row and jj-th column of is given by:

S(WhQ,WhK;Xi)kj=exp(1d(XikWhQXijWhK)2)\displaystyle S\left(W_{h}^{Q},W_{h}^{K};X_{i}\right)_{kj}=\operatorname{\exp}\left(-\frac{1}{\sqrt{d}}\left(X_{ik\cdot}W_{h}^{Q}-X_{ij\cdot}W_{h}^{K}\right)^{2}\right) (11)

Since the training dynamics/gradients of variables WQW^{Q} and WKW^{K} have the same property in (3) and (11), we will only concentrate on optimizing WQW^{Q}.

With some abuse of notation, define a matrix CC for Softmax attention and Gaussian kernel attention, respectively. Softmax attention: Cih:=XiWhQ(XiWhK)dn×n.C_{ih}:=\frac{X_{i}W_{h}^{Q}\left(X_{i}W_{h}^{K}\right)^{\top}}{\sqrt{d}}\in\mathbb{R}^{n\times n}.
Gaussian kernel attention: Cihn×n;(Cih)kj=XikWhQXijWhK22d.C_{ih}\in\mathbb{R}^{n\times n};\;{(C_{ih})}_{kj}=-\frac{\left\|X_{ik\cdot}W_{h}^{Q}-X_{ij\cdot}W_{h}^{K}\right\|^{2}}{2\sqrt{d}}.
For both Softmax attention and Gaussian kernel attention:

Cin×Hn=[Ci1,Ci2,,CiH];CNn×Hn=[C1,C2,,CN].\displaystyle C_{i}\in\mathbb{R}^{n\times Hn}=\left[C_{i1},C_{i2},\cdots,C_{iH}\right];\;C\in\mathbb{R}^{Nn\times Hn}=\left[C_{1}^{\top},C_{2}^{\top},\cdots,C_{N}^{\top}\right]^{\top}.

Using the above notation, the activation function S()S(\cdot) in (3) and (11) can be related to the matrices CC’s in the following manner:

Softmax attention:Sih=Softmax(Cih),Gaussian attention:(Sih)kj=exp((Cih)kj).\displaystyle\mbox{Softmax attention}:S_{ih}=\operatorname{Softmax}\left(C_{ih}\right),\;\mbox{Gaussian attention}:\left(S_{ih}\right)_{kj}=\exp\big{(}\left(C_{ih}\right)_{kj}\big{)}.\vspace{-0.2cm}

Additionally, note that CC is a function of variables MM. Therefore we will sometimes use C(M)C(M) when we need to emphasize the dependency of CC on MM.

4.2.2 Main Results

Next, we will outline the conditions under which GD can still successfully find global optimal solutions for Transformers with Gaussian kernel attention (when only WQW^{Q} is updated), while under the same set of conditions, but with Softmax kernel attention, GD fails.

Theorem 3.

Solve problem (4) with the following GD update (with M={WQ}M=\{W^{Q}\}): Wt+1Q=WtQηWQf(Mt;X)W^{Q}_{t+1}=W^{Q}_{t}-\eta\nabla_{W^{Q}}f(M_{t};X). Suppose δh:=σmin(C(M0)WhQ)>0,h[1,2,,H]\delta_{h}:=\sigma_{\min}(\frac{\partial C(M_{0})}{\partial W_{h}^{Q}})>0,\;\forall\leavevmode\nobreak\ h\in[1,2,\cdots,H], and the initialization condition further satisfies

nXF5(λ¯hQ+λ¯hK)exp(94XF2((λ¯hQ)2+(λ¯hK)2))(min(|VWO|))2min(δh,λ¯hQ)×λ¯VWO2𝖬𝖧(M0;X)y2ν,\displaystyle\frac{n\|X\|_{F}^{5}\big{(}\bar{\lambda}_{h}^{Q}+\bar{\lambda}_{h}^{K}\big{)}\exp\big{(}\frac{9}{4}\|X\|_{F}^{2}\big{(}(\bar{\lambda}_{h}^{Q})^{2}+(\bar{\lambda}_{h}^{K})^{2}\big{)}\big{)}}{\big{(}\min(|V^{\prime}W^{O}|)\big{)}^{2}\cdot\min(\delta_{h},\bar{\lambda}_{h}^{Q})}\times\bar{\lambda}^{V}\|W^{O}\|_{2}\cdot\left\|{\sf MH}\left(M_{0};X\right)-y\right\|_{2}\leq\nu^{\prime}, (12)

ν\nu^{\prime} is defined in Appendix 1.5{\rm 1.5}.
(1) When S()S(\cdot) is a Gaussian kernel function, there exists a stepsize η\eta and a positive constant γ\gamma, such that

f(Mt;X)(1ηγ)tf(M0;X),f\left(M_{t};X\right)\leq\left(1-\eta\gamma\right)^{t}f\left(M_{0};X\right), (13)

where γ,η\gamma,\eta are defined in Appendix 1.5.
(2) When S()S(\cdot) is a Softmax function, suppose WtQW^{Q}_{t} is bounded during the training phase, then there exists stepsize η\eta, such that

f(Mt;X)f(M0;X)ηr=0t1WQf(Mr;X)2,\displaystyle f\left(M_{t};X\right)\leq f\left(M_{0};X\right)-\eta^{\prime}\sum\limits_{r=0}^{t-1}\|\nabla_{W^{Q}}f\left(M_{r};X\right)\|^{2}, (14)

where η\eta^{\prime} is defined in Appendix 1.5.

Remark 4.

First, it’s important to note that the parameter size must satisfy DdNn2Dd\geq Nn^{2} for δ>0\delta>0 to hold. It is crucial to emphasize the fundamental distinction in convergence outcomes between Transformers employing Gaussian kernel attention and those utilizing Softmax attention under these conditions. With equivalent initialization conditions, training Transformers equipped with Gaussian kernel attention achieves global convergence using gradient descent (GD). Second, it is essential to emphasize that the dimension size DdNn2Dd\geq Nn^{2} is similar to the findings of works that have analyzed the convergence performance of over-parameterized neural networks Allen-Zhu et al. (2019); Du et al. (2019). The total number of samples, consisting of NN samples each with nn tokens, can be calculated as NnNn. Meanwhile, the total feature dimension is DdDd. The inequality implies that the width of the parameters is at least 𝒪(N)\mathcal{O}(N), a relationship also illustrated in Nguyen and Mondelli (2020).

In part (2), we demonstrate that the PL condition does not hold. In particular, we identify an initial solution that satisfies all the conditions given in Theorem 3, yet fails to satisfy the PL condition. Therefore, in this case, GD leads to vanishing gradients without being able to find a global optimal solution. The details of this specific example are provided below.

Example: Consider Transformer with Softmax attention, and N=1,n=2,H=1N=1,n=2,H=1. Let us first write down the close form of the gradient over W1QW_{1}^{Q}:

f(M0;X1)W1Q=1dX1f(M0;X1)C11X1W1,0K\displaystyle\frac{\partial f\left(M_{0};X_{1}\right)}{\partial W_{1}^{Q}}=\frac{1}{\sqrt{d}}X_{1}^{\top}\frac{\partial f\left(M_{0};X_{1}\right)}{\partial C_{11}}X_{1}W_{1,0}^{K}

Next, we show there exists WO,WV,X1,W1,0Q,W1,0KW^{O},W^{V},X_{1},W^{Q}_{1,0},W^{K}_{1,0} such that the loss function is non-zero with Equation (12) satisfied, while

f(M0;X1)C11=𝟎2×2.\frac{\partial f\left(M_{0};X_{1}\right)}{\partial C_{11}}=\mathbf{0}\in\mathbb{R}^{2\times 2}.

Denote L:=f(M0;X1)𝖬𝖧(M0;X1)(WO)(X1W0V)2×2L:=\frac{\partial f\left(M_{0};X_{1}\right)}{\partial{\sf MH}(M_{0};X_{1})}\left(W^{O}\right)^{\top}\left(X_{1}W^{V}_{0}\right)^{\top}\in\mathbb{R}^{2\times 2}. f(M0;X1)C11\frac{\partial f\left(M_{0};X_{1}\right)}{\partial C_{11}} can be expressed as follows:/

(f(M0;X1)C11)11=δ(L11L12),(f(M0;X1)C11)12=δ(L12L11),δ is some constant.\displaystyle\left({\frac{\partial f\left(M_{0};X_{1}\right)}{\partial C_{11}}}\right)_{11}=\delta\cdot(L_{11}-L_{12}),\;\left({\frac{\partial f\left(M_{0};X_{1}\right)}{\partial C_{11}}}\right)_{12}=\delta\cdot(L_{12}-L_{11}),\delta\text{ is some constant.}

Next, we will give the value of WO,W0VW^{O},W^{V}_{0} to show the case where GD leads to vanishing gradient. Let D=d=2D=d=2, WO=(1a,1a),X1=(1001)W^{O}=(\frac{1}{a},\frac{1}{a}),X_{1}=\begin{pmatrix}1&0\\ 0&1\end{pmatrix}, and W0V=(2aaa2a)W^{V}_{0}=\begin{pmatrix}2a&a\\ a&2a\end{pmatrix}, where aa is a constant. It is easy to show that there exists W1QW^{Q}_{1} and W1KW^{K}_{1} such that Equation (12) holds. Further, it is easy to verify that for this scenario, the following holds:

L11=L12,L21=L22.\displaystyle L_{11}=L_{12},L_{21}=L_{22}. (15)

Next, we can easily deduce that (f(M0;X1)C11)11=(f(M0;X1)C11)12=0\left(\frac{\partial f(M_{0};X_{1})}{\partial C_{11}}\right)_{11}=\left(\frac{\partial f(M_{0};X_{1})}{\partial C_{11}}\right)_{12}=0. Similarly, we can demonstrate that (f(M0;X1)C11)21=(f(M0;X1)C11)22=0\left(\frac{\partial f(M_{0};X_{1})}{\partial C_{11}}\right)_{21}=\left(\frac{\partial f(M_{0};X_{1})}{\partial C_{11}}\right)_{22}=0. Consequently, we have f(M0;X1)W1Q=𝟎\frac{\partial f(M_{0};X_{1})}{\partial W^{Q}_{1}}=\mathbf{0}. However, if y1y_{1} satisfies that f(M0;X1)𝖬𝖧(M0;X1)𝟎\frac{\partial f\left(M_{0};X_{1}\right)}{\partial{\sf MH}\left(M_{0};X_{1}\right)}\neq\mathbf{0}, it follows f(M0;X1)0f(M_{0};X_{1})\neq 0, which means M0M_{0} is not global optimal solution.

5 Experiment: Softmax v.s. Gaussian

Refer to caption
Refer to caption
Figure 2: Test performance on text classification task with different attention kernels

In this section, we present numerical results to illustrate the behaviors of Transformers models with Softmax attention and Gaussian kernel attention across various tasks.

5.1 Dataset

Refer to caption
Refer to caption
Figure 3: Test performance on pathfinder task with different attention kernels

We investigate two distinct tasks: Text Classification using the IMDb review dataset (Maas et al., 2011) and Pathfinder (Linsley et al., 2018). While both tasks involve processing long sequences, they exhibit different characteristics. Text Classification is a well-known NLP task that focuses on discerning relationships among token embeddings, while the Pathfinder task prioritizes capturing spatial information within the input pixels.

5.2 Model and Experiment Method

We follow the experiment setting in (Chen et al., 2021). For both tasks, we employ a 2-layer Transformer model with the following specifications: embedding dimension D=64D=64, hidden dimension d=128d=128, and number of attention heads H=2H=2. To align the model with the classification task, we use an additional mean pooling layer as the final layer. We determine the batch size based on available memory constraints. Specifically, we set a batch size of 16 for the Text Classification task with a learning rate of 1×1041\times 10^{-4}, and a batch size of 128 for the Pathfinder task with a learning rate of 2×1042\times 10^{-4}. For optimization, we use Stochastic Gradient Descent (SGD) for the Text Classification task and Adam for the Pathfinder task. We conduct two types of experiments.

In the first experiment, we plot the test accuracy and test loss within the training steps with both Softmax and Gaussian kernel attention on both tasks. We repeat the training for 1010 times and make the shadow plot on the test performance.

Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 4: The loss landscapes on text classification task and Pathfinder task. For both tasks, we use the two-stage training in Section 5.2 with the same training hyperparameters, while the only difference is the attention structure in the second training stage. The two axes represent the two directions d1d_{1} and d2d_{2} as defined in Section 5.2.

In our second experiment, the training process consists of two stages: In the first stage, we train the Transformer model equipped with Softmax attention (defined in Equation (3)) for 8,000 steps. In the second stage, we continue training from the pre-trained model for an additional 500 steps, with the option of using either Softmax or Gaussian kernel. To explore the optimization landscape around the trained model, we employed a technique inspired by Li et al. (2018). We select two parameter directions, specifically the WQW^{Q} and WKW^{K} matrices in the first Transformer layer. These two directions, denoted as d1,d2d_{1},d_{2}, are centered at the trained model MM, and represent the parameter space of WQ,WKW^{Q},W^{K}, respectively. We evaluate the loss function on the set {M+0.02(r25)d1+0.02(s25)d2}\{M+0.02(r-25)d_{1}+0.02(s-25)d_{2}\}, where r,s[1,2,,50]r,s\in[1,2,\cdots,50]. The above set is the neighborhood of the trained model MM, and we chose the evaluation stepsize as 0.020.02 along the two directions d1,d2d_{1},d_{2}, with the total steps limit as 100100. Within this parameter space, we plot a 3-D surface representing the landscape around the trained model.

5.3 Results

5.3.1 Test Loss & Accuracy Curve comparison

To begin with, we present some observations in our first experiment. We plot the test performance of these two tasks on Transformers with two different types of attention. From Fig 2 and Fig 3, we can conclude that in both tasks, Transformers with Gaussian kernel attention exhibit faster convergence and higher test accuracy than Softmax attention with the same model size and learning rate. Especially, training Transformers with Softmax attention in the Pathfinder task can lead to unstable performance as indicated in Fig 3. The test accuracy has a significantly higher variance at the same training epoch. Further, the worst test accuracy after 20,00020,000 epochs is around 0.580.58 for the Softmax attention Transformer, compared with 0.620.62 for the Gaussian kernel Transformer. These observations align with the experiment results in (Chen et al., 2021) and (Tay et al., 2020), where Transformers with different attention kernels are trained with the same model size and learning rate, while Softmax attention Transformers show instability in a few tasks.

5.3.2 Optimization Landscape Comparison

In Figure 4, we present a comparison of the optimization landscape between Transformers with Softmax and Gaussian kernel attention. Notably, we observe distinct differences in the training landscapes of these two attention types for both tasks. We follow the visualization method described in Section 5.2. We conduct a visualization of the optimization landscape around the trained models after a two-stage training process, with identical learning rates, network sizes, and training epochs. Keeping all other factors consistent, the disparity in the landscape provides a direct representation of the difference in the attention structure during the optimization procedure. With Softmax attention, the landscape appears more complicated compared with Gaussian kernel attention. This complexity can be interpreted as the presence of a greater number of local optima in the optimization landscape, suggesting that Transformers utilizing Softmax attention may encounter more challenges in reaching global optimal solutions. In contrast, the landscape with the Gaussian kernel is flatter. This observation aligns with our earlier findings in Figure 2 and Figure 3, where Softmax attention exhibited certain convergence issues. These observations also provide empirical evidence supporting our Theorem 3, which reflects in a slightly different perspective the complicated optimization landscape within the Softmax kernel.

6 Conclusion and Future Work

In conclusion, our study addresses critical gaps in our understanding of why Transformer models perform exceptionally well in a variety of machine learning tasks. Our work also provides a nuanced understanding of the advantages and disadvantages of using classical Softmax attention in Transformers. We find that while shallow Softmax attention Transformers can achieve global convergence with overparameterization, there are scenarios where this attention structure can lead to local solutions. However, those issues can be mitigated by the Gaussian kernel-based attention. In our work, we need strong initialization and large embedding size, i.e, HDNnHD\geq Nn to obtain the global convergence, which exhibits a gap towards real case. In the future work, we will investigate how to relax the assumptions.

7 Acknowledgment

The work of B. Song was partially done while interning at Amazon Web Services. M. Hong holds concurrent appointments as an Amazon Scholar and as a faculty at the University of Minnesota. This paper describes their work performed at Amazon. The work of Jie Ding was supported in part by the Army Research Office Early Career Program Award under grant number W911NF2310315.

References

  • Allen-Zhu et al. (2019) Z. Allen-Zhu, Y. Li, and Z. Song. A convergence theory for deep learning via over-parameterization. In International conference on machine learning, pages 242–252. PMLR, 2019.
  • Beltagy et al. (2020) I. Beltagy, M. E. Peters, and A. Cohan. Longformer: The long-document transformer. arXiv preprint arXiv:2004.05150, 2020.
  • Bhojanapalli et al. (2020) S. Bhojanapalli, C. Yun, A. S. Rawat, S. Reddi, and S. Kumar. Low-rank bottleneck in multi-head attention models. In International conference on machine learning, pages 864–873. PMLR, 2020.
  • Brown et al. (2020) T. Brown, B. Mann, N. Ryder, M. Subbiah, J. D. Kaplan, P. Dhariwal, A. Neelakantan, P. Shyam, G. Sastry, A. Askell, et al. Language models are few-shot learners. Advances in neural information processing systems, 33:1877–1901, 2020.
  • Chen et al. (2021) Y. Chen, Q. Zeng, H. Ji, and Y. Yang. Skyformer: Remodel self-attention with gaussian kernel and nystr\\backslash” om method. Advances in Neural Information Processing Systems, 34:2122–2135, 2021.
  • Danilova et al. (2022) M. Danilova, P. Dvurechensky, A. Gasnikov, E. Gorbunov, S. Guminov, D. Kamzolov, and I. Shibaev. Recent theoretical advances in non-convex optimization. In High-Dimensional Optimization and Probability: With a View Towards Data Science, pages 79–163. Springer, 2022.
  • Devlin et al. (2018) J. Devlin, M.-W. Chang, K. Lee, and K. Toutanova. Bert: Pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805, 2018.
  • Dosovitskiy et al. (2020) A. Dosovitskiy, L. Beyer, A. Kolesnikov, D. Weissenborn, X. Zhai, T. Unterthiner, M. Dehghani, M. Minderer, G. Heigold, S. Gelly, et al. An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929, 2020.
  • Du et al. (2019) S. Du, J. Lee, H. Li, L. Wang, and X. Zhai. Gradient descent finds global minima of deep neural networks. In International conference on machine learning, pages 1675–1685. PMLR, 2019.
  • He et al. (2020) P. He, X. Liu, J. Gao, and W. Chen. Deberta: Decoding-enhanced bert with disentangled attention. arXiv preprint arXiv:2006.03654, 2020.
  • Huang et al. (2020) X. S. Huang, F. Perez, J. Ba, and M. Volkovs. Improving transformer optimization through better initialization. In International Conference on Machine Learning, pages 4475–4483. PMLR, 2020.
  • Huang et al. (2023) Y. Huang, Y. Cheng, and Y. Liang. In-context convergence of transformers. arXiv preprint arXiv:2310.05249, 2023.
  • Jain et al. (2017) P. Jain, P. Kar, et al. Non-convex optimization for machine learning. Foundations and Trends® in Machine Learning, 10(3-4):142–363, 2017.
  • Jin et al. (2021) C. Jin, P. Netrapalli, R. Ge, S. M. Kakade, and M. I. Jordan. On nonconvex optimization for machine learning: Gradients, stochasticity, and saddle points. Journal of the ACM (JACM), 68(2):1–29, 2021.
  • Langley (2000) P. Langley. Crafting papers on machine learning. In P. Langley, editor, Proceedings of the 17th International Conference on Machine Learning (ICML 2000), pages 1207–1216, Stanford, CA, 2000. Morgan Kaufmann.
  • Li et al. (2023) G. Li, G. Wang, and J. Ding. Provable identifiability of two-layer relu neural networks via lasso regularization. IEEE Transactions on Information Theory, 2023.
  • Li et al. (2018) H. Li, Z. Xu, G. Taylor, C. Studer, and T. Goldstein. Visualizing the loss landscape of neural nets. Advances in neural information processing systems, 31, 2018.
  • Linsley et al. (2018) D. Linsley, J. Kim, V. Veerabadran, C. Windolf, and T. Serre. Learning long-range spatial dependencies with horizontal gated recurrent units. Advances in neural information processing systems, 31, 2018.
  • Liu et al. (2020) L. Liu, X. Liu, J. Gao, W. Chen, and J. Han. Understanding the difficulty of training transformers. arXiv preprint arXiv:2004.08249, 2020.
  • Liu et al. (2019) Y. Liu, M. Ott, N. Goyal, J. Du, M. Joshi, D. Chen, O. Levy, M. Lewis, L. Zettlemoyer, and V. Stoyanov. Roberta: A robustly optimized bert pretraining approach. arXiv preprint arXiv:1907.11692, 2019.
  • Maas et al. (2011) A. Maas, R. E. Daly, P. T. Pham, D. Huang, A. Y. Ng, and C. Potts. Learning word vectors for sentiment analysis. In Proceedings of the 49th annual meeting of the association for computational linguistics: Human language technologies, pages 142–150, 2011.
  • Nguyen and Mondelli (2020) Q. N. Nguyen and M. Mondelli. Global convergence of deep networks with one wide layer followed by pyramidal topology. Advances in Neural Information Processing Systems, 33:11961–11972, 2020.
  • Noci et al. (2022) L. Noci, S. Anagnostidis, L. Biggio, A. Orvieto, S. P. Singh, and A. Lucchi. Signal propagation in transformers: Theoretical perspectives and the role of rank collapse. Advances in Neural Information Processing Systems, 35:27198–27211, 2022.
  • Pan and Li (2023) Y. Pan and Y. Li. Toward understanding why adam converges faster than sgd for transformers. arXiv preprint arXiv:2306.00204, 2023.
  • Radford et al. (2019) A. Radford, J. Wu, R. Child, D. Luan, D. Amodei, I. Sutskever, et al. Language models are unsupervised multitask learners. OpenAI blog, 1(8):9, 2019.
  • Shazeer (2020) N. Shazeer. Glu variants improve transformer. arXiv preprint arXiv:2002.05202, 2020.
  • Tay et al. (2020) Y. Tay, M. Dehghani, S. Abnar, Y. Shen, D. Bahri, P. Pham, J. Rao, L. Yang, S. Ruder, and D. Metzler. Long range arena: A benchmark for efficient transformers. arXiv preprint arXiv:2011.04006, 2020.
  • Tian et al. (2023) Y. Tian, Y. Wang, B. Chen, and S. Du. Scan and snap: Understanding training dynamics and token composition in 1-layer transformer. arXiv preprint arXiv:2305.16380, 2023.
  • Vaswani et al. (2017) A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, Ł. Kaiser, and I. Polosukhin. Attention is all you need. Advances in neural information processing systems, 30, 2017.
  • Wu et al. (2024) Y. Wu, F. Liu, G. Chrysos, and V. Cevher. On the convergence of encoder-only shallow transformers. Advances in Neural Information Processing Systems, 36, 2024.
  • Zhang et al. (2023) R. Zhang, S. Frei, and P. L. Bartlett. Trained transformers learn linear models in-context. arXiv preprint arXiv:2306.09927, 2023.

1 Appendix

1.1 Notations

Recall that we have defined the structure of a single Transformer model in Equation (1) and Equation (5). We will further define a few notations before we introduce a few useful lemmas that are needed in our proof.
(1) Operator: Denote vec()\operatorname{vec}(\cdot) as the vectorization operator on a matrix; \otimes as Kronecker product operator; \odot as the element product. Denote Υ()\Upsilon(\cdot) as a matrix operator, such that for any matrix XX without zero element

Υ(Xm×n)=[1/x111/x1n1/xm11/xmn]m×n\displaystyle\Upsilon\left(X_{m\times n}\right)=\left[\begin{array}[]{ccc}1/x_{11}&\cdots&1/x_{1n}\\ \vdots&\ddots&\vdots\\ 1/x_{m1}&\cdots&1/x_{mn}\end{array}\right]_{m\times n} (19)

(2) Matrix: Denote 𝕀\mathbb{I} as the identity matrix. Define matrix 𝔼\mathbb{E} and EE as following:

𝔼=[EE]Hn×Hn,E=[1111]n×n.\mathbb{E}=\left[\begin{array}[]{lll}E&&\\ &\ddots&\\ &&E\end{array}\right]_{Hn\times Hn},\quad E=\left[\begin{array}[]{ccc}1&\cdots&1\\ \vdots&\ddots&\vdots\\ 1&\cdots&1\end{array}\right]_{n\times n}.

Define matrix h\mathbb{P}_{h} as following: h=(,En×nh,),h=1,,H\mathbb{P}_{h}=\left(\ldots,E_{n\times n}^{h},\ldots\right),\;h=1,\cdots,H.
(3) Matrix in Transformer: Define the following matrix CC related to the attention layer

Softmax kernel:
Cih=XiWhQ(XiWhK)d,Sih=Softmax(Cih)\displaystyle C_{ih}=\frac{X_{i}W_{h}^{Q}\left(X_{i}W_{h}^{K}\right)^{\top}}{\sqrt{d}},\;S_{ih}=\operatorname{Softmax}(C_{ih}) (20)
Gaussian kernel:
(Cih)kj=XikWhQXijWhK22d,(Sih)kj=exp((Cih)kj)\displaystyle{(C_{ih})}_{kj}=-\frac{\left\|X_{ik\cdot}W_{h}^{Q}-X_{ij\cdot}W_{h}^{K}\right\|^{2}}{2\sqrt{d}},\;(S_{ih})_{kj}=\operatorname{\exp}\big{(}(C_{ih})_{kj}\big{)} (21)
Ci=[Ci1,,CiH],Si=[Si1,,SiH]\displaystyle C_{i}=[C_{i1},\cdots,C_{iH}],\;S_{i}=[S_{i1},\cdots,S_{iH}] (22)

Define matrix ViV^{\prime}_{i} for each data XiX_{i}:

Vi=[XiW1VXiWHV]Hn×d,V=[V1,,VN].\displaystyle V_{i}^{\prime}=\left[\begin{array}[]{ccc}X_{i}W_{1}^{V}&&\\ &\ddots&\\ &&X_{i}W_{H}^{V}\end{array}\right]_{Hn\times d},\;V=[V_{1}^{\top},\cdots,V_{N}^{\top}]^{\top}. (26)

Next, let us introduce several useful lemma which leads to Theorem 2:

1.2 Lemmas of Theorem 2

Lemma 1.
(1)f(M;X)WV=B(𝖬𝖧(M;X)y)(WO)\displaystyle(1)\leavevmode\nobreak\ \frac{\partial f(M;X)}{\partial W^{V}}=B^{\top}\left({\sf MH}(M;X\right)-y)\left(W^{O}\right)^{\top} (27)
(2)vec(f(M;X)WV)=(WO)B,vec(𝖬𝖧(M;X)y)\displaystyle(2)\leavevmode\nobreak\ \operatorname{vec}\left(\frac{\partial f(M;X)}{\partial W^{V}}\right)=\left\langle(W^{O})^{\top}\otimes B,\operatorname{vec}({\sf MH}(M;X)-y)\right\rangle
=(𝕀HdB)(WO𝕀N)(𝖬𝖧(M;X)y)\displaystyle\quad=\left(\mathbb{I}_{Hd}\otimes B^{\top}\right)\cdot\left(W^{O}\otimes\mathbb{I}_{N}\right)\cdot\left({\sf MH}(M;X)-y\right) (28)
(3)f(M;X)WhQ=1dXhf(M;X)CXWhK=i=1N1dXihf(M;Xi)CiXiWhK\displaystyle(3)\leavevmode\nobreak\ \frac{\partial f(M;X)}{\partial W_{h}^{Q}}=\frac{1}{\sqrt{d}}X^{\top}\mathbb{P}_{h}\frac{\partial f(M;X)}{\partial C}XW_{h}^{K}=\sum_{i=1}^{N}\frac{1}{\sqrt{d}}X_{i}^{\top}\mathbb{P}_{h}\frac{\partial f(M;X_{i})}{\partial C_{i}}X_{i}W_{h}^{K} (29)
(4)f(M;Xi)Ci=((𝖬𝖧(M;Xi)yi)(WO)(Vi))Si\displaystyle(4)\leavevmode\nobreak\ \frac{\partial f(M;X_{i})}{\partial C_{i}}=\left(({\sf MH}(M;X_{i})-y_{i})\left(W^{O}\right)^{\top}\left(V^{\prime}_{i}\right)^{\top}\right)\odot S_{i} (30)
((((𝖬𝖧(M;Xi)yi)(WO)(Vi))SiΥ((expCi)𝔼))𝔼)expCi\displaystyle\quad-\left(\left(\left(({\sf MH}(M;X_{i})-y_{i})\left(W^{O}\right)^{\top}\left(V^{\prime}_{i}\right)^{\top}\right)\odot S_{i}\odot\Upsilon\big{(}(\exp C_{i})\mathbb{E}\big{)}\right)\mathbb{E}^{\top}\right)\odot\exp C_{i} (31)
(5)f(M;X)C=diag(f(M;X1)C1,,f(M;XN)CN)\displaystyle(5)\leavevmode\nobreak\ \frac{\partial f(M;X)}{\partial C}=\operatorname{diag}\left(\frac{\partial f(M;X_{1})}{\partial C_{1}},\cdots,\frac{\partial f(M;X_{N})}{\partial C_{N}}\right) (32)

Remark: The above lemma derives the closed form of the gradient of objective over WV,WQW^{V},W^{Q}. Notice that we can derive the derivative of WKW^{K} in the same way as WQW^{Q} due to symmetry, so we do not include the derivation here. Some of the lemmas here refers https://say-hello2y.github.io/2022-09-07/attention-gradient

Lemma 2.

Consider updating WQ,WK,WVW^{Q},W^{K},W^{V} at iteration tt. Suppose σmax(WQ)\sigma_{\max}(W^{Q}), σmax(WK)\sigma_{\max}(W^{K}), σmax(WV)\sigma_{\max}(W^{V}) are bounded during in the optimization phase, then we have the following conclusion:

(1)d(Si)Fϕid(WQ)F,where ϕi=ndXiF2h=1Hσmax2(WhK)\displaystyle(1)\;\|d(S_{i})\|_{F}\leq\phi_{i}\|d(W^{Q})\|_{F},\quad\text{where }\phi_{i}=\frac{n}{\sqrt{d}}\left\|X_{i}\right\|_{F}^{2}\sqrt{\sum_{h=1}^{H}\sigma_{\max}^{2}\left(W_{h}^{K}\right)} (33)
(2)d(Si)Fψid(WK)F,where ψi=ndXiF2h=1Hσmax2(WhQ).\displaystyle(2)\;\|d(S_{i})\|_{F}\leq\psi_{i}\|d(W^{K})\|_{F},\quad\text{where }\psi_{i}=\frac{n}{\sqrt{d}}\left\|X_{i}\right\|_{F}^{2}\sqrt{\sum_{h=1}^{H}\sigma_{\max}^{2}\left(W_{h}^{Q}\right)}. (34)
(3)d(Si)Fϕi2+ψi2d(WQ),d(WK)F.\displaystyle(3)\;\|d(S_{i})\|_{F}\leq\sqrt{\phi_{i}^{2}+\psi_{i}^{2}}\cdot\|d(W^{Q}),d(W^{K})\|_{F}. (35)
(4)f(M;Xi)WQFQi𝖬𝖧(M;Xi)yiF,\displaystyle(4)\;\|\frac{\partial f(M;X_{i})}{\partial W^{Q}}\|_{F}\leq Q_{i}\|{\sf MH}\left(M;X_{i}\right)-y_{i}\|_{F},
where Qi=nHXiF3WO2h=1Hσmax2(WhK)σmax(WV).\displaystyle\quad\text{where }Q_{i}=n\sqrt{H}\left\|X_{i}\right\|_{F}^{3}\left\|W^{O}\right\|_{2}\sqrt{\sum_{h=1}^{H}\sigma_{\max}^{2}\left(W_{h}^{K}\right)}\cdot\sigma_{\max}\left(W^{V}\right). (36)
(5)f(M;Xi)WKFKi𝖬𝖧(M;Xi)yiF,\displaystyle(5)\;\|\frac{\partial f(M;X_{i})}{\partial W^{K}}\|_{F}\leq K_{i}\|{\sf MH}\left(M;X_{i}\right)-y_{i}\|_{F},
where Ki=nHXiF3WO2h=1Hσmax2(WhQ)σmax(WV).\displaystyle\quad\text{where }K_{i}=n\sqrt{H}\left\|X_{i}\right\|_{F}^{3}\left\|W^{O}\right\|_{2}\sqrt{\sum_{h=1}^{H}\sigma_{\max}^{2}\left(W_{h}^{Q}\right)}\cdot\sigma_{\max}\left(W^{V}\right). (37)
Lemma 3.

Consider updating WQ,WK,WVW^{Q},W^{K},W^{V} at iteration tt. Suppose σmax(WQ)\sigma_{\max}(W^{Q}), σmax(WK)\sigma_{\max}(W^{K}), σmax(WV)\sigma_{\max}(W^{V}) are bounded during in the optimization phase, then we have the following conclusion:

(1)𝖬𝖧(Mt+1;X)𝖬𝖧(Mt;X)FZMt+1MtF,where Z is some positive constant.\displaystyle(1)\;\|{\sf MH}(M_{t+1};X)-{\sf MH}(M_{t};X)\|_{F}\leq Z\|M_{t+1}-M_{t}\|_{F},\text{where }Z\text{ is some positive constant.} (38)
(2)f(Mt+1;X)f(Mt;X)2GMt+1MtF,where G is some positive constant.\displaystyle(2)\;\left\|\nabla f\left(M_{t+1};X\right)-\nabla f\left(M_{t};X\right)\right\|_{2}\leq G\|M_{t+1}-M_{t}\|_{F},\text{where }G\text{ is some positive constant.} (39)
Lemma 4.

Let f:nf:\mathbb{R}^{n}\rightarrow\mathbb{R} be a second order differentiable function. Let x,ynx,y\in\mathbb{R}^{n} be given, and assume that f(z)f(x)2Czx2\|\nabla f(z)-\nabla f(x)\|_{2}\leq C\|z-x\|_{2} for every z=x+t(yx)z=x+t(y-x) with t[0,1]t\in[0,1]. Then,

f(y)f(x)+f(x),yx+C2xy2.f(y)\leq f(x)+\langle\nabla f(x),y-x\rangle+\frac{C^{\prime}}{2}\|x-y\|^{2}.
Lemma 5.

For matrix Ak×l,Bl×m,Cm×nA\in\mathbb{R}^{k\times l},B\in\mathbb{R}^{l\times m},C\in\mathbb{R}^{m\times n}.

vec(ABC)\displaystyle\operatorname{vec}(ABC) =(InAB)vec(C)=(CTBTIk)vec(A)\displaystyle=\left(I_{n}\otimes AB\right)\operatorname{vec}(C)=\left(C^{\mathrm{T}}B^{\mathrm{T}}\otimes I_{k}\right)\operatorname{vec}(A)
vec(AB)\displaystyle\operatorname{vec}(AB) =(ImA)vec(B)=(BTIk)vec(A)\displaystyle=\left(I_{m}\otimes A\right)\operatorname{vec}(B)=\left(B^{\mathrm{T}}\otimes I_{k}\right)\operatorname{vec}(A)
vec(AB)\displaystyle\operatorname{vec}(A\odot B) =vec(A)vec(B).\displaystyle=\operatorname{vec}(A)\odot\operatorname{vec}(B).

1.3 Proof of Theorem 2

Proof Sketch of Theorem 2:
The main idea of the proof follows from [Nguyen and Mondelli, 2020]. Let us first recall a few notations. λ¯V:=23(1+σmax(W0V)),λ¯B:=σmin(B0)\bar{\lambda}^{V}:=\frac{2}{3}\big{(}1+\sigma_{\max}(W^{V}_{0})\big{)},\;\underline{\lambda}^{B}:=\sigma_{\min}(B_{0}). Using GD update rule, we aim to iteratively show

{σmax(WrV)32λ¯V,r{0,,t},σmax(WrQ)32λ¯Q,r{0,,t},σmax(WrK)32λ¯K,r{0,,t},σmin(Br)12λ¯B,r{0,,t},f(Mr;X)(1ημ)rf(M0,X),r{0,,t}\displaystyle\left\{\begin{array}[]{l}\sigma_{\max}(W_{r}^{V})\leq\frac{3}{2}\bar{\lambda}^{V},r\in\{0,\ldots,t\},\\ \sigma_{\max}(W^{Q}_{r})\leq\frac{3}{2}\bar{\lambda}^{Q},r\in\{0,\ldots,t\},\\ \sigma_{\max}(W^{K}_{r})\leq\frac{3}{2}\bar{\lambda}^{K},r\in\{0,\ldots,t\},\\ \sigma_{\min}\left(B_{r}\right)\geq\frac{1}{2}\underline{\lambda}^{B},r\in\{0,\ldots,t\},\\ f\left(M_{r};X\right)\leq(1-\eta\mu)^{r}f\left(M_{0},X\right),\;r\in\{0,\ldots,t\}\end{array}\right. (45)

Denote μ:=14(λ¯B)2WO22\mu:=\frac{1}{4}(\underline{\lambda}^{B})^{2}\|W^{O}\|_{2}^{2}. Let us discuss about the value of μ\mu. We know WOHd×1W^{O}\in\mathbb{R}^{Hd\times 1}, B0HD×NnB_{0}^{\top}\in\mathbb{R}^{HD\times Nn}. We require μ>0\mu>0, i.e, λ¯B>0\underline{\lambda}^{B}>0, which implies BB has full row rank. For simplicity, let us consider the H=1H=1 case. Recall the definition of BB:

B:=(S11X1SN1XN)\displaystyle B:=\left(\begin{array}[]{ccc}S_{11}X_{1}\\ \cdots\\ S_{N1}X_{N}\end{array}\right)

Suppose we initialize W1Q,W1KW^{Q}_{1},W^{K}_{1} such that each Si1n×nS_{i1}\in\mathbb{R}^{n\times n} is full rank, then we can easily show that rank(Si1Xi)=n\operatorname{rank}(S_{i1}X_{i})=n if XiX_{i} has full row rank. Suppose embedding dimension DD is large, with certain assumption on XX, we can show BB has full row rank. For example, if each XiX_{i} follows standard Gaussian distribution with D>>ND>>N, then rank(B)=Nn\operatorname{rank}(B)=Nn with probability 11 if we initialize W1Q,W1KW_{1}^{Q},W_{1}^{K} such that Si1S_{i1} is full rank.

Further, let us assume that h=1H(λ¯hQ)2>1,h=1H(λ¯hK)2>1\sum\limits_{h=1}^{H}(\bar{\lambda}_{h}^{Q})^{2}>1,\sum\limits_{h=1}^{H}(\bar{\lambda}_{h}^{K})^{2}>1, and initialization condition satisfies:

54n2NHXF6λ¯V(h=1H(λ¯hQ)2+(λ¯hK)2)(λ¯B)2WO2min(λ¯hQ,λ¯hK,1,λ¯B)1\displaystyle\frac{54n^{2}\sqrt{NH}\|X\|_{F}^{6}\bar{\lambda}^{V}\big{(}\sum\limits_{h=1}^{H}(\bar{\lambda}^{Q}_{h})^{2}+(\bar{\lambda}^{K}_{h})^{2}\big{)}}{(\underline{\lambda}^{B})^{2}\|W^{O}\|_{2}\min\big{(}\bar{\lambda}_{h}^{Q},\bar{\lambda}_{h}^{K},1,\underline{\lambda}^{B}\big{)}}\leq 1 (46)
Remark 5.

The initialization condition can be satisfied if WO2\|W^{O}\|_{2} is large and σmax(WV)\sigma_{\max}(W^{V}) is small. ν\nu in Equation (8) is 154\frac{1}{54}.

It is clear that Equation (45) holds when t=0t=0. Suppose it holds at iteration tt, we prove it holds at iteration t+1t+1.

Wr+1VW0VF(i)s=0rWr+1VWrVF=ηs=0rWVf(Mt;X)F\displaystyle\left\|W_{r+1}^{V}-W_{0}^{V}\right\|_{F}\stackrel{{\scriptstyle(i)}}{{\leq}}\sum_{s=0}^{r}\left\|W_{r+1}^{V}-W_{r}^{V}\right\|_{F}=\eta\sum_{s=0}^{r}\left\|\nabla_{W^{V}}f\left(M_{t};X\right)\right\|_{F}
(ii)ηs=0rBrFWO2𝖬𝖧(Mr;X)y2(iii)ηBrFWO2s=0r(1ημ)s/2𝖬𝖧(M0;X)y2,\displaystyle\stackrel{{\scriptstyle(ii)}}{{\leq}}\eta\sum_{s=0}^{r}\|B_{r}\|_{F}\|W^{O}\|_{2}\left\|{\sf MH}(M_{r};X)-y\right\|_{2}\stackrel{{\scriptstyle(iii)}}{{\leq}}\eta\|B_{r}\|_{F}\|W^{O}\|_{2}\sum_{s=0}^{r}\left(1-\eta\mu\right)^{s/2}\left\|{\sf MH}(M_{0};X)-y\right\|_{2},

where (i) uses the triangle inequality; (ii) plugs in the expression of WVf(Mt;X)\nabla_{W^{V}}f\left(M_{t};X\right) and uses the Cauchy-Schwartz inequality; (iii) is because we assume the loss function f()f(\cdot) linearly decreases until tt-th iteration. Let u=1ημu=\sqrt{1-\eta\mu}. So we have

ηBrFWO2s=0(1ημ)s/2𝖬𝖧(M0;X)y2\displaystyle\eta\left\|B_{r}\right\|_{F}\left\|W^{O}\right\|_{2}\sum_{s=0}(1-\eta\mu)^{s/2}\left\|{\sf MH}\left(M_{0};X\right)-y\right\|_{2}
1μBrFWO1ur+11u(1u2)𝖬𝖧(M0;X)y2\displaystyle\leq\frac{1}{\mu}\|B_{r}\|_{F}\|W^{O}\|\frac{1-u^{r+1}}{1-u}(1-u^{2})\left\|{\sf MH}(M_{0};X)-y\right\|_{2}
=1μ[Sr,1X1,,Sr,NXN]FWO1ur+11u(1u2)𝖬𝖧(M0;X)y2\displaystyle=\frac{1}{\mu}\|\left[S_{r,1}X_{1},\cdots,S_{r,N}X_{N}\right]\|_{F}\|W^{O}\|\frac{1-u^{r+1}}{1-u}(1-u^{2})\left\|{\sf MH}(M_{0};X)-y\right\|_{2} (47)
(i)2nHNμXFWO2𝖬𝖧(M0;X)yF(ii)1,\displaystyle\stackrel{{\scriptstyle(i)}}{{\leq}}\frac{2n\sqrt{HN}}{\mu}\|X\|_{F}\|W^{O}\|_{2}\|{\sf MH}(M_{0};X)-y\|_{F}\stackrel{{\scriptstyle(ii)}}{{\leq}}1, (48)

where (i) is because each element in Sr,iS_{r,i} has magnitude at most 11 and Sr,iFnH\|S_{r,i}\|_{F}\leq n\sqrt{H}, then by Cuachy-Schwartz inequality, we have BrHNXF\|B\|_{r}\leq\sqrt{HN}\|X\|_{F}; (ii) is due to the initialization condition. Then by Weyl’s inequality, there is

σmax(Wr+1V)σmax(W0V)+1=32λ¯V.\sigma_{\max}\left(W_{r+1}^{V}\right)\leq\sigma_{\max}(W_{0}^{V})+1=\frac{3}{2}\bar{\lambda}^{V}.

Similarly, let us derive the upper bound for σmax(Wh,rQ)\sigma_{\max}(W^{Q}_{h,r}).

Wh,r+1QWh,0QF(i)s=0rWh,r+1QWh,rQF=ηs=0rWhQf(Mt;X)F\displaystyle\left\|W_{h,r+1}^{Q}-W_{h,0}^{Q}\right\|_{F}\stackrel{{\scriptstyle(i)}}{{\leq}}\sum_{s=0}^{r}\left\|W_{h,r+1}^{Q}-W_{h,r}^{Q}\right\|_{F}=\eta\sum_{s=0}^{r}\left\|\nabla_{W^{Q}_{h}}f\left(M_{t};X\right)\right\|_{F}
(ii)ηs=0ri=1NQi2𝖬𝖧(Mr;X)y2(iii)ηi=1NQi2s=0r(1ημ)s/2𝖬𝖧(M0;X)y2\displaystyle\stackrel{{\scriptstyle(ii)}}{{\leq}}\eta\sum_{s=0}^{r}\sqrt{\sum\limits_{i=1}^{N}Q_{i}^{2}}\left\|{\sf MH}\left(M_{r};X\right)-y\right\|_{2}\stackrel{{\scriptstyle(iii)}}{{\leq}}\eta\sqrt{\sum\limits_{i=1}^{N}Q_{i}^{2}}\sum_{s=0}^{r}(1-\eta\mu)^{s/2}\left\|{\sf MH}\left(M_{0};X\right)-y\right\|_{2}
i=1NQi2μ1ur+11u(1u2)𝖬𝖧(M0;X)y2\displaystyle\leq\frac{\sqrt{\sum\limits_{i=1}^{N}Q_{i}^{2}}}{\mu}\frac{1-u^{r+1}}{1-u}\left(1-u^{2}\right)\left\|{\sf MH}\left(M_{0};X\right)-y\right\|_{2}
2i=1NQi2μ𝖬𝖧(M0;X)y2(iv)12λ¯hQ,\displaystyle\leq\frac{2\sqrt{\sum\limits_{i=1}^{N}Q_{i}^{2}}}{\mu}\left\|{\sf MH}\left(M_{0};X\right)-y\right\|_{2}\stackrel{{\scriptstyle(iv)}}{{\leq}}\frac{1}{2}\bar{\lambda}_{h}^{Q},

where (i) uses triangle inequality; (ii) uses Lemma 2 (4); (iii) comes from the assumption that loss function f()f(\cdot) linearly decreases until tt-th iteration; (iv) is due to the initialization condition Equation (46). Similarly, we can show

ηi=1NKi2s=0r(1ημ)s/2𝖬𝖧(M0;X)y2\displaystyle\eta\sqrt{\sum\limits_{i=1}^{N}K_{i}^{2}}\sum_{s=0}^{r}(1-\eta\mu)^{s/2}\left\|{\sf MH}\left(M_{0};X\right)-y\right\|_{2}
2i=1NKi2μ𝖬𝖧(M0;X)y212λ¯hK.\displaystyle\leq\frac{2\sqrt{\sum\limits_{i=1}^{N}K_{i}^{2}}}{\mu}\left\|{\sf MH}\left(M_{0};X\right)-y\right\|_{2}\leq\frac{1}{2}\bar{\lambda}_{h}^{K}. (49)

Then by Weyl’s inequality, there is

σmax(Wh,t+1Q)σmax(Wh,0Q)+12λ¯hQ=32λ¯hQ;σmax(Wh,t+1K)σmax(Wh,0K)+12λ¯hK=32λ¯hK.\sigma_{\max}\left(W_{h,t+1}^{Q}\right)\leq\sigma_{\max}(W_{h,0}^{Q})+\frac{1}{2}\bar{\lambda}_{h}^{Q}=\frac{3}{2}\bar{\lambda}_{h}^{Q};\;\sigma_{\max}\left(W_{h,t+1}^{K}\right)\leq\sigma_{\max}(W_{h,0}^{K})+\frac{1}{2}\bar{\lambda}_{h}^{K}=\frac{3}{2}\bar{\lambda}_{h}^{K}.

Now we aim to bound the eigenvalues of Br+1B_{r+1}.

Br+1B0Fs=0rBs+1BsF=i=1Ns=0rSi,s+1XiSi,sXiF(i)i=1Ns=0rXiFSi,s+1Si,sF(ii)i=1Ns=0rXiFh=1HSih,s+1Sih,sF(iii)ηi=1Ns=0rXiFϕi2+ψi2(WQf(Ms;Xi),WKf(Ms;Xi))F(iv)i=1Ns=0rXiFϕi2+ψ22Qi2+Ki2𝖬𝖧(Ms,Xi)yiF(v)ηs=0ri=1NXiF2(ϕi2+ψi2)(Qi2+Ki2)(1ημ)s/2𝖬𝖧(M0;X)y2,\begin{aligned} &\left\|B_{r+1}-B_{0}\right\|_{F}\leq\sum_{s=0}^{r}\left\|B_{s+1}-B_{s}\right\|_{F}=\sum_{i=1}^{N}\sum_{s=0}^{r}\|S_{i,s+1}X_{i}-S_{i,s}X_{i}\|_{F}\\ &\stackrel{{\scriptstyle(i)}}{{\leq}}\sum_{i=1}^{N}\sum_{s=0}^{r}\|X_{i}\|_{F}\|S_{i,s+1}-S_{i,s}\|_{F}\stackrel{{\scriptstyle(ii)}}{{\leq}}\sum_{i=1}^{N}\sum_{s=0}^{r}\|X_{i}\|_{F}\sum_{h=1}^{H}\|S_{ih,s+1}-S_{ih,s}\|_{F}\\ &\stackrel{{\scriptstyle(iii)}}{{\leq}}\eta\sum_{i=1}^{N}\sum_{s=0}^{r}\|X_{i}\|_{F}\cdot\sqrt{\phi_{i}^{2}+\psi_{i}^{2}}\cdot\|\left(\nabla_{W^{Q}}f(M_{s};X_{i}),\nabla_{W^{K}}f(M_{s};X_{i})\right)\|_{F}\\ &\stackrel{{\scriptstyle(iv)}}{{\leq}}\sum_{i=1}^{N}\sum_{s=0}^{r}\|X_{i}\|_{F}\cdot\sqrt{\phi_{i}^{2}+\psi_{2}^{2}}\cdot\sqrt{Q_{i}^{2}+K_{i}^{2}}\cdot\|{\sf MH}\left(M_{s},X_{i}\right)-y_{i}\|_{F}\\ &\stackrel{{\scriptstyle(v)}}{{\leq}}\eta\sum_{s=0}^{r}\sqrt{\sum_{i=1}^{N}\|X_{i}\|_{F}^{2}(\phi_{i}^{2}+\psi_{i}^{2})(Q_{i}^{2}+K_{i}^{2})}(1-\eta\mu)^{s/2}\left\|{\sf MH}\left(M_{0};X\right)-y\right\|_{2}\end{aligned},

where (i) and (ii) uses triangle inequality and Cauchy-Schwartz inequality; (iii) comes from Lemma 2 (5); (iv) uses Lemma 2 and Cauchy-Schwartz inequality; (v) comes from Cauchy-Schwartz inequality. Together with our initialization condition, we have

Br+1B0F1μi=1NXiF2(ϕi2+ψi2)(Qi2+Ki2)1ur+11u𝖬𝖧(M0;X)y2\displaystyle\left\|B_{r+1}-B_{0}\right\|_{F}\leq\frac{1}{\mu}\sqrt{\sum_{i=1}^{N}\|X_{i}\|_{F}^{2}(\phi_{i}^{2}+\psi_{i}^{2})(Q_{i}^{2}+K_{i}^{2})}\cdot\frac{1-u^{r+1}}{1-u}\left\|{\sf MH}\left(M_{0};X\right)-y\right\|_{2}
2μi=1NXiF2(ϕi2+ψi2)(Qi2+Ki2)𝖬𝖧(M0;X)y2(i)12λ¯B,\displaystyle\leq\frac{2}{\mu}\sqrt{\sum_{i=1}^{N}\|X_{i}\|_{F}^{2}(\phi_{i}^{2}+\psi_{i}^{2})(Q_{i}^{2}+K_{i}^{2})}\left\|{\sf MH}\left(M_{0};X\right)-y\right\|_{2}\stackrel{{\scriptstyle(i)}}{{\leq}}\frac{1}{2}\underline{\lambda}^{B},

where (i) comes from the initialization condition 46. By Weyl’s inequality, we can derive the bound for the singular values of BtB_{t}:

σmin(Br+1)σmin(B0)Br+1B0F12λ¯B.\displaystyle\sigma_{\min}(B_{r+1})\geq\sigma_{\min}(B_{0})-\left\|B_{r+1}-B_{0}\right\|_{F}\geq\frac{1}{2}\underline{\lambda}^{B}.

The final step is to show the last inequality holds. Since we have already showed σmax(WhQ),σmax(WhK),σmax(WhV)\sigma_{\max}(W^{Q}_{h}),\sigma_{\max}(W^{K}_{h}),\sigma_{\max}(W^{V}_{h}) are bounded, by Lemma 3 (2) we can conclude that:

f(Mt+1;X)f(Mt)2GMt+1MtF\displaystyle\left\|\nabla f\left(M_{t+1};X\right)-\nabla f\left(M_{t}\right)\right\|_{2}\leq G\|M_{t+1}-M_{t}\|_{F}

Thus by Lemma 4, we choose η<12G\eta<\frac{1}{2G}, then the following hold true:

f(Mt+1;X)\displaystyle f\left(M_{t+1};X\right) =f(Mtηf(Mt;X);X)\displaystyle=f\left(M_{t}-\eta\nabla f(M_{t};X);X\right)
(i)f(Mt;X)ηf(Mt;X)2+G2η2f(Mt;X)2\displaystyle\stackrel{{\scriptstyle(i)}}{{\leq}}f\left(M_{t};X\right)-\eta\left\|\nabla f\left(M_{t};X\right)\right\|^{2}+\frac{G}{2}\eta^{2}\left\|\nabla f\left(M_{t};X\right)\right\|^{2}
(ii)f(Mt;X)12ηf(Mt;X)2\displaystyle\stackrel{{\scriptstyle(ii)}}{{\leq}}f\left(M_{t};X\right)-\frac{1}{2}\eta\left\|\nabla f\left(M_{t};X\right)\right\|^{2}
(iii)f(Mt;X)12ηf(Mt;X)WV2\displaystyle\stackrel{{\scriptstyle(iii)}}{{\leq}}f\left(M_{t};X\right)-\frac{1}{2}\eta\left\|\frac{\partial f\left(M_{t};X\right)}{\partial W^{V}}\right\|^{2}
(iv)f(Mt;X)12ηWOBt(vec(𝖬𝖧(Mt;X)y))2\displaystyle\stackrel{{\scriptstyle(iv)}}{{\leq}}f\left(M_{t};X\right)-\frac{1}{2}\eta\|W^{O}\otimes B_{t}^{\top}\left(\operatorname{vec}({\sf MH}(M_{t};X)-y)\right)\|^{2}
(v)f(Mt;X)18ηWO22(λ¯B)2f(Mt;X)\displaystyle\stackrel{{\scriptstyle(v)}}{{\leq}}f\left(M_{t};X\right)-\frac{1}{8}\eta\|W^{O}\|_{2}^{2}(\underline{\lambda}^{B})^{2}\cdot f(M_{t};X)
=(114WO22(λ¯B)2)f(Mt;X)\displaystyle=(1-\frac{1}{4}\|W^{O}\|_{2}^{2}(\underline{\lambda}^{B})^{2})\cdot f(M_{t};X)
=(vi)(1ημ)f(Mt;X),\displaystyle\stackrel{{\scriptstyle(vi)}}{{=}}(1-\eta\mu)f(M_{t};X),

where (i) uses Lemma 4; (ii) is because we set η<12G\eta<\frac{1}{2G}; (iii) only considers the gradient over WVW^{V}; (iv) plugs in the closed form gradient in Lemma 1; (v) uses the property of smallest singular value and induction assumption; (vi) comes from the definition of μ\mu.

1.4 Lemma for Theorem 3

The following lemmas all consider the Transformers with Gaussian kernel attention 11.

Lemma 6.
(1)f(M;X)WV=B(𝖬𝖧(M;X)y)(WO)\displaystyle(1)\leavevmode\nobreak\ \frac{\partial f(M;X)}{\partial W^{V}}=B^{\top}\left({\sf MH}(M;X\right)-y)\left(W^{O}\right)^{\top} (50)
(2)vec(f(M;X)WV)=(WO)B,vec(𝖬𝖧(M;X)y)\displaystyle(2)\leavevmode\nobreak\ \operatorname{vec}\left(\frac{\partial f(M;X)}{\partial W^{V}}\right)=\left\langle(W^{O})^{\top}\otimes B,\operatorname{vec}({\sf MH}(M;X)-y)\right\rangle
=(𝕀HdB)(WO𝕀N)(𝖬𝖧(M;X)y)\displaystyle\quad=\left(\mathbb{I}_{Hd}\otimes B^{\top}\right)\cdot\left(W^{O}\otimes\mathbb{I}_{N}\right)\cdot\left({\sf MH}(M;X)-y\right) (51)
(3)f(M;X)WhQ=f(M;X)CCWhQ=i=1Nf(M;Xi)CiCiWhQ\displaystyle(3)\leavevmode\nobreak\ \frac{\partial f(M;X)}{\partial W_{h}^{Q}}=\frac{\partial f(M;X)}{\partial C}\cdot\frac{\partial C}{\partial W_{h}^{Q}}=\sum_{i=1}^{N}\frac{\partial f(M;X_{i})}{\partial C_{i}}\cdot\frac{\partial C_{i}}{\partial W_{h}^{Q}} (52)
(4)f(M;Xi)Ci=((𝖬𝖧(M;Xi)yi)(WO)(Vi))Si\displaystyle(4)\leavevmode\nobreak\ \frac{\partial f(M;X_{i})}{\partial C_{i}}=\left(({\sf MH}(M;X_{i})-y_{i})\left(W^{O}\right)^{\top}\left(V^{\prime}_{i}\right)^{\top}\right)\odot S_{i} (53)
(5)f(M;X)C=[f(M;X1)C1,,f(M;XN)CN]\displaystyle(5)\leavevmode\nobreak\ \frac{\partial f(M;X)}{\partial C}=\left[\frac{\partial f(M;X_{1})}{\partial C_{1}},\cdots,\frac{\partial f(M;X_{N})}{\partial C_{N}}\right]^{\top} (54)
Lemma 7.

Consider updating WQ,WK,WVW^{Q},W^{K},W^{V} at iteration tt. Suppose σmax(WQ)\sigma_{\max}(W^{Q}), σmax(WK)\sigma_{\max}(W^{K}), σmax(WV)\sigma_{\max}(W^{V}) are bounded during in the optimization phase, then we have the following conclusion:

(1)d(Cih)F2ndXiF2σmax2(WhQ)+σmax2(WhK)d(WhQ)F\displaystyle(1)\|d(C_{ih})\|_{F}\leq\sqrt{\frac{2n}{d}}\left\|X_{i}\right\|_{F}^{2}\sqrt{\sigma_{\max}^{2}\left(W_{h}^{Q}\right)+\sigma_{\max}^{2}\left(W_{h}^{K}\right)}\|d(W^{Q}_{h})\|_{F} (55)
(2)d(CihWhQ)FnXiF2d(WhQ)F.\displaystyle(2)\left\|d\left(\frac{\partial C_{ih}}{\partial W_{h}^{Q}}\right)\right\|_{F}\leq\sqrt{n}\|X_{i}\|_{F}^{2}\cdot\|d(W_{h}^{Q})\|_{F}. (56)
(3)f(M;Xi)CiFmin|ViWO|minSi𝖬𝖧(M;Xi)yi2,\displaystyle(3)\|\frac{\partial f(M;X_{i})}{\partial C_{i}}\|_{F}\geq\min|V_{i}W^{O}|\cdot\min S_{i}\cdot\|{\sf MH}\left(M;X_{i}\right)-y_{i}\|_{2}, (57)
where Ri=(𝖬𝖧(M;Xi)yi)(WO)(Vi′′).\displaystyle\quad\text{where }R_{i}=\left({\sf MH}\left(M;X_{i}\right)-y_{i}\right)\left(W^{O}\right)^{\top}\left(V_{i}^{\prime\prime}\right)^{\top}. (58)
(4)f(M;Xi)WhQFQi𝖬𝖧(M;Xi)yi2,\displaystyle(4)\left\|\frac{\partial f(M;X_{i})}{\partial W_{h}^{Q}}\right\|_{F}\leq Q_{i}^{\prime}\|{\sf MH}(M;X_{i})-y_{i}\|_{2}, (59)
Qi=2ndXiF3WO2σmax(WV)σmax2(WhQ)+σmax2(WhK)\displaystyle\quad Q_{i}^{\prime}=\sqrt{\frac{2n}{d}}\left\|X_{i}\right\|_{F}^{3}\|W^{O}\|_{2}\sigma_{\max}(W^{V})\sqrt{\sigma_{\max}^{2}\left(W_{h}^{Q}\right)+\sigma_{\max}^{2}\left(W_{h}^{K}\right)} (60)

where min|VWO|\min|V^{\prime}W^{O}| is the smallest absolute value of each element in vector VWOV^{\prime}W^{O}; minS\min S is the smallest element in matrix SS.

Lemma 8.

Consider updating WQ,WK,WVW^{Q},W^{K},W^{V} at iteration tt. Suppose σmax(WQ)\sigma_{\max}(W^{Q}), σmax(WK)\sigma_{\max}(W^{K}), σmax(WV)\sigma_{\max}(W^{V}) are bounded during in the optimization phase, then we have the following conclusion:

𝖬𝖧(Mt+1;X)𝖬𝖧(Mt;X)FZMt+1MtF,where Z is some positive constant.\displaystyle\|{\sf MH}(M_{t+1};X)-{\sf MH}(M_{t};X)\|_{F}\leq Z^{\prime}\|M_{t+1}-M_{t}\|_{F},\text{where }Z^{\prime}\text{ is some positive constant.} (61)
f(Mt+1;X)f(Mt;X)2GMt+1MtF,where G is some positive constant.\displaystyle\left\|\nabla f\left(M_{t+1};X\right)-\nabla f\left(M_{t};X\right)\right\|_{2}\leq G^{\prime}\|M_{t+1}-M_{t}\|_{F},\text{where }G^{\prime}\text{ is some positive constant.} (62)

1.5 Proof Sketch of Theorem 3.

(1)Using GD update rule, we aim to iteratively show

{σmax(WrQ)32λ¯Q,r{0,,t},σmin(Ch(Mr)WhQ)12δ,r{0,,t},minSrκ,r{0,,t},f(Mr;X)(1ηγ)rf(M0,X),r{0,,t}\displaystyle\left\{\begin{array}[]{l}\sigma_{\max}(W^{Q}_{r})\leq\frac{3}{2}\bar{\lambda}^{Q},r\in\{0,\ldots,t\},\\ \sigma_{\min}\left(\frac{\partial C_{h}\left(M_{r}\right)}{\partial W_{h}^{Q}}\right)\geq\frac{1}{2}\delta,r\in\{0,\ldots,t\},\\ \min S_{r}\geq\kappa,r\in\{0,\ldots,t\},\\ f\left(M_{r};X\right)\leq(1-\eta\gamma)^{r}f\left(M_{0},X\right),\;r\in\{0,\ldots,t\}\end{array}\right. (67)

Denote γ:=12δ2κ2(min|VWO|)2\gamma:=\frac{1}{2}\delta^{2}\kappa^{2}\left(\min\left|V^{\prime}W^{O}\right|\right)^{2}. Let us discuss about the value of γ\gamma. We know WOHdV×1W^{O}\in\mathbb{R}^{Hd^{V}\times 1}, B0HD×NnB_{0}^{\top}\in\mathbb{R}^{HD\times Nn}, where Hd>1,HD>NnHd>1,HD>Nn. We require γ>0\gamma>0, i.e, δ>0,κ>0,min|VWO|>0\delta>0,\kappa>0,\operatorname{min}\left|V^{\prime}W^{O}\right|>0. It is clear that κ>0\kappa>0 can hold as long as WhQW_{h}^{Q} is bounded. And it is easy to show that if Xi𝟎X_{i}\neq\mathbf{0}, we can always choose WVW^{V} and WOW^{O}, such that min|VWO|>0\operatorname{min}\left|V^{\prime}W^{O}\right|>0. Since Ch(M)WhQNn2×Dd\frac{\partial C_{h}\left(M\right)}{\partial W_{h}^{Q}}\in\mathbb{R}^{Nn^{2}\times Dd}, suppose we initialize WhQ,WhKW_{h}^{Q},W^{K}_{h} such that rank(Ch(M0)WhQ)=Nn2\operatorname{rank}(\frac{\partial C_{h}\left(M_{0}\right)}{\partial W_{h}^{Q}})=Nn^{2}, then we have σmin(Ch(M0)WhQ)δ\sigma_{\min}\left(\frac{\partial C_{h}\left(M_{0}\right)}{\partial W_{h}^{Q}}\right)\geq\delta for some positive constant δ\delta. Further, we assume the initialization condition satisfies:

8nXF5WO2λ¯V(λ¯hQ+λ¯hK)exp(94XF2((λ¯hQ)2+(λ¯hK)2))δ2(min(|VWO|))2min(δ,λ¯hQ)MH(M0;X)y21\displaystyle\frac{8n\|X\|_{F}^{5}\|W^{O}\|_{2}\bar{\lambda}^{V}(\bar{\lambda}_{h}^{Q}+\bar{\lambda}_{h}^{K})\exp\left(\frac{9}{4}\|X\|_{F}^{2}\left(\left(\bar{\lambda}_{h}^{Q}\right)^{2}+\left(\bar{\lambda}_{h}^{K}\right)^{2}\right)\right)}{\delta^{2}\left(\min\left(\left|V^{\prime}W^{O}\right|\right)\right)^{2}\cdot\min\left(\delta,\bar{\lambda}_{h}^{Q}\right)}\left\|\mathrm{MH}\left(M_{0};X\right)-y\right\|_{2}\leq 1 (68)
Remark 6.

The initialization condition can be satisfied if WO2\|W^{O}\|_{2} is large and σmax(WV)\sigma_{\max}(W^{V}) is small. ν\nu^{\prime} in Equation (12) is 18\frac{1}{8}.

Similar to the proof of Theorem 2, we use induction to prove the theorem. Equation (67) holds when t=0t=0. Suppose it holds at iteration tt, we prove it holds at iteration t+1t+1.

Wh,r+1QWh,0QF(i)s=0rWh,r+1QWh,rQF=ηs=0rWhQf(Mt;X)F\displaystyle\left\|W_{h,r+1}^{Q}-W_{h,0}^{Q}\right\|_{F}\stackrel{{\scriptstyle(i)}}{{\leq}}\sum_{s=0}^{r}\left\|W_{h,r+1}^{Q}-W_{h,r}^{Q}\right\|_{F}=\eta\sum_{s=0}^{r}\left\|\nabla_{W^{Q}_{h}}f\left(M_{t};X\right)\right\|_{F}
(ii)ηs=0ri=1NQi2𝖬𝖧(Mr;X)y2(iii)ηi=1NQi2s=0r(1ηγ)s/2𝖬𝖧(M0;X)y2,\displaystyle\stackrel{{\scriptstyle(ii)}}{{\leq}}\eta\sum_{s=0}^{r}\sqrt{\sum\limits_{i=1}^{N}{Q^{\prime}_{i}}^{2}}\left\|{\sf MH}(M_{r};X)-y\right\|_{2}\stackrel{{\scriptstyle(iii)}}{{\leq}}\eta\sqrt{\sum\limits_{i=1}^{N}{Q^{\prime}_{i}}^{2}}\sum_{s=0}^{r}\left(1-\eta\gamma\right)^{s/2}\left\|{\sf MH}(M_{0};X)-y\right\|_{2},

where (i) uses triangle inequality; (ii) comes from Lemma 7 and Cauchy-Schwartz inequality; (iii) is from the induction assumption that loss function f()f(\cdot) linearly decreases until tt-th iteration. Let u=1ηγu=\sqrt{1-\eta\gamma}. So we have

Wh,r+1QWh,0QFηi=1NQi2s=0r(1ηγ)s/2𝖬𝖧(M0;X)y2\displaystyle\left\|W_{h,r+1}^{Q}-W_{h,0}^{Q}\right\|_{F}\leq\eta\sqrt{\sum_{i=1}^{N}{Q_{i}^{\prime}}^{2}}\sum_{s=0}^{r}(1-\eta\gamma)^{s/2}\left\|{\sf MH}\left(M_{0};X\right)-y\right\|_{2} (69)
1γi=1NQi21ur+11u(1u2)𝖬𝖧(M0;X)y2\displaystyle\leq\frac{1}{\gamma}\sqrt{\sum\limits_{i=1}^{N}{Q^{\prime}_{i}}^{2}}\frac{1-u^{r+1}}{1-u}(1-u^{2})\left\|{\sf MH}(M_{0};X)-y\right\|_{2}
2i=1NQi2γ𝖬𝖧(M0;X)yF(i)12λ¯hQ,\displaystyle\leq\frac{2\sqrt{\sum\limits_{i=1}^{N}{Q^{\prime}_{i}}^{2}}}{\gamma}\|{\sf MH}(M_{0};X)-y\|_{F}\stackrel{{\scriptstyle(i)}}{{\leq}}\frac{1}{2}\bar{\lambda}^{Q}_{h}, (70)

where (i) comes from the initialization condition. Then by Weyl’s inequality, there is

σmax(Wh,t+1Q)σmax(Wh,0Q)+12λ¯hQ=32λ¯hQ.\sigma_{\max}\left(W_{h,t+1}^{Q}\right)\leq\sigma_{\max}(W_{h,0}^{Q})+\frac{1}{2}\bar{\lambda}^{Q}_{h}=\frac{3}{2}\bar{\lambda}^{Q}_{h}.
Ch(Mr+1)WhQCh(M0)WhQF(i)s=0rCh(Ms+1)WhQCh(Ms)WhQF\displaystyle\left\|\frac{\partial C_{h}\left(M_{r+1}\right)}{\partial W_{h}^{Q}}-\frac{\partial C_{h}\left(M_{0}\right)}{\partial W_{h}^{Q}}\right\|_{F}\stackrel{{\scriptstyle(i)}}{{\leq}}\sum\limits_{s=0}^{r}\left\|\frac{\partial C_{h}\left(M_{s+1}\right)}{\partial W_{h}^{Q}}-\frac{\partial C_{h}\left(M_{s}\right)}{\partial W_{h}^{Q}}\right\|_{F}
(ii)ηnXF2s=0rWhQf(Ms;X)F\displaystyle\stackrel{{\scriptstyle(ii)}}{{\leq}}\eta\sqrt{n}\|X\|_{F}^{2}\sum_{s=0}^{r}\left\|\nabla_{W_{h}^{Q}}f\left(M_{s};X\right)\right\|_{F}
(iii)ηnXF2s=0ri=1NQi2𝖬𝖧(Ms;X)y2\displaystyle\stackrel{{\scriptstyle(iii)}}{{\leq}}\eta\sqrt{n}\|X\|_{F}^{2}\sum_{s=0}^{r}\sqrt{\sum_{i=1}^{N}{Q_{i}^{\prime}}^{2}}\left\|{\sf MH}\left(M_{s};X\right)-y\right\|_{2}
(iv)ηnXF2i=1NQi2s=0r(1ηγ)s/2𝖬𝖧(M0;X)y2,\displaystyle\stackrel{{\scriptstyle(iv)}}{{\leq}}\eta\sqrt{n}\|X\|_{F}^{2}\sqrt{\sum_{i=1}^{N}{Q_{i}^{\prime}}^{2}}\sum_{s=0}^{r}(1-\eta\gamma)^{s/2}\left\|{\sf MH}\left(M_{0};X\right)-y\right\|_{2},
2γnXF2i=1NQi2𝖬𝖧(M0;X)y2\displaystyle\leq\frac{2}{\gamma}\sqrt{n}\|X\|_{F}^{2}\sqrt{\sum_{i=1}^{N}{Q_{i}^{\prime}}^{2}}\|{\sf MH}\left(M_{0};X\right)-y\|_{2}
(v)12δ,\displaystyle\stackrel{{\scriptstyle(v)}}{{\leq}}\frac{1}{2}\delta,

where (i) uses triangle inequality; (ii) applies Lemma 7 (2) and Cauchy-Schwartz inequality; (iii) uses Lemma 7 (4); (iv) applies the induction assumption that the loss function f()f(\cdot) linearly decreases until tt-th iteration; (v) comes from the initialization condition. Then by Weyl’s inequality, there is

σmax(Ch(Mt+1)WhQ)σmax(Ch(M0)WhQ)12δ=12δ.\sigma_{\max}\left(\frac{\partial C_{h}\left(M_{t+1}\right)}{\partial W_{h}^{Q}}\right)\geq\sigma_{\max}\left(\frac{\partial C_{h}\left(M_{0}\right)}{\partial W_{h}^{Q}}\right)-\frac{1}{2}\delta=\frac{1}{2}\delta.

For each element in SihS_{ih}, we have close form

S(WhQ,WhK;Xi)kj=exp(12dXikWhQXijWhK2)\displaystyle S\left(W_{h}^{Q},W_{h}^{K};X_{i}\right)_{kj}=\exp\left(-\frac{1}{2\sqrt{d}}\left\|X_{ik}\cdot W_{h}^{Q}-X_{ij}\cdot W_{h}^{K}\right\|^{2}\right)

Since we have already showed that σmax(Wh,rQ)32λ¯hQ\sigma_{\max}\left(W_{h,r}^{Q}\right)\leq\frac{3}{2}\bar{\lambda}^{Q}_{h}, it follows directly each element in matrix StS_{t} is lower bounded by some constant κ\kappa for any tt. Now we derive the expression of κ\kappa:

exp(12dXikWh,tQXijWhK2)\displaystyle\exp\left(-\frac{1}{2\sqrt{d}}\left\|X_{ik}\cdot W_{h,t}^{Q}-X_{ij}\cdot W_{h}^{K}\right\|^{2}\right)
(i)exp(1d(XikWh,tQ2+XijWhK2))\displaystyle\stackrel{{\scriptstyle(i)}}{{\geq}}\exp\left(-\frac{1}{\sqrt{d}}\big{(}\|X_{ik}\cdot W_{h,t}^{Q}\|^{2}+\|X_{ij}\cdot W_{h}^{K}\|^{2}\big{)}\right)
(ii)exp(1d(94(λ¯hQ)2Xik2+(λ¯hK)2Xij2))\displaystyle\stackrel{{\scriptstyle(ii)}}{{\geq}}\exp\left(-\frac{1}{\sqrt{d}}\big{(}\frac{9}{4}(\bar{\lambda}_{h}^{Q})^{2}\|X_{ik\cdot}\|^{2}+(\bar{\lambda}_{h}^{K})^{2}\|X_{ij\cdot}\|^{2}\big{)}\right)
(iii)exp(94XF2((λ¯hQ)2+(λ¯hK)2))\displaystyle\stackrel{{\scriptstyle(iii)}}{{\geq}}\exp\left(-\frac{9}{4}\|X\|_{F}^{2}\big{(}(\bar{\lambda}_{h}^{Q})^{2}+(\bar{\lambda}_{h}^{K})^{2}\big{)}\right)
:=κ,\displaystyle:=\kappa,

where (i) uses Cauchy-Schwartz inequality; (ii) applies the induction assumption σmax(Wh,tQ)32λ¯hQ\sigma_{\max}(W_{h,t}^{Q})\leq\frac{3}{2}\bar{\lambda}_{h}^{Q} and property of singular value; (iii) is because d1d\geq 1. Thus, we have minStκ\min S_{t}\geq\kappa. Finally, we aim to show f(Mt+1;X)(1ηγ)f(Mt,X)f\left(M_{t+1};X\right)\leq(1-\eta\gamma)f\left(M_{t},X\right). By Lemma 8, since we have showed that σmax(WhQ)\sigma_{\max}(W^{Q}_{h}) is bounded, we can directly derive that

f(Mt+1;X)f(Mt;X)2\displaystyle\left\|\nabla f\left(M_{t+1};X\right)-\nabla f\left(M_{t};X\right)\right\|_{2}
=WhQf(Mt+1;X)WhQf(Mt;X)2\displaystyle=\left\|\nabla_{W^{Q}_{h}}f\left(M_{t+1};X\right)-\nabla_{W^{Q}_{h}}f\left(M_{t};X\right)\right\|_{2}
GMt+1MtF\displaystyle\leq G^{\prime}\|M_{t+1}-M_{t}\|_{F}

Finally, by Lemma 4, choose η<12G\eta<\frac{1}{2G^{\prime}}, we have the following holds:

f(Mt+1;X)\displaystyle f\left(M_{t+1};X\right) =f(Mtηf(Mt;X);X)\displaystyle=f\left(M_{t}-\eta\nabla f(M_{t};X);X\right)
(i)f(Mt;X)ηWQf(Mt;X)2+G2η2WQf(Mt;X)2\displaystyle\stackrel{{\scriptstyle(i)}}{{\leq}}f\left(M_{t};X\right)-\eta\left\|\nabla_{W^{Q}}f\left(M_{t};X\right)\right\|^{2}+\frac{G^{\prime}}{2}\eta^{2}\left\|\nabla_{W^{Q}}f\left(M_{t};X\right)\right\|^{2}
(ii)f(Mt;X)12ηWhQf(Mt;X)2\displaystyle\stackrel{{\scriptstyle(ii)}}{{\leq}}f\left(M_{t};X\right)-\frac{1}{2}\eta\left\|\nabla_{W^{Q}_{h}}f\left(M_{t};X\right)\right\|^{2}
=(iii)f(Mt;X)12ηf(Mt;X)C(Mt)(C(Mt)WhQ)F2\displaystyle\stackrel{{\scriptstyle(iii)}}{{=}}f\left(M_{t};X\right)-\frac{1}{2}\eta\left\|\frac{\partial f(M_{t};X)}{\partial C(M_{t})}\cdot\left(\frac{\partial C(M_{t})}{\partial W_{h}^{Q}}\right)\right\|_{F}^{2}
(iv)f(Mt;X)14ηδ2f(Mt;X)C(Mt)F2\displaystyle\stackrel{{\scriptstyle(iv)}}{{\leq}}f\left(M_{t};X\right)-\frac{1}{4}\eta\delta^{2}\left\|\frac{\partial f\left(M_{t};X\right)}{\partial C\left(M_{t}\right)}\right\|_{F}^{2}
(v)f(Mt;X)14ηδ2((𝖬𝖧(M;X)y)(WO)(V))SF2\displaystyle\stackrel{{\scriptstyle(v)}}{{\leq}}f\left(M_{t};X\right)-\frac{1}{4}\eta\delta^{2}\left\|\left(\left({\sf MH}\left(M;X\right)-y\right)\left(W^{O}\right)^{\top}\left(V^{\prime}\right)^{\top}\right)\odot S\right\|^{2}_{F}
(vi)f(Mt;X)14ηδ2κ2(min|VWO|)2𝖬𝖧(M0;X)y22\displaystyle\stackrel{{\scriptstyle(vi)}}{{\leq}}f\left(M_{t};X\right)-\frac{1}{4}\eta\delta^{2}\kappa^{2}\cdot(\min|V^{\prime}W^{O}|)^{2}\left\|{\sf MH}\left(M_{0};X\right)-y\right\|^{2}_{2}
=(vii)(1ηγ)f(Mt;X),\displaystyle\stackrel{{\scriptstyle(vii)}}{{=}}(1-\eta\gamma)f(M_{t};X),

where (i) uses Lemma 4 (2); (ii) is because we choose η<12G\eta<\frac{1}{2G^{\prime}}; (iii) writes down the expression of gradient according to chain rule in Lemma 6; (iv) uses the induction assumption σmax(Ch(Mt+1)WhQ)12δ\sigma_{\max}\left(\frac{\partial C_{h}\left(M_{t+1}\right)}{\partial W_{h}^{Q}}\right)\geq\frac{1}{2}\delta and property of singular value; (v) uses Lemma 6 (4); (vi) comes from Lemma 7 (3); (vii) uses the definition of γ\gamma.

(2)Next, we show the convergence result for Transformer with Softmax kernel with only WQW^{Q} updated. Since we assume parameters are all bounded during optimization phase, by Lemma 8, we can easily show that there exists constant GG^{\prime} (see xx for details), such that

WhQf(Mt+1;X)WhQf(Mt;X)2GMt+1MtF\displaystyle\left\|\nabla_{W^{Q}_{h}}f\left(M_{t+1};X\right)-\nabla_{W^{Q}_{h}}f\left(M_{t};X\right)\right\|_{2}\leq G^{\prime}\left\|M_{t+1}-M_{t}\right\|_{F} (71)

Then by Lemma 4, choose η<12G\eta^{\prime}<\frac{1}{2G^{\prime}} we have

f(Mt+1;X)\displaystyle f\left(M_{t+1};X\right) =f(Mtηf(Mt;X);X)\displaystyle=f\left(M_{t}-\eta\nabla f\left(M_{t};X\right);X\right)
f(Mt;X)ηf(Mt;X)2+G2η2f(Mt;X)2\displaystyle{\leq}f\left(M_{t};X\right)-\eta^{\prime}\left\|\nabla f\left(M_{t};X\right)\right\|^{2}+\frac{G^{\prime}}{2}\eta^{\prime 2}\left\|\nabla f\left(M_{t};X\right)\right\|^{2}
f(Mt;X)12ηf(Mt;X)2\displaystyle\leq f\left(M_{t};X\right)-\frac{1}{2}\eta^{\prime}\left\|\nabla f\left(M_{t};X\right)\right\|^{2}

1.6 Proof of Lemma in Section 1.2

Proof of Lemma 2 (1).

Proof.

Step 1: When WQ,WKW^{Q},W^{K} are updated, we aim to prove

d(Si)Fnd(Ci)F.\displaystyle\|d(S_{i})\|_{F}\leq n\|d(C_{i})\|_{F}.

Step 2: We aim to show d(Ci)FndXiF2h=1Hσmax2(WhK)d(WQ)F\|d(C_{i})\|_{F}\leq\frac{n}{\sqrt{d}}\left\|X_{i}\right\|_{F}^{2}\sqrt{\sum\limits_{h=1}^{H}\sigma_{\max}^{2}\left(W_{h}^{K}\right)}\cdot\left\|d\left(W^{Q}\right)\right\|_{F}. Combine the above two steps, we can derive the bound in Equation (33).
Proof of Step 1: First, we can write down the closed form of the differential of SiS_{i}:

d(Si)F\displaystyle\|d(S_{i})\|_{F} =Sid(Ci)SiΥ((expCi)𝔼)d(exp(Ci)𝔼))F\displaystyle=\|S_{i}\odot d(C_{i})-S_{i}\odot\Upsilon((\exp C_{i})\mathbb{E})\odot d(\exp(C_{i})\mathbb{E}))\|_{F} (72)

We reorganize the terms on the right side of Equation (72), we have the following equation:

d(Si)F\displaystyle\|d(S_{i})\|_{F} =Si(d(Ci)Υ((expCi)𝔼)d((expCi)𝔼))F\displaystyle=\|S_{i}\odot\big{(}d(C_{i})-\Upsilon((\exp C_{i})\mathbb{E})\odot d((\exp C_{i})\mathbb{E})\big{)}\|_{F}
=Si(d(Ci)Υ((expCi)𝔼)((expCi)d(Ci))𝔼)F\displaystyle=\|S_{i}\odot\big{(}d(C_{i})-\Upsilon((\exp C_{i})\mathbb{E})\odot\left((\exp C_{i})\odot d(C_{i})\right)\mathbb{E}\big{)}\|_{F} (73)

Since Ci=[Ci1,,CiH]C_{i}=[C_{i1},\cdots,C_{iH}], we will investigate each Cih,h=1,2,,HC_{ih},\;h=1,2,\cdots,H. We focus on the term d(Ci)Υ((expCi)𝔼)(exp(Ci)d(Ci))𝔼d(C_{i})-\Upsilon((\exp C_{i})\mathbb{E})\odot\left(\exp(C_{i})\odot d(C_{i})\right)\mathbb{E} in Equation (73). We write down the close form of the element in the kk-th row and jj-th column:

[d(Cih)Υ(exp(Cih)𝔼)(exp(Cih)d(Cih))𝔼]kj\displaystyle\left[d(C_{ih})-\Upsilon(\exp(C_{ih})\mathbb{E})\odot\left(\exp(C_{ih})\odot d(C_{ih})\right)\mathbb{E}\right]_{kj} (74)
=(i)(1exp(Cihkj)j=1nexp(Cihkj))d(Cihkj)pjexp(Cihkp)d(Cihkp)j=1nexp(Cihkj)\displaystyle\stackrel{{\scriptstyle(i)}}{{=}}\left(1-\frac{\exp\left(C_{ihkj}\right)}{\sum\limits_{j=1}^{n}\exp\left(C_{ihkj}\right)}\right)d(C_{ihkj})-\frac{\sum\limits_{p\neq j}\exp\left(C_{ihkp}\right)d(C_{ihkp})}{\sum\limits_{j=1}^{n}\exp\left(C_{ihkj}\right)} (75)
(ii)(1exp(Cihkj)j=1nexp(Cihkj))2+pj(exp(Cihkp)j=1nexp(Cihkj))2j=1n(d(Cihkj))2\displaystyle\stackrel{{\scriptstyle(ii)}}{{\leq}}\sqrt{\left(1-\frac{\exp\left(C_{ihkj}\right)}{\sum\limits_{j=1}^{n}\exp\left(C_{ihkj}\right)}\right)^{2}+\sum\limits_{p\neq j}\left(\frac{\exp(C_{ihkp})}{\sum\limits_{j=1}^{n}\exp\left(C_{ihkj}\right)}\right)^{2}}\cdot\sqrt{\sum\limits_{j=1}^{n}\big{(}d(C_{ihkj})\big{)}^{2}} (76)
(iii)nd(Cihk)F,\displaystyle\stackrel{{\scriptstyle(iii)}}{{\leq}}\sqrt{n}\|d(C_{ihk})\|_{F}, (77)

where (i) is expand the closed form of Equation (74); (ii) uses the Cauchy-Schwartz inequality; (iii) is because each element in the square root in (ii) is upper bounded by 11. With Equation (76), we can easily show

d(Cih)Υ((expCih)𝔼)((expCih)d(Cih))𝔼Fnk=1nj=1nd(Cihk)F2nd(Cih)F\displaystyle\left\|d(C_{ih})-\Upsilon((\exp C_{ih})\mathbb{E})\odot\left((\exp C_{ih})\odot d(C_{ih})\right)\mathbb{E}\right\|_{F}\leq\sqrt{n}\sqrt{\sum\limits_{k=1}^{n}\sum\limits_{j=1}^{n}\|d(C_{ihk})\|_{F}^{2}}\leq n\|d(C_{ih})\|_{F} (78)

Since every element in SiS_{i} has magnitude less than 11, we have

d(Si)F=Si(d(Ci)Υ((expCi)𝔼)((expCi)d(Ci))𝔼)F\displaystyle\|d(S_{i})\|_{F}=\left\|S_{i}\odot\left(d\left(C_{i}\right)-\Upsilon\left(\left(\exp C_{i}\right)\mathbb{E}\right)\odot\left(\left(\exp C_{i}\right)\odot d\left(C_{i}\right)\right)\mathbb{E}\right)\right\|_{F} (79)
d(Cih)Υ((expCih)𝔼)((expCih)d(Cih))𝔼F\displaystyle\leq\left\|d(C_{ih})-\Upsilon((\exp C_{ih})\mathbb{E})\odot\left((\exp C_{ih})\odot d(C_{ih})\right)\mathbb{E}\right\|_{F} (80)
(i)nHd(Ci)F,\displaystyle\stackrel{{\scriptstyle(i)}}{{\leq}}n\sqrt{H}\|d(C_{i})\|_{F}, (81)

where (i) is from Cauchy-Schawatz inequality.

Proof of Step 2: We aim to show d(Ci)FndXiF2h=1Hσmax2(WhK)d(WQ)F\left\|d\left(C_{i}\right)\right\|_{F}\leq\frac{n}{\sqrt{d}}\left\|X_{i}\right\|_{F}^{2}\sqrt{\sum_{h=1}^{H}\sigma_{\max}^{2}\left(W_{h}^{K}\right)}\cdot\left\|d\left(W^{Q}\right)\right\|_{F}. Similarly, we investigate d(Cih)F,h=1,2,,H.\|d(C_{ih})\|_{F},\;h=1,2,\cdots,H. We have

d(Cih)F=Xid(WhQ)(XiWhK)dF1dXiF2σmax(WhK)d(WhQ)F\displaystyle\|d(C_{ih})\|_{F}=\left\|\frac{X_{i}d(W_{h}^{Q})\left(X_{i}W_{h}^{K}\right)^{\top}}{\sqrt{d}}\right\|_{F}\leq\frac{1}{\sqrt{d}}\|X_{i}\|_{F}^{2}\sigma_{\max}(W^{K}_{h})\|d(W^{Q}_{h})\|_{F} (82)

Then plug the above inequality to Equation (78), we can derive

d(Sih)FndXiF2σmax(WhK)d(WhQ)F\displaystyle\|d(S_{ih})\|_{F}\leq\frac{n}{\sqrt{d}}\|X_{i}\|_{F}^{2}\sigma_{\max}\left(W_{h}^{K}\right)\left\|d(W_{h}^{Q})\right\|_{F} (83)

Thus by Cauchy-Schwartz inequality, it is easy to show

d(Si)FndXiF2h=1Hσmax2(WhK)d(WQ)F.\displaystyle\|d(S_{i})\|_{F}\leq\frac{n}{\sqrt{d}}\|X_{i}\|_{F}^{2}\sqrt{\sum_{h=1}^{H}\ \sigma_{\max}^{2}\left(W_{h}^{K}\right)}\cdot\left\|d\left(W^{Q}\right)\right\|_{F}.

Proof of Lemma 2 (4).

Proof.

We first write down the close form of gradient of f()f(\cdot) over WhQW^{Q}_{h} by Lemma 1, and derive the upper bound of the norm of the gradient.

f(M;Xi)WhQF=1dXif(M;Xi)CihXiWhKFXiF2σmax(WhK)f(M;Xi)CiF\left\|\frac{\partial f(M;X_{i})}{\partial W_{h}^{Q}}\right\|_{F}=\left\|\frac{1}{\sqrt{d}}X_{i}^{\top}\frac{\partial f(M;X_{i})}{\partial C_{i}}\mathbb{P}_{h}^{\top}X_{i}W_{h}^{K}\right\|_{F}\leq\|X_{i}\|_{F}^{2}\sigma_{\max}(W_{h}^{K})\left\|\frac{\partial f\left(M;X_{i}\right)}{\partial C_{i}}\right\|_{F} (84)

By Lemma 1, there is

f(M;Xi)Ci=((𝖬𝖧(M;Xi)yi)(WO)(Vi))Si\displaystyle\frac{\partial f(M;X_{i})}{\partial C_{i}}=\left(({\sf MH}(M;X_{i})-y_{i})\left(W^{O}\right)^{\top}\left(V^{\prime}_{i}\right)^{\top}\right)\odot S_{i}
((((𝖬𝖧(M;Xi)yi)(WO)(Vi))SiΥ((expCi)𝔼))𝔼)expCi\displaystyle\quad-\left(\left(\left(({\sf MH}(M;X_{i})-y_{i})\left(W^{O}\right)^{\top}\left(V^{\prime}_{i}\right)^{\top}\right)\odot S_{i}\odot\Upsilon\big{(}(\exp C_{i})\mathbb{E}\big{)}\right)\mathbb{E}^{\top}\right)\odot\exp C_{i} (85)

Denote Ri=(𝖬𝖧(M;Xi)yi)(WO)(Vi),Ri=[Ri1,,RiH]R_{i}=\left({\sf MH}\left(M;X_{i}\right)-y_{i}\right)\left(W^{O}\right)^{\top}\left(V_{i}^{\prime}\right)^{\top},\;R_{i}=[R_{i1},\cdots,R_{iH}]. Write down the close form of the element in the kk-th row and jj-th column:

[RihSih((RihCihΥ((expCih)𝔼))𝔼)(expCih)]kj\displaystyle\left[R_{ih}S_{ih}-\left(\left(R_{ih}\odot C_{ih}\odot\Upsilon\left(\left(\exp C_{ih}\right)\mathbb{E}\right)\right)\mathbb{E}^{\top}\right)\odot\left(\exp C_{ih}\right)\right]_{kj}
=RihkjSihkjexp(Cihkj)j=1nRihkjSihkjj=1nexp(Cihkj)\displaystyle=R_{ihkj}S_{ihkj}-\frac{\exp(C_{ihkj})\sum\limits_{j=1}^{n}R_{ihkj}S_{ihkj}}{\sum\limits_{j=1}^{n}\exp(C_{ihkj})}
=(Sihkj(expCihkj)Sihkjj=1nexp(Cihkj))Rihkjpj(expCihkp)Sihkjj=1nexp(Cihkp)Rihkp\displaystyle=\left(S_{ihkj}-\frac{(\exp C_{ihkj})S_{ihkj}}{\sum\limits_{j=1}^{n}\exp(C_{ihkj})}\right)\cdot R_{ihkj}-\sum\limits_{p\neq j}\frac{(\exp C_{ihkp})S_{ihkj}}{\sum\limits_{j=1}^{n}\exp(C_{ihkp})}R_{ihkp}
(i)(1exp(Cihkj)j=1nexp(Cihkj))2+pj(exp(Cihkp)j=1nexp(Cihkj))2RihkF\displaystyle\stackrel{{\scriptstyle(i)}}{{\leq}}\sqrt{\left(1-\frac{\exp\left(C_{ihkj}\right)}{\sum_{j=1}^{n}\exp\left(C_{ihkj}\right)}\right)^{2}+\sum_{p\neq j}\left(\frac{\exp\left(C_{ihkp}\right)}{\sum\limits_{j=1}^{n}\exp\left(C_{ihkj}\right)}\right)^{2}}\cdot\|R_{ihk}\|_{F}
(ii)nRihkF\displaystyle\stackrel{{\scriptstyle(ii)}}{{\leq}}\sqrt{n}\|R_{ihk}\|_{F}

where (1) is due to the Cauchy-Schwartz inequality; (ii) is because each element within the squre root term in (i) has magnitude at most 11. Thus, we can further derive

f(M;Xi)CihF=RihSih((RihSihΥ((expCih)𝔼))𝔼)expCihF\displaystyle\left\|\frac{\partial f\left(M;X_{i}\right)}{\partial C_{ih}}\right\|_{F}=\left\|R_{ih}\odot S_{ih}-\left(\left(R_{ih}\odot S_{ih}\odot\Upsilon((\exp C_{ih})\mathbb{E})\right)\mathbb{E}^{\top}\right)\odot\exp C_{ih}\right\|_{F}
(i)nk=1nj=1nRihkF(ii)nRihF\displaystyle\stackrel{{\scriptstyle(i)}}{{\leq}}\sqrt{n}\sum\limits_{k=1}^{n}\sum\limits_{j=1}^{n}\|R_{ihk}\|_{F}\stackrel{{\scriptstyle(ii)}}{{\leq}}n\|R_{ih}\|_{F}
(iii)nXiFWO2σmax(WhV)𝖬𝖧(M;Xi)yi2,\displaystyle\stackrel{{\scriptstyle(iii)}}{{\leq}}n\|X_{i}\|_{F}\|W^{O}\|_{2}\sigma_{\max}(W^{V}_{h})\|{\sf MH}(M;X_{i})-y_{i}\|_{2},

where (i) if from the bound in Equation (1.6); (ii) comes from Cauchy-Schwatz inwquality; (iii) uses the property of Frobenious norm. Thus, by Cauchy-Schwartz inequality, we can derive the upper bound for f(M;Xi)CiF\left\|\frac{\partial f\left(M;X_{i}\right)}{\partial C_{i}}\right\|_{F}.

f(M;Xi)CiFnHXiFWO2σmax(WV)𝖬𝖧(M;Xi)yi2\displaystyle\left\|\frac{\partial f\left(M;X_{i}\right)}{\partial C_{i}}\right\|_{F}\leq n\sqrt{H}\|X_{i}\|_{F}\|W^{O}\|_{2}\sigma_{\max}(W^{V})\|{\sf MH}(M;X_{i})-y_{i}\|_{2} (86)

So plug the above inequality into Equation (84), we can derive the upper bound for f(M;Xi)WhQF\left\|\frac{\partial f\left(M;X_{i}\right)}{\partial W_{h}^{Q}}\right\|_{F}:

f(M;Xi)WhQFXiF2σmax(WhK)f(M;X)CiF\displaystyle\left\|\frac{\partial f\left(M;X_{i}\right)}{\partial W_{h}^{Q}}\right\|_{F}\leq\left\|X_{i}\right\|_{F}^{2}\sigma_{\max}\left(W_{h}^{K}\right)\left\|\frac{\partial f\left(M;X\right)}{\partial C_{i}}\right\|_{F}
nHXiF3WO2σmax(WhK)σmax(WhV)𝖬𝖧(M;Xi)yi2\displaystyle\leq n\sqrt{H}\left\|X_{i}\right\|^{3}_{F}\left\|W^{O}\right\|_{2}\sigma_{\max}\left(W_{h}^{K}\right)\sigma_{\max}\left(W^{V}_{h}\right)\left\|{\sf MH}\left(M;X_{i}\right)-y_{i}\right\|_{2}
nHXiF3WO2h=1Hσmax2(WhK)σmax(WV)𝖬𝖧(M;Xi)yi2\displaystyle\leq n\sqrt{H}\left\|X_{i}\right\|^{3}_{F}\left\|W^{O}\right\|_{2}\sqrt{\sum\limits_{h=1}^{H}\sigma^{2}_{\max}\left(W_{h}^{K}\right)}\sigma_{\max}\left(W^{V}\right)\left\|{\sf MH}\left(M;X_{i}\right)-y_{i}\right\|_{2}

Proof of Lemma 3 (1). By Mean Value Theorem and Cauchy-Schwartz inequality,

|f(Mt+1;Xi)f(Mt;Xi)|\displaystyle|f\left(M_{t+1};X_{i}\right)-f\left(M_{t};X_{i}\right)|
=f(Mt;Xi)W,Mt+1Mt\displaystyle=\left\langle\frac{\partial f(M_{t}^{\prime};X_{i})}{\partial W},M_{t+1}-M_{t}\right\rangle
f(Mt;Xi)WQ2+f(Mt;Xi)WK2+f(Mt;Xi)WV2Mt+1MtF,\displaystyle\leq\sqrt{\left\|\frac{\partial f(M_{t}^{\prime};X_{i})}{\partial W^{Q}}\right\|^{2}+\left\|\frac{\partial f(M_{t}^{\prime};X_{i})}{\partial W^{K}}\right\|^{2}+\left\|\frac{\partial f(M_{t};X_{i})}{\partial W^{V}}\right\|^{2}}\|M_{t+1}-M_{t}\|_{F}, (87)

where MtM^{\prime}_{t} is between MtM_{t} and Mt+1M_{t+1}. We can derive the upper bound of the norm of WVf(M;Xi)\nabla_{W^{V}}f(M;X_{i}):

f(Mt;Xi)WVF=Bi(𝖬𝖧(Mt;Xi)yi)(WO)F\displaystyle\left\|\frac{\partial f(M_{t};X_{i})}{\partial W^{V}}\right\|_{F}=\|B_{i}^{\top}\left({\sf MH}(M_{t};X_{i})-y_{i}\right)\left(W^{O}\right)^{\top}\|_{F}
BiF𝖬𝖧(Mt;Xi)yiFWO2\displaystyle\leq\|B_{i}\|_{F}\|{\sf MH}(M_{t};X_{i})-y_{i}\|_{F}\|W^{O}\|_{2}
nHXiFWO2𝖬𝖧(Mt;Xi)yiF\displaystyle\leq n\sqrt{H}\|X_{i}\|_{F}\|W^{O}\|_{2}\|{\sf MH}(M_{t};X_{i})-y_{i}\|_{F} (88)

By Lemma 2, we know

f(Mt;Xi)WQFQi𝖬𝖧(M;Xi)yi2;f(Mt;Xi)WKFKi𝖬𝖧(M;Xi)yi2.\displaystyle\left\|\frac{\partial f(M_{t};X_{i})}{\partial W^{Q}}\right\|_{F}\leq Q_{i}\left\|{\sf MH}\left(M;X_{i}\right)-y_{i}\right\|_{2};\;\left\|\frac{\partial f(M_{t};X_{i})}{\partial W^{K}}\right\|_{F}\leq K_{i}\left\|{\sf MH}\left(M;X_{i}\right)-y_{i}\right\|_{2}.
f(Mt+1;Xi)f(Mt;Xi)2Qi2+Ki2+n2Hσmax2(Xi)WO2Mt+1MtF\displaystyle\left\|f\left(M_{t+1};X_{i}\right)-f\left(M_{t};X_{i}\right)\right\|_{2}\leq\sqrt{Q_{i}^{2}+K_{i}^{2}+n^{2}H\sigma^{2}_{\max}(X_{i})\|W^{O}\|^{2}}\|M_{t+1}-M_{t}\|_{F}
:=ZiMt+1MtF\displaystyle:=Z_{i}\|M_{t+1}-M_{t}\|_{F} (89)

Therefore, together with Equation (88), we have

f(Mt+1;X)f(Mt;X)2NmaxiQi2+maxiKi2+n2HmaxiXiF2Mt+1MtF\displaystyle\left\|f\left(M_{t+1};X\right)-f\left(M_{t};X\right)\right\|_{2}\leq N\sqrt{\max\limits_{i}Q_{i}^{2}+\max\limits_{i}K_{i}^{2}+n^{2}H\max\limits_{i}\|X_{i}\|_{F}^{2}}\|M_{t+1}-M_{t}\|_{F}
:=ZMt+1MtF\displaystyle:=Z\|M_{t+1}-M_{t}\|_{F} (90)

Proof of Lemma 3 (2).

Proof.

By triangle inequality, we have

Wf(Mt+1;X)Wf(Mt;X)F\displaystyle\left\|\nabla_{W}f(M_{t+1};X)-\nabla_{W}f(M_{t};X)\right\|_{F}
WQf(Mt+1;X)WQf(Mt+1;X)F+WKf(Mt+1;X)WKf(Mt+1;X)F\displaystyle\leq\left\|\nabla_{W^{Q}}f(M_{t+1};X)-\nabla_{W^{Q}}f(M_{t+1};X)\right\|_{F}+\left\|\nabla_{W^{K}}f(M_{t+1};X)-\nabla_{W^{K}}f(M_{t+1};X)\right\|_{F}
+WVf(Mt+1;X)WVf(Mt+1;X)F\displaystyle\quad+\left\|\nabla_{W^{V}}f(M_{t+1};X)-\nabla_{W^{V}}f(M_{t+1};X)\right\|_{F} (91)
i=1N(WQf(Mt+1;X)WQf(Mt+1;X)F+WKf(Mt+1;X)WKf(Mt+1;X)F\displaystyle\leq\sum\limits_{i=1}^{N}\big{(}\left\|\nabla_{W^{Q}}f\left(M_{t+1};X\right)-\nabla_{W^{Q}}f\left(M_{t+1};X\right)\right\|_{F}+\left\|\nabla_{W^{K}}f\left(M_{t+1};X\right)-\nabla_{W^{K}}f\left(M_{t+1};X\right)\right\|_{F} (92)
+Wvf(Mt+1;X)Wvf(Mt+1;X)F)\displaystyle\quad+\|\nabla_{W^{v}}f\left(M_{t+1};X\right)-\nabla_{W^{v}}f\left(M_{t+1};X\right)\|_{F}\big{)} (93)

Step 1: Derive upper bound for

WQf(Mt+1;Xi))WQf(Mt;Xi))F=vec(WQf(Mt+1;Xi))vec(WQf(Mt;Xi))2.\|\nabla_{W^{Q}}f(M_{t+1};X_{i}))-\nabla_{W^{Q}}f(M_{t};X_{i}))\|_{F}=\|\operatorname{vec}(\nabla_{W^{Q}}f(M_{t+1};X_{i}))-\operatorname{vec}(\nabla_{W^{Q}}f(M_{t};X_{i}))\|_{2}.

First, we give the vectorized expression of WQf(Mt;Xi)\nabla_{W^{Q}}f\left(M_{t};X_{i}\right). Recall we denote Ui=((𝖬𝖧(M;Xi)yi)(WO)(Vi))SiU_{i}=\left(\left({\sf MH}\left(M;X_{i}\right)-y_{i}\right)\left(W^{O}\right)^{\top}\left(V_{i}^{\prime}\right)^{\top}\right)\odot S_{i}. By Lemma 1, we can derive the close form of vec(WQf(Mt;Xi))\operatorname{vec}(\nabla_{W^{Q}}f(M_{t};X_{i})):

vec(WQf(M;Xi))=(i)vec(Ui)vec((UiΥ((expCi)𝔼))𝔼)vec(expCi)\displaystyle\operatorname{vec}(\nabla_{W^{Q}}f(M;X_{i}))\stackrel{{\scriptstyle(i)}}{{=}}\operatorname{vec}(U_{i})-\operatorname{vec}\left(\left(U_{i}\odot\Upsilon((\exp C_{i})\mathbb{E})\right)\mathbb{E}^{\top}\right)\odot\operatorname{vec}(\exp C_{i})
=(ii)vec(Ui)(𝔼𝕀n)vec(UiΥ((expCi)𝔼))vec(expCi)\displaystyle\stackrel{{\scriptstyle(ii)}}{{=}}\operatorname{vec}(U_{i})-\left(\mathbb{E}\otimes\mathbb{I}_{n}\right)\operatorname{vec}\big{(}U_{i}\odot\Upsilon((\exp C_{i})\mathbb{E})\big{)}\odot\operatorname{vec}(\exp C_{i})
=(iii)vec(Ui)(𝔼𝕀n)vec(Ui)vec(Υ((expCi)𝔼))vec(expCi)\displaystyle\stackrel{{\scriptstyle(iii)}}{{=}}\operatorname{vec}(U_{i})-\left(\mathbb{E}\otimes\mathbb{I}_{n}\right)\operatorname{vec}(U_{i})\odot\operatorname{vec}\big{(}\Upsilon((\exp C_{i})\mathbb{E})\big{)}\odot\operatorname{vec}(\exp C_{i})
=(iv)vec(Ui)(𝔼𝕀n)vec(Ui)vec(Si)\displaystyle\stackrel{{\scriptstyle(iv)}}{{=}}\operatorname{vec}(U_{i})-\left(\mathbb{E}\otimes\mathbb{I}_{n}\right)\operatorname{vec}(U_{i})\odot\operatorname{vec}(S_{i})
=(v)𝕀n2Hvec(Ui)vec(𝟏n𝟏nH)(𝔼𝕀n)vec(Ui)vec(Si)\displaystyle\stackrel{{\scriptstyle(v)}}{{=}}\mathbb{I}_{n^{2}H}\operatorname{vec}(U_{i})\odot\operatorname{vec}(\mathbf{1}_{n}\mathbf{1}_{nH}^{\top})-\left(\mathbb{E}\otimes\mathbb{I}_{n}\right)\operatorname{vec}(U_{i})\odot\operatorname{vec}(S_{i})
=(vi)(𝕀n2H(𝔼𝕀n))vec(Ui)vec(𝟏n𝟏nHSi),\displaystyle\stackrel{{\scriptstyle(vi)}}{{=}}\big{(}\mathbb{I}_{n^{2}H}-(\mathbb{E}\otimes\mathbb{I}_{n})\big{)}\operatorname{vec}(U_{i})\odot\operatorname{vec}(\mathbf{1}_{n}\mathbf{1}^{\top}_{nH}-S_{i}), (94)

where (i) uses the Lemma 1; (ii) and (iii) comes from the property of vectorization in Lemma 5; (vi) uses the definition of SiS_{i} ; (v) gives an equivalent expression of vec(Ui)\operatorname{vec}(U_{i}); (vi) reorganizies (v). Further, it is easy to verify that:

UiF=((𝖬𝖧(M;Xi)yi)(WO)(Vi))SiFRiF\displaystyle\|U_{i}\|_{F}=\left\|\left(\left({\sf MH}\left(M;X_{i}\right)-y_{i}\right)\left(W^{O}\right)^{\top}\left(V_{i}^{\prime}\right)^{\top}\right)\odot S_{i}\right\|_{F}\leq\|R_{i}\|_{F}
=(𝖬𝖧(M;Xi)2+yi2)WO2XiFσmax(WV)\displaystyle=(\|{\sf MH}\left(M;X_{i}\right)\|_{2}+\|y_{i}\|_{2})\|W^{O}\|_{2}\|X_{i}\|_{F}\sigma_{\max}(W^{V})
(nHσmax(WV)XiFWO2+yi2)WO2XiFσmax(WV)\displaystyle\leq\big{(}n\sqrt{H}\sigma_{\max}(W^{V})\|X_{i}\|_{F}\|W^{O}\|_{2}+\|y_{i}\|_{2}\big{)}\left\|W^{O}\right\|_{2}\left\|X_{i}\right\|_{F}\sigma_{\max}\left(W^{V}\right)
(nHσmax(WV)XFWO2+y2)WO2XFσmax(WV)\displaystyle\leq\left(n\sqrt{H}\sigma_{\max}\left(W^{V}\right)\left\|X\right\|_{F}\left\|W^{O}\right\|_{2}+\left\|y\right\|_{2}\right)\left\|W^{O}\right\|_{2}\left\|X\right\|_{F}\sigma_{\max}\left(W^{V}\right)
:=R¯\displaystyle:=\bar{R} (95)

Next, let us derive upper bound for WQf(Mt+1;Xi)WQf(Mt+1;Xi)F\left\|\nabla_{W^{Q}}f\left(M_{t+1};X_{i}\right)-\nabla_{W^{Q}}f\left(M_{t+1};X_{i}\right)\right\|_{F}.

WQf(Mt+1;Xi)WQf(Mt+1;Xi)F\displaystyle\left\|\nabla_{W^{Q}}f(M_{t+1};X_{i})-\nabla_{W^{Q}}f(M_{t+1};X_{i})\right\|_{F}
=(i)(𝕀n2H(𝔼𝕀n))(vec(Ui,t+1)vec(Si,t+1)vec(Ui,t)vec(Si,t))F\displaystyle\stackrel{{\scriptstyle(i)}}{{=}}\left\|\big{(}\mathbb{I}_{n^{2}H}-(\mathbb{E}\otimes\mathbb{I}_{n})\big{)}\big{(}\operatorname{vec}(U_{i,t+1})\odot\operatorname{vec}(S_{i,t+1})-\operatorname{vec}(U_{i,t})\odot\operatorname{vec}(S_{i,t})\big{)}\right\|_{F}
=(𝕀n2H(𝔼𝕀n))(vec(Ui,t+1)vec(Si,t+1)vec(Ut)vec(Si,t+1)+vec(Ui,t)vec(Si,t+1)vec(Ui,t)vec(Si,t))F\displaystyle=\left\|\big{(}\mathbb{I}_{n^{2}H}-(\mathbb{E}\otimes\mathbb{I}_{n})\big{)}\big{(}\operatorname{vec}(U_{i,t+1})\odot\operatorname{vec}(S_{i,t+1})-\operatorname{vec}(U_{t})\odot\operatorname{vec}(S_{i,t+1})+\operatorname{vec}(U_{i,t})\odot\operatorname{vec}(S_{i,t+1})-\operatorname{vec}(U_{i,t})\odot\operatorname{vec}(S_{i,t})\big{)}\right\|_{F}
(ii)𝕀n2H(𝔼𝕀n)F(vec(Ui,t+1Ui,t)F+Ui,tFSi,t+1Si,tF)\displaystyle\stackrel{{\scriptstyle(ii)}}{{\leq}}\|\mathbb{I}_{n^{2}H}-(\mathbb{E}\otimes\mathbb{I}_{n})\|_{F}\bigg{(}\|\operatorname{vec}(U_{i,t+1}-U_{i,t})\|_{F}+\|U_{i,t}\|_{F}\|S_{i,t+1}-S_{i,t}\|_{F}\bigg{)}
(iii)nH(vec(Ui,t+1Ui,t)F+R¯Si,t+1Si,tF)\displaystyle\stackrel{{\scriptstyle(iii)}}{{\leq}}n\sqrt{H}\left(\left\|\operatorname{vec}\left(U_{i,t+1}-U_{i,t}\right)\right\|_{F}+\bar{R}\left\|S_{i,t+1}-S_{i,t}\right\|_{F}\right)
=(iv)nH(Ri,t+1Si,t+1Ri,tSi,tF+R¯Si,t+1Si,tF)\displaystyle\stackrel{{\scriptstyle(iv)}}{{=}}n\sqrt{H}\big{(}\|R_{i,t+1}\odot S_{i,t+1}-R_{i,t}\odot S_{i,t}\|_{F}+\bar{R}\|S_{i,t+1}-S_{i,t}\|_{F}\big{)}
=nH((Ri,t+1Si,t+1Ri,tSi,t+1+Ri,tSi,t+1Ri,tSi,t)F+R¯Si,t+1Si,tF)\displaystyle=n\sqrt{H}\big{(}\|(R_{i,t+1}\odot S_{i,t+1}-R_{i,t}\odot S_{i,t+1}+R_{i,t}\odot S_{i,t+1}-R_{i,t}\odot S_{i,t})\|_{F}+\bar{R}\|S_{i,t+1}-S_{i,t}\|_{F}\big{)}
(v)nH((Ri,t+1Ri,t)Si,t+1F+Ri,tSi,t+1Ri,tSt)F+R¯St+1StF)\displaystyle\stackrel{{\scriptstyle(v)}}{{\leq}}n\sqrt{H}\big{(}\|(R_{i,t+1}-R_{i,t})\odot S_{i,t+1}\|_{F}+\|R_{i,t}\odot S_{i,t+1}-R_{i,t}\odot S_{t})\|_{F}+\bar{R}\|S_{t+1}-S_{t}\|_{F}\big{)}
(vi)nH(Ri,t+1Ri,tF+Ri,tFSi,t+1Si,tF+R¯Si,t+1Si,t),\displaystyle\stackrel{{\scriptstyle(vi)}}{{\leq}}n\sqrt{H}\big{(}\|R_{i,t+1}-R_{i,t}\|_{F}+\|R_{i,t}\|_{F}\|S_{i,t+1}-S_{i,t}\|_{F}+\bar{R}\|S_{i,t+1}-S_{i,t}\|\big{)}, (96)

where (i) plugs in the expression in Equation (94); (ii) uses the fact that each element in Si,t+1S_{i,t+1} has magnitude at most 11, and Cauchy-Schwartz inequality; (iii) comes from the definition of 𝕀,𝔼\mathbb{I},\mathbb{E} and R¯\bar{R}; (iv) uses the definition of Ui,tU_{i,t}; (v) is because triangle inequality; (vi) uses the fact that each element in Si,t+1S_{i,t+1} has magnitude at most 11, and Cauchy-Schwartz inequality. Next, we aim to derive upper bound of Ri,t+1Ri,tF\left\|R_{i,t+1}-R_{i,t}\right\|_{F} in Equation (96).

Ri,t+1Ri,tF=(𝖬𝖧(Mt+1;Xi)yi)WO(Vi,t+1)(𝖬𝖧(Mt;Xi)yi)WO(Vi,t)F\displaystyle\|R_{i,t+1}-R_{i,t}\|_{F}=\left\|\big{(}{\sf MH}(M_{t+1};X_{i})-y_{i}\big{)}W^{O}(V^{\prime}_{i,t+1})-\big{(}{\sf MH}(M_{t};X_{i})-y_{i}\big{)}W^{O}(V^{\prime}_{i,t})\right\|_{F}
=(𝖬𝖧(Mt+1;Xi)yi)WO(Vi,t+1)(𝖬𝖧(Mt;Xi)yi)WO(Vi,t+1)+\displaystyle=\|\big{(}{\sf MH}(M_{t+1};X_{i})-y_{i}\big{)}W^{O}(V^{\prime}_{i,t+1})-\big{(}{\sf MH}(M_{t};X_{i})-y_{i}\big{)}W^{O}(V^{\prime}_{i,t+1})+
(𝖬𝖧(Mt;Xi)yi)WO(Vi,t+1)(𝖬𝖧(Mt;Xi)yi)WO(Vi,t)F\displaystyle\quad\big{(}{\sf MH}(M_{t};X_{i})-y_{i}\big{)}W^{O}(V^{\prime}_{i,t+1})-\big{(}{\sf MH}(M_{t};X_{i})-y_{i}\big{)}W^{O}(V^{\prime}_{i,t})\|_{F}
(i)(𝖬𝖧(Mt+1;Xi)𝖬𝖧(Mt;Xi))(Vi,t+1)WOF+(𝖬𝖧(Mt;Xi)yi)(Vi,t+1Vi,t)WOF\displaystyle\stackrel{{\scriptstyle(i)}}{{\leq}}\|\big{(}{\sf MH}(M_{t+1};X_{i})-{\sf MH}(M_{t};X_{i})\big{)}(V^{\prime}_{i,t+1})W^{O}\|_{F}+\|\big{(}{\sf MH}(M_{t};X_{i})-y_{i}\big{)}(V^{\prime}_{i,t+1}-V^{\prime}_{i,t})W^{O}\|_{F}
(ii)ZiMt+1MtFXiFσmax(WV)WO2\displaystyle\stackrel{{\scriptstyle(ii)}}{{\leq}}Z_{i}\|M_{t+1}-M_{t}\|_{F}\|\|X_{i}\|_{F}\sigma_{\max}(W^{V})\|W^{O}\|_{2}
+(𝖬𝖧(Mt+1;Xi)F+yi2)XiFWt+1VWtVFWO2\displaystyle\quad+\big{(}\|{\sf MH}(M_{t+1};X_{i})\|_{F}+\|y_{i}\|_{2}\big{)}\|X_{i}\|_{F}\|W_{t+1}^{V}-W_{t}^{V}\|_{F}\|W^{O}\|_{2}
(iii)ZiXiσmaxF(WV)WO2Mt+1MtF\displaystyle\stackrel{{\scriptstyle(iii)}}{{\leq}}Z_{i}\|X_{i}\left\|{}_{F}\sigma_{\max}\left(W^{V}\right)\right\|W^{O}\|_{2}\left\|M_{t+1}-M_{t}\right\|_{F}
+(nHσmax(WV)XiFWO2+yi2)XiFWt+1VWtVFWO2\displaystyle\quad+\big{(}n\sqrt{H}\sigma_{\max}\left(W^{V}\right)\left\|X_{i}\right\|_{F}\left\|W^{O}\right\|_{2}+\|y_{i}\|_{2}\big{)}\left\|X_{i}\right\|_{F}\left\|W_{t+1}^{V}-W_{t}^{V}\right\|_{F}\left\|W^{O}\right\|_{2}
(iv)(ZiXiFσmax(WV)WO2+(nHσmax(WV)XiFWO2+yi2)XiFWO2)\displaystyle\stackrel{{\scriptstyle(iv)}}{{\leq}}\big{(}Z_{i}\left\|X_{i}\right\|_{F}\sigma_{\max}\left(W^{V}\right)\left\|W^{O}\right\|_{2}+(n\sqrt{H}\sigma_{\max}\left(W^{V}\right)\left\|X_{i}\right\|_{F}\left\|W^{O}\right\|_{2}+\left\|y_{i}\right\|_{2})\left\|X_{i}\right\|_{F}\|W^{O}\|_{2}\big{)}
×Mt+1MtF\displaystyle\quad\times\|M_{t+1}-M_{t}\|_{F}
:=PiMt+1MtF,\displaystyle:=P_{i}\|M_{t+1}-M_{t}\|_{F}, (97)

where (i) is because of the triangle inequality; (ii) uses the definition of ZiZ_{i} in Equation (89), Cauchy-Schwartz inequality and triangle inequality; (iii) uses the Cauchy-Schwartz inequality; (iv) reorganizes the terms in (iii). Plug Equation (97) into Equation (96), we can finally derive the bound for WQf(Mt+1;Xi)WQf(Mt+1;Xi)F\left\|\nabla_{W^{Q}}f\left(M_{t+1};X_{i}\right)-\nabla_{W^{Q}}f\left(M_{t+1};X_{i}\right)\right\|_{F}.

WQf(Mt+1;Xi)WQf(Mt+1;Xi)F\displaystyle\left\|\nabla_{W^{Q}}f(M_{t+1};X_{i})-\nabla_{W^{Q}}f(M_{t+1};X_{i})\right\|_{F}
(i)nH(Ri,t+1Ri,tF+Ri,tFSi,t+1Si,tF+R¯Si,t+1Si,t)\displaystyle\stackrel{{\scriptstyle(i)}}{{\leq}}n\sqrt{H}\left(\left\|R_{i,t+1}-R_{i,t}\right\|_{F}+\left\|R_{i,t}\right\|_{F}\left\|S_{i,t+1}-S_{i,t}\right\|_{F}+\bar{R}\left\|S_{i,t+1}-S_{i,t}\right\|\right)
(ii)nHPiMt+1MtF+2R¯nHSi,t+1Si,tF\displaystyle\stackrel{{\scriptstyle(ii)}}{{\leq}}n\sqrt{H}P_{i}\|M_{t+1}-M_{t}\|_{F}+2\bar{R}n\sqrt{H}\|S_{i,t+1}-S_{i,t}\|_{F}
(iii)nHPiMt+1MtF+2R¯nHϕi2+ψi2Mt+1MtF\displaystyle\stackrel{{\scriptstyle(iii)}}{{\leq}}n\sqrt{H}P_{i}\|M_{t+1}-M_{t}\|_{F}+2\bar{R}n\sqrt{H}\sqrt{\phi_{i}^{2}+\psi_{i}^{2}}\|M_{t+1}-M_{t}\|_{F}
:=LiQMt+1MtF,\displaystyle:=L^{Q}_{i}\|M_{t+1}-M_{t}\|_{F},

where (i) is from Equation (96); (ii) uses the definition of R¯\bar{R} in Equation (95); (iii) comes from Lemma 3 (3). Since WQW^{Q} and WKW^{K} are symmetric in the Transormer structure, similarly, we can derive LiKL^{K}_{i}.
Step 2: In this step, we aim to derive bound for Wvf(Mt+1;Xi)Wvf(Mt;Xi)F\left\|\nabla_{W^{v}}f\left(M_{t+1};X_{i}\right)-\nabla_{W^{v}}f\left(M_{t};X_{i}\right)\right\|_{F}.

WVf(Mt+1;Xi)WVf(Mt;Xi)F\displaystyle\left\|\nabla_{W^{V}}f(M_{t+1};X_{i})-\nabla_{W^{V}}f(M_{t};X_{i})\right\|_{F}
=(i)Bi,t+1(𝖬𝖧(Mt+1;Xi)y)(WO)Bi,t(𝖬𝖧(Mt;Xi)yi)(WO)F\displaystyle\stackrel{{\scriptstyle(i)}}{{=}}\left\|B_{i,t+1}^{\top}\left({\sf MH}\left(M_{t+1};X_{i}\right)-y\right)\left(W^{O}\right)^{\top}-B_{i,t}^{\top}\left({\sf MH}\left(M_{t};X_{i}\right)-y_{i}\right)\left(W^{O}\right)^{\top}\right\|_{F}
(ii)Bi,t+1(𝖬𝖧(Mt+1;Xi)yi)(WO)Bi,t+1(𝖬𝖧(Mt;Xi)yi)(WO)F\displaystyle\stackrel{{\scriptstyle(ii)}}{{\leq}}\left\|B_{i,t+1}^{\top}\left({\sf MH}\left(M_{t+1};X_{i}\right)-y_{i}\right)\left(W^{O}\right)^{\top}-B_{i,t+1}^{\top}\left({\sf MH}\left(M_{t};X_{i}\right)-y_{i}\right)\left(W^{O}\right)^{\top}\right\|_{F}
+Bi,t+1(𝖬𝖧(Mt;Xi)yi)(WO)Bi,t(𝖬𝖧(Mt;Xi)yi)(WO)F\displaystyle\quad+\left\|B_{i,t+1}^{\top}\left({\sf MH}\left(M_{t};X_{i}\right)-y_{i}\right)\left(W^{O}\right)^{\top}-B_{i,t}^{\top}\left({\sf MH}\left(M_{t};X_{i}\right)-y_{i}\right)\left(W^{O}\right)^{\top}\right\|_{F}
(iii)Bi,t+1F𝖬𝖧(Mt+1;Xi)𝖬𝖧(Mt;Xi)FWO2+Bi,t+1Bi,tF𝖬𝖧(Mt;Xi)yiFWO2\displaystyle\stackrel{{\scriptstyle(iii)}}{{\leq}}\|B_{i,t+1}\|_{F}\|\left\|{\sf MH}\left(M_{t+1};X_{i}\right)-{\sf MH}\left(M_{t};X_{i}\right)\right\|_{F}\|W^{O}\|_{2}+\|B_{i,t+1}-B_{i,t}\|_{F}\|{\sf MH}\left(M_{t};X_{i}\right)-y_{i}\|_{F}\|W^{O}\|_{2}
(iv)nHXiFWO2ZiMt+1MtF+Si,t+1Si,tFXiFWO2(𝖬𝖧(Mt+1;Xi)F+yi2)\displaystyle\stackrel{{\scriptstyle(iv)}}{{\leq}}n\sqrt{H}\|X_{i}\|_{F}\|W^{O}\|_{2}Z_{i}\|M_{t+1}-M_{t}\|_{F}+\left\|S_{i,t+1}-S_{i,t}\right\|_{F}\|X_{i}\|_{F}\|W^{O}\|_{2}\left(\left\|{\sf MH}\left(M_{t+1};X_{i}\right)\right\|_{F}+\left\|y_{i}\right\|_{2}\right)
(v)ϕi2+ψi2XiFWO2(nHσmax(WV)XiFWO2+yi2)Mt+1MtF\displaystyle\stackrel{{\scriptstyle(v)}}{{\leq}}\sqrt{\phi_{i}^{2}+\psi_{i}^{2}}\left\|X_{i}\right\|_{F}\left\|W^{O}\right\|_{2}\left(n\sqrt{H}\sigma_{\max}\left(W^{V}\right)\left\|X_{i}\right\|_{F}\left\|W^{O}\right\|_{2}+\left\|y_{i}\right\|_{2}\right)\|M_{t+1}-M_{t}\|_{F}
+nHWO2XiFZiMt+1MtF\displaystyle\quad+n\sqrt{H}\left\|W^{O}\right\|_{2}\|X_{i}\|_{F}Z_{i}\left\|M_{t+1}-M_{t}\right\|_{F}
(vi)(ϕi2+ψi2XiFWO2(nHσmax(WV)XiFWO2+yi2)+nHWO2Zi)Mt+1MtF\displaystyle\stackrel{{\scriptstyle(vi)}}{{\leq}}\left(\sqrt{\phi_{i}^{2}+\psi_{i}^{2}}\left\|X_{i}\right\|_{F}\left\|W^{O}\right\|_{2}\left(n\sqrt{H}\sigma_{\max}\left(W^{V}\right)\left\|X_{i}\right\|_{F}\left\|W^{O}\right\|_{2}+\left\|y_{i}\right\|_{2}\right)+n\sqrt{H}\left\|W^{O}\right\|_{2}Z_{i}\right)\|M_{t+1}-M_{t}\|_{F}
:=LiVMt+1MtF\displaystyle:=L_{i}^{V}\|M_{t+1}-M_{t}\|_{F}

where (i) is from Lemma 1 (1); (ii) uses triangle inequality; (iii) uses Cauchy-Schwartz inequality; (iv) comes from the definition of Bi,t,B_{i,t},, ZiZ_{i}(in Equation (89)), Cauchy-Schwartz inequality and triangle inequality; (v) comes from Lemma 2 (3) and Cauchy-Schwartz inequality; (vi) reorganizes (v).

Now we combine the result in Step 1 and Step 2, and plug into Equation (93), we can finally derive

Wf(Mt+1;X)Wf(Mt;X)Fi=1N(LiQ+LiK+LiV)Mt+1MtF\displaystyle\left\|\nabla_{W}f\left(M_{t+1};X\right)-\nabla_{W}f\left(M_{t};X\right)\right\|_{F}\leq\sum\limits_{i=1}^{N}(L_{i}^{Q}+L_{i}^{K}+L_{i}^{V})\|M_{t+1}-M_{t}\|_{F}
N(maxiLiQ+maxiLiK+maxiLiV)Mt+1MtF\displaystyle\leq N(\max\limits_{i}L_{i}^{Q}+\max\limits_{i}L_{i}^{K}+\max\limits_{i}L_{i}^{V})\|M_{t+1}-M_{t}\|_{F} (98)
:=GMt+1MtF.\displaystyle:=G\|M_{t+1}-M_{t}\|_{F}. (99)

1.7 Proof of Lemma in Section 1.4

Proof.

Proof of Lemma 6 (1): We consider the differential of the element in the kk-th row and jj-th column. First, let us write down the closed form of each element:

(Cih)kj=XikWhQXijWhK2/2d\displaystyle(C_{ih})_{kj}=-\|X_{ik\cdot}W_{h}^{Q}-X_{ij\cdot}W_{h}^{K}\|^{2}/2\sqrt{d}

Next, we consider the differential of each element over WhQW^{Q}_{h}:

d(Cih)kj=12d(Xik(WhQ+d(WhQ))XijWhK2Xik(WhQ)XijWhK2)\displaystyle d\left(C_{ih}\right)_{kj}=-\frac{1}{2\sqrt{d}}\left(\left\|X_{ik\cdot}\big{(}W^{Q}_{h}+d(W_{h}^{Q})\big{)}-X_{ij\cdot}W_{h}^{K}\right\|^{2}-\left\|X_{ik\cdot}(W^{Q}_{h})-X_{ij\cdot}W_{h}^{K}\right\|^{2}\right)
=1dXikd(WhQ),XikWhQXijWhK+o(d(WhQ)),\displaystyle=-\frac{1}{\sqrt{d}}\langle X_{ik\cdot}d(W_{h}^{Q}),X_{ik\cdot}W_{h}^{Q}-X_{ij\cdot}W_{h}^{K}\rangle+o\big{(}d(W_{h}^{Q})\big{)},

where o(d(WhQ))o(d(W^{Q}_{h})) denotes the higher order of d(WhQ)d(W^{Q}_{h}). Leave out the higher order differential term, we derive

d(Cih)kjF1d(Xik2d(WhQ)Fσmax(WhQ)Xik2+d(WhQ)Fσmax(WhK)Xik2Xij2)\displaystyle\|d\left(C_{ih}\right)_{kj}\|_{F}\leq\frac{1}{\sqrt{d}}\left(\|X_{ik\cdot}\|_{2}\|d(W_{h}^{Q})\|_{F}\cdot\sigma_{\max}(W_{h}^{Q})\|X_{ik\cdot}\|_{2}+\|d(W_{h}^{Q})\|_{F}\cdot\sigma_{\max}(W_{h}^{K})\|X_{ik\cdot}\|_{2}\|X_{ij\cdot}\|_{2}\right)
1dXik2d(WhQ)F(σmax(WhQ)Xik2+σmax(WhK)Xij2)\displaystyle\leq\frac{1}{\sqrt{d}}\|X_{ik\cdot}\|_{2}\|d(W_{h}^{Q})\|_{F}(\sigma_{\max}(W_{h}^{Q})\|X_{ik\cdot}\|_{2}+\sigma_{\max}(W_{h}^{K})\|X_{ij\cdot}\|_{2})
1dXik2σmax2(WhQ)+σmax2(WhK)Xik22+Xij22d(WhQ)F\displaystyle\leq\frac{1}{\sqrt{d}}\|X_{ik\cdot}\|_{2}\sqrt{\sigma^{2}_{\max}(W_{h}^{Q})+\sigma^{2}_{\max}(W_{h}^{K})}\cdot\sqrt{\|X_{ik\cdot}\|_{2}^{2}+\|X_{ij\cdot}\|_{2}^{2}}\|d(W_{h}^{Q})\|_{F}
d(Cih)F=k=1nj=1nd(Cih)kjF2\displaystyle\left\|d\left(C_{ih}\right)\right\|_{F}=\sum\limits_{k=1}^{n}\sum\limits_{j=1}^{n}\|d(C_{ih})_{kj}\|_{F}^{2}
1dk=1nj=1nXik2σmax2(WhQ)+σmax2(WhK)Xik22+Xij22d(WhQ)F\displaystyle\leq\frac{1}{\sqrt{d}}\sum\limits_{k=1}^{n}\sum\limits_{j=1}^{n}\left\|X_{ik\cdot}\right\|_{2}\sqrt{\sigma_{\max}^{2}\left(W_{h}^{Q}\right)+\sigma_{\max}^{2}\left(W_{h}^{K}\right)}\cdot\sqrt{\left\|X_{ik\cdot}\right\|_{2}^{2}+\left\|X_{ij\cdot}\right\|_{2}^{2}}\|d(W_{h}^{Q})\|_{F}
1dσmax2(WhQ)+σmax2(WhK)k=1nXiknXik22+j=1nXijF2d(WhQ)F\displaystyle\leq\frac{1}{\sqrt{d}}\sqrt{\sigma_{\max}^{2}\left(W_{h}^{Q}\right)+\sigma_{\max}^{2}\left(W_{h}^{K}\right)}\sum\limits_{k=1}^{n}\|X_{ik\cdot}\|\sqrt{n\|X_{ik\cdot}\|_{2}^{2}+\sum\limits_{j=1}^{n}\|X_{ij\cdot}\|_{F}^{2}}\|d(W_{h}^{Q})\|_{F}
1dσmax2(WhQ)+σmax2(WhK)k=1nXikF2k=1n(nXik22+j=1nXijF2)d(WhQ)F\displaystyle\leq\frac{1}{\sqrt{d}}\sqrt{\sigma_{\max}^{2}\left(W_{h}^{Q}\right)+\sigma_{\max}^{2}\left(W_{h}^{K}\right)}\cdot\sqrt{\sum\limits_{k=1}^{n}\|X_{ik\cdot}\|_{F}^{2}}\cdot\sqrt{\sum\limits_{k=1}^{n}(n\|X_{ik\cdot}\|_{2}^{2}+\sum\limits_{j=1}^{n}\|X_{ij\cdot}\|_{F}^{2}})\|d(W_{h}^{Q})\|_{F}
=1dσmax2(WhQ)+σmax2(WhK)XiF2nXiFd(WhQ)F\displaystyle=\frac{1}{\sqrt{d}}\sqrt{\sigma_{\max}^{2}\left(W_{h}^{Q}\right)+\sigma_{\max}^{2}\left(W_{h}^{K}\right)}\cdot\|X_{i}\|_{F}\cdot\sqrt{2n}\|X_{i}\|_{F}\|d(W_{h}^{Q})\|_{F}
=2ndXiF2σmax2(WhQ)+σmax2(WhK)d(WhQ)F\displaystyle=\sqrt{\frac{2n}{d}}\|X_{i}\|_{F}^{2}\sqrt{\sigma_{\max}^{2}\left(W_{h}^{Q}\right)+\sigma_{\max}^{2}\left(W_{h}^{K}\right)}\|d(W_{h}^{Q})\|_{F}

Proof.

Proof of Lemma 7 (2): First, let us write down the closed form of (Cih)kjWhQ\frac{\partial(C_{ih})_{kj}}{\partial W_{h}^{Q}}. We have

(Cih)kjWhQ=(XikWhQXijWhK)𝕀dXik\displaystyle\frac{\partial(C_{ih})_{kj}}{\partial W_{h}^{Q}}=-(X_{ik\cdot}W_{h}^{Q}-X_{ij\cdot}W_{h}^{K})\mathbb{I}_{d}\otimes X_{ik\cdot} (100)

Thus, we can derive upper bound for d((Cih)kjWhQ)F\left\|d\left(\frac{\partial\left(C_{ih}\right)_{kj}}{\partial W_{h}^{Q}}\right)\right\|_{F}:

d((Cih)kjWhQ)F=(Xik(WhQ+d(WhQ))XijWhK)𝕀dXik+(XikWhQXijWhK)𝕀dXikF/d\displaystyle\left\|d\left(\frac{\partial\left(C_{ih}\right)_{kj}}{\partial W_{h}^{Q}}\right)\right\|_{F}=\left\|-\left(X_{ik\cdot}(W_{h}^{Q}+d(W_{h}^{Q}))-X_{ij\cdot}W_{h}^{K}\right)\mathbb{I}_{d}\otimes X_{ik}+\left(X_{ik\cdot}W_{h}^{Q}-X_{ij\cdot}W_{h}^{K}\right)\mathbb{I}_{d}\otimes X_{ik\cdot}\right\|_{F}/\sqrt{d}
=Xikd(WhQ)𝕀dXikF/d\displaystyle=\|X_{ik\cdot}d(W_{h}^{Q})\mathbb{I}_{d}\otimes X_{ik\cdot}\|_{F}/\sqrt{d}
Xik22𝕀dFd(WhQ)F/d\displaystyle\leq\|X_{ik\cdot}\|^{2}_{2}\|\mathbb{I}_{d}\|_{F}\|d(W_{h}^{Q})\|_{F}/\sqrt{d}
=Xik22d(WhQ)F\displaystyle=\|X_{ik\cdot}\|^{2}_{2}\|d(W_{h}^{Q})\|_{F} (101)

Thus, we have the following:

d((Cih)WhQ)Fk=1nj=1nd((Cih)kjWhQ)F\displaystyle\left\|d\left(\frac{\partial\left(C_{ih}\right)}{\partial W_{h}^{Q}}\right)\right\|_{F}\leq\sum\limits_{k=1}^{n}\sum\limits_{j=1}^{n}\left\|d\left(\frac{\partial\left(C_{ih}\right)_{kj}}{\partial W_{h}^{Q}}\right)\right\|_{F}
d(WhQ)Fk=1nj=1nXik22\displaystyle\leq\|d(W^{Q}_{h})\|_{F}\sum\limits_{k=1}^{n}\sum\limits_{j=1}^{n}\|X_{ik\cdot}\|_{2}^{2}
nXiF2d(WhQ)F\displaystyle\leq n\|X_{i}\|_{F}^{2}\|d(W^{Q}_{h})\|_{F}

Proof.

Proof of Lemma 7 (3):

f(M;Xi)CiF=((𝖬𝖧(M;Xi)yi)(WO)(Vi))SiF\displaystyle\left\|\frac{\partial f\left(M;X_{i}\right)}{\partial C_{i}}\right\|_{F}=\left\|\left(\left({\sf MH}\left(M;X_{i}\right)-y_{i}\right)\left(W^{O}\right)^{\top}\left(V_{i}^{\prime}\right)^{\top}\right)\odot S_{i}\right\|_{F}
((𝖬𝖧(M;Xi)yi)(WO)(Vi))Fmin|Si|\displaystyle\geq\left\|\left(\left({\sf MH}\left(M;X_{i}\right)-y_{i}\right)\left(W^{O}\right)^{\top}\left(V_{i}^{\prime}\right)^{\top}\right)\right\|_{F}\cdot\min|S_{i}|
min|ViWO|min|Si|𝖬𝖧(M;Xi)yi2.\displaystyle\geq\min|V_{i}^{\prime}W^{O}|\cdot\min|S_{i}|\cdot\|{\sf MH}\left(M;X_{i}\right)-y_{i}\|_{2}.

Proof.

Proof of Lemma 7 (4):

f(M;Xi)WhQF=vec(f(M;Xi)WhQ)2=vec(f(M;Xi)Ci)CiWhQ2\displaystyle\left\|\frac{\partial f(M;X_{i})}{\partial W^{Q}_{h}}\right\|_{F}=\left\|\operatorname{vec}\left(\frac{\partial f(M;X_{i})}{\partial W^{Q}_{h}}\right)\right\|_{2}=\left\|\operatorname{vec}\left(\frac{\partial f(M;X_{i})}{\partial C_{i}}\right)\cdot\frac{\partial C_{i}}{\partial W_{h}^{Q}}\right\|_{2}
f(M;Xi)CiFCiWhQ2\displaystyle\leq\left\|\frac{\partial f(M;X_{i})}{\partial C_{i}}\right\|_{F}\cdot\left\|\frac{\partial C_{i}}{\partial W_{h}^{Q}}\right\|_{2}
=((𝖬𝖧(M;Xi)yi)(WO)(Vi))SiF2ndXiF2σmax2(WhQ)+σmax2(WhK)\displaystyle=\left\|\left(\left({\sf MH}\left(M;X_{i}\right)-y_{i}\right)\left(W^{O}\right)^{\top}\left(V_{i}^{\prime}\right)^{\top}\right)\odot S_{i}\right\|_{F}\cdot\sqrt{\frac{2n}{d}}\left\|X_{i}\right\|_{F}^{2}\sqrt{\sigma_{\max}^{2}\left(W_{h}^{Q}\right)+\sigma_{\max}^{2}\left(W_{h}^{K}\right)}
2ndXiF3WO2σmax(WV)σmax2(WhQ)+σmax2(WhK)𝖬𝖧(M;Xi)yi2\displaystyle\leq\sqrt{\frac{2n}{d}}\left\|X_{i}\right\|_{F}^{3}\|W^{O}\|_{2}\sigma_{\max}(W^{V})\sqrt{\sigma_{\max}^{2}\left(W_{h}^{Q}\right)+\sigma_{\max}^{2}\left(W_{h}^{K}\right)}\left\|{\sf MH}\left(M;X_{i}\right)-y_{i}\right\|_{2}

Proof.

Proof of Lemma 8 (1): The proof is similar to the proof of Lemma 3 (1). So we do not include the details here. We can similarly derive

f(Mt+1;X)f(Mt;X)2NmaxiQi2+maxiKi2+n2HmaxiXiF2WO22Mt+1MtF:=ZMt+1MtF\displaystyle\begin{aligned} &\left\|f\left(M_{t+1};X\right)-f\left(M_{t};X\right)\right\|_{2}\leq N\sqrt{\max_{i}Q_{i}^{\prime 2}+\max_{i}K_{i}^{\prime 2}+n^{2}H\max_{i}\left\|X_{i}\right\|_{F}^{2}\|W^{O}\|_{2}^{2}}\left\|M_{t+1}-M_{t}\right\|_{F}\\ &:=Z^{\prime}\left\|M_{t+1}-M_{t}\right\|_{F}\end{aligned} (102)

Proof.

Proof of Lemma 8 (2): By triangle inequality, we have

Wf(Mt+1;X)Wf(Mt;X)F\displaystyle\left\|\nabla_{W}f(M_{t+1};X)-\nabla_{W}f(M_{t};X)\right\|_{F}
WQf(Mt+1;X)WQf(Mt+1;X)F+WKf(Mt+1;X)WKf(Mt+1;X)F\displaystyle\leq\left\|\nabla_{W^{Q}}f(M_{t+1};X)-\nabla_{W^{Q}}f(M_{t+1};X)\right\|_{F}+\left\|\nabla_{W^{K}}f(M_{t+1};X)-\nabla_{W^{K}}f(M_{t+1};X)\right\|_{F}
+WVf(Mt+1;X)WVf(Mt+1;X)F\displaystyle\quad+\left\|\nabla_{W^{V}}f(M_{t+1};X)-\nabla_{W^{V}}f(M_{t+1};X)\right\|_{F}
i=1N(WQf(Mt+1;X)WQf(Mt+1;X)F+WKf(Mt+1;X)WKf(Mt+1;X)F\displaystyle\leq\sum\limits_{i=1}^{N}\big{(}\left\|\nabla_{W^{Q}}f\left(M_{t+1};X\right)-\nabla_{W^{Q}}f\left(M_{t+1};X\right)\right\|_{F}+\left\|\nabla_{W^{K}}f\left(M_{t+1};X\right)-\nabla_{W^{K}}f\left(M_{t+1};X\right)\right\|_{F}
+Wvf(Mt+1;X)Wvf(Mt+1;X)F)\displaystyle\quad+\|\nabla_{W^{v}}f\left(M_{t+1};X\right)-\nabla_{W^{v}}f\left(M_{t+1};X\right)\|_{F}\big{)} (103)

Step 1: Derive upper bound for

WQf(Mt+1;Xi))WQf(Mt;Xi))F=vec(WQf(Mt+1;Xi))vec(WQf(Mt;Xi))2.\|\nabla_{W^{Q}}f(M_{t+1};X_{i}))-\nabla_{W^{Q}}f(M_{t};X_{i}))\|_{F}=\|\operatorname{vec}(\nabla_{W^{Q}}f(M_{t+1};X_{i}))-\operatorname{vec}(\nabla_{W^{Q}}f(M_{t};X_{i}))\|_{2}.

First, we give the vectorized expression of WQf(Mt;Xi)\nabla_{W^{Q}}f\left(M_{t};X_{i}\right). Recall we denote Ui=((𝖬𝖧(M;Xi)yi)(WO)(Vi))SiU_{i}=\left(\left({\sf MH}\left(M;X_{i}\right)-y_{i}\right)\left(W^{O}\right)^{\top}\left(V_{i}^{\prime}\right)^{\top}\right)\odot S_{i}. By Lemma 6, we can derive the close form of vec(WQf(Mt;Xi))\operatorname{vec}(\nabla_{W^{Q}}f(M_{t};X_{i})):

vec(WQf(M;Xi))=(i)vec(Ui)vec(CiWhQ)\displaystyle\operatorname{vec}(\nabla_{W^{Q}}f(M;X_{i}))\stackrel{{\scriptstyle(i)}}{{=}}\operatorname{vec}(U_{i})\cdot\operatorname{vec}\left(\frac{\partial C_{i}}{\partial W_{h}^{Q}}\right) (104)

Further, recall we have defined R¯\bar{R} and the following inequality holds:

UiFR¯\displaystyle\|U_{i}\|_{F}\leq\bar{R} (105)

Next, let us derive upper bound for WQf(Mt+1;Xi)WQf(Mt+1;Xi)F\left\|\nabla_{W^{Q}}f\left(M_{t+1};X_{i}\right)-\nabla_{W^{Q}}f\left(M_{t+1};X_{i}\right)\right\|_{F}.

WQf(Mt+1;Xi)WQf(Mt+1;Xi)F\displaystyle\left\|\nabla_{W^{Q}}f(M_{t+1};X_{i})-\nabla_{W^{Q}}f(M_{t+1};X_{i})\right\|_{F}
=(i)vec(Ui,t+1)(Ci(Mt+1)WhQ)vec(Ui,t)(Ci(Mt)WhQ)F\displaystyle\stackrel{{\scriptstyle(i)}}{{=}}\left\|\operatorname{vec}(U_{i,t+1})\cdot\left(\frac{\partial C_{i}(M_{t+1})}{\partial W_{h}^{Q}}\right)-\operatorname{vec}(U_{i,t})\cdot\left(\frac{\partial C_{i}(M_{t})}{\partial W_{h}^{Q}}\right)\right\|_{F}
=vec(Ui,t+1)(Ci(Mt+1)WhQ)vec(Ui,t+1)(Ci(Mt)WhQ)\displaystyle=\Bigg{\|}\operatorname{vec}(U_{i,t+1})\cdot\left(\frac{\partial C_{i}(M_{t+1})}{\partial W_{h}^{Q}}\right)-\operatorname{vec}(U_{i,t+1})\cdot\left(\frac{\partial C_{i}(M_{t})}{\partial W_{h}^{Q}}\right) (106)
+vec(Ui,t+1)(Ci(Mt)WhQ)vec(Ui,t)(Ci(Mt)WhQ)F\displaystyle\quad+\operatorname{vec}(U_{i,t+1})\cdot\left(\frac{\partial C_{i}(M_{t})}{\partial W_{h}^{Q}}\right)-\operatorname{vec}(U_{i,t})\cdot\left(\frac{\partial C_{i}(M_{t})}{\partial W_{h}^{Q}}\right)\Bigg{\|}_{F}
(ii)vec(Ui,t+1)2Ci(Mt+1)WhQCi(Mt)WhQ2+vec(Ui,t+1Ui,t)2Ci(Mt)WhQ2\displaystyle\stackrel{{\scriptstyle(ii)}}{{\leq}}\|\operatorname{vec}(U_{i,t+1})\|_{2}\left\|\frac{\partial C_{i}\left(M_{t+1}\right)}{\partial W_{h}^{Q}}-\frac{\partial C_{i}\left(M_{t}\right)}{\partial W_{h}^{Q}}\right\|_{2}+\|\operatorname{vec}(U_{i,t+1}-U_{i,t})\|_{2}\left\|\frac{\partial C_{i}\left(M_{t}\right)}{\partial W_{h}^{Q}}\right\|_{2}
(iii)R¯nXiF2d(WhQ)F+vec(Ui,t+1Ui,t)FnXiF2(σmax(WhQ)+σmax(WhK))\displaystyle\stackrel{{\scriptstyle(iii)}}{{\leq}}\bar{R}\sqrt{n}\left\|X_{i}\right\|_{F}^{2}\cdot\left\|d\left(W_{h}^{Q}\right)\right\|_{F}+\left\|\operatorname{vec}\left(U_{i,t+1}-U_{i,t}\right)\right\|_{F}\cdot\sqrt{n}\left\|X_{i}\right\|_{F}^{2}\cdot\left(\sigma_{\max}\left(W_{h}^{Q}\right)+\sigma_{\max}\left(W_{h}^{K}\right)\right)
R¯nXiF2d(WhQ)F+nXiF2(σmax(WhQ)+σmax(WhK))Ri,t+1Si,t+1Ri,tSi,tF\displaystyle\leq\bar{R}\sqrt{n}\left\|X_{i}\right\|_{F}^{2}\cdot\left\|d\left(W_{h}^{Q}\right)\right\|_{F}+\sqrt{n}\left\|X_{i}\right\|_{F}^{2}\cdot\left(\sigma_{\max}\left(W_{h}^{Q}\right)+\sigma_{\max}\left(W_{h}^{K}\right)\right)\left\|R_{i,t+1}\odot S_{i,t+1}-R_{i,t}\odot S_{i,t}\right\|_{F}
R¯nXiF2d(WhQ)F+nXiF2(σmax(WhQ)+σmax(WhK))\displaystyle\leq\bar{R}\sqrt{n}\left\|X_{i}\right\|_{F}^{2}\cdot\left\|d\left(W_{h}^{Q}\right)\right\|_{F}+\sqrt{n}\left\|X_{i}\right\|_{F}^{2}\cdot\left(\sigma_{\max}\left(W_{h}^{Q}\right)+\sigma_{\max}\left(W_{h}^{K}\right)\right)
×((Ri,t+1Ri,t)Si,t+1F+Ri,tSi,t+1Ri,tSt)F)\displaystyle\quad\times\big{(}\left.\left\|\left(R_{i,t+1}-R_{i,t}\right)\odot S_{i,t+1}\right\|_{F}+\|R_{i,t}\odot S_{i,t+1}-R_{i,t}\odot S_{t}\right)\|_{F}\big{)}
R¯nXiF2d(WhQ)F+nXiF2(σmax(WhQ)+σmax(WhK))\displaystyle\leq\bar{R}\sqrt{n}\left\|X_{i}\right\|_{F}^{2}\cdot\left\|d\left(W_{h}^{Q}\right)\right\|_{F}+\sqrt{n}\left\|X_{i}\right\|_{F}^{2}\cdot\left(\sigma_{\max}\left(W_{h}^{Q}\right)+\sigma_{\max}\left(W_{h}^{K}\right)\right)
×(Ri,t+1Ri,tF+Ri,tFSi,t+1Si,tF)\displaystyle\quad\times(\left\|R_{i,t+1}-R_{i,t}\right\|_{F}+\left\|R_{i,t}\right\|_{F}\left\|S_{i,t+1}-S_{i,t}\right\|_{F}) (107)

Next, we aim to derive upper bound of Ri,t+1Ri,tF\left\|R_{i,t+1}-R_{i,t}\right\|_{F} in Equation (107). Similar to the derivation in Equation (96), we can derive

Ri,t+1Ri,tF(iv)(ZiXiFσmax(WV)WO2+(nHσmax(WV)XiFWO2+yi2)XiFWO2)\displaystyle\|R_{i,t+1}-R_{i,t}\|_{F}\stackrel{{\scriptstyle(iv)}}{{\leq}}\big{(}Z_{i}^{\prime}\left\|X_{i}\right\|_{F}\sigma_{\max}\left(W^{V}\right)\left\|W^{O}\right\|_{2}+(n\sqrt{H}\sigma_{\max}\left(W^{V}\right)\left\|X_{i}\right\|_{F}\left\|W^{O}\right\|_{2}+\left\|y_{i}\right\|_{2})\left\|X_{i}\right\|_{F}\|W^{O}\|_{2}\big{)}
×Mt+1MtF\displaystyle\quad\times\|M_{t+1}-M_{t}\|_{F}
:=PiMt+1MtF,\displaystyle:=P_{i}^{\prime}\|M_{t+1}-M_{t}\|_{F}, (108)

Plug Equation (97) into Equation (107), we can finally derive the bound for WQf(Mt+1;Xi)WQf(Mt+1;Xi)F\left\|\nabla_{W^{Q}}f\left(M_{t+1};X_{i}\right)-\nabla_{W^{Q}}f\left(M_{t+1};X_{i}\right)\right\|_{F}.

WQf(Mt+1;Xi)WQf(Mt+1;Xi)F\displaystyle\left\|\nabla_{W^{Q}}f(M_{t+1};X_{i})-\nabla_{W^{Q}}f(M_{t+1};X_{i})\right\|_{F}
R¯nXiF2d(WhQ)F+nXiF2(σmax(WhQ)+σmax(WhK))\displaystyle\leq\bar{R}\sqrt{n}\left\|X_{i}\right\|_{F}^{2}\cdot\left\|d\left(W_{h}^{Q}\right)\right\|_{F}+\sqrt{n}\left\|X_{i}\right\|_{F}^{2}\cdot\left(\sigma_{\max}\left(W_{h}^{Q}\right)+\sigma_{\max}\left(W_{h}^{K}\right)\right)
×(Ri,t+1Ri,tF+Ri,tFSi,t+1Si,tF)\displaystyle\quad\times(\left\|R_{i,t+1}-R_{i,t}\right\|_{F}+\left\|R_{i,t}\right\|_{F}\left\|S_{i,t+1}-S_{i,t}\right\|_{F})
R¯nXiF2Mt+1MtF+nXiF2(σmax(WhQ)+σmax(WhK))\displaystyle\leq\bar{R}\sqrt{n}\left\|X_{i}\right\|_{F}^{2}\cdot\left\|M_{t+1}-M_{t}\right\|_{F}+\sqrt{n}\left\|X_{i}\right\|_{F}^{2}\cdot\left(\sigma_{\max}\left(W_{h}^{Q}\right)+\sigma_{\max}\left(W_{h}^{K}\right)\right)
×(PiMt+1MtF+nR¯XiF2(σmax(WhQ)+σmax(WhK))Mt+1MtF)\displaystyle\quad\times\left(P_{i}^{\prime}\|M_{t+1}-M_{t}\|_{F}+\sqrt{n}\bar{R}\left\|X_{i}\right\|_{F}^{2}\cdot\left(\sigma_{\max}\left(W_{h}^{Q}\right)+\sigma_{\max}\left(W_{h}^{K}\right)\right)\|M_{t+1}-M_{t}\|_{F}\right)
:=LiQMt+1MtF,\displaystyle:=L^{Q^{\prime}}_{i}\|M_{t+1}-M_{t}\|_{F},

and plug into Equation (93), we can finally derive

Wf(Mt+1;X)Wf(Mt;X)Fi=1N(LiQ+LiK+LiV)Mt+1MtF\displaystyle\left\|\nabla_{W}f\left(M_{t+1};X\right)-\nabla_{W}f\left(M_{t};X\right)\right\|_{F}\leq\sum\limits_{i=1}^{N}(L_{i}^{Q}+L_{i}^{K}+L_{i}^{V})\|M_{t+1}-M_{t}\|_{F}
N(maxiLiQ+maxiLiK+maxiLiV)Mt+1MtF\displaystyle\leq N(\max\limits_{i}L_{i}^{Q}+\max\limits_{i}L_{i}^{K}+\max\limits_{i}L_{i}^{V})\|M_{t+1}-M_{t}\|_{F} (109)
:=GMt+1MtF.\displaystyle:=G\|M_{t+1}-M_{t}\|_{F}. (110)

2 NeurIPS paper checklist

  1. 1.

    Claims

  2. Question: Do the main claims made in the abstract and introduction accurately reflect the paper’s contributions and scope?

  3. Answer:[Yes]

  4. Justification: See Theorem 1,2,3

  5. Guidelines:

    • The answer NA means that the abstract and introduction do not include the claims made in the paper.

    • The abstract and/or introduction should clearly state the claims made, including the contributions made in the paper and important assumptions and limitations. A No or NA answer to this question will not be perceived well by the reviewers.

    • The claims made should match theoretical and experimental results, and reflect how much the results can be expected to generalize to other settings.

    • It is fine to include aspirational goals as motivation as long as it is clear that these goals are not attained by the paper.

  6. 2.

    Limitations

  7. Question: Does the paper discuss the limitations of the work performed by the authors?

  8. Answer: [Yes]

  9. Justification: Please see our conclusion 6.

  10. Guidelines:

    • The answer NA means that the paper has no limitation while the answer No means that the paper has limitations, but those are not discussed in the paper.

    • The authors are encouraged to create a separate ”Limitations” section in their paper.

    • The paper should point out any strong assumptions and how robust the results are to violations of these assumptions (e.g., independence assumptions, noiseless settings, model well-specification, asymptotic approximations only holding locally). The authors should reflect on how these assumptions might be violated in practice and what the implications would be.

    • The authors should reflect on the scope of the claims made, e.g., if the approach was only tested on a few datasets or with a few runs. In general, empirical results often depend on implicit assumptions, which should be articulated.

    • The authors should reflect on the factors that influence the performance of the approach. For example, a facial recognition algorithm may perform poorly when image resolution is low or images are taken in low lighting. Or a speech-to-text system might not be used reliably to provide closed captions for online lectures because it fails to handle technical jargon.

    • The authors should discuss the computational efficiency of the proposed algorithms and how they scale with dataset size.

    • If applicable, the authors should discuss possible limitations of their approach to address problems of privacy and fairness.

    • While the authors might fear that complete honesty about limitations might be used by reviewers as grounds for rejection, a worse outcome might be that reviewers discover limitations that aren’t acknowledged in the paper. The authors should use their best judgment and recognize that individual actions in favor of transparency play an important role in developing norms that preserve the integrity of the community. Reviewers will be specifically instructed to not penalize honesty concerning limitations.

  11. 3.

    Theory Assumptions and Proofs

  12. Question: For each theoretical result, does the paper provide the full set of assumptions and a complete (and correct) proof?

  13. Answer: [Yes]

  14. Justification: See Appendix, which provides proof for each Theorem.

  15. Guidelines:

    • The answer NA means that the paper does not include theoretical results.

    • All the theorems, formulas, and proofs in the paper should be numbered and cross-referenced.

    • All assumptions should be clearly stated or referenced in the statement of any theorems.

    • The proofs can either appear in the main paper or the supplemental material, but if they appear in the supplemental material, the authors are encouraged to provide a short proof sketch to provide intuition.

    • Inversely, any informal proof provided in the core of the paper should be complemented by formal proofs provided in appendix or supplemental material.

    • Theorems and Lemmas that the proof relies upon should be properly referenced.

  16. 4.

    Experimental Result Reproducibility

  17. Question: Does the paper fully disclose all the information needed to reproduce the main experimental results of the paper to the extent that it affects the main claims and/or conclusions of the paper (regardless of whether the code and data are provided or not)?

  18. Answer: [Yes]

  19. Justification: See experiment setting in Section 5.2

  20. Guidelines:

    • The answer NA means that the paper does not include experiments.

    • If the paper includes experiments, a No answer to this question will not be perceived well by the reviewers: Making the paper reproducible is important, regardless of whether the code and data are provided or not.

    • If the contribution is a dataset and/or model, the authors should describe the steps taken to make their results reproducible or verifiable.

    • Depending on the contribution, reproducibility can be accomplished in various ways. For example, if the contribution is a novel architecture, describing the architecture fully might suffice, or if the contribution is a specific model and empirical evaluation, it may be necessary to either make it possible for others to replicate the model with the same dataset, or provide access to the model. In general. releasing code and data is often one good way to accomplish this, but reproducibility can also be provided via detailed instructions for how to replicate the results, access to a hosted model (e.g., in the case of a large language model), releasing of a model checkpoint, or other means that are appropriate to the research performed.

    • While NeurIPS does not require releasing code, the conference does require all submissions to provide some reasonable avenue for reproducibility, which may depend on the nature of the contribution. For example

      1. (a)

        If the contribution is primarily a new algorithm, the paper should make it clear how to reproduce that algorithm.

      2. (b)

        If the contribution is primarily a new model architecture, the paper should describe the architecture clearly and fully.

      3. (c)

        If the contribution is a new model (e.g., a large language model), then there should either be a way to access this model for reproducing the results or a way to reproduce the model (e.g., with an open-source dataset or instructions for how to construct the dataset).

      4. (d)

        We recognize that reproducibility may be tricky in some cases, in which case authors are welcome to describe the particular way they provide for reproducibility. In the case of closed-source models, it may be that access to the model is limited in some way (e.g., to registered users), but it should be possible for other researchers to have some path to reproducing or verifying the results.

  21. 5.

    Open access to data and code

  22. Question: Does the paper provide open access to the data and code, with sufficient instructions to faithfully reproduce the main experimental results, as described in supplemental material?

  23. Answer: [No]

  24. Justification: We do not include the open access to code.

  25. Guidelines:

    • The answer NA means that paper does not include experiments requiring code.

    • Please see the NeurIPS code and data submission guidelines (https://nips.cc/public/guides/CodeSubmissionPolicy) for more details.

    • While we encourage the release of code and data, we understand that this might not be possible, so “No” is an acceptable answer. Papers cannot be rejected simply for not including code, unless this is central to the contribution (e.g., for a new open-source benchmark).

    • The instructions should contain the exact command and environment needed to run to reproduce the results. See the NeurIPS code and data submission guidelines (https://nips.cc/public/guides/CodeSubmissionPolicy) for more details.

    • The authors should provide instructions on data access and preparation, including how to access the raw data, preprocessed data, intermediate data, and generated data, etc.

    • The authors should provide scripts to reproduce all experimental results for the new proposed method and baselines. If only a subset of experiments are reproducible, they should state which ones are omitted from the script and why.

    • At submission time, to preserve anonymity, the authors should release anonymized versions (if applicable).

    • Providing as much information as possible in supplemental material (appended to the paper) is recommended, but including URLs to data and code is permitted.

  26. 6.

    Experimental Setting/Details

  27. Question: Does the paper specify all the training and test details (e.g., data splits, hyperparameters, how they were chosen, type of optimizer, etc.) necessary to understand the results?

  28. Answer: [Yes]

  29. Justification: See Section 5.2 for experiment setting.

  30. Guidelines:

    • The answer NA means that the paper does not include experiments.

    • The experimental setting should be presented in the core of the paper to a level of detail that is necessary to appreciate the results and make sense of them.

    • The full details can be provided either with the code, in appendix, or as supplemental material.

  31. 7.

    Experiment Statistical Significance

  32. Question: Does the paper report error bars suitably and correctly defined or other appropriate information about the statistical significance of the experiments?

  33. Answer: [Yes]

  34. Justification: Please see Fig 2 and Fig 3. We have a 1-σ\sigma error bar.

  35. Guidelines:

    • The answer NA means that the paper does not include experiments.

    • The authors should answer ”Yes” if the results are accompanied by error bars, confidence intervals, or statistical significance tests, at least for the experiments that support the main claims of the paper.

    • The factors of variability that the error bars are capturing should be clearly stated (for example, train/test split, initialization, random drawing of some parameter, or overall run with given experimental conditions).

    • The method for calculating the error bars should be explained (closed form formula, call to a library function, bootstrap, etc.)

    • The assumptions made should be given (e.g., Normally distributed errors).

    • It should be clear whether the error bar is the standard deviation or the standard error of the mean.

    • It is OK to report 1-sigma error bars, but one should state it. The authors should preferably report a 2-sigma error bar than state that they have a 96% CI, if the hypothesis of Normality of errors is not verified.

    • For asymmetric distributions, the authors should be careful not to show in tables or figures symmetric error bars that would yield results that are out of range (e.g. negative error rates).

    • If error bars are reported in tables or plots, The authors should explain in the text how they were calculated and reference the corresponding figures or tables in the text.

  36. 8.

    Experiments Compute Resources

  37. Question: For each experiment, does the paper provide sufficient information on the computer resources (type of compute workers, memory, time of execution) needed to reproduce the experiments?

  38. Answer: [No]

  39. Justification: We do not include the compute resources detail.

  40. Guidelines:

    • The answer NA means that the paper does not include experiments.

    • The paper should indicate the type of compute workers CPU or GPU, internal cluster, or cloud provider, including relevant memory and storage.

    • The paper should provide the amount of compute required for each of the individual experimental runs as well as estimate the total compute.

    • The paper should disclose whether the full research project required more compute than the experiments reported in the paper (e.g., preliminary or failed experiments that didn’t make it into the paper).

  41. 9.

    Code Of Ethics

  42. Question: Does the research conducted in the paper conform, in every respect, with the NeurIPS Code of Ethics https://neurips.cc/public/EthicsGuidelines?

  43. Answer: [Yes]

  44. Justification: The paper has no harm in the research process or negative social impact. The paper is anonymous.

  45. Guidelines:

    • The answer NA means that the authors have not reviewed the NeurIPS Code of Ethics.

    • If the authors answer No, they should explain the special circumstances that require a deviation from the Code of Ethics.

    • The authors should make sure to preserve anonymity (e.g., if there is a special consideration due to laws or regulations in their jurisdiction).

  46. 10.

    Broader Impacts

  47. Question: Does the paper discuss both potential positive societal impacts and negative societal impacts of the work performed?

  48. Answer:[N/A]

  49. Justification: There is no societal impact

  50. Guidelines:

    • The answer NA means that there is no societal impact of the work performed.

    • If the authors answer NA or No, they should explain why their work has no societal impact or why the paper does not address societal impact.

    • Examples of negative societal impacts include potential malicious or unintended uses (e.g., disinformation, generating fake profiles, surveillance), fairness considerations (e.g., deployment of technologies that could make decisions that unfairly impact specific groups), privacy considerations, and security considerations.

    • The conference expects that many papers will be foundational research and not tied to particular applications, let alone deployments. However, if there is a direct path to any negative applications, the authors should point it out. For example, it is legitimate to point out that an improvement in the quality of generative models could be used to generate deepfakes for disinformation. On the other hand, it is not needed to point out that a generic algorithm for optimizing neural networks could enable people to train models that generate Deepfakes faster.

    • The authors should consider possible harms that could arise when the technology is being used as intended and functioning correctly, harms that could arise when the technology is being used as intended but gives incorrect results, and harms following from (intentional or unintentional) misuse of the technology.

    • If there are negative societal impacts, the authors could also discuss possible mitigation strategies (e.g., gated release of models, providing defenses in addition to attacks, mechanisms for monitoring misuse, mechanisms to monitor how a system learns from feedback over time, improving the efficiency and accessibility of ML).

  51. 11.

    Safeguards

  52. Question: Does the paper describe safeguards that have been put in place for responsible release of data or models that have a high risk for misuse (e.g., pretrained language models, image generators, or scraped datasets)?

  53. Answer: [N/A]

  54. Justification:The paper poses no such risks.

  55. Guidelines:

    • The answer NA means that the paper poses no such risks.

    • Released models that have a high risk for misuse or dual-use should be released with necessary safeguards to allow for controlled use of the model, for example by requiring that users adhere to usage guidelines or restrictions to access the model or implementing safety filters.

    • Datasets that have been scraped from the Internet could pose safety risks. The authors should describe how they avoided releasing unsafe images.

    • We recognize that providing effective safeguards is challenging, and many papers do not require this, but we encourage authors to take this into account and make a best faith effort.

  56. 12.

    Licenses for existing assets

  57. Question: Are the creators or original owners of assets (e.g., code, data, models), used in the paper, properly credited and are the license and terms of use explicitly mentioned and properly respected?

  58. Answer: [Yes]

  59. Justification: We have cited the code framework we use Chen et al. [2021].

  60. Guidelines:

    • The answer NA means that the paper does not use existing assets.

    • The authors should cite the original paper that produced the code package or dataset.

    • The authors should state which version of the asset is used and, if possible, include a URL.

    • The name of the license (e.g., CC-BY 4.0) should be included for each asset.

    • For scraped data from a particular source (e.g., website), the copyright and terms of service of that source should be provided.

    • If assets are released, the license, copyright information, and terms of use in the package should be provided. For popular datasets, paperswithcode.com/datasets has curated licenses for some datasets. Their licensing guide can help determine the license of a dataset.

    • For existing datasets that are re-packaged, both the original license and the license of the derived asset (if it has changed) should be provided.

    • If this information is not available online, the authors are encouraged to reach out to the asset’s creators.

  61. 13.

    New Assets

  62. Question: Are new assets introduced in the paper well documented and is the documentation provided alongside the assets?

  63. Answer: [N/A]

  64. Justification: The paper does not release new assets.

  65. Guidelines:

    • The answer NA means that the paper does not release new assets.

    • Researchers should communicate the details of the dataset/code/model as part of their submissions via structured templates. This includes details about training, license, limitations, etc.

    • The paper should discuss whether and how consent was obtained from people whose asset is used.

    • At submission time, remember to anonymize your assets (if applicable). You can either create an anonymized URL or include an anonymized zip file.

  66. 14.

    Crowdsourcing and Research with Human Subjects

  67. Question: For crowdsourcing experiments and research with human subjects, does the paper include the full text of instructions given to participants and screenshots, if applicable, as well as details about compensation (if any)?

  68. Answer: [N/A]

  69. Justification: The paper does not involve crowdsourcing nor research with human subjects.

  70. Guidelines:

    • The answer NA means that the paper does not involve crowdsourcing nor research with human subjects.

    • Including this information in the supplemental material is fine, but if the main contribution of the paper involves human subjects, then as much detail as possible should be included in the main paper.

    • According to the NeurIPS Code of Ethics, workers involved in data collection, curation, or other labor should be paid at least the minimum wage in the country of the data collector.

  71. 15.

    Institutional Review Board (IRB) Approvals or Equivalent for Research with Human Subjects

  72. Question: Does the paper describe potential risks incurred by study participants, whether such risks were disclosed to the subjects, and whether Institutional Review Board (IRB) approvals (or an equivalent approval/review based on the requirements of your country or institution) were obtained?

  73. Answer: [N/A]

  74. Justification: The paper does not involve crowdsourcing nor research with human subjects.

  75. Guidelines:

    • The answer NA means that the paper does not involve crowdsourcing nor research with human subjects.

    • Depending on the country in which research is conducted, IRB approval (or equivalent) may be required for any human subjects research. If you obtained IRB approval, you should clearly state this in the paper.

    • We recognize that the procedures for this may vary significantly between institutions and locations, and we expect authors to adhere to the NeurIPS Code of Ethics and the guidelines for their institution.

    • For initial submissions, do not include any information that would break anonymity (if applicable), such as the institution conducting the review.