Faster Diffusion Sampling with Randomized Midpoints: Sequential and Parallel
Abstract
Sampling algorithms play an important role in controlling the quality and runtime of diffusion model inference. In recent years, a number of works [CCL+23b, CCL+23a, BDD23, LLT22] have proposed schemes for diffusion sampling with provable guarantees; these works show that for essentially any data distribution, one can approximately sample in polynomial time given a sufficiently accurate estimate of its score functions at different noise levels. In this work, we propose a new scheme inspired by Shen and Lee’s randomized midpoint method for log-concave sampling [SL19]. We prove that this approach achieves the best known dimension dependence for sampling from arbitrary smooth distributions in total variation distance ( compared to from prior work). We also show that our algorithm can be parallelized to run in only parallel rounds, constituting the first provable guarantees for parallel sampling with diffusion models.
As a byproduct of our methods, for the well-studied problem of log-concave sampling in total variation distance, we give an algorithm and simple analysis achieving dimension dependence compared to from prior work.
1 Introduction
Diffusion models [SDWMG15, SE19, HJA20, DN21, SDME21, SSDK+21, VKK21] have emerged as the de facto approach to generative modeling across a range of data modalities like images [BGJ+23, EKB+24], audio [KPH+20], video [BPH+24], and molecules [WYvdB+24]. In recent years a slew of theoretical works have established surprisingly general convergence guarantees for this method [CCL+23b, LLT23, CLL23, CCL+23a, CDD23, BDBDD24, GPPX23, LWCC23, LHE+24]. They show that for essentially any data distribution, assuming one has a sufficiently accurate estimate for its score function, one can approximately sample from it in polynomial time.
While these results offer some theoretical justification for the empirical successes of diffusion models, the upper bounds they furnish for the number of iterations needed to generate a single sample are quite loose relative to what is done in practice. The best known provable bounds scale as , where is the dimension of the space in which the diffusion is taking place (e.g. for Stable Diffusion) [CCL+23a], and is the target error. Even ignoring the dependence on and the hidden constant factor, this is at least larger than the default value of inference steps in Stable Diffusion.
In this work we consider a new approach for driving down the amount of compute that is provably needed to sample with diffusion models. Our approach is rooted in the randomized midpoint method, originally introduced by Shen and Lee [SL19] in the context of Langevin Monte Carlo for log-concave sampling. At a high level, this is a method for numerically solving differential equations where within every discrete window of time, one forms an unbiased estimate for the drift by evaluating it at a random “midpoint” (see Section 2.2 for a formal treatment). For sampling from log-concave densities, the number of iterations needed by their method scales with , and this remains the best known bound in the “low-accuracy” regime.
While this method is well-studied in the log-concave setting [HBE20, YKD23, YD24, SL19], its applicability to diffusion models has been unexplored both theoretically and empirically. Our first result uses the randomized midpoint method to obtain an improvement over the prior best known bound of for sampling arbitrary smooth distributions with diffusion models:
Theorem 1.1 (Informal, see Theorem A.10).
Suppose that the data distribution has bounded second moment, its score functions along the forward process are -Lipschitz, and we are given score estimates which are -Lipschitz and 111 hides polylogarithmic factors in and -close to for all . Then there is a diffusion-based sampler using these score estimates (see Algorithm 1) which outputs a sample whose law is -close in total variation distance to using iterations.
Our algorithm is based on the ODE-based predictor-corrector algorithm introduced in [CCL+23a], but in place of the standard exponential integrator discretization in the predictor step, we employ randomized midpoint discretization. We note that in the domain of log-concave sampling, the result of Shen and Lee only achieves recovery in Wasserstein distance. Prior to our work, it was actually open whether one can achieve the same dimension dependence in total variation or KL divergence, for which the best known bound was [MCC+21, ZCL+23, AC23]. In contrast, our result circumvents this barrier by carefully trading off time spent in the corrector phase of the algorithm for time spent in the predictor phase. We defer the details of this, as well as other important technical hurdles, to Section 3.1.
Next, we turn to a different computational model: instead of quantifying the cost of an algorithm in terms of the total number of iterations, we consider the parallel setting where one has access to multiple processors and wishes to minimize the total number of parallel rounds needed to generate a single sample. This perspective has been explored in a recent empirical work [SBE+24], but to our knowledge, no provable guarantees were known for parallel sampling with diffusion models (see Section 1.1 for discussion of concurrent and independent work). Our second result provides the first such guarantee:
Theorem 1.2 (Informal, see Theorem B.13).
Under the same assumptions on as in Theorem 1.1, and assuming that we are given score estimates which are -close to for all , there is a diffusion-based sampler using these score estimates (see Algorithm 9) which outputs a sample whose law is -close in total variation distance to using parallel rounds.
This result follows in the wake of several recent theoretical works on parallel sampling of log-concave densities using Langevin Monte Carlo [AHL+23, ACV24, SL19]. A common thread among these works is the observation that differential equations can be numerically solved via fixed point iteration (see Section 2.3 for details), and we adopt a similar perspective in the context of diffusions. To our knowledge this is the first provable guarantee for parallel sampling beyond the log-concave setting.
Finally, we show that, as a byproduct of our methods, we can actually obtain a similar dimension dependence of as in Theorem 1.1 for log-concave sampling in TV, superseding the previously best known bound of mentioned above.
Theorem 1.3 (Informal, see Theorem C.2).
Suppose distribution is -strongly-log-concave, and its score function is -Lipschitz. Then, there is a underdamped-Langevin-based sampler that uses this score (Algorithm 11) and outputs a sample whose law is -close in total variation to using iterations.
1.1 Related work
Our discretization scheme is based on the randomized midpoint method of [SL19], which has been studied at length in the domain of log-concave sampling [HBE20, YKD23, YD24].
The proof of our parallel sampling result builds on the ideas of [SL19, AHL+23, ACV24] on parallelizing the collocation method. These prior results were focused on Langevin Monte Carlo, rather than diffusion-based sampling. We review these ideas in Section 2.3.
In [CCL+23a], the authors proposed the predictor-corrector framework that we also use for analysing convergence guarantee of the probability flow ODE and which achieved iteration complexity scaling with . In addition to this, there have been many works in recent years giving general convergence guarantees for diffusion models [DBTHD21, BMR22, DB22, LLT22, LWYL22, Pid22, WY22, CCL+23b, CDD23, LLT23, LWCC23, BDD23, CCL+23a, BDBDD24, CLL23, GPPX23]. Of these, one line of work [CCL+23b, LLT23, CLL23, BDBDD24] analyzed DDPM, the stochastic analogue of the probability flow ODE, and showed iteration complexity bounds. Another set of works [CCL+23a, CDD23, LWCC23, LHE+24] studied the probability flow ODE, for which our work provides a new discretization scheme for the probability flow ODE, that achieves a state-of-the-art dimension dependence for sampling from a diffusion model.
Concurrent work.
Here we discuss the independent works of [CRYR24] and [KN24]. [CRYR24] gave an analysis for parallel sampling with diffusion models that also achieves a number of parallel rounds like in the present work. [KN24] showed an improved dimension dependence of for log-concave sampling in total variation, similar to our analogous result, but via a different proof technique. In addition to this, they show a similar result when the distribution only satisfies a log-Sobolev inequality. They also show empirical results for diffusion models, showing that an algorithm inspired by the randomized midpoint method outperforms ODE based methods with similar compute. While their work builds on the randomized midpoint method, they do not theoretically analyze the diffusion setting and do not study parallel sampling.
2 Preliminaries
2.1 Probability flow ODE
In this section we review basics about deterministic diffusion-based samplers; we refer the reader to [CCL+23a] for a more thorough exposition.
Let denote the data distribution over . We consider the standard Ornstein-Uhlenbeck (OU) forward process, i.e. the “VP SDE,” given by
(1) |
where denotes a standard Brownian motion in . This process converges exponentially quickly to its stationary distribution, the Gaussian distribution .
Suppose the OU process is run until terminal time , and for any , let , i.e. the law of the forward process at time . We will consider the reverse process given by the probability flow ODE
(2) |
This is a time-reversal of the forward process, so that if , then . In practice, one initializes at , and instead of using the exact score function , one uses estimates which are learned from data. Additionally, the ODE is solved numerically using any of a number of discretization schemes. The theoretical literature on diffusion models has focused primarily on exponential integration, which we review next before turning to the discretization scheme, the randomized midpoint method used in the present work.
2.2 Discretization schemes
Suppose we wish to discretize the following semilinear ODE:
(3) |
For our application we will eventually take , but we use in this section to condense notation.
Suppose we want to discretize Equation 3 over a time window . The starting point is the integral formulation for this ODE:
(4) |
Under the standard exponential integrator discretization, one would approximate the integrand by and obtain the approximation
(5) |
The drawback of this discretization is that it uses an inherently biased estimate for the integral in Eq. (4). The key insight of [SL19] was to replace this with the following unbiased estimate
(6) |
where is a uniformly random sample from . While this alone does not suffice as the estimate depends on , naturally we could iterate the above procedure again to obtain an approximation to . It turns out though that even if we simply approximate using exponential integrator discretization, we can obtain nontrivial improvements in discretization error (e.g. our Theorem 1.1). In this case, the above sequence of approximations takes the following form:
(7) | ||||
(8) |
Note that a similar idea can be used to discretize stochastic differential equations, but in this work we only use it to discretize the probability flow ODE.
Predictor-Corrector.
For important technical reasons, in our analysis we actually consider a slightly different algorithm than simply running the probability flow ODE with approximate score, Gaussian initialization, and randomized midpoint discretization. Specifically, we interleave the ODE with corrector steps that periodically inject noise into the sampling trajectory. We refer to the phases in which we are running the probability flow ODE as predictor steps.
2.3 Parallel sampling
The scheme outlined in the previous section is a simple special case of the collocation method. In the context of the semilinar ODE from Eq. (3), the idea behind the collocation method is to solve the integral formulation of the ODE in Eq. (4) via fixed point iteration. For our parallel sampling guarantees, instead of choosing a single randomized midpoint , we break up the window into sub-windows, select randomized midpoints for these sub-windows, and approximate the trajectory of the ODE at any time , where , by
(9) |
One can show that as , this approximation tends to an equality. For sufficiently large , Eq. (9) naturally suggests a fixed point iteration that can be used to approximate each , i.e. we can maintain a sequence of estimates defined by the iteration
(10) |
for ranging from up to some sufficiently large . Finally, analogously to Eq. (8), we can estimate via
(11) |
The key observation, made in [SL19] and also in related works of [ACV24, SBE+24, AHL+23], is that for any fixed round , all of the iterations Eq. (10) for different choices of can be computed in parallel. With parallel processors, one can thus compute the estimate for in parallel rounds, with total work.
2.4 Assumptions
Throughout the paper, for our diffusion results, we will make the following standard assumptions on the data distribution and score estimates.
Assumption 2.1 (Bounded Second Moment).
Assumption 2.2 (Lipschitz Score).
For all , the score is -Lipschitz.
Assumption 2.3 (Lipschitz Score estimates).
For all for which we need to estimate the score function in our algorithms, the score estimate is -lipschitz.
Assumption 2.4 (Score Estimation Error).
For all for which we need to estimate the score function in our algorithms,
3 Technical overview
Here we provide an overview of our sequential and parallel algorithms, along with the analysis of our iteration complexity bounds. We begin with a description of the sequential algorithm.
3.1 Sequential algorithm
Following the framework of [CCL+23a], our algorithm consists of “predictor” steps interspersed with “corrector” steps, with the time spent on each carefully tuned to obtain our final dimension dependence. We first describe our predictor step – this is the piece of our algorithm that makes use of the Shen and Lee’s randomized midpoint method [SL19].
Input parameters:
-
•
Starting sample , Starting time , Number of steps , Step sizes , Score estimates
-
1.
For :
-
(a)
Let
-
(b)
Randomly sample uniformly from .
-
(c)
Let
-
(d)
Let
-
(a)
-
2.
Let
-
3.
Return .
The main difference between the above and the predictor step of [CCL+23a] are steps – . and together compute a randomized midpoint, and uses this midpoint to obtain an approximate solution to the integral of the ODE. We describe these steps in more detail in Section 3.3.
Next, we describe the “corrector” step, introduced in [CCL+23a]. First, recall the underdamped Langevin ODE:
(12) |
Here is our accurate score estimate for a fixed time (say ). Then, the corrector step is described below.
Input parameters:
-
•
Starting sample , Total time , Step size , Score estimate
-
1.
Run underdamped Langevin Monte Carlo in (12) for total time using step size , and let the result be .
-
2.
Return .
Finally, Algorithm 3 below puts the predictor and corrector steps together to give our final sequential algorithm.
Input parameters:
-
•
Start time , End time , Corrector steps time , Number of predictor-corrector steps , Predictor step size , Corrector step size , Score estimates
3.2 Predictor-corrector framework
The general framework of our algorithm closely follows that of [CCL+23a], which proposed to run the (discretized) reverse ODE but interspersed with “corrector” steps given by running underdamped Langevin dynamics. The idea is that the “predictor” steps where the discretized reverse ODE is being run keep the sampler close to the true reverse process in Wasserstein distance, but they cannot be run for too long before potentially incurring exponential blowups. The main purpose of the corrector steps is then to inject stochasticity into the trajectory of the sampler in order to convert closeness in Wasserstein to closeness in KL divergence. This effectively allows one to “restart the coupling” used to control the predictor steps. For technical reasons that are inherited from [CCL+23a], for most of the reverse process the predictor steps (Step 2(a)) are run with a fixed step size, but at the end of the reverse process (Step 3), they are run with exponentially decaying step sizes.
We follow the same framework, and the core of our result lies in refining the algorithm and analysis for the predictor steps by using the randomized midpoint method. Below, we highlight our key technical steps.
3.3 Predictor step – improved discretization error with randomized midpoints
Here, we explain the main idea behind why randomized midpoint allows us to achieve improved dimension dependence. We first focus on the analysis of the predictor (Algorithm 1) and restrict our attention to running the reverse process for a small amount of time .
We begin by recalling the dimension dependence achieved by the standard exponential integrator scheme. One can show (see e.g. Lemma 4 in [CCL+23a]) that if the true reverse process and the discretized reverse process are both run for small time starting from the same initialization, the two processes drift by a distance of . By iterating this coupling times, we conclude that in an window of time, the processes drift by a distance of . To ensure this is not too large, one would take the step size to be , thus obtaining an iteration complexity of as in [CCL+23a].
The starting point in the analysis of randomized midpoint is to instead track the squared displacement between the two processes instead. Given two neighboring time steps and in the algorithm, let denote the true reverse process at time , and let denote the algorithm at time (in the notation of Algorithm 2, this is for some , but we use in the discussion here to make the comparison to the true reverse process clearer). Note that depends on the choice of randomized midpoint (see Step 1(b)). One can bound the squared displacement as follows. Let be the result of running the reverse process for time starting from . Then by writing as and applying Young’s inequality, we obtain
(13) |
For the first term, because and are the result of running the same ODE on initializations and , the first term is close to provided . The upshot is that the squared displacement at time is at most the squared displacement at time plus the remaining two terms on the right of Equation 13.
The main part of the proof lies in bounding these two terms, which can be thought of as “bias” and “variance” terms respectively. The variance term can be shown to scale with the square of the aforementioned displacement bound that arises in the exponential integrator analysis, giving :
Lemma 3.1 (Informal, see Lemma A.4 for formal statement).
If and , then
Note that in this bound, in addition to the term and a term for the score estimation error, there is an additional term which depends on the squared displacement from the previous time step. Because the prefactor is sufficiently small, this will ultimately be negligible.
The upshot of the above Lemma is that if the bias term is of lower order, then this means that the squared displacement essentially increases by with every time step of length . Over such steps, the total squared displacement is , so if we take the step size to be , this suggests an improved iteration complexity of .
Arguing that the bias term is dominated by the variance term is where it is crucial that we use randomized midpoint instead of exponential integrator. But recall that the randomized midpoint method was engineered so that it would give an unbiased estimate for the true solution to the reverse ODE if the estimate of the trajectory at the randomized midpoint were exact. In reality we only have an approximation to the latter, but as we show, the error incurred by this is indeed of lower order (see Lemma A.3). One technical complication that arises here is that the relevant quantity to bound is the distance between the true process at the randomized midpoint versus the algorithm, when both are initialized at an intermediate point in the algorithm’s trajectory. Bounding such quantities in expectation over the randomness of the algorithm’s trajectory can be difficult, but our proof identifies a way of “offloading” some of this difficulty by absorbing some excess terms into a term of the form , i.e. the squared displacement from the previous time step. Concretely, we obtain the following bound on the bias term:
Lemma 3.2 (Informal, see Lemma A.2 for formal statement).
3.4 Shortening the corrector steps
While we have outlined how to improve the predictor step in the framewok of [CCL+23a], it is quite unclear whether the same can be achieved for the corrector step. Whereas the the former is geared towards closeness in Wasserstein distance, the latter is geared towards closeness in KL divergence, and it is a well-known open question in the log-concave sampling literature to obtain analogous discretization bounds in KL for the randomized midpoint method [Che23].
We will sidestep this issue and argue that even using exponential integrator discretization of the underdamped Langevin dynamics will suffice for our purposes, by simply shortening the amount of time for which each corrector step is run.
First, let us briefly recall what was shown in [CCL+23a] for the corrector step. If one runs underdamped Langevin dynamics with stationary distribution for time and exponential integrator discretization with step size starting from two distributions and , then the resulting distributions and satisfy
(14) |
where is the Lipschitzness of (see Theorem A.6). At first glance this appears insufficient for our purposes: because of the term coming from the discretization error, we would need to take step size , which would suggest that the number of iterations must scale with .
To improve the dimension dependence for our overall predictor-corrector algorithm, we observe that if we take itself to be smaller, then we can take to be larger while keeping the discretization error in Equation 14 sufficiently small. Of course, this comes at a cost, as also appears in the term in Equation 14. But in our overall proof, the term is bounded by the predictor analysis. There, we had quite a bit of slack: even with step size as large as , we could achieve small Wasserstein error. By balancing appropriately, we get our improved dimension dependence.
3.5 Parallel algorithm
Now, we summarize the main proof ideas for our parallel sampling result. In Section 2.3, we described how to approximately solve the reverse ODE over time by running rounds of the iteration in Equation 10. In our final algorithm, we will take to be dimension-independent, namely , so that the main part of the proof is to bound the discretization error incurred over each of these time windows of length . As in the sequential analysis, we will interleave these “predictor” steps with corrector steps given by (parallelized) underdamped Langevin dynamics.
We begin by describing the parallel predictor step. Suppose we have produced an estimate for the reverse process at and now wish to solve the ODE from time to . We initialize at via exponential integrator steps starting from the beginning of the window – see Line 1(c) in Algorithm 9 (this can be thought of as the analogue of Equation 7 used in the sequential algorithm). The key difference relative to the sequential algorithm is that here, because the length of the window is dimension-free, the discretization error incurred by this initialization is too large and must be refined using the fixed point iteration in Equation 10. The main step is then to show that with each iteration of Equation 10, the distance to the true reverse process contracts:
Lemma 3.3 (Informal, see Lemma B.2 for formal statement).
Suppose . If denotes the solution of the true ODE starting at and running until time , then for all and ,
(15) |
where is the iterate of the algorithm from the previous time window, and is the corresponding iterate in the true ODE.
In particular, because is at most a small multiple of , the prefactor is exponentially decaying in , so that the error incurred by the estimate is contracting with each fixed point iteration. Because the initialization is at distance from the true process, rounds of contraction thus suffice, which translates to parallel rounds for the sampler. The rest of the analysis of the predictor step is quite similar to the analogous proofs for the sequential algorithm (i.e. Lemma B.4 and Lemma B.5 give the corresponding bias and variance bounds).
One shortcoming of the predictor analysis is that the contraction achieved by fixed point iteration ultimately bottoms out at error which scales with (see the second term in Equation 15). In order for the discretization error to be sufficiently small, we thus have to take , and thus the total work of the algorithm, to scale with . So in this case we do not improve over the dimension dependence of [CCL+23a], and instead the improvement is in obtaining a parallel algorithm.
For the corrector analysis, we mostly draw upon the recent work of [ACV24] which analyzed a parallel implementation of the underdamped Langevin dynamics. While their guarantee focuses on sampling from log-concave distributions, implicit in their analysis is a bound for general smooth distributions on how much the law of the algorithm and the law of the true process drift apart in a bounded time window (see Lemma B.8). This bound suffices for our analysis of the corrector step, and we can conclude the following:
Theorem 3.4 (Informal, see Theorem B.12).
Let be an adjustable parameter. Let denote the law of the output of running the parallel corrector (see Algorithm 8) for total time and step size , using an -approximate estimate for and starting from a sample from another distribution .
Furthermore, this algorithm uses score evaluations over parallel rounds.
Overall, parallel algorithm is somewhat different from the parallel sampler developed in the empirical work of [SBE+24], even apart from the fact that we use randomized midpoint discretization and corrector steps. The reason is that our algorithm applies collocation to fixed windows of time, whereas the algorithm of [SBE+24] utilizes a sliding window approach that proactively shifts the window forward as soon as the iterates at the start of the previous window begin to converge. We leave rigorously analyzing the benefits of this approach as an interesting future direction.
3.6 Log-concave sampling in total variation
Finally, we briefly summarize the simple proof for our result on log-concave sampling in TV, which achieves the best known dimension dependence of . Our main observation is that Shen and Lee’s randomized midpoint method [SL19] applied to the underdamped Langevin process gives a Wasserstein guarantee for log-concave sampling, while the corrector step of [CCL+23a] can convert a Wasserstein guarantee to closeness in TV. Thus, we can simply run the randomized midpoint method, followed by the corrector step to achieve closeness in TV. Carefully tuning the amount of time spend and step sizes for each phase of this algorithm yields our improved dimension dependence – see Appendix C for the full proof.
4 Discussion and Future Work
In this work, we showed that it is possible to leverage Shen and Lee’s randomized midpoint method [SL19] to achieve the best known dimension dependence for sampling from arbitrary smooth distributions in TV using diffusion. We also showed how to parallelize our algorithm, and showed that parallel rounds suffice for sampling. These constitute the first provable guarantees for parallel sampling with diffusion models. Finally, we showed that our techniques can be used to obtain an improved dimension dependence for log-concave sampling in TV.
We note that relative to [CCL+23a], our result requires a slightly stronger guarantee on the score estimation error, by a factor; we believe this is an artifact of our analysis, and it would be interesting to remove this dependence in future work. Importantly however, it was not known how to achieve an improvement over the dependence shown in that paper even in case that the scores are known exactly prior to the present work. Moreover, another line of work [LWCC23, LHE+24, DCWY24] analyzing diffusion sampling makes the stronger assumption that the score estimation error is , an assumption stronger than ours by a factor; this does not detract from the importance of these works, and we feel the same is true in our case.
We also note that our diffusion results require smoothness assumptions – we assume that the true score, as well as our score estimates are -Lipschitz. Although this assumption is standard in the literature, recent work [CLL23, BDBDD24] has analyzed DDPM in the absence of these assumptions, culminating in a dependence for sampling using a discretization of the reverse SDE. However, unlike in the smooth case, it is not known whether even a sublinear in dependence is possible without smoothness assumptions via any algorithm. We leave this as an interesting open question for future work.
Acknowledgements
S.C. thanks Sinho Chewi for a helpful discussion on log-concave sampling. S.G. was funded by NSF Award CCF-1751040 (CAREER) and the NSF AI Institute for Foundations of Machine Learning (IFML).
References
- [AC23] Jason M Altschuler and Sinho Chewi. Faster high-accuracy log-concave sampling via algorithmic warm starts. In 2023 IEEE 64th Annual Symposium on Foundations of Computer Science (FOCS), pages 2169–2176. IEEE, 2023.
- [ACV24] Nima Anari, Sinho Chewi, and Thuy-Duong Vuong. Fast parallel sampling under isoperimetry. CoRR, abs/2401.09016, 2024.
- [AHL+23] Nima Anari, Yizhi Huang, Tianyu Liu, Thuy-Duong Vuong, Brian Xu, and Katherine Yu. Parallel discrete sampling via continuous walks. In Proceedings of the 55th Annual ACM Symposium on Theory of Computing, pages 103–116, 2023.
- [BDBDD24] Joe Benton, Valentin De Bortoli, Arnaud Doucet, and George Deligiannidis. Nearly d-linear convergence bounds for diffusion models via stochastic localization. In The Twelfth International Conference on Learning Representations, 2024.
- [BDD23] Joe Benton, George Deligiannidis, and Arnaud Doucet. Error bounds for flow matching methods. arXiv preprint arXiv:2305.16860, 2023.
- [BGJ+23] James Betker, Gabriel Goh, Li Jing, Tim Brooks, Jianfeng Wang, Linjie Li, Long Ouyang, Juntang Zhuang, Joyce Lee, Yufei Guo, et al. Improving image generation with better captions. Computer Science. https://cdn. openai. com/papers/dall-e-3. pdf, 2(3):8, 2023.
- [BMR22] Adam Block, Youssef Mroueh, and Alexander Rakhlin. Generative modeling with denoising auto-encoders and Langevin sampling. arXiv preprint 2002.00107, 2022.
- [BPH+24] Tim Brooks, Bill Peebles, Connor Holmes, Will DePue, Yufei Guo, Li Jing, David Schnurr, Joe Taylor, Troy Luhman, Eric Luhman, Clarence Ng, Ricky Wang, and Aditya Ramesh. Video generation models as world simulators. 2024.
- [CCL+23a] Sitan Chen, Sinho Chewi, Holden Lee, Yuanzhi Li, Jianfeng Lu, and Adil Salim. The probability flow ODE is provably fast. In Thirty-seventh Conference on Neural Information Processing Systems, 2023.
- [CCL+23b] Sitan Chen, Sinho Chewi, Jerry Li, Yuanzhi Li, Adil Salim, and Anru R Zhang. Sampling is as easy as learning the score: theory for diffusion models with minimal data assumptions. In International Conference on Learning Representations, 2023.
- [CDD23] Sitan Chen, Giannis Daras, and Alex Dimakis. Restoration-degradation beyond linear diffusions: A non-asymptotic analysis for ddim-type samplers. In International Conference on Machine Learning, pages 4462–4484. PMLR, 2023.
- [Che23] Sinho Chewi. Log-concave sampling. Book draft available at https://chewisinho. github. io, 2023.
- [CLL23] Hongrui Chen, Holden Lee, and Jianfeng Lu. Improved analysis of score-based generative modeling: User-friendly bounds under minimal smoothness assumptions. In International Conference on Machine Learning, pages 4735–4763. PMLR, 2023.
- [CRYR24] Haoxuan Chen, Yinuo Ren, Lexing Ying, and Grant M. Rotskoff. Accelerating diffusion models with parallel sampling: Inference at sub-linear time complexity, 2024.
- [DB22] Valentin De Bortoli. Convergence of denoising diffusion models under the manifold hypothesis. Transactions on Machine Learning Research, 2022.
- [DBTHD21] Valentin De Bortoli, James Thornton, Jeremy Heng, and Arnaud Doucet. Diffusion Schrödinger bridge with applications to score-based generative modeling. In M. Ranzato, A. Beygelzimer, Y. Dauphin, P.S. Liang, and J. Wortman Vaughan, editors, Advances in Neural Information Processing Systems, volume 34, pages 17695–17709. Curran Associates, Inc., 2021.
- [DCWY24] Zehao Dou, Minshuo Chen, Mengdi Wang, and Zhuoran Yang. Theory of consistency diffusion models: Distribution estimation meets fast sampling. In Ruslan Salakhutdinov, Zico Kolter, Katherine Heller, Adrian Weller, Nuria Oliver, Jonathan Scarlett, and Felix Berkenkamp, editors, Proceedings of the 41st International Conference on Machine Learning, volume 235 of Proceedings of Machine Learning Research, pages 11592–11612. PMLR, 21–27 Jul 2024.
- [DN21] Prafulla Dhariwal and Alexander Nichol. Diffusion models beat GANs on image synthesis. In M. Ranzato, A. Beygelzimer, Y. Dauphin, P.S. Liang, and J. Wortman Vaughan, editors, Advances in Neural Information Processing Systems, volume 34, pages 8780–8794. Curran Associates, Inc., 2021.
- [EKB+24] Patrick Esser, Sumith Kulal, Andreas Blattmann, Rahim Entezari, Jonas Müller, Harry Saini, Yam Levi, Dominik Lorenz, Axel Sauer, Frederic Boesel, et al. Scaling rectified flow transformers for high-resolution image synthesis. arXiv preprint arXiv:2403.03206, 2024.
- [GLP23] Shivam Gupta, Jasper C.H. Lee, and Eric Price. High-dimensional location estimation via norm concentration for subgamma vectors. In Proceedings of the 40th International Conference on Machine Learning, ICML’23. JMLR.org, 2023.
- [GPPX23] Shivam Gupta, Aditya Parulekar, Eric Price, and Zhiyang Xun. Sample-efficient training for diffusion. arXiv preprint arXiv:2311.13745, 2023.
- [HBE20] Ye He, Krishnakumar Balasubramanian, and Murat A Erdogdu. On the ergodicity, bias and asymptotic normality of randomized midpoint sampling method. Advances in Neural Information Processing Systems, 33:7366–7376, 2020.
- [HJA20] Jonathan Ho, Ajay Jain, and Pieter Abbeel. Denoising diffusion probabilistic models. Advances in Neural Information Processing Systems, 33:6840–6851, 2020.
- [KN24] Saravanan Kandasamy and Dheeraj Nagaraj. The poisson midpoint method for langevin dynamics: Provably efficient discretization for diffusion models, 2024.
- [KPH+20] Zhifeng Kong, Wei Ping, Jiaji Huang, Kexin Zhao, and Bryan Catanzaro. Diffwave: A versatile diffusion model for audio synthesis. arXiv preprint arXiv:2009.09761, 2020.
- [LHE+24] Gen Li, Yu Huang, Timofey Efimov, Yuting Wei, Yuejie Chi, and Yuxin Chen. Accelerating convergence of score-based diffusion models, provably. arXiv preprint arXiv:2403.03852, 2024.
- [LLT22] Holden Lee, Jianfeng Lu, and Yixin Tan. Convergence for score-based generative modeling with polynomial complexity. Advances in Neural Information Processing Systems, 35:22870–22882, 2022.
- [LLT23] Holden Lee, Jianfeng Lu, and Yixin Tan. Convergence of score-based generative modeling for general data distributions. In International Conference on Algorithmic Learning Theory, pages 946–985. PMLR, 2023.
- [LWCC23] Gen Li, Yuting Wei, Yuxin Chen, and Yuejie Chi. Towards faster non-asymptotic convergence for diffusion-based generative models. arXiv preprint arXiv:2306.09251, 2023.
- [LWYL22] Xingchao Liu, Lemeng Wu, Mao Ye, and Qiang Liu. Let us build bridges: Understanding and extending diffusion generative models. arXiv preprint arXiv:2208.14699, 2022.
- [MCC+21] Yi-An Ma, Niladri S Chatterji, Xiang Cheng, Nicolas Flammarion, Peter L Bartlett, and Michael I Jordan. Is there an analog of nesterov acceleration for gradient-based mcmc? 2021.
- [Pid22] Jakiw Pidstrigach. Score-based generative models detect manifolds. In S. Koyejo, S. Mohamed, A. Agarwal, D. Belgrave, K. Cho, and A. Oh, editors, Advances in Neural Information Processing Systems, volume 35, pages 35852–35865. Curran Associates, Inc., 2022.
- [SBE+24] Andy Shih, Suneel Belkhale, Stefano Ermon, Dorsa Sadigh, and Nima Anari. Parallel sampling of diffusion models. Advances in Neural Information Processing Systems, 36, 2024.
- [SDME21] Yang Song, Conor Durkan, Iain Murray, and Stefano Ermon. Maximum likelihood training of score-based diffusion models. In M. Ranzato, A. Beygelzimer, Y. Dauphin, P.S. Liang, and J. Wortman Vaughan, editors, Advances in Neural Information Processing Systems, volume 34, pages 1415–1428. Curran Associates, Inc., 2021.
- [SDWMG15] Jascha Sohl-Dickstein, Eric Weiss, Niru Maheswaranathan, and Surya Ganguli. Deep unsupervised learning using nonequilibrium thermodynamics. In Francis Bach and David Blei, editors, Proceedings of the 32nd International Conference on Machine Learning, volume 37 of Proceedings of Machine Learning Research, pages 2256–2265, Lille, France, 7 2015. PMLR.
- [SE19] Yang Song and Stefano Ermon. Generative modeling by estimating gradients of the data distribution. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d'Alché-Buc, E. Fox, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 32. Curran Associates, Inc., 2019.
- [SL19] Ruoqi Shen and Yin Tat Lee. The randomized midpoint method for log-concave sampling. In Hanna M. Wallach, Hugo Larochelle, Alina Beygelzimer, Florence d’Alché-Buc, Emily B. Fox, and Roman Garnett, editors, Advances in Neural Information Processing Systems 32: Annual Conference on Neural Information Processing Systems 2019, NeurIPS 2019, December 8-14, 2019, Vancouver, BC, Canada, pages 2098–2109, 2019.
- [SSDK+21] Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole. Score-based generative modeling through stochastic differential equations. In International Conference on Learning Representations, 2021.
- [VKK21] Arash Vahdat, Karsten Kreis, and Jan Kautz. Score-based generative modeling in latent space. In M. Ranzato, A. Beygelzimer, Y. Dauphin, P.S. Liang, and J. Wortman Vaughan, editors, Advances in Neural Information Processing Systems, volume 34, pages 11287–11302. Curran Associates, Inc., 2021.
- [WY22] Andre Wibisono and Kaylee Y. Yang. Convergence in KL divergence of the inexact Langevin algorithm with application to score-based generative models. arXiv preprint 2211.01512, 2022.
- [WYvdB+24] Kevin E Wu, Kevin K Yang, Rianne van den Berg, Sarah Alamdari, James Y Zou, Alex X Lu, and Ava P Amini. Protein structure generation via folding diffusion. Nature Communications, 15(1):1059, 2024.
- [YD24] Lu Yu and Arnak Dalalyana. Parallelized midpoint randomization for langevin monte carlo. arXiv preprint arXiv:2402.14434, 2024.
- [YKD23] Lu Yu, Avetik Karagulyan, and Arnak Dalalyan. Langevin monte carlo for strongly log-concave distributions: Randomized midpoint revisited. arXiv preprint arXiv:2306.08494, 2023.
- [ZCL+23] Shunshi Zhang, Sinho Chewi, Mufan Li, Krishna Balasubramanian, and Murat A Erdogdu. Improved discretization analysis for underdamped langevin monte carlo. In The Thirty Sixth Annual Conference on Learning Theory, pages 36–71. PMLR, 2023.
Roadmap.
In Section A, we give the proof of Theorem 1.1, our main result on sequential sampling with diffusions. In Section B, we give the proof of Theorem 1.2, our main result on parallel sampling with diffusions. In Section C, we give the proof of Theorem 1.3 on log-concave sampling.
As a notational remark, in the proofs to follow we will sometimes use the notation , , and for random variables and to denote the distance between their associated probability distributions. Also, throughout the Appendix, we use to denote time in the forward process.
Appendix A Sequential algorithm
In this section, we describe our sequential randomized-midpoint-based algorithm in detail. Following the framework of [CCL+23a], we begin by describing the predictor Step and show in Lemma A.5 that in steps (ignoring other dependencies), when run for time at most starting from , it produces a sample that is close to the true distribution at time . Then, we show that the corrector step can be used to convert our error to error in TV distance by running the underdamped Langevin Monte Carlo algorithm, as described in [CCL+23a]. We show in A.8 that if we run our predictor and corrector steps in succession for a careful choice of times, we obtain a sample that is close to the true distribution in TV using just steps, but covering a time . Finally, in Theorem A.10, we iterate this bound times to obtain our final iteration complexity of .
A.1 Predictor step
To show the dependence on dimension for the predictor step, we will, roughly speaking, show that its bias after one step is bounded by in Lemma A.3, and that the variance is bounded by in Lemma A.4. Then, iterating these bounds times as shown in Lemma A.5 will give error in squared Wasserstein Distance.
Input parameters:
-
•
Starting sample , Starting time , Number of steps , Step sizes , Score estimates
-
1.
For :
-
(a)
Let
-
(b)
Randomly sample uniformly from .
-
(c)
Let
-
(d)
Let
-
(a)
-
2.
Let
-
3.
Return .
Lemma A.1 (Naive ODE Coupling).
Consider two variables starting at time , and consider the result of running the true ODE for time , and let the results be . For , , we have
Proof.
Recall that the true ODE is given by
So,
So,
Lemma A.2.
Suppose . In Algorithm 4, for all , let be the solution of the true ODE starting at at time , running until time . If and , we have
where refers to the expectation over the initial choice .
Proof.
For the proof, we will let . It suffices to show that
(16) |
Now,
(17) |
Now, note is the solution to the following ODE run for time , starting at at time :
Similarly, is the solution to the following ODE run for time , starting at at time :
So, we have
where the last line is by Young’s inequality. So, by Grönwall’s inequality,
Now, we have
By Corollary D.1,
By Lipschitzness of and Lemma A.1, for ,
and similarly,
So, we have shown that
so that
Combining this with the bound in (17) and recalling that yields the desired inequality in (16). ∎
Lemma A.3 (Sequential Predictor Bias).
Suppose . In Algorithm 4, for all , let be the solution of the true ODE starting at at time and running until time , and let be the solution of the true ODE, starting at . If , and , we have
where is the expectation with respect to the chosen in the step, and is the expectation with respect to the choice of the initial .
Proof.
For the proof, we wil fix , and let . By the integral formulation of the true ODE,
Thus, we have
The second term is since
For the first term, we have, by Lemma A.2
The claimed bound follows. ∎
Lemma A.4 (Sequential Predictor Variance).
Suppose . In Algorithm 4, for all , let be the solution of the true ODE starting at at time and running until time , and let be the solution of the true ODE starting at . If and , we have
where refers to the expectation wrt the random in the step, along with the initial choice .
Proof.
Fix and let . We have
The first term was bounded in Lemma A.2:
For the second term,
Now,
The first of these terms is bounded in Corollary D.1:
For the remaining two terms, note that by the Lipschitzness of and Lemma A.1,
and similarly, for ,
(18) |
Thus, we have shown that the second term in our bound on is bounded as follows:
For the third term,
Now, we have
where the last step follows by Lemma D.4 and (18). So,
Thus, noting that , we obtain the claimed bound on . ∎
Finally, we put together the bias and variance bounds above to obtain a bound on the Wasserstein error at the end of the Predictor Step.
Lemma A.5 (Sequential Predictor Wasserstein Guarantee).
Suppose that , and that for our sequence of step sizes , . Let . Then, at the end of Algorithm 4,
-
1.
If ,
-
2.
If and for each ,
Here, is the solution of the true ODE beginning at .
Proof.
For all , let be the solution of the exact one step ODE starting from . Let the operator be the expectation over the random choice of in the iteration. Note that only depends on . We have
where the third line is by Young’s inequality, and the fourth line is by Lemma A.1. Taking the expectation wrt , by Lemmas A.3 and A.4,
By induction, noting that , we have
By assumption, . In the first case, for all , so
In the second case,
A.2 Corrector step
For the sequential algorithm, we make use of the underdamped Langevin corrector step and analysis from [CCL+23a]. We reproduce the same here for convenience.
The underdamped Langevin Monte Carlo process with step size is given by:
(19) |
where is Brownian motion, and satisfies
(20) |
for some target measure . Here, we set the friction parameter .
Then, our corrector step is as follows.
Input parameters:
-
•
Starting sample , Total time , Step size , Score estimate
-
1.
Run underdamped Langevin Monte Carlo in (19) for total time using step size , and let the result be .
-
2.
Return .
Theorem A.6 (Theorem 5 of [CCL+23a], restated).
Corollary A.7 (Underdamped Corrector).
For
A.3 End-to-end analysis
Finally, we put together the analysis of the predictor and corrector step to obtain our final dependence on sampling time. We first show that carefully choosing the amount of time to run the corrector results in small TV error after successive rounds of the predictor and corrector steps in Lemma A.8. Finally, we iterate this bound to obtain our final guarantee, given by Theorem A.10.
Input parameters:
-
•
Start time , End time , Corrector steps time , Number of predictor-corrector steps , Predictor step size , Corrector step size , Score estimates
Lemma A.8 (TV error after one round of predictor and corrector).
Let be a sample from the true distribution at time . Let for . If we set , we have,
-
1.
For , if ,
-
2.
If ,
Proof.
For , let be the result of a single predictor-corrector sequence as described in step 2 of Algorithm 6, but starting from instead of . Additionally, let be the result of running steps and starting from instead of . Similarly, let be the result of only applying the predictor step starting from , analogous to defined in step 2a.
We recall the following lemma on the convergence of the OU process from [CCL+23a]
Lemma A.9 (Lemma 13 of [CCL+23a]).
Let denote the marginal law of the OU process started at . Then, for all ,
Finally, we prove our main theorem on the convergence of our sequential algorithm.
Theorem A.10 (Convergence bound for sequential algorithm).
Proof.
We will let . First, note that by Lemma A.9
We divide our analysis into two steps. For the first steps, we iterate the first part of Lemma A.8 to obtain
Applying the second part of Lemma A.8 for the second stage of the algorithm, we have
Setting , , and , if the score estimation error satisfies , with iteration complexity , we obtain . ∎
Appendix B Parallel algorithm
B.1 Predictor step
In this section, we will apply a parallel version of randomized midpoint for the predictor step, where only iteration complexity will be required to attain our desired error bound for one predictor step.
In each iteration , we will first sample randomized midpoints that are in expectation evenly spaced with time intervals between consecutive midpoints. Next, in our step (c), we provide an initial estimate on the value of midpoints using our estimate of position at time provided by iteration . This step is analogous to step (c) in Algorithm 4 for the sequential predictor step. Then, in step (d) we refine our initial estimates by using a discrete version of Picard iteration, where for round , we compute a new estimate of based on the estimates of for in round . Note that a trajectory that starts from time and follows the true ODE is a fix point of operator that maps continuous function to continuous function, where
By smoothness of the true ODE, the continuous Picard iteration converges exponentially to the true trajectory, and we will show that discretization error for Picard iteration is controlled. After the refinements have sufficiently reduced the estimation error for our randomized midpoints, we make a final calculation, estimating the value of based on the estimated value at all the randomized midpoints.
Input parameters:
-
•
Starting sample , Starting time , Number of steps , Step size , Number of midpoint estimates , Number of parallel iteration , Score estimates
-
•
For all : let
-
1.
For :
-
(a)
Let
-
(b)
Randomly sample uniformly from for all
-
(c)
For in parallel: Let
-
(d)
For :
For in parallel:
-
(e)
-
(a)
-
2.
Let
-
3.
Return .
In our analysis, we follow the same notation as in Section A.1. We first establish a bound on the initial estimation error incurred in step (c) of each iteration in Algorithm 7.
Claim B.1.
Suppose . Assume . For any , suppose we draw from an arbitrary distribution , then run step (a) - (e) in Algorithm 7. Then for any ,
where is solution of the true ODE starting at at time and running until time .
Proof.
Notice that in step (c) of Algorithm 7, the initial estimate of the randomized midpoint is done with the exact same formula as in step (c) of Algorithm 6, except we calculate this initial estimate for different randomized midpoints. Notice also that the bound for discretization error in Lemma A.2 is not dependent on specific value of , as long as the randomed value is at most . Hence we can use identical calculation to yield the claim. ∎
Next, we show how to drive the initialization error from Lemma B.1 down (exponentially) using the Picard iterations in step (d) of Algorithm 7.
Lemma B.2.
Suppose . Assume . For all iterations , suppose we draw from an arbitrary distribution , then run step (a) - (e) in Algorithm 7. Then for all and ,
(21) |
where is solution of the true ODE starting at at time and running until time .
Proof.
Fixing iteration , we will let , and . The formula of and has the same coefficient for , thus we can bound the difference as follows:
(22) | ||||
(23) |
The first to second line is by definition and the second to third line is by Young’s inequality. Now, we will bound Equation 22 and Equation 23 separately. By 2.2 and 2.4,
The term in Equation 22 can now be bounded as follows
The first to second line is by inequality , the second to third line is by the fact that , the fifth to sixth line is by and that is at most when .
Similarly, the term in Equation 23 can be bounded as follows
Since ,
The second to third line is by Young’s inequality, and the fourth to fifth line is by 2.2. By Lemma 3 in [CCL+23a],
Now, by Lemma A.1 and the fact that ,
We conclude that the term in Equation 23 can be bounded by
Combining the bounds for Equation 22 and Equation 23, we get
Given sufficiently small constant , . Moreover, by the definition of , . By unrolling the recursion, we get
As a consequence of Lemma B.2, if we take the number of Picard iterations sufficiently large, the error incurred in our estimate for is dominated by the terms in the second line of Eq. (21).
Corollary B.3.
Assume . For all , suppose we draw from an arbitrary distribution , then run step (a) - (e) in Algorithm 7. In addition, suppose and . Then for any ,
where is solution of the true ODE starting at at time and running until time .
Proof.
We can now prove the parallel analogue of Lemma A.3 and Lemma A.4. Note that the bounds in Lemma B.4 and Lemma B.5 are identical to the bounds in Lemma A.3 and Lemma A.4, except from an additional factor for the middle term. This additional factor stems from using midpoints in each iteration (compared to using one midpoint each iteration in Algorithm 6).
Lemma B.4 (Parallel Predictor Bias).
Assume . For all , suppose we draw from an arbitrary distribution , then run step (a) - (e) in Algorithm 7. In addition, suppose and . Then we have
where is solution of the true ODE starting at at time and running until time .
Proof.
Fixing iteration , we will let , , and . We have
(24) | ||||
(25) |
Since is drawn uniformly from ,
and the term in Equation 25 is equal to . Now we need to bound Equation 24.
The first to second line by inequality , the second to third line is by 2.4 and 2.2, and the third to fourth line is by Corollary B.3 (, which satisfies the condition in Corollary B.3). Hence
The first to second step is by inequality and Young’s inequality, the second to third line is by plugging in our previous calculation, and the third to forth line is by and that . ∎
Lemma B.5 (Parallel Predictor Variance).
Assume . For all , suppose we draw from an arbitrary distribution , then run step (a) - (e) in Algorithm 7. In addition, suppose and . Then we have
where is solution of the true ODE starting at at time and running until time .
Proof.
Fixing iteration , we will let , , and . We will separate into several terms and bound each term separately.
(26) | ||||
(27) | ||||
(28) |
Equation 26 is identical to Equation 24 in Lemma B.4, and can be bounded by
Next we will bound Equation 27. By Lemma D.1 and Lemma A.1,
hence Equation 27 can be bounded with similar calculations as for Equation 23 in Lemma B.2, by the following term:
Finally we will bound Equation 28 Since both and belong to the range ,
Moreover, by the fact that (by integration by parts), Lemma D.1, Lemma A.1 and the fact that , we have
Hence Equation 28 can be bounded by
By adding together Equation 26, Equation 27 and Equation 28, and combining terms, we conclude that
We can now prove our main guarantee for the parallel predictor step, which states that with logarithmically many parallel rounds and score estimate queries over a short time interval (of length ) of the reverse process, the algorithm does not drift too far from the true ODE. Our proof follows a similar flow as Lemma A.5.
Note that due to the existence of midpoints in each step, we can set . We will set to be , unless is close to the end time (see Algorithm 9 for the global algorithm and timeline). If is close to , we will repeated half as increases, until we reach the end time.
Theorem B.6.
Assume . Let be an adjustable parameter. When we set , and , the Wasserstein distance between the true ODE process and the process in Algorithm 7, both starting from and run for total time is bounded by
Proof.
To avoid confusion, we will reserve as the result of running Algorithm 7 for steps, starting at . We will use to denote the result from running the true ODE process, starting at . Then by an identical calculation as in Lemma A.5,
Now we do a similar calculation as in Lemma A.5, but utilizes the bias and variance of one step in the parallel algorithm instead of sequential. Taking the expectation wrt , by Lemmas B.4 and B.5,
Since , the term is dominated by the term and the term os dominated by . By induction, noting that , we have
Since , is a constant. Moreover, by our choice of , it is always the case that and that . Therefore
and thus
We conclude that
B.2 Corrector step
In this step we will be using the parallel algorithm in [ACV24] to estimate the underdamped Langevin diffusion process. Since we will be fixing the score function in time, we will use to denote the true score function for the diffusion process, and to denote the estimated score function. We will choose the friction parameter to be .
Input parameters:
-
•
Starting sample , Number of steps , Step size , Score estimates , Number of midpoint estimates ,
-
1.
For :
-
(a)
Let
-
(b)
Let represent the algorithmic estimate of at iteration .
-
(c)
Let be a correlated gaussian vector corresponding to change caused by the Brownian motion term in time (see more detail in [ACV24])
-
(d)
For in parallel: Let
-
(e)
For :
For in parallel:
-
(f)
-
(a)
-
2.
Let
-
3.
Return .
Let denote the total time the parallel corrector step is run (namely, ). Consider two continuous underdamped Langevin diffusion processes and with coupled brownian motions. The first one start from position and the second one start from position . Both processes start with velocity . We will bound both the distance measure between and the true sample , and the distance measure between and outputs of Algorithm 8. First, [CCL+23a] gives the following bound on the total variation error between and .
Lemma B.7 ([CCL+23a], Lemma 9).
If , then
Next, [ACV24] bounds the discretization error in Algorithm 8 in terms of quantities that relates to the supremum of and where .
Lemma B.8 ([ACV24], Theorem 20, Implicit).
To reason about the value of and , we will use the following lemma in [CCL+23a].
Lemma B.9 ([CCL+23a], Lemma 10).
For any ,
Lemma B.10.
Assume . For any ,
and
Proof.
Note that is a stationary distribution of the underdamped Langevin diffusion process, hence and . Hence by integration by parts. Similarly, . Since , for any , we can now bound and by Lemma B.9 as follows:
and
Theorem B.11.
Let be an adjustable parameter. Algorithm 8 with parameter ,, and has discretization error
Proof.
Since and , . Plugging Lemma B.10 into Lemma B.8, we get that
The first to second line is by combining terms and setting , , and the second to third line is by setting . Taking the square root of yields the claim. ∎
Theorem B.12.
Let be an adjustable parameter. When Algorithm 8 is initialized at , there exists parameters ,, and such that the total variation distance between the final output of Algorithm 8 and the true distribution can be bounded as
Proof.
By triangle inequality, . Combining Lemma B.7 and Theorem B.11 yields the claim. ∎
B.3 End-to-end analysis
Input parameters:
-
•
Start time , End time , Corrector steps time , Number of predictor-corrector steps , Score estimates
-
1.
Draw .
- 2.
-
3.
Return .
Theorem B.13 (Parallel End to End Error).
By setting , , , and in Algorithm 7 and Algorithm 8, when , the total variation distance between the output of Algorithm 9 and the target distribution is
with iteration complexity .
Proof.
Let be the result of running the true ODE for time , starting from . Let be the result of running the predictor step in step of Algorithm 9, starting from and start time . In addition, let be the result of the corrector step in step of Algorithm 9, starting from .
We will first bound the error in one predictor + corrector step that starts at . By triangle inequality of TV distance and data processing inequality (applied to and ),
(29) |
By Theorem B.6 parametrized by and Theorem B.12 parametrized by ,
The first line is by Theorem B.12, and the first to second line is by Theorem B.12. Next, note that at the beginning of the process, , and at the end of the process, . By induction on Equation 29,
By Lemma A.9, . By [LLT23, Lemma 6.4], . Therefore by setting , and in Algorithm 7 and Algorithm 8, when , we obtain .
The iteration complexity of Algorithm 9 given above parameters is roughly number of predictor-corrector steps times the iteration complexity in one predictor-corrector step. Note that in any corrector step and any predictor step except the last one, only number of sub-steps are taken, therefore the iteration complexity of one predictor step (except the last step) is and iteration complexity of one corrector step is . In the last predictor step, the number of steps taken is , and thus the iteration complexity is . We conclude that the total iteration complexity of Algorithm 9 is
∎
Appendix C Log-concave sampling in total variation
In this section, we give a simple proof, using our observation about trading off the time spent on the predictor and corrector steps, of an improved bound for sampling from a log-concave distribution in total variation. Note that for this section, we assume that is the true score of the distribution and is known, as is standard in the log-concave sampling literature.
We begin by recalling Shen and Lee’s randomized midpoint method applied to approximate the underdamped Langevin process, for log-concave sampling in the Wasserstein metric [SL19] in Algorithm 10.
Input parameters:
-
•
Starting sample , Starting , Number of steps , Step size , Score function , .
-
1.
For :
-
(a)
Randomly sample uniformly from .
-
(b)
Generate Gaussian random variable as in Appendix A of [SL19].
-
(c)
Let .
-
(d)
Let .
-
(e)
Let .
-
(a)
-
2.
Return .
Theorem C.1 (Theorem 3 of [SL19], restated).
Let , the score function of a log-concave distribution be such that , for the Jacobian of . Let be the root of , and . Let be the condition number. For any , if we set the step size of Algorithm 10 as for some small constant and run the algorithm for iterations, then Algorithm 10 after iterations can generate such that
where .
Now, we make the following simple observation – if we run the corrector step from Section A.2 for a short time, we can convert the above Wasserstein guarantee to a TV guarantee. We carefully trade off the time spent on the Randomized Midpoint step above and the corrector step to obtain the improved dimension dependence. Our final algorithm is given in Algorithm 11.
Input parameters:
-
•
Number of Randomized Midpoint steps , Corrector steps Time , Randomized Midpoint Step size , Corrector step size , Score function .
We obtain the following guarantee with our improved dimension dependence of .
Theorem C.2 (Log-Concave Sampling in Total Variation).
Let be the score function of a log-concave distribution such that for the Jacobian of . Let be the condition number. For any , if we set for a small constant , , and , we have that Algorithm 11 returns with
for . Furthemore, the total iteration complexity is .
Appendix D Helper lemmas
Lemma D.1 (Corollary of [CCL+23a]).
For the ODE
if and , we have, for and ,
Lemma D.2 (Implicit in Lemma of [CCL+23a]).
Suppose , and . For ODEs starting at , where
we have
Lemma D.3 (Lemma of [GLP23], restated).
Let be a distribution over . For , let for independent of . Then,
and
Lemma D.4.
For , for , we have
Furthermore,
Proof.
The first claim is an immediate consequence of the definition of and Lemma D.3. For the second claim, note that