This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Teacher Guided Training:
An Efficient Framework for Knowledge Transfer

Manzil Zaheer Google Research and DeepMind New York, USA {manzilzaheer,ankitsrawat,seungyeonk,cyou,himj,aveit,robfergus,sanjivk}@google.com Ankit Singh Rawat Google Research and DeepMind New York, USA {manzilzaheer,ankitsrawat,seungyeonk,cyou,himj,aveit,robfergus,sanjivk}@google.com
Seungyeon Kim
Google Research and DeepMind New York, USA {manzilzaheer,ankitsrawat,seungyeonk,cyou,himj,aveit,robfergus,sanjivk}@google.com
Chong You Google Research and DeepMind New York, USA {manzilzaheer,ankitsrawat,seungyeonk,cyou,himj,aveit,robfergus,sanjivk}@google.com Himanshu Jain Google Research and DeepMind New York, USA {manzilzaheer,ankitsrawat,seungyeonk,cyou,himj,aveit,robfergus,sanjivk}@google.com Andreas Veit Google Research and DeepMind New York, USA {manzilzaheer,ankitsrawat,seungyeonk,cyou,himj,aveit,robfergus,sanjivk}@google.com
Rob Fergus
Google Research and DeepMind New York, USA {manzilzaheer,ankitsrawat,seungyeonk,cyou,himj,aveit,robfergus,sanjivk}@google.com
Sanjiv Kumar Google Research and DeepMind New York, USA {manzilzaheer,ankitsrawat,seungyeonk,cyou,himj,aveit,robfergus,sanjivk}@google.com
Abstract

The remarkable performance gains realized by large pretrained models, e.g., GPT-3, hinge on the massive amounts of data they are exposed to during training. Analogously, distilling such large models to compact models for efficient deployment also necessitates a large amount of (labeled or unlabeled) training data. In this paper, we propose the teacher-guided training (TGT) framework for training a high-quality compact model that leverages the knowledge acquired by pretrained generative models, while obviating the need to go through a large volume of data. TGT exploits the fact that the teacher has acquired a good representation of the underlying data domain, which typically corresponds to a much lower dimensional manifold than the input space. Furthermore, we can use the teacher to explore input space more efficiently through sampling or gradient-based methods; thus, making TGT especially attractive for limited data or long-tail settings. We formally capture this benefit of proposed data-domain exploration in our generalization bounds. We find that TGT can improve accuracy on several image classification benchmarks as well as a range of text classification and retrieval tasks.

00footnotetext: * Equal contribution

1 Introduction

Recent general purpose machine learning models (e.g., BERT (Devlin et al., 2019), DALL-E (Ramesh et al., 2021), SimCLR (Chen et al., 2020a), Perceiver (Jaegle et al., 2021), GPT-3 (Brown et al., 2020)), trained on broad data at scale, have demonstrated adaptability to a diverse range of downstream tasks. Despite being trained in unsupervised (or so-called self-supervised) fashion, these models have been shown to capture highly specialized information in their internal representations such as relations between entities Heinzerling and Inui (2021) or object hierarchies from images (Weng et al., 2021).

Despite their impressive performance, the prohibitively high inference cost of such large models prevents their widespread deployment. A standard approach to reducing the inference cost while preserving performance is to train a compact (student) model via knowledge distillation (Bucilua et al., 2006; Hinton et al., 2015) from a large (teacher) model. However, existing distillation methods require a large amount of training data (labeled or unlabeled) for knowledge transfer. For each data point, the teacher must be evaluated, making the process computationally expensive Xie et al. (2020d); He et al. (2021); Sanh et al. (2019a). This is compounded by the need to repeat the distillation process separately for every down-stream task, each with its own training set. Enabling efficient distillation is thus an important challenge. Additionally, minimizing the number of distillation samples would especially benefit low-data down-stream tasks, (e.g.  those with long-tails).

Another inefficiency with standard distillation approaches is that within each evaluation of the teacher, only the final layer output (aka logits) is utilized. This ignores potentially useful internal representations which can also be levered for knowledge transfer. Various extensions have been proposed in the literature along these lines  (see, e.g., (Sun et al., 2020; Aguilar et al., 2020; Li et al., 2019; Sun et al., 2019) and references therein). However, despite their success, most use the teacher model in a black-box manner, and do not fully utilize the domain understanding it contains (Cho and Hariharan, 2019; Stanton et al., 2021). In these approaches, the teacher is used passively as the input sample distribution is fixed and does not adapt to the student model performance. Consequently, these forms of distillation do not lead to faster training of a high-performance student model.

Refer to caption
Figure 1: An overview of the proposed teacher guided training (TGT) framework. Given a learning task, the framework leverages a large teacher with a pretrained generator and labeler that exhibits high performance on the task. In particular, we assume that the generator consists of an encoder and a decoder. TGT performs three key operations during student model training: (1) Given an original training instance, by using the teacher generator, identify a novel task-relevant instance. We search for informative instances in the lower dimensional latent space, where we can propagate the gradient to. (2) Obtain (soft) labels for the original and newly generated training instance from the teacher labeler; and (3) Minimize the student training objective that depends on the original dataset and the newly generated instances and their corresponding labels produced by the teacher labeler. Note that TGT reduces to standard knowledge distillation in the absence of the generator component.

In this work, we go beyond the passive application of large teacher models for training compact student models, and leverage the domain understanding captured by the teacher to generate new informative training instances that can help the compact model achieve higher accuracy with fewer samples and thus enable reduced training time. In particular, we propose the teacher guided training (TGT) framework for a more efficient transfer of knowledge from large models to a compact model. TGT relies on the fact that teacher’s internal representation of data often lies in a much smaller dimensional manifold than the input dimension. Furthermore, we can use teacher to help guide training by identifying the directions where the student’s current decision boundary starts to diverge from that of the teacher, e.g., via backpropagating through the teacher to identify regions of disagreement.

We also give a formal justification for the TGT algorithm, showing that leveraging the internal data representation of large models enables better generalization bounds for the student model. Given nn instances in a DD-dimensional space the generalization gap for learning a Lipschitz decision boundary of an underlying classification task decays only as 𝒪(n1D)\mathcal{O}\big{(}n^{-\frac{1}{D}}\big{)} (Györfi et al., 2002). In contrast, assuming that the large model can learn a good data representation in a dd-dimensional latent space, the TGT framework realizes a generalization gap of 𝒪(n1d+𝒲(𝒟,𝒟t))\mathcal{O}\big{(}n^{-\frac{1}{d}}+\mathcal{W}(\mathscr{D},\mathscr{D}^{t})\big{)}, where 𝒲(𝒟,𝒟t)\mathcal{W}(\mathscr{D},\mathscr{D}^{t}) denotes the Wasserstein distance between the data distribution 𝒟\mathscr{D} and the distribution 𝒟t\mathscr{D}^{\rm t} learned by the underlying generative teacher model. Typically dDd\ll D, thus TGT ensures much faster convergence whenever we employ a high-quality generative teacher model. This makes TGT especially attractive for low-data or long-tail regimes.

In order to realize TGT, we take advantage of the fact that most of the unsupervised pretrained models like Transformers, VAE, and GANs have two components: (1) an encoder that maps data to a latent representation, and (2) a decoder that transforms the latent representation back to the original data space. We utilize this latent space for the data representations learned by the teacher model to efficiently search for the regions of mismatch between the teacher and student’s decision boundaries. This search can take the form of either (i) a zero-order approach involving random perturbation or (ii) a first-order method exploring along the direction of the gradient of a suitably defined distance measure between the teacher and student models.

Some of these pretrained models, particularly in NLP such as T5 (Raffel et al., 2020), can also provide labels for a downstream task and act as a sole teacher. However, our approach is sufficiently general to utilize separate pretrained models for generative and discriminative (labeler) functions (cf. Fig. 1), e.g., we employ an image BiGAN as generator and an EfficientNet as labeler for an image classification task.

Our main contributions are summarized as follows:

  1. 1.

    We introduce the TGT framework, a conceptually simple and scalable approach to distilling knowledge from a large teacher into a smaller student. TGT adaptively changes the distribution of distillation examples, yielding higher performing student models with fewer training examples.

  2. 2.

    We provide theoretical justifications for utilizing the latent space of the teacher generator in the TGT framework, which yields tighter generalization bounds.

  3. 3.

    We empirically demonstrate the superiority of TGT to existing state-of-the-art distillation approaches, also showing results on both vision and NLP tasks, unlike most previous work which is specialized to one domain.

2 Related Work

Our proposed TGT framework can be considered a form of data augmentation where data is dynamically added at points of current discrepancy between the teacher and student. Next, we provide a brief overview of how data augmentation has been used in the context of distillation.

Using pseudo labels. The earliest line of work involves using consistency regularization (Sajjadi et al., 2016; Laine and Aila, 2017; Tarvainen and Valpola, 2017) to obtain pseudo labels for unlabelled data where a model is expected to make consistent predictions on an unlabeled instance and its augmented versions, cf. (Miyato et al., 2019; Xie et al., 2020a; Verma et al., 2019; Berthelot et al., 2019; Sohn et al., 2020; Zhu et al., 2021, inter alia). Another approach is self-training (Xie et al., 2020d; Du et al., 2021) where a smaller teacher model is learned on the labeled data first which is then used to generate pseudo labels for a large but relevant unlabeled set. A large student model is then trained on both labeled and pseudo labeled sets. Label propagation (Iscen et al., 2019) is another direction where unlabeled instances receive pseudo labels based on neighboring labeled instances in a similarity graph constructed based on the representations from a model trained on only labeled data.

Furthermore, there have been work on learning to teach (Fan et al., 2018; Raghu et al., 2021; Pham et al., 2021), where the teacher is dynamically updated so as to provided more valuable pseudo labels based on the student loss function. Such an interactive approach presents a challenging optimization problem and potentially opens up the door for borrowing techniques from reinforcement learning. In contrast, our work focuses on the setting where high-quality pretrained teacher model is fixed throughout the training. Here, we focus on a setting where updating the large teacher model is prohibitively costly or undesirable as such a model would potentially be used to distill many student models. Moreover, many large models like GPT-3 may only be available through API access, thus making it infeasible to update the teacher.

Using pretrained models.  One can use large scale pretrained class conditional generative models like BigGAN (Brock et al., 2019) or VQ-VAE2 (Razavi et al., 2019) to generate more data for augmentation. Despite evidence  (Webster et al., 2019) that GANs are not memorizing training data, using them to simply augment the training dataset has limited utility when training ResNets (Ravuri and Vinyals, 2019b, a). One possible reason might be the lack of diversity (Arora et al., 2017) in data generated by GANs, especially among high density regions (Arora et al., 2018). In contrast, we use generative models to adaptively explore the local region of disagreement between teacher and student as opposed to blindly sampling from the generative model. This way we circumvent the excessive reliance on samples from high density regions which often have low diversity.

Another line of work by Chen et al. (2020b) combines unsupervised/self-supervised pretraining (on unlabeled data) with SimCLR-based approach (Chen et al., 2020a), task-specific finetuning (on labeled data), and distillation (natural loss on labeled and distillation loss on unlabeled data). The setup considered in this work is very close to our work with two key differences: (1) We assume access to a very high-quality teacher model, which is potentially trained on a much larger labeled set, to provide pseudo labels; (2) We go beyond utilizing the given unlabeled dataset from the domain of interest, exploring the dynamic generation of domain-specific unlabeled data by leveraging the representations learned by pretrained models. Additionally, our work aims to develop a theoretical framework to identify the utility of unlabeled data instances for student training, especially the unlabeled instances generated based on teacher learned representations.

Using both pseudo labels and pretrained models.  The idea of combining pretrained models to generate training instances along with pseudo-labelers has been previously considered in the name of the GAL framework (He et al., 2021). However, the GAL framework generates these new instances in an offline manner at the beginning of student training. In contrast, our proposed approach (cf. Fig. 1) generates the new informative training instances in an online fashion, aiming at improving the student performance while reducing its training time.

Recently, MATE-KD (Rashid et al., 2021) also considers a setup where a generator model is used to obtain new training instances based on the current student model performance (by looking at the divergence between the student and teacher predictions). However, there are two key differences between our proposed TGT approach and the MATE-KD framework: First, their method updates the teacher so as to find adversarial examples for the students, but this can cause the generator to drift away from true data distribution. Second, they perturb in input space itself and do not leverage the latent space of the teacher, which is the crux of our method. Further details are provided in App. A.

Another work worth mentioning is KDGAN (Wang et al., 2018) which leverages a GAN during distillation. However, it samples examples from a GAN without taking the student performance into account. We also note (Heo et al., 2019; Dong et al., 2020) that search for adversarial examples during distillation. However, their search also does not depend on student’s performance, resulting in wasteful exploration of those regions of the input spaces where the student is already good. Further, unlike TGT, (Heo et al., 2019; Dong et al., 2020) perform example search in the input space which is often inefficient due to the large ambient dimension of the input space.

Finally, data-free KD approaches perform knowledge distillation using only synthetically generated data (Nayak et al., 2019; Yoo et al., 2019; Chen et al., 2019). Unlike TGT, in this approach, the synthetic data distribution is updated at each epoch, but this causes the student model to lose the information over epochs and experience accuracy degradation (Binici et al., 2022). In this framework, Micaelli and Storkey (2019) targeted generating samples that would cause maximum information gain to the student when learned, however, it also suffers from similar drawbacks as MATE-KD noted above.

3 Teacher Guided Training

We begin by formally introducing our setup in Section 3.1. We then describe our proposed TGT framework in Sec. 3.2 and present a theoretical analysis in Sec. 3.3 and Sec. 3.4.

3.1 Problem setup

In this paper, we focus on a multiclass classification task where given an instance x𝒳x\in\mathscr{X} the objective is to predict its true label y𝒴:=[K]y\in\mathscr{Y}:=[K] out of KK potential classes. Let 𝒟:=𝒟X,Y\mathscr{D}:=\mathscr{D}_{X,Y} denote the underlying (joint) data distribution over the instance and label spaces for the task. Moreover, we use 𝒟X\mathscr{D}_{X} and 𝒟Y|X=x\mathscr{D}_{Y|X=x} to denote the marginal distribution over the instance space 𝒳\mathscr{X} and the conditional label distribution for a given instance xx, respectively. A classification model f:𝒳Kf:\mathscr{X}\to{\mathbb{R}}^{K}, with f(x)=(f(x)1,,f(x)K)f(x)=(f(x)_{1},\ldots,f(x)_{K}) takes in an input instance xx and yields scores for each of the KK classes. Finally, we are given a (tractable) loss function :K×[K]\ell:{\mathbb{R}}^{K}\times[K]\to{\mathbb{R}} which closely approximates model’s misclassification error on an example (x,y)(x,y), e.g., softmax-based cross-entropy loss.

We assume access to nn i.i.d. labeled samples 𝒮nlabeled:={(xi,yi)}i[n]\mathscr{S}^{\rm labeled}_{n}:=\{(x_{i},y_{i})\}_{i\in[n]}, generated from 𝒟\mathscr{D}. Given 𝒮nlabeled\mathscr{S}^{\rm labeled}_{n} and a collection of allowable models \mathcal{F}, one typically learns a model with small misclassification error by solving the following empirical risk minimization (ERM) problem:

f^n=argminf1ni[n](f(xi),yi).\displaystyle\widehat{f}_{n}=\arg\min_{f\in\mathcal{F}}{\frac{1}{n}\sum\nolimits_{i\in[n]}\ell(f(x_{i}),y_{i})}. (1)

Besides the standard classification setting introduced above, in our TGT setup, we further assume access to a high quality teacher model, which has:

  • Teacher generator. A generative component that captures 𝒟X\mathscr{D}_{X} well, e.g., a transformer, VAE, or ALI-GAN. This usually consists of an encoder 𝖤𝗇𝖼:𝒳d\mathsf{Enc}:\mathscr{X}\to\mathbb{R}^{d} and a decoder 𝖣𝖾𝖼:d𝒳\mathsf{Dec}:\mathbb{R}^{d}\to\mathscr{X}.

  • Teacher labeler. A classification network, denoted by h:𝒳Kh:\mathscr{X}\to{\mathbb{R}}^{K}, with good performance on the underlying classification task. In general, our framework allows for hh to be either a head on top of the teacher generator or an independent large teacher classification model.

Given 𝒮nlabeled\mathscr{S}^{\rm labeled}_{n} and such a teacher model, our objective is to learn a high-quality compact student (classification) model in \mathcal{F}, as assessed by its misclassification error on 𝒟\mathscr{D}.

3.2 Proposed approach

To train a student model ff\in\mathcal{F}, we propose to minimize:

RfTGT(𝒮nlabeled):=\displaystyle R^{\rm TGT}_{f}(\mathscr{S}^{\rm labeled}_{n}):= 1ni[n](f(xi),yi)+d(f(xi),h(xi))+1mj[m]d(f(x~j),h(x~j))\displaystyle\frac{1}{n}\sum_{i\in[n]}\ell(f(x_{i}),y_{i})+\ell_{{\mathrm{d}}}(f(x_{i}),h(x_{i}))+\frac{1}{m}\sum_{j\in[m]}\ell_{{\mathrm{d}}}(f(\tilde{x}_{j}),h(\tilde{x}_{j})) (2)

where d:K×K\ell_{{\mathrm{d}}}:{\mathbb{R}}^{K}\times{\mathbb{R}}^{K}\to{\mathbb{R}} is a loss function that captures the mismatch between two models ff and hh, and 𝒮~m={x~j}j[m]\tilde{\mathscr{S}}_{m}=\{\tilde{x}_{j}\}_{j\in[m]} is introduced in subsequent passage. The first term, (f(xi),yi)\ell(f(x_{i}),y_{i}), corresponds to standard ERM problem (cf. Eq. 1). The subsequent terms, d(f(xi),h(xi))\ell_{{\mathrm{d}}}(f(x_{i}),h(x_{i})) and d(f(x~j),h(x~j))\ell_{{\mathrm{d}}}(f(\tilde{x}_{j}),h(\tilde{x}_{j})), do not make use of labels. In particular, the second term, d(f(xi),h(xi))\ell_{{\mathrm{d}}}(f(x_{i}),h(x_{i})), corresponds to the knowledge distillation  (Bucilua et al., 2006; Hinton et al., 2015) where the teacher model provides supervision for the student model.

We introduce a novel third term, d(f(x~j),h(x~j))\ell_{{\mathrm{d}}}(f(\tilde{x}_{j}),h(\tilde{x}_{j})), where the data 𝒮~m={x~j}\tilde{\mathscr{S}}_{m}=\{\tilde{x}_{j}\} is generated based on 𝒮n={xi}\mathscr{S}_{n}=\{x_{i}\}. Here, we want to generate additional informative samples which will help student learn faster, e.g., points where it disagrees with teacher but still lie on the data manifold. In other words, we want to find x~\tilde{x} as follows:

x~=\displaystyle\tilde{x}= argmaxx𝒳(h(x),f(x))such thatp𝒟X(x)λ\displaystyle\operatorname*{arg\,max}_{x\in\mathscr{X}}\ell(h(x),f(x))~{}\text{such that}~{}~{}p_{\mathscr{D}_{X}}(x)\geq\lambda (3)

We propose two specific approaches to generate novel samples:

  1. 1.

    Isotropically perturb in latent space:

    x~=𝖣𝖾𝖼(𝖤𝗇𝖼(x)+ν)whereν𝒩(0,σ2𝕀d).\displaystyle\tilde{x}=\mathsf{Dec}(\mathsf{Enc}(x)+\nu)\quad\text{where}~{}\nu\sim\mathcal{N}(0,\sigma^{2}\mathbb{I}_{d}). (4)

    This can be regarded as a zero-order search in the latent space, which satisfies the constraint of remaining within the data manifold.

  2. 2.

    Gradient-based exploration: Run a few iterations of gradient ascent on Eq. 3 in order to find the example that diverges most with teacher. To enforce the constraint, we run the gradient ascent in the latent space of the teacher generator as opposed to performing gradient ascent in the instance space 𝒳\mathscr{X}, which might move the perturbed point out of the data manifold. For a high-quality teacher generator, the latent space should capture the data manifold well. To implement this we need to backprop all the way through the student and teacher-labeler to the teacher-decoder, as shown in Fig. 1. Mathematically, it involves the following three operations:

    z\displaystyle z :=𝖤𝗇𝖼(x);z\displaystyle:=\mathsf{Enc}(x);\qquad z z+ηzd(f(𝖣𝖾𝖼(z)),h(𝖣𝖾𝖼(z)));x~\displaystyle\leftarrow z+\eta\nabla_{z}\ell_{{\mathrm{d}}}\left(f(\mathsf{Dec}(z)\right),h(\mathsf{Dec}(z)));\qquad\tilde{x} :=𝖣𝖾𝖼(z).\displaystyle:=\mathsf{Dec}(z). (5)

    This is akin to a first-order search in the latent space.

Extension to discrete data.  Note that perturbing an instance from a discrete domain, e.g., text data, is not as straightforward as in a continuous space. Typically, one has to resort to expensive combinatorial search or crude approximations to perform such perturbations (Tan et al., 2020; Zang et al., 2020; Ren et al., 2019). Interestingly, our approach in Eq. 4 provides a simple alternative where one performs the perturbation in the latent space which is continuous. On the other hand, in gradient based exploration, we assume that 𝒳\mathscr{X} is a differentiable space in order to calculate necessary quantities such as f(x)x\frac{\partial f(x)}{\partial x} in Eq. 5. This assumption holds for various data such as images and point clouds but not for discrete data like text. We can, however, circumvent this limitation by implementing weight sharing between the output softmax layer of the teacher’s decoder 𝖣𝖾𝖼\mathsf{Dec} and the input embedding layer of the student ff (and also to teacher labeler hh when an independent model is used). Now, one can bypass discrete space during the backward pass, similar to ideas behind VQ-VAE (Hafner et al., 2019). Note that, during forward pass, we still need the discrete representation for decoding, e.g., using beam search.

Finally, we address the superficial resemblance between our approach and adversarial training. In adversarial training, the goal is to learn a robust classifier, i.e., to increase margin. Towards this, for any xx, one wants to enforce model agreement in its local neighborhood Br(x)B_{r}(x), i.e., f(x)=f(x),xBr(x)f(x^{\prime})=f(x),\forall x^{\prime}\in B_{r}(x). One needs to carefully choose small enough neighborhood by restricting rr, so as to not cross the decision boundary. In contrast, we are not looking for such max-margin training which has its own issues (cf. (Nowak-Vila et al., 2021)). We simply want to encourage agreement between the teacher and student, i.e., f(x)=h(x),xf(x^{\prime})=h(x^{\prime}),~{}\forall x^{\prime}. Thus, we don’t have any limitation on the size of the neighborhood to consider. As a result, we can explore much bigger regions as long as we remain on the data manifold.

3.3 Value of generating samples via the latent space

In this section, we formally show how leveraging the latent space can help learning. For this exposition, we assume 𝒳=D\mathscr{X}=\mathbb{R}^{D}. Furthermore, for directly learning in the input space, we assume that our function class \mathcal{F} corresponds to all Lipschitz functions that map D\mathbb{R}^{D} to K\mathbb{R}^{K}. Then for any such function ff\in\mathcal{F}, there are existing results for generalization bound of the form (Devroye et al., 2013; Mohri et al., 2018):

R,f(𝒟)R,f(𝒮n)+n(𝒢,)𝒪(n1/D)+𝒪(log(1/δ)/n),R_{\ell,f}(\mathscr{D})\leq{R}_{\ell,f}(\mathscr{S}_{n})+\underbrace{\mathfrak{R}_{n}(\mathcal{G}_{\ell,\mathcal{F}})}_{\leq\mathcal{O}(n^{-1/D})}+\mathcal{O}\big{(}\sqrt{{\log(1/\delta)}/{n}}\big{)},

where R,f(𝒟)R_{\ell,f}(\mathscr{D}) is true population risk of the classifier, R,f(𝒮n){R}_{\ell,f}(\mathscr{S}_{n}) is empirical risk, and n(𝒢,)\mathfrak{R}_{n}(\mathcal{G}_{\ell,\mathcal{F}}) is the Rademacher complexity of the induced function class 𝒢,\mathcal{G}_{\ell,\mathcal{F}}, which is known in our case to be 𝒪(n1/D)\mathcal{O}(n^{-1/D}) (see App. B for more details). Any reduction in the Rademacher term would imply a smaller generalizing gap, which is our goal.

In our TGT framework, we assume availability of a teacher that is able to learn a good representation for the underlying data distribution. In particular, we assume that, for xsupp(𝒟X)x\in{\rm supp}(\mathscr{D}_{X}), we have

𝖣𝖾𝖼𝖤𝗇𝖼(x)xϵ,\displaystyle\|\mathsf{Dec}\circ\mathsf{Enc}(x)-x\|\leq\epsilon, (6)

i.e., for xx, applying the decoder 𝖣𝖾𝖼\mathsf{Dec} on the latent representation of xx, as produced by the encoder 𝖤𝗇𝖼\mathsf{Enc}, leads to a point 𝖣𝖾𝖼𝖤𝗇𝖼(x)𝒳\mathsf{Dec}\circ\mathsf{Enc}(x)\in\mathscr{X} that approximates xx with a small error.

This ability of teacher generator to model the data distribution using latent representation can be used to reduce the complexity of the function class needed. Specifically, in TGT framework, we leverage the teacher decoder to restrict the function class to be a composition of the decoder function 𝖣𝖾𝖼\mathsf{Dec} and a learnable Lipschitz function operating on the latent space d\mathbb{R}^{d}. Since dDd\ll D, this leads to a function class with much lower complexity. Next, we formally capture this idea for distillation with both the original samples 𝒮n\mathscr{S}_{n} sampled from 𝒟X\mathscr{D}_{X} as well as the novel samples 𝒮~\tilde{\mathscr{S}} introduced by the teacher generator. In what follows, we only consider the distillation losses and ignore the first loss term (which depends on true labels). Our analysis can be easily extended to take the latter term into account (e.g., by using tools from Foster et al. (2019)).

We start with the standard distillation in the following result.

Theorem 3.1.

Suppose a generative model with 𝖤𝗇𝖼\mathsf{Enc} and 𝖣𝖾𝖼\mathsf{Dec} satisfies the approximation guarantee in Eq. 6 for 𝒟X\mathscr{D}_{X}. Let 𝖣𝖾𝖼\mathsf{Dec} and teacher labeler hh be Lipschtiz functions, and the distillation loss d\ell_{{\mathrm{d}}} satisfies Assumption C.1. Then, with probability at least 1δ1-\delta, the following holds for any ff\in\mathcal{F}.

R,f(𝒟)\displaystyle R_{\ell,f}(\mathscr{D})\leq Rd,fh(𝒮n)+n(𝒢d,h,𝖣𝖾𝖼)𝒪(n1/d)+𝒪(log(1/δ)n)+Lϵ+𝒪(K𝔼𝒟X[𝒟Y|Xh(X)2]).\displaystyle R^{h}_{\ell_{{\mathrm{d}}},f}(\mathscr{S}_{n})+\underbrace{\mathfrak{R}_{n}(\mathcal{G}^{h,\mathsf{Dec}}_{\ell_{{\mathrm{d}}},\mathcal{F}})}_{\leq\mathcal{O}(n^{-1/d})}+\mathcal{O}\Big{(}{\frac{\sqrt{\log({1}/{\delta})}}{\sqrt{n}}}\Big{)}+L\epsilon+\mathcal{O}\big{(}{\sqrt{K}}\mathbb{E}_{\mathscr{D}_{X}}\left[\|\mathscr{D}_{Y|X}-h(X)\|_{2}\right]\big{)}.

where LL is the effective Lipschitz constant of 𝒢d,h,𝖣𝖾𝖼={zd(f𝖣𝖾𝖼(z),h𝖣𝖾𝖼(z)):f}\mathcal{G}^{h,\mathsf{Dec}}_{\ell_{{\mathrm{d}}},\mathcal{F}}=\{z\mapsto\ell_{{\mathrm{d}}}(f\circ\mathsf{Dec}(z),h\circ\mathsf{Dec}(z)):f\in\mathcal{F}\} — an induced function class which maps d{\mathbb{R}}^{d} (latent space of generator) to {\mathbb{R}}.

Thus, we can reduce the Rademacher term from O(n1/D)O(n^{-1/D}) to O(n1/d)O(n^{-1/d}), which yields a significant reduction in sample complexity. However, as the teacher model is not perfect, a penalty is incurred in terms of reconstruction and prediction error. See Sec. C.1 for the details.

Thus far, we have not leveraged the fact that we can also use the teacher to generate further samples. Accounting for using samples 𝒮~n\tilde{\mathscr{S}}_{n} generated from teacher generator instead, one can obtain similar generalization gap for the distillation based on the teacher generated samples:

Theorem 3.2.

Let 𝒮~n={x~i}i[n]\tilde{\mathscr{S}}_{n}=\{\tilde{x}_{i}\}_{i\in[n]} be nn i.i.d. samples generated by the the TGT framework, whose distribution be denoted by 𝒟~X\tilde{\mathscr{D}}_{X}. Further, let f~n\tilde{f}_{n}\in\mathcal{F} denote the student model learned via distillation on 𝒮~n\tilde{\mathscr{S}}_{n}, with hh as the teacher model and d\ell_{{\mathrm{d}}} be the distillation loss satisfying Assumption  C.1. Then, with probability at least 1δ1-\delta, we have

R,f(𝒟)\displaystyle R_{\ell,f}(\mathscr{D})\leq Rd,f~nh(𝒮~n)+~n(𝒢d,h,𝖣𝖾𝖼)𝒪(n1/d)+𝒪(log(1/δ)n)+𝒲(𝒟X,𝒟~X)\displaystyle R^{h}_{\ell_{{\mathrm{d}}},\tilde{f}_{n}}(\tilde{\mathscr{S}}_{n})+\underbrace{\tilde{\mathfrak{R}}_{n}(\mathcal{G}^{h,\mathsf{Dec}}_{\ell_{{\mathrm{d}}},\mathcal{F}})}_{\leq\mathcal{O}(n^{-1/d})}+\mathcal{O}\left(\sqrt{\frac{\log(1/\delta)}{n}}\right)+\mathcal{W}(\mathscr{D}_{X},\tilde{\mathscr{D}}_{X})
+𝒪(K𝔼DX[DY|Xh(X)2]),where 𝒢d,h,𝖣𝖾𝖼 is defined in Thm. 3.1\displaystyle+\mathcal{O}\big{(}{\sqrt{K}}\mathbb{E}_{D_{X}}\left[\|D_{Y|X}-h(X)\|_{2}\right]\big{)},\quad\text{where }\mathcal{G}^{h,\mathsf{Dec}}_{\ell_{{\mathrm{d}}},\mathcal{F}}\text{ is defined in Thm.~{}\ref{thm:manifold-advantage-main-body}}

Please see Sec. C.2 for a more precise statement and proof of Thm. 3.2.

Remark 3.3.

Comparing with the generalization gap for standard distillation (cf. Thm. 3.1), the generalization gap for TGT in Thm. 3.2 does not have the reconstruction error related term LϵL\epsilon. Thus, by working with the samples with exact latent representation, TGT avoids this reconstruction error penalty. On the other hand, generalization gap for TGT does have an additional term 𝒲(DX,D~X)\mathcal{W}(D_{X},\tilde{D}_{X}), capturing the mistmatch between the original data distribution and the distribution of the samples used by TGT. However, this term becomes smaller as the teacher generator gets better at capturing the data. Note that generative models like WGAN explicitly minimize this term (Arjovsky et al., 2017).

3.4 Motivation for gradient based exploration

Results so far do not throw light on how to design optimal 𝒟~X\tilde{\mathscr{D}}_{X}, i.e., the search mechanism in the latent space for our TGT framework. In this regard, we look at the variance-based generalization bounds (Maurer and Pontil, 2009). These were previously utilized by Menon et al. (2021a) in the context of distillation. Applying this technique in our TGT approach, we would obtain a generalization bound of the form:

Rd,fh(𝒟X)\displaystyle R^{h}_{\ell_{{\mathrm{d}}},f}(\mathscr{D}_{X})\leq Rd,fh(𝒮~n)+(II)+𝒲(𝒟X,𝒟~X),\displaystyle R^{h}_{\ell_{{\mathrm{d}}},f}(\tilde{\mathscr{S}}_{n})+\text{\rm(II)}+\mathcal{W}(\mathscr{D}_{X},\tilde{\mathscr{D}}_{X}), (7)

where, (II)=𝒪(Var𝒟~X(d,fh(x~))log((n)δ)n+log((n)δ)n),{\rm(II)}=\mathcal{O}\bigg{(}\sqrt{\frac{{\rm Var}_{\tilde{\mathscr{D}}_{X}}(\ell^{h}_{{\mathrm{d}},f}(\tilde{x}))\cdot\log(\frac{\mathcal{M}(n)}{\delta})}{n}}+\frac{\log(\frac{\mathcal{M}(n)}{\delta})}{n}\bigg{)}, with d,fh(x~):=d(f(x~),h(x~))\ell^{h}_{{\mathrm{d}},f}(\tilde{x}):=\ell_{{\mathrm{d}}}\big{(}f(\tilde{x}),h(\tilde{x})\big{)} and (n)\mathcal{M}(n) depending on the covering number for the induced function class 𝒢d,h\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}} (cf. Eq. 16). Here, we note that by combining Eq. 7 with Lemma D.4 translate the bound on Rd,fh(𝒟X)R^{h}_{\ell_{{\mathrm{d}}},f}(\mathscr{D}_{X}) to a bound on R,f(𝒟)R_{\ell,f}(\mathscr{D}) with an additional penalty term that depends on the quality of the teacher labeler hh.

Note that Eq. 7 suggests a general approach to select the distribution  𝒟~X\tilde{\mathscr{D}}_{X} that generates the training samples 𝒮~n\tilde{\mathscr{S}}_{n}. In order to ensure small generalization gap, we need to focus on two terms depending on 𝒟~X\tilde{\mathscr{D}}_{X}: (1) the variance term Var𝒟~X(d(x~)){\rm Var}_{\tilde{\mathscr{D}}_{X}}(\ell_{{\mathrm{d}}}(\tilde{x})); and (2) the divergence term 𝒲(𝒟X,𝒟~X)\mathcal{W}(\mathscr{D}_{X},\tilde{\mathscr{D}}_{X}). We note that finding a distribution that jointly minimizes both terms is a non-trivial task. That said, in our sampling approach in Eq. 5, we control for 𝒲(𝒟X,𝒟~X)\mathcal{W}(\mathscr{D}_{X},\tilde{\mathscr{D}}_{X}) by operating in the latent space of a good quality teacher generative model and minimize variance by finding points with high loss values through gradient ascent, thereby striking a balance between the two objectives. We refer to Sec. C.3 for more details on the bound stated in Eq. 7.

4 Experiments

We now conduct a comprehensive empirical study of our TGT framework in order to establish that TGT (i) leads to high accuracy in transferring knowledge in low data/long-tail regimes (Sec. 4.1); (ii) effectively increases sample size (Sec. 4.2); and (iii) has wide adaptability even to discrete data domains such as text classification (Sec. 4.3) and retrieval (Sec. 4.4).

Approach Architecture Balanced Accuracy # parameters FLOPs
ImageNet1K-LT LDAM-DRW* (Cao et al., 2019) ResNet-50 47.8 26 M 4.1 B
LWS (Kang et al., 2020) ResNeXt-50 49.9 25 M 4.2 B
Logit adjustment loss* (Menon et al., 2021b) ResNet-50 50.4 26 M 4.1 B
LDAM-DRS-RSG (Wang et al., 2021) ResNeXt-50 51.8 25 M 4.2 B
DRAGON + Bal’Loss (Samuel et al., 2021) ResNet-10 46.5 5.4 M 819 M
DRAGON + Bal’Loss (Samuel et al., 2021) ResNet-50 53.5 26 M 4.1 B
Teacher (labeler) model EfficientNet-b3 79.2 12 M 1.8 B
One-hot MobileNetV3-0.75 35.5 4.01 M 156 M
Distillation MobileNetV3-0.75 47.2 4.01 M 156 M
TGT (random) MobileNetV3-0.75 53.2 4.01 M 156 M
TGT (gradient-based) MobileNetV3-0.75 53.3 4.01 M 156 M
SUN397-LT LDAM-DRS-RSG (Wang et al., 2021) ResNeXt-50 29.8 25 M 4.2 B
CAD-VAE (Schönfeld et al., 2019) ResNet-101 32.8 42 M 7.6 B
LWS (Kang et al., 2020) ResNeXt-50 33.9 25 M 4.2 B
DRAGON + Bal’Loss (Samuel et al., 2021) ResNet-101 36.1 42 M 7.6 B
Teacher (labeler) model EfficientNet-b3 65.3 12 M 1.8 B
One-hot MobileNetV3-0.75 39.3 4.01 M 156 M
Distillation MobileNetV3-0.75 42.2 4.01 M 156 M
TGT (random) MobileNetV3-0.75 44.3 4.01 M 156 M
TGT (gradient-based) MobileNetV3-0.75 44.7 4.01 M 156 M
Table 1: Performance of TGT and various baselines from the literature on long-tail image classification benchmarks (see App. E for results on Places-LT (Liu et al., 2019c)). Rows with * denote results taken from Menon et al. (2021b) and the rest were taken from Samuel et al. (2021). We report top-1 accuracy on balanced eval sets. We also state the number of model parameters and inference cost (in terms of FLOPs) for all the methods. Note that TGT leads to performance improvements over standard distillation on all three datasets, particularly for ImageNet-LT where the teacher generator models the task distribution well. TGT also often outperforms stated baselines that rely on much larger and expensive models.

4.1 Long-tail image classification

Setup. We evaluate TGT by training student models on three benchmark long-tail image classification datasets: ImageNet-LT (Liu et al., 2019c), SUN-LT (Patterson and Hays, 2012), Places-LT (Liu et al., 2019c) We employ off-the-shelf teacher models, in particular BigBiGAN (ResNet-50) (Donahue and Simonyan, 2019) and EfficientNet-B3 (Xie et al., 2020c) as the teacher generator and teacher labeler models, respectively. We utilize MobileNetV3 (Howard et al., 2019) as compact student model architecture. The teacher-labeler model is self-trained on JFT-300M (Sun et al., 2017), and then finetuned on the task-specific long-tail dataset. The teacher generator is trained on the unlabelled full version of ImageNet (Russakovsky et al., 2015).

Results. The results111We report results for Places-LT in App. E due to space constraints. are reported in Table 1 compared with similar sized baselines (we ignored gigantic transformer models). We see that TGT is able to effectively transfer knowledge acquired by the teacher during its training with the huge amount of data into a significantly smaller student model, which also has lower inference cost. We see that TGT considerably improves the performance across the board over standard distillation, even on Sun-LT and Places-LT whose data distribution does not exactly match to the distribution that the teacher’s generator was trained with. Comparing TGT (random) (cf. Eq. 4) and TGT (gradient-based) (cf. Eq. 5) indicates that most of our win comes from utilizing the latent space, the form of search being of secondary importance. Thus, for all subsequent experiments we only consider TGT (random).

Here, we note the baselines stated from the literature in Table 1 rely on specialized loss function and/or training procedures designed for the long-tail setting, whereas we do not leverage such techniques. One can pontentially combine the TGT framework with a long-tail specific loss function as opposed to employing the standard cross-entropy loss function as a way to further improve its performance. We leave this direction for future explorations.

Method Architecture Amazon-5 IMDB MNLI Yelp-5
2.5k 3M 2.5k 650k
UDA (Random Init) (Xie et al., 2020b) BERT base 55.8 - - - 58.6 -
UDA (Pretrained) (Xie et al., 2020b) BERT base 62.9 - - - 67.9 -
Layer-wise Distillation (Sun et al., 2020) MobileBERT - - 93.6 83.3 - -
MATE-KD (Rashid et al., 2021) DistilBERT - - - 85.8 - -
Teacher (labeler) model RoBERTa large - 67.6 96.2 90.6 - 72.0
One-hot (Random Init) DistilBERT 44.3 53.6 50.0 63.0 50.4 58.1
One-hot (Pretrained) DistilBERT 55.9 66.3 93.6 84.1 59.1 67.3
Distillation (Random Init) DistilBERT 56.5 65.3 87.9 77.4 54.8 69.5
Distillation (Pretrained) DistilBERT 60.2 66.8 94.0 84.5 63.2 71.4
TGT (Random Init) DistilBERT 61.3 66.6 91.0 79.3 62.0 70.4
TGT (Pretrained) DistilBERT 64.6 67.1 95.4 86.0 68.6 71.7
Table 2: Performance of TGT and various baselines from the literature on four text classification benchmarks. The teacher labeler RoBERTa-large is pretrained on a large corpus and finetuned on the task-specific dataset. BART-base serves as the task-agnostic teacher generator model. For student model training, we show results for task-specific finetuning on both randomly initialized and pretrained DistilBERT models. Note that TGT (Pretrained) — TGT with a pretrained student model — outperforms all other methods across the board. Even more interstingly, on Amazon-5 and Yelp-5, TGT with randomly initialized student, i.e., TGT (Random Init), outperfroms the standard approach of finetuning a pretrained model with one-hot labels, i.e., One-hot (Pretrained).
Refer to caption
Figure 2: Performance comparison among normal training (one-hot), standard distillation (distillation), and TGT in simulated low-data regimes. We imitate a low-data regime via subsampling the ImageNet training set and evaluate the trained student models on the entire eval set. We employ 450k training steps for normal training and standard distillation, and 112k training steps for TGT. TGT outperforms other methods in less training steps, thus, effectively simulating an increase in the sample size.

4.2 TGT in low-data regime

To further showcase effectiveness of knowledge transfer via TGT, we simulate a low-data regime by varying the amount of available training data for ImageNet (Russakovsky et al., 2015) and studying its impact on student’s performance. For these experiments, we use the same model architectures as in Sec. 4.1, but finetune the teacher labeler on the entire ImageNet. We then compare the performance of the student trained via TGT, with the students trained via normal training (with one-hot labels) and standard distillation.

The results are shown in Fig. 2. Note that both TGT and standard distillation utilize additional training data more effectively than normal training, with TGT being the most efficient of the two. The curves show TGT is equivalent to an increase in sample size by 4x, compared to the normal training. This verifies that TGT generates informative training instances for the student.

4.3 Text classification

Setup. We evaluate the proposed TGT framework on four benchmark text classification datasets: Amazon-5 (Zhang et al., 2015), IMDB (Maas et al., 2011), MNLI (Williams et al., 2018), and Yelp-5 (Zhang et al., 2015). Following Xie et al. (2020a), we also consider an extremely sub-sampled version of Amazon-5 and Yelp-5 datasets consisting of only 2500 labeled examples. Similar to the image setting, we utilize off-the-shelf teacher models, in particular a BART-base (Lewis et al., 2020) and RoBERTa-large (Liu et al., 2019a) as the teacher generator and teacher labeler, respectively. Following standard practice (Rashid et al., 2021), we employ a DistilBERT (Sanh et al., 2019b) model for the student model architecture. Both teacher networks are pretrained on a very large generic text corpus of size 160GB. The teacher labeler model is finetuned on the corresponding dataset for each task. The teacher generator is not specialized to any specific classification task.

Results. The results are reported in Table 2 where we compare TGT with other data augmentation and distillation baselines. We see that TGT considerably improves the performance and beats the state-of-the-art methods MATE-KD (Rashid et al., 2021) and UDA (Xie et al., 2020a). Also, note that by using TGT on a randomly initialized student, we can match the performance of normal finetuning (with one-hot labels) on a pretrained model on Amazon-5 and Yelp-5. We highlight that baselines such as MATE-KD (Rashid et al., 2021) always work with a pretrained student model. Thus, the aforementioned improvements realized by TGT with a randomly initialized student model demonstrates enormous saving in overall data and training time requirement as it eliminates the need for pretraining on a large corpus. This further establishes that TGT can enable a data-efficient knowledge transfer from the teacher to the student.

4.4 Text retrieval

Method recall@20 recall@100
Teacher (labeler) model 0.7957 0.8855
one-hot 0.6453 0.8198
distillation (standard) 0.7486 0.8608
uniform negatives distillation 0.7536 0.8496
TGT distillation (ours) 0.7659 0.8763
Table 3: Performance of TGT and various baselines on the NQ retrieval task (Kwiatkowski et al., 2019). The teacher labeler follows the setup of (Oğuz et al., 2021) that pretrains RoBERTa-base on a large corpus and also PAQ (Lewis et al., 2021) and then finetuned to NQ (Kwiatkowski et al., 2019). BART-base (Lewis et al., 2020) is employed to serve as a task-agnostic generator. All student models follow the architecture of DistilBERT(Sanh et al., 2019b). TGT significantly outperforms standard training (One-hot) and teacher-label only distillation (Distillation). TGT closes the teacher-student gap by 37% at @20, 63% at @100) compared to the standard distillation.

Setup. Finally, we evaluate TGT on Natural Questions (NQ) (Kwiatkowski et al., 2019) — a text retrieval benchmark dataset. The task is to find a matching passage given a question, from a large set of candidate passages (21M). We utilize the RoBERTa-Base dual-encoder model Oğuz et al. (2021) as our teacher labeler. For teacher generator, we employ BART-base (Lewis et al., 2020). We utilize DistilBERT dual encoder model as our student architecture. We follow the standard retrieval distillation setup where the teacher labeler provides labels for all the within-batch question-to-passage pairs for the student to match.

We consider three baselines: One-hot trains the student with the original one-hot training labels whereas Distillation utilizes the teacher labeler instead. In uniform negatives, for a given question-to-passage pair in NQ, we uniformly sample and label additional 2 passages from the entire passage corpus (21M). TGT instead dynamically generates 2 confusing passages for each question-passage pair with BART generator, infusing the isotropic perturbation as per Eq. 4.

Results. Table 3 compares TGT with other baselines. TGT significantly improves retrieval performance, closing the teacher-student gap by 37% at recall@20 and 63% at recall@100 compared to the standard distillation. Unlike TGT, uniformly sampled random passages only partially helped (slightly on recall@20 but degrades at @100 compared to the standard distillation). A plausible explanation is that the randomly sampled passages do not provide enough relevance to the matching pair since the output space is extremely large (21M). TGT instead generates informative passages that are close to the matching pair.

5 Conclusion and Future Directions

We have introduced a simple and formally justified distillation scheme (TGT) that adaptively generates samples with the aim of closing the divergence between student and teacher predictions. Our results show it to outperform, in aggregate, existing distillation approaches. Unlike alternative methods, it is also applicable to both continuous and discrete domains, as the results on image and text data show. TGT is orthogonal to other approaches that enable efficient inference such as quantization and pruning, and combining them is an interesting avenue for future work. Another potential research direction is to employ TGT for multi-modal data which would require accommodating multiple generative models with their own latent space, raising both practical and theoretical challenges.

References

  • Aguilar et al. [2020] Gustavo Aguilar, Yuan Ling, Yu Zhang, Benjamin Yao, Xing Fan, and Chenlei Guo. Knowledge distillation from internal representations. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 34, pages 7350–7357, 2020.
  • Arjovsky et al. [2017] Martin Arjovsky, Soumith Chintala, and Léon Bottou. Wasserstein generative adversarial networks. In International conference on machine learning, pages 214–223. PMLR, 2017.
  • Arora et al. [2017] Sanjeev Arora, Rong Ge, Yingyu Liang, Tengyu Ma, and Yi Zhang. Generalization and equilibrium in generative adversarial nets (GANs). In Doina Precup and Yee Whye Teh, editors, Proceedings of the 34th International Conference on Machine Learning, volume 70 of Proceedings of Machine Learning Research, pages 224–232. PMLR, 06–11 Aug 2017.
  • Arora et al. [2018] Sanjeev Arora, Andrej Risteski, and Yi Zhang. Do GANs learn the distribution? some theory and empirics. In International Conference on Learning Representations, 2018.
  • Berthelot et al. [2019] David Berthelot, Nicholas Carlini, Ian Goodfellow, Nicolas Papernot, Avital Oliver, and Colin A Raffel. Mixmatch: A holistic approach to semi-supervised learning. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d Alché-Buc, E. Fox, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 32. Curran Associates, Inc., 2019.
  • Binici et al. [2022] Kuluhan Binici, Nam Trung Pham, Tulika Mitra, and Karianto Leman. Preventing catastrophic forgetting and distribution mismatch in knowledge distillation via synthetic data. In Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision, pages 663–671, 2022.
  • Brock et al. [2019] Andrew Brock, Jeff Donahue, and Karen Simonyan. Large scale GAN training for high fidelity natural image synthesis. In International Conference on Learning Representations, 2019.
  • Brown et al. [2020] Tom Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared D Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, Sandhini Agarwal, Ariel Herbert-Voss, Gretchen Krueger, Tom Henighan, Rewon Child, Aditya Ramesh, Daniel Ziegler, Jeffrey Wu, Clemens Winter, Chris Hesse, Mark Chen, Eric Sigler, Mateusz Litwin, Scott Gray, Benjamin Chess, Jack Clark, Christopher Berner, Sam McCandlish, Alec Radford, Ilya Sutskever, and Dario Amodei. Language models are few-shot learners. In H. Larochelle, M. Ranzato, R. Hadsell, M. F. Balcan, and H. Lin, editors, Advances in Neural Information Processing Systems, volume 33, pages 1877–1901. Curran Associates, Inc., 2020.
  • Bucilua et al. [2006] Cristian Bucilua, Rich Caruana, and Alexandru Niculescu-Mizil. Model compression. In Proceedings of the 12th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, KDD ’06, page 535–541, New York, NY, USA, 2006. Association for Computing Machinery. ISBN 1595933395. doi: 10.1145/1150402.1150464.
  • Cao et al. [2019] Kaidi Cao, Colin Wei, Adrien Gaidon, Nikos Arechiga, and Tengyu Ma. Learning imbalanced datasets with label-distribution-aware margin loss. In Proceedings of the 33rd International Conference on Neural Information Processing Systems, pages 1567–1578, 2019.
  • Chen et al. [2019] Hanting Chen, Yunhe Wang, Chang Xu, Zhaohui Yang, Chuanjian Liu, Boxin Shi, Chunjing Xu, Chao Xu, and Qi Tian. Data-free learning of student networks. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pages 3514–3522, 2019.
  • Chen et al. [2020a] Ting Chen, Simon Kornblith, Mohammad Norouzi, and Geoffrey Hinton. A simple framework for contrastive learning of visual representations. In Hal Daumé III and Aarti Singh, editors, Proceedings of the 37th International Conference on Machine Learning, volume 119 of Proceedings of Machine Learning Research, pages 1597–1607. PMLR, 13–18 Jul 2020a.
  • Chen et al. [2020b] Ting Chen, Simon Kornblith, Kevin Swersky, Mohammad Norouzi, and Geoffrey E Hinton. Big self-supervised models are strong semi-supervised learners. In H. Larochelle, M. Ranzato, R. Hadsell, M. F. Balcan, and H. Lin, editors, Advances in Neural Information Processing Systems, volume 33, pages 22243–22255. Curran Associates, Inc., 2020b.
  • Cho and Hariharan [2019] J. Cho and B. Hariharan. On the efficacy of knowledge distillation. In 2019 IEEE/CVF International Conference on Computer Vision (ICCV), pages 4793–4801, Los Alamitos, CA, USA, nov 2019. IEEE Computer Society.
  • Devlin et al. [2019] Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. BERT: pre-training of deep bidirectional transformers for language understanding. In Jill Burstein, Christy Doran, and Thamar Solorio, editors, Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, NAACL-HLT 2019, Minneapolis, MN, USA, June 2-7, 2019, Volume 1 (Long and Short Papers), pages 4171–4186. Association for Computational Linguistics, 2019. doi: 10.18653/v1/n19-1423.
  • Devroye et al. [2013] Luc Devroye, László Györfi, and Gábor Lugosi. A probabilistic theory of pattern recognition, volume 31. Springer Science & Business Media, 2013.
  • Donahue and Simonyan [2019] Jeff Donahue and Karen Simonyan. Large scale adversarial representation learning. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d Alché-Buc, E. Fox, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 32. Curran Associates, Inc., 2019.
  • Dong et al. [2020] Zihe Dong, Xin Sun, Junyu Dong, and Haoran Zhao. Adversarial metric knowledge distillation. In 2020 the 6th International Conference on Communication and Information Processing, page 159–164, New York, NY, USA, 2020. Association for Computing Machinery. ISBN 9781450388092.
  • Du et al. [2021] Jingfei Du, Edouard Grave, Beliz Gunel, Vishrav Chaudhary, Onur Celebi, Michael Auli, Veselin Stoyanov, and Alexis Conneau. Self-training improves pre-training for natural language understanding. In Proceedings of the 2021 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, pages 5408–5418, Online, June 2021. Association for Computational Linguistics. doi: 10.18653/v1/2021.naacl-main.426.
  • Fan et al. [2018] Yang Fan, Fei Tian, Tao Qin, Xiang-Yang Li, and Tie-Yan Liu. Learning to teach. In International Conference on Learning Representations, 2018.
  • Foret et al. [2020] Pierre Foret, Ariel Kleiner, Hossein Mobahi, and Behnam Neyshabur. Sharpness-aware minimization for efficiently improving generalization. arXiv preprint arXiv:2010.01412, 2020.
  • Foster et al. [2019] Dylan J Foster, Spencer Greenberg, Satyen Kale, Haipeng Luo, Mehryar Mohri, and Karthik Sridharan. Hypothesis set stability and generalization. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d’Alché Buc, E. Fox, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 32. Curran Associates, Inc., 2019.
  • Gottlieb et al. [2016] Lee-Ad Gottlieb, Aryeh Kontorovich, and Robert Krauthgamer. Adaptive metric dimensionality reduction. Theoretical Computer Science, 620:105–118, 2016.
  • Györfi et al. [2002] László Györfi, Michael Kohler, Adam Krzyżak, and Harro Walk. A distribution-free theory of nonparametric regression, volume 1. Springer, 2002.
  • Hafner et al. [2019] Danijar Hafner, Timothy Lillicrap, Ian Fischer, Ruben Villegas, David Ha, Honglak Lee, and James Davidson. Learning latent dynamics for planning from pixels. In International Conference on Machine Learning, pages 2555–2565. PMLR, 2019.
  • He et al. [2016] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 770–778, 2016.
  • He et al. [2021] Xuanli He, Islam Nassar, Jamie Kiros, Gholamreza Haffari, and Mohammad Norouzi. Generate, annotate, and learn: Generative models advance self-training and knowledge distillation. arXiv preprint arXiv:2106.06168, 2021.
  • Heinzerling and Inui [2021] Benjamin Heinzerling and Kentaro Inui. Language models as knowledge bases: On entity representations, storage capacity, and paraphrased queries. In Proceedings of the 16th Conference of the European Chapter of the Association for Computational Linguistics: Main Volume, pages 1772–1791, Online, April 2021. Association for Computational Linguistics. doi: 10.18653/v1/2021.eacl-main.153.
  • Heo et al. [2019] Byeongho Heo, Minsik Lee, Sangdoo Yun, and Jin Young Choi. Knowledge distillation with adversarial samples supporting decision boundary. AAAI’19/IAAI’19/EAAI’19. AAAI Press, 2019. ISBN 978-1-57735-809-1.
  • Hinton et al. [2015] Geoffrey Hinton, Oriol Vinyals, and Jeff Dean. Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531, 2015.
  • Howard et al. [2019] Andrew Howard, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang, Yukun Zhu, Ruoming Pang, Vijay Vasudevan, Quoc V. Le, and Hartwig Adam. Searching for mobilenetv3. In Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV), October 2019.
  • Iscen et al. [2019] Ahmet Iscen, Giorgos Tolias, Yannis Avrithis, and Ondrej Chum. Label propagation for deep semi-supervised learning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), June 2019.
  • Jaegle et al. [2021] Andrew Jaegle, Felix Gimeno, Andy Brock, Oriol Vinyals, Andrew Zisserman, and Joao Carreira. Perceiver: General perception with iterative attention. In Marina Meila and Tong Zhang, editors, Proceedings of the 38th International Conference on Machine Learning, volume 139 of Proceedings of Machine Learning Research, pages 4651–4664. PMLR, 18–24 Jul 2021.
  • Jia et al. [2018] Xianyan Jia, Shutao Song, Wei He, Yangzihao Wang, Haidong Rong, Feihu Zhou, Liqiang Xie, Zhenyu Guo, Yuanzhou Yang, Liwei Yu, et al. Highly scalable deep learning training system with mixed-precision: Training imagenet in four minutes. arXiv preprint arXiv:1807.11205, 2018.
  • Kang et al. [2020] Bingyi Kang, Saining Xie, Marcus Rohrbach, Zhicheng Yan, Albert Gordo, Jiashi Feng, and Yannis Kalantidis. Decoupling representation and classifier for long-tailed recognition. In International Conference on Learning Representations, 2020.
  • Karpukhin et al. [2020] Vladimir Karpukhin, Barlas Oğuz, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. Dense passage retrieval for open-domain question answering. arXiv preprint arXiv:2004.04906, 2020.
  • Kwiatkowski et al. [2019] Tom Kwiatkowski, Jennimaria Palomaki, Olivia Redfield, Michael Collins, Ankur Parikh, Chris Alberti, Danielle Epstein, Illia Polosukhin, Jacob Devlin, Kenton Lee, et al. Natural questions: a benchmark for question answering research. Transactions of the Association for Computational Linguistics, 7:453–466, 2019.
  • Laine and Aila [2017] Samuli Laine and Timo Aila. Temporal ensembling for semi-supervised learning. In International Conference on Learning Representations, 2017.
  • Lee et al. [2019] Kenton Lee, Ming-Wei Chang, and Kristina Toutanova. Latent retrieval for weakly supervised open domain question answering. In Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics, pages 6086–6096, Florence, Italy, July 2019. Association for Computational Linguistics. doi: 10.18653/v1/P19-1612.
  • Lewis et al. [2020] Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Veselin Stoyanov, and Luke Zettlemoyer. Bart: Denoising sequence-to-sequence pre-training for natural language generation, translation, and comprehension. In Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics, pages 7871–7880, 2020.
  • Lewis et al. [2021] Patrick Lewis, Yuxiang Wu, Linqing Liu, Pasquale Minervini, Heinrich Küttler, Aleksandra Piktus, Pontus Stenetorp, and Sebastian Riedel. Paq: 65 million probably-asked questions and what you can do with them. Transactions of the Association for Computational Linguistics, 9:1098–1115, 2021.
  • Li et al. [2019] Hao-Ting Li, Shih-Chieh Lin, Cheng-Yeh Chen, and Chen-Kuo Chiang. Layer-level knowledge distillation for deep neural network learning. Applied Sciences, 9(10):1966, 2019.
  • Liu et al. [2019a] Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, and Veselin Stoyanov. RoBERTa: A robustly optimized bert pretraining approach. arXiv preprint arXiv:1907.11692, 2019a.
  • Liu et al. [2019b] Ziwei Liu, Zhongqi Miao, Xiaohang Zhan, Jiayun Wang, Boqing Gong, and Stella X. Yu. Large-scale long-tailed recognition in an open world. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), June 2019b.
  • Liu et al. [2019c] Ziwei Liu, Zhongqi Miao, Xiaohang Zhan, Jiayun Wang, Boqing Gong, and Stella X Yu. Large-scale long-tailed recognition in an open world. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 2537–2546, 2019c.
  • Maas et al. [2011] Andrew L. Maas, Raymond E. Daly, Peter T. Pham, Dan Huang, Andrew Y. Ng, and Christopher Potts. Learning word vectors for sentiment analysis. In Proceedings of the 49th Annual Meeting of the Association for Computational Linguistics: Human Language Technologies, pages 142–150, Portland, Oregon, USA, June 2011. Association for Computational Linguistics.
  • Maurer and Pontil [2009] A. Maurer and M. Pontil. Empirical bernstein bounds and sample-variance penalization. In Proceedings of the 22nd Conference on Learning Theory (COLT), June 2009.
  • Menon et al. [2021a] Aditya K Menon, Ankit Singh Rawat, Sashank Reddi, Seungyeon Kim, and Sanjiv Kumar. A statistical perspective on distillation. In Marina Meila and Tong Zhang, editors, Proceedings of the 38th International Conference on Machine Learning, volume 139 of Proceedings of Machine Learning Research, pages 7632–7642. PMLR, 18–24 Jul 2021a.
  • Menon et al. [2021b] Aditya Krishna Menon, Sadeep Jayasumana, Ankit Singh Rawat, Himanshu Jain, Andreas Veit, and Sanjiv Kumar. Long-tail learning via logit adjustment. In International Conference on Learning Representations, 2021b.
  • Micaelli and Storkey [2019] Paul Micaelli and Amos J Storkey. Zero-shot knowledge transfer via adversarial belief matching. Advances in Neural Information Processing Systems, 32, 2019.
  • Miyato et al. [2019] Takeru Miyato, Shin-Ichi Maeda, Masanori Koyama, and Shin Ishii. Virtual adversarial training: A regularization method for supervised and semi-supervised learning. IEEE Transactions on Pattern Analysis and Machine Intelligence, 41(8):1979–1993, 2019. doi: 10.1109/TPAMI.2018.2858821.
  • Mohri et al. [2018] Mehryar Mohri, Afshin Rostamizadeh, and Ameet Talwalkar. Foundations of machine learning. MIT press, 2018.
  • Nayak et al. [2019] Gaurav Kumar Nayak, Konda Reddy Mopuri, Vaisakh Shaj, Venkatesh Babu Radhakrishnan, and Anirban Chakraborty. Zero-shot knowledge distillation in deep networks. In International Conference on Machine Learning, pages 4743–4751. PMLR, 2019.
  • Nowak-Vila et al. [2021] Alex Nowak-Vila, Alessandro Rudi, and Francis Bach. Max-margin is dead, long live max-margin! arXiv preprint arXiv:2105.15069, 2021.
  • Oğuz et al. [2021] Barlas Oğuz, Kushal Lakhotia, Anchit Gupta, Patrick Lewis, Vladimir Karpukhin, Aleksandra Piktus, Xilun Chen, Sebastian Riedel, Wen-tau Yih, Sonal Gupta, et al. Domain-matched pre-training tasks for dense retrieval. arXiv preprint arXiv:2107.13602, 2021.
  • Patterson and Hays [2012] Genevieve Patterson and James Hays. Sun attribute database: Discovering, annotating, and recognizing scene attributes. In 2012 IEEE Conference on Computer Vision and Pattern Recognition, pages 2751–2758. IEEE, 2012.
  • Pham et al. [2021] Hieu Pham, Zihang Dai, Qizhe Xie, and Quoc V. Le. Meta pseudo labels. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pages 11557–11568, June 2021.
  • Raffel et al. [2020] Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, Yanqi Zhou, Wei Li, and Peter J. Liu. Exploring the limits of transfer learning with a unified text-to-text transformer. Journal of Machine Learning Research, 21(140):1–67, 2020.
  • Raghu et al. [2021] Aniruddh Raghu, Maithra Raghu, Simon Kornblith, David Duvenaud, and Geoffrey Hinton. Teaching with commentaries. In International Conference on Learning Representations, 2021.
  • Ramesh et al. [2021] Aditya Ramesh, Mikhail Pavlov, Gabriel Goh, Scott Gray, Chelsea Voss, Alec Radford, Mark Chen, and Ilya Sutskever. Zero-shot text-to-image generation. In Marina Meila and Tong Zhang, editors, Proceedings of the 38th International Conference on Machine Learning, volume 139 of Proceedings of Machine Learning Research, pages 8821–8831. PMLR, 18–24 Jul 2021.
  • Rashid et al. [2021] Ahmad Rashid, Vasileios Lioutas, and Mehdi Rezagholizadeh. MATE-KD: Masked adversarial TExt, a companion to knowledge distillation. In Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing (Volume 1: Long Papers), pages 1062–1071, Online, August 2021. Association for Computational Linguistics. doi: 10.18653/v1/2021.acl-long.86.
  • Ravuri and Vinyals [2019a] Suman Ravuri and Oriol Vinyals. Classification accuracy score for conditional generative models. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d Alché-Buc, E. Fox, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 32. Curran Associates, Inc., 2019a.
  • Ravuri and Vinyals [2019b] Suman Ravuri and Oriol Vinyals. Seeing is not necessarily believing: Limitations of biggans for data augmentation. Learning from Limited Labeled Data: ICLR 2019 Workshop, 2019b.
  • Razavi et al. [2019] Ali Razavi, Aaron van den Oord, and Oriol Vinyals. Generating diverse high-fidelity images with vq-vae-2. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d Alché-Buc, E. Fox, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 32. Curran Associates, Inc., 2019.
  • Ren et al. [2019] Shuhuai Ren, Yihe Deng, Kun He, and Wanxiang Che. Generating natural language adversarial examples through probability weighted word saliency. In Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics, pages 1085–1097, Florence, Italy, July 2019. Association for Computational Linguistics.
  • Russakovsky et al. [2015] Olga Russakovsky, Jia Deng, Hao Su, Jonathan Krause, Sanjeev Satheesh, Sean Ma, Zhiheng Huang, Andrej Karpathy, Aditya Khosla, Michael Bernstein, Alexander C. Berg, and Li Fei-Fei. ImageNet Large Scale Visual Recognition Challenge. International Journal of Computer Vision (IJCV), 115(3):211–252, 2015. doi: 10.1007/s11263-015-0816-y.
  • Sajjadi et al. [2016] Mehdi Sajjadi, Mehran Javanmardi, and Tolga Tasdizen. Regularization with stochastic transformations and perturbations for deep semi-supervised learning. In D. Lee, M. Sugiyama, U. Luxburg, I. Guyon, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 29. Curran Associates, Inc., 2016.
  • Samuel et al. [2021] Dvir Samuel, Yuval Atzmon, and Gal Chechik. From generalized zero-shot learning to long-tail with class descriptors. In Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV), 2021.
  • Sanh et al. [2019a] Victor Sanh, Lysandre Debut, Julien Chaumond, and Thomas Wolf. Distilbert, a distilled version of BERT: smaller, faster, cheaper and lighter. arXiv preprint arXiv:1910.01108, 2019a.
  • Sanh et al. [2019b] Victor Sanh, Lysandre Debut, Julien Chaumond, and Thomas Wolf. Distilbert, a distilled version of bert: smaller, faster, cheaper and lighter. arXiv preprint arXiv:1910.01108, 2019b.
  • Schönfeld et al. [2019] Edgar Schönfeld, Sayna Ebrahimi, Samarth Sinha, Trevor Darrell, and Zeynep Akata. Generalized zero-shot learning via aligned variational autoencoders. red, 2:D2, 2019.
  • Sohn et al. [2020] Kihyuk Sohn, David Berthelot, Nicholas Carlini, Zizhao Zhang, Han Zhang, Colin A Raffel, Ekin Dogus Cubuk, Alexey Kurakin, and Chun-Liang Li. Fixmatch: Simplifying semi-supervised learning with consistency and confidence. In H. Larochelle, M. Ranzato, R. Hadsell, M. F. Balcan, and H. Lin, editors, Advances in Neural Information Processing Systems, volume 33, pages 596–608. Curran Associates, Inc., 2020.
  • Stanton et al. [2021] Samuel Don Stanton, Pavel Izmailov, Polina Kirichenko, Alexander A Alemi, and Andrew Gordon Wilson. Does knowledge distillation really work? In A. Beygelzimer, Y. Dauphin, P. Liang, and J. Wortman Vaughan, editors, Advances in Neural Information Processing Systems, 2021.
  • Sun et al. [2017] C. Sun, A. Shrivastava, S. Singh, and A. Gupta. Revisiting unreasonable effectiveness of data in deep learning era. In 2017 IEEE International Conference on Computer Vision (ICCV), pages 843–852, 2017. doi: 10.1109/ICCV.2017.97.
  • Sun et al. [2019] Siqi Sun, Yu Cheng, Zhe Gan, and Jingjing Liu. Patient knowledge distillation for BERT model compression. In Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP), pages 4323–4332, Hong Kong, China, November 2019. Association for Computational Linguistics.
  • Sun et al. [2020] Zhiqing Sun, Hongkun Yu, Xiaodan Song, Renjie Liu, Yiming Yang, and Denny Zhou. MobileBERT: a compact task-agnostic BERT for resource-limited devices. In Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics, pages 2158–2170, Online, July 2020. Association for Computational Linguistics.
  • Tan et al. [2020] Samson Tan, Shafiq Joty, Min-Yen Kan, and Richard Socher. It’s morphin’ time! Combating linguistic discrimination with inflectional perturbations. In Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics, pages 2920–2935, Online, July 2020. Association for Computational Linguistics.
  • Tarvainen and Valpola [2017] Antti Tarvainen and Harri Valpola. Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results. In I. Guyon, U. V. Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 30. Curran Associates, Inc., 2017.
  • Verma et al. [2019] Vikas Verma, Alex Lamb, Juho Kannala, Yoshua Bengio, and David Lopez-Paz. Interpolation consistency training for semi-supervised learning. In Proceedings of the Twenty-Eighth International Joint Conference on Artificial Intelligence (IJCAI-19), pages 3635–3641, 2019. International Joint Conference on Artificial Intelligence, IJCAI ; Conference date: 10-08-2019 Through 16-08-2019.
  • Villani [2008] C. Villani. Optimal transport: old and new. Springer Verlag, 2008.
  • Wang et al. [2021] Jianfeng Wang, Thomas Lukasiewicz, Xiaolin Hu, Jianfei Cai, and Zhenghua Xu. Rsg: A simple but effective module for learning imbalanced datasets. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pages 3784–3793, June 2021.
  • Wang et al. [2018] Xiaojie Wang, Rui Zhang, Yu Sun, and Jianzhong Qi. Kdgan: Knowledge distillation with generative adversarial networks. In S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. Cesa-Bianchi, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 31. Curran Associates, Inc., 2018.
  • Webster et al. [2019] Ryan Webster, Julien Rabin, Loïc Simon, and Frédéric Jurie. Detecting overfitting of deep generative networks via latent recovery. In 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pages 11265–11274, 2019. doi: 10.1109/CVPR.2019.01153.
  • Weng et al. [2021] Zhenzhen Weng, Mehmet Giray Ogut, Shai Limonchik, and Serena Yeung. Unsupervised discovery of the long-tail in instance segmentation using hierarchical self-supervision. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 2603–2612, 2021.
  • Williams et al. [2018] Adina Williams, Nikita Nangia, and Samuel Bowman. A broad-coverage challenge corpus for sentence understanding through inference. In Proceedings of the 2018 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long Papers), pages 1112–1122. Association for Computational Linguistics, 2018.
  • Xie et al. [2020a] Qizhe Xie, Zihang Dai, Eduard Hovy, Thang Luong, and Quoc Le. Unsupervised data augmentation for consistency training. In H. Larochelle, M. Ranzato, R. Hadsell, M. F. Balcan, and H. Lin, editors, Advances in Neural Information Processing Systems, volume 33, pages 6256–6268. Curran Associates, Inc., 2020a.
  • Xie et al. [2020b] Qizhe Xie, Zihang Dai, Eduard Hovy, Thang Luong, and Quoc Le. Unsupervised data augmentation for consistency training. Advances in Neural Information Processing Systems, 33, 2020b.
  • Xie et al. [2020c] Qizhe Xie, Minh-Thang Luong, Eduard Hovy, and Quoc V. Le. Self-training with noisy student improves imagenet classification. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), June 2020c.
  • Xie et al. [2020d] Qizhe Xie, Minh-Thang Luong, Eduard Hovy, and Quoc V. Le. Self-training with noisy student improves imagenet classification. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), June 2020d.
  • Yoo et al. [2019] Jaemin Yoo, Minyong Cho, Taebum Kim, and U Kang. Knowledge extraction with no observable data. Advances in Neural Information Processing Systems, 32, 2019.
  • Zang et al. [2020] Yuan Zang, Fanchao Qi, Chenghao Yang, Zhiyuan Liu, Meng Zhang, Qun Liu, and Maosong Sun. Word-level textual adversarial attacking as combinatorial optimization. In Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics, pages 6066–6080, Online, July 2020. Association for Computational Linguistics. doi: 10.18653/v1/2020.acl-main.540.
  • Zhang et al. [2015] Xiang Zhang, Junbo Zhao, and Yann LeCun. Character-level convolutional networks for text classification. Advances in neural information processing systems, 28:649–657, 2015.
  • Zhu et al. [2021] Zhaowei Zhu, Tianyi Luo, and Yang Liu. The rich get richer: Disparate impact of semi-supervised learning. arXiv preprint arXiv:2110.06282, 2021.

Appendix A Further comparison with MATE-KD

MATE-KD [Rashid et al., 2021] alternative trains generator model and student model, with the hope of generating most adversarial examples for the students during the training. This can cause the generator to drift away from true data distribution. In contrast, we keep the pre-trained teacher-generator model fixed throughout the training process of the student. Our objective behind employing the generator model is to leverage the domain knowledge it has already acquired during its pre-training. While we do want to generate ‘hard instances’ for the student, we also want those instances to be relevant for the underlying task. Thus, keeping the generator fixed introduces a regularization where the training instances the student encounters do not introduce domain mismatch.

Keeping in mind the objective of producing new informative training instances that are in-domain, we introduce perturbation in the latent space realized by the encoder of the teacher-generator model (see Figure 1). This is different from directly perturbing an original training instance in the input space itself, as done by MATE-KD. As evident from our theoretical analysis and empirical evaluation, for a fixed teacher-generator model, employing perturbation in the latent space leads to more informative data augmentation and enables good performance on both image and text domain.

Appendix B Background and notation

For a,ba,b\in{\mathbb{R}}, we use a=𝒪(b)a=\mathcal{O}(b) to denote that there exits a constant γ>0\gamma>0 such that aγba\leq\gamma\cdot b.

Given a collection of nn i.i.d. random variables 𝒰n={u1,,un}𝒰\mathscr{U}_{n}=\{u_{1},\ldots,u_{n}\}\subset\mathcal{U}, generated from a distribution 𝒟U\mathscr{D}_{U} and a function τ:𝒰\tau:\mathcal{U}\to{\mathbb{R}}, we define the empirical mean of {τ(u1),,τ(un)}\{\tau(u_{1}),\ldots,\tau(u_{n})\} as

𝔼𝒰n[τ(U)]:=1ni[n]ui.\displaystyle\mathbb{E}_{\mathscr{U}_{n}}[\tau(U)]:=\frac{1}{n}\sum_{i\in[n]}u_{i}. (8)

For the underlying multiclass classification problem defined by the distribution 𝒟:=𝒟X×Y\mathscr{D}:=\mathscr{D}_{X\times Y}, we assume that the label set 𝒴\mathscr{Y} with KK classes takes the form [K]:={1,,K}[K]:=\{1,\ldots,K\}. We use \mathcal{F} to denote the collection of potential classification models that the learning methods is allowed to select from, namely function class or hypothesis set:

{𝒳K},\displaystyle\mathcal{F}\subseteq\{\mathscr{X}\to{\mathbb{R}}^{K}\}, (9)

which is a subset of all functions that map elements of the instance space 𝒳\mathscr{X} to the elements of K{\mathbb{R}}^{K}.

Given a classification loss function :K×𝒴\ell:{\mathbb{R}}^{K}\times\mathscr{Y}\to{\mathbb{R}} and a model f:𝒳Kf:\mathscr{X}\to{\mathbb{R}}^{K} and a sample 𝒮nlabeled={(xi,yi)}i[n]\mathscr{S}^{\rm labeled}_{n}=\{(x_{i},y_{i})\}_{i\in[n]} generated from 𝒟\mathscr{D}, we define the empirical risk for ff\in\mathcal{F} as follows.

R,f(𝒮nlabeled):=𝔼𝒮nlabeled[(f(X))]=1ni[n](f(xi),yi).\displaystyle R_{\ell,f}(\mathscr{S}^{\rm labeled}_{n}):=\mathbb{E}_{\mathscr{S}^{\rm labeled}_{n}}[\ell\big{(}f(X)\big{)}]=\frac{1}{n}\sum\nolimits_{i\in[n]}\ell\big{(}f(x_{i}),y_{i}\big{)}. (10)

Further, we define the population risk for ff\in\mathcal{F} associated with data distribution 𝒟\mathscr{D} as follows.

R,f(𝒟)=𝔼X,Y𝒟[(f(X),Y)].\displaystyle R_{\ell,f}(\mathscr{D})=\mathbb{E}_{X,Y\sim\mathscr{D}}[\ell(f(X),Y)]. (11)

Note that, when the loss function \ell is clear from the context, we drop \ell from the notation and simply use Rf(𝒮nlabeled)R_{f}(\mathscr{S}^{\rm labeled}_{n}) and Rf(𝒟)R_{f}(\mathscr{D}) to denote the the empirical and populations risks for ff, respectively.

Given a function class \mathcal{F}, the loss function \ell induces the following function class.

𝒢,={(x,y)(f(x),y):f}.\displaystyle\mathcal{G}_{\ell,\mathcal{F}}=\big{\{}(x,y)\mapsto\ell(f(x),y):f\in\mathcal{F}\big{\}}. (12)
Definition B.1 (Rademacher complexity of 𝒢,\mathcal{G}_{\ell,\mathcal{F}}).

Now, given a sample 𝒮nlabeled={(xi,yi)}i[n]𝒟n\mathscr{S}^{\rm labeled}_{n}=\{(x_{i},y_{i})\}_{i\in[n]}\sim\mathscr{D}^{n} and a vector 𝝈=(σi,,σm){+1,1}\bm{\sigma}=(\sigma_{i},\ldots,\sigma_{m})\in\{+1,-1\} with nn i.i.d. Bernoulli random variables, empirical Rademacher complexity 𝒮(𝒢,)\mathfrak{R}_{\mathscr{S}}(\mathcal{G}_{\ell,\mathcal{F}}) and Rademacher complexity n(𝒢,)\mathfrak{R}_{n}(\mathcal{G}_{\ell,\mathcal{F}}) are defined as

𝒮nlabeled(𝒢,)=1n𝔼𝝈[supg𝒢,i=1nσig(xi,yi)]\displaystyle\mathfrak{R}_{\mathscr{S}^{\rm labeled}_{n}}(\mathcal{G}_{\ell,\mathcal{F}})=\frac{1}{n}\mathbb{E}_{\bm{\sigma}}\left[\sup_{g\in\mathcal{G}_{\ell,\mathcal{F}}}\sum_{i=1}^{n}\sigma_{i}g(x_{i},y_{i})\right]

and

n(𝒢,)=𝔼𝒮𝒟n[𝒮nlabeled(𝒢,)]\displaystyle\mathfrak{R}_{n}(\mathcal{G}_{\ell,\mathcal{F}})=\mathbb{E}_{\mathscr{S}\sim\mathscr{D}^{n}}\left[\mathfrak{R}_{\mathscr{S}^{\rm labeled}_{n}}(\mathcal{G}_{\ell,\mathcal{F}})\right] (13)

Let 𝒮n={xi}i[n]\mathscr{S}_{n}=\{x_{i}\}_{i\in[n]} be a set of nn unlabeled samples generated from 𝒟X\mathscr{D}_{X}. Then, given a teacher model h:𝒳Kh:\mathscr{X}\to{\mathbb{R}}^{K} and a distillation loss d:K×K\ell_{{\mathrm{d}}}:{\mathbb{R}}^{K}\times{\mathbb{R}}^{K}\to{\mathbb{R}}, we define the empirical (distillation) risk for ff\in\mathcal{F} to be

Rd,fh(𝒮n):=𝔼𝒮n[d(f(X),h(X))]=1ni[n]d(f(xi),h(xi)).\displaystyle R^{h}_{\ell_{{\mathrm{d}}},f}(\mathscr{S}_{n}):=\mathbb{E}_{\mathscr{S}_{n}}\left[\ell_{{\mathrm{d}}}(f(X),h(X))\right]=\frac{1}{n}\sum_{i\in[n]}\ell_{{\mathrm{d}}}\big{(}f(x_{i}),h(x_{i})\big{)}. (14)

Accordingly, the population (distillation) risk for ff\in\mathcal{F} is defined as

Rd,fh(𝒟):=𝔼X𝒟X[d(f(X),h(X))].\displaystyle R^{h}_{\ell_{{\mathrm{d}}},f}(\mathscr{D}):=\mathbb{E}_{X\sim\mathscr{D}_{X}}\left[\ell_{{\mathrm{d}}}(f(X),h(X))\right]. (15)

Again, when d\ell_{{\mathrm{d}}} is clear from the context, we simply use Rfh(𝒮n)R^{h}_{f}(\mathscr{S}_{n}) and Rfh(𝒟)R^{h}_{f}(\mathscr{D}) to denote the empirical and population distillation risk for ff, respectively.

Note that, for a (student) function class \mathcal{F} and a teacher model hh, d\ell_{{\mathrm{d}}} produces an induced function class 𝒢d,h()\mathcal{G}_{\ell_{{\mathrm{d}}},h}(\mathcal{F}), defined as

𝒢d,h:={xd(f(x),h(x)):f}.\displaystyle\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}}:=\{x\mapsto\ell_{{\mathrm{d}}}(f(x),h(x)):f\in\mathcal{F}\}. (16)
Definition B.2 (Rademacher complexity of 𝒢d,h\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}}).

Given a sample 𝒮n={xi}i[n]𝒟Xn\mathscr{S}_{n}=\{x_{i}\}_{i\in[n]}\sim\mathscr{D}_{X}^{n} and a vector 𝝈=(σi,,σm){+1,1}\bm{\sigma}=(\sigma_{i},\ldots,\sigma_{m})\in\{+1,-1\} with nn i.i.d. Bernoulli randoms variable, empirical Rademacher complexity 𝒮n(𝒢d,h)\mathfrak{R}_{\mathscr{S}_{n}}\big{(}\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}}\big{)} and Rademacher complexity n(𝒢d,h)\mathfrak{R}_{n}\big{(}\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}}\big{)} are defined as

𝒮n(𝒢d,h)=1n𝔼𝝈[supg𝒢d,hi=1nσig(xi)],\displaystyle\mathfrak{R}_{\mathscr{S}_{n}}(\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}})=\frac{1}{n}\mathbb{E}_{\bm{\sigma}}\Big{[}\sup_{g\in\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}}}\sum_{i=1}^{n}\sigma_{i}g(x_{i})\Big{]}, (17)

and

n(𝒢d,h)=𝔼𝒮𝒟Xn[𝒮n(𝒢d,h)]\displaystyle\mathfrak{R}_{n}(\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}})=\mathbb{E}_{\mathscr{S}\sim\mathscr{D}_{X}^{n}}\left[\mathfrak{R}_{\mathscr{S}_{n}}(\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}})\right] (18)

respectively.

Appendix C Deferred proofs from Section 3

C.1 Proof of Theorem 3.1

In this subsection, we present a general version of Theorem 3.1. Before that, we state the following relevant assumption on the distillation loss d\ell_{{\mathrm{d}}}.

Assumption C.1.

Let :K×𝒴\ell:{\mathbb{R}}^{K}\times\mathscr{Y}\to{\mathbb{R}} be a bounded loss function. For a teacher function h:𝒳Kh:\mathscr{X}\to{\mathbb{R}}^{K}, the distillation loss d\ell_{{\mathrm{d}}} takes the form

d(f(x),h(x))=y[K]h(x)y(f(x),y).\ell_{{\mathrm{d}}}(f(x),h(x))=\sum_{y\in[K]}h(x)_{y}\cdot\ell(f(x),y).
Remark C.2.

Note that the cross-entropy loss d(f(x),h(x))=yh(x)ylog(f(x)y)\ell_{{\mathrm{d}}}(f(x),h(x))=-\sum_{y}h(x)_{y}\cdot\log\big{(}f(x)_{y}\big{)}, here, one of the most common choices for the distillation loss, indeed satisfies Assumption C.1.222For the sake of brevity, we simply include the softmax-operation in the definition of hh and ff, i.e., h(x)h(x) and f(x)f(x) are valid probability distributions over 𝒴=[K]\mathscr{Y}=[K].

The following results is a general version of Theorem 3.1 in the main body.

Theorem C.3.

Let a generator with the encoder 𝖤𝗇𝖼\mathsf{Enc} and decoder 𝖣𝖾𝖼\mathsf{Dec} ensures the approximation guarantee in Eq. 6 for 𝒟X\mathscr{D}_{X}. Let 𝖣𝖾𝖼\mathsf{Dec} and teacher labeler be Lipschtiz functions, \mathcal{F} be function class of Lipschitz functions, and the distillation loss d\ell_{{\mathrm{d}}} be Lipschtiz. Then, with probability at least 1δ1-\delta, the following holds for any ff\in\mathcal{F}.

Rd,fh(𝒟X)Rd,fh(𝒮n)+𝒪(n1/d)+Lϵ+𝒪(log(1/δ)n),\displaystyle R^{h}_{\ell_{{\mathrm{d}}},f}(\mathscr{D}_{X})\leq R^{h}_{\ell_{{\mathrm{d}}},f}(\mathscr{S}_{n})+\mathcal{O}\big{(}n^{-1/d}\big{)}+L\epsilon+\mathcal{O}\Big{(}\sqrt{\frac{\log(1/\delta)}{n}}\Big{)}, (19)

where LL denotes the effective Lipschitz constant of the induced function class 𝒢d,h\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}} in Eq. 16. Additionally, if the distillation loss d\ell_{{\mathrm{d}}} satisfies Assumption C.1 with a classification loss \ell, then Eq. 19 further implies the following.

R,f(𝒟)Rd,fh(𝒮n)+𝒪(n1/d)+Lϵ+𝒪(log(1/δ)n)+𝒪(K𝔼𝒟X[𝒟Y|Xh(X)2]).\displaystyle R_{\ell,f}(\mathscr{D})\leq R^{h}_{\ell_{{\mathrm{d}}},f}(\mathscr{S}_{n})+\mathcal{O}\big{(}n^{-1/d}\big{)}+L\epsilon+\mathcal{O}\Big{(}\sqrt{\frac{\log(1/\delta)}{n}}\Big{)}+\mathcal{O}\big{(}{\sqrt{K}}\cdot\mathbb{E}_{\mathscr{D}_{X}}\left[\|\mathscr{D}_{Y|X}-h(X)\|_{2}\right]\big{)}. (20)
Proof.

Note that

Rd,fh(𝒟X)\displaystyle R^{h}_{\ell_{{\mathrm{d}}},f}(\mathscr{D}_{X}) =𝔼𝒟X[d(f(X),h(X))]\displaystyle=\mathbb{E}_{\mathscr{D}_{X}}[\ell_{{\mathrm{d}}}(f(X),h(X))]
𝔼𝒮n[d(f(X),h(X))]+supf|𝔼𝒮n[d(f(X),h(X))]𝔼𝒟X[d(f(X),h(X))]|\displaystyle\leq\mathbb{E}_{\mathscr{S}_{n}}[\ell_{{\mathrm{d}}}(f(X),h(X))]+\sup_{f\in\mathcal{F}}\Big{|}\mathbb{E}_{\mathscr{S}_{n}}[\ell_{{\mathrm{d}}}(f(X),h(X))]-\mathbb{E}_{\mathscr{D}_{X}}[\ell_{{\mathrm{d}}}(f(X),h(X))]\Big{|}
(i)𝔼𝒮n[d(f(X),h(X))]+supg𝒢d,h|𝔼𝒮n[g(X)]𝔼𝒟X[g(X)]|\displaystyle\overset{(i)}{\leq}\mathbb{E}_{\mathscr{S}_{n}}[\ell_{{\mathrm{d}}}(f(X),h(X))]+\sup_{g\in\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}}}\Big{|}\mathbb{E}_{\mathscr{S}_{n}}[g(X)]-\mathbb{E}_{\mathscr{D}_{X}}[g(X)]\Big{|}
(ii)𝔼𝒮n[d(f(X),h(X))]+𝒮n(𝒢d,h),\displaystyle\overset{(ii)}{\leq}\mathbb{E}_{\mathscr{S}_{n}}[\ell_{{\mathrm{d}}}(f(X),h(X))]+\mathfrak{R}_{\mathscr{S}_{n}}(\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}}), (21)

where (i)(i) follows from the definition of 𝒢d,h\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}} in Eq. 16 and (i)(i) follow from the standard symmetrization argument [Devroye et al., 2013, Mohri et al., 2018]. Next, we turn our focus to the empirical Rademacher complexity 𝒮n(𝒢d,h)\mathfrak{R}_{\mathscr{S}_{n}}(\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}}). Recall that 𝒮n={x1,x2,,xn}\mathscr{S}_{n}=\{x_{1},x_{2},\ldots,x_{n}\} contains nn i.i.d. samples generated from the distribution 𝒟X\mathscr{D}_{X}. We define another set of nn points

𝒮~n={x~1=𝖣𝖾𝖼𝖤𝗇𝖼(x1),,x~n=𝖣𝖾𝖼𝖤𝗇𝖼(xn)}.\displaystyle\tilde{\mathscr{S}}_{n}=\{\tilde{x}_{1}=\mathsf{Dec}\circ\mathsf{Enc}(x_{1}),\ldots,\tilde{x}_{n}=\mathsf{Dec}\circ\mathsf{Enc}(x_{n})\}.

It follows from our assumption on the quality of the generator (cf. Eq. 6) that

𝖣𝖾𝖼𝖤𝗇𝖼(xi)xiϵ,i[n].\displaystyle\|\mathsf{Dec}\circ\mathsf{Enc}(x_{i})-x_{i}\|\leq\epsilon,\quad\forall i\in[n]. (22)

Note that

𝒮n(𝒢d,h)=1n𝔼𝝈|supg𝒢d,hiσig(xi)|,\displaystyle\mathfrak{R}_{\mathscr{S}_{n}}(\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}})=\frac{1}{n}\mathbb{E}_{\bm{\sigma}}\left|\sup_{g\in\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}}}\sum_{i}\sigma_{i}g(x_{i})\right|,

where 𝝈\bm{\sigma} denote a vector with nn i.i.d Bernoulli random variables.

𝒮n(𝒢d,h)\displaystyle\mathfrak{R}_{\mathscr{S}_{n}}(\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}}) =1n𝔼𝝈|supg𝒢d,h1niσi(g(x~i)g(x~i)+g(xi))|\displaystyle=\frac{1}{n}\mathbb{E}_{\bm{\sigma}}\left|\sup_{g\in\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}}}\frac{1}{n}\sum_{i}\sigma_{i}\big{(}g(\tilde{x}_{i})-g(\tilde{x}_{i})+g(x_{i})\big{)}\right|
1n𝔼𝝈|supg𝒢d,h1niσig(x~i)|+\displaystyle\leq\frac{1}{n}\mathbb{E}_{\bm{\sigma}}\left|\sup_{g\in\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}}}\frac{1}{n}\sum_{i}\sigma_{i}g(\tilde{x}_{i})\right|+
1n𝔼𝝈|supg𝒢d,hiσi(g(xi)g(x~i))|\displaystyle~{}\quad\frac{1}{n}\mathbb{E}_{\bm{\sigma}}\left|\sup_{g\in\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}}}\sum_{i}\sigma_{i}\big{(}g(x_{i})-g(\tilde{x}_{i})\big{)}\right|
1n𝔼𝝈|supg𝒢d,hiσig(x~i)|+supg𝒢d,h1ni|g(xi)g(x~i)|\displaystyle\leq\frac{1}{n}\mathbb{E}_{\bm{\sigma}}\left|\sup_{g\in\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}}}\sum_{i}\sigma_{i}g(\tilde{x}_{i})\right|+\sup_{g\in\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}}}\frac{1}{n}\sum_{i}\left|g(x_{i})-g(\tilde{x}_{i})\right|
1n𝔼𝝈|supg𝒢d,hiσig(x~i)|+1niLxix~i\displaystyle\leq\frac{1}{n}\mathbb{E}_{\bm{\sigma}}\left|\sup_{g\in\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}}}\sum_{i}\sigma_{i}g(\tilde{x}_{i})\right|+\frac{1}{n}\sum_{i}L\cdot\|x_{i}-\tilde{x}_{i}\|
1n𝔼𝝈|supg𝒢d,hiσig(x~i)|+Lϵ\displaystyle\leq\frac{1}{n}\mathbb{E}_{\bm{\sigma}}\left|\sup_{g\in\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}}}\sum_{i}\sigma_{i}g(\tilde{x}_{i})\right|+L\epsilon
1n𝔼𝝈|supg𝒢d,hiσig(𝖣𝖾𝖼(zi))|+Lϵ,\displaystyle\leq\frac{1}{n}\mathbb{E}_{\bm{\sigma}}\left|\sup_{g\in\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}}}\sum_{i}\sigma_{i}g(\mathsf{Dec}(z_{i}))\right|+L\epsilon, (23)

where zi=𝖤𝗇𝖼(xi)z_{i}=\mathsf{Enc}(x_{i}), for i[n]i\in[n]. By definition of 𝒢d,h\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}}, g(𝖣𝖾𝖼(e))=d(f(x),h(x))g(\mathsf{Dec}(e))=\ell_{{\mathrm{d}}}(f(x),h(x)) for some ff\in\mathcal{F}. Now, we can define a new function class from d{\mathbb{R}}^{d} to {\mathbb{R}}:

𝒢d,h,𝖣𝖾𝖼={zd(f𝖣𝖾𝖼(z),h𝖣𝖾𝖼(z)):f}.\displaystyle\mathcal{G}^{h,\mathsf{Dec}}_{\ell_{{\mathrm{d}}},\mathcal{F}}=\{z\mapsto\ell_{{\mathrm{d}}}(f\circ\mathsf{Dec}(z),h\circ\mathsf{Dec}(z)):f\in\mathcal{F}\}. (24)

Therefore, it follows from Eq. 23 and Eq. 24 that

𝒮n(𝒢d,h)n(𝒢d,h,𝖣𝖾𝖼)+Lϵ,\displaystyle\mathfrak{R}_{\mathscr{S}_{n}}(\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}})\leq\mathfrak{R}_{\mathscr{E}_{n}}(\mathcal{G}^{h,\mathsf{Dec}}_{\ell_{{\mathrm{d}}},\mathcal{F}})+L\epsilon, (25)

where n={𝖤𝗇𝖼(x1),,𝖤𝗇𝖼(xn)}d\mathscr{E}_{n}=\{\mathsf{Enc}(x_{1}),\ldots,\mathsf{Enc}(x_{n})\}\subset{\mathbb{R}}^{d}. It follows from the standard concentration results for empirical Rademacher complexity around Rademacher complexity that with probability at least 1δ1-\delta,

n(𝒢d,h,𝖣𝖾𝖼)n(𝒢d,h,𝖣𝖾𝖼)+𝒪(log(1δ)1n).\displaystyle\mathfrak{R}_{\mathscr{E}_{n}}(\mathcal{G}^{h,\mathsf{Dec}}_{\ell_{{\mathrm{d}}},\mathcal{F}})\leq\mathfrak{R}_{n}(\mathcal{G}^{h,\mathsf{Dec}}_{\ell_{{\mathrm{d}}},\mathcal{F}})+\mathcal{O}\Big{(}\sqrt{{\log\Big{(}\frac{1}{\delta}\Big{)}}\cdot\frac{1}{n}}\Big{)}. (26)

Since ff\in\mathcal{F}, hh and 𝖣𝖾𝖼\mathsf{Dec} are Lipschitz functions, 𝒢d,h,𝖣𝖾𝖼\mathcal{G}^{h,\mathsf{Dec}}_{\ell_{{\mathrm{d}}},\mathcal{F}} is collection of Lipschitz functions from d{\mathbb{R}}^{d} to {\mathbb{R}}. Thus, it follows from the standard results [Gottlieb et al., 2016, Theorem 4.3] that

n(𝒢d,h,𝖣𝖾𝖼)𝒪(n1d).\displaystyle\mathfrak{R}_{n}(\mathcal{G}^{h,\mathsf{Dec}}_{\ell_{{\mathrm{d}}},\mathcal{F}})\leq\mathcal{O}\big{(}{n^{-\frac{1}{d}}}\big{)}. (27)

Now, Eq. 19 follow from Sec. C.1, Eq. 25, Eq. 26, and Eq. 27. Finally, Eq. 20 follows by combining Lemma D.4 with Eq. 19. ∎

C.2 Proof of Theorem 3.2

Here, we present the following result, which along with Remark C.5 implies Theorem 3.2 stated in the main body.

Theorem C.4.

Let 𝒮~n={x~i}i[n]\tilde{\mathscr{S}}_{n}=\{\tilde{x}_{i}\}_{i\in[n]} be nn i.i.d. samples generated from a distribution 𝒟~X\tilde{\mathscr{D}}_{X}. Further, let f~n\tilde{f}_{n}\in\mathcal{F} denote the student model learned via distillation on 𝒮~n\tilde{\mathscr{S}}_{n}, with hh and d\ell_{{\mathrm{d}}} as the teacher model and distillation loss, respectively. Then, with probability at least 1δ1-\delta, we have

Rd,f~nh(𝒟X)Rd,f~nh(𝒮~n)+𝒲(𝒟X,𝒟~X)+~n(𝒢d,h)+𝒪(log(1δ)1n),\displaystyle R^{h}_{\ell_{{\mathrm{d}}},\tilde{f}_{n}}({\mathscr{D}}_{X})\leq R^{h}_{\ell_{{\mathrm{d}}},\tilde{f}_{n}}(\tilde{\mathscr{S}}_{n})+\mathcal{W}(\mathscr{D}_{X},\tilde{\mathscr{D}}_{X})+\tilde{\mathfrak{R}}_{n}(\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}})+\mathcal{O}\Big{(}\sqrt{{\log\Big{(}\frac{1}{\delta}\Big{)}}\cdot\frac{1}{n}}\Big{)}, (28)

where ~n(𝒢d,h)=𝔼𝒮~𝒟~n[𝒮~n(𝒢d,h)]\tilde{\mathfrak{R}}_{n}(\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}})=\mathbb{E}_{\tilde{\mathscr{S}}\sim\tilde{\mathscr{D}}^{n}}\left[\mathfrak{R}_{\tilde{\mathscr{S}}_{n}}(\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}})\right] denote that Rademacher complexity of the induced function class 𝒢d,h\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}}, defined in Eq. 16. If 𝒮~\tilde{\mathscr{S}} is constructed with the TGT framework based on a generator with the encoder 𝖤𝗇𝖼\mathsf{Enc} and decoder 𝖣𝖾𝖼\mathsf{Dec}, then Eq. 28 further specialized to

Rd,f~nh(𝒟X)Rd,f~nh(𝒮~n)+𝒲(𝒟X,𝒟~X)+~n(𝒢d,h,𝖣𝖾𝖼)+𝒪(log(1δ)1n),\displaystyle R^{h}_{\ell_{{\mathrm{d}}},\tilde{f}_{n}}({\mathscr{D}}_{X})\leq R^{h}_{\ell_{{\mathrm{d}}},\tilde{f}_{n}}(\tilde{\mathscr{S}}_{n})+\mathcal{W}(\mathscr{D}_{X},\tilde{\mathscr{D}}_{X})+\tilde{\mathfrak{R}}_{n}(\mathcal{G}^{h,\mathsf{Dec}}_{\ell_{{\mathrm{d}}},\mathcal{F}})+\mathcal{O}\Big{(}\sqrt{{\log\Big{(}\frac{1}{\delta}\Big{)}}\cdot\frac{1}{n}}\Big{)}, (29)

where 𝒢d,h,𝖣𝖾𝖼\mathcal{G}^{h,\mathsf{Dec}}_{\ell_{{\mathrm{d}}},\mathcal{F}} defines the following induced function class from d{\mathbb{R}}^{d} (i.e., the latent space of the generator) to {\mathbb{R}}.

𝒢d,h,𝖣𝖾𝖼={zd(f𝖣𝖾𝖼(z),h𝖣𝖾𝖼(z)):f}.\displaystyle\mathcal{G}^{h,\mathsf{Dec}}_{\ell_{{\mathrm{d}}},\mathcal{F}}=\{z\mapsto\ell_{{\mathrm{d}}}(f\circ\mathsf{Dec}(z),h\circ\mathsf{Dec}(z)):f\in\mathcal{F}\}. (30)
Proof.

Note that

Rd,f~nh(𝒟~X)\displaystyle R^{h}_{\ell_{{\mathrm{d}}},\tilde{f}_{n}}(\tilde{\mathscr{D}}_{X}) =𝔼𝒟~X[d(f~n(X),h(X))]\displaystyle=\mathbb{E}_{\tilde{\mathscr{D}}_{X}}[\ell_{{\mathrm{d}}}(\tilde{f}_{n}(X),h(X))]
𝔼𝒮~n[d(f~n(X),h(X))]+supf|𝔼𝒮~n[d(f(X),h(X))]𝔼𝒟~X[d(f(X),h(X))]|\displaystyle\leq\mathbb{E}_{\tilde{\mathscr{S}}_{n}}[\ell_{{\mathrm{d}}}(\tilde{f}_{n}(X),h(X))]+\sup_{f\in\mathcal{F}}\big{|}\mathbb{E}_{\tilde{\mathscr{S}}_{n}}[\ell_{{\mathrm{d}}}(f(X),h(X))]-\mathbb{E}_{\tilde{\mathscr{D}}_{X}}[\ell_{{\mathrm{d}}}(f(X),h(X))]\big{|}
𝔼𝒮~n[d(f~n(X),h(X))]+supg𝒢d,h|𝔼𝒮~n[g(X)]𝔼𝒟~X[g(X)]|\displaystyle{\leq}\mathbb{E}_{\tilde{\mathscr{S}}_{n}}[\ell_{{\mathrm{d}}}(\tilde{f}_{n}(X),h(X))]+\sup_{g\in\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}}}\Big{|}\mathbb{E}_{\tilde{\mathscr{S}}_{n}}[g(X)]-\mathbb{E}_{\tilde{\mathscr{D}}_{X}}[g(X)]\Big{|}
𝔼𝒮~n[d(f~n(X),h(X))]+𝒮~n(𝒢d,h),\displaystyle{\leq}\mathbb{E}_{\tilde{\mathscr{S}}_{n}}[\ell_{{\mathrm{d}}}(\tilde{f}_{n}(X),h(X))]+{\mathfrak{R}}_{\tilde{\mathscr{S}}_{n}}(\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}}), (31)

where the last two inequality follows from the definition of 𝒢d,h\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}} (cf. Eq. 16) and the standard symmetrization argument  [Devroye et al., 2013, Mohri et al., 2018], respectively.

Now, the standard concentration results for empirical Rademacher complexity implies that, with probability at least 1δ1-\delta, we have the following.

𝒮~n(𝒢d,h)\displaystyle{\mathfrak{R}}_{\tilde{\mathscr{S}}_{n}}(\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}}) 𝔼𝒮~𝒟~n[𝒮~n(𝒢d,h)]+𝒪(log(1δ)1n)\displaystyle\leq\mathbb{E}_{\tilde{\mathscr{S}}\sim\tilde{\mathscr{D}}^{n}}\left[\mathfrak{R}_{\tilde{\mathscr{S}}_{n}}(\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}})\right]+\mathcal{O}\Big{(}\sqrt{{\log\Big{(}\frac{1}{\delta}\Big{)}}\cdot\frac{1}{n}}\Big{)} (32)
=~n(𝒢d,h)+𝒪(log(1δ)1n).\displaystyle=\tilde{\mathfrak{R}}_{n}(\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}})+\mathcal{O}\Big{(}\sqrt{{\log\Big{(}\frac{1}{\delta}\Big{)}}\cdot\frac{1}{n}}\Big{)}. (33)

It follows from Lemma D.3 that

Rd,f~nh(𝒟X)Rd,f~nh(𝒟~X)+𝒲(𝒟X,𝒟~X)\displaystyle R^{h}_{\ell_{{\mathrm{d}}},\tilde{f}_{n}}({\mathscr{D}}_{X})\leq R^{h}_{\ell_{{\mathrm{d}}},\tilde{f}_{n}}(\tilde{\mathscr{D}}_{X})+\mathcal{W}(\mathscr{D}_{X},\tilde{\mathscr{D}}_{X}) (34)

Now the first part of Theorem C.4, as stated in Eq. 28, follows by combining Sec. C.2, Eq. 32, and Eq. 34.

We now focus on establishing Eq. 29. Note that, for a sample 𝒮~n={x~1,,x~n}\tilde{\mathscr{S}}_{n}=\{\tilde{x}_{1},\ldots,\tilde{x}_{n}\} generated by the TGT framework, there exists {z1,,zn}d\{z_{1},\ldots,z_{n}\}\subset{\mathbb{R}}^{d} such that

x~i=𝖣𝖾𝖼(zi),i[n].\displaystyle\tilde{x}_{i}=\mathsf{Dec}(z_{i}),~{}\forall i\in[n]. (35)

Thus,

𝒮~n(𝒢d,h)\displaystyle{\mathfrak{R}}_{\tilde{\mathscr{S}}_{n}}(\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}}) =1n𝔼𝝈|supg𝒢d,hiσig(x~i)|\displaystyle=\frac{1}{n}\mathbb{E}_{\bm{\sigma}}\left|\sup_{g\in\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}}}\sum_{i}\sigma_{i}g(\tilde{x}_{i})\right|
=(i)1n𝔼𝝈|supg𝒢d,hiσig(𝖣𝖾𝖼(zi))|\displaystyle\overset{(i)}{=}\frac{1}{n}\mathbb{E}_{\bm{\sigma}}\left|\sup_{g\in\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}}}\sum_{i}\sigma_{i}g(\mathsf{Dec}(z_{i}))\right|
1n𝔼𝝈|supg𝒢d,h,𝖣𝖾𝖼iσig(zi)|\displaystyle\leq\frac{1}{n}\mathbb{E}_{\bm{\sigma}}\left|\sup_{g^{\prime}\in\mathcal{G}^{h,\mathsf{Dec}}_{\ell_{{\mathrm{d}}},\mathcal{F}}}\sum_{i}\sigma_{i}g^{\prime}(z_{i})\right|
=𝒮~n(𝒢d,h,𝖣𝖾𝖼),\displaystyle={\mathfrak{R}}_{\tilde{\mathscr{S}}_{n}}(\mathcal{G}^{h,\mathsf{Dec}}_{\ell_{{\mathrm{d}}},\mathcal{F}}), (36)

where (i)(i) employs Eq. 35. Thus, combining Sec. C.2 and Sec. C.2 gives us that

Rd,f~nh(𝒟~X)𝔼𝒮~n[d(f~n(X),h(X))]+𝒮~n(𝒢d,h,𝖣𝖾𝖼).\displaystyle R^{h}_{\ell_{{\mathrm{d}}},\tilde{f}_{n}}(\tilde{\mathscr{D}}_{X})\leq\mathbb{E}_{\tilde{\mathscr{S}}_{n}}[\ell_{{\mathrm{d}}}(\tilde{f}_{n}(X),h(X))]+{\mathfrak{R}}_{\tilde{\mathscr{S}}_{n}}(\mathcal{G}^{h,\mathsf{Dec}}_{\ell_{{\mathrm{d}}},\mathcal{F}}). (37)

Now, similar to the proof of Eq. 28, we can invoke Lemma D.3 and the concentration result for empirical Rademacher complexity to obtain the desired result in Eq. 29 from Eq. 37. ∎

Remark C.5.

Note that, if the distillation loss d\ell_{{\mathrm{d}}} satisfies Assumption C.1 with a loss function \ell, then, one can combine Theorem C.4 and Lemma D.4 to readily obtain bounds on R,f~n(𝒟)R_{\ell,\tilde{f}_{n}}(\mathscr{D}) with an additional term

𝒪(K𝔼𝒟X[𝒟Y|Xh(X)2]).\mathcal{O}\big{(}{\sqrt{K}}\cdot\mathbb{E}_{\mathscr{D}_{X}}\left[\|\mathscr{D}_{Y|X}-h(X)\|_{2}\right]\big{)}.

This term captures the quality of the teacher labeler hh.

C.3 Weighted ERM: An alternative training procedure for TGT

Note that given the samples 𝒮~n={x~i}i[n]\tilde{\mathscr{S}}_{n}=\{\tilde{x}_{i}\}_{i\in[n]} generated from 𝒟~X\tilde{\mathscr{D}}_{X} and a teacher labeler hh, we minimize the following empirical risk for student training:

Rd,fh(𝒮~n)=1ni[n]d(f(x~i),h(x~i)).\displaystyle R^{h}_{\ell_{{\mathrm{d}}},f}(\tilde{\mathscr{S}}_{n})=\frac{1}{n}\sum_{i\in[n]}\ell_{{\mathrm{d}}}\big{(}f(\tilde{x}_{i}),h(\tilde{x}_{i})\big{)}. (38)

However, as we notice in Theorem C.4, this leads to an additional 𝒲(𝒟X,𝒟~X)\mathcal{W}(\mathscr{D}_{X},\tilde{\mathscr{D}}_{X}) penalty term in the generalization bound. One standard approach to address this issue is to consider the following weighted empirical risk.

Rd,fh,IS(𝒮~n)=1ni[n]d(f(x~i),h(x~i))p𝒟X(x~i)p𝒟~X(x~i),\displaystyle R^{h,{\rm IS}}_{\ell_{{\mathrm{d}}},f}(\tilde{\mathscr{S}}_{n})=\frac{1}{n}\sum_{i\in[n]}\ell_{{\mathrm{d}}}\big{(}f(\tilde{x}_{i}),h(\tilde{x}_{i})\big{)}\cdot\frac{p_{\mathscr{D}_{X}}(\tilde{x}_{i})}{p_{\tilde{\mathscr{D}}_{X}}(\tilde{x}_{i})}, (39)

where p𝒟Xp_{\mathscr{D}_{X}} and p𝒟~Xp_{\tilde{\mathscr{D}}_{X}} denote the probability density function (pdf) for 𝒟X\mathscr{D}_{X} and 𝒟~X\tilde{\mathscr{D}}_{X}.333Note that the formulation assumes that 𝒟X𝒟~X\mathscr{D}_{X}\ll\tilde{\mathscr{D}}_{X}, i.e., 𝒟X\mathscr{D}_{X} is absolutely continuous w.r.t. 𝒟~X\tilde{\mathscr{D}}_{X}. Also, one can replace the pdf’s with probability mass functions if 𝒟X\mathscr{D}_{X} and 𝒟~X\tilde{\mathscr{D}}_{X} are discrete distributions. Accordingly, we define a new induced function class related to the weighted empirical risk:

𝒢d,hIS={xd(f(x~i),h(x~))p𝒟X(x~)p𝒟~X(x~):f}\displaystyle\mathcal{G}^{h{\rm IS}}_{\ell_{{\mathrm{d}}},\mathcal{F}}=\{x\mapsto\ell_{{\mathrm{d}}}\big{(}f(\tilde{x}_{i}),h(\tilde{x})\big{)}\cdot\frac{p_{\mathscr{D}_{X}}(\tilde{x})}{p_{\tilde{\mathscr{D}}_{X}}(\tilde{x})}:f\in\mathcal{F}\} (40)

Importantly, we have

Rd,fh,IS(𝒟~X)=𝔼𝒟~X[Rd,fh,IS(𝒮~n)]=Rd,fh(𝒟X)\displaystyle R^{h,{\rm IS}}_{\ell_{{\mathrm{d}}},f}(\tilde{\mathscr{D}}_{X})=\mathbb{E}_{\tilde{\mathscr{D}}_{X}}\left[{R^{h,{\rm IS}}_{\ell_{{\mathrm{d}}},f}(\tilde{\mathscr{S}}_{n})}\right]=R^{h}_{\ell_{{\mathrm{d}}},f}(\mathscr{D}_{X}) (41)

Thus, following the analysis utilized in Theorem C.4, one can obtain a high probability generalization of the form.

Rd,fh(𝒟X)Rd,fh,IS(𝒮~n)+~n(𝒢d,h,IS)+𝒪(log(1δ)1n),\displaystyle R^{h}_{\ell_{{\mathrm{d}}},f}(\mathscr{D}_{X})\leq R^{h,{\rm IS}}_{\ell_{{\mathrm{d}}},f}(\tilde{\mathscr{S}}_{n})+\tilde{\mathfrak{R}}_{n}(\mathcal{G}^{h,{\rm IS}}_{\ell_{{\mathrm{d}}},\mathcal{F}})+\mathcal{O}\Big{(}\sqrt{{\log\Big{(}\frac{1}{\delta}\Big{)}}\cdot\frac{1}{n}}\Big{)}, (42)

which avoids the 𝒲(𝒟X,𝒟~X)\mathcal{W}(\mathscr{D}_{X},\tilde{\mathscr{D}}_{X}) term.

In what follows, we explore an alternative approach to highlight the importance of the sampling approach adapted by (gradient-based) TGT. By leveraging the variance-based generalization bound [Maurer and Pontil, 2009] that were previously utilized by Menon et al. [2021a] in the context distillation, we obtain the following result for the weighted empirical risk in Eq. 39.

Proposition C.6.

Let hh, d\ell_{{\mathrm{d}}}, \mathcal{F} and 𝒮~n\tilde{\mathscr{S}}_{n} be as defined in the statement of Theorem C.4. Further, assume that d,fh,IS(x~):=d(f(x~i),h(x~))p𝒟X(x~)p𝒟~X(x~)\ell^{h,{\rm IS}}_{{\mathrm{d}},f}(\tilde{x}):=\ell_{{\mathrm{d}}}\big{(}f(\tilde{x}_{i}),h(\tilde{x})\big{)}\cdot\frac{p_{\mathscr{D}_{X}}(\tilde{x})}{p_{\tilde{\mathscr{D}}_{X}}(\tilde{x})} is bounded for all x~supp(D~X\tilde{x}\in{\rm supp}(\tilde{D}_{X}). Then, for any ff\in\mathcal{F}, the following holds with probability at least 1δ1-\delta.

Rd,fh(𝒟X)Rd,fh,IS(𝒮~n)+(I),\displaystyle R^{h}_{\ell_{{\mathrm{d}}},f}(\mathscr{D}_{X})\leq R^{h,{\rm IS}}_{\ell_{{\mathrm{d}}},f}(\tilde{\mathscr{S}}_{n})+\text{\rm(I)}, (43)

where (I) denotes

𝒪(Var𝒟~X(d,fh,IS(x~))log((n)δ)n+log((n)δ)n).\mathcal{O}\bigg{(}\sqrt{\frac{{\rm Var}_{\tilde{\mathscr{D}}_{X}}(\ell^{h,{\rm IS}}_{{\mathrm{d}},f}(\tilde{x}))\cdot\log(\frac{\mathcal{M}(n)}{\delta})}{n}}+\frac{\log(\frac{\mathcal{M}(n)}{\delta})}{n}\bigg{)}.

Here, (n)=sup𝒮n𝒳n𝒩(1/n,𝒢d,h,IS(𝒮n),),\mathcal{M}(n)=\sup_{\mathscr{S}_{n}\subset\mathscr{X}^{n}}\mathcal{N}(1/n,\mathcal{G}^{h,{\rm IS}}_{\ell_{{\mathrm{d}}},\mathcal{F}}({\mathscr{S}}_{n}),\|\cdot\|_{\infty}), with 𝒩(ϵ,𝒢d,h,IS(𝒮n),)\mathcal{N}(\epsilon,\mathcal{G}^{h,{\rm IS}}_{\ell_{{\mathrm{d}}},\mathcal{F}}(\mathscr{S}_{n}),\|\cdot\|_{\infty}) denoting the covering number [Devroye et al., 2013] of the set

𝒢d,h,IS(𝒮n):={(g(x1),,g(xn)):g𝒢d,h,IS}.\displaystyle\mathcal{G}^{h,{\rm IS}}_{\ell_{{\mathrm{d}}},\mathcal{F}}(\mathscr{S}_{n}):=\{(g(x_{1}),\ldots,g(x_{n})):g\in\mathcal{G}^{h,{\rm IS}}_{\ell_{{\mathrm{d}}},\mathcal{F}}\}.
Proof.

By utilizing the uniform convergence version of Bennet’s inequality and uniform bound for Var𝒮~n(dIS(x~))\sqrt{{\rm Var}_{\tilde{\mathscr{S}}_{n}}(\ell^{\rm IS}_{{\mathrm{d}}}(\tilde{x}))}, where Var𝒮~n(dIS(x~)){\rm Var}_{\tilde{\mathscr{S}}_{n}}(\ell^{\rm IS}_{{\mathrm{d}}}(\tilde{x})) denotes the empirical variance of dIS(x~)\ell^{\rm IS}_{{\mathrm{d}}}(\tilde{x}) based on 𝒮~n\tilde{\mathscr{S}}_{n}, the following holds with probability at least 1δ1-\delta [Maurer and Pontil, 2009].

Rd,fh,IS(𝒟~X)Rd,fh,IS(𝒮~n)+𝒪(Var𝒟~X(dIS(x~))log((n)δ)n+log((n)δ)n),f.\displaystyle R^{h,IS}_{\ell_{{\mathrm{d}}},f}(\tilde{\mathscr{D}}_{X})\leq R^{h,{\rm IS}}_{\ell_{{\mathrm{d}}},f}(\tilde{\mathscr{S}}_{n})+\mathcal{O}\bigg{(}\sqrt{\frac{{\rm Var}_{\tilde{\mathscr{D}}_{X}}(\ell^{\rm IS}_{{\mathrm{d}}}(\tilde{x}))\cdot\log(\frac{\mathcal{M}(n)}{\delta})}{n}}+\frac{\log(\frac{\mathcal{M}(n)}{\delta})}{n}\bigg{)},\forall~{}f\in\mathcal{F}. (44)

Since, Rd,fh,IS(𝒟~X)=𝔼𝒟~X[Rd,fh,IS(𝒮~n)]=Rd,fh(𝒟X)R^{h,{\rm IS}}_{\ell_{{\mathrm{d}}},f}(\tilde{\mathscr{D}}_{X})=\mathbb{E}_{\tilde{\mathscr{D}}_{X}}\left[{R^{h,{\rm IS}}_{\ell_{{\mathrm{d}}},f}(\tilde{\mathscr{S}}_{n})}\right]=R^{h}_{\ell_{{\mathrm{d}}},f}(\mathscr{D}_{X}), the statement of Theorem C.6 follows from Eq. 44. ∎

Remark C.7.

Eq. 43 suggests general approach to select the distribution 𝒟~X\tilde{\mathscr{D}}_{X} that generated the training samples 𝒮~n\tilde{\mathscr{S}}_{n}. In order to ensure small generalization gap, it is desirable that the variance term Var𝒟~X(dIS(x~)){\rm Var}_{\tilde{\mathscr{D}}_{X}}(\ell^{\rm IS}_{{\mathrm{d}}}(\tilde{x})) is as small as possible. Note that, the distribution that minimizes this variance takes the form

logp𝒟~X(x)logd(f(x),h(x))+logp𝒟X(x),x𝒳.\displaystyle\log p_{\tilde{\mathscr{D}}_{X}}(x)\propto\log\ell_{{\mathrm{d}}}\big{(}f(x),h(x)\big{)}+\log p_{{\mathscr{D}}_{X}}(x),~{}\forall~{}x\in\mathscr{X}. (45)

This looks like the lagrangian form of Eq. 3. Interestingly, TGT framework with gradient-based sampling (cf. equation 5) focuses on generating samples that maximizes the right hand side RHSRHS of Eq. 45 by first taking a sample generated according to 𝒟X\mathscr{D}_{X} and then perturbing it in the latent space to maximize the loss d(f(x),h(x))\ell_{{\mathrm{d}}}\big{(}f(x),h(x)\big{)}. Thus, the resulting distribution 𝒟~X\tilde{\mathscr{D}}_{X} has pdf that aims to approximate the variance minimizing pdf in Eq. 45.

Here it is worth pointing out that, since exact form of p𝒟~X()p_{\tilde{\mathscr{D}}_{X}}(\cdot) and p𝒟X()p_{{\mathscr{D}}_{X}}(\cdot) is generally not available during the training, it’s not straightforward to optimize the weighted risk introduced in Eq. 39.

Remark C.8.

Note that, as introduced in Section 3, TGT framework optimizes the empirical risk in Eq. 38 as opposed to minimizing Eq. 39. In this case, one can obatain a variance based bound analogous to Eq. 43 that takes the form:

Rd,fh(𝒟X)Rd,fh(𝒮~n)+(II)+𝒲(𝒟X,𝒟~X),\displaystyle R^{h}_{\ell_{{\mathrm{d}}},f}(\mathscr{D}_{X})\leq R^{h}_{\ell_{{\mathrm{d}}},f}(\tilde{\mathscr{S}}_{n})+\text{\rm(II)}+\mathcal{W}(\mathscr{D}_{X},\tilde{\mathscr{D}}_{X}), (46)

where, (II) denotes

𝒪(Var𝒟~X(d,fh(x~))log((n)δ)n+log((n)δ)n),\mathcal{O}\bigg{(}\sqrt{\frac{{\rm Var}_{\tilde{\mathscr{D}}_{X}}(\ell^{h}_{{\mathrm{d}},f}(\tilde{x}))\cdot\log(\frac{\mathcal{M}(n)}{\delta})}{n}}+\frac{\log(\frac{\mathcal{M}(n)}{\delta})}{n}\bigg{)},

with d,fh(x~):=d(f(x~i),h(x~))\ell^{h}_{{\mathrm{d}},f}(\tilde{x}):=\ell_{{\mathrm{d}}}\big{(}f(\tilde{x}_{i}),h(\tilde{x})\big{)} and (n)\mathcal{M}(n) depending the covering number for the induced function class 𝒢d,h\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}} (cf. Eq. 16). Notably, this bound again incurs a penalty of 𝒲(𝒟X,𝒟~X)\mathcal{W}(\mathscr{D}_{X},\tilde{\mathscr{D}}_{X}) which is expected to be small for our TGT based sampling distribution when we employ high-quality teacher generator.

Appendix D Toolbox

This section presents necessary definitions and lemmas that we utilize to establish our theoretical results presented in Sec. 3 (and restated in App. C.

Definition D.1 (Wasserstein-1 metric).

Let (𝒳,ρ)(\mathscr{X},\rho) be a metric space. Given two probability distributions 𝒟X1\mathscr{D}^{1}_{X} and 𝒟X2\mathscr{D}^{2}_{X} over 𝒳\mathscr{X}, Wasserstein-1 distance between 𝒟X1\mathscr{D}^{1}_{X} and 𝒟X2\mathscr{D}^{2}_{X} is defined as follows.

𝒲(𝒟X1,𝒟X2):=infπΠ(𝒟X1,𝒟X2)𝔼X,Xπ[d(X,X)]=infπΠ(𝒟X1,𝒟X2)𝒳×𝒳ρ(X,X)𝑑π(x,x),\displaystyle\mathcal{W}(\mathscr{D}^{1}_{X},\mathscr{D}^{2}_{X}):=\inf_{\pi\in\Pi(\mathscr{D}^{1}_{X},\mathscr{D}^{2}_{X})}\mathbb{E}_{X,X^{\prime}\sim\pi}\left[d(X,X^{\prime})\right]=\inf_{\pi\in\Pi(\mathscr{D}^{1}_{X},\mathscr{D}^{2}_{X})}\int_{\mathscr{X}\times\mathscr{X}}\rho(X,X^{\prime})\,d\pi(x,x^{\prime}), (47)

where Π(𝒟X1,𝒟X2)\Pi(\mathscr{D}^{1}_{X},\mathscr{D}^{2}_{X}) denotes the set of all joint distributions over 𝒳×𝒳\mathscr{X}\times\mathscr{X} that have 𝒟X1\mathscr{D}^{1}_{X} and 𝒟X2\mathscr{D}^{2}_{X} as their marginals.

Lemma D.2 (Kantorovich-Rubinstein duality [Villani, 2008]).

Let Lip1(ρ){\rm Lip}_{1}(\rho) denote the set of all 1-Lipschitz functions in the metric ρ\rho, i.e., for any fLip1(ρ)f\in{\rm Lip}_{1}(\rho),

|f(x)f(x)|ρ(x,x),x,x.\displaystyle|f(x)-f(x^{\prime})|\leq\rho(x,x^{\prime}),~{}\forall~{}x,x^{\prime}. (48)

Then,

𝒲(𝒟X1,𝒟X2)=supfLip1(ρ)(𝔼X𝒟X1[f(X)]𝔼X𝒟X2[f(X)]).\displaystyle\mathcal{W}(\mathscr{D}^{1}_{X},\mathscr{D}^{2}_{X})=\sup_{f\in{\rm Lip}_{1}(\rho)}\left(\mathbb{E}_{X\sim\mathscr{D}^{1}_{X}}\left[f(X)\right]-\mathbb{E}_{X^{\prime}\sim\mathscr{D}^{2}_{X}}\left[f(X^{\prime})\right]\right). (49)
Lemma D.3.

Let d:K×K\ell_{{\mathrm{d}}}:{\mathbb{R}}^{K}\times{\mathbb{R}}^{K}\to{\mathbb{R}} be a loss function employed during the distillation. For a given teacher h:𝒳Kh:\mathscr{X}\to{\mathbb{R}}^{K} and a function class \mathcal{F}, we assume the the induced function class

𝒢d,h={xd(f(x),h(x)):f}\displaystyle\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}}=\{x\mapsto\ell_{{\mathrm{d}}}(f(x),h(x)):f\in\mathcal{F}\} (50)

is contained in the class of LL-Lipschitz functions with respect to a metric ρ\rho. Then, for any two distributions 𝒟X1\mathscr{D}^{1}_{X} and 𝒟X2\mathscr{D}^{2}_{X}, we have

Rd,fh(𝒟X1)Rd,fh(𝒟X2)𝒲(𝒟X1,𝒟X2),f,\displaystyle R^{h}_{\ell_{{\mathrm{d}}},f}(\mathscr{D}^{1}_{X})-R^{h}_{\ell_{{\mathrm{d}}},f}(\mathscr{D}^{2}_{X})\leq\mathcal{W}(\mathscr{D}^{1}_{X},\mathscr{D}^{2}_{X}),\quad\forall~{}f\in\mathcal{F}, (51)

where 𝒲(𝒟X1,𝒟X2)\mathcal{W}(\mathscr{D}^{1}_{X},\mathscr{D}^{2}_{X}) denotes the Wasserstein-1 metric between the two distribution 𝒟X1\mathscr{D}^{1}_{X} and 𝒟X2\mathscr{D}^{2}_{X} (cf. Definition D.1).

Proof.

Note that

Rd,fh(𝒟X1)Rd,fh(𝒟X2)\displaystyle R^{h}_{\ell_{{\mathrm{d}}},f}(\mathscr{D}^{1}_{X})-R^{h}_{\ell_{{\mathrm{d}}},f}(\mathscr{D}^{2}_{X}) =𝔼X𝒟X1[d(f(X),h(X)]𝔼X𝒟X2[d(f(X),h(X)]\displaystyle=\mathbb{E}_{X\sim\mathscr{D}^{1}_{X}}\left[\ell_{{\mathrm{d}}}(f(X),h(X)\right]-\mathbb{E}_{X^{\prime}\sim\mathscr{D}^{2}_{X}}\left[\ell_{{\mathrm{d}}}(f(X^{\prime}),h(X^{\prime})\right]
supg𝒢d,h(𝔼X𝒟X1[g(X)]𝔼X𝒟X1[g(X)])\displaystyle\leq\sup_{g\in\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}}}\left(\mathbb{E}_{X\sim\mathscr{D}^{1}_{X}}\left[g(X)\right]-\mathbb{E}_{X^{\prime}\sim\mathscr{D}^{1}_{X}}\left[g(X^{\prime})\right]\right)
=(i)Lsupg𝒢d,h(𝔼X𝒟X1[g(X)L]𝔼X𝒟X1[g(X)L])\displaystyle\overset{(i)}{=}L\cdot\sup_{g\in\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}}}\left(\mathbb{E}_{X\sim\mathscr{D}^{1}_{X}}\left[\frac{g(X)}{L}\right]-\mathbb{E}_{X^{\prime}\sim\mathscr{D}^{1}_{X}}\left[\frac{g(X^{\prime})}{L}\right]\right)
(ii)LsupgLip1(ρ)(𝔼X𝒟X1[g(X)]𝔼X𝒟X1[g(X)]),\displaystyle\overset{(ii)}{\leq}L\cdot\sup_{g\in{\rm Lip}_{1}(\rho)}\left(\mathbb{E}_{X\sim\mathscr{D}^{1}_{X}}\left[g(X)\right]-\mathbb{E}_{X^{\prime}\sim\mathscr{D}^{1}_{X}}\left[g(X^{\prime})\right]\right),
=(iv)𝒲(𝒟X1,𝒟X2).\displaystyle\overset{(iv)}{=}\mathcal{W}(\mathscr{D}^{1}_{X},\mathscr{D}^{2}_{X}). (52)

where (i)(i) follow by dividing and multiply by LL; (ii)(ii) follows as, for any g𝒢d,hg\in\mathcal{G}^{h}_{\ell_{{\mathrm{d}}},\mathcal{F}} is gL\frac{g}{L} is 11-Lipschitz; and (iii)(iii) follows from Lemma D.2. ∎

Lemma D.4.

Let the distillation loss d\ell_{{\mathrm{d}}} satisfy Assumption C.1 with a bounded loss function :K×𝒴\ell:{\mathbb{R}}^{K}\times\mathscr{Y}\to{\mathbb{R}}. Then, given a teacher h:𝒳Kh:\mathscr{X}\to{\mathbb{R}}^{K} and a student model f:𝒳Kf:\mathscr{X}\to{\mathbb{R}}^{K}, we have

|Rd,fh(𝒟X)R,f(𝒟)|𝒪(K𝔼𝒟X[𝒟Y|Xh(X)2]),\displaystyle\Big{|}R^{h}_{\ell_{{\mathrm{d}}},f}(\mathscr{D}_{X})-R_{\ell,f}(\mathscr{D})\Big{|}\leq\mathcal{O}\big{(}{\sqrt{K}}\cdot\mathbb{E}_{\mathscr{D}_{X}}\left[\|\mathscr{D}_{Y|X}-h(X)\|_{2}\right]\big{)}, (53)

where 𝒟Y|X=(𝒟Y|X(1),,𝒟Y|X(K))\mathscr{D}_{Y|X}=(\mathscr{D}_{Y|X}(1),\ldots,\mathscr{D}_{Y|X}(K)) is treated as a vector in K{\mathbb{R}}^{K}.

Proof.

Note that

|Rd,fh(𝒟X)R,f(𝒟)|\displaystyle\Big{|}R^{h}_{\ell_{{\mathrm{d}}},f}(\mathscr{D}_{X})-R_{\ell,f}(\mathscr{D})\Big{|} =|𝔼𝒟X[d(f(X),h(X))]R,f(𝒟)|\displaystyle=\Big{|}\mathbb{E}_{\mathscr{D}_{X}}[\ell_{{\mathrm{d}}}(f(X),h(X))]-R_{\ell,f}(\mathscr{D})\Big{|}
=|𝔼𝒟X[d(f(X),h(X))]𝔼𝒟[(f(X),Y)]|\displaystyle=\Big{|}\mathbb{E}_{\mathscr{D}_{X}}[\ell_{{\mathrm{d}}}(f(X),h(X))]-\mathbb{E}_{\mathscr{D}}[\ell(f(X),Y)]\Big{|}
=|𝔼𝒟X[y[K]h(X)y(f(x),y)]𝔼𝒟X[y[K]𝒟Y|X(y)(f(X),y)]|\displaystyle=\Big{|}\mathbb{E}_{\mathscr{D}_{X}}\Big{[}\sum_{y\in[K]}h(X)_{y}\cdot\ell(f(x),y)\Big{]}-\mathbb{E}_{\mathscr{D}_{X}}\Big{[}\sum_{y\in[K]}\mathscr{D}_{Y|X}(y)\cdot\ell(f(X),y)\Big{]}\Big{|}
=|𝔼𝒟X[y[K](h(X)y𝒟Y|X(y))(f(X),y)]|\displaystyle=\Big{|}\mathbb{E}_{\mathscr{D}_{X}}\Big{[}\sum_{y\in[K]}\big{(}h(X)_{y}-\mathscr{D}_{Y|X}(y)\big{)}\cdot\ell(f(X),y)\Big{]}\Big{|}
(i)𝔼𝒟X[𝒟Y|Xh(X)2(f(X))2],\displaystyle\overset{(i)}{\leq}\mathbb{E}_{\mathscr{D}_{X}}[\|\mathscr{D}_{Y|X}-h(X)\|_{2}\cdot\|\ell(f(X))\|_{2}], (54)

where (i)(i) follow from the Cauchy-Schwarz inequality. Now the statement of Lemma D.4 follows from the assumption on the loss \ell is bounded. ∎

Appendix E Additional experiments

E.1 Long-tail image classification

Please see Table 4 for Places365-LT result. Discussion is in Sec. 4.1.

Approach Architecture Balanced Accuracy # parameters FLOPs
Places365-LT LWS [Kang et al., 2020] ResNet-152 37.6 60 M 11 B
LDAM-DRS-RSG [Wang et al., 2021] ResNet-152 39.3 60 M 11 B
OLTR [Liu et al., 2019b] ResNet-152 35.9 60 M 11 B
DRAGON + Bal’Loss [Samuel et al., 2021] ResNet-50 38.1 26 M 4.1 B
Teacher (labeler) model EfficientNet-b3 42.1 12 M 1.8 B
One-hot MobileNetV3-0.75 26.8 4.01 M 156 M
Distillation MobileNetV3-0.75 33.0 4.01 M 156 M
TGT (random) MobileNetV3-0.75 34.7 4.01 M 156 M
TGT (gradient-based) MobileNetV3-0.75 35.0 4.01 M 156 M
Table 4: Performance of TGT on Places-LT [Liu et al., 2019c]. The table shows the top-1 accuracy on the corresponding balanced eval sets for TGT and different long-tail baselines from the literature (taken from [Samuel et al., 2021]). We also state the number of model parameters and inference cost (in terms of FLOPs) for all the methods. Note that TGT leads to performance improvements over standard distillation. Note that, for Places-LT, TGT does not outperform stated baselines for the literature that rely on specialized loss function and/or training procedures designed from the long-tail setting. One reason for this could be that the BigBiGAN does not generate very informative samples for Places-LT due to distribution mismatch. That said, as discussed in Sec. 4.1, one can combine the TGT framework with a long-tail specific loss function as opposed to employing the standard cross-entropy loss function as a way to further improve its performance.

Appendix F Details to reproduce our empirical results

Hereby we provide details to reproduce our experimental results.

F.1 Long-tail image classification (Sec. 4.1)

Dataset. The full balanced version of 3 datasets (ImageNet 444https://www.tensorflow.org/datasets/catalog/imagenet2012, Place365 555https://www.tensorflow.org/datasets/catalog/places365_small, SUN397 666https://www.tensorflow.org/datasets/catalog/sun397) are available in tensflow-datasets (https://www.tensorflow.org/datasets/). Next to obtain the the long-tail version of the datasets, we downloaded 777https://drive.google.com/drive/u/1/folders/1j7Nkfe6ZhzKFXePHdsseeeGI877Xu1yf image ids from repository of ”Large-Scale Long-Tailed Recognition in an Open World [Liu et al., 2019b]” according to which we subsampled the full balanced dataset.

Teacher fine-tuning. For teacher labeler, we follow ”Sharpness Aware Minimization’ [Foret et al., 2020] codebase (available at https://github.com/google-research/sam) to fine-tune on the long-tail datasets. We start with pretrained EfficientNet-B3 model checkpoint available from official repository888https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/noisystudent/efficientnet-b3/checkpoint.tar.gz and used default parameters from the codebase. We fine-tuned all 3 datasets (ImageNet-LT, SUN397-LT, Place365-LT) for 3 epochs.

We directly used teacher generator as BigBiGAN ResNet-50 checkpoint from the official repository https://github.com/deepmind/deepmind-research/tree/master/bigbigan. (We did not fine-tune it.)

Student training. We start from randomly initialized MobileNetV3-0.75 model. We employed SGD optimizer with cosine schedule (peak learning rate of 0.4 and decay down to 0). We also did a linear warm-up (from 0 to peak learning rate of 0.4) for first 5 epochs. The input image size are unfortunately different between EfficientNet-B3 model, BigBiGAN-ResNet50, and MobileNetV3-0.75 models. From original images in dataset, we use Tensorflow’s bicubic resizing to obtain appropriate size image for each mode. We did a grid search over the perturbation parameters σ\sigma and η\eta (c.f. Eq. 4 and Eq. 5). All hyper-parameters and grid are listed in table below:

Hyper-param ImageNet-LT Place365-LT Sun397-LT
Num epochs 90 30 30
Optimizer SGD
Schedule Cosine
Warm-up epochs 5
Peak learning rate 0.4
Batch size 256
Teacher labeler image size 300×300×3300\times 300\times 3
Teacher generator image size 256×256×3256\times 256\times 3
Student image size 224×224×3224\times 224\times 3
Perturbation noise (σ\sigma) {0, 0.001, 0.01, 0.1}
Gradient exploration
- Step size (η\eta) {0, 0.001, 0.01, 0.1}
- Num steps 2
Table 5: Hyper-parameters for long-tail image classification

F.2 TGT in low-data regime (Sec. 4.2)

Dataset. We used ImageNet 999https://www.tensorflow.org/datasets/catalog/imagenet2012 dataset from tensflow-datasets repository (https://www.tensorflow.org/datasets/). We used in-built sub-sampling functionality available in tensorflow (https://www.tensorflow.org/datasets/splits) to simulate the low-data regime.

Teacher model. For teacher labeler, we directly used trained EfficientNet-B3 model checkpoint available from ”Sharpness Aware Minimization” repository101010https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/noisystudent/efficientnet-b3/checkpoint.tar.gz For teacher generator, we directly used trained BigBiGAN checkpoint from the official repository https://github.com/deepmind/deepmind-research/tree/master/bigbigan. (We did not fine-tune either of the models.)

Student training. We start from randomly initialized MobileNetV3-0.75 model. We employed SGD optimizer with cosine schedule (peak learning rate of 0.4 and decay down to 0). We also did a linear warm-up (from 0 to peak learning rate of 0.4) for first 5 epochs. The input image size are unfortunately different between EfficientNet-B3 model, BigBiGAN-ResNet50, and MobileNetV3-0.75 models. From original images in dataset, we use Tensorflow’s bicubic resizing to obtain appropriate size image for each mode. Following standard practice in literature He et al. [2016], Jia et al. [2018], we train one-hot and standard distillation student models for 90 epochs (= 450k steps). We use 4x less steps for TGT than the simple distillation baseline, which amounts to 450k/4 = 112k steps.

F.3 Text classification (Sec. 4.3)

Dataset. We conduct text classification experiments on following datasets:

Optimizer. For all training, we employed ADAM optimizer with linear decay schedule (peak learning rate of 3e-5 and decay to 0). We also did a linear warm-up at start. We used batch size of 128.

Teacher fine-tuning. For teacher labeler, we started from RoBERTa-Base [Liu et al., 2019a] pretrained checkpoint 111111https://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz from official FAIRSEQ repository https://github.com/facebookresearch/fairseq. We fine-tuned using default parameters, other than number of steps which are same as those listed in Table 6.

For teacher generator, we directly use a pre-trained BART-Base [Lewis et al., 2020] checkpoint 121212https://dl.fbaipublicfiles.com/fairseq/models/bart.base.tar.gz from official FAIRSEQ repository https://github.com/facebookresearch/fairseq. (We did not fine-tune it.)

Student training. We start from DistillBERT pretrained checkpoint downloaded from HuggingFace repository 131313https://huggingface.co/distilroberta-base/tree/main. We perturb by adding Gaussian noise of σ2\sigma^{2} variance in between encoder-decoder as well as masking out pp fraction of input. Then we generate new examples by running a greedy decoding of BART teacher generator for sequence length of 512. For dual input classification task, like in MNLI, we generate the two inputs independently. We did a grid search over the perturbation parameters σ\sigma and masking fraction pp. All hyper-parameters and grid are listed in table below:

Hyper-param Amazon-5 IMDB MNLI Yelp-5
2.5k 3M 2.5k 650k
Num steps 5000 75000 20000 75000 5000 75000
Warm-up steps 1000 2000 500 2000 1000 2000
Optimizer Adam
Schedule Linear
Peak learning rate 3e-5
Batch size 128
Max Sequence length 512
Perturbation noise (σ\sigma) {0, 0.01, 0.1}
Masking fraction (pp) {0, 0.1, 0.2}
Table 6: Hyper-parameters for student training of text classification

F.4 Text retrieval (Sec. 4.4)

Dataset. From official ”Dense Passage Retrieval” repository at https://github.com/facebookresearch/DPR, we download passage corpus 141414https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz. Further, from the same repository, we download a pre-processed version of natural questions open [Lee et al., 2019] which has been aligned to passage corpus 151515https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-train.json.gz,https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-dev.json.gz. Finally, we download a pre-processed version of PAQ dataset [Lewis et al., 2021] dataset from official repository of ”Domain-matched Pre-training Tasks for Dense Retrieval” available at https://github.com/facebookresearch/dpr-scale which has been aligned to the same passage corpus  161616https://dl.fbaipublicfiles.com/dpr_scale/paq/PAQ.dpr.train.neg1.jsonl.zip

Optimizer. For all text retrieval model training, we employed ADAM optimizer with linear decay schedule (peak learning rate of 1e-5 and decay to 1e-7). We also did a linear warm-up (from 0 to peak learning rate of 1e-5) for 1K steps. We used batch size of 128.

Teacher fine-tuning. For teacher labeler dual encoder (a question encoder and a passage encoder), we utilized RoBERTa-Base [Liu et al., 2019a] pretrained checkpoint 171717https://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz from official FAIRSEQ repository https://github.com/facebookresearch/fairseq. We then conducted first round of fine-training for 300k iterations with passage-aligned PAQ dataset. We used same configuration as Oğuz et al. [2021] except Oğuz et al. trained with PAQ longer. After the pretraining, the teacher is fine-tuned on NQ-open [Kwiatkowski et al., 2019] downloaded with 40K steps. Similar to Karpukhin et al. [2020], Oğuz et al. [2021], the teacher is trained with within-batch negatives and the softmax-based cross-entropy loss.

For teacher generator, we directly use a pre-trained BART-Base [Lewis et al., 2020] checkpoint 181818https://dl.fbaipublicfiles.com/fairseq/models/bart.base.tar.gz from official FAIRSEQ repository https://github.com/facebookresearch/fairseq. (We did not fine-tune it.)

This same teacher labeler and generator is used for all student training except for the direct training (one-hot).

Student training. We start from DistillBERT pretrained checkpoint downloaded from HuggingFace repository 191919https://huggingface.co/distilroberta-base/tree/main. All students are trained with 40K steps. The teacher labeler will label all-pair within the batch and will label additional 2 passages per each question-passage pair for the uniform negative sampling baseline and TGT. We employed a off-the-shelf BART-base model as our generator [Lewis et al., 2020] and isotropic perturbation was added by random Gaussian noise of scale σ=0.1\sigma=0.1 combined with p=0.2p=0.2 for masking the original passage.

Appendix G Qualitative Examples of Generated Examples

G.1 Image Classification

We show some representative examples of generated images using TGT-random as well as TGT-gradient based from the experiment on ImageNet classification in Table 7.

Input TGT-Random Example TGT-Gradient based Example
Data label: lion Teacher label: brown bear Teacher label: chow
[Uncaptioned image] [Uncaptioned image] [Uncaptioned image]
Data label: cheeseburger Teacher label: potpie Teacher label: cheeseburger
[Uncaptioned image] [Uncaptioned image] [Uncaptioned image]
Data label: digital clock Teacher label: tape player Teacher label: grocery store
[Uncaptioned image] [Uncaptioned image] [Uncaptioned image]
Data label: wall clock Teacher label: shield Teacher label: gong
[Uncaptioned image] [Uncaptioned image] [Uncaptioned image]
Table 7: Image examples

G.2 Text Classification

We show some representative examples of generated text using TGT from the experiment on MNLI classification in Table 8.

Input TGT Example
Data label: Contradicts Teacher label: Neutral
The house was bought with the royalties she earned from her first book, The Tales of Peter Rabbit. [SEP] The house was bought with the money she inherited from her grandfather. The book was published in the United States in 1987 with the royalties she received from her first book, The Tales of Peter Rabbit. [SEP] The house was bought with the money she inherited from her grandfather.
Data label: Entail Teacher label: Entail
Leather goods are no longer a bargain in Spain, though very good quality products may still be priced lower than at home. [SEP] Leather goods are still very cheap in Spain. Leather and leather goods are no longer a bargain in Spain, though very good quality products may still be priced lower than at home and abroad. [SEP] Leather goods are still very cheap at Spain.
Data label: Entail Teacher label: Neutral
Then I got up as softly as I could, and felt in the dark along the left-hand wall. [SEP] The wall was wet. Then I got up as softly as I could, and walked the way I felt in the dark along the left [SEP] The wall was wet.
Data label: Entails Teacher label: Entail
But then this very particular island is hardly in danger of being invaded except, of course, by tourism. [SEP] This island is least likely to be invaded by tourism. But then this very particular island is not in danger of being invaded except, of course, by tourism. [SEP] The island is likely to be invaded by tourism.
Data label: Contradicts Teacher label: Neutral
All you need to do is just wander off the beaten path, beyond the bustling tourist zone. [SEP] There is no point going off the beaten path, there is nothing there. All you need to do is just wander off the beaten path, and youĺl be in the bustling tourist zone of the city. [SEP] There is no point going off the beaten path, there is nothing there.
Data label: Entails Teacher label: Neutral
The silt of the River Maeander has also stranded the once-mighty city of Miletus. [SEP] The River Maeander has been depositing silt near Miletus for nearly two millennia. The silt of the River Mae has also stranded the once-mighty city of Miletus. [SEP] The River Maeander has been depositing silt near Miletus for more than two decades.
Data label: Entails Teacher label: Entails
It was hardly the most enlightened of times, not with the conflict in Indochina rapidly becoming Americaś costliest and most divisive war. [SEP] The war in Indochina has cost America 100 billion dollars so far. It was hardly the most enlightened of times, not with the war in Indochina becoming Americaś costliest and most divisive war. [SEP] The war in Indochina has cost America 100 billion dollars so far.
Table 8: Text examples