Learning Dynamic BERT via
Trainable Gate Variables and a Bi-modal Regularizer
Abstract
The BERT model has shown significant success on various natural language processing tasks. However, due to the heavy model size and high computational cost, the model suffers from high latency, which is fatal to its deployments on resource-limited devices. To tackle this problem, we propose a dynamic inference method on BERT via trainable gate variables applied on input tokens and a regularizer that has a bi-modal property. Our method shows reduced computational cost on the GLUE dataset with a minimal performance drop. Moreover, the model adjusts with a trade-off between performance and computational cost with the user-specified hyperparameter.
1 Introduction
BERT Devlin et al. (2018), the large-scale pre-trained language model, has shown significant improvements in natural language processing tasks Dai and Le (2015); Rajpurkar et al. (2016); McCann et al. (2017); Peters et al. (2018); Howard and Ruder (2018). However, the model suffers from the heavy model size and high computational cost, which hinders the model to be applicable in real-time scenarios on resource-limited devices.

(a) Two-stage method framework

(b) One-stage method framework

Khetan and Karnin (2020) has shown that using a “tall and narrow” architecture provides better performance than a “wide and shallow” architecture when obtaining a computationally lighter model. Inspired by this finding, we propose a task-agnostic method to dynamically learns masks on the input word vectors for each BERT layer during fine-tuning. To do this, we propose a gating module, which we call a mask matching module, that sorts and matches each of the input tokens to corresponding learned masks. Note that we use the word “gate” and “mask” interchangeably. Inspired by Srinivas et al. (2017), we train our model with an additional regularizer that has - property on top of the -variant regularizer, that we suggest in this work, and an original task loss. Using a - regularizer allows the model to learn a downstream task and search the model architecture simultaneously, without requiring any further fine-tuning stage.
In this paper, we conduct experiments with BERT-base on the GLUE Wang et al. (2018) dataset and show that the mask matching module and the - regularizer enable the model to search the architecture and fine-tune on a downstream dataset simultaneously. Compared to previous works in compressing or accelerating the inference time of BERT Sanh et al. (2019); Jiao et al. (2019); Sun et al. (2019); Liu et al. (2020), our method possesses three main differences. First of all, our method allows task-agnostic dynamic inference rather than a single reduced-sized model. Secondly, our method does not require any additional stage of fine-tuning or knowledge distillation (KD) Hinton et al. (2015) and lastly, our method provides a hyperparameter that can be specified by a user to control a trade-off between the computational complexity and the performance.
2 Related Work
There have been numerous works to compress and accelerate the inference of BERT. Adopting KD, Sanh et al. (2019) attempts to distill heavy teacher models into a lighter student model. Pruning the pre-trained model is another method to handle the issue of heavy model sizes and high computational cost Michel et al. (2019); Gordon et al. (2020). Sajjad et al. (2020) prunes BERT by dropping unnecessary blocks and Goyal et al. (2020) does it by dropping semantically redundant word vectors. Some other works have introduced dynamic inference to accelerate the inference speed on BERT. Xin et al. (2020) allows early exiting and Liu et al. (2020) adjusts the number of executed blocks dynamically.
Our work is mainly inspired by Goyal et al. (2020) and we integrate dynamic inference to the sequence pruning. The main difference of our method compared to theirs and other pruning methods Michel et al. (2019); Gordon et al. (2020); Sajjad et al. (2020) is that our model allows a task-agnostic dynamic inference without the additional requirement of fine-tuning after a model architecture search, as illustrated in Figure 1 (b). There exist other works Hou et al. (2020); Goyal et al. (2020); Liu et al. (2020); Fan et al. (2019); Elbayad et al. (2019) to dynamically adjust the size and latency of the language models. However, these approaches either works in a two-stage setting where further fine-tuning or knowledge distillation is required, as shown in Figure 1 (a) or consider a depth-wise compression rather than a width-wise compression. We experimentally show that the computational cost can be reduced with minimal performance drop on GLUE Wang et al. (2018).
3 Method
In this section, we introduce our proposed method that mainly consists of a mask matching module and an additional regularizer to induce polarization on mask variables.
3.1 Mask Matching Module
As presented in Figure 2 (a), the original encoder block of BERT consists of multi-head attention and feed-forward networks. The intuition behind the mask matching module is to filter out input tokens that do not contribute as much in solving a given task so that the model can benefit from the reduced computational burden during the process of multi-head attention. Since the multi-head attention on sequences of length, is in computational complexity, we expect to reduce this cost by masking out unnecessary tokens for each encoder block.
In order to learn important tokens in the training process, we introduce the mask matching module which is placed before the original encoder block of BERT, as shown in Figure 2 (b). Figure 2 (c) shows the detailed process of the mask matching module. The superscript represents the block, which we omit from the following description in this section. The module consists of sorting input tokens according to importance scores and matching each input token to a mask, and thresholding the computed tokens with a certain value.
We first compute the importance score, , of each token in the input sequence as , where is the matrix representation of the input, with being the length of the input sequence and being the size of the hidden dimension. As each token has corresponding , we sort the input matrix, X, input sequence-wise according to the importance score of each token. Then, sorted input matrix and sorted masks are multiplied element-wise to perform mask matching to obtain a masked matched input matrix, :
(1) |
where is a sigmoid function and is a parameter. Note that since and , we expand to match the shape of X by multiplying it times and stacking them.
Then, we introduce a thresholding scheme on masked tokens as follows:
(2) |
where is a hyperparamter and is a learned mask value for token in the input. The thresholed output is unsorted into the original sequence of input tokens and passed to the consecutive encoder block as an input. The final output of the masked matching module is written as follows:
(3) |
3.2 Inducing Polarization
Traditional and regularizers do not guarantee well-polarized values for a gate(mask) variable. In order to induce polarization on our masks, we utilize a bi-modal regularizer proposed by Murray and Ng (2010); Srinivas et al. (2017) to learn binary values for parameters. Srinivas et al. (2017) used an overall regularizer which is a combination of the regularizer and a traditional or regularizer. In this work, we use a customized regularizer, which is a variant of , denoted as , to dynamically adjust the level of sparsity according to the user-specified hyperparameter.
(4) |
, w are filtering weights, mass of masks, and the user specified mass of masks with being the number of blocks in a model.
(5) |
where is the length of the input token sequence and is a hyperparameter to enforce the user-specified level of filtering tokens in the model. Then, the polarization regularizer is written as a linear combination of and , which has a form of , as follows:
(6) | ||||
Our total objective function is stated as follows:
(7) |
is the loss for a downstream task. We show the effect of the bi-modal regularizer in Sec. 4.3.
Models | GLUE-test | |||||||
MNLI-(m/mm) | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B | |
BERT-base Devlin et al. (2018) | 84.6 / 83.4 | 90.5 | 71.2 | 66.4 | 93.5 | 88.9 | 52.1 | 85.8 |
BERT-base-ours | 84.5 / 83.7 | 90.7 | 71.8 | 62.4 | 93.9 | 83.7 | 51.2 | 78.9 |
(FLOPs) | 10872M | 10872M | 10872M | 10872M | 10872M | 10872M | 10872M | 10872M |
Ours | 82.2 / 81.8 | 87.5 | 69.9 | 58.2 | 92.8 | 84.9 | 33.3 | 79.5 |
(FLOPs) | 3357M | 3915M | 3766M | 4629M | 3887M | 4417M | 2629M | 3371M |
(3.23) | (2.77) | (2.88) | (2.35) | (2.80) | (2.46) | (4.13) | (3.23) |
Models | GLUE-eval | ||
MNLI-(m/mm) | QNLI | SST-2 | |
BERT-base-ours | 84.3 / 84.9 | 91.7 | 92.5 |
(FLOPs) | 10872M | 10872M | 10872M |
Ours | 70.8 / 71.0 | 73.6 | 87.0 |
(FLOPs) | 1907M | 1872M | 1865M |
Ours | 79.0 / 79.0 | 85.5 | 91.3 |
(FLOPs) | 2904M | 2883M | 2883M |
Ours | 82.4 / 82.6 | 88.7 | 91.6 |
(FLOPs) | 3357M | 3915M | 3887M |
Ours | 82.9 / 83.7 | 89.8 | 92.2 |
(FLOPs) | 4919M | 4926M | 4863M |
Ours | 83.1 / 83.7 | 90.4 | 91.6 |
(FLOPs) | 5923M | 5994M | 5916M |
Ours | 83.1 / 83.7 | 90.6 | 91.6 |
(FLOPs) | 6962M | 7033M | 6676M |
Ours | 83.2 / 83.9 | 89.8 | 92.0 |
(FLOPs) | 9218M | 9027M | 9182M |
Ours | 83.8 / 84.2 | 89.8 | 92.2 |
(FLOPs) | 10829M | 10398M | 9818M |
Ours | 83.9 / 84.4 | 91.1 | 92.4 |
(FLOPs) | 10872M | 10872M | 10872M |
Models | GLUE-eval | ||
MNLI-(m/mm) | QNLI | SST-2 | |
Ours | 82.4 / 82.6 | 88.7 | 91.6 |
Ours | 60.9 / 61.3 | 67.2 | 76.0 |
4 Experiments
We evaluate the proposed method on eight datasets in GLUE Wang et al. (2018) benchmark.
4.1 Implementation Details
We fine-tune the pre-trained BERT-base model on 8 datasets in the GLUE benchmark dataset for 3 epochs with a batch size of 128. The hidden dimension is set to and the length of the input token sequence is set to . For the rest of the details, we follow the original settings of BERT.
We set , , and . We use the separate Adam Kingma and Ba (2014) optimizer for training mask variables. The Adam optimizer for mask variables are set with initial learning rate of with two momentum parameters and , and . Mask variables are initialized with random values from a uniform distribution on the interval [0, 1). We do not introduce the mask variables for the very first block in the model. Additionally, we never filter (do not mask) the first token of each input, the special token [CLS]. Introducing mask variables results in “the length of tokens the number of blocks” additional number of parameters.
4.2 Main Results
We compare our model with the BERT-base baseline. Table 1 summarizes the results of these models. Performances on the first row are taken from Devlin et al. (2018) and we show performances with our implementation on the second row. The last row shows the improvement compared to the FLOPs of the baseline model. It shows that our dynamic inference method with shows minimal degradation on GLUE datasets with an average of 3 times fewer FLOPs. Furthermore, our model works in a task-agnostic manner and outputs the optimal architecture for each given downstream dataset, instead of a single reduced-sized model.
Table 2 shows that our model is capable of dynamically adjusting the computational cost with a trade-off between FLOPs and performance. It shows that the hyperparameter, , works properly showing proportional FLOPs to its given value. The result presents generally a consistent trade-off between FLOPs and performance.
4.3 Ablation Studies
To analyze the effect of the bi-modal regularizer, we conduct an ablation study by removing it from the training process. Table 3 shows the effect of the bi-modal regularizer and we claim that employing this regularizer during the training process plays a huge role in learning to perform well on a downstream task as well as searching the optimal model structure with the help of well-polarized mask variables. Further analysis on the behavior of mask variables with and without the bi-modal regularizer is shown in Appendix C.
5 Conclusions
In this work, we explore the task-agnostic dynamic inference method on BERT that works by masking out the input sequence for each block. To do this, we propose a mask matching module and a variant of regularizer, which we call . Our method yields various levels of models with different performance and computational complexity, depending on the hyperparameter value that the user inputs. Conducting experiments on the GLUE dataset, our method shows that BERT, used with our method, can enjoy lighter computation with minimal performance degradation.
References
- Dai and Le (2015) Andrew M Dai and Quoc V Le. 2015. Semi-supervised sequence learning. In Advances in neural information processing systems, pages 3079–3087.
- Devlin et al. (2018) Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. 2018. Bert: Pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805.
- Elbayad et al. (2019) Maha Elbayad, Jiatao Gu, Edouard Grave, and Michael Auli. 2019. Depth-adaptive transformer. arXiv preprint arXiv:1910.10073.
- Fan et al. (2019) Angela Fan, Edouard Grave, and Armand Joulin. 2019. Reducing transformer depth on demand with structured dropout. arXiv preprint arXiv:1909.11556.
- Gordon et al. (2020) Mitchell A Gordon, Kevin Duh, and Nicholas Andrews. 2020. Compressing bert: Studying the effects of weight pruning on transfer learning. arXiv preprint arXiv:2002.08307.
- Goyal et al. (2020) Saurabh Goyal, Anamitra Roy Choudhury, Saurabh Raje, Venkatesan Chakaravarthy, Yogish Sabharwal, and Ashish Verma. 2020. Power-bert: Accelerating bert inference via progressive word-vector elimination. In International Conference on Machine Learning, pages 3690–3699. PMLR.
- Hinton et al. (2015) Geoffrey Hinton, Oriol Vinyals, and Jeff Dean. 2015. Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531.
- Hou et al. (2020) Lu Hou, Lifeng Shang, Xin Jiang, and Qun Liu. 2020. Dynabert: Dynamic bert with adaptive width and depth. arXiv preprint arXiv:2004.04037.
- Howard and Ruder (2018) Jeremy Howard and Sebastian Ruder. 2018. Universal language model fine-tuning for text classification. arXiv preprint arXiv:1801.06146.
- Jiao et al. (2019) Xiaoqi Jiao, Yichun Yin, Lifeng Shang, Xin Jiang, Xiao Chen, Linlin Li, Fang Wang, and Qun Liu. 2019. Tinybert: Distilling bert for natural language understanding. arXiv preprint arXiv:1909.10351.
- Khetan and Karnin (2020) Ashish Khetan and Zohar Karnin. 2020. schubert: Optimizing elements of bert. arXiv preprint arXiv:2005.06628.
- Kingma and Ba (2014) Diederik P Kingma and Jimmy Ba. 2014. Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980.
- Liu et al. (2020) Weijie Liu, Peng Zhou, Zhe Zhao, Zhiruo Wang, Haotang Deng, and Qi Ju. 2020. Fastbert: a self-distilling bert with adaptive inference time. arXiv preprint arXiv:2004.02178.
- McCann et al. (2017) Bryan McCann, James Bradbury, Caiming Xiong, and Richard Socher. 2017. Learned in translation: Contextualized word vectors. In Advances in Neural Information Processing Systems, pages 6294–6305.
- Michel et al. (2019) Paul Michel, Omer Levy, and Graham Neubig. 2019. Are sixteen heads really better than one? arXiv preprint arXiv:1905.10650.
- Murray and Ng (2010) Walter Murray and Kien-Ming Ng. 2010. An algorithm for nonlinear optimization problems with binary variables. Computational optimization and applications, 47(2):257–288.
- Peters et al. (2018) Matthew E Peters, Mark Neumann, Mohit Iyyer, Matt Gardner, Christopher Clark, Kenton Lee, and Luke Zettlemoyer. 2018. Deep contextualized word representations. arXiv preprint arXiv:1802.05365.
- Rajpurkar et al. (2016) Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, and Percy Liang. 2016. Squad: 100,000+ questions for machine comprehension of text. arXiv preprint arXiv:1606.05250.
- Sajjad et al. (2020) Hassan Sajjad, Fahim Dalvi, Nadir Durrani, and Preslav Nakov. 2020. Poor man’s bert: Smaller and faster transformer models. arXiv preprint arXiv:2004.03844.
- Sanh et al. (2019) Victor Sanh, Lysandre Debut, Julien Chaumond, and Thomas Wolf. 2019. Distilbert, a distilled version of bert: smaller, faster, cheaper and lighter. arXiv preprint arXiv:1910.01108.
- Srinivas et al. (2017) Suraj Srinivas, Akshayvarun Subramanya, and R Venkatesh Babu. 2017. Training sparse neural networks. In Proceedings of the IEEE conference on computer vision and pattern recognition workshops, pages 138–145.
- Sun et al. (2019) Siqi Sun, Yu Cheng, Zhe Gan, and Jingjing Liu. 2019. Patient knowledge distillation for bert model compression. arXiv preprint arXiv:1908.09355.
- Wang et al. (2018) Alex Wang, Amanpreet Singh, Julian Michael, Felix Hill, Omer Levy, and Samuel R Bowman. 2018. Glue: A multi-task benchmark and analysis platform for natural language understanding. arXiv preprint arXiv:1804.07461.
- Xin et al. (2020) Ji Xin, Raphael Tang, Jaejun Lee, Yaoliang Yu, and Jimmy Lin. 2020. Deebert: Dynamic early exiting for accelerating bert inference. arXiv preprint arXiv:2004.12993.

Appendix A Additional Details
A.1 Experimental Details
An output of multi-head attention, feed-forward network, and layer normalization from Figure 2 further needs to be masked since these computations contain bias terms. Our goal is to mask out the input matrix of each encoder block token-wise. Therefore, we apply hard masking on input matrix dimensions that are masked out by the mask matching module after computations mentioned above.
A.2 Reported Measures for GLUE
QQP and MRPC are reported with F1 scores, STS-B is reported with Spearman correlations and other tasks are reported with accuracy.
Appendix B Interpretation of Regularizer
We propose a variant of regularizer, called , as shown in Eq. 4 and 5. As our can come across somewhat heuristic, we explain the intuition and interpretation behind the regularizer. Let’s consider an extreme case of and . Then, from the last line of Eq. 5, and . This means that input tokens for the block are required more than input tokens for the block of the model, since the prior use more tokens (mask out less tokens). As shown in the second line of Eq. 4, w works as a weight for . Instead of applying same weight for each block, we intend to apply weights accordingly to the number of masks used in each block. In other words, we wish to pose heavier loss on the block than the block of the model.
Appendix C Analysis on Mask Variables
Figure 3 shows learned values for mask variables after the sigmoid function. Each histogram has mask values after the sigmoid function on the x-axis. Since we conduct experiments on BERT-base, we show results for every block in the model from the block to the block from bottom to top in the figure. It shows that the regularizer not only participates in training mask variables in a well-polarized manner but also plays an important role in learning to perform well on a given task.