This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Transformers are Provably Optimal In-context Estimators for Wireless Communications

Vishnu Teja Kunde1,∗  Vicram Rajagopalan2,∗
Chandra Shekhara Kaushik Valmeekam1,∗
Krishna Narayanan1  Srinivas Shakkottai1
Dileep Kalathil1  Jean-Francois Chamberland1
1Department of Electrical and Computer Engineering, Texas A&M University
2Department of Computer Science and Engineering, Texas A&M University
(February 2024)
Abstract

Pre-trained transformers exhibit the capability of adapting to new tasks through in-context learning (ICL), where they efficiently utilize a limited set of prompts without explicit model optimization. The canonical communication problem of estimating transmitted symbols from received observations can be modelled as an in-context learning problem: Received observations are essentially a noisy function of transmitted symbols, and this function can be represented by an unknown parameter whose statistics depend on an (also unknown) latent context. This problem, which we term in-context estimation (ICE), has significantly greater complexity than the extensively studied linear regression problem. The optimal solution to the ICE problem is a non-linear function of the underlying context. In this paper, we prove that, for a subclass of such problems, a single layer softmax attention transformer (SAT) computes the optimal solution of the above estimation problem in the limit of large prompt length. We also prove that the optimal configuration of such transformer is indeed the minimizer of the corresponding training loss. Further, we empirically demonstrate the proficiency of multi-layer transformers in efficiently solving broader in-context estimation problems. Through extensive simulations, we show that solving ICE problems using transformers significantly outperforms standard approaches. Moreover, just with a few context examples, it achieves the same performance as an estimator with perfect knowledge of the latent context.

**footnotetext: Equal contribution

I Introduction

Recent advances in our understanding of transformers NIPS2017_3f5ee243 have brought to the fore the notion that they are capable of in-context learning (ICL). Here, a pre-trained transformer is presented with example prompts followed by a query to be answered of the form (𝐱1,f(𝐱1,θ),,𝐱k,f(𝐱k,θ),𝐱k+1),(\mathbf{x}_{1},f(\mathbf{x}_{1},\theta),\ldots,\mathbf{x}_{k},f(\mathbf{x}_{k},\theta),\mathbf{x}_{k+1}), where θ\theta is a context common to all the prompts. The finding is that the transformer is able to respond with a good approximation to f(𝐱k+1,θ)f(\mathbf{x}_{k+1},\theta) for many function classes garg2023transformers; panwar2024incontext. The transformer itself is pre-trained, either implicitly or explicitly over a variety of contexts and so acquires the ability to generate in-distribution outputs conditioned on a specific context.

Although the interpretability of in-context learning for solving rich classes of problems can be very challenging, there have been attempts to understand the theoretical aspects of in-context learning for simpler problems, specifically linear regression. The work zhang2024trained characterizes the convergence of the training dynamics of a single layer linear attention (LA) for solving a linear regression problem over random regression instances. It shows that this configuration of the trained transformer is optimal in the limit of large prompt length.

In this paper, we introduce an in-context estimation problem where we are presented with the sequence (g(𝐱1,θ),𝐱1,,(g(\mathbf{x}_{1},\theta),\mathbf{x}_{1},\ldots, g(𝐱k,θ),𝐱k,g(𝐱k+1,θ))g(\mathbf{x}_{k},\theta),\mathbf{x}_{k},g(\mathbf{x}_{k+1},\theta)) where g(𝐱,θ)g(\mathbf{x},\theta) is a stochastic function whose distribution is parameterized by θ\theta, and our goal is to estimate 𝐱k+1\mathbf{x}_{k+1}. This inverse problem is inherently more challenging than predicting f(𝐱k+1,θ)f(\mathbf{x}_{k+1},\theta) for 𝐱k+1\mathbf{x}_{k+1} (as in linear regression) due to the complicated dependence of the optimal estimator on noisy observations.

The above formulation of in-context estimation models many problems of interest in engineering. However, to keep the paper focused, we consider a subclass of problems that is of particular interest to communication theory. Here 𝐱\mathbf{x} represents a sequence of transmitted symbols and the communication channel with the unknown state (context) θ\theta maps 𝐱\mathbf{x} into a sequence of received symbols represented by 𝐲=f(𝐱,θ)\mathbf{y}=f(\mathbf{x},\theta). The receiver is then required to recover 𝐱\mathbf{x} using in-context examples that can be constructed from the received symbols and transmitted pilot symbols. More details about the function f(𝐱,θ)f(\mathbf{x},\theta) and the in-context examples can be found in Section II.

Our work is motivated by the observation that a pre-trained transformer’s ability to perform in-context sequence completion suggests strongly that it should be able to approximate the desired conditional mean estimator given the in-context examples for our inverse problem. Essentially, if a transformer is pre-trained on a variety of contexts, it should be able to implicitly determine the latent context from pilot symbols and then perform end-to-end in-context estimation of transmitted symbols. Once trained, such a transformer is simple to deploy because there are no runtime modifications, which could make it a potential building block of future wireless receivers. We indeed show this result both theoretically and empirically.

Our main contributions are:

  • We define the inverse problem of estimating symbols from noisy observations using in-context examples, which we term in-context estimation (ICE). This is more difficult than linear regression which is a well-studied in-context learning (ICL) problem.

  • For special cases of such problems (detailed in Section III), we theoretically show that there exists a configuration of single layer softmax attention transformer (SAT) that is asymptotically optimal.

  • We empirically demonstrate that multi-layer transformers are capable of efficiently solving representative ICE problems with finite in-context examples.

II In-Context Estimation

In this section, we define the problem of in-context estimation, and we discuss the complexities associated with solving it.

Definition 1

(In-Context Estimation) Let {𝐱t}t=1n+1iidx\{\mathbf{x}_{t}\}_{t=1}^{n+1}\stackrel{{\scriptstyle\rm iid}}{{\sim}}\mathbb{P}_{x} be the input symbols and {𝐳t}t=1n+1iidz\{\mathbf{z}_{t}\}_{t=1}^{n+1}\stackrel{{\scriptstyle\rm iid}}{{\sim}}\mathbb{P}_{z} be the noise samples. Let {𝐇t(θ)}t=1n+1hθ\{\mathbf{H}_{t}(\theta)\}_{t=1}^{n+1}\sim\mathbb{P}_{h\mid\theta} be a random process determined by a latent parameter θΘ\theta\sim\mathbb{P}_{\Theta}. For t[n+1]t\in[n+1], 𝐲t=𝐠(𝐱t,𝐇t(θ))+𝐳t\mathbf{y}_{t}=\mathbf{g}(\mathbf{x}_{t},\mathbf{H}_{t}(\theta))+\mathbf{z}_{t}. Let 𝒫θn{(𝐲t,𝐱t),t[n]}\mathcal{P}^{n}_{\theta}\triangleq\{(\mathbf{y}_{t},\mathbf{x}_{t}),t\in[n]\} to be the set of examples in the prompt until time nn. The problem of in-context estimation (ICE) involves estimating 𝐱n+1\mathbf{x}_{n+1} from 𝐲n+1,𝒫θn\mathbf{y}_{n+1},\mathcal{P}_{\theta}^{n} such that the mean squared error (MSE) 𝔼θΘ,{𝐇(θ)}hθ,𝐳z,𝐱x[𝐱^n+1(𝒫θn,𝐲n+1)𝐱n+122]\mathbb{E}_{\theta\sim\mathbb{P}_{\Theta},\{\mathbf{H}(\theta)\}\sim\mathbb{P}_{h\mid\theta},\mathbf{z}\sim\mathbb{P}_{z},\mathbf{x}\sim\mathbb{P}_{x}}[\lVert\hat{\mathbf{x}}_{n+1}(\mathcal{P}_{\theta}^{n},\mathbf{y}_{n+1})-\mathbf{x}_{n+1}\rVert_{2}^{2}] is minimized.

For the above problem, consider the following estimator. Given the past nn examples 𝒫θn\mathcal{P}_{\theta}^{n} from the context θΘ\theta\in\Theta and the current observation 𝐲n+1\mathbf{y}_{n+1}, the context unaware (CU) minimum mean squared error estimate (MMSE) of 𝐱n+1\mathbf{x}_{n+1} is known to be the conditional mean estimate (CME)

𝐱^n+1CU(𝒫θn,𝐲n+1)=𝔼[𝐱n+1=𝐱𝒫θn,𝐲n+1].\hat{\mathbf{x}}_{n+1}^{\rm CU}(\mathcal{P}_{\theta}^{n},\mathbf{y}_{n+1})=\mathbb{E}[\mathbf{x}_{n+1}=\mathbf{x}\mid\mathcal{P}_{\theta}^{n},\mathbf{y}_{n+1}]. (1)

Having a closer look at the computation of the conditional expectation in (1), we get

𝐱^n+1CU(𝒫θn,𝐲n+1)=𝔼[𝐱n+1=𝐱𝒫θn,𝐲n+1]=𝐱𝒳𝐱x(𝐱n+1=𝐱𝒫θn,𝐲n+1)\displaystyle\hat{\mathbf{x}}_{n+1}^{\rm CU}(\mathcal{P}_{\theta}^{n},\mathbf{y}_{n+1})=\mathbb{E}[\mathbf{x}_{n+1}=\mathbf{x}\mid\mathcal{P}_{\theta}^{n},\mathbf{y}_{n+1}]=\sum_{\mathbf{x}\in\mathcal{X}}\mathbf{x}\mathbb{P}_{x}(\mathbf{x}_{n+1}=\mathbf{x}\mid\mathcal{P}_{\theta}^{n},\mathbf{y}_{n+1})
𝐱𝒳𝐱x(𝐱)p(𝒫θn,𝐲n+1𝐱)\displaystyle\hskip 20.0pt\propto\sum_{\mathbf{x}\in\mathcal{X}}\mathbf{x}\mathbb{P}_{x}(\mathbf{x})p(\mathcal{P}_{\theta}^{n},\mathbf{y}_{n+1}\mid\mathbf{x})
=𝐱𝒳𝐱x(𝐱)θΘpΘ(θ)𝐇1,,𝐇n,𝐇n+1d𝐲×d𝐱phθ(𝐇1,,𝐇n,𝐇n+1)\displaystyle\hskip 20.0pt=\sum_{\mathbf{x}\in\mathcal{X}}\mathbf{x}\mathbb{P}_{x}(\mathbf{x})\int_{\theta^{\prime}\in\Theta}p_{\Theta}(\theta^{\prime})\int_{\mathbf{H}_{1},\dots,\mathbf{H}_{n},\mathbf{H}_{n+1}\in\mathbb{R}^{d_{\mathbf{y}}\times d_{\mathbf{x}}}}p_{h\mid\theta^{\prime}}(\mathbf{H}_{1},\dots,\mathbf{H}_{n},\mathbf{H}_{n+1})
p(𝒫θn,𝐲n+1𝐱,𝐇1,,𝐇n,𝐇n+1)d𝐇1d𝐇n+1\displaystyle\hskip 50.0ptp(\mathcal{P}_{\theta}^{n},\mathbf{y}_{n+1}\mid\mathbf{x},\mathbf{H}_{1},\dots,\mathbf{H}_{n},\mathbf{H}_{n+1})~{}d\mathbf{H}_{1}\dots d\mathbf{H}_{n+1}

The above computation involves evaluating a multi-dimensional integral. In most practical problems of estimation, computation of (1) can be very difficult or even intractable.

Why can estimation problems be more challenging than the regression problems?

To distinguish these two settings, consider the linear model 𝐠:d𝐱×d𝐲×d𝐱d𝐲\mathbf{g}:\mathbb{R}^{d_{\mathbf{x}}}\times\mathbb{R}^{d_{\mathbf{y}}\times d_{\mathbf{x}}}\to\mathbb{R}^{d_{\mathbf{y}}} defined by 𝐠(𝐱,𝐇)=𝐇𝐱\mathbf{g}(\mathbf{x},\mathbf{H})=\mathbf{H}\mathbf{x}. In the linear regression setting with the relation 𝐲=𝐇𝐱\mathbf{y}=\mathbf{H}\mathbf{x}, the optimal estimate 𝐲^\hat{\mathbf{y}} obtained from observation 𝐱\mathbf{x} and given 𝐇\mathbf{H} is simply the linear transform 𝐲^=𝐇𝐱=𝐲\hat{\mathbf{y}}=\mathbf{H}\mathbf{x}=\mathbf{y}. In other words, the optimal estimator is a linear function of the underlying context 𝐇\mathbf{H}; the optimal estimate has trivial (zero) error when the context 𝐇\mathbf{H} is perfectly known. On the other hand, the optimal estimate 𝐱^\hat{\mathbf{x}} of 𝐱\mathbf{x} from the noisy 𝐲=𝐇𝐱+𝐳\mathbf{y}=\mathbf{H}\mathbf{x}+\mathbf{z} is not straightforward due to the presence of noise. This difficulty arises due to the fact that the independent variable 𝐱\mathbf{x} is to be estimated from a noisy dependent observation 𝐲\mathbf{y} in the latter setting and, hence, the optimal estimate does not have a simple form.

Motivated by canonical estimation problems in wireless communication, we restrict our attention to the problem of estimating complex symbols under an unknown transformation embedded in Gaussian noise, and we show that it naturally fits into the above framework. In this class of problems, the relationship between the input symbols and the observed symbols (in the complex domain) is captured by:

𝐲~t=𝐡~t(θ)x~t+𝐳~td,t[n+1].\tilde{\mathbf{y}}_{t}=\tilde{\mathbf{h}}_{t}(\theta)\tilde{x}_{t}+\tilde{\mathbf{z}}_{t}\in\mathbb{C}^{d},\quad t\in[n+1]. (2)

Writing the above as a real-matrix equation, we obtain

𝐲t[Re(𝐲~t)Im(𝐲~t)]=[Re(𝐡~t(θ))Im(𝐡~t(θ))Im(𝐡~t(θ))Re(𝐡~t(θ))][Re(x~t)Im(x~t)]+[Re(𝐳~t)Im(𝐳~t)]𝐇t(θ)𝐱t+𝐳t,\displaystyle\mathbf{y}_{t}\triangleq\begin{bmatrix}{\rm Re}(\tilde{\mathbf{y}}_{t})\\ {\rm Im}(\tilde{\mathbf{y}}_{t})\end{bmatrix}=\begin{bmatrix}\operatorname{Re}(\tilde{\mathbf{h}}_{t}(\theta))&-\operatorname{Im}(\tilde{\mathbf{h}}_{t}(\theta))\\ \operatorname{Im}(\tilde{\mathbf{h}}_{t}(\theta))&\operatorname{Re}(\tilde{\mathbf{h}}_{t}(\theta))\end{bmatrix}\begin{bmatrix}\operatorname{Re}(\tilde{x}_{t})\\ \operatorname{Im}(\tilde{x}_{t})\end{bmatrix}+\begin{bmatrix}\operatorname{Re}(\tilde{\mathbf{z}}_{t})\\ \operatorname{Im}(\tilde{\mathbf{z}}_{t})\end{bmatrix}\triangleq\mathbf{H}_{t}(\theta)\mathbf{x}_{t}+\mathbf{z}_{t}, (3)

where 𝐱t𝒳2,𝐇t(θ)2d×2,𝐳t,𝐲t2d\mathbf{x}_{t}\in\mathcal{X}\subset\mathbb{R}^{2},\mathbf{H}_{t}(\theta)\in\mathbb{R}^{2d\times 2},\mathbf{z}_{t},\mathbf{y}_{t}\in\mathbb{R}^{2d} which takes the form of the problem in Definition 1.

III Main Results

In this section, we present our theoretical analysis for a single layer softmax attention transformer (SAT) as an estimator for the problem corresponding to (2). We work with the following assumption.

Assumption 2

The distributions of the hidden process, input symbols, and noise are characterized as follows.

  1. (a)

    The hidden process is time invariant, i.e., 𝐡~t(θ)=𝐡~(θ)\tilde{\mathbf{h}}_{t}(\theta)=\tilde{\mathbf{h}}(\theta), where 𝐡~(θ)hθ\tilde{\mathbf{h}}(\theta)\sim\mathbb{P}_{h\mid\theta} is a fixed hidden variable under the context θΘ\theta\sim\mathbb{P}_{\Theta}.

  2. (b)

    The inputs x~tiidx\tilde{x}_{t}\stackrel{{\scriptstyle\rm iid}}{{\sim}}\mathbb{P}_{x} where x\mathbb{P}_{x} is some distribution on 𝒳~\tilde{\mathcal{X}}, and 𝒳~\tilde{\mathcal{X}}\subset\mathbb{C} is a finite subset of the unit circle in \mathbb{C}, i.e., |z|=1|z|=1 for any z𝒳~z\in\tilde{\mathcal{X}}.

  3. (c)

    The noise samples 𝐳~tiid𝒞𝒩(𝟎,𝚺~z)\tilde{\mathbf{z}}_{t}\stackrel{{\scriptstyle\rm iid}}{{\sim}}\mathcal{C}\mathcal{N}(\mathbf{0},\tilde{\boldsymbol{\Sigma}}_{z}) for some real positive definite matrix 𝚺~zd×d\tilde{\boldsymbol{\Sigma}}_{z}\in\mathbb{R}^{d\times d}.

Then, the real equations are given by

𝐲t=𝐇(θ)𝐱t+𝐳t,t1\mathbf{y}_{t}=\mathbf{H}(\theta)\mathbf{x}_{t}+\mathbf{z}_{t},\quad t\geq 1 (4)

where 𝐳t𝒩(𝟎,𝚺z)\mathbf{z}_{t}\sim\mathcal{N}(\mathbf{0},\boldsymbol{\Sigma}_{z}) such that 𝚺z12[𝚺~z𝟎𝟎𝚺~z]\boldsymbol{\Sigma}_{z}\triangleq\frac{1}{2}\begin{bmatrix}\tilde{\boldsymbol{\Sigma}}_{z}&\mathbf{0}\\ \mathbf{0}&\tilde{\boldsymbol{\Sigma}}_{z}\end{bmatrix}, and note that 𝚺z1=2[𝚺~z1𝟎𝟎𝚺~z1]\boldsymbol{\Sigma}_{z}^{-1}=2\begin{bmatrix}\tilde{\boldsymbol{\Sigma}}_{z}^{-1}&\mathbf{0}\\ \mathbf{0}&\tilde{\boldsymbol{\Sigma}}_{z}^{-1}\end{bmatrix}, and 𝐱t𝒳\mathbf{x}_{t}\in\mathcal{X} can be identified as a subset of unit sphere 𝕊2\mathbb{S}^{2} in 2\mathbb{R}^{2}. From now on, we only work with the above real quantities.

We next derive the minimum mean squared error (MMSE) estimate for 𝐱t\mathbf{x}_{t} given 𝐲t\mathbf{y}_{t} and 𝐇t=𝐇\mathbf{H}_{t}=\mathbf{H}. Note that when 𝐇t=𝐇\mathbf{H}_{t}=\mathbf{H} is known, the dependence of the estimation problem on θ\theta vanishes. Further, {(𝐲t,𝐱t)}\left\{(\mathbf{y}_{t},\mathbf{x}_{t})\right\} are conditionally independent and identically distributed (iid) given 𝐇t=𝐇\mathbf{H}_{t}=\mathbf{H} for all tt. Therefore, the MMSE estimation corresponds to estimating 𝐱\mathbf{x} from 𝐲=𝐇𝐱+𝐳\mathbf{y}=\mathbf{H}\mathbf{x}+\mathbf{z} when 𝐇\mathbf{H} is known.

Lemma 1

(MMSE estimate) The optimal estimator 𝐱^MMSE:2d×2d×22\hat{\mathbf{x}}^{\rm MMSE}:\mathbb{R}^{2d}\times\mathbb{R}^{2d\times 2}\to\mathbb{R}^{2} for 𝐱𝒳𝕊2\mathbf{x}\in\mathcal{X}\subset\mathbb{S}^{2} given 𝐲2d,𝐇2d×2\mathbf{y}\in\mathbb{R}^{2d},\mathbf{H}\in\mathbb{R}^{2d\times 2} where 𝐲=𝐇𝐱+𝐳\mathbf{y}=\mathbf{H}\mathbf{x}+\mathbf{z} for some 𝐳𝒩(𝟎,𝚺z)\mathbf{z}\sim\mathcal{N}(\mathbf{0},\boldsymbol{\Sigma}_{z}) is given by

𝐱^MMSE(𝐲,𝐇;𝚺z,x)=𝔼[𝐱𝐲,𝐇]=𝐱𝒳𝐱x(𝐱)exp(𝐲T𝚺z1𝐇𝐱)𝐱𝒳x(𝐱)exp(𝐲T𝚺z1𝐇𝐱),\displaystyle\hat{\mathbf{x}}^{\rm MMSE}(\mathbf{y},\mathbf{H};\boldsymbol{\Sigma}_{z},\mathbb{P}_{x})=\mathbb{E}[\mathbf{x}\mid\mathbf{y},\mathbf{H}]=\frac{\sum_{\mathbf{x}\in\mathcal{X}}\mathbf{x}\mathbb{P}_{x}(\mathbf{x})\exp(\mathbf{y}^{T}\boldsymbol{\Sigma}_{z}^{-1}\mathbf{H}\mathbf{x})}{\sum_{\mathbf{x}\in\mathcal{X}}\mathbb{P}_{x}(\mathbf{x})\exp(\mathbf{y}^{T}\boldsymbol{\Sigma}_{z}^{-1}\mathbf{H}\mathbf{x})},

where 𝐇[Re(𝐡~)Im(𝐡~)Im(𝐡~)Re(𝐡~)]\mathbf{H}\triangleq\begin{bmatrix}\operatorname{Re}(\tilde{\mathbf{h}})&-\operatorname{Im}(\tilde{\mathbf{h}})\\ \operatorname{Im}(\tilde{\mathbf{h}})&\operatorname{Re}(\tilde{\mathbf{h}})\end{bmatrix} for some 𝐡~d\tilde{\mathbf{h}}\in\mathbb{C}^{d}, and 𝚺z12[𝚺~z𝟎𝟎𝚺~z]\boldsymbol{\Sigma}_{z}\triangleq\frac{1}{2}\begin{bmatrix}\tilde{\boldsymbol{\Sigma}}_{z}&\mathbf{0}\\ \mathbf{0}&\tilde{\boldsymbol{\Sigma}}_{z}\end{bmatrix} for some symmetric positive definite matrix 𝚺~zd×d\tilde{\boldsymbol{\Sigma}}_{z}\in\mathbb{R}^{d\times d}.

Proof:

This follows from elementary computations involving Bayes theorem (see Appendix LABEL:sec:appendix-finite-estimation).

Since 𝐇(θ)\mathbf{H}(\theta) is not known in the problems we consider, the above estimator essentially provides a lower bound on the performance of any in-context estimator.

Let us consider a single layer softmax attention transformer (SAT) to solve the in-context estimation of 𝐱\mathbf{x} given observation 𝐲\mathbf{y} and nn in-context examples 𝒫θn{(𝐲t,𝐱t),t[n]}\mathcal{P}^{n}_{\theta}\triangleq\{(\mathbf{y}_{t},\mathbf{x}_{t}),t\in[n]\}, where (𝐲,𝐱)(\mathbf{y},\mathbf{x}) and (𝐲t,𝐱t)(\mathbf{y}_{t},\mathbf{x}_{t}) in 𝒫θn\mathcal{P}_{\theta}^{n} satisfy (4). Let 𝐓SA\mathbf{T}^{\rm SA} denote an SAT with the query, key, and value matrices 𝐖~Q,𝐖~K,𝐖~V\tilde{\mathbf{W}}_{Q},\tilde{\mathbf{W}}_{K},\tilde{\mathbf{W}}_{V} respectively, acting on (n+1)(n+1) tokens 𝐮t[𝐲tT𝐱tT]T2d+2\mathbf{u}_{t}\triangleq[\mathbf{y}_{t}^{T}~{}\mathbf{x}_{t}^{T}]^{T}\in\mathbb{R}^{2d+2} for t[n]t\in[n] and 𝐮n+1[𝐲T𝟎2T]T2d+2\mathbf{u}_{n+1}\triangleq[\mathbf{y}^{T}~{}\mathbf{0}_{2}^{T}]^{T}\in\mathbb{R}^{2d+2}. Let 𝐔n+1\mathbf{U}_{n+1} be the matrix with columns 𝐮t\mathbf{u}_{t} for t[n+1]t\in[n+1]. Then, the (n+1)(n+1)th output token is given by

𝐓n+1SA(𝐔n+1)=t=1n+1𝐖~V𝐮texp(𝐮n+1𝐖~QT𝐖~K𝐮t)t=1n+1exp(𝐮n+1𝐖~QT𝐖~K𝐮t).\displaystyle\mathbf{T}_{n+1}^{\rm SA}(\mathbf{U}_{n+1})=\frac{\sum_{t=1}^{n+1}\tilde{\mathbf{W}}_{V}\mathbf{u}_{t}\exp(\mathbf{u}_{n+1}\tilde{\mathbf{W}}_{Q}^{T}\tilde{\mathbf{W}}_{K}\mathbf{u}_{t})}{\sum_{t=1}^{n+1}\exp(\mathbf{u}_{n+1}\tilde{\mathbf{W}}_{Q}^{T}\tilde{\mathbf{W}}_{K}\mathbf{u}_{t})}. (5)

The estimate of the SAT for 𝐱\mathbf{x} using nn examples, denoted by 𝐱^nSA\hat{\mathbf{x}}^{\rm SA}_{n} is obtained from the last two elements of the (n+1)(n+1)th output token, i.e., [𝐓n+1SA]2d+1:2d+2[\mathbf{T}_{n+1}^{\rm SA}]_{2d+1:2d+2}. Thus, the first nn columns of 𝐖~V\tilde{\mathbf{W}}_{V} do not affect the output 𝐱^nSA\hat{\mathbf{x}}^{\rm SA}_{n}, hence we set them to zeros without loss of generality. Motivated by the form of the MMSE estimate, we choose to re-parameterize the remaining entries of the weights as below:

𝐖~Q=[𝐖Q𝟎2d×1𝟎1×2d𝟎2×2],𝐖~K=[𝐖K𝟎2d×1𝟎1×2d𝟎2×2],𝐖~V=[𝟎2d×2d𝟎2d×1𝟎1×2d𝐈2].\displaystyle\tilde{\mathbf{W}}_{Q}=\begin{bmatrix}\mathbf{W}_{Q}&\mathbf{0}_{2d\times 1}\\ \mathbf{0}_{1\times 2d}&\mathbf{0}_{2\times 2}\end{bmatrix},\quad\tilde{\mathbf{W}}_{K}=\begin{bmatrix}\mathbf{W}_{K}&\mathbf{0}_{2d\times 1}\\ \mathbf{0}_{1\times 2d}&\mathbf{0}_{2\times 2}\end{bmatrix},\quad\tilde{\mathbf{W}}_{V}=\begin{bmatrix}\mathbf{0}_{2d\times 2d}&\mathbf{0}_{2d\times 1}\\ \mathbf{0}_{1\times 2d}&\mathbf{I}_{2}\end{bmatrix}. (6)

Denoting 𝐖𝐖QT𝐖K2d×2d\mathbf{W}\triangleq\mathbf{W}_{\rm Q}^{T}\mathbf{W}_{\rm K}\in\mathbb{R}^{2d\times 2d}, using (6) in (5), we get

𝐱^nSA(𝐔n+1;𝐖)=t=1n𝐱texp(𝐲T𝐖𝐲t)exp(𝐲T𝐖𝐲)+t=1nexp(𝐲T𝐖𝐲t).\displaystyle\hat{\mathbf{x}}^{\rm SA}_{n}(\mathbf{U}_{n+1};\mathbf{W})=\frac{\sum_{t=1}^{n}\mathbf{x}_{t}\exp(\mathbf{y}^{T}\mathbf{W}\mathbf{y}_{t})}{\exp(\mathbf{y}^{T}\mathbf{W}\mathbf{y})+\sum_{t=1}^{n}\exp(\mathbf{y}^{T}\mathbf{W}\mathbf{y}_{t})}.
Lemma 2

(Functional form of asymptotic softmax attention) For any θΘ\theta\in\Theta, suppose 𝐇i(θ)=𝐇\mathbf{H}_{i}(\theta)=\mathbf{H} is the common hidden parameter for i[n+1]i\in[n+1], such that 𝐲i=𝐇𝐱i+𝐳i\mathbf{y}_{i}=\mathbf{H}\mathbf{x}_{i}+\mathbf{z}_{i}, and 𝐲=𝐇𝐱+𝐳\mathbf{y}=\mathbf{H}\mathbf{x}+\mathbf{z}. For a prompt 𝐔n+1\mathbf{U}_{n+1} with the ttth column constructed as 𝐮t[𝐲tT,𝐱tT]T\mathbf{u}_{t}\triangleq[\mathbf{y}_{t}^{T},\mathbf{x}_{t}^{T}]^{T} for t[n]t\in[n] and 𝐮n+1[𝐲T,𝟎2T]T\mathbf{u}_{n+1}\triangleq[\mathbf{y}^{T},\mathbf{0}_{2}^{T}]^{T}, the estimated value by the transformer 𝐱^nSA\hat{\mathbf{x}}_{n}^{\rm SA} with parameter 𝐖\mathbf{W} satisfies

limn𝐱^nSA(𝐔n+1;𝐖)=𝐱𝒳𝐱x(𝐱)exp(𝐲T𝐖𝐇𝐱)𝐱𝒳x(𝐱)exp(𝐲T𝐖𝐇𝐱)a.s.\displaystyle\lim_{n\to\infty}\hat{\mathbf{x}}^{\rm SA}_{n}(\mathbf{U}_{n+1};\mathbf{W})=\frac{\sum_{\mathbf{x}\in\mathcal{X}}\mathbf{x}\mathbb{P}_{x}(\mathbf{x})\exp(\mathbf{y}^{T}\mathbf{W}\mathbf{H}\mathbf{x})}{\sum_{\mathbf{x}\in\mathcal{X}}\mathbb{P}_{x}(\mathbf{x})\exp(\mathbf{y}^{T}\mathbf{W}\mathbf{H}\mathbf{x})}~{}{\rm a.s.}