Causal Diffusion Autoencoders: Toward Counterfactual Generation via Diffusion Probabilistic Models
Abstract
Diffusion probabilistic models (DPMs) have become the state-of-the-art in high-quality image generation. However, DPMs have an arbitrary noisy latent space with no interpretable or controllable semantics. Although there has been significant research effort to improve image sample quality, there is little work on representation-controlled generation using diffusion models. Specifically, causal modeling and controllable counterfactual generation using DPMs is an underexplored area. In this work, we propose CausalDiffAE, a diffusion-based causal representation learning framework to enable counterfactual generation according to a specified causal model. Our key idea is to use an encoder to extract high-level semantically meaningful causal variables from high-dimensional data and model stochastic variation using reverse diffusion. We propose a causal encoding mechanism that maps high-dimensional data to causally related latent factors and parameterize the causal mechanisms among latent factors using neural networks. To enforce the disentanglement of causal variables, we formulate a variational objective and leverage auxiliary label information in a prior to regularize the latent space. We propose a DDIM-based counterfactual generation procedure subject to do-interventions. Finally, to address the limited label supervision scenario, we also study the application of CausalDiffAE when a part of the training data is unlabeled, which also enables granular control over the strength of interventions in generating counterfactuals during inference. We empirically show that CausalDiffAE learns a disentangled latent space and is capable of generating high-quality counterfactual images.
1635
1 Introduction
Diffusion probabilistic models (DPMs) [31, 11, 20, 32, 33] are a class of likelihood-based generative models that have achieved remarkable successes in the generation of high-resolution images with many large-scale implementations such as DALLE-2 [26], Stable Diffusion [27], and Imagen [28]. Thus, there has been great interest in evaluating the capabilities of diffusion models. Two of the most promising approaches are formulated as discrete-time [11] and continuous-time [33] step-wise perturbations of the data distribution. A model is then trained to estimate the reverse process which transforms noisy samples to samples from the underlying data distribution. Representation learning has been an integral component of generative models such as GANs [8] and VAEs [14] for extracting robust and interpretable features from complex data [30, 2, 25]. Recently, a thrust of research has focused on whether DPMs can be used to extract a semantically meaningful and decodable representation that increases the quality of and control over generated images [21, 24]. However, there has been no work in modeling causal relations among the semantic latent codes to learn causal representations and enable counterfactual generation at inference time in DPMs. Generating high-quality counterfactual images is critical for domains such as healthcare and medicine [17, 29]. The ability to generate accurate counterfactual data from a causal graph obtained from domain knowledge can significantly cut the cost of data collection. Furthermore, reasoning about hypothetical scenarios unseen in the training distribution can be quite insightful for gauging the interactions among causal variables in complex systems. Given a causal graph of a system, we study the capability of DPMs as causal representation learners and evaluate their ability to generate counterfactuals upon interventions on causal variables.
Intuitively, we can think about the DPM as an encoder-decoder framework. The encoding maps an input image to a spatial latent variable through a series of Gaussian noise perturbations. However, can be interpreted as a noise representation that lacks high-level semantics [24]. Recently, Preechakul et al [24] proposed a diffusion-based autoencoder (DiffAE) to extract a high-level semantic representation alongside the stochastic low-level representation for decodable representation learning. Learning such a semantic representation also enables interpolation in the latent space for controllable generation and has been shown to improve image generation quality. Mittal et al [19] built on this framework and introduced a diffusion-based representation learning (DRL) objective that instead learns time-conditioned representations throughout the diffusion process. However, both these approaches learn arbitrary representations and do not focus on disentanglement, a key property of interpretable representations. Disentangled representations enable precise control of generative factors in isolation. When considering causal systems, disentanglement is important for performing isolated interventions.
In this paper, we focus on learning disentangled causal representations, where the high-level semantic factors are causally related. To the best of our knowledge, we are the first to explore representation-based counterfactual image generation using diffusion probabilistic models. We propose CausalDiffAE, a learning framework for causal representation learning and controllable counterfactual generation in DPMs. Our key idea is to learn a causal representation via a learnable stochastic encoder and model the relations among latents via causal mechanisms parameterized by neural networks. We formulate a variational objective with a label alignment prior to enforce disentanglement of the learned causal factors. We then utilize a conditional denoising diffusion implicit model (DDIM) [32] for decoding and modeling the stochastic variations. Intuitively, the causal representation encodes compact information that is causally relevant for image decoding in reverse diffusion. Furthermore, the modeling of causal relations in the latent space enables the generation of counterfactuals upon interventions on learned causal variables. We propose a DDIM variant for counterfactual generation subject to interventions [23]. In an effort to improve the practicality and interpretability of the model, we propose an extension to CausalDiffAE that utilizes weaker supervision. In the scenario where labeled data is limited, we jointly train an unconditional and representation-conditioned diffusion model on the unlabeled and labeled partitions, respectively. This approach significantly reduces the number of labeled samples required for training and enables granular control over the strength of interventions and the quality of generated counterfactuals.
2 Related Work
Recent work in causal generative modeling has focused on either learning causal representations or controllable counterfactual generation [16]. Yang et al proposed CausalVAE [35], a causal representation learning framework that models latent causal variables by a linear SCM. Kocaoglu et al [15] proposed CausalGAN, an extension of the GAN to model causal variables for sampling from interventional distributions. Diffusion and score-based generative models [11, 33] have shown impressive results in class-conditional generation either through classifier-based [5] or classifier-free [10] paradigms. Recently, there has been an interest in exploring the capacity of diffusion models as representation learners. For instance, Mittal et al [19] and Preechakul et al [24] considered diffusion-based representation learning objectives. Mamaghan et al [12] explored representation learning from a score-based perspective given access to data in the form of counterfactual pairs. However, this work does not focus on counterfactual generation. Another related area of research is counterfactual explanations [1], which focuses on post-hoc methods to generate realistic counterfactuals, but not in the strictly causal sense. Our work focuses on diffusion-based representation learning and is most closely related to DiffAE [24] and DRL [19], which aim to learn semantically meaningful representations. However, the key distinction is that we learn causal representations to enable counterfactual generation. Our proposed framework extends CausalVAE to diffusion-based models and under a weaker supervision paradigm.
3 Background
3.1 Structural Causal Model
A structural causal model (SCM) is formally defined by a tuple , where is the domain of the set of endogenous causal variables , is the domain of the set of exogenous noise variables , which is learned as an intermediate latent variable, and is a collection of independent causal mechanisms of the form
(1) |
where , are causal mechanisms that determine each causal variable as a function of the parents and noise, are the parents of causal variable ; and a probability measure , which admits a product distribution. For the purposes of this work, we assume a causally sufficient setting (no hidden confounding), no SCM misspecification, and faithfulness is satisfied.

3.2 Diffusion Probabilistic Models
Diffusion Probabilistic Models (DPMs) [11, 20] have shown impressive results in image generation tasks, even beating out GANs in many cases [5]. The idea of the denoising diffusion probabilistic model (DDPM) [11] is to define a Markov chain of diffusion steps to slowly destroy the structure in a data distribution through a forward diffusion process by adding noise [11] and learn a reverse diffusion process that restores the structure of the data. Some proposed methods, such as denoising diffusion implicit model (DDIM) [32], break the Markov assumption to speed up the sampling in the diffusion process by carrying out a deterministic encoding of the noise.
Forward Diffusion. Given some input data sampled from a distribution , the forward diffusion process is defined by adding small amounts of Gaussian noise to the sample in steps thereby producing noisy samples . The distribution of the noisy sample at time step is defined as a conditional distribution as follows:
(2) |
where is a variance parameter that controls the step size of noise. As , the input sample loses its distinguishable features. In the end, when , follows an isotropic Gaussian. From Eq (2), we can then define a closed-form tractable posterior over all time steps factorized as follows:
(3) |
Now, can be sampled at any arbitrary time step using the reparameterization trick. Let :
(4) |
Reverse Diffusion. In the reverse process, to sample from , the goal is to recreate the true sample from a Gaussian noise input . Unlike the forward diffusion, is not analytically tractable and thus requires learning a model to approximate the conditional distributions as follows:
(5) |
where and are learned via neural networks. It turns out that conditioning on the input yields a tractable reverse conditional probability
(6) |
where and are the true mean and variance. The learning objective is then formulated as a simplified objective of the ELBO via reparameterization to minimize the following mean squared error loss
(7) |
where is the noise that takes an analytical form via a reparameterization from , as shown in [11].
DPMs produce latent variables through the forward process. However, these variables are stochastic [24]. Song et al. proposed a DPM called Denoising Diffusion Implicit model (DDIM), which enables a deterministic process as follows:
(8) |
with the following deterministic decoding process
(9) |
which keeps the DDPM marginal distribution . It turns out that this formulation shares the same objective and solution of DDPM and only differs in the sampling procedure. Thus, we can deterministically obtain the noise map corresponding to a given image .
4 Causal Diffusion Autoencoders
Existing diffusion-based controllable generation methods neglect the scenario where generative factors are causally related and do not support counterfactual generation. To tackle this issue, we propose CausalDiffAE, a diffusion-based causal representation learning framework to enable counterfactual generation. Firstly, we define a latent SCM to describe the semantic causal representation as a function of learned noise encodings. In the case of diffusion autoencoders [24], the semantic latent representation captures high-level semantic information, and captures low-level stochastic information. In our formulation, we learn a causal representation which captures causally relevant information. Together, the two latent variables capture all the detailed causal semantics and stochasticity in the image. Secondly, given a trained CausalDiffAE model, we propose a counterfactual generation algorithm that utilizes interventions and the DDIM sampling algorithm. The overall framework of CausalDiffAE is shown in Figure 1.
4.1 Causal Encoding
Let be the observed input image. We carry out the forward diffusion process until we have a set of perturbed samples , each at a different noise scale. Suppose there are abstract causal variables that describe the high-level semantics of the observed image. To learn a meaningful representation, we propose to encode the input image to a low-dimensional noise encoding . We then map the noise encoding to latent causal factors corresponding to the abstract causal variables. In this formulation, each noise term is the exogenous noise term for causal variable in the SCM. Let be the adjacency matrix encoding the causal graph among the underlying factors where implies is a cause of . Then, we parameterize the mechanisms among causal variables as follows
(10) |
where is the causal mechanism generating causal variable as a function of its parents and exogenous noise term and denotes the causal parents of factor . In practice, we can implement as a post-nonlinear additive noise model such that
(11) |
where are the parameters of the neural network parameterizing each mechanism, is the elementwise product, and . This module captures the causal relations between latent variables using neural structural causal models. For the purposes of this work, we assume that the causal graph is known since we focus on counterfactual generation. However, a more end-to-end framework may include a causal discovery component. See Appendix C for a more detailed discussion.
4.2 Generative Model
Let denote the high-dimensional input image and denote an auxiliary weak supervision signal. Then, the CausalDiffAE generative process can be factorized as follows:
(12) |
where are the parameters of the reverse process of the causal diffusion decoder (will discuss in Section 4.3), , , and is the alignment prior defined in Eq. (19). The joint posterior distribution is intractable, so we approximate it using a variational distribution which can be factorized into the following conditional distributions
(13) |
where are the parameters of the variational encoder network parameterizing the joint distribution over the noise and causal factors . We can remove the dependence on for the second conditional term in the decomposition of Eq. (13) since is independent of the auxiliary label . We note that can be factorized as since and have a one-to-one correspondence.
Input: (image, label) pairs
Output: learned parameters
4.3 Causal Diffusion Decoder
We use a conditional DDIM decoder that takes as input the pair of latent variables to generate the output image. We approximate the inference distribution by parameterizing the probabilistic decoder via a conditional DDIM . With DDIM, the forward process becomes completely deterministic except for . Similar to [24], we define the joint distribution of the reverse generative process as follows:
(14) |
(15) |
where is parameterized by a noise prediction network (i.e., UNet [5]) as follows:
(16) |
Note that in Eq. (14), is omitted since already captures all the information about the noise. By leveraging the reparameterization trick, we can optimize the following mean squared error between noise terms
(17) |
where and .
4.4 Learning Objective
To ensure the causal representation is disentangled, we incorporate label information as a prior in the variational objective to aid in learning semantic factors and for identifiability guarantees [13]. We define the following joint loss objective:
(18) |
where is a regularization hyperparameter similar to the bottleneck parameter in -VAEs [9], and the alignment prior over latent variables is defined as the following exponential family distribution
(19) |
where and are functions that estimate the mean and variance of the Gaussian, respectively. Intuitively, this prior ensures that the learned factors are one-to-one mapped to an indicator of the underlying ground truth factors. DiffAE requires training a latent DDIM in the latent space of the pre-trained autoencoder to enable sampling of latent semantic representation. However, CausalDiffAE is formulated as a variational objective with a stochastic encoder. Thus, we can sample the representation from the defined prior directly without having to train a separate diffusion model in the latent space. The training procedure for CausalDiffAE is outlined in Algorithm 1. See Appendix A for a derivation of the ELBO. For a detailed discussion on the connection of our diffusion objective to score-based generative models [33], see Appendix B.
Input: Factual sample , intervention target set with intervention values , noise predictor , encoder
Output: Counterfactual sample
4.5 Counterfactual Generation
A fundamental property of causal models is the ability to perform interventions and observe changes to a system. In generative models, this enables the sampling of counterfactual data. Given a pre-trained CausalDiffAE, we can controllably manipulate any factor of variation, propagate the causal effects to descendants, and perform reverse diffusion to sample from the counterfactual distribution. Algorithm 2 shows the process of generating counterfactuals from a trained CausalDiffAE, where refers to the factual observation and refers to the generated counterfactual sample. To generate counterfactual instances, we first encode the high dimensional observation to a noise encoding (abduction) and transform it to causal latent variables . Then, we intervene on a desired variable and propagate the causal effects via neural mechanisms to yield the intervened representation . We utilize the DDIM sampling algorithm to ensure the stochastic noise is a deterministic encoding to enable semantic manipulations. Finally, we decode using DDIM conditioned on to obtain a counterfactual . In lines 12-13, we use the DDIM non-Markovian deterministic generative process to generate counterfactual instances as follows:
(20) |
Conditioning vs. Intervening. When we study causal generative models, we utilize the intervention operation, which is a fundamentally different operation than conditioning. When we condition, we narrow our scope to a specific subgroup of the data based on the conditioning variable. Interventions are population-level operations that fix a variable’s value (rendering it independent of its parents) to determine causal effects downstream. We emphasize that, under this intervention operation, causal models are robust to distribution shifts and can generate data outside the support of the training distribution.
4.6 Weak Supervision
To reduce the reliance on labeled data, inspired by classifier-free [10] guidance, we train a CausalDiffAE with a weak supervision guidance paradigm on the representation level [5].
Training. In the limited labeled-data regime, we train two models: an unconditional denoising diffusion model parameterized by the score estimator and a representation-conditioned model parameterized through . We use a single neural network to parameterize both models, where for the unconditional model we use only the unlabeled data for predicting the score (i.e., ).
Generation. The counterfactual generation procedure in lines 12-13 of Algorithm 2 can be modified to generate counterfactuals with a guidance strength , which can be interpreted as controlling the strength of the intervention on the causal variable to generate the counterfactual in our case. The overall modified score estimation during generation can be performed using the following linear combination of conditional and unconditional score estimates
(21) |
where is the set of latent causal factors after an intervention. The original utility of the classifier-free paradigm was to decrease the generation of diverse data in favor of higher-quality image samples without needing classifier gradients. So, controls the trade-off between higher quality and diverse samples. In our case, we care about generating high-quality counterfactual data. Intuitively, a higher implies a stronger effect of the intervention on the generated counterfactual since the conditional model is sensitive to interventions. So, as decreases, the unconditional model dilutes the effect of the intervention-sensitive model. In this sense, the sampling mechanism can be used to evaluate the causal strength of interventions. We find that the weak supervision paradigm enables (1) more efficient training with a weaker supervision signal, and (2) fine-grained control over generated counterfactuals.
5 Experiments
5.1 Empirical Setting
Datasets. We experiment on three datasets. We use the MorphoMNIST dataset [4] created by imposing a 2-variable SCM to generate morphological transformations on the original MNIST dataset, where thickness is the cause of the intensity of the digit [22], as shown in Figure 2(a). The Pendulum dataset [35] consists of images of a causal system consisting of a light source, pendulum, and shadow. The light source and the pendulum angle determine the length and position of the shadow, as shown in the causal graph in Figure 2(b). We also use CausalCircuit [3], a complex 3D robotics dataset where a robot arm moves around to turn on red, green, or blue lights. The causal graph of this system is shown in Figure 3.


Baselines. CausalVAE [35] is a VAE-based causal representation learning framework that models causal variables using a linear SCM and enables counterfactual generation during inference time through interventions on causal variables. Class-conditional diffusion model (CCDM) [20] is a conditional diffusion model that utilizes class labels as the conditioning signal in reverse diffusion. Thus, this model is capable of generating new samples determined by a discrete or continuous set of labels . DiffAE [24] is a diffusion model that aims to learn manipulable and semantically meaningful latent codes. However, this approach learns an arbitrary representation in an unsupervised fashion and does not disentangle the latent space. Manipulations are performed using a post-hoc classifier for linear interpolation. Thus, the learned representation would not be ideal to perform causal interventions. For a fair comparison in counterfactual generation, we modify the objective to disentangle the latent space by incorporating label information in a prior to regularize the posterior. We call this extension DisDiffAE. We use DisDiffAE as a baseline to evaluate counterfactual generation and the DiffAE to evaluate disentanglement.
Metrics. We primarily use two quantitative metrics to evaluate the performance of our approach. To evaluate the disentanglement of the learned representations, we use the DCI disentanglement metric [7]. A high DCI score also suggests the effectiveness of controllable generation. In the context of a causal representation, this means that we can intervene on latent codes in an isolated fashion without any entanglements (i.e., two factors are encoded in the same latent code). To quantitatively evaluate generated counterfactuals, we adopt the Effectiveness metric from Melistas et al [18], which evaluates how successful the performed intervention was at generating the counterfactual. We train anti-causal predictors via convolutional regressors on the training dataset for each continuous causal variable . Then, we report the mean absolute error (MAE) loss between the predicted values from the generated counterfactual and the true values of the counterfactual. This metric captures how controllable the factors are and the accuracy of the generated counterfactuals.
For details about the datasets, implementation, metrics, and computational requirements, see Appendix D. Our code is available at https://github.com/Akomand/CausalDiffAE.
Dataset | Model | DCI |
---|---|---|
MorphoMNIST | CausalVAE | |
DiffAE | ||
CausalDiffAE | ||
Pendulum | CausalVAE | |
DiffAE | ||
CausalDiffAE | ||
CausalCircuit | CausalVAE | |
DiffAE | ||
CausalDiffAE |
5.2 Disentanglement of Latent Space
We compare the disentanglement of CausalDiffAE with other baseline models, as shown in Table 1. We observe that diffusion-based representation learning objectives coupled with a suitable prior can better disentangle latent variables compared to VAE-based models. We do not include CCDM as a baseline here since it does not produce a representation to be evaluated. Compared to CausalVAE, the diffusion-based decoder in CausalDiffAE disentangles the semantic factors of variation to a much greater degree. Thus, we can perform interventions on causal variables in isolation and observe their downstream effects. We also note that DiffAE [24] does not learn a disentangled latent space since the semantic representation learned is arbitrary. To perform controllable manipulations with DiffAE, a post-hoc classifier must be trained to guide the sampling process. CausalDiffAE offers more precise control over learned factors through the disentanglement objective without the need to train additional classifiers.
5.3 Controllable Counterfactual Generation
Qualitative Evaluation. We show that CausalDiffAE produces much more realistic counterfactual samples compared to other acausal baselines and its VAE counterpart, CausalVAE. We attribute this to the diffusion process, which is better capable of capturing causally relevant information along with low-level stochastic variation.
Figure 2(a) shows the counterfactual generation results for the MorphoMNIST dataset. CausalVAE can generate counterfactual images after intervening on either thickness or intensity, but the accuracy and quality of the generated counterfactuals is far lower than CausalDiffAE. For instance, lower thickness does not lower the intensity and lower intensity intervention seems to change the thickness of the digit. CCDM fails to produce samples consistent with the underlying causal model. For example, intervening on the intensity produces a sample that increases in thickness. From a conditioning perspective, high-intensity digits tend to be thicker in the training distribution. For DisDiffAE, increasing the thickness does not influence the intensity.

Figure 2(b) shows counterfactual generation results for the Pendulum causal system. Upon interventions, images generated by CCDM are not consistent with the causal model. For example, intervening on the light position changes both the light position and the pendulum angle. DisDiffAE produces images where we can control one factor at a time, but does not reflect causal effects. For example, changing the angle of the pendulum does not accurately change the shadow length and position. CausalDiffAE generates higher-quality counterfactuals that are consistent with the causal model. Specifically, intervention on the pendulum angle or light position changes the shadow length and position accurately. On the other hand, interventions on children variables leave the parents unchanged.
Figure 3 shows counterfactuals from the CausalCircuit dataset. CausalVAE generates inaccurate counterfactuals for many scenarios (e.g., intervention on the blue light intensity changes the intensity of all other lights). For CCDM, moving the robot arm over the green light fails to turn the light on. Furthermore, manipulating the light intensity of other lights affects the position of the robot arm. DisDiffAE enables control over the generative factors, but does not consider causal effects. For example, moving the robot arm over the green light does not turn it on. Counterfactuals generated from CausalDiffAE are consistent with the causal system. For example, moving the robot arm over the green button turns the light on and as a result also turns on the red light, which is a downstream child variable. Intervening on the blue or green light slightly increases the intensity of the red light. Intervening on the red light leaves all parent variables unchanged. For additional counterfactual generation results, see Appendix D.5.
Factor | Model | Intervention | |
---|---|---|---|
do() | do() | ||
Thickness | CausalVAE | ||
() | DisDiffAE | ||
CausalDiffAE | |||
Intensity | CausalVAE | ||
() | DisDiffAE | ||
CausalDiffAE |
Factor | Model | Intervention | |||
---|---|---|---|---|---|
do() | do() | do() | do() | ||
Angle | CausalVAE | ||||
() | DisDiffAE | ||||
CausalDiffAE | |||||
LightPos | CausalVAE | ||||
() | DisDiffAE | ||||
CausalDiffAE | |||||
ShadowLen | CausalVAE | ||||
() | DisDiffAE | ||||
CausalDiffAE | |||||
ShadowPos | CausalVAE | ||||
() | DisDiffAE | ||||
CausalDiffAE |
* Standard error is roughly in the range to for all averages.
Quantitative Evaluation. We quantitatively show using the effectiveness metric that CausalDiffAE generates counterfactuals that are both accurate and realistic. We perform random interventions from a uniform distribution over the test dataset for each causal variable. We find that CausalDiffAE almost always outperforms other baselines in the effectiveness metric, as shown in Tables 2 and 3, for all causal factors. Specifically, for the MorphoMNIST dataset, we observe that interventions on thickness produce counterfactuals that accurately reflect both the thickness and intensity values. In the scenario where we intervene on thickness, the intensity MAE is lower for CausalDiffAE than other baselines, which indicates that the generated counterfactual has an accurate intensity value consistent with the causal effect of thickness on intensity. When we intervene on intensity, the thickness MAE is lower for CausalDiffAE than baselines, which suggests that the generated counterfactual retains its original thickness value upon intervention on intensity. For the Pendulum dataset, we see a similar phenomenon, where interventions on causal factors along with their downstream effects are accurately captured in the generated counterfactuals. We do not evaluate effectiveness for the CausalCircuit dataset since we do not have access to the generative process used to obtain the factors. We compute the average effectiveness value over runs with different random seeds. Our results strongly imply that the generated counterfactuals closely match the true counterfactuals.
5.4 Case Study: Weak Supervision Results
Unlike VAE-based approaches, the weak supervision paradigm of diffusion models reduces the full-label supervision. We study the weak supervision scenario with the MorphoMNIST dataset. We jointly train a representation-conditioned and unconditional model, where the conditioned split is far less than the unconditioned split. We have two main motivations for doing this: (1) it greatly reduces the need for fully labeled datasets, and (2) it enables granular control over generated counterfactuals. We denote the proportion of unlabeled data by . The CausalDiffAE model trained with on the MorphoMNIST dataset yields a DCI score of , which suggests that even under strictly limited label supervision, CausalDiffAE learns disentangled representations. We also empirically show that changing the parameter controls the strength of the intervention on the generated counterfactual. Figure 4 shows MNIST digits generated using the joint estimated score from the reduced supervision version of CausalDiffAE. We observe that interventions have virtually no effect when sampling using the joint score with . For , we see a stronger effect of the intervention on the thickness and intensity of the digit. Finally, for the fully-supervised score , the intervention acts the strongest. Thus, varying in the range can be interpreted as generating a range of different counterfactuals.

6 Conclusion
In this work, we propose CausalDiffAE, a diffusion-based framework for causal representation learning and counterfactual generation. We propose a causal encoding mechanism that maps images to causally related factors. We learn the causal mechanisms among factors via neural networks. We formulate a variational diffusion-based objective to enforce the disentanglement of the latent space to enable latent space manipulations. We propose a DDIM-based counterfactual generation algorithm subject to interventions. For limited supervision scenarios, we propose a weak supervision extension of our model, which jointly learns an unconditional and conditional model. This objective also enables granular control over generated counterfactuals. We empirically show the capability of our model using both qualitative and quantitative metrics. Future work includes exploring counterfactual generation in text-to-image diffusion models.
Acknowledgements
This work is supported in part by National Science Foundation under awards 1910284, 1946391 and 2147375, the National Institute of General Medical Sciences of National Institutes of Health under award P20GM139768, and the Arkansas Integrative Metabolic Research Center at University of Arkansas.
References
- Augustin et al. [2022] M. Augustin, V. Boreiko, F. Croce, and M. Hein. Diffusion visual counterfactual explanations. In Advances in Neural Information Processing Systems, 2022.
- Bengio et al. [2013] Y. Bengio, A. Courville, and P. Vincent. Representation learning: A review and new perspectives. IEEE Transactions on Pattern Analysis and Machine Intelligence, 35(8):1798–1828, 2013.
- Brehmer et al. [2022] J. Brehmer, P. D. Haan, P. Lippe, and T. Cohen. Weakly supervised causal representation learning. In Advances in Neural Information Processing Systems, 2022.
- Castro et al. [2019] D. C. Castro, J. Tan, B. Kainz, E. Konukoglu, and B. Glocker. Morpho-MNIST: Quantitative assessment and diagnostics for representation learning. Journal of Machine Learning Research, 20(178), 2019.
- Dhariwal and Nichol [2021] P. Dhariwal and A. Q. Nichol. Diffusion models beat GANs on image synthesis. In Advances in Neural Information Processing Systems, 2021.
- Eastwood and Williams [2018] C. Eastwood and C. K. I. Williams. A framework for the quantitative evaluation of disentangled representations. In International Conference on Learning Representations, 2018.
- Eastwood et al. [2023] C. Eastwood, A. L. Nicolicioiu, J. V. Kügelgen, A. Kekić, F. Träuble, A. Dittadi, and B. Schölkopf. DCI-ES: An extended disentanglement framework with connections to identifiability. In The Eleventh International Conference on Learning Representations, 2023.
- Goodfellow et al. [2014] I. Goodfellow, J. Pouget-Abadie, M. Mirza, B. Xu, D. Warde-Farley, S. Ozair, A. Courville, and Y. Bengio. Generative adversarial nets. In Advances in Neural Information Processing Systems, 2014.
- Higgins et al. [2017] I. Higgins, L. Matthey, A. Pal, C. Burgess, X. Glorot, M. Botvinick, S. Mohamed, and A. Lerchner. beta-VAE: Learning basic visual concepts with a constrained variational framework. In International Conference on Learning Representations, 2017.
- Ho and Salimans [2021] J. Ho and T. Salimans. Classifier-free diffusion guidance. In NeurIPS 2021 Workshop on Deep Generative Models and Downstream Applications, 2021.
- Ho et al. [2020] J. Ho, A. Jain, and P. Abbeel. Denoising diffusion probabilistic models. In Advances in Neural Information Processing Systems, 2020.
- Karimi Mamaghan et al. [2024] A. M. Karimi Mamaghan, A. Dittadi, S. Bauer, K. H. Johansson, and F. Quinzan. Diffusion-based causal representation learning. Entropy, 26(7), 2024.
- Khemakhem et al. [2020] I. Khemakhem, D. Kingma, R. Monti, and A. Hyvarinen. Variational autoencoders and nonlinear ica: A unifying framework. In Proceedings of the Twenty Third International Conference on Artificial Intelligence and Statistics, 2020.
- Kingma and Welling [2014] D. P. Kingma and M. Welling. Auto-encoding variational bayes. In International Conference on Learning Representations, 2014.
- Kocaoglu et al. [2018] M. Kocaoglu, C. Snyder, A. G. Dimakis, and S. Vishwanath. CausalGAN: Learning causal implicit generative models with adversarial training. In International Conference on Learning Representations, 2018.
- Komanduri et al. [2024] A. Komanduri, X. Wu, Y. Wu, and F. Chen. From identifiable causal representations to controllable counterfactual generation: A survey on causal generative modeling. Transactions on Machine Learning Research, 2024.
- Liu et al. [2022] X. Liu, P. Sanchez, S. Thermos, A. Q. O’Neil, and S. A. Tsaftaris. Learning disentangled representations in the imaging domain. Medical Image Analysis, 80:102516, 2022.
- Melistas et al. [2024] T. Melistas, N. Spyrou, N. Gkouti, P. Sanchez, A. Vlontzos, G. Papanastasiou, and S. A. Tsaftaris. Benchmarking counterfactual image generation. arXiv preprint arXiv:2403.20287, 2024.
- Mittal et al. [2023] S. Mittal, K. Abstreiter, S. Bauer, B. Schölkopf, and A. Mehrjou. Diffusion based representation learning. In Proceedings of the 40th International Conference on Machine Learning, 2023.
- Nichol and Dhariwal [2021] A. Q. Nichol and P. Dhariwal. Improved denoising diffusion probabilistic models. In Proceedings of the 38th International Conference on Machine Learning, 2021.
- Pandey et al. [2022] K. Pandey, A. Mukherjee, P. Rai, and A. Kumar. DiffuseVAE: Efficient, controllable and high-fidelity generation from low-dimensional latents. Transactions on Machine Learning Research, 2022.
- Pawlowski et al. [2020] N. Pawlowski, D. Coelho de Castro, and B. Glocker. Deep structural causal models for tractable counterfactual inference. In Advances in Neural Information Processing Systems, 2020.
- Pearl [2009] J. Pearl. Causality. Cambridge University Press, Cambridge, UK, 2 edition, 2009. ISBN 978-0-521-89560-6.
- Preechakul et al. [2022] K. Preechakul, N. Chatthee, S. Wizadwongsa, and S. Suwajanakorn. Diffusion autoencoders: Toward a meaningful and decodable representation. In IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2022.
- Radford et al. [2021] A. Radford, J. W. Kim, C. Hallacy, A. Ramesh, G. Goh, S. Agarwal, G. Sastry, A. Askell, P. Mishkin, J. Clark, G. Krueger, and I. Sutskever. Learning transferable visual models from natural language supervision. In Proceedings of the 38th International Conference on Machine Learning, 2021.
- Ramesh et al. [2022] A. Ramesh, P. Dhariwal, A. Nichol, C. Chu, and M. Chen. Hierarchical text-conditional image generation with clip latents. arXiv preprint arXiv:2204.06125, 2022.
- Rombach et al. [2022] R. Rombach, A. Blattmann, D. Lorenz, P. Esser, and B. Ommer. High-resolution image synthesis with latent diffusion models. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 2022.
- Saharia et al. [2022] C. Saharia, W. Chan, S. Saxena, L. Li, J. Whang, E. Denton, S. K. S. Ghasemipour, B. K. Ayan, S. S. Mahdavi, R. G. Lopes, T. Salimans, J. Ho, D. J. Fleet, and M. Norouzi. Photorealistic text-to-image diffusion models with deep language understanding. In Advances in Neural Information Processing Systems, 2022.
- Sanchez et al. [2022] P. Sanchez, J. P. Voisey, T. Xia, H. I. Watson, A. Q. ONeil, and S. A. Tsaftaris. Causal machine learning for healthcare and precision medicine. Royal Society Open Science, 2022.
- Scholkopf et al. [2021] B. Scholkopf, F. Locatello, S. Bauer, N. R. Ke, N. Kalchbrenner, A. Goyal, and Y. Bengio. Toward Causal Representation Learning. Proceedings of the IEEE, 109:612–634, May 2021. ISSN 0018-9219, 1558-2256.
- Sohl-Dickstein et al. [2015] J. Sohl-Dickstein, E. Weiss, N. Maheswaranathan, and S. Ganguli. Deep unsupervised learning using nonequilibrium thermodynamics. In Proceedings of the 32nd International Conference on Machine Learning, 2015.
- Song et al. [2021a] J. Song, C. Meng, and S. Ermon. Denoising diffusion implicit models. In International Conference on Learning Representations, 2021a.
- Song et al. [2021b] Y. Song, J. Sohl-Dickstein, D. P. Kingma, A. Kumar, S. Ermon, and B. Poole. Score-based generative modeling through stochastic differential equations. In International Conference on Learning Representations, 2021b.
- Vowels et al. [2022] M. J. Vowels, N. C. Camgoz, and R. Bowden. D’ya like dags? a survey on structure learning and causal discovery. ACM Computing Surveys, 2022.
- Yang et al. [2021] M. Yang, F. Liu, Z. Chen, X. Shen, J. Hao, and J. Wang. Causalvae: Disentangled representation learning via neural structural causal models. In IEEE Conference on Computer Vision and Pattern Recognition, 2021.
- Zheng et al. [2018] X. Zheng, B. Aragam, P. K. Ravikumar, and E. P. Xing. Dags with no tears: Continuous optimization for structure learning. In Advances in Neural Information Processing Systems, 2018.
Appendices
Appendix A Derivation of ELBO
Given a high-dimensional input image , an auxiliary weak supervision signal , a latent noise encoding , latent representation , and a sequence of latent representations learned by the diffusion model, the CausalDiffAE generative process can be factorized as follows:
(22) |
where are the parameters of the reverse process of the conditional diffusion model. The log-likelihood of the input data distribution can be obtained as follows:
(23) |
The joint posterior distribution is intractable, so we approximate it using a variational distribution which can be factorized into the following conditional distributions
(24) |
where are the parameters of the variational encoder network. Since the likelihood of the data is intractable, we can approximate it by maximizing the following evidence lower bound (ELBO):
(25) | ||||
(26) | ||||
(27) | ||||
(28) | ||||
(29) |
In the learning process, we minimize the negative of the derived ELBO. We simplify this objective by using the parameterization to optimize the representation-conditioned DDPM loss. Further, since and are one-to-one mapped, we can split the joint conditional distribution into separate conditional distributions. Thus, we have the following final objective for CausalDiffAE:
(31) |
Appendix B Connection to Score-based Generative Models
Diffusion models can also be represented as stochastic differential equations (SDEs) [33] to model continuous-time perturbations. Specifically, the forward diffusion process can be modeled as the solution to an SDE on a continuous-time domain with stochastic trajectories:
(32) |
where is the standard Weiner process, is a vector-valued function known as the drift coefficient of and is a scalar function known as the diffusion coefficient of . The drift and diffusion coefficients can be considered as the mean and variance of the noise perturbations in the diffusion process, respectively. The reverse diffusion process can be modeled by the solution to the reverse-time SDE of Eq. (32), which can be derived analytically as:
(33) |
where is the standard Weiner process in reverse time and is the score of the data distribution at timestep . Once we know the score of the marginal distribution for all timesteps , we can derive the reverse diffusion process from Eq. (33).
Song et al [33] showed that the denoising diffusion probabilistic model (DDPM) is a discretization of the following Variance Preserving SDE (VP-SDE)
(34) |
Thus, learning a noise prediction network and minimizing MSE in diffusion probabilistic models is equivalent to approximating the score of the data distribution in the SDE formulation. From a score-based perspective, we aim to minimize the following conditional denoising score-matching form of our objective
(35) |
where approximates the score of the data distribution conditioned on and is a positive weighing function. The ideal for modeling natural phenomena in the world is by using differential equations to model the physical mechanisms [30]. In the SDE formulation, the causal variables are used to denoise the high-dimensional data, which is modeled as a reverse-time stochastic trajectory. We can interpret this idea as modeling the dynamics of high-dimensional systems by incorporating causal information. As opposed to simply learning an arbitrary latent representation, a disentangled causal representation encodes the causal information that the denoising process can use to reconstruct causally relevant features in high-dimensional data.
Appendix C Discussion on Causal Discovery
In this work, we assume the latent causal structure is known since we focus on counterfactual generation. In principle, our framework can be combined with causal structure learning methods such as NOTEARS [36] by adding a penalty to terms in the VAE loss objective to enforce sparsity and acyclicity as follows
(36) |
where is the acyclicity constraint and enforces the sparsity of the DAG. We can alternatively use the for sparsity to ensure a differentiable objective. Similar to [36], we can utilize the augmented Lagrangian to optimize the joint loss objective. Additionally, other causal discovery algorithms could be used heuristically with a variety of different assumptions [34]. We look to explore this direction in future work.
Appendix D Experiment Details
D.1 Dataset Details
MorphoMNIST. The MorphoMNIST dataset [4] is produced by applying morphological transformations on the original MNIST handwritten digit dataset. The digits can be described by measurable shape attributes such as stroke thickness, stroke length, width, height, and slant of digit. Pawlowski et al [22] impose a -variable SCM to generate the morphological transformations, where stroke thickness is a cause of the brightness of each digit. That is, thicker digits are often brighter, whereas thinner digits are dimmer. The data-generating process is as follows
(37) | ||||||
where is the resulting image, is the exogenous noise for each variable, and is the logistic sigmoid.
Pendulum. The Pendulum dataset [35] consists of a set of K images with resolution describing a physical system of a pendulum and light source that cause the length and position of a shadow. The causal variables of interest are the angle of the pendulum, the position of the light source, the length of the shadow, and the position of the shadow. The data generating process is as follows:
Causal Circuit. The Causal Circuit dataset is a new dataset created by [3] to explore research in causal representation learning. The dataset consists of resolution images generated by ground-truth latent causal variables: robot arm position, red light intensity, green light intensity, and blue light intensity. The images show a robot arm interacting with a system of buttons and lights. The data is rendered using an open-source physics engine. The original dataset consists of pairs of images before and after an intervention has taken place. For the purposes of this work, we only utilize observational data of either the before or after system. The data is generated according to the following process:
where , , and are the pressed state of buttons that depends on how far the button is touched from the center, is the robot arm position, and , , and are the intensities of the blue, green, and red lights, respectively.
D.2 Implementation Details
We use the same network architectures and hyperparameters used in other works based on diffusion models [11, 5, 20]. We set the causal latent variable size to to ensure a large enough capacity to capture causally relevant information. The representation-conditioned noise predictor is parameterized by a UNet with the attention mechanism. Similar to [11], we use a linear noise scheduling for the variance parameter between and during training. For all three datasets, we start the bottleneck parameter at and linearly increase throughout training to a final value of .
Parameter | MorphoMNIST | Pendulum | CausalCircuit |
---|---|---|---|
Batch size | |||
Base channels | |||
Channel multipliers | |||
Training set | K | K | K |
Test set | K | K | K |
Image resolution | |||
Num causal variables | |||
size | |||
scheduler | Linear | Linear | Cosine |
Learning rate | |||
Optimizer | Adam | Adam | Adam |
Diffusion steps | |||
Iterations | K | K | K |
Diffusion loss | MSE | MSE | MSE |
Sampling | DDIM | DDIM | DDIM |
Stride | |||
Bottleneck |
D.3 Metrics Details
DCI Disentanglement [6]. The DCI disentanglement score quantifies the degree to which a representation disentangles the underlying factors of variation with each variable capturing at most one generative factor. Let be the probability of being a strong predictor of . Then, the disentanglement score is defined as
(38) |
If is a strong predictor for only a single generative factor, . If is equally important in predicting all generative factors, . Let be the relative latent code importances. The total disentanglement score is a weighted average of the individual
(39) |
Effectiveness [18]. The effectiveness metric aims to identify how successful the performed intervention is. To quantitatively evaluate the effectiveness for a given counterfactual image, an anti-causal predictor is trained on the data distribution, for each causal variable . Each predictor approximates the counterfactual value of the variable given the counterfactual image as input
(40) |
where is the corresponding distance, defined as a classification metric for categorical variables and a regression metric for continuous ones.
D.4 Computational Requirements
We run our experiments on an Ubuntu 20.04 workstation with eight NVIDIA Tesla V100-SXM2 GPUs with 32GB RAM. It is well-known that diffusion models have a higher computational complexity than other generative models, such as VAEs and GANs. Generally speaking, all the diffusion-based approaches have quite a similar runtime, whereas CausalVAE is much faster. We expect any developments in the training and sampling efficiency of diffusion probabilistic models to apply to our proposed diffusion-based approach as well.
D.5 Additional Experiments


