33email: [email protected]
A Penalty Approach for Normalizing Feature Distributions to Build Confounder-Free Models
Abstract
Translating machine learning algorithms into clinical applications requires addressing challenges related to interpretability, such as accounting for the effect of confounding variables (or metadata). Confounding variables affect the relationship between input training data and target outputs. When we train a model on such data, confounding variables will bias the distribution of the learned features. A recent promising solution, MetaData Normalization (MDN), estimates the linear relationship between the metadata and each feature based on a non-trainable closed-form solution. However, this estimation is confined by the sample size of a mini-batch and thereby may cause the approach to be unstable during training. In this paper, we extend the MDN method by applying a Penalty approach (referred to as PDMN). We cast the problem into a bi-level nested optimization problem. We then approximate this optimization problem using a penalty method so that the linear parameters within the MDN layer are trainable and learned on all samples. This enables PMDN to be plugged into any architectures, even those unfit to run batch-level operations, such as transformers and recurrent models. We show improvement in model accuracy and greater independence from confounders using PMDN over MDN in a synthetic experiment and a multi-label, multi-site dataset of magnetic resonance images (MRIs).
Keywords:
Confounders Neuroscience Fairness Deep Learning.1 Introduction
Modern machine learning approaches rely on automatically learning features from data [28] using approaches such as convolutional neural networks (CNNs) [2, 8] and attention-based transformer models [6, 9]. Although these methods solve challenging problems, they are known to capture spurious associations and biases introduced by confounding or protected variables [27]. These limitations confine the neuroscientific impact of these algorithms, in which controlling for (and explaining the effects of) confounding variables is crucial. To remedy this, several approaches have been proposed, such as based on adversarial training [27, 15], counterfactual generative models [19, 13], disentanglement [22, 16], and correlation fair inference [5]. They learn features that are invariant or conditionally independent to the confounding variables.
These training methods reduce the error from confounders with minimum compromise to model accuracy. However, adversarial models or those based on disentanglement and correlation are inefficient when accounting for multiple confounders (or metadata) and only partially remove the effects from feature maps of a single layer in the network [27]. Methods based on counterfactual require reliable generative models with respect to arbitrary variables, which is added complexity. To remove confounding effects at different feature layers, MetaData Normalization (MDN) [17] can be plugged into a CNN and remove the effects of multiple confounders (or metadata) from the features while training the network. MDN aims to fix the distribution shift [3] caused by the confounding variables using a closed-form solution to linear regression capturing the relationship between confounders and each feature.
The closed-form solution in MDN requires building a linear model (relationship between metadata and each feature) as a batch-level operation. It requires large batch sizes to obtain accurate approximations of the linear model. However, batch-level statistics in MDN (similar to batch normalization) face several challenges, including (1) instability when using small batch sizes, (2) increased training time due to the calculation of closed-form solutions for each feature at each iteration, (3) inconsistent results from training and inference since there are no batches during inference, (4) inability to use MDN for online training, in which the model is trained incrementally by feeding the samples in a sequential manner, and finally (5) inability to apply MDN to Recurrent Neural Networks [4] and selected transformer models [23, 24]. To overcome these limitations, we now introduce a new penalty method that turns MDN to a layer with parameters that can be optimized with other components of the network during training.
Referred to as a Penalty approach for MetaData Normalizing (PMDN), our method improves upon the batch-level MDN operation. Specifically, PMDN can be applied to all architectures and any number of confounding variables. We show that PMDN is not dependent on the batch size. We apply PDMN to a synthetic dataset to analyze and validate the method within a controlled setting. We then examine applicability of PDMN compared to MDN in classifying multi-site MRIs into 4 diagnostic groups with image acquisition site, sex, and age as confounders.
2 Methodology

Given a dataset of training samples, we define the metadata matrix as . Each row of M, , defines the metadata for sample . Also, let be the features for all training samples extracted at a particular layer. The goal of MDN is to remove confounding information from the features and use the residual component, r, as input to the next layer of the network. The next subsection reviews how MDN performs this task via batch-level operations (Section 2.1), while Section 2.2 reformulates MDN so that it can be parameterized with respect to all training samples (Figure 1).
2.1 MDN Review
Lu et al. [17] implemented the MDN layer as a general linear model (GLM), i.e.,
(1) |
where is an unknown set of parameters, describes the component in f that is relevant to the metadata, and r is the residual component that is irrelevant to the metadata. Then, the MDN operation is defined as
(2) |
The MDN layer is not trained but instead is determined by the closed-form solution of least squares, i.e.,
(3) |
The underlying model assumes that the computation is performed across the features of all training samples. However, training of deep learning is generally confined to batches of samples producing . Therefore, [17] approximates as
(4) |
As Eq. (4) approximates the expectation only using data from a batch, the estimates are generally inaccurate for a small batch size and are likely to vary from batch to batch resulting in model instability, similar to Batch Norm [25].
2.2 PMDN: MDN as a Bi-Level Optimization
To improve model stability and accuracy, we realize that inserting an MDN layer into a generic neural network is equivalent to reformulating the original objective function of the network to a bi-level optimization.
Specifically, let be the training samples and be their corresponding prediction targets. Without loss of generality, let us assume that a network can be defined as the composition between the first few layers and the layers afterwards . For simplicity, we assume results in a one-channel feature but the following formulation generalizes to multi-channel features. Let W be the network parameters of and , then training of the network often reduces to solving the minimization problem
(5) |
Adding an MDN layer after changes the minimization problem to
(6) | ||||
(7) |
In other words, the constraint itself is a nested optimization, which aims to maximally remove the metadata effect from the feature learned by .
To solve this bi-level optimization problem, PMDN (a Penalty approach for MDN) determines the minimum to a proxy objective function that combines the two minimization problems:
(8) |
Now, Eq. (8) is a well-defined, differentiable function that can be optimized by any gradient descent algorithm. Unlike MDN that sets to different values according to the batch construction, the estimates in PMDN can converge to a local optimum defined with respect to all training data. Here, we use an alternating optimization schema for removing metadata effects (Alg. 1). As can be seen in lines 6 and 9, each of the two objectives have their own learning rates which are then consolidated into the optimizer (e.g., Adam [12]), making the implementation independent from the hyperparameter .
Although Alg. 1 only is based on one PMDN layer, multiple PMDN layers can be added without loss of generality to further remove any remaining residual confounding effects. If we perform the metadata normalization after each of the features (from different layers or channels), in Eq. (8) is the sum of all PMDN losses , where and are the feature vector and parameters of the PMDN, respectively. Furthermore, Alg. 1 uses Stochastic Gradient Descent (SGD) [21] to update W and . However, SGD can be replaced with other optimizers such as Adam [12].
3 Experiments
We apply the method to a synthetic and an MRI dataset with both continuous and discrete metadata. For each experiment, we investigate the effect of metadata on a variety of architectures including a baseline CNN, the baseline network with MDN as described in Section 2.2, and the baseline network with PMDN as described in Section 2.3. The code is available at https://github.com/vento99/PMDN.
3.1 Synthetic Dataset Experiments
Data.
The synthetic dataset [17] consisted of 2000 images subdivided into two groups of 1,000 images. The first group consisted of images where quadrants two and four are Gaussians with a variance sampled from the uniform distribution . The second group consisted of images where quadrants two and four are Gaussians with a variance from . We introduce metadata into the third quadrant of the images. In the first group, quadrant three also consists of a Gaussian with a variance from while in the second group, quadrant three consists of a Gaussian with a variance from . Theoretically, complete removal of the metadata effect will lead to a maximum model accuracy of .
Implementation.
The baseline is a simple CNN of two standard blocks. The first block consists of two convolution layers with 16 and 32 filters and a ReLU activation after each convolution layer. The second block incorporates a fully connected layer of size 84 with ReLU activation followed by another fully connected layer. We use binary cross entropy loss as . For all other methods, we add a normalization layer (one of BatchNorm [11], MDN, or PMDN) after the convolution and first fully connected layers (before ReLU activations).
Note that we also insert a LayerNorm layer [4] before each PMDN operation in order to stabilize the input features and to enable smoother gradients, faster training, and better generalization accuracy. Similar to the setup in [17], the metadata variable is colinear with the group labels. Thus, the labels were included as an additional column in the metadata matrix M during training to remove the metadata effect while preserving group differences. During inference, we remove the label column from M and the last component from as implemented in [17].

Evaluation.
To examine whether the metadata is removed from the learned features, we calculate the squared distance correlation (dcor2) between the output of the first FC layer and the metadata of each group separately and report the average of the two dcor2 values. Unlike linear correlation, dcor2 examines relationships between high-dimensional variables. Lower dcor2 values reflect independence from metadata confounding and thus, better normalization of the feature distribution due to PMDN.
Figure 2 summarizes the results. Results for the baseline, BatchNorm, and standard MDN layers are adopted from [17]. We see that the baseline and BatchNorm have an accuracy much greater than 83.3% (the theoretical optimum), which means that the metadata effect has not been removed from the features. Instead, the model may have learned spurious associations between metadata and labels as the dcor2 values are much higher than those for MDN and PMDN.
For MDN, we see an inconsistency of results across different batch sizes. When the batch size is large, the batch-level closed-form solution of approximates the true , so that MDN successfully normalizes the metadata effect. However, for a batch size of 200, MDN performance is significantly reduced. On the other hand, PMDN shows consistent results across all batch sizes supporting our hypothesis that the penalty-based approach is impartial to batch size.
3.2 Multi-label Multi-site MRI Dataset Experiments
Data.
The dataset consists of 1,262 T1-Weighted MRIs from three brain studies, where each MRI was bias field corrected, skull stripped, affinely registrated to a template of resolution. The three studies (see Table 1 for summary) were performed by (1) the Memory and Aging Center, University of California - San Francisco (UCSF) [26], (2) the Neuroscience Program, SRI International [1], and (3) the public Alzheimer’s Disease Neuroimaging Initiative (ADNI1) [20]. The participants of the three studies were divided into four cohorts: healthy older adults (no neurological/psychiatric diagnosis) (CTRL; ), adults infected with human immunodeficiency virus (HIV) without cognitive impairment (HIV; ), HIV-infected individuals with cognitive impairment that were diagnosed with HIV-Associated Neurocognitive Disorder (HAND; ) and HIV-negative adults diagnosed with mild cognitive impairment (MCI) but no HIV (MCI; ). MCI is a heterogeneous condition that reflects impairment in memory and other cognitive abilities [7]. By definition, individuals with HAND meet the criteria for both MCI and HIV. Thus, this problem is formulated as a two-label classification problem: predicting whether or not individuals are infected with HIV and predicting whether or not individuals are diagnosed with cognitive impairment. For this dataset, the metadata includes the acquisition site (one-hot encoded), participant age (-score) and participant self-identified sex (male/female).
Site | CTRL | MCI | HIV | HAND | F/M | Age Mean Std |
---|---|---|---|---|---|---|
UCSF | 156 | 148 | 37 | 145 | 97/389 | 67.00 6.47 |
ADNI | 229 | 397 | - | - | 253/373 | 75.16 6.61 |
SRI | 75 | - | 75 | - | 44/106 | 50.72 11.33 |
Implementation and Baseline Models.
The baseline consists of a 3D-ResNet [10] followed by a series of fully connected (FC) blocks. The 3D-ResNet consists of four standard residual blocks. Each block incorporates a 3D Conv with ReLU activation and a skip connection. The number of filters for the standard convolutions in each block are 3, 6, 9, and 6, respectively. All use kernel size 3 and padding size 1. The output of the 3D-ResNet (flattened size 2048) is passed through a FC-ReLU-FC-ReLU-FC architecture. The FC outputs are of size 128, 16, and 2, respectively. The loss () we use is the focal loss [14] to combat the class imbalance. For MDN, the layers are added after each FC-ReLU and after the final FC layer. As before with the Synthetic Dataset, for PMDN, we add a LayerNorm before the first two PMDN layers and include the labels as metadata. We also examine a BatchNorm architecture where we insert BatchNorms (BNs) after each FC-ReLU. Finally, since most of the previous work focused on domain-adversarial methods for learning confounder-invariant features, we examine an adversarial training method similar to [27]. After the 3D-ResNet, we add an additional head of a FC-ReLU-FC, which attempts to predict the confounding variables. The correlation loss [27] from this head is adversarially subtracted from the classification loss when updating the weights of the 3D-ResNet.
Evaluation.
We perform 5-fold cross validation and report the results in Table 2. We note that N=240 is the largest batch size feasible with our resource constraints. Based on the type of metadata variables (continuous, binary, or categorical), we choose different metric to investigate the metadata effect in the features. For age, we take the magnitude of the Pearson’s correlation between the ages and each of the two output logits and report the average. For sex, we report on the average magnitude of the point-biserial correlation between the sexes and each of the two output logits. For site, we compute the average dcor2 correlation between the site (one-hot encoded) and each of the two output logits. Finally, we calculate the accuracy for each of the four groups separately and report the average.
For each batch size, PMDN achieves the highest accuracy. This highlights that PMDN mitigates the confounding effect and produces a less biased distribution. This is also evident in the low dcor2 values for PMDN. As expected, small batch size significantly compromises MDN model performance. Additionally, we see that the adversarial training method [27] removed the confounding effects for sex and site but the confounding effect of age remained as noted by the correlations. This observation underscores the inherent limitation of the adversarial method in controlling multiple confounds because each metadata variable requires a new adversarial component in the network. Training multiple adversarial components sacrificed model accuracy.
|Batch| | Metric | Baseline | BN | Adversarial | MDN | PMDN |
---|---|---|---|---|---|---|
(Ours) | ||||||
20 | Age Corr | 0.431 | 0.382 | 0.408 | 0.235 | 0.213 |
Sex Corr | 0.209 | 0.237 | 0.154 | 0.172 | 0.141 | |
Site Corr | 0.388 | 0.312 | 0.086 | 0.132 | 0.155 | |
Accuracy | 48.8% | 44.0% | 26.5% | 41.2% | 51.3% | |
80 | Age Corr | 0.461 | 0.374 | 0.385 | 0.225 | 0.208 |
Sex Corr | 0.259 | 0.220 | 0.214 | 0.189 | 0.187 | |
Site Corr | 0.402 | 0.285 | 0.127 | 0.126 | 0.172 | |
Accuracy | 49.7% | 42.1% | 25.8% | 41.2% | 50.7% | |
160 | Age Corr | 0.488 | 0.384 | 0.543 | 0.241 | 0.199 |
Sex Corr | 0.199 | 0.268 | 0.094 | 0.174 | 0.188 | |
Site Corr | 0.431 | 0.293 | 0.166 | 0.160 | 0.149 | |
Accuracy | 45.0% | 45.1% | 26.8% | 45.6% | 50.7% | |
240 | Age Corr | 0.488 | 0.369 | 0.382 | 0.185 | 0.180 |
Sex Corr | 0.226 | 0.225 | 0.172 | 0.173 | 0.166 | |
Site Corr | 0.456 | 0.275 | 0.110 | 0.137 | 0.158 | |
Accuracy | 43.5% | 42.2% | 27.8% | 46.2% | 51.9% |
Figure 3 visualizes the feature space (via tSNE [18]) after removing metadata effects by MDN and PMDN for the small batch size of 80. As can be seen, the embedding space does not show a clear pattern with respect to sex differences (i.e., it is independent from the sex variable). For the case of site variable, PMDN illustrates less clustering effect compared to MDN’s embedding. However, note that the three sites (UCSF, ADNI, SRI) are considerably different with respect to their class label distributions (see Table 1), which explains moderate clustering in the embedding space. In conclusion, this analysis qualitatively illustrates that of the two methods, PMDN is better at removing the confounding factor.

4 Conclusion
Herein, we introduce PMDN, our novel penalty method for removing bias in model training due to confounding factors. PMDN can be plugged into any neural network architecture and is independent from batch size. By removing the effects from confounding relationships between training and target outputs, PMDN minimizes the bias in the learned features. We show improvement of PMDN, a layer with trainable parameters, when compared to MDN, a layer with a closed-form solution, on a synthetic and a neuroimaging dataset. The improvement in accuracy and confounder independence from PMDN represent an important step towards neuroscience, imaging, or clinical applications of machine learning prediction models.
4.0.1 Acknowledgements
This study was partially supported by NIH Grants (AA017347, MH113406, and MH098759) and Stanford Institute for Human-Centered AI (HAI) Google Cloud Platform (GCP) Credit.
References
- [1] Adeli, E., Kwon, D., Zhao, Q., Pfefferbaum, A., Zahr, N.M., Sullivan, E.V., Pohl, K.M.: Chained regularization for identifying brain patterns specific to hiv infection. Neuroimage 183, 425–437 (2018)
- [2] Adeli, E., Zhao, Q., Zahr, N.M., Goldstone, A., Pfefferbaum, A., Sullivan, E.V., Pohl, K.M.: Deep learning identifies morphological determinants of sex differences in the pre-adolescent brain. NeuroImage 223, 117293 (2020)
- [3] Agarwal, A., Kakade, S.M., Lee, J.D., Mahajan, G.: On the theory of policy gradient methods: Optimality, approximation, and distribution shift. Journal of Machine Learning Research 22(98), 1–76 (2021)
- [4] Ba, J.L., Kiros, J.R., Hinton, G.E.: Layer normalization. arXiv preprint arXiv:1607.06450 (2016)
- [5] Baharlouei, S., Nouiehed, M., Beirami, A., Razaviyayn, M.: R’enyi fair inference. arXiv preprint arXiv:1906.12005 (2019)
- [6] Chen, J., Lu, Y., Yu, Q., Luo, X., Adeli, E., Wang, Y., Lu, L., Yuille, A.L., Zhou, Y.: Transunet: Transformers make strong encoders for medical image segmentation. arXiv preprint arXiv:2102.04306 (2021)
- [7] Delano-Wood, L., Bondi, M.W., Sacco, J., Abeles, N., Jak, A.J., Libon, D.J., Bozoki, A.: Heterogeneity in mild cognitive impairment: Differences in neuropsychological profile and associated white matter lesion pathology. Journal of the International Neuropsychological Society 15(6), 906–914 (2009)
- [8] Deshmukh, S., Khaparde, A.: Faster region-convolutional neural network oriented feature learning with optimal trained recurrent neural network for bone age assessment for pediatrics. Biomedical Signal Processing and Control 71, 103016 (2022)
- [9] Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S., Uszkoreit, J., Houlsby, N.: An image is worth 16x16 words: Transformers for image recognition at scale. In: International Conference on Learning Representations (2021), https://openreview.net/forum?id=YicbFdNTTy
- [10] Hara, K., Kataoka, H., Satoh, Y.: Learning spatio-temporal features with 3d residual networks for action recognition. In: Proceedings of the IEEE International Conference on Computer Vision Workshops. pp. 3154–3160 (2017)
- [11] Ioffe, S., Szegedy, C.: Batch normalization: Accelerating deep network training by reducing internal covariate shift. In: International conference on machine learning. pp. 448–456. PMLR (2015)
- [12] Kingma, D.P., Ba, J.: Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980 (2014)
- [13] Lahiri, A., Alipour, K., Adeli, E., Salimi, B.: Combining counterfactuals with shapley values to explain image models. arXiv preprint arXiv:2206.07087 (2022)
- [14] Lin, T.Y., Goyal, P., Girshick, R., He, K., Dollár, P.: Focal loss for dense object detection. In: Proceedings of the IEEE international conference on computer vision. pp. 2980–2988 (2017)
- [15] Liu, T.Y., Kannan, A., Drake, A., Bertin, M., Wan, N.: Bridging the generalization gap: Training robust models on confounded biological data. arXiv preprint arXiv:1812.04778 (2018)
- [16] Liu, X., Li, B., Bron, E.E., Niessen, W.J., Wolvius, E.B., Roshchupkin, G.V.: Projection-wise disentangling for fair and interpretable representation learning: Application to 3d facial shape analysis. In: International Conference on Medical Image Computing and Computer-Assisted Intervention. pp. 814–823. Springer (2021)
- [17] Lu, M., Zhao, Q., Zhang, J., Pohl, K.M., Fei-Fei, L., Niebles, J.C., Adeli, E.: Metadata normalization. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. pp. 10917–10927 (2021)
- [18] Van der Maaten, L., Hinton, G.: Visualizing data using t-sne. Journal of machine learning research 9(11) (2008)
- [19] Neto, E.C.: Causality-aware counterfactual confounding adjustment for feature representations learned by deep models. arXiv preprint arXiv:2004.09466 (2020)
- [20] Petersen, R.C., Aisen, P., Beckett, L.A., Donohue, M., Gamst, A., Harvey, D.J., Jack, C., Jagust, W., Shaw, L., Toga, A., et al.: Alzheimer’s disease neuroimaging initiative (adni): clinical characterization. Neurology 74(3), 201–209 (2010)
- [21] Robbins, H., Monro, S.: A stochastic approximation method. The annals of mathematical statistics pp. 400–407 (1951)
- [22] Tartaglione, E., Barbano, C.A., Grangetto, M.: End: Entangling and disentangling deep representations for bias correction. In: Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. pp. 13508–13517 (2021)
- [23] Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A.N., Kaiser, Ł., Polosukhin, I.: Attention is all you need. Advances in neural information processing systems 30 (2017)
- [24] Yao, Z., Cao, Y., Lin, Y., Liu, Z., Zhang, Z., Hu, H.: Leveraging batch normalization for vision transformers. In: Proceedings of the IEEE/CVF International Conference on Computer Vision. pp. 413–422 (2021)
- [25] Yong, H., Huang, J., Meng, D., Hua, X., Zhang, L.: Momentum batch normalization for deep learning with small batch size. In: European Conference on Computer Vision. pp. 224–240. Springer (2020)
- [26] Zhang, Y., Kwon, D., Esmaeili-Firidouni, P., Pfefferbaum, A., Sullivan, E.V., Javitz, H., Valcour, V., Pohl, K.M.: Extracting patterns of morphometry distinguishing hiv associated neurodegeneration from mild cognitive impairment via group cardinality constrained classification. Human brain mapping 37(12), 4523–4538 (2016)
- [27] Zhao, Q., Adeli, E., Pohl, K.M.: Training confounder-free deep learning models for medical applications. Nature communications 11(1), 1–9 (2020)
- [28] Zhong, G., Wang, L.N., Ling, X., Dong, J.: An overview on data representation learning: From traditional feature learning to recent deep learning. The Journal of Finance and Data Science 2(4), 265–278 (2016)