This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Generating Synthetic Datasets by
Interpolating along Generalized Geodesics

Jiaojiao Fan Work done partly during an internship at Microsoft Research. Georgia Tech
Atlanta, Georgia, USA
David Alvarez-Melis Microsoft Research & Harvard University
Cambridge, Massachusetts, USA
Abstract

Data for pretraining machine learning models often consists of collections of heterogeneous datasets. Although training on their union is reasonable in agnostic settings, it might be suboptimal when the target domain —where the model will ultimately be used— is known in advance. In that case, one would ideally pretrain only on the dataset(s) most similar to the target one. Instead of limiting this choice to those datasets already present in the pretraining collection, here we explore extending this search to all datasets that can be synthesized as ‘combinations’ of them. We define such combinations as multi-dataset interpolations, formalized through the notion of generalized geodesics from optimal transport (OT) theory. We compute these geodesics using a recent notion of distance between labeled datasets, and derive alternative interpolation schemes based on it: using either barycentric projections or optimal transport maps, the latter computed using recent neural OT methods. These methods are scalable, efficient, and —notably— can be used to interpolate even between datasets with distinct and unrelated label sets. Through various experiments in transfer learning in computer vision, we demonstrate this is a promising new approach for targeted on-demand dataset synthesis.

1 Introduction

Recent progress in machine learning has been characterized by the rapid adoption of large pretrained models as a fundamental building block [Brown et al., 2020]. These models are typically pretrained on large amounts of general-purpose data and then adapted (e.g., fine-tuned) to a specific task of interest. Such pretraining datasets usually draw from multiple heterogeneous data sources, e.g., arising from different domains or sources. Traditionally, all available datasets are used in their entirety during pretraining, for example by pooling them together into a single dataset (when they all share the same label sets) or by training in all of them sequentially one by one. These strategies, however, come with important disadvantages. Training on the union of multiple datasets might be prohibitive or too time-consuming, and it might even be detrimental — indeed, there is a growing line of research showing that removing pretraining data sometimes improves transfer performance [Jain et al., 2022]. On the other hand, sequential learning (i.e., consuming datasets one by one) is infamously prone to catastrophic forgetting [McCloskey and Cohen, 1989, Kirkpatrick et al., 2017]: the information from earlier datasets gradually vanishing as the model is trained on new datasets. The pitfalls of both of these approaches suggest training instead on a subset of the available pretraining datasets, but how to choose that subset is unclear. However, when the target dataset on which the model is to be used is known in advance, the answer is much easier: intuitively, one would train only of those relevant to the target one: e.g., those most similar to it. Indeed, recent work has shown that selecting pretraining datasets based on the distance to the target is a successful strategy [Alvarez-Melis and Fusi, 2020, Gao and Chaudhari, 2021]. However, such methods are limited to selecting (only) among individual datasets already present in the collection.

In this work, we propose a novel approach to generate synthetic pretraining datasets as combinations of existing ones. In particular, this method searches among all possible continuous combinations of the available datasets and thus is not limited to selecting specifically one of them. When given access to the target dataset of interest, we seek among all such combinations the one closest (in terms of a metric between datasets) to the target. By characterizing datasets as sampled from an underlying probability distribution, this problem can be understood as a generalization (from Euclidean to probability space) of the problem of finding among the convex hull of a set of reference points, that which is closest to a query point. While this problem has a simple closed-form solution in Euclidean space (via an orthogonal projection), solving it in probability space is —as we shall see here— significantly more challenging.

We tackle this problem from the perspective of interpolation. Formally, we model the combination of datasets as an interpolation between their distributions, formalized through the notion of geodesics in probability space endowed with the Wasserstein metric [Ambrosio et al., 2008, Villani, 2008]. In particular, we rely on generalized geodesics [Craig, 2016, Ambrosio et al., 2008]: constant-speed curves connecting a pair (or more) distributions parametrized with respect to a ‘base’ distribution, whose role is played by the target dataset in our setting. Computing such geodesics requires access to either an optimal transport coupling or a map between the base distribution and every other reference distribution. The former can be computed very efficiently with off-the-shelf OT solvers, but are limited to generating only as many samples as the problem is originally solved on. In contrast, OT maps allow for on-demand out-of-sample mapping and can be estimated using recent advances in neural OT methods [Fan et al., 2020, Korotin et al., 2022b, Makkuva et al., 2020]. However, most existing OT methods assume unlabeled (feature-only) distributions, but our goal here is to interpolate between classification (i.e., labeled) datasets. Therefore, we leverage a recent generalization of OT for labeled datasets to compute couplings [Alvarez-Melis and Fusi, 2020] and adapt and generalize neural OT methods to the labeled setting to estimate OT maps.

In summary, the contributions of this paper are: (i) a novel approach to generate new synthetic classification datasets from existing ones by using geodesic interpolations, applicable even if they have disjoint label sets, (ii) two efficient methods to solve OT between labeled datasets, which might be of independent interest, (iii) empirical validation of the method in various transfer learning settings.

2 Related work

Mixup and related In-Domain Interpolation

Generating training data through convex combinations was popularized by mixup [Zhang et al., 2018]: a simple data augmentation technique that interpolates features and labels between pairs of points. This and other works based on it [Zhang et al., 2021, Chuang and Mroueh, 2021, Yao et al., 2021] use mixup to improve in-domain model robustness [Zhu et al., 2023] and generalization by increasing the in-distribution diversity of the training data. Although sharing some intuitive principles with mixup, our method interpolates entire datasets —rather than individual datapoints— with the goal of improving across-distribution diversity and therefore out-of-domain generalization.

Dataset synthesis in machine learning

Generating data beyond what is provided as a training dataset is a crucial component of machine learning in practice. Basic transformations such as rotations, cropping, and pixel transformations can be found in most state-of-the-art computer vision models. More recently, Generative Adversarial Nets (GAN) have been used to generate synthetic data in various contexts [Bowles et al., 2018, Yoon et al., 2019], a technique that has proven particularly successful in the medical imaging domain [Sandfort et al., 2019]. Since GANs are trained to replicate the dataset on which they are trained, these approaches are typically confined to generating in-distribution diversity and typically operate on features only.

Discrete OT, Neural OT, Gradient Flows

Barycentric projection [Ambrosio et al., 2008, Perrot et al., 2016] is a simple and effective method to approximate an optimal transport map with discrete regularized OT. On the other hand, there has been remarkable recent progress in methods to estimate OT maps in Euclidean space using neural networks [Makkuva et al., 2020, Fan et al., 2021, Rout et al., 2022], which have been successfully used for image generation [Rout et al., 2022], style transfer [Korotin et al., 2022b], among other applications. However, the estimation of an optimal map between (labeled) datasets has so far received much less attention. Some conditional Monge map solvers [Bunne et al., 2022a] utilize the label information in a semi-supervised manner, where they assume the label-to-label correspondence between two distributions is known. Our method differs from theirs in that we do not require a pre-specified label-to-label mapping, but instead estimate it from data. Geodesics and interpolation in general metric spaces have been studied extensively in the optimal transport and metric geometry literature [McCann, 1997, Agueh and Carlier, 2011, Ambrosio et al., 2008, Santambrogio, 2015, Villani, 2008, Craig, 2016], albeit mostly in a theoretical setting. Gradient flows [Santambrogio, 2015], increasingly popular in machine learning to model existing processes [Bunne et al., 2022b, Mokrov et al., 2021, Fan et al., 2022, Hua et al., 2023] or solving optimization problems over datasets [Alvarez-Melis and Fusi, 2021], provide an alternative approach for interpolation between distributions but are computationally expensive.

3 Background

3.1 Distribution interpolation with OT

Consider 𝒫(𝒳){\mathcal{P}}({\mathcal{X}}) the space of probability distributions with finite second moments over some Euclidean space 𝒳{\mathcal{X}}. Given μ,ν𝒫(𝒳)\mu,\nu\in{\mathcal{P}}({\mathcal{X}}), the Monge formulation optimal transport problem seeks a map T:𝒳𝒳T:{\mathcal{X}}\rightarrow{\mathcal{X}} that transforms μ\mu into ν\nu at a minimal cost. Formally, the objective of this problem is minT:Tμ=νdxT(x)22dμ(x),\min_{T:T\sharp\mu=\nu}\int_{{\mathbb{R}}^{d}}\|x-T(x)\|_{2}^{2}{\text{d}}\mu(x), where the minimization is over all the maps that pushforward distribution μ\mu into distribution ν\nu. While a solution to this problem might not exist, a relaxation due to Kantorovich is guaranteed to have one. This modified version yields the Wasserstein-2 distance: W22(μ,ν)=minπΠ(μ,ν)dxx22dπ(x,x),W_{2}^{2}(\mu,\nu)=\min_{\pi\in\Pi(\mu,\nu)}\int_{{\mathbb{R}}^{d}}\|x-x^{\prime}\|_{2}^{2}{\text{d}}\pi(x,x^{\prime}), where now the constraint set Π(μ,ν)={π𝒫(𝒳2)P0π=μ,P1π=ν}\Pi(\mu,\nu)=\{\pi\in\mathcal{P}(\mathcal{X}^{2})\mid P_{0\sharp}\pi=\mu,P_{1\sharp\pi}=\nu\} contains all couplings with marginals μ\mu and ν\nu. The optimal such coupling is known as the OT plan. A celebrated result by Brenier [1991] states that whenever PP has density with respect to the Lebesgue measure, the optimal TT^{*} exists and is unique. In that case, the Kantorovich and Monge formulations coincide and their solutions are linked by π=(Id,T)μ\pi^{*}=(\text{Id},T^{*})_{\sharp}\mu where Id\rm Id is the identity map. The Wasserstein-2 distance enjoys many desirable geometrical properties compared to other distances for distributions [Ambrosio et al., 2008]. One such property is the characterization of geodesics in probability space [Agueh and Carlier, 2011, Santambrogio, 2015]. When 𝒫(𝒳){\mathcal{P}}({\mathcal{X}}) is equipped with metric W2W_{2}, the unique minimal geodesic between any two distributions μ0\mu_{0} and μ1\mu_{1} is fully determined by π\pi, the optimal transport plan between them, through the relation:

ρtD:=((1t)x+ty)π(x,y),t[0,1],\displaystyle\rho_{t}^{D}:=((1-t)x+ty)\sharp\pi(x,y),\quad t\in[0,1], (1)

known as displacement interpolation. If the Monge map from μ0\mu_{0} to μ1\mu_{1} exists, the geodesic can also be written as

ρtM:=((1t)Id+tT)μ0,t[0,1],\displaystyle\rho^{M}_{t}:=((1-t){\rm Id}+tT^{*})\sharp\mu_{0},\quad t\in[0,1], (2)

and is known as McCann’s interpolation [McCann, 1997]. It is easy to see that ρ0M=μ0\rho^{M}_{0}=\mu_{0} and ρ1M=μ1\rho^{M}_{1}=\mu_{1}.

Such interpolations are only defined between two distributions. When there are m2m\geq 2 marginal distributions {μ1,,μm}\{\mu_{1},\ldots,\mu_{m}\}, the Wasserstein barycenter

ρaB:=argminρi=1maiW22(ρ,μi),aΔm10m\displaystyle\rho^{B}_{a}:=\operatorname*{arg\,min}_{\rho}\sum_{i=1}^{m}a_{i}W_{2}^{2}(\rho,\mu_{i}),\quad a\in\Delta_{m-1}\subset{\mathbb{R}}_{\geq 0}^{m} (3)

generalizes McCann’s interpolation [Agueh and Carlier, 2011]. Intuitively, the interpolation parameters a=[a1,,am]a=[a_{1},\dots,a_{m}] determine the ‘mixture proportions’ of each dataset in the combination, akin to the weights in a convex combination of points in Euclidean space. In particular, when aa is a one-hot vector with ai=1a_{i}=1, then ρaB=μi\rho^{B}_{a}=\mu_{i}, i.e., the barycenter is simply the ii-th distribution. Barycenters have attracted significant attention in machine learning recently [Srivastava et al., 2018, Korotin et al., 2021], but they remain challenging to compute in high dimension [Fan et al., 2020, Korotin et al., 2022a].

Another limitation of these interpolation notions is the non-convexity of W22W_{2}^{2} along them. In Euclidean space, given three points x1,x2,ydx_{1},x_{2},y\in{\mathbb{R}}^{d}, the function txty22t\mapsto\|x_{t}-y\|_{2}^{2}, where xtx_{t} is the interpolation xt=(1t)x1+tx2x_{t}=(1-t)x_{1}+tx_{2}, is convex. In contrast, in Wasserstein space, neither the function tW22(ρtM,ν)t\mapsto W_{2}^{2}(\rho^{M}_{t},\nu) or aW22(ρaB,ν)a\mapsto W_{2}^{2}(\rho^{B}_{a},\nu) are guaranteed to be convex [Santambrogio, 2017, §4.4]. This complicates their theoretical analysis, such as in gradient flows. To circumvent this issue, Ambrosio et al. [2008] introduced the generalized geodesic of {μ1,,μm}\{\mu_{1},\ldots,\mu_{m}\} with base ν𝒫(𝒳)\nu\in\mathcal{P}(\mathcal{X}):

ρaG:=(i=1maiTi)ν,aΔm1,\displaystyle\rho^{G}_{a}:=\left(\sum_{i=1}^{m}a_{i}T^{*}_{i}\right)\sharp\nu,\quad a\in\Delta_{m-1}, (4)

where TiT^{*}_{i} is the optimal map from ν\nu to μi\mu_{i}.

Lemma 1.

The functional μW22(μ,ν)\mu\mapsto W_{2}^{2}(\mu,\nu) is convex along the generalized geodesics, and W22(ρaG,ν)i=1maiW22(μi,ν).W_{2}^{2}(\rho^{G}_{a},\nu)\leq\sum_{i=1}^{m}a_{i}W_{2}^{2}(\mu_{i},\nu).

Thus, unlike the barycenter, the generalized geodesic does yield a notion of convexity satisfied by the Wasserstein distance and is easier to compute. The proof of Lemma 1 is postponed to §A. For these reasons, we adopt this notion of interpolation for our approach. It remains to discuss how to use it on (labeled) datasets.

3.2 Dataset distance

Consider a dataset 𝒟P={z(i)}i=1N={x(i),y(i)}i=1Ni.i.d.P(x,y){\mathcal{D}}_{P}=\{z^{(i)}\}_{i=1}^{N}=\{x^{(i)},y^{(i)}\}_{i=1}^{N}\overset{i.i.d.}{\sim}P(x,y). The Optimal Transport Dataset Distance (OTDD) [Alvarez-Melis and Fusi, 2020] measures its distance to another dataset 𝒟Q{\mathcal{D}}_{Q} as:

dOT2(𝒟P,𝒟Q)=\displaystyle d^{2}_{{\textup{OT}}}({\mathcal{D}}_{P},{\mathcal{D}}_{Q})=
minπΠ(P,Q)(xx22+W22(αy,αy))dπ(z,z),\displaystyle\min_{\pi\in\Pi(P,Q)}\int\left(\|x-x^{\prime}\|_{2}^{2}+W_{2}^{2}(\alpha_{y},\alpha_{y^{\prime}})\right){\text{d}}\pi(z,z^{\prime}), (5)

which defines a proper metric between datasets. Here, αy,αy\alpha_{y},\alpha_{y^{\prime}} are class-conditional measures corresponding to P(x|y)P(x|y) and Q(x|y)Q(x|y^{\prime}). This distance is strongly correlated with transfer learning performance, i.e., the accuracy achieved when training a model on 𝒟P{\mathcal{D}}_{P} and then fine-tuning and evaluating on 𝒟P{\mathcal{D}}_{P}. Therefore, it can be used to select pretraining datasets for a given target domain. Henceforth we abuse the notation PP to represent both a dataset and its underlying distribution for simplicity. To avoid confusion, we use ν\nu and μ\mu to represent distributions in the feature space (typically d\mathbb{R}^{d}) and use P,QP,Q to represent distributions in the product space of features and labels.

4 Dataset interpolation along generalized geodesic

Our method consists of two steps: estimating optimal transport maps between the target dataset and all training datasets (§4.1), and using them to generate a convex combination of these datasets by interpolating along generalized geodesics (§4.2). The OT map estimation is in feature space or original space depending on the dataset’s dimension. For some downstream applications, we will additionally project the target dataset into the ‘convex hull’ of the training datasets (§4.3).

4.1 Estimating optimal maps between labeled datasets

The OTDD is a special case of Wasserstein distance, so it is natural to consider the alternative Monge (map-based) formulation to (3.2). We propose two methods to approximate the OTDD map, one using the entropy-regularized OT and another one based on neural OT.

OTDD barycentric projection.

Barycentric projections [Ambrosio et al., 2008, Pooladian and Niles-Weed, 2021] can be efficiently computed for entropic regularized OT using the Sinkhorn algorithm [Sinkhorn, 1967]. Assume that we have empirical distributions ν=i=1Nν1Nνδxν(i)\nu=\sum_{i=1}^{N_{\nu}}\frac{1}{N_{\nu}}\delta_{x_{\nu}^{(i)}} and μ=i=1Nμ1Nμδxμ(i)\mu=\sum_{i=1}^{N_{\mu}}\frac{1}{N_{\mu}}\delta_{x_{\mu}^{(i)}}, where δx\delta_{x} is the Dirac function at xx. Denote all the samples compactly as matrices: Xν=(xν(1),,xν(Nν))Nν×d,Xμ=(xμ(1),,xμ(Nμ))Nμ×dX_{\nu}=\left(x_{\nu}^{(1)},\ldots,x_{\nu}^{(N_{\nu})}\right)\in{\mathbb{R}}^{N_{\nu}\times d},X_{\mu}=\left(x_{\mu}^{(1)},\ldots,x_{\mu}^{(N_{\mu})}\right)\in{\mathbb{R}}^{N_{\mu}\times d}. After solving the optimal coupling π:=minπΠ(ν,μ)i,jxν(i)xμ(j)2π(i,j)\pi^{*}:=\min_{\pi\in\Pi(\nu,\mu)}\sum_{i,j}\|x_{\nu}^{(i)}-x_{\mu}^{(j)}\|^{2}\pi(i,j), the barycentric projection can be expressed as TB(Xν)=NνπXμ.T_{B}(X_{\nu})=N_{\nu}\pi^{*}X_{\mu}. We extend this method to two datasets ZQ={XQ,YQ},ZP={XP,YP}Z_{Q}=\{X_{Q},Y_{Q}\},Z_{P}=\{X_{P},Y_{P}\}, where we have additional one-hot label data YQ=(yQ(1),,yQ(NQ)){0,1}NQ×CQ,YP=(yP(1),,yP(NP)){0,1}NP×CPY_{Q}=(y_{Q}^{(1)},\ldots,y_{Q}^{(N_{Q})})\in\{0,1\}^{N_{Q}\times C_{Q}},Y_{P}=(y_{P}^{(1)},\ldots,y_{P}^{(N_{P})})\in\{0,1\}^{N_{P}\times C_{P}}. CQC_{Q} and CPC_{P} are the number of classes in dataset QQ and PP. We solve the optimal coupling π0NP×NQ\pi^{*}\in{\mathbb{R}}_{\geq 0}^{N_{P}\times N_{Q}} for OTDD (3.2) following the regularized scheme in  Alvarez-Melis and Fusi [2020]. The barycentric projection can then be written as:

𝒯B(ZQ)=[NQπXP,NQπYP].\displaystyle{\mathcal{T}}_{B}(Z_{Q})=[N_{Q}\pi^{*}X_{P},N_{Q}\pi^{*}Y_{P}]. (6)

The visualization of barycentric projected data appears in Figure 1.

Refer to caption
Figure 1: Visualization of OTDD barycentric projection on binary PCAM dataset. We first solve the optimal coupling π[0,1]NQ×NP\pi^{*}\in[0,1]^{N_{Q}\times N_{P}} for the problem (3.2) using entropy regularization. Next, we map the ii-th datapoint in the source dataset to a pair consisting of a weighted image and a weighted soft label. The weight vector, extracted from the ii-th row of the coupling π\pi^{*}, is then normalized to sum to 1. As a result, the mapped image (or soft label) is obtained as a convex combination of all the images (or one-hot labels) in the target dataset.

However, this approach has two important limitations: it can not naturally map out-of-sample data and it does not scale well to large datasets (due to the quadratic dependency on sample size). To relieve the scaling issue, we will use batchified version of OTDD barycentric projection in this paper (see complexity discussion in §6).

OTDD neural map.

Inspired by recent approaches to estimate Monge maps using neural networks [Rout et al., 2022, Fan et al., 2021], we design a similar framework for the OTDD setting. Fan et al. [2021] approach the Monge OT problem with general cost functions by solving its max-min dual problem

supfinfT[c(x,T(x))f(T(x))]dν(x)+f(x)dμ(x).\sup_{f}\inf_{T}\int\left[c(x,T(x))-f(T(x))\right]{\text{d}}\nu(x)+\int f(x^{\prime}){\text{d}}\mu(x^{\prime}).

We extend this method to the distributions involving labels by introducing an additional classifier in the map. Given two datasets P,QP,Q, we parameterize the map 𝒯N:d×[0,1]CQd×[0,1]CP{\mathcal{T}}_{N}:{\mathbb{R}}^{d}\times[0,1]^{C_{Q}}\rightarrow{\mathbb{R}}^{d}\times[0,1]^{C_{P}} as

𝒯N(z)=𝒯N(x,y)=[x¯;y¯]=[G(z);(G(z))],\displaystyle{\mathcal{T}}_{N}(z)={\mathcal{T}}_{N}(x,y)=[\bar{x};\bar{y}]=[G(z);\ell(G(z))], (7)

where G:d×[0,1]CQdG:{\mathbb{R}}^{d}\times[0,1]^{C_{Q}}\rightarrow{\mathbb{R}}^{d} is the pushforward feature map, and the :d[0,1]CP\ell:{\mathbb{R}}^{d}\rightarrow[0,1]^{C_{P}} is a frozen classifier that is pre-trained on the dataset PP. Notice that, with the cost c(z,𝒯(z))=xG(z)22+W22(αy,αy¯)c(z,{\mathcal{T}}(z))=\|x-G(z)\|_{2}^{2}+W_{2}^{2}(\alpha_{y},\alpha_{\bar{y}}), the Monge formulation of OTDD (3.2) reads infTQ=PxG(z)22+W22(αy,αy¯)dQ(z).\inf_{T\sharp Q=P}\int\|x-G(z)\|_{2}^{2}+W_{2}^{2}(\alpha_{y},\alpha_{\bar{y}}){\text{d}}Q(z). We therefore propose to solve the max-min dual problem

supfinfG[xG(z)22+W22(αy,αy¯)]dQ(z)\displaystyle\sup_{f}\inf_{G}\int\left[\|x-G(z)\|_{2}^{2}+W_{2}^{2}(\alpha_{y},\alpha_{\bar{y}})\right]{\text{d}}Q(z)
f(x¯,y¯)dQ(z)+f(x,y)dP(z).\displaystyle-\int f(\bar{x},\bar{y}){\text{d}}Q(z)+\int f(x^{\prime},y^{\prime}){\text{d}}P(z^{\prime}). (8)
Refer to caption
Figure 2: Training paradigm for learning the OTDD neural map betweem two datasets (distributions), parametrized via a pushforward feature map GG and a labeling function \ell, using a dual potential ff.

Implementation details are provided in §B. Compared to previous conditional Monge map solvers [Bunne et al., 2022a, Asadulaev et al., 2022], the two methods proposed here: (i) do not assume class overlap across datasets, allowing for maps between datasets with different label sets; (ii) are invariant to class permutation and re-labeling; (iii) do not force one-to-one class alignments, e.g., samples can be mapped across similar classes.

4.2 Convex combination in dataset space

Refer to caption
Figure 3: In few-shot settings, we use pseudo-labels for the test dataset, generated e.g. via kNN using the few-shot examples. If more labeled data from the test dataset is available, we use it instead of the pseudo-labels. The projection dataset has the same number of samples as the test dataset.

Computing generalized geodesics requires constructing convex combinations of datapoints from different datasets. Given a weight vector aΔm1a\in\Delta_{m-1}, features can be naturally combined as xa=i=1maixix_{a}=\sum_{i=1}^{m}a_{i}x_{i}. But combining labels is not as simple because: (i) we allow for datasets with a different number of labels, so adding them directly is not possible; (ii) we do not assume different datasets have the same label sets, e.g. MNIST (digits) vs CIFAR10 (objects). Our solution is to represent all labels in the same dimensional space by padding them with zeros in all entries corresponding to other datasets. As an example, consider three datasets with 2,32,3, and 44 classes respectively. Given a label vector y13y_{1}\in{\mathbb{R}}^{3} for the first one, we embed it into 9{\mathbb{R}}^{9} as y~1=[y1;𝟎3;𝟎4].\tilde{y}_{1}=[y_{1};\mathbf{0}_{3};\mathbf{0}_{4}]^{\top}. Defining y~2,y~3\tilde{y}_{2},\tilde{y}_{3} analogously, we compute their combination as ya=a1y~1+a2y~2+a3y~3y_{a}=a_{1}\tilde{y}_{1}+a_{2}\tilde{y}_{2}+a_{3}\tilde{y}_{3}. This representation is lossless and preserves the distinction of labels across datasets. The visualization of our convex combination is in Figure 3.

Refer to caption
(a)
Refer to caption
(b)
Figure 4: Visualization and comparison of dataset interpolation methods. (a) The reference dataset QQ (with color-coded classes) is projected onto the generalized geodesic of the training datasets PiP_{i}, resulting in P^a\widehat{P}_{a}. (b) 2D visualizations of (left-to-right): dataset QQ, the ‘optimal’ interpolated dataset P^a:=(i=1ma^i𝒯i)Q\hat{P}_{a}:=\left(\sum_{i=1}^{m}\hat{a}_{i}{\mathcal{T}}^{*}_{i}\right)\sharp Q using the true OTDD maps 𝒯i{\mathcal{T}}_{i}^{*} , and a naively interpolated dataset (i=1ma^i𝒯i)Q\left(\sum_{i=1}^{m}\hat{a}_{i}{\mathcal{T}}_{i}\right)\sharp Q using randomly generated maps 𝒯i{\mathcal{T}}_{i}.

4.3 Projection onto generalized geodesic of datasets

We now put together the components in Sec 4.1 and 4.2 to construct generalized geodesics between datasets in two steps. First, we compute OTDD maps 𝒯i{\mathcal{T}}_{i}^{*} between QQ and all other datasets Pi,i=1,,mP_{i},i=1,\ldots,m using the discrete or neural OT approaches. Then, for any interpolation vector aΔm1a\in\Delta_{m-1} we identify a dataset along the generalized geodesic via

Pa:=(i=1mai𝒯i)Q.\displaystyle P_{a}:=\left(\sum_{i=1}^{m}a_{i}{\mathcal{T}}^{*}_{i}\right)\sharp Q. (9)

By using the convex combination method in §4.2 for labeled data, we can efficiently sample from PaP_{a}.

Next, we find the dataset PaP^{*}_{a} that minimizes the distance between PaP_{a} and QQ, i.e. the projection of QQ onto the generalized geodesic. We first approach this problem from a Euclidean viewpoint. Suppose there are several distributions {μi}i=1m\{\mu_{i}\}_{i=1}^{m} and an additional distribution ν\nu on Euclidean space d{\mathbb{R}}^{d}. Lemma 1 guarantees there exists a unique parameter aa^{*} that minimizes W22(ρaG,ν)W_{2}^{2}(\rho_{a}^{G},\nu). However, finding aa^{*} is far from trivial because there is no closed-form formula of the map aW22(ρaG,ν)a\mapsto W_{2}^{2}(\rho_{a}^{G},\nu) and it can be expensive to calculate W22(ρaG,ν)W_{2}^{2}(\rho_{a}^{G},\nu) for all possible aa. To solve this problem, we resort to another transport distance: the (2,ν\nu)-transport metric.

Definition 1 (Craig [2016]).

Given distributions μi,μj\mu_{i},\mu_{j}, the (2,ν\nu)-transport metric between them is given by

W2,ν(μi,μj):=(Ti(x)Tj(x)22dν(x))1/2,W_{2,\nu}(\mu_{i},\mu_{j}):=\left(\int\|T_{i}^{*}(x)-T_{j}^{*}(x)\|_{2}^{2}{\text{d}}\nu(x)\right)^{1/2},

where TiT_{i}^{*} is the optimal map from ν\nu to μi\mu_{i}.

When ν\nu has a density with respect to Lebesgue measure W2,νW_{2,\nu} is a valid metric [Craig, 2016, Prop. 1.15]. Moreover, we can derive a closed-form formula for the map aW2,ν2(ρaG,ν)a\mapsto W_{2,\nu}^{2}(\rho_{a}^{G},\nu).

Proposition 1.

W2,ν2(ρaG,ν)=i=1maiW2,ν2(μi,ν)12ijaiajW2,ν2(μi,μj).W_{2,\nu}^{2}(\rho^{G}_{a},\nu)=\sum_{i=1}^{m}a_{i}W_{2,\nu}^{2}(\mu_{i},\nu)-\frac{1}{2}\sum_{i\neq j}a_{i}a_{j}W_{2,\nu}^{2}(\mu_{i},\mu_{j}).

This equation implies that given distributions {μi},ν\{\mu_{i}\},\nu in Euclidean space, we can trivially solve the optimal aa^{*} that minimizes W2,ν2(ρaG,ν)W_{2,\nu}^{2}(\rho^{G}_{a},\nu) by a quadratic programming solverWe use the implementation https://github.com/stephane-caron/qpsolvers. The proof (§A ) relies on Brenier’s theorem. Inspired by this, we also define a transport metric for datasets:

Definition 2.

The squared (2,QQ)-dataset distance is

𝒲2,Q2(Pi,Pj):=(xixj22+W22(αyi,αyj))dQ(z),{\mathcal{W}}^{2}_{2,Q}(P_{i},P_{j}):=\int\left(\|x_{i}-x_{j}\|_{2}^{2}+W_{2}^{2}(\alpha_{y_{i}},\alpha_{y_{j}})\right){\text{d}}Q(z),

where [xi;yi]=𝒯i(z)[x_{i};y_{i}]={\mathcal{T}}_{i}^{*}(z) and 𝒯i{\mathcal{T}}_{i}^{*} is the OTDD map from QQ to PiP_{i}.

Denote 𝒫2,Q(𝒳×𝒫(𝒳)){\mathcal{P}}_{2,Q}({\mathcal{X}}\times{\mathcal{P}}({\mathcal{X}})) as the set of all probability measures PP that satisfy dOT(P,Q)<d_{\textup{OT}}(P,Q)<\infty and the OTDD map from QQ to PP exists. The following result shows that (2,QQ)-dataset distance is a proper distance. The proof is again deferred to §A.

Proposition 2.

𝒲2,Q{\mathcal{W}}_{2,Q} is a valid metric on 𝒫2,Q(𝒳×𝒫(𝒳)){\mathcal{P}}_{2,Q}({\mathcal{X}}\times{\mathcal{P}}({\mathcal{X}})).

Unfortunately, in this case 𝒲2,Q2(Pa,Q){\mathcal{W}}^{2}_{2,Q}(P_{a},Q) does not have an analytic form like before because Brenier’s theorem may not hold for a general transport cost problem. However, we still borrow this idea and define an approximated projection P^a\widehat{P}_{a} as the minimizer of function

𝒲2(Pa,Q):=\displaystyle{\mathcal{W}}^{2}(P_{a},Q):=
i=1mai𝒲2,Q2(Pi,Q)12ijaiaj𝒲2,Q2(Pi,Pj),\displaystyle\sum_{i=1}^{m}a_{i}{\mathcal{W}}^{2}_{2,Q}(P_{i},Q)-\frac{1}{2}\sum_{i\neq j}a_{i}a_{j}{\mathcal{W}}^{2}_{2,Q}(P_{i},P_{j}), (10)

which is an analog of Proposition 1. Since Pa{P}_{a} is defined by its interpolation weight aa, solving P^a\widehat{P}_{a} is equivalent to finding a weight

a^=argminaΔm1𝒲2(Pa,Q),\displaystyle\hat{a}=\operatorname*{arg\,min}_{a\in\Delta_{m-1}}{\mathcal{W}}^{2}(P_{a},Q), (11)

which is a simple quadratic programming problem. Unlike the Wasserstein distance, 𝒲2,Q2(,){\mathcal{W}}^{2}_{2,Q}(\cdot,\cdot) is easier to compute because it does not involve optimization, so it is relatively cheap to find the minimizer of 𝒲2(Pa,Q){\mathcal{W}}^{2}(P_{a},Q). Experimentally, we observe that W2,Q2(Pa,Q)W_{2,Q}^{2}(P_{a},Q) is predictive of model transferability across tasks. Figure 4(a) illustrates this projection on toy 3D datasets, color-coded by class.

5 Experiments

5.1 Learning OTDD maps

In this section, we visualize the quality of the learnt OTDD maps on both synthetic and realistic datasets.

Synthetic datasets

Figure 4 (b) illustrates the role of the optimal map in estimating the projection of a dataset into the generalized geodesic hull of three others. Using maps 𝒯i{\mathcal{T}}_{i}^{*} estimated via barycentric projection (6) results in better preservation of the four-mode class structure, whereas using non-optimal maps 𝒯i{\mathcal{T}}_{i} based on random couplings (as the usual mixup does) destroys the class structure.

*NIST datasets

Refer to caption
Figure 5: Datasets generated by pushing forward QQ (the EMNIST dataset) towards Fashion-MNIST, MNIST, USPS, KMNIST, using OTDD maps 𝒯i\mathcal{T}_{i}, obtained using the neural OT method described in Section 4.1.

In Figure 5, we provide qualitative results of OTDD map from EMNIST (letter) [Cohen et al., 2017] dataset to all other *NIST dataset and USPS dataset. At this point, we can confirm three traits of OTDD map, which are mentioned at the end of §4.1.

Refer to caption
Figure 6: Relationship between the function 𝒲2(Pa,Q){\mathcal{W}}^{2}(P_{a},Q) and the accuracy of the fine-tuned model. The model trained on the projection dataset P^a\hat{P}_{a}, i.e. the minimizer of 𝒲2(Pa,Q){\mathcal{W}}^{2}(P_{a},Q), tends to have a better generalization accuracy. The training datasets are marked on the vertexes of each ternary plot. Each ternary plot is an average of 5 runs with distinct random seeds.

1) We don’t assume a known source label to target label correspondence. So we can map between two irrelevant datasets such as EMNIST and FashionMNIST. 2) The map is invariant to the permutation of label assignment. For example, we show two different labelling in Figure 7, and the final OTDD map will be the same. 3) It doesn’t enforce the label-to-label mapping but would follow the feature similarity. From Figure 5, we notice many cross-class mapping behaviors. For example, when the target domain is USPS [Hull, 1994] dataset, the lower-case letter "l" is always mapped to digit 1, and the capital letter "L" is mapped to other digits such as 6 or 0 because the map follows the feature similarity.

5.2 Transfer learning on *NIST datasets

Next, we use our framework to generate new pretraining datasets for transfer learning. Preceding works illustrate that the transfer learning performance can be quite sensitive to the type of test datasets if there is abundant training data from the test task [Zhai et al., 2019, Table 1]. Thus, we will focus on the few-shot setting, where we only have few labeled data from the test task. We first show that the generalization ability of training models has a strong correlation with the distance 𝒲2,Q2(Pa,Q){\mathcal{W}}^{2}_{2,Q}(P_{a},Q). Then we compare our framework with several baseline methods.

Setup

Given mm labeled pretraining datasets {Pi}\{P_{i}\}, we consider a few-shot task in which only a limited amount of data from the target domain is labeled, e.g. 5 samples per class. The goal is to find a single dataset of size comparable to any individual PiP_{i} that yields the best generalization to the target domain when pre-training a model on it and fine-tuning on the target few-shot data. Here, we seek this training dataset within those generated by generalized geodesics {Pa}\{P_{a}\}, which can be understood as weighted interpolations of the training datasets {Pi}\{P_{i}\}. Note this includes individual datasets as particular cases when aa is a one-hot vector.

Table 1: Pretraining on synthetic data. For each of the *NIST datasets, we treat it as the target domain and pretrain a neural net on a synthetic dataset generated as a combination of the remaining dataset with three interpolation methods. Here we show 5-shot transfer accuracy (mean ±\pm s.d. over 5 runs). The first baseline is to create a synthetic dataset as a training dataset by Mixup among datasets. For Mixup, we randomly sample data from each training dataset, and do the convex combination of them with weight a^\hat{a} (see Eq. (11)). We use the same convex combination method in §4.2, thus this Mixup baseline is equivalent to our framework with suboptimal OTDD maps. The other two baselines (the bottom block) skip the transfer learning part, and directly train the model or solve 1-NN on the few-shot test dataset.
Methods MNIST-M EMNIST MNIST FMNIST USPS KMNIST
OTDD barycentric projection 42.10±\pm4.37 67.06±\pm2.55 93.74±\pm1.46 70.12±\pm3.02 86.01±\pm1.50 52.55±\pm2.73
OTDD neural map 40.06±\pm4.75 65.32±\pm1.80 88.78±\pm3.85 70.02±\pm2.59 83.80±\pm1.60 50.32±\pm3.10
Mixup with weights a^\hat{a} 33.85±\pm2.22 60.95±\pm1.38 88.68±\pm1.57 66.74±\pm3.79 88.61±\pm2.00 48.16±\pm3.38
Train on few-shot dataset 19.10±\pm3.57 53.60±\pm1.18 72.80±\pm3.10 60.50±\pm3.07 80.73±\pm2.07 41.67±\pm2.11
1-NN on few-shot dataset 20.95±\pm1.39 39.70±\pm0.57 64.50±\pm3.32 60.92±\pm2.42 73.64±\pm2.35 40.18±\pm3.09

Connection to generalization

The closed-form expression of W2,ν2(ρaG,ν)W_{2,\nu}^{2}(\rho_{a}^{G},\nu) (Prop. 1) provides a distance between a base distribution ν\nu and the distribution along generalized geodesic ρaG\rho_{a}^{G} in Euclidean space. We study its analog (4.3) for labeled datasets QQ and {Pi}\{P_{i}\} and visualize it in Figure 6 (first row). To investigate the generalization abilities of models trained on different datasets, we discretize the simplex Δ2\Delta_{2} to obtain 3636 interpolation parameters aa, and train a 5-layer LeNet classifier on each PaP_{a}. Then we fine-tune all of these classifiers on the few-shot test dataset QQ with only 20 samples per each class. We control the same number of training iterations and fine-tuning iterations across all experiments. The second row of Figure 6 shows fine-tuning accuracy. Comparing the first row and the second, we find the accuracy and 𝒲2(Pa,Q){\mathcal{W}}^{2}(P_{a},Q) are highly correlated. This implies that the model trained on the minimizer dataset of 𝒲2(Pa,Q){\mathcal{W}}^{2}(P_{a},Q) tends to have a better generalization ability. We fix the same colorbar range for all heatmaps across datasets to highlight the impact of training dataset choice. A more concrete visualization of the correlation between 𝒲2(Pa,Q)\mathcal{W}^{2}(P_{a},Q) and accuracy is shown in Figure 11.

For some test datasets, the choice of training dataset strongly affects the fine-tuning accuracy. For example, when QQ is EMNIST and the training dataset is FMNIST, the fine-tuning accuracy is only 60%\sim 60\%, but this can be improved to >70%>70\% by choosing an interpolated dataset closer to MNIST. This is reasonable because MNIST is more similar to EMNSIT than FMNIST or USPS. To some test datasets like FMNIST and KMNIST, this difference is not so obvious because all training datasets are all far away from the test dataset.

Comparison with baselines.

Next, we compare our method with several baseline methods on NIST datasets. In each set of experiments, we select one *NIST dataset as the target domain, and use the rest for pre-training. We consider a 5-shot task, so we randomly choose 5 samples per class to be the labeled data, and treat the remaining samples as unlabeled. Our method first trains a model on P^a\widehat{P}_{a}, and fine-tunes the model on the 5-shot target data. To obtain P^a\widehat{P}_{a}, we use barycentric projection or neural map to approximate the OTDD maps from the test to training datasets. Our results are shown in the first two rows in Table 1. Overall, transfer learning can bring additional knowledge from other domains and improve the test accuracy by at most 21%\%. Among the methods in the first block, training on datasets generated by OTDD barycentric projection outperforms others except USPS dataset, where the difference is only about 2.6%\%.

5.3 Transfer learning on VTAB datasets

Finally, we use our method for transfer learning with large-scale VTAB datasets [Zhai et al., 2019]. In particular, we take Oxford-IIIT Pet dataset as the target domain, and use Caltech101, DTD, and Flowers102 for pre-training. To encode a richer geometry in our interpolation, we embed the datasets using a masked auto encoder (MAE) [He et al., 2022] and learn the OTDD map in this (\sim200K dimensional) latent space. Since OTDD barycentric projection consistently works better than OTDD neural map (see Table 1), we only use barycentric projection in this section. We use ResNet-18 as the model architecture and pre-train the model on decoded MAE images (interpolated dataset) or original images (single dataset). Meanwhile, Mixup baseline is over pixel space and therefore does not utilize embeddings at all.

Table 2: Transfer Learning on VTAB datasets. The table shows relative improvement (w.r.t. a no-transfer baseline) of test accuracy on Oxford-IIIT Pet (mean ±\pm std over 5 runs) given only 1000 randomly selected samples of this dataset to fine-tune. The first three rows show single-pretraining-dataset baselines, and the remaining rows show methods that pretrain on a synthetic interpolation of these three, generated using Mixup or our proposed OTDD Map, using uniform or a^\hat{a} (see Eq. (11)) dataset interpolation weights. The pooling baseline pretrains on a dataset including all the pre-training datasets. To construct the sub-pooling pretraining dataset, for each training sample from the target dataset (Pet) we find its 10-nearest neighbors (in embedding space) from across all pretraining datasets, and label them as belonging to the class from the target domain.
Pre-Training Map Weights Rel. Improv. (%\%)
Caltech101 - - 59.68 ±\pm 41.44
DTD - - -1.17 ±\pm 9.52
Flowers102 - - -2.45 ±\pm 26.25
Pooling - - 28.96 ±\pm 18.29
Sub-pooling - - 3.00 ±\pm 19.10
Interpolation Mixup uniform 33.26 ±\pm 21.30
Interpolation Mixup a^\hat{a} 51.99 ±\pm 34.10
Interpolation OTDD uniform 82.61 ±\!\pm\! 25.93
Interpolation OTDD a^\hat{a} 95.17±\pm 20.57

The pre-training interpolation dataset generated by our method has ‘optimal’ mixture weights a=(0.43,0.24,0.33)a=(0.43,0.24,0.33) for (Caltech101, dtd, Flowers102), suggesting a stronger similarity between the first of these and the target domain (Pets). This is consistent with the single-dataset transfer accuracies shown in Table 2. However, their interpolation yields better transfer than any single dataset, particularly when using our full method (interpolating using OTDD map with optimal mixture weights).

In Table 2, we compute relative improvement per run, and then average these across runs; in other words, we compute the mean of ratios (MoR) rather than the ratio of means (RoM). Our reasoning for doing this was (i) controlling for the ‘hardness’ inherent to the randomly sampled subsets of Pet by relativizing before averaging and (ii) our observation that it is common practice to compute MoR when the denominator and numerator correspond to paired data (as is the case here), and the terms in the sum are sampled i.i.d. (again, satisfied in this case by the randomly sampled subsets of the target domain).

Table 2 shows a high deviation due to a particularly good result generated by the non-transfer learning baseline with seed 2, while other methods such as Caltech101 pretraining and Flowers102 pretraining had particularly bad results with the same seed.

6 Conclusion and discussion

The method we introduce in this work provides, as shown by our experimental results, a promising new approach to generate synthetic datasets as combinations of existing ones. Crucially, our method allows one to combine datasets even if their label sets are different, and is grounded on principled and well-understood concepts from optimal transport theory. Two key applications of this approach that we envision are:

  • Pretraining data enrichment. Given a collection of classification datasets, generate additional interpolated datasets to increase diversity, with the aim of achieving better out-of-distribution generalization. This could be done even without knowledge of the specific target domain (as we do here) by selecting various datasets to play the role of the ‘reference’ distribution.

  • On-demand optimized synthetic data generation. Generate a synthetic dataset, by combining existing ones, that is ‘optimized’ for transferring a model to a new (data-limited) target domain.

Complexity

The complexity of solving OTDD barycentric projection by Sinkhorn algorithm is 𝒪(N2){\mathcal{O}}(N^{2}) [Dvurechensky et al., 2018], where NN is the number of data in both datasets. This can be expensive for large-scale datasets. In practice, we solve the batched barycentric projection, i.e. take a batch from both datasets and solve the projection from source to target batch, and we normally fix batch size BB as 10410^{4}. This reduces the complexity from 𝒪(N2){\mathcal{O}}(N^{2}) to 𝒪(BN){\mathcal{O}}(BN). The complexity of solving OTDD neural map is 𝒪(BKH){\mathcal{O}}(BKH), where KK is number of iterations, and HH is the size of the network. We always choose K=𝒪(N)K={\mathcal{O}}(N) in the experiments. The complexity of solving all the (2,Q)(2,Q)-dataset distances in (4.3) is 𝒪(m2N){\mathcal{O}}(m^{2}N) since we need to solve the dataset distance between each pair of training datasets. Putting these pieces together, the complexity of approximating the interpolation parameter a^\hat{a} for the minimizer of (4.3) is 𝒪(N(B+m2)){\mathcal{O}}(N(B+m^{2})).

Memory

As the number of pre-training tasks (mm) increases, our method, which generates an interpolated label by concatenating labels from all tasks, creates an increasingly sparse vector. Consequently, the memory demands of the classifier’s output layer, which is proportional to mm, could rise significantly.

Barycentric projection vs Neural map

These two versions of our method offer complementary advantages. While estimating the OT map allows for easy out-of-sample mapping and continuous generation, the barycentric projection approach often yields better downstream performance (Table 1). We hypothesize this is due to the barycentric projection relying on (re-weighted) real data, while the neural map generates data which might be noisy or imperfect.

Pixel space vs feature space

We present results with OTDD mapping in both pixel space (§5.2) and feature space (§5.3). For the VTAB datasets with regular-sized images (e.g. 256×256×3256\times 256\times 3), we found that the feature space is more appropriate for measuring data distance. For small-scale images like NIST, feature space may be overkill because most foundation models are trained on images with a larger size. In our preliminary experiments with NIST datasets, we attempted a feature space approach using an off-the-shelf ResNet-18 model. However, we encountered challenges in achieving convergence when training OTDD neural maps with PyTorch ResNet-18 features.

High variance issue

Our method is not limited to the data scarcity regime, but indeed this is the most interesting one from the transfer learning perspective, which is why we assume limited labeled data (but potentially much more unlabeled data) from the target domain distribution. This is a typical few-shot learning scenario. The quality of a learned OT map will likely depend on the number of samples used to fit it, and might suffer from high variance. To mitigate this in our setting, we opt for augmenting our dataset by generating additional pseudo-labeled data via kNN (Fig. 3). Recall that we do have access to more unlabeled data from the target domain, which is a common situation in practice.

Limitations

Our method for generating a synthetic dataset relies on solving OTDD maps from the test dataset to each training dataset. These OTDD maps are tailored to the considered test dataset and can not be reused for a new test dataset. Another limitation is our framework is based on model training and fine-tuning pipeline. This can be resource-demanding for large-scale models, like GPT [Brown et al., 2020] or other similar models. Finally, if at least one of the datasets is imbalanced, our OTDD map will struggle to match the class with similar marginal distributions.

Acknowledgements.
We thank Yongxin Chen and Nicolò Fusi for their invaluable comments, ideas, and feedback. We extend our gratitude to the anonymous reviewers for their useful feedback that significantly improved this work.

References

  • Agueh and Carlier [2011] M. Agueh and G. Carlier. Barycenters in the wasserstein space. SIAM Journal on Mathematical Analysis, 43(2):904–924, 2011.
  • Alvarez-Melis and Fusi [2020] D. Alvarez-Melis and N. Fusi. Geometric dataset distances via optimal transport. Advances in Neural Information Processing Systems, 33:21428–21439, 2020.
  • Alvarez-Melis and Fusi [2021] D. Alvarez-Melis and N. Fusi. Dataset dynamics via gradient flows in probability space. In International Conference on Machine Learning, pages 219–230. PMLR, 2021.
  • Ambrosio et al. [2008] L. Ambrosio, N. Gigli, and G. Savaré. Gradient flows: in metric spaces and in the space of probability measures. Springer Science & Business Media, 2008.
  • Asadulaev et al. [2022] A. Asadulaev, A. Korotin, V. Egiazarian, and E. Burnaev. Neural optimal transport with general cost functionals. arXiv preprint arXiv:2205.15403, 2022.
  • Bowles et al. [2018] C. Bowles, L. Chen, R. Guerrero, P. Bentley, R. Gunn, A. Hammers, D. A. Dickie, M. V. Hernández, J. Wardlaw, and D. Rueckert. GAN augmentation: Augmenting training data using generative adversarial networks, Oct. 2018.
  • Brenier [1991] Y. Brenier. Polar factorization and monotone rearrangement of vector-valued functions. Communications on pure and applied mathematics, 44(4):375–417, 1991.
  • Brown et al. [2020] T. Brown, B. Mann, N. Ryder, M. Subbiah, J. D. Kaplan, P. Dhariwal, A. Neelakantan, P. Shyam, G. Sastry, A. Askell, S. Agarwal, A. Herbert-Voss, G. Krueger, T. Henighan, R. Child, A. Ramesh, D. Ziegler, J. Wu, C. Winter, C. Hesse, M. Chen, E. Sigler, M. Litwin, S. Gray, B. Chess, J. Clark, C. Berner, S. McCandlish, A. Radford, I. Sutskever, and D. Amodei. Language models are Few-Shot learners. In H. Larochelle, M. Ranzato, R. Hadsell, M. F. Balcan, and H. Lin, editors, Advances in Neural Information Processing Systems, volume 33, pages 1877–1901. Curran Associates, Inc., 2020.
  • Bunne et al. [2022a] C. Bunne, A. Krause, and M. Cuturi. Supervised training of conditional monge maps. arXiv preprint arXiv:2206.14262, 2022a.
  • Bunne et al. [2022b] C. Bunne, L. Papaxanthos, A. Krause, and M. Cuturi. Proximal optimal transport modeling of population dynamics. In G. Camps-Valls, F. J. R. Ruiz, and I. Valera, editors, Proceedings of The 25th International Conference on Artificial Intelligence and Statistics, volume 151 of Proceedings of Machine Learning Research, pages 6511–6528. PMLR, 2022b.
  • Chuang and Mroueh [2021] C.-Y. Chuang and Y. Mroueh. Fair mixup: Fairness via interpolation. In International Conference on Learning Representations, 2021.
  • Cohen et al. [2017] G. Cohen, S. Afshar, J. Tapson, and A. Van Schaik. Emnist: Extending mnist to handwritten letters. In 2017 international joint conference on neural networks (IJCNN), pages 2921–2926. IEEE, 2017.
  • Craig [2016] K. Craig. The exponential formula for the wasserstein metric. ESAIM: Control, Optimisation and Calculus of Variations, 22(1):169–187, 2016.
  • Dvurechensky et al. [2018] P. Dvurechensky, A. Gasnikov, and A. Kroshnin. Computational optimal transport: Complexity by accelerated gradient descent is better than by sinkhorn’s algorithm. In International conference on machine learning, pages 1367–1376. PMLR, 2018.
  • Fan et al. [2020] J. Fan, A. Taghvaei, and Y. Chen. Scalable computations of wasserstein barycenter via input convex neural networks. arXiv preprint arXiv:2007.04462, 2020.
  • Fan et al. [2021] J. Fan, S. Liu, S. Ma, H. Zhou, and Y. Chen. Neural monge map estimation and its applications. arXiv preprint arXiv:2106.03812, 2021.
  • Fan et al. [2022] J. Fan, Q. Zhang, A. Taghvaei, and Y. Chen. Variational Wasserstein gradient flow. In K. Chaudhuri, S. Jegelka, L. Song, C. Szepesvari, G. Niu, and S. Sabato, editors, Proceedings of the 39th International Conference on Machine Learning, volume 162 of Proceedings of Machine Learning Research, pages 6185–6215. PMLR, 2022.
  • Gao and Chaudhari [2021] Y. Gao and P. Chaudhari. An Information-Geometric distance on the space of tasks. In Proceedings of the 38th International Conference on Machine Learning. PMLR, 2021.
  • He et al. [2022] K. He, X. Chen, S. Xie, Y. Li, P. Dollár, and R. Girshick. Masked autoencoders are scalable vision learners. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 16000–16009, 2022.
  • Hua et al. [2023] X. Hua, T. Nguyen, T. Le, J. Blanchet, and V. A. Nguyen. Dynamic flows on curved space generated by labeled data. arXiv preprint arXiv:2302.00061, 2023.
  • Hull [1994] J. J. Hull. A database for handwritten text recognition research. IEEE Transactions on pattern analysis and machine intelligence, 16(5):550–554, 1994.
  • Jain et al. [2022] S. Jain, H. Salman, A. Khaddaj, E. Wong, S. M. Park, and A. Madry. A data-based perspective on transfer learning. arXiv preprint arXiv:2207.05739, 2022.
  • Kabir et al. [2022] H. D. Kabir, M. Abdar, A. Khosravi, S. M. J. Jalali, A. F. Atiya, S. Nahavandi, and D. Srinivasan. Spinalnet: Deep neural network with gradual input. IEEE Transactions on Artificial Intelligence, 2022.
  • Kirkpatrick et al. [2017] J. Kirkpatrick, R. Pascanu, N. Rabinowitz, J. Veness, G. Desjardins, A. A. Rusu, K. Milan, J. Quan, T. Ramalho, A. Grabska-Barwinska, D. Hassabis, C. Clopath, D. Kumaran, and R. Hadsell. Overcoming catastrophic forgetting in neural networks. PNAS, 114(13):3521–3526, mar 2017.
  • Korotin et al. [2021] A. Korotin, L. Li, J. Solomon, and E. Burnaev. Continuous wasserstein-2 barycenter estimation without minimax optimization. arXiv preprint arXiv:2102.01752, 2021.
  • Korotin et al. [2022a] A. Korotin, V. Egiazarian, L. Li, and E. Burnaev. Wasserstein iterative networks for barycenter estimation. arXiv preprint arXiv:2201.12245, 2022a.
  • Korotin et al. [2022b] A. Korotin, D. Selikhanovych, and E. Burnaev. Neural optimal transport. arXiv preprint arXiv:2201.12220, 2022b.
  • Liu et al. [2019] H. Liu, X. Gu, and D. Samaras. Wasserstein gan with quadratic transport cost. In Proceedings of the IEEE/CVF international conference on computer vision, pages 4832–4841, 2019.
  • Makkuva et al. [2020] A. Makkuva, A. Taghvaei, S. Oh, and J. Lee. Optimal transport mapping via input convex neural networks. In International Conference on Machine Learning, volume 37, 2020.
  • McCann [1995] R. J. McCann. Existence and uniqueness of monotone measure-preserving maps. Duke Mathematical Journal, 80(2):309–323, 1995.
  • McCann [1997] R. J. McCann. A convexity principle for interacting gases. Advances in mathematics, 128(1):153–179, 1997.
  • McCloskey and Cohen [1989] M. McCloskey and N. J. Cohen. Catastrophic interference in connectionist networks: The sequential learning problem. In G. H. Bower, editor, Psychology of Learning and Motivation, volume 24, pages 109–165. Academic Press, Jan. 1989. 10.1016/S0079-7421(08)60536-8.
  • Mokrov et al. [2021] P. Mokrov, A. Korotin, L. Li, A. Genevay, J. M. Solomon, and E. Burnaev. Large-Scale wasserstein gradient flows. In M. Ranzato, A. Beygelzimer, Y. Dauphin, P. S. Liang, and J. W. Vaughan, editors, Advances in Neural Information Processing Systems, volume 34, pages 15243–15256. Curran Associates, Inc., 2021.
  • Perrot et al. [2016] M. Perrot, N. Courty, R. Flamary, and A. Habrard. Mapping estimation for discrete optimal transport. Advances in Neural Information Processing Systems, 29, 2016.
  • Pooladian and Niles-Weed [2021] A.-A. Pooladian and J. Niles-Weed. Entropic estimation of optimal transport maps. arXiv preprint arXiv:2109.12004, 2021.
  • Ronneberger et al. [2015] O. Ronneberger, P. Fischer, and T. Brox. U-net: Convolutional networks for biomedical image segmentation. In International Conference on Medical image computing and computer-assisted intervention, pages 234–241. Springer, 2015.
  • Rout et al. [2022] L. Rout, A. Korotin, and E. Burnaev. Generative modeling with optimal transport maps. In International Conference on Learning Representations, 2022.
  • Sandfort et al. [2019] V. Sandfort, K. Yan, P. J. Pickhardt, and R. M. Summers. Data augmentation using generative adversarial networks (CycleGAN) to improve generalizability in CT segmentation tasks. Sci. Rep., 9(1):16884, Nov. 2019. ISSN 2045-2322. 10.1038/s41598-019-52737-x.
  • Santambrogio [2015] F. Santambrogio. Optimal transport for applied mathematicians. Birkäuser, NY, 55(58-63):94, 2015.
  • Santambrogio [2017] F. Santambrogio. {\{Euclidean, metric, and Wasserstein}\} gradient flows: an overview. Bulletin of Mathematical Sciences, 7(1):87–154, 2017.
  • Sinkhorn [1967] R. Sinkhorn. Diagonal equivalence to matrices with prescribed row and column sums. The American Mathematical Monthly, 74(4):402–405, 1967.
  • Srivastava et al. [2018] S. Srivastava, C. Li, and D. B. Dunson. Scalable bayes via barycenter in wasserstein space. The Journal of Machine Learning Research, 19(1):312–346, 2018.
  • Villani [2008] C. Villani. Optimal transport, Old and New, volume 338. Springer Science & Business Media, 2008. ISBN 9783540710493.
  • Yao et al. [2021] H. Yao, L. Zhang, and C. Finn. Meta-learning with fewer tasks through task interpolation. arXiv preprint arXiv:2106.02695, 2021.
  • Yeaton et al. [2022] A. Yeaton, R. G. Krishnan, R. Mieloszyk, D. Alvarez-Melis, and G. Huynh. Hierarchical optimal transport for comparing histopathology datasets. arXiv preprint arXiv:2204.08324, 2022.
  • Yoon et al. [2019] J. Yoon, J. Jordon, and M. van der Schaar. PATE-GAN: Generating synthetic data with differential privacy guarantees. In International Conference on Learning Representations, 2019.
  • Zhai et al. [2019] X. Zhai, J. Puigcerver, A. Kolesnikov, P. Ruyssen, C. Riquelme, M. Lucic, J. Djolonga, A. S. Pinto, M. Neumann, A. Dosovitskiy, et al. A large-scale study of representation learning with the visual task adaptation benchmark. arXiv preprint arXiv:1910.04867, 2019.
  • Zhang et al. [2018] H. Zhang, M. Cisse, Y. N. Dauphin, and D. Lopez-Paz. mixup: Beyond empirical risk minimization. In International Conference on Learning Representations, 2018.
  • Zhang et al. [2021] L. Zhang, Z. Deng, K. Kawaguchi, A. Ghorbani, and J. Zou. How does mixup help with robustness and generalization? In International Conference on Learning Representations, 2021.
  • Zhu et al. [2023] J. Zhu, J. Qiu, A. Guha, Z. Yang, X. Nguyen, B. Li, and D. Zhao. Interpolation for robust learning: Data augmentation on geodesics. arXiv preprint arXiv:2302.02092, 2023.

Appendix A Proofs

Proof of Lemma 1.

By Santambrogio [2017, §4.4], the result holds when m=2m=2. Then Proposition 7.5 in Agueh and Carlier [2011] extends the result to the case of m>2m>2. ∎

Proof of Proposition 1.

Since linear combination preserves cyclically monotonicity, i=1maiTi(x)\sum_{i=1}^{m}a_{i}T_{i}^{*}(x) is the optimal map from ν\nu to ρaG\rho_{a}^{G} [McCann, 1995]. Then according to the definition of W2,ν(,)W_{2,\nu}(\cdot,\cdot), we can write

W2,ν2(ρaG,ν)=xi=1maiTi(x)2dν(x).\displaystyle W_{2,\nu}^{2}(\rho^{G}_{a},\nu)=\int\left\|x-\sum_{i=1}^{m}a_{i}T_{i}^{*}(x)\right\|^{2}{\text{d}}\nu(x). (12)

For scalars p,q1,,qmp,q_{1},\ldots,q_{m}, it holds that

(pi=1maiqi)2\displaystyle\left(p-\sum_{i=1}^{m}a_{i}q_{i}\right)^{2} =p2+i=1mai2qi22i=1maipqi+ijaiajqiqj\displaystyle=p^{2}+\sum_{i=1}^{m}a_{i}^{2}q_{i}^{2}-2\sum_{i=1}^{m}a_{i}pq_{i}+\sum_{i\neq j}a_{i}a_{j}q_{i}q_{j} (13)
=p2+i=1m(aiaijiaj)qi22i=1maipqi+ijaiajqiqj\displaystyle=p^{2}+\sum_{i=1}^{m}(a_{i}-a_{i}\sum_{j\neq i}a_{j})q_{i}^{2}-2\sum_{i=1}^{m}a_{i}pq_{i}+\sum_{i\neq j}a_{i}a_{j}q_{i}q_{j} (14)
=i=1mai(pqi)212ijaiaj(qiqj)2.\displaystyle=\sum_{i=1}^{m}a_{i}(p-q_{i})^{2}-\frac{1}{2}\sum_{i\neq j}a_{i}a_{j}(q_{i}-q_{j})^{2}. (15)

Plugging this equality into (12) gives

W2,ν2(ρaG,ν)\displaystyle W_{2,\nu}^{2}(\rho^{G}_{a},\nu) =(i=1maixTi(x)212ijaiajTi(x)Tj(x)2)dν(x)\displaystyle=\int\left(\sum_{i=1}^{m}a_{i}\|x-T_{i}^{*}(x)\|^{2}-\frac{1}{2}\sum_{i\neq j}a_{i}a_{j}\|T_{i}^{*}(x)-T_{j}^{*}(x)\|^{2}\right){\text{d}}\nu(x) (16)
=i=1maixTi(x)2dν(x)12ijaiajTi(x)Tj(x)2dν(x)\displaystyle=\sum_{i=1}^{m}a_{i}\int\|x-T_{i}^{*}(x)\|^{2}{\text{d}}\nu(x)-\frac{1}{2}\sum_{i\neq j}a_{i}a_{j}\int\|T_{i}^{*}(x)-T_{j}^{*}(x)\|^{2}{\text{d}}\nu(x) (17)
=i=1maiW2,ν2(μi,ν)12ijaiajW2,ν2(μi,μj).\displaystyle=\sum_{i=1}^{m}a_{i}W_{2,\nu}^{2}(\mu_{i},\nu)-\frac{1}{2}\sum_{i\neq j}a_{i}a_{j}W_{2,\nu}^{2}(\mu_{i},\mu_{j}). (18)

Proof of Proposition 2.

Firstly, 𝒲2,Q{\mathcal{W}}_{2,Q} is symmetric and nonnegative by definition. It is non-degenerate since 𝒲2,Q(Pi,Pj)dOT(Pi,Pj){\mathcal{W}}_{2,Q}(P_{i},P_{j})\geq d_{\textup{OT}}(P_{i},P_{j}) and dOTd_{\textup{OT}} is a metric. Finally, we show it satisfies the triangular inequality. Indeed,

𝒲2,Q(P1,P3)\displaystyle\leavevmode\nobreak\ \leavevmode\nobreak\ \leavevmode\nobreak\ \leavevmode\nobreak\ \leavevmode\nobreak\ \leavevmode\nobreak\ {\mathcal{W}}_{2,Q}(P_{1},P_{3}) (19)
=(x1x32+W22(αy1,αy3)dQ(z))1/2\displaystyle=\left(\int\|x_{1}-x_{3}\|^{2}+W_{2}^{2}(\alpha_{y_{1}},\alpha_{y_{3}}){\text{d}}Q(z)\right)^{1/2} (20)
((x1x2+x2x3)2+(W2(αy1,αy2)+W2(αy2,αy3))2dQ(z))1/2\displaystyle\leq\left(\int(\|x_{1}-x_{2}\|+\|x_{2}-x_{3}\|)^{2}+(W_{2}(\alpha_{y_{1}},\alpha_{y_{2}})+W_{2}(\alpha_{y_{2}},\alpha_{y_{3}}))^{2}{\text{d}}Q(z)\right)^{1/2} (21)
(x1x22+W22(αy1,αy2)dQ(z))1/2+(x2x32+W22(αy2,αy3)dQ(z))1/2\displaystyle\leq\left(\int\|x_{1}-x_{2}\|^{2}+W^{2}_{2}(\alpha_{y_{1}},\alpha_{y_{2}}){\text{d}}Q(z)\right)^{1/2}+\left(\int\|x_{2}-x_{3}\|^{2}+W^{2}_{2}(\alpha_{y_{2}},\alpha_{y_{3}}){\text{d}}Q(z)\right)^{1/2} (22)
=𝒲2,Q(P1,P2)+𝒲2,Q(P2,P3),\displaystyle={\mathcal{W}}_{2,Q}(P_{1},P_{2})+{\mathcal{W}}_{2,Q}(P_{2},P_{3}), (23)

where the first inequality is the triangular inequality and the second inequality is the Minkowski inequality. ∎

Appendix B Implementation details of OTDD map

OTDD barycentric projection

We use the implementation https://github.com/microsoft/otdd to solve OTDD coupling. The rest part is straightforward.

OTDD neural map

To solve the problem (4.1), we parameterize f,G,f,G,\ell to be three neural networks. In NIST dataset experiments, we parameterize ff as ResNet https://github.com/harryliew/WGAN-QC from WGAN-QC [Liu et al., 2019], and take feature map GG to be UNet§§§https://github.com/milesial/Pytorch-UNet [Ronneberger et al., 2015]. We generate the labels y¯\bar{y} with a pre-trained classifier ()\ell(\cdot), and use a LeNet or VGG-5 with Spinal layershttps://github.com/dipuk0506/SpinalNet [Kabir et al., 2022] to parameterize ()\ell(\cdot). In 2D Gaussian mixture experiments, we use Residual MLP to represent all of them.

We remove the discriminator’s condition on label to simplify the loss function as

supfinfG(xG(z)22feature loss+W22(αy,αy¯)label loss)dQ(z)f(x¯)dQ(z)+f(x)dP(z)discriminator loss.\displaystyle\sup_{f}\inf_{G}\int\bigl{(}\underbrace{\|x-G(z)\|_{2}^{2}}_{\text{feature loss}}+\underbrace{W_{2}^{2}(\alpha_{y},\alpha_{\bar{y}})}_{\text{label loss}}\bigr{)}{\text{d}}Q(z)\underbrace{-\int f(\bar{x}){\text{d}}Q(z)+\int f(x^{\prime}){\text{d}}P(z^{\prime})}_{\text{discriminator loss}}. (24)

In this formula, we assume both yy and y¯\bar{y} are hard labels, but in practice, the output of ()\ell(\cdot) is a soft label. Simply taking the argmax to get a hard label can break the computational graph, so we replace the label loss W22(αy,αy¯)W_{2}^{2}(\alpha_{y},\alpha_{\bar{y}}) by yMy¯y^{\top}M\bar{y}, where yy is the one-hot label from dataset QQ. And M0CQ×CPM\in{\mathbb{R}}_{\geq 0}^{C_{Q}\times C_{P}} is the label-to-label matrix where M(i,j):=W22(αyi,αyj).M(i,j):=W_{2}^{2}(\alpha_{y_{i}},\alpha_{y_{j}}). The matrix MM is precomputed before the training, and is frozen during the training.

We pre-train the feature map GG to be an identity map before the main adversarial training. We use the Exponential Moving Averagehttps://github.com/fadel/pytorch_ema of the trained feature maps as the final feature map.

Data processing

For all the *NIST datasets, we rescale the images to size 32×3232\times 32, and repeat their channel 3 times and obtain 3-channel images. We use the default train-test split from torchvision. For the VTAB datasets, we use a masked auto-encoder with 196 batches and 1024 embed dimension based on ViT-Large. So the final embedding dimension is 197×1024=201728197\times 1024=201728. We also use the default train-test split from torchvision.

Hyperparameters

For the experimental results in §5.2, we use the OTDD neural map and train them using Adam optimizer with learning rate 10310^{-3} and batch size 64. We train a LeNet for 2000 iterations, and fine-tune for 100 epochs. Regarding the comparison with other baselines in §5.2, for transfer learning methods, we train a SpinalNet for 10410^{4} iterations, and fine-tune it for 20002000 iterations on the test dataset. Training from scratch on the test dataset takes also 2000 iterations. For the results in §5.3, we pre-train the ResNet-18 model for 5 epochs, then fine-tune the model on the few-shot dataset for 10 epochs. During fine-tuning, we still let the whole network tunable. The batch size is 128, and the learning rate is 10310^{-3}.

Appendix C Discussions over complexity-accuracy trade-off

We agree that our method is more computationally demanding than Mixup in general. Specifically, we consider Mixup and our methods to occupy different points of a compute-accuracy trade-off characterized by the expressivity of the geodesics between datasets they define. That being said, the trade-off is nevertheless not a prohibitive one, as shown by the fact that we can scale our method to VTAB-sized datasets with a very standard GPU setup.

‘Vanilla’ mixup with uniform dataset weights is indeed quite cheap (but, as shown in Table 2, considerably worse than alternatives). On the other hand, the version of Mixup that uses the ‘optimal’ mixture weights (labeled Mixup - optimal in Table 2, and the only Mixup version in Table 1) requires solving Eq. (4.3), which involves non-trivial computing to obtain OTDD maps. In the context of the trade-off spectrum described above, Mixup with optimal weights is strictly in between vanilla Mixup and OTDD interpolation.

Appendix D Additional results

Refer to caption
Figure 7: The numbers above images are the labels. In the first labelling method, all 0 MNIST digits are assigned as class "0", and they are labelled as class "7" in the bottom labelling.

D.1 OTDD neural map visualization

We show the OTDD neural map between 2D Gaussian mixture models with 16 components in Figure 8. This example is very special so that we have the closed-form solution of OTDD map. The feature map is a identity map and the pushforward label is equal to the corresponding class that has the same conditional distribution p(x|y)p(x|y) as source label. For example, the sample from top left corner cluster is still mapped to the top left corner cluster, and the label is changed from blue to orange. This map achieves zero transport cost. Since the transport cost is always non-negative, this map is the optimal OTDD map. However, Asadulaev et al. [2022], Bunne et al. [2022a] enforce mapping to preserve the labels, so with their methods, the blue cluster would still map to the blue cluster. Thus their feature map is highly non-convex and more difficult to learn. We refer to Figure 5 in Asadulaev et al. [2022] for their performance on the same example. Compared with them, our pushforward dataset aligns with the target dataset better.

Refer to caption
Figure 8: OTDD neural map for 2D Gaussian mixture distributions.

D.2 McCann’s interpolation between datasets

Our OTDD map can be extended to generate McCann’s interpolation between datasets. We propose an anolog of McCann’s interpolation (2) in the dataset space. We define McCann’s interpolation between datasets P0P_{0} and P1P_{1} as

PtM:=((1t)Id+t𝒯)P0,t[0,1],\displaystyle P^{M}_{t}:=((1-t){\rm Id}+t{\mathcal{T}}^{*})\sharp P_{0},\quad t\in[0,1], (25)

where 𝒯{\mathcal{T}}^{*} is the optimal OTDD map from P0P_{0} to P1P_{1} and tt is the interpolation parameter. The superscript MM of PtMP_{t}^{M} means McCann. We use the same convex combination method in §4.2 to obtain samples from PtMP^{M}_{t}. Assume (x0,y0)P0,(x1,y1)=𝒯(x0,y0)(x_{0},y_{0})\sim P_{0},\leavevmode\nobreak\ (x_{1},y_{1})={\mathcal{T}}^{*}(x_{0},y_{0}) and P0,P1P_{0},P_{1} contain 7, 3 classes respectively, i.e. y0{0,1}7,y1{0,1}3y_{0}\in\{0,1\}^{7},y_{1}\in\{0,1\}^{3}. Then the combination of features is xt=(1t)x0+tx1x_{t}=(1-t)x_{0}+tx_{1}, and the combination of labels is

yt=(1t)[y0𝟎3]+t[𝟎7y1].\displaystyle y_{t}=(1-t)\begin{bmatrix}y_{0}\\ \mathbf{0}_{3}\end{bmatrix}+t\begin{bmatrix}\mathbf{0}_{7}\\ y_{1}\end{bmatrix}. (26)

Thus (xt,yt)(x_{t},y_{t}) is a sample from ((1t)Id+t𝒯)P0((1-t){\rm Id}+t{\mathcal{T}}^{*})\sharp P_{0}. We visualize McCann’s interpolation between two Gaussian mixture distributions in Figure 9. This method can map the labeled data from one dataset to another, and do the interpolation between them. Thus we can use it to map abundant data from an external dataset, to a scarce dataset for data augmentation. For example, in Figure 10, the target dataset only has 30 samples, but the source dataset has 60000 samples. We learn the OTDD neural map between them and solve their interpolation. We find that P1MP_{1}^{M} creates new data out of the domain of the original target distribution, which Mixup [Zhang et al., 2018] can not achieve. Thus, the data from PtMP_{t}^{M} for tt close to 1.0 can enrich the target dataset, and be potentially used in data augmentation for classification tasks.

Refer to caption
Figure 9: McCann’s interpolation for 2D labelled datasets. Each color represents a class. When t1.0t\rightarrow 1.0, the samples within blue classes become less and less, and finally disappear when t=1.0t=1.0.
Refer to caption
Figure 10: Data augmentation by mapping an external dataset to a few-shot dataset.

D.3 Correlation study of *NIST experiments

A more concrete visualization of the correlation between 𝒲2(Pa,Q)\mathcal{W}^{2}(P_{a},Q) and *NIST transfer learning test accuracy is shown in Figure 11. Among all datasets, USPS and KMNIST lack correlation. We believe it’s caused by (i) small variance in the distances from pretraining dataset to target dataset, implying a limited relative diversity of datasets on which to draw on and (ii) (in the case of USPS) a very simple task where baseline accuracy is already very high and hard to improve upon via transfer.

Refer to caption
Figure 11: Pearson correlation between the (averaged) function 𝒲2(Pa,Q){\mathcal{W}}^{2}(P_{a},Q) and the test accuracy of the fine-tuned model. Most datasets present a negative correlation between 𝒲2(Pa,Q){\mathcal{W}}^{2}(P_{a},Q) and the accuracy. When test dataset is USPS or KMNIST (rightmost two), all three training datasets are similarly distant to the test dataset; thus, the range of 𝒲2(Pa,Q){\mathcal{W}}^{2}(P_{a},Q) is not wide enough to show an obvious negative correlation. This explains the nearly zero slope and relatively large pp-value for those two datasets. Similar pattern has been observed in Yeaton et al. [2022, Figure 5(a)].

D.4 Fine-grained analysis over 𝒲2(Pa,W){\mathcal{W}}^{2}(P_{a},W) in *NIST experiments

In Table 3, we provide a more fine-grained analysis for different aspects of 𝒲(Pa,Q)\mathcal{W}(P_{a},Q) and their effect on transfer accuracy. To do so, we provide the min, median, range, and standard deviation of 𝒲(Pa,Q)\mathcal{W}(P_{a},Q) in the table below. In addition, as a proxy for the hardness / best possible gain from transfer learning, we show in the last column OTDD accuracy minus few shot accuracy, where OTDD accuracy and few shot accuracy are the mean accuracies in Rows 1 and 4, respectively, in Table 1.

Based on these statistics, we make the following observations on the relation between 𝒲(Pa,Q)\mathcal{W}(P_{a},Q) and transfer accuracy:

  • The accuracy improvement is strongly driven by mina𝒲(Pa,Q)\min_{a}\mathcal{W}(P_{a},Q). EMNIST and MNIST are with relatively smaller mina𝒲(Pa,Q)\min_{a}\mathcal{W}(P_{a},Q) and share the largest improvement margin. On the other hand, FMNIST and KMNIST as QQ have the largest 𝒲(Pa,Q)\mathcal{W}(P_{a},Q) to the other pre-training datasets, and have relatively smaller accuracy gain. In other words, the correlation between distance and accuracy is stronger in the part of the convex dataset polytope that is closest to the target dataset.

  • The strength of the correlation between 𝒲(Pa,Q)\mathcal{W}(P_{a},Q) and accuracy seems to depend on the range and standard deviation of he former. On the one hand, settings with low dynamic range in 𝒲(Pa,Q)\mathcal{W}(P_{a},Q) (like USPS and EMNIST) make it harder to observe meaningful differences in accuracy. On the other hand, this indicates that those datasets are roughly (or at least more) equidistant from all pretraining datasets, and therefore any convex combination of them will also be close to equidistant from the target, yielding no visible improvement.

  • Intrinsic task hardness matters. Consider USPS: all pretraining datasets, regardless of distance, seem to yield very similar accuracy on it, and it has the lowest accuracy gain (only \sim5%) among 5 tasks. But considering that the no-transfer (i.e. 5-shot) accuracy is already almost 81%, it is clear that the benefit from transfer learning is “a priori” limited, and therefore all pretraining datasets yield a similar minor improvement.

Table 3: Statistics of 𝒲(Pa,Q){\mathcal{W}}(P_{a},Q) and transfer accuracy in *NIST experiments (§5.2).
Test dataset
Mean of
𝒲(Pa,Q)\mathcal{W}(P_{a},Q)
Median of
𝒲(Pa,Q)\mathcal{W}(P_{a},Q)
Range of
𝒲(Pa,Q)\mathcal{W}(P_{a},Q)
Standard deviation
of 𝒲(Pa,Q)\mathcal{W}(P_{a},Q)
Mean of accuracy
improvement
EMNIST 34.41 43.71 39.58 9.94 13.46
MNIST 39.13 49.04 44.17 11.35 20.94
FMNIST 44.19 54.75 39.11 10.64 10.62
USPS 42.04 48.32 23.49 6.13 5.28
KMNIST 47.65 53.92 24.83 6.19 10.88

D.5 Full results of VTAB experiments

In Section 5.3, we only showed the relative improvement of the test accuracy compared to non-pretraining. Here we will show the full test accuracy results. We keep the hyper-parameters consistent through all pre-training datasets. Table 4 clearly shows that the interpolation dataset with optimal weight assigned by our method can have a better performance than a naïve uniform weight. And with the same weight, our OTDD map will give a higher accuracy than Mixup because Mixup does not use the information from the reference dataset (see Figure 4).

Poor sub-pooling performance

We show the sub-pooling baseline as a non-trivial method to combine datasets. However, it performs poorly, and we believe there are two main reasons for this. First, this baseline wastes relevant label data, by discarding the original labels of the pretraining dataset and replacing them with the inputted nearest-neighbor label from the target examples. Secondly, it only uses the neighbors of the pet dataset, leaving all other datapoints unused.

Table 4: Test accuracy (mean ±\pm std over 5 runs in percent) of 1000-shot learning on Oxford-IIIT Pet test dataset. Non-transfer learning skips the pre-training step.
Transfer learning OTDD map (optimal weight) 22.60 ±\pm 1.01
OTDD map (uniform weight) 21.06 ±\pm 0.45
Mixup (optimal weight) 17.45 ±\pm 2.2
Mixup (uniform weight) 15.4 ±\pm 1.56
Caltech101 18.24 ±\pm 3.42
DTD 11.46 ±\pm 0.68
Flowers102 11.11 ±\pm 1.92
Pooling 14.88 ±\pm 0.57
Sub-pooling 14.88 ±\pm 0.57
Non-transfer learning 11.71 ±\pm 1.65