This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

SLaM: Student-Label Mixing for Distillation with Unlabeled Examples

Vasilis Kontonis
UW-Madison, Google Research
[email protected] &Fotis Iliopoulos
Google Research
[email protected] &Khoa Trinh
Google Research
[email protected] Cenk Baykal
Google Research
[email protected] &Gaurav Menghani
Google Research
[email protected] &Erik Vee
Google Research
[email protected]
Abstract

Knowledge distillation with unlabeled examples is a powerful training paradigm for generating compact and lightweight student models in applications where the amount of labeled data is limited but one has access to a large pool of unlabeled data. In this setting, a large teacher model generates “soft” pseudo-labels for the unlabeled dataset which are then used for training the student model. Despite its success in a wide variety of applications, a shortcoming of this approach is that the teacher’s pseudo-labels are often noisy, leading to impaired student performance. In this paper, we present a principled method for knowledge distillation with unlabeled examples that we call Student-Label Mixing (SLaM) and we show that it consistently improves over prior approaches by evaluating it on several standard benchmarks. Finally, we show that SLaM comes with theoretical guarantees; along the way we give an algorithm improving the best-known sample complexity for learning halfspaces with margin under random classification noise, and provide the first convergence analysis for so-called “forward loss-adjustment” methods.

1 Introduction

While good quality human-labeled data are often hard to obtain, finding huge amounts of unlabeled data is relatively easy. Therefore, in modern machine learning applications, we often face the situation where we have a small “golden” dataset with human labels and a large unlabeled dataset. In Distillation with Unlabeled Examples [12, 26, 16] a large teacher model is first trained (or fine-tuned) on the human-labeled data and is then used to generate “soft” pseudo-labels for the unlabeled dataset. Then the (typically smaller) student model, i.e., the model that will be deployed for the purposes of the application, is trained on the combined dataset that contains both the labels generated by humans and the pseudo-labels generated by the teacher model. This general-purpose training paradigm has been applied in a wide variety of contexts [16, 46, 53, 54, 58, 57] including but not limited to distilling knowledge from large-scale foundational models like BERT [18] and GPT-3 [11]. We remark that in such settings one does not have access to the teacher model but only on its pseudo-labels (which were generated during some previous “bulk-inference” phase). This “bulk-inference” step is typically computationally expensive and happens once: one cannot modify the teacher network (or even use it for inference) during the training process of student.

Despite its widespread success in practice, the effectiveness of this powerful approach generally depends on the quality of the pseudo-labels generated by the teacher model. Indeed, training the student model on noisy pseudo-labels often leads to significant degradation of its generalization performance, and this is a well-known phenomenon that has been observed and studied in a plethora of papers in the literature, e.g., [6, 36, 44, 51, 53, 8, 27].

In this work, we propose Student-Label Mixing (SLaM), a principled method for knowledge distillation with unlabeled examples that accounts for the teacher’s noise and consistently improves over prior approaches. At the heart of our method lies the observation that the noise introduced by the teacher is neither random nor adversarial, in the sense that it correlates well with metrics of “confidence” such as the margin score or the entropy of the teacher’s predictions. We exploit this empirical fact to our benefit in order to introduce a model for the teacher’s noise, which we use to appropriately modify the student’s loss function. At a high level, for any given example during the student’s training process, we evaluate the student’s loss function on a convex combination of the student’s current prediction and another (soft-)label that we estimate using our model for the teacher’s noise (hence the name “student-label mixing”).

Our contributions can be summarized as follows:

  1. 1.

    We propose SLaM: a principled method for improving knowledge distillation with unlabeled examples. The method is efficient, data-agnostic and simple to implement.

  2. 2.

    We provide extensive experimental evidence and comparisons which show that our method consistently outperforms previous approaches on standard benchmarks. Moreover, we show that SLaM can be combined with standard distillation techniques such as temperature scaling and confidence-based weighting schemes.

  3. 3.

    We give theoretical guarantees for SLaM under standard assumptions. As a byproduct of our analysis we obtain a simple “forward loss-adjustment” iteration that provably learns halfspaces with γ\gamma-margin under Random Classification Noise with O(1/(ϵ2γ2))O(1/(\epsilon^{2}\gamma^{2})) samples improving over prior works that had worse dependence on either the margin γ\gamma or the generalization error ϵ\epsilon (see Theorem 5.1 and Remark 5.2).

2 Related Work

Knowledge Distillation. Most of the literature on knowledge distillation has been focused on the fully supervised/labeled setting, i.e., when distillation is performed on the labeled training data of the teacher model rather than on new, unlabeled data — see e.g. the original paper of [26]. Naturally, in this setting the pseudo-labels generated by the teacher are almost always accurate and so many follow-up works [2, 14, 15, 41, 52] have developed advanced distillation techniques that aim to enforce greater consistency between the teacher’s and the student’s predictions, or even between the intermediate representations learned by the two models. Applying such methods in our setting where the training dataset contains mainly unlabeled examples is still possible but, in this case, it is known [51, 27] that fully trusting the teacher model can be actually harmful to the student model, making these methods less effective. (In fact, when the teacher is highly noisy these methods even underperform vanilla distillation with unlabeled examples.) In Section 4.2 we present results that show the improved effectiveness of SLaM relative to the state-of-the-art supervised knowledge distillation methods like the Variational Information Distillation for Knowledge Transfer (VID) framework [2]. Moreover, in Section D.5 we show that our method can be combined with (i.e., provide an additional improvement) the most simple, yet surprisingly effective, methods of improving knowledge distillation, namely the temperature-scaling idea introduced by [26].

For distillation with unlabeled examples, many approaches [17, 33, 29] propose filtering-out or reweighting the teacher’s pseudo-labels based on measures of teacher’s uncertainty, such as dropout variance, entropy, margin-score, or the cut-statistic. These methods are independent of the student model and can be synergistically combined with our technique. For instance, in Section D.4 we demonstrate that combining our method with teacher-uncertainty-based reweighting schemes leads to improved student performance relative to applying the reweighting scheme alone.

Much more closely related to our approach is the recently introduced approach of [27]. There, the authors design a model for the teacher’s noise and utilize it in order to modify the student’s loss function so that, in expectation, the loss simulates the loss with respect to noise-free pseudo-labels. One of the main advantages of our method compared to that of [27] is that our model for the teacher’s noise is more structured and easier to learn, which — as our experiments in Section 4.2 show — leads to consistently better student performance.

Learning From Noisy Labels. Learning from noisy labels is an important and well-studied problem with a vast literature [7, 21, 23, 28, 31, 37, 40, 42, 45, 47] — see [50] for a recent survey. The fundamental difference between our setting and papers in this literature is that the noise introduced by the teacher is structured, and this is a crucial observation we utilize in our design. Specifically, our approach is inspired by the so-called forward loss-adjustment methods, e.g. [43], but it is specifically tailored to the structure of the distillation with unabeled examples setting. Indeed, forward methods typically attempt to estimate a noise transition matrix whose (i,j)(i,j) entry is the probability of the true label ii being flipped into a corrupted label jj, which can be rather problematic when dealing with general, instance specific noise like in the case of distillation with unlabeled examples. On the other hand, we exploit that (i) we have access to confidence metrics of the teacher’s predictions; and (ii) that often times, when the teacher model’s top-11 prediction is inaccurate the true label is within its top-kk predictions for some appropriate kk, to design and estimate a much more refined model for the teacher’s noise that we use to inform the design of the student’s loss function.

Another related technique for dealing with noisy data is using “robust” loss functions [4, 20, 24, 35, 56] such that they achieve a small risk for new clean examples even under the presence of noise in the training dataset. In Section 4.2 we compare our method with the general framework of  [20] for designing robust loss functions and we show that our approach, when applied to the standard cross-entropy loss, consistently outperforms [20] in the setting of distillation with unlabeled examples. That said, we stress that our method is not tied to the cross-entropy loss and, in fact, it often gives better results when combined with more sophisticated loss functions. We demonstrate this in Section D.6 where we apply our method in cases where the student loss function comes from the families of losses introduced in [20] and [35].

Semi-Supervised Learning. Akin to our setting, in semi-supervised learning (SSL) (see e.g. [55] for a recent survey) the learner is presented with a small labeled dataset AA and a typically much larger unlabeled dataset BB. Unlike to our setting though, there is typically no distinction between the student and teacher: the model of interest generates pseudo-labels on BB which are utilized by using appropriate loss functions or preprocessing procedures (e.g. “filtering” or “correcting”) — often times in an iterative fashion with the goal of improving the quality of the newly-generated pseudo-labels. It is also worth noting that in many real-world applications of distillation with unlabeled examples either the teacher model is unavailable or it is too expensive to retrain it and create fresh pseudo-labels on the data (e.g., when we request labels from a pretrained large language model). Therefore, SSL approaches that either (i) update the “teacher” model (e.g., [34]), or (ii) require several fresh teacher-generated pseudo-labels (e.g., by requesting teacher-predictions on random data-augmentations or perturbed version of the unlabeled examples of BB e.g., [9]) are not applicable in our setting. We implement the recent SSL technique of [48] and show that our method outperforms it in the context of distillation with unlabeled examples. Besides performing on par with state-of-the-art SSL approaches like [9], the method of [48] is free of inherent limitations like using domain-specific data augmentations — which is also an important feature of our approach.

Learning Halfspaces with Random Classification Noise. The theoretical study of classification with Random Classification Noise (RCN) was initiated by [5]. For the fundamental class of linear classifiers (halfspaces) the first polynomial time algorithms for the problem where given in [13] and [10]. The iteration proposed in [13] is a “backward loss-adjustment” method [43] for which it is known that resulting optimization landscape is convex (for linear classifiers). In [19] an improved analysis of the method of [13] was given, showing that SGD on this convex loss learns γ\gamma-margin halfspaces with RCN with O~(1/(γ4ϵ2))\widetilde{O}(1/(\gamma^{4}\epsilon^{2})) samples. On the other hand, forward loss-adjustment methods for dealing with RCN are known to result in an inherently non-convex landscape, see [38] and Figure 8). Our theoretical result for SLaM (see Theorem 5.1) is the first convergence result for a “forward loss-adjustment” method and, at the same time, achieves a sample complexity of O(1/(γ2ϵ2))O(1/(\gamma^{2}\epsilon^{2})) improving over the prior work.

3 SLaM: Student-Label Mixing Distillation

In this section, we describe our distillation with unlabeled examples setting and present SLaM. In what follows, we assume that examples are represented by feature-vectors in some space 𝒳\mathcal{X}. We shall denote by XX the distribution over examples. We consider multi-class classification with LL classes and assume that the ground-truth label of an example xx is represented by a one-hot vector in 𝒴={0,1}L\mathcal{Y}=\{0,1\}^{L} given by some unknown function g(x):𝒳𝒴g(x):\mathcal{X}\mapsto\mathcal{Y}. In multi-class classification the learning algorithm typically optimizes a parametric family of classification models ={f(;w):𝒳𝐑L:w𝒲}\mathcal{F}=\{f(\cdot;w):\mathcal{X}\mapsto\mathbf{R}^{L}:w\in\mathcal{W}\}, i.e., for every parameter w𝒲w\in\mathcal{W}, f(x;w)f(x;w) is an LL-dimensional “score vector”, where f(x;w)if(x;w)_{i} corresponds to the probability that the model assigns to the class ii for the example xx. We shall denote by (,):𝐑L×𝐑L𝐑\ell(\cdot,\cdot):\mathbf{R}^{L}\times\mathbf{R}^{L}\mapsto\mathbf{R} the classification loss function used by the learning algorithm. During training the algorithm considers a set of labeled examples S={(x(1),g(x(1))),,(x(n),g(x(n))}S=\{(x^{(1)},g(x^{(1)})),\ldots,(x^{(n)},g(x^{(n)})\} and optimizes the loss (,)\ell(\cdot,\cdot) over SS, i.e., solves the problem minw𝒲1|S|(x,g(x))S(g(x),f(x;w)).\min_{w\in\mathcal{W}}\frac{1}{|S|}\sum_{(x,g(x))\in S}\ell(g(x),f(x;w))\,. For two vectors v,u𝐑Lv,u\in\mathbf{R}^{L} we denote by err(v,u)=𝟏{argmax(v)argmax(u)}\mathrm{err}(v,u)=\mathbf{1}\{\operatorname*{argmax}(v)\neq\operatorname*{argmax}(u)\} the indicator of the event that the positions of the maximum elements of v,uv,u agree. Similarly, for two classifiers h(x),f(x):𝐑d𝐑Lh(x),f(x):\mathbf{R}^{d}\mapsto\mathbf{R}^{L} we can use err(h(x),f(x))\mathrm{err}(h(x),f(x)) to denote whether their top-1 predictions for the example xx agree. Our goal is to train a classifier over the sample SS so that its generalization error, i.e., 𝐄xX[err(f(x;w),g(x))]\operatorname{\mathbf{E}}_{x\sim X}[\mathrm{err}(f(x;w),g(x))], is small.

Distillation with Unlabeled Examples.

We assume that we are given a (usually small) dataset AA of correctly labeled examples (x,g(x))(x,g(x)) and a set of unlabeled data UU. A “teacher” model ys():𝒳𝐑Ly_{s}(\cdot):\mathcal{X}\mapsto\mathbf{R}^{L} is first trained on the labeled dataset AA and then provides soft-labels for the examples of dataset UU, i.e., we create a dataset B={(x,ys(x)):xU}B=\{(x,y_{s}(x)):x\in U\} containing examples labeled with the corresponding probability distribution over classes (soft-labels) of the teacher model. We then train a (typically smaller) student model using both the original labeled data AA and the teacher-labeled dataset BB, i.e., minw𝒲1|AB|(x,z)AB(z,f(x;w))\min_{w\in\mathcal{W}}\frac{1}{|A\cup B|}\sum_{(x,z)\in{A\cup B}}\ell(z,f(x;w)). In what follows, we shall call the above training procedure as “vanilla-distillation”.

Remark 3.1 (“Hard-” vs “Soft-” Distillation).

We remark that the process where instead of using the soft-labels provided by the teacher model on the unlabeled dataset U, we use one-hot vectors representing the class with maximum score according to the teacher, is known as hard-distillation. We will denote by ys(x)y_{s}(x) the soft-label of the teacher and by y(x)y(x) the corresponding hard-label, i.e., y(x)y(x) is the one-hot representation of argmaxys(x)\operatorname*{argmax}y_{s}(x). When it is clear from the context we may simply write yy instead of y(x)y(x).

Modelling the Teacher as a “Noisy” Label Oracle.

In the distillation etting described in the previous paragraph, it is known [51, 27, 8, 44] that the teacher model often generates incorrect predictions on the unlabeled examples, impairing the student’s performance. Given any xUx\in U, we model the teacher’s prediction yy as a random variable. Similarly to [27] we assume that, for every unlabeled datapoint xUx\in U, the provided teacher label yy is correct with probability α(x)\alpha(x) and incorrect with probability 1α(x)1-\alpha(x). However, in contrast with [27], our noise model prescribes a non-advsersarial (semi-random) behavior of the teacher when its top-1 prediction is incorrect.

A first step towards more benign noisy teachers is to assume that, conditionally on being wrong, the teacher label is a uniformly random class of the remaining L1L-1 classes. We remark that this model is already enough to give improvements in datasets with a moderately large number of classes (e.g., up to 100). In particular, it perfectly captures the noisy teacher in binary classification: when the teacher label is different than the ground-truth g(x)g(x) then it has to be equal to the “flipped” ground-truth 1g(x)1-g(x).

We now further refine our model so that it is realistic for datasets with thousands of classes. Even though the top-1 accuracy of the teacher model may not be very high on the unlabeled data UU, the true label is much more likely to belong in the top-5 or top-10 predictions of the teacher rather than being completely arbitrary. For example, training a ResNet50 network on 10%10\% of ImageNet [49] yields an average top-1 accuracy about 52.78%52.78\% on the test dataset whereas the top-10 accuracy of the same model is about 83.55%83.55\%. In datasets with a large number of classes, this observation significantly reduces the number of potential correct classes of the examples where the teacher label is incorrect. Motivated by the above, we assume the following structured, semi-random noise model for the teacher, tailored to multi-class settings.

Definition 3.2 (Noisy Teacher Model).

Let xx be any example of the unlabeled data UU and denote by g(x)g(x) its ground-truth label. Let ys(x)y_{s}(x) resp. y(x)y(x) be the random variable corresponding to the soft resp. hard prediction of the teacher model for the example xx. We assume that for every xx there exist (unknown to the learner) α(x)[0,1]\alpha(x)\in[0,1] and k(x){2,,L}k(x)\in\{2,\ldots,L\} such that the teacher’s top-1 prediction yy agrees with the ground-truth g(x)g(x) with probability α(x)\alpha(x) and, with probability 1α(x)1-\alpha(x): (i) the ground-truth belongs in the top-k(x)k(x) predictions of the teacher; and (ii) the teacher’s (hard)-prediction is a uniformly random incorrect class out of the top-k(x)k(x) predictions of the teacher soft-label ys(x)y_{s}(x) 111Given that the teacher’s prediction is incorrect and that the ground-truth belongs in the top-k(x) predictions of the teacher, assumption (ii) describes a uniform distribution on k(x)1k(x)-1 labels..

Remark 3.3.

We remark that the model of Definition 3.2 captures having a “perfect” teacher model by setting α(x)=1\alpha(x)=1 for all xx and also generalizes the binary case described above by taking k(x)=2k(x)=2 for all xXx\in X.

Given the above noise model for the teacher, the problem of improving knowledge-distillation consists of two main tasks: (i) obtaining estimates for accuracy statistics α(x),k(x)\alpha(x),k(x) for each example xUx\in U; and (ii) using those estimated values to improve the training of the student model so that it is affected less by the mistakes of the teacher on dataset BB.

Training Better Students Using α(x),k(x)\alpha(x),k(x)

We first assume that for every xx we have oracle access to the values α(x),k(x)\alpha(x),k(x) and present our Student-Label Mixing loss function. Instead of using α(x),k(x)\alpha(x),k(x) to “denoise” the teacher’s label, we use them to add noise to the student’s predictions. To make notation more compact, in what follows, given a vector z𝐑Lz\in\mathbf{R}^{L} we denote by top(z;k)\mathrm{top}(z;k) the vector that has the value 11 in the positions of the of the 11-st up to kk-th largest elements of zz and 0 in all other positions, e.g., top((1,2,3);1)=(0,0,1)\mathrm{top}((1,2,3);1)=(0,0,1) and top((1,1,0,2);3)=(0,1,1,1)\mathrm{top}((-1,1,0,2);3)=(0,1,1,1). Assuming that the student-label for some xUx\in U is f(x;w)f(x;w) we “mix” it (hence the name Student-Label Mixing) using α(x),k(x)\alpha(x),k(x) to obtain the mixed prediction

mix(f(x;\displaystyle\mathrm{mix}(f(x; w);α(x),k(x))=α(x)f(x;w)+(1α(x))top(ys(x);k(x))1f(x;w)k(x)1,\displaystyle w);\alpha(x),k(x))=\alpha(x)f(x;w)~{}+~{}(1-\alpha(x))~{}\mathrm{top}(y_{s}(x);k(x))*\frac{1-f(x;w)}{k(x)-1}\,, (1)

where qpq*p is the element-wise multiplication of the vectors p,qp,q. We then train the mixed student model, on the “noisy” dataset BB:

minw𝒲\displaystyle\min_{w\in\mathcal{W}} 1|AB|((x,z)A(z,f(x;w))+(x,y)B(y,mix(f(x;w);α(x),k(x)))\displaystyle\frac{1}{|A\cup B|}\Bigg{(}\sum_{(x,z)\in A}\ell(z,f(x;w))+\sum_{(x,y)\in B}\ell(y,\mathrm{mix}(f(x;w);\alpha(x),k(x))\Bigg{)} (2)

The main intuition behind the mixing of the student’s labels is that by training the “noisy” student to match the “noisy” teacher label yy on dataset BB, the underlying (non-mixed) student f(x;w)f(x;w) will eventually learn the ground-truth. In particular, when (,)\ell(\cdot,\cdot) is the Cross-Entropy loss we have that the expected mixed loss conditioned on any xx is

𝐄[(y;mix(f(x;w),a(x),k(x)))x]=(mix(g(x);α(x),k(x)),mix(f(x;w);α(x),k(x))),\displaystyle\operatorname{\mathbf{E}}[\ell(y;\mathrm{mix}(f(x;w),a(x),k(x)))\mid x]=\ell(\mathrm{mix}(g(x);\alpha(x),k(x)),\mathrm{mix}(f(x;w);\alpha(x),k(x)))\,,

where we used the fact that the cross-entropy is linear in its first argument, and that by the definition of our noise model (Definition 3.2) it holds that 𝐄[yx]=mix(g(x);α(x),k(x))\operatorname{\mathbf{E}}[y\mid x]=\mathrm{mix}(g(x);\alpha(x),k(x)). Therefore, when the student is equal to the ground-truth f(x;w)=g(x)f(x;w)=g(x), we obtain that the mixed student-model will satisfy mix(g(x);α(x),k(x))=mix(f(x;w);α(x),k(x))\mathrm{mix}(g(x);\alpha(x),k(x))=\mathrm{mix}(f(x;w);\alpha(x),k(x)) for all xXx\in X, and (by Gibb’s inequality), we obtain that g(x)g(x) is a minimizer of the SLaM loss. We show the following proposition, see Appendix C for the formal statement and proof.

Proposition 3.4 (SLaM Consistency (Informal)).

Let DD be the distribution of the teacher-labeled examples of dataset BB, i.e., we first draw xXx\sim X and then label it using the noisy teacher of Definition 3.2. Moreover, assume that there exists some parameter w𝒲w^{\ast}\in\mathcal{W} such that the ground-truth g(x)=f(x;w)g(x)=f(x;w^{\ast}). Then ww^{\ast} is the minimizer of the (population) SLaM objective: minw𝐄(x,y)D[ce(y,f(x;w))]\min_{w}\operatorname{\mathbf{E}}_{(x,y)\sim D}[\mathrm{ce}(y,f(x;w))], where ce(,)\mathrm{ce}(\cdot,\cdot) is the Cross-Entropy loss.

Estimating the Teacher’s Accuracy Statistics α(x),k(x)\alpha(x),k(x) via Isotonic Regression

We first show how we estimate α(x)\alpha(x) for each xx of dataset BB, i.e., the dataset labeled by the teacher model. In [27] the authors empirically observed that α(x)\alpha(x) correlates with metrics of teacher’s confidence such as the “margin”, i.e., the difference between the probabilities assigned in the top-1 class and the second largest class according to the teacher’s soft label ysy_{s}. In particular, the larger the margin is the more likely is that the corresponding teacher label is correct. We exploit this monotonicity by employing isotonic regression on a small validation dataset to learn the mapping from the teacher’s margin at an example xx to the corresponding teacher’s accuracy α(x)\alpha(x). For more details, see Section B.1.

To perform this regression task we use a small validation dataset VV with correct labels that the teacher has not seen during training. For every example xVx\in V we compute the corresponding soft-teacher label ys(x)y_{s}(x) and compute its margin margin(x)=max1(ys(x))max2(ys(x))\mathrm{margin}(x)=\max_{1}(y_{s}(x))-\max_{2}(y_{s}(x)). For every xVx\in V we also compute the hard-prediction of the teacher and compare it with the ground-truth, i.e., for every xVx\in V the covariate and responce pair is (margin(x),1err(g(x),y(x)))(\mathrm{margin}(x),1-\mathrm{err}(g(x),y(x))). We then use isotonic regression to fit a piecewise constant, increasing function to the data. We remark that isotonic regression can be implemented very efficiently in O(nlogn)O(n\log n) time (where nn is the size of the validation dataset).

For k(x)k(x) we consider two different options: (i) using the same value for all examples (e.g., using kk so that the top-k accuracy of teacher is above some threshold on the validation data); and (ii) using a “data-dependent” k(x)k(x) that we estimate by solving LL (recall that LL is the number of classes) isotonic-regression problems (similar to that for estimating α(x)\alpha(x) above). We refer to Section B.1 for more details.

Refer to caption
Figure 1: Learning α(x)\alpha(x) via isotonic regression. The data were generated by a ResNet 110 teacher trained on 50005000 examples of CIFAR-100 and evaluated on a validation dataset VV of 500500 examples. The regression data {(margin(ys(x)),1err(ys(x),g(x))):xV}\{(\mathrm{margin}(y_{s}(x)),1-\mathrm{err}(y_{s}(x),g(x))):x\in V\} are shown in gray (the response is binary 0/10/1). By enforcing monotonicity, isotonic regression yields a more stable and robust curve than, for example, the KNN predictor.

4 Experimental Evaluation

In this section, we present our experimental results. In Section 4.1 we describe our experimental setup and in Section 4.2 we compare the performance of our method with previous approaches on standard benchmarks. In Section D.4 we show that our method can be combined with teacher-uncertainty-based reweighting techniques. Finally, due to space limitations, we provide additional empirical results in the Appendix: in Section D.5 we show that SLaM can effectively be used with distillation temperature, and in Section D.6 we consider using SLaM with other losses beyond the Cross-Entropy.

4.1 The Setup

Here, we describe our procedure for simulating knowledge distillation with unlabeled examples on academic datasets. We start by splitting the training dataset in two parts: dataset A and dataset C. We then train the teacher and student models on dataset A (using the standard cross-entropy loss).222 We remark that our method does not require pre-training the student on dataset A, however, since [27] requires pre-training the student, we do the same for all methods that we compare. Then we perform multiple independent trials where, for each trial, we randomly split dataset C into a small (e.g., 500 examples validation dataset V and an unlabeled training dataset U. For each trial we (i) use the teacher model to label the points on dataset U to obtain the teacher-labeled dataset B (ii) initialize the weights of the student to those of the student model that was pre-trained on dataset A; (iii) train the student model (using each distillation method) on the combined labeled data of A, V (that have true labels) and the data of B (that have teacher labels). We remark here that we include the validation data V during the training of the student to be fair towards methods that do not use a validation dataset. However, while it is important that the teacher has not seen the validation data during training, the performance of no method was affected significantly by including (or excluding) the validation data from the training dataset.

4.2 Comparison with Previous Approaches

The Baselines

A natural question is whether a more sophisticated distillation method that enforces greater consistency between the teacher and the student, would improve distillation with unlabeled examples: we use the VID method [2] that incorporates the penultimate layer of the student model (after a suitable trainable projection) in the loss. We also compare our method against the weighted distillation method of [27] that reweights the examples of dataset BB in order to “correct” the effect of the noisy pseudo-labels provided by the teacher. The Taylor cross-entropy method of [20] is a modification of CE that truncates the taylor-series of the CE loss. In [20] it was shown that it offers significant improvements when the labels are corrupted by random classification noise. The fact that the teacher’s noise is much closer to random than to adversarial makes this approach a natural baseline. The UPS loss of [48] is a semi-supervised technique that takes into account the variance (uncertainty) of the teacher model on the examples of dataset BB in order to transform the soft pseudo-labels provided by the teacher to more “robust” binary vectors and then use a modified binary CE loss. To estimate the uncertainty of the teacher model, we used either dropout with Monte-Carlo estimation or random data-augmentations as suggested in [48]. We remark that, as we discussed in Section 2 and Section 1, strictly speaking, this method is not applicable in our setting because it requires multiple forward passes of the teacher model to estimate its variance but we implement it as it is a relevant approach that aims to improve the pseudo-labels of the teacher.

CIFAR-{10,100} and CelebA

Refer to caption
Refer to caption
Refer to caption
Figure 2: Comparison of distillation methods on CIFAR-10,100 and CelebA. On the horizontal axis we plot the size of Dataset A as a percentage of the whole training dataset. On the vertical axis we plot the accuracy of the trained student-model on the test dataset.

Here we present our results on CIFAR-{10, 100} [30] and CelebA [22]. CIFAR-10 and CIFAR-100 are image classification datasets with 10 and 100 classes respectively. They contain 60000 labeled images, which are split to a training set of 50000 images, and a test set of 10000 images. From the 50000 images of the train set we use the 10%,15%,20%,25%,30%,35%10\%,15\%,20\%,25\%,30\%,35\% (or 5000, 7500, 10000, 12500, 15000, and 17500 examples) as the labeled dataset A where we train the teacher and pre-train the student models. For each size of dataset A, we perform a random split on the remaining training data and use 500 labeled examples as the validation dataset and the remaining examples as the unlabeled dataset U. For the CIFAR-10 experiments, we use a Mobilenet with depth multiplier 2 as the teacher, and a Mobilenet with depth multiplier 1 as the student. For CIFAR-100, we use a ResNet-110 as a teacher, and a ResNet-56 as the student. We compare the methods both on soft- and hard-distillation. For each trial we train the student model for 200200 epochs and keep the best test accuracy over all epochs. We perform 3 trials and report the average of each method and the variance of the achieved accuracies over the trials. The results of our experiments for soft-distillation can be found in Table 1 and Table 2. The corresponding plots are given inFigure 2. We include our results for hard-distillation in Section D.2.

Table 1: Experiments on CIFAR-10 (soft-distillation). See Section 4.2 for details.
Labeled Examples 50005000 75007500 1000010000 1250012500 1500015000 1750017500
Teacher 61.3061.30 68.9868.98 72.4272.42 73.9273.92 76.6376.63 78.6378.63
Vanilla 63.53±0.2963.53\pm 0.29 70.39±0.1170.39\pm 0.11 73.23±0.1573.23\pm 0.15 74.29±0.2574.29\pm 0.25 76.64±0.2076.64\pm 0.20 78.63±0.1678.63\pm 0.16
Taylor-CE [20] 64.07±0.2664.07\pm 0.26 71.19±0.1771.19\pm 0.17 74.18±0.2574.18\pm 0.25 74.65±0.2474.65\pm 0.24 77.17±0.0477.17\pm 0.04 78.67±0.1378.67\pm 0.13
UPS [48] 64.56±0.1364.56\pm 0.13 71.10±0.3471.10\pm 0.34 74.17±0.0674.17\pm 0.06 75.05±0.2475.05\pm 0.24 77.64±0.1277.64\pm 0.12 79.21±0.27\mathbf{79.21\pm 0.27}
VID [3] 63.76±0.1363.76\pm 0.13 70.58±0.1770.58\pm 0.17 73.77±0.4073.77\pm 0.40 74.95±0.2174.95\pm 0.21 77.25±0.0677.25\pm 0.06 78.23±0.0978.23\pm 0.09
Weighted [27] 63.85±0.1363.85\pm 0.13 71.04±0.2471.04\pm 0.24 73.64±0.3673.64\pm 0.36 75.00±0.1775.00\pm 0.17 77.40±0.1777.40\pm 0.17 78.93±0.1978.93\pm 0.19
SLaM (Ours) 66.82±0.61\mathbf{66.82\pm 0.61} 72.61±0.30\mathbf{72.61\pm 0.30} 75.01±0.25\mathbf{75.01\pm 0.25} 75.72±0.17\mathbf{75.72\pm 0.17} 78.04±0.16\mathbf{78.04\pm 0.16} 79.22±0.11\mathbf{79.22\pm 0.11}
Table 2: Experiments on CIFAR-100 (soft-distillation). See Section 4.2 for details.
Labeled Examples 50005000 75007500 1000010000 1250012500 1500015000 1750017500
Teacher 35.9735.97 44.6544.65 49.6249.62 55.6855.68 59.1959.19 62.0562.05
Vanilla 37.94±0.1037.94\pm 0.10 46.42±0.2446.42\pm 0.24 52.17±0.2152.17\pm 0.21 57.72±0.1757.72\pm 0.17 60.91±0.0760.91\pm 0.07 63.47±0.2363.47\pm 0.23
Taylor-CE [20] 40.18±0.0740.18\pm 0.07 48.05±0.2948.05\pm 0.29 54.08±0.2454.08\pm 0.24 58.45±0.1758.45\pm 0.17 61.13±0.1061.13\pm 0.10 63.54±0.2663.54\pm 0.26
UPS [48] 39.62±0.2339.62\pm 0.23 48.48±0.1548.48\pm 0.15 54.43±0.2754.43\pm 0.27 58.17±0.0758.17\pm 0.07 60.74±0.1060.74\pm 0.10 62.13±0.1262.13\pm 0.12
VID [3] 38.93±0.3938.93\pm 0.39 46.76±0.1046.76\pm 0.10 52.56±0.1752.56\pm 0.17 57.94±0.3757.94\pm 0.37 61.14±0.2861.14\pm 0.28 63.56±0.1863.56\pm 0.18
Weighted [27] 38.63±0.3238.63\pm 0.32 47.11±0.2947.11\pm 0.29 53.16±0.2553.16\pm 0.25 58.20±0.1158.20\pm 0.11 61.29±0.15\mathbf{61.29\pm 0.15} 63.58±0.0763.58\pm 0.07
SLaM (Ours) 42.72±0.30\mathbf{42.72\pm 0.30} 49.89±0.23\mathbf{49.89\pm 0.23} 54.73±0.27\mathbf{54.73\pm 0.27} 58.78±0.15\mathbf{58.78\pm 0.15} 61.30±0.09\mathbf{61.30\pm 0.09} 63.98±0.19\mathbf{63.98\pm 0.19}

We consider the male/female binary classification task using the CelebA dataset [22] consisting of a training set of 162770 images and a test set of 19962 images. We use a MobileNet with depth multiplier 2 as the teacher, and a ResNet-11 as the student. As the labeled dataset A we used 2%,3%,4%,5%,6%2\%,3\%,4\%,5\%,6\% percent (or 3256, 4883, 6510, 8138, 9766, 11394 examples) of the training dataset and split the remaining data in a validation dataset of 500 examples and an unlabeled dataset U. Our results for CelebA can be found in Table 3 (soft-distillation) and in Table 7 (hard-distillation). The corresponding plots are given in Figure 2. Due to space limitations our results for hard-distillation can be found in Section D.2.

Taken together, our comparisons show that SLaM consistently outperforms the baselines, often by a large margin. The reader is referred to Section D.1 for additional details.

Remark 4.1 (Soft-Distillation and Temperature Scaling).

We remark that in the comparisons we performed soft-distillation with temperature set to 11, i.e., for every example we do not scale the corresponding teacher and student logits. In Section D.5 we show that our method can readily be used together with temperature scaling to improve the accuracy of the student model.

Table 3: Experiments on CelebA (soft-distillation). See Section 4.2 for details.
Labeled Examples 2%2\% 3%3\% 4%4\% 5%5\% 6%6\% 7%7\%
Teacher 86.1986.19 88.2588.25 88.9588.95 91.3191.31 92.0992.09 92.6292.62
Vanilla 89.96±0.0889.96\pm 0.08 91.55±0.1491.55\pm 0.14 92.16±0.1092.16\pm 0.10 93.42±0.0693.42\pm 0.06 93.98±0.0493.98\pm 0.04 94.29±0.0394.29\pm 0.03
Taylor-CE [20] 90.80±0.07\mathbf{90.80\pm 0.07} 92.23±0.1\mathbf{92.23\pm 0.1} 92.56±0.1492.56\pm 0.14 93.80±0.2093.80\pm 0.20 94.17±0.0794.17\pm 0.07 94.47±0.0194.47\pm 0.01
UPS [48] 89.96±0.1189.96\pm 0.11 92.03±0.0992.03\pm 0.09 92.44±0.0492.44\pm 0.04 93.9±0.0593.9\pm 0.05 94.28±0.0794.28\pm 0.07 94.68±0.0394.68\pm 0.03
VID [3] 89.91±0.1089.91\pm 0.10 91.75±0.2191.75\pm 0.21 92.21±0.1092.21\pm 0.10 93.67±0.2193.67\pm 0.21 94.15±0.0794.15\pm 0.07 94.33±0.1694.33\pm 0.16
Weighted [27] 89.92±0.1289.92\pm 0.12 91.73±0.0991.73\pm 0.09 92.31±0.2292.31\pm 0.22 93.64±0.1093.64\pm 0.10 93.93±0.1493.93\pm 0.14 94.23±0.1194.23\pm 0.11
SLaM (Ours) 90.37±0.1790.37\pm 0.17 92.25±0.11\mathbf{92.25\pm 0.11} 92.74±0.17\mathbf{92.74\pm 0.17} 94.06±0.07\mathbf{94.06\pm 0.07} 94.39±0.10\mathbf{94.39\pm 0.10} 94.75±0.08\mathbf{94.75\pm 0.08}

ImageNet

Here we present the results on ImageNet [49]. ImageNet is an image classification dataset with 1000 classes consisting of a training set of approximately 1.31.3 million images, and a test set of 50000 images. From the 1.31.3 million images of the training set we use the 5%,10%,15%,20%5\%,10\%,15\%,20\% percent (or 64058, 128116, 192174, 256232 examples) as the labeled dataset AA where we train the teacher and pre-train the student models. For each size of dataset AA, we perform a random split on the remaining training data and use 1000010000 labeled examples as the validation dataset and the remaining examples as the unlabeled dataset UU. We use a ResNet-50 as the teacher, and a ResNet-18 as the student. We compare the methods on soft-distillation. For each trial, we train the student model for 100100 epochs and keep the best test accuracy over all epochs. We perform 44 trials and report the average of each method and the variance of the achieved accuracies over the trials. Our results for ImageNet can be found in Table 4. We remark that we do not include the results of the UPS method in Table 4 because it did not improve over the accuracy achieved after pre-training the student model on dataset AA. The reader is referred to Section D.1 for additional details.

Table 4: Experiments on ImageNet (soft-distillation). See Section 4.2 for details.
Labeled Examples 5%5\% 10%10\% 15%15\% 20%20\% 25%25\% 30%30\%
Teacher 39.4839.48 52.9652.96 59.6459.64 63.6263.62 66.0066.00 67.8567.85
Vanilla 41.67±0.0541.67\pm 0.05 55.9±0.0655.9\pm 0.06 62.3±0.0962.3\pm 0.09 65.91±0.0565.91\pm 0.05 67.98±0.0767.98\pm 0.07 69.12±0.0869.12\pm 0.08
Taylor-CE [20] 41.61±0.0641.61\pm 0.06 56.43±0.0656.43\pm 0.06 62.38±0.1162.38\pm 0.11 65.86±0.0865.86\pm 0.08 67.70±0.2267.70\pm 0.22 68.62±0.0768.62\pm 0.07
VID [3] 40.12±0.0440.12\pm 0.04 52.75±0.0452.75\pm 0.04 58.01±0.0358.01\pm 0.03 61.21±0.0661.21\pm 0.06 62.37±0.0662.37\pm 0.06 63.05±0.0763.05\pm 0.07
Weighted [27] 41.67±0.0441.67\pm 0.04 55.96±0.0755.96\pm 0.07 62.29±0.0862.29\pm 0.08 65.91±0.0565.91\pm 0.05 67.96±0.0667.96\pm 0.06 69.16±0.08\mathbf{69.16\pm 0.08}
SLaM (Ours) 48.1±0.05\mathbf{48.1\pm 0.05} 59.51±0.07\mathbf{59.51\pm 0.07} 64.08±0.06\mathbf{64.08\pm 0.06} 66.72±0.11\mathbf{66.72\pm 0.11} 68.17±0.07\mathbf{68.17\pm 0.07} 69.07±0.0569.07\pm 0.05

Large Movies Reviews Dataset

Here we present results on the Large Movies Reviews Dataset [39]. This is a dataset for binary sentiment classification containing 25000 movie reviews for training and 25000 for testing. We use an ALBERT-large model [32] as a teacher, and an ALBERT-base model as a student. We use 2%,4%,8%,40%2\%,4\%,8\%,40\% percent (or 500, 1000, 2000, 10000 examples) from the training dataset and split the remaining data in a validation dataset of 500 examples and an unlabeled dataset UU. Our results and more experimental details can be found in Section D.3.

Performance Gains of SLaM as a Function of The Number of Labeled Examples

In our experiments, the fraction of examples we consider “labeled” controls two things at the same time: (i) the accuracy of the teacher model — as the teacher is trained on the labeled examples available; and (ii) the number of unlabeled examples the teacher model provides pseudo-labels for. The more inaccurate the teacher model is, the better the improvements provided by our method. (Given a “perfect” teacher that never generates incorrect pseudo-labels for the unlabeled examples, our method is mathematically equivalent to the “vanilla” approach (see the mixing operation in Equation 1). Therefore, the smaller the number of labeled examples available, the bigger the performance gains of SLaM as (i) the teacher will be less accurate; and (ii) it has to generate labels for more unlabeled examples (and therefore the absolute number of inaccurate predictions that SLaM “corrects” increases statistically). It is worth emphasizing that the main reason behind the enormous success of distillation is exactly that the teacher network can blow up the size of the student’s training dataset: in practice, the ratio of labeled examples to unlabeled examples is typically (much) less than 1%.

5 Distilling Linear Models and Learning Noisy Halfspaces

In this section we show that, when the dataset is separable by a halfspace, i.e., for every example xx, the ground-truth is g(x)=(𝟏{wx>0},𝟏{wx0})g(x)=(\mathbf{1}\{w^{\ast}\cdot x>0\},\mathbf{1}\{w^{\ast}\cdot x\leq 0\}) for some unknown weight vector ww^{\ast}, then using SLaM with a linear model as the student will recover the ground truth classifier. We make the standard assumption that the ground-truth halfspace has γ\gamma-margin, i.e., that w2=1\|w^{\ast}\|_{2}=1 and that it holds |wx|γ|w^{\ast}\cdot x|\geq\gamma for all examples xx. For a fixed example xx, the observed noisy teacher-label yy satisfies Definition 3.2, i.e., y=g(x)y=g(x) w.p. α(x)\alpha(x) and y=1g(x)y=1-g(x) w.p. 1α(x)1-\alpha(x) (since k=2k=2 for binary classification). Our approach consists of using the standard cross-entropy loss ce(p,q)\mathrm{ce}(p,q) and training a student-model consisting of a linear layer plus a soft-max activation, i.e., f(x;w)=(11+ewx,ewx1+ewx).f(x;w)=\left(\frac{1}{1+e^{-w\cdot x}},\frac{e^{-w\cdot x}}{1+e^{-w\cdot x}}\right)\,.

Theorem 5.1 (SLaM Convergence).

Let XX be a distribution on 𝐑d\mathbf{R}^{d} and g(x)g(x) be the ground-truth halfspace with normal vector w𝐑dw^{\ast}\in\mathbf{R}^{d}. Let DD be the distribution over (noisy) teacher-labeled examples (x,y)(x,y) whose xx-marginal is XX. Assume that there exist β,γ>0\beta,\gamma>0 such that for all examples xx in the support of XX it holds that |wx|γ|w^{\ast}\cdot x|\geq\gamma and |1/2α(x)|β|1/2-\alpha(x)|\leq\beta. Let ϵ>0\epsilon>0. After T=O(1/(β2γ2ϵ2))T=O(1/(\beta^{2}\gamma^{2}\epsilon^{2})) SGD iterations on the SLaM objective (see Algorithm 3), with probability at least 99%99\%, there exists an iteration tTt\leq T where 𝐏xX[err(f(x;w(t)),g(x))]ϵ\operatorname{\mathbf{P}}_{x\sim X}[\mathrm{err}(f(x;w^{(t)}),g(x))]\leq\epsilon.

Remark 5.2 (Learning Halfspaces with RCN).

The problem of learning halfspaces with Random Classification Noise (RCN) can be modeled as having a teacher with constant accuracy probability, i.e., α(x)=α>1/2\alpha(x)=\alpha>1/2 for all xx. As a corollary of Theorem 5.1 we obtain an efficient learning algorithm for γ\gamma-margin halfspaces under RCN achieving a sample complexity of O(1/(γ2ϵ2))O(1/(\gamma^{2}\epsilon^{2})). Prior to our work, the best known sample complexity for provably learning halfspaces with RCN was O~(1/(γ4ϵ2))\widetilde{O}(1/(\gamma^{4}\epsilon^{2})) [19] where the “backward loss-adjustment” of [13] was used.

6 Conclusion, Limitations, and Broader Impact

In this work we propose SLaM, a novel and principled method for improving distillation with unlabeled examples. We empirically show that SLaM consistently outperforms the baselines, often by a large margin. We also showed that SLaM can be used with and improve (i) knowledge distillation with temperature scaling; (ii) loss functions beyond the standard Cross-Entropy loss; and (iii) confidence-based weighting schemes that down-weight examples where the teacher model is not very confident. Apart from extensive experimental evaluation, we provide strong theoretical guarantees establishing the consistency and optimality of SLaM. As a byproduct of our theoretical analysis, we obtain a new iteration for learning γ\gamma-margin halfspaces with RCN that improves the best known sample complexity for this problem.

A limitation of SLaM is that it does not necessarily improve over vanilla distillation when the teacher model makes only a few mistakes (this is to be expected as our method is designed for the case where the teacher-model is imperfect). Moreover, while our theoretical result for learning noisy halfspaces improves the theoretical SOTA, it does not match the information-theoretic lower bound of Ω(1/(γ2ϵ))\Omega(1/(\gamma^{2}\epsilon)). An interesting question for future work is whether this can be further improved to match the information-theoretic lower bound.

Knowledge-distillation is a very popular deep learning method, and therefore, potentially malicious usage of our work is an important societal issue, as deep learning has far-reaching applications from NLP to Robotics and Self-Driving cars.

References

  • [1] Martín Abadi, Ashish Agarwal, Paul Barham, Eugene Brevdo, Zhifeng Chen, Craig Citro, Greg S Corrado, Andy Davis, Jeffrey Dean, Matthieu Devin, et al. Tensorflow: Large-scale machine learning on heterogeneous distributed systems. arXiv preprint arXiv:1603.04467, 2016.
  • [2] Sungsoo Ahn, Shell Xu Hu, Andreas Damianou, Neil D Lawrence, and Zhenwen Dai. Variational information distillation for knowledge transfer. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 9163–9171, 2019.
  • [3] Sungsoo Ahn, Shell Xu Hu, Andreas Damianou, Neil D. Lawrence, and Zhenwen Dai. Variational information distillation for knowledge transfer. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), June 2019.
  • [4] Ehsan Amid, Manfred KK Warmuth, Rohan Anil, and Tomer Koren. Robust bi-tempered logistic loss based on bregman divergences. Advances in Neural Information Processing Systems, 32, 2019.
  • [5] D. Angluin and P. Laird. Learning from noisy examples. Machine Learning, 2(4):343–370, 1988.
  • [6] Eric Arazo, Diego Ortego, Paul Albert, Noel E O’Connor, and Kevin McGuinness. Pseudo-labeling and confirmation bias in deep semi-supervised learning. In 2020 International Joint Conference on Neural Networks (IJCNN), pages 1–8. IEEE, 2020.
  • [7] Noga Bar, Tomer Koren, and Raja Giryes. Multiplicative reweighting for robust neural network optimization. arXiv preprint arXiv:2102.12192, 2021.
  • [8] Cenk Baykal, Khoa Trinh, Fotis Iliopoulos, Gaurav Menghani, and Erik Vee. Robust active distillation. International Conference on Learning Representations (ICLR), 2023.
  • [9] David Berthelot, Nicholas Carlini, Ian Goodfellow, Nicolas Papernot, Avital Oliver, and Colin A Raffel. Mixmatch: A holistic approach to semi-supervised learning. Advances in neural information processing systems, 32, 2019.
  • [10] A. Blum, A. Frieze, R. Kannan, and S. Vempala. A polynomial-time algorithm for learning noisy linear threshold functions. Algorithmica, 22(1):35–52, 1998.
  • [11] Tom Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared D Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, et al. Language models are few-shot learners. Advances in neural information processing systems, 33:1877–1901, 2020.
  • [12] Cristian Buciluǎ, Rich Caruana, and Alexandru Niculescu-Mizil. Model compression. In Proceedings of the 12th ACM SIGKDD international conference on Knowledge discovery and data mining, pages 535–541, 2006.
  • [13] T. Bylander. Learning linear threshold functions in the presence of classification noise. In Proceedings of the seventh annual conference on Computational learning theory, pages 340–347, 1994.
  • [14] Hanting Chen, Yunhe Wang, Chang Xu, Chao Xu, and Dacheng Tao. Learning student networks via feature embedding. IEEE Transactions on Neural Networks and Learning Systems, 32(1):25–35, 2020.
  • [15] Liqun Chen, Dong Wang, Zhe Gan, Jingjing Liu, Ricardo Henao, and Lawrence Carin. Wasserstein contrastive representation distillation. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pages 16296–16305, 2021.
  • [16] Ting Chen, Simon Kornblith, Kevin Swersky, Mohammad Norouzi, and Geoffrey E Hinton. Big self-supervised models are strong semi-supervised learners. Advances in neural information processing systems, 33:22243–22255, 2020.
  • [17] Mostafa Dehghani, Arash Mehrjou, Stephan Gouws, Jaap Kamps, and Bernhard Schölkopf. Fidelity-weighted learning. arXiv preprint arXiv:1711.02799, 2017.
  • [18] Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. Bert: Pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805, 2018.
  • [19] I. Diakonikolas, T. Gouleakis, and C. Tzamos. Distribution-independent pac learning of halfspaces with massart noise. Advances in Neural Information Processing Systems, 32, 2019.
  • [20] Lei Feng, Senlin Shu, Zhuoyi Lin, Fengmao Lv, Li Li, and Bo An. Can cross entropy loss be robust to label noise? In Proceedings of the Twenty-Ninth International Conference on International Joint Conferences on Artificial Intelligence, pages 2206–2212, 2021.
  • [21] Benoît Frénay and Michel Verleysen. Classification in the presence of label noise: a survey. IEEE transactions on neural networks and learning systems, 25(5):845–869, 2013.
  • [22] Tommaso Furlanello, Zachary Lipton, Michael Tschannen, Laurent Itti, and Anima Anandkumar. Born again neural networks. In International Conference on Machine Learning, pages 1607–1616. PMLR, 2018.
  • [23] Dragan Gamberger, Nada Lavrac, and Ciril Groselj. Experiments with noise filtering in a medical domain. In ICML, volume 99, pages 143–151, 1999.
  • [24] Aritra Ghosh, Himanshu Kumar, and P Shanti Sastry. Robust loss functions under label noise for deep neural networks. In Proceedings of the AAAI conference on artificial intelligence, volume 31, 2017.
  • [25] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 770–778, 2016.
  • [26] Geoffrey E. Hinton, Oriol Vinyals, and Jeffrey Dean. Distilling the knowledge in a neural network. CoRR, abs/1503.02531, 2015.
  • [27] Fotis Iliopoulos, Vasilis Kontonis, Cenk Baykal, Gaurav Menghani, Khoa Trinh, and Erik Vee. Weighted distillation with unlabeled examples. In NeurIPS, 2022.
  • [28] Lu Jiang, Zhengyuan Zhou, Thomas Leung, Li-Jia Li, and Li Fei-Fei. Mentornet: Learning data-driven curriculum for very deep neural networks on corrupted labels. In International Conference on Machine Learning, pages 2304–2313. PMLR, 2018.
  • [29] Akisato Kimura, Zoubin Ghahramani, Koh Takeuchi, Tomoharu Iwata, and Naonori Ueda. Few-shot learning of neural networks from scratch by pseudo example optimization. arXiv preprint arXiv:1802.03039, 2018.
  • [30] Alex Krizhevsky, Geoffrey Hinton, et al. Learning multiple layers of features from tiny images. 2009.
  • [31] Abhishek Kumar and Ehsan Amid. Constrained instance and class reweighting for robust learning under label noise. arXiv preprint arXiv:2111.05428, 2021.
  • [32] Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, and Radu Soricut. Albert: A lite bert for self-supervised learning of language representations. arXiv preprint arXiv:1909.11942, 2019.
  • [33] Hunter Lang, Aravindan Vijayaraghavan, and David Sontag. Training subset selection for weak supervision. arXiv preprint arXiv:2206.02914, 2022.
  • [34] Dong-Hyun Lee et al. Pseudo-label: The simple and efficient semi-supervised learning method for deep neural networks. In Workshop on challenges in representation learning, ICML, volume 3, page 896, 2013.
  • [35] Zhaoqi Leng, Mingxing Tan, Chenxi Liu, Ekin Dogus Cubuk, Jay Shi, Shuyang Cheng, and Dragomir Anguelov. Polyloss: A polynomial expansion perspective of classification loss functions. In International Conference on Learning Representations, 2022.
  • [36] Lu Liu and Robby T Tan. Certainty driven consistency loss on multi-teacher networks for semi-supervised learning. Pattern Recognition, 120:108140, 2021.
  • [37] Tongliang Liu and Dacheng Tao. Classification with noisy labels by importance reweighting. IEEE Transactions on pattern analysis and machine intelligence, 38(3):447–461, 2015.
  • [38] Michal Lukasik, Srinadh Bhojanapalli, Aditya Menon, and Sanjiv Kumar. Does label smoothing mitigate label noise? In International Conference on Machine Learning, pages 6448–6458. PMLR, 2020.
  • [39] Andrew L. Maas, Raymond E. Daly, Peter T. Pham, Dan Huang, Andrew Y. Ng, and Christopher Potts. Learning word vectors for sentiment analysis. In Proceedings of the 49th Annual Meeting of the Association for Computational Linguistics: Human Language Technologies, pages 142–150, Portland, Oregon, USA, June 2011. Association for Computational Linguistics.
  • [40] Negin Majidi, Ehsan Amid, Hossein Talebi, and Manfred K Warmuth. Exponentiated gradient reweighting for robust training under label noise and beyond. arXiv preprint arXiv:2104.01493, 2021.
  • [41] Rafael Müller, Simon Kornblith, and Geoffrey Hinton. Subclass distillation. arXiv preprint arXiv:2002.03936, 2020.
  • [42] Nagarajan Natarajan, Inderjit S Dhillon, Pradeep K Ravikumar, and Ambuj Tewari. Learning with noisy labels. Advances in neural information processing systems, 26, 2013.
  • [43] Giorgio Patrini, Alessandro Rozza, Aditya Krishna Menon, Richard Nock, and Lizhen Qu. Making deep neural networks robust to label noise: A loss correction approach. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 1944–1952, 2017.
  • [44] Hieu Pham, Zihang Dai, Qizhe Xie, and Quoc V Le. Meta pseudo labels. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 11557–11568, 2021.
  • [45] Geoff Pleiss, Tianyi Zhang, Ethan Elenberg, and Kilian Q Weinberger. Identifying mislabeled data using the area under the margin ranking. Advances in Neural Information Processing Systems, 33:17044–17056, 2020.
  • [46] Ilija Radosavovic, Piotr Dollár, Ross Girshick, Georgia Gkioxari, and Kaiming He. Data distillation: Towards omni-supervised learning. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 4119–4128, 2018.
  • [47] Mengye Ren, Wenyuan Zeng, Bin Yang, and Raquel Urtasun. Learning to reweight examples for robust deep learning. In International conference on machine learning, pages 4334–4343. PMLR, 2018.
  • [48] Mamshad Nayeem Rizve, Kevin Duarte, Yogesh S Rawat, and Mubarak Shah. In defense of pseudo-labeling: An uncertainty-aware pseudo-label selection framework for semi-supervised learning. In International Conference on Learning Representations (ICLR), 2021.
  • [49] Olga Russakovsky, Jia Deng, Hao Su, Jonathan Krause, Sanjeev Satheesh, Sean Ma, Zhiheng Huang, Andrej Karpathy, Aditya Khosla, Michael Bernstein, et al. Imagenet large scale visual recognition challenge. International journal of computer vision, 115(3):211–252, 2015.
  • [50] Hwanjun Song, Minseok Kim, Dongmin Park, Yooju Shin, and Jae-Gil Lee. Learning from noisy labels with deep neural networks: A survey. IEEE Transactions on Neural Networks and Learning Systems, 2022.
  • [51] Samuel Stanton, Pavel Izmailov, Polina Kirichenko, Alexander A Alemi, and Andrew G Wilson. Does knowledge distillation really work? Advances in Neural Information Processing Systems, 34:6906–6919, 2021.
  • [52] Yonglong Tian, Dilip Krishnan, and Phillip Isola. Contrastive representation distillation. arXiv preprint arXiv:1910.10699, 2019.
  • [53] Qizhe Xie, Minh-Thang Luong, Eduard Hovy, and Quoc V Le. Self-training with noisy student improves imagenet classification. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pages 10687–10698, 2020.
  • [54] I Zeki Yalniz, Hervé Jégou, Kan Chen, Manohar Paluri, and Dhruv Mahajan. Billion-scale semi-supervised learning for image classification. arXiv preprint arXiv:1905.00546, 2019.
  • [55] Xiangli Yang, Zixing Song, Irwin King, and Zenglin Xu. A survey on deep semi-supervised learning. IEEE Transactions on Knowledge and Data Engineering, 2022.
  • [56] Zhilu Zhang and Mert Sabuncu. Generalized cross entropy loss for training deep neural networks with noisy labels. Advances in neural information processing systems, 31, 2018.
  • [57] Barret Zoph, Golnaz Ghiasi, Tsung-Yi Lin, Yin Cui, Hanxiao Liu, Ekin Dogus Cubuk, and Quoc Le. Rethinking pre-training and self-training. Advances in neural information processing systems, 33:3833–3845, 2020.
  • [58] Yang Zou, Zhiding Yu, Xiaofeng Liu, BVK Kumar, and Jinsong Wang. Confidence regularized self-training. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pages 5982–5991, 2019.

Appendix A Notation

For two vectors p,q𝐑dp,q\in\mathbf{R}^{d} we denote by pq=i=1dpiqip\cdot q=\sum_{i=1}^{d}p_{i}q_{i} their inner product. We use pqp*q to denote their element-wise product, i.e., (pq)i=piqi(p*q)_{i}=p_{i}q_{i}. We use the notation maxip\max_{i}p to denote the ii-th largest element of the vector pp. We use margin(p)\mathrm{margin}(p) to denote the difference between the top-2 elements of pp, i.e., margin(p)=max1pmax2p\mathrm{margin}(p)=\max_{1}p-\max_{2}p. Moreover, we use margink(p)\mathrm{margin}_{k}(p) to denote the top-k margin, i.e., margink(p)=i=1kmaxipmaxk+1p\mathrm{margin}_{k}(p)=\sum_{i=1}^{k}\max_{i}p-\max_{k+1}p. Given a function f(w):𝐑d𝐑f(w):\mathbf{R}^{d}\mapsto\mathbf{R} we denote by wf(w)\partial_{w}f(w) the gradient of ff with respect to the parameter ww.

Appendix B Detailed Description of SLaM

B.1 Estimating the Teacher’s Accuracy Parameters: α(x),k(x)\alpha(x),k(x)

Estimating the Teacher’s Accuracy α(x)\alpha(x) via Isotonic Regression

We now turn our attention to the problem of estimating α(x)\alpha(x) for each xx of dataset BB, i.e., the dataset labeled by the teacher model. In [27] the authors empirically observed that α(x)\alpha(x) correlates with metrics of teacher’s confidence such as the “margin”, i.e., the difference between the probabilities assigned in the top-1 class and the second largest class according to the teacher’s soft label ysy_{s}. In particular, the larger the margin is the more likely is that the corresponding teacher label is correct. We exploit (and enforce) this monotonicity by employing isotonic regression on a small validation dataset to learn the mapping from the teacher’s margin at an example xx to the corresponding teacher’s accuracy α(x)\alpha(x).

To perform this regression task we use a small validation dataset VV with correct labels that the teacher has not seen during training. For every example xVx\in V we compute the corresponding soft-teacher label ys(x)y_{s}(x) and compute its margin margin(x)=max1(ys(x))max2(ys(x))\mathrm{margin}(x)=\max_{1}(y_{s}(x))-\max_{2}(y_{s}(x)). For every xVx\in V we also compute the hard-prediction of the teacher and compare it with the ground-truth, i.e., for every xx the covariate and responce pair is (margin(x),1err(g(x),y(x)))(\mathrm{margin}(x),1-\mathrm{err}(g(x),y(x))). We then use isotonic regression to fit a piecewise constant, increasing function to the data. Sorting the regression data {(margin(x),1err(g(x),y(x)))xV}\{(\mathrm{margin}(x),1-\mathrm{err}(g(x),y(x)))x\in V\} by increasing margin to obtain a list (c(1),,r(1)),,(c(m),r(m))(c^{(1)},\ldots,r^{(1)}),\ldots,(c^{(m)},r^{(m)}), isotonic regression solves the following task

minr^(1),,r^(m)i=1m(r(i)r^(i))2\displaystyle\min_{\hat{r}^{(1)},\ldots,\hat{r}^{(m)}}\sum_{i=1}^{m}(r^{(i)}-\hat{r}^{(i)})^{2}
subject to lbr^(i)r^(i+1)1,\displaystyle\text{subject to ~{} }\mathrm{lb}\leq\hat{r}^{(i)}\leq\hat{r}^{(i+1)}\leq 1,

where the parameter lb\mathrm{lb} is a lower bound on the values r^(i)\hat{r}^{(i)} and is a hyper-parameter that we tune. On the other hand, the upper bound for the values can be set to 11 since we know that the true value α(x)\alpha(x) is at most 11 for every xx (since it corresponds to the probability that the teacher-label is correct). After we compute the values r^(1),,r^(m)\hat{r}^{(1)},\ldots,\hat{r}^{(m)} for any given c[0,1]c\in[0,1] the output of the regressor is the value of r^(i)\hat{r}^{(i)} corresponding to the smallest c(i)c^{(i)} that is larger-than or equal to cc. This is going to be our estimate for α(x)\alpha(x). We remark that finding the values r(i)r^{(i)} can be done efficiently in O(n)O(n) time after sorting the data (which has a runtime of O(nlogn)O(n\log n)) so the whole isotonic regression task can be done very efficiently.

Estimating k(x)k(x).

We now describe our process for estimating the values of α(x)\alpha(x) and k(x)k(x) for every example of dataset BB. Similarly to the binary classification setting, we estimate the accuracy probability α(x)\alpha(x) using isotonic regression on a small validation dataset. The value of k(x)k(x) can be set to be equal to a fixed value of kk for all data, so that the top-k accuracy of the teacher on the validation data is reasonable (say above 60%60\%). For example, in our ImageNet experiments, we used k=5k=5. We also provide a data-dependent method to find different values k(x)k(x) for every example xx. To do this we adapt the method for estimating the top-1 accuracy α(x)\alpha(x) of the teacher from the validation dataset. For every value of k=2,,L1k=2,\ldots,L-1 we compute the top-k margin of the teacher’s predictions on the validation data which is equal to the sum of the top-k probabilities of the teacher soft-label minus the probability assigned to the k+1k+1-th class, i.e.,

margink(ys(x))=(i=1kmaxiys(x))maxk+1ys(x).\mathrm{margin}_{k}(y_{s}(x))=\Big{(}\sum_{i=1}^{k}\max_{i}y_{s}(x)\Big{)}-\max_{k+1}y_{s}(x)\,.

Using the top-k margin as the covariate and the top-k accuracy as the response we solve the corresponding regression task using isotonic regression to obtain the value αk(x)\alpha_{k}(x) representing the probability that the true label belongs in the top-k predictions of the teacher soft-label. For some threshold, say 90%90\%, for every xx we set k(x)k(x) to be the smallest value of kk so that αk(x)90%\alpha_{k}(x)\geq 90\%. We empirically observed that using larger thresholds for the top-k accuracy (e.g., 90%90\% or 95%95\%), is better. We remark that while using the top-k margin as the covariate in the regression task is reasonable, our method can be used with other “uncertainty metrics” of the teacher’s soft-labels, e.g., the entropy of the distribution of ys(x)y_{s}(x) after grouping together the top-k elements. The higher this entropy metric is the more likely that the top-k accuracy probability α(x)k\alpha(x)_{k} of the teacher is low.

B.2 SLaM for Distillation with Unlabeled Examples: Pseudocode

In this section we present pseudo-code describing the distillation with unlabeled examples setting and the SLaM method, Algorithm 1.

Remark B.1.

We remark that in our experiments, we observed that not normalizing the mixing operation with k(x)1k(x)-1 resulted in better results overall. Therefore, the mixing operation used in our experimental evaluation of SLaM is mix(f(x;w);α(x),k(x))=α(x)f(x;w)+(1α(x))(1f(x;w))top(ys(x);k(x))\mathrm{mix}(f(x;w);\alpha(x),k(x))=\alpha(x)f(x;w)+(1-\alpha(x))(1-f(x;w))*\mathrm{top}(y_{s}(x);k(x)). For more details we refer the reader to the code provided in the supplementary material.

Algorithm 1 Student Label Mixing (SLaM) Distillation
  Input: Labeled Dataset A, Labeled Validation dataset V, Unlabeled Dataset U
  Output: A trained Student model f(x;w)f(x;w)
  
  Train Teacher model on Labeled Dataset A
  Pre-train Student model on Labeled Dataset A
  
  #   Label examples of Dataset U using the Teacher
  BB\leftarrow\emptyset
  for each xUx\in U do
     Add (x,ys(x))(x,y_{s}(x)) to BB     # For hard-distillation use y(x)y(x)
  end for
  
  #   Learn Teacher Accuracy Statistics α(x),k(x)\alpha(x),k(x) Algorithm 2
  α^(x),k^(x)LearnAccuracyStatistics(y(),V,B)\hat{\alpha}(x),\hat{k}(x)\leftarrow\mathrm{LearnAccuracyStatistics}(y(\cdot),V,B)
  Train student f(x;w)f(x;w) using the SLaM loss:
(x,y)AV(y,f(x;w))+(x,y)B(y,mix(f(x;w);a^(x),k^(x)))\sum_{(x,y)\in A\cup V}\ell(y,f(x;w))+\sum_{(x,y)\in B}\ell(y,\mathrm{mix}{(f(x;w);\hat{a}(x),\hat{k}(x))})
  Input: (Noisy) Teacher Model ys(x)y_{s}(x), Labeled Validation dataset V, Isotonic-Regression lower-bound lb[0,1]\mathrm{lb}\in[0,1], and top-k accuracy threshold t[0,1]t\in[0,1].
  Output: Estimates α^(x),k^(x)\hat{\alpha}(x),\hat{k}(x) of the actual α(x),k(x)\alpha(x),k(x).
  
  Create Soft-labels for the Validation dataset using the teacher model {ys(x):xV}\{y_{s}(x):x\in V\}.
  for j=1j=1 to L1L-1 do
     # Map ys(x)y_{s}(x) to top-j margin and accuracy pairs on the Validation V
     
C{(r=1jmaxrys(x)maxj+1ys(x),1err(ys(x),z)):(x,z)V}.C\leftarrow\left\{\left(\sum_{r=1}^{j}\max_{r}y_{s}(x)-\max_{j+1}y_{s}(x),~{}1-\mathrm{err}(y_{s}(x),z)\right):(x,z)\in V\right\}\,.
     Set α^j(x)\hat{\alpha}_{j}(x) to be the output of Isotonic-Regression with lower-bound lb\mathrm{lb} on the (covariate, responce) pairs in CC.     # See Section B.1
  end for
  a^(x)a^1(x)\hat{a}(x)\leftarrow\hat{a}_{1}(x)
  a^L(x)1\hat{a}_{L}(x)\leftarrow 1     # The top-L accuracy is always (trivially) equal to 1
  Given example xx for some threshold tt set k^(x)\hat{k}(x) to be the smallest integer r{1,,L}r\in\{1,\ldots,L\} so that ar(x)ta_{r}(x)\geq t.
Algorithm 2 Estimating Teacher’s Accuracy Statistics α(x),k(x)\alpha(x),k(x)

Appendix C SLaM Consistency

In the following proposition we show that any minimizer of the SLaM loss over the noisy teacher-data must agree with the ground-truth for all xx (that have positive density). To keep the presentation simple and avoid measurability issues (e.g., considering measure zero sets under XX) in the following we will assume that the example distribution XX is supported on a finite set. We remark that one can easily adapt the proof to hold for any distribution XX (but the result will hold after excluding measure-zero sets under XX).

Proposition C.1 (SLaM Consistency).

Let DD be the distribution of the teacher-labeled examples of dataset BB, i.e., we first draw xXx\sim X and then label it using the noisy teacher of Definition 3.2. Moreover, assume that there exists some parameter w𝒲w^{\ast}\in\mathcal{W} such that the ground-truth g(x)=f(x;w)g(x)=f(x;w^{\ast}). Denote by SLaM(w)=𝐄(x,y)D[(y,mix(f(x;w);α(x),k(x))].\mathcal{L}^{\text{SLaM}}(w)=\operatorname{\mathbf{E}}_{(x,y)\sim D}[\ell(y,\mathrm{mix}(f(x;w);\alpha(x),k(x))]. the SLaM objective. The following hold true.

  1. 1.

    ww^{\ast} minimizes the SLaM objective.

  2. 2.

    Assuming further that for all xx it holds that α(x)k(x)1\alpha(x)k(x)\neq 1, we have that any minimizer ww of the SLaM objective satisfies: f(x;w)=g(x)f(x;w)=g(x) for all xx.

Proof.

Fix any example xXx\in X. By Definition 3.2 we have that the corresponding teacher label yy is correct with probability α(x)\alpha(x) and a uniformly random incorrect label out of the top-k labels according to the teacher soft-label ys(x)y_{s}(x). Recall for an LL-dimension score vector pp, by top(p;k){0,1}L\mathrm{top}(p;k)\in\{0,1\}^{L} we denote the vector that has 11 on the positions of the top-k elements of pp, e.g., top((1,2,3,4,5);2)=(0,0,0,1,1)\mathrm{top}((1,2,3,4,5);2)=(0,0,0,1,1). Conditional on xx, the corresponding expected noisy teacher label is

𝐄[yx]\displaystyle\operatorname{\mathbf{E}}[y\mid x] =𝐏[y=g(x)x]g(x)+𝐏[yg(x)]𝐄[yx,yg(x)]\displaystyle=\operatorname{\mathbf{P}}[y=g(x)\mid x]g(x)+\operatorname{\mathbf{P}}[y\neq g(x)]\operatorname{\mathbf{E}}[y\mid x,y\neq g(x)]
=α(x)g(x)+(1α(x))𝐄[yyg(x),x].\displaystyle=\alpha(x)g(x)+(1-\alpha(x))\operatorname{\mathbf{E}}[y\mid y\neq g(x),x]\,.

We know that the expected teacher label conditional on it being wrong 𝐄[yyg(x),x]\operatorname{\mathbf{E}}[y\mid y\neq g(x),x] is a uniformly random incorrect label from the top-k labels of the corresponding teacher soft-label ys(x)y_{s}(x). Assume first that k=Lk=L, since the ground-truth is represented by a one-hot vector, the distribution of uniformly random incorrect labels conditional on xx can be written as (1g(x))/(L1)(1-g(x))/(L-1). For example, if the ground-truth label is g(x)=(1,0,0,0,0)g(x)=(1,0,0,0,0) then a uniformly random incorrect label has probability distribution (0,1/4,1/4,1/4,1/4)(0,1/4,1/4,1/4,1/4). Assume now that k(x)=3k(x)=3 and top(ys(x);3)=(1,1,1,0,0)\mathrm{top}(y_{s}(x);3)=(1,1,1,0,0). Then the distribution of the (incorrect) teacher label becomes (0,1/2,1/2,0,0)(0,1/2,1/2,0,0). Using * to denote element-wise multiplication of two vectors, we have

𝐄[yx,yg(x)]=1g(x)k(x)1top(ys(x);k(x))\operatorname{\mathbf{E}}[y\mid x,y\neq g(x)]=\frac{1-g(x)}{k(x)-1}*\mathrm{top}(y_{s}(x);k(x))

Therefore, we obtain

𝐄[yx]=α(x)g(x)+(1α(x))1g(x)k(x)1top(ys(x);k(x))=mix(g(x);α(x),k(x)).\operatorname{\mathbf{E}}[y\mid x]=\alpha(x)g(x)+(1-\alpha(x))\frac{1-g(x)}{k(x)-1}*\mathrm{top}(y_{s}(x);k(x))=\mathrm{mix}(g(x);\alpha(x),k(x))\,.

Therefore, by using the fact that Cross-Entropy is linear in its first argument, we obtain that the expected SLaM loss on some example xx is

𝐄[ce(y,mix(f(x;w);α(x),k(x)))x]\displaystyle\operatorname{\mathbf{E}}[\mathrm{ce}(y,\mathrm{mix}(f(x;w);\alpha(x),k(x)))\mid x] =ce(𝐄[yx],mix(f(x;w);α(x),k(x)))\displaystyle=\mathrm{ce}(\operatorname{\mathbf{E}}[y\mid x],\mathrm{mix}(f(x;w);\alpha(x),k(x)))
=ce(mix(g(x;w);α(x),k(x)),mix(f(x;w);α(x),k(x))).\displaystyle=\mathrm{ce}(\mathrm{mix}(g(x;w);\alpha(x),k(x)),\mathrm{mix}(f(x;w);\alpha(x),k(x)))\,.

We first have to show that there exist some parameter w𝒲w\in\mathcal{W} that matches the (expected) observed labels 𝐄[yx]\operatorname{\mathbf{E}}[y\mid x]. Observe first that by using the realizability assumption, i.e.,that there exists ww^{\ast} so that f(x;w)=g(x)f(x;w^{\ast})=g(x) we obtain that, for every xx, it holds mix(g(x);α(x),k(x))=mix(f(x;w);α(x),k(x))\mathrm{mix}(g(x);\alpha(x),k(x))=\mathrm{mix}(f(x;w^{\ast});\alpha(x),k(x)). In fact, by Gibb’s inequality (convexity of Cross-Entropy) we have that ww^{\ast} is a (global) minimizer of the SLaM objective.

We next show that any (global) minimizer of the SLaM objective must agree with the ground-truth for every xx. Since we have shown that ww^{\ast} is able to match the (expected) labels 𝐄[yx]\operatorname{\mathbf{E}}[y\mid x] any other minimizer ww must also satisfy mix(g(x);α(x),k(x))=mix(f(x;w);α(x),k(x)))\mathrm{mix}(g(x);\alpha(x),k(x))=\mathrm{mix}(f(x;w);\alpha(x),k(x))). Assume without loss of generality that g0=1g_{0}=1, i.e., the ground-truth label is 0. We observe that by using that mix(g(x;w);α(x),k(x))=α(x)g(x)+(1α(x))1g(x)k(x)1top(ys(x);k(x))\mathrm{mix}(g(x;w);\alpha(x),k(x))=\alpha(x)g(x)+(1-\alpha(x))\frac{1-g(x)}{k(x)-1}*\mathrm{top}(y_{s}(x);k(x)) and the fact that the ground-truth belongs in the top-k(x)k(x) of the teacher’s predictions conditional that the teacher’s top-1 prediction is incorrect (thus top(ys(x))0=1\mathrm{top}(y_{s}(x))_{0}=1), we obtain that

α(x)g0(x)+(1α(x))(1g0(x))/(1k(x))=α(x)f(x;w)0+(1α(x))(1f(x;w)0)/(k(x)1).\alpha(x)g_{0}(x)+(1-\alpha(x))(1-g_{0}(x))/(1-k(x))=\alpha(x)f(x;w)_{0}+(1-\alpha(x))(1-f(x;w)_{0})/(k(x)-1)\,.

Using the fact that g0=1g_{0}=1 we can simplify the above expression to

(1f(x;w)0)(α(x)1α(x)k(x)1)=0.(1-f(x;w)_{0})\left(\alpha(x)-\frac{1-\alpha(x)}{k(x)-1}\right)=0\,.

Using the assumption that a(x)k(x)1a(x)k(x)\neq 1 we obtain that the term (α(x)1α(x)k(x)1)\left(\alpha(x)-\frac{1-\alpha(x)}{k(x)-1}\right) is not vanishing and therefore it must hold that f(x;w)0=1=g0f(x;w)_{0}=1=g_{0}, i.e., the student model must be equal to the ground-truth.

Appendix D Extended Experimental Evaluation

We implemented all algorithms in Python and used the TensorFlow deep learning library [1]. We ran our experiments on 64 Cloud TPU v4s each with two cores.

D.1 Implementation Details: Vision Datasets

Here we present the implementation details for the vision datasets we considered.

Remark D.1.

We note that in all our experiments, “VID” corresponds to the implementation of the loss described in equation (2), (4) and (6) of [2] (which requires appropriately modifying the student model so that we have access to its embedding layer).

Experiments on CIFAR-{10/100} and CelebA

For the experiments on CIFAR-10/100 and CelebA we use the Adam optimizer with initial learning rate lr=0.001\mathrm{lr}=0.001. We then proceed according to the following learning rate schedule (see, e.g., [25]):

lr{lr0.5103,if #epochs>180 lr103,if #epochs>160 lr102,if #epochs>120 lr101,if #epochs>80\displaystyle\mathrm{lr}\leftarrow\begin{cases}\mathrm{lr}\cdot 0.5\cdot 10^{-3},&\text{if $\#\mathrm{epochs}>180$ }\\ \mathrm{lr}\cdot 10^{-3},&\text{if $\#\mathrm{epochs}>160$ }\\ \mathrm{lr}\cdot 10^{-2},&\text{if $\#\mathrm{epochs}>120$ }\\ \mathrm{lr}\cdot 10^{-1},&\text{if $\#\mathrm{epochs}>80$ }\end{cases}

Finally, we use data-augmentation. In particular, we use random horizontal flipping and random width and height translations with width and height factor, respectively, equal to 0.10.1.

The hyperparameters of each method are optimized as follows. For SLaM we always use 0.50.5 as the lower bound for isotonic regression (i.e., the parameter lb\mathrm{lb} in Algorithm 2). As CelebA is a binary classification benchmark k(x)k(x) is naturally set to 22 for all examples. For CIFAR-10/10 we used the data-dependent method for estimating k(x)k(x) (see Algorithm 2) with threshold parameter t=0.9t=0.9. For weighted distillation we do a grid search over updating the weights every {1,25,50,100,200}\{1,25,50,100,200\} epochs and we report the best average accuracy achieved. Finally, for VID we search over {0.001,0.1,0.2,0.5,0.8,1.0,2.0,10.0,50.0,100.0}\{0.001,0.1,0.2,0.5,0.8,1.0,2.0,10.0,50.0,100.0\} for the coefficient of the VID-related term of the loss function, and for the PolyLoss we optimize its hyperparameter over {1.0,0.8,0.6,0.4,0.2,0.5,1.0,2.0,50.0,100.0}\{-1.0,-0.8,-0.6,-0.4,-0.2,0.5,1.0,2.0,50.0,100.0\}.

Experiments on ImageNet

For the ImageNet experiments we use SGD with momentum 0.90.9 as the optimizer. For data-augmentation we use random horizontal flipping and random cropping. Finally, the learning rate schedule is as follows. For the first 55 epochs the learning rate lr\mathrm{lr} is increased from 0.00.0 to 0.10.1 linearly. After that, the learning rate changes as follows:

lr={0.01,if #epochs>30 0.001,if #epochs>60 0.0001,if #epochs>80 .\displaystyle\mathrm{lr}=\begin{cases}0.01,&\text{if $\#\mathrm{epochs}>30$ }\\ 0.001,&\text{if $\#\mathrm{epochs}>60$ }\\ 0.0001,&\text{if $\#\mathrm{epochs}>80$ }.\end{cases}

The hyperparameters of each method are optimized as follows. For SLaM we do a hyperparameter search over {0.55,0.60,0.65,0.70}\{0.55,0.60,0.65,0.70\} for the lower bound for isotonic regression, and we keep the best performing value for each potential size of dataset AA. We used the fixed value 55 for k(x)k(x), as the top-5 accuracy of the teacher model was satisfactory (much higher than its top-1 accuracy) on the validation dataset. For Taylor-CE we did a hyper-parameter search for the Taylor series truncation values in {1,2,3,4,5,6,10,20,50,80,100}\{1,2,3,4,5,6,10,20,50,80,100\}. For weighted distillation we compute the weights in a one-shot fashion using the pre-trained student (as in the ImageNet experiments in [27]). For VID we search over {0.1,0.3,0.5}\{0.1,0.3,0.5\} for the coefficient of the VID-related term of the loss function, and for the PolyLoss we optimize its hyperparameter over {1.0,2.0,50.0,100.0}\{1.0,2.0,50.0,100.0\}.

Refer to caption
Figure 3: Comparison of distillation methods on ImageNet. On the horizontal axis we plot the size of Dataset A as a percentage of the whole training dataset. On the vertical axis we plot the accuracy of the trained student-model on the test dataset.

D.2 Hard-Distillation

Here we present results on hard-distillation. The hyper-parameters of all methods are chosen the same way as in our soft-distillation experiments, see Section D.1. Tables 56 and 7 contain our results on CIFAR-10, CIFAR-100 and CelebA, respectively. We observe that in almost all cases, SLaM consistently outperforms the other baselines. Moreover, for CIFAR-10 and CIFAR-100 hard-distillation performs worse than soft-distillation (as it is typical the case) but in CelebA hard-distillation seems to be performing on par with (sometimes even outperforming) soft-distillation. A plausible explanation for the latter outcome is that in our CelebA experiments the teacher and student have different architectures (MobileNet and ResNet, respectively) so that soft-labels from the teacher are not so informative for the student. (This is also a binary classification task where the information passed from the teacher to the student through its soft-labels is limited.)

Table 5: Experiments on CIFAR-10 (hard-distillation). See Section 4.2 for details.
Labeled Examples 50005000 75007500 1000010000 1250012500 1500015000 1750017500
Teacher 61.3061.30 68.9868.98 72.4272.42 73.9273.92 76.6376.63 78.6378.63
Vanilla 62.26±0.4562.26\pm 0.45 69.07±0.1169.07\pm 0.11 72.09±0.1172.09\pm 0.11 73.43±0.1673.43\pm 0.16 75.93±0.2575.93\pm 0.25 77.43±0.1577.43\pm 0.15
Taylor-CE [20] 63.14±0.0763.14\pm 0.07 69.98±0.1169.98\pm 0.11 72.72±0.3672.72\pm 0.36 73.77±0.2873.77\pm 0.28 76.26±0.2976.26\pm 0.29 77.88±0.2077.88\pm 0.20
UPS [48] 64.27±0.0864.27\pm 0.08 70.93±0.2670.93\pm 0.26 73.78±0.1673.78\pm 0.16 74.66±0.2974.66\pm 0.29 77.38±0.3777.38\pm 0.37 78.95±0.0878.95\pm 0.08
VID [3] 61.95±0.2261.95\pm 0.22 66.91±0.2166.91\pm 0.21 69.59±0.2469.59\pm 0.24 72.16±0.4772.16\pm 0.47 74.83±0.1174.83\pm 0.11 75.55±0.2175.55\pm 0.21
Weighted [27] 63.22±0.4563.22\pm 0.45 71.04±0.2671.04\pm 0.26 72.84±0.1272.84\pm 0.12 74.20±0.1674.20\pm 0.16 76.56±0.2476.56\pm 0.24 78.23±0.1578.23\pm 0.15
SLaM (Ours) 66.40±0.31\mathbf{66.40\pm 0.31} 72.44±0.17\mathbf{72.44\pm 0.17} 74.77±0.13\mathbf{74.77\pm 0.13} 75.64±0.19\mathbf{75.64\pm 0.19} 77.99±0.36\mathbf{77.99\pm 0.36} 79.26±0.26\mathbf{79.26\pm 0.26}
Table 6: Experiments on CIFAR-100 (hard-distillation). See Section 4.2 for details.
Labeled Examples 50005000 75007500 1000010000 1250012500 1500015000 1750017500
Teacher 35.97 44.65 49.62 55.68 59.19 62.05
Vanilla 36.36±0.0436.36\pm 0.04 44.15±0.1044.15\pm 0.10 50.22±0.0750.22\pm 0.07 55.55±0.2455.55\pm 0.24 58.85±0.158.85\pm 0.1 61.43±0.1961.43\pm 0.19
Taylor-CE [20] 39.12±0.1439.12\pm 0.14 46.87±0.1046.87\pm 0.10 52.64±0.2252.64\pm 0.22 57.19±0.2857.19\pm 0.28 59.95±0.1159.95\pm 0.11 62.36±0.2162.36\pm 0.21
UPS [48] 39.49±0.1339.49\pm 0.13 48.36±0.4448.36\pm 0.44 53.95±0.1053.95\pm 0.10 57.95±0.1057.95\pm 0.10 60.59±0.2960.59\pm 0.29 62.09±0.2862.09\pm 0.28
VID [3] 37.19±0.0937.19\pm 0.09 44.67±0.1644.67\pm 0.16 50.63±0.3550.63\pm 0.35 54.78±0.0754.78\pm 0.07 59.27±0.1459.27\pm 0.14 62.01±0.0562.01\pm 0.05
Weighted [27] 38.04±0.2938.04\pm 0.29 46.45±0.2246.45\pm 0.22 52.33±0.1852.33\pm 0.18 57.43±0.1357.43\pm 0.13 60.81±0.0960.81\pm 0.09 63.02±0.0663.02\pm 0.06
SLaM (Ours) 42.01±0.29\mathbf{42.01\pm 0.29} 49.08±0.14\mathbf{49.08\pm 0.14} 54.49±0.17\mathbf{54.49\pm 0.17} 58.53±0.04\mathbf{58.53\pm 0.04} 61.12±0.15\mathbf{61.12\pm 0.15} 63.21±0.18\mathbf{63.21\pm 0.18}
Table 7: Experiments on CelebA (hard-distillation). See Section 4.2 for details.
Labeled Examples 2%2\% 3%3\% 4%4\% 5%5\% 6%6\% 7%7\%
Teacher 86.1986.19 88.2588.25 88.9588.95 91.3191.31 92.0992.09 92.6292.62
Vanilla 89.73±0.0889.73\pm 0.08 91.61±0.0991.61\pm 0.09 92.05±0.1192.05\pm 0.11 93.41±0.1393.41\pm 0.13 94.02±0.1594.02\pm 0.15 94.05±0.0494.05\pm 0.04
Taylor-CE [20] 90.62±0.05\mathbf{90.62\pm 0.05} 92.19±0.0292.19\pm 0.02 92.66±0.1192.66\pm 0.11 93.60±0.1493.60\pm 0.14 94.00±0.0494.00\pm 0.04 94.38±0.1094.38\pm 0.10
UPS [48] 89.35±0.0489.35\pm 0.04 91.30±0.0491.30\pm 0.04 91.95±0.1291.95\pm 0.12 93.18±0.0793.18\pm 0.07 93.71±0.0493.71\pm 0.04 94.18±0.0394.18\pm 0.03
VID [3] 89.92±0.2189.92\pm 0.21 91.60±0.1191.60\pm 0.11 92.20±0.1292.20\pm 0.12 93.51±0.1593.51\pm 0.15 94.08±0.1594.08\pm 0.15 94.27±0.1094.27\pm 0.10
Weighted [27] 90.06±0.0690.06\pm 0.06 91.97±0.1391.97\pm 0.13 92.45±0.1092.45\pm 0.10 93.60±0.0793.60\pm 0.07 93.94±0.1293.94\pm 0.12 94.25±0.1694.25\pm 0.16
SLaM (Ours) 90.43±0.0590.43\pm 0.05 92.25±0.11\mathbf{92.25\pm 0.11} 92.71±0.08\mathbf{92.71\pm 0.08} 93.96±0.17\mathbf{93.96\pm 0.17} 94.39±0.21\mathbf{94.39\pm 0.21} 94.52±0.12\mathbf{94.52\pm 0.12}
Refer to caption
Refer to caption
Refer to caption
Figure 4: Comparison of distillation methods on CIFAR-10,100 and CelebA. On the horizontal axis we plot the size of Dataset A as a percentage of the whole training dataset. On the vertical axis we plot the accuracy of the trained student-model on the test dataset.

D.3 Large Movies Reviews Dataset Results

Here we present the results and the implementation details regarding the experiments on the Large Movies Reviews dataset. Recall that we use an ALBERT-large model as a teacher, and an ALBERT-base model as a student. We also use 2%,4%,8%,40%2\%,4\%,8\%,40\% percent (or 500, 1000, 2000, 10000 examples) from the training dataset and split the remaining data in a validation dataset of 500 examples and an unlabeled dataset U. We compare the methods on the soft-distillation. For each trial we train the student model for 4040 epochs and keep the best test accuracy over all epochs. We perform 33 trials and report the average of each method and the variance of the achieved accuracies over the trials. The results of our experiments can be found in Table 8. We remark that we did not implement the UPS method for this dataset as the data-augmentation method for estimating the teacher’s accuracy could not be readily used for this NLP dataset. Moreover, using dropout and Monte Carlo estimation for the uncertainty was also not compatible with the Albert model used in this experiment.

Since we are dealing with ALBERT-models (which are already pre-trained), we do not pre-train the student model on dataset A except in the case of “weighted-distillation” [27], where we pre-train the student model on dataset A just for 11 epoch. The teacher model is trained using the Adam optimizer for 2020 epochs with initial learning rate 10610^{-6}. The student model is trained also using the Adam optimizer but for 4040 epochs and with learning rate 10710^{-7}.

The hyperparameters of each method are optimized as follows. For SLaM we do a hyperparameter search over {0.5,0.6,0.7,0.8,0.9}\{0.5,0.6,0.7,0.8,0.9\} for the lower bound for isotonic regression, and we keep the best performing value for each potential size of dataset AA. As this is a binary classification benchmark we naturally set k(x)=2k(x)=2 for all examples. For weighted distillation we do a grid search over updating the weights every {1,10,20,40}\{1,10,20,40\} epochs and, similarly, we report the best average accuracy achieved. Finally, for VID (recall also Remark D.1) we search over {0.1,0.5,1.0,2.0}\{0.1,0.5,1.0,2.0\} for the coefficient of the VID-related term of the loss function, and for the PolyLoss we opitmize its hyperparameter over {1.0,0.8,0.6,0.4,0.2,0.5,1.0,2.0}\{-1.0,-0.8,-0.6,-0.4,-0.2,0.5,1.0,2.0\}.

Table 8: Experiments on the Large Movies Reviews Dataset (soft-distillation). See Section D.3 for details.
Labeled Examples 2%2\% 4%4\% 8%8\% 40%40\%
Teacher 77.5277.52 84.0484.04 85.4485.44 88.388.3
Vanilla 80.93±0.1080.93\pm 0.10 85.12±0.2985.12\pm 0.29 85.99±0.0885.99\pm 0.08 87.50±0.687.50\pm 0.6
Taylor-CE [20] 79.5±0.3879.5\pm 0.38 85.14±0.1385.14\pm 0.13 85.98±0.1485.98\pm 0.14 87.57±0.387.57\pm 0.3
VID [3] 81.76±0.3281.76\pm 0.32 85.33±0.3585.33\pm 0.35 86.17±0.0686.17\pm 0.06 87.71±0.0187.71\pm 0.01
Weighted [27] 81.1±+0.181.1\pm+0.1 85.2±0.0585.2\pm 0.05 86.13±0.1786.13\pm 0.17 87.8±0.25\mathbf{87.8\pm 0.25}
SLaM (Ours) 81.88±0.23\mathbf{81.88\pm 0.23} 85.5±0.09\mathbf{85.5\pm 0.09} 86.23±0.13\mathbf{86.23\pm 0.13} 87.73±0.3887.73\pm 0.38

D.4 Combining with Teacher-Uncertainty-Based Reweighting Techniques

As we discussed in Section 2, our method can in principle be combined with teacher-uncertainty filtering and weighting schemes as these can be seen as preprocessing steps. To demonstrate this, we combine our method with the so-called fidelity-based weighting scheme of [17]. The fidelity weighting scheme reweights examples using some uncertainty measure for teacher’s labels, e.g., by performing random data-augmentations and estimating the variance of the resulting teacher labels or using dropout and Monte Carlo estimation. More precisely, for every example xx in the teacher-labeled dataset BB, the fidelity-weighting scheme assigns the weight wFid(x)=exp(βuncertaintyteacher(x))w^{\mathrm{Fid}}(x)=\exp(-\beta~{}\mathrm{uncertainty}^{\mathrm{teacher}}(x)) for some hyper-parameter β>0\beta>0. In our experiments we performed 1010 random data augmentations (random crop and resize), estimated the coordinate-wise variance of the resulting teacher soft-labels, and finally computed the average of the variances of the kk-classes, as proposed in [17]. We normalized the above uncertainty of each example by the total uncertainty of the teacher over the whole dataset BB. The weights of examples in dataset AA are set to 11 and the reweighted objective is optimized over the combination of the datasets A,BA,B.

fid(w)=\displaystyle\mathcal{L}^{\mathrm{fid}}(w)= 1|AB|((x,y)A(y,f(x;w))+(x,y)BwFid(x)(y,f(x;w))).\displaystyle\frac{1}{|A\cup B|}\Bigg{(}\sum_{(x,y)\in A}\ell(y,f(x;w))+\sum_{(x,y)\in B}w^{\mathrm{Fid}}(x)~{}\ell(y,f(x;w))\Bigg{)}\,. (3)
Refer to caption
Figure 5: Composability the fidelity-based weighting scheme of [17]. The xx-axis shows the different values of the fidelity hyper-parameter β\beta and the size of dataset A. From left to right we increase the size of dataset A from 10%10\% to 35%35\% and for each size we try different values of β\beta. We observe that SLaM on its own (shown in green) is usually much better than the fidelity weighting scheme (shown in orange). Moreover, using SLaM on top of the fidelity weighting scheme (shown in blue) consistently improves its performance.

To demonstrate the composability of our method with such uncertainty-based weighting schemes, we use CIFAR100 and the percentage of the labeled dataset A (as a fraction of the whole training set) is 10%,15%,20%,25%,30%,35%10\%,15\%,20\%,25\%,30\%,35\%, similar to the setting of Section 4.2. The teacher is a ResNet110 and the student is a ResNet56. We first train the student using only the fidelity weighting scheme, i.e., optimize the loss function of Equation 4 using different values for the hyperparameter β{0.1,0.2,1.0,1.2,2.0,5.0,10.0,20.0}\beta\in\{0.1,0.2,1.0,1.2,2.0,5.0,10.0,20.0\}, i.e., ranging from mildly reweighting the examples of dataset B to more agressively “removing” examples where the teacher’s entropy is large. For the same values of β\beta we then train the student using the reweighted SLaM objective:

Fid+SLaM(w)=1|AB|((x,y)A(y,f(x;w))+(x,y)Bwfid(x)(y,mix(f(x;w);α(x),k(x))).\displaystyle\mathcal{L}^{\mathrm{Fid+SLaM}}(w)=\frac{1}{|A\cup B|}\Bigg{(}\sum_{(x,y)\in A}\ell(y,f(x;w))+\sum_{(x,y)\in B}w^{\mathrm{fid}}(x)~{}\ell(y,\mathrm{mix}(f(x;w);\alpha(x),k(x))\Bigg{)}\,. (4)

For the combined SLaM + Fidelity method we did not perform hyper-parameter search and used the same parameters for the isotonic regression as we did in the “standard” SLaM experiment in CIFAR100 of Section D.1. We present our comprehensive results for all sizes of dataset A and values of the hyper-parameter β\beta in Figure 5. Our results show that, regardless of the value of the hyperparameter β\beta and the size of the labeled dataset A, using SLaM together with the fidelity weighting scheme provides consistent improvements. Moreover, in Figure 5, we observe that by using SLaM the achieved accuracy depends less on the hyper-parameter β\beta: since SLaM takes into account the fact that some of the teacher’s predictions are incorrect, it is not crucial to down-weight them or filter them out.

D.5 Using Distillation Temperature

In this section we show that our approach can be effectively combined with temperature-scaling [26]. Choosing the right distillation temperature often provides significant improvements. In our setting, the teacher provides much more confident predictions (e.g., soft-labels with high-margin) on dataset A (where the teacher was trained) compared to the teacher soft-labels of dataset B where the teacher is, on average, less confident. Given this observation, it is reasonable to use different distillation temperatures for dataset A and dataset B. We try different temperatures for dataset A and dataset B and perform vanilla distillation with temperature and also consider applying the temperature scaling before applying SLaM. For each size of dataset A we try pairs of temperatures tA,tB{0.01,0.1,0.5,0.8,1.,2.,5.,10.,100.}t_{A},t_{B}\in\{0.01,0.1,0.5,0.8,1.,2.,5.,10.,100.\} and report the best accuracy achieved by vanilla distillation and the best achieved by first applying temperature scaling and then SLaM. In Figure 6 we observe that SLaM with temperature scaling consistently improves over vanilla distillation with temperature.

Refer to caption
Figure 6: CIFAR100: Temperature Ablation. On the x-axis we have the size of the labeled dataset (as a percentage of the whole training dataset) that the teacher model uses for training.

D.6 Using SLaM with other loss functions beyond cross-entropy

In this section, we demonstrate that our method can be successfully applied when the student loss function comes from the families of losses introduced in [20] and [35]. We perform experiments on CIFAR-100 and ImageNet following the setting of Section 4.2. In particular, we compare vanilla distillation with unlabeled examples using the Taylor-CE loss of [20] and the PolyLoss of [35], with combining SLaM with these losses. For the Taylor-CE loss we set the “degree” hyperparameter to be 22 (as suggested in [20]) and we set the hyperparameter of the PolyLoss to be 2.02.0 (as suggested in [35]). The corresponding results can be found in  Figure 7.

Refer to caption
Refer to caption
Refer to caption
Figure 7: Using SLaM with PolyLoss [35] and Taylor CE [20]. On the x-axis we have the size of the labeled dataset (as a percentage of the whole training dataset) that the teacher model uses for training. See Section D.6 for more details.

Appendix E Distilling Linear Models and Learning Noisy Halfspaces

In this section we state and prove our convergence result for the SLaM method when applied to linear models. Our assumption is that the ground-truth g(x)g(x) corresponds to a halfspace, i.e., g(x)=(𝟏{wx>0},𝟏{wx0})g(x)=(\mathbf{1}\{w^{\ast}\cdot x>0\},\mathbf{1}\{w^{\ast}\cdot x\leq 0\}) for some unknown weight vector ww^{\ast}. We show that using SLaM with a linear model as the student will recover the ground truth classifier. We make the standard assumption that the ground-truth halfspace has γ\gamma-margin, i.e., that w2=1\|w^{\ast}\|_{2}=1 and that it holds |wx|γ|w^{\ast}\cdot x|\geq\gamma for all examples xx. For a fixed example xx, the observed noisy teacher-label yy satisfies Definition 3.2, i.e., y=g(x)y=g(x) w.p. α(x)\alpha(x) and y=1g(x)y=1-g(x) w.p. 1α(x)1-\alpha(x) (since k=2k=2 for binary classification). Our approach consists of using the standard cross-entropy loss ce(p,q)\mathrm{ce}(p,q) and training a student-model consisting of a linear layer plus a soft-max activation, i.e.,

f(x;w)=(f0(x;w),f1(x;w))=(11+ewx,ewx1+ewx).f(x;w)=(f_{0}(x;w),f_{1}(x;w))=\left(\frac{1}{1+e^{-w\cdot x}},\frac{e^{-w\cdot x}}{1+e^{-w\cdot x}}\right)\,.

Recall, that for binary classification, we define the mixing operation as

mix(f(x;w);α(x))=α(x)f(x;w)+(1α(x))(1f(x;w)).\mathrm{mix}(f(x;w);\alpha(x))=\alpha(x)f(x;w)+(1-\alpha(x))(1-f(x;w))\,.
Algorithm 3 SLaM for Linear Models
  Initialiaze weight vector of student w(0)0w^{(0)}\leftarrow 0
  for t=1,,Tt=1,\ldots,T do
     Draw example x(t)Xx^{(t)}\sim X.
     Label x(t)x^{(t)} with (noisy) teacher to obtain y(t)y^{(t)}
     Compute the gradient of the SLaM loss at (x(t),y(t))(x^{(t)},y^{(t)}):
g(t)wce(y(t),mix(f(x(t));w(t1)),α(x(t)))w=w(t1)g^{(t)}\leftarrow\partial_{w}\mathrm{ce}(y^{(t)},\mathrm{mix}(f(x^{(t)});w^{(t-1)}),\alpha(x^{(t)}))\mid_{w=w^{(t-1)}}
     Compute step size: λ(t)1/r(f(x(t);w(t1)),α(x(t)))\lambda^{(t)}\leftarrow 1/r(f(x^{(t)};w^{(t-1)}),\alpha(x^{(t)})) (see Lemma E.3 for the definition of r(,)r(\cdot,\cdot)).
     Update the student model: w(t)w(t1)λ(t)g(t)w^{(t)}\leftarrow w^{(t-1)}-\lambda^{(t)}~{}g^{(t)}
  end for
Theorem E.1 (Student Label Mixing Convergence).

Let XX be a distribution on 𝐑d\mathbf{R}^{d} and g(x)g(x) be the ground-truth halfspace with normal vector w𝐑dw^{\ast}\in\mathbf{R}^{d}. Let DD be the distribution over (noisy) teacher-labeled examples (x,y)(x,y) whose xx-marginal is XX. We denote by α(x)\alpha(x) the probability that the teacher label y[0,1]2y\in[0,1]^{2} is correct, i.e., α(x)=𝐏(x,y)D[argmax(y)=g(x)x]\alpha(x)=\operatorname{\mathbf{P}}_{(x,y)\sim D}[\operatorname*{argmax}(y)=g(x)\mid x]. Assume that there exist β,γ>0\beta,\gamma>0 such that for all examples xx in the support of XX it holds that |wx|γ|w^{\ast}\cdot x|\geq\gamma and |1/2α(x)|β|1/2-\alpha(x)|\leq\beta. Let ϵ>0\epsilon>0. After T=O(1/(β2γ2ϵ2))T=O(1/(\beta^{2}\gamma^{2}\epsilon^{2})) iterations of SLaM (Algorithm 3), with probability at least 99%99\%, there exists an iteration tTt\leq T where 𝐏xX[err(f(x;w(t)),g(x))]ϵ\operatorname{\mathbf{P}}_{x\sim X}[\mathrm{err}(f(x;w^{(t)}),g(x))]\leq\epsilon.

Remark E.2 (High-Probability Result).

We remark that even though our learner succeeds with constant probability (at least %99\%99) we can amplify its success probability to 1δ1-\delta by standard amplification techniques (i.e., by repeating the algorithm O(log(1/δ))O(\log(1/\delta)) times and keeping the best result). To achieve success probability 1δ1-\delta the total sample complexity is O(log(1/δ)/(ϵ2γ2β2))O(\log(1/\delta)/(\epsilon^{2}\gamma^{2}\beta^{2})).

Proof.

We first provide simplified expressions for the gradient of the SLaM objective and the update vectors λ(t)g(t)\lambda^{(t)}g^{(t)} used in Algorithm 3. In what follows we remark that for any binary classification model f(x;w)=(f0(x;w),f1(x;w))f(x;w)=(f_{0}(x;w),f_{1}(x;w)) we have the following identities: (i) (mix(f(x;w);α(x)))0=mix(f0(x;w);α(x))(\mathrm{mix}(f(x;w);\alpha(x)))_{0}=\mathrm{mix}(f_{0}(x;w);\alpha(x)), where to simplify notation we overload the mixing operation to also act on the scalar f0(x;w)f_{0}(x;w), i.e., mix(f0(x;w);α(x))=α(x)f0(x;w)+(1α(x))(1f0(x;w))\mathrm{mix}(f_{0}(x;w);\alpha(x))=\alpha(x)f_{0}(x;w)+(1-\alpha(x))(1-f_{0}(x;w)); and (ii) f1(x;w)=1f0(x;w)f_{1}(x;w)=1-f_{0}(x;w).

Lemma E.3 (SLaM Gradient).

The gradient of the SLaM objective is equal to

wce(y,mix(f(x;w);α(x))=r(f0(x;w);α(x))sgn(2α(x)1)((mix(f0(x;w);α(x))y0)x,\partial_{w}\mathrm{ce}(y,\mathrm{mix}(f(x;w);\alpha(x))=r(f_{0}(x;w);\alpha(x))~{}\mathrm{sgn}(2\alpha(x)-1)~{}((\mathrm{mix}(f_{0}(x;w);\alpha(x))-y_{0})x,

where

r(f(x;w);α(x))=f0(x;w)(1f0(x;w))mix(f0(x;w);α(x))(1mix(f0(x;w),α(x)))|2α(x)1|r(f(x;w);\alpha(x))=\frac{f_{0}(x;w)(1-f_{0}(x;w))}{\mathrm{mix}(f_{0}(x;w);\alpha(x))(1-\mathrm{mix}(f_{0}(x;w),\alpha(x)))}~{}|2\alpha(x)-1|

Let L(x;w)=𝐄(x,y)D[ce(y,mix(f(x;w),α(x))x]L(x;w)=\operatorname{\mathbf{E}}_{(x,y)\sim D}[\mathrm{ce}(y,\mathrm{mix}(f(x;w),\alpha(x))\mid x] be the expected student label mixing loss conditional on some example x𝐑dx\in\mathbf{R}^{d}. It holds wL(x;w)=r(f(x;w),α(x))|2α(x)1|(f0(x;w)g0(x))x.\partial_{w}L(x;w)=r(f(x;w),\alpha(x))~{}|2\alpha(x)-1|~{}(f_{0}(x;w)-g_{0}(x))~{}x\,.

Proof.

We first show the formula

wce(y,mix(f(x;w),α(x))=r(f0(x;w),α(x))sgn(2α(x)1)((mix(f0(x;w),α(x))y0)x.\partial_{w}\mathrm{ce}(y,\mathrm{mix}(f(x;w),\alpha(x))=r(f_{0}(x;w),\alpha(x))~{}\mathrm{sgn}(2\alpha(x)-1)~{}((\mathrm{mix}(f_{0}(x;w),\alpha(x))-y_{0})x\,. (5)

Using the chain rule, we obtain

wce(y,\displaystyle\partial_{w}\mathrm{ce}(y, mix(f(x;w);α(x))=\displaystyle\mathrm{mix}(f(x;w);\alpha(x))=
y0mix(f0(x;w),α(x))w(mix(f0(x;w);α(x))\displaystyle-\frac{y_{0}}{\mathrm{mix}(f_{0}(x;w),\alpha(x))}\partial_{w}(\mathrm{mix}(f_{0}(x;w);\alpha(x))
y1mix(f1(x;w),α(x))w(mix(f1(x;w);α(x)).\displaystyle-\frac{y_{1}}{\mathrm{mix}(f_{1}(x;w),\alpha(x))}\partial_{w}(\mathrm{mix}(f_{1}(x;w);\alpha(x))\,.

Now we observe that that for binary classification, it holds that y1=1y0y_{1}=1-y_{0}, mix(f1(x;w);α(x))=1mix(f0(x;w);α(x))\mathrm{mix}(f_{1}(x;w);\alpha(x))=1-\mathrm{mix}(f_{0}(x;w);\alpha(x)), and therefore, also wmix(f(x;w);α(x))1)=wmix(f(x;w);α(x))0)\partial_{w}\mathrm{mix}(f(x;w);\alpha(x))_{1})=-\partial_{w}\mathrm{mix}(f(x;w);\alpha(x))_{0}) to obtain the simplified expression:

wce(y,\displaystyle\partial_{w}\mathrm{ce}(y, mix(f(x;w);α(x))=\displaystyle\mathrm{mix}(f(x;w);\alpha(x))=
y0mix(f0(x;w),α(x))w(mix(f0(x;w);α(x))\displaystyle-\frac{y_{0}}{\mathrm{mix}(f_{0}(x;w),\alpha(x))}\partial_{w}(\mathrm{mix}(f_{0}(x;w);\alpha(x))
+1y01mix(f0(x;w),α(x))w(mix(f0(x;w);α(x)).\displaystyle+\frac{1-y_{0}}{1-\mathrm{mix}(f_{0}(x;w),\alpha(x))}\partial_{w}(\mathrm{mix}(f_{0}(x;w);\alpha(x))\,.

Further simplifying the above expression, we obtain:

wce(y,mix(f(x;w);α(x))=\displaystyle\partial_{w}\mathrm{ce}(y,\mathrm{mix}(f(x;w);\alpha(x))=
=mix(f0(x;w),α(x))y0mix(f0(x;w),α(x))(1mix(f0(x;w),α(x)))w(mix(f0(x;w);α(x)).\displaystyle=\frac{\mathrm{mix}(f_{0}(x;w),\alpha(x))-y_{0}}{\mathrm{mix}(f_{0}(x;w),\alpha(x))~{}(1-\mathrm{mix}(f_{0}(x;w),\alpha(x)))}\partial_{w}(\mathrm{mix}(f_{0}(x;w);\alpha(x))\,.

Using again the chain rule we obtain that

w(mix(f0(x;w);α(x))=α(x)w(f0(x;w))+(1α(x))w(1f0(x;w))=(2α(x)1)wf0(x;w).\partial_{w}(\mathrm{mix}(f_{0}(x;w);\alpha(x))=\alpha(x)\partial_{w}(f_{0}(x;w))+(1-\alpha(x))\partial_{w}(1-f_{0}(x;w))=(2\alpha(x)-1)~{}\partial_{w}f_{0}(x;w)\,.

Using the fact that the derivative of the sigmoid function r(t)=1/(1+et)r(t)=1/(1+e^{-t}), is r(t)=et/(1et)2=r(t)(1r(t))r^{\prime}(t)=e^{-t}/(1-e^{-t})^{2}=r(t)(1-r(t)), and the chain rule, we obtain that wf0(x;w)=f0(x;w)(1f0(x;w))x\partial_{w}f_{0}(x;w)=f_{0}(x;w)(1-f_{0}(x;w))x. Putting everything together we obtain the claimed formula for wce(y,mix(f(x;w);α(x)))\partial_{w}\mathrm{ce}(y,\mathrm{mix}(f(x;w);\alpha(x))).

To obtain the gradient formula for the expected loss conditional on some fixed example xx, we can use the fact that w𝐄[ce(y,mix(f(x;w);α(x)))x]=𝐄[wce(y,mix(f(x;w);α(x)))x].\partial_{w}\operatorname{\mathbf{E}}[\mathrm{ce}(y,\mathrm{mix}(f(x;w);\alpha(x)))\mid x]=\operatorname{\mathbf{E}}[\partial_{w}\mathrm{ce}(y,\mathrm{mix}(f(x;w);\alpha(x)))\mid x]. Now using the formula of Equation 5 and the fact that 𝐄[y0x]=mix(g0(x);α(x))\operatorname{\mathbf{E}}[y_{0}\mid x]=\mathrm{mix}(g_{0}(x);\alpha(x)) by the definition of our noise model, we obtain that

wL(x;w)\displaystyle\partial_{w}L(x;w) =r(f0(x;w);α(x))sgn(2α(x)1)(mix(f0(x;w);α(x))mix(g0(x);α(x)))\displaystyle=r(f_{0}(x;w);\alpha(x))\mathrm{sgn}(2\alpha(x)-1)(\mathrm{mix}(f_{0}(x;w);\alpha(x))-\mathrm{mix}(g_{0}(x);\alpha(x)))
=r(f0(x;w);α(x))(2α(x)1)(f0(x;w)g0(x))\displaystyle=r(f_{0}(x;w);\alpha(x))(2\alpha(x)-1)(f_{0}(x;w)-g_{0}(x))

We first show the following claim proving that after roughly T=1/(β2γ2ϵ2)T=1/(\beta^{2}\gamma^{2}\epsilon^{2}) gradient iterations the student parameter vector w(t)w^{(t)} will have good correlation with the ground-truth vector ww^{\ast}.

Claim 1.

Fix any TT larger than a sufficiently large constant multiple of log(1/δ)/(ϵ2γ2β2)\log(1/\delta)/(\epsilon^{2}\gamma^{2}\beta^{2}), and assume that for all tTt\leq T it holds that 𝐏xX[err(f(x;w(t)),g(x))]>ϵ\operatorname{\mathbf{P}}_{x\sim X}[\mathrm{err}(f(x;w^{(t)}),g(x))]>\epsilon. Then, we have w(T)w=Ω(βγϵ)Tw^{(T)}\cdot w^{\ast}=\Omega(\beta\gamma\epsilon)~{}T, with probability at least 1δ1-\delta.

Proof.

Denote by u(t)=λ(t)g(t)u^{(t)}=-\lambda^{(t)}g^{(t)} the update vector used in Algorithm 3. We observe that the weight vector at round TT is equal to w(T)=t=1Tu(t)w^{(T)}=\sum_{t=1}^{T}u^{(t)}. In what follows we denote by (t)\mathcal{F}^{(t)} the filtration corresponding to the randomness of the updates of Algorithm 3. We define the martingale q(T)=t=1T(𝐄[u(t)(t1)]u(t))q^{(T)}=\sum_{t=1}^{T}(\operatorname{\mathbf{E}}[u^{(t)}\mid\mathcal{F}^{(t-1)}]-u^{(t)}) with q(0)=0q^{(0)}=0. We first show that under the assumption that 𝐏xX[argmax(f(x;w(t)))g(x)]>ϵ\operatorname{\mathbf{P}}_{x\sim X}[\operatorname*{argmax}(f(x;w^{(t)}))\neq g(x)]>\epsilon, for all tTt\leq T, it holds that t=1T𝐄[u(t)(t1)]w(ϵγβ/2)T\sum_{t=1}^{T}\operatorname{\mathbf{E}}[u^{(t)}\mid\mathcal{F}^{(t-1)}]\cdot w^{\ast}\geq(\epsilon\gamma\beta/2)~{}T. Using the SLaM gradient expression of Lemma E.3 and the definition of the step size λ(t)\lambda^{(t)} we obtain that 𝐄[u(t)(t1)]=𝐄xX[|2α(x)1|(g0(x)f0(x;w(t1)))x]\operatorname{\mathbf{E}}[u^{(t)}\mid\mathcal{F}^{(t-1)}]=\operatorname{\mathbf{E}}_{x\sim X}[|2\alpha(x)-1|~{}(g_{0}(x)-f_{0}(x;w^{(t-1)}))~{}x]. Take any step tt. We have that

𝐄[u(t)(t1)]w\displaystyle\operatorname{\mathbf{E}}[u^{(t)}\mid\mathcal{F}^{(t-1)}]\cdot w^{\ast} =𝐄xX[|2α(x)1|(g0(x)f0(x;w(t1)))(xw)]\displaystyle=\operatorname{\mathbf{E}}_{x\sim X}[|2\alpha(x)-1|~{}(g_{0}(x)-f_{0}(x;w^{(t-1)}))~{}(x\cdot w^{\ast})]
=𝐄xX[|2α(x)1||g0(x)f0(x;w(t1))|xw|],\displaystyle=\operatorname{\mathbf{E}}_{x\sim X}[|2\alpha(x)-1|~{}|g_{0}(x)-f_{0}(x;w^{(t-1)})~{}|x\cdot w^{\ast}|]\,,

where we used the fact that (g0(x)f0(x;w(t1)))sgn(xw)=|g0(x)f0(x;w(t1))|(g_{0}(x)-f_{0}(x;w^{(t-1)}))~{}\mathrm{sgn}(x\cdot w^{\ast})=|g_{0}(x)-f_{0}(x;w^{(t-1)})|. Now, using the γ\gamma-margin assumption of the distribution DD and the fact that |2α(x)1|β|2\alpha(x)-1|\geq\beta we obtain

𝐄[u(t)(t1)]w\displaystyle\operatorname{\mathbf{E}}[u^{(t)}\mid\mathcal{F}^{(t-1)}]\cdot w^{\ast} βγ𝐄xX[|g0(x)f0(x;w(t1))|]\displaystyle\geq\beta\gamma~{}\operatorname{\mathbf{E}}_{x\sim X}[|g_{0}(x)-f_{0}(x;w^{(t-1)})|]
βγ𝐄xX[|g0(x)f0(x;w(t1))|err(g(x),f(x;w(t1)))]\displaystyle\geq\beta\gamma~{}\operatorname{\mathbf{E}}_{x\sim X}[|g_{0}(x)-f_{0}(x;w^{(t-1)})|~{}\mathrm{err}(g(x),f(x;w^{(t-1)}))]
(βγ/2)𝐏xX[err(g(x),f(x;w(t1)))]βγϵ/2,\displaystyle\geq(\beta\gamma/2)~{}\operatorname{\mathbf{P}}_{x\sim X}[\mathrm{err}(g(x),f(x;w^{(t-1)}))]\geq\beta\gamma\epsilon/2\,,

where for the penultimate inequality we used the fact that when g(x)g(x) and f(x;w(t1))f(x;w^{(t-1)}) disagree it holds that |g0(x)f0(x;w(t1))|1/2|g_{0}(x)-f_{0}(x;w^{(t-1)})|\geq 1/2. Take, for example, the case where g0(x)=1g_{0}(x)=1. Then f0(x;w(t1))f_{0}(x;w^{(t-1)}) must be smaller than 1/21/2 otherwise the prediction of the model argmaxf(x;w(t1))\operatorname*{argmax}f(x;w^{(t-1)}) would also be 0 (and would agree with the prediction of g(x)g(x)). Finally, for the last inequality we used the fact that, by our assumption, it holds that 𝐏xX[err(g(x),f(x;w(t1)))]ϵ\operatorname{\mathbf{P}}_{x\sim X}[\mathrm{err}(g(x),f(x;w^{(t-1)}))]\geq\epsilon. Therefore, we conclude that t=1T𝐄[u(t)(t1)]w(ϵγβ/2)T\sum_{t=1}^{T}\operatorname{\mathbf{E}}[u^{(t)}\mid\mathcal{F}^{(t-1)}]\cdot w^{\ast}\geq(\epsilon\gamma\beta/2)~{}T. Next, we shall show that w(T)w^{(T)} also achieves good correlation with the optimal direction ww^{\ast} with high probability. We will use the fact that q(t)q^{(t)} is a martingale and the Azuma-Hoeffding inequality to show that w(T)ww^{(T)}\cdot w^{\ast} will not be very far from its expectation.

Lemma E.4 (Azuma-Hoeffding).

Let ξ(t)\xi^{(t)} be a martingale with bounded increments, i.e., |ξ(t)ξ(t1)|M|\xi^{(t)}-\xi^{(t-1)}|\leq M. It holds that 𝐏[ξ(T)ξ(0)+λ]eλ2/(2M2T)\operatorname{\mathbf{P}}[\xi^{(T)}\geq\xi^{(0)}+\lambda]\leq e^{-\lambda^{2}/(2M^{2}T)}.

Recall that from Lemma E.3 we have that 𝐄[u(t)(t1)]=𝐄xX[|2α(x)1|(g0(x)f0(x;w(t1)))x]\operatorname{\mathbf{E}}[u^{(t)}\mid\mathcal{F}^{(t-1)}]=\operatorname{\mathbf{E}}_{x\sim X}[|2\alpha(x)-1|~{}(g_{0}(x)-f_{0}(x;w^{(t-1)}))~{}x] and

u(t)=sgn(2α(x(t))1)(y0(t)mix(f0(x(t);w(t1)),α(x(t)))x(t).u^{(t)}=\mathrm{sgn}(2\alpha(x^{(t)})-1)~{}(y_{0}^{(t)}-\mathrm{mix}(f_{0}(x^{(t)};w^{(t-1)}),\alpha(x^{(t)}))~{}x^{(t)}\,.

Observe that since x21\|x\|_{2}\leq 1 for all xx it holds that u(t)21\|u^{(t)}\|_{2}\leq 1. Therefore, the difference 𝐄[u(t)(t1)]u(t)2\|\operatorname{\mathbf{E}}[u^{(t)}\mid\mathcal{F}^{(t-1)}]-u^{(t)}\|\leq 2 with probability 11. Since w2=1\|w^{\ast}\|_{2}=1, using Cauchy-Schwarz, we also obtain that |𝐄[u(t)w(t1)]u(t)w|2|\operatorname{\mathbf{E}}[u^{(t)}\cdot w^{\ast}\mid\mathcal{F}^{(t-1)}]-u^{(t)}\cdot w^{\ast}|\leq 2.

Using Lemma E.4, and the fact that q(0)=0q^{(0)}=0 we obtain that 𝐏[q(t)w(βγϵ/4)T]eβ2γ2ϵ2T/128.\operatorname{\mathbf{P}}[q^{(t)}\cdot w^{\ast}\geq(\beta\gamma\epsilon/4)~{}T]\leq e^{-\beta^{2}\gamma^{2}\epsilon^{2}T/128}\,. Therefore we conclude that for any TT larger than 128log(1/δ)/(β2γ2ϵ2)128\log(1/\delta)/(\beta^{2}\gamma^{2}\epsilon^{2}), with probability at least 1δ1-\delta, it holds that q(T)w(βγϵ/4)Tq^{(T)}\cdot w^{\ast}\geq(\beta\gamma\epsilon/4)T or equivalently w(T)w(βγϵ/4)Tw^{(T)}\cdot w^{\ast}\geq(\beta\gamma\epsilon/4)~{}T, where we used our previously obtained bound for the expected updates t=1T𝐄[u(t)(t1)]w(βγϵ/2)T\sum_{t=1}^{T}\operatorname{\mathbf{E}}[u^{(t)}\mid\mathcal{F}^{(t-1)}]\cdot w^{\ast}\geq(\beta\gamma\epsilon/2)~{}T.

Claim 2.

Fix any T1T\geq 1. Then, we have w(T)2=O(T)\|w^{(T)}\|_{2}=O(\sqrt{T}), with probability at least 99%99\%.

Proof.

We have that w(T)22=w(T1)22+2u(T)w(T1)+u(T)22\|w^{(T)}\|_{2}^{2}=\|w^{(T-1)}\|_{2}^{2}+2u^{(T)}\cdot w^{(T-1)}+\|u^{(T)}\|_{2}^{2}. Unrolling the iteration, we obtain that

w(T)22=2t=1Tu(t)w(t1)+t=1Tu(t)222t=1Tu(t)w(t1)+T,\|w^{(T)}\|_{2}^{2}=2\sum_{t=1}^{T}u^{(t)}\cdot w^{(t-1)}+\sum_{t=1}^{T}\|u^{(t)}\|_{2}^{2}\leq 2\sum_{t=1}^{T}u^{(t)}\cdot w^{(t-1)}+T\,, (6)

where we used the fact that, since x(t)21\|x^{(t)}\|_{2}\leq 1, it holds that u(t)21\|u^{(t)}\|_{2}\leq 1 (see the proof of 1). We first show that t=1T𝐄[u(t)(t1)]w(t1)=O(T)\sum_{t=1}^{T}\operatorname{\mathbf{E}}[u^{(t)}\mid\mathcal{F}^{(t-1)}]\cdot w^{(t-1)}=O(T). We have

𝐄[u(t)(t1)]w(t1)\displaystyle\operatorname{\mathbf{E}}[u^{(t)}\mid\mathcal{F}^{(t-1)}]\cdot w^{(t-1)} =𝐄xX[|2α(x)1|(g0(x)f0(x;w(t1)))(xw(t1))]\displaystyle=\operatorname{\mathbf{E}}_{x\sim X}[|2\alpha(x)-1|~{}(g_{0}(x)-f_{0}(x;w^{(t-1)}))~{}(x\cdot w^{(t-1)})]
𝐄xX[(g0(x)f0(x;w(t1)))(xw(t1))].\displaystyle\leq\operatorname{\mathbf{E}}_{x\sim X}[(g_{0}(x)-f_{0}(x;w^{(t-1)}))~{}(x\cdot w^{(t-1)})]\,.

We will show that for xx it holds that

g0(x)f(x;w(t1))(xw(t1))1e.g_{0}(x)-f(x;w^{(t-1)})(x\cdot w^{(t-1)})\leq\frac{1}{e}\,.

Fix some xx and let s=w(t1)xs=w^{(t-1)}\cdot x. Assume first that g0(x)=1g_{0}(x)=1. Then, we have

g0(x)f(x;w(t1))(xw(t1))=(111+es)s=ses1+es1e,g_{0}(x)-f(x;w^{(t-1)})(x\cdot w^{(t-1)})=\left(1-\frac{1}{1+e^{-s}}\right)s=s~{}\frac{e^{-s}}{1+e^{-s}}\leq\frac{1}{e}\,,

where we used the fact that ses1+es0s~{}\frac{e^{-s}}{1+e^{-s}}\leq 0 for s0s\leq 0 and ses1+esses1/es~{}\frac{e^{-s}}{1+e^{-s}}\leq se^{-s}\leq 1/e for s0s\geq 0 (using the elementary inequality zez1/eze^{-z}\leq 1/e for all z𝐑z\in\mathbf{R}). When g0(x)=0g_{0}(x)=0 we similarly have that

g0(x)f(x;w(t1))(xw(t1))=s1+es1e,g_{0}(x)-f(x;w^{(t-1)})(x\cdot w^{(t-1)})=-\frac{s}{1+e^{-s}}\leq\frac{1}{e}\,,

where we used the fact that when s0s\geq 0 it holds that s1+es0-\frac{s}{1+e^{-s}}\leq 0 and when s0s\leq 0, s1+ess/es=ses-\frac{s}{1+e^{-s}}\leq-s/e^{-s}=-se^{s}. For the final inequality, we used again the inequality zez1/eze^{-z}\leq 1/e for all z𝐑z\in\mathbf{R} (where we replaced zz with s-s).

Therefore, we obtain that 𝐄[u(t)(t1)]w(t1)1/e\operatorname{\mathbf{E}}[u^{(t)}\mid\mathcal{F}^{(t-1)}]\cdot w^{(t-1)}\leq 1/e and t=1T𝐄[u(t)(t1)]w(t1)T/e\sum_{t=1}^{T}\operatorname{\mathbf{E}}[u^{(t)}\mid\mathcal{F}^{(t-1)}]\cdot w^{(t-1)}\leq T/e. Using the decomposition of Equation 6, linearity of expectation, and the tower rule for conditional expectations, we conclude that 𝐄[w(T)22](2/e+1)T\operatorname{\mathbf{E}}[\|w^{(T)}\|_{2}^{2}]\leq(2/e+1)T. Using Markov’s inequality we obtain that with probability at least 99% it holds that w(T)22=O(T)\|w^{(T)}\|_{2}^{2}=O(T) or equivalently w(T)2=O(T)\|w^{(T)}\|_{2}=O(\sqrt{T}).

We can now finish the proof of Theorem 5.1. Assume, in order to reach a contradiction, that for all tTt\leq T it holds that 𝐏xX[err(f(x;w(t)),g(x))]>ϵ\operatorname{\mathbf{P}}_{x\sim X}[\mathrm{err}(f(x;w^{(t)}),g(x))]>\epsilon. Now picking TT to be larger than a sufficiently large constant multiple of 1/(ϵ2γ2β2)1/(\epsilon^{2}\gamma^{2}\beta^{2}) and using 1 and 2 we obtain that, with probability at least 99%99\%, it holds that w(T)w/w(T)2Ω(βγϵT)w^{(T)}\cdot w^{\ast}/\|w^{(T)}\|_{2}\geq\Omega(\beta\gamma\epsilon\sqrt{T}), which can be made to be larger than 11 by our choice of TT. However, this is a contradiction as by Cauchy-Schwarz we have w(T)w/w(T)2w21w^{(T)}\cdot w^{\ast}/\|w^{(T)}\|_{2}\leq\|w^{\ast}\|_{2}\leq 1. Therefore, with probability at least 99%99\%, it must be that for some tTt\leq T it holds that 𝐏xX[err(f(x;w(t)),g(x))]ϵ\operatorname{\mathbf{P}}_{x\sim X}[\mathrm{err}(f(x;w^{(t)}),g(x))]\leq\epsilon.

Refer to caption
Refer to caption
Figure 8: The landscape and gradient field of the population student label mixing loss for a simple 2 dimensional feature problem with a ground truth corresponding to a halfspace. We observe that the landscape is non-convex; however we can see that the corresponding gradient field “points towards the optimal direction” and therefore gradient descent converges to the global minimizer. A potential issue is the fact that the landscape contains regions where the gradients may almost vanish and this could lead to the gradient iteration of the student getting trapped there. To handle this issue, in Algorithm 3 we multiply the gradient of SLaM with an appropriate step-size.