Environment Diversification with Multi-head Neural Network for Invariant Learning
Abstract
Neural networks are often trained with empirical risk minimization; however, it has been shown that a shift between training and testing distributions can cause unpredictable performance degradation. On this issue, a research direction, invariant learning, has been proposed to extract invariant features insensitive to the distributional changes. This work proposes EDNIL, an invariant learning framework containing a multi-head neural network to absorb data biases. We show that this framework does not require prior knowledge about environments or strong assumptions about the pre-trained model. We also reveal that the proposed algorithm has theoretical connections to recent studies discussing properties of variant and invariant features. Finally, we demonstrate that models trained with EDNIL are empirically more robust against distributional shifts.
1 Introduction
Ensuring model performance on unseen data is a common yet challenging task in machine learning. A widely adopted solution would be Empirical Risk Minimization (ERM), where training and testing data are assumed to be independent and identically distributed (i.i.d.). However, data in real-world applications can come with undesired biases, causing a shift between training and testing distributions. It has been known that the distributional shifts can severely harm ERM model performance and even cause the trained model to be worse than random predictions [10]. In this work, we focus on Invariant Learning, which aims at learning invariant features expected to be robust against distributional changes. Invariant Risk Minimization (IRM) [1] has been proposed as a popular solution for invariant learning. Specifically, IRM is based on an assumption that training data are collected from multiple sources or environments having distinct data distributions. The learning objective is then designed as a standard ERM loss function with a penalty term constraining the trained model (e.g. classifier) to be optimal in all the environments.
IRM has shown to be effective; however, we note that IRM and many invariant learning methods rely on strict assumptions which limit the practical impacts. The limitations are summarized as follows.
Prior Knowledge of Environments IRM assumes training data are collected from environments, and the environment labels (i.e. which data instance belongs to which environment) are given. However, the environment labels are often unavailable. Moreover, the definition of environments can be implicit, making human labeling more difficult and expensive. To find environments without supervision, Creager et al. [6] propose Environment Inference for Invariant Learning (EIIL), a min-max optimization framework training the model of interest and inferring the environment labels. Another work, Heterogeneous Risk Minimization (HRM) [22], or its extension called Kernelized Heterogeneous Risk Minimization (KerHRM) [23], parameterizes the environments and proposes clustering-based approaches to estimate the parameters. A recent method, ZIN [20], also learns to label data. However, it relies on carefully chosen features satisfying a series of theoretical constraints, and thus human efforts are still required.
Delicate Initialization EIIL is able to infer environments but requires an ERM model for initialization. Crucially, the ERM model should heavily depend on spurious correlations. Creager et al. [6] reveal that, for example, slightly underfitted ERM models may encode more spurious relationships in some cases. However, as the distributional shifts are assumed to be unknown in the training stage, appropriate initialization might be difficult to guarantee.
Efficiency Issue HRM and KerHRM, though do not possess the above two limitations, suffer from the efficiency issue. Specifically, HRM is assumed to be trained on raw features. As for KerHRM, although it extends HRM to avoid the issue of representation learning by adopting kernel methods, the computational costs of the proposed method can be very high if the data or model size is large.
This work proposes a novel framework, Environment Diversification with multi-head neural Network for Invariant Learning (EDNIL). EDNIL is able to infer environment labels without supervision and achieve joint optimization of environment inference and invariant learning models. The underlying multi-head neural network explicitly diversifies the inferred environments, which is consistent with recent studies [5, 22, 23] revealing the benefits of diverse environments. Notably, the proposed neural network is functionally similar to a multi-class classifier and can be optimized efficiently. The advantages of EDNIL are summarized as:
- •
- •
-
•
EDNIL mitigates the three limitations discussed above. The comparisons between EDNIL and other methods are shown in Table 1.
Unsupervised111A method is unsupervised if it does not require extra human efforts to obtain environments. | Insensitive Initialization | Efficiency | |
IRM [1] | ✗ | ✓ | ✓ |
ZIN [20] | ✗ | ✓ | ✓ |
EIIL [6] | ✓ | ✗ | ✓ |
HRM [22] | ✓ | ✓ | ✗ |
KerHRM [23] | ✓ | ✓ | ✗ |
EDNIL (Ours) | ✓ | ✓ | ✓ |
2 Preliminaries and Related Work
The goal of EDNIL is to tackle out-of-distribution problems with invariant learning in the absence of manual environment labels. In Section 2.1, background knowledge about out-of-distribution generalization and invariant learning are introduced. In Section 2.2, we discuss recent studies investigating ideal environments. In Section 2.3, we introduce the existing unsupervised methods inferring environments.
2.1 Out-of-Distribution Generalization and Invariant Learning
Following [1, 22], we consider a dataset with different sources collected under multiple training environments . Random variable indicates training environments. For simplicity, , and denote features, target and distribution in environment , respectively.
With containing all possible environments such that , the goal of out-of-distribution generalization is to learn a predictor as Equation 1, where measures the risk under environment with any loss function . In general, for and , is rather different from .
(1) |
Recently, several studies [1, 26, 28, 18, 4] have attempted to tackle the generalization problems by discovering invariant relationships across all environments. A commonly proposed assumption is the existence of invariant features and variant features . Specifically, raw features are assumed to be composed of and , or where is a transformation function. Invariant features are assumed to be equally informative for predicting targets across environments . On the contrary, the distribution can arbitrarily vary across . As a result, predictors depending on can have unpredictable performance in unseen environments. In particular, the correlations between and are known as spurious and unreliable.
To extract , IRM [1] assumes there is an encoder for obtaining representations . The encoder is trained with a regularization term enforcing simultaneous optimality of the predictor in training environments, where dummy classifier is a fixed multiplier for the encoder outputs:
(2) |
As is a dummy layer, the encoder is also regarded as a predictor.
2.2 Ideal Environments
As is unavailable or sub-optimal in most applications, learning to find appropriate environments (denoted by ) is attractive. However, the challenge is the lack of knowledge of valid environments. Recently, Lin et al. [20] have proposed Equation 3 and 4 as the conditions of ideal environments, where is conditional entropy. To satisfy the conditions, Lin et al. [20] propose leveraging auxiliary information for model training. However, the method still requires extra human efforts to collect and verify the additional information.
(3) |
(4) |
Particularly, Equation 4 can be implied by empirical studies [5] where diversity of environments is recognized as the key to obtaining effective IRM models. To be more precise, large discrepancy of spurious correlations, or , between environments is favored. As the environments give a clear indication of distributional shifts, IRM can easily identify and eliminate variant features. Beyond IRM, HRM [22] and KerHRM [23] can also be viewed as optimizing diversity via clustering-based methods specifically.
2.3 Unsupervised Environment Inference
Here we provide a detailed introduction of the existing environment inference methods that do not require extra human efforts. The general idea is to integrate environment inference with invariant learning algorithms who require provided environments.
EIIL [6] proposes formulating invariant learning as a min-max optimization problem. Specifically, EIIL is composed of two objectives, Environment Inference () and Invariant Learning (), where is optimized by maximizing the training penalty via labeling the data, and is optimized by minimizing the training loss given the data labeled by . The two-stage framework bypasses the difficulty of defining environments; however, the training result heavily relies on the initialization of the optimization. Specifically, the initialization demands a strongly biased ERM reference model; otherwise, EIIL can have a significantly weaker performance.
Another method, HRM [22], proposes a clustering-based method for learning plausible environments. HRM assumes that spurious correlations in each environment can be modeled by a parameterized function and the dataset is generated by the mixture of the functions. The parameters are then estimated by employing EM algorithm. Additionally, HRM equips a joint learning framework which alternatively learns invariant predictors and improves quality of clustering results. However, a known issue of HRM is an assumption that the data are represented by raw features. Data such as images and texts requiring non-linear neural networks to obtain representations are beyond the capability.
To extend HRM to a broader class of applications and improve the model performance, Liu et al. [23] propose KerHRM. The main idea is to adopt the Neural Tangent Kernel [15] method which transforms non-linear neural network training into a linear regression problem on the proposed Neural Tangent Features space. As a result, KerHRM elegantly resolves the shortcomings of HRM and is shown to be more effective. However, the proposed method and its implementations bring additional computational costs depending on data and model capacity. For applications favoring large datasets and pre-trained models, such as Resnet [14] and BERT [7], KerHRM may not be an affordable option at the present stage.
3 Methodology


In this section, we propose a general framework to learn invariant model without manual environment labels. As shown in Figure 1(a), our proposed method consists of two models, and . Given the pooled data , infers environments satisfying Condition 3 and 4, and is an invariant model trained with the inferred environments. Our framework is jointly optimized with alternating updates. The learned can provide information of invariant features to , so that Condition 3 and 4 can be fulfilled simultaneously. Note that only serves at train time to provide environments for invariant learning. At test time, only is needed to perform invariant predictions.
3.1 The Environment Inference Model

The target of environment inference is to partition data into environments satisfying Condition 3 and 4. In this regard, we propose a graphical model, which is a sufficient condition for Condition 3 and 4 (the proof is in Appendix A), as our foundation of inference model and learning objectives. The graph is shown in Figure 2, where the data generation process follows the proposed example in [1].
The inference model, , is a neural network learning to realize the underlying graphical model. Following the idea of parameterizing environments from HRM [22] and KerHRM [23], we assume the distinct mapping relation between and in environment can be modeled by a function , where is learned representations expected to encode variant features and is an environmental function responsible for predicting . Instead of employing clusters, we propose building a multi-head neural network as shown in Figure 1(b); particularly, a single-head network in with shared parameters corresponds to a cluster center in HRM or KerHRM. The training procedure of can be divided into inference stage and learning stage.
3.1.1 Inference Stage of
The goal is to infer an environment label for each training data. As in the graphical model, is associated with variant relationships. Inspired by multi-class classification problem, we propose Equation 5, where the probability is estimated via a softmax of negative divided by a constant temperature . The function is expected to be the commonly used cross entropy or mean squared error that measures the discrepancy between and for each environment . Intuitively, each data prefers the environment whose model has better prediction.
(5) |
3.1.2 Learning Stage of
The goal is to update the neural network to improve the quality of inference. Based on the structure of the graphical model, three losses are designed for minimization, i.e. Environment Diversification Loss (), Label Independence Loss () and Invariance Preserving Loss (). In particular, and correspond to the concepts of Condition 4 and 3, respectively.
Environment Diversification Loss () We consider maximizing to satisfy Condition 4 and capture more diverse variant relationships. Given the estimated , selects the most probable environment and its corresponding network for optimization:
(6) |
For each data , although only one environment is selected for the minimization, the softmax simultaneously propagates gradient to maximize for . The network learns to maximize the dependency between and given variant representations. In terms of spurious correlations, the distinction between environments is expected to become clearer accordingly. In practice, we utilize scaling weight inversely proportional to the size of . The importance of smaller environments will be thus enhanced within the summation.
Label Independence Loss () With d-separation [25], is independent of in the graphical model. Hence, constraints their dependency measured by mutual information . It can be verified that minimizing is equivalent to minimizing Equation 7 given is known. Empirically, prevents a trivial solution that environments are determined purely by target labels regardless of input features, which is undesirable for invariant learning.
(7) |
Invariance Preserving Loss () For Condition 3, as learns some invariant relationships after several training steps (Section 3.2), can be considered to exclude invariant features from the diversification. Specifically, designed as the contrary of , limits the variance of expected loss from invariant predictor (in ) across environments (Equation 8). Similar idea can be found in [17]. However, instead of training an invariant model given known environments, we freeze the invariant predictor and regularize the adjustment of (i.e. the updates of ) here.
(8) |
In summary, the training loss of can be summarized as Environment Inference Loss (). The regularization strengths of and can be controlled by hyper-parameters and , respectively:
(9) |
In addition, before minimizing , we pre-train our and one arbitrary with ERM. In general, it empirically facilitates better feature extraction. Unlike EIIL [6] taking ERM as a reference model heavily relying on variant features, EDNIL performs more consistently under various choices of ERM. Namely, the initialization of EDNIL can be more arbitrary than that of EIIL. We verify the argument in Section 4.
3.2 The Invariant Learning Model
To identify invariance across environments, IRM [1] is selected as our base algorithm optimizing the invariant predictor in our model . As for the required environment partitions during training, we assign environment label with largest , inferred by (Section 3.1.1), to each data . However, it is inevitable that there exist some noises in automatically inferred environments, especially in the beginning of joint optimization. To reduce the impact of immature environments on invariant learning, we calculate confidence score for each environment . Our training objective is modified to minimize Invariant Learning Loss () that considers the weighted average of environmental losses:
(10) |
(11) |
4 Experiments
We empirically validate the proposed method on biased datasets, Adult-Confounded, CMNIST, Waterbirds and SNLI. The definition of spurious correlations mainly follows the protocols proposed by [6, 1, 29, 9]. In Section 4.1, Adult-Confounded and CMNIST are tested with Multilayer Perceptron (MLP). In Section 4.2, two more complex datasets, Waterbirds and SNLI, are used to evaluate the integration of transfer learning. Deep pre-trained models will be fine-tuned to extract representations.
The following four methods are selected as our competitors: Empirical Risk Minimization (ERM), Environment Inference for Invariant Learning (EIIL [6]), Kernelized Heterogeneous Risk Minimization (KerHRM [23]) and Invariant Risk Minimization (IRM [1], Equation 2). EIIL and KerHRM are invariant learning methods with unsupervised environment inference, which share the same settings as EDNIL. HRM [22] is replaced by KerHRM for non-linearity. For IRM who requires environment partitions, we re-label on each given biased training set, which diversifies the spurious relationships to elicit upper-bound performance of IRM.
For hyper-parameter tuning, we utilize an in-distribution validation set composed of 10% of training data. In each dataset, several testing environments with different distributions are listed to evaluate the robustness of each method, and we mainly take worst-case performance for assessment. As all tasks in this section are classification problems, accuracy is adopted as the evaluation metric.
Besides, more experimental details are revealed in Appendix B. We also discuss additional experiments in Appendix C, including solutions to regression problem and model stability given different spurious strengths at train time.
4.1 Simple Datasets with MLP
This section includes two simple datasets, Adult-Confounded and CMNIST, where spurious correlations are synthetically produced with the predefined strengths. For all competitors, MLP is taken as the base model and full-batch training is implemented. Since KerHRM performs unstably over random seeds, we first average the results after 10 runs as its first score, and select top 5 among them as the second one, which will be marked with an asterisk (∗) in each table.
4.1.1 Discussions on Adult-Confounded
We take UCI Adult [16] to predict binarized income levels (above or below $50,000 per year)222UCI Adult dataset is widely used in algorithmic fairness papers. However, a recent study [8] discusses some limitations of this dataset, such as the choice of income threshold.. Following [6], individuals are re-sampled according to sensitive features race and sex to simulate spurious correlations. Specifically, with binarized race (Black/Non-Black) and sex (Female/Male), four possible subgroups are constructed: Non-black Males (SG1), Non-black Females (SG2), Black Males (SG3), and Black Females (SG4). Keeping original train/test split and subgroup sizes from UCI Adult, we sample data based on the given target distributions in each sensitive subgroup as shown in Table 3. In particular, OOD contributes the worst-case performance to validate if the predictions rely on group information. In this task, MLP with one hidden layer of 96 neurons is considered. For IRM, four environments comprise , where the correlations between variant features and target are distributed without overlapping. More details are provided in Appendix B.
Results The results are shown in Table 3. With strong spurious correlations at train time, ERM obtains high accuracy as the correlations remain aligned; however, its generalization to other testing distributions is limited. Among all invariant learning methods without prior environment labels, EDNIL can perfectly identify variant features and infer ideally diversified environments. Therefore, it achieves the most invariant performance over different testing distributions. On the other hand, EIIL can improve consistency to some degree, but not as strong as EDNIL. One possible reason is that empirically trained reference model is not guaranteed to be purely variant [6]. For KerHRM, it performs inconsistently across random seeds, which is reflected in the large standard deviation. In some cases, the performance hardly improves over iterations, as observed by Liu et al. [23].
Ablation Study for We first claim the importance of , which constrains label dependency, by setting the coefficient to zero. As discussed in Section 3, the resulting environments are determined purely by target labels, leading to inferior performance as shown in Table 3. Next, we demonstrate the effectiveness of joint optimization in Figure 3(a). The regularization promotes environment inference by considering existing invariant relationships, so that the worst-case performance improves and remains stable over iterations. According to Table 3, if the coefficient is turned off, feedback generated by will be ignored and the effect of invariant learning may become undesirable.
Train | Test | Test | Test | |
(IID) | (IND) | (OOD) | ||
SG1 | 0.9 | 0.9 | 0.5 | 0.1 |
SG2 | 0.1 | 0.1 | 0.5 | 0.9 |
SG3 | 0.9 | 0.9 | 0.5 | 0.1 |
SG4 | 0.1 | 0.1 | 0.5 | 0.9 |
IID | IND | OOD | |
ERM | 92.4 0.1 | 66.8 0.3 | 40.7 0.5 |
EIIL | 76.2 0.4 | 73.5 0.5 | 70.2 1.7 |
KerHRM | 82.4 3.9 | 75.1 4.0 | 67.9 9.3 |
KerHRM∗ | 81.2 1.8 | 78.5 0.3 | 75.6 1.9 |
EDNIL | 80.7 0.4 | 79.1 0.4 | 77.5 0.3 |
EDNILβ=0 | 91.8 0.0 | 66.7 0.1 | 41.3 0.7 |
EDNILγ=0 | 78.2 2.4 | 75.4 1.6 | 72.5 3.3 |
IRM | 79.9 0.4 | 79.3 0.3 | 78.8 0.4 |
4.1.2 Discussions on CMNIST
We report our evaluation on a noisy digit recognition dataset, CMNIST. Following [1], we first assign to those whose digits are smaller than 5 and to the others. Next, we apply label noise by randomly flipping with probability 0.2. Finally, the digits are colored with color labels , which are generated by randomly flipping with probability . For training, two equal-sized environments with and are merged, which is equivalent to one with on average. For testing, three situations are considered when is 0.1, 0.5 or 0.9, respectively. Note that when , the spurious correlation is much aligned with the training set. On the other hand, defines the most challenging case since the spurious correlation shifts most dramatically from training.
For all competitors except KerHRM, we select MLP with two hidden layers of 390 neurons, and consider the whole dataset (50,000 samples) for training. For KerHRM who requires massive computing resources, we follow the settings recommended by [23]. Specifically, we randomly select 5,000 samples and train MLP with one hidden layer of 1024 neurons. To construct ideally diversified for IRM, we pack all examples with into one environments, and into the other.
Results The results are shown in Table 4. First of all, not surprisingly ERM still adopts poorly to distributional shifts. Among all invariant learning methods without manual environment labels, EDNIL gets closest to IRM with , achieving consistent and robust performance in this dataset. As shown in Figure 3(c), EDNIL provides almost ideally diversified for invariant learning.
Color Noise | 0.1 | 0.5 | 0.9 |
ERM | 88.4 0.3 | 55.0 0.5 | 21.7 0.8 |
EIIL | 79.6 0.3 | 71.7 0.7 | 63.1 0.5 |
KerHRM | 74.3 0.7 | 66.2 1.7 | 58.0 11.5 |
KerHRM∗ | 71.3 0.7 | 68.5 0.5 | 66.1 0.7 |
EDNIL | 77.7 0.4 | 76.8 0.3 | 75.2 0.4 |
IRM | 77.8 0.4 | 76.8 0.4 | 75.2 0.3 |
Number of Environments While the predefined number of environments is a hyper-parameter that requires tuning, as illustrated in Figure 3(b), EDNIL can demonstrate a certain level of tolerance towards this parameter. Specifically, when the environment number is larger than the oracle (i.e. 2), some environment classifiers become redundant. Each of them provides a moderate constant loss, taking up fixed and ignorable space in the softmax function. The visualization of with 5 available environments is shown in Figure 3(c). Additionally, training in EDNIL with more environments is much more efficient than clustering-based methods, such as the one proposed in KerHRM.




4.2 Complex Datasets with Pre-trained Deep Learning Models
This section extends MLP to deep learning models with pre-trained weights for more complex data. With mini-batch fine-tuning, we consider all competitors but KerHRM due to its efficiency issue. In Section 4.2.1, image dataset, Waterbirds [29], with controlled spurious correlations is selected for evaluating the generalization on more high-dimensional images. In Section 4.2.2, a real-world NLP dataset, SNLI [3], is considered. The biases in SNLI are naturally derived from the procedure of data collection, and we define biased subsets for evaluation following Dranker et al. [9].
4.2.1 Discussions on Waterbirds
In Waterbirds [29], each bird photograph, from CUB dataset [31], is combined with one background image, from Places dataset [33]. Both birds and backgrounds are either from land or water, and our target is to predict the binarized species of birds. At train time, landbirds and waterbirds are frequently present in land and water backgrounds, respectively. Therefore, empirically trained models are prone to learn context features, and fail to generalize as background varies [2, 6, 10, 21, 29].
To split an in-distribution validation set, we merge original training and validation data333In the original training split, backgrounds are unequally distributed in each class. However, they are equally distributed in the original validation split. and split 10% for hyper-parameter tuning. For testing, we observe all four combinations of birds and backgrounds in the original testing set. Among them, the minor subgroup (waterbirds on land) contributes the most challenging case. In this task, Resnet-34 [14] is chosen for mini-batch fine-tuning. Given target and background , we distribute and into two different environments and apply balanced class weights for the oracle settings of IRM.
Results The results are shown in Table 5. As observed in [6, 29], ERM suffers in the hardest case (i.e. waterbirds on land). EIIL also performs poorly in this case. With a more sophisticated learning framework, EDNIL narrows the gaps between subgroups and raises the worst-case performance. The results show that EDNIL can be more resistant to distributional shifts.
Choice of Initialization Both EIIL and EDNIL take ERM as initialization. As mentioned in Section 1, heavy dependency on initialization is risky when testing distribution is unavailable. Therefore, we take ERM with different training steps for EIIL and EDNIL to verify the stability. The results in Figure 3(d) reveal the consistency of EDNIL, which accentuates our strength of less sensitive initialization. As implied by [6], EIIL could fail with a more well-fitted reference model in this case since ERM might get distracted from variant features. One is prone to be misled into an undesirable choice for EIIL when seeking hyper-parameters without prior knowledge of distributional shifts. For instance, the validation score of EIIL with a 500-step reference model (95.8%) slightly surpasses that of the 100-step model (94.7%), but their performance on testing exhibits significant discrepancy. On top of that, the selection of an effective reference model could be more intractable when spurious correlations in data are relatively weak. Supporting evidence can be observed in Appendix C.2.
(Y, BG) | (Land, Land) | (Water, Water) | (Land, Water) | (Water, Land) |
ERM | 99.4 0.0 | 91.4 0.2 | 90.9 0.8 | 72.8 1.0 |
EIIL | 99.4 0.3 | 90.5 1.8 | 89.3 3.7 | 68.6 5.4 |
EDNIL | 98.5 0.6 | 89.9 1.5 | 90.3 3.0 | 78.6 4.3 |
IRM | 98.0 0.5 | 90.6 1.1 | 89.5 1.7 | 83.2 2.2 |
4.2.2 Discussions on SNLI
The target of SNLI [3] is to predict the relation between two given sentences, premise and hypothesis. Recent studies [12, 24, 27] reveal hypothesis bias in SNLI, which is characterized by patterns in hypothesis sentences highly correlated with a specific label. One can achieve low empirical risk without considering premises during prediction. However, as the bias no longer holds, the performance degradation occurs [11, 24].
We sample 100,000 examples and consider all classes, entailment, neutral and contradiction, for our experiment. Following [9], we define three subsets by training a biased model with hypothesis as its only input. The specification of the subsets is as follows:
-
•
Unbiased: Examples whose predictions from the biased model are ambiguous
-
•
Aligned: Examples that the biased model can predict correctly with high confidence
-
•
Misaligned: Examples that the biased model can predict incorrectly with high confidence
The proportions of the three subsets are 17%, 67% and 16%, respectively. Due to the minority, the bias misaligned subset is more likely to be ignored and thus defines the worst-case performance.
Subset | Unbiased | Aligned | Misaligned |
ERM | 74.6 0.3 | 94.7 0.2 | 52.6 0.9 |
EIIL | 74.2 0.3 | 95.0 0.1 | 51.7 1.3 |
EDNIL | 74.3 0.8 | 94.2 0.2 | 54.5 1.0 |
IRM | 74.0 0.9 | 92.3 0.5 | 56.9 1.1 |
For all methods, DistilBERT [30] is taken as the pre-trained model for further mini-batch fine-tuning. For , we assign the bias aligned subset to the first environment, and the misaligned one to the second. In order to make bias prevalence equal in the two environments, unbiased samples are scattered proportionally to the two environments.
Results The results are shown in Table 6. As reported by [9], ERM receives higher score on the bias aligned subset, but it fails in the bias misaligned case. Among all invariant learning methods without environment labels, only EDNIL improves on the bias misaligned subset. Namely, even though the definitions of biases are at a high level, EDNIL is still capable of encoding and eliminating possible variant features.
5 Limitation
Our learning algorithm for environment inference is based on the graphical model plotted in Figure 2. However, it is important to acknowledge that data may not always conform to the assumed process, resulting in potential biases that cannot be adequately captured by the proposed neural network. In the paper, while we provide empirical studies of effectiveness on diverse datasets, we are still aware that a stronger guarantee of performance is required.
6 Conclusion and Societal Impacts
This work proposes EDNIL for training models invariant to distributional shifts. To infer environments without supervision, we propose a multi-head neural network structure to identify and diversify plausible environments. With joint optimization, the resulting invariant models are shown to be more robust than existing solutions on data having distinct characteristics and different strengths of biases. We attribute the effectiveness to the underlying learning objectives, which are consistent with recent studies of ideal environments. Additionally, we note that the classifier-like structure of environment inference model makes EDNIL easy to combine with off-the-shelf pre-trained models and trained more efficiently.
Our contributions to invariant learning have broader societal impacts on numerous domains. For instance, it can encourage further research and real-world applications on debiasing machine learning systems. The identification and elimination of potential biases can facilitate more robust model training. It can be beneficial to many real applications where distributional shifts commonly occur, such as autonomous driving, social media and healthcare.
Furthermore, as discussed by [6], invariant learning can promote algorithmic fairness in some ways. In particular, our empirical achievements on Adult-Confounded can prevent discrimination against sensitive demographic subgroups in decision-making process. It shows that EDNIL has a potential to learn a fair predictor without prior knowledge of sensitive attributes, which is related to [13, 19, 32]. We expect that one can extend our work to more fairness benchmarks and criteria in the future.
Last but not least, it is worth mentioning some cautions, however. Since the invariant learning algorithm claims to find invariant relationships, one might cast more attention on feature importance of the invariant model and even incorporate the results into further research or applications. Nevertheless, the results are reliable only when the model is trained appropriately. Insufficient data collection or careless training process, for example, can certainly affect the identification of invariant features, and thus mislead experimental findings. As a result, we believe that adequate and careful preparations and analyses are essential before drawing conclusions from the inferred invariant relationships.
Acknowledgments and Disclosure of Funding
We would like to thank the anonymous reviewers for their helpful suggestions. This material is based upon work supported by National Science and Technology Council, ROC under grant number 111-2221-E-002 -146 -MY3 and 110-2634-F-002-050 -.
References
- Arjovsky et al. [2019] Martin Arjovsky, Léon Bottou, Ishaan Gulrajani, and David Lopez-Paz. Invariant risk minimization, 2019. URL https://arxiv.org/abs/1907.02893.
- Beery et al. [2018] Sara Beery, Grant Van Horn, and Pietro Perona. Recognition in terra incognita. In Proceedings of the European Conference on Computer Vision (ECCV), September 2018.
- Bowman et al. [2015] Samuel R. Bowman, Gabor Angeli, Christopher Potts, and Christopher D. Manning. A large annotated corpus for learning natural language inference. In Proceedings of the 2015 Conference on Empirical Methods in Natural Language Processing, pages 632–642, Lisbon, Portugal, sep 2015. Association for Computational Linguistics. doi: 10.18653/v1/D15-1075. URL https://aclanthology.org/D15-1075.
- Chang et al. [2020] Shiyu Chang, Yang Zhang, Mo Yu, and Tommi Jaakkola. Invariant rationalization. In Proceedings of the 37th International Conference on Machine Learning, pages 1448–1458, 2020. URL https://proceedings.mlr.press/v119/chang20c.html.
- Choe et al. [2020] Yo Joong Choe, Jiyeon Ham, and Kyubyong Park. An empirical study of invariant risk minimization. In ICML 2020 Workshop on Uncertainty and Robustness in Deep Learning, 2020.
- Creager et al. [2021] Elliot Creager, Joern-Henrik Jacobsen, and Richard Zemel. Environment inference for invariant learning. In Proceedings of the 38th International Conference on Machine Learning, pages 2189–2200, 2021. URL https://proceedings.mlr.press/v139/creager21a.html.
- Devlin et al. [2019] Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. BERT: Pre-training of deep bidirectional transformers for language understanding. In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers), pages 4171–4186, Minneapolis, Minnesota, June 2019. Association for Computational Linguistics. doi: 10.18653/v1/N19-1423. URL https://aclanthology.org/N19-1423.
- Ding et al. [2021] Frances Ding, Moritz Hardt, John Miller, and Ludwig Schmidt. Retiring adult: New datasets for fair machine learning. In Advances in Neural Information Processing Systems, pages 6478–6490, 2021. URL https://proceedings.neurips.cc/paper/2021/file/32e54441e6382a7fbacbbbaf3c450059-Paper.pdf.
- Dranker et al. [2021] Yana Dranker, He He, and Yonatan Belinkov. IRM—when it works and when it doesn’t: A test case of natural language inference. In Advances in Neural Information Processing Systems, 2021. URL https://openreview.net/forum?id=KtvHbjCF4v.
- Geirhos et al. [2020] Robert Geirhos, Jörn-Henrik Jacobsen, Claudio Michaelis, Richard Zemel, Wieland Brendel, Matthias Bethge, and Felix A. Wichmann. Shortcut learning in deep neural networks. Nature Machine Intelligence, 2(11):665–673, 2020. URL https://doi.org/10.1038%2Fs42256-020-00257-z.
- Glockner et al. [2018] Max Glockner, Vered Shwartz, and Yoav Goldberg. Breaking NLI systems with sentences that require simple lexical inferences. In Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers), pages 650–655, Melbourne, Australia, July 2018. Association for Computational Linguistics. doi: 10.18653/v1/P18-2103. URL https://aclanthology.org/P18-2103.
- Gururangan et al. [2018] Suchin Gururangan, Swabha Swayamdipta, Omer Levy, Roy Schwartz, Samuel Bowman, and Noah A. Smith. Annotation artifacts in natural language inference data. In Proceedings of the 2018 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 2 (Short Papers), pages 107–112, June 2018. doi: 10.18653/v1/N18-2017. URL https://aclanthology.org/N18-2017.
- Hashimoto et al. [2018] Tatsunori Hashimoto, Megha Srivastava, Hongseok Namkoong, and Percy Liang. Fairness without demographics in repeated loss minimization. In Proceedings of the 35th International Conference on Machine Learning, pages 1929–1938, 2018. URL https://proceedings.mlr.press/v80/hashimoto18a.html.
- He et al. [2016] K. He, X. Zhang, S. Ren, and J. Sun. Deep residual learning for image recognition. In 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pages 770–778, Los Alamitos, CA, USA, jun 2016. IEEE Computer Society. doi: 10.1109/CVPR.2016.90. URL https://doi.ieeecomputersociety.org/10.1109/CVPR.2016.90.
- Jacot et al. [2018] Arthur Jacot, Clément Hongler, and Franck Gabriel. Neural tangent kernel: Convergence and generalization in neural networks. In Samy Bengio, Hanna M. Wallach, Hugo Larochelle, Kristen Grauman, Nicolò Cesa-Bianchi, and Roman Garnett, editors, Advances in Neural Information Processing Systems 31: Annual Conference on Neural Information Processing Systems 2018, NeurIPS 2018, 3-8 December 2018, Montréal, Canada, pages 8580–8589, 2018.
- Kohavi and Becker [1996] Ronny Kohavi and Barry Becker. Adult. UCI Machine Learning Repository, 1996. Accessed: 2021-09-06.
- Krueger et al. [2021] David Krueger, Ethan Caballero, Joern-Henrik Jacobsen, Amy Zhang, Jonathan Binas, Dinghuai Zhang, Remi Le Priol, and Aaron Courville. Out-of-distribution generalization via risk extrapolation (rex). In Proceedings of the 38th International Conference on Machine Learning, pages 5815–5826, 2021. URL https://proceedings.mlr.press/v139/krueger21a.html.
- Kuang et al. [2020] Kun Kuang, Ruoxuan Xiong, Peng Cui, Susan Athey, and Bo Li. Stable prediction with model misspecification and agnostic distribution shift. In Proceedings of the AAAI Conference on Artificial Intelligence, pages 4485–4492, 04 2020. doi: 10.1609/aaai.v34i04.5876.
- Lahoti et al. [2020] Preethi Lahoti, Alex Beutel, Jilin Chen, Kang Lee, Flavien Prost, Nithum Thain, Xuezhi Wang, and Ed Chi. Fairness without demographics through adversarially reweighted learning. In Advances in Neural Information Processing Systems, pages 728–740, 2020. URL https://proceedings.neurips.cc/paper/2020/file/07fc15c9d169ee48573edd749d25945d-Paper.pdf.
- Lin et al. [2022] Yong Lin, Shengyu Zhu, and Peng Cui. Zin: When and how to learn invariance by environment inference?, 2022. URL https://arxiv.org/abs/2203.05818.
- Liu et al. [2021a] Evan Z Liu, Behzad Haghgoo, Annie S Chen, Aditi Raghunathan, Pang Wei Koh, Shiori Sagawa, Percy Liang, and Chelsea Finn. Just train twice: Improving group robustness without training group information. In Proceedings of the 38th International Conference on Machine Learning, pages 6781–6792, 2021a. URL https://proceedings.mlr.press/v139/liu21f.html.
- Liu et al. [2021b] Jiashuo Liu, Zheyuan Hu, Peng Cui, Bo Li, and Zheyan Shen. Heterogeneous risk minimization. In Proceedings of the 38th International Conference on Machine Learning, pages 6804–6814, 2021b. URL https://proceedings.mlr.press/v139/liu21h.html.
- Liu et al. [2021c] Jiashuo Liu, Zheyuan Hu, Peng Cui, Bo Li, and Zheyan Shen. Kernelized heterogeneous risk minimization. In Advances in Neural Information Processing Systems, pages 21720–21731, 2021c. URL https://proceedings.neurips.cc/paper/2021/file/b59a51a3c0bf9c5228fde841714f523a-Paper.pdf.
- McCoy et al. [2019] Tom McCoy, Ellie Pavlick, and Tal Linzen. Right for the wrong reasons: Diagnosing syntactic heuristics in natural language inference. In Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics, pages 3428–3448, Florence, Italy, July 2019. Association for Computational Linguistics. doi: 10.18653/v1/P19-1334. URL https://aclanthology.org/P19-1334.
- Pearl [1988] Judea Pearl. Probabilistic Reasoning in Intelligent Systems. Morgan Kaufmann, 1988.
- Peters et al. [2015] J. Peters, Peter Buhlmann, and Nicolai Meinshausen. Causal inference by using invariant prediction: identification and confidence intervals. Journal of the Royal Statistical Society: Series B (Statistical Methodology), 78, 2015.
- Poliak et al. [2018] Adam Poliak, Jason Naradowsky, Aparajita Haldar, Rachel Rudinger, and Benjamin Van Durme. Hypothesis only baselines in natural language inference. In Proceedings of the Seventh Joint Conference on Lexical and Computational Semantics, pages 180–191, New Orleans, Louisiana, jun 2018. Association for Computational Linguistics. doi: 10.18653/v1/S18-2023. URL https://aclanthology.org/S18-2023.
- Rojas-Carulla et al. [2018] Mateo Rojas-Carulla, Bernhard Schölkopf, Richard E. Turner, and J. Peters. Invariant models for causal transfer learning. J. Mach. Learn. Res., 19:36:1–36:34, 2018.
- Sagawa* et al. [2020] Shiori Sagawa*, Pang Wei Koh*, Tatsunori B. Hashimoto, and Percy Liang. Distributionally robust neural networks. In International Conference on Learning Representations, 2020. URL https://openreview.net/forum?id=ryxGuJrFvS.
- Sanh et al. [2019] Victor Sanh, Lysandre Debut, Julien Chaumond, and Thomas Wolf. Distilbert, a distilled version of bert: smaller, faster, cheaper and lighter. In NeurIPS 2019 Workshop on Energy Efficient Machine Learning and Cognitive Computing, 2019.
- Welinder et al. [2010] P Welinder, S Branson, T Mita, C Wah, F Schroff, S Belongie, and P Perona. Caltech-UCSD Birds 200. Technical Report CNS-TR-2010-001, Caltech, 2010.
- Yan et al. [2020] Shen Yan, Hsien-te Kao, and Emilio Ferrara. Fair class balancing: Enhancing model fairness without observing sensitive attributes. In Proceedings of the 29th ACM International Conference on Information & Knowledge Management, CIKM ’20, page 1715–1724, New York, NY, USA, 2020. Association for Computing Machinery. ISBN 9781450368599. doi: 10.1145/3340531.3411980. URL https://doi.org/10.1145/3340531.3411980.
- Zhou et al. [2018] Bolei Zhou, Agata Lapedriza, Aditya Khosla, Aude Oliva, and Antonio Torralba. Places: A 10 million image database for scene recognition. IEEE Transactions on Pattern Analysis and Machine Intelligence, 40(6):1452–1464, 2018. doi: 10.1109/TPAMI.2017.2723009.
Appendix
Appendix A Proof of the Underlying Graphical Model for EDNIL
We assume data are generated by Equation 12 which is the process adopted by Arjovsky et al. [1] and Lin et al. [20], where , are deterministic functions, and , are random noises independent of and .
(12) |
Equation 12 establishes a chain that . With the additional relation proposed in our graphical model, the following two dependencies can be obtained via d-separation [25]:
(13) |
As can be seen, the conditional independence between and given leads to . On the other hand, the dependency between and given implies that there exists an environment variable satisfying .
Appendix B Experimental Details
B.1 Implementation Resources
All experiments were run on a GeForce RTX 3090 machine. The training time and GPU memory consumption of EDNIL are specified in Table 7. It takes approximately 30 hours for EDNIL to accomplish all tasks, including main experiments and analyses.
For the choices of deep encoders, we utilize Resnet34 from torchvision444https://pytorch.org/hub/pytorch_vision_resnet on Waterbirds, and DistilBERT from Huggingface555https://huggingface.co/distilbert-base-uncased on SNLI. Both network architectures and pre-trained weights are kept as the default.
Adult-Confounded | CMNIST | Waterbirds | SNLI | |
Time | 1 min | 2 min | 40 min | 60 min |
GPU memory | 1.3 GiB | 1.6 GiB | 7.0 GiB | 10.2 GiB |
B.2 Hyper-parameter Tuning
Without leaking into out-of-distribution information, 10% of training data are split to form an in-distribution validation set. Notably, we infer environments on the validation data with and determine hyper-parameters according to the worst-environment score.
For , we select number of environments from 2 to 5, in softmax function from 0.05 to 0.5, and in from 0.2 to 10. In , we clip overaggressive with an upper bound for training stability, and it is chosen from 1.2 to 5. As for , the penalty strength in is selected from {2, 10, 100, 1000}. Following Arjovsky et al. [1], we conduct an annealing mechanism before using the configured penalty strength. The chosen number of annealing iterations ranges from 20% to 80% of the whole. For the complex datasets, we consider a longer annealing period (larger than 50%) to learn basic representations better.
Number of total training steps is decided from 500 to 2000. In Adult-Confounded and CMNIST, full-batch training is implemented due to enough memory space. In Waterbirds and SNLI, batch size is chosen among 128, 256 and 512. The choices of learning rate and optimizer depend on the dataset. For Adult-Confounded, CMNIST and Waterbirds, a learning rate between 2e-4 and 2e-3 is considered. We take Adam as the optimizer for Adult-Confounded and CMNIST, and choose SGD for Waterbirds. For SNLI, a smaller learning rate in {2e-5, 3e-5, 5e-5, 1e-4} is selected when fine-tuning DistilBERT with AdamW.
B.3 Oracle Settings on Adult-Confounded
Given a biased training set and two sensitive features, race and sex, for IRM is constructed according to Table 8. The correlations between variant features and target are maximized within each environment and diversified across environments. As implied by [5, 22], spurious correlations are supposed to be eliminated when an invariant learning algorithm converges properly.
Race | Sex | Race | Sex | |
Non-black | Male | Black | Female | |
Non-black | Female | Black | Male | |
Black | Male | Non-black | Female | |
Black | Female | Non-black | Male |
B.4 Biased Model for SNLI
To define subsets for evaluation on SNLI, we follow the labeling procedure in [9]. Specifically, k-fold cross validation () is applied on the training set. We fine-tune BERT [7] with hypothesis as its only inputs on folds, and score the left-out -th set. For the development and testing sets, we score each example with average predictions from different models. The accuracy of the biased model is approximately 68%. Finally, we set two thresholds , defined by [9], to (0.2, 0.5), where is used to identify unbiased data, and is used to define bias aligned and misaligned sets.
Appendix C Additional Empirical Results
C.1 Synthetic Data for Regression Problem
We further validate our work on regression problem with a synthetic dataset proposed in [23]. The features are , and the target is generated by , where is a random orthogonal matrix for scrambling features and is a non-linear function. are invariant features that is consistent across environments, while are variant features that can arbitrary change according to the following data sampling mechanism:
(14) |
In Equation 14 where , controls the spurious correlation between the certain variable and the target . Specifically, larger represents stronger correlation, and the sign of indicates the direction of correlation. In the training set, there are 1000 examples generated from the environment with and 100 examples from that with . The environment labels are unavailable as in the previous experiments. For testing, two scenarios are considered. First, we define two environments, IID and OOD, to evaluate the generalization under dramatic distributional shifts. In IID, 1100 examples are sampled with the same procedures as training data. In OOD, 1000 examples are generated from the environment with . Secondly, following [23], we evaluate the stability over six testing environments where .
In our regression task, the evaluation metric is mean square error. For all methods, MLP with one hidden layer of 1024 neurons is utilized. When calculating the label independence term for EDNIL, we discretize by quartiles. Empirically, such efficient estimation can improve the quality of environment inference to some degree. We leave more precise approximations of mutual information between discrete and continuous variables for future work.
Results The results of the first testing scenario are listed in Table 10. Among all methods, EDNIL obtains the most consistent scores across IID and OOD. The performance degradation when suggests the importance of .
Table 10 shows the results of the second scenario. As in [23], Mean Error and Std Error represent the mean and standard deviation of errors over six testing environments, respectively. Both of the values are averaged over 20 runs. Similar to the first scenario, EDNIL performs the best and most robustly in out-of-distribution settings. The estimated also gains empirical improvements in this test.
IID | OOD | |
ERM | 0.772 0.079 | 5.431 0.461 |
EIIL | 1.629 0.174 | 3.675 0.756 |
KerHRM | 1.246 0.339 | 3.612 1.082 |
EDNIL | 1.971 0.183 | 2.253 0.422 |
EDNILβ=0 | 1.733 0.340 | 2.933 0.808 |
Mean Error | Std Error | |
ERM | 5.367 | 0.217 |
EIIL | 3.623 | 0.188 |
KerHRM | 3.526 | 0.151 |
EDNIL | 2.218 | 0.103 |
EDNILβ=0 | 2.879 | 0.151 |
C.2 CMNIST with Different Color Noises
In this task, the learning effects of all methods are tested under different strengths of spurious correlation at train time. We select CMNIST with fixed label noise 0.2, and adjust overall color noise from 0.1 to 0.3 to generate five different training sets. For testing, and are considered. Given target and color , samples with are more than those with when , where the direction of spurious correlation is aligned with that in the training sets. On the other hand, when , samples with are in the majority. Due to the reversed correlation, models relying on the variant feature (i.e. color) will be vulnerable to this setting.
Results The results are plotted in Figure 4. As the training color noise increases, the generalization of ERM improves since the spurious correlation decreases at train time. Meanwhile, as indicated in [6], EIIL fails because the reference model, i.e. ERM, is no longer a pure variant predictor. For KerHRM, the instability of inferred environments results in large standard deviation, especially when training color noise is small. In comparison, given different strengths of spurious correlation for training, EDNIL can distinguish invariant features from variant ones, and perform more comparably to IRM (with ) does.

