This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Learning Dynamic BERT via
Trainable Gate Variables and a Bi-modal Regularizer

Seohyeong Jeong
Seoul National University
[email protected]
&Nojun Kwak
Seoul National University
[email protected]
Abstract

The BERT model has shown significant success on various natural language processing tasks. However, due to the heavy model size and high computational cost, the model suffers from high latency, which is fatal to its deployments on resource-limited devices. To tackle this problem, we propose a dynamic inference method on BERT via trainable gate variables applied on input tokens and a regularizer that has a bi-modal property. Our method shows reduced computational cost on the GLUE dataset with a minimal performance drop. Moreover, the model adjusts with a trade-off between performance and computational cost with the user-specified hyperparameter.

1 Introduction

BERT Devlin et al. (2018), the large-scale pre-trained language model, has shown significant improvements in natural language processing tasks Dai and Le (2015); Rajpurkar et al. (2016); McCann et al. (2017); Peters et al. (2018); Howard and Ruder (2018). However, the model suffers from the heavy model size and high computational cost, which hinders the model to be applicable in real-time scenarios on resource-limited devices.

Refer to caption

(a) Two-stage method framework

Refer to caption

(b) One-stage method framework

Figure 1: The overview of the two-stage method and the one-stage method for dynamic inference models.
Refer to caption
Figure 2: Comparison of an original encoder block of BERT and our model with the mask matching module.

Khetan and Karnin (2020) has shown that using a “tall and narrow” architecture provides better performance than a “wide and shallow” architecture when obtaining a computationally lighter model. Inspired by this finding, we propose a task-agnostic method to dynamically learns masks on the input word vectors for each BERT layer during fine-tuning. To do this, we propose a gating module, which we call a mask matching module, that sorts and matches each of the input tokens to corresponding learned masks. Note that we use the word “gate” and “mask” interchangeably. Inspired by Srinivas et al. (2017), we train our model with an additional regularizer that has bibi-modalmodal property on top of the l1l_{1}-variant regularizer, that we suggest in this work, and an original task loss. Using a bibi-modalmodal regularizer allows the model to learn a downstream task and search the model architecture simultaneously, without requiring any further fine-tuning stage.

In this paper, we conduct experiments with BERT-base on the GLUE Wang et al. (2018) dataset and show that the mask matching module and the bibi-modalmodal regularizer enable the model to search the architecture and fine-tune on a downstream dataset simultaneously. Compared to previous works in compressing or accelerating the inference time of BERT Sanh et al. (2019); Jiao et al. (2019); Sun et al. (2019); Liu et al. (2020), our method possesses three main differences. First of all, our method allows task-agnostic dynamic inference rather than a single reduced-sized model. Secondly, our method does not require any additional stage of fine-tuning or knowledge distillation (KD) Hinton et al. (2015) and lastly, our method provides a hyperparameter that can be specified by a user to control a trade-off between the computational complexity and the performance.

2 Related Work

There have been numerous works to compress and accelerate the inference of BERT. Adopting KD, Sanh et al. (2019) attempts to distill heavy teacher models into a lighter student model. Pruning the pre-trained model is another method to handle the issue of heavy model sizes and high computational cost Michel et al. (2019); Gordon et al. (2020). Sajjad et al. (2020) prunes BERT by dropping unnecessary blocks and Goyal et al. (2020) does it by dropping semantically redundant word vectors. Some other works have introduced dynamic inference to accelerate the inference speed on BERT. Xin et al. (2020) allows early exiting and Liu et al. (2020) adjusts the number of executed blocks dynamically.

Our work is mainly inspired by Goyal et al. (2020) and we integrate dynamic inference to the sequence pruning. The main difference of our method compared to theirs and other pruning methods Michel et al. (2019); Gordon et al. (2020); Sajjad et al. (2020) is that our model allows a task-agnostic dynamic inference without the additional requirement of fine-tuning after a model architecture search, as illustrated in Figure 1 (b). There exist other works Hou et al. (2020); Goyal et al. (2020); Liu et al. (2020); Fan et al. (2019); Elbayad et al. (2019) to dynamically adjust the size and latency of the language models. However, these approaches either works in a two-stage setting where further fine-tuning or knowledge distillation is required, as shown in Figure 1 (a) or consider a depth-wise compression rather than a width-wise compression. We experimentally show that the computational cost can be reduced with minimal performance drop on GLUE Wang et al. (2018).

3 Method

In this section, we introduce our proposed method that mainly consists of a mask matching module and an additional regularizer to induce polarization on mask variables.

3.1 Mask Matching Module

As presented in Figure 2 (a), the original encoder block of BERT consists of multi-head attention and feed-forward networks. The intuition behind the mask matching module is to filter out input tokens that do not contribute as much in solving a given task so that the model can benefit from the reduced computational burden during the process of multi-head attention. Since the multi-head attention on sequences of length, LL is O(L2)O(L^{2}) in computational complexity, we expect to reduce this cost by masking out unnecessary tokens for each encoder block.

In order to learn important tokens in the training process, we introduce the mask matching module which is placed before the original encoder block of BERT, as shown in Figure 2 (b). Figure 2 (c) shows the detailed process of the mask matching module. The superscript ll represents the lthl^{th} block, which we omit from the following description in this section. The module consists of sorting input tokens according to importance scores and matching each input token to a mask, and thresholding the computed tokens with a certain value.

We first compute the importance score, sI\textbf{s}\in\mathbb{R}^{I}, of each token in the input sequence as si=j=1J|Xij|s_{i}=\sum_{j=1}^{J}|\textbf{X}_{ij}|, where X[I,J]\textbf{X}\in[I,J] is the matrix representation of the input, with II being the length of the input sequence and JJ being the size of the hidden dimension. As each token has corresponding sis_{i}, we sort the input matrix, X, input sequence-wise according to the importance score of each token. Then, sorted input matrix and sorted masks are multiplied element-wise to perform mask matching to obtain a masked matched input matrix, S[I,J]\textbf{S}\in[I,J]:

S=sort(X)expand(sort(σ(m)))\textbf{S}=sort(\textbf{X})\odot expand(sort(\sigma(\textbf{m}))) (1)

where σ()\sigma() is a sigmoid function and mI\textbf{m}\in\mathbb{R}^{I} is a parameter. Note that since σ(m)I\sigma(\textbf{m})\in\mathbb{R}^{I} and X[I,J]\textbf{X}\in[I,J], we expand σ(m)\sigma(\textbf{m}) to match the shape of X by multiplying it JJ times and stacking them.

Then, we introduce a thresholding scheme on masked tokens as follows:

th(Si,1:J)={Si,1:Jσ(mi)α0σ(mi)<αth(\textbf{S}_{i,1:J})=\begin{cases}\textbf{S}_{i,1:J}&\text{$\sigma(m_{i})\geq\alpha$}\\ \textbf{0}&\text{$\sigma(m_{i})<\alpha$}\end{cases} (2)

where α\alpha is a hyperparamter and mim_{i} is a learned mask value for ithi^{th} token in the input. The thresholed output is unsorted into the original sequence of input tokens and passed to the consecutive encoder block as an input. The final output of the masked matching module is written as follows:

Xm=unsort(th(S))\textbf{X}^{m}=unsort(th(\textbf{S})) (3)

3.2 Inducing Polarization

Traditional l1l_{1} and l2l_{2} regularizers do not guarantee well-polarized values for a gate(mask) variable. In order to induce polarization on our masks, we utilize a bi-modal regularizer proposed by Murray and Ng (2010); Srinivas et al. (2017) to learn binary values for parameters. Srinivas et al. (2017) used an overall regularizer which is a combination of the bimodalbi-modal regularizer and a traditional l1l_{1} or l2l_{2} regularizer. In this work, we use a customized regularizer, which is a variant of l1l_{1}, denoted as lfilterl_{filter}, to dynamically adjust the level of sparsity according to the user-specified hyperparameter.

lfilter=1Ln=1L|vfilter,l|,vfilter=w(vmasksvuser).\begin{gathered}l_{filter}=\frac{1}{L}\sum_{n=1}^{L}|\textbf{v}_{filter,l}|,\\ \textbf{v}_{filter}=\textbf{w}\odot(\textbf{v}_{masks}-\textbf{v}_{user}).\end{gathered} (4)

vmasks\textbf{v}_{masks}, vuser,\textbf{v}_{user},wL\in\mathbb{R}^{L} are filtering weights, mass of masks, and the user specified mass of masks with LL being the number of blocks in a model.

vmasks,l=i=1Iσ(mil),vuser,l=I×L×γ,wl=1.5{i=1Iσ(mil)}/I.\begin{gathered}\textbf{v}_{masks,l}=\sum_{i=1}^{I}\sigma(m_{i}^{l}),\\ \textbf{v}_{user,l}=I\times L\times\gamma,\\ \textbf{w}_{l}=1.5-\{\sum_{i=1}^{I}\sigma(m_{i}^{l})\}/I.\end{gathered} (5)

where II is the length of the input token sequence and 0γ10\leq\gamma\leq 1 is a hyperparameter to enforce the user-specified level of filtering tokens in the model. Then, the polarization regularizer is written as a linear combination of lfilterl_{filter} and lbimodall_{bi-modal}, which has a form of w×(1w)w\times(1-w), as follows:

Lpolar\displaystyle L_{polar} =λfilterlfilter\displaystyle=\lambda_{filter}*l_{filter} (6)
+λbil=1Li=1Iσ(mil)(1σ(mil)).\displaystyle+\lambda_{bi}*\sum_{l=1}^{L}\sum_{i=1}^{I}\sigma(m_{i}^{l})(1-\sigma(m_{i}^{l})).

Our total objective function is stated as follows:

Ltotal=Ltask+LpolarL_{total}=L_{task}+L_{polar} (7)

LtaskL_{task} is the loss for a downstream task. We show the effect of the bi-modal regularizer in Sec. 4.3.

Models GLUE-test
MNLI-(m/mm) QNLI QQP RTE SST-2 MRPC CoLA STS-B
BERT-base Devlin et al. (2018) 84.6 / 83.4 90.5 71.2 66.4 93.5 88.9 52.1 85.8
BERT-base-ours 84.5 / 83.7 90.7 71.8 62.4 93.9 83.7 51.2 78.9
(FLOPs) 10872M 10872M 10872M 10872M 10872M 10872M 10872M 10872M
Ours (γ=0.3)(\gamma=0.3) 82.2 / 81.8 87.5 69.9 58.2 92.8 84.9 33.3 79.5
(FLOPs) 3357M 3915M 3766M 4629M 3887M 4417M 2629M 3371M
(3.23×\times) (2.77×\times) (2.88×\times) (2.35×\times) (2.80×\times) (2.46×\times) (4.13×\times) (3.23×\times)
Table 1: Comparison of GLUE test results, scored by the official evaluation server. BERT-ours is our implementation of the baseline model, BERT. Performances for Ours is reported with γ=0.3\gamma=0.3. Last row shows the computational improvement compared to the FLOPs of original BERT-base.
Models GLUE-eval
MNLI-(m/mm) QNLI SST-2
BERT-base-ours 84.3 / 84.9 91.7 92.5
(FLOPs) 10872M 10872M 10872M
Ours (γ=0.1)(\gamma=0.1) 70.8 / 71.0 73.6 87.0
(FLOPs) 1907M 1872M 1865M
Ours (γ=0.2)(\gamma=0.2) 79.0 / 79.0 85.5 91.3
(FLOPs) 2904M 2883M 2883M
Ours (γ=0.3)(\gamma=0.3) 82.4 / 82.6 88.7 91.6
(FLOPs) 3357M 3915M 3887M
Ours (γ=0.4)(\gamma=0.4) 82.9 / 83.7 89.8 92.2
(FLOPs) 4919M 4926M 4863M
Ours (γ=0.5)(\gamma=0.5) 83.1 / 83.7 90.4 91.6
(FLOPs) 5923M 5994M 5916M
Ours (γ=0.6)(\gamma=0.6) 83.1 / 83.7 90.6 91.6
(FLOPs) 6962M 7033M 6676M
Ours (γ=0.7)(\gamma=0.7) 83.2 / 83.9 89.8 92.0
(FLOPs) 9218M 9027M 9182M
Ours (γ=0.8)(\gamma=0.8) 83.8 / 84.2 89.8 92.2
(FLOPs) 10829M 10398M 9818M
Ours (γ=0.9)(\gamma=0.9) 83.9 / 84.4 91.1 92.4
(FLOPs) 10872M 10872M 10872M
Table 2: Performances and FLOPs on GLUE evaluation set with different values of γ\gamma.
Models GLUE-eval
MNLI-(m/mm) QNLI SST-2
Ours (λbi=2.0)(\lambda_{bi}=2.0) 82.4 / 82.6 88.7 91.6
Ours (λbi=0.0)(\lambda_{bi}=0.0) 60.9 / 61.3 67.2 76.0
Table 3: Ablation study of the bi-modal regularizer on the GLUE evaluation set.

4 Experiments

We evaluate the proposed method on eight datasets in GLUE Wang et al. (2018) benchmark.

4.1 Implementation Details

We fine-tune the pre-trained BERT-base model on 8 datasets in the GLUE benchmark dataset for 3 epochs with a batch size of 128. The hidden dimension is set to J=768J=768 and the length of the input token sequence is set to I=128I=128. For the rest of the details, we follow the original settings of BERT.

We set α=0.5\alpha=0.5, λfilter=0.01\lambda_{filter}=0.01, and λbi=2.0\lambda_{bi}=2.0. We use the separate Adam Kingma and Ba (2014) optimizer for training mask variables. The Adam optimizer for mask variables are set with initial learning rate of 0.050.05 with two momentum parameters β1=0.9\beta_{1}=0.9 and β2=0.999\beta_{2}=0.999, and ϵ=1×108\epsilon=1\times 10^{-8}. Mask variables are initialized with random values from a uniform distribution on the interval [0, 1). We do not introduce the mask variables for the very first block in the model. Additionally, we never filter (do not mask) the first token of each input, the special token [CLS]. Introducing mask variables results in “the length of tokens ×\times the number of blocks” additional number of parameters.

4.2 Main Results

We compare our model with the BERT-base baseline. Table 1 summarizes the results of these models. Performances on the first row are taken from Devlin et al. (2018) and we show performances with our implementation on the second row. The last row shows the improvement compared to the FLOPs of the baseline model. It shows that our dynamic inference method with γ=0.3\gamma=0.3 shows minimal degradation on GLUE datasets with an average of 3 times fewer FLOPs. Furthermore, our model works in a task-agnostic manner and outputs the optimal architecture for each given downstream dataset, instead of a single reduced-sized model.

Table 2 shows that our model is capable of dynamically adjusting the computational cost with a trade-off between FLOPs and performance. It shows that the hyperparameter, γ\gamma, works properly showing proportional FLOPs to its given value. The result presents generally a consistent trade-off between FLOPs and performance.

4.3 Ablation Studies

To analyze the effect of the bi-modal regularizer, we conduct an ablation study by removing it from the training process. Table 3 shows the effect of the bi-modal regularizer and we claim that employing this regularizer during the training process plays a huge role in learning to perform well on a downstream task as well as searching the optimal model structure with the help of well-polarized mask variables. Further analysis on the behavior of mask variables with and without the bi-modal regularizer is shown in Appendix C.

5 Conclusions

In this work, we explore the task-agnostic dynamic inference method on BERT that works by masking out the input sequence for each block. To do this, we propose a mask matching module and a variant of l1l_{1} regularizer, which we call lfilterl_{filter}. Our method yields various levels of models with different performance and computational complexity, depending on the hyperparameter value that the user inputs. Conducting experiments on the GLUE dataset, our method shows that BERT, used with our method, can enjoy lighter computation with minimal performance degradation.

References

  • Dai and Le (2015) Andrew M Dai and Quoc V Le. 2015. Semi-supervised sequence learning. In Advances in neural information processing systems, pages 3079–3087.
  • Devlin et al. (2018) Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. 2018. Bert: Pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805.
  • Elbayad et al. (2019) Maha Elbayad, Jiatao Gu, Edouard Grave, and Michael Auli. 2019. Depth-adaptive transformer. arXiv preprint arXiv:1910.10073.
  • Fan et al. (2019) Angela Fan, Edouard Grave, and Armand Joulin. 2019. Reducing transformer depth on demand with structured dropout. arXiv preprint arXiv:1909.11556.
  • Gordon et al. (2020) Mitchell A Gordon, Kevin Duh, and Nicholas Andrews. 2020. Compressing bert: Studying the effects of weight pruning on transfer learning. arXiv preprint arXiv:2002.08307.
  • Goyal et al. (2020) Saurabh Goyal, Anamitra Roy Choudhury, Saurabh Raje, Venkatesan Chakaravarthy, Yogish Sabharwal, and Ashish Verma. 2020. Power-bert: Accelerating bert inference via progressive word-vector elimination. In International Conference on Machine Learning, pages 3690–3699. PMLR.
  • Hinton et al. (2015) Geoffrey Hinton, Oriol Vinyals, and Jeff Dean. 2015. Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531.
  • Hou et al. (2020) Lu Hou, Lifeng Shang, Xin Jiang, and Qun Liu. 2020. Dynabert: Dynamic bert with adaptive width and depth. arXiv preprint arXiv:2004.04037.
  • Howard and Ruder (2018) Jeremy Howard and Sebastian Ruder. 2018. Universal language model fine-tuning for text classification. arXiv preprint arXiv:1801.06146.
  • Jiao et al. (2019) Xiaoqi Jiao, Yichun Yin, Lifeng Shang, Xin Jiang, Xiao Chen, Linlin Li, Fang Wang, and Qun Liu. 2019. Tinybert: Distilling bert for natural language understanding. arXiv preprint arXiv:1909.10351.
  • Khetan and Karnin (2020) Ashish Khetan and Zohar Karnin. 2020. schubert: Optimizing elements of bert. arXiv preprint arXiv:2005.06628.
  • Kingma and Ba (2014) Diederik P Kingma and Jimmy Ba. 2014. Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980.
  • Liu et al. (2020) Weijie Liu, Peng Zhou, Zhe Zhao, Zhiruo Wang, Haotang Deng, and Qi Ju. 2020. Fastbert: a self-distilling bert with adaptive inference time. arXiv preprint arXiv:2004.02178.
  • McCann et al. (2017) Bryan McCann, James Bradbury, Caiming Xiong, and Richard Socher. 2017. Learned in translation: Contextualized word vectors. In Advances in Neural Information Processing Systems, pages 6294–6305.
  • Michel et al. (2019) Paul Michel, Omer Levy, and Graham Neubig. 2019. Are sixteen heads really better than one? arXiv preprint arXiv:1905.10650.
  • Murray and Ng (2010) Walter Murray and Kien-Ming Ng. 2010. An algorithm for nonlinear optimization problems with binary variables. Computational optimization and applications, 47(2):257–288.
  • Peters et al. (2018) Matthew E Peters, Mark Neumann, Mohit Iyyer, Matt Gardner, Christopher Clark, Kenton Lee, and Luke Zettlemoyer. 2018. Deep contextualized word representations. arXiv preprint arXiv:1802.05365.
  • Rajpurkar et al. (2016) Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, and Percy Liang. 2016. Squad: 100,000+ questions for machine comprehension of text. arXiv preprint arXiv:1606.05250.
  • Sajjad et al. (2020) Hassan Sajjad, Fahim Dalvi, Nadir Durrani, and Preslav Nakov. 2020. Poor man’s bert: Smaller and faster transformer models. arXiv preprint arXiv:2004.03844.
  • Sanh et al. (2019) Victor Sanh, Lysandre Debut, Julien Chaumond, and Thomas Wolf. 2019. Distilbert, a distilled version of bert: smaller, faster, cheaper and lighter. arXiv preprint arXiv:1910.01108.
  • Srinivas et al. (2017) Suraj Srinivas, Akshayvarun Subramanya, and R Venkatesh Babu. 2017. Training sparse neural networks. In Proceedings of the IEEE conference on computer vision and pattern recognition workshops, pages 138–145.
  • Sun et al. (2019) Siqi Sun, Yu Cheng, Zhe Gan, and Jingjing Liu. 2019. Patient knowledge distillation for bert model compression. arXiv preprint arXiv:1908.09355.
  • Wang et al. (2018) Alex Wang, Amanpreet Singh, Julian Michael, Felix Hill, Omer Levy, and Samuel R Bowman. 2018. Glue: A multi-task benchmark and analysis platform for natural language understanding. arXiv preprint arXiv:1804.07461.
  • Xin et al. (2020) Ji Xin, Raphael Tang, Jaejun Lee, Yaoliang Yu, and Jimmy Lin. 2020. Deebert: Dynamic early exiting for accelerating bert inference. arXiv preprint arXiv:2004.12993.
Refer to caption
Figure 3: Histograms of σ(m)\sigma(\textbf{m}) for each encoder block without and with the lbimodall_{bi-modal} regularizer. For MNLI-(m/mm), QNLI, and SST-2 datasets, we show the histogram of σ(m)\sigma(\textbf{m}) values without the lbimodall_{bi-modal} regularizer on the left and with the regularizer on the right. The performance on each scenario is written on the left top corner of each figure. It shows that the lbimodall_{bi-modal} regularizer not only participates in training mask variables in a well-polarized manner but also plays an important role in learning to perform well on a given task.

Appendix A Additional Details

A.1 Experimental Details

An output of multi-head attention, feed-forward network, and layer normalization from Figure 2 further needs to be masked since these computations contain bias terms. Our goal is to mask out the input matrix of each encoder block token-wise. Therefore, we apply hard masking on input matrix dimensions that are masked out by the mask matching module after computations mentioned above.

A.2 Reported Measures for GLUE

QQP and MRPC are reported with F1 scores, STS-B is reported with Spearman correlations and other tasks are reported with accuracy.

Appendix B Interpretation of lfilterl_{filter} Regularizer

We propose a variant of l1l_{1} regularizer, called lfilterl_{filter}, as shown in Eq. 4 and 5. As our lfilterl_{filter} can come across somewhat heuristic, we explain the intuition and interpretation behind the regularizer. Let’s consider an extreme case of i=1Iσ(mia)=I\sum_{i=1}^{I}\sigma(m_{i}^{a})=I and i=1Iσ(mib)=0\sum_{i=1}^{I}\sigma(m_{i}^{b})=0. Then, from the last line of Eq. 5, wa=0.5\textbf{w}_{a}=0.5 and wa=1.5\textbf{w}_{a}=1.5. This means that input tokens for the atha^{th} block are required more than input tokens for the bthb^{th} block of the model, since the prior use more tokens (mask out less tokens). As shown in the second line of Eq. 4, w works as a weight for vmasksvuser\textbf{v}_{masks}-\textbf{v}_{user}. Instead of applying same weight for each block, we intend to apply weights accordingly to the number of masks used in each block. In other words, we wish to pose heavier loss on the bthb^{th} block than the atha^{th} block of the model.

Appendix C Analysis on Mask Variables

Figure 3 shows learned values for mask variables after the sigmoid function. Each histogram has mask values after the sigmoid function on the x-axis. Since we conduct experiments on BERT-base, we show results for every block in the model from the 2nd2^{nd} block to the 12th12^{th} block from bottom to top in the figure. It shows that the lbimodall_{bi-modal} regularizer not only participates in training mask variables in a well-polarized manner but also plays an important role in learning to perform well on a given task.