This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

\TOCclone

[Contents (Appendix)]tocatoc \AfterTOCHead[toc] \AfterTOCHead[atoc]

MALI: A memory efficient and reverse accurate integrator for Neural ODEs

Juntang Zhuang; Nicha C. Dvornek; Sekhar Tatikonda; James S. Duncan
{j.zhuang; nicha.dvornek; sekhar.tatikonda; james.duncan} @yale.edu
Yale University, New Haven, CT, USA
Abstract

Neural ordinary differential equations (Neural ODEs) are a new family of deep-learning models with continuous depth. However, the numerical estimation of the gradient in the continuous case is not well solved: existing implementations of the adjoint method suffer from inaccuracy in reverse-time trajectory, while the naive method and the adaptive checkpoint adjoint method (ACA) have a memory cost that grows with integration time. In this project, based on the asynchronous leapfrog (ALF) solver, we propose the Memory-efficient ALF Integrator (MALI), which has a constant memory cost w.r.t number of solver steps in integration similar to the adjoint method, and guarantees accuracy in reverse-time trajectory (hence accuracy in gradient estimation). We validate MALI in various tasks: on image recognition tasks, to our knowledge, MALI is the first to enable feasible training of a Neural ODE on ImageNet and outperform a well-tuned ResNet, while existing methods fail due to either heavy memory burden or inaccuracy; for time series modeling, MALI significantly outperforms the adjoint method; and for continuous generative models, MALI achieves new state-of-the-art performance.Code is available at https://github.com/juntang-zhuang/TorchDiffEqPack

1 Introduction

Recent research builds the connection between continuous models and neural networks. The theory of dynamical systems has been applied to analyze the properties of neural networks or guide the design of networks (Weinan, 2017; Ruthotto & Haber, 2019; Lu et al., 2018). In these works, a residual block (He et al., 2016) is typically viewed as a one-step Euler discretization of an ODE; instead of directly analyzing the discretized neural network, it might be easier to analyze the ODE.

Another direction is the neural ordinary differential equation (Neural ODE) (Chen et al., 2018), which takes a continuous depth instead of discretized depth. The dynamics of a Neural ODE is typically approximated by numerical integration with adaptive ODE solvers. Neural ODEs have been applied in irregularly sampled time-series (Rubanova et al., 2019), free-form continuous generative models (Grathwohl et al., 2018; Finlay et al., 2020), mean-field games (Ruthotto et al., 2020), stochastic differential equations (Li et al., 2020) and physically informed modeling (Sanchez-Gonzalez et al., 2019; Zhong et al., 2019).

Though the Neural ODE has been widely applied in practice, how to train it is not extensively studied. The naive method directly backpropagates through an ODE solver, but tracking a continuous trajectory requires a huge memory. Chen et al. (2018) proposed to use the adjoint method to determine the gradient in continuous cases, which achieves constant memory cost w.r.t integration time; however, as pointed out by Zhuang et al. (2020), the adjoint method suffers from numerical errors due to the inaccuracy in reverse-time trajectory. Zhuang et al. (2020) proposed the adaptive checkpoint adjoint (ACA) method to achieve accuracy in gradient estimation at a much smaller memory cost compared to the naive method, yet the memory consumption of ACA still grows linearly with integration time. Due to the non-constant memory cost, neither ACA nor naive method are suitable for large scale datasets (e.g. ImageNet) or high-dimensional Neural ODEs (e.g. FFJORD (Grathwohl et al., 2018)).

In this project, we propose the Memory-efficient Asynchronous Leapfrog Integrator (MALI) to achieve advantages of both the adjoint method and ACA: constant memory cost w.r.t integration time and accuracy in reverse-time trajectory. MALI is based on the asynchronous leapfrog (ALF) integrator (Mutze, 2013). With the ALF integrator, each numerical step forward in time is reversible. Therefore, with MALI, we delete the trajectory and only keep the end-time states, hence achieve constant memory cost w.r.t integration time; using the reversibility, we can accurately reconstruct the trajectory from the end-time value, hence achieve accuracy in gradient. Our contributions are:

  1. 1.

    We propose a new method (MALI) to solve Neural ODEs, which achieves constant memory cost w.r.t number of solver steps in integration and accuracy in gradient estimation. We provide theoretical analysis.

  2. 2.

    We validate our method with extensive experiments: (a) for image classification tasks, MALI enables a Neural ODE to achieve better accuracy than a well-tuned ResNet with the same number of parameters; to our knowledge, MALI is the first method to enable training of Neural ODEs on a large-scale dataset such as ImageNet, while existing methods fail due to either heavy memory burden or inaccuracy. (b) In time-series modeling, MALI achieves comparable or better results than other methods. (c) For generative modeling, a FFJORD model trained with MALI achieves new state-of-the-art results on MNIST and Cifar10.

2 Preliminaries

2.1 Numerical Integration Methods

An ordinary differential equation (ODE) typically takes the form

dz(t)dt=fθ(t,z(t))s.t.z(t0)=x,t[t0,T],Loss=L(z(T),y)\frac{\mathrm{d}z(t)}{\mathrm{d}t}=f_{\theta}(t,z(t))\ \ \ \ s.t.\ \ \ \ z(t_{0})=x,\ t\in[t_{0},T],\ \ \ \ Loss=L(z(T),y) (1)

where z(t)z(t) is the hidden state evolving with time, TT is the end time, t0t_{0} is the start time (typically 0), xx is the initial state. The derivative of z(t)z(t) w.r.t tt is defined by a function ff, and ff is defined as a sequence of layers parameterized by θ\theta. The loss function is L(z(T),y)L(z(T),y), where yy is the target variable. Eq. 1 is called the initial value problem (IVP) because only z(t0)z(t_{0}) is specified.

Input initial state xx, start time t0t_{0}, end time TT, error tolerance etoletol, initial stepsize hh.
Initialize z(0)=x,t=t0z(0)=x,t=t_{0}
While t<Tt<T
        error_est=error\_est=\infty
        While error_est>etolerror\_est>etol
               hh×DecayFactorh\leftarrow h\times DecayFactor
               z^,error_est=ψh(t,z)\hat{z},\ error\_est=\psi_{h}(t,z)
        If error_est<etolerror\_est<etol
               hh×IncreaseFactorh\leftarrow h\times IncreaseFactor
        tt+h,zz^t\leftarrow t+h,\ \ z\leftarrow\hat{z}
Algorithm 1 Numerical Integration

Notations   We summarize the notations following Zhuang et al. (2020).

  • zi(ti)/z¯(τi)z_{i}(t_{i})/\overline{z}(\tau_{i}): hidden state in forward/reverse time trajectory at time ti/τit_{i}/\tau_{i}.

  • ψh(ti,zi)\psi_{h}(t_{i},z_{i}): the numerical solution at time ti+ht_{i}+h, starting from (ti,zi)(t_{i},z_{i}) with a stepsize hh.

  • Nf,NzN_{f},N_{z}: NfN_{f} is the number of layers in ff in Eq. 1, NzN_{z} is the dimension of zz.

  • Nt/NrN_{t}/N_{r}: number of discretized points (outer iterations in Algo. 1) in forward / reverse integration.

  • mm: average number of inner iterations in Algo. 1 to find an acceptable stepsize.

Numerical Integration   The algorithm for general adaptive-stepsize numerical ODE solvers is summarized in Algo. 1 (Wanner & Hairer, 1996). The solver repeatedly advances in time by a step, which is the outer loop in Algo. 1 (blue curve in Fig. 1). For each step, the solver decreases the stepsize until the estimate of error is lower than the tolerance, which is the inner loop in Algo. 1 (green curve in Fig. 1). For fixed-stepsize solvers, the inner loop is replaced with a single evaluation of ψh(t,z)\psi_{h}(t,z) using predefined stepsize hh. Different methods typically use different ψ\psi, for example different orders of the Runge-Kutta method (Runge, 1895).

2.2 Analytical form of gradient in continuous case

We first briefly introduce the analytical form of the gradient in the continuous case, then we compare different numerical implementations in the literature to estimate the gradient. The analytical form of the gradient in the continuous case is

dLdθ=T0a(t)f(z(t),t,θ)θ𝑑t\frac{\mathrm{d}L}{\mathrm{d}\theta}=-\int_{T}^{0}a(t)^{\top}\frac{\partial f(z(t),t,\theta)}{\partial\theta}dt (2)
da(t)dt+(f(z(t),t,θ)z(t))a(t)=0t(0,T),a(T)=Lz(T)\frac{\mathrm{d}a(t)}{dt}+\Big{(}\frac{\partial f(z(t),t,\theta)}{\partial z(t)}\Big{)}^{\top}a(t)=0\ \ \forall t\in(0,T),\ \ \ a(T)=\frac{\partial L}{\partial z(T)} (3)

where a(t)a(t) is the “adjoint state”. Detailed proof is given in (Pontryagin, 1962). In the next section we compare different numerical implementations of this analytical form.

Table 1: Comparison between different methods for gradient estimation in continuous case. MALI achieves reverse accuracy, constant memory w.r.t number of solver steps in integration, shallow computation graph and low computation cost.
Naive Adjoint ACA MALI
Computation NzNf×Nt×m×2NzN_{f}\times N_{t}\times m\times 2 NzNf×(Nt+Nr)×mNzN_{f}\times(N_{t}+N_{r})\times m NzNf×Nt×(m+1)NzN_{f}\times N_{t}\times(m+1) NzNf×Nt×(m+2)NzN_{f}\times N_{t}\times(m+2)
Memory NzNf×Nt×mNzN_{f}\times N_{t}\times m NzNfNzN_{f} Nz(Nf+Nt)Nz(N_{f}+N_{t}) Nz(Nf+1)Nz(N_{f}+1)
Computation graph depth Nf×Nt×mN_{f}\times N_{t}\times m Nf×NrN_{f}\times N_{r} Nf×NtN_{f}\times N_{t} Nf×NtN_{f}\times N_{t}
Reverse accuracy
Refer to caption
Figure 1: Illustration of numerical solver in forward-pass. For adaptive solvers, for each step forward-in-time, the stepsize is recursively adjusted until the estimated error is below predefined tolerance; the search process is represented by green curve, and the accepted step (ignore the search process) is represented by blue curve.
Refer to caption
Figure 2: In backward-pass, the adjoint method reconstructs trajectory as a separate IVP. Naive, ACA and MALI track the forward-time trajectory, hence are accurate. ACA and MALI only backpropagate through the accepted step, while naive method backpropagates through the search process hence has deeper computation graphs.

2.3 Numerical implementations in the literature for the analytical form

We compare different numerical implementations of the analytical form in this section. The forward-pass and backward-pass of different methods are demonstrated in Fig. 1 and Fig. 2 respectively. Forward-pass is similar for different methods. The comparison of backward-pass among different methods are summarized in Table. 1. We explain methods in the literature below.

Naive method  The naive method saves all of the computation graph (including search for optimal stepsize, green curve in Fig. 2) in memory, and backpropagates through it. Hence the memory cost is NzNf×Nt×mN_{z}N_{f}\times N_{t}\times m and depth of computation graph are Nf×Nt×mN_{f}\times N_{t}\times m, and the computation is doubled considering both forward and backward passes. Besides the large memory and computation, the deep computation graph might cause vanishing or exploding gradient (Pascanu et al., 2013).

Adjoint method  Note that we use “adjoint state equation” to refer to the analytical form in Eq. 2 and 3, while we use “adjoint method” to refer to the numerical implementation by Chen et al. (2018). As in Fig. 1 and  2, the adjoint method forgets forward-time trajectory (blue curve) to achieve memory cost NzNfN_{z}N_{f} which is constant to integration time; it takes the end-time state (derived from forward-time integration) as the initial state, and solves a separate IVP (red curve) in reverse-time.

Theorem 2.1.

(Zhuang et al., 2020) For an ODE solver of order pp, the error of the reconstructed initial value by the adjoint method is k=0N1[hkp+1DΦtkT(zk)l(tk,zk)+(hk)p+1DΦTtk(zk¯)l(tk,zk¯)¯]+O(hp+1)\sum_{k=0}^{N-1}\big{[}h_{k}^{p+1}D\Phi_{t_{k}}^{T}(z_{k})l(t_{k},z_{k})+(-h_{k})^{p+1}D\Phi_{T}^{t_{k}}(\overline{z_{k}})\overline{l(t_{k},\overline{z_{k}})}\big{]}+O(h^{p+1}), where Φ\Phi is the ideal solution, DΦD\Phi is the Jacobian of Φ\Phi, l(t,z)l(t,z) and l(t,z)¯\overline{l(t,z)} are the local error in forward-time and reverse-time integration respectively.

Theorem 2.1 is stated as Theorem 3.2 in Zhuang et al. (2020); please see reference paper for detailed proof. To summarize, due to inevitable errors with numerical ODE solvers, the reverse-time trajectory (red curve, z¯(τ)\overline{z}(\tau)) cannot match the forward-time trajectory (blue curve, z(t)z(t)) accurately. The error in z¯\overline{z} propagates to dLdθ\frac{\mathrm{d}L}{\mathrm{d}\theta} by Eq. 2, hence affects the accuracy in gradient estimation.

Adaptive checkpoint adjoint (ACA)  To solve the inaccuracy of adjoint method, Zhuang et al. (2020) proposed ACA: ACA stores forward-time trajectory in memory for backward-pass, hence guarantees accuracy; ACA deletes the search process (green curve in Fig. 2), and only back-propagates through the accepted step (blue curve in Fig. 2), hence has a shallower computation graph (Nf×NtN_{f}\times N_{t} for ACA vs Nf×Nt×mN_{f}\times N_{t}\times m for naive method). ACA only stores {z(ti)}i=1Nt\{z(t_{i})\}_{i=1}^{N_{t}}, and deletes the computation graph for {f(z(ti),ti)}i=1Nt\{f\big{(}z(t_{i}),t_{i}\big{)}\}_{i=1}^{N_{t}}, hence the memory cost is Nz(Nf+Nt)N_{z}(N_{f}+N_{t}). Though the memory cost is much smaller than the naive method, it grows linearly with NtN_{t}, and can not handle very high dimensional models. In the following sections, we propose a method to overcome all these disadvantages of existing methods.

3 Methods

3.1 Asynchronous Leapfrog Integrator

In this section we give a brief introduction to the asynchronous leapfrog (ALF) method (Mutze, 2013), and we provide theoretical analysis which is missing in Mutze (2013). For general first-order ODEs in the form of Eq. 1, the tuple (z,t)(z,t) is sufficient for most ODE solvers to take a step numerically. For ALF, the required tuple is (z,v,t)(z,v,t), where vv is the “approximated derivative”. Most numerical ODE solvers such as the Runge-Kutta method (Runge, 1895) track state zz evolving with time, while ALF tracks the “augmented state” (z,v)(z,v). We explain the details of ALF as below. Input (zin,vin,sin,h)(z_{in},v_{in},s_{in},h) where sins_{in} is current time, zinz_{in} and vinv_{in} are correponding values at time sins_{in}, hh is stepsize. Forward    s1=sin+h/2s_{1}=s_{in}+h/2                k1=zin+vin×h/2k_{1}=z_{in}+v_{in}\times h/2                u1=f(k1,s1)u_{1}=f(k_{1},s_{1})                vout=vin+2(u1vin)v_{out}=v_{in}+2(u_{1}-v_{in})                zout=k1+vout×h/2z_{out}=k_{1}+v_{out}\times h/2                sout=s1+h/2s_{out}=s_{1}+h/2 Output      (zout,vout,sout,h)(z_{out},v_{out},s_{out},h) Algorithm 2 Forward of ψ\psi in ALF Input (zout,vout,sout,h)(z_{out},v_{out},s_{out},h) where souts_{out} is current time, zoutz_{out} and voutv_{out} are corresponding values at souts_{out}, hh is stepsize. Inverse      s1=south/2s_{1}=s_{out}-h/2                k1=zoutvout×h/2k_{1}=z_{out}-v_{out}\times h/2                u1=f(k1,s1)u_{1}=f(k_{1},s_{1})                vin=2u1voutv_{in}=2u_{1}-v_{out}                zin=k1vin×h/2z_{in}=k_{1}-v_{in}\times h/2                sin=s1h/2s_{in}=s_{1}-h/2 Output      (zin,vin,sin,h)(z_{in},v_{in},s_{in},h) Algorithm 3 ψ1\psi^{-1} (Inverse of ψ\psi) in ALF

Refer to caption
Figure 3: With ALF method, given any tuple (zj,vj,tj)(z_{j},v_{j},t_{j}) and discretized time points {ti}i=1Nt\{t_{i}\}_{i=1}^{N_{t}}, we can reconstruct the entire trajectory accurately due to the reversibility of ALF.

Procedure of ALF  Different ODE solvers have different ψ\psi in Algo. 1, hence we only summarize ψ\psi for ALF in Algo. 2. Note that for a complete algorithm of integration for ALF, we need to plug Algo. 2 into Algo. 1. The forward-pass is summarized in Algo. 2. Given stepsize hh, with input (zin,vin,sin)(z_{in},v_{in},s_{in}), a single step of ALF outputs (zout,vout,sout)(z_{out},v_{out},s_{out}).

As in Fig. 3, given (z0,v0,t0)(z_{0},v_{0},t_{0}), the numerical forward-time integration calls Algo. 2 iteratively:

(zi,vi,ti,hi)=ψ(zi1,vi1,ti1,hi)\displaystyle(z_{i},v_{i},t_{i},h_{i})=\psi(z_{i-1},v_{i-1},t_{i-1},h_{i})
s.t.hi=titi1,i=1,2,Nt\displaystyle s.t.\ \ h_{i}=t_{i}-t_{i-1},\ \ i=1,2,...N_{t} (4)

Invertibility of ALF  An interesting property of ALF is that ψ\psi defines a bijective mapping; therefore, we can reconstruct (zin,vin,sin,h)(z_{in},v_{in},s_{in},h) from (zout,vout,sout,h)(z_{out},v_{out},s_{out},h), as demonstrated in Algo. 7. As in Fig. 3, we can reconstruct the entire trajectory given the state (zj,vj)(z_{j},v_{j}) at time tjt_{j}, and the discretized time points {t0,tNt}\{t_{0},...t_{N_{t}}\}. For example, given (zNt,vNt)(z_{N_{t}},v_{N_{t}}) and {ti}i=0Nt\{t_{i}\}_{i=0}^{N_{t}}, the trajectory for Eq. 3.1 is reconstructed:

(zi1,vi1,ti1,hi)=ψ1(zi,vi,ti,hi)s.t.hi=titi1,i=Nt,Nt1,,1(z_{i-1},v_{i-1},t_{i-1},h_{i})=\psi^{-1}(z_{i},v_{i},t_{i},h_{i})\ \ s.t.\ \ h_{i}=t_{i}-t_{i-1},\ \ i=N_{t},N_{t}-1,...,1 (5)

In the following sections, we will show the invertibility of ALF is the key to maintain accuracy at a constant memory cost to train Neural ODEs. Note that “inverse” refers to reconstructing the input from the output without computing the gradient, hence is different from “back-propagation”.

Initial value   For an initial value problem (IVP) such as Eq. 1, typically z0=z(t0)z_{0}=z(t_{0}) is given while v0v_{0} is undetermined. We can construct v0=f(z(t0),t0)v_{0}=f(z(t_{0}),t_{0}), so the initial augmented state is (z0,v0)(z_{0},v_{0}).

Difference from midpoint integrator   The midpoint integrator (Süli & Mayers, 2003) is similar to Algo. 2, except that it recomputes vin=f(zin,sin)v_{in}=f(z_{in},s_{in}) for every step, while ALF directly uses the input vinv_{in}. Therefore, the midpoint method does not have an explicit form of inverse.

Local truncation error    Theorem 3.1 indicates that the local truncation error of ALF is of order O(h3)O(h^{3}); this implies the global error is O(h2)O(h^{2}). Detailed proof is in Appendix A.3.

Theorem 3.1.

For a single step in ALF with stepsize hh, the local truncation error of zz is O(h3)O(h^{3}), and the local truncation error of vv is O(h2)O(h^{2}).

A-Stability  The ALF solver has a limited stability region, but this can be solved with damping. The damped ALF replaces the update of voutv_{out} in Algo. 2 with vout=vin+2η(u1vin)v_{out}=v_{in}+2\eta(u_{1}-v_{in}), where η\eta is the “damping coefficient” between 0 and 1. We have the following theorem on its numerical stability.

Theorem 3.2.

For the damped ALF integrator with stepsize hh, where σi\sigma_{i} is the ii-th eigenvalue of the Jacobian fz\frac{\partial f}{\partial z}, then the solver is A-stable if |1+η(hσi1)±η[2hσi+η(hσi1)2]|<1,i\Big{|}1+\eta(h\sigma_{i}-1)\pm\sqrt{\eta\big{[}2h\sigma_{i}+\eta(h\sigma_{i}-1)^{2}\big{]}}\Big{|}<1,\ \forall i

Proof is in Appendix A.4 and A.5. Theorem 3.2 implies the following: when η=1\eta=1, the damped ALF reduces to ALF, and the stability region is empty; when 0<η<10<\eta<1, the stability region is non-empty. However, stability describes the behaviour when TT goes to infinity; in practice we always use a bounded TT and ALF performs well. Inverse of damped ALF is in Appendix A.5.

3.2 Memory-efficient ALF Integrator (MALI) for gradient estimation

An ideal solver for Neural ODEs should achieve two goals: accuracy in gradient estimation and constant memory cost w.r.t integration time. Yet none of the existing methods can achieve both goals. We propose a method based on the ALF solver, which to our knowledge is the first method to achieve the two goals simultaneously.

Input Initial state z0z_{0}, start time t0t_{0}, end time TT
Forward
    Apply the numerical integration in Algo. 1, with the ψ\psi function defined by Algo. 2.
    Delete computation graph on the fly, only keep end-time state (zNt,vNt)(z_{N_{t}},v_{N_{t}})
    Keep accepted discretized time points {ti}i=0Nt\{t_{i}\}_{i=0}^{N_{t}} (ignore process to search for optimal stepsize)
Backward
     Initialize a(T)=Lz(T)a(T)=\frac{\partial L}{\partial z(T)} by Eq. 3, initialize dLdθ=0\frac{\mathrm{d}L}{\mathrm{d}\theta}=0
     For ii in {Nt,Nt1,,2,1}\{N_{t},N_{t}-1,...,2,1\}:
          Reconstruct (zi1,vi1)(z_{i-1},v_{i-1}) from (zi,vi)(z_{i},v_{i}) by Algo. 7
          Local forward (zi,vi,ti,hi)=ψ(zi1,vi1,ti1,hi)(z_{i},v_{i},t_{i},h_{i})=\psi(z_{i-1},v_{i-1},t_{i-1},h_{i})
          Local backward, get f(zi1,ti1,θ)zi1\frac{\partial f(z_{i-1},t_{i-1},\theta)}{\partial z_{i-1}} and f(zi1,ti1,θ)θ\frac{\partial f(z_{i-1},t_{i-1},\theta)}{\partial\theta}
          Update a(t)a(t) and dLdθ\frac{\mathrm{d}L}{\mathrm{d}\theta} by Eq. 2 and Eq. 3 discretized at time points ti1t_{i-1} and tit_{i}
          Delete local computation graph
     Output the adjoint state a(t0)a(t_{0}) (gradient w.r.t input z0z_{0}) and parameter gradient dLdθ\frac{\mathrm{d}L}{\mathrm{d}\theta}
Algorithm 4 MALI to acheive accuracy at a constant memory cost w.r.t integration time

Procedure of MALI  Details of MALI are summarized in Algo. 4. For the forward-pass, we only keep the end-time state (zNt,vNt)(z_{N_{t}},v_{N_{t}}) and the accepted discretized time points (blue curves in Fig. 1 and  2). We ignore the search process for optimal stepsize (green curve in Fig. 1 and 2), and delete other variables to save memory. During the backward pass, we can reconstruct the forward-time trajectory as in Eq. 5, then calculate the gradient by numerical discretization of Eq. 2 and Eq. 3.

Constant memory cost w.r.t number of solver steps in integration  We delete the computation graph and only keep the end-time state to save memory. The memory cost is Nz(Nf+1)N_{z}(N_{f}+1), where NzNfN_{z}N_{f} is due to evaluating f(z,t)f(z,t) and is irreducible for all methods. Compared with the adjoint method, MALI only requires extra NzN_{z} memory to record vNtv_{N_{t}}, and also has a constant memory cost w.r.t time step NtN_{t}. The memory cost is Nz(Nf+1)N_{z}(N_{f}+1).

Accuracy   Our method guarantees the accuracy of reverse-time trajectory (e.g. blue curve in Fig. 2 matches the blue curve in Fig. 1), because ALF is explicitly invertible for free-form ff (see Algo. 7). Therefore, the gradient estimation in MALI is more accurate compared to the adjoint method.

Computation cost  Recall that on average it takes mm steps to find an acceptable stepsize, whose error estimate is below tolerance. Therefore, the forward-pass with search process has computation burden Nz×Nf×Nt×mN_{z}\times N_{f}\times N_{t}\times m. Note that we only reconstruct and backprop through the accepted step and ignore the search process, hence it takes another Nz×Nf×Nt×2N_{z}\times N_{f}\times N_{t}\times 2 computation. The overall computation burden is NzNf×Nt×(m+2)N_{z}N_{f}\times N_{t}\times(m+2) as in Table 1.

Refer to caption
Refer to caption
Refer to caption
Figure 4: Comparison of error in gradient in Eq. 6. (a) error in dLdz0\frac{\mathrm{d}L}{\mathrm{d}z_{0}}. (b) error in dLdα\frac{\mathrm{d}L}{\mathrm{d}\alpha}. (c) memory cost.
Refer to caption
Refer to caption
Refer to caption
Figure 5: Results on Cifar10. From left to right: (1) box plot of test accuracy (first 4 columns are Neural ODEs, last is ResNet); (2) test accuracy ±std\pm std v.s. training epoch for Neural ODE; (3) test accuracy ±std\pm std v.s. training time of 90 epochs for Neural ODE.

Shallow computation graph  Similar to ACA, MALI only backpropagates through the accepted step (blue curve in Fig. 2) and ignores the search process (green curve in Fig. 2), hence the depth of computation graph is Nf×NtN_{f}\times N_{t}. The computation graph of MALI is much shallower than the naive method, hence is more robust to vanishing and exploding gradients (Pascanu et al., 2013).

Summary  The adjoint method suffers from inaccuracy in reverse-time trajectory, the naive method suffers from exploding or vanishing gradient caused by deep computation graph, and ACA finds a balance but the memory grows linearly with integration time. MALI achieves accuracy in reverse-time trajectory, constant memory w.r.t integration time, and a shallow computation graph.

4 Experiments

4.1 Validation on a toy example

We compare the performance of different methods on a toy example, defined as

L(z(T))=z(T)2s.t.z(0)=z0,dz(t)/dt=αz(t)L(z(T))=z(T)^{2}\ \ s.t.\ \ z(0)=z_{0},\ \ \mathrm{d}z(t)/\mathrm{d}t=\alpha z(t) (6)

The analytical solution is

z(t)=z0eαt,L=z02e2αT,dL/dz0=2z0e2αT,dL/dα=2Tz02e2αTz(t)=z_{0}e^{\alpha t},\ \ L=z_{0}^{2}e^{2\alpha T},\ \ \mathrm{d}L/\mathrm{d}z_{0}=2z_{0}e^{2\alpha T},\ \ \mathrm{d}L/d\alpha=2Tz_{0}^{2}e^{2\alpha T} (7)

We plot the amplitude of error between numerical solution and analytical solution varying with TT (integrated under the same error tolerance, rtol=105,atol=106\text{rtol}=10^{-5},\text{atol}=10^{-6}) in Fig 4. ACA and MALI have similar errors, both outperforming other methods. We also plot the memory consumption for different methods on a Neural ODE with the same input in Fig. 4. As the error tolerance decreases, the solver evaluates more steps, hence the naive method and ACA increase memory consumption, while MALI and the adjoint method have a constant memory cost. These results validate our analysis in Sec. 3.2 and Table 1, and shows MALI achieves accuracy at a constant memory cost.

Table 2: Top-1 test accuracy of Neural ODE and ResNet on ImageNet. Neural ODE is trained with MALI, and ResNet is trained as the original model; Neural ODE is tested using different solvers without retraining.
Fixed-stepsize solvers of various stepsizes Adaptive-stepsize solver of various tolerances
Stepsize 1 0.5 0.25 0.15 0.1 Tolerance 1.00E+00 1.00E-01 1.00E-02
Neural ODE MALI 42.33 66.4 69.59 70.17 69.94 MALI 62.56 69.89 69.87
Euler 21.94 61.25 67.38 68.69 70.02 Heun-Euler 68.48 69.87 69.88
RK2 42.33 69 69.72 70.14 69.92 RK23 50.77 69.89 69.93
RK4 12.6 69.99 69.91 70.21 69.96 Dopri5 52.3 68.58 69.71
ResNet 70.09
Table 3: Top-1 accuracy under FGSM attack. ϵ\epsilon is the perturbation amplitude. For Neural ODE models, row names represent the solvers to derive the gradient for attack, and column names represent solvers for inference on the perturbed image.
ϵ=1/255\epsilon=1/255 ϵ=2/255\epsilon=2/255
MALI Heun-Euler RK23 Dopri5 MALI Heun-Euler RK23 Dopri5
Neural ODE MALI 14.69 14.72 14.77 15.71 10.38 10.46 10.62 10.62
Heun-Euler 14.77 14.75 14.80 15.74 10.63 10.47 10.44 10.49
RK23 14.82 14.77 14.79 15.69 10.78 10.53 10.48 10.56
Dopri5 14.82 14.78 14.79 15.15 10.76 10.49 10.48 10.51
ResNet 13.02 9.57

4.2 Image recognition with Neural ODE

We validate MALI on image recognition tasks using Cifar10 and ImageNet datasets. Similar to Zhuang et al. (2020), we modify a ResNet18 into its corresponding Neural ODE: the forward function is y=x+fθ(x)y=x+f_{\theta}(x) and y=x+0Tfθ(z)dty=x+\int_{0}^{T}f_{\theta}(z)\mathrm{d}t for the residual block and Neural ODE respectively, where the same fθf_{\theta} is shared. We compare MALI with the naive method, adjoint method and ACA.

Results on Cifar10  Results of 5 independent runs on Cifar10 are summarized in Fig. 5. MALI achieves comparable accuracy to ACA, and both significantly outperform the naive and the adjoint method. Furthermore, the training speed of MALI is similar to ACA, and both are almost two times faster than the adjoint memthod, and three times faster than the naive method. This validates our analysis on accuracy and computation burden in Table 1.

Refer to caption
Figure 6: Top-1 accuracy on ImageNet validation dataset.

Accuracy on ImageNet  Due to the heavy memory burden caused by large images, the naive method and ACA are unable to train a Neural ODE on ImageNet with 4 GPUs; only MALI and the adjoint method are feasible due to the constant memory. We also compare the Neural ODE to a standard ResNet. As shown in Fig. 6, the accuracy of the Neural ODE trained with MALI closely follows ResNet, and significantly outperforms the adjoint method (top-1 validation: 70% v.s. 63%).

Invariance to discretization scheme  A continuous model should be invariant to discretization schemes (e.g. different types of ODE solvers) as long as the discretization is sufficiently accurate. We test the Neural ODE using different solvers without re-training; since ResNet is often viewed as a one-step Euler discretization of an ODE (Haber & Ruthotto, 2017), we perform similar experiments. As shown in Table 2, Neural ODE consistently achieves high accuracy (\sim70%), while ResNet drops to random guessing (\sim0.1%) because ResNet as a one-step Euler discretization fails to be a meaningful dynamical system (Queiruga et al., 2020).

Robustness to adversarial attack  Hanshu et al. (2019) demonstrated that Neural ODE is more robust to adversarial attack than ResNet on small-scale datasets such as Cifar10. We validate this result on the large-scale ImageNet dataset. The top-1 accuracy of Neural ODE and ResNet under FGSM attack (Goodfellow et al., 2014) are summarized in Table 3. For Neural ODE, due to its invariance to discretization scheme, we derive the gradient for attack using a certain solver (row in Table 3), and inference on the perturbed images using various solvers. For different combinations of solvers and perturbation amplitudes, Neural ODE consistently outperforms ResNet.

Summary   In image recognition tasks, we demonstrate Neural ODE is accurate, invariant to discretization scheme, and more robust to adversarial attack than ResNet. Note that detailed explanation on the robustness of Neural ODE is out of the scope for this paper, but to our knowledge, MALI is the first method to enable training of Neural ODE on large datasets due to constant memory cost.

4.3 Time-series modeling

We apply MALI to latent-ODE (Rubanova et al., 2019) and Neural Controlled Differential Equation (Neural CDE) (Kidger et al., 2020a; b). Our experiment is based on the official implementation from the literature. We report the mean squared error (MSE) on the Mujoco test set in Table 4, which is generated from the “Hopper” model using DeepMind control suite (Tassa et al., 2018); for all experiments with different ratios of training data, MALI achieves similar MSE to ACA, and both outperform the adjoint and naive method. We report the test accuracy on the Speech Command dataset for Neural CDE in Table 5; MALI achieves a higher accuracy than competing methods.

4.4 Continuous generative models

We apply MALI on FFJORD (Grathwohl et al., 2018), a free-from continuous generative model, and compare with several variants in the literature (Finlay et al., 2020; Kidger et al., 2020a). Our experiment is based on the official implementaion of Finlay et al. (2020); for a fair comparison, we train with MALI, and test with the same solver as in the literature (Grathwohl et al., 2018; Finlay et al., 2020), the Dopri5 solver with rtol=atol=105\text{rtol}=\text{atol}=10^{-5} from the torchdiffeq package (Chen et al., 2018). Bits per dim (BPD, lower is better) on validation set for various datasets are reported in Table 6. For continuous models, MALI consistently generates the lowest BPD, and outperforms the Vanilla FFJORD (trained with adjoint), RNODE (regularized FFJORD) and the SemiNorm Adjoint (Kidger et al., 2020a). Furthermore, FFJORD trained with MALI achieves comparable BPD to state-of-the-art discrete-layer flow models in the literature. Please see Sec. B.3 for generated samples.

5 Related works

Besides ALF, the symplectic integrator (Verlet, 1967; Yoshida, 1990) is also able to reconstruct trajectory accurately, yet it’s typically restricted to second order Hamiltonian systems (De Almeida, 1990), and are unsuitable for general ODEs. Besides aforementioned methods, there are other methods for gradient estimation such as interpolated adjoint (Daulbaev et al., 2020) and spectral method (Quaglino et al., 2019), yet the implementations are involved and not publicly available. Other works focus on the theoretical properties of Neural ODEs (Dupont et al., 2019; Tabuada & Gharesifard, 2020; Massaroli et al., 2020). Neural ODE is recently applied to stochastic differential equation (Li et al., 2020), jump differential equation (Jia & Benson, 2019) and auto-regressive models (Wehenkel & Louppe, 2019). 00footnotetext: 1. Rubanova et al. (2019); 2. Zhuang et al. (2020); 3. Kidger et al. (2020a); 4. Chen et al. (2018); 5. Finlay et al. (2020); 6. Dinh et al. (2016); 7. Behrmann et al. (2019); 8. Kingma & Dhariwal (2018); 9. Ho et al. (2019); 10. Chen et al. (2019)

Table 4: Test MSE (×0.01\times 0.01) on Mujoco dataset (lower is better). Results marked with superscript numbers correspond to literature in the footnote.
Percentage of training data RNN1 RNN-GRU1 Latent-ODE
Adjoint1 Naive2 ACA2 MALI
10% 2.451 1.972 0.471 0.362 0.312 0.35
20% 1.711 1.421 0.441 0.302 0.272 0.27
50% 0.791 0.751 0.401 0.292 0.262 0.26
Table 5: Test ACC on Speech Command Dataset
Method Accuracy (%)
Adjoint3 92.8±0.492.8\pm 0.4
SemiNorm3 92.9±0.492.9\pm 0.4
Naive 93.2±0.293.2\pm 0.2
ACA 93.2±0.293.2\pm 0.2
MALI 93.7±0.3\mathbf{93.7\pm 0.3}
Table 6: Bits per dim (BPD) of generative models, lower is better. Results marked with superscript numbers correspond to literature in the footnote.
Dataset Continuous Flow (FFJORD) Discrete Flow
Vanilla4 RNODE5 SemiNorm3 MALI RealNVP6 i-ResNet7 Glow8 Flow++9 Residual Flow10
MNIST 0.994 0.975 0.963 0.87 1.066 1.057 1.058 - 0.9710
CIFAR10 3.404 3.385 3.353 3.27 3.496 3.457 3.358 3.289 3.2810
ImageNet64 - 3.835 - 3.71 3.986 - 3.818 - 3.7610

6 Conclusion

Based on the asynchronous leapfrog integrator, we propose MALI to estimate the gradient for Neural ODEs. To our knowledge, our method is the first to achieve accuracy, fast speed and a constant memory cost. We provide comprehensive theoretical analysis on its properties. We validate MALI with extensive experiments, and achieved new state-of-the-art results in various tasks, including image recognition, continuous generative modeling, and time-series modeling.

References

  • Behrmann et al. (2019) Jens Behrmann, Will Grathwohl, Ricky TQ Chen, David Duvenaud, and Jörn-Henrik Jacobsen. Invertible residual networks. In International Conference on Machine Learning, pp. 573–582, 2019.
  • Chen et al. (2018) Ricky TQ Chen, Yulia Rubanova, Jesse Bettencourt, and David K Duvenaud. Neural ordinary differential equations. In Advances in Neural Information Processing Systems, pp. 6571–6583, 2018.
  • Chen et al. (2019) Ricky TQ Chen, Jens Behrmann, David K Duvenaud, and Jörn-Henrik Jacobsen. Residual flows for invertible generative modeling. In Advances in Neural Information Processing Systems, pp. 9916–9926, 2019.
  • Daulbaev et al. (2020) Talgat Daulbaev, Alexandr Katrutsa, Larisa Markeeva, Julia Gusak, Andrzej Cichocki, and Ivan Oseledets. Interpolated adjoint method for neural odes. arXiv preprint arXiv:2003.05271, 2020.
  • De Almeida (1990) Alfredo M Ozorio De Almeida. Hamiltonian systems: chaos and quantization. Cambridge University Press, 1990.
  • Dinh et al. (2016) Laurent Dinh, Jascha Sohl-Dickstein, and Samy Bengio. Density estimation using real nvp. arXiv preprint arXiv:1605.08803, 2016.
  • Dupont et al. (2019) Emilien Dupont, Arnaud Doucet, and Yee Whye Teh. Augmented neural odes. In Advances in Neural Information Processing Systems, pp. 3140–3150, 2019.
  • Finlay et al. (2020) Chris Finlay, Jörn-Henrik Jacobsen, Levon Nurbekyan, and Adam M Oberman. How to train your neural ode: the world of jacobian and kinetic regularization. In International Conference on Machine Learning, 2020.
  • Goodfellow et al. (2014) Ian J Goodfellow, Jonathon Shlens, and Christian Szegedy. Explaining and harnessing adversarial examples. arXiv preprint arXiv:1412.6572, 2014.
  • Grathwohl et al. (2018) Will Grathwohl, Ricky TQ Chen, Jesse Bettencourt, Ilya Sutskever, and David Duvenaud. Ffjord: Free-form continuous dynamics for scalable reversible generative models. arXiv preprint arXiv:1810.01367, 2018.
  • Haber & Ruthotto (2017) Eldad Haber and Lars Ruthotto. Stable architectures for deep neural networks. Inverse Problems, 34(1):014004, 2017.
  • Hanshu et al. (2019) YAN Hanshu, DU Jiawei, TAN Vincent, and FENG Jiashi. On robustness of neural ordinary differential equations. In International Conference on Learning Representations, 2019.
  • He et al. (2016) Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp.  770–778, 2016.
  • Ho et al. (2019) Jonathan Ho, Xi Chen, Aravind Srinivas, Yan Duan, and Pieter Abbeel. Flow++: Improving flow-based generative models with variational dequantization and architecture design. arXiv preprint arXiv:1902.00275, 2019.
  • Jia & Benson (2019) Junteng Jia and Austin R Benson. Neural jump stochastic differential equations. In Advances in Neural Information Processing Systems, pp. 9847–9858, 2019.
  • Kidger et al. (2020a) Patrick Kidger, Ricky T. Q. Chen, and Terry Lyons. “Hey, that’s not an ODE”: Faster ODE Adjoints with 12 Lines of Code. arXiv:2009.09457, 2020a.
  • Kidger et al. (2020b) Patrick Kidger, James Morrill, James Foster, and Terry Lyons. Neural controlled differential equations for irregular time series. arXiv preprint arXiv:2005.08926, 2020b.
  • Kingma & Dhariwal (2018) Durk P Kingma and Prafulla Dhariwal. Glow: Generative flow with invertible 1x1 convolutions. In Advances in Neural Information Processing Systems, pp. 10215–10224, 2018.
  • Li et al. (2020) Xuechen Li, Ting-Kam Leonard Wong, Ricky TQ Chen, and David Duvenaud. Scalable gradients for stochastic differential equations. arXiv preprint arXiv:2001.01328, 2020.
  • Liu (2017) Kuang Liu. Train cifar10 with pytorch. 2017. URL https://github.com/kuangliu/pytorch-cifar.
  • Lu et al. (2018) Yiping Lu, Aoxiao Zhong, Quanzheng Li, and Bin Dong. Beyond finite layer neural networks: Bridging deep architectures and numerical differential equations. In International Conference on Machine Learning, pp. 3276–3285. PMLR, 2018.
  • Massaroli et al. (2020) Stefano Massaroli, Michael Poli, Jinkyoo Park, Atsushi Yamashita, and Hajime Asama. Dissecting neural odes. arXiv preprint arXiv:2002.08071, 2020.
  • Mutze (2013) Ulrich Mutze. An asynchronous leapfrog method ii. arXiv preprint arXiv:1311.6602, 2013.
  • Pascanu et al. (2013) Razvan Pascanu, Tomas Mikolov, and Yoshua Bengio. On the difficulty of training recurrent neural networks. In International conference on machine learning, pp. 1310–1318, 2013.
  • Pontryagin (1962) Lev Semenovich Pontryagin. Mathematical theory of optimal processes. Routledge, 1962.
  • Quaglino et al. (2019) Alessio Quaglino, Marco Gallieri, Jonathan Masci, and Jan Koutník. Snode: Spectral discretization of neural odes for system identification. arXiv preprint arXiv:1906.07038, 2019.
  • Queiruga et al. (2020) Alejandro F Queiruga, N Benjamin Erichson, Dane Taylor, and Michael W Mahoney. Continuous-in-depth neural networks. arXiv preprint arXiv:2008.02389, 2020.
  • Rubanova et al. (2019) Yulia Rubanova, Ricky TQ Chen, and David K Duvenaud. Latent ordinary differential equations for irregularly-sampled time series. In Advances in Neural Information Processing Systems, pp. 5320–5330, 2019.
  • Runge (1895) Carl Runge. Über die numerische auflösung von differentialgleichungen. Mathematische Annalen, 46(2):167–178, 1895.
  • Ruthotto & Haber (2019) Lars Ruthotto and Eldad Haber. Deep neural networks motivated by partial differential equations. Journal of Mathematical Imaging and Vision, pp.  1–13, 2019.
  • Ruthotto et al. (2020) Lars Ruthotto, Stanley J Osher, Wuchen Li, Levon Nurbekyan, and Samy Wu Fung. A machine learning framework for solving high-dimensional mean field game and mean field control problems. Proceedings of the National Academy of Sciences, 117(17):9183–9193, 2020.
  • Sanchez-Gonzalez et al. (2019) Alvaro Sanchez-Gonzalez, Victor Bapst, Kyle Cranmer, and Peter Battaglia. Hamiltonian graph networks with ode integrators. arXiv preprint arXiv:1909.12790, 2019.
  • Silvester (2000) John R Silvester. Determinants of block matrices. The Mathematical Gazette, 84(501):460–467, 2000.
  • Süli & Mayers (2003) Endre Süli and David F Mayers. An introduction to numerical analysis. Cambridge university press, 2003.
  • Tabuada & Gharesifard (2020) Paulo Tabuada and Bahman Gharesifard. Universal approximation power of deep neural networks via nonlinear control theory. arXiv preprint arXiv:2007.06007, 2020.
  • Tassa et al. (2018) Yuval Tassa, Yotam Doron, Alistair Muldal, Tom Erez, Yazhe Li, Diego de Las Casas, David Budden, Abbas Abdolmaleki, Josh Merel, Andrew Lefrancq, et al. Deepmind control suite. arXiv preprint arXiv:1801.00690, 2018.
  • Verlet (1967) Loup Verlet. Computer” experiments” on classical fluids. i. Thermodynamical properties of Lennard-Jones molecules. Physical review, 159(1):98, 1967.
  • Wanner & Hairer (1996) Gerhard Wanner and Ernst Hairer. Solving ordinary differential equations II. Springer Berlin Heidelberg, 1996.
  • Wehenkel & Louppe (2019) Antoine Wehenkel and Gilles Louppe. Unconstrained monotonic neural networks. In Advances in Neural Information Processing Systems, pp. 1545–1555, 2019.
  • Weinan (2017) E Weinan. A proposal on machine learning via dynamical systems. Communications in Mathematics and Statistics, 5(1):1–11, 2017.
  • Yoshida (1990) Haruo Yoshida. Construction of higher order symplectic integrators. Physics letters A, 150(5-7):262–268, 1990.
  • Zhong et al. (2019) Yaofeng Desmond Zhong, Biswadip Dey, and Amit Chakraborty. Symplectic ode-net: Learning hamiltonian dynamics with control. arXiv preprint arXiv:1909.12077, 2019.
  • Zhuang et al. (2020) Juntang Zhuang, Nicha Dvornek, Xiaoxiao Li, Sekhar Tatikonda, Xenophon Papademetris, and James Duncan. Adaptive checkpoint adjoint method for gradient estimation in neural ode. International Conference on Machine Learning, 2020.
\listofatoc

Appendix A Theoretical properties of ALF integrator

A.1 Algorithm of ALF

For the ease of reading, we write the algorithm for ψ\psi in ALF below, which is the same as Algo. 2 in the main paper, but uses slightly different notations for the ease of analysis.

Input (zin^,vin^,sin,h)=(z0^,v0^,s0,h)(\widehat{z_{in}},\widehat{v_{in}},s_{in},h)=(\widehat{z_{0}},\widehat{v_{0}},s_{0},h) where s0s_{0} is current time, z0^\widehat{z_{0}} and v0^\widehat{v_{0}} are correponding values at time s0s_{0}; stepsize hh.
Forward
s1\displaystyle s_{1} =s0+h/2\displaystyle=s_{0}+h/2 (1)
z1^\displaystyle\widehat{z_{1}} =z0^+v0^×h/2\displaystyle=\widehat{z_{0}}+\widehat{v_{0}}\times h/2 (2)
v1^\displaystyle\widehat{v_{1}} =f(z1^,s1)\displaystyle=f(\widehat{z_{1}},s_{1}) (3)
v2^\displaystyle\widehat{v_{2}} =v1^+(v1^v0^)\displaystyle=\widehat{v_{1}}+(\widehat{v_{1}}-\widehat{v_{0}}) (4)
z2^\displaystyle\widehat{z_{2}} =z1^+v2^×h/2\displaystyle=\widehat{z_{1}}+\widehat{v_{2}}\times h/2 (5)
s2\displaystyle s_{2} =s1+h/2\displaystyle=s_{1}+h/2 (6)
Output                          (zout^,vout^,sout,h)=(z2^,v2^,s2,h)(\widehat{z_{out}},\widehat{v_{out}},s_{out},h)=(\widehat{z_{2}},\widehat{v_{2}},s_{2},h)
Algorithm 5 Forward of ψ\psi in ALF

For simplicity, we can re-write the forward of ALF as

[z2^v2^]=[z0^+hf(z0^+h2v0^,s0+h2)2f(z0^+h2v0^,s0+h2)v0^]\begin{bmatrix}\widehat{z_{2}}\\ \\ \widehat{v_{2}}\end{bmatrix}=\begin{bmatrix}\widehat{z_{0}}+hf(\widehat{z_{0}}+\frac{h}{2}\widehat{v_{0}},s_{0}+\frac{h}{2})\\ \\ 2f(\widehat{z_{0}}+\frac{h}{2}\widehat{v_{0}},s_{0}+\frac{h}{2})-\widehat{v_{0}}\end{bmatrix} (7)

Similarly, the inverse of ALF can be written as

[z0^v0^]=[z2^hf(z2^h2v2^,s2h2)2f(z2^h2v2^,s2h2)v2^]\begin{bmatrix}\widehat{z_{0}}\\ \widehat{v_{0}}\end{bmatrix}=\begin{bmatrix}\widehat{z_{2}}-hf(\widehat{z_{2}}-\frac{h}{2}\widehat{v_{2}},s_{2}-\frac{h}{2})\\ \\ 2f(\widehat{z_{2}}-\frac{h}{2}\widehat{v_{2}},s_{2}-\frac{h}{2})-\widehat{v_{2}}\end{bmatrix} (8)

A.2 Preliminaries

For an ODE of the form

dz(t)dt=f(z(t),t)\frac{\mathrm{d}z(t)}{\mathrm{d}t}=f(z(t),t) (9)

We have:

d2z(t)dt2=ddtf(z(t),t)=f(z(t),t)t+f(z(t),t)zdz(t)dt\displaystyle\frac{\mathrm{d}^{2}z(t)}{dt^{2}}=\frac{\mathrm{d}}{\mathrm{d}t}f(z(t),t)=\frac{\partial f(z(t),t)}{\partial t}+\frac{\partial f(z(t),t)}{\partial z}\frac{\mathrm{d}z(t)}{\mathrm{d}t} (10)

For the ease of notation, we re-write Eq. 10 as

d2z(t)dt2=ft+fzf\frac{\mathrm{d}^{2}z(t)}{dt^{2}}=f_{t}+f_{z}f (11)

where ftf_{t} and fzf_{z} represents the partial derivative of ff w.r.t tt and zz respectively.

A.3 Local truncation error of ALF

Theorem A.1 (Theorem 3.1 in the main paper).

For a single step in ALF with stepsize hh, the local truncation error of zz is O(h3)O(h^{3}), and the local truncation errof of vv is O(h2)O(h^{2}).

Proof.

Under the same notation as Algo. 5, denote the ground-truth state of zz and vv starting from (z0^,s0)(\widehat{z_{0}},s_{0}) as z~\widetilde{z} and v~\widetilde{v} respectively. Then the local truncation error is

Lz=z~(s0+h)z2^,Lv=v~(s0+h)v2^L_{z}=\widetilde{z}(s_{0}+h)-\widehat{z_{2}},\ \ L_{v}=\widetilde{v}(s_{0}+h)-\widehat{v_{2}} (12)

We estimate LzL_{z} and LvL_{v} in terms of polynomial of hh.

Under mild assumptions that ff is smooth up to 2nd order almost everywhere (this is typically satisfied with neural networks with bounded weights), hence Taylor expansion is meaningful for ff. By Eq. 11, the Taylor expansion of z~\widetilde{z} around point (z0^,v0^,s0)(\widehat{z_{0}},\widehat{v_{0}},s_{0}) is

z~(s0+h)\displaystyle\widetilde{z}(s_{0}+h) =z0^+hdzdt+h22d2zdt2+O(h3)\displaystyle=\widehat{z_{0}}+h\frac{\mathrm{d}z}{dt}+\frac{h^{2}}{2}\frac{\mathrm{d}^{2}z}{dt^{2}}+O(h^{3}) (13)
=z0^+hf(z0^,s0)+h22(ft(z0^,s0)+fz(z0^,s0)f(z0^,s0))+O(h3)\displaystyle=\widehat{z_{0}}+hf(\widehat{z_{0}},s_{0})+\frac{h^{2}}{2}\Big{(}f_{t}(\widehat{z_{0}},s_{0})+f_{z}(\widehat{z_{0}},s_{0})f(\widehat{z_{0}},s_{0})\Big{)}+O(h^{3}) (14)

Next, we analyze accuracy of the numerical approximation. For simplicity, we directly analyze Eq. 7 by performing Taylor Expansion on ff.

f(z0^+h2v0^,s0+h2)=f(z0^,s0)+h2ft(z0^,s0)+hv0^2fz(z0^,s0)+O(h2)\displaystyle f(\widehat{z_{0}}+\frac{h}{2}\widehat{v_{0}},s_{0}+\frac{h}{2})=f(\widehat{z_{0}},s_{0})+\frac{h}{2}f_{t}(\widehat{z_{0}},s_{0})+\frac{h\widehat{v_{0}}}{2}f_{z}(\widehat{z_{0}},s_{0})+O(h^{2}) (15)
z2^=z0^+hf(z0^+h2v0^,s0+h2)\widehat{z_{2}}=\widehat{z_{0}}+hf(\widehat{z_{0}}+\frac{h}{2}\widehat{v_{0}},s_{0}+\frac{h}{2}) (16)

Plug Eq. 14, Eq. 15 and E.q. 16 into the definition of LzL_{z}, we get

Lz\displaystyle L_{z} =z~(s0+h)z2^\displaystyle=\widetilde{z}(s_{0}+h)-\widehat{z_{2}} (17)
=[z0^+hf(z0^,s0)+h22(ft(z0^,s0)+fz(z0^,s0)f(z0^,s0))]\displaystyle=\Big{[}\widehat{z_{0}}+hf(\widehat{z_{0}},s_{0})+\frac{h^{2}}{2}\Big{(}f_{t}(\widehat{z_{0}},s_{0})+f_{z}(\widehat{z_{0}},s_{0})f(\widehat{z_{0}},s_{0})\Big{)}\Big{]}
[z0^+h(f(z0^,s0)+h2ft(z0^,s0)+hv0^2fz(z0^,s0))]+O(h3)\displaystyle-\Big{[}\widehat{z_{0}}+h\Big{(}f(\widehat{z_{0}},s_{0})+\frac{h}{2}f_{t}(\widehat{z_{0}},s_{0})+\frac{h\widehat{v_{0}}}{2}f_{z}(\widehat{z_{0}},s_{0})\Big{)}\Big{]}+O(h^{3}) (18)
=h22fz(z0^,s0)(f(z0^,s0)v0^)+O(h3)\displaystyle=\frac{h^{2}}{2}f_{z}(\widehat{z_{0}},s_{0})\Big{(}f(\widehat{z_{0}},s_{0})-\widehat{v_{0}}\Big{)}+O(h^{3}) (19)

Therefore, if |f(z0^,s0)v0^|\Big{|}f(\widehat{z_{0}},s_{0})-\widehat{v_{0}}\Big{|} is of order O(1)O(1), LzL_{z} is of order O(h2)O(h^{2}); if |f(z0^,s0)v0^|\Big{|}f(\widehat{z_{0}},s_{0})-\widehat{v_{0}}\Big{|} is of order O(h)O(h) or smaller, then LzL_{z} is of order O(h3)O(h^{3}). Specifically, at the start time of integration, we have |f(z0^,s0)v0^=0|\Big{|}f(\widehat{z_{0}},s_{0})-\widehat{v_{0}}=0\Big{|}, by induction, LzL_{z} at end time is O(h3)O(h^{3}).

Next we analyze the local truncation error in vv, denoted as LvL_{v}. Denote the ground truth as v~(t0+h)\widetilde{v}(t_{0}+h), we have

v~(s0+h)\displaystyle\widetilde{v}(s_{0}+h) =f(z~(s0+h),s0+h)\displaystyle=f\big{(}\widetilde{z}(s_{0}+h),s_{0}+h\big{)} (20)
=f(z0^,s0)+hft(z0^,s0)+(z~(s0+h)z0^)fz(z0^,s0)+O(h2)\displaystyle=f(\widehat{z_{0}},s_{0})+hf_{t}(\widehat{z_{0}},s_{0})+\big{(}\widetilde{z}(s_{0}+h)-\widehat{z_{0}}\big{)}f_{z}(\widehat{z_{0}},s_{0})+O(h^{2}) (21)

Next we analyze the error in the numerical approximation. Plug Eq. 15 into Eq. 7,

v2^\displaystyle\widehat{v_{2}} =2f(z0^+h2v0^,s0+h2)v0^\displaystyle=2f(\widehat{z_{0}}+\frac{h}{2}\widehat{v_{0}},s_{0}+\frac{h}{2})-\widehat{v_{0}} (22)
=f(z0^,s0)+(f(z0^,s0)v0^)+hft(z0^,s0)+hv0^fz(z0^,s0)+O(h2)\displaystyle=f(\widehat{z_{0}},s_{0})+\big{(}f(\widehat{z_{0}},s_{0})-\widehat{v_{0}}\big{)}+hf_{t}(\widehat{z_{0}},s_{0})+h\widehat{v_{0}}f_{z}(\widehat{z_{0}},s_{0})+O(h^{2}) (23)

From Eq. 14, Eq. 21 and Eq. 23, we have

Lv\displaystyle L_{v} =v~(s0+h)v2^\displaystyle=\widetilde{v}(s_{0}+h)-\widehat{v_{2}} (24)
=(f(z0^,s0)v0^)+(z~(s0+h)(z0^+hv0^))fz(z0^,s0)+O(h2)\displaystyle=\Big{(}f(\widehat{z_{0}},s_{0})-\widehat{v_{0}}\Big{)}+\Big{(}\widetilde{z}(s_{0}+h)-\big{(}\widehat{z_{0}}+h\widehat{v_{0}}\big{)}\Big{)}f_{z}(\widehat{z_{0}},s_{0})+O(h^{2}) (25)
=(f(z0^,s0)v0^)+h(f(z0^,s0)v0^)fz(z0^,s0)+O(h2)\displaystyle=\Big{(}f(\widehat{z_{0}},s_{0})-\widehat{v_{0}}\Big{)}+h\Big{(}f(\widehat{z_{0}},s_{0})-\widehat{v_{0}}\Big{)}f_{z}(\widehat{z_{0}},s_{0})+O(h^{2}) (26)

The last equation is derived by plugging in Eq. 14. Note that Eq. 26 holds for every single step forward in time, and at the start time of integration, we have |f(z0^,s0)v0^|=0\big{|}f(\widehat{z_{0}},s_{0})-\widehat{v_{0}}\big{|}=0 due to our initialization as in Sec. 3.1 of the main paper. Therefore, by induction, LvL_{v} is of order O(h2)O(h^{2}) for consecutive steps. ∎

A.4 Stability analysis

Lemma A.1.1.

For a matrix of the form [ABCD]\begin{bmatrix}A\ \ B\\ C\ \ D\end{bmatrix}, if A,B,C,DA,B,C,D are square matrices of the same shape, and CD=DCCD=DC, then we have  det[ABCD]=det(ADBC)\mathrm{det}\begin{bmatrix}A\ \ B\\ C\ \ D\end{bmatrix}=\mathrm{det}(AD-BC)

Proof.

See (Silvester, 2000) for a detailed proof. ∎

Theorem A.2.

For ALF integrator with stepsize hh, if hσih\sigma_{i} is 0 or is imaginary with norm no larger than 1, where σi\sigma_{i} is the ii-th eigenvalue of the Jacobian fz\frac{\partial f}{\partial z}, then the solver is on the critical boundary of A-stability; otherwise, the solver is not A-stable.

Proof.

A solver is A-stable is equivalent to the eigenvalue of the numerical forward has a norm below 1. We calculate the eigenvalue of ψ\psi below.

For the function defined by Eq. 7, the Jacobian is

J=[z2^z0z2^v0^v2^z0v2^v0^]=[I+hfzh22fz2×fzhfzI]J=\begin{bmatrix}\frac{\partial\widehat{z_{2}}}{\partial z_{0}}&\frac{\partial\widehat{z_{2}}}{\partial\widehat{v_{0}}}\\ \\ \frac{\partial\widehat{v_{2}}}{\partial z_{0}}&\frac{\partial\widehat{v_{2}}}{\partial\widehat{v_{0}}}\end{bmatrix}=\begin{bmatrix}I+h\frac{\partial f}{\partial z}&\frac{h^{2}}{2}\frac{\partial f}{\partial z}\\ \\ 2\times\frac{\partial f}{\partial z}&h\frac{\partial f}{\partial z}-I\end{bmatrix} (27)

We determine the eigenvalue of JJ by solving the equation

det(JλI)=[hfz+(1λ)Ih22fz2×fzhfz(1+λ)I]=0\mathrm{det}(J-\lambda I)=\begin{bmatrix}h\frac{\partial f}{\partial z}+(1-\lambda)I&\frac{h^{2}}{2}\frac{\partial f}{\partial z}\\ \\ 2\times\frac{\partial f}{\partial z}&h\frac{\partial f}{\partial z}-(1+\lambda)I\end{bmatrix}=0 (28)

It’s trivial to check JJ satisfies conditions for Lemma A.1.1.Therefore, we have

det(JλI)\displaystyle\mathrm{det}(J-\lambda I) =det[(hfz+(1λ)I)(hfz(1+λ)I)(h22fz)(2×fz)]\displaystyle=\mathrm{det}\Big{[}\Big{(}h\frac{\partial f}{\partial z}+(1-\lambda)I\Big{)}\Big{(}h\frac{\partial f}{\partial z}-(1+\lambda)I\Big{)}-\Big{(}\frac{h^{2}}{2}\frac{\partial f}{\partial z}\Big{)}\Big{(}2\times\frac{\partial f}{\partial z}\Big{)}\Big{]} (29)
=det[2λhfz+(λ21)I]\displaystyle=\mathrm{det}\Big{[}-2\lambda h\frac{\partial f}{\partial z}+(\lambda^{2}-1)I\Big{]} (30)

Suppose the eigen-decompostion of fz\frac{\partial f}{\partial z} can be written as

fz=Λ[σ1σ2σN]Λ1\frac{\partial f}{\partial z}=\Lambda\begin{bmatrix}\sigma_{1}\\ &\sigma_{2}\\ &&...\\ &&&&\sigma_{N}\end{bmatrix}\Lambda^{-1} (31)

Note that I=ΛIλ1I=\Lambda I\lambda^{-1}, hence we have

det(JλI)\displaystyle\mathrm{det}(J-\lambda I) =detΛ{2λh[σ1σ2σN]+(λ21)I}Λ1\displaystyle=\mathrm{det}\ \ \Lambda\Bigg{\{}-2\lambda h\begin{bmatrix}\sigma_{1}\\ &\sigma_{2}\\ &&...\\ &&&&\sigma_{N}\end{bmatrix}+(\lambda^{2}-1)I\Bigg{\}}\Lambda^{-1} (32)
=i=1N(λ22hσiλ1)\displaystyle=\prod_{i=1}^{N}(\lambda^{2}-2h\sigma_{i}\lambda-1) (33)

Hence the eigenvalues are

λi±=hσi±h2σi2+1\lambda_{i\pm}=h\sigma_{i}\pm\sqrt{h^{2}\sigma_{i}^{2}+1} (34)

A-stability requires |λi±|<1,i|\lambda_{i\pm}|<1,\forall i, and has no solution.

The critical boundary is |λi±|=1|\lambda_{i\pm}|=1, the solution is: hσih\sigma_{i} is 0 or on the imaginary line with norm no larger than 1. ∎

A.5 Damped ALF

Input (zin^,vin^,sin,h)=(z0^,v0^,s0,h)(\widehat{z_{in}},\widehat{v_{in}},s_{in},h)=(\widehat{z_{0}},\widehat{v_{0}},s_{0},h) where s0s_{0} is current time, z0^\widehat{z_{0}} and v0^\widehat{v_{0}} are correponding values at time s0s_{0}; stepsize hh.
Forward
s1\displaystyle s_{1} =s0+h/2\displaystyle=s_{0}+h/2 (35)
z1^\displaystyle\widehat{z_{1}} =z0^+v0^×h/2\displaystyle=\widehat{z_{0}}+\widehat{v_{0}}\times h/2 (36)
v1^\displaystyle\widehat{v_{1}} =f(z1^,s1)\displaystyle=f(\widehat{z_{1}},s_{1}) (37)
v2^\displaystyle\widehat{v_{2}} =v0^+2η(v1^v0^)\displaystyle={\color[rgb]{0,0,1}{\widehat{v_{0}}+2\eta(\widehat{v_{1}}-\widehat{v_{0}})}} (38)
z2^\displaystyle\widehat{z_{2}} =z1^+v2^×h/2\displaystyle=\widehat{z_{1}}+\widehat{v_{2}}\times h/2 (39)
s2\displaystyle s_{2} =s1+h/2\displaystyle=s_{1}+h/2 (40)
Output                          (zout^,vout^,sout,h)=(z2^,v2^,s2,h)(\widehat{z_{out}},\widehat{v_{out}},s_{out},h)=(\widehat{z_{2}},\widehat{v_{2}},s_{2},h)
Algorithm 6 Forward of ψ\psi in Damped ALF ( η(0,1]\eta\in(0,1] )
Input (zout^,vout^,sout,h)(\widehat{z_{out}},\widehat{v_{out}},s_{out},h) where souts_{out} is current time, zout^\widehat{z_{out}} and vout^\widehat{v_{out}} are corresponding values at souts_{out}, hh is stepsize.
Inverse
(z2^,v2^,s2,h)\displaystyle(\widehat{z_{2}},\widehat{v_{2}},s_{2},h) =(zout^,vout^,sout,h)\displaystyle=(\widehat{z_{out}},\widehat{v_{out}},s_{out},h) (41)
s1\displaystyle s_{1} =s2h/2\displaystyle=s_{2}-h/2 (42)
z1^\displaystyle\widehat{z_{1}} =z2v2^×h/2\displaystyle=z_{2}-\widehat{v_{2}}\times h/2 (43)
v1^\displaystyle\widehat{v_{1}} =f(z1^,s1)\displaystyle=f(\widehat{z_{1}},s_{1}) (44)
v0^\displaystyle\widehat{v_{0}} =(v2^2ηv1^)/(12η)\displaystyle={\color[rgb]{0,0,1}(\widehat{v_{2}}-2\eta\widehat{v_{1}})/(1-2\eta)} (45)
z0^\displaystyle\widehat{z_{0}} =z1^v0^×h/2\displaystyle=\widehat{z_{1}}-\widehat{v_{0}}\times h/2 (46)
s0\displaystyle s_{0} =s1h/2\displaystyle=s_{1}-h/2 (47)
Output                          (zin^,vin^,sin,h)=(z0^,v0^,s0,h)(\widehat{z_{in}},\widehat{v_{in}},s_{in},h)=(\widehat{z_{0}},\widehat{v_{0}},s_{0},h)
Algorithm 7 ψ1\psi^{-1} (Inverse of ψ\psi) in Damped ALF ( η(0,1]\eta\in(0,1] )

The main difference between ALF and Damped ALF is marked in blue in Algo. 6. In ALF, the update of v2^\widehat{v_{2}} is v2^=(v1^v0^)+v1^=2(v1^v0^)+v0^\widehat{v_{2}}=(\widehat{v_{1}}-\widehat{v_{0}})+\widehat{v_{1}}=2(\widehat{v_{1}}-\widehat{v_{0}})+\widehat{v_{0}}; while in Damped ALF, the update is scaled by a factor η\eta between 0 and 1, so the update is v2^=2η(v1^v0^)+v0^\widehat{v_{2}}=2\eta(\widehat{v_{1}}-\widehat{v_{0}})+\widehat{v_{0}}. When η=1\eta=1, Damped ALF reduces to ALF.

Similar to Sec. A.1, we can write the forward as For simplicity, we can re-write the forward of ALF as

[z2^v2^]=[z0^+ηhf(z0^+h2v0^,s0+h2)+(1η)hv0^2ηf(z0^+h2v0^,s0+h2)+(12η)v0^]\begin{bmatrix}\widehat{z_{2}}\\ \\ \widehat{v_{2}}\end{bmatrix}=\begin{bmatrix}\widehat{z_{0}}+\eta hf(\widehat{z_{0}}+\frac{h}{2}\widehat{v_{0}},s_{0}+\frac{h}{2})+(1-\eta)h\widehat{v_{0}}\\ \\ 2\eta f(\widehat{z_{0}}+\frac{h}{2}\widehat{v_{0}},s_{0}+\frac{h}{2})+(1-2\eta)\widehat{v_{0}}\end{bmatrix} (48)

Similarly, the inverse of ALF can be written as

[z0^v0^]=[z2^h1η12ηv2^+hη12ηf(z2^h2v2^,s2h2)112ηv2^2η12ηf(z2^h2v2^,s2h2)]\begin{bmatrix}\widehat{z_{0}}\\ \\ \widehat{v_{0}}\end{bmatrix}=\begin{bmatrix}\widehat{z_{2}}-h\frac{1-\eta}{1-2\eta}\widehat{v_{2}}+h\frac{\eta}{1-2\eta}f(\widehat{z_{2}}-\frac{h}{2}\widehat{v_{2}},s_{2}-\frac{h}{2})\\ \\ \frac{1}{1-2\eta}\widehat{v_{2}}-\frac{2\eta}{1-2\eta}f(\widehat{z_{2}}-\frac{h}{2}\widehat{v_{2}},s_{2}-\frac{h}{2})\end{bmatrix} (49)
Theorem A.3.

For a single step in Damped ALF with stepsize hh, the local truncation error of zz is O(h2)O(h^{2}), and the local truncation errof of vv is O(h)O(h).

Proof.

The proof is similar to Thm. A.3. By similar calculations using the Taylor Expansion in Eq. 15 and Eq. 14, we have

z2^z~(s0+h)\displaystyle\widehat{z_{2}}-\tilde{z}(s_{0}+h) =(1η)hv0^+hη[f(z0^,s0)+h2ft(z0^,s0)+hv0^2fz(z0^,s0)]\displaystyle=(1-\eta)h\widehat{v_{0}}+h\eta\Big{[}f(\widehat{z_{0}},s_{0})+\frac{h}{2}f_{t}(\widehat{z_{0}},s_{0})+\frac{h\widehat{v_{0}}}{2}f_{z}(\widehat{z_{0}},s_{0})\Big{]}
h[f(z0^,s0)+h2ftz0^,s0+h2fz(z0^,s0)f(z0^,s0)]+O(h2)\displaystyle-h\Big{[}f(\widehat{z_{0}},s_{0})+\frac{h}{2}f_{t}{\widehat{z_{0}},s_{0}}+\frac{h}{2}f_{z}(\widehat{z_{0}},s_{0})f(\widehat{z_{0}},s_{0})\Big{]}+O(h^{2}) (50)
=(1η)h(v0^f(z0^,s0))+η12h2ft(z0^,s0)\displaystyle=(1-\eta)h\Big{(}\widehat{v_{0}}-f(\widehat{z_{0}},s_{0})\Big{)}+\frac{\eta-1}{2}h^{2}f_{t}(\widehat{z_{0}},s_{0})
+h22(ηv0^f(z0^,s0))fz(z0^,s0)+O(h2)\displaystyle+\frac{h^{2}}{2}\Big{(}\eta\widehat{v_{0}}-f(\widehat{z_{0}},s_{0})\Big{)}f_{z}(\widehat{z_{0}},s_{0})+O(h^{2}) (51)

Using Eq. 21, Eq. 15 and Eq. 14, we have

v2~v2^\displaystyle\tilde{v_{2}}-\widehat{v_{2}} =(12η)v0^+(2η1)f(z0^,s0)+(1η)hft(z0^,s0)\displaystyle=(1-2\eta)\widehat{v_{0}}+(2\eta-1)f(\widehat{z_{0}},s_{0})+(1-\eta)hf_{t}(\widehat{z_{0}},s_{0})
+(z~(s0+h)z0^ηhv0^)fz(z0^,s0)+O(h2)\displaystyle+\Big{(}\tilde{z}(s_{0}+h)-\widehat{z_{0}}-\eta h\widehat{v_{0}}\Big{)}f_{z}(\widehat{z_{0}},s_{0})+O(h^{2}) (52)
=(2η1)[f(z0^,s0)z0^]+(1η)hft(z0^,s0)\displaystyle=(2\eta-1)\big{[}f(\widehat{z_{0}},s_{0})-\widehat{z_{0}}\big{]}+(1-\eta)hf_{t}(\widehat{z_{0}},s_{0})
+η[hf(z0^,s0)hv0^]fz(z0^,s0)+O(h2)\displaystyle+\eta\Big{[}hf(\widehat{z_{0}},s_{0})-h\widehat{v_{0}}\Big{]}f_{z}(\widehat{z_{0}},s_{0})+O(h^{2}) (53)

Note that when η=1\eta=1, Eq. 51 reduces to Eq. 19, and Eq. 53 reduces to Eq. 26. By initialization, we have |f(z0^,s0)v0^|=0|f(\widehat{z_{0}},s_{0})-\widehat{v_{0}}|=0 at initial time, hence by induction, the local truncation error for zz is O(h2)O(h^{2}); the local truncation error for vv is O(h)O(h) when η<1\eta<1, and is O(h2)O(h^{2}) when η=1\eta=1. ∎

Theorem A.4 (Theorem 3.2 in the main paper).

For Dampled ALF integrator with stepsize hh, where σi\sigma_{i} is the ii-th eigenvalue of the Jacobian fz\frac{\partial f}{\partial z}, then the solver is A-stable if |1+η(hσ1)±η[2hσi+η(hσi1)2]|<1,i\Big{|}1+\eta(h\sigma-1)\pm\sqrt{\eta\big{[}2h\sigma_{i}+\eta(h\sigma_{i}-1)^{2}\big{]}}\Big{|}<1,\ \forall i.

Proof.

The Jacobian of the forward-pass of a single step damped ALF is

J=[I+ηhfz(1η)hI+ηh22fz2ηfzηhfz+(12η)I]J=\begin{bmatrix}I+\eta h\frac{\partial f}{\partial z}&(1-\eta)hI+\eta\frac{h^{2}}{2}\frac{\partial f}{\partial z}\\ \\ 2\eta\frac{\partial f}{\partial z}&\eta h\frac{\partial f}{\partial z}+(1-2\eta)I\end{bmatrix} (54)

when η=1\eta=1, JJ reduces to Eq. 27. We can determine the eigenvalue of JJ using similar techniques. Assume the eigenvalues for fz\frac{\partial f}{\partial z} are {σi}\{\sigma_{i}\}, then we have

det(JλI)\displaystyle\mathrm{det}(J-\lambda I) =det[(1λ)I+ηhfz(1η)hI+ηh22fz2ηfzηhfz+(12ηλ)I]\displaystyle=\mathrm{det}\begin{bmatrix}(1-\lambda)I+\eta h\frac{\partial f}{\partial z}&(1-\eta)hI+\eta\frac{h^{2}}{2}\frac{\partial f}{\partial z}\\ \\ 2\eta\frac{\partial f}{\partial z}&\eta h\frac{\partial f}{\partial z}+(1-2\eta-\lambda)I\end{bmatrix} (55)
=det[((1λ)I+ηhfz)(ηhfz+(12ηλ)I)\displaystyle=\mathrm{det}\Big{[}\Big{(}(1-\lambda)I+\eta h\frac{\partial f}{\partial z}\Big{)}\Big{(}\eta h\frac{\partial f}{\partial z}+(1-2\eta-\lambda)I\Big{)}
((1η)hI+ηh22fz)(2ηfz)]\displaystyle-\Big{(}(1-\eta)hI+\eta\frac{h^{2}}{2}\frac{\partial f}{\partial z}\Big{)}\Big{(}2\eta\frac{\partial f}{\partial z}\Big{)}\Big{]} (56)
=i=1N[1+η(hσi1)±η[2hσi+η(hσi1)2]]\displaystyle=\prod_{i=1}^{N}\Big{[}1+\eta(h\sigma_{i}-1)\pm\sqrt{\eta\big{[}2h\sigma_{i}+\eta(h\sigma_{i}-1)^{2}\big{]}}\Big{]} (57)

when η<1\eta<1, it’s easy to check that |1+η(hσi1)±η[2hσi+η(hσi1)2]|<1\Big{|}1+\eta(h\sigma_{i}-1)\pm\sqrt{\eta\big{[}2h\sigma_{i}+\eta(h\sigma_{i}-1)^{2}\big{]}}\Big{|}<1 has non-empty solutions for hσh\sigma. ∎

For a quick validation, we plot the region of A-stability on the imaginary plane for a single eigenvalue in Fig. 1. As η\eta increases, the area of stability decreases. When η=1\eta=1, the system is no-where A-stable, and the boundary for A-stability is on the imaginary axis [i,i][-i,i] where ii is the imaginary unit.

Refer to caption
Refer to caption
Refer to caption
Figure 1: Region of A-stability for eigenvalue on the imaginary plane for damped ALF. From left to right, the region of stability for η=0.25\eta=0.25, η=0.7\eta=0.7,η=0.8\eta=0.8 respectively. As η\eta increases to 1, the area of stability region decreases.

Appendix B Experimental Details

B.1 Image Recognition

B.1.1 Experiment on Cifar10

We directly modify a ResNet18 into a Neural ODE, where the forward of a residual block (y=x+f(x)y=x+f(x)) and the forward of an ODE block (y=x+0Tf(z,t)𝑑ty=x+\int_{0}^{T}f(z,t)dt where T=1T=1) share the same parameterization ff, hence they have the same number of parameters. Our experiment is based on the official implementation by Zhuang et al. (2020) and an open-source repository (Liu, 2017).

All models are trained with SGD optimizer for 90 epochs, with an initial learning rate of 0.01, and decayed by a factor of 10 at 30th epoch and 60th epoch respectively. Training scheme is the same for all models (ResNet, Neural ODE trained with adjoint, naive, ACA and MALI). For ACA, we follow the settings in (Zhuang et al., 2020) and use the official implementation torch_ACA 111https://github.com/juntang-zhuang/torch_ACA, and use a Heun-Euler solver with rtol=101,atol=102rtol=10^{-1},atol=10^{-2} during training. For MALI, we use an adaptive version and set rtol=101,atol=102rtol=10^{-1},atol=10^{-2}. For the naive and adjoint method, we use the default Dopri5 solver from the torchdiffeq222https://github.com/rtqichen/torchdiffeq package with rtol=atol=105\text{rtol}=\text{atol}=10^{-5}. We train all models for 5 independent runs, and report the mean and standard deviation across runs.

B.1.2 Experiments on ImageNet

Training scheme

We conduct experiments on ImageNet with ResNet18 and Neural-ODE18. All models are trained on 4 GTX-1080Ti GPUs with a batchsize of 256. All models are trained for 80 epochs, with an initial learning rate of 0.1, and decayed by a factor of 10 at 30th and 60th epoch. Note that due to the large size input 256×256256\times 256, the naive method and ACA requires a huge memory, and is infeasible to train. MALI and the adjoint method requires a constant memory hence is suitable for large-scale experiments. For both MALI and the adjoint menthod, we use a fixed stepsize of 0.25, and integrates from 0 to T=1T=1. As shown in Table. 2 in the main paper, a stepsize of 0.25 is sufficiently small to train a meaningful continuous model that is robust to discretization scheme.

Refer to caption
((a)) Training curve on ImageNet.
Refer to caption
((b)) Validation curve on ImageNet.
Figure 2: Results on ImageNet.
Invariance to discretization scheme

To test the influence of discretization scheme, we test our Neural ODE with different solvers without re-training. For fixed-stepsize solvers, we tested various step sizes including {0.1,0.15,0.25,0.5,1.0}\{0.1,0.15,0.25,0.5,1.0\}; for adaptive solvers, we set rtol=0.1, atol=0.01 for MALI and Heun-Euler method, and set rtol=102,atol=103\text{rtol}=10^{-2},\text{atol}=10^{-3} for RK23 solver, and set rtol=104,atol=105\text{rtol}=10^{-4},\text{atol}=10^{-5} for Dopri5 solver. As shown in Table. 2, Neural ODE trained with MALI is robust to discretization scheme, and MALI significantly outperforms the adjoint method in terms of accuracy (70% v.s. 63% top-1 accuracy on the validation dataset). An interesting finding is that when trained with MALI which is a second-order solver, and tested with higher-order solver (e.g. RK4), our Neural ODE achieves 70.21% top-1 accuracy, which is higher than both the same solver during training (MALI, 69.59% accuracy) and the ResNet18 (70.09% accuracy).

Furthermore, many papers claim ResNet to be an approximation for an ODE (Lu et al., 2018). However, Queiruga et al. (2020) argues that many numerical discretizations fail to be meaningful dynamical systems, while our experiments demonstrate that our model is continuous hence invariant to discretization schemes.

Adversarial robustness

Besides the high accuracy and robustness to discretization scheme, another advantage of Neural ODE is the robustness to adversarial attack. The adversary robustness of Neural ODE is extensively studied in (Hanshu et al., 2019), but not only validated on small-scale datasets such as Cifar10. To our knowledge, our method is the first to enable effectuve training of Neural ODE on large-scale datasets such as ImageNet and achieve a high accuracy, and we are the first to validate the robustness of Neural ODE on ImageNet. We use the advertorch 333https://github.com/BorealisAI/advertorch toolbox to perform adversarial attack. We test the performance of ResNet and Neural ODE under FGSM attack. To be more convincing, we conduct experiment on the pretrained ResNet18 provided by the official PyTorch website 444https://pytorch.org/docs/stable/torchvision/models.html. Since Neural ODE is invariant to discretization scheme, it’s possible to derive the gradient for attack using one ODE solver, and inference on the perturbed image using another solver. As summarized in Table. 3, Neural ODE consistently achieves a higher accuracy than ResNet under the same attack.

B.2 Time series modeling

We conduct experiments on Latent-ODE models (Rubanova et al., 2019) and Neural CDE (controlled differential equation) (Kidger et al., 2020a). For all experiments, we use the official implementation, and only replace the solver with MALI. The latent-ODE model is trained on the Mujoco dataset processed with code provided by the official implementation, and we experiment with different ratios (10%,20%,50%) of training data as described in (Rubanova et al., 2019). All models are trained for 300 epochs with Adamax optimizer, with an initial learning rate of 0.01 and scaled by 0.999 for each epoch. For the Neural CDE model, for the naive method, ACA and MALI, we perform 5 independent runs and report the mean value and standard deviation; results for the adjoint and seminorm adjoint are from (Kidger et al., 2020a). For Neural CDE, we use MALI with ALF solver with a fixed stepsize of 0.25, and train the model for 100 epochs with an initial learning rate of 0.004.

B.3 Continuous generative models

B.3.1 Training details

Our experiment is based on the official implementation of (Finlay et al., 2020), with the only difference in ODE solver. For a fair comparison, we only use MALI for training, and use Dopri5 solver from torchdiffeq package (Chen et al., 2018) with rtol=atol=105\text{rtol}=\text{atol}=10^{-5}. For MALI, we use adaptive ALF solver with rtol=102,atol=103rtol=10^{-2},atol=10^{-3}, and use an initial stepsize of 0.25. Integration time is from 0 to 1.

On MNIST and CIFAR dataset, we set the regularization coefficients for kinetic energy and Frobenius norm of the derivative function as 0.05. We train the model for 50 epochs with an initial learning rate of 0.001.

B.3.2 Addtional results

We show generated examples on MNIST dataset in Fig. 3, results for Cifar10 dataset in Fig. 4, and results for ImageNet64 in Fig. 5.

Refer to caption
((a)) Real samples from MNIST dataset.
Refer to caption
((b)) Generated samples from FFJORD.
Figure 3: Results on MNIST dataset.
Refer to caption
((a)) Real samples from CIFAR10 dataset.
Refer to caption
((b)) Generated samples from FFJORD.
Figure 4: Results on Cifar10 dataset.
Refer to caption
((a)) Real samples from ImageNet64 dataset.
Refer to caption
((b)) Generated samples from FFJORD.
Figure 5: Results on ImageNet64 dataset.

B.4 Error in gradient estimation for toy examples when t<1t<1

We plot the error in gradient estimation for the toy example defined by Eq.6 in the main paper in Fig. 6. Note that the integration time TT is set as smaller than 1, while the main paper is larger than 20. We observe the same results, MALI and ACA generate smaller error than the adjoint and the naive method.

Refer to caption
((a)) Error in the estimation of gradient w.r.t initial condition.
Refer to caption
((b)) Error in the estimation of gradient w.r.t parameter α\alpha.
Figure 6: Comparison of error in gradient estimation for the toy example by Eq.6 of the main paper, when t<1t<1.

B.5 Results of damped MALI

For all experiments in the main paper, we set η=1\eta=1 and did not use damping. For completeness, we experimented with damped MALI using different values of η\eta. As shown in Table. 7, MALI is robust to different η\eta values.

Table 7: Results of damped MALI with different η\eta values. We report the test accuracy of Neural CDE on Speech Command dataset, and the test MSE of latent-ODE on Mujoco data.
η\eta 1.0 0.95 0.9 0.85
Test Accuracy
on Speech Commands
(Higher is better)
93.7±0.393.7\pm 0.3 93.7±0.193.7\pm 0.1 93.5±0.293.5\pm 0.2 93.7±0.393.7\pm 0.3
Test MSE of latent ODE on Mujoco (Lower is better) 10% training data 0.35 0.36 0.33 0.33
20% training data 0.27 0.25 0.26 0.27