This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Domain Generalisation via Risk Distribution Matching

Toan Nguyen, Kien Do, Bao Duong, Thin Nguyen
Applied Artificial Intelligence Institute, Deakin University, Australia
{s222165627, k.do, duongng, thin.nguyen}\left\{\text{s222165627,\,k.do,\,duongng,\,thin.nguyen}\right\}@deakin.edu.au
Abstract

We propose a novel approach for domain generalisation (DG) leveraging risk distributions to characterise domains, thereby achieving domain invariance. In our findings, risk distributions effectively highlight differences between training domains and reveal their inherent complexities. In testing, we may observe similar, or potentially intensifying in magnitude, divergences between risk distributions. Hence, we propose a compelling proposition: Minimising the divergences between risk distributions across training domains leads to robust invariance for DG. The key rationale behind this concept is that a model, trained on domain-invariant or stable features, may consistently produce similar risk distributions across various domains. Building upon this idea, we propose Risk Distribution Matching (RDM). Using the maximum mean discrepancy (MMD) distance, RDM aims to minimise the variance of risk distributions across training domains. However, when the number of domains increases, the direct optimisation of variance leads to linear growth in MMD computations, resulting in inefficiency. Instead, we propose an approximation that requires only one MMD computation, by aligning just two distributions: that of the worst-case domain and the aggregated distribution from all domains. Notably, this method empirically outperforms optimising distributional variance while being computationally more efficient. Unlike conventional DG matching algorithms, RDM stands out for its enhanced efficacy by concentrating on scalar risk distributions, sidestepping the pitfalls of high-dimensional challenges seen in feature or gradient matching. Our extensive experiments on standard benchmark datasets demonstrate that RDM shows superior generalisation capability over state-of-the-art DG methods.

1 Introduction

In recent years, deep learning (DL) models have witnessed remarkable achievements and demonstrated super-human performance on training distributions [27]. Nonetheless, this success is accompanied by a caveat - deep models are vulnerable to distributional shifts and exhibit catastrophic failures to unseen out-of-domain data [34, 12]. Such limitations hinder the widespread deployment of DL systems in real-world applications, where domain difference can be induced by several factors, such as spurious correlations [2] or variations in location or time [49].

Refer to caption
Figure 1: Risk distributions derived from training with ERM for the “Art” and “Photo” domains on the validation set of PACS dataset. Beyond low-risk samples, which may resemble training data, the “Photo” domain generally exhibits a larger distribution of risk values compared to “Art”, hinting at an inherent complexity in learning “Photo” samples. The figure indicates our motivation that risk distributions can effectively highlight differences between domains.

In light of these challenges, domain generalisation (DG) aims to produce models capable of generalising to unseen target domains by leveraging data from diverse sets of training domains or environments [38]. An effective approach involves exploring and establishing domain invariance [32], with the expectation that these invariances will similarly apply to related, yet distinct, test domains. To this end, prevailing research focuses on characterising domains through sample representation [31, 38]. The objective is to seek for domain-invariant features by aligning the distributions of hidden representations across various domains. CORAL [54] trains a non-linear transformation that can align the second-order statistics of representations across different layers within deep networks. More, CausIRL [11] aims to match representation distributions that have been intervened upon the spurious factors. While these methods show promise, they can face multiple challenges with the curse of dimensionality [6, 22]. The sparsity of high-dimensional representation spaces can lead to unreliable estimates of statistical properties, which in turn affects the quality of distribution matching techniques. Also, high-dimensional representations may contain many irrelevant or redundant dimensions, which can introduce noise to the true underlying similarities or differences between distributions. As dimensionality rises, computational complexity intensifies, reducing the efficacy of these methods [39]. Such challenges similarly present in DG methods that utilise gradients for domain alignment [50, 46].

In this paper, we propose to utilise scalar risk distributions as a means to characterise domains, leading to successfully exploring and enforcing domain invariance. Our research reveals that risk distributions can be a reliable indicator of domain variation as it effectively highlights differences between training domains. In Figure 1, we present a visual evidence through histograms, contrasting the risk distributions between the “Art” and “Photo” domains on the validation set of PACS dataset [30], derived from training with Empirical Risk Minimisation (ERM) [57]. The “Photo” domain generally exhibits a larger distribution of scalar risks than that of “Art”. This suggests an inherent complexity in learning “Photo” samples, or possibly due to a more limited training dataset compared to “Art”. During the testing phase, similar divergences between risk distributions may emerge, potentially intensifying in magnitude. Hence, we propose a compelling proposition: by minimising the divergences between risk distributions across training domains, we can achieve robust invariance for DG. The underlying rationale for this concept is that a model, when learning domain-invariant and stable features, tends to produce consistent risk distributions across domains.

Building upon this idea, we propose a novel matching approach for DG, namely Risk Distribution Matching (RDM). RDM’s objective is to minimise the variance of risk distributions across all training domains. Inspired by [38], we redefine the distributional variance metric to focus specifically on risk distributions and propose to compute it via the maximum mean discrepancy (MMD) distance [19]. However, when the number of training domains increases, directly optimising the variance induces a linear growth in MMD computations, reducing efficiency. Instead, we propose an approximation that requires only one MMD computation via aligning just two distributions: that of the worst-case (or worst-performing) domain and the aggregated distribution from all domains. Empirically, this approach outperforms optimising distributional variance while significantly reducing computational complexity. Unlike prevailing matching algorithms, RDM can address the high-dimensional challenges and further improve efficacy by exclusively focusing on scalar risk distributions. Notably, our empirical studies show that RDM even exhibits enhanced generalisation while being more convenient to optimise. We summarise our contributions below:

  • We propose RDM, a novel and efficient matching method for DG, based on our two hypotheses: i) risk distribution disparities offer insightful cues into domain variation; ii) reducing these divergences fosters a generalisable and invariant feature-learning predictor.

  • We re-conceptualise the distributional variance metric to exclusively focus on risk distributions, with an objective to minimise it. We further provide an approximate version that aligns only the risk distribution of the worst-case domain with the aggregate from all domains, improving both performance and efficiency.

  • Through extensive experiments on standard benchmark datasets, we empirically show that RDM consistently outperforms state-of-the-art DG methods, showcasing its remarkable generalisation capability.

2 Related Work

Domain Generalisation (DG)

DG aims to develop models that can generalise well on unseen target domains by leveraging knowledge from multiple source domains. Typical DG methods include domain alignment [38, 7, 32], meta learning [29, 4], data augmentation [63, 61], disentangled representation learning [55, 44], robust optimisation [48, 8] and causality-based methods [25, 13, 41]. Our proposed method RDM is related to domain alignment, striving for domain invariance to enhance OOD generalisation. Existing research focuses on characterising domains through sample representations and aligning their distributions across domains to achieve domain-invariant features [1, 38]. CORAL [54] matches mean and variance of representation distributions, while MMD-AAE [31] and FedKA [56] consider matching all moments via the maxmimum mean discrepancy (MMD) distance [19]. Other methods promote domain invariance by minimising contrastive loss [10] between representations sharing the same labels [35, 37]. Many studies bypass the representation focus, instead characterising domains via gradients and achieving invariance by reducing inter-domain gradient variance [50, 60, 46].

Despite their potential, aligning these high-dimensional distributions may be affected by data sparsity, diversity, and high computational demands [6, 22]. Unlike these methods, RDM offers enhanced efficacy by focusing on scalar risk distributions, overcoming the high-dimensional challenges. Further, RDM adopts a novel strategy by efficiently aligning only two distributions: that of the worst-case domain with the aggregate from all domains. From our experiments, RDM generally exhibits better generalisation performance while being more convenient to optimise compared to competing matching techniques. To the best of our knowledge, the incarnation of risk distributions for domain matching in RDM is novel and sensible.

Distribution matching

Distribution matching has been an important topic with a wide range of applications in machine learning such as DG [31, 54], domain adaptation [59, 9], generative modelling [28, 33]. Early methods, like the MMD distance [19], leverage kernel-based approaches to quantify the distance between distributions, laying the foundation for many subsequent DG techniques [31, 38]. Further advancements have explored optimal transport methods, like the Wasserstein distance [36, 3], which provides a geometrically intuitive means to compare distributions. Other metrics, such as the Kullback-Leibler [26] or Jensen-Shannon [16, 14] divergences, can serve to measure the divergence between distributions and may require additional parameters for estimating the density ratio between the two distributions [53]. In this paper, we utilise the MMD distance to align risk distributions. Its inherent advantages include an analytical measure of the divergence between distributions without relying on distribution densities, and its non-parametric nature [19]. Alternative DG methods augment data by utilising distribution matching and style transfer to generate semantic-preserving samples [62, 63]. Our method differs as we emphasise domain invariance via aligning risk distributions, rather than augmenting representation distributions.

Invariance and Causality in DG

Causal methods in DG assume that the causal mechanism of the target given causal input features is invariant while non-causal features may change across domains [2, 45, 25]. Based on this assumption, methods establish domain invariance to recover the causal mechanism, thereby improving generalisation. ICP [45] has shown that the causal predictor has an invariant distribution of residuals in regression models, however, is not suitable for deep learning. EQRM [13] and REx [25] leverage the invariance in the average risks over samples across domains. In contrast to above methods, we consider matching entire risk distributions over samples across domains, which, as our experiments demonstrate, is more powerful and enhances generalisation capability.

3 Preliminaries

Domain generalisation (DG) involves training a classifier ff on data composed of multiple training domains (also called environments) so that ff can perform well on unseen domains at test time. Mathematically, let 𝒟={𝒟1,,𝒟m}\mathcal{D}=\left\{\mathcal{D}_{1},...,\mathcal{D}_{m}\right\} denote the training set consisting of mm different domains/environments, and let 𝒟e:={(xei,yei)}i=1ne\mathcal{D}_{e}:=\left\{\left(x_{e}^{i},y_{e}^{i}\right)\right\}_{i=1}^{n_{e}} denote the training data belonging to domain ee (1em1\leq e\leq m). Given a loss function \ell, the risk of a particular domain sample (xei,yei)(x_{e}^{i},y_{e}^{i}) is denoted by Rei:=(f(xei),yei)R_{e}^{i}:=\ell\left(f\left(x_{e}^{i}\right),y_{e}^{i}\right), and the expected risk R¯e\overline{R}_{e} of domain ee is defined as:

R¯e\displaystyle\overline{R}_{e} :=𝔼(xe,ye)𝒟e[(f(xe),ye)]=𝔼𝒟e[Re]\displaystyle:=\mathbb{E}_{\left(x_{e},y_{e}\right)\sim\mathcal{D}_{e}}\left[\ell\left(f\left(x_{e}\right),y_{e}\right)\right]=\mathbb{E}_{\mathcal{D}_{e}}\left[R_{e}\right] (1)

A common approach to train ff is Empirical Risk Minimisation (ERM) [57] which minimises the expected risks across all training domains. Its loss function, denoted by ERM\mathcal{L}_{\text{ERM}}, is computed as follows:

ERM\displaystyle\mathcal{L}_{\text{$\text{ERM}$}} =𝔼e𝔼(xe,ye)𝒟e[(f(xe),ye)]\displaystyle=\mathbb{E}_{e\sim\mathcal{E}}\mathbb{E}_{\left(x_{e},y_{e}\right)\sim\mathcal{D}_{e}}\left[\ell\left(f\left(x_{e}\right),y_{e}\right)\right] (2)
=𝔼e[R¯e]\displaystyle=\mathbb{E}_{e\sim\mathcal{E}}\left[\overline{R}_{e}\right] (3)

where :={1,,m}\mathcal{E}:=\{1,...,m\} denotes the set of all domains.

4 Risk Distribution Matching

A model ff trained via ERM often struggles with generalisation to new test domains. This is because it tends to capture domain-specific features [2, 41], such as domain styles, to achieve low risks in training domains, rather than focusing on domain-invariant or semantic features. To overcome this issue, we present a novel training objective that bolsters generalisation through domain invariance. Our goal requires utilising a unique domain representative that both characterises each domain and provides valuable insights into domain variation. Specifically, we propose to leverage the distribution of risks over all samples within a domain (or shortly risk distribution) as this representative. Unlike other domain representatives, like latent representation or gradient distributions [31, 50], the risk distribution sidesteps high-dimensional challenges like data sparsity and high computational demands [6, 39]. In essence, a model capturing stable, domain-invariant features may consistently yield similar risk distributions across all domains. In pursuit of invariant models, we propose Risk Distribution Matching (RDM), a novel approach for DG that reduces the divergences between training risk distributions via minimising the distributional variance across them.

Let 𝒯e\mathcal{T}_{e} be the probability distribution over the risks of all samples in domain ee (i.e., {Rei}i=1ne\left\{R_{e}^{i}\right\}_{i=1}^{n_{e}}). We refer to 𝒯e\mathcal{T}_{e} as the risk distribution of domain ee, the representative that effectively captures the core characteristics of the domain. We denote 𝕍({𝒯1,,𝒯m})\mathbb{V}_{\mathbb{R}}\left(\left\{\mathcal{T}_{1},...,\mathcal{T}_{m}\right\}\right) the distributional variance across the risk distributions {𝒯1,,𝒯m}\left\{\mathcal{T}_{1},...,\mathcal{T}_{m}\right\} in the real number space. We achieve our objective by minimising the following loss function:

final:=ERM+λ𝕍({𝒯1,,𝒯m})\mathcal{L}_{\text{final}}:=\ \mathcal{L}_{\text{ERM}}+\lambda\mathbb{V}_{\mathbb{R}}\left(\left\{\mathcal{T}_{1},...,\mathcal{T}_{m}\right\}\right) (4)

where (λ0)\left(\lambda\geq 0\right) is a coefficient balancing between reducing the total training risks with enforcing invariance across domains. λ\lambda is set to 1 unless specified otherwise.

To compute 𝕍({𝒯1,,𝒯m})\mathbb{V}_{\mathbb{R}}\left(\left\{\mathcal{T}_{1},...,\mathcal{T}_{m}\right\}\right), we require a suitable representation for the implicit risk distribution 𝒯e\mathcal{T}_{e} of domain ee. Leveraging kernel mean embedding [51], we express 𝒯e\mathcal{T}_{e} as its embedding, μ𝒯e\mu_{\mathcal{T}_{e}}, within a reproducing kernel Hilbert space (RKHS) \mathcal{H} using a feature map ϕ:\phi:\mathbb{R}\rightarrow\mathcal{H} below:

μ𝒯e\displaystyle\mu_{\mathcal{T}_{e}} 𝔼Re𝒯e[ϕ(Re)]\displaystyle\coloneqq\ \mathbb{E}_{R_{e}\sim\mathcal{T}_{e}}\left[\phi\left(R_{e}\right)\right] (5)
=𝔼Re𝒯e[k(Re,)]\displaystyle=\ \mathbb{E}_{R_{e}\sim\mathcal{T}_{e}}\left[k\left(R_{e},\cdot\right)\right] (6)

where a kernel function k(,):×k\left(\cdot,\cdot\right):\mathbb{R}\times\mathbb{R}\rightarrow\mathbb{R} is introduced to bypass the explicit specification of ϕ\phi. Assuming the condition (𝔼Re𝒯e(k(Re,Re))<)\left(\mathbb{E}_{R_{e}\sim\mathcal{T}_{e}}\left(k\left(R_{e},R_{e}\right)\right)<\infty\right) holds, the mean map μ𝒯e\mu_{\mathcal{T}_{e}} remains an element of \mathcal{H} [19, 31]. It is noteworthy that for a characteristic kernel kk, the representation μ𝒯e\mu_{\mathcal{T}_{e}} within \mathcal{H} is unique [38, 19]. Consequently, two distinct risk distributions 𝒯u\mathcal{T}_{u} and 𝒯v\mathcal{T}_{v} for any domains u,vu,v respectively have different kernel mean embeddings in \mathcal{H}. In this work, we use the RBF kernel, a well-known characteristic kernel defined as k(x,x):=exp(12σxx2)k\left(x,x^{\prime}\right):=\text{exp$\left(-\frac{1}{2\sigma}\left\|x-x^{\prime}\right\|^{2}\right)$}, where σ>0\sigma>0 is the bandwidth parameter.

With the unique representation of 𝒯e\mathcal{T}_{e} established, our objective becomes computing the distributional variance between risk distributions within \mathcal{H}, represented by 𝕍({𝒯1,,𝒯m})\mathbb{V}_{\mathbb{\mathcal{H}}}\left(\left\{\mathcal{T}_{1},...,\mathcal{T}_{m}\right\}\right). Inspired by [38], we redefine the variance metric to focus specifically on risk distributions across multiple domains below:

𝕍({𝒯1,,𝒯m}):=1me=1mμ𝒯eμ𝒯2\mathbb{V}_{\mathcal{H}}\left(\left\{\mathcal{T}_{1},...,\mathcal{T}_{m}\right\}\right):=\ \frac{1}{m}\sum_{e=1}^{m}\left\|\mu_{\mathcal{T}_{e}}-\mu_{\mathcal{T}}\right\|_{\mathcal{H}}^{2} (7)

where 𝒯=1me=1m𝒯e\mathcal{T}=\frac{1}{m}\sum_{e=1}^{m}\mathcal{T}_{e} denotes the probability distribution over the risks of all samples in the entire training set, or equivalently, the set of all mm domains. Meanwhile, μ𝒯e\mu_{\mathcal{T}_{e}} and μ𝒯\mu_{\mathcal{T}} represent the mean embedings of 𝒯e\mathcal{T}_{e} and 𝒯\mathcal{T}, respectively, and are computed as in Eq. 5. Incorporating 𝕍({𝒯1,,𝒯m})\mathbb{V}_{\mathcal{H}}\left(\left\{\mathcal{T}_{1},...,\mathcal{T}_{m}\right\}\right) into our loss function from Eq. 4, we get:

final\displaystyle\mathcal{L}_{\text{final}} :=ERM+λ𝕍({𝒯1,,𝒯m})\displaystyle:=\ \mathcal{L}_{\text{ERM}}+\lambda\mathbb{V}_{\mathbb{\mathcal{H}}}\left(\left\{\mathcal{T}_{1},...,\mathcal{T}_{m}\right\}\right) (8)

Minimising 𝕍({𝒯1,,𝒯m})\mathbb{V}_{\mathbb{\mathcal{H}}}\left(\left\{\mathcal{T}_{1},...,\mathcal{T}_{m}\right\}\right) in Eq. 8 facilitates our objective of equalising risk distributions across all domains, as proven by the theorem below.

Theorem 1.

[38] Given the distributional variance 𝕍({𝒯1,,𝒯m})\mathbb{V}_{\mathbb{\mathcal{H}}}\left(\left\{\mathcal{T}_{1},...,\mathcal{T}_{m}\right\}\right) is calculated with a characteristic kernel kk, 𝕍({𝒯1,,𝒯m})=0\mathbb{V}_{\mathcal{H}}\left(\left\{\mathcal{T}_{1},...,\mathcal{T}_{m}\right\}\right)=0 if and only if 𝒯1==𝒯m(=𝒯)\mathcal{T}_{1}=...=\mathcal{T}_{m}\left(=\mathcal{T}\right).

Proof.

Please refer to our appendix for the proof. ∎

In the next part, we present how to compute the distributional variance using the Maximum Mean Discrepancy (MMD) distance [19], relying only on risk samples. Then, we propose an efficient approximation of optimising the distributional variance, yielding improved empirical performance.

4.1 Maximum Mean Discrepancy

For domain ee, the squared norm, μ𝒯eμ𝒯2\left\|\mu_{\mathcal{T}_{e}}-\mu_{\mathcal{T}}\right\|_{\mathcal{H}}^{2}, defined in Eq. 7, is identified as the squared MMD distance [18] between distributions 𝒯e\mathcal{T}_{e} and 𝒯\mathcal{T}. It is expressed as follows:

MMD2(𝒯e,𝒯)=\displaystyle\text{MMD}^{2}\left(\mathcal{T}_{e},\mathcal{T}\right)=\ μ𝒯eμ𝒯2\displaystyle\left\|\mu_{\mathcal{T}_{e}}-\mu_{\mathcal{T}}\right\|_{\mathcal{H}}^{2} (9)
=\displaystyle=\ 𝔼Re𝒯e[ϕ(Re)]𝔼Rf𝒯[ϕ(Rf)]2\displaystyle\left\|\mathbb{E}_{R_{e}\sim\mathcal{T}_{e}}\left[\phi\left(R_{e}\right)\right]-\mathbb{E}_{R_{f}\sim\mathcal{T}}\left[\phi\left(R_{f}\right)\right]\right\|_{\mathcal{H}}^{2} (10)
=\displaystyle=\ 𝔼Re,Re𝒯eϕ(Re),ϕ(Re)\displaystyle\mathbb{E}_{R_{e},R_{e}^{{}^{\prime}}\sim\mathcal{T}_{e}}\left\langle\phi\left(R_{e}\right),\phi\left(R_{e}^{{}^{\prime}}\right)\right\rangle
2𝔼Re𝒯e;Rf𝒯ϕ(Re),ϕ(Rf)\displaystyle-2\mathbb{E}_{R_{e}\sim\mathcal{T}_{e};R_{f}\sim\mathcal{T}}\left\langle\phi\left(R_{e}\right),\phi\left(R_{f}\right)\right\rangle (11)
+𝔼Rf,Rf𝒯ϕ(Rf),ϕ(Rf)\displaystyle+\mathbb{E}_{R_{f},R_{f}^{{}^{\prime}}\sim\mathcal{T}}\left\langle\phi\left(R_{f}\right),\phi\left(R_{f}^{{}^{\prime}}\right)\right\rangle

where ,\left\langle\cdot,\cdot\right\rangle denote the inner product operation in .\mathcal{H}. Through the kernel trick, we can compute these inner products via the kernel function kk without an explicit form of ϕ\phi below:

MMD2(𝒯e,𝒯)=\displaystyle\text{MMD}^{2}\left(\mathcal{T}_{e},\mathcal{T}\right)=\ 𝔼Re,Re𝒯ek(Re,Re)\displaystyle\mathbb{E}_{R_{e},R_{e}^{{}^{\prime}}\sim\mathcal{T}_{e}}k\left(R_{e},R_{e}^{{}^{\prime}}\right)
2𝔼Re𝒯e;Rf𝒯k(Re,Rf)\displaystyle-2\mathbb{E}_{R_{e}\sim\mathcal{T}_{e};R_{f}\sim\mathcal{T}}k\left(R_{e},R_{f}\right) (12)
+𝔼Rf,Rf𝒯k(f,f)\displaystyle+\mathbb{E}_{R_{f},R_{f}^{{}^{\prime}}\sim\mathcal{T}}k\left(\mathcal{R}_{f},\mathcal{R}_{f}^{{}^{\prime}}\right)

We reformulate our loss function in Eq. 8 to incorporate MMD as follows:

final\displaystyle\mathcal{L}_{\text{final}} :=ERM+λme=1mMMD2(𝒯e,𝒯)\displaystyle:=\ \mathcal{L}_{\text{ERM}}+\frac{\lambda}{m}\sum_{e=1}^{m}\text{MMD}^{2}\left(\mathcal{T}_{e},\mathcal{T}\right) (13)
=ERM+λRDM\displaystyle=\ \mathcal{L}_{\text{ERM}}+\lambda\mathcal{L_{\text{RDM}}} (14)

The loss function RDM\mathcal{L_{\text{RDM}}} involves minimising MMD2(𝒯e,𝒯)\text{MMD}^{2}\left(\mathcal{T}_{e},\mathcal{T}\right) for every domain ee. Ideally, the distributional variance reaches its lowest value at 0 if MMD(𝒯e,𝒯)=0\text{MMD}\left(\mathcal{T}_{e},\mathcal{T}\right)=0, equivalent to (𝒯e=𝒯)\left(\mathcal{T}_{e}=\mathcal{T}\right) [19, 18], across ee domains. The objective also entails aligning each individual risk distribution, 𝒯e\mathcal{T}_{e}, with the aggregated distribution spanning all domains, 𝒯\mathcal{T}. With the characteristic RBF kernel, it can be viewed as matching an infinite number of moments across all risk distributions.

We emphasise our choice of MMD owing to its benefits for effective risk distribution matching: i) MMD is an important member of the Integral Probability Metric family [40] that offers an analytical solution facilitated through RKHS, and ii) MMD enjoys the property of quantifying the dissimilarity between two implicit distributions via their finite samples in a non-parametric manner.

4.2 Further improvement of RDM

We find that effective alignment of risk distributions across mm domains can be achieved by matching the risk distribution of the worst-case domain, denoted as ww, with the combined risk distribution of all domains, offering an approximation to the optimisation of risk distributional variance seen in Eq. 13. This approximate version significantly reduces the MMD distances computation in RDM\mathcal{L}_{\text{RDM}} from O(m)O\left(m\right) to O(1)O\left(1\right), and further improves generalisation, as we demonstrate with empirical evidence in Section 5.

Denote by (w=argmax eR¯e)\left(w=\underset{e\in\mathcal{E}}{\text{argmax }}\overline{R}_{e}\right) the worst-case domain, i.e., the domain that has the largest expected risk in \mathcal{E}. The approximate RDM’s loss,^RDM\hat{\mathcal{L}}_{\text{RDM}}, is computed as follows:

^RDM\displaystyle\mathcal{\hat{L}}_{\text{RDM}} =MMD2(𝒯w,𝒯)\displaystyle=\text{MMD}^{2}\left(\mathcal{T}_{w},\mathcal{T}\right) (15)
RDM\displaystyle\approx\mathcal{L}_{\text{RDM}} (16)

In our experiments, we observed only a small gap between ^RDM\mathcal{\hat{L}}_{\text{RDM}} and RDM\mathcal{L}_{\text{RDM}}, while optimising ^RDM\mathcal{\hat{L}}_{\text{RDM}} proving to be more computationally efficient. The key insight emerges from R¯e\overline{R}_{e}, the first moment (or mean) of 𝒯e\mathcal{T}_{e}. Often, the average risk can serve as a measure of domain uniqueness or divergence [25, 46]. Specifically, a domain with notably distinct mean risk is more likely to diverge greatly from other risk distributions. Under such circumstances, ^RDM\mathcal{\hat{L}}_{\text{RDM}} will be an upper-bound of RDM\mathcal{L_{\text{RDM}}}, as shown by: RDM=1me=1mMMD2(𝒯e,𝒯)1me=1mMMD2(𝒯w,𝒯)=MMD2(𝒯w,𝒯)=^RDM\mathcal{\mathcal{\mathcal{L}_{\text{RDM}}}}=\frac{1}{m}\sum_{e=1}^{m}\text{MMD}^{2}\left(\mathcal{T}_{e},\mathcal{T}\right)\leq\frac{1}{m}\sum_{e=1}^{m}\text{MMD}^{2}\left(\mathcal{T}_{w},\mathcal{T}\right)=\text{MMD}^{2}\left(\mathcal{T}_{w},\mathcal{T}\right)=\mathcal{\hat{L}}_{\text{RDM}}. By optimising ^RDM\mathcal{\hat{L}}_{\text{RDM}}, we can also potentially decrease RDM\mathcal{L}_{\text{RDM}}, thus aligning risk distributions across domains effectively. More, ^RDM\mathcal{\hat{L}}_{\text{RDM}} drives the model to prioritise the worst-case domain’s optimisation. This approach enhances the model’s robustness to extreme training scenarios, which further improves generalisation as proven in [48, 25]. These insights shed light on the superior performance of optimising ^RDM\mathcal{\hat{L}}_{\text{RDM}} over RDM\mathcal{L}_{\text{RDM}}. Therefore, we opted to use ^RDM\mathcal{\hat{L}}_{\text{RDM}}, simplifying our model’s training and further bolstering its OOD performance.

5 Experiments

We evaluate and analyse RDM using a synthetic ColoredMNIST dataset [2] and multiple benchmarks from the DomainBed suite [20]. Each of our claims is backed by empirical evidence in this section. Our source code to reproduce results is available at: https://github.com/nktoan/risk-distribution-matching

5.1 Synthetic Dataset: ColoredMNIST

We evaluate all baselines on a synthetic binary classification task, namely ColoredMNIST [2]. This dataset involves categorising digits (0-9) into two labels: “zero” for 0 to 4 range and “one” for 5 to 9 range, with each digit colored either red or green. The dataset is designed to assess the generalisation and robustness of baseline models against the influence of spurious color features. The dataset contains two training domains, where the chance of red digits being classified as “zero” is 8080% and 9090%, respectively, while this probability decreases to only 1010% during testing. The goal is to train a predictor invariant to “digit color” features, capturing only “digit shape” features.

Following [13], we employ a two-hidden-layer MLP with 390 hidden units for all baselines. Optimised through the Adam optimiser [23] at a learning rate of 0.00010.0001, with a dropout rate of 0.20.2, we train each algorithm for 600600 iterations with a batch size of 25,000. We repeat the experiment ten times over different values of the penalty weight λ\lambda. We find our matching penalty quite small, yielding optimal RDM’s performance within the λ\lambda range of [1000,10000][1000,10000]. We provide more details about experimental settings in the supplementary material.

Algorithm Initialisation
Rand. ERM
ERM 27.9±\pm1.5 27.9±\pm1.5
GroupDRO 27.3±\pm0.9 29.0±\pm1.1
IGA 50.7±\pm1.4 57.7±\pm3.3
IRM 52.5±\pm2.4 69.7±\pm0.9
VREx 55.2±\pm4.0 71.6±\pm0.5
EQRM 53.4±\pm1.7 71.4±\pm0.4
CORAL 55.3±\pm2.8 65.6±\pm1.1
MMD 54.6±\pm3.2 66.4±\pm1.7
RDM (ours) 56.3±\pm1.5 72.4±\pm1.0
Oracle 72.1±\pm0.7
Optimum 75.0
Table 1: ColoredMNIST test accuracy where the best results are marked as bold. Results of other methods are referenced from [13].
Refer to caption
(a)
Refer to caption
(b)
Figure 2: Histograms with their KDE curves depicting the risk distributions of ERM and RDM across four domains on ColoredMNIST. Vertical ticks denote the mean values of all distributions.

We compare RDM with ERM and three different types of algorithms: robust optimisation (GroupDRO [48], IGA [24]), causal methods learning invariance (IRM [2], VREx [25], EQRM [13]) and representation distribution matching (MMD [31], CORAL [54]). All algorithms are run using two distinct network configurations: (i) initialising the network randomly via Xavier method [17]; (ii) pre-training the network with ERM for 400400 iterations prior to performing the algorithms. Table 1 shows that our proposed method RDM surpasses all algorithms, irrespective of the network configuration. RDM exhibits improvements of 1.01.0% and 6.86.8% over CORAL, both without and with pre-trained ERM, respectively, underlining the effectiveness of aligning risk distributions instead of high-dimensional representations. VREx and EQRM, which pursue invariant predictors by equalising average training risks across domains, demonstrate suboptimal performance compared to our approach. This improvement arises from our consideration of the entire risk distributions and the matching of all moments across them, which inherently foster stronger invariance for DG. Notably, all methods experience enhanced performance with ERM initialisation. RDM even excels beyond oracle performance (ERM trained on grayscale digits with 50% red and 50% green) and converges towards optimality.

Figure 2 demonstrates histograms with their KDE curves [42] depicting the risk distributions of ERM and RDM across four domains. The figure confirms our hypothesis that the disparities among risk distributions could serve as a valuable signal of domain variation. ERM’s histogram shows a clear difference between environments with 9090% and 8080% chance of red digits labelled “zero” and those with only 5050% or 1010%. More, ERM tends to overfit to training domains, which negatively impacts its generalisation to test domains. Remarkably, RDM effectively minimises the divergences between risk distributions across all domains, including test domains with lower risks. This also aligns with our motivation: an invariant or stable feature-learning predictor, by displaying similar risk distributions across domains, inherently boosts generalisation.

5.2 DomainBed

Algorithm VLCS PACS OfficeHome TerraIncognita DomainNet Avg
ERM 77.5±\pm0.4 85.5±\pm0.2 66.5±\pm0.3 46.1±\pm1.8 40.9±\pm0.1 63.3
Mixup 77.4±\pm0.6 84.6±\pm0.6 68.1±\pm0.3 47.9±\pm0.8 39.2±\pm0.1 63.4
MLDG 77.2±\pm0.4 84.9±\pm1.0 66.8±\pm0.6 47.7±\pm0.9 41.2±\pm0.1 63.6
GroupDRO 76.7±\pm0.6 84.4±\pm0.8 66.0±\pm0.7 43.2±\pm1.1 33.3±\pm0.2 60.9
IRM 78.5±\pm0.5 83.5±\pm0.8 64.3±\pm2.2 47.6±\pm0.8 33.9±\pm2.8 61.6
VREx 78.3±\pm0.2 84.9±\pm0.6 66.4±\pm0.6 46.4±\pm0.6 33.6±\pm2.9 61.9
EQRM 77.8±\pm0.6 86.5±\pm0.2 67.5±\pm0.1 47.8±\pm0.6 41.0±\pm0.3 64.1
Fish 77.8±\pm0.3 85.5±\pm0.3 68.6±\pm0.4 45.1±\pm1.3 42.7±\pm0.2 64.0
Fishr 77.8±\pm0.1 85.5±\pm0.4 67.8±\pm0.1 47.4±\pm1.6 41.7±\pm0.0 64.0
CORAL 78.8±\pm0.6 86.2±\pm0.3 68.7±\pm0.3 47.6±\pm1.0 41.5±\pm0.1 64.6
MMD 77.5±\pm0.9 84.6±\pm0.5 66.3±\pm0.1 42.2±\pm1.6 23.4±\pm9.5 63.3
RDM (ours) 78.4±\pm0.4 87.2±\pm0.7 67.3±\pm0.4 47.5±\pm1.0 43.4±\pm0.3 64.8
Table 2: DomainBed test accuracy where the best results are marked as bold. Results of other methods are referenced from [13, 50]. Model selection: training-domain validation set.
Dataset and Protocol

Following previous works [20, 13], we extensively evaluate all methods on five well-known DG benchmarks: VLCS [15], PACS [30], OfficeHome [58], TerraIncognita [5], and DomainNet [43]. For a fair comparison, we reuse the training and evaluation protocol in DomainBed [20], including the dataset splits, training iterations, and model selection criteria. Our evaluation employs the leave-one-domain-out approach: each model is trained on all domains except one and then tested on the excluded domain. The final model is chosen based on its combined accuracy across all training-domain validation sets.

Implementation Details

We use ResNet-50 [21] pre-trained on ImageNet [47] as the default backbone. The model is optimised via the Adam optimiser for 5,0005,000 iterations on every dataset. We follow [13, 25] to pre-train baselines with ERM for certain iterations before performing the algorithms. Importantly, we find that achieving accurate risk distribution matching using distribution samples requires larger batch sizes - details of which are examined in our ablation studies. For most datasets, the optimal batch size lies between [70,100]\left[70,100\right]. However, for huge datasets like TerraIncognita and DomainNet, it is between [30,60]\left[30,60\right]. Although computational resources limit us from testing larger batch sizes, these ranges consistently achieve strong performance on benchmarks. The matching coefficient λ\lambda in our method is set in [0.1,10.0][0.1,10.0]. Additional hyper-parameters like learning rate, dropout rate, or weight decay, adhere to the preset ranges as detailed in [13]. We provide more implementation details in the supplementary material. We repeat our experiments ten times with varied seed values and hyper-parameters and report the average results.

Experimental Results

In Table 2, we show the average out-of-domain (OOD) accuracies of state-of-the-art DG methods on five benchmarks. Due to space constraints, domain-specific accuracies are detailed in the supplementary material. We compare RDM with ERM and various types of algorithms: distributional robustness (GroupDRO), causal methods learning invariance (IRM, VREx, EQRM), gradient matching (Fish [50], Fishr [46]), representation distribution matching (MMD, CORAL) and other variants (Mixup [61], MLDG [29]). To ensure fairness in our evaluations, we have used the same training data volume across all baselines, although further employing augmentations can enhance models’ performance.

On average, RDM surpasses other baselines across all benchmarks, notably achieving a 1.51.5% average improvement over ERM. The significant improvement of RDM on DomainNet, a large-scale dataset with 586,575 images across 6 domains, is worth mentioning. This suggests that characterising domains with risk distributions to achieve invariance effectively enhances OOD performance. Compared to distributional robustness methods, RDM notably outperforms GroupDRO with improvement of 2.82.8% on PACS and a substantial 10.110.1% on DomainNet. RDM consistently improves over causality-based methods that rely on the average risk for domain invariance. This superiority attributes to our novel adoption of risk distributions, achieving enhanced invariance for DG. Our remarkable improvement over MMD suggests that aligning risk distributions via the MMD distance is more effective, easier to optimise than aligning representation distributions. While RDM typically outperforms CORAL and Fish in OOD scenarios, it only remains competitive or sometimes underperforms on certain datasets like OfficeHome. This decrease in performance may stem from the dataset’s inherent tendency to overfit within our risk distribution alignment objective. OfficeHome has only average about 240 samples per class, significantly fewer than other datasets with at least 1,400. This reduced sample size may not provide sufficiently diverse risk distributions to capture stable class features, resulting in overfitting on the training set. Despite these limitations, our OfficeHome results still outperform several well-known baselines such as MLDG, VREx, or ERM. For a detailed discussion on this challenge, please refer to our supplementary material.

5.3 Analysis

Refer to caption
(a)
Refer to caption
(b)
Figure 3: Figure LABEL:fig:analysis_A-small-gap supports our claim about the approximation of the distributional variance, while Figure LABEL:fig:analysis_Learning-curves compares the OOD performance learning curves of RDM with other methods. These insights are visualised every 15 iterations during PACS dataset training, excluding the OOD Sketch domain. After RDM is pre-trained with ERM for 100(×15)100(\times 15) iterations, our visual analysis commences, ensuring a fair comparison.

In this section, we provide empirical evidence backing our claims in Section 4. In Figure LABEL:fig:analysis_A-small-gap, we highlight a small gap when aligning the risk distribution of the worst-case domain with that of all domains combined (RDM with ^RDM\mathcal{\hat{L}}_{\text{RDM}}), compared to directly optimising the distributional variance (RDM with RDM\mathcal{\mathcal{L}}_{\text{RDM}}). Notably, ^RDM\mathcal{\hat{L}}_{\text{RDM}} consistently represents an upper bound of RDM\mathcal{\mathcal{L}}_{\text{RDM}}, which is sensible since the worst-case domain often exhibits the most distinct risk distribution. This suggests that optimising ^RDM\mathcal{\hat{L}}_{\text{RDM}} also helps reduce the distributional variance RDM\mathcal{\mathcal{L}}_{\text{RDM}}, bringing the risk distributions across domains closer.

Algorithm Training (s) Mem (GiB) Acc (%)
Fish 11,502 5.26 42.7
CORAL 11,504 17.00 41.5
RDM with RDM\mathcal{\mathcal{L}}_{\text{RDM}} 9,854 16.94 43.1
RDM with ^RDM\mathcal{\hat{L}}_{\text{RDM}} 7,749 16.23 43.4
Table 3: Comparison between Fish, CORAL, and two variants of our method in terms of the training time (seconds), memory usage per iteration (GiB) and accuracy (%) on DomainNet.

When the number of training domains grows, especially with large-scale datasets like DomainNet, emphasising the risk distribution of the worst-case domain not only proves to be a more efficient approach but also significantly enhances OOD performance. In our exploration of training resources for DomainNet, we study three matching methods: Fish, CORAL and two variants of our RDM method. For a fair evaluation, all experiments were conducted with identical GPU resources, settings, and hyper-parameters, such as batch size or training iterations. Results can be seen in Table 3. Full details on training resources for these methods on other datasets are available in the supplementary material due to space constraints.

Our RDM with the ^RDM\mathcal{\hat{L}}_{\text{RDM}} objective proves fastest in training and achieves the notably highest 43.443.4% accuracy on DomainNet. While RDM demands more memory than Fish, due to the storage of MMD distance values, it can be trained in less time - under an hour - and still delivers a 0.70.7% performance boost. This gain over Fish, a leading gradient maching method on DomainNet, is significant. Among two variants of RDM, the one using ^RDM\mathcal{\hat{L}}_{\text{RDM}} is both the fastest and most accurate, justifying our claims on the benefits of aligning the risk distribution of the worst-case domain.

More, to further highlight the efficacy of risk distribution alignment for DG, we compare the OOD performance learning curves of RDM with competing baselines using representation (CORAL) and gradient (Fish) alignments, as depicted in Figure LABEL:fig:analysis_Learning-curves. Impressively, RDM consistently outperforms, demonstrating enhanced generalisation throughout the training process.

5.4 Ablation studies

We explore the impact of the matching coefficient λ\lambda and training batch size on risk distribution matching, using primarily the PACS dataset for brevity. While other datasets exhibit similar trends, their detailed results are provided in the supplementary material.

Matching coefficient λ\lambda

Figure LABEL:fig:ablation_study_lambda illustrates the performance of RDM on the PACS dataset for varying values of the matching coefficient λ\lambda, spanning {0.1,1.0,2.5,5.0,7.5,10.0}\left\{0.1,1.0,2.5,5.0,7.5,10.0\right\}. Notably, as λ\lambda increases, RDM’s accuracy consistently improves, justifying the significance of our risk distribution matching module in fostering generalisation. In particular, when λ=5.0\lambda=5.0, RDM demonstrates a notable 1.61.6% average accuracy boost across all domains, in contrast to when using only λ=0.1\lambda=0.1. Across most datasets, a λ\lambda value within [0.1,10.0]\left[0.1,10.0\right] appears sufficient to produce good results.

Batch size

We study the impact of batch size on RDM’s performance. Our assumption is that achieving accurate risk distribution matching through data samples would require larger batch sizes. Figure LABEL:fig:ablation_study_batchsize validates this, revealing enhanced generalisation results on PACS with increased batch sizes. For PACS, sizes between [70,100][70,100] yield promising, potentially optimal outcomes, despite computational limitations restrict our exploration of larger sizes.

Refer to caption
(a)
Refer to caption
(b)
Figure 4: Ablation studies on the effects of the matching coefficient λ\lambda and the training batch size on the PACS dataset.

6 Conclusion

We have demonstrated that RDM, a novel matching method for domain generalisation (DG), provides enhanced generalisation capability by aligning risk distributions across domains. RDM efficiently overcomes high-dimensional challenges of conventional DG matching methods. RDM is built on our observation that risk distributions can effectively represent the differences between training domains. By minimising these divergences, we can achieve an invariant and generalisable predictor. We further improve RDM by matching only the risk distribution of the worst-case domain with the aggregate from all domains, bypassing the need to directly compute the distributional variance. This approximate version not only offers computational efficiency but also delivers improved out-of-domain results. Our extensive experiments on several benchmarks reveal that RDM surpasses leading DG techniques. We hope our work can inspire further investigations into the benefits of risk distributions for DG.

References

  • [1] Isabela Albuquerque, João Monteiro, Mohammad Darvishi, Tiago H Falk, and Ioannis Mitliagkas. Generalizing to unseen domains via distribution matching. arXiv preprint arXiv:1911.00804, 2019.
  • [2] Martin Arjovsky, Léon Bottou, Ishaan Gulrajani, and David Lopez-Paz. Invariant risk minimization. arXiv preprint arXiv:1907.02893, 2019.
  • [3] Martin Arjovsky, Soumith Chintala, and Léon Bottou. Wasserstein generative adversarial networks. In ICML, pages 214–223, 2017.
  • [4] Yogesh Balaji, Swami Sankaranarayanan, and Rama Chellappa. Metareg: Towards domain generalization using meta-regularization. NIPS, 31, 2018.
  • [5] Sara Beery, Grant Van Horn, and Pietro Perona. Recognition in terra incognita. In ECCV, pages 456–473, 2018.
  • [6] Richard Bellman. Dynamic programming. Science, 153(3731):34–37, 1966.
  • [7] Shai Ben-David, John Blitzer, Koby Crammer, Alex Kulesza, Fernando Pereira, and Jennifer Wortman Vaughan. A theory of learning from different domains. Machine learning, 79:151–175, 2010.
  • [8] Peter Bühlmann. Invariance, causality and robustness. 2020.
  • [9] Chaoqi Chen, Weiping Xie, Wenbing Huang, Yu Rong, Xinghao Ding, Yue Huang, Tingyang Xu, and Junzhou Huang. Progressive feature alignment for unsupervised domain adaptation. In CVPR, pages 627–636, 2019.
  • [10] Ting Chen, Simon Kornblith, Mohammad Norouzi, and Geoffrey Hinton. A simple framework for contrastive learning of visual representations. In ICML, pages 1597–1607, 2020.
  • [11] Mathieu Chevalley, Charlotte Bunne, Andreas Krause, and Stefan Bauer. Invariant causal mechanisms through distribution matching. arXiv preprint arXiv:2206.11646, 2022.
  • [12] Andrea Dittadi, Frederik Träuble, Francesco Locatello, Manuel Wüthrich, Vaibhav Agrawal, Ole Winther, Stefan Bauer, and Bernhard Schölkopf. On the transfer of disentangled representations in realistic settings. ICLR, 2021.
  • [13] Cian Eastwood, Alexander Robey, Shashank Singh, Julius Von Kügelgen, Hamed Hassani, George J Pappas, and Bernhard Schölkopf. Probable domain generalization via quantile risk minimization. NeurIPS, 35:17340–17358, 2022.
  • [14] Dominik Maria Endres and Johannes E Schindelin. A new metric for probability distributions. IEEE Transactions on Information theory, pages 1858–1860, 2003.
  • [15] Chen Fang, Ye Xu, and Daniel N Rockmore. Unbiased metric learning: On the utilization of multiple datasets and web images for softening bias. In ICCV, pages 1657–1664, 2013.
  • [16] Bent Fuglede and Flemming Topsoe. Jensen-shannon divergence and hilbert space embedding. In ISIT, page 31. IEEE, 2004.
  • [17] Xavier Glorot and Yoshua Bengio. Understanding the difficulty of training deep feedforward neural networks. In AISTATS, pages 249–256, 2010.
  • [18] Arthur Gretton, Karsten Borgwardt, Malte Rasch, Bernhard Schölkopf, and Alex Smola. A kernel method for the two-sample-problem. NIPS, 19, 2006.
  • [19] Arthur Gretton, Karsten M Borgwardt, Malte J Rasch, Bernhard Schölkopf, and Alexander Smola. A kernel two-sample test. JMLR, pages 723–773, 2012.
  • [20] Ishaan Gulrajani and David Lopez-Paz. In search of lost domain generalization. In ICLR, 2021.
  • [21] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In CVPR, pages 770–778, 2016.
  • [22] Piotr Indyk and Rajeev Motwani. Approximate nearest neighbors: towards removing the curse of dimensionality. In Proc. Annu. ACM Symp. Theory Comput., pages 604–613, 1998.
  • [23] Diederik P Kingma and Jimmy Ba. Adam: A method for stochastic optimization. ICLR, 2015.
  • [24] Masanori Koyama and Shoichiro Yamaguchi. Out-of-distribution generalization with maximal invariant predictor. 2020.
  • [25] David Krueger, Ethan Caballero, Joern-Henrik Jacobsen, Amy Zhang, Jonathan Binas, Dinghuai Zhang, Remi Le Priol, and Aaron Courville. Out-of-distribution generalization via risk extrapolation (rex). In ICML, pages 5815–5826, 2021.
  • [26] Solomon Kullback and Richard A Leibler. On information and sufficiency. The annals of mathematical statistics, 22(1):79–86, 1951.
  • [27] Yann LeCun, Yoshua Bengio, and Geoffrey Hinton. Deep learning. Nature, 521(7553):436–444, 2015.
  • [28] Chun-Liang Li, Wei-Cheng Chang, Yu Cheng, Yiming Yang, and Barnabás Póczos. Mmd-gan: Towards deeper understanding of moment matching network. NIPS, 30, 2017.
  • [29] Da Li, Yongxin Yang, Yi-Zhe Song, and Timothy Hospedales. Learning to generalize: Meta-learning for domain generalization. In AAAI, volume 32, 2018.
  • [30] Da Li, Yongxin Yang, Yi-Zhe Song, and Timothy M Hospedales. Deeper, broader and artier domain generalization. In ICCV, pages 5542–5550, 2017.
  • [31] Haoliang Li, Sinno Jialin Pan, Shiqi Wang, and Alex C Kot. Domain generalization with adversarial feature learning. In CVPR, pages 5400–5409, 2018.
  • [32] Ya Li, Mingming Gong, Xinmei Tian, Tongliang Liu, and Dacheng Tao. Domain generalization via conditional invariant representations. In AAAI, volume 32, 2018.
  • [33] Yujia Li, Kevin Swersky, and Rich Zemel. Generative moment matching networks. In ICML, pages 1718–1727, 2015.
  • [34] Chaochao Lu, Yuhuai Wu, José Miguel Hernández-Lobato, and Bernhard Schölkopf. Invariant causal representation learning for out-of-distribution generalization. In ICLR, 2021.
  • [35] Divyat Mahajan, Shruti Tople, and Amit Sharma. Domain generalization using causal matching. In ICML, pages 7313–7324, 2021.
  • [36] Facundo Mémoli. Gromov–wasserstein distances and the metric approach to object matching. Foundations of computational mathematics, 11:417–487, 2011.
  • [37] Saeid Motiian, Marco Piccirilli, Donald A Adjeroh, and Gianfranco Doretto. Unified deep supervised domain adaptation and generalization. In ICCV, pages 5715–5725, 2017.
  • [38] Krikamol Muandet, David Balduzzi, and Bernhard Schölkopf. Domain generalization via invariant feature representation. In ICML, pages 10–18, 2013.
  • [39] Marius Muja and David G Lowe. Scalable nearest neighbor algorithms for high dimensional data. TPAMI, 36(11):2227–2240, 2014.
  • [40] Alfred Müller. Integral probability metrics and their generating classes of functions. Advances in applied probability, 29(2):429–443, 1997.
  • [41] Toan Nguyen, Kien Do, Duc Thanh Nguyen, Bao Duong, and Thin Nguyen. Causal inference via style transfer for out-of-distribution generalisation. In KDD, pages 1746–1757, 2023.
  • [42] Emanuel Parzen. On estimation of a probability density function and mode. The annals of mathematical statistics, 33(3):1065–1076, 1962.
  • [43] Xingchao Peng, Qinxun Bai, Xide Xia, Zijun Huang, Kate Saenko, and Bo Wang. Moment matching for multi-source domain adaptation. In ICCV, pages 1406–1415, 2019.
  • [44] Xingchao Peng, Zijun Huang, Ximeng Sun, and Kate Saenko. Domain agnostic learning with disentangled representations. In ICML, pages 5102–5112, 2019.
  • [45] Jonas Peters, Peter Bühlmann, and Nicolai Meinshausen. Causal inference by using invariant prediction: identification and confidence intervals. Journal of the Royal Statistical Society Series B: Statistical Methodology, 78(5):947–1012, 2016.
  • [46] Alexandre Rame, Corentin Dancette, and Matthieu Cord. Fishr: Invariant gradient variances for out-of-distribution generalization. In ICML, pages 18347–18377, 2022.
  • [47] Olga Russakovsky, Jia Deng, Hao Su, Jonathan Krause, Sanjeev Satheesh, Sean Ma, Zhiheng Huang, Andrej Karpathy, Aditya Khosla, Michael Bernstein, Alexander C. Berg, and Li Fei-Fei. ImageNet Large Scale Visual Recognition Challenge. IJCV, 115(3):211–252, 2015.
  • [48] Shiori Sagawa, Pang Wei Koh, Tatsunori B Hashimoto, and Percy Liang. Distributionally robust neural networks for group shifts: On the importance of regularization for worst-case generalization. In ICLR, 2020.
  • [49] Vaishaal Shankar, Achal Dave, Rebecca Roelofs, Deva Ramanan, Benjamin Recht, and Ludwig Schmidt. Do image classifiers generalize across time? In ICCV, pages 9661–9669, 2021.
  • [50] Yuge Shi, Jeffrey Seely, Philip Torr, Siddharth N, Awni Hannun, Nicolas Usunier, and Gabriel Synnaeve. Gradient matching for domain generalization. In ICLR, 2022.
  • [51] Alex Smola, Arthur Gretton, Le Song, and Bernhard Schölkopf. A hilbert space embedding for distributions. In International conference on algorithmic learning theory, pages 13–31, 2007.
  • [52] Bharath K Sriperumbudur, Arthur Gretton, Kenji Fukumizu, Bernhard Schölkopf, and Gert RG Lanckriet. Hilbert space embeddings and metrics on probability measures. JMLR, 11:1517–1561, 2010.
  • [53] Masashi Sugiyama, Taiji Suzuki, and Takafumi Kanamori. Density ratio estimation in machine learning. Cambridge University Press, 2012.
  • [54] Baochen Sun and Kate Saenko. Deep coral: Correlation alignment for deep domain adaptation. In ECCV Workshops, pages 443–450. Springer, 2016.
  • [55] Xinwei Sun, Botong Wu, Xiangyu Zheng, Chang Liu, Wei Chen, Tao Qin, and Tie-Yan Liu. Recovering latent causal factor for generalization to distributional shifts. NeurIPS, 34:16846–16859, 2021.
  • [56] Yuwei Sun, Ng Chong, and Hideya Ochiai. Feature distribution matching for federated domain generalization. In ACML, pages 942–957, 2023.
  • [57] Vladimir N Vapnik. An overview of statistical learning theory. IEEE transactions on neural networks, 10(5):988–999, 1999.
  • [58] Hemanth Venkateswara, Jose Eusebio, Shayok Chakraborty, and Sethuraman Panchanathan. Deep hashing network for unsupervised domain adaptation. In CVPR, pages 5018–5027, 2017.
  • [59] Jindong Wang, Wenjie Feng, Yiqiang Chen, Han Yu, Meiyu Huang, and Philip S Yu. Visual domain adaptation with manifold embedded distribution alignment. In ACMMM, pages 402–410, 2018.
  • [60] Pengfei Wang, Zhaoxiang Zhang, Zhen Lei, and Lei Zhang. Sharpness-aware gradient matching for domain generalization. In CVPR, pages 3769–3778, 2023.
  • [61] Hongyi Zhang, Moustapha Cisse, Yann N Dauphin, and David Lopez-Paz. mixup: Beyond empirical risk minimization. ICLR, 2018.
  • [62] Yabin Zhang, Minghan Li, Ruihuang Li, Kui Jia, and Lei Zhang. Exact feature distribution matching for arbitrary style transfer and domain generalization. In CVPR, pages 8035–8045, 2022.
  • [63] Kaiyang Zhou, Yongxin Yang, Yu Qiao, and Tao Xiang. Domain generalization with mixstyle. In ICLR, 2021.

7 Supplementary Material

In this supplementary material, we first provide a detailed proof for our theorem on distributional variance, as outlined in Section 8. Next, in Section 9, we detail more about our experimental settings, covering both the ColoredMNIST synthetic dataset [2] and the extensive benchmarks from the DomainBed suite [20] in the main text. Additional ablation studies and discussions on our proposed method are given in Section 10. Finally, Section 11 provides domain-specific out-of-domain accuracies for each dataset within the DomainBed suite.

8 Theoretical Results

We provide the proof for the theorem on distributional variance discussed in the main paper. We revisit the concept of kernel mean embedding [51] to express the risk distribution 𝒯e\mathcal{T}_{e} of domain ee. Particularly, we represent 𝒯e\mathcal{T}_{e} through its embedding, μ𝒯e\mu_{\mathcal{T}_{e}}, in a reproducing kernel Hilbert space (RKHS) denoted as \mathcal{H}. This is achieved by using a feature map ϕ:\phi:\mathbb{R}\rightarrow\mathcal{H} below:

μ𝒯e\displaystyle\mu_{\mathcal{T}_{e}} 𝔼Re𝒯e[ϕ(Re)]\displaystyle\coloneqq\mathbb{E}_{R_{e}\sim\mathcal{T}_{e}}\left[\phi\left(R_{e}\right)\right] (17)
=𝔼Re𝒯e[k(Re,)]\displaystyle=\ \mathbb{E}_{R_{e}\sim\mathcal{T}_{e}}\left[k\left(R_{e},\cdot\right)\right] (18)

where a kernel function k(,):×k\left(\cdot,\cdot\right):\mathbb{R}\times\mathbb{R}\rightarrow\mathbb{R} is introduced to bypass the explicit specification of ϕ\phi.

Theorem.

[38] Denote 𝒯=1me=1m𝒯e\mathcal{T}=\frac{1}{m}\sum_{e=1}^{m}\mathcal{T}_{e} the probability distribution over the risks of all samples in the entire training set, or equivalently, the set of all mm domains. Given the distributional variance 𝕍({𝒯1,,𝒯m})\mathbb{V}_{\mathbb{\mathcal{H}}}\left(\left\{\mathcal{T}_{1},...,\mathcal{T}_{m}\right\}\right) is calculated with a characteristic kernel kk, 𝕍({𝒯1,,𝒯m})=0\mathbb{V}_{\mathcal{H}}\left(\left\{\mathcal{T}_{1},...,\mathcal{T}_{m}\right\}\right)=0 if and only if 𝒯1==𝒯m(=𝒯)\mathcal{T}_{1}=...=\mathcal{T}_{m}\left(=\mathcal{T}\right).

Proof.

In our methodology, we employ the RBF kernel, which is characteristic in nature. As a result, the term μ𝒯eμ𝒯2\left\|\mu_{\mathcal{T}_{e}}-\mu_{\mathcal{T}}\right\|_{\mathcal{H}}^{2} acts as a metric within the Hilbert space \mathcal{H} [38]. Importantly, this metric reaches zero if and only if (𝒯e=𝒯)\left(\mathcal{T}_{e}=\mathcal{T}\right) [52]. Let’s consider the distributional variance, 𝕍({𝒯1,,𝒯m})\mathbb{V}_{\mathcal{H}}\left(\left\{\mathcal{T}_{1},...,\mathcal{T}_{m}\right\}\right), which is defined below:

𝕍=1me=1mμ𝒯eμ𝒯2\mathbb{V}_{\mathcal{H}}=\frac{1}{m}\sum_{e=1}^{m}\left\|\mu_{\mathcal{T}_{e}}-\mu_{\mathcal{T}}\right\|_{\mathcal{H}}^{2} (19)

This variance becomes zero if and only if μ𝒯eμ𝒯2=0\left\|\mu_{\mathcal{T}_{e}}-\mu_{\mathcal{T}}\right\|_{\mathcal{H}}^{2}=0 for each ee. This logically implies that (𝒯e=𝒯)\left(\mathcal{T}_{e}=\mathcal{T}\right) for all ee, leading to (𝒯1=𝒯2==𝒯m)\left(\mathcal{T}_{1}=\mathcal{T}_{2}=...=\mathcal{T}_{m}\right).

Conversely, we assume that (𝒯1=𝒯2==𝒯m)\left(\mathcal{T}_{1}=\mathcal{T}_{2}=...=\mathcal{T}_{m}\right). Given this condition, for any ee, it follows that:

μ𝒯=1me=1mμ𝒯e=μ𝒯e\mu_{\mathcal{T}}=\frac{1}{m}\sum_{e=1}^{m}\mu_{\mathcal{T}_{e}}=\mu_{\mathcal{T}_{e}} (20)

which implies

μ𝒯eμ𝒯2=0.\left\|\mu_{\mathcal{T}_{e}}-\mu_{\mathcal{T}}\right\|_{\mathcal{H}}^{2}=0. (21)

Consequently, by the given definition of distributional variance, we have: 𝕍({𝒯1,,𝒯m})=1me=1mμ𝒯eμ𝒯2=0\mathbb{V}_{\mathcal{H}}\left(\left\{\mathcal{T}_{1},...,\mathcal{T}_{m}\right\}\right)=\frac{1}{m}\sum_{e=1}^{m}\left\|\mu_{\mathcal{T}_{e}}-\mu_{\mathcal{T}}\right\|_{\mathcal{H}}^{2}=0. This completes the proof. ∎

9 More implementation details

For our experiments, we leveraged the PyTorch DomainBed toolbox [20, 13] and utilised an Ubuntu 20.4 server outfitted with a 36-core CPU, 767GB RAM, and NVIDIA V100 32GB GPUs. The software stack included Python 3.11.2, PyTorch 1.7.1, Torchvision 0.8.2, and Cuda 12.0. Additional implementation details, beyond the hyper-parameters discussed in the main text, are elaborated below.

9.1 ColoredMNIST

In alignment with [13], we performed experiments on the ColoredMNIST dataset, the results of which are detailed in Table 1 in the main paper. We partitioned the original MNIST training dataset into distinct training and validation sets of 25,000 and 5,000 samples for each of two training domains, respectively. The original MNIST test set was adapted to function as our test set. Particularly, we synthesised this test set to introduce a distribution shift: red digits have only a 1010% probability of being classified as “zero”, compared to 80% and 90% in the training sets for different domains. Besides the hyper-parameters highlighted in the main paper, we also leveraged a cosine annealing scheduler to further optimise the training process like other baselines.

For our RDM method, we constrained the alignment to focus only on the first two empirical moments (mean and variance) of 𝒯w\mathcal{T}_{w} and 𝒯\mathcal{T}. We experimented with five different penalty weight values for λ\lambda in the range of {500,1000,2500,5000,10000}\left\{500,1000,2500,5000,10000\right\}, running each experiment ten times and varying λ\lambda. The reported results are the average accuracies and their standard deviations over these 10 runs, all measured on a test-domain test set. We adhered to test-domain validation for model selection across all methods, as recommended by [20]. We reference results for other methods from [13].

9.2 DomainBed

9.2.1 Description of benchmarks

Parameter Dataset Default value Random distribution
steps All 5,000 5,000
learning rate All 5e-55\text{e-5} 10Uniform(4.5,4)10^{\text{Uniform$\left(-4.5,-4\right)$}}
dropout All 0 RandomChoice([0,0.003,0.03])\left(\left[0,0.003,0.03\right]\right)
weight decay All 0 10Uniform(8,5)10^{\text{Uniform$\left(-8,-5\right)$}}
batch size PACS / VLCS / OfficeHome 88 Uniform(70,10070,100)
TerraIncognita / DomainNet 40 Uniform(30,6030,60)
matching coefficient λ\lambda All except DomainNet 5.0 Uniform(0.1,10.00.1,10.0)
DomainNet 0.5 Uniform(0.1,1.00.1,1.0)
pre-trained iterations All except DomainNet 1500 Uniform(800,2700800,2700)
DomainNet 2400 Uniform(1500,30001500,3000)
learning rate after pre-training All 1.5e-51.5\text{e-5} Uniform(8e-6,2e-58\text{e-6},2\text{e-5})
variance regularisation coefficient PACS / VLCS 0.004 Uniform(0.001,0.0070.001,0.007)
OfficeHome / TerraIncognita /
DomainNet 0 0
Table 4: Hyper-parameters, along with their default values and distributions, are optimised through random search across the five benchmark datasets.

For our evaluations, we leveraged five large-scale benchmark datasets from the DomainBed suite [20], comprising:

  • VLCS [15]: The dataset encompasses four photographic domains: Caltech101, LabelMe, SUN09, VOC2007. It contains 10,729 examples, each with dimensions (3,224,224)\left(3,224,224\right), and spans five distinct classes.

  • PACS [30]: The dataset includes 9,991 images from four different domains: Photo (P), Art-painting (A), Cartoon (C), and Sketch (S). These domains each have their own unique style, making this dataset particularly challenging for out-of-distribution (OOD) generalisation. Each domain has seven classes.

  • OfficeHome [58]: The dataset features 15,500 images of objects commonly found in office and home settings, categorised into 65 classes. These images are sourced from four distinct domains: Art (A), Clipart (C), Product (P), and Real-world (R).

  • TerraIncognita [5]: The dataset includes 24,788 camera-trap photographs of wild animals captured at locations {L100,L38,L43,L46}\left\{\text{L100},\text{L38},\text{L43},\text{L46}\right\}. Each image has dimensions (3,224,224)\left(3,224,224\right) and falls into one of 10 distinct classes.

  • DomainNet [43]: The largest dataset in DomainBed, DomainNet, contains 586,575 examples in dimensions (3,224,224)\left(3,224,224\right), spread across six domains {clipart,infograph,painting,quickdraw,real,sketch}\left\{\text{clipart},\text{infograph},\text{painting},\text{quickdraw},\text{real},\text{sketch}\right\} and encompassing 345 classes.

9.2.2 Our implementation details

Algorithm Training (s) Mem (GiB) Acc (%)
Fish 7,566 7.97 85.5
CORAL 4,485 21.81 86.2
RDM with RDM\mathcal{\mathcal{L}}_{\text{RDM}} 4,783 21.87 86.6
RDM with ^RDM\mathcal{\hat{L}}_{\text{RDM}} 4,214 21.71 87.2
(a)
Algorithm Training (s) Mem (GiB) Acc (%)
Fish 13,493 7.97 77.8
CORAL 6,329 21.81 78.8
RDM with RDM\mathcal{\mathcal{L}}_{\text{RDM}} 9,441 21.87 77.8
RDM with ^RDM\mathcal{\hat{L}}_{\text{RDM}} 6,151 21.71 78.4
(b)
Algorithm Training (s) Mem (GiB) Acc (%)
Fish 9,035 7.97 68.6
CORAL 4,762 21.81 68.7
RDM with RDM\mathcal{\mathcal{L}}_{\text{RDM}} 5,467 21.87 67.0
RDM with ^RDM\mathcal{\hat{L}}_{\text{RDM}} 4,588 21.71 67.3
(c)
Algorithm Training (s) Mem (GiB) Acc (%)
Fish 6,019 4.08 45.1
CORAL 2,973 10.21 47.6
RDM with RDM\mathcal{\mathcal{L}}_{\text{RDM}} 4,040 10.17 47.1
RDM with ^RDM\mathcal{\hat{L}}_{\text{RDM}} 2,697 10.11 47.5
(d)
Table 5: Comparison between Fish, CORAL, and two variants of our method in terms of the training time (seconds), memory usage per iteration (GiB) and accuracy (%) on PACS, VLCS, OfficeHome and TerraIncognita.

To ensure rigorous evaluation and a fair comparison with existing baselines [13, 46], we conducted experiments on five datasets from the DomainBed suite, the results of which are elaborated in Table 2 of the main text. In alignment with standard practices, we optimised hyper-parameters for each domain through a randomised search across 20 trials on the validation set, utilising a joint distribution as specified in Table 9.2.1. The dataset from each domain was partitioned into an 80% split for training and testing, and a 20% split for hyper-parameter validation. A comprehensive discussion on the hyper-parameters used in our experiments is provided below. For each domain, we performed our experiments ten times, employing varied seed values and hyper-parameters within the specified range, and reported the averaged results with their standard deviations. We reference results for other methods from [13, 46]. We kindly refer readers to our given source code for more detail.

In our methodology, we employed the MMD distance for aligning risk distributions 𝒯e\mathcal{T}_{e} and 𝒯\mathcal{T}, as described in Section 4. Utilising the RBF kernel, we compute the average MMD distance across an expansive bandwidth spectrum {0.0001,0.001,0.01,0.1,1,10,100,1000}\left\{0.0001,0.001,0.01,0.1,1,10,100,1000\right\}, bypassing the need for tuning this parameter.

Inspired by recent insights [13], we incorporated an initial pre-training-with-ERM phase to further improve the OOD performance. DomainNet, given its scale, requires longer ERM pre-training; specific parameters for all datasets are provided in Table 9.2.1. Our initial learning rate lies within [1e-4.5,1e-4]\left[1\text{e-4.5},1\text{e-4}\right], which adapts to [8e-6,2e-5]\left[8\text{e-6},2\text{e-5}\right] post-ERM pre-training. Incorporating additional variance regularisation on 𝒯e\mathcal{T}_{e} and 𝒯\mathcal{T} proves beneficial for the PACS and VLCS datasets. This approach constrains the induced risks to fall within narrower, more optimal value ranges, facilitating more effective risk distribution alignment. Optimal regularisation coefficients for this strategy are detailed in Table 9.2.1.

We maintain minimal dropout and weight decay, reserving our focus for risk distribution alignment. Optimal batch sizes differ: [70,100]\left[70,100\right] for VLCS and OfficeHome, and [30,60]\left[30,60\right] for TerraIncognita and DomainNet. Despite computational constraints limiting our ability to test larger batch sizes, the selected ranges yield robust performance across datasets.

Regarding the matching coefficient λ\lambda in our objective, most datasets work well within [0.1,10.0]\left[0.1,10.0\right], but DomainNet prefers a narrower [0.1,1.0]\left[0.1,1.0\right] range. This fine-tuning is key, especially for large-scale datasets, to balance risk reduction and cross-domain alignment in the early training stages.

10 More ablation studies and analyses

10.1 Efficacy of DG matching methods

Table 5 compares the efficiency and effectiveness of various methods: Fish, CORAL, RDM with RDM\mathcal{\mathcal{L}}_{\text{RDM}}, and RDM with ^RDM\mathcal{\hat{L}}_{\text{RDM}} across several benchmarks - PACS, VLCS, OfficeHome, and TerraIncognita. Notably, the approximate variant, denoted as RDM with ^RDM\mathcal{\hat{L}}_{\text{RDM}}, stands out for its exceptional performance. This version emphasises the alignment of risk distribution for the worst-case domain and exhibits both faster training times and improved accuracy over its counterpart that optimises distributional variance, RDM with RDM\mathcal{\mathcal{L}}_{\text{RDM}}. For instance, on the VLCS dataset, this variant is trained in under an hour while achieving a 0.60.6% accuracy boost.

When compared to the gradient-matching Fish method, our approach demonstrates similar advantages but requires additional memory to store MMD distance values. The memory constraint is not unique to our method; CORAL also encounters this limitation. However, RDM outperforms CORAL in both training time and memory usage, especially evident on large-scale datasets like DomainNet. This efficiency gain is noteworthy, given that CORAL’s increased computational requirements arise from its handling of high-dimensional representation vectors.

In terms of the accuracy, as confirmed by our main text, RDM outperforms CORAL substantially on both PACS and DomainNet, while maintaining competitiveness on TerraIncognita and VLCS. On OfficeHome, although RDM lags behind CORAL, we provide an in-depth explanation for this behavior both in the main text and in the subsequent section.

10.2 Decreased Performance on OfficeHome

Refer to caption
Figure 5: The notable gap between training and validation matching loss on the OfficeHome dataset, excluding the OOD Art domain. Analysis begins after RDM completes 1,500 pre-training iterations via ERM. Metrics recorded at every 15-iteration interval.

In our evaluation, RDM generally surpasses competing matching methods in OOD settings but faces challenges in specific datasets like OfficeHome. The dataset’s limitations are noteworthy: with an average of only 240 samples per class, OfficeHome has significantly fewer instances per class than other datasets, which usually have at least 1,400. This limited sample size may constrain the model from learning sufficiently class-semantic features or diverse risk distributions, leading to overfitting on the training set. To shed light on this issue, we present a visual analysis in Figure 5. Starting from the 100th iteration, when we perform the task of matching risk distributions, we note that the training matching loss is already minimal, forming a clear divergence with the validation matching loss. While the training loss continues to converge to minimal values, the validation loss remains inconsistent throughout the training phase. This inconsistency showcases that the limited diversity in OfficeHome’s risk distributions may induce the model’s overfitting on training samples, reducing its generalisation capabilities. Despite these constraints, our method still outperforms other well-known baselines, such as MLDG, VREx, and ERM, on OfficeHome.

10.3 Impact of batch size and matching coefficient

Refer to caption
(a)
Refer to caption
(b)
Refer to caption
(c)
Refer to caption
(d)
Figure 6: The influence of batch size in our method on VLCS, OfficeHome, TerraIncognita and DomainNet.
Refer to caption
(a)
Refer to caption
(b)
Refer to caption
(c)
Refer to caption
(d)
Figure 7: The influence of matching coefficient λ\lambda in our method on VLCS, OfficeHome, TerraIncognita and DomainNet.

In our analysis, we closely examine how batch size and the matching coefficient λ\lambda affect RDM’s performance across four benchmark datasets: VLCS, OfficeHome, TerraIncognita, and DomainNet. Consistent with our main text findings on PACS, Figure 6 shows that using larger batch sizes enhances the model’s generalisation by facilitating accurate risk distribution matching. Similarly, Figure 7 highlights the importance of λ\lambda in improving OOD performance; as λ\lambda increases, OOD performance generally improves.

We find optimal batch size ranges for each dataset: VLCS and OfficeHome perform best with sizes between [70,100]\left[70,100\right], while the larger datasets of TerraIncognita and DomainNet benefit from a more limited range of [30,60]\left[30,60\right]. Even with computational limitations, these batch sizes lead to strong performance. For most datasets, a λ\lambda value between [0.1,10.0]\left[0.1,10.0\right] is effective. In the case of DomainNet, a smaller λ\lambda range of [0.1,1.0]\left[0.1,1.0\right] works well, balancing the reduction of training risks and the alignment of risk distributions across domains. This is particularly important for large-scale datasets where reducing training risks is crucial for learning predictive features, especially during the initial phases of training.

10.4 Risk distributions

Refer to caption
(a)
Refer to caption
(b)
Figure 8: Histograms with their KDE curves depicting the risk distributions of ERM and our RDM method across four domains on PACS. Vertical ticks denote the mean values of all distributions.
Refer to caption
(a)
Refer to caption
(b)
Figure 9: Histograms with their KDE curves depicting the risk distributions of ERM and our RDM method across six domains on DomainNet. Vertical ticks denote the mean values of all distributions.

We present visualisations of risk distribution histograms accompanied by their KDE curves for two datasets, PACS and DomainNet, in Figures 8 and 9, respectively. These visualisations compare the risk distributions of ERM and our proposed RDM method on the validation sets. Both figures confirm our hypothesis that variations in training domains lead to distinct risk distributions, making them valuable indicators of domain differences.

On PACS, we observe that ERM tends to capture domain-specific features, resulting in low risks within the training domains. However, ERM’s substantial deviation of the average risk for the test domain from that for the training domains suggests sub-optimal OOD generalisation. In contrast, our RDM approach prioritises stable, domain-invariant features, yielding more consistent risk distributions and enhanced generalisation. This trend holds across both two datasets, as our approach consistently aligns risk distributions across domains better than ERM. This alignment effectively narrows the gap between test and training domains, especially reducing risks for test domains.

These findings underscore the efficacy of our RDM method in mitigating domain variations by aligning risk distributions, ultimately leading to enhanced generalisation.

11 More experimental results

We provide domain-specific out-of-domain accuracies for each dataset within the DomainBed suite in Tables 678910. In each table, the accuracy listed in each column represents the out-of-domain performance when that specific domain is excluded from the training set and used solely for testing within the respective dataset. We note that the per-domain results for Fish [50] are not available.

Algorithm clip info paint quick real sketch Avg
ERM 58.1 ±\pm 0.3 18.8 ±\pm 0.3 46.7 ±\pm 0.3 12.2 ±\pm 0.4 59.6 ±\pm 0.1 49.8 ±\pm 0.4 40.9
Mixup 55.7 ±\pm 0.3 18.5 ±\pm 0.5 44.3 ±\pm 0.5 12.5 ±\pm 0.4 55.8 ±\pm 0.3 48.2 ±\pm 0.5 39.2
MLDG 59.1 ±\pm 0.2 19.1 ±\pm 0.3 45.8 ±\pm 0.7 13.4 ±\pm 0.3 59.6 ±\pm 0.2 50.2 ±\pm 0.4 41.2
GroupDRO 47.2 ±\pm 0.5 17.5 ±\pm 0.4 33.8 ±\pm 0.5 9.3 ±\pm 0.3 51.6 ±\pm 0.4 40.1 ±\pm 0.6 33.3
IRM 48.5 ±\pm 2.8 15.0 ±\pm 1.5 38.3 ±\pm 4.3 10.9 ±\pm 0.5 48.2 ±\pm 5.2 42.3 ±\pm 3.1 33.9
VREx 47.3 ±\pm 3.5 16.0 ±\pm 1.5 35.8 ±\pm 4.6 10.9 ±\pm 0.3 49.6 ±\pm 4.9 42.0 ±\pm 3.0 33.6
EQRM 56.1 ±\pm 1.3 19.6 ±\pm 0.1 46.3 ±\pm 1.5 12.9 ±\pm 0.3 61.1 ±\pm 0.0 50.3 ±\pm 0.1 41.0
Fish - - - - - - 42.7
Fishr 58.2 ±\pm 0.5 20.2 ±\pm 0.2 47.7 ±\pm 0.3 12.7 ±\pm 0.2 60.3 ±\pm 0.2 50.8 ±\pm 0.1 41.7
CORAL 59.2 ±\pm 0.1 19.7 ±\pm 0.2 46.6 ±\pm 0.3 13.4 ±\pm 0.4 59.8 ±\pm 0.2 50.1 ±\pm 0.6 41.5
MMD 32.1 ±\pm 13.3 11.0 ±\pm 4.6 26.8 ±\pm 11.3 8.7 ±\pm 2.1 32.7 ±\pm 13.8 28.9 ±\pm 11.9 23.4
RDM (ours) 62.1 ±\pm 0.2 20.7 ±\pm 0.1 49.2 ±\pm 0.4 14.1 ±\pm 0.4 63.0 ±\pm 1.3 51.4 ±\pm 0.1 43.4
Table 6: Domain-specific out-of-domain accuracy on DomainNet where the best results are marked as bold. Results of other methods are referenced from [13, 50].
Algorithm A C P S Avg
ERM 84.7 ±\pm 0.4 80.8 ±\pm 0.6 97.2 ±\pm 0.3 79.3 ±\pm 1.0 85.5
Mixup 86.1 ±\pm 0.5 78.9 ±\pm 0.8 97.6 ±\pm 0.1 75.8 ±\pm 1.8 84.6
MLDG 85.5 ±\pm 1.4 80.1 ±\pm 1.7 97.4 ±\pm 0.3 76.6 ±\pm 1.1 84.9
GroupDRO 83.5 ±\pm 0.9 79.1 ±\pm 0.6 96.7 ±\pm 0.3 78.3 ±\pm 2.0 84.4
IRM 84.8 ±\pm 1.3 76.4 ±\pm 1.1 96.7 ±\pm 0.6 76.1 ±\pm 1.0 83.5
VREx 86.0 ±\pm 1.6 79.1 ±\pm 0.6 96.9 ±\pm 0.5 77.7 ±\pm 1.7 84.9
EQRM 86.5 ±\pm 0.4 82.1 ±\pm 0.7 96.6 ±\pm 0.2 80.8 ±\pm 0.2 86.5
Fish - - - - 85.5
Fishr 88.4 ±\pm 0.2 78.7±\pm 0.7 97.0 ±\pm 0.1 77.8 ±\pm 2.0 85.5
CORAL 88.3 ±\pm 0.2 80.0 ±\pm 0.5 97.5 ±\pm 0.3 78.8 ±\pm 1.3 86.2
MMD 86.1 ±\pm 1.4 79.4 ±\pm 0.9 96.6 ±\pm 0.2 76.5 ±\pm 0.5 84.6
RDM (ours) 88.4 ±\pm 0.2 81.3 ±\pm 1.6 97.1 ±\pm 0.1 81.8 ±\pm 1.1 87.2
Table 7: Domain-specific out-of-domain accuracy on PACS where the best results are marked as bold. Results of other methods are referenced from [13, 50].
Algorithm C L S V Avg
ERM 97.7 ±\pm 0.4 64.3 ±\pm 0.9 73.4 ±\pm 0.5 74.6 ±\pm 1.3 77.5
Mixup 98.3 ±\pm 0.6 64.8 ±\pm 1.0 72.1 ±\pm 0.5 74.3 ±\pm 0.8 77.4
MLDG 97.4 ±\pm 0.2 65.2 ±\pm 0.7 71.0 ±\pm 1.4 75.3 ±\pm 1.0 77.2
GroupDRO 97.3 ±\pm 0.3 63.4 ±\pm 0.9 69.5 ±\pm 0.8 76.7 ±\pm 0.7 76.7
IRM 98.6 ±\pm 0.1 64.9 ±\pm 0.9 73.4 ±\pm 0.6 77.3 ±\pm 0.9 78.5
VREx 98.4 ±\pm 0.3 64.4 ±\pm 1.4 74.1 ±\pm 0.4 76.2 ±\pm 1.3 78.3
EQRM 98.3 ±\pm 0.0 63.7 ±\pm 0.8 72.6 ±\pm 1.0 76.7 ±\pm 1.1 77.8
Fish - - - - 77.8
Fishr 98.9 ±\pm 0.3 64.0 ±\pm 0.5 71.5 ±\pm 0.2 76.8 ±\pm 0.7 77.8
CORAL 98.3 ±\pm 0.1 66.1 ±\pm 1.2 73.4 ±\pm 0.3 77.5 ±\pm 1.2 78.8
MMD 97.7 ±\pm 0.1 64.0 ±\pm 1.1 72.8 ±\pm 0.2 75.3 ±\pm 3.3 77.5
RDM (ours) 98.1 ±\pm 0.2 64.9 ±\pm 0.7 72.6 ±\pm 0.5 77.9 ±\pm 1.2 78.4
Table 8: Domain-specific out-of-domain accuracy on VLCS where the best results are marked as bold. Results of other methods are referenced from [13, 50].
Algorithm A C P R Avg
ERM 61.3 ±\pm 0.7 52.4 ±\pm 0.3 75.8 ±\pm 0.1 76.6 ±\pm 0.3 66.5
Mixup 62.4 ±\pm 0.8 54.8 ±\pm 0.6 76.9 ±\pm 0.3 78.3 ±\pm 0.2 68.1
MLDG 61.5 ±\pm 0.9 53.2 ±\pm 0.6 75.0 ±\pm 1.2 77.5 ±\pm 0.4 66.8
GroupDRO 60.4 ±\pm 0.7 52.7 ±\pm 1.0 75.0 ±\pm 0.7 76.0 ±\pm 0.7 66.0
IRM 58.9 ±\pm 2.3 52.2 ±\pm 1.6 72.1 ±\pm 2.9 74.0 ±\pm 2.5 64.3
VREx 60.7 ±\pm 0.9 53.0 ±\pm 0.9 75.3 ±\pm 0.1 76.6 ±\pm 0.5 66.4
EQRM 60.5 ±\pm 0.1 56.0 ±\pm 0.2 76.1 ±\pm 0.4 77.4 ±\pm 0.3 67.5
Fish - - - - 68.6
Fishr 62.4 ±\pm 0.5 54.4 ±\pm 0.4 76.2 ±\pm 0.5 78.3 ±\pm 0.1 67.8
CORAL 65.3 ±\pm 0.4 54.4 ±\pm 0.5 76.5 ±\pm 0.1 78.4 ±\pm 0.5 68.7
MMD 60.4 ±\pm 0.2 53.3 ±\pm 0.3 74.3 ±\pm 0.1 77.4 ±\pm 0.6 66.3
RDM (ours) 61.1 ±\pm 0.4 55.1 ±\pm 0.3 75.7 ±\pm 0.5 77.3 ±\pm 0.3 67.3
Table 9: Domain-specific out-of-domain accuracy on OfficeHome where the best results are marked as bold. Results of other methods are referenced from [13, 50].
Algorithm L100 L38 L43 L46 Avg
ERM 49.8 ±\pm 4.4 42.1 ±\pm 1.4 56.9 ±\pm 1.8 35.7 ±\pm 3.9 46.1
Mixup 59.6 ±\pm 2.0 42.2 ±\pm 1.4 55.9 ±\pm 0.8 33.9 ±\pm 1.4 47.9
MLDG 54.2 ±\pm 3.0 44.3 ±\pm 1.1 55.6 ±\pm 0.3 36.9 ±\pm 2.2 47.7
GroupDRO 41.2 ±\pm 0.7 38.6 ±\pm 2.1 56.7 ±\pm 0.9 36.4 ±\pm 2.1 43.2
IRM 54.6 ±\pm 1.3 39.8 ±\pm 1.9 56.2 ±\pm 1.8 39.6 ±\pm 0.8 47.6
VREx 48.2 ±\pm 4.3 41.7 ±\pm 1.3 56.8 ±\pm 0.8 38.7 ±\pm 3.1 46.4
EQRM 47.9 ±\pm 1.9 45.2 ±\pm 0.3 59.1 ±\pm 0.3 38.8 ±\pm 0.6 47.8
Fish - - - - 45.1
Fishr 50.2 ±\pm 3.9 43.9 ±\pm 0.8 55.7 ±\pm 2.2 39.8 ±\pm 1.0 47.4
CORAL 51.6 ±\pm 2.4 42.2 ±\pm 1.0 57.0 ±\pm 1.0 39.8 ±\pm 2.9 47.6
MMD 41.9 ±\pm 3.0 34.8 ±\pm 1.0 57.0 ±\pm 1.9 35.2 ±\pm 1.8 42.2
RDM (ours) 52.9 ±\pm 1.2 43.1 ±\pm 1.0 58.1 ±\pm 1.3 36.1 ±\pm 2.9 47.5
Table 10: Domain-specific out-of-domain accuracy on TerraIncognita where the best results are marked as bold. Results of other methods are referenced from [13, 50].