This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

11institutetext: USC Ming Hsieh Department of Electrical and Computer Engineering 22institutetext: USC Information Sciences Institute 33institutetext: Visual Intelligence and Multimedia Analytics Laboratory
33email: {jiagengz, hanchenx, wamageed}@isi.edu

Weakly Supervised Invariant Representation Learning Via Disentangling Known and Unknown Nuisance Factors

Jiageng Zhu 112233    Hanchen Xie 2233    Wael Abd-Almageed 112233
Abstract

Disentangled and invariant representations are two critical goals of representation learning and many approaches have been proposed to achieve either one of them. However, those two goals are actually complementary to each other so that we propose a framework to accomplish both of them simultaneously. We introduce a weakly supervised signal to learn disentangled representation which consists of three splits containing predictive, known nuisance and unknown nuisance information respectively. Furthermore, we incorporate contrastive method to enforce representation invariance. Experiments shows that the proposed method outperforms state-of-the-art (SOTA) methods on four standard benchmarks and shows that the proposed method can have better adversarial defense ability comparing to other methods without adversarial training.

1 Introduction

Robust representation learning which aims at preventing overfitting and increasing generality can benefit various down-stream tasks [he2015deep, bib:vae, DBLP:journals/corr/abs-1905-12506]. Typically, a DNN learns to encode a representation which contains all factors of variations of data, such as pose, expression, illumination, and age for face recognition, as well as other nuisance factors which are unknown or unlabelled. Disentangled representation learning and invariant representation learning are often used to address these challenges.

Refer to caption
Figure 1: Given an image with nuisance factors, (a): invariant representation learning splits predictive factors from all nuisance factors; (b): disentangled representation learning splits the known nuisance factors but leaving predictive and unknown nuisance factors; (c): Our method splits predictive, known nuisance and unknown nuisance factors simultaneously.

For disentangled representation learning, Bengio et al. [bengio2014representation] define disentangled representation which change in a given dimension corresponding to variantion of one and only one generative factors of the input data. Although many unsupervised learning methods have been proposed [bib:betaVAE, burgess2018understanding, bib:factorvae], Locatello et al. [locatello2019challenging] have shown both theoretically and empirically that the factor variants disentanglement is impossible without supervision or inductive bias. To this end, recent works have adopted the concept of semi-supervised learning [bib:semi-disen] and weakly supervised learning [bib:adavae]. On the other hand, Jaiswal et al. [bib:uai] take an invariant representation learning perspective in which they split representation zz into two parts z=[zp,zn]z=[z_{p},z_{n}], where zpz_{p} only contains predictive related information, and znz_{n} merely contains nuisance factors.

Invariant representation learning aims to learn to encode predictive latent factors which are invariant to nuisance factors in inputs [bib:cai, bib:uai, bib:vfae, bib:cvib, bib:irmi]. By removing information of nuisance factors, invariant representation learning achieves good performance in challenges like adversarial attack [DBLP:conf/cvpr/ChenKI20] and out-of-distribution generalization [bib:uai]. Furthermore, invariant representation learning has also been studied in the reinforcement learning settings [DBLP:conf/aaai/Castro20].

Despite the success of either disentangled or invariant representation learning methods, the relation between these two has not been thoroughly investigated. As shown in Figure 1(a), invariant representation learning methods learn representations that maximize prediction accuracy by separating predictive factors from all other nuisance factors, while leaving the representations of both known and unknown nuisance factors entangled. Meanwhile, as illustrated in Figure 1(b), although supervised disentangled representation methods can identify known nuisance factors, it fails to handle unknown nuisance factors, which may hurt downstream prediction tasks. Inspired by this observation, we propose a new training framework for seeking to achieve disentanglement and invariance of representation simultaneously. To split the known nuisance factors znkz_{nk} from predictive zpz_{p} and unknown nuisance factors znuz_{nu}, we introduce the weak supervision signals to achieve disentangled representation learning. To make predictive factors zpz_{p} independent of all nuisance factors znz_{n}, we introduce a new invariant regularizer via reconstruction. The predictive factors from the same class are further aligned through contrastive loss to enforce invariance. Moreover, since our model achieve more robust representation comparing to other invariant models, our model is demonstrated to obtain better adversarial defense ability. In summary our main contributions are:

  • Extending and combining both disentangled and invariant representation learning and proposing a novel approach to robust representation learning.

  • Proposing a novel strategy for splitting the predictive, known nuisance factors and unknown nuisance factors, where mutual independence of those factors is achieved by the reconstruction step used during training.

  • Outperforming state-of-the-art (SOTA) models on invariance tasks on standard benchmarks.

  • Invariant latent representation trained using our method is also disentangled.

  • Without using adversarial training, our model have better adversarial defense ability than other invariant models, which reflects that the generality of the model increases through our methods.

2 Related Work

Disentangled representation learning: Early works on disentangled representation learning aim at learning disentangled latent factors zz by implementing an autoencoder framework [bib:betaVAE, bib:factorvae, burgess2018understanding]. Variational autoencoder (VAE) [bib:vae] is commonly used in disentanglement learning methods as basic framework. VAE uses DNN to map the high dimension input xx to low dimension representation zz. The latent representation zz is then mapped to high dimension reconstruction x^\hat{x}. As shown in Equation 1, the overall objective function to train VAE is the evidence lower bounds (ELBO) of likelihood logpθ(x1,x2,xn)\log p_{\theta}(x_{1},x_{2},...x_{n}), which contains two parts: quality of reconstruction and Kullback-Leibler divergence (DKLD_{KL}) between distribution qϕ(z|x)q_{\phi}(z|x) and the assumed prior p(z)p(z). Then, VAE uses the negative of ELBO, LVAE=ELBOL_{VAE}=-ELBO, as loss function to update the parameters in the model.

LVAE=ELBO=i=1N[𝔼qϕ(z|x(i))[logpθ(x(i)|z)]DKL(qϕ(z|x(i)||p(z))]L_{VAE}=-ELBO=-\sum_{i=1}^{N}\Big{[}\mathbb{E}_{q_{\phi}(z|x^{(i)})}[\log p_{\theta}(x^{(i)}|z)]-D_{KL}(q_{\phi}(z|x^{(i)}||p(z))\Big{]}

(1)

Advanced methods based on VAE improve the disentanglement performance by implementing new disentanglement regularization. β\beta-VAE [bib:betaVAE] modifies the original VAE by adding a hyper-parameter β\beta to balance the weights of reconstruction loss and DKLD_{KL}. When β>1\beta>1, the model gains stronger disentanglement regularization. AnnealedVAE implements a dynamic algorithm to change the β\beta from large to small value during training. FactorVAE [bib:factorvae] proposes to use a discriminator in order to distinguish between the joint distribution of latent factors q(z)q(z) and multiplication of marginal distribution of every latent factor q(zi)\prod q(z_{i}). By using the discriminator, FactorVAE can automatically finds a better balance between reconstruction quality and disentangled representation. Compared to β\beta-VAE, DIP-VAE [bib:dipvae] adds another regularization D(qϕ(z)||p(z))D(q_{\phi}(z)||p(z)) between the marginal distribution of latent factors qϕ(z)=qϕ(z|x)p(x)𝑑xq_{\phi}(z)=\int q_{\phi}(z|x)p(x)dx and the prior p(z)p(z) to further aid disentangled representation learning, where DD can be any proper distance function such as mean square error. β\beta-TCVAE proposed by [chen2019isolating] modifies the DKLD_{KL} used in VAE into three part: total correlation, index-coded mutual information and dimension-wise KL divergence. To overcome the challenge proposed by [locatello2019challenging], AdaVAE [bib:adavae] purposely chooses pairs of inputs as supervision signal to learn representation disentanglement.

Refer to caption
Figure 2: Architecture of the model. Red box is the generation part and the blue box is the prediction part. Generation module aims at learning and splitting known nuisance factors znkz_{nk} and znuz_{nu}, and the prediction module aims at learning good predictive factors zpz_{p}.

Invariant Representation Learning: The methods that aim at learning invariant representation can be classified into two groups: those methods that require annotations of nuisance factors [bib:nnmmd, bib:vfae] and those that do not. A considerable number of approaches using nuisance factors annotations have been recently proposed. By implementing a regularizer which minimizes the Maximum Mean Discrepancy (MMD) on neural network (NN), The NN+MMD approach [bib:nnmmd] removes affects of nuisance from predictive factors. On the basis of NN+MMD, The Variational Fair Autoencoder (VFAE) [bib:vfae] uses special priors which encourage independence between nuisance factors and ideal invariant factors. The Controllable Adversarial Invariance (CAI) [bib:cai] approach applies the gradient reversal trick [bib:dann] which penalizes the model if latent representation has information of nuisance factors. CVIB [bib:cvib] proposes a conditional form of Information Bottleneck (IB) and encourages the invariant representation learning by optimizing its variational bounds.

However, due to the constrains of demanding annotations, those methods take more effort to pre-process the data and encounter challenges when the annotations are inaccurate or insufficient. Comparing to annotation-eager approaches, annotation-free methods are easier to be implemented in practice. The Unsupervised Adversarial Invariance (UAI) [bib:uai] splits the latent factors into factors useful for prediction and nuisance factors. UAI encourages the independence of those two latent factors by incorporating competition between the prediction and the reconstruction objectives. NN+DIM [bib:irmi] achieves invariant representation by using pairs of inputs and applying a neural network based mutual information estimator to minimize the mutual information between two shared representations. Furthermore, Sanchez et al. [bib:irmi] employ a discriminator to distinguish the difference between shared representation and nuisance representation.

3 Learning disentangled and invariant representation

3.1 Model Architecture

Refer to caption
Figure 3: Weakly supervised disentanglement representation learning for known nuisance factors znkz_{nk}

As illustrated in Figure 2, the architecture of the proposed model contains two components: a generation module and a prediction module. Similar to VAE, the generation module performs the encoding-decoding task. However, it encodes the input xx into latent factors zz, z=[zp,zn]z=[z_{p},z_{n}], where zpz_{p} represents the latent predictive factors that contains useful information for the prediction task, whereas znz_{n} represents the latent nuisance factors and can be further divided into known latent factors znkz_{nk} and unknown nuisance factors znuz_{nu}.

znkz_{nk} are discovered and separated from znz_{n} via weakly supervised disentangled representation learning, where the joint distribution p(znk)=ip(znki)p(z_{nk})=\prod_{i}p(z_{{nk}_{i}}). Since znz_{n} is the split containing nuisance factor, after znkz_{nk} is identified, the remaining factors of znz_{n} naturally result in unknown nuisance factors znuz_{nu}. Then, zpz_{p} and znz_{n} are concatenated for generating reconstructions xrecx_{rec} which are used to measure the quality of reconstruction. To enforce the independence between zpz_{p} and znz_{n}, we add a regularizer using another reconstruction task, where the average mean and variance of predictive factors zpz_{p} are used to form new latent factors z¯p\bar{z}_{p} and it will be discussed in Section 3.3. In the prediction module, we further incorporate contrastive loss to cluster the predictive latent factors belonging to the same class.

3.2 Learning independent known nuisance factors znkz_{nk}

As illustrated in Figure 2, the known nuisance factors znkz_{nk} are discovered and separated from znz_{n}, where p(znk)=ip(znki)p(z_{nk})=\prod_{i}p(z_{nk_{i}}), since nuisance information is expected to be present only within znz_{n}.

To fulfill the theoretical requirement of including supervision signal for disentangled representation learning as proven in [locatello2019challenging], we use selected pairs of inputs x(l)x^{(l)} and x(m)x^{(m)} as supervision signals, where only a few common generative factors are shared. As illustrated in Figure 3, during training, the network encodes a pair of inputs x(l)x^{(l)} and x(m)x^{(m)} into two latent factors z(l)=[zp(l),zn(l)]z^{(l)}=[z^{(l)}_{p},z^{(l)}_{n}] and z(m)=[zp(m),zn(m)]z^{(m)}=[z^{(m)}_{p},z^{(m)}_{n}] respectively, which are then decoded to reconstruct xrec(l)x^{(l)}_{rec} and xrec(m)x^{(m)}_{rec}. To encourage representation disentanglement, certain elements of zn(l)z_{n}^{(l)} and zn(m)z_{n}^{(m)} are detected and swapped to generate two new corresponding latent factors z^(l)\hat{z}^{(l)} and z^(m)\hat{z}^{(m)}.

Refer to caption
Figure 4: In early training stages, small number of latent factors are swapped. the number of latent factors to be swapped increases gradually.

The two new latent factors are then decoded to new reconstructions x^rec(l)\hat{x}^{(l)}_{rec} and x^rec(m)\hat{x}^{(m)}_{rec}. By comparing x^rec\hat{x}_{rec} with xrecx_{rec}, the known nuisance factors znkz_{nk} are discovered and the elements of znkz_{nk} are enforced to be mutually independent with each other.

Selecting image pairs for training and latent factors assumptions:

As mentioned by [bib:betaVAE], the true world simulator using generative factors to generate xx can be modeled as: p(x|v,w)=Sim(v,w)p(x|v,w)=Sim(v,w), where vv is the generative factors and ww is other nuisance factors. Inspired by this, we choose pairs of images by randomly selecting several generative factors to be the same and keeping the value of other generative factors to be random. Each image xx has corresponding generative factors vv, and the training pair is generated as follows: we first randomly select a sample x(l)x^{(l)} whose generative factors are v(l)=[v1,v2,vn]v^{(l)}=[v_{1},v_{2},...v_{n}]. We then randomly change the value of kk elements in v(l)v^{(l)} to form a new generative factors v(m)v^{(m)} and choose another sample x(m)x^{(m)} according to v(m)v^{(m)}. During training, indices of different generative factors between v(l)v^{(l)} and v(m)v^{(m)}, and the groundtruth value of all generative factors are not available to the model. The model is weakly supervised since it is trained with only the knowledge of the number of factors kk that have changed. Ideally, if the model can learn a disentangled representation, the model will encode the image pair x(l)x^{(l)} and x(m)x^{(m)} to the corresponding representations z(l)z^{(l)} and z(m)z^{(m)} which have the characteristic shown in Equation 2. We annotate the set of all different elements between z(l)z^{(l)} and z(m)z^{(m)} to be dfzdf_{z} and set of all latent factors to be dzd_{z} such that dfzdzdf_{z}\subseteq d_{z}.

p(znj(l)|𝐱(𝐥))=p(znj(m)|𝐱(𝐦));jdfzp(zni(l)|𝐱(𝐥))p(zni(m)|𝐱(𝐦));idfz\begin{split}p(z_{n_{j}}^{(l)}|\mathbf{x^{(l)}})&=p(z_{n_{j}}^{(m)}|\mathbf{x^{(m)}});\leavevmode\nobreak\ j\notin df_{z}\\ p(z_{n_{i}}^{(l)}|\mathbf{x^{(l)}})&\neq p(z_{n_{i}}^{(m)}|\mathbf{x^{(m)}});\leavevmode\nobreak\ i\in df_{z}\end{split} (2)

Detecting and Swapping the distinct latent factors:

VAE adopts the reparameterization to make posterior distribution qθ(z|x)q_{\theta}(z|x) differentiable, where the posterior distribution of latent factors is commonly assumed to be a factorized multivariate Gaussian: p(z|x)=qθ(z|x)p(z|x)=q_{\theta}(z|x) [bib:vae]. By this assumption, we can directly measure the mutual information between the corresponding dimensions of the two latent representations z(l)z^{(l)} and z(m)z^{(m)} by measuring the divergence (DmD_{m}), which can be KL divergence (DKLD_{KL}). We show the process of detecting distinct latent factors in Equation 3, where a larger value of DKLD_{KL} implies higher difference between the two corresponding latent factor distributions.

DKL(qϕ(zi(l)|x(l))||qϕ(zi(m)|x(m)))=(σi(l))2+(μi(l)μi(m))22(σi(l))2+log(σi(l)σi(m))12\displaystyle D_{KL}(q_{\phi}(z_{i}^{(l)}|x^{(l)})||q_{\phi}(z_{i}^{(m)}|x^{(m)}))=\frac{({\sigma^{(l)}_{i}})^{2}+(\mu^{(l)}_{i}-\mu^{(m)}_{i})^{2}}{2{(\sigma^{(l)}_{i}})^{2}}+log(\frac{\sigma^{(l)}_{i}}{\sigma^{(m)}_{i}})-\frac{1}{2} (3)

Since the model only has the knowledge of the number of different generative factors kk, we swap all corresponding dimension elements of zn(l)z_{n}^{(l)} and zn(m)z_{n}^{(m)} except the top kk highest DmD_{m} value elements. We incorporate this swapping step to create two new latent representations z^n(l)\hat{z}_{n}^{(l)} and z^n(m)\hat{z}_{n}^{(m)} shown in Equation 4.

z^ni(l)=zni(m);z^ni(m)=zni(l);idfzz^nj(l)=znj(l);z^nj(m)=znj(m);jdfz\begin{split}\hat{z}_{n_{i}}^{(l)}&=z_{n_{i}}^{(m)};\hat{z}_{n_{i}}^{(m)}=z_{n_{i}}^{(l)};\leavevmode\nobreak\ i\notin df_{z}\\ \hat{z}_{n_{j}}^{(l)}&=z_{n_{j}}^{(l)};\hat{z}_{n_{j}}^{(m)}=z_{n_{j}}^{(m)};\leavevmode\nobreak\ j\in df_{z}\\ \end{split} (4)

Disentangled representation loss function:

After z^n(l)\hat{z}_{n}^{(l)} and z^n(m)\hat{z}_{n}^{(m)} are obtained, they are concatenated with zp(m)z_{p}^{(m)} and zp(l)z_{p}^{(l)} respectively, to generate two new latent representations z^(l)=[zp(m),z^n(l)]\hat{z}^{(l)}=[z_{p}^{(m)},\hat{z}_{n}^{(l)}] and z^(m)=[zp(l),z^n(m)]\hat{z}^{(m)}=[z_{p}^{(l)},\hat{z}_{n}^{(m)}]. z^(l)\hat{z}^{(l)} and z^(m)\hat{z}^{(m)} are decoded into new reconstructions x^(l)\hat{x}^{(l)} and x^(m)\hat{x}^{(m)}. Since there are only kk different generative factors between pair of images, ideally, after encoding the images, there should also be merely kk pairs of different distributions on the latent representation space. By swapping other latent factors except them, the new representations z^(l)\hat{z}^{(l)} and z^(m)\hat{z}^{(m)} are the same with the original representations z(l)z^{(l)} and z(m)z^{(m)}. Accordingly, the new reconstructions x^rec(l)\hat{x}_{rec}^{(l)} and x^rec(m)\hat{x}_{rec}^{(m)} should be identical to the original reconstruction xrec(l)x_{rec}^{(l)} and xrec(m)x_{rec}^{(m)}. Therefore, we design the disentangled representation loss in Equation 5, where DD can be any suitable distance function e.g., mean square error (MSE) or binary cross-entropy (BCE).

L=LVAE(xrec(l),z(l))+LVAE(xrec(m),z(m))+D(x^rec(l),xrec(l))+D(x^rec(m),xrec(m))L=L_{VAE}(x^{(l)}_{rec},z^{(l)})+L_{VAE}(x^{(m)}_{rec},z^{(m)})+D(\hat{x}^{(l)}_{rec},x^{(l)}_{rec})+D(\hat{x}^{(m)}_{rec},x^{(m)}_{rec}) (5)

Training Strategies for disentangled representation learning:

To further improve the performance of disentangled representation learning, we design two strategies: warmup by amount and warmup by difficulty. Recalling that in the swapping step, the model needs to swap |dz|k|d_{z}|-k elements of latent representations. At beginning, exchanging too many latent factors will easily lead to mistakes. Therefore, in the first strategy, we gradually increase the number of latent factors being swapped from 11 to |dz|k|d_{z}|-k. Further, to smoothly increase the training difficulty, we set the number of different generative factors to be 11 at the beginning and increase the number of different generative factors as training continues.

3.3 Learning invariant predictive factors zpz_{p}

After we obtain the disentangled representation znkz_{nk}, the predictive factors zpz_{p} may still be entangled with znuz_{nu}. Therefore, we need to add other constraints to achieve fully invariant representation of zpz_{p}.

Making zpz_{p} independent of znz_{n}:

As shown in [locatello2019challenging], supervision signals need to be introduced for disentangled representation. Similarly, the independency of zpz_{p} and znz_{n} also needs the help from a supervision signals as we discuss in Appendix. Luckily, for supervised training, a batch of samples naturally contains supervision signal. Similar to Equation 2, the distribution of the representations zpz_{p} should be the same for the same class and can be shown in Equation 6 where C(x(l))C(x^{(l)}) means the class of sample x(l)x^{(l)}.

p(zp(l)|𝐱(𝐥))=p(zp(m)|𝐱(𝐦));C(x(l))=C(x(m))p(zp(l)|𝐱(𝐥))p(zp(m)|𝐱(𝐦));C(x(l))C(x(m))\begin{split}p(z_{p}^{(l)}|\mathbf{x^{(l)}})&=p(z_{p}^{(m)}|\mathbf{x^{(m)}});\leavevmode\nobreak\ C(x^{(l)})=C(x^{(m)})\\ p(z_{p}^{(l)}|\mathbf{x^{(l)}})&\neq p(z_{p}^{(m)}|\mathbf{x^{(m)}});\leavevmode\nobreak\ C(x^{(l)})\neq C(x^{(m)})\end{split} (6)

Similar to the method we use for disentangled representation learning, we generate a new latent representation z¯p\bar{z}_{p} and its corresponding reconstruction x¯recp\bar{x}_{rec-p}. Then, we enforce the disentanglement between zpz_{p} and znz_{n} by comparing the new reconstruction x¯recp\bar{x}_{rec-p} and xrecx_{rec}. In contrast to the swapping method mentioned in Section 3.2, since the batch of samples used for training often contains more than two samples from the same class, the swapping method is hard to be implemented in this situation. Therefore, we generate the new latent representations z¯p\bar{z}_{p} by calculating the average mean μ¯p\bar{\mu}_{p} and average variance V¯p\bar{V}_{p} of the latent representations from the same class as shown in LABEL:eq:mean_zp.

zp¯=𝒩(μp¯,Vp¯);x¯recp=Decoder([z¯p,zn])μp¯=1|C|μp(i);Vp¯=1|C|Vpi;whereiC\begin{split}&\bar{z_{p}}=\mathcal{N}(\bar{\mu_{p}},\bar{V_{p}});\leavevmode\nobreak\ \bar{x}_{rec-p}=Decoder([\bar{z}_{p},z_{n}])\\ \leavevmode\nobreak\ \bar{\mu_{p}}=&\frac{1}{|C|}\sum\mu^{(i)}_{p}\leavevmode\nobreak\ ;\leavevmode\nobreak\ \bar{V_{p}}=\frac{1}{|C|}\sum V^{i}_{p};\leavevmode\nobreak\ where\leavevmode\nobreak\ \forall i\in C\end{split} (7)

We then generate the new reconstruction x¯recp\bar{x}_{rec-p} using the same decoder as in other reconstruction tasks and enforce the disentanglement of zpz_{p} and znz_{n} by calculating the D(xrec,x¯recp)D({x_{rec},\bar{x}_{rec_{p}}}) and update the parameters of the model according to its gradient.

Constrastive feature alignment:

To achieve invariant representation, we need to make sure the latent representation that is useful for prediction can also be clustered according to their corresponding classes. Even though the often used cross-entropy (CE) loss can accomplish similar goals, the direct goal of CE loss is to achieve logit-level alignment and change the representations distribution according to the logits, which does not guarantee the uniform distribution of features. Alternatively, we incorporate contrastive methods to ensure that representation/feature alignment can be accomplished effectively [Wang_2021_CVPR].

Similar to [bib:super_contras], we use supervised contrastive loss to achieve feature alignment and cluster the representations zpz_{p} according to their classes as shown in Equation 8 where CC is the set that contains samples from the same class and yp=yiy_{p}=y_{i}.

sup=iI1|C|pClogexp(zizp/τ)aA(i)exp(ziza/τ)\mathcal{L}_{sup}=\sum_{i\in I}\frac{-1}{|C|}\sum_{p\in C}\log{\frac{\text{exp}\left(z_{i}\cdot z_{p}/\tau\right)}{\sum\limits_{a\in A(i)}\text{exp}\left(z_{i}\cdot z_{a}/\tau\right)}} (8)

The final loss function used to train the model, after adding the standard cross-entropy(CE) loss to train the classifier, is given by LABEL:eq:all_loss.

L=LCE(x,y)+LVAE+αLdisentangle+βLSup+γLZp\displaystyle L=L_{CE}(x,y)+L_{VAE}+\alpha L_{disentangle}+\beta L_{Sup}+\gamma L_{Z_{p}} (9)
Ldisentangle=D(x^rec(l),xrec(l))+D(x^rec(m),xrec(m))\displaystyle L_{disentangle}=D(\hat{x}^{(l)}_{rec},x^{(l)}_{rec})+D(\hat{x}^{(m)}_{rec},x^{(m)}_{rec})
LZp=D(x¯recp,xrec)\displaystyle L_{Z_{p}}=D(\bar{x}_{rec-p},x_{rec})
Table 1: Test average and worst accuracy results on Colored-MNIST, 3dShapes and MPI3D. Bold, Black: best result

\hlineB2 Colored-MNIST 3dShapes MPI3D Models Avg Acc Worst Acc Avg Acc Worst Acc Avg Acc Worst Acc Baseline 95.12 ±\pm 2.42 66.17 ±\pm 3.31 98.87 ±\pm 0.52 96.89 ±\pm 1.25 90.12 ±\pm 3.13 87.89 ±\pm 4.31 VFAE [bib:vfae] 93.12 ±\pm 3.07 65.54 ±\pm 6.21 97.72 ±\pm 0.81 93.34 ±\pm 1.05 86.69 ±\pm 3.12 82.43 ±\pm 3.25 CAI [bib:cai] 93.56 ±\pm 2.76 63.17 ±\pm 5.61 97.62 ±\pm 0.53 94.32 ±\pm 0.89 86.63 ±\pm 2.14 82.16 ±\pm 5.83 CVIB [bib:cvib] 93.31 ±\pm 3.09 70.12 ±\pm 4.77 97.11 ±\pm 0.59 94.46 ±\pm 0.90 87.04 ±\pm 3.02 85.61 ±\pm 2.08 UAI [bib:uai] 94.74 ±\pm 2.19 74.25 ±\pm 2.69 97.13 ±\pm 1.02 95.21 ±\pm 1.03 87.89 ±\pm 4.23 83.01 ±\pm 2.21 NN+DIM [bib:irmi] 94.48 ±\pm 2.35 80.25±\pm 3.44 97.03 ±\pm 1.07 96.02 ±\pm 0.46 88.81 ±\pm 1.37 82.01 ±\pm 3.34 \hdashlineOur model 97.96 ±\pm 1.21 90.43±\pm 2.79 98.52 ±\pm 0.51 97.63 ±\pm 0.72 91.32 ±\pm 2.38 89.17±\pm 2.69 \hlineB2

4 Experiments Evaluation

4.1 Benchmarks, Baselines and Metrics

The main objective of this work is to learn invariant representations and reduce overfitting to nuisance factors. Meanwhile, as a secondary objective, we also want to ensure that the learned representations are at least not less robust to adversarial attacks. Therefore, all models are evaluated on both invariant representation learning task and adversarial robustness task. We use four (4) dataset with different underlying factors of variations to evaluate the model:

  • Colored-MNIST Colored-MNIST dataset is augmented version of MNIST [lecun-mnisthandwrittendigit-2010] with two known nuisance factors: digit color and background color [bib:irmi]. During training, the background color is chosen from three (3) colors and digit color is chosen from other six (6) colors. In test, we set the background color into three (3) new colors which is different from training set.

  • Rotation-Colored-MNIST This dataset is further augmented version of Colored-MNIST. The background color and digit color setting is the same with the Colored-MNIST. This dataset further contains digits rotated to four (4) different angles Θtrain={0,±22.5,±45}\Theta_{train}=\{0,\pm 22.5,\pm 45\}. For test data, the rotation angles for digit is set to Θtest={0,±65,±75}\Theta_{test}=\{0,\pm 65,\pm 75\}. The rotation angles are used as unknown nuisance factors.

  • 3dShapes [3dshapes18] contains 480,000 RGB 64×64×364\times 64\times 3 images and the whole dataset has six (6) different generative factors. We choose object shape (four (4) classes) as the prediction task and only half number of object colors are used during training, and the remaining half of object color samples are used to evaluate performance of invariant representation.

  • MPI3D [gondal2019transfer] is a real-world dataset contains 1,036,800 RGB images and the whole dataset has seven (7) generative factors. Like 3dShapes, we choose object shape (six (6) classes) as the prediction target and half of object colors are used for training.

Table 2: Test average accuracy and worst accuracy results on Rotation-Colored-MNIST with different rotation angles. Bold, Black: best result

\hlineB2 Models Rotation-Colored-MNIST Avg Acc Worst Acc Avg Acc Worst Acc Avg Acc Worst Acc Avg Acc Worst Acc -75 -65 +65 +75 Baseline 77.0 ±\pm1.3 62.3 ±\pm1.9 89.7 ±\pm1.2 77.5 ±\pm2.2 85.8 ±\pm1.2 65.8 ±\pm3.0 68.3 ±\pm2.2 49.9 ±\pm4.6 VFAE [bib:vfae] 72.2 ±\pm2.4 58.9 ±\pm2.3 85.8 ±\pm1.7 74.4 ±\pm2.5 84.1 ±\pm2.1 64.6 ±\pm3.7 71.7 ±\pm1.3 48.0 ±\pm3.8 CAI [bib:cai] 74.9 ±\pm0.9 59.3 ±\pm3.9 86.5 ±\pm1.9 77.3 ±\pm2.0 84.2 ±\pm1.7 67.8 ±\pm1.9 64.7 ±\pm4.2 42.9 ±\pm3.7 CVIB [bib:cvib] 76.1 ±\pm0.8 59.2 ±\pm3.0 88.6 ±\pm0.9 79.1 ±\pm1.2 85.6 ±\pm0.7 68.8 ±\pm2.9 72.2 ±\pm1.2 53.4 ±\pm 2.6 UAI [bib:uai] 76.0 ±\pm1.7 61.1 ±\pm5.6 88.8 ±\pm0.7 80.0 ±\pm0.9 85.4 ±\pm1.6 68.2 ±\pm2.3 70.2 ±\pm 0.9 51.1 ±\pm2.3 NN+DIM [bib:irmi] 77.6 ±\pm2.6 69.2 ±\pm2.7 85.2 ±\pm3.4 76.3 ±\pm4.3 84.6 ±\pm3.1 66.7 ±\pm3.7 68.4 ±\pm3.1 53.2 ±\pm5.6 \hdashlineOur model 81.0 ±\pm2.1 75.3 ±\pm2.5 90.8 ±\pm1.6 85.7 ±\pm2.4 87.3 ±\pm2.5 82.3 ±\pm2.1 73.2 ±\pm2.3 63.3 ±\pm2.9 \hlineB2

Prediction accuracy is used to evaluate the performance of invariant representation learning. Furthermore, we record both average test accuracy and worst-case test accuracy which was suggested by [Sagawa*2020Distributionally]. We find that using LABEL:eq:all_loss directly does not guarantee good performance. This may be caused by inconsistent behavior of CE loss and supervised contrastive loss. Thus, we separately train the classifier using CE loss and use remaining part of total loss to train the rest of the model.

Meanwhile, the performance of representation disentanglement is also important for representation invariance since it can evaluate the invariance of latent factors representing known nuisance factors.

We adopt the following metrics to evaluate the performance of disentangled representation. All metrics range from 0 to 11, where 11 indicates that the latent factors are fully disentangled — (1) Mutual Information Gap (MIG) [chen2019isolating] evaluates the gap of top two highest mutual information between a latent factors and generative factors. (2) Separated Attribute Predictability (SAP) [bib:dipvae] measures the mean of the difference of perdition error between the top two most predictive latent factors. (3) Interventional Robustness Score (IRS) [suter2019robustly] evaluates reliance of a latent factor solely on generative factor regardless of other generative factors. (4) FactorVAE (FVAE) score [bib:factorvae] implements a majority vote classifier to predict the index of a fixed generative factor and take the prediction accuracy as the final score value. (5) DCI-Disentanglement (DCI) [bib:dci] calculates the entropy of the distribution obtained by normalizing among each dimension of the learned representation for predicting the value of a generative factor.

4.2 Comparison with Previous Work

Table 3: Disentanglement metrics on 3dShapes and MPI3D. Bold, Black: best result