This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Learning Consistent Deep Generative Models from Sparse Data via Prediction Constraints

Gabriel Hope
UC Irvine
[email protected]
&Madina Abdrakhmanova
Nazarbayev University
[email protected]
&Xiaoyin Chen
UC Irvine
[email protected]
&Michael C. Hughes
Tufts University
[email protected]
&Erik B. Sudderth
UC Irvine
[email protected]
Abstract

We develop a new framework for learning variational autoencoders and other deep generative models that balances generative and discriminative goals. Our framework optimizes model parameters to maximize a variational lower bound on the likelihood of observed data, subject to a task-specific prediction constraint that prevents model misspecification from leading to inaccurate predictions. We further enforce a consistency constraint, derived naturally from the generative model, that requires predictions on reconstructed data to match those on the original data. We show that these two contributions – prediction constraints and consistency constraints – lead to promising image classification performance, especially in the semi-supervised scenario where category labels are sparse but unlabeled data is plentiful. Our approach enables advances in generative modeling to directly boost semi-supervised classification performance, an ability we demonstrate by augmenting deep generative models with latent variables capturing spatial transformations.

1 Introduction

We develop broadly applicable methods for learning flexible models of high-dimensional data, like images, that are paired with (discrete or continuous) labels. We are particularly interested in semi-supervised learning (Zhu, 2005, Oliver et al., 2018) from data that is sparsely labeled, a common situation in practice due to the cost or privacy concerns associated with data annotation. Given a large and sparsely labeled dataset, we seek a single probabilistic model that simultaneously makes good predictions of labels and provides a high-quality generative model of the high-dimensional input data. Strong generative models are valuable because they can allow incorporation of domain knowledge, can address partially missing or corrupted data, and can be visualized to improve interpretability.

Prior approaches for the semi-supervised learning of deep generative models include methods based on variational autoencoders (VAEs) (Kingma et al., 2014, Siddharth et al., 2017), generative adversarial networks (GANs) (Dumoulin et al., 2017, Kumar et al., 2017), and hybrids of the two (Larsen et al., 2016, de Bem et al., 2018, Zhang et al., 2019). While these all allow sampling of data, a major shortcoming of these approaches is that they do not adequately use labels to inform the generative model. Furthermore, GAN-based approaches lack the ability to evaluate the learned probability density function, which can be important for tasks such as model selection and anomaly detection.

This paper develops a framework for training prediction constrained variational autoencoders (PC-VAEs) that minimize application-motivated loss functions in the prediction of labels, while simultaneously learning high-quality generative models of the raw data. Our approach is inspired by the prediction-constrained framework recently proposed for learning supervised topic models of “bag of words” count data (Hughes et al., 2018), but differs in four major ways. First, we develop scalable algorithms for learning a much larger and richer family of deep generative models. Second, we capture uncertainty in latent variables rather than simply using point estimates. Third, we allow more flexible specification of loss functions. Finally, we show that the generative model structure leads to a natural consistency constraint vital for semi-supervised learning from very sparse labels.

Our experiments demonstrate that consistent prediction-constrained (CPC) VAE training leads to prediction performance competitive with state-of-the-art discriminative methods on fully-labeled datasets, and excels over these baselines when given semi-supervised datasets where labels are rare.

VAE-then-MLP PC-VAE CPC-VAE M2 M2 (14) CPC-VAE (14)
\begin{overpic}[width=65.44142pt]{figures/halfmoon/cropped/M1_2_6_77_9.pdf} \put(8.0,8.0){77.9\%} \end{overpic} \begin{overpic}[width=65.44142pt]{figures/halfmoon/cropped/PC_2_6_78_1.pdf} \put(8.0,8.0){78.1\%} \end{overpic} \begin{overpic}[width=65.44142pt]{figures/halfmoon/cropped/CPC_2_6_98_4.pdf} \put(8.0,8.0){98.4\%} \end{overpic} \begin{overpic}[width=65.44142pt]{figures/halfmoon/cropped/M2_2_6_98_1.pdf} \put(8.0,8.0){98.1\%} \end{overpic} \begin{overpic}[width=65.44142pt]{figures/halfmoon/cropped/M2_14_6_80_6.pdf} \put(8.0,8.0){80.6\%} \end{overpic} \begin{overpic}[width=65.44142pt]{figures/halfmoon/cropped/CPC_14_6_98_5.pdf} \put(8.0,8.0){98.5\%} \end{overpic}
\begin{overpic}[width=65.44142pt]{figures/halfmoon/cropped/M1_2_100_83_8.pdf} \put(8.0,8.0){83.8\%} \end{overpic} \begin{overpic}[width=65.44142pt]{figures/halfmoon/cropped/PC_2_100_98_2.pdf} \put(8.0,8.0){98.2\%} \end{overpic} \begin{overpic}[width=65.44142pt]{figures/halfmoon/cropped/CPC_2_100_98_4.pdf} \put(8.0,8.0){98.4\%} \end{overpic} \begin{overpic}[width=65.44142pt]{figures/halfmoon/cropped/M2_2_100_98_1.pdf} \put(8.0,8.0){98.1\%} \end{overpic} \begin{overpic}[width=65.44142pt]{figures/halfmoon/cropped/M2_14_100_96_4.pdf} \put(8.0,8.0){96.4\%} \end{overpic} \begin{overpic}[width=65.44142pt]{figures/halfmoon/cropped/CPC_14_100_98_1.pdf} \put(8.0,8.0){98.1\%} \end{overpic}
Figure 1: Predictions from SSL VAE methods on half-moon binary classification task, with accuracy in lower corner. Each dot indicates a 2-dim. feature vector, colored by predicted binary label. Top: 6 labeled examples (diamond markers), 994 unlabeled. Bottom: 100 labeled, 900 unlabeled. First 4 columns use C=2C=2 encoding dimensions, last 2 use C=14C=14. M2 (Kingma et al., 2014) classification accuracy deterioriates when increasing model capacity from 2 to 14, especially with only 6 labels (drop from 98.1% to 80.6% accuracy). In contrast, our CPC VAE approach is reliable at any model capacity, as it better aligns generative and discriminative goals.

2 Background: Deep Generative Models and Semi-supervision

We now describe VAEs as deep generative models and review previous methods for semi-supervised learning (SSL) of VAEs, highlighting weaknesses that we later improve upon. We assume all SSL tasks provide two training datasets: an unsupervised (or unlabeled) dataset 𝒟U\mathcal{D}^{U} of NN feature vectors xx, and a supervised (or labeled) dataset 𝒟S\mathcal{D}^{S} containing MM pairs (x,y)(x,y) of features xx and label y𝒴y\in\mathcal{Y}. Labels are often sparse (NMN\gg M) and can be discrete or continuous.

2.1 Unsupervised Generative Modeling with the VAE

The variational autoencoder (Kingma & Welling, 2014) is an unsupervised model with two components: a generative model and an inference model. The generative model defines for each example a joint distribution pθ(x,z)p_{\theta}(x,z) over “features” (observed vector xDx\in\mathbb{R}^{D}) and “encodings” (hidden vector zCz\in\mathbb{R}^{C}). The “inference model” of the VAE defines an approximate posterior qϕ(zx)q_{\phi}(z\mid x), which is trained to be close to the true posterior (qϕ(zx)pθ(zx)q_{\phi}(z\mid x)\approx p_{\theta}(z\mid x)) but much easier to evaluate. As in Kingma & Welling (2014), we assume the following conditional independence structure:

pθ(x,z)=𝒩(z0,IC)(xμθ(z),σθ(z)),qϕ(zx)=𝒩(zμϕ(x),σϕ(x)).\displaystyle p_{\theta}(x,z)=\mathcal{N}(z\mid 0,I_{C})\cdot\mathcal{F}(x\mid\mu_{\theta}(z),\sigma_{\theta}(z)),\quad q_{\phi}(z\mid x)=\mathcal{N}(z\mid\mu_{\phi}(x),\sigma_{\phi}(x)). (1)

The likelihood \mathcal{F} is often multivariate normal, but other distributions may give robustness to outliers. The (deterministic) functions μθ\mu_{\theta} and σθ\sigma_{\theta}, with trainable parameters θ\theta, define the mean and covariance of the likelihood. Given any observation xx, the posterior of zz is approximated as normal with mean μϕ\mu_{\phi} and (diagonal) covariance σϕ\sigma_{\phi} parameterized by ϕ\phi. These functions can be represented as multi-layer perceptrons (MLPs), convolutional neural networks (CNNs), or other (deep) neural networks.

We would ideally learn generative parameters θ\theta by maximizing the marginal likelihood of features xx, integrating latent variable zz. Since this is intractable, we instead maximize a variational lower bound:

maxθ,ϕx𝒟VAE(x;θ,ϕ),VAE(x;θ,ϕ)=𝔼qϕ(z|x)[logpθ(x,z)qϕ(z|x)]logpθ(x).\displaystyle\max_{\theta,\phi}~{}~{}\textstyle\sum_{x\in\mathcal{D}}\mathcal{L}^{\text{VAE}}(x;\theta,\phi),\qquad\mathcal{L}^{\text{VAE}}(x;\theta,\phi)=\mathbb{E}_{q_{\phi}(z|x)}\left[\log\frac{p_{\theta}(x,z)}{q_{\phi}(z|x)}\right]\leq\log p_{\theta}(x). (2)

This expectation can be evaluated via Monte Carlo samples from the inference model qϕ(z|x)q_{\phi}(z|x). Gradients with respect to θ,ϕ\theta,\phi can be similarly estimated by the reparameterization “trick” of representing qϕ(zx)q_{\phi}(z\mid x) as a linear transformation of standard normal variables (Kingma & Welling, 2014).

Throughout this paper, we denote variational parameters by ϕ\phi. Because the factorization of qq changes for more complex models, we will write ϕz|x\phi^{z|x} to denote the parameters specific to factor q(z|x)q(z|x).

2.2 Two-Stage SSL: Maximize Feature Likelihood then Train Predictor

One way to employ the VAE for a semi-supervised task is a two-stage “VAE-then-MLP”. First, train a VAE to maximize the unsupervised likelihood (2) of all observed features xx (both labeled 𝒟S\mathcal{D}^{S} and unlabeled 𝒟U\mathcal{D}^{U}). Second, we define a label-from-code predictor y^w(z)\hat{y}_{w}(z) that maps each learned code representation zz to a predicted label y𝒴y\in\mathcal{Y}. We use an MLP with weights ww, though any predictor could do. Let S(y,y^)\ell_{S}(y,\hat{y}) be a loss function, such as cross-entropy, appropriate for the prediction task. We train the predictor to minimize the loss: minwx,y𝒟S𝔼qϕ(z|x)[S(y,y^w(z))]\min_{w}\sum_{x,y\in\mathcal{D}^{S}}\mathbb{E}_{q_{\phi}(z|x)}\left[\ell_{S}(y,\hat{y}_{w}(z))\right]. Importantly, this second stage uses only the small labeled dataset and relies on fixed parameters ϕ\phi from stage one.

While “VAE-then-MLP” is a simple common baseline (Kingma et al., 2014), it has a key disadvantage: Labels are only used in the second stage, and thus a misspecified generative model in stage one will likely produce inferior predictions. Fig. 1 illustrates this weakness.

2.3 Semi-supervised VAEs: Maximize Joint Likelihood of Labels and Features

To overcome the weakness of the two-stage approach, previous work by Kingma et al. (2014) presented a VAE-inspired model called “M2” focused on the joint generative modeling of labels yy and data xx. M2 has two components: a generative model pθ(x,y,z)p_{\theta}(x,y,z) and an inference model qϕ(y,zx)q_{\phi}(y,z\mid x). Their generative model is factorized to sample labels (with frequencies π\pi) first, and then features xx:

pθ(x,y,z)=𝒩(z0,IC)Cat(yπ)(xμθ(y,z),σθ(y,z)).\displaystyle p_{\theta}(x,y,z)=\mathcal{N}(z\mid 0,I_{C})\cdot\text{Cat}(y\mid\pi)\cdot\mathcal{F}(x\mid\mu_{\theta}(y,z),\sigma_{\theta}(y,z)). (3)

The M2 inference model sets qϕ(y,zx)=qϕy|x(yx)qϕz|x,y(zx,y)q_{\phi}(y,z\mid x)=q_{\phi^{y|x}}(y\mid x)q_{\phi^{z|x,y}}(z\mid x,y), where ϕ=(ϕy|x,ϕz|x,y)\phi=(\phi^{y|x},\phi^{z|x,y}).

To train M2, Kingma et al. (2014) maximize the likelihood of all observations (labels and features):

maxθ,ϕy|x,ϕz|x,yx,y𝒟SS(x,y;θ,ϕz|x,y)+x𝒟UU(x;θ,ϕy|x,ϕz|x,y).\displaystyle\max_{\theta,\phi^{y|x},\phi^{z|x,y}}\quad\textstyle\sum_{x,y\in\mathcal{D}^{S}}\mathcal{L}^{S}(x,y;\theta,\phi^{z|x,y})+\textstyle\sum_{x\in\mathcal{D}^{U}}\mathcal{L}^{U}(x;\theta,\phi^{y|x},\phi^{z|x,y}). (4)

The first, “supervised” term in Eq. (4) is a variational bound for the feature-and-label joint likelihood:

S(x,y;θ,ϕz|x,y)\displaystyle\mathcal{L}^{S}(x,y;\theta,\phi^{z|x,y}) =𝔼qϕz|x,y(z|x,y)[logpθ(x,y,z)qϕz|x,y(z|x,y)]logpθ(x,y).\displaystyle=\textstyle\mathbb{E}_{q_{\phi^{z|x,y}}(z|x,y)}\left[\log\frac{p_{\theta}(x,y,z)}{q_{\phi^{z|x,y}}(z|x,y)}\right]\leq\log p_{\theta}(x,y). (5)

The second, “unsupervised” term is a variational lower bound for the features-only likelihood logpθ(x)U\log p_{\theta}(x)\geq\mathcal{L}^{U}, where U=𝔼qϕ(y,z|x)[logpθ(x,y,z)qϕ(y,z|x)]\mathcal{L}^{U}=\mathbb{E}_{q_{\phi}(y,z|x)}\left[\log\frac{p_{\theta}(x,y,z)}{q_{\phi}(y,z|x)}\right] can be simply expressed in terms of S\mathcal{L}^{S}:

U(x;θ,ϕy|x,ϕz|x,y)\displaystyle\mathcal{L}^{U}(x;\theta,\phi^{y|x},\phi^{z|x,y}) =y𝒴qϕy|x(yx)(S(x,y;θ,ϕz|x,y)logqϕy|x(yx)).\displaystyle=\textstyle\sum_{y\in\mathcal{Y}}q_{\phi^{y|x}}(y\mid x)\left(\mathcal{L}^{S}(x,y;\theta,\phi^{z|x,y})-\log q_{\phi^{y|x}}(y\mid x)\right). (6)

As with the unsupervised VAE, both terms in the objective can be computed via Monte Carlo sampling from the variational posterior, and gradients can be estimated via the reparameterization trick.

M2’s prediction dilemma and heuristic fix.

After training parameters θ,ϕ\theta,\phi, we need to predict labels yy given test data xx. M2’s structure assumes we make predictions via the inference model’s discriminator density qϕy|x(yx)q_{\phi^{y|x}}(y\mid x). However, the discriminator’s parameter ϕy|x\phi^{y|x} is only informed by the unlabeled data when using the objective above (it is not used to compute S\mathcal{L}^{S}). We cannot expect accurate predictions from a parameter that does not touch any labeled examples in the training set.

To partially overcome this issue, Kingma et al. (2014) and later work use a weighted objective:

maxθ,ϕx,y𝒟S(αlogqϕy|x(yx)+λS(x,y;θ,ϕz|x,y))+x𝒟UU(x;θ,ϕy|x,ϕz|x,y).\displaystyle\max_{\theta,\phi}\!\sum_{x,y\in\mathcal{D}^{S}}\left(\alpha\log q_{\phi^{y|x}}(y\mid x)+\lambda\mathcal{L}^{S}(x,y;\theta,\phi^{z|x,y})\right)+\!\sum_{x\in\mathcal{D}^{U}}\mathcal{L}^{U}(x;\theta,\phi^{y|x},\phi^{z|x,y}). (7)

This objective biases the inference model’s discriminator to do well on the labeled set via an extra loss term (weighted by hyperparameter α>0\alpha>0). We can further include λ>0\lambda>0 to balance the supervised and unsupervised terms. Originally, Kingma et al. (2014) fix λ=1\lambda=1 and tune α\alpha to achieve good performance. Later, Siddharth et al. (2017) tuned λ\lambda to improve performance. Maaløe et al. (2016) used this same αlogq(yx)\alpha\log q(y\mid x) term for labeled data to train VAEs with auxiliary variables.

Disadvantage: What Justification? While the S\mathcal{L}^{S} and U\mathcal{L}^{U} terms in Eq. (7) have a rigorous justification as maximizing the data likelihood under the assumed generative model, the first term (αlogq(yx)\alpha\log q(y\mid x)) is not justified by the generative or inference model. In particular, suppose the training data were fully labeled: we would ignore the U\mathcal{L}^{U} terms altogether, and the remaining terms would decouple the parameters θ,ϕz|x,y\theta,\phi^{z|x,y} from the discriminator parameters ϕy|x\phi^{y|x}. This is deeply unsatisfying: We want a single model guided by both generative and discriminative goals, not two separate models. Even in partially-labeled scenarios, including this α\alpha term does not adequately balance generative and discriminative goals, as we demonstrate in later examples. An overly flexible yet misspecified generative model may go astray and compromise predictions.

Disadvantage: Runtime Cost. Another disadvantage is that the computation of U\mathcal{L}^{U} in Eq. (6) is expensive. If labels are discrete, computing this sum exactly is possible but requires a sum over all L=|𝒴|L=|\mathcal{Y}| possible class labels, computing a Monte Carlo estimate of S\mathcal{L}^{S} for each one. In Appendix C, we demonstrate that practically, M2’s runtime is roughly ||/2|\mathcal{L}|/2 times longer than our consistent prediction-constrained approach. While further Monte Carlo approximations could avoid the explicit sum over classes in Eq. (6), they may make gradients far too noisy.

Extensions. Siddharth et al. (2017) showed how S\mathcal{L}^{S} and U\mathcal{L}^{U} could be extended to any desired conditional independence structure for qϕ(y,zx)q_{\phi}(y,z\mid x), generalizing the label-then-code factorization qϕ(yx)qϕ(zx,y)q_{\phi}(y\mid x)q_{\phi}(z\mid x,y) of Kingma et al. (2014). While importance sampling leads to likelihood bounds, the overall objective still has two undesirable traits. First, it is expensive, requiring either marginalization of yy to compute U\mathcal{L}^{U} in Eq. (6) or marginalization of zz to compute q(y|x)=qϕ(y,z|x)𝑑zq(y|x)=\int q_{\phi}(y,z|x)dz. Second, the approach requires the heuristic inclusion of the discriminator loss αlogq(yx)\alpha\log q(y\mid x). While recent parallel work by Gordon & Hernández-Lobato (2020) also tries to improve SSL for VAEs, their approach couples discriminative and generative terms only distantly through a joint prior over parameters and still requires expensive sums over labels when computing generative likelihoods.

VAE-then-MLP Supervised VAE PC-VAE CPC-VAE M2
\begin{overpic}[height=56.9055pt]{figures/mnist_2d/cropped/vae.pdf} \put(8.0,8.0){54.9\%} \end{overpic} \begin{overpic}[height=56.9055pt]{figures/mnist_2d/cropped/generative.pdf} \put(8.0,8.0){66.2\%} \end{overpic} \begin{overpic}[height=56.9055pt]{figures/mnist_2d/cropped/pcvae.pdf} \put(8.0,8.0){74.1\%} \end{overpic} \begin{overpic}[height=56.9055pt]{figures/mnist_2d/cropped/cpc_vae.pdf} \put(8.0,8.0){81.1\%} \end{overpic} \begin{overpic}[height=56.9055pt]{figures/mnist_2d/cropped/m2.pdf} \put(8.0,8.0){69.1\%} \end{overpic}
Figure 2: Semi-supervised learning of 2-dim. encodings of MNIST digits, with accuracy in lower corner. All methods use 100 labeled examples and 49,900 unlabeled examples. Each observed image xx is mapped to its most likely 2-dim. encoding zz and colored by true label yy. Labeled examples are emphasized. Where applicable, we also show class decision boundaries. Baselines (from left): 2-stage unsupervised VAE-then-MLP (Sec. 2.2) and a “supervised” VAE maximizing joint likelihood logp(x,y)\log p(x,y) (a special case of our PC method with λ=1\lambda=1, Sec. 3.1). Our methods: Prediction constrained VAE (PC-VAE with λ=25\lambda=25, Sec. 3.1) and consistent prediction constrained VAE (CPC-VAE, Sec. 3.2). Competitors: M2 from Kingma et al. (2014), which intentionally decouples label yy from “style” zz, has limited accuracy due to imbalance of discriminative and generative goals.

3 Prediction-Constrained Learning with Consistency

We now highlight two experiments that demonstrate disadvantages of prior SSL methods, and contrast them with our new approaches. In Fig. 1 we show the predictive accuracy of several SSL methods on the widely-used “half-moon” task, where the goal is to to predict a binary label yy given 2-dimensional features xx. We focus on the top row, which shows results given only 6 labeled examples (3 of each class) but hundreds of unlabeled examples. Notably, while M2 has 98.1% accuracy with a small encoding space (C=2C=2), if the generative model is too flexible (C=14C=14) it learns overly complex structure that does not help label-from-feature predictions, dropping accuracy to only 80.6%. In contrast, our consistent prediction constrained (CPC) VAE gets over 98% accuracy with either C=2C=2 or C=14C=14. We have verified it maintains 98% even at C=50C=50, while M2 shows further instability.

Second, in Fig. 2 we show SSL methods for classifying images of MNIST digits (LeCun et al., 2010), given only 10 labeled examples per digit. We seek models with highly accurate label-from-feature predictions, as well as interpretable relationships between the encoding zz and these predictions. When forced to use a 2-dimensional latent space, M2 has worse accuracy and (by design) no apparent relationship between encoding zz and label yy. In contrast, our CPC approach offers noticeable advantages over all baselines in both accuracy and interpretability of the encoding space.

3.1 Prediction Constrained Training for VAEs

We develop a framework for jointly learning a strong generative model of features xx, and making label-given-feature predictions y^(x)\hat{y}(x) of uncompromised quality, by requiring predictions to meet a user-specified quality threshold. Our prediction constrained training objective enables end-to-end estimation of all parameters while incorporating the same task-specific prediction rules and loss functions that will be used in heldout evaluation (“test”) scenarios. Our goals are similar to previous work on end-to-end approximate inference for task-specific losses with simpler probabilistic models (Lacoste-Julien et al., 2011, Stoyanov et al., 2011), but our approach yields simpler algorithms.

Generative model. Our generative model does not include labels yy, only features xx and encodings zz. Their joint distribution pθ(x,z)p_{\theta}(x,z) factorizes as the unsupervised VAE of Eq. (1), and we also use the inference model qϕ(zx)q_{\phi}(z\mid x) defined in Eq. (1). While M2 included the labels yy in its generative model (Kingma et al., 2014), our goals are different: we wish to make label-given-feature predictions, but we are not interested in label marginals or other distributions over yy that do not condition on xx.

Label-from-feature prediction. To predict labels yy from features xx, we use a predictor similar to the two-stage method of Sec. 2.2. We first sample an encoding zqϕ(z|x)z\sim q_{\phi}(z|x) from the learned inference model, and then transform this encoding zz to a label via the predictor function y^w(z)\hat{y}_{w}(z) with parameter ww. By sharing random variable zz, the generative model is involved in label-from-feature predictions.

Constrained PC objective. Unlike the two-stage model, our approach does not do post-hoc prediction with a previously learned generative model. Instead, we train the predictor simultaneously with the generative model via a new, prediction-constrained (PC) objective:

maxθ,ϕz|x,wx𝒟U𝒟SVAE(x;θ,ϕz|x),subj. to:1Mx,y𝒟S𝔼qϕ(z|x)[S(y,y^w(z))]𝒫(x,y;ϕz|x,w)ϵ.\max_{\theta,\phi^{z|x},w}\quad\sum_{x\in\mathcal{D}^{U}\cup\mathcal{D}^{S}}\mathcal{L}^{\text{VAE}}(x;\theta,\phi^{z|x}),~{}~{}\textnormal{subj. to:}~{}\;\frac{1}{M}\!\!\!\sum_{x,y\in\mathcal{D}^{S}}\underbrace{\mathbb{E}_{q_{\phi}(z|x)}[\ell_{S}(y,\hat{y}_{w}(z))]}_{\mathcal{P}(x,y;\phi^{z|x},w)}\leq\epsilon. (8)

The constraint requires that any feasible solution achieve average prediction loss less than ϵ\epsilon on the labeled training set. Both the loss function and scalar threshold ϵ>0\epsilon>0 can be set to reflect task-specific needs (e.g., classification must have a certain false positive rate or overall accuracy). The loss function may be any differentiable function, and need not equal the log-likelihood of discrete labels as assumed by previous work specialized to supervision of topic models (Hughes et al., 2018).

Unconstrained PC objective. Using the KKT conditions, we define an equivalent unconstrained objective that maximizes the unsupervised likelihood but penalizes inaccurate label predictions:

maxθ,ϕz|x,wx𝒟U𝒟SVAE(x;θ,ϕz|x)λx,y𝒟S𝒫(x,y;ϕz|x,w).\displaystyle\max_{\theta,\phi^{z|x},w}\quad\textstyle\sum_{x\in\mathcal{D}^{U}\cup\mathcal{D}^{S}}\mathcal{L}^{\text{VAE}}(x;\theta,\phi^{z|x})-\lambda\textstyle\sum_{x,y~{}\in\mathcal{D}^{S}}\mathcal{P}(x,y;\phi^{z|x},w). (9)

Here λ\lambda is a positive Lagrange multiplier chosen to ensure that the target prediction constraint is achieved; smaller loss tolerances ϵ\epsilon require larger penalty multipliers λ\lambda. This PC objective, and gradients for parameters θ,ϕ,w\theta,\phi,w, can be estimated via Monte Carlo samples from qϕ(zx)q_{\phi}(z\mid x).

Justification. While the PC objective of Eq. (9) may look superficially similar to Eq. (7), we emphasize two key differences. First, our objective couples a generative likelihood and a prediction loss via the shared variational parameters ϕz|x\phi^{z|x}. This makes both generative and discriminative performance depend on the same learned encoding zz. (Later we show how to partition zz so some entries are discriminative, while others affect generative “style” only.) In contrast, the M2 objective uses a label-given-features conditional to make predictions that does not share any of its parameters ϕy|x\phi^{y|x} with the supervised likelihood S\mathcal{L}^{S}. Second, our objective is more affordable: no term requires an expensive marginalization over labels. This is key to scaling to big unlabeled datasets, and also enables tractable learning from datasets whose labels are continuous or multi-dimensional.

Hyperparameters. The major hyperparameter influencing PC training is the constraint multiplier λ0\lambda\geq 0. Setting λ=0\lambda=0 leads to unsupervised maximum likelihood training (or MAP training, given priors on θ\theta) of a classic VAE. Setting λ=1\lambda=1 and choosing a probabilistic loss logp(yz)-\log p(y\mid z) produces a “supervised VAE” that maximizes the joint likelihood pθ(x,y)p_{\theta}(x,y). But as illustrated in Fig. 2, because features xx have much higher dimension than labels yy, the resulting model may have weak predictive performance. Satisfying the strong prediction constraint of Eq. (8) typically requires λ1\lambda\gg 1, and in practice we use validation data to select the best of several candidate λ\lambda values. If a task motivates a concrete tolerance ϵ\epsilon, we can test an increasing sequence of λ\lambda values until the constraint is satisfied.

We emphasize that although Eq. (9) is easier to optimize, we prefer to think of the constrained problem in Eq. (8) as the “primary” objective, because our applied goals are to satisfy discriminative quality first; a generative model that predicts poorly is not plausible. Furthermore, the constrained objective is far more natural for semi-supervised learning. The choice of ϵ\epsilon need not be concerned by the relative sizes of the labeled and unlabeled datasets. In contrast, if either |𝒟U||\mathcal{D}^{U}| or |𝒟S||\mathcal{D}^{S}| changes, the value of λ\lambda may need to change dramatically to reach the same prediction quality.

Refer to caption
Figure 3: Formalization of our consistency-constrained VAE as a decision network (Cowell et al., 2006). Circular nodes are random variables, including the latent VAE code zz that generates observed features xx. Shaded nodes are observed, including the class labels yy for some data (left). Square decision nodes indicate predictions y^w\hat{y}_{w} of class labels that depend on the amortized inference network qϕ(zx)q_{\phi}(z\mid x). Diamonds indicate losses (negative utilities) that influence the variational prediction of labels and latent variables. Generative likelihood: Like standard VAEs, our generalizations choose generative model parameters θ\theta and variational posteriors qϕq_{\phi} to maximize the ELBO \mathcal{L} (orange). Prediction accuracy: Unlike previous semi-supervised VAEs, we do not directly model the probabilistic dependence of labels yy on zz and/or xx. We instead treat label prediction as a decision problem, with application-motivated loss S\ell_{S} (red), that constrains the approximate VAE posterior qϕ(z|x)q_{\phi}(z|x) (and therefore the encodings and generative model). Prediction consistency: For unlabeled data (right), we cannot directly evaluate the quality of predictions. However, we do know that if two observations xx and x¯\bar{x} are generated from the same latent code zz, they should have identical labels; otherwise, the model cannot have high accuracy. The loss C\ell_{C} (blue) enforces the consistency of such predictions. Aggregate consistency: By the law of large numbers, we also know that aggregate label frequencies for unlabeled data should be close to the frequencies π\pi observed in labeled data. The loss A\ell_{A} (green) enforces this constraint, and penalizes degenerate predictors that satisfy C\ell_{C} by predicting the same label y^w\hat{y}_{w} for most or all of the unlabeled data.

3.2 Enforcing Consistent Predictions from Generative Model Reconstructions

While the PC objective is effective given sufficient labeled data, it may generalize poorly when labels are very sparse (see Fig. 1). This fundamental problem arises because in the PC objective of Eq. (8), the parameters ww of the predictor y^w(z)\hat{y}_{w}(z) are only directly informed by the labeled training data.

Revisiting the generative model, let xpθ(z)x\sim p_{\theta}(\cdot\mid z^{\prime}) and x¯pθ(z)\bar{x}\sim p_{\theta}(\cdot\mid z^{\prime}) be two observations sampled from the same latent code zz^{\prime}. Even if the true label yy of xx is uncertain, we know that for this model to be useful for predictive tasks, x¯\bar{x} must have the same label as xx. We formalize this relationship via a consistency constraint requiring label predictions for common-code data pairs (x,x¯)(x,\bar{x}) to approximately match (see Fig. 3). As we show, this regularization may dramatically boost performance.

Given features xx, our method predicts labels by sampling zqϕ(zx)z\sim q_{\phi}(z\mid x) from the approximate posterior, and then applying our predictor y^w(z)\hat{y}_{w}(z). Alternatively, given xx we can first simulate alternative features x¯\bar{x} with matching code zz by sampling from the inference and generative models, and then predict the label associated with x¯\bar{x}. We constrain the label predictions yy for xx, and y¯\bar{y} for x¯\bar{x}, to be similar via a consistency penalty function C(y,y¯)\ell_{C}(y,\bar{y}). For the classification tasks considered below, we use a cross-entropy consistency penalty C\ell_{C}. Given this penalty, we constrain the maximum values of the following consistency costs on unlabeled and labeled examples, respectively:

𝒞U(x;θ,ϕ,w)\displaystyle\mathcal{C}^{U}(x;\theta,\phi,w) 𝔼qϕ(z|x)[𝔼pθ(x¯|z)[𝔼qϕ(z¯|x¯)[C(y^w(z),y^w(z¯))]]],\displaystyle\triangleq\mathbb{E}_{q_{\phi}(z|x)}\left[\mathbb{E}_{p_{\theta}(\bar{x}|z)}\left[\mathbb{E}_{q_{\phi}(\bar{z}|\bar{x})}\left[\ell_{C}(\hat{y}_{w}(z),\hat{y}_{w}(\bar{z}))\right]\right]\right], (10)
𝒞S(x,y;θ,ϕ,w)\displaystyle\mathcal{C}^{S}(x,y;\theta,\phi,w) 𝔼qϕ(z|x)[𝔼pθ(x¯|z)[𝔼qϕ(z¯|x¯)[C(y,y^w(z¯))]]].\displaystyle\triangleq\mathbb{E}_{q_{\phi}(z|x)}\left[\mathbb{E}_{p_{\theta}(\bar{x}|z)}\left[\mathbb{E}_{q_{\phi}(\bar{z}|\bar{x})}\left[\ell_{C}(y,\hat{y}_{w}(\bar{z}))\right]\right]\right]. (11)

Consistent PC: Unconstrained objective. To train parameters, we apply our consistency costs to unlabeled and labeled feature vectors, respectively. The overall objective becomes:

maxθ,ϕ,w\displaystyle\max_{\theta,\phi,w} x𝒟U𝒟SVAE(x;θ,ϕ)x𝒟Uγ𝒞U(x;θ,ϕ,w)+x,y𝒟Sλ𝒫(x,y;ϕ,w)γ𝒞S(x,y;θ,ϕ,w),\displaystyle\sum_{x\in\mathcal{D}^{U}\cup\mathcal{D}^{S}}\mathcal{L}^{\text{VAE}}(x;\theta,\phi)-\sum_{x\in\mathcal{D}^{U}}\gamma\mathcal{C}^{U}(x;\theta,\phi,w)+\!\sum_{x,y\in\mathcal{D}^{S}}-\lambda\mathcal{P}(x,y;\phi,w)-\gamma\mathcal{C}^{S}(x,y;\theta,\phi,w),

where \mathcal{L} is the unsupervised likelihood, 𝒫\mathcal{P} is the predictor loss, and 𝒞\mathcal{C} are the consistency constraints. Here, γ>0\gamma>0 is a scalar Lagrange multiplier for the consistency terms, with similar interpretation as λ\lambda.

Aggregate Label Consistency. For SSL applications, we find it is also useful to regularize our model with an aggregate label consistency constraint, which forces the distribution of label predictions for unlabeled data to be aligned with a known target distribution π\pi. This discourages predictions on ambiguous unlabeled examples from collapsing to a single value. We define the aggregate consistency loss as: A(π,𝔼x𝒟U,zq(z|x)[y^w(z)])\ell_{A}(\pi,\mathbb{E}_{x\sim\mathcal{D}^{U},z\sim q(z|x)}[\hat{y}_{w}(z)]), and again use a cross-entropy penalty. If the target distribution of labels π\pi is unknown, we set it to the empirical distribution of the labeled data.

Related work on consistency. Recently popular SSL image classifiers focused on discriminative goals will train the weights of a CNN to minimize a modified objective that penalizes both label accuracy and a notions of consistency or smoothness on unlabeled data. Examples include consistency under adversarial perturbations (Miyato et al., 2019), label-invariant transformations (Laine & Aila, 2017), and when interpolating between training features (Berthelot et al., 2019). This regularization can deliver competitive discriminative performance, but does not meet our goal of generative modeling. Recently, Unsupervised Data Augmentation (UDA, Xie et al. (2020)), achieved state-of-the-art vision and text SSL classification by enforcing label consistency on augmented samples of unlabeled features. UDA relies on the availability of well-engineered augmentation routines for specific domains (e.g. image processing library transforms for vision or back-translation for text). In contrast, we learn a generative model that produces feature vectors for which predictions need to be consistent. Our approach is more applicable to new domains where advanced augmentation routines are not available.

In broader machine learning, “cycle-consistency” has improved generative adversarial methods for images (Zhu et al., 2017, Zhou et al., 2016) or biomedical data (McDermott et al., 2018). Others have developed cycle-consistent objectives for VAEs (Jha et al., 2018) which focus on consistency in code vectors zz. In contrast, our work focuses on semi-supervised learning and enforces cycle consistency in labels yy. Recently, Miller et al. (2019) developed discriminative regularization for VAEs. Their objective is not designed for SSL and uses a direct feature-to-label prediction model that must be consistent with reconstructed predictions. Our approach uses code-to-label prediction and SSL.

3.3 Improved Generative Models: Robust Likelihoods and Spatial Transformers

As our approach is applicable to any generative model, we can incorporate prior knowledge of the data domain to improve both generative and discriminative performance. We consider two examples: likelihoods that model noisy pixels, and explicit affine transformations to model image deformations.

Robust Likelihoods. Instead of a normal (or other common) likelihood we use a "Noise-Normal" likelihood to model images more robustly. We assume that pixel intensities xx have values in the interval [1,1][-1,1] and rescale our datasets to match. Our Noise-Normal likelihood is defined as a 2-component mixture of a truncated Normal and a uniform “noise” distribution, with pixel-specific mixture weights. Define the standard normal PDF as ϕ()\phi(\cdot) and standard normal CDF as Φ()\Phi(\cdot). We write the probability density function of the Noise-Normal distribution with parameters (ρ\rho, μ\mu, σ\sigma) as:

(xρ,μ,σ)=ρ(ϕ(xμσ)Φ(1μσ)Φ(1μσ))+(1ρ)(12).\displaystyle\mathcal{F}(x\mid\rho,\mu,\sigma)=\rho\left(\frac{\phi\left(\frac{x-\mu}{\sigma}\right)}{\Phi\left(\frac{1-\mu}{\sigma}\right)-\Phi\left(\frac{-1-\mu}{\sigma}\right)}\right)+\left(1-\rho\right)\left(\frac{1}{2}\right). (12)

Following Eq. (1), our (unsupervised) VAE now uses the revised generative and inference models:

pθ(x,z)=𝒩(z0,IC)(xρθ(z),μθ(z),σθ(z)),qϕ(zx)=𝒩(zμϕ(x),σϕ(x)).\displaystyle p_{\theta}(x,z)=\mathcal{N}(z\mid 0,I_{C})\cdot\mathcal{F}(x\mid\rho_{\theta}(z),\mu_{\theta}(z),\sigma_{\theta}(z)),\quad q_{\phi}(z\mid x)=\mathcal{N}(z\mid\mu_{\phi}(x),\sigma_{\phi}(x)). (13)

This approach allows our model to avoid sensitivity to outliers and noise in the observed images, and boosts the performance of our CPC method for SSL.

Spatial Transformer VAE. Our spatial transformer VAE retains the structure of a standard VAE, but reinterprets the latent code zz as two components. We denote the first 6 latent dimensions as ztz_{t}, and associate these with 6 affine transformation parameters capturing image translation, rotation, scaling, and shear. The generative model maps each value into a fixed range and creates an affine transformation matrix, MtM_{t}, by applying the transformations in a fixed order. The remainder of the latent code, zz_{*}, is used to generate parameters for independent, per-pixel likelihoods. Assuming normal likelihoods, the output parameters for the pixel with coordinates (i,j)(i,j) are μθ(z)ij\mu_{\theta}(z_{*})_{ij}, σθ(z)ij\sigma_{\theta}(z_{*})_{ij}.

We re-orient our per-pixel likelihoods according to the affine transform MtM_{t}. The parameters of the likelihood for pixel (i,j)(i,j) will use the decoder outputs at coordinate (i,j)(i^{\prime},j^{\prime}), where MtM_{t} defines the affine mapping from (i,j)(i^{\prime},j^{\prime}) to (i,j)(i,j). We apply this transformation in a (sub) differentiable way via a spatial transformer layer (Jaderberg et al., 2015) that takes as input MtM_{t} and the appropriate parameter maps μθ(z)\mu_{\theta}(z_{*}), σθ(z)\sigma_{\theta}(z_{*}), and outputs a final set of parameters for the individual pixel likelihoods. As (i,j)(i^{\prime},j^{\prime}) may not correspond to integer coordinates, we use bilinear interpolation over an appropriate representation of the likelihood parameters, and appropriately pad the size of the decoder output.

We further account for this special structure in our prediction and consistency constraints. For many applications, we have prior knowledge that small affine transforms should not affect the class of an image, and thus we can define consistency constraints that condition on zz_{*} but not ztz_{t}.

Source Method MNIST (100) SVHN (1000) NORB (1000)
Tables 1-2 of Kingma et al. (2014) M1 + M2 96.67%96.67\% (±0.14)(\pm 0.14) 63.98%63.98\% (±0.10)(\pm 0.10) -
Table 2 of Maaløe et al. (2016) ADGM 99.04%\textbf{99.04}\% (±0.02)(\pm 0.02) 77.14%77.14\% 89.94%89.94\% (±0.05)(\pm 0.05)
Table 2 of Maaløe et al. (2016) SDGM 98.68%98.68\% (±0.07)(\pm 0.07) 83.39%83.39\% (±0.24)(\pm 0.24) 90.60%90.60\% (±0.04)(\pm 0.04)
Gordon & Hernández-Lobato (2020) Blended M2 93.05%93.05\% (±0.73)(\pm 0.73) - -
Tables 3-4 of Miyato et al. (2019) VAT 98.64%98.64\% (±0.03)(\pm 0.03) 94.23%\textbf{94.23}\% (±0.32)(\pm 0.32) -
ours, using labeled-set only WRN 73.91%73.91\% (±1.45)(\pm 1.45) 87.7%87.7\% (±1.02)(\pm 1.02) 86.7%86.7\% (±1.32)(\pm 1.32)
ours CPC VAE 98.29%98.29\% (±0.50)(\pm 0.50) 94.22%\textbf{94.22}\% (±0.62)(\pm 0.62) 92.00%\textbf{92.00}\% (±1.21)(\pm 1.21)
Table 1: Test accuracy of SSL methods. Our results show mean (std. dev.) across 10 samples of the labeled set. The labeled-set-only discriminative neural net (WRN) has roughly the same size as our CPC VAE.
Method MNIST (100) Method MNIST (100)
CPC (2 layer) 96.68%\textbf{96.68}\% (±0.54)(\pm 0.54) M1 + M2 (Kingma et al., 2014) 96.67%\textbf{96.67}\% (±0.14)(\pm 0.14)
CPC (2 layer, w/o aggregate loss) 94.27%94.27\% (±3.78)(\pm 3.78) M2 (1 layer, α=0.1\alpha=0.1) (Kingma et al., 2014) 88.03%88.03\% (±1.71)(\pm 1.71)
CPC (2 layer, w/o transforms) 91.93%91.93\% (±1.65)(\pm 1.65) M2 (2 layer, α=0.1\alpha=0.1) 83.32%83.32\% (±5.22)(\pm 5.22)
CPC (4 layer, w/o transforms) 93.78%93.78\% (±2.25)(\pm 2.25) M2 (4 layer, α=0.1\alpha=0.1) 47.05%47.05\% (±8.13)(\pm 8.13)
PC (2 layer) 80.49%80.49\% (±3.31)(\pm 3.31) M2 (4 layer, tuned to α=10\alpha=10) 68.15%68.15\% (±3.43)(\pm 3.43)
VAE + MLP 72.90%72.90\% (±1.98)(\pm 1.98) M2 (1 layer, α=0.1\alpha=0.1, Noise-Normal) 73.93%73.93\% (±8.12)(\pm 8.12)
Table 2: Ablation study for SSL on MNIST using a dense MLP matching M2’s network architecture. Trials used 100 labels and encoding size C=50C=50. Unless cited, all results come from our implementation, where encoder and decoder have 1000 units per hidden layer. All our introduced techniques improve accuracy. M2’s accuracy deteriorates when its networks are overparameterized, even after tuning α\alpha on validation set instead of using default from Kingma et al. (2014), while our method remains stable.

4 Experiments

We assess our consistent prediction-constrained (CPC) VAE on two key goals: accurate prediction of labels yy given features xx (especially when labels are rare) and useful generative modeling of xx. We compare to ablations of our own method (without consistency, without spatial transformations) and to external baselines. We report each method’s mean and standard deviation in classification accuracy across 10 labeled subsets. We trained using ADAM (Kingma & Ba, 2014), with each minibatch containing 50% labeled and 50% unlabeled data. Hyperparameter search used Optuna (Akiba et al., 2019) to maximize accuracy on validation data. Supervised baselines used either MLP or wide residual nets (WRN, Zagoruyko & Komodakis (2016)). Reproducible details are in appendices.

SSL classification on MNIST with thorough internal comparisons. In Table 2 we compare several variations of our CPC methods and the M2 model on an SSL version of MNIST  (LeCun et al., 2010, 10 classes,|𝒟S|=100|\mathcal{D}^{S}|=100,|𝒟U|=49900|\mathcal{D}^{U}|=49900, 10000 validation, 10000 test).

SSL classification on SVHN and NORB. Table 1 compares our methods on two standard SSL tasks: Street-View Housing Numbers (SVHN) (Netzer et al., 2011, 10 classes, |𝒟S|=1000|\mathcal{D}^{S}|=1000, |𝒟U|=62257|\mathcal{D}^{U}|=62257, 10000 validation, 26032 test) and the NYU Object Recognition Benchmark (NORB) (LeCun et al., 2004, 5 classes, |𝒟S|=1000|\mathcal{D}^{S}|=1000, |𝒟U|=21300|\mathcal{D}^{U}|=21300, 2000 validation, 24300 test).

SSL classification on CelebA. We ran additional experiments on a variant of the CelebA dataset  (Liu et al., 2015). For these trials we created a classification problem with 4 classes based on the combination of gender (woman/man) and facial expression (neutral/smiling). (4 classes, |𝒟S|=1000|\mathcal{D}^{S}|=1000, |𝒟U|=159770|\mathcal{D}^{U}|=159770, 2000 validation, 19962 test). We report our results in figure 6.

Across all evaluations, we can conclude:

Both consistency and prediction constraints are needed for high accuracy. In Table 2, PC alone gets 80% accuracy on 100-label MNIST, while adding consistency yields 97% for CPC. The benefits of CPC over PC in both accuracy and latent interpretability are visible in Figs. 1-2. Our aggregate label consistency improves robustness, reducing the variance in CPC accuracy (from 3.7823.78^{2} to 0.5420.54^{2}).

CPC training delivers strong improvements in SSL prediction quality over baselines. In Table 1, our CPC achieves 94.22% and 92.0% on the challenging 1000-label SVHN and NORB benchmarks, which surpasses by >1.4% all reported baselines while being reliable across runs. The M1+M2 baseline (Kingma et al., 2014) is not a coherent generative model, but rather a discriminative ensemble of multiple models. It performs well on MNIST, but very poorly on the more challenging SVHN.

CPC delivers better generative performance; it is not all about prediction. We improve on unsupervised VAEs by explicitly learning latent representations informed by class labels (Fig. 2). In Fig. 4, Fig. 5, and Fig. 6,we show visually-plausible class-conditional samples from our best CPC models. Additional visuals from learned VAE and CPC-VAE models are in the supplement.

With improved generative models, CPC can improve predictions. Fig. 5 shows that including spatial transformations allows learning a canonical orientation and scale for each digit. This generative improvement boosts classifier accuracy (e.g., MNIST improves from 91.9% to to 96.7% in Table 2).

Refer to caption
Figure 4: Sampled reconstructions used to compute the consistency loss during training. Top: Original image. Middle: Sampled reconstructions using a “Noise-Normal” likelihood. Bottom: Sampled reconstructions with spatial affine transformations sampled from the prior.
Refer to caption
(a) Prior samples
Refer to caption
(b) Reconstructions
Figure 5: Visualizations of generative model performance on SVHN for a prediction and consistency constrained VAE incorporating latent affine transformations, trained on the SVHN dataset with all labels. (5(a)) Samples from the learned generative model conditioned on class (more are shown in supplement A). Samples are chosen via rejection sampling in the latent space with a threshold of 95% confidence in the target class. (5(b)) Reconstructions of images in the held-out test set. Each triplet shows the original image (left), the reconstruction (middle), and an aligned reconstruction (right) obtained by fixing the learned affine transform variables to the global mean.
Refer to caption
Refer to caption
Figure 6: CelebA dataset results. Left: Test accuracy on the 4-class CelebA SSL task (1000 labeled training images, 159,770 unlabeled). CPC-VAE improves on the labeled-set only WRN, and dramatically improves on the unsupervised VAE which poorly separates the classes in the latent space. Right: Class-conditional samples of the 4 possible classes from our semi-supervised CPC VAE.

5 Conclusion

We have developed a new optimization framework for semi-supervised VAEs that can balance discriminative and generative goals. Across image classification tasks, our CPC-VAE method delivers superior accuracy and label-informed generative models with visually-plausible samples. Unlike previous efforts to enforce constraints on latent variable models, such as expectation constraints (Mann & McCallum, 2010), posterior regularization (Zhu et al., 2014; 2012), posterior constraints (Ganchev et al., 2010), or prediction constraints for topic models (Hughes et al., 2018), our approach is the only one that coherently and simultaneously treats uncertainty in latent variables zz, applies to flexible “deep” non-conjugate models, and offers scalable training and test evaluation via amortized inference.

A further contribution is demonstrating the necessity of consistency for improving discrimination. Our CPC approach is an antidote to model misspecification: the constraints on prediction quality and consistency prevent the model from learning a generative model that is unaligned with the classification task or that overfits with more flexible generative models (as M2 is vulnerable to do). As we show with spatial transformers, our work lets improvements in generative model quality directly improve semi-supervised label prediction, helping realize the promise of deep generative models.

References

  • Abadi et al. (2015) Martín Abadi, Ashish Agarwal, Paul Barham, et al. TensorFlow: Large-scale machine learning on heterogeneous systems, 2015. Software available from tensorflow.org.
  • Akiba et al. (2019) Takuya Akiba, Shotaro Sano, Toshihiko Yanase, Takeru Ohta, and Masanori Koyama. Optuna: A next-generation hyperparameter optimization framework. In Proceedings of the 25rd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, 2019.
  • Berthelot et al. (2019) David Berthelot, Nicholas Carlini, Ian Goodfellow, Nicolas Papernot, Avital Oliver, and Colin Raffel. MixMatch: A Holistic Approach to Semi-Supervised Learning. In Advances in Neural Information Processing Systems, 2019. URL http://arxiv.org/abs/1905.02249.
  • Cowell et al. (2006) Robert G Cowell, Philip Dawid, Steffen L Lauritzen, and David J Spiegelhalter. Probabilistic networks and expert systems: Exact computational methods for Bayesian networks. Springer Science & Business Media, 2006.
  • de Bem et al. (2018) Rodrigo de Bem, Arnab Ghosh, Thalaiyasingam Ajanthan, Ondrej Miksik, N. Siddharth, and Philip Torr. A Semi-supervised Deep Generative Model for Human Body Analysis. In Proceedings of the European Conference on Computer Vision (ECCV) Workshops, 2018. URL http://openaccess.thecvf.com/content_eccv_2018_workshops/w11/html/de_A_Semi-supervised_Deep_Generative_Modelfor_Human_Body_Analysis_ECCVW_2018_paper.html.
  • Dumoulin et al. (2017) Vincent Dumoulin, Ishmael Belghazi, Ben Poole, Olivier Mastropietro, Alex Lamb, Martin Arjovsky, and Aaron Courville. Adversarially Learned Inference. In International Conference on Learning Representations (ICLR), 2017.
  • Figurnov et al. (2018) Michael Figurnov, Shakir Mohamed, and Andriy Mnih. Implicit reparameterization gradients. In Advances in Neural Information Processing Systems, 2018.
  • Ganchev et al. (2010) Kuzman Ganchev, João Graça, Jennifer Gillenwater, and Ben Taskar. Posterior Regularization for Structured Latent Variable Models. Journal of Machine Learning Research, 11:2001–2049, 2010.
  • Gordon & Hernández-Lobato (2020) Jonathan Gordon and José Miguel Hernández-Lobato. Combining deep generative and discriminative models for Bayesian semi-supervised learning. Pattern Recognition, 100, 2020.
  • Grandvalet & Bengio (2004) Yves Grandvalet and Yoshua Bengio. Semi-supervised learning by entropy minimization. In Proceedings of the 17th International Conference on Neural Information Processing Systems, NIPS’04, pp.  529–536, Cambridge, MA, USA, 2004. MIT Press.
  • Higgins et al. (2017) Irina Higgins, Loic Matthey, Arka Pal, Christopher Burgess, Xavier Glorot, Matthew Botvinick, Shakir Mohamed, and Alexander Lerchner. beta-vae: Learning basic visual concepts with a constrained variational framework. In International Conference on Learning Representations (ICLR), 2017.
  • Hughes et al. (2018) Michael C. Hughes, Gabriel Hope, Leah Weiner, Thomas H. McCoy, Roy H. Perlis, Erik B. Sudderth, and Finale Doshi-Velez. Semi-Supervised Prediction-Constrained Topic Models. In Artificial Intelligence and Statistics, 2018. URL http://proceedings.mlr.press/v84/hughes18a.html.
  • Jaderberg et al. (2015) Max Jaderberg, Karen Simonyan, Andrew Zisserman, and koray kavukcuoglu. Spatial transformer networks. In Advances in Neural Information Processing Systems, 2015. URL http://papers.nips.cc/paper/5854-spatial-transformer-networks.pdf.
  • Jha et al. (2018) Ananya Harsh Jha, Saket Anand, Maneesh Singh, and V. S. R. Veeravasarapu. Disentangling Factors of Variation with Cycle-Consistent Variational Auto-encoders. In European Conference on Computer Vision (ECCV). Springer International Publishing, 2018.
  • Kingma & Ba (2014) Diederik P. Kingma and Jimmy Ba. Adam: A Method for Stochastic Optimization. arXiv:1412.6980 [cs], 2014. URL http://arxiv.org/abs/1412.6980.
  • Kingma & Welling (2014) Diederik P. Kingma and Max Welling. Auto-Encoding Variational Bayes. In International Conference on Learning Representations, 2014. URL http://arxiv.org/abs/1312.6114.
  • Kingma et al. (2014) Diederik P Kingma, Shakir Mohamed, Danilo Jimenez Rezende, and Max Welling. Semi-supervised learning with deep generative models. In Advances in Neural Information Processing Systems, 2014. URL https://papers.nips.cc/paper/5352-semi-supervised-learning-with-deep-generative-models.pdf.
  • Kumar et al. (2017) Abhishek Kumar, Prasanna Sattigeri, and Tom Fletcher. Semi-supervised Learning with GANs: Manifold Invariance with Improved Inference. In Advances in Neural Information Processing Systems, 2017. URL https://papers.nips.cc/paper/7137-semi-supervised-learning-with-gans-manifold-invariance-with-improved-inference.pdf.
  • Lacoste-Julien et al. (2011) Simon Lacoste-Julien, Ferenc Huszár, and Zoubin Ghahramani. Approximate inference for the loss-calibrated bayesian. In Artificial Intelligence and Statistics, 2011.
  • Laine & Aila (2017) Samuli Laine and Timo Aila. Temporal Ensembling for Semi-Supervised Learning. In International Conference on Learning Representations, 2017. URL https://openreview.net/pdf?id=BJ6oOfqge.
  • Larsen et al. (2016) Anders Boesen Lindbo Larsen, Søren Kaae Sønderby, Hugo Larochelle, and Ole Winther. Autoencoding beyond pixels using a learned similarity metric. In International Conference on Machine Learning, pp. 1558–1566, 2016. URL http://proceedings.mlr.press/v48/larsen16.html.
  • LeCun et al. (2004) Y LeCun, F. J. Huang, and L. Bottou. Learning Methods for Generic Object Recognition with Invariance to Pose and Lighting. In IEEE Computer Vision and Pattern Recognition (CVPR), 2004. URL http://yann.lecun.com/exdb/publis/pdf/lecun-04.pdf.
  • LeCun et al. (2010) Yann LeCun, Corinna Cortes, and CJ Burges. MNIST handwritten digit database, 2010. URL http://yann.lecun.com/exdb/mnist/.
  • Liu et al. (2015) Ziwei Liu, Ping Luo, Xiaogang Wang, and Xiaoou Tang. Deep learning face attributes in the wild. In Proceedings of International Conference on Computer Vision (ICCV), December 2015.
  • Maaløe et al. (2016) Lars Maaløe, Casper Kaae Sønderby, Søren Kaae Sønderby, and Ole Winther. Auxiliary Deep Generative Models. arXiv:1602.05473 [cs, stat], 2016. URL http://arxiv.org/abs/1602.05473.
  • Mann & McCallum (2010) Gideon S Mann and Andrew McCallum. Generalized expectation criteria for semi-supervised learning with weakly labeled data. Journal of Machine Learning Research, 11(Feb):955–984, 2010.
  • McDermott et al. (2018) Matthew B A McDermott, Tom Yan, Tristan Naumann, Nathan Hunt, Harini Suresh, Peter Szolovits, and Marzyeh Ghassemi. Semi-Supervised Biomedical Translation with Cycle Wasserstein Regression GANs. In Thirty-Second AAAI Conference on Artificial Intelligence (AAAI-18), pp.  8, 2018. URL https://www.aaai.org/ocs/index.php/AAAI/AAAI18/paper/viewFile/16938/15951.
  • Miller et al. (2019) Andrew C Miller, Ziad Obermeyer, John P Cunningham, and Sendhil Mullainathan. Discriminative Regularization for Latent Variable Models with Applications to Electrocardiography. In International Conference on Machine Learning, pp.  10, 2019. URL https://proceedings.mlr.press/v97/miller19a/miller19a.pdf.
  • Miyato et al. (2019) Takeru Miyato, Shin-Ichi Maeda, Masanori Koyama, and Shin Ishii. Virtual Adversarial Training: A Regularization Method for Supervised and Semi-Supervised Learning. IEEE Transactions on Pattern Analysis and Machine Intelligence, 41(8):1979–1993, 2019. URL https://ieeexplore.ieee.org/document/8417973/.
  • Netzer et al. (2011) Yuval Netzer, Tao Wang, Adam Coates, Alessandro Bissacco, Bo Wu, and Andrew Y Ng. Reading Digits in Natural Images with Unsupervised Feature Learning. In NeurIPS Workshop on Deep Learning and Unsupervised Feature Learning, 2011. URL http://ufldl.stanford.edu/housenumbers.
  • Oliver et al. (2018) Avital Oliver, Augustus Odena, Colin Raffel, Ekin D. Cubuk, and Ian J. Goodfellow. Realistic Evaluation of Deep Semi-Supervised Learning Algorithms. arXiv:1804.09170 [cs, stat], 2018. URL http://arxiv.org/abs/1804.09170.
  • Siddharth et al. (2017) N. Siddharth, Brooks Paige, Jan-Willem van de Meent, Alban Desmaison, Noah D. Goodman, Pushmeet Kohli, Frank Wood, and Philip H. S. Torr. Learning Disentangled Representations with Semi-Supervised Deep Generative Models. In Advances in Neural Information Processing Systems, 2017. URL http://arxiv.org/abs/1706.00400.
  • Stoyanov et al. (2011) Veselin Stoyanov, Alexander Ropson, and Jason Eisner. Empirical risk minimization of graphical model parameters given approximate inference, decoding, and model structure. In Artificial Intelligence and Statistics, 2011.
  • Xie et al. (2020) Qizhe Xie, Zihang Dai, Eduard Hovy, Minh-Thang Luong, and Quoc V. Le. Unsupervised data augmentation for consistency training. In Advances in Neural Information Processing Systems, 2020.
  • Zagoruyko & Komodakis (2016) Sergey Zagoruyko and Nikos Komodakis. Wide residual networks. In Edwin R. Hancock Richard C. Wilson and William A. P. Smith (eds.), Proceedings of the British Machine Vision Conference (BMVC), pp. 87.1–87.12. BMVA Press, September 2016. ISBN 1-901725-59-6. doi: 10.5244/C.30.87. URL https://dx.doi.org/10.5244/C.30.87.
  • Zhang et al. (2019) Xiang Zhang, Lina Yao, and Feng Yuan. Adversarial Variational Embedding for Robust Semi-supervised Learning. In Proceedings of the 25th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining, pp.  139–147, Anchorage AK USA, 2019. ACM.
  • Zhou et al. (2016) Tinghui Zhou, Philipp Krahenbuhl, Mathieu Aubry, Qixing Huang, and Alexei A. Efros. Learning Dense Correspondence via 3D-Guided Cycle Consistency. In 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pp.  117–126, Las Vegas, NV, USA, 2016. IEEE.
  • Zhu et al. (2012) Jun Zhu, Amr Ahmed, and Eric P Xing. MedLDA: Maximum margin supervised topic models. The Journal of Machine Learning Research, 13(1):2237–2278, 2012.
  • Zhu et al. (2014) Jun Zhu, Ning Chen, and Eric P Xing. Bayesian inference with posterior regularization and applications to infinite latent SVMs. Journal of Machine Learning Research, 15(1):1799–1847, 2014.
  • Zhu et al. (2017) Jun-Yan Zhu, Taesung Park, Phillip Isola, and Alexei A. Efros. Unpaired Image-to-Image Translation Using Cycle-Consistent Adversarial Networks. In 2017 IEEE International Conference on Computer Vision (ICCV), pp.  2242–2251, Venice, 2017. IEEE.
  • Zhu (2005) Xiaojin Zhu. Semi-Supervised Learning Literature Survey. Technical Report Technical Report 1530, Department of Computer Science, University of Wisconsin Madison., 2005.

Appendix A Details and Visualizations of Generative Models

A.1 Noise-Normal Likelihood

As discussed in Sec. 4.3, we use a “Noise-Normal” distribution as the pixel likelihood for many of our experiments. We define this distribution to be a parameterized two-component mixture of a truncated-normal distribution and a uniform distribution. We will use ρ\rho to denote the mixture probability of the Normal component, and μ\mu and σ\sigma to denote the mean and standard deviation of the truncated-normal, respectively. The generative model (or decoder) predicts a distinct outlier probability (1ρ)(1-\rho) for each pixel. We assume that pixel intensities are defined on the domain [1,1][-1,1] and rescale our datasets to match. We can write the probability density function of the Noise-Normal distribution via the standard normal PDF ϕ()\phi(\cdot), and standard normal CDF Φ()\Phi(\cdot), as follows:

f(xρ,μ,σ)=ρ(ϕ(xμσ)Φ(1μσ)Φ(1μσ))+(1ρ)(12).\displaystyle f(x\mid\rho,\mu,\sigma)=\rho\left(\frac{\phi\left(\frac{x-\mu}{\sigma}\right)}{\Phi\left(\frac{1-\mu}{\sigma}\right)-\Phi\left(\frac{-1-\mu}{\sigma}\right)}\right)+\left(1-\rho\right)\left(\frac{1}{2}\right). (14)

We can similarly express the cumulative distribution function of the Noise-Normal distribution as:

F(xρ,μ,σ)=ρ(Φ(xμσ)Φ(1μσ)Φ(1μσ)Φ(1μσ))+(1ρ)(x+12).\displaystyle F(x\mid\rho,\mu,\sigma)=\rho\left(\frac{\Phi\left(\frac{x-\mu}{\sigma}\right)-\Phi\left(\frac{-1-\mu}{\sigma}\right)}{\Phi\left(\frac{1-\mu}{\sigma}\right)-\Phi\left(\frac{-1-\mu}{\sigma}\right)}\right)+\left(1-\rho\right)\left(\frac{x+1}{2}\right). (15)

In order to propagate gradients through the sampling process of the noise-normal distribution, we use the implicit reparameterization gradients approach of Figurnov et al. (2018). Given a sample xx drawn from this distribution, we compute the gradient with respect to the parameters ρ\rho, μ\mu, and σ\sigma as:

ρ,μ,σx=ρ,μ,σF(xρ,μ,σ)f(xρ,μ,σ).\nabla_{\rho,\mu,\sigma}x=\frac{-\nabla_{\rho,\mu,\sigma}F(x\mid\rho,\mu,\sigma)}{f(x\mid\rho,\mu,\sigma)}. (16)

When fitting the parameters of this distribution using gradient descent, we enforce the constraints that ρ[0,1]\rho\in[0,1], μ[1,1]\mu\in[-1,1], and σ>0\sigma>0. To do this, we optimize unconstrained parameters ρ,μ,σ\rho_{*},\mu_{*},\sigma_{*}, and then define ρ=sigmoid(ρ)\rho=\text{sigmoid}(\rho_{*}), μ=tanh(μ)\mu=\text{tanh}(\mu_{*}), and σ=softplus(σ)\sigma=\text{softplus}(\sigma_{*}).

A.2 Spatial Transformer VAE

We now describe how to sample affine transformations MtM_{t} for use in our generative model of images. As described in Sec. 4.3, the latent transformation code ztz_{t} has 6 real-valued dimensions, each corresponding to one of the following 6 affine transformation parameters:

  • zt\scaleto(1)5ptz_{t}^{\scaleto{(1)}{5pt}}\rightarrow horizontal translation,

  • zt\scaleto(2)5ptz_{t}^{\scaleto{(2)}{5pt}}\rightarrow vertical translation,

  • zt\scaleto(3)5ptz_{t}^{\scaleto{(3)}{5pt}}\rightarrow rotation,

  • zt\scaleto(4)5ptz_{t}^{\scaleto{(4)}{5pt}}\rightarrow shear,

  • zt\scaleto(5)5ptz_{t}^{\scaleto{(5)}{5pt}}\rightarrow horizontal scale,

  • zt\scaleto(6)5ptz_{t}^{\scaleto{(6)}{5pt}}\rightarrow vertical scale.

To constrain our transformations to a fixed range of plausible values, we construct MtM_{t} using parameters z¯t(i)=tanh(zt(i))\bar{z}_{t}^{(i)}=\text{tanh}(z_{t}^{(i)}) that are first mapped to the interval [1,+1][-1,+1], and then linearly rescaled to an appropriate range via hyperparameters α(1),,α(6)\alpha^{(1)},\ldots,\alpha^{(6)}. Figure 7 illustrates that the induced prior for z¯t(i)\bar{z}_{t}^{(i)} is heaviest for extreme values, encouraging aggressive augmentation when sampling from the prior. The mapping function could be changed to modify this distribution for other applications.

Refer to caption
Figure 7: Prior distribution for latent parameters z¯t(i)=tanh(zt(i))\bar{z}_{t}^{(i)}=\text{tanh}(z_{t}^{(i)}) used to represent affine transformations.

Given these latent transformation parameters, we define an affine transformation matrix MtM_{t} as follows:

Mt=[10α\scaleto(1)5ptz¯t\scaleto(1)5pt01α\scaleto(2)5ptz¯t\scaleto(2)5pt001][cos(α\scaleto(3)5ptz¯t\scaleto(3)5pt)sin(α\scaleto(3)5ptz¯t\scaleto(3)5ptα\scaleto(4)5ptz¯t\scaleto(4)5pt)0sin(α\scaleto(3)5ptz¯t\scaleto(3)5pt)cos(α\scaleto(3)5ptz¯t\scaleto(3)5ptα\scaleto(4)5ptz¯t\scaleto(4)5pt)0001][(α\scaleto(5)5pt)z¯t\scaleto(5)3pt000(α\scaleto(6)5pt)z¯t\scaleto(6)3pt0001]M_{t}=\begin{bmatrix}1&0&\alpha^{\scaleto{(1)}{5pt}}\bar{z}_{t}^{\scaleto{(1)}{5pt}}\\ 0&1&\alpha^{\scaleto{(2)}{5pt}}\bar{z}_{t}^{\scaleto{(2)}{5pt}}\\ 0&0&1\end{bmatrix}\!\cdot\!\begin{bmatrix}\cos(\alpha^{\scaleto{(3)}{5pt}}\bar{z}_{t}^{\scaleto{(3)}{5pt}})&-\sin(\alpha^{\scaleto{(3)}{5pt}}\bar{z}_{t}^{\scaleto{(3)}{5pt}}\alpha^{\scaleto{(4)}{5pt}}\bar{z}_{t}^{\scaleto{(4)}{5pt}})&0\\ \sin(\alpha^{\scaleto{(3)}{5pt}}\bar{z}_{t}^{\scaleto{(3)}{5pt}})&\cos(\alpha^{\scaleto{(3)}{5pt}}\bar{z}_{t}^{\scaleto{(3)}{5pt}}\alpha^{\scaleto{(4)}{5pt}}\bar{z}_{t}^{\scaleto{(4)}{5pt}})&0\\ 0&0&1\end{bmatrix}\!\cdot\!\begin{bmatrix}(\alpha^{\scaleto{(5)}{5pt}})^{\bar{z}_{t}^{\scaleto{(5)}{3pt}}}&0&0\\ 0&(\alpha^{\scaleto{(6)}{5pt}})^{\bar{z}_{t}^{\scaleto{(6)}{3pt}}}&0\\ 0&0&1\end{bmatrix} (17)

To determine the parameters of the likelihood function for the pixel at coordinate (i,j)(i,j), we use the generative model (or decoder) output at the pixel (i,j)(i^{\prime},j^{\prime}) for which

[ij1]=Mt[ij1].\begin{bmatrix}i\\ j\\ 1\end{bmatrix}=M_{t}\begin{bmatrix}i^{\prime}\\ j^{\prime}\\ 1\end{bmatrix}. (18)

This corresponds to applying horizontal and vertical scaling, followed by rotation and shear, followed by translation. We use the spatial transformer layer proposed by Jaderberg et al. (2015) with bilinear interpolation to apply this transformation with non-integer pixel coordinates. For the Noise-Normal distribution we independently interpolate the ρ\rho, μ\mu, and σ2\sigma^{2} parameters.

A.3 Class-conditional sampling

A standard VAE generates data by sampling z𝒩(0,I)z\sim\mathcal{N}(0,I), and then sampling x𝒩(μθ(z),σθ(z))x\sim\mathcal{N}(\mu_{\theta}(z),\sigma_{\theta}(z)), or an alternative like the Noise-Normal likelihood. For the PC-VAE or CPC-VAE, we can further sample images conditioned on a particular class label. As labels are not explicitly part of the generative model, we accomplish this by sampling images that would be confidently predicted as the target class. We use a rejection sampler, repeatedly sampling z𝒩(0,I)z\sim\mathcal{N}(0,I) until a sample meets the criteria: pw(yz)>1ϵp_{w}(y\mid z)>1-\epsilon, for some target threshold ϵ\epsilon. We typically use ϵ=0.05\epsilon=0.05 in our experiments.

MNIST digit samples for models with a 2-D latent space.

Fig. 2 in the main text shows 2-dimensional latent space encodings of the MNIST dataset using several different models. We provide a complementary visualization of generative models in Fig. 8, where we compare class-conditional samples for three of these models. The unsupervised VAE’s encodings of some classes (e.g., 2’s and 4’s and 8’s and 9’s) are not separated, and samples thus frequently appear to be the wrong class. Model M2 (Kingma et al., 2014) explicitly encodes the class label as a latent variable, but nevertheless many sampled images do not visually match the conditioned class. In contrast, for our CPC-VAE model almost all samples are easily recognized as the target class.

Refer to caption
(a) Unsupervised VAE
Refer to caption
(b) CPC-VAE
Refer to caption
(c) M2
Figure 8: Class-conditional samples of the 10 possible digit classes in the MNIST dataset. Each column shows multiple samples from one specific digit class. From left to right, each panel shows samples from a standard unsupervised VAE, our CPC-VAE, and model M2 (Kingma et al., 2014). All models use a 2-dimensional latent code, and are trained on the MNIST dataset with 100 labeled examples (10 per class).
Refer to caption
Figure 9: Class-conditional samples of the 10 possible digit classes in the SVHN dataset, extending figure 5(a). The generative model was trained on the fully labeled SVHN dataset with prediction and consistency constraints. Samples were chosen via rejection sampling in the latent space with a threshold of 95% confidence in the target class.

Appendix B Sensitivity to constraint multipliers

We compare the test accuracy for our consistency-constrained model for MNIST over a range of values for both λ\lambda (the prediction constraint multiplier) and γ\gamma (the consistency constraint multiplier) in figure 10. All runs used our best consistency-constrained model for MNIST using dense networks. We kept all hyperparameters identical to the previous results (see section D), changing only the value of interest for each run.

We see that the resulting test accuracy smoothly varies across several orders of magnitude, with the optimal result being at or near the values we chose for our experiments. Performance is superior to the M2 baseline model for a wide range of hyperparameter values.

Refer to caption
(a) Sensitivity to λ\lambda
Refer to caption
(b) Sensitivity to γ\gamma
Figure 10: Sensitivity of test accuracy to the constraint (Lagrange multiplier) hyperparameters λ\lambda and γ\gamma.

Appendix C Training time

Figure 11 below provides an empirical comparison of the average training time cost per step using the MNIST models summarized in Table 2. Our CPC-VAE implementation runs both the encoder and decoder networks twice to compute the objective (once for the standard VAE loss and an additional time to compute the consistency reconstruction and prediction), thus the runtime is approximately twice that of the PC-VAE. We see that this approximately holds in practice: The PC-VAE requires 38 milliseconds per training step, while the CPC-VAE requires 80.7 milliseconds.

Furthermore, our empirical findings show that training M2 is more expensive than our proposed CPC-VAE in practice, which we expect given the runtime analysis described in Sec. 2.3. The M2 model must run the encoder and decoder networks once per class in order to compute the loss, due to the marginalization of the labels required for the unsupervised loss in Eq. (6). This increases the runtime by a factor equivalent to the number of classes. In our empirical test, we see that the training time per step is  6.7x that of the PC-VAE model, close to the 10x slowdown we would expect for the 10 digit classes of MNIST. In our experiments, we did not find substantial differences in the size of networks or number of training steps needed to train each of these models effectively.

Refer to caption
Figure 11: Comparison of training time per update step of stochastic gradient descent. Each model was trained on the semi-supervised MNIST with 100 labels using hyperparameter settings identical to those used in Table 2. Experiments were run on an RTX Titan GPU, using a common codebase built on top of Tensorflow (Abadi et al., 2015) that implements all methods. Each time reported is the average training step time over the second epoch.

Appendix D Experimental Protocol

Here we provide details about models and experiments which did not fit into the primary paper.

D.1 Hyperparameter optimization

The hyperparameter search for all models, including the CPC-VAE and various baselines, used Optuna (Akiba et al., 2019) to achieve the best accuracy on a labeled validation set. For our 2-layer and 4-layer M2 experiments, we used our own implementation (available in our code release) and followed the hyperparameters used by the original authors. For the 4-layer variant, we tested 10 different settings of α\alpha, ranging from 0.050.05 to 5050, reporting both the the result using the original suggested value for α\alpha (0.10.1) and the best value for α\alpha we found for the setting (1010). For M2, we also dynamically reduced the learning rate when the validation loss plateaued.

D.2 Network architectures

For our PC-VAE and CPC-VAE models of the MNIST data, we use fully-connected encoder and decoder networks with two hidden layers, 1000 hidden units per layer, and softplus activation functions. Like the M2 model (Kingma et al., 2014), we use a 50-dimensional latent space. The original M2 experiments used networks with a single hidden layer of 500 units. We compare this to replications with networks matching ours, as well as 4-layer networks.

For the SVHN and NORB datsets, we adapt the wide-residual network architecture (WRN-28-2) (Zagoruyko & Komodakis, 2016) that was proposed as a standard for semi-supervised deep learning research in Oliver et al. (2018). In particular, we use this architecture for our encoder with two notable changes: We replace the final global average pooling layer with 3 dense layers with 1000 hidden units, and add a final dense layer that outputs means and variances for the latent space. We find that the dense layers provide the capacity needed for accomplishing both generative and discriminative tasks with a single network. For the decoder network we use a "mirrored" version of this architecture, reversing the order of layer sizes used, replacing convolutions with transposed convolutions, and removing pooling layers. We maintain the residual structure of the network. Our best classification results with this architecture were achieved with a latent space dimension of 200.

D.3 Beta-VAE regularization

As an additional form of regularization for our model, we allow our hyperparameter optimization to adjust a weight on the KL-divergence term in the variational lower bound, which we call 𝜷\bm{\beta} as in previous work (Higgins et al., 2017):

βVAE(x;θ,ϕ)=𝔼qϕ(z|x)[logpθ(xz)]+𝜷𝔼qϕ(zx)[logp(z)qϕ(zx)]\displaystyle\mathcal{L}_{\beta}^{\text{VAE}}(x;\theta,\phi)=\mathbb{E}_{q_{\phi}(z|x)}\left[\log p_{\theta}(x\mid z)\right]+\bm{\beta}\cdot\mathbb{E}_{q_{\phi}(z\mid x)}\left[\log\frac{p(z)}{q_{\phi}(z\mid x)}\right] (19)

This allows us to encourage qϕ(zx)q_{\phi}(z\mid x) to more closely conform to the prior, which may be necessary to balance the scale of the objective, depending on the likelihoods used and the dimensionality of the dataset.

D.4 Prediction model regularization

We add two standard regularization terms to the prediction model used in our constraint, y^w(yz)\hat{y}_{w}(y\mid z). The first is an 2\ell_{2} regularizer on the regression weights, w22||w||_{2}^{2}, to help reduce overfitting. The second is an entropy penalty. As y^w(z)\hat{y}_{w}(z) defines a categorical distribution over labels, we compute this as: 𝔼y^w(y|z)[logy^w(yz)]-\mathbb{E}_{\hat{y}_{w}(y|z)}[\log\hat{y}_{w}(y\mid z)], which has been shown to be helpful for semi-supervised learning in Grandvalet & Bengio (2004) and was used as part of the standardized training framework of Oliver et al. (2018). We allowed our hyperparameter optimization approach to select appropriate weights for both terms.

D.5 Image pre-processing

For all of our image datasets, we rescale the inputs to the range [-1, 1]. For our NORB classification results, we downsample each image to 48x48 pixels. For our SVHN classification results, we convert images to greyscale to reduce the representational load on our generative model. Before the grayscale conversion, we apply contrast normalization to better disambiguate the colors within each image.

For the SVHN and NORB results, we follow the recommendation of a recent survey of semi-supervised learning methods (Oliver et al., 2018) and apply a single data augmentation technique: random translations by up to 2-pixels in each direction. For generative results, we retained the original color images and trained with full labels.

D.6 Likelihoods

For all of our image datasets, we use the Noise-Normal likelihood for our CPC methods. For all experiments on toy data (e.g. half-moon), we used a normal likelihood.

For our implementation of M2 for extensive experiments on MNIST we retained the Bernoulli likelihood used by the original authors (Kingma et al., 2014). That is, we rescaling each pixel’s numerical intensity value to the unit interval [0,1], and then sampled binary values from a Bernoulli with probability equal to the intensity.

D.7 Summary of hyperparameter settings for final results

Table 3 below provides all hyperparameter settings used in our experiments.

Hyperparameter MNIST (100) SVHN (1000) NORB (1000)
Encoder/decoder 2 FC layers WRN-28-2 + 3 FC WRN-28-2 + 3 FC
Fully connected layer size 1000 units 1000 units 1000 units
Network activations Softplus Leaky ReLU Leaky ReLU
Latent dimension 50 200 200
Pixel likelihood Noise-Normal Noise-Normal Noise-Normal
Prediction multiplier λ\lambda 25 140 80
Consistency multiplier γ\gamma 4.25λ\lambda 1.25λ\lambda 4λ\lambda
Aggregate consistency penalty 0.1λ\lambda 0.2λ\lambda 0.2λ\lambda
β\beta-VAE weight 1 1.3 2
Predictor reg. (w22||w||_{2}^{2}) 1 1 1
Entropy reg. (𝔼pw(y|z)[logpw(y|z)]\mathbb{E}_{p_{w}(y|z)}[\log p_{w}(y|z)]) 0.5λ\lambda 0.5λ\lambda 0.5λ\lambda
Translation range (α(1)=α(2)\alpha^{(1)}=\alpha^{(2)}) 0.2 ×\times (image-width) 0.2 ×\times (image-width) 0.2 ×\times (image-width)
Rotation range (α(3)\alpha^{(3)}) 0.4 rad 0.5 rad 0.4 rad
Shear range (α(4)\alpha^{(4)}) 0.2 rad 0.2 rad 0.2 rad
Scale range (α(5)=α(6)\alpha^{(5)}=\alpha^{(6)}) 1.5 1.5 1.5
Optimizer ADAM ADAM ADAM
Learning rate 3×1043\times 10^{-4} 3×1043\times 10^{-4} 3×1043\times 10^{-4}
Table 3: Hyperparameter settings for semi-supervised learning experiments with our CPC-VAE.

Appendix E Dataset Details

For each dataset considered in our paper, we provide a more detailed overview of its contents and properties.

E.1 MNIST

Overview. We consider a 10-way exclusive categorization task for MNIST digits.

We use 28-by-28 pixel grayscale images.

Public availability. We will make code to extract our version available after publication.

Data statistics.

Statistics for MNIST are shown in Table 4.

split num. examples label distribution
labeled train 100 [0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]
unlabeled train 49900 [0.1 0.11 0.1 0.1 0.1 0.09 0.1 0.1 0.1 0.1]
labeled valid 10000 [0.1 0.11 0.1 0.1 0.1 0.09 0.1 0.1 0.1 0.1]
labeled test 10000 [0.1 0.11 0.1 0.1 0.1 0.09 0.1 0.1 0.1 0.1]
Table 4: MNIST dataset.

E.2 SVHN

Overview. We consider a 10-way exclusive categorization task for SVHN digits.

We use 32x32 pixel grayscale images.

Public availability. We will make code to extract our version available after publication.

Data statistics.

Statistics for SVHN are shown in Table 5.

split num. examples label distribution
labeled train 1000 [0.10 0.10 0.10 0.10 0.10 0.10 0.10 0.10 0.10 0.10]
unlabeled train 62257 [0.07 0.19 0.15 0.12 0.10 0.09 0.08 0.08 0.07 0.06]
labeled valid 10000 [0.07 0.19 0.14 0.12 0.10 0.09 0.08 0.08 0.07 0.06]
labeled test 26032 [0.07 0.20 0.16 0.11 0.10 0.09 0.08 0.08 0.06 0.06]
Table 5: SVHN dataset.

E.3 NORB

Overview.

We use 48x48 pixel grayscale images.

Public availability. We will make code to extract our version available after publication.

Data statistics.

Statistics for NORB are shown in Table 6.

split num. examples label distribution
labeled train 1000 [0.2 0.2 0.2 0.2 0.2]
unlabeled train 21300 [0.2 0.2 0.2 0.2 0.2]
labeled valid 2000 [0.2 0.2 0.2 0.2 0.2]
labeled test 24300 [0.2 0.2 0.2 0.2 0.2]
Table 6: NORB dataset.

E.4 CelebA

Overview.

We use 64x64 pixel grayscale images. Images were cropped to square from the CelebA aligned variant and downscaled to our 64x64 resolution for computational efficiency. Labels were generated from the provided attributes. Our dataset used 4 classes: woman/neutral face, man/neutral face, woman/smiling, man/smiling.

Public availability. We will make code to extract our version available after publication.

Data statistics.

Statistics for CelebA are shown in Table 7.

split num. examples label distribution
labeled train 1000 [0.25 0.25 0.25 0.25]
unlabeled train 21300 [0.25 0.25 0.34 0.16]
labeled valid 2000 [0.25 0.25 0.25 0.25]
labeled test 24300 [0.27 0.23 0.35 0.15]
Table 7: CelebA dataset.