This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Faster Diffusion Sampling with Randomized Midpoints: Sequential and Parallel

Shivam Gupta
UT Austin
[email protected]
   Linda Cai
Princeton University
[email protected]
   Sitan Chen
Harvard SEAS
[email protected]
Abstract

Sampling algorithms play an important role in controlling the quality and runtime of diffusion model inference. In recent years, a number of works [CCL+23b, CCL+23a, BDD23, LLT22] have proposed schemes for diffusion sampling with provable guarantees; these works show that for essentially any data distribution, one can approximately sample in polynomial time given a sufficiently accurate estimate of its score functions at different noise levels. In this work, we propose a new scheme inspired by Shen and Lee’s randomized midpoint method for log-concave sampling [SL19]. We prove that this approach achieves the best known dimension dependence for sampling from arbitrary smooth distributions in total variation distance (O~(d5/12)\widetilde{O}(d^{5/12}) compared to O~(d)\widetilde{O}(\sqrt{d}) from prior work). We also show that our algorithm can be parallelized to run in only O~(log2d)\widetilde{O}(\log^{2}d) parallel rounds, constituting the first provable guarantees for parallel sampling with diffusion models.

As a byproduct of our methods, for the well-studied problem of log-concave sampling in total variation distance, we give an algorithm and simple analysis achieving dimension dependence O~(d5/12)\widetilde{O}(d^{5/12}) compared to O~(d)\widetilde{O}(\sqrt{d}) from prior work.

1 Introduction

Diffusion models [SDWMG15, SE19, HJA20, DN21, SDME21, SSDK+21, VKK21] have emerged as the de facto approach to generative modeling across a range of data modalities like images [BGJ+23, EKB+24], audio [KPH+20], video [BPH+24], and molecules [WYvdB+24]. In recent years a slew of theoretical works have established surprisingly general convergence guarantees for this method [CCL+23b, LLT23, CLL23, CCL+23a, CDD23, BDBDD24, GPPX23, LWCC23, LHE+24]. They show that for essentially any data distribution, assuming one has a sufficiently accurate estimate for its score function, one can approximately sample from it in polynomial time.

While these results offer some theoretical justification for the empirical successes of diffusion models, the upper bounds they furnish for the number of iterations needed to generate a single sample are quite loose relative to what is done in practice. The best known provable bounds scale as O(d/ε)O(\sqrt{d}/\varepsilon), where dd is the dimension of the space in which the diffusion is taking place (e.g. d=16384d=16384 for Stable Diffusion) [CCL+23a], and ε\varepsilon is the target error. Even ignoring the dependence on ε\varepsilon and the hidden constant factor, this is at least 23×2-3\times larger than the default value of 5050 inference steps in Stable Diffusion.

In this work we consider a new approach for driving down the amount of compute that is provably needed to sample with diffusion models. Our approach is rooted in the randomized midpoint method, originally introduced by Shen and Lee [SL19] in the context of Langevin Monte Carlo for log-concave sampling. At a high level, this is a method for numerically solving differential equations where within every discrete window of time, one forms an unbiased estimate for the drift by evaluating it at a random “midpoint” (see Section 2.2 for a formal treatment). For sampling from log-concave densities, the number of iterations needed by their method scales with d1/3d^{1/3}, and this remains the best known bound in the “low-accuracy” regime.

While this method is well-studied in the log-concave setting [HBE20, YKD23, YD24, SL19], its applicability to diffusion models has been unexplored both theoretically and empirically. Our first result uses the randomized midpoint method to obtain an improvement over the prior best known bound of O(d/ε)O(\sqrt{d}/\varepsilon) for sampling arbitrary smooth distributions with diffusion models:

Theorem 1.1 (Informal, see Theorem A.10).

Suppose that the data distribution qq has bounded second moment, its score functions lnqt\nabla\ln q_{t} along the forward process are LL-Lipschitz, and we are given score estimates which are LL-Lipschitz and O~(εd1/12L)\widetilde{O}(\frac{\varepsilon}{d^{1/12}\sqrt{L}})111O~()\widetilde{O}(\cdot) hides polylogarithmic factors in d,L,εd,L,\varepsilon and \mathbbExq[x2]\operatorname*{\mathbb{E}}_{x\sim q}[\|x\|^{2}]-close to lnqt\nabla\ln q_{t} for all tt. Then there is a diffusion-based sampler using these score estimates (see Algorithm 1) which outputs a sample whose law is ε\varepsilon-close in total variation distance to qq using O~(L5/3d5/12/ε)\widetilde{O}(L^{5/3}d^{5/12}/\varepsilon) iterations.

Our algorithm is based on the ODE-based predictor-corrector algorithm introduced in [CCL+23a], but in place of the standard exponential integrator discretization in the predictor step, we employ randomized midpoint discretization. We note that in the domain of log-concave sampling, the result of Shen and Lee only achieves recovery in Wasserstein distance. Prior to our work, it was actually open whether one can achieve the same dimension dependence in total variation or KL divergence, for which the best known bound was O~(d)\widetilde{O}(\sqrt{d}) [MCC+21, ZCL+23, AC23]. In contrast, our result circumvents this barrier by carefully trading off time spent in the corrector phase of the algorithm for time spent in the predictor phase. We defer the details of this, as well as other important technical hurdles, to Section 3.1.

Next, we turn to a different computational model: instead of quantifying the cost of an algorithm in terms of the total number of iterations, we consider the parallel setting where one has access to multiple processors and wishes to minimize the total number of parallel rounds needed to generate a single sample. This perspective has been explored in a recent empirical work [SBE+24], but to our knowledge, no provable guarantees were known for parallel sampling with diffusion models (see Section 1.1 for discussion of concurrent and independent work). Our second result provides the first such guarantee:

Theorem 1.2 (Informal, see Theorem B.13).

Under the same assumptions on qq as in Theorem 1.1, and assuming that we are given score estimates which are O~(εL)\widetilde{O}(\frac{\varepsilon}{\sqrt{L}})-close to lnqt\nabla\ln q_{t} for all tt, there is a diffusion-based sampler using these score estimates (see Algorithm 9) which outputs a sample whose law is ε\varepsilon-close in total variation distance to qq using O~(Lpolylog(Ld/ε))\widetilde{O}(L\cdot\mathrm{polylog}(Ld/\varepsilon)) parallel rounds.

This result follows in the wake of several recent theoretical works on parallel sampling of log-concave densities using Langevin Monte Carlo [AHL+23, ACV24, SL19]. A common thread among these works is the observation that differential equations can be numerically solved via fixed point iteration (see Section 2.3 for details), and we adopt a similar perspective in the context of diffusions. To our knowledge this is the first provable guarantee for parallel sampling beyond the log-concave setting.

Finally, we show that, as a byproduct of our methods, we can actually obtain a similar dimension dependence of O~(d5/12)\widetilde{O}(d^{5/12}) as in Theorem 1.1 for log-concave sampling in TV, superseding the previously best known bound of O~(d)\widetilde{O}(\sqrt{d}) mentioned above.

Theorem 1.3 (Informal, see Theorem C.2).

Suppose distribution qq is mm-strongly-log-concave, and its score function lnq\nabla\ln q is LL-Lipschitz. Then, there is a underdamped-Langevin-based sampler that uses this score (Algorithm 11) and outputs a sample whose law is ε\varepsilon-close in total variation to qq using O~(d5/12(L4/3ε2/3m4/3+1ε))\widetilde{O}\left(d^{5/12}\left(\frac{L^{4/3}}{\varepsilon^{2/3}m^{4/3}}+\frac{1}{\varepsilon}\right)\right) iterations.

1.1 Related work

Our discretization scheme is based on the randomized midpoint method of [SL19], which has been studied at length in the domain of log-concave sampling [HBE20, YKD23, YD24].

The proof of our parallel sampling result builds on the ideas of [SL19, AHL+23, ACV24] on parallelizing the collocation method. These prior results were focused on Langevin Monte Carlo, rather than diffusion-based sampling. We review these ideas in Section 2.3.

In [CCL+23a], the authors proposed the predictor-corrector framework that we also use for analysing convergence guarantee of the probability flow ODE and which achieved iteration complexity scaling with O~(d)\widetilde{O}(\sqrt{d}). In addition to this, there have been many works in recent years giving general convergence guarantees for diffusion models [DBTHD21, BMR22, DB22, LLT22, LWYL22, Pid22, WY22, CCL+23b, CDD23, LLT23, LWCC23, BDD23, CCL+23a, BDBDD24, CLL23, GPPX23]. Of these, one line of work [CCL+23b, LLT23, CLL23, BDBDD24] analyzed DDPM, the stochastic analogue of the probability flow ODE, and showed O~(d)\tilde{O}(d) iteration complexity bounds. Another set of works [CCL+23a, CDD23, LWCC23, LHE+24] studied the probability flow ODE, for which our work provides a new discretization scheme for the probability flow ODE, that achieves a state-of-the-art O~(d5/12)\widetilde{O}(d^{5/12}) dimension dependence for sampling from a diffusion model.

Concurrent work.

Here we discuss the independent works of [CRYR24] and [KN24]. [CRYR24] gave an analysis for parallel sampling with diffusion models that also achieves a polylog(d)\mathrm{polylog}(d) number of parallel rounds like in the present work. [KN24] showed an improved dimension dependence of O~(d5/12)\widetilde{O}(d^{5/12}) for log-concave sampling in total variation, similar to our analogous result, but via a different proof technique. In addition to this, they show a similar result when the distribution only satisfies a log-Sobolev inequality. They also show empirical results for diffusion models, showing that an algorithm inspired by the randomized midpoint method outperforms ODE based methods with similar compute. While their work builds on the randomized midpoint method, they do not theoretically analyze the diffusion setting and do not study parallel sampling.

2 Preliminaries

2.1 Probability flow ODE

In this section we review basics about deterministic diffusion-based samplers; we refer the reader to [CCL+23a] for a more thorough exposition.

Let qq^{*} denote the data distribution over \mathbbRd\mathbb{R}^{d}. We consider the standard Ornstein-Uhlenbeck (OU) forward process, i.e. the “VP SDE,” given by

dxt=xtdt+2dBtx0q,\mathrm{d}x^{\rightarrow}_{t}=-x^{\rightarrow}_{t}\,\mathrm{d}t+\sqrt{2}\,\mathrm{d}B_{t}\,\qquad x_{0}^{\rightarrow}\sim q^{*}\,, (1)

where (Bt)t0(B_{t})_{t\geq 0} denotes a standard Brownian motion in \mathbbRd\mathbb{R}^{d}. This process converges exponentially quickly to its stationary distribution, the Gaussian distribution 𝒩(0,Id)\mathcal{N}(0,\mathrm{Id}).

Suppose the OU process is run until terminal time T>0T>0, and for any t[0,T]t\in[0,T], let qt\triangleqlaw(xt)q^{*}_{t}\triangleq\text{law}(x^{\rightarrow}_{t}), i.e. the law of the forward process at time tt. We will consider the reverse process given by the probability flow ODE

dxt=(xt+lnqTt(xt))dt.\mathrm{d}x_{t}=(x_{t}+\nabla\ln q_{T-t}(x_{t}))\,\mathrm{d}t\,. (2)

This is a time-reversal of the forward process, so that if x0qTx_{0}\sim q_{T}, then law(xt)=qTt\text{law}(x_{t})=q^{*}_{T-t}. In practice, one initializes at x0𝒩(0,Id)x_{0}\sim\mathcal{N}(0,\mathrm{Id}), and instead of using the exact score function lnqTt\nabla\ln q_{T-t}, one uses estimates s^TtlnqTt\widehat{s}_{T-t}\approx\nabla\ln q_{T-t} which are learned from data. Additionally, the ODE is solved numerically using any of a number of discretization schemes. The theoretical literature on diffusion models has focused primarily on exponential integration, which we review next before turning to the discretization scheme, the randomized midpoint method used in the present work.

2.2 Discretization schemes

Suppose we wish to discretize the following semilinear ODE:

dxt=(xt+ft(xt))dt.\mathrm{d}x_{t}=(x_{t}+f_{t}(x_{t}))\,\mathrm{d}t\,. (3)

For our application we will eventually take ft\triangleqs^Ttf_{t}\triangleq\widehat{s}_{T-t}, but we use ftf_{t} in this section to condense notation.

Suppose we want to discretize Equation 3 over a time window [t0,t0+h][t_{0},t_{0}+h]. The starting point is the integral formulation for this ODE:

xt0+h=ehxt0+t0t0+het0+htft(xt)dt.x_{t_{0}+h}=e^{h}x_{t_{0}}+\int^{t_{0}+h}_{t_{0}}e^{t_{0}+h-t}f_{t}(x_{t})\,\mathrm{d}t\,. (4)

Under the standard exponential integrator discretization, one would approximate the integrand by et0+htft0(xt0)e^{t_{0}+h-t}f_{t_{0}}(x_{t_{0}}) and obtain the approximation

xt0+hehxt0+(eh1)ft0(xt0).x_{t_{0}+h}\approx e^{h}x_{t_{0}}+(e^{h}-1)f_{t_{0}}(x_{t_{0}})\,. (5)

The drawback of this discretization is that it uses an inherently biased estimate for the integral in Eq. (4). The key insight of [SL19] was to replace this with the following unbiased estimate

t0t0+het0+htft(xt)dthe(1α)hft0+αh(xt0+αh),\int^{t_{0}+h}_{t_{0}}e^{t_{0}+h-t}f_{t}(x_{t})\,\mathrm{d}t\approx he^{(1-\alpha)h}f_{t_{0}+\alpha h}(x_{t_{0}+\alpha h})\,, (6)

where α\alpha is a uniformly random sample from [0,1][0,1]. While this alone does not suffice as the estimate depends on xt0+αhx_{t_{0}+\alpha h}, naturally we could iterate the above procedure again to obtain an approximation to xt0+αhx_{t_{0}+\alpha h}. It turns out though that even if we simply approximate xt0+αhx_{t_{0}+\alpha h} using exponential integrator discretization, we can obtain nontrivial improvements in discretization error (e.g. our Theorem 1.1). In this case, the above sequence of approximations takes the following form:

xt0+αh\displaystyle x_{t_{0}+\alpha h} eαhxt0+(eαh1)ft0(xt0)\displaystyle\approx e^{\alpha h}x_{t_{0}}+(e^{\alpha h}-1)f_{t_{0}}(x_{t_{0}}) (7)
xt0+h\displaystyle x_{t_{0}+h} ehxt0+he(1α)hft0+αh(xt0+αh).\displaystyle\approx e^{h}x_{t_{0}}+he^{(1-\alpha)h}f_{t_{0}+\alpha h}(x_{t_{0}+\alpha h})\,. (8)

Note that a similar idea can be used to discretize stochastic differential equations, but in this work we only use it to discretize the probability flow ODE.

Predictor-Corrector.

For important technical reasons, in our analysis we actually consider a slightly different algorithm than simply running the probability flow ODE with approximate score, Gaussian initialization, and randomized midpoint discretization. Specifically, we interleave the ODE with corrector steps that periodically inject noise into the sampling trajectory. We refer to the phases in which we are running the probability flow ODE as predictor steps.

The corrector step will be given by running underdamped Langevin dynamics. As our analysis of this will borrow black-box from bounds proven in [CCL+23a], we refer to Section A.2 for details.

2.3 Parallel sampling

The scheme outlined in the previous section is a simple special case of the collocation method. In the context of the semilinar ODE from Eq. (3), the idea behind the collocation method is to solve the integral formulation of the ODE in Eq. (4) via fixed point iteration. For our parallel sampling guarantees, instead of choosing a single randomized midpoint α\alpha, we break up the window [t0,t0+h][t_{0},t_{0}+h] into RR sub-windows, select randomized midpoints α1,,αR\alpha_{1},\ldots,\alpha_{R} for these sub-windows, and approximate the trajectory of the ODE at any time t0+iδt_{0}+i\delta, where δ\triangleqh/R\delta\triangleq h/R, by

xt0+αiheαihxt0+j=1i(eαih(j1)δmax(eαihjδ,1))ft0+αjh(xt0+αjh).x_{t_{0}+\alpha_{i}h}\approx e^{\alpha_{i}h}x_{t_{0}}+\sum^{i}_{j=1}\Bigl{(}e^{\alpha_{i}h-(j-1)\delta}-\max(e^{\alpha_{i}h-j\delta},1)\Bigr{)}\cdot f_{t_{0}+\alpha_{j}h}(x_{t_{0}+\alpha_{j}h})\,. (9)

One can show that as RR\to\infty, this approximation tends to an equality. For sufficiently large RR, Eq. (9) naturally suggests a fixed point iteration that can be used to approximate each xt0+αihx_{t_{0}+\alpha_{i}h}, i.e. we can maintain a sequence of estimates x^t0+αih(k)\widehat{x}^{(k)}_{t_{0}+\alpha_{i}h} defined by the iteration

x^t0+αih(k)eαihx^t0(k1)+j=1i(eαih(j1)δmax(eαihjδ,1))ft0+αjh(x^t0+αjh(k1)),\widehat{x}^{(k)}_{t_{0}+\alpha_{i}h}\leftarrow e^{\alpha_{i}h}\widehat{x}^{(k-1)}_{t_{0}}+\sum^{i}_{j=1}\Bigl{(}e^{\alpha_{i}h-(j-1)\delta}-\max(e^{\alpha_{i}h-j\delta},1)\Bigr{)}\cdot f_{t_{0}+\alpha_{j}h}(\widehat{x}^{(k-1)}_{t_{0}+\alpha_{j}h})\,, (10)

for kk ranging from 11 up to some sufficiently large KK. Finally, analogously to Eq. (8), we can estimate xt0+hx_{t_{0}+h} via

xt0+hehx^t0(K)+δi=1Re(1αi)hft0+αih(x^t0+αih(K)).x_{t_{0}+h}\approx e^{h}\widehat{x}^{(K)}_{t_{0}}+\delta\sum^{R}_{i=1}e^{(1-\alpha_{i})h}f_{t_{0}+\alpha_{i}h}(\widehat{x}^{(K)}_{t_{0}+\alpha_{i}h})\,. (11)

The key observation, made in [SL19] and also in related works of [ACV24, SBE+24, AHL+23], is that for any fixed round kk, all of the iterations Eq. (10) for different choices of i=1,,Ri=1,\ldots,R can be computed in parallel. With RR parallel processors, one can thus compute the estimate for xt0+hx_{t_{0}+h} in KK parallel rounds, with O(KR)O(KR) total work.

2.4 Assumptions

Throughout the paper, for our diffusion results, we will make the following standard assumptions on the data distribution and score estimates.

Assumption 2.1 (Bounded Second Moment).
\mathfrakm22:=\mathbbExq0[x2]<.\displaystyle\mathfrak{m}_{2}^{2}:=\operatorname*{\mathbb{E}}_{x\sim q_{0}}\left[\|x\|^{2}\right]<\infty.
Assumption 2.2 (Lipschitz Score).

For all tt, the score lnqt\nabla\ln q_{t} is LL-Lipschitz.

Assumption 2.3 (Lipschitz Score estimates).

For all tt for which we need to estimate the score function in our algorithms, the score estimate s^t\widehat{s}_{t} is LL-lipschitz.

Assumption 2.4 (Score Estimation Error).

For all tt for which we need to estimate the score function in our algorithms,

\mathbbExtqt[s^t(xt)lnqt(xt)2]εsc2.\displaystyle\operatorname*{\mathbb{E}}_{x_{t}\sim q_{t}}\left[\|\widehat{s}_{t}(x_{t})-\nabla\ln q_{t}(x_{t})\|^{2}\right]\leq\varepsilon_{\mathrm{sc}}^{2}.

3 Technical overview

Here we provide an overview of our sequential and parallel algorithms, along with the analysis of our iteration complexity bounds. We begin with a description of the sequential algorithm.

3.1 Sequential algorithm

Following the framework of [CCL+23a], our algorithm consists of “predictor” steps interspersed with “corrector” steps, with the time spent on each carefully tuned to obtain our final O~(d5/12)\widetilde{O}(d^{5/12}) dimension dependence. We first describe our predictor step – this is the piece of our algorithm that makes use of the Shen and Lee’s randomized midpoint method [SL19].

Algorithm 1 PredictorStep (Sequential)

Input parameters:

  • Starting sample x^0\widehat{x}_{0}, Starting time t0t_{0}, Number of steps NN, Step sizes hn[0,,N1]h_{n\in[0,\dots,N-1]}, Score estimates s^t\widehat{s}_{t}

  1. 1.

    For n=0,,N1n=0,\dots,N-1:

    1. (a)

      Let tn=t0i=0n1hit_{n}=t_{0}-\sum_{i=0}^{n-1}h_{i}

    2. (b)

      Randomly sample α\alpha uniformly from [0,1][0,1].

    3. (c)

      Let x^n+12=eαhnx^n+(eαhn1)s^tn(x^n)ds\widehat{x}_{n+\frac{1}{2}}=e^{\alpha h_{n}}\widehat{x}_{n}+\left(e^{\alpha h_{n}}-1\right)\widehat{s}_{t_{n}}(\widehat{x}_{n})ds

    4. (d)

      Let x^n+1=ehnx^n+hne(1α)hns^tnαhn(x^n+12)\widehat{x}_{n+1}=e^{h_{n}}\widehat{x}_{n}+h_{n}\cdot e^{(1-\alpha){h_{n}}}\widehat{s}_{t_{n}-\alpha h_{n}}(\widehat{x}_{n+\frac{1}{2}})

  2. 2.

    Let tN=t0i=0N1hit_{N}=t_{0}-\sum_{i=0}^{N-1}h_{i}

  3. 3.

    Return x^N,tN\widehat{x}_{N},t_{N}.

The main difference between the above and the predictor step of [CCL+23a] are steps 1(b)1(b)1(d)1(d). 1(b)1(b) and 1(c)1(c) together compute a randomized midpoint, and 1(d)1(d) uses this midpoint to obtain an approximate solution to the integral of the ODE. We describe these steps in more detail in Section 3.3.

Next, we describe the “corrector” step, introduced in [CCL+23a]. First, recall the underdamped Langevin ODE:

dx^t=v^tdtdv^t=(s^(x^t/hh)γv^t)dt+2γdBt\displaystyle\begin{split}\mathrm{d}\widehat{x}_{t}&=\widehat{v}_{t}\,\mathrm{d}t\\ \mathrm{d}\widehat{v}_{t}&=(\widehat{s}(\widehat{x}_{\lfloor t/h\rfloor h})-\gamma\widehat{v}_{t})\,\mathrm{d}t+\sqrt{2\gamma}\,\mathrm{d}B_{t}\end{split} (12)

Here s^\widehat{s} is our L2L^{2} accurate score estimate for a fixed time (say tt). Then, the corrector step is described below.

Algorithm 2 CorrectorStep (Sequential)

Input parameters:

  • Starting sample x^0\widehat{x}_{0}, Total time TcorrT_{\mathrm{corr}}, Step size hcorrh_{\mathrm{corr}}, Score estimate s^\widehat{s}

  1. 1.

    Run underdamped Langevin Monte Carlo in (12) for total time TcorrT_{\mathrm{corr}} using step size hcorrh_{\mathrm{corr}}, and let the result be x^N\widehat{x}_{N}.

  2. 2.

    Return x^N\widehat{x}_{N}.

Finally, Algorithm 3 below puts the predictor and corrector steps together to give our final sequential algorithm.

Algorithm 3 SequentialAlgorithm

Input parameters:

  • Start time TT, End time δ\delta, Corrector steps time Tcorr\lesssim1/LT_{\mathrm{corr}}\lesssim 1/\sqrt{L}, Number of predictor-corrector steps N0N_{0}, Predictor step size hpredh_{\mathrm{pred}}, Corrector step size hcorrh_{\mathrm{corr}}, Score estimates s^t\widehat{s}_{t}

  1. 1.

    Draw x^0𝒩(0,Id)\widehat{x}_{0}\sim\mathcal{N}(0,I_{d}).

  2. 2.

    For n=0,,N01n=0,\dots,N_{0}-1:

    1. (a)

      Starting from x^n\widehat{x}_{n}, run Algorithm 4 with starting time Tn/LT-n/L using step sizes hpredh_{\mathrm{pred}} for all NN steps, with N=1LhpredN=\frac{1}{Lh_{\mathrm{pred}}}, so that the total time is 1/L1/L. Let the result be x^n+1\widehat{x}_{n+1}^{\prime}.

    2. (b)

      Starting from x^n+1\widehat{x}_{n+1}^{\prime}, run Algorithm 2 for total time TcorrT_{\mathrm{corr}} with step size hcorrh_{\mathrm{corr}} and score estimate s^T(n+1)/L\widehat{s}_{T-(n+1)/L} to obtain x^n+1\widehat{x}_{n+1}.

  3. 3.

    Starting from x^N0\widehat{x}_{N_{0}}, run Algorithm 4 with starting time TN0/LT-N_{0}/L using step sizes hpred/2,hpred/4,hpred/8,,δh_{\mathrm{pred}}/2,h_{\mathrm{pred}}/4,h_{\mathrm{pred}}/8,\dots,\delta to obtain x^N0+1\widehat{x}_{N_{0}+1}^{\prime}.

  4. 4.

    Starting from x^N0+1\widehat{x}^{\prime}_{N_{0}+1}, run Algorithm 2 for total time TcorrT_{\mathrm{corr}} with step size hcorrh_{\mathrm{corr}} and score estimate s^δ\widehat{s}_{\delta} to obtain x^N0+1\widehat{x}_{N_{0}+1}.

  5. 5.

    Return x^N0+1\widehat{x}_{N_{0}+1}.

For the final setting of parameters in Algorithm 3, see Theorem A.10. Now, we describe the analysis of the above algorithm in detail.

3.2 Predictor-corrector framework

The general framework of our algorithm closely follows that of [CCL+23a], which proposed to run the (discretized) reverse ODE but interspersed with “corrector” steps given by running underdamped Langevin dynamics. The idea is that the “predictor” steps where the discretized reverse ODE is being run keep the sampler close to the true reverse process in Wasserstein distance, but they cannot be run for too long before potentially incurring exponential blowups. The main purpose of the corrector steps is then to inject stochasticity into the trajectory of the sampler in order to convert closeness in Wasserstein to closeness in KL divergence. This effectively allows one to “restart the coupling” used to control the predictor steps. For technical reasons that are inherited from [CCL+23a], for most of the reverse process the predictor steps (Step 2(a)) are run with a fixed step size, but at the end of the reverse process (Step 3), they are run with exponentially decaying step sizes.

We follow the same framework, and the core of our result lies in refining the algorithm and analysis for the predictor steps by using the randomized midpoint method. Below, we highlight our key technical steps.

3.3 Predictor step – improved discretization error with randomized midpoints

Here, we explain the main idea behind why randomized midpoint allows us to achieve improved dimension dependence. We first focus on the analysis of the predictor (Algorithm 1) and restrict our attention to running the reverse process for a small amount of time h1/Lh\ll 1/L.

We begin by recalling the dimension dependence achieved by the standard exponential integrator scheme. One can show (see e.g. Lemma 4 in [CCL+23a]) that if the true reverse process and the discretized reverse process are both run for small time hh starting from the same initialization, the two processes drift by a distance of O(d1/2h2)O(d^{1/2}h^{2}). By iterating this coupling O(1/h)O(1/h) times, we conclude that in an O(1)O(1) window of time, the processes drift by a distance of O(d1/2h)O(d^{1/2}h). To ensure this is not too large, one would take the step size hh to be O(1/d)O(1/\sqrt{d}), thus obtaining an iteration complexity of O(d)O(\sqrt{d}) as in [CCL+23a].

The starting point in the analysis of randomized midpoint is to instead track the squared displacement between the two processes instead. Given two neighboring time steps tht-h and tt in the algorithm, let xtx_{t} denote the true reverse process at time tt, and let x^t\widehat{x}_{t} denote the algorithm at time tt (in the notation of Algorithm 2, this is x^n\widehat{x}_{n} for some nn, but we use tt in the discussion here to make the comparison to the true reverse process clearer). Note that x^t\widehat{x}_{t} depends on the choice of randomized midpoint α\alpha (see Step 1(b)). One can bound the squared displacement \mathbbExtx^t2\operatorname*{\mathbb{E}}\,\left\|x_{t}-\widehat{x}_{t}\right\|^{2} as follows. Let yty_{t} be the result of running the reverse process for time hh starting from x^th\widehat{x}_{t-h}. Then by writing xtx^tx_{t}-\widehat{x}_{t} as (xtyt)(x^tyt)(x_{t}-y_{t})-(\widehat{x}_{t}-y_{t}) and applying Young’s inequality, we obtain

\mathbbEx^th,αxtx^t2(1+Lh2)\mathbbEx^thxtyt2+2Lh\mathbbEx^th\mathbbEαx^tyt2+\mathbbEx^th\mathbbEαx^tyt2.\operatorname*{\mathbb{E}}_{\widehat{x}_{t-h},\alpha}\|x_{t}-\widehat{x}_{t}\|^{2}\leq\Bigl{(}1+\frac{Lh}{2}\Bigr{)}\operatorname*{\mathbb{E}}_{\widehat{x}_{t-h}}\|x_{t}-y_{t}\|^{2}+\frac{2}{Lh}\operatorname*{\mathbb{E}}_{\widehat{x}_{t-h}}\|\operatorname*{\mathbb{E}}_{\alpha}\widehat{x}_{t}-y_{t}\|^{2}+\operatorname*{\mathbb{E}}_{\widehat{x}_{t-h}}\operatorname*{\mathbb{E}}_{\alpha}\|\widehat{x}_{t}-y_{t}\|^{2}\,. (13)

For the first term, because xtx_{t} and yty_{t} are the result of running the same ODE on initializations xthx_{t-h} and x^th\widehat{x}_{t-h} , the first term is close to \mathbbExthx^th2\operatorname*{\mathbb{E}}\|x_{t-h}-\widehat{x}_{t-h}\|^{2} provided h1/Lh\ll 1/L. The upshot is that the squared displacement at time tt is at most the squared displacement at time tht-h plus the remaining two terms on the right of Equation 13.

The main part of the proof lies in bounding these two terms, which can be thought of as “bias” and “variance” terms respectively. The variance term can be shown to scale with the square of the aforementioned O(d1/2h2)O(d^{1/2}h^{2}) displacement bound that arises in the exponential integrator analysis, giving O(dh4)O(dh^{4}):

Lemma 3.1 (Informal, see Lemma A.4 for formal statement).

If h\lesssim1Lh\lesssim\frac{1}{L} and Tt(Tth)/2T-t\geq(T-t-h)/2, then

\mathbbEx^th\mathbbEαx^tyt2\lesssimL2dh4(L1T(th))+h2εsc2+L2h2\mathbbEx^thxthx^th2.\displaystyle\operatorname*{\mathbb{E}}_{\widehat{x}_{t-h}}\operatorname*{\mathbb{E}}_{\alpha}\left\|\widehat{x}_{t}-y_{t}\right\|^{2}\lesssim L^{2}dh^{4}\left(L\lor\frac{1}{T-(t-h)}\right)+h^{2}\varepsilon_{\mathrm{sc}}^{2}+L^{2}h^{2}\operatorname*{\mathbb{E}}_{\widehat{x}_{t-h}}\|x_{t-h}-\widehat{x}_{t-h}\|^{2}\,.

Note that in this bound, in addition to the O(dh4)O(dh^{4}) term and a term for the score estimation error, there is an additional term which depends on the squared displacement from the previous time step. Because the prefactor L2h2L^{2}h^{2} is sufficiently small, this will ultimately be negligible.

The upshot of the above Lemma is that if the bias term is of lower order, then this means that the squared displacement essentially increases by O(dh4)O(dh^{4}) with every time step of length hh. Over O(1/h)O(1/h) such steps, the total squared displacement is O(dh3)O(dh^{3}), so if we take the step size hh to be O(1/d1/3)O(1/d^{1/3}), this suggests an improved iteration complexity of O(d1/3)O(d^{1/3}).

Arguing that the bias term 2Lh\mathbbEx^th\mathbbEαx^tyt2\frac{2}{Lh}\operatorname*{\mathbb{E}}_{\widehat{x}_{t-h}}\left\|\operatorname*{\mathbb{E}}_{\alpha}\widehat{x}_{t}-y_{t}\right\|^{2} is dominated by the variance term is where it is crucial that we use randomized midpoint instead of exponential integrator. But recall that the randomized midpoint method was engineered so that it would give an unbiased estimate for the true solution to the reverse ODE if the estimate of the trajectory at the randomized midpoint were exact. In reality we only have an approximation to the latter, but as we show, the error incurred by this is indeed of lower order (see Lemma A.3). One technical complication that arises here is that the relevant quantity to bound is the distance between the true process at the randomized midpoint versus the algorithm, when both are initialized at an intermediate point in the algorithm’s trajectory. Bounding such quantities in expectation over the randomness of the algorithm’s trajectory can be difficult, but our proof identifies a way of “offloading” some of this difficulty by absorbing some excess terms into a term of the form xthx^th2\|x_{t-h}-\widehat{x}_{t-h}\|^{2}, i.e. the squared displacement from the previous time step. Concretely, we obtain the following bound on the bias term:

Lemma 3.2 (Informal, see Lemma A.2 for formal statement).
\mathbbEx^th\mathbbEαx^tyt2\lesssimL4dh6(L1Tt+h)+h2εsc2+L4h4\mathbbEx^thxthx^th2\operatorname*{\mathbb{E}}_{\widehat{x}_{t-h}}\|\operatorname*{\mathbb{E}}_{\alpha}\widehat{x}_{t}-y_{t}\|^{2}\lesssim L^{4}dh^{6}\left(L\lor\frac{1}{T-t+h}\right)+h^{2}\varepsilon_{\mathrm{sc}}^{2}+L^{4}h^{4}\operatorname*{\mathbb{E}}_{\widehat{x}_{t-h}}\|x_{t-h}-\widehat{x}_{t-h}\|^{2}

3.4 Shortening the corrector steps

While we have outlined how to improve the predictor step in the framewok of [CCL+23a], it is quite unclear whether the same can be achieved for the corrector step. Whereas the the former is geared towards closeness in Wasserstein distance, the latter is geared towards closeness in KL divergence, and it is a well-known open question in the log-concave sampling literature to obtain analogous discretization bounds in KL for the randomized midpoint method [Che23].

We will sidestep this issue and argue that even using exponential integrator discretization of the underdamped Langevin dynamics will suffice for our purposes, by simply shortening the amount of time for which each corrector step is run.

First, let us briefly recall what was shown in [CCL+23a] for the corrector step. If one runs underdamped Langevin dynamics with stationary distribution qq for time TT and exponential integrator discretization with step size hh starting from two distributions pp and qq, then the resulting distributions pp^{\prime} and qq satisfy

𝖳𝖵(p,q)\lesssimW2(p,q)L1/4T3/2+L3/4T1/2d1/2h,\mathsf{TV}(p^{\prime},q)\lesssim\frac{W_{2}(p,q)}{L^{1/4}T^{3/2}}+L^{3/4}T^{1/2}d^{1/2}h\,, (14)

where LL is the Lipschitzness of lnq\nabla\ln q (see Theorem A.6). At first glance this appears insufficient for our purposes: because of the d1/2hd^{1/2}h term coming from the discretization error, we would need to take step size h=1/dh=1/\sqrt{d}, which would suggest that the number of iterations must scale with d\sqrt{d}.

To improve the dimension dependence for our overall predictor-corrector algorithm, we observe that if we take TT itself to be smaller, then we can take hh to be larger while keeping the discretization error in Equation 14 sufficiently small. Of course, this comes at a cost, as TT also appears in the term W2(p,q)L1/4T3/2\frac{W_{2}(p,q)}{L^{1/4}T^{3/2}} in Equation 14. But in our overall proof, the W2(p,q)W_{2}(p,q) term is bounded by the predictor analysis. There, we had quite a bit of slack: even with step size as large as 1/d1/31/d^{1/3}, we could achieve small Wasserstein error. By balancing appropriately, we get our improved dimension dependence.

3.5 Parallel algorithm

Now, we summarize the main proof ideas for our parallel sampling result. In Section 2.3, we described how to approximately solve the reverse ODE over time hh by running KK rounds of the iteration in Equation 10. In our final algorithm, we will take hh to be dimension-independent, namely h=Θ(1/L)h=\Theta(1/\sqrt{L}), so that the main part of the proof is to bound the discretization error incurred over each of these time windows of length hh. As in the sequential analysis, we will interleave these “predictor” steps with corrector steps given by (parallelized) underdamped Langevin dynamics.

We begin by describing the parallel predictor step. Suppose we have produced an estimate for the reverse process at t0t_{0} and now wish to solve the ODE from time t0t_{0} to t0+ht_{0}+h. We initialize at {x^t0+αih(0)}i[R]\{\widehat{x}^{(0)}_{t_{0}+\alpha_{i}h}\}_{i\in[R]} via exponential integrator steps starting from the beginning of the window – see Line 1(c) in Algorithm 9 (this can be thought of as the analogue of Equation 7 used in the sequential algorithm). The key difference relative to the sequential algorithm is that here, because the length of the window is dimension-free, the discretization error incurred by this initialization is too large and must be refined using the fixed point iteration in Equation 10. The main step is then to show that with each iteration of Equation 10, the distance to the true reverse process contracts:

Lemma 3.3 (Informal, see Lemma B.2 for formal statement).

Suppose h\lesssim1/Lh\lesssim 1/L. If yty_{t} denotes the solution of the true ODE starting at x^t0\widehat{x}_{t_{0}} and running until time t0+αiht_{0}+\alpha_{i}h, then for all k{1,K}k\in\{1,\cdots K\} and i{1,,R}i\in\{1,\cdots,R\},

\mathbbEx^t0,α1,αRx^t0+αih(k)yt0+αih2\displaystyle\operatorname*{\mathbb{E}}_{\hat{x}_{t_{0}},\alpha_{1},\cdots\alpha_{R}}\left\|\widehat{x}^{(k)}_{t_{0}+\alpha_{i}h}-y_{t_{0}+\alpha_{i}h}\right\|^{2} \lesssim(8h2L2)k(1Rj=1R\mathbbEx^t0,αjx^t0+αjh(0)yt0+αjh2)\displaystyle\lesssim\left(8h^{2}L^{2}\right)^{k}\cdot\left(\frac{1}{R}\sum_{j=1}^{R}\operatorname*{\mathbb{E}}_{\hat{x}_{t_{0}},\alpha_{j}}\left\|\widehat{x}^{(0)}_{t_{0}+\alpha_{j}h}-y_{t_{0}+\alpha_{j}h}\right\|^{2}\right)
+h2(εsc2+L2dh2R2(L1Tt0+h)+L2\mathbbEx^t0x^t0xt02),\displaystyle+h^{2}\left(\varepsilon^{2}_{\mathrm{sc}}+\frac{L^{2}dh^{2}}{R^{2}}(L\lor\frac{1}{T-t_{0}+h})+L^{2}\cdot\operatorname*{\mathbb{E}}_{\hat{x}_{t_{0}}}\left\|\widehat{x}_{t_{0}}-x_{t_{0}}\right\|^{2}\right)\,, (15)

where x^t0\widehat{x}_{t_{0}} is the iterate of the algorithm from the previous time window, and xt0x_{t_{0}} is the corresponding iterate in the true ODE.

In particular, because hh is at most a small multiple of 1/L1/L, the prefactor (8h2L2)k(8h^{2}L^{2})^{k} is exponentially decaying in kk, so that the error incurred by the estimate x^t0+αih(k)\widehat{x}^{(k)}_{t_{0}+\alpha_{i}h} is contracting with each fixed point iteration. Because the initialization is at distance poly(d)\mathrm{poly}(d) from the true process, O(logd)O(\log d) rounds of contraction thus suffice, which translates to O(logd)O(\log d) parallel rounds for the sampler. The rest of the analysis of the predictor step is quite similar to the analogous proofs for the sequential algorithm (i.e. Lemma B.4 and Lemma B.5 give the corresponding bias and variance bounds).

One shortcoming of the predictor analysis is that the contraction achieved by fixed point iteration ultimately bottoms out at error which scales with d/R2d/R^{2} (see the second term in Equation 15). In order for the discretization error to be sufficiently small, we thus have to take RR, and thus the total work of the algorithm, to scale with O(d)O(\sqrt{d}). So in this case we do not improve over the dimension dependence of [CCL+23a], and instead the improvement is in obtaining a parallel algorithm.

For the corrector analysis, we mostly draw upon the recent work of [ACV24] which analyzed a parallel implementation of the underdamped Langevin dynamics. While their guarantee focuses on sampling from log-concave distributions, implicit in their analysis is a bound for general smooth distributions on how much the law of the algorithm and the law of the true process drift apart in a bounded time window (see Lemma B.8). This bound suffices for our analysis of the corrector step, and we can conclude the following:

Theorem 3.4 (Informal, see Theorem B.12).

Let β1\beta\geq 1 be an adjustable parameter. Let pp^{\prime} denote the law of the output of running the parallel corrector (see Algorithm 8) for total time 1/L1/\sqrt{L} and step size hh, using an εsc\varepsilon_{\mathrm{sc}}-approximate estimate for lnq\nabla\ln q and starting from a sample from another distribution pp.

TV(p,q)\lesssimKL(p,q)\lesssimεscL+εβ+εβdW2(p,q).\textup{{TV}}(p^{\prime},q)\lesssim\sqrt{\textup{{KL}}(p^{\prime},q)}\lesssim\frac{\varepsilon_{\mathrm{sc}}}{\sqrt{L}}+\frac{\varepsilon}{\beta}+\frac{\varepsilon}{\beta\sqrt{d}}\cdot W_{2}(p,q)\,.

Furthermore, this algorithm uses Θ~(βd/ε)\widetilde{\Theta}(\beta\sqrt{d}/\varepsilon) score evaluations over Θ(log(β2d/ε2))\Theta(\log(\beta^{2}d/\varepsilon^{2})) parallel rounds.

Overall, parallel algorithm is somewhat different from the parallel sampler developed in the empirical work of [SBE+24], even apart from the fact that we use randomized midpoint discretization and corrector steps. The reason is that our algorithm applies collocation to fixed windows of time, whereas the algorithm of [SBE+24] utilizes a sliding window approach that proactively shifts the window forward as soon as the iterates at the start of the previous window begin to converge. We leave rigorously analyzing the benefits of this approach as an interesting future direction.

3.6 Log-concave sampling in total variation

Finally, we briefly summarize the simple proof for our result on log-concave sampling in TV, which achieves the best known dimension dependence of O~(d5/12)\widetilde{O}(d^{5/12}). Our main observation is that Shen and Lee’s randomized midpoint method [SL19] applied to the underdamped Langevin process gives a Wasserstein guarantee for log-concave sampling, while the corrector step of [CCL+23a] can convert a Wasserstein guarantee to closeness in TV. Thus, we can simply run the randomized midpoint method, followed by the corrector step to achieve closeness in TV. Carefully tuning the amount of time spend and step sizes for each phase of this algorithm yields our improved dimension dependence – see Appendix C for the full proof.

4 Discussion and Future Work

In this work, we showed that it is possible to leverage Shen and Lee’s randomized midpoint method [SL19] to achieve the best known dimension dependence for sampling from arbitrary smooth distributions in TV using diffusion. We also showed how to parallelize our algorithm, and showed that O~(log2d)\widetilde{O}(\log^{2}d) parallel rounds suffice for sampling. These constitute the first provable guarantees for parallel sampling with diffusion models. Finally, we showed that our techniques can be used to obtain an improved dimension dependence for log-concave sampling in TV.

We note that relative to [CCL+23a], our result requires a slightly stronger guarantee on the score estimation error, by a d1/12d^{1/12} factor; we believe this is an artifact of our analysis, and it would be interesting to remove this dependence in future work. Importantly however, it was not known how to achieve an improvement over the O(d)O(\sqrt{d}) dependence shown in that paper even in case that the scores are known exactly prior to the present work. Moreover, another line of work [LWCC23, LHE+24, DCWY24] analyzing diffusion sampling makes the stronger assumption that the score estimation error is O~(εd)\widetilde{O}\left(\frac{\varepsilon}{\sqrt{d}}\right), an assumption stronger than ours by a d5/12d^{5/12} factor; this does not detract from the importance of these works, and we feel the same is true in our case.

We also note that our diffusion results require smoothness assumptions – we assume that the true score, as well as our score estimates are LL-Lipschitz. Although this assumption is standard in the literature, recent work [CLL23, BDBDD24] has analyzed DDPM in the absence of these assumptions, culminating in a O~(d)\widetilde{O}(d) dependence for sampling using a discretization of the reverse SDE. However, unlike in the smooth case, it is not known whether even a sublinear in dd dependence is possible without smoothness assumptions via any algorithm. We leave this as an interesting open question for future work.

Acknowledgements

S.C. thanks Sinho Chewi for a helpful discussion on log-concave sampling. S.G. was funded by NSF Award CCF-1751040 (CAREER) and the NSF AI Institute for Foundations of Machine Learning (IFML).

References

  • [AC23] Jason M Altschuler and Sinho Chewi. Faster high-accuracy log-concave sampling via algorithmic warm starts. In 2023 IEEE 64th Annual Symposium on Foundations of Computer Science (FOCS), pages 2169–2176. IEEE, 2023.
  • [ACV24] Nima Anari, Sinho Chewi, and Thuy-Duong Vuong. Fast parallel sampling under isoperimetry. CoRR, abs/2401.09016, 2024.
  • [AHL+23] Nima Anari, Yizhi Huang, Tianyu Liu, Thuy-Duong Vuong, Brian Xu, and Katherine Yu. Parallel discrete sampling via continuous walks. In Proceedings of the 55th Annual ACM Symposium on Theory of Computing, pages 103–116, 2023.
  • [BDBDD24] Joe Benton, Valentin De Bortoli, Arnaud Doucet, and George Deligiannidis. Nearly d-linear convergence bounds for diffusion models via stochastic localization. In The Twelfth International Conference on Learning Representations, 2024.
  • [BDD23] Joe Benton, George Deligiannidis, and Arnaud Doucet. Error bounds for flow matching methods. arXiv preprint arXiv:2305.16860, 2023.
  • [BGJ+23] James Betker, Gabriel Goh, Li Jing, Tim Brooks, Jianfeng Wang, Linjie Li, Long Ouyang, Juntang Zhuang, Joyce Lee, Yufei Guo, et al. Improving image generation with better captions. Computer Science. https://cdn. openai. com/papers/dall-e-3. pdf, 2(3):8, 2023.
  • [BMR22] Adam Block, Youssef Mroueh, and Alexander Rakhlin. Generative modeling with denoising auto-encoders and Langevin sampling. arXiv preprint 2002.00107, 2022.
  • [BPH+24] Tim Brooks, Bill Peebles, Connor Holmes, Will DePue, Yufei Guo, Li Jing, David Schnurr, Joe Taylor, Troy Luhman, Eric Luhman, Clarence Ng, Ricky Wang, and Aditya Ramesh. Video generation models as world simulators. 2024.
  • [CCL+23a] Sitan Chen, Sinho Chewi, Holden Lee, Yuanzhi Li, Jianfeng Lu, and Adil Salim. The probability flow ODE is provably fast. In Thirty-seventh Conference on Neural Information Processing Systems, 2023.
  • [CCL+23b] Sitan Chen, Sinho Chewi, Jerry Li, Yuanzhi Li, Adil Salim, and Anru R Zhang. Sampling is as easy as learning the score: theory for diffusion models with minimal data assumptions. In International Conference on Learning Representations, 2023.
  • [CDD23] Sitan Chen, Giannis Daras, and Alex Dimakis. Restoration-degradation beyond linear diffusions: A non-asymptotic analysis for ddim-type samplers. In International Conference on Machine Learning, pages 4462–4484. PMLR, 2023.
  • [Che23] Sinho Chewi. Log-concave sampling. Book draft available at https://chewisinho. github. io, 2023.
  • [CLL23] Hongrui Chen, Holden Lee, and Jianfeng Lu. Improved analysis of score-based generative modeling: User-friendly bounds under minimal smoothness assumptions. In International Conference on Machine Learning, pages 4735–4763. PMLR, 2023.
  • [CRYR24] Haoxuan Chen, Yinuo Ren, Lexing Ying, and Grant M. Rotskoff. Accelerating diffusion models with parallel sampling: Inference at sub-linear time complexity, 2024.
  • [DB22] Valentin De Bortoli. Convergence of denoising diffusion models under the manifold hypothesis. Transactions on Machine Learning Research, 2022.
  • [DBTHD21] Valentin De Bortoli, James Thornton, Jeremy Heng, and Arnaud Doucet. Diffusion Schrödinger bridge with applications to score-based generative modeling. In M. Ranzato, A. Beygelzimer, Y. Dauphin, P.S. Liang, and J. Wortman Vaughan, editors, Advances in Neural Information Processing Systems, volume 34, pages 17695–17709. Curran Associates, Inc., 2021.
  • [DCWY24] Zehao Dou, Minshuo Chen, Mengdi Wang, and Zhuoran Yang. Theory of consistency diffusion models: Distribution estimation meets fast sampling. In Ruslan Salakhutdinov, Zico Kolter, Katherine Heller, Adrian Weller, Nuria Oliver, Jonathan Scarlett, and Felix Berkenkamp, editors, Proceedings of the 41st International Conference on Machine Learning, volume 235 of Proceedings of Machine Learning Research, pages 11592–11612. PMLR, 21–27 Jul 2024.
  • [DN21] Prafulla Dhariwal and Alexander Nichol. Diffusion models beat GANs on image synthesis. In M. Ranzato, A. Beygelzimer, Y. Dauphin, P.S. Liang, and J. Wortman Vaughan, editors, Advances in Neural Information Processing Systems, volume 34, pages 8780–8794. Curran Associates, Inc., 2021.
  • [EKB+24] Patrick Esser, Sumith Kulal, Andreas Blattmann, Rahim Entezari, Jonas Müller, Harry Saini, Yam Levi, Dominik Lorenz, Axel Sauer, Frederic Boesel, et al. Scaling rectified flow transformers for high-resolution image synthesis. arXiv preprint arXiv:2403.03206, 2024.
  • [GLP23] Shivam Gupta, Jasper C.H. Lee, and Eric Price. High-dimensional location estimation via norm concentration for subgamma vectors. In Proceedings of the 40th International Conference on Machine Learning, ICML’23. JMLR.org, 2023.
  • [GPPX23] Shivam Gupta, Aditya Parulekar, Eric Price, and Zhiyang Xun. Sample-efficient training for diffusion. arXiv preprint arXiv:2311.13745, 2023.
  • [HBE20] Ye He, Krishnakumar Balasubramanian, and Murat A Erdogdu. On the ergodicity, bias and asymptotic normality of randomized midpoint sampling method. Advances in Neural Information Processing Systems, 33:7366–7376, 2020.
  • [HJA20] Jonathan Ho, Ajay Jain, and Pieter Abbeel. Denoising diffusion probabilistic models. Advances in Neural Information Processing Systems, 33:6840–6851, 2020.
  • [KN24] Saravanan Kandasamy and Dheeraj Nagaraj. The poisson midpoint method for langevin dynamics: Provably efficient discretization for diffusion models, 2024.
  • [KPH+20] Zhifeng Kong, Wei Ping, Jiaji Huang, Kexin Zhao, and Bryan Catanzaro. Diffwave: A versatile diffusion model for audio synthesis. arXiv preprint arXiv:2009.09761, 2020.
  • [LHE+24] Gen Li, Yu Huang, Timofey Efimov, Yuting Wei, Yuejie Chi, and Yuxin Chen. Accelerating convergence of score-based diffusion models, provably. arXiv preprint arXiv:2403.03852, 2024.
  • [LLT22] Holden Lee, Jianfeng Lu, and Yixin Tan. Convergence for score-based generative modeling with polynomial complexity. Advances in Neural Information Processing Systems, 35:22870–22882, 2022.
  • [LLT23] Holden Lee, Jianfeng Lu, and Yixin Tan. Convergence of score-based generative modeling for general data distributions. In International Conference on Algorithmic Learning Theory, pages 946–985. PMLR, 2023.
  • [LWCC23] Gen Li, Yuting Wei, Yuxin Chen, and Yuejie Chi. Towards faster non-asymptotic convergence for diffusion-based generative models. arXiv preprint arXiv:2306.09251, 2023.
  • [LWYL22] Xingchao Liu, Lemeng Wu, Mao Ye, and Qiang Liu. Let us build bridges: Understanding and extending diffusion generative models. arXiv preprint arXiv:2208.14699, 2022.
  • [MCC+21] Yi-An Ma, Niladri S Chatterji, Xiang Cheng, Nicolas Flammarion, Peter L Bartlett, and Michael I Jordan. Is there an analog of nesterov acceleration for gradient-based mcmc? 2021.
  • [Pid22] Jakiw Pidstrigach. Score-based generative models detect manifolds. In S. Koyejo, S. Mohamed, A. Agarwal, D. Belgrave, K. Cho, and A. Oh, editors, Advances in Neural Information Processing Systems, volume 35, pages 35852–35865. Curran Associates, Inc., 2022.
  • [SBE+24] Andy Shih, Suneel Belkhale, Stefano Ermon, Dorsa Sadigh, and Nima Anari. Parallel sampling of diffusion models. Advances in Neural Information Processing Systems, 36, 2024.
  • [SDME21] Yang Song, Conor Durkan, Iain Murray, and Stefano Ermon. Maximum likelihood training of score-based diffusion models. In M. Ranzato, A. Beygelzimer, Y. Dauphin, P.S. Liang, and J. Wortman Vaughan, editors, Advances in Neural Information Processing Systems, volume 34, pages 1415–1428. Curran Associates, Inc., 2021.
  • [SDWMG15] Jascha Sohl-Dickstein, Eric Weiss, Niru Maheswaranathan, and Surya Ganguli. Deep unsupervised learning using nonequilibrium thermodynamics. In Francis Bach and David Blei, editors, Proceedings of the 32nd International Conference on Machine Learning, volume 37 of Proceedings of Machine Learning Research, pages 2256–2265, Lille, France, 7 2015. PMLR.
  • [SE19] Yang Song and Stefano Ermon. Generative modeling by estimating gradients of the data distribution. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d'Alché-Buc, E. Fox, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 32. Curran Associates, Inc., 2019.
  • [SL19] Ruoqi Shen and Yin Tat Lee. The randomized midpoint method for log-concave sampling. In Hanna M. Wallach, Hugo Larochelle, Alina Beygelzimer, Florence d’Alché-Buc, Emily B. Fox, and Roman Garnett, editors, Advances in Neural Information Processing Systems 32: Annual Conference on Neural Information Processing Systems 2019, NeurIPS 2019, December 8-14, 2019, Vancouver, BC, Canada, pages 2098–2109, 2019.
  • [SSDK+21] Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole. Score-based generative modeling through stochastic differential equations. In International Conference on Learning Representations, 2021.
  • [VKK21] Arash Vahdat, Karsten Kreis, and Jan Kautz. Score-based generative modeling in latent space. In M. Ranzato, A. Beygelzimer, Y. Dauphin, P.S. Liang, and J. Wortman Vaughan, editors, Advances in Neural Information Processing Systems, volume 34, pages 11287–11302. Curran Associates, Inc., 2021.
  • [WY22] Andre Wibisono and Kaylee Y. Yang. Convergence in KL divergence of the inexact Langevin algorithm with application to score-based generative models. arXiv preprint 2211.01512, 2022.
  • [WYvdB+24] Kevin E Wu, Kevin K Yang, Rianne van den Berg, Sarah Alamdari, James Y Zou, Alex X Lu, and Ava P Amini. Protein structure generation via folding diffusion. Nature Communications, 15(1):1059, 2024.
  • [YD24] Lu Yu and Arnak Dalalyana. Parallelized midpoint randomization for langevin monte carlo. arXiv preprint arXiv:2402.14434, 2024.
  • [YKD23] Lu Yu, Avetik Karagulyan, and Arnak Dalalyan. Langevin monte carlo for strongly log-concave distributions: Randomized midpoint revisited. arXiv preprint arXiv:2306.08494, 2023.
  • [ZCL+23] Shunshi Zhang, Sinho Chewi, Mufan Li, Krishna Balasubramanian, and Murat A Erdogdu. Improved discretization analysis for underdamped langevin monte carlo. In The Thirty Sixth Annual Conference on Learning Theory, pages 36–71. PMLR, 2023.

Roadmap.

In Section A, we give the proof of Theorem 1.1, our main result on sequential sampling with diffusions. In Section B, we give the proof of Theorem 1.2, our main result on parallel sampling with diffusions. In Section C, we give the proof of Theorem 1.3 on log-concave sampling.

As a notational remark, in the proofs to follow we will sometimes use the notation KL(xy)\textup{{KL}}(x\parallel y), W2(x,y)W_{2}(x,y), and TV(x,y)\textup{{TV}}(x,y) for random variables xx and yy to denote the distance between their associated probability distributions. Also, throughout the Appendix, we use tt to denote time in the forward process.

Appendix A Sequential algorithm

In this section, we describe our sequential randomized-midpoint-based algorithm in detail. Following the framework of [CCL+23a], we begin by describing the predictor Step and show in Lemma A.5 that in O~(d1/3)\widetilde{O}(d^{1/3}) steps (ignoring other dependencies), when run for time tt at most O(1L)O(\frac{1}{L}) starting from tnt_{n}, it produces a sample that is close to the true distribution at time tntt_{n}-t. Then, we show that the corrector step can be used to convert our W2W_{2} error to error in TV distance by running the underdamped Langevin Monte Carlo algorithm, as described in [CCL+23a]. We show in A.8 that if we run our predictor and corrector steps in succession for a careful choice of times, we obtain a sample that is close to the true distribution in TV using just O~(d5/12)\widetilde{O}(d^{5/12}) steps, but covering a time O(1L)O\left(\frac{1}{L}\right). Finally, in Theorem A.10, we iterate this bound O~(log2d)\widetilde{O}(\log^{2}d) times to obtain our final iteration complexity of O~(d5/12)\widetilde{O}(d^{5/12}).

A.1 Predictor step

To show the O~(d1/3)\widetilde{O}(d^{1/3}) dependence on dimension for the predictor step, we will, roughly speaking, show that its bias after one step is bounded by OL(dh6)\approx O_{L}\left(dh^{6}\right) in Lemma A.3, and that the variance is bounded by OL(dh4)O_{L}\left(dh^{4}\right) in Lemma A.4. Then, iterating these bounds 1h\approx\frac{1}{h} times as shown in Lemma A.5 will give error OL(dh4+dh3)O_{L}\left(dh^{4}+dh^{3}\right) in squared Wasserstein Distance.

Algorithm 4 PredictorStep (Sequential)

Input parameters:

  • Starting sample x^0\widehat{x}_{0}, Starting time t0t_{0}, Number of steps NN, Step sizes hn[0,,N1]h_{n\in[0,\dots,N-1]}, Score estimates s^t\widehat{s}_{t}

  1. 1.

    For n=0,,N1n=0,\dots,N-1:

    1. (a)

      Let tn=t0i=0n1hit_{n}=t_{0}-\sum_{i=0}^{n-1}h_{i}

    2. (b)

      Randomly sample α\alpha uniformly from [0,1][0,1].

    3. (c)

      Let x^n+12=eαhnx^n+(eαhn1)s^tn(x^n)ds\widehat{x}_{n+\frac{1}{2}}=e^{\alpha h_{n}}\widehat{x}_{n}+\left(e^{\alpha h_{n}}-1\right)\widehat{s}_{t_{n}}(\widehat{x}_{n})ds

    4. (d)

      Let x^n+1=ehnx^n+hne(1α)hns^tnαhn(x^n+12)\widehat{x}_{n+1}=e^{h_{n}}\widehat{x}_{n}+h_{n}\cdot e^{(1-\alpha){h_{n}}}\widehat{s}_{t_{n}-\alpha h_{n}}(\widehat{x}_{n+\frac{1}{2}})

  2. 2.

    Let tN=t0i=0N1hit_{N}=t_{0}-\sum_{i=0}^{N-1}h_{i}

  3. 3.

    Return x^N,tN\widehat{x}_{N},t_{N}.

Lemma A.1 (Naive ODE Coupling).

Consider two variables x0,x0x_{0},x_{0}^{\prime} starting at time t0t_{0}, and consider the result of running the true ODE for time hh, and let the results be x1,x1x_{1},x_{1}^{\prime}. For L1L\geq 1, h1/Lh\leq 1/L, we have

x1x12exp(O(Lh))x0x02\displaystyle\|x_{1}-x_{1}^{\prime}\|^{2}\leq\exp(O(Lh))\|x_{0}-x_{0}^{\prime}\|^{2}
Proof.

Recall that the true ODE is given by

dxt=(xt+lnqTt(xt))dt\displaystyle dx_{t}=(x_{t}+\nabla\ln q_{T-t}(x_{t}))dt

So,

txtxt2\displaystyle\partial_{t}\|x_{t}-x_{t}^{\prime}\|^{2} =2xtxt,txttxt\displaystyle=2\langle x_{t}-x_{t}^{\prime},\partial_{t}x_{t}-\partial_{t}x_{t}^{\prime}\rangle
=2xtxt,xtxt+lnqTt(xt)lnqTt(xt)\displaystyle=2\langle x_{t}-x_{t}^{\prime},x_{t}-x_{t}^{\prime}+\nabla\ln q_{T-t}(x_{t})-\nabla\ln q_{T-t}(x_{t}^{\prime})\rangle
\lesssimLxtxt2\displaystyle\lesssim L\|x_{t}-x_{t}^{\prime}\|^{2}

So,

x1x12\displaystyle\|x_{1}-x_{1}^{\prime}\|^{2} exp(O(Lh))x0x02.\displaystyle\leq\exp\left(O(Lh)\right)\|x_{0}-x_{0}^{\prime}\|^{2}\,.\qed
Lemma A.2.

Suppose L1L\geq 1. In Algorithm 4, for all n{0,,N1}n\in\{0,\dots,N-1\}, let xn(t)x^{*}_{n}(t) be the solution of the true ODE starting at x^n\widehat{x}_{n} at time tnt_{n}, running until time tntt_{n}-t. If hn\lesssim1Lh_{n}\lesssim\frac{1}{L} and tnhntn/2t_{n}-h_{n}\geq t_{n}/2, we have

\mathbbEhne(1α)hns^tnαhn(x^n+12)hne(1α)hnlnqtnαhn(xn(αhn))2\lesssimhn2εsc2+L4dhn6(L1tn)+L4hn4\mathbbEx^nxtn2,\operatorname*{\mathbb{E}}\|h_{n}e^{(1-\alpha)h_{n}}\widehat{s}_{t_{n}-\alpha h_{n}}(\widehat{x}_{n+\frac{1}{2}})-h_{n}e^{(1-\alpha)h_{n}}\nabla\ln q_{t_{n}-\alpha h_{n}}(x_{n}^{*}(\alpha h_{n}))\|^{2}\\ \lesssim h_{n}^{2}\varepsilon_{\mathrm{sc}}^{2}+L^{4}dh_{n}^{6}\left(L\lor\frac{1}{t_{n}}\right)+L^{4}h_{n}^{4}\operatorname*{\mathbb{E}}\|\widehat{x}_{n}-x_{t_{n}}\|^{2}\,,

where \mathbbE\operatorname*{\mathbb{E}} refers to the expectation over the initial choice x^0qt0\widehat{x}_{0}\sim q_{t_{0}}.

Proof.

For the proof, we will let h:=hnh:=h_{n}. It suffices to show that

s^tnαh(x^n+12)lnqtnαh(xn(αh))2\lesssimεsc2+L4dh4(L1tn)+L4h2\mathbbEx^nxtn2.\|\widehat{s}_{t_{n}-\alpha h}(\widehat{x}_{n+\frac{1}{2}})-\nabla\ln q_{t_{n}-\alpha h}(x_{n}^{*}(\alpha h))\|^{2}\lesssim\varepsilon_{\mathrm{sc}}^{2}+L^{4}dh^{4}\left(L\lor\frac{1}{t_{n}}\right)+L^{4}h^{2}\operatorname*{\mathbb{E}}\|\widehat{x}_{n}-x_{t_{n}}\|^{2}\,. (16)

Now,

s^tnαh(x^n+12)lnqtnαh(xn(αh))2\displaystyle\|\widehat{s}_{t_{n}-\alpha h}(\widehat{x}_{n+\frac{1}{2}})-\nabla\ln q_{t_{n}-\alpha h}(x_{n}^{*}(\alpha h))\|^{2}
\lesssims^tnαh(x^n+12)lnqtnαh(x^n+12)2+lnqtnαh(x^n+12)lnqtnαh(xn(αh))2\displaystyle\lesssim\|\widehat{s}_{t_{n}-\alpha h}(\widehat{x}_{n+\frac{1}{2}})-\nabla\ln q_{t_{n}-\alpha h}(\widehat{x}_{n+\frac{1}{2}})\|^{2}+\|\nabla\ln q_{t_{n}-\alpha h}(\widehat{x}_{n+\frac{1}{2}})-\nabla\ln q_{t_{n}-\alpha h}(x_{n}^{*}(\alpha h))\|^{2}
\lesssimεsc2+L2x^n+12xn(αh)2.\displaystyle\lesssim\varepsilon_{\mathrm{sc}}^{2}+L^{2}\|\widehat{x}_{n+\frac{1}{2}}-x_{n}^{*}(\alpha h)\|^{2}\,. (17)

Now, note xn(αh)x_{n}^{*}(\alpha h) is the solution to the following ODE run for time αh\alpha h, starting at x^n\widehat{x}_{n} at time tnt_{n}:

dxt=(xt+lnqt(xt))dtdx_{t}=\left(x_{t}+\nabla\ln q_{t}(x_{t})\right)dt

Similarly, x^n+12\widehat{x}_{n+\frac{1}{2}} is the solution to the following ODE run for time αh\alpha h, starting at x^n\widehat{x}_{n} at time tnt_{n}:

dx^t=(x^t+s^tn(x^n))dtd\widehat{x}_{t}=\left(\widehat{x}_{t}+\widehat{s}_{t_{n}}(\widehat{x}_{n})\right)dt

So, we have

txtx^t2\displaystyle\partial_{t}\|x_{t}-\widehat{x}_{t}\|^{2} =2xtx^t,txttx^t\displaystyle=2\langle x_{t}-\widehat{x}_{t},\partial_{t}x_{t}-\partial_{t}{\widehat{x}}_{t}\rangle
=2(xtx^t2+xtx^t,lnqt(xt)s^tn(x^n))\displaystyle=2\left(\|x_{t}-\widehat{x}_{t}\|^{2}+\langle x_{t}-\widehat{x}_{t},\nabla\ln q_{t}(x_{t})-\widehat{s}_{t_{n}}(\widehat{x}_{n})\rangle\right)
(2+1h)xtx^t2+hlnqt(xt)s^tn(x^n)2\displaystyle\leq\left(2+\frac{1}{h}\right)\|x_{t}-\widehat{x}_{t}\|^{2}+h\|\nabla\ln q_{t}(x_{t})-\widehat{s}_{t_{n}}(\widehat{x}_{n})\|^{2}

where the last line is by Young’s inequality. So, by Grönwall’s inequality,

xn(αh)x^n+122\displaystyle\|x_{n}^{*}(\alpha h)-\widehat{x}_{n+\frac{1}{2}}\|^{2} exp((2+1h)αh)0hhlnqtns(xn(s))s^tn(x^n)2ds\displaystyle\leq\exp\left(\left(2+\frac{1}{h}\right)\cdot\alpha h\right)\int_{0}^{h}h\|\nabla\ln q_{t_{n}-s}(x_{n}^{*}(s))-\widehat{s}_{t_{n}}(\widehat{x}_{n})\|^{2}\ \mathrm{d}s
\lesssimh0hlnqtns(xn(s))s^tn(x^n)2ds\displaystyle\lesssim h\int_{0}^{h}\|\nabla\ln q_{t_{n}-s}\left(x_{n}^{*}(s)\right)-\widehat{s}_{t_{n}}(\widehat{x}_{n})\|^{2}\ \mathrm{d}s
\lesssimh2εsc2+h0hlnqtns(xn(s))lnqtn(x^n)2ds\displaystyle\lesssim h^{2}\varepsilon_{\mathrm{sc}}^{2}+h\int_{0}^{h}\|\nabla\ln q_{t_{n}-s}(x_{n}^{*}(s))-\nabla\ln q_{t_{n}}(\widehat{x}_{n})\|^{2}\ \mathrm{d}s

Now, we have

lnqtns(xn(s))lnqtn(x^n)2\displaystyle\|\nabla\ln q_{t_{n}-s}(x_{n}^{*}(s))-\nabla\ln q_{t_{n}}(\widehat{x}_{n})\|^{2}
\lesssimlnqtns(xtns)lnqtn(xtn)2\displaystyle\lesssim\|\nabla\ln q_{t_{n}-s}(x_{t_{n}-s})-\nabla\ln q_{t_{n}}(x_{t_{n}})\|^{2}
+lnqtns(xn(s))lnqtns(xtns)2+lnqtn(xtn)lnqtn(x^n)2.\displaystyle\qquad+\|\nabla\ln q_{t_{n}-s}(x_{n}^{*}(s))-\nabla\ln q_{t_{n}-s}(x_{t_{n}-s})\|^{2}+\|\nabla\ln q_{t_{n}}(x_{t_{n}})-\nabla\ln q_{t_{n}}(\widehat{x}_{n})\|^{2}\,.

By Corollary D.1,

\mathbbElnqtns(xtns)lnqtn(xtn)2\lesssimL2dh2(L1tn)\displaystyle\operatorname*{\mathbb{E}}\|\nabla\ln q_{t_{n}-s}(x_{t_{n}-s})-\nabla\ln q_{t_{n}}(x_{t_{n}})\|^{2}\lesssim L^{2}dh^{2}\left(L\lor\frac{1}{t_{n}}\right)

By Lipschitzness of lnqt\nabla\ln q_{t} and Lemma A.1, for shn1/Ls\leq h_{n}\leq 1/L,

lnqtns(xn(s))lnqtns(xtns)2\displaystyle\|\nabla\ln q_{t_{n}-s}(x_{n}^{*}(s))-\nabla\ln q_{t_{n}-s}(x_{t_{n}-s})\|^{2} L2xn(s)xtns2\displaystyle\leq L^{2}\|x_{n}^{*}(s)-x_{t_{n}-s}\|^{2}
\lesssimL2exp(O(Lhn))x^nxtn2\displaystyle\lesssim L^{2}\exp\left(O(Lh_{n})\right)\|\widehat{x}_{n}-x_{t_{n}}\|^{2}
\lesssimL2x^nxtn2\displaystyle\lesssim L^{2}\|\widehat{x}_{n}-x_{t_{n}}\|^{2}

and similarly,

lnqtn(xtn)lnqtn(x^n)2L2x^nxtn2\displaystyle\|\nabla\ln q_{t_{n}}(x_{t_{n}})-\nabla\ln q_{t_{n}}(\widehat{x}_{n})\|^{2}\leq L^{2}\|\widehat{x}_{n}-x_{t_{n}}\|^{2}

So, we have shown that

\mathbbElnqtns(xn(s))lnqtn(x^n)2\lesssimL2dh2(L1tn)+L2\mathbbEx^nxtn2\displaystyle\operatorname*{\mathbb{E}}\|\nabla\ln q_{t_{n}-s}(x_{n}^{*}(s))-\nabla\ln q_{t_{n}}(\widehat{x}_{n})\|^{2}\lesssim L^{2}dh^{2}\left(L\lor\frac{1}{t_{n}}\right)+L^{2}\operatorname*{\mathbb{E}}\|\widehat{x}_{n}-x_{t_{n}}\|^{2}

so that

\mathbbExn(αh)x^n+122\lesssimh2εsc2+L2dh4(L1tn)+L2h2\mathbbEx^nxtn2.\displaystyle\operatorname*{\mathbb{E}}\|x_{n}^{*}(\alpha h)-\widehat{x}_{n+\frac{1}{2}}\|^{2}\lesssim h^{2}\varepsilon_{\mathrm{sc}}^{2}+L^{2}dh^{4}\left(L\lor\frac{1}{t_{n}}\right)+L^{2}h^{2}\operatorname*{\mathbb{E}}\|\widehat{x}_{n}-x_{t_{n}}\|^{2}\,.

Combining this with the bound in (17) and recalling that h\lesssim1/Lh\lesssim 1/L yields the desired inequality in (16). ∎

Lemma A.3 (Sequential Predictor Bias).

Suppose L1L\geq 1. In Algorithm 4, for all n{0,,N1}n\in\{0,\dots,N-1\}, let xn(t)x_{n}^{*}(t) be the solution of the true ODE starting at x^n\widehat{x}_{n} at time tnt_{n} and running until time tntt_{n}-t, and let xtqtx_{t}\sim q_{t} be the solution of the true ODE, starting at x^0qt0\widehat{x}_{0}\sim q_{t_{0}}. If hn\lesssim1Lh_{n}\lesssim\frac{1}{L}, and tnhntn/2t_{n}-h_{n}\geq t_{n}/2, we have

\mathbbE\mathbbEαx^n+1xn(hn)2\lesssimh2εsc2+L4dh6(L1tn)+L4h4\mathbbEx^nxtn2\operatorname*{\mathbb{E}}\|\operatorname*{\mathbb{E}}_{\alpha}\widehat{x}_{n+1}-x_{n}^{*}(h_{n})\|^{2}\lesssim h^{2}\varepsilon_{\mathrm{sc}}^{2}+L^{4}dh^{6}\left(L\lor\frac{1}{t_{n}}\right)+L^{4}h^{4}\operatorname*{\mathbb{E}}\|\widehat{x}_{n}-x_{t_{n}}\|^{2}

where \mathbbEα\operatorname*{\mathbb{E}}_{\alpha} is the expectation with respect to the α\alpha chosen in the nthn^{th} step, and \mathbbE\operatorname*{\mathbb{E}} is the expectation with respect to the choice of the initial x^0qt0\widehat{x}_{0}\sim q_{t_{0}}.

Proof.

For the proof, we wil fix nn, and let h:=hnh:=h_{n}. By the integral formulation of the true ODE,

xn(h)=ehx^n+tnhtnes(tnh)lnqs(xn(tns))ds.x_{n}^{*}(h)=e^{h}\widehat{x}_{n}+\int_{t_{n}-h}^{t_{n}}e^{s-(t_{n}-h)}\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\,\mathrm{d}s\,.

Thus, we have

\mathbbEαx^n+1xn(h)2\displaystyle\|\operatorname*{\mathbb{E}}_{\alpha}\widehat{x}_{n+1}-x_{n}^{*}(h)\|^{2} =h\mathbbEαe(1α)hs^tnαh(x^n+12)tnhtnes(tnh)lnqs(xn(tns))ds2\displaystyle=\|h\operatorname*{\mathbb{E}}_{\alpha}e^{(1-\alpha)h}\widehat{s}_{t_{n}-\alpha h}(\widehat{x}_{n+\frac{1}{2}})-\int_{t_{n}-h}^{t_{n}}e^{s-(t_{n}-h)}\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\,\mathrm{d}s\|^{2}
\lesssim\mathbbEαhe(1α)hs^tnαh(x^n+12)he(1α)hlnqtnαh(xn(αh))2\displaystyle\lesssim\operatorname*{\mathbb{E}}_{\alpha}\|he^{(1-\alpha)h}\widehat{s}_{t_{n}-\alpha h}(\widehat{x}_{n+\frac{1}{2}})-he^{(1-\alpha)h}\nabla\ln q_{t_{n}-\alpha h}(x_{n}^{*}(\alpha h))\|^{2}
+h\mathbbEαe(1α)hlnqtnαh(xn(αh))tnhtnes(tnh)lnqs(xn(tns))ds2.\displaystyle\qquad+\|h\cdot\operatorname*{\mathbb{E}}_{\alpha}e^{(1-\alpha)h}\nabla\ln q_{t_{n}-\alpha h}(x_{n}^{*}(\alpha h))-\int_{t_{n}-h}^{t_{n}}e^{s-(t_{n}-h)}\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\,\mathrm{d}s\|^{2}\,.

The second term is 0 since

h\mathbbEαe(1α)hlnqtnαh(xn(αh))\displaystyle h\operatorname*{\mathbb{E}}_{\alpha}e^{(1-\alpha)h}\nabla\ln q_{t_{n}-\alpha h}(x_{n}^{*}(\alpha h)) =h01e(1α)hlnqtnαh(xn(αh))dα\displaystyle=h\int_{0}^{1}e^{(1-\alpha)h}\nabla\ln q_{t_{n}-\alpha h}(x_{n}^{*}(\alpha h))\,\mathrm{d}\alpha
=tnhtnes(tnh)lnqs(xn(tns))ds.\displaystyle=\int_{t_{n}-h}^{t_{n}}e^{s-(t_{n}-h)}\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\,\mathrm{d}s\,.

For the first term, we have, by Lemma A.2

\mathbbEhe(1α)hs^tnαh(x^n+12)he(1α)hlnqtnαh(xn(αh))2\lesssimh2εsc2+L4dh6(L1tn)+L4h4\mathbbEx^nxtn2.\operatorname*{\mathbb{E}}\|he^{(1-\alpha)h}\widehat{s}_{t_{n}-\alpha h}(\widehat{x}_{n+\frac{1}{2}})-he^{(1-\alpha)h}\nabla\ln q_{t_{n}-\alpha h}(x_{n}^{*}(\alpha h))\|^{2}\\ \lesssim h^{2}\varepsilon_{\mathrm{sc}}^{2}+L^{4}dh^{6}\left(L\lor\frac{1}{t_{n}}\right)+L^{4}h^{4}\operatorname*{\mathbb{E}}\|\widehat{x}_{n}-x_{t_{n}}\|^{2}\,.

The claimed bound follows. ∎

Lemma A.4 (Sequential Predictor Variance).

Suppose L1L\geq 1. In Algorithm 4, for all n{0,,N1}n\in\{0,\dots,N-1\}, let xn(t)x_{n}^{*}(t) be the solution of the true ODE starting at x^n\widehat{x}_{n} at time tnt_{n} and running until time tntt_{n}-t, and let xtqtx_{t}\sim q_{t} be the solution of the true ODE starting at x^0qt0\widehat{x}_{0}\sim q_{t_{0}}. If hn\lesssim1Lh_{n}\lesssim\frac{1}{L} and tnhntn/2t_{n}-h_{n}\geq t_{n}/2, we have

\mathbbEx^n+1xn(hn)2\lesssimhn2εsc2+L2dhn4(L1tn)+L2hn2\mathbbExtnx^n2\displaystyle\operatorname*{\mathbb{E}}\|\widehat{x}_{n+1}-x_{n}^{*}(h_{n})\|^{2}\lesssim h_{n}^{2}\varepsilon_{\mathrm{sc}}^{2}+L^{2}dh_{n}^{4}\left(L\lor\frac{1}{t_{n}}\right)+L^{2}h_{n}^{2}\operatorname*{\mathbb{E}}\|x_{t_{n}}-\widehat{x}_{n}\|^{2}

where \mathbbE\operatorname*{\mathbb{E}} refers to the expectation wrt the random α\alpha in the nthn^{th} step, along with the initial choice x^0qt0\widehat{x}_{0}\sim q_{t_{0}}.

Proof.

Fix nn and let h:=hnh:=h_{n}. We have

\mathbbEx^n+1xn(h)2\displaystyle\operatorname*{\mathbb{E}}\|\widehat{x}_{n+1}-x_{n}^{*}(h)\|^{2}
=\mathbbEhe(1α)hs^tnαh(x^n+12)tnhtnes(tnh)lnqs(xn(tns))ds2\displaystyle=\operatorname*{\mathbb{E}}\|h\cdot e^{(1-\alpha)h}\widehat{s}_{t_{n}-\alpha h}(\widehat{x}_{n+\frac{1}{2}})-\int_{t_{n}-h}^{t_{n}}e^{s-(t_{n}-h)}\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\,\mathrm{d}s\|^{2}
\lesssim\mathbbEhe(1α)hs^tnαh(x^n+12)he(1α)hlnqtnαh(xn(αh))2\displaystyle\lesssim\operatorname*{\mathbb{E}}\|h\cdot e^{(1-\alpha)h}\widehat{s}_{t_{n}-\alpha h}(\widehat{x}_{n+\frac{1}{2}})-he^{(1-\alpha)h}\nabla\ln q_{t_{n}-\alpha h}(x_{n}^{*}(\alpha h))\|^{2}
+\mathbbEhe(1α)hlnqtnαh(xn(αh))tnhtne(1α)hlnqs(xn(tns))ds2\displaystyle\qquad\qquad+\operatorname*{\mathbb{E}}\|h\cdot e^{(1-\alpha)h}\nabla\ln q_{t_{n}-\alpha h}(x_{n}^{*}(\alpha h))-\int_{t_{n}-h}^{t_{n}}e^{(1-\alpha)h}\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\,\mathrm{d}s\|^{2}
+\mathbbEtnhtne(1α)hlnqs(xn(tns))dstnhtnes(tnh)lnqs(xn(tns))ds2\displaystyle\qquad\qquad+\operatorname*{\mathbb{E}}\|\int_{t_{n}-h}^{t_{n}}e^{(1-\alpha)h}\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\,\mathrm{d}s-\int_{t_{n}-h}^{t_{n}}e^{s-(t_{n}-h)}\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\,\mathrm{d}s\|^{2}

The first term was bounded in Lemma A.2:

\mathbbEhe(1α)hs^tnαh(x^n+12)he(1α)hlnqtnαh(xtnαh)2\lesssimh2εsc2+L4dh6(L1tn)+L4h4\mathbbEx^nxtn2.\operatorname*{\mathbb{E}}\|h\cdot e^{(1-\alpha)h}\widehat{s}_{t_{n}-\alpha h}(\widehat{x}_{n+\frac{1}{2}})-he^{(1-\alpha)h}\nabla\ln q_{t_{n}-\alpha h}(x_{t_{n}-\alpha h})\|^{2}\\ \lesssim h^{2}\varepsilon_{\mathrm{sc}}^{2}+L^{4}dh^{6}\left(L\lor\frac{1}{t_{n}}\right)+L^{4}h^{4}\operatorname*{\mathbb{E}}\|\widehat{x}_{n}-x_{t_{n}}\|^{2}\,.

For the second term,

\mathbbEhe(1α)hlnqtnαh(xn(αh))tnhtne(1α)hlnqs(xn(tns))ds2\displaystyle\operatorname*{\mathbb{E}}\|h\cdot e^{(1-\alpha)h}\nabla\ln q_{t_{n}-\alpha h}(x_{n}^{*}(\alpha h))-\int_{t_{n}-h}^{t_{n}}e^{(1-\alpha)h}\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\,\mathrm{d}s\|^{2}
=\mathbbEtnhtne(1α)h(lnqtnαh(xn(αh))lnqs(xn(tns)))ds2\displaystyle=\operatorname*{\mathbb{E}}\|\int_{t_{n}-h}^{t_{n}}e^{(1-\alpha)h}\cdot\left(\nabla\ln q_{t_{n}-\alpha h}(x_{n}^{*}(\alpha h))-\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\right)\,\mathrm{d}s\|^{2}
\lesssimhtnhtn\mathbbElnqtnαh(xn(αh))lnqs(xn(tns)2ds.\displaystyle\lesssim h\int_{t_{n}-h}^{t_{n}}\operatorname*{\mathbb{E}}\|\nabla\ln q_{t_{n}-\alpha h}(x_{n}^{*}(\alpha h))-\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s)\|^{2}\,\mathrm{d}s\,.

Now,

\mathbbElnqtnαh(xn(αh))lnqs(xn(tns))2\displaystyle\operatorname*{\mathbb{E}}\|\nabla\ln q_{t_{n}-\alpha h}(x_{n}^{*}(\alpha h))-\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\|^{2} \lesssim\mathbbElnqtnαh(xtnαh)lnqs(xs)2\displaystyle\lesssim\operatorname*{\mathbb{E}}\|\nabla\ln q_{t_{n}-\alpha h}(x_{t_{n}-\alpha h})-\nabla\ln q_{s}(x_{s})\|^{2}
+\mathbbElnqtnαh(xtnαh)lnqtnαh(xn(αh))2\displaystyle\qquad+\operatorname*{\mathbb{E}}\|\nabla\ln q_{t_{n}-\alpha h}(x_{t_{n}-\alpha h})-\nabla\ln q_{t_{n}-\alpha h}(x_{n}^{*}(\alpha h))\|^{2}
+\mathbbElnqs(xs)lnqs(xn(tns))2\displaystyle\qquad+\operatorname*{\mathbb{E}}\|\nabla\ln q_{s}(x_{s})-\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\|^{2}

The first of these terms is bounded in Corollary D.1:

\mathbbElnqtnαh(xtnαh)lnqs(xs)2\lesssimL2dh2(L1tn)\operatorname*{\mathbb{E}}\|\nabla\ln q_{t_{n}-\alpha h}(x_{t_{n}-\alpha h})-\nabla\ln q_{s}(x_{s})\|^{2}\lesssim L^{2}dh^{2}\left(L\lor\frac{1}{t_{n}}\right)

For the remaining two terms, note that by the Lipschitzness of lnqt\nabla\ln q_{t} and Lemma A.1,

\mathbbElnqtnαh(xtnαh)lnqtnαh(xn(αh))2\displaystyle\operatorname*{\mathbb{E}}\|\nabla\ln q_{t_{n}-\alpha h}(x_{t_{n}-\alpha h})-\nabla\ln q_{t_{n}-\alpha h}(x_{n}^{*}(\alpha h))\|^{2} L2\mathbbExtnαhxn(αh)2\displaystyle\leq L^{2}\operatorname*{\mathbb{E}}\|x_{t_{n}-\alpha h}-x_{n}^{*}(\alpha h)\|^{2}
\lesssimL2exp(O(Lh))\mathbbExtnx^n2\displaystyle\lesssim L^{2}\exp(O(Lh))\operatorname*{\mathbb{E}}\|x_{t_{n}}-\widehat{x}_{n}\|^{2}
\lesssimL2\mathbbExtnx^n2\displaystyle\lesssim L^{2}\operatorname*{\mathbb{E}}\|x_{t_{n}}-\widehat{x}_{n}\|^{2}

and similarly, for tnhstnt_{n}-h\leq s\leq t_{n},

\mathbbElnqs(xs)lnqs(xn(tns))2\lesssimL2\mathbbExtnx^n2.\operatorname*{\mathbb{E}}\|\nabla\ln q_{s}(x_{s})-\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\|^{2}\lesssim L^{2}\operatorname*{\mathbb{E}}\|x_{t_{n}}-\widehat{x}_{n}\|^{2}\,. (18)

Thus, we have shown that the second term in our bound on \mathbbEx^n+1xn(h)2\operatorname*{\mathbb{E}}\|\widehat{x}_{n+1}-x^{*}_{n}(h)\|^{2} is bounded as follows:

\mathbbEhe(1α)hlnqtnαh(xn(αh))tnhtne(1α)hlnqs(xn(tns)),ds2\lesssimL2dh4(L1tn)+L2h2\mathbbExtnx^n2.\operatorname*{\mathbb{E}}\|h\cdot e^{(1-\alpha)h}\nabla\ln q_{t_{n}-\alpha h}(x_{n}^{*}(\alpha h))-\int_{t_{n}-h}^{t_{n}}e^{(1-\alpha)h}\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\,,\mathrm{d}s\|^{2}\\ \lesssim L^{2}dh^{4}\left(L\lor\frac{1}{t_{n}}\right)+L^{2}h^{2}\operatorname*{\mathbb{E}}\|x_{t_{n}}-\widehat{x}_{n}\|^{2}\,.

For the third term,

\mathbbEtnhtne(1α)hlnqs(xn(tns)),dstnhtnes(tnh)lnqs(xn(tns)),ds2\displaystyle\operatorname*{\mathbb{E}}\|\int_{t_{n}-h}^{t_{n}}e^{(1-\alpha)h}\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\,,\mathrm{d}s-\int_{t_{n}-h}^{t_{n}}e^{s-(t_{n}-h)}\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\,,\mathrm{d}s\|^{2}
=\mathbbEtnhtn(e(1α)hes(tnh))lnqs(xn(tns)),ds2\displaystyle=\operatorname*{\mathbb{E}}\|\int_{t_{n}-h}^{t_{n}}\left(e^{(1-\alpha)h}-e^{s-(t_{n}-h)}\right)\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\,,\mathrm{d}s\|^{2}
\lesssimhtnhtn\mathbbEα(e(1α)hes(tnh))2\mathbbEx^0qt0lnqs(xn(tns))2ds.\displaystyle\lesssim h\int_{t_{n}-h}^{t_{n}}\operatorname*{\mathbb{E}}_{\alpha}\left(e^{(1-\alpha)h}-e^{s-(t_{n}-h)}\right)^{2}\operatorname*{\mathbb{E}}_{\widehat{x}_{0}\sim q_{t_{0}}}\|\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\|^{2}\,\mathrm{d}s\,.

Now, we have

\mathbbElnqs(xn(tns))2\displaystyle\operatorname*{\mathbb{E}}\|\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\|^{2} \lesssim\mathbbElnqs(xs)2+\mathbbElnqs(xn(tns))lnqs(xs)2\displaystyle\lesssim\operatorname*{\mathbb{E}}\|\nabla\ln q_{s}(x_{s})\|^{2}+\operatorname*{\mathbb{E}}\|\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))-\nabla\ln q_{s}(x_{s})\|^{2}
\lesssimds+L2\mathbbExtnx^n2.\displaystyle\lesssim\frac{d}{s}+L^{2}\operatorname*{\mathbb{E}}\|x_{t_{n}}-\widehat{x}_{n}\|^{2}\,.

where the last step follows by Lemma D.4 and (18). So,

\mathbbEtnhtne(1α)hlnqs(xn(tns))dstnhtnes(tnh)lnqs(xn(tns))ds2\displaystyle\operatorname*{\mathbb{E}}\|\int_{t_{n}-h}^{t_{n}}e^{(1-\alpha)h}\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\,\mathrm{d}s-\int_{t_{n}-h}^{t_{n}}e^{s-(t_{n}-h)}\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\,\mathrm{d}s\|^{2}
\lesssimhtnhtn\mathbbEα(e(1α)hes(tnh))2(ds+L2\mathbbExtnx^n2)ds\displaystyle\lesssim h\int_{t_{n}-h}^{t_{n}}\operatorname*{\mathbb{E}}_{\alpha}\left(e^{(1-\alpha)h}-e^{s-(t_{n}-h)}\right)^{2}\cdot\left(\frac{d}{s}+L^{2}\operatorname*{\mathbb{E}}\|x_{t_{n}}-\widehat{x}_{n}\|^{2}\right)\,\mathrm{d}s
\lesssimh4(dtn+L2\mathbbExtnx^n2)\displaystyle\lesssim h^{4}\cdot\left(\frac{d}{t_{n}}+L^{2}\operatorname*{\mathbb{E}}\|x_{t_{n}}-\widehat{x}_{n}\|^{2}\right)

Thus, noting that h1Lh\leq\frac{1}{L}, we obtain the claimed bound on \mathbbEx^n+1xn(hn)2\operatorname*{\mathbb{E}}\|\widehat{x}_{n+1}-x^{*}_{n}(h_{n})\|^{2}. ∎

Finally, we put together the bias and variance bounds above to obtain a bound on the Wasserstein error at the end of the Predictor Step.

Lemma A.5 (Sequential Predictor Wasserstein Guarantee).

Suppose that L1L\geq 1, and that for our sequence of step sizes h0,,hN1h_{0},\dots,h_{N-1}, ihi1/L\sum_{i}h_{i}\leq 1/L. Let hmax=maxihih_{\text{max}}=\max_{i}h_{i}. Then, at the end of Algorithm 4,

  1. 1.

    If tN\gtrsim1/Lt_{N}\gtrsim 1/L,

    W22(x^N,xtN)\lesssimεsc2L2+L3dhmax4+L2dhmax3\displaystyle W_{2}^{2}(\widehat{x}_{N},x_{t_{N}})\lesssim\frac{\varepsilon_{\mathrm{sc}}^{2}}{L^{2}}+L^{3}dh_{\max}^{4}+L^{2}dh_{\max}^{3}
  2. 2.

    If tN\lesssim1/Lt_{N}\lesssim 1/L and hn\lesssimtn2h_{n}\lesssim\frac{t_{n}}{2} for each nn,

    W22(x^N,xtN)\lesssimεsc2L2+(L3dhmax4+L2dhmax3)N\displaystyle W_{2}^{2}(\widehat{x}_{N},x_{t_{N}})\lesssim\frac{\varepsilon_{\mathrm{sc}}^{2}}{L^{2}}+\left(L^{3}dh_{\max}^{4}+L^{2}dh_{\max}^{3}\right)\cdot N

Here, xtNqtNx_{t_{N}}\sim q_{t_{N}} is the solution of the true ODE beginning at xt0=x^0qt0x_{t_{0}}=\widehat{x}_{0}\sim q_{t_{0}}.

Proof.

For all n[1,,N]n\in[1,\dots,N], let yny_{n} be the solution of the exact one step ODE starting from x^n1\widehat{x}_{n-1}. Let the operator \mathbbEα\operatorname*{\mathbb{E}}_{\alpha} be the expectation over the random choice of α\alpha in the nthn^{th} iteration. Note that only x^N\widehat{x}_{N} depends on α\alpha. We have

\mathbbEα[xtNx^N2]\displaystyle\operatorname*{\mathbb{E}}_{\alpha}\left[\|x_{t_{N}}-\widehat{x}_{N}\|^{2}\right] =\mathbbEα[(xtNyN)(x^NyN)2]\displaystyle=\operatorname*{\mathbb{E}}_{\alpha}\left[\|(x_{t_{N}}-y_{N})-(\widehat{x}_{N}-y_{N})\|^{2}\right]
=xtNyN22xtNyN,\mathbbEαx^NyN+\mathbbEαx^NyN2\displaystyle=\|x_{t_{N}}-y_{N}\|^{2}-2\langle x_{t_{N}}-y_{N},\operatorname*{\mathbb{E}}_{\alpha}\widehat{x}_{N}-y_{N}\rangle+\operatorname*{\mathbb{E}}_{\alpha}\|\widehat{x}_{N}-y_{N}\|^{2}
(1+LhN12)xtNyN2+2LhN1\mathbbEαx^NyN2+\mathbbEαx^NyN2\displaystyle\leq\left(1+\frac{Lh_{N-1}}{2}\right)\|x_{t_{N}}-y_{N}\|^{2}+\frac{2}{Lh_{N-1}}\|\operatorname*{\mathbb{E}}_{\alpha}\widehat{x}_{N}-y_{N}\|^{2}+\operatorname*{\mathbb{E}}_{\alpha}\|\widehat{x}_{N}-y_{N}\|^{2}
exp(O(LhN1))xtN1x^N12+2LhN1\mathbbEαx^NyN2+\mathbbEαx^NyN2,\displaystyle\leq\exp\left(O(Lh_{N-1})\right)\|x_{t_{N-1}}-\widehat{x}_{N-1}\|^{2}+\frac{2}{Lh_{N-1}}\|\operatorname*{\mathbb{E}}_{\alpha}\widehat{x}_{N}-y_{N}\|^{2}+\operatorname*{\mathbb{E}}_{\alpha}\|\widehat{x}_{N}-y_{N}\|^{2}\,,

where the third line is by Young’s inequality, and the fourth line is by Lemma A.1. Taking the expectation wrt x^0qt0\widehat{x}_{0}\sim q_{t_{0}}, by Lemmas A.3 and A.4,

\mathbbExtNx^N2\displaystyle\operatorname*{\mathbb{E}}\|x_{t_{N}}-\widehat{x}_{N}\|^{2}
exp(O(LhN1))\mathbbExtN1x^N12+2LhN1\mathbbE\mathbbEαx^NyN2+\mathbbEx^NyN2\displaystyle\leq\exp\left(O(Lh_{N-1})\right)\operatorname*{\mathbb{E}}\|x_{t_{N-1}}-\widehat{x}_{N-1}\|^{2}+\frac{2}{Lh_{N-1}}\operatorname*{\mathbb{E}}\|\operatorname*{\mathbb{E}}_{\alpha}\widehat{x}_{N}-y_{N}\|^{2}+\operatorname*{\mathbb{E}}\|\widehat{x}_{N}-y_{N}\|^{2}
exp(O(LhN1))\mathbbExtN1x^N12\displaystyle\leq\exp\left(O(Lh_{N-1})\right)\operatorname*{\mathbb{E}}\|x_{t_{N-1}}-\widehat{x}_{N-1}\|^{2}
+O(1LhN1(hN12εsc2+L4dhN16(L1tN1)+L4hN14\mathbbExtN1x^N12))\displaystyle\qquad+O\left(\frac{1}{Lh_{N-1}}\left(h_{N-1}^{2}\varepsilon_{\mathrm{sc}}^{2}+L^{4}dh_{N-1}^{6}\left(L\lor\frac{1}{t_{N-1}}\right)+L^{4}h_{N-1}^{4}\operatorname*{\mathbb{E}}\|x_{t_{N-1}}-\widehat{x}_{N-1}\|^{2}\right)\right)
+O(hN12εsc2+L2dhN14(L1tN1)+L2hN12\mathbbExtN1x^N12)\displaystyle\qquad+O\left(h_{N-1}^{2}\varepsilon_{\mathrm{sc}}^{2}+L^{2}dh_{N-1}^{4}\left(L\lor\frac{1}{t_{N-1}}\right)+L^{2}h_{N-1}^{2}\operatorname*{\mathbb{E}}\|x_{t_{N-1}}-\widehat{x}_{N-1}\|^{2}\right)
exp(O(LhN1))\mathbbExtN1x^N12\displaystyle\leq\exp\left(O(Lh_{N-1})\right)\operatorname*{\mathbb{E}}\|x_{t_{N-1}}-\widehat{x}_{N-1}\|^{2}
+O(hN1εsc2L+hN12εsc2+(L3dhN15+L2dhN14)(L1tN1))\displaystyle\qquad+O\left(\frac{h_{N-1}\varepsilon_{\mathrm{sc}}^{2}}{L}+h_{N-1}^{2}\varepsilon_{\mathrm{sc}}^{2}+\left(L^{3}dh_{N-1}^{5}+L^{2}dh_{N-1}^{4}\right)\left(L\lor\frac{1}{t_{N-1}}\right)\right)

By induction, noting that xt0=x^0x_{t_{0}}=\widehat{x}_{0}, we have

\mathbbExtNx^N2\lesssimn=0N1(hnεsc2L+hn2εsc2+(L3dhn5+L2dhn4)(L1tn))exp(O(Li=n+1N1hi)).\operatorname*{\mathbb{E}}\|x_{t_{N}}-\widehat{x}_{N}\|^{2}\lesssim\sum_{n=0}^{N-1}\left(\frac{h_{n}\varepsilon_{\mathrm{sc}}^{2}}{L}+h_{n}^{2}\varepsilon_{\mathrm{sc}}^{2}+\left(L^{3}dh_{n}^{5}+L^{2}dh_{n}^{4}\right)\cdot\left(L\lor\frac{1}{t_{n}}\right)\right)\cdot\exp\left(O\left(L\sum_{i=n+1}^{N-1}h_{i}\right)\right)\,.

By assumption, ihi1L\sum_{i}h_{i}\leq\frac{1}{L}. In the first case, L1tn\lesssimLL\vee\frac{1}{t_{n}}\lesssim L for all nn, so

W22(x^N,xtN)\lesssimεsc2L2+L3dhmax4+L2dhmax3.W_{2}^{2}(\widehat{x}_{N},x_{t_{N}})\lesssim\frac{\varepsilon_{\mathrm{sc}}^{2}}{L^{2}}+L^{3}dh_{\max}^{4}+L^{2}dh_{\max}^{3}\,.

In the second case,

W22(x^N,xtN)\displaystyle W_{2}^{2}(\widehat{x}_{N},x_{t_{N}}) \lesssimεsc2L2+(L3dhmax4+L2dhmax3)n=0N1hntn\displaystyle\lesssim\frac{\varepsilon_{\mathrm{sc}}^{2}}{L^{2}}+\left(L^{3}dh_{\max}^{4}+L^{2}dh_{\max}^{3}\right)\cdot\sum_{n=0}^{N-1}\frac{h_{n}}{t_{n}}
\lesssimεsc2L2+(L3dhmax4+L2dhmax3)N.\displaystyle\lesssim\frac{\varepsilon_{\mathrm{sc}}^{2}}{L^{2}}+\left(L^{3}dh_{\max}^{4}+L^{2}dh_{\max}^{3}\right)\cdot N\,.\qed

A.2 Corrector step

For the sequential algorithm, we make use of the underdamped Langevin corrector step and analysis from [CCL+23a]. We reproduce the same here for convenience.

The underdamped Langevin Monte Carlo process with step size hh is given by:

dx^t=v^tdtdv^t=(s^(x^t/hh)γv^t)dt+2γdBt\displaystyle\begin{split}\mathrm{d}\widehat{x}_{t}&=\widehat{v}_{t}\,\mathrm{d}t\\ \mathrm{d}\widehat{v}_{t}&=(\widehat{s}(\widehat{x}_{\lfloor t/h\rfloor h})-\gamma\widehat{v}_{t})\,\mathrm{d}t+\sqrt{2\gamma}\,\mathrm{d}B_{t}\end{split} (19)

where BtB_{t} is Brownian motion, and s^\widehat{s} satisfies

\mathbbExqs^(x)lnq(x)2εsc2.\operatorname*{\mathbb{E}}_{x\sim q}\left\|\widehat{s}(x)-\nabla\ln q(x)\right\|^{2}\leq\varepsilon_{\mathrm{sc}}^{2}\,. (20)

for some target measure qq. Here, we set the friction parameter γ=Θ(L)\gamma=\Theta(\sqrt{L}).

Then, our corrector step is as follows.

Algorithm 5 CorrectorStep (Sequential)

Input parameters:

  • Starting sample x^0\widehat{x}_{0}, Total time TcorrT_{\mathrm{corr}}, Step size hcorrh_{\mathrm{corr}}, Score estimate s^\widehat{s}

  1. 1.

    Run underdamped Langevin Monte Carlo in (19) for total time TcorrT_{\mathrm{corr}} using step size hcorrh_{\mathrm{corr}}, and let the result be x^N\widehat{x}_{N}.

  2. 2.

    Return x^N\widehat{x}_{N}.

Theorem A.6 (Theorem 5 of [CCL+23a], restated).

Suppose Eq. (20) holds. For any distribution pp over \mathbbRd\mathbb{R}^{d}, and total time Tcorr\lesssim1/LT_{\mathrm{corr}}\lesssim 1/\sqrt{L}, if we let pNp_{N} be the distribution of x^N\widehat{x}_{N} resulting from running Algorithm 5 initialized at x^0p\widehat{x}_{0}\sim p, then we have

TV(pN,q)\lesssimW2(p,q)L1/4Tcorr3/2+εscTcorr1/2L1/4+L3/4Tcorr1/2d1/2hcorr.\displaystyle\textup{{TV}}(p_{N},q)\lesssim\frac{W_{2}(p,q)}{L^{1/4}T_{\mathrm{corr}}^{3/2}}+\frac{\varepsilon_{\mathrm{sc}}T_{\mathrm{corr}}^{1/2}}{L^{1/4}}+L^{3/4}T_{\mathrm{corr}}^{1/2}d^{1/2}h_{\mathrm{corr}}\,.
Corollary A.7 (Underdamped Corrector).

For Tcorr=Θ(1Ld1/18)T_{\mathrm{corr}}=\Theta\left(\frac{1}{\sqrt{L}d^{1/18}}\right)

TV(pN,q)\lesssimW2(p,q)d1/12L+εscLd1/36+Ld17/36hcorr.\displaystyle\textup{{TV}}(p_{N},q)\lesssim W_{2}(p,q)\cdot d^{1/12}\cdot\sqrt{L}+\frac{\varepsilon_{\mathrm{sc}}}{\sqrt{L}d^{1/36}}+\sqrt{L}d^{17/36}h_{\mathrm{corr}}\,.

A.3 End-to-end analysis

Finally, we put together the analysis of the predictor and corrector step to obtain our final O~(d5/12)\widetilde{O}(d^{5/12}) dependence on sampling time. We first show that carefully choosing the amount of time to run the corrector results in small TV error after successive rounds of the predictor and corrector steps in Lemma A.8. Finally, we iterate this bound to obtain our final guarantee, given by Theorem A.10.

Algorithm 6 SequentialAlgorithm

Input parameters:

  • Start time TT, End time δ\delta, Corrector steps time Tcorr\lesssim1/LT_{\mathrm{corr}}\lesssim 1/\sqrt{L}, Number of predictor-corrector steps N0N_{0}, Predictor step size hpredh_{\mathrm{pred}}, Corrector step size hcorrh_{\mathrm{corr}}, Score estimates s^t\widehat{s}_{t}

  1. 1.

    Draw x^0𝒩(0,Id)\widehat{x}_{0}\sim\mathcal{N}(0,I_{d}).

  2. 2.

    For n=0,,N01n=0,\dots,N_{0}-1:

    1. (a)

      Starting from x^n\widehat{x}_{n}, run Algorithm 4 with starting time Tn/LT-n/L using step sizes hpredh_{\mathrm{pred}} for all NN steps, with N=1LhpredN=\frac{1}{Lh_{\mathrm{pred}}}, so that the total time is 1/L1/L. Let the result be x^n+1\widehat{x}_{n+1}^{\prime}.

    2. (b)

      Starting from x^n+1\widehat{x}_{n+1}^{\prime}, run Algorithm 5 for total time TcorrT_{\mathrm{corr}} with step size hcorrh_{\mathrm{corr}} and score estimate s^T(n+1)/L\widehat{s}_{T-(n+1)/L} to obtain x^n+1\widehat{x}_{n+1}.

  3. 3.

    Starting from x^N0\widehat{x}_{N_{0}}, run Algorithm 4 with starting time TN0/LT-N_{0}/L using step sizes hpred/2,hpred/4,hpred/8,,δh_{\mathrm{pred}}/2,h_{\mathrm{pred}}/4,h_{\mathrm{pred}}/8,\dots,\delta to obtain x^N0+1\widehat{x}_{N_{0}+1}^{\prime}.

  4. 4.

    Starting from x^N0+1\widehat{x}^{\prime}_{N_{0}+1}, run Algorithm 5 for total time TcorrT_{\mathrm{corr}} with step size hcorrh_{\mathrm{corr}} and score estimate s^δ\widehat{s}_{\delta} to obtain x^N0+1\widehat{x}_{N_{0}+1}.

  5. 5.

    Return x^N0+1\widehat{x}_{N_{0}+1}.

Lemma A.8 (TV error after one round of predictor and corrector).

Let xtqtx_{t}\sim q_{t} be a sample from the true distribution at time tt. Let tn=Tn/Lt_{n}=T-n/L for n[0,,N0]n\in[0,\dots,N_{0}]. If we set Tcorr=Θ(1Ld1/18)T_{\mathrm{corr}}=\Theta\left(\frac{1}{\sqrt{L}d^{1/18}}\right), we have,

  1. 1.

    For n[0,,N01]n\in[0,\dots,N_{0}-1], if tn\gtrsim1/Lt_{n}\gtrsim 1/L,

    TV(x^n+1,xtn+1)TV(x^n,xtn)+O(L2d7/12hpred2+L3/2d7/12hpred3/2+Ld17/36hcorr+εscd1/12L)\textup{{TV}}(\widehat{x}_{n+1},x_{t_{n+1}})\leq\textup{{TV}}(\widehat{x}_{n},x_{t_{n}})+O\left(L^{2}d^{7/12}h_{\mathrm{pred}}^{2}+L^{3/2}d^{7/12}h_{\mathrm{pred}}^{3/2}+\sqrt{L}d^{17/36}h_{\mathrm{corr}}+\frac{\varepsilon_{\mathrm{sc}}d^{1/12}}{\sqrt{L}}\right)
  2. 2.

    If tN0\lesssim1/Lt_{N_{0}}\lesssim 1/L,

    TV(x^N0+1,xδ)\displaystyle\textup{{TV}}(\widehat{x}_{N_{0}+1},x_{\delta}) TV(x^N0,xtN0)\displaystyle\leq\textup{{TV}}(\widehat{x}_{N_{0}},x_{t_{N_{0}}})
    +O((L2d7/12hpred2+L3/2d7/12hpred3/2)loghpredδ+Ld17/36hcorr+εscd1/12L)\displaystyle\quad+O\left(\left(L^{2}d^{7/12}h_{\mathrm{pred}}^{2}+L^{3/2}d^{7/12}h_{\mathrm{pred}}^{3/2}\right)\cdot\sqrt{\log\frac{h_{\mathrm{pred}}}{\delta}}+\sqrt{L}d^{17/36}h_{\mathrm{corr}}+\frac{\varepsilon_{\mathrm{sc}}d^{1/12}}{\sqrt{L}}\right)
Proof.

For n[0,,N0]n\in[0,\ldots,N_{0}], let y^n+1\widehat{y}_{n+1} be the result of a single predictor-corrector sequence as described in step 2 of Algorithm 6, but starting from xtnqtnx_{t_{n}}\sim q_{t_{n}} instead of x^n\widehat{x}_{n}. Additionally, let y^N0+1\widehat{y}_{N_{0}+1} be the result of running steps 33 and 44 starting from xtN0qtN0x_{t_{N_{0}}}\sim q_{t_{N_{0}}} instead of x^N0\widehat{x}_{N_{0}}. Similarly, let y^n+1\widehat{y}_{n+1}^{\prime} be the result of only applying the predictor step starting from xtnqtnx_{t_{n}}\sim q_{t_{n}}, analogous to x^n+1\widehat{x}_{n+1}^{\prime} defined in step 2a.

We have, by the triangle inequality and the data-processing inequality, for n[0,,N01]n\in[0,\dots,N_{0}-1],

TV(x^n+1,xtn+1)\displaystyle\textup{{TV}}(\widehat{x}_{n+1},x_{t_{n+1}}) TV(x^n+1,y^n+1)+TV(y^n+1,xtn+1)\displaystyle\leq\textup{{TV}}(\widehat{x}_{n+1},\widehat{y}_{n+1})+\textup{{TV}}(\widehat{y}_{n+1},x_{t_{n+1}})
TV(x^n,xtn)+TV(y^n+1,xtn+1)\displaystyle\leq\textup{{TV}}(\widehat{x}_{n},x_{t_{n}})+\textup{{TV}}(\widehat{y}_{n+1},x_{t_{n+1}})

By Corollary A.7,

TV(y^n+1,xtn+1)\lesssimW2(y^n+1,xtn+1)d1/12L+εscLd1/36+Ld17/36hcorr\textup{{TV}}(\widehat{y}_{n+1},x_{t_{n+1}})\lesssim W_{2}(\widehat{y}_{n+1}^{\prime},x_{t_{n+1}})\cdot d^{1/12}\cdot\sqrt{L}+\frac{\varepsilon_{\mathrm{sc}}}{\sqrt{L}d^{1/36}}+\sqrt{L}d^{17/36}h_{\mathrm{corr}}

Now, for tn\gtrsim1/Lt_{n}\gtrsim 1/L, by Lemma A.5,

W2(y^n+1,xtn+1)\lesssimεscL+L3/2dhpred2+Ldhpred3/2.W_{2}(\widehat{y}_{n+1}^{\prime},x_{t_{n+1}})\lesssim\frac{\varepsilon_{\mathrm{sc}}}{L}+L^{3/2}\sqrt{d}h_{\mathrm{pred}}^{2}+L\sqrt{d}h_{\mathrm{pred}}^{3/2}\,.

Combining the above gives the first claim.

For the second claim, similar to above, we have

TV(x^N0+1,xδ)\displaystyle\textup{{TV}}(\widehat{x}_{N_{0}+1},x_{\delta}) TV(x^N0+1,y^N0+1)+TV(y^N0+1,xδ)\displaystyle\leq\textup{{TV}}(\widehat{x}_{N_{0}+1},\widehat{y}_{N_{0}+1})+\textup{{TV}}(\widehat{y}_{N_{0}+1},x_{\delta})
TV(x^N0,xtN0)+TV(y^N0+1,xδ)\displaystyle\leq\textup{{TV}}(\widehat{x}_{N_{0}},x_{t_{N_{0}}})+\textup{{TV}}(\widehat{y}_{N_{0}+1},x_{\delta})

By Corollary A.7,

TV(y^N0+1,xδ)W2(y^N0+1,xδ)d1/12L+εscLd1/36+Ld17/36hcorr\textup{{TV}}(\widehat{y}_{N_{0}+1},x_{\delta})\leq W_{2}(\widehat{y}^{\prime}_{N_{0}+1},x_{\delta})\cdot d^{1/12}\cdot\sqrt{L}+\frac{\varepsilon_{\mathrm{sc}}}{\sqrt{L}d^{1/36}}+\sqrt{L}d^{17/36}h_{\mathrm{corr}}

For tN0\lesssim1/Lt_{N_{0}}\lesssim 1/L, by Lemma A.5, noting that the number of predictor steps in this case is O(loghpredδ)O\left(\log\frac{h_{\mathrm{pred}}}{\delta}\right),

W2(y^N0+1,xδ)\lesssimεscL+(L3/2dhpred2+Ldhpred3/2)loghpredδW_{2}(\widehat{y}^{\prime}_{N_{0}+1},x_{\delta})\lesssim\frac{\varepsilon_{\mathrm{sc}}}{L}+\left(L^{3/2}\sqrt{d}h_{\mathrm{pred}}^{2}+L\sqrt{d}h_{\mathrm{pred}}^{3/2}\right)\cdot\sqrt{\log\frac{h_{\mathrm{pred}}}{\delta}}

The second claim follows by combining the above. ∎

We recall the following lemma on the convergence of the OU process from [CCL+23a]

Lemma A.9 (Lemma 13 of [CCL+23a]).

Let qtq_{t} denote the marginal law of the OU process started at q0=qq_{0}=q. Then, for all T\gtrsim1T\gtrsim 1,

TV(qT,𝒩(0,Id))\lesssim(d+\mathfrakm2)exp(T)\displaystyle\textup{{TV}}(q_{T},\mathcal{N}(0,I_{d}))\lesssim(\sqrt{d}+\mathfrak{m}_{2})\exp(-T)

Finally, we prove our main theorem on the convergence of our sequential algorithm.

Theorem A.10 (Convergence bound for sequential algorithm).

Suppose Assumptions 2.1-2.4 hold. If x^\widehat{x} denotes the output of Algorithm 6, for T=Θ(log(d\mathfrakm22ε2)),Tcorr=Θ(1Ld1/18)T=\Theta\left(\log\left(\frac{d\lor\mathfrak{m}_{2}^{2}}{\varepsilon^{2}}\right)\right),T_{\mathrm{corr}}=\Theta\left(\frac{1}{\sqrt{L}d^{1/18}}\right) and δ=Θ(ε2L2(d\mathfrakm22))\delta=\Theta\left(\frac{\varepsilon^{2}}{L^{2}(d\lor\mathfrak{m}_{2}^{2})}\right), if we set hpred=Θ~(min(ε1/2d1/3L3/2,ε2/3d5/12L5/3)1log(\mathfrakm2))h_{\mathrm{pred}}=\widetilde{\Theta}\left(\min\left(\frac{\varepsilon^{1/2}}{d^{1/3}L^{3/2}},\frac{\varepsilon^{2/3}}{d^{5/12}L^{5/3}}\right)\cdot\frac{1}{\log(\mathfrak{m}_{2})}\right), hcorr=Θ~(εd17/36L3/2log\mathfrak(m2))h_{\mathrm{corr}}=\widetilde{\Theta}\left(\frac{\varepsilon}{d^{17/36}L^{3/2}\log\mathfrak(m_{2})}\right), and if the score estimation satisfies εscO~(εLd1/12log\mathfrakm2)\varepsilon_{\mathrm{sc}}\leq\widetilde{O}\left(\frac{\varepsilon}{\sqrt{L}d^{1/12}\log\mathfrak{m_{2}}}\right), we have that

TV(x^,x0)\lesssimε\displaystyle\textup{{TV}}(\widehat{x},x_{0})\lesssim\varepsilon

with iteration complexity Θ~(L5/3d5/12εlog2(\mathfrakm2))\widetilde{\Theta}\left(\frac{L^{5/3}d^{5/12}}{\varepsilon}\cdot\log^{2}(\mathfrak{m}_{2})\right)

Proof.

We will let tn=Tn/Lt_{n}=T-n/L. First, note that by Lemma A.9

TV(x^0,xt0)\lesssim(d+\mathfrakm2)exp(T)\displaystyle\textup{{TV}}(\widehat{x}_{0},x_{t_{0}})\lesssim(\sqrt{d}+\mathfrak{m}_{2})\exp(-T)

We divide our analysis into two steps. For the first N0=O(LT)N_{0}=O(LT) steps, we iterate the first part of Lemma A.8 to obtain

TV(x^N0,xtN0)\displaystyle\textup{{TV}}(\widehat{x}_{N_{0}},x_{t_{N_{0}}}) TV(x^0,xt0)+O(L2d7/12hpred2+L3/2d7/12hpred3/2+Ld17/36hcorr+εscd1/12L)N0\displaystyle\leq\textup{{TV}}(\widehat{x}_{0},x_{t_{0}})+O\left(L^{2}d^{7/12}h_{\mathrm{pred}}^{2}+L^{3/2}d^{7/12}h_{\mathrm{pred}}^{3/2}+\sqrt{L}d^{17/36}h_{\mathrm{corr}}+\frac{\varepsilon_{\mathrm{sc}}d^{1/12}}{\sqrt{L}}\right)\cdot N_{0}
\lesssim(d+\mathfrakm2)exp(T)+L3d7/12hpred2T+L5/2d7/12hpred3/2T+L3/2d17/36hcorrT+εscd1/12TL\displaystyle\lesssim\left(\sqrt{d}+\mathfrak{m}_{2}\right)\exp(-T)+L^{3}d^{7/12}h_{\mathrm{pred}}^{2}T+L^{5/2}d^{7/12}h_{\mathrm{pred}}^{3/2}T+L^{3/2}d^{17/36}h_{\mathrm{corr}}T+\varepsilon_{\mathrm{sc}}d^{1/12}T\sqrt{L}

Applying the second part of Lemma A.8 for the second stage of the algorithm, we have

TV(x^N0+1,xδ)\displaystyle\textup{{TV}}(\widehat{x}_{N_{0}+1},x_{\delta})
\lesssim(d+\mathfrakm2)exp(T)+(L3d7/12hpred2+L5/2d7/12hpred3/2)(T+loghpredδ)\displaystyle\lesssim\left(\sqrt{d}+\mathfrak{m}_{2}\right)\exp(-T)+\left(L^{3}d^{7/12}h_{\mathrm{pred}}^{2}+L^{5/2}d^{7/12}h_{\mathrm{pred}}^{3/2}\right)\left(T+\sqrt{\log\frac{h_{\mathrm{pred}}}{\delta}}\right)
+L3/2d17/36hcorrT+εscd1/12TL\displaystyle+L^{3/2}d^{17/36}h_{\mathrm{corr}}T+\varepsilon_{\mathrm{sc}}d^{1/12}T\sqrt{L}

Setting T=Θ(log(d\mathfrakm22ε2))T=\Theta\left(\log\left(\frac{d\lor\mathfrak{m}_{2}^{2}}{\varepsilon^{2}}\right)\right), hpred=Θ~(min(ε1/2d1/3L3/2,ε2/3d5/12L5/3)1log(\mathfrakm2))h_{\mathrm{pred}}=\widetilde{\Theta}\left(\min\left(\frac{\varepsilon^{1/2}}{d^{1/3}L^{3/2}},\frac{\varepsilon^{2/3}}{d^{5/12}L^{5/3}}\right)\frac{1}{\log(\mathfrak{m}_{2})}\right), and hcorr=Θ~(εd17/36L3/21log(\mathfrakm2))h_{\mathrm{corr}}=\widetilde{\Theta}\left(\frac{\varepsilon}{d^{17/36}L^{3/2}}\cdot\frac{1}{\log(\mathfrak{m}_{2})}\right), if the score estimation error satisfies εscO~(εLd1/12log(\mathfrakm2))\varepsilon_{\mathrm{sc}}\leq\widetilde{O}\left(\frac{\varepsilon}{\sqrt{L}d^{1/12}\log(\mathfrak{m}_{2})}\right), with iteration complexity Θ~(L5/3d5/12log2\mathfrakm2ε)\widetilde{\Theta}\left(\frac{L^{5/3}d^{5/12}\log^{2}\mathfrak{m}_{2}}{\varepsilon}\right), we obtain TV(x^N0+1,xδ)ε\textup{{TV}}(\widehat{x}_{N_{0}+1},x_{\delta})\leq\varepsilon. ∎

Appendix B Parallel algorithm

B.1 Predictor step

In this section, we will apply a parallel version of randomized midpoint for the predictor step, where only Θ~(log2(Ldε))\widetilde{\Theta}(\log^{2}(\frac{Ld}{\varepsilon})) iteration complexity will be required to attain our desired error bound for one predictor step.

In each iteration nn, we will first sample RR randomized midpoints that are in expectation evenly spaced with δn=hnRn\delta_{n}=\frac{h_{n}}{R_{n}} time intervals between consecutive midpoints. Next, in our step (c), we provide an initial estimate on the xx value of midpoints using our estimate of position x^n\widehat{x}_{n} at time tnt_{n} provided by iteration n1n-1. This step is analogous to step (c) in Algorithm 4 for the sequential predictor step. Then, in step (d) we refine our initial estimates by using a discrete version of Picard iteration, where for round kk, we compute a new estimate of xtnαihnx_{t_{n}-\alpha_{i}h_{n}} based on the estimates of xtnαjhnx_{t_{n}-\alpha_{j}h_{n}} for jij\leq i in round k1k-1. Note that a trajectory x(t)x(t) that starts from time t0t_{0} and follows the true ODE is a fix point of operator τ\tau that maps continuous function to continuous function, where

τ(x)(t)=ett0x(t0)+t0test0qs(x(s))ds.\tau(x)(t)=e^{t-t_{0}}x(t_{0})+\int_{t_{0}}^{t}e^{s-t_{0}}\cdot\nabla q_{s}(x(s))\mathrm{d}s.

By smoothness of the true ODE, the continuous Picard iteration converges exponentially to the true trajectory, and we will show that discretization error for Picard iteration is controlled. After the refinements have sufficiently reduced the estimation error for our randomized midpoints, we make a final calculation, estimating the value of xtn+hnx_{t_{n}+h_{n}} based on the estimated value at all the randomized midpoints.

Algorithm 7 PredictorStep (Parallel)

Input parameters:

  • Starting sample x^0\widehat{x}_{0}, Starting time t0t_{0}, Number of steps NN, Step size {hn}n=0N1\{h_{n}\}_{n=0}^{N-1}, Number of midpoint estimates {Rn}n=0N1\{R_{n}\}_{n=0}^{N-1}, Number of parallel iteration {Kn}n=0N1\{K_{n}\}_{n=0}^{N-1}, Score estimates s^t\widehat{s}_{t}

  • For all n=0,,N1n=0,\cdots,N-1: let δn=hnRn\delta_{n}=\frac{h_{n}}{R_{n}}

  1. 1.

    For n=0,,N1n=0,\dots,N-1:

    1. (a)

      Let tn=t0w=0n1hwt_{n}=t_{0}-\sum_{w=0}^{n-1}h_{w}

    2. (b)

      Randomly sample αi\alpha_{i} uniformly from [(i1)/Rn,i/Rn][(i-1)/R_{n},i/R_{n}] for all i{1,,Rn}i\in\{1,\cdots,R_{n}\}

    3. (c)

      For i=1,,Rni=1,\cdots,R_{n} in parallel: Let x^n,i(0)=eαihnx^n+(eαihn1)s^tn(x^n)\widehat{x}^{(0)}_{n,i}=e^{\alpha_{i}h_{n}}\widehat{x}_{n}+\left(e^{\alpha_{i}h_{n}}-1\right)\cdot\widehat{s}_{t_{n}}(\widehat{x}_{n})

    4. (d)

      For k=1,,Knk=1,\cdots,K_{n}:

      For i=1,,Rni=1,\cdots,R_{n} in parallel:

      x^n,i(k):=eαihnx^n+j=1i(eαihn(j1)δnmax(eαihnjδn,1))s^tnαjhn(x^n,j(k1))\widehat{x}^{(k)}_{n,i}:=e^{\alpha_{i}h_{n}}\widehat{x}_{n}+\sum_{j=1}^{i}\left(e^{\alpha_{i}h_{n}-(j-1)\delta_{n}}-\max(e^{\alpha_{i}h_{n}-j\delta_{n}},1)\right)\cdot\widehat{s}_{t_{n}-\alpha_{j}h_{n}}(\widehat{x}^{(k-1)}_{n,j})

    5. (e)

      x^n+1=ehnx^n+δni=1Rehnαihns^tnαihn(x^n,i(Kn))\widehat{x}_{n+1}=e^{h_{n}}\widehat{x}_{n}+\delta_{n}\cdot\sum_{i=1}^{R}e^{h_{n}-\alpha_{i}h_{n}}\widehat{s}_{t_{n}-\alpha_{i}h_{n}}(\widehat{x}^{(K_{n})}_{n,i})

  2. 2.

    Let tN=t0n=0N1hnt_{N}=t_{0}-\sum_{n=0}^{N-1}h_{n}

  3. 3.

    Return x^N,tN\widehat{x}_{N},t_{N}.

In our analysis, we follow the same notation as in Section A.1. We first establish a poly(L,d)\operatorname{poly}(L,d) bound on the initial estimation error incurred in step (c) of each iteration in Algorithm 7.

Claim B.1.

Suppose L1L\geq 1. Assume hn\lesssim1/Lh_{n}\lesssim 1/L. For any n=0,,N1n=0,\cdots,N-1, suppose we draw x^n\widehat{x}_{n} from an arbitrary distribution pnp_{n}, then run step (a) - (e) in Algorithm 7. Then for any i=1,,Rni=1,\cdots,R_{n},

\mathbbEx^n,i(0)xn(αihn)2\lesssimhn2εsc2+L2dhn4(L1tnhn)+L2h2x^nxtn2,\operatorname*{\mathbb{E}}\left\|\widehat{x}^{(0)}_{n,i}-x_{n}^{*}(\alpha_{i}h_{n})\right\|^{2}\lesssim h_{n}^{2}\varepsilon_{\mathrm{sc}}^{2}+L^{2}dh_{n}^{4}(L\lor\frac{1}{t_{n}-h_{n}})+L^{2}h^{2}\left\|\widehat{x}_{n}-x_{t_{n}}\right\|^{2},

where xn(t)x_{n}^{*}(t) is solution of the true ODE starting at x^n\widehat{x}_{n} at time tnt_{n} and running until time tntt_{n}-t.

Proof.

Notice that in step (c) of Algorithm 7, the initial estimate of the randomized midpoint is done with the exact same formula as in step (c) of Algorithm 6, except we calculate this initial estimate for RnR_{n} different randomized midpoints. Notice also that the bound for discretization error in Lemma A.2 is not dependent on specific value of α\alpha, as long as the randomed value is at most 11. Hence we can use identical calculation to yield the claim. ∎

Next, we show how to drive the initialization error from Lemma B.1 down (exponentially) using the Picard iterations in step (d) of Algorithm 7.

Lemma B.2.

Suppose L1L\geq 1. Assume hn\lesssim1/Lh_{n}\lesssim 1/L. For all iterations n{0,,N1}n\in\{0,\cdots,N-1\}, suppose we draw x^n\widehat{x}_{n} from an arbitrary distribution pnp_{n}, then run step (a) - (e) in Algorithm 7. Then for all k{1,Kn}k\in\{1,\cdots K_{n}\} and i{1,,Rn}i\in\{1,\cdots,R_{n}\},

\mathbbEx^n,i(k)xn(αihn)2\displaystyle\operatorname*{\mathbb{E}}\left\|\widehat{x}^{(k)}_{n,i}-x_{n}^{*}(\alpha_{i}h_{n})\right\|^{2} \lesssim(8hn2L2)k(1Rj=1Rx^n,j(0)xn(αjhn)2)\displaystyle\lesssim\left(8h_{n}^{2}L^{2}\right)^{k}\cdot\left(\frac{1}{R}\sum_{j=1}^{R}\left\|\widehat{x}^{(0)}_{n,j}-x_{n}^{*}(\alpha_{j}h_{n})\right\|^{2}\right)
+hn2(εsc2+L2dhn2Rn2(L1tnhn)+L2\mathbbEx^nxtn2),\displaystyle\quad+h_{n}^{2}\left(\varepsilon_{\mathrm{sc}}^{2}+\frac{L^{2}dh_{n}^{2}}{R_{n}^{2}}(L\lor\frac{1}{t_{n}-h_{n}})+L^{2}\cdot\operatorname*{\mathbb{E}}\left\|\widehat{x}_{n}-x_{t_{n}}\right\|^{2}\right), (21)

where xn(t)x_{n}^{*}(t) is solution of the true ODE starting at x^n\widehat{x}_{n} at time tnt_{n} and running until time tntt_{n}-t.

Proof.

Fixing iteration nn, we will let h:=hnh:=h_{n}, R:=RnR:=R_{n} and δ:=δn\delta:=\delta_{n}. The formula of x^n,i(k)\widehat{x}^{(k)}_{n,i} and xn(αih)x_{n}^{*}(\alpha_{i}h) has the same coefficient for x^n\hat{x}_{n}, thus we can bound the difference as follows:

\mathbbEx^n,i(k)xn(αih)2\displaystyle\operatorname*{\mathbb{E}}\left\|\widehat{x}^{(k)}_{n,i}-x_{n}^{*}(\alpha_{i}h)\right\|^{2}
\displaystyle\leq \mathbbEj=1i(tnmin(jδ,αih)tn(j1)δes(tnαih)dss^tnαjh(x^n,j(k1))tnαihtnes(tnαih)lnqs(xn(tns))ds)2\displaystyle\operatorname*{\mathbb{E}}\left\|\sum_{j=1}^{i}\left(\int_{t_{n}-\min(j\delta,\alpha_{i}h)}^{t_{n}-(j-1)\delta}e^{s-(t_{n}-\alpha_{i}h)}\,\mathrm{d}s\cdot\widehat{s}_{t_{n}-\alpha_{j}h}(\widehat{x}^{(k-1)}_{n,j})-\int_{t_{n}-\alpha_{i}h}^{t_{n}}e^{s-(t_{n}-\alpha_{i}h)}\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\,\mathrm{d}s\right)\right\|^{2}
\displaystyle\leq 2\mathbbEj=1itnmin(jδ,αih)tn(j1)δes(tnαih)ds(s^tnαjh(x^n,j(k1))lnqtnαjh(xn(αjh)))2\displaystyle 2\operatorname*{\mathbb{E}}\left\|\sum_{j=1}^{i}\int_{t_{n}-\min(j\delta,\alpha_{i}h)}^{t_{n}-(j-1)\delta}e^{s-(t_{n}-\alpha_{i}h)}\,\mathrm{d}s\cdot\left(\widehat{s}_{t_{n}-\alpha_{j}h}(\widehat{x}^{(k-1)}_{n,j})-\nabla\ln q_{t_{n}-\alpha_{j}h}(x_{n}^{*}(\alpha_{j}h))\right)\right\|^{2} (22)
+2\mathbbEj=1itnmin(jδ,αih)tn(j1)δes(tnαih)(qtnαjh(xn(αjh))lnqs(xn(tns)))ds2.\displaystyle\quad+2\operatorname*{\mathbb{E}}\left\|\sum_{j=1}^{i}\int_{t_{n}-\min(j\delta,\alpha_{i}h)}^{t_{n}-(j-1)\delta}e^{s-(t_{n}-\alpha_{i}h)}\left(\nabla q_{t_{n}-\alpha_{j}h}(x_{n}^{*}(\alpha_{j}h))-\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\right)\,\mathrm{d}s\right\|^{2}. (23)

The first to second line is by definition and the second to third line is by Young’s inequality. Now, we will bound Equation 22 and Equation 23 separately. By 2.2 and 2.4,

\mathbbEs^tnαjh(x^n,j(k1))lnqtnαjh(xn(αjh))2\displaystyle\operatorname*{\mathbb{E}}\left\|\widehat{s}_{t_{n}-\alpha_{j}h}(\widehat{x}^{(k-1)}_{n,j})-\nabla\ln q_{t_{n}-\alpha_{j}h}(x_{n}^{*}(\alpha_{j}h))\right\|^{2}
\displaystyle\leq 2\mathbbEs^tnαjh(x^n,j(k1))lnqtnαjh(x^n,j(k1))2+2\mathbbElnqtnαjh(x^n,j(k1))lnqtnαjh(xn(αjh))2\displaystyle 2\operatorname*{\mathbb{E}}\left\|\widehat{s}_{t_{n}-\alpha_{j}h}(\widehat{x}^{(k-1)}_{n,j})-\nabla\ln q_{t_{n}-\alpha_{j}h}(\widehat{x}^{(k-1)}_{n,j})\right\|^{2}+2\operatorname*{\mathbb{E}}\left\|\nabla\ln q_{t_{n}-\alpha_{j}h}(\widehat{x}^{(k-1)}_{n,j})-\nabla\ln q_{t_{n}-\alpha_{j}h}(x_{n}^{*}(\alpha_{j}h))\right\|^{2}
\displaystyle\leq 2εsc2+2L2x^n,j(k1)xn(αjh)2.\displaystyle 2\varepsilon_{\mathrm{sc}}^{2}+2L^{2}\cdot\left\|\widehat{x}^{(k-1)}_{n,j}-x_{n}^{*}(\alpha_{j}h)\right\|^{2}.

The term in Equation 22 can now be bounded as follows

\mathbbEj=1itnmin(jδ,αih)tn(j1)δes(tnαih)ds(s^tnαjh(x^n,j(k1))lnqtnαjh(xn(αjh)))2\displaystyle\operatorname*{\mathbb{E}}\left\|\sum_{j=1}^{i}\int_{t_{n}-\min(j\delta,\alpha_{i}h)}^{t_{n}-(j-1)\delta}e^{s-(t_{n}-\alpha_{i}h)}\,\mathrm{d}s\cdot\left(\widehat{s}_{t_{n}-\alpha_{j}h}(\widehat{x}^{(k-1)}_{n,j})-\nabla\ln q_{t_{n}-\alpha_{j}h}(x_{n}^{*}(\alpha_{j}h))\right)\right\|^{2}
\displaystyle\leq Rj=1i\mathbbEtnmin(jδ,αih)tn(j1)δes(tnαih)ds(s^tnαjh(x^n,j(k1))lnqtnαjh(xn(αjh)))2\displaystyle R\cdot\sum_{j=1}^{i}\operatorname*{\mathbb{E}}\left\|\int_{t_{n}-\min(j\delta,\alpha_{i}h)}^{t_{n}-(j-1)\delta}e^{s-(t_{n}-\alpha_{i}h)}\,\mathrm{d}s\cdot\left(\widehat{s}_{t_{n}-\alpha_{j}h}(\widehat{x}^{(k-1)}_{n,j})-\nabla\ln q_{t_{n}-\alpha_{j}h}(x_{n}^{*}(\alpha_{j}h))\right)\right\|^{2}
\displaystyle\leq Rδ2e2αihj=1i\mathbbEs^tnαjh(x^n,j(k1))lnqtnαjh(xn(αjh))2\displaystyle R\cdot\delta^{2}\cdot e^{2\alpha_{i}h}\sum_{j=1}^{i}\operatorname*{\mathbb{E}}\left\|\widehat{s}_{t_{n}-\alpha_{j}h}(\widehat{x}^{(k-1)}_{n,j})-\nabla\ln q_{t_{n}-\alpha_{j}h}(x_{n}^{*}(\alpha_{j}h))\right\|^{2}
\displaystyle\leq 2Rδ2e2αihj=1i(εsc2+L2x^n,j(k1)xn(αjh)2)\displaystyle 2R\cdot\delta^{2}\cdot e^{2\alpha_{i}h}\sum_{j=1}^{i}\left(\varepsilon_{\mathrm{sc}}^{2}+L^{2}\cdot\left\|\widehat{x}^{(k-1)}_{n,j}-x_{n}^{*}(\alpha_{j}h)\right\|^{2}\right)
\displaystyle\leq 2R2δ2e2αih1Rj=1R(εsc2+L2x^n,j(k1)xn(αjh)2)\displaystyle 2R^{2}\cdot\delta^{2}\cdot e^{2\alpha_{i}h}\frac{1}{R}\sum_{j=1}^{R}\left(\varepsilon_{\mathrm{sc}}^{2}+L^{2}\cdot\left\|\widehat{x}^{(k-1)}_{n,j}-x_{n}^{*}(\alpha_{j}h)\right\|^{2}\right)
\displaystyle\leq 4h2εsc2+4h2L21Rj=1Rx^n,j(k1)xn(αjh)2.\displaystyle 4h^{2}\varepsilon_{\mathrm{sc}}^{2}+4h^{2}L^{2}\cdot\frac{1}{R}\sum_{j=1}^{R}\cdot\left\|\widehat{x}^{(k-1)}_{n,j}-x_{n}^{*}(\alpha_{j}h)\right\|^{2}.

The first to second line is by inequality (i=1nai)2ni=1nai2(\sum_{i=1}^{n}a_{i})^{2}\leq n\sum_{i=1}^{n}a_{i}^{2}, the second to third line is by the fact that tnmin(jδ,αih)tn(j1)δes(tnαih)dsδeαih\int_{t_{n}-\min(j\delta,\alpha_{i}h)}^{t_{n}-(j-1)\delta}e^{s-(t_{n}-\alpha_{i}h)}\,\mathrm{d}s\leq\delta\cdot e^{\alpha_{i}h}, the fifth to sixth line is by Rδ=hR\delta=h and that e2αihe^{2\alpha_{i}h} is at most 22 when h<1/4h<1/4.

Similarly, the term in Equation 23 can be bounded as follows

\mathbbEj=1itnmin(jδ,αih)tn(j1)δes(tnαih)(lnqtnαjh(xn(αjh))lnqs(xn(tns)))ds2\displaystyle\operatorname*{\mathbb{E}}\left\|\sum_{j=1}^{i}\int_{t_{n}-\min(j\delta,\alpha_{i}h)}^{t_{n}-(j-1)\delta}e^{s-(t_{n}-\alpha_{i}h)}\left(\nabla\ln q_{t_{n}-\alpha_{j}h}(x_{n}^{*}(\alpha_{j}h))-\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\right)\,\mathrm{d}s\right\|^{2}
\displaystyle\leq Rj=1i\mathbbEtnmin(jδ,αih)tn(j1)δes(tnαih)(lnqtnαjh(xn(αjh))lnqs(xn(tns)))ds2\displaystyle R\cdot\sum_{j=1}^{i}\operatorname*{\mathbb{E}}\left\|\int_{t_{n}-\min(j\delta,\alpha_{i}h)}^{t_{n}-(j-1)\delta}e^{s-(t_{n}-\alpha_{i}h)}\left(\nabla\ln q_{t_{n}-\alpha_{j}h}(x_{n}^{*}(\alpha_{j}h))-\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\right)\,\mathrm{d}s\right\|^{2}
\displaystyle\leq Rδe2αihj=1itnmin(jδ,αih)tn(j1)δ\mathbbElnqtnαjh(xn(αjh))lnqs(xn(tns))2ds\displaystyle R\cdot\delta\cdot e^{2\alpha_{i}h}\cdot\sum_{j=1}^{i}\int_{t_{n}-\min(j\delta,\alpha_{i}h)}^{t_{n}-(j-1)\delta}\operatorname*{\mathbb{E}}\left\|\nabla\ln q_{t_{n}-\alpha_{j}h}(x_{n}^{*}(\alpha_{j}h))-\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\right\|^{2}\,\mathrm{d}s

Since |xn(αjh)s|δ|x_{n}^{*}(\alpha_{j}h)-s|\leq\delta,

\mathbbEqtnαjh(xn(αjh))lnqs(xn(tns))2\displaystyle\operatorname*{\mathbb{E}}\left\|\nabla q_{t_{n}-\alpha_{j}h}(x_{n}^{*}(\alpha_{j}h))-\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\right\|^{2}
\displaystyle\leq 3\mathbbEqtnαjh(xn(αjh))qtnαjh(xtnαjh)2+3\mathbbEqs(xs)qs(xn(tns))2\displaystyle 3\operatorname*{\mathbb{E}}\left\|\nabla q_{t_{n}-\alpha_{j}h}(x_{n}^{*}(\alpha_{j}h))-\nabla q_{t_{n}-\alpha_{j}h}(x_{t_{n}-\alpha_{j}h})\right\|^{2}+3\operatorname*{\mathbb{E}}\left\|\nabla q_{s}(x_{s})-\nabla q_{s}(x_{n}^{*}(t_{n}-s))\right\|^{2}
+3\mathbbEqtnαjh(xtnαjh)qs(xs)2\displaystyle\quad+3\operatorname*{\mathbb{E}}\left\|\nabla q_{t_{n}-\alpha_{j}h}(x_{t_{n}-\alpha_{j}h})-\nabla q_{s}(x_{s})\right\|^{2}
=\displaystyle= 3\mathbbEqtnαjh(xn(αjh))qtnαjh(xtnαjh)2+3\mathbbEqs(xs)qs(xn(tns))2\displaystyle 3\operatorname*{\mathbb{E}}\left\|\nabla q_{t_{n}-\alpha_{j}h}(x_{n}^{*}(\alpha_{j}h))-\nabla q_{t_{n}-\alpha_{j}h}(x_{t_{n}-\alpha_{j}h})\right\|^{2}+3\operatorname*{\mathbb{E}}\left\|\nabla q_{s}(x_{s})-\nabla q_{s}(x_{n}^{*}(t_{n}-s))\right\|^{2}
+3\mathbbEstnαjhulnqu(xu)du2\displaystyle\quad+3\operatorname*{\mathbb{E}}\left\|\int_{s}^{t_{n}-\alpha_{j}h}\partial_{u}\nabla\ln q_{u}(x_{u})du\right\|^{2}
\displaystyle\leq 3L2\mathbbExn(αjh)xtnαjh2+3L2\mathbbExn(tns)xs2+3\mathbbEstnαjhulnqu(xu)du2.\displaystyle 3L^{2}\cdot\operatorname*{\mathbb{E}}\left\|x_{n}^{*}(\alpha_{j}h)-x_{t_{n}-\alpha_{j}h}\right\|^{2}+3L^{2}\cdot\operatorname*{\mathbb{E}}\left\|x_{n}^{*}(t_{n}-s)-x_{s}\right\|^{2}+3\operatorname*{\mathbb{E}}\left\|\int_{s}^{t_{n}-\alpha_{j}h}\partial_{u}\nabla\ln q_{u}(x_{u})du\right\|^{2}.

The second to third line is by Young’s inequality, and the fourth to fifth line is by 2.2. By Lemma 3 in [CCL+23a],

\mathbbEstnαjhulnqu(xu)du2\displaystyle\operatorname*{\mathbb{E}}\left\|\int_{s}^{t_{n}-\alpha_{j}h}\partial_{u}\nabla\ln q_{u}(x_{u})du\right\|^{2} δstnαjh\mathbbEulnqu(xu)2du\displaystyle\leq\delta\cdot\int_{s}^{t_{n}-\alpha_{j}h}\operatorname*{\mathbb{E}}\left\|\partial_{u}\nabla\ln q_{u}(x_{u})\right\|^{2}du
δstnαjhL2dmax(L,1u)𝑑uL2dδ2(L1tnh).\displaystyle\leq\delta\cdot\int_{s}^{t_{n}-\alpha_{j}h}L^{2}d\max(L,\frac{1}{u})du\leq L^{2}d\delta^{2}(L\lor\frac{1}{t_{n}-h}).

Now, by Lemma A.1 and the fact that h\lesssim1Lh\lesssim\frac{1}{L},

\mathbbEqtnαjh(xn(αjh))lnqs(xn(tns))2\displaystyle\operatorname*{\mathbb{E}}\left\|\nabla q_{t_{n}-\alpha_{j}h}(x_{n}^{*}(\alpha_{j}h))-\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\right\|^{2} 12L2exp(Lh)\mathbbEx^nxtn2+3L2dδ2(L1tnh)\displaystyle\leq 12L^{2}\exp(Lh)\operatorname*{\mathbb{E}}\left\|\widehat{x}_{n}-x_{t_{n}}\right\|^{2}+3L^{2}d\delta^{2}(L\lor\frac{1}{t_{n}-h})
36L2\mathbbEx^nxtn2+3L2dδ2(L1tnh).\displaystyle\leq 36L^{2}\operatorname*{\mathbb{E}}\left\|\widehat{x}_{n}-x_{t_{n}}\right\|^{2}+3L^{2}d\delta^{2}(L\lor\frac{1}{t_{n}-h}).

We conclude that the term in Equation 23 can be bounded by

Rδe2αihj=1itnmin(jδ,αih)tn(j1)δ36L2\mathbbEx^nxtn2+3L2dδ2(L1tnh)ds\displaystyle R\cdot\delta\cdot e^{2\alpha_{i}h}\cdot\sum_{j=1}^{i}\int_{t_{n}-\min(j\delta,\alpha_{i}h)}^{t_{n}-(j-1)\delta}36L^{2}\operatorname*{\mathbb{E}}\left\|\widehat{x}_{n}-x_{t_{n}}\right\|^{2}+3L^{2}d\delta^{2}(L\lor\frac{1}{t_{n}-h})\,\mathrm{d}s
\displaystyle\leq 2R2δ2(36L2\mathbbEx^nxtn2+3L2dδ2(L1tnh))\displaystyle 2R^{2}\delta^{2}\cdot\left(36L^{2}\operatorname*{\mathbb{E}}\left\|\widehat{x}_{n}-x_{t_{n}}\right\|^{2}+3L^{2}d\delta^{2}(L\lor\frac{1}{t_{n}-h})\right)
=\displaystyle= 4h2L2(18\mathbbEx^nxtn2+32dδ2(L1tnh)).\displaystyle 4h^{2}L^{2}\left(18\operatorname*{\mathbb{E}}\left\|\widehat{x}_{n}-x_{t_{n}}\right\|^{2}+\frac{3}{2}d\delta^{2}(L\lor\frac{1}{t_{n}-h})\right).

Combining the bounds for Equation 22 and Equation 23, we get

\mathbbEx^n,i(k)xn(αih)28h2L2(1Rj=1Rx^n,j(k1)xn(αjh)2+εsc2L2+32dδ2(L1tnh)+18\mathbbEx^nxtn2).\operatorname*{\mathbb{E}}\left\|\widehat{x}^{(k)}_{n,i}-x_{n}(\alpha_{i}h)\right\|^{2}\\ \leq 8h^{2}L^{2}\cdot\left(\frac{1}{R}\sum_{j=1}^{R}\left\|\widehat{x}^{(k-1)}_{n,j}-x_{n}^{*}(\alpha_{j}h)\right\|^{2}+\frac{\varepsilon_{\mathrm{sc}}^{2}}{L^{2}}+\frac{3}{2}d\delta^{2}(L\lor\frac{1}{t_{n}-h})+18\operatorname*{\mathbb{E}}\left\|\widehat{x}_{n}-x_{t_{n}}\right\|^{2}\right).

Given sufficiently small constant hh, eαih2e^{\alpha_{i}h}\leq 2. Moreover, by the definition of δ\delta, Rδ=hR\delta=h. By unrolling the recursion, we get

\mathbbEx^n,i(k)xtnαih2\displaystyle\operatorname*{\mathbb{E}}\left\|\widehat{x}^{(k)}_{n,i}-x_{t_{n}-\alpha_{i}h}\right\|^{2} \lesssim(8h2L2)k(1Rj=1Rxn,j(0)xn(αjh)2)\displaystyle\lesssim\left(8h^{2}L^{2}\right)^{k}\cdot\left(\frac{1}{R}\sum_{j=1}^{R}\left\|x^{(0)}_{n,j}-x_{n}^{*}(\alpha_{j}h)\right\|^{2}\right)
+8h2L218h2L2(εsc2L2+dδ2(L1tnh)+\mathbbEx^nxtn2)\displaystyle\quad+\frac{8h^{2}L^{2}}{1-8h^{2}L^{2}}\left(\frac{\varepsilon_{\mathrm{sc}}^{2}}{L^{2}}+d\delta^{2}(L\lor\frac{1}{t_{n}-h})+\operatorname*{\mathbb{E}}\left\|\widehat{x}_{n}-x_{t_{n}}\right\|^{2}\right)
\lesssim(8h2L2)k(1Rj=1Rx^n,j(0)xn(αjh)2)\displaystyle\lesssim\left(8h^{2}L^{2}\right)^{k}\cdot\left(\frac{1}{R}\sum_{j=1}^{R}\left\|\widehat{x}^{(0)}_{n,j}-x_{n}^{*}(\alpha_{j}h)\right\|^{2}\right)
+h2(εsc2+L2dδ2(L1tnh)+L2\mathbbEx^nxtn2).\displaystyle\quad+h^{2}\left(\varepsilon_{\mathrm{sc}}^{2}+L^{2}d\delta^{2}(L\lor\frac{1}{t_{n}-h})+L^{2}\cdot\operatorname*{\mathbb{E}}\left\|\widehat{x}_{n}-x_{t_{n}}\right\|^{2}\right)\,.\qed

As a consequence of Lemma B.2, if we take the number KnK_{n} of Picard iterations sufficiently large, the error incurred in our estimate for xtnαihnx_{t_{n}-\alpha_{i}h_{n}} is dominated by the terms in the second line of Eq. (21).

Corollary B.3.

Assume L1L\geq 1. For all n{0,,N1}n\in\{0,\cdots,N-1\}, suppose we draw x^n\widehat{x}_{n} from an arbitrary distribution pnp_{n}, then run step (a) - (e) in Algorithm 7. In addition, suppose hn<13Lh_{n}<\frac{1}{3L} and Kn\gtrsimlog(Rn)K_{n}\gtrsim\log(R_{n}). Then for any i{1,,Rn}i\in\{1,\cdots,R_{n}\},

\mathbbEx^n,i(Kn)xn(αihn)2\displaystyle\operatorname*{\mathbb{E}}\left\|\widehat{x}^{(K_{n})}_{n,i}-x_{n}^{*}(\alpha_{i}h_{n})\right\|^{2} \lesssimhn2εsc2+L2dhn4Rn2(L1tnhn)+L2hn2\mathbbEx^nxtn2,\displaystyle\lesssim h_{n}^{2}\varepsilon_{\mathrm{sc}}^{2}+\frac{L^{2}dh_{n}^{4}}{R_{n}^{2}}\left(L\lor\frac{1}{t_{n}-h_{n}}\right)+L^{2}h_{n}^{2}\cdot\operatorname*{\mathbb{E}}\left\|\widehat{x}_{n}-x_{t_{n}}\right\|^{2},

where xn(t)x_{n}^{*}(t) is solution of the true ODE starting at x^n\widehat{x}_{n} at time tnt_{n} and running until time tntt_{n}-t.

Proof.

Fixing iteration nn, we will let h:=hnh:=h_{n}, R:=RnR:=R_{n}, δ:=δn\delta:=\delta_{n} and K:=KnK:=K_{n}. Notice that when K2log18h2L2log(R)K\geq\frac{2}{\log\frac{1}{8h^{2}L^{2}}}\cdot\log(R), (8h2L2)K\left(8h^{2}L^{2}\right)^{K} is at most 1R2\frac{1}{R^{2}}. Now by plugging B.1 into Lemma B.2, we get

\mathbbEx^n,i(K)xn(αih)2\displaystyle\operatorname*{\mathbb{E}}\left\|\widehat{x}^{(K)}_{n,i}-x_{n}^{*}(\alpha_{i}h)\right\|^{2} \lesssim(8h2L2)Kh2(εsc2+L2dh2(L1tnh)+L2x^nxtn2)\displaystyle\lesssim\left(8h^{2}L^{2}\right)^{K}\cdot h^{2}\cdot\left(\varepsilon_{\mathrm{sc}}^{2}+L^{2}dh^{2}(L\lor\frac{1}{t_{n}-h})+L^{2}\left\|\widehat{x}_{n}-x_{t_{n}}\right\|^{2}\right)
+h2(εsc2+L2dδ2(L1tnh)+L2\mathbbEx^nxtn2)\displaystyle\quad+h^{2}\left(\varepsilon_{\mathrm{sc}}^{2}+L^{2}d\delta^{2}(L\lor\frac{1}{t_{n}-h})+L^{2}\cdot\operatorname*{\mathbb{E}}\left\|\widehat{x}_{n}-x_{t_{n}}\right\|^{2}\right)
\lesssim((8h2L2)K+1)h2(εsc2+L2\mathbbEx^nxtn2)\displaystyle\lesssim\left((8h^{2}L^{2})^{K}+1\right)h^{2}\cdot\left(\varepsilon_{\mathrm{sc}}^{2}+L^{2}\cdot\operatorname*{\mathbb{E}}\left\|\widehat{x}_{n}-x_{t_{n}}\right\|^{2}\right)
+((8h2L2)K+1R2)L2dh4(L1tnh)\displaystyle\qquad+((8h^{2}L^{2})^{K}+\frac{1}{R^{2}})\cdot L^{2}dh^{4}\left(L\lor\frac{1}{t_{n}-h}\right)
\lesssimh2εsc2+L2h2\mathbbEx^nxtn2+L2dh4R2(L1tnh)\displaystyle\lesssim h^{2}\varepsilon_{\mathrm{sc}}^{2}+L^{2}h^{2}\cdot\operatorname*{\mathbb{E}}\left\|\widehat{x}_{n}-x_{t_{n}}\right\|^{2}+\frac{L^{2}dh^{4}}{R^{2}}\left(L\lor\frac{1}{t_{n}-h}\right)

The first to second inequality by rearrangement of terms, while the second to third inequality is by the fact that the terms 1R2\frac{1}{R^{2}} dominates (8h2L2)K\left(8h^{2}L^{2}\right)^{K}. ∎

We can now prove the parallel analogue of Lemma A.3 and Lemma A.4. Note that the bounds in Lemma B.4 and Lemma B.5 are identical to the bounds in Lemma A.3 and Lemma A.4, except from an additional 1Rn2\frac{1}{R_{n}^{2}} factor for the middle term. This additional factor stems from using RnR_{n} midpoints in each iteration nn (compared to using one midpoint each iteration in Algorithm 6).

Lemma B.4 (Parallel Predictor Bias).

Assume L1L\geq 1. For all n{0,,N1}n\in\{0,\cdots,N-1\}, suppose we draw x^n\widehat{x}_{n} from an arbitrary distribution pnp_{n}, then run step (a) - (e) in Algorithm 7. In addition, suppose hn<13Lh_{n}<\frac{1}{3L} and Kn\gtrsimlog(Rn)K_{n}\gtrsim\log(R_{n}). Then we have

\mathbbE\mathbbEαx^n+1xn(hn)2\lesssimhn2εsc2+L4hn6dRn2(L1tnhn)))+L4hn4\mathbbEx^nxtn2,\displaystyle\operatorname*{\mathbb{E}}\|\operatorname*{\mathbb{E}}_{\alpha}\widehat{x}_{n+1}-x_{n}^{*}(h_{n})\|^{2}\lesssim h_{n}^{2}\cdot\varepsilon_{\mathrm{sc}}^{2}+\frac{L^{4}h_{n}^{6}d}{R_{n}^{2}}\cdot(L\lor\frac{1}{t_{n}-h_{n}})))+L^{4}h_{n}^{4}\cdot\operatorname*{\mathbb{E}}\left\|\widehat{x}_{n}-x_{t_{n}}\right\|^{2},

where xn(t)x_{n}^{*}(t) is solution of the true ODE starting at x^n\widehat{x}_{n} at time tnt_{n} and running until time tntt_{n}-t.

Proof.

Fixing iteration nn, we will let h:=hnh:=h_{n}, R:=RnR:=R_{n}, δ:=δn\delta:=\delta_{n} and K:=KnK:=K_{n}. We have

\mathbbE\mathbbEαx^n+1xn(h)2\displaystyle\operatorname*{\mathbb{E}}\|\operatorname*{\mathbb{E}}_{\alpha}\widehat{x}_{n+1}-x_{n}^{*}(h)\|^{2}
\displaystyle\leq \mathbbE\mathbbEα[δi=1Rehαihs^tnαih(x^n,i(K))]tnhtnes(tnh)lnqs(xn(tns))ds2\displaystyle\operatorname*{\mathbb{E}}\left\|\operatorname*{\mathbb{E}}_{\alpha}\left[\delta\cdot\sum_{i=1}^{R}e^{h-\alpha_{i}h}\cdot\widehat{s}_{t_{n}-\alpha_{i}h}(\widehat{x}^{(K)}_{n,i})\right]-\int_{t_{n}-h}^{t_{n}}e^{s-(t_{n}-h)}\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\,\mathrm{d}s\right\|^{2}
\displaystyle\leq 2\mathbbE\mathbbEα[i=1Rδehαih(s^tnαih(x^n,i(K))lnqtnαih(xtnαih))]2\displaystyle 2\operatorname*{\mathbb{E}}\left\|\operatorname*{\mathbb{E}}_{\alpha}\left[\sum_{i=1}^{R}\delta e^{h-\alpha_{i}h}\cdot\left(\widehat{s}_{t_{n}-\alpha_{i}h}(\widehat{x}^{(K)}_{n,i})-\nabla\ln q_{t_{n}-\alpha_{i}h}(x_{t_{n}-\alpha_{i}h})\right)\right]\right\|^{2} (24)
+2\mathbbE\mathbbEαi=1Rδehαihlnqtnαih(xtnαih)tnhtnes(tnh)lnqs(xn(tns))ds2.\displaystyle+2\operatorname*{\mathbb{E}}\left\|\operatorname*{\mathbb{E}}_{\alpha}\sum_{i=1}^{R}\delta e^{h-\alpha_{i}h}\cdot\nabla\ln q_{t_{n}-\alpha_{i}h}(x_{t_{n}-\alpha_{i}h})-\int_{t_{n}-h}^{t_{n}}e^{s-(t_{n}-h)}\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\,\mathrm{d}s\right\|^{2}\,. (25)

Since αi\alpha_{i} is drawn uniformly from [(i1)δ,iδ][(i-1)\delta,i\delta],

\mathbbEα[δehαihlnqtnαih(xtnαih)]=tniδtn(i1)δes(tnh)lnqs(xn(tns))ds,\displaystyle\operatorname*{\mathbb{E}}_{\alpha}\left[\delta e^{h-\alpha_{i}h}\cdot\nabla\ln q_{t_{n}-\alpha_{i}h}(x_{t_{n}-\alpha_{i}h})\right]=\int_{t_{n}-i\delta}^{t_{n}-(i-1)\delta}e^{s-(t_{n}-h)}\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\,\mathrm{d}s,

and the term in Equation 25 is equal to 0. Now we need to bound Equation 24.

s^tnαih(x^n,i(K))lnqtnαih(xtnαih)2\displaystyle\left\|\widehat{s}_{t_{n}-\alpha_{i}h}(\widehat{x}^{(K)}_{n,i})-\nabla\ln q_{t_{n}-\alpha_{i}h}(x_{t_{n}-\alpha_{i}h})\right\|^{2}
\displaystyle\leq 2s^tnαih(x^n,i(K))lnqtnαih(x^n,i(K))2+2lnqtnαih(x^n,i(K))lnqtnαih(xtnαih)2\displaystyle 2\left\|\widehat{s}_{t_{n}-\alpha_{i}h}(\widehat{x}^{(K)}_{n,i})-\nabla\ln q_{t_{n}-\alpha_{i}h}(\widehat{x}^{(K)}_{n,i})\right\|^{2}+2\left\|\nabla\ln q_{t_{n}-\alpha_{i}h}(\widehat{x}^{(K)}_{n,i})-\nabla\ln q_{t_{n}-\alpha_{i}h}(x_{t_{n}-\alpha_{i}h})\right\|^{2}
\lesssim\displaystyle\lesssim εsc2+L2xn,j(K)xn(αjh)2\displaystyle\varepsilon_{\mathrm{sc}}^{2}+L^{2}\cdot\left\|x^{(K)}_{n,j}-x_{n}^{*}(\alpha_{j}h)\right\|^{2}
\lesssim\displaystyle\lesssim εsc2+L2(h2εsc2+L2dh4R2(L1tnh)+L2h2\mathbbEx^nxtn2).\displaystyle\varepsilon_{\mathrm{sc}}^{2}+L^{2}\cdot\left(h^{2}\varepsilon_{\mathrm{sc}}^{2}+\frac{L^{2}dh^{4}}{R^{2}}\left(L\lor\frac{1}{t_{n}-h}\right)+L^{2}h^{2}\cdot\operatorname*{\mathbb{E}}\left\|\widehat{x}_{n}-x_{t_{n}}\right\|^{2}\right).

The first to second line by inequality (i=1nai)2ni=1nai2(\sum_{i=1}^{n}a_{i})^{2}\leq n\sum_{i=1}^{n}a_{i}^{2}, the second to third line is by 2.4 and 2.2, and the third to fourth line is by Corollary B.3 (K\gtrsimlog(R)K\gtrsim\log(R), which satisfies the condition in Corollary B.3). Hence

2\mathbbE\mathbbEα[i=1Rδehαih(s^tnαih(x^n,i(K))lnqtnαih(xtnαih))]2\displaystyle 2\operatorname*{\mathbb{E}}\left\|\operatorname*{\mathbb{E}}_{\alpha}\left[\sum_{i=1}^{R}\delta e^{h-\alpha_{i}h}\cdot\left(\widehat{s}_{t_{n}-\alpha_{i}h}(\widehat{x}^{(K)}_{n,i})-\nabla\ln q_{t_{n}-\alpha_{i}h}(x_{t_{n}-\alpha_{i}h})\right)\right]\right\|^{2}
\displaystyle\leq 2R(2δ)2i=1R\mathbbEαs^tnαih(x^n,i(K))lnqtnαih(xtnαih)2\displaystyle 2R(2\delta)^{2}\sum_{i=1}^{R}\operatorname*{\mathbb{E}}_{\alpha}\left\|\widehat{s}_{t_{n}-\alpha_{i}h}(\widehat{x}^{(K)}_{n,i})-\nabla\ln q_{t_{n}-\alpha_{i}h}(x_{t_{n}-\alpha_{i}h})\right\|^{2}
\lesssim\displaystyle\lesssim 8δ2R2(εsc2+L2(h2εsc2+L2dh4R2(L1tnh)+L2h2\mathbbEx^nxtn2))\displaystyle 8\delta^{2}R^{2}\cdot\left(\varepsilon_{\mathrm{sc}}^{2}+L^{2}\cdot\left(h^{2}\varepsilon_{\mathrm{sc}}^{2}+\frac{L^{2}dh^{4}}{R^{2}}\left(L\lor\frac{1}{t_{n}-h}\right)+L^{2}h^{2}\cdot\operatorname*{\mathbb{E}}\left\|\widehat{x}_{n}-x_{t_{n}}\right\|^{2}\right)\right)
\lesssim\displaystyle\lesssim h2εsc2+L4dh6R2(L1tnh)+L4h4\mathbbEx^nxtn2.\displaystyle h^{2}\cdot\varepsilon_{\mathrm{sc}}^{2}+\frac{L^{4}dh^{6}}{R^{2}}(L\lor\frac{1}{t_{n}-h})+L^{4}h^{4}\cdot\operatorname*{\mathbb{E}}\left\|\widehat{x}_{n}-x_{t_{n}}\right\|^{2}.

The first to second step is by inequality (i=1nai)2ni=1nai2(\sum_{i=1}^{n}a_{i})^{2}\leq n\sum_{i=1}^{n}a_{i}^{2} and Young’s inequality, the second to third line is by plugging in our previous calculation, and the third to forth line is by h=δRh=\delta R and that h\lesssim1/Lh\lesssim 1/L. ∎

Lemma B.5 (Parallel Predictor Variance).

Assume L1L\geq 1. For all n{0,,N1}n\in\{0,\cdots,N-1\}, suppose we draw x^n\widehat{x}_{n} from an arbitrary distribution pnp_{n}, then run step (a) - (e) in Algorithm 7. In addition, suppose hn<13Lh_{n}<\frac{1}{3L} and Kn\gtrsimlog(Rn)K_{n}\gtrsim\log(R_{n}). Then we have

\mathbbEαx^n+1xn(hn)2\lesssimhn2εsc2+L2dhn4Rn2(L1tnhn)+L2hn2\mathbbEx^nxtn2,\displaystyle\operatorname*{\mathbb{E}}_{\alpha}\|\widehat{x}_{n+1}-x_{n}^{*}(h_{n})\|^{2}\lesssim h_{n}^{2}\cdot\varepsilon_{\mathrm{sc}}^{2}+\frac{L^{2}dh_{n}^{4}}{R_{n}^{2}}\left(L\lor\frac{1}{t_{n}-h_{n}}\right)+L^{2}h_{n}^{2}\cdot\operatorname*{\mathbb{E}}\left\|\widehat{x}_{n}-x_{t_{n}}\right\|^{2},

where xn(t)x_{n}^{*}(t) is solution of the true ODE starting at x^n\widehat{x}_{n} at time tnt_{n} and running until time tntt_{n}-t.

Proof.

Fixing iteration nn, we will let h:=hnh:=h_{n}, R:=RnR:=R_{n}, δ:=δn\delta:=\delta_{n} and K:=KnK:=K_{n}. We will separate \mathbbEαx^n+1xn(h)2\operatorname*{\mathbb{E}}_{\alpha}\|\widehat{x}_{n+1}-x_{n}^{*}(h)\|^{2} into several terms and bound each term separately.

\mathbbEαx^n+1xn(h)2\displaystyle\operatorname*{\mathbb{E}}_{\alpha}\|\widehat{x}_{n+1}-x_{n}^{*}(h)\|^{2}
\displaystyle\leq \mathbbEαδi=1Rehαihs^tnαih(x^n,i(K))tnhtnes(tnh)lnqs(xn(tns))ds2\displaystyle\operatorname*{\mathbb{E}}_{\alpha}\left\|\delta\cdot\sum_{i=1}^{R}e^{h-\alpha_{i}h}\cdot\widehat{s}_{t_{n}-\alpha_{i}h}(\widehat{x}^{(K)}_{n,i})-\int_{t_{n}-h}^{t_{n}}e^{s-(t_{n}-h)}\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\,\mathrm{d}s\right\|^{2}
\displaystyle\leq 3\mathbbEαi=1Rδehαih(s^tnαih(x^n,i(K))lnqtnαih(xtnαih))2\displaystyle 3\operatorname*{\mathbb{E}}_{\alpha}\left\|\sum_{i=1}^{R}\delta e^{h-\alpha_{i}h}\cdot\left(\widehat{s}_{t_{n}-\alpha_{i}h}(\widehat{x}^{(K)}_{n,i})-\nabla\ln q_{t_{n}-\alpha_{i}h}(x_{t_{n}-\alpha_{i}h})\right)\right\|^{2} (26)
+3\mathbbEi=1Rtniδtn(i1)δehαih(lnqtnαih(xtnαih)lnqs(xn(tns)))ds2\displaystyle+3\operatorname*{\mathbb{E}}\left\|\sum_{i=1}^{R}\int_{t_{n}-i\delta}^{t_{n}-(i-1)\delta}e^{h-\alpha_{i}h}\cdot\left(\nabla\ln q_{t_{n}-\alpha_{i}h}(x_{t_{n}-\alpha_{i}h})-\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\right)\,\mathrm{d}s\right\|^{2} (27)
+3\mathbbEi=1Rtniδtn(i1)δ(ehαihes(tnh))lnqs(xn(tns))ds2.\displaystyle+3\operatorname*{\mathbb{E}}\left\|\sum_{i=1}^{R}\int_{t_{n}-i\delta}^{t_{n}-(i-1)\delta}\left(e^{h-\alpha_{i}h}-e^{s-(t_{n}-h)}\right)\cdot\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\,\mathrm{d}s\right\|^{2}. (28)

Equation 26 is identical to Equation 24 in Lemma B.4, and can be bounded by

Equation 26\lesssimh2εsc2+L4h4\mathbbEx^nxtn2+L4dh6R2(L1tnh))).\displaystyle\lx@cref{creftypecap~refnum}{eq:parallel_predictor_var_term_1}\lesssim h^{2}\cdot\varepsilon_{\mathrm{sc}}^{2}+L^{4}h^{4}\cdot\operatorname*{\mathbb{E}}\left\|\widehat{x}_{n}-x_{t_{n}}\right\|^{2}+\frac{L^{4}dh^{6}}{R^{2}}(L\lor\frac{1}{t_{n}-h}))).

Next we will bound Equation 27. By Lemma D.1 and Lemma A.1,

\mathbbEqtnαjh(xn(αjh))lnqs(xn(tns))2\lesssimL2dδ2(L1tnh)+L2\mathbbEx^nxtn2,\operatorname*{\mathbb{E}}\left\|\nabla q_{t_{n}-\alpha_{j}h}(x_{n}^{*}(\alpha_{j}h))-\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\right\|^{2}\lesssim L^{2}d\delta^{2}(L\lor\frac{1}{t_{n}-h})+L^{2}\cdot\operatorname*{\mathbb{E}}\left\|\widehat{x}_{n}-x_{t_{n}}\right\|^{2},

hence Equation 27 can be bounded with similar calculations as for Equation 23 in Lemma B.2, by the following term:

12R2δ2(L2dδ2(L1tnh)+O(L2)\mathbbEx^nxtn2)\lesssimL2dh4R2(L1tnh)+L2h2\mathbbEx^nxtn2.12R^{2}\delta^{2}\cdot\left(L^{2}d\delta^{2}(L\lor\frac{1}{t_{n}-h})+O(L^{2})\cdot\operatorname*{\mathbb{E}}\left\|\widehat{x}_{n}-x_{t_{n}}\right\|^{2}\right)\lesssim\frac{L^{2}dh^{4}}{R^{2}}(L\lor\frac{1}{t_{n}-h})+L^{2}h^{2}\operatorname*{\mathbb{E}}\left\|\widehat{x}_{n}-x_{t_{n}}\right\|^{2}.

Finally we will bound Equation 28 Since both αi\alpha_{i} and ss belong to the range [(i1)δ,iδ][(i-1)\delta,i\delta],

ehαihes(tnh)eh(e(i1)δeiδ)ehδ2δ.e^{h-\alpha_{i}h}-e^{s-(t_{n}-h)}\leq e^{h}(e^{-(i-1)\delta}-e^{-i\delta})\leq e^{h}\cdot\delta\leq 2\delta.

Moreover, by the fact that \mathbbElnqs(xtn)2Ld\operatorname*{\mathbb{E}}\left\|\nabla\ln q_{s}(x_{t_{n}})\right\|^{2}\leq Ld (by integration by parts), Lemma D.1, Lemma A.1 and the fact that Lδ=o(1)L\delta=o(1), we have

\mathbbElnqs(xn(tns))2\displaystyle\operatorname*{\mathbb{E}}\left\|\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\right\|^{2} \lesssim\mathbbElnqs(xtn)2+\mathbbElnqs(xn(tns))lnqs(xtn)2\displaystyle\lesssim\operatorname*{\mathbb{E}}\left\|\nabla\ln q_{s}(x_{t_{n}})\right\|^{2}+\operatorname*{\mathbb{E}}\left\|\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))-\nabla\ln q_{s}(x_{t_{n}})\right\|^{2}
\lesssimLd+L2exp(Lδ)x^nxtn2\lesssimLd+L2x^nxtn2.\displaystyle\lesssim Ld+L^{2}\exp(L\delta)\left\|\widehat{x}_{n}-x_{t_{n}}\right\|^{2}\lesssim Ld+L^{2}\left\|\widehat{x}_{n}-x_{t_{n}}\right\|^{2}.

Hence Equation 28 can be bounded by

3\mathbbEi=1Rtniδtn(i1)δ(ehαihes(tnh))lnqs(xn(tns))ds2\displaystyle 3\operatorname*{\mathbb{E}}\left\|\sum_{i=1}^{R}\int_{t_{n}-i\delta}^{t_{n}-(i-1)\delta}\left(e^{h-\alpha_{i}h}-e^{s-(t_{n}-h)}\right)\cdot\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\,\mathrm{d}s\right\|^{2}
\displaystyle\leq 3Rδi=1Rtniδtn(i1)δ\mathbbE2δlnqs(xn(tns))2\displaystyle 3R\delta\sum_{i=1}^{R}\int_{t_{n}-i\delta}^{t_{n}-(i-1)\delta}\operatorname*{\mathbb{E}}\left\|2\delta\nabla\ln q_{s}(x_{n}^{*}(t_{n}-s))\right\|^{2}
\lesssim\displaystyle\lesssim 3Rδ4δ2i=1Rtniδtn(i1)δ(Ld+L2x^nxtn2)\displaystyle 3R\delta\cdot 4\delta^{2}\sum_{i=1}^{R}\int_{t_{n}-i\delta}^{t_{n}-(i-1)\delta}\left(Ld+L^{2}\left\|\widehat{x}_{n}-x_{t_{n}}\right\|^{2}\right)
\lesssim\displaystyle\lesssim R2δ4(Ld+L2x^nxtn2)=Ldh4R2+L2h4R2x^nxtn2.\displaystyle R^{2}\delta^{4}\left(Ld+L^{2}\left\|\widehat{x}_{n}-x_{t_{n}}\right\|^{2}\right)=\frac{Ldh^{4}}{R^{2}}+\frac{L^{2}h^{4}}{R^{2}}\left\|\widehat{x}_{n}-x_{t_{n}}\right\|^{2}.

By adding together Equation 26, Equation 27 and Equation 28, and combining terms, we conclude that

\mathbbEαx^n+1xn(h)2\displaystyle\operatorname*{\mathbb{E}}_{\alpha}\|\widehat{x}_{n+1}-x_{n}^{*}(h)\|^{2} \lesssimh2εsc2+L4h4\mathbbEx^nxtn2+L4dh6R2(L1tnh)))\displaystyle\lesssim h^{2}\cdot\varepsilon_{\mathrm{sc}}^{2}+L^{4}h^{4}\cdot\operatorname*{\mathbb{E}}\left\|\widehat{x}_{n}-x_{t_{n}}\right\|^{2}+\frac{L^{4}dh^{6}}{R^{2}}(L\lor\frac{1}{t_{n}-h})))
+L2dh4R2(L1tnh)+L2h2\mathbbEx^nxtn2+Ldh4R2+L2h4R2x^nxtn2\displaystyle\quad+\frac{L^{2}dh^{4}}{R^{2}}(L\lor\frac{1}{t_{n}-h})+L^{2}h^{2}\operatorname*{\mathbb{E}}\left\|\widehat{x}_{n}-x_{t_{n}}\right\|^{2}+\frac{Ldh^{4}}{R^{2}}+\frac{L^{2}h^{4}}{R^{2}}\left\|\widehat{x}_{n}-x_{t_{n}}\right\|^{2}
\lesssimh2εsc2+L2dh4R2(L1tnh)+L2h2\mathbbEx^nxtn2.\displaystyle\lesssim h^{2}\cdot\varepsilon_{\mathrm{sc}}^{2}+\frac{L^{2}dh^{4}}{R^{2}}(L\lor\frac{1}{t_{n}-h})+L^{2}h^{2}\operatorname*{\mathbb{E}}\left\|\widehat{x}_{n}-x_{t_{n}}\right\|^{2}\,.\qed

We can now prove our main guarantee for the parallel predictor step, which states that with logarithmically many parallel rounds and O~(d)\widetilde{O}(\sqrt{d}) score estimate queries over a short time interval (of length O(1/L)O(1/L)) of the reverse process, the algorithm does not drift too far from the true ODE. Our proof follows a similar flow as Lemma A.5.

Note that due to the existence of RnR_{n} midpoints in each step, we can set hn=Θ(1)h_{n}=\Theta(1). We will set hnh_{n} to be Θ(1L)\Theta(\frac{1}{L}), unless tnt_{n} is close to the end time δ\delta (see Algorithm 9 for the global algorithm and timeline). If tnt_{n} is close to δ\delta, we will repeated half hnh_{n} as nn increases, until we reach the end time.

Theorem B.6.

Assume L1L\geq 1. Let β1\beta\geq 1 be an adjustable parameter. When we set hn=min{14L,tn/2,tnδ}h_{n}=\min\{\frac{1}{4L},t_{n}/2,t_{n}-\delta\}, K\gtrsimlog(βdε)K\gtrsim\log(\frac{\beta\sqrt{d}}{\varepsilon}) and Rn=hnβLdε=O(βdε)R_{n}=h_{n}\cdot\beta\cdot\frac{L\sqrt{d}}{\varepsilon}=O\left(\frac{\beta\sqrt{d}}{\varepsilon}\right), the Wasserstein distance between the true ODE process and the process in Algorithm 7, both starting from x^0qt0\hat{x}_{0}\sim q_{t_{0}} and run for total time TN\lesssim1LT_{N}\lesssim\frac{1}{L} is bounded by

W2(x^N,xTN)\lesssimεscL+εβL.W_{2}(\widehat{x}_{N},x_{T_{N}})\lesssim\frac{\varepsilon_{\mathrm{sc}}}{L}+\frac{\varepsilon}{\beta\sqrt{L}}.
Proof.

To avoid confusion, we will reserve x^n\widehat{x}_{n} as the result of running Algorithm 7 for nn steps, starting at x^0q0\widehat{x}_{0}\sim q_{0}. We will use y^n\widehat{y}_{n} to denote the result from running the true ODE process, starting at x^n1\widehat{x}_{n-1}. Then by an identical calculation as in Lemma A.5,

\mathbbEαxtNx^N2exp(O(LhN1))xtN1x^N12+2LhN1\mathbbEαx^NyN2+\mathbbEαx^NyN2.\displaystyle\operatorname*{\mathbb{E}}_{\alpha}\left\|x_{t_{N}}-\widehat{x}_{N}\right\|^{2}\leq\exp\left(O(Lh_{N-1})\right)\left\|x_{t_{N-1}}-\widehat{x}_{N-1}\right\|^{2}+\frac{2}{Lh_{N-1}}\left\|\operatorname*{\mathbb{E}}_{\alpha}\widehat{x}_{N}-y_{N}\right\|^{2}+\operatorname*{\mathbb{E}}_{\alpha}\left\|\widehat{x}_{N}-y_{N}\right\|^{2}.

Now we do a similar calculation as in Lemma A.5, but utilizes the bias and variance of one step in the parallel algorithm instead of sequential. Taking the expectation wrt x^0qt0\widehat{x}_{0}\sim q_{t_{0}}, by Lemmas B.4 and B.5,

\mathbbExtNx^N2\displaystyle\operatorname*{\mathbb{E}}\|x_{t_{N}}-\widehat{x}_{N}\|^{2}
exp(O(LhN1))\mathbbExtN1x^N12+2LhN1\mathbbE\mathbbEαx^NyN2+\mathbbEx^NyN2\displaystyle\leq\exp\left(O(Lh_{N-1})\right)\operatorname*{\mathbb{E}}\|x_{t_{N-1}}-\widehat{x}_{N-1}\|^{2}+\frac{2}{Lh_{N-1}}\operatorname*{\mathbb{E}}\|\operatorname*{\mathbb{E}}_{\alpha}\widehat{x}_{N}-y_{N}\|^{2}+\operatorname*{\mathbb{E}}\|\widehat{x}_{N}-y_{N}\|^{2}
exp(O(LhN1))\mathbbExtN1x^N12\displaystyle\leq\exp\left(O(Lh_{N-1})\right)\operatorname*{\mathbb{E}}\|x_{t_{N-1}}-\widehat{x}_{N-1}\|^{2}
+O(1LhN1(hN12εsc2+L4dhN16Rn2(L1tN1)+L4hN14\mathbbExtN1x^N12))\displaystyle\qquad+O\left(\frac{1}{Lh_{N-1}}\left(h_{N-1}^{2}\varepsilon_{\mathrm{sc}}^{2}+\frac{L^{4}dh_{N-1}^{6}}{R_{n}^{2}}\left(L\lor\frac{1}{t_{N-1}}\right)+L^{4}h_{N-1}^{4}\operatorname*{\mathbb{E}}\|x_{t_{N-1}}-\widehat{x}_{N-1}\|^{2}\right)\right)
+O(hN12εsc2+L2dhN14Rn2(L1tN1)+L2hN12\mathbbExtN1x^N12)\displaystyle\qquad+O\left(h_{N-1}^{2}\varepsilon_{\mathrm{sc}}^{2}+\frac{L^{2}dh_{N-1}^{4}}{R_{n}^{2}}\left(L\lor\frac{1}{t_{N-1}}\right)+L^{2}h_{N-1}^{2}\operatorname*{\mathbb{E}}\|x_{t_{N-1}}-\widehat{x}_{N-1}\|^{2}\right)
exp(O(LhN1))\mathbbExtN1x^N12\displaystyle\leq\exp\left(O(Lh_{N-1})\right)\operatorname*{\mathbb{E}}\|x_{t_{N-1}}-\widehat{x}_{N-1}\|^{2}
+O(hN1εsc2L+hN12εsc2+L3dhN15+L2dhN14Rn2(L1tN1))\displaystyle\qquad+O\left(\frac{h_{N-1}\varepsilon_{\mathrm{sc}}^{2}}{L}+h_{N-1}^{2}\varepsilon_{\mathrm{sc}}^{2}+\frac{L^{3}dh_{N-1}^{5}+L^{2}dh_{N-1}^{4}}{R_{n}^{2}}\left(L\lor\frac{1}{t_{N-1}}\right)\right)

Since hN1<1Lh_{N-1}<\frac{1}{L}, the term L3dhN15L^{3}dh_{N-1}^{5} is dominated by the term L2dhN14L^{2}dh_{N-1}^{4} and the term hN12εsc2h_{N-1}^{2}\varepsilon_{\mathrm{sc}}^{2} os dominated by hN1εsc2L\frac{h_{N-1}\varepsilon_{\mathrm{sc}}^{2}}{L}. By induction, noting that xt0=x^0x_{t_{0}}=\widehat{x}_{0}, we have

\mathbbExtNx^N2\lesssimn=0N1(hnLεsc2+L2dhn4Rn2(L1tn))exp(O(Li=n+1N1hi)).\operatorname*{\mathbb{E}}\|x_{t_{N}}-\widehat{x}_{N}\|^{2}\lesssim\sum_{n=0}^{N-1}\left(\frac{h_{n}}{L}\varepsilon_{\mathrm{sc}}^{2}+\frac{L^{2}dh_{n}^{4}}{R_{n}^{2}}\cdot\left(L\lor\frac{1}{t_{n}}\right)\right)\cdot\exp\left(O\left(L\sum_{i=n+1}^{N-1}h_{i}\right)\right)\,.

Since i=n+1N1hi1L\sum_{i=n+1}^{N-1}h_{i}\leq\frac{1}{L}, exp(O(Li=n+1N1hi))\exp\left(O\left(L\sum_{i=n+1}^{N-1}h_{i}\right)\right) is a constant. Moreover, by our choice of hnh_{n} , it is always the case that hntn2tnhnh_{n}\leq\frac{t_{n}}{2}\leq t_{n}-h_{n} and that hn1Lh_{n}\leq\frac{1}{L}. Therefore

L2dhn4Rn2(L1tnhn)L2dhn3Rn2L2dhnβ2L2dε2=hnε2β2,\displaystyle\frac{L^{2}dh_{n}^{4}}{R_{n}^{2}}\left(L\lor\frac{1}{t_{n}-h_{n}}\right)\leq\frac{L^{2}dh_{n}^{3}}{R_{n}^{2}}\leq\frac{L^{2}dh_{n}}{\frac{\beta^{2}L^{2}d}{\varepsilon^{2}}}=\frac{h_{n}\varepsilon^{2}}{\beta^{2}},

and thus

\mathbbExtNx^N2\lesssimn=0N1hnLεsc2+hnε2β2\lesssimεsc2L2+ε2Lβ2.\displaystyle\operatorname*{\mathbb{E}}\|x_{t_{N}}-\widehat{x}_{N}\|^{2}\lesssim\sum_{n=0}^{N-1}\frac{h_{n}}{L}\varepsilon_{\mathrm{sc}}^{2}+\frac{h_{n}\varepsilon^{2}}{\beta^{2}}\lesssim\frac{\varepsilon_{\mathrm{sc}}^{2}}{L^{2}}+\frac{\varepsilon^{2}}{L\beta^{2}}.

We conclude that

W2(xtN,x^N)\displaystyle W_{2}(x_{t_{N}},\widehat{x}_{N}) =\mathbbExtNx^N2\lesssimεscL+εβL.\displaystyle=\sqrt{\operatorname*{\mathbb{E}}\|x_{t_{N}}-\widehat{x}_{N}\|^{2}}\lesssim\frac{\varepsilon_{\mathrm{sc}}}{L}+\frac{\varepsilon}{\beta\sqrt{L}}.\qed

B.2 Corrector step

In this step we will be using the parallel algorithm in [ACV24] to estimate the underdamped Langevin diffusion process. Since we will be fixing the score function in time, we will use lnq\nabla\ln q to denote the true score function for the diffusion process, and s^\widehat{s} to denote the estimated score function. We will choose the friction parameter to be γL\gamma\asymp\sqrt{L}.

Algorithm 8 Corrector Step (Parallel) [ACV24]

Input parameters:

  • Starting sample (x^0,v^0)p𝒩(0,Id)(\widehat{x}_{0},\widehat{v}_{0})\sim p\otimes\mathcal{N}(0,I_{d}), Number of steps NN, Step size hh, Score estimates s^lnq\widehat{s}\approx\nabla\ln q, Number of midpoint estimates RR, δ:=hR\delta:=\frac{h}{R}

  1. 1.

    For n=0,,N1n=0,\dots,N-1:

    1. (a)

      Let tn=nht_{n}=nh

    2. (b)

      Let (x^n,i(k),v^n,i(k))(\widehat{x}^{(k)}_{n,i},\widehat{v}^{(k)}_{n,i}) represent the algorithmic estimate of (xtn+ih/R,vtn+ih/R)(x_{t_{n}+ih/R},v_{t_{n}+ih/R}) at iteration kk.

    3. (c)

      Let (ζx,ζv)(\zeta^{x},\zeta^{v}) be a correlated gaussian vector corresponding to change caused by the Brownian motion term in h/Rh/R time (see more detail in [ACV24])

    4. (d)

      For i=0,,Ri=0,\cdots,R in parallel: Let (x^n,i(0),v^n,i(0))=(x^n,v^n)(\widehat{x}^{(0)}_{n,i},\widehat{v}^{(0)}_{n,i})=(\widehat{x}_{n},\widehat{v}_{n})

    5. (e)

      For k=1,,Kk=1,\cdots,K:

      For i=1,,Ri=1,\cdots,R in parallel:

      x^n,i(k):=x^n,i1(k1)+1exp(γh/R)γv^n,i1(k1)h/R(1exp(γh/R))/γγs^(xn,jk1)+ζx\widehat{x}^{(k)}_{n,i}:=\widehat{x}^{(k-1)}_{n,i-1}+\frac{1-\exp(-\gamma h/R)}{\gamma}\cdot\widehat{v}^{(k-1)}_{n,i-1}-\frac{h/R-(1-\exp(-\gamma h/R))/\gamma}{\gamma}\cdot\widehat{s}(x^{k-1}_{n,j})+\zeta^{x}

      v^n,i(k)=exp(γh/R)v^n,i1(k1)1exp(γh/R)γs^(xn,i1(k1))+ζv\widehat{v}^{(k)}_{n,i}=\exp(-\gamma h/R)\cdot\widehat{v}^{(k-1)}_{n,{i-1}}-\frac{1-\exp(-\gamma h/R)}{\gamma}\cdot\widehat{s}(x^{(k-1)}_{n,i-1})+\zeta^{v}

    6. (f)

      (x^n+1,v^n+1)=(x^n,RK,v^n,RK)(\widehat{x}_{n+1},\widehat{v}_{n+1})=(\widehat{x}^{K}_{n,R},\widehat{v}^{K}_{n,R})

  2. 2.

    Let tN=Nht_{N}=Nh

  3. 3.

    Return x^N,tN\widehat{x}_{N},t_{N}.

Let TNT_{N} denote the total time the parallel corrector step is run (namely, TN=nhT_{N}=nh). Consider two continuous underdamped Langevin diffusion processes u(t)=(x(t),v(t))u^{*}(t)=(x^{*}(t),v^{*}(t)) and ut0+t=(xt0+t,vt0+t)u_{t_{0}+t}=(x_{t_{0}+t},v_{t_{0}+t}) with coupled brownian motions. The first one start from position x(0)=x^0px^{*}(0)=\widehat{x}_{0}\sim p and the second one start from position xt0qx_{t_{0}}\sim q. Both processes start with velocity v(0)=vt0𝒩(0,Id)v^{*}(0)=v_{t_{0}}\sim\mathcal{N}(0,I_{d}). We will bound both the distance measure between x(t)x^{*}(t) and the true sample xt0+tx_{t_{0}+t}, and the distance measure between x(tN)x^{*}(t_{N}) and outputs of Algorithm 8. First, [CCL+23a] gives the following bound on the total variation error between x(TN)x^{*}(T_{N}) and xtNx_{t_{N}}.

Lemma B.7 ([CCL+23a], Lemma 9).

If h\lesssim1Lh\lesssim\frac{1}{\sqrt{L}}, then

TV(x(TN),xtN)\lesssimW2(p,q)L1/4TN3/2.\textup{{TV}}(x^{*}(T_{N}),x_{t_{N}})\lesssim\frac{W_{2}(p,q)}{L^{1/4}T_{N}^{3/2}}\,.

Next, [ACV24] bounds the discretization error in Algorithm 8 in terms of quantities that relates to the supremum of \mathbbElnq(x(t))2\operatorname*{\mathbb{E}}\left\|\nabla\ln q(x^{*}(t))\right\|^{2} and \mathbbEv(t)2\operatorname*{\mathbb{E}}\left\|v^{*}(t)\right\|^{2} where t[0,TN]t\in[0,T_{N}].

Lemma B.8 ([ACV24], Theorem 20, Implicit).

Assume L1L\geq 1. In Algorithm 8, assume K\gtrsimlog(d)K\gtrsim\log(d) (for sufficiently large constant), K\lesssim4logRK\lesssim 4\log R and h\lesssim1Lh\lesssim\frac{1}{\sqrt{L}}. Then

KL(x^N,x(TN))\lesssimTNL(εsc2+L2(γdh3R4+h2R2𝒫+h4R4𝒬)),\displaystyle\textup{{KL}}(\widehat{x}_{N},x^{*}(T_{N}))\lesssim\frac{T_{N}}{\sqrt{L}}\cdot\left(\varepsilon_{\mathrm{sc}}^{2}+L^{2}(\frac{\gamma dh^{3}}{R^{4}}+\frac{h^{2}}{R^{2}}{\cal P}+\frac{h^{4}}{R^{4}}{\cal Q})\right),

where 𝒫=supt[0,TN]\mathbbE[v(t)2]{\cal P}=\sup_{t\in[0,T_{N}]}\operatorname*{\mathbb{E}}[\left\|v^{*}(t)\right\|^{2}] and 𝒬=supt[0,TN]\mathbbE[lnq(x(t))2]{\cal Q}=\sup_{t\in[0,T_{N}]}\operatorname*{\mathbb{E}}[\left\|\nabla\ln q(x^{*}(t))\right\|^{2}].

To reason about the value of 𝒫{\cal P} and 𝒬{\cal Q}, we will use the following lemma in [CCL+23a].

Lemma B.9 ([CCL+23a], Lemma 10).

For any t\lesssim1Lt\lesssim\frac{1}{\sqrt{L}},

\mathbbEu(t)ut0+t2\lesssimW22(p,q).\operatorname*{\mathbb{E}}\left\|u^{*}(t)-u_{t_{0}+t}\right\|^{2}\lesssim W_{2}^{2}(p,q).
Lemma B.10.

Assume L1L\geq 1. For any TN\lesssim1LT_{N}\lesssim\frac{1}{\sqrt{L}},

supt[0,TN]\mathbbE[lnq(x(t))2]\lesssimL2W22(p,q)+Ld\sup_{t\in[0,T_{N}]}\operatorname*{\mathbb{E}}[\left\|\nabla\ln q(x^{*}(t))\right\|^{2}]\lesssim L^{2}W_{2}^{2}(p,q)+Ld

and

supt[0,TN]\mathbbEv(t)2\lesssimW22(p,q)+d.\sup_{t\in[0,T_{N}]}\operatorname*{\mathbb{E}}\left\|v^{*}(t)\right\|^{2}\lesssim W_{2}^{2}(p,q)+d.
Proof.

Note that (q,𝒩(0,Id))(q,\mathcal{N}(0,I_{d})) is a stationary distribution of the underdamped Langevin diffusion process, hence xtqx_{t}\sim q and vt𝒩(0,Id)v_{t}\sim\mathcal{N}(0,I_{d}). Hence \mathbbElnq(xt)2Ld\operatorname*{\mathbb{E}}\left\|\nabla\ln q(x_{t})\right\|^{2}\leq Ld by integration by parts. Similarly, \mathbbEvt2=\mathbbE[𝒩(0,Id)2]\lesssimd\operatorname*{\mathbb{E}}\left\|v_{t}\right\|^{2}=\operatorname*{\mathbb{E}}[\left\|\mathcal{N}(0,I_{d})\right\|^{2}]\lesssim d. Since TN\lesssim1LT_{N}\lesssim\frac{1}{\sqrt{L}}, for any t[0,TN]t\in[0,T_{N}], we can now bound \mathbbElnq(x(t))2\operatorname*{\mathbb{E}}\left\|\nabla\ln q(x^{*}(t))\right\|^{2} and \mathbbEv(t)2\operatorname*{\mathbb{E}}\left\|v^{*}(t)\right\|^{2} by Lemma B.9 as follows:

\mathbbElnq(x(t))2\displaystyle\operatorname*{\mathbb{E}}\left\|\nabla\ln q(x^{*}(t))\right\|^{2} 2\mathbbElnq(xt)2+2\mathbbElnq(x(t))lnq(xt)2\displaystyle\leq 2\cdot\operatorname*{\mathbb{E}}\left\|\nabla\ln q(x_{t})\right\|^{2}+2\cdot\operatorname*{\mathbb{E}}\left\|\nabla\ln q(x^{*}(t))-\nabla\ln q(x_{t})\right\|^{2}
\lesssimLd+L2\mathbbEx(t)xt2\lesssimLd+L2\mathbbEu(t)ut2\lesssimLd+L2W22(p,q),\displaystyle\lesssim Ld+L^{2}\operatorname*{\mathbb{E}}\left\|x^{*}(t)-x_{t}\right\|^{2}\lesssim Ld+L^{2}\operatorname*{\mathbb{E}}\left\|u^{*}(t)-u_{t}\right\|^{2}\lesssim Ld+L^{2}W_{2}^{2}(p,q),

and

\mathbbEv(t)2\displaystyle\operatorname*{\mathbb{E}}\left\|v^{*}(t)\right\|^{2} 2\mathbbEvt2+2\mathbbEv(t)vt2\displaystyle\leq 2\cdot\operatorname*{\mathbb{E}}\left\|v_{t}\right\|^{2}+2\cdot\operatorname*{\mathbb{E}}\left\|v^{*}(t)-v_{t}\right\|^{2}
\lesssim\mathbbEvt2+\mathbbEu(t)ut2\lesssimd+W22(p,q).\displaystyle\lesssim\operatorname*{\mathbb{E}}\left\|v_{t}\right\|^{2}+\operatorname*{\mathbb{E}}\left\|u^{*}(t)-u_{t}\right\|^{2}\lesssim d+W_{2}^{2}(p,q).\qed
Theorem B.11.

Let β1\beta\geq 1 be an adjustable parameter. Algorithm 8 with parameter h=18Lh=\frac{1}{\sqrt{8L}},R=βΘ(dε)R=\beta\cdot\Theta(\frac{\sqrt{d}}{\varepsilon}), K=4log(R)K=4\cdot\log(R) and TN\lesssim1LT_{N}\lesssim\frac{1}{\sqrt{L}} has discretization error

TV(x^N,x(TN))\lesssimKL(x^N,x(TN))\lesssimεscL+εβ+εβdW2(p,q).\textup{{TV}}(\widehat{x}_{N},x^{*}(T_{N}))\lesssim\sqrt{\textup{{KL}}(\widehat{x}_{N},x^{*}(T_{N}))}\lesssim\frac{\varepsilon_{\mathrm{sc}}}{\sqrt{L}}+\frac{\varepsilon}{\beta}+\frac{\varepsilon}{\beta\sqrt{d}}\cdot W_{2}(p,q).
Proof.

Since TN\lesssim1LT_{N}\lesssim\frac{1}{\sqrt{L}} and h=Θ(1L)h=\Theta(\frac{1}{\sqrt{L}}), N=O(1)N=O(1). Plugging Lemma B.10 into Lemma B.8, we get that

KL(x^N,x(TN))\displaystyle\textup{{KL}}(\widehat{x}_{N},x^{*}(T_{N})) \lesssimTNL(εsc2+L2(γdh3R4+h2R2(d+W22(p,q))+h4R4(L2W22(p,q)+Ld)))\displaystyle\lesssim\frac{T_{N}}{\sqrt{L}}\cdot\left(\varepsilon_{\mathrm{sc}}^{2}+L^{2}\left(\frac{\gamma dh^{3}}{R^{4}}+\frac{h^{2}}{R^{2}}\cdot(d+W_{2}^{2}(p,q))+\frac{h^{4}}{R^{4}}\cdot(L^{2}W_{2}^{2}(p,q)+Ld)\right)\right)
\lesssim1L(εsc2+dR4+LdR2+LdR4+(LR2+L2R4)W22(p,q))\displaystyle\lesssim\frac{1}{L}\cdot\left(\varepsilon_{\mathrm{sc}}^{2}+\frac{d}{R^{4}}+\frac{Ld}{R^{2}}+\frac{Ld}{R^{4}}+\left(\frac{L}{R^{2}}+\frac{L^{2}}{R^{4}}\right)\cdot W_{2}^{2}(p,q)\right)
=εsc2L+ε2β2+ε2β2dW22(p,q).\displaystyle=\frac{\varepsilon_{\mathrm{sc}}^{2}}{L}+\frac{\varepsilon^{2}}{\beta^{2}}+\frac{\varepsilon^{2}}{\beta^{2}d}W_{2}^{2}(p,q).

The first to second line is by combining terms and setting h=Θ(1L)h=\Theta(\frac{1}{\sqrt{L}}), γ=Θ(L)\gamma=\Theta(\sqrt{L}), and the second to third line is by setting R=βΘ(dε)R=\beta\cdot\Theta(\frac{\sqrt{d}}{\varepsilon}). Taking the square root of KL(x^N,x(TN))\textup{{KL}}(\widehat{x}_{N},x^{*}(T_{N})) yields the claim. ∎

Theorem B.12.

Let β1\beta\geq 1 be an adjustable parameter. When Algorithm 8 is initialized at (x^0,v^0)p𝒩(0,Id)(\widehat{x}_{0},\widehat{v}_{0})\sim p\otimes\mathcal{N}(0,I_{d}), there exists parameters h=18Lh=\frac{1}{\sqrt{8L}},R=βΘ(dε)R=\beta\cdot\Theta(\frac{\sqrt{d}}{\varepsilon}), K=Θ(log(β2dε2))K=\Theta(\log(\frac{\beta^{2}d}{\varepsilon^{2}})) and TN\lesssim1LT_{N}\lesssim\frac{1}{\sqrt{L}} such that the total variation distance between the final output of Algorithm 8 and the true distribution can be bounded as

TV(x^N,xtN)\lesssimεscL+εβ+LW2(p,q).\displaystyle\textup{{TV}}(\widehat{x}_{N},x_{t_{N}})\lesssim\frac{\varepsilon_{\mathrm{sc}}}{\sqrt{L}}+\frac{\varepsilon}{\beta}+\sqrt{L}\cdot W_{2}(p,q).
Proof.

By triangle inequality, TV(x^N,xtN)TV(x^N,x(TN))+TV(x(TN),xtN)\textup{{TV}}(\widehat{x}_{N},x_{t_{N}})\leq\textup{{TV}}(\widehat{x}_{N},x^{*}(T_{N}))+\textup{{TV}}(x^{*}(T_{N}),x_{t_{N}}). Combining Lemma B.7 and Theorem B.11 yields the claim. ∎

B.3 End-to-end analysis

Algorithm 9 ParallelAlgorithm

Input parameters:

  • Start time TT, End time δ\delta, Corrector steps time Tcorr\lesssim1/LT_{\mathrm{corr}}\lesssim 1/\sqrt{L}, Number of predictor-corrector steps N0N_{0}, Score estimates s^t\widehat{s}_{t}

  1. 1.

    Draw x^0𝒩(0,Id)\widehat{x}_{0}\sim\mathcal{N}(0,I_{d}).

  2. 2.

    For n=0,,N0n=0,\dots,N_{0}:

    1. (a)

      Starting from x^n\widehat{x}_{n}, run Algorithm 7 with starting time Tn/LT-n/L with total time min(1/L,Tn/Lδ)\min(1/L,T-n/L-\delta). Let the result be x^n+1\widehat{x}_{n+1}^{\prime}.

    2. (b)

      Starting from x^n+1\widehat{x}_{n+1}^{\prime}, run Algorithm 8 for total time TcorrT_{\mathrm{corr}} and score estimate s^T(n+1)/L\widehat{s}_{T-(n+1)/L} to obtain x^n+1\widehat{x}_{n+1}.

  3. 3.

    Return x^N0+1\widehat{x}_{N_{0}+1}.

Theorem B.13 (Parallel End to End Error).

By setting T=Θ(log(d\mathfrakm22ε2))T=\Theta\left(\log\left(\frac{d\lor\mathfrak{m}_{2}^{2}}{\varepsilon^{2}}\right)\right), Tcorr=1LT_{corr}=\frac{1}{\sqrt{L}}, δ=Θ(ε2L2(d\mathfrakm22))\delta=\Theta\left(\frac{\varepsilon^{2}}{L^{2}(d\lor\mathfrak{m}_{2}^{2})}\right), and β1=β2=Θ(Llog(d\mathfrakm22ε2))\beta_{1}=\beta_{2}=\Theta\left(L\log\left(\frac{d\lor\mathfrak{m}_{2}^{2}}{\varepsilon^{2}}\right)\right) in Algorithm 7 and Algorithm 8, when εsc\lesssimΘ~(εL)\varepsilon_{\mathrm{sc}}\lesssim\widetilde{\Theta}(\frac{\varepsilon}{\sqrt{L}}), the total variation distance between the output of Algorithm 9 and the target distribution x0qx_{0}\sim q^{*} is

TV(x^N0+1,x0)\lesssimε,\textup{{TV}}(\widehat{x}_{N_{0}+1},x_{0})\lesssim\varepsilon,

with iteration complexity Θ~(Llog2(Ld\mathfrakm22ε))\widetilde{\Theta}(L\cdot\log^{2}\left(\frac{Ld\lor\mathfrak{m}_{2}^{2}}{\varepsilon}\right)).

Proof.

Let xtnx_{t_{n}} be the result of running the true ODE for time TtnT-t_{n}, starting from xTqx_{T}\sim q^{*}. Let yny^{\prime}_{n} be the result of running the predictor step in step n1n-1 of Algorithm 9, starting from xtn1qtn1x_{t_{n-1}}\sim q_{t_{n-1}} and start time tn1t_{n-1}. In addition, let y^n\widehat{y}_{n} be the result of the corrector step in step n1n-1 of Algorithm 9, starting from yny^{\prime}_{n}.

We will first bound the error in one predictor + corrector step that starts at tn1=T(n1)/Lt_{n-1}=T-(n-1)/L. By triangle inequality of TV distance and data processing inequality (applied to x^n\widehat{x}_{n} and y^n\widehat{y}_{n}),

TV(x^n,xtn)\displaystyle\textup{{TV}}(\widehat{x}_{n},x_{t_{n}}) TV(x^n,y^n)+TV(y^n,xtn)\displaystyle\leq\textup{{TV}}(\widehat{x}_{n},\widehat{y}_{n})+\textup{{TV}}(\widehat{y}_{n},x_{t_{n}})
TV(x^n1,xtn1)+TV(y^n,xtn)\displaystyle\leq\textup{{TV}}(\widehat{x}_{n-1},x_{t_{n-1}})+\textup{{TV}}(\widehat{y}_{n},x_{t_{n}}) (29)

By Theorem B.6 parametrized by β1\beta_{1} and Theorem B.12 parametrized by β2\beta_{2},

TV(y^n,xtn)\displaystyle\textup{{TV}}(\widehat{y}_{n},x_{t_{n}}) \lesssimεscL+εβ2+LW2(yn,xtn)\displaystyle\lesssim\frac{\varepsilon_{\mathrm{sc}}}{\sqrt{L}}+\frac{\varepsilon}{\beta_{2}}+\sqrt{L}\cdot W_{2}(y^{\prime}_{n},x_{t_{n}})
\lesssimεscL+εβ2+L(εscL+εβ1L)\displaystyle\lesssim\frac{\varepsilon_{\mathrm{sc}}}{\sqrt{L}}+\frac{\varepsilon}{\beta_{2}}+\sqrt{L}\left(\frac{\varepsilon_{\mathrm{sc}}}{L}+\frac{\varepsilon}{\beta_{1}\sqrt{L}}\right)
\lesssimεscL+εmin(β1,β2).\displaystyle\lesssim\frac{\varepsilon_{\mathrm{sc}}}{\sqrt{L}}+\frac{\varepsilon}{\min(\beta_{1},\beta_{2})}.

The first line is by Theorem B.12, and the first to second line is by Theorem B.12. Next, note that at the beginning of the process, t0=Tt_{0}=T, and at the end of the process, tN0+1=δt_{N_{0}+1}=\delta. By induction on Equation 29,

TV(x^N0+1,x0)\displaystyle\textup{{TV}}(\widehat{x}_{N_{0}+1},x_{0}) TV(x0,xtN0+1)+TV(x^N0+1,xtN0+1)\displaystyle\leq\textup{{TV}}(x_{0},x_{t_{N_{0}+1}})+\textup{{TV}}(\widehat{x}_{N_{0}+1},x_{t_{N_{0}+1}})
TV(x0,xδ)+TV(xT,𝒩(0,Id))+n=1N0+1TV(y^n,xtn)\displaystyle\leq\textup{{TV}}(x_{0},x_{\delta})+\textup{{TV}}(x_{T},\mathcal{N}(0,I_{d}))+\sum_{n=1}^{N_{0}+1}\textup{{TV}}(\widehat{y}_{n},x_{t_{n}})
TV(x0,xδ)+TV(xT,𝒩(0,Id))+N0(εscL+εmin(β1,β2)).\displaystyle\leq\textup{{TV}}(x_{0},x_{\delta})+\textup{{TV}}(x_{T},\mathcal{N}(0,I_{d}))+N_{0}\cdot\left(\frac{\varepsilon_{\mathrm{sc}}}{\sqrt{L}}+\frac{\varepsilon}{\min(\beta_{1},\beta_{2})}\right).

By Lemma A.9, TV(xT,𝒩(0,Id))\lesssim(d+\mathfrakm2)exp(T)\textup{{TV}}(x_{T},\mathcal{N}(0,I_{d}))\lesssim(\sqrt{d}+\mathfrak{m}_{2})\exp(-T). By [LLT23, Lemma 6.4], TV(x0,xδ)ε\textup{{TV}}(x_{0},x_{\delta})\leq\varepsilon. Therefore by setting T=Θ(log(d\mathfrakm22ε2))T=\Theta\left(\log\left(\frac{d\lor\mathfrak{m}_{2}^{2}}{\varepsilon^{2}}\right)\right), N0=Θ(Llog(d\mathfrakm22ε2))N_{0}=\Theta\left(L\log\left(\frac{d\lor\mathfrak{m}_{2}^{2}}{\varepsilon^{2}}\right)\right) and β1=β2=Θ(Llog(d\mathfrakm22ε2))\beta_{1}=\beta_{2}=\Theta\left(L\log\left(\frac{d\lor\mathfrak{m}_{2}^{2}}{\varepsilon^{2}}\right)\right) in Algorithm 7 and Algorithm 8, when εsc\lesssimΘ~(εL)\varepsilon_{\mathrm{sc}}\lesssim\widetilde{\Theta}(\frac{\varepsilon}{\sqrt{L}}), we obtain TV(x^N0+1,x0)\lesssimε\textup{{TV}}(\widehat{x}_{N_{0}+1},x_{0})\lesssim\varepsilon.

The iteration complexity of Algorithm 9 given above parameters is roughly number of predictor-corrector steps times the iteration complexity in one predictor-corrector step. Note that in any corrector step and any predictor step except the last one, only N=O(1)N=O(1) number of sub-steps are taken, therefore the iteration complexity of one predictor step (except the last step) is Θ(log(β1dε))\Theta(\log(\frac{\beta_{1}\sqrt{d}}{\varepsilon})) and iteration complexity of one corrector step is Θ(log(β22dε2))\Theta(\log(\frac{\beta_{2}^{2}d}{\varepsilon^{2}})). In the last predictor step, the number of steps taken is O(log(1δL))=O(log(L)+T)O\left(\log\left(\frac{1}{\delta L}\right)\right)=O(\log(L)+T), and thus the iteration complexity is Θ((log(L)+T)log(β1dε))\Theta((\log(L)+T)\cdot\log(\frac{\beta_{1}\sqrt{d}}{\varepsilon})). We conclude that the total iteration complexity of Algorithm 9 is

LT(Θ(log(β1dε))+Θ(log(β22dε2)))+Θ((log(L)+T)log(β1dε))=Θ~(Llog2(Ld\mathfrakm22ε)).LT\cdot\left(\Theta(\log(\frac{\beta_{1}\sqrt{d}}{\varepsilon}))+\Theta(\log(\frac{\beta_{2}^{2}d}{\varepsilon^{2}}))\right)+\Theta\left((\log(L)+T)\cdot\log\left(\frac{\beta_{1}\sqrt{d}}{\varepsilon}\right)\right)=\widetilde{\Theta}\left(L\log^{2}\left(\frac{Ld\lor\mathfrak{m}_{2}^{2}}{\varepsilon}\right)\right)\,.

Appendix C Log-concave sampling in total variation

In this section, we give a simple proof, using our observation about trading off the time spent on the predictor and corrector steps, of an improved bound for sampling from a log-concave distribution in total variation. Note that for this section, we assume that s^\widehat{s} is the true score of the distribution and is known, as is standard in the log-concave sampling literature.

We begin by recalling Shen and Lee’s randomized midpoint method applied to approximate the underdamped Langevin process, for log-concave sampling in the Wasserstein metric [SL19] in Algorithm 10.

Algorithm 10 RandomizedMidpointMethod [SL19]

Input parameters:

  • Starting sample x^0\widehat{x}_{0}, Starting v0v_{0}, Number of steps NN, Step size hh, Score function s^\widehat{s}, u=1Lu=\frac{1}{L}.

  1. 1.

    For n=0,,N1n=0,\dots,N-1:

    1. (a)

      Randomly sample α\alpha uniformly from [0,1][0,1].

    2. (b)

      Generate Gaussian random variable (W1(n),W2(n),W3(n))\mathbbR3d(W_{1}^{(n)},W_{2}^{(n)},W_{3}^{(n)})\in\mathbb{R}^{3d} as in Appendix A of [SL19].

    3. (c)

      Let x^n+12=x^n+12(1e2αh)vn12u(αh12(1e2(hαh)))s^(xn)+uW1(n)\widehat{x}_{n+\frac{1}{2}}=\widehat{x}_{n}+\frac{1}{2}\left(1-e^{-2\alpha h}\right)v_{n}-\frac{1}{2}u\left(\alpha h-\frac{1}{2}\left(1-e^{-2(h-\alpha h)}\right)\right)\widehat{s}(x_{n})+\sqrt{u}W_{1}^{(n)}.

    4. (d)

      Let x^n+1=x^n+12(1e2h)vn12uh(1e2(hαh))s^(xn+12)+uW2(n)\widehat{x}_{n+1}=\widehat{x}_{n}+\frac{1}{2}\left(1-e^{-2h}\right)v_{n}-\frac{1}{2}uh\left(1-e^{-2(h-\alpha h)}\right)\widehat{s}(x_{n+\frac{1}{2}})+\sqrt{u}W_{2}^{(n)}.

    5. (e)

      Let vn+1=vne2huhe2(hαh)s^(xn+12)+2uW3(n)v_{n+1}=v_{n}e^{-2h}-uhe^{-2(h-\alpha h)}\widehat{s}(x_{n+\frac{1}{2}})+2\sqrt{u}W_{3}^{(n)}.

  2. 2.

    Return x^N\widehat{x}_{N}.

Theorem C.1 (Theorem 3 of [SL19], restated).

Let s^=lnp\widehat{s}=\nabla\ln p, the score function of a log-concave distribution pp be such that 0\preccurlyeqmId\preccurlyeqJs^(x)\preccurlyeqLId0\preccurlyeq m\cdot I_{d}\preccurlyeq J_{\widehat{s}}(x)\preccurlyeq L\cdot I_{d}, for the Jacobian Js^J_{\widehat{s}} of s^\widehat{s}. Let x^0\widehat{x}_{0} be the root of s^\widehat{s}, and v0=0v_{0}=0. Let κ=Lm\kappa=\frac{L}{m} be the condition number. For any 0<ε<10<\varepsilon<1, if we set the step size of Algorithm 10 as h=Cmin(ε1/3m1/6d1/6κ1/6log1/6(dεm),ε2/3m1/3d1/3log1/3(dεm))h=C\min\left(\frac{\varepsilon^{1/3}m^{1/6}}{d^{1/6}\kappa^{1/6}}\log^{-1/6}\left(\frac{d}{\varepsilon m}\right),\frac{\varepsilon^{2/3}m^{1/3}}{d^{1/3}}\log^{-1/3}\left(\frac{d}{\varepsilon m}\right)\right) for some small constant CC and run the algorithm for N=4κhlog20dε2mO~(κ7/6d1/6ε1/3m1/6+κd1/3ε2/3m1/3)N=\frac{4\kappa}{h}\log\frac{20d}{\varepsilon^{2}m}\leq\widetilde{O}\left(\frac{\kappa^{7/6}d^{1/6}}{\varepsilon^{1/3}m^{1/6}}+\frac{\kappa d^{1/3}}{\varepsilon^{2/3}m^{1/3}}\right) iterations, then Algorithm 10 after NN iterations can generate x^N\widehat{x}_{N} such that

W2(x^N,x)ε\displaystyle W_{2}(\widehat{x}_{N},x)\leq\varepsilon

where xpx\sim p.

Now, we make the following simple observation – if we run the corrector step from Section A.2 for a short time, we can convert the above Wasserstein guarantee to a TV guarantee. We carefully trade off the time spent on the Randomized Midpoint step above and the corrector step to obtain the improved dimension dependence. Our final algorithm is given in Algorithm 11.

Algorithm 11 LogConcaveSampling [SL19]

Input parameters:

  • Number of Randomized Midpoint steps NrandN_{\text{rand}}, Corrector steps Time Tcorr\lesssim1LT_{\mathrm{corr}}\lesssim\frac{1}{\sqrt{L}}, Randomized Midpoint Step size hrandh_{\text{rand}}, Corrector step size hcorrh_{\mathrm{corr}}, Score function s^\widehat{s}.

  1. 1.

    Let x^0\widehat{x}_{0} be the root of s^\widehat{s}, and let v0=0v_{0}=0.

  2. 2.

    Run Algorithm 10 with NrandN_{\text{rand}} steps and step size hrandh_{\text{rand}}, using x^0,v0\widehat{x}_{0},v_{0}, and let the result be x^Nrand\widehat{x}_{N_{\text{rand}}}^{\prime}.

  3. 3.

    Run Algorithm 5 starting from x^Nrand\widehat{x}_{N_{\text{rand}}}^{\prime} for time TcorrT_{\mathrm{corr}}, using step size hcorrh_{\mathrm{corr}}. Let the result be x^Nrand\widehat{x}_{N_{\text{rand}}}.

  4. 4.

    Return x^Nrand\widehat{x}_{N_{\text{rand}}}.

We obtain the following guarantee with our improved dimension dependence of O~(d5/12)\widetilde{O}(d^{5/12}).

Theorem C.2 (Log-Concave Sampling in Total Variation).

Let s^=lnp\widehat{s}=\nabla\ln p be the score function of a log-concave distribution pp such that 0\preccurlyeqmId\preccurlyeqJs^(x)\preccurlyeqLId0\preccurlyeq m\cdot I_{d}\preccurlyeq J_{\widehat{s}}(x)\preccurlyeq L\cdot I_{d} for the Jacobian Js^J_{\widehat{s}} of s^\widehat{s}. Let κ=Lm\kappa=\frac{L}{m} be the condition number. For any ε<1\varepsilon<1, if we set hrand=C(ε2/3d5/12κ1/3log1/3(dκε))h_{\text{rand}}=C\left(\frac{\varepsilon^{2/3}}{d^{5/12}\kappa^{1/3}}\log^{-1/3}\left(\frac{d\kappa}{\varepsilon}\right)\right) for a small constant CC, Nrand=4κhlog20dκε2O~(κ4/3d5/12ε2/3)N_{\text{rand}}=\frac{4\kappa}{h}\log\frac{20d\kappa}{\varepsilon^{2}}\leq\widetilde{O}\left(\frac{\kappa^{4/3}d^{5/12}}{\varepsilon^{2/3}}\right), hcorr=O~(εd17/36L)h_{\mathrm{corr}}=\widetilde{O}\left(\frac{\varepsilon}{d^{17/36}\sqrt{L}}\right) and Tcorr=O(1Ld1/18)T_{\mathrm{corr}}=O\left(\frac{1}{\sqrt{L}d^{1/18}}\right), we have that Algorithm 11 returns x^Nrand\widehat{x}_{N_{\text{rand}}} with

TV(x^Nrand,x)\lesssimε\displaystyle\textup{{TV}}(\widehat{x}_{N_{\text{rand}}},x)\lesssim\varepsilon

for xpx\sim p. Furthemore, the total iteration complexity is O~(d5/12(κ4/3ε2/3+1ε))\widetilde{O}\left(d^{5/12}\left(\frac{\kappa^{4/3}}{\varepsilon^{2/3}}+\frac{1}{\varepsilon}\right)\right).

Proof.

By Theorem C.1, we have, for our setting of NrandN_{\text{rand}} and hrandh_{\text{rand}} that, at the end of step 22 of Algorithm 11,

W2(x^Nrand,x)εd1/12L.W_{2}(\widehat{x}_{N_{\text{rand}}}^{\prime},x)\leq\frac{\varepsilon}{d^{1/12}\sqrt{L}}\,.

Then, by the first part of Corollary A.7,

TV(x^Nrand,x)\lesssimε+Ld17/36(εd17/36L)\lesssimε.\textup{{TV}}(\widehat{x}_{N_{\text{rand}}},x)\lesssim\varepsilon+\sqrt{L}d^{17/36}\cdot\left(\frac{\varepsilon}{d^{17/36}\sqrt{L}}\right)\lesssim\varepsilon\,.

Our iteration complexity is bounded by Nrand+Tcorrhcorr=O~(κ4/3d5/12ε2/3+d5/12ε)N_{\text{rand}}+\frac{T_{\mathrm{corr}}}{h_{\mathrm{corr}}}=\widetilde{O}\left(\frac{\kappa^{4/3}d^{5/12}}{\varepsilon^{2/3}}+\frac{d^{5/12}}{\varepsilon}\right) as claimed. ∎

Appendix D Helper lemmas

Lemma D.1 (Corollary 11 of [CCL+23a]).

For the ODE

dxt=(xt+lnqt(xt))dt,\mathrm{d}x_{t}=\left(x_{t}+\nabla\ln q_{t}(x_{t})\right)\mathrm{d}t\,,

if L1L\geq 1 and \mathbbE[2logqt(x)2]L\operatorname*{\mathbb{E}}\left[\|\nabla^{2}\log q_{t}(x)\|^{2}\right]\leq L, we have, for 0<s<t0<s<t and h=tsh=t-s,

\mathbbE[lnqt(xt)lnqs(xs)2]\lesssimL2dh2(L1t).\operatorname*{\mathbb{E}}\left[\|\nabla\ln q_{t}(x_{t})-\nabla\ln q_{s}(x_{s})\|^{2}\right]\lesssim L^{2}dh^{2}\left(L\lor\frac{1}{t}\right)\,.
Lemma D.2 (Implicit in Lemma 44 of [CCL+23a]).

Suppose L1L\geq 1, h\lesssim1Lh\lesssim\frac{1}{L} and t0ht0/2t_{0}-h\geq t_{0}/2. For ODEs starting at xt0=x^t0x_{t_{0}}=\widehat{x}_{t_{0}}, where

dxt\displaystyle\mathrm{d}x_{t} =(xt+lnqt(xt))dt\displaystyle=\left(x_{t}+\nabla\ln q_{t}(x_{t})\right)\mathrm{d}t
dx^t\displaystyle\mathrm{d}\widehat{x}_{t} =(xt+s^t0(x^t0))dt,\displaystyle=\left(x_{t}+\widehat{s}_{t_{0}}(\widehat{x}_{t_{0}})\right)\mathrm{d}t,

we have

\mathbbExt0hx^t0h2\lesssimh2(L2dh2(L1t0)+εsc2).\displaystyle\operatorname*{\mathbb{E}}\|x_{t_{0}-h}-\widehat{x}_{t_{0}-h}\|^{2}\lesssim h^{2}\left(L^{2}dh^{2}\left(L\lor\frac{1}{t_{0}}\right)+\varepsilon_{\mathrm{sc}}^{2}\right).
Lemma D.3 (Lemma B.1.B.1. of [GLP23], restated).

Let p0p_{0} be a distribution over \mathbbRd\mathbb R^{d}. For x0p0x_{0}\sim p_{0}, let xt=x0+ztptx_{t}=x_{0}+z_{t}\sim p_{t} for zt𝒩(0,tId)z_{t}\sim\mathcal{N}(0,tI_{d}) independent of x0x_{0}. Then,

pt(xt+ε)pt(xt)=\mathbbEzt|xt[eεTzttε22t]\displaystyle\frac{p_{t}(x_{t}+\varepsilon)}{p_{t}(x_{t})}=\operatorname*{\mathbb{E}}_{z_{t}|x_{t}}[e^{\frac{\varepsilon^{T}z_{t}}{t}-\frac{\|\varepsilon\|^{2}}{2t}}]

and

lnpt(xt)=\mathbbEzt|xt[ztt]\displaystyle\nabla\ln p_{t}(x_{t})=\operatorname*{\mathbb{E}}_{z_{t}|x_{t}}\left[-\frac{z_{t}}{t}\right]
Lemma D.4.

For qt(yt)pe2t1(etyt)q_{t}(y_{t})\propto p_{e^{2t}-1}(e^{t}y_{t}), for zt𝒩(0,(e2t1)Id)z_{t}\sim\mathcal{N}(0,(e^{2t}-1)I_{d}), we have

lnqt(yt)=etlnpe2t1(ety)=et\mathbbEzt|etyt[zte2t1]\displaystyle\nabla\ln q_{t}(y_{t})=e^{t}\nabla\ln p_{e^{2t}-1}(e^{t}y)=e^{t}\operatorname*{\mathbb{E}}_{z_{t}|e^{t}y_{t}}\left[\frac{-z_{t}}{e^{2t}-1}\right]

Furthermore,

\mathbbEytqt[lnqt(yt)2]\lesssimdt\displaystyle\operatorname*{\mathbb{E}}_{y_{t}\sim q_{t}}\left[\|\nabla\ln q_{t}(y_{t})\|^{2}\right]\lesssim\frac{d}{t}
Proof.

The first claim is an immediate consequence of the definition of qtq_{t} and Lemma D.3. For the second claim, note that

\mathbbEytqt[lnqt(yt)2]\displaystyle\operatorname*{\mathbb{E}}_{y_{t}\sim q_{t}}\left[\|\nabla\ln q_{t}(y_{t})\|^{2}\right] =e2t\mathbbEytqt[\mathbbEzt|etyt[zte2t1]2]\displaystyle=e^{2t}\operatorname*{\mathbb{E}}_{y_{t}\sim q_{t}}\left[\left\|\operatorname*{\mathbb{E}}_{z_{t}|e^{t}y_{t}}\left[\frac{-z_{t}}{e^{2t}-1}\right]\right\|^{2}\right]
e2t\mathbbEytqt[\mathbbEzt|etyt[zt2(e2t1)2]]\displaystyle\leq e^{2t}\operatorname*{\mathbb{E}}_{y_{t}\sim q_{t}}\left[\operatorname*{\mathbb{E}}_{z_{t}|e^{t}y_{t}}\left[\frac{\|z_{t}\|^{2}}{(e^{2t}-1)^{2}}\right]\right]
=e2t\mathbbEzt[zt2(e2t1)2]\displaystyle=e^{2t}\operatorname*{\mathbb{E}}_{z_{t}}\left[\frac{\|z_{t}\|^{2}}{(e^{2t}-1)^{2}}\right]
=e2tde2t1since zt𝒩(0,(e2t1)Id)\displaystyle=\frac{e^{2t}\cdot d}{e^{2t}-1}\quad\text{since $z_{t}\sim\mathcal{N}(0,(e^{2t}-1)I_{d})$}
\lesssimdt.\displaystyle\lesssim\frac{d}{t}\,.\qed