This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Where’s Waldo: Identifying layers to interpret in BERT for RCQA

First Author
Affiliation / Address line 1
Affiliation / Address line 2
Affiliation / Address line 3
email@domain
&Second Author
Affiliation / Address line 1
Affiliation / Address line 2
Affiliation / Address line 3
email@domain
Abstract

BERT and its variants have replaced SOTA performance in multiple NLP tasks. Since then, various works have been proposed to analyse the linguistic information being captured in BERT. However, the current works do not provide an insight on how BERT is able to achieve near human-level performance on the task of Reading Comprehension based Question Answering, due to the complexity of BERT(110M parameters across 12 layers) as well as of the task QA. As a first step i) we define a layer’s role or functionality using Integrated Gradients ii) we then identify layers that have distinctive functionalities as compared to other layers iii) we perform a preliminary analysis on these distinctive layers. We observed that on the basis of Integrated Gradients, layer 0-6 majorly focuses on passage-query interactions, and later layers seems to verify the answer prediction.

1 Introduction

The past decade has seen a surge in the development of deep neural network models to solve NLP tasks. Pretrained language models such as ELMOelmo, BERTDevlin:18 , XLNetxlnet etc. have achieved state-of-the-art on various NLP tasks. For tasks such as QA, earlier models were introduced with a fixed purpose for each layer such as modelling query-passage interaction. However, BERT do not have a predefined purpose individual layers. With the rapid increase in the introduction of such models, it now becomes necessary to analyse them at a layer level in order to interpret the model.
Earlier works  (tenney2019bert; peters2018dissecting) analyse syntactic and semantic purposes for each layer in the model. Clark:19 specifically analyses BERT’s attention heads for syntactic and linguistic phenomena. Most works on the analysis of such models work with tasks such as sentiment classification, syntactic/semantic tags prediction, NLI etc.
However, there has not been much work done to analyze models for complex tasks like RCQA. We believe that this is due to the sheer number of parameters and non-linearity in deep QA models. Therefore in this work, we analyse BERT’s layers on the RCQA dataset SQuADRajpurkar:16 and take the first step in identifying logical functions for each layer leading to the identification of specific layers in the model which are to be further interpreted from a human point-of-view.
We first define 3 logical purposes a layer could have and classify BERT’s layers based into them :

  • distinctive: a distinct part in the model’s overall logic (could be one of predefined purposes as in earlier QA models, but not necessarily)

  • performance booster: performing mathematical adjustments in its high dimensions that are essential to the model’s performance but are not perceivable by the human eye.

  • verification: re-affirming/verifying the work done by the previous layers

The first 2 purposes describe layers that are essential, and the third one refers to layers that are redundant for the model’s performance and interpretability. In Section 4 we mathematically define a layer’s functionality is the way it distributes its attribution over its input items (in this case, passage words) using Integrated GradientsSundararajan:17. In Section 5 we provide pruning and comparison experiments and experiments on pruning BERT’s layers; together, these help to classify BERT’s layers into the 3 defined logical purposesFinally, we provide a qualitative study on identifying the roles of distinctive versus other layers.

2 Related Work

The past few years have seen many works on RCQA datasets (Lai:17; Nguyen:16; Joshi:17) and subsequent deep QA models (Seo:16; dhingra2016gated; Yu:18) to solve them. Numerous BERT-based models (Liu:19; Lan:19; Devlin:18) have neared human-level performance on such datasets. To study the interpretability of such models, various attribution methods  Bach:15; Sundararajan:17; Ribeiro:16 have been proposed. Works such as  (tenney2019bert; peters2018dissecting; Clark:19) focus on analyzing model layers and assigning them syntactic and semantic meaning using probing classifiers. si2019does analyzes BERT using adversarial attacks similar to earlier works Jia:17; Mudrakarta:18 and show that these models are likely to be fooled.

3 Implementation Details

SQuAD: 90k/10k train/dev samples, each with a 100-300 words passage, natural language query and answer span in the passage itself.

Integrated Gradients:. The integrated gradients for a Model MM, a passage word wiw_{i}, embedded as xi𝐑Lx_{i}\in\mathbf{R}^{L} is as follows:

IG(xi)=α=01M(x~+α(xix~))xidα\displaystyle IG(x_{i})=\int\limits_{\alpha=0}^{1}\frac{\partial M(\tilde{x}+\alpha(x_{i}-\tilde{x}))}{\partial x_{i}}\ \ d\alpha

where x~\tilde{x} is a zero vector, that serves as a baseline to measure integrated gradient for wiw_{i}. We approximate the above integral across 5050 uniform samples between [0,1][0,1].

BERT: In this work, we use the BERT-BASE model which has 12 Transformer blocks(layers) each with a multi-head self-attention and a feed-forward neural network. We use the official code and pre-trained checkpoints111https://github.com/google-research/bert and fine-tune it to get an F1 score of 88.7388.73 on SQuAD’s dev split.

4 Layer-wise Integrated Gradients

In this section, we extend integrated gradients for BERT at the layer-level.

For a given passage P consisting of nn words [w1,w2,,wn][w_{1},w_{2},\dots,w_{n}], query QQ, and model ff with θ\theta parameters, the task of predicting the answer span is modeled as :

p(ws,we)\displaystyle p(w_{s},w_{e}) =f(ws,we|P,Q,θ)\displaystyle=f(w_{s},w_{e}|P,Q,\theta)

where ws,wew_{s},w_{e} are the predicted answer start and end words or positions.

For any given layer l, the above is equivalent to:

p(ws,we)\displaystyle p(w_{s},w_{e}) =fl(ws,we|El1(P),El1(Q),θ)\displaystyle=f_{l}(w_{s},w_{e}|E_{l-1}(P),E_{l-1}(Q),\theta)

where flf_{l} is the forward propagation from layer ll to prediction. El(.)E_{l}(.), is the representation learnt for passage or query words by a given layer l. To elaborate, we consider the network below the layer l as a blackbox which generates input representations for layer l.
We now calculate the integrated gradients for each layer IGl(xi)IG_{l}(x_{i}) for all passage words wiw_{i} using Algorithm 1. We then compute importance scores for each wiw_{i} by taking the euclidean norm of IG(wi)IG(w_{i}) and then normalize it to give a probability distribution IlI_{l} over passage words.

Algorithm 1 To compute Layer-wise Integrated Gradients for layer l
1:  p~=0\tilde{p}=0  //zero baseline
2:  Gl(p)=1mk=1mfl(p~+km(pp~))ElG_{l}(p)=\frac{1}{m}\sum_{k=1}^{m}\frac{\partial f_{l}(\tilde{p}+\frac{k}{m}(p-\tilde{p}))}{\partial E_{l}}
3:  IGl(p)=[(pp~)×Gl(p)]IG_{l}(p)=[(p-\tilde{p})\times G_{l}(p)]
4:  // Compute squared norm at each row
5:  I~l([w1,,wk])=IGl(p)\tilde{I}_{l}([w_{1},\dots,w_{k}])=||IG_{l}(p)||
6:  Normalize I~l\tilde{I}_{l} to a probability distribution IlI_{l}

5 Experiments

As described in Section 4, we quantify and visualize a layer’s function as how it distributes importance over the passage words, using the distribution IlI_{l}. To compute the similarity between any two layers x,yx,y, we measure the Jensen-Shannon Divergence (JSD) between their corresponding importance distributions Ix,IyI_{x},I_{y}. The JSD scores are calculated between every pair of layers in the model, and are visualised as a nl×nln_{l}\times n_{l} heatmap (nln_{l} is the number of layers in the model). The higher the JSD score is, the more dissimilar the two layers are and the more different are the words the two layers consider as salient. All heatmaps visualised in this section are the experiment-corresponding heatmaps averaged over 500 samples in SQuAD’s dev-split.

5.1 JSD analysis

We first present pairwise layer JSD scores for BERT in Fig. 1. We observe low JSD scores between all pairs of layers, with only minimal increase as the layers go further apart(min/max JSD observed is just 0.06/0.41) giving a preliminary result that the layers are highly similar to each other.

Refer to caption
Figure 1: JSD between IlI_{l}’s

5.2 JSD with top-k retained/removed

To further evaluate the source of the similarity, we analyse the distribution in two parts: (i) we retain only top-k scores in each layer and zero out the rest. This denotes the head of the distribution. (ii) we zero the top-k scores in each layer and retain the rest, which denotes the tail of the distribution. In either case we re-normalize to maintain the probability distribution. The resulting heatmaps can be seen in Fig. 2.

Refer to caption
(a)
Refer to caption
(b)
Figure 2: JSD between IlI_{l}’s with top-2 items removed/retained

When comparing just the top-2 items in heatmap 2(a), we see higher values(min 0.08/max 0.72) than in heatmap 1; when the top-2 items are removed, we see lower values(min 0.09/max 0.26). Therefore we conclude that a layer’s function is reflected in the words high up in the importance distribution, and as they are removed, we encounter almost a uniform distribution across the less important words. Hence to correctly identify different layer functionalities, we need to focus only on the head(top-k words) and not the tail.
Further in heatmap 2(a), we see higher JSD scores when layers 0-6 are each compared with all the layers(min 0.28/max 0.72), whereas we see much lower JSD values between layers 7-11(min 0.08/max 0.32). This suggests that layers 7-11 have fairly similar functionalities.

5.3 JSD analysis split by question type

In this section, we analyze the JSD heatmaps split by question type, on the motivation that the model approaches different question types differently. For example, “what” or “who” questions require entities as answers, and in SQuAD can probably be answered more directly, whereas questions like “why” or “how” require a more in-depth reading of the passage. Hence we analyze the heatmaps for each question type separately(for the top-2 words retained). The results can be found in Fig. 3.

Refer to caption
(a)
Refer to caption
(b)
Figure 3: JSD of IlI_{l}’s, split by question types

The ‘what’ heatmap(Fig. 3(a)) indicates behaviour similar to that observed in the previous section 5.1. However, the heatmap for “why” (Fig. 3(b)), shows a slightly higher JSD in the later layers as well, supporting the hypothesis that such questions require a deeper understanding of the passage and hence more work to be done by the model. We present heatmaps for other question types in the appendix.

5.4 Case Study: Classifying Layers

We observed in Section 5.2 that layers 0-6 have high JSD scores with all of BERT’s layers; hence, we first classify them as distinctive. Layers 7-11 have low JSD scores between each other; they could be classified in either verification or performance booster. To resolve this ambiguity, we perform pruning experiments on BERT, wherein we remove certain layers from the original pre-trained BERT, train it on SQuAD and observe the change in performance.(Table 1).

Layers pruned %F1 %Drop in F1
None 88.73 -
11 88.66 0.07
10,11 87.81 0.85
9,10,11 86.58 1.23
8,9,10,11 86.4 0.18
7,8,9,10,11 85.15 1.25
6,7,8,9,10,11 83.75 1.4
Table 1: Pruned BERT models on SQuAD’s dev-set

We iteratively drop the layers to identify the impact they have on the model’s performance. We first drop only layer 11, then drop 10 & 11 and so on until layer 6.
We see that pruning layer 11 causes almost no change in the performance(only a 0.07% dip). Dropping layers 10 and then 9 cause a further dip of \sim1% each. However, dropping layer 8 does not have a huge impact(only 0.18%). Again, dropping layers 7 and then 6 have a tangible impact(1.25% and 1.4% respectively).
We have already classified layer 6 as distinctive using the JSD scores, and the pruning experiment further corroborates that. Based on the pruning results, we now classify layers 11 and 8 as verification, since their removal causes almost no reduction in the model’s performance. From the JSD experiment, layers 10,9,7 seemed redundant in functionality; however, their removal caused a noticeable dip in BERT’s performance. Hence, we classify them as performance booster.

5.5 Qualitative Analysis

Finally, we determine and analyse the top-5 words of each layer, and compute their overlap with question words and predicted answer span. We find that query words make up 14.7-24.3% of (distinctive) layers 0-6 and 10.2-18.9% of other layers 7-11. Further answer words make up 26-31.2% of layers 0-6 and 30.8-34.6% of layers 7-11.
We present a corroborating example in Table LABEL:tab:qual_eg. We see that all these six layers give a high score to the answer span itself (‘disastrous’, ‘situation’). Further, we see that the initial layers 0,1 and 2 are also trying to make a connection between the passage and the query (‘relegated’, ‘because’, ‘Polonia’ get high importance scores).
Hence, based on this qualitative analysis, we conclude that the distinctive layers focus on contextual understanding and interaction between the query and passage. In contrast, the non-distinctive layers focus on enhancing and verifying the model’s prediction.

6 Conclusion

In this work, we analyzed BERT’s layers’ functionalities using their respective importance distributions over passage words. We presented a JSD-based comparison of the same to understand how the layers work individually as well as collectively. Further, we presented a pruning experiment on BERT’s layers, which in combination with the JSD experiment, helped to classify BERT’s layers into 3 logical roles (distinctive, verification, and performance booster). Through preliminary experiments, we found that the identified distinctive layers have contextual and query-passage interaction purposes. In contrast, the other layers work on enhancing the performance and re-verifying the predicted answer. As future work, we would like to extend this to a more detailed analysis of the discovered distinctive layers.