This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

JoMA: Demystifying Multilayer Transformers via JOint Dynamics of MLP and Attention

Yuandong Tian
AI@Meta (FAIR)
[email protected] &Yiping Wang
University of Washington
[email protected]
&Zhenyu Zhang
University of Texas at Austin
[email protected]
\ANDBeidi Chen
Carnegie Mellon University, AI@Meta (FAIR)
[email protected], [email protected]
&Simon Du
University of Washington
[email protected]
Abstract

We propose Joint MLP/Attention (JoMA) dynamics, a novel mathematical framework to understand the training procedure of multilayer Transformer architectures. This is achieved by integrating out the self-attention layer in Transformers, producing a modified dynamics of MLP layers only. JoMA removes unrealistic assumptions from previous analysis (e.g., lack of residual connection) and predicts that the attention first becomes sparse (to learn salient tokens), then dense (to learn less salient tokens) in the presence of nonlinear activations, while in the linear case, it is consistent with existing works that show attention becomes sparse over time. We leverage JoMA to qualitatively explains how tokens are combined to form hierarchies in multilayer Transformers, when the input tokens are generated by a latent hierarchical generative model. Experiments on models trained from real-world dataset (Wikitext2/Wikitext103) and various pre-trained models (OPT, Pythia) verify our theoretical findings. The code is at111 https://github.com/facebookresearch/luckmatters/tree/yuandong3.

1 Introduction

Since its debut, Transformers (Vaswani et al., 2017) have been extensively used in many applications and demonstrates impressive performance (Dosovitskiy et al., 2020; OpenAI, 2023) compared to domain-specific models (e.g., CNN in computer vision, GNN in graph modeling, RNN/LSTM in language modeling, etc). In all these scenarios, the basic Transformer block, which consists of one self-attention plus two-layer nonlinear MLP, plays a critical role. A natural question arises:

How the basic Transformer block leads to effective learning?

Due to the complexity and nonlinearity of Transformer architectures, it remains a highly nontrivial open problem to find a unified mathematical framework that characterizes the learning mechanism of multi-layer transformers. Existing works mostly focus on 1-layer Transformer (Li et al., 2023a; Tarzanagh et al., 2023b) with fixed MLP (Tarzanagh et al., 2023a) layer, linear activation functions (Tian et al., 2023), and local gradient steps at initialization (Bietti et al., 2023; Oymak et al., 2023), etc.

In this paper, we propose a novel joint dynamics of self-attention plus MLP, based on Joint MLP/Attention Integral (JoMA), a first integral that combines the lower layer of the MLP and self-attention layers. Leveraging this joint dynamics, the self-attention is shown to have more fine-grained and delicate behavior: it first becomes sparse as in the linear case (Tian et al., 2023), only attends to tokens that frequently co-occur with the query, and then becomes denser and gradually includes tokens with less frequent co-occurrence, in the case of nonlinear activation. This shows a changing inductive bias in the Transformer training: first the model focuses on most salient features, then extends to less salient ones.

Another natural question arises: why such a learning pattern is preferred? While for 1-layer this does not give any benefits, in multilayer Transformer setting, we show qualitatively that such a dynamics plays an important role. To demonstrate that this is the case, we assume a hierarchical tree generative model for the input tokens. In this model, starting from the upper level latent variables (in which the top-most is the class label of the input sequence), abbreviated as LVs\texttt{LV}_{s}, generates the latents LVs1\texttt{LV}_{s-1} in the lower layer, until reaching the token level (s=0s=0). With this model, we show that the tokens generated by the lowest latents LV1\texttt{LV}_{1} co-occur a lot and thus can be picked up first by the attention dynamics as “salient features”. This leads to learning of such token combinations in hidden MLP nodes, which triggers self-attention grouping at s=1s=1, etc. In this way, the non-salient co-occurrences are naturally explained by the top level hierarchy, rather than incorrectly learned by the lower layer as spurious correlation, which is fortunately delayed by the attention mechanism. Our theoretical finding is consistent with both the pre-trained models such as OPT/Pythia and models trained from scratch using real-world dataset (Wikitext2 and Wikitext103).

We show that JoMA overcomes several main limitations from Scan&Snap (Tian et al., 2023). JoMA incorporates residual connections and MLP nonlinearity as key ingredients, analyzes joint training of MLP and self-attention layer, and qualitatively explains dynamics of multilayer Transformers. For linear activation, JoMA coincides with Scan&Snap, i.e., the attention becomes sparse during training.

1.1 Related Work

Expressiveness of Attention-based Models. The universal approximation abilities of attention-based models have been studied extensively  (Yun et al., 2019; Bhattamishra et al., 2020a; b; Dehghani et al., 2018; Pérez et al., 2021). More recent studies offer detailed insights into their expressiveness for specific functions across various scenarios, sometimes incorporating statistical evaluations (Edelman et al., 2022; Elhage et al., 2021; Likhosherstov et al., 2021; Akyürek et al., 2022; Zhao et al., 2023; Yao et al., 2021; Anil et al., 2022; Barak et al., 2022). A fruitful line of work studied in-context learning capabilities of the Transformer (Dong et al., 2022), linking gradient descent in classification/regression learning to the feedforward actions in Transformer layers (Garg et al., 2022; Von Oswald et al., 2022; Bai et al., 2023; Olsson et al., 2022; Akyürek et al., 2022; Li et al., 2023b). However, unlike our study, these work do not characterize the training dynamics.

Training Dynamics of Neural Networks. Earlier research has delved into training dynamics within multi-layer linear neural networks (Arora et al., 2018; Bartlett et al., 2018), the teacher-student setting (Brutzkus & Globerson, 2017; Tian, 2017; Soltanolkotabi, 2017; Goel et al., 2018; Du et al., 2017; 2018a; Zhou et al., 2019; Liu et al., 2019; Xu & Du, 2023), and infinite-width limits (Jacot et al., 2018; Chizat et al., 2019; Du et al., 2018b; 2019; Allen-Zhu et al., 2019; Arora et al., 2019; Oymak & Soltanolkotabi, 2020; Zou et al., 2020; Li & Liang, 2018; Chizat & Bach, 2018; Mei et al., 2018; Nguyen & Pham, 2020; Fang et al., 2021; Lu et al., 2020). This includes extensions to attention-based-models (Hron et al., 2020; Yang et al., 2022). For self-supervised learning, there are analyses of linear networks (Tian, 2022) and explorations into the impact of nonlinearity (Tian, 2023).

Dynamics of Attention-based models. Regarding attention-based models, Zhang et al. (2020) delves into adaptive optimization techniques. Jelassi et al. (2022) introduces an idealized context, demonstrating that the vision transformer (Dosovitskiy et al., 2020) trained via gradient descent can discern spatial structures. Li et al. (2023c) illustrates that a single-layer Transformer can learn a constrained topic model, where each word is tied to a single topic, using 2\ell_{2} loss, BERT-like framework (Devlin et al., 2018), and certain assumptions on attention patterns. Snell et al. (2021) investigate the training dynamics of single-head attention in mimicking Seq2Seq learning. Tian et al. (2023) characterizes the SGD training dynamics of a 1-layer Transformer and shows that with cross-entropy loss, the model will pay more attention to the key tokens that frequently co-occur with the query token. Oymak et al. (2023) constructs the attention-based contextual mixture model and demonstrates how the prompt can attend to the sparse context-relevant tokens via gradient descent. Tarzanagh et al. (2023b) also finds that running gradient descent will converge in direction to the max-margin solution that separates the locally optimal tokens from others, and Tarzanagh et al. (2023a) further disclose the connection between the optimization geometry of self-attention and hard-margin SVM problem. For the in-context learning scenario, several recent works analyze linear transformers trained on random instances for linear regression tasks from the perspective of loss landscape (Boix-Adsera et al., 2023; Zhang et al., 2023). While these studies also study the optimization dynamics of attention-based models, they do not reveal the phenomena we discuss.

2 Problem Setting

Let the total vocabulary size be MM, in which MCM_{C} is the number of contextual tokens and MQM_{Q} is the number of query tokens. Consider one layer in multilayer transformer (Fig. 1(b)):

hk=ϕ(𝒘k𝒇),𝒇=UC𝒃+𝒖q,𝒃=σ(𝒛q)𝒙/Ah_{k}=\phi({\bm{w}}^{\top}_{k}{\bm{f}}),\quad{\bm{f}}=U_{C}{\bm{b}}+{\bm{u}}_{q},\quad{\bm{b}}=\sigma({\bm{z}}_{q})\circ{\bm{x}}/A (1)

Input/outputs. 𝒙=[xl]MC{\bm{x}}=[x_{l}]\in\mathbb{R}^{M_{C}} is the input frequency vector for contextual token 1lMC1\leq l\leq M_{C}, 1qMQ1\leq q\leq M_{Q} is the query token index, KK is the number of nodes in the hidden MLP layer, whose outputs are hkh_{k}. All the quantities above vary across different sample index ii (i.e., xl=xl[i]x_{l}=x_{l}[i], q=q[i]q=q[i]). In addition, ϕ\phi is the nonlinearity (e.g., ReLU).

Model weights. 𝒛q=[zql]MC{\bm{z}}_{q}=[z_{ql}]\in\mathbb{R}^{M_{C}} is the (unnormalized) attention logits given query qq, and 𝒘kd{\bm{w}}_{k}\in\mathbb{R}^{d} are the weights for the lower MLP layer. These will be analyzed in the paper.

Refer to caption
Figure 1: (a) Overview of JoMA framework. Using the invariant of training dynamics, the self-attention layer and the lower layer of MLP can be merged together to yield a MLP layer with modified dynamics (Theorem 1), which explains the behaviors of attention in linear (Sec. 3.1) and nonlinear (Sec. 4) MLP activation ϕ\phi, as well as hierarchical concept learning in multilayer cases (Sec. 5). (b) Problem setting. JoMA frameworks support different kind of attentions, including linear attention bl:=xlzqlb_{l}:=x_{l}z_{ql}, exp attention bl:=xlezql/Ab_{l}:=x_{l}e^{z_{ql}}/A and softmax bl:=xlezql/lxlezqlb_{l}:=x_{l}e^{z_{ql}}/\sum_{l}x_{l}e^{z_{ql}}.

The Attention Mechanism. In this paper, we mainly study three kinds of attention:

  • Linear Attention (Von Oswald et al., 2022): σ(x)=x\sigma(x)=x and A:=1A:=1;

  • Exp Attention: σ(x)=exp(x)\sigma(x)=\exp(x) and A:=constA:=\mathrm{const};

  • Softmax Attention (Vaswani et al., 2017): σ(x)=exp(x)\sigma(x)=\exp(x) and A:=𝟏(σ(𝒛q)𝒙)A:={\bm{1}}^{\top}\left(\sigma({\bm{z}}_{q})\circ{\bm{x}}\right).

Here \circ is the Hadamard (element-wise) product. 𝒃MC{\bm{b}}\in\mathbb{R}^{M_{C}} are the attention scores for contextual tokens, given by a point-wise attention function σ\sigma. AA is the normalization constant.

Embedding vectors. 𝒖l{\bm{u}}_{l} is the embedding vector for token ll. We assume that the embedding dimension dd is sufficiently large and thus 𝒖l𝒖l=𝕀(l=l){\bm{u}}_{l}^{\top}{\bm{u}}_{l^{\prime}}=\mathbb{I}(l=l^{\prime}), i.e., {𝒖l}\{{\bm{u}}_{l}\} are orthonormal bases. Let UC=[𝒖1,𝒖2,,𝒖MC]d×MCU_{C}=[{\bm{u}}_{1},{\bm{u}}_{2},\ldots,{\bm{u}}_{M_{C}}]\in\mathbb{R}^{d\times M_{C}} be the matrix that encodes all embedding vectors of contextual tokens. Then UCUC=IU_{C}^{\top}U_{C}=I. Appendix B.1 verifies the orthogonality assumption in multiple pre-trained models (Pythia, LLaMA, etc).

Residual connections are introduced as an additional term 𝒖q{\bm{u}}_{q} in Eqn. 1, which captures the critical component in Transformer architecture. Note that we do not model value matrix WVW_{V} since it can be merged into the embedding vectors (e.g., by 𝒖l=WV𝒖l{\bm{u}}_{l}^{\prime}=W_{V}{\bm{u}}_{l}), while WKW_{K} and WQW_{Q} are already implicitly modeled by the self-attention logits zql=𝒖qWQWK𝒖lz_{ql}={\bm{u}}^{\top}_{q}W^{\top}_{Q}W_{K}{\bm{u}}_{l}.

Gradient backpropagation in multilayers. In multilayer setting, the gradient gets backpropagated from top layer. Specifically, let ghk[i]g_{h_{k}}[i] be the backpropagated gradient sent to node kk at sample ii. For 1-layer Transformer with softmax loss directly applied to the hidden nodes of MLP, we have ghk[i]𝕀(y0[i]=k)g_{h_{k}}[i]\sim\mathbb{I}(y_{0}[i]=k), where y0[i]y_{0}[i] is the label to be predicted for sample ii. For brevity, we often omit sample index ii if there is no ambiguity.

Assumption 1 (Stationary backpropagated gradient ghkg_{h_{k}}).

Expectation terms involving ghkg_{h_{k}} (e.g., 𝔼[ghk𝐱]\mathbb{E}\left[g_{h_{k}}{\bm{x}}\right]) remains constant during training.

Note that this is true for layer-wise training: optimizing the weights for a specific Transformer layer, while fixing the weights of others and thus the statistics of backpropagated are stationary. For joint training, this condition also holds approximately since the weights change gradually during the training process. Under Assumption 1, Appendix A.1 gives an equivalent formulation in terms of per-hidden node loss.

Training Dynamics. Define the conditional expectation 𝔼q=m[]:=𝔼[|q=m]\mathbb{E}_{q=m}\left[\cdot\right]:=\mathbb{E}\left[\cdot|q=m\right]. Now let us consider the dynamics of 𝒘k{\bm{w}}_{k} and 𝒛m{\bm{z}}_{m}, if we train the model with a batch of inputs that always end up with query q[i]=mq[i]=m, then:

𝒘˙k=𝔼q=m[ghkhk𝒇],𝒛˙m=𝔼q=m[(𝒃/𝒛m)UC𝒈𝒇]\dot{\bm{w}}_{k}=\mathbb{E}_{q=m}\left[g_{h_{k}}h^{\prime}_{k}{\bm{f}}\right],\quad\quad\dot{\bm{z}}_{m}=\mathbb{E}_{q=m}\left[\left(\partial{\bm{b}}/\partial{\bm{z}}_{m}\right)^{\top}U_{C}^{\top}{\bm{g}}_{{\bm{f}}}\right] (2)

Here hk:=ϕ(𝒘k𝒇)h^{\prime}_{k}:=\phi^{\prime}({\bm{w}}_{k}^{\top}{\bm{f}}) is the derivative of current activation and 𝒈𝒇:=kghkhk𝒘k{\bm{g}}_{{\bm{f}}}:=\sum_{k}g_{h_{k}}h^{\prime}_{k}{\bm{w}}_{k}.

Refer to caption
Refer to caption
Figure 2: Test of training dynamics with linear MLP activation (ϕ(x)=x\phi(x)=x) under softmax attention. Left Two: The distribution of 𝒙{\bm{x}} smoothly transits over different class labels. Right Two: The distribution of 𝒙{\bm{x}} over different classes are randomly generated. In both cases, the estimated 𝒛^m(t)\hat{\bm{z}}_{m}(t) by the first integral (Theorem 1), despite assumptions on 𝒃¯m\bar{\bm{b}}_{m}, shows high correlation with the ground truth self-attention logits 𝒛m(t){\bm{z}}_{m}(t), while its two components 𝒛^m1(t):=12k𝒗k2(t)\hat{\bm{z}}_{m1}(t):=\frac{1}{2}\sum_{k}{\bm{v}}_{k}^{2}(t) and 𝒛^m2(t):=12k𝒗k(t)22𝒃¯m\hat{\bm{z}}_{m2}(t):=-\frac{1}{2}\sum_{k}\|{\bm{v}}_{k}(t)\|_{2}^{2}\bar{\bm{b}}_{m} do not.

3 JoMA: Existence of JOint dynamics of Attention and MLP

While the learning dynamics of 𝒘k{\bm{w}}_{k} and 𝒛m{\bm{z}}_{m} can be complicated, surprisingly, training dynamics suggests that the attention logits 𝒛m(t){\bm{z}}_{m}(t) have close-form relationship with respect to the MLP weights 𝒘k(t){\bm{w}}_{k}(t), which lays the foundation of our JoMA framework:

Theorem 1 (JoMA).

Let 𝐯k:=UC𝐰k{\bm{v}}_{k}:=U_{C}^{\top}{\bm{w}}_{k}, then the dynamics of Eqn. 2 satisfies the invariants:

  • Linear attention. The dynamics satisfies 𝒛m2(t)=k𝒗k2(t)+𝒄{\bm{z}}^{2}_{m}(t)=\sum_{k}{\bm{v}}^{2}_{k}(t)+{\bm{c}}.

  • Exp attention. The dynamics satisfies 𝒛m(t)=12k𝒗k2(t)+𝒄{\bm{z}}_{m}(t)=\frac{1}{2}\sum_{k}{\bm{v}}^{2}_{k}(t)+{\bm{c}}.

  • Softmax attention. If 𝒃¯m:=𝔼q=m[𝒃]\bar{\bm{b}}_{m}:=\mathbb{E}_{q=m}\left[{\bm{b}}\right] is a constant over time and 𝔼q=m[kghkhk𝒃𝒃]=𝒃¯m𝔼q=m[kghkhk𝒃]\mathbb{E}_{q=m}\left[\sum_{k}g_{h_{k}}h_{k}^{\prime}{\bm{b}}{\bm{b}}^{\top}\right]=\bar{\bm{b}}_{m}\mathbb{E}_{q=m}\left[\sum_{k}g_{h_{k}}h_{k}^{\prime}{\bm{b}}\right], then the dynamics satisfies 𝒛m(t)=12k𝒗k2(t)𝒗k(t)22𝒃¯m+𝒄{\bm{z}}_{m}(t)=\frac{1}{2}\sum_{k}{\bm{v}}^{2}_{k}(t)-\|{\bm{v}}_{k}(t)\|_{2}^{2}\bar{\bm{b}}_{m}+{\bm{c}}.

Under zero initialization (𝐰k(0)=0{\bm{w}}_{k}(0)=0, 𝐳m(0)=0{\bm{z}}_{m}(0)=0), then the time-independent constant 𝐜=0{\bm{c}}=0.

Therefore, we don’t need to explicitly update self-attention, since it is already implicitly incorporated in the lower layer of MLP weight! For softmax attention, we verify that even with the assumption, the invariance proposed by Theorem 1 still predicts 𝒛m(t){\bm{z}}_{m}(t) fairly well.

3.1 Linear activations: winner-take-all

Now we can solve the dynamics of 𝒘k(t){\bm{w}}_{k}(t) (Eqn. 2), by plugging in the close-form solution of self-attention. For simplicity, we consider exp attention with K=1K=1 (i.e., single hidden MLP node). Let Δm:=𝔼q=m[ghkhk𝒙]\Delta_{m}:=\mathbb{E}_{q=m}\left[g_{h_{k}}h^{\prime}_{k}{\bm{x}}\right], then 𝒗k{\bm{v}}_{k}’s dynamics is (𝒗k{\bm{v}}_{k} written as 𝒗{\bm{v}}):

𝒗˙=Δmexp(𝒛m)=Δmexp(𝒗2/2+𝒄)\dot{\bm{v}}=\Delta_{m}\circ\exp({\bm{z}}_{m})=\Delta_{m}\circ\exp({\bm{v}}^{2}/2+{\bm{c}}) (3)

In the case of linear activations ϕ(x)=x\phi(x)=x, hk1h^{\prime}_{k}\equiv 1. According to Assumption 1, Δm\Delta_{m} does not depend on 𝒗{\bm{v}} and we arrive at the following theorem:

Refer to caption
Figure 3: Growth of different components in 𝒗0(t){\bm{v}}_{0}(t) (First few components of the first column of V(t)V(t)) in linear MLP activation and softmax attention. As predicted by Sec. 3.1, after convergence, only some components of 𝒗0{\bm{v}}_{0} grows while the remaining components is saturated after initial growing, consistent with Theorem 2 even if it is derived from JoMA’s approximation in Theorem 1. Each node kk (and thus 𝒘k{\bm{w}}_{k}) receives back-propagated gradient from kk-th class via cross-entropy loss.
Theorem 2 (Linear Dynamics with Self-attention).

With linear MLP activation and zero initialization, for exp attention any two tokens lll\neq l^{\prime} satisfy the following invariants:

erf(vl(t)/2)Δlm=erf(vl(t)/2)Δlm\frac{\mathrm{erf}\left(v_{l}(t)/2\right)}{\Delta_{lm}}=\frac{\mathrm{erf}(v_{l^{\prime}}(t)/2)}{\Delta_{l^{\prime}m}} (4)

where Δlm=𝔼q=m[ghkxl]\Delta_{lm}=\mathbb{E}_{q=m}\left[g_{h_{k}}x_{l}\right] and erf(x)=2π0xet2dt\mathrm{erf}(x)=\frac{2}{\sqrt{\pi}}\int_{0}^{x}e^{-t^{2}}\mathrm{d}t is Gauss error function.

Remarks. The dynamics suggests that the weights become one-hot over training. Specifically, let l=argmaxl|Δlm|l^{*}=\operatorname*{arg\,max}_{l}|\Delta_{lm}|, then vl(t)sign(Δlm)×v_{l^{*}}(t)\rightarrow\operatorname{sign}(\Delta_{l^{*}m})\times\infty and other vl(t)v_{l}(t) converges to finite numbers, because of the constraint imposed by Eqn. 4 (see Fig. 3). For softmax attention, there is an additional sample-dependent normalization constant A[i]A[i], if A[i]A[i] remains constant across samples and all elements of 𝒃¯m\bar{\bm{b}}_{m} are the same, then Theorem 2 also applies.

Beyond distinct/common tokens. Δlm:=𝔼l,q=m[ghk](l|m)\Delta_{lm}:=\mathbb{E}_{l,q=m}\left[g_{h_{k}}\right]\mathbb{P}(l|m)222Since xl[i]x_{l}[i] is the empirical frequency of token ll in sample ii, we have Δlm=𝔼q=m[ghkxl]=ighk[i](l|q=m,i)(i|q=m)=ighk[i](i|q=m,l)(l|q=m)=𝔼l,q=m[ghk](l|m)\Delta_{lm}=\mathbb{E}_{q=m}\left[g_{h_{k}}x_{l}\right]=\sum_{i}g_{h_{k}}[i]\mathbb{P}(l|q=m,i)\mathbb{P}(i|q=m)=\sum_{i}g_{h_{k}}[i]\mathbb{P}(i|q=m,l)\mathbb{P}(l|q=m)=\mathbb{E}_{l,q=m}\left[g_{h_{k}}\right]\mathbb{P}(l|m). is a product of token discriminancy (i.e., 𝔼l,q=m[ghk]>0\mathbb{E}_{l,q=m}\left[g_{h_{k}}\right]>0 means token ll positively correlated to backpropagated gradient ghkg_{h_{k}}, or label in the 1-layer case) and token frequency (i.e., (l|m)\mathbb{P}(l|m), how frequent ll appears given mm). This covers a broader spectrum of tokens than Tian et al. (2023), which only discusses distinct (i.e., large |Δlm||\Delta_{lm}|) and common tokens (i.e., when Δlm0\Delta_{lm}\approx 0).

4 Training Dynamics under Nonlinear Activations

In nonlinear case, the dynamics turns out to be very different. In this case, Δm\Delta_{m} is no longer a constant, but will change. As a result, the dynamics also changes substantially.

Theorem 3 (Dynamics of nonlinear activation with uniform attention).

If 𝐱{\bm{x}} is sampled from a mixture of CC isotropic distributions centered at [𝐱¯1,,𝐱¯C][\bar{\bm{x}}_{1},\ldots,\bar{\bm{x}}_{C}], where each 𝐱¯cd\bar{\bm{x}}_{c}\in\mathbb{R}^{d} and gradient ghkg_{h_{k}} are constant within each mixture, then:

𝒗˙\displaystyle\dot{\bm{v}} =\displaystyle= Δm=1𝒗2cacθ1(rc)𝒙¯c+1𝒗23cacθ2(rc)𝒗\displaystyle\Delta_{m}=\frac{1}{\|{\bm{v}}\|_{2}}\sum_{c}a_{c}\theta_{1}(r_{c})\bar{\bm{x}}_{c}+\frac{1}{\|{\bm{v}}\|_{2}^{3}}\sum_{c}a_{c}\theta_{2}(r_{c}){\bm{v}} (5)

here ac:=𝔼q=m,c[ghk][c]a_{c}:=\mathbb{E}_{q=m,c}\left[g_{h_{k}}\right]\mathbb{P}[c], rc:=𝐯𝐱¯c+ξr_{c}:={\bm{v}}^{\top}\bar{\bm{x}}_{c}+\xi is the affinity to 𝐱¯c\bar{\bm{x}}_{c} and the “bias” term ξ(t):=0t𝔼q=m[ghkhk]dt\xi(t):=\int_{0}^{t}\mathbb{E}_{q=m}\left[g_{h_{k}}h^{\prime}_{k}\right]\mathrm{d}t, θ1\theta_{1} and θ2\theta_{2} depend on derivative of nonlinearity ψ:=ϕ\psi:=\phi^{\prime} and data distribution but not 𝐯{\bm{v}}. If ψ\psi is monotonous with ψ()=0\psi(-\infty)=0 and ψ(+)=1\psi(+\infty)=1, so does θ1\theta_{1}.

Appendix A.3.2 presents critical point analysis. Here we focus on a simplified one when 𝒗{\bm{v}} is constrained to be a unit vector, which leads to the following modified dynamics (P𝒗𝒗=0P^{\perp}_{\bm{v}}{\bm{v}}=0):

𝒗˙=P𝒗Δm=cacθ1(rc)P𝒗𝒙¯c=cacθ1(rc)𝒙¯c[𝝁c(𝒗𝝁c)𝒗]\dot{\bm{v}}=P^{\perp}_{\bm{v}}\Delta_{m}=\sum_{c}a_{c}\theta_{1}(r_{c})P^{\perp}_{\bm{v}}\bar{\bm{x}}_{c}=\sum_{c}a_{c}\theta_{1}(r_{c})\|\bar{\bm{x}}_{c}\|[\bm{\mu}_{c}-({\bm{v}}^{\top}\bm{\mu}_{c}){\bm{v}}] (6)

where 𝝁c:=𝒙¯c/𝒙¯c\bm{\mu}_{c}:=\bar{\bm{x}}_{c}/\|\bar{\bm{x}}_{c}\|. We consider when 𝒗{\bm{v}} is aligned with one cluster 𝒙¯c\bar{\bm{x}}_{c} but far away from others, then rcrcr_{c}\gg r_{c^{\prime}} for ccc^{\prime}\neq c and θ1(rc)θ1(rc)\theta_{1}(r_{c})\gg\theta_{1}(r_{c^{\prime}}) since θ1\theta_{1} is monotonously increasing. Hence 𝝁c\bm{\mu}_{c} dominates and let 𝝁:=𝝁c\bm{\mu}:=\bm{\mu}_{c} for brevity. Similar to Eqn. 3, we use close-form simplification of JoMA to incorporate self-attention, which leads to (we use exp attention):

𝒗˙(𝝁𝒗)exp(𝒗2/2)\dot{\bm{v}}\propto(\bm{\mu}-{\bm{v}})\circ\exp({\bm{v}}^{2}/2) (7)

Here we omit the scalar terms and study when 𝒗{\bm{v}} is close to 𝝁\bm{\mu}, in which 𝒗𝝁=1+O(𝝁𝒗22)1{\bm{v}}^{\top}\bm{\mu}=1+O(\|\bm{\mu}-{\bm{v}}\|_{2}^{2})\approx 1. It is clear that the critical point 𝒗=𝝁{\bm{v}}_{*}=\bm{\mu} does not change after adding the term exp(𝒗2/2)\exp({\bm{v}}^{2}/2). However, the convergence speed changes drastically. As shown in the following lemma, the convergence speed towards salient component of 𝝁\bm{\mu} (i.e., component with large magnitude) is much faster than non-salient ones:

Theorem 4 (Convergence speed of salient vs. non-salient components).

Let δj(t):=1vj(t)/μj\delta_{j}(t):=1-v_{j}(t)/\mu_{j} be the convergence metric for component jj (δj(t)=0\delta_{j}(t)=0 means that the component jj converges). For nonlinear dynamics with attention (Eqn. 7), then

lnδj(0)/δj(t)lnδk(0)/δk(t)=eμj2/2eμk2/2(1+Λjk(t))\frac{\ln\delta_{j}(0)/\delta_{j}(t)}{\ln\delta_{k}(0)/\delta_{k}(t)}=\frac{e^{\mu^{2}_{j}/2}}{e^{\mu^{2}_{k}/2}}(1+\Lambda_{jk}(t)) (8)

Here Λjk(t)=λjk(t)eμk2/2ln1(δk(0)/δk(t))\Lambda_{jk}(t)=\lambda_{jk}(t)\cdot e^{\mu_{k}^{2}/2}\ln^{-1}(\delta_{k}(0)/\delta_{k}(t)) where |λjk(t)|Cjk|\lambda_{jk}(t)|\leq C_{jk} and CjkC_{jk} only depends on δj(0)\delta_{j}(0) and δk(0)\delta_{k}(0). So when |δk(t)||δk(0)|exp[Cjkexp(μk2)]|\delta_{k}(t)|\ll|\delta_{k}(0)|\exp[-C_{jk}\exp(\mu_{k}^{2})], we have |Λ(t)|1|\Lambda(t)|\ll 1.

Remarks. For linear attention, the ratio is different but the derivation is similar and simpler. Note that the convergence speed heavily depends on the magnitude of μj\mu_{j}. If μj>μk\mu_{j}>\mu_{k}, then δj(t)δk(t)\delta_{j}(t)\ll\delta_{k}(t) and vj(t)v_{j}(t) converges much faster than vk(t)v_{k}(t). Therefore, the salient (i.e., large) components is learned first, and the non-salient (i.e., small) component is learned later, due to the modulation of the extra term exp(𝒗2/2)\exp({\bm{v}}^{2}/2) thanks to self-attention, as demonstrated in Fig. 4.

A follow-up question arises: What is the intuition behind salient and non-salient components in 𝝁\bm{\mu}? Note that 𝝁\bm{\mu} is an 2\ell_{2}-normalized version of the conditional token frequency 𝒙{\bm{x}}, given the query q=mq=m. In this case, similar to Theorem 2 (and Tian et al. (2023)), we again see that if a contextual token ll co-occurs a lot with the query mm, then the corresponding component μl\mu_{l} becomes larger and the growth speed of vlv_{l} towards μl\mu_{l} is much faster.

Refer to caption
Figure 4: Dynamics of nonlinear MLP with self-attention components included (Eqn. 7). Left: Training dynamics (color indicating training steps). The salient components (i.e., components with large magnitude in 𝝁\bm{\mu}) of 𝒗(t){\bm{v}}(t) are learned first, followed by non-salient ones. Right: Entropy of the attention (i.e., entropy(softmax(𝒗2))\mathrm{entropy}(\mathrm{softmax}({\bm{v}}^{2}))) drops when salient components are learned first, and then rebounces when other components catch up.

Relationship with rank of MLP lower layer. Since MLP and attention layer has joint dynamics (Theorem 1), this also suggests that in the MLP layer, the rank of lower layer matrix WW (which projects into the hidden nodes) will first drop since the weight components that correspond to high target value μj\mu_{j} grow first, and then bounce back to higher rank when the components that correspond to low target value μj\mu_{j} catch up later.

5 How self-attention learns hierarchical data distribution?

A critical difference between the training dynamics of linear and nonlinear MLP is that in the nonlinear case, although slowly, the non-salient components will still grow, and the entropy of the attention bounces back later. While for 1-layer Transformer, this may only slow the training with no clear benefits, the importance of such a behavior is manifested if we think about the dynamics of multiple Transformer layers trained on a data distribution generated in a hierarchical manner.

Consider a simple generative hierarchical binary latent tree model (HBLT(Tian et al., 2020) (Fig. 7(a)) in which we have latent (unobservable) binary variables yy at layer ss that generate latents at layer s1s-1, until the observable tokens are generated at the lowest level (s=0s=0). The topmost layer is the class label y0y_{0}, which can take DD discrete values. In HBLT, the generation process of yβy_{\beta} at layer s1s-1 given yαy_{\alpha} at layer ss can be characterized by their conditional probability [yβ=1|yα=1]=[yβ=0|yα=0]=12(1+ρ)\mathbb{P}[y_{\beta}=1|y_{\alpha}=1]=\mathbb{P}[y_{\beta}=0|y_{\alpha}=0]=\frac{1}{2}(1+\rho). The uncertainty hyperparameter ρ[1,1]\rho\in[-1,1] determines how much the top level latents can determine the values of the low level ones. Please check Appendix A.5 for its formal definition.

With HBLT, we can compute the co-occurrence frequency of two tokens ll and mm, as a function of the depth of their common latent ancestor (CLA):

Theorem 5 (Token Co-occurrence in HBLT(ρ)\texttt{HBLT}{}(\rho)).

If token ll and mm have common latent ancestor (CLA) of depth HH (Fig. 5(c)), then [yl=1|ym=1]=12(1+ρ2H2ρL1ρ01ρL1ρ0)\mathbb{P}[y_{l}=1|y_{m}=1]=\frac{1}{2}\left(\frac{1+\rho^{2H}-2\rho^{L-1}\rho_{0}}{1-\rho^{L-1}\rho_{0}}\right), where LL is the total depth of the hierarchy and ρ0:=𝐩|0𝐩0\rho_{0}:={\bm{p}}_{\cdot|0}^{\top}{\bm{p}}_{0}, in which 𝐩0=[[y0=k]]D{\bm{p}}_{0}=[\mathbb{P}[y_{0}=k]]\in\mathbb{R}^{D} and 𝐩|0:=[[yl=0|y0=k]]D{\bm{p}}_{\cdot|0}:=[\mathbb{P}[y_{l}=0|y_{0}=k]]\in\mathbb{R}^{D}, where {yl}\{y_{l}\} are the immediate children of the root node y0y_{0}.

Remarks. If y0y_{0} takes multiple values (many classes) and each class only trigger one specific latent binary variables, then most of the top layer latents are very sparsely triggered and thus ρ0\rho_{0} is very close to 11. If ρ\rho is also close to 11, then for deep hierarchy and shallow common ancestor, [yl=1|ym=1]1\mathbb{P}[y_{l}=1|y_{m}=1]\rightarrow 1. To see this, assume ρ=ρ0=1ϵ\rho=\rho_{0}=1-\epsilon, then we have:

[yl=1|ym=1]=12[1+12Hϵ2(1Lϵ)1(1Lϵ)]+O(ϵ2)=1HL+O(ϵ2)\displaystyle\mathbb{P}[y_{l}=1|y_{m}=1]=\frac{1}{2}\left[\frac{1+1-2H\epsilon-2(1-L\epsilon)}{1-(1-L\epsilon)}\right]+O(\epsilon^{2})=1-\frac{H}{L}+O(\epsilon^{2}) (9)

This means that two tokens ll and mm co-occur a lot, if they have a shallow CLA (HH small) that is close to both tokens. If their CLA is high in the hierarchy (e.g., ll^{\prime} and mm), then the token ll^{\prime} and mm have much weaker co-occurrence and (l|m)\mathbb{P}(l^{\prime}|m) (and thus xlx_{l^{\prime}} and μl\mu_{l^{\prime}}) is small.

Refer to caption
Figure 5: (a) Hierarchical binary tree generative models. Except for y0y_{0} that is the observable label of a sequence and can take DD discrete labels, all latent variables follow binomial distribution. A binary leaf variable yl=1y_{l}=1 indicates that token ll appears in the sequence. (b) Attention dynamics in multi-layer setting. There is a strong co-occurrence between the query mm and the token ll, but a weak co-occurrence between mm and ll^{\prime}. As a result, mm associates with ll first, and eventually associates with ll^{\prime}, even if they co-occur weakly, according to Theorem 4. (c) If there exists an additional layer yβy_{\beta} and yβy_{\beta^{\prime}} in the latent hierarchy, the association mm-ll and mm^{\prime}-ll^{\prime} will be learned first due to their high co-occurrence. Once the lower hierarchy gets learned and some hidden nodes in MLP represents yβy_{\beta} and yβy_{\beta^{\prime}} (see Sec. 8 for experimental validation), on the next level, yβy_{\beta} and yβy_{\beta^{\prime}} shows strong co-occurrence and gets picked up by the self-attention mechanism to form even higher level features. In contrast, the association of ll^{\prime}-mm is much slower and does not affect latent hierarchy learning, showing that self-attention mechanism is adaptive to the structure of data distribution.

With this generative model, we can analyze qualitatively the learning dynamics of JoMA: first it focuses on associating the tokens in the same lowest hierarchy as the query mm (and these tokens co-occur a lot with mm), then gradually reaches out to other tokens ll^{\prime} that co-occur less with mm, if they have not been picked up by other tokens (Fig. 5(b)); if ll^{\prime} co-occurs a lot with some other mm^{\prime}, then mm-ll and mm^{\prime}-ll^{\prime} form their own lower hierarchy, respectively. This leads to learning of high-level features yβy_{\beta} and yβy_{\beta^{\prime}}, which has high correlation are associated in the higher level. Therefore, the latent hierarchy is implicitly learned.

6 Experiments

Dynamics of Attention Sparsity. Fig. 6 shows how attention sparsity changes over time when training from scratch. We use 10410^{-4} learning rate and test our hypothesis on Wikitext2/Wikitext103 (Merity et al., 2016) (top/bottom row). Fig. 8 further shows that different learning rate leads to different attention sparsity patterns. With large learning rate, attention becomes extremely sparse as in (Tian et al., 2023). Interestingly, the attention patterns, which coincide with our theoretical analysis, yield the best validation score.

We also tested our hypothesis in OPT (Zhang et al., 2022) (OPT-2.7B) and Pythia (Biderman et al., 2023) (Pythia-70M/1.4B/6.9B) pre-trained models, both of which has public intermediate checkpoints. While the attention patterns show less salient drop-and-bounce patterns, the dynamics of stable ranks of the MLP lower layer (projection into hidden neurons) show much salient such structures for top layers, and dropping curves for bottom layers since they are suppressed by top-level learning (Sec. 5). Note that stable ranks only depend on the model parameters and thus may be more reliable than attention sparsity.

Refer to caption
Refer to caption
Figure 6: Dynamics of attention sparsity. In 1-layer setting, The curves bear strong resemblance to our theoretical prediction (Fig. 4); in multi-layer settings, the attention entropy in top Transformer layers has a similar shape, while the entropy in bottom layers are suppressed due to layer interactions (Sec. 4). Top row: Wikitext2, Bottom row: Wikitext103.
Refer to caption
Refer to caption
Figure 7: Dynamics of attention sparsity and stable rank in OPT-2.7B and Pythia-70M/1.4B/6.9B. Results are evaluated on Wikitext103 (Merity et al., 2016).
Refer to caption
Refer to caption
Figure 8: Effect of different learning rates on attention sparsity. Different learning rates lead to different dynamics of attention sparsity, and the attention patterns consistent with our theoretical analysis (Fig. 4) give the lowest validation losses.

Validation of Alignment between latents and hidden nodes in MLP. Sec. 5 is based on an assumption that the hidden nodes in MLP layer will learn the latent variables. We verify this assumption in synthetic data sampled by HBLT, which generate latent variables in a top-down manner, until the final tokens are generated. The latent hierarchy has 2 hyperparameters: number of latents per layer (NsN_{s}) and number of children per latent (NchN_{\mathrm{ch}}). CC is the number of classes. Adam optimizer is used with learning rate 10510^{-5}. Vocabulary size M=100M=100, sequence length T=30T=30 and embedding dimension d=1024d=1024.

We use 3-layer generative model as well as 3-layer Transformer models. We indeed perceive high correlations between the latents and the hidden neurons between corresponding layers. Note that latents are known during input generation procedure but are not known to the transformer being trained. We take the maximal activation of each neuron across the sequence length, and compute normalized correlation between maximal activation of each neuron and latents, after centeralizing across the sample dimension. Tbl. 1 shows that indeed in the learned models, for each latent, there exists at least one hidden node in MLP that has high normalized correlation with it, in particular in the lowest layer. When the generative models becomes more complicated (i.e., both NchN_{\mathrm{ch}} and NlN_{l} become larger), the correlation goes down a bit.

C=20C=20, Nch=2N_{\mathrm{ch}}=2 C=20C=20, Nch=3N_{\mathrm{ch}}=3 C=30C=30, Nch=2N_{\mathrm{ch}}=2
(N0,N1)(N_{0},N_{1}) (10, 20) (20, 30) (10, 20) (20, 30) (10, 20) (20, 30)
NCorr (s=0s=0) 0.99±0.010.99\pm 0.01 0.97±0.020.97\pm 0.02 1.00±0.001.00\pm 0.00 0.96±0.020.96\pm 0.02 0.99±0.010.99\pm 0.01 0.94±0.040.94\pm 0.04
NCorr (s=1s=1) 0.81±0.050.81\pm 0.05 0.80±0.050.80\pm 0.05 0.69±0.050.69\pm 0.05 0.68±0.040.68\pm 0.04 0.73±0.080.73\pm 0.08 0.74±0.030.74\pm 0.03
C=30C=30 Nch=3N_{\mathrm{ch}}=3 C=50C=50, Nch=2N_{\mathrm{ch}}=2 C=50C=50, Nch=3N_{\mathrm{ch}}=3
(N0,N1)(N_{0},N_{1}) (10, 20) (20, 30) (10, 20) (20, 30) (10, 20) (20, 30)
NCorr (s=0s=0) 0.99±0.010.99\pm 0.01 0.95±0.030.95\pm 0.03 0.99±0.010.99\pm 0.01 0.95±0.030.95\pm 0.03 0.99±0.010.99\pm 0.01 0.95±0.030.95\pm 0.03
NCorr (s=1s=1) 0.72±0.040.72\pm 0.04 0.66±0.020.66\pm 0.02 0.58±0.020.58\pm 0.02 0.55±0.010.55\pm 0.01 0.64±0.020.64\pm 0.02 0.61±0.040.61\pm 0.04
Table 1: Normalized correlation between the latents and their best matched hidden node in MLP of the same layer. All experiments are run with 5 random seeds.

7 Discussion

Deal with almost orthogonal embeddings. In this paper, we focus on fixed orthonormal embeddings vectors. However, in real-world Transformer training, the assumption may not be valid, since often the embedding dimension dd is smaller than the number of vocabulary MM so the embedding vectors cannot be orthogonal to each other. In this setting, one reasonable assumption is that the embedding vectors are almost orthogonal. Thanks to Johnson–Lindenstrauss lemma, one interesting property of high-dimensional space is that for MM embedding vectors to achieve almost orthogonality |𝒖l𝒖l|ϵ|{\bm{u}}_{l}^{\top}{\bm{u}}_{l^{\prime}}|\leq\epsilon, only d8ϵ2logMd\geq 8\epsilon^{-2}\log M is needed. As a result, our JoMA framework (Theorem 1) will have additional ϵ\epsilon-related terms and we leave the detailed analysis as one of our future work.

Training embedding vectors. Another factor that is not considered in JoMA is that the embedding vectors are also trained simultaneously. This could further boost the efficiency of Transformer architecture, since concepts with similar semantics will learn similar embeddings. This essentially reduces the vocabulary size at each layer for learning to be more effective, and leads to better generalization. For example, in each hidden layer 4d4d hidden neurons are computed, which does not mean there are 4d4d independent intermediate “tokens”, because many of their embeddings are highly correlated.

Self-attention computed from embedding. JoMA arrives at the joint dynamics of MLP and attention by assuming that the pairwise attention score ZZ is an independent parameters optimized under SGD dynamics. In practice, Z=UWQWKUZ=UW_{Q}W^{\top}_{K}U^{\top} is also parameterized by the embedding matrix, which allow generalization to tokens with similar embeddings, and may accelerate the training dynamics of ZZ. We leave it in the future works.

8 Conclusion

We propose JoMA, a framework that characterizes the joint training dynamics of nonlinear MLP and attention layer, by integrating out the self-attention logits. The resulting dynamics connects the dynamics of nonlinear MLP lower layer weights (projection into hidden neurons) and self-attention, and shows that the attention first becomes sparse (or weights becomes low rank) and then becomes dense (or weights becomes high rank). Furthermore, we qualitatively give a learning mechanism of multilayer Transformer that reveals how self-attentions at different layers interact with each other to learn the latent feature hierarchy.

Acknowledgments

Simon S. Du is supported by supported by NSF IIS 2110170, NSF DMS 2134106, NSF CCF 2212261, NSF IIS 2143493, NSF CCF 2019844, NSF IIS 2229881.

References

  • Akyürek et al. (2022) Ekin Akyürek, Dale Schuurmans, Jacob Andreas, Tengyu Ma, and Denny Zhou. What learning algorithm is in-context learning? investigations with linear models. arXiv preprint arXiv:2211.15661, 2022.
  • Allen-Zhu et al. (2019) Zeyuan Allen-Zhu, Yuanzhi Li, and Zhao Song. A convergence theory for deep learning via over-parameterization. In International Conference on Machine Learning, pp. 242–252. PMLR, 2019.
  • Anil et al. (2022) Cem Anil, Yuhuai Wu, Anders Andreassen, Aitor Lewkowycz, Vedant Misra, Vinay Ramasesh, Ambrose Slone, Guy Gur-Ari, Ethan Dyer, and Behnam Neyshabur. Exploring length generalization in large language models. arXiv preprint arXiv:2207.04901, 2022.
  • Arora et al. (2018) Sanjeev Arora, Nadav Cohen, Noah Golowich, and Wei Hu. A convergence analysis of gradient descent for deep linear neural networks. arXiv preprint arXiv:1810.02281, 2018.
  • Arora et al. (2019) Sanjeev Arora, Simon Du, Wei Hu, Zhiyuan Li, and Ruosong Wang. Fine-grained analysis of optimization and generalization for overparameterized two-layer neural networks. In International Conference on Machine Learning, pp. 322–332. PMLR, 2019.
  • Bai et al. (2023) Yu Bai, Fan Chen, Huan Wang, Caiming Xiong, and Song Mei. Transformers as statisticians: Provable in-context learning with in-context algorithm selection. arXiv preprint arXiv:2306.04637, 2023.
  • Barak et al. (2022) Boaz Barak, Benjamin Edelman, Surbhi Goel, Sham Kakade, Eran Malach, and Cyril Zhang. Hidden progress in deep learning: Sgd learns parities near the computational limit. Advances in Neural Information Processing Systems, 35:21750–21764, 2022.
  • Bartlett et al. (2018) Peter Bartlett, Dave Helmbold, and Philip Long. Gradient descent with identity initialization efficiently learns positive definite linear transformations by deep residual networks. In International conference on machine learning, pp. 521–530. PMLR, 2018.
  • Bhattamishra et al. (2020a) Satwik Bhattamishra, Kabir Ahuja, and Navin Goyal. On the ability and limitations of transformers to recognize formal languages. arXiv preprint arXiv:2009.11264, 2020a.
  • Bhattamishra et al. (2020b) Satwik Bhattamishra, Arkil Patel, and Navin Goyal. On the computational power of transformers and its implications in sequence modeling. arXiv preprint arXiv:2006.09286, 2020b.
  • Biderman et al. (2023) Stella Biderman, Hailey Schoelkopf, Quentin Gregory Anthony, Herbie Bradley, Kyle O’Brien, Eric Hallahan, Mohammad Aflah Khan, Shivanshu Purohit, USVSN Sai Prashanth, Edward Raff, et al. Pythia: A suite for analyzing large language models across training and scaling. In International Conference on Machine Learning, pp. 2397–2430. PMLR, 2023.
  • Bietti et al. (2023) Alberto Bietti, Vivien Cabannes, Diane Bouchacourt, Herve Jegou, and Leon Bottou. Birth of a transformer: A memory viewpoint. arXiv preprint arXiv:2306.00802, 2023.
  • Boix-Adsera et al. (2023) Enric Boix-Adsera, Etai Littwin, Emmanuel Abbe, Samy Bengio, and Joshua Susskind. Transformers learn through gradual rank increase. arXiv preprint arXiv:2306.07042, 2023.
  • Brutzkus & Globerson (2017) Alon Brutzkus and Amir Globerson. Globally optimal gradient descent for a convnet with gaussian inputs. In International conference on machine learning, pp. 605–614. PMLR, 2017.
  • Chizat & Bach (2018) Lenaic Chizat and Francis Bach. On the global convergence of gradient descent for over-parameterized models using optimal transport. Advances in neural information processing systems, 31, 2018.
  • Chizat et al. (2019) Lenaic Chizat, Edouard Oyallon, and Francis Bach. On lazy training in differentiable programming. Advances in neural information processing systems, 32, 2019.
  • Dehghani et al. (2018) Mostafa Dehghani, Stephan Gouws, Oriol Vinyals, Jakob Uszkoreit, and Łukasz Kaiser. Universal transformers. arXiv preprint arXiv:1807.03819, 2018.
  • Devlin et al. (2018) Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. Bert: Pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805, 2018.
  • Dong et al. (2022) Qingxiu Dong, Lei Li, Damai Dai, Ce Zheng, Zhiyong Wu, Baobao Chang, Xu Sun, Jingjing Xu, and Zhifang Sui. A survey for in-context learning. arXiv preprint arXiv:2301.00234, 2022.
  • Dosovitskiy et al. (2020) Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, et al. An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929, 2020.
  • Du et al. (2018a) Simon Du, Jason Lee, Yuandong Tian, Aarti Singh, and Barnabas Poczos. Gradient descent learns one-hidden-layer cnn: Don’t be afraid of spurious local minima. In International Conference on Machine Learning, pp. 1339–1348. PMLR, 2018a.
  • Du et al. (2019) Simon Du, Jason Lee, Haochuan Li, Liwei Wang, and Xiyu Zhai. Gradient descent finds global minima of deep neural networks. In International conference on machine learning, pp. 1675–1685. PMLR, 2019.
  • Du et al. (2017) Simon S Du, Jason D Lee, and Yuandong Tian. When is a convolutional filter easy to learn? arXiv preprint arXiv:1709.06129, 2017.
  • Du et al. (2018b) Simon S. Du, Xiyu Zhai, Barnabas Poczos, and Aarti Singh. Gradient descent provably optimizes over-parameterized neural networks, 2018b. URL https://arxiv.org/abs/1810.02054.
  • Edelman et al. (2022) Benjamin L Edelman, Surbhi Goel, Sham Kakade, and Cyril Zhang. Inductive biases and variable creation in self-attention mechanisms. In International Conference on Machine Learning, pp. 5793–5831. PMLR, 2022.
  • Elhage et al. (2021) N Elhage, N Nanda, C Olsson, T Henighan, N Joseph, B Mann, A Askell, Y Bai, A Chen, T Conerly, et al. A mathematical framework for transformer circuits. Transformer Circuits Thread, 2021.
  • Fang et al. (2021) Cong Fang, Jason Lee, Pengkun Yang, and Tong Zhang. Modeling from features: a mean-field framework for over-parameterized deep neural networks. In Conference on learning theory, pp.  1887–1936. PMLR, 2021.
  • Garg et al. (2022) Shivam Garg, Dimitris Tsipras, Percy S Liang, and Gregory Valiant. What can transformers learn in-context? a case study of simple function classes. Advances in Neural Information Processing Systems, 35:30583–30598, 2022.
  • Goel et al. (2018) Surbhi Goel, Adam Klivans, and Raghu Meka. Learning one convolutional layer with overlapping patches. In International Conference on Machine Learning, pp. 1783–1791. PMLR, 2018.
  • Hron et al. (2020) Jiri Hron, Yasaman Bahri, Jascha Sohl-Dickstein, and Roman Novak. Infinite attention: Nngp and ntk for deep attention networks. In International Conference on Machine Learning, pp. 4376–4386. PMLR, 2020.
  • Jacot et al. (2018) Arthur Jacot, Franck Gabriel, and Clément Hongler. Neural tangent kernel: Convergence and generalization in neural networks. Advances in neural information processing systems, 31, 2018.
  • Jelassi et al. (2022) Samy Jelassi, Michael Sander, and Yuanzhi Li. Vision transformers provably learn spatial structure. Advances in Neural Information Processing Systems, 35:37822–37836, 2022.
  • Li et al. (2023a) Hongkang Li, Meng Wang, Sijia Liu, and Pin-Yu Chen. A theoretical understanding of shallow vision transformers: Learning, generalization, and sample complexity. In The Eleventh International Conference on Learning Representations, 2023a. URL https://openreview.net/forum?id=jClGv3Qjhb.
  • Li et al. (2023b) Shuai Li, Zhao Song, Yu Xia, Tong Yu, and Tianyi Zhou. The closeness of in-context learning and weight shifting for softmax regression. arXiv preprint arXiv:2304.13276, 2023b.
  • Li & Liang (2018) Yuanzhi Li and Yingyu Liang. Learning overparameterized neural networks via stochastic gradient descent on structured data. Advances in neural information processing systems, 31, 2018.
  • Li et al. (2023c) Yuchen Li, Yuanzhi Li, and Andrej Risteski. How do transformers learn topic structure: Towards a mechanistic understanding. arXiv preprint arXiv:2303.04245, 2023c.
  • Likhosherstov et al. (2021) Valerii Likhosherstov, Krzysztof Choromanski, and Adrian Weller. On the expressive power of self-attention matrices. arXiv preprint arXiv:2106.03764, 2021.
  • Liu et al. (2019) Tianyi Liu, Minshuo Chen, Mo Zhou, Simon S Du, Enlu Zhou, and Tuo Zhao. Towards understanding the importance of shortcut connections in residual networks. Advances in neural information processing systems, 32, 2019.
  • Lu et al. (2020) Yiping Lu, Chao Ma, Yulong Lu, Jianfeng Lu, and Lexing Ying. A mean field analysis of deep resnet and beyond: Towards provably optimization via overparameterization from depth. In International Conference on Machine Learning, pp. 6426–6436. PMLR, 2020.
  • Mei et al. (2018) Song Mei, Andrea Montanari, and Phan-Minh Nguyen. A mean field view of the landscape of two-layer neural networks. Proceedings of the National Academy of Sciences, 115(33):E7665–E7671, 2018.
  • Merity et al. (2016) Stephen Merity, Caiming Xiong, James Bradbury, and Richard Socher. Pointer sentinel mixture models. arXiv preprint arXiv:1609.07843, 2016.
  • Nguyen & Pham (2020) Phan-Minh Nguyen and Huy Tuan Pham. A rigorous framework for the mean field limit of multilayer neural networks. arXiv preprint arXiv:2001.11443, 2020.
  • Olsson et al. (2022) Catherine Olsson, Nelson Elhage, Neel Nanda, Nicholas Joseph, Nova DasSarma, Tom Henighan, Ben Mann, Amanda Askell, Yuntao Bai, Anna Chen, et al. In-context learning and induction heads. arXiv preprint arXiv:2209.11895, 2022.
  • OpenAI (2023) OpenAI. Gpt-4 technical report, 2023.
  • Oymak & Soltanolkotabi (2020) Samet Oymak and Mahdi Soltanolkotabi. Toward moderate overparameterization: Global convergence guarantees for training shallow neural networks. IEEE Journal on Selected Areas in Information Theory, 1(1):84–105, 2020.
  • Oymak et al. (2023) Samet Oymak, Ankit Singh Rawat, Mahdi Soltanolkotabi, and Christos Thrampoulidis. On the role of attention in prompt-tuning. ICML, 2023.
  • Pérez et al. (2021) Jorge Pérez, Pablo Barceló, and Javier Marinkovic. Attention is turing complete. The Journal of Machine Learning Research, 22(1):3463–3497, 2021.
  • Snell et al. (2021) Charlie Snell, Ruiqi Zhong, Dan Klein, and Jacob Steinhardt. Approximating how single head attention learns. arXiv preprint arXiv:2103.07601, 2021.
  • Soltanolkotabi (2017) Mahdi Soltanolkotabi. Learning relus via gradient descent. Advances in neural information processing systems, 30, 2017.
  • Tarzanagh et al. (2023a) Davoud Ataee Tarzanagh, Yingcong Li, Christos Thrampoulidis, and Samet Oymak. Transformers as support vector machines. arXiv preprint arXiv:2308.16898, 2023a.
  • Tarzanagh et al. (2023b) Davoud Ataee Tarzanagh, Yingcong Li, Xuechen Zhang, and Samet Oymak. Max-margin token selection in attention mechanism. arXiv preprint arXiv:2306.13596, 3(7):47, 2023b.
  • Tian (2017) Yuandong Tian. An analytical formula of population gradient for two-layered relu network and its applications in convergence and critical point analysis. In International conference on machine learning, pp. 3404–3413. PMLR, 2017.
  • Tian (2022) Yuandong Tian. Understanding the role of nonlinearity in training dynamics of contrastive learning. arXiv preprint arXiv:2206.01342, 2022.
  • Tian (2023) Yuandong Tian. Understanding the role of nonlinearity in training dynamics of contrastive learning. ICLR, 2023.
  • Tian et al. (2020) Yuandong Tian, Lantao Yu, Xinlei Chen, and Surya Ganguli. Understanding self-supervised learning with dual deep networks. arXiv preprint arXiv:2010.00578, 2020.
  • Tian et al. (2023) Yuandong Tian, Yiping Wang, Beidi Chen, and Simon Du. Scan and snap: Understanding training dynamics and token composition in 1-layer transformer, 2023.
  • Vaswani et al. (2017) Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin. Attention is all you need. 2017. URL https://arxiv.org/pdf/1706.03762.pdf.
  • Von Oswald et al. (2022) Johannes Von Oswald, Eyvind Niklasson, Ettore Randazzo, João Sacramento, Alexander Mordvintsev, Andrey Zhmoginov, and Max Vladymyrov. Transformers learn in-context by gradient descent. arXiv preprint arXiv:2212.07677, 2022.
  • Xu & Du (2023) Weihang Xu and Simon S Du. Over-parameterization exponentially slows down gradient descent for learning a single neuron. arXiv preprint arXiv:2302.10034, 2023.
  • Yang et al. (2022) Greg Yang, Edward J Hu, Igor Babuschkin, Szymon Sidor, Xiaodong Liu, David Farhi, Nick Ryder, Jakub Pachocki, Weizhu Chen, and Jianfeng Gao. Tensor programs v: Tuning large neural networks via zero-shot hyperparameter transfer. arXiv preprint arXiv:2203.03466, 2022.
  • Yao et al. (2021) Shunyu Yao, Binghui Peng, Christos Papadimitriou, and Karthik Narasimhan. Self-attention networks can process bounded hierarchical languages. arXiv preprint arXiv:2105.11115, 2021.
  • Yun et al. (2019) Chulhee Yun, Srinadh Bhojanapalli, Ankit Singh Rawat, Sashank J Reddi, and Sanjiv Kumar. Are transformers universal approximators of sequence-to-sequence functions? arXiv preprint arXiv:1912.10077, 2019.
  • Zhang et al. (2020) Jingzhao Zhang, Sai Praneeth Karimireddy, Andreas Veit, Seungyeon Kim, Sashank Reddi, Sanjiv Kumar, and Suvrit Sra. Why are adaptive methods good for attention models? Advances in Neural Information Processing Systems, 33:15383–15393, 2020.
  • Zhang et al. (2023) Ruiqi Zhang, Spencer Frei, and Peter L Bartlett. Trained transformers learn linear models in-context. arXiv preprint arXiv:2306.09927, 2023.
  • Zhang et al. (2022) Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen, Christopher Dewan, Mona Diab, Xian Li, Xi Victoria Lin, et al. Opt: Open pre-trained transformer language models. arXiv preprint arXiv:2205.01068, 2022.
  • Zhao et al. (2023) Haoyu Zhao, Abhishek Panigrahi, Rong Ge, and Sanjeev Arora. Do transformers parse while predicting the masked word? arXiv preprint arXiv:2303.08117, 2023.
  • Zhou et al. (2019) Mo Zhou, Tianyi Liu, Yan Li, Dachao Lin, Enlu Zhou, and Tuo Zhao. Toward understanding the importance of noise in training neural networks. In International Conference on Machine Learning, pp. 7594–7602. PMLR, 2019.
  • Zou et al. (2020) Difan Zou, Yuan Cao, Dongruo Zhou, and Quanquan Gu. Gradient descent optimizes over-parameterized deep relu networks. Machine learning, 109:467–492, 2020.

Appendix A Proofs

A.1 Per-hidden loss formulation

Our Assumption 1 has an equivalent per-hidden node loss:

max{𝒘k},{𝒛m}𝔼𝒟[kghkhk]:=max{𝒘k},{𝒛m}𝔼i𝒟[kghk[i]hk[i]]\max_{\{{\bm{w}}_{k}\},\{{\bm{z}}_{m}\}}\mathbb{E}_{\mathcal{D}}\left[\sum_{k}g_{h_{k}}h_{k}\right]:=\max_{\{{\bm{w}}_{k}\},\{{\bm{z}}_{m}\}}\mathbb{E}_{i\sim\mathcal{D}}\left[\sum_{k}g_{h_{k}}[i]h_{k}[i]\right] (10)

where ghk[i]g_{h_{k}}[i] is the backpropagated gradient sent to node hkh_{k} at sample ii.

A.2 JoMA framework (Section  3)

See 1

Proof.

Let L:=𝒃/𝒛mL:=\partial{\bm{b}}/\partial{\bm{z}}_{m}. Plugging the dynamics of 𝒘k{\bm{w}}_{k} into the dynamics of self-attention logits 𝒛m{\bm{z}}_{m}, we have:

𝒛˙m=𝔼q=m[LUCkghkhk𝒘k]=k𝔼q=m[ghkhkL𝒗k]\dot{\bm{z}}_{m}=\mathbb{E}_{q=m}\left[L^{\top}U_{C}^{\top}\sum_{k}g_{h_{k}}h^{\prime}_{k}{\bm{w}}_{k}\right]=\sum_{k}\mathbb{E}_{q=m}\left[g_{h_{k}}h^{\prime}_{k}L^{\top}{\bm{v}}_{k}\right] (11)

Before we start, we first define ξk(t):=0t𝔼q=m[ghk(t)hk(t)]dt\xi_{k}(t):=\int_{0}^{t}\mathbb{E}_{q=m}\left[g_{h_{k}}(t^{\prime})h_{k}^{\prime}(t^{\prime})\right]\mathrm{d}t^{\prime}. Therefore, ξ˙k=𝔼q=m[ghkhk]\dot{\xi}_{k}=\mathbb{E}_{q=m}\left[g_{h_{k}}h_{k}^{\prime}\right]. Intuitively, ξk\xi_{k} is the bias of node kk, regardless of whether there exists an actual bias parameter to optimize.

Notice that UC𝒇=𝒃+UC𝒖qU_{C}^{\top}{\bm{f}}={\bm{b}}+U_{C}^{\top}{\bm{u}}_{q}, with orthonormal condition between contextual and query tokens: UC𝒖m=0U_{C}^{\top}{\bm{u}}_{m}=0, and thus UC𝒇=𝒃U_{C}^{\top}{\bm{f}}={\bm{b}}, which leads to

𝒗˙k=UC𝒘˙k=UC𝔼q=m[ghkhk𝒇]=𝔼q=m[ghkhk𝒃]\dot{\bm{v}}_{k}=U_{C}^{\top}\dot{\bm{w}}_{k}=U_{C}^{\top}\mathbb{E}_{q=m}\left[g_{h_{k}}h_{k}^{\prime}{\bm{f}}\right]=\mathbb{E}_{q=m}\left[g_{h_{k}}h_{k}^{\prime}{\bm{b}}\right] (12)

Unnormalized attention (A:=constA:=\mathrm{const}). In this case, we have 𝒃=σ(𝒛m)𝒙/A{\bm{b}}=\sigma({\bm{z}}_{m})\circ{\bm{x}}/A and L=diag(σ(𝒛m)𝒙)/A=diag(σ(𝒛m)σ(𝒛m))diag(𝒃)L=\mathrm{diag}(\sigma^{\prime}({\bm{z}}_{m})\circ{\bm{x}})/A=\mathrm{diag}\left(\frac{\sigma^{\prime}({\bm{z}}_{m})}{\sigma({\bm{z}}_{m})}\right)\mathrm{diag}({\bm{b}}) and thus

𝒛˙m\displaystyle\dot{\bm{z}}_{m} =\displaystyle= k𝔼q=m[ghkhkL𝒗k]=diag(σ(𝒛m)σ(𝒛m))k𝔼q=m[ghkhk𝒃]𝒗k\displaystyle\sum_{k}\mathbb{E}_{q=m}\left[g_{h_{k}}h^{\prime}_{k}L^{\top}{\bm{v}}_{k}\right]=\mathrm{diag}\left(\frac{\sigma^{\prime}({\bm{z}}_{m})}{\sigma({\bm{z}}_{m})}\right)\sum_{k}\mathbb{E}_{q=m}\left[g_{h_{k}}h^{\prime}_{k}{\bm{b}}\right]\circ{\bm{v}}_{k} (13)
=\displaystyle= diag(σ(𝒛m)σ(𝒛m))k𝒗˙k𝒗k\displaystyle\mathrm{diag}\left(\frac{\sigma^{\prime}({\bm{z}}_{m})}{\sigma({\bm{z}}_{m})}\right)\sum_{k}\dot{\bm{v}}_{k}\circ{\bm{v}}_{k} (14)

which leads to

diag(σ(𝒛m)σ(𝒛m))𝒛˙m=k𝒗˙k𝒗k\mathrm{diag}\left(\frac{\sigma({\bm{z}}_{m})}{\sigma^{\prime}({\bm{z}}_{m})}\right)\dot{\bm{z}}_{m}=\sum_{k}\dot{\bm{v}}_{k}\circ{\bm{v}}_{k} (15)

Therefore, for linear attention, σ(𝒛m)/σ(𝒛m)=𝒛m\sigma({\bm{z}}_{m})/\sigma^{\prime}({\bm{z}}_{m})={\bm{z}}_{m}, by integrating both sides, we have 𝒛m2(t)=k𝒗k2(t)+𝒄{\bm{z}}_{m}^{2}(t)=\sum_{k}{\bm{v}}^{2}_{k}(t)+{\bm{c}}. For exp attention, σ(𝒛m)/σ(𝒛m)=1\sigma({\bm{z}}_{m})/\sigma^{\prime}({\bm{z}}_{m})=1, then by integrating both sides, we have 𝒛m(t)=12k𝒗k2(t)+𝒄{\bm{z}}_{m}(t)=\frac{1}{2}\sum_{k}{\bm{v}}^{2}_{k}(t)+{\bm{c}}.

Softmax attention. In this case, we have L=diag(𝒃)𝒃𝒃L=\mathrm{diag}({\bm{b}})-{\bm{b}}{\bm{b}}^{\top}. Therefore,

𝔼q=m[ghkhkdiag(𝒃)]UC𝒘k=𝔼q=m[ghkhk𝒃]𝒗k=𝒗˙k𝒗k\mathbb{E}_{q=m}\left[g_{h_{k}}h^{\prime}_{k}\mathrm{diag}({\bm{b}})\right]U_{C}^{\top}{\bm{w}}_{k}=\mathbb{E}_{q=m}\left[g_{h_{k}}h^{\prime}_{k}{\bm{b}}\right]\circ{\bm{v}}_{k}=\dot{\bm{v}}_{k}\circ{\bm{v}}_{k} (16)

where \circ is the Hadamard (element-wise) product. Now Therefore, we have:

𝔼q=m[ghkhk𝒃]UC𝒘k=𝒗˙k𝒗k\mathbb{E}_{q=m}\left[g_{h_{k}}h^{\prime}_{k}{\bm{b}}^{\top}\right]U_{C}^{\top}{\bm{w}}_{k}=\dot{\bm{v}}_{k}^{\top}{\bm{v}}_{k} (17)

Given the assumption that 𝒃{\bm{b}} is uncorrelated with kghkhk𝒃\sum_{k}g_{h_{k}}h^{\prime}_{k}{\bm{b}} (e.g., due to top-down gradient information), and let 𝒃¯m=𝔼q=m[𝒃]\bar{\bm{b}}_{m}=\mathbb{E}_{q=m}\left[{\bm{b}}\right], we have:

𝒛˙m=k𝒗˙k𝒗k𝒃¯m𝒗˙k𝒗k\dot{\bm{z}}_{m}=\sum_{k}\dot{\bm{v}}_{k}\circ{\bm{v}}_{k}-\bar{\bm{b}}_{m}\dot{\bm{v}}_{k}^{\top}{\bm{v}}_{k} (18)

If we further assume that 𝒃¯m\bar{\bm{b}}_{m} is constant over time, then we can integrate both side to get a close-form solution between 𝒛m(t){\bm{z}}_{m}(t) and {𝒗k(t)}\{{\bm{v}}_{k}(t)\}:

𝒛m(t)=12k(𝒗k2𝒗k22𝒃¯m)+𝒄{\bm{z}}_{m}(t)=\frac{1}{2}\sum_{k}\left({\bm{v}}^{2}_{k}-\|{\bm{v}}_{k}\|_{2}^{2}\bar{\bm{b}}_{m}\right)+{\bm{c}} (19)

See 2

Proof.

Due to the assumption, we have:

v˙l=𝔼q=m[ghkxl]exp(zml)/A=Δlmexp(zml)/A\dot{v}_{l}=\mathbb{E}_{q=m}\left[g_{h_{k}}x_{l}\right]\exp(z_{ml})/A=\Delta_{lm}\exp(z_{ml})/A (20)

where Δlm:=𝔼q=m[ghkxl]\Delta_{lm}:=\mathbb{E}_{q=m}\left[g_{h_{k}}x_{l}\right]. If xl[i]=(l|m,y[i])x_{l}[i]=\mathbb{P}(l|m,y[i]), then Δlm=𝔼l,q=m[ghk](l|m)\Delta_{lm}=\mathbb{E}_{l,q=m}\left[g_{h_{k}}\right]\mathbb{P}(l|m). Note that for linear model, Δlm\Delta_{lm} is a constant over time.

Plugging in the close-form solution for exp attention, the dynamics becomes

v˙l=Δlmexp(vl2/2+cl)/A\dot{v}_{l}=\Delta_{lm}\exp(v_{l}^{2}/2+c_{l})/A (21)

Assuming cl=0c_{l}=0, then for any two tokens lll\neq l^{\prime}, we get

v˙lv˙l=Δlmexp(zml)Δlmexp(zml)=Δlmexp(vl2/2)Δlmexp(vl2/2)\frac{\dot{v}_{l}}{\dot{v}_{l^{\prime}}}=\frac{\Delta_{lm}\exp(z_{ml})}{\Delta_{l^{\prime}m}\exp(z_{ml^{\prime}})}=\frac{\Delta_{lm}\exp(v^{2}_{l}/2)}{\Delta_{l^{\prime}m}\exp(v^{2}_{l^{\prime}}/2)} (22)

which can be integrated using erf()\mathrm{erf}(\cdot) function (i.e., Gaussian CRF: erf(x)=2π0xet2dt\mathrm{erf}(x)=\frac{2}{\sqrt{\pi}}\int_{0}^{x}e^{-t^{2}}\mathrm{d}t):

erf(vl(t)/2)Δlm=erf(vl(t)/2)Δlm+cll\frac{\mathrm{erf}\left(v_{l}(t)/2\right)}{\Delta_{lm}}=\frac{\mathrm{erf}(v_{l^{\prime}}(t)/2)}{\Delta_{l^{\prime}m}}+c_{ll^{\prime}} (23)

if 𝒗(0)=0{\bm{v}}(0)=0, then cll=0c_{ll^{\prime}}=0. ∎

A.3 Dynamics of Nonlinear activations (Sec. 4)

A.3.1 Without self-attention (or equivalently, with uniform attention)

Lemma 1 (Expectation of Hyperplane function under Isotropic distribution).

For any isotropic distribution p(𝐱𝐱¯)p({\bm{x}}-\bar{\bm{x}}) with mean 𝐱¯\bar{\bm{x}} in a subspace spanned by orthonormal bases RR, if 𝐯𝟎{\bm{v}}\neq{\bm{0}}, we have:

𝔼p[𝒙ψ(𝒗𝒙+ξ)]=θ1(r𝒗)𝒗2𝒙¯+θ2(r𝒗)𝒗23RR𝒗,𝔼p[ψ(𝒗𝒙+ξ)]=θ1(r𝒗)𝒗2\mathbb{E}_{p}\left[{\bm{x}}\psi({\bm{v}}^{\top}{\bm{x}}+\xi)\right]=\frac{\theta_{1}(r_{{\bm{v}}})}{\|{\bm{v}}\|_{2}}\bar{\bm{x}}+\frac{\theta_{2}(r_{{\bm{v}}})}{\|{\bm{v}}\|^{3}_{2}}RR^{\top}{\bm{v}},\quad\quad\mathbb{E}_{p}\left[\psi({\bm{v}}^{\top}{\bm{x}}+\xi)\right]=\frac{\theta_{1}(r_{{\bm{v}}})}{\|{\bm{v}}\|_{2}} (24)

where r𝐯:=𝐯𝐱¯+ξr_{{\bm{v}}}:={\bm{v}}^{\top}\bar{\bm{x}}+\xi is the (signed) distance between the distribution mean 𝐱¯\bar{\bm{x}} and the affine hyperplane (𝐯,ξ)({\bm{v}},\xi). θ1(r)\theta_{1}(r) and θ2(r)\theta_{2}(r) only depends on ψ\psi and the underlying distribution but not 𝐯{\bm{v}}. Additionally,

  • If ψ(r)\psi(r) is monotonously increasing, then θ1(r)\theta_{1}(r) is also monotonous increasing;

  • If ψ(r)0\psi(r)\geq 0, then θ1(r)0\theta_{1}(r)\geq 0;

  • If ψ()=0\psi(-\infty)=0, ψ(+)=1\psi(+\infty)=1, then θ1()=0\theta_{1}(-\infty)=0 and θ1(+)=1\theta_{1}(+\infty)=1;

  • If ψ()=0\psi(-\infty)=0, then θ2()=0\theta_{2}(-\infty)=0.

Proof.

Note that 𝒙{\bm{x}}^{\prime} is isotropic in span(RR) and thus p(𝒙)p({\bm{x}}^{\prime}) just depends on 𝒙\|{\bm{x}}^{\prime}\|, we let p0:++p_{0}:\mathbb{R}^{+}\rightarrow\mathbb{R}^{+} satisfies p0(𝒙)=p(𝒙)p_{0}(\|{\bm{x}}^{\prime}\|)=p({\bm{x}}^{\prime}). Our goal is to calculate

𝔼p[𝒙ψ(𝒘𝒙+ξ)]\displaystyle\mathbb{E}_{p}\left[{\bm{x}}\psi({\bm{w}}^{\top}{\bm{x}}+\xi)\right] =\displaystyle= span(R)𝒙ψ(𝒘𝒙+ξ)p(𝒙𝝁)d𝒙\displaystyle\int_{\text{span}(R)}{\bm{x}}\psi({\bm{w}}^{\top}{\bm{x}}+\xi)p({\bm{x}}-\bm{\mu})\mathrm{d}{\bm{x}} (25)
=\displaystyle= span(R)(𝒙+𝝁)ψ(𝒘𝒙+r𝒘)p(𝒙)d𝒙\displaystyle\int_{\text{span}(R)}({\bm{x}}^{\prime}+\bm{\mu})\psi({\bm{w}}^{\top}{\bm{x}}^{\prime}+r_{{\bm{w}}})p({\bm{x}}^{\prime})\mathrm{d}{\bm{x}}^{\prime} (26)

where 𝒙:=𝒙𝝁{\bm{x}}^{\prime}:={\bm{x}}-\bm{\mu} is isotropic. Since RR𝒘RR^{\top}{\bm{w}} is the projection of 𝒘{\bm{w}} onto space span(RR), we denote 𝒗:=RR𝒘{\bm{v}}:=RR^{\top}{\bm{w}} and y:=𝒘𝒙=𝒗𝒙y^{\prime}:={\bm{w}}^{\top}{\bm{x}}^{\prime}={\bm{v}}^{\top}{\bm{x}}^{\prime} since 𝒙{\bm{x}}^{\prime} lies in span(RR). Then let SS be any hyper-plane through 𝒗{\bm{v}}, which divide span(RR) into two symmetric part V+V_{+} and VV_{-}(Boundary is zero measurement set and can be ignored), we have,

P1\displaystyle P_{1} :=\displaystyle:= span(R)𝒙ψ(𝒘𝒙+r𝒘)p(𝒙)d𝒙\displaystyle\int_{\text{span}(R)}{\bm{x}}^{\prime}\psi({\bm{w}}^{\top}{\bm{x}}^{\prime}+r_{{\bm{w}}})p({\bm{x}}^{\prime})\mathrm{d}{\bm{x}}^{\prime} (27)
=\displaystyle= (V++V)𝒙ψ(𝒗𝒙+r𝒘)p(𝒙)d𝒙\displaystyle(\int_{V_{+}}+\int_{V_{-}}){\bm{x}}^{\prime}\psi({\bm{v}}^{\top}{\bm{x}}^{\prime}+r_{{\bm{w}}})p({\bm{x}}^{\prime})\mathrm{d}{\bm{x}}^{\prime} (28)
=\displaystyle= 2×V+𝒗𝒙𝒗𝒗𝒗ψ(𝒗𝒙+r𝒘)p(𝒙)d𝒙\displaystyle 2\times\int_{V_{+}}\frac{{\bm{v}}^{\top}{\bm{x}}^{\prime}}{\|{\bm{v}}\|}\cdot\frac{{\bm{v}}}{\|{\bm{v}}\|}\cdot\psi({\bm{v}}^{\top}{\bm{x}}^{\prime}+r_{{\bm{w}}})p({\bm{x}}^{\prime})\mathrm{d}{\bm{x}}^{\prime} (29)
=\displaystyle= {span(R)yψ(y+r𝒘)p(𝒙)d𝒙}𝒗𝒗2\displaystyle\{\int_{\text{span}(R)}y^{\prime}\psi(y^{\prime}+r_{{\bm{w}}})p({\bm{x}}^{\prime})\mathrm{d}{\bm{x}}^{\prime}\}\cdot\frac{{\bm{v}}}{\|{\bm{v}}\|^{2}} (30)

Eqn. 29 holds since for every 𝒙V+{\bm{x}}^{\prime}\in V_{+}, we can always find unique 𝒙′′V{\bm{x}}^{\prime\prime}\in V_{-} defined as

𝒙′′=(𝒙𝒗𝒙𝒗2𝒗)+𝒗𝒙𝒗2𝒗=2y𝒗2𝒗𝒙{\bm{x}}^{\prime\prime}=-({\bm{x}}^{\prime}-\frac{{\bm{v}}^{\top}{\bm{x}}^{\prime}}{\|{\bm{v}}\|^{2}}{\bm{v}})+\frac{{\bm{v}}^{\top}{\bm{x}}^{\prime}}{\|{\bm{v}}\|^{2}}{\bm{v}}=\frac{2y^{\prime}}{\|{\bm{v}}\|^{2}}{\bm{v}}-{\bm{x}}^{\prime} (31)

where 𝒙′′{\bm{x}}^{\prime\prime} and 𝒙{\bm{x}}^{\prime} satisfy 𝒙′′=𝒙\|{\bm{x}}^{\prime\prime}\|=\|{\bm{x}}^{\prime}\|, 𝒗𝒙′′=𝒗𝒙{\bm{v}}^{\top}{\bm{x}}^{\prime\prime}={\bm{v}}^{\top}{\bm{x}}^{\prime}, and have equal reverse component ±(𝒙𝒗𝒙𝒗2𝒗)\pm({\bm{x}}^{\prime}-\frac{{\bm{v}}^{\top}{\bm{x}}^{\prime}}{\|{\bm{v}}\|^{2}}{\bm{v}}) perpendicular to 𝒗{\bm{v}}. Thus for the 𝒙{\bm{x}}^{\prime} in Eqn. 28, only the component parallel to 𝒗{\bm{v}} remains. Furthermore, let {𝒖1,,𝒖n1,𝒗/𝒗}\{{\bm{u}}_{1},\ldots,{\bm{u}}_{n-1},{\bm{v}}/\|{\bm{v}}\|\} to be an orthonormal bases of span(RR) and denote xi:=𝒖i𝒙,i[n1]x^{\prime}_{i}:={\bm{u}}_{i}^{\top}{\bm{x}}^{\prime},\forall i\in[n-1], then we have

P1\displaystyle P_{1} =\displaystyle= {yyψ(y+r𝒘)d(y𝒗)[x1xn1p(𝒙)dx1dxn1]}𝒗𝒗2\displaystyle\{\int_{y^{\prime}}y^{\prime}\psi(y^{\prime}+r_{{\bm{w}}})\mathrm{d}(\frac{y^{\prime}}{\|{\bm{v}}\|})[\int_{x^{\prime}_{1}}\cdots\int_{x^{\prime}_{n-1}}p({\bm{x}}^{\prime})\mathrm{d}x^{\prime}_{1}\ldots\mathrm{d}x^{\prime}_{n-1}]\}\cdot\frac{{\bm{v}}}{\|{\bm{v}}\|^{2}} (32)
=:\displaystyle=: {+yψ(y+r𝒘)pn(y)dy}𝒗𝒗3\displaystyle\{\int_{-\infty}^{+\infty}y^{\prime}\psi(y^{\prime}+r_{{\bm{w}}})p_{n}(y^{\prime})\mathrm{d}y^{\prime}\}\cdot\frac{{\bm{v}}}{\|{\bm{v}}\|^{3}} (33)

Here pn(y)p_{n}(y^{\prime}) is the probability density function of yy^{\prime} obtained from 𝒙{\bm{x}}^{\prime}. For the trivial case where n=1n=1, clearly pn(y)=p0(|y|)=p(y)p_{n}(y^{\prime})=p_{0}(|y^{\prime}|)=p(y^{\prime}). If n2n\geq 2, it can be further calculated as:

pn(y)\displaystyle p_{n}(y^{\prime}) =\displaystyle= x1xn1p0((x1)2++(xn1)2+(y)2)dx1dxn1\displaystyle\int_{x^{\prime}_{1}}\cdots\int_{x^{\prime}_{n-1}}p_{0}(\sqrt{(x^{\prime}_{1})^{2}+\ldots+(x^{\prime}_{n-1})^{2}+(y^{\prime})^{2}})\ \cdot\mathrm{d}x^{\prime}_{1}\ldots\mathrm{d}x^{\prime}_{n-1} (34)
=\displaystyle= 0+p0(y2+l2)Sn1(l)dl\displaystyle\int_{0}^{+\infty}p_{0}(\sqrt{y^{\prime 2}+l^{2}})\cdot S_{n-1}(l)\mathrm{d}l (35)
=\displaystyle= (n1)π(n1)/2Γ(n+12)0+p0(y2+l2)ln2dl\displaystyle\frac{(n-1)\pi^{(n-1)/{2}}}{\Gamma(\frac{n+1}{2})}\int_{0}^{+\infty}p_{0}(\sqrt{y^{\prime 2}+l^{2}})\cdot l^{n-2}\mathrm{d}l (36)
=\displaystyle= {2n/2πn/21(n3)!!0+p0(y2+l2)ln2dl,n is even2π(n1)/2(n32)!0+p0(y2+l2)ln2dl,n is odd\displaystyle\begin{cases}\begin{aligned} &\frac{2^{n/2}\pi^{n/2-1}}{(n-3)!!}\int_{0}^{+\infty}p_{0}(\sqrt{y^{\prime 2}+l^{2}})\cdot l^{n-2}\mathrm{d}l,&\quad&n\text{ is even}\\ &\frac{2\pi^{(n-1)/2}}{(\frac{n-3}{2})!}\int_{0}^{+\infty}p_{0}(\sqrt{y^{\prime 2}+l^{2}})\cdot l^{n-2}\mathrm{d}l,&\quad&n\text{ is odd}\end{aligned}\end{cases} (37)

where Sn(R)=nπn/2Γ(n/2+1)Rn1S_{n}(R)=\frac{n\pi^{n/2}}{\Gamma(n/2+1)}R^{n-1} represents the surface area of an nn-dimensional hyper-sphere of radius ll. Γ\Gamma denotes the gamma function and we use the property that Γ(n+1)=n!\Gamma(n+1)=n! and Γ(n+12)=(2n1)!!π2n\Gamma(n+\frac{1}{2})=(2n-1)!!\sqrt{\pi}2^{-n} for any n+n\in\mathbb{N}^{+}.

Similarly, for another term we have

P2\displaystyle P_{2} =\displaystyle= span(R)𝝁ψ(𝒘𝒙+r𝒘)p(𝒙)d𝒙\displaystyle\int_{\text{span}(R)}\bm{\mu}\cdot\psi({\bm{w}}^{\top}{\bm{x}}^{\prime}+r_{{\bm{w}}})p({\bm{x}}^{\prime})\mathrm{d}{\bm{x}}^{\prime} (38)
=\displaystyle= {+ψ(y+r𝒘)pn(y)dy}𝝁𝒗\displaystyle\{\int_{-\infty}^{+\infty}\psi(y^{\prime}+r_{{\bm{w}}})p_{n}(y^{\prime})\mathrm{d}y^{\prime}\}\cdot\frac{\bm{\mu}}{\|{\bm{v}}\|} (39)

Finally, let

θ1(r𝒘)\displaystyle\theta_{1}(r_{{\bm{w}}}) :=\displaystyle:= +ψ(y+r𝒘)pn(y)dy\displaystyle\int_{-\infty}^{+\infty}\psi(y^{\prime}+r_{{\bm{w}}})p_{n}(y^{\prime})\mathrm{d}y^{\prime} (41)
θ2(r𝒘)\displaystyle\theta_{2}(r_{{\bm{w}}}) :=\displaystyle:= +yψ(y+r𝒘)pn(y)dy\displaystyle\int_{-\infty}^{+\infty}y^{\prime}\cdot\psi(y^{\prime}+r_{{\bm{w}}})p_{n}(y^{\prime})\mathrm{d}y^{\prime} (42)

Then we arrive at the conclusion. ∎

See 3

Proof.

Since backpropagated gradient ghkg_{h_{k}} is constant within each of its mixed components, we have:

Δm\displaystyle\Delta_{m} :=\displaystyle:= 𝔼q=m[ghkhk𝒃]=j𝔼q=m,c=j[ghkhk𝒃][c=j]\displaystyle\mathbb{E}_{q=m}\left[g_{h_{k}}h^{\prime}_{k}{\bm{b}}\right]=\sum_{j}\mathbb{E}_{q=m,c=j}\left[g_{h_{k}}h^{\prime}_{k}{\bm{b}}\right]\mathbb{P}[c=j] (43)
=\displaystyle= j𝔼q=m,c=j[ghk][c=j]𝔼q=m,c=j[hk𝒃]\displaystyle\sum_{j}\mathbb{E}_{q=m,c=j}\left[g_{h_{k}}\right]\mathbb{P}[c=j]\mathbb{E}_{q=m,c=j}\left[h^{\prime}_{k}{\bm{b}}\right] (44)
=\displaystyle= jaj𝔼𝒙p(𝒙𝒙¯j)[𝒃ϕ(𝒘𝒇)]\displaystyle\sum_{j}a_{j}\mathbb{E}_{{\bm{x}}\sim p({\bm{x}}-\bar{\bm{x}}_{j})}\left[{\bm{b}}\phi^{\prime}({\bm{w}}^{\top}{\bm{f}})\right] (45)

Let ψ=ϕ\psi=\phi^{\prime}. Note that 𝒘𝒇=𝒘(Uc𝒃+𝒖q)=𝒗𝒃+ξ{\bm{w}}^{\top}{\bm{f}}={\bm{w}}^{\top}(U_{c}{\bm{b}}+{\bm{u}}_{q})={\bm{v}}^{\top}{\bm{b}}+\xi and with uniform attention 𝒃=𝒙{\bm{b}}={\bm{x}}, we have:

Δm=jaj𝔼𝒙p(𝒙𝒙¯j)[𝒙ψ(𝒗𝒙+ξ)]\Delta_{m}=\sum_{j}a_{j}\mathbb{E}_{{\bm{x}}\sim p({\bm{x}}-\bar{\bm{x}}_{j})}\left[{\bm{x}}\psi({\bm{v}}^{\top}{\bm{x}}+\xi)\right] (46)

Using Lemma 1 leads to the conclusion. ∎

Remarks. Note that if ϕ\phi is linear, then ψ1\psi\equiv 1, θ11\theta_{1}\equiv 1 and θ20\theta_{2}\equiv 0. In this case, θ1\theta_{1} is a constant, which marks a key difference between linear and nonlinear dynamics.

A.3.2 (Tentative) Critical Point Analysis of Dynamics in Theorem 3

Lemma 2 (Property of θ1,θ2\theta_{1},\theta_{2} with homogeneous activation).

If ϕ(x)=xϕ(x)\phi(x)=x\phi^{\prime}(x) is a homogeneous activation function and ψ=ϕ\psi=\phi^{\prime}, then we have:

ddr(θ2(r)+rθ1(r))=θ1(r)\frac{\mathrm{d}}{\mathrm{d}r}\left(\theta_{2}(r)+r\theta_{1}(r)\right)=\theta_{1}(r) (47)

Integrating both sides and we get:

θ2(r)+rθ1(r)=F(r):=0rθ1(r)dr+C\theta_{2}(r)+r\theta_{1}(r)=F(r):=\int_{0}^{r}\theta_{1}(r^{\prime})\mathrm{d}r^{\prime}+C (48)

Let r=0r=0 and it is clear that C=θ2(0)C=\theta_{2}(0). Thus

θ2(r)+rθ1(r)=F(r)=0rθ1(r)dr+θ2(0)\theta_{2}(r)+r\theta_{1}(r)=F(r)=\int_{0}^{r}\theta_{1}(r^{\prime})\mathrm{d}r^{\prime}+\theta_{2}(0) (49)

If ψ0\psi\geq 0, then F(r)F(r) is a monotonous increasing function with F(+)=+F(+\infty)=+\infty. Furthermore, if limrrθ1(r)=0\lim_{r\rightarrow-\infty}r\theta_{1}(r)=0 and ψ()=0\psi(-\infty)=0, then θ2()=0\theta_{2}(-\infty)=0 and F()=0F(-\infty)=0 and thus F(r)0F(r)\geq 0.

Proof.

Simply verify Eqn. 47 is true. ∎

Overall, the dynamics can be quite complicated. We consider a special C=2C=2 case with one positive (a+a_{+}, r+r_{+} and 𝒙¯+\bar{\bm{x}}_{+}) and one negative (aa_{-}, rr_{-} and 𝒙¯\bar{\bm{x}}_{-}) distribution.

Lemma 3 (Existence of critical point of dynamics with ReLU activation).

For any homogeneous activation ϕ(x)=xϕ(x)\phi(x)=x\phi^{\prime}(x), any stationary point of Eqn. 5 must satisfy jajF(rj)=0\sum_{j}a_{j}F(r_{j})=0, where F(r):=θ2(0)+0rθ1(r)drF(r):=\theta_{2}(0)+\int_{0}^{r}\theta_{1}(r^{\prime})\mathrm{d}r^{\prime} is a monotonous increasing function.

Proof.

We rewrite the dynamics equations for the nonlinear activation without attention case:

𝒗˙=1𝒗2jajθ1(rj)𝒙¯j+1𝒗23jajθ2(rj)𝒗,ξ˙=1𝒗2jajθ1(rj)\dot{\bm{v}}=\frac{1}{\|{\bm{v}}\|_{2}}\sum_{j}a_{j}\theta_{1}(r_{j})\bar{\bm{x}}_{j}+\frac{1}{\|{\bm{v}}\|^{3}_{2}}\sum_{j}a_{j}\theta_{2}(r_{j}){\bm{v}},\qquad\dot{\xi}=\frac{1}{\|{\bm{v}}\|_{2}}\sum_{j}a_{j}\theta_{1}(r_{j}) (50)

Notice that 𝒙¯j𝒗=rjξ\bar{\bm{x}}_{j}^{\top}{\bm{v}}=r_{j}-\xi, this gives that:

𝒗2𝒗𝒗˙\displaystyle\|{\bm{v}}\|_{2}{\bm{v}}^{\top}\dot{\bm{v}} =\displaystyle= jajθ1(rj)(rjξ)+jajθ2(rj)\displaystyle\sum_{j}a_{j}\theta_{1}(r_{j})(r_{j}-\xi)+\sum_{j}a_{j}\theta_{2}(r_{j}) (51)
=\displaystyle= jaj(rjθ1(rj)+θ2(rj))ξjajθ1(rj)\displaystyle\sum_{j}a_{j}(r_{j}\theta_{1}(r_{j})+\theta_{2}(r_{j}))-\xi\sum_{j}a_{j}\theta_{1}(r_{j}) (52)
=\displaystyle= jajF(rj)𝒗2ξξ˙\displaystyle\sum_{j}a_{j}F(r_{j})-\|{\bm{v}}\|_{2}\xi\dot{\xi} (53)

in which the last equality is because the dynamics of ξ\xi, and due to Lemma 2. Now we leverage the condition of stationary points (𝒗˙=0\dot{\bm{v}}=0 and ξ˙=0\dot{\xi}=0), we arrive at the necessary conditions at the stationary points:

jajF(rj)=0\sum_{j}a_{j}F(r_{j})=0 (54)

Note that in general, the scalar condition above is only necessary but not sufficient. Eqn. 50 has Mc+1M_{c}+1 equations but we only have two scalar equations (Eqn. 50 and 𝒗2ξ˙=jajθ1(rj)=0\|{\bm{v}}\|_{2}\dot{\xi}=\sum_{j}a_{j}\theta_{1}(r_{j})=0). However, we can get a better characterization of the stationary points if there are only two components a+a_{+} and aa_{-}:

A special case: one positive and one negative samples In this case, we have (here r+:=𝒗𝒙¯++ξr_{+}:={\bm{v}}^{\top}\bar{\bm{x}}_{+}+\xi and r:=𝒗𝒙¯+ξr_{-}:={\bm{v}}^{\top}\bar{\bm{x}}_{-}+\xi):

a+F(r+)aF(r)=0a_{+}F(r_{+})-a_{-}F(r_{-})=0 (55)

So the sufficient and necessary condition for (𝒗,ξ)({\bm{v}},\xi) to be the critical point is that

F(r+)F(r)=θ1(r+)θ1(r)=aa+\frac{F(r_{+})}{F(r_{-})}=\frac{\theta_{1}(r_{+})}{\theta_{1}(r_{-})}=\frac{a_{-}}{a_{+}} (56)

Without loss of generality, we consider the case where ϕ\phi is ReLU and ψ(r)=𝐈[r>0]\psi(r)=\mathbf{I}[r>0]. Note that θ1\theta_{1} is a monotonously increasing function, we have θ11:(0,1)\theta_{1}^{-1}:(0,1)\rightarrow\mathbb{R} such that θ11(θ1(r))=r\theta_{1}^{-1}(\theta_{1}(r))=r for any rr\in\mathbb{R}. And we denote G:(0,1)G:(0,1)\rightarrow\mathbb{R} which satisfies:

G(y)=F(θ11(y))G(y)=F(\theta_{1}^{-1}(y)) (57)

and y+:=θ11(r+)y_{+}:=\theta_{1}^{-1}(r_{+}), y:=θ11(r)y_{-}:=\theta_{1}^{-1}(r_{-}). Then if we can find some line lk:y=kxl_{k}:y=kx for some kk\in\mathbb{R} such that lkl_{k} has at least two points of intersection (yi,kyi),i=1,2(y_{i},ky_{i}),i=1,2 with curve GG and a/a+=y1/y2a_{-}/a_{+}=y_{1}/y_{2} or a/a+=y2/y1a_{-}/a_{+}=y_{2}/y_{1}, then we can always find some 𝒗{\bm{v}} and ξ\xi such that Eqn. 56 holds.

On the other hand, it’s easy to find that (Fig. 9):

dG(y)dy|y=θ1(x)\displaystyle\frac{\mathrm{d}G(y)}{\mathrm{d}y}\left.\right|_{y=\theta_{1}(x)} =\displaystyle= θ1(x)pn(x)>0\displaystyle\frac{\theta_{1}(x)}{p_{n}(x)}>0
limy1G(y)\displaystyle\lim_{y\rightarrow 1}G(y) =\displaystyle= limr+F(r)=+\displaystyle\lim_{r\rightarrow+\infty}F(r)=+\infty
limy0G(y)\displaystyle\lim_{y\rightarrow 0}G(y) =\displaystyle= limrF(r)=limrrθ1(r)\displaystyle\lim_{r\rightarrow-\infty}F(r)=\lim_{r\rightarrow-\infty}r\theta_{1}(r)
Refer to caption
Figure 9: The plot of function G(y)G(y).

Note that since G(y+)/G(y)=y+/yG(y_{+})/G(y_{-})=y_{+}/y_{-}, we have G(y+)/y+=G(y)/yG(y_{+})/y_{+}=G(y_{-})/y_{-} and thus (y+,G(y+))(y_{+},G(y_{+})) and (y,G(y))(y_{-},G(y_{-})) are lying at the same straight line.

For finding the sufficient condition, we focus on the range x0x\geq 0 and θ1(x)12\theta_{1}(x)\geq\frac{1}{2}. Then in order that line lk:y=kxl_{k}:y=kx for some kk\in\mathbb{R} has at least two points of intersection with curve GG, we just need to let

G(θ~1(0))θ~1(0)dG(y)dy|y=θ~1(0)θ~2(0)pn(0)=pn(0)0+ypn(y)dy14\frac{G(\tilde{\theta}_{1}(0))}{\tilde{\theta}_{1}(0)}\geq\frac{\mathrm{d}G(y)}{\mathrm{d}y}\left.\right|_{y=\tilde{\theta}_{1}(0)}\iff\tilde{\theta}_{2}(0)\cdot p_{n}(0)=p_{n}(0)\int_{0}^{+\infty}y^{\prime}p_{n}(y^{\prime})\mathrm{d}y^{\prime}\geq\frac{1}{4} (58)

For convenience, let Slk:={(x,y)|y=kx}S_{l_{k}}:=\{(x,y)|y=kx\} and SG:={(x,y)|y=G(x)}S_{G}:=\{(x,y)|y=G(x)\} to be the image of the needed functions. Denote π1:2:π1((x,y))=x\pi_{1}:\mathbb{R}^{2}\rightarrow\mathbb{R}:\pi_{1}((x,y))=x for any x,yx,y\in\mathbb{R}, π1(S)={π1(s)|sS}\pi_{1}(S)=\{\pi_{1}(s)|\forall s\in S\}. Therefore, if Eqn. 58 holds, then the following set 𝒮\mathcal{S} will not be empty.

𝒮:=k{x2x1|x1x2π1(SlkSG)}\mathcal{S}:=\bigcup_{k\in\mathbb{R}}\{\frac{x_{2}}{x_{1}}\ |\ \forall x_{1}\neq x_{2}\in\pi_{1}(S_{l_{k}}\cap S_{G})\} (59)

And Eqn. 5 has critical points if a+/a𝒮a_{+}/a_{-}\in\mathcal{S}. And it’s easy to find that s𝒮\forall s\in\mathcal{S}, s(12,1)(1,2)s\in(\frac{1}{2},1)\cup(1,2). Similar results also hold for other homogeneous activations.

Remarks. It is often the case that y<1/2y_{-}<1/2 and y+>1/2y_{+}>1/2, since G(y)G(y) when y>1/2y>1/2 is convex and there will be at most two intersection between a convex function and a straight line. This means that r+>0r^{*}_{+}>0 and r=ξ<0r^{*}_{-}=\xi_{*}<0.

A.4 Several remarks

The intuition behind ξ\xi: Note that while node kk in MLP layer does not have an explicit bias term, our analysis above demonstrates that there exists an “implicit bias” term ξk(t)\xi_{k}(t) embedded in the weight vector 𝒘k{\bm{w}}_{k}:

𝒘(t)=𝒘(0)+UC[𝒗(t)𝒗(0)]+𝒖mξ(t){\bm{w}}(t)={\bm{w}}(0)+U_{C}[{\bm{v}}(t)-{\bm{v}}(0)]+{\bm{u}}_{m}\xi(t) (60)

This bias term allows encoding of the query embedding 𝒖m{\bm{u}}_{m} into the weight, and the negative bias ξ<0\xi^{*}<0 ensures that given the query q=mq=m, there needs to be a positive inner product between 𝒗{\bm{v}}_{*} (i.e., the “pattern template”) and the input contextual tokens, in order to activate the node kk.

Pattern superposition. Note that due to such mechanism, one single weight 𝒘{\bm{w}} may contain multiple query vectors (e.g., 𝒖m1{\bm{u}}_{m_{1}} and 𝒖m2{\bm{u}}_{m_{2}}) and their associated pattern templates (e.g., 𝒗m1{\bm{v}}_{m_{1}} and 𝒗m2{\bm{v}}_{m_{2}}), as long as they are orthogonal to each other. Specifically, if 𝒘=𝒗m1ξm1𝒖m1+𝒗m2ξm2𝒖m2{\bm{w}}={\bm{v}}_{m_{1}}-\xi_{m_{1}}{\bm{u}}_{m_{1}}+{\bm{v}}_{m_{2}}-\xi_{m_{2}}{\bm{u}}_{m_{2}}, then it can match both pattern 1 and pattern 2. We called this “pattern superposition”, as demonstrated in Fig. 10.

Refer to caption
Figure 10: Examples of pattern superposition: the same neuron in MLP hidden layers can be activated by multiple irrelevant combinations of tokens (A and B in each group, e.g., the same neuron activated by both “Every morning” and “In the realm of physics”), in Pythia-70M and Pythia-160M models. Bold tokens are what the query token attends to.
Lemma 4.

If ϕ(x)\phi(x) is homogeneous, i.e., ϕ(x)=ϕ(x)x\phi(x)=\phi^{\prime}(x)x, then there exist constant c,c+c_{-},c_{+}\in\mathbb{R} depend on ϕ\phi such that ϕ(x)=c𝟏[x<0]+c+𝟏[x>0]\phi(x)=c_{-}\mathbf{1}[x<0]+c_{+}\mathbf{1}[x>0], and thus

dθ1dr=(c+c+)pn(r),dθ2dr=(c+c+)rpn(r)\frac{\mathrm{d}\theta_{1}}{\mathrm{d}r}=(c_{-}+c_{+})p_{n}(r),\quad\frac{\mathrm{d}\theta_{2}}{\mathrm{d}r}=-(c_{-}+c_{+})r\cdot p_{n}(r) (61)
Proof.

For any x>0x>0, we have

ϕ(x)\displaystyle\phi^{\prime}(x) =\displaystyle= limδx0+ϕ(x+δx)ϕ(x)δx\displaystyle\lim_{\delta x\rightarrow 0+}\frac{\phi(x+\delta x)-\phi(x)}{\delta x} (62)
=\displaystyle= limδx0+ϕ(x+δx)ϕ(x)δxx+limδx0ϕ(x+δx)\displaystyle\lim_{\delta x\rightarrow 0+}\frac{\phi^{\prime}(x+\delta x)-\phi^{\prime}(x)}{\delta x}\cdot x+\lim_{\delta x\rightarrow 0}\phi^{\prime}(x+\delta x) (63)
=\displaystyle= xlimδx0+ϕ(x+δx)ϕ(x)δx+ϕ(x)\displaystyle x\cdot\lim_{\delta x\rightarrow 0+}\frac{\phi^{\prime}(x+\delta x)-\phi^{\prime}(x)}{\delta x}+\phi^{\prime}(x) (64)

So for any x>0x>0, ϕ(x)\phi^{\prime}(x) must be constant, and similar results hold for x<0x<0. Then by direct calculation, we can get the results. ∎

A.4.1 With self-attention

Lemma 5.

Let g(y):=1ey2yg(y):=\frac{1-e^{-y^{2}}}{y}. Then maxy0g(y)12\max_{y\geq 0}g(y)\leq\frac{1}{\sqrt{2}}.

Proof.

Any of its stationary point yy_{*} must satisfies gy(y)=0g_{y}^{\prime}(y_{*})=0, which gives:

ey2=12y2+1e^{-y_{*}^{2}}=\frac{1}{2y_{*}^{2}+1} (66)

Therefore, at any stationary points, we have:

g(y)=2y2y2+1=22y+y112g(y_{*})=\frac{2y_{*}}{2y_{*}^{2}+1}=\frac{2}{2y_{*}+y_{*}^{-1}}\leq\frac{1}{\sqrt{2}} (67)

since g(0)=g(+)=0g(0)=g(+\infty)=0, the conclusion follows. ∎

Lemma 6 (Bound of Gaussian integral).

Let G(y):=ey2/20yex2/2dxG(y):=e^{-y^{2}/2}\int_{0}^{y}e^{x^{2}/2}\mathrm{d}x, then 0G(y)10\leq G(y)\leq 1 for y0y\geq 0.

Proof.

G(y)0G(y)\geq 0 is obvious. Note that

G(y)\displaystyle G(y) :=\displaystyle:= ey2/20yex2/2dxey2/20yexy/2dx=2y(1ey2/2)=2g(y/2)\displaystyle e^{-y^{2}/2}\int_{0}^{y}e^{x^{2}/2}\mathrm{d}x\leq e^{-y^{2}/2}\int_{0}^{y}e^{xy/2}\mathrm{d}x=\frac{2}{y}\left(1-e^{-y^{2}/2}\right)=\sqrt{2}g(y/\sqrt{2})

Applying Lemma 5 gives the conclusion. ∎

See 4

Proof.

We first consider when 𝝁>0\bm{\mu}>0. We can write down the dynamics in a component wise manner, since all components share the same scalar constant:

v˙jv˙k=(μjvj)evj2/2(μkvk)evk2/2\frac{\dot{v}_{j}}{\dot{v}_{k}}=\frac{(\mu_{j}-v_{j})e^{v_{j}^{2}/2}}{(\mu_{k}-v_{k})e^{v_{k}^{2}/2}} (68)

which gives the following separable form:

v˙jevj2/2μjvj=v˙kevk2/2μkvk\frac{\dot{v}_{j}e^{-v^{2}_{j}/2}}{\mu_{j}-v_{j}}=\frac{\dot{v}_{k}e^{-v^{2}_{k}/2}}{\mu_{k}-v_{k}} (69)

Let

F(r,r0,μ):=r0μrμev2/2μvdv=r0reμ2x2/21xdx(x=v/μ)F(r,r_{0},\mu):=\int_{r_{0}\mu}^{r\mu}\frac{e^{-v^{2}/2}}{\mu-v}\mathrm{d}v=\int_{r_{0}}^{r}\frac{e^{-\mu^{2}x^{2}/2}}{1-x}\mathrm{d}x\quad\quad(x=v/\mu) (70)

Integrating both sides of Eqn. 69 from t=0t=0 to tt, the dynamics must satisfy the following equation at time tt:

F(rj(t),rj(0),μj)=F(rk(t),rk(0),μk)F(r_{j}(t),r_{j}(0),\mu_{j})=F(r_{k}(t),r_{k}(0),\mu_{k}) (71)

where rj(t):=vj(t)/μjr_{j}(t):=v_{j}(t)/\mu_{j}. According to the dynamics, rj(t)1r_{j}(t)\rightarrow 1 and the question is how fast the convergence is. Depending on the initialization, rj(t)>1r_{j}(t)>1 or rj(t)<1r_{j}(t)<1.

Eqn. 71 implicitly gives the relationship between rj(t)r_{j}(t) and rk(t)r_{k}(t) (and thus δj(t)\delta_{j}(t) and δk(t)\delta_{k}(t)). Now the question is how to bound F(r,r0,μ)F(r,r_{0},\mu), which does not have close-form solutions.

Note that we have:

Fμ\displaystyle\frac{\partial F}{\partial\mu} =\displaystyle= μr0rx2eμ2x2/21xdx\displaystyle-\mu\int_{r_{0}}^{r}\frac{x^{2}e^{-\mu^{2}x^{2}/2}}{1-x}\mathrm{d}x (72)
=\displaystyle= μr0r1x21xeμ2x2/2dxμr0reμ2x2/21xdx\displaystyle\mu\int_{r_{0}}^{r}\frac{1-x^{2}}{1-x}e^{-\mu^{2}x^{2}/2}\mathrm{d}x-\mu\int_{r_{0}}^{r}\frac{e^{-\mu^{2}x^{2}/2}}{1-x}\mathrm{d}x (73)
=\displaystyle= μr0r(1+x)eμ2x2/2dxμF(r,r0,μ)\displaystyle\mu\int_{r_{0}}^{r}(1+x)e^{-\mu^{2}x^{2}/2}\mathrm{d}x-\mu F(r,r_{0},\mu) (74)
=\displaystyle= π2[erf(rμ2)erf(r0μ2)]+1μ(er02μ2/2er2μ2/2)μF(r,r0,μ)\displaystyle\sqrt{\frac{\pi}{2}}\left[\mathrm{erf}\left(\frac{r\mu}{\sqrt{2}}\right)-\mathrm{erf}\left(\frac{r_{0}\mu}{\sqrt{2}}\right)\right]+\frac{1}{\mu}(e^{-r_{0}^{2}\mu^{2}/2}-e^{-r^{2}\mu^{2}/2})-\mu F(r,r_{0},\mu) (75)

Let

ζ(r,r0,μ):=π2[erf(rμ2)erf(r0μ2)]+1μ(er02μ2/2er2μ2/2)\zeta(r,r_{0},\mu):=\sqrt{\frac{\pi}{2}}\left[\mathrm{erf}\left(\frac{r\mu}{\sqrt{2}}\right)-\mathrm{erf}\left(\frac{r_{0}\mu}{\sqrt{2}}\right)\right]+\frac{1}{\mu}(e^{-r_{0}^{2}\mu^{2}/2}-e^{-r^{2}\mu^{2}/2}) (76)

Applying Lemma 5 and notice that μ>0\mu>0, we have

|ζ(r,r0,μ)|2π+2(|r0|+|r|)/22π+max(2|r0|,|r0|+1)=:M(r0)|\zeta(r,r_{0},\mu)|\leq\sqrt{2\pi}+\sqrt{2}(|r_{0}|+|r|)/\sqrt{2}\leq\sqrt{2\pi}+\max(2|r_{0}|,|r_{0}|+1)=:M(r_{0}) (77)

which means that |ζ(r,r0,μ)||\zeta(r,r_{0},\mu)| is uniformly bounded, regardless of μ\mu and r(t)r(t) (note that rr is bounded and will converge to 11 from the dynamics). Integrating both side and we have:

μ(eμ2/2F(r,r0,μ))\displaystyle\frac{\partial}{\partial\mu}\left(e^{\mu^{2}/2}F(r,r_{0},\mu)\right) =\displaystyle= ζ(r,r0,μ)eμ2/2\displaystyle\zeta(r,r_{0},\mu)e^{\mu^{2}/2} (78)
eμ2/2F(r,r0,μ)F(r,r0,0)\displaystyle e^{\mu^{2}/2}F(r,r_{0},\mu)-F(r,r_{0},0) =\displaystyle= 0μζ(r,r0,x)ex2/2dx\displaystyle\int_{0}^{\mu}\zeta(r,r_{0},x)e^{x^{2}/2}\mathrm{d}x (79)
F(r,r0,μ)\displaystyle F(r,r_{0},\mu) =\displaystyle= eμ2/2F(r,r0,0)+eμ2/20μζ(r,r0,x)ex2/2dx\displaystyle e^{-\mu^{2}/2}F(r,r_{0},0)+e^{-\mu^{2}/2}\int_{0}^{\mu}\zeta(r,r_{0},x)e^{x^{2}/2}\mathrm{d}x (80)

Note that F(r,r0,0)F(r,r_{0},0) has a close form:

F(r,r0,0)=r0r11xdx=ln1r01rF(r,r_{0},0)=\int_{r_{0}}^{r}\frac{1}{1-x}\mathrm{d}x=\ln\frac{1-r_{0}}{1-r} (81)

has a close-form solution that works for both r0<r<1r_{0}<r<1 and r0>r>1r_{0}>r>1 (the situations that 1 is between r0r_{0} and rr won’t happen). Using mean-value theorem, we have:

F(r,r0,μ)=eμ2/2ln1r01r+ζ(r,r0,μ¯)eμ2/20μex2/2dxF(r,r_{0},\mu)=e^{-\mu^{2}/2}\ln\frac{1-r_{0}}{1-r}+\zeta(r,r_{0},\bar{\mu})e^{-\mu^{2}/2}\int_{0}^{\mu}e^{x^{2}/2}\mathrm{d}x (82)

Applying Lemma 6, we have the following bound for F(r,μ)F(r,\mu):

M(r0)F(r,μ)eμ2/2ln1r01rM(r0)-M(r_{0})\leq F(r,\mu)-e^{-\mu^{2}/2}\ln\frac{1-r_{0}}{1-r}\leq M(r_{0}) (83)

When rr is close to 11 (near convergence), the term eμ2/2ln1r01re^{-\mu^{2}/2}\ln\frac{1-r_{0}}{1-r} (with fixed μ\mu and fixed r0r_{0}) is huge compared to the constant M(r0)M(r_{0}), which is 2π+1.54.0066\sqrt{2\pi}+1.5\approx 4.0066 for e.g., |r0|=1/2|r_{0}|=1/2, and thus F(r,μ)eμ2/2ln1r01rF(r,\mu)\rightarrow e^{-\mu^{2}/2}\ln\frac{1-r_{0}}{1-r}.

To be more concrete, note that δ(t)=1v(t)/μ=1r(t)\delta(t)=1-v(t)/\mu=1-r(t), we let

ρ(δ(t),μ)=F(1δ(t),1δ(0),μ)eμ2/2lnδ(0)δ(t)(M(r0),M(r0))\rho(\delta(t),\mu)=F(1-\delta(t),1-\delta(0),\mu)-e^{-\mu^{2}/2}\ln\frac{\delta(0)}{\delta(t)}\in(-M(r_{0}),M(r_{0})) (84)

And using Eqn. 71, we have:

F(1δj(t),1δj(0),μj)=F(1δk(t),1δk(0),μk)F(1-\delta_{j}(t),1-\delta_{j}(0),\mu_{j})=F(1-\delta_{k}(t),1-\delta_{k}(0),\mu_{k}) (85)

Then

λjk(t)\displaystyle\lambda_{jk}(t) :=\displaystyle:= ρ(δk(t),μk)ρ(δj(t),μj)\displaystyle\rho(\delta_{k}(t),\mu_{k})-\rho(\delta_{j}(t),\mu_{j}) (86)
=\displaystyle= eμj2/2lnδj(0)δj(t)eμk2/2lnδk(0)δk(t)\displaystyle e^{-\mu_{j}^{2}/2}\ln\frac{\delta_{j}(0)}{\delta_{j}(t)}-e^{-\mu_{k}^{2}/2}\ln\frac{\delta_{k}(0)}{\delta_{k}(t)} (87)

and |λjk(t)|M(rj(0))+M(rk(0))|\lambda_{jk}(t)|\leq M(r_{j}(0))+M(r_{k}(0)). Then we arrive at the conclusion. ∎

A.5 Hierarchical Latent Tree Models (Section  5)

We formally introduce the definition of HBLT here. Let yαy_{\alpha} be a binary variable at layer ss (upper layer and yβy_{\beta} be a binary variable at layer s1s-1 (lower layer). We use a 2x2 matrix Pβ|αP_{\beta|\alpha} to represent their conditional probability:

Pβ|α:=[[yβ|yα]]=[[yβ=0|yα=0][yβ=0|yα=1][yβ=1|yα=0][yβ=1|yα=1]]P_{\beta|\alpha}:=[\mathbb{P}[y_{\beta}|y_{\alpha}]]=\left[\begin{array}[]{cc}\mathbb{P}[y_{\beta}=0|y_{\alpha}=0]&\mathbb{P}[y_{\beta}=0|y_{\alpha}=1]\\ \mathbb{P}[y_{\beta}=1|y_{\alpha}=0]&\mathbb{P}[y_{\beta}=1|y_{\alpha}=1]\end{array}\right] (88)
Definition 1.

Define 2×22\times 2 matrix M(ρ):=12[1+ρ1ρ1ρ1+ρ]M(\rho):=\frac{1}{2}\left[\begin{array}[]{cc}1+\rho&1-\rho\\ 1-\rho&1+\rho\end{array}\right] and 22-dimensional vector 𝐩(ρ)=12[1+ρ,1ρ]{\bm{p}}(\rho)=\frac{1}{2}[1+\rho,1-\rho]^{\top} for ρ[1,1]\rho\in[-1,1].

Lemma 7 (Property of M(ρ)M(\rho)).

M(ρ)M(\rho) has the following properties:

  • M(ρ)M(\rho) is a symmetric matrix.

  • M(ρ)𝟏2=𝟏2M(\rho){\bm{1}}_{2}={\bm{1}}_{2}.

  • M(ρ1)M(ρ2)=M(ρ1ρ2)M(\rho_{1})M(\rho_{2})=M(\rho_{1}\rho_{2}). So matrix multiplication in {M(ρ)}ρ[1,1]\{M(\rho)\}_{\rho\in[-1,1]} is communicative and isomorphic to scalar multiplication.

  • M(ρ1)𝒑(ρ2)=𝒑(ρ1ρ2)M(\rho_{1}){\bm{p}}(\rho_{2})={\bm{p}}(\rho_{1}\rho_{2}).

Proof.

The first two are trivial properties. For the third one, notice that M(ρ)=12(𝟏𝟏T+ρ𝒆𝒆)M(\rho)=\frac{1}{2}({\bm{1}}{\bm{1}}^{T}+\rho{\bm{e}}{\bm{e}}^{\top}), in which 𝒆:=[1,1]{\bm{e}}:=[1,-1]^{\top}. Therefore, 𝒆𝒆=2{\bm{e}}^{\top}{\bm{e}}=2 and 𝟏𝒆=0{\bm{1}}^{\top}{\bm{e}}=0 and thus:

M(ρ1)M(ρ2)=14(𝟏𝟏T+ρ1𝒆𝒆)(𝟏𝟏T+ρ2𝒆𝒆)=12(𝟏𝟏+ρ1ρ2𝒆𝒆)=M(ρ1ρ2)M(\rho_{1})M(\rho_{2})=\frac{1}{4}({\bm{1}}{\bm{1}}^{T}+\rho_{1}{\bm{e}}{\bm{e}}^{\top})({\bm{1}}{\bm{1}}^{T}+\rho_{2}{\bm{e}}{\bm{e}}^{\top})=\frac{1}{2}({\bm{1}}{\bm{1}}^{\top}+\rho_{1}\rho_{2}{\bm{e}}{\bm{e}}^{\top})=M(\rho_{1}\rho_{2}) (89)

For the last one, note that 𝒑(ρ)=12(𝟏+ρ𝒆){\bm{p}}(\rho)=\frac{1}{2}({\bm{1}}+\rho{\bm{e}}) and the conclusion follows. ∎

Definition 2 (Definition of HBLT).

In HBLT(ρ)\texttt{HBLT}(\rho), Pβ|α=M(ρβ|α)P_{\beta|\alpha}=M(\rho_{\beta|\alpha}), where ρβ|α[1,1]\rho_{\beta|\alpha}\in[-1,1] is the uncertainty parameter. In particular, if ρβ|α=ρ\rho_{\beta|\alpha}=\rho, then we just write the entire HBLT model as HBLT(ρ)\texttt{HBLT}(\rho).

Lemma 8.

For latent yαy_{\alpha} and its descendent yγy_{\gamma}, we have:

Pγ|α=Pγ|β1Pβ1|β2Pβk|α=M(ργ|α)P_{\gamma|\alpha}=P_{\gamma|\beta_{1}}P_{\beta_{1}|\beta_{2}}\ldots P_{\beta_{k}|\alpha}=M\left(\rho_{\gamma|\alpha}\right) (90)

where ργ|α:=ργ|β1ρβ1|β2ρβk|α\rho_{\gamma|\alpha}:=\rho_{\gamma|\beta_{1}}\rho_{\beta_{1}|\beta_{2}}\ldots\rho_{\beta_{k}|\alpha} and αβ1β2βkγ\alpha\succ\beta_{1}\succ\beta_{2}\succ\ldots\succ\beta_{k}\succ\gamma is the descendent chain from yαy_{\alpha} to yγy_{\gamma}.

Proof.

Due to the tree structure of HBLT, we have:

[yγ|yα]=yβ1,yβ2,,yβk[yγ|yβ1][yβ1|yβ2][yβk|yα]\mathbb{P}[y_{\gamma}|y_{\alpha}]=\sum_{y_{\beta_{1}},y_{\beta_{2}},\ldots,y_{\beta_{k}}}\mathbb{P}[y_{\gamma}|y_{\beta_{1}}]\mathbb{P}[y_{\beta_{1}}|y_{\beta_{2}}]\ldots\mathbb{P}[y_{\beta_{k}}|y_{\alpha}] (91)

which is precisely how the entries of Pγ|β1Pβ1|β2Pβk|αP_{\gamma|\beta_{1}}P_{\beta_{1}|\beta_{2}}\ldots P_{\beta_{k}|\alpha} get computed. By leveraging the property of M(ρ)M(\rho), we arrive at the conclusion. ∎

See 5

Proof.

Let the common latent ancestor (CLA) of yβ1y_{\beta_{1}} and yβ2y_{\beta_{2}} be ycy_{c}, then we have:

[yβ1,yβ2]=yc[yβ1|yc][yβ2|yc][yc]\mathbb{P}[y_{\beta_{1}},y_{\beta_{2}}]=\sum_{y_{c}}\mathbb{P}[y_{\beta_{1}}|y_{c}]\mathbb{P}[y_{\beta_{2}}|y_{c}]\mathbb{P}[y_{c}] (92)

Let Pβ1β2=[[yβ1,yβ2]]P_{\beta_{1}\beta_{2}}=[\mathbb{P}[y_{\beta_{1}},y_{\beta_{2}}]], then we have:

Pβ1β2=M(ρβ1|c)D(c)M(ρβ2|c)P_{\beta_{1}\beta_{2}}=M(\rho_{\beta_{1}|c})D(c)M^{\top}(\rho_{\beta_{2}|c}) (93)

where D(c):=diag([yc])=12[1+ρc001ρc]D(c):=\mathrm{diag}(\mathbb{P}[y_{c}])=\frac{1}{2}\left[\begin{array}[]{cc}1+\rho_{c}&0\\ 0&1-\rho_{c}\end{array}\right] is a diagonal matrix, and ρc:=2[yc=0]1\rho_{c}:=2\mathbb{P}[y_{c}=0]-1. Note that

𝟏D(c)𝟏=𝒆D(c)𝒆=1,𝟏D(c)𝒆=𝒆D(c)𝟏=ρc{\bm{1}}^{\top}D(c){\bm{1}}={\bm{e}}^{\top}D(c){\bm{e}}=1,\quad\quad{\bm{1}}^{\top}D(c){\bm{e}}={\bm{e}}^{\top}D(c){\bm{1}}=\rho_{c} (94)

And M(ρ)=12(𝟏𝟏T+ρ𝒆𝒆)M(\rho)=\frac{1}{2}({\bm{1}}{\bm{1}}^{T}+\rho{\bm{e}}{\bm{e}}^{\top}), therefore we have:

Pβ1β2\displaystyle P_{\beta_{1}\beta_{2}} =\displaystyle= M(ρβ1|c)D(c)M(ρβ2|c)\displaystyle M(\rho_{\beta_{1}|c})D(c)M^{\top}(\rho_{\beta_{2}|c}) (95)
=\displaystyle= 14(𝟏𝟏T+ρβ1|c𝒆𝒆)D(c)(𝟏𝟏T+ρβ2|c𝒆𝒆)\displaystyle\frac{1}{4}({\bm{1}}{\bm{1}}^{T}+\rho_{\beta_{1}|c}{\bm{e}}{\bm{e}}^{\top})D(c)({\bm{1}}{\bm{1}}^{T}+\rho_{\beta_{2}|c}{\bm{e}}{\bm{e}}^{\top}) (96)
=\displaystyle= 14(𝟏𝟏T+ρβ1|cρβ2|c𝒆𝒆+ρβ1|cρc𝒆𝟏+ρβ2|cρc𝟏𝒆)\displaystyle\frac{1}{4}\left({\bm{1}}{\bm{1}}^{T}+\rho_{\beta_{1}|c}\rho_{\beta_{2}|c}{\bm{e}}{\bm{e}}^{\top}+\rho_{\beta_{1}|c}\rho_{c}{\bm{e}}{\bm{1}}^{\top}+\rho_{\beta_{2}|c}\rho_{c}{\bm{1}}{\bm{e}}^{\top}\right) (97)

Now we compute ρc\rho_{c}. Note that

[yc]=y0[yc|y0][y0]\mathbb{P}[y_{c}]=\sum_{y_{0}}\mathbb{P}[y_{c}|y_{0}]\mathbb{P}[y_{0}] (98)

Let 𝒑c:=[[yc]]{\bm{p}}_{c}:=[\mathbb{P}[y_{c}]] be a 2-dimensional vector. Then we have 𝒑c=Pyc|y0𝒑0=𝒑(ρc|0ρ0){\bm{p}}_{c}=P_{y_{c}|y_{0}}{\bm{p}}_{0}={\bm{p}}(\rho_{c|0}\rho_{0}), where 𝒑0{\bm{p}}_{0} is the probability distribution of class label y0y_{0}, which can be categorical of size CC:

𝒑c\displaystyle{\bm{p}}_{c} =\displaystyle= Pyc|y0𝒑0=y1Pyc|y1Py1|y0𝒑0\displaystyle P_{y_{c}|y_{0}}{\bm{p}}_{0}=\sum_{y_{1}}P_{y_{c}|y_{1}}P_{y_{1}|y_{0}}{\bm{p}}_{0} (99)
=\displaystyle= M(ρc|1)12[1+p1|01+p2|01+pC|01p1|01p2|01pC|0]𝒑0\displaystyle M(\rho_{c|1})\frac{1}{2}\left[\begin{array}[]{cccc}1+p_{1|0}&1+p_{2|0}&\ldots&1+p_{C|0}\\ 1-p_{1|0}&1-p_{2|0}&\ldots&1-p_{C|0}\end{array}\right]{\bm{p}}_{0} (102)
=\displaystyle= M(ρc|1)12[1+𝒑|0𝒑01𝒑|0𝒑0]\displaystyle M(\rho_{c|1})\frac{1}{2}\left[\begin{array}[]{c}1+{\bm{p}}_{\cdot|0}^{\top}{\bm{p}}_{0}\\ 1-{\bm{p}}_{\cdot|0}^{\top}{\bm{p}}_{0}\end{array}\right] (105)
=\displaystyle= M(ρc|1𝒑|0𝒑0)\displaystyle M(\rho_{c|1}{\bm{p}}_{\cdot|0}^{\top}{\bm{p}}_{0}) (106)

in which y1y_{1} is the last binary variable right below the root node class label y0y_{0}.

Therefore, ρc=ρc|1ρ0\rho_{c}=\rho_{c|1}\rho_{0}, where ρ0:=𝒑|0𝒑0\rho_{0}:={\bm{p}}_{\cdot|0}^{\top}{\bm{p}}_{0} is the uncertainty parameter of the root node y0y_{0}.

If all ρβ|α=ρ\rho_{\beta|\alpha}=\rho for immediate parent yαy_{\alpha} and child yβy_{\beta}, yβ1y_{\beta_{1}} is for token ll and yβ2y_{\beta_{2}} is for token mm, then ρβ1|c=ρβ2|c=ρH\rho_{\beta_{1}|c}=\rho_{\beta_{2}|c}=\rho^{H}, and ρc|1=ρL1H\rho_{c|1}=\rho^{L-1-H} and thus we have:

[yl=1|ym=1]\displaystyle\mathbb{P}[y_{l}=1|y_{m}=1] =\displaystyle= [yl=1,ym=1][ym=1]=12(1+ρ2H2ρHρc1ρHρc)\displaystyle\frac{\mathbb{P}[y_{l}=1,y_{m}=1]}{\mathbb{P}[y_{m}=1]}=\frac{1}{2}\left(\frac{1+\rho^{2H}-2\rho^{H}\rho_{c}}{1-\rho^{H}\rho_{c}}\right) (107)
=\displaystyle= 12(1+ρ2H2ρL1ρ01ρL1ρ0)\displaystyle\frac{1}{2}\left(\frac{1+\rho^{2H}-2\rho^{L-1}\rho_{0}}{1-\rho^{L-1}\rho_{0}}\right) (108)

and the conclusion follows. ∎

Appendix B More Experiment Results

B.1 Orthogonality of embedding vectors

We verify the orthogonality assumption mentioned in our problem setting (Sec. 2). The orthogonality is measured by absolute cosine similarity cossim(𝒙1,𝒙2)[0,1]\mathrm{cossim}({\bm{x}}_{1},{\bm{x}}_{2})\in[0,1] of two vectors 𝒙1{\bm{x}}_{1} and 𝒙2{\bm{x}}_{2}:

cossim(𝒙1,𝒙2):=|𝒙1𝒙2|𝒙1𝒙2\mathrm{cossim}({\bm{x}}_{1},{\bm{x}}_{2}):=\frac{|{\bm{x}}^{\top}_{1}{\bm{x}}_{2}|}{\|{\bm{x}}_{1}\|\|{\bm{x}}_{2}\|} (109)

Here the two vectors 𝒙1{\bm{x}}_{1} and 𝒙2{\bm{x}}_{2} are column vectors of the out-projection (or upper) matrix of MLPs at different layers, each corresponding to one hidden neuron. For a MLP layer with model dimension dd and hidden dimension 4d4d, there will be 4d4d such column vectors. We measure the average cosine similarity across all 2d(4d1)2d(4d-1) pairs and report in the figure.

While 4d4d dd-dimensional vectors have to be linearly dependent, they are indeed almost orthogonal (i.e., cossim(𝒙1,𝒙2)1\mathrm{cossim}({\bm{x}}_{1},{\bm{x}}_{2})\ll 1) throughout the training process, as shown below. In Fig. 11, we show cosine similiarity over the entire training process of Pythia models of different sizes. Fig. 12 further checks the training curve at early training stages, since Pythia checkpoints are more densely sampled around early training stages, i.e., “steps 0 (initialization), 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1000, and then every 1,000 subsequent steps” (Biderman et al., 2023). Finally, for models whose intermediate checkpoints are not available, we show the cosine similarity in the publicly released pre-trained models (Fig. 13).

Refer to caption
Figure 11: Orthogonality of embeddings of MLP in LLMs during the whole training process.
Refer to caption
Figure 12: Orthogonality of embeddings of MLP in LLMs during the early training stage.
Refer to caption
Figure 13: Orthogonality measures in model architectures (BERT-Base, OPT-6.7B, LLaMA-2-7B, ViT-Huge), with only final checkpoint available.

B.2 Attention Entropy for Encoder-decoder models

We also measure how attention entropy, as well as stable rank of the in-projection (or lower) matrix in MLP, changes over time for encoder-decoder models like BERT, as shown in Fig. 14. The behavior is very similar to the decoder-only case (Fig. 7), further verifying our theoretical findings.

Refer to caption
Figure 14: (Left) Attention entropy of BERT; (Right) Stable rank in BERT.