Domain Generalization via Optimal Transport with Metric Similarity Learning
Abstract
Generalizing knowledge to unseen domains, where data and labels are unavailable, is crucial for machine learning. We tackle in this chapter the domain generalization problem to learn from multiple source domains and generalize to a target domain with unknown statistics. The crucial idea is to extract the underlying invariant features across all the domains. Previous domain generalization approaches mainly focused on learning invariant features and stacking the learned features from each source domain to generalize to a new target domain while ignoring the label information, this generally leads to indistinguishable features with an ambiguous classification boundary. One possible solution is to constrain the label-similarity when extracting the invariant features and take advantage of the label similarities for class-specific cohesion and separation of features across domains. We adopt here the optimal transport with Wasserstein distance, which could constrain the class label similarity, for adversarial training. We also deploy a metric learning objective to leverage the label information for achieving distinguishable classification boundary. Our empirical results show that our proposed method could outperform most of the baselines. Furthermore, ablation studies also demonstrate the effectiveness of each component of our method.
keywords:
Domain Generalization , Adversarial Learning , Metric Learning1 Introduction
Recent years witness a rapid development of machine learning and its succeeded applications such as computer vision [Ma et al., 2018a, Zhu et al., 2019, Ma et al., 2019b], natural language processing [Ma et al., 2013, 2018b] and cross-modalities learning [Zhu et al., 2019, Xu et al., 2018] with many real-world applications [Xie et al., 2018, Ma et al., 2019a]. Traditional machine learning methods are typically based on the assumption that training and testing datasets are from the same distribution. However, in many real-world applications, this assumption may not hold, and the performance could degrade rapidly if the trained models are deployed to domains different from the training dataset [Ganin et al., 2016]. More severely, to train a high-performance vision system requires a large amount of labelled data, and getting such labels may be expensive. Taking a pre-trained robotic vision system as an example, during each deployment task, the robot itself (e.g. position and angle), the environment (e.g. weather and illumination) and the camera (e.g. resolution) may result in different image styles. The cost to annotate enough data for each deployment task could be very expensive.
This kind of problem has been widely addressed by transfer learning (TL) [Zhuang et al., 2019] and domain adaptation (DA) [Ganin et al., 2016]. In DA, a learner usually has access to the labelled source data and unlabelled target data, and it is typically trained to align the feature distribution between the source and target domain. However, sometimes, we could not expect the target data is accessible for the learner. In the robot example, the distribution divergences (different image styles) from training to testing domain can only be identified after the model is trained and deployed. In this scenario, it’s unrealistic to collect samples before deployment. This would require a robot to have abilities to handle domain divergences even though the target data is absent.
We tackle this kind of problem under domain generalization (DG) paradigm, under which the learner has access to many source domains (data and corresponding labels), and aims at generalizing to the new (target) domain, where both data and labels are unknown. The goal of DG is to learn a prediction model on training data from the seen source domains so that it can generalize well on the unseen target domain. An underlying assumption behind domain generalization is that there exists a common feature space underlying the multiple known source domains and unseen target domain. Specifically, we want to learn domain invariant features across these source domains, and then generalize to a new domain. An example of how domain generalization is processed is illustrated in Fig.1.

A critical problem in DG and DA involves aligning the domain distributions, which typically are achieved by extracting such representations. Previous DA works usually tried to minimize the domain discrepancies, such as KL-divergence and Maximum Mean Discrepancy (MMD) etc. via adversarial training, to achieve domain distribution alignments. Due to the similar problem setting between DA and DG, many previous approaches directly adopt the same adversarial training technique for DG. For example, a MMD metric is adopted by Li et al. [2018b] as a cross-domain regularizer and KL divergence is adopted to measure the domain shift by Li et al. [2017a] for domain generalization problem. The MMD metric is usually implemented in kernel space, which is not sufficient for large-scaled applications, and KL divergence is unbounded, which is also insufficient for a successful measuring domain shift [Zhao et al., 2019].
Besides, previous domain generalization approaches [Ilse et al., 2019, Ghifary et al., 2015, Li et al., 2018c, D’Innocente & Caputo, 2018, Volpi et al., 2018] mainly focused on applying similar DA technique to extract the invariant features and how to stack the learned features from each domain for generalizing to a new domain. These methods usually ignore the label information and will sometimes make the features became indistinguishable with ambiguous classification boundaries, semantic misalignment problem [Deng et al., 2020]. A successful generalization should guide the learner not only to align the feature distributions between each domain but also to discriminate the samples in the same class could lie close to each other while samples from different classes could stay apart from each other, feature compactness [Kamnitsas et al., 2018].
Aiming to solve this, we adopt Optimal Transport (OT) with Wasserstein distance to align the feature distribution for domain generalization since it could constrain labelled source samples of the same class to remain close during the transportation process [Courty et al., 2016]. Moreover, some information theoretical metrics such as KL divergence is not capable to measure the inherent geometric relations among the different domains [Arjovsky et al., 2017]. In contrast, OT can exactly measure their corresponding geometry properties. Besides, compared with [Ben-David et al., 2010], OT benefits from the advantages of Wasserstein distance by its gradient property [Arjovsky et al., 2017] and the promising generalization bound [Redko et al., 2017]. The empirical studies [Gulrajani et al., 2017, Shen et al., 2018] also demonstrated the effectiveness of OT for extracting the invariant features to align the marginal distributions of different domains.
Furthermore, although the optimal transport process could constrain the labelled samples of the same class to stay close to each other, our preliminary results showed that just implementing optimal transport for domain generalization is not sufficient for a cohesion and separable classification boundary. The model could still suffer from indistinguishable features (see Fig. 4(c)). In order to train the model to predict well on all the domains, this separable classification boundary should also be achieved under a domain-agnostic manner. That is, for a pair of instances, no matter which domain they come from, they should stay close to each other if they are in the same class and vice-versa. To this end, we further promote metric learning as an auxiliary objective for leveraging the source domain label information for a domain-independent distinguishable classification boundary.
To summarize, we deployed the optimal transport technique with Wasserstein distance for domain generalization for extracting the domain invariant features. To avoid ambiguous classification boundary, we proposed to implement metric learning strategies to achieve a distinguishable feature space. Therefore, we proposed the Wasserstein Adversarial Domain Generalization (WADG) algorithm.
In order to check the effectiveness of the proposed approach, we tested the algorithm on two benchmarks comparing with some recent domain generalization baselines. The experiment results showed that our proposed algorithm could outperform most of the baselines, which confirms the effectiveness of our proposed algorithm. Furthermore, the ablation studies also demonstrated the contributions of our algorithm.
2 Related Works
2.1 Domain Generalization
The goal of DG is to learn a model that can extract common knowledge that is shared across source domains and generalize well on the target domain. Compare with DA, the main challenge of DG is that the target domain data is not available during the learning process.
A common framework for DG is to extract the most informative and transferable underlying common features from source instances generated from different distributions and to generalize to unseen one. This kind of approach holds with the assumption that there exists an underlying invariant feature distribution among all domains, and that consequently such invariant features can generalize well to a target domain. Muandet et al. [2013] implemented MMD as a distribution regularizer and proposed the kernel-based Domain Invariant Component Analysis (DICA) algorithm. An autoencoder-based model was proposed by Ghifary et al. [2015] under a multi-task learning setting to learn domain-invariant features via adversarial training. Li et al. [2018c] proposed an end-to-end deep domain generalization approach by leveraging deep neural networks for domain-invariant representation learning. Motiian et al. [2017] proposed to minimize the semantic alignment loss as well as the separation loss based on deep learning models. Li et al. [2018b] proposed a low-rank Convolutional Neural Network model based on domain shift-robust deep learning methods.
There are also some approaches to tackle the domain generalization problems in a meta-learning manner. To the best of our knowledge, Li et al. [2018a] first proposed to adopt the Meta Agnostic Meta-Learning (MAML) [Finn et al., 2017] which back-propagates the gradients of ordinary loss function of meta-test tasks. As pointed by Dou et al. [2019], such an approach might lead to a sub-optimal solution, as it is highly abstracted from the feature representations. Balaji et al. [2018] proposed MetaReg algorithm in which a regularization function ( weighted loss) is implemented for the classification layer of the model but not for the feature extractor layers. Then, [Li et al., 2019] proposes an auxiliary meta loss which is gained based on the feature extractor. Furthermore, the network architecture of [Li et al., 2019] is the widely used feature-critic style model based on a similar model from domain adversarial training technique [Ganin et al., 2016]. Dou et al. [2019] and Matsuura & Harada [2020] also started to implement clustering techniques on the invariant feature space for better classification and showed better performance on the target domain.
2.2 Metric Learning
Metric learning aims to learn a discriminative feature embedding where similar samples are closer while different samples are further apart [Deng et al., 2020]. Hadsell et al. [2006] proposed the siamese network together with contrastive loss to guide the instances stay close with each other in the feature space if they have the same labels and push them apart vice-versa. Schroff et al. [2015] proposed the triplet loss aiming to learn a feature space where a positive pair has higher similarity than the negative pair when comparing by the same anchor with a given margin. Oh Song et al. [2016] showed that neither the contrastive loss nor triplet loss could efficiently explore the full pair-wise relations between instances under the mini-batch training setting. They further propose the lifted structure loss to fully utilize pair-wise relations across batches. However, it only choose equal number of positive pairs as negative ones randomly, and many informative pairs are discarded [Wang et al., 2019], which restricts the ability of finding the informative pairs. Yi et al. [2014] proposed the binomial deviance loss which could measure the hard pairs. One remarkable work by Wang et al. [2019] combines the advantages both from lifted structure loss and binomial loss to leverage the pair-similarity. They proposed to leverage not only pair-similarities (positive or negative pairs with each other) but also self-similarity which enables the learner to collect and weight informative pairs (positive or negative pairs) under an iterative (mining and weighting) manner. For a pair of instances, the self-similarity is gained from itself. Such a multi-similarity has been shown could measure the similarity and could cluster the samplers more efficiently and accurately. In the context of domain generalization, Dou et al. [2019] proposed to guide the learner to leverage from the local similarity in the semantic feature space, in which the authors argued may contain essential domain-independent general knowledge for domain generalization and adopt the constrative loss and triplet loss to encourage the clustering for solving this issue. Leveraging from the across-domain class similarity information can encourage the learner to extract robust semantic features that regardless of domains, which is an useful auxiliary information for the learner. If the learner could not separate the samples (from different source domains) with domain-independent class-specific cohesion and separation on the domain invariant feature space, it would still suffer from ambiguous decision boundaries. This ambiguous decision boundaries might still be sensitive to the unseen target domain. Matsuura & Harada [2020] implement unsupervised clustering on source domains and showed better classification performance. Our work is orthogonal to previous works, proposing to enforce more distinguishable invariant features space via Wasserstein adversarial training and encouraging to leverage from label similarity information for better classification boundary.
Symbol | Meaning | Symbol | Meaning | ||||||
---|---|---|---|---|---|---|---|---|---|
The feature extraction function | Parameter of feature extraction network | ||||||||
The critic function | Parameter of critic network | ||||||||
The classification function | Parameter for classification network | ||||||||
The number of source domains | The -th instance from the -th domain | ||||||||
|
|
||||||||
|
The extracted feature from domain | ||||||||
|
The label for corresponding instance | ||||||||
The similarity matrix |
|
||||||||
The weight for similarity |
|
||||||||
|
Fixed parameter for negative mining | ||||||||
|
|
||||||||
|
|
3 Preliminaries and Problem Setup
We start by introducing some preliminaries. In order to better summarize the notations symbols in this work, we provide the list of notations and symbols in Table 1.
3.1 Notations and Definitions
Following Redko et al. [2017] and Li et al. [2017a], suppose we have known source domains distributions , and domain contains labeled instances in total, denoted by , where is the instance feature from the domain and are the corresponding labels. For a hypothesis class , the expected source and target risk of a hypothesis over domain distribution is the probabilities that wrongly predicts on the entire distribution : , where is the loss function. The empirical loss is also defined by: .
In the setting of domain generalization, we only have the access to the seen source domains but have no information about the target domain. The learner is expected to extract the underlying invariant feature space across the source domains and generalize to a new target domain.
3.2 Optimal Transport and Wasserstein Distance
We follow Redko et al. [2017] and define as the cost function for transporting one unit of mass to , then the primal form of the Wasserstein distance between and could be computed by,
(1) |
where is the probability coupling on with marginals and referring to all the possible coupling functions. Throughout this paper, we adopt Wasserstein-1 distance only ().
Computing the primal form of Wasserstein distance (Eq. 1) is computational inefficiently. Assuming , the time complexity for directly computing Eq. 1 is . On the contrary, leveraging the Kantorovich-Rubinstein duality [Wainwright, 2019] of Wasserstein distances could help to get a more efficient approximation. Assume a -Lipschitz-continuous the cost function: , we can prove that for any function ,
The equality arrives when reaches the maximum of the right side,
(2) |
In practice, such a function could be approximated by a neural-network, which allows us to compute this Kantorovich-Rubinstein duality efficiently by computing the expectation and the complexity is only . Empirically, to compute the is equivalent to find out the maximum of (by an operation). General neural network optimizer ( SGD or Adam) can efficiently solve the maximum problem to evaluate the dual value of distance.

Optimal transport theory and Wasserstein distance were recently investigated in the context of machine learning [Arjovsky et al., 2017] especially in the domain adaptation area [Courty et al., 2016, Zhou et al., 2020]. The general idea of implementing the optimal transport technique for domain generalization across domains is illustrated in Fig. 2. To learn domain invariant features, OT technique is implemented to achieve domain alignments for extracting invariant features. After the OT transition, the invariant features can be generalized to unseen domain.
3.3 Metric Learning
For a pair of instances and , the notion of positive pairs usually refers to the condition where pair have same labels (), while the negative pairs usually refers to the condition . The central idea of metric learning is to encourage a pair of instances who have the same labels to be closer, and push negative pairs to be apart from each other [Wu et al., 2017].
Follow the framework of Wang et al. [2019], we show the general pair-weighting process of metric learning. Assuming the feature extractor parameterized by projects the instance to a -dimensional normalized space: . Then, for two samples and , the similarity between them could be defined as the inner product of the corresponding feature vector:
(3) |
To leverage the across-domain class similarity information can encourage the learner to extract the classification boundary that regardless of domains, which is an useful auxiliary information for the learner. We further elaborate it in section 4.2.
4 Proposed Method
The high-level idea of WADG algorithm is to learn a domain-invariant feature space and domain-agnostic classification boundary. Firstly, we align the marginal distribution of different source domains via optimal transport by minimizing the Wasserstein distance to achieve the domain-invariant feature space. And then, we adopt metric learning objective to guide the learner to leverage the class similarity information for a better classification boundary. A general workflow of our method is illustrated in Fig. 3(a). The model contains three major parts: a feature extractor, a classifier and a critic function.
The feature extractor function , parameterized by , extracts the features from different source domain. For set of instances from domain , we can then denote the extracted feature from domain as . The classification function , parameterized by , is expected to learn to predict labels of instances from all the domains correctly. The critic function , parameterized by , aims to measure the empirical Wasserstein distance between features from a pair of source domains. For the target domain, all the instances and labels are absent during the training time.
WADG aims to learn the domain-agnostic features with distinguishable classification boundary. During each train round, the network receives the labelled data from all domains and train the classifier under a supervised mode with the classification loss . For the classification process, we use the typical cross-entropy loss for all source domains:
(4) |
Through this, the model could learn to train the category information on over all the domains. The feature extractor is then trained to minimize the estimated Wasserstein Distance in an adversarial manner with the critic with an objective . We then adopt a metric learning objective (namely, ) for leveraging the similarities for a better classification boundary. Our full method then solve the joint loss function,
where is the adversarial objective function, and is the metric learning objective function. In the sequel, we will elaborate these two objectives in section 4.1 and section 4.2, respectively.



4.1 Adversarial Domain Generalization via Optimal Transport
As optimal transport could constrain labelled source samples of the same class to remain close during the transportation process [Courty et al., 2016]. We deploy optimal transport with Wasserstein distance [Redko et al., 2017, Shen et al., 2018] for aligning the marginal feature distribution over all the source domains.
A brief workflow of the optimal transport for a pair of sourcce domains is illustrated in Fig. 3(b). The critic function estimates the empirical Wasserstein Distance between the each source domain through a pair of instances from the empirical sets and . In practice [Shen et al., 2018], the dual term Eq. 2 of Wasserstein distance could be computed by,
(5) |
As in domain generalization setting, there usually exists more that two source domains, we can sum all the empirical Wasserstein distance between each pair of source domains,
(6) |
Throughout this pair-wise optimal transport process, the learner could extract a domain-invariant feature space, we then propose to apply metric learning approaches to leverage the class label similarity for domain independent clustering feature extraction. We then introduce the metric learning for domain agnostic clustering in the next section.
4.2 Metric Learning for Domain Agnostic Classification Boundary
As aforementioned, only aligning the marginal features via adversarial training is not sufficient for DG since there may exist a ambiguous decision boundary [Dou et al., 2019]. When predicting on the target domain, the learner may still suffer from this ambiguous decision boundary. To this end, we adopt the metric learning techniques [Wang et al., 2019] to help cluster the instances and promote a better prediction boundary for better generalization.
To solve this, except to the supervised source classification and alignment of the marginal distribution across domains with the Wasserstein adversarial training defined above, we then further encourage robust domain-independent local clustering via leverage from label information using the metric learning objective. The brief workflow is illustrated in Fig. 3(c). Specifically, we adopt the metric learning objective to require the images regardless of their domains to follow the two aspects: 1) images from the same class are semantically similar, thereby should be mapped nearby in the embedding space (semantic clustering), while 2) instances from different classes should be mapped apart from each other in embedding space. Since goal of domain generalization aims to learn to hypothesis could predict well on all the domains, the clustering should also be achieved under a domain-agnostic manner.
To this end, we mix the instances from all the source domains together and encourage the clustering for domain agnostic features via the metric learning techniques to achieve a domain-independent clustering decision boundary. For this, during each training iteration, for a batch from source domains with batch size , we mix all the instances from each domain and denoted by with total size . We first measure the relative similarity between the negative and positive pairs, which is introduced in the next sub-section.
4.2.1 Pair Similarity Mining
Assume is an anchor, a negative pair and a positive pair are selected if and satisfy the negative condition and the positive condition , respectively :
(7) |
where is a given margin. Through Eq. 7 and specific margin , we will have a set of negative pairs and a set of positive pairs . This process (Eq. 7) could roughly cluster the instances with each anchor by selecting informative pairs (inside of the margin), and discard the less informative ones (outside of the margin).
With such roughly selected informative pairs and , we then assign the instance with different weights. Intuitively, if a instance has higher similarity with an anchor, then it should stay closer with the anchor and vice-versa. We introduce the weighting process in the next section.
4.2.2 Pair Weighting
For instances of positive pairs, if they are more similar with the anchor, then it should have higher weights while give the negative pairs with lower weights if they are more dissimilar, no matter which domain they come from. Through this process, we can push the instances into several groups via measure their similarities.
For instances, computing the similarity between each pair could result in a similarity matrix . For a loss function based on pair similarity, it can usually be defined by . Let be the row, column element of matrix . The gradient the network could be computed by,
(8) |
Eq. 8 could be reformulated into a new loss function as,
(9) |
usually the metric loss defined similarity matrix and label could be reformulated by Eq. 9. The term in Eq. 9 could be treated as an constant scalar since it doesn’t contain the gradient of . Then, we just need to compute the gradient term for the positive and negative pairs. Since the goal is to encourage the positive pairs to be closer, then we can assume the gradient , , . Conversely, for a negative pair, we could assume . Thus, Eq. 9 is transformed by the summation over all the positive pair () and negative pairs (),
(10) |
where is regarded as the weight for similarity . Since our goal is to encourage the positive pairs to be closer, then we can assume the weight for positive pairs is smaller than 0. Conversely, for a negative pair, we can assume the weight is larger than 0. The intuition is that for a negative pair of instances, let the weight be positive, we can give it a higher loss value. Then, the learner can learn to distinguish them. On the contrary, we can assign the negative weights towards the positive pairs, which will guide the learner to not separate them apart. For each pair of instances , we could assign different weights according to their similarities . Then, we can denote and as the weight of a positive or negative pairs’ similarity, respectively.
Previously, Yi et al. [2014] and Wang et al. [2019] applied a soft function for measuring the similarity. We then consider the similarity of the pair itself ( self-similarity), the negative similarity and the positive similarity. The weight of self-similarity could be measured by with a small threshold . For a selected negative pair the corresponding weight (see Eq. 10) could be defined by the soft function of self-similarity together with the negative similarity:
(11) |
Similarly, the weight of a positive pair is defined by,
(12) |
Then, take Eq. 11 and Eq. 12 into Eq. 10, and integrate Eq. 10 with the similarity mining , we have the objective function for clustering,
(13) |
where , and are fixed hyper-parameters, we elaborate them in the empirical setting section 5.2. Then, the whole objective of our proposed method is,
(14) |
where and are coefficients to regularize and respectively.
Based on these above, we propose the WADG algorithm in Algorithm 1. And we show the empirical results in the next section.
5 Experiments and Results
5.1 Datasets
In order to evaluate our proposed approach, we implement experiments on three common used datasest: VLCS [Torralba & Efros, 2011], PACS [Li et al., 2017a] and Office-home [Venkateswara et al., 2017] dataset. The VLCS dataset contains images from 4 different domains: PASCAL VOC2007 (V), LabelMe (L), Caltech (C), and SUN09 (S). Each domain includes five classes: bird, car, chair, dog and person. PACS dataset is a recent benchmark dataset for domain generalization. It consists of four domains: Photo (P), Art painting (A), Cartoon (C), Sketch (S), with objects from seven classes: dog, elephant, giraffe, guitar, house, horse, person. Office-Home is a more challenging dataset, which contains four different domains: Art (Ar), Clipart (Cl), Product (Pr) and Real World (Rw), with categories in each domain. Previous work showed that matter the adversarial model is trained under supervised [Long et al., 2017], semi-supervised [Zhou et al., 2020] or unsupervised [Long et al., 2018] way, the model will suffer from learning the diverse feature. To test our domain generalization model on this dataset could also help to affirm the effectiveness of our approach.
5.2 Baselines and Implementation details
To show the effectiveness of our proposed approach, we compare our algorithm on the benchmark datasets with the following recent domain generalization methods.
-
1.
Deep All: We follow the standard evaluation protocol of Domain Generalization to set up the pre-trained Alexnet or ResNet-18 fine-tuned on the aggregation of all source domains with only the classification loss.
-
2.
TF [Li et al., 2017b]: A low-rank parameterized Convolution Neural Network model which aims to reduce the total number of model parameters for an end-to-end Domain Generalization training.
-
3.
CIDDG [Li et al., 2018c]: Matches the conditional distribution by change the class prior.
-
4.
MLDG [Li et al., 2018a]: The meta-learning approach for domain generalization. It runs the meta-optimization on simulated meta-train/ meta-test sets with domain shift
-
5.
CCSA [Motiian et al., 2017]: The contrastive semantic alignment loss was adopted together with the source classification loss function for both the domain adaptation and domain generalization problem.
-
6.
MMD-AAE [Li et al., 2018b]: The Adversarial Autoencoder model was adopted together with the Mean-Max Discrepancy to extract a domain invariant feature for generalization.
-
7.
D-SAM [D’Innocente & Caputo, 2018]: It aggregates domain-specific modules and merges general and specific information together for generalization.
-
8.
JiGen [Carlucci et al., 2019]: It achieves domain generalization by solving the Jigsaw puzzle via the unsupervised task.
-
9.
MASF [Dou et al., 2019]: A meta-learning style method which based on MLDG and combined with Consitrastive Loss/ Triplet Loss to encourage domain-independent semantic feature space.
-
10.
MMLD [Matsuura & Harada, 2020]: An approach that mixes all the source domains by assigning a pseudo domain label for extract domain-independent cluster feature space.
Hyper-parameters | Value | Hyper-parameters | Value |
---|---|---|---|
learning rate | PACS: | ||
Office-home: | |||
Following the general evaluation protocol of domain generalization ( Dou et al. [2019], Matsuura & Harada [2020]), on PACS and VLCS dataset. We first test our algorithm on by using AlexNet [Krizhevsky et al., 2012] backbones by removing the last layer as feature extractor. For preparing the dataset, we follow the train/val./test split and the data pre-processing protocol of Matsuura & Harada [2020]. As for the classifier, we initialize a three-layers MLP whose input has the same number of inputs as the feature extractor’s output and to have the same number of outputs as the number of object categories (2048-256-256-), where is the number of classes. For the critic network, we also adopt a three-layers MLP (2048-1024-1024-1). For the metric learning objective, we use the output of the second layer of classifier network (with size 256) for computing the similarity.
In order to better demonstrating the hyper-parameters used in this work, we firstly summarized the value of hyper-parameters in Table. 2. The corresponding descriptions are provided in the following parts.
We adopt the ADAM [Kingma & Ba, 2014] optimizer for training with learning rate ranging from to for the whole model together with mini-batch size .
For stable training, we set coefficient to regularize the adversarial loss, where is the training progress, to regularize the adversarial loss. This regularization scheme has been widely used in adversarial training based domain adaptation and generalization setting ( [Long et al., 2017, Wen et al., 2019, Matsuura & Harada, 2020]) and have been proved could help to stabilize the training process. For the setting of , we follow the setting of [Dou et al., 2019] and set the value to . In our preliminary validation results, the performance is not sensitive with . We also tried to range from to via reverse validation and didn’t observe obvious differences.
Method | Caltech | LabelMe | Pascal | Sun | Avg. |
---|---|---|---|---|---|
Deep All | |||||
D-MATE [Ghifary et al., 2015] | |||||
CIDDG [Li et al., 2018c] | |||||
CCSA [Motiian et al., 2017] | |||||
SLRC [Ding & Fu, 2017] | |||||
TF [Li et al., 2017b] | |||||
MMD-AAE [Li et al., 2018b] | |||||
D-SAM [D’Innocente & Caputo, 2018] | |||||
MLDG [Li et al., 2018a] | |||||
JiGen [Carlucci et al., 2019] | |||||
MASF [Dou et al., 2019] | |||||
MMLD [Matsuura & Harada, 2020] | |||||
Ours |
Then, we examined our algorithm on the office-home benchmark, which is more challenging than the previous PACS and VLCS datasets. We follow the setting of [Carlucci et al., 2019], which is the most recent work who also evaluated on office-home dataset, to have a fair comparison. For this Office-home dataset, we also used reverse validation to set the learning rate as for the whole model. For the remaining hyper-parameters, we keep the same with PACS and VLCS experiments. To avoid over-training, we also adopt the early stopping technique. All the experiments are programmed with PyTorch [Paszke et al., 2019].
5.3 Experiments Results
We first reported the empirical results on PACS and VLCS dataset using AlexNet as feature extractor in Table 3 and Table 4, respectively. For each generalization task, we train the model on all the source domains and test on the target domain and report the average of top 5 accuracy values. The empirical results refers to the average accuracy about training on source domains while testing on the target domain.
From the empirical results, we can observe our method outperforms the baselines both on the PACS and VLCS dataset, indicating an improvement on benchmark performances. This showed the effectiveness of our method. Then, we report the empirical results on Office-home dataset in Table 5. As stated before, Office-home is a more larger and challenging dataset contains more diverse features from different classes. To evaluate the performance on this dataset requires large amount of computational resources. Due to the limits, we follow the evaluation protocol of Carlucci et al. [2019] to report the empirical results. From those results, we could observe that our algorithm outperforms the previous Domain Generalization method, this also confirm the effectiveness of our proposed method.
5.4 Further Analysis
To further show the effectiveness of our algorithm especially on more deep models, follow Dou et al. [2019], we also report the results of our algorithm by using ResNet-18 backbone on PACS dataset in Table 6. The ResNet-18 backbone, the output feature dim will be . From the results, we could observe that our method could outperform the baselines on most generalization tasks and on average accuracy improvement.
Then, we implement ablation studies on each component of our algorithm. We report the empirical results of ablation studies in Table 7, where we test the ablation studies on both the AlexNet backbone and ResNet-18 backbone. We compare the ablations by, (1) Deep All: Train the model using feature extractor on source domain datasets with classification loss only, that is, neither optimal transport nor metric learning techniques is adopted. (2) No : Train the model with classification loss and metric learning loss but without adversarial training component; (3) w.o. : omit the positive weighting scheme in (4) w.o. : omit the positive weighting scheme in . (5) No : Train the model with classification loss and adversarial loss but without metric learning component; (6) WADG-All: Train the model with full objective Eq. 14.




From the results, we could observe that one we omit the adversarial training, the accuracy would drop off rapidly ( with AlexNet backbone and with ResNet-18 backbone). The contribution of the metric learning loss is relatively small compared with adversarial loss. Comparing the ablations w.o. and w.o. , we could observe almost similar accuracy. This indicates that the positive and negative weighting scheme of the metric learning objective may have equivalent contribution. . Once we omit the metric learning loss, the performance will drop and with AlexNet and ResNet-18 backbone, respectively.




AlexNet | ResNet-18 | |||||||||
---|---|---|---|---|---|---|---|---|---|---|
Ablation | Art | Carton | Sketch | Photo | Avg. | Art | Carton | Sketch | Photo | Avg. |
Deep All | ||||||||||
No | ||||||||||
No | ||||||||||
w.o. | ||||||||||
w.o. | ||||||||||
WADG-All |
Then, to better understand the contribution of each component of our algorithm, the T-SNE visualization of the ablation studies of each components on PACS and VLCS dataset are represented in Fig. 4 for the generalization task of target domain Photo. and Fig. 5 for the generalization task of target domain Caltech, respectively. Since our goal is to not only align the feature distribution but also encourage a cohesion and separable boundary, in order to show the alignment and clustering performance, we report the T-SNE features of all the source domains and target domain to show the feature alignment and clustering across domains.
For PASC dataset, as we can see, the T-SNE features by Deep All could neither project the instances from different domains to align with each other nor cluster the features into groups. The T-SNE features by No showed the metric learning loss could to some extent to cluster the features, but without the adversarial training, the features could not be aligned well. The T-SNE features by No showed that the adversarial training could help to align the features from different domains but could not have a good clustering performance. The T-SNE features by WADG-All showed that the full objective could help to not only align the features from different domains but also could cluster the features from different domains into several cluster groups, which confirms the effective of our algorithm.
As for the VLCS dataset, we could observe similar performance on the T-SNE on the VLCS dataset while the features are somehow overlap with each other. This is due to the features in Caltech domain is somehow easy to learn and predict. As also analyzed in [Li et al., 2017a], a supervised model on Caltech domain could achieved accuracy, which also confirms that the features in Caltech domain is easy to learn indicating the features might be more likely overlapping with each other. As we can see from Fig.5(d), the WADG method could help to separate the features with each other, which again confirms the effectiveness of our proposed method.
6 Conclusion
In this paper, we proposed the Wasserstein Adversarial Domain Generalization algorithm for not only aligning the source domain features and transferring to an unseen target domain but also leveraging the label information across domains. We first adopt optimal transport with Wasserstein distance for aligning the marginal distribution and then adopt the metric learning method to encourage a domain-independent distinguishable feature space for a clear classification boundary. The experiments results showed our proposed algorithm could outperform most of the baseline methods on two standard benchmark datasets. Furthermore, the ablation studies and visualization of the T-SNE features also confirmed the effectiveness of our algorithm.
Acknowledgement
This work has been partially supported by Natural Sciences and Engineering Research Council of Canada (NSERC), The Fonds de recherche du Québec - Nature et technologies (FRQNT). Fan Zhou is supported by China Scholarship Council. Boyu Wang is supported by the Natural Sciences and Engineering Research Council of Canada (NSERC), Discovery Grants Program.
A full version of this preprint has been published as:
Fan Zhou, Zhuqing Jiang, Changjian Shui, Boyu Wang, Brahim Chaib-draa, Domain generalization via optimal transport with metric similarity learning, Neurocomputing, Volume 456, 2021, Pages 469-480, https://doi.org/10.1016/j.neucom.2020.09.091. (https://www.sciencedirect.com/science/article/pii/S0925231221002009)
References
- Arjovsky et al. [2017] Arjovsky, M., Chintala, S., & Bottou, L. (2017). Wasserstein gan. arXiv preprint arXiv:1701.07875, .
- Balaji et al. [2018] Balaji, Y., Sankaranarayanan, S., & Chellappa, R. (2018). Metareg: Towards domain generalization using meta-regularization. In Advances in Neural Information Processing Systems (pp. 998–1008).
- Ben-David et al. [2010] Ben-David, S., Blitzer, J., Crammer, K., Kulesza, A., Pereira, F., & Vaughan, J. (2010). A theory of learning from different domains. Machine Learning, 79, 151–75. URL: http://www.springerlink.com/content/q6qk230685577n52/.
- Carlucci et al. [2019] Carlucci, F. M., D’Innocente, A., Bucci, S., Caputo, B., & Tommasi, T. (2019). Domain generalization by solving jigsaw puzzles. In CVPR.
- Courty et al. [2016] Courty, N., Flamary, R., Tuia, D., & Rakotomamonjy, A. (2016). Optimal transport for domain adaptation. IEEE transactions on pattern analysis and machine intelligence, 39, 1853–65.
- Deng et al. [2020] Deng, W., Zheng, L., Sun, Y., & Jiao, J. (2020). Rethinking triplet loss for domain adaptation. IEEE Transactions on Circuits and Systems for Video Technology, (pp. 1–).
- Ding & Fu [2017] Ding, Z., & Fu, Y. (2017). Deep domain generalization with structured low-rank constraint. IEEE Transactions on Image Processing, 27, 304–13.
- Dou et al. [2019] Dou, Q., de Castro, D. C., Kamnitsas, K., & Glocker, B. (2019). Domain generalization via model-agnostic learning of semantic features. In Advances in Neural Information Processing Systems (pp. 6447–58).
- D’Innocente & Caputo [2018] D’Innocente, A., & Caputo, B. (2018). Domain generalization with domain-specific aggregation modules. In German Conference on Pattern Recognition (pp. 187–98). Springer.
- Finn et al. [2017] Finn, C., Abbeel, P., & Levine, S. (2017). Model-agnostic meta-learning for fast adaptation of deep networks. In Proceedings of the 34th International Conference on Machine Learning-Volume 70 (pp. 1126–35). JMLR. org.
- Ganin et al. [2016] Ganin, Y., Ustinova, E., Ajakan, H., Germain, P., Larochelle, H., Laviolette, F., Marchand, M., & Lempitsky, V. (2016). Domain-adversarial training of neural networks. The Journal of Machine Learning Research, 17, 2096–30.
- Ghifary et al. [2015] Ghifary, M., Bastiaan Kleijn, W., Zhang, M., & Balduzzi, D. (2015). Domain generalization for object recognition with multi-task autoencoders. In Proceedings of the IEEE international conference on computer vision (pp. 2551–9).
- Goldberg et al. [2009] Goldberg, A., Zhu, X., Singh, A., Xu, Z., & Nowak, R. (2009). Multi-manifold semi-supervised learning. In Artificial Intelligence and Statistics (pp. 169–76).
- Gulrajani et al. [2017] Gulrajani, I., Ahmed, F., Arjovsky, M., Dumoulin, V., & Courville, A. C. (2017). Improved training of wasserstein gans. In Advances in neural information processing systems (pp. 5767–77).
- Hadsell et al. [2006] Hadsell, R., Chopra, S., & LeCun, Y. (2006). Dimensionality reduction by learning an invariant mapping. In 2006 IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR’06) (pp. 1735–42). IEEE volume 2.
- Ilse et al. [2019] Ilse, M., Tomczak, J. M., Louizos, C., & Welling, M. (2019). Diva: Domain invariant variational autoencoders. arXiv preprint arXiv:1905.10427, .
- Kamnitsas et al. [2018] Kamnitsas, K., Castro, D. C., Folgoc, L. L., Walker, I., Tanno, R., Rueckert, D., Glocker, B., Criminisi, A., & Nori, A. (2018). Semi-supervised learning via compact latent space clustering. arXiv preprint arXiv:1806.02679, .
- Kingma & Ba [2014] Kingma, D. P., & Ba, J. (2014). Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980, .
- Krizhevsky et al. [2012] Krizhevsky, A., Sutskever, I., & Hinton, G. E. (2012). Imagenet classification with deep convolutional neural networks. In Advances in neural information processing systems (pp. 1097–105).
- Li et al. [2017a] Li, D., Yang, Y., Song, Y.-Z., & Hospedales, T. (2017a). Deeper, broader and artier domain generalization. In International Conference on Computer Vision.
- Li et al. [2017b] Li, D., Yang, Y., Song, Y.-Z., & Hospedales, T. M. (2017b). Deeper, broader and artier domain generalization. In Proceedings of the IEEE international conference on computer vision (pp. 5542–50).
- Li et al. [2018a] Li, D., Yang, Y., Song, Y.-Z., & Hospedales, T. M. (2018a). Learning to generalize: Meta-learning for domain generalization. In Thirty-Second AAAI Conference on Artificial Intelligence.
- Li et al. [2018b] Li, H., Jialin Pan, S., Wang, S., & Kot, A. C. (2018b). Domain generalization with adversarial feature learning. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (pp. 5400–9).
- Li et al. [2018c] Li, Y., Tian, X., Gong, M., Liu, Y., Liu, T., Zhang, K., & Tao, D. (2018c). Deep domain generalization via conditional invariant adversarial networks. In Proceedings of the European Conference on Computer Vision (ECCV) (pp. 624–39).
- Li et al. [2019] Li, Y., Yang, Y., Zhou, W., & Hospedales, T. M. (2019). Feature-critic networks for heterogeneous domain generalization. arXiv preprint arXiv:1901.11448, .
- Long et al. [2018] Long, M., Cao, Z., Wang, J., & Jordan, M. I. (2018). Conditional adversarial domain adaptation. In Advances in Neural Information Processing Systems (pp. 1640–50).
- Long et al. [2017] Long, M., Cao, Z., Wang, J., & Philip, S. Y. (2017). Learning multiple tasks with multilinear relationship networks. In Advances in neural information processing systems (pp. 1594–603).
- Ma et al. [2019a] Ma, Z., Chang, D., Xie, J., Ding, Y., Wen, S., Li, X., Si, Z., & Guo, J. (2019a). Fine-grained vehicle classification with channel max pooling modified cnns. IEEE Transactions on Vehicular Technology, 68, 3224–33.
- Ma et al. [2019b] Ma, Z., Ding, Y., Wen, S., Xie, J., Jin, Y., Si, Z., & Wang, H. (2019b). Shoe-print image retrieval with multi-part weighted cnn. IEEE Access, 7, 59728–36.
- Ma et al. [2018a] Ma, Z., Lai, Y., Kleijn, W. B., Song, Y.-Z., Wang, L., & Guo, J. (2018a). Variational bayesian learning for dirichlet process mixture of inverted dirichlet distributions in non-gaussian image feature modeling. IEEE transactions on neural networks and learning systems, 30, 449–63.
- Ma et al. [2013] Ma, Z., Leijon, A., & Kleijn, W. B. (2013). Vector quantization of lsf parameters with a mixture of dirichlet distributions. IEEE Transactions on Audio, Speech, and Language Processing, 21, 1777–90.
- Ma et al. [2018b] Ma, Z., Yu, H., Chen, W., & Guo, J. (2018b). Short utterance based speech language identification in intelligent vehicles with time-scale modifications and deep bottleneck features. IEEE transactions on vehicular technology, 68, 121–8.
- Matsuura & Harada [2020] Matsuura, T., & Harada, T. (2020). Domain generalization using a mixture of multiple latent domains. In AAAI.
- Motiian et al. [2017] Motiian, S., Piccirilli, M., Adjeroh, D. A., & Doretto, G. (2017). Unified deep supervised domain adaptation and generalization. In The IEEE International Conference on Computer Vision (ICCV).
- Muandet et al. [2013] Muandet, K., Balduzzi, D., & Schölkopf, B. (2013). Domain generalization via invariant feature representation. In International Conference on Machine Learning (pp. 10–8).
- Oh Song et al. [2016] Oh Song, H., Xiang, Y., Jegelka, S., & Savarese, S. (2016). Deep metric learning via lifted structured feature embedding. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 4004–12).
- Paszke et al. [2019] Paszke, A., Gross, S., Massa, F., Lerer, A., Bradbury, J., Chanan, G., Killeen, T., Lin, Z., Gimelshein, N., Antiga, L. et al. (2019). Pytorch: An imperative style, high-performance deep learning library. In Advances in Neural Information Processing Systems (pp. 8024–35).
- Redko et al. [2017] Redko, I., Habrard, A., & Sebban, M. (2017). Theoretical analysis of domain adaptation with optimal transport. In Joint European Conference on Machine Learning and Knowledge Discovery in Databases (pp. 737–53). Springer.
- Schroff et al. [2015] Schroff, F., Kalenichenko, D., & Philbin, J. (2015). Facenet: A unified embedding for face recognition and clustering. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 815–23).
- Shen et al. [2018] Shen, J., Qu, Y., Zhang, W., & Yu, Y. (2018). Wasserstein distance guided representation learning for domain adaptation. In AAAI Conference on Artificial Intelligence.
- Torralba & Efros [2011] Torralba, A., & Efros, A. A. (2011). Unbiased look at dataset bias. In CVPR 2011 (pp. 1521–8). IEEE.
- Venkateswara et al. [2017] Venkateswara, H., Eusebio, J., Chakraborty, S., & Panchanathan, S. (2017). Deep hashing network for unsupervised domain adaptation. In (IEEE) Conference on Computer Vision and Pattern Recognition (CVPR).
- Volpi et al. [2018] Volpi, R., Namkoong, H., Sener, O., Duchi, J. C., Murino, V., & Savarese, S. (2018). Generalizing to unseen domains via adversarial data augmentation. In Advances in Neural Information Processing Systems (pp. 5334–44).
- Wainwright [2019] Wainwright, M. J. (2019). High-dimensional statistics: A non-asymptotic viewpoint volume 48. Cambridge University Press.
- Wang et al. [2019] Wang, X., Han, X., Huang, W., Dong, D., & Scott, M. R. (2019). Multi-similarity loss with general pair weighting for deep metric learning. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (pp. 5022–30).
- Wen et al. [2019] Wen, J., Zheng, N., Yuan, J., Gong, Z., & Chen, C. (2019). Bayesian uncertainty matching for unsupervised domain adaptation. arXiv preprint arXiv:1906.09693, .
- Wu et al. [2017] Wu, C.-Y., Manmatha, R., Smola, A. J., & Krahenbuhl, P. (2017). Sampling matters in deep embedding learning. In Proceedings of the IEEE International Conference on Computer Vision (pp. 2840–8).
- Xie et al. [2018] Xie, J., Song, Z., Li, Y., & Ma, Z. (2018). Mobile big data analysis with machine learning. arXiv preprint arXiv:1808.00803, .
- Xu et al. [2018] Xu, P., Yin, Q., Huang, Y., Song, Y.-Z., Ma, Z., Wang, L., Xiang, T., Kleijn, W. B., & Guo, J. (2018). Cross-modal subspace learning for fine-grained sketch-based image retrieval. Neurocomputing, 278, 75–86.
- Yi et al. [2014] Yi, D., Lei, Z., Liao, S., & Li, S. Z. (2014). Deep metric learning for person re-identification. In 2014 22nd International Conference on Pattern Recognition (pp. 34–9). IEEE.
- Zhao et al. [2019] Zhao, H., Combes, R. T. d., Zhang, K., & Gordon, G. J. (2019). On learning invariant representation for domain adaptation. arXiv preprint arXiv:1901.09453, .
- Zhou et al. [2020] Zhou, F., Shui, C., Huang, B., Wang, B., & Chaib-draa, B. (2020). Discriminative active learning for domain adaptation. arXiv preprint arXiv:2005.11653, .
- Zhu et al. [2019] Zhu, F., Ma, Z., Li, X., Chen, G., Chien, J.-T., Xue, J.-H., & Guo, J. (2019). Image-text dual neural network with decision strategy for small-sample image classification. Neurocomputing, 328, 182–8.
- Zhuang et al. [2019] Zhuang, F., Qi, Z., Duan, K., Xi, D., Zhu, Y., Zhu, H., Xiong, H., & He, Q. (2019). A comprehensive survey on transfer learning. arXiv preprint arXiv:1911.02685, .