This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Domain Adaptation with Conditional Distribution Matching and Generalized Label Shift

Remi Tachet des Combes
Microsoft Research Montreal
Montreal, QC, Canada
[email protected]
&Han Zhao11footnotemark: 1
D. E. Shaw & Co.
New York, NY, USA
[email protected]
Yu-Xiang Wang
UC Santa Barbara
Santa Barbara, CA, USA
[email protected]
&Geoff Gordon
Microsoft Research Montreal
Montreal, QC, Canada
[email protected]
The first two authors contributed equally to this work. Work done while HZ was at Carnegie Mellon University.
Abstract

Adversarial learning has demonstrated good performance in the unsupervised domain adaptation setting, by learning domain-invariant representations. However, recent work has shown limitations of this approach when label distributions differ between the source and target domains. In this paper, we propose a new assumption, generalized label shift (GLSGLS), to improve robustness against mismatched label distributions. GLSGLS states that, conditioned on the label, there exists a representation of the input that is invariant between the source and target domains. Under GLSGLS, we provide theoretical guarantees on the transfer performance of any classifier. We also devise necessary and sufficient conditions for GLSGLS to hold, by using an estimation of the relative class weights between domains and an appropriate reweighting of samples. Our weight estimation method could be straightforwardly and generically applied in existing domain adaptation (DA) algorithms that learn domain-invariant representations, with small computational overhead. In particular, we modify three DA algorithms, JAN, DANN and CDAN, and evaluate their performance on standard and artificial DA tasks. Our algorithms outperform the base versions, with vast improvements for large label distribution mismatches. Our code is available at https://tinyurl.com/y585xt6j.

1 Introduction

In spite of impressive successes, most deep learning models [24] rely on huge amounts of labelled data and their features have proven brittle to distribution shifts [59, 43]. Building more robust models, that learn from fewer samples and/or generalize better out-of-distribution is the focus of many recent works [5, 2, 57]. The research direction of interest to this paper is that of domain adaptation, which aims at learning features that transfer well between domains. We focus in particular on unsupervised domain adaptation (UDA), where the algorithm has access to labelled samples from a source domain and unlabelled data from a target domain. Its objective is to train a model that generalizes well to the target domain. Building on advances in adversarial learning [25], adversarial domain adaptation (ADA) leverages the use of a discriminator to learn an intermediate representation that is invariant between the source and target domains. Simultaneously, the representation is paired with a classifier, trained to perform well on the source domain [22, 53, 64, 36]. ADA is rather successful on a variety of tasks, however, recent work has proven an upper bound on the performance of existing algorithms when source and target domains have mismatched label distributions [66]. Label shift is a property of two domains for which the marginal label distributions differ, but the conditional distributions of input given label stay the same across domains [52, 62].

In this paper, we study domain adaptation under mismatched label distributions and design methods that are robust in that setting. Our contributions are the following. First, we extend the upper bound by Zhao et al. [66] to kk-class classification and to conditional domain adversarial networks, a recently introduced domain adaptation algorithm [41]. Second, we introduce generalized label shift (GLSGLS), a broader version of the standard label shift where conditional invariance between source and target domains is placed in representation rather than input space. Third, we derive performance guarantees for algorithms that seek to enforce GLSGLS via learnt feature transformations, in the form of upper bounds on the error gap and the joint error of the classifier on the source and target domains. Those guarantees suggest principled modifications to ADA to improve its robustness to mismatched label distributions. The modifications rely on estimating the class ratios between source and target domains and use those as importance weights in the adversarial and classification objectives. The importance weights estimation is performed using a method of moment by solving a quadratic program, inspired from Lipton et al. [35]. Following these theoretical insights, we devise three new algorithms based on learning importance-weighted representations, DANN[22], JAN[40] and CDAN[41]. We apply our variants to artificial UDA tasks with large divergences between label distributions, and demonstrate significant performance gains compared to the algorithms’ base versions. Finally, we evaluate them on standard domain adaptation tasks and also show improved performance.

2 Preliminaries

Notation  We focus on the general kk-class classification problem. 𝒳\mathcal{X} and 𝒴\mathcal{Y} denote the input and output space, respectively. 𝒵\mathcal{Z} stands for the representation space induced from 𝒳\mathcal{X} by a feature transformation g:𝒳𝒵g:\mathcal{X}\mapsto\mathcal{Z}. Accordingly, we use X,Y,ZX,Y,Z to denote random variables which take values in 𝒳,𝒴,𝒵\mathcal{X},\mathcal{Y},\mathcal{Z}. Domain corresponds to a joint distribution on the input space 𝒳\mathcal{X} and output space 𝒴\mathcal{Y}, and we use 𝒟S\mathcal{D}_{S} (resp. 𝒟T\mathcal{D}_{T}) to denote the source (resp. target) domain. Noticeably, this corresponds to a stochastic setting, which is stronger than the deterministic one previously studied [6, 7, 66]. A hypothesis is a function h:𝒳[k]h:\mathcal{X}\to[k]. The error of a hypothesis hh under distribution 𝒟S\mathcal{D}_{S} is defined as: εS(h):=Pr𝒟S(h(X)Y)\varepsilon_{S}(h)\vcentcolon=\Pr_{\mathcal{D}_{S}}(h(X)\neq Y), i.e., the probability that hh disagrees with YY under 𝒟S\mathcal{D}_{S}.

Domain Adaptation via Invariant Representations   For source (𝒟S\mathcal{D}_{S}) and target (𝒟T\mathcal{D}_{T}) domains, we use 𝒟SX\mathcal{D}_{S}^{X}, 𝒟TX\mathcal{D}_{T}^{X}, 𝒟SY\mathcal{D}^{Y}_{S} and 𝒟TY\mathcal{D}^{Y}_{T} to denote the marginal data and label distributions. In UDA, the algorithm has access to nn labeled points {(𝐱i,yi)}i=1n(𝒳×𝒴)n\{(\mathbf{x}_{i},y_{i})\}_{i=1}^{n}\in(\mathcal{X}\times\mathcal{Y})^{n} and mm unlabeled points {𝐱j}j=1m𝒳m\{\mathbf{x}_{j}\}_{j=1}^{m}\in\mathcal{X}^{m} sampled i.i.d. from the source and target domains. Inspired by Ben-David et al. [7], a common approach is to learn representations invariant to the domain shift. With g:𝒳𝒵g:\mathcal{X}\mapsto\mathcal{Z} a feature transformation and h:𝒵𝒴h:\mathcal{Z}\mapsto\mathcal{Y} a hypothesis on the feature space, a domain invariant representation [22, 53, 63] is a function gg that induces similar distributions on 𝒟S\mathcal{D}_{S} and 𝒟T\mathcal{D}_{T}. gg is also required to preserve rich information about the target task so that εS(hg)\varepsilon_{S}(h\circ g) is small. The above process results in the following Markov chain (assumed to hold throughout the paper):

X𝑔ZY^,X\overset{g}{\longrightarrow}Z\overset{h}{\longrightarrow}\widehat{Y}, (1)

where Y^=h(g(X))\widehat{Y}=h(g(X)). We let 𝒟SZ\mathcal{D}_{S}^{Z}, 𝒟TZ\mathcal{D}_{T}^{Z}, 𝒟SY^\mathcal{D}_{S}^{\widehat{Y}} and 𝒟TY^\mathcal{D}_{T}^{\widehat{Y}} denote the pushforwards (induced distributions) of 𝒟SX\mathcal{D}_{S}^{X} and 𝒟TX\mathcal{D}_{T}^{X} by gg and hgh\circ g. Invariance in feature space is defined as minimizing a distance or divergence between 𝒟SZ\mathcal{D}_{S}^{Z} and 𝒟TZ\mathcal{D}_{T}^{Z}.

Adversarial Domain Adaptation  Invariance is often attained by training a discriminator d:𝒵[0,1]d:\mathcal{Z}\mapsto[0,1] to predict if zz is from the source or target. gg is trained both to maximize the discriminator loss and minimize the classification loss of hgh\circ g on the source domain (hh is also trained with the latter objective). This leads to domain-adversarial neural networks [22, DANN], where gg, hh and dd are parameterized with neural networks: gθg_{\theta}, hϕh_{\phi} and dψd_{\psi} (see Algo. 1 and App. B.5). Building on DANN, conditional domain adversarial networks [41, CDAN] use the same adversarial paradigm. However, the discriminator now takes as input the outer product, for a given xx, between the predictions of the network h(g(x))h(g(x)) and its representation g(x)g(x). In other words, dd acts on the outer product: hg(x):=(h1(g(x))g(x),,hk(g(x))g(x))h\otimes g(x)\vcentcolon=(h_{1}(g(x))\cdot g(x),\dots,h_{k}(g(x))\cdot g(x)) rather than on g(x)g(x). hih_{i} denotes the ii-th element of vector hh. We now highlight a limitation of DANNs and CDANs.

An Information-Theoretic Lower Bound We let DJSD_{\text{JS}} denote the Jensen-Shanon divergence between two distributions (App. A.1), and Z~\widetilde{Z} correspond to ZZ (for DANN) or to Y^Z\widehat{Y}\otimes Z (for CDAN). The following theorem lower bounds the joint error of the classifier on the source and target domains:

Theorem 2.1.

Assuming that DJS(𝒟SY𝒟TY)DJS(𝒟SZ~𝒟TZ~)D_{\text{JS}}(\mathcal{D}_{S}^{Y}~{}\|~{}\mathcal{D}_{T}^{Y})\geq D_{\text{JS}}(\mathcal{D}_{S}^{\widetilde{Z}}~{}\|~{}\mathcal{D}_{T}^{\widetilde{Z}}), then:

εS(hg)+εT(hg)12(DJS(𝒟SY𝒟TY)DJS(𝒟SZ~𝒟TZ~))2.\displaystyle\varepsilon_{S}(h\circ g)+\varepsilon_{T}(h\circ g)\geq\frac{1}{2}\left(\sqrt{D_{\text{JS}}(\mathcal{D}_{S}^{Y}~{}\|~{}\mathcal{D}_{T}^{Y})}-\sqrt{D_{\text{JS}}(\mathcal{D}_{S}^{\widetilde{Z}}~{}\|~{}\mathcal{D}_{T}^{\widetilde{Z}})}\right)^{2}.

Remark  The lower bound is algorithm-independent. It is also a population-level result and holds asymptotically with increasing data. Zhao et al. [66] prove the theorem for k=2k=2 and Z~=Z\widetilde{Z}=Z. We extend it to CDAN and arbitrary kk (it actually holds for any Z~\widetilde{Z} s.t. Y^=h~(Z~)\widehat{Y}=\widetilde{h}(\widetilde{Z}) for some h~\widetilde{h}, see App. A.3). Assuming that label distributions differ between source and target domains, the lower bound shows that: the better the alignment of feature distributions, the worse the joint error. For an invariant representation (DJS(𝒟SZ~,𝒟TZ~)=0D_{\text{JS}}(\mathcal{D}_{S}^{\tilde{Z}},\mathcal{D}_{T}^{\tilde{Z}})=0) with no source error, the target error will be larger than DJS(𝒟SY,𝒟TY)/2D_{\text{JS}}(\mathcal{D}_{S}^{Y},\mathcal{D}_{T}^{Y})/2. Hence algorithms learning invariant representations and minimizing the source error are fundamentally flawed when label distributions differ between source and target domains.

Table 1: Common assumptions in the domain adaptation literature.
Covariate Shift Label Shift
𝒟SX𝒟TX\mathcal{D}_{S}^{X}\neq\mathcal{D}_{T}^{X} 𝒟SY𝒟TY\mathcal{D}_{S}^{Y}\neq\mathcal{D}_{T}^{Y}
𝐱𝒳,𝒟S(YX=𝐱)=𝒟T(YX=𝐱)\forall\mathbf{x}\in\mathcal{X},\mathcal{D}_{S}(Y\mid X=\mathbf{x})=\mathcal{D}_{T}(Y\mid X=\mathbf{x}) y𝒴,𝒟S(XY=y)=𝒟T(XY=y)\forall y\in\mathcal{Y},\mathcal{D}_{S}(X\mid Y=y)=\mathcal{D}_{T}(X\mid Y=y)

Common Assumptions to Tackle Domain Adaptation Two common assumptions about the data made in DA are covariate shift and label shift. They correspond to different ways of decomposing the joint distribution over X×YX\times Y, as detailed in Table 1. From a representation learning perspective, covariate shift is not robust to feature transformation and can lead to an effect called negative transfer [66]. At the same time, label shift clearly fails in most practical applications, e.g. transferring knowledge from synthetic to real images [55]. In that case, the input distributions are actually disjoint.

3 Main Results

In light of the limitations of existing assumptions, (e.g. covariate shift and label shift), we propose generalized label shift (GLSGLS), a relaxation of label shift that substantially improves its applicability. We first discuss some of its properties and explain why the assumption is favorable in domain adaptation based on representation learning. Motivated by GLSGLS, we then present a novel error decomposition theorem that directly suggests a bound minimization framework for domain adaptation. The framework is naturally compatible with \mathcal{F}-integral probability metrics [44, \mathcal{F}-IPM] and generates a family of domain adaptation algorithms by choosing various function classes \mathcal{F}. In a nutshell, the proposed framework applies a method of moments [35] to estimate the importance weight 𝐰\mathbf{w} of the marginal label distributions by solving a quadratic program (QP), and then uses 𝐰\mathbf{w} to align the weighted source feature distribution with the target feature distribution.

3.1 Generalized Label Shift

Definition 3.1 (Generalized Label Shift, GLSGLS).

A representation Z=g(X)Z=g(X) satisfies GLSGLS if

𝒟S(ZY=y)=𝒟T(ZY=y),y𝒴.\mathcal{D}_{S}(Z\mid Y=y)=\mathcal{D}_{T}(Z\mid Y=y),~{}\forall y\in\mathcal{Y}. (2)

First, when gg is the identity map, the above definition of GLSGLS reduces to the original label shift assumption. Next, GLSGLS is always achievable for any distribution pair (𝒟S,𝒟T)(\mathcal{D}_{S},\mathcal{D}_{T}): any constant function gcg\equiv c\in\mathbb{R} satisfies the above definition. The most important property is arguably that, unlike label shift, GLSGLS is compatible with a perfect classifier (in the noiseless case). Precisely, if there exists a ground-truth labeling function hh^{*} such that Y=h(X)Y=h^{*}(X), then hh^{*} satisfies GLSGLS. As a comparison, without conditioning on Y=yY=y, hh^{*} does not satisfy 𝒟S(h(X))=𝒟T(h(X))\mathcal{D}_{S}(h^{*}(X))=\mathcal{D}_{T}(h^{*}(X)) if the marginal label distributions are different across domains. This observation is consistent with the lower bound in Theorem 2.1, which holds for arbitrary marginal label distributions.

GLSGLS imposes label shift in the feature space 𝒵\mathcal{Z} instead of the original input space 𝒳\mathcal{X}. Conceptually, although samples from the same classes in the source and target domain can be dramatically different, the hope is to find an intermediate representation for both domains in which samples from a given class look similar to one another. Taking digit classification as an example and assuming the feature variable ZZ corresponds to the contour of a digit, it is possible that by using different contour extractors for e.g. MNIST and USPS, those contours look roughly the same in both domains. Technically, GLSGLS can be facilitated by having separate representation extractors gSg_{S} and gTg_{T} for source and target [9, 53].

3.2 An Error Decomposition Theorem based on GLSGLS

We now provide performance guarantees for models that satisfy GLSGLS, in the form of upper bounds on the error gap and on the joint error between source and target domains. It requires two concepts:

Definition 3.2 (Balanced Error Rate).

The balanced error rate (BER) of predictor Y^\widehat{Y} on 𝒟S\mathcal{D}_{S} is:

BER𝒟S(Y^Y):=maxj[k]𝒟S(Y^YY=j).\displaystyle\mathrm{BER}_{\mathcal{D}_{S}}(\widehat{Y}~{}\|~{}Y)\vcentcolon=\max_{j\in[k]}\mathcal{D}_{S}(\widehat{Y}\neq Y\mid Y=j). (3)
Definition 3.3 (Conditional Error Gap).

Given a joint distribution 𝒟\mathcal{D}, the conditional error gap of a classifier Y^\widehat{Y} is ΔCE(Y^):=maxyy𝒴2|𝒟S(Y^=yY=y)𝒟T(Y^=yY=y)|\Delta_{\mathrm{CE}}(\widehat{Y})\vcentcolon=\max_{y\neq y^{\prime}\in\mathcal{Y}^{2}}~{}|\mathcal{D}_{S}(\widehat{Y}=y^{\prime}\mid Y=y)-\mathcal{D}_{T}(\widehat{Y}=y^{\prime}\mid Y=y)|.

When GLSGLS holds, ΔCE(Y^)\Delta_{\mathrm{CE}}(\widehat{Y}) is equal to 0. We now give an upper bound on the error gap between source and target, which can also be used to obtain a generalization upper bound on the target risk.

Theorem 3.1.

(Error Decomposition Theorem) For any classifier Y^=(hg)(X)\widehat{Y}=(h\circ g)(X),

|εS(hg)εT(hg)|𝒟SY𝒟TY1BER𝒟S(Y^Y)+2(k1)ΔCE(Y^),\displaystyle|\varepsilon_{S}(h\circ g)-\varepsilon_{T}(h\circ g)|\leq~{}\|\mathcal{D}_{S}^{Y}-\mathcal{D}_{T}^{Y}\|_{1}\cdot\mathrm{BER}_{\mathcal{D}_{S}}(\widehat{Y}~{}\|~{}Y)+2(k-1)\Delta_{\mathrm{CE}}(\widehat{Y}),

where 𝒟SY𝒟TY1:=i=1k|𝒟S(Y=i)𝒟T(Y=i)|\|\mathcal{D}_{S}^{Y}-\mathcal{D}_{T}^{Y}\|_{1}\vcentcolon=\sum_{i=1}^{k}|\mathcal{D}_{S}(Y=i)-\mathcal{D}_{T}(Y=i)| is the L1L_{1} distance between 𝒟SY\mathcal{D}_{S}^{Y} and 𝒟TY\mathcal{D}_{T}^{Y}.

Remark The upper bound in Theorem 3.1 provides a way to decompose the error gap between source and target domains. It also immediately gives a generalization bound on the target risk εT(hg)\varepsilon_{T}(h\circ g). The bound contains two terms. The first contains 𝒟SY𝒟TY1\|\mathcal{D}_{S}^{Y}-\mathcal{D}_{T}^{Y}\|_{1}, which measures the distance between the marginal label distributions across domains and is a constant that only depends on the adaptation problem itself, and BER\mathrm{BER}, a reweighted classification performance on the source domain. The second is ΔCE(Y^)\Delta_{\mathrm{CE}}(\widehat{Y}) measures the distance between the family of conditional distributions Y^Y\widehat{Y}\mid Y. In other words, the bound is oblivious to the optimal labeling functions in feature space. This is in sharp contrast with upper bounds from previous work [7, Theorem 2], [66, Theorem 4.1], which essentially decompose the error gap in terms of the distance between the marginal feature distributions (𝒟SZ\mathcal{D}_{S}^{Z}, 𝒟TZ\mathcal{D}_{T}^{Z}) and the optimal labeling functions (fSZf_{S}^{Z}, fTZf_{T}^{Z}). Because the optimal labeling function in feature space depends on ZZ and is unknown in practice, such decomposition is not very informative. As a comparison, Theorem 3.1 provides a decomposition orthogonal to previous results and does not require knowledge about unknown optimal labeling functions in feature space.

Notably, the balanced error rate, BER𝒟S(Y^Y)\mathrm{BER}_{\mathcal{D}_{S}}(\widehat{Y}~{}\|~{}Y), only depends on samples from the source domain and can be minimized. Furthermore, using a data-processing argument, the conditional error gap ΔCE(Y^)\Delta_{\mathrm{CE}}(\widehat{Y}) can be minimized by aligning the conditional feature distributions across domains. Putting everything together, the result suggests that, to minimize the error gap, it suffices to align the conditional distributions ZY=yZ\mid Y=y while simultaneously minimizing the balanced error rate. In fact, under the assumption that the conditional distributions are perfectly aligned (i.e., under GLSGLS), we can prove a stronger result, guaranteeing that the joint error is small:

Theorem 3.2.

If Z=g(X)Z=g(X) satisfies GLSGLS, then for any h:𝒵𝒴h:\mathcal{Z}\to\mathcal{Y} and letting Y^=h(Z)\widehat{Y}=h(Z) be the predictor, we have εS(Y^)+εT(Y^)2BER𝒟S(Y^Y)\varepsilon_{S}(\widehat{Y})+\varepsilon_{T}(\widehat{Y})\leq 2\mathrm{BER}_{\mathcal{D}_{S}}(\widehat{Y}~{}\|~{}Y).

3.3 Conditions for Generalized Label Shift

The main difficulty in applying a bound minimization algorithm inspired by Theorem 3.1 is that we do not have labels from the target domain in UDA, so we cannot directly align the conditional label distributions. By using relative class weights between domains, we can provide a necessary condition for GLSGLS that bypasses an explicit alignment of the conditional feature distributions.

Definition 3.4.

Assuming 𝒟S(Y=y)>0,y𝒴\mathcal{D}_{S}(Y=y)>0,\forall y\in\mathcal{Y}, we let 𝐰k\mathbf{w}\in\mathbb{R}^{k} denote the importance weights of the target and source label distributions:

𝐰y:=𝒟T(Y=y)𝒟S(Y=y),y𝒴.\displaystyle\mathbf{w}_{y}\vcentcolon=\frac{\mathcal{D}_{T}(Y=y)}{\mathcal{D}_{S}(Y=y)},\quad\forall y\in\mathcal{Y}. (4)
Lemma 3.1.

(Necessary condition for GLSGLS) If Z=g(X)Z=g(X) satisfies GLSGLS, then 𝒟T(Z~)=y𝒴𝐰y𝒟S(Z~,Y=y)=:𝒟S𝐰(Z~)\mathcal{D}_{T}(\widetilde{Z})=\sum_{y\in\mathcal{Y}}\mathbf{w}_{y}\cdot\mathcal{D}_{S}(\widetilde{Z},Y=y)=\vcentcolon\mathcal{D}_{S}^{\mathbf{w}}(\widetilde{Z}) where Z~\widetilde{Z} verifies either Z~=Z\widetilde{Z}=Z or Z~=Y^Z\widetilde{Z}=\widehat{Y}\otimes Z.

Compared to previous work that attempts to align 𝒟T(Z)\mathcal{D}_{T}(Z) with 𝒟S(Z)\mathcal{D}_{S}(Z) (using adversarial discriminators [22] or maximum mean discrepancy (MMD) [38]) or 𝒟T(Y^Z)\mathcal{D}_{T}(\hat{Y}\otimes Z) with 𝒟S(Y^Z)\mathcal{D}_{S}(\hat{Y}\otimes Z) [41], Lemma 3.1 suggests to instead align 𝒟T(Z~)\mathcal{D}_{T}(\widetilde{Z}) with the reweighted marginal distribution 𝒟S𝐰(Z~)\mathcal{D}^{\mathbf{w}}_{S}(\widetilde{Z}). Reciprocally, the following two theorems give sufficient conditions to know when perfectly aligned target feature distribution and reweighted source feature distribution imply GLSGLS:

Theorem 3.3.

(Clustering structure implies sufficiency) Let Z=g(X)Z=g(X) such that 𝒟T(Z)=𝒟S𝐰(Z)\mathcal{D}_{T}(Z)=\mathcal{D}_{S}^{\mathbf{w}}(Z). Assume 𝒟T(Y=y)>0,y𝒴\mathcal{D}_{T}(Y=y)>0,\forall y\in\mathcal{Y}. If there exists a partition of 𝒵=y𝒴𝒵y\mathcal{Z}=\cup_{y\in\mathcal{Y}}\mathcal{Z}_{y} such that y𝒴\forall y\in\mathcal{Y}, 𝒟S(Z𝒵yY=y)=𝒟T(Z𝒵yY=y)=1\mathcal{D}_{S}(Z\in\mathcal{Z}_{y}\mid Y=y)=\mathcal{D}_{T}(Z\in\mathcal{Z}_{y}\mid Y=y)=1, then Z=g(X)Z=g(X) satisfies GLSGLS.

Remark

Theorem 3.3 shows that if there exists a partition of the feature space such that instances with the same label are within the same component, then aligning the target feature distribution with the reweighted source feature distribution implies GLSGLS. While this clustering assumption may seem strong, it is consistent with the goal of reducing classification error: if such a clustering exists, then there also exists a perfect predictor based on the feature Z=g(X)Z=g(X), i.e., the cluster index.

Theorem 3.4.

Let Y^=h(Z)\widehat{Y}=h(Z), γ:=miny𝒴𝒟T(Y=y)\gamma\vcentcolon=\min_{y\in\mathcal{Y}}\mathcal{D}_{T}(Y=y) and 𝐰M:=maxy𝒴𝐰y\mathbf{w}_{M}\vcentcolon=\max_{y\in\mathcal{Y}}\thinspace\mathbf{w}_{y}. For Z~=Z\widetilde{Z}=Z or Z~=Y^Z\widetilde{Z}=\hat{Y}\otimes Z, we have:

maxy𝒴dTV(𝒟S(ZY=y),𝒟T(ZY=y))𝐰MεS(Y^)+εT(Y^)+8DJS(𝒟S𝐰(Z~)𝒟T(Z~))γ.\displaystyle\max_{y\in\mathcal{Y}}~{}d_{\text{TV}}(\mathcal{D}_{S}(Z\mid Y=y),\mathcal{D}_{T}(Z\mid Y=y))\leq\frac{\mathbf{w}_{M}\varepsilon_{S}(\widehat{Y})+\varepsilon_{T}(\widehat{Y})+\sqrt{8D_{\text{JS}}(\mathcal{D}_{S}^{\mathbf{w}}(\widetilde{Z})\|\mathcal{D}_{T}(\widetilde{Z}))}}{\gamma}.

Theorem 3.4 confirms that matching 𝒟T(Z~)\mathcal{D}_{T}(\widetilde{Z}) with 𝒟S𝐰(Z~)\mathcal{D}^{\mathbf{w}}_{S}(\widetilde{Z}) is the proper objective in the context of mismatched label distributions. It shows that, for matched feature distributions and a source error equal to zero, successful domain adaptation (i.e. a target error equal to zero) implies that GLSGLS holds. Combined with Theorem 3.2, we even get equivalence between the two.

Remark Thm. 3.4 extends Thm. 3.3 by incorporating the clustering assumption in the joint error achievable by a classifier Y^\widehat{Y} based on a fixed ZZ. In particular, if the clustering structure holds, the joint error is 0 for an appropriate hh, and aligning the reweighted feature distributions implies GLSGLS.

3.4 Estimating the Importance Weights 𝐰\mathbf{w}

Inspired by the moment matching technique to estimate 𝐰\mathbf{w} under label shift from Lipton et al. [35], we propose a method to get 𝐰\mathbf{w} under GLSGLS by solving a quadratic program (QP).

Definition 3.5.

We let C|𝒴|×|𝒴|\textbf{C}\in\mathbb{R}^{|\mathcal{Y}|\times|\mathcal{Y}|} denote the confusion matrix of the classifier on the source domain and 𝝁|𝒴|\boldsymbol{\mu}\in\mathbb{R}^{|\mathcal{Y}|} the distribution of predictions on the target one, y,y𝒴\forall y,y^{\prime}\in\mathcal{Y}:

Cy,y:=𝒟S(Y^=y,Y=y),𝝁y:=𝒟T(Y^=y).\displaystyle\textbf{C}_{y,y^{\prime}}\vcentcolon=\mathcal{D}_{S}(\widehat{Y}=y,Y=y^{\prime}),\qquad\boldsymbol{\mu}_{y}\vcentcolon=\mathcal{D}_{T}(\widehat{Y}=y).
Lemma 3.2.

If GLSGLS is verified, and if the confusion matrix C is invertible, then 𝐰=C1𝝁\mathbf{w}=\textbf{C}^{-1}\boldsymbol{\mu}.

The key insight from Lemma 3.2 is that, to estimate the importance vector 𝐰\mathbf{w} under GLSGLS, we do not need access to labels from the target domain. However, matrix inversion is notoriously numerically unstable, especially with finite sample estimates C^\hat{\textbf{C}} and 𝝁^\hat{\boldsymbol{\mu}} of C and 𝝁\boldsymbol{\mu}. We propose to solve instead the following QP (written as QP(C^,𝝁^)QP(\hat{\textbf{C}},\hat{\boldsymbol{\mu}})), whose solution will be consistent if C^C\hat{\textbf{C}}\to\textbf{C} and 𝝁^𝝁\hat{\boldsymbol{\mu}}\to\boldsymbol{\mu}:

minimize𝐰12𝝁^C^𝐰22, subject to 𝐰0,𝐰T𝒟S(Y)=1.\underset{\mathbf{w}}{\text{minimize}}\quad\frac{1}{2}~{}\|\hat{\boldsymbol{\mu}}-\hat{\textbf{C}}\mathbf{w}\|_{2}^{2},\quad\quad\quad\text{ subject to }\quad\mathbf{w}\geq 0,~{}\mathbf{w}^{T}\mathcal{D}_{S}(Y)=1. (5)

The above program (5) can be efficiently solved in time O(|𝒴|3)O(|\mathcal{Y}|^{3}), with |𝒴||\mathcal{Y}| small and constant; and by construction, its solution is element-wise non-negative, even with limited amounts of data to estimate C and 𝝁\boldsymbol{\mu}.

Lemma 3.3.

If the source error εS(hg)\varepsilon_{S}(h\circ g) is zero and the source and target marginals verify DJS(𝒟S𝐰~(Z),𝒟T(Z))=0D_{\text{JS}}(\mathcal{D}_{S}^{\tilde{\mathbf{w}}}(Z),\mathcal{D}_{T}(Z))=0, then the estimated weight vector 𝐰\mathbf{w} is equal to 𝐰~\tilde{\mathbf{w}}.

Lemma 3.3 shows that the weight estimation is stable once the DA losses have converged, but it does not imply convergence to the true weights (see Sec. 4.2 and App. B.8 for more details).

3.5 \mathcal{F}-IPM for Distributional Alignment

To align the target feature distribution and the reweighted source feature distribution as suggested by Lemma 3.1, we now provide a general framework using the integral probability metric [44, IPM].

Definition 3.6.

With \mathcal{F} a set of real-valued functions, the \mathcal{F}-IPM between distributions 𝒟\mathcal{D} and 𝒟\mathcal{D}^{\prime} is

d(𝒟,𝒟):=supf|𝔼X𝒟[f(X)]𝔼X𝒟[f(X)]|.d_{\mathcal{F}}(\mathcal{D},\mathcal{D}^{\prime})\vcentcolon=\sup_{f\in\mathcal{F}}|\mathbb{E}_{X\sim\mathcal{D}}[f(X)]-\mathbb{E}_{X\sim\mathcal{D}^{\prime}}[f(X)]|.\vspace*{-1em} (6)

By approximating any function class \mathcal{F} using parametrized models, e.g., neural networks, we obtain a general framework for domain adaptation by aligning reweighted source feature distribution and target feature distribution, i.e. by minimizing d(𝒟T(Z~),𝒟S𝐰(Z~))d_{\mathcal{F}}(\mathcal{D}_{T}(\widetilde{Z}),\mathcal{D}_{S}^{\mathbf{w}}(\widetilde{Z})). Below, by instantiating \mathcal{F} to be the set of bounded norm functions in a RKHS \mathcal{H} [27], we obtain maximum mean discrepancy methods, leading to IWJAN (cf. Section 4.1), a variant of JAN [40] for UDA.

4 Practical Implementation

4.1 Algorithms

The sections above suggest simple algorithms based on representation learning: (i) estimate 𝐰\mathbf{w} on the fly during training, (ii) align the feature distributions Z~\widetilde{Z} of the target domain with the reweighted feature distribution of the source domain and, (iii) minimize the balanced error rate. Overall, we present the pseudocode of our algorithm in Alg. 1.

To compute 𝐰\mathbf{w}, we build estimators C^\hat{\textbf{C}} and 𝝁^\hat{\boldsymbol{\mu}} of C and 𝝁\boldsymbol{\mu} by averaging during each epoch the predictions of the classifier on the source (per true class) and target (overall). This corresponds to the inner-most loop of Algorithm 1 (lines 8-9). At epoch end, 𝐰\mathbf{w} is updated (line 10), and the estimators reset to 0. We have found empirically that using an exponential moving average of 𝐰\mathbf{w} performs better. Our results all use a factor λ=0.5\lambda=0.5. We also note that Alg. 1 implies a minimal computational overhead (see App. B.1 for details): in practice our algorithms run as fast as their base versions.

Algorithm 1 Importance-Weighted Domain Adaptation
1:  Input: source and target data (xS,yS)(x_{S},y_{S}), xTx_{T}; gθg_{\theta}, hϕh_{\phi} and dψd_{\psi}; epochs EE, batches per epoch BB
2:  Initialize 𝐰1=1\mathbf{w}_{1}=1
3:  for t=1t=1 to EE do
4:     Initialize C^=0\hat{\textbf{C}}=0, 𝝁^=0\hat{\boldsymbol{\mu}}=0
5:     for b=1b=1 to BB do
6:        Sample batches (xSi,ySi)(x^{i}_{S},y^{i}_{S}) and (xTi)(x^{i}_{T}) of size s
7:        Maximize DA𝐰t\mathcal{L}_{DA}^{\mathbf{w}_{t}} w.r.t. θ\theta, minimize DA𝐰t\mathcal{L}_{DA}^{\mathbf{w}_{t}} w.r.t. ψ\psi and minimize C𝐰t\mathcal{L}_{C}^{\mathbf{w}_{t}} w.r.t. θ\theta and ϕ\phi
8:        for i=1i=1 to ss do
9:           C^ySiC^ySi+hϕ(gθ(xSi))\hat{\textbf{C}}_{\cdot y_{S}^{i}}\leftarrow\hat{\textbf{C}}_{\cdot y_{S}^{i}}+h_{\phi}(g_{\theta}(x_{S}^{i})) (ySiy_{S}^{i}-th column)  and  𝝁^𝝁^+hϕ(gθ(xTi))\hat{\boldsymbol{\mu}}\leftarrow\hat{\boldsymbol{\mu}}+h_{\phi}(g_{\theta}(x_{T}^{i}))
10:     C^C^/sB\hat{\textbf{C}}\leftarrow\hat{\textbf{C}}/sB and 𝝁^𝝁^/sB\hat{\boldsymbol{\mu}}\leftarrow\hat{\boldsymbol{\mu}}/sB;  then  𝐰t+1=λQP(C^,𝝁^)+(1λ)𝐰t\mathbf{w}_{t+1}=\lambda\cdot QP(\hat{\textbf{C}},\hat{\boldsymbol{\mu}})+(1-\lambda)\mathbf{w}_{t}

Using 𝐰\mathbf{w}, we can define our first algorithm, Importance-Weighted Domain Adversarial Network (IWDAN), that aligns 𝒟S𝐰(Z)\mathcal{D}^{\mathbf{w}}_{S}(Z) and 𝒟T(Z)\mathcal{D}_{T}(Z)) using a discriminator. To that end, we modify the DANN losses DA\mathcal{L}_{DA} and C\mathcal{L}_{C} as follows. For batches (xSi,ySi)(x^{i}_{S},y^{i}_{S}) and (xTi)(x^{i}_{T}) of size ss, the weighted DA loss is:

DA𝐰(xSi,ySi,xTi;θ,ψ)=1si=1s𝐰ySilog(dψ(gθ(xSi)))+log(1dψ(gθ(xTi))).\displaystyle\mathcal{L}_{DA}^{\mathbf{w}}(x^{i}_{S},y^{i}_{S},x_{T}^{i};\theta,\psi)=-\frac{1}{s}\displaystyle{\sum_{i=1}^{s}}\mathbf{w}_{y^{i}_{S}}\log(d_{\psi}(g_{\theta}(x_{S}^{i})))+\log(1-d_{\psi}(g_{\theta}(x_{T}^{i}))). (7)

We verify in App. A.1, that the standard ADA framework applied to DA𝐰\mathcal{L}_{DA}^{\mathbf{w}} indeed minimizes DJS(𝒟S𝐰(Z)𝒟T(Z))D_{\text{JS}}(\mathcal{D}^{\mathbf{w}}_{S}(Z)\|\mathcal{D}_{T}(Z)). Our second algorithm, Importance-Weighted Joint Adaptation Networks (IWJAN) is based on JAN [40] and follows the reweighting principle described in Section 3.5 with \mathcal{F} a learnt RKHS (the exact JAN and IWJAN losses are specified in App. B.5). Finally, our third algorithm is Importance-Weighted Conditional Domain Adversarial Network (IWCDAN). It matches 𝒟S𝐰(Y^Z)\mathcal{D}_{S}^{\mathbf{w}}(\hat{Y}\otimes Z) with 𝒟T(Y^Z)\mathcal{D}_{T}(\hat{Y}\otimes Z) by replacing the standard adversarial loss in CDAN with Eq. 7, where dψd_{\psi} takes as input (hϕgθ)gθ(h_{\phi}\circ g_{\theta})\otimes g_{\theta} instead of gθg_{\theta}. The classifier loss for our three variants is:

C𝐰(xSi,ySi;θ,ϕ)=1si=1s1k𝒟S(Y=y)log(hϕ(gθ(xSi))ySi).\displaystyle\mathcal{L}_{C}^{\mathbf{w}}(x^{i}_{S},y^{i}_{S};\theta,\phi)=-\frac{1}{s}\displaystyle{\sum_{i=1}^{s}}\frac{1}{k\cdot\mathcal{D}_{S}(Y=y)}\log(h_{\phi}(g_{\theta}(x_{S}^{i}))_{y^{i}_{S}}). (8)

This reweighting is suggested by our theoretical analysis from Section 3, where we seek to minimize the balanced error rate BER𝒟S(Y^Y)\mathrm{BER}_{\mathcal{D}_{S}}(\widehat{Y}~{}\|~{}Y). We also define oracle versions, IWDAN-O, IWJAN-O and IWCDAN-O, where the weights 𝐰\mathbf{w} are the true weights. It gives an idealistic version of the reweighting method, and allows to assess the soundness of GLSGLS. IWDAN, IWJAN and IWCDAN are Alg. 1 with their respective loss functions, the oracle versions use the true weights instead of 𝐰t\mathbf{w}_{t}.

4.2 Experiments

Refer to caption
Refer to caption
Figure 1: Gains of our algorithms vs their base versions (the horizontal grey line) for 100 tasks. The xx-axis is DJS(𝒟SY,𝒟TY)D_{\text{JS}}(\mathcal{D}_{S}^{Y},\mathcal{D}_{T}^{Y}). The mean improvements for IWDAN and IWDAN-O (resp. IWCDAN and IWCDAN-O) are 6.55%6.55\% and 8.14%8.14\% (resp. 2.25%2.25\% and 2.81%2.81\%).

We apply our three base algorithms, their importance weighted versions, and the oracles to 4 standard DA datasets generating 21 tasks: Digits (MNIST \leftrightarrow USPS [32, 19]), Visda [55], Office-31 [49] and Office-Home [54]. All values are averages over 5 runs of the best test accuracy throughout training (evaluated at the end of each epoch). We used that value for fairness with respect to the baselines (as shown in the left panel of Figure 2, the performance of DANN decreases as training progresses, due to the inappropriate matching of representations showcased in Theorem 2.1). For full details, see App. B.2 and B.7.

Performance vs 𝐃JS\mathbf{D_{\text{JS}}}  We artificially generate 100 tasks from MNIST and USPS by considering various random subsets of the classes in either the source or target domain (see Appendix B.6 for details). These 100 DA tasks have a DJS(𝒟SY,𝒟TY)D_{\text{JS}}(\mathcal{D}_{S}^{Y},\mathcal{D}_{T}^{Y}) varying between 0 and 0.10.1. Applying IWDAN and IWCDAN results in Fig. 1. We see a clear correlation between the improvements provided by our algorithms and DJS(𝒟SY,𝒟TY)D_{\text{JS}}(\mathcal{D}_{S}^{Y},\mathcal{D}_{T}^{Y}), which is well aligned with Theorem 2.1. Moreover, IWDAN outperfoms DANN on the 100100 tasks and IWCDAN bests CDAN on 9494. Even on small divergences, our algorithms do not suffer compared to their base versions.

Original Datasets  The average results on each dataset are shown in Table 2 (see App.B.3 for the per-task breakdown). IWDAN outperforms the basic algorithm DANN by 1.75%1.75\%, 1.64%1.64\%, 1.16%1.16\% and 2.65%2.65\% on the Digits, Visda, Office-31 and Office-Home tasks respectively. Gains for IWCDAN are more limited, but still present: 0.18%0.18\%, 0.89%0.89\%, 0.07%0.07\% and 1.07%1.07\% respectively. This might be explained by the fact that CDAN enforces a weak form of GLSGLS (App. B.5.2). Gains for JAN are 0.58%0.58\%, 0.19%0.19\% and 0.19%0.19\%. We also show the fraction of times (over all seeds and tasks) our variants outperform the original algorithms. Even for small gains, the variants provide consistent improvements. Additionally, the oracle versions show larger improvements, which strongly supports enforcing GLSGLS.

Table 2: Average results on the various domains (Digits has 2 tasks, Visda 1, Office-31 6 and Office-Home 12). The prefix ss denotes the experiment where the source domain is subsampled to increase DJS(𝒟SY,𝒟TY)D_{\text{JS}}(\mathcal{D}_{S}^{Y},\mathcal{D}_{T}^{Y}). Each number is a mean over 5 seeds, the subscript denotes the fraction of times (out of 5seeds×#tasks5~{}seeds\times\#tasks) our algorithms outperform their base versions. JAN is not available on Digits.
Method Digits ssDigits Visda ssVisda O-31 ssO-31 O-H ssO-H
No DA 77.17 75.67 48.39 49.02 77.81 75.72 56.39 51.34
DANN 93.15 83.24 61.88 52.85 82.74 76.17 59.62 51.83
IWDAN 94.90100% 92.54100% 63.52100% 60.18100% 83.9087% 82.60100% 62.2797% 57.61100%
IWDAN-O 95.27100% 94.46100% 64.19100% 62.10100% 85.3397% 84.41100% 64.68100% 60.87100%
CDAN 95.72 88.23 65.60 60.19 87.23 81.62 64.59 56.25
IWCDAN 95.9080% 93.22100% 66.4960% 65.83100% 87.3073% 83.88100% 65.6670% 61.24100%
IWCDAN-O 95.8590% 94.81100% 68.15100% 66.85100% 88.1490% 85.47100% 67.6498% 63.73100%
JAN N/A N/A 56.98 50.64 85.13 78.21 59.59 53.94
IWJAN N/A N/A 57.56100% 57.12100% 85.3260% 82.6197% 59.7863% 55.89100%
IWJAN-O N/A N/A 61.48100% 61.30100% 87.14100% 86.24100% 60.7392% 57.36100%
Refer to caption
Refer to caption
Figure 2: Left Accuracy on sDigits. Right Euclidian distance between estimated and true weights.

Subsampled datasets  The original datasets have fairly balanced classes, making the JSD between source and target label distributions DJS(𝒟SY𝒟TY)D_{\text{JS}}(\mathcal{D}_{S}^{Y}~{}\|~{}\mathcal{D}_{T}^{Y}) rather small (Tables 11a, 12a and 13a in App. B.4). To evaluate our algorithms on larger divergences, we arbitrarily modify the source domains above by considering only 30%30\% of the samples coming from the first half of their classes. This results in larger divergences (Tables 11b, 12b and 13b). Performance is shown in Table 2 (datasets prefixed by s). For IWDAN, we see gains of 9.3%9.3\%, 7.33%7.33\%, 6.43%6.43\% and 5.58%5.58\% on the digits, Visda, Office-31 and Office-Home datasets respectively. For IWCDAN, improvements are 4.99%4.99\%, 5.64%5.64\%, 2.26%2.26\% and 4.99%4.99\%, and IWJAN shows gains of 6.48%6.48\%, 4.40%4.40\% and 1.95%1.95\%. Moreover, on all seeds and tasks but one, our variants outperform their base versions.

Importance weights While our method demonstrates gains empirically, Lemma 3.2 does not guarantee convergence of 𝐰\mathbf{w} to the true weights. In Fig. 2, we show the test accuracy and distance between estimated and true weights during training on sDigits. We see that DANN’s performance gets worse after a few epoch, as predicted by Theorem 2.1. The representation matching objective collapses classes that are over-represented in the target domain on the under-represented ones (see App. B.9). This phenomenon does not occur for IWDAN and IWDAN-O. Both monotonously improve in accuracy and estimation (see Lemma 3.3 and App. B.8 for more details). We also observe that IWDAN’s weights do not converge perfectly. This suggests that fine-tuning λ\lambda (we used λ=0.5\lambda=0.5 in all our experiments for simplicity) or updating 𝐰\mathbf{w} more or less often could lead to better performance.

Ablation Study  Our algorithms have two components, a weighted adversarial loss DA𝐰\mathcal{L}_{DA}^{\mathbf{w}} and a weighted classification loss C𝐰\mathcal{L}_{C}^{\mathbf{w}}. In Table 3, we augment DANN and CDAN using those losses separately (with the true weights). We observe that DANN benefits essentially from the reweighting of its adversarial loss DA𝐰\mathcal{L}_{DA}^{\mathbf{w}}, the classification loss has little effect. For CDAN, gains are essentially seen on the subsampled datasets. Both losses help, with a +2%+2\% extra gain for DA𝐰\mathcal{L}_{DA}^{\mathbf{w}}.

Table 3: Ablation study on the Digits tasks.
Method Digits sDigits Method Digits sDigits
DANN 93.15 83.24 CDAN 95.72 88.23
DANN + C𝐰\mathcal{L}_{C}^{\mathbf{w}} 93.27 84.52 CDAN + C𝐰\mathcal{L}_{C}^{\mathbf{w}} 95.65 91.01
DANN + DA𝐰\mathcal{L}_{DA}^{\mathbf{w}} 95.31 94.41 CDAN + DA𝐰\mathcal{L}_{DA}^{\mathbf{w}} 95.42 93.18
IWDAN-O 95.27 94.46 IWCDAN-O 95.85 94.81

5 Related Work

Covariate shift has been studied and used in many adaptation algorithms [30, 26, 3, 1, 53, 65, 48]. While less known, label shift has also been tackled from various angles over the years: applying EM to learn 𝒟TY\mathcal{D}^{Y}_{T} [14], placing a prior on the label distribution [52], using kernel mean matching [61, 20, 46], etc. Schölkopf et al. [50] cast the problem in a causal/anti-causal perspective corresponding to covariate/label shift. That perspective was then further developed [61, 23, 35, 4]. Numerous domain adaptation methods rely on learning invariant representations, and minimize various metrics on the marginal feature distributions: total variation or equivalently DJSD_{\text{JS}} [22, 53, 64, 36], maximum mean discrepancy [27, 37, 38, 39, 40], Wasserstein distance [17, 16, 51, 33, 15], etc. Other noteworthy DA methods use reconstruction losses and cycle-consistency to learn transferable classifiers [67, 29, 56]. Recently, Liu et al. [36] have introduced Transferable Adversarial Training (TAT), where transferable examples are generated to fill the gap in feature space between source and target domains, the datasets is then augmented with those samples. Applying our method to TAT is a future research direction.

Other relevant settings include partial ADA, i.e. UDA when target labels are a strict subset of the source labels / some components of 𝐰\mathbf{w} are 0 [11, 12, 13]. Multi-domain adaptation, where multiple source or target domains are given, is also very studied [42, 18, 45, 63, 28, 47]. Recently, Binkowski et al. [8] study sample reweighting in the domain transfer to handle mass shifts between distributions.

Prior work on combining importance weight in domain-invariant representation learning also exists in the setting of partial DA [60]. However, the importance ratio in these works is defined over the features ZZ, rather than the class label YY. Compared to our method, this is both statistically inefficient and computationally expensive, since the feature space 𝒵\mathcal{Z} is often a high-dimensional continuous space, whereas the label space 𝒴\mathcal{Y} only contains a finite number (kk) of distinct labels. In a separate work, Yan et al. [58] proposed a weighted MMD distance to handle target shift in UDA. However, their weights are estimated based on pseudo-labels obtained from the learned classifier, hence it is not clear whether the pseudo-labels provide accurate estimation of the importance weights even in simple settings. As a comparison, under GLSGLS, we show that our weight estimation by solving a quadratic program converges asymptotically.

6 Conclusion and Future Work

We have introduced the generalized label shift assumption, GLSGLS, and theoretically-grounded variations of existing algorithms to handle mismatched label distributions. On tasks from classic benchmarks as well as artificial ones, our algorithms consistently outperform their base versions. The gains, as expected theoretically, correlate well with the JSD between label distributions across domains. In real-world applications, the JSD is unknown, and might be larger than in ML datasets where classes are often purposely balanced. Being simple to implement and adding barely any computational cost, the robustness of our method to mismatched label distributions makes it very relevant to such applications.

Extensions  The framework we define in this paper relies on appropriately reweighting the domain adversarial losses. It can be straightforwardly applied to settings where multiple source and/or target domains are used, by simply maintaining one importance weights vector 𝐰\mathbf{w} for each source/target pair [63, 47]. In particular, label shift could explain the observation from Zhao et al. [63] that too many source domains hurt performance, and our framework might alleviate the issue. One can also think of settings (e.g. semi-supervised domain adaptation) where estimations of 𝒟TY\mathcal{D}_{T}^{Y} can be obtained via other means. A more challenging but also more interesting future direction is to extend our framework to domain generalization, where the learner has access to multiple labeled source domains but no access to (even unlabelled) data from the target domain.

Acknowledgements

The authors thank Romain Laroche and Alessandro Sordoni for useful feedback and helpful discussions. HZ and GG would like to acknowledge support from the DARPA XAI project, contract #FA87501720152 and a Nvidia GPU grant. YW would like acknowledge partial support from NSF Award #2029626, a start-up grant from UCSB Department of Computer Science, as well as generous gifts from Amazon, Adobe, Google and NEC Labs.

Broader Impact

Our work focuses on domain adaptation and attempts to properly handle mismatches in the label distributions between the source and target domains. Domain Adaptation as a whole aims at transferring knowledge gained from a certain domain (or data distribution) to another one. It can potentially be used in a variety of decision making systems, such as spam filters, machine translation, etc.. One can also potentially think of much more sensitive applications such as recidivism prediction, or loan approvals.

While it is unclear to us to what extent DA is currently applied, or how it will be applied in the future, the bias formalized in Th. 2.1 and verified in Table 17 demonstrates that imbalances between classes will result in poor transfer performance of standard ADA methods on a subset of them, which is without a doubt a source of potential inequalities. Our method is actually aimed at counter-balancing the effect of such imbalances. As shown in our empirical results (for instance Table 18) it is rather successful at it, especially on significant shifts. This makes us rather confident in the algorithm’s ability to mitigate potential effects of biases in the datasets. On the downside, failure in the weight estimation of some classes might result in poor performance on those. However, we have not observed, in any of our experiments, our method performing significantly worse than its base version. Finally, our method is a variation over existing deep learning algorithms. As such, it carries with it the uncertainties associated to deep learning models, in particular a lack of interpretability and of formal convergence guarantees.

References

  • Adel et al. [2017] Tameem Adel, Han Zhao, and Alexander Wong. Unsupervised domain adaptation with a relaxed covariate shift assumption. In Thirty-First AAAI Conference on Artificial Intelligence, 2017.
  • Arjovsky et al. [2019] Martin Arjovsky, Léon Bottou, Ishaan Gulrajani, and David Lopez-Paz. Invariant risk minimization, 2019. URL http://arxiv.org/abs/1907.02893. cite arxiv:1907.02893.
  • Ash et al. [2016] Jordan T Ash, Robert E Schapire, and Barbara E Engelhardt. Unsupervised domain adaptation using approximate label matching. arXiv preprint arXiv:1602.04889, 2016.
  • Azizzadenesheli et al. [2019] Kamyar Azizzadenesheli, Anqi Liu, Fanny Yang, and Animashree Anandkumar. Regularized learning for domain adaptation under label shifts. In ICLR (Poster). OpenReview.net, 2019. URL http://dblp.uni-trier.de/db/conf/iclr/iclr2019.html#Azizzadenesheli19.
  • Bachman et al. [2019] Philip Bachman, R. Devon Hjelm, and William Buchwalter. Learning representations by maximizing mutual information across views. CoRR, abs/1906.00910, 2019. URL http://dblp.uni-trier.de/db/journals/corr/corr1906.html#abs-1906-00910.
  • Ben-David et al. [2007] Shai Ben-David, John Blitzer, Koby Crammer, Fernando Pereira, et al. Analysis of representations for domain adaptation. Advances in neural information processing systems, 19:137, 2007.
  • Ben-David et al. [2010] Shai Ben-David, John Blitzer, Koby Crammer, Alex Kulesza, Fernando Pereira, and Jennifer Wortman Vaughan. A theory of learning from different domains. Machine learning, 79(1-2):151–175, 2010.
  • Binkowski et al. [2019] Mikolaj Binkowski, R. Devon Hjelm, and Aaron C. Courville. Batch weight for domain adaptation with mass shift. CoRR, abs/1905.12760, 2019. URL http://dblp.uni-trier.de/db/journals/corr/corr1905.html#abs-1905-12760.
  • Bousmalis et al. [2016] Konstantinos Bousmalis, George Trigeorgis, Nathan Silberman, Dilip Krishnan, and Dumitru Erhan. Domain separation networks. In Advances in Neural Information Processing Systems, pages 343–351, 2016.
  • Briët and Harremoës [2009] Jop Briët and Peter Harremoës. Properties of classical and quantum jensen-shannon divergence. Phys. Rev. A, 79:052311, May 2009. doi: 10.1103/PhysRevA.79.052311. URL https://link.aps.org/doi/10.1103/PhysRevA.79.052311.
  • Cao et al. [2018a] Zhangjie Cao, Mingsheng Long, Jianmin Wang, and Michael I. Jordan. Partial transfer learning with selective adversarial networks. In CVPR, pages 2724–2732. IEEE Computer Society, 2018a. URL http://dblp.uni-trier.de/db/conf/cvpr/cvpr2018.html#CaoL0J18.
  • Cao et al. [2018b] Zhangjie Cao, Lijia Ma, Mingsheng Long, and Jianmin Wang. Partial adversarial domain adaptation. In Vittorio Ferrari, Martial Hebert, Cristian Sminchisescu, and Yair Weiss, editors, ECCV (8), volume 11212 of Lecture Notes in Computer Science, pages 139–155. Springer, 2018b. ISBN 978-3-030-01237-3. URL http://dblp.uni-trier.de/db/conf/eccv/eccv2018-8.html#CaoMLW18.
  • Cao et al. [2019] Zhangjie Cao, Kaichao You, Mingsheng Long, Jianmin Wang, and Qiang Yang. Learning to transfer examples for partial domain adaptation. In CVPR, pages 2985–2994. Computer Vision Foundation / IEEE, 2019. URL http://dblp.uni-trier.de/db/conf/cvpr/cvpr2019.html#CaoYLW019.
  • Chan and Ng [2005] Yee Seng Chan and Hwee Tou Ng. Word sense disambiguation with distribution estimation. In Leslie Pack Kaelbling and Alessandro Saffiotti, editors, IJCAI, pages 1010–1015. Professional Book Center, 2005. ISBN 0938075934. URL http://dblp.uni-trier.de/db/conf/ijcai/ijcai2005.html#ChanN05.
  • Chen et al. [2018] Qingchao Chen, Yang Liu, Zhaowen Wang, Ian Wassell, and Kevin Chetty. Re-weighted adversarial adaptation network for unsupervised domain adaptation. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pages 7976–7985, 2018.
  • Courty et al. [2017a] Nicolas Courty, Rémi Flamary, Amaury Habrard, and Alain Rakotomamonjy. Joint distribution optimal transportation for domain adaptation. In Advances in Neural Information Processing Systems, pages 3730–3739, 2017a.
  • Courty et al. [2017b] Nicolas Courty, Rémi Flamary, Devis Tuia, and Alain Rakotomamonjy. Optimal transport for domain adaptation. IEEE transactions on pattern analysis and machine intelligence, 39(9):1853–1865, 2017b.
  • Daumé III [2009] Hal Daumé III. Frustratingly easy domain adaptation. arXiv preprint arXiv:0907.1815, 2009.
  • Dheeru and Karra [2017] Dua Dheeru and Efi Karra. UCI machine learning repository, 2017. URL http://archive.ics.uci.edu/ml.
  • du Plessis and Sugiyama [2014] Marthinus Christoffel du Plessis and Masashi Sugiyama. Semi-supervised learning of class balance under class-prior change by distribution matching. Neural Networks, 50:110–119, 2014. URL http://dblp.uni-trier.de/db/journals/nn/nn50.html#PlessisS14.
  • Endres and Schindelin [2003] Dominik Maria Endres and Johannes E Schindelin. A new metric for probability distributions. IEEE Transactions on Information theory, 2003.
  • Ganin et al. [2016] Yaroslav Ganin, Evgeniya Ustinova, Hana Ajakan, Pascal Germain, Hugo Larochelle, François Laviolette, Mario Marchand, and Victor Lempitsky. Domain-adversarial training of neural networks. Journal of Machine Learning Research, 17(59):1–35, 2016.
  • Gong et al. [2016] Mingming Gong, Kun Zhang, Tongliang Liu, Dacheng Tao, Clark Glymour, and Bernhard Schölkopf. Domain adaptation with conditional transferable components. In International conference on machine learning, pages 2839–2848, 2016.
  • Goodfellow et al. [2017] Ian Goodfellow, Yoshua Bengio, and Aaron Courville. Deep learning. 2017. ISBN 9780262035613 0262035618. URL https://www.worldcat.org/title/deep-learning/oclc/985397543&referer=brief_results.
  • Goodfellow et al. [2014] Ian J. Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, and Yoshua Bengio. Generative adversarial networks, 2014. URL http://arxiv.org/abs/1406.2661. cite arxiv:1406.2661.
  • Gretton et al. [2009] Arthur Gretton, Alex Smola, Jiayuan Huang, Marcel Schmittfull, Karsten Borgwardt, and Bernhard Schölkopf. Covariate shift by kernel mean matching. Dataset shift in machine learning, 3(4):5, 2009.
  • Gretton et al. [2012] Arthur Gretton, Karsten M Borgwardt, Malte J Rasch, Bernhard Schölkopf, and Alexander Smola. A kernel two-sample test. Journal of Machine Learning Research, 13(Mar):723–773, 2012.
  • Guo et al. [2018] Jiang Guo, Darsh J Shah, and Regina Barzilay. Multi-source domain adaptation with mixture of experts. arXiv preprint arXiv:1809.02256, 2018.
  • Hoffman et al. [2017] Judy Hoffman, Eric Tzeng, Taesung Park, Jun-Yan Zhu, Phillip Isola, Kate Saenko, Alexei A Efros, and Trevor Darrell. Cycada: Cycle-consistent adversarial domain adaptation. arXiv preprint arXiv:1711.03213, 2017.
  • Huang et al. [2006] Jiayuan Huang, Arthur Gretton, Karsten M Borgwardt, Bernhard Schölkopf, and Alex J Smola. Correcting sample selection bias by unlabeled data. In Advances in neural information processing systems, pages 601–608, 2006.
  • LeCun et al. [1998] Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. Gradient-based learning applied to document recognition. Proceedings of the IEEE, 86(11):2278–2324, 1998. ISSN 0018-9219. doi: 10.1109/5.726791.
  • LeCun and Cortes [2010] Yann LeCun and Corinna Cortes. MNIST handwritten digit database. http://yann.lecun.com/exdb/mnist/, 2010. URL http://yann.lecun.com/exdb/mnist/.
  • Lee and Raginsky [2018] Jaeho Lee and Maxim Raginsky. Minimax statistical learning with wasserstein distances. In Advances in Neural Information Processing Systems, pages 2692–2701, 2018.
  • Lin [1991] Jianhua Lin. Divergence measures based on the Shannon entropy. IEEE Transactions on Information Theory, 37(1):145–151, 1991.
  • Lipton et al. [2018] Zachary Lipton, Yu-Xiang Wang, and Alexander Smola. Detecting and correcting for label shift with black box predictors. In International Conference on Machine Learning, pages 3128–3136, 2018.
  • Liu et al. [2019] Hong Liu, Mingsheng Long, Jianmin Wang, and Michael I. Jordan. Transferable adversarial training: A general approach to adapting deep classifiers. In Kamalika Chaudhuri and Ruslan Salakhutdinov, editors, ICML, volume 97 of Proceedings of Machine Learning Research, pages 4013–4022. PMLR, 2019. URL http://dblp.uni-trier.de/db/conf/icml/icml2019.html#LiuLWJ19.
  • Long et al. [2014] Mingsheng Long, Jianmin Wang, Guiguang Ding, Jiaguang Sun, and Philip S Yu. Transfer joint matching for unsupervised domain adaptation. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 1410–1417, 2014.
  • Long et al. [2015] Mingsheng Long, Yue Cao, Jianmin Wang, and Michael Jordan. Learning transferable features with deep adaptation networks. In International Conference on Machine Learning, pages 97–105, 2015.
  • Long et al. [2016] Mingsheng Long, Han Zhu, Jianmin Wang, and Michael I Jordan. Unsupervised domain adaptation with residual transfer networks. In Advances in Neural Information Processing Systems, pages 136–144, 2016.
  • Long et al. [2017] Mingsheng Long, Han Zhu, Jianmin Wang, and Michael I Jordan. Deep transfer learning with joint adaptation networks. In Proceedings of the 34th International Conference on Machine Learning-Volume 70, pages 2208–2217. JMLR, 2017.
  • Long et al. [2018] Mingsheng Long, Zhangjie Cao, Jianmin Wang, and Michael I. Jordan. Conditional adversarial domain adaptation. In Samy Bengio, Hanna M. Wallach, Hugo Larochelle, Kristen Grauman, Nicolò Cesa-Bianchi, and Roman Garnett, editors, NeurIPS, pages 1647–1657, 2018. URL http://dblp.uni-trier.de/db/conf/nips/nips2018.html#LongC0J18.
  • Mansour et al. [2009] Yishay Mansour, Mehryar Mohri, and Afshin Rostamizadeh. Domain adaptation with multiple sources. In Advances in neural information processing systems, pages 1041–1048, 2009.
  • McCoy et al. [2019] R. Thomas McCoy, Ellie Pavlick, and Tal Linzen. Right for the wrong reasons: Diagnosing syntactic heuristics in natural language inference. Proceedings of the ACL, 2019.
  • Müller [1997] Alfred Müller. Integral probability metrics and their generating classes of functions. Advances in Applied Probability, 29(2):429–443, 1997.
  • Nam and Han [2016] Hyeonseob Nam and Bohyung Han. Learning multi-domain convolutional neural networks for visual tracking. In The IEEE Conference on Computer Vision and Pattern Recognition (CVPR), June 2016.
  • Nguyen et al. [2015] Tuan Duong Nguyen, Marthinus Christoffel du Plessis, and Masashi Sugiyama. Continuous target shift adaptation in supervised learning. In ACML, volume 45 of JMLR Workshop and Conference Proceedings, pages 285–300. JMLR.org, 2015. URL http://dblp.uni-trier.de/db/conf/acml/acml2015.html#NguyenPS15.
  • Peng et al. [2019] Xingchao Peng, Zijun Huang, Ximeng Sun, and Kate Saenko. Domain agnostic learning with disentangled representations. In Kamalika Chaudhuri and Ruslan Salakhutdinov, editors, ICML, volume 97 of Proceedings of Machine Learning Research, pages 5102–5112. PMLR, 2019. URL http://dblp.uni-trier.de/db/conf/icml/icml2019.html#PengHSS19.
  • Redko et al. [2019] Ievgen Redko, Nicolas Courty, Rémi Flamary, and Devis Tuia. Optimal transport for multi-source domain adaptation under target shift. In 22nd International Conference on Artificial Intelligence and Statistics (AISTATS) 2019, volume 89, 2019.
  • Saenko et al. [2010] Kate Saenko, Brian Kulis, Mario Fritz, and Trevor Darrell. Adapting visual category models to new domains. In Kostas Daniilidis, Petros Maragos, and Nikos Paragios, editors, ECCV (4), volume 6314 of Lecture Notes in Computer Science, pages 213–226. Springer, 2010. ISBN 978-3-642-15560-4. URL http://dblp.uni-trier.de/db/conf/eccv/eccv2010-4.html#SaenkoKFD10.
  • Schölkopf et al. [2012] Bernhard Schölkopf, Dominik Janzing, Jonas Peters, Eleni Sgouritsa, Kun Zhang, and Joris M. Mooij. On causal and anticausal learning. In ICML. icml.cc / Omnipress, 2012. URL http://dblp.uni-trier.de/db/conf/icml/icml2012.html#ScholkopfJPSZM12.
  • Shen et al. [2018] Jian Shen, Yanru Qu, Weinan Zhang, and Yong Yu. Wasserstein distance guided representation learning for domain adaptation. In Thirty-Second AAAI Conference on Artificial Intelligence, 2018.
  • Storkey [2009] Amos Storkey. When training and test sets are different: Characterising learning transfer. Dataset shift in machine learning., 2009.
  • Tzeng et al. [2017] Eric Tzeng, Judy Hoffman, Kate Saenko, and Trevor Darrell. Adversarial discriminative domain adaptation. arXiv preprint arXiv:1702.05464, 2017.
  • Venkateswara et al. [2017] Hemanth Venkateswara, Jose Eusebio, Shayok Chakraborty, and Sethuraman Panchanathan. Deep hashing network for unsupervised domain adaptation. In (IEEE) Conference on Computer Vision and Pattern Recognition (CVPR), 2017.
  • Visda [2017] Visda. Visual domain adaptation challenge, 2017. URL http://ai.bu.edu/visda-2017/.
  • Xie et al. [2018] Shaoan Xie, Zibin Zheng, Liang Chen, and Chuan Chen. Learning semantic representations for unsupervised domain adaptation. In Jennifer G. Dy and Andreas Krause, editors, ICML, volume 80 of Proceedings of Machine Learning Research, pages 5419–5428. PMLR, 2018. URL http://dblp.uni-trier.de/db/conf/icml/icml2018.html#XieZCC18.
  • Yaghoobzadeh et al. [2019] Yadollah Yaghoobzadeh, Remi Tachet des Combes, Timothy J. Hazen, and Alessandro Sordoni. Robust natural language inference models with example forgetting. CoRR, abs/1911.03861, 2019. URL http://dblp.uni-trier.de/db/journals/corr/corr1911.html#abs-1911-03861.
  • Yan et al. [2017] Hongliang Yan, Yukang Ding, Peihua Li, Qilong Wang, Yong Xu, and Wangmeng Zuo. Mind the class weight bias: Weighted maximum mean discrepancy for unsupervised domain adaptation. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pages 2272–2281, 2017.
  • Yosinski et al. [2014] Jason Yosinski, Jeff Clune, Yoshua Bengio, and Hod Lipson. How transferable are features in deep neural networks? In Advances in neural information processing systems, pages 3320–3328, 2014.
  • Zhang et al. [2018] Jing Zhang, Zewei Ding, Wanqing Li, and Philip Ogunbona. Importance weighted adversarial nets for partial domain adaptation. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pages 8156–8164, 2018.
  • Zhang et al. [2013] Kun Zhang, Bernhard Schölkopf, Krikamol Muandet, and Zhikun Wang. Domain adaptation under target and conditional shift. In International Conference on Machine Learning, pages 819–827, 2013.
  • Zhang et al. [2015] Xu Zhang, Felix X. Yu, Shih-Fu Chang, and Shengjin Wang. Deep transfer network: Unsupervised domain adaptation. CoRR, abs/1503.00591, 2015. URL http://dblp.uni-trier.de/db/journals/corr/corr1503.html#ZhangYCW15.
  • Zhao et al. [2018a] Han Zhao, Shanghang Zhang, Guanhang Wu, Geoffrey J Gordon, et al. Multiple source domain adaptation with adversarial learning. In International Conference on Learning Representations, 2018a.
  • Zhao et al. [2018b] Han Zhao, Shanghang Zhang, Guanhang Wu, José MF Moura, Joao P Costeira, and Geoffrey J Gordon. Adversarial multiple source domain adaptation. In Advances in Neural Information Processing Systems, pages 8568–8579, 2018b.
  • Zhao et al. [2019a] Han Zhao, Junjie Hu, Zhenyao Zhu, Adam Coates, and Geoff Gordon. Deep generative and discriminative domain adaptation. In Proceedings of the 18th International Conference on Autonomous Agents and MultiAgent Systems, pages 2315–2317. International Foundation for Autonomous Agents and Multiagent Systems, 2019a.
  • Zhao et al. [2019b] Han Zhao, Remi Tachet des Combes, Kun Zhang, and Geoffrey J. Gordon. On learning invariant representations for domain adaptation. In Kamalika Chaudhuri and Ruslan Salakhutdinov, editors, ICML, volume 97 of Proceedings of Machine Learning Research, pages 7523–7532. PMLR, 2019b. URL http://dblp.uni-trier.de/db/conf/icml/icml2019.html#0002CZG19.
  • Zhu et al. [2017] Jun-Yan Zhu, Taesung Park, Phillip Isola, and Alexei A. Efros. Unpaired image-to-image translation using cycle-consistent adversarial networks. In ICCV, pages 2242–2251. IEEE Computer Society, 2017. ISBN 978-1-5386-1032-9. URL http://dblp.uni-trier.de/db/conf/iccv/iccv2017.html#ZhuPIE17.

Appendix A Omitted Proofs

In this section, we provide the theoretical material that completes the main text.

A.1 Definition

Definition A.1.

Let us recall that for two distributions 𝒟\mathcal{D} and 𝒟\mathcal{D}^{\prime}, the Jensen-Shannon (JSD) divergence DJS(𝒟𝒟)D_{\text{JS}}(\mathcal{D}~{}\|~{}\mathcal{D}^{\prime}) is defined as:

DJS(𝒟𝒟):=12DKL(𝒟𝒟M)+12DKL(𝒟𝒟M),D_{\text{JS}}(\mathcal{D}~{}\|~{}\mathcal{D}^{\prime})\vcentcolon=\frac{1}{2}D_{\text{KL}}(\mathcal{D}~{}\|~{}\mathcal{D}_{M})+\frac{1}{2}D_{\text{KL}}(\mathcal{D}^{\prime}~{}\|~{}\mathcal{D}_{M}),

where DKL()D_{\text{KL}}(\cdot~{}\|~{}\cdot) is the Kullback–Leibler (KL) divergence and 𝒟M:=(𝒟+𝒟)/2\mathcal{D}_{M}\vcentcolon=(\mathcal{D}+\mathcal{D}^{\prime})/2.

A.2 Consistency of the Weighted Domain Adaptation Loss (7)

For the sake of conciseness, we verify here that the domain adaptation training objective does lead to minimizing the Jensen-Shannon divergence between the weighted feature distribution of the source domain and the feature distribution of the target domain.

Lemma A.1.

Let p(x,y)p(x,y) and q(x)q(x) be two density distributions, and w(y)w(y) be a positive function such that p(y)w(y)𝑑y=1\int p(y)w(y)dy=1. Let pw(x)=p(x,y)w(y)𝑑yp^{w}(x)=\int p(x,y)w(y)dy denote the ww-reweighted marginal distribution of xx under pp. The minimum value of

I(d):=𝔼(x,y)p,xq[w(y)log(d(x))log(1d(x))]I(d)\vcentcolon=\mathbb{E}_{(x,y)\sim p,x^{\prime}\sim q}[-w(y)\log(d(x))-\log(1-d(x^{\prime}))]

is log(4)2DJS(pw(x)q(x))\log(4)-2D_{\text{JS}}(p^{w}(x)~{}\|~{}q(x)), and is attained for d(x)=pw(x)pw(x)+q(x)d^{*}(x)=\frac{p^{w}(x)}{p^{w}(x)+q(x)}.

Proof.

We see that:

I(d)\displaystyle I(d) =[w(y)log(d(x))+log(1d(x))]p(x,y)q(x)𝑑x𝑑x𝑑y\displaystyle=-\iiint[w(y)\log(d(x))+\log(1-d(x^{\prime}))]p(x,y)q(x^{\prime})dxdx^{\prime}dy (9)
=[w(y)p(x,y)𝑑y]log(d(x))+q(x)log(1d(x))dx\displaystyle=-\int[\int w(y)p(x,y)dy]\log(d(x))+q(x)\log(1-d(x))dx (10)
=pw(x)log(d(x))+q(x)log(1d(x))dx.\displaystyle=-\int p^{w}(x)\log(d(x))+q(x)\log(1-d(x))dx. (11)

From the last line, we follow the exact method from Goodfellow et al. [25] to see that point-wise in xx the minimum is attained for d(x)=pw(x)pw(x)+q(x)d^{*}(x)=\frac{p^{w}(x)}{p^{w}(x)+q(x)} and that I(d)=log(4)2DJS(pw(x)q(x))I(d^{*})=\log(4)-2D_{\text{JS}}(p^{w}(x)~{}\|~{}q(x)). ∎

Applying Lemma A.1 to 𝒟S(Z,Y)\mathcal{D}_{S}(Z,Y) and 𝒟T(Z)\mathcal{D}_{T}(Z) proves that the domain adaptation objective leads to minimizing DJS(𝒟Sw(Z)𝒟T(Z))D_{\text{JS}}(\mathcal{D}^{w}_{S}(Z)~{}\|~{}\mathcal{D}_{T}(Z)).

A.3 kk-class information-theoretic lower bound

In this section, we prove Theorem 2.1 that extends previous result to the general kk-class classification problem. See 2.1

Proof.

We essentially follow the proof from Zhao et al. [66], except for Lemmas 4.6 that needs to be adapted to the CDAN framework and Lemma 4.7 to kk-class classification.

Lemma 4.6 from Zhao et al. [66] states that DJS(𝒟SY^,𝒟TY^)DJS(𝒟SZ,𝒟TZ)D_{\text{JS}}(\mathcal{D}^{\widehat{Y}}_{S},\mathcal{D}^{\widehat{Y}}_{T})\leq D_{\text{JS}}(\mathcal{D}^{Z}_{S},\mathcal{D}^{Z}_{T}), which covers the case Z~=Z\widetilde{Z}=Z.

When Z~=Y^Z\widetilde{Z}=\widehat{Y}\otimes Z, let us first recall that we assume hh or equivalently Y^\widehat{Y} to be a one-hot prediction of the class. We have the following Markov chain:

X𝑔Zh~Z~𝑙Y^,X\overset{g}{\longrightarrow}Z\overset{\tilde{h}}{\longrightarrow}\widetilde{Z}\overset{l}{\longrightarrow}\widehat{Y},

where h~(z)=h(z)z\tilde{h}(z)=h(z)\otimes z and l:𝒴𝒵𝒴l:\mathcal{Y}\otimes\mathcal{Z}\to\mathcal{Y} returns the index of the non-zero block in h~(z)\tilde{h}(z). There is only one such block since hh is a one-hot, and its index corresponds to the class predicted by hh. Given the definition of ll, we clearly see that Y^\widehat{Y} is independent of XX knowing Z~\widetilde{Z}. We can now apply the same proof than in Zhao et al. [66] to conclude that:

DJS(𝒟SY^,𝒟TY^)DJS(𝒟SZ~,𝒟TZ~).D_{\text{JS}}(\mathcal{D}^{\widehat{Y}}_{S},\mathcal{D}^{\widehat{Y}}_{T})\leq D_{\text{JS}}(\mathcal{D}^{\widetilde{Z}}_{S},\mathcal{D}^{\widetilde{Z}}_{T}). (12)

It essentially boils down to a data-processing argument: the discrimination distance between two distributions cannot increase after the same (possibly stochastic) channel (kernel) is applied to both. Here, the channel corresponds to the (potentially randomized) function ll.

Remark

Additionally, we note that the above inequality holds for any Z~\tilde{Z} such that Y^=l(Z~)\widehat{Y}=l(\widetilde{Z}) for a (potentially randomized) function l. This covers any and all potential combinations of representations at various layers of the deep net, including the last layer (which corresponds to its predictions Y^\widehat{Y}).

Let us move to the second part of the proof. We wish to show that DJS(𝒟Y,𝒟Y^)ε(hg)D_{\text{JS}}(\mathcal{D}^{Y},\mathcal{D}^{\widehat{Y}})\leq\varepsilon(h\circ g), where 𝒟\mathcal{D} can be either 𝒟S\mathcal{D}_{S} or 𝒟T\mathcal{D}_{T}:

2DJS(𝒟Y,𝒟Y^)\displaystyle 2D_{\text{JS}}(\mathcal{D}^{Y},\mathcal{D}^{\widehat{Y}}) 𝒟Y𝒟Y^1\displaystyle\leq\|\mathcal{D}^{Y}-\mathcal{D}^{\widehat{Y}}\|_{1} [34]
=i=1k|𝒟(Y^=i)𝒟(Y=i)|\displaystyle=\displaystyle{\sum_{i=1}^{k}}|\mathcal{D}(\widehat{Y}=i)-\mathcal{D}(Y=i)|
=i=1k|j=1k𝒟(Y^=i|Y=j)𝒟(Y=j)𝒟(Y=i)|\displaystyle=\displaystyle{\sum_{i=1}^{k}}|\displaystyle{\sum_{j=1}^{k}}\mathcal{D}(\widehat{Y}=i|Y=j)\mathcal{D}(Y=j)-\mathcal{D}(Y=i)|
=i=1k|𝒟(Y^=i|Y=i)𝒟(Y=i)𝒟(Y=i)+ji𝒟(Y^=i|Y=j)𝒟(Y=j)|\displaystyle=\displaystyle{\sum_{i=1}^{k}}|\mathcal{D}(\widehat{Y}=i|Y=i)\mathcal{D}(Y=i)-\mathcal{D}(Y=i)+\displaystyle{\sum_{j\neq i}}\mathcal{D}(\widehat{Y}=i|Y=j)\mathcal{D}(Y=j)|
i=1k|𝒟(Y^=i|Y=i)1|𝒟(Y=i)+i=1kji𝒟(Y^=i|Y=j)𝒟(Y=j)\displaystyle\leq\displaystyle{\sum_{i=1}^{k}}|\mathcal{D}(\widehat{Y}=i|Y=i)-1|\mathcal{D}(Y=i)+\displaystyle{\sum_{i=1}^{k}}\displaystyle{\sum_{j\neq i}}\mathcal{D}(\widehat{Y}=i|Y=j)\mathcal{D}(Y=j)
=i=1k𝒟(Y^Y|Y=i)𝒟(Y=i)+j=1kij𝒟(Y^=i|Y=j)𝒟(Y=j)\displaystyle=\displaystyle{\sum_{i=1}^{k}}\mathcal{D}(\widehat{Y}\neq Y|Y=i)\mathcal{D}(Y=i)+\displaystyle{\sum_{j=1}^{k}}\displaystyle{\sum_{i\neq j}}\mathcal{D}(\widehat{Y}=i|Y=j)\mathcal{D}(Y=j)
=2i=1k𝒟(Y^Y|Y=i)𝒟(Y=i)=2𝒟(Y^Y)=2ε(hg).\displaystyle=2\displaystyle{\sum_{i=1}^{k}}\mathcal{D}(\widehat{Y}\neq Y|Y=i)\mathcal{D}(Y=i)=2\mathcal{D}(\widehat{Y}\neq Y)=2\varepsilon(h\circ g). (13)

We can now apply the triangular inequality to DJS\sqrt{D_{\text{JS}}}, which is a distance metric [21], called the Jensen-Shannon distance. This gives us:

DJS(𝒟SY,𝒟TY)\displaystyle\sqrt{D_{\text{JS}}(\mathcal{D}_{S}^{Y},\mathcal{D}_{T}^{Y})} DJS(𝒟SY,𝒟SY^)+DJS(𝒟SY^,𝒟TY^)+DJS(𝒟TY^,𝒟TY)\displaystyle\leq\sqrt{D_{\text{JS}}(\mathcal{D}_{S}^{Y},\mathcal{D}_{S}^{\widehat{Y}})}+\sqrt{D_{\text{JS}}(\mathcal{D}_{S}^{\widehat{Y}},\mathcal{D}_{T}^{\widehat{Y}})}+\sqrt{D_{\text{JS}}(\mathcal{D}_{T}^{\widehat{Y}},\mathcal{D}_{T}^{Y})}
DJS(𝒟SY,𝒟SY^)+DJS(𝒟SZ~,𝒟TZ~)+DJS(𝒟TY^,𝒟TY)\displaystyle\leq\sqrt{D_{\text{JS}}(\mathcal{D}_{S}^{Y},\mathcal{D}_{S}^{\widehat{Y}})}+\sqrt{D_{\text{JS}}(\mathcal{D}^{\widetilde{Z}}_{S},\mathcal{D}^{\widetilde{Z}}_{T})}+\sqrt{D_{\text{JS}}(\mathcal{D}_{T}^{\widehat{Y}},\mathcal{D}_{T}^{Y})}
εS(hg)+DJS(𝒟SZ~,𝒟TZ~)+εT(hg).\displaystyle\leq\sqrt{\varepsilon_{S}(h\circ g)}+\sqrt{D_{\text{JS}}(\mathcal{D}^{\widetilde{Z}}_{S},\mathcal{D}^{\widetilde{Z}}_{T})}+\sqrt{\varepsilon_{T}(h\circ g)}.

where we used Equation (12) for the second inequality and (13) for the third.

Finally, assuming that DJS(𝒟SY,𝒟TY)DJS(𝒟SZ~,𝒟TZ~)D_{\text{JS}}(\mathcal{D}_{S}^{Y},\mathcal{D}_{T}^{Y})\geq D_{\text{JS}}(\mathcal{D}_{S}^{\widetilde{Z}},\mathcal{D}_{T}^{\widetilde{Z}}), we get:

(DJS(𝒟SY,𝒟TY)DJS(𝒟SZ~,𝒟TZ~))2(εS(hg)+εT(hg))22(εS(hg)+εT(hg)).\displaystyle\left(\sqrt{D_{\text{JS}}(\mathcal{D}_{S}^{Y},\mathcal{D}_{T}^{Y})}-\sqrt{D_{\text{JS}}(\mathcal{D}^{\widetilde{Z}}_{S},\mathcal{D}^{\widetilde{Z}}_{T})}\right)^{2}\leq\left(\sqrt{\varepsilon_{S}(h\circ g)}+\sqrt{\varepsilon_{T}(h\circ g)}\right)^{2}\leq 2\left(\varepsilon_{S}(h\circ g)+\varepsilon_{T}(h\circ g)\right).

which concludes the proof. ∎

A.4 Proof of Theorem 3.1

To simplify the notation, we define the error gap Δε(Y^)\Delta_{\varepsilon}(\widehat{Y}) as follows:

Δε(Y^):=|εS(Y^)εT(Y^)|.\Delta_{\varepsilon}(\widehat{Y})\vcentcolon=|\varepsilon_{S}(\widehat{Y})-\varepsilon_{T}(\widehat{Y})|.

Also, in this case we use 𝒟a,a{S,T}\mathcal{D}_{a},~{}a\in\{S,T\} to mean the source and target distributions respectively. Before we give the proof of Theorem 3.1, we first prove the following two lemmas that will be used in the proof.

Lemma A.2.

Define γa,j:=𝒟a(Y=j),a{S,T},j[k]\gamma_{a,j}\vcentcolon=\mathcal{D}_{a}(Y=j),\forall a\in\{S,T\},\forall j\in[k], then αj,βj0\forall\alpha_{j},\beta_{j}\geq 0 such that αj+βj=1\alpha_{j}+\beta_{j}=1, and ij\forall i\neq j, the following upper bound holds:

|γS,j𝒟S(Y^=iY=j)γT,j𝒟T(Y^=iY=j)|\displaystyle|\gamma_{S,j}\mathcal{D}_{S}(\widehat{Y}=i\mid Y=j)-\gamma_{T,j}\mathcal{D}_{T}(\widehat{Y}=i\mid Y=j)|\leq
|γS,jγT,j|(αj𝒟S(Y^=iY=j)+βj𝒟T(Y^=iY=j))+γS,jβjΔCE(Y^)+γT,jαjΔCE(Y^).\displaystyle\hskip 22.76228pt|\gamma_{S,j}-\gamma_{T,j}|\cdot\left(\alpha_{j}\mathcal{D}_{S}(\widehat{Y}=i\mid Y=j)+\beta_{j}\mathcal{D}_{T}(\widehat{Y}=i\mid Y=j)\right)+\gamma_{S,j}\beta_{j}\Delta_{\mathrm{CE}}(\widehat{Y})+\gamma_{T,j}\alpha_{j}\Delta_{\mathrm{CE}}(\widehat{Y}).
Proof.

To make the derivation uncluttered, define 𝒟j(Y^=i):=αj𝒟S(Y^=iY=j)+βj𝒟T(Y^=iY=j)\mathcal{D}_{j}(\widehat{Y}=i)\vcentcolon=\alpha_{j}\mathcal{D}_{S}(\widehat{Y}=i\mid Y=j)+\beta_{j}\mathcal{D}_{T}(\widehat{Y}=i\mid Y=j) to be the mixture conditional probability of Y^=i\widehat{Y}=i given Y=jY=j, where the mixture weight is given by αj\alpha_{j} and βj\beta_{j}. Then in order to prove the upper bound in the lemma, it suffices if we give the desired upper bound for the following term

||γS,j𝒟S(Y^=iY=j)γT,j𝒟T(Y^=iY=j)||(γS,jγT,j)𝒟j(Y^=i)||\displaystyle\left||\gamma_{S,j}\mathcal{D}_{S}(\widehat{Y}=i\mid Y=j)-\gamma_{T,j}\mathcal{D}_{T}(\widehat{Y}=i\mid Y=j)|-|(\gamma_{S,j}-\gamma_{T,j})\mathcal{D}_{j}(\widehat{Y}=i)|\right|
|(γS,j𝒟S(Y^=iY=j)γT,j𝒟T(Y^=iY=j))(γS,jγT,j)𝒟j(Y^=i)|\displaystyle\leq\left|\left(\gamma_{S,j}\mathcal{D}_{S}(\widehat{Y}=i\mid Y=j)-\gamma_{T,j}\mathcal{D}_{T}(\widehat{Y}=i\mid Y=j)\right)-(\gamma_{S,j}-\gamma_{T,j})\mathcal{D}_{j}(\widehat{Y}=i)\right|
=|γS,j(𝒟S(Y^=iY=j)𝒟j(Y^=i))γT,j(𝒟T(Y^=iY=j)𝒟j(Y^=i))|,\displaystyle=\left|\gamma_{S,j}(\mathcal{D}_{S}(\widehat{Y}=i\mid Y=j)-\mathcal{D}_{j}(\widehat{Y}=i))-\gamma_{T,j}(\mathcal{D}_{T}(\widehat{Y}=i\mid Y=j)-\mathcal{D}_{j}(\widehat{Y}=i))\right|,

following which we will have:

|γS,j𝒟S(Y^=iY=j)γT,j𝒟T(Y^=iY=j)||(γS,jγT,j)𝒟j(Y^=i)|\displaystyle~{}|\gamma_{S,j}\mathcal{D}_{S}(\widehat{Y}=i\mid Y=j)-\gamma_{T,j}\mathcal{D}_{T}(\widehat{Y}=i\mid Y=j)|\leq|(\gamma_{S,j}-\gamma_{T,j})\mathcal{D}_{j}(\widehat{Y}=i)|
+|γS,j(𝒟S(Y^=iY=j)𝒟j(Y^=i))γT,j(𝒟T(Y^=iY=j)𝒟j(Y^=i))|\displaystyle+\left|\gamma_{S,j}(\mathcal{D}_{S}(\widehat{Y}=i\mid Y=j)-\mathcal{D}_{j}(\widehat{Y}=i))-\gamma_{T,j}(\mathcal{D}_{T}(\widehat{Y}=i\mid Y=j)-\mathcal{D}_{j}(\widehat{Y}=i))\right|
|γS,jγT,j|(αj𝒟S(Y^=iY=j)+βj𝒟T(Y^=iY=j))\displaystyle\leq|\gamma_{S,j}-\gamma_{T,j}|\left(\alpha_{j}\mathcal{D}_{S}(\widehat{Y}=i\mid Y=j)+\beta_{j}\mathcal{D}_{T}(\widehat{Y}=i\mid Y=j)\right)
+γS,j|𝒟S(Y^=iY=j)𝒟j(Y^=i)|+γT,j|𝒟T(Y^=iY=j)𝒟j(Y^=i)|.\displaystyle+\gamma_{S,j}\left|\mathcal{D}_{S}(\widehat{Y}=i\mid Y=j)-\mathcal{D}_{j}(\widehat{Y}=i)\right|+\gamma_{T,j}\left|\mathcal{D}_{T}(\widehat{Y}=i\mid Y=j)-\mathcal{D}_{j}(\widehat{Y}=i)\right|.

To proceed, let us first simplify 𝒟S(Y^=iY=j)𝒟j(Y^=i)\mathcal{D}_{S}(\widehat{Y}=i\mid Y=j)-\mathcal{D}_{j}(\widehat{Y}=i). By definition of 𝒟j(Y^=i)=αj𝒟S(Y^=iY=j)+βj𝒟T(Y^=iY=j)\mathcal{D}_{j}(\widehat{Y}=i)=\alpha_{j}\mathcal{D}_{S}(\widehat{Y}=i\mid Y=j)+\beta_{j}\mathcal{D}_{T}(\widehat{Y}=i\mid Y=j), we know that:

𝒟S(Y^=iY=j)𝒟j(Y^=i)\displaystyle~{}\mathcal{D}_{S}(\widehat{Y}=i\mid Y=j)-\mathcal{D}_{j}(\widehat{Y}=i)
=\displaystyle= 𝒟S(Y^=iY=j)(αj𝒟S(Y^=iY=j)+βj𝒟T(Y^=iY=j))\displaystyle~{}\mathcal{D}_{S}(\widehat{Y}=i\mid Y=j)-\big{(}\alpha_{j}\mathcal{D}_{S}(\widehat{Y}=i\mid Y=j)+\beta_{j}\mathcal{D}_{T}(\widehat{Y}=i\mid Y=j)\big{)}
=\displaystyle= (𝒟S(Y^=iY=j)αj𝒟S(Y^=iY=j))βj𝒟T(Y^=iY=j)\displaystyle~{}\big{(}\mathcal{D}_{S}(\widehat{Y}=i\mid Y=j)-\alpha_{j}\mathcal{D}_{S}(\widehat{Y}=i\mid Y=j)\big{)}-\beta_{j}\mathcal{D}_{T}(\widehat{Y}=i\mid Y=j)
=\displaystyle= βj(𝒟S(Y^=iY=j)𝒟T(Y^=iY=j)).\displaystyle~{}\beta_{j}\big{(}\mathcal{D}_{S}(\widehat{Y}=i\mid Y=j)-\mathcal{D}_{T}(\widehat{Y}=i\mid Y=j)\big{)}.

Similarly, for the second term 𝒟T(Y^=iY=j)𝒟j(Y^=i)\mathcal{D}_{T}(\widehat{Y}=i\mid Y=j)-\mathcal{D}_{j}(\widehat{Y}=i), we can show that:

𝒟T(Y^=iY=j)𝒟j(Y^=i)=αj(𝒟T(Y^=iY=j)𝒟S(Y^=iY=j)).\mathcal{D}_{T}(\widehat{Y}=i\mid Y=j)-\mathcal{D}_{j}(\widehat{Y}=i)=\alpha_{j}\big{(}\mathcal{D}_{T}(\widehat{Y}=i\mid Y=j)-\mathcal{D}_{S}(\widehat{Y}=i\mid Y=j)\big{)}.

Plugging these two identities into the above, we can continue the analysis with

|γS,j(𝒟S(Y^=iY=j)𝒟j(Y^=i))γT,j(𝒟T(Y^=iY=j)𝒟j(Y^=i))|\displaystyle\left|\gamma_{S,j}(\mathcal{D}_{S}(\widehat{Y}=i\mid Y=j)-\mathcal{D}_{j}(\widehat{Y}=i))-\gamma_{T,j}(\mathcal{D}_{T}(\widehat{Y}=i\mid Y=j)-\mathcal{D}_{j}(\widehat{Y}=i))\right|
=|γS,jβ(𝒟S(Y^=iY=j)𝒟T(Y^=iY=j))γT,jαj(𝒟T(Y^=iY=j)𝒟S(Y^=iY=j))|\displaystyle=\left|\gamma_{S,j}\beta(\mathcal{D}_{S}(\widehat{Y}=i\mid Y=j)-\mathcal{D}_{T}(\widehat{Y}=i\mid Y=j))-\gamma_{T,j}\alpha_{j}(\mathcal{D}_{T}(\widehat{Y}=i\mid Y=j)-\mathcal{D}_{S}(\widehat{Y}=i\mid Y=j))\right|
|γS,jβj(𝒟S(Y^=iY=j)𝒟T(Y^=iY=j))|+|γT,jαj(𝒟T(Y^=iY=j)𝒟S(Y^=iY=j))|\displaystyle\leq\left|\gamma_{S,j}\beta_{j}(\mathcal{D}_{S}(\widehat{Y}=i\mid Y=j)-\mathcal{D}_{T}(\widehat{Y}=i\mid Y=j))\right|+\left|\gamma_{T,j}\alpha_{j}(\mathcal{D}_{T}(\widehat{Y}=i\mid Y=j)-\mathcal{D}_{S}(\widehat{Y}=i\mid Y=j))\right|
γS,jβjΔCE(Y^)+γT,jαjΔCE(Y^).\displaystyle\leq\gamma_{S,j}\beta_{j}\Delta_{\mathrm{CE}}(\widehat{Y})+\gamma_{T,j}\alpha_{j}\Delta_{\mathrm{CE}}(\widehat{Y}).

The first inequality holds by the triangle inequality and the second by the definition of the conditional error gap. Combining all the inequalities above completes the proof. ∎

We are now ready to prove the theorem: See 3.1

Proof of Theorem 3.1.

First, by the law of total probability, it is easy to verify that following identity holds for a{S,T}a\in\{S,T\}:

𝒟a(Y^Y)\displaystyle\mathcal{D}_{a}(\widehat{Y}\neq Y) =ij𝒟a(Y^=i,Y=j)=ijγa,j𝒟a(Y^=iY=j).\displaystyle=\sum_{i\neq j}\mathcal{D}_{a}(\widehat{Y}=i,Y=j)=\sum_{i\neq j}\gamma_{a,j}\mathcal{D}_{a}(\widehat{Y}=i\mid Y=j).

Using this identity, to bound the error gap, we have:

|𝒟S(YY^)𝒟T(YY^)|\displaystyle~{}|\mathcal{D}_{S}(Y\neq\widehat{Y})-\mathcal{D}_{T}(Y\neq\widehat{Y})|
=\displaystyle= |ijγS,j𝒟S(Y^=iY=j)ijγT,j𝒟T(Y^=iY=j)|\displaystyle~{}\big{|}\sum_{i\neq j}\gamma_{S,j}\mathcal{D}_{S}(\widehat{Y}=i\mid Y=j)-\sum_{i\neq j}\gamma_{T,j}\mathcal{D}_{T}(\widehat{Y}=i\mid Y=j)\big{|}
\displaystyle\leq ij|γS,j𝒟S(Y^=iY=j)γT,j𝒟T(Y^=iY=j)|.\displaystyle~{}\sum_{i\neq j}\big{|}\gamma_{S,j}\mathcal{D}_{S}(\widehat{Y}=i\mid Y=j)-\gamma_{T,j}\mathcal{D}_{T}(\widehat{Y}=i\mid Y=j)\big{|}.

Invoking Lemma A.2 to bound the above terms, and since j[k],γS,j,γT,j[0,1]\forall j\in[k],\gamma_{S,j},\gamma_{T,j}\in[0,1], αj+βj=1\alpha_{j}+\beta_{j}=1, we get:

|𝒟S(YY^)𝒟T(YY^)|\displaystyle~{}|\mathcal{D}_{S}(Y\neq\widehat{Y})-\mathcal{D}_{T}(Y\neq\widehat{Y})|
\displaystyle\leq ij|γS,j𝒟S(Y^=iY=j)γT,j𝒟T(Y^=iY=j)|\displaystyle~{}\sum_{i\neq j}\big{|}\gamma_{S,j}\mathcal{D}_{S}(\widehat{Y}=i\mid Y=j)-\gamma_{T,j}\mathcal{D}_{T}(\widehat{Y}=i\mid Y=j)\big{|}
\displaystyle\leq ij|γS,jγT,j|(αj𝒟S(Y^=iY=j)+βj𝒟T(Y^=iY=j))+γS,jβjΔCE(Y^)+γT,jαjΔCE(Y^)\displaystyle~{}\sum_{i\neq j}|\gamma_{S,j}-\gamma_{T,j}|\cdot\left(\alpha_{j}\mathcal{D}_{S}(\widehat{Y}=i\mid Y=j)+\beta_{j}\mathcal{D}_{T}(\widehat{Y}=i\mid Y=j)\right)+\gamma_{S,j}\beta_{j}\Delta_{\mathrm{CE}}(\widehat{Y})+\gamma_{T,j}\alpha_{j}\Delta_{\mathrm{CE}}(\widehat{Y})
\displaystyle\leq ij|γS,jγT,j|(αj𝒟S(Y^=iY=j)+βj𝒟T(Y^=iY=j))+γS,jΔCE(Y^)+γT,jΔCE(Y^)\displaystyle~{}\sum_{i\neq j}|\gamma_{S,j}-\gamma_{T,j}|\cdot\left(\alpha_{j}\mathcal{D}_{S}(\widehat{Y}=i\mid Y=j)+\beta_{j}\mathcal{D}_{T}(\widehat{Y}=i\mid Y=j)\right)+\gamma_{S,j}\Delta_{\mathrm{CE}}(\widehat{Y})+\gamma_{T,j}\Delta_{\mathrm{CE}}(\widehat{Y})
=\displaystyle= ij|γS,jγT,j|(αj𝒟S(Y^=iY=j)+βj𝒟T(Y^=iY=j))+i=1kjiγS,jΔCE(Y^)+γT,jΔCE(Y^)\displaystyle~{}\sum_{i\neq j}|\gamma_{S,j}-\gamma_{T,j}|\cdot\left(\alpha_{j}\mathcal{D}_{S}(\widehat{Y}=i\mid Y=j)+\beta_{j}\mathcal{D}_{T}(\widehat{Y}=i\mid Y=j)\right)+\sum_{i=1}^{k}\sum_{j\neq i}\gamma_{S,j}\Delta_{\mathrm{CE}}(\widehat{Y})+\gamma_{T,j}\Delta_{\mathrm{CE}}(\widehat{Y})
=\displaystyle= ij|γS,jγT,j|(αj𝒟S(Y^=iY=j)+βj𝒟T(Y^=iY=j))+2(k1)ΔCE(Y^).\displaystyle~{}\sum_{i\neq j}|\gamma_{S,j}-\gamma_{T,j}|\cdot\left(\alpha_{j}\mathcal{D}_{S}(\widehat{Y}=i\mid Y=j)+\beta_{j}\mathcal{D}_{T}(\widehat{Y}=i\mid Y=j)\right)+2(k-1)\Delta_{\mathrm{CE}}(\widehat{Y}).
Note that the above holds αj,βj0\forall\alpha_{j},\beta_{j}\geq 0 such that αj+βj=1\alpha_{j}+\beta_{j}=1. By choosing αj=1,j[k]\alpha_{j}=1,\forall j\in[k] and βj=0,j[k]\beta_{j}=0,\forall j\in[k], we have:
=\displaystyle= ij|γS,jγT,j|𝒟S(Y^=iY=j)+2(k1)ΔCE(Y^)\displaystyle~{}\sum_{i\neq j}|\gamma_{S,j}-\gamma_{T,j}|\cdot\mathcal{D}_{S}(\widehat{Y}=i\mid Y=j)+2(k-1)\Delta_{\mathrm{CE}}(\widehat{Y})
=\displaystyle= j=1k|γS,jγT,j|(i=1,ijk𝒟S(Y^=iY=j))+2(k1)ΔCE(Y^)\displaystyle~{}\sum_{j=1}^{k}|\gamma_{S,j}-\gamma_{T,j}|\cdot\left(\sum_{i=1,i\neq j}^{k}\mathcal{D}_{S}(\widehat{Y}=i\mid Y=j)\right)+2(k-1)\Delta_{\mathrm{CE}}(\widehat{Y})
=\displaystyle= j=1k|γS,jγT,j|𝒟S(Y^YY=j)+2(k1)ΔCE(Y^)\displaystyle~{}\sum_{j=1}^{k}|\gamma_{S,j}-\gamma_{T,j}|\cdot\mathcal{D}_{S}(\widehat{Y}\neq Y\mid Y=j)+2(k-1)\Delta_{\mathrm{CE}}(\widehat{Y})
\displaystyle\leq 𝒟SY𝒟TY1BER𝒟S(Y^Y)+2(k1)ΔCE(Y^),\displaystyle~{}\|\mathcal{D}_{S}^{Y}-\mathcal{D}_{T}^{Y}\|_{1}\cdot\mathrm{BER}_{\mathcal{D}_{S}}(\widehat{Y}~{}\|~{}Y)+2(k-1)\Delta_{\mathrm{CE}}(\widehat{Y}),

where the last line is due to Holder’s inequality, completing the proof. ∎

A.5 Proof of Theorem 3.2

See 3.2

Proof.

First, by the law of total probability, we have:

εS(Y^)+εT(Y^)\displaystyle\varepsilon_{S}(\widehat{Y})+\varepsilon_{T}(\widehat{Y}) =𝒟S(YY^)+𝒟T(YY^)\displaystyle=\mathcal{D}_{S}(Y\neq\widehat{Y})+\mathcal{D}_{T}(Y\neq\widehat{Y})
=j=1kij𝒟S(Y^=i|Y=j)𝒟S(Y=j)+𝒟T(Y^=i|Y=j)𝒟T(Y=j).\displaystyle=\displaystyle{\sum_{j=1}^{k}}\displaystyle{\sum_{i\neq j}}\mathcal{D}_{S}(\widehat{Y}=i|Y=j)\mathcal{D}_{S}(Y=j)+\mathcal{D}_{T}(\widehat{Y}=i|Y=j)\mathcal{D}_{T}(Y=j).

Now, since Y^=(hg)(X)=h(Z)\widehat{Y}=(h\circ g)(X)=h(Z), Y^\widehat{Y} is a function of ZZ. Given the generalized label shift assumption, this guarantees that:

y,y𝒴,𝒟S(Y^=yY=y)=𝒟T(Y^=yY=y).\displaystyle\forall y,y^{\prime}\in\mathcal{Y},\quad\mathcal{D}_{S}(\widehat{Y}=y^{\prime}\mid Y=y)=\mathcal{D}_{T}(\widehat{Y}=y^{\prime}\mid Y=y).
Thus:
εS(Y^)+εT(Y^)\displaystyle\varepsilon_{S}(\widehat{Y})+\varepsilon_{T}(\widehat{Y}) =j=1kij𝒟S(Y^=i|Y=j)(𝒟S(Y=j)+𝒟T(Y=j))\displaystyle=\displaystyle{\sum_{j=1}^{k}}\displaystyle{\sum_{i\neq j}}\mathcal{D}_{S}(\widehat{Y}=i|Y=j)(\mathcal{D}_{S}(Y=j)+\mathcal{D}_{T}(Y=j))
=j[k]𝒟S(Y^YY=j)(𝒟S(Y=j)+𝒟T(Y=j))\displaystyle=\sum_{j\in[k]}\mathcal{D}_{S}(\widehat{Y}\neq Y\mid Y=j)\cdot(\mathcal{D}_{S}(Y=j)+\mathcal{D}_{T}(Y=j))
maxj[k]𝒟S(Y^YY=j)j[k]𝒟S(Y=j)+𝒟T(Y=j)\displaystyle\leq\max_{j\in[k]}\mathcal{D}_{S}(\widehat{Y}\neq Y\mid Y=j)\cdot\sum_{j\in[k]}\mathcal{D}_{S}(Y=j)+\mathcal{D}_{T}(Y=j)
=2BER𝒟S(Y^Y).\displaystyle=2\mathrm{BER}_{\mathcal{D}_{S}}(\widehat{Y}~{}\|~{}Y).\qed

A.6 Proof of Lemma 3.1

See 3.1

Proof.

From GLSGLS, we know that 𝒟S(ZY=y)=𝒟T(ZY=y)\mathcal{D}_{S}(Z\mid Y=y)=\mathcal{D}_{T}(Z\mid Y=y). Applying any function h~\tilde{h} to ZZ will maintain that equality (in particular h~(Z)=Y~Z\tilde{h}(Z)=\tilde{Y}\otimes Z). Using that fact and Eq. (4) on the second line gives:

𝒟T(Z~)=\displaystyle\mathcal{D}_{T}(\tilde{Z})= y𝒴𝒟T(Y=y)𝒟T(Z~Y=y)\displaystyle~{}\sum_{y\in\mathcal{Y}}\mathcal{D}_{T}(Y=y)\cdot\mathcal{D}_{T}(\tilde{Z}\mid Y=y)
=\displaystyle= y𝒴𝐰y𝒟S(Y=y)𝒟S(Z~Y=y)\displaystyle~{}\sum_{y\in\mathcal{Y}}\mathbf{w}_{y}\cdot\mathcal{D}_{S}(Y=y)\cdot\mathcal{D}_{S}(\tilde{Z}\mid Y=y)
=\displaystyle= y𝒴𝐰y𝒟S(Z~,Y=y).\displaystyle~{}\sum_{y\in\mathcal{Y}}\mathbf{w}_{y}\cdot\mathcal{D}_{S}(\tilde{Z},Y=y).\qed (14)

A.7 Proof of Theorem 3.3

See 3.3

Proof.

Follow the condition that 𝒟T(Z)=𝒟S𝐰(Z)\mathcal{D}_{T}(Z)=\mathcal{D}_{S}^{\mathbf{w}}(Z), by definition of 𝒟S𝐰(Z)\mathcal{D}_{S}^{\mathbf{w}}(Z), we have:

𝒟T(Z)=y𝒴𝒟T(Y=y)𝒟S(Y=y)𝒟S(Z,Y=y)\displaystyle\mathcal{D}_{T}(Z)=\sum_{y\in\mathcal{Y}}\frac{\mathcal{D}_{T}(Y=y)}{\mathcal{D}_{S}(Y=y)}\mathcal{D}_{S}(Z,Y=y)
\displaystyle\iff 𝒟T(Z)=y𝒴𝒟T(Y=y)𝒟S(ZY=y)\displaystyle\mathcal{D}_{T}(Z)=\sum_{y\in\mathcal{Y}}\mathcal{D}_{T}(Y=y)\mathcal{D}_{S}(Z\mid Y=y)
\displaystyle\iff y𝒴𝒟T(Y=y)𝒟T(ZY=y)=y𝒴𝒟T(Y=y)𝒟S(ZY=y).\displaystyle\sum_{y\in\mathcal{Y}}\mathcal{D}_{T}(Y=y)\mathcal{D}_{T}(Z\mid Y=y)=\sum_{y\in\mathcal{Y}}\mathcal{D}_{T}(Y=y)\mathcal{D}_{S}(Z\mid Y=y).

Note that the above equation holds for all measurable subsets of 𝒵\mathcal{Z}. Now by the assumption that 𝒵=y𝒴𝒵y\mathcal{Z}=\cup_{y\in\mathcal{Y}}\mathcal{Z}_{y} is a partition of 𝒵\mathcal{Z}, consider 𝒵y\mathcal{Z}_{y^{\prime}}:

y𝒴𝒟T(Y=y)𝒟T(Z𝒵yY=y)=y𝒴𝒟T(Y=y)𝒟S(Z𝒵yY=y).\sum_{y\in\mathcal{Y}}\mathcal{D}_{T}(Y=y)\mathcal{D}_{T}(Z\in\mathcal{Z}_{y^{\prime}}\mid Y=y)=\sum_{y\in\mathcal{Y}}\mathcal{D}_{T}(Y=y)\mathcal{D}_{S}(Z\in\mathcal{Z}_{y^{\prime}}\mid Y=y).

Due to the assumption 𝒟S(Z𝒵yY=y)=𝒟T(Z𝒵yY=y)=1\mathcal{D}_{S}(Z\in\mathcal{Z}_{y}\mid Y=y)=\mathcal{D}_{T}(Z\in\mathcal{Z}_{y}\mid Y=y)=1, we know that yy\forall y^{\prime}\neq y, 𝒟T(Z𝒵yY=y)=𝒟S(Z𝒵yY=y)=0\mathcal{D}_{T}(Z\in\mathcal{Z}_{y^{\prime}}\mid Y=y)=\mathcal{D}_{S}(Z\in\mathcal{Z}_{y^{\prime}}\mid Y=y)=0. This shows that both the supports of 𝒟S(ZY=y)\mathcal{D}_{S}(Z\mid Y=y) and 𝒟T(ZY=y)\mathcal{D}_{T}(Z\mid Y=y) are contained in 𝒵y\mathcal{Z}_{y}. Now consider an arbitrary measurable set E𝒵yE\subseteq\mathcal{Z}_{y}, since y𝒴𝒵y\cup_{y\in\mathcal{Y}}\mathcal{Z}_{y} is a partition of 𝒵\mathcal{Z}, we know that

𝒟S(ZEY=y)=𝒟T(ZEY=y)=0,yy.\mathcal{D}_{S}(Z\in E\mid Y=y^{\prime})=\mathcal{D}_{T}(Z\in E\mid Y=y^{\prime})=0,\quad\forall y^{\prime}\neq y.

Plug ZEZ\in E into the following identity:

y𝒴𝒟T(Y=y)𝒟T(ZEY=y)=y𝒴𝒟T(Y=y)𝒟S(ZEY=y)\displaystyle\sum_{y\in\mathcal{Y}}\mathcal{D}_{T}(Y=y)\mathcal{D}_{T}(Z\in E\mid Y=y)=\sum_{y\in\mathcal{Y}}\mathcal{D}_{T}(Y=y)\mathcal{D}_{S}(Z\in E\mid Y=y)
\displaystyle\implies 𝒟T(Y=y)𝒟T(ZEY=y)=𝒟T(Y=y)𝒟S(ZEY=y)\displaystyle~{}\mathcal{D}_{T}(Y=y)\mathcal{D}_{T}(Z\in E\mid Y=y)=\mathcal{D}_{T}(Y=y)\mathcal{D}_{S}(Z\in E\mid Y=y)
\displaystyle\implies 𝒟T(ZEY=y)=𝒟S(ZEY=y),\displaystyle~{}\mathcal{D}_{T}(Z\in E\mid Y=y)=\mathcal{D}_{S}(Z\in E\mid Y=y),

where the last line holds because 𝒟T(Y=y)0\mathcal{D}_{T}(Y=y)\neq 0. Realize that the choice of EE is arbitrary, this shows that 𝒟S(ZY=y)=𝒟T(ZY=y)\mathcal{D}_{S}(Z\mid Y=y)=\mathcal{D}_{T}(Z\mid Y=y), which completes the proof. ∎

A.8 Sufficient Conditions for GLSGLS

See 3.4

Proof.

To prove the above upper bound, let us first fix a y𝒴y\in\mathcal{Y} and fix a classifier Y^=h(Z)\widehat{Y}=h(Z) for some h:𝒵𝒴h:\mathcal{Z}\to\mathcal{Y}. Now consider any measurable subset E𝒵E\subseteq\mathcal{Z}, we would like to upper bound the following quantity:

|𝒟S(ZEY=y)\displaystyle|\mathcal{D}_{S}(Z\in E\mid Y=y) 𝒟T(ZEY=y)|\displaystyle-\mathcal{D}_{T}(Z\in E\mid Y=y)|
=1𝒟T(Y=y)|𝒟S(ZE,Y=y)𝐰y𝒟T(ZE,Y=y)|\displaystyle=\frac{1}{\mathcal{D}_{T}(Y=y)}\cdot|\mathcal{D}_{S}(Z\in E,Y=y)\mathbf{w}_{y}-\mathcal{D}_{T}(Z\in E,Y=y)|
1γ|𝒟S(ZE,Y=y)𝐰y𝒟T(ZE,Y=y)|.\displaystyle\leq\frac{1}{\gamma}\cdot|\mathcal{D}_{S}(Z\in E,Y=y)\mathbf{w}_{y}-\mathcal{D}_{T}(Z\in E,Y=y)|.

Hence it suffices if we can upper bound |𝒟S(ZE,Y=y)𝐰y𝒟T(ZE,Y=y)||\mathcal{D}_{S}(Z\in E,Y=y)\mathbf{w}_{y}-\mathcal{D}_{T}(Z\in E,Y=y)|. To do so, consider the following decomposition:

|𝒟T(ZE,Y=y)𝒟S(ZE,Y=y)𝐰y|=\displaystyle|\mathcal{D}_{T}(Z\in E,Y=y)-\mathcal{D}_{S}(Z\in E,Y=y)\mathbf{w}_{y}|= |𝒟T(ZE,Y=y)𝒟T(ZE,Y^=y)\displaystyle~{}|\mathcal{D}_{T}(Z\in E,Y=y)-\mathcal{D}_{T}(Z\in E,\widehat{Y}=y)
+𝒟T(ZE,Y^=y)𝒟S𝐰(ZE,Y^=y)\displaystyle+\mathcal{D}_{T}(Z\in E,\widehat{Y}=y)-\mathcal{D}_{S}^{\mathbf{w}}(Z\in E,\widehat{Y}=y)
+𝒟S𝐰(ZE,Y^=y)𝒟S(ZE,Y=y)𝐰y|\displaystyle+\mathcal{D}_{S}^{\mathbf{w}}(Z\in E,\widehat{Y}=y)-\mathcal{D}_{S}(Z\in E,Y=y)\mathbf{w}_{y}|
\displaystyle\leq |𝒟T(ZE,Y=y)𝒟T(ZE,Y^=y)|\displaystyle~{}|\mathcal{D}_{T}(Z\in E,Y=y)-\mathcal{D}_{T}(Z\in E,\widehat{Y}=y)|
+|𝒟T(ZE,Y^=y)𝒟S𝐰(ZE,Y^=y)|\displaystyle+|\mathcal{D}_{T}(Z\in E,\widehat{Y}=y)-\mathcal{D}_{S}^{\mathbf{w}}(Z\in E,\widehat{Y}=y)|
+|𝒟S𝐰(ZE,Y^=y)𝒟S(ZE,Y=y)𝐰y|.\displaystyle+|\mathcal{D}_{S}^{\mathbf{w}}(Z\in E,\widehat{Y}=y)-\mathcal{D}_{S}(Z\in E,Y=y)\mathbf{w}_{y}|.

We bound the above three terms in turn. First, consider |𝒟T(ZE,Y=y)𝒟T(ZE,Y^=y)||\mathcal{D}_{T}(Z\in E,Y=y)-\mathcal{D}_{T}(Z\in E,\widehat{Y}=y)|:

|𝒟T(ZE,Y=y)\displaystyle|\mathcal{D}_{T}(Z\in E,Y=y) 𝒟T(ZE,Y^=y)|\displaystyle-\mathcal{D}_{T}(Z\in E,\widehat{Y}=y)|
=\displaystyle= |y𝒟T(ZE,Y=y,Y^=y)y𝒟T(ZE,Y^=y,Y=y)|\displaystyle~{}|\sum_{y^{\prime}}\mathcal{D}_{T}(Z\in E,Y=y,\widehat{Y}=y^{\prime})-\sum_{y^{\prime}}\mathcal{D}_{T}(Z\in E,\widehat{Y}=y,Y=y^{\prime})|
\displaystyle\leq yy|𝒟T(ZE,Y=y,Y^=y)𝒟T(ZE,Y^=y,Y=y)|\displaystyle~{}\sum_{y^{\prime}\neq y}|\mathcal{D}_{T}(Z\in E,Y=y,\widehat{Y}=y^{\prime})-\mathcal{D}_{T}(Z\in E,\widehat{Y}=y,Y=y^{\prime})|
\displaystyle\leq yy𝒟T(ZE,Y=y,Y^=y)+𝒟T(ZE,Y^=y,Y=y)\displaystyle~{}\sum_{y^{\prime}\neq y}\mathcal{D}_{T}(Z\in E,Y=y,\widehat{Y}=y^{\prime})+\mathcal{D}_{T}(Z\in E,\widehat{Y}=y,Y=y^{\prime})
\displaystyle\leq yy𝒟T(Y=y,Y^=y)+𝒟T(Y^=y,Y=y)\displaystyle~{}\sum_{y^{\prime}\neq y}\mathcal{D}_{T}(Y=y,\widehat{Y}=y^{\prime})+\mathcal{D}_{T}(\widehat{Y}=y,Y=y^{\prime})
\displaystyle\leq 𝒟T(YY^)\displaystyle~{}\mathcal{D}_{T}(Y\neq\widehat{Y})
=\displaystyle= εT(Y^),\displaystyle~{}\varepsilon_{T}(\widehat{Y}),

where the last inequality is due to the fact that the definition of error rate corresponds to the sum of all the off-diagonal elements in the confusion matrix while the sum here only corresponds to the sum of all the elements in two slices. Similarly, we can bound the third term as follows:

|𝒟S𝐰(ZE,Y^=y)𝒟S(ZE,Y=y)𝐰y|\displaystyle|\mathcal{D}_{S}^{\mathbf{w}}(Z\in E,\widehat{Y}=y)-\mathcal{D}_{S}(Z\in E,Y=y)\mathbf{w}_{y}|
=|y𝒟S(ZE,Y^=y,Y=y)𝐰yy𝒟S(ZE,Y^=y,Y=y)𝐰y|\displaystyle\hskip 85.35826pt=|\displaystyle{\sum_{y^{\prime}}}\mathcal{D}_{S}(Z\in E,\widehat{Y}=y,Y=y^{\prime})\mathbf{w}_{y^{\prime}}-\displaystyle{\sum_{y^{\prime}}}\mathcal{D}_{S}(Z\in E,\widehat{Y}=y^{\prime},Y=y)\mathbf{w}_{y}|
|yy𝒟S(ZE,Y^=y,Y=y)𝐰y𝒟S(ZE,Y^=y,Y=y)𝐰y|\displaystyle\hskip 85.35826pt\leq|\displaystyle{\sum_{y^{\prime}\neq y}}\mathcal{D}_{S}(Z\in E,\widehat{Y}=y,Y=y^{\prime})\mathbf{w}_{y^{\prime}}-\mathcal{D}_{S}(Z\in E,\widehat{Y}=y^{\prime},Y=y)\mathbf{w}_{y}|
𝐰Myy𝒟S(ZE,Y^=y,Y=y)+𝒟S(ZE,Y^=y,Y=y)\displaystyle\hskip 85.35826pt\leq\mathbf{w}_{M}\displaystyle{\sum_{y^{\prime}\neq y}}\mathcal{D}_{S}(Z\in E,\widehat{Y}=y,Y=y^{\prime})+\mathcal{D}_{S}(Z\in E,\widehat{Y}=y^{\prime},Y=y)
𝐰M𝒟S(ZE,Y^Y)\displaystyle\hskip 85.35826pt\leq\mathbf{w}_{M}\mathcal{D}_{S}(Z\in E,\widehat{Y}\neq Y)
𝐰MεS(Y^).\displaystyle\hskip 85.35826pt\leq\mathbf{w}_{M}\varepsilon_{S}(\widehat{Y}).

Now we bound the last term. Recall the definition of total variation, we have:

|𝒟T(ZE,Y^=y)\displaystyle|\mathcal{D}_{T}(Z\in E,\widehat{Y}=y) 𝒟S𝐰(ZE,Y^=y)|\displaystyle-\mathcal{D}_{S}^{\mathbf{w}}(Z\in E,\widehat{Y}=y)|
=|𝒟T(ZEZY^1(y))𝒟S𝐰(ZEZY^1(y))|\displaystyle=|\mathcal{D}_{T}(Z\in E\land Z\in\widehat{Y}^{-1}(y))-\mathcal{D}_{S}^{\mathbf{w}}(Z\in E\land Z\in\widehat{Y}^{-1}(y))|
supE is measurable|𝒟T(ZE)𝒟S𝐰(ZE)|\displaystyle\leq\sup_{E^{\prime}\text{ is measurable}}|\mathcal{D}_{T}(Z\in E^{\prime})-\mathcal{D}_{S}^{\mathbf{w}}(Z\in E^{\prime})|
=dTV(𝒟T(Z),𝒟S𝐰(Z)).\displaystyle=d_{\text{TV}}(\mathcal{D}_{T}(Z),\mathcal{D}_{S}^{\mathbf{w}}(Z)).

Combining the above three parts yields

|𝒟S(ZEY=y)𝒟T(ZEY=y)|1γ(𝐰MεS(Y^)+εT(Y^)+dTV(𝒟S𝐰(Z),𝒟T(Z))).\displaystyle|\mathcal{D}_{S}(Z\in E\mid Y=y)-\mathcal{D}_{T}(Z\in E\mid Y=y)|\leq\frac{1}{\gamma}\cdot\left(\mathbf{w}_{M}\varepsilon_{S}(\widehat{Y})+\varepsilon_{T}(\widehat{Y})+d_{\text{TV}}(\mathcal{D}_{S}^{\mathbf{w}}(Z),\mathcal{D}_{T}(Z))\right).

Now realizing that the choice of y𝒴y\in\mathcal{Y} and the measurable subset EE on the LHS is arbitrary, this leads to

maxy𝒴supE|𝒟S(ZEY=y)\displaystyle\max_{y\in\mathcal{Y}}\sup_{E}|\mathcal{D}_{S}(Z\in E\mid Y=y) 𝒟T(ZEY=y)|\displaystyle-\mathcal{D}_{T}(Z\in E\mid Y=y)|
1γ(𝐰MεS(Y^)+εT(Y^)+dTV(𝒟S𝐰(Z),𝒟T(Z))).\displaystyle\leq\frac{1}{\gamma}\cdot\left(\mathbf{w}_{M}\varepsilon_{S}(\widehat{Y})+\varepsilon_{T}(\widehat{Y})+d_{\text{TV}}(\mathcal{D}_{S}^{\mathbf{w}}(Z),\mathcal{D}_{T}(Z))\right).

From Briët and Harremoës [10], we have:

dTV(𝒟S𝐰(Z),𝒟T(Z))8DJS(𝒟S𝐰(Z)||𝒟T(Z))d_{\text{TV}}(\mathcal{D}_{S}^{\mathbf{w}}(Z),\mathcal{D}_{T}(Z))\leq\sqrt{8D_{\text{JS}}(\mathcal{D}_{S}^{\mathbf{w}}(Z)~{}||~{}\mathcal{D}_{T}(Z))}

(the total variation and Jensen-Shannon distance are equivalent), which gives the results for Z~=Z\tilde{Z}=Z. Finally, noticing that zh(z)zz\to h(z)\otimes z is a bijection (h(z)h(z) sums to 1), we have:

DJS(𝒟S𝐰(Z)||𝒟T(Z))=DJS(𝒟S𝐰(Y^Z)||𝒟T(Y^Z)),D_{\text{JS}}(\mathcal{D}_{S}^{\mathbf{w}}(Z)~{}||~{}\mathcal{D}_{T}(Z))=D_{\text{JS}}(\mathcal{D}_{S}^{\mathbf{w}}(\widehat{Y}\otimes Z)~{}||~{}\mathcal{D}_{T}(\widehat{Y}\otimes Z)),

which completes the proof. ∎

Furthermore, since the above upper bound holds for any classifier Y^=h(Z)\widehat{Y}=h(Z), we even have:

maxy𝒴dTV(𝒟S(ZEY=y)\displaystyle\max_{y\in\mathcal{Y}}d_{\text{TV}}(\mathcal{D}_{S}(Z\in E\mid Y=y) ,𝒟T(ZEY=y))\displaystyle,\mathcal{D}_{T}(Z\in E\mid Y=y))
1γinfY^(𝐰MεS(Y^)+εT(Y^)+dTV(𝒟S𝐰(Z),𝒟T(Z))).\displaystyle\leq\frac{1}{\gamma}\cdot\inf_{\widehat{Y}}\left(\mathbf{w}_{M}\varepsilon_{S}(\widehat{Y})+\varepsilon_{T}(\widehat{Y})+d_{\text{TV}}(\mathcal{D}_{S}^{\mathbf{w}}(Z),\mathcal{D}_{T}(Z))\right).

A.9 Proof of Lemma 3.2

See 3.2

Proof.

Given (2), and with the joint hypothesis Y^=h(Z)\widehat{Y}=h(Z) over both source and target domains, it is straightforward to see that the induced conditional distributions over predicted labels match between the source and target domains, i.e.:

𝒟S(Y^=h(Z)Y=y)=𝒟T(Y^=h(Z)Y=y),y𝒴.\displaystyle\mathcal{D}_{S}(\widehat{Y}=h(Z)\mid Y=y)=\mathcal{D}_{T}(\widehat{Y}=h(Z)\mid Y=y),~{}\forall y\in\mathcal{Y}. (15)

This allows us to compute 𝝁y,y𝒴\boldsymbol{\mu}_{y},~{}\forall y\in\mathcal{Y} as

𝒟T(Y^=y)=\displaystyle\mathcal{D}_{T}(\widehat{Y}=y)= y𝒴𝒟T(Y^=yY=y)𝒟T(Y=y)\displaystyle~{}\sum_{y^{\prime}\in\mathcal{Y}}\mathcal{D}_{T}(\widehat{Y}=y\mid Y=y^{\prime})\cdot\mathcal{D}_{T}(Y=y^{\prime})
=\displaystyle= y𝒴𝒟S(Y^=yY=y)𝒟T(Y=y)\displaystyle~{}\sum_{y^{\prime}\in\mathcal{Y}}\mathcal{D}_{S}(\widehat{Y}=y\mid Y=y^{\prime})\cdot\mathcal{D}_{T}(Y=y^{\prime})
=\displaystyle= y𝒴𝒟S(Y^=y,Y=y)𝒟T(Y=y)𝒟S(Y=y)\displaystyle~{}\sum_{y^{\prime}\in\mathcal{Y}}\mathcal{D}_{S}(\widehat{Y}=y,Y=y^{\prime})\cdot\frac{\mathcal{D}_{T}(Y=y^{\prime})}{\mathcal{D}_{S}(Y=y^{\prime})}
=\displaystyle= y𝒴Cy,y𝐰y.\displaystyle~{}\sum_{y^{\prime}\in\mathcal{Y}}\textbf{C}_{y,y^{\prime}}\cdot\mathbf{w}_{y^{\prime}}.

where we used (15) for the second line. We thus have 𝝁=C𝐰\boldsymbol{\mu}=\textbf{C}\mathbf{w} which concludes the proof. ∎

A.10 \mathcal{F}-IPM for Distributional Alignment

In Table 4, we list different instances of IPM with different choices of the function class \mathcal{F} in the above definition, including the total variation distance, Wasserstein-1 distance and the Maximum mean discrepancy [27].

Table 4: List of IPMs with different \mathcal{F}. Lip\|\cdot\|_{\text{Lip}} denotes the Lipschitz seminorm and \mathcal{H} is a reproducing kernel Hilbert space (RKHS).
\mathcal{F} dd_{\mathcal{F}}
{f:f1}\{f:\|f\|_{\infty}\leq 1\} Total Variation
{f:fLip1}\{f:\|f\|_{\text{Lip}}\leq 1\} Wasserstein-1 distance
{f:f1}\{f:\|f\|_{\mathcal{H}}\leq 1\} Maximum mean discrepancy

Appendix B Experimentation Details

B.1 Computational Complexity

Our algorithms imply negligible time and memory overhead compared to their base versions. They are, in practice, indistinguishable from the underlying baseline:

  • Weight estimation requires storing the confusion matrix CC and the predictions μ\mu. This has a memory cost of O(k2)O(k^{2}), small compared to the size of a neural network that performs well on k classes.

  • The extra computational cost comes from solving the quadratic program 5, which only depends on the number of classes kk and is solved once per epoch (not per gradient step). For Office-Home, it is a 65×6565\times 65 QP, solved 100\approx 100 times. Its runtime is negligible compared to tens of thousands of gradient steps.

B.2 Description of the domain adaptation tasks

Digits We follow a widely used evaluation protocol [29, 41]. For the digits datasets MNIST (M, LeCun and Cortes [32]) and USPS (U, Dheeru and Karra [19]), we consider the DA tasks: M \rightarrow U and U \rightarrow M. Performance is evaluated on the 10,000/2,007 examples of the MNIST/USPS test sets.

Visda [55] is a sim-to-real domain adaptation task. The synthetic domain contains 2D rendering of 3D models captured at different angles and lighting conditions. The real domain is made of natural images. Overall, the training, validation and test domains contain 152,397, 55,388 and 5,534 images, from 12 different classes.

Office-31 [49] is one of the most popular dataset for domain adaptation . It contains 4,652 images from 31 classes. The samples come from three domains: Amazon (A), DSLR (D) and Webcam (W), which generate six possible transfer tasks, A \rightarrow D, A \rightarrow W, D \rightarrow A, D \rightarrow W, W \rightarrow A and W \rightarrow D, which we all evaluate.

Office-Home [54] is a more complex dataset than Office-31. It consists of 15,500 images from 65 classes depicting objects in office and home environments. The images form four different domains: Artistic (A), Clipart (C), Product (P), and Real-World images (R). We evaluate the 12 possible domain adaptation tasks.

B.3 Full results on the domain adaptation tasks

Tables 5, 6, 7, 8, 9 and 10 show the detailed results of all the algorithms on each task of the domains described above. The performance we report is the best test accuracy obtained during training over a fixed number of epochs. We used that value for fairness with respect to the baselines (as shown in Figure 2 Left, the performance of DANN decreases as training progresses, due to the inappropriate matching of representations showcased in Theorem 2.1).

The subscript denotes the fraction of seeds for which our variant outperforms the base algorithm. More precisely, by outperform, we mean that for a given seed (which fixes the network initialization as well as the data being fed to the model) the variant has a larger accuracy on the test set than its base version. Doing so allows to assess specifically the effect of the algorithm, all else kept constant.

Table 5: Results on the Digits tasks. M and U stand for MNIST and USPS, the prefix ss denotes the experiment where the source domain is subsampled to increase DJS(𝒟SY,𝒟TY)D_{\text{JS}}(\mathcal{D}_{S}^{Y},\mathcal{D}_{T}^{Y}).

Method M \rightarrow U U \rightarrow M Avg. sM \rightarrow U sU \rightarrow M Avg.
No Ad. 79.04 75.30 77.17 76.02 75.32 75.67
DANN 90.65 95.66 93.15 79.03 87.46 83.24
IWDAN 93.28100% 96.52100% 94.90100% 91.77100% 93.32100% 92.54100%
IWDAN-O 93.73100% 96.81100% 95.27100% 92.50100% 96.42100% 94.46100%
CDAN 94.16 97.29 95.72 84.91 91.55 88.23
IWCDAN 94.3660% 97.45100% 95.9080% 93.42100% 93.03100% 93.22100%
IWCDAN-O 94.3480% 97.35100% 95.8590% 93.37100% 96.26100% 94.81100%
Table 6: Results on the Visda domain. The prefix ss denotes the experiment where the source domain is subsampled to increase DJS(𝒟SY,𝒟TY)D_{\text{JS}}(\mathcal{D}_{S}^{Y},\mathcal{D}_{T}^{Y}).

Method Visda sVisda
No Ad. 48.39 49.02
DANN 61.88 52.85
IWDAN 63.52100% 60.18100%
IWDAN-O 64.19100% 62.10100%
CDAN 65.60 60.19
IWCDAN 66.4960% 65.83100%
IWCDAN-O 68.15100% 66.85100%
JAN 56.98100% 50.64100%
IWJAN 57.56100% 57.12100%
IWJAN-O 61.48100% 61.30100%
Table 7: Results on the Office dataset.

Method A \rightarrow D A \rightarrow W D \rightarrow A D \rightarrow W W \rightarrow A W \rightarrow D Avg.
No DA 79.60 73.18 59.33 96.30 58.75 99.68 77.81
DANN 84.06 85.41 64.67 96.08 66.77 99.44 82.74
IWDAN 84.3060% 86.42100% 68.38100% 97.13100% 67.1660% 100.0100% 83.9087%
IWDAN-O 87.23100% 88.88100% 69.92100% 98.09100% 67.9680% 99.92100% 85.3397%
CDAN 89.56 93.01 71.25 99.24 70.32 100.0 87.23
IWCDAN 88.9160% 93.2360% 71.9080% 99.3080% 70.4360% 100.0100% 87.3073%
IWCDAN-O 90.0860% 94.52100% 73.11100% 99.3080% 71.83100% 100.0100% 88.1490%
JAN 85.94 85.66 70.50 97.48 71.5 99.72 85.13
IWJAN 87.68100% 84.860% 70.3660% 98.98100% 70.060% 100.0100% 85.3260%
IWJAN-O 89.68100% 89.18100% 71.96100% 99.02100% 73.0100% 100.0100% 87.14100%
Table 8: Results on the Subsampled Office dataset.

Method sA \rightarrow D sA \rightarrow W sD \rightarrow A sD \rightarrow W sW \rightarrow A sW \rightarrow D Avg.
No DA 75.82 70.69 56.82 95.32 58.35 97.31 75.72
DANN 75.46 77.66 56.58 93.76 57.51 96.02 76.17
IWDAN 81.61100% 88.43100% 65.00100% 96.98100% 64.86100% 98.72100% 82.60100%
IWDAN-O 84.94100% 91.17100% 68.44100% 97.74100% 64.57100% 99.60100% 84.41100%
CDAN 82.45 84.60 62.54 96.83 65.01 98.31 81.62
IWCDAN 86.59100% 87.30100% 66.45100% 97.69100% 66.34100% 98.92100% 83.88100%
IWCDAN-O 87.39100% 91.47100% 69.69100% 97.91100% 67.50100% 98.88100% 85.47100%
JAN 77.74 77.64 64.48 91.68 92.60 65.10 78.21
IWJAN 84.62100% 83.28100% 65.3080% 96.30100% 98.80100% 67.38100% 82.6197%
IWJAN-O 88.42100% 89.44100% 72.06100% 97.26100% 98.96100% 71.30100% 86.24100%
Table 9: Results on the Office-Home dataset.

Method A \rightarrow C A \rightarrow P A \rightarrow R C \rightarrow A C \rightarrow P C \rightarrow R
No DA 41.02 62.97 71.26 48.66 58.86 60.91
DANN 46.03 62.23 70.57 49.06 63.05 64.14
IWDAN 48.65100% 69.19100% 73.60100% 53.59100% 66.25100% 66.09100%
IWDAN-O 50.19100% 70.53100% 75.44100% 56.69100% 67.40100% 67.98100%
CDAN 49.00 69.23 74.55 54.46 68.23 68.9
IWCDAN 49.81100% 73.41100% 77.56100% 56.5100% 69.6480% 70.33100%
IWCDAN-O 52.31100% 74.54100% 78.46100% 60.33100% 70.78100% 71.47100%
JAN 41.64 67.20 73.12 51.02 62.52 64.46
IWJAN 41.120% 67.5680% 73.1460% 51.70100% 63.42100% 65.22100%
IWJAN-O 41.8880% 68.72100% 73.62100% 53.04100% 63.88100% 66.48100%
Method P \rightarrow A P \rightarrow C P \rightarrow R R \rightarrow A R \rightarrow C R \rightarrow P Avg.
No DA 47.1 35.94 68.27 61.79 44.42 75.5 56.39
DANN 48.29 44.06 72.62 63.81 53.93 77.64 59.62
IWDAN 52.81100% 46.2480% 73.97100% 64.90100% 54.0280% 77.96100% 62.2797%
IWDAN-O 59.33100% 48.28100% 76.37100% 69.42100% 56.09100% 78.45100% 64.68100%
CDAN 56.77 48.8 76.83 71.27 55.72 81.27 64.59
IWCDAN 58.99100% 48.410% 77.94100% 69.480% 54.730% 81.0760% 65.6670%
IWCDAN-O 62.60100% 50.73100% 78.88100% 72.44100% 57.79100% 81.3180% 67.6498%
JAN 54.5 40.36 73.10 64.54 45.98 76.58 59.59
IWJAN 55.2680% 40.3860% 73.0880% 64.4060% 45.680% 76.3640% 59.7863%
IWJAN-O 57.78100% 41.32100% 73.66100% 65.40100% 46.68100% 76.3620% 60.7392%
Table 10: Results on the subsampled Office-Home dataset.

Method A \rightarrow C A \rightarrow P A \rightarrow R C \rightarrow A C \rightarrow P C \rightarrow R
No DA 35.70 54.72 62.61 43.71 52.54 56.62
DANN 36.14 54.16 61.72 44.33 52.56 56.37
IWDAN 39.81100% 63.01100% 68.67100% 47.39100% 61.05100% 60.44100%
IWDAN-O 42.79100% 66.22100% 71.40100% 53.39100% 61.47100% 64.97100%
CDAN 38.90 56.80 64.77 48.02 60.07 61.17
IWCDAN 42.96100% 65.01100% 71.34100% 52.89100% 64.65100% 66.48100%
IWCDAN-O 45.76100% 68.61100% 73.18100% 56.88100% 66.61100% 68.48100%
JAN 34.52 56.86 64.54 46.18 56.84 59.06
IWJAN 36.24100% 61.00100% 66.34100% 48.66100% 59.92100% 61.88100%
IWJAN-O 37.46100% 62.68100% 66.88100% 49.82100% 60.22100% 62.54100%
Method P \rightarrow A P \rightarrow C P \rightarrow R R \rightarrow A R \rightarrow C R \rightarrow P Avg.
No DA 44.29 33.05 65.20 57.12 40.46 70.0
DANN 44.58 37.14 65.21 56.70 43.16 69.86 51.83
IWDAN 50.44100% 41.63100% 72.46100% 61.00100% 49.40100% 76.07100% 57.61100%
IWDAN-O 56.05100% 43.39100% 74.87100% 66.73100% 51.72100% 77.46100% 60.87100%
CDAN 49.65 41.36 70.24 62.35 46.98 74.69 56.25
IWCDAN 54.87100% 44.80100% 75.91100% 67.02100% 50.45100% 78.55100% 61.24100%
IWCDAN-O 59.63100% 46.98100% 77.54100% 69.24100% 53.77100% 78.11100% 63.73100%
JAN 50.64 37.24 69.98 58.72 40.64 72.00 53.94
IWJAN 52.92100% 37.68100% 70.88100% 60.32100% 41.54100% 73.26100% 55.89100%
IWJAN-O 56.54100% 39.66100% 71.78100% 62.36100% 44.56100% 73.76100% 57.36100%

B.4 Jensen-Shannon divergence of the original and subsampled domain adaptation datasets

Tables 11, 12 and 13 show DJS(𝒟S(Z)||𝒟T(Z))D_{\text{JS}}(\mathcal{D}_{S}(Z)||\mathcal{D}_{T}(Z)) for our four datasets and their subsampled versions, rows correspond to the source domain, and columns to the target one. We recall that subsampling simply consists in taking 30%30\% of the first half of the classes in the source domain (which explains why DJS(𝒟S(Z)||𝒟T(Z))D_{\text{JS}}(\mathcal{D}_{S}(Z)||\mathcal{D}_{T}(Z)) is not symmetric for the subsampled datasets).

Table 11: Jensen-Shannon divergence between the label distributions of the Digits and Visda tasks.
MNIST USPS Real
MNIST 0 6.64e36.64\mathrm{e}{-3} -
USPS 6.64e36.64\mathrm{e}{-3} 0 -
Synth. - - 2.61e22.61\mathrm{e}{-2}
(a) Full Dataset
MNIST USPS Real
MNIST 0 6.52e26.52\mathrm{e}{-2} -
USPS 2.75e22.75\mathrm{e}{-2} 0 -
Synth. - - 6.81e26.81\mathrm{e}{-2}
(b) Subsampled
Table 12: Jensen-Shannon divergence between the label distributions of the Office-31 tasks.
Amazon DSLR Webcam
Amazon 0 1.76e21.76\mathrm{e}{-2} 9.52e39.52\mathrm{e}{-3}
DSLR 1.76e21.76\mathrm{e}{-2} 0 2.11e22.11\mathrm{e}{-2}
Webcam 9.52e39.52\mathrm{e}{-3} 2.11e22.11\mathrm{e}{-2} 0
(a) Full Dataset
Amazon DSLR Webcam
Amazon 0 6.25e26.25\mathrm{e}{-2} 4.61e24.61\mathrm{e}{-2}
DSLR 5.44e25.44\mathrm{e}{-2} 0 5.67e25.67\mathrm{e}{-2}
Webcam 5.15e25.15\mathrm{e}{-2} 7.05e27.05\mathrm{e}{-2} 0
(b) Subsampled
Table 13: Jensen-Shannon divergence between the label distributions of the Office-Home tasks.
Art Clipart Product Real World
Art 0 3.85e23.85\mathrm{e}{-2} 4.49e24.49\mathrm{e}{-2} 2.40e22.40\mathrm{e}{-2}
Clipart 3.85e23.85\mathrm{e}{-2} 0 2.33e22.33\mathrm{e}{-2} 2.14e22.14\mathrm{e}{-2}
Product 4.49e24.49\mathrm{e}{-2} 2.33e22.33\mathrm{e}{-2} 0 1.61e21.61\mathrm{e}{-2}
Real World 2.40e22.40\mathrm{e}{-2} 2.14e22.14\mathrm{e}{-2} 1.61e21.61\mathrm{e}{-2} 0
(a) Full Dataset
Art Clipart Product Real World
Art 0 8.41e28.41\mathrm{e}{-2} 8.86e28.86\mathrm{e}{-2} 6.69e26.69\mathrm{e}{-2}
Clipart 7.07e27.07\mathrm{e}{-2} 0 5.86e25.86\mathrm{e}{-2} 5.68e25.68\mathrm{e}{-2}
Product 7.85e27.85\mathrm{e}{-2} 6.24e26.24\mathrm{e}{-2} 0 5.33e25.33\mathrm{e}{-2}
Real World 6.09e26.09\mathrm{e}{-2} 6.52e26.52\mathrm{e}{-2} 5.77e25.77\mathrm{e}{-2} 0
(b) Subsampled

B.5 Losses

B.5.1 DANN

For batches of data (xSi,ySi)(x^{i}_{S},y^{i}_{S}) and (xTi)(x^{i}_{T}) of size ss, the DANN losses are:

DA(xSi,ySi,xTi;θ,ψ)\displaystyle\mathcal{L}_{DA}(x^{i}_{S},y^{i}_{S},x_{T}^{i};\theta,\psi) =1si=1slog(dψ(gθ(xSi)))+log(1dψ(gθ(xTi))),\displaystyle=\hskip 2.84544pt-\frac{1}{s}\displaystyle{\sum_{i=1}^{s}}\log(d_{\psi}(g_{\theta}(x_{S}^{i})))+\log(1-d_{\psi}(g_{\theta}(x_{T}^{i}))), (16)
C(xSi,ySi;θ,ϕ)\displaystyle\mathcal{L}_{C}(x^{i}_{S},y^{i}_{S};\theta,\phi) =1si=1slog(hϕ(gθ(xSi)ySi)).\displaystyle=-\frac{1}{s}\displaystyle{\sum_{i=1}^{s}}\log(h_{\phi}(g_{\theta}(x_{S}^{i})_{y^{i}_{S}})). (17)

B.5.2 CDAN

Similarly, the CDAN losses are:

DA(xSi,ySi,xTi;θ,ψ)\displaystyle\mathcal{L}_{DA}(x^{i}_{S},y^{i}_{S},x_{T}^{i};\theta,\psi) =1si=1slog(dψ(hϕ(gθ(xSi))gθ(xSi)))\displaystyle=\hskip 2.84544pt-\frac{1}{s}\displaystyle{\sum_{i=1}^{s}}\log(d_{\psi}(h_{\phi}(g_{\theta}(x_{S}^{i}))\otimes g_{\theta}(x_{S}^{i}))) (18)
+log(1dψ(hϕ(gθ(xTi))gθ(xTi))),\displaystyle\hskip 71.13188pt+\log(1-d_{\psi}(h_{\phi}(g_{\theta}(x_{T}^{i}))\otimes g_{\theta}(x_{T}^{i}))), (19)
C(xSi,ySi;θ,ϕ)\displaystyle\mathcal{L}_{C}(x^{i}_{S},y^{i}_{S};\theta,\phi) =1si=1slog(hϕ(gθ(xSi)ySi)),\displaystyle=-\frac{1}{s}\displaystyle{\sum_{i=1}^{s}}\log(h_{\phi}(g_{\theta}(x_{S}^{i})_{y^{i}_{S}})), (20)

where hϕ(gθ(xSi))gθ(xSi):=(h1(g(xSi))g(xSi),,hk(g(xSi))g(xSi))h_{\phi}(g_{\theta}(x_{S}^{i}))\otimes g_{\theta}(x_{S}^{i})\vcentcolon=(h_{1}(g(x_{S}^{i}))g(x_{S}^{i}),\dots,h_{k}(g(x_{S}^{i}))g(x_{S}^{i})) and h1(g(xSi))h_{1}(g(x_{S}^{i})) is the ii-th element of vector h(g(xSi))h(g(x_{S}^{i})).

CDAN is particularly well-suited for conditional alignment. As described in Section 2, the CDAN discriminator seeks to match 𝒟S(Y^Z)\mathcal{D}_{S}(\widehat{Y}\otimes Z) with 𝒟T(Y^Z)\mathcal{D}_{T}(\widehat{Y}\otimes Z). This objective is very aligned with GLSGLS: let us first assume for argument’s sake that Y^\widehat{Y} is a perfect classifier on both domains. For any sample (x,y)(x,y), y^z\hat{y}\otimes z is thus a matrix of 0s except on the yy-th row, which contains zz. When label distributions match, the effect of fooling the discriminator will result in representations such that the matrices Y^Z\widehat{Y}\otimes Z are equal on the source and target domains. In other words, the model is such that ZYZ\mid Y match: it verifies GLSGLS (see Th. 3.4 below with 𝐰=1\mathbf{w}=1). On the other hand, if the label distributions differ, fooling the discriminator actually requires mislabelling certain samples (a fact quantified in Th. 2.1).

B.5.3 JAN

The JAN losses [40] are :

DA(xSi,ySi,xTi;θ,ψ)\displaystyle\mathcal{L}_{DA}(x^{i}_{S},y^{i}_{S},x_{T}^{i};\theta,\psi) =1s2i,j=1sk(xSi,xSj)1s2i,j=1sk(xTi,xTj)+2s2i,j=1sk(xSi,xTj)\displaystyle=\hskip 2.84544pt-\frac{1}{s^{2}}\displaystyle{\sum_{i,j=1}^{s}}k(x^{i}_{S},x^{j}_{S})-\frac{1}{s^{2}}\displaystyle{\sum_{i,j=1}^{s}}k(x^{i}_{T},x^{j}_{T})+\frac{2}{s^{2}}\displaystyle{\sum_{i,j=1}^{s}}k(x^{i}_{S},x^{j}_{T}) (21)
C(xSi,ySi;θ,ϕ)\displaystyle\mathcal{L}_{C}(x^{i}_{S},y^{i}_{S};\theta,\phi) =1si=1slog(hϕ(gθ(xSi)ySi)),\displaystyle=-\frac{1}{s}\displaystyle{\sum_{i=1}^{s}}\log(h_{\phi}(g_{\theta}(x_{S}^{i})_{y^{i}_{S}})), (22)

where kk corresponds to the kernel of the RKHS \mathcal{H} used to measure the discrepancy between distributions. Exactly as in Long et al. [40], it is the product of kernels on various layers of the network k(xSi,xSj)=lkl(xSi,xSj)k(x^{i}_{S},x^{j}_{S})=\prod_{l\in\mathcal{L}}k^{l}(x^{i}_{S},x^{j}_{S}). Each individual kernel klk^{l} is computed as the dot-product between two transformations of the representation: kl(xSi,xSj)=dψl(gθl(xSi)),dψl(gθl(xSj))k^{l}(x^{i}_{S},x^{j}_{S})=\langle d^{l}_{\psi}(g^{l}_{\theta}(x_{S}^{i})),d^{l}_{\psi}(g^{l}_{\theta}(x_{S}^{j}))\rangle (in this case, dψld^{l}_{\psi} outputs vectors in a high-dimensional space). See Section B.7 for more details.

The IWJAN losses are:

DA𝐰(xSi,ySi,xTi;θ,ψ)\displaystyle\mathcal{L}^{\mathbf{w}}_{DA}(x^{i}_{S},y^{i}_{S},x_{T}^{i};\theta,\psi) =1s2i,j=1s𝐰ySi𝐰ySjk(xSi,xSj)1s2i,j=1sk(xTi,xTj)+2s2i,j=1s𝐰ySik(xSi,xTj)\displaystyle=\hskip 2.84544pt-\frac{1}{s^{2}}\displaystyle{\sum_{i,j=1}^{s}}\mathbf{w}_{y^{i}_{S}}\mathbf{w}_{y^{j}_{S}}k(x^{i}_{S},x^{j}_{S})-\frac{1}{s^{2}}\displaystyle{\sum_{i,j=1}^{s}}k(x^{i}_{T},x^{j}_{T})+\frac{2}{s^{2}}\displaystyle{\sum_{i,j=1}^{s}}\mathbf{w}_{y^{i}_{S}}k(x^{i}_{S},x^{j}_{T}) (23)
C𝐰(xSi,ySi;θ,ϕ)\displaystyle\mathcal{L}_{C}^{\mathbf{w}}(x^{i}_{S},y^{i}_{S};\theta,\phi) =1si=1s𝐰ySik𝒟S(Y=y)log(hϕ(gθ(xSi))ySi).\displaystyle=-\frac{1}{s}\displaystyle{\sum_{i=1}^{s}}\frac{\mathbf{w}_{y^{i}_{S}}}{k\mathcal{D}_{S}(Y=y)}\log(h_{\phi}(g_{\theta}(x_{S}^{i}))_{y^{i}_{S}}). (24)

B.6 Generation of domain adaptation tasks with varying DJS(𝒟S(Z)𝒟T(Z))D_{\text{JS}}(\mathcal{D}_{S}(Z)~{}\|~{}\mathcal{D}_{T}(Z))

We consider the MNIST \rightarrow USPS task and generate a set 𝒱\mathcal{V} of 5050 vectors in [0.1,1]10[0.1,1]^{10}. Each vector corresponds to the fraction of each class to be trained on, either in the source or the target domain (to assess the impact of both). The left bound is chosen as 0.10.1 to ensure that classes all contain some samples.

This methodology creates 100100 domain adaptation tasks, 5050 for subsampled-MNIST \rightarrow USPS and 5050 for MNIST \rightarrow subsampled-USPS, with Jensen-Shannon divergences varying from 6.1e36.1\mathrm{e}{-3} to 9.53e29.53\mathrm{e}{-2}111We manually rejected some samples to guarantee a rather uniform set of divergences.. They are then used to evaluate our algorithms, see Section 4 and Figures 1 and 3. They show the performance of the 6 algorithms we consider. We see the sharp decrease in performance of the base versions DANN and CDAN. Comparatively, our importance-weighted algorithms maintain good performance even for large divergences between the marginal label distributions.

B.7 Implementation details

All the values reported below are the default ones in the implementations of DANN, CDAN and JAN released with the respective papers (see links to the github repos in the footnotes). We did not perform any search on them, assuming they had already been optimized by the authors of those papers. To ensure a fair comparison and showcase the simplicity of our approach, we simply plugged the weight estimation on top of those baselines and used their original hyperparameters.

For MNIST and USPS, the architecture is akin to LeNet [31], with two convolutional layers, ReLU and MaxPooling, followed by two fully connected layers. The representation is also taken as the last hidden layer, and has 500 neurons. The optimizer for those tasks is SGD with a learning rate of 0.020.02, annealed by 0.50.5 every five training epochs for M \rightarrow U and 66 for U \rightarrow M. The weight decay is also 5e45\mathrm{e}{-4} and the momentum 0.90.9.

For the Office and Visda experiments with IWDAN and IWCDAN, we train a ResNet-50, optimized using SGD with momentum. The weight decay is also 5e45\mathrm{e}{-4} and the momentum 0.90.9. The learning rate is 3e43\mathrm{e}{-4} for the Office-31 tasks A \rightarrow D and D \rightarrow W, 1e31\mathrm{e}{-3} otherwise (default learning rates from the CDAN implementation222https://github.com/thuml/CDAN/tree/master/pytorch).

For the IWJAN experiments, we use the default implementation of Xlearn codebase333https://github.com/thuml/Xlearn/tree/master/pytorch and simply add the weigths estimation and reweighted objectives to it, as described in Section B.5. Parameters, configuration and networks remain the same.

Finally, for the Office experiments, we update the importance weights 𝐰\mathbf{w} every 15 passes on the dataset (in order to improve their estimation on small datasets). On Digits and Visda, the importance weights are updated every pass on the source dataset. Here too, fine-tuning that value might lead to a better estimation of 𝐰\mathbf{w} and help bridge the gap with the oracle versions of the algorithms.

We use the cvxopt package444http://cvxopt.org/ to solve the quadratic programm 5.

We trained our models on single-GPU machines (P40s and P100s). The runtime of our algorithms is undistinguishable from the the runtime of their base versions.

Refer to caption
(a) Performance of DANN, IWDAN and IWDAN-O.
Refer to caption
(b) Performance of CDAN, CDAN and IWCDAN.
Figure 3: Performance in % of our algorithms and their base versions. The xx-axis represents DJS(𝒟SY,𝒟TY)D_{\text{JS}}(\mathcal{D}_{S}^{Y},\mathcal{D}_{T}^{Y}), the Jensen-Shannon distance between label distributions. Lines represent linear fits to the data. For both sets of algorithms, the larger the jsd, the larger the improvement.

B.8 Weight Estimation

We estimate importance weights using Lemma 3.2, which relies on the GLSGLS assumption. However, there is no guarantee that GLSGLS is verified at any point during training, so the exact dynamics of 𝐰\mathbf{w} are unclear. Below we discuss those dynamics and provide some intuition about them.

In Fig. 4b, we plot the Euclidian distance between the moving average of weights estimated using the equation 𝐰=C1𝝁\mathbf{w}=\textbf{C}^{-1}\boldsymbol{\mu} and the true weights (note that this can be done for any algorithm). As can be seen in the figure, the distance between the estimated and true weights is highly correlated with the performance of the algorithm (see Fig.4). In particular, we see that the estimations for IWDAN is more accurate than for DANN. The estimation for DANN exhibits an interesting shape, improving at first, and then getting worse. At the same time, the estimation for IWDAN improves monotonously. The weights for IWDAN-O get very close to the true weights which is in line with our theoretical results: IWDAN-O gets close to zero error on the target error, Th. 3.4 thus guarantees that GLSGLS is verified, which in turns implies that the weight estimation is accurate (Lemma 3.2). Finally, without domain adaptation, the estimation is very poor. The following two lemmas shed some light on the phenomena observed for DANN and IWDAN:

See 3.3

Proof.

If εS(hg)=0\varepsilon_{S}(h\circ g)=0, then the confusion matrix C is diagonal and its yy-th line is 𝒟S(Y=y)\mathcal{D}_{S}(Y=y). Additionally, if DJS(𝒟S𝐰~(Z),𝒟T(Z))=0D_{\text{JS}}(\mathcal{D}_{S}^{\tilde{\mathbf{w}}}(Z),\mathcal{D}_{T}(Z))=0, then from a straightforward extension of Eq. 12, we have DJS(𝒟S𝐰~(Y^),𝒟T(Y^))=0D_{\text{JS}}(\mathcal{D}_{S}^{\tilde{\mathbf{w}}}(\hat{Y}),\mathcal{D}_{T}(\hat{Y}))=0. In other words, the distribution of predictions on the source and target domains match, i.e. 𝝁y=𝒟T(Y^=y)=y𝐰~y𝒟S(Y^=y,Y=y)=𝐰~y𝒟S(Y=y),y\boldsymbol{\mu}_{y}=\mathcal{D}_{T}(\hat{Y}=y)=\displaystyle{\sum_{y^{\prime}}}\tilde{\mathbf{w}}_{y^{\prime}}\mathcal{D}_{S}(\hat{Y}=y,Y=y^{\prime})=\tilde{\mathbf{w}}_{y}\mathcal{D}_{S}(Y=y),\forall y (where the last equality comes from εS(hg)=0\varepsilon_{S}(h\circ g)=0). Finally, we get that 𝐰=C1𝝁=𝐰~\mathbf{w}=\textbf{C}^{-1}\boldsymbol{\mu}=\tilde{\mathbf{w}}. ∎

In particular, applying this lemma to DANN (i.e. with 𝐰~y=1\tilde{\mathbf{w}}_{y^{\prime}}=\textbf{1}) suggests that at convergence, the estimated weights should tend to 1. Empirically, Fig. 4b shows that as the marginals get matched, the estimation for DANN does get closer to 1 (1 corresponds to a distance of 2.162.16)555It does not reach it as the learning rate is decayed to 0.. We now attempt to provide some intuition on the behavior of IWDAN, with the following lemma:

Lemma B.1.

If εS(hg)=0\varepsilon_{S}(h\circ g)=0 and if for a given yy:

min(𝐰y~𝒟S(Y=y),𝒟T(Y=y))𝝁ymax(𝐰y~𝒟S(Y=y),𝒟T(Y=y)),\min(\tilde{\mathbf{w}_{y}}\mathcal{D}_{S}(Y=y),\mathcal{D}_{T}(Y=y))\leq\boldsymbol{\mu}_{y}\leq\max(\tilde{\mathbf{w}_{y}}\mathcal{D}_{S}(Y=y),\mathcal{D}_{T}(Y=y)), (25)

then, letting 𝐰=C1𝝁\mathbf{w}=\textbf{C}^{-1}\boldsymbol{\mu} be the estimated weight:

|𝐰y𝐰y||𝐰y~𝐰y|.|\mathbf{w}_{y}-\mathbf{w}_{y}^{*}|\leq|\tilde{\mathbf{w}_{y}}-\mathbf{w}_{y}^{*}|.

Applying this lemma to 𝐰y~=𝐰t\tilde{\mathbf{w}_{y}}=\mathbf{w}_{t}, and assuming that (25) holds for all the classes yy (we discuss what the assumption implies below), we get that:

𝐰t+1𝐰y𝐰t𝐰y,\|\mathbf{w}_{t+1}-\mathbf{w}_{y}^{*}\|\leq\|\mathbf{w}_{t}-\mathbf{w}_{y}^{*}\|, (26)

or in other words, the estimation improves monotonously. Combining this with Lemma B.1 suggests an explanation for the shape of the IWDAN estimated weights on Fig. 4b: the monotonous improvement of the estimation is counter-balanced by the matching of weighted marginals which, when reached, makes 𝐰t\mathbf{w}_{t} constant (Lemma 3.3 applied to 𝐰~=𝐰t\tilde{\mathbf{w}}=\mathbf{w}_{t}). However, we wish to emphasize that the exact dynamics of 𝐰\mathbf{w} are complex, and we do not claim understand them fully. In all likelihood, they are the by-product of regularity in the data, properties of deep neural networks and their interaction with stochastic gradient descent. Additionally, the dynamics are also inherently linked to the success of domain adaptation, which to this day remains an open problem.

As a matter of fact, assumption (25) itself relates to successful domain adaptation. Setting aside 𝐰~\tilde{\mathbf{w}}, which simply corresponds to a class reweighting of the source domain, (25) states that predictions on the target domain fall between a successful prediction (corresponding to 𝒟T(Y=y)\mathcal{D}_{T}(Y=y)) and the prediction of a model with matched marginals (corresponding to 𝒟S(Y=y)\mathcal{D}_{S}(Y=y)). In other words, we assume that the model is naturally in between successful domain adaptation and successful marginal matching. Empirically, we observed that it holds true for most classes (with 𝐰~=𝐰~t\tilde{\mathbf{w}}=\tilde{\mathbf{w}}_{t} for IWDAN and with 𝐰~=1\tilde{\mathbf{w}}=\textbf{1} for DANN), but not all early in training666In particular at initialization, one class usually dominates the others..

To conclude this section, we prove Lemma B.1.

Proof.

From εS(hg)=0\varepsilon_{S}(h\circ g)=0, we know that C is diagonal and that its yy-th line is 𝒟S(Y=y)\mathcal{D}_{S}(Y=y). This gives us: 𝐰y=(C1𝝁)y=𝝁y𝒟S(Y=y)\mathbf{w}_{y}=(\textbf{C}^{-1}\boldsymbol{\mu})_{y}=\frac{\boldsymbol{\mu}_{y}}{\mathcal{D}_{S}(Y=y)}. Hence:

min(𝐰y~𝒟S(Y=y),𝒟T(Y=y))𝝁ymax(𝐰y~𝒟S(Y=y),𝒟T(Y=y))\displaystyle\min(\tilde{\mathbf{w}_{y}}\mathcal{D}_{S}(Y=y),\mathcal{D}_{T}(Y=y))\leq\boldsymbol{\mu}_{y}\leq\max(\tilde{\mathbf{w}_{y}}\mathcal{D}_{S}(Y=y),\mathcal{D}_{T}(Y=y))
\displaystyle\Longleftrightarrow\quad min(𝐰y~𝒟S(Y=y),𝒟T(Y=y))𝒟S(Y=y)𝝁y𝒟S(Y=y)max(𝐰y~𝒟S(Y=y),𝒟T(Y=y))𝒟S(Y=y)\displaystyle\frac{\min(\tilde{\mathbf{w}_{y}}\mathcal{D}_{S}(Y=y),\mathcal{D}_{T}(Y=y))}{\mathcal{D}_{S}(Y=y)}\leq\frac{\boldsymbol{\mu}_{y}}{\mathcal{D}_{S}(Y=y)}\leq\frac{\max(\tilde{\mathbf{w}_{y}}\mathcal{D}_{S}(Y=y),\mathcal{D}_{T}(Y=y))}{\mathcal{D}_{S}(Y=y)}
\displaystyle\Longleftrightarrow\quad min(𝐰y~,𝐰y)𝐰ymax(𝐰y~,𝐰y)\displaystyle\min(\tilde{\mathbf{w}_{y}},\mathbf{w}_{y}^{*})\leq\mathbf{w}_{y}\leq\max(\tilde{\mathbf{w}_{y}},\mathbf{w}_{y}^{*})
\displaystyle\Longleftrightarrow\quad min(𝐰y~,𝐰y)𝐰y𝐰y𝐰ymax(𝐰y~,𝐰y)𝐰y\displaystyle\min(\tilde{\mathbf{w}_{y}},\mathbf{w}_{y}^{*})-\mathbf{w}_{y}^{*}\leq\mathbf{w}_{y}-\mathbf{w}_{y}^{*}\leq\max(\tilde{\mathbf{w}_{y}},\mathbf{w}_{y}^{*})-\mathbf{w}_{y}^{*}
\displaystyle\Longleftrightarrow\quad |𝐰y𝐰y||𝐰y~𝐰y|,\displaystyle|\mathbf{w}_{y}-\mathbf{w}_{y}^{*}|\leq|\tilde{\mathbf{w}_{y}}-\mathbf{w}_{y}^{*}|,

which conludes the proof. ∎

Refer to caption
(a) Transfer accuracy during training.
Refer to caption
(b) Distance to true weights during training.
Figure 4: Left Accuracy of various algorithms during training. Right Euclidian distance between the weights estimated using Lemma 3.2 and the true weights. Those plots correspond to averages over 5 seeds.

B.9 Per-class predictions and estimated weights

In this section, we display the per-class predictions of various algorithms on the sU \rightarrow M task. In Table 16, we see that without domain adaptation, performance on classes is rather random, the digit 99 for instance is very poorly predicted and confused with 44 and 88.

Table 17 shows an interesting pattern for DANN. In line with the limitations described by Theorem 2.1, the model performs very poorly on the subsampled classes (we recall that the subsampling is done in the source domain): the neural network tries to match the unweighted marginals. To do so, it projects representations of classes that are over-represented in the target domain (digits 0 to 44) on representations of the under-represented classes (digits 55 to 99). In doing so, it heavily degrades its performance on those classes (it is worth noting that digit 0 has an importance weight close to 11 which probably explains why DANN still performs well on it, see Table 14).

As far as IWDAN is concerned, Table 18 shows that the model perfoms rather well on all classes, at the exception of the digit 77 confused with 99. IWDAN-O is shown in Table 19 and as expected outperforms the other algorithms on all classes.

Finally, Table 14 shows the estimated weights of all the algorithms, at the training epoch displayed in Tables 16, 17, 18 and 19. We see a rather strong correlation between errors on the estimated weight for a given class, and errors in the predictions for that class (see for instance digit 33 for DANN or digit 77 for IWDAN).

Table 14: Estimated weights and their euclidian distance to the true weights, taken at the training epoch for the confusion matrices in Tables 16, 17, 18 and 19. The first row contains the true weights. The last column gives the euclidian distance from the true weights.
Class
0 1 2 3 4 5 6 7 8 9 Distance
TRUE 1.19 1.61 1.96 2.24 2.16 0.70 0.64 0.70 0.78 0.66 0
DANN 1.06 1.15 1.66 1.33 1.95 0.86 0.72 0.70 1.02 0.92 1.15
IWDAN 1.19 1.61 1.92 1.96 2.31 0.70 0.63 0.55 0.78 0.78 0.38
IWDAN-O 1.19 1.60 2.01 2.14 2.1 0.73 0.64 0.65 0.78 0.66 0.12
No DA 1.14 1.4 2.42 1.49 4.21 0.94 0.38 0.82 0.62 0.29 2.31
Table 15: Ablation study on the Digits tasks, with weights learnt during training.
Method Digits sDigits Method Digits sDigits
DANN 93.15 83.24 CDAN 95.72 88.23
DANN + C𝐰\mathcal{L}_{C}^{\mathbf{w}} 93.18 84.20 CDAN + C𝐰\mathcal{L}_{C}^{\mathbf{w}} 95.30 91.14
DANN + DA𝐰\mathcal{L}_{DA}^{\mathbf{w}} 94.35 92.48 CDAN + DA𝐰\mathcal{L}_{DA}^{\mathbf{w}} 95.42 92.35
IWDAN 94.90 92.54 IWCDAN 95.90 93.22
Table 16: Per-class predictions without domain adaptation on the sU \rightarrow M task. Average accuracy: 74.49%74.49\%. The table MM below verifies Mij=𝒟T(Y^=j|Y=i)M_{ij}=\mathcal{D}_{T}(\hat{Y}=j|Y=i).
\collectcell 9 2.89\endcollectcell \collectcell 0 .13\endcollectcell \collectcell 3 .24\endcollectcell \collectcell 0 .00\endcollectcell \collectcell 2 .20\endcollectcell \collectcell 0 .01\endcollectcell \collectcell 0 .45\endcollectcell \collectcell 0 .88\endcollectcell \collectcell 0 .18\endcollectcell \collectcell 0 .02\endcollectcell
\collectcell 0 .00\endcollectcell \collectcell 7 2.54\endcollectcell \collectcell 1 2.38\endcollectcell \collectcell 0 .00\endcollectcell \collectcell 3 .40\endcollectcell \collectcell 0 .37\endcollectcell \collectcell 7 .54\endcollectcell \collectcell 1 .50\endcollectcell \collectcell 2 .28\endcollectcell \collectcell 0 .00\endcollectcell
\collectcell 0 .31\endcollectcell \collectcell 0 .23\endcollectcell \collectcell 9 3.28\endcollectcell \collectcell 0 .09\endcollectcell \collectcell 0 .72\endcollectcell \collectcell 0 .03\endcollectcell \collectcell 0 .34\endcollectcell \collectcell 4 .78\endcollectcell \collectcell 0 .17\endcollectcell \collectcell 0 .05\endcollectcell
\collectcell 0 .06\endcollectcell \collectcell 0 .77\endcollectcell \collectcell 4 .81\endcollectcell \collectcell 6 8.53\endcollectcell \collectcell 1 .50\endcollectcell \collectcell 1 9.91\endcollectcell \collectcell 0 .02\endcollectcell \collectcell 2 .48\endcollectcell \collectcell 1 .61\endcollectcell \collectcell 0 .31\endcollectcell
\collectcell 0 .02\endcollectcell \collectcell 0 .62\endcollectcell \collectcell 0 .28\endcollectcell \collectcell 0 .00\endcollectcell \collectcell 9 7.19\endcollectcell \collectcell 0 .51\endcollectcell \collectcell 0 .04\endcollectcell \collectcell 0 .17\endcollectcell \collectcell 0 .79\endcollectcell \collectcell 0 .37\endcollectcell
\collectcell 0 .75\endcollectcell \collectcell 3 .03\endcollectcell \collectcell 0 .69\endcollectcell \collectcell 1 .01\endcollectcell \collectcell 1 .20\endcollectcell \collectcell 8 8.96\endcollectcell \collectcell 0 .39\endcollectcell \collectcell 0 .31\endcollectcell \collectcell 2 .69\endcollectcell \collectcell 0 .96\endcollectcell
\collectcell 0 .73\endcollectcell \collectcell 1 .98\endcollectcell \collectcell 0 .42\endcollectcell \collectcell 0 .03\endcollectcell \collectcell 2 3.86\endcollectcell \collectcell 2 .74\endcollectcell \collectcell 6 9.08\endcollectcell \collectcell 0 .29\endcollectcell \collectcell 0 .12\endcollectcell \collectcell 0 .75\endcollectcell
\collectcell 1 .02\endcollectcell \collectcell 2 .01\endcollectcell \collectcell 4 .16\endcollectcell \collectcell 0 .13\endcollectcell \collectcell 9 .32\endcollectcell \collectcell 6 .36\endcollectcell \collectcell 0 .01\endcollectcell \collectcell 7 3.48\endcollectcell \collectcell 1 .01\endcollectcell \collectcell 2 .50\endcollectcell
\collectcell 6 .01\endcollectcell \collectcell 8 .27\endcollectcell \collectcell 2 .55\endcollectcell \collectcell 1 .35\endcollectcell \collectcell 1 .62\endcollectcell \collectcell 3 .62\endcollectcell \collectcell 4 .98\endcollectcell \collectcell 6 .96\endcollectcell \collectcell 6 4.40\endcollectcell \collectcell 0 .24\endcollectcell
\collectcell 1 .49\endcollectcell \collectcell 3 .35\endcollectcell \collectcell 0 .55\endcollectcell \collectcell 1 .28\endcollectcell \collectcell 3 8.30\endcollectcell \collectcell 1 5.36\endcollectcell \collectcell 0 .05\endcollectcell \collectcell 2 0.68\endcollectcell \collectcell 1 .34\endcollectcell \collectcell 1 7.60\endcollectcell
Table 17: Per-class predictions for DANN on the sU \rightarrow M task. Average accuracy: 86.71%86.71\%. The table MM below verifies Mij=𝒟T(Y^=j|Y=i)M_{ij}=\mathcal{D}_{T}(\hat{Y}=j|Y=i). The first 55 classes are under-represented in the source domain compared to the target domain. On those (except 0), DANN does not perform as well as on the over-represented classes (the last 55). In line with Th. 2.1, matching the representation distributions on source and target forced the classifier to confuse the digits “1”, “3” and “4” in the target domain with “8”, “5” and “9”.
\collectcell 9 5.79\endcollectcell \collectcell 0 .01\endcollectcell \collectcell 0 .08\endcollectcell \collectcell 0 .01\endcollectcell \collectcell 0 .12\endcollectcell \collectcell 0 .38\endcollectcell \collectcell 2 .34\endcollectcell \collectcell 0 .36\endcollectcell \collectcell 0 .36\endcollectcell \collectcell 0 .57\endcollectcell
\collectcell 0 .14\endcollectcell \collectcell 7 0.77\endcollectcell \collectcell 0 .80\endcollectcell \collectcell 0 .01\endcollectcell \collectcell 1 .03\endcollectcell \collectcell 1 .29\endcollectcell \collectcell 9 .46\endcollectcell \collectcell 0 .06\endcollectcell \collectcell 1 6.39\endcollectcell \collectcell 0 .06\endcollectcell
\collectcell 1 .61\endcollectcell \collectcell 0 .14\endcollectcell \collectcell 8 9.82\endcollectcell \collectcell 0 .20\endcollectcell \collectcell 0 .42\endcollectcell \collectcell 0 .48\endcollectcell \collectcell 0 .83\endcollectcell \collectcell 3 .73\endcollectcell \collectcell 1 .37\endcollectcell \collectcell 1 .39\endcollectcell
\collectcell 0 .46\endcollectcell \collectcell 0 .08\endcollectcell \collectcell 1 .10\endcollectcell \collectcell 6 3.33\endcollectcell \collectcell 0 .04\endcollectcell \collectcell 2 6.28\endcollectcell \collectcell 0 .02\endcollectcell \collectcell 1 .78\endcollectcell \collectcell 3 .76\endcollectcell \collectcell 3 .14\endcollectcell
\collectcell 0 .11\endcollectcell \collectcell 0 .13\endcollectcell \collectcell 0 .05\endcollectcell \collectcell 0 .00\endcollectcell \collectcell 7 8.17\endcollectcell \collectcell 0 .85\endcollectcell \collectcell 0 .16\endcollectcell \collectcell 0 .15\endcollectcell \collectcell 1 .97\endcollectcell \collectcell 1 8.41\endcollectcell
\collectcell 0 .19\endcollectcell \collectcell 0 .04\endcollectcell \collectcell 0 .02\endcollectcell \collectcell 0 .04\endcollectcell \collectcell 0 .01\endcollectcell \collectcell 9 1.30\endcollectcell \collectcell 0 .25\endcollectcell \collectcell 0 .27\endcollectcell \collectcell 5 .97\endcollectcell \collectcell 1 .91\endcollectcell
\collectcell 0 .62\endcollectcell \collectcell 0 .12\endcollectcell \collectcell 0 .01\endcollectcell \collectcell 0 .00\endcollectcell \collectcell 1 .98\endcollectcell \collectcell 4 .61\endcollectcell \collectcell 9 1.73\endcollectcell \collectcell 0 .05\endcollectcell \collectcell 0 .36\endcollectcell \collectcell 0 .51\endcollectcell
\collectcell 0 .14\endcollectcell \collectcell 0 .23\endcollectcell \collectcell 1 .39\endcollectcell \collectcell 0 .13\endcollectcell \collectcell 0 .10\endcollectcell \collectcell 0 .32\endcollectcell \collectcell 0 .02\endcollectcell \collectcell 9 4.10\endcollectcell \collectcell 1 .46\endcollectcell \collectcell 2 .10\endcollectcell
\collectcell 0 .69\endcollectcell \collectcell 0 .13\endcollectcell \collectcell 0 .11\endcollectcell \collectcell 0 .05\endcollectcell \collectcell 0 .21\endcollectcell \collectcell 2 .12\endcollectcell \collectcell 0 .50\endcollectcell \collectcell 0 .36\endcollectcell \collectcell 9 5.19\endcollectcell \collectcell 0 .66\endcollectcell
\collectcell 0 .36\endcollectcell \collectcell 0 .31\endcollectcell \collectcell 0 .03\endcollectcell \collectcell 0 .08\endcollectcell \collectcell 0 .46\endcollectcell \collectcell 3 .67\endcollectcell \collectcell 0 .01\endcollectcell \collectcell 1 .64\endcollectcell \collectcell 1 .03\endcollectcell \collectcell 9 2.40\endcollectcell
Table 18: Per-class predictions for IWDAN on the sU \rightarrow M task. Average accuracy: 94.38%94.38\%. The table MM below verifies Mij=𝒟T(Y^=j|Y=i)M_{ij}=\mathcal{D}_{T}(\hat{Y}=j|Y=i).
\collectcell 9 7.33\endcollectcell \collectcell 0 .06\endcollectcell \collectcell 0 .23\endcollectcell \collectcell 0 .01\endcollectcell \collectcell 0 .20\endcollectcell \collectcell 0 .43\endcollectcell \collectcell 1 .29\endcollectcell \collectcell 0 .19\endcollectcell \collectcell 0 .18\endcollectcell \collectcell 0 .09\endcollectcell
\collectcell 0 .00\endcollectcell \collectcell 9 7.71\endcollectcell \collectcell 0 .41\endcollectcell \collectcell 0 .05\endcollectcell \collectcell 0 .56\endcollectcell \collectcell 0 .69\endcollectcell \collectcell 0 .14\endcollectcell \collectcell 0 .03\endcollectcell \collectcell 0 .34\endcollectcell \collectcell 0 .07\endcollectcell
\collectcell 0 .70\endcollectcell \collectcell 0 .16\endcollectcell \collectcell 9 6.32\endcollectcell \collectcell 0 .08\endcollectcell \collectcell 0 .34\endcollectcell \collectcell 0 .23\endcollectcell \collectcell 0 .23\endcollectcell \collectcell 1 .49\endcollectcell \collectcell 0 .43\endcollectcell \collectcell 0 .01\endcollectcell
\collectcell 0 .23\endcollectcell \collectcell 0 .01\endcollectcell \collectcell 0 .97\endcollectcell \collectcell 8 7.67\endcollectcell \collectcell 0 .01\endcollectcell \collectcell 9 .25\endcollectcell \collectcell 0 .02\endcollectcell \collectcell 0 .63\endcollectcell \collectcell 0 .87\endcollectcell \collectcell 0 .35\endcollectcell
\collectcell 0 .11\endcollectcell \collectcell 0 .25\endcollectcell \collectcell 0 .05\endcollectcell \collectcell 0 .00\endcollectcell \collectcell 9 6.93\endcollectcell \collectcell 0 .16\endcollectcell \collectcell 0 .22\endcollectcell \collectcell 0 .02\endcollectcell \collectcell 0 .40\endcollectcell \collectcell 1 .85\endcollectcell
\collectcell 0 .15\endcollectcell \collectcell 0 .11\endcollectcell \collectcell 0 .01\endcollectcell \collectcell 0 .16\endcollectcell \collectcell 0 .05\endcollectcell \collectcell 9 5.81\endcollectcell \collectcell 0 .69\endcollectcell \collectcell 0 .11\endcollectcell \collectcell 2 .82\endcollectcell \collectcell 0 .08\endcollectcell
\collectcell 0 .26\endcollectcell \collectcell 0 .25\endcollectcell \collectcell 0 .00\endcollectcell \collectcell 0 .00\endcollectcell \collectcell 2 .07\endcollectcell \collectcell 1 .49\endcollectcell \collectcell 9 5.84\endcollectcell \collectcell 0 .01\endcollectcell \collectcell 0 .07\endcollectcell \collectcell 0 .00\endcollectcell
\collectcell 0 .16\endcollectcell \collectcell 0 .42\endcollectcell \collectcell 2 .12\endcollectcell \collectcell 0 .91\endcollectcell \collectcell 1 .15\endcollectcell \collectcell 0 .60\endcollectcell \collectcell 0 .03\endcollectcell \collectcell 8 2.07\endcollectcell \collectcell 1 .35\endcollectcell \collectcell 1 1.19\endcollectcell
\collectcell 0 .44\endcollectcell \collectcell 0 .50\endcollectcell \collectcell 0 .36\endcollectcell \collectcell 0 .18\endcollectcell \collectcell 0 .43\endcollectcell \collectcell 0 .90\endcollectcell \collectcell 0 .91\endcollectcell \collectcell 0 .15\endcollectcell \collectcell 9 5.74\endcollectcell \collectcell 0 .40\endcollectcell
\collectcell 0 .34\endcollectcell \collectcell 0 .42\endcollectcell \collectcell 0 .06\endcollectcell \collectcell 0 .30\endcollectcell \collectcell 1 .67\endcollectcell \collectcell 2 .46\endcollectcell \collectcell 0 .11\endcollectcell \collectcell 0 .31\endcollectcell \collectcell 0 .85\endcollectcell \collectcell 9 3.50\endcollectcell
Table 19: Per-class predictions for IWDAN-O on the sU \rightarrow M task. Average accuracy: 96.8%96.8\%. The table MM below verifies Mij=𝒟T(Y^=j|Y=i)M_{ij}=\mathcal{D}_{T}(\hat{Y}=j|Y=i).
\collectcell 9 8.04\endcollectcell \collectcell 0 .01\endcollectcell \collectcell 0 .20\endcollectcell \collectcell 0 .00\endcollectcell \collectcell 0 .27\endcollectcell \collectcell 0 .03\endcollectcell \collectcell 1 .17\endcollectcell \collectcell 0 .11\endcollectcell \collectcell 0 .15\endcollectcell \collectcell 0 .02\endcollectcell
\collectcell 0 .00\endcollectcell \collectcell 9 8.35\endcollectcell \collectcell 0 .33\endcollectcell \collectcell 0 .15\endcollectcell \collectcell 0 .17\endcollectcell \collectcell 0 .27\endcollectcell \collectcell 0 .19\endcollectcell \collectcell 0 .05\endcollectcell \collectcell 0 .47\endcollectcell \collectcell 0 .02\endcollectcell
\collectcell 0 .22\endcollectcell \collectcell 0 .04\endcollectcell \collectcell 9 7.48\endcollectcell \collectcell 0 .07\endcollectcell \collectcell 0 .29\endcollectcell \collectcell 0 .08\endcollectcell \collectcell 0 .52\endcollectcell \collectcell 1 .09\endcollectcell \collectcell 0 .19\endcollectcell \collectcell 0 .02\endcollectcell
\collectcell 0 .10\endcollectcell \collectcell 0 .00\endcollectcell \collectcell 0 .66\endcollectcell \collectcell 9 5.72\endcollectcell \collectcell 0 .01\endcollectcell \collectcell 2 .32\endcollectcell \collectcell 0 .00\endcollectcell \collectcell 0 .35\endcollectcell \collectcell 0 .56\endcollectcell \collectcell 0 .27\endcollectcell
\collectcell 0 .01\endcollectcell \collectcell 0 .25\endcollectcell \collectcell 0 .05\endcollectcell \collectcell 0 .00\endcollectcell \collectcell 9 6.80\endcollectcell \collectcell 0 .03\endcollectcell \collectcell 0 .18\endcollectcell \collectcell 0 .01\endcollectcell \collectcell 0 .60\endcollectcell \collectcell 2 .06\endcollectcell
\collectcell 0 .23\endcollectcell \collectcell 0 .11\endcollectcell \collectcell 0 .00\endcollectcell \collectcell 0 .72\endcollectcell \collectcell 0 .00\endcollectcell \collectcell 9 6.09\endcollectcell \collectcell 0 .68\endcollectcell \collectcell 0 .13\endcollectcell \collectcell 2 .01\endcollectcell \collectcell 0 .03\endcollectcell
\collectcell 0 .27\endcollectcell \collectcell 0 .31\endcollectcell \collectcell 0 .00\endcollectcell \collectcell 0 .00\endcollectcell \collectcell 2 .07\endcollectcell \collectcell 0 .63\endcollectcell \collectcell 9 6.54\endcollectcell \collectcell 0 .00\endcollectcell \collectcell 0 .17\endcollectcell \collectcell 0 .01\endcollectcell
\collectcell 0 .26\endcollectcell \collectcell 0 .45\endcollectcell \collectcell 2 .13\endcollectcell \collectcell 0 .29\endcollectcell \collectcell 0 .90\endcollectcell \collectcell 0 .32\endcollectcell \collectcell 0 .01\endcollectcell \collectcell 9 2.66\endcollectcell \collectcell 1 .06\endcollectcell \collectcell 1 .92\endcollectcell
\collectcell 0 .55\endcollectcell \collectcell 0 .22\endcollectcell \collectcell 0 .30\endcollectcell \collectcell 0 .06\endcollectcell \collectcell 0 .18\endcollectcell \collectcell 0 .22\endcollectcell \collectcell 0 .41\endcollectcell \collectcell 0 .33\endcollectcell \collectcell 9 7.11\endcollectcell \collectcell 0 .62\endcollectcell
\collectcell 0 .46\endcollectcell \collectcell 0 .37\endcollectcell \collectcell 0 .16\endcollectcell \collectcell 0 .86\endcollectcell \collectcell 0 .82\endcollectcell \collectcell 1 .45\endcollectcell \collectcell 0 .01\endcollectcell \collectcell 0 .77\endcollectcell \collectcell 0 .98\endcollectcell \collectcell 9 4.13\endcollectcell