This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Pruning Pre-trained Language Models with Principled Importance and Self-regularization

Siyu Ren           Kenny Q. Zhu
Shanghai Jiao Tong University
Shanghai, China
[email protected], [email protected]
  The corresponding author.
Abstract

Iterative pruning is one of the most effective compression methods for pre-trained language models. We discovered that finding the optimal pruning decision is an equality-constrained 0-1 Integer Linear Programming problem. The solution to this optimization problem leads to a principled importance criterion which we use to rank parameters during iterative model pruning. To mitigate the poor generalization at high sparsity levels, we propose a self-regularization scheme where model prediction is regularized by the latest checkpoint with increasing sparsity throughout pruning. Our experiments on natural language understanding, question answering, named entity recognition, and data-to-text generation with various Transformer-based PLMs show the effectiveness of the approach at various sparsity levels.

1 Introduction

Pre-trained language models (PLMs) Devlin et al. (2019); Radford et al. (2018) have significantly advanced the state-of-the-art in various natural language processing tasks Wang et al. (2018); Zhou and Lampouras (2020); Dušek et al. (2020); Radev et al. (2020). However, these models often contain a vast amount of parameters, posing non-trivial requirements for storage and computation. Due to this inefficiency, the applications of PLMs in resource-constrained scenarios are still limited.

To resolve the above challenge, model compression Sun et al. (2019); Ben Noach and Goldberg (2020); Lan et al. (2020) has been actively studied to make PLMs meet the practical requirement. Among them, iterative pruning methods are widely adopted at only a tiny expense of model performance when adapting PLMs to downstream tasks. During the course of iterative pruning, model parameters can not only be updated but also be pruned based on the rank of their importance scores in order to satisfy the cardinality constraint. Prevalent importance criteria are based on the parameter’s magnitude Zhu and Gupta (2017); Renda et al. (2020) or sensitivity Louizos et al. (2018); Sanh et al. (2020); Liang et al. (2021); Zhang et al. (2022). Parameters with low importance scores are pruned and are expected to have little impact on model performance.

Despite the empirical success, existing importance criteria for model pruning still face two major limitations: (1) they are heuristically defined and may not accurately quantify a parameter’s contribution to the learning process, e.g., absolute weight value in magnitude-based pruning and gradient-weight product in sensitivity-based pruning; (2) they determine the importance of each parameter individually without considering the effect of coinstantaneous parameter updates on model performance, e.g., sensitivity is estimated by the absolute change in training error if only a single parameter is pruned and others remain unchanged.

In this paper, we rethink the design of the importance criterion for model pruning from an optimization perspective. We begin by analyzing the temporal variation of any given learning objective based on a single-step gradient descent update under the iterative pruning setting. We show that finding the optimal pruning decision can be framed as solving an equality-constrained 0-1 Integer Linear Programming (ILP) problem, where the constraint is defined by the specified sparsity. The resulting problem is a particular case of a general 0-1 Knapsack problem in which the weight for each item is the same. The solution to this problem naturally leads to a principled importance criterion which we use to rank all model parameters and derive the optimal stepwise pruning decision.

When a high sparsity (e.g., 80%\sim90%) is pursued, the limited capacity often renders the pruned model fails to retain satisfactory performance with conventional fine-tuning. To further improve the model’s generalization ability, we propose a self-regularization scheme, where the model prediction is regularized by the latest best-performing model checkpoint during pruning. We show that such a scheme eases model learning with decreasing capacity and effectively yields a tighter upper bound of expected generalization error than learning from training data alone.

To validate the effectiveness of our approach, dubbed PINS (Pruning with principled Importance aNd Self-regularization), we conducted extensive experiments with various pre-trained language models on a wide variety of tasks, including natural language understanding on GLUE Wang et al. (2018)), question answering on SQuAD Rajpurkar et al. (2016), named entity recognition on CoNLL 2003 Tjong Kim Sang and De Meulder (2003), and data-to-text generation on WebNLG Zhou and Lampouras (2020), DART Radev et al. (2020), and E2E Dušek et al. (2020). Experimental results show that PINS provides more accurate models at different sparsity levels. Detailed analysis shed further light on some intriguing properties of models pruned by PINS. By exploiting the resulting high sparsity, we show that the storage/inference can be reduced/accelerated by 8.9x and 2.7x using CSR format and a sparsity-aware inference runtime Kurtz et al. (2020) on consumer-level CPUs 111Code available at https://github.com/DRSY/PINS.

In summary, our contributions are:

  • We establish the equivalence between the optimal pruning decision and the solution to an equality-constrained 0-1 Integer Linear Programming problem. The solution to this problem leads to a principled importance criterion that can be used to rank parameters during iterative pruning.

  • We propose a simple yet effective self-regularization scheme to enhance the model’s generalization capability, especially under a high-sparsity regime.

  • Comprehensive experiments and analyses confirm the effectiveness of our approach at various sparsity levels.

2 Background and Related Work

In this section, we review the necessary background on Transformer-based pre-trained language models and popular importance criteria for iterative pruning.

2.1 Transformer-based Pre-trained Language Models

Most existing pre-trained neural language models Radford et al. (2018); Devlin et al. (2019); Wang et al. (2020); Clark et al. (2020) are based on the Transformer Vaswani et al. (2017) architecture, which consists of several identical blocks of self-attention and feedforward network. After pre-training on a massive amount of unlabeled general-domain corpus in a self-supervised learning manner, these models exhibit superior performance on various downstream tasks via fine-tuning. However, good generalization performance comes at the cost of a vast amount of parameters. For example, the base version of BERT has 110M parameters and leads to more than 400MB of disk storage. Therefore, how to effectively reduce model size while preserving as much task accuracy as possible remains a challenging research problem.

2.2 Iterative Pruning

Pruning methods can be divided into two categories: one-shot pruning Lee et al. (2018); Frankle and Carbin (2018) and iterative pruning Louizos et al. (2018); Sanh et al. (2020); Zhang et al. (2022). One-shot pruning removes parameters of low importance after training. It is efficient but ignores the complicated training dynamics when applied to modern large neural language models. On the contrary, iterative pruning performs training and pruning simultaneously. Therefore, the resulting sparsity pattern is aware of the complex dynamics of parameters through the course of training and delivers considerable improvement compared to one-shot pruning.

Let 𝜽(t)={θ1(t)θ2(t),,θd(t)}\bm{\theta}^{(t)}=\{\theta_{1}^{(t)}\,\theta_{2}^{(t)},...,\theta_{d}^{(t)}\} denote the dd-dimensional model parameters at tt-th training iteration, the typical updating rule of iterative pruning can be formulated as:

𝜽^(t+1)\displaystyle\hat{\bm{\theta}}^{(t+1)} =𝜽(t)η(t)𝜽(𝜽(t))\displaystyle=\bm{\theta}^{(t)}-\eta^{(t)}\nabla_{\bm{\theta}}\mathcal{L}(\bm{\theta}^{(t)}) (1)
𝜽(t+1)\displaystyle\bm{\theta}^{(t+1)} =𝜽^(t+1)𝑴(t)\displaystyle=\hat{\bm{\theta}}^{(t+1)}\odot\bm{M}^{(t)} (2)

where η(t)\eta^{(t)} is the learning rate at time step tt and \mathcal{L} is the learning objective. The temporarily updated 𝜽^(t+1)\hat{\bm{\theta}}^{(t+1)} is further pruned by the binary mask 𝑴(t)\bm{M}^{(t)}\in {0,1}d\{0,1\}^{d}, which is computed based on a given importance criterion 𝑺(t)\bm{S}^{(t)}:

𝑴i(t)={1,if 𝑺i(t)is in the top-r(t)of 𝑺(t)0,otherwise\displaystyle\bm{M}^{(t)}_{i}=\begin{cases}1,&\text{if }~{}\bm{S}^{(t)}_{i}\text{is in the top-}r^{(t)}\text{of }\bm{S}^{(t)}\\ 0,&\text{otherwise}\end{cases} (3)

where r(t)r^{(t)}d\leq d indicates the number of remaining parameters at time step tt according to a given sparsity scheduler.

2.3 Importance Criteria for Model Pruning

Popular importance criteria for model pruning include parameters’ magnitude and sensitivity.

Magnitude

is a simple yet effective importance criterion that is widely used for model pruning. It estimates the importance of each parameter as its absolute value, i.e., 𝑺i(t)=|𝜽i(t)|\bm{S}^{(t)}_{i}=|\bm{\theta}^{(t)}_{i}|. Despite its simplicity, the magnitude cannot accurately gauge the importance of parameters because even parameters with small magnitude can have a large impact on the model prediction due to the complex compositional structure of PLMs.

Sensitivity

is another useful importance criterion. It estimates the importance of each parameter as the absolute change of the learning objective if the parameter is pruned, i.e., set to zero. The mathematical formulation of the sensitivity of ii-th parameter is given by:

𝑺i(t)\displaystyle\bm{S}^{(t)}_{i} =|(𝜽i(t))(𝜽(t))|\displaystyle=|\mathcal{L}(\bm{\theta}^{(t)}_{-i})-\mathcal{L}(\bm{\theta}^{(t)})| (4)
|𝒈i(t)𝜽i(t)|\displaystyle\approx|\bm{g}_{i}^{(t)}\bm{\theta}^{(t)}_{i}| (5)

where 𝜽i(t)\bm{\theta}^{(t)}_{-i} is identical to 𝜽(t)\bm{\theta}^{(t)} except that the ii-th entry is set to zero and gi(t)g_{i}^{(t)} is the gradient of ii-th entry. Though taking the training dynamics into account, sensitivity still estimates the importance of each parameter individually without considering the effect of holistic parameter update.

3 Methodology

Instead of heuristically defining the importance criterion as in prior pruning methods, we take a step back and rethink the design of the importance criterion for model pruning from an optimization perspective. From our analysis, we draw an equivalence between finding the optimal stepwise pruning decision and solving an equality-constrained 0-1 Integer Linear Programming problem. We further show that the optimal solution to this problem leads to a new importance criterion for model pruning. Moreover, we propose a simple yet effective self-regularization scheme to facilitate the generalization ability of the sparse model. We elucidate our analysis in Section 3.1 and describe our self-regularization scheme in Section 1.

3.1 Rethinking Importance Criterion from the Optimization Perspective

Without loss of generality, we denote \mathcal{L} as the learning objective when adapting a pre-trained language model ff with parameter 𝜽\bm{\theta} to a downstream task. At tt-th training iteration, we denote the current model parameters as 𝜽(t)\bm{\theta}^{(t)} and the evaluated learning objective as (𝜽(t))\mathcal{L}(\bm{\theta}^{(t)}).

The temporal variation of the learning objective (𝜽(t))\mathcal{L}(\bm{\theta}^{(t)}) at time step tt is given by the second-order Taylor series expansion:

Δ(t)\displaystyle\Delta\mathcal{L}^{(t)} =(𝜽(t)+Δ𝜽(t))(𝜽(t))\displaystyle=\mathcal{L}(\bm{\theta}^{(t)}+\Delta\bm{\theta}^{(t)})-\mathcal{L}(\bm{\theta}^{(t)}) (6)
=𝜽(𝜽(t))Δ𝜽(t)+\displaystyle=\nabla_{\bm{\theta}}\mathcal{L}(\bm{\theta}^{(t)})^{\top}\Delta\bm{\theta}^{(t)}+
12Δ𝜽(t)𝑯(t)Δ𝜽(t)+o(|Δ𝜽(t)|2)\displaystyle\frac{1}{2}\Delta\bm{\theta}^{(t)^{\top}}\bm{H}^{(t)}\Delta\bm{\theta}^{(t)}+o(|\Delta\bm{\theta}^{(t)}|^{2}) (7)

where 𝑯(t)\bm{H}^{(t)} is the Hessian matrix at step tt. It is known that the largest eigenvalue λmax\lambda_{max} of Hessian matrices in a PLM is typically small Shen et al. (2019), i.e., Δ𝜽(t)𝑯(t)Δ𝜽(t)λmax|Δ𝜽(t)|220\Delta\bm{\theta}^{(t)^{\top}}\bm{H}^{(t)}\Delta\bm{\theta}^{(t)}\leq\lambda_{max}|\Delta\bm{\theta}^{(t)}|_{2}^{2}\approx 0. Thus, we ignore the second-order term as well as the infinitesimal of higher order in Eq. (7):

Δ(t)\displaystyle\Delta\mathcal{L}^{(t)} =𝜽(𝜽(t))Δ𝜽(t)\displaystyle=\nabla_{\bm{\theta}}\mathcal{L}(\bm{\theta}^{(t)})^{\top}\Delta\bm{\theta}^{(t)}
=i=1d𝒈i(t)Δ𝜽i(t)\displaystyle=\sum_{i=1}^{d}\bm{g}_{i}^{(t)}\cdot\Delta\bm{\theta}^{(t)}_{i} (8)

Under the iterative pruning setting, the actual temporal variation Δ𝜽i(t)\Delta\bm{\theta}^{(t)}_{i} of ii-th parameter depends on whether it is allowed to be updated or forced to zeroed out. Formally, we use a binary variable 𝒙i(t)\bm{x}_{i}^{(t)} to indicate the pruning decision of ii-th parameter at time step tt, i.e., 𝒙i(t)=1\bm{x}_{i}^{(t)}=1 means 𝜽i(t)\bm{\theta}^{(t)}_{i} is updated and 𝒙i(t)=0\bm{x}_{i}^{(t)}=0 means 𝜽i(t)\bm{\theta}^{(t)}_{i} is pruned. The temporal variation in Eq. (8) can now be rewritten as:

Δ(t)=i=1d𝒈i(t)(𝒙i(t)Δ𝜽^i(t)+(1𝒙i(t))(𝜽i(t)))\displaystyle\Delta\mathcal{L}^{(t)}=\sum_{i=1}^{d}\bm{g}_{i}^{(t)}(\bm{x}_{i}^{(t)}\Delta\hat{\bm{\theta}}_{i}^{(t)}+(1-\bm{x}_{i}^{(t)})(-\bm{\theta}_{i}^{(t)})) (9)

where Δ𝜽^i(t)=η(t)𝒈i(t)\Delta\hat{\bm{\theta}}_{i}^{(t)}=-\eta^{(t)}\bm{g}_{i}^{(t)} is the gradient descent update. Finding the optimal pruning decision that leads to the smallest Δ(t)\Delta\mathcal{L}^{(t)} is now converted to an equality-constrained 0-1 integer linear programming (ILP) problem of variables 𝒙(t)\bm{x}^{(t)}:

𝒙~(t)\displaystyle\tilde{\bm{x}}^{(t)} =argmin𝒙(t)Δ(t)\displaystyle=\underset{\bm{x}^{(t)}}{\arg\min}~{}\Delta\mathcal{L}^{(t)}
s.t. i=1d\displaystyle\text{s.t.~{}~{}~{}~{}~{}}\sum_{i=1}^{d} 𝒙i(t)=r(t),𝒙i(t){0,1}\displaystyle\bm{x}_{i}^{(t)}=r^{(t)},\bm{x}_{i}^{(t)}\in\{0,1\} (10)

where r(t)r^{(t)} is the number of remaining parameters at step tt according to the pre-defined sparsity scheduler. If we consider each parameter 𝜽i(t)\bm{\theta}^{(t)}_{i} as an item and r(t)r^{(t)} as the total capacity, the problem that Eq. (10) defines can be treated as a special case of 0-1 Knapsack problem where the weight for each item is one and the value for each item is given by:

𝑺i(t)=𝒈i(t)Δ𝜽^i(t)𝒈i(t)𝜽i(t)\displaystyle\bm{S}_{i}^{(t)}=-\bm{g}_{i}^{(t)}\Delta\hat{\bm{\theta}}_{i}^{(t)}-\bm{g}_{i}^{(t)}\bm{\theta}_{i}^{(t)} (11)

Contrary to the general 0-1 Knapsack problem which is known to be NP-complete, fortunately, the equal-weight 0-1 Knapsack is a P problem. Its optimal solution can be obtained by sorting items in descending order according to their values and selecting the top-r(t)r^{(t)} ones:

𝒙~i(t)={1,if 𝑺i(t)is in the top-r(t)of 𝑺(t)0,otherwise\displaystyle\tilde{\bm{x}}^{(t)}_{i}=\begin{cases}1,&\text{if }~{}\bm{S}^{(t)}_{i}\text{is in the top-}r^{(t)}\text{of }\bm{S}^{(t)}\\ 0,&\text{otherwise}\end{cases} (12)

Putting it in the context of iterative pruning, our analysis theoretically reveals the validity of: (1) selecting parameters based on the ranking of certain importance criterion; (2) using Eq. (11) as a principled new importance criterion.

3.2 Self-regularization

In vanilla fine-tuning, the learning objective \mathcal{L} is defined as the training error er\mathcal{L}_{er} (a.k.a empirical risk in statistical learning) over the empirical data distribution. However, minimizing such training error does not translate to good generalization. Moreover, as iterative pruning proceeds, the number of non-zero parameters in the model monotonically decreases. The reduced model capacity increases the learning difficulty Lopez-Paz et al. (2015); Mirzadeh et al. (2019) and usually leads to degenerated generalization performance of the sparsified model Sanh et al. (2020).

Confronting the above challenges, we propose an effective self-regularization scheme tailored to improving the model’s generalization ability during iterative pruning. Concretely, besides learning from the hard label of training data, the output of the current model with parameter 𝜽(t)\bm{\theta}^{(t)} is also regularized by the output of the latest best-performing model checkpoint with parameter 𝜽(tl)\bm{\theta}^{(t_{l})}, where tltt_{l}\leq t denotes the time step at which the latest checkpoint was saved. The learning objective of self-regularization is defined as:

sr=𝒟(y𝜽(t),y𝜽(tl))\displaystyle\mathcal{L}_{sr}=\mathcal{D}(y_{\bm{\theta}^{(t)}},y_{\bm{\theta}^{(t_{l})}}) (13)

where 𝒟\mathcal{D} can be any divergence metric, e.g., KL-divergence for classification tasks. sr\mathcal{L}_{sr} is then integrated with the original learning objective, i.e., =er+sr\mathcal{L}=\mathcal{L}_{er}+\mathcal{L}_{sr}.

Why does self-regularization work?

Our self-regularization is similar to teacher-student knowledge distillation in the sense that the model output is regularized by the output of another model. However, the most critical difference is that the “teacher” in self-regularization is instantiated by checkpoint with increasing sparsity, such that the capacity gap between “teacher” and “student” is dynamically adjusted. We theoretically justify the effectiveness of self-regularization as follows:

Theorem 1.

Let tit_{i} and tjt_{j} where titjt_{i}\geq t_{j} denote the time steps at which two different checkpoints are saved; Let R(f𝛉(tti))R(f_{\bm{\theta}^{(t\leftarrow t_{i})}}) and R(f𝛉(ttj))R(f_{\bm{\theta}^{(t\leftarrow t_{j})}}) denote the expected generalization error of models learned from f𝛉(ti)f_{\bm{\theta}^{(t_{i})}} and f𝛉(tj)f_{\bm{\theta}^{(t_{j})}}; Let n denotes the size of training data; ||C|\cdot|_{\text{C}} denotes a capacity measure of function class 𝛉\mathcal{F}_{\bm{\theta}}. Based on previous expositions on VC theory Vapnik (1998), we have the following asymptotic generalization bounds hold:

R(f𝜽(tti))O(|𝜽(t)|Cnαi)+inf𝜽(tti)R(f𝜽(t))bound(f𝜽(tti))\displaystyle R(f_{\bm{\theta}^{(t\leftarrow t_{i})}})\leq\underbrace{O(\frac{|\mathcal{F}_{\bm{\theta}^{(t)}}|_{\text{C}}}{n^{\alpha_{i}}})+\underset{\mathcal{F}_{\bm{\theta}^{(t\leftarrow t_{i})}}}{\inf}R(f_{\bm{\theta}^{(t)}})}_{bound(f_{\bm{\theta}^{(t\leftarrow t_{i})}})}
R(f𝜽(ttj))O(|𝜽(t)|Cnαj)+inf𝜽(ttj)R(f𝜽(t))bound(f𝜽(ttj))\displaystyle R(f_{\bm{\theta}^{(t\leftarrow t_{j})}})\leq\underbrace{O(\frac{|\mathcal{F}_{\bm{\theta}^{(t)}}|_{\text{C}}}{n^{\alpha_{j}}})+\underset{\mathcal{F}_{\bm{\theta}^{(t\leftarrow t_{j})}}}{\inf}R(f_{\bm{\theta}^{(t)}})}_{bound(f_{\bm{\theta}^{(t\leftarrow t_{j})}})}

Because 𝛉(ti)\bm{\theta}^{(t_{i})} is a later checkpoint with higher sparsity than 𝛉(tj)\bm{\theta}^{(t_{j})}, we have the learning speed 1αiαj121\geq\alpha_{i}\geq\alpha_{j}\geq\frac{1}{2}, then the following inequality holds with high probability:

bound(f𝜽(tti))bound(f𝜽(ttj))\displaystyle bound(f_{\bm{\theta}^{(t\leftarrow t_{i})}})\leq bound(f_{\bm{\theta}^{(t\leftarrow t_{j})}})

In summary, self-regularization works by enabling a tighter generalization bound compared to learning from training data alone or a static dense teacher as in knowledge distillation. Please refer to Appendix B for detailed derivation.

3.3 The Algorithm

Here we formally summarize our algorithm PINS (Pruning with principled Importance aNd Self-regularization) in Algorithm 1:

Algorithm 1 PINS

Input: Training set 𝒟tr={(xi,yi)}i=1N\mathcal{D}_{tr}=\{(x_{i},y_{i})\}_{i=1}^{N}; Validation set 𝒟val\mathcal{D}_{val}; pre-trained parameters 𝜽\bm{\theta}; maximum training steps TT; evaluation interval tevalt_{eval}.
Initialize: 𝜽(0)𝜽\bm{\theta}^{(0)}\leftarrow\bm{\theta}, tl0t_{l}\leftarrow 0, best validation accuracy acctlINF\text{acc}_{t_{l}}\leftarrow-\text{INF}.

1:for t=0t=0 to T1T-1 do
2:     Sample a mini-batch (𝒙,𝒚)(\bm{x},\bm{y}) from DtrD_{tr}
3:     Compute current model’s output 𝒚𝜽(t)\bm{y}_{\bm{\theta}^{(t)}}
4:     Compute latest best-performing checkpoint’s output 𝒚𝜽(tl)\bm{y}_{\bm{\theta}^{(t_{l})}}
5:     Compute \mathcal{L} based on 𝒚𝜽(t)\bm{y}_{\bm{\theta}^{(t)}}, 𝒚𝜽(tl)\bm{y}_{\bm{\theta}^{(t_{l})}} and 𝒚\bm{y}
6:     Compute 𝑺(t)\bm{S}^{(t)} via Eq. (11)
7:     Compute 𝜽(t+1)\bm{\theta}^{(t+1)} via Eq. (2) and Eq. (3)
8:     if t%teval=0t\%t_{eval}=0 and acct>acctl{}_{t_{l}} then
9:         acctlacct{}_{t_{l}}\leftarrow\text{acc}_{t}, 𝜽(tl)𝜽(t)\bm{\theta}^{(t_{l})}\leftarrow\bm{\theta}^{(t)}      

Output: the pruned parameters 𝜽(T)\bm{\theta}^{(T)}.

4 Experiments

In this section, We compare PINS with state-of-the-art pruning algorithms and perform detailed analysis to understand the effectiveness of PINS.

4.1 Setup

4.1.1 Tasks

We conduct experiments on a comprehensive spectrum of tasks following standard data splits.
Natural Language Understanding. We opt for tasks from the GLUE Wang et al. (2018) benchmark, including linguistic acceptability (CoLA), natural language inference (RTE, QNLI, MNLI), paraphrase (MRPC, QQP), sentiment analysis (SST-2) and textual similarity (STS-B). Because the official test set of GLUE is hidden, we randomly split a small portion of training set as validation set and treat the original validation set as test set.
Question Answering. We use SQuAD v1.1 Rajpurkar et al. (2016) as a representative dataset for extractive question answering following previous work Zhang et al. (2022).
Named Entity Recognition. We also examine our approach on CoNLL 2003 Tjong Kim Sang and De Meulder (2003) for token-level named entity recognition task.
Data-to-Text Generation. Besides language understanding tasks, we also extend our evaluation to data-to-text generation on three datasets: E2E Dušek et al. (2020), DART Radev et al. (2020), and WebNLG Zhou and Lampouras (2020), which involves generating a piece of fluent text from a set of structured relational triples.

4.1.2 Baselines

Magnitude-based. Iterative magnitude pruning (IMP) Zhu and Gupta (2017) is the state-of-the-art magnitude-based approach.

Sensitivity-based. l0l_{0}-regularization Louizos et al. (2018) trains masking variables via re-parametrization trick with l0l_{0} penalty; SMvP Sanh et al. (2020) uses accumulated sensitivity as importance metric; PST Li et al. (2022) proposed a hybrid importance criterion combining both magnitude and sensitivity; PLATON Zhang et al. (2022) uses a modified variant of sensitivity by exponential moving average and uncertainty re-weighting.

Sparsity Method RTE Acc MRPC F1 STS-B Pearson CoLA Mcc SST-2 Acc QNLI Acc MNLI Acc QQP Acc Avg.
0% Fine-tune\dagger 69.3 90.3 90.2 58.3 92.4 91.3 84.0 91.5 83.4
80% IMP\dagger 65.7 86.2 86.8 42.5 84.3 89.2 82.2 86.0 77.9
l0l_{0}-regularization\dagger 63.2 80.2 82.8 0.0 85.0 85.0 80.8 88.5 70.7
SMvP\dagger 62.8 86.7 87.8 48.5 89.0 88.3 81.9 90.6 79.5
PST 63.0 87.4 88.0 44.6 89.3 88.3 79.3 88.9 78.6
PLATON\dagger 68.6 89.8 89.0 54.5 91.2 90.1 83.3 90.7 82.2
PINS (ours) 72.7 90.9 89.2 57.1 91.9 91.2 83.9 90.9 83.5
90% IMP\dagger 57.4 80.3 83.4 18.3 80.7 86.6 78.9 78.8 70.5
l0l_{0}-regularizatio\dagger 59.9 79.5 82.7 0.0 82.5 82.8 78.4 87.6 69.1
SMvP\dagger 58.8 85.9 86.5 0.0 87.4 86.6 80.9 90.2 72.1
PST\ddagger 62.8 85.6 81.7 42.5 88.7 86.0 76.7 83.9 76.0
PLATON\dagger 65.3 88.8 87.4 44.3 90.5 88.9 81.8 90.2 79.6
PINS (ours) 68.5 90.1 87.9 49.8 91.0 89.5 82.7 90.6 81.3
Table 1: Results with BERTbase{}_{\text{base}} on the GLUE development set. For MNLI, the results are averaged on MNLI-m and MNLI-mm. \dagger indicates the results are directly quoted from Zhang et al. (2022) while \ddagger indicates the results are reported by Li et al. (2022).

4.1.3 Implementation Details

We mainly conduct experiments on the pre-trained BERTbase{}_{\text{base}} Devlin et al. (2019) as a pruning target for all tasks except data-to-text generation. We defer the pruning results of MiniLM12L-384H{}_{\text{12L-384H}} Wang et al. (2020) and Electrabase{}_{\text{base}} Clark et al. (2020) to Appendix A. For data-to-text generation, we adopt the pre-trained GPT-2 Radford et al. (2018) following a prior study Li et al. (2022).

During pruning, we employ the cubic sparsity scheduler Sanh et al. (2020); Zhang et al. (2022) to gradually increase the sparsity level from 0 to the specified target sparsity. To avoid tremendous computation cost brought by hyper-parameter tuning, we only search the batch size from {16,32}\{16,32\} and fix the learning rate as 3e-5 for all experiments on GLUE and CoNLL. For SQuAD v1.1, we fix the batch size as 16 and the learning rate as 3e-5 following Zhang et al. (2022). We adopt AdamW Loshchilov and Hutter (2017) as the default optimizer. To reduce the variance induced by mini-batch sampling, we adopt a smoothing technique similar to PLATON. We run each experiment five times with different random seeds and report the average results (significance tests with pp-value < 0.05 are conducted for all performance gains).

4.2 Main Results

4.2.1 Comparison with Baselines

Sparsity 80% 70% 60% 50%
Fine-tune\dagger 88.1
IMP\dagger 82.9 86.5 86.7 87.0
l0l_{0}-regularization\dagger 81.9 82.8 83.9 84.6
SMvP\dagger 84.6 85.8
PLATON\dagger 86.1 86.7 86.9 87.2
PINS (ours) 86.4 86.9 87.4 88.0
Table 2: Results with BERTbase{}_{\text{base}} on SQuAD v1.1. \dagger indicates numbers reported from Zhang et al. (2022). F1 score is reported as evaluation metric.
Natural language understanding

We present the experimental results on GLUE at high sparsity, i.e., 80% and 90% in Table 1. Among all baselines, sensitivity-based methods generally achieve better results than magnitude-based IMP, which implies the importance of training dynamics when designing pruning criteria. We can see that PINS delivers more accurate sparsified models on all datasets at both sparsity levels. The advantage of PINS is more evident on small datasets. For example, PINS outperforms the previous best-performing baseline (PLATON) by 4.1 and 2.6 points on RTE and CoLA at 80% sparsity, where there are only a few thousand training data. Under extremely high sparsity, i.e., 90%, PINS is still able to retain 97.5% overall performance of fine-tuning, outperforming 95.4% of the previous best method PLATON. Notably, PINS even surpasses fine-tuning on RTE and MRPC at 80% sparsity. This can be attributed to the fact that PLMs are heavily over-parameterized and PINS can effectively identify parameters crucial to the task to realize low bias and low variance simultaneously.

Sparsity Method P R F1
0% Fine-tune 93.5 94.6 94.0
70% IMP 90.7 91.8 91.2
SMvP 92.9 94.1 93.5
PINS(ours) 93.5 94.3 93.9
80% IMP 84.4 87.3 85.8
SMvP 92.1 93.1 92.6
PINS(ours) 92.8 93.8 93.3
Table 3: Results with BERTbase{}_{\text{base}} on CoNLL 2003. P and R stands for Precision and Recall respectively.
Sparsity Method E2E DART WebNLG
BLEU ROUGE-L METEOR BLEU BLEURT BLEU BLEURT
0% Fine-tune 69.4 71.1 46.2 46.6 0.30 46.9 0.23
80% IMP 69.3 71.0 45.8 44.9 0.22 39.9 0.00
PST 69.4 70.8 45.9 44.1 0.22 44.3 0.16
PINS (ours) 69.6 71.8 46.6 46.2 0.29 45.5 0.18
Table 4: Results with GPT-2 on data-to-text generation datasets. The higher the BLEU, ROUGE-L, METEOR, and BLEURT scores are, the better the performance.
Sparsity Method RTE Acc MRPC F1 STS-B Pearson CoLA Mcc SST-2 Acc QNLI Acc MNLI Acc QQP Acc Avg.
0% Fine-tune 69.3 90.3 90.2 58.3 92.4 91.3 84.0 91.5 83.4
50% PINS 70.8 91.4 89.7 60.6 92.9 91.8 85.1 91.3 84.2
30% PINS 71.7 91.2 89.8 60.4 93.3 92.0 85.1 91.5 84.4
Table 5: Results with BRETbase{}_{\text{base}} on the GLUE development set under medium-to-low sparsity regime. Numbers are the mean of five trials with different random seeds. PINS outperforms fine-tuning at medium-to-low sparsity.
Question answering

Table 2 summarizes the pruning results on SQuAD v1.1. Interestingly, IMP outperforms all sensitivity-based methods except for PLATON at all considered sparsity levels, in contrast to the observations on GLUE. Our method, however, consistently yields the best performance at all sparsity settings.

Named entity recognition

Table 3 demonstrates the pruning results on CoNLL 2003 dataset for named entity recognition. At 70% sparsity, our method almost matches the performance of fine-tuning, outperforming baselines on all evaluation metrics. The gain of PINS is more prominent when further increasing sparsity.

Data-to-text generation

Table 4 shows the pruning results on E2E, DART and WebNLG at 80% sparsity. PINS achieves the best performance on all three datasets in all evaluation metrics. In particular, PINS delivers performance even better than fine-tuning on the E2E dataset by 0.7 ROUGE-L and 0.4 METEOR scores, respectively. We posit that this is due to the relative easiness of E2E compared to the other two datasets.

4.2.2 Results at Medium-to-Low Sparsity

The typical utility of pruning is to produce a sparse yet competitive model that can benefit downstream applications in terms of efficiency without sacrificing much task accuracy. We hypothesize that PINS might also bring a regularization effect compared to vanilla fine-tuning under the medium-to-low sparsity regime.

As shown in Table 5, when specifying a medium-to-low sparsity, e.g., 50%\sim30%, our method can effectively play a role of regularization and improve model performance compared to vanilla fine-tuning. With half of the parameters being pruned, the sparse model produced by PINS outperforms fine-tuning by 1 percentage point on the GLUE score. This observation suggests that appropriate pruning can effectively reduce variance without hurting model expressiveness.

4.3 Ablation Study

The self-regularization scheme is proposed and integrated into PINS to improve model generalization. Here we investigate the effectiveness of self-regularization by comparing it to the conventional knowledge distillation scheme and the classical empirical risk minimization scheme.

The pruning results of using the three different learning objectives on RTE, CoLA, and MRPC are listed in Table 6. Pruning with PINS using classical empirical risk minimization still achieves performance better than existing baselines (Table 1). Learning from a densely fine-tuned BERTbase{}_{\text{base}} as the teacher does not always improve and sometime may even hurt performance. In contrast, our proposed self-regularization consistently boosts model performance, which echoes our theoretical justification in Section 3.2.

\mathcal{L} RTE CoLA MRPC
empirical risk 70.9 55.4 90.6
w/ knowledge distillatiojn 70.3 56.0 90.6
w/ self-regularization 72.7 57.1 90.9
Table 6: Ablation Study with BERTbase{}_{\text{base}} on the learning objective during iterative pruning at 80% sparsity.

4.4 Analysis

We provide an in-depth analysis of various importance criteria to uncover more valuable insights.

Sparsity pattern of weight matrices

We are interested in the sparsity pattern produced by different pruning criteria. To this end, we plot the remaining parameters’ distribution of the same weight matrix in BERTbase{}_{\text{base}} pruned via magnitude, sensitivity, and PINS in Figure 1. We observe that magnitude-based pruning generates a sparsity pattern close to randomness. Sensitivity-based pruning produces a more structured pattern where the remaining parameters tend to occupy complete rows. Interestingly, the sparsity pattern produced by PINS exhibits the highest concentration on specific rows. This implies that the parameters contributing most to the end-task are preferably distributed in a structured way and PINS is more effective at extracting such patterns.

Refer to caption
Figure 1: Sparsity pattern (80%) of the same weight matrix in BERTbase{}_{\text{base}} trained on SST-2. See Appendix C for more details on the matrix rank distribution.
Refer to caption
Figure 2: Layerwise distribution of average matrix rank in BERTbase{}_{\text{base}} pruned at 80% sparsity on SST-2.
Layerwise rank distribution

The highly structured sparsity pattern generated by PINS intrigues our interest to further analyze the intrinsic property of parameter matrices after pruning. Specifically, we inspect the matrix rank as it is usually associated with the complexity of matrix. To this end, we visualize the layerwise rank distribution of BERTbase{}_{\text{base}} pruned using different importance criteria on SST-2 dataset. As shown in Figure 4, magnitude pruning produces sparse matrices that are still near full-rank despite containing 80% zeros. Sensitivity pruning tends to generate sparsity pattern with lower rank compared to magnitude pruning. Notably, model pruned by PINS shows consistently lower matrix rank than the other two criteria. This implies that PINS is more effective at identifying the low-dimensional task representation during adaptation, which is usually correlated with tighter generalization bounds Arora et al. (2018); Aghajanyan et al. (2021).

Empirical validation of importance criterion

In Section 3.1 we prove that the pruning decision derived by our importance criterion is theoretically optimal. Here we empirically validate this point by visualizing the change of learning objective as pruning proceeds. Figure 3 illustrates that our importance criterion indeed leads to the most significant decrease in the learning objective compared to heuristical ones like magnitude and sensitivity.

Refer to caption
Figure 3: Change of learning objective (cross-entropy) during iterative pruning on SST-2.
Sparsity Time(s) Storage(MB) Acc.
0% 0.110 (1.0x) 340 (1.0x) 69.3
80% 0.041 (2.7x) 38 (8.9x) 69.0
Table 7: Practical time and storage efficiency gain on RTE with Deepsparse and CSR format. Inference is perform on Intel Xeon E5-2640 CPU with batch size 1.

4.5 Efficiency Gain

We can exploit the resulting high sparsity to attain practical efficiency gain on storage and inference speed. We first apply quantization upon the pruned model and transform it into INT8 data type before saving it using Compressed Sparse Row (CSR) format. We then leverage a sparsity-aware runtime Kurtz et al. (2020) for accelerating inference. As shown in Table 7, on the RTE dataset, the disk space and inference time of BERTbase{}_{\text{base}} pruned at 80% sparsity can be reduced by roughly 8.9x and 2.7x respectively with negligible accuracy loss.

5 Conclusion

We present PINS, a new iterative pruning method that hinges on a principled weight importance criterion to deliver the optimal stepwise pruning decision. Integrated with a self-regularization scheme tailored to pruning-during-adaptation, PINS allows for provably better generalization ability. Empirical experiments and analyses confirm the effectiveness of our method and shed further light on the different sparsity patterns produced by PINS and other existing methods.

Limitations

Compared to the empirical risk minimization scheme, the introduced self-regularization scheme incurs certain overhead because each mini-batch of data will go through two models. For BERTbase{}_{\text{base}} scale pre-trained language models, the additional memory overhead is about 27% and the additional training time overhead is about 30%. Nevertheless, once pruned, the sparsified model can enjoy considerable efficiency gains in terms of storage and inference time. Therefore, this is a trade-off that future practitioners might need to consider.

Acknowledgments

This work was generously supported by the CMB Credit Card Center & SJTU joint research grant, and Meituan-SJTU joint research grant.

References

Appendix A Results with More PLMs on subset of GLUE

In addition the widely used BERT and GPT-2 models, we also perform pruning experiments upon other two pre-trained language models: Electrabase{}_{\text{base}} and MiniLM12L-384H{}_{\text{12L-384H}} to further verify the effectiveness of our method.

Due to computing resource constraint, we restrict our experiments on a subset of GLUE task, including RTE, CoLA and QNLI at 80% and 90% sparsity. We compare PINS against IMP and PLATON as two representative baselines for magnitude-based and sensitivity-based pruning methods. We fix the batch size as 32 and learning rate as 3e-5 similar to the BERT experiments. We illustrate the pruning results on Table 8 and Table 9. At both sparsity levels, PINS consistently outperforms IMP and PLATON on all three datasets, verifying the general effectiveness of PINS for language model pruning.

Sparsity Method RTE Acc CoLA Mcc QNLI Acc
0% Fine-tune 73.0 58.5 91.5
80% IMP 60.5 21.6 87.5
PLATON 68.2 54.1 89.8
PINS (ours) 69.5 54.4 90.4
90% IMP 57.5 14.1 83.9
PLATON 63.1 38.8 88.0
PINS (ours) 66.2 44.8 88.6
Table 8: Results with MiniLM12L-384H{}_{\text{12L-384H}} on the GLUE development set.
Sparsity Method RTE Acc CoLA Mcc QNLI Acc
0% Fine-tune 81.9 69.0 93.1
80% IMP 59.9 11.2 87.5
PLATON 73.6 60.0 91.0
PINS (ours) 75.5 63.7 92.0
90% IMP 52.9 0.0 83.0
PLATON 69.9 48.0 89.7
PINS (ours) 72.3 49.2 90.2
Table 9: Results with Electrabase{}_{\text{base}} on the GLUE development set.

Appendix B Proof of Theorem 1

Proof.

Let tit_{i} and tjt_{j} where titjt_{i}\geq t_{j} denote the time steps at which two different checkpoints are saved; Let R(f𝜽(tti))R(f_{\bm{\theta}^{(t\leftarrow t_{i})}}) and R(f𝜽(ttj))R(f_{\bm{\theta}^{(t\leftarrow t_{j})}}) denote the expected generalization error of models learned from f𝜽(ti)f_{\bm{\theta}^{(t_{i})}} and f𝜽(tj)f_{\bm{\theta}^{(t_{j})}}; Let nn denotes the size of training data; ||C|\cdot|_{\text{C}} denotes a capacity measure like VC-dimension for function class 𝜽\mathcal{F}_{\bm{\theta}}. Based on previous expositions on VC theory, the following asymptotic generalization bound holds:

R(f𝜽(tti))=R(f𝜽(tti))R(f𝜽(ti))\displaystyle R(f_{\bm{\theta}^{(t\leftarrow t_{i})}})=R(f_{\bm{\theta}^{(t\leftarrow t_{i})}})-R(f_{\bm{\theta}^{(t_{i})}})
+R(f𝜽(ti))\displaystyle+R(f_{\bm{\theta}^{(t_{i})}})
O(|𝜽(t)|Cnαi)+ϵt,ti+R(f𝜽(ti))\displaystyle\leq O(\frac{|\mathcal{F}_{\bm{\theta}^{(t)}}|_{\text{C}}}{n^{\alpha_{i}}})+\epsilon_{t,t_{i}}+R(f_{\bm{\theta}^{(t_{i})}})
=O(|𝜽(t)|Cnαi)+inff𝜽(t)𝜽(tti)R(f𝜽(t))bound(f𝜽(tti))\displaystyle=\underbrace{O(\frac{|\mathcal{F}_{\bm{\theta}^{(t)}}|_{\text{C}}}{n^{\alpha_{i}}})+\underset{f_{\bm{\theta}^{(t)}}\in\mathcal{F}_{\bm{\theta}^{(t\leftarrow t_{i})}}}{\inf}R(f_{\bm{\theta}^{(t)}})}_{bound(f_{\bm{\theta}^{(t\leftarrow t_{i})}})}
R(f𝜽(ttj))=R(f𝜽(ttj))R(f𝜽(tj))\displaystyle R(f_{\bm{\theta}^{(t\leftarrow t_{j})}})=R(f_{\bm{\theta}^{(t\leftarrow t_{j})}})-R(f_{\bm{\theta}^{(t_{j})}})
+R(f𝜽(tj))\displaystyle+R(f_{\bm{\theta}^{(t_{j})}})
O(|𝜽(t)|Cnαj)+ϵt,tj+R(f𝜽(tj))\displaystyle\leq O(\frac{|\mathcal{F}_{\bm{\theta}^{(t)}}|_{\text{C}}}{n^{\alpha_{j}}})+\epsilon_{t,t_{j}}+R(f_{\bm{\theta}^{(t_{j})}})
=O(|𝜽(t)|Cnαj)+inff𝜽(t)𝜽(ttj)R(f𝜽(t))bound(f𝜽(ttj))\displaystyle=\underbrace{O(\frac{|\mathcal{F}_{\bm{\theta}^{(t)}}|_{\text{C}}}{n^{\alpha_{j}}})+\underset{f_{\bm{\theta}^{(t)}}\in\mathcal{F}_{\bm{\theta}^{(t\leftarrow t_{j})}}}{\inf}R(f_{\bm{\theta}^{(t)}})}_{bound(f_{\bm{\theta}^{(t\leftarrow t_{j})}})}

where ϵt,ti\epsilon_{t,ti} is the approximation error of function class 𝜽(tti)\mathcal{F}_{\bm{\theta}^{(t\leftarrow t_{i})}} with respect to f𝜽(ti)f_{\bm{\theta}^{(t_{i})}}. ϵt,tj\epsilon_{t,tj} is defined in analogy. Because: (1) 𝜽(ti)\bm{\theta}^{(t_{i})} is a later checkpoint with higher sparsity than 𝜽(tj)\bm{\theta}^{(t_{j})}, we have the learning speed 1αiαj121\geq\alpha_{i}\geq\alpha_{j}\geq\frac{1}{2}; (2) f𝜽(ti)f_{\bm{\theta}^{(t_{i})}} has lower generalization error than f𝜽(tj)f_{\bm{\theta}^{(t_{j})}}, we have the following inequality holds with high probability:

bound(f𝜽(tti))bound(f𝜽(ttj))\displaystyle bound(f_{\bm{\theta}^{(t\leftarrow t_{i})}})\leq bound(f_{\bm{\theta}^{(t\leftarrow t_{j})}})

Refer to caption
Figure 4: Weight distributions of BERTbase{}_{\text{base}} pruned using different importance criteria on RTE dataset. Left figure shows the value distribution and the right figure shows how remaining parameters are distributed at different model layers.
Refer to caption
Figure 5: Layerwise rank distribution of BERTbase{}_{\text{base}} pruning using different importance criteria on RTE dataset.

Appendix C More Post-pruning Analyses

This section presents more visualized analyses of models sparsified by different pruning methods.

Figure 5 shows the layerwise rank distribution of BERTbase{}_{\text{base}} pruned using different importance criteria on the RTE dataset. The observation here is similar to what is discussed in the main body of the paper: PINS exhibits the lowest average matrix rank in the sparsified model compared to the other two criteria.

Figure 4 illustrates the weight distribution of BERTbase{}_{\text{base}} pruning using different importance criteria. From the left figure we can see that magnitude-based pruning tends to keep parameters with high absolute values, which is expected based on its definition. Sensitivity and PINS produce similar weight value distribution mainly because the two methods both contain the gθg\theta term in their importance calculation. Despite the similarity, we can still observe that PINS produces smoother distribution than sensitivity and covers more weights with larger absolute values.

The right figure shows the layerwise distribution of remaining parameters after pruning. A clear trend is that PINS tends to retain more parameters in the middle layers (4-7), which also coincided with the inter-model sparsity pattern analysis in the main body of our paper. Both sensitivity and PINS remove a large proportion of parameters in the top layers (10-12) while magnitude-based pruning has no preference for model layers.

Appendix D Sparsity Scheduler

The proportion of remaining weights is controlled by the sparsity scheduler, here we adopt the commonly used cubic sparsity schedule to progressively reach target sparsity, i.e., r(t)r^{(t)} at time step tt within the maximum time steps TT is given by:

{rit[0,ti)rf+(rirf)(TtftTtfti)3t[ti,Ttf)rfotherwise\displaystyle\begin{cases}r_{i}&t\in[0,t_{i})\\ r_{f}+(r_{i}-r_{f})(\frac{T-t_{f}-t}{T-t_{f}-t_{i}})^{3}&t\in[t_{i},T-t_{f})\\ r_{f}&\text{otherwise}\end{cases} (14)

where ri=1.0r_{i}=1.0, rfr_{f} is the final percent of remained parameters, tit_{i} and tft_{f} are the warmup and cool-down steps.

Appendix E Accelerating Inference and Reducing Storage

We attain practical efficiency gain in terms of inference time and disk storage space using different sets of off-the-shelf techniques. Specifically, we use DeepSparse222https://github.com/neuralmagic/deepsparse, a sparsity-aware inference runtime to accelerate inference of sparse model on CPUs. We also utilize the Pytorch built-in quantization function333https://pytorch.org/docs/stable/quantization.html and Compressed Sparse Row (CSR) format444https://github.com/huggingface/block_movement_pruning/blob/master/Saving_PruneBERT.ipynb to achieve a much smaller disk space requirement.