This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Learning Linear Attention in Polynomial Time

Morris Yau MIT CSAIL Ekin Akyürek MIT CSAIL Jiayuan Mao MIT CSAIL
Joshua B. Tenenbaum
MIT CSAIL Department of Brain and Cognitive Sciences, MIT Center for Brains, Minds, and Machines, MIT
Stefanie Jegelka School of CIT; MCML and MDSI, TU Munich Jacob Andreas MIT CSAIL
Abstract

Previous research has explored the computational expressivity of Transformer models in simulating Boolean circuits or Turing machines. However, the learnability of these simulators from observational data has remained an open question. Our study addresses this gap by providing the first polynomial-time learnability results (specifically strong, agnostic PAC learning) for single-layer Transformers with linear attention. We show that linear attention may be viewed as a linear predictor in a suitably defined RKHS. As a consequence, the problem of learning any linear transformer may be converted into the problem of learning an ordinary linear predictor in an expanded feature space, and any such predictor may be converted back into a multiheaded linear transformer. Moving to generalization, we show how to efficiently identify training datasets for which every empirical risk minimizer is equivalent (up to trivial symmetries) to the linear Transformer that generated the data, thereby guaranteeing the learned model will correctly generalize across all inputs. Finally, we provide examples of computations expressible via linear attention and therefore polynomial-time learnable, including associative memories, finite automata, and a class of Universal Turing Machine (UTMs) with polynomially bounded computation histories. We empirically validate our theoretical findings on three tasks: learning random linear attention networks, key–value associations, and learning to execute finite automata. Our findings bridge a critical gap between theoretical expressivity and learnability of Transformers, and show that flexible and general models of computation are efficiently learnable.

1 Introduction

Transformers have become a ubiquitous tool in a range of learning problems, due to their versatility in generating structured sequences (including text) conditioned on complex inputs (including natural language instructions). A large body of work has attempted to explain the behavior of trained Transformers and characterize their expressive power . Here one central problem—particularly useful for understanding Transformers’ ability to follow instructions in language—is understanding how reliably they can simulate execution of automata like universal Turing machines, which in turn can execute arbitrary programs  (i.e., the input “instructions”). While several recent papers have shown that Transformers are expressive enough to implement important models of computation, it remains an open question whether these machines may be effectively learned. Even verifying that a trained model has successfully learned a generalizable computation procedure has remained challenging.

Concretely, existing work has shown positive results on how Transformer-like architectures can realize many algorithmic computations, including simulating universal Turing machines (Li et al., 2024), evaluating sentences of first-order logic (Barceló et al., 2020), and recognizing various formal languages (Strobl et al., 2024). However, these proofs rely on manual constructions of network weights. For questions about learnability, the most comprehensive result to date is due to Luo et al. (2023), who proposed a proof under very strong assumptions about the data distribution (i.e., under complete presentation by enumerating of all Turing machines up to a constant size). While the result is positive in the sense that training a Transformer model on a finite dataset enables generalization to input strings of arbitrary lengths, the bound on the amount of data required for training grows exponentially in the size of the model to be trained.

In this paper, we establish the strong, agnostic PAC-learnability of a family of simplified Transformer models, namely linear attention modules (without a softmax function; Ahn et al., 2024, Katharopoulos et al., 2020). We focus our analysis on single layer linear Transformers (i.e. multi-head linear attention networks, or MHLAs) for regression tasks. An MHLA is parameterized by two matrices (Vh,Qh)(V_{h},Q_{h}) for each of HH heads as such Θ={(Vh,Qh)}h[H]\Theta=\{(V_{h},Q_{h})\}_{h\in[H]}. One layer MHLA computes Y=h[H]VhZ(ZTQhZ)Y=\sum_{h\in[H]}V_{h}Z(Z^{T}Q_{h}Z).

Despite its simplicity, this model class retains significant expressive power: we show that it can realize a restricted but expressive class of universal Turing machines with bounded computation history size. Our results imply that Transformer models can efficiently (in polynomial time) learn these machines with polynomial sample complexity. Moreover, we show that under checkable conditions, minimizing the empirical risk is guaranteed to recover a model that correctly simulates arbitrary new input machines.

We first show that the computation performed by MHLAs can be reformulated as an elementwise product between two larger matrices W,𝒳(Z)\langle W,\mathcal{X}(Z)\rangle, where W=h[H]flatten(Vh)flatten(Qh)TW=\sum_{h\in[H]}\text{flatten}(V_{h})\text{flatten}(Q_{h})^{T} and 𝒳(Z)\mathcal{X}(Z) is a fixed cubic polynomial function of ZZ. Consequently, optimizing over the class of HH-head MHLA models is equivalent to optimizing over the class of rank-HH matrices WW. Furthermore, in the full-rank space of d2×d2d^{2}\times d^{2} matrices, optimization of WW can be performed via linear regression with time polynomial in the inverse target error and size of the dataset. Finally, decomposing an optimal WW via SVD recovers an MHLA model with no more than d2d^{2} heads that is then guaranteed to compete against the best MHLA parameters—establishing our agnostic learning result (the learned model competes against the best choice of parameters in the hypothesis class).

Next we define a certificate of identifiability for MHLAs—an efficiently checkable condition for any dataset DD that, if true, ensures every empirical risk minimizer in the class of MHLAs computes the same function on all possible inputs. More specifically, let ΛD\Lambda_{D} be the second moment of the flattening of the data in the 𝒳\mathcal{X}-feature space: ΛD=𝔼D[𝒳(Z)𝒳(Z)T]\Lambda_{D}=\mathbb{E}_{D}\left[\mathcal{X}(Z)\mathcal{X}(Z)^{T}\right]. We show that, if ΛD\Lambda_{D} is full rank, it is guaranteed that MHLA is identifiable with respect to the data. We call this second moment condition “certifiable identifiability”. When combined with existing results on realizability, this leads to a polynomial-time algorithm for learning computations expressible as linear attention, with checkable certificates of identifiability. Given a dataset satisfying checkable identifiability, if a computational procedure can fit the dataset and can be realized by a MHLA, then all empirical risk minimizers will be equivalent to that procedure. This applies to any function realizable by single-layer linear Transformers.

In the experimental section, we validate our theoretical findings beyond their initial scope. In Section 6.1, we train multiple models using stochastic gradient descent on a dataset generated by a single linear attention network’s output. Our results demonstrate that multi-head linear attention outperforms both single-layer linear attention and multi-layer linear attention, achieving comparable results to our polynomial-time algorithm. In Section 6.2, we show that our proposed certificate correlates with generalization error, even for models trained using stochastic gradient descent. In our final set of experiments (Section 6.3), we examine the data requirements for learning a specific Universal Turing Machine (UTM) capable of executing discrete finite automata, in which we observe sub-exponential data requirements.

In summary, we show:

  • We provide a polynomial time algorithm that given any dataset, finds the best fit parameters for single-layer linear transformers (MHLAs) and generalizes with polynomial data i.e strong agnostic PAC learning (Sections 2.1 and 3).

  • We find an efficiently checkable condition on the training dataset that certifies every empirical risk minimizer of a MHLA is functionally equivalent, and therefore has the same behavior out of distribution (Section 4). We call this condition ”certifiable identifiability” Lemma 4.1.

  • MHLAs are an expressive model of computation that includes universal Turing machines with polynomially bounded computation histories (Section 2.3). Combined with our identifiability results, we conclude that given a certifiably identifiable dataset of Turing machines and computation histories on input words, empirical risk minimization and in particular algorithm 1 will learn the universal Turing machine in a strong sense. That is at test time it will run any TM and input word up to a given size for a bounded number of steps. See section 5.

Our results shed new light on the learnability of instruction following procedures, and indeed the learnability of a broad class of simple attention-based models.

2 Technical Overview

We start with basic definitions of a multi-head linear attention (MHLA) module, a stackable attention-based neural module simplified from the standard Transformer module by removing the softmax activation. MHLA has been a standard subject of study for expressivity and learning theory.

Definition (Multi-Head Linear Attention).

Let Zd×nZ\in\mathbb{R}^{d\times n} be a matrix of input data. Let Θ={(Vh,Qh)}h[H]\Theta=\{(V_{h},Q_{h})\}_{h\in[H]} be a set of parameters where each Vh,Qhd×dV_{h},Q_{h}\in\mathbb{R}^{d\times d} denotes value and key-query matrices for all heads h[H]h\in[H]. We say ΘΩH\Theta\in\Omega_{H} where ΩH\Omega_{H} is the space of sets of HH ordered tuples of of d×dd\times d matrices. We define multi-head linear attention (MHLA) to be the function MHLAΘ:d×nd×n\text{MHLA}_{\Theta}:\mathbb{R}^{d\times n}\rightarrow\mathbb{R}^{d\times n},

Y^=MHLAΘ(Z)=h[H]VhZ(ZTQhZ),\hat{Y}=\text{MHLA}_{\Theta}(Z)=\sum\nolimits_{h\in[H]}V_{h}Z(Z^{T}Q_{h}Z)~{}, (1)

where Y^d×n\hat{Y}\in\mathbb{R}^{d\times n} is the output of the one layer linear attention. We will primarily be interested in the rightmost column vector output by MHLAΘ\text{MHLA}_{\Theta} (e.g., as in auto-regressive language models), which is:

y^=MHLAΘ(Z)=h[H]VhZ(ZTQhZ[:,n]),\hat{y}=\text{MHLA}_{\Theta}(Z)=\sum\nolimits_{h\in[H]}V_{h}Z(Z^{T}Q_{h}Z[:,n])~{}, (2)

where Z[:,n]Z[:,n] is the nnth column of ZZ.

Note that MHLAΘ\text{MHLA}_{\Theta} is a uniform circuit family: it can take inputs of any length dimension nn and fixed embedding dimension dd. It is uniform in the sense that there is a polynomial time algorithm that maps from the parameters Θ\Theta to the circuit that processes inputs of length nn. We will occasionally abuse terminology and refer to MHLA as a function rather than as a family of functions.

2.1 Polynomial-time learnability

Our main result is that MHLA is learnable in polynomial time. Colloquially, Algorithm 1 returns an MHLA that attains the global minimum of the training loss, and requires as few as poly(d,ϵ1,log(δ1))\text{poly}(d,\epsilon^{-1},\log(\delta^{-1})) samples to achieve ϵ\epsilon generalization error with probability 1δ1-\delta. Here our algorithmic guarantees do not require the data to be “realizable” (that is, there need be no underlying MHLA that generates the data).

Theorem 2.1 (Learnability of Linear Attention).

Let DD be a dataset D={Zi,yi}i[N]D=\{Z_{i},y_{i}\}_{i\in[N]} drawn i.i.d. from a distribution 𝒟\mathcal{D} where each Zid×niZ_{i}\in\mathbb{R}^{d\times n_{i}}, yidy_{i}\in\mathbb{R}^{d}. Here the embedding dimension dd is fixed across the dataset, whereas nin_{i} can be different for each datapoint. Let nmaxn_{max} be the maximum sequence length |ni||n_{i}| for i[N]i\in[N], and let ΩH\Omega_{H} be the space of HH pairs of value and key-query matrices {(Vh,Qh)}h[H]\{(V_{h},Q_{h})\}_{h\in[H]} for any H[1,)H\in[1,\infty). Then there is an algorithm (Algorithm 1) that runs in time O(Nd4nmaxϵ1)O(Nd^{4}n_{max}\epsilon^{-1}) and that, given input–output pairs {(Zi,yi)}i[N]\{(Z_{i},y_{i})\}_{i\in[N]}, returns Θ^={(V^h,Q^h)}h[H^]ΩH^\hat{\Theta}=\{(\hat{V}_{h},\hat{Q}_{h})\}_{h\in[\hat{H}]}\in\Omega_{\hat{H}} for H^d2\hat{H}\leq d^{2} such that with probability 1δ1-\delta,

𝔼(Z,y)𝒟[MHLAΘ^(Z)y2]minΘΩH𝔼(Z,y)𝒟[MHLAΘ(Z)y2]ϵ\mathbb{E}_{(Z,y)\in\mathcal{D}}\left[\|\text{MHLA}_{\hat{\Theta}}(Z)-y\|^{2}\right]-\min\nolimits_{\Theta\in\Omega_{H}}\mathbb{E}_{(Z,y)\in\mathcal{D}}\left[\|\text{MHLA}_{\Theta}(Z)-y\|^{2}\right]\leq\epsilon (3)

with sample complexity N=O(1ϵ(d4+log(δ1)))N=O\left(\frac{1}{\epsilon}\left(d^{4}+\log(\delta^{-1})\right)\right).

Below we describe the high-level intuition behind this proof; a formal statement is given Appendix A. Note additionally that, if we are purely concerned with guaranteeing that we can find a global minimum of the training loss, we may remove the i.i.d. assumption: Algorithm 1 is always within error ϵ\epsilon of the training loss. This is also detailed in Appendix A. Specific issues related to generalization over autoregressive sequences rather than i.i.d. data are handled in the UTM learning result: see Section C.2.

The main idea behind Algorithm 1 is to construct a feature mapping 𝒳:d×nd×d2\mathcal{X}:\mathbb{R}^{d\times n}\rightarrow\mathbb{R}^{d\times d^{2}} from the data covariates ZZ to a feature space of dimension d×d2d\times d^{2}. The map 𝒳(Z)\mathcal{X}(Z) is defined as:

𝒳(Z):=[z1:,z1:z1nz1:,z2:z1nz1:,zd:z2nz1:,zd:zdnz2:,z1:z1nz2:,z2:z1nz2:,zd:z2nz2:,zd:zdnzd:,z1:z1nzd:,z2:z1nzd:,zd:z2nzd:,zd:zdn].\mathcal{X}(Z)\vcentcolon=\begin{bmatrix}\langle z_{1:},z_{1:}\rangle z_{1n}&\langle z_{1:},z_{2:}\rangle z_{1n}&\cdots&\langle z_{1:},z_{d:}\rangle z_{2n}&\cdots&\langle z_{1:},z_{d:}\rangle z_{dn}\\ \langle z_{2:},z_{1:}\rangle z_{1n}&\langle z_{2:},z_{2:}\rangle z_{1n}&\cdots&\langle z_{2:},z_{d:}\rangle z_{2n}&\cdots&\langle z_{2:},z_{d:}\rangle z_{dn}\\ \vdots&\vdots&\ddots&\vdots&\ddots&\vdots\\ \langle z_{d:},z_{1:}\rangle z_{1n}&\langle z_{d:},z_{2:}\rangle z_{1n}&\cdots&\langle z_{d:},z_{d:}\rangle z_{2n}&\cdots&\langle z_{d:},z_{d:}\rangle z_{dn}\\ \end{bmatrix}~{}. (4)

Here, we index the rows of 𝒳(Z)\mathcal{X}(Z) by j[d]j\in[d] and the columns by all tuples (k,)[d]2(k,\ell)\in[d]^{2} such that 𝒳(Z)j,(k,)=zj:,zk:zn\mathcal{X}(Z)_{j,(k,\ell)}=\langle z_{j:},z_{k:}\rangle z_{\ell n}. At a high level, Algorithm 1 is a kernel method defined by the feature mapping 𝒳\mathcal{X}. The learned kernel predictor (a regressor) can be mapped back onto a set of parameters {V^h,Q^h}hH^\{\hat{V}_{h},\hat{Q}_{h}\}_{h\in\hat{H}} for a MHLA with no more than d2d^{2} heads via SVD.

2.2 Identifiability

Our algorithmic result also sheds light on a closely related question about generalization in MHLAs: Is there an efficiently checkable condition on any given dataset, such that empirical risk minimization is guaranteed to learn the true data-generating process and generalize even to out-of-distribution examples? Without any qualifiers, “out of distribution” generalization is ill-defined. Nevertheless, if a model class is identifiable—if the empirical risk minimizers of the class all compute the same function on all inputs—we can at least know that the empirical risk minimization algorithm produces models with the same behavior (in- and out-of-distribution) on held-out data. Furthermore, if we have a specific computational procedure (e.g., a Universal Turing Machine) that can fit the data and can be realized by some setting of model parameters, then, assuming the dataset is identifiable, all empirical risk minimizers will be equivalent to that procedure. (For a more formal definition of realizability see Definition Definition).

We show that, as a direct implication of our algorithmic result, it is possible to produce such an efficiently checkable condition (a certificate) on the data that guarantees every empirical risk minimizer in a family of MHLAs computes the same function. Let ΛD\Lambda_{D} be the second moment of the flattening of the data, denoted (Z)\mathcal{H}(Z), in feature space:

ΛD=𝔼[(Z)(Z)T]=1NZD[(Z)(Z)T].\Lambda_{D}=\mathbb{E}[\mathcal{H}(Z)\,\mathcal{H}(Z)^{T}]=\frac{1}{N}\sum_{Z\in D}[\mathcal{H}(Z)\,\mathcal{H}(Z)^{T}]. (5)

Then if ΛD\Lambda_{D} is full rank, it is guaranteed that MHLA is identifiable with respect to the data.

Lemma 2.1 (Certificate of Identifiability—Informal).

Let dataset D={(Zi,yi)}i[N]D=\{(Z_{i},y_{i})\}_{i\in[N]} be realizable (see Definition Definition) by an HH-head MHLA for any H1H\geq 1. Let \mathcal{H} be the uniform family of polynomials n:d×nψ\mathcal{H}_{n}:\mathbb{R}^{d\times n}\rightarrow\mathbb{R}^{\psi} for ψ:=(d2)d+d2\psi\vcentcolon={d\choose 2}d+d^{2} defined as in Algorithm 2, and for convenience write (Z)=n(Z)\mathcal{H}(Z)=\mathcal{H}_{n}(Z) for Zd×nZ\in\mathbb{R}^{d\times n} (there is one n\mathcal{H}_{n}\in\mathcal{H} for each length-nn input). Finally, define ΛDψ×ψ\Lambda_{D}\in\mathbb{R}^{\psi\times\psi} to be the second moment of the data features:

ΛD:=𝔼D[(Z)(Z)T].\Lambda_{D}\vcentcolon=\mathbb{E}_{D}\left[\mathcal{H}(Z)\mathcal{H}(Z)^{T}\right]~{}. (6)

Then if the eigenvalue λmin(ΛD)>0\lambda_{\min}\left(\Lambda_{D}\right)>0, we say that MHLAΘ\text{MHLA}_{\Theta} is certifiably identifiable with respect to DD. That is, for every pair of empirical risk minimizers Θ,ΘΩH\Theta,\Theta^{\prime}\in\Omega_{H}

MHLAΘ=MHLAΘ\text{MHLA}_{\Theta}=\text{MHLA}_{\Theta^{\prime}} (7)

i.e., the two models have the same outputs on all inputs.

Corollary 1.

There is a polynomial p:Ωψp:\Omega\rightarrow\mathbb{R}^{\psi} such that for any pair of parameters Θ,ΘΩH\Theta,\Theta^{\prime}\in\Omega_{H} we have MHLAΘ=MHLAΘ\text{MHLA}_{\Theta}=\text{MHLA}_{\Theta^{\prime}} if and only if p(Θ)=p(Θ)p(\Theta)=p(\Theta^{\prime}).

Clearly MHLAs with very different parameters can compute the same function, due to simple symmetries like reciprocal scaling of the VV and QQ matrices. The polynomial pp defines the equivalence class of parameters that compute the same function. For a formal statement of Lemma 2.1 see Lemma 4.1. For handling of errors for approximate empirical risk minimization see Lemma 4.4. Moreover, the certificate given by Algorithm 2 is not the only choice of feature mapping \mathcal{H} that would certify identifiability; for a general lemma on certifiable identifiability see Lemma B.1. One way to interpret Corollary 1 is that two MHLA models parameterized by Θ\Theta and Θ\Theta^{\prime} compute the same function if and only if they are the same linear function in feature space (akin to matching coefficients in polynomial regression), which in turn is true if p(Θ)=p(Θ)p(\Theta)=p(\Theta^{\prime}) for the polynomial pp given in Lemma 4.1. Comparing distance between the coefficients in pp-space is essentially the only meaningful metric of distance that is agnostic to the choice of dataset.

Finally, we answer a few natural questions related to identifiability which we briefly summarize here. Firstly, perfectly noisy inputs yield identifiable data under weak assumptions on the moments of the noise (see Lemma 4.2). Secondly, the model class of MHLA with at least d2d^{2} heads is certifiably identifiable from the second moment condition alone, and does not require realizability of the data (see Lemma 4.3). Finally, we empirically verify these certificates predict the training behavior of SGD for MHLA for the problem of learning key–value memories (see Figure 2).

2.3 Realizability of Universal Automata in MHLA

We also include an application of our theory on learnability and identifiability to the problem of learning a universal Turing machine (UTMs) with polynomially bounded computation length. We prove such a UTM is expressible via MHLA in Lemma 2.2, and show that for certifiably identifiable data the learned MHLA generalizes to any TM MM and input word xx in Lemma 5.2.

Lemma 2.2 (UTM Expressibility).

Let Δ(𝒬^,Σ^,n^,Φ^)\Delta(\hat{\mathcal{Q}},\hat{\Sigma},\hat{n},\hat{\Phi}) be the set of Turing machines M={δ,Σ,𝒬,qstart,qaccept,qreject}M=\{\delta,\Sigma,\mathcal{Q},q_{start},q_{accept},q_{reject}\} and words xΣx\in\Sigma^{*} with number of states, size of alphabet, size of input, and number of steps in computation history bounded by 𝒬^,Σ^,n^,Φ^\hat{\mathcal{Q}},\hat{\Sigma},\hat{n},\hat{\Phi} respectively. For any (M,x)Δ(M,x)\in\Delta, let {xt}t[Φ]\{x_{t}\}_{t\in[\Phi]} be the computation history of the UTM on (M,x)(M,x). Let the autoregressive computation history (see Definition Definition) of MHLAΘ\text{MHLA}_{\Theta} on input (M,x)(M,x) be denoted CHΘ(M,x)={Z1,Z2,,ZΦ}\text{CH}_{\Theta}(M,x)=\{Z^{1},Z^{2},...,Z^{\Phi}\}. Then there exists a set of parameters ΘΩH\Theta\in\Omega_{H} for H=O(n^Φ^Σ^)H=O(\hat{n}\hat{\Phi}\hat{\Sigma}) and embedding dimension d=O(n^Φ^Σ^max(Σ^,𝒬^))d=O(\hat{n}\hat{\Phi}\hat{\Sigma}\max(\hat{\Sigma},\hat{\mathcal{Q}})), such that for all (M,x)Δ(M,x)\in\Delta, the TM computation history at time step tt is equivalent to the autoregressive computation history at time step c(t)c(t) where c(t)O((n+t)t)c(t)\leq O((n+t)t) i.e Zc(t)[:length(xt))]=xtZ^{c(t)}[:-\text{length}(x^{t}))]=x^{t}. Furthermore, this can be achieved with 2 bits of precision.

Our construction bears similarities to (Pérez et al., 2019; Hahn, 2020; Merrill & Sabharwal, 2023; Merrill et al., 2022; 2021; Liu et al., 2022; Feng et al., 2023); the high-level idea is write down every letter in the computation history of MM on xx. If we use orthogonal vectors to encode every letter, state, and positional embedding we arrive at a natural construction involving a few basic primitives copy, lookup, and if-then-else. For details see discussion section C and Proof C.1

We now proceed to a more detailed discussion of the main technical result Theorem 2.1.

3 Polynomial Algorithm for Learning MHLA

We first show that, given any dataset 𝒟{\mathcal{D}}, there exists a learning algorithm that can recover the optimal parameter Θ\Theta of an MHLA with a fixed latent dimension dd, in the sense of empirical risk minimization.

Algorithm 1 MHLA Learning via Regression
1:  Input: Data D:={(Zi,yi)}i[N]D\vcentcolon=\{(Z_{i},y_{i})\}_{i\in[N]} for Zid×niZ_{i}\in\mathbb{R}^{d\times n_{i}} and ydy\in\mathbb{R}^{d}
2:   {𝒳i}i[N]:=ExtractFeature(D)\{\mathcal{X}_{i}\}_{i\in[N]}\vcentcolon=\text{ExtractFeature}(D) (see Algorithm 8)
3:  
𝒳i:=[z1:,z1:z1niz1:,z2:z1niz1:,zd:z1niz1:,zd:zdniz2:,z1:z1niz2:,z2:z1niz2:,zd:z1niz2:,zd:zdnizd:,z1:z1nizd:,z2:z1nizd:,zd:z1nizd:,zd:zdni].\mathcal{X}_{i}\vcentcolon=\begin{bmatrix}\langle z_{1:},z_{1:}\rangle z_{1n_{i}}&\langle z_{1:},z_{2:}\rangle z_{1n_{i}}&\cdots&\langle z_{1:},z_{d:}\rangle z_{1n_{i}}&\cdots&\langle z_{1:},z_{d:}\rangle z_{dn_{i}}\\ \langle z_{2:},z_{1:}\rangle z_{1n_{i}}&\langle z_{2:},z_{2:}\rangle z_{1n_{i}}&\cdots&\langle z_{2:},z_{d:}\rangle z_{1n_{i}}&\cdots&\langle z_{2:},z_{d:}\rangle z_{dn_{i}}\\ \vdots&\vdots&\ddots&\vdots&\ddots&\vdots\\ \langle z_{d:},z_{1:}\rangle z_{1n_{i}}&\langle z_{d:},z_{2:}\rangle z_{1n_{i}}&\cdots&\langle z_{d:},z_{d:}\rangle z_{1n_{i}}&\cdots&\langle z_{d:},z_{d:}\rangle z_{dn_{i}}\\ \end{bmatrix}~{}. (8)
4:  Create dataset {Xi,a}i[N],a[d]\{X_{i,a}\}_{i\in[N],a\in[d]}. Let Xi,ad2×d2X_{i,a}\in\mathbb{R}^{d^{2}\times d^{2}} be a matrix that is comprised of 𝒳i\mathcal{X}_{i} in the atha^{\prime}th block of dd rows and 0 everywhere else:
5:  
Xi,a=[0𝒳iT0]TX_{i,a}=\begin{bmatrix}0&\ldots&\mathcal{X}_{i}^{T}&\ldots&0\end{bmatrix}^{T} (9)
6:  Let Wd2×d2W\in\mathbb{R}^{d^{2}\times d^{2}} be regressor:
7:  
W^:=argminWd2×d2i[N]a[d](W,Xi,ayi,a)2\hat{W}\vcentcolon=\operatorname*{arg\,min}_{W\in\mathbb{R}^{d^{2}\times d^{2}}}\sum_{i\in[N]}\sum_{a\in[d]}\left(\langle W,X_{i,a}\rangle-y_{i,a}\right)^{2} (10)
8:  Take the SVD of W^=ABT=i[H^]AiBiT\hat{W}=AB^{T}=\sum_{i\in[\hat{H}]}A_{i}B_{i}^{T} where H^\hat{H} is the rank of W^\hat{W}.
9:  Vh=Fold(Ah)V_{h}=\text{Fold}(A_{h}) and Qh=Fold(Bh)Q_{h}=\text{Fold}(B_{h}) where Fold:d2d×d\text{Fold}:\mathbb{R}^{d^{2}}\rightarrow\mathbb{R}^{d\times d} takes a vector p:=[pij for i[d] and j[d]]p\vcentcolon=[p_{ij}\text{ for }i\in[d]\text{ and }j\in[d]] and reshapes into a matrix Pd×dP\in\mathbb{R}^{d\times d} such that Pij=pijP_{ij}=p_{ij}.
10:  Return: {Vh,Qh}h[H^]\{V_{h},Q_{h}\}_{h\in[\hat{H}]}

See 2.1

Proof Idea:

First we write down the loss, and observe that a one-layer attention network is a quadratic polynomial in {Vh,Qh}h[H]\{V_{h},Q_{h}\}_{h\in[H]} of input features Xi,aX_{i,a}.

Θ({(Zi,yi)}i[N])=1Ni[N]a[d](𝒯Θ,Xi,ayi,a)2\mathcal{L}_{\Theta}(\{(Z_{i},y_{i})\}_{i\in[N]})=\frac{1}{N}\sum_{i\in[N]}\sum_{a\in[d]}(\left\langle\mathcal{T}_{\Theta},X_{i,a}\right\rangle-y_{i,a})^{2} (11)

Here

𝒯Θ:=h[H]flatten(Vh)flatten(Qh)T=h[H][Vh,00Qh,00Vh,00Qh,01Vh,00Qh,ddVh,01Qh,00Vh,01Qh,01Vh,01Qh,ddVh,ddQh,00Vh,ddQh,01Vh,ddQh,dd]\mathcal{T}_{\Theta}\vcentcolon=\sum_{h\in[H]}\text{flatten}(V_{h})\text{flatten}(Q_{h})^{T}=\sum_{h\in[H]}\begin{bmatrix}V_{h,00}Q_{h,00}&V_{h,00}Q_{h,01}&\ldots&V_{h,00}Q_{h,dd}\\ V_{h,01}Q_{h,00}&V_{h,01}Q_{h,01}&\ldots&V_{h,01}Q_{h,dd}\\ \vdots&\vdots&\vdots\\ V_{h,dd}Q_{h,00}&V_{h,dd}Q_{h,01}&\ldots&V_{h,dd}Q_{h,dd}\end{bmatrix} (12)

Now we relax this objective by replacing 𝒯Θ\mathcal{T}_{\Theta} with an unconstrained matrix Wd2×d2W\in\mathbb{R}^{d^{2}\times d^{2}}. Where 𝒯Θ\mathcal{T}_{\Theta} is a rank-HH matrix, we allow WW to be a general matrix, so this relaxation is guaranteed to have a smaller loss. Furthermore, the loss can be optimized via ordinary least squares. Finally, if we apply SVD to WW we obtain a set of d2d^{2} left and right singular vectors scaled by square root the magnitude of the singular value. Here the scaled left singular vectors correspond to V^h\hat{V}_{h} and the scaled right singular vectors correspond to Q^h\hat{Q}_{h} for h[H^]h\in[\hat{H}]. Since the rank of WW is no greater than d2d^{2} the resulting MHLA satisfies H^d2\hat{H}\leq d^{2}. The sample complexity follows from classical results in VC theory (Kearns & Vazirani, 1994). For full proof see Appendix A.

4 Certificate for identifiability of linear attention

We begin by defining identifiability of a model class with respect to a dataset.

Definition (Identifiability).

Let D={(Zi,yi)}i[N]D=\{(Z_{i},y_{i})\}_{i\in[N]}. Let 𝒰Θ\mathcal{U}_{\Theta} denote a model class which is a uniform circuit family parameterized by parameters ΘΩ\Theta\in\Omega. Let \mathcal{L} be a loss function and ΩERM\Omega_{\text{ERM}} be the set of empirical risk minimizers:

ΩΘ={Θ^ΩΘ^=argminΘΩ(𝒰Θ,D)}.\Omega_{\Theta}=\{\hat{\Theta}\in\Omega\mid\hat{\Theta}=\operatorname*{arg\,min}\nolimits_{\Theta\in\Omega}\mathcal{L}(\mathcal{U}_{\Theta},D)\}. (13)

We say model class 𝒰Θ\mathcal{U}_{\Theta} is identifiable with respect to the dataset DD if for all Zd×nZ\in\mathbb{R}^{d\times n^{\prime}}, and for all pairs of empirical risk minimizers Θ,ΘΩERM\Theta,\Theta^{\prime}\in\Omega_{\text{ERM}} we have 𝒰Θ\mathcal{U}_{\Theta} and 𝒰Θ\mathcal{U}_{\Theta^{\prime}} compute the same function, i.e., they agree on all inputs (are the same uniform circuit family):

𝒰Θ(Z)=𝒰Θ(Z).\mathcal{U}_{\Theta}(Z)=\mathcal{U}_{\Theta^{\prime}}(Z). (14)

In establishing conditions for identifiability, it will be useful to refer to another condition relating models to datasets.

Definition (Realizability).

Let ΘΩH\Theta\in\Omega_{H} be an MHLA parameterization. We say a dataset D={(Zi,yi)}i[N]D=\{(Z_{i},y_{i})\}_{i\in[N]} is realizable by a parameterization Θ\Theta if yi=MHLAΘ(Zi)y_{i}=\text{MHLA}_{\Theta}(Z_{i}).

The definition of realizability can be modified to include independent noise at the expense of adding some terms to our analyses. See Lemma 4.4 for details.

Next, we prove that for the model class MHLA there is an efficiently checkable condition (certificate) of the data DD that guarantees the model class is identifiable with respect to DD. Our results follow by reinterpreting the results of Section 3 with a focus on data conditions that uniquely determine the optimal regressor. In this section we denote the mapping from data to feature space to be \mathcal{H} and the mapping from parameters to feature space to be pp which are analogous to the XX and 𝒯Θ\mathcal{T}_{\Theta} of Section 3. The overall structure is the following:

  1. 1.

    In Lemma 4.1, we prove that if we define \mathcal{H} to be the polynomial defined in Algorithm 2, then checking the eigenvalue λmin(ΛD)>0\lambda_{min}(\Lambda_{D})>0 provides a certificate that the model class is identifiable with respect to the data.

  2. 2.

    In Corollary 2, we establish that there exists a polynomial pp mapping parameter space to feature space such that MHLAΘ=MHLAΘ\text{MHLA}_{\Theta}=\text{MHLA}_{\Theta^{\prime}} if and only if p(Θ)=p(Θ)p(\Theta)=p(\Theta^{\prime}) for any pair of parameters Θ,ΘΩH\Theta,\Theta^{\prime}\in\Omega_{H}. In other words the mapping pp determines the equivalence class of parameters that compute the same function. We crucially use this fact in our empirical demonstration where we measure functional distance of models with learned parameters vs. ground truth parameters by comparing their distance in feature space.

We instantiate the feature mapping \mathcal{H} and parameter mapping polynomial pp as follows.

Lemma 4.1 (Certificate of Identifiability).

Let dataset D={(Zi,yi)}i[N]D=\{(Z_{i},y_{i})\}_{i\in[N]} be a realizable dataset. Let ={n}n=1\mathcal{H}=\{\mathcal{H}_{n}\}_{n=1}^{\infty} be a family of polynomials n:d×nψ\mathcal{H}_{n}:\mathbb{R}^{d\times n}\rightarrow\mathbb{R}^{\psi} for ψ=(d2)d+d2\psi={d\choose 2}d+d^{2} defined as follows. We index the entries of \mathcal{H} by taking the Kronecker product between all sets of pairs {j,k}\{j,k\} (for all j,k[d]j,k\in[d]) with with all [d]\ell\in[d]. We define (Z){j,k}\mathcal{H}(Z)_{\{j,k\}\ell} as in Algorithm 2 to be

(Z){j,k}:=zj:,zk:zni.\mathcal{H}(Z)_{\{j,k\}\ell}\vcentcolon=\langle z_{j:},z_{k:}\rangle z_{\ell n_{i}}. (15)

Then if λmin(𝔼D[(Z)(Z)T])>0\lambda_{min}\left(\mathbb{E}_{D}\left[\mathcal{H}(Z)\mathcal{H}(Z)^{T}\right]\right)>0, we have that MHLAΘ\text{MHLA}_{\Theta} is identifiable with respect to DD.

Next we construct a mapping p:Ωd×ψp:\Omega\rightarrow\mathbb{R}^{d\times\psi} that partitions parameter space into equivalence classes of parameters that compute the same function. This is akin to matching coefficients in polynomial regression. While not necessary to prove identifiability, this mapping defines a meaningful notion of “distance” between different transformer parameters by constructing a feature space in which computationally equivalent models have the same representation. We denote the aa’th row of pp to be pa:Ωψp_{a}:\Omega\rightarrow\mathbb{R}^{\psi} and define it as follows.

Corollary 2.

Let {pa}a[d]\{p_{a}\}_{a\in[d]} be a collection of polynomials such that pa(Θ):ΩHψp_{a}(\Theta):\Omega_{H}\rightarrow\mathbb{R}^{\psi} is defined as follows. Each pa(Θ)p_{a}(\Theta) is indexed by pairs {j,k}\{j,k\} for j,k[d]j,k\in[d] and [d]\ell\in[d] defined to be

pa(Θ){j,k}=h[H](Vh,ajQk+Vh,akQj).p_{a}(\Theta)_{\{j,k\}\ell}=\sum_{h\in[H]}\left(V_{h,aj}Q_{k\ell}+V_{h,ak}Q_{j\ell}\right)~{}. (16)

Let the polynomial p:Ωd×ψp:\Omega\rightarrow\mathbb{R}^{d\times\psi} with output dimension dd be p:=(p1,p2,,pd)p\vcentcolon=(p_{1},p_{2},...,p_{d}). Then for any pair of parameters Θ,ΘΩH\Theta,\Theta^{\prime}\in\Omega_{H} we have MHLAΘ=MHLAΘ\text{MHLA}_{\Theta}=\text{MHLA}_{\Theta^{\prime}} if and only if p(Θ)=p(Θ)p(\Theta)=p(\Theta^{\prime}).

Algorithm 2 Constructing Features for Certificates of Identifiability
1:  Input: Data D:={Zi}i[N]D\vcentcolon=\{Z_{i}\}_{i\in[N]} for Zid×niZ_{i}\in\mathbb{R}^{d\times n_{i}}
2:  Output: feature vectors (Zi)\mathcal{H}(Z_{i}) for i[N]i\in[N]
3:  for ZiDZ_{i}\in D do
4:     Let z1,z2,zdz_{1},z_{2},...z_{d} be the rows of ZiZ_{i} and let za,bz_{a,b} be the (a,b)(a,b) entry of ZiZ_{i}
5:     for sets {j,k} in Distinct Pairs of Indices in [d]2\text{sets }\{j,k\}\text{ in Distinct Pairs of Indices in }[d]^{2}  do
6:        for [d]\ell\in[d] do
7:           (Zi)=(Zi)[zj:,zk:zni]\mathcal{H}(Z_{i})=\mathcal{H}(Z_{i})\circ\left[\langle z_{j:},z_{k:}\rangle z_{\ell n_{i}}\right]
8:        end for
9:     end for
10:     for j[d]j\in[d] do
11:        for [d]\ell\in[d] do
12:           (Zi)=(Zi)[zj2zni]\mathcal{H}(Z_{i})=\mathcal{H}(Z_{i})\circ\left[\|z_{j}\|^{2}z_{\ell n_{i}}\right]
13:        end for
14:     end for
15:  end for
16:  Return: {(Zi)}i[N]\{\mathcal{H}(Z_{i})\}_{i\in[N]}

We give an overview of a few results using our certifiable identifiability machinery:

  1. 1.

    Data that is drawn from independent noise is certifiably identifiable. If the data matrices ZZ are drawn with each entry being standard normal noise, then MHLAΘ\text{MHLA}_{\Theta} for ΘΩH\Theta\in\Omega_{H} is identifiable with respect to the data. The statement holds beyond standard normals to distributions satisfying weak moment conditions. See Lemma 4.2

  2. 2.

    If the model class is MHLA with more than d2d^{2} heads, then even if the data is not realizable, as long as the minimum eigenvalue of ΛD\Lambda_{D} is greater than zero, MHLA remains identifiable with respect to DD. See Lemma 4.3.

  3. 3.

    This identifiability condition is robust, and can be generalized to ϵ\epsilon-approximate empirical risk minimizers. See Lemma 4.4.

More formally:

Lemma 4.2 (Independent input noise yields identifiability).

Let (Z,y)𝒟(Z,y)\sim\mathcal{D} be a realizable dataset. Let ZZ be drawn from a distribution 𝒵\mathcal{Z} where the (a,b)(a,b)-th entry of ZZ denoted by ZabZ_{ab} is drawn i.i.d. from a distribution ν\nu over \mathbb{R} for all a[d]a\in[d] and b[n]b\in[n]. Let the second and fourth moment of ν\nu be denoted m2m_{2} and m4m_{4} respectively. Let m2>0m_{2}>0 and m4>m22m_{4}>m_{2}^{2}. Then MHLAΘ\text{MHLA}_{\Theta} for ΘΩH\Theta\in\Omega_{H} is identifiable with respect to DD. That is to say, any population risk minimizers Θ,ΘΩPRM\Theta,\Theta^{\prime}\in\Omega_{\text{PRM}}

MHLAΘ=MHLAΘ.\text{MHLA}_{\Theta}=\text{MHLA}_{\Theta^{\prime}}. (17)

When specialized to the case of Multi Head Linear Attention MHLAΘ\text{MHLA}_{\Theta} with more than d2d^{2} heads we can avoid the realizability assumption entirely. This is because the class of MHLA with an arbitrary number of heads is linear in the feature space \mathcal{H} given in Lemma 4.1.

Lemma 4.3 (Identifiability without realizability for MHLA with arbitrarily many heads).

Let dataset D={(Zi,yi)}i[N]D=\{(Z_{i},y_{i})\}_{i\in[N]} be any dataset drawn i.i.d from a distribution 𝒟\mathcal{D}. Let \mathcal{H} be defined as in Lemma 4.1. Then if λmin(𝔼D[(Z)(Z)T])>0\lambda_{\min}\left(\mathbb{E}_{D}[\mathcal{H}(Z)\mathcal{H}(Z)^{T}]\right)>0 then MHLAΘ\text{MHLA}_{\Theta} for ΘΩH\Theta\in\Omega_{H} for any H[d2,)H\in[d^{2},\infty) is identifiable with respect to the data DD. That is,

MHLAΘ=MHLAΘ\text{MHLA}_{\Theta}=\text{MHLA}_{\Theta^{\prime}} (18)

for all pairs of empirical risk minimizers Θ,ΘΩERM\Theta,\Theta^{\prime}\in\Omega_{\text{ERM}}.

We also add a quantitative version of identifiability with precise treatment of issues related to error. (For a corresponding statement of realizability with noise see Lemma B.2.)

Lemma 4.4 (Identifiability with Error).

Let ΩϵERM\Omega_{\epsilon-\text{ERM}} be the set of ϵ\epsilon-approximate empirical risk minimizers.

ΩϵERM={ΘΩH|𝔼(Zi,yi)D[(MHLAΘ(Zi)yi)2]ϵ}.\Omega_{\epsilon-\text{ERM}}=\left\{\Theta\in\Omega_{H}~{}\big{|}~{}\mathbb{E}_{(Z_{i},y_{i})\in D}\left[\left(\text{MHLA}_{\Theta}(Z_{i})-y_{i}\right)^{2}\right]\leq\epsilon\right\}. (19)

Then we have for any Θ,ΘΩϵERM\Theta,\Theta^{\prime}\in\Omega_{\epsilon-\text{ERM}} that for all inputs Zd×nZ\in\mathbb{R}^{d\times n}

MHLAΘ(Z)MHLAΘ(Z)ϵλmin(ΛD)ZF6.\|\text{MHLA}_{\Theta}(Z)-\text{MHLA}_{\Theta^{\prime}}(Z)\|\leq\frac{\epsilon}{\lambda_{\min}\left(\Lambda_{D}\right)}\|Z\|_{F}^{6}. (20)

Proof of all of the above statements is given in Appendix B.

5 Application to Learning Universal Turing Machines

We apply our algorithmic and identifiability machinery to show that an important computational procedure is representable and learnable as an MHLA: namely, a restricted class of universal Turing machines (UTMs) with bounded computation history. We must first generalize our previous MHLA definition to enable multi-step computation:

Definition (Autoregressive MHLA).

Let Z0Z^{0} be an input matrix in dimension d×n\mathbb{R}^{d\times n}. We define the iterative process of Φ\Phi-step autoregressive MHLA as follows: starting from t=0t=0, let the next token yt+1dy^{t+1}\in\mathbb{R}^{d} be:

yt+1=MHLAΘ(Zt),y^{t+1}=\text{MHLA}_{\Theta}(Z^{t})~{}, (21)

and, for all t[Φ]t\in[\Phi], let Zt+1d×(n+1)Z^{t+1}\in\mathbb{R}^{d\times(n+1)} be the concatenation:

Zt+1=Ztyt.Z^{t+1}=Z^{t}\circ y^{t}~{}. (22)

Next we define the computation history of an autoregressive model analogously to the computation history of a Turing machine.

Definition (Autoregressive Computation History).

We refer to CHΘ(Z)={Zt}t[Φ]\text{CH}_{\Theta}(Z)=\{Z^{t}\}_{t\in[\Phi]} as the computation history of the Φ\Phi-step autoregressive MHLA. We denote the tt-th step of the computation history as CHΘt(Z)=Zt\text{CH}_{\Theta}^{t}(Z)=Z^{t}.

We will often use the notation Zt[:k]Z_{t}[:-k] to denote the last k+k\in\mathbb{Z}^{+} tokens of ZtZ_{t}. Often, ZZ will be the embeddings corresponding to a word xx in a language \mathcal{L}, in which case we will use the notation CHΘ(x)\text{CH}_{\Theta}(x) and CHΘ(Z)\text{CH}_{\Theta}(Z) interchangeably. For pedagogical discussion on how to map embeddings to letters in an alphabet, see Section D

Although the theory derived in this paper applies to all functions expressible by MHLAs, we are particularly interested in the task of learning universal Turing machines (UTMs). Let Σ\Sigma be an alphabet. Let 𝒬\mathcal{Q} be a set of states that includes {qstart,qaccept,qreject}\{q_{start},q_{accept},q_{reject}\} a start, accept, and reject state respectively. Let δ:𝒬×Σ𝒬×Σ×{L/R}\delta:\mathcal{Q}\times\Sigma\rightarrow\mathcal{Q}\times\Sigma\times\{L/R\} be a transition function that takes an alphabet and state symbol and maps to a state transition, an output symbol, and a head movement left or right. Typically there is also a tape alphabet Γ\Gamma for which the input alphabet Σ\Sigma is a subset.

Definition (Accept TM).

Let M={δ,Σ,Γ,𝒬,qstart,qaccept,qreject}M=\{\delta,\Sigma,\Gamma,\mathcal{Q},q_{start},q_{accept},q_{reject}\} be a TM. Let xΣx\in\Sigma^{*} be all strings in the alphabet Σ\Sigma. Then let ATMA_{\text{TM}} be the language ATM={(M,x)M accepts x}A_{\text{TM}}=\{(M,x)\mid M\text{ accepts }x\}.

The UTM constructed in Turing’s 1936 paper recognizes ATMA_{\text{TM}}. In practice, we are most often interested in the behavior of TMs that run in polynomial time, and focus below on implementing a universal simulator for this restricted class:

Definition.

(Polynomially Bounded Universal Turing Machine) In general, a UTM is a recognizer for the language ATMA_{\text{TM}}. That is if xx is in ATMA_{\text{TM}}, the UTM accepts, else, the UTM rejects or does not halt. Let ATMPA_{\text{TM}}\cap P be the language of input pairs (M,x)(M,x) for TM MM and word xΣx\in\Sigma^{*} such that MM decides xx in polynomial time. Here, we consider UTM to be the polynomial time decider for ATMPA_{\text{TM}}\cap P.

To define what it means for an autoregressive MHLA to perform the same computation as a TM, our main idea is to construct parameters for MHLA such that it executes the computation history of TM MM on input xx. Let the UTM computation history at step tt include the contents x0,,xktx_{0},\ldots,x_{k_{t}} on the tape after tt transition steps of the Turing machine MM, the current state qtq_{t}, and the current head position hth_{t}. Here ktk_{t} is the number of tokens at timestep tt. Then, there is a single-layer MHLA capable of simulating a UTM:

See 2.2

We include the full proof for the existence of Θ\Theta in the appendix. For simplicity, we adopt a naive embedding scheme that represents different letters in an alphabet as orthogonal unit vectors. This makes it easy to contrive embedding schemes that incorporate arbitrary polynomial-sized circuits which could compute whether x(M)x\in\mathcal{L}(M). Moreover, we adopt positional encodings that are simply orthogonal unit vectors. Thus, in order to give each of TT tokens a unique ID, we would require O(T)O(T) dimensional positional embeddings.

This can be combined with the learnability results above to yield a specialized result for UTMs:

Lemma 5.1 (Learning a UTM).

Let ΘΩH\Theta\in\Omega_{H} in dimension dd be the MHLA parameters in Lemma 2.2. Let {Mi,xi}i[N]\{M_{i},x_{i}\}_{i\in[N]} be pairs of TM’s MM and words xx of maximum length nn drawn i.i.d. from a distribution 𝒟\mathcal{D}. Let Zi=Embed(Mi,xi)Z_{i}=\text{Embed}(M_{i},x_{i}). For each TM/word pair (Mi,xi)(M_{i},x_{i}) let CHΘ(Zi)={Zi1,Zi2,,ZiΦ}\text{CH}_{\Theta}(Z_{i})=\{Z^{1}_{i},Z^{2}_{i},...,Z^{\Phi}_{i}\} be the Φ\Phi-step autoregressive computation history of MHLAΘ\text{MHLA}_{\Theta} on ZiZ_{i}. Let DD be the dataset D:={(CHΘ(Zi)t,yit+1}i[N],t[T]D\vcentcolon=\{(\text{CH}_{\Theta}(Z_{i})^{t},y^{t+1}_{i}\}_{i\in[N],t\in[T]} where yit+1=MHLAΘ(Zit)y^{t+1}_{i}=\text{MHLA}_{\Theta}(Z^{t}_{i}). Then Algorithm 1 applied to input DD returns Θ^ΩH\hat{\Theta}\in\Omega_{H} for Hd2H\leq d^{2} such that with probability 1δ1-\delta

𝔼(Z,y)𝒟[(MHLAΘ^(Z)y)2]ϵ\mathbb{E}_{(Z,y)\in\mathcal{D}}\left[\left(\text{MHLA}_{\hat{\Theta}}(Z)-y\right)^{2}\right]\leq\epsilon (23)

for sample complexity N=poly(d,ϵ1,log(δ1))N=\text{poly}(d,\epsilon^{-1},\log(\delta^{-1})). Then with probability 1δ1-\delta over the randomness in the data, the probability over 𝒟\mathcal{D} that the Φ\Phi-step autoregressive computation history CHΘ^(M,x)\text{CH}_{\hat{\Theta}}(M,x) and CHΘ(M,x)\text{CH}_{\Theta}(M,x) differ is upper bounded by

Pr(M,x)𝒟[CHΘ^(M,x)CHΘ(M,x)]O(ϵΦ).\Pr\nolimits_{(M,x)\sim\mathcal{D}}[\text{CH}_{\hat{\Theta}}(M,x)\neq\text{CH}_{\Theta}(M,x)]\leq O(\epsilon\Phi). (24)

Finally, if the dataset DD is certifiably identifiable, then generalization holds out-of-distribution. For proof see Section C.2.

Lemma 5.2 (Learning UTM from Certifiably Identifiable Data).

Let D={(Zi,yi)}i[N]D=\{(Z_{i},y_{i})\}_{i\in[N]} be a dataset satisfying yi=MHLAΘy_{i}=\text{MHLA}_{\Theta} for ΘΩH\Theta\in\Omega_{H} being the expressibility parameters of Lemma 2.2 for the set of TM’s/words (M,x)Δ(𝒬^,Σ^,n^,Φ^)(M,x)\in\Delta(\hat{\mathcal{Q}},\hat{\Sigma},\hat{n},\hat{\Phi}). If DD is certifiably identifiable with λmin(ΛD)>η\lambda_{min}(\Lambda_{D})>\eta, then there is a poly(d,N,Q^,Σ^,n^,Φ^,η1)\text{poly}(d,N,\hat{Q},\hat{\Sigma},\hat{n},\hat{\Phi},\eta^{-1}) time algorithm that outputs a set of parameters Θ^Ωd2\hat{\Theta}\in\Omega_{d^{2}} such that for all TM’s MM and input words xx in Δ(𝒬^,Σ^,n^,Φ^)\Delta(\hat{\mathcal{Q}},\hat{\Sigma},\hat{n},\hat{\Phi}), we have

CHΘ^(M,x)c(t)[:kt]=xt.\text{CH}_{\hat{\Theta}}(M,x)^{c(t)}[:-k_{t}]=x^{t}~{}. (25)

The c(t)c(t) step of the autoregressive computation history of Θ^\hat{\Theta} is equal to the tt’th step of the computation history of MM on xx.

6 Experiments

The theoretical analysis establishes three key results rephrased in the context of empirical validation:

  • An overparameterized family of linear attention networks can learn linear attention in polynomial time and samples (Theorem 2.1).

  • In the realizable setting, there are sufficient and checkable conditions under which empirical risk minimization recovers the equivalence class of the ground truth parameter values (Lemma 4.1).

  • Linear attention networks can a restricted class of universal Turing machines with polynomial hidden size, using polynomially bounded computation histories for state tracking (Lemma 5.1).

In our experiments, we validate these theoretical predictions in practical settings where Transformers are trained using stochastic gradient descent (SGD), as follows:

  1. 1.

    One interpretation of Theorem 2.1 is that relaxing MHLA learning into an “easy” linear regression problem corresponds to adding extra attention heads, and suggests that adding extra heads might provide optimization benefits even when learning MHLA models in their native form. We investigate the role of over-parameterization in multi-head and multi-layer linear attention networks. For random data generated from linear attention networks, we observe that adding more heads achieves faster convergence of training loss than adding more layers. This suggests that while depth is important for expressiveness, the number of heads is important for optimization (Figure 1).

  2. 2.

    We empirically verify the certificate of identifiability provided by Lemma 4.1 on datasets for associative memory (Bietti et al., 2023; Cabannes et al., 2024) with different choices of embeddings, demonstrating convergence to the equivalence class of the true parameters when λmin(ΛD)>0\lambda_{min}(\Lambda_{D})>0 and converging to spurious solutions when λmin(ΛD)=0\lambda_{min}(\Lambda_{D})=0 (Figure 2).

  3. 3.

    We test the practical data requirements for learning universal DFA executors, testing our polynomial complexity predictions Lemma 5.1. We provide evidence that the sample requirement for learning DFA execution of NN state, alphabet size VV, and LL length words is polynomial in the relevant parameters (Figure 3).

6.1 Do extra heads help independent of the learning algorithm?

Refer to caption
(a) N=512,d=2N=512,d=2
Refer to caption
(b) N=2048,d=4N=2048,d=4
Figure 1: Performance comparison of multi-head, multi-layer linear attention models and the original Transformer model (denoted as full). We trained using stochastic gradient descent (SGD) on synthetic data generated from a single-layer linear attention model for varying training set sizes (NN) and input dimensions (dd), number of heads mm, and number of layers nn. We present mean squared error of the predictions w.r.t number of training epochs. Results demonstrate that multi-head architectures converge faster on different input dimensions and match the performance of our algorithm 1 (convex algorithm). Increasing the number of layers or incorporating multilayer perceptrons (MLPs) and layer normalization did not yield consistent improvements. Shading indicates the standard error over three different runs.

Algorithm 1 returns a d2d^{2} head MHLA which competes with the optimal model on the training data. If the data is generated by a single-head linear attention, our method can be viewed as learning with an over-parameterized model. This raises the question: Are other forms of over-parameterization equally effective in learning linear attention networks? To address this, we train three types of over-parameterized models with SGD on data generated from a single-layer linear attention network: (1) multi-layer linear attention networks, (2) multi-head linear attention networks, (3) a full Transformer layer. The results are presented in Figure 1.

Experimental Setup:

We initialize a single-layer linear attention network with parameters V1×dV\in\mathbb{R}^{1\times d} and Qd×dQ\in\mathbb{R}^{d\times d}, sampled from a Gaussian distribution 𝒩(0,Id)\mathcal{N}(0,\frac{I}{\sqrt{d}}). Input sequences ZiT×dZ^{i}\in\mathbb{R}^{T\times d} are sampled from 𝒩(0,IT)\mathcal{N}(0,\frac{I}{\sqrt{T}}), where i=1,,Ni=1,\ldots,N, T=100T=100 is the maximum number of time steps, and NN is the dataset size. We generate outputs by running the ground-truth network auto-regressively: yti=VZ1:ti(Zi[:,:t]QZi[:,t])y^{i}_{t}=VZ^{i}_{1:t}(Z^{i}[:,:t]QZ^{i}[:,t]), creating our dataset 𝒟={(Zi,yi)}i=1N\mathcal{D}=\{(Z^{i},y^{i})\}_{i=1}^{N}.

In addition to learning with Algorithm 1, we train three types of models on this data using SGD:

  • Multi-head linear attention as in Equation 1.

  • Multi-layer linear attention with a single head.

  • An ordinary Transformer network (Vaswani et al., 2017) with softmax attention, multi-layer perceptron blocks, and layer normalization.

For detailed hyperparameters and optimization procedures, please refer to Section D.1.

Multi-head attention outperforms both multi-layer attention and the full transformer:

We observe that multi-head attention scales effectively with an increasing number of heads, resulting in improved performance. Notably, for d=2,4d=2,4 input dimensions, using d2d^{2} heads (Figure 1(a)) yields the best performance and is comparable to Algorithm 1, approaching floating-point error precision. In contrast, multi-layer attention models show diminishing returns and performs worse than single-layer attention. Interestingly, adding more layers can sometimes degrade performance. The full transformer model, which incorporates softmax attention, MLP layers and the layer normalization, does not significantly outperform the single-layer linear attention model on this task.

These findings suggest that the type of over-parameterization matters significantly in learning linear attention networks. Interestingly, that multi-head architectures appear to be particularly effective—aligned with the structure of Algorithm 1.

6.2 Does certifiable identifiability predict generalization?

Refer to caption
(a) Minimum Eigenvalue λmin(ΛD)\lambda_{min}(\Lambda_{D}) vs. Euclidean distance in feature space of parameters learned by Algorithm 1.
Refer to caption
(b) Distance to ground truth parameters in feature space for certifiably identifiable data (min eigenvalue =0.06=0.06) vs. nonidentifiable data (min eigenvalue =0=0). Here the parameters of MHLA are learned via SGD. Note the error on identifiable data is effectively zero and is barely visible. Error bars are standard error on three different runs.
Figure 2: Impact of data distribution on the associative lookup task performance: We generated training data for an associative lookup task (Bietti et al., 2023; Cabannes et al., 2024) using mixtures of two distributions: (a) Gaussian key and value vectors, and (b) random unitary key and value vectors. By adjusting the mixture probability, we can manipulate the certificate value (minimum eigenvalue of the data covariance matrix), as unitary key–value vectors give rank-deficient “certificates”. In figure 2(a), we demonstrate that as the minimum eigenvalue increases, algorithm 1 converges more closely to the true parameters. For figure 2(b) we see that SGD learns parameters that are equivalent to the ground truth parameters in feature space for certifiably identifiable data, but for unidentifiable data, they are far apart in feature space and therefore compute different functions.

In Lemma 4.1, we developed a certificate that provides a sufficient condition for identifiability. However, it gives a sufficient, not necessary, condition for generalization. To assess the practical relevance of this certificate, we conducted an empirical analysis of convergence in cases where the condition is not satisfied. The results of this analysis are presented in Figure 2.

Associative Memory

Associative Memory (Bietti et al., 2023; Cabannes et al., 2024) is a task of looking up a value in a table with a query. As a single head one-layer linear attention it can be represented with ground truth parameters Θ={V,Q}\Theta=\{V,Q\} where V,Q2d×2dV,Q\in\mathbb{R}^{2d\times 2d}.

V=[000Id×d]Q=[Id×d000].\displaystyle V=\begin{bmatrix}0&0\\ 0&I_{d\times d}\end{bmatrix}\quad Q=\begin{bmatrix}I_{d\times d}&0\\ 0&0\end{bmatrix}.

The data ZZ is drawn as follows: let k1,k2,,kddk_{1},k_{2},...,k_{d}\in\mathbb{R}^{d} be random variables corresponding to keys in a lookup table, let v1,v2,,vddv_{1},v_{2},...,v_{d}\in\mathbb{R}^{d} be random variables corresponding to values in a lookup table, let qdq\in\mathbb{R}^{d} be a random variable corresponding to a query to the lookup table, and ζ𝒩(0,I)\zeta\sim\mathcal{N}(0,I) be some random noise, with:

Z=[k1k2kdqv1v2vdζ].Z=\begin{bmatrix}k_{1}&k_{2}&\ldots&k_{d}&q\\ v_{1}&v_{2}&\ldots&v_{d}&\zeta\end{bmatrix}. (26)

We set the output vector yy to be

y=MHLAΘ¯(Z)=[0j[d]q,kjvj].y=\text{MHLA}_{\bar{\Theta}}(Z)=\begin{bmatrix}0\\ \sum_{j\in[d]}\langle q,k_{j}\rangle v_{j}\end{bmatrix}. (27)

Mixture of distributions:

We generate two datasets, one that has identifiable λmin(ΛD)>0\lambda_{min}(\Lambda_{D})>0 and one that is nonidentifiable with λmin(ΛD)=0\lambda_{min}(\Lambda_{D})=0. The identifiable dataset is generated with {kj}j[d]\{k_{j}\}_{j\in[d]} and {vj}j[d]\{v_{j}\}_{j\in[d]} drawn i.i.d 𝒩(0,I)\mathcal{N}(0,I). The query qq is chosen to be one of the {kj}j[d]\{k_{j}\}_{j\in[d]} uniformly at random. The unidentifiable dataset is drawn such that {kj}j[d]\{k_{j}\}_{j\in[d]} forms a random unitary matrix i.e kj=1\|k_{j}\|=1 for all j[d]j\in[d] and kj,kj=0\langle k_{j},k_{j^{\prime}}\rangle=0 for all jjj\neq j^{\prime}. Similarly, {vj}j[d]\{v_{j}\}_{j\in[d]} is also drawn from a randomly generated unitary matrix. We draw new random unitary matrices for each datapoint, where qq is again chosen to be one of the {kj}j[d]\{k_{j}\}_{j\in[d]} uniformly at random. We set d=4d=4 dimensions for both datasets, and draw N=214N=2^{14} samples for each dataset. We mix the two datasets together with a mixing probability ranging from 95% unidentifiable to 100% unidentifiable. In this manner we generate a spread of datasets with different values for λmin(ΛD)\lambda_{min}(\Lambda_{D}) that tend to zero.

Certifiable Identifiability for Algorithm 1:

For each dataset, we run algorithm 1 which returns Θ^\hat{\Theta}. We compare Θ^\hat{\Theta} to the ground truth Θ\Theta in feature space via the distance metric

d(Θ,Θ^):=p(Θ)p(Θ^)F.d(\Theta,\hat{\Theta})\vcentcolon=\|p(\Theta)-p(\hat{\Theta})\|_{F}. (28)

Here, pp is the polynomial given in Lemma 4.1. Recall from Corollary 2 that pp defines the equivalence class of parameters that compute the same function, i.e., MHLAΘ=MHLAΘ^\text{MHLA}_{\Theta}=\text{MHLA}_{\hat{\Theta}} if and only if p(Θ)=p(Θ^)p(\Theta)=p(\hat{\Theta}). On each dataset, we measure the certificate value λmin(ΛD)\lambda_{min}(\Lambda_{D}) on the x-axis vs. d(Θ,Θ^)d(\Theta,\hat{\Theta}) on the y-axis. In Figure 2, we see that as the certificate value increases, d(Θ,Θ^)d(\Theta,\hat{\Theta}) decreases, indicating that MHLAΘ\text{MHLA}_{\Theta} and MHLAΘ^\text{MHLA}_{\hat{\Theta}} compute the same function.

Certifiable Identifiability for MHLA:

Our notion of certifiable identifiability in Lemma 4.1 applies to any empirical risk minimizer. Therefore, it applies to popular optimizers like SGD and Adam if they achieve the minimum of the loss, which is in our synthetic case equal to zero. In Figure 2(b), we train MHLA models via SGD with 1,2,4,1,2,4, and 88 heads. For identifiable data with minimum eigenvalue 0.060.06, we see that the learned parameters and ground truth parameters are the same in feature space. However, for unidentifiable data with minimum eigenvalue 0, learned parameters and ground truth parameters are far apart in feature space and therefore compute different functions.

6.3 Learning the Computation History of Deterministic Finite Automata

Universal automata (like the universal Turing machine discussed in Section C.2) receive descriptions of other automata as input, and simulate them to produce an output. Here we empirically evaluate the ability of MHLA models to perform universal simulation of deterministic finite automata (DFAs). We limit our study to DFAs with a maximum number of states (NN), alphabet size (VV), and input length (LL). While recent work on in-context learning (Akyürek et al., 2024) has focused on inferring DFA behavior from input–output examples, here, we aim to simulate DFAs given explicit descriptions of their state transitions as input—a task somewhat analogous to instruction following in large scale language models.

The construction in Lemma 5.1 shows that a linear attention layer can output the polynomially bounded computation history of any TM (and therefore any DFA). Our construction requires embedding size linear with maximum length of computation history, number of states and alphabet size. Therefore, we predict the data requirements are polynomial in each of N,VN,V and LL.

Refer to caption
Figure 3: Data requirement for universal DFA simulation: We train a fixed sized Transformer (4-layers, 16 heads and 2048 hidden dimensions) to simulate a DFA given a transition table and input word. The vertical axis shows the number of tokens (expressed as word length LL times the number of examples QQ) required to obtain 99% next token accuracy.

Dataset

Our dataset consists of strings containing three components: the input DFA’s transition function δ:𝒬×Σ𝒬\delta:\mathcal{Q}\times\Sigma\rightarrow\mathcal{Q}, the input word xΣLx\in\Sigma^{L} and the computation history h𝒬Lh\in\mathcal{Q}^{L} which is the sequence of states visited in the DFA as it decides if xx is in its language. The first two components are the input to the model, while the computation history is the target output. We adopt the following schema for representing δ,x,\delta,x, and hh:

(si,w,sj),,si𝒬,wΣδDFA transition functionw0w1wLword(s0w0s1),(s1w1s2),,(sL1wLsL)computation history\underbrace{(s_{i},w,s_{j}),\dots,\forall_{s_{i}\in\mathcal{Q},w\in\Sigma}\in\delta}_{\text{DFA transition function}}\mid\underbrace{w_{0}w_{1}\dots w_{L}}_{\text{word}}\mid\underbrace{(s^{0}w_{0}s^{1}),(s^{1}w_{1}s^{2}),\dots,(s^{L-1}w_{L}s^{L})}_{\text{computation history}}

We encode each input-output relation in the transition function as a sequence of three tokens (si,w,sj)(s_{i},w,s_{j}) where δ(si,w)=sj\delta(s_{i},w)=s_{j}. We also include two parantheses to separate each triplet of tokens for a total of five tokens for each input-output relation. The total description length of δ\delta is then 5𝒬Σ5\mathcal{Q}\Sigma. We encode word xx of length LL as a sequence of LL tokens. Finally, we encode the computation history as the sequence of state transitions the DFA visits when deciding if xx is in its language. Here we designate s0s0 as the start state, and let si=δ(si1,wi1)s^{i}=\delta(s^{i-1},w^{i-1}). Each state transition is again represented by a triplet (s,w,δ(s,w))(s,w,\delta(s,w)). We train an autoregressive Transformer model using cross-entropy loss to predict the computation history tokens given the transition function and word. Please refer to Section D.2 for hyperparameter details.

Results

In Figure 3, we vary each of the parameters QQ, LL and VV, while the other two parameters are fixed to a constant (in this case we fix them to be 4). Then, on the vertical axis, we display the minimum number of tokens (number of examples times the word length) required to get 99% accuracy on the next token prediction. Plots are suggestive of a sub-exponential dependence on DFA complexity.

7 Related Work

7.1 Formal Expressivity of Transformers

A large body of work has been trying to tackle the problem of quantifying what algorithmic tasks can a Transformer do, in terms of various kinds of circuit families (Pérez et al., 2019; Edelman et al., 2022b; Hahn, 2020; Merrill & Sabharwal, 2023; Merrill et al., 2022; 2021; Liu et al., 2022; Feng et al., 2023). In particular, researchers have studied how Transformers can realize specific DSLs (Weiss et al., 2021), logic expressions (Dong et al., 2019; Barceló et al., 2020; 2024), Turing machines (Dehghani et al., 2018; Giannou et al., 2023; Pérez et al., 2021), formal language recognition (Hao et al., 2022; Chiang et al., 2023), as well as automata and universal Turing machines (Liu et al., 2022; Li et al., 2024). However, while these works primarily focus on determining the types of problems whose solutions a Transformer can express, they often overlook the crucial question of how these solutions can be learned from data. Moreover, there is limited discussion on the sufficiency of the dataset itself—whether the data available can identify the underlying “true” function or algorithm that we aim to capture.

7.2 Learning Transformers

We break down the literature on learning transformers. First, there is the literature on statistical learnability, where the focus is on the amount of data required to learn without considering whether there is a tractable algorithm for learning (Edelman et al., 2022a; Wei et al., 2021; Zhang et al., 2024; Trauger & Tewari, 2023).

Second, there are learnability results for single head transformers for data distributions under a variety of assumptions. In particular, Zhang et al. (2023) provide learnability results for in-context linear regression; Jelassi et al. (2022) show that data with spatial structure can be learned; the work of Tian et al. (2023) analyzes SGD training dynamics for a toy model for data; and the work of Oymak et al. (2023) studies the prompt attention model.

Third, the literature on provable guarantees for learning multi head attention is rather sparse. Fu et al. (2023) give learnability results in a regime where attention matrices are fixed and only the projection matrices are trained. The work of Tarzanagh et al. (2024) show connections between single layer attention optimization to SVM learning. Under a good gradient initialization condition, overparameterization condition, and a condition on the scores of optimal tokens the global convergence of gradient descent to a particular SVM problem can be established. Deora et al. (2023) analyze a setting of learning multi head attention with gradient descent under their Assumption 2. In the words of the authors ”these conditions are related to the realizability condition, which guarantees obtaining small training error near initialization”, which they instantiate with the separability of the data in an NTK space and a proximity of initialization to realizable parameters. Interestingly, they find that multi head attention has benign optimization properties. Finally, Chen & Li (2024) study learning for multi head attention for well structured data that is drawn independent Bernoulli or Gaussian. They provide an extensive discussion of lower bounds for learning multi head attention.

Acknowledgments

We gratefully acknowledge support from NSF grants IIS-2214177, IIS-2238240, CCF-2112665 and DMS-2134108; from AFOSR grant FA9550-22-1-0249; from ONR MURI grant N00014-22-1-2740; and from ARO grant W911NF-23-1-0034; from the OpenPhilanthropy Foundation; from MIT Quest for Intelligence; from the MIT-IBM Watson AI Lab; from ONR Science of AI; from Simons Center for the Social Brain; and from an Alexander von Humboldt fellowship. Any opinions, findings, and conclusions or recommendations expressed in this material are those of the authors and do not necessarily reflect the views of our sponsors.

References

  • Ahn et al. (2024) Kwangjun Ahn, Xiang Cheng, Minhak Song, Chulhee Yun, Ali Jadbabaie, and Suvrit Sra. Linear attention is (maybe) all you need (to understand transformer optimization), 2024. URL https://arxiv.org/abs/2310.01082.
  • Akyürek et al. (2024) Ekin Akyürek, Bailin Wang, Yoon Kim, and Jacob Andreas. In-context language learning: Architectures and algorithms, 2024. URL https://arxiv.org/abs/2401.12973.
  • Barceló et al. (2020) Pablo Barceló, Egor V Kostylev, Mikael Monet, Jorge Pérez, Juan Reutter, and Juan-Pablo Silva. The logical expressiveness of graph neural networks. In ICLR, 2020.
  • Barceló et al. (2024) Pablo Barceló, Alexander Kozachinskiy, Anthony Widjaja Lin, and Vladimir Podolskii. Logical languages accepted by transformer encoders with hard attention. 2024.
  • Bietti et al. (2023) Alberto Bietti, Vivien Cabannes, Diane Bouchacourt, Herve Jegou, and Leon Bottou. Birth of a transformer: A memory viewpoint, 2023. URL https://arxiv.org/abs/2306.00802.
  • Cabannes et al. (2024) Vivien Cabannes, Berfin Simsek, and Alberto Bietti. Learning associative memories with gradient descent, 2024. URL https://arxiv.org/abs/2402.18724.
  • Chen & Li (2024) Sitan Chen and Yuanzhi Li. Provably learning a multi-head attention layer, 2024. URL https://arxiv.org/abs/2402.04084.
  • Chiang et al. (2023) David Chiang, Peter Cholak, and Anand Pillay. Tighter bounds on the expressivity of transformer encoders. arXiv preprint arXiv:2301.10743, 2023.
  • Dehghani et al. (2018) Mostafa Dehghani, Stephan Gouws, Oriol Vinyals, Jakob Uszkoreit, and Łukasz Kaiser. Universal transformers. arXiv preprint arXiv:1807.03819, 2018.
  • Deora et al. (2023) Puneesh Deora, Rouzbeh Ghaderi, Hossein Taheri, and Christos Thrampoulidis. On the optimization and generalization of multi-head attention, 2023. URL https://arxiv.org/abs/2310.12680.
  • Dong et al. (2019) Honghua Dong, Jiayuan Mao, Tian Lin, Chong Wang, Lihong Li, and Denny Zhou. Neural logic machines. In ICLR, 2019.
  • Edelman et al. (2022a) Benjamin L. Edelman, Surbhi Goel, Sham Kakade, and Cyril Zhang. Inductive biases and variable creation in self-attention mechanisms, 2022a. URL https://arxiv.org/abs/2110.10090.
  • Edelman et al. (2022b) Benjamin L Edelman, Surbhi Goel, Sham Kakade, and Cyril Zhang. Inductive biases and variable creation in self-attention mechanisms. In International Conference on Machine Learning, pp. 5793–5831. PMLR, 2022b.
  • Feng et al. (2023) Guhao Feng, Yuntian Gu, Bohang Zhang, Haotian Ye, Di He, and Liwei Wang. Towards revealing the mystery behind chain of thought: a theoretical perspective. arXiv preprint arXiv:2305.15408, 2023.
  • Fu et al. (2023) Hengyu Fu, Tianyu Guo, Yu Bai, and Song Mei. What can a single attention layer learn? a study through the random features lens, 2023. URL https://arxiv.org/abs/2307.11353.
  • Giannou et al. (2023) Angeliki Giannou, Shashank Rajput, Jy-yong Sohn, Kangwook Lee, Jason D Lee, and Dimitris Papailiopoulos. Looped transformers as programmable computers. arXiv preprint arXiv:2301.13196, 2023.
  • Hahn (2020) Michael Hahn. Theoretical limitations of self-attention in neural sequence models. Transactions of the Association for Computational Linguistics, 8:156–171, 2020.
  • Hao et al. (2022) Yiding Hao, Dana Angluin, and Robert Frank. Formal language recognition by hard attention transformers: Perspectives from circuit complexity. Transactions of the Association for Computational Linguistics, 10:800–810, 2022.
  • Jelassi et al. (2022) Samy Jelassi, Michael E. Sander, and Yuanzhi Li. Vision transformers provably learn spatial structure, 2022. URL https://arxiv.org/abs/2210.09221.
  • Katharopoulos et al. (2020) Angelos Katharopoulos, Apoorv Vyas, Nikolaos Pappas, and François Fleuret. Transformers are rnns: Fast autoregressive transformers with linear attention, 2020. URL https://arxiv.org/abs/2006.16236.
  • Kearns & Vazirani (1994) Michael J. Kearns and Umesh Vazirani. An Introduction to Computational Learning Theory. The MIT Press, 08 1994. ISBN 9780262276863. doi: 10.7551/mitpress/3897.001.0001. URL https://doi.org/10.7551/mitpress/3897.001.0001.
  • Kingma & Ba (2014) Diederik P Kingma and Jimmy Ba. Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980, 2014.
  • Li et al. (2024) Zhiyuan Li, Hong Liu, Denny Zhou, and Tengyu Ma. Chain of thought empowers transformers to solve inherently serial problems. arXiv preprint arXiv:2402.12875, 2024.
  • Liu et al. (2022) Bingbin Liu, Jordan T Ash, Surbhi Goel, Akshay Krishnamurthy, and Cyril Zhang. Transformers learn shortcuts to automata. arXiv preprint arXiv:2210.10749, 2022.
  • Loshchilov & Hutter (2018) Ilya Loshchilov and Frank Hutter. Fixing weight decay regularization in adam, 2018. URL https://openreview.net/forum?id=rk6qdGgCZ.
  • Luo et al. (2023) Zhezheng Luo, Jiayuan Mao, Joshua B Tenenbaum, and Leslie Pack Kaelbling. On the expressiveness and generalization of hypergraph neural networks. In Learning on Graphs Conference, 2023.
  • Merrill & Sabharwal (2023) William Merrill and Ashish Sabharwal. The parallelism tradeoff: Limitations of log-precision transformers. Transactions of the Association for Computational Linguistics, 11:531–545, 2023.
  • Merrill et al. (2021) William Merrill, Yoav Goldberg, and Noah A Smith. On the power of saturated transformers: A view from circuit complexity. arXiv preprint arXiv:2106.16213, 2021.
  • Merrill et al. (2022) William Merrill, Ashish Sabharwal, and Noah A Smith. Saturated transformers are constant-depth threshold circuits. Transactions of the Association for Computational Linguistics, 10:843–856, 2022.
  • Oymak et al. (2023) Samet Oymak, Ankit Singh Rawat, Mahdi Soltanolkotabi, and Christos Thrampoulidis. On the role of attention in prompt-tuning, 2023. URL https://arxiv.org/abs/2306.03435.
  • Pérez et al. (2019) Jorge Pérez, Javier Marinković, and Pablo Barceló. On the turing completeness of modern neural network architectures. In ICLR, 2019.
  • Pérez et al. (2021) Jorge Pérez, Pablo Barceló, and Javier Marinkovic. Attention is turing complete. The Journal of Machine Learning Research, 22(1):3463–3497, 2021.
  • Strobl et al. (2024) Lena Strobl, William Merrill, Gail Weiss, David Chiang, and Dana Angluin. What formal languages can transformers express? a survey. Transactions of the Association for Computational Linguistics, 12:543–561, 2024.
  • Tarzanagh et al. (2024) Davoud Ataee Tarzanagh, Yingcong Li, Christos Thrampoulidis, and Samet Oymak. Transformers as support vector machines, 2024. URL https://arxiv.org/abs/2308.16898.
  • Tian et al. (2023) Yuandong Tian, Yiping Wang, Beidi Chen, and Simon Du. Scan and snap: Understanding training dynamics and token composition in 1-layer transformer, 2023. URL https://arxiv.org/abs/2305.16380.
  • Touvron et al. (2023) Hugo Touvron, Louis Martin, Kevin Stone, Peter Albert, Amjad Almahairi, Yasmine Babaei, Nikolay Bashlykov, Soumya Batra, Prajjwal Bhargava, Shruti Bhosale, et al. Llama 2: Open foundation and fine-tuned chat models. arXiv preprint arXiv:2307.09288, 2023.
  • Trauger & Tewari (2023) Jacob Trauger and Ambuj Tewari. Sequence length independent norm-based generalization bounds for transformers, 2023. URL https://arxiv.org/abs/2310.13088.
  • Vaswani et al. (2017) Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. Advances in neural information processing systems, 30, 2017.
  • Wei et al. (2021) Colin Wei, Yining Chen, and Tengyu Ma. Statistically meaningful approximation: a case study on approximating turing machines with transformers. CoRR, abs/2107.13163, 2021. URL https://arxiv.org/abs/2107.13163.
  • Weiss et al. (2021) Gail Weiss, Yoav Goldberg, and Eran Yahav. Thinking like transformers. In International Conference on Machine Learning, pp. 11080–11090. PMLR, 2021.
  • Zhang et al. (2023) Ruiqi Zhang, Spencer Frei, and Peter L. Bartlett. Trained transformers learn linear models in-context, 2023. URL https://arxiv.org/abs/2306.09927.
  • Zhang et al. (2024) Yufeng Zhang, Boyi Liu, Qi Cai, Lingxiao Wang, and Zhaoran Wang. An analysis of attention via the lens of exchangeability and latent variable models, 2024. URL https://arxiv.org/abs/2212.14852.

Appendix A Proof of Main Theorem

See 2.1

Proof.

First we write down the loss.

Θ({(Zi,yi)}i[N]):=1Ni[N]h[H]VhZi(ZiTQhZ[:,ni])yiF2=1Ni[N]a[d](h[H]eaTVhZi(ZiTQhZ[:,ni])yi,a)2\mathcal{L}_{\Theta}(\{(Z_{i},y_{i})\}_{i\in[N]})\vcentcolon=\frac{1}{N}\sum_{i\in[N]}\left\|\sum_{h\in[H]}V_{h}Z_{i}(Z_{i}^{T}Q_{h}Z[:,n_{i}])-y_{i}\right\|_{F}^{2}\\ =\frac{1}{N}\sum_{i\in[N]}\sum_{a\in[d]}\left(\sum_{h\in[H]}e_{a}^{T}V_{h}Z_{i}(Z_{i}^{T}Q_{h}Z[:,n_{i}])-y_{i,a}\right)^{2} (29)

Observe that the one layer attention network is a quadratic polynomial in {Vh,Qh}h[H]\{V_{h},Q_{h}\}_{h\in[H]}.

=1Ni[N]a[d](𝒯Θ,Xi,ayi,a)2=\frac{1}{N}\sum_{i\in[N]}\sum_{a\in[d]}(\left\langle\mathcal{T}_{\Theta},X_{i,a}\right\rangle-y_{i,a})^{2} (30)

Here

𝒯Θ:=h[H]flatten(Vh)flatten(Qh)T=h[H][Vh,00Qh,00Vh,00Qh,01Vh,00Qh,ddVh,01Qh,00Vh,01Qh,01Vh,01Qh,ddVh,ddQh,00Vh,ddQh,01Vh,ddQh,dd]\mathcal{T}_{\Theta}\vcentcolon=\sum_{h\in[H]}\text{flatten}(V_{h})\text{flatten}(Q_{h})^{T}=\sum_{h\in[H]}\begin{bmatrix}V_{h,00}Q_{h,00}&V_{h,00}Q_{h,01}&\ldots&V_{h,00}Q_{h,dd}\\ V_{h,01}Q_{h,00}&V_{h,01}Q_{h,01}&\ldots&V_{h,01}Q_{h,dd}\\ \vdots&\vdots&\vdots\\ V_{h,dd}Q_{h,00}&V_{h,dd}Q_{h,01}&\ldots&V_{h,dd}Q_{h,dd}\end{bmatrix} (31)

Now we relax the objective where we replace 𝒯Θ\mathcal{T}_{\Theta} with an unconstrained matrix Wd2×d2W\in\mathbb{R}^{d^{2}\times d^{2}}. Another way to put it is that 𝒯Θ\mathcal{T}_{\Theta} is rank-HH but WW can be a general matrix. Because the space of general rank matrices is larger, we have written down a relaxation guaranteed to have a smaller loss. Furthermore the loss can be optimized via ordinary least squares.

minWd2×d2W({(Zi,yi)}i[N]):=1Ni[N]a[d](W,Xi,ayi,a)2minΘΩHΘ({(Zi,yi)}i[N])+ϵ\min_{W\in\mathbb{R}^{d^{2}\times d^{2}}}\mathcal{L}_{W}(\{(Z_{i},y_{i})\}_{i\in[N]})\vcentcolon=\frac{1}{N}\sum_{i\in[N]}\sum_{a\in[d]}(\left\langle W,X_{i,a}\right\rangle-y_{i,a})^{2}\\ \leq\min_{\Theta\in\Omega_{H}}\mathcal{L}_{\Theta}(\{(Z_{i},y_{i})\}_{i\in[N]})+\epsilon (32)

Thus the optimum of the regression with respect to the data achieves optimum of the loss to error ϵ\epsilon in time O(1ϵd4N)O(\frac{1}{\epsilon}d^{4}N). The sample complexity to achieve error ϵ\epsilon is then O(1ϵ(d4+log(δ1)))O(\frac{1}{\epsilon}(d^{4}+\log(\delta^{-1}))) with probability 1δ1-\delta over the data distribution. Furthermore, if we take the SVD of W=i[H^]AiBiTW=\sum_{i\in[\hat{H}]}A_{i}B_{i}^{T} where we absorb the singular values into the left and right singular vectors we have for Θ^={Fold(Ah),Fold(Bh)}i[H^]\hat{\Theta}=\{\text{Fold}(A_{h}),\text{Fold}(B_{h})\}_{i\in[\hat{H}]}. Let V^h=Fold(Ah)\hat{V}_{h}=\text{Fold}(A_{h}) and Q^h=Fold(Bh)\hat{Q}_{h}=\text{Fold}(B_{h})

Θ^({(Zi,yi)}i[N]):=1Ni[N]h[H^]V^hZi(ZiTQ^hZi[:,ni])yiF2=1Ni[N]a[d](h[H^]V^hZi(ZiTQ^hZi[:,ni])yi,a)2ϵ\mathcal{L}_{\hat{\Theta}}(\{(Z_{i},y_{i})\}_{i\in[N]})\vcentcolon=\frac{1}{N}\sum_{i\in[N]}\left\|\sum_{h\in[\hat{H}]}\hat{V}_{h}Z_{i}(Z_{i}^{T}\hat{Q}_{h}Z_{i}[:,n_{i}])-y_{i}\right\|_{F}^{2}\\ =\frac{1}{N}\sum_{i\in[N]}\sum_{a\in[d]}\left(\sum_{h\in[\hat{H}]}\hat{V}_{h}Z_{i}(Z_{i}^{T}\hat{Q}_{h}Z_{i}[:,n_{i}])-y_{i,a}\right)^{2}\leq\epsilon (33)

as desired. ∎

Appendix B Proofs from Identifiability Section

First, we start with a general lemma (Lemma B.1) which states a sufficient condition for identifiability of any model class that can be written as an inner product of a polynomial of parameters Θ\Theta with a polynomial feature mapping \mathcal{H}. If the data is realizable by the model class and ΛD=𝔼D[(Z)(Z)T]\Lambda_{D}=\mathbb{E}_{D}\left[\mathcal{H}(Z)\mathcal{H}(Z)^{T}\right] is full rank then the model class is identifiable with respect to DD.

The following is the certificate of identifiability written in an abstract form involving polynomials to map parameters to feature space and polynomials to map data to feature space. The proof does not require the model to be an MHLA, but we state it in MHLA terms for the sake of concreteness.

Lemma B.1 (General Certificate of Identifiability).

Let dataset D={(Zi,yi)}i[N]D=\{(Z_{i},y_{i})\}_{i\in[N]} be a dataset realizable by ΘΩH\Theta\in\Omega_{H}. Let p:={pa}a[d]p\vcentcolon=\{p_{a}\}_{a\in[d]} be a collection of polynomials pa:Ωψp_{a}:\Omega\rightarrow\mathbb{R}^{\psi} mapping the parameters ΘΩ\Theta\in\Omega to a feature space of fixed dimension ψ+\psi\in\mathbb{Z}^{+}. Let ={n}n=1\mathcal{H}=\{\mathcal{H}_{n}\}_{n=1}^{\infty} be a uniform family of polynomials such that n:d×nψ\mathcal{H}_{n}:\mathbb{R}^{d\times n}\rightarrow\mathbb{R}^{\psi}. Let pp and \mathcal{H} satisfy

MHLAΘ(Z)[a]=pa(Θ),n(Z)\text{MHLA}_{\Theta}(Z)[a]=\langle p_{a}(\Theta),\mathcal{H}_{n}(Z)\rangle (34)

for all Zd×nZ\in\mathbb{R}^{d\times n} for all n[1,)n\in[1,\infty). Then if λmin(𝔼D[(Z)(Z)T])>0\lambda_{\min}\left(\mathbb{E}_{D}\left[\mathcal{H}(Z)\mathcal{H}(Z)^{T}\right]\right)>0 , we have

MHLAΘ=MHLAΘ\text{MHLA}_{\Theta}=\text{MHLA}_{\Theta^{\prime}} (35)

for all empirical risk minimizers Θ,ΘΩERM\Theta,\Theta^{\prime}\in\Omega_{\text{ERM}}. That is, all empirical risk minimizers compute the same function.

Proof.

We construct a map p:Ωψp:\Omega\rightarrow\mathbb{R}^{\psi} such that MHLAΘ=MHLAΘ\text{MHLA}_{\Theta}=\text{MHLA}_{\Theta^{\prime}} if and only if p(Θ)=p(Θ)p(\Theta)=p(\Theta^{\prime}). Then we show that any empirical risk minimizer ΘERM\Theta_{\text{ERM}} and the ground truth Θ¯\bar{\Theta} satisfy p(ΘERM)=p(Θ¯)p(\Theta_{\text{ERM}})=p(\bar{\Theta}).

In more detail, we construct some polynomials {pa}a[d]\{p_{a}\}_{a\in[d]} and family of polynomials \mathcal{H} such that

MHLAΘ(Z)|a=pa(Θ),(Z)\text{MHLA}_{\Theta}(Z)|_{a}=\langle p_{a}(\Theta),\mathcal{H}(Z)\rangle (36)

We construct a linear model class \mathcal{R} that takes as parameters vψv\in\mathbb{R}^{\psi} and data (Z)ψ\mathcal{H}(Z)\in\mathbb{R}^{\psi}. such that

v((Z))=v,(Z)\mathcal{R}_{v}(\mathcal{H}(Z))=\langle v,\mathcal{H}(Z)\rangle (37)

Let ΘERM\Theta_{\text{ERM}} be defined as

ΘERM:={ΘΩ|Θ=argminΘΩ𝔼i[N][(MHLAΘ(Zi),yi)}]\Theta_{\text{ERM}}\vcentcolon=\{\Theta^{\prime}\in\Omega|\Theta^{\prime}=\operatorname*{arg\,min}_{\Theta\in\Omega}\mathbb{E}_{i\in[N]}\left[\mathcal{L}(\text{MHLA}_{\Theta}(Z_{i}),y_{i})\}\right] (38)

Let vERMv_{\text{ERM}} be defined as

vERM:={vψ|v=argminvψ𝔼i[N][(v((Zi)),yi)]}v_{\text{ERM}}\vcentcolon=\{v^{\prime}\in\mathbb{R}^{\psi}|v^{\prime}=\operatorname*{arg\,min}_{v\in\mathbb{R}^{\psi}}\mathbb{E}_{i\in[N]}\left[\mathcal{L}(\mathcal{R}_{v}(\mathcal{H}(Z_{i})),y_{i})\right]\} (39)

Observe that for all ΘΘERM\Theta\in\Theta_{\text{ERM}}, we have p(Θ)vERMp(\Theta)\subseteq v_{\text{ERM}}. Here we use the fact that yy is realizable by the ground truth Θ¯\bar{\Theta}. Therefore if we show that vERMv_{\text{ERM}} is unique, i.e comprised of a single element then pERM:={p(Θ)|ΘΘERM}p_{\text{ERM}}\vcentcolon=\{p(\Theta)|\Theta\in\Theta_{\text{ERM}}\} is also unique. Therefore, MHLAΘ\text{MHLA}_{\Theta} is the same function for any ΘΘERM\Theta\in\Theta_{\text{ERM}}

To show vERMv_{\text{ERM}} is unique, all we need is that the second moment of the features ΛD=𝔼D[(Z)(Z)T]\Lambda_{D}=\mathbb{E}_{D}\left[\mathcal{H}(Z)\mathcal{H}(Z)^{T}\right] is positive definite (the covariance has a minimum eigenvalue bounded away from zero). ∎

Next we prove the main certifiable identifiability lemma by instantiating the polynomials \mathcal{H} and pp from Lemma B.1. See 4.1

Proof.

First we construct a polynomial p:Ωψp:\Omega\rightarrow\mathbb{R}^{\psi} and :d×nψ\mathcal{H}:\mathbb{R}^{d\times n}\rightarrow\mathbb{R}^{\psi} for ψ=(d2)d+d2\psi={d\choose 2}d+d^{2} such that

MHLAΘ(Z)[a]=pa(Θ),(Z)\text{MHLA}_{\Theta}(Z)[a]=\langle p_{a}(\Theta),\mathcal{H}(Z)\rangle (40)

We begin by rewriting MHLAΘ(Z)[a]\text{MHLA}_{\Theta}(Z)[a]. We index the first (d2)d{d\choose 2}d entries of pa(Θ)p_{a}(\Theta) by all pairs {j,k}\{j,k\} for j,k[d]j,k\in[d] and all [d]\ell\in[d].

pa(Θ){j,k},{}:=h[H](Vh,ajQh,k+Vh,akQh,j)p_{a}(\Theta)_{\{j,k\},\{\ell\}}\vcentcolon=\sum_{h\in[H]}\left(V_{h,aj}Q_{h,k\ell}+V_{h,ak}Q_{h,j\ell}\right) (41)

We define the entries of pa(Θ)p_{a}(\Theta) from [(d2)d,(d2)d+d2][{d\choose 2}d,{d\choose 2}d+d^{2}] as follows.

pa(Θ){j2}{}:=h[H]Vh,ajQh,jp_{a}(\Theta)_{\{j^{2}\}\{\ell\}}\vcentcolon=\sum_{h\in[H]}V_{h,aj}Q_{h,j\ell} (42)

Similarly, we define (Z)\mathcal{H}(Z) be be the following (d2)d+d2{d\choose 2}d+d^{2} features. (Z){j,k}{}\mathcal{H}(Z)_{\{j,k\}\{\ell\}} and (Z){}\mathcal{H}(Z)_{\{\ell\}}.

(Z){j,k}{}:=zj:,zk:zn\mathcal{H}(Z)_{\{j,k\}\{\ell\}}\vcentcolon=\langle z_{j:},z_{k:}\rangle z_{\ell n} (43)

and

(Z){j2}{}:=zj:2zn\mathcal{H}(Z)_{\{j^{2}\}\{\ell\}}\vcentcolon=\|z_{j:}\|^{2}z_{\ell n} (44)

Thus we rewrite MHLAΘ(Z)[a]\text{MHLA}_{\Theta}(Z)[a] as

MHLAΘ(Z)[a]={j,k}𝒮2d[d]pa(Θ){j,k},{}(Z){j,k}{}+j,[d]pa(Θ){j2}{}(Z){j2}{}=pa(Θ),(Z)\text{MHLA}_{\Theta}(Z)[a]=\sum_{\{j,k\}\in\mathcal{S}^{d}_{2}}\sum_{\ell\in[d]}p_{a}(\Theta)_{\{j,k\},\{\ell\}}\mathcal{H}(Z)_{\{j,k\}\{\ell\}}+\sum_{j,\ell\in[d]}p_{a}(\Theta)_{\{j^{2}\}\{\ell\}}\mathcal{H}(Z)_{\{j^{2}\}\{\ell\}}\\ =\langle p_{a}(\Theta),\mathcal{H}(Z)\rangle (45)

Here we introduce the notation 𝒮2d\mathcal{S}_{2}^{d} to denote the set of all pairs {j,k}\{j,k\} for j,k[d]j,k\in[d]. We have constructed a polynomial pa(Θ)p_{a}(\Theta) such that for any Θ,ΘΩ\Theta,\Theta^{\prime}\in\Omega in the same equivalence class pa(Θ)=pa(Θ)p_{a}(\Theta)=p_{a}(\Theta^{\prime}), we have MHLAΘ=MHLAΘ\text{MHLA}_{\Theta}=\text{MHLA}_{\Theta^{\prime}}. Furthermore, if there exists b[n]b\in[n] such that λmin(𝔼D[(Z)(Z)T])>0\lambda_{min}\left(\mathbb{E}_{D}\left[\mathcal{H}(Z)\mathcal{H}(Z)^{T}\right]\right)>0 then OLS returns a unique solution for pa(Θ)p_{a}(\Theta). Since the data is realizable, we conclude pa(Θ)=pa(Θ¯)p_{a}(\Theta)=p_{a}(\bar{\Theta}) for all ΘΩERM\Theta\in\Omega_{\text{ERM}}. ∎

Next we present the proof that realizability is not necessary to identify the function learned by MHLA with more than d2d^{2} heads. See 4.3

Proof.

We know from [lemma main algorithm] there exists a surjective map pa(Θ)p_{a}(\Theta) that takes ΘΩ\Theta\in\Omega into vψv\in\mathbb{R}^{\psi}. This implies that for all vψv\in\mathbb{R}^{\psi} there exists a right inverse function pr(v)=Θp^{r}(v)=\Theta satisfying p(Θ)=vp(\Theta)=v given by SVD. Therefore, p(ΘERM)vERMp(\Theta_{\text{ERM}})\in v_{\text{ERM}} i.e optimizing over vψv\in\mathbb{R}^{\psi} does no better than optimizing over ΘΩ\Theta\in\Omega. To prove this consider the contrary that there exists vvERMv^{\prime}\in v_{\text{ERM}} and there is no ΘΩ\Theta\in\Omega that achieves the same empirical risk as vv^{\prime}. However, pr(v)Ωp^{r}(v)\in\Omega is such a Θ\Theta, and we have a contradiction. The key point is that we avoid the assumption of realizability and replace it with surjectivity of the polynomials pap_{a}. ∎

Finally we prove that data drawn from independent noise is certifiably identifiable. A subtlety in the proof is that we use a somewhat different set of polynomials than Lemma 4.1 as we center and normalize our features, which still satisfies the assumptions of the general certificate Lemma B.1 See 4.2

Proof.

We give the entries of Λ(Z)\Lambda(Z) the following naming convention. Let the terms {j,k}{}\{j,k\}\{\ell\} and pairs {j,k}{}\{j^{\prime},k^{\prime}\}\{\ell^{\prime}\}. Terms that involve {j2}{}\{j^{2}\}\{\ell\} and {j2}{}\{j^{\prime 2}\}\{\ell^{\prime}\} are referred to as ’singles’.

𝔼[b(Z){j,k}{}b(Z){j,k}{}]=1n𝔼[zj:,zk:zj:,zk:zbzb]\mathbb{E}\left[\mathcal{H}_{b}(Z)_{\{j,k\}\{\ell\}}\mathcal{H}_{b}(Z)_{\{j^{\prime},k^{\prime}\}\{\ell^{\prime}\}}\right]=\frac{1}{n}\mathbb{E}\left[\langle z_{j:},z_{k:}\rangle\langle z_{j^{\prime}:},z_{k^{\prime}:}\rangle z_{\ell b}z_{\ell^{\prime}b}\right] (46)

We give entries of the following form the name ”singles to singles”

𝔼[b(Z){j2}{}b(Z){j2}{}]=1n𝔼[(zj:2nm2)(zj:2nm2)zb2]\mathbb{E}\left[\mathcal{H}_{b}(Z)_{\{j^{2}\}\{\ell\}}\mathcal{H}_{b}(Z)_{\{j^{\prime 2}\}\{\ell^{\prime}\}}\right]=\frac{1}{n}\mathbb{E}[(\|z_{j:}\|^{2}-nm_{2})(\|z_{j^{\prime}:}\|^{2}-nm_{2})z_{\ell b}^{2}] (47)

For the case of ZZ drawn with each entry i.i.d ν\nu we can proceed via case work.

Case 1: Pairs to Pairs, jkj\neq k and jkj^{\prime}\neq k^{\prime}

  1. 1.

    Subcase 1: {j,k}{j,k}\{j,k\}\neq\{j^{\prime},k^{\prime}\} and =\ell=\ell^{\prime}:

    1n𝔼[zj:,zk:zj:,zk:zbzb]=0\frac{1}{n}\mathbb{E}[\langle z_{j:},z_{k:}\rangle\langle z_{j^{\prime}:},z_{k^{\prime}:}\rangle z_{\ell b}z_{\ell^{\prime}b}]=0 (48)
  2. 2.

    Subcase 2: {j,k}={j,k}\{j,k\}=\{j^{\prime},k^{\prime}\} and =\ell=\ell^{\prime}:

    1n𝔼[zj:,zk:2zb2]=m23\frac{1}{n}\mathbb{E}[\langle z_{j:},z_{k:}\rangle^{2}z_{\ell b}^{2}]=m_{2}^{3} (49)

Case 2: Singles to Singles, j=kj=k and j=kj^{\prime}=k^{\prime}

  1. 1.

    Subcase 1: jjj\neq j^{\prime} and =\ell=\ell^{\prime}:

    1n𝔼[(zj:2nm2)(zj:2nm2)zb2]=0\frac{1}{n}\mathbb{E}\left[\left(\|z_{j:}\|^{2}-nm_{2}\right)\left(\|z_{j^{\prime}:}\|^{2}-nm_{2}\right)z_{\ell b}^{2}\right]=0 (50)
  2. 2.

    Subcase 2: j=jj=j^{\prime} and =\ell=\ell^{\prime}:

    1n𝔼[(zj:2nm2)2zb2]=1n((n2n)m22+nm4n2m22)m2=(m4m22)m2\frac{1}{n}\mathbb{E}\left[\left(\|z_{j:}\|^{2}-nm_{2}\right)^{2}z_{\ell b}^{2}\right]=\frac{1}{n}\left((n^{2}-n)m_{2}^{2}+nm_{4}-n^{2}m_{2}^{2}\right)m_{2}=(m_{4}-m_{2}^{2})m_{2} (51)

Case 3: Singles to Pairs, j=kj=k and jkj^{\prime}\neq k^{\prime}

  1. 1.

    Subcase 1: =\ell=\ell^{\prime}:

    1n𝔼[(zj:2nm2)zj:,zk:zb2]=0\frac{1}{n}\mathbb{E}\left[\left(\|z_{j:}\|^{2}-nm_{2}\right)\langle z_{j^{\prime}:},z_{k^{\prime}:}\rangle z_{\ell b}^{2}\right]=0 (52)

Finally for the feature (Z)b=m2zb\mathcal{H}(Z)_{\ell b}=m_{2}z_{\ell b} we have on the main diagonal 𝔼[m22zb2]=m22\mathbb{E}[m_{2}^{2}z_{\ell b}^{2}]=m_{2}^{2} and 0 everywhere else.

Therefore we’ve concluded that Λ(Z)\Lambda(Z) is a block diagonal matrix because the \ell\neq\ell^{\prime} blocks are near zero. All that remains is to verify that the diagonal blocks are full rank.

  1. 1.

    Pairs to Pairs: m23Im_{2}^{3}I is full rank with min eigenvalue m23m_{2}^{3}

  2. 2.

    Singles to Singles: (m4m22)m2I(m_{4}-m_{2}^{2})m_{2}I is full rank with min eigenvalue (m4m22)m2(m_{4}-m_{2}^{2})m_{2}.

Finally we provide a simple error bound for approximate empirical risk minimizers to demonstrate the robustness of the conclusions in Lemma 4.1. See 4.4

Proof.
MHLAΘ(Z)MHLAΘ(Z)2=a[d](pa(Θ)pa(Θ),(Z))2a[d]pa(Θ)pa(Θ)2(Z)2(a[d]pa(Θ)pa(Θ)2)ZF6ϵλmin(ΛD)ZF6\|\text{MHLA}_{\Theta}(Z)-\text{MHLA}_{\Theta^{\prime}}(Z)\|^{2}=\sum_{a\in[d]}\left(\langle p_{a}(\Theta)-p_{a}(\Theta^{\prime}),\mathcal{H}(Z)\rangle\right)^{2}\\ \leq\sum_{a\in[d]}\|p_{a}(\Theta)-p_{a}(\Theta^{\prime})\|^{2}\|\mathcal{H}(Z)\|^{2}\\ \leq\left(\sum_{a\in[d]}\|p_{a}(\Theta)-p_{a}(\Theta^{\prime})\|^{2}\right)\|Z\|_{F}^{6}\\ \leq\frac{\epsilon}{\lambda_{min}\left(\Lambda_{D}\right)}\|Z\|_{F}^{6}\\ (53)

Here the first equality follows from the linearization exhibited in Lemma B.1. The first inequality is cauchy schwarz. In the second inequality we apply a crude upper bound that no more than 6’th degree polynomials that are products of three squares of entries in ZZ are involved in (Z)2\|\mathcal{H}(Z)\|^{2}.

(Z)2a,a,a′′[d]b,b,b′′[n]Zab2Zab2Za′′b′′2ZF6\|\mathcal{H}(Z)\|^{2}\leq\sum_{a,a^{\prime},a^{\prime\prime}\in[d]\text{, }b,b^{\prime},b^{\prime\prime}\in[n]}Z_{ab}^{2}Z_{a^{\prime}b^{\prime}}^{2}Z_{a^{\prime\prime}b^{\prime\prime}}^{2}\leq\|Z\|_{F}^{6} (54)

The last inequality comes from the fact that Θ,Θ\Theta,\Theta^{\prime} are ϵ\epsilon approximate empirical risk minimizers. Therefore we know

λmin(ΛD)a[d]pa(Θ)pa(Θ)2a[d](pa(Θ)pa(Θ),(Z))2ϵ\lambda_{min}(\Lambda_{D})\sum_{a\in[d]}\|p_{a}(\Theta)-p_{a}(\Theta^{\prime})\|^{2}\leq\sum_{a\in[d]}\left(\langle p_{a}(\Theta)-p_{a}(\Theta^{\prime}),\mathcal{H}(Z)\rangle\right)^{2}\leq\epsilon (55)

which implies

a[d]pa(Θ)pa(Θ)2ϵλmin(ΛD)\sum_{a\in[d]}\|p_{a}(\Theta)-p_{a}(\Theta^{\prime})\|^{2}\leq\frac{\epsilon}{\lambda_{min}(\Lambda_{D})} (56)

which concludes the proof. ∎

Lemma B.2 (Identifiability with Error and Noise in Realizability).

Let D={(Zi,yi)}i[N]D=\{(Z_{i},y_{i})\}_{i\in[N]} be a dataset such that yi=MHLA(Zi)+ζiy_{i}=\text{MHLA}(Z_{i})+\zeta_{i} for ζi\zeta_{i} i.i.d and bounded. Let ΩϵERM\Omega_{\epsilon-\text{ERM}} be the set of ϵ\epsilon-approximate empirical risk minimizers.

ΩϵERM={ΘΩH|𝔼(Zi,yi)D[(MHLAΘ(Zi)yi)2]ϵ}.\Omega_{\epsilon-\text{ERM}}=\left\{\Theta\in\Omega_{H}~{}\big{|}~{}\mathbb{E}_{(Z_{i},y_{i})\in D}\left[\left(\text{MHLA}_{\Theta}(Z_{i})-y_{i}\right)^{2}\right]\leq\epsilon\right\}. (57)

Let maxi[N]ZiFB\max_{i\in[N]}\|Z_{i}\|_{F}\leq B . Then we have for any Θ,ΘΩϵERM\Theta,\Theta^{\prime}\in\Omega_{\epsilon-\text{ERM}} that for all inputs Zd×nZ\in\mathbb{R}^{d\times n}

MHLAΘ(Z)MHLAΘ(Z)ϵ1Ni[N]ζi2+B2Nlog(δ1)λmin(ΛD)ZF6.\|\text{MHLA}_{\Theta}(Z)-\text{MHLA}_{\Theta^{\prime}}(Z)\|\leq\frac{\epsilon-\frac{1}{N}\sum_{i\in[N]}\zeta_{i}^{2}+\frac{B^{2}}{N}\log(\delta^{-1})}{\lambda_{\min}\left(\Lambda_{D}\right)}\|Z\|_{F}^{6}. (58)
Proof.

The proof follows directly from Lemma 4.4 but we incorporate the ζi\zeta_{i} terms as is standard in analyses of linear regression. ∎

Appendix C Programs Expressible as Fixed Depth Linear Transformer

In this section we build out examples of programs that can be expressed as fixed depth linear transformers. Expressibility results can be carried out in a variety of equivalent ways. The main takeaway, is that the computation history of TM MM on word xx, when written down ”step by step” can be captured by next token prediction of linear attention. This is because the key-query-value naturally implements a table lookup sometimes referred to as ”associative memory” or ”in context linear regression” in the linear case.

The notion of an Autoregressive MHLA Program is useful for condensing the proofs of expressibility. We write such programs in an object oriented syntax with each token representing an object with multiple attributes. Attributes can be updated and looked up from other objects using a generalized lookup akin to associative memory.

Algorithm 3 Autoregressive MHLA Program
1:  Instantiate N instances OBJ = {obj(i)}i[N]\{\text{obj(i)}\}_{i\in[N]} of Class with set of Attributes {Attr1,Attr2,,Attrk}\{\text{Attr}_{1},\text{Attr}_{2},...,\text{Attr}_{k}\}
2:  Each Attribute takes on values in an alphabet ΣAttribute\Sigma_{\text{Attribute}}
3:  for iter \in [T] do
4:     Let obj[r]\text{obj}[r] be the rightmost token
5:     Let obj[r+1]\text{obj}[r+1] be a new token initialized with positional embedding obj[r+1].pos=r+1\text{obj}[r+1].\text{pos}=r+1
6:     for each {AttrSource, AttrDest} in {Pairs of Attributes in Class} do
7:        #AttrKey and AttrValue can be any pair of Attributes (and can be distinct from VarSource/VarDest)
8:        LookupDict={{obj.AttrKey: obj.AttrValue} for obj in OBJ}\text{LookupDict}=\{\{\text{obj.AttrKey: obj.AttrValue}\}\text{ for obj in OBJ}\}
9:        # if multiple objects have same obj.AttrKey then returns sum of obj.AttrValues which we aim to avoid
10:        Let Q\mathcal{B}_{Q} be any function from ΣAttrSource\Sigma_{\text{AttrSource}} to ΣAttrKey\Sigma_{\text{AttrKey}}
11:        Let V\mathcal{B}_{V} be any function from ΣAttrValue\Sigma_{\text{AttrValue}} to ΣAttrDest\Sigma_{\text{AttrDest}}
12:        Let query = Q\mathcal{B}_{Q}(obj[r].AttrSource)
13:        if query in LookupDict.Keys then
14:           obj[r+1].AttrDest = V\mathcal{B}_{V}(LookupDict(query))
15:        end if
16:     end for
17:     Append next token OBJ={obj[i]}i[r]{obj[r+1]}\text{OBJ}=\{\text{obj}[i]\}_{i\in[r]}\cup\{\text{obj}[r+1]\}
18:     r=r+1r=r+1
19:  end for

Lemma C.1.

For any program 𝒫\mathcal{P} written in the form of algorithm 6, there exists corresponding MHLA parameters ΘΩH\Theta\in\Omega_{H} such that MHLAΘ(Z)=𝒫(Z)\text{MHLA}_{\Theta}(Z)=\mathcal{P}(Z).

Proof.

We set some matrices to implement lookup tables. For any function of f:ABf:A\rightarrow B for sets AA and BB there is a canonical representation of the input domain as orthogonal unit vector v1,v2,,v|A|Av_{1},v_{2},...,v_{|A|}\in\mathbb{R}^{A} and output domain as another set of orthogonal unit vectors u1,u2,,u|B|Bu_{1},u_{2},...,u_{|B|}\in\mathbb{R}^{B}. Therefore, there is a matrix GfG_{f} that maps input vectors to output vectors satisfying Gfvi=ujG_{f}v_{i}=u_{j} for j=f(i)j=f(i) for all i[A]i\in[A] and j[B]j\in[B].

For functions f:ΣAttrSourceΣAttrKeyf:\Sigma_{\text{AttrSource}}\rightarrow\Sigma_{\text{AttrKey}} and f:ΣAttrValueΣAttrDestf^{\prime}:\Sigma_{\text{AttrValue}}\rightarrow\Sigma_{\text{AttrDest}} we associate matrices BQ|ΣAttrSource|×|ΣAttrKey|B_{Q}\in\mathbb{R}^{|\Sigma_{\text{AttrSource}}|\times|\Sigma_{\text{AttrKey}}|} and BV|ΣAttrValue|×|ΣAttrDest|B_{V}\in\mathbb{R}^{|\Sigma_{\text{AttrValue}}|\times|\Sigma_{\text{AttrDest}}|} respectively.

Then we form {Vh,Qh}h[H]\{V_{h},Q_{h}\}_{h\in[H]} as follows. Let VV be the matrix that is all zeros with BVB_{V} in the rows associated with ΣAttrSource\Sigma_{\text{AttrSource}} and the columns associated with ΣAttrKey\Sigma_{\text{AttrKey}}. Let QQ be the matrix that is all zeros with BVB_{V} in the rows associated with ΣAttrValue\Sigma_{\text{AttrValue}} and the columns associated with ΣAttrDest\Sigma_{\text{AttrDest}}.

In each layer we have multiple heads, each one performs the lookup operation for each pair of attributes in the class. ∎

C.1 Construction of UTM

Now we proceed with our construction of an Autoregressive MHLA-Program for UTM. The UTM requires a small number of operations captured by an Autoregressive MHLA-Program.

We define an embedding function that takes as input a TM MM and word xx such that

Definition (Embedding).

Let MM be a TM over state space QQ, alphabet AA, transition function δ\delta. Then

Embedding(M)=[q0q1qk#a0a0a0#δ(q0,a0)δ(q1,a0)δ(qk,a0)#a1a1a1#δ(q0,a1)δ(q1,a1)δ(qk,a1)#]\text{Embedding}(M)=\begin{bmatrix}q_{0}&q_{1}&\cdots&q_{k}&\#\\ a_{0}&a_{0}&\cdots&a_{0}&\#\\ \delta(q_{0},a_{0})&\delta(q_{1},a_{0})&\cdots&\delta(q_{k},a_{0})&\#\\ a_{1}&a_{1}&\cdots&a_{1}&\#\\ \delta(q_{0},a_{1})&\delta(q_{1},a_{1})&\cdots&\delta(q_{k},a_{1})&\#\\ \end{bmatrix} (59)

Let p1,p2,,pδp_{1},p_{2},...,p_{\delta} be ”positional encodings” that assign unique id’s for every letter in the word xx.

Embedding(x)=[p1p2pipi+1pδ#x1x2xixi+1xδ#00q00#]\text{Embedding}(x)=\begin{bmatrix}p_{1}&p_{2}&\cdots&p_{i}&p_{i+1}&\cdots&p_{\delta}&\#\\ x_{1}&x_{2}&\cdots&x_{i}&x_{i+1}&\cdots&x_{\delta}&\#\\ 0&0&\cdots&q&0&\cdots&0&\#\end{bmatrix} (60)

Then we define Embedding(M,x) to be

Embedding(M,x)=[Embedding(M)00Embedding(x)]\text{Embedding(M,x)}=\begin{bmatrix}\text{Embedding}(M)&0\\ 0&\text{Embedding}(x)\end{bmatrix} (61)

Henceforth we will write the construction in the syntax of an Autoregressive MHLA-Program instead of matrices with blocks of zeros and token embeddings to save space.

See 2.2

The construction is given in the language of Autoregressive MHLA-Programs in algorithm 6 which provides the instruction set for writing the next letter in the computation history onto the output tape.

Proof.

Proof Idea: A few elementary operations can be captured by a MHLA-program which can be composed to output the computation history of MM on xx. We begin by introducing some notation for the ”Lookup” operation which we build into copy, move, and if-then which are all the operations required to construct the UTM.

General Lookup: For each lookup there are three objects that are involved. Let Token=obj[r]=\text{obj}[r] be the ”source” which is always the rightmost token. An attribute from the source object known as AttrSource is linearly transformed to form a ”query”. Lookup involves a table T={obj[i].AttrKey: obj[i].AttrValue}i[r]T=\{\text{obj[i].AttrKey: obj[i].AttrValue}\}_{i\in[r]} which is used to match an AttrKey to look up an AttrValue from an object obj[p]\text{obj}[p] that we denote the ”target”. Note, that if the obj[i] has an AttrKey that is zero, it is the same as not being in the table. In the pseudocode algorithm 6 these zero attributes are denoted as ”None”.

Given a query, we copy the associated AttrValue from the lookup table TT and update AttrDest in an object NextToken=obj[r+1]=\text{obj}[r+1] which we denote the ”destination”. Multiple lookup operations can be performed in parallel by many heads with each head responsible for a single lookup.

To output each letter of the computation history, we increase the number of tokens rr by a constant cc. We refer to the set of contiguous tokens [0,c],[c,2c],etc.[0,c],[c,2c],etc. involved in the computation of a single letter as a ”block”. Here block[i] = {obj[j]}j[ic,(i+1)c]\{\text{obj}[j]\}_{j\in[ic,(i+1)c]}. We construct a different set of heads to act on each token and enforce that the nonzero rows that each block of tokens occupy are disjoint. Furthermore, within a block, the states of each token occupies a disjoint set of rows except when they are used to construct a table. Tables are the only case where we want tokens to occupy the same rows. In this manner the following abstraction can be made.

At the beginning of each block starting with obj[r], we can lookup attributes from anywhere in OBJ that we want to load into different attributes in obj[r]. Then we can apply any sequence of if-then statements involving the attributes of obj[r] to update the attributes (or create new attributes). To run the UTM we need a few simple primitives denoted Lookup and If-Then.

Construction of Primitives: We write down the construction by constructing a sufficient set of primitives Lookup and If-Then. We also include Copy which is a special case of Lookup that is used frequently.

Lookup:

When the transforms BQB_{Q} and BVB_{V} are the identity we denote the lookup operation for table TT where we query an attribute ss^{\prime} of obj[r]\text{obj}[r] to update the attribute ss of obj[r+1] as obj[r+1].s=Lookup(T,obj[r].s’)\text{obj[r+1].s}=\text{Lookup(T,obj[r].s')}

Copy:

A special case of lookup is copy, where we need to copy attributes from tokens that are at an offset k-k for k[r]k\in[r]. This can be done by setting Q\mathcal{B}_{Q} to permute the positional encoding by k-k positions. Then the query matches the key that is the positional encoding of the target object. Let s,ss,s^{\prime} be target and destination attributes. We denote the copy operation of the attribute ss^{\prime} of the obj at offset k-k from rr into the attribute ss of the destination object to be obj[r+1].s=Copy(obj[r-k].s’)\text{obj[r+1].s}=\text{Copy}(\text{obj[r-k].s'}).

If-Then:

We write down an If-Then Program algorithm 4 and a corresponding Autoregressive MHLA-Program algorithm 5 to implement If-Then. An If-Then program looks up whether an attribute xx is equal to any of attributes a1,a2,,aka_{1},a_{2},...,a_{k} then we set attribute xx^{\prime} to b1,b2,,bkb_{1},b_{2},...,b_{k} respectively. This is achieved by copying the attributes aia_{i} and bib_{i} into dummy attributes s0s0 and s1s1 for all ii in kk for a series of kk consecutive tokens. This creates a table with key s0s0 and value s1s1. Then we use attribute xx as the query, which looks up the corresponding value s1s1 which we use to update an attribute xx^{\prime}.

Algorithm 4 If-Then Program
1:  # If attribute x is equal to any of a1,a2,,aka_{1},a_{2},...,a_{k} then set attribute xx^{\prime} to b1,b2,,bkb_{1},b_{2},...,b_{k} respectively
2:  if Token.x == Token.a1a_{1}:  then
3:     NextToken.x’ = Token.b1b_{1}
4:  end if
5:  if Token.x == Token.a2a_{2}:  then
6:     NextToken.x’ = Token.b2b_{2}
7:  end if
8:  \ldots
9:  if Token.x == Token.aka_{k}:  then
10:     NextToken.x’ = Token.bkb_{k}
11:  end if
Algorithm 5 MHLA If-Then Program
1:  # If attribute x is equal to any of a1,a2,,aka_{1},a_{2},...,a_{k} then set attribute xx^{\prime} to b1,b2,,bkb_{1},b_{2},...,b_{k} respectively
2:  token[r+1].s0 = token[r].a1a_{1}
3:  token[r+1].s1 = token[r].b1b_{1}
4:  NEXT TOKEN r=r+1r=r+1
5:  token[r+1].s0 = token[r].a2a_{2}
6:  token[r+1].s1 = token[r].b2b_{2}
7:  \ldots
8:  NEXT TOKEN r=r+1r=r+1
9:  token[r+1].s0 = token[r].aka_{k}
10:  token[r+1].s1 = token[r].bkb_{k}
11:  NEXT TOKEN r=r+1r=r+1
12:  Table T = {obj[i].s0 : obj[i].s1}i[r,rk+1]\{\text{obj[i].s0 : obj[i].s1}\}_{i\in[r,r-k+1]}
13:  token[r+1].x’ = Lookup(T,token[r].x)

Algorithm 6 Simplified Instruction Set MHLA Program for UTM for a single block
1:  # Initialize Lookup Tables for TM M and tape T1T_{1}
2:  # δ(q,a)=[next-state, next-letter, next-move]\delta(q,a)=[\text{next-state, next-letter, next-move}]
3:  M = {q:[a0,δ(q,a0),a1,δ(q,a1)]}qQ\{q:[a_{0},\delta(q,a_{0}),a_{1},\delta(q,a_{1})]\}_{q\in Q}
4:  T1T_{1} = {token[i].PosEncoding: token[i].Letter}i[r]\{\text{token[i].PosEncoding: token[i].Letter}\}_{i\in[r]}
5:  # Begin Loading Information from M and previous tokens on tape
6:  # First copy letter/state from token -N-1 positions away
7:  # Attribute s(-1) = {letter, state} where state can be equal to None
8:  NextToken.s(-1) = Copy(Token[-N-1].s0)
9:  # Second copy letter/state from token -N positions away
10:  # Attribute s0 = {letter, state} where state can be equal to None
11:  NextToken.s0 = Copy(Token[-N].s0)
12:  # Third copy letter/state from token -N+1 positions away
13:  # Attribute s1 = {letter, state} where state can be equal to None
14:  NextToken.s1 = Copy(Token[-N+1].s0)
15:  NEXT TOKEN r=r+1r=r+1
16:  #Split into three branches to handle left, head, and right positions relative to head
17:  RUN BRANCH 1 (Token is Left of Head Position) See algorithm 7
18:  RUN BRANCH 2 (Token is at Head Position) See algorithm 7
19:  RUN BRANCH 3 (Token is Right of Head Position) See algorithm 7
Algorithm 7 Branches to handle cases Left of Head, Head, and Right of Head
1:  #Split into three branches to handle left, head, and right positions relative to head
2:  BRANCH 1 (Token is Left of Head Position)
3:  # we have loaded a state q into s1 (if left of head) and next we load [a0,δ(q,a0),a1,δ(q,a1)][a_{0},\delta(q,a_{0}),a_{1},\delta(q,a_{1})] into s2
4:  NextToken.s2 = Lookup(M,Token.s1.state)
5:  NEXT TOKEN r=r+3r=r+3
6:  if Token.s2.letter==a0\text{Token.s2.letter}==a_{0} then
7:     NextToken.s3 = δ(q,a0)\delta(q,a_{0}) = [q’,w’,L/R]
8:  end if
9:  if Token.s2.letter==a1\text{Token.s2.letter}==a_{1} then
10:     NextToken.s3 = δ(q,a1)\delta(q,a_{1}) = [q’,w’,L/R]
11:  end if
12:  NEXT TOKEN r = r+3
13:  if Token.s3.move == L then
14:     NextToken.return-letter = Token.s0.letter
15:     NextToken.return-state = q’
16:  end if
17:  if Token.s3.move == L then
18:     NextToken.return-letter = Token.s0.letter
19:     NextToken.return-state = None
20:  end if
21:  BRANCH 2 (Token is at Head Position)
22:  # we have loaded a state q into s0 and next we load [a0,δ(q,a0),a1,δ(q,a1)][a_{0},\delta(q,a_{0}),a_{1},\delta(q,a_{1})] into s2
23:  NextToken.s2 = Lookup(M,Token.s0.state)
24:  NEXT TOKEN r = r+3
25:  if Token.s2.letter==a0\text{Token.s2.letter}==a_{0} then
26:     NextToken.s3 = δ(q,a0)\delta(q,a_{0}) = [q’,w’,L/R]
27:  end if
28:  if Token.s2.letter==a1\text{Token.s2.letter}==a_{1} then
29:     NextToken.s3 = δ(q,a1)\delta(q,a_{1}) = [q’,w’,L/R]
30:  end if
31:  NEXT TOKEN r = r+3
32:  if Token.s3.next-letter is not None then
33:     NextToken.return-letter = Token.s3.next-letter
34:     NextToken.return-state = None
35:  end if
36:  BRANCH 3 (Token is Right of Head Position)
37:  # we have loaded a state q into s(-1) and next we load [a0,δ(q,a0),a1,δ(q,a1)][a_{0},\delta(q,a_{0}),a_{1},\delta(q,a_{1})] into s2
38:  NextToken.s2 = Lookup(M,Token.s(-1).state)
39:  NEXT TOKEN r = r+3
40:  if Token.s2.letter==a0\text{Token.s2.letter}==a_{0} then
41:     NextToken.s3 = δ(q,a0)\delta(q,a_{0}) = [q’,w’,L/R]
42:  end if
43:  if Token.s2.letter==a1\text{Token.s2.letter}==a_{1} then
44:     NextToken.s3 = δ(q,a1)\delta(q,a_{1}) = [q’,w’,L/R]
45:  end if
46:  NEXT TOKEN r = r+3
47:  if Token.s3.move == L then
48:     NextToken.return-letter = Token.s0.letter
49:     NextToken.return-state = None
50:  end if
51:  if Token.s3.move == R then
52:     NextToken.return-letter = Token.s0.letter
53:     NextToken.return-state = Token.s3.next-state
54:  end if

C.2 Proofs For Learning UTM

See 5.1

Corollary 3.

In particular, for sample complexity N=poly(d,ϵ1,log(δ1),n,t)N=\text{poly}(d,\epsilon^{-1},\log(\delta^{-1}),n,t), by Lemma 2.2, we have with probability 1δ1-\delta over the randomness in the data that the probability that the c(t)c(t) step of the computation history of MHLAΘ^\text{MHLA}_{\hat{\Theta}} is equal to xtx_{t} is

Pr(M,x)𝒟[CHΘ^(M,x)c(t)[:kt]=xt]1ϵ,\Pr\nolimits_{(M,x)\sim\mathcal{D}}\left[\text{CH}_{\hat{\Theta}}(M,x)^{c(t)}[:-k_{t}]=x^{t}\right]\geq 1-\epsilon, (62)

where c(t)O((n+t)t)c(t)\leq O((n+t)t). That is, the computation history of the MHLA returned by algorithm 1 is equal to the computation history of MM on xx.

Proof.

We have from Theorem 2.1 that algorithm 1 returns Θ^\hat{\Theta} such that

𝔼(Z,y)𝒟[(MHLAΘ^(Z)y)2]minΘΩH𝔼(Z,y)𝒟[(MHLAΘ(Z)y)2]ϵ\mathbb{E}_{(Z,y)\in\mathcal{D}}\left[\left(\text{MHLA}_{\hat{\Theta}}(Z)-y\right)^{2}\right]-\min_{\Theta\in\Omega_{H}}\mathbb{E}_{(Z,y)\in\mathcal{D}}\left[\left(\text{MHLA}_{\Theta}(Z)-y\right)^{2}\right]\leq\epsilon (63)

Then to obtain an error bound on the Φ\Phi step computation history, which involves O(nΦ)O(n\Phi) tokens, we just observe that by union bound each step rounds to an incorrect set of tokens with probability less than ϵ\epsilon. Therefore, over O(Φ)O(\Phi) steps the error probability is upper bounded by ϵΦ\epsilon\Phi. Equivalently

Pr(M,x)𝒟[CHΘ^(M,x)CHΘ(M,x)]O(ϵΦ).\Pr_{(M,x)\sim\mathcal{D}}[\text{CH}_{\hat{\Theta}}(M,x)\neq\text{CH}_{\Theta}(M,x)]\leq O(\epsilon\Phi). (64)

Then proving Corollary 3 is a simple exercise. For a larger sample complexity N=poly(d,ϵ1,log(δ1),n,t)N=\text{poly}(d,\epsilon^{-1},\log(\delta^{-1}),n,t), by Lemma 2.2, we have that the probability that every token of the autoregressive computation history of MHLAΘ^\text{MHLA}_{\hat{\Theta}} is equal to xtx_{t} is

Pr(M,x)𝒟[CHΘ^(M,x)c(t)[:kt]=xt]1ϵ\Pr_{(M,x)\sim\mathcal{D}}\left[\text{CH}_{\hat{\Theta}}(M,x)^{c(t)}[:-k_{t}]=x^{t}\right]\geq 1-\epsilon (65)

See 5.2

Proof.

The proof follows from the quantitative version of Lemma 4.4. Using the given that λmin(ΛD)>η\lambda_{min}(\Lambda_{D})>\eta, we conclude that for any Θ^ΩϵERM\hat{\Theta}\in\Omega_{\epsilon-\text{ERM}} that for all inputs Zd×nZ\in\mathbb{R}^{d\times n}

MHLAΘ^(Z)MHLAΘ(Z)ϵηZF6.\|\text{MHLA}_{\hat{\Theta}}(Z)-\text{MHLA}_{\Theta}(Z)\|\leq\frac{\epsilon}{\eta}\|Z\|_{F}^{6}. (66)

If we select a sufficiently small ϵ=1/poly(d,N,|Q|,|Σ|,n,t,η1)\epsilon=1/\text{poly}(d,N,|Q|,|\Sigma|,n,t,\eta^{-1}) then we can ensure

Pr(M,x)𝒟[CHΘ^(M,x)c(t)[:kt]=xt]1ϵ\Pr_{(M,x)\sim\mathcal{D}}\left[\text{CH}_{\hat{\Theta}}(M,x)^{c(t)}[:-k_{t}]=x^{t}\right]\geq 1-\epsilon (67)

.

The runtime then scales with poly(d,N,|Q|,|Σ|,n,t,η1)\text{poly}(d,N,|Q|,|\Sigma|,n,t,\eta^{-1}) as desired. ∎

Appendix D Additional Definitions

Definition (Orthogonal Embeddings).

Let Embed be a function Embed:Σ|Σ|\text{Embed}:\Sigma\rightarrow\mathbb{R}^{|\Sigma|}. Let Σ\Sigma be an alphabet and let e1,e2,,e|Σ||Σ|e_{1},e_{2},...,e_{|\Sigma|}\in\mathbb{R}^{|\Sigma|} be a basis of orthogonal unit vectors. Then for each letter aa in an alphabet Σ\Sigma, we define Embed(a)=ea\text{Embed}(a)=e_{a} where we associate a different unit vector to each letter.

We adopt a naive ”rounding” scheme for converting vectors into tokens. This can be done in a variety of ways, and we choose to simply round the entries of the vector embeddings to the nearest token embedding.

Definition (Rounding).

For any vector v=(v1,v2,,vd)dv=(v_{1},v_{2},...,v_{d})\in\mathbb{R}^{d}, let Round(v)=ej\text{Round}(v)=e_{j} for j=argmaxi[d]v,eij=\operatorname*{arg\,max}_{i\in[d]}\langle v,e_{i}\rangle. Since we use orthogonal unit vectors for token embeddings we will refer to Round(v)\text{Round}(v) as a token. We will often refer to a matrix Zd×nZ\in\mathbb{R}^{d\times n} as being equivalent to a series of nn tokens a1,a2,,ana_{1},a_{2},...,a_{n} to mean Round(Z[:,i])=ai\text{Round}(Z[:,i])=a_{i} for all i[n]i\in[n].

Algorithm 8 Extract Features
1:  Input: Data D:={Zi}i[N]D\vcentcolon=\{Z_{i}\}_{i\in[N]} for Zid×niZ_{i}\in\mathbb{R}^{d\times n_{i}} and yidy_{i}\in\mathbb{R}^{d}
2:  for ZiDZ_{i}\in D do
3:     Let z1,z2,zdz_{1},z_{2},...z_{d} be the rows of ZiZ_{i} and let za,bz_{a,b} be the (a,b)(a,b) entry of ZiZ_{i}
4:     for j[d]j\in[d]  do
5:        for k[d]k\in[d] do
6:           for [d]\ell\in[d] do
7:              Let 𝒳id×d2\mathcal{X}_{i}\in\mathbb{R}^{d\times d^{2}} be defined as follows
8:              𝒳i[j,kd+]=[zj:,zk:zni]\mathcal{X}_{i}[j,kd+\ell]=\left[\langle z_{j:},z_{k:}\rangle z_{\ell n_{i}}\right]
9:           end for
10:        end for
11:     end for
12:  end for
13:  Return: {𝒳i}i[N]\{\mathcal{X}_{i}\}_{i\in[N]} such that
𝒳i:=[z1:,z1:z1niz1:,z2:z1niz1:,zd:z1niz1:,zd:zdniz2:,z1:z1niz2:,z2:z1niz2:,zd:z1niz2:,zd:zdnizd:,z1:z1nizd:,z2:z1nizd:,zd:z1nizd:,zd:zdni].\mathcal{X}_{i}\vcentcolon=\begin{bmatrix}\langle z_{1:},z_{1:}\rangle z_{1n_{i}}&\langle z_{1:},z_{2:}\rangle z_{1n_{i}}&\cdots&\langle z_{1:},z_{d:}\rangle z_{1n_{i}}&\cdots&\langle z_{1:},z_{d:}\rangle z_{dn_{i}}\\ \langle z_{2:},z_{1:}\rangle z_{1n_{i}}&\langle z_{2:},z_{2:}\rangle z_{1n_{i}}&\cdots&\langle z_{2:},z_{d:}\rangle z_{1n_{i}}&\cdots&\langle z_{2:},z_{d:}\rangle z_{dn_{i}}\\ \vdots&\vdots&\ddots&\vdots&\ddots&\vdots\\ \langle z_{d:},z_{1:}\rangle z_{1n_{i}}&\langle z_{d:},z_{2:}\rangle z_{1n_{i}}&\cdots&\langle z_{d:},z_{d:}\rangle z_{1n_{i}}&\cdots&\langle z_{d:},z_{d:}\rangle z_{dn_{i}}\\ \end{bmatrix}~{}. (68)

D.1 Training details of attention networks

We use Adam Kingma & Ba (2014) optimizer to train linear attention model Equation 1 and the full Transformer Vaswani et al. (2017) models.

hyper parameter search space
dd input dimension [2, 4, 8, 16]
mm number of heads [1, 2, 4, 8, 16]
nn number of layers [1, 2, 4]
learning rate [0.01, 0.001]
batch size [32, 64]
optimizer AdamW Loshchilov & Hutter (2018)

D.2 Training details in DFA Execution

We use the Llama variant of the Transformer arhitecture from Touvron et al. (2023). We run each setting with NN number of training examples with the following different values N{16,32,64,128,256,512,1024,2048,4096,6144,8192,12290,16384,20480,32768,65536}N\in\{16,32,64,128,256,512,1024,2048,4096,6144,8192,12290,16384,20480,32768,65536\}. The other hyper parameters are given in the below table.

hyper parameter search space
dd input dimension [2048]
mm number of heads [16]
nn number of layers [4]
learning rate [0.00025]
epochs 100
optimizer AdamW Loshchilov & Hutter (2018)