SMReferences for Appendix
DeepMed: Semiparametric Causal Mediation Analysis with Debiased Deep Learning
Abstract
Causal mediation analysis can unpack the black box of causality and is therefore a powerful tool for disentangling causal pathways in biomedical and social sciences, and also for evaluating machine learning fairness. To reduce bias for estimating Natural Direct and Indirect Effects in mediation analysis, we propose a new method called that uses deep neural networks (DNNs) to cross-fit the infinite-dimensional nuisance functions in the efficient influence functions. We obtain novel theoretical results that our method (1) can achieve semiparametric efficiency bound without imposing sparsity constraints on the DNN architecture and (2) can adapt to certain low-dimensional structures of the nuisance functions, significantly advancing the existing literature on DNN-based semiparametric causal inference. Extensive synthetic experiments are conducted to support our findings and also expose the gap between theory and practice. As a proof of concept, we apply to analyze two real datasets on machine learning fairness and reach conclusions consistent with previous findings.
1 Introduction
Tremendous progress has been made in this decade on deploying deep neural networks (DNNs) in real-world problems (krizhevsky2012imagenet; wolf2019huggingface; jumper2021highly; brown2022deep). Causal inference is no exception. In semiparametric causal inference, a series of seminal works (chen2020causal; chernozhukov2020adversarial; farrell2021deep) initiated the investigation of statistical properties of causal effect estimators when the nuisance functions (the outcome regressions and propensity scores) are estimated by DNNs. However, there are a few limitations in the current literature that need to be addressed before the theoretical results can be used to guide practice:
(1) Most recent works mainly focus on total effect (chen2020causal; farrell2021deep). In many settings, however, more intricate causal parameters are often of greater interests. In biomedical and social sciences, one is often interested in βmediation analysisβ to decompose the total effect into direct and indirect effect to unpack the underlying black-box causal mechanism (baron1986moderator). More recently, mediation analysis also percolated into machine learning fairness. For instance, in the context of predicting the recidivism risk, nabi2018fair argued that, for a βfairβ algorithm, sensitive features such as race should have no direct effect on the predicted recidivism risk. If such direct effects can be accurately estimated, one can detect the potential unfairness of a machine learning algorithm. We will revisit such applications in Section LABEL:sec:real and Appendix LABEL:app:real.
(2) Statistical properties of DNN-based causal estimators in recent works mostly follow from several (recent) results on the convergence rates of DNN-based nonparametric regression estimators (suzuki2019adaptivity; schmidt2020nonparametric; tsuji2021estimation), with the limitation of relying on sparse DNN architectures. The theoretical properties are in turn evaluated by relatively simple synthetic experiments not designed to generate nearly infinite-dimensional nuisance functions, a setting considered by almost all the above related works.
The above limitations raise the tantalizing question whether the available statistical guarantees for DNN-based causal inference have practical relevance. In this work, we plan to partially fill these gaps by developing a new method called for semiparametric mediation analysis with DNNs. We focus on the Natural Direct/Indirect Effects (NDE/NIE) (robins1992identifiability; pearl2001direct) (defined in Section 2.1), but our results can also be applied to more general settings; see Remark 2. The estimators leverage the βmultiply-robustβ property of the efficient influence function (EIF) of NDE/NIE (tchetgen2012semiparametric; farbmacher2022causal) (see Proposition 1 in Section 2.2), together with the flexibility and superior predictive power of DNNs (see Section 3.1 and Algorithm 3.1). In particular, we also make the following novel contributions to deepen our understanding of DNN-based semiparametric causal inference:
-
β’
On the theoretical side, we obtain new results that our method can achieve semiparametric efficiency bound without imposing sparsity constraints on the DNN architecture and can adapt to certain low-dimensional structures of the nuisance functions (see Section LABEL:sec:stat), thus significantly advancing the existing literature on DNN-based semiparametric causal inference. Non-sparse DNN architecture is more commonly employed in practice (farrell2021deep), and the low-dimensional structures of nuisance functions can help avoid curse-of-dimensionality. These two points, taken together, significantly advance our understanding of the statistical guarantee of DNN-based causal inference.
-
β’
More importantly, on the empirical side, in Section LABEL:sec:sim, we designed sophisticated synthetic experiments to simulate nearly infinite-dimensional functions, which are much more complex than those in previous related works (chen2020causal; farrell2021deep; adcock2021gap). We emphasize that these nontrivial experiments could be of independent interest to the theory of deep learning beyond causal inference, to further expose the gap between deep learning theory and practice (adcock2021gap; gottschling2020troublesome); see Remark LABEL:beyond for an extended discussion. As a proof of concept, in Section LABEL:sec:real and Appendix LABEL:app:real, we also apply to re-analyze two real-world datasets on algorithmic fairness and reach similar conclusions to related works.
-
β’
Finally, a user-friendly R package can be found at https://github.com/siqixu/DeepMed. Making such resources available helps enhance reproducibility, a highly recognized problem in all scientific disciplines, including (causal) machine learning (pineau2021improving; kaddour2022causal).
2 Definition, identification, and estimation of NDE and NIE
2.1 Definition of NDE and NIE
Throughout this paper, we denote as the primary outcome of interest, as a binary treatment variable, as the mediator on the causal pathway from to , and (or more generally, compactly supported in ) as baseline covariates including all potential confounders. We denote the observed data vector as . Let denote the potential outcome for the mediator when setting and be the potential outcome of under and , where and is in the support of . We define the average total (treatment) effect as , the average NDE of the treatment on the outcome when the mediator takes the natural potential outcome when as , and the average NIE of the treatment on the outcome via the mediator as . We have the trivial decomposition for . In causal mediation analysis, the parameters of interest are and .
2.2 Semiparametric multiply-robust estimators of NDE/NIE
Estimating and can be reduced to estimating for . We make the following standard identification assumptions:
-
i.
Consistency: if , then for all ; while if and , then for all and all in the support of .
-
ii.
Ignorability: , , , and , almost surely for all and all . The first three conditions are, respectively, no unmeasured treatment-outcome, mediator-outcome and treatment-mediator confounding, whereas the fourth condition is often referred to as the βcross-worldβ condition. We provide more detailed comments on these four conditions in Appendix LABEL:app:ignore.
-
iii.
Positivity: The propensity score for some constants , almost surely for all ; , the conditional density (mass) function of (when is discrete) given and , is strictly bounded between for some constants almost surely for all in and all .
Under the above assumptions, the causal parameter for can be identified as either of the following three observed-data functionals:
(1) |
where denotes the indicator function, denotes the marginal density of , and is the outcome regression model, for which we also make the following standard boundedness assumption:
-
iv.
is also strictly bounded between for some constant .
Following the convention in the semiparametric causal inference literature, we call βnuisance functionsβ. tchetgen2012semiparametric derived the EIF of : , where
(2) |
The nuisance functions , and appeared in are unknown and generally high-dimensional. But with a sample of the observed data, based on , one can construct the following generic sample-splitting multiply-robust estimator of :
(3) |
where is a subset of all data, and replaces the unknown nuisance functions in by some generic estimators computed using the remaining nuisance sample data, denoted as . Cross-fit is then needed to recover the information lost due to sample splitting; see Algorithm 3.1. It is clear from (2) that is a consistent estimator of as long as any two of are consistent estimators of the corresponding true nuisance functions, hence the name βmultiply-robustβ. Throughout this paper, we take and assume:
-
v.
Any nuisance function estimators are strictly bounded within the respective lower and upper bounds of .
To further ease notation, we define: for any , and where , and are point-wise estimation errors of the estimated nuisance functions. In defining the above -estimation errors, we choose to take expectation with respect to (w.r.t.) the law only for convenience, with no loss of generality by Assumptions iii and v.
To show the cross-fit version of is semiparametric efficient for , we shall demonstrate under what conditions (newey1990semiparametric). The following proposition on the statistical properties of is a key step towards this objective.
Proposition 1.
Denote as the bias of conditional on the nuisance sample . Under Assumptions i β v, is of second-order:
(4) |
Furthermore, if the RHS of (4) is , then
(5) |
Although the above result is a direct consequence of the EIF , we prove Proposition 1 in Appendix LABEL:app:bias for completeness.
Remark 2.
The total effect can be viewed as a special case, for which for . Then corresponds to the nonparametric EIF of :
where . Hence all the theoretical results in this paper are applicable to total effect estimation. Our framework can also be applied to all the statistical functionals that satisfy a so-called βmixed-biasβ property, characterized recently in rotnitzky2021characterization. This class includes the quadratic functional, which is important for uncertainty quantification in machine learning.
3 Estimation and inference of NDE/NIE using DeepMed
We now introduce , a method for mediation analysis with nuisance functions estimated by DNNs. By leveraging the second-order bias property of the multiply-robust estimators of NDE/NIE (Proposition 1), we will derive statistical properties of in this section. The nuisance function estimators by DNNs are denoted as .
3.1 Details on DeepMed
First, we introduce the fully-connected feed-forward neural network with the rectified linear units (ReLU) as the activation function for the hidden layer neurons (FNN-ReLU), which will be used to estimate the nuisance functions. Then, we will introduce an estimation procedure using a -fold cross-fitting with sample-splitting to avoid the Donsker-type empirical-process assumption on the nuisance functions, which, in general, is violated in high-dimensional setup. Finally, we provide the asymptotic statistical properties of the DNN-based estimators of , and .
We denote the ReLU activation function as for any . Given vectors , we denote , with acting on the vector component-wise.
Let denote the class of the FNN-ReLU functions
where is the composition operator, is the number of layers (i.e. depth) of the network, and for , is a -dimensional weight matrix with being the number of neurons in the -th layer (i.e. width) of the network, with and , and is a -dimensional vector. To avoid notation clutter, we concatenate all the network parameters as and simply take . We also assume to be bounded: for some universal constant . We may let the dependence on , , explicit by writing as .
estimates by (3), with the nuisance functions estimated using with the -fold cross-fitting strategy, summarized in Algorithm 3.1 below; also see farbmacher2022causal. inputs the observed data and outputs the estimated total effect , NDE and NIE , together with their variance estimators , and .
Β \fname@algorithmΒ 1 with -fold cross-fitting
Β