This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Domain Generalization via Optimal Transport with Metric Similarity Learning

Fan Zhou [email protected] Zhuqing Jiang Changjian Shui Boyu Wang Brahim Chaib-draa Department of Computer Science and Software Engineering, Laval University, QC, Canada School of Information and Communication Engineering, Beijing University of Posts and Telecommunications, Beijing, China Department of Electrical and Computer Engineering, Laval University, QC, Canada Department of Computer Science, University of Western Ontario, ON, Canada Vector Insitute, ON, Canada
Abstract

Generalizing knowledge to unseen domains, where data and labels are unavailable, is crucial for machine learning. We tackle in this chapter the domain generalization problem to learn from multiple source domains and generalize to a target domain with unknown statistics. The crucial idea is to extract the underlying invariant features across all the domains. Previous domain generalization approaches mainly focused on learning invariant features and stacking the learned features from each source domain to generalize to a new target domain while ignoring the label information, this generally leads to indistinguishable features with an ambiguous classification boundary. One possible solution is to constrain the label-similarity when extracting the invariant features and take advantage of the label similarities for class-specific cohesion and separation of features across domains. We adopt here the optimal transport with Wasserstein distance, which could constrain the class label similarity, for adversarial training. We also deploy a metric learning objective to leverage the label information for achieving distinguishable classification boundary. Our empirical results show that our proposed method could outperform most of the baselines. Furthermore, ablation studies also demonstrate the effectiveness of each component of our method.

keywords:
Domain Generalization , Adversarial Learning , Metric Learning
journal: Journal of Neurocomputing

1 Introduction

Recent years witness a rapid development of machine learning and its succeeded applications such as computer vision [Ma et al., 2018a, Zhu et al., 2019, Ma et al., 2019b], natural language processing [Ma et al., 2013, 2018b] and cross-modalities learning [Zhu et al., 2019, Xu et al., 2018] with many real-world applications [Xie et al., 2018, Ma et al., 2019a]. Traditional machine learning methods are typically based on the assumption that training and testing datasets are from the same distribution. However, in many real-world applications, this assumption may not hold, and the performance could degrade rapidly if the trained models are deployed to domains different from the training dataset [Ganin et al., 2016]. More severely, to train a high-performance vision system requires a large amount of labelled data, and getting such labels may be expensive. Taking a pre-trained robotic vision system as an example, during each deployment task, the robot itself (e.g. position and angle), the environment (e.g. weather and illumination) and the camera (e.g. resolution) may result in different image styles. The cost to annotate enough data for each deployment task could be very expensive.

This kind of problem has been widely addressed by transfer learning (TL) [Zhuang et al., 2019] and domain adaptation (DA) [Ganin et al., 2016]. In DA, a learner usually has access to the labelled source data and unlabelled target data, and it is typically trained to align the feature distribution between the source and target domain. However, sometimes, we could not expect the target data is accessible for the learner. In the robot example, the distribution divergences (different image styles) from training to testing domain can only be identified after the model is trained and deployed. In this scenario, it’s unrealistic to collect samples before deployment. This would require a robot to have abilities to handle domain divergences even though the target data is absent.

We tackle this kind of problem under domain generalization (DG) paradigm, under which the learner has access to many source domains (data and corresponding labels), and aims at generalizing to the new (target) domain, where both data and labels are unknown. The goal of DG is to learn a prediction model on training data from the seen source domains so that it can generalize well on the unseen target domain. An underlying assumption behind domain generalization is that there exists a common feature space underlying the multiple known source domains and unseen target domain. Specifically, we want to learn domain invariant features across these source domains, and then generalize to a new domain. An example of how domain generalization is processed is illustrated in Fig.1.

Refer to caption
Figure 1: Domain Generalization: A learner faces a set labelled data from several source domains, and it aims at extracting invariant features across the seen source domains and learn to generalize to an unseen domain. Based on the manifold assumption [Goldberg et al., 2009], each domain ii is supported by distribution 𝒟i\mathcal{D}_{i}. The learner can measure the source domain distribution via the source datasets but has no information on the unseen target distribution. After training on the source domains, the model is then deployed to a new domain 𝒟t\mathcal{D}_{t} for prediction.

A critical problem in DG and DA involves aligning the domain distributions, which typically are achieved by extracting such representations. Previous DA works usually tried to minimize the domain discrepancies, such as KL-divergence and Maximum Mean Discrepancy (MMD) etc. via adversarial training, to achieve domain distribution alignments. Due to the similar problem setting between DA and DG, many previous approaches directly adopt the same adversarial training technique for DG. For example, a MMD metric is adopted by Li et al. [2018b] as a cross-domain regularizer and KL divergence is adopted to measure the domain shift by Li et al. [2017a] for domain generalization problem. The MMD metric is usually implemented in kernel space, which is not sufficient for large-scaled applications, and KL divergence is unbounded, which is also insufficient for a successful measuring domain shift [Zhao et al., 2019].

Besides, previous domain generalization approaches [Ilse et al., 2019, Ghifary et al., 2015, Li et al., 2018c, D’Innocente & Caputo, 2018, Volpi et al., 2018] mainly focused on applying similar DA technique to extract the invariant features and how to stack the learned features from each domain for generalizing to a new domain. These methods usually ignore the label information and will sometimes make the features became indistinguishable with ambiguous classification boundaries, a.k.aa.k.a semantic misalignment problem [Deng et al., 2020]. A successful generalization should guide the learner not only to align the feature distributions between each domain but also to discriminate the samples in the same class could lie close to each other while samples from different classes could stay apart from each other, a.k.a.a.k.a. feature compactness [Kamnitsas et al., 2018].

Aiming to solve this, we adopt Optimal Transport (OT) with Wasserstein distance to align the feature distribution for domain generalization since it could constrain labelled source samples of the same class to remain close during the transportation process [Courty et al., 2016]. Moreover, some information theoretical metrics such as KL divergence is not capable to measure the inherent geometric relations among the different domains [Arjovsky et al., 2017]. In contrast, OT can exactly measure their corresponding geometry properties. Besides, compared with [Ben-David et al., 2010], OT benefits from the advantages of Wasserstein distance by its gradient property [Arjovsky et al., 2017] and the promising generalization bound [Redko et al., 2017]. The empirical studies [Gulrajani et al., 2017, Shen et al., 2018] also demonstrated the effectiveness of OT for extracting the invariant features to align the marginal distributions of different domains.

Furthermore, although the optimal transport process could constrain the labelled samples of the same class to stay close to each other, our preliminary results showed that just implementing optimal transport for domain generalization is not sufficient for a cohesion and separable classification boundary. The model could still suffer from indistinguishable features (see Fig. 4(c)). In order to train the model to predict well on all the domains, this separable classification boundary should also be achieved under a domain-agnostic manner. That is, for a pair of instances, no matter which domain they come from, they should stay close to each other if they are in the same class and vice-versa. To this end, we further promote metric learning as an auxiliary objective for leveraging the source domain label information for a domain-independent distinguishable classification boundary.

To summarize, we deployed the optimal transport technique with Wasserstein distance for domain generalization for extracting the domain invariant features. To avoid ambiguous classification boundary, we proposed to implement metric learning strategies to achieve a distinguishable feature space. Therefore, we proposed the Wasserstein Adversarial Domain Generalization (WADG) algorithm.

In order to check the effectiveness of the proposed approach, we tested the algorithm on two benchmarks comparing with some recent domain generalization baselines. The experiment results showed that our proposed algorithm could outperform most of the baselines, which confirms the effectiveness of our proposed algorithm. Furthermore, the ablation studies also demonstrated the contributions of our algorithm.

2 Related Works

2.1 Domain Generalization

The goal of DG is to learn a model that can extract common knowledge that is shared across source domains and generalize well on the target domain. Compare with DA, the main challenge of DG is that the target domain data is not available during the learning process.

A common framework for DG is to extract the most informative and transferable underlying common features from source instances generated from different distributions and to generalize to unseen one. This kind of approach holds with the assumption that there exists an underlying invariant feature distribution among all domains, and that consequently such invariant features can generalize well to a target domain. Muandet et al. [2013] implemented MMD as a distribution regularizer and proposed the kernel-based Domain Invariant Component Analysis (DICA) algorithm. An autoencoder-based model was proposed by Ghifary et al. [2015] under a multi-task learning setting to learn domain-invariant features via adversarial training. Li et al. [2018c] proposed an end-to-end deep domain generalization approach by leveraging deep neural networks for domain-invariant representation learning. Motiian et al. [2017] proposed to minimize the semantic alignment loss as well as the separation loss based on deep learning models. Li et al. [2018b] proposed a low-rank Convolutional Neural Network model based on domain shift-robust deep learning methods.

There are also some approaches to tackle the domain generalization problems in a meta-learning manner. To the best of our knowledge, Li et al. [2018a] first proposed to adopt the Meta Agnostic Meta-Learning (MAML) [Finn et al., 2017] which back-propagates the gradients of ordinary loss function of meta-test tasks. As pointed by Dou et al. [2019], such an approach might lead to a sub-optimal solution, as it is highly abstracted from the feature representations.  Balaji et al. [2018] proposed MetaReg algorithm in which a regularization function (e.g.e.g. weighted L1L_{1} loss) is implemented for the classification layer of the model but not for the feature extractor layers. Then, [Li et al., 2019] proposes an auxiliary meta loss which is gained based on the feature extractor. Furthermore, the network architecture of [Li et al., 2019] is the widely used feature-critic style model based on a similar model from domain adversarial training technique [Ganin et al., 2016]Dou et al. [2019] and Matsuura & Harada [2020] also started to implement clustering techniques on the invariant feature space for better classification and showed better performance on the target domain.

2.2 Metric Learning

Metric learning aims to learn a discriminative feature embedding where similar samples are closer while different samples are further apart [Deng et al., 2020]. Hadsell et al. [2006] proposed the siamese network together with contrastive loss to guide the instances stay close with each other in the feature space if they have the same labels and push them apart vice-versa. Schroff et al. [2015] proposed the triplet loss aiming to learn a feature space where a positive pair has higher similarity than the negative pair when comparing by the same anchor with a given margin. Oh Song et al. [2016] showed that neither the contrastive loss nor triplet loss could efficiently explore the full pair-wise relations between instances under the mini-batch training setting. They further propose the lifted structure loss to fully utilize pair-wise relations across batches. However, it only choose equal number of positive pairs as negative ones randomly, and many informative pairs are discarded [Wang et al., 2019], which restricts the ability of finding the informative pairs.  Yi et al. [2014] proposed the binomial deviance loss which could measure the hard pairs. One remarkable work by Wang et al. [2019] combines the advantages both from lifted structure loss and binomial loss to leverage the pair-similarity. They proposed to leverage not only pair-similarities (positive or negative pairs with each other) but also self-similarity which enables the learner to collect and weight informative pairs (positive or negative pairs) under an iterative (mining and weighting) manner. For a pair of instances, the self-similarity is gained from itself. Such a multi-similarity has been shown could measure the similarity and could cluster the samplers more efficiently and accurately. In the context of domain generalization,  Dou et al. [2019] proposed to guide the learner to leverage from the local similarity in the semantic feature space, in which the authors argued may contain essential domain-independent general knowledge for domain generalization and adopt the constrative loss and triplet loss to encourage the clustering for solving this issue. Leveraging from the across-domain class similarity information can encourage the learner to extract robust semantic features that regardless of domains, which is an useful auxiliary information for the learner. If the learner could not separate the samples (from different source domains) with domain-independent class-specific cohesion and separation on the domain invariant feature space, it would still suffer from ambiguous decision boundaries. This ambiguous decision boundaries might still be sensitive to the unseen target domain. Matsuura & Harada [2020] implement unsupervised clustering on source domains and showed better classification performance. Our work is orthogonal to previous works, proposing to enforce more distinguishable invariant features space via Wasserstein adversarial training and encouraging to leverage from label similarity information for better classification boundary.

Table 1: List of notations
Symbol Meaning Symbol Meaning
FF The feature extraction function 𝜽f\boldsymbol{\theta}_{f} Parameter of feature extraction network
DD The critic function 𝜽d\boldsymbol{\theta}_{d} Parameter of critic network
CC The classification function 𝜽c\boldsymbol{\theta}_{c} Parameter for classification network
mm The number of source domains 𝐱j(i)\mathbf{x}_{j}^{(i)} The ii-th instance from the jj-th domain
NiN_{i}
The number of instances
in the ii-th domain
𝐗(i)\mathbf{X}^{(i)}
The set of instances in the ii-th domain
𝐗(i)={𝐱j(i)}j=1N\mathbf{X}^{(i)}=\{\mathbf{x}_{j}^{(i)}\}_{j=1}^{N}
𝒟\mathcal{D}
The data distribution.
𝒟i\mathcal{D}_{i} are the source
domain distributions
Z(i)Z^{(i)} The extracted feature from domain ii
W1(𝒟i,𝒟j)W_{1}(\mathcal{D}_{i},\mathcal{D}_{j})
Wasserstein-1 distance over two
distributions 𝒟i\mathcal{D}_{i} and 𝒟j\mathcal{D}_{j}
yjy_{j} The label for corresponding instance xjx_{j}
𝐒\mathbf{S} The similarity matrix Si,jS_{i,j}
The value of ii-th row and jj-th column
of the similarity matrix 𝐒\mathbf{S}
wi,jw_{i,j} The weight for similarity Si,jS_{i,j} ϵ\epsilon
Small margin for roughly select
the positive and negative pairs
α\alpha
Fixed parameter for
positive mining
β\beta Fixed parameter for negative mining
λ\lambda
Parameter for
self-similarity mining
λd\lambda_{d}
Coefficient for regularizing
the adversarial objective
λs\lambda_{s}
Coefficient for regularizing
the metric learning objective
\mathcal{L}
The objective functions,
C\mathcal{L}_{C} is the classification loss,
D\mathcal{L}_{D} is the adversarial loss,
MS\mathcal{L}_{MS} is the metric similarity loss

3 Preliminaries and Problem Setup

We start by introducing some preliminaries. In order to better summarize the notations symbols in this work, we provide the list of notations and symbols in Table 1.

3.1 Notations and Definitions

Following Redko et al. [2017] and Li et al. [2017a], suppose we have mm known source domains distributions {𝒟i}i=1m\{\mathcal{D}_{i}\}_{i=1}^{m}, and ithi^{th} domain contains NiN_{i} labeled instances in total, denoted by {(𝐱j(i),yj(i))}j=1Ni\{(\mathbf{x}^{(i)}_{j},y^{(i)}_{j})\}^{N_{i}}_{j=1}, where 𝐱j(i)n\mathbf{x}^{(i)}_{j}\in\mathbb{R}^{n} is the jthj^{th} instance feature from the ithi^{th} domain and yj(i){1,,K}y^{(i)}_{j}\in\{1,\dots,K\} are the corresponding labels. For a hypothesis class \mathcal{H}, the expected source and target risk of a hypothesis hh\in\mathcal{H} over domain distribution 𝒟i\mathcal{D}_{i} is the probabilities that hh wrongly predicts on the entire distribution 𝒟i\mathcal{D}_{i}: ϵi(h)=𝔼(𝐱,y)𝒟i(h(𝐱,y))\epsilon_{i}(h)=\mathbb{E}_{(\mathbf{x},y)\sim\mathcal{D}_{i}}\ell(h(\mathbf{x},y)), where ()\ell(\cdot) is the loss function. The empirical loss is also defined by: ϵ^i(h)=1Nij=1Ni(h(𝐱j,yj))\hat{\epsilon}_{i}(h)=\frac{1}{N_{i}}\sum_{j=1}^{N_{i}}\ell(h(\mathbf{x}_{j},y_{j})).

In the setting of domain generalization, we only have the access to the seen source domains 𝒟i\mathcal{D}_{i} but have no information about the target domain. The learner is expected to extract the underlying invariant feature space across the source domains and generalize to a new target domain.

3.2 Optimal Transport and Wasserstein Distance

We follow Redko et al. [2017] and define c:n×n+c:\mathbb{R}^{n}\times\mathbb{R}^{n}\to\mathbb{R}^{+} as the cost function for transporting one unit of mass 𝐱\mathbf{x} to 𝐱\mathbf{x}^{\prime}, then the primal form of the Wasserstein distance between 𝒟i\mathcal{D}_{i} and 𝒟j\mathcal{D}_{j} could be computed by,

Wpp(𝒟i,𝒟j)=infγΠ(𝒟i,𝒟j)n×nc(𝐱,𝐱)p𝑑γ(𝐱,𝐱)W_{p}^{p}(\mathcal{D}_{i},\mathcal{D}_{j})=\inf_{\gamma\in\Pi(\mathcal{D}_{i},\mathcal{D}_{j})}\int_{\mathbb{R}^{n}\times\mathbb{R}^{n}}c(\mathbf{x},\mathbf{x}^{\prime})^{p}d\gamma(\mathbf{x},\mathbf{x}^{\prime}) (1)

where Π(𝒟i,𝒟j)\Pi(\mathcal{D}_{i},\mathcal{D}_{j}) is the probability coupling on n×n\mathbb{R}^{n}\times\mathbb{R}^{n} with marginals 𝒟i\mathcal{D}_{i} and 𝒟j\mathcal{D}_{j} referring to all the possible coupling functions. Throughout this paper, we adopt Wasserstein-1 distance only (p=1p=1).

Computing the primal form of Wasserstein distance (Eq. 1) is computational inefficiently. Assuming |𝒟i|=n,|𝒟j|=m|\mathcal{D}_{i}|=n,|\mathcal{D}_{j}|=m, the time complexity for directly computing Eq. 1 is 𝒪(n3+m3)\mathcal{O}(n^{3}+m^{3}). On the contrary, leveraging the Kantorovich-Rubinstein duality [Wainwright, 2019] of Wasserstein distances could help to get a more efficient approximation. Assume ff a 11-Lipschitz-continuous w.r.t.w.r.t. the cost function: f(x)f(x)c(x,x)\|f(x)-f(x^{\prime})\|\leq c(x,x^{\prime}), we can prove that for any function ff,

W1(𝒟i,𝒟j)𝔼x𝒟id(x)𝔼x𝒟jd(x)W_{1}(\mathcal{D}_{i},\mathcal{D}_{j})\geq\mathbb{E}_{x\sim\mathcal{D}_{i}}d(x)-\mathbb{E}_{x^{\prime}\sim\mathcal{D}_{j}}d(x^{\prime})

The equality arrives when ff reaches the maximum of the right side,

W1(𝒟i,𝒟j)=supfL<1𝔼x𝒟if(x)𝔼x𝒟jf(x)W_{1}(\mathcal{D}_{i},\mathcal{D}_{j})=\sup_{\|f\|_{L}<1}\mathbb{E}_{x\in\mathcal{D}_{i}}f(x)-\mathbb{E}_{x^{\prime}\in\mathcal{D}_{j}}f(x^{\prime}) (2)

In practice, such a function ff could be approximated by a neural-network, which allows us to compute this Kantorovich-Rubinstein duality efficiently by computing the expectation and the complexity w.r.t.w.r.t. f(x)f(x) is only 𝒪(n+m)\mathcal{O}(n+m). Empirically, to compute the sup\sup is equivalent to find out the maximum of W1W_{1} (by an argmax\arg\max operation). General neural network optimizer (e.g.e.g. SGD or Adam) can efficiently solve the maximum problem to evaluate the dual value of W1W_{1} distance.

Refer to caption
Figure 2: Use optimal transport (OT) for domain generalization: Typically to directly predict on the unseen domain (the white dashed arrow) is difficult. In order to learn domain invariant features, as showed in the direction of the green arrow we adopted the OT technique to achieve domain alignments for extracting invariant features. After the OT transition, the invariant features can be generalized to unseen domain.

Optimal transport theory and Wasserstein distance were recently investigated in the context of machine learning [Arjovsky et al., 2017] especially in the domain adaptation area [Courty et al., 2016, Zhou et al., 2020]. The general idea of implementing the optimal transport technique for domain generalization across domains is illustrated in Fig. 2. To learn domain invariant features, OT technique is implemented to achieve domain alignments for extracting invariant features. After the OT transition, the invariant features can be generalized to unseen domain.

3.3 Metric Learning

For a pair of instances (𝐱i,yi)(\mathbf{x}_{i},y_{i}) and (𝐱j,yj)(\mathbf{x}_{j},y_{j}), the notion of positive pairs usually refers to the condition where pair i,ji,j have same labels (yi=yjy_{i}=y_{j}), while the negative pairs usually refers to the condition yiyjy_{i}\neq y_{j}. The central idea of metric learning is to encourage a pair of instances who have the same labels to be closer, and push negative pairs to be apart from each other [Wu et al., 2017].

Follow the framework of Wang et al. [2019], we show the general pair-weighting process of metric learning. Assuming the feature extractor ff parameterized by 𝜽f\boldsymbol{\theta}_{f} projects the instance 𝐱n\mathbf{x}\in\mathbb{R}^{n} to a dd-dimensional normalized space: f(𝐱;𝜽f):n[0,1]df(\mathbf{x};\boldsymbol{\theta}_{f}):\;\mathbb{R}^{n}\to[0,1]^{d}. Then, for two samples 𝐱i\mathbf{x}_{i} and 𝐱j\mathbf{x}_{j}, the similarity between them could be defined as the inner product of the corresponding feature vector:

Si,j:=f(𝐱i;𝜽f),f(𝐱j;𝜽f)S_{i,j}:=\langle f(\mathbf{x}_{i};\boldsymbol{\theta}_{f}),f(\mathbf{x}_{j};\boldsymbol{\theta}_{f})\rangle (3)

To leverage the across-domain class similarity information can encourage the learner to extract the classification boundary that regardless of domains, which is an useful auxiliary information for the learner. We further elaborate it in section 4.2.

4 Proposed Method

The high-level idea of WADG algorithm is to learn a domain-invariant feature space and domain-agnostic classification boundary. Firstly, we align the marginal distribution of different source domains via optimal transport by minimizing the Wasserstein distance to achieve the domain-invariant feature space. And then, we adopt metric learning objective to guide the learner to leverage the class similarity information for a better classification boundary. A general workflow of our method is illustrated in Fig. 3(a). The model contains three major parts: a feature extractor, a classifier and a critic function.

The feature extractor function FF, parameterized by 𝜽f\boldsymbol{\theta}_{f}, extracts the features from different source domain. For set of instances 𝐗(i)={𝐱j(i)}j=1Ni\mathbf{X}^{(i)}=\{\mathbf{x}_{j}^{(i)}\}_{j=1}^{N_{i}} from domain 𝒟i\mathcal{D}_{i}, we can then denote the extracted feature from domain ii as 𝐙(i)=F(𝐗(i))\mathbf{Z}^{(i)}=F(\mathbf{X}^{(i)}). The classification function CC, parameterized by 𝜽c\boldsymbol{\theta}_{c}, is expected to learn to predict labels of instances from all the domains correctly. The critic function DD, parameterized by 𝜽d\boldsymbol{\theta}_{d}, aims to measure the empirical Wasserstein distance between features from a pair of source domains. For the target domain, all the instances and labels are absent during the training time.

WADG aims to learn the domain-agnostic features with distinguishable classification boundary. During each train round, the network receives the labelled data from all domains and train the classifier under a supervised mode with the classification loss C\mathcal{L}_{C}. For the classification process, we use the typical cross-entropy loss for all mm source domains:

C=i=1mj=1Niyjlog((C(F(xj(i)))))\mathcal{L}_{C}=-\sum_{i=1}^{m}\sum_{j=1}^{N_{i}}y_{j}\log(\mathbb{P}(C(F(\textbf{x}_{j}^{(i)})))) (4)

Through this, the model could learn to train the category information on over all the domains. The feature extractor FF is then trained to minimize the estimated Wasserstein Distance in an adversarial manner with the critic DD with an objective D\mathcal{L}_{D}. We then adopt a metric learning objective (namely, MS\mathcal{L}_{MS}) for leveraging the similarities for a better classification boundary. Our full method then solve the joint loss function,

=argminθf,θcmaxθdC+D+MS,\mathcal{L}=\arg\min_{\theta_{f},\theta_{c}}\max_{\theta_{d}}\mathcal{L}_{C}+\mathcal{L}_{D}+\mathcal{L}_{MS},

where D\mathcal{L}_{D} is the adversarial objective function, and MS\mathcal{L}_{MS} is the metric learning objective function. In the sequel, we will elaborate these two objectives in section 4.1 and section 4.2, respectively.

Refer to caption
(a) The whole workflow the proposed WADG model.
Refer to caption
(b) Optimal Transport for Feature Alignment.
Refer to caption
(c) Metric Learning for Clustering Proces
Figure 3: The proposed WADG method. (a): the general workflow of WADG method. The model mainly consists of three parts, the feature extractor, classifier and critic function. During training, the model receives all the source domains. The feature extractor is trained to learn invariant features together with the critic function in an adversarial manner. (b): For each pair of source domains 𝒟i\mathcal{D}_{i} and 𝒟j\mathcal{D}_{j}, optimal transport process for aligning the features from different domains. (c): The metric learning process. For a batch of all source domain instances, we first roughly mining the positive and negative pairs via Eq. 7. Then, compute the corresponding weights via Eq. 11 and Eq. 12 to compute MS\mathcal{L}_{MS} to guide the clustering process.

4.1 Adversarial Domain Generalization via Optimal Transport

As optimal transport could constrain labelled source samples of the same class to remain close during the transportation process [Courty et al., 2016]. We deploy optimal transport with Wasserstein distance [Redko et al., 2017, Shen et al., 2018] for aligning the marginal feature distribution over all the source domains.

A brief workflow of the optimal transport for a pair of sourcce domains is illustrated in Fig. 3(b). The critic function DD estimates the empirical Wasserstein Distance between the each source domain through a pair of instances from the empirical sets 𝐱(i)𝐗(i)\mathbf{x}^{(i)}\in\mathbf{X}^{(i)} and 𝐱(j)𝐗(j)\mathbf{x}^{(j)}\in\mathbf{X}^{(j)}. In practice [Shen et al., 2018], the dual term Eq. 2 of Wasserstein distance could be computed by,

W1(𝐗(i),𝐗(j))=max(1Ni𝐱(i)𝐗(i)D(F(𝐱(i)))1Nj𝐱(j)𝐗(j)D(F(𝐱(j))))\begin{split}W_{1}(\mathbf{X}^{(i)},\mathbf{X}^{(j)})=\max\big{(}\frac{1}{N_{i}}\sum_{\mathbf{x}^{(i)}\in\mathbf{X}^{(i)}}D(F(\mathbf{x}^{(i)}))-\frac{1}{N_{j}}\sum_{\mathbf{x}^{(j)}\in\mathbf{X}^{(j)}}D(F(\mathbf{x}^{(j)}))\big{)}\end{split} (5)

As in domain generalization setting, there usually exists more that two source domains, we can sum all the empirical Wasserstein distance between each pair of source domains,

D=i=1mj=i+1m[1Ni𝐱(i)𝐗(i)D(F(𝐱(i)))1Nj𝐱(j)𝐗(j)D(F(𝐱(j)))]\mathcal{L}_{D}=\sum_{i=1}^{m}\sum_{j=i+1}^{m}\big{[}\frac{1}{N_{i}}\sum_{\mathbf{x}^{(i)}\in\mathbf{X}^{(i)}}D(F(\mathbf{x}^{(i)}))-\frac{1}{N_{j}}\sum_{\mathbf{x}^{(j)}\in\mathbf{X}^{(j)}}D(F(\mathbf{x}^{(j)}))\big{]} (6)

Throughout this pair-wise optimal transport process, the learner could extract a domain-invariant feature space, we then propose to apply metric learning approaches to leverage the class label similarity for domain independent clustering feature extraction. We then introduce the metric learning for domain agnostic clustering in the next section.

4.2 Metric Learning for Domain Agnostic Classification Boundary

As aforementioned, only aligning the marginal features via adversarial training is not sufficient for DG since there may exist a ambiguous decision boundary [Dou et al., 2019]. When predicting on the target domain, the learner may still suffer from this ambiguous decision boundary. To this end, we adopt the metric learning techniques [Wang et al., 2019] to help cluster the instances and promote a better prediction boundary for better generalization.

To solve this, except to the supervised source classification and alignment of the marginal distribution across domains with the Wasserstein adversarial training defined above, we then further encourage robust domain-independent local clustering via leverage from label information using the metric learning objective. The brief workflow is illustrated in Fig. 3(c). Specifically, we adopt the metric learning objective to require the images regardless of their domains to follow the two aspects: 1) images from the same class are semantically similar, thereby should be mapped nearby in the embedding space (semantic clustering), while 2) instances from different classes should be mapped apart from each other in embedding space. Since goal of domain generalization aims to learn to hypothesis could predict well on all the domains, the clustering should also be achieved under a domain-agnostic manner.

To this end, we mix the instances from all the source domains together and encourage the clustering for domain agnostic features via the metric learning techniques to achieve a domain-independent clustering decision boundary. For this, during each training iteration, for a batch {𝐱1(i),y1(i),,𝐱b(i),yb(i)}i=1m\{\mathbf{x}_{1}^{(i)},y_{1}^{(i)},\dots,\mathbf{x}^{(i)}_{b},y^{(i)}_{b}\}_{i=1}^{m} from mm source domains with batch size bb, we mix all the instances from each domain and denoted by {(𝐱iB,yiB)}i=1m\{(\mathbf{x}_{i}^{B},y_{i}^{B})\}_{i=1}^{m^{\prime}} with total size mm^{\prime}. We first measure the relative similarity between the negative and positive pairs, which is introduced in the next sub-section.

4.2.1 Pair Similarity Mining

Assume 𝐱iB\mathbf{x}_{i}^{B} is an anchor, a negative pair {𝐱iB,𝐱jB}\{\mathbf{x}_{i}^{B},\mathbf{x}_{j}^{B}\} and a positive pair {𝐱iB,𝐱jB}\{\mathbf{x}_{i}^{B},\mathbf{x}_{j^{\prime}}^{B}\} are selected if SijS_{ij} and Si,jS_{i,{j^{\prime}}} satisfy the negative condition Si,jS_{i,j}^{-} and the positive condition Si,j+S_{i,j}^{+}, respectively :

Si,jminyi=ykSi,kϵ,Si,j+minyiykSi,k+ϵS_{i,j}^{-}\geq\min_{y_{i}=y_{k}}S_{i,k}-\epsilon,\;\;\;\;S_{i,j^{\prime}}^{+}\leq\min_{y_{i}\neq y_{k}}S_{i,k}+\epsilon (7)

where ϵ\epsilon is a given margin. Through Eq. 7 and specific margin ϵ\epsilon, we will have a set of negative pairs 𝒩\mathcal{N} and a set of positive pairs 𝒫\mathcal{P}. This process (Eq. 7) could roughly cluster the instances with each anchor by selecting informative pairs (inside of the margin), and discard the less informative ones (outside of the margin).

With such roughly selected informative pairs 𝒩\mathcal{N} and 𝒫\mathcal{P}, we then assign the instance with different weights. Intuitively, if a instance has higher similarity with an anchor, then it should stay closer with the anchor and vice-versa. We introduce the weighting process in the next section.

4.2.2 Pair Weighting

For instances of positive pairs, if they are more similar with the anchor, then it should have higher weights while give the negative pairs with lower weights if they are more dissimilar, no matter which domain they come from. Through this process, we can push the instances into several groups via measure their similarities.

For NN instances, computing the similarity between each pair could result in a similarity matrix 𝐒N×N\mathbf{S}\in\mathbb{R}^{N\times N}. For a loss function based on pair similarity, it can usually be defined by (𝐒,y)\mathcal{F}(\mathbf{S},y). Let Si,jS_{i,j} be the ithi^{th} row, jthj^{th} column element of matrix 𝐒\mathbf{S}. The gradient w.r.tw.r.t the network could be computed by,

(𝐒,y)𝜽f=(𝐒,y)𝐒𝐒𝜽f=i=1Nj=1N(𝐒,y)Si,jSi,j𝜽f\begin{split}\frac{\partial\mathcal{F}(\mathbf{S},y)}{\partial\boldsymbol{\theta}_{f}}&=\frac{\partial\mathcal{F}(\mathbf{S},y)}{\partial\mathbf{S}}\frac{\partial\mathbf{S}}{\partial\boldsymbol{\theta}_{f}}=\sum_{i=1}^{N}\sum_{j=1}^{N}\frac{\partial\mathcal{F}(\mathbf{S},y)}{\partial S_{i,j}}\frac{\partial S_{i,j}}{\partial\boldsymbol{\theta}_{f}}\end{split} (8)

Eq. 8 could be reformulated into a new loss function MS\mathcal{L}_{MS} as,

MS=i=1Nj=1N(𝐒,y)Si,jSi,j\mathcal{L}_{MS}=\sum_{i=1}^{N}\sum_{j=1}^{N}\frac{\partial\mathcal{F}(\mathbf{S},y)}{\partial S_{i,j}}S_{i,j} (9)

usually the metric loss defined w.r.tw.r.t similarity matrix 𝐒\mathbf{S} and label yy could be reformulated by Eq. 9. The term (𝐒,y)Si,j\frac{\partial\mathcal{F}(\mathbf{S},y)}{\partial S_{i,j}} in Eq. 9 could be treated as an constant scalar since it doesn’t contain the gradient of MS\mathcal{L}_{MS} w.r.tw.r.t 𝜽f\boldsymbol{\theta}_{f}. Then, we just need to compute the gradient term i,j𝜽f\frac{\partial\mathcal{F}_{i,j}}{\partial\boldsymbol{\theta}_{f}} for the positive and negative pairs. Since the goal is to encourage the positive pairs to be closer, then we can assume the gradient 0\leq 0, i.e.i.e., i,j𝜽f0\frac{\partial\mathcal{F}_{i,j}}{\partial\boldsymbol{\theta}_{f}}\leq 0. Conversely, for a negative pair, we could assume i,j𝜽f0\frac{\partial\mathcal{F}_{i,j}}{\partial\boldsymbol{\theta}_{f}}\geq 0. Thus, Eq. 9 is transformed by the summation over all the positive pair (yi=yjy_{i}=y_{j}) and negative pairs (yiyjy_{i}\neq y_{j}),

MS=i=1Nj=1N(𝐒,y)Si,jSi,j=i=1N(j=1,yjyiN(𝐒,y)Si,jSi,j+j=1,yj=yiN(𝐒,y)Si,jSi,j)=i=1N(j=1,yjyiNwi,jSi,jj=1,yj=yiNwi,jSi,j)\begin{split}\mathcal{L}_{MS}&=\sum_{i=1}^{N}\sum_{j=1}^{N}\frac{\partial\mathcal{F}(\mathbf{S},y)}{\partial S_{i,j}}S_{i,j}\\ &=\sum_{i=1}^{N}\left(\sum_{j=1,y_{j}\neq y_{i}}^{N}\frac{\partial\mathcal{F}(\mathbf{S},y)}{\partial S_{i,j}}S_{i,j}+\sum_{j=1,y_{j}=y_{i}}^{N}\frac{\partial\mathcal{F}\big{(}\mathbf{S},y)}{\partial S_{i,j}}S_{i,j}\right)\\ &=\sum_{i=1}^{N}\left(\sum_{j=1,y_{j}\neq y_{i}}^{N}w_{i,j}S_{i,j}-\sum_{j=1,y_{j}=y_{i}}^{N}w_{i,j}S_{i,j}\right)\end{split} (10)

where wi,j=|Si,j𝜽f|w_{i,j}=\big{|}\frac{\partial S_{i,j}}{\partial\boldsymbol{\theta}_{f}}\big{|} is regarded as the weight for similarity Si,jS_{i,j}. Since our goal is to encourage the positive pairs to be closer, then we can assume the weight for positive pairs is smaller than 0. Conversely, for a negative pair, we can assume the weight is larger than 0. The intuition is that for a negative pair of instances, let the weight be positive, we can give it a higher loss value. Then, the learner can learn to distinguish them. On the contrary, we can assign the negative weights towards the positive pairs, which will guide the learner to not separate them apart. For each pair of instances i,ji,j, we could assign different weights according to their similarities Si,jS_{i,j}. Then, we can denote wi,j+w_{i,j}^{+} and wi,jw_{i,j}^{-} as the weight of a positive or negative pairs’ similarity, respectively.

Previously, Yi et al. [2014] and Wang et al. [2019] applied a soft function for measuring the similarity. We then consider the similarity of the pair itself (i.e.i.e. self-similarity), the negative similarity and the positive similarity. The weight of self-similarity could be measured by exp(Si,jλ)\exp({S_{i,j}-\lambda}) with a small threshold λ\lambda. For a selected negative pair {𝐱iB,𝐱jB}𝒩\{\mathbf{x}^{B}_{i},\mathbf{x}^{B}_{j}\}\in\mathcal{N} the corresponding weight (see Eq. 10) could be defined by the soft function of self-similarity together with the negative similarity:

wi,j=1exp(β(λSij))+k𝒩exp(β(Si,kλ))=exp(β(Sijλ))1+k𝒩exp(β(Sikλ))\begin{split}w_{i,j}^{-}&=\frac{1}{\exp(\beta(\lambda-S_{ij}))+\sum_{k\in\mathcal{N}}\exp(\beta(S_{i,k}-\lambda))}\\ &=\frac{\exp(\beta(S_{ij}-\lambda))}{1+\sum_{k\in\mathcal{N}}\exp(\beta(S_{ik}-\lambda))}\end{split} (11)

Similarly, the weight of a positive pair {𝐱iB,𝐱jB}𝒫\{\mathbf{x}^{B}_{i},\mathbf{x}^{B}_{j}\}\in\mathcal{P} is defined by,

wi,j+=1exp(α(λSi,j))+k𝒫exp(α(Si,kSi,j))w_{i,j}^{+}=\frac{1}{\exp(-\alpha(\lambda-S_{i,j}))+\sum_{k\in\mathcal{P}}\exp(-\alpha(S_{i,k}-S_{i,j}))} (12)

Then, take Eq. 11 and Eq. 12 into Eq. 10, and integrate Eq. 10 with the similarity mining Si,jS_{i,j}, we have the objective function for clustering,

MS=1mi=1m{1αlog[1+k𝒫iexp(α(Sikλ))]+1βlog[1+k𝒩iexp(β(Sikλ))]}\begin{split}\mathcal{L}_{MS}=\frac{1}{m}\sum_{i=1}^{m}\big{\{}\frac{1}{\alpha}\log[1+\sum_{k\in\mathcal{P}_{i}}\exp(-\alpha(S_{ik}-\lambda))]+\frac{1}{\beta}\log[1+\sum_{k\in\mathcal{N}_{i}}\exp(\beta(S_{ik}-\lambda))]\big{\}}\end{split} (13)

where λ\lambda, α\alpha and β\beta are fixed hyper-parameters, we elaborate them in the empirical setting section 5.2. Then, the whole objective of our proposed method is,

=argminθf,θcmaxθdC+λdD+λsMS\mathcal{L}=\arg\min_{\theta_{f},\theta_{c}}\max_{\theta_{d}}\mathcal{L}_{C}+\lambda_{d}\mathcal{L}_{D}+\lambda_{s}\mathcal{L}_{MS} (14)

where λd\lambda_{d} and λs\lambda_{s} are coefficients to regularize d\mathcal{L}_{d} and MS\mathcal{L}_{MS} respectively.

Based on these above, we propose the WADG algorithm in Algorithm 1. And we show the empirical results in the next section.

Algorithm 1 The proposed WADG algorithm (one round)
0:  Samples from different source domains {𝒟i}i=1M\{{\mathcal{D}}_{i}\}_{i=1}^{M}
0:  Neural network parameters 𝜽f\boldsymbol{\theta}_{f}, 𝜽c\boldsymbol{\theta}_{c}, 𝜽d\boldsymbol{\theta}_{d}
1:  for mini-batch of samples {(𝐱s(i),ys(i))}\{(\mathbf{x}^{(i)}_{s},y^{(i)}_{s})\} from source domains do
2:     Compute the classification loss C\mathcal{L}_{C} over all the domains according to Eq. 4
3:     Compute the Wasserstein distance D\mathcal{L}_{D} between each pair of source domains according to Eq. 6
4:     Mix the pairs from different domains and compute the similarity by Eq. 3
5:     Roughly select the positive and negative pairs by solving Eq. 7
6:     Compute similarity loss MS\mathcal{L}_{MS} on all the source instances by Eq. 13
7:     Update θf,θc\theta_{f},\theta_{c} and d by solving Eq. 14 with learning rate η\eta:
𝜽f𝜽fη(C+λdD+λsMS)𝜽f,𝜽c𝜽cη(C+λdD+λsMS)𝜽c,𝜽d𝜽d+ηD𝜽d\begin{split}&\boldsymbol{\theta}_{f}\leftarrow\boldsymbol{\theta}_{f}-\eta\frac{\partial(\mathcal{L}_{C}+\lambda_{d}\mathcal{L}_{D}+\lambda_{s}\mathcal{L}_{MS})}{\partial\boldsymbol{\theta}_{f}},\\ &\boldsymbol{\theta}_{c}\leftarrow\boldsymbol{\theta}_{c}-\eta\frac{\partial(\mathcal{L}_{C}+\lambda_{d}\mathcal{L}_{D}+\lambda_{s}\mathcal{L}_{MS})}{\partial\boldsymbol{\theta}_{c}},\\ &\boldsymbol{\theta}_{d}\leftarrow\boldsymbol{\theta}_{d}+\eta\frac{\partial\mathcal{L}_{D}}{\partial\boldsymbol{\theta}_{d}}\end{split}
8:  end for
9:  Return the optimal parameters 𝜽f\boldsymbol{\theta}_{f}^{\star}, 𝜽c\boldsymbol{\theta}_{c}^{\star} and 𝜽d\boldsymbol{\theta}_{d}^{\star}

5 Experiments and Results

5.1 Datasets

In order to evaluate our proposed approach, we implement experiments on three common used datasest: VLCS [Torralba & Efros, 2011], PACS [Li et al., 2017a] and Office-home [Venkateswara et al., 2017] dataset. The VLCS dataset contains images from 4 different domains: PASCAL VOC2007 (V), LabelMe (L), Caltech (C), and SUN09 (S). Each domain includes five classes: bird, car, chair, dog and person. PACS dataset is a recent benchmark dataset for domain generalization. It consists of four domains: Photo (P), Art painting (A), Cartoon (C), Sketch (S), with objects from seven classes: dog, elephant, giraffe, guitar, house, horse, person. Office-Home is a more challenging dataset, which contains four different domains: Art (Ar), Clipart (Cl), Product (Pr) and Real World (Rw), with 6565 categories in each domain. Previous work showed that matter the adversarial model is trained under supervised [Long et al., 2017], semi-supervised [Zhou et al., 2020] or unsupervised [Long et al., 2018] way, the model will suffer from learning the diverse feature. To test our domain generalization model on this dataset could also help to affirm the effectiveness of our approach.

5.2 Baselines and Implementation details

To show the effectiveness of our proposed approach, we compare our algorithm on the benchmark datasets with the following recent domain generalization methods.

  • 1.

    Deep All: We follow the standard evaluation protocol of Domain Generalization to set up the pre-trained Alexnet or ResNet-18 fine-tuned on the aggregation of all source domains with only the classification loss.

  • 2.

    TF [Li et al., 2017b]: A low-rank parameterized Convolution Neural Network model which aims to reduce the total number of model parameters for an end-to-end Domain Generalization training.

  • 3.

    CIDDG [Li et al., 2018c]: Matches the conditional distribution by change the class prior.

  • 4.

    MLDG [Li et al., 2018a]: The meta-learning approach for domain generalization. It runs the meta-optimization on simulated meta-train/ meta-test sets with domain shift

  • 5.

    CCSA [Motiian et al., 2017]: The contrastive semantic alignment loss was adopted together with the source classification loss function for both the domain adaptation and domain generalization problem.

  • 6.

    MMD-AAE [Li et al., 2018b]: The Adversarial Autoencoder model was adopted together with the Mean-Max Discrepancy to extract a domain invariant feature for generalization.

  • 7.

    D-SAM [D’Innocente & Caputo, 2018]: It aggregates domain-specific modules and merges general and specific information together for generalization.

  • 8.

    JiGen [Carlucci et al., 2019]: It achieves domain generalization by solving the Jigsaw puzzle via the unsupervised task.

  • 9.

    MASF [Dou et al., 2019]: A meta-learning style method which based on MLDG and combined with Consitrastive Loss/ Triplet Loss to encourage domain-independent semantic feature space.

  • 10.

    MMLD [Matsuura & Harada, 2020]: An approach that mixes all the source domains by assigning a pseudo domain label for extract domain-independent cluster feature space.

Table 2: The hyper-parameter values for experiments
Hyper-parameters Value Hyper-parameters Value
learning rate PACS: 5×1045\times 10^{-4} λ\lambda 1.01.0
Office-home: 2×1042\times 10^{-4} α\alpha 2.02.0
λd\lambda_{d} λd=21+exp(10p)1\lambda_{d}=\frac{2}{1+\exp(-10p)}-1 β\beta 40.040.0
λs\lambda_{s} [1e4,1e5][1e-4,1e-5] ϵ\epsilon 0.10.1

Following the general evaluation protocol of domain generalization (e.g.e.g. Dou et al. [2019], Matsuura & Harada [2020]), on PACS and VLCS dataset. We first test our algorithm on by using AlexNet [Krizhevsky et al., 2012] backbones by removing the last layer as feature extractor. For preparing the dataset, we follow the train/val./test split and the data pre-processing protocol of Matsuura & Harada [2020]. As for the classifier, we initialize a three-layers MLP whose input has the same number of inputs as the feature extractor’s output and to have the same number of outputs as the number of object categories (2048-256-256-KK), where KK is the number of classes. For the critic network, we also adopt a three-layers MLP (2048-1024-1024-1). For the metric learning objective, we use the output of the second layer of classifier network (with size 256) for computing the similarity.

In order to better demonstrating the hyper-parameters used in this work, we firstly summarized the value of hyper-parameters in Table. 2. The corresponding descriptions are provided in the following parts. We adopt the ADAM [Kingma & Ba, 2014] optimizer for training with learning rate ranging from 5×1045\times 10^{-4} to 5×1055\times 10^{-5} for the whole model together with mini-batch size 6464.

Table 3: Empirical Results (accuracy %\%) on PACS dataset with pre-trained AlexNet as Feature Extractor. For each column, we refer the generalization taks as the target domain name. For example, the third column ‘Cartoon‘ refers to the generalization tasks where domain Cartoon is the target domain while the model is trained on the rest three domains.
Method Art Cartoon Sketch Photo Avg.
Deep All 63.3063.30 63.1363.13 54.0754.07 87.7087.70 67.0567.05
TF[Li et al., 2017b] 62.8662.86 66.9766.97 57.5157.51 89.5089.50 59.2159.21
CIDDG[Li et al., 2018c] 62.7062.70 69.7369.73 64.4564.45 78.6578.65 68.8868.88
MLDG [Li et al., 2018a] 66.2366.23 66.8866.88 58.9658.96 88.0088.00 70.0170.01
D-SAM[D’Innocente & Caputo, 2018] 63.8763.87 70.7070.70 64.6664.66 85.5585.55 71.2071.20
JiGen[Carlucci et al., 2019] 67.6367.63 71.7171.71 65.1865.18 89.0089.00 73.3873.38
MASF[Dou et al., 2019] 70.35\mathbf{70.35} 72.4672.46 67.3367.33 90.6890.68 75.2175.21
MMLD[Matsuura & Harada, 2020] 69.2769.27 72.83\mathbf{72.83} 66.4466.44 88.9888.98 74.3874.38
Ours 70.2170.21 72.5172.51 70.32\mathbf{70.32} 89.81\mathbf{89.81} 75.71\mathbf{75.71}

For stable training, we set coefficient λd=21+exp(10p)1\lambda_{d}=\frac{2}{1+\exp(-10p)}-1 to regularize the adversarial loss, where pp is the training progress, to regularize the adversarial loss. This regularization scheme λd\lambda_{d} has been widely used in adversarial training based domain adaptation and generalization setting (e.g.e.g. [Long et al., 2017, Wen et al., 2019, Matsuura & Harada, 2020]) and have been proved could help to stabilize the training process. For the setting of λs\lambda_{s}, we follow the setting of [Dou et al., 2019] and set the value to 10410^{-4}. In our preliminary validation results, the performance is not sensitive with λd[0,1]\lambda_{d}\in[0,1]. We also tried to range λs\lambda_{s} from 10310^{-3} to 10610^{-6} via reverse validation and didn’t observe obvious differences.

Method Caltech LabelMe Pascal Sun Avg.
Deep All 92.8692.86 63.1063.10 68.6768.67 64.1164.11 72.1972.19
D-MATE [Ghifary et al., 2015] 89.0589.05 60.1360.13 63.9063.90 61.3361.33 68.6068.60
CIDDG [Li et al., 2018c] 88.8388.83 63.0663.06 64.3864.38 62.1062.10 69.5969.59
CCSA [Motiian et al., 2017] 92.3092.30 62.1062.10 67.1067.10 59.1059.10 70.1570.15
SLRC [Ding & Fu, 2017] 92.7692.76 62.3462.34 65.2565.25 63.5463.54 70.9770.97
TF [Li et al., 2017b] 93.6393.63 63.4963.49 69.9969.99 61.3261.32 72.1172.11
MMD-AAE [Li et al., 2018b] 94.4094.40 62.6062.60 67.7067.70 64.4064.40 72.2872.28
D-SAM [D’Innocente & Caputo, 2018] 91.7591.75 56.9556.95 58.9558.95 60.8460.84 67.0367.03
MLDG [Li et al., 2018a] 94.494.4 61.361.3 67.767.7 65.965.9 73.3073.30
JiGen [Carlucci et al., 2019] 96.9396.93 60.9060.90 70.6270.62 64.3064.30 73.1973.19
MASF [Dou et al., 2019] 94.7894.78 64.90\mathbf{64.90} 69.1469.14 67.6467.64 74.1174.11
MMLD [Matsuura & Harada, 2020] 96.6696.66 58.7758.77 71.96\mathbf{71.96} 68.13\mathbf{68.13} 73.8873.88
Ours 96.68\mathbf{96.68} 64.2664.26 71.4771.47 66.6266.62 74.76\mathbf{74.76}
Table 4: Empirical Results (accuracy %\%) on VLCS dataset with pre-trained AlexNet as feature extractor.

Then, we examined our algorithm on the office-home benchmark, which is more challenging than the previous PACS and VLCS datasets. We follow the setting of [Carlucci et al., 2019], which is the most recent work who also evaluated on office-home dataset, to have a fair comparison. For this Office-home dataset, we also used reverse validation to set the learning rate as 2e42e-4 for the whole model. For the remaining hyper-parameters, we keep the same with PACS and VLCS experiments. To avoid over-training, we also adopt the early stopping technique. All the experiments are programmed with PyTorch [Paszke et al., 2019].

5.3 Experiments Results

We first reported the empirical results on PACS and VLCS dataset using AlexNet as feature extractor in Table 3 and Table 4, respectively. For each generalization task, we train the model on all the source domains and test on the target domain and report the average of top 5 accuracy values. The empirical results refers to the average accuracy about training on source domains while testing on the target domain.

From the empirical results, we can observe our method outperforms the baselines both on the PACS and VLCS dataset, indicating an improvement on benchmark performances. This showed the effectiveness of our method. Then, we report the empirical results on Office-home dataset in Table 5. As stated before, Office-home is a more larger and challenging dataset contains more diverse features from 6565 different classes. To evaluate the performance on this dataset requires large amount of computational resources. Due to the limits, we follow the evaluation protocol of Carlucci et al. [2019] to report the empirical results. From those results, we could observe that our algorithm outperforms the previous Domain Generalization method, this also confirm the effectiveness of our proposed method.

Art Clipart Product Real-World Avg.
Deep All 52.1552.15 45.8645.86 70.8670.86 73.1573.15 60.5160.51
D-SAM[D’Innocente & Caputo, 2018] 58.03\mathbf{58.03} 44.3744.37 69.2269.22 71.4571.45 60.7760.77
JiGen[Carlucci et al., 2019] 53.0453.04 47.51\mathbf{47.51} 71.4771.47 72.7972.79 61.2061.20
Ours 55.3455.34 44.8244.82 72.03\mathbf{72.03} 73.55\mathbf{73.55} 61.44\mathbf{61.44}
Table 5: Empirical Results (accuracy %\%) on Office-home dataset with pre-trained ResNet-18 as feature extractor.

5.4 Further Analysis

To further show the effectiveness of our algorithm especially on more deep models, follow Dou et al. [2019], we also report the results of our algorithm by using ResNet-18 backbone on PACS dataset in Table 6. The ResNet-18 backbone, the output feature dim will be 512512. From the results, we could observe that our method could outperform the baselines on most generalization tasks and on average +1.6%+1.6\% accuracy improvement.

Then, we implement ablation studies on each component of our algorithm. We report the empirical results of ablation studies in Table 7, where we test the ablation studies on both the AlexNet backbone and ResNet-18 backbone. We compare the ablations by, (1) Deep All: Train the model using feature extractor on source domain datasets with classification loss only, that is, neither optimal transport nor metric learning techniques is adopted. (2) No D\mathcal{L}_{D}: Train the model with classification loss and metric learning loss but without adversarial training component; (3) MS\mathcal{L}_{MS} w.o. w+w^{+}: omit the positive weighting scheme in MS\mathcal{L}_{MS} (4)  MS\mathcal{L}_{MS} w.o. ww^{-} : omit the positive weighting scheme in MS\mathcal{L}_{MS}. (5) No MS\mathcal{L}_{MS}: Train the model with classification loss and adversarial loss but without metric learning component; (6) WADG-All: Train the model with full objective Eq. 14.

Table 6: Empirical Results (accuracy %\%) on PACS dataset with pre-trained ResNet-18 as feature extractor .
Method Art Cartoon Sketch Photo Avg.
Deep All 77.8777.87 75.8975.89 69.2769.27 95.1995.19 79.5579.55
D-SAM[D’Innocente & Caputo, 2018] 77.3377.33 72.4372.43 77.8377.83 95.3095.30 80.7280.72
JiGen[Carlucci et al., 2019] 79.4279.42 75.2575.25 71.3571.35 96.0396.03 80.5180.51
MASF[Dou et al., 2019] 80.2980.29 77.1777.17 71.6971.69 94.9994.99 81.0481.04
MMLD[Matsuura & Harada, 2020] 81.2881.28 77.1677.16 72.2972.29 96.09\mathbf{96.09} 81.8381.83
Ours 81.56\mathbf{81.56} 78.02\mathbf{78.02} 78.43\mathbf{78.43} 95.8295.82 83.45\mathbf{83.45}
Refer to caption
(a) Deep All
Refer to caption
(b) No D\mathcal{L}_{D}
Refer to caption
(c) No MS\mathcal{L}_{MS}
Refer to caption
(d) WADG-All
Figure 4: T-SNE visualization of ablation studies on PACS dataset for Target domain as Photo. Detailed analysis is presented in section 5.4.

From the results, we could observe that one we omit the adversarial training, the accuracy would drop off rapidly (3.5%\sim 3.5\% with AlexNet backbone and 5.8%\sim 5.8\% with ResNet-18 backbone). The contribution of the metric learning loss is relatively small compared with adversarial loss. Comparing the ablations  MS\mathcal{L}_{MS} w.o. w+w^{+} and  MS\mathcal{L}_{MS} w.o. ww^{-}, we could observe almost similar accuracy. This indicates that the positive and negative weighting scheme of the metric learning objective may have equivalent contribution. . Once we omit the metric learning loss, the performance will drop 2.1%\sim 2.1\% and 2.5%\sim 2.5\% with AlexNet and ResNet-18 backbone, respectively.

Refer to caption
(a) Deep All
Refer to caption
(b) No D\mathcal{L}_{D}
Refer to caption
(c) No MS\mathcal{L}_{MS}
Refer to caption
(d) WADG-All
Figure 5: T-SNE visualization of ablation studies on VLCS dataset for Target domain as Caltech. Detailed analysis is presented in section 5.4.
Table 7: Ablation Studies on PACS dataset on all components of our proposed method using AlexNet and ResNet-18 backbone
AlexNet ResNet-18
Ablation Art Carton Sketch Photo Avg. Art Carton Sketch Photo Avg.
Deep All 63.3063.30 63.1363.13 54.0754.07 87.7087.70 67.0567.05 77.8777.87 75.8975.89 69.2769.27 95.1995.19 79.5579.55
No D\mathcal{L}_{D} 65.8065.80 69.6469.64 63.9163.91 89.5389.53 72.2272.22 74.6274.62 73.0273.02 68.6768.67 94.8694.86 77.7977.79
No MS\mathcal{L}_{MS} 66.7866.78 71.4771.47 68.1268.12 88.8788.87 73.6573.65 78.2578.25 76.2776.27 73.4273.42 95.6895.68 80.9180.91
MS\mathcal{L}_{MS} w.o. w+w^{+} 66.3166.31 70.8670.86 67.1167.11 88.9788.97 73.3173.31 80.5880.58 77.9577.95 75.1375.13 95.6395.63 82.3282.32
MS\mathcal{L}_{MS} w.o. ww^{-} 66.4166.41 70.9570.95 68.7368.73 87.3887.38 73.3773.37 79.9879.98 77.6577.65 77.8977.89 95.2195.21 82.6882.68
WADG-All 70.21\mathbf{70.21} 72.51\mathbf{72.51} 70.32\mathbf{70.32} 89.81\mathbf{89.81} 75.71\mathbf{75.71} 81.56\mathbf{81.56} 78.02\mathbf{78.02} 78.43\mathbf{78.43} 95.82\mathbf{95.82} 83.45\mathbf{83.45}

Then, to better understand the contribution of each component of our algorithm, the T-SNE visualization of the ablation studies of each components on PACS and VLCS dataset are represented in Fig. 4 for the generalization task of target domain Photo. and Fig. 5 for the generalization task of target domain Caltech, respectively. Since our goal is to not only align the feature distribution but also encourage a cohesion and separable boundary, in order to show the alignment and clustering performance, we report the T-SNE features of all the source domains and target domain to show the feature alignment and clustering across domains.
For PASC dataset, as we can see, the T-SNE features by Deep All could neither project the instances from different domains to align with each other nor cluster the features into groups. The T-SNE features by No D\mathcal{L}_{D} showed the metric learning loss could to some extent to cluster the features, but without the adversarial training, the features could not be aligned well. The T-SNE features by No MS\mathcal{L}_{MS} showed that the adversarial training could help to align the features from different domains but could not have a good clustering performance. The T-SNE features by WADG-All showed that the full objective could help to not only align the features from different domains but also could cluster the features from different domains into several cluster groups, which confirms the effective of our algorithm.
As for the VLCS dataset, we could observe similar performance on the T-SNE on the VLCS dataset while the features are somehow overlap with each other. This is due to the features in Caltech domain is somehow easy to learn and predict. As also analyzed in [Li et al., 2017a], a supervised model on Caltech domain could achieved 100%\sim 100\% accuracy, which also confirms that the features in Caltech domain is easy to learn indicating the features might be more likely overlapping with each other. As we can see from Fig.5(d), the WADG method could help to separate the features with each other, which again confirms the effectiveness of our proposed method.

6 Conclusion

In this paper, we proposed the Wasserstein Adversarial Domain Generalization algorithm for not only aligning the source domain features and transferring to an unseen target domain but also leveraging the label information across domains. We first adopt optimal transport with Wasserstein distance for aligning the marginal distribution and then adopt the metric learning method to encourage a domain-independent distinguishable feature space for a clear classification boundary. The experiments results showed our proposed algorithm could outperform most of the baseline methods on two standard benchmark datasets. Furthermore, the ablation studies and visualization of the T-SNE features also confirmed the effectiveness of our algorithm.

Acknowledgement

This work has been partially supported by Natural Sciences and Engineering Research Council of Canada (NSERC), The Fonds de recherche du Québec - Nature et technologies (FRQNT). Fan Zhou is supported by China Scholarship Council. Boyu Wang is supported by the Natural Sciences and Engineering Research Council of Canada (NSERC), Discovery Grants Program.

A full version of this preprint has been published as:

Fan Zhou, Zhuqing Jiang, Changjian Shui, Boyu Wang, Brahim Chaib-draa, Domain generalization via optimal transport with metric similarity learning, Neurocomputing, Volume 456, 2021, Pages 469-480, https://doi.org/10.1016/j.neucom.2020.09.091. (https://www.sciencedirect.com/science/article/pii/S0925231221002009)

References

  • Arjovsky et al. [2017] Arjovsky, M., Chintala, S., & Bottou, L. (2017). Wasserstein gan. arXiv preprint arXiv:1701.07875, .
  • Balaji et al. [2018] Balaji, Y., Sankaranarayanan, S., & Chellappa, R. (2018). Metareg: Towards domain generalization using meta-regularization. In Advances in Neural Information Processing Systems (pp. 998–1008).
  • Ben-David et al. [2010] Ben-David, S., Blitzer, J., Crammer, K., Kulesza, A., Pereira, F., & Vaughan, J. (2010). A theory of learning from different domains. Machine Learning, 79, 151–75. URL: http://www.springerlink.com/content/q6qk230685577n52/.
  • Carlucci et al. [2019] Carlucci, F. M., D’Innocente, A., Bucci, S., Caputo, B., & Tommasi, T. (2019). Domain generalization by solving jigsaw puzzles. In CVPR.
  • Courty et al. [2016] Courty, N., Flamary, R., Tuia, D., & Rakotomamonjy, A. (2016). Optimal transport for domain adaptation. IEEE transactions on pattern analysis and machine intelligence, 39, 1853–65.
  • Deng et al. [2020] Deng, W., Zheng, L., Sun, Y., & Jiao, J. (2020). Rethinking triplet loss for domain adaptation. IEEE Transactions on Circuits and Systems for Video Technology, (pp. 1–).
  • Ding & Fu [2017] Ding, Z., & Fu, Y. (2017). Deep domain generalization with structured low-rank constraint. IEEE Transactions on Image Processing, 27, 304–13.
  • Dou et al. [2019] Dou, Q., de Castro, D. C., Kamnitsas, K., & Glocker, B. (2019). Domain generalization via model-agnostic learning of semantic features. In Advances in Neural Information Processing Systems (pp. 6447–58).
  • D’Innocente & Caputo [2018] D’Innocente, A., & Caputo, B. (2018). Domain generalization with domain-specific aggregation modules. In German Conference on Pattern Recognition (pp. 187–98). Springer.
  • Finn et al. [2017] Finn, C., Abbeel, P., & Levine, S. (2017). Model-agnostic meta-learning for fast adaptation of deep networks. In Proceedings of the 34th International Conference on Machine Learning-Volume 70 (pp. 1126–35). JMLR. org.
  • Ganin et al. [2016] Ganin, Y., Ustinova, E., Ajakan, H., Germain, P., Larochelle, H., Laviolette, F., Marchand, M., & Lempitsky, V. (2016). Domain-adversarial training of neural networks. The Journal of Machine Learning Research, 17, 2096–30.
  • Ghifary et al. [2015] Ghifary, M., Bastiaan Kleijn, W., Zhang, M., & Balduzzi, D. (2015). Domain generalization for object recognition with multi-task autoencoders. In Proceedings of the IEEE international conference on computer vision (pp. 2551–9).
  • Goldberg et al. [2009] Goldberg, A., Zhu, X., Singh, A., Xu, Z., & Nowak, R. (2009). Multi-manifold semi-supervised learning. In Artificial Intelligence and Statistics (pp. 169–76).
  • Gulrajani et al. [2017] Gulrajani, I., Ahmed, F., Arjovsky, M., Dumoulin, V., & Courville, A. C. (2017). Improved training of wasserstein gans. In Advances in neural information processing systems (pp. 5767–77).
  • Hadsell et al. [2006] Hadsell, R., Chopra, S., & LeCun, Y. (2006). Dimensionality reduction by learning an invariant mapping. In 2006 IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR’06) (pp. 1735–42). IEEE volume 2.
  • Ilse et al. [2019] Ilse, M., Tomczak, J. M., Louizos, C., & Welling, M. (2019). Diva: Domain invariant variational autoencoders. arXiv preprint arXiv:1905.10427, .
  • Kamnitsas et al. [2018] Kamnitsas, K., Castro, D. C., Folgoc, L. L., Walker, I., Tanno, R., Rueckert, D., Glocker, B., Criminisi, A., & Nori, A. (2018). Semi-supervised learning via compact latent space clustering. arXiv preprint arXiv:1806.02679, .
  • Kingma & Ba [2014] Kingma, D. P., & Ba, J. (2014). Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980, .
  • Krizhevsky et al. [2012] Krizhevsky, A., Sutskever, I., & Hinton, G. E. (2012). Imagenet classification with deep convolutional neural networks. In Advances in neural information processing systems (pp. 1097–105).
  • Li et al. [2017a] Li, D., Yang, Y., Song, Y.-Z., & Hospedales, T. (2017a). Deeper, broader and artier domain generalization. In International Conference on Computer Vision.
  • Li et al. [2017b] Li, D., Yang, Y., Song, Y.-Z., & Hospedales, T. M. (2017b). Deeper, broader and artier domain generalization. In Proceedings of the IEEE international conference on computer vision (pp. 5542–50).
  • Li et al. [2018a] Li, D., Yang, Y., Song, Y.-Z., & Hospedales, T. M. (2018a). Learning to generalize: Meta-learning for domain generalization. In Thirty-Second AAAI Conference on Artificial Intelligence.
  • Li et al. [2018b] Li, H., Jialin Pan, S., Wang, S., & Kot, A. C. (2018b). Domain generalization with adversarial feature learning. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (pp. 5400–9).
  • Li et al. [2018c] Li, Y., Tian, X., Gong, M., Liu, Y., Liu, T., Zhang, K., & Tao, D. (2018c). Deep domain generalization via conditional invariant adversarial networks. In Proceedings of the European Conference on Computer Vision (ECCV) (pp. 624–39).
  • Li et al. [2019] Li, Y., Yang, Y., Zhou, W., & Hospedales, T. M. (2019). Feature-critic networks for heterogeneous domain generalization. arXiv preprint arXiv:1901.11448, .
  • Long et al. [2018] Long, M., Cao, Z., Wang, J., & Jordan, M. I. (2018). Conditional adversarial domain adaptation. In Advances in Neural Information Processing Systems (pp. 1640–50).
  • Long et al. [2017] Long, M., Cao, Z., Wang, J., & Philip, S. Y. (2017). Learning multiple tasks with multilinear relationship networks. In Advances in neural information processing systems (pp. 1594–603).
  • Ma et al. [2019a] Ma, Z., Chang, D., Xie, J., Ding, Y., Wen, S., Li, X., Si, Z., & Guo, J. (2019a). Fine-grained vehicle classification with channel max pooling modified cnns. IEEE Transactions on Vehicular Technology, 68, 3224–33.
  • Ma et al. [2019b] Ma, Z., Ding, Y., Wen, S., Xie, J., Jin, Y., Si, Z., & Wang, H. (2019b). Shoe-print image retrieval with multi-part weighted cnn. IEEE Access, 7, 59728–36.
  • Ma et al. [2018a] Ma, Z., Lai, Y., Kleijn, W. B., Song, Y.-Z., Wang, L., & Guo, J. (2018a). Variational bayesian learning for dirichlet process mixture of inverted dirichlet distributions in non-gaussian image feature modeling. IEEE transactions on neural networks and learning systems, 30, 449–63.
  • Ma et al. [2013] Ma, Z., Leijon, A., & Kleijn, W. B. (2013). Vector quantization of lsf parameters with a mixture of dirichlet distributions. IEEE Transactions on Audio, Speech, and Language Processing, 21, 1777–90.
  • Ma et al. [2018b] Ma, Z., Yu, H., Chen, W., & Guo, J. (2018b). Short utterance based speech language identification in intelligent vehicles with time-scale modifications and deep bottleneck features. IEEE transactions on vehicular technology, 68, 121–8.
  • Matsuura & Harada [2020] Matsuura, T., & Harada, T. (2020). Domain generalization using a mixture of multiple latent domains. In AAAI.
  • Motiian et al. [2017] Motiian, S., Piccirilli, M., Adjeroh, D. A., & Doretto, G. (2017). Unified deep supervised domain adaptation and generalization. In The IEEE International Conference on Computer Vision (ICCV).
  • Muandet et al. [2013] Muandet, K., Balduzzi, D., & Schölkopf, B. (2013). Domain generalization via invariant feature representation. In International Conference on Machine Learning (pp. 10–8).
  • Oh Song et al. [2016] Oh Song, H., Xiang, Y., Jegelka, S., & Savarese, S. (2016). Deep metric learning via lifted structured feature embedding. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 4004–12).
  • Paszke et al. [2019] Paszke, A., Gross, S., Massa, F., Lerer, A., Bradbury, J., Chanan, G., Killeen, T., Lin, Z., Gimelshein, N., Antiga, L. et al. (2019). Pytorch: An imperative style, high-performance deep learning library. In Advances in Neural Information Processing Systems (pp. 8024–35).
  • Redko et al. [2017] Redko, I., Habrard, A., & Sebban, M. (2017). Theoretical analysis of domain adaptation with optimal transport. In Joint European Conference on Machine Learning and Knowledge Discovery in Databases (pp. 737–53). Springer.
  • Schroff et al. [2015] Schroff, F., Kalenichenko, D., & Philbin, J. (2015). Facenet: A unified embedding for face recognition and clustering. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 815–23).
  • Shen et al. [2018] Shen, J., Qu, Y., Zhang, W., & Yu, Y. (2018). Wasserstein distance guided representation learning for domain adaptation. In AAAI Conference on Artificial Intelligence.
  • Torralba & Efros [2011] Torralba, A., & Efros, A. A. (2011). Unbiased look at dataset bias. In CVPR 2011 (pp. 1521–8). IEEE.
  • Venkateswara et al. [2017] Venkateswara, H., Eusebio, J., Chakraborty, S., & Panchanathan, S. (2017). Deep hashing network for unsupervised domain adaptation. In (IEEE) Conference on Computer Vision and Pattern Recognition (CVPR).
  • Volpi et al. [2018] Volpi, R., Namkoong, H., Sener, O., Duchi, J. C., Murino, V., & Savarese, S. (2018). Generalizing to unseen domains via adversarial data augmentation. In Advances in Neural Information Processing Systems (pp. 5334–44).
  • Wainwright [2019] Wainwright, M. J. (2019). High-dimensional statistics: A non-asymptotic viewpoint volume 48. Cambridge University Press.
  • Wang et al. [2019] Wang, X., Han, X., Huang, W., Dong, D., & Scott, M. R. (2019). Multi-similarity loss with general pair weighting for deep metric learning. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (pp. 5022–30).
  • Wen et al. [2019] Wen, J., Zheng, N., Yuan, J., Gong, Z., & Chen, C. (2019). Bayesian uncertainty matching for unsupervised domain adaptation. arXiv preprint arXiv:1906.09693, .
  • Wu et al. [2017] Wu, C.-Y., Manmatha, R., Smola, A. J., & Krahenbuhl, P. (2017). Sampling matters in deep embedding learning. In Proceedings of the IEEE International Conference on Computer Vision (pp. 2840–8).
  • Xie et al. [2018] Xie, J., Song, Z., Li, Y., & Ma, Z. (2018). Mobile big data analysis with machine learning. arXiv preprint arXiv:1808.00803, .
  • Xu et al. [2018] Xu, P., Yin, Q., Huang, Y., Song, Y.-Z., Ma, Z., Wang, L., Xiang, T., Kleijn, W. B., & Guo, J. (2018). Cross-modal subspace learning for fine-grained sketch-based image retrieval. Neurocomputing, 278, 75–86.
  • Yi et al. [2014] Yi, D., Lei, Z., Liao, S., & Li, S. Z. (2014). Deep metric learning for person re-identification. In 2014 22nd International Conference on Pattern Recognition (pp. 34–9). IEEE.
  • Zhao et al. [2019] Zhao, H., Combes, R. T. d., Zhang, K., & Gordon, G. J. (2019). On learning invariant representation for domain adaptation. arXiv preprint arXiv:1901.09453, .
  • Zhou et al. [2020] Zhou, F., Shui, C., Huang, B., Wang, B., & Chaib-draa, B. (2020). Discriminative active learning for domain adaptation. arXiv preprint arXiv:2005.11653, .
  • Zhu et al. [2019] Zhu, F., Ma, Z., Li, X., Chen, G., Chien, J.-T., Xue, J.-H., & Guo, J. (2019). Image-text dual neural network with decision strategy for small-sample image classification. Neurocomputing, 328, 182–8.
  • Zhuang et al. [2019] Zhuang, F., Qi, Z., Duan, K., Xi, D., Zhu, Y., Zhu, H., Xiong, H., & He, Q. (2019). A comprehensive survey on transfer learning. arXiv preprint arXiv:1911.02685, .