This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Causal Diffusion Autoencoders: Toward Counterfactual Generation via Diffusion Probabilistic Models

Aneesh Komanduri Corresponding Author. Email: [email protected].    Chen Zhao    Feng Chen    Xintao Wu University of Arkansas Baylor University University of Texas at Dallas
Abstract

Diffusion probabilistic models (DPMs) have become the state-of-the-art in high-quality image generation. However, DPMs have an arbitrary noisy latent space with no interpretable or controllable semantics. Although there has been significant research effort to improve image sample quality, there is little work on representation-controlled generation using diffusion models. Specifically, causal modeling and controllable counterfactual generation using DPMs is an underexplored area. In this work, we propose CausalDiffAE, a diffusion-based causal representation learning framework to enable counterfactual generation according to a specified causal model. Our key idea is to use an encoder to extract high-level semantically meaningful causal variables from high-dimensional data and model stochastic variation using reverse diffusion. We propose a causal encoding mechanism that maps high-dimensional data to causally related latent factors and parameterize the causal mechanisms among latent factors using neural networks. To enforce the disentanglement of causal variables, we formulate a variational objective and leverage auxiliary label information in a prior to regularize the latent space. We propose a DDIM-based counterfactual generation procedure subject to do-interventions. Finally, to address the limited label supervision scenario, we also study the application of CausalDiffAE when a part of the training data is unlabeled, which also enables granular control over the strength of interventions in generating counterfactuals during inference. We empirically show that CausalDiffAE learns a disentangled latent space and is capable of generating high-quality counterfactual images.

\paperid

1635

1 Introduction

Diffusion probabilistic models (DPMs) [31, 11, 20, 32, 33] are a class of likelihood-based generative models that have achieved remarkable successes in the generation of high-resolution images with many large-scale implementations such as DALLE-2 [26], Stable Diffusion [27], and Imagen [28]. Thus, there has been great interest in evaluating the capabilities of diffusion models. Two of the most promising approaches are formulated as discrete-time [11] and continuous-time [33] step-wise perturbations of the data distribution. A model is then trained to estimate the reverse process which transforms noisy samples to samples from the underlying data distribution. Representation learning has been an integral component of generative models such as GANs [8] and VAEs [14] for extracting robust and interpretable features from complex data [30, 2, 25]. Recently, a thrust of research has focused on whether DPMs can be used to extract a semantically meaningful and decodable representation that increases the quality of and control over generated images [21, 24]. However, there has been no work in modeling causal relations among the semantic latent codes to learn causal representations and enable counterfactual generation at inference time in DPMs. Generating high-quality counterfactual images is critical for domains such as healthcare and medicine [17, 29]. The ability to generate accurate counterfactual data from a causal graph obtained from domain knowledge can significantly cut the cost of data collection. Furthermore, reasoning about hypothetical scenarios unseen in the training distribution can be quite insightful for gauging the interactions among causal variables in complex systems. Given a causal graph of a system, we study the capability of DPMs as causal representation learners and evaluate their ability to generate counterfactuals upon interventions on causal variables.

Intuitively, we can think about the DPM as an encoder-decoder framework. The encoding maps an input image 𝐱0\mathbf{x}_{0} to a spatial latent variable 𝐱T\mathbf{x}_{T} through a series of Gaussian noise perturbations. However, 𝐱T\mathbf{x}_{T} can be interpreted as a noise representation that lacks high-level semantics [24]. Recently, Preechakul et al [24] proposed a diffusion-based autoencoder (DiffAE) to extract a high-level semantic representation alongside the stochastic low-level representation 𝐱T\mathbf{x}_{T} for decodable representation learning. Learning such a semantic representation also enables interpolation in the latent space for controllable generation and has been shown to improve image generation quality. Mittal et al [19] built on this framework and introduced a diffusion-based representation learning (DRL) objective that instead learns time-conditioned representations throughout the diffusion process. However, both these approaches learn arbitrary representations and do not focus on disentanglement, a key property of interpretable representations. Disentangled representations enable precise control of generative factors in isolation. When considering causal systems, disentanglement is important for performing isolated interventions.

In this paper, we focus on learning disentangled causal representations, where the high-level semantic factors are causally related. To the best of our knowledge, we are the first to explore representation-based counterfactual image generation using diffusion probabilistic models. We propose CausalDiffAE, a learning framework for causal representation learning and controllable counterfactual generation in DPMs. Our key idea is to learn a causal representation via a learnable stochastic encoder and model the relations among latents via causal mechanisms parameterized by neural networks. We formulate a variational objective with a label alignment prior to enforce disentanglement of the learned causal factors. We then utilize a conditional denoising diffusion implicit model (DDIM) [32] for decoding and modeling the stochastic variations. Intuitively, the causal representation encodes compact information that is causally relevant for image decoding in reverse diffusion. Furthermore, the modeling of causal relations in the latent space enables the generation of counterfactuals upon interventions on learned causal variables. We propose a DDIM variant for counterfactual generation subject to do()\textbf{do}(\cdot) interventions [23]. In an effort to improve the practicality and interpretability of the model, we propose an extension to CausalDiffAE that utilizes weaker supervision. In the scenario where labeled data is limited, we jointly train an unconditional and representation-conditioned diffusion model on the unlabeled and labeled partitions, respectively. This approach significantly reduces the number of labeled samples required for training and enables granular control over the strength of interventions and the quality of generated counterfactuals.

2 Related Work

Recent work in causal generative modeling has focused on either learning causal representations or controllable counterfactual generation [16]. Yang et al proposed CausalVAE [35], a causal representation learning framework that models latent causal variables by a linear SCM. Kocaoglu et al [15] proposed CausalGAN, an extension of the GAN to model causal variables for sampling from interventional distributions. Diffusion and score-based generative models [11, 33] have shown impressive results in class-conditional generation either through classifier-based [5] or classifier-free [10] paradigms. Recently, there has been an interest in exploring the capacity of diffusion models as representation learners. For instance, Mittal et al [19] and Preechakul et al [24] considered diffusion-based representation learning objectives. Mamaghan et al [12] explored representation learning from a score-based perspective given access to data in the form of counterfactual pairs. However, this work does not focus on counterfactual generation. Another related area of research is counterfactual explanations [1], which focuses on post-hoc methods to generate realistic counterfactuals, but not in the strictly causal sense. Our work focuses on diffusion-based representation learning and is most closely related to DiffAE [24] and DRL [19], which aim to learn semantically meaningful representations. However, the key distinction is that we learn causal representations to enable counterfactual generation. Our proposed framework extends CausalVAE to diffusion-based models and under a weaker supervision paradigm.

3 Background

3.1 Structural Causal Model

A structural causal model (SCM) is formally defined by a tuple =𝒵,𝒰,F\mathcal{M}=\langle\mathcal{Z},\mathcal{U},F\rangle, where 𝒵\mathcal{Z} is the domain of the set of nn endogenous causal variables 𝐳={z1,,zn}\mathbf{z}=\{z_{1},\dots,z_{n}\}, 𝒰\mathcal{U} is the domain of the set of nn exogenous noise variables 𝐮={u1,,un}\mathbf{u}=\{u_{1},\dots,u_{n}\}, which is learned as an intermediate latent variable, and F={f1,,fn}F=\{f_{1},\dots,f_{n}\} is a collection of nn independent causal mechanisms of the form

zi=fi(ui,zpai)z_{i}=f_{i}(u_{i},z_{\textbf{pa}_{i}}) (1)

where i\forall i, fi:𝒰i×jpai𝒵j𝒵if_{i}:\mathcal{U}_{i}\times\prod_{j\in\textbf{pa}_{i}}\mathcal{Z}_{j}\to\mathcal{Z}_{i} are causal mechanisms that determine each causal variable as a function of the parents and noise, zpaiz_{\textbf{pa}_{i}} are the parents of causal variable ziz_{i}; and a probability measure p𝒰(𝐮)=p𝒰1(u1)p𝒰2(u2)p𝒰n(un)p_{\mathcal{U}}(\mathbf{u})=p_{\mathcal{U}_{1}}(u_{1})p_{\mathcal{U}_{2}}(u_{2})\dots p_{\mathcal{U}_{n}}(u_{n}), which admits a product distribution. For the purposes of this work, we assume a causally sufficient setting (no hidden confounding), no SCM misspecification, and faithfulness is satisfied.

Refer to caption
Figure 1: CausalDiffAE Framework. The left side details the training process of CausalDiffAE by encoding to causal representation 𝐳causal\mathbf{z}_{\text{causal}} and using a conditional DDIM decoder conditioned on 𝐳causal\mathbf{z}_{\text{causal}} and 𝐱T\mathbf{x}_{T} for image reconstruction. The right side shows the DDIM-based counterfactual generation procedure using a trained CausalDiffAE model.

3.2 Diffusion Probabilistic Models

Diffusion Probabilistic Models (DPMs) [11, 20] have shown impressive results in image generation tasks, even beating out GANs in many cases [5]. The idea of the denoising diffusion probabilistic model (DDPM) [11] is to define a Markov chain of diffusion steps to slowly destroy the structure in a data distribution through a forward diffusion process by adding noise [11] and learn a reverse diffusion process that restores the structure of the data. Some proposed methods, such as denoising diffusion implicit model (DDIM) [32], break the Markov assumption to speed up the sampling in the diffusion process by carrying out a deterministic encoding of the noise.

Forward Diffusion. Given some input data sampled from a distribution 𝐱0q(𝐱)\mathbf{x}_{0}\sim q(\mathbf{x}), the forward diffusion process is defined by adding small amounts of Gaussian noise to the sample in TT steps thereby producing noisy samples 𝐱1,,𝐱T\mathbf{x}_{1},\dots,\mathbf{x}_{T}. The distribution of the noisy sample at time step tt is defined as a conditional distribution as follows:

q(𝐱t|𝐱t1)=𝒩(𝐱t;1βt𝐱t1,βt𝐈)q(\mathbf{x}_{t}|\mathbf{x}_{t-1})=\mathcal{N}(\mathbf{x}_{t};\sqrt{1-\beta_{t}}\mathbf{x}_{t-1},\beta_{t}\mathbf{I}) (2)

where βt(0,1)\beta_{t}\in(0,1) is a variance parameter that controls the step size of noise. As tt\to\infty, the input sample 𝐱0\mathbf{x}_{0} loses its distinguishable features. In the end, when t=Tt=T, 𝐱T\mathbf{x}_{T} follows an isotropic Gaussian. From Eq (2), we can then define a closed-form tractable posterior over all time steps factorized as follows:

q(𝐱1:T|𝐱0)=t=1Tq(𝐱t|𝐱t1)q(\mathbf{x}_{1:T}|\mathbf{x}_{0})=\prod_{t=1}^{T}q(\mathbf{x}_{t}|\mathbf{x}_{t-1}) (3)

Now, 𝐱t\mathbf{x}_{t} can be sampled at any arbitrary time step tt using the reparameterization trick. Let αt=i=1t1βi\alpha_{t}=\prod_{i=1}^{t}1-\beta_{i}:

q(𝐱t|𝐱0)=𝒩(𝐱t;αt𝐱0,(1αt)𝐈)q(\mathbf{x}_{t}|\mathbf{x}_{0})=\mathcal{N}(\mathbf{x}_{t};\sqrt{\alpha_{t}}\mathbf{x}_{0},(1-\alpha_{t})\mathbf{I}) (4)

Reverse Diffusion. In the reverse process, to sample from q(𝐱t1|𝐱t)q(\mathbf{x}_{t-1}|\mathbf{x}_{t}), the goal is to recreate the true sample 𝐱0\mathbf{x}_{0} from a Gaussian noise input 𝐱T𝒩(𝟎,𝐈)\mathbf{x}_{T}\sim\mathcal{N}(\mathbf{0},\mathbf{I}). Unlike the forward diffusion, q(𝐱t1|𝐱t)q(\mathbf{x}_{t-1}|\mathbf{x}_{t}) is not analytically tractable and thus requires learning a model pθp_{\theta} to approximate the conditional distributions as follows:

pθ(𝐱0:T)=p(𝐱T)t=1Tpθ(𝐱t1|𝐱t)pθ(𝐱t1|𝐱t)=𝒩(𝐱t1;μθ(𝐱t,t),Σθ(𝐱t,t))\begin{split}p_{\theta}(\mathbf{x}_{0:T})&=p(\mathbf{x}_{T})\prod_{t=1}^{T}p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})\\ p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})&=\mathcal{N}(\mathbf{x}_{t-1};\mu_{\theta}(\mathbf{x}_{t},t),\Sigma_{\theta}(\mathbf{x}_{t},t))\end{split} (5)

where μθ\mu_{\theta} and Σθ\Sigma_{\theta} are learned via neural networks. It turns out that conditioning on the input 𝐱0\mathbf{x}_{0} yields a tractable reverse conditional probability

q(𝐱t1|𝐱t,𝐱0)=𝒩(𝐱t1;μ~(𝐱t,𝐱0),β~t𝐈)q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0})=\mathcal{N}(\mathbf{x}_{t-1};\tilde{\mu}(\mathbf{x}_{t},\mathbf{x}_{0}),\tilde{\beta}_{t}\mathbf{I}) (6)

where μ~\tilde{\mu} and β~t\tilde{\beta}_{t} are the true mean and variance. The learning objective is then formulated as a simplified objective of the ELBO via reparameterization to minimize the following mean squared error loss

simple=t=1T𝔼𝐱0,ϵt[ϵtϵθ(𝐱t,t)22]\mathcal{L}_{\text{simple}}=\sum_{t=1}^{T}\mathbb{E}_{\mathbf{x}_{0},\epsilon_{t}}\Big{[}\|\epsilon_{t}-\epsilon_{\theta}(\mathbf{x}_{t},t)\|^{2}_{2}\Big{]} (7)

where ϵt𝒩(𝟎,𝐈)\epsilon_{t}\sim\mathcal{N}(\mathbf{0},\mathbf{I}) is the noise that takes an analytical form via a reparameterization from 𝐱0\mathbf{x}_{0}, as shown in [11].

DPMs produce latent variables 𝐱1:T\mathbf{x}_{1:T} through the forward process. However, these variables are stochastic [24]. Song et al. proposed a DPM called Denoising Diffusion Implicit model (DDIM), which enables a deterministic process as follows:

𝐱t1=αt1(𝐱t1αtϵθ(𝐱t,t)αt)+1αt1ϵθ(𝐱t,t)\mathbf{x}_{t-1}=\sqrt{\alpha_{t-1}}\Big{(}\frac{\mathbf{x}_{t}-\sqrt{1-\alpha_{t}}\epsilon_{\theta}(\mathbf{x}_{t},t)}{\sqrt{\alpha_{t}}}\Big{)}+\sqrt{1-\alpha_{t-1}}\epsilon_{\theta}(\mathbf{x}_{t},t) (8)

with the following deterministic decoding process

q(𝐱t1|𝐱t,𝐱0)=𝒩(αt1𝐱0+1αt1𝐱tαt𝐱01αt,𝟎)q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0})=\mathcal{N}\Big{(}\sqrt{\alpha_{t-1}}\mathbf{x}_{0}+\sqrt{1-\alpha_{t-1}}\frac{\mathbf{x}_{t}-\sqrt{\alpha_{t}}\mathbf{x}_{0}}{\sqrt{1-\alpha_{t}}},\mathbf{0}\Big{)} (9)

which keeps the DDPM marginal distribution q(𝐱t|𝐱0)=𝒩(αt1𝐱0,(1αt)𝐈)q(\mathbf{x}_{t}|\mathbf{x}_{0})=\mathcal{N}(\sqrt{\alpha_{t-1}}\mathbf{x}_{0},(1-\alpha_{t})\mathbf{I}). It turns out that this formulation shares the same objective and solution of DDPM and only differs in the sampling procedure. Thus, we can deterministically obtain the noise map 𝐱T\mathbf{x}_{T} corresponding to a given image 𝐱0\mathbf{x}_{0}.

4 Causal Diffusion Autoencoders

Existing diffusion-based controllable generation methods neglect the scenario where generative factors are causally related and do not support counterfactual generation. To tackle this issue, we propose CausalDiffAE, a diffusion-based causal representation learning framework to enable counterfactual generation. Firstly, we define a latent SCM to describe the semantic causal representation as a function of learned noise encodings. In the case of diffusion autoencoders [24], the semantic latent representation 𝐳sem\mathbf{z}_{\text{sem}} captures high-level semantic information, and 𝐱T\mathbf{x}_{T} captures low-level stochastic information. In our formulation, we learn a causal representation 𝐳causal\mathbf{z}_{\text{causal}} which captures causally relevant information. Together, the two latent variables (𝐳causal,𝐱T)(\mathbf{z}_{\text{causal}},\mathbf{x}_{T}) capture all the detailed causal semantics and stochasticity in the image. Secondly, given a trained CausalDiffAE model, we propose a counterfactual generation algorithm that utilizes do()\textbf{do}(\cdot) interventions and the DDIM sampling algorithm. The overall framework of CausalDiffAE is shown in Figure 1.

4.1 Causal Encoding

Let 𝐱0d\mathbf{x}_{0}\in\mathbb{R}^{d} be the observed input image. We carry out the forward diffusion process until we have a set of TT perturbed samples {𝐱1,𝐱2,,𝐱T}\{\mathbf{x}_{1},\mathbf{x}_{2},\dots,\mathbf{x}_{T}\}, each at a different noise scale. Suppose there are nn abstract causal variables that describe the high-level semantics of the observed image. To learn a meaningful representation, we propose to encode the input image 𝐱0\mathbf{x}_{0} to a low-dimensional noise encoding 𝐮n\mathbf{u}\in\mathbb{R}^{n}. We then map the noise encoding to latent causal factors 𝐳causaln\mathbf{z}_{\text{causal}}\in\mathbb{R}^{n} corresponding to the abstract causal variables. In this formulation, each noise term uiu_{i} is the exogenous noise term for causal variable ziz_{i} in the SCM. Let 𝐀\mathbf{A} be the adjacency matrix encoding the causal graph among the underlying factors where AjiA_{ji} implies zjz_{j} is a cause of ziz_{i}. Then, we parameterize the mechanisms among causal variables as follows

zi=fi(zpai,ui)z_{i}=f_{i}(z_{\textbf{pa}_{i}},u_{i}) (10)

where fif_{i} is the causal mechanism generating causal variable ziz_{i} as a function of its parents and exogenous noise term and zpaiz_{\textbf{pa}_{i}} denotes the causal parents of factor ziz_{i}. In practice, we can implement fif_{i} as a post-nonlinear additive noise model such that

𝐳=(I𝐀T)1𝐮zi=fi(𝐀i𝐳;νi)+ui\begin{split}\mathbf{z}&=(I-\mathbf{A}^{T})^{-1}\mathbf{u}\\ z_{i}&=f_{i}(\mathbf{A}_{i}\odot\mathbf{z};\nu_{i})+u_{i}\end{split} (11)

where νi\nu_{i} are the parameters of the neural network parameterizing each mechanism, \odot is the elementwise product, and 𝐳causal={z1,,zn}\mathbf{z}_{\text{causal}}=\{z_{1},\dots,z_{n}\}. This module captures the causal relations between latent variables using neural structural causal models. For the purposes of this work, we assume that the causal graph is known since we focus on counterfactual generation. However, a more end-to-end framework may include a causal discovery component. See Appendix C for a more detailed discussion.

4.2 Generative Model

Let 𝐱0\mathbf{x}_{0} denote the high-dimensional input image and 𝐲n\mathbf{y}\in\mathbb{R}^{n} denote an auxiliary weak supervision signal. Then, the CausalDiffAE generative process can be factorized as follows:

p(𝐱0:T,𝐮,𝐳causal|𝐲)=pθ(𝐱0:T|𝐮,𝐳causal,𝐲)p(𝐮,𝐳causal|𝐲)p(\mathbf{x}_{0:T},\mathbf{u},\mathbf{z}_{\text{causal}}|\mathbf{y})=p_{\theta}(\mathbf{x}_{0:T}|\mathbf{u},\mathbf{z}_{\text{causal}},\mathbf{y})p(\mathbf{u},\mathbf{z}_{\text{causal}}|\mathbf{y}) (12)

where θ\theta are the parameters of the reverse process of the causal diffusion decoder (will discuss in Section 4.3), p(𝐮,𝐳causal|𝐲)=p(𝐮)p(𝐳causal|𝐲)p(\mathbf{u},\mathbf{z}_{\text{causal}}|\mathbf{y})=p(\mathbf{u})p(\mathbf{z}_{\text{causal}}|\mathbf{y}), p(𝐮)=𝒩(𝟎,𝐈)p(\mathbf{u})=\mathcal{N}(\mathbf{0},\mathbf{I}), and p(𝐳causal|𝐲)p(\mathbf{z}_{\text{causal}}|\mathbf{y}) is the alignment prior defined in Eq. (19). The joint posterior distribution p(𝐱1:T,𝐮,𝐳causal|𝐱0,𝐲)p(\mathbf{x}_{1:T},\mathbf{u},\mathbf{z}_{\text{causal}}|\mathbf{x}_{0},\mathbf{y}) is intractable, so we approximate it using a variational distribution q(𝐱1:T,𝐮,𝐳causal|𝐱0,𝐲)q(\mathbf{x}_{1:T},\mathbf{u},\mathbf{z}_{\text{causal}}|\mathbf{x}_{0},\mathbf{y}) which can be factorized into the following conditional distributions

q(𝐱1:T,𝐮,𝐳causal|𝐱0,𝐲)=qϕ(𝐳causal,𝐮|𝐱0,𝐲)q(𝐱1:T|𝐮,𝐳causal,𝐱0)\begin{split}q(\mathbf{x}_{1:T},\mathbf{u},\mathbf{z}_{\text{causal}}|\mathbf{x}_{0},\mathbf{y})=q_{\phi}&(\mathbf{z}_{\text{causal}},\mathbf{u}|\mathbf{x}_{0},\mathbf{y})\\ &q(\mathbf{x}_{1:T}|\mathbf{u},\mathbf{z}_{\text{causal}},\mathbf{x}_{0})\end{split} (13)

where ϕ\phi are the parameters of the variational encoder network parameterizing the joint distribution over the noise 𝐮\mathbf{u} and causal factors 𝐳causal\mathbf{z}_{\text{causal}}. We can remove the dependence on 𝐲\mathbf{y} for the second conditional term in the decomposition of Eq. (13) since 𝐱1:T\mathbf{x}_{1:T} is independent of the auxiliary label 𝐲\mathbf{y}. We note that qϕ(𝐳causal,𝐮|𝐱0,𝐲)q_{\phi}(\mathbf{z}_{\text{causal}},\mathbf{u}|\mathbf{x}_{0},\mathbf{y}) can be factorized as qϕ(𝐳causal|𝐱0,𝐲)qϕ(𝐮|𝐱0)q_{\phi}(\mathbf{z}_{\text{causal}}|\mathbf{x}_{0},\mathbf{y})q_{\phi}(\mathbf{u}|\mathbf{x}_{0}) since 𝐮\mathbf{u} and 𝐳causal\mathbf{z}_{\text{causal}} have a one-to-one correspondence.

Algorithm 1 CausalDiffAE Training

Input: (image, label) pairs (𝐱0,𝐲)(\mathbf{x}_{0},\mathbf{y})
Output: learned parameters {θ,ϕ}\{\theta,\phi\}

1:repeat
2:     𝐱0q(𝐱0)\mathbf{x}_{0}\sim q(\mathbf{x}_{0})
3:     𝐮qϕ(𝐮|𝐱0)\mathbf{u}\sim q_{\phi}(\mathbf{u}|\mathbf{x}_{0}) \triangleright Noise encoding
4:     𝐳causal={fi(ui,zpai;νi)}i=1n\mathbf{z}_{\text{causal}}=\{f_{i}(u_{i},z_{\textbf{pa}_{i}};\nu_{i})\}_{i=1}^{n} \triangleright Causal encoding
5:     t𝒰({1,,T})t\sim\mathcal{U}(\{1,\dots,T\}) \triangleright Sample timestep
6:     ϵt𝒩(𝟎,𝐈)\epsilon_{t}\sim\mathcal{N}(\mathbf{0},\mathbf{I})
7:     𝐱t=αt𝐱0+1αtϵt\mathbf{x}_{t}=\sqrt{\alpha_{t}}\mathbf{x}_{0}+\sqrt{1-\alpha_{t}}\epsilon_{t} \triangleright Corrupt data to sampled time
8:     Take gradient step on θ,ϕCausalDiffAE\nabla_{\theta,\phi}\mathcal{L}_{\text{CausalDiffAE}}
9:until convergence

4.3 Causal Diffusion Decoder

We use a conditional DDIM decoder that takes as input the pair of latent variables (𝐳causal,𝐱T)(\mathbf{z}_{\text{causal}},\mathbf{x}_{T}) to generate the output image. We approximate the inference distribution q(𝐱t1|𝐱t,𝐱0)q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0}) by parameterizing the probabilistic decoder via a conditional DDIM pθ(𝐱t1|𝐱t,𝐳causal)p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{z}_{\text{causal}}). With DDIM, the forward process becomes completely deterministic except for t=1t=1. Similar to [24], we define the joint distribution of the reverse generative process as follows:

pθ(𝐱0:T|𝐳causal)=p(𝐱T)t=1Tpθ(𝐱t1|𝐱t,𝐳causal)p_{\theta}(\mathbf{x}_{0:T}|\mathbf{z}_{\text{causal}})=p(\mathbf{x}_{T})\prod_{t=1}^{T}p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{z}_{\text{causal}}) (14)
pθ(𝐱t1|𝐱t,𝐳causal)={𝒩(𝐟θ(𝐱1,1,𝐳causal),𝟎)if t=1q(𝐱t1|𝐱t,𝐟θ(𝐱t,t,𝐳causal))otherwisep_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{z}_{\text{causal}})=\begin{cases}\mathcal{N}(\mathbf{f}_{\theta}(\mathbf{x}_{1},1,\mathbf{z}_{\text{causal}}),\mathbf{0})&\text{if $t=1$}\\ q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{f}_{\theta}(\mathbf{x}_{t},t,\mathbf{z}_{\text{causal}}))&\text{otherwise}\end{cases} (15)

where 𝐟θ\mathbf{f}_{\theta} is parameterized by a noise prediction network ϵθ\epsilon_{\theta} (i.e., UNet [5]) as follows:

𝐟θ(𝐱t,t,𝐳causal)=1αt(𝐱t1αtϵθ(𝐱t,t,𝐳causal))\mathbf{f}_{\theta}(\mathbf{x}_{t},t,\mathbf{z}_{\text{causal}})=\frac{1}{\sqrt{\alpha_{t}}}(\mathbf{x}_{t}-\sqrt{1-\alpha_{t}}\epsilon_{\theta}(\mathbf{x}_{t},t,\mathbf{z}_{\text{causal}})) (16)

Note that in Eq. (14), 𝐮\mathbf{u} is omitted since 𝐳causal\mathbf{z}_{\text{causal}} already captures all the information about the noise. By leveraging the reparameterization trick, we can optimize the following mean squared error between noise terms

simple=t=1T𝔼𝐱0,ϵt[ϵθ(𝐱t,t,𝐳causal)ϵt22]\mathcal{L}_{\text{simple}}=\sum_{t=1}^{T}\mathbb{E}_{\mathbf{x}_{0},\epsilon_{t}}\Big{[}\|\epsilon_{\theta}(\mathbf{x}_{t},t,\mathbf{z}_{\text{causal}})-\epsilon_{t}\|_{2}^{2}\Big{]} (17)

where ϵt𝒩(𝟎,𝐈)\epsilon_{t}\sim\mathcal{N}(\mathbf{0},\mathbf{I}) and 𝐱t=αt𝐱0+1αtϵt\mathbf{x}_{t}=\sqrt{\alpha_{t}}\mathbf{x}_{0}+\sqrt{1-\alpha_{t}}\epsilon_{t}.

4.4 Learning Objective

To ensure the causal representation is disentangled, we incorporate label information 𝐲n\mathbf{y}\in\mathbb{R}^{n} as a prior in the variational objective to aid in learning semantic factors and for identifiability guarantees [13]. We define the following joint loss objective:

CausalDiffAE=simple+γ{𝒟KL(qϕ(𝐳causal|𝐱0,𝐲)p(𝐳causal|𝐲))+𝒟KL(qϕ(𝐮|𝐱0)𝒩(𝟎,𝐈))}\begin{split}\mathcal{L}_{\text{CausalDiffAE}}&=\mathcal{L}_{\text{simple}}\\ &+\gamma\Big{\{}\mathcal{D}_{KL}(q_{\phi}(\mathbf{z}_{\text{causal}}|\mathbf{x}_{0},\mathbf{y})\|p(\mathbf{z}_{\text{causal}}|\mathbf{y}))\\ &+\mathcal{D}_{KL}(q_{\phi}(\mathbf{u}|\mathbf{x}_{0})\|\mathcal{N}(\mathbf{0},\mathbf{I}))\Big{\}}\end{split} (18)

where γ\gamma is a regularization hyperparameter similar to the bottleneck parameter in β\beta-VAEs [9], and the alignment prior over latent variables is defined as the following exponential family distribution

p(𝐳causal|𝐲)=i=1np(zi|yi)=i=1n𝒩(zi;μν(yi),σν2(yi)𝐈)p(\mathbf{z}_{\text{causal}}|\mathbf{y})=\prod_{i=1}^{n}p(z_{i}|y_{i})=\prod_{i=1}^{n}\mathcal{N}(z_{i};\mu_{\nu}(y_{i}),\sigma^{2}_{\nu}(y_{i})\mathbf{I}) (19)

where μν\mu_{\nu} and σν2\sigma^{2}_{\nu} are functions that estimate the mean and variance of the Gaussian, respectively. Intuitively, this prior ensures that the learned factors are one-to-one mapped to an indicator of the underlying ground truth factors. DiffAE requires training a latent DDIM in the latent space of the pre-trained autoencoder to enable sampling of latent semantic representation. However, CausalDiffAE is formulated as a variational objective with a stochastic encoder. Thus, we can sample the representation from the defined prior directly without having to train a separate diffusion model in the latent space. The training procedure for CausalDiffAE is outlined in Algorithm 1. See Appendix A for a derivation of the ELBO. For a detailed discussion on the connection of our diffusion objective to score-based generative models [33], see Appendix B.

Algorithm 2 CausalDiffAE Counterfactual Generation

Input: Factual sample 𝐱0\mathbf{x}_{0}, intervention target set \mathcal{I} with intervention values cc, noise predictor ϵθ\epsilon_{\theta}, encoder ϕ\phi
Output: Counterfactual sample 𝐱0CF\mathbf{x}_{0}^{CF}

1:𝐮qϕ(𝐮|𝐱0)\mathbf{u}\sim q_{\phi}(\mathbf{u}|\mathbf{x}_{0}) \triangleright Noise encoding
2:for i=1i=1 to nn do \triangleright in topological order
3:    if ii\in\mathcal{I} then
4:         zi=ciz_{i}=c_{i}
5:    else
6:         zi=fi(ui,zpai)z_{i}=f_{i}(u_{i},z_{\textbf{pa}_{i}})
7:    end if
8:end for
9:𝐳¯causal={z1,,zn}\bar{\mathbf{z}}_{\text{causal}}=\{z_{1},\dots,z_{n}\} \triangleright Intervened representation
10:𝐱T𝒩(αT𝐱0,(1αT)𝐈)\mathbf{x}_{T}\sim\mathcal{N}(\sqrt{\alpha_{T}}\mathbf{x}_{0},(1-\alpha_{T})\mathbf{I})
11:𝐱TCF=𝐱T\mathbf{x}_{T}^{CF}=\mathbf{x}_{T}
12:for t=T,,1t=T,\dotsc,1 do \triangleright DDIM sampling
13:    𝐱t1CF=αt1(𝐱tCF1αtϵθ(𝐱tCF,t,𝐳causal)αt)\mathbf{x}_{t-1}^{CF}=\sqrt{\alpha_{t-1}}\Big{(}\frac{\mathbf{x}_{t}^{CF}-\sqrt{1-\alpha_{t}}\epsilon_{\theta}(\mathbf{x}_{t}^{CF},t,\mathbf{z}_{\text{causal}})}{\sqrt{\alpha_{t}}}\Big{)}
14:                              +1αt1ϵθ(𝐱tCF,t,𝐳causal)+\sqrt{1-\alpha_{t-1}}\epsilon_{\theta}(\mathbf{x}_{t}^{CF},t,\mathbf{z}_{\text{causal}})
15:end for
16:return 𝐱0CF\mathbf{x}_{0}^{CF}

4.5 Counterfactual Generation

A fundamental property of causal models is the ability to perform interventions and observe changes to a system. In generative models, this enables the sampling of counterfactual data. Given a pre-trained CausalDiffAE, we can controllably manipulate any factor of variation, propagate the causal effects to descendants, and perform reverse diffusion to sample from the counterfactual distribution. Algorithm 2 shows the process of generating counterfactuals from a trained CausalDiffAE, where 𝐱0\mathbf{x}_{0} refers to the factual observation and 𝐱0CF\mathbf{x}_{0}^{CF} refers to the generated counterfactual sample. To generate counterfactual instances, we first encode the high dimensional observation 𝐱0\mathbf{x}_{0} to a noise encoding 𝐮\mathbf{u} (abduction) and transform it to causal latent variables 𝐳causal\mathbf{z}_{\text{causal}}. Then, we intervene on a desired variable and propagate the causal effects via neural mechanisms to yield the intervened representation 𝐳¯causal\bar{\mathbf{z}}_{\text{causal}}. We utilize the DDIM sampling algorithm to ensure the stochastic noise 𝐱T\mathbf{x}_{T} is a deterministic encoding to enable semantic manipulations. Finally, we decode using DDIM conditioned on (𝐳¯causal,𝐱T)(\bar{\mathbf{z}}_{\text{causal}},\mathbf{x}_{T}) to obtain a counterfactual 𝐱0CF\mathbf{x}_{0}^{CF}. In lines 12-13, we use the DDIM non-Markovian deterministic generative process to generate counterfactual instances as follows:

𝐱t1CF=αt1(𝐱tCF1αtϵθ(𝐱tCF,t,𝐳¯causal)αt)+1αt1ϵθ(𝐱tCF,t,𝐳¯causal)\begin{split}\mathbf{x}_{t-1}^{CF}&=\sqrt{\alpha_{t-1}}\Big{(}\frac{\mathbf{x}_{t}^{CF}-\sqrt{1-\alpha_{t}}\epsilon_{\theta}(\mathbf{x}_{t}^{CF},t,\bar{\mathbf{z}}_{\text{causal}})}{\sqrt{\alpha_{t}}}\Big{)}\\ &+\sqrt{1-\alpha_{t-1}}\epsilon_{\theta}(\mathbf{x}_{t}^{CF},t,\bar{\mathbf{z}}_{\text{causal}})\end{split} (20)

Conditioning vs. Intervening. When we study causal generative models, we utilize the intervention operation, which is a fundamentally different operation than conditioning. When we condition, we narrow our scope to a specific subgroup of the data based on the conditioning variable. Interventions are population-level operations that fix a variable’s value (rendering it independent of its parents) to determine causal effects downstream. We emphasize that, under this intervention operation, causal models are robust to distribution shifts and can generate data outside the support of the training distribution.

4.6 Weak Supervision

To reduce the reliance on labeled data, inspired by classifier-free [10] guidance, we train a CausalDiffAE with a weak supervision guidance paradigm on the representation level [5].

Training. In the limited labeled-data regime, we train two models: an unconditional denoising diffusion model pθ(𝐱)p_{\theta}(\mathbf{x}) parameterized by the score estimator ϵθ(𝐱t,t)\epsilon_{\theta}(\mathbf{x}_{t},t) and a representation-conditioned model pθ(𝐱|𝐳causal)p_{\theta}(\mathbf{x}|\mathbf{z}_{\text{causal}}) parameterized through ϵθ(𝐱t,t,𝐳causal)\epsilon_{\theta}(\mathbf{x}_{t},t,\mathbf{z}_{\text{causal}}). We use a single neural network to parameterize both models, where for the unconditional model we use only the unlabeled data for predicting the score (i.e., ϵθ(𝐱t,t)\epsilon_{\theta}(\mathbf{x}_{t},t)).

Generation. The counterfactual generation procedure in lines 12-13 of Algorithm 2 can be modified to generate counterfactuals with a guidance strength ω\omega, which can be interpreted as controlling the strength of the intervention on the causal variable to generate the counterfactual in our case. The overall modified score estimation during generation can be performed using the following linear combination of conditional and unconditional score estimates

ϵ¯θ(𝐱t,t,𝐳¯causal)=ωϵθ(𝐱t,t,𝐳¯causal)causal conditional model+(1ω)ϵθ(𝐱t,t)unconditional model\bar{\epsilon}_{\theta}(\mathbf{x}_{t},t,\bar{\mathbf{z}}_{\text{causal}})=\underbrace{\omega\epsilon_{\theta}(\mathbf{x}_{t},t,\bar{\mathbf{z}}_{\text{causal}})}_{\text{causal conditional model}}+\underbrace{(1-\omega)\epsilon_{\theta}(\mathbf{x}_{t},t)}_{\text{unconditional model}} (21)

where 𝐳¯causal\bar{\mathbf{z}}_{\text{causal}} is the set of latent causal factors after an intervention. The original utility of the classifier-free paradigm was to decrease the generation of diverse data in favor of higher-quality image samples without needing classifier gradients. So, ω\omega controls the trade-off between higher quality and diverse samples. In our case, we care about generating high-quality counterfactual data. Intuitively, a higher ω\omega implies a stronger effect of the intervention on the generated counterfactual since the conditional model ϵθ(𝐱t,t,𝐳causal)\epsilon_{\theta}(\mathbf{x}_{t},t,\mathbf{z}_{\text{causal}}) is sensitive to interventions. So, as ω\omega decreases, the unconditional model dilutes the effect of the intervention-sensitive model. In this sense, the sampling mechanism can be used to evaluate the causal strength of interventions. We find that the weak supervision paradigm enables (1) more efficient training with a weaker supervision signal, and (2) fine-grained control over generated counterfactuals.

5 Experiments

5.1 Empirical Setting

Datasets. We experiment on three datasets. We use the MorphoMNIST dataset [4] created by imposing a 2-variable SCM to generate morphological transformations on the original MNIST dataset, where thickness is the cause of the intensity of the digit [22], as shown in Figure 2(a). The Pendulum dataset [35] consists of images of a causal system consisting of a light source, pendulum, and shadow. The light source and the pendulum angle determine the length and position of the shadow, as shown in the causal graph in Figure 2(b). We also use CausalCircuit [3], a complex 3D robotics dataset where a robot arm moves around to turn on red, green, or blue lights. The causal graph of this system is shown in Figure 3.

Refer to caption
(a) MorphoMNIST results (Orig: y1=2.399,y2=162.2739y_{1}=2.399,y_{2}=162.2739)
Refer to caption
(b) Pendulum results (Orig: y1=16,y2=113,y3=3,y4=12y_{1}=16,y_{2}=113,y_{3}=3,y_{4}=12)
Figure 2: Counterfactual trajectories generated by CausalDiffAE and baseline models for (a) MorphoMNIST and (b) Pendulum datasets. We observe that CausalDiffAE generates much more accurate counterfactuals upon interventions on causal factors compared to baselines.

Baselines. CausalVAE [35] is a VAE-based causal representation learning framework that models causal variables using a linear SCM and enables counterfactual generation during inference time through interventions on causal variables. Class-conditional diffusion model (CCDM) [20] is a conditional diffusion model that utilizes class labels as the conditioning signal in reverse diffusion. Thus, this model is capable of generating new samples determined by a discrete or continuous set of labels 𝐲\mathbf{y}. DiffAE [24] is a diffusion model that aims to learn manipulable and semantically meaningful latent codes. However, this approach learns an arbitrary representation in an unsupervised fashion and does not disentangle the latent space. Manipulations are performed using a post-hoc classifier for linear interpolation. Thus, the learned representation would not be ideal to perform causal interventions. For a fair comparison in counterfactual generation, we modify the objective to disentangle the latent space by incorporating label information in a prior to regularize the posterior. We call this extension DisDiffAE. We use DisDiffAE as a baseline to evaluate counterfactual generation and the DiffAE to evaluate disentanglement.

Metrics. We primarily use two quantitative metrics to evaluate the performance of our approach. To evaluate the disentanglement of the learned representations, we use the DCI disentanglement metric [7]. A high DCI score also suggests the effectiveness of controllable generation. In the context of a causal representation, this means that we can intervene on latent codes in an isolated fashion without any entanglements (i.e., two factors are encoded in the same latent code). To quantitatively evaluate generated counterfactuals, we adopt the Effectiveness metric from Melistas et al [18], which evaluates how successful the performed intervention was at generating the counterfactual. We train anti-causal predictors via convolutional regressors on the training dataset for each continuous causal variable ziz_{i}. Then, we report the mean absolute error (MAE) loss between the predicted values from the generated counterfactual and the true values of the counterfactual. This metric captures how controllable the factors are and the accuracy of the generated counterfactuals.

For details about the datasets, implementation, metrics, and computational requirements, see Appendix D. Our code is available at https://github.com/Akomand/CausalDiffAE.

Table 1: Disentanglement (DCI)
Dataset Model DCI \uparrow
MorphoMNIST CausalVAE 0.784±0.010.784\pm{0.01}
DiffAE 0.358±0.010.358\pm{0.01}
CausalDiffAE 0.993±0.01\mathbf{0.993\pm{0.01}}
Pendulum CausalVAE 0.885±0.010.885\pm{0.01}
DiffAE 0.353±0.010.353\pm{0.01}
CausalDiffAE 0.999±0.01\mathbf{0.999\pm{0.01}}
CausalCircuit CausalVAE 0.8860±0.010.8860\pm{0.01}
DiffAE 0.353±0.010.353\pm{0.01}
CausalDiffAE 0.999±0.01\mathbf{0.999\pm{0.01}}

5.2 Disentanglement of Latent Space

We compare the disentanglement of CausalDiffAE with other baseline models, as shown in Table 1. We observe that diffusion-based representation learning objectives coupled with a suitable prior can better disentangle latent variables compared to VAE-based models. We do not include CCDM as a baseline here since it does not produce a representation to be evaluated. Compared to CausalVAE, the diffusion-based decoder in CausalDiffAE disentangles the semantic factors of variation to a much greater degree. Thus, we can perform interventions on causal variables in isolation and observe their downstream effects. We also note that DiffAE [24] does not learn a disentangled latent space since the semantic representation learned is arbitrary. To perform controllable manipulations with DiffAE, a post-hoc classifier must be trained to guide the sampling process. CausalDiffAE offers more precise control over learned factors through the disentanglement objective without the need to train additional classifiers.

5.3 Controllable Counterfactual Generation

Qualitative Evaluation. We show that CausalDiffAE produces much more realistic counterfactual samples compared to other acausal baselines and its VAE counterpart, CausalVAE. We attribute this to the diffusion process, which is better capable of capturing causally relevant information along with low-level stochastic variation.

Figure 2(a) shows the counterfactual generation results for the MorphoMNIST dataset. CausalVAE can generate counterfactual images after intervening on either thickness or intensity, but the accuracy and quality of the generated counterfactuals is far lower than CausalDiffAE. For instance, lower thickness does not lower the intensity and lower intensity intervention seems to change the thickness of the digit. CCDM fails to produce samples consistent with the underlying causal model. For example, intervening on the intensity produces a sample that increases in thickness. From a conditioning perspective, high-intensity digits tend to be thicker in the training distribution. For DisDiffAE, increasing the thickness does not influence the intensity.

Refer to caption
Figure 3: CausalCircuit results (Orig: y1=0.02,y2=0.03,y3=0.04,y4=0.14y_{1}=0.02,y_{2}=0.03,y_{3}=0.04,y_{4}=0.14)

Figure 2(b) shows counterfactual generation results for the Pendulum causal system. Upon interventions, images generated by CCDM are not consistent with the causal model. For example, intervening on the light position changes both the light position and the pendulum angle. DisDiffAE produces images where we can control one factor at a time, but does not reflect causal effects. For example, changing the angle of the pendulum does not accurately change the shadow length and position. CausalDiffAE generates higher-quality counterfactuals that are consistent with the causal model. Specifically, intervention on the pendulum angle or light position changes the shadow length and position accurately. On the other hand, interventions on children variables leave the parents unchanged.

Figure 3 shows counterfactuals from the CausalCircuit dataset. CausalVAE generates inaccurate counterfactuals for many scenarios (e.g., intervention on the blue light intensity changes the intensity of all other lights). For CCDM, moving the robot arm over the green light fails to turn the light on. Furthermore, manipulating the light intensity of other lights affects the position of the robot arm. DisDiffAE enables control over the generative factors, but does not consider causal effects. For example, moving the robot arm over the green light does not turn it on. Counterfactuals generated from CausalDiffAE are consistent with the causal system. For example, moving the robot arm over the green button turns the light on and as a result also turns on the red light, which is a downstream child variable. Intervening on the blue or green light slightly increases the intensity of the red light. Intervening on the red light leaves all parent variables unchanged. For additional counterfactual generation results, see Appendix D.5.

Table 2: Effectiveness on MorphoMNIST test set (MAE)
Factor Model Intervention
do(tt) do(ii)
Thickness CausalVAE 3.763±0.013.763\pm{0.01} 4.645±0.014.645\pm{0.01}
(tt) DisDiffAE 0.377±0.02\mathbf{0.377\pm{0.02}} 0.326±0.020.326\pm{0.02}
CausalDiffAE 0.392±0.020.392\pm{0.02} 0.309±0.02\mathbf{0.309\pm{0.02}}
Intensity CausalVAE 13.233±0.0113.233\pm{0.01} 15.087±0.0115.087\pm{0.01}
(ii) DisDiffAE 0.794±0.020.794\pm{0.02} 0.262±0.020.262\pm{0.02}
CausalDiffAE 0.503±0.01\mathbf{0.503\pm{0.01}} 0.256±0.01\mathbf{0.256\pm{0.01}}
Table 3: Effectiveness on Pendulum test set (MAE)
Factor Model Intervention
do(aa) do(lplp) do(slsl) do(spsp)
Angle CausalVAE 24.86024.860 23.03023.030 20.47020.470 11.58011.580
(aa) DisDiffAE 0.6680.668 0.6480.648 0.6470.647 0.6470.647
CausalDiffAE 0.297\mathbf{0.297} 0.132\mathbf{0.132} 0.031\mathbf{0.031} 0.034\mathbf{0.034}
LightPos CausalVAE 34.20034.200 26.01026.010 35.49035.490 47.06047.060
(lplp) DisDiffAE 0.6560.656 0.6540.654 0.6300.630 0.6510.651
CausalDiffAE 0.045\mathbf{0.045} 0.434\mathbf{0.434} 0.035\mathbf{0.035} 0.064\mathbf{0.064}
ShadowLen CausalVAE 1.9461.946 1.431.43 2.022.02 1.721.72
(slsl) DisDiffAE 0.5500.550 0.5270.527 0.5600.560 0.5160.516
CausalDiffAE 0.136\mathbf{0.136} 0.322\mathbf{0.322} 0.492\mathbf{0.492} 0.082\mathbf{0.082}
ShadowPos CausalVAE 52.5252.52 72.5072.50 57.0357.03 32.7832.78
(spsp) DisDiffAE 0.4740.474 0.4750.475 0.4790.479 0.5340.534
CausalDiffAE 0.146\mathbf{0.146} 0.303\mathbf{0.303} 0.064\mathbf{0.064} 0.471\mathbf{0.471}

* Standard error is roughly in the range ±0.01\pm 0.01 to ±0.02\pm 0.02 for all averages.

Quantitative Evaluation. We quantitatively show using the effectiveness metric that CausalDiffAE generates counterfactuals that are both accurate and realistic. We perform random interventions from a uniform distribution over the test dataset for each causal variable. We find that CausalDiffAE almost always outperforms other baselines in the effectiveness metric, as shown in Tables 2 and 3, for all causal factors. Specifically, for the MorphoMNIST dataset, we observe that interventions on thickness produce counterfactuals that accurately reflect both the thickness and intensity values. In the scenario where we intervene on thickness, the intensity MAE is lower for CausalDiffAE than other baselines, which indicates that the generated counterfactual has an accurate intensity value consistent with the causal effect of thickness on intensity. When we intervene on intensity, the thickness MAE is lower for CausalDiffAE than baselines, which suggests that the generated counterfactual retains its original thickness value upon intervention on intensity. For the Pendulum dataset, we see a similar phenomenon, where interventions on causal factors along with their downstream effects are accurately captured in the generated counterfactuals. We do not evaluate effectiveness for the CausalCircuit dataset since we do not have access to the generative process used to obtain the factors. We compute the average effectiveness value over 55 runs with different random seeds. Our results strongly imply that the generated counterfactuals closely match the true counterfactuals.

5.4 Case Study: Weak Supervision Results

Unlike VAE-based approaches, the weak supervision paradigm of diffusion models reduces the full-label supervision. We study the weak supervision scenario with the MorphoMNIST dataset. We jointly train a representation-conditioned and unconditional model, where the conditioned split is far less than the unconditioned split. We have two main motivations for doing this: (1) it greatly reduces the need for fully labeled datasets, and (2) it enables granular control over generated counterfactuals. We denote the proportion of unlabeled data by punlabeledp_{\text{unlabeled}}. The CausalDiffAE model trained with punlabeled=0.8p_{\text{unlabeled}}=0.8 on the MorphoMNIST dataset yields a DCI score of 0.99640.9964, which suggests that even under strictly limited label supervision, CausalDiffAE learns disentangled representations. We also empirically show that changing the ω\omega parameter controls the strength of the intervention on the generated counterfactual. Figure 4 shows MNIST digits generated using the joint estimated score from the reduced supervision version of CausalDiffAE. We observe that interventions have virtually no effect when sampling using the joint score with ω=0.2\omega=0.2. For ω=0.5\omega=0.5, we see a stronger effect of the intervention on the thickness and intensity of the digit. Finally, for the fully-supervised score ω=1.0\omega=1.0, the intervention acts the strongest. Thus, varying ω\omega in the range (0,1)(0,1) can be interpreted as generating a range of different counterfactuals.

Refer to caption
Figure 4: MorphoMNIST Weak Supervision

6 Conclusion

In this work, we propose CausalDiffAE, a diffusion-based framework for causal representation learning and counterfactual generation. We propose a causal encoding mechanism that maps images to causally related factors. We learn the causal mechanisms among factors via neural networks. We formulate a variational diffusion-based objective to enforce the disentanglement of the latent space to enable latent space manipulations. We propose a DDIM-based counterfactual generation algorithm subject to interventions. For limited supervision scenarios, we propose a weak supervision extension of our model, which jointly learns an unconditional and conditional model. This objective also enables granular control over generated counterfactuals. We empirically show the capability of our model using both qualitative and quantitative metrics. Future work includes exploring counterfactual generation in text-to-image diffusion models.

Acknowledgements

This work is supported in part by National Science Foundation under awards 1910284, 1946391 and 2147375, the National Institute of General Medical Sciences of National Institutes of Health under award P20GM139768, and the Arkansas Integrative Metabolic Research Center at University of Arkansas.

References

  • Augustin et al. [2022] M. Augustin, V. Boreiko, F. Croce, and M. Hein. Diffusion visual counterfactual explanations. In Advances in Neural Information Processing Systems, 2022.
  • Bengio et al. [2013] Y. Bengio, A. Courville, and P. Vincent. Representation learning: A review and new perspectives. IEEE Transactions on Pattern Analysis and Machine Intelligence, 35(8):1798–1828, 2013.
  • Brehmer et al. [2022] J. Brehmer, P. D. Haan, P. Lippe, and T. Cohen. Weakly supervised causal representation learning. In Advances in Neural Information Processing Systems, 2022.
  • Castro et al. [2019] D. C. Castro, J. Tan, B. Kainz, E. Konukoglu, and B. Glocker. Morpho-MNIST: Quantitative assessment and diagnostics for representation learning. Journal of Machine Learning Research, 20(178), 2019.
  • Dhariwal and Nichol [2021] P. Dhariwal and A. Q. Nichol. Diffusion models beat GANs on image synthesis. In Advances in Neural Information Processing Systems, 2021.
  • Eastwood and Williams [2018] C. Eastwood and C. K. I. Williams. A framework for the quantitative evaluation of disentangled representations. In International Conference on Learning Representations, 2018.
  • Eastwood et al. [2023] C. Eastwood, A. L. Nicolicioiu, J. V. Kügelgen, A. Kekić, F. Träuble, A. Dittadi, and B. Schölkopf. DCI-ES: An extended disentanglement framework with connections to identifiability. In The Eleventh International Conference on Learning Representations, 2023.
  • Goodfellow et al. [2014] I. Goodfellow, J. Pouget-Abadie, M. Mirza, B. Xu, D. Warde-Farley, S. Ozair, A. Courville, and Y. Bengio. Generative adversarial nets. In Advances in Neural Information Processing Systems, 2014.
  • Higgins et al. [2017] I. Higgins, L. Matthey, A. Pal, C. Burgess, X. Glorot, M. Botvinick, S. Mohamed, and A. Lerchner. beta-VAE: Learning basic visual concepts with a constrained variational framework. In International Conference on Learning Representations, 2017.
  • Ho and Salimans [2021] J. Ho and T. Salimans. Classifier-free diffusion guidance. In NeurIPS 2021 Workshop on Deep Generative Models and Downstream Applications, 2021.
  • Ho et al. [2020] J. Ho, A. Jain, and P. Abbeel. Denoising diffusion probabilistic models. In Advances in Neural Information Processing Systems, 2020.
  • Karimi Mamaghan et al. [2024] A. M. Karimi Mamaghan, A. Dittadi, S. Bauer, K. H. Johansson, and F. Quinzan. Diffusion-based causal representation learning. Entropy, 26(7), 2024.
  • Khemakhem et al. [2020] I. Khemakhem, D. Kingma, R. Monti, and A. Hyvarinen. Variational autoencoders and nonlinear ica: A unifying framework. In Proceedings of the Twenty Third International Conference on Artificial Intelligence and Statistics, 2020.
  • Kingma and Welling [2014] D. P. Kingma and M. Welling. Auto-encoding variational bayes. In International Conference on Learning Representations, 2014.
  • Kocaoglu et al. [2018] M. Kocaoglu, C. Snyder, A. G. Dimakis, and S. Vishwanath. CausalGAN: Learning causal implicit generative models with adversarial training. In International Conference on Learning Representations, 2018.
  • Komanduri et al. [2024] A. Komanduri, X. Wu, Y. Wu, and F. Chen. From identifiable causal representations to controllable counterfactual generation: A survey on causal generative modeling. Transactions on Machine Learning Research, 2024.
  • Liu et al. [2022] X. Liu, P. Sanchez, S. Thermos, A. Q. O’Neil, and S. A. Tsaftaris. Learning disentangled representations in the imaging domain. Medical Image Analysis, 80:102516, 2022.
  • Melistas et al. [2024] T. Melistas, N. Spyrou, N. Gkouti, P. Sanchez, A. Vlontzos, G. Papanastasiou, and S. A. Tsaftaris. Benchmarking counterfactual image generation. arXiv preprint arXiv:2403.20287, 2024.
  • Mittal et al. [2023] S. Mittal, K. Abstreiter, S. Bauer, B. Schölkopf, and A. Mehrjou. Diffusion based representation learning. In Proceedings of the 40th International Conference on Machine Learning, 2023.
  • Nichol and Dhariwal [2021] A. Q. Nichol and P. Dhariwal. Improved denoising diffusion probabilistic models. In Proceedings of the 38th International Conference on Machine Learning, 2021.
  • Pandey et al. [2022] K. Pandey, A. Mukherjee, P. Rai, and A. Kumar. DiffuseVAE: Efficient, controllable and high-fidelity generation from low-dimensional latents. Transactions on Machine Learning Research, 2022.
  • Pawlowski et al. [2020] N. Pawlowski, D. Coelho de Castro, and B. Glocker. Deep structural causal models for tractable counterfactual inference. In Advances in Neural Information Processing Systems, 2020.
  • Pearl [2009] J. Pearl. Causality. Cambridge University Press, Cambridge, UK, 2 edition, 2009. ISBN 978-0-521-89560-6.
  • Preechakul et al. [2022] K. Preechakul, N. Chatthee, S. Wizadwongsa, and S. Suwajanakorn. Diffusion autoencoders: Toward a meaningful and decodable representation. In IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2022.
  • Radford et al. [2021] A. Radford, J. W. Kim, C. Hallacy, A. Ramesh, G. Goh, S. Agarwal, G. Sastry, A. Askell, P. Mishkin, J. Clark, G. Krueger, and I. Sutskever. Learning transferable visual models from natural language supervision. In Proceedings of the 38th International Conference on Machine Learning, 2021.
  • Ramesh et al. [2022] A. Ramesh, P. Dhariwal, A. Nichol, C. Chu, and M. Chen. Hierarchical text-conditional image generation with clip latents. arXiv preprint arXiv:2204.06125, 2022.
  • Rombach et al. [2022] R. Rombach, A. Blattmann, D. Lorenz, P. Esser, and B. Ommer. High-resolution image synthesis with latent diffusion models. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 2022.
  • Saharia et al. [2022] C. Saharia, W. Chan, S. Saxena, L. Li, J. Whang, E. Denton, S. K. S. Ghasemipour, B. K. Ayan, S. S. Mahdavi, R. G. Lopes, T. Salimans, J. Ho, D. J. Fleet, and M. Norouzi. Photorealistic text-to-image diffusion models with deep language understanding. In Advances in Neural Information Processing Systems, 2022.
  • Sanchez et al. [2022] P. Sanchez, J. P. Voisey, T. Xia, H. I. Watson, A. Q. ONeil, and S. A. Tsaftaris. Causal machine learning for healthcare and precision medicine. Royal Society Open Science, 2022.
  • Scholkopf et al. [2021] B. Scholkopf, F. Locatello, S. Bauer, N. R. Ke, N. Kalchbrenner, A. Goyal, and Y. Bengio. Toward Causal Representation Learning. Proceedings of the IEEE, 109:612–634, May 2021. ISSN 0018-9219, 1558-2256.
  • Sohl-Dickstein et al. [2015] J. Sohl-Dickstein, E. Weiss, N. Maheswaranathan, and S. Ganguli. Deep unsupervised learning using nonequilibrium thermodynamics. In Proceedings of the 32nd International Conference on Machine Learning, 2015.
  • Song et al. [2021a] J. Song, C. Meng, and S. Ermon. Denoising diffusion implicit models. In International Conference on Learning Representations, 2021a.
  • Song et al. [2021b] Y. Song, J. Sohl-Dickstein, D. P. Kingma, A. Kumar, S. Ermon, and B. Poole. Score-based generative modeling through stochastic differential equations. In International Conference on Learning Representations, 2021b.
  • Vowels et al. [2022] M. J. Vowels, N. C. Camgoz, and R. Bowden. D’ya like dags? a survey on structure learning and causal discovery. ACM Computing Surveys, 2022.
  • Yang et al. [2021] M. Yang, F. Liu, Z. Chen, X. Shen, J. Hao, and J. Wang. Causalvae: Disentangled representation learning via neural structural causal models. In IEEE Conference on Computer Vision and Pattern Recognition, 2021.
  • Zheng et al. [2018] X. Zheng, B. Aragam, P. K. Ravikumar, and E. P. Xing. Dags with no tears: Continuous optimization for structure learning. In Advances in Neural Information Processing Systems, 2018.

Appendices

Appendix A Derivation of ELBO

Given a high-dimensional input image 𝐱0\mathbf{x}_{0}, an auxiliary weak supervision signal 𝐲\mathbf{y}, a latent noise encoding 𝐮\mathbf{u}, latent representation 𝐳causal\mathbf{z}_{\text{causal}}, and a sequence of TT latent representations 𝐱1:T\mathbf{x}_{1:T} learned by the diffusion model, the CausalDiffAE generative process can be factorized as follows:

p(𝐱0:T,𝐮,𝐳causal|𝐲)=pθ(𝐱0:T|𝐮,𝐳causal,𝐲)p(𝐮,𝐳causal|𝐲)p(\mathbf{x}_{0:T},\mathbf{u},\mathbf{z}_{\text{causal}}|\mathbf{y})=p_{\theta}(\mathbf{x}_{0:T}|\mathbf{u},\mathbf{z}_{\text{causal}},\mathbf{y})p(\mathbf{u},\mathbf{z}_{\text{causal}}|\mathbf{y}) (22)

where θ\theta are the parameters of the reverse process of the conditional diffusion model. The log-likelihood of the input data distribution can be obtained as follows:

logp(𝐱0,𝐲)=logp(𝐱0:T,𝐮,𝐳causal,𝐲)𝑑𝐱1:T𝑑𝐮𝑑𝐳causal\log p(\mathbf{x}_{0},\mathbf{y})=\log\int p(\mathbf{x}_{0:T},\mathbf{u},\mathbf{z}_{\text{causal}},\mathbf{y})\;d\mathbf{x}_{1:T}\;d\mathbf{u}\;d\mathbf{z}_{\text{causal}} (23)

The joint posterior distribution p(𝐱1:T,𝐮,𝐳causal|𝐱0,𝐲)p(\mathbf{x}_{1:T},\mathbf{u},\mathbf{z}_{\text{causal}}|\mathbf{x}_{0},\mathbf{y}) is intractable, so we approximate it using a variational distribution q(𝐱1:T,𝐮,𝐳causal|𝐱0,𝐲)q(\mathbf{x}_{1:T},\mathbf{u},\mathbf{z}_{\text{causal}}|\mathbf{x}_{0},\mathbf{y}) which can be factorized into the following conditional distributions

q(𝐱1:T,𝐮,𝐳causal|𝐱0,𝐲)=qϕ(𝐳causal,𝐮|𝐱0,𝐲)q(𝐱1:T|𝐮,𝐳causal,𝐱0)q(\mathbf{x}_{1:T},\mathbf{u},\mathbf{z}_{\text{causal}}|\mathbf{x}_{0},\mathbf{y})=q_{\phi}(\mathbf{z}_{\text{causal}},\mathbf{u}|\mathbf{x}_{0},\mathbf{y})q(\mathbf{x}_{1:T}|\mathbf{u},\mathbf{z}_{\text{causal}},\mathbf{x}_{0}) (24)

where ϕ\phi are the parameters of the variational encoder network. Since the likelihood of the data is intractable, we can approximate it by maximizing the following evidence lower bound (ELBO):

logp(𝐱0,𝐲)\displaystyle\log p(\mathbf{x}_{0},\mathbf{y}) 𝔼q(𝐱1:T,𝐮,𝐳causal|𝐱0,𝐲)[logp(𝐱0:T,𝐮,𝐳causal,𝐲)q(𝐱1:T,𝐮,𝐳causal|𝐱0,𝐲)]\displaystyle\geq\mathbb{E}_{q(\mathbf{x}_{1:T},\mathbf{u},\mathbf{z}_{\text{causal}}|\mathbf{x}_{0},\mathbf{y})}\Bigg{[}\log\frac{p(\mathbf{x}_{0:T},\mathbf{u},\mathbf{z}_{\text{causal}},\mathbf{y})}{q(\mathbf{x}_{1:T},\mathbf{u},\mathbf{z}_{\text{causal}}|\mathbf{x}_{0},\mathbf{y})}\Bigg{]} (25)
=𝔼q(𝐱1:T,𝐮,𝐳causal|𝐱0,𝐲)[logp(𝐮)p(𝐳causal|𝐲)pθ(𝐱0:T|𝐮,𝐳causal)qϕ(𝐳causal,𝐮|𝐱0,𝐲)q(𝐱1:T|𝐮,𝐳causal,𝐱0)]\displaystyle=\mathbb{E}_{q(\mathbf{x}_{1:T},\mathbf{u},\mathbf{z}_{\text{causal}}|\mathbf{x}_{0},\mathbf{y})}\Bigg{[}\log\frac{p(\mathbf{u})p(\mathbf{z}_{\text{causal}}|\mathbf{y})p_{\theta}(\mathbf{x}_{0:T}|\mathbf{u},\mathbf{z}_{\text{causal}})}{q_{\phi}(\mathbf{z}_{\text{causal}},\mathbf{u}|\mathbf{x}_{0},\mathbf{y})q(\mathbf{x}_{1:T}|\mathbf{u},\mathbf{z}_{\text{causal}},\mathbf{x}_{0})}\Bigg{]} (26)
=𝔼q(𝐱1:T,𝐮,𝐳causal|𝐱0,𝐲)[logp(𝐮,𝐳causal|𝐲)qϕ(𝐳causal,𝐮|𝐱0,𝐲)+logpθ(𝐱0:T|𝐮,𝐳causal)q(𝐱1:T|𝐮,𝐳causal,𝐱0)]\displaystyle=\mathbb{E}_{q(\mathbf{x}_{1:T},\mathbf{u},\mathbf{z}_{\text{causal}}|\mathbf{x}_{0},\mathbf{y})}\Bigg{[}\log\frac{p(\mathbf{u},\mathbf{z}_{\text{causal}}|\mathbf{y})}{q_{\phi}(\mathbf{z}_{\text{causal}},\mathbf{u}|\mathbf{x}_{0},\mathbf{y})}+\log\frac{p_{\theta}(\mathbf{x}_{0:T}|\mathbf{u},\mathbf{z}_{\text{causal}})}{q(\mathbf{x}_{1:T}|\mathbf{u},\mathbf{z}_{\text{causal}},\mathbf{x}_{0})}\Bigg{]} (27)
=𝔼q(𝐮,𝐳causal|𝐱0,𝐲)[logp(𝐮,𝐳causal|𝐲)qϕ(𝐳causal,𝐮|𝐱0,𝐲)]+𝔼q(𝐱1:T,𝐮,𝐳causal|𝐱0)[logpθ(𝐱0:T|𝐮,𝐳causal)q(𝐱1:T|𝐮,𝐳causal,𝐱0)]\displaystyle=\mathbb{E}_{q(\mathbf{u},\mathbf{z}_{\text{causal}}|\mathbf{x}_{0},\mathbf{y})}\Bigg{[}\log\frac{p(\mathbf{u},\mathbf{z}_{\text{causal}}|\mathbf{y})}{q_{\phi}(\mathbf{z}_{\text{causal}},\mathbf{u}|\mathbf{x}_{0},\mathbf{y})}\Bigg{]}+\mathbb{E}_{q(\mathbf{x}_{1:T},\mathbf{u},\mathbf{z}_{\text{causal}}|\mathbf{x}_{0})}\Bigg{[}\log\frac{p_{\theta}(\mathbf{x}_{0:T}|\mathbf{u},\mathbf{z}_{\text{causal}})}{q(\mathbf{x}_{1:T}|\mathbf{u},\mathbf{z}_{\text{causal}},\mathbf{x}_{0})}\Bigg{]} (28)
=𝔼q(𝐮,𝐳causal|𝐱0,𝐲)[𝔼q(𝐱1:T,𝐮,𝐳causal|𝐱0)[pθ(𝐱0:T|𝐮,𝐳causal)q(𝐱1:T|𝐮,𝐳causal,𝐱0)]Representation-conditioned DDPM Loss]𝒟KL(qϕ(𝐮,𝐳causal|𝐱0,𝐲)p(𝐮,𝐳causal|𝐲))Joint Latent Posterior Loss\displaystyle=\mathbb{E}_{q(\mathbf{u},\mathbf{z}_{\text{causal}}|\mathbf{x}_{0},\mathbf{y})}\Bigg{[}\underbrace{\mathbb{E}_{q(\mathbf{x}_{1:T},\mathbf{u},\mathbf{z}_{\text{causal}}|\mathbf{x}_{0})}\Bigg{[}\frac{p_{\theta}(\mathbf{x}_{0:T}|\mathbf{u},\mathbf{z}_{\text{causal}})}{q(\mathbf{x}_{1:T}|\mathbf{u},\mathbf{z}_{\text{causal}},\mathbf{x}_{0})}\Bigg{]}}_{\text{Representation-conditioned DDPM Loss}}\Bigg{]}-\underbrace{\mathcal{D}_{KL}(q_{\phi}(\mathbf{u},\mathbf{z}_{\text{causal}}|\mathbf{x}_{0},\mathbf{y})\|p(\mathbf{u},\mathbf{z}_{\text{causal}}|\mathbf{y}))}_{\text{Joint Latent Posterior Loss}} (29)

In the learning process, we minimize the negative of the derived ELBO. We simplify this objective by using the ϵθ\epsilon_{\theta} parameterization to optimize the representation-conditioned DDPM loss. Further, since 𝐮\mathbf{u} and 𝐳causal\mathbf{z}_{\text{causal}} are one-to-one mapped, we can split the joint conditional distribution into separate conditional distributions. Thus, we have the following final objective for CausalDiffAE:

CausalDiffAE=t=1T𝔼t,𝐱0,ϵ[ϵθ(𝐱t,t,𝐳causal)ϵt22]+γ{𝒟KL(qϕ(𝐳causal|𝐱0,𝐲)p(𝐳causal|𝐲))+𝒟KL(qϕ(𝐮|𝐱0)𝒩(𝟎,𝐈))}\mathcal{L}_{\text{CausalDiffAE}}=\sum_{t=1}^{T}\mathbb{E}_{t,\mathbf{x}_{0},\epsilon}\Big{[}\|\epsilon_{\theta}(\mathbf{x}_{t},t,\mathbf{z}_{\text{causal}})-\epsilon_{t}\|_{2}^{2}\Big{]}+\gamma\Big{\{}\mathcal{D}_{KL}(q_{\phi}(\mathbf{z}_{\text{causal}}|\mathbf{x}_{0},\mathbf{y})\|p(\mathbf{z}_{\text{causal}}|\mathbf{y}))+\mathcal{D}_{KL}(q_{\phi}(\mathbf{u}|\mathbf{x}_{0})\|\mathcal{N}(\mathbf{0},\mathbf{I}))\Big{\}} (31)

Appendix B Connection to Score-based Generative Models

Diffusion models can also be represented as stochastic differential equations (SDEs) [33] to model continuous-time perturbations. Specifically, the forward diffusion process can be modeled as the solution to an SDE on a continuous-time domain t[0,T]t\in[0,T] with stochastic trajectories:

d𝐱=f(𝐱,t)dt+g(t)dwd\mathbf{x}=f(\mathbf{x},t)\;dt+g(t)\;dw (32)

where ww is the standard Weiner process, ff is a vector-valued function known as the drift coefficient of 𝐱(t)\mathbf{x}(t) and gg is a scalar function known as the diffusion coefficient of 𝐱(t)\mathbf{x}(t). The drift and diffusion coefficients can be considered as the mean and variance of the noise perturbations in the diffusion process, respectively. The reverse diffusion process can be modeled by the solution to the reverse-time SDE of Eq. (32), which can be derived analytically as:

d𝐱=[f(𝐱,t)g2(t)xlogpt(𝐱)]dt+g(t)dw¯d\mathbf{x}=[f(\mathbf{x},t)-g^{2}(t)\nabla_{x}\log p_{t}(\mathbf{x})]\;dt+g(t)\;d\bar{w} (33)

where w¯\bar{w} is the standard Weiner process in reverse time and xlogpt(𝐱)\nabla_{x}\log p_{t}(\mathbf{x}) is the score of the data distribution at timestep tt. Once we know the score of the marginal distribution for all timesteps tt, we can derive the reverse diffusion process from Eq. (33).

Song et al [33] showed that the denoising diffusion probabilistic model (DDPM) is a discretization of the following Variance Preserving SDE (VP-SDE)

d𝐱=12β(t)𝐱dt+β(t)dwd\mathbf{x}=\frac{1}{2}\beta(t)\mathbf{x}\;dt+\sqrt{\beta(t)}\;dw (34)

Thus, learning a noise prediction network ϵθ\epsilon_{\theta} and minimizing MSE in diffusion probabilistic models is equivalent to approximating the score of the data distribution in the SDE formulation. From a score-based perspective, we aim to minimize the following conditional denoising score-matching form of our objective

𝔼p(𝐱)𝔼qϕ(𝐳causal|𝐱0)𝔼q(𝐱t|𝐱0)[logp(𝐮)+p(𝐳causal|𝐲)logqϕ(𝐮|𝐱0)logqϕ(𝐳causal|𝐱0,𝐲)+λ(t)sθ(𝐱t,𝐳causal,t)𝐱tlogp(𝐱t|𝐱0)]\begin{split}&\mathbb{E}_{p(\mathbf{x})}\mathbb{E}_{q_{\phi}(\mathbf{z}_{\text{causal}}|\mathbf{x}_{0})}\mathbb{E}_{q(\mathbf{x}_{t}|\mathbf{x}_{0})}\Big{[}\log p(\mathbf{u})+p(\mathbf{z}_{\text{causal}}|\mathbf{y})\\ &-\log q_{\phi}(\mathbf{u}|\mathbf{x}_{0})-\log q_{\phi}(\mathbf{z}_{\text{causal}}|\mathbf{x}_{0},\mathbf{y})\\ &+\lambda(t)\|s_{\theta}(\mathbf{x}_{t},\mathbf{z}_{\text{causal}},t)-\nabla_{\mathbf{x}_{t}}\log p(\mathbf{x}_{t}|\mathbf{x}_{0})\|\Big{]}\end{split} (35)

where sθs_{\theta} approximates the score of the data distribution conditioned on 𝐱0\mathbf{x}_{0} and λ(t)\lambda(t) is a positive weighing function. The ideal for modeling natural phenomena in the world is by using differential equations to model the physical mechanisms [30]. In the SDE formulation, the causal variables are used to denoise the high-dimensional data, which is modeled as a reverse-time stochastic trajectory. We can interpret this idea as modeling the dynamics of high-dimensional systems by incorporating causal information. As opposed to simply learning an arbitrary latent representation, a disentangled causal representation encodes the causal information that the denoising process can use to reconstruct causally relevant features in high-dimensional data.

Appendix C Discussion on Causal Discovery

In this work, we assume the latent causal structure is known since we focus on counterfactual generation. In principle, our framework can be combined with causal structure learning methods such as NOTEARS [36] by adding a penalty to terms in the VAE loss objective to enforce sparsity and acyclicity as follows

total=CausalDiffAE+H(A)+A0\mathcal{L}_{total}=\mathcal{L}_{\text{CausalDiffAE}}+H(A)+\|A\|_{0} (36)

where H(A)=tr[(I+αAA)]nn=0H(A)=tr[(I+\alpha A\odot A)]^{n}-n=0 is the acyclicity constraint and 0\|\cdot\|_{0} enforces the sparsity of the DAG. We can alternatively use the 1\|\cdot\|_{1} for sparsity to ensure a differentiable objective. Similar to [36], we can utilize the augmented Lagrangian to optimize the joint loss objective. Additionally, other causal discovery algorithms could be used heuristically with a variety of different assumptions [34]. We look to explore this direction in future work.

Appendix D Experiment Details

D.1 Dataset Details

MorphoMNIST. The MorphoMNIST dataset [4] is produced by applying morphological transformations on the original MNIST handwritten digit dataset. The digits can be described by measurable shape attributes such as stroke thickness, stroke length, width, height, and slant of digit. Pawlowski et al [22] impose a 33-variable SCM to generate the morphological transformations, where stroke thickness is a cause of the brightness of each digit. That is, thicker digits are often brighter, whereas thinner digits are dimmer. The data-generating process is as follows

t=\displaystyle t= fT(uT)=0.5+uT,\displaystyle f_{T}(u_{T})=0.5+u_{T}\,, uT\displaystyle u_{T} Γ(10,5),\displaystyle\sim\Gamma(10,5)\,, (37)
i=\displaystyle i= fI(uI;t)=191σ(0.5uI+2t5)+64,\displaystyle f_{I}(u_{I};t)=191\cdot\sigma(0.5\cdot u_{I}+2\cdot t-5)+64\,, uI\displaystyle u_{I} 𝒩(0,1),\displaystyle\sim\mathcal{N}(0,1)\,,
x=\displaystyle x= fX(uX;i,t)=SetIntensity(SetThickness(uX;t;i),\displaystyle f_{X}(u_{X};i,t)=\text{SetIntensity(SetThickness($u_{X};t$) $;i$)}\,, uX\displaystyle u_{X} MNIST,\displaystyle\sim\text{MNIST}\,,

where xx is the resulting image, uu is the exogenous noise for each variable, and σ()\sigma(\cdot) is the logistic sigmoid.

Pendulum. The Pendulum dataset [35] consists of a set of 77K images with resolution 96×96×496\times 96\times 4 describing a physical system of a pendulum and light source that cause the length and position of a shadow. The causal variables of interest are the angle of the pendulum, the position of the light source, the length of the shadow, and the position of the shadow. The data generating process is as follows:

y1\displaystyle y_{1} U(45,45);θ=y1π200;x=10+9.5sinθ\displaystyle\sim U(-45,45);\;\qquad\theta=y_{1}*\frac{\pi}{200};\;\qquad x=10+9.5\sin\theta
y2\displaystyle y_{2} U(60,145);ϕ=y2π200;y=109.5cosθ\displaystyle\sim U(60,145);\;\qquad\phi=y_{2}*\frac{\pi}{200};\;\qquad y=10-9.5\cos\theta
y3\displaystyle y_{3} =max(3,|9.5cosθtanϕ+9.5sinθ|)\displaystyle=\max(3,\Big{|}9.5\frac{\cos\theta}{\tan\phi}+9.5\sin\theta\Big{|})
y4\displaystyle y_{4} =11+4.75cosθtanϕ+(10+4.75sinθ)\displaystyle=\frac{-11+4.75\cos\theta}{\tan\phi}+(10+4.75\sin\theta)

Causal Circuit. The Causal Circuit dataset is a new dataset created by [3] to explore research in causal representation learning. The dataset consists of 512×512×3512\times 512\times 3 resolution images generated by 44 ground-truth latent causal variables: robot arm position, red light intensity, green light intensity, and blue light intensity. The images show a robot arm interacting with a system of buttons and lights. The data is rendered using an open-source physics engine. The original dataset consists of pairs of images before and after an intervention has taken place. For the purposes of this work, we only utilize observational data of either the before or after system. The data is generated according to the following process:

vR\displaystyle v_{R} =0.2+0.6clip(y2+y3+bR,0,1)\displaystyle=0.2+0.6*\text{clip}(y_{2}+y_{3}+b_{R},0,1)
vG\displaystyle v_{G} =0.2+0.6bG\displaystyle=0.2+0.6*b_{G}
vB\displaystyle v_{B} =0.2+0.6bB\displaystyle=0.2+0.6*b_{B}
y4\displaystyle y_{4} Beta(5vR,5(1vR))\displaystyle\sim\text{Beta}(5v_{R},5*(1-v_{R}))
y3\displaystyle y_{3} Beta(5vG,5(1vG))\displaystyle\sim\text{Beta}(5v_{G},5*(1-v_{G}))
y2\displaystyle y_{2} Beta(5vB,5(1vB))\displaystyle\sim\text{Beta}(5v_{B},5*(1-v_{B}))
y1\displaystyle y_{1} U(0,1)\displaystyle\sim U(0,1)

where bRb_{R}, bGb_{G}, and bBb_{B} are the pressed state of buttons that depends on how far the button is touched from the center, y1y_{1} is the robot arm position, and y2y_{2}, y3y_{3}, and y4y_{4} are the intensities of the blue, green, and red lights, respectively.

D.2 Implementation Details

We use the same network architectures and hyperparameters used in other works based on diffusion models [11, 5, 20]. We set the causal latent variable size to 512512 to ensure a large enough capacity to capture causally relevant information. The representation-conditioned noise predictor is parameterized by a UNet with the attention mechanism. Similar to [11], we use a linear noise scheduling for the variance parameter β\beta between β1=104\beta_{1}=10^{-4} and β2=0.02\beta_{2}=0.02 during training. For all three datasets, we start the bottleneck parameter at γ=0\gamma=0 and linearly increase γ\gamma throughout training to a final value of γ=1.0\gamma=1.0.

Table 4: Implementation details of CausalDiffAE
      Parameter       MorphoMNIST       Pendulum       CausalCircuit
      Batch size       768768       128128       128128
      Base channels       128128       128128       128128
      Channel multipliers       [1,2,2][1,2,2]       [1,2,4,8][1,2,4,8]       [1,2,4,8][1,2,4,8]
      Training set       6060K       55K       5050K
      Test set       1010K       22K       1010K
      Image resolution       28×28×128\times 28\times 1       96×96×496\times 96\times 4       128×128×3128\times 128\times 3
      Num causal variables       22       44       44
      zcausalz_{\text{causal}} size       512512       512512       512512
      β\beta scheduler       Linear       Linear       Cosine
      Learning rate       10410^{-4}       10410^{-4}       10410^{-4}
      Optimizer       Adam       Adam       Adam
      Diffusion steps       10001000       10001000       40004000
      Iterations       1010K       4040K       2020K
      Diffusion loss       MSE       MSE       MSE
      Sampling       DDIM       DDIM       DDIM
      Stride       250250       250250       250250
      Bottleneck γ\gamma       1.01.0       1.01.0       1.01.0

D.3 Metrics Details

DCI Disentanglement [6]. The DCI disentanglement score quantifies the degree to which a representation disentangles the underlying factors of variation with each variable capturing at most one generative factor. Let Pij=Rij/k=0K1RikP_{ij}=R_{ij}/\sum_{k=0}^{K-1}R_{ik} be the probability of ziz_{i} being a strong predictor of yjy_{j}. Then, the disentanglement score is defined as

Di=(1+k=0K1PiklogkPik)D_{i}=(1+\sum_{k=0}^{K-1}P_{ik}\log_{k}P_{ik}) (38)

If ziz_{i} is a strong predictor for only a single generative factor, Di=1D_{i}=1. If ziz_{i} is equally important in predicting all generative factors, Di=0D_{i}=0. Let ρi=jRij/ijRij\rho_{i}=\sum_{j}R_{ij}/\sum_{ij}R_{ij} be the relative latent code importances. The total disentanglement score is a weighted average of the individual DiD_{i}

D=iρiDiD=\sum_{i}\rho_{i}D_{i} (39)

Effectiveness [18]. The effectiveness metric aims to identify how successful the performed intervention is. To quantitatively evaluate the effectiveness for a given counterfactual image, an anti-causal predictor hθih_{\theta}^{i} is trained on the data distribution, for each causal variable yiy_{i}. Each predictor approximates the counterfactual value of the variable yxiy^{i*}_{x} given the counterfactual image xx^{*} as input

effectivenessi(x,yxi)=d(yxi,hθi(x))\text{effectiveness}_{i}(x^{*},y^{i*}_{x})=d(y^{i*}_{x},h_{\theta}^{i}(x^{*})) (40)

where d()d(\cdot) is the corresponding distance, defined as a classification metric for categorical variables and a regression metric for continuous ones.

D.4 Computational Requirements

We run our experiments on an Ubuntu 20.04 workstation with eight NVIDIA Tesla V100-SXM2 GPUs with 32GB RAM. It is well-known that diffusion models have a higher computational complexity than other generative models, such as VAEs and GANs. Generally speaking, all the diffusion-based approaches have quite a similar runtime, whereas CausalVAE is much faster. We expect any developments in the training and sampling efficiency of diffusion probabilistic models to apply to our proposed diffusion-based approach as well.

D.5 Additional Experiments

Refer to caption
Figure 5: CausalDiffAE generated counterfactuals (MorphoMNIST)
Refer to caption
Figure 6: CausalDiffAE generated counterfactuals via latent traversals in the normalized range (1,1)(-1,1) (Pendulum)
Refer to caption
Figure 7: CausalDiffAE generated counterfactuals via latent traversals in the normalized range (1,1)(-1,1) (CausalCircuit).