This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Learning differential equations that are easy to solve

Jacob Kelly
University of Toronto, Vector Institute
[email protected]
&Jesse Bettencourt111Equal Contribution.
University of Toronto, Vector Institute
[email protected]
Matthew James Johnson
mmmmmm Google Brain mmmmmm
[email protected]
&David Duvenaud
University of Toronto, Vector Institute
[email protected]
Equal Contribution. Code available at:
github.com/jacobjinkelly/easy-neural-ode
Abstract

Differential equations parameterized by neural networks become expensive to solve numerically as training progresses. We propose a remedy that encourages learned dynamics to be easier to solve. Specifically, we introduce a differentiable surrogate for the time cost of standard numerical solvers, using higher-order derivatives of solution trajectories. These derivatives are efficient to compute with Taylor-mode automatic differentiation. Optimizing this additional objective trades model performance against the time cost of solving the learned dynamics. We demonstrate our approach by training substantially faster, while nearly as accurate, models in supervised classification, density estimation, and time-series modelling tasks.

1 Introduction

Refer to caption
Refer to caption
Figure 1: Top: Trajectories of an ODE fit to map 𝐳(t1)=𝐳(t0)+𝐳(t0)3{\mathbf{z}(t_{1})=\mathbf{z}(t_{0})+\mathbf{z}(t_{0})^{3}}. The learned dynamics are unnecessarily complex and require many evaluations (black dots) to solve.
Bottom: Regularizing the third total derivative d3𝐳(t)dt3\frac{\mathrm{d}^{3}\mathbf{z}(t)}{\mathrm{d}t^{3}} (shown by colour) gives dynamics that fit the same map, but require fewer evaluations to solve.

Differential equations describe a system’s behavior by specifying its instantaneous dynamics. Historically, differential equations have been derived from theory, such as Newtonian mechanics, Maxwell’s equations, or epidemiological models of infectious disease, with parameters inferred from observations. Solutions to these equations usually cannot be expressed in closed-form, requiring numerical approximation.

Recently, ordinary differential equations parameterized by millions of learned parameters, called neural ODEs, have been fit for latent time series models, density models, or as a replacement for very deep neural networks (Rubanova et al., 2019; Grathwohl et al., 2019; Chen et al., 2018). These learned models are not constrained to match a theoretical model, only to optimize an objective on observed data. Learned models with nearly indistinguishable predictions can have substantially different dynamics. This raises the possibility that we can find equivalent models that are easier and faster to solve. Yet standard training methods have no way to penalize the complexity of the dynamics being learned.

How can we learn dynamics that are faster to solve numerically without substantially changing their predictions? Much of the computational advantages of a continuous-time formulation come from using adaptive solvers, and most of the time cost of these solvers comes from repeatedly evaluating the dynamics function, which in our settings is a moderately-sized neural network. So, we’d like to reduce the number of function evaluations (NFE) required for these solvers to reach a given error tolerance. Ideally, we would add a term penalizing the NFE to the training objective, and let a gradient-based optimizer trade off between solver cost and predictive performance. But because NFE is integer-valued, we need to find a differentiable surrogate.

The NFE taken by an adaptive solver depends on how far it can extrapolate the trajectory forward without introducing too much error. For example, for a standard adaptive-step Runge-Kutta solver with order mm, the step size is approximately inversely proportional to the norm of the local mmth total derivative of the solution trajectory with respect to time. That is, a larger mmth derivative leads to a smaller step size and thus more function evaluations. Thus, we propose to minimize the norm of this total derivative during training, as a way to control the time required to solve the learned dynamics.

In this paper, we investigate the effect of this speed regularization in various models and solvers. We examine the relationship between the solver order and the regularization order, and characterize the tradeoff between speed and performance. In most instances, we find that solver speed can be approximately doubled without a substantial increase in training loss. We also provide an extension to the JAX program transformation framework that provides Taylor-mode automatic differentiation, which is asymptotically more efficient for computing the required total derivatives than standard nested gradients.

Our work compares against and generalizes that of Finlay et al. (2020), who proposed regularizing dynamics in the FFJORD density estimation model, and showed that it stabilized dynamics enough in that setting to allow the use of fixed-step solvers during training.

2 Background

An ordinary differential equation (ODE) specifies the instantaneous change of a vector-valued state 𝐳(t)\mathbf{z}(t): d𝐳(t)dt=f(𝐳(t),t,θ)\frac{\mathrm{d}\mathbf{z}(t)}{\mathrm{d}t}=f(\mathbf{z}(t),t,\theta). Given an initial condition 𝐳(t0)\mathbf{z}(t_{0}), computing the state at a later time:

𝐳(t1)=𝐳(t0)+t0t1f(𝐳(t),t,θ)dt{\mathbf{z}(t_{1})=\mathbf{z}(t_{0})+\int_{t_{0}}^{t_{1}}f(\mathbf{z}(t),t,\theta)\,\mathrm{d}t}

is called an initial value problem (IVP). For example, ff could describe the equations of motion for a particle, or the transmission and recovery rates for a virus across a population. Usually, the required integral has no analytic solution, and must be approximated numerically.

Adaptive-step Runge-Kutta ODE Solvers

Runge-Kutta methods (Runge, 1895; Kutta, 1901) approximate the solution trajectories of ODEs through a series of small steps, starting at time t0t_{0}. At each step, they choose a step size hh, and fit a local approximation to the solution, 𝐳^(t)\mathbf{\hat{z}}(t), using several evaluations of ff. When hh is sufficiently small, the numerical error of a mmth-order method is bounded by 𝐳^(t+h)𝐳(t+h)chm+1\left\|\mathbf{\hat{z}}(t+h)-\mathbf{z}(t+h)\right\|\leq ch^{m+1} for some constant cc (Hairer et al., 1993). So, for a mmth-order method, the local error grows approximately in proportion to the size of the mmth coefficient in the Taylor expansion of the true solution. All else being equal, controlling this coefficient for all dimensions of 𝐳(t)\mathbf{z}(t) will allow larger steps to be taken without surpassing the error tolerance.

Neural Ordinary Differential Equations

The dynamics function ff can be a moderately-sized neural network, and its parameters θ\theta trained by gradient descent. Solving the resulting IVP is analogous to evaluating a very deep residual network in which the number of layers corresponds to the number of function evaluations of the solver (Chang et al., 2017; Ruthotto & Haber, 2018; Chen et al., 2018). Solving such continuous-depth models using adaptive numerical solvers has several computational advantages over standard discrete-depth network architectures. However, this approach is often slower than using a fixed-depth network, due to an inability to control the number of steps required by an adaptive-step solver.

3 Regularizing Higher-Order Derivatives for Speed

The ability of Runge-Kutta methods to take large and accurate steps is limited by the KKth-order Taylor coefficients of the solution trajectory. We would like these coefficients to be small. Specifically, we propose to regularize the squared norm of the KKth-order total derivatives of the state with respect to time, integrated along the entire solution trajectory:

K(θ)\displaystyle\mathcal{R}_{K}(\theta) =t0t1dK𝐳(t)dtK22dt\displaystyle=\int_{t_{0}}^{t_{1}}\left\|\frac{\mathrm{d}^{K}\mathbf{z}(t)}{\mathrm{d}t^{K}}\right\|^{2}_{2}\,\mathrm{d}t (1)

where 22\left\|\cdot\right\|_{2}^{2} is the squared 2\ell_{2} norm, and the dependence on the dynamics parameters θ\theta is implicit through the solution 𝐳(t)\mathbf{z}(t) integrating d𝐳(t)dt=f(𝐳(t),t,θ)\frac{\mathrm{d}\mathbf{z}(t)}{\mathrm{d}t}=f(\mathbf{z}(t),t,\theta). During training, we weigh this regularization term by a hyperparameter λ\lambda and add it to our original loss to get our regularized objective:

Lreg(θ)=L(θ)+λK(θ)\displaystyle L_{reg}(\theta)=L(\theta)+\lambda\mathcal{R}_{K}(\theta) (2)
Refer to caption
Figure 2: mm-order Runge-Kutta solvers need small steps when the dynamics have non-zero total derivatives of order KmK\geq m (lower triangle). Color denotes the increase in number of steps from KK to K1K-1, normalized for each solver order.

What kind of solutions are allowed when K=0\mathcal{R}_{K}=0? For K=0K=0, we have 𝐳(t)22=0\left\|\mathbf{z}(t)\right\|_{2}^{2}=0, so the only possible solution is 𝐳(t)=0\mathbf{z}(t)=0. For K=1K=1, we have f(𝐳(t),t)22=0\left\|f(\mathbf{z}(t),t)\right\|_{2}^{2}=0, so all solutions are constant, flat trajectories. For K=2K=2 solutions are straight-line trajectories. Higher values of KK shrink higher derivatives, but don’t penalize lower-order dynamics. For instance, a quadratic trajectory will have 3=0\mathcal{R}_{3}=0. Setting the KKth order dynamics to exactly zero everywhere automatically makes all higher orders zero as well. Figure 1 shows that regularizing 3\mathcal{R}_{3} on a toy 1D neural ODE reduces NFE.

Which orders should we regularize? We propose matching the order of the regularizer to that of the solver being used. We conjecture that regularizing dynamics of lower orders than that of the solver restricts the model unnecessarily, and that letting the lower orders remain unregularized should not increase NFE very much. Figure 2 shows empirically which orders of Runge-Kutta solvers can efficiently solve which orders of toy polynomial trajectories. We partially confirm these conjectures on real models and datasets in section 6.2.

The solution trajectory and our regularization term can be computed in a single call to an ODE solver by augmenting the system with the integrand in eq. 1.

4 Efficient Higher Order Differentiation with Taylor Mode

The number of terms in higher-order forward derivatives grows exponentially in KK, becoming prohibitively expensive for K=5K=5, and causing substantial slowdowns even for K=2K=2 and K=3K=3. Luckily, there exists a generalization of forward-mode automatic differentiation (AD), known as Taylor mode, which can compute the total derivative exactly for a cost of only 𝒪(K2)\mathcal{O}(K^{2}). We found that this asymptotic improvement reduced wall-clock time by an order of magnitude, even for KK as low as 3.

First-order forward-mode AD

Standard forward-mode AD computes, for a function f(x)f(x) and an input perturbation vector vv, the product fxv\frac{\partial f}{\partial x}v. This Jacobian-vector product, or JVP, can be computed efficiently without explicitly instantiating the Jacobian. This implicit computation of JVPs is straightforward whenever ff is a composition of operations for which which implicit JVP rules are known.

Higher-order Jacobian-vector products

Forward-mode AD can be generalized to higher orders to compute KKth-order Jacobians contracted KK times against the perturbation vector: KfxKvK\frac{\partial^{K}f}{\partial x^{K}}v^{\otimes K}. Similarly, this can also be computed without representing any Jacobian matrices explicitly.

A naïve approach to higher-order forward mode is to recursively apply first-order forward mode. Specifically, nesting JVPs KK times gives the right answer: KfxKvK=x((x(fxv)v)v){\frac{\partial^{K}f}{\partial x^{K}}v^{\otimes K}=\frac{\partial}{\partial x}(\cdots(\frac{\partial}{\partial x}(\frac{\partial f}{\partial x}v)v)\cdots v)} but causes an unnecessary exponential slowdown, costing O(exp(K))O(\exp(K)). This is because expressions that appear in lower derivatives also appear in higher derivatives, but the work to compute is not shared across orders.

Function Taylor propagation rule
y=z+cwy=z+cw y[k]=z[k]+cw[k]y_{{[k]}}=z_{{[k]}}+cw_{{[k]}}
y=zwy=z*w y[k]=j=0kz[j]w[kj]y_{{[k]}}=\sum_{j=0}^{k}z_{{[j]}}w_{{[k-j]}}
y=z/wy=z/w y[k]=1w0[zkj=0k1z[j]w[kj]]y_{{[k]}}=\frac{1}{w_{0}}\left[z_{k}-\sum_{j=0}^{k-1}z_{{[j]}}w_{{[k-j]}}\right]
y=exp(z)y=\exp(z) y~[k]=j=1ky[kj]z~[j]\tilde{y}_{{[k]}}=\sum_{j=1}^{k}y_{{[k-j]}}\tilde{z}_{{[j]}}
s=sin(z)s=\sin(z) s~[k]=j=1kz~[j]c[kj]\tilde{s}_{{[k]}}=\sum_{j=1}^{k}\tilde{z}_{{[j]}}c_{{[k-j]}}
c=cos(z)c=\cos(z) c~[k]=j=1kz~[j]s[kj]\tilde{c}_{{[k]}}=\sum_{j=1}^{k}-\tilde{z}_{{[j]}}s_{{[k-j]}}
Table 1: Rules for propagating Taylor polynomial coefficients through standard functions. These rules generalize standard first-order derivatives. Notation z[i]=1i!ziz_{{[i]}}=\frac{1}{i!}z_{i} and y~[i]=ii!zi\tilde{y}_{{[i]}}=\frac{i}{i!}z_{i}.
Taylor Mode

Taylor-mode AD generalizes first-order forward mode to compute the first KK derivatives exactly with a time cost of only O(K2)O(K^{2}) or O(KlogK)O(K\log K), depending on the operations involved. Instead of providing rules for propagating perturbation vectors, one provides rules for propagating truncated Taylor series. Some example rules are shown in table 1. For more details see the Appendix and Griewank & Walther (2008, Chapter 13). We provide an open source implementation of Taylor mode AD in the JAX Python library (Bradbury et al., 2018).

5 Experiments

Refer to caption
Figure 3: Number of function evaluations (NFE) and training error during training. Speed regularization (solid) decreases the NFE throughout training without substantially changing the training error.

We consider three different tasks in which continuous-depth or continuous time models might have computational advantages over standard discrete-depth models: supervised learning, continuous generative modeling of time-series (Rubanova et al., 2019), and density estimation using continuous normalizing flows (Grathwohl et al., 2019). Unless specified otherwise, we use the standard dopri5 Runge-Kutta 4(5) solver (Dormand & Prince, 1980; Shampine, 1986).

5.1 Supervised Learning

We construct a model for MNIST classification: it takes in as input a flattened MNIST image and integrates it through dynamics given by a simple MLP, then applies a linear classification layer. In fig. 3 we compare the NFE and training error of a model with and without regularizing 3\mathcal{R}_{3}.

Refer to caption
(a) Unregularized
Refer to caption
(b) Regularized
Figure 4: Regularizing dynamics in a latent ODE modeling PhysioNet clinical data. Shown are a representative 2-dimensional slice of 20 dimensional dynamics. We reduce average NFE from 281 to 90 while only incurring an 8% increase in loss.

5.2 Continuous Generative Time Series Models

As in Rubanova et al. (2019), we use the Latent ODE architecture for modelling trajectories of ICU patients using the PhysioNet Challenge 2012 dataset (Silva et al., 2012). This variational autoencoder architecture uses an RNN recognition network, and models the state dynamics using an ODE in a latent space.

In the supervised learning setting described in the previous section only the final state affects model predictions. In contrast, time-series models’ predictions also depend on the value of the trajectory at all intermediate times when observations were made. So, we might expect speed regularization to be ineffective due to these extra constraints on the dynamics. However, fig. 4 shows that, without changing their overall shape the latent dynamics can be adjusted to reduce their NFE by a factor of 3.

5.3 Density Estimation with Continuous Normalizing Flows

Our third task is unsupervised density estimation, using a scalable variant of continuous normalizing flows called FFJORD (Grathwohl et al., 2019). We fit the MINIBOONE tabular dataset from Papamakarios et al. (2017) and the MNIST image dataset (LeCun et al., 2010). We use the respective singe-flow architectures from Grathwohl et al. (2019).

Grathwohl et al. (2019) noted that the NFE required to numerically integrate their dynamics could become prohibitively expensive throughout training. Table 2 shows that we can reduce NFE by 38% for only a 0.6% increase in log-likelihood measured in bits/dim.

How to train your Neural ODE

We compare against the approach of Finlay et al. (2020), who design two regularization terms specifically for stabilizing the dynamics of FFJORD models:

𝒦(θ)\displaystyle\mathcal{K}(\theta) =t0t1f(𝐳(t),t,θ)22dt\displaystyle=\int_{t_{0}}^{t_{1}}\left\|f(\mathbf{z}(t),t,\theta)\right\|^{2}_{2}\,\mathrm{d}t (3)
(θ)\displaystyle\mathcal{B}(\theta) =t0t1ϵ𝐳f(𝐳(t),t,θ)22dt,ϵ𝒩(0,I)\displaystyle=\int_{t_{0}}^{t_{1}}\left\|\epsilon^{\intercal}\nabla_{\mathbf{z}}f(\mathbf{z}(t),t,\theta)\right\|^{2}_{2}\,\mathrm{d}t,\qquad\epsilon\sim\mathcal{N}(0,I) (4)

The first term is designed to encourage straight-line paths, and the second, stochastic, term is designed to reduce overfitting. Finlay et al. (2020) used fixed-step solvers during training for some datasets. We compare these two regularization on training with each of adaptive and fixed-step solvers, and evaluated using an adaptive solver, in section 6.3.

6 Analysis and Discussion

6.1 Trading off function evaluations for loss

What does the trade off between accuracy and speed look like? Ideally, we could reduce the solver time a lot without substantially reducing model performance. Indeed, this is demonstrated in all three settings we explored. Figure 5 shows that generally, model performance starts getting substantially worse only after a 50% reduction in solver speed when controlling 2\mathcal{R}_{2}.

Refer to caption
(a) MNIST Classification
Refer to caption
(b) PhysioNet Time-Series
Refer to caption
(c) Miniboone Density Estimation
Figure 5: Tuning the regularization of 2\mathcal{R}_{2} trades off between training loss and solver speed in three different applications of neural ODEs. Horizontal axes show average number of function evaluations, and vertical axes show unregularized training loss, both at the end of training.

6.2 Order of regularization vs. order of solver

Which order of total derivatives should we regularize for a particular solver? As mentioned earlier, we conjecture that the best choice would be to match the order of the solver being used. Regularizing too low an order might needlessly constrain the dynamics and make it harder to fit the data, while regularizing too high an order might leave the dynamics difficult to solve for a lower-order solver. However, we also expect that optimizing higher-order derivatives might be challenging, since these higher derivatives can change quickly even for small changes to the dynamics parameters.

Figures 6 and 7 investigate this question on the task of MNIST classification. Figure 6 compares the effectiveness of regularizing different orders when using a solver of a particular order. For a 2nd order solver, regularizing K=2K=2 produces a strictly better trade-off between performance and speed, as expected. For higher-order solvers, including ones with adaptive order, we found that regularizing orders above K=3K=3 gave little benefit.

Refer to caption
(a) Order 2 Solver
Refer to caption
(b) Order 3 Solver
Refer to caption
(c) Order 5 Solver
Refer to caption
(d) Adaptive Order Solver
Figure 6: Comparing tradeoff between speed and performance when regularizing different orders. 6(a)): For a 2nd-order solver, regularizing the 2nd total derivative gives the best tradeoff. 6(b)): For a 3rd-order solver, regularizing the 3rd total derivative gives the best tradeoff, but the difference is small. 6(c)): For a 5th-order solver, results are mixed. 6(d)): For an adaptive-order solver, the difference is again small but regularizing higher orders works slightly better.

Figure 7 investigates the relationship between K\mathcal{R}_{K} and the quantity it is meant to be a surrogate for: NFE. We observe a clear monotonic relationship between the two, for all orders of solver and regularization.

Refer to caption
(a) Order 2 Solver
Refer to caption
(b) Order 3 Solver
Refer to caption
(c) Order 5 Solver
Figure 7: For all orders, K\mathcal{R}_{K} varies monotonically with NFE. For each order of solver, the model with the lowest NFE was achieved by regularizing the same order.

6.3 Do we reduce training time?

Our approach produces models that are fastest to evaluate at test time. However, when we train with adaptive solvers we do not improve overall training time, due to the additional expense of computing our regularizer. Training with a fixed-grid solver is faster, but can be unstable if dynamics are unregularized. Finlay et al. (2020)’s regularization and ours allow us to use fixed grid solvers and reduce training time. However, ours is 2.4×\times slower than Finlay et al. (2020) for FFJORD because their regularization re-uses terms already computed in the FFJORD training objective. For objectives where these cannot be re-used, like MNIST classification, our method is 1.7×1.7\times slower, but achieves better test-time NFE.

Table 2: Density Estimation on MNIST using FFJORD. For adaptive solvers, indicated by \infty Steps, our approach is slowest to train, but requires the fewest NFE once trained. For fixed-step solvers our approach achieves lower bits/dim and NFE when comparing across fixed-grid solvers using the same number of steps. Fixed step solvers that diverged due to instability are indicated by NaN bits/dim.
Training Evaluation using adaptive solvers
Steps Hours Bits/Dim NFE 2\mathcal{R}_{2} \mathcal{B} 𝒦\mathcal{K}
Unregularized 8 - NaN - - - -
\infty 35.8 1.033 149 3596 4.76 73.6
RNODE 5 - NaN - - - -
(Finlay et al., 2020) 6 8.4 1.069 122 157.8 1.82 35.0
8 11.1 1.048 97 39.3 1.85 34.8
\infty 22.9 1.049 104 46.6 1.85 34.7
TayNODE (ours) 5 20.3 1.077 98 31.3 2.89 36.5
6 20.4 1.057 105 31.1 2.91 36.5
8 27.1 1.046 98 26.0 2.53 36.3
\infty 54.7 1.039 92 22.9 2.41 36.2

6.4 Are we making the solver overconfident?

Because we optimize dynamics in a way specifically designed to make the solver take longer steps, we might fear that we are “adversarially attacking” our solver, making it overconfident in its ability to extrapolate. Figure 8(a) shows that this is not the case for MNIST classification.

Refer to caption
(a) Solver tolerance vs. Error
Refer to caption
(b) NFE Overfitting
Refer to caption
(c) Statistical Regularization
Figure 8: Figure 8(a): We observe that the actual solver error is about equally well-calibrated for regularized dynamics as random dynamics, indicating that regularization does not make the solver overconfident. Figure 8(b): There is negligible overfitting of solver speed. Figure 8(c): Speed regularization does not usefully improve generalization. For large λ\lambda, our method reduces overfitting, but increases overall test error due to under-fitting.

6.5 Does speed regularization overfit?

Finlay et al. (2020) motivated one of their regularization terms by the possibility of overfitting: having faster dynamics only for the examples in the training set, but still low on the test set. However, they did not check whether overfitting was occurring. In fig. 8(b) we confirm that our regularized dynamics have nearly identical average solve time on a held-out test set, on MNIST classification.

7 Related Work

Grathwohl et al. (2019) mention attempting to use weight decay and spectral normalization to reduce NFE. Of course, Finlay et al. (2020), among other contributions, regularized trajectories of continuous normalizing flows and introduced the use of fixed-step solvers for stable and faster training. The use of fixed-step solvers is also explored in Onken & Ruthotto (2020). Onken et al. (2020) also regularized the trajectories of continuous normalizing flows, among other contributions. Massaroli et al. (2020b) introduce new formulations of neural differential equations and investigate regularizing these models as applied to continuous normalizing flows in Massaroli et al. (2020a).

Poli et al. (2020) introduce solvers parameterized by neural networks for faster solving of neural differential equations. Kidger et al. (2020a) exploit the structure of the adjoint equations to construct a solver needing less NFE to speed up backpropagation through neural differential equations.

Morrill et al. (2020) introduce a method to improve the speed of neural controlled differential equations (Kidger et al., 2020b) which are designed for irregularly-sampled timeseries.

Simard et al. (1991) regularized the dynamics of discrete-time recurrent neural networks to improve their stability, by constraining the norm of the Jacobian of the dynamics function in the direction of its largest eigenvalue. However, this approach has an 𝒪(D3)\mathcal{O}(D^{3}) time cost. De Brouwer et al. (2019) introduced a parameterization of neural ODEs analogous to instantaneous Gated Recurrent Unit (GRU) recurrent neural network architectures in order to stabilize training dynamics. Dupont et al. (2019) provided theoretical arguments that adding extra dimensions to the state of a neural ODE should make training easier, and showed that this helped reduce NFE during training.

Chang et al. (2017) noted the connection between residual networks and ODEs, and took advantage of this connection to gradually make resnets deeper during training, in order to save time. One can view the increase in NFE with neural ODEs as an automatic, but uncontrolled, version of their method. Their results suggest we might benefit from introducing a speed regularization schedule that gradually tapers off during training.

Novak et al. (2018); Drucker & LeCun (1992) regularized the gradients of neural networks to improve generalization.

We speculate on the application of our regularization in eq. 1 for other purposes, including adversarial robustness (Yang et al., 2020; Hanshu et al., 2019). and function approximation with Gaussian processes (Dutra et al., 2014; van der Vaart et al., 2008).

8 Scope

The initial speedups obtained in this paper are not yet enough to make neural ODEs competitive with standard fixed-depth architectures in terms of speed for standard supervised learning. However, there are many applications where continuous-depth architectures provide a unique advantage. Besides density models such as FFJORD and time series models, continuous-depth architectures have been applied in solving mean-field games (Ruthotto et al., 2019), image segmentation (Pinckaers & Litjens, 2019), image super-resolution (Scao, 2020), and molecular simulations (Wang et al., 2020). These applications, which already use continuous-time models, could benefit from the speed regularization proposed in this paper.

While we investigated only ODEs in this paper, this approach could presumably be extended straightforwardly to neural stochastic differential equations fit by adaptive solvers (Li et al., 2020) and other flavors of parametric differential equations fit by gradient descent (Rackauckas et al., 2019).

9 Limitations

Hyperparameters

The hyperparameter λ\lambda needs to be chosen to balance speed and training loss. One the other hand, neural ODEs don’t require choosing the outer number of layers, which needs to be chosen separately for each stack of layers in standard architectures.

One also needs to choose solver order and tolerances, and these can substantially affect solver speed. We did not investigate loosening tolerances, or modifying other parameters of the solver. The default tolerance of 1.4e-8 for both atol and rtol behaved well in all our experiments.

One also needs to choose KK. Higher KK seems to generally work better, but is slower per step at training time. In principle, if one can express their utility explicitly in terms of training loss and NFE, it may be possible to tune λ\lambda automatically during training using the predictable relationship between K\mathcal{R}_{K} and NFE shown in fig. 7.

Slower overall training

Although speed regularization reduces the overall NFE during training, it makes each step more expensive. In our density estimation experiments (table 2), the overall effect was about about 70% slower training, compared to no regularization, when using adaptive solvers. However, test-time evaluation is much faster, since there is no slowdown per step.

10 Conclusions

This paper is an initial attempt at controlling the integration time of differential equations by regularizing their dynamics. This is an almost unexplored problem, and there are almost certainly better quantities to optimize than the ones examined in this paper.

Based on these initial experiments, we propose three practical takeaways:

  1. 1.

    Across all tasks, tuning the regularization usually gave at least a 2x speedup without substantially hurting model performance.

  2. 2.

    Overall training time with speed regularization is in general about 30% to 50% slower with adaptive solvers.

  3. 3.

    For standard solvers, regularizing orders higher than 2\mathcal{R}_{2} or 3\mathcal{R}_{3} provided little additional benefit.

Future work

It may be possible to adapt solver architectures to take advantage of flexibility in choosing the dynamics. Standard solver design has focused on robustly and accurately solving a given set of differential equations. However, in a learning setting, we could consider simply rejecting some kinds of dynamics as being too difficult to solve, analogous to other kinds of constraints we put on models to encourage statistical regularization.

Acknowledgements

We thank Dougal Maclaurin, Andreas Griewank, Barak Perlmutter, Ken Jackson, Chris Finlay, James Saunderson, James Bradbury, Ricky T.Q. Chen, Will Grathwohl, Chris Rackauckas, David Sanders, and Lyndon White for feedback and helpful discussions. Resources used in preparing this research were provided, in part, by the Province of Ontario, the Government of Canada through CIFAR, NSERC, and companies sponsoring the Vector Institute.

Broader Impact

We expect the main impact from this work, if any, would be through a potential improvement of the fundamental modeling tools of regression, classification, time series models, and density estimation. Thus the impact of this work is not distinct from that of improved machine learning tools in general. While machine learning tools present both benefits and unintended consequences, we avoid speculating further.

References

  • Bradbury et al. (2018) Bradbury, J., Frostig, R., Hawkins, P., Johnson, M. J., Leary, C., Maclaurin, D., and Wanderman-Milne, S. JAX: composable transformations of Python+NumPy programs, 2018. URL http://github.com/google/jax.
  • Chang et al. (2017) Chang, B., Meng, L., Haber, E., Tung, F., and Begert, D. Multi-level residual networks from dynamical systems view. arXiv preprint arXiv:1710.10348, 2017.
  • Chen et al. (2018) Chen, T. Q., Rubanova, Y., Bettencourt, J., and Duvenaud, D. K. Neural ordinary differential equations. In Advances in neural information processing systems, pp. 6571–6583, 2018.
  • De Brouwer et al. (2019) De Brouwer, E., Simm, J., Arany, A., and Moreau, Y. GRU-ODE-Bayes: Continuous modeling of sporadically-observed time series. In Advances in Neural Information Processing Systems, pp. 7377–7388, 2019.
  • Dormand & Prince (1980) Dormand, J. R. and Prince, P. J. A family of embedded Runge-Kutta formulae. Journal of computational and applied mathematics, 6(1):19–26, 1980.
  • Drucker & LeCun (1992) Drucker, H. and LeCun, Y. Improving generalization performance using double backpropagation. IEEE Trans. Neural Networks, 3(6):991–997, 1992. doi: 10.1109/72.165600. URL https://doi.org/10.1109/72.165600.
  • Dupont et al. (2019) Dupont, E., Doucet, A., and Teh, Y. W. Augmented neural ODEs. In Advances in Neural Information Processing Systems, pp. 3134–3144, 2019.
  • Dutra et al. (2014) Dutra, D. A., Teixeira, B. O. S., and Aguirre, L. A. Maximum a posteriori state path estimation: Discretization limits and their interpretation. Automatica, 50(5):1360–1368, 2014.
  • Finlay et al. (2020) Finlay, C., Jacobsen, J.-H., Nurbekyan, L., and Oberman, A. M. How to train your neural ODE. arXiv preprint arXiv:2002.02798, 2020.
  • Grathwohl et al. (2019) Grathwohl, W., Chen, R. T. Q., Bettencourt, J., Sutskever, I., and Duvenaud, D. FFJORD: Free-form continuous dynamics for scalable reversible generative models. International Conference on Learning Representations, 2019.
  • Griewank & Walther (2008) Griewank, A. and Walther, A. Evaluating derivatives. 2008.
  • Hairer et al. (1993) Hairer, E., Norsett, S., and Wanner, G. Solving Ordinary Differential Equations I: Nonstiff Problems, volume 8. 01 1993. doi: 10.1007/978-3-540-78862-1.
  • Hanshu et al. (2019) Hanshu, Y., Jiawei, D., Vincent, T., and Jiashi, F. On robustness of neural ordinary differential equations. In International Conference on Learning Representations, 2019.
  • Kidger et al. (2020a) Kidger, P., Chen, R. T., and Lyons, T. "hey, that’s not an ode": Faster ode adjoints with 12 lines of code. arXiv preprint arXiv:2009.09457, 2020a.
  • Kidger et al. (2020b) Kidger, P., Morrill, J., Foster, J., and Lyons, T. Neural controlled differential equations for irregular time series. arXiv preprint arXiv:2005.08926, 2020b.
  • Kutta (1901) Kutta, W. Beitrag zur näherungsweisen Integration totaler Differentialgleichungen. Zeitschrift für Mathematik und Physik, 46:435–453, 1901.
  • LeCun et al. (2010) LeCun, Y., Cortes, C., and Burges, C. MNIST handwritten digit database. ATT Labs [Online]. Available: http://yann.lecun.com/exdb/mnist, 2, 2010.
  • Li et al. (2020) Li, X., Chen, R. T. Q., Wong, T.-K. L., and Duvenaud, D. Scalable gradients for stochastic differential equations. In Artificial Intelligence and Statistics, 2020.
  • Massaroli et al. (2020a) Massaroli, S., Poli, M., Bin, M., Park, J., Yamashita, A., and Asama, H. Stable neural flows. arXiv preprint arXiv:2003.08063, 2020a.
  • Massaroli et al. (2020b) Massaroli, S., Poli, M., Park, J., Yamashita, A., and Asama, H. Dissecting neural odes. arXiv preprint arXiv:2002.08071, 2020b.
  • Morrill et al. (2020) Morrill, J., Kidger, P., Salvi, C., Foster, J., and Lyons, T. Neural cdes for long time series via the log-ode method. arXiv preprint arXiv:2009.08295, 2020.
  • Novak et al. (2018) Novak, R., Bahri, Y., Abolafia, D. A., Pennington, J., and Sohl-Dickstein, J. Sensitivity and generalization in neural networks: an empirical study. In 6th International Conference on Learning Representations, ICLR 2018, Vancouver, BC, Canada, April 30 - May 3, 2018, Conference Track Proceedings. OpenReview.net, 2018. URL https://openreview.net/forum?id=HJC2SzZCW.
  • Onken & Ruthotto (2020) Onken, D. and Ruthotto, L. Discretize-optimize vs. optimize-discretize for time-series regression and continuous normalizing flows. arXiv preprint arXiv:2005.13420, 2020.
  • Onken et al. (2020) Onken, D., Fung, S. W., Li, X., and Ruthotto, L. Ot-flow: Fast and accurate continuous normalizing flows via optimal transport. arXiv preprint arXiv:2006.00104, 2020.
  • Papamakarios et al. (2017) Papamakarios, G., Pavlakou, T., and Murray, I. Masked autoregressive flow for density estimation. Advances in Neural Information Processing Systems, 2017.
  • Pinckaers & Litjens (2019) Pinckaers, H. and Litjens, G. Neural ordinary differential equations for semantic segmentation of individual colon glands. arXiv preprint arXiv:1910.10470, 2019.
  • Poli et al. (2020) Poli, M., Massaroli, S., Yamashita, A., Asama, H., and Park, J. Hypersolvers: Toward fast continuous-depth models. arXiv preprint arXiv:2007.09601, 2020.
  • Rackauckas et al. (2019) Rackauckas, C., Innes, M., Ma, Y., Bettencourt, J., White, L., and Dixit, V. Diffeqflux.jl-a Julia library for neural differential equations. arXiv preprint arXiv:1902.02376, 2019.
  • Rubanova et al. (2019) Rubanova, Y., Chen, T. Q., and Duvenaud, D. K. Latent ordinary differential equations for irregularly-sampled time series. In Advances in Neural Information Processing Systems, pp. 5321–5331, 2019.
  • Runge (1895) Runge, C. Über die numerische Auflösung von Differentialgleichungen. Mathematische Annalen, 46:167–178, 1895.
  • Ruthotto & Haber (2018) Ruthotto, L. and Haber, E. Deep neural networks motivated by partial differential equations. Journal of Mathematical Imaging and Vision, pp.  1–13, 2018.
  • Ruthotto et al. (2019) Ruthotto, L., Osher, S. J., Li, W., Nurbekyan, L., and Fung, S. W. A machine learning framework for solving high-dimensional mean field game and mean field control problems. CoRR, abs/1912.01825, 2019. URL http://arxiv.org/abs/1912.01825.
  • Scao (2020) Scao, T. L. Neural differential equations for single image super-resolution. arXiv preprint arXiv:2005.00865, 2020.
  • Shampine (1986) Shampine, L. F. Some practical Runge-Kutta formulas. Mathematics of Computation, 46(173):135–150, 1986. ISSN 00255718, 10886842. URL http://www.jstor.org/stable/2008219.
  • Silva et al. (2012) Silva, I., Moody, G., Scott, D. J., Celi, L. A., and Mark, R. G. Predicting in-hospital mortality of ICU patients: The physionet/computing in cardiology challenge 2012. In 2012 Computing in Cardiology, pp.  245–248, 2012.
  • Simard et al. (1991) Simard, P., Raysz, J. P., and Victorri, B. Shaping the state space landscape in recurrent networks. In Advances in neural information processing systems, pp. 105–112, 1991.
  • van der Vaart et al. (2008) van der Vaart, A. W., van Zanten, J. H., et al. Reproducing kernel hilbert spaces of gaussian priors. In Pushing the limits of contemporary statistics: contributions in honor of Jayanta K. Ghosh, pp.  200–222. Institute of Mathematical Statistics, 2008.
  • Wang et al. (2020) Wang, W., Axelrod, S., and Gómez-Bombarelli, R. Differentiable molecular simulations for control and learning. arXiv preprint arXiv:2003.00868, 2020.
  • Yang et al. (2020) Yang, Z., Liu, Y., Bao, C., and Shi, Z. Interpolation between residual and non-residual networks. arXiv preprint arXiv:2006.05749, 2020.

Appendix A Taylor-mode Automatic Differentiation

A.1 Taylor Polynomials

To clarify the relationship between the presentation in Chapter 13 of Griewank & Walther (2008) and our results we give the distinction between the Taylor coefficients and derivative coefficients, also known, unhelpfully, as Tensor coefficients.

For a sufficiently smooth vector valued function f:nmf:\mathbb{R}^{n}\rightarrow\mathbb{R}^{m} and the polynomial

x(t)=x[0]+x[1]t+x[2]t2+x[3]t3++x[d]tdnx(t)=x_{[0]}+x_{[1]}t+x_{[2]}t^{2}+x_{[3]}t^{3}+\cdots+x_{[d]}t^{d}\in\mathbb{R}^{n} (5)

we are interested in the dd-truncated Taylor expansion

y(t)\displaystyle y(t) =f(x(t))+O(td+1)\displaystyle=f(x(t))+O(t^{d+1}) (6)
y[0]+y[1]t+y[2]t2+y[3]t3++y[d]tdm\displaystyle\equiv y_{[0]}+y_{[1]}t+y_{[2]}t^{2}+y_{[3]}t^{3}+\cdots+y_{[d]}t^{d}\in\mathbb{R}^{m} (7)

with the notation that y[i]=1i!yiy_{[i]}=\frac{1}{i!}y_{i} is the Taylor coefficient, which is the normalized derivative coefficient yiy_{i}.

The Taylor coefficients of the expansion, y[j]y_{[j]}, are smooth functions of the iji\leq j coefficients x[i]x_{[i]},

y[0]\displaystyle y_{[0]} =y[0](x[0])\displaystyle=y_{[0]}(x_{[0]}) =f(x[0])\displaystyle=f(x_{[0]}) (8)
y[1]\displaystyle y_{[1]} =y[1](x[0],x[1])\displaystyle=y_{[1]}(x_{[0]},x_{[1]}) =f(x[0])x[1]\displaystyle=f^{\prime}(x_{[0]})x_{[1]} (9)
y[2]\displaystyle y_{[2]} =y[2](x[0],x[1],x[2])\displaystyle=y_{[2]}(x_{[0]},x_{[1]},x_{[2]}) =f(x[0])x[2]+12f′′(x[0])x[1]x[1]\displaystyle=f^{\prime}(x_{[0]})x_{[2]}+\frac{1}{2}f^{\prime\prime}(x_{[0]})x_{[1]}x_{[1]} (10)
y[3]\displaystyle y_{[3]} =y[3](x[0],x[1],x[2],x[3])\displaystyle=y_{[3]}(x_{[0]},x_{[1]},x_{[2]},x_{[3]}) =f(x[0])x[3]+f′′(x[0])x[1]x[2]+16f′′′(x[0])x[1]x[1]x[1]\displaystyle=f^{\prime}(x_{[0]})x_{[3]}+f^{\prime\prime}(x_{[0]})x_{[1]}x_{[2]}+\frac{1}{6}f^{\prime\prime\prime}(x_{[0]})x_{[1]}x_{[1]}x_{[1]} (11)
\displaystyle\vdots

These, as given in Griewank & Walther (2008), are written in terms of the normalized, Taylor coefficients. This obscures their direct relationship with the derivatives, which we make explicit.

Consider the polynomial eq. 5 with Taylor coefficients expanded so their normalization is clear. Further, let’s use suggestive notation that these coefficients correspond to the higher derivatives of of xx with respect to tt, making x(t)x(t) a Taylor polynomial. That is x[i]=1i!xi=1i!dixdtix_{[i]}=\frac{1}{i!}x_{i}=\frac{1}{i!}\frac{d^{i}x}{dt^{i}}.

x(t)\displaystyle x(t) =x0+x1t+12!x2t2+13!x3t3++1d!xdtdn\displaystyle=x_{0}+x_{1}t+\frac{1}{2!}x_{2}t^{2}+\frac{1}{3!}x_{3}t^{3}+\cdots+\frac{1}{d!}x_{d}t^{d}\in\mathbb{R}^{n} (12)
=x0+dxdtt+12!d2xdt2t2+13!d3xdt3t3++1d!ddxdtdtdn\displaystyle=x_{0}+\frac{\mathrm{d}x}{\mathrm{d}t}t+\frac{1}{2!}\frac{d^{2}x}{dt^{2}}t^{2}+\frac{1}{3!}\frac{d^{3}x}{dt^{3}}t^{3}+\cdots+\frac{1}{d!}\frac{d^{d}x}{dt^{d}}t^{d}\in\mathbb{R}^{n} (13)

Again, we are interested in the polynomial eq. 7, but with the normalization terms explicit

y(t)y0+y1t+12!y2t2+13!y3t3++1d!ydtdmy(t)\equiv y_{0}+y_{1}t+\frac{1}{2!}y_{2}t^{2}+\frac{1}{3!}y_{3}t^{3}+\cdots+\frac{1}{d!}y_{d}t^{d}\in\mathbb{R}^{m} (15)

Now we can expand the expressions for the Taylor coefficients y[i]y_{[i]} to expressions for derivative coefficients yi=i!y[i]y_{i}=i!y_{[i]}

The coefficients of the Taylor expansion, yjy_{j}, are smooth functions of the iji\leq j coefficients xix_{i},

y0\displaystyle y_{0} =y0(x0)\displaystyle=y_{0}(x_{0}) =y[0](x0)\displaystyle=y_{[0]}(x_{0})
=f(x0)\displaystyle=f(x_{0}) (16)
y1\displaystyle y_{1} =y1(x0,x1)\displaystyle=y_{1}(x_{0},x_{1}) =y[1](x0,x1)\displaystyle=y_{[1]}(x_{0},x_{1})
=f(x0)x1\displaystyle=f^{\prime}(x_{0})x_{1}
=f(x0)dxdt\displaystyle=f^{\prime}(x_{0})\frac{\mathrm{d}x}{\mathrm{d}t} (17)
y2\displaystyle y_{2} =y2(x0,x1,x2)\displaystyle=y_{2}(x_{0},x_{1},x_{2}) =2!(y[2](x0,x1,12!x2))\displaystyle=2!\left(y_{[2]}(x_{0},x_{1},\frac{1}{2!}x_{2})\right)
=2!(f(x0)12!x2+12f′′(x0)x1x1)\displaystyle=2!\left(f^{\prime}(x_{0})\frac{1}{2!}x_{2}+\frac{1}{2}f^{\prime\prime}(x_{0})x_{1}x_{1}\right)
=f(x0)x2+f′′(x0)x1x1\displaystyle=f^{\prime}(x_{0})x_{2}+f^{\prime\prime}(x_{0})x_{1}x_{1}
=f(x0)d2xdt2+f′′(x0)(dxdt)2\displaystyle=f^{\prime}(x_{0})\frac{d^{2}x}{dt^{2}}+f^{\prime\prime}(x_{0})\left(\frac{\mathrm{d}x}{\mathrm{d}t}\right)^{2} (18)
=d2dt2f(x(t))\displaystyle=\frac{d^{2}}{dt^{2}}f(x(t)) (19)
y3\displaystyle y_{3} =y3(x0,x1,x2,x3)\displaystyle=y_{3}(x_{0},x_{1},x_{2},x_{3}) =3!(y[3](x0,x1,12!x2,13!x3))\displaystyle=3!\left(y_{[3]}(x_{0},x_{1},\frac{1}{2!}x_{2},\frac{1}{3!}x_{3})\right)
=3!(f(x0)13!x3+f′′(x0)x112!x2+16f′′′(x0)x1x1x1)\displaystyle=3!\left(f^{\prime}(x_{0})\frac{1}{3!}x_{3}+f^{\prime\prime}(x_{0})x_{1}\frac{1}{2!}x_{2}+\frac{1}{6}f^{\prime\prime\prime}(x_{0})x_{1}x_{1}x_{1}\right)
=f(x0)x3+3f′′(x0)x1x2+f′′′(x0)x1x1x1\displaystyle=f^{\prime}(x_{0})x_{3}+3f^{\prime\prime}(x_{0})x_{1}x_{2}+f^{\prime\prime\prime}(x_{0})x_{1}x_{1}x_{1}
=f(x0)d3xdt3+3f′′(x0)dxdtd2xdt2+f′′′(x0)(dxdt)3\displaystyle=f^{\prime}(x_{0})\frac{d^{3}x}{dt^{3}}+3f^{\prime\prime}(x_{0})\frac{\mathrm{d}x}{\mathrm{d}t}\frac{d^{2}x}{dt^{2}}+f^{\prime\prime\prime}(x_{0})\left(\frac{\mathrm{d}x}{\mathrm{d}t}\right)^{3} (20)
=d3dt3f(x(t))\displaystyle=\frac{d^{3}}{dt^{3}}f(x(t)) (21)
\displaystyle\vdots

Therefore, eqs. 16, 17, 19 and 21 show that the derivative coefficient yiy_{i} are exactly the iith order higher derivatives of the composition f(x(t))f(x(t)) with respect to tt. The key insight to this exercise is that by writing the derivative coefficients explicitly we reveal that the expressions for the terms, eqs. 16, 17, 18 and 20, involve terms previously computed for lower order terms.

In general, it will be useful to consider that the yky_{k} derivative coefficients is a function of all lower order input derivatives

yk=yk(x0,,xk).y_{k}=y_{k}(x_{0},\dots,x_{k}). (22)

We provide the API to compute this in JAX by indexing the kk-output of jet

yk=jet(f,x0,(x1,,xk))[k].y_{k}=\texttt{jet}(f,{x_{0}},({x_{1},\dots,x_{k}}))[k].

A.2 Relationship with Differential Equations

A.2.1 Autonomous Form

We can transform the initial value problem

dxdt=f(x(t),t)wherex(t0)=x0n\frac{\mathrm{d}x}{\mathrm{d}t}=f(x(t),t)\quad\text{where}\quad x(t_{0})=x_{0}\in\mathbb{R}^{n} (23)

into an autonomous dynamical system by augmenting the system to include the independent variable with trivial dynamics Hairer et al. (1993):

ddt(xt)=(f(x(t))1)where(x(0)t(0))=(x0t0)n\frac{\mathrm{d}}{\mathrm{d}t}\begin{pmatrix}x\\ t\end{pmatrix}=\begin{pmatrix}f(x(t))\\ 1\end{pmatrix}\quad\text{where}\quad\begin{pmatrix}x(0)\\ t(0)\end{pmatrix}=\begin{pmatrix}x_{0}\\ t_{0}\end{pmatrix}\in\mathbb{R}^{n} (24)

We do this for notational convenience, as well it disambiguates that derivatives with respect to tt are meant in the “total" sense. This is aleviates the potential ambiguity of tf(x(t),t)\frac{\partial}{\partial t}f(x(t),t) which could mean both the derivative with respect to the second argument and the derivative through x(t)x(t) by the chain rule fxxt\frac{\partial f}{\partial x}\frac{\partial x}{\partial t}.

A.2.2 Taylor Coefficients for ODE Solution with jet

Recall that jet gives us the coefficients for yiy_{i} as a function of ff and the coefficients xjix_{j\leq i}. We can use jet  and the relationship xk+1=ykx_{k+1}=y_{k} to recursively compute the coefficients of the solution polynomial.

Algorithm 1 Taylor Coefficients for ODE Solution by Recursive Jet
# Have: x_0, f
# Want: x_1, …, x_K
\pary_0 = jet(f, x_0, (0,)) # equivalently, f(x_0)
x_1 = y_0
\parfor k in range(K):
(y_0, (y_1,…, y_k)) = jet(f, x_0, (x_1,…, x_k))
x_{k+1} = y_k
\parreturn x_0, (x_1, …, x_K)

A.3 Regularizing Taylor Terms

Computing the Taylor coefficients for the ODE solution as in algorithm 1 will give a local approximation to the ODE solution. If infinitely many Taylor coefficients could be computed this would give the exact solution. The order of the final Taylor coefficient, determining the truncation of the polynomial, gives the order of the approximation.

If the higher order Taylor coefficients of the solution are large, then truncation will result in a local approximation that quickly diverts from the solution. However, if the higher Taylor coefficients are small then the local approximation will remain close to the solution. This motivates our regularization method. The effect of our regularizer on the Taylor expansion of a solution to a neural ODE can be seen in fig. 9.

Refer to caption
Figure 9: Left: The dynamics and a trajectory of a neural ODE trained on a toy supervised learning problem. The dynamics are poorly approximated by a 6th-order local Taylor series, and requires 92 NFE by a solve by a 5th-order Runge-Kutta solver. Right: Regularizing the 6th-order derivatives of trajectories gives dynamics that are easier to solve numerically, requiring only 68 NFE.

Appendix B Experimental Details

Experiments were conducted using GPU-based ODE solvers. Training gradients were computed using the adjoint method, in which the trajectory is reconstructed backwards in time to save memory, for backpropagation. As in Finlay et al. (2020), we normalize our regularization term in eq. 1 by the dimension of the vector-valued trajectory 𝐳(t)\mathbf{z}(t) so that we may choose λ\lambda free of scaling by the dimension of the problem.

B.1 Efficient computation of the gradient of regularization term

To optimize our regularized objective, we must compute its gradient. We use the adjoint method as described in Chen et al. (2018) to differentiate through the solution to the ODE. In particular, to optimize our model we only need to compute the gradient of the regularization term. The adjoint method gives the gradient of the ODE solution as a solution to an augmented ODE.

B.2 Supervised Learning

The dynamics function f:d×df:\mathbb{R}^{d}\times\mathbb{R}\to\mathbb{R}^{d} is given by an MLP as follows

z1=σ(x)\displaystyle z_{1}=\sigma(x)
h1=W1[z1;t]+b1\displaystyle h_{1}=W_{1}[z_{1};t]+b_{1}
z2=σ(h1)\displaystyle z_{2}=\sigma(h_{1})
y=W2[z2;t]+b2\displaystyle y=W_{2}[z_{2};t]+b_{2}

Where [;][\cdot;\cdot] denotes concatenation of a scalar onto a column vector. The parameters are W1h×d,b1hW_{1}\in\mathbb{R}^{h\times d},b_{1}\in\mathbb{R}^{h} and W2d×h,b2dW_{2}\in\mathbb{R}^{d\times h},b_{2}\in\mathbb{R}^{d}. Here we use 100 hidden units, i.e. h=100h=100. We have d=784d=784, the dimension of an MNIST image.

We train with a batch size of 100 for 160 epochs. We use the standard training set of 60,000 images, and the standard test set of 10,000 images as a validation/test set. We optimize our model using SGD with momentum with β=0.9\beta=0.9. Our learning rate schedule is 1e-1 for the first 60 epochs, 1e-2 until epoch 100, 1e-3 until epoch 140, and 1e-4 for the final 20 epochs.

B.3 Continuous Generative Modelling of Time-Series

The PhysioNet dataset consists of observations of 41 distinct traits over a time period of 48 hours. We remove the parameters “Age”, “Gender”, “Height”, and “ICUType” as these attributes do not vary in time. We also quantize the measurements for each attribute by the hour by averaging multiple measurements within the same hour. This leaves 49 unique time stamps (the extra time stamp for observations at exactly the endpoint of the 48 hour observation period). We report all our losses on this quantized data. We performed this rather coarse quantization for computational reasons having to do with our particular implementation of this model. The validation split was obtained by taking a random split of 20% of the trajectories from the full dataset. In total there are 8000 trajectories. Code is included for processing the dataset, and links to downloading the data may be found in the code for Rubanova et al. (2019). All other experimental details may be found in the main body and appendices of Rubanova et al. (2019).

B.4 Continuous Normalizing Flows

For the model trained on the MINIBOONE tabular dataset from Papamakarios et al. (2017), we used the same architecture as in Table 4 in the appendix of Grathwohl et al. (2019). We chose the number of epochs and a learning rate schedule based on manual tuning on the validation set, in contrast to Grathwohl et al. (2019) who tuned these automatically using early stopping and an automatic heuristic for the learning rate decay using evaluation on a validation set. In particular, we trained for 500 epochs with a learning rate of 1e-3 for the first 300 epochs, 1e-4 until epoch 425, and 1e-5 for the remaining 75 epochs. The number of epochs and learning rate schedule was determined by evaluating the model on the validation set every 10 epochs, and decaying the learning rate by a factor of 10 once the loss on the validation set stopped improving for several evaluations, with the goal of matching or improving upon the log-likelihood reported in Grathwohl et al. (2019). The data was obtained as made available from Papamakarios et al. (2017), which was already processed and split into train/validation/test. In particular, the training set has 29556 examples, the validation set has 3284 examples, and the test set has 3648 examples, which consist of 43 features.

It is important to note that we implemented a single-flow model for the MNIST dataset, while the original comparison in Finlay et al. (2020) was on a multi-flow model. This accounts for discrepancy in bits/dim and NFE reported in Finlay et al. (2020).

All other experimental details are as in Grathwohl et al. (2019).

B.5 Hardware

MNIST Supervised learning, Physionet Time-series, and MNIST FFJORD experiments were trained and evaluated on NVIDIA Tesla P100 GPU. Tabular data FFJORD experiments were evaluated on NVIDIA Tesla P100 GPU but trained on NVIDIA Tesla T4 GPU. All experiments except for MNIST FFJORD were trained with double precision for purposes of reproducibility.

Appendix C Additional Results

C.1 Overfitting of NFE

Refer to caption
Figure 10: The difference in NFE is tracked by the variance of NFE.

In fig. 10 we note that there is a striking correspondence in the variance of NFE across individual examples (in both the train set (dark red) and test set (light red)) and the absolute difference in NFE between examples in the training set and test set. This suggests that any difference in the average NFE between training examples and test examples is explained by noise in the estimate of the true average NFE. It is also interesting that speed regularization does not have a monotonic relationship with the variance of NFE, and we speculate as to how this might interact between the correspondence of NFE for a particular example and the difficulty in the model correctly classifying it.

C.2 Trading off function evaluations with a surrogate loss

In fig. 12 and fig. 12 we confirm that our method poses a suitable tradeoff not only on the loss being optimized, but also on the potentially non-differentiable loss which we truly care about. On MNIST, we get a similar pareto curve when plotting classification error as opposed to cross-entropy loss, and similarly on the time-series modelling task we see that we get a similar pareto curve on MSE loss as compared to IWAE loss. The pareto curves are plotted for 3\mathcal{R}_{3}, 2\mathcal{R}_{2} respectively.

Refer to caption
Figure 11: MNIST Classification
Refer to caption
Figure 12: Physionet Time-Series

C.3 Wall-clock Time

We include additional tables with wall-clock time and training with fixed grid solvers in table 3 and table 4.

Table 3: Classification on MNIST
Training Evaluation using adaptive solvers
Steps Hours Loss NFE 2\mathcal{R}_{2} \mathcal{B} 𝒦\mathcal{K}
No Regularization 2 0.08 .0239 116 25.9 .231 7.91
4 0.13 .0235 110 21.9 .234 7.66
8 0.23 .0236 110 21.3 .233 7.62
\infty 1.71 .0235 110 - .233 7.63
RNODE 2 0.12 .0238 110 18.4 .229 7.07
(Finlay et al., 2020) 4 0.20 .0238 110 14.6 .230 6.85
8 0.37 .0238 110 14.1 .229 6.82
TayNODE (ours) 2 0.19 .0234 104 3.2 .217 7.12
4 0.33 .0234 104 2.4 .218 7.06
8 0.61 .0234 104 2.4 .219 7.06
\infty 2.56 .0234 104 - .233 7.63
Table 4: Density Estimation on Tabular Data (MINIBOONE)
Training Evaluation using adaptive solvers
Steps Hours Loss NFE 2\mathcal{R}_{2} \mathcal{B} 𝒦\mathcal{K}
No Regularization 4 0.19 9.78 185 17.1 4.10 1.72
8 0.37 9.77 184 19.0 4.10 1.77
\infty 1.64 9.74 182 - 4.10 1.77
RNODE 4 0.19 9.77 182 15.9 4.02 1.65
(Finlay et al., 2020) 8 0.38 9.76 181 17.3 4.01 1.69
16 0.73 9.77 189 17.5 4.03 1.70
TayNODE (ours) 4 0.49 9.84 177 13.1 4.00 1.57
8 0.96 9.79 181 13.6 3.99 1.58
16 1.90 9.77 181 13.7 3.99 1.59

Appendix D Comparison to How to Train Your Neural ODE

The terms from Finlay et al. (2020) are

f(𝐳(t),t,θ)22\displaystyle\left\|f(\mathbf{z}(t),t,\theta)\right\|^{2}_{2}

and an estimate of

𝐳f(𝐳(t),t,θ)F2\displaystyle\left\|\nabla_{\mathbf{z}}f(\mathbf{z}(t),t,\theta)\right\|^{2}_{F}

These are combined with a weighted average and integrated along the solution trajectory.

These terms are motivated by the expansion

d2𝐳(t)dt2\displaystyle\frac{\mathrm{d}^{2}\mathbf{z}(t)}{\mathrm{d}t^{2}} =𝐳f(𝐳(t),t)f(𝐳(t),t)+ft(𝐳(t),t)\displaystyle=\nabla_{\mathbf{z}}f(\mathbf{z}(t),t)f(\mathbf{z}(t),t)+\frac{\partial f}{\partial t}(\mathbf{z}(t),t)

Namely, eq. 3 regularizes the first total derivative of the solution, f(𝐳(t),t)f(\mathbf{z}(t),t), along the trajectory, and eq. 4 regularizes a stochastic estimate of the Frobenius norm of the spatial derivative, 𝐳f(𝐳(t),t)\nabla_{\mathbf{z}}f(\mathbf{z}(t),t), along the solution trajectory.

In contrast, 2\mathcal{R}_{2} regularizes the norm of the second total derivative directly. In particular, this takes into account the ft\frac{\partial f}{\partial t} term. In other words, this accounts for the explicit dependence of ff on time, while eq. 3 and eq. 4 capture only the implicit dependence on time through 𝐳(t)\mathbf{z}(t).

Even in the case of an autonomous system, that is, where ft\frac{\partial f}{\partial t} is identically 0 and the dynamics ff only depend implicitly on time, these terms still differ. Namely, 2\mathcal{R}_{2} integrates the following along the solution trajectory:

𝐳f(𝐳(t),t,θ)f(𝐳(t),t,θ)22\displaystyle\left\|\nabla_{\mathbf{z}}f(\mathbf{z}(t),t,\theta)f(\mathbf{z}(t),t,\theta)\right\|^{2}_{2}

while Finlay et al. (2020) penalizes the respective norms of the matrix 𝐳f(𝐳(t),t)\nabla_{\mathbf{z}}f(\mathbf{z}(t),t) and vector f(𝐳(t),t)f(\mathbf{z}(t),t) separately.