This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Approximating How Single Head Attention Learns

Charlie Snell  Ruiqi Zhong  Dan Klein  Jacob Steinhardt
Computer Science Division, University of California, Berkeley
{csnell22, ruiqi-zhong, klein, jsteinhardt}@berkeley.edu
Abstract

Why do models often attend to salient words, and how does this evolve throughout training? We approximate model training as a two stage process: early on in training when the attention weights are uniform, the model learns to translate individual input word ii to oo if they co-occur frequently. Later, the model learns to attend to ii while the correct output is oo because it knows ii translates to oo. To formalize, we define a model property, Knowledge to Translate Individual Words (KTIW) (e.g. knowing that ii translates to oo), and claim that it drives the learning of the attention. This claim is supported by the fact that before the attention mechanism is learned, KTIW can be learned from word co-occurrence statistics, but not the other way around. Particularly, we can construct a training distribution that makes KTIW hard to learn, the learning of the attention fails, and the model cannot even learn the simple task of copying the input words to the output. Our approximation explains why models sometimes attend to salient words, and inspires a toy example where a multi-head attention model can overcome the above hard training distribution by improving learning dynamics rather than expressiveness. We end by discussing the limitation of our approximation framework and suggest future directions.

1 Introduction

The attention mechanism underlies many recent advances in natural language processing, such as machine translation Bahdanau et al. (2015) and pretraining Devlin et al. (2019). While many works focus on analyzing attention in already-trained models Jain and Wallace (2019); Vashishth et al. (2019); Brunner et al. (2019), little is understood about how the attention mechanism is learned via gradient descent at training time.

These learning dynamics are important, as standard, gradient-trained models can have very unique inductive biases, distinguishing them from more esoteric but equally accurate models. For example, in text classification, while standard models typically attend to salient (high gradient influence) words Serrano and Smith (2019a), recent work constructs accurate models that attend to irrelevant words instead Wiegreffe and Pinter (2019a); Pruthi et al. (2020). In machine translation, while the standard gradient descent cannot train a high-accuracy transformer with relatively few attention heads, we can construct one by first training with more heads and then pruning the redundant heads Voita et al. (2019); Michel et al. (2019). To explain these differences, we need to understand how attention is learned at training time.

Our work opens the black box of attention training, focusing on attention in LSTM Seq2Seq models Luong et al. (2015) (Section 2.1). Intuitively, if the model knows that the input individual word ii translates to the correct output word oo, it should attend to ii to minimize the loss. This motivates us to investigate the model’s knowledge to translate individual words (abbreviated as KTIW), and we define a lexical probe β\beta to measure this property.

We claim that KTIW drives the attention mechanism to be learned. This is supported by the fact that KTIW can be learned when the attention mechanism has not been learned (Section 3.2), but not the other way around (Section 3.3). Specifically, even when the attention weights are frozen to be uniform, probe β\beta still strongly agrees with the attention weights of a standardly trained model. On the other hand, when KTIW cannot be learned, the attention mechanism cannot be learned. Particularly, we can construct a distribution where KTIW is hard to learn; as a result, the model fails to learn a simple task of copying the input to the output.

Now the problem of understanding how attention mechanism is learned reduces to understanding how KTIW is learned. Section 2.3 builds a simpler proxy model that approximates how KTIW is learned, and Section 3.2 verifies empirically that the approximation is reasonable. This proxy model is simple enough to analyze and we interpret its training dynamics with the classical IBM Translation Model 1 (Section 4.2), which translates individual word ii to oo if they co-occur more frequently.

To collapse this chain of reasoning, we approximate model training in two stages. Early on in training when the attention mechanism has not been learned, the model learns KTIW through word co-occurrence statistics; KTIW later drives the learning of the attention.

Using these insights, we explain why attention weights sometimes correlate with word saliency in binary text classification (Section 5.1): the model first learns to “translate" salient words into labels, and then attend to them. We also present a toy experiment (Section 5.2) where multi-head attention improves learning dynamics by combining differently initialized attention heads, even though a single head model can express the target function.

Nevertheless, “all models are wrong". Even though our framework successfully explains and predicts the above empirical phenomena, it cannot fully explain the behavior of attention-based models, since approximations are after all less accurate. Section 6 identifies and discusses two key assumptions: (1) information of a word tends to stay in the local hidden state (Section 6.1) and (2) attention weights are free variables (Section 6.2). We discuss future directions in Section 7.

2 Model

Section 2.1 defines the LSTM with attention Seq2Seq architecture. Section 2.2 defines the lexical probe β\beta, which measures the model’s knowledge to translate individual words (KTIW). Section 2.3 approximates how KTIW is learned early on in training by building a “bag of words" proxy model. Section 2.4 shows that our framework generalizes to binary classification.

Refer to caption
Figure 1: Attention mechanism in recurrent models (left, Section 2.1) and word alignments in the classical model (right, Section 4.2) are learned similarly. Both first learn how to translate individual words (KTIW) under uniform attention weights/alignment at the start of training (upper, blue background), which then drives the attention mechanism/alignment to be learned (lower, red background).

2.1 Machine Translation Model

We use the dot-attention variant from Luong et al. (2015). The model maps from an input sequence {xl}\{x_{l}\} with length LL to an output sequence {yt}\{y_{t}\} with length TT. We first use LSTM encoders to embed {xl}\{x_{l}\}\subset\mathcal{I} and {yt}𝒪\{y_{t}\}\subset\mathcal{O} respectively, where \mathcal{I} and 𝒪\mathcal{O} are input and output vocab space, and obtain encoder and decoder hidden states {hl}\{h_{l}\} and {st}\{s_{t}\}. Then we calculate the attention logits at,la_{t,l} by applying a learnable mapping from hlh_{l} and sts_{t}, and use softmax to obtain the attention weights αt,l\alpha_{t,l}:

at,l=stTWhl;αt,l=eat,ll=1Leat,l.a_{t,l}=s_{t}^{T}Wh_{l};\quad\alpha_{t,l}=\frac{e^{a_{t,l}}}{\sum_{l^{\prime}=1}^{L}e^{a_{t,l^{\prime}}}}. (1)

Next we sum the encoder hidden states {ht}\{h_{t}\} weighted by the attention to obtain the “context vector" ctc_{t}, concatenate it with the decoder sts_{t}, and obtain the output vocab probabilities ptp_{t} by applying a learnable neural network NN with one hidden layer and softmax activation at the output:

ct=l=1Lαt,lhl;pt=N([ct,st]).c_{t}=\sum_{l=1}^{L}\alpha_{t,l}h_{l};\quad p_{t}=N([c_{t},s_{t}]). (2)

We train the model by minimizing the sum of negative log likelihood of all the output words yty_{t}:

=t=1Tlogpt,yt.\mathcal{L}=-\sum_{t=1}^{T}\log p_{t,y_{t}}. (3)

2.2 Lexical Probe β\beta

We define the lexical probe βt,l\beta_{t,l} as:

βt,l:=N([hl,st])yt,\beta_{t,l}:=N([h_{l},s_{t}])_{y_{t}}, (4)

which means “the probability assigned to the correct word yty_{t}, if the network attends only to the input encoder state hlh_{l}". If we assume that hlh_{l} only contains information about xlx_{l}, β\beta closely reflects KTIW, since β\beta can be interpreted as “the probability that xlx_{l} is translated to the output yty_{t}".

Heuristically, to minimize the loss, the attention weights α\alpha should be attracted to positions with larger βt,l\beta_{t,l}.111This statement is heuristical rather than rigorous. See Appendix A.1 for a counterexample. Hence, we expect the learning of the attention to be driven by KTIW (Figure 1 left). We then discuss how KTIW is learned.

2.3 Early Dynamics of Lexical Knowledge

To approximate how KTIW is learned early on in training, we build a proxy model by making a few simplifying assumptions. First, since attention weights are uniform early on in training, we replace the attention distribution with a uniform one. Second, since we are defining individual word translation, we assume that information about each word is localized to its corresponding hidden state. Therefore, similar to Sun and Lu (2020), we replace hlh_{l} with an input word embedding exlde_{x_{l}}\in\mathbb{R}^{d}, where ee represents the word embedding matrix and dd is the embedding dimension. Third, to simplify analysis, we assume NN only contains one linear layer W|𝒪|×dW\in\mathbb{R}^{|\mathcal{O}|\times d} before softmax activation and ignore the decoder state sts_{t}. Putting these assumptions together, we now define a new proxy model that produces output vocab probability ptp_{t}:

t,pt=σ(1Ll=1LWexl).\forall t,p_{t}=\sigma(\frac{1}{L}\sum_{l=1}^{L}We_{x_{l}}). (5)

On a high level, this proxy averages the embeddings of the input “bag of words", and produces a distribution over output vocabs to predict the output “bag of words". This implies that the sets of input and output words for each sentence pair are sufficient statistics for this proxy.

The probe βpx\beta^{\mathrm{px}} can be similarly defined as:

βt,lpx=σ(Wexl)yt.\beta^{\mathrm{px}}_{t,l}=\sigma(We_{x_{l}})_{y_{t}}. (6)

We provide more intuitions on how this proxy learns in Section 4.

2.4 Binary Classification Model

Binary classification can be reduced to “machine translation", where T=1T=1 and |𝒪|=2|\mathcal{O}|=2. We drop the subscript t=1t=1 when discussing classification.

We use the standard architecture from Wiegreffe and Pinter (2019a). After obtaining the encoder hidden states {ht}\{h_{t}\}, we calculate the attention logits ala_{l} by applying a feed-forward neural network with one hidden layer and take the softmax of aa to obtain the attention weights α\alpha:

al=vT(ReLU(Qhl));αl=eall=1Leal,a_{l}=v^{T}(ReLU(Qh_{l}));\quad\alpha_{l}=\frac{e^{a_{l}}}{\sum_{l^{\prime}=1}^{L}e^{a_{l^{\prime}}}}\quad, (7)

where QQ and vv are learnable.

We sum the hidden states {hl}\{h_{l}\} weighted by the attention, feed it to a final linear layer and apply the sigmoid activation function (σ\sigma) to obtain the probability for the positive class

ppos=σ(WTl=1Lalhl)=σ(l=1LαlWThl).p^{\text{pos}}=\sigma(W^{T}\sum_{l=1}^{L}a_{l}h_{l})=\sigma(\sum_{l=1}^{L}\alpha_{l}W^{T}h_{l}). (8)

Similar to the machine translation model (Section 2.1), we define the “lexical probe":

βl:=σ((2y1)WThl),\beta_{l}:=\sigma((2y-1)W^{T}h_{l}), (9)

where y{0,1}y\in\{0,1\} is the label and 2y1{1,1}2y-1\in\{-1,1\} controls the sign.

On a high level, Sun and Lu (2020) focuses on binary classification and provides almost the exact same arguments as ours. Specifically, their polarity score “sls_{l}" equals βl1βl\frac{\beta_{l}}{1-\beta_{l}} in our context, and they provide a more subtle analysis of how the attention mechanism is learned in binary classification.

3 Empirical Evidence

We provide evidence that KTIW drives the learning of the attention early on in training: KTIW can be learned when the attention mechanism has not been learned (Section 3.2), but not the other way around (Section 3.3).

3.1 Measuring Agreement

We start by describing how to evaluate the agreement between quantities of interest, such as α\alpha and β\beta. For any input-output sentence pair (xm,ym)(x^{m},y^{m}), for each output index tt, αtm,βtm,βtpx,mLm\alpha^{m}_{t},\beta^{m}_{t},\beta^{\mathrm{px},m}_{t}\in\mathbb{R}^{L^{m}} all associate each input position ll with a real number. Since attention weights and word alignment tend to be sparse, we focus on the agreement of the highest-valued position. Suppose u,vLu,v\in\mathbb{R}^{L}, we formally define the agreement of vv with uu as:

𝒜(u,v):=𝟏[|{j|vj>vargmaxui}|<5%L],\mathcal{A}(u,v):=\mathbf{1}[|\{j|v_{j}>v_{\operatorname*{arg\,max}u_{i}}\}|<5\%L], (10)

which means “whether the highest-valued position (dimension) in uu is in the top 5% highest-valued positions in vv". We average the 𝒜\mathcal{A} values across all output words on the validation set to measure the agreement between two model properties. We also report Kendall’s τ\tau rank correlation coefficient in Appendix 3 for completeness.

We denote its random baseline as 𝒜^\hat{\mathcal{A}}. 𝒜^\hat{\mathcal{A}} is close to but not exactly 5%5\% because of integer rounding.

Contextualized Agreement Metric.

However, since different datasets have different sentence length distributions and variance of attention weights caused by random seeds, it might be hard to directly interpret this agreement metric. Therefore, we contextualize this metric with model performance. We use the standard method to train a model till convergence using 𝒯\mathcal{T} steps and denote its attention weights as α\alpha; next we train the same model from scratch again using another random seed. We denote its attention weights at training step τ\tau as α^(τ)\hat{\alpha}(\tau) and its performance as p^(τ)\hat{p}(\tau). Roughly speaking, when τ<𝒯\tau<\mathcal{T}, both 𝒜(α,α^(τ))\mathcal{A}(\alpha,\hat{\alpha}(\tau)) and p^(τ)\hat{p}(\tau) increase as τ\tau increases. We define the contextualized agreement ξ\xi as:

ξ(u,v):=p^(inf{τ|𝒜(α,α^(τ))>𝒜(u,v)}).\xi(u,v):=\hat{p}(\inf\{\tau|\mathcal{A}(\alpha,\hat{\alpha}(\tau))>\mathcal{A}(u,v)\}). (11)

In other words, we find the training step τ0\tau_{0} where its attention weights α^(τ0)\hat{\alpha}(\tau_{0}) and the standard attention weights α\alpha agrees more than uu and vv agrees, and report the performance at this iteration. See Figure 2. We refer to the model performance when training finishes (τ=𝒯\tau=\mathcal{T}) as ξ\xi^{*}. Table 1 lists the rough intuition for each abstract symbol.

Refer to caption
Figure 2: We find the smallest training step τ0\tau_{0} where (α,α^(τ0))>𝒜(u,v)(\alpha,\hat{\alpha}(\tau_{0}))>\mathcal{A}(u,v), and define ξ(u,v):=p^(τ0)\xi(u,v):=\hat{p}(\tau_{0}).
Symbol Intuition
α\alpha Attention weights
β\beta Lexical Probe
γ\gamma Logits, before σ\sigma (softmax/sigmoid)
Δ\Delta Word Saliency by gradient
𝒜\mathcal{A} Agreement metric
ξ\xi Contextualized agreement metric
τ\tau Training steps
CC Co-occurrence statistics, count table
σ\sigma Softmax or sigmoid function
Table 1: Intuitions for each abstract symbol (some occur later in the paper). The first group is model activations/gradients, the second metrics, and the third others.
Task 𝒜(α,βuf)\mathcal{A}(\alpha,\beta^{\mathrm{uf}}) 𝒜(βuf,βpx)\mathcal{A}(\beta^{\mathrm{uf}},\beta^{\mathrm{px}}) 𝒜(Δ,βuf)\mathcal{A}(\Delta,\beta^{\mathrm{uf}}) 𝒜(α,β)\mathcal{A}(\alpha,\beta) 𝒜^\hat{\mathcal{A}} ξ(α,βuf)\xi(\alpha,\beta^{\mathrm{uf}}) ξ(α,β)\xi(\alpha,\beta) ξ\xi^{*}
IMDB 53 82 62 60 5 87 87 90
AG News 39 55 43 48 6 94 95 96
20 NG 65 41 65 63 5 91 85 94
SST 20 34 22 25 8 78 82 84
Multi30k 31 34 27 49 7 43 49 66
IWSLT14 36 39 28 55 7 36 44 67
News It-Pt 29 39 25 52 6 22 25 55
Table 2: The tasks above the horizontal line are classification and below are translation. The (contextualized) agreement metric 𝒜\mathcal{A}(ξ\xi) is described in Section 3.1. Across all tasks, 𝒜(α,β)\mathcal{A}(\alpha,\beta), 𝒜(α,βuf)\mathcal{A}(\alpha,\beta^{\mathrm{uf}}), and 𝒜(βuf,βpx)\mathcal{A}(\beta^{\mathrm{uf}},\beta^{\mathrm{px}}) significantly outperform the random baseline 𝒜^\hat{\mathcal{A}} and the corresponding contextualized interpretations ξ\xi are also non-trivial. This implies that 1) the proxy model from Section 2.3 approximates well how KTIW is learned, 2) attention weights α\alpha and the probe β\beta of KTIW strongly agrees, and 3) KTIW can still be learned when the attention weights are uniform.

Datasets.

We evaluate the agreement metrics 𝒜\mathcal{A} and ξ\xi on multiple machine translation and text classification datasets. For machine translation, we use Multi-30k (En-De), IWSLT’14 (De-En), and News Commentary v14 (En-Nl, En-Pt, and It-Pt). For text classification, we use IMDB Sentiment Analysis, AG News Corpus, 20 Newsgroups (20 NG), Stanford Sentiment Treebank, Amazon review, and Yelp Open Data Set. All of them are in English. The details and citations of these datasets can be seen in the Appendix A.5. We use token accuracy222 Appendix Tables 6, 4, and 8 include results for BLEU. to evaluate the performance of translation models and accuracy to evaluate the classification models.

Due to space limit we round to integers and include a subset of datasets in Table 2 for the main paper. Appendix Table 5 includes the full results.

3.2 KTIW Learns under Uniform Attention

Even when the attention mechanism has not been learned, KTIW can still be learned. We train the same model architecture with the attention weights frozen to be uniform, and denote its lexical probe as βuf\beta^{\mathrm{uf}}. Across all tasks, 𝒜(α,βuf)\mathcal{A}(\alpha,\beta^{\mathrm{uf}}) and 𝒜(βuf,βpx)\mathcal{A}(\beta^{\mathrm{uf}},\beta^{\mathrm{px}}) 333Empirically, βpx\beta^{\mathrm{px}} converges to the unigram weight of a bag-of-words logistic regression model, and hence βpx\beta^{\mathrm{px}} does capture an interpretable notion of “keywords”. (Appendix A.10.) significantly outperform the random baseline 𝒜^\hat{\mathcal{A}}, and the contextualized agreement ξ(α,βuf)\xi(\alpha,\beta^{\mathrm{uf}}) is also non-trivial. This indicates that 1) the proxy we built in Section 2.3 approximates KTIW  and 2) even when the attention weights are uniform, KTIW is still learned.

3.3 Attention Fails When KTIW Fails

We consider a simple task of copying from the input to the output, and each input is a permutation of the same set of 40 vocab types. Under this training distribution, the proxy model provably cannot learn: every input-output pair contains the exact same set of input-output words.444We provide more intuitions on this in Section 4 As a result, our framework predicts that KTIW is unlikely to be learned, and hence the learning of attention is likely to fail.

The training curves of learning to copy the permutations are in Figure 3 left, colored in red: the model sometimes fails to learn. For the control experiment, if we randomly sample and permute 40 vocabs from 60 vocab types as training samples, the model successfully learns (blue curve) from this distribution every time. Therefore, even if the model is able to express this task, it might fail to learn it when KTIW is not learned. The same qualitative conclusion holds for the training distribution that mixes permutations of two disjoint sets of words (Figure 3 right), and Appendix A.3 illustrates the intuition.

For binary classification, it follows from the model definition that attention mechanism cannot be learned if KTIW cannot be learned, since

pcorrect=σ(l=1Lαlσ1(βl));σ(x)=11+ex,p^{\text{correct}}=\sigma(\sum_{l=1}^{L}\alpha_{l}\sigma^{-1}(\beta_{l}));\quad\sigma(x)=\frac{1}{1+e^{-x}}, (12)

and the model needs to attend to positions with higher β\beta, in order to predict correctly and minimize the loss. For completeness, we include results where we freeze β\beta and find that the learning of the attention fails in Appendix A.6.

Refer to caption
Figure 3: Each curve represents accuracy on the test distribution vs. number of training steps for different random seeds (20 each). When trained on a distribution of permutation of 40 vocabs (red) (Left) or a mixture of permutations (Right), the model sometimes fails to learn and converges slower.

4 Connection to IBM Model 1

Section 2.3 built a simple proxy model to approximate how KTIW is learned when the attention weights are uniform early on in training, and Section 3.2 verified that such an approximation is empirically sound. However, it is still hard to intuitively reason about how this proxy model learns. This section provides more intuitions by connecting its initial gradient (Section 4.1) to the classical IBM Model 1 alignment algorithm Brown et al. (1993) (Section 4.2).

4.1 Derivative at Initialization

We continue from the end of Section 2.3. For each input word ii and output word oo, we are interested in understanding the probability that ii assigns to oo, defined as:

θi,opx:=σ(Wei)o.\theta^{\mathrm{px}}_{i,o}:=\sigma(We_{i})_{o}. (13)

This quantity is directly tied to βpx\beta^{\mathrm{px}}, since βt,lpx=θxl,ytpx\beta^{\mathrm{px}}_{t,l}=\theta^{\mathrm{px}}_{x_{l},y_{t}}. Using super-script m to index sentence pairs in the dataset, the total loss \mathcal{L} is:

=mt=1Tmlog(σ(1Lml=1LmWexlm)ytm).\mathcal{L}=-\sum_{m}\sum_{t=1}^{T^{m}}\log(\sigma(\frac{1}{L^{m}}\sum_{l=1}^{L^{m}}We_{x^{m}_{l}})_{y^{m}_{t}}). (14)

Suppose each eie_{i} or WoW_{o} is independently initialized from a normal distribution 𝒩(0,Id/d)\mathcal{N}(0,I_{d}/d) and we minimize \mathcal{L} over WW and ee using gradient flow, then the value of ee and WW are uniquely defined for each continuous time step τ\tau. By some straightforward but tedious calculations (details in Appendix A.2), the derivative of θi,o\theta_{i,o} when the training starts is:

limdθi,opxτ(τ=0)𝑝2(Ci,opx1|𝒪|o𝒪Ci,opx).\lim_{d\rightarrow\infty}\frac{\partial\theta^{\mathrm{px}}_{i,o}}{\partial\tau}(\tau=0)\overset{p}{\to}2(C^{\mathrm{px}}_{i,o}-\frac{1}{|\mathcal{O}|}\sum_{o^{\prime}\in\mathcal{O}}C^{\mathrm{px}}_{i,o^{\prime}}). (15)

where 𝑝\overset{p}{\to} means convergence in probability and Ci,opxC^{\mathrm{px}}_{i,o} is defined as

Ci,opx:=ml=1Lmt=1Tm1Lm𝟏[xlm=i]𝟏[ytm=o].C^{\mathrm{px}}_{i,o}:=\sum_{m}\sum^{L^{m}}_{l=1}\sum^{T^{m}}_{t=1}\frac{1}{L^{m}}\mathbf{1}[x^{m}_{l}=i]\mathbf{1}[y^{m}_{t}=o]. (16)

Equation 15 tells us that βt,lpx=θxl,ytpx\beta^{\mathrm{px}}_{t,l}=\theta^{\mathrm{px}}_{x_{l},y_{t}} is likely to be larger if Cxl,ytC_{x_{l},y_{t}} is large. The definition of CC seems hard to interpret from Equation 16, but in the next subsection we will find that this quantity naturally corresponds to the “count table" used in the classical IBM 1 alignment learning algorithm.

4.2 IBM Model 1 Alignment Learning

The classical alignment algorithm aims to learn which input word is responsible for each output word (e.g. knowing that y2y_{2} “movie" aligns to x2x_{2} “Film" in Figure 1 upper left), from a set of input-output sentence pairs. IBM Model 1 Brown et al. (1993) starts with a 2-dimensional count table CIBMC^{\mathrm{IBM}} indexed by ii\in\mathcal{I} and o𝒪o\in\mathcal{O}, denoting input and output vocabs. Whenever vocab ii and oo co-occurs in an input-output pair, we add 1L\frac{1}{L} to the Ci,oIBMC^{\mathrm{IBM}}_{i,o} entry (step 1 and 2 in Figure 1 right). After updating CIBMC^{\mathrm{IBM}} for the entire dataset, CIBMC^{\mathrm{IBM}} is exactly the same as CpxC^{\mathrm{px}} defined in Equation 16. We drop the super-script of CC to keep the notation uncluttered.

Given CC, the classical model estimates a probability distribution of “what output word oo does the input word ii translate to" (Figure 1 right step 3) as

Trans(o|i)=Ci,ooCi,o.\text{Trans}(o|i)=\frac{C_{i,o}}{\sum_{o^{\prime}}C_{i,o^{\prime}}}. (17)

In a pair of sequences ({xl},{yt}\{x_{l}\},\{y_{t}\}), the probability βIBM\beta^{\mathrm{IBM}} that xlx_{l} is translated to the output yty_{t} is:

βt,lIBM:=Trans(yt|xl),\beta^{\mathrm{IBM}}_{t,l}:=\text{Trans}(y_{t}|x_{l}), (18)

and the alignment probability αIBM\alpha^{\mathrm{IBM}} that “xlx_{l} is responsible for outputting yty_{t} versus other xlx_{l^{\prime}}" is

αIBM(t,l)=βt,lIBMl=1Lβt,lIBM,\alpha^{\mathrm{IBM}}(t,l)=\frac{\beta^{\mathrm{IBM}}_{t,l}}{\sum_{l^{\prime}=1}^{L}\beta^{\mathrm{IBM}}_{t,l^{\prime}}}, (19)

which monotonically increases with respect to βt,lIBM\beta^{\mathrm{IBM}}_{t,l}. See Figure 1 right step 5.

4.3 Visualizing Aforementioned Tasks

Figure 1 (right) visualizes the count table CC for the machine translation task, and illustrates how KTIW is learned and drives the learning of attention. We provide similar visualization for why KTIW is hard to learn under a distribution of vocab permutations (Section 3.3) in Figure 4, and how word polarity is learned in binary classification (Section 2.4) in Figure 5.

Refer to caption
Figure 4: Co-occurrence table CC is non-informative under a distribution of permutations. Therefore, this distribution is hard for the attention-based model to learn.
Refer to caption
Figure 5: The classical model first learns word polarity, which later attracts attention.

5 Application

5.1 Interpretability in Classification

We use gradient based method Ebrahimi et al. (2018) to approximate the influence Δl\Delta_{l} for each input word xlx_{l}. The column 𝒜(Δ,βuf)\mathcal{A}(\Delta,\beta^{\mathrm{uf}}) reports the agreement between Δ\Delta and βuf\beta^{\mathrm{uf}}, and it significantly outperforms the random baseline. Since KTIW initially drives the attention mechanism to be learned, this explains why attention weights are correlated with word saliency on many classification tasks, even though the training objective does not explicitly reward this.

5.2 Multi-head Improves Training Dynamics

We saw in Section 3.3 that learning to copy sequences under a distribution of permutations is hard and the model can fail to learn; however, sometimes it is still able to learn. Can we improve learning and overcome this hard distribution by ensembling several attention parameters together?

We introduce a multi-head attention architecture by summing the context vector ctc_{t} obtained by each head. Suppose there are KK heads each indexed by kk, similar to Section 2.1:

at,l(k)=stTW(k)hl;αt,l(k)=eat,l(k)l=1Leαt,l(k),a^{(k)}_{t,l}=s_{t}^{T}W^{(k)}h_{l};\quad\alpha^{(k)}_{t,l}=\frac{e^{a^{(k)}_{t,l}}}{\sum_{l^{\prime}=1}^{L}e^{\alpha^{(k)}_{t,l^{\prime}}}}, (20)

and the context vector and final probability ptp_{t} defined as:

ct(k)=l=1Lαt,l(k)hl;pt=N([k=1Kct(k),dt]),c^{(k)}_{t}=\sum_{l=1}^{L}\alpha^{(k)}_{t,l}h_{l};\quad p_{t}=N([\sum_{k=1}^{K}c^{(k)}_{t},d_{t}]), (21)

where W(k)W^{(k)} are different learn-able parameters.

We call Winit(k)W^{(k)}_{init} a good initialization if training with this single head converges, and bad otherwise. We use rejection sampling to find good/bad head initializations and combine them to form 8-head (K=8K=8) attention models. We experiment with 3 scenarios: (1) all head initializations are bad, (2) only one initialization is good, and (3) initializations are sampled independently at random.

Figure 6 presents the training curves. If all head initializations are bad, the model fails to converge (red). However, as long as one of the eight initializations is good, the model can converge (blue). As the number of heads increases, the probability that all initializations are bad is exponentially small if all initializations are sampled independently; hence the model converges with very high probability (green). In this experiment, multi-head attention improves not by increasing expressiveness, since one head is sufficient to accomplish the task, but by improving the learning dynamics.

Refer to caption
Figure 6: If all head initializations (head-init) are bad (red), the model is likely to fail; if one of the head-init is good (blue), it is likely to learn; with high chance, at least one out of eight random head-init is good (green). We used 20 random seeds for each setting.

6 Assumptions

We revisit the approximation assumptions used in our framework. Section 6.1 discusses whether the lexical probe βt,l\beta_{t,l} necessarily reflects local information about input word xlx_{l}, and Section 6.2 discusses whether attention weights can be freely optimized to attend to large β\beta. These assumptions are accurate enough to predict phenomenon in Section 3 and 5, but they are not always true and hence warrant more future researches. We provide simple examples where these assumptions might fail.

6.1 β\beta Remains Local

We use a toy classification task to show that early on in training, expectantly, βuf\beta^{\mathrm{uf}} is larger near positions that contain the keyword. However, unintuitively, βLuf\beta^{\mathrm{uf}}_{L} (β\beta at the last position in the sequence) will become the largest if we train the model for too long under uniform attention weights.

In this toy task, each input is a length-4040 sequence of words sampled from {1,,40}\{1,\dots,40\} uniformly at random; a sequence is positive if and only if the keyword “1” appears in the sequence. We restrict “1" to appear only once in each positive sequence, and use rejection sampling to balance positive and negative examples. Let ll^{*} be the position where xl=1x_{l^{*}}=1.

For the positive sequences, we examine the log-odd ratio γl\gamma_{l} before the sigmoid activation in Equation 8, since β\beta will be all close to 1 and comparing γ\gamma would be more informative: γl:=logβluf1βluf.\gamma_{l}:=\log\frac{\beta^{\mathrm{uf}}_{l}}{1-\beta^{\mathrm{uf}}_{l}}.

We measure four quantities: 1) γl\gamma_{l^{*}}, the log-odd ratio if the model only attends to the key word position, 2) γl+1\gamma_{l^{*}+1}, one position after the key word position, 3) γ¯:=l=1LγlL\bar{\gamma}:=\frac{\sum_{l=1}^{L}\gamma_{l}}{L}, if attention weights are uniform, and 4) γL\gamma_{L} if the model attends to the last hidden state. If the γl\gamma_{l} only contains information about word xlx_{l}, we should expect:

Hypothesis 1:γlγ¯γLγl+1.\text{Hypothesis 1}:\gamma_{l^{*}}\gg\bar{\gamma}\gg\gamma_{L}\approx\gamma_{l^{*}+1}. (22)

However, if we accept the conventional wisdom that hidden states contain information about nearby words Khandelwal et al. (2018), we should expect:

Hypothesis 2:γlγl+1γ¯γL.\text{Hypothesis 2}:\gamma_{l^{*}}\gg\gamma_{l^{*}+1}\gg\bar{\gamma}\approx\gamma_{L}. (23)

To verify these hypotheses, we plot how γl,γl+1,γ¯\gamma_{l^{*}},\gamma_{l^{*}+1},\bar{\gamma}, and γL\gamma_{L} evolve as training proceeds in Figure 7. Hypothesis 2 is indeed true when training starts; however, we find the following to be true asymptotically:

Observation 3:γLγl+1γ¯γl.\text{Observation 3}:\gamma_{L}\gg\gamma_{l^{*}+1}\gg\bar{\gamma}\approx\gamma_{l^{*}}. (24)

which is wildly different from Hypothesis 2. If we train under uniform attention weights for too long, the information about keywords can freely flow to other non-local hidden states.

Refer to caption
Figure 7: When training begins, Hypothesis 2 (Equation 22) is true; however, asymptotically, Oberservation 3 (Equation 24) is true.

6.2 Attention Weights are Free Variables

In Section 2.1 we assumed that attention weights α\alpha behave like free variables that can assign arbitrarily high probabilities to positions with larger β\beta. However, α\alpha is produced by a model, and sometimes learning the correct α\alpha can be challenging.

Let π\pi be a random permutation of integers from 1 to 40, and we want to learn the function ff that permutes the input with π\pi:

f([x1,x2,x40]):=[xπ(1),xπ(2)xπ(40)].f([x_{1},x_{2},\dots x_{40}]):=[x_{\pi(1)},x_{\pi(2)}\dots x_{\pi(40)}]. (25)

Input xx are randomly sampled from a vocab of size 60 as in Section 3.3. Even though βuf\beta^{\mathrm{uf}} behaves exactly the same for these two tasks, sequence copying is much easier to learn than permutation function: while the model always reaches perfect accuracy in the former setting within 300 iterations, it always fails in the latter. LSTM has a built-in inductive bias to learn monotonic attention.

7 Conclusions and Future Directions

Our work tries to understand the black box of attention training. Early on in training, the LSTM attention models first learn how to translation individual words from bag of words co-occurrence statistics, which then drives the learning of the attention. Our framework explains why attention weights obtained by standard training often correlate with saliency, and how multi-head attention can increase performance by improving the training dynamics rather than expressiveness. These phenomena cannot be explained if we treated training as a black box.

Increasingly more theoretical deep learning papers study the optimization trajectory, since many important properties of neural networks are determined by what happens at training time Jacot et al. (2018); Du et al. (2018); Şimşekli et al. (2019). However, it is hard to extract useful intuitions for practitioners from these results in abstract high-dimensional parameter space. In contrast, the NLP community takes another path and mostly interprets models using intuitive concepts Andreas and Klein (2017); Strobelt et al. (2018); Hewitt and Liang (2019), while relatively few look at the training dynamics. We look forward to more future works that can qualitatively predict the training dynamics using intuitive concepts by formally reasoning about the optimization trajectory.

8 Ethical Considerations

We present a new framework for understanding and predicting behaviors of an existing technology: the attention mechanism in recurrent neural networks. We do not propose any new technologies or any new datasets that could directly raise ethical questions. However, it is useful to keep in mind that our framework is far from solving the question of neural network interpretability, and should not be interpreted as ground truth in high stake domains like medicine or recidivism. We are aware and very explicit about the limitations of our framework, which we made clear in Section 6.

9 Reasons to Reject

Almost all reviewers are worried that the assumption “lexical information stays local" is “wrong". For example, Jain and Wallace (2019); Serrano and Smith (2019b); Wiegreffe and Pinter (2019b); Pruthi et al. (2020) all show that information flows across hidden states. I completely agree with the fact that information diffuses across positions, and in fact I reported the same fact in one of my own prior works Zhong et al. (2019).

Nevertheless, an assumption being wrong does not mean that we should not apply it. For example, according to modern physics, it can be argued that friction does not “exist", since it is only an aggregation of electromagnetic force at a microscopic level. However, the concept of “friction" is still extremely useful for engineering domains and successfully contributes to important technological progresses. From a instrumentalist view, scientific theory should be ultimately benchmarked by its ability to explain and predict (unknown) empirical phenomena, instead of whether it is literally true or not.

We demonstrated the predictive power of our theoretical framework in Section 3.2, 3.3, and 5, and it is up to the readers to decide whether our approximation is accurate enough. Anecdotally, I expected the model always able to learn permutation copying in Section 3.3 before running the experiments, since the task looks extremely simple. However, our theory predicts otherwise, and the empirical result indeed agrees with our theoretical prediction. The result indeed surprised me at the time of experimentation (in March 2020).

References

  • Andreas and Klein (2017) Jacob Andreas and Dan Klein. 2017. Analogs of linguistic structure in deep representations. In Proceedings of the 2017 Conference on Empirical Methods in Natural Language Processing, pages 2893–2897, Copenhagen, Denmark. Association for Computational Linguistics.
  • Bahdanau et al. (2015) Dzmitry Bahdanau, Kyunghyun Cho, and Yoshua Bengio. 2015. Neural machine translation by jointly learning to align and translate. In 3rd International Conference on Learning Representations, ICLR 2015, San Diego, CA, USA, May 7-9, 2015, Conference Track Proceedings.
  • Brown et al. (1993) Peter F Brown, Stephen A Della Pietra, Vincent J Della Pietra, and Robert L Mercer. 1993. The mathematics of statistical machine translation: Parameter estimation. Computational linguistics, 19(2):263–311.
  • Brunner et al. (2019) Gino Brunner, Yang Liu, Damian Pascual, Oliver Richter, Massimiliano Ciaramita, and Roger Wattenhofer. 2019. On identifiability in transformers. In International Conference on Learning Representations.
  • Cettolo et al. (2015) Mauro Cettolo, Jan Niehues, Sebastian Stüker, Luisa Bentivogli, and Marcello Federico. 2015. Report on the 11th iwslt evaluation campaign, iwslt 2014.
  • Devlin et al. (2019) Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. 2019. BERT: Pre-training of deep bidirectional transformers for language understanding. In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers), pages 4171–4186, Minneapolis, Minnesota. Association for Computational Linguistics.
  • Du et al. (2018) Simon S Du, Wei Hu, and Jason D Lee. 2018. Algorithmic regularization in learning deep homogeneous models: Layers are automatically balanced. In Advances in Neural Information Processing Systems, pages 384–395.
  • Ebrahimi et al. (2018) Javid Ebrahimi, Anyi Rao, Daniel Lowd, and Dejing Dou. 2018. HotFlip: White-box adversarial examples for text classification. In Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers), pages 31–36, Melbourne, Australia. Association for Computational Linguistics.
  • Elliott et al. (2016) Desmond Elliott, Stella Frank, Khalil Sima’an, and Lucia Specia. 2016. Multi30K: Multilingual English-German image descriptions. In Proceedings of the 5th Workshop on Vision and Language, pages 70–74, Berlin, Germany. Association for Computational Linguistics.
  • Hewitt and Liang (2019) John Hewitt and Percy Liang. 2019. Designing and interpreting probes with control tasks. In Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP), pages 2733–2743.
  • Jacot et al. (2018) Arthur Jacot, Franck Gabriel, and Clément Hongler. 2018. Neural tangent kernel: Convergence and generalization in neural networks. In Advances in neural information processing systems, pages 8571–8580.
  • Jain and Wallace (2019) Sarthak Jain and Byron C. Wallace. 2019. Attention is not Explanation. In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers), pages 3543–3556, Minneapolis, Minnesota. Association for Computational Linguistics.
  • Khandelwal et al. (2018) Urvashi Khandelwal, He He, Peng Qi, and Dan Jurafsky. 2018. Sharp nearby, fuzzy far away: How neural language models use context. In Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pages 284–294, Melbourne, Australia. Association for Computational Linguistics.
  • Luong et al. (2015) Thang Luong, Hieu Pham, and Christopher D. Manning. 2015. Effective approaches to attention-based neural machine translation. In Proceedings of the 2015 Conference on Empirical Methods in Natural Language Processing, pages 1412–1421, Lisbon, Portugal. Association for Computational Linguistics.
  • Maas et al. (2011) Andrew L. Maas, Raymond E. Daly, Peter T. Pham, Dan Huang, Andrew Y. Ng, and Christopher Potts. 2011. Learning word vectors for sentiment analysis. In Proceedings of the 49th Annual Meeting of the Association for Computational Linguistics: Human Language Technologies, pages 142–150, Portland, Oregon, USA. Association for Computational Linguistics.
  • Michel et al. (2019) Paul Michel, Omer Levy, and Graham Neubig. 2019. Are sixteen heads really better than one? In Advances in Neural Information Processing Systems, pages 14014–14024.
  • Pruthi et al. (2020) Danish Pruthi, Mansi Gupta, Bhuwan Dhingra, Graham Neubig, and Zachary C. Lipton. 2020. Learning to deceive with attention-based explanations. In Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics, pages 4782–4793, Online. Association for Computational Linguistics.
  • Serrano and Smith (2019a) Sofia Serrano and Noah A. Smith. 2019a. Is attention interpretable? In Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics, pages 2931–2951, Florence, Italy. Association for Computational Linguistics.
  • Serrano and Smith (2019b) Sofia Serrano and Noah A Smith. 2019b. Is attention interpretable? arXiv preprint arXiv:1906.03731.
  • Şimşekli et al. (2019) Umut Şimşekli, Levent Sagun, and Mert Gurbuzbalaban. 2019. A tail-index analysis of stochastic gradient noise in deep neural networks. In Proceedings of the 36th International Conference on Machine Learning (ICML 2019).
  • Socher et al. (2013) Richard Socher, Alex Perelygin, Jean Wu, Jason Chuang, Christopher D Manning, Andrew Y Ng, and Christopher Potts. 2013. Recursive deep models for semantic compositionality over a sentiment treebank. In Proceedings of the 2013 conference on empirical methods in natural language processing, pages 1631–1642.
  • Strobelt et al. (2018) Hendrik Strobelt, Sebastian Gehrmann, Michael Behrisch, Adam Perer, Hanspeter Pfister, and Alexander M Rush. 2018. S eq 2s eq-v is: A visual debugging tool for sequence-to-sequence models. IEEE transactions on visualization and computer graphics, 25(1):353–363.
  • Sun and Lu (2020) Xiaobing Sun and Wei Lu. 2020. Understanding attention for text classification. In Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics, pages 3418–3428, Online. Association for Computational Linguistics.
  • Vashishth et al. (2019) Shikhar Vashishth, Shyam Upadhyay, Gaurav Singh Tomar, and Manaal Faruqui. 2019. Attention interpretability across nlp tasks. arXiv preprint arXiv:1909.11218.
  • Voita et al. (2019) Elena Voita, David Talbot, Fedor Moiseev, Rico Sennrich, and Ivan Titov. 2019. Analyzing multi-head self-attention: Specialized heads do the heavy lifting, the rest can be pruned. In Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics, pages 5797–5808, Florence, Italy. Association for Computational Linguistics.
  • Wiegreffe and Pinter (2019a) Sarah Wiegreffe and Yuval Pinter. 2019a. Attention is not not explanation. In Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP), pages 11–20, Hong Kong, China. Association for Computational Linguistics.
  • Wiegreffe and Pinter (2019b) Sarah Wiegreffe and Yuval Pinter. 2019b. Attention is not not explanation. arXiv preprint arXiv:1908.04626.
  • Zhang et al. (2015) Xiang Zhang, Junbo Zhao, and Yann LeCun. 2015. Character-level convolutional networks for text classification. In Advances in neural information processing systems, pages 649–657.
  • Zhong et al. (2019) Ruiqi Zhong, Steven Shao, and Kathleen McKeown. 2019. Fine-grained sentiment analysis with faithful attention. arXiv preprint arXiv:1908.06870.

Appendix A Appendices

A.1 Heuristic that α\alpha Attends to Larger β\beta

It is a heuristic rather than a rigorous theorem that attention α\alpha is attracted to larger β\beta. There are two reasons. First, there is a non-linear layer after the averaging the hidden states, which can interact in an arbitrarily complex way to break this heuristic. Second, even if there are no non-linear operations after hidden state aggregation, the optimal attention that minimizes the loss does not necessarily assign any probability to the position with the largest β\beta value when there are more than two output vocabs.

Specifically, we consider the following model:

pt=σ(Wcl=1αt,lhl+Wsst)=σ(l=1αt,lγl+γs),p_{t}=\sigma(W_{c}\sum_{l=1}\alpha_{t,l}h_{l}+W_{s}s_{t})=\sigma(\sum_{l=1}\alpha_{t,l}\gamma_{l}+\gamma_{s}), (26)

where WcW_{c} and WsW_{s} are learnable weights, and γ\gamma defined as:

γl:=Wchl;γs:=Wsstβt,l=σ(γl+γs)yt.\gamma_{l}:=W_{c}h_{l};\quad\gamma_{s}:=W_{s}s_{t}\Rightarrow\beta_{t,l}=\sigma(\gamma_{l}+\gamma_{s})_{y_{t}}. (27)

Consider the following scenario that outputs a probability distribution pp over 3 output vocabs and γs\gamma_{s} is set to 0:

p=σ(α1γ1+α2γ2+α3γ3),p=\sigma(\alpha_{1}\gamma_{1}+\alpha_{2}\gamma_{2}+\alpha_{3}\gamma_{3}), (28)

where γl=1,2,3|𝒪|=3\gamma_{l=1,2,3}\in\mathbb{R}^{|\mathcal{O}|=3} are the logits, α\alpha is a valid attention probability distribution, σ\sigma is the softmax, and pp is the probability distribution produced by this model. Suppose

γ1=[0,0,0],γ2=[0,10,5],γ3=[0,5,10]\gamma_{1}=[0,0,0],\gamma_{2}=[0,-10,5],\gamma_{3}=[0,5,-10] (29)

and the correct output is the first output vocab (i.e. the first dimension). Therefore, we take the softmax of γl\gamma_{l} and consider the first dimension:

βl=1=13>βl=2=βl=3e5.\beta_{l=1}=\frac{1}{3}>\beta_{l=2}=\beta_{l=3}\approx e^{-5}. (30)

We calculate “optimal α\alpha" αopt\alpha^{\mathrm{opt}}: the optimal attention weights that can maximize the correct output word probability p0p_{0} and minimize the loss. We find that α2opt=α3opt=0.5\alpha^{\mathrm{opt}}_{2}=\alpha^{\mathrm{opt}}_{3}=0.5, while α1opt=0\alpha^{\mathrm{opt}}_{1}=0. In this example, the optimal attention assigns 0 weight to the position ll with the highest βl\beta_{l}.

Fortunately, such pathological examples rarely occur in real datasets, and the optimal α\alpha are usually attracted to positions with higher β\beta. We empirically verify this for the below variant of machine translation model on Multi30K.

As before, we obtain the context vector ctc_{t}. Instead of concatenating ctc_{t} and dtd_{t} and pass it into a non-linear neural network NN, we add them and apply a linear layer with softmax after it to obtain the output word probability distribution

pt=σ(W(ct+dt)).p_{t}=\sigma(W(c_{t}+d_{t})). (31)

This model is desirable because we can now provably find the optimal α\alpha using gradient descent (we delay the proof to the end of this subsection). Additionally, this model has comparable performance with the variant from our main paper (Section 2.1), achieving 38.2 BLEU score, vs. 37.9 for the model in our main paper. We use αopt\alpha^{\mathrm{opt}} to denote the attention that can minimize the loss, and we find that 𝒜(αopt,β)=0.53\mathcal{A}(\alpha^{\mathrm{opt}},\beta)=0.53. β\beta do strongly agree with αopt\alpha^{\mathrm{opt}}.

Now we are left to show that we can use gradient descent to find the optimal attention weights to minimize the loss. We can rewrite ptp_{t} as

pt=σ(l=1LαlWhl+Wdt).p_{t}=\sigma(\sum_{l=1}^{L}\alpha_{l}Wh_{l}+Wd_{t}). (32)

We define

γl:=Whl;γs:=Wdt.\gamma_{l}:=Wh_{l};\quad\gamma_{s}:=Wd_{t}. (33)

Without loss of generality, suppose the first dimension of γ1L,γs\gamma_{1\dots L},\gamma_{s} are all 0, and the correct token we want to maximize probability for is the first dimension, then the loss for the output word is

=log(1+g(α)),\mathcal{L}=\log(1+g(\alpha)), (34)

where

g(α):=o𝒪,o0eαTγo+γs,o,g(\alpha):=\sum_{o\in\mathcal{O},o\neq 0}e^{\alpha^{T}\gamma^{\prime}_{o}+\gamma_{s,o}}, (35)

where

γo=[γ1,oγl,oγL,o]L.\gamma^{\prime}_{o}=[\gamma_{1,o}\dots\gamma_{l,o}\dots\gamma_{L,o}]\in\mathbb{R}^{L}. (36)

Since α\alpha is defined within the convex probability simplex and g(α)g(\alpha) is convex with respect to α\alpha, the global optima αopt\alpha^{\mathrm{opt}} can be found by gradient descent.

A.2 Calculating θi,oτ\frac{\partial\theta_{i,o}}{\partial\tau}

We drop the px super-script of θ\theta to keep the notation uncluttered. We copy the loss function here to remind the readers:

=mt=1Tmlog(σ(1Lml=1LmWexlm)ytm).\mathcal{L}=-\sum_{m}\sum_{t=1}^{T^{m}}\log(\sigma(\frac{1}{L^{m}}\sum_{l=1}^{L^{m}}We_{x^{m}_{l}})_{y^{m}_{t}}). (37)

and since we optimize WW and ee with gradient flow,

Wτ:=W;eτ:=e.\frac{\partial W}{\partial\tau}:=-\frac{\mathcal{L}}{\partial W};\quad\frac{\partial e}{\partial\tau}:=-\frac{\mathcal{L}}{\partial e}. (38)

We first define the un-normalized logits γ^\hat{\gamma} and then take the softmax.

θ^=We,\hat{\theta}=We, (39)

then

θ^τ=(We)τ=WeτWτe.\frac{\partial\hat{\theta}}{\partial\tau}=\frac{\partial(We)}{\partial\tau}=-W\frac{\partial e}{\partial\tau}-\frac{\partial W}{\partial\tau}e. (40)

We first analyze ϵ:=Weτ\epsilon:=-W\frac{\partial e}{\partial\tau}. Since ϵ||×|𝒪|\epsilon\in\mathbb{R}^{|\mathcal{I}|\times|\mathcal{O}|}, we analyze each entry ϵi,o\epsilon_{i,o}. Since differentiation operation and left multiplication by matrix WW is linear, we analyze each individual loss term in Equation 37 and then sum them up.

We define

pm:=σ(1Lml=1LmWexlm)p^{m}:=\sigma(\frac{1}{L^{m}}\sum_{l=1}^{L^{m}}We_{x^{m}_{l}}) (41)

and

tm:=log(pytmm);ϵt,i,om:=Wotmei.\mathcal{L}^{m}_{t}:=-\log(p^{m}_{y^{m}_{t}});\quad\epsilon^{m}_{t,i,o}:=W_{o}\frac{\partial\mathcal{L}^{m}_{t}}{\partial e_{i}}. (42)

Hence,

=mt=1Tmtm;ϵi,o=mt=1Tmϵt,i,om.\mathcal{L}=\sum_{m}\sum_{t=1}^{T^{m}}\mathcal{L}^{m}_{t};\quad\epsilon_{i,o}=\sum_{m}\sum_{t=1}^{T^{m}}\epsilon^{m}_{t,i,o}. (43)

Therefore,

tmei=1Lml=1Lm𝟏[xlm=i](Wytmo=1|𝒪|pomWo).-\frac{\partial\mathcal{L}^{m}_{t}}{\partial e_{i}}=\frac{1}{L^{m}}\sum_{l=1}^{L^{m}}\mathbf{1}[x^{m}_{l}=i](W_{y^{m}_{t}}-\sum_{o=1}^{|\mathcal{O}|}p^{m}_{o}W_{o}). (44)

Hence,

ϵt,i,ytmm\displaystyle\epsilon^{m}_{t,i,y^{m}_{t}} =WytmTtmei=1Lml=1Lm𝟏[xlm=i]\displaystyle=-W^{T}_{y^{m}_{t}}\frac{\partial\mathcal{L}^{m}_{t}}{\partial e_{i}}=\frac{1}{L^{m}}\sum_{l=1}^{L^{m}}\mathbf{1}[x^{m}_{l}=i] (45)
(Wytm22o=1|𝒪|pomWytmTWo),\displaystyle(||W_{y^{m}_{t}}||^{2}_{2}-\sum_{o=1}^{|\mathcal{O}|}p^{m}_{o}W_{y^{m}_{t}}^{T}W_{o}),

while for oytmo^{\prime}\neq y^{m}_{t},

ϵt,i,om\displaystyle\epsilon^{m}_{t,i,o^{\prime}} =WoTtmei=1Lml=1Lm𝟏[xlm=i]\displaystyle=-W^{T}_{o^{\prime}}\frac{\partial\mathcal{L}^{m}_{t}}{\partial e_{i}}=\frac{1}{L^{m}}\sum_{l=1}^{L^{m}}\mathbf{1}[x^{m}_{l}=i] (46)
(WoTWytmo=1|𝒪|pomWoTWo).\displaystyle(W_{o^{\prime}}^{T}W_{y^{m}_{t}}-\sum_{o=1}^{|\mathcal{O}|}p^{m}_{o}W_{o^{\prime}}^{T}W_{o}).

If WoW_{o} and eie_{i} are each sampled i.i.d. from 𝒩(0,Id/d)\mathcal{N}(0,I_{d}/d), then by central limit theorem:

oo,dWoTWo𝑝𝒩(0,1),\forall o\neq o^{\prime},\sqrt{d}W_{o}^{T}W_{o^{\prime}}\overset{p}{\to}\mathcal{N}(0,1), (47)
o,i,dWoTei𝑝𝒩(0,1),\forall o,i,\sqrt{d}W_{o}^{T}e_{i}\overset{p}{\to}\mathcal{N}(0,1), (48)

and

o,d(Wo221)𝑝𝒩(0,2).\forall o,\sqrt{d}(||W_{o}||_{2}^{2}-1)\overset{p}{\to}\mathcal{N}(0,2). (49)

Therefore, when τ=0\tau=0,

limdϵt,i,om𝑝1Lml=1Lm𝟏[xlm=i](𝟏[ylt=o]1|𝒪|).\lim_{d\rightarrow\infty}\epsilon^{m}_{t,i,o}\overset{p}{\to}\frac{1}{L^{m}}\sum_{l=1}^{L^{m}}\mathbf{1}[x^{m}_{l}=i](\mathbf{1}[y^{t}_{l}=o]-\frac{1}{|\mathcal{O}|}). (50)

Summing over all the ϵt,i,om\epsilon_{t,i,o}^{m} terms, we have that

ϵi,o=Ci,o1|𝒪|oCi,o,\epsilon_{i,o}=C_{i,o}-\frac{1}{|\mathcal{O}|}\sum_{o^{\prime}}C_{i,o^{\prime}}, (51)

where CC is defined as

Ci,o:=ml=1Lmt=1Tm1Lm𝟏[xlm=i]𝟏[ytm=o].C_{i,o}:=\sum_{m}\sum^{L^{m}}_{l=1}\sum^{T^{m}}_{t=1}\frac{1}{L^{m}}\mathbf{1}[x^{m}_{l}=i]\mathbf{1}[y^{m}_{t}=o]. (52)

We find that Wτe-\frac{\partial W}{\partial\tau}e converges exactly to the same value. Hence

θ^i,oτ=Weτ=2(Ci,o1|𝒪|oCi,o).\frac{\partial\hat{\theta}_{i,o}}{\partial\tau}=\frac{\partial We}{\partial\tau}=2(C_{i,o}-\frac{1}{|\mathcal{O}|}\sum_{o^{\prime}}C_{i,o^{\prime}}). (53)

Since limdθ(τ=0)𝑝1|𝒪|𝟏||×|𝒪|\lim_{d\rightarrow\infty}\theta(\tau=0)\overset{p}{\to}\frac{1}{|\mathcal{O}|}\mathbf{1}^{|\mathcal{I}|\times|\mathcal{O}|}, by chain rule,

limdγi,oτ(τ=0)𝑝2(Ci,o1|𝒪|o𝒪Ci,o).\lim_{d\rightarrow\infty}\frac{\partial\gamma_{i,o}}{\partial\tau}(\tau=0)\overset{p}{\to}2(C_{i,o}-\frac{1}{|\mathcal{O}|}\sum_{o^{\prime}\in\mathcal{O}}C_{i,o^{\prime}}). (54)

A.3 Mixture of Permutations

For this experiment, each input is either a random permutation of the set {140}\{1\dots 40\}, or a random permutation of the set {4180}\{41\dots 80\}. The proxy model can easily learn whether the input words are less than 4040 and decide whether the output words are all less than 4040. However, βpx\beta^{\mathrm{px}} is still the same for every position; as a result, the attention and hence the model fail to learn. The count table CC can be see in Figure 8.

Refer to caption
Figure 8: The training distributions mixes random permutation of disjoint set of words (left and right, respectively). From the count table, βpx\beta^{\mathrm{px}} could learn that the set of input words {A,B,C,D}\{A,B,C,D\} corresponds to the set of output words {A,B,C,D}\{A^{\prime},B^{\prime},C^{\prime},D^{\prime}\}, but its β\beta value for each input position is still uniformly 0.25.

A.4 Additional Tables for Completeness

We report several variants of Table 2. We chose to use token accuracy to contextualize the agreement metric in the main paper, because the errors would accumulate much more if we use a not-fully trained model to auto-regressively generate output words.

  • Table 3 contains the same results as Table 2, except that its agreement score 𝒜(u,v)\mathcal{A}(u,v) is now Kendall Tau rank correlation coefficient, which is a more popular metric.

  • Table 5 contains the same results as Table 2, except that results are now rounded to two decimal places.

  • Table 7 consists of the same results as Table 2, except that the statistics is calculated over the training set rather than the validation set.

  • Table 4, Table 6, and Table 8 contain the translation results from the above 3 mentioned tables respectively, except that p^\hat{p} is defined as BLEU score rather than token accuracy, and hence the contextualized metric interpretation ξ\xi changes correspondingly.

Task 𝒜(α,βuf)\mathcal{A}(\alpha,\beta^{\mathrm{uf}}) 𝒜(βuf,βpx)\mathcal{A}(\beta^{\mathrm{uf}},\beta^{\mathrm{px}}) 𝒜(Δ,βuf)\mathcal{A}(\Delta,\beta^{\mathrm{uf}}) 𝒜^\hat{\mathcal{A}}
IMDB 12.77 33.31 12.56 0.00
Yelp 20.38 36.75 20.98 0.00
AG News 26.31 36.65 20.55 0.00
20 NG 16.06 22.03 06.50 0.00
SST 11.68 31.43 15.01 0.00
Amzn 15.21 35.84 09.33 0.00
Muti30k 07.89 27.54 03.93 0.00
IWSLT14 08.64 22.56 02.72 0.00
News It-Pt 04.82 17.16 01.63 0.00
News En-Nl 04.53 20.35 02.08 0.00
News En-Pt 04.65 18.20 02.17 0.00
Task ξ(α,βuf)\xi(\alpha,\beta^{\mathrm{uf}}) ξ(βuf,βpx)\xi(\beta^{\mathrm{uf}},\beta^{\mathrm{px}}) ξ(Δ,βuf)\xi(\Delta,\beta^{\mathrm{uf}}) ξ\xi^{*}
IMDB 70.60 80.50* 70.60 89.55
Yelp 87.44 93.20* 87.44 96.20
AG News 89.31 93.54* 85.85 96.05
20 NG 60.75 60.75* 60.75 94.22
SST 76.69 83.53 76.69 83.53
Amzn 69.91 88.07* 57.36 90.38
Muti30k 22.94 36.61* 22.94 66.29
IWSLT14 29.07 32.98* 29.07 67.36
News It-Pt 14.01 18.25* 08.10 55.41
News En-Nl 09.60 18.59* 09.60 62.90
News En-Pt 14.10 14.10* 07.71 67.75
Table 3: Table 2 except with agreement defined by Kendall Tau. Section A.4
Task 𝒜(α,βuf)\mathcal{A}(\alpha,\beta^{\mathrm{uf}}) 𝒜(βuf,βpx)\mathcal{A}(\beta^{\mathrm{uf}},\beta^{\mathrm{px}}) 𝒜(Δ,βuf)\mathcal{A}(\Delta,\beta^{\mathrm{uf}}) 𝒜^\hat{\mathcal{A}}
Muti30k 8.68 27.54 4.24 0.00
IWSLT14 8.64 22.56 2.72 0.00
News It-Pt 4.82 17.16 1.63 0.00
News En-Nl 4.53 20.35 2.08 0.00
News En-Pt 4.41 18.20 2.05 0.00
Task ξ(α,βuf)\xi(\alpha,\beta^{\mathrm{uf}}) ξ(βuf,βpx)\xi(\beta^{\mathrm{uf}},\beta^{\mathrm{px}}) ξ(Δ,βuf)\xi(\Delta,\beta^{\mathrm{uf}}) ξ\xi^{*}
Muti30k 1.99 6.91 1.99 37.89
IWSLT14 5.38 5.31 5.38 32.95
News It-Pt 0.09 0.55 0.04 24.71
News En-Nl 0.01 0.94 0.01 29.42
News En-Pt 0.01 0.22 0.01 37.04
Table 4: translation results from Table 3 except with performance measured by bleu rather than token accuracy. Section A.4
Task 𝒜(α,βuf)\mathcal{A}(\alpha,\beta^{\mathrm{uf}}) 𝒜(βuf,βpx)\mathcal{A}(\beta^{\mathrm{uf}},\beta^{\mathrm{px}}) 𝒜(Δ,βuf)\mathcal{A}(\Delta,\beta^{\mathrm{uf}}) 𝒜(α,β)\mathcal{A}(\alpha,\beta) 𝒜^\hat{\mathcal{A}}
IMDB 52.55 81.60 61.55 60.35 5.30
Yelp 17.55 75.38 58.90 35.00 5.80
AG News 39.24 55.13 43.08 48.13 6.20
20 NG 65.08 41.33 64.82 63.07 5.11
SST 19.85 33.57 22.45 25.33 8.39
Amzn 52.02 76.78 49.68 62.13 5.80
Muti30k 31.02 34.43 27.06 48.78 7.11
IWSLT14 35.75 39.09 27.69 55.25 6.52
News It-Pt 29.13 38.62 25.45 52.48 6.17
News En-Nl 35.53 41.72 29.15 60.15 6.36
News En-Pt 35.90 37.37 30.23 65.49 6.34
Task ξ(α,βuf)\xi(\alpha,\beta^{\mathrm{uf}}) ξ(βuf,βpx)\xi(\beta^{\mathrm{uf}},\beta^{\mathrm{px}}) ξ(Δ,βuf)\xi(\Delta,\beta^{\mathrm{uf}}) ξ(α,β)\xi(\alpha,\beta) ξ\xi^{*}
IMDB 86.81 88.88* 86.81 86.81 89.55
Yelp 90.39 95.22* 95.31 93.59 96.20
AG News 93.54 96.05 94.32 94.50 96.05
20 NG 91.16 60.75* 84.57 84.57 94.22
SST 78.16 83.53 78.16 82.38 83.53
Amzn 82.48 90.38 82.48 88.07 90.38
Muti30k 43.45 43.45* 43.45 48.58 66.29
IWSLT14 35.82 35.82* 32.98 44.09 67.36
News It-Pt 21.82 25.06* 21.82 25.06 55.41
News En-Nl 18.59 23.21* 18.59 26.79 62.90
News En-Pt 19.12 19.12* 19.12 27.85 67.75
Table 5: Table 2 with 2 decimal results. Section A.4
Task 𝒜(α,βuf)\mathcal{A}(\alpha,\beta^{\mathrm{uf}}) 𝒜(βuf,βpx)\mathcal{A}(\beta^{\mathrm{uf}},\beta^{\mathrm{px}}) 𝒜(Δ,βuf)\mathcal{A}(\Delta,\beta^{\mathrm{uf}}) 𝒜(α,β)\mathcal{A}(\alpha,\beta) 𝒜^\hat{\mathcal{A}}
Muti30k 30.77 34.43 27.24 48.70 7.19
IWSLT14 35.75 39.09 27.69 55.25 6.52
News It-Pt 29.13 38.62 25.45 52.48 6.17
News En-Nl 35.53 41.72 29.15 60.15 6.35
News En-Pt 35.77 37.37 30.37 64.94 6.34
Task ξ(α,βuf)\xi(\alpha,\beta^{\mathrm{uf}}) ξ(βuf,βpx)\xi(\beta^{\mathrm{uf}},\beta^{\mathrm{px}}) ξ(Δ,βuf)\xi(\Delta,\beta^{\mathrm{uf}}) ξ(α,β)\xi(\alpha,\beta) ξ\xi^{*}
Muti30k 11.43 11.43 11.43 16.41 37.89
IWSLT14 06.71 06.71 05.31 9.89 32.95
News It-Pt 01.29 02.16 01.29 2.16 24.71
News En-Nl 00.94 02.39 00.94 4.12 29.42
News En-Pt 00.74 00.74 00.74 4.28 37.04
Table 6: translation results from Table 5 except with performance measured by bleu rather than token accuracy. Section A.4
Task 𝒜(α,βuf)\mathcal{A}(\alpha,\beta^{\mathrm{uf}}) 𝒜(βuf,βpx)\mathcal{A}(\beta^{\mathrm{uf}},\beta^{\mathrm{px}}) 𝒜(Δ,βuf)\mathcal{A}(\Delta,\beta^{\mathrm{uf}}) 𝒜(α,β)\mathcal{A}(\alpha,\beta) 𝒜^\hat{\mathcal{A}}
IMDB 51.52 80.10 42.85 64.88 5.29
Yelp 11.15 76.12 55.50 37.63 5.85
AG News 36.97 53.95 43.11 46.89 6.17
20 NG 72.36 38.69 71.73 69.47 5.32
SST 21.82 29.35 20.48 28.50 8.48
Amzn 51.95 77.18 40.15 61.78 5.91
Muti30k 32.89 34.67 28.36 56.39 7.21
IWSLT14 36.61 38.95 28.37 57.71 6.52
News It-Pt 31.03 38.70 27.11 64.81 6.15
News En-Nl 37.86 41.91 31.11 67.68 6.39
News En-Pt 37.43 37.23 31.76 71.96 6.35
Task ξ(α,βuf)\xi(\alpha,\beta^{\mathrm{uf}}) ξ(βuf,βpx)\xi(\beta^{\mathrm{uf}},\beta^{\mathrm{px}}) ξ(Δ,βuf)\xi(\Delta,\beta^{\mathrm{uf}}) ξ(α,β)\xi(\alpha,\beta) ξ\xi^{*}
IMDB 090.40 099.95 090.40 095.01 099.95
Yelp 075.61 096.54* 096.19 094.44 098.22
AG News 093.57 098.42 094.63 095.54 098.42
20 NG 100.00 65.40 100.00 0100.0 100.00
SST 097.72 100.00 084.11 100.0 100.00
Amzn 087.96 099.58 080.98 091.09 099.58
Muti30k 043.27 043.27* 043.27 51.97 080.76
IWSLT14 035.94 035.94* 035.94 44.18 071.18
News It-Pt 022.69 025.96* 022.69 39.98 077.10
News En-Nl 018.85 023.56* 018.85 40.09 074.49
News En-Pt 019.33 019.33* 019.33 42.41 077.97
Table 7: Table 2 except with correlations and performance metrics taken over the training set instead of the validation set. Section A.4
Task 𝒜(α,βuf)\mathcal{A}(\alpha,\beta^{\mathrm{uf}}) 𝒜(βuf,βpx)\mathcal{A}(\beta^{\mathrm{uf}},\beta^{\mathrm{px}}) 𝒜(Δ,βuf)\mathcal{A}(\Delta,\beta^{\mathrm{uf}}) 𝒜^\hat{\mathcal{A}}
Muti30k 32.89 34.67 28.36 7.16
IWSLT14 36.61 38.95 28.37 6.54
News It-Pt 31.03 38.70 27.11 6.17
News En-Nl 37.86 41.91 31.11 6.38
News En-Pt 37.43 37.23 31.76 6.37
Task ξ(α,βuf)\xi(\alpha,\beta^{\mathrm{uf}}) ξ(βuf,βpx)\xi(\beta^{\mathrm{uf}},\beta^{\mathrm{px}}) ξ(Δ,βuf)\xi(\Delta,\beta^{\mathrm{uf}}) ξ\xi^{*}
Muti30k 11.87 11.87 11.87 52.28
IWSLT14 06.82 06.82 06.82 36.23
News It-Pt 01.30 02.30 01.30 42.40
News En-Nl 01.11 02.29 01.11 39.40
News En-Pt 00.83 00.83 00.83 46.57
Table 8: translation results from Table Table 7 except with performance measured by bleu rather than token accuracy. Section A.4

A.5 Dataset Description

We summarize the datasets that we use for classification and machine translation. See Table 9 for details on train/test splits and median sequence lengths for each dataset.

IMDB Sentiment Analysis Maas et al. (2011) A sentiment analysis data set with 50,000 (25,000 train and 25,000 test) IMDB movie reviews and their corresponding positive or negative sentiment.

AG News Corpus Zhang et al. (2015) 120,000 news articles and their corresponding topic (world, sports, business, or science/tech). We classify between the world and business articles.

20 Newsgroups 555http://qwone.com/ jason/20Newsgroups/ A news data set containing around 18,000 newsgroups articles split between 20 different labeled categories. We classify between baseball and hocky articles.

Stanford Sentiment Treebank Socher et al. (2013) A data set for classifying the sentiment of movie reviews, labeled on a scale from 1 (negative) to 5 (positive). We remove all movies labeled as 3, and classify between 4 or 5 and 1 or 2.

Multi Domain Sentiment Data set 666https://www.cs.jhu.edu/ mdredze/datasets/sentiment/ Approximately 40,000 Amazon reviews from various product categories labeled with a corresponding positive or negative label. Since some of the sequences are particularly long, we only use sequences of length less than 400 words.

Yelp Open Data Set 777https://www.yelp.com/dataset 20,000 Yelp reviews and their corresponding star rating from 1 to 5. We classify between reviews with rating 2\leq 2 and 4\geq 4.

Multi-30k Elliott et al. (2016) English to German translation. The data is from translation image captions.

IWSLT’14 Cettolo et al. (2015) German to English translation. The data is from translated TED talk transcriptions.

News Commentary v14 Cettolo et al. (2015) A collection of translation news commentary datasets in different languages from WMT19 888http://www.statmt.org/wmt19/translation-task.html. We use the following translation splits: English-Dutch (En-Nl), English-Portuguese (En-Pt), and Italian-Portuguese (It-Pt). In pre-processing for this dataset, we removed all purely numerical examples.

Data median train seq len train #
IMDB 181 025000
AG News 040 060000
NewsG 183 001197
SST 016 005130
Amzn 071 032514
Yelp 074 088821
IWSLT14 (23 src, 24 trg) 160240
Multi-30k (14 src, 14 trg) 029000
News-en-nl (30 src, 34 trg) 052070
News-en-pt (31 src, 35 trg) 048538
News-it-pt (36 src, 35 trg) 021572
Data median val seq len val #
IMDB 178 004000
AG News 040 003800
NewsG 207 000796
SST 017 001421
Amzn 072 004000
Yelp 074 004000
IWSLT14 (22 src, 23 trg) 007284
Multi-30k (15 src, 14 trg) 001014
News-en-nl (30 src, 34 trg) 005786
News-en-pt (31 src, 35 trg) 005394
News-it-pt (36 src, 36 trg) 002397
Data vocab size
IMDB 60338
AG News 31065
NewsG 31065
SST 11022
Amzn 37110
Yelp 41368
IWSLT14 08000
Multi-30k 08000
News-en-nl 08000
News-en-pt 08000
News-it-pt 08000
Table 9: statistics for each dataset. Median sequence length in the training set and train set size. Note: src refers to the input "source" sequence, and trg refers to the output "target" sequence. Section A.5

A.6 α\alpha Fails When β\beta is Frozen

For each classification task we initialize a random model and freeze all parameters except for the attention layer (frozen β\beta model). We then compute the correlation between this trained attention (defined as αfr\alpha^{\mathrm{fr}}) and the normal attention α\alpha. Table 10 reports this correlation at the iteration where αfr\alpha^{\mathrm{fr}} is most correlated with α\alpha on the validation set. As shown in Table 10, the left column is consistently lower than the right column. This indicates that the model can learn output relevance without attention, but not vice versa.

Dataset 𝒜(α,αfr)\mathcal{A}(\alpha,\alpha^{\mathrm{fr}}) 𝒜(α,βuf)\mathcal{A}(\alpha,\beta^{\mathrm{uf}})
IMDB 09 53
AG News 17 39
20 NG 19 65
SST 14 20
Amzn 15 52
Yelp 08 18
Table 10: We report the correlation between αfr\alpha^{\mathrm{fr}} and α\alpha on classification datasets, and compare it against 𝒜(α,βuf)\mathcal{A}(\alpha,\beta^{\mathrm{uf}}), the same column defined in Table 2. Section A.6

A.7 Training βuf\beta^{\mathrm{uf}}

We find that 𝒜(α,βuf(τ))\mathcal{A}(\alpha,\beta^{\mathrm{uf}}(\tau)) first increases and then decreases as training proceeds (i.e. τ\tau increases), so we chose the maximum agreement to report in Table 2 over the course of training. Since this trend is consistent across all datasets, our choice minimally inflates the agreement measure, and is comparable to the practice of reporting dev set results. As discussed in Section 6.1, training under uniform attention for too long might bring unintuitive results,

A.8 Model and Training Details

Classification

Our model uses dimension 300 GloVe-6B pre-trained embeddings to initialize the token embeddings where they aligned with our vocabulary. The sequences are encoded with a 1 layer bidirectional LSTM of dimension 256. The rest of the model, including the attention mechanism, is exactly as described in 2.4. Our model has 1,274,882 parameters excluding embeddings. Since each classification set has a different vocab size each model has a slightly different parameter count when considering embeddings: 19,376,282 for IMDB, 10,594,382 for AG News, 5,021,282 for 20 Newsgroups, 4,581,482 for SST, 13,685,282 for Yelp, 12,407,882 for Amazon, and 2,682,182 for SMS.

Translation

We use a a bidirectional two layer bi-LSTM of dimension 256 to encode the source and the use last hidden state hLh_{L} as the first hidden state of the decoder. The attention and outputs are then calculated as described in 2. The learn-able neural network before the outputs that is mentioned in Section 2, is a 1 hidden layer model with ReLU non-linearity. The hidden layer is dimension 256. Our model contains 6,132,544 parameters excluding embeddings and 8,180,544 including embeddings on all datasets.

Permutation Copying

We use single directional single layer LSTM with hidden dimension 256 for both the encoder and the decoder.

Classification Procedure

For all classification datasets we used a batch size of 32. We trained for 4000 iterations on each dataset. For each dataset we train on the pre-defined training set if the dataset has one. Additionally, if a dataset had a predefined test set, we randomly sample at most 4000 examples from this test set for validation. Specific dataset split sizes are given in Table 9.

Classification Evaluation

We evaluated each model at steps 0, 10, 50, 100, 150, 200, 250, and then every 250 iterations after that.

Classification Tokenization

We tokenized the data at the word level. We mapped all words occurring less than 3 times in the training set to <unk>. For 20 Newsgroups and AG News we mapped all non-single digit integer "words" to <unk>. For 20 Newsgroups we also split words with the "_" character.

Classification Training

We trained all classification models on a single GPU. Some datasets took slightly longer to train than others (largely depending on average sequence length), but each train took at most 45 minutes.

Translation Hyper Parameters

For translation all hidden states in the model are dimension 256. We use the sequence to sequence architecture described above. The LSTMs used dropout 0.5.

Translation Procedure

For all translation tasks we used batch size 16 when training. For IWSLT’14 and Multi-30k we used the provided dataset splits. For the News Commentary v14 datasets we did a 90-10 split of the data for training and validation respectively.

Translation Evaluation

We evaluated each model at steps 0, 50, 100, 500, 1000, 1500, and then every 2000 iterations after that.

Translation Training

We trained all translation models on a single GPU. IWSLT’14, and the News Commentary datasets took approximately 5-6 hours to train, and multi-30k took closer to 1 hour to train.

Translation Tokenization

We tokenized both translation datasets using the Sentence-Piece tokenizer trained on the corresponding train set to a vocab size of 8,000. We used a single tokenization for source and target tokens. And accordingly also used the same matrix of embeddings for target and source sequences.

A.9 A Note On SMS Dataset

In addition to the classification datasets reported in the tables, we also ran experiments on the SMS Spam Collection V.1 dataset 999http://www.dt.fee.unicamp.br/ tiago/smsspamcollection/. The attention learned from this dataset was very high variance, and so two different random seeds would consistently produce attentions that did not correlate much. The dataset itself was also a bit of an outlier; it had shorter sequence lengths than any of the other datasets (median sequence length 13 on train and validation set), it also had the smallest training set out of all our datasets (3500 examples), and it had by far the smallest vocab (4691 unique tokens). We decided not to include this dataset in the main paper due to these unusual results and leave further exploration to future works.

A.10 Logistic Regression Proxy Model

Our proxy model can be shown to be equivalent to a bag-of-words logistic regression model in the classification case. Specifically, we define a bag-of-words logistic regression model to be:

t,pt=σ(βlogx).\forall t,p_{t}=\sigma(\beta^{\mathrm{log}}x). (55)

where x||x\in\mathbb{R}^{|\mathcal{I}|}, βlog|𝒪|×||\beta^{\mathrm{log}}\in\mathbb{R}^{|\mathcal{O}|\times|\mathcal{I}|}, and σ\sigma is the softmax function. The entries in xx are the number of times each word occurs in the input sequence, normalized by the sequence length. and βlog\beta^{\mathrm{log}} is learned. This is equivalent to:

t,pt=σ(1Ll=1Lβxllog).\forall t,p_{t}=\sigma(\frac{1}{L}\sum_{l=1}^{L}\beta^{\mathrm{log}}_{x_{l}}). (56)

Here βilog\beta^{\mathrm{log}}_{i} indicates the iith column of βlog\beta^{\mathrm{log}}; these are the entries in βlog\beta^{\mathrm{log}} corresponding to predictions for the iith word in the vocab. Now it is easy to arrive at the equivalence between logistic regression and our proxy model. If we restrict the rank of βlog\beta^{\mathrm{log}} to be at most min(d,|O|,|I|)\text{min}(d,|O|,|I|) by factoring it as βlog=WE\beta^{\mathrm{log}}=WE where W|𝒪|×dW\in\mathbb{R}^{|\mathcal{O}|\times d} and Ed×||E\in\mathbb{R}^{d\times|\mathcal{I}|}, then the logistic regression looks like:

t,pt=σ(1Ll=1LWExl),\forall t,p_{t}=\sigma(\frac{1}{L}\sum_{l=1}^{L}WE_{x_{l}}), (57)

which is equivalent to our proxy model:

t,pt=σ(1Ll=1LWexl).\forall t,p_{t}=\sigma(\frac{1}{L}\sum_{l=1}^{L}We_{x_{l}}). (58)

Since d=256d=256 for the proxy model, which is larger than |O|=2|O|=2 in the classification case, the proxy model is not rank limited and is hence fully equivalent to the logistic regression model. Therefore the βpx\beta^{\mathrm{px}} can be interpreted as "keywords" in the same way that the logistic regression weights can.

To empirically verify this equivalence, we trained a logistic regression model with \ell2 regularization on each of our classification datasets. To pick the optimal regularization level, we did a sweep of regularization coefficients across ten orders of magnitude and picked the one with the best validation accuracy. We report results for 𝒜(βuf,βlog)\mathcal{A}(\beta^{\mathrm{uf}},\beta^{\mathrm{log}}) in comparison to 𝒜(βuf,βpx)\mathcal{A}(\beta^{\mathrm{uf}},\beta^{\mathrm{px}}) in Table 11 101010These numbers were obtained from a retrain of all the models in the main table, so for instance, the LSTM model used to produce βuf\beta^{\mathrm{uf}} might not be exactly the same as the one used for the results in all the other tables due to random seed difference..

Note that these numbers are similar but not exactly equivalent. The reason is that the proxy model did not use \ell2 regularization, while logistic regression did.

Task 𝒜(βuf,βpx)\mathcal{A}(\beta^{\mathrm{uf}},\beta^{\mathrm{px}}) 𝒜(βuf,βlog)\mathcal{A}(\beta^{\mathrm{uf}},\beta^{\mathrm{log}})
IMDB 0.81 0.84
Yelp 0.74 0.76
AG News 0.57 0.58
20 NG 0.40 0.45
SST 0.39 0.46
Amzn 0.53 0.60
Table 11: we report 𝒜(βuf,βlog)\mathcal{A}(\beta^{\mathrm{uf}},\beta^{\mathrm{log}}) to demonstrate its effective equivalence to 𝒜(βuf,βpx)\mathcal{A}(\beta^{\mathrm{uf}},\beta^{\mathrm{px}}). These values are not exactly the same due to differences in regularization strategies.