Ensemble Distillation Approaches for Grammatical Error Correction
Abstract
Ensemble approaches are commonly used techniques to improving a system by combining multiple model predictions. Additionally these schemes allow the uncertainty, as well as the source of the uncertainty, to be derived for the prediction. Unfortunately these benefits come at a computational and memory cost. To address this problem ensemble distillation (EnD) and more recently ensemble distribution distillation (EnDD) have been proposed that compress the ensemble into a single model, representing either the ensemble average prediction or prediction distribution respectively. This paper examines the application of both these distillation approaches to a sequence prediction task, grammatical error correction (GEC). This is an important application area for language learning tasks as it can yield highly useful feedback to the learner. It is, however, more challenging than the standard tasks investigated for distillation as the prediction of any grammatical correction to a word will be highly dependent on both the input sequence and the generated output history for the word. The performance of both EnD and EnDD are evaluated on both publicly available GEC tasks as well as a spoken language task.
Index Terms— Ensemble, Distribution Distillation, Dirichlet, Transformer, Grammatical Error Correction
1 Introduction
Knowledge distillation (KD) is a general approach in machine learning where predictions from a more complex model are modelled by a simpler model, or an ensemble of models represented by a single model [2]. The aim is to reduce the computational cost and memory requirements for deployment. It has been applied successfully in many domains in deep learning, such as object detection [3], natural language processing [4, 5], acoustic models [6, 7], and also in adversarial defence [8]. This paper focuses on applying ensemble distillation approaches for sequence to sequence tasks, in this case grammatical error correction (GEC).
Deep ensembles generally perform better than single models, with the added benefit of providing measures of prediction uncertainty [9, 10]. Unfortunately these benefits come at the cost of higher computational power and memory. To address this problem, ensembles can be distilled into a single model. Normally, distillation results in information about the uncertainty in the distilled model’s prediction being lost [11]. One approach to maintaining these uncertainty measures after ensemble distillation is Ensemble Distribution Distillation [12]. Here the distilled student models the distribution of the categorical predictions from the ensemble, allowing it to retain more information about its teacher ensemble. Usually a Dirichlet distribution is used for the distribution over categorical distributions. This model can be challenging to train as the student model is required to represent significantly more information from the teacher ensemble than the standard, cross-entropy based, ensemble distillation [12]. If the distilled model is well trained, however, it should efficiently yield good results and enable, for example, detecting whether inputs are in or out-of-distribution, and yield insight into why a prediction is (un)reliable [12, 11, 13].
This paper examines distillation approaches for sequence to sequence tasks such as machine translation and grammatical error correction. For these tasks the inference computational cost is significantly larger than ”static” tasks, as the decoding process must combine the sequence predictions from each individual in the ensemble [6, 14]. Distillation schemes already exist which tackle the problem of reducing an ensemble, to save resources while also maintaining performance. Distribution distillation methods however, have not been explored for sequence models [15, 16]. This distribution distillation would enable efficient uncertainty measures for tasks such as speech-to-text systems and neural machine translation. The application area explored here is GEC [17]. The aim is to extend previous work on uncertainty for spoken language assessment [13, 17] to the problem of providing feedback to the learner, and one challenge for general GEC systems is the difference between acceptable grammar for written and spoken English.
Initially, standard distillation approach for sequence-to-sequence models, in particular auto-regressive sequence-to-sequence, are described. Note rather than modelling the (approximate) sequence posterior as in [6, 14], this work models the token-level (word) posterior as this enables a simple extension to distribution distillation for this form of model, and can be efficiently used to derive uncertainty measures [15]. The challenges of applying distribution distillation are then discussed, building on top of the work done in [12]. Two possible optimisation approaches for ensemble distribution distillation are described, as well combining standard distillation and distribution distillation models together. These approaches are all evaluated on the task of performing grammatical error correction, on both written data, where the training and test data are from the same domain, and speech data where there is a mismatch. Uncertainty performance is assessed on these tasks, and compared to rejection based on the manual, reference, corrections.
2 Ensemble Distribution Distillation
Knowledge and distribution distillation are motivated by the need for high performance systems with low memory, power, and time, requirements, whilst also yielding uncertainty measures [11, 12, 14, 15, 18]. There are many ways in which this can be achieved: here we will present work on knowledge distillation and distribution distillation. In this section, we will focus on static models, which are used for image classification [11] and spoken language assessment [13]. A teacher ensemble is represented by a set of parameters , and the parameters of the standard and distribution distilled students will be represented by and , respectively (to signify a fundamental difference). Assuming that the teacher ensemble are drawn from the posterior , then:
where represents the probability and is the input. The set of categoricals can then be used to train new students. In standard distillation [2], a student is trained to emulate the predictive distribution by minimising:
This method essentially aims to let the student predict the mean of the teacher while information about the spread (or disagreement) between members of the ensemble is lost. This means that although distillation can significantly improve memory and power consumption, the uncertainty metrics will be lost.
Instead, following [12], we let a new student model predict the parameters of a Dirichlet distribution :
The Dirichlet is a distribution over categoricals, and can encapsulate information about the uncertainty and spread in . The training can then be based on negative log-likelihood (NLL):
Now it is possible to quantify uncertainty in the prediction. Three measures of uncertainty, and their associated source, are used in this work: total, expected data and knowledge uncertainty. These are related by [19, 20]:
(1) |
These uncertainties have different properties depending on whether the input, , is in the same or different domain (distribution) to the training set. If is in-domain (ID), the ensemble members should return consistent predictions, giving low expected data uncertainty (DU). If is out-of-domain (OOD), the members will generate inconsistent predictions giving high total uncertainty (TU) [21, 22, 23].
3 Sequence Ensemble Distillation
Applying both standard and distribution distillation to sequence models adds another layer of complexity. This section covers token-level distillation schemes, as these allow uncertainties for individual tokens, words, as well as being combined to yield sequence-level uncertainties [15]. Extending the notation in the previous section to sequence-to-sequence models, the pair denotes the input-output reference sequence pair, and when necessary, represents the corresponding predicted sequence for . The teacher ensemble now makes predictions by taking expectations over the following ensemble member predictions111For a discussion of alternative approaches to ensemble predictions based on sequences see [15].:
and students are instead trained on the set .
3.1 Ensemble Distillation
For sequence models, distillation can be performed in multiple ways [14, 24]. The approach adopted in this work is token-level knowledge distillation, and is one of the simplest methods. The teacher and student use the same reference back-history (teacher-forcing) and input . The KL-divergence between the ensemble and student token-level categorical distributions is then minimised:
(2) |
Extending distribution distillation described in the previous section, this section introduces distribution distillation for sequence models. For token-level distribution distillation, the student with parameters predicts the Dirichlet distribution with parameters for the -th token:
Given from the sequence ensemble, the distribution distilled model can be trained using negative log-likelihood (NLL):
(3) |
Eq. (3) optimizes the parameters of the distribution distilled model directly from the ensemble predictions for each time instance. Though this yields an appropriate criterion to find , predicting the distribution over the ensemble members for all back-histories, it may be very challenging to optimise the network parameters. To simplify this optimization, a two stage approach may be adopted. First, the predictions of the ensemble for each back-history is modelled by a Dirichlet distribution with parameters where
The distribution distillation parameters are now trained to minimize the KL-divergence between this ensemble Dirichlet and the predicted distilled Dirichlet:
(4) |
This has the same general form as eq. (2) but is based on the KL-divergence between distributions of distributions, rather than just the expected posterior distribution based on the ensemble.
Once a distribution distilled sequence model has been obtained, the probability of predicting class is then:
as would be expected when the student parametrises a Dirichlet. This shows that achieving sequence based distribution distillation is a straightforward generalization of both ensemble distribution and sequence distillation.
Both eq. (3) and (4) result in viable objectives for obtaining a distribution distilled sequence model , and from which uncertainty metrics can be derived. Instead of calculating entropy over the classes as in eq. (1), uncertainties for sequence models have to enumerate all possible sequences . Let be the sequence of categorical distributions from which the sequence is generated:
The sequence uncertainties for the student can be expressed as:
(5) |
However, as noted in [15], calculating the uncertainties require the intractable computation of enumerating all . Instead the same approximations made in [15] will be utilized here. Given sampled sequences: , the following Monte-Carlo approximations can be made:
(6) | ||||
(7) |
The sequence-level uncertainties, or rates, in [15] may be obtained by normalising the quantities in eqs. (6) and (7) by the sequence length.
3.2 Guided Uncertainty Approach
Training distribution distilled models can be significantly harder than training to standard distillation. Hence, a two model approach will also be explored. This is based on the observation that the distribution distilled model tends to have high Spearman’s rank correlation with the ensemble predicted uncertainties in teacher-forcing mode used in training. In contrast, evaluating the model in free-run decoding, so predictions are based on , the same consistency between the model and the ensemble was not observed. To address this, the distribution distilled model was fed the back-history from the distilled model, as the distilled model was found to be less sensitive to the teacher-forcing and free-running mismatch.
Assuming we have a distilled model obtained from eq. (2), and a distribution distilled model obtained from either eq. (3) or (4), one can then perform free-run decoding according to:
The distilled model is used to predict the output sequence, which then guides the second model to return higher quality uncertainties that can be derived from . Although this method (referred to as guided uncertainty approach; GUA) does not yield the same efficiency as standard distribution distillation, it ensures that the best attributes of and are maintained in testing.
4 Experimental Evaluation
4.1 Data and Experimental Setup
The data used in this paper is the same as those described in [25]. This data has been manually annotated with grammatical errors and corrections. The training set and FCE (a specific test set) have been taken from the Cambridge Learner Corpus [26] and includes written exams of candidates with different L1 languages. The FCE test set is a held-out subset of the corpus and therefore, in-domain (ID) with the training data. NICT-JLE [27] is a public speech corpus based on non-native speech. Only manual transcriptions of these interviews involving Japanese learners are available, along with the grammatical corrections. No audio is available. BULATS [28, 29] is a corpus based on a free speaking business English test, where candidates were prompted for responses up to 1 minute. Both manual (BLT) and ASR (BLT) transcriptions were used, the average ASR WER for this data was 19.5% (see [30] for details). The candidates are drawn from 6 L1s and span the full range of proficiencies in English. These sets are out-of-domain (OOD) and will be used to test the performance of any uncertainties derived from distribution distilled models and ensembles, see Table (1).
Set | Train | FCE | NICT | BLT | BLT |
---|---|---|---|---|---|
# of Sentences | 1.8M | 2.7K | 21.1K | 3.7K | 3.6K |
Ave. Length | 13.4 | 14.0 | 6.6 | 16.6 | 16.7 |
Domain | Ref. | ID | OOD | OOD | OOD |
All models used in this work are transformers based on [31], with the default parameters of 6 layers, , , 8 heads, and 10% dropout. Input words were mapped to randomly initialized embeddings of dimension 512. Random seeds were used to generate an ensemble of 5 models [10]. As 5 models are used in the ensemble, this increases the memory requirements by a factor 5 over a single model, and approximately 5 for the decoding cost depending on the form of beam-search being used. It is possible to utilize more advanced ensemble generation methods, but is not key for this work [19, 20, 32]. Additionally, sequence models can make predictions in different ways: expectation of products, and products of expectations [15]. Since it has been found in [15] that products of expectations performs better, it will be adopted in this work when evaluating ensembles. Beam-search with a beam of one will be used, and uncertainty metrics will be evaluated on this single output.
As noted in [12], when training, the target categorical is often concentrated in one of the classes, while the Dirichlet predicted by the student often has a distribution spread over many classes. This implies that the common support between the teacher and student is poor. Optimizing KL-divergence when there is a small non-zero common support between target and prediction can lead to instability, and therefore, temperature annealing will be used to alleviate this [2]. In this work, when training a distribution distilled system, the targets from the ensemble are annealed with a temperature , following a schedule from to . Further reducing the temperature down to resulted in instability during training. To remain consistent, all results concerning ensembles will also be based on the final temperature .
The GEC performance metric used in this work is GLEU [33], and uncertainty metrics will be based at the token level. Furthermore, the distribution distilled models and ensembles will be compared when a fraction of samples are rejected (based on some metric) and passed on to an oracle. In these cases, the rejection will be based on the highest sentence level uncertainty metrics, and will be compared to simply rejecting sentences based on their length. An additional metric based on knowing the true target will be (referred to as manual), where is the true sequence length; the inclusion of length is important as sentence level GLEU does not take length into account even though it has an effect when used at a dataset level [34]. System performance will be evaluated using relative area under the curve as defined by:
(8) |
following the same reasoning as in [35] to simplify comparison between metrics; refers to fully random rejection. Finally GLEU performance will also be reported at 10% rejection.
4.2 Results
Table 2 shows the GLEU performance of a range of different models. As expected the ensemble performs best, showing performance gains over the individual ensemble members. Distilling the ensemble, using eq. (2), again yielded performance gains over the individual ensemble members, though not to the same level as the original ensemble but at reduced memory and computational cost.
Test set | Ind±σ | Ens. | Dist. | NLL | KL | GUA |
---|---|---|---|---|---|---|
FCE | 69.5±0.11 | 70.6 | 69.9 | 68.0 | 68.8 | 69.9 |
NICT | 47.2±0.20 | 48.0 | 47.8 | 44.7 | 45.7 | 47.8 |
BLT | 49.8±0.15 | 50.9 | 50.7 | 48.2 | 48.9 | 50.7 |
BLT | 31.3±0.09 | 31.6 | 31.5 | 30.9 | 31.3 | 31.5 |
Secondly, the table also shows the performance of distribution distilled models. The simplified training using KL-divergence, eq. (4), outperforms NLL, eq. (3), though neither matches the average performance of the ensemble members. This illustrates the challenges in training sequence ensemble distribution distillation models. As the KL model performed better, it will be used as the uncertainty model in GUA, though both the NLL and KL based models yielded high Spearman rank correlation coefficient with the ensemble uncertainty measures. The same trend can be seen for both the in-domain data (FCE) and the out-of-domain data.
Table 3 shows the average word-level uncertainties predicted for two datasets, FCE and the most mismatched BLT.
Test Set | Model | TU | DU | KU |
---|---|---|---|---|
FCE | Ensemble | 8.41 | 8.24 | 0.16 |
GUA | 8.37 | 8.20 | 0.18 | |
BLT | Ensemble | 8.85 | 8.67 | 0.18 |
GUA | 8.85 | 8.65 | 0.20 |
The guided model approach behaves similarly to the ensemble. High knowledge uncertainty (KU) is, according to theory, an indication of a sample being OOD and in this case as expected KU is higher for BLT—an out of-of-domain test set. Furthermore, it can be seen that all uncertainties for BLT are higher than FCE across both models.
Table 4 shows relative AUC (AUCRR) for FCE and BLT, as well as mix of the two which assess whether the uncertainty measures can detect the more challenging OOD speech data. It is interesting that as a baseline simply using the length of the of the output sequence (Length) is a good baseline, as longer sentences will tend to have more complex grammatical structure and opportunity for mistakes. KU performs best in rejecting challenging inputs. for both the ensemble and GUA, outperforming simple length selection for all conditions.
Test Set Model Length TU DU KU FCE Ensemble 0.701 0.734 0.734 0.740 GUA 0.701 0.736 0.735 0.750 BLT Ensemble 0.895 0.914 0.914 0.909 GUA 0.895 0.902 0.901 0.917 FCE + Ensemble 0.810 0.840 0.840 0.843 BLT GUA 0.813 0.837 0.837 0.856
Table 5 shows the performance when a fixed percentage, 10%, is ”rejected” and manually corrected. These results can be compared to the baseline numbers in Table 2, show significant gains from rejecting 10% of the data. Similar trends to that shown in AUCRR can be seen, KU generally yields the best performance, outperforming the baseline length approach in all conditions. GUA yields similar performance to the original ensemble, but at reduced memory and computational costs.
Test Set Model Length TU DU KU Manual FCE Ens. 80.5 80.8 80.8 81.0 83.4 GUA 80.1 80.5 80.5 80.5 82.9 BLT Ens. 52.4 54.4 54.2 53.2 56.5 GUA 51.9 52.6 52.5 53.4 56.3 FCE + Ens. 65.9 66.9 66.8 66.9 68.7 BLT GUA 65.3 65.8 65.8 66.2 68.3
5 Conclusion
This work describes the application of ensemble distillation approaches to sequence data. Two forms of distillation are discussed: standard distillation where the distilled model predicts the ensemble mean; and distribution distillation where the distribution over the ensemble predictions is modelled. Though more challenging to train, ensemble distribution distillation yields both an ensemble prediction and uncertainty measures associated with prediction which allows, for example, more challenging predictions to be manually corrected. The approaches were evaluated on grammatical error correction tasks using data either matched to the written training data, or mismatched speech data. Standard distillation was found to work well for sequence distillation, but distribution distillation acting alone did not yield good performance. By the combining the predictions from standard distillation with the uncertainty predictions of distribution distillation, both good performance and uncertainty measures could be obtained, with reduced computational memory costs compared to the original ensemble.
Future work will examine: improved optimisation approaches and approximations for sequence ensemble distribution distillation; and alternative uncertainty criteria described in [15].
References
- [1]
- [2] G. Hinton, O. Vinyals, & J. Dean, ”Distilling the Knowledge in a Neural Network,” Proc. NIPS, Montreal, Canada, 2014.
- [3] G. Chen, W. Choi; X. Yu, T. Han, & M. Chandraker, ”Learning Efficient Object Detection Models with Knowledge Distillation, ” Proc. NeurIPS, Long Beach, CA, United States, 2017.
- [4] J. Cui, B. Kingsbury, B. Ramabhadran, G. Saon, T. Sercu, K. Audhkhasi, A. Sethy, M. Nussbaum-Thom, & A. Rosenberg, ”Knowledge Distillation Across Ensembles of Multilingual Models for Low-resource Languages,” Proc. ICASSP, New Orleans, LA, United States, 2017.
- [5] R. Yu, A. Li, V. I. Morariu, & L. S. Davis, ”Visual Relationship Detection with Internal and External Linguistic Knowledge Distillation,” Proc. ICCV, Venice, Italy, 2017.
- [6] J. H. M. Wong & M. J. F. Gales, ”Sequence Student-teacher Training of Deep Neural Networks,” Proc. Interspeech, San Francisco, CA, United States, 2016.
- [7] T. Asami, R. Masumura, Y. Yamaguchi, H. Masataki, & Y. Aono, ”Domain Adaptation of DNN Acoustic Models using Knowledge Distillation,” Proc. ICASSP, New Orleans, LA, United States, 2017.
- [8] N. Papernot, P. McDaniel, X. Wu, S. Jha, & A. Swami, “Distillation as a Defense to Adversarial Perturbations Against Deep Neural Networks,” Proc. SP, San José, Costa Rica, 2016.
- [9] Kevin P. Murphy. Machine Learning. The MIT Press, 2012.
- [10] B. Lakshminarayanan, A. Pritzel, & C. Blundell, ”Simple and Scalable Predictive Uncertainty Estimation using Deep Ensembles,” Proc. NeurIPS, Long Beach, CA, United States, 2017.
- [11] A. Malinin & M. J. F. Gales, ”Predictive Uncertainty Estimation via Prior Networks,” Proc. NeurIPS, Montreal, Canada, 2018.
- [12] A. Malininm, B. Mlodozeniec, & M. J. F. Gales, ”Ensemble Distribution Distillation,” Proc. ICLR, 2020.
- [13] X. Wu , K. M. Knill, M. J. F. Gales, & A. Malinin, “Ensemble Approaches for Uncertainty in Spoken Languagege Assessment,” Proc. Interspeech, Shanghai, China, 2020.
- [14] Y. Kim & A. M. Rush, “Sequence-Level Knowledge Distillation,” Proc. EMNLP, Austin, TX, United States, 2016.
- [15] A. Malinin & M. J. F. Gales, ”Uncertainty in Structured Prediction,” arXiv.org, 2002.07650v3, 2020.
- [16] C. Zhou, J. Gu, & G. Neubig, ”Understanding Knowledge Distillation in Non-autoregressive Machine Translation,” Proc. ICLR, 2020.
- [17] Z. Yuan & T. Briscoe, “Grammatical Error Correction Using Neural Machine Translation,” Proc. ACL, San Diego, CA, United States, 2016.
- [18] R. Pang, T. N. Sainath, R. Prabhavalkar, S. Gupta, Y. Wu, S. Zhang, & C. Chiu, ”Compression of End-to-End Models”, Proc. Interspeech, Hyderabad, India, 2018.
- [19] J. M. Hernandez-Lobato & R. Adams, ”Probabilistic Backpropagation for Scalable Learning of Bayesian Neural Networks,” Proc. ICML, Lille, France, 2015.
- [20] Y. Gal & Z. Ghahramani, ”Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning,” Proc. ICML, New York City, NY, United States, 2016.
- [21] L. Smith & Y. Gal, “Understanding Measures of Uncertainty for Adversarial Example Detection,” Proc. UAI, Monterey, CA, United States, 2018.
- [22] Y. Gal, ”Uncertainty in Deep Learning”, Ph.D. thesis, University of Cambridge, 2016.
- [23] J. Quiñonero-Candela, ”Dataset Shift in Machine Learning”, The MIT Press, 2009.
- [24] M. Huang, Y. You, Z. Chen, Y. Qian, & K. Yu, ”Knowledge Distillation for Sequence Model,” Proc. Interspeech, Hyderabad, India, 2018.
- [25] Y. Lu, M. J. F. Gales, & Y. Wang, “Spoken Language ‘Grammatical Error Correction’,” Proc. Interspeech, Shanghai, China, 2020.
- [26] D. Nicholls, “The Cambridge Learner Corpus: Error Coding and Analysis for Lexicography and ELT,” Proc. Corpus Linguistics, Lancaster, United Kingdom, 2003.
- [27] E. Izumi, K. Uchimoto, & H. Isahara, “The NICT JLE Corpus Exploiting the Language Learners’ Speech Database for Research and Education,” Proc. IJCIM, Bangkok, Thailand, 2004.
- [28] L. Chambers & K. Ingham, “The BULATS Online Speaking Test,” Research Notes, vol. 43, pp. 21–25, 2011. [Online]. Available: http://www.cambridgeenglish.org/images/23161-research-notes-43.pdf
- [29] K. M. Knill, M. J. F. Gales, P. P. Manakul, & A. Caines, ”Automatic Grammatical Error Detection of Non-native Spoken Learner English,” Proc. ICASSP, Brighton, United Kingdom, 2019.
- [30] Y. Lu, M. J. F. Gales, K. M. Knill, P. P. Manakul, L. Wang, & Y. Wang, ”Impact of ASR performance on spoken grammatical error detection,” Proc. Interspeech, Graz, Austria, 2019.
- [31] A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, & Ł. Kaiser, ”Attention is All You Need,” Proc. NeurIPS, Long Beach, CA, United States, 2017.
- [32] W. Maddox, T. Garipov, P. Izmailov, D. P. Vetrov, & A. G. Wilson, ”A Simple Baseline for Bayesian Uncertainty in Deep Learning,” Proc. NIPS, Vancouver, Canada, 2019.
- [33] C. Napoles, K. Sakaguchi, M. Post, & J. Tetreault, ”Ground truth for grammatical error correction metrics,” Proc. ACL, Beijing, China, 2015.
- [34] Y. Wu, M. Schuster, Z. Chen, Q. V. Le, et al. ”Google’s Neural Machine Translation System: Bridging the Gap between Human and Machine Translation,” arXiv.org, 1609.08144v2, 2016.
- [35] A. Malinin, A. Ragni, K. M. Knill, & M. J. F. Gales, ”Incorporating Uncertainty into Deep Learning for Spoken Language Assessment, ” Proc. ACL, Vancouver, Canada, 2017.
novel approach to distribution distilling sequence models, and additionally presented a second improved objective function. Since distribution distillation can be cumbersome, a less restricted approach, where one distribution distills onto two models has also been introduced. The use of a separate uncertainty model significantly improves upon the single model distribution distillation approach, and additionally has the added benefit that it can perform as well as any other available model. Furthermore, this distribution distillation approach was also shown to perform similarly to an ensemble when used in a hybrid Oracle-DNN system, showing that distribution distilled models can perform as well as their teacher ensembles and replace them to save resources. Future studies can build on top of this by focusing on obtaining well performing single model distribution distilled models, and performing sequence-level distribution distillation as opposed to word- level performed here.