Teacher Guided Training:
An Efficient Framework for Knowledge Transfer
Abstract
The remarkable performance gains realized by large pretrained models, e.g., GPT-3, hinge on the massive amounts of data they are exposed to during training. Analogously, distilling such large models to compact models for efficient deployment also necessitates a large amount of (labeled or unlabeled) training data. In this paper, we propose the teacher-guided training (TGT) framework for training a high-quality compact model that leverages the knowledge acquired by pretrained generative models, while obviating the need to go through a large volume of data. TGT exploits the fact that the teacher has acquired a good representation of the underlying data domain, which typically corresponds to a much lower dimensional manifold than the input space. Furthermore, we can use the teacher to explore input space more efficiently through sampling or gradient-based methods; thus, making TGT especially attractive for limited data or long-tail settings. We formally capture this benefit of proposed data-domain exploration in our generalization bounds. We find that TGT can improve accuracy on several image classification benchmarks as well as a range of text classification and retrieval tasks.
1 Introduction
Recent general purpose machine learning models (e.g., BERT (Devlin et al., 2019), DALL-E (Ramesh et al., 2021), SimCLR (Chen et al., 2020a), Perceiver (Jaegle et al., 2021), GPT-3 (Brown et al., 2020)), trained on broad data at scale, have demonstrated adaptability to a diverse range of downstream tasks. Despite being trained in unsupervised (or so-called self-supervised) fashion, these models have been shown to capture highly specialized information in their internal representations such as relations between entities Heinzerling and Inui (2021) or object hierarchies from images (Weng et al., 2021).
Despite their impressive performance, the prohibitively high inference cost of such large models prevents their widespread deployment. A standard approach to reducing the inference cost while preserving performance is to train a compact (student) model via knowledge distillation (Bucilua et al., 2006; Hinton et al., 2015) from a large (teacher) model. However, existing distillation methods require a large amount of training data (labeled or unlabeled) for knowledge transfer. For each data point, the teacher must be evaluated, making the process computationally expensive Xie et al. (2020d); He et al. (2021); Sanh et al. (2019a). This is compounded by the need to repeat the distillation process separately for every down-stream task, each with its own training set. Enabling efficient distillation is thus an important challenge. Additionally, minimizing the number of distillation samples would especially benefit low-data down-stream tasks, (e.g. those with long-tails).
Another inefficiency with standard distillation approaches is that within each evaluation of the teacher, only the final layer output (aka logits) is utilized. This ignores potentially useful internal representations which can also be levered for knowledge transfer. Various extensions have been proposed in the literature along these lines (see, e.g., (Sun et al., 2020; Aguilar et al., 2020; Li et al., 2019; Sun et al., 2019) and references therein). However, despite their success, most use the teacher model in a black-box manner, and do not fully utilize the domain understanding it contains (Cho and Hariharan, 2019; Stanton et al., 2021). In these approaches, the teacher is used passively as the input sample distribution is fixed and does not adapt to the student model performance. Consequently, these forms of distillation do not lead to faster training of a high-performance student model.

In this work, we go beyond the passive application of large teacher models for training compact student models, and leverage the domain understanding captured by the teacher to generate new informative training instances that can help the compact model achieve higher accuracy with fewer samples and thus enable reduced training time. In particular, we propose the teacher guided training (TGT) framework for a more efficient transfer of knowledge from large models to a compact model. TGT relies on the fact that teacher’s internal representation of data often lies in a much smaller dimensional manifold than the input dimension. Furthermore, we can use teacher to help guide training by identifying the directions where the student’s current decision boundary starts to diverge from that of the teacher, e.g., via backpropagating through the teacher to identify regions of disagreement.
We also give a formal justification for the TGT algorithm, showing that leveraging the internal data representation of large models enables better generalization bounds for the student model. Given instances in a -dimensional space the generalization gap for learning a Lipschitz decision boundary of an underlying classification task decays only as (Györfi et al., 2002). In contrast, assuming that the large model can learn a good data representation in a -dimensional latent space, the TGT framework realizes a generalization gap of , where denotes the Wasserstein distance between the data distribution and the distribution learned by the underlying generative teacher model. Typically , thus TGT ensures much faster convergence whenever we employ a high-quality generative teacher model. This makes TGT especially attractive for low-data or long-tail regimes.
In order to realize TGT, we take advantage of the fact that most of the unsupervised pretrained models like Transformers, VAE, and GANs have two components: (1) an encoder that maps data to a latent representation, and (2) a decoder that transforms the latent representation back to the original data space. We utilize this latent space for the data representations learned by the teacher model to efficiently search for the regions of mismatch between the teacher and student’s decision boundaries. This search can take the form of either (i) a zero-order approach involving random perturbation or (ii) a first-order method exploring along the direction of the gradient of a suitably defined distance measure between the teacher and student models.
Some of these pretrained models, particularly in NLP such as T5 (Raffel et al., 2020), can also provide labels for a downstream task and act as a sole teacher. However, our approach is sufficiently general to utilize separate pretrained models for generative and discriminative (labeler) functions (cf. Fig. 1), e.g., we employ an image BiGAN as generator and an EfficientNet as labeler for an image classification task.
Our main contributions are summarized as follows:
-
1.
We introduce the TGT framework, a conceptually simple and scalable approach to distilling knowledge from a large teacher into a smaller student. TGT adaptively changes the distribution of distillation examples, yielding higher performing student models with fewer training examples.
-
2.
We provide theoretical justifications for utilizing the latent space of the teacher generator in the TGT framework, which yields tighter generalization bounds.
-
3.
We empirically demonstrate the superiority of TGT to existing state-of-the-art distillation approaches, also showing results on both vision and NLP tasks, unlike most previous work which is specialized to one domain.
2 Related Work
Our proposed TGT framework can be considered a form of data augmentation where data is dynamically added at points of current discrepancy between the teacher and student. Next, we provide a brief overview of how data augmentation has been used in the context of distillation.
Using pseudo labels. The earliest line of work involves using consistency regularization (Sajjadi et al., 2016; Laine and Aila, 2017; Tarvainen and Valpola, 2017) to obtain pseudo labels for unlabelled data where a model is expected to make consistent predictions on an unlabeled instance and its augmented versions, cf. (Miyato et al., 2019; Xie et al., 2020a; Verma et al., 2019; Berthelot et al., 2019; Sohn et al., 2020; Zhu et al., 2021, inter alia). Another approach is self-training (Xie et al., 2020d; Du et al., 2021) where a smaller teacher model is learned on the labeled data first which is then used to generate pseudo labels for a large but relevant unlabeled set. A large student model is then trained on both labeled and pseudo labeled sets. Label propagation (Iscen et al., 2019) is another direction where unlabeled instances receive pseudo labels based on neighboring labeled instances in a similarity graph constructed based on the representations from a model trained on only labeled data.
Furthermore, there have been work on learning to teach (Fan et al., 2018; Raghu et al., 2021; Pham et al., 2021), where the teacher is dynamically updated so as to provided more valuable pseudo labels based on the student loss function. Such an interactive approach presents a challenging optimization problem and potentially opens up the door for borrowing techniques from reinforcement learning. In contrast, our work focuses on the setting where high-quality pretrained teacher model is fixed throughout the training. Here, we focus on a setting where updating the large teacher model is prohibitively costly or undesirable as such a model would potentially be used to distill many student models. Moreover, many large models like GPT-3 may only be available through API access, thus making it infeasible to update the teacher.
Using pretrained models. One can use large scale pretrained class conditional generative models like BigGAN (Brock et al., 2019) or VQ-VAE2 (Razavi et al., 2019) to generate more data for augmentation. Despite evidence (Webster et al., 2019) that GANs are not memorizing training data, using them to simply augment the training dataset has limited utility when training ResNets (Ravuri and Vinyals, 2019b, a). One possible reason might be the lack of diversity (Arora et al., 2017) in data generated by GANs, especially among high density regions (Arora et al., 2018). In contrast, we use generative models to adaptively explore the local region of disagreement between teacher and student as opposed to blindly sampling from the generative model. This way we circumvent the excessive reliance on samples from high density regions which often have low diversity.
Another line of work by Chen et al. (2020b) combines unsupervised/self-supervised pretraining (on unlabeled data) with SimCLR-based approach (Chen et al., 2020a), task-specific finetuning (on labeled data), and distillation (natural loss on labeled and distillation loss on unlabeled data). The setup considered in this work is very close to our work with two key differences: (1) We assume access to a very high-quality teacher model, which is potentially trained on a much larger labeled set, to provide pseudo labels; (2) We go beyond utilizing the given unlabeled dataset from the domain of interest, exploring the dynamic generation of domain-specific unlabeled data by leveraging the representations learned by pretrained models. Additionally, our work aims to develop a theoretical framework to identify the utility of unlabeled data instances for student training, especially the unlabeled instances generated based on teacher learned representations.
Using both pseudo labels and pretrained models. The idea of combining pretrained models to generate training instances along with pseudo-labelers has been previously considered in the name of the GAL framework (He et al., 2021). However, the GAL framework generates these new instances in an offline manner at the beginning of student training. In contrast, our proposed approach (cf. Fig. 1) generates the new informative training instances in an online fashion, aiming at improving the student performance while reducing its training time.
Recently, MATE-KD (Rashid et al., 2021) also considers a setup where a generator model is used to obtain new training instances based on the current student model performance (by looking at the divergence between the student and teacher predictions). However, there are two key differences between our proposed TGT approach and the MATE-KD framework: First, their method updates the teacher so as to find adversarial examples for the students, but this can cause the generator to drift away from true data distribution. Second, they perturb in input space itself and do not leverage the latent space of the teacher, which is the crux of our method. Further details are provided in App. A.
Another work worth mentioning is KDGAN (Wang et al., 2018) which leverages a GAN during distillation. However, it samples examples from a GAN without taking the student performance into account. We also note (Heo et al., 2019; Dong et al., 2020) that search for adversarial examples during distillation. However, their search also does not depend on student’s performance, resulting in wasteful exploration of those regions of the input spaces where the student is already good. Further, unlike TGT, (Heo et al., 2019; Dong et al., 2020) perform example search in the input space which is often inefficient due to the large ambient dimension of the input space.
Finally, data-free KD approaches perform knowledge distillation using only synthetically generated data (Nayak et al., 2019; Yoo et al., 2019; Chen et al., 2019). Unlike TGT, in this approach, the synthetic data distribution is updated at each epoch, but this causes the student model to lose the information over epochs and experience accuracy degradation (Binici et al., 2022). In this framework, Micaelli and Storkey (2019) targeted generating samples that would cause maximum information gain to the student when learned, however, it also suffers from similar drawbacks as MATE-KD noted above.
3 Teacher Guided Training
We begin by formally introducing our setup in Section 3.1. We then describe our proposed TGT framework in Sec. 3.2 and present a theoretical analysis in Sec. 3.3 and Sec. 3.4.
3.1 Problem setup
In this paper, we focus on a multiclass classification task where given an instance the objective is to predict its true label out of potential classes. Let denote the underlying (joint) data distribution over the instance and label spaces for the task. Moreover, we use and to denote the marginal distribution over the instance space and the conditional label distribution for a given instance , respectively. A classification model , with takes in an input instance and yields scores for each of the classes. Finally, we are given a (tractable) loss function which closely approximates model’s misclassification error on an example , e.g., softmax-based cross-entropy loss.
We assume access to i.i.d. labeled samples , generated from . Given and a collection of allowable models , one typically learns a model with small misclassification error by solving the following empirical risk minimization (ERM) problem:
(1) |
Besides the standard classification setting introduced above, in our TGT setup, we further assume access to a high quality teacher model, which has:
-
•
Teacher generator. A generative component that captures well, e.g., a transformer, VAE, or ALI-GAN. This usually consists of an encoder and a decoder .
-
•
Teacher labeler. A classification network, denoted by , with good performance on the underlying classification task. In general, our framework allows for to be either a head on top of the teacher generator or an independent large teacher classification model.
Given and such a teacher model, our objective is to learn a high-quality compact student (classification) model in , as assessed by its misclassification error on .
3.2 Proposed approach
To train a student model , we propose to minimize:
(2) |
where is a loss function that captures the mismatch between two models and , and is introduced in subsequent passage. The first term, , corresponds to standard ERM problem (cf. Eq. 1). The subsequent terms, and , do not make use of labels. In particular, the second term, , corresponds to the knowledge distillation (Bucilua et al., 2006; Hinton et al., 2015) where the teacher model provides supervision for the student model.
We introduce a novel third term, , where the data is generated based on . Here, we want to generate additional informative samples which will help student learn faster, e.g., points where it disagrees with teacher but still lie on the data manifold. In other words, we want to find as follows:
(3) |
We propose two specific approaches to generate novel samples:
-
1.
Isotropically perturb in latent space:
(4) This can be regarded as a zero-order search in the latent space, which satisfies the constraint of remaining within the data manifold.
-
2.
Gradient-based exploration: Run a few iterations of gradient ascent on Eq. 3 in order to find the example that diverges most with teacher. To enforce the constraint, we run the gradient ascent in the latent space of the teacher generator as opposed to performing gradient ascent in the instance space , which might move the perturbed point out of the data manifold. For a high-quality teacher generator, the latent space should capture the data manifold well. To implement this we need to backprop all the way through the student and teacher-labeler to the teacher-decoder, as shown in Fig. 1. Mathematically, it involves the following three operations:
(5) This is akin to a first-order search in the latent space.
Extension to discrete data. Note that perturbing an instance from a discrete domain, e.g., text data, is not as straightforward as in a continuous space. Typically, one has to resort to expensive combinatorial search or crude approximations to perform such perturbations (Tan et al., 2020; Zang et al., 2020; Ren et al., 2019). Interestingly, our approach in Eq. 4 provides a simple alternative where one performs the perturbation in the latent space which is continuous. On the other hand, in gradient based exploration, we assume that is a differentiable space in order to calculate necessary quantities such as in Eq. 5. This assumption holds for various data such as images and point clouds but not for discrete data like text. We can, however, circumvent this limitation by implementing weight sharing between the output softmax layer of the teacher’s decoder and the input embedding layer of the student (and also to teacher labeler when an independent model is used). Now, one can bypass discrete space during the backward pass, similar to ideas behind VQ-VAE (Hafner et al., 2019). Note that, during forward pass, we still need the discrete representation for decoding, e.g., using beam search.
Finally, we address the superficial resemblance between our approach and adversarial training. In adversarial training, the goal is to learn a robust classifier, i.e., to increase margin. Towards this, for any , one wants to enforce model agreement in its local neighborhood , i.e., . One needs to carefully choose small enough neighborhood by restricting , so as to not cross the decision boundary. In contrast, we are not looking for such max-margin training which has its own issues (cf. (Nowak-Vila et al., 2021)). We simply want to encourage agreement between the teacher and student, i.e., . Thus, we don’t have any limitation on the size of the neighborhood to consider. As a result, we can explore much bigger regions as long as we remain on the data manifold.
3.3 Value of generating samples via the latent space
In this section, we formally show how leveraging the latent space can help learning. For this exposition, we assume . Furthermore, for directly learning in the input space, we assume that our function class corresponds to all Lipschitz functions that map to . Then for any such function , there are existing results for generalization bound of the form (Devroye et al., 2013; Mohri et al., 2018):
where is true population risk of the classifier, is empirical risk, and is the Rademacher complexity of the induced function class , which is known in our case to be (see App. B for more details). Any reduction in the Rademacher term would imply a smaller generalizing gap, which is our goal.
In our TGT framework, we assume availability of a teacher that is able to learn a good representation for the underlying data distribution. In particular, we assume that, for , we have
(6) |
i.e., for , applying the decoder on the latent representation of , as produced by the encoder , leads to a point that approximates with a small error.
This ability of teacher generator to model the data distribution using latent representation can be used to reduce the complexity of the function class needed. Specifically, in TGT framework, we leverage the teacher decoder to restrict the function class to be a composition of the decoder function and a learnable Lipschitz function operating on the latent space . Since , this leads to a function class with much lower complexity. Next, we formally capture this idea for distillation with both the original samples sampled from as well as the novel samples introduced by the teacher generator. In what follows, we only consider the distillation losses and ignore the first loss term (which depends on true labels). Our analysis can be easily extended to take the latter term into account (e.g., by using tools from Foster et al. (2019)).
We start with the standard distillation in the following result.
Theorem 3.1.
Suppose a generative model with and satisfies the approximation guarantee in Eq. 6 for . Let and teacher labeler be Lipschtiz functions, and the distillation loss satisfies Assumption C.1. Then, with probability at least , the following holds for any .
where is the effective Lipschitz constant of — an induced function class which maps (latent space of generator) to .
Thus, we can reduce the Rademacher term from to , which yields a significant reduction in sample complexity. However, as the teacher model is not perfect, a penalty is incurred in terms of reconstruction and prediction error. See Sec. C.1 for the details.
Thus far, we have not leveraged the fact that we can also use the teacher to generate further samples. Accounting for using samples generated from teacher generator instead, one can obtain similar generalization gap for the distillation based on the teacher generated samples:
Theorem 3.2.
Let be i.i.d. samples generated by the the TGT framework, whose distribution be denoted by . Further, let denote the student model learned via distillation on , with as the teacher model and be the distillation loss satisfying Assumption C.1. Then, with probability at least , we have
Remark 3.3.
Comparing with the generalization gap for standard distillation (cf. Thm. 3.1), the generalization gap for TGT in Thm. 3.2 does not have the reconstruction error related term . Thus, by working with the samples with exact latent representation, TGT avoids this reconstruction error penalty. On the other hand, generalization gap for TGT does have an additional term , capturing the mistmatch between the original data distribution and the distribution of the samples used by TGT. However, this term becomes smaller as the teacher generator gets better at capturing the data. Note that generative models like WGAN explicitly minimize this term (Arjovsky et al., 2017).
3.4 Motivation for gradient based exploration
Results so far do not throw light on how to design optimal , i.e., the search mechanism in the latent space for our TGT framework. In this regard, we look at the variance-based generalization bounds (Maurer and Pontil, 2009). These were previously utilized by Menon et al. (2021a) in the context of distillation. Applying this technique in our TGT approach, we would obtain a generalization bound of the form:
(7) |
where, with and depending on the covering number for the induced function class (cf. Eq. 16). Here, we note that by combining Eq. 7 with Lemma D.4 translate the bound on to a bound on with an additional penalty term that depends on the quality of the teacher labeler .
Note that Eq. 7 suggests a general approach to select the distribution that generates the training samples . In order to ensure small generalization gap, we need to focus on two terms depending on : (1) the variance term ; and (2) the divergence term . We note that finding a distribution that jointly minimizes both terms is a non-trivial task. That said, in our sampling approach in Eq. 5, we control for by operating in the latent space of a good quality teacher generative model and minimize variance by finding points with high loss values through gradient ascent, thereby striking a balance between the two objectives. We refer to Sec. C.3 for more details on the bound stated in Eq. 7.
4 Experiments
We now conduct a comprehensive empirical study of our TGT framework in order to establish that TGT (i) leads to high accuracy in transferring knowledge in low data/long-tail regimes (Sec. 4.1); (ii) effectively increases sample size (Sec. 4.2); and (iii) has wide adaptability even to discrete data domains such as text classification (Sec. 4.3) and retrieval (Sec. 4.4).
Approach | Architecture | Balanced Accuracy | # parameters | FLOPs | |
---|---|---|---|---|---|
ImageNet1K-LT | LDAM-DRW* (Cao et al., 2019) | ResNet-50 | 47.8 | 26 M | 4.1 B |
LWS (Kang et al., 2020) | ResNeXt-50 | 49.9 | 25 M | 4.2 B | |
Logit adjustment loss* (Menon et al., 2021b) | ResNet-50 | 50.4 | 26 M | 4.1 B | |
LDAM-DRS-RSG (Wang et al., 2021) | ResNeXt-50 | 51.8 | 25 M | 4.2 B | |
DRAGON + Bal’Loss (Samuel et al., 2021) | ResNet-10 | 46.5 | 5.4 M | 819 M | |
DRAGON + Bal’Loss (Samuel et al., 2021) | ResNet-50 | 53.5 | 26 M | 4.1 B | |
Teacher (labeler) model | EfficientNet-b3 | 79.2 | 12 M | 1.8 B | |
One-hot | MobileNetV3-0.75 | 35.5 | 4.01 M | 156 M | |
Distillation | MobileNetV3-0.75 | 47.2 | 4.01 M | 156 M | |
TGT (random) | MobileNetV3-0.75 | 53.2 | 4.01 M | 156 M | |
TGT (gradient-based) | MobileNetV3-0.75 | 53.3 | 4.01 M | 156 M | |
SUN397-LT | LDAM-DRS-RSG (Wang et al., 2021) | ResNeXt-50 | 29.8 | 25 M | 4.2 B |
CAD-VAE (Schönfeld et al., 2019) | ResNet-101 | 32.8 | 42 M | 7.6 B | |
LWS (Kang et al., 2020) | ResNeXt-50 | 33.9 | 25 M | 4.2 B | |
DRAGON + Bal’Loss (Samuel et al., 2021) | ResNet-101 | 36.1 | 42 M | 7.6 B | |
Teacher (labeler) model | EfficientNet-b3 | 65.3 | 12 M | 1.8 B | |
One-hot | MobileNetV3-0.75 | 39.3 | 4.01 M | 156 M | |
Distillation | MobileNetV3-0.75 | 42.2 | 4.01 M | 156 M | |
TGT (random) | MobileNetV3-0.75 | 44.3 | 4.01 M | 156 M | |
TGT (gradient-based) | MobileNetV3-0.75 | 44.7 | 4.01 M | 156 M |
4.1 Long-tail image classification
Setup. We evaluate TGT by training student models on three benchmark long-tail image classification datasets: ImageNet-LT (Liu et al., 2019c), SUN-LT (Patterson and Hays, 2012), Places-LT (Liu et al., 2019c) We employ off-the-shelf teacher models, in particular BigBiGAN (ResNet-50) (Donahue and Simonyan, 2019) and EfficientNet-B3 (Xie et al., 2020c) as the teacher generator and teacher labeler models, respectively. We utilize MobileNetV3 (Howard et al., 2019) as compact student model architecture. The teacher-labeler model is self-trained on JFT-300M (Sun et al., 2017), and then finetuned on the task-specific long-tail dataset. The teacher generator is trained on the unlabelled full version of ImageNet (Russakovsky et al., 2015).
Results. The results111We report results for Places-LT in App. E due to space constraints. are reported in Table 1 compared with similar sized baselines (we ignored gigantic transformer models). We see that TGT is able to effectively transfer knowledge acquired by the teacher during its training with the huge amount of data into a significantly smaller student model, which also has lower inference cost. We see that TGT considerably improves the performance across the board over standard distillation, even on Sun-LT and Places-LT whose data distribution does not exactly match to the distribution that the teacher’s generator was trained with. Comparing TGT (random) (cf. Eq. 4) and TGT (gradient-based) (cf. Eq. 5) indicates that most of our win comes from utilizing the latent space, the form of search being of secondary importance. Thus, for all subsequent experiments we only consider TGT (random).
Here, we note the baselines stated from the literature in Table 1 rely on specialized loss function and/or training procedures designed for the long-tail setting, whereas we do not leverage such techniques. One can pontentially combine the TGT framework with a long-tail specific loss function as opposed to employing the standard cross-entropy loss function as a way to further improve its performance. We leave this direction for future explorations.
Method | Architecture | Amazon-5 | IMDB | MNLI | Yelp-5 | |||
---|---|---|---|---|---|---|---|---|
2.5k | 3M | 2.5k | 650k | |||||
UDA (Random Init) (Xie et al., 2020b) | BERT base | 55.8 | - | - | - | 58.6 | - | |
UDA (Pretrained) (Xie et al., 2020b) | BERT base | 62.9 | - | - | - | 67.9 | - | |
Layer-wise Distillation (Sun et al., 2020) | MobileBERT | - | - | 93.6 | 83.3 | - | - | |
MATE-KD (Rashid et al., 2021) | DistilBERT | - | - | - | 85.8 | - | - | |
Teacher (labeler) model | RoBERTa large | - | 67.6 | 96.2 | 90.6 | - | 72.0 | |
One-hot (Random Init) | DistilBERT | 44.3 | 53.6 | 50.0 | 63.0 | 50.4 | 58.1 | |
One-hot (Pretrained) | DistilBERT | 55.9 | 66.3 | 93.6 | 84.1 | 59.1 | 67.3 | |
Distillation (Random Init) | DistilBERT | 56.5 | 65.3 | 87.9 | 77.4 | 54.8 | 69.5 | |
Distillation (Pretrained) | DistilBERT | 60.2 | 66.8 | 94.0 | 84.5 | 63.2 | 71.4 | |
TGT (Random Init) | DistilBERT | 61.3 | 66.6 | 91.0 | 79.3 | 62.0 | 70.4 | |
TGT (Pretrained) | DistilBERT | 64.6 | 67.1 | 95.4 | 86.0 | 68.6 | 71.7 |

4.2 TGT in low-data regime
To further showcase effectiveness of knowledge transfer via TGT, we simulate a low-data regime by varying the amount of available training data for ImageNet (Russakovsky et al., 2015) and studying its impact on student’s performance. For these experiments, we use the same model architectures as in Sec. 4.1, but finetune the teacher labeler on the entire ImageNet. We then compare the performance of the student trained via TGT, with the students trained via normal training (with one-hot labels) and standard distillation.
The results are shown in Fig. 2. Note that both TGT and standard distillation utilize additional training data more effectively than normal training, with TGT being the most efficient of the two. The curves show TGT is equivalent to an increase in sample size by 4x, compared to the normal training. This verifies that TGT generates informative training instances for the student.
4.3 Text classification
Setup. We evaluate the proposed TGT framework on four benchmark text classification datasets: Amazon-5 (Zhang et al., 2015), IMDB (Maas et al., 2011), MNLI (Williams et al., 2018), and Yelp-5 (Zhang et al., 2015). Following Xie et al. (2020a), we also consider an extremely sub-sampled version of Amazon-5 and Yelp-5 datasets consisting of only 2500 labeled examples. Similar to the image setting, we utilize off-the-shelf teacher models, in particular a BART-base (Lewis et al., 2020) and RoBERTa-large (Liu et al., 2019a) as the teacher generator and teacher labeler, respectively. Following standard practice (Rashid et al., 2021), we employ a DistilBERT (Sanh et al., 2019b) model for the student model architecture. Both teacher networks are pretrained on a very large generic text corpus of size 160GB. The teacher labeler model is finetuned on the corresponding dataset for each task. The teacher generator is not specialized to any specific classification task.
Results. The results are reported in Table 2 where we compare TGT with other data augmentation and distillation baselines. We see that TGT considerably improves the performance and beats the state-of-the-art methods MATE-KD (Rashid et al., 2021) and UDA (Xie et al., 2020a). Also, note that by using TGT on a randomly initialized student, we can match the performance of normal finetuning (with one-hot labels) on a pretrained model on Amazon-5 and Yelp-5. We highlight that baselines such as MATE-KD (Rashid et al., 2021) always work with a pretrained student model. Thus, the aforementioned improvements realized by TGT with a randomly initialized student model demonstrates enormous saving in overall data and training time requirement as it eliminates the need for pretraining on a large corpus. This further establishes that TGT can enable a data-efficient knowledge transfer from the teacher to the student.
4.4 Text retrieval
Method | recall@20 | recall@100 |
---|---|---|
Teacher (labeler) model | 0.7957 | 0.8855 |
one-hot | 0.6453 | 0.8198 |
distillation (standard) | 0.7486 | 0.8608 |
uniform negatives distillation | 0.7536 | 0.8496 |
TGT distillation (ours) | 0.7659 | 0.8763 |
Setup. Finally, we evaluate TGT on Natural Questions (NQ) (Kwiatkowski et al., 2019) — a text retrieval benchmark dataset. The task is to find a matching passage given a question, from a large set of candidate passages (21M). We utilize the RoBERTa-Base dual-encoder model Oğuz et al. (2021) as our teacher labeler. For teacher generator, we employ BART-base (Lewis et al., 2020). We utilize DistilBERT dual encoder model as our student architecture. We follow the standard retrieval distillation setup where the teacher labeler provides labels for all the within-batch question-to-passage pairs for the student to match.
We consider three baselines: One-hot trains the student with the original one-hot training labels whereas Distillation utilizes the teacher labeler instead. In uniform negatives, for a given question-to-passage pair in NQ, we uniformly sample and label additional 2 passages from the entire passage corpus (21M). TGT instead dynamically generates 2 confusing passages for each question-passage pair with BART generator, infusing the isotropic perturbation as per Eq. 4.
Results. Table 3 compares TGT with other baselines. TGT significantly improves retrieval performance, closing the teacher-student gap by 37% at recall@20 and 63% at recall@100 compared to the standard distillation. Unlike TGT, uniformly sampled random passages only partially helped (slightly on recall@20 but degrades at @100 compared to the standard distillation). A plausible explanation is that the randomly sampled passages do not provide enough relevance to the matching pair since the output space is extremely large (21M). TGT instead generates informative passages that are close to the matching pair.
5 Conclusion and Future Directions
We have introduced a simple and formally justified distillation scheme (TGT) that adaptively generates samples with the aim of closing the divergence between student and teacher predictions. Our results show it to outperform, in aggregate, existing distillation approaches. Unlike alternative methods, it is also applicable to both continuous and discrete domains, as the results on image and text data show. TGT is orthogonal to other approaches that enable efficient inference such as quantization and pruning, and combining them is an interesting avenue for future work. Another potential research direction is to employ TGT for multi-modal data which would require accommodating multiple generative models with their own latent space, raising both practical and theoretical challenges.
References
- Aguilar et al. [2020] Gustavo Aguilar, Yuan Ling, Yu Zhang, Benjamin Yao, Xing Fan, and Chenlei Guo. Knowledge distillation from internal representations. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 34, pages 7350–7357, 2020.
- Arjovsky et al. [2017] Martin Arjovsky, Soumith Chintala, and Léon Bottou. Wasserstein generative adversarial networks. In International conference on machine learning, pages 214–223. PMLR, 2017.
- Arora et al. [2017] Sanjeev Arora, Rong Ge, Yingyu Liang, Tengyu Ma, and Yi Zhang. Generalization and equilibrium in generative adversarial nets (GANs). In Doina Precup and Yee Whye Teh, editors, Proceedings of the 34th International Conference on Machine Learning, volume 70 of Proceedings of Machine Learning Research, pages 224–232. PMLR, 06–11 Aug 2017.
- Arora et al. [2018] Sanjeev Arora, Andrej Risteski, and Yi Zhang. Do GANs learn the distribution? some theory and empirics. In International Conference on Learning Representations, 2018.
- Berthelot et al. [2019] David Berthelot, Nicholas Carlini, Ian Goodfellow, Nicolas Papernot, Avital Oliver, and Colin A Raffel. Mixmatch: A holistic approach to semi-supervised learning. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d Alché-Buc, E. Fox, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 32. Curran Associates, Inc., 2019.
- Binici et al. [2022] Kuluhan Binici, Nam Trung Pham, Tulika Mitra, and Karianto Leman. Preventing catastrophic forgetting and distribution mismatch in knowledge distillation via synthetic data. In Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision, pages 663–671, 2022.
- Brock et al. [2019] Andrew Brock, Jeff Donahue, and Karen Simonyan. Large scale GAN training for high fidelity natural image synthesis. In International Conference on Learning Representations, 2019.
- Brown et al. [2020] Tom Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared D Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, Sandhini Agarwal, Ariel Herbert-Voss, Gretchen Krueger, Tom Henighan, Rewon Child, Aditya Ramesh, Daniel Ziegler, Jeffrey Wu, Clemens Winter, Chris Hesse, Mark Chen, Eric Sigler, Mateusz Litwin, Scott Gray, Benjamin Chess, Jack Clark, Christopher Berner, Sam McCandlish, Alec Radford, Ilya Sutskever, and Dario Amodei. Language models are few-shot learners. In H. Larochelle, M. Ranzato, R. Hadsell, M. F. Balcan, and H. Lin, editors, Advances in Neural Information Processing Systems, volume 33, pages 1877–1901. Curran Associates, Inc., 2020.
- Bucilua et al. [2006] Cristian Bucilua, Rich Caruana, and Alexandru Niculescu-Mizil. Model compression. In Proceedings of the 12th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, KDD ’06, page 535–541, New York, NY, USA, 2006. Association for Computing Machinery. ISBN 1595933395. doi: 10.1145/1150402.1150464.
- Cao et al. [2019] Kaidi Cao, Colin Wei, Adrien Gaidon, Nikos Arechiga, and Tengyu Ma. Learning imbalanced datasets with label-distribution-aware margin loss. In Proceedings of the 33rd International Conference on Neural Information Processing Systems, pages 1567–1578, 2019.
- Chen et al. [2019] Hanting Chen, Yunhe Wang, Chang Xu, Zhaohui Yang, Chuanjian Liu, Boxin Shi, Chunjing Xu, Chao Xu, and Qi Tian. Data-free learning of student networks. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pages 3514–3522, 2019.
- Chen et al. [2020a] Ting Chen, Simon Kornblith, Mohammad Norouzi, and Geoffrey Hinton. A simple framework for contrastive learning of visual representations. In Hal Daumé III and Aarti Singh, editors, Proceedings of the 37th International Conference on Machine Learning, volume 119 of Proceedings of Machine Learning Research, pages 1597–1607. PMLR, 13–18 Jul 2020a.
- Chen et al. [2020b] Ting Chen, Simon Kornblith, Kevin Swersky, Mohammad Norouzi, and Geoffrey E Hinton. Big self-supervised models are strong semi-supervised learners. In H. Larochelle, M. Ranzato, R. Hadsell, M. F. Balcan, and H. Lin, editors, Advances in Neural Information Processing Systems, volume 33, pages 22243–22255. Curran Associates, Inc., 2020b.
- Cho and Hariharan [2019] J. Cho and B. Hariharan. On the efficacy of knowledge distillation. In 2019 IEEE/CVF International Conference on Computer Vision (ICCV), pages 4793–4801, Los Alamitos, CA, USA, nov 2019. IEEE Computer Society.
- Devlin et al. [2019] Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. BERT: pre-training of deep bidirectional transformers for language understanding. In Jill Burstein, Christy Doran, and Thamar Solorio, editors, Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, NAACL-HLT 2019, Minneapolis, MN, USA, June 2-7, 2019, Volume 1 (Long and Short Papers), pages 4171–4186. Association for Computational Linguistics, 2019. doi: 10.18653/v1/n19-1423.
- Devroye et al. [2013] Luc Devroye, László Györfi, and Gábor Lugosi. A probabilistic theory of pattern recognition, volume 31. Springer Science & Business Media, 2013.
- Donahue and Simonyan [2019] Jeff Donahue and Karen Simonyan. Large scale adversarial representation learning. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d Alché-Buc, E. Fox, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 32. Curran Associates, Inc., 2019.
- Dong et al. [2020] Zihe Dong, Xin Sun, Junyu Dong, and Haoran Zhao. Adversarial metric knowledge distillation. In 2020 the 6th International Conference on Communication and Information Processing, page 159–164, New York, NY, USA, 2020. Association for Computing Machinery. ISBN 9781450388092.
- Du et al. [2021] Jingfei Du, Edouard Grave, Beliz Gunel, Vishrav Chaudhary, Onur Celebi, Michael Auli, Veselin Stoyanov, and Alexis Conneau. Self-training improves pre-training for natural language understanding. In Proceedings of the 2021 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, pages 5408–5418, Online, June 2021. Association for Computational Linguistics. doi: 10.18653/v1/2021.naacl-main.426.
- Fan et al. [2018] Yang Fan, Fei Tian, Tao Qin, Xiang-Yang Li, and Tie-Yan Liu. Learning to teach. In International Conference on Learning Representations, 2018.
- Foret et al. [2020] Pierre Foret, Ariel Kleiner, Hossein Mobahi, and Behnam Neyshabur. Sharpness-aware minimization for efficiently improving generalization. arXiv preprint arXiv:2010.01412, 2020.
- Foster et al. [2019] Dylan J Foster, Spencer Greenberg, Satyen Kale, Haipeng Luo, Mehryar Mohri, and Karthik Sridharan. Hypothesis set stability and generalization. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d’Alché Buc, E. Fox, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 32. Curran Associates, Inc., 2019.
- Gottlieb et al. [2016] Lee-Ad Gottlieb, Aryeh Kontorovich, and Robert Krauthgamer. Adaptive metric dimensionality reduction. Theoretical Computer Science, 620:105–118, 2016.
- Györfi et al. [2002] László Györfi, Michael Kohler, Adam Krzyżak, and Harro Walk. A distribution-free theory of nonparametric regression, volume 1. Springer, 2002.
- Hafner et al. [2019] Danijar Hafner, Timothy Lillicrap, Ian Fischer, Ruben Villegas, David Ha, Honglak Lee, and James Davidson. Learning latent dynamics for planning from pixels. In International Conference on Machine Learning, pages 2555–2565. PMLR, 2019.
- He et al. [2016] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 770–778, 2016.
- He et al. [2021] Xuanli He, Islam Nassar, Jamie Kiros, Gholamreza Haffari, and Mohammad Norouzi. Generate, annotate, and learn: Generative models advance self-training and knowledge distillation. arXiv preprint arXiv:2106.06168, 2021.
- Heinzerling and Inui [2021] Benjamin Heinzerling and Kentaro Inui. Language models as knowledge bases: On entity representations, storage capacity, and paraphrased queries. In Proceedings of the 16th Conference of the European Chapter of the Association for Computational Linguistics: Main Volume, pages 1772–1791, Online, April 2021. Association for Computational Linguistics. doi: 10.18653/v1/2021.eacl-main.153.
- Heo et al. [2019] Byeongho Heo, Minsik Lee, Sangdoo Yun, and Jin Young Choi. Knowledge distillation with adversarial samples supporting decision boundary. AAAI’19/IAAI’19/EAAI’19. AAAI Press, 2019. ISBN 978-1-57735-809-1.
- Hinton et al. [2015] Geoffrey Hinton, Oriol Vinyals, and Jeff Dean. Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531, 2015.
- Howard et al. [2019] Andrew Howard, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang, Yukun Zhu, Ruoming Pang, Vijay Vasudevan, Quoc V. Le, and Hartwig Adam. Searching for mobilenetv3. In Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV), October 2019.
- Iscen et al. [2019] Ahmet Iscen, Giorgos Tolias, Yannis Avrithis, and Ondrej Chum. Label propagation for deep semi-supervised learning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), June 2019.
- Jaegle et al. [2021] Andrew Jaegle, Felix Gimeno, Andy Brock, Oriol Vinyals, Andrew Zisserman, and Joao Carreira. Perceiver: General perception with iterative attention. In Marina Meila and Tong Zhang, editors, Proceedings of the 38th International Conference on Machine Learning, volume 139 of Proceedings of Machine Learning Research, pages 4651–4664. PMLR, 18–24 Jul 2021.
- Jia et al. [2018] Xianyan Jia, Shutao Song, Wei He, Yangzihao Wang, Haidong Rong, Feihu Zhou, Liqiang Xie, Zhenyu Guo, Yuanzhou Yang, Liwei Yu, et al. Highly scalable deep learning training system with mixed-precision: Training imagenet in four minutes. arXiv preprint arXiv:1807.11205, 2018.
- Kang et al. [2020] Bingyi Kang, Saining Xie, Marcus Rohrbach, Zhicheng Yan, Albert Gordo, Jiashi Feng, and Yannis Kalantidis. Decoupling representation and classifier for long-tailed recognition. In International Conference on Learning Representations, 2020.
- Karpukhin et al. [2020] Vladimir Karpukhin, Barlas Oğuz, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. Dense passage retrieval for open-domain question answering. arXiv preprint arXiv:2004.04906, 2020.
- Kwiatkowski et al. [2019] Tom Kwiatkowski, Jennimaria Palomaki, Olivia Redfield, Michael Collins, Ankur Parikh, Chris Alberti, Danielle Epstein, Illia Polosukhin, Jacob Devlin, Kenton Lee, et al. Natural questions: a benchmark for question answering research. Transactions of the Association for Computational Linguistics, 7:453–466, 2019.
- Laine and Aila [2017] Samuli Laine and Timo Aila. Temporal ensembling for semi-supervised learning. In International Conference on Learning Representations, 2017.
- Lee et al. [2019] Kenton Lee, Ming-Wei Chang, and Kristina Toutanova. Latent retrieval for weakly supervised open domain question answering. In Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics, pages 6086–6096, Florence, Italy, July 2019. Association for Computational Linguistics. doi: 10.18653/v1/P19-1612.
- Lewis et al. [2020] Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Veselin Stoyanov, and Luke Zettlemoyer. Bart: Denoising sequence-to-sequence pre-training for natural language generation, translation, and comprehension. In Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics, pages 7871–7880, 2020.
- Lewis et al. [2021] Patrick Lewis, Yuxiang Wu, Linqing Liu, Pasquale Minervini, Heinrich Küttler, Aleksandra Piktus, Pontus Stenetorp, and Sebastian Riedel. Paq: 65 million probably-asked questions and what you can do with them. Transactions of the Association for Computational Linguistics, 9:1098–1115, 2021.
- Li et al. [2019] Hao-Ting Li, Shih-Chieh Lin, Cheng-Yeh Chen, and Chen-Kuo Chiang. Layer-level knowledge distillation for deep neural network learning. Applied Sciences, 9(10):1966, 2019.
- Liu et al. [2019a] Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, and Veselin Stoyanov. RoBERTa: A robustly optimized bert pretraining approach. arXiv preprint arXiv:1907.11692, 2019a.
- Liu et al. [2019b] Ziwei Liu, Zhongqi Miao, Xiaohang Zhan, Jiayun Wang, Boqing Gong, and Stella X. Yu. Large-scale long-tailed recognition in an open world. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), June 2019b.
- Liu et al. [2019c] Ziwei Liu, Zhongqi Miao, Xiaohang Zhan, Jiayun Wang, Boqing Gong, and Stella X Yu. Large-scale long-tailed recognition in an open world. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 2537–2546, 2019c.
- Maas et al. [2011] Andrew L. Maas, Raymond E. Daly, Peter T. Pham, Dan Huang, Andrew Y. Ng, and Christopher Potts. Learning word vectors for sentiment analysis. In Proceedings of the 49th Annual Meeting of the Association for Computational Linguistics: Human Language Technologies, pages 142–150, Portland, Oregon, USA, June 2011. Association for Computational Linguistics.
- Maurer and Pontil [2009] A. Maurer and M. Pontil. Empirical bernstein bounds and sample-variance penalization. In Proceedings of the 22nd Conference on Learning Theory (COLT), June 2009.
- Menon et al. [2021a] Aditya K Menon, Ankit Singh Rawat, Sashank Reddi, Seungyeon Kim, and Sanjiv Kumar. A statistical perspective on distillation. In Marina Meila and Tong Zhang, editors, Proceedings of the 38th International Conference on Machine Learning, volume 139 of Proceedings of Machine Learning Research, pages 7632–7642. PMLR, 18–24 Jul 2021a.
- Menon et al. [2021b] Aditya Krishna Menon, Sadeep Jayasumana, Ankit Singh Rawat, Himanshu Jain, Andreas Veit, and Sanjiv Kumar. Long-tail learning via logit adjustment. In International Conference on Learning Representations, 2021b.
- Micaelli and Storkey [2019] Paul Micaelli and Amos J Storkey. Zero-shot knowledge transfer via adversarial belief matching. Advances in Neural Information Processing Systems, 32, 2019.
- Miyato et al. [2019] Takeru Miyato, Shin-Ichi Maeda, Masanori Koyama, and Shin Ishii. Virtual adversarial training: A regularization method for supervised and semi-supervised learning. IEEE Transactions on Pattern Analysis and Machine Intelligence, 41(8):1979–1993, 2019. doi: 10.1109/TPAMI.2018.2858821.
- Mohri et al. [2018] Mehryar Mohri, Afshin Rostamizadeh, and Ameet Talwalkar. Foundations of machine learning. MIT press, 2018.
- Nayak et al. [2019] Gaurav Kumar Nayak, Konda Reddy Mopuri, Vaisakh Shaj, Venkatesh Babu Radhakrishnan, and Anirban Chakraborty. Zero-shot knowledge distillation in deep networks. In International Conference on Machine Learning, pages 4743–4751. PMLR, 2019.
- Nowak-Vila et al. [2021] Alex Nowak-Vila, Alessandro Rudi, and Francis Bach. Max-margin is dead, long live max-margin! arXiv preprint arXiv:2105.15069, 2021.
- Oğuz et al. [2021] Barlas Oğuz, Kushal Lakhotia, Anchit Gupta, Patrick Lewis, Vladimir Karpukhin, Aleksandra Piktus, Xilun Chen, Sebastian Riedel, Wen-tau Yih, Sonal Gupta, et al. Domain-matched pre-training tasks for dense retrieval. arXiv preprint arXiv:2107.13602, 2021.
- Patterson and Hays [2012] Genevieve Patterson and James Hays. Sun attribute database: Discovering, annotating, and recognizing scene attributes. In 2012 IEEE Conference on Computer Vision and Pattern Recognition, pages 2751–2758. IEEE, 2012.
- Pham et al. [2021] Hieu Pham, Zihang Dai, Qizhe Xie, and Quoc V. Le. Meta pseudo labels. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pages 11557–11568, June 2021.
- Raffel et al. [2020] Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, Yanqi Zhou, Wei Li, and Peter J. Liu. Exploring the limits of transfer learning with a unified text-to-text transformer. Journal of Machine Learning Research, 21(140):1–67, 2020.
- Raghu et al. [2021] Aniruddh Raghu, Maithra Raghu, Simon Kornblith, David Duvenaud, and Geoffrey Hinton. Teaching with commentaries. In International Conference on Learning Representations, 2021.
- Ramesh et al. [2021] Aditya Ramesh, Mikhail Pavlov, Gabriel Goh, Scott Gray, Chelsea Voss, Alec Radford, Mark Chen, and Ilya Sutskever. Zero-shot text-to-image generation. In Marina Meila and Tong Zhang, editors, Proceedings of the 38th International Conference on Machine Learning, volume 139 of Proceedings of Machine Learning Research, pages 8821–8831. PMLR, 18–24 Jul 2021.
- Rashid et al. [2021] Ahmad Rashid, Vasileios Lioutas, and Mehdi Rezagholizadeh. MATE-KD: Masked adversarial TExt, a companion to knowledge distillation. In Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing (Volume 1: Long Papers), pages 1062–1071, Online, August 2021. Association for Computational Linguistics. doi: 10.18653/v1/2021.acl-long.86.
- Ravuri and Vinyals [2019a] Suman Ravuri and Oriol Vinyals. Classification accuracy score for conditional generative models. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d Alché-Buc, E. Fox, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 32. Curran Associates, Inc., 2019a.
- Ravuri and Vinyals [2019b] Suman Ravuri and Oriol Vinyals. Seeing is not necessarily believing: Limitations of biggans for data augmentation. Learning from Limited Labeled Data: ICLR 2019 Workshop, 2019b.
- Razavi et al. [2019] Ali Razavi, Aaron van den Oord, and Oriol Vinyals. Generating diverse high-fidelity images with vq-vae-2. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d Alché-Buc, E. Fox, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 32. Curran Associates, Inc., 2019.
- Ren et al. [2019] Shuhuai Ren, Yihe Deng, Kun He, and Wanxiang Che. Generating natural language adversarial examples through probability weighted word saliency. In Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics, pages 1085–1097, Florence, Italy, July 2019. Association for Computational Linguistics.
- Russakovsky et al. [2015] Olga Russakovsky, Jia Deng, Hao Su, Jonathan Krause, Sanjeev Satheesh, Sean Ma, Zhiheng Huang, Andrej Karpathy, Aditya Khosla, Michael Bernstein, Alexander C. Berg, and Li Fei-Fei. ImageNet Large Scale Visual Recognition Challenge. International Journal of Computer Vision (IJCV), 115(3):211–252, 2015. doi: 10.1007/s11263-015-0816-y.
- Sajjadi et al. [2016] Mehdi Sajjadi, Mehran Javanmardi, and Tolga Tasdizen. Regularization with stochastic transformations and perturbations for deep semi-supervised learning. In D. Lee, M. Sugiyama, U. Luxburg, I. Guyon, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 29. Curran Associates, Inc., 2016.
- Samuel et al. [2021] Dvir Samuel, Yuval Atzmon, and Gal Chechik. From generalized zero-shot learning to long-tail with class descriptors. In Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV), 2021.
- Sanh et al. [2019a] Victor Sanh, Lysandre Debut, Julien Chaumond, and Thomas Wolf. Distilbert, a distilled version of BERT: smaller, faster, cheaper and lighter. arXiv preprint arXiv:1910.01108, 2019a.
- Sanh et al. [2019b] Victor Sanh, Lysandre Debut, Julien Chaumond, and Thomas Wolf. Distilbert, a distilled version of bert: smaller, faster, cheaper and lighter. arXiv preprint arXiv:1910.01108, 2019b.
- Schönfeld et al. [2019] Edgar Schönfeld, Sayna Ebrahimi, Samarth Sinha, Trevor Darrell, and Zeynep Akata. Generalized zero-shot learning via aligned variational autoencoders. red, 2:D2, 2019.
- Sohn et al. [2020] Kihyuk Sohn, David Berthelot, Nicholas Carlini, Zizhao Zhang, Han Zhang, Colin A Raffel, Ekin Dogus Cubuk, Alexey Kurakin, and Chun-Liang Li. Fixmatch: Simplifying semi-supervised learning with consistency and confidence. In H. Larochelle, M. Ranzato, R. Hadsell, M. F. Balcan, and H. Lin, editors, Advances in Neural Information Processing Systems, volume 33, pages 596–608. Curran Associates, Inc., 2020.
- Stanton et al. [2021] Samuel Don Stanton, Pavel Izmailov, Polina Kirichenko, Alexander A Alemi, and Andrew Gordon Wilson. Does knowledge distillation really work? In A. Beygelzimer, Y. Dauphin, P. Liang, and J. Wortman Vaughan, editors, Advances in Neural Information Processing Systems, 2021.
- Sun et al. [2017] C. Sun, A. Shrivastava, S. Singh, and A. Gupta. Revisiting unreasonable effectiveness of data in deep learning era. In 2017 IEEE International Conference on Computer Vision (ICCV), pages 843–852, 2017. doi: 10.1109/ICCV.2017.97.
- Sun et al. [2019] Siqi Sun, Yu Cheng, Zhe Gan, and Jingjing Liu. Patient knowledge distillation for BERT model compression. In Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP), pages 4323–4332, Hong Kong, China, November 2019. Association for Computational Linguistics.
- Sun et al. [2020] Zhiqing Sun, Hongkun Yu, Xiaodan Song, Renjie Liu, Yiming Yang, and Denny Zhou. MobileBERT: a compact task-agnostic BERT for resource-limited devices. In Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics, pages 2158–2170, Online, July 2020. Association for Computational Linguistics.
- Tan et al. [2020] Samson Tan, Shafiq Joty, Min-Yen Kan, and Richard Socher. It’s morphin’ time! Combating linguistic discrimination with inflectional perturbations. In Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics, pages 2920–2935, Online, July 2020. Association for Computational Linguistics.
- Tarvainen and Valpola [2017] Antti Tarvainen and Harri Valpola. Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results. In I. Guyon, U. V. Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 30. Curran Associates, Inc., 2017.
- Verma et al. [2019] Vikas Verma, Alex Lamb, Juho Kannala, Yoshua Bengio, and David Lopez-Paz. Interpolation consistency training for semi-supervised learning. In Proceedings of the Twenty-Eighth International Joint Conference on Artificial Intelligence (IJCAI-19), pages 3635–3641, 2019. International Joint Conference on Artificial Intelligence, IJCAI ; Conference date: 10-08-2019 Through 16-08-2019.
- Villani [2008] C. Villani. Optimal transport: old and new. Springer Verlag, 2008.
- Wang et al. [2021] Jianfeng Wang, Thomas Lukasiewicz, Xiaolin Hu, Jianfei Cai, and Zhenghua Xu. Rsg: A simple but effective module for learning imbalanced datasets. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pages 3784–3793, June 2021.
- Wang et al. [2018] Xiaojie Wang, Rui Zhang, Yu Sun, and Jianzhong Qi. Kdgan: Knowledge distillation with generative adversarial networks. In S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. Cesa-Bianchi, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 31. Curran Associates, Inc., 2018.
- Webster et al. [2019] Ryan Webster, Julien Rabin, Loïc Simon, and Frédéric Jurie. Detecting overfitting of deep generative networks via latent recovery. In 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pages 11265–11274, 2019. doi: 10.1109/CVPR.2019.01153.
- Weng et al. [2021] Zhenzhen Weng, Mehmet Giray Ogut, Shai Limonchik, and Serena Yeung. Unsupervised discovery of the long-tail in instance segmentation using hierarchical self-supervision. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 2603–2612, 2021.
- Williams et al. [2018] Adina Williams, Nikita Nangia, and Samuel Bowman. A broad-coverage challenge corpus for sentence understanding through inference. In Proceedings of the 2018 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long Papers), pages 1112–1122. Association for Computational Linguistics, 2018.
- Xie et al. [2020a] Qizhe Xie, Zihang Dai, Eduard Hovy, Thang Luong, and Quoc Le. Unsupervised data augmentation for consistency training. In H. Larochelle, M. Ranzato, R. Hadsell, M. F. Balcan, and H. Lin, editors, Advances in Neural Information Processing Systems, volume 33, pages 6256–6268. Curran Associates, Inc., 2020a.
- Xie et al. [2020b] Qizhe Xie, Zihang Dai, Eduard Hovy, Thang Luong, and Quoc Le. Unsupervised data augmentation for consistency training. Advances in Neural Information Processing Systems, 33, 2020b.
- Xie et al. [2020c] Qizhe Xie, Minh-Thang Luong, Eduard Hovy, and Quoc V. Le. Self-training with noisy student improves imagenet classification. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), June 2020c.
- Xie et al. [2020d] Qizhe Xie, Minh-Thang Luong, Eduard Hovy, and Quoc V. Le. Self-training with noisy student improves imagenet classification. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), June 2020d.
- Yoo et al. [2019] Jaemin Yoo, Minyong Cho, Taebum Kim, and U Kang. Knowledge extraction with no observable data. Advances in Neural Information Processing Systems, 32, 2019.
- Zang et al. [2020] Yuan Zang, Fanchao Qi, Chenghao Yang, Zhiyuan Liu, Meng Zhang, Qun Liu, and Maosong Sun. Word-level textual adversarial attacking as combinatorial optimization. In Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics, pages 6066–6080, Online, July 2020. Association for Computational Linguistics. doi: 10.18653/v1/2020.acl-main.540.
- Zhang et al. [2015] Xiang Zhang, Junbo Zhao, and Yann LeCun. Character-level convolutional networks for text classification. Advances in neural information processing systems, 28:649–657, 2015.
- Zhu et al. [2021] Zhaowei Zhu, Tianyi Luo, and Yang Liu. The rich get richer: Disparate impact of semi-supervised learning. arXiv preprint arXiv:2110.06282, 2021.
Appendix A Further comparison with MATE-KD
MATE-KD [Rashid et al., 2021] alternative trains generator model and student model, with the hope of generating most adversarial examples for the students during the training. This can cause the generator to drift away from true data distribution. In contrast, we keep the pre-trained teacher-generator model fixed throughout the training process of the student. Our objective behind employing the generator model is to leverage the domain knowledge it has already acquired during its pre-training. While we do want to generate ‘hard instances’ for the student, we also want those instances to be relevant for the underlying task. Thus, keeping the generator fixed introduces a regularization where the training instances the student encounters do not introduce domain mismatch.
Keeping in mind the objective of producing new informative training instances that are in-domain, we introduce perturbation in the latent space realized by the encoder of the teacher-generator model (see Figure 1). This is different from directly perturbing an original training instance in the input space itself, as done by MATE-KD. As evident from our theoretical analysis and empirical evaluation, for a fixed teacher-generator model, employing perturbation in the latent space leads to more informative data augmentation and enables good performance on both image and text domain.
Appendix B Background and notation
For , we use to denote that there exits a constant such that .
Given a collection of i.i.d. random variables , generated from a distribution and a function , we define the empirical mean of as
(8) |
For the underlying multiclass classification problem defined by the distribution , we assume that the label set with classes takes the form . We use to denote the collection of potential classification models that the learning methods is allowed to select from, namely function class or hypothesis set:
(9) |
which is a subset of all functions that map elements of the instance space to the elements of .
Given a classification loss function and a model and a sample generated from , we define the empirical risk for as follows.
(10) |
Further, we define the population risk for associated with data distribution as follows.
(11) |
Note that, when the loss function is clear from the context, we drop from the notation and simply use and to denote the the empirical and populations risks for , respectively.
Given a function class , the loss function induces the following function class.
(12) |
Definition B.1 (Rademacher complexity of ).
Now, given a sample and a vector with i.i.d. Bernoulli random variables, empirical Rademacher complexity and Rademacher complexity are defined as
and
(13) |
Let be a set of unlabeled samples generated from . Then, given a teacher model and a distillation loss , we define the empirical (distillation) risk for to be
(14) |
Accordingly, the population (distillation) risk for is defined as
(15) |
Again, when is clear from the context, we simply use and to denote the empirical and population distillation risk for , respectively.
Note that, for a (student) function class and a teacher model , produces an induced function class , defined as
(16) |
Definition B.2 (Rademacher complexity of ).
Given a sample and a vector with i.i.d. Bernoulli randoms variable, empirical Rademacher complexity and Rademacher complexity are defined as
(17) |
and
(18) |
respectively.
Appendix C Deferred proofs from Section 3
C.1 Proof of Theorem 3.1
In this subsection, we present a general version of Theorem 3.1. Before that, we state the following relevant assumption on the distillation loss .
Assumption C.1.
Let be a bounded loss function. For a teacher function , the distillation loss takes the form
Remark C.2.
Note that the cross-entropy loss , here, one of the most common choices for the distillation loss, indeed satisfies Assumption C.1.222For the sake of brevity, we simply include the softmax-operation in the definition of and , i.e., and are valid probability distributions over .
The following results is a general version of Theorem 3.1 in the main body.
Theorem C.3.
Let a generator with the encoder and decoder ensures the approximation guarantee in Eq. 6 for . Let and teacher labeler be Lipschtiz functions, be function class of Lipschitz functions, and the distillation loss be Lipschtiz. Then, with probability at least , the following holds for any .
(19) |
where denotes the effective Lipschitz constant of the induced function class in Eq. 16. Additionally, if the distillation loss satisfies Assumption C.1 with a classification loss , then Eq. 19 further implies the following.
(20) |
Proof.
Note that
(21) |
where follows from the definition of in Eq. 16 and follow from the standard symmetrization argument [Devroye et al., 2013, Mohri et al., 2018]. Next, we turn our focus to the empirical Rademacher complexity . Recall that contains i.i.d. samples generated from the distribution . We define another set of points
It follows from our assumption on the quality of the generator (cf. Eq. 6) that
(22) |
Note that
where denote a vector with i.i.d Bernoulli random variables.
(23) |
where , for . By definition of , for some . Now, we can define a new function class from to :
(24) |
Therefore, it follows from Eq. 23 and Eq. 24 that
(25) |
where . It follows from the standard concentration results for empirical Rademacher complexity around Rademacher complexity that with probability at least ,
(26) |
C.2 Proof of Theorem 3.2
Here, we present the following result, which along with Remark C.5 implies Theorem 3.2 stated in the main body.
Theorem C.4.
Let be i.i.d. samples generated from a distribution . Further, let denote the student model learned via distillation on , with and as the teacher model and distillation loss, respectively. Then, with probability at least , we have
(28) |
where denote that Rademacher complexity of the induced function class , defined in Eq. 16. If is constructed with the TGT framework based on a generator with the encoder and decoder , then Eq. 28 further specialized to
(29) |
where defines the following induced function class from (i.e., the latent space of the generator) to .
(30) |
Proof.
Note that
(31) |
where the last two inequality follows from the definition of (cf. Eq. 16) and the standard symmetrization argument [Devroye et al., 2013, Mohri et al., 2018], respectively.
Now, the standard concentration results for empirical Rademacher complexity implies that, with probability at least , we have the following.
(32) | ||||
(33) |
It follows from Lemma D.3 that
(34) |
Now the first part of Theorem C.4, as stated in Eq. 28, follows by combining Sec. C.2, Eq. 32, and Eq. 34.
We now focus on establishing Eq. 29. Note that, for a sample generated by the TGT framework, there exists such that
(35) |
Thus,
(36) |
where employs Eq. 35. Thus, combining Sec. C.2 and Sec. C.2 gives us that
(37) |
Now, similar to the proof of Eq. 28, we can invoke Lemma D.3 and the concentration result for empirical Rademacher complexity to obtain the desired result in Eq. 29 from Eq. 37. ∎
C.3 Weighted ERM: An alternative training procedure for TGT
Note that given the samples generated from and a teacher labeler , we minimize the following empirical risk for student training:
(38) |
However, as we notice in Theorem C.4, this leads to an additional penalty term in the generalization bound. One standard approach to address this issue is to consider the following weighted empirical risk.
(39) |
where and denote the probability density function (pdf) for and .333Note that the formulation assumes that , i.e., is absolutely continuous w.r.t. . Also, one can replace the pdf’s with probability mass functions if and are discrete distributions. Accordingly, we define a new induced function class related to the weighted empirical risk:
(40) |
Importantly, we have
(41) |
Thus, following the analysis utilized in Theorem C.4, one can obtain a high probability generalization of the form.
(42) |
which avoids the term.
In what follows, we explore an alternative approach to highlight the importance of the sampling approach adapted by (gradient-based) TGT. By leveraging the variance-based generalization bound [Maurer and Pontil, 2009] that were previously utilized by Menon et al. [2021a] in the context distillation, we obtain the following result for the weighted empirical risk in Eq. 39.
Proposition C.6.
Proof.
Remark C.7.
Eq. 43 suggests general approach to select the distribution that generated the training samples . In order to ensure small generalization gap, it is desirable that the variance term is as small as possible. Note that, the distribution that minimizes this variance takes the form
(45) |
This looks like the lagrangian form of Eq. 3. Interestingly, TGT framework with gradient-based sampling (cf. equation 5) focuses on generating samples that maximizes the right hand side of Eq. 45 by first taking a sample generated according to and then perturbing it in the latent space to maximize the loss . Thus, the resulting distribution has pdf that aims to approximate the variance minimizing pdf in Eq. 45.
Here it is worth pointing out that, since exact form of and is generally not available during the training, it’s not straightforward to optimize the weighted risk introduced in Eq. 39.
Remark C.8.
Note that, as introduced in Section 3, TGT framework optimizes the empirical risk in Eq. 38 as opposed to minimizing Eq. 39. In this case, one can obatain a variance based bound analogous to Eq. 43 that takes the form:
(46) |
where, (II) denotes
with and depending the covering number for the induced function class (cf. Eq. 16). Notably, this bound again incurs a penalty of which is expected to be small for our TGT based sampling distribution when we employ high-quality teacher generator.
Appendix D Toolbox
This section presents necessary definitions and lemmas that we utilize to establish our theoretical results presented in Sec. 3 (and restated in App. C.
Definition D.1 (Wasserstein-1 metric).
Let be a metric space. Given two probability distributions and over , Wasserstein-1 distance between and is defined as follows.
(47) |
where denotes the set of all joint distributions over that have and as their marginals.
Lemma D.2 (Kantorovich-Rubinstein duality [Villani, 2008]).
Let denote the set of all 1-Lipschitz functions in the metric , i.e., for any ,
(48) |
Then,
(49) |
Lemma D.3.
Let be a loss function employed during the distillation. For a given teacher and a function class , we assume the the induced function class
(50) |
is contained in the class of -Lipschitz functions with respect to a metric . Then, for any two distributions and , we have
(51) |
where denotes the Wasserstein-1 metric between the two distribution and (cf. Definition D.1).
Proof.
Note that
(52) |
where follow by dividing and multiply by ; follows as, for any is is -Lipschitz; and follows from Lemma D.2. ∎
Lemma D.4.
Let the distillation loss satisfy Assumption C.1 with a bounded loss function . Then, given a teacher and a student model , we have
(53) |
where is treated as a vector in .
Proof.
Note that
(54) |
where follow from the Cauchy-Schwarz inequality. Now the statement of Lemma D.4 follows from the assumption on the loss is bounded. ∎
Appendix E Additional experiments
E.1 Long-tail image classification
Approach | Architecture | Balanced Accuracy | # parameters | FLOPs | |
---|---|---|---|---|---|
Places365-LT | LWS [Kang et al., 2020] | ResNet-152 | 37.6 | 60 M | 11 B |
LDAM-DRS-RSG [Wang et al., 2021] | ResNet-152 | 39.3 | 60 M | 11 B | |
OLTR [Liu et al., 2019b] | ResNet-152 | 35.9 | 60 M | 11 B | |
DRAGON + Bal’Loss [Samuel et al., 2021] | ResNet-50 | 38.1 | 26 M | 4.1 B | |
Teacher (labeler) model | EfficientNet-b3 | 42.1 | 12 M | 1.8 B | |
One-hot | MobileNetV3-0.75 | 26.8 | 4.01 M | 156 M | |
Distillation | MobileNetV3-0.75 | 33.0 | 4.01 M | 156 M | |
TGT (random) | MobileNetV3-0.75 | 34.7 | 4.01 M | 156 M | |
TGT (gradient-based) | MobileNetV3-0.75 | 35.0 | 4.01 M | 156 M |
Appendix F Details to reproduce our empirical results
Hereby we provide details to reproduce our experimental results.
F.1 Long-tail image classification (Sec. 4.1)
Dataset. The full balanced version of 3 datasets (ImageNet 444https://www.tensorflow.org/datasets/catalog/imagenet2012, Place365 555https://www.tensorflow.org/datasets/catalog/places365_small, SUN397 666https://www.tensorflow.org/datasets/catalog/sun397) are available in tensflow-datasets (https://www.tensorflow.org/datasets/). Next to obtain the the long-tail version of the datasets, we downloaded 777https://drive.google.com/drive/u/1/folders/1j7Nkfe6ZhzKFXePHdsseeeGI877Xu1yf image ids from repository of ”Large-Scale Long-Tailed Recognition in an Open World [Liu et al., 2019b]” according to which we subsampled the full balanced dataset.
Teacher fine-tuning. For teacher labeler, we follow ”Sharpness Aware Minimization’ [Foret et al., 2020] codebase (available at https://github.com/google-research/sam) to fine-tune on the long-tail datasets. We start with pretrained EfficientNet-B3 model checkpoint available from official repository888https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/noisystudent/efficientnet-b3/checkpoint.tar.gz and used default parameters from the codebase. We fine-tuned all 3 datasets (ImageNet-LT, SUN397-LT, Place365-LT) for 3 epochs.
We directly used teacher generator as BigBiGAN ResNet-50 checkpoint from the official repository https://github.com/deepmind/deepmind-research/tree/master/bigbigan. (We did not fine-tune it.)
Student training. We start from randomly initialized MobileNetV3-0.75 model. We employed SGD optimizer with cosine schedule (peak learning rate of 0.4 and decay down to 0). We also did a linear warm-up (from 0 to peak learning rate of 0.4) for first 5 epochs. The input image size are unfortunately different between EfficientNet-B3 model, BigBiGAN-ResNet50, and MobileNetV3-0.75 models. From original images in dataset, we use Tensorflow’s bicubic resizing to obtain appropriate size image for each mode. We did a grid search over the perturbation parameters and (c.f. Eq. 4 and Eq. 5). All hyper-parameters and grid are listed in table below:
Hyper-param | ImageNet-LT | Place365-LT | Sun397-LT |
---|---|---|---|
Num epochs | 90 | 30 | 30 |
Optimizer | SGD | ||
Schedule | Cosine | ||
Warm-up epochs | 5 | ||
Peak learning rate | 0.4 | ||
Batch size | 256 | ||
Teacher labeler image size | |||
Teacher generator image size | |||
Student image size | |||
Perturbation noise () | {0, 0.001, 0.01, 0.1} | ||
Gradient exploration | |||
- Step size () | {0, 0.001, 0.01, 0.1} | ||
- Num steps | 2 |
F.2 TGT in low-data regime (Sec. 4.2)
Dataset. We used ImageNet 999https://www.tensorflow.org/datasets/catalog/imagenet2012 dataset from tensflow-datasets repository (https://www.tensorflow.org/datasets/). We used in-built sub-sampling functionality available in tensorflow (https://www.tensorflow.org/datasets/splits) to simulate the low-data regime.
Teacher model. For teacher labeler, we directly used trained EfficientNet-B3 model checkpoint available from ”Sharpness Aware Minimization” repository101010https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/noisystudent/efficientnet-b3/checkpoint.tar.gz For teacher generator, we directly used trained BigBiGAN checkpoint from the official repository https://github.com/deepmind/deepmind-research/tree/master/bigbigan. (We did not fine-tune either of the models.)
Student training. We start from randomly initialized MobileNetV3-0.75 model. We employed SGD optimizer with cosine schedule (peak learning rate of 0.4 and decay down to 0). We also did a linear warm-up (from 0 to peak learning rate of 0.4) for first 5 epochs. The input image size are unfortunately different between EfficientNet-B3 model, BigBiGAN-ResNet50, and MobileNetV3-0.75 models. From original images in dataset, we use Tensorflow’s bicubic resizing to obtain appropriate size image for each mode. Following standard practice in literature He et al. [2016], Jia et al. [2018], we train one-hot and standard distillation student models for 90 epochs (= 450k steps). We use 4x less steps for TGT than the simple distillation baseline, which amounts to 450k/4 = 112k steps.
F.3 Text classification (Sec. 4.3)
Dataset. We conduct text classification experiments on following datasets:
-
•
Amazon-5 downloaded from http://goo.gl/JyCnZq
-
•
IMDB from tensorflow-datasets https://www.tensorflow.org/datasets/catalog/imdb_reviews
-
•
MNLI from from tensorflow-datasets https://www.tensorflow.org/datasets/catalog/multi_nli
-
•
Yelp-5 downloaded from http://goo.gl/JyCnZq
Optimizer. For all training, we employed ADAM optimizer with linear decay schedule (peak learning rate of 3e-5 and decay to 0). We also did a linear warm-up at start. We used batch size of 128.
Teacher fine-tuning. For teacher labeler, we started from RoBERTa-Base [Liu et al., 2019a] pretrained checkpoint 111111https://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz from official FAIRSEQ repository https://github.com/facebookresearch/fairseq. We fine-tuned using default parameters, other than number of steps which are same as those listed in Table 6.
For teacher generator, we directly use a pre-trained BART-Base [Lewis et al., 2020] checkpoint 121212https://dl.fbaipublicfiles.com/fairseq/models/bart.base.tar.gz from official FAIRSEQ repository https://github.com/facebookresearch/fairseq. (We did not fine-tune it.)
Student training. We start from DistillBERT pretrained checkpoint downloaded from HuggingFace repository 131313https://huggingface.co/distilroberta-base/tree/main. We perturb by adding Gaussian noise of variance in between encoder-decoder as well as masking out fraction of input. Then we generate new examples by running a greedy decoding of BART teacher generator for sequence length of 512. For dual input classification task, like in MNLI, we generate the two inputs independently. We did a grid search over the perturbation parameters and masking fraction . All hyper-parameters and grid are listed in table below:
Hyper-param | Amazon-5 | IMDB | MNLI | Yelp-5 | ||
---|---|---|---|---|---|---|
2.5k | 3M | 2.5k | 650k | |||
Num steps | 5000 | 75000 | 20000 | 75000 | 5000 | 75000 |
Warm-up steps | 1000 | 2000 | 500 | 2000 | 1000 | 2000 |
Optimizer | Adam | |||||
Schedule | Linear | |||||
Peak learning rate | 3e-5 | |||||
Batch size | 128 | |||||
Max Sequence length | 512 | |||||
Perturbation noise () | {0, 0.01, 0.1} | |||||
Masking fraction () | {0, 0.1, 0.2} |
F.4 Text retrieval (Sec. 4.4)
Dataset. From official ”Dense Passage Retrieval” repository at https://github.com/facebookresearch/DPR, we download passage corpus 141414https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz. Further, from the same repository, we download a pre-processed version of natural questions open [Lee et al., 2019] which has been aligned to passage corpus 151515https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-train.json.gz,https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-dev.json.gz. Finally, we download a pre-processed version of PAQ dataset [Lewis et al., 2021] dataset from official repository of ”Domain-matched Pre-training Tasks for Dense Retrieval” available at https://github.com/facebookresearch/dpr-scale which has been aligned to the same passage corpus 161616https://dl.fbaipublicfiles.com/dpr_scale/paq/PAQ.dpr.train.neg1.jsonl.zip
Optimizer. For all text retrieval model training, we employed ADAM optimizer with linear decay schedule (peak learning rate of 1e-5 and decay to 1e-7). We also did a linear warm-up (from 0 to peak learning rate of 1e-5) for 1K steps. We used batch size of 128.
Teacher fine-tuning. For teacher labeler dual encoder (a question encoder and a passage encoder), we utilized RoBERTa-Base [Liu et al., 2019a] pretrained checkpoint 171717https://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz from official FAIRSEQ repository https://github.com/facebookresearch/fairseq. We then conducted first round of fine-training for 300k iterations with passage-aligned PAQ dataset. We used same configuration as Oğuz et al. [2021] except Oğuz et al. trained with PAQ longer. After the pretraining, the teacher is fine-tuned on NQ-open [Kwiatkowski et al., 2019] downloaded with 40K steps. Similar to Karpukhin et al. [2020], Oğuz et al. [2021], the teacher is trained with within-batch negatives and the softmax-based cross-entropy loss.
For teacher generator, we directly use a pre-trained BART-Base [Lewis et al., 2020] checkpoint 181818https://dl.fbaipublicfiles.com/fairseq/models/bart.base.tar.gz from official FAIRSEQ repository https://github.com/facebookresearch/fairseq. (We did not fine-tune it.)
This same teacher labeler and generator is used for all student training except for the direct training (one-hot).
Student training. We start from DistillBERT pretrained checkpoint downloaded from HuggingFace repository 191919https://huggingface.co/distilroberta-base/tree/main. All students are trained with 40K steps. The teacher labeler will label all-pair within the batch and will label additional 2 passages per each question-passage pair for the uniform negative sampling baseline and TGT. We employed a off-the-shelf BART-base model as our generator [Lewis et al., 2020] and isotropic perturbation was added by random Gaussian noise of scale combined with for masking the original passage.
Appendix G Qualitative Examples of Generated Examples
G.1 Image Classification
We show some representative examples of generated images using TGT-random as well as TGT-gradient based from the experiment on ImageNet classification in Table 7.
Input | TGT-Random Example | TGT-Gradient based Example |
Data label: lion | Teacher label: brown bear | Teacher label: chow |
![]() |
![]() |
![]() |
Data label: cheeseburger | Teacher label: potpie | Teacher label: cheeseburger |
![]() |
![]() |
![]() |
Data label: digital clock | Teacher label: tape player | Teacher label: grocery store |
![]() |
![]() |
![]() |
Data label: wall clock | Teacher label: shield | Teacher label: gong |
![]() |
![]() |
![]() |
G.2 Text Classification
We show some representative examples of generated text using TGT from the experiment on MNLI classification in Table 8.
Input | TGT Example |
---|---|
Data label: Contradicts | Teacher label: Neutral |
The house was bought with the royalties she earned from her first book, The Tales of Peter Rabbit. [SEP] The house was bought with the money she inherited from her grandfather. | The book was published in the United States in 1987 with the royalties she received from her first book, The Tales of Peter Rabbit. [SEP] The house was bought with the money she inherited from her grandfather. |
Data label: Entail | Teacher label: Entail |
Leather goods are no longer a bargain in Spain, though very good quality products may still be priced lower than at home. [SEP] Leather goods are still very cheap in Spain. | Leather and leather goods are no longer a bargain in Spain, though very good quality products may still be priced lower than at home and abroad. [SEP] Leather goods are still very cheap at Spain. |
Data label: Entail | Teacher label: Neutral |
Then I got up as softly as I could, and felt in the dark along the left-hand wall. [SEP] The wall was wet. | Then I got up as softly as I could, and walked the way I felt in the dark along the left [SEP] The wall was wet. |
Data label: Entails | Teacher label: Entail |
But then this very particular island is hardly in danger of being invaded except, of course, by tourism. [SEP] This island is least likely to be invaded by tourism. | But then this very particular island is not in danger of being invaded except, of course, by tourism. [SEP] The island is likely to be invaded by tourism. |
Data label: Contradicts | Teacher label: Neutral |
All you need to do is just wander off the beaten path, beyond the bustling tourist zone. [SEP] There is no point going off the beaten path, there is nothing there. | All you need to do is just wander off the beaten path, and youĺl be in the bustling tourist zone of the city. [SEP] There is no point going off the beaten path, there is nothing there. |
Data label: Entails | Teacher label: Neutral |
The silt of the River Maeander has also stranded the once-mighty city of Miletus. [SEP] The River Maeander has been depositing silt near Miletus for nearly two millennia. | The silt of the River Mae has also stranded the once-mighty city of Miletus. [SEP] The River Maeander has been depositing silt near Miletus for more than two decades. |
Data label: Entails | Teacher label: Entails |
It was hardly the most enlightened of times, not with the conflict in Indochina rapidly becoming Americaś costliest and most divisive war. [SEP] The war in Indochina has cost America 100 billion dollars so far. | It was hardly the most enlightened of times, not with the war in Indochina becoming Americaś costliest and most divisive war. [SEP] The war in Indochina has cost America 100 billion dollars so far. |