Less is More: Task-aware Layer-wise Distillation for Language Model Compression
Abstract
Layer-wise distillation is a powerful tool to compress large models (i.e. teacher models) into small ones (i.e., student models). The student distills knowledge from the teacher by mimicking the hidden representations of the teacher at every intermediate layer. However, layer-wise distillation is difficult. Since the student has a smaller model capacity than the teacher, it is often under-fitted. Furthermore, the hidden representations of the teacher contain redundant information that the student does not necessarily need for the target task’s learning. To address these challenges, we propose a novel Task-aware layEr-wise Distillation (TED). TED designs task-aware filters to align the hidden representations of the student and the teacher at each layer. The filters select the knowledge that is useful for the target task from the hidden representations. As such, TED reduces the knowledge gap between the two models and helps the student to fit better on the target task. We evaluate TED in two scenarios: continual pre-training and fine-tuning. TED demonstrates significant and consistent improvements over existing distillation methods in both scenarios. Code is available at https://github.com/cliang1453/task-aware-distillation.
1 Introduction
Large pre-trained language models have achieved state-of-the-art performances in many natural language processing tasks (wang2018glue; rajpurkar2016squad). However, their deployment in resource-limited scenarios is hindered by their huge number of parameters (raffel2019exploring; radford2019language; brown2020language; he2020deberta; he2023debertav). Knowledge Distillation (KD) (hinton2015distilling) is a powerful tool to compress large models (i.e., teacher models) into small ones (i.e., student models) with a minimal loss of performance. This approach trains the student to match the output predictions of the teacher.
However, such a last-layer-only distillation approach does not exploit the intermediate layers of the teacher, which contain rich semantic and syntactic knowledge. To leverage such knowledge, researchers have proposed a layer-wise distillation approach, which trains the student to match the hidden representation of the teacher at each layer (sun2019patient; jiao2019tinybert; sun2020mobilebert; hou2020dynabert; zuo2022moebert). Such an approach often improves the generalization performance of the student model.
Nevertheless, layer-wise distillation faces two major challenges. First, the student may struggle to mimic the hidden representations of the teacher due to their large capacity gap. This often leads to large discrepancies between their hidden representations. Consequently, model training/optimization often favors reducing such large discrepancies over the training loss of the student (i.e., the target task’s loss such as cross-entropy), resulting in an under-fitted student model. Second, mimicking the hidden representations may not be beneficial for the target task’s learning. This is because the hidden representations of the teacher often contain redundant information (dalvi2020analyzing; durrani2020analyzing). Given the limited capacity of the student, such redundant information may compete with the useful information for distillation, hindering the useful knowledge from being distilled. Our empirical observations show that for some tasks, layer-wise distillation only marginally outperforms standard KD (Table LABEL:tb:deberta_glue).

To address these challenges, we propose a novel layer-wise distillation method, TED (Task-aware layEr-wise Distillation), which distills task-specific knowledge from the teacher to the student. We design a pair of task-aware filters for each layer of the teacher and student111For simplicity, we assume that the student and teacher are of the same depth (number of layers) but different widths. The case of different depths will be elaborated in Section 2.. Each filter is a neural network with a task-specific head (e.g., a linear soft-max layer for classification), and is trained to extract the predictive knowledge from the hidden representation of the corresponding model. Figure 1 illustrates the training procedure of TED, which consists of two stages:
Stage I: We train the task-aware filters for both the teacher and the student models, while keeping the model parameters frozen. At each layer, the filter takes the hidden representation of the model as input, and produces a target task’s loss (e.g., cross-entropy) as output. The filter is subsequently optimized based on such a loss to capture the predictive knowledge from the hidden representation.
Stage II: We jointly train the student model and its task-aware filters, while keeping the teacher and its filters fixed. At each layer, we feed the hidden representation of the teacher and the student to their respective filters (without the task-specific heads). Then, we adopt a regularizer that penalizes the discrepancy between the filtered representations. This regularizer encourages the student to learn the task-specific knowledge from the teacher, while ignoring the redundant information.
The task-aware filters serve as a selection mechanism that reduces the knowledge gap between the teacher and the student and encourages the distillation of task-specific knowledge. This makes distillation easier for the student.
We evaluate TED on two settings: continual pre-training and task-specific fine-tuning. In the continual pre-training setting, we distill a -layer GPT-2 student model (M) from a -layer GPT-2 teacher model (M) (radford2019language). We show that TED outperforms existing methods in both zero-shot and transfer learning settings on various downstream tasks (paperno2016lambada; merity2016pointer). In the task-specific fine-tuning setting, we distill a DeBERTaV3-xsmall student model (M) from a DeBERTaV3-base teacher model (M) (he2023debertav). We demonstrate that TED achieves significant improvement on the GLUE benchmark (wang2018glue) and the SQuAD v1.1/2.0 question answering datasets (rajpurkar2016squad; rajpurkar2018know).
The rest of the paper is organized as follows: Section 2 briefly reviews the background; Section 3 presents our proposed method; Section LABEL:sec:lm presents experiments on language modeling; Section LABEL:sec:nlu presents experiments on natural language understanding; Section LABEL:sec:analysis presents analysis of models; and Section LABEL:sec:conclusion discusses and concludes the paper.
2 Background
Transformer-based Language Models. The Transformer architecture is a powerful neural network design for modeling sequential data, such as natural language (vaswani2017attention; devlin2018bert; radford2019language; he2023debertav). It consists of multiple layers that are stacked on top of each other. Each layer performs two operations: a multi-head self-attention mechanism and a two-layer feed-forward neural network. We use to denote a Transformer-based model that has a set of parameters , where takes an input sequence from the input sample space and produces an output prediction. We define the loss function , where is the target task loss. For example, is the causal language modeling loss for generative models (i.e., ).
Knowledge Distillation is a powerful approach to compress large models (i.e., teacher models) into smaller models (i.e., student models) by transferring knowledge from the former to the latter (hinton2015distilling). The student is trained to mimic the output predictions of the teacher. Specifically, we denote the teacher as and the student as and consider the following optimization problem:
(1) |
where is the distillation loss, a distance metric between the output predictions of the teacher and the student. For example, can be the KL-divergence: , where is the temperature that controls the softness of the prediction probability distributions (hinton2015distilling). A commonly adopted distillation scheme is the offline distillation, where the teacher is fully-trained and fixed, and the student is optimized based on Eq. 1.
Layer-wise Distillation. In large Transformer-based models, the output predictions of the models may not capture all the semantic and syntactic knowledge encoded in the intermediate layers. Therefore, researchers propose a layer-wise distillation approach, which aligns the hidden representations of the student and the teacher at each layer (romero2014fitnets; sun2019patient; sun2020mobilebert; jiao2019tinybert; hou2020dynabert; zuo2022moebert; liang2023homodistil). Specifically, we denote the hidden representation at the -th layer of a -layer student as , and at the -th layer of the teacher as . Here is the sequence length; and are the hidden dimensions of the student and the teacher, respectively. is a layer mapping function that determines from which layer in the teacher that a student layer should distill. For example, if we set , the student would distill from every other layer in the teacher. The layer-wise distillation loss is defined as:
(2) |
Here is the mean-squared error, is a randomly initialized and learnable linear projection that projects into the same space as , and . In practice, the student is often optimized using multiple distillation losses, e.g.,
(3) |
where are hyper-parameters. Besides the intermediate layers, distilling knowledge from the attention scores and the embedding layers can also improve the distillation performance (sun2020mobilebert; jiao2019tinybert; wang2020minilm; wang2020minilmv2). Eq. 3 can be further extended by adding such losses.
3 Method
We introduce TED, a two-stage training framework that uses task-aware filters to distill knowledge from a teacher to a student. The task-aware filters are neural networks that learn to extract task-specific knowledge from the hidden representations of the teacher and the student. In the first stage, we add a task-aware filter to each layer of the teacher and the student. We train these filters using the task-specific loss while keeping the model parameters frozen. In the second stage, we fine-tune the student and its filters by minimizing the discrepancy between the filtered representations of the teacher and the student.
3.1 Stage I: Training Task-aware Filters
For a student that contains layers, we select corresponding layers from the teacher to match with the student using a layer mapping function, , as defined in Section 2. We then equip each layer with a task-aware filter to extract the task-specific knowledge from the hidden representation of this layer. Each filter is a neural network with a task-specific head (e.g., a linear soft-max layer for classification). It takes in the hidden representation generated by this layer and outputs a prediction for the target task. For example, for a classification task, the filter outputs a probability distribution over the classes.
For simplicity, we only specify how to train task-aware filters for the teacher. The student is treated similarly (see Section LABEL:sec:lm for details). To train the task-aware filters, we fix the parameters of the teacher, which is already pre-trained 222We discuss in detail how to initialize the teacher and the student models in Section LABEL:sec:lm and LABEL:sec:nlu.. In other words, we only update the parameters of the filters. We denote the task-aware filter at the -th layer as , where is the filter’s parameters. The filter takes in the hidden representation at the -th layer, and outputs a task-specific loss
(4) |
where is the teacher’s parameters up to the -th layer. The loss function depends on the task and the setting. For example, is the causal language modeling loss for continual pre-training and the cross-entropy loss for fine-tuning of classification tasks. Given the loss in Eq. 4, we train the filters jointly:
(5) |
where . By training the task-aware filters, we can reduce the redundant information in the hidden representations, and keep the information that is useful for learning the target task.
Remark 3.1.
We can choose different neural network architectures to implement the task-aware filters, such as a simple linear projection that maps the input to a lower-dimensional space, a multi-layer perceptron that applies a sequence of nonlinear transformations, or a stack of Transformer layers that encode the input with attention mechanism. We compare the performances of these architectures in Section LABEL:ana:complexity.
3.2 Stage II: task-aware Layer-wise Distillation
In Stage II, we remove the task-specific heads in the task-aware filters, which are learned in Stage I. Then, we freeze the parameters of the teacher and its filters, and fine-tune the student and its filters by minimizing the discrepancy between the filtered representations at each layer of the two models.
Formally, we denote as the task-aware filters at the -th layer of the student. Then the task-aware layer-wise distillation loss is defined as
(6) |
which measures the discrepancy between the filtered representations of the teacher and the student. Based on the distillation loss, the training objective for the student and its filters is
(7) |
where is the target task’s loss and is the prediction distillation loss defined in Eq 1 and are hyper-parameters. By using the task-aware filters, Eq. 7 imposes an easier requirement on the student than the conventional layer-wise distillation loss (Eq. 3). That is, Eq. 3 requires the student to match the teacher on the unfiltered hidden representations, regardless of their relevance to the target task.