This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Estimating the Rate-Distortion Function by Wasserstein Gradient Descent

Yibo Yang1  Stephan Eckstein2  Marcel Nutz3  Stephan Mandt1
1University of California, Irvine  2ETH Zurich  3Columbia University
{yibo.yang, mandt}@uci.edu
[email protected]
[email protected]
Abstract

In the theory of lossy compression, the rate-distortion (R-D) function R(D)R(D) describes how much a data source can be compressed (in bit-rate) at any given level of fidelity (distortion). Obtaining R(D)R(D) for a given data source establishes the fundamental performance limit for all compression algorithms. We propose a new method to estimate R(D)R(D) from the perspective of optimal transport. Unlike the classic Blahut–Arimoto algorithm which fixes the support of the reproduction distribution in advance, our Wasserstein gradient descent algorithm learns the support of the optimal reproduction distribution by moving particles. We prove its local convergence and analyze the sample complexity of our R-D estimator based on a connection to entropic optimal transport. Experimentally, we obtain comparable or tighter bounds than state-of-the-art neural network methods on low-rate sources while requiring considerably less tuning and computation effort. We also highlight a connection to maximum-likelihood deconvolution and introduce a new class of sources that can be used as test cases with known solutions to the R-D problem.

1 Introduction

The rate-distortion (R-D) function R(D)R(D) occupies a central place in the theory of lossy compression. For a given data source and a fidelity (or distortion) criterion, R(D)R(D) characterizes the minimum possible communication cost needed to reproduce a source sample within an error threshold of DD, by any compression algorithm (Shannon, 1959). A basic scientific and practical question is therefore establishing R(D)R(D) for any given data source of interest, which helps assess the (sub)optimality of the compression algorithms and guide their development. The classic algorithm by Blahut (1972) and Arimoto (1972) assumes a known discrete source and computes its R(D)R(D) by an exhaustive optimization procedure. This often has limited applicability in practice, and a line of research has sought to instead estimate R(D)R(D) from data samples (Harrison and Kontoyiannis, 2008; Gibson, 2017), with recent methods (Yang and Mandt, 2022; Lei et al., 2023a) inspired by deep generative models.

In this work, we propose a new approach to R-D estimation from the perspective of optimal transport. Our starting point is the formulation of the R-D problem as the minimization of a certain rate functional (Harrison and Kontoyiannis, 2008) over the space of probability measures on the reproduction alphabet. Optimization over such an infinite-dimensional space has long been studied under gradient flows (Ambrosio et al., 2008), and we consider a concrete algorithmic implementation based on moving particles in space. This formulation of the R-D problem also suggests connections to entropic optimal transport and non-parametric statistics, each offering us new insight into the solution of the R-D problem under a quadratic distortion. More specifically, our contributions are three-fold:

First, we introduce a neural-network-free R(D)R(D) upper bound estimator for continuous alphabets. We implement the estimator by Wasserstein gradient descent (WGD) over the space of reproduction distributions. Experimentally, we found the method to converge much more quickly than state-of-the-art neural methods with hand-tuned architectures, while offering comparable or tighter bounds.

Second, we theoretically characterize convergence of our WGD algorithm and the sample complexity of our estimator. The latter draws on a connection between the R-D problem and that of minimizing an entropic optimal transport (EOT) cost relative to the source measure, allowing us to turn statistical bounds for EOT (Mena and Niles-Weed, 2019) into finite-sample bounds for R-D estimation.

Finally, we introduce a new, rich class of sources with known ground truth, including Gaussian mixtures, as a benchmark for algorithms. While the literature relies on the Gaussian or Bernoulli for this purpose, we use the connection with maximum likelihood deconvolution to show that a Gaussian convolution of any distribution can serve as a source with a known solution to the R-D problem.

2 Lossy compression, entropic optimal transport, and MLE

This section introduces the R-D problem and its rich connections to entropic optimal transport and statistics, along with new insights into its solution. Sec. 2.1 sets the stage for our method (Sec. 4) by a known formulation of the standard R-D problem as an optimization problem over a space of probability measures. Sec. 2.2 discusses the equivalence between the R-D problem and a projection of the source distribution under entropic optimal transport; this is a key to our sample complexity results in Sec. 4.3. Lastly, Sec 2.3 gives a statistical interpretation of R-D as maximum-likelihood deconvolution and uses it to analytically derive a segment of the R-D curve for a new class of sources under quadratic distortion; this allows us to assess the optimality of algorithms in experiment Sec. 5.1.

2.1 Setup

For a memoryless data source XX with distribution PXP_{X}, its rate-distortion (R-D) function describes the minimum possible number of bits per sample needed to reproduce the source within a prescribed distortion threshold DD. Let the source and reproduction take values in two sets 𝒳{\mathcal{X}} and 𝒴{\mathcal{Y}}, known as the source and reproduction alphabets, and let ρ:(𝒳,𝒴)[0,)\rho:({\mathcal{X}},{\mathcal{Y}})\to[0,\infty) be a given distortion function. The R-D function is defined by the following optimization problem (Polyanskiy and Wu, 2022),

R(D)=infQY|X:𝔼PXQY|X[ρ(X,Y)]DI(X;Y),\displaystyle R(D)=\inf_{Q_{Y|X}:{\mathbb{E}}_{P_{X}Q_{Y|X}}[\rho(X,Y)]\leq D}I(X;Y), (1)

where QY|XQ_{Y|X} is any Markov kernel from 𝒳{\mathcal{X}} to 𝒴{\mathcal{Y}} conceptually associated with a (possibly) stochastic compression algorithm, and I(X;Y)I(X;Y) is the mutual information of the joint distribution PXQY|XP_{X}Q_{Y|X}.

For ease of presentation, we now switch to a more abstract notation without reference to random variables. We provide the precise definitions in the Supplementary Material. Let 𝒳{\mathcal{X}} and 𝒴{\mathcal{Y}} be standard Borel spaces; let μ𝒫(𝒳)\mu\in\mathcal{P}({\mathcal{X}}) be a fixed probability measure on 𝒳{\mathcal{X}}, which should be thought of as the source distribution PXP_{X}. For a measure π\pi on the product space 𝒳×𝒴{\mathcal{X}}\times{\mathcal{Y}}, the notation π1\pi_{1} (or π2\pi_{2}) denotes the first (or second) marginal of π\pi. For any ν𝒫(𝒴)\nu\in\mathcal{P}({\mathcal{Y}}), we denote by Π(μ,ν)\Pi(\mu,\nu) the set of couplings between μ\mu and ν\nu (i.e., π1=μ\pi_{1}=\mu and π2=ν\pi_{2}=\nu). Similarly, Π(μ,)\Pi(\mu,\cdot) denotes the set of measures π\pi with π1=μ\pi_{1}=\mu. Throughout the paper, KK denotes a transition kernel (conditional distribution) from 𝒳{\mathcal{X}} to 𝒴{\mathcal{Y}}, and μK\mu\otimes K denotes the product measure formed by μ\mu and KK. Then R(D)R(D) is equivalent to

R(D)=infK:ρd(μK)DH(μK|μ(μK)2)=infπΠ(μ,):ρ𝑑πDH(π|π1π2),\displaystyle R(D)=\inf_{K:\int\rho d(\mu\otimes K)\leq D}H(\mu\otimes K|\mu\otimes(\mu\otimes K)_{2})=\inf_{\pi\in\Pi(\mu,\cdot):\int\rho d\pi\leq D}H(\pi|\pi_{1}\otimes\pi_{2}), (2)

where HH denotes relative entropy, i.e., for two measures α,β\alpha,\beta defined on a common measurable space, H(α|β):=log(dαdβ)𝑑αH(\alpha|\beta):=\int\log(\frac{d\alpha}{d\beta})d\alpha when α\alpha is absolutely continuous w.r.t β\beta, and infinite otherwise.

To make the problem more tractable, we follow the approach of the classic Blahut–Arimoto algorithm (Blahut, 1972; Arimoto, 1972) (to be discussed in Sec. 3.1) and work with an equivalent unconstrained Lagrangian problem as follows. Instead of parameterizing the R-D function via a distortion threshold DD, we parameterize it via a Lagrange multiplier λ0\lambda\geq 0. For each fixed λ\lambda (usually selected from a predefined grid), we aim to solve the following optimization problem,

Fλ(μ):=infν𝒫(𝒴)infπΠ(μ,)λρ𝑑π+H(π|μν).\displaystyle F_{\lambda}(\mu):=\inf_{\nu\in\mathcal{P}({\mathcal{Y}})}\inf_{\pi\in\Pi(\mu,\cdot)}\lambda\int\rho d\pi+H(\pi|\mu\otimes\nu). (3)

Geometrically, Fλ(μ)F_{\lambda}(\mu)\in\mathbb{R} is the y-axis intercept of a tangent line to the R(D)R(D) with slope λ-\lambda, and R(D)R(D) is determined by the convex envelope of all such tangent lines (Gray, 2011). To simplify notation, we often drop the dependence on λ\lambda (e.g., we write F(μ)=Fλ(μ)F(\mu)=F_{\lambda}(\mu)) whenever it is harmless.

To set the stage for our later developments, we write the unconstrained R-D problem as

Fλ(μ)\displaystyle F_{\lambda}(\mu) =infν𝒫(𝒴)BA(μ,ν),\displaystyle=\inf_{\nu\in\mathcal{P}({\mathcal{Y}})}\mathcal{L}_{BA}(\mu,\nu), (4)
BA(μ,ν)\displaystyle\mathcal{L}_{BA}(\mu,\nu) :=infπΠ(μ,)λρ𝑑π+H(π|μν)=infKλρd(μK)+H(μK|μν),\displaystyle:=\inf_{\pi\in\Pi(\mu,\cdot)}\lambda\int\rho d\pi+H(\pi|\mu\otimes\nu)=\inf_{K}\lambda\int\rho d(\mu\otimes K)+H(\mu\otimes K|\mu\otimes\nu), (5)

where we refer to the optimization objective BA\mathcal{L}_{BA} as the rate function (Harrison and Kontoyiannis, 2008). We abuse the notation to write BA(ν):=BA(μ,ν)\mathcal{L}_{BA}(\nu):=\mathcal{L}_{BA}(\mu,\nu) when it is viewed as a function of ν\nu only, and refer to it as the rate functional. The rate function characterizes a generalized Asymptotic Equipartition Property, where BA(μ,ν)\mathcal{L}_{BA}(\mu,\nu) is the asymptotically optimal cost of lossy compression of data XμX\sim\mu using a random codebook constructed from samples of ν\nu (Dembo and Kontoyiannis, 2002). Notably, the optimization in (5) can be solved analytically (Csiszár, 1974a, Lemma 1.3), and BA\mathcal{L}_{BA} simplifies to

BA(μ,ν)=𝒳log(𝒴eλρ(x,y)ν(dy))μ(dx).\displaystyle\mathcal{L}_{BA}(\mu,\nu)=\int_{\mathcal{X}}-\log\left(\int_{\mathcal{Y}}e^{-\lambda\rho(x,y)}\nu(dy)\right)\mu(dx). (6)

In practice, the source μ\mu is only accessible via independent samples, on the basis of which we propose to estimate its R(D)R(D), or equivalently F(μ)F(\mu). Let μm\mu^{m} denote an mm-sample empirical measure of μ\mu, i.e., μm=i=1mδxi\mu^{m}=\sum_{i=1}^{m}\delta_{x_{i}} with x1,,nx_{1,...,n} being independent samples from μ\mu, which should be thought of as the “training data”. Following Harrison and Kontoyiannis (2008), we consider two kinds of (plug-in) estimators for F(μ)F(\mu): (1) the non-parametric estimator F(μm)F(\mu^{m}), and (2) the parametric estimator F(μm):=infνBA(μm,ν)F^{\mathcal{H}}(\mu^{m}):=\inf_{\nu\in\mathcal{H}}\mathcal{L}_{BA}(\mu^{m},\nu), where \mathcal{H} is a family of probability measures on 𝒴{\mathcal{Y}}. Harrison and Kontoyiannis (2008) showed that under rather broad conditions, both kinds of estimators are strongly consistent, i.e., F(μm)F(\mu^{m}) converges to F(μ)F(\mu) (and respectively, F(μm)F^{\mathcal{H}}(\mu^{m}) to F(μ)F^{\mathcal{H}}(\mu)) with probability one as mm\to\infty. Our algorithm will implement the parametric estimator F(μm)F^{\mathcal{H}}(\mu^{m}) with \mathcal{H} chosen to be the set of probability measures with finite support, and we will develop finite-sample convergence results for both kinds of estimators in the continuous setting (Proposition. 4.3).

2.2 Connection to entropic optimal transport

The R-D problem turns out to have a close connection to entropic optimal transport (EOT) (Peyré and Cuturi, 2019), which we will exploit in Sec. 4.3 to obtain sample complexity results under our approach. For ϵ>0\epsilon>0, the entropy-regularized optimal transport problem is given by

EOT(μ,ν):=infπΠ(μ,ν)ρ𝑑π+ϵH(π|μν).\displaystyle\mathcal{L}_{EOT}(\mu,\nu):=\inf_{\pi\in\Pi(\mu,\nu)}\int\rho d\pi+\epsilon H(\pi|\mu\otimes\nu). (7)

We now consider the problem of projecting μ\mu onto 𝒫(𝒴)\mathcal{P}({\mathcal{Y}}) under the cost EOT\mathcal{L}_{EOT}:

infν𝒫(𝒴)EOT(μ,ν).\displaystyle\inf_{\nu\in{\mathcal{P}}({\mathcal{Y}})}\mathcal{L}_{EOT}(\mu,\nu). (8)

In the OT literature this is known as the (regularized) Kantorovich estimator (Bassetti et al., 2006) for μ\mu, and can also be viewed as a Wasserstein barycenter problem (Agueh and Carlier, 2011).

With the identification ϵ=λ1\epsilon=\lambda^{-1}, problem (8) is in fact equivalent to the R-D problem (4): compared to BA{\mathcal{L}}_{BA} (5), the extra constraint on the second marginal of π\pi in EOT{\mathcal{L}}_{EOT} (7) is redundant at the optimal ν\nu. More precisely, Lemma 7.1 shows that (we omit the notational dependence on μ\mu when it is fixed):

infν𝒫(𝒴)EOT(ν)=infν𝒫(𝒴)λ1BA(ν)andargminν𝒫(𝒴)EOT(ν)=argminν𝒫(𝒴)BA(ν).\inf_{\nu\in\mathcal{P}({\mathcal{Y}})}\mathcal{L}_{EOT}(\nu)=\inf_{\nu\in\mathcal{P}({\mathcal{Y}})}\lambda^{-1}\mathcal{L}_{BA}(\nu)\qquad\mbox{and}\qquad\operatorname*{arg\,min}_{\nu\in\mathcal{P}({\mathcal{Y}})}\mathcal{L}_{EOT}(\nu)=\operatorname*{arg\,min}_{\nu\in\mathcal{P}({\mathcal{Y}})}\mathcal{L}_{BA}(\nu). (9)

Existence of a minimizer holds under mild conditions, for instance if 𝒳=𝒴=d{\mathcal{X}}={\mathcal{Y}}=\mathbb{R}^{d} and ρ(x,y)\rho(x,y) is a coercive lower semicontinuous function of yxy-x (Csiszár, 1974a, p. 66).

2.3 Connection to maximum-likelihood deconvolution

The connection between R-D and maximum-likelihood estimation has been observed in the information theory, machine learning and compression literature (Harrison and Kontoyiannis, 2008; Alemi et al., 2018; Ballé et al., 2017; Theis et al., 2017; Yang et al., 2020; Yang and Mandt, 2022). Here, we bring attention to a basic equivalence between the R-D problem and maximum-likelihood deconvolution, where the connection is particularly natural under a quadratic distortion function. Also see (Rigollet and Weed, 2018) for a related discussion that inspired ours and extension to a non-quadratic distortion. We provide further insight from the view of variational learning and inference in Section 10.

Refer to caption
Figure 1: The R(D)R(D) of a Gaussian mixture source, and the estimated optimal reproduction distributions ν\nu^{*} (in bar plots) at varying R-D trade-offs. For any λ[σ2,)\lambda\in[\sigma^{-2},\infty), the corresponding R(D)R(D) (yellow segment) is known analytically as is the optimal reproduction distribution ν\nu^{*} (whose density is plotted in gray). For λ(0,σ2]\lambda\in(0,\sigma^{-2}], ν\nu^{*} becomes singular and concentrated on two points, collapsing to the source mean as λ0\lambda\to 0.

Maximum-likelihood deconvolution is a classical problem of non-parametric statistics and mixture models (Carroll and Hall, 1988; Lindsay and Roeder, 1993). The deconvolution problem is concerned with estimating an unknown distribution α\alpha from noise-corrupted observations X1,X2,X_{1},X_{2},..., where for each ii\in\mathbb{N}, we have Xi=Yi+NiX_{i}=Y_{i}+N_{i}, Yii.i.d.αY_{i}\stackrel{{\scriptstyle i.i.d.}}{{\sim}}\alpha, and NiN_{i} are i.i.d. independent noise variables with a known distribution. For concreteness, suppose all variables are d\mathbb{R}^{d} valued and the noise distribution is 𝒩(0,σ2Id)\mathcal{N}(0,\sigma^{2}\text{I}_{d}) with Lebesgue density ϕσ2\phi_{\sigma^{2}}. Denote the distribution of the observations XiX_{i} by μ\mu. Then μ\mu has a Lebesgue density given by the convolution αϕσ2(x):=ϕσ2(xy)α(dy)\alpha*\phi_{\sigma^{2}}(x):=\int\phi_{\sigma^{2}}(x-y)\alpha(dy). Here, we consider the population-level (instead of the usual sample-based) maximum-likelihood estimator (MLE) for α\alpha:

ν=argmaxν𝒫(d)log(νϕσ2(x))μ(dx),\displaystyle\nu^{*}=\operatorname*{arg\,max}_{\nu\in{\mathcal{P}}(\mathbb{R}^{d})}\int\log\left(\nu*\phi_{\sigma^{2}}(x)\right)\mu(dx), (10)

and observe that ν=α\nu^{*}=\alpha. Plugging in the density ϕσ2(x)e12σ2x2\phi_{\sigma^{2}}(x)\propto e^{-\frac{1}{2\sigma^{2}}\|x\|^{2}}, we see that the MLE problem (10) is equivalent to the R-D problem (4) with ρ(x,y)=12xy2,λ=1σ2\rho(x,y)=\frac{1}{2}\|x-y\|^{2},\lambda=\frac{1}{\sigma^{2}}, and BA{\mathcal{L}}_{BA} given by (6) in the form of a marginal log-likelihood. Thus the R-D problem has the interpretation of estimating a distribution from its noisy observations given through μ\mu, assuming a Gaussian noise with variance 1λ\frac{1}{\lambda}.

This connection suggests analytical solutions to the R-D problem for a variety of sources that arise from convolving an underlying distribution with Gaussian noise. Consider an R-D problem (4) with 𝒳=𝒴=d,ρ(x,y)=12xy2{\mathcal{X}}={\mathcal{Y}}=\mathbb{R}^{d},\rho(x,y)=\frac{1}{2}\|x-y\|^{2}, and let the source μ\mu be the convolution between an arbitrary measure α𝒫(𝒴)\alpha\in{\mathcal{P}}({\mathcal{Y}}) and Gaussian noise with known variance σ2\sigma^{2}. E.g., using a discrete measure for α\alpha results in a Gaussian mixture source with equal covariance among its components. When λ=1σ2\lambda=\frac{1}{\sigma^{2}}, we recover exactly the population-MLE problem (10) discussed earlier, which has the solution ν=α\nu^{*}=\alpha. While this allows us to obtain one point of R(D)R(D), we can in fact extend this idea to any λ1σ2\lambda\geq\frac{1}{\sigma^{2}} and obtain the analytical form for the corresponding segment of the R-D curve. Specifically, for any λ1σ2\lambda\geq\frac{1}{\sigma^{2}}, applying the summation rule for independent Gaussians reveals the source distribution μ\mu as

μ=α𝒩(0,σ2)=α𝒩(0,σ21λ)𝒩(0,1λ)=αλ𝒩(0,1λ),αλ:=α𝒩(0,σ21λ),\mu=\alpha*\mathcal{N}(0,\sigma^{2})=\alpha*\mathcal{N}(0,\sigma^{2}-\frac{1}{\lambda})*\mathcal{N}(0,\frac{1}{\lambda})=\alpha_{\lambda}*\mathcal{N}(0,\frac{1}{\lambda}),\quad\alpha_{\lambda}:=\alpha*\mathcal{N}(0,\sigma^{2}-\frac{1}{\lambda}),

i.e., as the convolution between another underlying distribution αλ\alpha_{\lambda} and independent noise with variance 1λ\frac{1}{\lambda}. A solution to the R-D problem (3) is then analogously given by ν=αλ\nu^{*}=\alpha_{\lambda}, with the corresponding optimal coupling given by νK~\nu^{*}\otimes\tilde{K}, K~(y,dx)=𝒩(y,1λ)\tilde{K}(y,dx)=\mathcal{N}(y,\frac{1}{\lambda}). 111 Here K~\tilde{K} maps from the reproduction to the source alphabet, opposite to the kernel KK elsewhere in the text. Evaluating the distortion and mutual information of the coupling then yields the R(D)R(D) point associated with λ\lambda. Fig. 1 illustrates the R(D)R(D) of a toy Gaussian mixture source, along with the ν\nu^{*} estimated by our proposed WGD algorithm (Sec. 4); note that ν\nu^{*} transitions from continuous (a Gaussian mixture with smaller component variances) to singular (a mixture of two Diracs) at λ=σ2\lambda=\sigma^{-2}. See caption for more details.

3 Related Work

3.1 Blahut–Arimoto

The Blahut–Arimoto (BA) algorithm (Blahut, 1972; Arimoto, 1972) is the default method for computing R(D)R(D) for a known and discrete case. For a fixed λ\lambda, BA carries out the optimization problem (3) via coordinate ascent. Starting from an initial measure ν(0)𝒫(𝒴)\nu^{(0)}\in\mathcal{P}({\mathcal{Y}}), the BA algorithm at step tt computes an updated pair (ν(t+1),K(t+1))(\nu^{(t+1)},K^{(t+1)}) as follows

dK(t+1)(x,)dν(t)(y)\displaystyle\frac{dK^{(t+1)}(x,\cdot)}{d\nu^{(t)}}(y) =eλρ(x,y)eλρ(x,y)ν(t)(dy),x𝒳,\displaystyle=\frac{e^{-\lambda\rho(x,y)}}{\int e^{-\lambda\rho(x,y^{\prime})}\nu^{(t)}(dy^{\prime})},\quad\forall x\in{\mathcal{X}}, (11)
ν(t+1)\displaystyle\nu^{(t+1)} =(μK(t+1))2.\displaystyle=(\mu\otimes K^{(t+1)})_{2}. (12)

When the alphabets are finite, the above computation can be carried out in matrix and vector operations, and the resulting sequence {(ν(t),K(t))}t=1\{(\nu^{(t)},K^{(t)})\}_{t=1}^{\infty} can be shown to converge to an optimum of (3); cf. (Csiszár, 1974b, 1984). When the alphabets are not finite, e.g., 𝒳=𝒴=d{\mathcal{X}}={\mathcal{Y}}=\mathbb{R}^{d}, the BA algorithm no longer applies, as it is unclear how to digitally represent the measure ν\nu and kernel KK and to tractably perform the integrals required by the algorithm. The common workaround is to perform a discretization step and then apply BA on the resulting discrete problem.

One standard discretization method is to tile up the alphabets with small bins (Gray and Neuhoff, 1998). This quickly becomes infeasible as the number of dimensions increases. We therefore consider discretizing the data space 𝒳{\mathcal{X}} to be the support of training data distribution μm\mu^{m}, i.e., the discretized alphabet is the set of training samples; this can be justified by the consistency of the parametric R-D estimator F(μm)F^{\mathcal{H}}(\mu^{m}) (Harrison and Kontoyiannis, 2008). It is less clear how to discretize the reproduction space 𝒴{\mathcal{Y}}, especially in high dimensions. Since we work with 𝒳=𝒴{\mathcal{X}}={\mathcal{Y}}, we will disretize 𝒴{\mathcal{Y}} similarly and use an nn-element random subset of the training samples, as also considered by Lei et al. (2023a). As we will show, this rather arbitrary placement of the support of ν\nu results in poor performance, and can be significantly improved from our perspective of evolving particles.

3.2 Neural network-based methods for estimating R(D)R(D)

RD-VAE ((Yang and Mandt, 2022)): To overcome the limitations of the BA algorithm, Yang and Mandt (2022) proposed to parameterize the transition kernel KK and reproduction distribution ν\nu of the BA algorithm by neural density estimators (Papamakarios et al., 2021), and optimize the same objective (3) by (stochastic) gradient descent. They estimate (3) by Monte Carlo using joint samples (Xi,Yi)μK(X_{i},Y_{i})\sim\mu\otimes K; in particular, the relative entropy can be written as H(μK|μν)=log(dK(x,)dν(y))K(x,dy)μ(dx)H(\mu\otimes K|\mu\otimes\nu)=\int\int\log\left(\frac{dK(x,\cdot)}{d\nu}(y)\right)K(x,dy)\mu(dx), where the integrand is computed exactly via a density ratio. In practice, an alternative parameteriation is often used where the neural density estimators are defined on a lower dimensional latent space than the reproduction alphabet, and the resulting approach is closely related to VAEs (Kingma and Welling, 2013). Yang and Mandt (2022) additionally propose a neural estimator for a lower bound on R(D)R(D), based on a dual representation due to Csiszár (1974a).    NERD (Lei et al., 2023a): Instead of working with the transition kernel KK as in the RD-VAE, Lei et al. (2023a) considered optimizing the form of the rate functional in (6), via gradient descent on the parameters of ν\nu parameterized by a neural network. Let ν𝒵\nu^{{\mathcal{Z}}} be a base distribution over 𝒵=K{\mathcal{Z}}=\mathbb{R}^{K}, such as the standard Gaussian, and ω:𝒵𝒴\omega:{\mathcal{Z}}\to{\mathcal{Y}} be a decoder network. The variational measure ν\nu is then modeled as the image measure of ν𝒵\nu^{{\mathcal{Z}}} under ω\omega. To evaluate and optimize the objective (6), the intractable inner integral w.r.t. ν\nu is replaced with a plug-in estimator, so that for a given x𝒳x\in{\mathcal{X}},

log(𝒴eλρ(x,y)ν(dy))log(1nj=1neλρ(x,Yj)),Yjν,j=1,2,,n.\displaystyle-\log\left(\int_{\mathcal{Y}}e^{-\lambda\rho(x,y)}\nu(dy)\right)\approx-\log\left(\frac{1}{n}\sum_{j=1}^{n}e^{-\lambda\rho(x,Y_{j})}\right),\quad Y_{j}\sim\nu,j=1,2,...,n. (13)

After training, we estimate an R-D upper bound using nn samples from ν\nu (to be discussed in Sec. 4.4).

3.3 Other related work

Within information theory: Recent work by Wu et al. (2022) and Lei et al. (2023b) also note the connection between the R-D function and entropic optimal transport. Wu et al. (2022) compute the R-D function in the finite and known alphabet setting by solving a version of the EOT problem (8), whereas Lei et al. (2023b) numerically verify the equivalence (9) on a discrete problem and discuss the connection to scalar quantization. We also experimented with estimating R(D)R(D) by solving the EOT problem (8), but found it computationally much more efficient to work with the rate functional (6), and we see the primary benefit of the EOT connection as bringing in tools from statistical OT (Genevay et al., 2019; Mena and Niles-Weed, 2019; Rigollet and Stromme, 2022) for R-D estimation. Outside of information theory: Rigollet and Weed (2018) note a connection between the EOT projection problem (8) and maximum-likelihood deconvolution (10); our work complements their perspective by re-interpreting both problems through the equivalent R-D problem. Unbeknownst to us at the time, Yan et al. (2023) proposed similar algorithms to ours in the context of Gaussian mixture estimation, which we recognize as R-D estimation under quadratic distortion (see Sec. 2.3). Their work is based on gradient flow in the Fisher-Rao-Wasserstein (FRW) geometry (Chizat et al., 2018), which our hybrid algorithm can be seen as implementing. Yan et al. (2023) prove that, in an idealized setting with infinite particles, FRW gradient descent does not get stuck at local minima; by contrast, our convergence and sample-complexity results (Prop. 4.2, 4.3) hold for any finite number of particles. We additionally consider larger-scale problems and the stochastic optimization setting.

4 Proposed method

For our algorithm, we require 𝒳=𝒴=d{\mathcal{X}}={\mathcal{Y}}=\mathbb{R}^{d} and ρ\rho be continuously differentiable. We now introduce the gradient descent algorithm in Wasserstein space to solve the problems (4) and (8). We defer all proofs to the Supplementary Material. To minimize a functional :𝒫(𝒴){\mathcal{L}}:{\mathcal{P}}({\mathcal{Y}})\to\mathbb{R} over the space of probability measures, our algorithm essentially simulates the gradient flow (Ambrosio et al., 2008) of {\mathcal{L}} and follows the trajectory of steepest descent in the Wasserstein geometry. In practice, we represent a measure ν(t)𝒫(𝒴)\nu^{(t)}\in{\mathcal{P}}({\mathcal{Y}}) by a collection of particles and at each time step update ν(t)\nu^{(t)} in a direction of steepest descent of {\mathcal{L}} as given by its (negative) Wasserstein gradient. Denote by 𝒫n(d)\mathcal{P}_{n}(\mathbb{R}^{d}) the set of probability measures on d\mathbb{R}^{d} that are supported on at most nn points. Our algorithm implements the parametric R-D estimator with the choice =𝒫n(d){\mathcal{H}}=\mathcal{P}_{n}(\mathbb{R}^{d}) (see discussions at the end of Sec. 2.1).

4.1 Wasserstein gradient descent (WGD)

Abstractly, Wasserstein gradient descent updates the variational measure ν\nu to its pushforward ν~\tilde{\nu} under the map (idγΨ)(\text{id}-\gamma\Psi), for a function Ψ:dd\Psi:\mathbb{R}^{d}\rightarrow\mathbb{R}^{d} called the Wasserstein gradient of \mathcal{L} at ν\nu (see below) and a step size γ\gamma. To implement this scheme, we represent ν\nu as a convex combination of Dirac measures, ν=i=1nwiδxi\nu=\sum_{i=1}^{n}w_{i}\delta_{x_{i}} with locations {xi}i=1nd\{x_{i}\}_{i=1}^{n}\subset\mathbb{R}^{d} and weights {wi}i=1n\{w_{i}\}_{i=1}^{n}. The algorithm moves each particle xix_{i} in the direction of Ψ(xi)-\Psi(x_{i}), more precisely, ν~=i=1nwiδxiγΨ(xi)\tilde{\nu}=\sum_{i=1}^{n}w_{i}\delta_{x_{i}-\gamma\Psi(x_{i})}.

Algorithm 1 Wasserstein gradient descent
  Inputs: Loss function {BA,EOT}\mathcal{L}\in\{\mathcal{L}_{BA},\mathcal{L}_{EOT}\}; data distribution μ𝒫(d)\mu\in\mathcal{P}(\mathbb{R}^{d}); the number of particles nn\in\mathbb{N}; total number of iterations NN\in\mathbb{N}; step sizes γ1,,γN\gamma_{1},\dots,\gamma_{N}; batch size mm\in\mathbb{N}.
  for t=1,,Nt=1,\dots,N do
     Pick an initial measure ν(0)𝒫n(d)\nu^{(0)}\in\mathcal{P}_{n}(\mathbb{R}^{d}), e.g., setting the particles to nn random samples from μ\mu.
     if support of μ\mu contains more than mm points then
        μm1mi=1mδxi\mu^{m}\leftarrow\frac{1}{m}\sum_{i=1}^{m}\delta_{x_{i}} for x1,,xmx_{1},\dots,x_{m} independent samples from μ\mu
        Ψ(t)\Psi^{(t)}\leftarrow Wasserstein gradient of (μm,)\mathcal{L}(\mu^{m},\cdot) at ν(t1)\nu^{(t-1)} {see Definition 4.1}
     else
        Ψ(t)\Psi^{(t)}\leftarrow Wasserstein gradient of (μ,)\mathcal{L}(\mu,\cdot) at ν(t1)\nu^{(t-1)} {see Definition 4.1}
     end if
     ν(t)(idγtΨ(t))#ν(t1)\nu^{(t)}\leftarrow\left(\text{id}-\gamma_{t}\Psi^{(t)}\right)_{\#}\nu^{(t-1)} {“#\#” denotes pushforward}
  end for
  Return: ν(N)\nu^{(N)}

Since the optimization objectives (4) and (8) appear as integrals w.r.t. the data distribution μ\mu, we can also apply stochastic optimization and perform stochastic gradient descent on mini-batches with size mm. This allows us to handle a very large or infinite amount of data samples, or when the source is continuous. We formalize the procedure in Algorithm 1.

The following gives a constructive definition of a Wasserstein gradient which forms the computational basis of our algorithm. In the literature, the Wasserstein gradient is instead usually defined as a Fréchet differential (cf. (Ambrosio et al., 2008, Definition 10.1.1)), but we emphasize that in smooth settings, the given definition recovers the one from the literature (cf. (Chizat, 2022, Lemma A.2)).

Definition 4.1.

For a functional :𝒫(𝒴)\mathcal{L}:\mathcal{P}({\mathcal{Y}})\rightarrow\mathbb{R} and ν𝒫(𝒴)\nu\in\mathcal{P}({\mathcal{Y}}), we say that V(ν):dV_{\mathcal{L}}(\nu):\mathbb{R}^{d}\rightarrow\mathbb{R} is a first variation of \mathcal{L} at ν\nu if

limε0((1ε)ν+εν~)(ν)ε=V(ν)d(ν~ν) for all ν~𝒫(𝒴).\lim_{\varepsilon\rightarrow 0}\frac{\mathcal{L}((1-\varepsilon)\nu+\varepsilon\tilde{\nu})-\mathcal{L}(\nu)}{\varepsilon}=\int V_{\mathcal{L}}(\nu)\,d(\tilde{\nu}-\nu)~{}~{}\text{ for all }\tilde{\nu}\in\mathcal{P}({\mathcal{Y}}).

We call its (Euclidean) gradient V(ν):dd\nabla V_{\mathcal{L}}(\nu):\mathbb{R}^{d}\rightarrow\mathbb{R}^{d}, if it exists, the Wasserstein gradient of \mathcal{L} at ν\nu.

For =EOT\mathcal{L}=\mathcal{L}_{EOT}, the first variation is given by the Kantorovich potential, which is the solution of the convex dual of EOT\mathcal{L}_{EOT} and commonly computed by Sinkhorn’s algorithm (Peyré and Cuturi, 2019; Nutz, 2021). Specifically, let (φν,ψν)(\varphi^{\nu},\psi^{\nu}) be potentials for EOT(μ,ν)\mathcal{L}_{EOT}(\mu,\nu). Then V(ν)=ψνV_{{\mathcal{L}}}(\nu)=\psi^{\nu} is the first variation w.r.t. ν\nu (cf. (Carlier et al., 2022, equation (20))), and hence ψν\nabla\psi^{\nu} is the Wasserstein gradient. This gradient exists whenever ρ\rho is differentiable and the marginals are sufficiently light-tailed; we give details in Sec. 9.1 of the Supplementary Material. For =BA{\mathcal{L}}={\mathcal{L}}_{BA}, the first variation can be computed explicitly. As derived in Sec. 9.1 of the Supplementary Material, the first variation at ν\nu is

ψν(y)=exp(λρ(x,y))exp(λρ(x,y~))ν(dy~)μ(dx)\psi^{\nu}(y)=\int-\frac{\exp(-\lambda\rho(x,y))}{\int\exp(-\lambda\rho(x,\tilde{y}))\nu(d\tilde{y})}\mu(dx)

and then the Wasserstein gradient is BA(ν)=ψν\nabla\mathcal{L}_{BA}(\nu)=\nabla\psi^{\nu}. We observe that ψν(y)\psi^{\nu}(y) is computationally cheap; it corresponds to running a single iteration of Sinkhorn’s algorithm. By contract, finding the potential for EOT\mathcal{L}_{EOT} requires running Sinkhorn’s algorithm to convergence.

Like the usual Euclidean gradient, the Wasserstein gradient can be shown to possess a linearization property, whereby the loss functional is reduced by taking a small enough step along its Wasserstein gradient. Following (Carlier et al., 2022), we state it as follows: for any ν~𝒫(𝒴)\tilde{\nu}\in\mathcal{P}({\mathcal{Y}}) and πΠ(ν,ν~)\pi\in\Pi(\nu,\tilde{\nu}),

(ν~)(ν)=(yx)V(ν)(x)π(dx,dy)+o(yx2π(dx,dy)),|V(ν)2𝑑νV(ν~)2𝑑ν~|CW2(ν,ν~).\begin{split}\mathcal{L}(\tilde{\nu})-\mathcal{L}(\nu)=\int(y-x)^{\top}\nabla V_{\mathcal{L}}(\nu)(x)\,\pi(dx,dy)+o\left(\int\|y-x\|^{2}\,\pi(dx,dy)\right),\\ \left|\int\|\nabla V_{\mathcal{L}}(\nu)\|^{2}\,d\nu-\int\|\nabla V_{\mathcal{L}}(\tilde{\nu})\|^{2}\,d\tilde{\nu}\right|\leq CW_{2}(\nu,\tilde{\nu}).\end{split} (14)

The first line of (14) is proved for EOT\mathcal{L}_{EOT} in (Carlier et al., 2022, Proposition 4.2) in the case that the marginals are compactly supported and ρ\rho is twice continuously differentiable. In this setting, the second line of (14) follows using a2b2=(a+b)(ab)a^{2}-b^{2}=(a+b)(a-b) and a combination of boundedness and Lipschitz continuity of V\nabla V_{\mathcal{L}}, see (Carlier et al., 2022, Proposition 2.2 and Corollary 2.4).

The linearization property given by (14) enables us to show that Wasserstein gradient descent for EOT\mathcal{L}_{EOT} and BA\mathcal{L}_{BA} converges to a stationary point under mild conditions:

Proposition 4.2 (Convergence of Wasserstein gradient descent).

Let γ1γ20\gamma_{1}\geq\gamma_{2}\geq\dots\geq 0 satisfy k=1γk=\sum_{k=1}^{\infty}\gamma_{k}=\infty and k=1γk2<\sum_{k=1}^{\infty}\gamma_{k}^{2}<\infty. Let :𝒫(d)\mathcal{L}:\mathcal{P}(\mathbb{R}^{d})\rightarrow\mathbb{R} be Wasserstein differentiable in the sense that (14) holds. Denoting by ν(t)\nu^{(t)} the steps in Algorithm 1, and suppose that (ν(0))\mathcal{L}(\nu^{(0)}) is finite and V(ν(t))2𝑑ν(t)\int\|\nabla V_{\mathcal{L}}(\nu^{(t)})\|^{2}\,d\nu^{(t)} is bounded. Then

limtV(ν(t))2𝑑ν(t)=0.\lim_{t\rightarrow\infty}\int\|\nabla V_{\mathcal{L}}(\nu^{(t)})\|^{2}\,d\nu^{(t)}=0.

4.2 Hybrid algorithm

A main limitation of the BA algorithm is that the support of ν(t)\nu^{(t)} is restricted to that of the (possibly bad) initialization ν(0)\nu^{(0)}. On the other hand, Wasserstein gradient descent (Algorithm 1) only evolves the particle locations of ν(t)\nu^{(t)}, but not the weights, which are fixed to be uniform by default. We therefore consider a hybrid algorithm where we alternate between WGD and the BA update steps, allowing us to optimize the particle weights as well. Experimentally, this translates to faster convergence than the base WGD algorithm (Sec. 5.1). Note however, unlike WGD, the hybrid algorithm does not directly lend itself to the stochastic optimization setting, as BA updates on mini-batches no longer guarantee monotonic improvement in the objective and can lead to divergence. We treat the convergence of the hybrid algorithm in the Supplementary Material Sec. 9.4.

4.3 Sample complexity

Let 𝒳=𝒴=d{\mathcal{X}}={\mathcal{Y}}=\mathbb{R}^{d} and ρ(x,y)=xy2\rho(x,y)=\|x-y\|^{2}. Leveraging work on the statistical complexity of EOT (Mena and Niles-Weed, 2019), we obtain finite-sample bounds for the theoretical estimators implemented by WGD in terms of the number of particles and source samples. The bounds hold for both the R-D problem (4) and EOT projection problem (8) as they share the same optimizers (see Sec. 2.2), and strengthen existing asymptotic results for empirical R-D estimators (Harrison and Kontoyiannis, 2008). We note that a recent result by Rigollet and Stromme (2022) might be useful for deriving alternative bounds under distortion functions other than the quadratic.

Proposition 4.3.

Let μ\mu be σ2\sigma^{2}-subgaussian. Then every optimizer ν\nu^{*} of (4) and (8) is also σ2\sigma^{2}-subgaussian. Consider :=EOT\mathcal{L}:=\mathcal{L}_{EOT}. For a constant CdC_{d} only depending on dd, we have

|minν𝒫(d)(μ,ν)minνn𝒫n(d)(μ,νn)|\displaystyle\left|\min_{\nu\in\mathcal{P}(\mathbb{R}^{d})}\mathcal{L}(\mu,\nu)-\min_{\nu_{n}\in\mathcal{P}_{n}(\mathbb{R}^{d})}\mathcal{L}(\mu,\nu_{n})\right| Cdϵ(1+σ5d/2+6ϵ5d/4+3)1n,\displaystyle\leq C_{d}\,\epsilon\,\left(1+\frac{\sigma^{\lceil 5d/2\rceil+6}}{\epsilon^{\lceil 5d/4\rceil+3}}\right)\,\frac{1}{\sqrt{n}},
𝔼[|minν𝒫(d)(μ,ν)minν𝒫(d)(μm,ν)|]\displaystyle\mathbb{E}\left[\left|\min_{\nu\in\mathcal{P}(\mathbb{R}^{d})}\mathcal{L}(\mu,\nu)-\min_{\nu\in\mathcal{P}(\mathbb{R}^{d})}\mathcal{L}(\mu^{m},\nu)\right|\right] Cdϵ(1+σ5d/2+6ϵ5d/4+3)1m,\displaystyle\leq C_{d}\,\epsilon\,\left(1+\frac{\sigma^{\lceil 5d/2\rceil+6}}{\epsilon^{\lceil 5d/4\rceil+3}}\right)\,\frac{1}{\sqrt{m}},
𝔼[|minν𝒫(d)(μ,ν)minνn𝒫n(d)(μm,νn)|]\displaystyle\mathbb{E}\left[\left|\min_{\nu\in\mathcal{P}(\mathbb{R}^{d})}\mathcal{L}(\mu,\nu)-\min_{\nu_{n}\in\mathcal{P}_{n}(\mathbb{R}^{d})}\mathcal{L}(\mu^{m},\nu_{n})\right|\right] Cdϵ(1+σ5d/2+6ϵ5d/4+3)(1m+1n),\displaystyle\leq C_{d}\,\epsilon\,\left(1+\frac{\sigma^{\lceil 5d/2\rceil+6}}{\epsilon^{\lceil 5d/4\rceil+3}}\right)\,\left(\frac{1}{\sqrt{m}}+\frac{1}{\sqrt{n}}\right),

for all n,mn,m\in\mathbb{N}, where 𝒫n(d)\mathcal{P}_{n}(\mathbb{R}^{d}) is the set of probability measures over d\mathbb{R}^{d} supported on at most nn points, μm\mu^{m} is the empirical measure of μ\mu with mm independent samples and the expectation 𝔼[]\mathbb{E}[\cdot] is over these samples. The same inequalities hold for :=λ1BA\mathcal{L}:=\lambda^{-1}\mathcal{L}_{BA}, with the identification ϵ=λ1\epsilon=\lambda^{-1}.

4.4 Estimation of rate and distortion

Here, we describe our estimator for an upper bound (𝒟,)(\mathcal{D},\mathcal{R}) of R(D)R(D) after solving the unconstrained problem (3). We provide more details in Sec. 8 of the Supplementary Material.

For any given pair of ν\nu and KK, we always have that 𝒟:=ρd(μK)\mathcal{D}:=\int\rho d(\mu\otimes K) and :=H(μK|μν)\mathcal{R}:=H(\mu\otimes K|\mu\otimes\nu) lie on an upper bound of R(D)R(D) (Berger, 1971). The two quantities can be estimated by standard Monte Carlo provided we can sample from μK\mu\otimes K and evaluate the density dμKdμν(x,y)=dK(x,)dν(y)\frac{d\mu\otimes K}{d\mu\otimes\nu}(x,y)=\frac{dK(x,\cdot)}{d\nu}(y).

When only ν\nu is given, e.g., obtained from optimizing (6) with WGD or NERD, we estimate an R-D upper bound as follows. As in the BA algorithm, we construct a kernel KνK_{\nu} similarly to (11), i.e., dKν(x,)dν(y)=eλρ(x,y)eλρ(x,y~)ν(dy~)\frac{dK_{\nu}(x,\cdot)}{d\nu}(y)=\frac{e^{-\lambda\rho(x,y)}}{\int e^{-\lambda\rho(x,\tilde{y})}\nu(d\tilde{y})}; then we estimate (𝒟,)(\mathcal{D},\mathcal{R}) using the pair (ν,Kν)(\nu,K_{\nu}) as described earlier.

As NERD uses a continuous ν\nu, we follow (Lei et al., 2023a) and approximate it with its nn-sample empirical measure to estimate (𝒟,)(\mathcal{D},\mathcal{R}). A limitation of NERD, BA, and our method is that they tend to converge to a rate estimate of at most log(n)\log(n), where nn is the support size of ν\nu. This is because as the algorithms approach an nn-point minimizer νn\nu_{n}^{*} of the R-D problem, the rate estimate \mathcal{R} approaches the mutual information of μKνn\mu\otimes K_{\nu_{n}^{*}}, which is upper-bounded by log(n)\log(n) (Eckstein and Nutz, 2022). In practice, this means if a target point of R(D)R(D) has rate rr, then we need nern\geq e^{r} to estimate it accurately.

4.5 Computational considerations

Common to all the aforementioned methods is the evaluation of a pairwise distortion matrix between mm points in 𝒳{\mathcal{X}} and nn points in 𝒴{\mathcal{Y}}, which usually has a cost of 𝒪(mnd)\mathcal{O}(mnd) for a dd-dimensional source. While RD-VAE uses n=1n=1 (in the reparameterization trick), the other methods (BA, WGD, NERD) typically use a much larger nn and thus has the distortion computation as their main computation bottleneck. Compared to BA and WGD, the neural methods (RD-VAE, NERD) incur additional computation from neural network operations, which can be significant for large networks.

For NERD and WGD (and BA), the rate estimate upper bound of log(n)\log(n) nats/sample (see Sec. 4.4) can present computational challenges. To target a high-rate setting, a large number of ν\nu particles (high nn) is required, and care needs to be taken to avoid running out of memory during the distortion matrix computation (one possibility is to use a small batch size mm with stochastic optimization).

5 Experiments

We compare the empirical performance of our proposed method (WGD) and its hybrid variant with Blahut–Arimoto (BA) (Blahut, 1972; Arimoto, 1972), RD-VAE (Yang and Mandt, 2022), and NERD (Lei et al., 2022) on the tasks of maximum-likelihood deconvolution and estimation of R-D upper bounds. While we experimented with WGD for both BA\mathcal{L}_{BA} and EOT\mathcal{L}_{EOT}, we found the former to be 10 to 100 times faster computationally while giving similar or better results; we therefore focus on WGD for BA\mathcal{L}_{BA} in discussions below. For the neural-network baselines, we use the same (or as similar as possible) network architectures as in the original work (Yang and Mandt, 2022; Lei et al., 2023a). We use the Adam optimizer for all gradient-based methods, except we use simple gradient descent with a decaying step size in Sec. 5.1 to better compare the convergence speed of WGD and its hybrid variant. Further experiment details and results are given in the Supplementary Material Sec. 11.

5.1 Deconvolution

To better understand the behavior of the various algorithms, we apply them to a deconvolution problem with known ground truth (see Sec. 2.3). We adopt the Gaussian noise as before, letting α\alpha be the uniform measure on the unit circle in 2\mathbb{R}^{2} and the source μ=α𝒩(0,σ2)\mu=\alpha*\mathcal{N}(0,\sigma^{2}) with σ2=0.1\sigma^{2}=0.1.

Refer to caption
Figure 2: Losses over iterations. Shading corresponds to one standard deviation over random initializations.
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 3: Visualizing μ\mu samples (top left), as well as the ν\nu returned by various algorithms compared to the ground truth ν\nu^{*} (cyan).
Refer to caption
Figure 4: Optimality gap v.s. the number of particles nn used, on deconvolution problems with different dimension dd and distortion multiplier λ\lambda. The first three panels fix λ=σ2\lambda=\sigma^{-2} and increase dd, and the right-most panel corresponds to the 2-D problem with a higher λ\lambda (denoising with a narrower Gaussian kernel). Overall the WGD methods attain higher accuracy for a given budget of nn.

We use n=20n=20 particles for BA, NERD, WGD and its hybrid variant. We use a two-layer network for NERD and RD-VAE with some hand-tuning (we replace the softplus activation in the original RD-VAE network by ReLU as it led to difficulty in optimization). Fig. 3 plots the resulting loss curves and shows that the proposed algorithms converge the fastest to the ground truth value OPT:=(α)OPT:={\mathcal{L}}(\alpha). In Fig. 3, we visualize the final ν(t)\nu^{(t)} at the end of training, compared to the ground truth ν=α\nu^{*}=\alpha supported on the circle (colored in cyan). Note that we initialize ν(0)\nu^{(0)} for BA, WGD, and its hybrid variant to the same nn random data samples. While BA is stuck with the randomly initialized particles and assigns large weights to those closer to the circle, WGD learns to move the particles to uniformly cover the circle. The hybrid algorithm, being able to reweight particles to reduce their transportation cost, learns a different solution where a cluster of particles covers the top-left portion of the circle with small weights while the remaining particles evenly covers the rest. Unlike our particle-based methods, the neural methods generally struggle to place the support of their ν\nu exactly on the circle.

We additionally compare how the performance of BA, NERD, and the proposed algorithms scale to higher dimensions and a higher λ\lambda (corresponding to lower entropic regularization in EOT{\mathcal{L}}_{EOT}). Fig. 4 plots the gap between the converged and the optimal losses for the algorithms, and demonstrates the proposed algorithms to be more particle-efficient and scale more favorably than the alternatives which also use nn particles in the reproduction space. We additionally visualize how the converged particles for our methods vary across the R-D trade-off in Fig. 6 of the Supplementary Material.

5.2 Higher-dimensional data

We perform R(D)R(D) estimation on higher-dimensional data, including the physics and speech datasets from (Yang and Mandt, 2022) and MNIST (LeCun et al., 1998). As the memory cost to operating on the full datasets becomes prohibitive, we focus on NERD, RD-VAE, and WGD using mini-batch stochastic gradient descent. BA and hybrid WGD do not directly apply in the stochastic setting, as BA updates on random mini-batches can lead to divergence (as discussed in Sec. 4.2).

Fig. 5 plots the estimated R-D bounds on the datasets, and compares the convergence speed of WGD and neural methods in both iteration count and compute time. Overall, we find WGD to require minimal tuning and obtains the tightest R-D upper bounds within the log(n)\log(n) rate limit (see Sec. 4.4), and consistently obtains tighter bounds than NERD given the same computation budget.

Refer to caption
Refer to caption
Refer to captionRefer to caption
Figure 5: Left, Middle: R-D bound estimates on the physics, speech datasets (Yang and Mandt, 2022) and MNIST training set. Right: Example speed comparisons of WGD and neural upper bound methods on MNIST, with WGD converging at least an order of magnitude faster. On each of the dataset we also include an R-D lower bound estimated using the method of (Yang and Mandt, 2022).

6 Discussions

In this work, we leverage tools from optimal transport to develop a new approach for estimating the rate-distortion function in the continuous setting. Compared to state-of-the-art neural approaches (Yang and Mandt, 2022; Lei et al., 2022), our Wasserstein gradient descent algorithm offers complementary strengths: 1) It requires a single main hyperparameter nn (the number of particles) and no network architecture tuning; and 2) empirically we found it to converge significantly faster and rarely end up in bad local optima; increasing nn almost always yielded an improvement (unless the bound is already close to being tight). From a modeling perspective, a particle representation may be inherently more efficient when the optimal reproduction distribution is singular or has many disconnected modes (shown, e.g., in Figs. 1 and 3). However, like NERD (Lei et al., 2022), our method has a fundamental limitation – it requires an nn that is exponential in the rate of the targeted R(D)R(D) point to estimate it accurately (see Sec. 4.4). Thus, a neural method like RD-VAE (Yang and Mandt, 2022) may still be preferable on high-rate sources, while our particle-based method stands as a state-of-the-art solution on lower-rate sources with substantially less tuning or computation requirements.

Besides R-D estimation, our algorithm also applies to the mathematically equivalent problems of maximum likelihood deconvolution and projection under an entropic optimal transport (EOT) cost, and may find other connections and applications. Indeed, the EOT projection view of our algorithm is further related to optimization-based approaches to sampling (Wibisono, 2018), variational inference (Liu and Wang, 2016), and distribution compression (Shetty et al., 2021). Our particle-based algorithm also generalizes optimal quantization (corresponding to ϵ=0\epsilon=0 in EOT{\mathcal{L}}_{EOT} and projection of the source under the Wasserstein distance (Graf and Luschgy, 2007; Gray, 2013)) to incorporate a rate constraint (ϵ>0\epsilon>0), and it would be interesting to explore the use of the resulting rate-distortion optimal quantizer for practical data compression and communication.

References

  • Shannon [1959] CE Shannon. Coding theorems for a discrete source with a fidelity criterion. IRE Nat. Conv. Rec., March 1959, 4:142–163, 1959.
  • Blahut [1972] R. Blahut. Computation of channel capacity and rate-distortion functions. IEEE Transactions on Information Theory, 18(4):460–473, 1972. doi: 10.1109/TIT.1972.1054855.
  • Arimoto [1972] Suguru Arimoto. An algorithm for computing the capacity of arbitrary discrete memoryless channels. IEEE Transactions on Information Theory, 18(1):14–20, 1972.
  • Harrison and Kontoyiannis [2008] Matthew T. Harrison and Ioannis Kontoyiannis. Estimation of the rate–distortion function. IEEE Transactions on Information Theory, 54(8):3757–3762, 2008. doi: 10.1109/tit.2008.926387.
  • Gibson [2017] Jerry Gibson. Rate distortion functions and rate distortion function lower bounds for real-world sources. Entropy, 19(11):604, 2017.
  • Yang and Mandt [2022] Yibo Yang and Stephan Mandt. Towards empirical sandwich bounds on the rate-distortion function. In International Conference on Learning Representations, 2022.
  • Lei et al. [2023a] Eric Lei, Hamed Hassani, and Shirin Saeedi Bidokhti. Neural estimation of the rate-distortion function with applications to operational source coding. IEEE Journal on Selected Areas in Information Theory, 2023a.
  • Ambrosio et al. [2008] Luigi Ambrosio, Nicola Gigli, and Giuseppe Savaré. Gradient flows in metric spaces and in the space of probability measures. Lectures in Mathematics ETH Zürich. Birkhäuser Verlag, Basel, second edition, 2008.
  • Mena and Niles-Weed [2019] Gonzalo Mena and Jonathan Niles-Weed. Statistical bounds for entropic optimal transport: sample complexity and the central limit theorem. Advances in Neural Information Processing Systems, 32, 2019.
  • Polyanskiy and Wu [2022] Yury Polyanskiy and Yihong Wu. Information theory: From coding to learning. Book draft, 2022.
  • Gray [2011] Robert M Gray. Entropy and information theory. Springer Science & Business Media, 2011.
  • Dembo and Kontoyiannis [2002] Amir Dembo and L Kontoyiannis. Source coding, large deviations, and approximate pattern matching. IEEE Transactions on Information Theory, 48(6):1590–1615, 2002.
  • Csiszár [1974a] Imre Csiszár. On an extremum problem of information theory. Studia Scientiarum Mathematicarum Hungarica, 9, 01 1974a.
  • Peyré and Cuturi [2019] Gabriel Peyré and Marco Cuturi. Computational optimal transport: With applications to data science. Foundations and Trends in Machine Learning, 11(5-6):355–607, 2019.
  • Bassetti et al. [2006] Federico Bassetti, Antonella Bodini, and Eugenio Regazzini. On minimum kantorovich distance estimators. Statistics & probability letters, 76(12):1298–1302, 2006.
  • Agueh and Carlier [2011] Martial Agueh and Guillaume Carlier. Barycenters in the wasserstein space. SIAM Journal on Mathematical Analysis, 43(2):904–924, 2011.
  • Alemi et al. [2018] Alexander Alemi, Ben Poole, Ian Fischer, Joshua Dillon, Rif A Saurous, and Kevin Murphy. Fixing a broken ELBO. In International Conference on Machine Learning, pages 159–168. PMLR, 2018.
  • Ballé et al. [2017] Johannes Ballé, Valero Laparra, and Eero P Simoncelli. End-to-end optimized image compression. International Conference on Learning Representations, 2017.
  • Theis et al. [2017] Lucas Theis, Wenzhe Shi, Andrew Cunningham, and Ferenc Huszár. Lossy image compression with compressive autoencoders. International Conference on Learning Representations, 2017.
  • Yang et al. [2020] Yibo Yang, Robert Bamler, and Stephan Mandt. Improving inference for neural image compression. In Neural Information Processing Systems (NeurIPS), 2020, 2020.
  • Rigollet and Weed [2018] Philippe Rigollet and Jonathan Weed. Entropic optimal transport is maximum-likelihood deconvolution. Comptes Rendus Mathematique, 356(11-12):1228–1235, 2018.
  • Carroll and Hall [1988] Raymond J Carroll and Peter Hall. Optimal rates of convergence for deconvolving a density. Journal of the American Statistical Association, 83(404):1184–1186, 1988.
  • Lindsay and Roeder [1993] Bruce G Lindsay and Kathryn Roeder. Uniqueness of estimation and identifiability in mixture models. Canadian Journal of Statistics, 21(2):139–147, 1993.
  • Csiszár [1974b] Imre Csiszár. On the computation of rate-distortion functions (corresp.). IEEE Transactions on Information Theory, 20(1):122–124, 1974b. doi: 10.1109/TIT.1974.1055146.
  • Csiszár [1984] Imre Csiszár. Information geometry and alternating minimization procedures. Statistics and Decisions, Dedewicz, 1:205–237, 1984.
  • Gray and Neuhoff [1998] Robert M. Gray and David L. Neuhoff. Quantization. IEEE transactions on information theory, 44(6):2325–2383, 1998.
  • Papamakarios et al. [2021] George Papamakarios, Eric Nalisnick, Danilo Jimenez Rezende, Shakir Mohamed, and Balaji Lakshminarayanan. Normalizing flows for probabilistic modeling and inference. The Journal of Machine Learning Research, 22(1):2617–2680, 2021.
  • Kingma and Welling [2013] Diederik P Kingma and Max Welling. Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114, 2013.
  • Wu et al. [2022] Shitong Wu, Wenhao Ye, Hao Wu, Huihui Wu, Wenyi Zhang, and Bo Bai. A communication optimal transport approach to the computation of rate distortion functions. arXiv preprint arXiv:2212.10098, 2022.
  • Lei et al. [2023b] Eric Lei, Hamed Hassani, and Shirin Saeedi Bidokhti. On a relation between the rate-distortion function and optimal transport, 2023b.
  • Genevay et al. [2019] Aude Genevay, Lénaic Chizat, Francis Bach, Marco Cuturi, and Gabriel Peyré. Sample complexity of Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics, pages 1574–1583. PMLR, 2019.
  • Rigollet and Stromme [2022] Philippe Rigollet and Austin J Stromme. On the sample complexity of entropic optimal transport. arXiv preprint arXiv:2206.13472, 2022.
  • Yan et al. [2023] Yuling Yan, Kaizheng Wang, and Philippe Rigollet. Learning gaussian mixtures using the wasserstein-fisher-rao gradient flow. arXiv preprint arXiv:2301.01766, 2023.
  • Chizat et al. [2018] Lenaic Chizat, Gabriel Peyré, Bernhard Schmitzer, and François-Xavier Vialard. An interpolating distance between optimal transport and fisher–rao metrics. Foundations of Computational Mathematics, 18:1–44, 2018.
  • Chizat [2022] Lénaïc Chizat. Mean-field langevin dynamics: Exponential convergence and annealing. arXiv preprint arXiv:2202.01009, 2022.
  • Nutz [2021] Marcel Nutz. Introduction to entropic optimal transport. Lecture notes, Columbia University, 2021. https://www.math.columbia.edu/~mnutz/docs/EOT_lecture_notes.pdf.
  • Carlier et al. [2022] Guillaume Carlier, Lénaïc Chizat, and Maxime Laborde. Lipschitz continuity of the Schrödinger map in entropic optimal transport. arXiv preprint arXiv:2210.00225, 2022.
  • Berger [1971] Toby Berger. Rate distortion theory, a mathematical basis for data compression. Prentice Hall, 1971.
  • Eckstein and Nutz [2022] Stephan Eckstein and Marcel Nutz. Convergence rates for regularized optimal transport via quantization. arXiv preprint arXiv:2208.14391, 2022.
  • Lei et al. [2022] Eric Lei, Hamed Hassani, and Shirin Saeedi Bidokhti. Neural estimation of the rate-distortion function for massive datasets. In 2022 IEEE International Symposium on Information Theory (ISIT), pages 608–613. IEEE, 2022.
  • LeCun et al. [1998] Yann LeCun, Léon Bottou, Yoshua Bengio, and Patrick Haffner. Gradient-based learning applied to document recognition. Proceedings of the IEEE, 86(11):2278–2324, 1998.
  • Wibisono [2018] Andre Wibisono. Sampling as optimization in the space of measures: The langevin dynamics as a composite optimization problem. In Sébastien Bubeck, Vianney Perchet, and Philippe Rigollet, editors, Proceedings of the 31st Conference On Learning Theory, volume 75 of Proceedings of Machine Learning Research, pages 2093–3027. PMLR, 06–09 Jul 2018. URL https://proceedings.mlr.press/v75/wibisono18a.html.
  • Liu and Wang [2016] Qiang Liu and Dilin Wang. Stein variational gradient descent: A general purpose bayesian inference algorithm. Advances in neural information processing systems, 29, 2016.
  • Shetty et al. [2021] Abhishek Shetty, Raaz Dwivedi, and Lester Mackey. Distribution compression in near-linear time. arXiv preprint arXiv:2111.07941, 2021.
  • Graf and Luschgy [2007] Siegfried Graf and Harald Luschgy. Foundations of quantization for probability distributions. Springer, 2007.
  • Gray [2013] Robert M. Gray. Transportation distance, shannon information, and source coding. GRETSI 2013 Symposium on Signal and Image Processing, 2013. URL https://ee.stanford.edu/~gray/gretsi.pdf.
  • Çinlar [2011] Erhan Çinlar. Probability and stochastics, volume 261. Springer, 2011.
  • Folland [1999] Gerald B Folland. Real analysis: modern techniques and their applications, volume 40. John Wiley & Sons, 1999.
  • Polyanskiy and Wu [2014] Yury Polyanskiy and Yihong Wu. Lecture notes on information theory. Lecture Notes for ECE563 (UIUC) and, 6(2012-2016):7, 2014.
  • Yang et al. [2023] Yibo Yang, Stephan Mandt, and Lucas Theis. An introduction to neural data compression. Foundations and Trends® in Computer Graphics and Vision, 15(2):113–200, 2023. ISSN 1572-2740. doi: 10.1561/0600000107. URL http://dx.doi.org/10.1561/0600000107.
  • Blei et al. [2017] David M Blei, Alp Kucukelbir, and Jon D McAuliffe. Variational inference: A review for statisticians. Journal of the American statistical Association, 112(518):859–877, 2017.
  • Beal and Ghahramani [2003] MJ Beal and Z Ghahramani. The variational bayesian em algorithm for incomplete data: with application to scoring graphical model structures. Bayesian statistics, 7(453-464):210, 2003.
  • Dempster et al. [1977] Arthur P Dempster, Nan M Laird, and Donald B Rubin. Maximum likelihood from incomplete data via the em algorithm. Journal of the royal statistical society: series B (methodological), 39(1):1–22, 1977.
  • Zhang et al. [2020] Mingtian Zhang, Peter Hayes, Thomas Bird, Raza Habib, and David Barber. Spread divergence. In International Conference on Machine Learning, pages 11106–11116. PMLR, 2020.
  • Vincent [2011] Pascal Vincent. A connection between score matching and denoising autoencoders. Neural computation, 23(7):1661–1674, 2011.
  • Jordan et al. [1999] Michael I Jordan, Zoubin Ghahramani, Tommi S Jaakkola, and Lawrence K Saul. An introduction to variational methods for graphical models. Machine learning, 37:183–233, 1999.
  • Wainwright et al. [2008] Martin J Wainwright, Michael I Jordan, et al. Graphical models, exponential families, and variational inference. Foundations and Trends® in Machine Learning, 1(1–2):1–305, 2008.
  • Jordan [1999] Michael Irwin Jordan. Learning in graphical models. MIT press, 1999.
  • Kingma and Ba [2015] Diederik P Kingma and Jimmy Lei Ba. Adam: A method for stochastic gradient descent. In International Conference on Learning Representations, 2015.
  • Platt and Barr [1987] John Platt and Alan Barr. Constrained differential optimization. In Neural Information Processing Systems, 1987.
  • Papamakarios et al. [2017] George Papamakarios, Theo Pavlakou, and Iain Murray. Masked autoregressive flow for density estimation. In Advances in Neural Information Processing Systems, pages 2338–2347, 2017.
  • Grosse et al. [2015] Roger B Grosse, Zoubin Ghahramani, and Ryan P Adams. Sandwiching the marginal likelihood using bidirectional monte carlo. arXiv preprint arXiv:1511.02543, 2015.
  • Bradbury et al. [2018] James Bradbury, Roy Frostig, Peter Hawkins, Matthew James Johnson, Chris Leary, Dougal Maclaurin, George Necula, Adam Paszke, Jake VanderPlas, Skye Wanderman-Milne, and Qiao Zhang. JAX: composable transformations of Python+NumPy programs, 2018. URL http://github.com/google/jax.

Broader Impacts

Improved estimates on the fundamental cost of data compression aids the development and analysis of compression algorithms. This helps researchers and engineers make better decisions about where to allocate their resources to improve certain compression algorithms, and can translate to economic gains for the broader society. However, like most machine learning algorithms trained on data, the output of our estimator is only accurate insofar as the training data is representative of the population distribution of interest, and practitioners need to ensure this in the data collection process.

Acknowledgements

We thank anonymous reviewers for feedback on the manuscript. Yibo Yang acknowledges support from the Hasso Plattner Foundation. Marcel Nutz acknowledges support from NSF Grants DMS-1812661, DMS-2106056. Stephan Mandt acknowledges support from the National Science Foundation (NSF) under the NSF CAREER Award 2047418; NSF Grants 2003237 and 2007719, the Department of Energy, Office of Science under grant DE-SC0022331, the IARPA WRIVA program, as well as gifts from Intel, Disney, and Qualcomm.

Supplementary Material for
Estimating the Rate-Distortion Function by Wasserstein Gradient Descent

We review probability theory background and explain our notation from the main text in Section 7, give the formulas we used for numerically estimating an R(D)R(D) upper bound in Section 8, provide additional discussions and proofs regarding Wasserstein gradient descent in Section 9, elaborate on the connections between the R-D estimation problem and variational inference/learning in Section 10, provide additional experimental results and details in Section 11, and list an example implementation of WGD in Section 12. Our code and can be found at https://github.com/yiboyang/wgd.

7 Notions from probability theory

In this section we collect notions of probability theory used in the main text. See, e.g., [Çinlar, 2011] or [Folland, 1999] for more background.

Marginal and conditional distributions.

The source and reproduction spaces 𝒳,𝒴{\mathcal{X}},{\mathcal{Y}} are equipped with sigma-algebras 𝒜𝒳\mathcal{A}_{\mathcal{X}} and 𝒜𝒴\mathcal{A}_{{\mathcal{Y}}}, respectively. Let 𝒳×𝒴{\mathcal{X}}\times{\mathcal{Y}} denote the product space equipped with the product sigma algebra 𝒜𝒳𝒜𝒴\mathcal{A}_{\mathcal{X}}\otimes\mathcal{A}_{{\mathcal{Y}}}. For any probability measure π\pi on 𝒳×𝒴{\mathcal{X}}\times{\mathcal{Y}}, its first marginal is

π1(A):=π(A×𝒴),A𝒜𝒳,\displaystyle\pi_{1}(A):=\pi(A\times{\mathcal{Y}}),\quad A\in\mathcal{A}_{\mathcal{X}},

which is a probability measure on 𝒳{\mathcal{X}}. When π\pi is the distribution of a random vector (X,Y)(X,Y), then π1\pi_{1} is the distribution of XX. The second marginal of π\pi is defined analogously as

π2(B):=π(𝒳×B),B𝒜𝒴.\displaystyle\pi_{2}(B):=\pi({\mathcal{X}}\times B),\quad B\in\mathcal{A}_{\mathcal{Y}}.

A Markov kernel or conditional distribution K(x,dy)K(x,dy) is a map 𝒳×𝒜𝒴[0,1]{\mathcal{X}}\times\mathcal{A}_{{\mathcal{Y}}}\to[0,1] such that

  1. 1.

    K(x,)K(x,\cdot) is a probability measure on 𝒴{\mathcal{Y}} for each x𝒳x\in{\mathcal{X}};

  2. 2.

    the function xK(x,B)x\mapsto K(x,B) is measurable for each set B𝒜𝒴B\in\mathcal{A}_{{\mathcal{Y}}}.

When speaking of the conditional distribution of a random variable YY given another random variable XX, we occasionally also use the notation QY|XQ_{Y|X} from information theory [Polyanskiy and Wu, 2014]. Then, QY|X=x(B)=K(x,B)Q_{Y|X=x}(B)=K(x,B) is the conditional probability of the event {YB}\{Y\in B\} given X=xX=x.

Suppose that a probability measure μ\mu on 𝒳{\mathcal{X}} is given, in addition to a kernel K(x,dy)K(x,dy). Together they define a unique measure μK\mu\otimes K on the product space 𝒳×𝒴{\mathcal{X}}\times{\mathcal{Y}}. For a rectangle set A×B𝒜𝒳𝒜𝒴A\times B\in\mathcal{A}_{\mathcal{X}}\otimes\mathcal{A}_{{\mathcal{Y}}},

μK(A×B)=Aμ(dx)K(x,B),A𝒜𝒳,B𝒜𝒴.\displaystyle\mu\otimes K(A\times B)=\int_{A}\mu(dx)K(x,B),\quad A\in\mathcal{A}_{\mathcal{X}},B\in\mathcal{A}_{\mathcal{Y}}.

The measure π:=μK\pi:=\mu\otimes K has first marginal π1=μ\pi_{1}=\mu.

The classic product measure is a special case of this construction. Namely, when a measure ν\nu on 𝒴{\mathcal{Y}} is given, using the constant kernel K(x,dy):=ν(dy)K(x,dy):=\nu(dy) (which does not depend on xx) gives rise to the product measure μν\mu\otimes\nu,

μν(A×B)=μ(A)ν(B),A𝒜𝒳,B𝒜𝒴.\displaystyle\mu\otimes\nu(A\times B)=\mu(A)\nu(B),\quad A\in\mathcal{A}_{\mathcal{X}},B\in\mathcal{A}_{\mathcal{Y}}.

Under mild conditions (for instance when 𝒳,𝒴{\mathcal{X}},{\mathcal{Y}} are Polish spaces equipped with their Borel sigma algebras, as in the main text), any probability measure π\pi on 𝒳×𝒴{\mathcal{X}}\times{\mathcal{Y}} is of the above form. Namely, the disintegration theorem asserts that π\pi can be written as π=π1K\pi=\pi_{1}\otimes K for some kernel KK. When π\pi is the joint distribution of a random vector (X,Y)(X,Y), this says that there is a measurable version of the conditional distribution QY|XQ_{Y|X}.

Optimal transport.

Given a measure μ\mu on 𝒳{\mathcal{X}} and a measurable function T:𝒳𝒴T:{\mathcal{X}}\to{\mathcal{Y}}, the pushforward (or image measure) of μ\mu under TT is a measure on 𝒴{\mathcal{Y}}, given by

T#μ(B)=μ(T1(B)),B𝒜𝒴.\displaystyle T_{\#}\mu(B)=\mu(T^{-1}(B)),\quad B\in\mathcal{A}_{\mathcal{Y}}.

If TT is seen as a random variable and μ\mu as the baseline probability measure, then T#μT_{\#}\mu is simply the distribution of TT.

Suppose that μ\mu and ν\nu are probability measures on 𝒳=𝒴=d{\mathcal{X}}={\mathcal{Y}}=\mathbb{R}^{d} with finite second moment. As introduced in the main text, Π(μ,ν)\Pi(\mu,\nu) denotes the set of couplings, i.e., measures π\pi on 𝒳×𝒴{\mathcal{X}}\times{\mathcal{Y}} with π1=μ\pi_{1}=\mu and π2=ν\pi_{2}=\nu. The 2-Wasserstein distance W2(μ,ν)W_{2}(\mu,\nu) between μ\mu and ν\nu is defined as

W2(μ,ν)=(infπΠ(μ,ν)yx2π(dx,dy))1/2.W_{2}(\mu,\nu)=\left(\inf_{\pi\in\Pi(\mu,\nu)}\int\|y-x\|^{2}\pi(dx,dy)\right)^{1/2}.

This indeed defines a metric on the space of probability measures with finite second moment.

We finish this section by giving a proof of the basic equivalence between optimization of EOT\mathcal{L}_{EOT} and BA\mathcal{L}_{BA}, which goes back to [Csiszár, 1974a, Lemma 1.3 and subsequent discussion]. Recall that Π(μ,)\Pi(\mu,\cdot) denotes the set of measures on the product space 𝒳×𝒴{\mathcal{X}}\times{\mathcal{Y}} with π1=μ\pi_{1}=\mu.

Lemma 7.1.

Set ϵ=λ1\epsilon=\lambda^{-1}. It holds that

infν𝒫(𝒴)EOT(ν)=infν𝒫(𝒴)λ1BA(ν)andargminν𝒫(𝒴)EOT(ν)=argminν𝒫(𝒴)BA(ν).\inf_{\nu\in\mathcal{P}({\mathcal{Y}})}\mathcal{L}_{EOT}(\nu)=\inf_{\nu\in\mathcal{P}({\mathcal{Y}})}\lambda^{-1}\mathcal{L}_{BA}(\nu)\qquad\mbox{and}\qquad\operatorname*{arg\,min}_{\nu\in\mathcal{P}({\mathcal{Y}})}\mathcal{L}_{EOT}(\nu)=\operatorname*{arg\,min}_{\nu\in\mathcal{P}({\mathcal{Y}})}\mathcal{L}_{BA}(\nu).
Proof.

Both statements will follow from a simple property of relative entropy [Polyanskiy and Wu, 2022, Theorem 4.1, “golden formula”]: For πΠ(μ,)\pi\in\Pi(\mu,\cdot) with H(π2ν)<H(\pi_{2}\mid\nu)<\infty, the properties of the logarithm reveal

H(πμν)=H(πμπ2)+H(π2ν),\displaystyle H(\pi\mid\mu\otimes\nu)=H(\pi\mid\mu\otimes\pi_{2})+H(\pi_{2}\mid\nu), (15)

and hence H(πμν)H(πμπ2)H(\pi\mid\mu\otimes\nu)\geq H(\pi\mid\mu\otimes\pi_{2}). This implies

infν𝒫(𝒴)EOT(ν)\displaystyle\inf_{\nu\in\mathcal{P}({\mathcal{Y}})}\mathcal{L}_{EOT}(\nu) =infπΠ(μ,)ρ𝑑π+ϵH(πμπ2)\displaystyle=\inf_{\pi\in\Pi(\mu,\cdot)}\int\rho d\pi+\epsilon H(\pi\mid\mu\otimes\pi_{2})
=infπΠ(μ,)infν𝒫(𝒴)ρ𝑑π+ϵH(πμν)\displaystyle=\inf_{\pi\in\Pi(\mu,\cdot)}\inf_{\nu\in\mathcal{P}({\mathcal{Y}})}\int\rho d\pi+\epsilon H(\pi\mid\mu\otimes\nu)
=1λinfν𝒫(𝒴)BA(ν).\displaystyle=\frac{1}{\lambda}\inf_{\nu\in\mathcal{P}({\mathcal{Y}})}\mathcal{L}_{BA}(\nu).

Further, any optimizer ν\nu^{*} of BA\mathcal{L}_{BA} with corresponding optimal “coupling” πΠ(μ,)\pi^{*}\in\Pi(\mu,\cdot) must satisfy H(π2ν)=0H(\pi_{2}^{*}\mid\nu^{*})=0 (otherwise, taking ν=π2\nu^{*}=\pi^{*}_{2} has better objective) and thus π2=ν\pi_{2}^{*}=\nu^{*} and πΠ(μ,ν)\pi^{*}\in\Pi(\mu,\nu^{*}), therefore ν\nu^{*} is also an optimizer of EOT\mathcal{L}_{EOT} by the above equality. Conversely, any optimizer ν\nu^{*} of EOT\mathcal{L}_{EOT} with coupling πΠ(μ,ν)Π(μ,)\pi^{*}\in\Pi(\mu,\nu^{*})\subset\Pi(\mu,\cdot) clearly yields a feasible solution for BA\mathcal{L}_{BA} as well, and hence is also an optimizer of BA\mathcal{L}_{BA} by the above equality. ∎

8 Numerical estimation of rate and distortion from a reproduction distribution

Given a reproduction distribution ν\nu and a kernel 𝒳×𝒜𝒴[0,1]{\mathcal{X}}\times\mathcal{A}_{{\mathcal{Y}}}\to[0,1], the tuple (𝒟,)2(\mathcal{D},\mathcal{R})\in\mathbb{R}^{2} defined by

𝒟:=ρd(μK)\mathcal{D}:=\int\rho d(\mu\otimes K)

and

:=H(μK|μν)\mathcal{R}:=H(\mu\otimes K|\mu\otimes\nu)

lies above the R(D)R(D) curve. This again follows from the variational inequality (15) H(π|μ×(μ×K)2)=I(μ×K)\mathcal{R}\geq H(\pi|\mu\times(\mu\times K)_{2})=I(\mu\times K), so that R(𝒟)=:infπΠ(π,):π(ρ)𝒟I(π)\mathcal{R}\geq R(\mathcal{D})=:\inf_{\pi\in\Pi(\pi,\cdot):\pi(\rho)\leq\mathcal{D}}I(\pi).

Given only a reproduction distribution ν\nu, we will construct a kernel KK from ν\nu and use (μ,K)(\mu,K) to compute (𝒟,)(\mathcal{D},\mathcal{R}), letting

dK(x,)dν(y)=eλρ(x,y)eλρ(x,y~)ν(dy~)\frac{dK(x,\cdot)}{d\nu}(y)=\frac{e^{-\lambda\rho(x,y)}}{\int e^{-\lambda\rho(x,\tilde{y})}\nu(d\tilde{y})}

as in the first step of the BA algorithm. This choice of KK is in fact optimal as it achieves the minimum in the definition of BA{\mathcal{L}}_{BA} (5) for the given ν\nu. Plugging KK into the formulas for 𝒟\mathcal{D} and \mathcal{R} gives

𝒟=ρ(x,y)eφ(x)λρ(x,y)μν(dx,dy)\displaystyle\mathcal{D}=\int\rho(x,y)e^{\varphi(x)-\lambda\rho(x,y)}\mu\otimes\nu(dx,dy) (16)

and

\displaystyle\mathcal{R} =log(K(x,dy)ν(dy))K(x,dy)μ(dx)\displaystyle=\int\int\log\left(\frac{K(x,dy)}{\nu(dy)}\right)K(x,dy)\mu(dx) (17)
=[φ(x)λρ(x,y)]eφ(x)λρ(x,y)μν(dx,dy),\displaystyle=\int[\varphi(x)-\lambda\rho(x,y)]e^{\varphi(x)-\lambda\rho(x,y)}\mu\otimes\nu(dx,dy),

where we introduced the shorthand

φ(x):=log𝒴eλρ(x,y)ν(dy).\varphi(x):=-\log\int_{\mathcal{Y}}e^{-\lambda\rho(x,y)}\nu(dy). (18)

Note that we have the following relation (which explains (6))

BA(ν)=+λ𝒟=φ(x)μ(dx).{\mathcal{L}}_{BA}(\nu)=\mathcal{R}+\lambda\mathcal{D}=\int\varphi(x)\mu(dx).

Let ν\nu be an nn-point measure, ν=i=1nwiδyi\nu=\sum_{i=1}^{n}w_{i}\delta_{y_{i}}, e.g., the output of WGD or in the inner step of NERD. Then φ(x)\varphi(x) can be evaluated exactly as a finite sum over the ν\nu particles, and the expressions above for 𝒟\mathcal{D} and \mathcal{R} (which are integrals w.r.t. the product distribution μν\mu\otimes\nu) can be estimated as sample averages. That is, given mm independent samples {xi}i=1m\{x_{i}\}_{i=1}^{m} from μ\mu, we compute unbiased estimates

𝒟^\displaystyle\hat{\mathcal{D}} =i=1mj=1n1mwjρ(xi,yj)eφ(xi)λρ(xi,yj),\displaystyle=\sum_{i=1}^{m}\sum_{j=1}^{n}\frac{1}{m}w_{j}\rho(x_{i},y_{j})e^{\varphi(x_{i})-\lambda\rho(x_{i},y_{j})},
^\displaystyle\hat{\mathcal{R}} =i=1mj=1n1mwj[φ(xi)λρ(xi,yj)]eφ(xi)λρ(xi,yj),\displaystyle=\sum_{i=1}^{m}\sum_{j=1}^{n}\frac{1}{m}w_{j}[\varphi(x_{i})-\lambda\rho(x_{i},y_{j})]e^{\varphi(x_{i})-\lambda\rho(x_{i},y_{j})},

where φ(x)=logj=1neφ(xi)λρ(xi,yj)\varphi(x)=-\log\sum_{j=1}^{n}e^{\varphi(x_{i})-\lambda\rho(x_{i},y_{j})}. Similarly, a sample mean estimate for BA{\mathcal{L}}_{BA} is given by

^BA(ν)=1mi=1mφ(xi)=1mi=1mlog(j=1neφ(xi)λρ(xi,yj)).\displaystyle\hat{\mathcal{L}}_{BA}(\nu)=\frac{1}{m}\sum_{i=1}^{m}\varphi(x_{i})=-\frac{1}{m}\sum_{i=1}^{m}\log\left(\sum_{j=1}^{n}e^{\varphi(x_{i})-\lambda\rho(x_{i},y_{j})}\right). (19)

In practice, we found it simpler and numerically more stable to instead compute ^\hat{\mathcal{R}} as

^=^BA(ν)λ𝒟^.\hat{\mathcal{R}}=\hat{\mathcal{L}}_{BA}(\nu)-\lambda\hat{\mathcal{D}}.

Whenever possible, we avoid exponentiation and instead use logsumexp to prevent numerical issues.

9 Wasserstein gradient descent

9.1 On Wasserstein gradients of the EOT and rate functionals

First, we elaborate on the Wasserstein gradient of the EOT functional EOT(ν)\mathcal{L}_{EOT}(\nu). That the dual potential from Sinkhorn’s algorithm is differentiable follows from the fact that optimal dual potentials satisfy the Schrödinger equations (cf. [Nutz, 2021, Corollary 2.5]). Differentiability was shown in [Genevay et al., 2019, Theorem 2] in the compact setting, and in [Mena and Niles-Weed, 2019, Proposition 1] in unbounded settings. While Mena and Niles-Weed [2019] only states the result for quadratic cost, the approach of Proposition 1 therein applies more generally.

Below, we compute the Wasserstein gradient of the rate functional BA(ν)\mathcal{L}_{BA}(\nu). Recall from (6),

BA(ν)=logexp(λρ(x,y))ν(dy)μ(dx).\mathcal{L}_{BA}(\nu)=\int-\log\int\exp(-\lambda\rho(x,y))\nu(dy)\mu(dx).

Under sufficient integrability on μ\mu and ν\nu to exchange the order of limit and integral, we can calculate the first variation as

limε0((1ε)ν+εν~)(ν)ε\displaystyle\lim_{\varepsilon\rightarrow 0}\frac{\mathcal{L}((1-\varepsilon)\nu+\varepsilon\tilde{\nu})-\mathcal{L}(\nu)}{\varepsilon} =limε01εlog[exp(λρ(x,y))(ν+ε(ν~ν))(dy)exp(λρ(x,y))ν(dy)]μ(dx)\displaystyle=-\int\lim_{\varepsilon\rightarrow 0}\frac{1}{\varepsilon}\log\left[\frac{\int\exp(-\lambda\rho(x,y))(\nu+\varepsilon(\tilde{\nu}-\nu))(dy)}{\int\exp(-\lambda\rho(x,y))\nu(dy)}\right]\mu(dx)
=limε01εlog[1+exp(λρ(x,y))ε(ν~ν)(dy)exp(λρ(x,y))ν(dy)]μ(dx)\displaystyle=-\int\lim_{\varepsilon\rightarrow 0}\frac{1}{\varepsilon}\log\left[1+\frac{\int\exp(-\lambda\rho(x,y))\varepsilon(\tilde{\nu}-\nu)(dy)}{\int\exp(-\lambda\rho(x,y))\nu(dy)}\right]\mu(dx)
=exp(λρ(x,y))exp(λρ(x,y~))ν(dy~)μ(dx)(ν~ν)(dy),\displaystyle=\iint-\frac{\exp(-\lambda\rho(x,y))}{\int\exp(-\lambda\rho(x,\tilde{y}))\nu(d\tilde{y})}\mu(dx)\,(\tilde{\nu}-\nu)(dy),

where the last equality uses limε01εlog(1+εx)=x\lim_{\varepsilon\rightarrow 0}\frac{1}{\varepsilon}\log(1+\varepsilon x)=x and Fubini’s theorem. Thus the first variation ψν\psi^{\nu} of BA\mathcal{L}_{BA} at ν\nu is

ψν(y)=exp(λρ(x,y))exp(λρ(x,y~))ν(dy~)μ(dx).\displaystyle\psi^{\nu}(y)=\int-\frac{\exp(-\lambda\rho(x,y))}{\int\exp(-\lambda\rho(x,\tilde{y}))\nu(d\tilde{y})}\mu(dx). (20)

To find the desired Wasserstein gradient of BA\mathcal{L}_{BA}, it remains to take the Euclidean gradient of ψν\psi^{\nu}, i.e., BA(ν)=ψν\nabla\mathcal{L}_{BA}(\nu)=\nabla\psi^{\nu}. Doing so gives us the desired Wasserstein gradient:

VBA(ν)[y]\displaystyle\nabla V_{\mathcal{L}_{BA}}(\nu)[y] =yψν(y)\displaystyle=\nabla_{y}\psi^{\nu}(y)
=yexp(λρ(x,y))exp(λρ(x,y~))ν(dy~)μ(dx)\displaystyle=\frac{\partial}{\partial y}\int-\frac{\exp(-\lambda\rho(x,y))}{\int\exp(-\lambda\rho(x,\tilde{y}))\nu(d\tilde{y})}\mu(dx)
=exp(λρ(x,y))λyρ(x,y)exp(λρ(x,y~))ν(dy~)μ(dx),\displaystyle=\int\frac{\exp(-\lambda\rho(x,y))\lambda\frac{\partial}{\partial y}\rho(x,y)}{\int\exp(-\lambda\rho(x,\tilde{y}))\nu(d\tilde{y})}\mu(dx), (21)

again assuming suitable regularity conditions on ρ\rho and μ\mu to exchange the order of integral and differentiation.

9.2 Proof of Proposition 4.2 (convergence of Wasserstein gradient descent)

We first provide an auxiliary result.

Lemma 9.1.

Let γ1γ20\gamma_{1}\geq\gamma_{2}\geq\dots\geq 0 and at0a_{t}\geq 0, tt\in\mathbb{N}, C>0C>0 satisfy t=1γt=\sum_{t=1}^{\infty}\gamma_{t}=\infty, t=1γt2<\sum_{t=1}^{\infty}\gamma_{t}^{2}<\infty, t=1atγt<\sum_{t=1}^{\infty}a_{t}\gamma_{t}<\infty and |atat+1|Cγt|a_{t}-a_{t+1}|\leq C\gamma_{t} for all tt\in\mathbb{N}. Then limtat=0\lim_{t\rightarrow\infty}a_{t}=0.

Proof.

The conclusion remains unchanged when rescaling ata_{t} by the constant CC, and thus without loss of generality C=1C=1.

Clearly γt0\gamma_{t}\to 0 as t=1γt2<\sum_{t=1}^{\infty}\gamma_{t}^{2}<\infty. Moreover, there exists a subsequence of (at)t(a_{t})_{t\in\mathbb{N}} which converges to zero (otherwise there exists δ>0\delta>0 such that atδ>0a_{t}\geq\delta>0 for all but finitely many tt, contradicting t=1γtat<\sum_{t=1}^{\infty}\gamma_{t}a_{t}<\infty).

Arguing by contradiction, suppose that the conclusion fails, i.e., that there exists a subsequence of (at)t(a_{t})_{t\in\mathbb{N}} which is uniformly bounded away from zero, say atδ>0a_{t}\geq\delta>0 along that subsequence. Using this subsequence and the convergent subsequence mentioned above, we can construct a subsequence ai1,ai2,ai3,a_{i_{1}},a_{i_{2}},a_{i_{3}},\dots where ain0a_{i_{n}}\approx 0 for nn odd and ainδa_{i_{n}}\geq\delta for nn even. We will show that

t=i2n1i2natγtδ2/2for all n,\sum_{t=i_{2n-1}}^{i_{2n}}a_{t}\gamma_{t}\gtrsim\delta^{2}/2\qquad\mbox{for all $n\in\mathbb{N}$},

contradicting the finiteness of tγtat\sum_{t}\gamma_{t}a_{t}. (The notation \approx (\gtrsim) indicates (in)equality up to additive terms converging to zero for nn\rightarrow\infty.)

To ease notation, fix nn and set m=i2n1m=i_{2n-1} and M=i2nM=i_{2n}. We show that t=mMatγtδ2/2\sum_{t=m}^{M}a_{t}\gamma_{t}\gtrsim\delta^{2}/2. To this end, using |atat+1|γt|a_{t}-a_{t+1}|\leq\gamma_{t} we find

ataMj=kM1γjδj=kM1γj.a_{t}\geq a_{M}-\sum_{j=k}^{M-1}\gamma_{j}\geq\delta-\sum_{j=k}^{M-1}\gamma_{j}.

Since am0a_{m}\approx 0, there exists a largest n0n_{0}\in\mathbb{N}, n0mn_{0}\geq m, such that j=n0M1γjδ\sum_{j=n_{0}}^{M-1}\gamma_{j}\gtrsim\delta (and thus j=n0M1γjδγn0δ\sum_{j=n_{0}}^{M-1}\gamma_{j}\lesssim\delta-\gamma_{n_{0}}\approx\delta as well). We conclude

t=mMγtatt=n0Mγtatt=n0Mγt(δj=kM1γj)δ2t=n0Mj=n0Mγtγj𝟏{jk}\displaystyle\sum_{t=m}^{M}\gamma_{t}a_{t}\geq\sum_{t=n_{0}}^{M}\gamma_{t}a_{t}\geq\sum_{t=n_{0}}^{M}\gamma_{t}\left(\delta-\sum_{j=k}^{M-1}\gamma_{j}\right)\gtrsim\delta^{2}-\sum_{t=n_{0}}^{M}\sum_{j=n_{0}}^{M}\gamma_{t}\gamma_{j}\mathbf{1}_{\{j\geq k\}}
=δ212(t=n0Mγt)212t=n0Mγt2\displaystyle=\delta^{2}-\frac{1}{2}\left(\sum_{t=n_{0}}^{M}\gamma_{t}\right)^{2}-\frac{1}{2}\sum_{t=n_{0}}^{M}\gamma_{t}^{2} δ2/2,\displaystyle\approx\delta^{2}/2,

where we used that t=n0Mγt20\sum_{t=n_{0}}^{M}\gamma_{t}^{2}\approx 0. This completes the proof. ∎

Proof of Proposition 4.2.

Using the linear approximation property in (14), we calculate

(ν(n))(ν(0))\displaystyle\mathcal{L}(\nu^{(n)})-\mathcal{L}(\nu^{(0)}) =t=0n1(ν(t+1))(ν(t))\displaystyle=\sum_{t=0}^{n-1}\mathcal{L}(\nu^{(t+1)})-\mathcal{L}(\nu^{(t)})
=t=0n1γtV(ν(t))2𝑑ν(t)+γt2o(V(ν(t))2𝑑ν(t)).\displaystyle=\sum_{t=0}^{n-1}-\gamma_{t}\int\|\nabla V_{\mathcal{L}}(\nu^{(t)})\|^{2}\,d\nu^{(t)}+\gamma_{t}^{2}\,o\left(\int\|\nabla V_{\mathcal{L}}(\nu^{(t)})\|^{2}\,d\nu^{(t)}\right).

As (ν(0))\mathcal{L}(\nu^{(0)}) is finite and (ν(n))\mathcal{L}(\nu^{(n)}) is bounded from below, it follows that

t=0γtV(ν(t))2𝑑ν(t)<.\sum_{t=0}^{\infty}\gamma_{t}\int\|\nabla V_{\mathcal{L}}(\nu^{(t)})\|^{2}\,d\nu^{(t)}<\infty.

The claim now follow by applying Lemma 9.1 with at=ψν(t)2𝑑ν(t)a_{t}=\int\|\nabla\psi^{\nu^{(t)}}\|^{2}\,d\nu^{(t)}; note that the assumption in the lemma, |atat+1|Cγt|a_{t}-a_{t+1}|\leq C\gamma_{t}, is satisfied due to the second inequality in (14) and a short calculation below

|ψν(t)2ψν(t+1)2|\displaystyle\left|\int\|\nabla\psi^{\nu^{(t)}}\|^{2}-\int\|\nabla\psi^{\nu^{(t+1)}}\|^{2}\right| CW2(ν(t),ν(t+1))\displaystyle\leq CW_{2}(\nu^{(t)},\nu^{(t+1)})
C(xy2δxγtV(ν(t))[x](dy)ν(t)(dx))12\displaystyle\leq C\left(\int\int\|x-y\|^{2}\delta_{x-\gamma_{t}\nabla V_{\mathcal{L}}(\nu^{(t)})[x]}(dy)\nu^{(t)}(dx)\right)^{\frac{1}{2}}
Cγt(V(ν(t))2ν(t)(dx))12\displaystyle\leq C\gamma_{t}\left(\int\|\nabla V_{\mathcal{L}}(\nu^{(t)})\|^{2}\nu^{(t)}(dx)\right)^{\frac{1}{2}}
Cγtsupt(V(ν(t))2ν(t)(dx))12\displaystyle\leq C\gamma_{t}\sup_{t^{\prime}}\left(\int\|\nabla V_{\mathcal{L}}(\nu^{(t^{\prime})})\|^{2}\nu^{(t^{\prime})}(dx)\right)^{\frac{1}{2}}
Cγt,\displaystyle\leq C^{\prime}\gamma_{t},

where we use suptV(ν(t))2ν(t)(dx)<\sup_{t}\int\|\nabla V_{\mathcal{L}}(\nu^{(t)})\|^{2}\nu^{(t)}(dx)<\infty in the last step. ∎

9.3 Proof of Proposition 4.3 (sample complexity)

Recall that 𝒳=𝒴=d{\mathcal{X}}={\mathcal{Y}}=\mathbb{R}^{d} and ρ(x,y)=xy2\rho(x,y)=\|x-y\|^{2} in this proposition. For the proof, we will need the following lemma which is of independent interest. We write νcμ\nu\leq_{c}\mu if ν\nu is dominated by μ\mu in convex order, i.e., f𝑑νf𝑑μ\int f\,d\nu\leq\int f\,d\mu for all convex functions f:df:\mathbb{R}^{d}\rightarrow\mathbb{R}.

Lemma 9.2.

Let μ\mu have finite second moment. Given ν𝒫(d)\nu\in\mathcal{P}(\mathbb{R}^{d}), there exists ν~𝒫(d)\tilde{\nu}\in\mathcal{P}(\mathbb{R}^{d}) with ν~cμ\tilde{\nu}\leq_{c}\mu and

EOT(μ,ν~)EOT(μ,ν).\mathcal{L}_{EOT}(\mu,\tilde{\nu})\leq\mathcal{L}_{EOT}(\mu,\nu).

This inequality is strict if νcμ\nu{\not\leq}_{c}\mu. In particular, any optimizer ν\nu^{*} of (8) satisfies νcμ\nu^{*}\leq_{c}\mu.

Proof.

Because this proof uses disintegration over 𝒴{\mathcal{Y}}, it is convenient to reverse the order of the spaces in the notation and write a generic point as (x,y)𝒴×𝒳(x,y)\in{\mathcal{Y}}\times{\mathcal{X}}. Consider πΠ(ν,μ)\pi\in\Pi(\nu,\mu) and its disintegration π=ν(dx)K(x,dy)\pi=\nu(dx)\otimes K(x,dy) over x𝒴x\in{\mathcal{Y}}. Define T:ddT:\mathbb{R}^{d}\rightarrow\mathbb{R}^{d} by

T(x):=yK(x,dy).T(x):=\int y\,K(x,dy).

Define also π~:=(T,id)#π\tilde{\pi}:=(T,\text{id})_{\#}\pi and ν~:=π~1\tilde{\nu}:=\tilde{\pi}_{1}. From the definition of TT, we see that π~\tilde{\pi} is a martingale, thus ν~cμ\tilde{\nu}\leq_{c}\mu. Moreover, ν~μ=(T,id)#νμ\tilde{\nu}\otimes\mu=(T,\text{id})_{\#}\nu\otimes\mu. The data-processing inequality now shows that

H(π~|ν~μ)H(π|νμ).H(\tilde{\pi}|\tilde{\nu}\otimes\mu)\leq H(\pi|\nu\otimes\mu).

On the other hand, y~K(x,dy~)y2K(x,dy)xy2K(x,dy)\int\|\int\tilde{y}\,K(x,d\tilde{y})-y\|^{2}\,K(x,dy)\leq\int\|x-y\|^{2}K(x,dy) since the barycenter minimizes the squared distance, and this inequality is strict whenever xy~K(x,dy~)x\neq\int\tilde{y}K(x,d\tilde{y}). Thus

xy2π~(dx,dy)xy2π(dx,dy),\int\|x-y\|^{2}\,\tilde{\pi}(dx,dy)\leq\int\|x-y\|^{2}\,\pi(dx,dy),

and the inequality is strict unless T(x)=xT(x)=x for ν\nu-a.e. xx, which in turn is equivalent to π\pi being a martingale. The claims follow. ∎

Proof of Proposition 4.3.

Subgaussianity of the optimizer follows directly from Lemma 9.2.

Recalling that infνEOT(ν)\inf_{\nu}\mathcal{L}_{EOT}(\nu) and infνλ1BA(ν)\inf_{\nu}\lambda^{-1}\mathcal{L}_{BA}(\nu) have the same values and minimizers (given by (9) in Sec. 2.2), it suffices to show the claim for =EOT\mathcal{L}=\mathcal{L}_{EOT}. Let ν\nu^{*} be an optimizer of (8) (i.e., an optimal reproduction distribution) and νn\nu^{n} its empirical measure from nn samples, then clearly

|minνn𝒫n(d)EOT(μ,νn)minν𝒫(d)EOT(μ,ν)|\displaystyle\left|\min_{\nu_{n}\in\mathcal{P}_{n}(\mathbb{R}^{d})}\mathcal{L}_{EOT}(\mu,\nu_{n})-\min_{\nu\in\mathcal{P}(\mathbb{R}^{d})}\mathcal{L}_{EOT}(\mu,\nu)\right| =minνn𝒫n(d)EOT(μ,νn)minν𝒫(d)EOT(μ,ν)\displaystyle=\!\!\min_{\nu_{n}\in\mathcal{P}_{n}(\mathbb{R}^{d})}\mathcal{L}_{EOT}(\mu,\nu_{n})-\!\!\min_{\nu\in\mathcal{P}(\mathbb{R}^{d})}\mathcal{L}_{EOT}(\mu,\nu)
𝔼[|EOT(μ,νn)EOT(μ,ν)|]\displaystyle\leq\mathbb{E}\left[|\mathcal{L}_{EOT}(\mu,\nu^{n})-\mathcal{L}_{EOT}(\mu,\nu^{*})|\right]

where the expectation is taken over samples for νn\nu^{n}. The first inequality of Proposition 4.3 now follows from the sample complexity result for entropic optimal transport in [Mena and Niles-Weed, 2019, Theorem 2].

Denote by νm\nu_{m}^{*} the optimizer for the problem (8) with μ\mu replaced by μm\mu^{m}. Similarly to the above, we obtain

𝔼[|minν𝒫(d)EOT(μ,ν)minν𝒫(d)EOT(μm,ν)|]𝔼[maxν{ν,νm}|EOT(μ,ν)EOT(μm,ν)|],\mathbb{E}\left[\left|\min_{\nu\in\mathcal{P}(\mathbb{R}^{d})}\mathcal{L}_{EOT}(\mu,\nu)-\min_{\nu\in\mathcal{P}(\mathbb{R}^{d})}\mathcal{L}_{EOT}(\mu^{m},\nu)\right|\right]\\ \leq\mathbb{E}\left[\max_{\nu\in\{\nu^{*},\nu_{m}^{*}\}}\left|\mathcal{L}_{EOT}(\mu,\nu)-\mathcal{L}_{EOT}(\mu^{m},\nu)\right|\right],

where the expectation is taken over samples from μm\mu^{m}. In this situation, we cannot directly apply [Mena and Niles-Weed, 2019, Theorem 2]. However, the bound given by [Mena and Niles-Weed, 2019, Proposition 2] still applies, and the only dependence on ν{ν,νm}\nu\in\{\nu^{*},\nu_{m}^{*}\} is through their subgaussianity constants. By Lemma 9.2, these constants are bounded by the corresponding constants of μ\mu and μm\mu^{m}. Thus, the arguments in the proof of [Mena and Niles-Weed, 2019, Theorem 2] can be applied, yielding the second inequality of Proposition 4.3.

The final inequality of Proposition 4.3 follows from the first two inequalities (the first one being applied with μm\mu^{m}) and the triangle inequality, where we again use the arguments in the proof of [Mena and Niles-Weed, 2019, Theorem 2] to bound the expectation over the subgaussianity constants of μm\mu^{m}. ∎

9.4 Convergence of the proposed hybrid algorithm

In our present work, our hybrid algorithm targets the non-stochastic setting and is motivated by the fact that the BA update is in some sense orthogonal to the Wasserstein gradient update, and can only monotonically improve the objective. While empirically we observe the additional BA steps to not hurt – but rather accelerate – the convergence of WGD (see 5.1), additional effort is required to theoretically guarantee convergence of the hybrid algorithm.

There are two key properties we use for the convergence of the base WGD algorithm: 1) a certain monotonicity of the update steps (up to higher order terms, gradient descent improves the objective) and 2) stability of gradients across iterations. If we include the BA step, we find that 1) still holds, but 2) may a-priori be lost. Indeed, 1) holds since BA updates monotonically improve the objective. Using just 1), we can still obtain a Pareto convergence of the gradients for the hybrid algorithm, t=0γtV(ν(t))2𝑑ν(t)<\sum_{t=0}^{\infty}\gamma_{t}\int\|\nabla V_{\mathcal{L}}(\nu^{(t)})\|^{2}\,d\nu^{(t)}<\infty (here ν(t)\nu^{(t)} are the outputs from the respective BA steps and γt\gamma_{t} is the step size of the gradient steps). Without property 2), we cannot conclude V(ν(t))2𝑑ν(t)0\int\|\nabla V_{\mathcal{L}}(\nu^{(t)})\|^{2}\,d\nu^{(t)}\rightarrow 0 for tt\rightarrow\infty. We emphasize that in practice, it still appears that 2) holds even after including the BA step. Motivated by this analysis, an adjusted hybrid algorithm where, e.g., the BA update is rejected if it causes a drastic change in the Wasserstein gradient, could guarantee that 2) holds with little practical changes. From a different perspective, we also believe the hybrid algorithm may be tractable to study in relation to gradient flows in the Wasserstein-Fisher-Rao geometry (cf. [Yan et al., 2023]), in which the BA step corresponds to a gradient update in the Fisher-Rao geometry with a unit step size.

In the stochastic setting, the BA (and therefore our hybrid) algorithm does not directly apply, as performing BA updates on mini-batches no longer guarantees monotonic improvement of the overall objective. Extending the BA and hybrid algorithm to this setting would be interesting future work.

Table 1: Guide to notation and their interpretations in various problem domains. “LVM” stands for latent variable modeling, “NPMLE” stands for non-parametric maximum-likelihood estimation. The R-D problem (3) is equivalent to a projection problem under an entropic optimal transport cost (discussed in Sec. 2.2) and statistical problems involving maximum-likelihood estimation (see discussion in Sec. 2.3 and below).
Context μ=PX\mu=P_{X} ρ(x,y)\rho(x,y) K=QY|XK=Q_{Y|X} ν=QY\nu=Q_{Y}
OT source distribution transport cost “transport plan” target distribution
R-D data distribution distortion criterion compression algorithm codebook distribution
LVM/NPMLE data distribution logp(x|y)-\log p(x|y) variational posterior prior distribution
deconvolution noisy measurements “noise kernel” noiseless model

10 R-D estimation and variational inference/learning

In this section, we elaborate on the connection between the R-D problem (3) and variational inference and learning in latent variable models, of which maximum likelihood deconvolution (discussed in Sec. 2.3) can be seen as a special case. Also see Section 3 of [Yang et al., 2023] for a related discussion in the context of lossy data compression.

To make the connections clearer to a general machine learning audience, we adopt notation more common in statistics and information theory. Table 1 summarizes the notation and the correspondence to the measure-theoretic notation used in the main text.

In statistics, a common goal is to model an (unknown) data generating distribution PXP_{X} by some density model p^(x)\hat{p}(x). In particular, here we will choose p^(x)\hat{p}(x) to be a latent variable model, where 𝒴{\mathcal{Y}} takes on the role of a latent space, and QY=νQ_{Y}=\nu is the distribution of a latent variable YY (which may encapsulate the model parameters). As we shall see, the optimization objective defining the rate functional (5) corresponds to an aggregate Evidence LOwer Bound (ELBO) [Blei et al., 2017]. Thus, computing the rate functional corresponds to variational inference [Blei et al., 2017] in a given model (see Sec. 10.2), and the parametric R-D estimation problem, i.e.,

infνBA(ν),\displaystyle\inf_{\nu\in{\mathcal{H}}}{\mathcal{L}}_{BA}(\nu),

is equivalent to estimating a latent variable model using the variational EM algorithm [Beal and Ghahramani, 2003] (see Sec. 10.3). The variational EM algorithm can be seen as a restricted version of the BA algorithm (see Sec. 10.3), whereas the EM algorithm [Dempster et al., 1977] shares its E-step with the BA algorithm but can differ in its M-step (see Sec. 10.4).

10.1 Setup

For concreteness, fix a reference measure ζ\zeta on 𝒴{\mathcal{Y}}, and suppose QYQ_{Y} has density q(y)q(y) w.r.t. ζ\zeta. Often the latent space 𝒴{\mathcal{Y}} is a Euclidean space, and q(y)q(y) is the usual probability density function w.r.t. the Lebesgue measure ζ\zeta; or when the latent space is discrete/countable, ζ\zeta is the counting measure and q(y)q(y) is a probability mass function. We will consider the typical parametric estimation problem and choose a particular parametric form for QYQ_{Y} indexed by a parameter vector θ\theta. This defines a parametric family ={QYθ:θΘ}{\mathcal{H}}=\{Q_{Y}^{\theta}:\theta\in\Theta\} for some parameter space Θ\Theta. Finally, suppose the distortion function ρ\rho induces a conditional likelihood density via p(x|y)eλρ(x,y)p(x|y)\propto e^{-\lambda\rho(x,y)}, with a normalization constant that is constant in yy.

These choices then result in a latent variable model specified by the joint density q(y)p(x|y)q(y)p(x|y), and we model the data distribution with the marginal density: 222To be more precise, we always fix a reference measure η\eta on 𝒳{\mathcal{X}}, so that densities such as p^(x)\hat{p}(x) and p(x|y)p(x|y) are with respect to η\eta. In the applications we considered in this work, η\eta is the Lebesgue measure on 𝒳=d{\mathcal{X}}=\mathbb{R}^{d}.

p^(x)=𝒴p(x|y)𝑑QY(y)=𝒴p(x|y)q(y)ζ(dy).\displaystyle\hat{p}(x)=\int_{{\mathcal{Y}}}p(x|y)dQ_{Y}(y)=\int_{{\mathcal{Y}}}p(x|y)q(y)\zeta(dy). (22)

As an example, a Gaussian mixture model with isotropic component variances can be specified as follows. Let QYQ_{Y} be a mixing distribution on 𝒳=𝒴=d{\mathcal{X}}={\mathcal{Y}}=\mathbb{R}^{d} parameterized by component weights w1,,kw_{1,...,k} and locations μ1,,k\mu_{1,...,k}, such that QY=k=1KwkδμkQ_{Y}=\sum_{k=1}^{K}w_{k}\delta_{\mu_{k}}. Let p(x|y)=𝒩(y,σ2)p(x|y)=\mathcal{N}(y,\sigma^{2}) be a conditional Gaussian density with mean yy and variance σ2\sigma^{2}. Now formula (22) gives the usual Gaussian mixture density on d\mathbb{R}^{d}.

Maximum-likelihood estimation then ideally maximizes the population log (marginal) likelihood,

𝔼xPX[logp^(x)]=logp^(x)PX(dx)=log(𝒴p(x|y)𝑑QY(y))PX(dx).\displaystyle{\mathbb{E}}_{x\sim P_{X}}[\log\hat{p}(x)]=\int\log\hat{p}(x)P_{X}(dx)=\int\log\left(\int_{{\mathcal{Y}}}p(x|y)dQ_{Y}(y)\right)P_{X}(dx). (23)

The maximum-likelihood deconvolution setup can be seen as a special case where the form of the marginal density (22) derives from knowledge of the true data generation process, with PX=α𝒩(0,1λ)P_{X}=\alpha*\mathcal{N}(0,\frac{1}{\lambda}) for some unknown α\alpha and known noise 𝒩(0,1λ)\mathcal{N}(0,\frac{1}{\lambda}) (i.e., the model is well-specified). We note in passing that the idea of estimating an unknown data distribution by adding artificial noise to it also underlies work on spread divergence [Zhang et al., 2020] and denoising score matching [Vincent, 2011].

To deal with the often intractable marginal likelihood in the inner integral, we turn to variational inference and learning [Jordan et al., 1999, Wainwright et al., 2008].

10.2 Connection to variational inference

Given a latent variable model and any data observation xx, a central task in Bayesian statistics is to infer the Bayesian posterior [Jordan, 1999], which we formally view as a conditional distribution QY|X=xQ^{*}_{Y|X=x}. It is given by

dQY|X=x(y)dQY(y)=p(x|y)p^(x),\displaystyle\frac{dQ^{*}_{Y|X=x}(y)}{dQ_{Y}(y)}=\frac{p(x|y)}{\hat{p}(x)},

or, in terms of the density q(y)q(y) of QYQ_{Y}, given by the following conditional density via the familiar Bayes’ rule,

q(y|x)=p(x|y)q(y)p^(x)=p(x|y)q(y)𝒴p(x|y)q(y)ζ(dy).\displaystyle q^{*}(y|x)=\frac{p(x|y)q(y)}{\hat{p}(x)}=\frac{p(x|y)q(y)}{\int_{{\mathcal{Y}}}p(x|y)q(y)\zeta(dy)}.

Unfortunately, the true Bayesian posterior is typically intractable, as the (marginal) data likelihood in the denominator involves an often high-dimensional integral. Variational inference [Jordan et al., 1999, Wainwright et al., 2008] therefore aims to approximate the true posterior by a variational distribution QY|X=x𝒫(𝒴)Q_{Y|X=x}\in{\mathcal{P}}({\mathcal{Y}}) by minimizing their relative divergence H(QY|X=x|QY|X=x)H(Q_{Y|X=x}|Q^{*}_{Y|X=x}). The problem is equivalent to maximizing the following lower bound on the marginal log-likelihood, known as the Evidence Lower BOund (ELBO) [Blei et al., 2017]:

argminQY|X=xH(QY|X=x|QY|X=x)\displaystyle\operatorname*{arg\,min}_{Q_{Y|X=x}}H(Q_{Y|X=x}|Q^{*}_{Y|X=x}) =argmaxQY|X=xELBO(QY,x,QY|X=x),\displaystyle=\operatorname*{arg\,max}_{Q_{Y|X=x}}\operatorname{ELBO}(Q_{Y},x,Q_{Y|X=x}),
ELBO(QY,x,QY|X=x)\displaystyle\operatorname{ELBO}(Q_{Y},x,Q_{Y|X=x}) =𝔼yQY|X=x[logp(x|y)]H(QY|X=x|QY)\displaystyle={\mathbb{E}}_{y\sim Q_{Y|X=x}}[\log p(x|y)]-H(Q_{Y|X=x}|Q_{Y})
=logp^(x)H(QY|X=x|QY|X=x).\displaystyle=\log\hat{p}(x)-H(Q_{Y|X=x}|Q^{*}_{Y|X=x}). (24)

Recall the rate functional (5) arises through an optimization problem over a transition kernel KK,

BA(ν)=infKλρd(μK)+H(μK|μν).{\mathcal{L}}_{BA}(\nu)=\inf_{K}\lambda\int\rho d(\mu\otimes K)+H(\mu\otimes K|\mu\otimes\nu).

Translating the above into the present notation (μPX,KQY|X,νQY\mu\to P_{X},K\to Q_{Y|X},\nu\to Q_{Y}; see Table 1), we obtain

BA(QY)\displaystyle{\mathcal{L}}_{BA}(Q_{Y}) =infQY|X𝔼xPX,yQY|X=x[logp(x|y)]+𝔼xPX[H(QY|X=x|QY)]+const\displaystyle=\inf_{Q_{Y|X}}{\mathbb{E}}_{x\sim P_{X},y\sim Q_{Y|X=x}}[-\log p(x|y)]+{\mathbb{E}}_{x\sim P_{X}}[H(Q_{Y|X=x}|Q_{Y})]+\text{const}
=infQY|X𝔼xPX[ELBO(QY,x,QY|X=x)]+const.\displaystyle=\inf_{Q_{Y|X}}{\mathbb{E}}_{x\sim P_{X}}[-\operatorname{ELBO}(Q_{Y},x,Q_{Y|X=x})]+\text{const}. (25)

which allows us to interpret the rate functional as the result of performing variational inference through ELBO optimization. At optimality, QY|X=QY|XQ_{Y|X}=Q^{*}_{Y|X}, the ELBO (24) is tight and recovers logp^(x)\log\hat{p}(x), and the rate functional takes on the form of a (negated) population marginal log likelihood (23), as given earlier by (6) in Sec. 2.1.

10.3 Connection to variational EM

The discussion so far concerns probabilistic inference, where a latent variable model (QY,p(x|y))(Q_{Y},p(x|y)) has been given and we saw that computing the rate functional amounts to variational inference. Suppose now we wish to learn a model from data. The R-D problem (4) then corresponds to model estimation using the variational EM algorithm [Beal and Ghahramani, 2003].

To estimate a latent variable model by (approximate) maximum-likelihood, the variational EM algorithm maximizes the population ELBO

𝔼xPX[ELBO(QY,x,QY|X=x)]=𝔼xPX,yQY|X=x[logp(x|y)]𝔼xPX[H(QY|X=x|QY)],\displaystyle{\mathbb{E}}_{x\sim P_{X}}[\operatorname{ELBO}(Q_{Y},x,Q_{Y|X=x})]={\mathbb{E}}_{x\sim P_{X},y\sim Q_{Y|X=x}}[\log p(x|y)]-{\mathbb{E}}_{x\sim P_{X}}[H(Q_{Y|X=x}|Q_{Y})], (26)

w.r.t. QYQ_{Y} and QY|XQ_{Y|X}. This precisely corresponds to the R-D problem infQYBA(QY)\inf_{Q_{Y}\in{\mathcal{H}}}{\mathcal{L}}_{BA}(Q_{Y}), using the form of BA(QY){\mathcal{L}}_{BA}(Q_{Y}) from (25).

In popular implementations of variational EM such as the VAE [Kingma and Welling, 2013], QYQ_{Y} and QY|XQ_{Y|X} are restricted to parametric families. When they are allowed to range over all of 𝒫(𝒴){\mathcal{P}}({\mathcal{Y}}) and all conditional distributions, variational EM then becomes equivalent to the BA algorithm.

10.4 The Blahut–Arimoto and (exact) EM algorithms

Here we focus on the distinction between the BA algorithm and the (exact) EM algorithm [Dempster et al., 1977], rather than the variational EM algorithm. Both BA and (exact) EM perform coordinate descent on the same objective function (namely the negative of the population ELBO from (26)), but they define the coordinates slightly differently — the BA algorithm uses (QY|X,QY)(Q_{Y|X},Q_{Y}) with QY𝒫(𝒴)Q_{Y}\in{\mathcal{P}}({\mathcal{Y}}), whereas the EM algorithm uses (QY|X,θ)(Q_{Y|X},\theta) with θ\theta indexing a parametric family ={QYθ:θΘ}{\mathcal{H}}=\{Q_{Y}^{\theta}:\theta\in\Theta\}. Thus the coordinate update w.r.t. QY|XQ_{Y|X} (the “E-step”) is the same in both algorithms, but the subseuquent “M-step” potentially differs depending on the role of θ\theta.

Given the optimization objective, which is simply the negative of (26),

𝔼xPX,yQY|X=x[logp(x|y)]+H(PXQY|X|PXQY),\displaystyle{\mathbb{E}}_{x\sim P_{X},y\sim Q_{Y|X=x}}[-\log p(x|y)]+H(P_{X}Q_{Y|X}|P_{X}\otimes Q_{Y}), (27)

both BA and EM optimize the transition kernel QY|XQ_{Y|X} the same way in the E-step, as

dQY|X=xdQY(y)=p(x|y)p^(x).\displaystyle\frac{dQ^{*}_{Y|X=x}}{dQ_{Y}}(y)=\frac{p(x|y)}{\hat{p}(x)}. (28)

For the M-step, the BA algorithm only minimizes the relative entropy term of the objective (27),

minQY𝒫(𝒴)H(PXQY|X;PXQY),\displaystyle\min_{Q_{Y}\in{\mathcal{P}}({\mathcal{Y}})}H(P_{X}Q^{*}_{Y|X};P_{X}\otimes Q_{Y}),

(with the optimal QYQ_{Y} given by the second marginal of PXQY|XP_{X}Q^{*}_{Y|X}) whereas the EM algorithm minimizes the full objective w.r.t. the parameters θ\theta of QYQ_{Y},

minθΘ𝔼(x,y)PXQY|X[logp(x|y)]+H(PXQY|X;PXQY).\displaystyle\min_{\theta\in\Theta}{\mathbb{E}}_{(x,y)\sim P_{X}Q^{*}_{Y|X}}[-\log p(x|y)]+H(P_{X}Q^{*}_{Y|X};P_{X}\otimes Q_{Y}). (29)

The difference comes from the fact that when we parameterize QYQ_{Y} by θ\theta in the parameter estimation problem, QY|XQ^{*}_{Y|X} — and consequently both terms in the objective of (29) — will have functional dependence on θ\theta through the E-step optimality condition (28).

In the Gaussian mixture example, QY=k=1KwkδμkQ_{Y}=\sum_{k=1}^{K}w_{k}\delta_{\mu_{k}}, and its parameters θ\theta consist of the components weights (w1,,wK)Δd1(w_{1},...,w_{K})\in\Delta^{d-1} and location vectors {μ1,,μK}d\{\mu_{1},...,\mu_{K}\}\subset\mathbb{R}^{d}. The E-step computes QY|X=x=kwkp(x|μk)p(x)δμkQ^{*}_{Y|X=x}=\sum_{k}w_{k}\frac{p(x|\mu_{k})}{p(x)}\delta_{\mu_{k}}. For the M-step, if we regard the locations as known so that θ=(w1,,wK)\theta=(w_{1},...,w_{K}) only consists of the weights, then the two algorithms perform the same update; however if θ\theta also includes the locations, then the M-step of the EM algorithm will not only update the weights as in the BA algorithm, but also the locations, due to the distortion term 𝔼(x,y)PXQY|X[logp(x|y)]=kwkp(x|μk)p(x)logp(x|μk)PX(dx){\mathbb{E}}_{(x,y)\sim P_{X}Q^{*}_{Y|X}}[-\log p(x|y)]=-\int\sum_{k}w_{k}\frac{p(x|\mu_{k})}{p(x)}\log p(x|\mu_{k})P_{X}(dx).

11 Further experimental results

Our deconvolution experiments were run on Intel(R) Xeon(R) CPUs, and the rest of the experiments were run on Titan RTX GPUs.

In most experiments, we use the Adam [Kingma and Ba, 2015] optimizer for updating the ν\nu particle locations in WGD and for updating the variational parameters in other methods. For our hybrid WGD algorithm, which adjusts the particle weights in addition to their locations, we found that applying momentum to the particle locations can in fact slow down convergence, and therefore use plain gradient descent with a decaying step size.

To trace an R-D upper bound for a given source, we solve the unconstrained R-D problem 3 for a heuristically-chosen grid of λ\lambda values similarly to those in [Yang and Mandt, 2022], e.g., on a log-linear grid {1e4,3e3,1e3,3e2,1e2,}\{1e4,3e3,1e3,3e2,1e2,...\}. Alternatively, a constrained optimization method like the Modified Differential Multiplier Method (MDMM) [Platt and Barr, 1987] can be adopted to directly target R(D)R(D) for various values of DDs. The latter approach will also help us identify when we run into the log(n)\log(n) rate limit of particle-based methods (Sec. 4.4): suppose we perform constrained optimization with a distortion threshold of DD; if the chosen nn is not large enough, i.e., log(n)<R(D)\log(n)<R(D), then no ν\nu supported on (at most) nn points is feasible for the given constraint (otherwise we have a contradiction). When this happens, the Lagrange multiplier associated with the infeasiable constraint (λ\lambda in our case) will be observed to increase indefinitely rather than stabalize to a finite value under a method like MDMM.

11.1 Deconvolution

Architectures for the neural baselines

For the RD-VAE, we used a similar architecture as the one used on the banana-shaped source in [Yang and Mandt, 2022], consisting of two-layer MLPs for the encoder and decoder networks, and Masked Autoregressive Flow [Papamakarios et al., 2017] for the variational prior. For NERD, we follow similar architecture settings as [Lei et al., 2023a], using a two-layer MLP for the decoder network. Following Yang and Mandt [2022], we initially used softplus activation for the MLP, but found it made the optimization difficult; therefore we switched to ReLU activation which gave much better results. We conjecture that the softplus activation led to overly smooth mappings, which had difficulty matching the optimal ν=α\nu^{*}=\alpha measure which is concentrated on the unit circle and does not have a Lebesgue density.

Experiment setup

As discussed in Sec. 4.2, BA and our hybrid algorithms do not apply to the stochastic setting; to be able to include them in our comparison, the input to all the algorithms is an empirical measure μm\mu^{m} (training distribution) with m=100000m=100000 samples, given the same fixed seed. We found the number of training samples is sufficiently large that the losses do not differ significantly on the training distribution v.s. random/held-out samples.

Recall from Sec. 2.3, given data observations following μ=α𝒩(0,σ2)\mu=\alpha*\mathcal{N}(0,\sigma^{2}), if we perform deconvolution with an assumed noise variance 1λ\frac{1}{\lambda} for some 1λσ2\frac{1}{\lambda}\leq\sigma^{2}, then the optimal solution to the problem is given by ν=αλ=α𝒩(0,1λ)\nu^{*}=\alpha_{\lambda}=\alpha*\mathcal{N}(0,\frac{1}{\lambda}). We compute the optimal loss OPT=BA(ν)OPT={\mathcal{L}}_{BA}(\nu^{*}) by a Monte-Carlo estimate, using formula (19) with m=104m=10^{4} samples from μ\mu and n=106n=10^{6} samples from ν\nu^{*}. The resulting ^BA\hat{\mathcal{L}}_{BA} is positively biased (it overestimates the OPTOPT in expectation) due to the bias of the plug-in estimator for φ(x)\varphi(x) (18), so we use a large nn to reduce this bias. 333Another, more sophisticated solution would be annealed importance sampling [Grosse et al., 2015] or a related method developed for estimating marginal likelihoods and partition functions. Note the same issue occurs in the NERD algorithm (13).

Refer to caption
Refer to caption
Figure 6: Visualizing the converged n=20n=20 particles of WGD (top row) and hybrid WGD (bottom row) estimated on m=10000m=10000 source samples in the deconvolution problem (Sec. 5.1), for decreasing distortion penalty λ\lambda. The case λ=10.0=σ2\lambda=10.0=\sigma^{-2} corresponds to “complete noise removal” of the source and recovers the ground truth α\alpha (unit circle).

Optimized reproduction distributions

To visualize the (continuous) ν\nu learned by RD-VAE and NERD, we fit a kernel density estimator on 10000 ν\nu samples drawn from each, using seaborn.kdeplot.

In Figure 6, we provide additional visualization for the particles estimated from the training samples by WGD and its hybrid variant, for a decreasing distortion penalty λ\lambda (equivalently, increasing entropic regularization strength ϵ\epsilon).

11.2 Higher-dimensional datasets

For NERD, we set nn to the default 4000040000 in the code provided by [Lei et al., 2023a], on all three datasets.

For WGD, we used n=4000n=4000 on physics, 200000 on speech, and 40000 on MNIST (see also smaller nn for comparison in Fig. 5).

On speech, both NERD and WGD encountered the issue of a log(n)\log(n) upper bound on the rate estimate as described in Sec. 4.4. Therefore, we increased nn to 200000 for WGD and obtain a tighter R-D upper bound than NERD, particularly in the low-distortion regime. Such a large nn incurred out-of-memory error for NERD.

We borrow the R-D upper and lower bound results of [Yang and Mandt, 2022] from their paper on physics and speech. We obtain sandwich bounds on MNIST using a similar neural network architecture and other hyperparameter choices as in the GAN-generated image experiments of [Yang and Mandt, 2022] (using a ResNet-VAE for the upper bound and a CNN for the lower bound), except we set a larger k=10000k=10000 in the lower bound experiment; we suspect the resulting lower bound is still far from being tight.

12 Example implementation of WGD

We provide a self-contained minimal implementation of Wasserstein gradient descent on BA{\mathcal{L}}_{BA}, using the Jax library [Bradbury et al., 2018]. To compute the Wasserstein gradient, we evaluate the first variation of the rate functional in compute_psi_sum according to (20), yielding i=1nψν(xi)\sum_{i=1}^{n}\psi^{\nu}(x_{i}), then simply take its gradient w.r.t. the particle locations x1,nx_{1,...n} using Jax’s autodiff tool on line 51.

The implementation of WGD on EOT{\mathcal{L}}_{EOT} is similar, except the first variation is computed using Sinkhorn’s algorithm. Both versions can be easily just-in-time compiled and enjoy GPU acceleration.

# Wasserstein GD on the rate functional L_{BA}.
import jax.numpy as jnp
import jax
from jax.scipy.special import logsumexp
# Define the distortion function \rho.
squared_diff = lambda x, y: jnp.sum((x - y) ** 2)
pairwise_distortion_fn = jax.vmap(jax.vmap(squared_diff, (None, 0)), (0, None))
def wgrad(mu_x, mu_w, nu_x, nu_w, rd_lambda):
"""
Compute the Wasserstein gradient of the rate functional, which we will use
to move the \nu particles.
:param mu_x: locations of \mu atoms.
:param mu_w: weights of \mu atoms.
:param nu_x: locations of \nu atoms.
:param nu_w: weights of \nu atoms.
:param rd_lambda: R-D tradeoff hyperparameter.
:return:
"""
def compute_psi_sum(nu_x):
"""
Here we compute a surrogate loss based on the first variation \psi, which
allows jax autodiff to compute the desired Wasserstein gradient.
:param nu_x:
:return: psi_sum = \sum_i \psi(nu_x[i])
"""
C = pairwise_distortion_fn(mu_x, nu_x)
scaled_C = rd_lambda * C # [m, n]
log_nu_w = jnp.log(nu_w) # [1, n]
# Solve BA inner problem with a fixed nu.
phi = - logsumexp(-scaled_C + log_nu_w, axis=1, keepdims=True) # [m, 1]
loss = jnp.sum(mu_w * phi) # Evaluate the rate functional. Eq (6) in paper.
# Lets also report rate and distortion estimates (discussed in Sec. 4.4 of the paper).
# Find \pi^* via \phi
pi = jnp.exp(phi - scaled_C) * jnp.outer(mu_w, nu_w) # [m, n]
distortion = jnp.sum(pi * C)
rate = loss - rd_lambda * distortion
# Now evaluate \psi on the atoms of \nu.
phi = jax.lax.stop_gradient(phi)
psi = - jnp.sum(jnp.exp(jax.lax.stop_gradient(phi) - scaled_C) * mu_w, axis=0)
psi_sum = jnp.sum(psi) # For computing gradient w.r.t. each nu_x atom.
return psi_sum, (loss, rate, distortion)
# Evaluate the Wasserstein gradient, i.e., \nabla \psi, on nu_x.
psi_prime, loss = jax.grad(compute_psi_sum, has_aux=True)(nu_x)
return psi_prime, loss
def wgd(X, n, rd_lambda, num_steps, lr, rng):
"""
A basic demo of Wasserstein gradient descent on a discrete distribution.
:param X: a 2D array [N, d] of data points defining the source \mu.
:param n: the number of particles to use for \nu.
:param rd_lambda: R-D tradeoff hyperparameter.
:param num_steps: total number of gradient updates.
:param lr: step size.
:param rng: jax random key.
:return: (nu_x, nu_w), the locations and weights of the final \nu.
"""
# Set up the source measure \mu.
m = jnp.size(X, 0)
mu_x = X
mu_w = 1 / m * jnp.ones((m, 1))
# Initialize \nu atoms using random training samples.
rand_idx = jax.random.permutation(rng, m)[:n]
nu_x = X[rand_idx] # Locations of \nu atoms.
nu_w = 1 / n * jnp.ones((1, n)) # Uniform weights.
for step in range(num_steps):
psi_prime, (loss, rate, distortion) = wgrad(mu_x, mu_w, nu_x, nu_w, rd_lambda)
nu_x -= lr * psi_prime
print(fstep={step}, loss={loss:.4g}, rate={rate:.4g}, distortion={distortion:.4g}’)
return nu_x, nu_w
if __name__ == __main__’:
# Run a toy example on 2D Gaussian samples.
rng = jax.random.PRNGKey(0)
X = jax.random.normal(rng, [10, 2])
nu_x, nu_w = wgd(X, n=4, rd_lambda=2., num_steps=100, lr=0.1, rng=rng)