33email: {jiagengz, hanchenx, wamageed}@isi.edu
Weakly Supervised Invariant Representation Learning Via Disentangling Known and Unknown Nuisance Factors
Abstract
Disentangled and invariant representations are two critical goals of representation learning and many approaches have been proposed to achieve either one of them. However, those two goals are actually complementary to each other so that we propose a framework to accomplish both of them simultaneously. We introduce a weakly supervised signal to learn disentangled representation which consists of three splits containing predictive, known nuisance and unknown nuisance information respectively. Furthermore, we incorporate contrastive method to enforce representation invariance. Experiments shows that the proposed method outperforms state-of-the-art (SOTA) methods on four standard benchmarks and shows that the proposed method can have better adversarial defense ability comparing to other methods without adversarial training.
1 Introduction
Robust representation learning which aims at preventing overfitting and increasing generality can benefit various down-stream tasks [he2015deep, bib:vae, DBLP:journals/corr/abs-1905-12506]. Typically, a DNN learns to encode a representation which contains all factors of variations of data, such as pose, expression, illumination, and age for face recognition, as well as other nuisance factors which are unknown or unlabelled. Disentangled representation learning and invariant representation learning are often used to address these challenges.

For disentangled representation learning, Bengio et al. [bengio2014representation] define disentangled representation which change in a given dimension corresponding to variantion of one and only one generative factors of the input data. Although many unsupervised learning methods have been proposed [bib:betaVAE, burgess2018understanding, bib:factorvae], Locatello et al. [locatello2019challenging] have shown both theoretically and empirically that the factor variants disentanglement is impossible without supervision or inductive bias. To this end, recent works have adopted the concept of semi-supervised learning [bib:semi-disen] and weakly supervised learning [bib:adavae]. On the other hand, Jaiswal et al. [bib:uai] take an invariant representation learning perspective in which they split representation into two parts , where only contains predictive related information, and merely contains nuisance factors.
Invariant representation learning aims to learn to encode predictive latent factors which are invariant to nuisance factors in inputs [bib:cai, bib:uai, bib:vfae, bib:cvib, bib:irmi]. By removing information of nuisance factors, invariant representation learning achieves good performance in challenges like adversarial attack [DBLP:conf/cvpr/ChenKI20] and out-of-distribution generalization [bib:uai]. Furthermore, invariant representation learning has also been studied in the reinforcement learning settings [DBLP:conf/aaai/Castro20].
Despite the success of either disentangled or invariant representation learning methods, the relation between these two has not been thoroughly investigated. As shown in Figure 1(a), invariant representation learning methods learn representations that maximize prediction accuracy by separating predictive factors from all other nuisance factors, while leaving the representations of both known and unknown nuisance factors entangled. Meanwhile, as illustrated in Figure 1(b), although supervised disentangled representation methods can identify known nuisance factors, it fails to handle unknown nuisance factors, which may hurt downstream prediction tasks. Inspired by this observation, we propose a new training framework for seeking to achieve disentanglement and invariance of representation simultaneously. To split the known nuisance factors from predictive and unknown nuisance factors , we introduce the weak supervision signals to achieve disentangled representation learning. To make predictive factors independent of all nuisance factors , we introduce a new invariant regularizer via reconstruction. The predictive factors from the same class are further aligned through contrastive loss to enforce invariance. Moreover, since our model achieve more robust representation comparing to other invariant models, our model is demonstrated to obtain better adversarial defense ability. In summary our main contributions are:
-
•
Extending and combining both disentangled and invariant representation learning and proposing a novel approach to robust representation learning.
-
•
Proposing a novel strategy for splitting the predictive, known nuisance factors and unknown nuisance factors, where mutual independence of those factors is achieved by the reconstruction step used during training.
-
•
Outperforming state-of-the-art (SOTA) models on invariance tasks on standard benchmarks.
-
•
Invariant latent representation trained using our method is also disentangled.
-
•
Without using adversarial training, our model have better adversarial defense ability than other invariant models, which reflects that the generality of the model increases through our methods.
2 Related Work
Disentangled representation learning: Early works on disentangled representation learning aim at learning disentangled latent factors by implementing an autoencoder framework [bib:betaVAE, bib:factorvae, burgess2018understanding]. Variational autoencoder (VAE) [bib:vae] is commonly used in disentanglement learning methods as basic framework. VAE uses DNN to map the high dimension input to low dimension representation . The latent representation is then mapped to high dimension reconstruction . As shown in Equation 1, the overall objective function to train VAE is the evidence lower bounds (ELBO) of likelihood , which contains two parts: quality of reconstruction and Kullback-Leibler divergence () between distribution and the assumed prior . Then, VAE uses the negative of ELBO, , as loss function to update the parameters in the model.
|
(1) |
Advanced methods based on VAE improve the disentanglement performance by implementing new disentanglement regularization. -VAE [bib:betaVAE] modifies the original VAE by adding a hyper-parameter to balance the weights of reconstruction loss and . When , the model gains stronger disentanglement regularization. AnnealedVAE implements a dynamic algorithm to change the from large to small value during training. FactorVAE [bib:factorvae] proposes to use a discriminator in order to distinguish between the joint distribution of latent factors and multiplication of marginal distribution of every latent factor . By using the discriminator, FactorVAE can automatically finds a better balance between reconstruction quality and disentangled representation. Compared to -VAE, DIP-VAE [bib:dipvae] adds another regularization between the marginal distribution of latent factors and the prior to further aid disentangled representation learning, where can be any proper distance function such as mean square error. -TCVAE proposed by [chen2019isolating] modifies the used in VAE into three part: total correlation, index-coded mutual information and dimension-wise KL divergence. To overcome the challenge proposed by [locatello2019challenging], AdaVAE [bib:adavae] purposely chooses pairs of inputs as supervision signal to learn representation disentanglement.

Invariant Representation Learning: The methods that aim at learning invariant representation can be classified into two groups: those methods that require annotations of nuisance factors [bib:nnmmd, bib:vfae] and those that do not. A considerable number of approaches using nuisance factors annotations have been recently proposed. By implementing a regularizer which minimizes the Maximum Mean Discrepancy (MMD) on neural network (NN), The NN+MMD approach [bib:nnmmd] removes affects of nuisance from predictive factors. On the basis of NN+MMD, The Variational Fair Autoencoder (VFAE) [bib:vfae] uses special priors which encourage independence between nuisance factors and ideal invariant factors. The Controllable Adversarial Invariance (CAI) [bib:cai] approach applies the gradient reversal trick [bib:dann] which penalizes the model if latent representation has information of nuisance factors. CVIB [bib:cvib] proposes a conditional form of Information Bottleneck (IB) and encourages the invariant representation learning by optimizing its variational bounds.
However, due to the constrains of demanding annotations, those methods take more effort to pre-process the data and encounter challenges when the annotations are inaccurate or insufficient. Comparing to annotation-eager approaches, annotation-free methods are easier to be implemented in practice. The Unsupervised Adversarial Invariance (UAI) [bib:uai] splits the latent factors into factors useful for prediction and nuisance factors. UAI encourages the independence of those two latent factors by incorporating competition between the prediction and the reconstruction objectives. NN+DIM [bib:irmi] achieves invariant representation by using pairs of inputs and applying a neural network based mutual information estimator to minimize the mutual information between two shared representations. Furthermore, Sanchez et al. [bib:irmi] employ a discriminator to distinguish the difference between shared representation and nuisance representation.
3 Learning disentangled and invariant representation
3.1 Model Architecture

As illustrated in Figure 2, the architecture of the proposed model contains two components: a generation module and a prediction module. Similar to VAE, the generation module performs the encoding-decoding task. However, it encodes the input into latent factors , , where represents the latent predictive factors that contains useful information for the prediction task, whereas represents the latent nuisance factors and can be further divided into known latent factors and unknown nuisance factors .
are discovered and separated from via weakly supervised disentangled representation learning, where the joint distribution . Since is the split containing nuisance factor, after is identified, the remaining factors of naturally result in unknown nuisance factors . Then, and are concatenated for generating reconstructions which are used to measure the quality of reconstruction. To enforce the independence between and , we add a regularizer using another reconstruction task, where the average mean and variance of predictive factors are used to form new latent factors and it will be discussed in Section 3.3. In the prediction module, we further incorporate contrastive loss to cluster the predictive latent factors belonging to the same class.
3.2 Learning independent known nuisance factors
As illustrated in Figure 2, the known nuisance factors are discovered and separated from , where , since nuisance information is expected to be present only within .
To fulfill the theoretical requirement of including supervision signal for disentangled representation learning as proven in [locatello2019challenging], we use selected pairs of inputs and as supervision signals, where only a few common generative factors are shared. As illustrated in Figure 3, during training, the network encodes a pair of inputs and into two latent factors and respectively, which are then decoded to reconstruct and . To encourage representation disentanglement, certain elements of and are detected and swapped to generate two new corresponding latent factors and .

The two new latent factors are then decoded to new reconstructions and . By comparing with , the known nuisance factors are discovered and the elements of are enforced to be mutually independent with each other.
Selecting image pairs for training and latent factors assumptions:
As mentioned by [bib:betaVAE], the true world simulator using generative factors to generate can be modeled as: , where is the generative factors and is other nuisance factors. Inspired by this, we choose pairs of images by randomly selecting several generative factors to be the same and keeping the value of other generative factors to be random. Each image has corresponding generative factors , and the training pair is generated as follows: we first randomly select a sample whose generative factors are . We then randomly change the value of elements in to form a new generative factors and choose another sample according to . During training, indices of different generative factors between and , and the groundtruth value of all generative factors are not available to the model. The model is weakly supervised since it is trained with only the knowledge of the number of factors that have changed. Ideally, if the model can learn a disentangled representation, the model will encode the image pair and to the corresponding representations and which have the characteristic shown in Equation 2. We annotate the set of all different elements between and to be and set of all latent factors to be such that .
(2) |
Detecting and Swapping the distinct latent factors:
VAE adopts the reparameterization to make posterior distribution differentiable, where the posterior distribution of latent factors is commonly assumed to be a factorized multivariate Gaussian: [bib:vae]. By this assumption, we can directly measure the mutual information between the corresponding dimensions of the two latent representations and by measuring the divergence (), which can be KL divergence (). We show the process of detecting distinct latent factors in Equation 3, where a larger value of implies higher difference between the two corresponding latent factor distributions.
(3) |
Since the model only has the knowledge of the number of different generative factors , we swap all corresponding dimension elements of and except the top highest value elements. We incorporate this swapping step to create two new latent representations and shown in Equation 4.
(4) |
Disentangled representation loss function:
After and are obtained, they are concatenated with and respectively, to generate two new latent representations and . and are decoded into new reconstructions and . Since there are only different generative factors between pair of images, ideally, after encoding the images, there should also be merely pairs of different distributions on the latent representation space. By swapping other latent factors except them, the new representations and are the same with the original representations and . Accordingly, the new reconstructions and should be identical to the original reconstruction and . Therefore, we design the disentangled representation loss in Equation 5, where can be any suitable distance function e.g., mean square error (MSE) or binary cross-entropy (BCE).
(5) |
Training Strategies for disentangled representation learning:
To further improve the performance of disentangled representation learning, we design two strategies: warmup by amount and warmup by difficulty. Recalling that in the swapping step, the model needs to swap elements of latent representations. At beginning, exchanging too many latent factors will easily lead to mistakes. Therefore, in the first strategy, we gradually increase the number of latent factors being swapped from to . Further, to smoothly increase the training difficulty, we set the number of different generative factors to be at the beginning and increase the number of different generative factors as training continues.
3.3 Learning invariant predictive factors
After we obtain the disentangled representation , the predictive factors may still be entangled with . Therefore, we need to add other constraints to achieve fully invariant representation of .
Making independent of :
As shown in [locatello2019challenging], supervision signals need to be introduced for disentangled representation. Similarly, the independency of and also needs the help from a supervision signals as we discuss in Appendix. Luckily, for supervised training, a batch of samples naturally contains supervision signal. Similar to Equation 2, the distribution of the representations should be the same for the same class and can be shown in Equation 6 where means the class of sample .
(6) |
Similar to the method we use for disentangled representation learning, we generate a new latent representation and its corresponding reconstruction . Then, we enforce the disentanglement between and by comparing the new reconstruction and . In contrast to the swapping method mentioned in Section 3.2, since the batch of samples used for training often contains more than two samples from the same class, the swapping method is hard to be implemented in this situation. Therefore, we generate the new latent representations by calculating the average mean and average variance of the latent representations from the same class as shown in LABEL:eq:mean_zp.
(7) |
We then generate the new reconstruction using the same decoder as in other reconstruction tasks and enforce the disentanglement of and by calculating the and update the parameters of the model according to its gradient.
Constrastive feature alignment:
To achieve invariant representation, we need to make sure the latent representation that is useful for prediction can also be clustered according to their corresponding classes. Even though the often used cross-entropy (CE) loss can accomplish similar goals, the direct goal of CE loss is to achieve logit-level alignment and change the representations distribution according to the logits, which does not guarantee the uniform distribution of features. Alternatively, we incorporate contrastive methods to ensure that representation/feature alignment can be accomplished effectively [Wang_2021_CVPR].
Similar to [bib:super_contras], we use supervised contrastive loss to achieve feature alignment and cluster the representations according to their classes as shown in Equation 8 where is the set that contains samples from the same class and .
(8) |
The final loss function used to train the model, after adding the standard cross-entropy(CE) loss to train the classifier, is given by LABEL:eq:all_loss.
(9) | ||||
\hlineB2 Colored-MNIST 3dShapes MPI3D Models Avg Acc Worst Acc Avg Acc Worst Acc Avg Acc Worst Acc Baseline 95.12 2.42 66.17 3.31 98.87 0.52 96.89 1.25 90.12 3.13 87.89 4.31 VFAE [bib:vfae] 93.12 3.07 65.54 6.21 97.72 0.81 93.34 1.05 86.69 3.12 82.43 3.25 CAI [bib:cai] 93.56 2.76 63.17 5.61 97.62 0.53 94.32 0.89 86.63 2.14 82.16 5.83 CVIB [bib:cvib] 93.31 3.09 70.12 4.77 97.11 0.59 94.46 0.90 87.04 3.02 85.61 2.08 UAI [bib:uai] 94.74 2.19 74.25 2.69 97.13 1.02 95.21 1.03 87.89 4.23 83.01 2.21 NN+DIM [bib:irmi] 94.48 2.35 80.25 3.44 97.03 1.07 96.02 0.46 88.81 1.37 82.01 3.34 \hdashlineOur model 97.96 1.21 90.43 2.79 98.52 0.51 97.63 0.72 91.32 2.38 89.17 2.69 \hlineB2
4 Experiments Evaluation
4.1 Benchmarks, Baselines and Metrics
The main objective of this work is to learn invariant representations and reduce overfitting to nuisance factors. Meanwhile, as a secondary objective, we also want to ensure that the learned representations are at least not less robust to adversarial attacks. Therefore, all models are evaluated on both invariant representation learning task and adversarial robustness task. We use four (4) dataset with different underlying factors of variations to evaluate the model:
-
•
Colored-MNIST Colored-MNIST dataset is augmented version of MNIST [lecun-mnisthandwrittendigit-2010] with two known nuisance factors: digit color and background color [bib:irmi]. During training, the background color is chosen from three (3) colors and digit color is chosen from other six (6) colors. In test, we set the background color into three (3) new colors which is different from training set.
-
•
Rotation-Colored-MNIST This dataset is further augmented version of Colored-MNIST. The background color and digit color setting is the same with the Colored-MNIST. This dataset further contains digits rotated to four (4) different angles . For test data, the rotation angles for digit is set to . The rotation angles are used as unknown nuisance factors.
-
•
3dShapes [3dshapes18] contains 480,000 RGB images and the whole dataset has six (6) different generative factors. We choose object shape (four (4) classes) as the prediction task and only half number of object colors are used during training, and the remaining half of object color samples are used to evaluate performance of invariant representation.
-
•
MPI3D [gondal2019transfer] is a real-world dataset contains 1,036,800 RGB images and the whole dataset has seven (7) generative factors. Like 3dShapes, we choose object shape (six (6) classes) as the prediction target and half of object colors are used for training.
\hlineB2 Models Rotation-Colored-MNIST Avg Acc Worst Acc Avg Acc Worst Acc Avg Acc Worst Acc Avg Acc Worst Acc -75 -65 +65 +75 Baseline 77.0 1.3 62.3 1.9 89.7 1.2 77.5 2.2 85.8 1.2 65.8 3.0 68.3 2.2 49.9 4.6 VFAE [bib:vfae] 72.2 2.4 58.9 2.3 85.8 1.7 74.4 2.5 84.1 2.1 64.6 3.7 71.7 1.3 48.0 3.8 CAI [bib:cai] 74.9 0.9 59.3 3.9 86.5 1.9 77.3 2.0 84.2 1.7 67.8 1.9 64.7 4.2 42.9 3.7 CVIB [bib:cvib] 76.1 0.8 59.2 3.0 88.6 0.9 79.1 1.2 85.6 0.7 68.8 2.9 72.2 1.2 53.4 2.6 UAI [bib:uai] 76.0 1.7 61.1 5.6 88.8 0.7 80.0 0.9 85.4 1.6 68.2 2.3 70.2 0.9 51.1 2.3 NN+DIM [bib:irmi] 77.6 2.6 69.2 2.7 85.2 3.4 76.3 4.3 84.6 3.1 66.7 3.7 68.4 3.1 53.2 5.6 \hdashlineOur model 81.0 2.1 75.3 2.5 90.8 1.6 85.7 2.4 87.3 2.5 82.3 2.1 73.2 2.3 63.3 2.9 \hlineB2
Prediction accuracy is used to evaluate the performance of invariant representation learning. Furthermore, we record both average test accuracy and worst-case test accuracy which was suggested by [Sagawa*2020Distributionally]. We find that using LABEL:eq:all_loss directly does not guarantee good performance. This may be caused by inconsistent behavior of CE loss and supervised contrastive loss. Thus, we separately train the classifier using CE loss and use remaining part of total loss to train the rest of the model.
Meanwhile, the performance of representation disentanglement is also important for representation invariance since it can evaluate the invariance of latent factors representing known nuisance factors.
We adopt the following metrics to evaluate the performance of disentangled representation. All metrics range from to , where indicates that the latent factors are fully disentangled — (1) Mutual Information Gap (MIG) [chen2019isolating] evaluates the gap of top two highest mutual information between a latent factors and generative factors. (2) Separated Attribute Predictability (SAP) [bib:dipvae] measures the mean of the difference of perdition error between the top two most predictive latent factors. (3) Interventional Robustness Score (IRS) [suter2019robustly] evaluates reliance of a latent factor solely on generative factor regardless of other generative factors. (4) FactorVAE (FVAE) score [bib:factorvae] implements a majority vote classifier to predict the index of a fixed generative factor and take the prediction accuracy as the final score value. (5) DCI-Disentanglement (DCI) [bib:dci] calculates the entropy of the distribution obtained by normalizing among each dimension of the learned representation for predicting the value of a generative factor.