This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Less is More: Task-aware Layer-wise Distillation for Language Model Compression

Chen Liang    Simiao Zuo    Qingru Zhang    Pengcheng He    Weizhu Chen    Tuo Zhao
Abstract

Layer-wise distillation is a powerful tool to compress large models (i.e. teacher models) into small ones (i.e., student models). The student distills knowledge from the teacher by mimicking the hidden representations of the teacher at every intermediate layer. However, layer-wise distillation is difficult. Since the student has a smaller model capacity than the teacher, it is often under-fitted. Furthermore, the hidden representations of the teacher contain redundant information that the student does not necessarily need for the target task’s learning. To address these challenges, we propose a novel Task-aware layEr-wise Distillation (TED). TED designs task-aware filters to align the hidden representations of the student and the teacher at each layer. The filters select the knowledge that is useful for the target task from the hidden representations. As such, TED reduces the knowledge gap between the two models and helps the student to fit better on the target task. We evaluate TED in two scenarios: continual pre-training and fine-tuning. TED demonstrates significant and consistent improvements over existing distillation methods in both scenarios. Code is available at https://github.com/cliang1453/task-aware-distillation.

Machine Learning, ICML

1 Introduction

Large pre-trained language models have achieved state-of-the-art performances in many natural language processing tasks (wang2018glue; rajpurkar2016squad). However, their deployment in resource-limited scenarios is hindered by their huge number of parameters (raffel2019exploring; radford2019language; brown2020language; he2020deberta; he2023debertav). Knowledge Distillation (KD) (hinton2015distilling) is a powerful tool to compress large models (i.e., teacher models) into small ones (i.e., student models) with a minimal loss of performance. This approach trains the student to match the output predictions of the teacher.

However, such a last-layer-only distillation approach does not exploit the intermediate layers of the teacher, which contain rich semantic and syntactic knowledge. To leverage such knowledge, researchers have proposed a layer-wise distillation approach, which trains the student to match the hidden representation of the teacher at each layer (sun2019patient; jiao2019tinybert; sun2020mobilebert; hou2020dynabert; zuo2022moebert). Such an approach often improves the generalization performance of the student model.

Nevertheless, layer-wise distillation faces two major challenges. First, the student may struggle to mimic the hidden representations of the teacher due to their large capacity gap. This often leads to large discrepancies between their hidden representations. Consequently, model training/optimization often favors reducing such large discrepancies over the training loss of the student (i.e., the target task’s loss such as cross-entropy), resulting in an under-fitted student model. Second, mimicking the hidden representations may not be beneficial for the target task’s learning. This is because the hidden representations of the teacher often contain redundant information (dalvi2020analyzing; durrani2020analyzing). Given the limited capacity of the student, such redundant information may compete with the useful information for distillation, hindering the useful knowledge from being distilled. Our empirical observations show that for some tasks, layer-wise distillation only marginally outperforms standard KD (Table LABEL:tb:deberta_glue).

Refer to caption
Figure 1: An illustration of TED’s two-stage training framework. In Stage I (left), we fix the model parameters and only train the filters and task-specific heads based on the target task loss. In Stage II (right), we jointly train the student and its filters by aligning the filter outputs of each pair of the teacher and the student layers.

To address these challenges, we propose a novel layer-wise distillation method, TED (Task-aware layEr-wise Distillation), which distills task-specific knowledge from the teacher to the student. We design a pair of task-aware filters for each layer of the teacher and student111For simplicity, we assume that the student and teacher are of the same depth (number of layers) but different widths. The case of different depths will be elaborated in Section 2.. Each filter is a neural network with a task-specific head (e.g., a linear soft-max layer for classification), and is trained to extract the predictive knowledge from the hidden representation of the corresponding model. Figure 1 illustrates the training procedure of TED, which consists of two stages:

\bullet Stage I: We train the task-aware filters for both the teacher and the student models, while keeping the model parameters frozen. At each layer, the filter takes the hidden representation of the model as input, and produces a target task’s loss (e.g., cross-entropy) as output. The filter is subsequently optimized based on such a loss to capture the predictive knowledge from the hidden representation.

\bullet Stage II: We jointly train the student model and its task-aware filters, while keeping the teacher and its filters fixed. At each layer, we feed the hidden representation of the teacher and the student to their respective filters (without the task-specific heads). Then, we adopt a regularizer that penalizes the discrepancy between the filtered representations. This regularizer encourages the student to learn the task-specific knowledge from the teacher, while ignoring the redundant information.

The task-aware filters serve as a selection mechanism that reduces the knowledge gap between the teacher and the student and encourages the distillation of task-specific knowledge. This makes distillation easier for the student.

We evaluate TED on two settings: continual pre-training and task-specific fine-tuning. In the continual pre-training setting, we distill a 66-layer GPT-2 student model (8282M) from a 1212-layer GPT-2 teacher model (125125M) (radford2019language). We show that TED outperforms existing methods in both zero-shot and transfer learning settings on various downstream tasks (paperno2016lambada; merity2016pointer). In the task-specific fine-tuning setting, we distill a DeBERTaV3-xsmall student model (7070M) from a DeBERTaV3-base teacher model (183183M) (he2023debertav). We demonstrate that TED achieves significant improvement on the GLUE benchmark (wang2018glue) and the SQuAD v1.1/2.0 question answering datasets (rajpurkar2016squad; rajpurkar2018know).

The rest of the paper is organized as follows: Section 2 briefly reviews the background; Section 3 presents our proposed method; Section LABEL:sec:lm presents experiments on language modeling; Section LABEL:sec:nlu presents experiments on natural language understanding; Section LABEL:sec:analysis presents analysis of models; and Section LABEL:sec:conclusion discusses and concludes the paper.

2 Background

Transformer-based Language Models. The Transformer architecture is a powerful neural network design for modeling sequential data, such as natural language (vaswani2017attention; devlin2018bert; radford2019language; he2023debertav). It consists of multiple layers that are stacked on top of each other. Each layer performs two operations: a multi-head self-attention mechanism and a two-layer feed-forward neural network. We use f(;Θ)f(\cdot;\Theta) to denote a Transformer-based model ff that has a set of parameters Θ\Theta, where ff takes an input sequence xx from the input sample space 𝒳\mathcal{X} and produces an output prediction. We define the loss function (Θ)=𝔼x𝒳[(f(x;Θ))]\mathcal{L}(\Theta)=\mathbb{E}_{x\sim\mathcal{X}}[\ell(f(x;\Theta))], where \ell is the target task loss. For example, \ell is the causal language modeling loss for generative models (i.e., t=1|x|logp(xt|x<t;Θ)\sum_{t=1}^{|x|}\log p(x_{t}|x_{<t};\Theta)).

Knowledge Distillation is a powerful approach to compress large models (i.e., teacher models) into smaller models (i.e., student models) by transferring knowledge from the former to the latter (hinton2015distilling). The student is trained to mimic the output predictions of the teacher. Specifically, we denote the teacher as ft(Θt)f_{t}(\Theta_{t}) and the student as fs(Θs)f_{s}(\Theta_{s}) and consider the following optimization problem:

minΘs(Θs)+𝒟pred(Θt,Θs),\displaystyle\min_{\Theta_{s}}\mathcal{L}(\Theta_{s})+\mathcal{D}_{\rm pred}(\Theta_{t},\Theta_{s}), (1)

where 𝒟pred(Θt,Θs)\mathcal{D}_{\rm pred}(\Theta_{t},\Theta_{s}) is the distillation loss, a distance metric between the output predictions of the teacher and the student. For example, 𝒟pred\mathcal{D}_{\rm pred} can be the KL-divergence: KL(ft(Θt)/T,fs(Θs)/T)\texttt{KL}(f_{t}(\Theta_{t})/T,f_{s}(\Theta_{s})/T), where T>0T>0 is the temperature that controls the softness of the prediction probability distributions (hinton2015distilling). A commonly adopted distillation scheme is the offline distillation, where the teacher is fully-trained and fixed, and the student is optimized based on Eq. 1.

Layer-wise Distillation. In large Transformer-based models, the output predictions of the models may not capture all the semantic and syntactic knowledge encoded in the intermediate layers. Therefore, researchers propose a layer-wise distillation approach, which aligns the hidden representations of the student and the teacher at each layer (romero2014fitnets; sun2019patient; sun2020mobilebert; jiao2019tinybert; hou2020dynabert; zuo2022moebert; liang2023homodistil). Specifically, we denote the hidden representation at the kk-th layer of a KK-layer student as Hsk|x|×dsH_{s}^{k}\in\mathbb{R}^{|x|\times d_{s}}, and at the M(k)M(k)-th layer of the teacher as HtM(k)|x|×dtH_{t}^{M(k)}\in\mathbb{R}^{|x|\times d_{t}}. Here |x||x| is the sequence length; dsd_{s} and dtd_{t} are the hidden dimensions of the student and the teacher, respectively. M()M(\cdot) is a layer mapping function that determines from which layer in the teacher that a student layer should distill. For example, if we set M(k)=2kM(k)=2k, the student would distill from every other layer in the teacher. The layer-wise distillation loss is defined as:

𝒟layer(Θt,[Θs,𝒲s])=k=1KMSE(HtM(k),HskWsk).\displaystyle\mathcal{D}_{\rm layer}(\Theta_{t},[\Theta_{s},\mathcal{W}_{s}])=\sum_{k=1}^{K}\texttt{MSE}(H_{t}^{M(k)},H_{s}^{k}W_{s}^{k}). (2)

Here MSE(,)\texttt{MSE}(\cdot,\cdot) is the mean-squared error, Wskds×dtW_{s}^{k}\in\mathbb{R}^{d_{s}\times d_{t}} is a randomly initialized and learnable linear projection that projects HskH_{s}^{k} into the same space as HtM(k)H_{t}^{M(k)}, and 𝒲s={Wsk}k=1K\mathcal{W}_{s}=\{W_{s}^{k}\}_{k=1}^{K}. In practice, the student is often optimized using multiple distillation losses, e.g.,

minΘs,𝒲s(Θs)+α1𝒟pred(Θt,Θs)+α2𝒟layer(Θt,[Θs,𝒲s]).\displaystyle\min_{\Theta_{s},\mathcal{W}_{s}}\mathcal{L}(\Theta_{s})+\alpha_{1}\mathcal{D}_{\rm pred}(\Theta_{t},\Theta_{s})+\alpha_{2}\mathcal{D}_{\rm layer}(\Theta_{t},[\Theta_{s},\mathcal{W}_{s}]). (3)

where α1,α20\alpha_{1},\alpha_{2}\geq 0 are hyper-parameters. Besides the intermediate layers, distilling knowledge from the attention scores and the embedding layers can also improve the distillation performance (sun2020mobilebert; jiao2019tinybert; wang2020minilm; wang2020minilmv2). Eq. 3 can be further extended by adding such losses.

3 Method

We introduce TED, a two-stage training framework that uses task-aware filters to distill knowledge from a teacher to a student. The task-aware filters are neural networks that learn to extract task-specific knowledge from the hidden representations of the teacher and the student. In the first stage, we add a task-aware filter to each layer of the teacher and the student. We train these filters using the task-specific loss while keeping the model parameters frozen. In the second stage, we fine-tune the student and its filters by minimizing the discrepancy between the filtered representations of the teacher and the student.

3.1 Stage I: Training Task-aware Filters

For a student that contains KK layers, we select KK corresponding layers from the teacher to match with the student using a layer mapping function, M()M(\cdot), as defined in Section 2. We then equip each layer with a task-aware filter to extract the task-specific knowledge from the hidden representation of this layer. Each filter is a neural network with a task-specific head (e.g., a linear soft-max layer for classification). It takes in the hidden representation generated by this layer and outputs a prediction for the target task. For example, for a classification task, the filter outputs a probability distribution over the classes.

For simplicity, we only specify how to train task-aware filters for the teacher. The student is treated similarly (see Section LABEL:sec:lm for details). To train the task-aware filters, we fix the parameters of the teacher, which is already pre-trained 222We discuss in detail how to initialize the teacher and the student models in Section LABEL:sec:lm and LABEL:sec:nlu.. In other words, we only update the parameters of the filters. We denote the task-aware filter at the M(k)M(k)-th layer as gtk(;Wtk)g_{t}^{k}(\cdot;W_{t}^{k}), where WtkW_{t}^{k} is the filter’s parameters. The filter takes in the hidden representation HtM(k)H_{t}^{{M(k)}} at the M(k)M(k)-th layer, and outputs a task-specific loss

tk(ΘtM(k),Wtk)=𝔼x𝒳[(gtk(HtM(k);Wtk))],\displaystyle\mathcal{L}_{t}^{k}(\Theta_{t}^{M(k)},W_{t}^{k})=\mathbb{E}_{x\sim\mathcal{X}}[\ell(g_{t}^{k}(H_{t}^{{M(k)}};W_{t}^{k}))], (4)

where ΘtM(k)\Theta_{t}^{M(k)} is the teacher’s parameters up to the M(k)M(k)-th layer. The loss function \ell depends on the task and the setting. For example, \ell is the causal language modeling loss for continual pre-training and the cross-entropy loss for fine-tuning of classification tasks. Given the loss in Eq. 4, we train the KK filters jointly:

min𝒲tk=1Ktk(ΘtM(k),Wtk),\displaystyle\min_{\mathcal{W}_{t}}\ \sum_{k=1}^{K}\mathcal{L}_{t}^{k}(\Theta_{t}^{M(k)},W_{t}^{k}), (5)

where 𝒲t={Wtk}k=1K\mathcal{W}_{t}=\{W_{t}^{k}\}_{k=1}^{K}. By training the task-aware filters, we can reduce the redundant information in the hidden representations, and keep the information that is useful for learning the target task.

Remark 3.1.

We can choose different neural network architectures to implement the task-aware filters, such as a simple linear projection that maps the input to a lower-dimensional space, a multi-layer perceptron that applies a sequence of nonlinear transformations, or a stack of Transformer layers that encode the input with attention mechanism. We compare the performances of these architectures in Section LABEL:ana:complexity.

3.2 Stage II: task-aware Layer-wise Distillation

In Stage II, we remove the task-specific heads in the task-aware filters, which are learned in Stage I. Then, we freeze the parameters of the teacher and its filters, and fine-tune the student and its filters by minimizing the discrepancy between the filtered representations at each layer of the two models.

Formally, we denote gsk(,Wsk)g_{s}^{k}(\cdot,W_{s}^{k}) as the task-aware filters at the kk-th layer of the student. Then the task-aware layer-wise distillation loss is defined as

𝒟TED([Θt,𝒲t],[Θs,𝒲s])\displaystyle\mathcal{D}_{\rm TED}\left(\left[\Theta_{t},\mathcal{W}_{t}\right],\left[\Theta_{s},\mathcal{W}_{s}\right]\right)\quad\quad\quad\quad\quad\quad\quad\quad\quad\quad
=k=1KMSE(gtk(HtM(k);Wtk),gsk(Hsk;Wsk)),\displaystyle=\sum_{k=1}^{K}\texttt{MSE}\left(g_{t}^{k}(H_{t}^{{M(k)}};W_{t}^{k}),g_{s}^{k}(H_{s}^{k};W_{s}^{k})\right), (6)

which measures the discrepancy between the filtered representations of the teacher and the student. Based on the distillation loss, the training objective for the student and its filters is

minΘs,𝒲s(Θs)+α1𝒟pred(Θt,Θs)\displaystyle\min_{\Theta_{s},\mathcal{W}_{s}}\mathcal{L}(\Theta_{s})+\alpha_{1}\mathcal{D}_{\rm pred}(\Theta_{t},\Theta_{s})\quad\quad\quad\quad
+α2𝒟TED([Θt,𝒲t],[Θs,𝒲s]),\displaystyle+\alpha_{2}\mathcal{D}_{\rm TED}([\Theta_{t},\mathcal{W}_{t}],[\Theta_{s},\mathcal{W}_{s}]), (7)

where \mathcal{L} is the target task’s loss and 𝒟pred\mathcal{D}_{\rm pred} is the prediction distillation loss defined in Eq 1 and α1,α20\alpha_{1},\alpha_{2}\geq 0 are hyper-parameters. By using the task-aware filters, Eq. 7 imposes an easier requirement on the student than the conventional layer-wise distillation loss (Eq. 3). That is, Eq. 3 requires the student to match the teacher on the unfiltered hidden representations, regardless of their relevance to the target task.