This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Symplectic Adjoint Method for Exact Gradient of Neural ODE with Minimal Memory

Takashi Matsubara
Osaka University
Osaka, Japan 560–8531
[email protected]
&Yuto Miyatake
Osaka University
Osaka, Japan 560–0043
[email protected]
&Takaharu Yaguchi
Kobe University
Kobe, Japan 657–8501
[email protected]
Abstract

A neural network model of a differential equation, namely neural ODE, has enabled the learning of continuous-time dynamical systems and probabilistic distributions with high accuracy. The neural ODE uses the same network repeatedly during a numerical integration. The memory consumption of the backpropagation algorithm is proportional to the number of uses times the network size. This is true even if a checkpointing scheme divides the computation graph into sub-graphs. Otherwise, the adjoint method obtains a gradient by a numerical integration backward in time. Although this method consumes memory only for a single network use, it requires high computational cost to suppress numerical errors. This study proposes the symplectic adjoint method, which is an adjoint method solved by a symplectic integrator. The symplectic adjoint method obtains the exact gradient (up to rounding error) with memory proportional to the number of uses plus the network size. The experimental results demonstrate that the symplectic adjoint method consumes much less memory than the naive backpropagation algorithm and checkpointing schemes, performs faster than the adjoint method, and is more robust to rounding errors.

1 Introduction

Deep neural networks offer remarkable methods for various tasks, such as image recognition [18] and natural language processing [4]. These methods employ a residual architecture [21, 34], in which the output xn+1x_{n+1} of the nn-th operation is defined as the sum of a subroutine fnf_{n} and the input xnx_{n} as xn+1=fn(xn)+xnx_{n+1}=f_{n}(x_{n})+x_{n}. The residual architecture can be regarded as a numerical integration applied to an ordinary differential equation (ODE) [30]. Accordingly, a neural network model of the differential equation dx/dt=f(x){\mathrm{d}x}/{\mathrm{d}t}=f(x), namely, neural ODE, was proposed in [2]. Given an initial condition x(0)=xx(0)=x as an input, the neural ODE solves an initial value problem by numerical integration, obtaining the final value as an output y=x(T)y=x(T). The neural ODE can model continuous-time dynamics such as irregularly sampled time series [24], stable dynamical systems [37, 42], and physical phenomena associated with geometric structures [3, 13, 31]. Further, because the neural ODE approximates a diffeomorphism [43], it can model probabilistic distributions of real-world data by a change of variables [12, 23, 25, 45].

For an accurate integration, the neural ODE must employ a small step size and a high-order numerical integrator composed of many internal stages. A neural network ff is used at each stage of each time step. Thus, the backpropagation algorithm consumes exorbitant memory to retain the whole computation graph [39, 2, 10, 46]. The neural ODE employs the adjoint method to reduce memory consumption—this method obtains a gradient by a backward integration along with the state xx, without consuming memory for retaining the computation graph over time [8, 17, 41, 44]. However, this method incurs high computational costs to suppress numerical errors. Several previous works employed a checkpointing scheme [10, 46, 47]. This scheme only sparsely retains the state xx as checkpoints and recalculates a computation graph from each checkpoint to obtain the gradient. However, this scheme still consumes a significant amount of memory to retain the computation graph between checkpoints.

To address the above limitations, this study proposes the symplectic adjoint method. The main advantages of the proposed method are presented as follows.

Exact Gradient and Fast Computation: In discrete time, the adjoint method suffers from numerical errors or needs a smaller step size. The proposed method uses a specially designed integrator that obtains the exact gradient in discrete time. It works with the same step size as the forward integration and is thus faster than the adjoint method in practice.

Minimal Memory Consumption: Excepting the adjoint method, existing methods apply the backpropagation algorithm to the computation graph of the whole or a subset of numerical integration [10, 46, 47]. The memory consumption is proportional to the number of steps/stages in the graph times the neural network size. Conversely, the proposed method applies the algorithm only to each use of the neural network, and thus the memory consumption is only proportional to the number of steps/stages plus the network size.

Robust to Rounding Error: The backpropagation algorithm accumulates the gradient from each use of the neural network and tends to suffer from rounding errors. Conversely, the proposed method obtains the gradient from each step as a numerical integration and is thus more robust to rounding errors.

Table 1: Comparison of the proposed method with existing methods
Methods Gradient Calculation Exact Checkpoints Memory Consumption

Computational Cost

checkpoint

backprop.

NODE [2] adjoint method no xNx_{N} MM LL M(N+2N~)sLM(\!N\!\>\!\!+\!\>\!\!2\tilde{N}\!)sL
NODE [2] backpropagation yes ​​MNsLMNsL​​ 2MNsL2MNsL

baseline scheme

backpropagation yes x0x_{0} MM NsLNsL 3MNsL3MNsL
ACA [46] backpropagation yes {xn}n=0N1\{x_{n}\}_{n=0}^{N-1} MNM\!N sLsL 3MNsL3MNsL
MALI [47] backpropagation yes xNx_{N} MM sLsL 4MNsL4MNsL
proposed∗∗

symplectic adjoint method

yes {xn}n=0N1,{Xn,i}i=1s\{x_{\!n}\!\}_{n=0}^{N\!-\!1},\{\!X_{n,i}\!\}_{i=1}^{s} MN+sM\!N\!\>\!\!+\!\>\!\!s LL 4MNsL4MNsL

Available only for the asynchronous leapfrog integrator. ∗∗Available for any Runge–Kutta methods.

2 Background and Related Work

2.1 Neural Ordinary Differential Equation and Adjoint Method

We use the following notation.

  • MM:

    the number of stacked neural ODE components,

  • LL:

    the number of layers in a neural network,

  • NN, N~\tilde{N}:

    the number of time steps in the forward and backward integrations, respectively, and

  • ss:

    the number of uses of a neural network ff per step.

ss is typically equal to the number of internal stages of a numerical integrator [17]. A numerical integration forward in time requires a computational cost of O(MNsL)O(MNsL). It also provides a computation graph over time steps, which is retained with a memory of O(MNsL)O(MNsL); the backpropagation algorithm is then applied to obtain the gradient. The total computational cost is O(2MNsL)O(2MNsL), where we suppose the computational cost of the backpropagation algorithm is equal to that of forward propagation. The memory consumption and computational cost are summarized in Table 1.

To reduce the memory consumption, the original study on the neural ODE introduced the adjoint method [2, 8, 17, 41, 44]. This method integrates the pair of the system state xx and the adjoint variable λ\lambda backward in time. The adjoint variable λ\lambda represents the gradient x\frac{\partial\mathcal{L}}{\partial x} of some function \mathcal{L}, and the backward integration of the adjoint variable λ\lambda works as the backpropagation (or more generally the reverse-mode automatic differentiation) in continuous time. The memory consumption is O(M)O(M) to retain the final values x(T)x(T) of MM neural ODE components and O(L)O(L) to obtain the gradient of a neural network ff for integrating the adjoint variable λ\lambda. The computational cost is at least doubled because of the re-integration of the system state xx backward in time. The adjoint method suffers from numerical errors [41, 10]. To suppress the numerical errors, the backward integration often requires a smaller step size than the forward integration (i.e., N~>N\tilde{N}>N), leading to an increase in computation time. Conversely, the proposed symplectic adjoint method uses a specially designed integrator, which provides the exact gradient with the same step size as the forward integration.

2.2 Checkpointing Scheme

The checkpointing scheme has been investigated to reduce the memory consumption of neural networks [14, 15], where intermediate states are retained sparsely as checkpoints, and a computation graph is recomputed from each checkpoint. For example, Gruslys et al. applied this scheme to recurrent neural networks [15]. When the initial value x(0)x(0) of each neural ODE component is retained as a checkpoint, the initial value problem is solved again before applying the backpropagation algorithm to obtain the gradient of the component. Then, the memory consumption is O(M)O(M) for checkpoints and O(NsL)O(NsL) for the backpropagation; the memory consumption is O(M+NsL)O(M+NsL) in total (see the baseline scheme). ANODE scheme retains each step {xn}n=0N1\{x_{n}\}_{n=0}^{N-1} as a checkpoint with a memory of O(MN)O(MN) [10]. Form each checkpoint xnx_{n}, this scheme recalculates the next step xn+1x_{n+1} and obtains the gradient using the backpropagation algorithm with a memory of O(sL)O(sL); the memory consumption is O(MN+sL)O(MN+sL) in total. ACA scheme improves ANODE scheme for methods with adaptive time-stepping by discarding the computation graph to find an optimal step size. Even with checkpoints, the memory consumption is still proportional to the number of uses ss of a neural network ff per step, which is not negligible for a high-order integrator, e.g., s=6s=6 for the Dormand–Prince method [7]. In this context, the proposed method is regarded as a checkpointing scheme inside a numerical integrator. Note that previous studies did not use the notation ss.

Instead of a checkpointing scheme, MALI employs an asynchronous leapfrog (ALF) integrator after the state xx is paired up with the velocity state vv [47]. The ALF integrator is time-reversible, i.e., the backward integration obtains the state xx equal to that in the forward integration without checkpoints [17]. However, the ALF integrator is a second-order integrator, implying that it requires a small step size and a high computational cost to suppress numerical errors. Higher-order Runge–Kutta methods cannot be used in place of the ALF integrator because they are implicit or non-time-reversible. The ALF integrator is inapplicable to physical systems without velocity such as partial differential equation (PDE) systems. Nonetheless, a similar approach named RevNet was proposed before in [11]. When regarding ResNet as a forward Euler method [2, 18], RevNet has an architecture regarded as the leapfrog integrator, and it recalculates the intermediate activations in the reverse direction.

3 Adjoint Method

Consider a system

ddtx=f(x,t,θ),\frac{{\mathrm{d}}}{\mathrm{d}t}x=f(x,t,\theta), (1)

where xx, tt, and θ\theta, respectively, denote the system state, an independent variable (e.g., time), and parameters of the function ff. Given an initial condition x(0)=x0x(0)=x_{0}, the solution x(t)x(t) is given by

x(t)=x0+0tf(x(τ),τ,θ)dτ.x(t)=x_{0}+\int_{0}^{t}f(x(\tau),\tau,\theta){\mathrm{d}\tau}. (2)

The solution x(t)x(t) is evaluated at the terminal t=Tt=T by a function \mathcal{L} as (x(T))\mathcal{L}(x(T)). Our main interest is in obtaining the gradients of (x(T))\mathcal{L}(x(T)) with respect to the initial condition x0x_{0} and the parameters θ\theta.

Now, we introduce the adjoint method [2, 8, 17, 41, 44]. We first focus on the initial condition x0x_{0} and omit the parameters θ\theta. The adjoint method is based on the variational variable δ(t)\delta(t) and the adjoint variable λ(t)\lambda(t). The variational and adjoint variables respectively follow the variational system and adjoint system as follows.

ddtδ(t)=fx(x(t),t)δ(t) for δ(0)=I,ddtλ(t)=fx(x,t)λ(t) for λ(T)=λT.\frac{{\mathrm{d}}}{\mathrm{d}t}\delta(t)=\frac{\partial f}{\partial x}(x(t),t)\delta(t)\mbox{ for }\delta(0)=I,\ \ \frac{{\mathrm{d}}}{\mathrm{d}t}\lambda(t)=-\frac{\partial f}{\partial x}(x,t)^{\top}\lambda(t)\mbox{ for }\lambda(T)=\lambda_{T}. (3)

The variational variable δ(t)\delta(t) represents the Jacobian x(t)x0\frac{\partial x(t)}{\partial x_{0}} of the state x(t)x(t) with respect to the initial condition x0x_{0}; the detailed derivation is summarized in Appendix A.

Remark 1.

The quantity λδ\lambda^{\top}\delta is time-invariant, i.e., λ(t)δ(t)=λ(0)δ(0)\lambda(t)^{\top}\delta(t)=\lambda(0)^{\top}\delta(0).

The proofs of most Remarks and Theorems in this paper are summarized in Appendix B.

Remark 2.

The adjoint variable λ(t)\lambda(t) represents the gradient ((x(T))x(t))(\frac{\partial\mathcal{L}(x(T))}{\partial x(t)})^{\top} if the final condition λT\lambda_{T} of the adjoint variable λ\lambda is set to ((x(T))x(T))(\frac{\partial\mathcal{L}(x(T))}{\partial x(T)})^{\top}.

This is because of the chain rule. Thus, the backward integration of the adjoint variable λ(t)\lambda(t) works as reverse-mode automatic differentiation. The adjoint method has been used for data assimilation, where the initial condition x0x_{0} is optimized by a gradient-based method. For system identification (i.e., parameter adjustment), one can consider the parameters θ\theta as a part of the augmented state x~=[xθ]\tilde{x}=[x\ \ \theta]^{\top} of the system

ddtx~=f~(x~,t),f~(x~,t)=[f(x,t,θ)0],x~(0)=[x0θ].\frac{{\mathrm{d}}}{\mathrm{d}t}\tilde{x}=\tilde{f}(\tilde{x},t),\ \tilde{f}(\tilde{x},t)=\begin{bmatrix}f(x,t,\theta)\\ 0\end{bmatrix},\ \ \tilde{x}(0)=\begin{bmatrix}x_{0}\\ \theta\end{bmatrix}. (4)

The variational and adjoint variables are augmented in the same way. Hereafter, we let xx denote the state or augmented state without loss of generality. See Appendix C for details.

According to the original implementation of the neural ODE [2], the final value x(T)x(T) of the system state xx is retained after forward integration, and the pair of the system state xx and the adjoint variable λ\lambda is integrated backward in time to obtain the gradients. The right-hand sides of the main system in Eq. (1) and the adjoint system in Eq. (3) are obtained by the forward and backward propagations of the neural network ff, respectively. Therefore, the computational cost of the adjoint method is twice that of the ordinary backpropagation algorithm.

After a numerical integrator discretizes the time, Remark 1 does not hold, and thus the adjoint variable λ(t)\lambda(t) is not equal to the exact gradient [10, 41]. Moreover, in general, the numerical integration backward in time is not consistent with that forward in time. Although a small step size (i.e., a small tolerance) suppresses numerical errors, it also leads to a longer computation time. These facts provide the motivation to obtain the exact gradient with a small memory, in the present study.

4 Symplectic Adjoint Method

4.1 Runge–Kutta Method

We first discretize the main system in Eq. (1). Let tnt_{n}, hnh_{n}, and xnx_{n} denote the nn-th time step, step size, and state, respectively, where hn=tn+1tnh_{n}=t_{n+1}-t_{n}. Previous studies employed one of the Runge–Kutta methods, generally expressed as

xn+1=xn+hni=1sbikn,i,kn,i:=f(Xn,i,tn+cihn),Xn,i:=xn+hnj=1sai,jkn,j.\begin{split}x_{n+1}&=x_{n}+h_{n}\sum_{i=1}^{s}b_{i}k_{n,i},\\ k_{n,i}&\vcentcolon=f(X_{n,i},t_{n}+c_{i}h_{n}),\\ X_{n,i}&\vcentcolon=x_{n}+h_{n}\sum_{j=1}^{s}a_{i,j}k_{n,j}.\end{split} (5)

The coefficients ai,ja_{i,j}, bib_{i}, and cic_{i} are summarized as the Butcher tableau [16, 17, 41]. If ai,j=0a_{i,j}=0 for jij\geq i, the intermediate state Xn,iX_{n,i} is calculable from i=1i=1 to i=si=s sequentially; then, the Runge–Kutta method is considered explicit. Runge–Kutta methods are not time-reversible in general, i.e., the numerical integration backward in time is not consistent with that forward in time.

Remark 3 (Bochev and Scovel, [1], Hairer et al., [16]).

When the system in Eq. (1) is discretized by the Runge–Kutta method in Eq. (5), the variational system in Eq. (3) is discretized by the same Runge–Kutta method.

Therefore, it is not necessary to solve the variational variable δ(t)\delta(t) separately.

4.2 Symplectic Runge–Kutta Method for Adjoint System

We assume bi0b_{i}\neq 0 for i=1,,si=1,\dots,s. We suppose the adjoint system to be solved by another Runge–Kutta method with the same step size as that used for the system state xx, expressed as

λn+1=λn+hni=1sBiln,i,ln,i:=fx(Xn,i,tn+Cihn)Λn,i,Λn,i:=λn+hnj=1sAi,jln,j.\begin{split}\lambda_{n+1}&=\lambda_{n}+h_{n}\sum_{i=1}^{s}B_{i}l_{n,i},\\ l_{n,i}&\vcentcolon=-\frac{\partial f}{\partial x}(X_{n,i},t_{n}+C_{i}h_{n})^{\top}\Lambda_{n,i},\\ \Lambda_{n,i}&\vcentcolon=\lambda_{n}+h_{n}\sum_{j=1}^{s}A_{i,j}l_{n,j}.\end{split} (6)

The final condition λN\lambda_{N} is set to ((xN)xN)(\frac{\partial\mathcal{L}(x_{N})}{\partial x_{N}})^{\top}. Because the time evolutions of the variational variable δ\delta and the adjoint variable λ\lambda are expressible by two equations, the combined system is considered as a partitioned system. A combination of two Runge–Kutta methods for solving a partitioned system is called a partitioned Runge–Kutta method, where Ci=ciC_{i}=c_{i} for i=1,,si=1,\dots,s. We introduce the following condition for a partitioned Runge–Kutta method.

Condition 1.

biAi,j+Bjaj,ibiBj=0b_{i}A_{i,j}+B_{j}a_{j,i}-b_{i}B_{j}=0 for i,j=1,,si,j=1,\dots,s, and Bi=bi0B_{i}=b_{i}\neq 0 and Ci=ciC_{i}=c_{i} for i=1,,si=1,\dots,s.

Theorem 1 (Sanz-Serna, [41]).

The partitioned Runge–Kutta method in Eqs. (5) and (6) conserves a bilinear quantity S(δ,λ)S(\delta,\lambda) if the continuous-time system conserves the quantity S(δ,λ)S(\delta,\lambda) and Condition 1 holds.

Because the bilinear quantity SS (including λδ\lambda^{\top}\delta) is conserved, the adjoint system solved by the Runge–Kutta method in Eq. (6) under Condition 1 provides the exact gradient as the adjoint variable λn=((xN)xn)\lambda_{n}=(\frac{\partial\mathcal{L}(x_{N})}{\partial x_{n}})^{\top}. The Dormand–Prince method, one of the most popular Runge–Kutta methods, has b2=0b_{2}=0 [7]. For such methods, the Runge–Kutta method under Condition 1 in Eq. (6) is generalized as

λn+1=λn+hni=1sb~iln,i,ln,i:=fx(Xn,i,tn+cihn)Λn,i,Λn,i:={λn+hnj=1sb~j(1aj,ibi)ln,jifiI0j=1sb~jaj,iln,jifiI0,\begin{split}\lambda_{n+1}&=\lambda_{n}+h_{n}\sum_{i=1}^{s}\tilde{b}_{i}l_{n,i},\\ l_{n,i}&\vcentcolon=-\frac{\partial f}{\partial x}(X_{n,i},t_{n}+c_{i}h_{n})^{\top}\Lambda_{n,i},\\ \Lambda_{n,i}&\vcentcolon=\begin{cases}\lambda_{n}+h_{n}\sum_{j=1}^{s}\tilde{b}_{j}\left(1-\frac{a_{j,i}}{b_{i}}\right)l_{n,j}&\mbox{if}\ \ i\not\in I_{0}\\ -\sum_{j=1}^{s}\tilde{b}_{j}a_{j,i}l_{n,j}&\mbox{if}\ \ i\in I_{0},\end{cases}\\ \end{split} (7)

where

b~i={biifiI0hnifiI0,I0={i|i=1,,s,bi=0}.\tilde{b}_{i}=\begin{cases}b_{i}&\mbox{if}\ \ i\not\in I_{0}\\ h_{n}&\mbox{if}\ \ i\in I_{0},\\ \end{cases}\ \ I_{0}=\{i|i=1,\dots,s,\ b_{i}=0\}. (8)

Note that this numerical integrator is no longer a Runge–Kutta method and is an alternative expression for the “fancy” integrator proposed in [41].

Theorem 2.

The combination of the integrators in Eqs. (5) and (7) conserves a bilinear quantity S(δ,λ)S(\delta,\lambda) if the continuous-time system conserves the quantity S(δ,λ)S(\delta,\lambda).

Remark 4.

The Runge–Kutta method in Eq. (6) under Condition 1 and the numerical integrator in Eq. (7) are explicit backward in time if the Runge–Kutta method in Eq. (5) is explicit forward in time.

We emphasize that Theorems 1 and 2 hold for any ODE systems even if the systems have discontinuity [19], stochasticity [29], or physics constraints [13]. This is because the Theorems are not the properties of a system but of Runge–Kutta methods.

A partitioned Runge–Kutta method that satisfies Condition 1 is symplectic [17, 16]. It is known that, when a symplectic integrator is applied to a Hamiltonian system using a fixed step size, it conserves a modified Hamiltonian, which is an approximation to the system energy of the Hamiltonian system. The bilinear quantity SS is associated with the symplectic structure but not with a Hamiltonian. Regardless of the step size, a symplectic integrator conserves the symplectic structure and thereby conserves the bilinear quantity SS. Hence, we named this method the symplectic adjoint method. For integrators other than Runge–Kutta methods, one can design the integrator for the adjoint system so that the pair of integrators is symplectic (see [32] for example).

Algorithm 1 Forward Integration
0:  x0x_{0}
0:  xNx_{N} ,{xn}n=0N1\{x_{n}\}_{n=0}^{N-1}
1:  for n=0n=0 to N1N-1 do
2:     Retain xnx_{n} as a checkpoint According to Eq. (5)
3:     for i=1i=1 to ss do
4:        Get Xn,iX_{n,i} using xnx_{n} and kn,jk_{n,j} for j<ij<i
5:        Get kn,ik_{n,i} using Xn,iX_{n,i}
6:     end for
7:     Get xn+1x_{n+1} using xnx_{n} and kn,ik_{n,i}
8:  end for
Algorithm 2 Backward Integration
0:  xNx_{N} ,{xn}n=0N1\{x_{n}\}_{n=0}^{N-1}
0:  λ0\lambda_{0}
1:  for n=N1n=N-1 to 0 do
2:     Load checkpoint xnx_{n} According to Eq. (5)
3:     for i=1i=1 to ss do
4:        Get Xn,iX_{n,i} using xnx_{n} and kn,jk_{n,j} for j<ij<i
5:        Get kn,ik_{n,i} using Xn,iX_{n,i}.
6:        Retain Xn,iX_{n,i} as a checkpoint
7:     end forAccording to Eq. (7)
8:     for i=si=s to 11 do
9:        Get Λn,i\Lambda_{n,i} using λn+1\lambda_{n+1} and ln,jl_{n,j} for j>ij>i
10:        Load checkpoint Xn,iX_{n,i}
11:        Get ln,il_{n,i} using Λn,i\Lambda_{n,i} and Xn,iX_{n,i}.
12:        Discard checkpoint Xn,iX_{n,i}
13:     end for
14:     Get λn\lambda_{n} using λn+1\lambda_{n+1} and ln,il_{n,i}
15:     Discard checkpoint xnx_{n}
16:  end for

4.3 Proposed Implementation

The theories given in the last section were mainly introduced for the numerical analysis in [41]. Because the original expression includes recalculations of intermediate variables, we propose the alternative expression in Eq. (7) to reduce the computational cost. The discretized adjoint system in Eq. (7) depends on the vector–Jacobian product (VJP) Λfx\Lambda^{\top}\frac{\partial f}{\partial x}. To obtain it, the computation graph from the input Xn,iX_{n,i} to the output f(Xn,i,tn+cihn)f(X_{n,i},t_{n}+c_{i}h_{n}) is required. When the computation graph in the forward integration is entirely retained, the memory consumption and computational cost are of the same orders as those for the naive backpropagation algorithm. To reduce the memory consumption, we propose the following strategy as summarized in Algorithms 1 and 2.

At the forward integration of a neural ODE component, the pairs of system states xnx_{n} and time points tnt_{n} at time steps n=0,,N1n=0,\dots,N-1 are retained with a memory of O(N)O(N) as checkpoints, and all computation graphs are discarded, as shown in Algorithm 1. For MM neural ODE components, the memory for checkpoints is O(MN)O(MN). The backward integration is summarized in Algorithm 2. The below steps are repeated from n=N1n=N-1 to n=0n=0. From the checkpoint xnx_{n}, the intermediate states Xn,iX_{n,i} for ss stages are obtained following the Runge–Kutta method in Eq. (5) and retained as checkpoints with a memory of O(s)O(s), while all computation graphs are discarded. Then, the adjoint system is integrated from n+1n+1 to nn using Eq. (7). Because the computation graph of the neural network ff in line 5 is discarded, it is recalculated and the VJP λfx\lambda^{\top}\frac{\partial f}{\partial x} is obtained using the backpropagation algorithm one-by-one in line 11, where only a single use of the neural network is recalculated at a time. This is why the memory consumption is proportional to the number of checkpoints MN+sMN+s plus the neural network size LL. By contrast, existing methods apply the backpropagation algorithm to the computation graph of a single step composed of ss stages or multiple steps. The memory consumption is proportional to the number of uses of the neural network between two checkpoints (ss at least) times the neural network size LL, in addition to the memory for checkpoints (see Table 1). Due to the recalculation, the computational cost of the proposed strategy is O(4MNsL)O(4MNsL), whereas those of the adjoint method [2] and ACA [46] are O(M(N+2N~)sL)O(M(N+2\tilde{N})sL) and O(3MNsL)O(3MNsL), respectively. However, the increase in the computation time is much less than that expected theoretically because of other bottlenecks (as demonstrated later).

5 Experiments

We evaluated the performance of the proposed symplectic adjoint method and existing methods using PyTorch 1.7.1 [35]. We implemented the proposed symplectic adjoint method by extending the adjoint method implemented in the package 𝗍𝗈𝗋𝖼𝗁𝖽𝗂𝖿𝖿𝖾𝗊\mathsf{torchdiffeq} 0.1.1 [2]. We re-implemented ACA [46] because the interfaces of the official implementation is incompatible with 𝗍𝗈𝗋𝖼𝗁𝖽𝗂𝖿𝖿𝖾𝗊\mathsf{torchdiffeq}. In practice, the number of checkpoints for an integration can be varied; we implemented a baseline scheme that retains only a single checkpoint per neural ODE component. The source code is available at https://github.com/tksmatsubara/symplectic-adjoint-method.

5.1 Continuous Normalizing Flow

Experimental Settings:

We evaluated the proposed symplectic adjoint method on training continuous normalizing flows [12]. A normalizing flow is a neural network that approximates a bijective map gg and obtains the exact likelihood of a sample uu by the change of variables logp(u)=logp(z)+log|detg(u)u|\log p(u)=\log p(z)+\log|\det\frac{\partial g(u)}{\partial u}|, where z=g(u)z=g(u) and p(z)p(z) denote the corresponding latent variable and its prior, respectively [5, 6, 38]. A continuous normalizing flow is a normalizing flow whose map gg is modeled by stacked neural ODE components, in particular, u=x(0)u=x(0) and z=x(T)z=x(T) for the case with M=1M=1. The log-determinant of the Jacobian is obtained by a numerical integration together with the system state xx as log|detg(u)u|=0TTr(fx(x(t),t)dt\log|\det\frac{\partial g(u)}{\partial u}|=-\int_{0}^{T}\mathrm{Tr}(\frac{\partial f}{\partial x}(x(t),t){\mathrm{d}t}. The trace operation Tr\mathrm{Tr} is approximated by the Hutchinson estimator [22]. We adopted the experimental settings of the continuous normalizing flow, FFJORD111https://github.com/rtqichen/ffjord (MIT License) [12], unless stated otherwise.

We examined five real tabular datasets, namely, MiniBooNE, GAS, POWER, HEPMASS, and BSDS300 datasets [33]. The network architectures were the same as those that achieved the best results in the original experiments; the number of neural ODE components MM varied across datasets. We employed the Dormand–Prince integrator, which is a fifth-order Runge–Kutta method with adaptive time-stepping, composed of seven stages [7]. Note that the number of function evaluations per step is s=6s=6 because the last stage is reused as the first stage of the next step. We set the absolute and relative tolerances to 𝖺𝗍𝗈𝗅=108\mathsf{atol}=10^{-8} and 𝗋𝗍𝗈𝗅=106\mathsf{rtol}=10^{-6}, respectively. The neural networks were trained using the Adam optimizer [27] with a learning rate of 10310^{-3}. We used a batch-size of 1000 for all datasets to put a mini-batch into a single NVIDIA GeForce RTX 2080Ti GPU with 11 GB of memory, while the original experiments employed a batch-size of 10 000 for the latter three datasets on multiple GPUs. When using multiple GPUs, bottlenecks such as data transfer across GPUs may affect performance, and a fair comparison becomes difficult. Nonetheless, the naive backpropagation algorithm and baseline scheme consumed the entire memory for BSDS300 dataset.

We also examined the MNIST dataset [28] using a single NVIDIA RTX A6000 GPU with 48 GB of memory. Following the original study, we employed the multi-scale architecture and set the tolerance to 𝖺𝗍𝗈𝗅=𝗋𝗍𝗈𝗅=105\mathsf{atol}=\mathsf{rtol}=10^{-5}. We set the learning rate to 10310^{-3} and then reduced it to 10410^{-4} at the 250th epoch. While the original experiments used a batch-size of 900, we set the batch-size to 200 following the official code1. The naive backpropagation algorithm and baseline scheme consumed the entire memory.

Table 2: Results obtained for continuous normalizing flows.
MINIBOONE (M=1M=1) GAS (M=5M=5) POWER (M=5M=5)
NLL mem. time NLL mem. time NLL mem. time
adjoint method [2] 10.59±\pm0.17 170 0.74 -10.53±\pm0.25 24 4.82 -0.31±\pm0.01 8.1 6.33
backpropagation [2] 10.54±\pm0.18 4436 0.91 -9.53±\pm0.42 4479 12.00 -0.24±\pm0.05 1710.9 10.64
baseline scheme 10.54±\pm0.18 4457 1.10 -9.53±\pm0.42 1858 5.48 -0.24±\pm0.05 515.2 4.37
ACA [46] 10.57±\pm0.30 306 0.77 -10.65±\pm0.45 73 3.98 -0.31±\pm0.02 29.5 5.08
proposed 10.49±\pm0.11 95 0.84 -10.89±\pm0.11 20 4.39 -0.31±\pm0.02 9.2 5.73
HEPMASS (M=10M=10) BSDS300 (M=2M=2) MNIST (M=6M=6)
NLL mem. time NLL mem. time NLL mem. time
adjoint method [2] 16.49±\pm0.25 40 4.19 -152.04±\pm0.09 577 11.70 0.918±\pm0.011 1086 10.12
backpropagation [2] 17.03±\pm0.22 5254 11.82
baseline scheme 17.03±\pm0.22 1102 4.40
ACA [46] 16.41±\pm0.39 88 3.67 -151.27±\pm0.47 757 6.97 0.919±\pm0.003 4332 7.94
proposed 16.48±\pm0.20 35 4.15 -151.17±\pm0.15 283 8.07 0.917±\pm0.002 1079 9.42

Negative log-likelihoods (NLL), peak memory consumption [MiB\mathrm{MiB}], and computation time per iteration [s/itr\mathrm{s/itr}]. See Table A2 in Appendix for standard deviations.

Performance:

The medians ±\pm standard deviations of three runs are summarized in Table 2. In many cases, all methods achieved negative log-likelihoods (NLLs) with no significant difference because all but the adjoint method provide the exact gradients up to rounding error, and the adjoint method with a small tolerance provides a sufficiently accurate gradient. The naive backpropagation algorithm and baseline scheme obtained slightly worse results on the GAS, POWER, and HEPMASS datasets. Due to adaptive time-stepping, the numerical integrator sometimes makes the step size much smaller, and the backpropagation algorithm over time steps suffered from rounding errors. Conversely, ACA and the proposed symplectic adjoint method applied the backpropagation algorithm separately to a subset of the integration, thereby becoming more robust to rounding errors (see Appendix D.1 for details).

After the training procedure, we obtained the peak memory consumption during additional training iterations (mem. [MiB\mathrm{MiB}]), from which we subtracted the memory consumption before training (i.e., occupied by the model parameters, loaded data, etc.). The memory consumption still includes the optimizer’s states and the intermediate results of the multiply–accumulate operation. The results roughly agree with the theoretical orders shown in Table 1 (see also Table A2 for standard deviations). The symplectic adjoint method consumed much smaller memory than the naive backpropagation algorithm and the checkpointing schemes. Owing to the optimized implementation, the symplectic adjoint method consumed smaller memory than the adjoint method in some cases (see Appendix D.2).

On the other hand, the computation time per iteration (time [s/itr\mathrm{s/itr}]) during the additional training iterations does not agree with the theoretical orders. First, the adjoint method was slower in many cases, especially for the BSDS300 and MNIST datasets. For obtaining the gradients, the adjoint method integrates the adjoint variable λ\lambda, whose size is equal to the sum of the sizes of the parameters θ\theta and the system state xx. With more parameters, the probability that at least one parameter does not satisfy the tolerance value is increased. An accurate backward integration requires a much smaller step size than the forward integration (i.e., N~\tilde{N} much greater than NN), leading to a longer computation time. Second, the naive backpropagation algorithm and baseline scheme were slower than that expected theoretically, in many cases. A method with high memory consumption may have to wait for a retained computation graph to be loaded or memory to be freed, leading to an additional bottleneck. The symplectic adjoint method is free from the above bottlenecks and performs faster in practice; it was faster than the adjoint method for all but MiniBooNE dataset.

The symplectic adjoint method is superior (or at least competitive) to the adjoint method, naive backpropagation, and baseline scheme in terms of both memory consumption and computation time. Between the proposed symplectic adjoint method and ACA, a trade-off exists between memory consumption and computation time.

Refer to caption
Figure 1: With different tolerances.

Robustness to Tolerance:

The adjoint method provides gradients with numerical errors. To evaluate the robustness against tolerance, we employed MiniBooNE dataset and varied the absolute tolerance 𝖺𝗍𝗈𝗅\mathsf{atol} while maintaining the relative tolerance as 𝗋𝗍𝗈𝗅=102×𝖺𝗍𝗈𝗅\mathsf{rtol}=10^{2}\!\times\!\mathsf{atol}. During the training, we obtained the computation time per iteration, as summarized in the upper panel of Fig. 1. The computation time reduced as the tolerance increased. After training, we obtained the NLLs with 𝖺𝗍𝗈𝗅=108\mathsf{atol}=10^{-8}, as summarized in the bottom panel of Fig. 1. The adjoint method performed well only with 𝖺𝗍𝗈𝗅<104\mathsf{atol}<10^{-4}. With 𝖺𝗍𝗈𝗅=104\mathsf{atol}=10^{-4}, the numerical error in the backward integration was non-negligible, and the performance degraded. With 𝖺𝗍𝗈𝗅>104\mathsf{atol}>10^{-4}, the adjoint method destabilized. The symplectic adjoint method performed well even with 𝖺𝗍𝗈𝗅=104\mathsf{atol}=10^{-4}. Even with 104<𝖺𝗍𝗈𝗅<10210^{-4}<\mathsf{atol}<10^{-2}, it performed to a certain level, while the numerical error in the forward integration was non-negligible. Because of the exact gradient, the symplectic adjoint method is robust to a large tolerance compared with the adjoint method, and thus potentially works much faster with an appropriate tolerance.

Table 3: Results obtained for GAS dataset with different Runge–Kutta methods.
p=2p=2, s=2s=2 p=3p=3, s=3s=3 p=5p=5, s=6s=6 p=8p=8, s=12s=12
mem. time mem. time mem. time mem. time
adjoint method [2] 21±\pm00 247.47±\pm07.52 22±\pm0 18.32±\pm0.88 24±\pm000 5.34±\pm0.31 28±\pm000 9.77±\pm0.81
backpropagation [2] 4433±\pm255 11.85±\pm1.10
baseline scheme 1858±\pm228 5.82±\pm0.28 4108±\pm576 22.76±\pm3.70
ACA [46] 607±\pm30 232.90±\pm13.81 69±\pm2 17.72±\pm1.38 73±\pm000 4.15±\pm0.21 138±\pm000 9.36±\pm0.55
proposed 589±\pm14 262.99±\pm05.19 43±\pm2 18.59±\pm0.75 20±\pm000 4.78±\pm0.32 21±\pm000 11.41±\pm0.23

Peak memory consumption [MiB\mathrm{MiB}], and computation time per iteration [s/itr\mathrm{s/itr}].

Different Runge–Kutta Methods:

The Runge–Kutta family includes various integrators characterized by the Butcher tableau [16, 17, 41], such as the Heun–Euler method (a.k.a. adaptive Heun), Bogacki–Shampine method (a.k.a. bosh3), fifth-order Dormand–Prince method (a.k.a. dopri5), and eighth-order Dormand–Prince method (a.k.a. dopri8). These methods have the orders of p=2,3,5p=2,3,5, and 88 using s=2,3,6s=2,3,6, and 1212 function evaluations, respectively. We examined these methods using GAS dataset, and the results are summarized in Table 3. The naive backpropagation algorithm and baseline scheme consumed the entire memory in some cases, as denoted by dashes. We omit the NLLs because all methods used the same tolerance and achieved the same NLLs.

Compared to ACA, the symplectic adjoint method suppresses the memory consumption more significantly with a higher-order method (i.e., more function evaluations ss), as the theory suggests in Table 1. With the Heun–Euler method, all methods were extremely slow, and all but the adjoint method consumed larger memory. A lower-order method has to use an extremely small step size to satisfy the tolerance, thereby increasing the number of steps NN, computation time, and memory for checkpoints. This result indicates the limitations of methods that depend on lower-order integrators, such as MALI [47]. With the eighth-order Dormand–Prince method, the adjoint method performs relatively faster. This is because the backward integration easily satisfies the tolerance with a higher-order method (i.e., N~N\tilde{N}\simeq N). Nonetheless, in terms of computation time, the fifth-order Dormand–Prince method is the best choice, for which the symplectic adjoint method greatly reduces the memory consumption and performs faster than all but ACA.

Refer to caption
Figure 2: Different number of steps.

Memory for Checkpoints:

To evaluate the memory consumption with varying numbers of checkpoints, we used the fifth-order Dormand–Prince method and varied the number of steps NN for MNIST by manually varying the step size. We summarized the results in Fig. 2 on a log-log scale. Note that, with the adaptive stepping, FFJORD needs approximately MN=200MN=200 steps for MNIST and fewer steps for other datasets. Because we set N~=N\tilde{N}=N, but N~>N\tilde{N}>N in practice, the adjoint method is expected to require a longer computation time.

The memory consumption roughly follows the theoretical orders summarized in Table 1. The adjoint method needs a memory of O(L)O(L) for the backpropagation, and the symplectic adjoint method needs an additional memory of O(MN+s)O(MN+s) for checkpoints. Until the number of steps MNMN exceeds a thousand, the memory for checkpoints is negligible compared to that for the backpropagation. Compared to the symplectic adjoint method, ACA needs a memory of O(sL)O(sL) for the backpropagation over ss stages. The increase in memory is significant until the number of steps MNMN reaches ten thousand. For a stiff (non-smooth) ODE for which a numerical integrator needs thousands of steps, one can employ a higher-order integrator such as the eighth-order Dormand–Prince method and suppress the number of steps. For a stiffer ODE, implicit integrators are commonly used, which are out of the scope of this study and the related works in Table 1. Therefore, we conclude that the symplectic adjoint method needs the memory at the same level as the adjoint method and much smaller than others in practical ranges.

A possible alternative to the proposed implementation in Algorithms 1 and 2 retains all intermediate states Xn,iX_{n,i} during the forward integration. Its computational cost and memory consumption are O(3MNsL)O(3MNsL) and O(MNs+L)O(MNs+L), respectively. The memory for checkpoints can be non-negligible with a practical number of steps.

Table 4: Results obtained for continuous-time physical systems.
KdV Equation Cahn–Hilliard System
MSE (×103\times 10^{-3}) mem. time MSE (×106\times 10^{-6}) mem. time
adjoint method [2] 1.61±\pm3.23 93.7±\pm0.0 276±\pm16 5.58±\pm1.67 93.7±\pm0.0 942±\pm24
backpropagation [2] 1.61±\pm3.40 693.9±\pm0.0 105±\pm04 4.68±\pm1.89 3047.1±\pm0.0 425±\pm13
ACA [46] 1.61±\pm3.40 647.8±\pm0.0 137±\pm05 5.82±\pm2.33 648.0±\pm0.0 484±\pm13
proposed 1.61±\pm4.00 79.8±\pm0.0 162±\pm06 5.47±\pm1.46 80.3±\pm0.0 568±\pm22
Mean-squared errors (MSEs) in long-term predictions, peak memory consumption [MiB\mathrm{MiB}],
and computation time per iteration [ms/itr\mathrm{ms/itr}].

5.2 Continuous-Time Dynamical System

Experimental Settings:

We evaluated the symplectic adjoint method on learning continuous-time dynamical systems [13, 31, 40]. Many physical phenomena can be modeled using the gradient of system energy HH as dx/dt=GH(x){\mathrm{d}x}/{\mathrm{d}t}=G\nabla H(x), where GG is a coefficient matrix that determines the behaviors of the energy [9]. We followed the experimental settings of HNN++, provided in [31]222https://github.com/tksmatsubara/discrete-autograd (MIT License). A neural network composed of one convolution layer and two fully connected layers approximated the energy function HH and learned the time series by interpolating two successive samples. The deterministic convolution algorithm was enabled (see Appendix D.3 for discussion). We employed two physical systems described by PDEs, namely the Korteweg–De Vries (KdV) equation and the Cahn–Hilliard system. We used a batch-size of 100 to put a mini-batch into a single NVIDIA TITAN V GPU instead of the original batch-size of 200. Moreover, we used the eighth-order Dormand–Prince method [17], composed of 13 stages, to emphasize the efficiency of the proposed method. We omitted the baseline scheme because of M=1M=1. We evaluated the performance using mean squared errors (MSEs) in the system energy for long-term predictions.

Performance:

The medians ±\pm standard deviations of 15 runs are summarized in Table 4. Due to the accumulated error in the numerical integration, the MSEs had large variances, but all methods obtained similar MSEs. ACA consumed much more memory than the symplectic adjoint method because of the large number of stages; the symplectic adjoint method is more beneficial for physical simulations, which often require extremely higher-order methods. Due to the severe nonlinearity, the adjoint method had to employ a small step size and thus performed slower than others (i.e., N~>N\tilde{N}>N).

6 Conclusion

We proposed the symplectic adjoint method, which solves the adjoint system by a symplectic integrator with appropriate checkpoints and thereby provides the exact gradient. It only applies the backpropagation algorithm to each use of the neural network, and thus consumes memory much less than the backpropagation algorithm and the checkpointing schemes. Its memory consumption is competitive to that of the adjoint method because the memory consumed by checkpoints is negligible in most cases. The symplectic adjoint method provides the exact gradient with the same step size as that used for the forward integration. Therefore, in practice, it performs faster than the adjoint method, which requires a small step size to suppress numerical errors.

As shown in the experiments, the best integrator and checkpointing scheme may depend on the target system and computational resources. For example, Kim et al., [26] has demonstrated that quadrature methods can reduce the computation cost of the adjoint system for a stiff equation in exchange for the additional memory consumption. Practical packages provide many integrators and can choose the best ones [20, 36]. In the future, we will provide the proposed symplectic adjoint method as a part of such packages for appropriate systems.

Acknowledgments and Disclosure of Funding

This study was partially supported by JST CREST (JPMJCR1914), JST PRESTO (JPMJPR21C7), and JSPS KAKENHI (19K20344).

References

  • Bochev and Scovel, [1994] Bochev, P. B. and Scovel, C. (1994). On quadratic invariants and symplectic structure. BIT Numerical Mathematics, 34(3):337–345.
  • Chen et al., [2018] Chen, T. Q., Rubanova, Y., Bettencourt, J., Duvenaud, D., Chen, R. T. Q., Rubanova, Y., Bettencourt, J., and Duvenaud, D. (2018). Neural Ordinary Differential Equations. In Advances in Neural Information Processing Systems (NeurIPS).
  • Chen et al., [2020] Chen, Z., Zhang, J., Arjovsky, M., and Bottou, L. (2020). Symplectic Recurrent Neural Networks. In International Conference on Learning Representations (ICLR).
  • Devlin et al., [2018] Devlin, J., Chang, M.-W., Lee, K., and Toutanova, K. (2018). BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding. arXiv.
  • Dinh et al., [2014] Dinh, L., Krueger, D., and Bengio, Y. (2014). NICE: Non-linear Independent Components Estimation. In Workshop on International Conference on Learning Representations.
  • Dinh et al., [2017] Dinh, L., Sohl-Dickstein, J., and Bengio, S. (2017). Density estimation using Real NVP. In International Conference on Learning Representations (ICLR).
  • Dormand and Prince, [1986] Dormand, J. R. and Prince, P. J. (1986). A reconsideration of some embedded Runge-Kutta formulae. Journal of Computational and Applied Mathematics, 15(2):203–211.
  • Errico, [1997] Errico, R. M. (1997). What Is an Adjoint Model? Bulletin of the American Meteorological Society, 78(11):2577–2591.
  • Furihata and Matsuo, [2010] Furihata, D. and Matsuo, T. (2010). Discrete Variational Derivative Method: A Structure-Preserving Numerical Method for Partial Differential Equations. Chapman and Hall/CRC.
  • Gholami et al., [2019] Gholami, A., Keutzer, K., and Biros, G. (2019). ANODE: Unconditionally accurate memory-efficient gradients for neural ODEs. In International Joint Conference on Artificial Intelligence (IJCAI).
  • Gomez et al., [2017] Gomez, A. N., Ren, M., Urtasun, R., and Grosse, R. B. (2017). The Reversible Residual Network: Backpropagation Without Storing Activations. In Advances in Neural Information Processing Systems (NIPS).
  • Grathwohl et al., [2018] Grathwohl, W., Chen, R. T. Q., Bettencourt, J., Sutskever, I., and Duvenaud, D. (2018). FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative Models. In International Conference on Learning Representations (ICLR).
  • Greydanus et al., [2019] Greydanus, S., Dzamba, M., and Yosinski, J. (2019). Hamiltonian Neural Networks. In Advances in Neural Information Processing Systems (NeurIPS).
  • Griewank and Walther, [2000] Griewank, A. and Walther, A. (2000). Algorithm 799: Revolve: An implementation of checkpointing for the reverse or adjoint mode of computational differentiation. ACM Transactions on Mathematical Software, 26(1):19–45.
  • Gruslys et al., [2016] Gruslys, A., Munos, R., Danihelka, I., Lanctot, M., and Graves, A. (2016). Memory-efficient backpropagation through time. Advances in Neural Information Processing Systems (NIPS).
  • Hairer et al., [2006] Hairer, E., Lubich, C., and Wanner, G. (2006). Geometric Numerical Integration: Structure-Preserving Algorithms for Ordinary Differential Equations, volume 31 of Springer Series in Computational Mathematics. Springer-Verlag, Berlin/Heidelberg.
  • Hairer et al., [1993] Hairer, E., Nørsett, S. P., and Wanner, G. (1993). Solving Ordinary Differential Equations I: Nonstiff Problems, volume 8 of Springer Series in Computational Mathematics. Springer Berlin Heidelberg, Berlin, Heidelberg.
  • He et al., [2016] He, K., Zhang, X., Ren, S., and Sun, J. (2016). Deep Residual Learning for Image Recognition. In IEEE Conference on Computer Vision and Pattern Recognition (CVPR).
  • Herrera et al., [2020] Herrera, C., Krach, F., and Teichmann, J. (2020). Neural Jump Ordinary Differential Equations: Consistent Continuous-Time Prediction and Filtering. In International Conference on Learning Representations (ICLR).
  • Hindmarsh et al., [2005] Hindmarsh, A. C., Brown, P. N., Grant, K. E., Lee, S. L., Serban, R., Shumaker, D. E., and Woodward, C. S. (2005). SUNDIALS: Suite of nonlinear and differential/algebraic equation solvers. ACM Transactions on Mathematical Software, 31(3):363–396.
  • Hochreiter and Schmidhuber, [1997] Hochreiter, S. and Schmidhuber, J. (1997). Long Short-Term Memory. Neural Computation, 9(8):1735–1780.
  • Hutchinson, [1990] Hutchinson, M. (1990). A stochastic estimator of the trace of the influence matrix for laplacian smoothing splines. Communications in Statistics - Simulation and Computation, 19(2):433–450.
  • Jiang et al., [2020] Jiang, C. M., Huang, J., Tagliasacchi, A., and Guibas, L. (2020). ShapeFlow: Learnable Deformations Among 3D Shapes. In Advances in Neural Information Processing Systems (NeurIPS).
  • Kidger et al., [2020] Kidger, P., Morrill, J., Foster, J., and Lyons, T. (2020). Neural Controlled Differential Equations for Irregular Time Series. In Advances in Neural Information Processing Systems (NeurIPS).
  • Kim et al., [2020] Kim, H., Lee, H., Kang, W. H., Cheon, S. J., Choi, B. J., and Kim, N. S. (2020). WaveNODE: A Continuous Normalizing Flow for Speech Synthesis. In ICML2020 Workshop on Invertible Neural Networks, Normalizing Flows, and Explicit Likelihood Models.
  • Kim et al., [2021] Kim, S., Ji, W., Deng, S., Ma, Y., and Rackauckas, C. (2021). Stiff Neural Ordinary Differential Equations. arXiv.
  • Kingma and Ba, [2015] Kingma, D. P. and Ba, J. (2015). Adam: A Method for Stochastic Optimization. In International Conference on Learning Representations (ICLR).
  • LeCun et al., [1998] LeCun, Y., Bottou, L., Bengio, Y., and Haffner, P. (1998). Gradient-based learning applied to document recognition. In Proceedings of the IEEE, volume 86, pages 2278–2323.
  • Li et al., [2020] Li, X., Wong, T.-K. L., Chen, R. T. Q., and Duvenaud, D. (2020). Scalable Gradients for Stochastic Differential Equations. In Artificial Intelligence and Statistics (AISTATS).
  • Lu et al., [2018] Lu, Y., Zhong, A., Li, Q., and Dong, B. (2018). Beyond finite layer neural networks: Bridging deep architectures and numerical differential equations. In International Conference on Machine Learning (ICML).
  • Matsubara et al., [2020] Matsubara, T., Ishikawa, A., and Yaguchi, T. (2020). Deep Energy-Based Modeling of Discrete-Time Physics. In Advances in Neural Information Processing Systems (NeurIPS).
  • Matsuda and Miyatake, [2021] Matsuda, T. and Miyatake, Y. (2021). Generalization of partitioned Runge–Kutta methods for adjoint systems. Journal of Computational and Applied Mathematics, 388:113308.
  • Papamakarios et al., [2017] Papamakarios, G., Pavlakou, T., and Murray, I. (2017). Masked Autoregressive Flow for Density Estimation. In Advances in Neural Information Processing Systems (NIPS).
  • Pascanu et al., [2013] Pascanu, R., Mikolov, T., and Bengio, Y. (2013). On the difficulty of training Recurrent Neural Networks. In International Conference on Machine Learning (ICML).
  • Paszke et al., [2017] Paszke, A., Chanan, G., Lin, Z., Gross, S., Yang, E., Antiga, L., and Devito, Z. (2017). Automatic differentiation in PyTorch. In Autodiff Workshop on Advances in Neural Information Processing Systems.
  • Rackauckas et al., [2020] Rackauckas, C., Ma, Y., Martensen, J., Warner, C., Zubov, K., Supekar, R., Skinner, D., Ramadhan, A., and Edelman, A. (2020). Universal Differential Equations for Scientific Machine Learning. arXiv.
  • Rana et al., [2020] Rana, M. A., Li, A., Fox, D., Boots, B., Ramos, F., and Ratliff, N. (2020). Euclideanizing Flows: Diffeomorphic Reduction for Learning Stable Dynamical Systems. In Conference on Learning for Dynamics and Control (L4DC).
  • Rezende and Mohamed, [2015] Rezende, D. J. and Mohamed, S. (2015). Variational Inference with Normalizing Flows. In International Conference on Machine Learning (ICML).
  • Rumelhart et al., [1986] Rumelhart, D. E., Hinton, G. E., and Williams, R. J. (1986). Learning representations by back-propagating errors. Nature, 323(6088):533–536.
  • Saemundsson et al., [2020] Saemundsson, S., Terenin, A., Hofmann, K., Deisenroth, M. P., Sæmundsson, S., Hofmann, K., Terenin, A., and Deisenroth, M. P. (2020). Variational Integrator Networks for Physically Meaningful Embeddings. In Artificial Intelligence and Statistics (AISTATS).
  • Sanz-Serna, [2016] Sanz-Serna, J. M. (2016). Symplectic Runge-Kutta schemes for adjoint equations, automatic differentiation, optimal control, and more. SIAM Review, 58(1):3–33.
  • Takeishi and Kawahara, [2020] Takeishi, N. and Kawahara, Y. (2020). Learning dynamics models with stable invariant sets. AAAI Conference on Artificial Intelligence (AAAI).
  • Teshima et al., [2020] Teshima, T., Tojo, K., Ikeda, M., Ishikawa, I., and Oono, K. (2020). Universal Approximation Property of Neural Ordinary Differential Equations. In NeurIPS Workshop on Differential Geometry meets Deep Learning (DiffGeo4DL).
  • Wang, [2013] Wang, Q. (2013). Forward and adjoint sensitivity computation of chaotic dynamical systems. Journal of Computational Physics, 235:1–13.
  • Yang et al., [2019] Yang, G., Huang, X., Hao, Z., Liu, M. Y., Belongie, S., and Hariharan, B. (2019). Pointflow: 3D point cloud generation with continuous normalizing flows. International Conference on Computer Vision (ICCV).
  • Zhuang et al., [2020] Zhuang, J., Dvornek, N., Li, X., Tatikonda, S., Papademetris, X., and Duncan, J. (2020). Adaptive Checkpoint Adjoint Method for Gradient Estimation in Neural ODE. In International Conference on Machine Learning (ICML).
  • Zhuang et al., [2021] Zhuang, J., Dvornek, N. C., Tatikonda, S., and Duncan, J. S. (2021). MALI: A memory efficient and reverse accurate integrator for Neural ODEs. In International Conference on Learning Representations (ICLR).

Supplementary Material: Appendices

Appendix A Derivation of Variational System

Let us consider a perturbed initial condition x¯0=x0+δ¯0\bar{x}_{0}=x_{0}+\bar{\delta}_{0}, from which the solution x¯(t)\bar{x}(t) arises. Suppose that the solution x¯(t)\bar{x}(t) satisfies x¯(t)=x(t)+δ¯(t)\bar{x}(t)=x(t)+\bar{\delta}(t). Then,

ddtδ¯=ddt(x¯x)=f(x¯,t)f(x,t)=fx(x,t)(x¯x)+o(|x¯x|)=fx(x,t)δ¯+o(|δ¯|),δ¯(0)=δ¯0.\begin{split}\frac{{\mathrm{d}}}{\mathrm{d}t}\bar{\delta}&=\frac{{\mathrm{d}}}{\mathrm{d}t}(\bar{x}-x)\\ &=f(\bar{x},t)-f(x,t)\\ &=\frac{\partial f}{\partial x}(x,t)(\bar{x}-x)+o(|\bar{x}-x|)\\ &=\frac{\partial f}{\partial x}(x,t)\bar{\delta}+o(|\bar{\delta}|),\\ \bar{\delta}(0)&=\bar{\delta}_{0}.\end{split} (9)

Dividing δ¯\bar{\delta} by δ¯0\bar{\delta}_{0} and taking the limit as |δ¯0|+0|\bar{\delta}_{0}|\rightarrow+0, we define the variational variable as δ(t)=x(t)x0\delta(t)=\frac{\partial x(t)}{\partial x_{0}} and the variational system as

ddtδ(t)=fx(x(t),t)δ(t) for δ(0)=I.\frac{{\mathrm{d}}}{\mathrm{d}t}\delta(t)=\frac{\partial f}{\partial x}(x(t),t)\delta(t)\mbox{ for }\delta(0)=I. (10)

Appendix B Complete Proofs

Proof of Remark 1:

ddt(λδ)=(ddtλ)δ+λ(ddtδ)=(fx(x,t)λ)δ+λ(fx(x,t)δ)=0.\frac{{\mathrm{d}}}{\mathrm{d}t}\left(\lambda^{\top}\delta\right)=\left(\frac{{\mathrm{d}}}{\mathrm{d}t}\lambda\right)^{\!\!\top}\delta+\lambda^{\top}\left(\frac{{\mathrm{d}}}{\mathrm{d}t}\delta\right)=\left(-\frac{\partial f}{\partial x}(x,t)^{\top}\lambda\right)^{\!\!\top}\delta+\lambda^{\top}\left(\frac{\partial f}{\partial x}(x,t)\delta\right)=0. (11)

Proof of Remark 2:

Because δ(t)=x(t)x0\delta(t)=\frac{\partial x(t)}{\partial x_{0}} and λδ\lambda^{\top}\delta is time-invariant,

(x(T))x0=(x(T))x(T)x(T)x0=λ(T)δ(T)=λ(t)δ(t)=(x(T))x(t)x(t)x0.\frac{\partial\mathcal{L}(x(T))}{\partial x_{0}}=\frac{\partial\mathcal{L}(x(T))}{\partial x(T)}\frac{\partial x(T)}{\partial x_{0}}=\lambda(T)^{\top}\delta(T)=\lambda(t)^{\top}\delta(t)=\frac{\partial\mathcal{L}(x(T))}{\partial x(t)}\frac{\partial x(t)}{\partial x_{0}}. (12)

Proof of Remark 3:

Differentiating each term in the Runge–Kutta method in Eq. (5) by the initial condition x0x_{0} gives the Runge–Kutta method applied to the variational variable δ\delta, as follows.

δn+1=δn+hni=1sbidn,i,dn,i:=kn,ix0=f(Xn,i,tn+cihn)x0=f(Xn,i,tn+cihn)Xn,iΔn,i,Δn,i:=Xn,ix0=δn+hnj=1sai,jdn,j.\begin{split}\delta_{n+1}&=\delta_{n}+h_{n}\sum_{i=1}^{s}b_{i}d_{n,i},\\ d_{n,i}&\vcentcolon=\frac{\partial k_{n,i}}{\partial x_{0}}=\frac{\partial f(X_{n,i},t_{n}+c_{i}h_{n})}{\partial x_{0}}=\frac{\partial f(X_{n,i},t_{n}+c_{i}h_{n})}{\partial X_{n,i}}\Delta_{n,i},\\ \Delta_{n,i}&\vcentcolon=\frac{\partial X_{n,i}}{\partial x_{0}}=\delta_{n}+h_{n}\sum_{j=1}^{s}a_{i,j}d_{n,j}.\end{split} (13)

Proof of Theorem 1:

Because the quantity SS is conserved in continuous time,

ddtS(δ,λ)=0.\frac{{\mathrm{d}}}{\mathrm{d}t}S(\delta,\lambda)=0. (14)

Because the quantity SS is bilinear,

ddtS(δ,λ)=Sδdδdt+Sλdλdt=S(dδdt,λ)+S(δ,dλdt),\frac{{\mathrm{d}}}{\mathrm{d}t}S(\delta,\lambda)=\frac{\partial S}{\partial\delta}\frac{{\mathrm{d}}\delta}{{\mathrm{d}t}}+\frac{\partial S}{\partial\lambda}\frac{{\mathrm{d}}\lambda}{{\mathrm{d}t}}=S\left(\frac{{\mathrm{d}}\delta}{{\mathrm{d}t}},\lambda\right)+S\left(\delta,\frac{{\mathrm{d}}\lambda}{{\mathrm{d}t}}\right), (15)

which implies

S(dn,i,Λn,i)+S(Δn,i,ln,i)=0.S(d_{n,i},\Lambda_{n,i})+S(\Delta_{n,i},l_{n,i})=0. (16)

The change in the bilinear quantity S(δ,λ)S(\delta,\lambda) is

S(δn+1,λn+1)S(δn,λn)=S(δn+hnibidn,i,λn+hniBiln,i)S(δn,λn)=ibihnS(dn,i,λn)+iBihnS(δn,ln,i)+ijbiBjhn2S(dn,i,ln,j)=ibihnS(dn,i,Λn,ihnjAi,jln,j)+iBihnS(Δn,ihnjai,jdn,j,ln,i)+ijbiBjhn2S(dn,i,ln,j)=ihn(biS(dn,i,Λn,i)+BiS(Δn,i,ln,i))+ij(biAi,jBjaj,i+biBj)hn2S(dn,i,ln,j).\begin{split}S(\delta_{n+1},\lambda_{n+1})-S(\delta_{n},\lambda_{n})&\textstyle=S(\delta_{n}+h_{n}\sum_{i}b_{i}d_{n,i},\lambda_{n}+h_{n}\sum_{i}B_{i}l_{n,i})-S(\delta_{n},\lambda_{n})\\ &\textstyle=\sum_{i}b_{i}h_{n}S(d_{n,i},\lambda_{n})+\sum_{i}B_{i}h_{n}S(\delta_{n},l_{n,i})\\ &\textstyle\ \ +\sum_{i}\sum_{j}b_{i}B_{j}h_{n}^{2}S(d_{n,i},l_{n,j})\\ &\textstyle=\sum_{i}b_{i}h_{n}S(d_{n,i},\Lambda_{n,i}-h_{n}\sum_{j}A_{i,j}l_{n,j})\\ &\textstyle\ \ +\sum_{i}B_{i}h_{n}S(\Delta_{n,i}-h_{n}\sum_{j}a_{i,j}d_{n,j},l_{n,i})\\ &\textstyle\ \ +\sum_{i}\sum_{j}b_{i}B_{j}h_{n}^{2}S(d_{n,i},l_{n,j})\\ &\textstyle=\sum_{i}h_{n}(b_{i}S(d_{n,i},\Lambda_{n,i})+B_{i}S(\Delta_{n,i},l_{n,i}))\\ &\textstyle\ \ +\sum_{i}\sum_{j}(-b_{i}A_{i,j}-B_{j}a_{j,i}+b_{i}B_{j})h_{n}^{2}S(d_{n,i},l_{n,j}).\end{split} (17)

If Bi=biB_{i}=b_{i} and biAi,j+Bjaj,ibiBj=0b_{i}A_{i,j}+B_{j}a_{j,i}-b_{i}B_{j}=0, the change vanishes, i.e., the partitioned Runge–Kutta conserves a bilinear quantity SS. Note that bib_{i} must not vanish because Ai,j=Bj(1aj,i/bi)A_{i,j}=B_{j}(1-a_{j,i}/b_{i}). Therefore, the bilinear quantity λnδn\lambda_{n}^{\top}\delta_{n} is conserved as

λNδN=λnδn for n=0,,N.\lambda_{N}^{\top}\delta_{N}=\lambda_{n}^{\top}\delta_{n}\mbox{ for }n=0,\dots,N. (18)

Remark 3 indicates δn=xnx0\delta_{n}=\frac{\partial x_{n}}{\partial x_{0}}. When λN\lambda_{N} is set to ((xN)xN)(\frac{\partial\mathcal{L}(x_{N})}{\partial x_{N}})^{\top},

(xN)x0=(xN)xNxNx0=λNδN=λnδn=(xN)xnxnx0,\frac{\partial\mathcal{L}(x_{N})}{\partial x_{0}}=\frac{\partial\mathcal{L}(x_{N})}{\partial x_{N}}\frac{\partial x_{N}}{\partial x_{0}}=\lambda_{N}^{\top}\delta_{N}=\lambda_{n}^{\top}\delta_{n}=\frac{\partial\mathcal{L}(x_{N})}{\partial x_{n}}\frac{\partial x_{n}}{\partial x_{0}}, (19)

Therefore, λn=((xN)xn)\lambda_{n}=(\frac{\partial\mathcal{L}(x_{N})}{\partial x_{n}})^{\top}.

Proof of Theorem 2:

By solving the combination of the integrators in Eqs. (5) and (7), a change in a bilinear quantity S(δ,λ)S(\delta,\lambda) that the continuous-time dynamics conserves is

S(δn+1,λn+1)S(δn,λn)=S(δn+hnibidn,i,λn+hnib~iln,i)S(δn,λn)=ibihnS(dn,i,λn)+ib~ihnS(δn,ln,i)+ijbib~jhn2S(dn,i,ln,j)=iI0bihnS(dn,i,Λn,ihnjb~j(1aj,i/bi)ln,j)+ib~ihnS(Δn,ihnjai,jdn,j,ln,i)+iI0jbib~jhn2S(dn,i,ln,j)=iI0bihn(S(dn,i,Λn,j)+S(Δn,i,ln,j))+iI0j(bib~j(1aj,i/bi)b~jaj,i+bib~j)hn2S(dn,i,ln,j)+iI0(b~ihnS(Δn,i,ln,j)jb~jaj,ihn2S(dn,i,ln,j))=iI0bihn(S(dn,i,Λn,j)+S(Δn,i,ln,j))+iI0hn2(S(dn,i,Λn,j)+S(Δn,i,ln,j))=0.\begin{split}S(\delta_{n+1},\lambda_{n+1})-S(\delta_{n},\lambda_{n})&\textstyle=S(\delta_{n}+h_{n}\sum_{i}b_{i}d_{n,i},\lambda_{n}+h_{n}\sum_{i}\tilde{b}_{i}l_{n,i})-S(\delta_{n},\lambda_{n})\\ &\textstyle=\sum_{i}b_{i}h_{n}S(d_{n,i},\lambda_{n})+\sum_{i}\tilde{b}_{i}h_{n}S(\delta_{n},l_{n,i})\\ &\textstyle\ \ +\sum_{i}\sum_{j}b_{i}\tilde{b}_{j}h_{n}^{2}S(d_{n,i},l_{n,j})\\ &\textstyle=\sum_{i\not\in I_{0}}b_{i}h_{n}S(d_{n,i},\Lambda_{n,i}-h_{n}\sum_{j}\tilde{b}_{j}(1-a_{j,i}/b_{i})l_{n,j})\\ &\textstyle\ \ +\sum_{i}\tilde{b}_{i}h_{n}S(\Delta_{n,i}-h_{n}\sum_{j}a_{i,j}d_{n,j},l_{n,i})\\ &\textstyle\ \ +\sum_{i\not\in I_{0}}\sum_{j}b_{i}\tilde{b}_{j}h_{n}^{2}S(d_{n,i},l_{n,j})\\ &\textstyle=\sum_{i\not\in I_{0}}b_{i}h_{n}(S(d_{n,i},\Lambda_{n,j})+S(\Delta_{n,i},l_{n,j}))\\ &\textstyle\ \ +\sum_{i\not\in I_{0}}\sum_{j}(-b_{i}\tilde{b}_{j}(1-a_{j,i}/b_{i})-\tilde{b}_{j}a_{j,i}+b_{i}\tilde{b}_{j})h_{n}^{2}S(d_{n,i},l_{n,j})\\ &\textstyle\ \ +\sum_{i\in I_{0}}(\tilde{b}_{i}h_{n}S(\Delta_{n,i},l_{n,j})-\sum_{j}\tilde{b}_{j}a_{j,i}h_{n}^{2}S(d_{n,i},l_{n,j}))\\ &\textstyle=\sum_{i\not\in I_{0}}b_{i}h_{n}(S(d_{n,i},\Lambda_{n,j})+S(\Delta_{n,i},l_{n,j}))\\ &\textstyle\ \ +\sum_{i\in I_{0}}h_{n}^{2}(S(d_{n,i},\Lambda_{n,j})+S(\Delta_{n,i},l_{n,j}))\\ &=0.\end{split} (20)

Hence, the bilinear quantity S(δ,λ)S(\delta,\lambda) is conserved.

Proof of Remark 4:

Eq. (6) can be rewritten as

λn=λn+1hni=1sbiln,iln,i=fx(Xn,i,tn+cihn)Λn,i,Λn,i=λn+1hni=1sbjaj,ibiln,j.\begin{split}\lambda_{n}&=\lambda_{n+1}-h_{n}\sum_{i=1}^{s}b_{i}l_{n,i}\\ l_{n,i}&=-\frac{\partial f}{\partial x}(X_{n,i},t_{n}+c_{i}h_{n})^{\top}\Lambda_{n,i},\\ \Lambda_{n,i}&=\lambda_{n+1}-h_{n}\sum_{i=1}^{s}b_{j}\frac{a_{j,i}}{b_{i}}l_{n,j}.\end{split} (21)

Eq. (7) can be rewritten as

λn=λn+1hni=1sb~iln,i,ln,i=fx(Xn,i,tn+cihn)Λn,i,Λn,i={λn+1hnj=1sb~jaj,ibiln,jifiI0j=1sb~jaj,iln,jifiI0.\begin{split}\lambda_{n}&=\lambda_{n+1}-h_{n}\sum_{i=1}^{s}\tilde{b}_{i}l_{n,i},\\ l_{n,i}&=-\frac{\partial f}{\partial x}(X_{n,i},t_{n}+c_{i}h_{n})^{\top}\Lambda_{n,i},\\ \Lambda_{n,i}&=\begin{cases}\lambda_{n+1}-h_{n}\sum_{j=1}^{s}\tilde{b}_{j}\frac{a_{j,i}}{b_{i}}l_{n,j}&\mbox{if}\ \ i\not\in I_{0}\\ -\sum_{j=1}^{s}\tilde{b}_{j}a_{j,i}l_{n,j}&\mbox{if}\ \ i\in I_{0}.\\ \end{cases}\\ \end{split} (22)

Because ai,j=0a_{i,j}=0 for jij\geq i, aj,i=0a_{j,i}=0 for jij\leq i. The intermediate adjoint variable Λn,i\Lambda_{n,i} is calculable from i=si=s to i=1i=1 sequentially, i.e., the integration backward in time is explicit.

Appendix C Gradients in General Cases

C.1 Gradient w.r.t. Parameters

For the parameter adjustment, one can consider the parameters θ\theta as a part of the augmented state x~=[xθ]\tilde{x}=[x\ \ \theta]^{\top} of the system

ddtx~=f~(x~,t),f~(x~,t)=[f(x,t,θ)0],x~(0)=[x0θ].\frac{{\mathrm{d}}}{\mathrm{d}t}\tilde{x}=\tilde{f}(\tilde{x},t),\ \tilde{f}(\tilde{x},t)=\begin{bmatrix}f(x,t,\theta)\\ 0\end{bmatrix},\ \ \tilde{x}(0)=\begin{bmatrix}x_{0}\\ \theta\end{bmatrix}. (23)

The variational and adjoint variables are augmented in the same way. For the augmented adjoint variable λ~=[λλθ]\tilde{\lambda}=[\lambda\ \ \lambda_{\theta}]^{\top}, the augmented adjoint system is

ddtλ~=f~x~(x~,t)λ~=[fx0fθ0][λλθ]=[fxλfθλ].\frac{{\mathrm{d}}}{\mathrm{d}t}\tilde{\lambda}=-\frac{\partial\tilde{f}}{\partial\tilde{x}}(\tilde{x},t)^{\top}\tilde{\lambda}=-\begin{bmatrix}\frac{\partial f}{\partial x}^{\top}&0\\ \frac{\partial f}{\partial\theta}^{\top}&0\end{bmatrix}\begin{bmatrix}\lambda\\ \lambda_{\theta}\end{bmatrix}=\begin{bmatrix}-\frac{\partial f}{\partial x}^{\top}\lambda\\ -\frac{\partial f}{\partial\theta}^{\top}\lambda\end{bmatrix}. (24)

Hence, the adjoint variable λ\lambda for the system state xx is unchanged from Eq. (3), and the one λθ\lambda_{\theta} for the parameters θ\theta depends on the former as

ddtλθ=fθ(x,t,θ)λ,\frac{{\mathrm{d}}}{\mathrm{d}t}\lambda_{\theta}=-\frac{\partial f}{\partial\theta}(x,t,\theta)^{\top}\lambda, (25)

and λθ(T)=((x(T),θ)θ)\lambda_{\theta}(T)=(\frac{\partial\mathcal{L}(x(T),\theta)}{\partial\theta})^{\top}.

C.2 Gradient of Functional

When the solution x(t)x(t) is evaluated by a functional 𝒞\mathcal{C} as

𝒞(x(t))=0T(x(t),t)dt,\mathcal{C}(x(t))=\int_{0}^{T}\mathcal{L}(x(t),t){\mathrm{d}t}, (26)

the adjoint variable λC\lambda_{C} that denotes the gradient λC(t)=(𝒞(x(T))x(t))\lambda_{C}(t)=(\frac{\partial\mathcal{C}(x(T))}{\partial x(t)})^{\top} of the functional 𝒞\mathcal{C} is given by

ddtλC=fx(x,t)λC+(x(t),t)x(t),λC(T)=𝟎.\frac{{\mathrm{d}}}{\mathrm{d}t}\lambda_{C}=-\frac{\partial f}{\partial x}(x,t)^{\top}\lambda_{C}+\frac{\partial\mathcal{L}(x(t),t)}{\partial x(t)},\ \ \lambda_{C}(T)=\mathbf{0}. (27)

Appendix D Implementation Details

D.1 Robustness to Rounding Error

By definition, the naive backpropagation algorithm, baseline scheme, ACA, and the proposed symplectic adjoint method provide the exact gradient up to rounding error. However, the naive backpropagation algorithm and baseline scheme obtained slightly worse results on the GAS, POWER, and HEPMASS datasets. Due to the repeated use of the neural network, each method accumulates the gradient of the parameters θ\theta for each use. Let θn,i\theta_{n,i} denote the parameters used in the ii-th stage of nn-th step even though their values are unchanged. The backpropagation algorithm obtains the gradient θ\frac{\partial\mathcal{L}}{\partial\theta} with respect to the parameters θ\theta by accumulating the gradient over all stages and steps one-by-one as

θ=n=0,,N1,i=1,,sθn,i.\begin{split}\frac{\partial\mathcal{L}}{\partial\theta}&=\sum_{\begin{subarray}{c}n=0,\dots,N-1,\\ i=1,\dots,s\end{subarray}}\frac{\partial\mathcal{L}}{\partial\theta_{n,i}}.\end{split} (28)

When the step size hnh_{n} at the nn-th step is sufficiently small, the gradient θn,i\frac{\partial\mathcal{L}}{\partial\theta_{n,i}} at the ii-th stage may be insignificant compared with the accumulated gradient and be rounded off during the accumulation.

Conversely, ACA accumulates the gradient within a step and then over time steps; this process can be expressed informally as

θ=n=0N1(i=1sθn,i).\begin{split}\frac{\partial\mathcal{L}}{\partial\theta}&=\sum_{n=0}^{N-1}\left(\sum_{i=1}^{s}\frac{\partial\mathcal{L}}{\partial\theta_{n,i}}\right).\end{split} (29)

Further, according to Eqs. (6) and (25), the (symplectic) adjoint method accumulates the adjoint variable λ\lambda (i.e., the transpose of the gradient) within a step and then over time steps as

λθ,n=λθ,n+1hn(i=1sBi(fθn,i(Xn,i,t+Cihn,θn,i)Λn,i)).\lambda_{\theta,n}=\lambda_{\theta,n+1}-h_{n}\left(\sum_{i=1}^{s}B_{i}\left(-\frac{\partial f}{\partial\theta_{n,i}}(X_{n,i},t+C_{i}h_{n},\theta_{n,i})^{\top}\Lambda_{n,i}\right)\right). (30)

In these cases, even when the step size hnh_{n} at the nn-th step is small, the gradient summed within a step (over ss stages) may still be significant and robust to rounding errors. This is the reason the adjoint method, ACA, and the symplectic adjoint method performed better than the naive backpropagation algorithm and baseline scheme for some datasets. Note that this approach requires additional memory consumption to store the gradient summed within a step, and it is applicable to the backpropagation algorithm with a slight modification.

D.2 Memory Consumption Optimization

Following Eqs. (21) and (22), a naive implementation of the adjoint method retains the adjoint variables Λn,i\Lambda_{n,i} at all stages i=1,,si=1,\dots,s to obtain their time-derivatives ln,il_{n,i}, and then adds them up to obtain the adjoint variable λn\lambda_{n} at the nn-th time step. However, as Eq. (25) shows, the adjoint variable λθ\lambda_{\theta} for the parameters θ\theta is not used for obtaining its time-derivative ddtλθ\frac{{\mathrm{d}}}{\mathrm{d}t}\lambda_{\theta}. One can add up the adjoint variable Λθn,i{\Lambda_{\theta}}_{n,i} for the parameters θ\theta at stage ii one-by-one without retaining it, thereby reducing the memory consumption proportionally to the number of parameters times the number of stages. A similar optimization is applicable to the adjoint method.

Table A1: Results on learning physical systems without the deterministic option.
KdV Equation Cahn–Hilliard System
MSE (×103\times 10^{-3}) mem. time MSE (×106\times 10^{-6}) mem. time
adjoint method [2] 1.61±\pm3.23 181.4±\pm00.0 240±\pm16 5.58±\pm2.12 181.4±\pm00.0 805±\pm25
backpropagation [2] 1.61±\pm3.24 733.9±\pm15.6 94±\pm04 5.45±\pm1.55 3053.5±\pm22.9 382±\pm11
ACA [46] 1.61±\pm3.24 734.5±\pm20.3 120±\pm04 6.00±\pm3.27 780.4±\pm22.9 422±\pm16
proposed 1.61±\pm3.58 182.1±\pm00.0 141±\pm07 5.48±\pm1.90 182.1±\pm00.0 480±\pm19
Mean-squared errors (MSEs) in long-term predictions, peak memory consumption [MiB\mathrm{MiB}],
and computation time per iteration [ms/itr\mathrm{ms/itr}].

D.3 Parallelization

The memory consumption and computation time depend highly on the implementations and devices. Being implemented on a GPU, the convolution operation can be easily parallelized in space and exhibits a non-deterministic behavior. To avoid the non-deterministic behavior, PyTorch provides an option torch.backends.cudnn.deterministic, which was used to obtain the results in Section 5.2, following the original implementation [31]. Without this option, the memory consumption increased by a certain amount, and the computation times reduced due to the aggressive parallelization, as shown by the results in Table A1. Even then, the proposed symplectic adjoint method occupied the smallest memory among the methods for the exact gradient. The increase in the memory consumption is proportional to the width of a neural network; therefore, it is negligible when the neural network is sufficiently deep.

Note that the results in Section 5.1 were obtained without the deterministic option.

Table A2: Results obtained for continuous normalizing flows.
MINIBOONE (M=1M=1) GAS (M=5M=5) POWER (M=5M=5)
NLL mem. time NLL mem. time NLL mem. time
adjoint method [2] 10.59±\pm0.17 170±\pm000 0.74±\pm0.04 -10.53±\pm0.25 24±\pm000 4.82±\pm0.29 -0.31±\pm0.01 8.1±\pm000.0 6.33±\pm0.18
backpropagation [2] 10.54±\pm0.18 4,436±\pm115 0.91±\pm0.05 -9.53±\pm0.42 4,479±\pm250 12.00±\pm0.93 -0.24±\pm0.05 1710.9±\pm193.1 10.64±\pm2.73
baseline scheme 10.54±\pm0.18 4,457±\pm115 1.10±\pm0.04 -9.53±\pm0.42 1,858±\pm228 5.48±\pm0.25 -0.24±\pm0.05 515.2±\pm122.0 4.37±\pm0.70
ACA [46] 10.57±\pm0.30 306±\pm000 0.77±\pm0.02 -10.65±\pm0.45 73±\pm000 3.98±\pm0.14 -0.31±\pm0.02 29.5±\pm000.5 5.08±\pm0.88
proposed 10.49±\pm0.11 95±\pm000 0.84±\pm0.03 -10.89±\pm0.11 20±\pm000 4.39±\pm0.23 -0.31±\pm0.02 9.2±\pm000.0 5.73±\pm0.43
HEPMASS (M=10M=10) BSDS300 (M=2M=2) MNIST (M=6M=6)
NLL mem. time NLL mem. time NLL mem. time
adjoint method [2] 16.49±\pm0.25 40±\pm000 4.19±\pm0.15 -152.04±\pm0.09 577±\pm0 11.70±\pm0.44 0.918±\pm0.011 1,086±\pm4 10.12±\pm0.88
backpropagation [2] 17.03±\pm0.22 5,254±\pm137 11.82±\pm1.33
baseline scheme 17.03±\pm0.22 1,102±\pm174 4.40±\pm0.40
ACA [46] 16.41±\pm0.39 88±\pm000 3.67±\pm0.12 -151.27±\pm0.47 757±\pm1 6.97±\pm0.25 0.919±\pm0.003 4,332±\pm1 7.94±\pm0.63
proposed 16.48±\pm0.20 35±\pm000 4.15±\pm0.13 -151.17±\pm0.15 283±\pm2 8.07±\pm0.72 0.917±\pm0.002 1,079±\pm1 9.42±\pm0.32

Negative log-likelihoods (NLL), peak memory consumption [MiB\mathrm{MiB}], and computation time per iteration [s/itr\mathrm{s/itr}]. The medians ±\pm standard deviations of three runs.