Where’s Waldo: Identifying layers to interpret in BERT for RCQA
Abstract
BERT and its variants have replaced SOTA performance in multiple NLP tasks. Since then, various works have been proposed to analyse the linguistic information being captured in BERT. However, the current works do not provide an insight on how BERT is able to achieve near human-level performance on the task of Reading Comprehension based Question Answering, due to the complexity of BERT(110M parameters across 12 layers) as well as of the task QA. As a first step i) we define a layer’s role or functionality using Integrated Gradients ii) we then identify layers that have distinctive functionalities as compared to other layers iii) we perform a preliminary analysis on these distinctive layers. We observed that on the basis of Integrated Gradients, layer 0-6 majorly focuses on passage-query interactions, and later layers seems to verify the answer prediction.
1 Introduction
The past decade has seen a surge in the development of deep neural network models to solve NLP tasks. Pretrained language models such as ELMOelmo, BERTDevlin:18 , XLNetxlnet etc. have achieved state-of-the-art on various NLP tasks. For tasks such as QA, earlier models were introduced with a fixed purpose for each layer such as modelling query-passage interaction. However, BERT do not have a predefined purpose individual layers. With the rapid increase in the introduction of such models, it now becomes necessary to analyse them at a layer level in order to interpret the model.
Earlier works (tenney2019bert; peters2018dissecting) analyse syntactic and semantic purposes for each layer in the model. Clark:19 specifically analyses BERT’s attention heads for syntactic and linguistic phenomena. Most works on the analysis of such models work with tasks such as sentiment classification, syntactic/semantic tags prediction, NLI etc.
However, there has not been much work done to analyze models for complex tasks like RCQA. We believe that this is due to the sheer number of parameters and non-linearity in deep QA models. Therefore in this work, we analyse BERT’s layers on the RCQA dataset SQuADRajpurkar:16 and take the first step in identifying logical functions for each layer leading to the identification of specific layers in the model which are to be further interpreted from a human point-of-view.
We first define 3 logical purposes a layer could have and classify BERT’s layers based into them :
-
•
distinctive: a distinct part in the model’s overall logic (could be one of predefined purposes as in earlier QA models, but not necessarily)
-
•
performance booster: performing mathematical adjustments in its high dimensions that are essential to the model’s performance but are not perceivable by the human eye.
-
•
verification: re-affirming/verifying the work done by the previous layers
The first 2 purposes describe layers that are essential, and the third one refers to layers that are redundant for the model’s performance and interpretability. In Section 4 we mathematically define a layer’s functionality is the way it distributes its attribution over its input items (in this case, passage words) using Integrated GradientsSundararajan:17. In Section 5 we provide pruning and comparison experiments and experiments on pruning BERT’s layers; together, these help to classify BERT’s layers into the 3 defined logical purposesFinally, we provide a qualitative study on identifying the roles of distinctive versus other layers.
2 Related Work
The past few years have seen many works on RCQA datasets (Lai:17; Nguyen:16; Joshi:17) and subsequent deep QA models (Seo:16; dhingra2016gated; Yu:18) to solve them. Numerous BERT-based models (Liu:19; Lan:19; Devlin:18) have neared human-level performance on such datasets. To study the interpretability of such models, various attribution methods Bach:15; Sundararajan:17; Ribeiro:16 have been proposed. Works such as (tenney2019bert; peters2018dissecting; Clark:19) focus on analyzing model layers and assigning them syntactic and semantic meaning using probing classifiers. si2019does analyzes BERT using adversarial attacks similar to earlier works Jia:17; Mudrakarta:18 and show that these models are likely to be fooled.
3 Implementation Details
SQuAD: 90k/10k train/dev samples, each with a 100-300 words passage, natural language query and answer span in the passage itself.
Integrated Gradients:. The integrated gradients for a Model , a passage word , embedded as is as follows:
where is a zero vector, that serves as a baseline to measure integrated gradient for . We approximate the above integral across uniform samples between .
BERT: In this work, we use the BERT-BASE model which has 12 Transformer blocks(layers) each with a multi-head self-attention and a feed-forward neural network. We use the official code and pre-trained checkpoints111https://github.com/google-research/bert and fine-tune it to get an F1 score of on SQuAD’s dev split.
4 Layer-wise Integrated Gradients
In this section, we extend integrated gradients for BERT at the layer-level.
For a given passage P consisting of words , query , and model with parameters, the task of predicting the answer span is modeled as :
where are the predicted answer start and end words or positions.
For any given layer l, the above is equivalent to:
where is the forward propagation from layer to prediction. , is the representation learnt for passage or query words by a given layer l. To elaborate, we consider the network below the layer l as a blackbox which generates input representations for layer l.
We now calculate the integrated gradients for each layer for all passage words using Algorithm 1. We then compute importance scores for each by taking the euclidean norm of and then normalize it to give a probability distribution over passage words.
5 Experiments
As described in Section 4, we quantify and visualize a layer’s function as how it distributes importance over the passage words, using the distribution . To compute the similarity between any two layers , we measure the Jensen-Shannon Divergence (JSD) between their corresponding importance distributions . The JSD scores are calculated between every pair of layers in the model, and are visualised as a heatmap ( is the number of layers in the model). The higher the JSD score is, the more dissimilar the two layers are and the more different are the words the two layers consider as salient. All heatmaps visualised in this section are the experiment-corresponding heatmaps averaged over 500 samples in SQuAD’s dev-split.
5.1 JSD analysis
We first present pairwise layer JSD scores for BERT in Fig. 1. We observe low JSD scores between all pairs of layers, with only minimal increase as the layers go further apart(min/max JSD observed is just 0.06/0.41) giving a preliminary result that the layers are highly similar to each other.

5.2 JSD with top-k retained/removed
To further evaluate the source of the similarity, we analyse the distribution in two parts: (i) we retain only top-k scores in each layer and zero out the rest. This denotes the head of the distribution. (ii) we zero the top-k scores in each layer and retain the rest, which denotes the tail of the distribution. In either case we re-normalize to maintain the probability distribution. The resulting heatmaps can be seen in Fig. 2.


When comparing just the top-2 items in heatmap 2(a), we see higher values(min 0.08/max 0.72) than in heatmap 1; when the top-2 items are removed, we see lower values(min 0.09/max 0.26). Therefore we conclude that a layer’s function is reflected in the words high up in the importance distribution, and as they are removed, we encounter almost a uniform distribution across the less important words. Hence to correctly identify different layer functionalities, we need to focus only on the head(top-k words) and not the tail.
Further in heatmap 2(a), we see higher JSD scores when layers 0-6 are each compared with all the layers(min 0.28/max 0.72), whereas we see much lower JSD values between layers 7-11(min 0.08/max 0.32). This suggests that layers 7-11 have fairly similar functionalities.
5.3 JSD analysis split by question type
In this section, we analyze the JSD heatmaps split by question type, on the motivation that the model approaches different question types differently. For example, “what” or “who” questions require entities as answers, and in SQuAD can probably be answered more directly, whereas questions like “why” or “how” require a more in-depth reading of the passage. Hence we analyze the heatmaps for each question type separately(for the top-2 words retained). The results can be found in Fig. 3.


The ‘what’ heatmap(Fig. 3(a)) indicates behaviour similar to that observed in the previous section 5.1. However, the heatmap for “why” (Fig. 3(b)), shows a slightly higher JSD in the later layers as well, supporting the hypothesis that such questions require a deeper understanding of the passage and hence more work to be done by the model. We present heatmaps for other question types in the appendix.
5.4 Case Study: Classifying Layers
We observed in Section 5.2 that layers 0-6 have high JSD scores with all of BERT’s layers; hence, we first classify them as distinctive. Layers 7-11 have low JSD scores between each other; they could be classified in either verification or performance booster. To resolve this ambiguity, we perform pruning experiments on BERT, wherein we remove certain layers from the original pre-trained BERT, train it on SQuAD and observe the change in performance.(Table 1).
Layers pruned | %F1 | %Drop in F1 |
None | 88.73 | - |
11 | 88.66 | 0.07 |
10,11 | 87.81 | 0.85 |
9,10,11 | 86.58 | 1.23 |
8,9,10,11 | 86.4 | 0.18 |
7,8,9,10,11 | 85.15 | 1.25 |
6,7,8,9,10,11 | 83.75 | 1.4 |
We iteratively drop the layers to identify the impact they have on the model’s performance. We first drop only layer 11, then drop 10 & 11 and so on until layer 6.
We see that pruning layer 11 causes almost no change in the performance(only a 0.07% dip). Dropping layers 10 and then 9 cause a further dip of 1% each. However, dropping layer 8 does not have a huge impact(only 0.18%). Again, dropping layers 7 and then 6 have a tangible impact(1.25% and 1.4% respectively).
We have already classified layer 6 as distinctive using the JSD scores, and the pruning experiment further corroborates that. Based on the pruning results, we now classify layers 11 and 8 as verification, since their removal causes almost no reduction in the model’s performance. From the JSD experiment, layers 10,9,7 seemed redundant in functionality; however, their removal caused a noticeable dip in BERT’s performance. Hence, we classify them as performance booster.
5.5 Qualitative Analysis
Finally, we determine and analyse the top-5 words of each layer, and compute their overlap with question words and predicted answer span. We find that query words make up 14.7-24.3% of (distinctive) layers 0-6 and 10.2-18.9% of other layers 7-11. Further answer words make up 26-31.2% of layers 0-6 and 30.8-34.6% of layers 7-11.
We present a corroborating example in Table LABEL:tab:qual_eg. We see that all these six layers give a high score to the answer span itself (‘disastrous’, ‘situation’). Further, we see that the initial layers 0,1 and 2 are also trying to make a connection between the passage and the query (‘relegated’, ‘because’, ‘Polonia’ get high importance scores).
Hence, based on this qualitative analysis, we conclude that the distinctive layers focus on contextual understanding and interaction between the query and passage. In contrast, the non-distinctive layers focus on enhancing and verifying the model’s prediction.
6 Conclusion
In this work, we analyzed BERT’s layers’ functionalities using their respective importance distributions over passage words. We presented a JSD-based comparison of the same to understand how the layers work individually as well as collectively. Further, we presented a pruning experiment on BERT’s layers, which in combination with the JSD experiment, helped to classify BERT’s layers into 3 logical roles (distinctive, verification, and performance booster). Through preliminary experiments, we found that the identified distinctive layers have contextual and query-passage interaction purposes. In contrast, the other layers work on enhancing the performance and re-verifying the predicted answer. As future work, we would like to extend this to a more detailed analysis of the discovered distinctive layers.