This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Towards Understanding How Transformers Learn In-context Through a Representation Learning Lens

Ruifeng Ren
Gaoling School of Artificial Intelligence
Renmin University of China
Beijing, China
[email protected]
&Yong Liu
Gaoling School of Artificial Intelligence
Renmin University of China
Beijing, China
[email protected]
Corresponding Author
Abstract

Pre-trained large language models based on Transformers have demonstrated remarkable in-context learning (ICL) abilities. With just a few demonstration examples, the models can implement new tasks without any parameter updates. However, it is still an open question to understand the mechanism of ICL. In this paper, we attempt to explore the ICL process in Transformers through a lens of representation learning. Initially, leveraging kernel methods, we figure out a dual model for one softmax attention layer. The ICL inference process of the attention layer aligns with the training procedure of its dual model, generating token representation predictions that are equivalent to the dual model’s test outputs. We delve into the training process of this dual model from a representation learning standpoint and further derive a generalization error bound related to the quantity of demonstration tokens. Subsequently, we extend our theoretical conclusions to more complicated scenarios, including one Transformer layer and multiple attention layers. Furthermore, drawing inspiration from existing representation learning methods especially contrastive learning, we propose potential modifications for the attention layer. Finally, experiments are designed to support our findings.

1 Introduction

Recently, large language models (LLMs) based on the Transformer architectures (Vaswani et al., 2017) has shown surprising in-context learning (ICL) capabilities (Brown et al., 2020; Wei et al., 2022; Dong et al., 2022; Liu et al., 2023). By prepending several training examples before query inputs without labels, the models can make predictions for the queries and achieve excellent performance without any parameter updates. This excellent capability enables pre-trained LLMs such as GPT models to be used in general downstream tasks conveniently. Despite the good performance of the ICL capabilities, the mechanism of ICL still remains an open question.

In order to better understand the ICL capabilities, many works began to give explanations from different aspects. Xie et al. (2021) propose a Bayesian inference framework to explain how ICL occurs between pretraining and test time, where the LLMs infers a shared latent concept among the demonstration examples. Garg et al. (2022) demonstrate through experiments that pre-trained Transformer-based models can learn new functions from in-context examples, including (sparse) linear functions, two-layer neural networks, and decision trees. Zhang et al. (2023b) adopt a Bayesian perspective and show that ICL implicitly performs the Bayesian model averaging algorithm, which is approximated by the attention mechanism. Li et al. (2023) define ICL as an algorithm learning problem where a transformer model implicitly builds a hypothesis function at inference-time and derive generalization bounds for ICL. Han et al. (2023) suggest that LLMs can emulate kernel regression algorithms and exhibit similar behaviors during ICL. These works have provided significant insights into the interpretation of ICL capabilities from various perspectives.

In addition to the above explorations, there are also some attempts to relate ICL capabilities to gradient descent. Inspired by the dual form of linear attention proposed in Aiserman et al. (1964) and Irie et al. (2022), the ICL process is interpreted as implicit fine-tuning in the setting of linear attention by Dai et al. (2022). However, there is still a certain noticeable gap between linear attention and the widely used softmax attention. Additionally, this comparison is more of a formal resemblance and the specific details of gradient descent, including the form of the loss function and training data, require a more fine-grained exploration. Akyürek et al. (2022) show that by constructing specific weights, Transformer layers can perform fundamental operations (mov, mul, div, aff), which can be combined to execute gradient descent. Von Oswald et al. (2023a) adopt another construction, such that the inference process on a single or multiple linear attention layers can be equivalently seen as taking one or multiple steps of gradient descent on linear regression tasks. Building upon this weight construction method, subsequent work has conducted a more in-depth exploration of the capabilities of ICL under a causal setting, noticing that the inference of such attention layers is akin to performing online gradient descent (Ding et al., 2023; Von Oswald et al., 2023b). However, these analyses are still conducted under the assumption of linear attention and primarily focus on linear regression tasks, adopting specific constructions for the input tokens (concatenated from features and labels) and model weights. This limits the explanation of the Transformer’s ICL capabilities in more general settings. Thus, the question arises: Can we relate ICL to gradient descent under the softmax attention setting, rather than the linear attention setting, without assuming specific constructions for model weights and input tokens?

Motivated by the aforementioned challenges and following these works that connect ICL with gradient descent, we explore the ICL inference process from a representation learning lens. First, by incorporating kernel methods, we establish a connection between the ICL inference process of one softmax attention layer and the gradient descent process of its dual model. The test prediction of the trained dual model will be equivalent to the ICL inference result. We analyze the training process of this dual model from the perspective of representation learning and compare it with existing representation learning methods. Then, we derive a generalization error bound of this process, which is related to the number of demonstration tokens. Our conclusions can be easily extended to more complex scenarios, including a single Transformer layer and multiple attention layers. Furthermore, inspired by existing representation learning methods especially contrastive learning, we propose potential modifications to the attention layer and experiments are designed to support our findings.

2 Preliminaries

2.1 In-context Learning with Transformers

The model we consider is composed of many stacked Transformer decoder layers, each of which is composed of an attention layer and a FFN layer. For simplicity, we have omitted structures such as residual connections and layer normalization, retaining only the most essential parts. We consider the standard ICL scenario, where the model’s input consists of demonstrations followed by query inputs, that is, the input can be represented as 𝑿=[𝑿D,𝑿T]di×(N+T){\bm{X}}=[{\bm{X}}_{D},{\bm{X}}_{T}]\in\mathbb{R}^{d_{i}\times(N+T)}, where 𝑿D=[𝒙1,𝒙2,,𝒙N]{\bm{X}}_{D}=[{\bm{x}}_{1},{\bm{x}}_{2},...,{\bm{x}}_{N}] denotes NN demonstration tokens, and 𝑿T=[𝒙1,𝒙2,,𝒙T]{\bm{X}}_{T}=[{\bm{x}}^{\prime}_{1},{\bm{x}}^{\prime}_{2},...,{\bm{x}}^{\prime}_{T}] denotes TT query tokens. Here, we focus more on how tokens interact during model inference while ignoring the internal structure of demonstration tokens. For the query input at position T+1T+1, its output after one layer of Transformer can be represented as

𝒉T+1=𝑾V𝑿softmax((𝑾K𝑿)T𝑾Q𝒙T+1/do),{\bm{h}}^{\prime}_{T+1}={\bm{W}}_{V}{\bm{X}}\mathrm{softmax}\left(({\bm{W}}_{K}{\bm{X}})^{T}{\bm{W}}_{Q}{\bm{x}}^{\prime}_{T+1}/\sqrt{d_{o}}\right), (1)
𝒙^T+1=𝑾2ReLu(𝑾1𝒉T+1+𝒃1)+𝒃2,\widehat{{\bm{x}}}^{\prime}_{T+1}={\bm{W}}_{2}\mathrm{ReLu}({\bm{W}}_{1}{\bm{h}}^{\prime}_{T+1}+{\bm{b}}_{1})+{\bm{b}}_{2}, (2)

where 𝑾K,𝑾Q,𝑾Vdo×di{\bm{W}}_{K},{\bm{W}}_{Q},{\bm{W}}_{V}\in{\mathbb{R}}^{d_{o}\times d_{i}} are parameters for key, query, value projections and 𝑾1dh×do{\bm{W}}_{1}\in{\mathbb{R}}^{d_{h}\times d_{o}},𝑾2do×dh{\bm{W}}_{2}\in{\mathbb{R}}^{d_{o}\times d_{h}},𝒃1dh,𝒃2×do{\bm{b}}_{1}\in{\mathbb{R}}^{d_{h}},{\bm{b}}_{2}\times{\mathbb{R}}^{d_{o}} are FFN parameters. Our concern is how the query token 𝒙T+1{\bm{x}}^{\prime}_{T+1} learns in-context information from demonstrations. Unlike previous work (Von Oswald et al., 2023a; Zhang et al., 2023a; Bai et al., 2023), here we do not make additional assumptions about the structure of input matrix 𝑿{\bm{X}} and parameters to study the Transformer’s ability to implement some specific algorithms. Instead, we adopt the same setting as (Dai et al., 2022) to study more general cases.

2.2 Self-Supervised Representation Learning Using Contrastive Loss Functions

Representation learning aims to learn embeddings of data to preserve useful information for downstream tasks. One class of methods most relevant to our work is probably contrastive learning methods without negative samples (Chen and He, 2021; Grill et al., 2020; Caron et al., 2020; Tian et al., 2021). Contrastive learning is a significant approach of self-supervised learning (SSL) which aims at learning representations by minimizing the distance between the augmentations of the same data point (positive samples) while maximizing the distance from different data points (negative samples) (He et al., 2020; Chen et al., 2020b; Oord et al., 2018; Oh Song et al., 2016). To alleviate the burden of constructing a sufficient number of negative samples while avoiding representational collapse, some works propose architectures for contrastive learning without negative samples, which mainly use weight-sharing network known as Siamese networks (Chen and He, 2021; Grill et al., 2020; Caron et al., 2020; Tian et al., 2021). The architecture takes two augmentations 𝒙1,𝒙2{\bm{x}}_{1},{\bm{x}}_{2} from the same data 𝒙{\bm{x}} as inputs, which will be processed by online network and target network respectively to obtain the corresponding representations, that is, 𝒙^1=fonline(𝒙1),𝒙^2=ftarget(𝒙2)\hat{{\bm{x}}}_{1}=f_{\mathrm{online}}({\bm{x}}_{1}),\hat{{\bm{x}}}_{2}=f_{\mathrm{target}}({\bm{x}}_{2}). The two encoder networks share weights directly or using Exponential Moving Average (EMA). Then, 𝒙^1\hat{{\bm{x}}}_{1} will be input into a predictor head to obtain the predictive representation 𝒛1=g(𝒙^1){\bm{z}}_{1}=g(\hat{{\bm{x}}}_{1}). Finally, we minimize the distance between the predictive representation and target representation, that is, (𝒛1,StopGrad(𝒙^2))\mathcal{L}\left({\bm{z}}_{1},\mathrm{StopGrad}(\hat{{\bm{x}}}_{2})\right) where StopGrad()\mathrm{StopGrad}(\cdot) means 𝒙^2\hat{{\bm{x}}}_{2} is treated as a constant during backpropagation. For ()\mathcal{L}(\cdot), we often choose the cosine similarity or the l2l_{2}-norm as a measure of distance, although they are equivalent when the vector is normalized. Another class similar to our work is kernel contrastive learning (Esser et al., 2024). Given an anchor 𝒙{\bm{x}} and its positive and negative samples 𝒙+,𝒙{\bm{x}}^{+},{\bm{x}}^{-}, it aims to optimize the loss function =f(𝒙)T(f(𝒙)f(𝒙+))\mathcal{L}=f({\bm{x}})^{T}(f({\bm{x}}^{-})-f({\bm{x}}^{+})), where f(𝒙)=𝑾ϕ(𝒙)f({\bm{x}})={\bm{W}}\phi({\bm{x}}) and ϕ(𝒙)\phi({\bm{x}}) is the feature mapping for some kernel. We will consider the gradient descent process corresponding to the inference process of ICL from the perspective of representation learning and compare it with the two aforementioned representation learning patterns.

2.3 Gradient Descent on Linear Layer is the Dual Form of Linear Attention

It has been found that the linear attention can be connected to the linear layer optimized by gradient descent (Aiserman et al., 1964; Irie et al., 2022; Dai et al., 2022), that is, the gradient descent on linear layer can be seen as the dual form 111It should be clarified that the term ”dual” here is different from the one in mathematical optimization theory. Instead, it follows the terminology used in previous works (Irie et al., 2022; Dai et al., 2022), where the forward process of the attention layer and backward process on some model are referred to as a form of ”dual”. of linear attention. A simple linear layer can be defined as fL(𝒙)=𝑾𝒙,f_{L}({\bm{x}})={\bm{W}}{\bm{x}}, where 𝑾do×di{\bm{W}}\in{\mathbb{R}}^{d_{o}\times d_{i}} is the projection matrix. Given training inputs [𝒙i]i=1Ndi[{\bm{x}}_{i}]^{N}_{i=1}\in{\mathbb{R}}^{d_{i}} with their labels [𝒚i]i=1Ndo[{\bm{y}}_{i}]^{N}_{i=1}\in{\mathbb{R}}^{d_{o}}, a linear layer can output the predictions [𝒚^i]i=1N[\hat{{\bm{y}}}_{i}]^{N}_{i=1} where 𝒚^i=𝑾𝒙i\hat{{\bm{y}}}_{i}={\bm{W}}{\bm{x}}_{i} and then compute certain loss (𝒚^i,𝒚i)\mathcal{L}(\hat{{\bm{y}}}_{i},{\bm{y}}_{i}) for training. Backpropagation signals [𝒆i]i=1Ndo[{\bm{e}}_{i}]^{N}_{i=1}\in{\mathbb{R}}^{d_{o}} will be produced to update 𝑾{\bm{W}} in gradient descent process where 𝒆i=η(𝒚^i){\bm{e}}_{i}=-\eta\left(\nabla_{\hat{{\bm{y}}}_{i}}\mathcal{L}\right) if we set η\eta as the learning rate. During test time, the trained weight matrix 𝑾^\widehat{{\bm{W}}} can be represented by its initialization 𝑾0{\bm{W}}_{0} and the updated part Δ𝑾\Delta{\bm{W}}, that is,

𝑾^=𝑾0+Δ𝑾=𝑾0+i=1N𝒆i𝒙i,\widehat{{\bm{W}}}={\bm{W}}_{0}+\Delta{\bm{W}}={\bm{W}}_{0}+\sum_{i=1}^{N}{\bm{e}}_{i}\otimes{\bm{x}}_{i}, (3)

where \otimes denotes the outer product according to the chain rule of differentiation. On the other hand, this process can be viewed from the perspective of linear attention. Let [𝒌i]i=1N,[𝒗i]i=1Ndi[{\bm{k}}_{i}]^{N}_{i=1},[{\bm{v}}_{i}]^{N}_{i=1}\in{\mathbb{R}}^{d_{i}} denote the NN key and value vectors constituting matrices 𝑲,𝑽di×N{\bm{K}},{\bm{V}}\in{\mathbb{R}}^{d_{i}\times N} respectively. For a given query input 𝒒di{\bm{q}}\in{\mathbb{R}}^{d_{i}}, linear attention is typically defined as the weighted sum of these value vectors

LA(𝑽,𝑲,𝒒)\displaystyle\mathrm{LA}({\bm{V}},{\bm{K}},{\bm{q}}) =𝑽𝑲T𝒒=i=1N𝒗i𝒌iT𝒒=(i=1N𝒗i𝒌i)𝒒.\displaystyle={\bm{V}}{\bm{K}}^{T}{\bm{q}}=\sum_{i=1}^{N}{\bm{v}}_{i}{\bm{k}}_{i}^{T}{\bm{q}}=\left(\sum_{i=1}^{N}{\bm{v}}_{i}\otimes{\bm{k}}_{i}\right){\bm{q}}.

Then, we can rewrite the output of a linear layer during test time as

fL(𝒙test)\displaystyle f_{L}({\bm{x}}_{test}) =𝑾^𝒙test=𝑾0𝒙test+(i=1N𝒆i𝒙i)𝒙test=𝑾0𝒙test+LA(𝑬,𝑿,𝒙test),\displaystyle=\widehat{{\bm{W}}}{\bm{x}}_{test}={\bm{W}}_{0}{\bm{x}}_{test}+\left(\sum_{i=1}^{N}{\bm{e}}_{i}\otimes{\bm{x}}_{i}\right){\bm{x}}_{test}={\bm{W}}_{0}{\bm{x}}_{test}+\mathrm{LA}({\bm{E}},{\bm{X}},{\bm{x}}_{test}), (4)

where 𝑬do×N{\bm{E}}\in{\mathbb{R}}^{d_{o}\times N} and 𝑿di×N{\bm{X}}\in{\mathbb{R}}^{d_{i}\times N} are stacked by backpropagation signals [𝒆i]i=1N[{\bm{e}}_{i}]^{N}_{i=1} and training inputs [𝒙i]i=1N[{\bm{x}}_{i}]^{N}_{i=1} respectively. We can find from Eq (4) that the trained weight 𝑾^\widehat{{\bm{W}}} records all training datapoints and the test prediction of the linear layer indicates which training datapoints are chosen to activate using LA()\mathrm{LA}(\cdot) where [𝒆i]i=1N[{\bm{e}}_{i}]^{N}_{i=1} can be considered as values while [𝒙i]i=1N[{\bm{x}}_{i}]^{N}_{i=1} as keys and 𝒙test{\bm{x}}_{test} as the query. This interpretation uses gradient descent as a bridge to connect predictions of linear layers with linear attention, which can be seen as a simplified softmax attention used in Transformers.

Inspired by this relationship, Dai et al. (2022) understand ICL as implicit fine-tuning. However, this interpretation based on linear attention deviates from the softmax attention used in practical Transformers. Furthermore, this alignment is also ambiguous as the specific details of the gradient descent process, including the form of loss function and dataset, have not been explicitly addressed. In addition, Von Oswald et al. (2023a); Ding et al. (2023) also connect ICL with gradient descent for linear regression tasks using weight construction methods, where parameters 𝑾K{\bm{W}}_{K}, 𝑾Q{\bm{W}}_{Q} and 𝑾V{\bm{W}}_{V} of the self-attention layer need to roughly adhere to a specific constructed form. However, these analyses rely on the setting of linear regression tasks and assumptions about the form of input tokens (concatenated with features and labels), which limits the interpretability of ICL capabilities from the perspective of gradient descent. Thus, we attempt to address these issues in the following sections.

Refer to caption
Figure 1: The ICL output 𝒉N+1{\bm{h}}^{\prime}_{N+1} of one softmax attention layer is equivalent to the test prediction 𝒚^test\hat{{\bm{y}}}_{test} of its trained dual model f(𝒙)=𝑾^ϕ(𝒙)f({\bm{x}})=\widehat{{\bm{W}}}\phi({\bm{x}}). The training data and test input can be obtained by linear transformations of demonstration and query tokens, respectively.

3 Connecting ICL with Gradient Descent

In this section, we will address two questions discussed above: (i) Without assuming specific constructions for model weights and input tokens, how to relate ICL to gradient descent in the setting of softmax attention instead of linear attention? (2) What are the specific forms of the training data and loss function in the gradient descent process corresponding to ICL? In addressing these two questions, we will explore the gradient descent process corresponding to ICL from the perspective of representation learning.

3.1 Connecting Softmax Attention with Kernels

Before we begin establishing the connection between ICL and gradient descent, we need to firstly rethink softmax attention with kernel methods. Dai et al. (2022) connect ICL with gradient descent under the linear attention setting. In fact, it is completely feasible to interpret ICL under softmax attention with the help of kernel methods. We define the attention block as

𝑨=softmax((𝑾K𝑿)T𝑾Q𝑿/do),{\bm{A}}=\mathrm{softmax}\left(({\bm{W}}_{K}{\bm{X}})^{T}{\bm{W}}_{Q}{\bm{X}}/\sqrt{d_{o}}\right), (5)

which can be viewed as the product of an unnormalized part 𝑨u{\bm{A}}_{u} and a normalizing multiplier 𝑫{\bm{D}}, that is,

𝑨=𝑨u𝑫1,𝑨u\displaystyle{\bm{A}}={\bm{A}}_{u}{\bm{D}}^{-1},~{}~{}{\bm{A}}_{u} =exp((𝑾K𝑿)T𝑾Q𝑿/do),𝑫=diag(𝟏NT𝑨u),\displaystyle=\mathrm{exp}\left(({\bm{W}}_{K}{\bm{X}})^{T}{\bm{W}}_{Q}{\bm{X}}/\sqrt{d_{o}}\right),~{}~{}{\bm{D}}=\mathrm{diag}({\bm{1}}_{N}^{T}{\bm{A}}_{u}), (6)

where exp()\mathrm{exp}(\cdot) is element-wise. Similar in (Choromanski et al., 2020), we define softmax kernel Ksm:do×do+K_{sm}:{\mathbb{R}}^{d_{o}}\times{\mathbb{R}}^{d_{o}}\to{\mathbb{R}}_{+} as Ksm(𝒙,𝒚)=e𝒙T𝒚=e𝒙2+𝒚22Kguass(𝒙,𝒚)K_{sm}({\bm{x}},{\bm{y}})=e^{{\bm{x}}^{T}{\bm{y}}}=e^{\frac{\|{\bm{x}}\|^{2}+\|{\bm{y}}\|^{2}}{2}}K_{guass}({\bm{x}},{\bm{y}}) where Kguass=e𝒙𝒚2/2K_{guass}=e^{-\|{\bm{x}}-{\bm{y}}\|^{2}/2} is the guassian kernel when the variance σ2=1\sigma^{2}=1. According to Mercer’s theorem (Mercer, 1909), there exists some mapping function ϕ:dodr\phi:{\mathbb{R}}^{d_{o}}\to{\mathbb{R}}^{d_{r}} satisfying that Ksm(𝒙,𝒚)=ϕ(𝒙)Tϕ(𝒚)K_{sm}({\bm{x}},{\bm{y}})=\phi({\bm{x}})^{T}\phi({\bm{y}}). Thus, noting that when omitting the do\sqrt{d_{o}}-renormalization and equivalently normalize key and value vectors in Eq (6), every entry in the unnormalized part 𝑨u{\bm{A}}_{u} can be seen as the output of softmax kernel KsmK_{sm} defined for the mapping ϕ\phi, which can be formulated as:

𝑨u(i,j)=exp((𝑾K𝒙i)T𝑾Q𝒙j)=Ksm(𝑾K𝒙i,𝑾Q𝒙j)=ϕ(𝑾K𝒙i)Tϕ(𝑾Q𝒙j).\displaystyle{\bm{A}}_{u}(i,j)=\mathrm{exp}\left(({\bm{W}}_{K}{\bm{x}}_{i})^{T}{\bm{W}}_{Q}{\bm{x}}_{j}\right)=K_{sm}({\bm{W}}_{K}{\bm{x}}_{i},{\bm{W}}_{Q}{\bm{x}}_{j})=\phi({\bm{W}}_{K}{\bm{x}}_{i})^{T}\phi({\bm{W}}_{Q}{\bm{x}}_{j}). (7)

There have been many forms of mapping function ϕ()\phi(\cdot) used in linear Transformers research to approximate this non-negative kernel (Choromanski et al., 2020; Katharopoulos et al., 2020; Peng et al., 2021; Lu et al., 2021). For example, we can choose ϕ()\phi(\cdot) as positive random features which has the form ϕ(𝒙)=e𝒘T𝒙𝒙2/2\phi({\bm{x}})=e^{{\bm{w}}^{T}{\bm{x}}-\|{\bm{x}}\|^{2}/2} to achieve unbiased approximation (Choromanski et al., 2020). Alternatively, we can also choose ϕ(𝒙)=elu(𝒙)+1\phi({\bm{x}})=\mathrm{elu}({\bm{x}})+1 proposed by Katharopoulos et al. (2020).

3.2 The Gradient Descent Process of ICL

Now, we begin to establish the connection between the ICL inference process of a softmax attention layer and gradient descent. We focus on a softmax attention layer in a trained Transformer model, where the parameters {𝑾Q,𝑾K,𝑾V}\{{\bm{W}}_{Q},{\bm{W}}_{K},{\bm{W}}_{V}\} have been determined and the input 𝑿=[𝑿D,𝑿T]{\bm{X}}=[{\bm{X}}_{D},{\bm{X}}_{T}] has the form introduced in Section 2.1. Then, after the inference by one attention layer, the query token at position T+1T+1 will have the form 𝒉T+1{\bm{h}}^{\prime}_{T+1} formulated by Eq (1).

On the other hand, given a specific softmax kernel mapping function ϕ(𝒙)\phi({\bm{x}}) that satisfies Eq (7), we can define the dual model for the softmax attention layer as

f(𝒙)=𝑾ϕ(𝒙),f({\bm{x}})={\bm{W}}\phi({\bm{x}}), (8)

where 𝑾do×dr{\bm{W}}\in{\mathbb{R}}^{d_{o}\times d_{r}} is parameters. We assume that the dual model obtains its updated weights 𝑾^\widehat{{\bm{W}}} after undergoing one step of gradient descent with some loss function \mathcal{L}. Subsequently, when we take 𝒛test=𝑾Q𝒙T+1{\bm{z}}_{test}={\bm{W}}_{Q}{\bm{x}}^{\prime}_{T+1} as the test input, we can obtain its test prediction as

𝒚^test=f(𝒛test)=f(𝑾Q𝒙T+1)=𝑾^ϕ(𝑾Q𝒙T+1).\hat{{\bm{y}}}_{test}=f({\bm{z}}_{test})=f\left({\bm{W}}_{Q}{\bm{x}}^{\prime}_{T+1}\right)=\widehat{{\bm{W}}}\phi\left({\bm{W}}_{Q}{\bm{x}}^{\prime}_{T+1}\right).

We will show that 𝒉T+1{\bm{h}}^{\prime}_{T+1} in Eq (1), is strictly equivalent to the above test prediction 𝒚^test\hat{{\bm{y}}}_{test}, which implies that the inference process of ICL involves a gradient descent step on the dual model. This can be illustrated by the following theorem:

Theorem 3.1.

The query token 𝐡T+1{\bm{h}}^{\prime}_{T+1} obtained through ICL inference process with one softmax attention layer, is equivalent to the test prediction 𝐲^test\hat{{\bm{y}}}_{test} obtained by performing one step of gradient descent on the dual model f(𝐱)=𝐖ϕ(𝐱)f({\bm{x}})={\bm{W}}\phi({\bm{x}}). The form of the loss function \mathcal{L} is:

=1ηDi=1N(𝑾V𝒙i)T𝑾ϕ(𝑾K𝒙i),\mathcal{L}=-\frac{1}{\eta D}\sum_{i=1}^{N}\left({\bm{W}}_{V}{\bm{x}}_{i}\right)^{T}{\bm{W}}\phi({\bm{W}}_{K}{\bm{x}}_{i}), (9)

where η\eta is the learning rate and DD is a constant.

Proof can be found in Appendix A. Theorem 3.1 demonstrates the equivalence between the ICL inference process and gradient descent. Below, we delve into more detailed discussions:

Training Set and Test Input: In fact, once the attention layer has already been trained, that is, 𝑾K,𝑾Q,𝑾V{\bm{W}}_{K},{\bm{W}}_{Q},{\bm{W}}_{V} has been determined, the demonstration tokens [𝒙i]i=1N[{\bm{x}}_{i}]_{i=1}^{N} will be used to construct a training set for the dual model. Specifically, the training data has the form {𝒛std(i),𝒚std(i)}i=1N\{{\bm{z}}_{std}^{(i)},{\bm{y}}_{std}^{(i)}\}_{i=1}^{N} where 𝒛std(i)=𝑾K𝒙i{\bm{z}}_{std}^{(i)}={\bm{W}}_{K}{\bm{x}}_{i} as inputs and 𝒚std(i)=𝑾V𝒙i{\bm{y}}_{std}^{(i)}={\bm{W}}_{V}{\bm{x}}_{i} as their labels. During training stage, for each input 𝒛std(i){\bm{z}}_{std}^{(i)}, the dual model outputs its prediction 𝒚^(i)=f(𝒛std(i))=𝑾ϕ(𝒛std(i))=𝑾ϕ(𝑾K𝒙i)\hat{{\bm{y}}}^{(i)}=f\left({\bm{z}}_{std}^{(i)}\right)={\bm{W}}\phi\left({\bm{z}}_{std}^{(i)}\right)={\bm{W}}\phi\left({\bm{W}}_{K}{\bm{x}}_{i}\right). Then, the loss function Eq (9) can be rewritten as =1ηDi=1N(𝒚std(i))T𝒚^(i),\mathcal{L}=-\frac{1}{\eta D}\sum_{i=1}^{N}({\bm{y}}_{std}^{(i)})^{T}\hat{{\bm{y}}}^{(i)}, which can be regarded as the cosine similarity. Then, using this loss function and the training data, we can perform one step of Stochastic Gradient Descent (SGD) on the dual model and obtain the updated 𝑾^\widehat{{\bm{W}}}. Finally, during the testing stage, we take 𝒛test=𝑾Q𝒙T+1{\bm{z}}_{test}={\bm{W}}_{Q}{\bm{x}}^{\prime}_{T+1} as the test input to get its prediction which will be consistent with the ICL result 𝒉T+1{\bm{h}}^{\prime}_{T+1}, that is, 𝒚^test=f(𝒛test)=𝑾^ϕ(𝑾Q𝒙T+1)=𝒉T+1\hat{{\bm{y}}}_{test}=f({\bm{z}}_{test})=\widehat{{\bm{W}}}\phi\left({\bm{W}}_{Q}{\bm{x}}^{\prime}_{T+1}\right)={\bm{h}}^{\prime}_{T+1}. This process can be illustrated in Figure 1. Demonstration tokens provide information about the training data points and the weight matrix 𝑾^\widehat{{\bm{W}}} is optimized to learn sufficient knowledge about demonstrations. This gradient descent process using the loss function \mathcal{L} applied to f(𝒙)f({\bm{x}}) can be seen as the dual form of the ICL inference process of the attention layer.

Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 2: Left Part: The representation learning process for the ICL inference by one attention layer. Remaining Part: Comparison of the ICL Representation Learning Process (Center Left), Contrastive Learning without Negative Samples (Center Right), and Contrastive Kernel Learning (Right).

Representation Learning Lens: Even though we have now clarified the details of the gradient descent process of ICL, what does this process more profoundly reveal to us? In fact, for a encoded demonstration token 𝒙i{\bm{x}}_{i}, the key and value mapping will generate a pair of features 𝑾K𝒙i{\bm{W}}_{K}{\bm{x}}_{i} and 𝑾V𝒙i{\bm{W}}_{V}{\bm{x}}_{i} that exhibit a certain distance from each other, akin to positive samples in contrastive learning. And then, ϕ(𝒙)\phi({\bm{x}}) projects 𝑾K𝒙i{\bm{W}}_{K}{\bm{x}}_{i} into a higher-dimensional space to capture deeper features. Finally, the weight matrix 𝑾{\bm{W}}, which maps ϕ(𝑾K𝒙i)\phi({\bm{W}}_{K}{\bm{x}}_{i}) back to the original space, is trained to make the mapped vector as close as possible to 𝑾V𝒙i{\bm{W}}_{V}{\bm{x}}_{i}. This process is illustrated in Figure 2. Below, we attempt to understand this process from the perspective of existing representation learning methods introduced in Section 2.2, although we emphasize that there are certain differences between them.

Comparison with Contrastive Learning without Negative Samples:  If we consider the key and value mapping as two types of data augmentation, then from the perspective of contrastive learning without negative samples, this process can be similarly formalized as

min𝑾(𝒚^(i),𝒚std(i))=(𝒚^(i),StopGrad(𝒚std(i))),\min_{{\bm{W}}}~{}\mathcal{L}\left(\hat{{\bm{y}}}^{(i)},{\bm{y}}_{std}^{(i)}\right)=\mathcal{L}\left(\hat{{\bm{y}}}^{(i)},\mathrm{StopGrad}({\bm{y}}_{std}^{(i)})\right),

where StopGrad()\mathrm{StopGrad}(\cdot) is naturally applicable because there are no learning parameters involved in the generation process of the representation 𝒚std(i){\bm{y}}_{std}^{(i)}. However, it’s important to note that the representation learning process of ICL is much simpler: Firstly, the online and target networks are absent while the augmentations 𝑾K𝒙i,𝑾V𝒙i{\bm{W}}_{K}{\bm{x}}_{i},{\bm{W}}_{V}{\bm{x}}_{i} are directly used as online and target representations respectively. Secondly, the predictor head is useful and not discarded, which is then used during test stage.

Comparison with Contrastive Kernel Learning:  Given an anchor data 𝒙{\bm{x}} and its positive and negative samples 𝒙+{\bm{x}}^{+}, 𝒙{\bm{x}}^{-}, contrastive kernel learning aims to optimize the loss function =f(𝒙)(f(𝒙)f(𝒙+))\mathcal{L}=f({\bm{x}})(f({\bm{x}}^{-})-f({\bm{x}}^{+})) where f(𝒙)=𝑾ϕ(𝒙)f({\bm{x}})={\bm{W}}\phi({\bm{x}}). There are significant differences in the representation learning process of ICL: Firstly, it does not involve negative samples. Secondly, there is no corresponding processing for positive samples, leading to parameter updates being solely dependent on the processing of the anchor.

Extension to More Complicated Scenarios: Theorem 3.1 can be naturally extended to one single Transformer layer and multiple attention layers. As for one Transformer layer formed in Section 2.1, its dual model f+(𝒙)=𝑾ϕ(𝒙)+𝒃f^{+}({\bm{x}})={\bm{W}}\phi({\bm{x}})+{\bm{b}} introduces an additional bias 𝒃{\bm{b}} and only 𝑾{\bm{W}} is trained while 𝒃{\bm{b}} remains fixed. In addition, the labels of training set will be 𝒚std(i)=𝑾F𝑾K𝒙i{\bm{y}}_{std}^{(i)}={\bm{W}}_{F}{\bm{W}}_{K}{\bm{x}}_{i} where 𝑾F{\bm{W}}_{F} has potential low-rankness property induced by Relu()\mathrm{Relu}(\cdot). As for multiple attention layers, the ICL inference process will be equivalent to sequentially performing gradient descent and making predictions on the dual model sequence. We provide more details in Appendix B.

Compared to Dai et al. (2022) considering the connection under linear attention setting, Theorem 3.1 gives explanation for more generally used softmax attention and offers a more detailed exploration of the training process. Additionally, unlike Von Oswald et al. (2023a, b); Ding et al. (2023)’s focus on particular linear regression task and specific configurations of token and parameters, we aim to explain the process of token interactions during ICL inference in a more general setting.

3.3 Generalization Bound of the dual gradient descent process for ICL

In this part, we are interested in the generalization bound of the ICL gradient process. When ICL inference is performed for some task 𝒯{\mathcal{T}}, we cannot provide all demonstrations related to task 𝒯{\mathcal{T}} limited by the length of input tokens. We denote 𝒮𝒯di{\mathcal{S}}_{{\mathcal{T}}}\subseteq{\mathbb{R}}^{d_{i}} as all possible tokens for the task 𝒯{\mathcal{T}} and assume that these tokens will be selected according to the distribution 𝒟𝒯{\mathcal{D}}_{\mathcal{T}}. During a particular instance of ICL inference, let 𝒮={𝒙i}i=1N𝒮𝒯\mathcal{S}=\{{\bm{x}}_{i}\}_{i=1}^{N}\subseteq{\mathcal{S}}_{{\mathcal{T}}} represent the example tokens we selected. We define the function class as :={f(𝒙)=𝑾ϕ(𝑾K𝒙)|𝑾w}\mathcal{F}:=\left\{f({\bm{x}})={\bm{W}}\phi({\bm{W}}_{K}{\bm{x}})~{}|~{}\|{\bm{W}}\|\leq w\right\} where \|\cdot\| denotes the Frobenius norm. Generally, ignoring constant term in Eq (9), we consider the representation learning loss as

(f)=𝔼𝒙𝒟𝒯[(𝑾V𝒙)Tf(𝒙)]=𝔼𝒙𝒟𝒯[(𝑾V𝒙)T𝑾ϕ(𝑾K𝒙)],{\mathcal{L}}(f)=\mathbb{E}_{{\bm{x}}\sim{\mathcal{D}}_{\mathcal{T}}}\left[-\left({\bm{W}}_{V}{\bm{x}}\right)^{T}f({\bm{x}})\right]=\mathbb{E}_{{\bm{x}}\sim{\mathcal{D}}_{\mathcal{T}}}\left[-\left({\bm{W}}_{V}{\bm{x}}\right)^{T}{\bm{W}}\phi({\bm{W}}_{K}{\bm{x}})\right], (10)

where ff\in{\mathcal{F}} and 𝒟𝒯{\mathcal{D}}_{{\mathcal{T}}} is the distribution for some ICL task 𝒯{\mathcal{T}}. Correspondingly, the empirical loss will be formulated as ^(f)=1Ni=1N(𝑾V𝒙i)Tf(𝒙i)\hat{{\mathcal{L}}}(f)=-\frac{1}{N}\sum_{i=1}^{N}\left({\bm{W}}_{V}{\bm{x}}_{i}\right)^{T}f({\bm{x}}_{i}) and we have f^=argminfL^(f)\hat{f}=\operatorname*{arg\,min}_{f\in{\mathcal{F}}}\hat{L}(f). In addition, we denote the kernel matrix of demonstration tokens 𝒮\mathcal{S} as 𝑲𝒮N×N{\bm{K}}_{\mathcal{S}}\in{\mathbb{R}}^{N\times N} where (𝑲𝒮)i,j=ϕ(𝑾K𝒙i),ϕ(𝑾K𝒙j)({\bm{K}}_{\mathcal{S}})_{i,j}=\left\langle\phi({\bm{W}}_{K}{\bm{x}}_{i}),\phi({\bm{W}}_{K}{\bm{x}}_{j})\right\rangle, that is, the inner product of the feature maps after 𝑾K{\bm{W}}_{K} projection between the ii-th token and jj-th token. We state our theorem as follows:

Theorem 3.2.

Define the function class as :={f(𝐱)=𝐖ϕ(𝐖K𝐱)|𝐖w}\mathcal{F}:=\left\{f({\bm{x}})={\bm{W}}\phi({\bm{W}}_{K}{\bm{x}})~{}|~{}\|{\bm{W}}\|\leq w\right\} and let the loss function defined as Eq (10). Consider the given demonstration set as 𝒮={𝐱i}i=1N\mathcal{S}=\{{\bm{x}}_{i}\}_{i=1}^{N} where 𝒮𝒮𝒯\mathcal{S}\subseteq\mathcal{S}_{\mathcal{T}} and 𝒮𝒯\mathcal{S}_{\mathcal{T}} is all possible demonstration tokens for some task 𝒯\mathcal{T}. With the assumption that 𝐖V𝐱i,𝐖ϕ(𝐖K𝐱i)ρ\|{\bm{W}}_{V}{\bm{x}}_{i}\|,\|{\bm{W}}\phi({\bm{W}}_{K}{\bm{x}}_{i})\|\leq\rho, then for any δ>0\delta>0, the following statement holds with probability at least 1δ1-\delta for any ff\in\mathcal{F}

(f^)(f)+O(wρdoTr(𝑲𝒮)N+log1δN).\mathcal{L}(\hat{f})\leq\mathcal{L}(f)+O\left(\frac{w\rho d_{o}\sqrt{\mathrm{Tr}({\bm{K}}_{\mathcal{S}})}}{N}+\sqrt{\frac{log\frac{1}{\delta}}{N}}\right). (11)

Proof of 3.2 can be found in Appendix C. Theorem 3.2 provides the generalization bound of the optimal dual model trained on a finite selected demonstration set under a mild assumption that 𝑾\|{\bm{W}}\| is bounded. Intuitively, as the number of demonstration (and therefore the number of demonstration tokens) increases, the generalization error decreases, which is consistent with existing experimental observations (Xie et al., 2021; Garg et al., 2022; Wang et al., 2024).

4 Attention Modification Inspired by the Representation Learning Lens

Analyzing the dual gradient descent process of ICL from the perspective of representation learning inspires us to consider that: Do existing representation learning methods, especially contrastive learning methods, also involve a dual attention inference process? Alternatively, can we modify the attention mechanism by drawing on existing methods? In fact, since there are lots of mature works in representation learning especially contrastive learning, it is possible for us to achieve this by drawing on these works (He et al., 2020; Chen et al., 2020c; Wu et al., 2018; Chen et al., 2020a; Chen and He, 2021). We will provide some simple perspectives from the loss function, data augmentations and negative samples to try to adjust attention mechanism. It is worth noting that these modifications are also applicable to the self-attention mechanism, and we will explore these variants in experiments. More details can be seen in Appendix D.

Attention Modification inspired by the Contrastive Loss: It can be observed that the unnormalized similarity in Eq (9) allows 𝑾\|{\bm{W}}\| to be optimized to infinity if we ignore the Layer Normalization (LN) layer to prevent this. As for one single attention layer without LN layer, to address this issue, we can introduce regularization term to constrain the norm of 𝑾{\bm{W}}, specifically by

=1ηDi=1N(𝑾V𝒙i)T𝑾ϕ(𝑾K𝒙i)+α2η𝑾F2,\mathcal{L}=-\frac{1}{\eta D}\sum_{i=1}^{N}\left({\bm{W}}_{V}{\bm{x}}_{i}\right)^{T}{\bm{W}}\phi({\bm{W}}_{K}{\bm{x}}_{i})+\frac{\alpha}{2\eta}\|{\bm{W}}\|_{F}^{2}, (12)

where α\alpha is a hyperparameter. Equivalently, the attention output Eq (1) will be modified as

𝒉T+1=𝑾V[𝑿D,(1α)𝑿T]softmax((𝑾K𝑿)T𝑾Q𝒙T+1/do).{\bm{h}}^{\prime}_{T+1}={\bm{W}}_{V}\left[{\bm{X}}_{D},(1-\alpha){\bm{X}}_{T}\right]\mathrm{softmax}\left(({\bm{W}}_{K}{\bm{X}})^{T}{\bm{W}}_{Q}{\bm{x}}^{\prime}_{T+1}/\sqrt{d_{o}}\right). (13)

This modification is equivalent to retaining less prompt information for query token during aggregation and relatively more demonstration information will be attended to.

Attention Modification inspired by the Data Augmentation: If we analogize the key and value mappings to data augmentations in contrastive learning, then for the representation learning process of ICL, these overly simple linear augmentations may limit the model’s ability to learn deeper representations. Thus, more complicated augmentations can be considered. Denoting these two augmentations as g1g_{1} and g2g_{2}, the loss function will be modified as

=1ηDi=1N[g1(𝑾V𝒙i)]T𝑾ϕ(g2(𝑾K𝒙i)).\mathcal{L}=-\frac{1}{\eta D}\sum_{i=1}^{N}\left[g_{1}({\bm{W}}_{V}{\bm{x}}_{i})\right]^{T}{\bm{W}}\phi(g_{2}({\bm{W}}_{K}{\bm{x}}_{i})).

Correspondingly, the attention layer can be adjusted as,

𝒉T+1=g1(𝑾V𝑿)softmax([g2(𝑾K𝑿)]T𝑾Q𝒙T+1/do),{\bm{h}}^{\prime}_{T+1}=g_{1}({\bm{W}}_{V}{\bm{X}})\mathrm{softmax}\left([g_{2}({\bm{W}}_{K}{\bm{X}})]^{T}{\bm{W}}_{Q}{\bm{x}}^{\prime}_{T+1}/\sqrt{d_{o}}\right), (14)

where g1()g_{1}(\cdot) and g2()g_{2}(\cdot) will be column-wise here. Here we add augmentations for all tokens instead of only demonstration ones to maintain uniformity in the semantic space. In experiments, we simply select MLP for g1g_{1} and g2g_{2}. It’s worth noting that here we only propose the framework, and for different tasks, the augmentation approach should be specifically designed to adapt them.

Attention Modification inspired by the Negative Samples: Negative samples play a crucial role in preventing feature collapse in contrastive learning methods while the representation learning process of ICL only brings a single pair of features closer, lacking the modeling of what should be pushed apart, which could potentially limit the model’s ability to learn representations effectively. Therefore, we can introduce negative samples to address this:

=\displaystyle\mathcal{L}= 1ηDi=1N(𝑾V𝒙~i)T𝑾ϕ(𝑾K𝒙i),𝒙~i=𝒙iβ|𝒩(i)|j𝒩(i)𝒙j,\displaystyle-\frac{1}{\eta D}\sum_{i=1}^{N}\left({\bm{W}}_{V}\tilde{{\bm{x}}}_{i}\right)^{T}{\bm{W}}\phi({\bm{W}}_{K}{\bm{x}}_{i}),\quad\tilde{{\bm{x}}}_{i}={\bm{x}}_{i}-\frac{\beta}{|\mathcal{N}(i)|}\sum_{j\in\mathcal{N(\mathit{i})}}{\bm{x}}_{j},

where 𝒩(i)\mathcal{N}(i) is the set of the negative samples for 𝒙i{\bm{x}}_{i} and β\beta is a hyperparameter. Correspondingly, the attention layer is modified as

𝒉T+1=𝑾V[𝑿~D,𝑿T]softmax((𝑾K𝑿)T𝑾Q𝒙T+1/do),\displaystyle{\bm{h}}^{\prime}_{T+1}={\bm{W}}_{V}\left[\tilde{{\bm{X}}}_{D},{\bm{X}}_{T}\right]\mathrm{softmax}\left(({\bm{W}}_{K}{\bm{X}})^{T}{\bm{W}}_{Q}{\bm{x}}_{T+1}/\sqrt{d_{o}}\right), (15)

where 𝑿~D=[𝒙~1,𝒙~2,,𝒙~N]\tilde{{\bm{X}}}_{D}=[\tilde{{\bm{x}}}_{1},\tilde{{\bm{x}}}_{2},...,\tilde{{\bm{x}}}_{N}]. Here we simply use other tokens as negative samples and we emphasize that for specific tasks, an appropriate design of negative samples will be more effective.

Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 3: The equivalence between ICL of one softmax attention layer and gradient descent, along with analysis on different model modifications. Left Part: 𝒚^test𝒉T+12\|\hat{{\bm{y}}}_{test}-{\bm{h}}^{\prime}_{T+1}\|_{2} as the gradient descent proceeds under setting N=15N=15; Remaining Part: the performance for regularized models (Center Left), augmented models (Center Right) and negative models (Right) with different settings.

5 Experiments

In this section, we design experiments on synthetic tasks to support our findings and more experiments including on more realistic tasks can be seen in Appendix E. The questions of interest are: (i) Is the result of ICL inference equivalent to the test prediction of the trained dual model? (ii) Is it potential to improve the attention mechanism from the perspective of representation learning?

Linear Task Setting: Inspired by Von Oswald et al. (2023a), to validate the equivalence and demonstrate the effectiveness of the modifications, we firstly train one softmax self-attention layer using linear regression tasks. We generate the task by 𝒔=𝑾𝒕{\bm{s}}={\bm{W}}{\bm{t}} where every element of 𝑾ds×dt{\bm{W}}\in\mathbb{R}^{d_{s}\times d_{t}} is sampled from a normal distribution 𝑾ij𝒩(0,1){\bm{W}}_{ij}\sim\mathcal{N}(0,1) and 𝒕{\bm{t}} from uniform distribution 𝒕U(1,1)dt{\bm{t}}\sim U(-1,1)^{d_{t}}. We set dt=11d_{t}=11 and ds=1d_{s}=1. Then, at each step, we use generated {𝒙i=[𝒕i;si]}i=1N+1\{{\bm{x}}_{i}=[{\bm{t}}_{i};s_{i}]\}^{N+1}_{i=1} to form the input matrix 𝑿{\bm{X}} where the last token will be used as the query token and the label part will be masked, that is, 𝒙N+1=[𝒕i;0]{\bm{x}}_{N+1}=[{\bm{t}}_{i};0]. Here we consider only one query token (T=0T=0) and we denote 𝒙T+1=𝒙N+1{\bm{x}}^{\prime}_{T+1}={\bm{x}}_{N+1} to maintain consistency of notation in Section 2.1. Finally, the attention layer is trained to predict s^N+1\hat{s}_{N+1} to approximate the true label sN+1s_{N+1} using mean square error (MSE) loss.

Model Setting: It is worth noting that to facilitate direct access to the dual model, we use positive random features as kernel mapping functions (Performer architecture (Choromanski et al., 2020)) to approximate the standard softmax attention, that is, ϕ(𝒙)=e𝒘T𝒙𝒙2/2\phi({\bm{x}})=e^{{\bm{w}}^{T}{\bm{x}}-\|{\bm{x}}\|^{2}/2} where 𝒘𝒩(0,I){\bm{w}}\sim\mathcal{N}(0,I). We set the dimension of the random features as dr=100(dt+ds)=1200d_{r}=100(d_{t}+d_{s})=1200 to obtain relatively accurate estimation. After training, the weights of the attention layer have been determined. Thus, given specified input 𝑿{\bm{X}}, we can construct the dual model f(𝒙)=𝑾ϕ(𝒙)f({\bm{x}})={\bm{W}}\phi({\bm{x}}) and its corresponding training data and test input according to Theorem 3.1.

We perform three experiments under different random seeds for linear regression tasks with the results of one presented in Figure 3. In addition, we also conduct more experiments including these on trigonometric, exponential synthetic regression tasks and more realistic tasks. More details of experiments setting and results can be found in Appendix E. We mainly discuss the results on the linear regression task as follows.

Equivalence Between ICL and Gradient Descent: To answer the first question, we generate the test input 𝑿test{\bm{X}}_{test} using the same method as training and obtain the ICL result of the query token 𝒉T+1{\bm{h}}^{\prime}_{T+1}. On the other hand, we use 𝑿test{\bm{X}}_{test} to train the dual model according to Theorem 3.1 and get the test prediction 𝒚^test\hat{{\bm{y}}}_{test}. The result is shown in the left part part of Figure 3. It can be observed that after N=15N=15 epochs training on the dual model, the test prediction 𝒚^test\hat{{\bm{y}}}_{test} is exactly equivalent to the ICL inference result 𝒉T+1{\bm{h}}^{\prime}_{T+1} by one softmax attention layer, which aligns with our analysis in Theorem 3.1. More detailed experiments can be seen in Appendix E.1.

Analysis on the Modifications: In Section 4, we discussed different modifications to the attention mechanism from perspectives of contrastive loss, data augmentation and negative samples. Here we call these modifications regularized models, augmented models and negative models respectively. More details of modifications for self-attention mechanism can be seen in Appendix D.

For regularized models, we vary different α\alpha to investigate the impact on pretraining performance under the same setting, as shown in the center left part of Figure 3. It can be observed that when α>0\alpha>0, the regularized models converges to a poorer result while when α<0\alpha<0, the model converges faster and achieves final results comparable to the normal model without regularization (α=0\alpha=0). At least for this setting, this is a little contrary to our initial intention of applying regularization to the contrastive loss where α\alpha should be positive. We explain it that the appropriate α\alpha contributes to achieving a full-rank attention matrix as stated in Appendix D, preserving information and accelerating convergence.

For augmented models, we simply choose a single-layer MLP for g1()g_{1}(\cdot) and g2()g_{2}(\cdot) as data augmentations to enhance the value and key embeddings respectively in Eq (14) and we choose GELU (Hendrycks and Gimpel, 2016) as the activation function. It can be observed in the center right part of Figure 3 that when we only use g2g_{2}, that is, only provide augmentation for keys, the model actually shows slightly faster convergence than other cases. Furthermore, when we use two-layer MLP as g2+(𝒙)g_{2}^{+}({\bm{x}}) as a more complicated augmentation function, the result indicates that although the model initially converges slightly slower due to the increased number of parameters, it eventually accelerates convergence and achieves a better solution. This indicates that appropriate data augmentation indeed have the potential to enhance the capabilities of the attention layer.

For negative models, we select the kk tokens with the lowest attention scores as negative samples for each token. From Eq (15), we can see that it is equivalent to subtracting a certain value from the attention scores corresponding to those negative samples. We vary the number of negative samples kk and β\beta in Eq (15) and the results are shown in the right part of Figure 3. It can be found that the model has the potential to achieve slightly faster convergence with appropriate settings (k=3k=3 and β=0.1\beta=0.1). In fact, it can be noted that in the original attention mechanism, attention scores are always non-negative, indicating that some irrelevant information will always be preserved to some extent. However, in the modified structure, attention scores can potentially become negative, which makes the model more flexible to utilize information. Certainly, as we discussed in Section 4, for different tasks, more refined methods of selecting augmentations and constructing negative samples may be more effective and we also leave these aspects for future.

6 Related Work

Since Transformers have shown remarkable ICL abilities (Brown et al., 2020), many works have aimed to analyze the underlying mechanisms (Garg et al., 2022; Wang et al., 2023). To explain how Transformers can learn new tasks without parameter updates given few demonstrations, an intuitive idea is to link ICL with (implicit) gradient updates. The most relevant work to ours is that of Dai et al. (2022), which utilizes the dual form to understand ICL as an implicit fine-tuning (gradient descent) of the original model under a linear attention setting (Aiserman et al., 1964; Irie et al., 2022). They design a specific fine-tuning setting where only the parameters for the key and value projection are updated and the causal language modeling objective is adopted. In this context, they find ICL will have common properties with fine-tuning. Based on this, Deutch et al. (2024) investigate potential shortcomings in the evaluation metrics used by Dai et al. (2022) in real model assessments and propose a layer-causal GD variant that performs better in simulating ICL. As a comparison, our research also uses the dual form to analyze the nonlinear attention layer and explores the specific form of the loss used in the training process. However, we link ICL to the gradient descent performed on the dual model rather than fine-tuning the original model. The former process utilizes a self-supervised representation learning loss formalized as Eq (9) determined by the attention structure itself while performing supervised fine-tuning on the original model is often determined by task-specific training objectives (or manually specified causal language modeling objective Dai et al. (2022)). A more formal and detailed comparison can be found in Appendix F.

Additionally, many other works also link ICL with gradient descent, aiming to explore the Transformer’s ability to perform gradient descent algorithms to achieve ICL (Bai et al., 2023; Schlag et al., 2021). Akyürek et al. (2022) reveal that under certain constructions, Transformer can implement simple basic operations (mov, mul, div and aff), which can be combined to further perform gradient descent. Von Oswald et al. (2023a) provide a simple and appealing construction for solving least squares solutions in the linear attention setting. Subsequently, Zhang et al. (2023a); Ahn et al. (2023); Mahankali et al. (2023) provide theoretical evidence showing that the local or global minima will have a form similar to this specific construction proposed by Von Oswald et al. (2023a) under certain assumptions. These works, both experimentally and theoretically, often focus on specific linear regression tasks (y=𝒘T𝒙y={\bm{w}}^{T}{\bm{x}}) and specific structured input format where each token takes the form [𝒙,y][{\bm{x}},y] consisting of the input part 𝒙{\bm{x}} and the label part yy. In addition, the label part of the final query to be predicted is masked, represented as [𝒙,0][{\bm{x}},0]. Subsequent works have expanded this exploration under more complicated setups, including examining nonlinear attention instead of linear attention(Cheng et al., 2023; Collins et al., 2024), using unstructured inputs rather than structured ones(Xing et al., 2024), and considering casual or autoregressive setting(Ding et al., 2023; Von Oswald et al., 2023b). As a comparison to these works, our work does not target specific tasks like linear regression; therefore, we do not make detailed assumptions about the model weights (simply treated as weights after pre-training) or specific input forms. Instead, we aim to view the ICL inference process from the perspective of representation learning in the dual model. However, we would like to point out that under these specific weight and input settings, an intuitive explanation can also be provided from a representation learning perspective (see Appendix F). We also notice that Shen et al. (2023) experimentally show that there may exist differences between ICL inference in LLMs and the fine-tuned models in real-world scenarios from various perspectives and assumptions used in previous works may be strong. As mentioned earlier, our analysis primarily focus on linking ICL with gradient descent on the dual model of a simplified Transformer rather than fine-tuning the original model. Analyzing more realistic models will also be our future directions.

7 Conclusion and Impact Statements

In this paper, we establish a connection between the ICL process of Transformers and gradient descent of the dual model, offering novel insights from a representation learning lens. Based on this, we propose modifications for the attention layer and experiments under our setup demonstrate their potential. Although we have made efforts in understanding ICL, there are still some limitations in our analysis: (1) our work primarily focuses on the simplified Transformer and the impact of structures like layer normalization, residual connections, and others requires more nuanced analysis; (2) for more tasks and settings, the proposed model modifications may require more nuanced design and validation. We leave these aspects for future exploration. And we believe that this work mainly studies the theory of in-context learning, which does not present any foreseeable societal consequence.

8 Acknowledgements

We sincerely appreciate the anonymous reviewers for their helpful suggestions and constructive comments. This research was supported by National Natural Science Foundation of China (No.62476277, No.6207623), Beijing Natural Science Foundation (No.4222029), CCF-ALIMAMA TECH Kangaroo Fund (No.CCF-ALIMAMA OF 2024008), and Huawei-Renmin University joint program on Information Retrieval. We also acknowledge the support provided by the fund for building worldclass universities (disciplines) of Renmin University of China and by the funds from Beijing Key Laboratory of Big Data Management and Analysis Methods, Gaoling School of Artificial Intelligence, Renmin University of China, from Engineering Research Center of Next-Generation Intelligent Search and Recommendation, Ministry of Education, from Intelligent Social Governance Interdisciplinary Platform, Major Innovation & Planning Interdisciplinary Platform for the “DoubleFirst Class” Initiative, Renmin University of China, from Public Policy and Decision-making Research Lab of Renmin University of China, and from Public Computing Cloud, Renmin University of China.

References

  • Ahn et al. [2023] Kwangjun Ahn, Xiang Cheng, Hadi Daneshmand, and Suvrit Sra. Transformers learn to implement preconditioned gradient descent for in-context learning. Advances in Neural Information Processing Systems, 36:45614–45650, 2023.
  • Aiserman et al. [1964] MA Aiserman, Emmanuil M Braverman, and Lev I Rozonoer. Theoretical foundations of the potential function method in pattern recognition. Avtomat. i Telemeh, 25(6):917–936, 1964.
  • Akyürek et al. [2022] Ekin Akyürek, Dale Schuurmans, Jacob Andreas, Tengyu Ma, and Denny Zhou. What learning algorithm is in-context learning? investigations with linear models. arXiv preprint arXiv:2211.15661, 2022.
  • Amari [1993] Shun-ichi Amari. Backpropagation and stochastic gradient descent method. Neurocomputing, 5(4-5):185–196, 1993.
  • Bai et al. [2023] Yu Bai, Fan Chen, Huan Wang, Caiming Xiong, and Song Mei. Transformers as statisticians: Provable in-context learning with in-context algorithm selection. arXiv preprint arXiv:2306.04637, 2023.
  • Brown et al. [2020] Tom Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared D Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, et al. Language models are few-shot learners. Advances in neural information processing systems, 33:1877–1901, 2020.
  • Caron et al. [2020] Mathilde Caron, Ishan Misra, Julien Mairal, Priya Goyal, Piotr Bojanowski, and Armand Joulin. Unsupervised learning of visual features by contrasting cluster assignments. Advances in neural information processing systems, 33:9912–9924, 2020.
  • Chen et al. [2020a] Ting Chen, Simon Kornblith, Mohammad Norouzi, and Geoffrey Hinton. A simple framework for contrastive learning of visual representations. In International conference on machine learning, pages 1597–1607. PMLR, 2020a.
  • Chen et al. [2020b] Ting Chen, Simon Kornblith, Mohammad Norouzi, and Geoffrey Hinton. A simple framework for contrastive learning of visual representations. In International conference on machine learning, pages 1597–1607. PMLR, 2020b.
  • Chen and He [2021] Xinlei Chen and Kaiming He. Exploring simple siamese representation learning. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pages 15750–15758, 2021.
  • Chen et al. [2020c] Xinlei Chen, Haoqi Fan, Ross Girshick, and Kaiming He. Improved baselines with momentum contrastive learning. arXiv preprint arXiv:2003.04297, 2020c.
  • Cheng et al. [2023] Xiang Cheng, Yuxin Chen, and Suvrit Sra. Transformers implement functional gradient descent to learn non-linear functions in context. arXiv preprint arXiv:2312.06528, 2023.
  • Choromanski et al. [2020] Krzysztof Choromanski, Valerii Likhosherstov, David Dohan, Xingyou Song, Andreea Gane, Tamas Sarlos, Peter Hawkins, Jared Davis, Afroz Mohiuddin, Lukasz Kaiser, et al. Rethinking attention with performers. arXiv preprint arXiv:2009.14794, 2020.
  • Collins et al. [2024] Liam Collins, Advait Parulekar, Aryan Mokhtari, Sujay Sanghavi, and Sanjay Shakkottai. In-context learning with transformers: Softmax attention adapts to function lipschitzness. arXiv preprint arXiv:2402.11639, 2024.
  • Dai et al. [2022] Damai Dai, Yutao Sun, Li Dong, Yaru Hao, Zhifang Sui, and Furu Wei. Why can gpt learn in-context? language models secretly perform gradient descent as meta optimizers. arXiv preprint arXiv:2212.10559, 2022.
  • Deutch et al. [2024] Gilad Deutch, Nadav Magar, Tomer Natan, and Guy Dar. In-context learning and gradient descent revisited. In Proceedings of the 2024 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (Volume 1: Long Papers), pages 1017–1028, 2024.
  • Ding et al. [2023] Nan Ding, Tomer Levinboim, Jialin Wu, Sebastian Goodman, and Radu Soricut. Causallm is not optimal for in-context learning. arXiv preprint arXiv:2308.06912, 2023.
  • Dong et al. [2022] Qingxiu Dong, Lei Li, Damai Dai, Ce Zheng, Zhiyong Wu, Baobao Chang, Xu Sun, Jingjing Xu, and Zhifang Sui. A survey for in-context learning. arXiv preprint arXiv:2301.00234, 2022.
  • Esser et al. [2024] Pascal Esser, Maximilian Fleissner, and Debarghya Ghoshdastidar. Non-parametric representation learning with kernels. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 38, pages 11910–11918, 2024.
  • Garg et al. [2022] Shivam Garg, Dimitris Tsipras, Percy S Liang, and Gregory Valiant. What can transformers learn in-context? a case study of simple function classes. Advances in Neural Information Processing Systems, 35:30583–30598, 2022.
  • Grill et al. [2020] Jean-Bastien Grill, Florian Strub, Florent Altché, Corentin Tallec, Pierre Richemond, Elena Buchatskaya, Carl Doersch, Bernardo Avila Pires, Zhaohan Guo, Mohammad Gheshlaghi Azar, et al. Bootstrap your own latent-a new approach to self-supervised learning. Advances in neural information processing systems, 33:21271–21284, 2020.
  • Guo et al. [2022] Jianyuan Guo, Kai Han, Han Wu, Yehui Tang, Xinghao Chen, Yunhe Wang, and Chang Xu. Cmt: Convolutional neural networks meet vision transformers. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 12175–12185, 2022.
  • Han et al. [2023] Chi Han, Ziqi Wang, Han Zhao, and Heng Ji. In-context learning of large language models explained as kernel regression. arXiv preprint arXiv:2305.12766, 2023.
  • He et al. [2021] Junxian He, Chunting Zhou, Xuezhe Ma, Taylor Berg-Kirkpatrick, and Graham Neubig. Towards a unified view of parameter-efficient transfer learning. arXiv preprint arXiv:2110.04366, 2021.
  • He et al. [2020] Kaiming He, Haoqi Fan, Yuxin Wu, Saining Xie, and Ross Girshick. Momentum contrast for unsupervised visual representation learning. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pages 9729–9738, 2020.
  • Hendrycks and Gimpel [2016] Dan Hendrycks and Kevin Gimpel. Gaussian error linear units (gelus). arXiv preprint arXiv:1606.08415, 2016.
  • Irie et al. [2022] Kazuki Irie, Róbert Csordás, and Jürgen Schmidhuber. The dual form of neural networks revisited: Connecting test time predictions to training patterns via spotlights of attention. In International Conference on Machine Learning, pages 9639–9659. PMLR, 2022.
  • Katharopoulos et al. [2020] Angelos Katharopoulos, Apoorv Vyas, Nikolaos Pappas, and François Fleuret. Transformers are rnns: Fast autoregressive transformers with linear attention. pages 5156–5165, 2020.
  • Kenton and Toutanova [2019] Jacob Devlin Ming-Wei Chang Kenton and Lee Kristina Toutanova. Bert: Pre-training of deep bidirectional transformers for language understanding. In Proceedings of naacL-HLT, volume 1, page 2. Minneapolis, Minnesota, 2019.
  • Li et al. [2023] Yingcong Li, M. Emrullah Ildiz, Dimitris Papailiopoulos, and Samet Oymak. Transformers as algorithms: Generalization and stability in in-context learning. arXiv preprint arXiv:2301.07067, 2023.
  • Liu et al. [2023] Pengfei Liu, Weizhe Yuan, Jinlan Fu, Zhengbao Jiang, Hiroaki Hayashi, and Graham Neubig. Pre-train, prompt, and predict: A systematic survey of prompting methods in natural language processing. ACM Computing Surveys, 55(9):1–35, 2023.
  • Lu et al. [2021] Jiachen Lu, Jinghan Yao, Junge Zhang, Xiatian Zhu, Hang Xu, Weiguo Gao, Chunjing Xu, Tao Xiang, and Li Zhang. Soft: Softmax-free transformer with linear complexity. Advances in Neural Information Processing Systems, 34:21297–21309, 2021.
  • Mahankali et al. [2023] Arvind Mahankali, Tatsunori B Hashimoto, and Tengyu Ma. One step of gradient descent is provably the optimal in-context learner with one layer of linear self-attention. arXiv preprint arXiv:2307.03576, 2023.
  • Maurer [2016] Andreas Maurer. A vector-contraction inequality for rademacher complexities. In Algorithmic Learning Theory: 27th International Conference, ALT 2016, Bari, Italy, October 19-21, 2016, Proceedings 27, pages 3–17. Springer, 2016.
  • Mercer [1909] J. Mercer. Functions of positive and negative type, and their connection with the theory of integral equations. Philosophical Transactions of the Royal Society of London. Series A, Containing Papers of a Mathematical or Physical Character, 209:415–446, 1909. ISSN 02643952. URL http://www.jstor.org/stable/91043.
  • Mohri et al. [2018] Mehryar Mohri, Afshin Rostamizadeh, and Ameet Talwalkar. Foundations of machine learning. MIT press, 2018.
  • Oh Song et al. [2016] Hyun Oh Song, Yu Xiang, Stefanie Jegelka, and Silvio Savarese. Deep metric learning via lifted structured feature embedding. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 4004–4012, 2016.
  • Oord et al. [2018] Aaron van den Oord, Yazhe Li, and Oriol Vinyals. Representation learning with contrastive predictive coding. arXiv preprint arXiv:1807.03748, 2018.
  • Peng et al. [2021] Hao Peng, Nikolaos Pappas, Dani Yogatama, Roy Schwartz, Noah A Smith, and Lingpeng Kong. Random feature attention. arXiv preprint arXiv:2103.02143, 2021.
  • Reid et al. [2023] Isaac Reid, Krzysztof Marcin Choromanski, Valerii Likhosherstov, and Adrian Weller. Simplex random features. In International Conference on Machine Learning, pages 28864–28888. PMLR, 2023.
  • Roberts et al. [2019] Adam Roberts, Colin Raffel, Katherine Lee, Michael Matena, Noam Shazeer, Peter J. Liu, Sharan Narang, Wei Li, and Yanqi Zhou. Exploring the limits of transfer learning with a unified text-to-text transformer. Technical report, Google, 2019.
  • Saunshi et al. [2019] Nikunj Saunshi, Orestis Plevrakis, Sanjeev Arora, Mikhail Khodak, and Hrishikesh Khandeparkar. A theoretical analysis of contrastive unsupervised representation learning. In International Conference on Machine Learning, pages 5628–5637. PMLR, 2019.
  • Schlag et al. [2021] Imanol Schlag, Kazuki Irie, and Jürgen Schmidhuber. Linear transformers are secretly fast weight programmers. In International Conference on Machine Learning, pages 9355–9366. PMLR, 2021.
  • Shen et al. [2023] Lingfeng Shen, Aayush Mishra, and Daniel Khashabi. Do pretrained transformers really learn in-context by gradient descent? arXiv preprint arXiv:2310.08540, 2023.
  • Tian et al. [2021] Yuandong Tian, Xinlei Chen, and Surya Ganguli. Understanding self-supervised learning dynamics without contrastive pairs. In International Conference on Machine Learning, pages 10268–10278. PMLR, 2021.
  • Vaswani et al. [2017] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. Advances in neural information processing systems, 30, 2017.
  • Von Oswald et al. [2023a] Johannes Von Oswald, Eyvind Niklasson, Ettore Randazzo, João Sacramento, Alexander Mordvintsev, Andrey Zhmoginov, and Max Vladymyrov. Transformers learn in-context by gradient descent. pages 35151–35174, 2023a.
  • Von Oswald et al. [2023b] Johannes Von Oswald, Eyvind Niklasson, Maximilian Schlegel, Seijin Kobayashi, Nicolas Zucchet, Nino Scherrer, Nolan Miller, Mark Sandler, Max Vladymyrov, Razvan Pascanu, et al. Uncovering mesa-optimization algorithms in transformers. arXiv preprint arXiv:2309.05858, 2023b.
  • Wang [2018] Alex Wang. Glue: A multi-task benchmark and analysis platform for natural language understanding. arXiv preprint arXiv:1804.07461, 2018.
  • Wang et al. [2023] Lean Wang, Lei Li, Damai Dai, Deli Chen, Hao Zhou, Fandong Meng, Jie Zhou, and Xu Sun. Label words are anchors: An information flow perspective for understanding in-context learning. arXiv preprint arXiv:2305.14160, 2023.
  • Wang et al. [2024] Xinyi Wang, Wanrong Zhu, Michael Saxon, Mark Steyvers, and William Yang Wang. Large language models are latent variable models: Explaining and finding good demonstrations for in-context learning. Advances in Neural Information Processing Systems, 36, 2024.
  • Wei et al. [2022] Jason Wei, Yi Tay, Rishi Bommasani, Colin Raffel, Barret Zoph, Sebastian Borgeaud, Dani Yogatama, Maarten Bosma, Denny Zhou, Donald Metzler, et al. Emergent abilities of large language models. arXiv preprint arXiv:2206.07682, 2022.
  • Wolf [2019] T Wolf. Huggingface’s transformers: State-of-the-art natural language processing. arXiv preprint arXiv:1910.03771, 2019.
  • Wu et al. [2018] Zhirong Wu, Yuanjun Xiong, Stella X Yu, and Dahua Lin. Unsupervised feature learning via non-parametric instance discrimination. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 3733–3742, 2018.
  • Xie et al. [2021] Sang Michael Xie, Aditi Raghunathan, Percy Liang, and Tengyu Ma. An explanation of in-context learning as implicit bayesian inference. arXiv preprint arXiv:2111.02080, 2021.
  • Xing et al. [2024] Yue Xing, Xiaofeng Lin, Namjoon Suh, Qifan Song, and Guang Cheng. Benefits of transformer: In-context learning in linear regression tasks with unstructured data. arXiv preprint arXiv:2402.00743, 2024.
  • Yu et al. [2016] Felix Xinnan X Yu, Ananda Theertha Suresh, Krzysztof M Choromanski, Daniel N Holtmann-Rice, and Sanjiv Kumar. Orthogonal random features. Advances in neural information processing systems, 29, 2016.
  • Zhang et al. [2023a] Ruiqi Zhang, Spencer Frei, and Peter L Bartlett. Trained transformers learn linear models in-context. arXiv preprint arXiv:2306.09927, 2023a.
  • Zhang et al. [2023b] Yufeng Zhang, Fengzhuo Zhang, Zhuoran Yang, and Zhaoran Wang. What and how does in-context learning learn? bayesian model averaging, parameterization, and generalization. arXiv preprint arXiv:2305.19420, 2023b.

Appendix A Details of Theorem 3.1

We repeat Theorem 3.1 as follows and provide proof and more discussion for it.

Theorem A.1.

The query token 𝐡T+1{\bm{h}}^{\prime}_{T+1} obtained through ICL inference process with one softmax attention layer, is equivalent to the test prediction 𝐲^test\hat{{\bm{y}}}_{test} obtained by performing one step of gradient descent on the dual model f(𝐱)=𝐖ϕ(𝐱)f({\bm{x}})={\bm{W}}\phi({\bm{x}}). The form of the loss function \mathcal{L} is:

=1ηDi=1N(𝑾V𝒙i)T𝑾ϕ(𝑾K𝒙i),\mathcal{L}=-\frac{1}{\eta D}\sum_{i=1}^{N}\left({\bm{W}}_{V}{\bm{x}}_{i}\right)^{T}{\bm{W}}\phi({\bm{W}}_{K}{\bm{x}}_{i}), (16)

where η\eta is the learning rate and DD is a constant.

Proof.

The derivative of \mathcal{L} with respect to 𝑾{\bm{W}} should be:

𝑾=[i=1N1ηD𝑾V𝒙iϕ(𝑾K𝒙i)].\frac{\partial\mathcal{L}}{\partial{\bm{W}}}=-\left[\sum_{i=1}^{N}\frac{1}{\eta D}{\bm{W}}_{V}{\bm{x}}_{i}\otimes\phi({\bm{W}}_{K}{\bm{x}}_{i})\right].

Thus, after one step of gradient descent , the learned 𝑾^\widehat{{\bm{W}}} will be

𝑾^=𝑾0η𝑾=𝑾0+[i=1N1D𝑾V𝒙iϕ(𝑾K𝒙i)],\widehat{{\bm{W}}}={\bm{W}}_{0}-\eta\frac{\partial\mathcal{L}}{\partial{\bm{W}}}={\bm{W}}_{0}+\left[\sum_{i=1}^{N}\frac{1}{D}{\bm{W}}_{V}{\bm{x}}_{i}\otimes\phi({\bm{W}}_{K}{\bm{x}}_{i})\right], (17)

where 𝑾0{\bm{W}}_{0} is the initialization of the reference model and η\eta is the learning rate. So the test prediction will be

𝒚^test=𝑾0ϕ(𝑾Q𝒙T+1)+[i=1N1D𝑾V𝒙iϕ(𝑾K𝒙i)]ϕ(𝑾Q𝒙T+1).\hat{{\bm{y}}}_{test}={\bm{W}}_{0}\phi\left({\bm{W}}_{Q}{\bm{x}}^{\prime}_{T+1}\right)+\left[\sum_{i=1}^{N}\frac{1}{D}{\bm{W}}_{V}{\bm{x}}_{i}\otimes\phi({\bm{W}}_{K}{\bm{x}}_{i})\right]\phi\left({\bm{W}}_{Q}{\bm{x}}^{\prime}_{T+1}\right). (18)

On the other hand, from the perspective of ICL process with one attention layer, with Eq (7) in our mind, we can rewrite Eq (1) as

𝒉T+1\displaystyle{\bm{h}}^{\prime}_{T+1} =𝑾V𝑿softmax((𝑾K𝑿)T𝑾Q𝒙T+1do)\displaystyle={\bm{W}}_{V}{\bm{X}}\mathrm{softmax}\left(\frac{({\bm{W}}_{K}{\bm{X}})^{T}{\bm{W}}_{Q}{\bm{x}}^{\prime}_{T+1}}{\sqrt{d_{o}}}\right)
=1D𝑾V[𝑿D,𝑿T][ϕ(𝑾K𝑿D),ϕ(𝑾K𝑿T)]Tϕ(𝑾Q𝒙T+1)\displaystyle=\frac{1}{D^{\prime}}{\bm{W}}_{V}[{\bm{X}}_{D},{\bm{X}}_{T}]\left[\phi({\bm{W}}_{K}{\bm{X}}_{D}),\phi({\bm{W}}_{K}{\bm{X}}_{T})\right]^{T}\phi({\bm{W}}_{Q}{\bm{x}}^{\prime}_{T+1})
=1D[𝑽D,𝑽T][ϕ(𝑲D),ϕ(𝑲T)]Tϕ(𝒒),\displaystyle=\frac{1}{D^{\prime}}[{\bm{V}}_{D},{\bm{V}}_{T}]\left[\phi({\bm{K}}_{D}),\phi({\bm{K}}_{T})\right]^{T}\phi({\bm{q}}),

where we use [𝑽D,𝑽T]=𝑾V[𝑿D,𝑿T][{\bm{V}}_{D},{\bm{V}}_{T}]={\bm{W}}_{V}[{\bm{X}}_{D},{\bm{X}}_{T}], [𝑲D,𝑲T]=𝑾K[𝑿D,𝑿T][{\bm{K}}_{D},{\bm{K}}_{T}]={\bm{W}}_{K}[{\bm{X}}_{D},{\bm{X}}_{T}], 𝒒=𝑾Q𝒙T+1{\bm{q}}={\bm{W}}_{Q}{\bm{x}}^{\prime}_{T+1} for simplify and D=𝟏NTϕ(𝑲D)Tϕ(𝒒)+𝟏TTϕ(𝑲T)Tϕ(𝒒)D^{\prime}={\bm{1}}_{N}^{T}\phi({\bm{K}}_{D})^{T}\phi({\bm{q}})+{\bm{1}}_{T}^{T}\phi({\bm{K}}_{T})^{T}\phi({\bm{q}}) is a constant to normalize the equivalent attention block. Further, we expand the above equation to connect the inference process of ICL using softmax attention with the gradient descent as follows

𝒉T+1\displaystyle{\bm{h}}^{\prime}_{T+1} =1D𝑽Tϕ(𝑲T)Tϕ(𝒒)+1D𝑽Dϕ(𝑲D)Tϕ(𝒒)\displaystyle=\frac{1}{D^{\prime}}{\bm{V}}_{T}\phi({\bm{K}}_{T})^{T}\phi({\bm{q}})+\frac{1}{D^{\prime}}{\bm{V}}_{D}\phi({\bm{K}}_{D})^{T}\phi({\bm{q}})
=𝑾0ϕ(𝒒)+1D[i=1N𝑽D(i)ϕ(𝑲D(i))]ϕ(𝒒)\displaystyle={\bm{W}}_{0}^{\prime}\phi({\bm{q}})+\frac{1}{D^{\prime}}\left[\sum_{i=1}^{N}{\bm{V}}_{D}^{(i)}\otimes\phi({\bm{K}}_{D}^{(i)})\right]\phi({\bm{q}})

where 𝑾0=1D𝑽Tϕ(𝑲T)T{\bm{W}}_{0}^{\prime}=\frac{1}{D^{\prime}}{\bm{V}}_{T}\phi({\bm{K}}_{T})^{T} and 𝑽D(i),𝑲D(i){\bm{V}}_{D}^{(i)},{\bm{K}}_{D}^{(i)} are the ii-th column vetors respectively.

Then, in Eq (18), when setting the initialization 𝑾0=𝑾0{\bm{W}}_{0}={\bm{W}}_{0}^{\prime} and the constant D=DD=D^{\prime}, we will find that

𝒚^test=𝑾0ϕ(𝒒)+1D[i=1N𝑽D(i)ϕ(𝑲D(i))]ϕ(𝒒)=𝒉T+1,\hat{{\bm{y}}}_{test}={\bm{W}}_{0}\phi({\bm{q}})+\frac{1}{D}\left[\sum_{i=1}^{N}{\bm{V}}_{D}^{(i)}\otimes\phi({\bm{K}}_{D}^{(i)})\right]\phi({\bm{q}})={\bm{h}}^{\prime}_{T+1}, (19)

which means 𝒚^test\hat{{\bm{y}}}_{test} is strictly equivalent to 𝒉T+1{\bm{h}}^{\prime}_{T+1}. Thus, we have completed our proof. ∎

Given a reference model f(𝒙)=𝑾ϕ(𝒙)f({\bm{x}})={\bm{W}}\phi({\bm{x}}), by comparing Eq (19) and Eq (4), we can easily observe that the gradient descent on the loss function \mathcal{L} applied to f(𝒙)f({\bm{x}}) is the dual form of the inference process of ICL, where 𝑽D(i){\bm{V}}_{D}^{(i)}, ϕ(𝑲D(i))\phi({\bm{K}}_{D}^{(i)}) and ϕ(𝒒)\phi({\bm{q}}) play the roles of backpropagation signals, training inputs and test inputs respectively. Recalling the form of Eq (4), we can interpret the 𝑾0{\bm{W}}_{0} as the initialization of the weight matrix which provide the information under the zero-shot case while the second part in Eq (19) shows that the demonstration examples in ICL acts as the training samples in gradient descent. The reference model f(𝒙)=𝑾ϕ(𝒙)f({\bm{x}})={\bm{W}}\phi({\bm{x}}), initialized with 𝑾0{\bm{W}}_{0}, will have test prediction 𝒚^test=𝒉T+1\hat{{\bm{y}}}_{test}={\bm{h}}^{\prime}_{T+1} after training. This is also why we refer to it as the dual model of the softmax attention layer. We also note that for different demonstrations, even though the model has the same query input, the different given demonstrations will result in different output results. This is equivalent to the dual model performing gradient descent in different directions from the same initialization.

Appendix B Extensions to more complex scenarios

In Theorem 3.1, we provided the dual form of gradient descent for the ICL of one softmax attention layer. Here, we extend the conclusion to more complex scenarios, including one Transformer layer (attention layer plus one FFN layer) and multiple attention layers.

B.1 Extension to one Transformer Layer

As for one Transformer layer introduced in Section 2.1, we define the new dual model as

f+(𝒙)=𝑾ϕ(𝒙)+𝒃.f^{+}({\bm{x}})={\bm{W}}\phi({\bm{x}})+{\bm{b}}. (20)

We will show that after performing gradient descent on 𝑾{\bm{W}}, the test output 𝒚^test=f+(𝑾Q𝒙T+1)\hat{{\bm{y}}}_{test}=f^{+}({\bm{W}}_{Q}{\bm{x}}^{\prime}_{T+1}) will be equivalent to 𝒙^T+1\hat{{\bm{x}}}^{\prime}_{T+1}. Our theorem is given as follows.

Theorem B.1.

The output 𝐱^N+1\hat{{\bm{x}}}^{\prime}_{N+1} of ICL inference process with one Transformer layer, is strictly equivalent to the test prediction of its dual model f+(𝐱)=𝐖ϕ(𝐱)+𝐛f^{+}({\bm{x}})={\bm{W}}\phi({\bm{x}})+{\bm{b}}, where f(𝐱)f({\bm{x}}) is trained under the loss function \mathcal{L} formed as

=1ηDi=1N(𝑾F𝑾V𝒙i)T(𝑾ϕ(𝑾K𝒙i)+𝒃),\mathcal{L}=-\frac{1}{\eta D}\sum_{i=1}^{N}\left({\bm{W}}_{F}{\bm{W}}_{V}{\bm{x}}_{i}\right)^{T}\left({\bm{W}}\phi({\bm{W}}_{K}{\bm{x}}_{i})+{\bm{b}}\right), (21)

where η\eta is the learning rate, DD is a constant, and 𝐖F{\bm{W}}_{F} will be determined once the specified pre-trained model, demonstrations and query tokens are given.

Proof.

Recalling the proof of Theorem 3.1, we can rewrite Eq (1) as

𝒉T+1=𝑾0ϕ(𝑾Q𝒙T+1)+[i=1N1D𝑾V𝒙iϕ(𝑾K𝒙i)]ϕ(𝑾Q𝒙T+1)\displaystyle{\bm{h}}^{\prime}_{T+1}={\bm{W}}_{0}\phi\left({\bm{W}}_{Q}{\bm{x}}^{\prime}_{T+1}\right)+\left[\sum_{i=1}^{N}\frac{1}{D}{\bm{W}}_{V}{\bm{x}}_{i}\otimes\phi({\bm{W}}_{K}{\bm{x}}_{i})\right]\phi\left({\bm{W}}_{Q}{\bm{x}}^{\prime}_{T+1}\right) (22)

where D=𝟏NTϕ(𝑾K𝑿D)Tϕ(𝑾Q𝒙T+1)+𝟏TTϕ(𝑾K𝑿T)Tϕ(𝑾Q𝒙T+1)D={\bm{1}}_{N}^{T}\phi({\bm{W}}_{K}{\bm{X}}_{D})^{T}\phi({\bm{W}}_{Q}{\bm{x}}^{\prime}_{T+1})+{\bm{1}}_{T}^{T}\phi({\bm{W}}_{K}{\bm{X}}_{T})^{T}\phi({\bm{W}}_{Q}{\bm{x}}^{\prime}_{T+1}) is a constant to normalize the attention scores and 𝑾0=1D(𝑾V𝑿T)ϕ(𝑾K𝑿T)T{\bm{W}}_{0}=\frac{1}{D}({\bm{W}}_{V}{\bm{X}}_{T})\phi({\bm{W}}_{K}{\bm{X}}_{T})^{T}. Furthermore, 𝒉N+1{\bm{h}}^{\prime}_{N+1} will be taken as input for the FFN sublayer and the Eq (2) can be rewritten as

𝒙^T+1=𝑾2𝑰M(𝑾1𝒉T+1+𝒃1)+𝒃2=𝑾2𝑰M𝑾1𝒉T+1+𝑾2𝑰M𝒃1+𝒃2,\hat{{\bm{x}}}^{\prime}_{T+1}={\bm{W}}_{2}{\bm{I}}_{M}({\bm{W}}_{1}{\bm{h}}^{\prime}_{T+1}+{\bm{b}}_{1})+{\bm{b}}_{2}={\bm{W}}_{2}{\bm{I}}_{M}{\bm{W}}_{1}{\bm{h}}^{\prime}_{T+1}+{\bm{W}}_{2}{\bm{I}}_{M}{\bm{b}}_{1}+{\bm{b}}_{2},

where 𝑰Md×d{\bm{I}}_{M}\in{\mathbb{R}}^{d\times d} is a diagonal matrix whose ii-th diagonal element will be one if (𝑾1𝒉T+1+𝒃1)i0({\bm{W}}_{1}{\bm{h}}^{\prime}_{T+1}+{\bm{b}}_{1})_{i}\geq 0 otherwise be zero. We need to note that this process is reasonable: for given demonstration and query tokens, once the parameters {𝑾Q,𝑾K,𝑾V,𝑾1,𝒃1}\{{\bm{W}}_{Q},{\bm{W}}_{K},{\bm{W}}_{V},{\bm{W}}_{1},{\bm{b}}_{1}\} of the Transformer layer are fixed after training, 𝑰M{\bm{I}}_{M} will be determined implicitly (otherwise, 𝑰M{\bm{I}}_{M} would be a function that varies with these settings). For simplify, we rewrite 𝒙^T+1\hat{{\bm{x}}}^{\prime}_{T+1} as

𝒙^T+1=𝑾F𝒉T+1+𝒃F,\hat{{\bm{x}}}^{\prime}_{T+1}={\bm{W}}_{F}{\bm{h}}_{T+1}+{\bm{b}}_{F},

where 𝑾F=𝑾2𝑰M𝑾1{\bm{W}}_{F}={\bm{W}}_{2}{\bm{I}}_{M}{\bm{W}}_{1} and 𝒃F=𝑾2𝑰M𝒃1+𝒃2{\bm{b}}_{F}={\bm{W}}_{2}{\bm{I}}_{M}{\bm{b}}_{1}+{\bm{b}}_{2}. Furthermore, expanding 𝒉T+1{\bm{h}}^{\prime}_{T+1} in the above Equation, we get:

𝒙^T+1\displaystyle\hat{{\bm{x}}}^{\prime}_{T+1} =𝑾F𝑾0ϕ(𝑾Q𝒙T+1)+[i=1N1D𝑾F𝑾V𝒙iϕ(𝑾K𝒙i)]ϕ(𝑾Q𝒙T+1)+𝒃F\displaystyle={\bm{W}}_{F}{\bm{W}}_{0}\phi\left({\bm{W}}_{Q}{\bm{x}}^{\prime}_{T+1}\right)+\left[\sum_{i=1}^{N}\frac{1}{D}{\bm{W}}_{F}{\bm{W}}_{V}{\bm{x}}_{i}\otimes\phi({\bm{W}}_{K}{\bm{x}}_{i})\right]\phi\left({\bm{W}}_{Q}{\bm{x}}^{\prime}_{T+1}\right)+{\bm{b}}_{F} (23)
=[𝑾F𝑾0+i=1N1D𝑾F𝑾V𝒙iϕ(𝑾K𝒙i)]ϕ(𝑾Q𝒙T+1)+𝒃F.\displaystyle=\left[{\bm{W}}_{F}{\bm{W}}_{0}+\sum_{i=1}^{N}\frac{1}{D}{\bm{W}}_{F}{\bm{W}}_{V}{\bm{x}}_{i}\otimes\phi({\bm{W}}_{K}{\bm{x}}_{i})\right]\phi\left({\bm{W}}_{Q}{\bm{x}}^{\prime}_{T+1}\right)+{\bm{b}}_{F}.

On the other hand, we define a reference model:

f+(𝒙)=𝑾ϕ(𝒙)+𝒃,f^{+}({\bm{x}})={\bm{W}}\phi({\bm{x}})+{\bm{b}},

where ϕ()\phi(\cdot) is exactly the mapping function satisfying Eq (7) to approximate the softmax kernel. Given the loss formed in Eq (21), we can note that the right part in \mathcal{L} is exactly the output of this reference model when taking 𝑾K𝒙i{\bm{W}}_{K}{\bm{x}}_{i} as input, that is,

=1ηDi=1N(𝑾F𝑾V𝒙i)T(𝑾ϕ(𝑾K𝒙i)+𝒃)=1ηDi=1N(𝑾F𝑾V𝒙i)Tf+(𝑾K𝒙i).\mathcal{L}=-\frac{1}{\eta D}\sum_{i=1}^{N}\left({\bm{W}}_{F}{\bm{W}}_{V}{\bm{x}}_{i}\right)^{T}\left({\bm{W}}\phi({\bm{W}}_{K}{\bm{x}}_{i})+{\bm{b}}\right)=-\frac{1}{\eta D}\sum_{i=1}^{N}\left({\bm{W}}_{F}{\bm{W}}_{V}{\bm{x}}_{i}\right)^{T}f^{+}\left({\bm{W}}_{K}{\bm{x}}_{i}\right).

We can calculate the derivative of \mathcal{L} with respect to 𝑾{\bm{W}} as

𝑾=1ηD[i=1N𝑾F𝑾V𝒙iϕ(𝑾K𝒙i)].\frac{\partial\mathcal{L}}{\partial{\bm{W}}}=-\frac{1}{\eta D}\left[\sum_{i=1}^{N}{\bm{W}}_{F}{\bm{W}}_{V}{\bm{x}}_{i}\otimes\phi({\bm{W}}_{K}{\bm{x}}_{i})\right].

Suppose that the weight matrix 𝑾{\bm{W}} in the reference model f(𝒙)f({\bm{x}}) is initialized as 𝑾init{\bm{W}}_{init}, then using one step of stochastic gradient descent (SGD) [Amari, 1993] with learning rate η\eta, the weight matrix 𝑾{\bm{W}} will be updated as

𝑾^=𝑾initη𝑾=𝑾init+[i=1N1D𝑾F𝑾V𝒙iϕ(𝑾K𝒙i)].\widehat{{\bm{W}}}={\bm{W}}_{init}-\eta\frac{\partial\mathcal{L}}{\partial{\bm{W}}}={\bm{W}}_{init}+\left[\sum_{i=1}^{N}\frac{1}{D}{\bm{W}}_{F}{\bm{W}}_{V}{\bm{x}}_{i}\otimes\phi({\bm{W}}_{K}{\bm{x}}_{i})\right].

Compared to Eq (23), we can set 𝑾init=𝑾F𝑾0{\bm{W}}_{init}={\bm{W}}_{F}{\bm{W}}_{0}, 𝒃=𝒃F{\bm{b}}={\bm{b}}_{F} and take 𝑾Q𝒙T+1{\bm{W}}_{Q}{\bm{x}}^{\prime}_{T+1} as test input. Then, after one step update to 𝑾{\bm{W}}, the output of the reference model will be

f+(𝑾Q𝒙T+1)\displaystyle f^{+}({\bm{W}}_{Q}{\bm{x}}^{\prime}_{T+1}) =𝑾^ϕ(𝑾Q𝒙T+1)+𝒃\displaystyle=\widehat{{\bm{W}}}\phi({\bm{W}}_{Q}{\bm{x}}^{\prime}_{T+1})+{\bm{b}}
=[𝑾init+i=1N1D𝑾F𝑾V𝒙iϕ(𝑾K𝒙i)]ϕ(𝑾Q𝒙T+1)+𝒃\displaystyle=\left[{\bm{W}}_{init}+\sum_{i=1}^{N}\frac{1}{D}{\bm{W}}_{F}{\bm{W}}_{V}{\bm{x}}_{i}\otimes\phi({\bm{W}}_{K}{\bm{x}}_{i})\right]\phi({\bm{W}}_{Q}{\bm{x}}^{\prime}_{T+1})+{\bm{b}}
=[𝑾F𝑾0+i=1N1D𝑾F𝑾V𝒙iϕ(𝑾K𝒙i)]ϕ(𝑾Q𝒙T+1)+𝒃F=𝒙^T+1,\displaystyle=\left[{\bm{W}}_{F}{\bm{W}}_{0}+\sum_{i=1}^{N}\frac{1}{D}{\bm{W}}_{F}{\bm{W}}_{V}{\bm{x}}_{i}\otimes\phi({\bm{W}}_{K}{\bm{x}}_{i})\right]\phi\left({\bm{W}}_{Q}{\bm{x}}^{\prime}_{T+1}\right)+{\bm{b}}_{F}=\hat{{\bm{x}}}^{\prime}_{T+1},

which implies that if we initialize the reference model f+(𝒙)=𝑾ϕ(𝒙)+𝒃f^{+}({\bm{x}})={\bm{W}}\phi({\bm{x}})+{\bm{b}} with 𝑾init=𝑾F𝑾0{\bm{W}}_{init}={\bm{W}}_{F}{\bm{W}}_{0}, 𝒃=𝒃F{\bm{b}}={\bm{b}}_{F}, then after one step of gradient descent for 𝑾{\bm{W}}, the test output of f+(𝑾Q𝒙N+1)f^{+}({\bm{W}}_{Q}{\bm{x}}_{N+1}) will be identical to the ICL result of one Transformer layer. Thus, we call the reference model with setting 𝑾init=𝑾F𝑾0{\bm{W}}_{init}={\bm{W}}_{F}{\bm{W}}_{0}, 𝒃=𝒃F{\bm{b}}={\bm{b}}_{F} as the dual model corresponding to the ICL inference process. Finally, we complete our proof. ∎

Now, we discuss Theorem B.1 from the following perspectives:

  • Training set and test input: In fact, we can observe that the loss function \mathcal{L} can be seen as the sum of inner products of NN vector-pairs. In Eq (21), the right vector happens to be the predicted output 𝒚^std(i)=f+(𝒛std(i))=𝑾ϕ(𝒛std(i))+𝒃\hat{{\bm{y}}}_{std}^{(i)}=f^{+}({\bm{z}}_{std}^{(i)})={\bm{W}}\phi({\bm{z}}_{std}^{(i)})+{\bm{b}} of the dual model for training input 𝒛std(i)=𝑾K𝒙i{\bm{z}}_{std}^{(i)}={\bm{W}}_{K}{\bm{x}}_{i}. Correspondingly, the vector on the left can be regarded as the true label 𝒚std(i)=𝑾F𝑾V𝒙i{\bm{y}}_{std}^{(i)}={\bm{W}}_{F}{\bm{W}}_{V}{\bm{x}}_{i}. In other words, it can be seen that the dual model performs one step SGD given training set {𝒛std(i),𝒚std(i)}i=1N\{{\bm{z}}_{std}^{(i)},{\bm{y}}_{std}^{(i)}\}_{i=1}^{N} on 𝑾{\bm{W}} using the loss \mathcal{L}:

    =1ηDi=1N(𝒚std(i))T𝒚^std(i).\mathcal{L}=\frac{1}{\eta D}\sum_{i=1}^{N}\left({\bm{y}}_{std}^{(i)}\right)^{T}\hat{{\bm{y}}}_{std}^{(i)}.

    And then taking 𝒛test=𝑾Q𝒙i{\bm{z}}_{test}={\bm{W}}_{Q}{\bm{x}}_{i} as test input, it finally output the prediction 𝒚test{\bm{y}}_{test}, which achieves the ICL result 𝒙^N+1\hat{{\bm{x}}}_{N+1}. Compared to Theorem 3.1, after introducing the FFN layer, the main difference is that the labels of the training data become 𝒚std(i)=𝑾F𝑾V𝒙i{\bm{y}}_{std}^{(i)}={\bm{W}}_{F}{\bm{W}}_{V}{\bm{x}}_{i} instead of 𝒚std(i)=𝑾V𝒙i{\bm{y}}_{std}^{(i)}={\bm{W}}_{V}{\bm{x}}_{i}. Additionally, compared to f(𝒙)f({\bm{x}}), an extra bias bb is introduced in the new dual model f+(𝒙)f^{+}({\bm{x}}), which also have a different initialization 𝑾init=𝑾F𝑾0{\bm{W}}_{init}={\bm{W}}_{F}{\bm{W}}_{0} rather than 𝑾0{\bm{W}}_{0}. We also need to note that in the dual model f+(𝒙)f^{+}({\bm{x}}), only 𝑾{\bm{W}} is trained, while 𝒃{\bm{b}} remains unchanged after initialization.

  • Potential Low-rankness of WF{\bm{W}}_{F}: Noting that 𝑾F=𝑾2𝑰M𝑾1{\bm{W}}_{F}={\bm{W}}_{2}{\bm{I}}_{M}{\bm{W}}_{1} where 𝑾1dh×d{\bm{W}}_{1}\in{\mathbb{R}}^{d_{h}\times d}, 𝑾2d×dh{\bm{W}}_{2}\in{\mathbb{R}}^{d\times d_{h}}, 𝑰Mdh×dh{\bm{I}}_{M}\in{\mathbb{R}}^{d_{h}\times d_{h}} (here we assume that di=do=dd_{i}=d_{o}=d for simplify), the rank of 𝑾F{\bm{W}}_{F} will satisfy

    Rank(𝑾F)min{Rank(𝑾1),Rank(𝑾2),Rank(𝑰M)}.\mathrm{Rank}({\bm{W}}_{F})\leq\min\left\{\mathrm{Rank}({\bm{W}}_{1}),\mathrm{Rank}({\bm{W}}_{2}),\mathrm{Rank}({\bm{I}}_{M})\right\}.

    We observe that 𝑰M{\bm{I}}_{M} is a diagonal matrix with elements being zero or one, and its rank is determined by the number of non-zero elements. Here, we can make a mild assumption that we can set Rank(𝑾1)=Rank(𝑾2)=min{d,dh}\mathrm{Rank}({\bm{W}}_{1})=\mathrm{Rank}({\bm{W}}_{2})=\min\{d,d_{h}\}. This assumption is quite mild as even for any random square matrix as it will be non-singular with probability 1. In addition, we also assume dh>dd_{h}>d which is consistent with settings in practice. Therefore, we get Rank(𝑾1)=Rank(𝑾2)=d\mathrm{Rank}({\bm{W}}_{1})=\mathrm{Rank}({\bm{W}}_{2})=d, and the upper bound of Rank(𝑾F)\mathrm{Rank}({\bm{W}}_{F}) will be

    Rank(𝑾F)min{d,Rank(𝑰M)}.\mathrm{Rank}({\bm{W}}_{F})\leq\min\left\{d,\mathrm{Rank}({\bm{I}}_{M})\right\}.

    Thus, we can find that if we want to avoid losing information, 𝑾F{\bm{W}}_{F} should strive to maintain Rank(𝑰M)>d\mathrm{Rank}({\bm{I}}_{M})>d which will be more easily achieved as dhd_{h} becomes larger than dd. Otherwise, Rank(𝑰M)\mathrm{Rank}({\bm{I}}_{M}) is likely to gradually decrease with an increase in the number of Transformer layers. This explains the necessity of setting dh>dd_{h}>d in practice. In some cases where Rank(𝑰M)<d\mathrm{Rank}({\bm{I}}_{M})<d, meaning that the number of non-zero elements in 𝑰M{\bm{I}}_{M} or positive elements in 𝑾1𝒉N+1+𝒃1{\bm{W}}_{1}{\bm{h}}_{N+1}+{\bm{b}}_{1} is less than dd, the upper bound of Rank(𝑾F)\mathrm{Rank}({\bm{W}}_{F}) will be Rank(𝑰M)\mathrm{Rank}({\bm{I}}_{M}) and the lower bound of Rank(𝑾F)\mathrm{Rank}({\bm{W}}_{F}) will be given as

    Rank(𝑾F)Rank(𝑾2𝑰M)+Rank(𝑰M𝑾1)Rank(𝑰M)=Rank(𝑰M),\mathrm{Rank}({\bm{W}}_{F})\geq\mathrm{Rank}({\bm{W}}_{2}{\bm{I}}_{M})+\mathrm{Rank}({\bm{I}}_{M}{\bm{W}}_{1})-\mathrm{Rank}({\bm{I}}_{M})=\mathrm{Rank}({\bm{I}}_{M}),

    which implies the rank of 𝑾F{\bm{W}}_{F} will exactly equal to Rank(𝑰M)\mathrm{Rank}({\bm{I}}_{M}). We should note that this condition, i.e., Rank(𝑰M)<d\mathrm{Rank}({\bm{I}}_{M})<d, is easily satisfied when dh=dd_{h}=d or when dhd_{h} is slightly larger than dd (for example, d<dh<2dd<d_{h}<2d in an expected sense). Thus, we conclude that 𝑾F{\bm{W}}_{F} has the potential low-rank property.

  • Representation Learning Lens: For a encoded demonstration representation 𝒙i{\bm{x}}_{i}, the key and value projections will generate a pair of feature 𝑾K𝒙i{\bm{W}}_{K}{\bm{x}}_{i} and 𝑾V𝒙i{\bm{W}}_{V}{\bm{x}}_{i} to create a certain distance between data representations in space. And then, on the one hand, a potential low-rank transformation 𝑾F{\bm{W}}_{F} is applied to the 𝑾V𝒙i{\bm{W}}_{V}{\bm{x}}_{i}, attempting to compress some information which increases the difficulty of contrastive learning and forces the model to learn better features; on the other hand, ϕ()\phi(\cdot) projects 𝑾K𝒙i{\bm{W}}_{K}{\bm{x}}_{i} into a higher-dimensional space to capture deeper-level features. Finally, we need to train the weight matrix 𝑾{\bm{W}}, which maps ϕ(𝑾K𝒙i)\phi({\bm{W}}_{K}{\bm{x}}_{i}) back to the original space, aiming to make the mapped vector as close as possible to 𝑾F𝑾V𝒙i{\bm{W}}_{F}{\bm{W}}_{V}{\bm{x}}_{i}. This interpretation is illustrated in Figure 4.

Refer to caption
Refer to caption
Figure 4: The representation learning process for the ICL inference by one Transformer layer.

B.2 Extension to Multiple Attention Layers

In this part , we extend Theorem 3.1 to multiple attention layers. Here we adopt the attention layer based on PrefixLM [Roberts et al., 2019], where the query tokens can compute attention with all preceding tokens (including itself), while for demonstration ones, attention can be computed between themselves, excluding the query tokens. Existing work [Ding et al., 2023] has theoretically and experimentally explained that PrefixLM achieves better results than CasualLM. In this paper, we assume we have only one query token, that is, there is no query input before the considered query token. With the assumption that T=0T=0, to maintain notational simplicity, we use 𝒙N+1{\bm{x}}_{N+1} to represent the query token here instead of 𝒙T+1{\bm{x}}^{\prime}_{T+1} and the input will be 𝑿=[𝑿D,𝒙N+1]{\bm{X}}=[{\bm{X}}_{D},{\bm{x}}_{N+1}]. We assume that there are LL attention layers and the output of the ll-th layer 𝑿(l){\bm{X}}^{(l)} can be expressed as:

𝑯(l)\displaystyle{\bm{H}}^{(l)} =[𝑯D(l),𝒉N+1(l)]=Atten(𝑯(l1);𝑾Q(l),𝑾K(l),𝑾V(l)),\displaystyle=[{\bm{H}}_{D}^{(l)},{\bm{h}}_{N+1}^{(l)}]=\mathrm{Atten}({\bm{H}}^{(l-1)};~{}{\bm{W}}_{Q}^{(l)},{\bm{W}}_{K}^{(l)},{\bm{W}}_{V}^{(l)}),
𝑯D(l)\displaystyle{\bm{H}}_{D}^{(l)} =𝑾V(l)𝑯D(l1)Softmax((𝑾K(l)𝑯D(l1))T𝑾Q(l)𝑯D(l1)d),\displaystyle={\bm{W}}_{V}^{(l)}{\bm{H}}_{D}^{(l-1)}\mathrm{Softmax}\left(\frac{({\bm{W}}_{K}^{(l)}{\bm{H}}_{D}^{(l-1)})^{T}{\bm{W}}_{Q}^{(l)}{\bm{H}}_{D}^{(l-1)}}{\sqrt{d}}\right),
𝒉N+1(l)\displaystyle{\bm{h}}_{N+1}^{(l)} =𝑾V(l)𝑯(l1)Softmax((𝑾K(l)𝑯(l1))T𝑾Q(l)𝒉N+1(l1)d).\displaystyle={\bm{W}}_{V}^{(l)}{\bm{H}}^{(l-1)}\mathrm{Softmax}\left(\frac{({\bm{W}}_{K}^{(l)}{\bm{H}}^{(l-1)})^{T}{\bm{W}}_{Q}^{(l)}{\bm{h}}_{N+1}^{(l-1)}}{\sqrt{d}}\right).

where we set 𝑯(0)=𝑿=[𝑿D,𝒙N+1]{\bm{H}}^{(0)}={\bm{X}}=[{\bm{X}}_{D},{\bm{x}}_{N+1}] as the initial input. And the final output of the query token is 𝒉^N+1(L)=𝒉N+1(L)\hat{{\bm{h}}}_{N+1}^{(L)}={\bm{h}}_{N+1}^{(L)}. Here, we assume that after training, the parameters 𝑾Q(l),𝑾K(l),𝑾V(l)do×di{\bm{W}}_{Q}^{(l)},{\bm{W}}_{K}^{(l)},{\bm{W}}_{V}^{(l)}\in{\mathbb{R}}^{d_{o}\times d_{i}} are fixed and we set do=di=dd_{o}=d_{i}=d.

Next, we extend Theorem 3.1 to the case of multiple softmax attention layers. Formally, we present our result in the following theorem.

Theorem B.2.

Given LL softmax attention layers whose parameters {𝐖Q(l),𝐖K(l),𝐖V(l)}l=1L\{{\bm{W}}_{Q}^{(l)},{\bm{W}}_{K}^{(l)},{\bm{W}}_{V}^{(l)}\}_{l=1}^{L} are fixed after training, the ICL output of these layers is equivalent to sequentially performing one step gradient descent on a sequence of dual models {f(l)(𝐱)=𝐖(l)ϕ(𝐱)}l=1L\left\{f^{(l)}({\bm{x}})={\bm{W}}^{(l)}\phi({\bm{x}})\right\}_{l=1}^{L}, where the loss function for the ll-th dual model is:

(l)=1ηD(l)i=1N(𝑾V(l)𝒉i(l1))T𝑾(l)ϕ(𝑾K(l)𝒉i(l1)),\mathcal{L}^{(l)}=-\frac{1}{\eta D^{(l)}}\sum_{i=1}^{N}\left({\bm{W}}_{V}^{(l)}{\bm{h}}_{i}^{(l-1)}\right)^{T}{\bm{W}}^{(l)}\phi({\bm{W}}_{K}^{(l)}{\bm{h}}_{i}^{(l-1)}), (24)

where 𝐡i(l1){\bm{h}}_{i}^{(l-1)} is the output of the (l1)(l-1)-th attention layer for the ii-th token, η\eta is the learning rate and D(l)D^{(l)} is a constant. The input for the ll-th dual model is generated by the trained (l1)(l-1)-th dual model.

Refer to caption
Figure 5: Illustrating the ICL inference process of multiple softmax attention layers from the perspective of dual models. The layer-wise process of ICL can be viewed as a gradual gradient descent on the dual model sequence. The datasets used for each gradient descent, including training data and test input, are obtained from the outputs of the previous dual model before and after training.
Proof.

Given 𝑯(l1){\bm{H}}^{(l-1)} as the input for the ll-th attention layer, the inference process of 𝒉N+1(l){\bm{h}}_{N+1}^{(l)} is

𝒉N+1(l)\displaystyle{\bm{h}}_{N+1}^{(l)} =𝑾V(l)𝑯(l1)Softmax((𝑾K(l)𝑯(l1))T𝑾Q(l)𝒉N+1(l1)d)\displaystyle={\bm{W}}_{V}^{(l)}{\bm{H}}^{(l-1)}\mathrm{Softmax}\left(\frac{({\bm{W}}_{K}^{(l)}{\bm{H}}^{(l-1)})^{T}{\bm{W}}_{Q}^{(l)}{\bm{h}}_{N+1}^{(l-1)}}{\sqrt{d}}\right)
=𝑾0(l)ϕ(𝑾Q(l)𝒉N+1(l1))+[i=1N1D(l)𝑾V(l)𝒉i(l1)ϕ(𝑾K(l)𝒉i(l1))]ϕ(𝑾Q(l)𝒉N+1(l1)),\displaystyle={\bm{W}}_{0}^{(l)}\phi\left({\bm{W}}_{Q}^{(l)}{\bm{h}}_{N+1}^{(l-1)}\right)+\left[\sum_{i=1}^{N}\frac{1}{D^{(l)}}{\bm{W}}_{V}^{(l)}{\bm{h}}_{i}^{(l-1)}\otimes\phi({\bm{W}}_{K}^{(l)}{\bm{h}}_{i}^{(l-1)})\right]\phi\left({\bm{W}}_{Q}^{(l)}{\bm{h}}_{N+1}^{(l-1)}\right),

where D(l)=𝟏N+1Tϕ(𝑾K(l)𝑯(l1))Tϕ(𝑾Q(l)𝒉N+1(l1))D^{(l)}={\bm{1}}_{N+1}^{T}\phi({\bm{W}}_{K}^{(l)}{\bm{H}}^{(l-1)})^{T}\phi({\bm{W}}_{Q}^{(l)}{\bm{h}}_{N+1}^{(l-1)}) is a constant to normalize the attention scores and 𝑾0(l)=1D(l)(𝑾V(l)𝒉N+1(l1))ϕ(𝑾K(l)𝒉N+1(l1))T{\bm{W}}_{0}^{(l)}=\frac{1}{D^{(l)}}({\bm{W}}_{V}^{(l)}{\bm{h}}_{N+1}^{(l-1)})\phi({\bm{W}}_{K}^{(l)}{\bm{h}}_{N+1}^{(l-1)})^{T}. According to Theorem 3.1, we can easily get the dual model finit(l)(𝒉)=𝑾init(l)ϕ(𝒉)f^{(l)}_{init}({\bm{h}})={\bm{W}}^{(l)}_{init}\phi({\bm{h}}) where the initialization is 𝑾init(l)=𝑾0(l){\bm{W}}^{(l)}_{init}={\bm{W}}_{0}^{(l)}. Given the loss function (l)\mathcal{L}^{(l)} formed as Equation 24 and training set {𝒛std(i),𝒚std(i)}i=1N\left\{{\bm{z}}_{std}^{(i)},{\bm{y}}_{std}^{(i)}\right\}_{i=1}^{N} where 𝒛std(i)=𝑾K(l)𝒉i(l1){\bm{z}}_{std}^{(i)}={\bm{W}}_{K}^{(l)}{\bm{h}}_{i}^{(l-1)} and 𝒚std(i)=𝑾V(l)𝒉i(l){\bm{y}}_{std}^{(i)}={\bm{W}}_{V}^{(l)}{\bm{h}}_{i}^{(l)}, we perform one step SGD with learning rate η\eta on weight matrix 𝑾(l){\bm{W}}^{(l)} and will get trained dual model:

f^(l)(𝒙)\displaystyle\hat{f}^{(l)}({\bm{x}}) =𝑾^(l)ϕ(𝒙)=(𝑾init(l)+Δ𝑾(l))ϕ(𝒙)\displaystyle=\widehat{{\bm{W}}}^{(l)}\phi({\bm{x}})=({\bm{W}}^{(l)}_{init}+\Delta{\bm{W}}^{(l)})\phi({\bm{x}})
=[𝑾0(l)+i=1N1D(l)𝑾V(l)𝒉i(l1)ϕ(𝑾K(l)𝒉i(l1))]ϕ(𝒙)\displaystyle=\left[{\bm{W}}_{0}^{(l)}+\sum_{i=1}^{N}\frac{1}{D^{(l)}}{\bm{W}}_{V}^{(l)}{\bm{h}}_{i}^{(l-1)}\otimes\phi({\bm{W}}_{K}^{(l)}{\bm{h}}_{i}^{(l-1)})\right]\phi({\bm{x}})

Taking test input as 𝒛test(l)=𝑾Q(l)𝒉N+1(l1){\bm{z}}^{(l)}_{test}={\bm{W}}_{Q}^{(l)}{\bm{h}}_{N+1}^{(l-1)}, the prediction f^(l)(𝒛test(l))\hat{f}^{(l)}({\bm{z}}^{(l)}_{test}) will exactly equal to 𝒉N+1(l){\bm{h}}_{N+1}^{(l)}.

Next, we will show how to obtain 𝑯D(l){\bm{H}}_{D}^{(l)} through the trained dual model f^(l)(𝒙)\hat{f}^{(l)}({\bm{x}}). And after 𝑾K(l+1),𝑾Q(l+1),𝑾V(l+1){\bm{W}}_{K}^{(l+1)},{\bm{W}}_{Q}^{(l+1)},{\bm{W}}_{V}^{(l+1)} projections, 𝑯D(l){\bm{H}}_{D}^{(l)} will constitute the training set as well as the test input for the next dual model finit(l+1)(𝒙)f^{(l+1)}_{init}({\bm{x}}).

Keeping the initialized dual model finit(l)(𝒙)f^{(l)}_{init}({\bm{x}}) and the trained one f^(l)(𝒙)\hat{f}^{(l)}({\bm{x}}) in mind, we can compute the demonstration token output 𝒉i(l){\bm{h}}^{(l)}_{i} (i=1,2,,Ni=1,2,...,N) of ll-th attention layer as

𝒉i(l)\displaystyle{\bm{h}}^{(l)}_{i} =𝑾V(l)𝑯D(l1)Softmax((𝑾K(l)𝑯D(l1))T𝑾Q(l)𝒉i(l1)d)\displaystyle={\bm{W}}_{V}^{(l)}{\bm{H}}_{D}^{(l-1)}\mathrm{Softmax}\left(\frac{({\bm{W}}_{K}^{(l)}{\bm{H}}_{D}^{(l-1)})^{T}{\bm{W}}_{Q}^{(l)}{\bm{h}}_{i}^{(l-1)}}{\sqrt{d}}\right) (25)
=[i=1N1Di(l)𝑾V(l)𝒉i(l1)ϕ(𝑾K(l)𝒉i(l1))]ϕ(𝑾Q(l)𝒉i(l1))\displaystyle=\left[\sum_{i=1}^{N}\frac{1}{D^{(l)}_{i}}{\bm{W}}_{V}^{(l)}{\bm{h}}_{i}^{(l-1)}\otimes\phi({\bm{W}}_{K}^{(l)}{\bm{h}}_{i}^{(l-1)})\right]\phi\left({\bm{W}}_{Q}^{(l)}{\bm{h}}_{i}^{(l-1)}\right)
=D(l)Di(l)[i=1N1D(l)𝑾V(l)𝒉i(l1)ϕ(𝑾K(l)𝒉i(l1))]ϕ(𝑾Q(l)𝒉i(l1))\displaystyle=\frac{D^{(l)}}{D^{(l)}_{i}}\left[\sum_{i=1}^{N}\frac{1}{D^{(l)}}{\bm{W}}_{V}^{(l)}{\bm{h}}_{i}^{(l-1)}\otimes\phi({\bm{W}}_{K}^{(l)}{\bm{h}}_{i}^{(l-1)})\right]\phi\left({\bm{W}}_{Q}^{(l)}{\bm{h}}_{i}^{(l-1)}\right)
=D(l)Di(l)[𝑾init(l)+i=1N1D(l)𝑾V(l)𝒉i(l1)ϕ(𝑾K(l)𝒉i(l1))𝑾init(l)]ϕ(𝑾Q(l)𝒉i(l1))\displaystyle=\frac{D^{(l)}}{D^{(l)}_{i}}\left[{\bm{W}}_{init}^{(l)}+\sum_{i=1}^{N}\frac{1}{D^{(l)}}{\bm{W}}_{V}^{(l)}{\bm{h}}_{i}^{(l-1)}\otimes\phi({\bm{W}}_{K}^{(l)}{\bm{h}}_{i}^{(l-1)})-{\bm{W}}_{init}^{(l)}\right]\phi\left({\bm{W}}_{Q}^{(l)}{\bm{h}}_{i}^{(l-1)}\right)
=D(l)Di(l)[𝑾^(l)𝑾init(l)]ϕ(𝑾Q(l)𝒉i(l1))=D(l)Di(l)[f^(l)(𝑾Q(l)𝒉i(l1))finit(l)(𝑾Q(l)𝒉i(l1))]\displaystyle=\frac{D^{(l)}}{D^{(l)}_{i}}\left[\widehat{{\bm{W}}}^{(l)}-{\bm{W}}_{init}^{(l)}\right]\phi\left({\bm{W}}_{Q}^{(l)}{\bm{h}}_{i}^{(l-1)}\right)=\frac{D^{(l)}}{D^{(l)}_{i}}\left[\hat{f}^{(l)}({\bm{W}}_{Q}^{(l)}{\bm{h}}_{i}^{(l-1)})-f^{(l)}_{init}({\bm{W}}_{Q}^{(l)}{\bm{h}}_{i}^{(l-1)})\right]

where Di(l)=𝟏NTϕ(𝑾K(l)𝑯D(l1))Tϕ(𝑾Q(l)𝒉i(l1))D^{(l)}_{i}={\bm{1}}_{N}^{T}\phi({\bm{W}}_{K}^{(l)}{\bm{H}}_{D}^{(l-1)})^{T}\phi({\bm{W}}_{Q}^{(l)}{\bm{h}}_{i}^{(l-1)}) is a constant to normalize the attention scores for 𝒉i(l){\bm{h}}^{(l)}_{i}. Therefore, once we obtain the trained dual model f^(l)(𝒉)\hat{f}^{(l)}({\bm{h}}), we can use the Eq (25) to get the demonstration token output 𝑯D(l)=[𝒉1(l),𝒉2(l),,𝒉N(l)]{\bm{H}}_{D}^{(l)}=[{\bm{h}}_{1}^{(l)},{\bm{h}}_{2}^{(l)},...,{\bm{h}}_{N}^{(l)}]. These demonstration token outputs, along with the output 𝒉N+1(l){\bm{h}}_{N+1}^{(l)} for query tokens, will together constitute the training set and test input for the next dual model finit(l+1)(𝒉)f^{(l+1)}_{init}({\bm{h}}). This process continues layer by layer until we obtain the ultimate ICL output 𝒉N+1(L){\bm{h}}_{N+1}^{(L)}. In summary, the ICL inference process across LL attention layers is equivalent to performing gradient descent on LL dual models sequentially. Thus, we complete our proof. ∎

This theorem is a natural extension of Theorem 3.1: when considering the stacking of multiple attention layers, a sequence of dual models is correspondingly generated. Although these dual models have the same form f(𝒙)=𝑾ϕ(𝒙)f({\bm{x}})={\bm{W}}\phi({\bm{x}}), they have different initializations and datasets. As the ICL inference process progresses layer by layer between attention layers, we equivalently perform gradient descent on the dual models one by one. The input 𝑯(l){\bm{H}}^{(l)} for each attention layer, including demonstration tokens and query tokens, can be obtained from the test output of the dual models. This can be illustrated in Figure 5.

Appendix C Proof of the Generalization bound

C.1 Proof of Theorem 3.2

In this part, we provide the proof regarding the generalization boundary in Theorem 3.2. We restate our theorem as follows:

Theorem C.1.

Define the function class as :={f(𝐱)=𝐖ϕ(𝐖K𝐱)|𝐖w}\mathcal{F}:=\left\{f({\bm{x}})={\bm{W}}\phi({\bm{W}}_{K}{\bm{x}})~{}|~{}\|{\bm{W}}\|\leq w\right\} and let the loss function defined as Eq (10). Consider the given demonstration set as 𝒮={𝐱i}i=1N\mathcal{S}=\{{\bm{x}}_{i}\}_{i=1}^{N} where 𝒮𝒮𝒯\mathcal{S}\subseteq\mathcal{S}_{\mathcal{T}} and 𝒮𝒯\mathcal{S}_{\mathcal{T}} is all possible demonstration tokens for some task 𝒯\mathcal{T}. With the assumption that 𝐖V𝐱i,𝐖ϕ(𝐖K𝐱i)ρ\|{\bm{W}}_{V}{\bm{x}}_{i}\|,\|{\bm{W}}\phi({\bm{W}}_{K}{\bm{x}}_{i})\|\leq\rho, then for any δ>0\delta>0, the following statement holds with probability at least 1δ1-\delta for any ff\in\mathcal{F}

(f^)(f)+O(wρdoTr(𝑲𝒮)N+log1δN).\mathcal{L}(\hat{f})\leq\mathcal{L}(f)+O\left(\frac{w\rho d_{o}\sqrt{\mathrm{Tr}({\bm{K}}_{\mathcal{S}})}}{N}+\sqrt{\frac{log\frac{1}{\delta}}{N}}\right). (26)
Proof.

Our proof is similar to the Lemma 4.2 in Saunshi et al. [2019], but here we focus on a different function class. Firstly, we consider the classical generalization bound based on the Rademacher complexity of the function class which can refer to Theorem 3.1 in Mohri et al. [2018]. For a real function class GG whose functions map from a set ZZ to [0,1][0,1] and for any δ>0\delta>0, if 𝒮{\mathcal{S}} is a training set composed by NN iid samples {𝒙i}i=1N\{{\bm{x}}_{i}\}_{i=1}^{N}, then with probability at least 1δ21-\frac{\delta}{2}, for all gGg\in G

𝔼[g(𝒙)]1Ni=1Ng(𝒙i)+2𝒮(G)N+3log4δ2N\mathbb{E}\left[g({\bm{x}})\right]\leq\frac{1}{N}\sum_{i=1}^{N}g({\bm{x}}_{i})+\frac{2{\mathcal{R}}_{{\mathcal{S}}}(G)}{N}+3\sqrt{\frac{\log\frac{4}{\delta}}{2N}} (27)

where 𝒮(G){\mathcal{R}}_{{\mathcal{S}}}(G) is the traditional Rademacher complexity. By setting 𝒮{\mathcal{S}} exactly the demonstration set and G={gf(𝒙)=(𝑾V𝒙)T𝑾ϕ(𝑾K𝒙)|𝑾w}G=\left\{g_{f}({\bm{x}})=-\left({\bm{W}}_{V}{\bm{x}}\right)^{T}{\bm{W}}\phi({\bm{W}}_{K}{\bm{x}})\big{|}\|{\bm{W}}\|\leq w\right\}, we can apply this bound to our case.

Then, we construct a function class ~={f~(𝒙)=[f(𝒙);𝑾V𝒙]=[𝑾ϕ(𝑾K𝒙);𝑾V𝒙]|𝑾w}\tilde{{\mathcal{F}}}=\left\{\tilde{f}({\bm{x}})=[f({\bm{x}});{\bm{W}}_{V}{\bm{x}}]=[{\bm{W}}\phi({\bm{W}}_{K}{\bm{x}});{\bm{W}}_{V}{\bm{x}}]\big{|}\|{\bm{W}}\|\leq w\right\} whose functions map from 𝒮{\mathcal{S}} to 2do{\mathbb{R}}^{2d_{o}}. Next, we will first prove 𝒮(G)2ρ𝒮(~){\mathcal{R}}_{{\mathcal{S}}}(G)\leq 2\rho{\mathcal{R}}_{{\mathcal{S}}}(\tilde{{\mathcal{F}}}) and to do this, we need to use the following Lemma:

Lemma C.2 (Corollary 4 in Maurer [2016]).

Let ZZ be any set, and 𝒮={𝐳i}i=1MZM{\mathcal{S}}=\{{\bm{z}}_{i}\}_{i=1}^{M}\in Z^{M}. Let ~\tilde{{\mathcal{F}}} be a class of functions f~:Zn\tilde{f}:Z\to{\mathbb{R}}^{n} and h:nh:{\mathbb{R}}^{n}\to{\mathbb{R}} be LL-Lipschitz. For all f~~\tilde{f}\in\tilde{{\mathcal{F}}}, let gf~=hf~g_{\tilde{f}}=h\circ\tilde{f}. Then

𝔼σ{±1}M[supf~F~σ,(gf~|𝒮)]2L𝔼σ{±1}nM[supf~F~σ,(f~|𝒮)]\mathop{\mathbb{E}}\limits_{\sigma\sim\{\pm 1\}^{M}}\left[\sup_{\tilde{f}\in\tilde{F}}\langle\sigma,(g_{\tilde{f}_{|{\mathcal{S}}}})\rangle\right]\leq\sqrt{2}L\mathop{\mathbb{E}}\limits_{\sigma\sim\{\pm 1\}^{nM}}\left[\sup_{\tilde{f}\in\tilde{F}}\langle\sigma,(\tilde{f}_{|{\mathcal{S}}})\rangle\right] (28)

where f~|𝒮=(f~t(𝐳j))t[n],j[M]\tilde{f}_{|{\mathcal{S}}}=\left(\tilde{f}_{t}({\bm{z}}_{j})\right)_{t\in[n],j\in[M]}.

We apply Lemma C.2 to our case by setting Z=diZ={\mathbb{R}}^{d_{i}}, 𝒮{\mathcal{S}} to be exactly the demonstration set, ~\tilde{{\mathcal{F}}} to be the function class we constructed and n=2dn=2d. We also use h:2doh:{\mathbb{R}}^{2d_{o}}\to{\mathbb{R}} where h(𝒙)=𝒙1:do,𝒙do+1:2doh({\bm{x}})=-\langle{\bm{x}}_{1:d_{o}},{\bm{x}}_{d_{o}+1:2d_{o}}\rangle and thus we have gf~(𝒙)=h(f~(𝒙))=h([f(𝒙);𝑾V𝒙])=(𝑾V𝒙)T𝑾ϕ(𝑾K𝒙)g_{\tilde{f}}({\bm{x}})=h(\tilde{f}({\bm{x}}))=h([f({\bm{x}});{\bm{W}}_{V}{\bm{x}}])=-\left({\bm{W}}_{V}{\bm{x}}\right)^{T}{\bm{W}}\phi({\bm{W}}_{K}{\bm{x}}). We can find that gf(𝒙)=gf~(𝒙)g_{f}({\bm{x}})=g_{\tilde{f}}({\bm{x}}) and the left side of inequality (28) is exactly 𝒮(G){\mathcal{R}}_{{\mathcal{S}}}(G).

Then we can see that hh is 2ρ\sqrt{2}\rho-Lipschitz with the assumption that 𝑾V𝒙i,𝑾ϕ(𝑾K𝒙i)ρ\|{\bm{W}}_{V}{\bm{x}}_{i}\|,\|{\bm{W}}\phi({\bm{W}}_{K}{\bm{x}}_{i})\|\leq\rho and we have 𝒮(G)2ρ𝒮(~){\mathcal{R}}_{{\mathcal{S}}}(G)\leq 2\rho{\mathcal{R}}_{{\mathcal{S}}}(\tilde{{\mathcal{F}}}). Now using Lemma C.2 and the classical generalization bound (27), we have that with probability at least 1δ21-\frac{\delta}{2}

(f^)^(f^)+O(ρ𝒮(F~)N+log1δN),\mathcal{L}(\hat{f})\leq\hat{\mathcal{L}}(\hat{f})+O\left(\frac{\rho{\mathcal{R}}_{{\mathcal{S}}}(\tilde{F})}{N}+\sqrt{\frac{log\frac{1}{\delta}}{N}}\right), (29)

Let fargminf(f)f^{*}\in\operatorname*{arg\,min}_{f\in{\mathcal{F}}}{\mathcal{L}}(f). According to Hoeffding’s inequality, with probability at least 1δ21-\frac{\delta}{2}, we have that ^(f)(f)+3log2δ2N\hat{{\mathcal{L}}}(f^{*})\leq{\mathcal{L}}(f^{*})+3\sqrt{\frac{\log\frac{2}{\delta}}{2N}}. Combining this with (29), the fact that ^(f^)^(f)\hat{{\mathcal{L}}}(\hat{f})\leq\hat{{\mathcal{L}}}(f^{*}) and applying a union bound, we can get that

(f^)(f)+O(ρ𝒮(F~)N+log1δN).\mathcal{L}(\hat{f})\leq{\mathcal{L}}(f)+O\left(\frac{\rho{\mathcal{R}}_{{\mathcal{S}}}(\tilde{F})}{N}+\sqrt{\frac{log\frac{1}{\delta}}{N}}\right). (30)

Next, we give the upper bound for 𝒮(~){\mathcal{R}}_{{\mathcal{S}}}(\tilde{{\mathcal{F}}}).

𝒮(~)\displaystyle{\mathcal{R}}_{{\mathcal{S}}}(\tilde{{\mathcal{F}}}) =missingEσ{±1}2Ndo[sup𝑾jwt=12Ndoσt(f~|𝒮)t]\displaystyle=\mathop{\mathbb{missing}}{E}\limits_{\sigma\sim\{\pm 1\}^{2Nd_{o}}}\left[\sup_{\|{\bm{W}}_{j}\|\leq w}\sum_{t=1}^{2Nd_{o}}\sigma_{t}(\tilde{f}_{|{\mathcal{S}}})_{t}\right]\quad (Definition of Rademacher complexity)
=missingEσ{±1}Ndo[sup𝑾wj=1do𝑾ji=1Nσi,jϕ(𝑾K𝒙i)]\displaystyle=\mathop{\mathbb{missing}}{E}\limits_{\sigma\sim\{\pm 1\}^{Nd_{o}}}\left[\sup_{\|{\bm{W}}\|\leq w}\sum_{j=1}^{d_{o}}{\bm{W}}_{j}\sum_{i=1}^{N}\sigma_{i,j}\phi({\bm{W}}_{K}{\bm{x}}_{i})\right]\quad (𝑾V𝒙i{\bm{W}}_{V}{\bm{x}}_{i} is independent of 𝑾j{\bm{W}}_{j})
missingEσ{±1}Ndo[sup𝑾wj=1do𝑾ji=1Nσi,jϕ(𝑾K𝒙i)]\displaystyle\leq\mathop{\mathbb{missing}}{E}\limits_{\sigma\sim\{\pm 1\}^{Nd_{o}}}\left[\sup_{\|{\bm{W}}\|\leq w}\sum_{j=1}^{d_{o}}\left\|{\bm{W}}_{j}\right\|\left\|\sum_{i=1}^{N}\sigma_{i,j}\phi({\bm{W}}_{K}{\bm{x}}_{i})\right\|\right]\quad (By Cauchy-Schwartz inequality)
wdomissingEσ{±1}N[i=1Nσiϕ(𝑾K𝒙i)]\displaystyle\leq~{}~{}wd_{o}\mathop{\mathbb{missing}}{E}\limits_{\sigma\sim\{\pm 1\}^{N}}\left[\left\|\sum_{i=1}^{N}\sigma_{i}\phi({\bm{W}}_{K}{\bm{x}}_{i})\right\|\right]\quad (Using the fact that 𝑾jw\|{\bm{W}}_{j}\|\leq w)
wdo𝔼σ{±1}N[i=1Nσiϕ(𝑾K𝒙i)2]\displaystyle\leq wd_{o}\sqrt{\mathbb{E}_{\sigma\sim\{\pm 1\}^{N}}\left[\left\|\sum_{i=1}^{N}\sigma_{i}\phi({\bm{W}}_{K}{\bm{x}}_{i})\right\|^{2}\right]}\quad (By Jensen’s inequality)
=wdoTr(𝑲𝒮)\displaystyle=wd_{o}\mathrm{Tr}({\bm{K}}_{\mathcal{S}})

Substituting the upper bound of 𝒮(~){\mathcal{R}}_{{\mathcal{S}}}(\tilde{{\mathcal{F}}}) into (30), we will get that

(f^)(f)+O(wρdoTr(𝑲𝒮)N+log1δN).\mathcal{L}(\hat{f})\leq\mathcal{L}(f)+O\left(\frac{w\rho d_{o}\sqrt{\mathrm{Tr}({\bm{K}}_{\mathcal{S}})}}{N}+\sqrt{\frac{log\frac{1}{\delta}}{N}}\right). (31)

Thus we finish our proof. ∎

C.2 Extension to negative models:

One may also wonder whether the ratio of negative samples mentioned in Section 4 will affect the generalization bounds. In fact, after introducing negative samples and ignoring constant term in Eq (9), we consider the following representation loss:

(f)=𝔼x𝒟𝒯[1Kj=1K(𝑾ϕ(𝑾Kx))T(𝑾Vx𝑾Vxj)],\mathcal{L}(f)=\mathbb{E}_{x\sim\mathcal{D}_{\mathcal{T}}}\left[-\frac{1}{K}\sum_{j=1}^{K}\left({\bm{W}}\phi({\bm{W}}_{K}x)\right)^{T}\left({\bm{W}}_{V}x-{\bm{W}}_{V}x^{-}_{j}\right)\right],

where we consider sampling KK negative samples for each xix_{i} and xjx_{j}^{-} denotes the jj-th negative sample for token xx . Correspondingly, the empirical loss will be considered as ^(f)=1Ni=1N1Kj=1K(𝑾ϕ(𝑾Kxi))T(𝑾Vxi𝑾Vxij)\hat{\mathcal{L}}(f)=-\frac{1}{N}\sum_{i=1}^{N}\frac{1}{K}\sum_{j=1}^{K}\left({\bm{W}}\phi({\bm{W}}_{K}x_{i})\right)^{T}\left({\bm{W}}_{V}x_{i}-{\bm{W}}_{V}x^{-}_{ij}\right) where xijx_{ij}^{-} is the jj-th negative sample for xix_{i}. Then, by retaining the other definitions in Section 3.3, corresponding to Theorem 3.2, we can obtain the generalization bound as

(f^)(f)+O(wρdoTr(KS)(5N2+1rN3)+log1δN),\mathcal{L}(\hat{f})\leq\mathcal{L}(f)+O\left(w\rho d_{o}\sqrt{\mathrm{Tr}(K_{S})\left(\frac{5}{N^{2}}+\frac{1}{rN^{3}}\right)}+\sqrt{\frac{log\frac{1}{\delta}}{N}}\right),

where r=KNr=\frac{K}{N} is excatly the the ratio of the number of negative samples. It can be observed that as the ratio of negative samples increases, the generalization error decreases. However, we also notice that 5N2>1rN3\frac{5}{N^{2}}>\frac{1}{rN^{3}} thus the former term dominates, which means the reduction in generalization error due to an increased proportion of negative samples is limited. Nevertheless, we do not rule out the possibility of a tighter generalization bound, which is a promising direction for future research.

Proof Sketch.

The proof process is similar to that of Theorem 3.2. The main difference lies in the fact that we should firstly define the function class G={1Kj=1K(𝑾ϕ(𝑾Kxi))T(𝑾Vxi𝑾Vxj)|𝑾w}G=\left\{-\frac{1}{K}\sum_{j=1}^{K}\left({\bm{W}}\phi({\bm{W}}_{K}x_{i})\right)^{T}\left({\bm{W}}_{V}x_{i}-{\bm{W}}_{V}x^{-}_{j}\right)\big{|}\|{\bm{W}}\|\leq w\right\} to use the classical bound. In addition, we define F~={f~(x)=[f(x);𝑾Vx;𝑾Vx1;;𝑾VxK]|𝑾w}\tilde{F}=\left\{\tilde{f}(x)=[f(x);{\bm{W}}_{V}x;{\bm{W}}_{V}x^{-}_{1};...;{\bm{W}}_{V}x^{-}_{K}]\big{|}\|{\bm{W}}\|\leq w\right\} whose functions map from 𝒮{\mathcal{S}} to (K+2)do\mathbb{R}^{(K+2)d_{o}}. Similarly, when using Lemma C.2, we set Z=diZ=\mathbb{R}^{d_{i}}, F~\tilde{F} be the above function class and n=(K+2)don=(K+2)d_{o}. We also use h:(K+2)doh:\mathbb{R}^{(K+2)d_{o}}\rightarrow\mathbb{R} defined as h(x)=1Kj=1Kx1:doT(xdo+1:2dox(j+1)do+1:(j+2)do)h(x)=-\frac{1}{K}\sum_{j=1}^{K}x_{1:d_{o}}^{T}(x_{d_{o}+1:2d_{o}}-x_{(j+1)d_{o}+1:(j+2)d_{o}}). Then we notice that

hx1:d0=1Kj=1K(xd0+1:2dox(j+1)do+1:(j+2)do),\displaystyle\frac{\partial h}{\partial x_{1:d_{0}}}=-\frac{1}{K}\sum_{j=1}^{K}(x_{d_{0}+1:2d_{o}}-x_{(j+1)d_{o}+1:(j+2)d_{o}}), (32)
hxdo+1:2do=x1:do,hx(j+1)do+1:(j+2)do=1Kx1:do.\displaystyle\frac{\partial h}{\partial x_{d_{o}+1:2d_{o}}}=-x_{1:d_{o}},~{}~{}\frac{\partial h}{\partial x_{(j+1)d_{o}+1:(j+2)d_{o}}}=\frac{1}{K}x_{1:d_{o}}.

With the assumption that WVx,Wϕ(WKx)ρ\|W_{V}x\|,\|W\phi(W_{K}x)\|\leq\rho, we can get that the Frobenius norm of the Jocabian JJ of hh has JF24ρ2+ρ2+KK2ρ2=(5+1K)ρ2\|J\|_{F}^{2}\leq 4\rho^{2}+\rho^{2}+\frac{K}{K^{2}}\rho^{2}=(5+\frac{1}{K})\rho^{2}. Thus we get that hh is 5+1Kρ\sqrt{5+\frac{1}{K}}\rho-Lipschitz. The rest of the proof process is similar to that of Theorem 3.2. Ultimately, we will obtain the aforementioned generalization error. ∎

Appendix D Details and More Discussions for Section 4

In this section, we provide a more detailed discussion on improving the model structure from the perspective of representation learning especially contrastive learning, which is presented in Section 4 of the main body. And we also point out the corresponding modifications in the self-attention mechanism, which are adopted in our experiments.

D.1 More Discussion on the Contrastive Loss

Although we have figured out the representation learning loss of the implicit gradient updates, it can be observed that this loss function has a flaw: due to the lack of normalization for 𝒚std(i){\bm{y}}_{std}^{(i)} and 𝒚^(i)\hat{{\bm{y}}}^{(i)} when calculating the cosine distance, the loss can theoretically be optimized to negative infinity. To address this issue, we introduce regularization to constrain the norm of 𝑾{\bm{W}}, that is,

=1ηDi=1N(𝑾V𝒙i)T𝑾ϕ(𝑾K𝒙i)+α2η𝑾F2,\mathcal{L}=-\frac{1}{\eta D}\sum_{i=1}^{N}\left({\bm{W}}_{V}{\bm{x}}_{i}\right)^{T}{\bm{W}}\phi({\bm{W}}_{K}{\bm{x}}_{i})+\frac{\alpha}{2\eta}\|{\bm{W}}\|_{F}^{2},

where α\alpha is a hyperparameter to balance the two parts. As a result, we can see that the gradient update for 𝑾{\bm{W}} will be in an exponentially smoothed manner meaning that a portion of the initial part will be discarded at every step, that is,

𝑾(t)=𝑾(t1)η𝑾=(1α)𝑾(t1)+i=1ND1𝑾V𝒉iϕ(𝑾K𝒉i).{\bm{W}}^{(t)}={\bm{W}}^{(t-1)}-\eta\frac{\partial\mathcal{L}}{\partial{\bm{W}}}=(1-\alpha){\bm{W}}^{(t-1)}+\sum_{i=1}^{N}D^{-1}{\bm{W}}_{V}{\bm{h}}_{i}\otimes\phi({\bm{W}}_{K}{\bm{h}}_{i}).

Equivalently, the inference process of ICL can be seen as the first step of the aforementioned update, and the attention mechanism will be correspondingly adjusted as,

𝒉T+1=(1α)𝑾0ϕ(𝒒)+D1[i=1N𝑽D(i)ϕ(𝑲D(i))]ϕ(𝒒),\displaystyle{\bm{h}}^{\prime}_{T+1}=(1-\alpha){\bm{W}}_{0}\phi({\bm{q}})+D^{-1}\left[\sum_{i=1}^{N}{\bm{V}}_{D}^{(i)}\otimes\phi({\bm{K}}_{D}^{(i)})\right]\phi({\bm{q}}),

which means more demonstration information will be attended to. This will directly result in Eq 13.

This result can be easily extended to self-attention mechanism. As for a self-attention layer, if all other tokens adopt the same modification, the self-attention layer will become

𝑯\displaystyle{\bm{H}} =𝑾V𝑿softmax((𝑾K𝑿)T𝑾Q𝑿do)α𝑾V𝑿\displaystyle={\bm{W}}_{V}{\bm{X}}\mathrm{softmax}\left(\frac{({\bm{W}}_{K}{\bm{X}})^{T}{\bm{W}}_{Q}{\bm{X}}}{\sqrt{d_{o}}}\right)-\alpha{\bm{W}}_{V}{\bm{X}}
=𝑾V𝑿[softmax((𝑾K𝑿)T𝑾Q𝑿do)α𝑰],\displaystyle={\bm{W}}_{V}{\bm{X}}\left[\mathrm{softmax}\left(\frac{({\bm{W}}_{K}{\bm{X}})^{T}{\bm{W}}_{Q}{\bm{X}}}{\sqrt{d_{o}}}\right)-\alpha{\bm{I}}\right],

which leads to the model structure incorporating an operation similar to skip connections. Furthermore, to ensure numerical stability, we normalize the attention scores yielding:

𝑯=𝑾V𝑿Norm(softmax((𝑾K𝑿)T𝑾Q𝑿do)α𝑰),{\bm{H}}={\bm{W}}_{V}{\bm{X}}\cdot\mathrm{Norm}\left(\mathrm{softmax}\left(\frac{({\bm{W}}_{K}{\bm{X}})^{T}{\bm{W}}_{Q}{\bm{X}}}{\sqrt{d_{o}}}\right)-\alpha{\bm{I}}\right),

where Norm()\mathrm{Norm}(\cdot) is performed column-wise to ensure that the attention scores sum to 11. The above modification reduce the attention score of each token to its own information during aggregation. It is worth noting that, although our initial intention is to impose regularization on the contrastive loss where α>0\alpha>0 to prevent it from diverging to negative infinity, we find in experiments that this modification remains effective even when α\alpha is less than 0. We interpret this as possibly stemming from the fact that an appropriate α\alpha helps the attention block become full-rank, thereby better preserving information, which can be illustrated by Lemma D.1:

Lemma D.1.

Let the attention block 𝐀n×n{\bm{A}}\in{\mathbb{R}}^{n\times n}. There exists some δ>0\delta>0 such that, for any 0<|α|<δ0<|\alpha|<\delta, the attention block 𝐀+α𝐈n{\bm{A}}+\alpha{\bm{I}}_{n} will become full-rank.

Proof.

Define f(α)=det(α𝑰n+𝑨)f(\alpha)=\det(\alpha{\bm{I}}_{n}+{\bm{A}}), which is a polynomial of degree nn in α\alpha. Then, f(α)f(\alpha) has only finitely roots. Let α1,α2,,αr\alpha_{1},\alpha_{2},\ldots,\alpha_{r} be the non-zero roots of f(t)f(t). Now, consider δ=min{|α1|,|α2|,,|αr|}\delta=\min\{|\alpha_{1}|,|\alpha_{2}|,\ldots,|\alpha_{r}|\}. For 0<|α|<δ0<|\alpha|<\delta, we can claim that f(α)=det(α𝑰n+𝑨)0f(\alpha)=\det(\alpha{\bm{I}}_{n}+{\bm{A}})\neq 0. Thus, 𝑨+α𝑰n{\bm{A}}+\alpha{\bm{I}}_{n} becomes non-singular (full-rank) and we complete the proof. ∎

Lemma D.1 provides one possible case for appropriate α\alpha. In fact, the selection of α\alpha can be quite flexible; for instance, similarly, when δ=max{|α1|,|α2|,,|αr|}\delta=\max\{|\alpha_{1}|,|\alpha_{2}|,\ldots,|\alpha_{r}|\} and |α|>δ|\alpha|>\delta holds, 𝑨+α𝑰n{\bm{A}}+\alpha{\bm{I}}_{n} also remains full-rank. Our experimental results related to regularized models will further illustrate the effectiveness of an appropriate α\alpha in enhancing model performance.

We also acknowledge that our modification is relatively straightforward and may not be optimal. However, we believe that it may be a good choice to make structural improvements to the model from the perspective of the loss function, or more generally, from an optimization standpoint. For example, to address the issue of non-normalized 𝒚std(i){\bm{y}}_{std}^{(i)} and 𝒚^(i)\hat{{\bm{y}}}^{(i)}, we can also modify the loss function from the perspective of ridge regression as:

=12ηDi=1N𝑾V𝒙i𝑾ϕ(𝑾K𝒙i)F2+α2η𝑾F2.\mathcal{L}=\frac{1}{2\eta D}\sum_{i=1}^{N}\|{\bm{W}}_{V}{\bm{x}}_{i}-{\bm{W}}\phi({\bm{W}}_{K}{\bm{x}}_{i})\|_{F}^{2}+\frac{\alpha}{2\eta}\|{\bm{W}}\|_{F}^{2}.

And the optimal 𝑾{\bm{W}}^{*} will be

𝑾=[ϕ(𝑾K𝑿)ϕ(𝑾K𝑿)T+αD𝑰]1𝑾V𝑿ϕ(𝑾K𝑿).{\bm{W}}^{*}=\left[\phi({\bm{W}}_{K}{\bm{X}})\phi({\bm{W}}_{K}{\bm{X}})^{T}+\alpha D{\bm{I}}\right]^{-1}{\bm{W}}_{V}{\bm{X}}\phi({\bm{W}}_{K}{\bm{X}}).

Correspondingly, the attention mechanism will be modified to

𝑯=𝑾ϕ(𝑾Q𝑿)=[ϕ(𝑾K𝑿)ϕ(𝑾K𝑿)T+αD𝑰]1𝑾V𝑿ϕ(𝑾K𝑿)ϕ(𝑾Q𝑿),{\bm{H}}={\bm{W}}^{*}\phi({\bm{W}}_{Q}{\bm{X}})=\left[\phi({\bm{W}}_{K}{\bm{X}})\phi({\bm{W}}_{K}{\bm{X}})^{T}+\alpha D{\bm{I}}\right]^{-1}{\bm{W}}_{V}{\bm{X}}\phi({\bm{W}}_{K}{\bm{X}})\phi({\bm{W}}_{Q}{\bm{X}}), (33)

where we neglect the normalization operation. This result is very similar to the mesa-layer proposed by Von Oswald et al. [2023b], which optimizes linear attention layers under the auto-regressive setting. Here, we presented its form on softmax self-attention setting using kernel methods and explained it from the perspectives of contrastive loss and ridge regression. Although the matrix inversion calculation in Eq (33) can be computationally expensive, effective methods for computing Eq (33), including both forward computation and backward propagation, have been thoroughly researched in Von Oswald et al. [2023b], which contributes to making the above modification practically applicable.

D.2 More Discussion on the Data Augmentation

In addition to discussing the loss function, the contrastive learning paradigm also offers our some insights. In the corresponding representation learning process of ICL, we can easily notice that "data augmentation" is performed using a simple linear mapping, which may be not sufficient for learning deeper-level features. To address this, we can employ more complicated nonlinear functions for more complex augmentations. Denoting these two augmentations as g1g_{1} and g2g_{2}, consequently, the process of contrastive learning will be modified as follows

=1ηDi=1N[g1(𝑾V𝒙i)]T𝑾ϕ(g2(𝑾K𝒙i)).\mathcal{L}=-\frac{1}{\eta D}\sum_{i=1}^{N}\left[g_{1}({\bm{W}}_{V}{\bm{x}}_{i})\right]^{T}{\bm{W}}\phi(g_{2}({\bm{W}}_{K}{\bm{x}}_{i})).

Correspondingly, the gradient update for 𝑾{\bm{W}} will become

𝑾(t)=𝑾(t1)η𝑾=𝑾(t1)+i=1ND1g1(𝑾V𝒙i)ϕ(g2(𝑾K𝒙i)).{\bm{W}}^{(t)}={\bm{W}}^{(t-1)}-\eta\frac{\partial\mathcal{L}}{\partial{\bm{W}}}={\bm{W}}^{(t-1)}+\sum_{i=1}^{N}D^{-1}g_{1}({\bm{W}}_{V}{\bm{x}}_{i})\otimes\phi(g_{2}({\bm{W}}_{K}{\bm{x}}_{i})).

And from the perspective of ICL, correspondingly, the last token will be updated as

𝒉T+1=𝑾0ϕ(𝒒)+D1[i=1Ng1(𝑽D(i))ϕ(g2(𝑲D(i)))]ϕ(𝒒).\displaystyle{\bm{h}}^{\prime}_{T+1}={\bm{W}}_{0}\phi({\bm{q}})+D^{-1}\left[\sum_{i=1}^{N}g_{1}({\bm{V}}_{D}^{(i)})\otimes\phi(g_{2}({\bm{K}}_{D}^{(i)}))\right]\phi({\bm{q}}).

And by reformulating the above equation we will get Eq (14) in the main body.

Correspondingly, the modification for self-attention layer can be adjusted as,

𝑯=g1(𝑾V𝑿)softmax(g2(𝑾K𝑿)T𝑾Q𝑿do),{\bm{H}}=g_{1}({\bm{W}}_{V}{\bm{X}})\mathrm{softmax}\left(\frac{g_{2}({\bm{W}}_{K}{\bm{X}})^{T}{\bm{W}}_{Q}{\bm{X}}}{\sqrt{d_{o}}}\right),

where g1()g_{1}(\cdot) and g2()g_{2}(\cdot) will be column-wise here. It is worth noting that here we have only presented the framework of using nonlinear functions as data augmentations to modify the self-attention layer and in the simplest case, we can set g1(x)g_{1}(x) and g2(x)g_{2}(x) as MLPs (Multi-Layer Perceptrons). However, in practice, it is encouraged to use data augmentation functions that are tailored to specific data structures. For example, in the case of CMT [Guo et al., 2022], the used Convolutional Neural Networks (CNNs) can be considered as a form of "strong data augmentations" suitable for image datas within our framework. We consider the exploration of various augmentation methods tailored to different types of data as an open question for future research.

D.3 More discussion on the Negative Samples

Although the gradient descent process corresponding to ICL exhibits some similarities with traditional contrastive learning approaches without negative samples, there are also significant differences: In traditional Siamese networks, the augmented representations as positive pairs are further learned through target and online network that share weights (or at least influence each other using EMA). The output of the target network is then passed through a predictor to compute the contrastive loss. In contrast, the representation learning pattern corresponding to ICL indeed performs more simply, which may potentially limit the ability of the dual model to learn representations fully without negative samples. To address this, similar to most contrastive learning approaches, we can introduce negative samples forcing the model to separate the distances between positive and negative samples at the same time, that is,

\displaystyle\mathcal{L} =1ηDi=1N(𝑾V𝒙i)T𝑾ϕ(𝑾K𝒙i)+βηDi=1N1|𝒩(i)|j𝒩(i)(𝑾V𝒙j)T𝑾ϕ(𝑾K𝒙i)\displaystyle=-\frac{1}{\eta D}\sum_{i=1}^{N}\left({\bm{W}}_{V}{\bm{x}}_{i}\right)^{T}{\bm{W}}\phi({\bm{W}}_{K}{\bm{x}}_{i})+\frac{\beta}{\eta D}\sum^{N}_{i=1}\frac{1}{|\mathcal{N}(i)|}\sum_{j\in\mathcal{N(\mathit{i})}}\left({\bm{W}}_{V}{\bm{x}}_{j}\right)^{T}{\bm{W}}\phi({\bm{W}}_{K}{\bm{x}}_{i})
=1ηDi=1N(𝑾V(𝒙iβ|𝒩(i)|j𝒩(i)𝒙j))T𝑾ϕ(𝑾K𝒙i)\displaystyle=-\frac{1}{\eta D}\sum_{i=1}^{N}\left({\bm{W}}_{V}\left({\bm{x}}_{i}-\frac{\beta}{|\mathcal{N}(i)|}\sum_{j\in\mathcal{N(\mathit{i})}}{\bm{x}}_{j}\right)\right)^{T}{\bm{W}}\phi({\bm{W}}_{K}{\bm{x}}_{i})
=1ηDi=1N(𝑾V𝒙~i)T𝑾ϕ(𝑾K𝒙i),\displaystyle=-\frac{1}{\eta D}\sum_{i=1}^{N}\left({\bm{W}}_{V}\tilde{{\bm{x}}}_{i}\right)^{T}{\bm{W}}\phi({\bm{W}}_{K}{\bm{x}}_{i}),

where 𝒙~i=𝒙iβ|𝒩(i)|j𝒩(i)𝒙j\tilde{{\bm{x}}}_{i}={\bm{x}}_{i}-\frac{\beta}{|\mathcal{N}(i)|}\sum_{j\in\mathcal{N(\mathit{i})}}{\bm{x}}_{j}, 𝒩(i)\mathcal{N}(i) is the set of the negative samples for 𝒙i{\bm{x}}_{i} and β\beta is a hyperparameter. As a result, the gradient descent on 𝑾{\bm{W}} will be modified as

𝑾(t)=𝑾(t1)η𝑾=𝑾(t1)+i=1ND1𝑾V𝒙~iϕ(𝑾K𝒙i).{\bm{W}}^{(t)}={\bm{W}}^{(t-1)}-\eta\frac{\partial\mathcal{L}}{\partial{\bm{W}}}={\bm{W}}^{(t-1)}+\sum_{i=1}^{N}D^{-1}{\bm{W}}_{V}\tilde{{\bm{x}}}_{i}\otimes\phi({\bm{W}}_{K}{\bm{x}}_{i}).

Correspondingly, the ICL process for 𝒉^N+1\hat{{\bm{h}}}_{N+1} will be

𝒉T+1=𝑾0ϕ(𝑾Q𝒙T+1)+D1[i=1N𝑾V𝒙~iϕ(𝑾K𝒙i)]ϕ(𝑾Q𝒙T+1).\displaystyle{\bm{h}}^{\prime}_{T+1}={\bm{W}}_{0}\phi({\bm{W}}_{Q}{\bm{x}}^{\prime}_{T+1})+D^{-1}\left[\sum_{i=1}^{N}{\bm{W}}_{V}\tilde{{\bm{x}}}_{i}\otimes\phi({\bm{W}}_{K}{\bm{x}}_{i})\right]\phi({\bm{W}}_{Q}{\bm{x}}^{\prime}_{T+1}).

And this will directly result in Eq (15) in the main body.

As for a self-attention layer, similarly, we can get the corresponding modification as

𝑯=𝑾V𝑿~softmax((𝑾K𝑿)T𝑾Q𝑿do),\displaystyle{\bm{H}}={\bm{W}}_{V}\tilde{{\bm{X}}}\mathrm{softmax}\left(\frac{({\bm{W}}_{K}{\bm{X}})^{T}{\bm{W}}_{Q}{\bm{X}}}{\sqrt{d_{o}}}\right), (34)

where 𝑿~(i)=𝒙~i\tilde{{\bm{X}}}^{(i)}=\tilde{{\bm{x}}}_{i}. In corresponding experiments, for each token, we simply choose other the kk least relevant tokens as its negative samples, i.e., the kk tokens with the lowest attention scores. Noting that here we simply use other token representations as negative samples for 𝒙i{\bm{x}}_{i}. However, there are more ways to construct negative samples that are worth exploring (for instance, using noise vectors or tokens with low semantic similarity as negative samples). For specific data structures and application scenarios, customizing the selection or construction of negative samples may be more effective.

Appendix E More Experiments

E.1 More details of Experiments on Linear Task

In this part, we will discuss our experimental setup in more details and provide more results on linear regression task.

Inspired by Garg et al. [2022] and Von Oswald et al. [2023a], we choose to pretrain a softmax attention layer before exploring the equivalence proposed by Theorem 3.1. In fact, pretraining is not mandatory since our theoretical analysis does not depend on any specific weight construction. In other words, the inference results of ICL and the test prediction of the dual model will still remain consistent for an attention layer with arbitrary weights or even random initialization. However, for the convenience of further investigating the impact of subsequent modifications to the model structure and to better align with real-world scenarios, we still opted for pretraining to let the model acquire some task-specific knowledge. Additionally, our experiments are conducted in a self-attention setting. When we focus only on the last token, this is equivalent to considering the case with only one query token (T=0T=0) in Section 2.1. The experiments are completed on a single 24GB NVIDIA GeForce RTX 3090 and the experiments can be completed within one day.

For the linear regression task, we generate the task by 𝒔=𝑾𝒕{\bm{s}}={\bm{W}}{\bm{t}} where every element of 𝑾ds×dt{\bm{W}}\in\mathbb{R}^{d_{s}\times d_{t}} is sampled from a normal distribution 𝑾ij𝒩(0,1){\bm{W}}_{ij}\sim\mathcal{N}(0,1) and 𝒕{\bm{t}} is sampled from a Gaussian distribution 𝒙U(1,1)dt{\bm{x}}\sim U(-1,1)^{d_{t}}. To facilitate more accurate estimation of attention matrices using random features and considering the limited learning capacity of a single attention layer, we only set a small value for dt=11d_{t}=11 and ds=1d_{s}=1. Then, at each step, we use generated {𝒙i=[𝒕i;si]}i=1N+1\{{\bm{x}}_{i}=[{\bm{t}}_{i};s_{i}]\}^{N+1}_{i=1} to form the input matrix 𝑿{\bm{X}} while the label part of the query token is masked to be zero, that is, 𝒙N+1=[𝒕i;0]{\bm{x}}_{N+1}=[{\bm{t}}_{i};0] where we consider only one query token and we denote 𝒙T+1=𝒙N+1{\bm{x}}^{\prime}_{T+1}={\bm{x}}_{N+1} to maintain consistency of notation in Section 2.1. The softmax attention layer is expected to predict 𝒔^N+1\hat{{\bm{s}}}_{N+1} to approximate the ground truth value 𝒔N+1{\bm{s}}_{N+1}. We use mean square error (MSE) as the loss function, that is, for each epoch,

=1Nstepj=1Nstep𝒔^N+1(j)𝒔N+1(j)2,\mathcal{L}=\frac{1}{N_{step}}\sum_{j=1}^{N_{step}}\|\hat{{\bm{s}}}_{N+1}^{(j)}-{\bm{s}}_{N+1}^{(j)}\|^{2},

where 𝒔^N+1(j)\hat{{\bm{s}}}_{N+1}^{(j)} and 𝒔N+1(j){\bm{s}}_{N+1}^{(j)} are the prediction and ground truth value at jj-th step and NstepN_{step} is the number of steps. We set Nstep=1024N_{step}=1024 for N+1=16N+1=16 which means the total number of tokens remains 1638416384. We choose stochastic gradient descent (SGD) [Amari, 1993] as the optimizer and we set the learning rate to 0.003 for normal and regularized models, while the remaining experiments to 0.0050.005. We also attempt the multi-task scenario, where the input token at each step is generated from a different task. However, we find it challenging for a single attention layer to effectively learn in this setting, resulting in disordered predictions. Therefore, our experiments are currently limited to single-task settings, and the multi-task scenario is worth further investigation in the future.

Refer to caption
(a) dr=3d_{r}=3
Refer to caption
(b) dr=12d_{r}=12
Refer to caption
(c) dr=120d_{r}=120
Refer to caption
(d) dr=12000d_{r}=12000
Refer to caption
(e) exact attention
Figure 6: The estimation of the attention matrix by positive random features when varying drd_{r}
Refer to caption
(a) dr=3d_{r}=3
Refer to caption
(b) dr=12d_{r}=12
Refer to caption
(c) dr=120d_{r}=120
Refer to caption
(d) dr=12000d_{r}=12000
Refer to caption
(e) exact output
Figure 7: The estimation of the output matrix by positive random features when varying drd_{r}

It is worth noting that we approximate the attention matrix calculation using random features as kernel mapping function instead of using the traditional softmax function in the self-attention layer [Choromanski et al., 2020]. The mapping function ϕ:dodr\phi:{\mathbb{R}}^{d_{o}}\to{\mathbb{R}}^{d_{r}} has the form of ϕ(𝒙)=e𝒘T𝒙𝒙2/2\phi({\bm{x}})=e^{{\bm{w}}^{T}{\bm{x}}-\|{\bm{x}}\|^{2}/2} where 𝒘𝒩(0,I){\bm{w}}\sim\mathcal{N}(0,I). Orthogonal random features [Yu et al., 2016, Choromanski et al., 2020] or simplex random features [Reid et al., 2023] can be chosen to achieve better performance theoretically. We investigate the impact of changing the dimension of random features drd_{r} on the approximation of attention matrices and output, using Mean Squared Error (MSE) and Mean Absolute Error (MAE) as evaluation metrics, where we conduct 5050 repeated experiments and calculated the average values for each value of drd_{r}, as shown in Figure 8. It can be observed that as the dimension of random features increases, the approximation performance gradually improves, with both errors reaching a low level in the end. We visualize the exact attention matrix and compare it with the estimated attention matrices obtained using different values of drd_{r}, as shown in Figure 6. Again, it can be seen that as drd_{r} increases, the approximation of the true attention matrix improves gradually and similar results can be observed for the analysis of output matrices in Figure 7.

To obtain a more accurate estimation of the attention matrix, we set the output dimension of the mapping function to be 100 times the input dimension, that is, dr=100(ds+dt)=1200d_{r}=100(d_{s}+d_{t})=1200. Furthermore, we visualize the exact attention matrix and the output with the approximation results, which are shown in the Figure 9. As we can see, although some larger values are not estimated accurately due to the limited dimension of the random features we select, the majority of the information is still estimated comprehensively well. These findings indicate that our choice of using positive random features as mapping functions to estimate the true softmax attention and conduct experiments is relatively feasible.

Refer to caption
(a) Estimation error of attention matrix vs. drd_{r}
Refer to caption
(b) Estimation error of attention matrix vs. drd_{r}
Figure 8: The error of positive random features in estimating the attention and output matrices as drd_{r} varies.
Refer to caption
(a) the exact attention matrix and its approximation
Refer to caption
(b) the exact output and its approximation
Figure 9: The comparison between the exact attention matrix, output and their estimated approximations using random features under setting N=16N=16 and dr=1200d_{r}=1200.

After the weights 𝑾Q{\bm{W}}_{Q}, 𝑾K{\bm{W}}_{K}, 𝑾V{\bm{W}}_{V} of the attention layer have been determined, we generate test N+1N+1 tokens in the same way where the ss part of the (N+1)(N+1)-th token is also set to be zero and finally input the test tokens into the attention layer to obtain the corresponding predicted 𝒉^N+1=[𝒕^N+1(1),s^N+1(1)]\hat{{\bm{h}}}_{N+1}=[\hat{{\bm{t}}}_{N+1}^{(1)},\hat{s}_{N+1}^{(1)}]. Here, we also use 𝒉T+1=𝒉^N+1{\bm{h}}^{\prime}_{T+1}=\hat{{\bm{h}}}_{N+1} to maintain the notation consistency in Section 2.1.

On the other hand, we construct a dual model f(𝒙)=𝑾ϕ(𝒙)f({\bm{x}})={\bm{W}}\phi({\bm{x}}) where ϕ()\phi(\cdot) is strictly equivalent to the kernel mapping function used in the attention layer. We transform the first NN tokens as the training set according to Theorem 3.1 and train the dual model using the loss formed by Eq (9). In fact, according to Theorem 3.1, after we perform one step of gradient descent on this training set, the test prediction 𝒚^test=[𝒕^N+1(2),s^N+1(2)]\hat{{\bm{y}}}_{test}=[\hat{{\bm{t}}}_{N+1}^{(2)},\hat{s}_{N+1}^{(2)}] of the dual model will strictly equal 𝒉T+1{\bm{h}}^{\prime}_{T+1}.

We conduct experiments under the same setup using different random seeds to explore the effects of various model modifications. The data for all three experiments are generated under identical conditions. One set of experimental results is presented in the main text, while the results of the other two sets are shown in the Figure 10. Similar to the discussion in the main text, we can achieve better performance than the normal model with appropriate parameter settings.

Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 10: The performance for regularized models (Center Left), augmented models (Center Right) and negative models (Right) with different settings for different random seeds.

E.2 More details of Experiments on Different Tasks

In addition to conducting experiments on linear regression tasks, we also extended our experiments to involve trigonometric and exponential tasks.

E.2.1 More details of Experiments on Trigonometric Tasks

Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 11: The equivalence between ICL of one softmax attention layer and gradient descent, along with analysis on different model modifications for trigonometric tasks. Left Part: 𝒚^test𝒉T+12\|\hat{{\bm{y}}}_{test}-{\bm{h}}^{\prime}_{T+1}\|_{2} as the gradient descent proceeds under setting N=127N=127; Remaining Part: the performance for regularized models (Center Left), augmented models (Center Right) and negative models (Right) with different settings.

For trigonometric task, we generate the task by 𝒔=cos(𝑾𝒕){\bm{s}}=\mathrm{cos}({\bm{W}}{\bm{t}}) where cos()\mathrm{cos}(\cdot) is element-wise, 𝑾ds×dt{\bm{W}}\in\mathbb{R}^{d_{s}\times d_{t}} is sampled from the normal distribution 𝑾ij𝒩(0,1){\bm{W}}_{ij}\sim\mathcal{N}(0,1) while 𝒕{\bm{t}} is sampled from the uniform distribution 𝒙U(0,π)dt{\bm{x}}\sim U(0,\pi)^{d_{t}}. In experiments, we found that for one softmax attention layer, learning higher-dimensional tasks is challenging. Therefore, we only set dt=7d_{t}=7 and ds=1d_{s}=1. At each step, we use N+1=128N+1=128 tokens {𝒙i=[𝒕i;𝒔i]}i=1N+1\{{\bm{x}}_{i}=[{\bm{t}}_{i};{\bm{s}}_{i}]\}^{N+1}_{i=1} and the total number of tokens remains unchanged at 1638416384. Compared to the setting N+1=16N+1=16 of linear tasks, we observed that for more complex tasks, the attention layer needs to use more tokens to provide information at each training step. Similarly, we mask the label part of the last token, that is, 𝒔N+1=𝟎{\bm{s}}_{N+1}={\bm{0}} and use mean square error (MSE) loss to train the attention layer. We choose SGD as the optimizer and the learning rate is set as 0.005. The rest of the settings remain consistent with those used in the linear task. The result for trigonometric regression task is shown in Figure 11.

Firstly, as shown in the left part of Figure 11, the inference results is of ICL is strictly equivalent to the prediction of the dual model, that is, 𝒉^N+1=𝒚^test\hat{{\bm{h}}}_{N+1}=\hat{{\bm{y}}}_{test} as well as the label part s^N+1(1)=s^N+1(2)\hat{s}_{N+1}^{(1)}=\hat{s}_{N+1}^{(2)}, aligning with our analysis in Theorem 3.1.

The performance of modified model during training process can be seen in the remaining parts of Figure 11. For regularized models, as seen in the center left part of figure 11, the models when α<0\alpha<0 converge slightly faster and reach better final results compared to the normal model (α=0\alpha=0). For augmented models, we use as the same augmentation functions g1g_{1} and g2g_{2} as the ones in the linear regression task, that is, g1(𝒙)=g2(𝒙)=σ(𝑾𝒙)g_{1}({\bm{x}})=g_{2}({\bm{x}})=\sigma({\bm{W}}{\bm{x}}) where σ()\sigma(\cdot) is GELU\mathrm{GELU} activation function. However, for g2+g_{2}^{+}, we use ELU\mathrm{ELU} as the activation function. We can find from the center right part of Figure 11 that, compared to the normal model, using g1g_{1} alone and using g1g_{1} and g2g_{2} simultaneously as data augmentations significantly degrade the model’s performance, including convergence speed and final results. However, using g2g_{2} alone yields comparable result with the normal model. Particularly, when using g2+g_{2}^{+}, the model accelerates its convergence speed. However, for negative models, the performance with the selected number of negative samples kk and the parameter β\beta is worse than the normal model, which suggests that our simple approach of selecting those tokens low attention scores as negative samples is not a reasonable method. Just as we discussed in Section 4, for different tasks, a more refined strategy for selecting negative samples should be considered.

E.2.2 More details of Experiments on Exponential Tasks

Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 12: The equivalence between ICL of one softmax attention layer and gradient descent, along with analysis on different model modifications for exponential tasks. Left Part: 𝒚^test𝒉T+12\|\hat{{\bm{y}}}_{test}-{\bm{h}}^{\prime}_{T+1}\|_{2} as the gradient descent proceeds under setting N=511N=511; Remaining Part: the performance for regularized models (Center Left), augmented models (Center Right) and negative models (Right) with different settings.

For exponential task, we generate the task by 𝒔=exp(𝑾𝒕){\bm{s}}=\mathrm{exp}({\bm{W}}{\bm{t}}) where exp()\mathrm{exp}(\cdot) is also element-wise, 𝑾ds×dt{\bm{W}}\in\mathbb{R}^{d_{s}\times d_{t}} is sampled from the normal distribution 𝑾ij𝒩(0,1){\bm{W}}_{ij}\sim\mathcal{N}(0,1) while 𝒕{\bm{t}} is sampled from the uniform distribution 𝒙U(1,1)dt{\bm{x}}\sim U(-1,1)^{d_{t}}. We only set dt=6d_{t}=6 and ds=1d_{s}=1 considering the limited learning capacity of one softmax attention layer. At each training step, we use N+1=512N+1=512 tokens {𝒙i=[𝒕i;𝒔i]}i=1N+1\{{\bm{x}}_{i}=[{\bm{t}}_{i};{\bm{s}}_{i}]\}^{N+1}_{i=1} and the total number of tokens remains unchanged at 1638416384. Compared to the setting N+1=16N+1=16 of linear tasks and N+1=128N+1=128 of trigonometric tasks, we also find that for exponential tasks, the attention layer needs more tokens to provide in-context information at each training step. The rest of the settings remain consistent with those used in the trigonometric task. The result for exponential regression task is shown in Figure 12.

Similarly, as shown in the left part of Figure 12, the result 𝒉^N+1\hat{{\bm{h}}}_{N+1} of ICL inference is equivalent to the test prediction 𝒚^test\hat{{\bm{y}}}_{test} of the dual model after training, just as stated in Theorem 3.1. For regularized models, it can be observed that when α=16\alpha=16, the model converges faster and achieves better result. For augmented models, using g1g_{1} or g2g_{2} alone as data augmentations results in better performance. However, when both g1g_{1} and g2g_{2} are used simultaneously, the training process becomes unstable, so we did not show it in the center right part of Figure 12. For negative model, similar to the case in the trigonometric task, the different combinations of negative samples’ number kk and parameter β\beta do not show a significant improvement over the normal model, highlighting the importance of the strategy for selecting negative samples. We leave the exploration of a more refined negative sample selection strategy when facing various tasks for future consideration.

E.3 More Experiments on Combinations

In addition, we also conduct experiments with their combinations on linear tasks, trigonometric tasks , and exponential tasks. The results are shown in Figure 13. For linear tasks, a combination of regularized and augmented modifications is sufficient. However, for the other two tasks, the results are actually worse than using regularized or augmented modification individually (compared to Figures 11 and 12). We think this may be due to the ineffective selection of negative samples, which is amplified when combined. Therefore, when the design of augmentation or negative sample improvement methods is not effective, we recommend using a single modification method.

Refer to caption
Refer to caption
Refer to caption
Figure 13: Performance of different combinations on linear (Left), exponential (Center), and trigonometric (Right) tasks.

E.4 More Experiments on One Transformer Layer

Refer to caption
Refer to caption
Figure 14: The equivalence between ICL of one Transformer layer and gradient descent, along with analysis on upper bound of Rank(𝑾F{\bm{W}}_{F}). Left: 𝒚^test𝒉T+12\|\hat{{\bm{y}}}_{test}-{\bm{h}}^{\prime}_{T+1}\|_{2} as the gradient descent proceeds under setting N=15N=15; Right: the upper bound of Rank(𝑾F{\bm{W}}_{F}) when setting d=12d=12 and varying dhd_{h}.

Similar to the experiments with one softmax attention layer, we also conduct experiments on a Transformer layer (introducing one FFN layer after the attention layer) and trained its dual model based on Theorem B.1. As shown in Figure 14, the inference result 𝒉^N+1\hat{{\bm{h}}}_{N+1} of ICL remains equivalent to the test prediction 𝒚^test\hat{{\bm{y}}}_{test} of the trained dual model. Furthermore, to validate the potential low-rank property of matrix 𝑾F{\bm{W}}_{F}, we explore its upper bound of rank. Noting that 𝑾F=𝑾2𝑰M𝑾1{\bm{W}}_{F}={\bm{W}}_{2}{\bm{I}}_{M}{\bm{W}}_{1} where 𝑾1dh×d{\bm{W}}_{1}\in{\mathbb{R}}^{d_{h}\times d}, 𝑾2d×dh{\bm{W}}_{2}\in{\mathbb{R}}^{d\times d_{h}}, 𝑰Mdh×dh{\bm{I}}_{M}\in{\mathbb{R}}^{d_{h}\times d_{h}}, the upper bound of Rank(𝑾F)\mathrm{Rank}({\bm{W}}_{F}) is

Rank(𝑾F)min{d,dh,Rank(𝑰M)},\mathrm{Rank}({\bm{W}}_{F})\leq\min\left\{d,d_{h},\mathrm{Rank}({\bm{I}}_{M})\right\},

where Rank(𝑰M)\mathrm{Rank}({\bm{I}}_{M}) is equivalent to the number of non-zero elements in 𝑰M{\bm{I}}_{M}. We fix d=12d=12 while varying the values of dhd_{h}. We generate 1024 sets of 𝑿test{\bm{X}}_{test} for different tasks and repeat the experiments 5 times. Finally, we calculate the average upper bound of the rank of 𝑾F{\bm{W}}_{F}. The results are shown in the right part of Figure 14, indicating that when dh2.75d=33d_{h}\geq 2.75d=33, the upper bound remains stable and equals d=12d=12. Otherwise, when dhd_{h} is set to a smaller value, 𝑾F{\bm{W}}_{F} exhibits clear low-rank property.

E.5 More Experiments on More Realistic NLP Tasks

  Model Types Dataset CoLA MRPC STS-B RTE
  Normal Bert-base-uncased 56.82 90.24/86.27 88.29/87.96 68.23
  α=1.0\alpha=-1.0 0.0 79.01/68.87 57.23/60.16 52.71
α=0.5\alpha=-0.5 61.42 83.17/74.02 85.28/85.22 57.04
α=0.1\alpha=-0.1 58.06 89.50/85.05 88.71/88.27 65.70
α=0.1\alpha=0.1 58.34 90.59/86.76 88.12/87.81 64.98
α=0.5\alpha=0.5 27.01 83.56/73.28 85.25/85.03 59.93
α=1.0\alpha=1.0 0.0 81.22/68.38 52.07/55.60 47.29
Regularized
Models
Local Best 61.42 90.59/86.76 88.71/88.27 65.70
  Augmented
Models
g1\mathrm{g_{1}} / c=0.2c=0.2 59.85 88.11/83.33 88.56/88.22 68.59
g1\mathrm{g_{1}} / c=1c=1 56.51 90.88/87.01 88.96/88.60 71.12
g2\mathrm{g_{2}} / c=0.2c=0.2 56.29 87.65/82.60 88.60/88.24 68.59
g2\mathrm{g_{2}} / c=1c=1 58.85 87.74/82.60 88.68/88.32 70.40
g1&g2\mathrm{g_{1}~{}\&~{}g_{2}} / c=0.2c=0.2 57.32 89.62/85.29 88.48/88.19 71.12
g1&g2\mathrm{g_{1}~{}\&~{}g_{2}} / c=1c=1 58.30 90.40/86.52 88.83/88.45 68.95
Local Best 59.85 90.88/87.01 88.96/88.60 71.12
  Negative
Models
r=0.1r=0.1 / β=0.1\beta=0.1 56.22 88.54/83.82 88.25/87.91 65.34
r=0.2r=0.2 / β=0.1\beta=0.1 57.92 90.00/85.78 88.22/87.84 66.06
r=0.3r=0.3 / β=0.1\beta=0.1 57.92 89.31/84.80 88.26/87.90 67.15
r=0.1r=0.1 / β=0.2\beta=0.2 58.92 87.90/83.33 88.34/88.11 63.54
r=0.2r=0.2 / β=0.2\beta=0.2 57.13 87.87/83.09 88.59/88.27 64.98
r=0.3r=0.3 / β=0.2\beta=0.2 58.14 88.97/84.56 88.64/88.33 66.79
Local Best 58.92 90.00/85.78 88.64/88.33 67.15
  Combined
Models
Reg & Aug 56.56 88.54/83.82 88.86/88.60 68.59
Reg & Neg 58.11 88.19/83.33 88.41/88.17 69.31
Aug & Neg 59.07 90.49/86.76 88.59/88.21 70.76
Reg & Aug & Neg 58.92 88.39/83.58 88.32/88.01 67.87
Local Best 59.07 90.49/86.76 88.86/88.60 70.76
  Global Best 61.42 90.88/87.01 88.96/88.60 71.12
 
Table 1: Partial GLUE test results of different modifications. “Local Best" is used to display the best results for each modification type, where bolded results indicate the performance superior to the original model. “Global Best" is used to showcase the best results among all modifications. Matthews correlation, F1 scores/accuracy, Pearson/Spearman correlation, accuracy are reported for CoLA, MRPC, STS-B, RTE respectively.

We supplement our experiments on more more realistic NLP tasks. We choose the BERT-base-uncased model (can be downloaded from Huggingface library[Wolf, 2019], hereafter referred to as BERT[Kenton and Toutanova, 2019]) to validate the effectiveness of modifications to the attention mechanism and select four relatively smaller GLUE datasets (CoLA, MRPC, STS-B, RTE) [Wang, 2018]. We load the checkpoint of the pre-trained BERT model, where ’classifier.bias’ and ’classifier.weight’ are newly initialized, and then we fine-tune the model to explore the performance of three attention modifications as well as their combinations. In terms of more detailed experiment settings, we set the batch size to 32, the learning rate to 2e-5, and the number of epochs to 5 for all datasets. All experiments are conducted on a single 24GB NVIDIA GeForce RTX 3090. All experimental results are presented in Table 1. Below, we discuss the various modifications and their performance.

For the regularized modification, we consider different values of α\alpha, specifically selected from {0.5,0.1,0.1,0.5}\{-0.5,-0.1,0.1,0.5\}. As can be observed in Table 1, except for RTE, the best regularized models outperform the original model on the other three datasets. However, we also note that when the absolute value of α\alpha is too large, the model’s performance declines significantly, so we recommend using smaller absolute values for α\alpha.

For the augmented modification, we also consider applying more complex “augmentation” functions to the linear key/value mappings. However, unlike the previous methods used in simulation tasks, we do not simply select g1g_{1} and g2g_{2} as MLPs, i.e., g1(𝑾V𝒙)=𝑾2σ(𝑾1𝑾V𝒙)g_{1}({\bm{W}}_{V}{\bm{x}})={\bm{W}}_{2}\sigma({\bm{W}}_{1}{\bm{W}}_{V}{\bm{x}}). This design is avoided because it could undermine the effort made during pre-training to learn the weights 𝑾V{\bm{W}}_{V} and 𝑾K{\bm{W}}_{K}, leading to difficulties in training and challenges in comparison. Instead, we adopt a parallel approach, i.e., g1(𝑾Vx)=𝑾Vx+c𝑾2σ(𝑾1x)g_{1}({\bm{W}}_{V}x)={\bm{W}}_{V}x+c{\bm{W}}_{2}\sigma({\bm{W}}_{1}x), where cc is a hyperparameter to control the influence of the new branch, σ\sigma is the GELU activation function and the hidden layer dimension is set to twice the original size of 𝑾Vx{\bm{W}}_{V}x. g2(𝑾Kx)=𝑾Kx+c𝑾2σ(𝑾1x)g_{2}({\bm{W}}_{K}x)={\bm{W}}_{K}x+c{\bm{W}}_{2}\sigma({\bm{W}}_{1}x) follows the same format.

Experimental results show that the best augmented models achieve better performance than the original model across all four datasets. Notably, augmentation on the value mapping (i.e., using g1g_{1} alone) proves to be more effective than other methods, both in terms of performance and the amount of additional parameters introduced. Using both g1g_{1} and g2g_{2} introduces more parameters, which is particularly undesirable for larger models. Thus, under the augmentation methods and experimental settings we selected, using g1g_{1} alone is recommended.

In addition, we do not rule out the possibility of more powerful and efficient augmentation methods. Our choice of g1g_{1} and g2g_{2} as parallel MLPs is primarily motivated by the desire to make better use of the pre-trained weights 𝑾K{\bm{W}}_{K} and 𝑾V{\bm{W}}_{V}. We have also noticed that this specific augmentation function design is structurally similar to the Parallel Adapter [He et al., 2021]. However, we would like to emphasize that our parallel design is just a specific case within this broader augmented modification framework and this is a new perspective for understanding the Parallel Adapter. As for practical implementation, the Parallel Adapter method focuses more on efficient training, so it uses fewer parameters, and the original 𝑾V{\bm{W}}_{V} and 𝑾K{\bm{W}}_{K} are freezed—only the newly introduced parameters are trained. In contrast, our approach aims to validate the benefits of introducing stronger nonlinear augmentation functions into the linear value/key mappings. Therefore, we set a higher hidden layer dimension (twice that of 𝑾V𝒙{\bm{W}}_{V}{\bm{x}} or 𝑾K𝒙{\bm{W}}_{K}{\bm{x}}) and also train 𝑾V{\bm{W}}_{V} and 𝑾K{\bm{W}}_{K} simultaneously. This design is relatively general and does not take into account the specific characteristics of individual tasks. We still encourage the development of more task-specific augmentation strategies tailored to different tasks.

For the negative modification, we continue to select tokens with lower attention scores as negative samples. The parameter rr represents the proportion of tokens used as negative samples, while β\beta indicates the overall reduction in attention scores. We choose rr from {0.1,0.2,0.3}\{0.1,0.2,0.3\} and β\beta from {0.1,0.2}\{0.1,0.2\}. Under these combinations, the best negative models only outperform the original model on CoLA and STS-B, whereas their performance on MRPC and RTE is worse than the original one. This suggests that our simple approach of considering tokens with low attention scores as negative samples might be too coarse. A more effective method for constructing negative samples should be designed, which is a direction worth exploring in the future.

We also consider combining different modification methods. Specifically, we choose α=0.1\alpha=0.1, g1/c=1g_{1}/c=1 and r=0.2/β=0.1r=0.2/\beta=0.1 respectively as the basis for combining the three types of modifications, considering their overall performance across all datasets. The results indicate that under our settings, the combination of augmented and negative modification achieves the best performance on CoLA, MRPC, and RTE, while the combination of regularized and augmented modification achieves the best performance on STS-B. However, their optimal performance is slightly inferior to the best performance achieved with augmented models alone. Therefore, we conclude that using all three modifications simultaneously is not necessary. With appropriate hyperparameter choices, using augmented modification alone or in combination with one other modification is sufficient.

Overall, the experimental results show that our modifications inspired by the representation learning process are helpful in enhancing performance. This further validates the potential of our approach of thinking about and improving the attention mechanism from a representation learning perspective. In addition, we would like to reiterate that more validation across additional tasks and models, and the development of task-specific augmentation and negative sampling methods are all interesting directions worth exploring in the future.

Appendix F More Details about Related Work

In this section, we provide additional details about the related work in Section  6, especially those that involve formalization. Dai et al. [2022] interpret ICL as implicit fine-tuning: More specifically, let 𝑿=[𝑿D,𝑿T]{\bm{X}}=[{\bm{X}}_{D},{\bm{X}}_{T}] where 𝑿D=[𝒙1,𝒙2,,𝒙N]{\bm{X}}_{D}=[{\bm{x}}_{1},{\bm{x}}_{2},\dots,{\bm{x}}_{N}] denotes the demonstration tokens and 𝑿T=[𝒙1,𝒙2,,𝒙T]{\bm{X}}_{T}=[{\bm{x}}^{\prime}_{1},{\bm{x}}^{\prime}_{2},\dots,{\bm{x}}^{\prime}_{T}] be query tokens. On the one hand, for ICL, they consider the output of 𝒒=𝑾Q𝒙T+1{\bm{q}}={\bm{W}}_{Q}{\bm{x}}^{\prime}_{T+1} under the linear attention setting as

F~ICL(𝒒)\displaystyle\tilde{F}_{\mathrm{ICL}}({\bm{q}}) =𝑾V[𝑿D,𝑿T](𝑾K[𝑿D;𝑿T])T𝒒\displaystyle={\bm{W}}_{V}[{\bm{X}}_{D},{\bm{X}}_{T}]({\bm{W}}_{K}[{\bm{X}}_{D};{\bm{X}}_{T}])^{T}{\bm{q}}
=𝑾V𝑿T(𝑾K𝑿T)Tq+𝑾V𝑿D(𝑾K𝑿D)T𝒒\displaystyle={\bm{W}}_{V}{\bm{X}}_{T}({\bm{W}}_{K}{\bm{X}}_{T})^{T}q+{\bm{W}}_{V}{\bm{X}}_{D}({\bm{W}}_{K}{\bm{X}}_{D})^{T}{\bm{q}}
=𝑾ZSL𝒒+LinearAtten(𝑾V𝑿D,𝑾K𝑿D,𝒒)\displaystyle={\bm{W}}_{\mathrm{ZSL}}{\bm{q}}+\mathrm{LinearAtten}({\bm{W}}_{V}{\bm{X}}_{D},{\bm{W}}_{K}{\bm{X}}_{D},{\bm{q}})
=𝑾ZSL𝒒+i((𝑾V𝒙i)(𝑾K𝒙i))T𝒒\displaystyle={\bm{W}}_{\mathrm{ZSL}}{\bm{q}}+\sum_{i}\left(({\bm{W}}_{V}{\bm{x}}_{i})\odot({\bm{W}}_{K}{\bm{x}}_{i})\right)^{T}{\bm{q}}
=𝑾ZSL𝒒+Δ𝑾ICL𝒒,\displaystyle={\bm{W}}_{\mathrm{ZSL}}{\bm{q}}+\Delta{\bm{W}}_{\mathrm{ICL}}{\bm{q}},

where 𝑾ZSL𝒒{\bm{W}}_{\mathrm{ZSL}}{\bm{q}} is interpreted as the output in the zero-shot learning (ZSL) where no demonstrations are given. On the other hand, they consider a specific fine-tuning setting, which updates only the parameters for the key and value projection, that is,

F~FT(𝒒)\displaystyle\tilde{F}_{\mathrm{FT}}({\bm{q}}) =(𝑾V+Δ𝑾V)𝑿𝑿T(𝑾K+Δ𝑾K)T𝒒\displaystyle=({\bm{W}}_{V}+\Delta{\bm{W}}_{V}){\bm{X}}{\bm{X}}^{T}({\bm{W}}_{K}+\Delta{\bm{W}}_{K})^{T}{\bm{q}}
=(𝑾ZSL+Δ𝑾FT)𝒒\displaystyle=({\bm{W}}_{\mathrm{ZSL}}+\Delta{\bm{W}}_{\mathrm{FT}}){\bm{q}}

where Δ𝑾K\Delta{\bm{W}}_{K} and Δ𝑾V\Delta{\bm{W}}_{V} denote the parameter updates and they are acquired by back-propagation from task-specific training objectives [Dai et al., 2022], which is a supervised learning process of the original model. Considering the similarity in form between F~ICL\tilde{F}_{\mathrm{ICL}} and F~FT\tilde{F}_{FT}, their focus is on establishing a connection between ICL and implicit fine-tuning on the original model.

As a comparison, we turn our attention to establish a connection between ICL and the gradient descent process of the dual model, rather than the original model. More specifically, we consider the dual model f(𝒙)=𝑾ϕ(𝒙)f({\bm{x}})={\bm{W}}\phi({\bm{x}}) of the nonlinear attention layer, where the weight 𝑾{\bm{W}} are updated according to the following loss (presented as Eq (9) in Section 3.2):

=1ηDi=1N(𝑾V𝒙i)T𝑾ϕ(𝑾K𝒙i),\mathcal{L}=-\frac{1}{\eta D}\sum_{i=1}^{N}({\bm{W}}_{V}{\bm{x}}_{i})^{T}{\bm{W}}\phi({\bm{W}}_{K}{\bm{x}}_{i}),

where 𝒙i{\bm{x}}_{i} is the ii-th demonstration token. The prediction output of the trained dual model will be consistent with the ICL output of the attention layer. The gradient descent process of the dual model using this loss can be viewed from a self-supervised learning lens: unlike in supervised fine-tuning, where the original model is instructed to perform gradient descent using a given objective (loss), this loss formed as Eq (9) is determined (derived) by the attention mechanism itself and it also does not require additional "true label" to supervise each token 𝒙i{\bm{x}}_{i} (so called self-supervised). Therefore, modifications to this self-supervised learning loss will in turn cause modifications in the attention mechanism correspondingly, as we discussed in our work in Section 4. We believe this perspective offers several benefits:

  • By analyzing from the dual perspective, we can transform the forward inference process into an optimization process. Since optimization processes are well-known and have established theoretical tools (for example, generalization error as mentioned in Section 3.3), this transformation can provide reverse insights into analyzing the model mechanisms.

  • It can clearly observed that the dual model involves a self-supervised representation learning process from the dual perspective. Considering that there are lots of mature works in this area, we can draw on these works to reflect on the attention mechanism, which has also inspired attention modifications as illustrated in Section 4.

  • Intuitively, this explanation might be also reasonable as the original model is not explicitly instructed to provide the answer under some given objective (e.g., minimizing cross-entropy) during ICL inference process. Instead, the underlying criterion should be determined by the model’s own structure (self-supervised) as we mentioned above.

In addition, although we do not target specific tasks like linear regression as previous works mentioned in Section 6, we would like to point out that under those specific weight and input settings, an intuitive explanation can also be provided from a representation learning perspective. Here, we take the linear regression task as well as the weight constructions considered by Von Oswald et al. [2023a] as an example. Specifically, it assumes that the structured input is 𝑯=[𝒉i]i=1N(d+1)×(N){\bm{H}}=[{\bm{h}}_{i}]_{i=1}^{N}\in\mathbb{R}^{(d+1)\times(N)} where 𝒉i=[𝒙i,yi]{\bm{h}}_{i}=[{\bm{x}}_{i},y_{i}] is sampled from some linear task y=𝒘T𝒙y={\bm{w}}^{T}{\bm{x}} and the query token will be 𝒉N+1=[𝒙N+1,𝒘0T𝒙N+1]{\bm{h}}_{N+1}=[{\bm{x}}_{N+1},-{\bm{w}}_{0}^{T}{\bm{x}}_{N+1}]. And the considered linear self-attention layer will take the constructed weights and query output as:

𝑾K=\displaystyle{\bm{W}}_{K}= 𝑾Q=[𝑰d×d000],𝑾V=[0d×d0𝒘0T1],𝑷=ηN𝑰,\displaystyle{\bm{W}}_{Q}=\begin{bmatrix}{\bm{I}}_{d\times d}&0\\ 0&0\end{bmatrix},{\bm{W}}_{V}=\begin{bmatrix}0_{d\times d}&0\\ {\bm{w}}_{0}^{T}&-1\end{bmatrix},{\bm{P}}=\frac{\eta}{N}{\bm{I}}, (35)
𝒉~N+1=𝒉N+1+𝑷(𝑾V𝑯)(𝑾K𝑯)T𝑾Q𝒉N+1,\displaystyle\tilde{{\bm{h}}}_{N+1}={\bm{h}}_{N+1}+{\bm{P}}({\bm{W}}_{V}{\bm{H}})({\bm{W}}_{K}{\bm{H}})^{T}{\bm{W}}_{Q}{\bm{h}}_{N+1},

where 𝒘0{\bm{w}}_{0} is the underlying initial matrix. Then the label part of 𝒉~N+1\tilde{{\bm{h}}}_{N+1} will has the form as y~N+1=𝒘0T𝒙N+1+Δ𝒘T𝒙N+1=(𝒘0TηNi=1N(𝒘0T𝒙iyi)𝒙iT)𝒙N+1=𝒚^N+1\tilde{y}_{N+1}=-{\bm{w}}_{0}^{T}{\bm{x}}_{N+1}+\Delta{\bm{w}}^{T}{\bm{x}}_{N+1}=-({\bm{w}}_{0}^{T}-\frac{\eta}{N}\sum_{i=1}^{N}({\bm{w}}_{0}^{T}{\bm{x}}_{i}-y_{i}){\bm{x}}_{i}^{T}){\bm{x}}_{N+1}=-\hat{{\bm{y}}}_{N+1}, which is equivalent to the output y^N+1-\hat{y}_{N+1} (multiplied by 1-1) of the linear layer y=𝒘T𝒙y={\bm{w}}^{T}{\bm{x}} where 𝒘{\bm{w}} is initialized as 𝒘0{\bm{w}}_{0} after performing one step of gradient descent under mean squared loss =12Ni=1N𝒘0T𝒙iyi2{\mathcal{L}}=\frac{1}{2N}\sum_{i=1}^{N}\|{\bm{w}}_{0}^{T}{\bm{x}}_{i}-y_{i}\|^{2}.

In practice, the underlying initial weight matrix 𝒘0{\bm{w}}_{0} is set to be approximately 𝟎{\bm{0}} thus the test input can be formed as 𝒉N+1=[𝒙i,𝟎]{\bm{h}}_{N+1}=[{\bm{x}}_{i},{\bm{0}}] [Von Oswald et al., 2023a]. In addition, when reading out the label y^N+1\hat{y}_{N+1}, the test prediction y~N+1\tilde{y}_{N+1} will be multiplied again by 1-1, which can be done by a final projection matrix (or equivalently, 𝑷=ηN𝑰{\bm{P}}=-\frac{\eta}{N}{\bm{I}}). In this case, we first note that the dual model of the linear attention layer can be written as f(𝒛)=𝑾𝒛f({\bm{z}})={\bm{W}}{\bm{z}} where 𝑾(d+1)×(d+1){\bm{W}}\in{\mathbb{R}}^{(d+1)\times(d+1)} and similar to Eq (9), it will be trained under the loss below:

min𝑾=1ηi=1N(𝑷𝑾V𝒉i)T𝑾𝑾K𝒉i.\min_{{\bm{W}}}{\mathcal{L}}=-\frac{1}{\eta}\sum_{i=1}^{N}\left({\bm{P}}{\bm{W}}_{V}{\bm{h}}_{i}\right)^{T}{\bm{W}}{\bm{W}}_{K}{\bm{h}}_{i}. (36)

By substituting the corresponding weights in Eq (35) where we replace 𝑷=ηN𝑰{\bm{P}}=-\frac{\eta}{N}{\bm{I}} for the readout, the loss can be reformulated as:

min𝑾=1Ni=1N[0,yi]𝑾[𝒙i0].\min_{{\bm{W}}}{\mathcal{L}}=-\frac{1}{N}\sum_{i=1}^{N}\begin{bmatrix}0,~{}y_{i}\end{bmatrix}{\bm{W}}\begin{bmatrix}{\bm{x}}_{i}\\ 0\end{bmatrix}. (37)

Recalling that 𝒉i=[𝒙i,yi]{\bm{h}}_{i}=[{\bm{x}}_{i},y_{i}] is sampled from some linear task y=𝒘T𝒙y={\bm{w}}^{T}{\bm{x}}, we assume that 𝑾F𝒘2\|{\bm{W}}\|_{F}\leq\|{\bm{w}}\|_{2}, it can then be easily seen that the optimal solution for Eq (37) will be

𝑾=[𝟎𝟎𝒘T𝟎].{\bm{W}}^{*}=\begin{bmatrix}{\bm{0}}&{\bm{0}}\\ {\bm{w}}^{T}&{\bm{0}}\end{bmatrix}. (38)

Furthermore, similar to Section 3.2, we take 𝑾Q𝒉N+1{\bm{W}}_{Q}{\bm{h}}_{N+1} as the input where 𝑾Q{\bm{W}}_{Q} is constructed as Eq (35) and 𝒉N+1=[𝒙N+1,0]{\bm{h}}_{N+1}=[{\bm{x}}_{N+1},0], the optimal dual model will output the result f(𝑾Q𝒉N+1)=𝑾𝑾Q𝒉N+1=[𝟎,𝒘T𝒙N+1]=[𝟎,yN+1]f({\bm{W}}_{Q}{\bm{h}}_{N+1})={\bm{W}}^{*}{\bm{W}}_{Q}{\bm{h}}_{N+1}=[{\bm{0}},{\bm{w}}^{T}{\bm{x}}_{N+1}]=[{\bm{0}},y_{N+1}] where the label part will be just the answer for the test query. Additionally, it would also be interesting to explore how these weights converge to the constructed form in Eq (35) or other forms under this special setting as previous works illustrated from the perspective of the dual model. Investigating this issue goes beyond the scope of this paper, and we will leave it for future exploration.