This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Law of Balance and Stationary Distribution of Stochastic Gradient Descent

Liu Ziyin Hongchao Li Masahito Ueda
Abstract

The stochastic gradient descent (SGD) algorithm is the algorithm we use to train neural networks. However, it remains poorly understood how the SGD navigates the highly nonlinear and degenerate loss landscape of a neural network. In this work, we prove that the minibatch noise of SGD regularizes the solution towards a balanced solution whenever the loss function contains a rescaling symmetry. Because the difference between a simple diffusion process and SGD dynamics is the most significant when symmetries are present, our theory implies that the loss function symmetries constitute an essential probe of how SGD works. We then apply this result to derive the stationary distribution of stochastic gradient flow for a diagonal linear network with arbitrary depth and width. The stationary distribution exhibits complicated nonlinear phenomena such as phase transitions, broken ergodicity, and fluctuation inversion. These phenomena are shown to exist uniquely in deep networks, implying a fundamental difference between deep and shallow models.

1 Introduction

The stochastic gradient descent (SGD) algorithm is defined as

Δθt=ηSxBθ(θ,x),\Delta\theta_{t}=-\frac{\eta}{S}\sum_{x\in B}\nabla_{\theta}\ell(\theta,x), (1)

where θ\theta is the model parameter, (θ,x)\ell(\theta,x) is a per-sample loss whose expectation over xx gives the training loss. BB is a randomly sampled minibatch of data points, each independently sampled from the training set, and SS is the minibatch size. The training-set average of \ell is the training objective L(θ)=𝔼x(θ,x)L(\theta)=\mathbb{E}_{x}\ell(\theta,x), where 𝔼x\mathbb{E}_{x} denotes averaging over the training set. Two aspects of the algorithm make it difficult to understand this algorithm: (1) its dynamics is discrete in time, and (2) the randomness is highly nonlinear and parameter-dependent. This work relies on the continuous-time approximation and deals with the second aspect.

In natural and social sciences, the most important object of study of a stochastic system is its stationary distribution, which is often found to offer fundamental insights into understanding a given stochastic process [37, 31]. Arguably, a great deal of insights into SGD can be obtained if we have an analytical understanding of the stationary distribution, which remains unknown until today. Predominantly many works study the dynamics and stationary properties of SGD in the case of a strongly convex loss function [43, 44, 20, 46, 27, 52, 21, 42]. The works that touch on the nonlinear aspects of the loss function rely heavily on the local approximations of the stationary distribution of SGD close to a local minimum, often with additional unrealistic assumptions about the noise. For example, using a saddle point expansion and assuming that the noise is parameter-independent, Refs. [23, 44, 20] showed that the stationary distribution of SGD is exponential. Taking partial parameter-dependence into account and near an interpolation minimum, Ref. [27] showed that the stationary distribution is power-law like and proportional to L(θ)c0L(\theta)^{-c_{0}} for some constant c0c_{0}. However, the stationary distribution of SGD is unknown when the loss function is beyond quadratic and high-dimensional.

Since the stationary distribution of SGD is unknown, we will compare our results with the most naive theory one can construct for SGD, a continuous-time Langevin equation with a constant noise level:

θ˙(t)=ηθL(θ)+2T0ϵ(t),\dot{\theta}(t)=-\eta\nabla_{\theta}L(\theta)+\sqrt{2T_{0}}\epsilon(t), (2)

where ϵ\epsilon is a random time-dependent noise with zero mean and 𝔼[ϵ(t)ϵ(t)T]=ηδ(tt)I\mathbb{E}[\epsilon(t)\epsilon(t^{\prime})^{T}]=\eta\delta(t-t^{\prime})I with II being the identity operator. Here, the naive theory relies on the assumption that one can find a constant scalar T0T_{0} such that Eq. (2) closely models (1), at least after some level of coarse-graining. Let us examine some of the predictions of this model to understand when and why it goes wrong.

There are two important predictions of this model. The first is that the stationary distribution of SGD is a Gibbs distribution with temperature T0T_{0}: p(θ)exp[L(θ)/T]p(\theta)\propto\exp[-L(\theta)/T]. This implies that the maximum likelihood estimator of θ\theta under SGD is the same as the global minimizer of the L(θ)L(\theta): argmaxp(θ)=argminL(θ)\arg\max p(\theta)=\arg\min L(\theta). This relation holds for the local minima as well: every local minimum of LL corresponds to a local maximum of pp. These properties are often required in the popular argument that SGD approximates Bayesian inference [23, 25]. Another implication is ergodicity [39]: any state with the same energy will have an equal probability of being accessed. The second is the dynamical implication: SGD will diffuse. If there is a degenerate direction in the loss function, SGD will diffuse along that direction.111Note that this can also be seen as a dynamical interpretation of the ergodicity.

Refer to caption
Refer to caption
Figure 1: SGD converges to a balanced solution. Left: the quantity u2w2u^{2}-w^{2} is conserved for GD without noise, is divergent for GD with an isotropic Gaussian noise, which simulates the simple Langevin model, and decays to zero for SGD, making a sharp and dramatic contrast. Right: illustration of the three types of dynamics. Gradient descent (GD) moves along the conservation line due to the conservation law: u2(t)w2(t)=u2(0)w2(0)u^{2}(t)-w^{2}(t)=u^{2}(0)-w^{2}(0). GD with an isotropic Gaussian noise expands and diverges along the flat direction of the minimum valley. The actual SGD oscillates along a balanced solution.

However, these predictions of the Langevin model are not difficult to reject. Let us consider a simple two-layer network with the loss function: (u,w,x)=(uwxy(x))2\ell(u,w,x)=(uwx-y(x))^{2}. Because of the rescaling symmetry, a valley of degenerate solution exists at uw=c0uw=c_{0}. Under the simple Langevin model, SGD diverges to infinity due to diffusion. One can also see this from a static perspective. All points on the line uw=c0uw=c_{0} must have the same probability at stationarity, but such a distribution does not exist because it is not normalizable. This means that the Langevin model of SGD diverges for this loss function.

Does this agree with the empirical observation? Certainly not.222In fact, had it been the case, no linear network or ReLU network can be trained with SGD. See Fig. 1. We see that contrary to the prediction of the Langevin model, |u2w2||u^{2}-w^{2}| converges to zero under SGD. Under GD, this quantity is conserved during training [7]. Only the Gaussian GD obeys the prediction of the Langevin model, which is expected. This sharp contrast shows that the SGD dynamics is quite special, and a naive theoretical model can be very far from the truth in understanding its behavior. There is one more lesson to be learned. The fact that the Langevin model disagrees the most with the experiments when symmetry conditions are present suggests that the symmetry conditions are crucial tools to probe and understand the nature of the SGD noise, which is the main topic of our theory.

2 Law of Balance

Now, we consider the actual continuous-time limit of SGD [16, 17, 19, 33, 9, 13]:

dθ=θLdt+TC(θ)dWt,d{\theta}=-\nabla_{\theta}Ldt+\sqrt{TC(\theta)}dW_{t}, (3)

where dWtdW_{t} is a stochastic process satisfying dWtN(0,Idt)dW_{t}\sim N(0,Idt) and 𝔼[dWtdWtT]=δ(tt)I\mathbb{E}[dW_{t}dW_{t^{\prime}}^{T}]=\delta(t-t^{\prime})I, and T=η/TT=\eta/T. Apparently, TT gives the average noise level in the dynamics. Previous works have suggested that the ratio η/S:=T\eta/S:=T is the main factor determining the behavior of SGD, and a higher TT often leads to better generalization performance [32, 20, 50]. The crucial difference between Eq. (3) and (2) is that in (3), the noise covariance C(θ)C(\theta) is parameter-dependent and, in general, low-rank when symmetries exist.

Due to standard architecture designs, a type of invariance – the rescaling symmetry – often appears in the loss function and exists for all sampling of minibatches. The per-sample loss \ell is said to have the rescaling symmetry for all xx if (u,w,x)=(λu,w/λ,x)\ell(u,w,x)=\ell\left(\lambda u,w/\lambda,x\right) for an arbitrary scalar λ\lambda Note that this implies that the expected loss LL also has the same symmetry. This type of symmetry appears in many scenarios in deep learning. For example, it appears in any neural network with the ReLU activation. It also appears in the self-attention of transformers, often in the form of key and query matrices [38]. When this symmetry exists between uu and ww, one can prove the following result, which we refer to as the law of balance.

Theorem 1.

(Law of balance.) Let uu and ww be vectors of arbitrary dimensions. Let (u,w,x)\ell(u,w,x) satisfy (u,w,x)=(λu,w/λ,x)\ell(u,w,x)=\ell(\lambda u,w/\lambda,x) for arbitrary xx and λ0\lambda\neq 0. Then,

ddt(u2w2)=T(uTC1uwTC2w),\frac{d}{dt}(||u||^{2}-||w||^{2})=-T(u^{T}C_{1}u-w^{T}C_{2}w), (4)

where C1=𝔼[ATA]𝔼[AT]𝔼[A]C_{1}=\mathbb{E}[A^{T}A]-\mathbb{E}[A^{T}]\mathbb{E}[A], C2=𝔼[AAT]𝔼[A]𝔼[AT]C_{2}=\mathbb{E}[AA^{T}]-\mathbb{E}[A]\mathbb{E}[A^{T}] and Aki=~(uiwk)A_{ki}=\frac{\partial\tilde{\ell}}{\partial(u_{i}w_{k})} with ~(uiwk,x)(ui,wk,x)\tilde{\ell}(u_{i}w_{k},x)\equiv\ell(u_{i},w_{k},x).

Our result holds in a stronger version if we consider the effect of a finite step-size by using the modified loss function (See Appendix A.7) [2, 34]. For common problems, C1C_{1} and C2C_{2} are positive definite, and this theorem implies that the norms of uu and ww will be approximately balanced. To see this, we can simplify the expression to

T(λ1Mu2λ2mw2)ddt(u2w2)T(λ1mu2λ2Mw2),-T(\lambda_{1M}||u||^{2}-\lambda_{2m}||w||^{2})\leq\frac{d}{dt}(||u||^{2}-||w||^{2})\leq-T(\lambda_{1m}||u||^{2}-\lambda_{2M}||w||^{2}), (5)

where λ1m(2m),λ1M(2M)\lambda_{1m(2m)},\lambda_{1M(2M)} represent the minimal and maximal eigenvalue of the matrix C1(2)C_{1(2)}, respectively. In the long-time limit, the value of u2/w2||u||^{2}/||w||^{2} is restricted by

λ2mλ1Mu2w2λ2Mλ1m,\frac{\lambda_{2m}}{\lambda_{1M}}\leq\frac{||u||^{2}}{||w||^{2}}\leq\frac{\lambda_{2M}}{\lambda_{1m}}, (6)

which implies that the stationary dynamics of the parameters u,wu,w is constrained in a bounded subspace of the unbounded degenerate local minimum valley. Conventional analysis shows that the difference between SGD and GD is of order T2T^{2} per unit time step, and it is thus often believed that SGD can be understood perturbatively through GD [13]. However, the law of balance implies that the difference between GD and SGD is not perturbative. As long as there is any level of noise, the difference between GD and SGD at stationarity is O(1)O(1). This theorem has an important implication: the noise in SGD creates a qualitative difference between SGD and GD, and we must study SGD noise in its own right. This theorem also implies the loss of ergodicity, an important phenomenon in nonequilibrium physics [29, 35, 24, 36], because not all solutions with the same training loss will be accessed by SGD with equal probability.

The theorem greatly simplifies when both uu and ww are one-dimensional.

Corollary 1.

If u,wu,w\in\mathbb{R}, then, ddt|u2w2|=TC0|u2w2|\frac{d}{dt}|u^{2}-w^{2}|=-TC_{0}|u^{2}-w^{2}|, where C0=Var[(uw)]C_{0}={\rm Var}[\frac{\partial\ell}{\partial(uw)}].

Before we apply the theorem to study the stationary distributions, we stress the importance of this balance condition. This relation is closely related to Noether’s theorem [26, 1, 22]. If there is no weight decay or stochasticity in training, the quantity u2w2||u||^{2}-||w||^{2} will be a conserved quantity under gradient flow [7, 15], as is evident by taking the infinite SS limit. The fact that it monotonically decays to zero at a finite TT may be a manifestation of some underlying fundamental mechanism. A more recent result in Ref. [40] showed that for a two-layer linear network, the norms of two layers are within a distance of order O(η1)O(\eta^{-1}), suggesting that the norm of the two layers are balanced. Our result agrees with Ref. [40] in this case, but our result is far stronger because our result is nonperturbative and only relies on the rescaling symmetry, and is independent of the loss function or architecture of the model.

Example: two-layer linear network.

It is instructive to illustrate the application of the law to a two-layer linear network, the simplest model that obeys the law. Let θ=(w,u)\theta=(w,u) denote the set of trainable parameters; the per-sample loss is (θ,x)=(iduiwixy)2+γθ2\ell(\theta,x)=\left(\sum_{i}^{d}u_{i}w_{i}x-y\right)^{2}+\gamma||\theta||^{2}. Here, dd is the width of the model, θ2||\theta||^{2} is the common L2L_{2} regularization term that encourages the learned model to have a small norm, γ0\gamma\geq 0 is the strength of regularization, and 𝔼x\mathbb{E}_{x} denotes the averaging over the training set, which could be a continuous distribution or a discrete sum of delta distributions. It will be convenient for us also to define the shorthand: v:=iduiwiv:=\sum_{i}^{d}u_{i}w_{i}. The distribution of vv is said to be the distribution of the “model.”

Applying the law of balance, we obtain that

ddt(ui2wi2)=4[T(α1v22α2v+α3)+γ](ui2wi2),\frac{d}{dt}(u_{i}^{2}-w_{i}^{2})=-4[T(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3})+\gamma](u_{i}^{2}-w_{i}^{2}), (7)

where we have introduced the parameters

{α1:=Var[x2],α2:=𝔼[x3y]𝔼[x2]𝔼[xy],α3:=Var[xy].\begin{cases}\alpha_{1}:={\rm Var}[x^{2}],\\ \alpha_{2}:=\mathbb{E}[x^{3}y]-\mathbb{E}[x^{2}]\mathbb{E}[xy],\\ \alpha_{3}:={\rm Var}[xy].\end{cases} (8)

When α22α1α3\alpha_{2}^{2}-\alpha_{1}\alpha_{3} or γ>0\gamma>0, the time evolution of |u2w2||u^{2}-w^{2}| can be upper-bounded by an exponentially decreasing function in time: |ui2wi2|(t)<|ui2wi2|(0)exp(4T(α22α1α3)t/α14γt)0|u_{i}^{2}-w_{i}^{2}|(t)<|u_{i}^{2}-w_{i}^{2}|(0)\exp\left(-4{T}(\alpha_{2}^{2}-\alpha_{1}\alpha_{3})t/\alpha_{1}-4\gamma t\right)\to 0. Namely, the quantity (ui2wi2)(u_{i}^{2}-w_{i}^{2}) decays to 0 with probability 11. We thus have ui2=wi2u_{i}^{2}=w_{i}^{2} for all i{1,,d}i\in\{1,\cdots,d\} at stationarity, in agreement with what we see in Figure 1.

3 Stationary Distribution of SGD

As an important application of the law of balance, we solve the stationary distribution of SGD for a deep diagonal linear network. While linear networks are limited in expressivity, their loss landscape and dynamics are highly nonlinear and is regarded as a minimal model of nonlinear neural networks [14, 18, 48, 41].

3.1 Depth-0 Case

Let us first derive the stationary distribution of a one-dimensional linear regressor, which will be a basis for comparison to help us understand what is unique about having a “depth” in deep learning. The per-sample loss is (x,v)=(vxy)2+γv2\ell(x,v)=(vx-y)^{2}+\gamma v^{2}, for which the SGD dynamics is dv=2(β1vβ2+γv)dt+TC(v)dW(t){dv}=-2(\beta_{1}v-\beta_{2}+\gamma v)dt+\sqrt{TC(v)}{dW(t)}, where we have defined

{β1:=𝔼[x2],β2:=𝔼[xy].\begin{cases}\beta_{1}:=\mathbb{E}[x^{2}],\\ \beta_{2}:=\mathbb{E}[xy].\end{cases} (9)

Note that the closed-form solution of linear regression gives the global minimizer of the loss function: v=β2/β1v^{*}=\beta_{2}/\beta_{1}. The gradient variance is also not trivial: C(v):=Var[(v,x)]=4(α1v22α2v+α3)C(v):={\rm Var}[\ell(v,x)]=4(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3}). Note that the loss landscape LL only depends on β1\beta_{1} and β2\beta_{2}, and the gradient noise only depends on α1\alpha_{1}, α2\alpha_{2} and, α3\alpha_{3}. These relations imply that CC can be quite independent of LL, contrary to popular beliefs in the literature [27, 23]. It is also reasonable to call β\beta the landscape parameters and α\alpha the noise parameters. We will see that both β\beta and α\alpha are important parameters appearing in all stationary distributions we derive, implying that the stationary distributions of SGD are strongly dependent on the data.

Another important quantity is Δ:=minvC(v)0\Delta:=\min_{v}C(v)\geq 0, which is the minimal level of noise on the landscape. For all the examples in this work,

Δ=Var[x2]Var[xy]cov(x2,xy)=α1α3α22.\Delta={\rm Var}[x^{2}]{\rm Var}[xy]-{\rm cov}(x^{2},xy)=\alpha_{1}\alpha_{3}-\alpha_{2}^{2}. (10)

When is Δ\Delta zero? It happens when, for all samples of (x,y)(x,y), xy+c=kx2xy+c=kx^{2} for some constant kk and cc. We focus on the case Δ>0\Delta>0 in the main text, which is most likely the case for practical situations. The other cases are dealt with in Section A.

For Δ>0\Delta>0, the stationary distribution for linear regression is found to be

p(v)\displaystyle p(v) (α1v22α2v+α3)1β12Tα1exp[1Tα2β1α1β2α1Δarctan(α1vα2Δ)],\displaystyle\propto{(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3})^{-1-\frac{\beta_{1}^{\prime}}{2T\alpha_{1}}}}\exp\left[-\frac{1}{T}\frac{\alpha_{2}\beta_{1}^{\prime}-\alpha_{1}\beta_{2}}{\alpha_{1}\sqrt{\Delta}}\arctan\left(\frac{\alpha_{1}v-\alpha_{2}}{\sqrt{\Delta}}\right)\right], (11)

roughly in agreement with the result in Ref. [27]. Two notable features exist for this distribution: (1) the power exponent for the tail of the distribution depends on the learning rate and batch size, and (2) the integral of p(v)p(v) converges for an arbitrary learning rate. On the one hand, this implies that increasing the learning rate alone cannot introduce new phases of learning to a linear regression; on the other hand, it implies that the expected error is divergent as one increases the learning rate (or the feature variation), which happens at T=β1/α1T=\beta_{1}^{\prime}/\alpha_{1}. We will see that deeper models differ from the single-layer model in these two crucial aspects.

3.2 Deep Diagonal Networks

Now, we consider a diagonal deep linear network, whose loss function can be written as

=[id0(k=0Dui(k))xy]2,\displaystyle\ell=\left[\sum_{i}^{d_{0}}\left(\prod_{k=0}^{D}u_{i}^{(k)}\right)x-y\right]^{2}, (12)

where DD is the depth and d0d_{0} is the width. When the width d0=1d_{0}=1, the law of balance is sufficient to solve the model. When d0>1d_{0}>1, we need to eliminate additional degrees of freedom. A lot of recent works study the properties of a diagonal linear network, which has been found to well approximate the dynamics of real networks [30, 28, 3, 8].

We introduce vi:=k=0Dui(k)v_{i}:=\prod_{k=0}^{D}u_{i}^{(k)}, and so v=iviv=\sum_{i}v_{i} and call viv_{i} a “subnetwork” and vv the “model.” The following proposition shows that the dynamics of this model can be reduced to a one-dimensional form.

Theorem 2.

For all iji\neq j, one (or more) of the following conditions holds for all trajectories at stationarity:

  1. 1.

    vi=0v_{i}=0, or vj=0v_{j}=0, or L(θ)=0L(\theta)=0;

  2. 2.

    sgn(vi)=sgn(vj){\rm sgn}(v_{i})={\rm sgn}(v_{j}). In addition,

    1. (a)

      if D=1D=1, for a constant c0c_{0}, log|vi|log|vj|=c0\log|v_{i}|-\log|v_{j}|=c_{0};

    2. (b)

      if D>1D>1, |vi|2|vj|2=0|v_{i}|^{2}-|v_{j}|^{2}=0.

This theorem contains many interesting aspects. First of all, the three situations in item 1 directly tell us the distribution of vv, which is the quantity we ultimately care about.333L0L\to 0 is only possible when Δ=0\Delta=0 and v=β2/β1v=\beta_{2}/\beta_{1}. This result implies that if we want to understand the stationary distribution of SGD, we only need to solve the case of item 2. Once the parameters enter the condition of item 2, item 2 will continue to hold with probability 11 for the rest of the trajectory. The second aspect is that item 2 of the theorem implies that all the viv_{i} of the model must be of the same sign for any network with D1D\geq 1. Namely, no subnetwork of the original network can learn an incorrect sign. This is dramatically different from the case of D=0D=0. We will discuss this point in more detail below. The third interesting aspect of the theorem is that it implies that the dynamics of SGD is qualitatively different for different depths of the model. In particular, D=1D=1 and D>1D>1 have entirely different dynamics. For D=1D=1, the ratio between every pair of viv_{i} and vjv_{j} is a conserved quantity. In sharp contrast, for D>1D>1, the distance between different viv_{i} is no longer conserved but decays to zero. Therefore, a new balancing condition emerges as we increase the depth. This qualitative distinction also corroborates the discovery in Refs. [48] and [51], where D=1D=1 models are found to be qualitatively different from models with D>1D>1.

With this theorem, we are now ready to solve for the stationary distribution. It suffices to condition on the event that viv_{i} does not converge to zero. Let us suppose that there are dd nonzero viv_{i} that obey item 2 of Theorem 2 and dd can be seen as an effective width of the model. We stress that the effective width dd0d\leq d_{0} depends on the initialization and can be arbitrary.444One can systematically initialize the parameters in a way that dd takes any desired value between 11 and d0d_{0}; for example, one way to achieve this is to initialize on the stationary conditions at the desired value of dd. Therefore, we condition on a fixed value of dd to solve for the stationary distribution of vv (Appendix A):

p±(|v|)1|v|3(11/(D+1))(α1|v|22α2|v|+α3)exp(1T0|v|d|v|d12/(D+1)(β1|v|β2)(D+1)|v|2D/(D+1)(α1|v|22α2|v|+α3)),p_{\pm}(|v|)\propto\frac{1}{|v|^{3(1-1/(D+1))}(\alpha_{1}|v|^{2}\mp 2\alpha_{2}|v|+\alpha_{3})}\exp\left(-\frac{1}{T}\int_{0}^{|v|}d|v|\frac{d^{1-2/(D+1)}(\beta_{1}|v|\mp\beta_{2})}{(D+1)|v|^{2D/(D+1)}(\alpha_{1}|v|^{2}\mp 2\alpha_{2}|v|+\alpha_{3})}\right), (13)

where pp_{-} is the distribution of vv on (,0)(-\infty,0) and p+p_{+} is that on (0,)(0,\infty). Next, we analyze this distribution in detail. Since the result is symmetric in the sign of β2=𝔼[xy]\beta_{2}=\mathbb{E}[xy], we assume that 𝔼[xy]>0\mathbb{E}[xy]>0 from now on.

3.2.1 Depth-11 Nets

Refer to caption
Refer to caption
Refer to caption
Figure 2: Stationary distributions of SGD for single linear regression (D=0D=0), and a two-layer network (D=1D=1) across different T=η/ST=\eta/S: T=0.05T=0.05 (left) and T=0.5T=0.5 (Mid). We see that for D=1D=1, the stationary distribution is strongly affected by the choice of the learning rate. In contrast, for D=0D=0, the stationary distribution is also centered at the global minimizer of the loss function, and the choice of the learning rate only affects the thickness of the tail. Right: the stationary distribution of a one-layer tanh\tanh-model, f(x)=tanh(vx)f(x)=\tanh(vx) (D=0D=0) and a two-layer tanh-model f(x)=wtanh(ux)f(x)=w\tanh(ux) (D=1D=1). For D=1D=1, we define v:=wuv:=wu. The vertical line shows the ground truth. The deeper model never learns the wrong sign of wuwu, whereas the shallow model can learn the wrong one.

We focus on the case γ=0\gamma=0.555When weight decay is present, the stationary distribution is the same, except that one needs to replace β2\beta_{2} with β2γ\beta_{2}-\gamma. Other cases are also studied in detail in Appendix A and listed in Table. 1. The distribution of vv is

p±(|v|)\displaystyle p_{\pm}(|v|) |v|±β2/2α3T3/2(α1|v|22α2|v|+α3)1±β2/4Tα3exp(12Tα3β1α2β2α3Δarctanα1|v|α2Δ),\displaystyle\propto\frac{|v|^{\pm\beta_{2}/2\alpha_{3}T-3/2}}{(\alpha_{1}|v|^{2}\mp 2\alpha_{2}|v|+\alpha_{3})^{1\pm\beta_{2}/4T\alpha_{3}}}\exp\left(-\frac{1}{2T}\frac{\alpha_{3}\beta_{1}-\alpha_{2}\beta_{2}}{\alpha_{3}\sqrt{\Delta}}\arctan\frac{\alpha_{1}|v|\mp\alpha_{2}}{\sqrt{\Delta}}\right), (14)

This measure is worth a close examination. First, the exponential term is upper and lower bounded and well-behaved in all situations. In contrast, the polynomial term becomes dominant both at infinity and close to zero. When v<0v<0, the distribution is a delta function at zero: p(v)=δ(v)p(v)=\delta(v). To see this, note that the term vβ2/2α3T3/2v^{-\beta_{2}/2\alpha_{3}T-3/2} integrates to give vβ2/2α3T1/2v^{-\beta_{2}/2\alpha_{3}T-1/2} close to the origin, which is infinite. Outside the origin, the integral is finite. This signals that the only possible stationary distribution has a zero measure for v0v\neq 0. The stationary distribution is thus a delta distribution, meaning that if xx and yy are positively correlated, the learned subnets viv_{i} can never be negative, no matter the initial configuration.

For v>0v>0, the distribution is nontrivial. Close to v=0v=0, the distribution is dominated by vβ2/2α3T3/2v^{\beta_{2}/2\alpha_{3}T-3/2}, which integrates to vβ2/2α3T1/2v^{\beta_{2}/2\alpha_{3}T-1/2}. It is only finite below a critical Tc=β2/α3T_{c}=\beta_{2}/\alpha_{3}. This is a phase-transition-like behavior. As T(β2/α3)T\to(\beta_{2}/\alpha_{3})_{-}, the integral diverges and tends to a delta distribution. Namely, if T>TcT>T_{c}, we have ui=wi=0u_{i}=w_{i}=0 for all ii with probability 11, and no learning can happen. If T<TcT<T_{c}, the stationary distribution has a finite variance, and learning may happen. In the more general setting, where weight decay is present, this critical TT shifts to

Tc=β2γα3.T_{c}=\frac{\beta_{2}-\gamma}{\alpha_{3}}. (15)

When T=0T=0, the phase transition occurs at β2=γ\beta_{2}=\gamma, in agreement with the critical point identified in [51]. This critical learning rate also agrees with the discrete-time analysis performed in Refs. [49, 47] and the approximate continuous-time analysis in Ref.[4]. See Figure 2 for illustrations of the distribution across different values of TT. We also compare with the stationary distribution of a depth-0 model. Two characteristics of the two-layer model appear rather striking: (1) the solution becomes a delta distribution at the sparse solution u=w=0u=w=0 at a large learning rate; (2) the two-layer model never learns the incorrect sign (vv is always non-negative). See Figure 2.

Therefore, training with SGD on deeper models simultaneously have two advantages: (1) a generalization advantage such that a sparse solution is favored when the underlying data correlation is weak; (2) an optimization advantage such that the training loss interpolates between that of the global minimizer and the sparse saddle and is well-bounded (whereas a depth-0 model can have arbitrarily bad objective value at a large learning rate).

Another exotic phenomenon implied by the result is what we call the “fluctuation inversion.” Naively, the variance of model parameters should increase as we increase TT, the noise level in SGD. However, for the distribution we derived, the variance of vv and uu both decrease to zero as we increase TT: injecting noise makes the model fluctuation vanish. We discuss more about this “fluctuation inversion” in the next section.

Also, while there is no other phase-transition behavior below TcT_{c}, there is still an interesting and practically relevant crossover behavior in the distribution of the parameters as we change the learning rate. When we train a model, we often run SGD only once or a few times. When we do this, the most likely parameter we obtain is given by the maximum likelihood estimator of the distribution, v^:=argmaxp(v)\hat{v}:=\arg\max p(v). Understanding how v^(T)\hat{v}(T) changes as a function of TT is crucial. This quantity also exhibits nontrivial crossover behaviors at critical values of TT.

When T<TcT<T_{c}, a nonzero maximizer for p(v)p(v) must satisfy

v=β110α2T(β110α2T)2+28α1T(β23α3T)14α1T.v^{*}=-\frac{\beta_{1}-10\alpha_{2}T-\sqrt{(\beta_{1}-10\alpha_{2}T)^{2}+28\alpha_{1}T(\beta_{2}-3\alpha_{3}T)}}{14\alpha_{1}T}. (16)

The existence of this solution is nontrivial, which we analyze in Appendix A.5. When T0T\to 0, a solution always exists and is given by v=β2/β1v=\beta_{2}/\beta_{1}, which does not depend on the learning rate or noise CC. Note that β2/β1\beta_{2}/\beta_{1} is also the minimum point of L(ui,wi)L(u_{i},w_{i}). This means that SGD is only a consistent estimator of the local minima in deep learning in the vanishing learning rate limit. How biased is SGD at a finite learning rate? Two limits can be computed. For a small learning rate, the leading order correction to the solution is v=β2β1+(10α2β2β127α1β22β133α3β1)Tv=\frac{\beta_{2}}{\beta_{1}}+\left(\frac{10\alpha_{2}\beta_{2}}{\beta_{1}^{2}}-\frac{7\alpha_{1}\beta_{2}^{2}}{\beta_{1}^{3}}-\frac{3\alpha_{3}}{\beta_{1}}\right)T. This implies that the common Bayesian analysis that relies on a Laplace expansion of the loss fluctuation around a local minimum is improper. The fact that the stationary distribution of SGD is very far away from the Bayesian posterior also implies that SGD is only a good Bayesian sampler at a small learning rate.

Refer to caption
Figure 3: Regimes of learning for SGD as a function of T=η/ST=\eta/S and the noise in the dataset σ\sigma. According to (1) whether the sparse transition has happened, (2) whether a nontrivial maximum probability estimator exists, and (3) whether the sparse solution is a maximum probability estimator, the learning of SGD can be characterized into 55 regimes. Regime I is where SGD converges to a sparse solution with zero variance. In regime II, the stationary distribution has a finite spread, and the probability density of the sparse solution diverges. Hence, the probability of being close to the sparse solution is very high. In regime III, the probability density of the sparse solution is zero, and therefore the model will learn without much problem. In regime b, a local nontrivial probability maximum exists, and hence SGD has some probability of successful learning. The only maximum probability estimator in regime a is the sparse solution.

It is instructive to consider an example of a structured dataset: y=kx+ϵy=kx+\epsilon, where x𝒩(0,1)x\sim\mathcal{N}(0,1) and the noise ϵ\epsilon obeys ϵ𝒩(0,σ2)\epsilon\sim\mathcal{N}(0,\sigma^{2}). We let γ=0\gamma=0 for simplicity. If σ2>821k2\sigma^{2}>\frac{8}{21}k^{2}, there always exists a transitional learning rate: T=4k+42σ4(21σ28k2)T^{*}=\frac{4k+\sqrt{42}\sigma}{4(21\sigma^{2}-8k^{2})}.666We say“transitional” to indicate that it is different from the critical learning rate. Obviously, Tc/3<TT_{c}/3<T^{*}. One can characterize the learning of SGD by comparing TT with TcT_{c} and TT^{*}. For this simple example, SGD can be classified into roughly 55 different regimes. See Figure 3.

Refer to caption
Refer to caption
Refer to caption
Figure 4: SGD on deep networks leads to a well-controlled distribution and training loss. Left: Power law of the tail of the parameter distribution of deep linear nets. The dashed lines show the upper (7/2-7/2) and lower (5-5) bound of the exponent of the tail. The predicted power-law scaling agrees with the experiment, and the exponent decreases as the theory predicts. Mid: training loss of a tanh network. D=0D=0 is the case where only the input weight is trained, and D=1D=1 is the case where both input and output layers are trained. For D=0D=0, the model norm increases as the model loses stability. For D=1D=1, a “fluctuation inversion” effect appears. The fluctuation of the model vanishes before it loses stability. Right: performance of fully connected tanh nets on MNIST. Scaling the learning rate as 1/D1/D keeps the model performance relatively unchanged.

3.3 Power-Law Tail of Deeper Models

An interesting aspect of the depth-11 model is that its distribution is independent of the width dd of the model. This is not true for a deep model, as seen from Eq. (13). The dd-dependent term vanishes only if D=1D=1. Another intriguing aspect of the depth-11 distribution is that its tail is independent of any hyperparameter of the problem, dramatically different from the linear regression case. This is true for deeper models as well.

Since dd only affects the non-polynomial part of the distribution, the stationary distribution scales as p(v)1v3(11/(D+1))(α1v22α2v+α3)p(v)\propto\frac{1}{v^{3(1-1/(D+1))}(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3})}. Hence, when vv\to\infty, the scaling behaviour is v5+3/(D+1)v^{-5+3/(D+1)}. The tail gets monotonically thinner as one increases the depth. For D=1D=1, the exponent is 7/27/2; an infinite-depth network has an exponent of 55. Therefore, the tail of the model distribution only depends on the depth and is independent of the data or details of training, unlike the depth-0 model. In addition, due to the scaling v53/(D+1)v^{5-3/(D+1)} for vv\to\infty, we can see that 𝔼[v2]\mathbb{E}[v^{2}] will never diverge no matter how large the TT is. See Figure 4–mid.

One implication is that neural networks with at least one hidden layer will never have a divergent training loss. This directly explains the puzzling observation of the edge-of-stability phenomenon in deep learning: SGD training often gives a neural network a solution where a slight increment of the learning rate will cause discrete-time instability and divergence [43, 5]. These solutions, quite surprisingly, exhibit low training and testing loss values even when the learning rate is right at the critical learning rate of instability. This observation contradicts naive theoretical expectations. Let ηsta\eta_{\rm sta} denote the largest stable learning rate. Close to a local minimum, one can expand the loss function up to the second order to show that the value of the loss function LL is proportional to Tr[Σ]{\rm Tr}[\Sigma]. However, Σ1/(ηstaη)\Sigma\propto 1/(\eta_{\rm sta}-\eta) should be a very large value [45, 50, 20], and therefore LL should diverge. Thus, the edge of stability phenomenon is incompatible with the naive expectation up to the second order as pointed out in Ref. [6]. Our theory offers a direct explanation of why the divergence of loss does not happen: for deeper models, SGD always has a finite loss because of the power-law tail and fluctuation inversion. See Figure 4-right.

Refer to caption
Figure 5: Loss landscape and noise covariance of a two-layer linear network with one hidden neuron and γ=0.005\gamma=0.005. The orange dahsed curve shows the noise covariance C(w,u)C(w,u) where w=uw=u. We see that the shape of the gradient noise is, in general, a more complicated function than the landscape itself.

3.4 Role of Width

As discussed, for D>1D>1, the model width dd directly affects the stationary distribution of SGD. However, the integral in the exponent of Eq. (13) cannot be analytically calculated for a generic DD. Two cases exist where an analytical solution exists: D=1D=1 and DD\to\infty. We thus consider the case DD\to\infty to study the effect of dd.

As DD tends to infinity, the distribution becomes

p(v)1v3+k1(α1v22α2v+α3)1k1/2exp(dDT(β2α3v+α2α3β12α22β2+α1α3β2α32Δarctan(α1vα2Δ))),p(v)\propto\frac{1}{v^{3+k_{1}}(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3})^{1-k_{1}/2}}\exp\left(-\frac{d}{DT}\left(\frac{\beta_{2}}{\alpha_{3}v}+\frac{\alpha_{2}\alpha_{3}\beta_{1}-2\alpha_{2}^{2}\beta_{2}+\alpha_{1}\alpha_{3}\beta_{2}}{\alpha_{3}^{2}\sqrt{\Delta}}\arctan(\frac{\alpha_{1}v-\alpha_{2}}{\sqrt{\Delta}})\right)\right), (17)

where k1=d(α3β12α2β2)/(TDα32)k_{1}=d(\alpha_{3}\beta_{1}-2\alpha_{2}\beta_{2})/(TD\alpha_{3}^{2}). The first striking feature is that the architecture ratio d/Dd/D always appears simultaneously with 1/T1/T. This implies that for a sufficiently deep neural network, the ratio D/dD/d also becomes proportional to the strength of the noise. Since we know that T=η/ST=\eta/S determines the performance of SGD,777Therefore, scaling η\eta with 1/S1/S is known as the learning-rate-batch-size scaling law [12]. our result thus shows an extended scaling law of training: dDSη=const\frac{d}{D}\frac{S}{\eta}=const. For example, if we want to scale up the depth without changing the width, we can increase the learning rate proportionally or decrease the batch size. This scaling law thus links all the learning rates, the batch size, and the model width and depth. The architecture aspect of the scaling law also agrees with the suggestion in Refs. [10, 11], where the optimal architecture is found to have a constant ratio of d/Dd/D. See Figure 4.

Now, let us fix TT and understand the different limits of the stationary distribution, which is decided by the scaling of dd as we scale up DD. There are three situations: (1) d=o(D)d=o(D), (2) d=c0Dd=c_{0}D for a constant c0c_{0}, (3) d=Ω(D)d=\Omega(D). If d=o(D)d=o(D), k10k_{1}\to 0 and the distribution converges to p(v)v3(α1v22α2v+α3)1p(v)\propto{v^{-3}(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3})^{-1}}, which is a delta distribution at 0. Namely, if the width is far smaller than the depth, the model will collapse, and no learning will happen under SGD. Therefore, we should increase the model width as we increase the depth. In the second case, d/Dd/D is a constant and can thus be absorbed into the definition of TT and is the only limit where we obtain a nontrivial distribution with a finite spread. If d=Ω(D)d=\Omega(D), one can perform a saddle point approximation to see that the distribution becomes a delta distribution at the global minimum of the loss landscape, p(v)=δ(vβ2/β1)p(v)=\delta(v-\beta_{2}/\beta_{1}). Therefore, the learned model locates deterministically at the global minimum.

4 Discussion

The most important implication of our theory is that the behavior of SGD cannot be understood through gradient flow or a simple Langevin approximation. Having even a perturbative amount of noise in SGD leads to an order-11 change in the stationary distribution of the solution. Our result suggests that one promising way to understand SGD is to study its behavior on a landscape from the viewpoint of symmetries. We showed that SGD systematically moves towards a balanced solution when rescaling symmetry exists. Likewise, it is not difficult to imagine that for other types of symmetries, SGD will also have interesting systematic tendencies to deviate from gradient flow and Brownian motion. An important future direction is thus to understand and characterize the dynamics of SGD on a loss function with different types of symmetries.

By utilizing the symmetry conditions in the loss landscape, we are able to characterize the stationary distribution of SGD analytically. To the best of our knowledge, this is the first analytical expression for SGD obtained for a globally nonconvex and highly nonlinear loss function without the need for any approximation. With this solution, we have demonstrated many phenomena of deep learning that were previously unknown. For example, we showed the qualitative difference between networks with different depths, the finiteness of the training loss, the fluctuation inversion effect, the loss of ergodicity, and the incapability of learning a wrong sign for a deep model.

Lastly, let us return to the original question we raised in Introduction. Why is the Gibbs measure a bad model of SGD? When the number of data points NSN\gg S, a standard computation shows that the noise covariance of SGD takes the following form:C(θ)=T(𝔼x[(θ)(θ)T](θL)(θL)T)C(\theta)=T(\mathbb{E}_{x}[(\nabla_{\theta}\ell)(\nabla_{\theta}\ell)^{T}]-(\nabla_{\theta}L)(\nabla_{\theta}L)^{T}), which is nothing but the covariance of the gradients of θ\theta. A key feature of the noise is that it depends on the dynamical variable θ\theta in a highly nontrivial manner. See Figure 5 for an illustration of the landscape against CC. We see that the shape of C(θ)C(\theta) generally changes faster than the loss landscape. For the Gibbs distribution to hold (at least locally), we need C(θ)C(\theta) to change much slower than L(θ)L(\theta). A good criterion is thus to compare to the relative magnitude of L||\nabla L|| and Tr[C]||\nabla{\rm Tr}[C]||, which tells us which term changes faster. When Tr[C]||\nabla{\rm Tr}[C]|| is larger, unexpected phenomena will happen and one must consider the parameter dependence of CC to understand SGD.

Acknowledgement

We thank Prof. Tsunetsugu for the discussion on ergodicity. We also thank Shi Chen for valuable discussions about symmetry. This work is financially supported by a research grant from JSPS (Grant No. JP22H01152).

References

  • [1] John C Baez and Brendan Fong. A noether theorem for markov processes. Journal of Mathematical Physics, 54(1):013301, 2013.
  • [2] David GT Barrett and Benoit Dherin. Implicit gradient regularization. arXiv preprint arXiv:2009.11162, 2020.
  • [3] Raphaël Berthier. Incremental learning in diagonal linear networks. Journal of Machine Learning Research, 24(171):1–26, 2023.
  • [4] Feng Chen, Daniel Kunin, Atsushi Yamamura, and Surya Ganguli. Stochastic collapse: How gradient noise attracts sgd dynamics towards simpler subnetworks. arXiv preprint arXiv:2306.04251, 2023.
  • [5] Jeremy M Cohen, Simran Kaur, Yuanzhi Li, J Zico Kolter, and Ameet Talwalkar. Gradient descent on neural networks typically occurs at the edge of stability. arXiv preprint arXiv:2103.00065, 2021.
  • [6] Alex Damian, Eshaan Nichani, and Jason D Lee. Self-stabilization: The implicit bias of gradient descent at the edge of stability. arXiv preprint arXiv:2209.15594, 2022.
  • [7] Simon S Du, Wei Hu, and Jason D Lee. Algorithmic regularization in learning deep homogeneous models: Layers are automatically balanced. Advances in neural information processing systems, 31, 2018.
  • [8] Mathieu Even, Scott Pesme, Suriya Gunasekar, and Nicolas Flammarion. (s) gd over diagonal linear networks: Implicit regularisation, large stepsizes and edge of stability. arXiv preprint arXiv:2302.08982, 2023.
  • [9] Xavier Fontaine, Valentin De Bortoli, and Alain Durmus. Convergence rates and approximation results for sgd and its continuous-time counterpart. In Mikhail Belkin and Samory Kpotufe, editors, Proceedings of Thirty Fourth Conference on Learning Theory, volume 134 of Proceedings of Machine Learning Research, pages 1965–2058. PMLR, 15–19 Aug 2021.
  • [10] Boris Hanin. Which neural net architectures give rise to exploding and vanishing gradients? Advances in neural information processing systems, 31, 2018.
  • [11] Boris Hanin and David Rolnick. How to start training: The effect of initialization and architecture. Advances in Neural Information Processing Systems, 31, 2018.
  • [12] Elad Hoffer, Itay Hubara, and Daniel Soudry. Train longer, generalize better: closing the generalization gap in large batch training of neural networks. In Advances in Neural Information Processing Systems, pages 1731–1741, 2017.
  • [13] Wenqing Hu, Chris Junchi Li, Lei Li, and Jian-Guo Liu. On the diffusion approximation of nonconvex stochastic gradient descent. arXiv preprint arXiv:1705.07562, 2017.
  • [14] Kenji Kawaguchi. Deep learning without poor local minima. Advances in Neural Information Processing Systems, 29:586–594, 2016.
  • [15] Daniel Kunin, Javier Sagastuy-Brena, Surya Ganguli, Daniel LK Yamins, and Hidenori Tanaka. Neural mechanics: Symmetry and broken conservation laws in deep learning dynamics. arXiv preprint arXiv:2012.04728, 2020.
  • [16] Jonas Latz. Analysis of stochastic gradient descent in continuous time. Statistics and Computing, 31(4):39, 2021.
  • [17] Qianxiao Li, Cheng Tai, and Weinan E. Stochastic modified equations and dynamics of stochastic gradient algorithms i: Mathematical foundations. Journal of Machine Learning Research, 20(40):1–47, 2019.
  • [18] Qianyi Li and Haim Sompolinsky. Statistical mechanics of deep linear neural networks: The backpropagating kernel renormalization. Physical Review X, 11(3):031059, 2021.
  • [19] Zhiyuan Li, Sadhika Malladi, and Sanjeev Arora. On the validity of modeling sgd with stochastic differential equations (sdes), 2021.
  • [20] Kangqiao Liu, Liu Ziyin, and Masahito Ueda. Noise and fluctuation of finite learning rate stochastic gradient descent, 2021.
  • [21] Siyuan Ma, Raef Bassily, and Mikhail Belkin. The power of interpolation: Understanding the effectiveness of sgd in modern over-parametrized learning. In International Conference on Machine Learning, pages 3325–3334. PMLR, 2018.
  • [22] Agnieszka B Malinowska and Moulay Rchid Sidi Ammi. Noether’s theorem for control problems on time scales. arXiv preprint arXiv:1406.0705, 2014.
  • [23] Stephan Mandt, Matthew D Hoffman, and David M Blei. Stochastic gradient descent as approximate bayesian inference. Journal of Machine Learning Research, 18:1–35, 2017.
  • [24] John C Mauro, Prabhat K Gupta, and Roger J Loucks. Continuously broken ergodicity. The Journal of chemical physics, 126(18), 2007.
  • [25] Chris Mingard, Guillermo Valle-Pérez, Joar Skalse, and Ard A Louis. Is sgd a bayesian sampler? well, almost. The Journal of Machine Learning Research, 22(1):3579–3642, 2021.
  • [26] Tetsuya Misawa. Noether’s theorem in symmetric stochastic calculus of variations. Journal of mathematical physics, 29(10):2178–2180, 1988.
  • [27] Takashi Mori, Liu Ziyin, Kangqiao Liu, and Masahito Ueda. Power-law escape rate of sgd. In International Conference on Machine Learning, pages 15959–15975. PMLR, 2022.
  • [28] Mor Shpigel Nacson, Kavya Ravichandran, Nathan Srebro, and Daniel Soudry. Implicit bias of the step size in linear diagonal neural networks. In International Conference on Machine Learning, pages 16270–16295. PMLR, 2022.
  • [29] Richard G Palmer. Broken ergodicity. Advances in Physics, 31(6):669–735, 1982.
  • [30] Scott Pesme, Loucas Pillaud-Vivien, and Nicolas Flammarion. Implicit bias of sgd for diagonal linear networks: a provable benefit of stochasticity. Advances in Neural Information Processing Systems, 34:29218–29230, 2021.
  • [31] Tomasz Rolski, Hanspeter Schmidli, Volker Schmidt, and Jozef L Teugels. Stochastic processes for insurance and finance. John Wiley & Sons, 2009.
  • [32] N. Shirish Keskar, D. Mudigere, J. Nocedal, M. Smelyanskiy, and P. T. P. Tang. On Large-Batch Training for Deep Learning: Generalization Gap and Sharp Minima. ArXiv e-prints, September 2016.
  • [33] Justin Sirignano and Konstantinos Spiliopoulos. Stochastic gradient descent in continuous time: A central limit theorem. Stochastic Systems, 10(2):124–151, 2020.
  • [34] Samuel L Smith, Benoit Dherin, David GT Barrett, and Soham De. On the origin of implicit regularization in stochastic gradient descent. arXiv preprint arXiv:2101.12176, 2021.
  • [35] D Thirumalai and Raymond D Mountain. Activated dynamics, loss of ergodicity, and transport in supercooled liquids. Physical Review E, 47(1):479, 1993.
  • [36] Christopher J Turner, Alexios A Michailidis, Dmitry A Abanin, Maksym Serbyn, and Zlatko Papić. Weak ergodicity breaking from quantum many-body scars. Nature Physics, 14(7):745–749, 2018.
  • [37] Nicolaas Godfried Van Kampen. Stochastic processes in physics and chemistry, volume 1. Elsevier, 1992.
  • [38] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. Advances in neural information processing systems, 30, 2017.
  • [39] Peter Walters. An introduction to ergodic theory, volume 79. Springer Science & Business Media, 2000.
  • [40] Yuqing Wang, Minshuo Chen, Tuo Zhao, and Molei Tao. Large learning rate tames homogeneity: Convergence and balancing effect, 2022.
  • [41] Zihao Wang and Liu Ziyin. Posterior collapse of a linear latent variable model. arXiv preprint arXiv:2205.04009, 2022.
  • [42] Blake Woodworth, Kumar Kshitij Patel, Sebastian Stich, Zhen Dai, Brian Bullins, Brendan Mcmahan, Ohad Shamir, and Nathan Srebro. Is local sgd better than minibatch sgd? In International Conference on Machine Learning, pages 10334–10343. PMLR, 2020.
  • [43] Lei Wu, Chao Ma, et al. How sgd selects the global minima in over-parameterized learning: A dynamical stability perspective. Advances in Neural Information Processing Systems, 31, 2018.
  • [44] Zeke Xie, Issei Sato, and Masashi Sugiyama. A diffusion theory for deep learning dynamics: Stochastic gradient descent exponentially favors flat minima. arXiv preprint arXiv:2002.03495, 2020.
  • [45] Sho Yaida. Fluctuation-dissipation relations for stochastic gradient descent. arXiv preprint arXiv:1810.00004, 2018.
  • [46] Zhanxing Zhu, Jingfeng Wu, Bing Yu, Lei Wu, and Jinwen Ma. The anisotropic noise in stochastic gradient descent: Its behavior of escaping from sharp minima and regularization effects. arXiv preprint arXiv:1803.00195, 2018.
  • [47] Liu Ziyin, Botao Li, Tomer Galanti, and Masahito Ueda. The probabilistic stability of stochastic gradient descent. arXiv preprint arXiv:2303.13093, 2023.
  • [48] Liu Ziyin, Botao Li, and Xiangming Meng. Exact solutions of a deep linear network. In Advances in Neural Information Processing Systems, 2022.
  • [49] Liu Ziyin, Botao Li, James B. Simon, and Masahito Ueda. Sgd may never escape saddle points, 2021.
  • [50] Liu Ziyin, Kangqiao Liu, Takashi Mori, and Masahito Ueda. Strength of minibatch noise in SGD. In International Conference on Learning Representations, 2022.
  • [51] Liu Ziyin and Masahito Ueda. Exact phase transitions in deep learning. arXiv preprint arXiv:2205.12510, 2022.
  • [52] Difan Zou, Jingfeng Wu, Vladimir Braverman, Quanquan Gu, and Sham Kakade. Benign overfitting of constant-stepsize sgd for linear regression. In Conference on Learning Theory, pages 4633–4635. PMLR, 2021.

Appendix A Theoretical Considerations

A.1 Background

A.1.1 Ito’s Lemma

Let us consider the following stochastic differential equation (SDE) for a Wiener process W(t)W(t):

dXt=μtdt+σtdW(t).dX_{t}=\mu_{t}dt+\sigma_{t}dW(t). (18)

We are interested in the dynamics of a generic function of XtX_{t}. Let Yt=f(t,Xt)Y_{t}=f(t,X_{t}); Ito’s lemma states that the SDE for the new variable is

df(t,Xt)=(ft+μtfXt+σt222fXt2)dt+σtfxdW(t).df(t,X_{t})=\left(\frac{\partial f}{\partial t}+\mu_{t}\frac{\partial f}{\partial X_{t}}+\frac{\sigma_{t}^{2}}{2}\frac{\partial^{2}f}{\partial X_{t}^{2}}\right)dt+\sigma_{t}\frac{\partial f}{\partial x}dW(t). (19)

Let us take the variable Yt=Xt2Y_{t}=X_{t}^{2} as an example. Then the SDE is

dYt=(2μtXt+σt2)dt+2σtXtdW(t).dY_{t}=\left(2\mu_{t}X_{t}+\sigma_{t}^{2}\right)dt+2\sigma_{t}X_{t}dW(t). (20)

Let us consider another example. Let two variables XtX_{t} and YtY_{t} follow

dXt\displaystyle dX_{t} =μtdt+σtdW(t),\displaystyle=\mu_{t}dt+\sigma_{t}dW(t),
dYt\displaystyle dY_{t} =λtdt+ϕtdW(t).\displaystyle=\lambda_{t}dt+\phi_{t}dW(t). (21)

The SDE of XtYtX_{t}Y_{t} is given by

d(XtYt)=(μtYt+λtXt+σtϕt)dt+(σtYt+ϕtXt)dW(t).d(X_{t}Y_{t})=(\mu_{t}Y_{t}+\lambda_{t}X_{t}+\sigma_{t}\phi_{t})dt+(\sigma_{t}Y_{t}+\phi_{t}X_{t})dW(t). (22)

A.1.2 Fokker Planck Equation

The general SDE of a 1d variable XX is given by:

dX=μ(X)dt+B(X)dW(t).{dX}=-\mu(X)dt+B(X){dW(t)}. (23)

The time evolution of the probability density P(x,t)P(x,t) is given by the Fokker-Planck equation:

P(X,t)t=XJ(X,t),\frac{\partial P(X,t)}{\partial t}=-\frac{\partial}{\partial X}J(X,t), (24)

where J(X,t)=μ(X)P(X,t)+12X[B2(X)P(X,t)]J(X,t)=\mu(X)P(X,t)+\frac{1}{2}\frac{\partial}{\partial X}[B^{2}(X)P(X,t)]. The stationary distribution satisfying P(X,t)/t=0{\partial P(X,t)}/{\partial t}=0 is

P(X)1B2(X)exp[𝑑X2μ(X)B2(X)]:=P~(X),P(X)\propto\frac{1}{B^{2}(X)}\exp\left[-\int dX\frac{2\mu(X)}{B^{2}(X)}\right]:=\tilde{P}(X), (25)

which gives a solution as a Boltzmann-type distribution if BB is a constant. We will apply Eq. (25) to determine the stationary distributions in the following sections.

A.2 Proof of Theorem 1

Proof.

By definition of the symmetry (𝐮,𝐰,x)=(λ𝐮,𝐰/λ,x)\ell(\mathbf{u},\mathbf{w},x)=\ell(\lambda\mathbf{u},\mathbf{w}/\lambda,x), we obtain its infinitesimal transformation (𝐮,𝐰,x)=((1+ϵ)𝐮,(1ϵ)𝐰/λ,x)\ell(\mathbf{u},\mathbf{w},x)=\ell((1+\epsilon)\mathbf{u},(1-\epsilon)\mathbf{w}/\lambda,x). Expanding this to first order in ϵ\epsilon, we obtain

iuiui=jwjwj.\sum_{i}u_{i}\frac{\partial\ell}{\partial u_{i}}=\sum_{j}w_{j}\frac{\partial\ell}{\partial w_{j}}. (26)

The equations of motion are

duidt\displaystyle\frac{du_{i}}{dt} =ui,\displaystyle=-\frac{\partial\ell}{\partial u_{i}}, (27)
dwjdt\displaystyle\frac{dw_{j}}{dt} =wj.\displaystyle=-\frac{\partial\ell}{\partial w_{j}}. (28)

Using Ito’s lemma, we can find the equations governing the evolutions of ui2u_{i}^{2} and wj2w_{j}^{2}:

dui2dt\displaystyle\frac{du_{i}^{2}}{dt} =2uiduidt+(dui)2dt=2uiui+TCiu,\displaystyle=2u_{i}\frac{du_{i}}{dt}+\frac{(du_{i})^{2}}{dt}=-2u_{i}\frac{\partial\ell}{\partial u_{i}}+TC_{i}^{u},
dwj2dt\displaystyle\frac{dw_{j}^{2}}{dt} =2wjdwjdt+(dwj)2dt=2wjwj+TCjw,\displaystyle=2w_{j}\frac{dw_{j}}{dt}+\frac{(dw_{j})^{2}}{dt}=-2w_{j}\frac{\partial\ell}{\partial w_{j}}+TC_{j}^{w}, (29)

where Ciu=Var[ui]C_{i}^{u}={\rm Var}[\frac{\partial\ell}{\partial u_{i}}] and Cjw=Var[wj]C_{j}^{w}={\rm Var}[\frac{\partial\ell}{\partial w_{j}}]. With Eq. (26), we obtain

ddt(u2w2)=T(jCjwiCiu)=T(jVar[wj]iVar[ui]).\frac{d}{dt}(||u||^{2}-||w||^{2})=-T(\sum_{j}C_{j}^{w}-\sum_{i}C_{i}^{u})=-T\left(\sum_{j}{\rm Var}\left[\frac{\partial\ell}{\partial w_{j}}\right]-\sum_{i}{\rm Var}\left[\frac{\partial\ell}{\partial u_{i}}\right]\right). (30)

Due to the rescaling symmetry, the loss function can be considered as a function of the matrix uwTuw^{T}. Here we define a new loss function as ~(uiwj)=(ui,wj)\tilde{\ell}(u_{i}w_{j})=\ell(u_{i},w_{j}). Hence, we have

wj=iui~(uiwj),ui=jwj~(uiwj).\frac{\partial\ell}{\partial w_{j}}=\sum_{i}u_{i}\frac{\partial\tilde{\ell}}{\partial(u_{i}w_{j})},\frac{\partial\ell}{\partial u_{i}}=\sum_{j}w_{j}\frac{\partial\tilde{\ell}}{\partial(u_{i}w_{j})}. (31)

We can rewrite Eq. (30) into

ddt(u2w2)=T(uTC1uwTC2w),,\frac{d}{dt}(||u||^{2}-||w||^{2})=-T(u^{T}C_{1}u-w^{T}C_{2}w),, (32)

where

(C1)ij\displaystyle(C_{1})_{ij} =𝔼[k~(uiwk)~(ujwk)]k𝔼[~(uiwk)]𝔼[~(ujwk)],\displaystyle=\mathbb{E}\left[\sum_{k}\frac{\partial\tilde{\ell}}{\partial(u_{i}w_{k})}\frac{\partial\tilde{\ell}}{\partial(u_{j}w_{k})}\right]-\sum_{k}\mathbb{E}\left[\frac{\partial\tilde{\ell}}{\partial(u_{i}w_{k})}\right]\mathbb{E}\left[\frac{\partial\tilde{\ell}}{\partial(u_{j}w_{k})}\right],
𝔼[ATA]𝔼[AT]𝔼[A]\displaystyle\equiv\mathbb{E}[A^{T}A]-\mathbb{E}[A^{T}]\mathbb{E}[A] (33)
(C2)kl\displaystyle(C_{2})_{kl} =𝔼[i~(uiwk)~(uiwl)]i𝔼[~(uiwk)]𝔼[~(uiwl)]\displaystyle=\mathbb{E}\left[\sum_{i}\frac{\partial\tilde{\ell}}{\partial(u_{i}w_{k})}\frac{\partial\tilde{\ell}}{\partial(u_{i}w_{l})}\right]-\sum_{i}\mathbb{E}\left[\frac{\partial\tilde{\ell}}{\partial(u_{i}w_{k})}\right]\mathbb{E}\left[\frac{\partial\tilde{\ell}}{\partial(u_{i}w_{l})}\right]
𝔼[AAT]𝔼[A]𝔼[AT],\displaystyle\equiv\mathbb{E}[AA^{T}]-\mathbb{E}[A]\mathbb{E}[A^{T}], (34)

where

(A)ik~(uiwk).(A)_{ik}\equiv\frac{\partial\tilde{\ell}}{\partial(u_{i}w_{k})}. (35)

The proof is thus complete. ∎

A.3 Proof of Theorem 2

Proof.

This proof is based on the fact that if a certain condition is satisfied for all trajectories with probability 1, this condition is satisfied by the stationary distribution of the dynamics with probability 1.

Let us first consider the case of D>1D>1. We first show that any trajectory satisfies at least one of the following five conditions: for any ii, (i) vi0v_{i}\to 0, (ii) L(θ)0L(\theta)\to 0, or (iii) for any klk\neq l, (ui(k))2(ui(l))20(u_{i}^{(k)})^{2}-(u_{i}^{(l)})^{2}\to 0.

The SDE for ui(k)u_{i}^{(k)} is

dui(k)dt=2viui(k)(β1vβ2)+2viui(k)η(α1v22α2v+α3)dWdt,\frac{du_{i}^{(k)}}{dt}=-2\frac{v_{i}}{u_{i}^{(k)}}(\beta_{1}v-\beta_{2})+2\frac{v_{i}}{u_{i}^{(k)}}\sqrt{\eta(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3})}\frac{dW}{dt}, (36)

where vi:=k=1Dui(k)v_{i}:=\prod_{k=1}^{D}u_{i}^{(k)}, and so v=iviv=\sum_{i}v_{i}. There exists rescaling symmetry between ui(k)u_{i}^{(k)} and ui(l)u_{i}^{(l)} for klk\neq l. By the law of balance, we have

ddt[(ui(k))2(ui(l))2]=T[(ui(k))2(ui(l))2]Var[(ui(k)ui(l))],\frac{d}{dt}[(u_{i}^{(k)})^{2}-(u_{i}^{(l)})^{2}]=-T[(u_{i}^{(k)})^{2}-(u_{i}^{(l)})^{2}]{\rm Var}\left[\frac{\partial\ell}{\partial(u_{i}^{(k)}u_{i}^{(l)})}\right], (37)

where

Var[(ui(k)ui(l))]=(viui(k)ui(l))2(α1v22α2v+α3){\rm Var}\left[\frac{\partial\ell}{\partial(u_{i}^{(k)}u_{i}^{(l)})}\right]=(\frac{v_{i}}{u_{i}^{(k)}u_{i}^{(l)}})^{2}(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3}) (38)

with vi/(ui(k)ui(l))=sk,lui(s)v_{i}/(u_{i}^{(k)}u_{i}^{(l)})=\prod_{s\neq k,l}u_{i}^{(s)}. In the long-time limit, (ui(k))2(u_{i}^{(k)})^{2} converges to (ui(l))2(u_{i}^{(l)})^{2} unless Var[(ui(k)ui(l))]=0{\rm Var}\left[\frac{\partial\ell}{\partial(u_{i}^{(k)}u_{i}^{(l)})}\right]=0, which is equivalent to vi/(ui(k)ui(l))=0v_{i}/(u_{i}^{(k)}u_{i}^{(l)})=0 or α1v22α2v+α3=0\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3}=0. These two conditions correspond to conditions (i) and (ii). The latter is because α1v22α2v+α3=0\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3}=0 takes place if and only if v=α2/α1v=\alpha_{2}/\alpha_{1} and α22α1α3=0\alpha_{2}^{2}-\alpha_{1}\alpha_{3}=0 together with L(θ)=0L(\theta)=0. Therefore, at stationarity, we must have conditions (i), (ii), or (iii).

Now, we prove that when (iii) holds, the condition 2-(b) in the theorem statement must hold: for D=1D=1, (log|vi|log|vj|)=c0(\log|v_{i}|-\log|v_{j}|)=c_{0} with sgn(vi)=sgn(vj){\rm sgn}(v_{i})={\rm sgn}(v_{j}). When (iii) holds, there are two situations. First, if vi=0v_{i}=0, we have ui(k)=0u_{i}^{(}k)=0 for all kk, and viv_{i} will stay 0 for the rest of the trajectory, which corresponds to condition (i).

If vi0v_{i}\neq 0, we have ui(k)0u_{i}^{(k)}\neq 0 for all kk. Therefore, the dynamics of viv_{i} is

dvidt=2k(viui(k))2(β1vβ2)+2k(viui(k))2η(α1v22α2v+α3)dWdt+4k,l(vi3(ui(k)ui(l))2)η(α1v22α2v+α3).\frac{dv_{i}}{dt}=-2\sum_{k}\left(\frac{v_{i}}{u_{i}^{(k)}}\right)^{2}(\beta_{1}v-\beta_{2})+2\sum_{k}\left(\frac{v_{i}}{u_{i}^{(k)}}\right)^{2}\sqrt{\eta(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3})}\frac{dW}{dt}+4\sum_{k,l}\left(\frac{v_{i}^{3}}{(u_{i}^{(k)}u_{i}^{(l)})^{2}}\right)\eta(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3}). (39)

Comparing the dynamics of viv_{i} and vjv_{j} for iji\neq j, we obtain

dvi/dtk(vi/ui(k))2dvj/dtk(vj/uj(k))2\displaystyle\frac{dv_{i}/dt}{\sum_{k}(v_{i}/u_{i}^{(k)})^{2}}-\frac{dv_{j}/dt}{\sum_{k}(v_{j}/u_{j}^{(k)})^{2}} =4(m,lvi3/(ui(m)ui(l))2k(vi/ui(k))2m,lvj3/(uj(m)uj(l))2k(vj/uj(k))2)η(α1v22α2v+α3)\displaystyle=4\left(\frac{\sum_{m,l}v_{i}^{3}/(u_{i}^{(m)}u_{i}^{(l)})^{2}}{\sum_{k}(v_{i}/u_{i}^{(k)})^{2}}-\frac{\sum_{m,l}v_{j}^{3}/(u_{j}^{(m)}u_{j}^{(l)})^{2}}{\sum_{k}(v_{j}/u_{j}^{(k)})^{2}}\right)\eta(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3})
=4(vim,lvi2/(ui(m)ui(l))2k(vi/ui(k))2vjm,lvj2/(uj(m)uj(l))2k(vj/uj(k))2)η(α1v22α2v+α3).\displaystyle=4\left(v_{i}\frac{\sum_{m,l}v_{i}^{2}/(u_{i}^{(m)}u_{i}^{(l)})^{2}}{\sum_{k}(v_{i}/u_{i}^{(k)})^{2}}-v_{j}\frac{\sum_{m,l}v_{j}^{2}/(u_{j}^{(m)}u_{j}^{(l)})^{2}}{\sum_{k}(v_{j}/u_{j}^{(k)})^{2}}\right)\eta(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3}). (40)

By condition (iii), we have |ui(0)|==|ui(D)||u_{i}^{(0)}|=\cdots=|u_{i}^{(D)}|, i.e., (vi/ui(k))2=(vi2)D/(D+1)(v_{i}/u_{i}^{(k)})^{2}=(v_{i}^{2})^{D/(D+1)} and (vi/ui(m)ui(l))2=(vi2)(D1)/(D+1)(v_{i}/u_{i}^{(m)}u_{i}^{(l)})^{2}=(v_{i}^{2})^{(D-1)/(D+1)}.888Here, we only consider the root on the positive real axis. Therefore, we obtain

dvi/dt(D+1)(vi2)D/(D+1)dvj/dt(D+1)(vj2)D/(D+1)=(viD(vi2)(D1)/(D+1)2(vi2)D/(D+1)vjD(vj2)(D1)/(D+1)2(vj2)D/(D+1))η(α1v22α2v+α3).\frac{dv_{i}/dt}{(D+1)(v_{i}^{2})^{D/(D+1)}}-\frac{dv_{j}/dt}{(D+1)(v_{j}^{2})^{D/(D+1)}}=\left(v_{i}\frac{D(v_{i}^{2})^{(D-1)/(D+1)}}{2(v_{i}^{2})^{D/(D+1)}}-v_{j}\frac{D(v_{j}^{2})^{(D-1)/(D+1)}}{2(v_{j}^{2})^{D/(D+1)}}\right)\eta(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3}). (41)

We first consider the case where viv_{i} and vjv_{j} initially share the same sign (both positive or both negative). When D>1D>1, the left-hand side of Eq. (41) can be written as

11Ddvi2/(D+1)1dt+4Dvi12/(D+1)η(α1v22α2v+α3)11Ddvj2/(D+1)1dt4Dvj12/(D+1)η(α1v22α2v+α3),\frac{1}{1-D}\frac{dv_{i}^{2/(D+1)-1}}{dt}+4Dv_{i}^{1-2/(D+1)}\eta(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3})-\frac{1}{1-D}\frac{dv_{j}^{2/(D+1)-1}}{dt}-4Dv_{j}^{1-2/(D+1)}\eta(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3}), (42)

which follows from Ito’s lemma:

dvi2/(D+1)1dt\displaystyle\frac{dv_{i}^{2/(D+1)-1}}{dt} =(2D+11)vi2/(D+1)2dvidt+2(2D+11)(2D+12)vi2/(D+1)3(k(viui(k))2η(α1v22α2v+α3))2\displaystyle=\left(\frac{2}{D+1}-1\right)v_{i}^{2/(D+1)-2}\frac{dv_{i}}{dt}+2(\frac{2}{D+1}-1)(\frac{2}{D+1}-2)v_{i}^{2/(D+1)-3}\left(\sum_{k}(\frac{v_{i}}{u_{i}^{(k)}})^{2}\sqrt{\eta(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3})}\right)^{2}
=(2D+11)vi2/(D+1)2dvidt+4D(D1)vi12/(D+1)η(α1v22α2v+α3).\displaystyle=(\frac{2}{D+1}-1)v_{i}^{2/(D+1)-2}\frac{dv_{i}}{dt}+4D(D-1)v_{i}^{1-2/(D+1)}\eta(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3}). (43)

Substitute in Eq. (41), we obtain Eq. (42).

Now, we consider the right-hand side of Eq. (41), which is given by

2Dvi12/(D+1)η(α1v22α2v+α3)2Dvj12/(D+1)η(α1v22α2v+α3).2Dv_{i}^{1-2/(D+1)}\eta(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3})-2Dv_{j}^{1-2/(D+1)}\eta(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3}). (44)

Combining Eq. (42) and Eq. (44), we obtain

11Ddvi2/(D+1)1dt11Ddvj2/(D+1)1dt=2D(vi12/(D+1)vj12/(D+1))η(α1v22α2v+α3).\frac{1}{1-D}\frac{dv_{i}^{2/(D+1)-1}}{dt}-\frac{1}{1-D}\frac{dv_{j}^{2/(D+1)-1}}{dt}=-2D(v_{i}^{1-2/(D+1)}-v_{j}^{1-2/(D+1)})\eta(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3}). (45)

By defining zi=vi2/(D+1)1z_{i}=v_{i}^{2/(D+1)-1}, we can further simplify the dynamics:

d(zizj)dt\displaystyle\frac{d(z_{i}-z_{j})}{dt} =2D(D1)(1zi1zj)η(α1v22α2v+α3)\displaystyle=2D(D-1)\left(\frac{1}{z_{i}}-\frac{1}{z_{j}}\right)\eta(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3})
=2D(D1)zizjzizjη(α1v22α2v+α3).\displaystyle=-2D(D-1)\frac{z_{i}-z_{j}}{z_{i}z_{j}}\eta(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3}). (46)

Hence,

zi(t)zj(t)=exp[𝑑t2D(D1)zizjη(α1v22α2v+α3)].z_{i}(t)-z_{j}(t)=\exp\left[-\int dt\frac{2D(D-1)}{z_{i}z_{j}}\eta(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3})\right]. (47)

Therefore, if viv_{i} and vjv_{j} initially have the same sign, they will decay to the same value in the long-time limit tt\to\infty, which gives condition 2-(b). When viv_{i} and vjv_{j} initially have different signs, we can write Eq. (41) as

d|vi|/dt(D+1)(|vi|2)D/(D+1)+d|vj|/dt(D+1)(|vj|2)D/(D+1)=\displaystyle\frac{d|v_{i}|/dt}{(D+1)(|v_{i}|^{2})^{D/(D+1)}}+\frac{d|v_{j}|/dt}{(D+1)(|v_{j}|^{2})^{D/(D+1)}}= (|vi|D(|vi|2)(D1)/(D+1)2(|vi|2)D/(D+1)+|vj|D(|vj|2)(D1)/(D+1)2(|vj|2)D/(D+1))\displaystyle\left(|v_{i}|\frac{D(|v_{i}|^{2})^{(D-1)/(D+1)}}{2(|v_{i}|^{2})^{D/(D+1)}}+|v_{j}|\frac{D(|v_{j}|^{2})^{(D-1)/(D+1)}}{2(|v_{j}|^{2})^{D/(D+1)}}\right)
×η(α1v22α2v+α3).\displaystyle\times\eta(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3}). (48)

Hence, when D>1D>1, we simplify the equation with a similar procedure as

11Dd|vi|2/(D+1)1dt+11Dd|vj|2/(D+1)1dt=2D(|vi|12/(D+1)+|vj|12/(D+1))η(α1v22α2v+α3).\frac{1}{1-D}\frac{d|v_{i}|^{2/(D+1)-1}}{dt}+\frac{1}{1-D}\frac{d|v_{j}|^{2/(D+1)-1}}{dt}=-2D(|v_{i}|^{1-2/(D+1)}+|v_{j}|^{1-2/(D+1)})\eta(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3}). (49)

Defining zi=|vi|2/(D+1)1z_{i}=|v_{i}|^{2/(D+1)-1}, we obtain

d(zi+zj)dt\displaystyle\frac{d(z_{i}+z_{j})}{dt} =2D(D1)(1zi+1zj)η(α1v22α2v+α3)\displaystyle=2D(D-1)\left(\frac{1}{z_{i}}+\frac{1}{z_{j}}\right)\eta(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3})
=2D(D1)zi+zjzizjη(α1v22α2v+α3),\displaystyle=2D(D-1)\frac{z_{i}+z_{j}}{z_{i}z_{j}}\eta(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3}), (50)

which implies

zi(t)+zj(t)=exp[𝑑t2D(D1)zizjη(α1v22α2v+α3)].z_{i}(t)+z_{j}(t)=\exp\left[\int dt\frac{2D(D-1)}{z_{i}z_{j}}\eta(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3})\right]. (51)

From this equation, we reach the conclusion that if viv_{i} and vjv_{j} have different signs initially, one of them converges to 0 in the long-time limit tt\to\infty, corresponding to condition 1 in the theorem statement. Hence, for D>1D>1, at least one of the conditions is always satisfied at tt\to\infty.

Now, we prove the theorem for D=1D=1, which is similar to the proof above. The law of balance gives

ddt[(ui(1))2(ui(2))2]=T[(ui(1))2(ui(2))2]Var[(ui(1)ui(2))].\frac{d}{dt}[(u_{i}^{(1)})^{2}-(u_{i}^{(2)})^{2}]=-T[(u_{i}^{(1)})^{2}-(u_{i}^{(2)})^{2}]{\rm Var}\left[\frac{\partial\ell}{\partial(u_{i}^{(1)}u_{i}^{(2)})}\right]. (52)

We can see that |ui(1)||ui(2)||u_{i}^{(1)}|\to|u_{i}^{(2)}| takes place unless Var[(ui(1)ui(2))]=0{\rm Var}\left[\frac{\partial\ell}{\partial(u_{i}^{(1)}u_{i}^{(2)})}\right]=0, which is equivalent to L(θ)=0L(\theta)=0. This corresponds to condition (ii). Hence, if condition (ii) is violated, we need to prove condition (iii). In this sense, |ui(1)||ui(2)||u_{i}^{(1)}|\to|u_{i}^{(2)}| occurs and Eq. (41) can be rewritten as

dvi/dt|vi|dvj/dt|vj|=(sign(vi)sign(vj))η(α1v22α2v+α3).\frac{dv_{i}/dt}{|v_{i}|}-\frac{dv_{j}/dt}{|v_{j}|}=(\text{sign}(v_{i})-\text{sign}(v_{j}))\eta(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3}). (53)

When viv_{i} and vjv_{j} are both positive, we have

dvi/dtvidvj/dtvj=0.\frac{dv_{i}/dt}{v_{i}}-\frac{dv_{j}/dt}{v_{j}}=0. (54)

With Ito’s lemma, we have

dlog(vi)dt=dvividt2η(α1v22α2v+α3).\frac{d\log(v_{i})}{dt}=\frac{dv_{i}}{v_{i}dt}-2\eta(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3}). (55)

Therefore, Eq. (54) can be simplified to

d(log(vi)log(vj))dt=0,\frac{d(\log(v_{i})-\log(v_{j}))}{dt}=0, (56)

which indicates that all viv_{i} with the same sign will decay at the same rate. This differs from the case of D>2D>2 where all viv_{i} decay to the same value. Similarly, we can prove the case where viv_{i} and vjv_{j} are both negative.

Now, we consider the case where viv_{i} is positive while vjv_{j} is negative and rewrite Eq. (53) as

dvi/dtvi+d(|vj|)/dt|vj|=2η(α1v22α2v+α3).\frac{dv_{i}/dt}{v_{i}}+\frac{d(|v_{j}|)/dt}{|v_{j}|}=2\eta(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3}). (57)

Furthermore, we can derive the dynamics of vjv_{j} with Ito’s lemma:

dlog(|vj|)dt=dvividt2η(α1v22α2v+α3).\frac{d\log(|v_{j}|)}{dt}=\frac{dv_{i}}{v_{i}dt}-2\eta(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3}). (58)

Therefore, Eq. (57) takes the form of

d(log(vi)+log(|vj|))dt=2η(α1v22α2v+α3).\frac{d(\log(v_{i})+\log(|v_{j}|))}{dt}=-2\eta(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3}). (59)

In the long-time limit, we can see log(vi|vj|)\log(v_{i}|v_{j}|) decays to -\infty, indicating that either viv_{i} or vjv_{j} will decay to 0. This corresponds to condition 1 in the theorem statement. Combining Eq. (56) and Eq. (59), we conclude that all viv_{i} have the same sign as tt\to\infty, which indicates condition 2-(a) if conditions in item 1 are all violated. The proof is thus complete. ∎

A.4 Stationary distribution in Eq. (13)

Following Eq. (39), we substitute ui(k)u_{i}^{(k)} with vi1/Dv_{i}^{1/D} for arbitrary kk and obtain

dvidt=\displaystyle\frac{dv_{i}}{dt}= 2(D+1)|vi|2D/(D+1)(β1vβ2)+2(D+1)|vi|2D/(D+1)η(α1v22α2v+α3)dWdt\displaystyle-2(D+1)|v_{i}|^{2D/(D+1)}(\beta_{1}v-\beta_{2})+2(D+1)|v_{i}|^{2D/(D+1)}\sqrt{\eta(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3})}\frac{dW}{dt}
+2(D+1)Dvi3|vi|4/(D+1)η(α1v22α2v+α3).\displaystyle+2(D+1)Dv_{i}^{3}|v_{i}|^{-4/(D+1)}\eta(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3}). (60)

With Eq. (47), we can see that for arbitrary ii and jj, viv_{i} will converge to vjv_{j} in the long-time limit. In this case, we have v=dviv=dv_{i} for each ii. Then, the SDE for vv can be written as

dvdt=\displaystyle\frac{dv}{dt}= 2(D+1)d2/(D+1)1|v|2D/(D+1)(β1vβ2)+2(D+1)d2/(D+1)1|v|2D/(D+1)η(α1v22α2v+α3)dWdt\displaystyle-2(D+1)d^{2/(D+1)-1}|v|^{2D/(D+1)}(\beta_{1}v-\beta_{2})+2(D+1)d^{2/(D+1)-1}|v|^{2D/(D+1)}\sqrt{\eta(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3})}\frac{dW}{dt}
+2(D+1)Dd4/(D+1)2v3|v|4/(D+1)η(α1v22α2v+α3).\displaystyle+2(D+1)Dd^{4/(D+1)-2}v^{3}|v|^{-4/(D+1)}\eta(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3}). (61)

If v>0v>0, Eq. (61) becomes

dvdt=\displaystyle\frac{dv}{dt}= 2(D+1)d2/(D+1)1v2D/(D+1)(β1vβ2)+2(D+1)d2/(D+1)1v2D/(D+1)η(α1v22α2v+α3)dWdt\displaystyle-2(D+1)d^{2/(D+1)-1}v^{2D/(D+1)}(\beta_{1}v-\beta_{2})+2(D+1)d^{2/(D+1)-1}v^{2D/(D+1)}\sqrt{\eta(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3})}\frac{dW}{dt}
+2(D+1)Dd4/(D+1)2v34/(D+1)η(α1v22α2v+α3).\displaystyle+2(D+1)Dd^{4/(D+1)-2}v^{3-4/(D+1)}\eta(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3}). (62)

Therefore, the stationary distribution of a general deep diagonal network is given by

p(v)\displaystyle p(v) 1v3(11/(D+1))(α1v22α2v+α3)exp(1T𝑑vd12/(D+1)(β1vβ2)(D+1)v2D/(D+1)(α1v22α2v+α3)).\displaystyle\propto\frac{1}{v^{3(1-1/(D+1))}(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3})}\exp\left(-\frac{1}{T}\int dv\frac{d^{1-2/(D+1)}(\beta_{1}v-\beta_{2})}{(D+1)v^{2D/(D+1)}(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3})}\right). (63)

If v<0v<0, Eq. (61) becomes

d|v|dt=\displaystyle\frac{d|v|}{dt}= 2(D+1)d2/(D+1)1|v|2D/(D+1)(β1|v|+β2)2(D+1)d2/(D+1)1|v|2D/(D+1)η(α1|v|2+2α2|v|+α3)dWdt\displaystyle-2(D+1)d^{2/(D+1)-1}|v|^{2D/(D+1)}(\beta_{1}|v|+\beta_{2})-2(D+1)d^{2/(D+1)-1}|v|^{2D/(D+1)}\sqrt{\eta(\alpha_{1}|v|^{2}+2\alpha_{2}|v|+\alpha_{3})}\frac{dW}{dt}
+2(D+1)Dd4/(D+1)2|v|34/(D+1)η(α1|v|2+2α2|v|+α3).\displaystyle+2(D+1)Dd^{4/(D+1)-2}|v|^{3-4/(D+1)}\eta(\alpha_{1}|v|^{2}+2\alpha_{2}|v|+\alpha_{3}). (64)

The stationary distribution of |v||v| is given by

p(|v|)1|v|3(11/(D+1))(α1|v|2+2α2|v|+α3)exp(1Td|v|d12/(D+1)(β1|v|+β2)(D+1)|v|2D/(D+1)(α1|v|2+2α2|v|+α3)).p(|v|)\propto\frac{1}{|v|^{3(1-1/(D+1))}(\alpha_{1}|v|^{2}+2\alpha_{2}|v|+\alpha_{3})}\exp\left(-\frac{1}{T}\int d|v|\frac{d^{1-2/(D+1)}(\beta_{1}|v|+\beta_{2})}{(D+1)|v|^{2D/(D+1)}(\alpha_{1}|v|^{2}+2\alpha_{2}|v|+\alpha_{3})}\right). (65)

Thus, we have obatined

p±(|v|)1|v|3(11/(D+1))(α1|v|22α2|v|+α3)exp(1Td|v|d12/(D+1)(β1|v|β2)(D+1)|v|2D/(D+1)(α1|v|22α2|v|+α3)).p_{\pm}(|v|)\propto\frac{1}{|v|^{3(1-1/(D+1))}(\alpha_{1}|v|^{2}\mp 2\alpha_{2}|v|+\alpha_{3})}\exp\left(-\frac{1}{T}\int d|v|\frac{d^{1-2/(D+1)}(\beta_{1}|v|\mp\beta_{2})}{(D+1)|v|^{2D/(D+1)}(\alpha_{1}|v|^{2}\mp 2\alpha_{2}|v|+\alpha_{3})}\right). (66)

Especially, when D=1D=1, the distribution function can be simplified as

p±(|v|)\displaystyle p_{\pm}(|v|) |v|±β2/2α3T3/2(α1|v|22α2|v|+α3)1±β2/4Tα3exp(12Tα3β1α2β2α3Δarctanα1|v|α2Δ),\displaystyle\propto\frac{|v|^{\pm\beta_{2}/2\alpha_{3}T-3/2}}{(\alpha_{1}|v|^{2}\mp 2\alpha_{2}|v|+\alpha_{3})^{1\pm\beta_{2}/4T\alpha_{3}}}\exp\left(-\frac{1}{2T}\frac{\alpha_{3}\beta_{1}-\alpha_{2}\beta_{2}}{\alpha_{3}\sqrt{\Delta}}\arctan\frac{\alpha_{1}|v|\mp\alpha_{2}}{\sqrt{\Delta}}\right), (67)

where we have used the integral

𝑑vβ1vβ2α1v22α2v+α3=α3β1α2β2α3Δarctanα1|v|α2Δ±β2α3log(v)±β22α3log(α1v22α2v+α3).\int dv\frac{\beta_{1}v\mp\beta_{2}}{\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3}}=\frac{\alpha_{3}\beta_{1}-\alpha_{2}\beta_{2}}{\alpha_{3}\sqrt{\Delta}}\arctan\frac{\alpha_{1}|v|\mp\alpha_{2}}{\sqrt{\Delta}}\pm\frac{\beta_{2}}{\alpha_{3}}\log(v)\pm\frac{\beta_{2}}{2\alpha_{3}}\log(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3}). (68)

A.5 Analysis of the maximum probability point

To investigate the existence of the maximum point given in Eq. (16), we treat TT as a variable and study whether (β110α2T)2+28α1T(β23α3T):=A(\beta_{1}-10\alpha_{2}T)^{2}+28\alpha_{1}T(\beta_{2}-3\alpha_{3}T):=A in the square root is always positive or not. When T<β23α3=Tc/3T<\frac{\beta_{2}}{3\alpha_{3}}=T_{c}/3, AA is positive for arbitrary data. When T>β23α3T>\frac{\beta_{2}}{3\alpha_{3}}, we divide the discussion into several cases. First, when α1α3>2521α22\alpha_{1}\alpha_{3}>\frac{25}{21}\alpha_{2}^{2}, there always exists a root for the expression AA. Hence, we find that

T=5α2β1+7α1β2+73α1α3β1210α1α2β1β2+7α12β222(21α1α325α22):=TT=\frac{-5\alpha_{2}\beta_{1}+7\alpha_{1}\beta_{2}+\sqrt{7}\sqrt{3\alpha_{1}\alpha_{3}\beta_{1}^{2}-10\alpha_{1}\alpha_{2}\beta_{1}\beta_{2}+7\alpha_{1}^{2}\beta_{2}^{2}}}{2(21\alpha_{1}\alpha_{3}-25\alpha_{2}^{2})}:=T^{*} (69)

is a critical point. When Tc/3<T<TT_{c}/3<T<T^{*}, there exists a solution to the maximum condition. When T>TT>T^{*}, there is no solution to the maximum condition.

The second case is α22<α1α3<2521α22\alpha_{2}^{2}<\alpha_{1}\alpha_{3}<\frac{25}{21}\alpha_{2}^{2}. In this case, we need to further compare the value between 5α2β15\alpha_{2}\beta_{1} and 7α1β27\alpha_{1}\beta_{2}. If 5α2β1<7α1β25\alpha_{2}\beta_{1}<7\alpha_{1}\beta_{2}, we have A>0A>0, which indicates that the maximum point exists. If 5α2β1>7α1β25\alpha_{2}\beta_{1}>7\alpha_{1}\beta_{2}, we need to further check the value of minimum of AA, which takes the form of

minTA(T)=(25α2221α1α3)β12(7α1β25α2β1)225α2221α1α3.{\rm min}_{T}A(T)=\frac{(25\alpha_{2}^{2}-21\alpha_{1}\alpha_{3})\beta_{1}^{2}-(7\alpha_{1}\beta_{2}-5\alpha_{2}\beta_{1})^{2}}{25\alpha_{2}^{2}-21\alpha_{1}\alpha_{3}}. (70)

If 7α15α2<β1β2<5α2+25α2221α1α33α3\frac{7\alpha_{1}}{5\alpha_{2}}<\frac{\beta_{1}}{\beta_{2}}<\frac{5\alpha_{2}+\sqrt{25\alpha_{2}^{2}-21\alpha_{1}\alpha_{3}}}{3\alpha_{3}}, the minimum of AA is always positive and the maximum exists. However, if β1β25α2+25α2221α1α33α3\frac{\beta_{1}}{\beta_{2}}\geq\frac{5\alpha_{2}+\sqrt{25\alpha_{2}^{2}-21\alpha_{1}\alpha_{3}}}{3\alpha_{3}}, there is always a critical learning rate TT^{*}. If β1β2=5α2+25α2221α1α33α3\frac{\beta_{1}}{\beta_{2}}=\frac{5\alpha_{2}+\sqrt{25\alpha_{2}^{2}-21\alpha_{1}\alpha_{3}}}{3\alpha_{3}}, there is only one critical learning rate as Tc=5α2β17α1β22(25α2221α1α3)T_{c}=\frac{5\alpha_{2}\beta_{1}-7\alpha_{1}\beta_{2}}{2(25\alpha_{2}^{2}-21\alpha_{1}\alpha_{3})}. When Tc/3<T<TT_{c}/3<T<T^{*}, there is a solution to the maximum condition while there is no solution when T>TT>T^{*}. If β1β2>5α2+25α2221α1α33α3\frac{\beta_{1}}{\beta_{2}}>\frac{5\alpha_{2}+\sqrt{25\alpha_{2}^{2}-21\alpha_{1}\alpha_{3}}}{3\alpha_{3}}, there are two critical points:

T1,2=5α2β1+7α1β273α1α3β1210α1α2β1β2+7α12β222(21α1α325α22).T_{1,2}=\frac{-5\alpha_{2}\beta_{1}+7\alpha_{1}\beta_{2}\mp\sqrt{7}\sqrt{3\alpha_{1}\alpha_{3}\beta_{1}^{2}-10\alpha_{1}\alpha_{2}\beta_{1}\beta_{2}+7\alpha_{1}^{2}\beta_{2}^{2}}}{2(21\alpha_{1}\alpha_{3}-25\alpha_{2}^{2})}. (71)

For T<T1T<T_{1} and T>T2T>T_{2}, there exists a solution to the maximum condition. For T1<T<T2T_{1}<T<T_{2}, there is no solution to the maximum condition. The last case is α22=α1α3<2521α22\alpha_{2}^{2}=\alpha_{1}\alpha_{3}<\frac{25}{21}\alpha_{2}^{2}. In this sense, the expression of AA is simplified as β12+28α1β2T20α2β1T\beta_{1}^{2}+28\alpha_{1}\beta_{2}T-20\alpha_{2}\beta_{1}T. Hence, when β1β2<7α15α2\frac{\beta_{1}}{\beta_{2}}<\frac{7\alpha_{1}}{5\alpha_{2}}, there is no critical learning rate and the maximum always exists. Nevertheless, when β1β2>7α15α2\frac{\beta_{1}}{\beta_{2}}>\frac{7\alpha_{1}}{5\alpha_{2}}, there is always a critical learning rate as T=β1220α2β128α1β2T^{*}=\frac{\beta_{1}^{2}}{20\alpha_{2}\beta_{1}-28\alpha_{1}\beta_{2}}. When T<TT<T^{*}, there is a solution to the maximum condition while there is no solution when T>TT>T^{*}.

without weight decay with weight decay
single layer (α1v22α2v+α3)1β12Tα1{(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3})^{-1-\frac{\beta_{1}}{2T\alpha_{1}}}} α1(vk)2(β1+γ)Tα1\alpha_{1}(v-k)^{-2-\frac{(\beta_{1}+\gamma)}{T\alpha_{1}}}
non-interpolation vβ2/2α3T3/2(α1v22α2v+α3)1+β2/4Tα3\frac{v^{\beta_{2}/2\alpha_{3}T-3/2}}{(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3})^{1+\beta_{2}/4T\alpha_{3}}} vS(β2γ)/2α3λ3/2(α1v22α2v+α3)1+(β2γ)/4Tα3\frac{v^{S(\beta_{2}-\gamma)/2\alpha_{3}\lambda-3/2}}{(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3})^{1+(\beta_{2}-\gamma)/4T\alpha_{3}}}
interpolation y=kxy=kx v3/2+β1/2Tα1k(vk)2+β1/2Tα1k\frac{v^{-3/2+\beta_{1}/2T\alpha_{1}k}}{(v-k)^{2+\beta_{1}/2T\alpha_{1}k}} v3/2+12Tα1k(β1γk)(vk)2+12Tα1k(β1γk)exp(βγ2Tα11k(kv))\frac{v^{-3/2+\frac{1}{2T\alpha_{1}k}(\beta_{1}-\frac{\gamma}{k})}}{(v-k)^{2+\frac{1}{2T\alpha_{1}k}(\beta_{1}-\frac{\gamma}{k})}}\exp(-\frac{\beta\gamma}{2T\alpha_{1}}\frac{1}{k(k-v)})
Table 1: Summary of distributions p(v)p(v) in a depth-11 neural network. Here, we show the distribution in the nontrivial subspace when the data xx and yy are positively correlated. The Θ(1)\Theta(1) factors are neglected for concision.

A.6 Other Cases for D=1D=1

The other cases are worth studying. For the interpolation case where the data is linear (y=kxy=kx for some kk), the stationary distribution is different and simpler. There exists a nontrivial fixed point for i(ui2wi2)\sum_{i}(u_{i}^{2}-w_{i}^{2}): jujwj=α2α1\sum_{j}u_{j}w_{j}=\frac{\alpha_{2}}{\alpha_{1}}, which is the global minimizer of LL and also has a vanishing noise. It is helpful to note the following relationships for the data distribution when it is linear:

{α1=Var[x2],α2=kVar[x2]=kα1,α3=k2α1,β1=𝔼[x2],β2=k𝔼[x2]=kβ1.\begin{cases}\alpha_{1}={\rm Var}[x^{2}],\\ \alpha_{2}=k{\rm Var}[x^{2}]=k\alpha_{1},\\ \alpha_{3}=k^{2}\alpha_{1},\\ \beta_{1}=\mathbb{E}[x^{2}],\\ \beta_{2}=k\mathbb{E}[x^{2}]=k\beta_{1}.\end{cases} (72)

Since the analysis of Fokker-Planck equation is the same, we directly begin with the distribution function in Eq. (14) for ui=wiu_{i}=-w_{i} which is given by P(|v|)δ(|v|)P(|v|)\propto\delta(|v|). Namely, the only possible weights are ui=wi=0u_{i}=w_{i}=0, the same as the non-interpolation case. This is because the corresponding stationary distribution is

P(|v|)\displaystyle P(|v|) 1|v|2(|v|+k)2exp(12Td|v|β1(|v|+k)+α11T(|v|+k)2α1|v|(|v|+k)2)\displaystyle\propto\frac{1}{|v|^{2}(|v|+k)^{2}}\exp\left(-\frac{1}{2T}\int d|v|\frac{\beta_{1}(|v|+k)+\alpha_{1}\frac{1}{T}(|v|+k)^{2}}{\alpha_{1}|v|(|v|+k)^{2}}\right)
|v|32β12Tα1k(|v|+k)2+β12Tα1k.\displaystyle\propto{|v|^{-\frac{3}{2}-\frac{\beta_{1}}{2T\alpha_{1}k}}(|v|+k)^{-2+\frac{\beta_{1}}{2T\alpha_{1}k}}}. (73)

The integral of Eq. (73) with respect to |v||v| diverges at the origin due to the factor |v|32+β12Tα1k|v|^{\frac{3}{2}+\frac{\beta_{1}}{2T\alpha_{1}k}}.

For the case ui=wiu_{i}=w_{i}, the stationary distribution is given from Eq. (14) as

P(v)\displaystyle P(v) 1v2(vk)2exp(12T𝑑vβ1(vk)+α1T(vk)2α1v(vk)2)\displaystyle\propto\frac{1}{v^{2}(v-k)^{2}}\exp\left(-\frac{1}{2T}\int dv\frac{\beta_{1}(v-k)+\alpha_{1}T(v-k)^{2}}{\alpha_{1}v(v-k)^{2}}\right)
v32+β12Tα1k(vk)2β12Tα1k.\displaystyle\propto{v^{-\frac{3}{2}+\frac{\beta_{1}}{2T\alpha_{1}k}}}{(v-k)^{-2-\frac{\beta_{1}}{2T\alpha_{1}k}}}. (74)

Now, we consider the case of γ0\gamma\neq 0. In the non-interpolation regime, when ui=wiu_{i}=-w_{i}, the stationary distribution is still p(v)=δ(v)p(v)=\delta(v). For the case of ui=wiu_{i}=w_{i}, the stationary distribution is the same as in Eq. (14) after replacing β\beta with β2=β2γ\beta_{2}^{\prime}=\beta_{2}-\gamma. It still has a phase transition. The weight decay has the effect of shifting β2\beta_{2} by γ-\gamma. In the interpolation regime, the stationary distribution is still p(v)=δ(v)p(v)=\delta(v) when ui=wiu_{i}=-w_{i}. However, when ui=wiu_{i}=w_{i}, the phase transition still exists since the stationary distribution is

p(v)\displaystyle p(v) v32+θ2(vk)2+θ2exp(β1γ2Tα11k(kv)),\displaystyle\propto\frac{v^{-\frac{3}{2}+\theta_{2}}}{(v-k)^{2+\theta_{2}}}\exp\left(-\frac{\beta_{1}\gamma}{2T\alpha_{1}}\frac{1}{k(k-v)}\right), (75)

where θ2=12Tα1k(β1γk)\theta_{2}=\frac{1}{2T\alpha_{1}k}(\beta_{1}-\frac{\gamma}{k}). The phase transition point is θ2=1/2\theta_{2}=1/2, which is the same as the non-interpolation one.

The last situation is rather special, which happens when Δ=0\Delta=0 but ykxy\neq kx: y=kxc/xy=kx-c/x for some c0c\neq 0. In this case, the parameters α\alpha and β\beta are the same as those given in Eq. (72) except for β2\beta_{2}:

β2=k𝔼[x2]kc=kβ1kc.\beta_{2}=k\mathbb{E}[x^{2}]-kc=k\beta_{1}-kc. (76)

The corresponding stationary distribution is

P(|v|)\displaystyle P(|v|) |v|32ϕ2(|v|+k)2ϕ2exp(c2Tα11k(k+|v|)),\displaystyle\propto\frac{|v|^{-\frac{3}{2}-\phi_{2}}}{(|v|+k)^{2-\phi_{2}}}\exp\left(\frac{c}{2T\alpha_{1}}\frac{1}{k(k+|v|)}\right), (77)

where ϕ2=12Tα1k(β1c)\phi_{2}=\frac{1}{2T\alpha_{1}k}(\beta_{1}-c). Here, we see that the behavior of stationary distribution P(|v|)P(|v|) is influenced by the sign of cc. When c<0c<0, the integral of P(|v|)P(|v|) diverges due to the factor |v|32ϕ2<|v|3/2|v|^{-\frac{3}{2}-\phi_{2}}<|v|^{-3/2} and Eq. (77) becomes δ(|v|)\delta(|v|) again. However, when c>0c>0, the integral of |v||v| may not diverge. The critical point is 32+ϕ2=1\frac{3}{2}+\phi_{2}=1 or equivalently: c=β1+Tα1kc=\beta_{1}+T\alpha_{1}k. This is because when c<0c<0, the data points are all distributed above the line y=kxy=kx. Hence, ui=wiu_{i}=-w_{i} can only give a trivial solution. However, if c>0c>0, there is the possibility to learn the negative slope kk. When 0<c<β1+Tα1k0<c<\beta_{1}+T\alpha_{1}k, the integral of P(|v|)P(|v|) still diverges and the distribution is equivalent to δ(|v|)\delta(|v|). Now, we consider the case of ui=wiu_{i}=w_{i}. The stationary distribution is

P(|v|)\displaystyle P(|v|) |v|32+ϕ2(|v|k)2+ϕ2exp(c2Tα11k|v|).\displaystyle\propto\frac{|v|^{-\frac{3}{2}+\phi_{2}}}{(|v|-k)^{2+\phi_{2}}}\exp\left(-\frac{c}{2T\alpha_{1}}\frac{1}{k-|v|}\right). (78)

It also contains a critical point: 32+ϕ2=1-\frac{3}{2}+\phi_{2}=-1, or equivalently, c=β1α1kTc=\beta_{1}-\alpha_{1}kT. There are two cases. When c<0c<0, the probability density only has support for |v|>k|v|>k since the gradient always pulls the parameter |v||v| to the region |v|>k|v|>k. Hence, the divergence at |v|=0|v|=0 is of no effect. When c>0c>0, the probability density has support on 0<|v|<k0<|v|<k for the same reason. Therefore, if β1>α1kT\beta_{1}>\alpha_{1}kT, there exists a critical point c=β1α1kTc=\beta_{1}-\alpha_{1}kT. When c>β1α1kTc>\beta_{1}-\alpha_{1}kT, the distribution function P(|v|)P(|v|) becomes δ(|v|)\delta(|v|). When c<β1α1kTc<\beta_{1}-\alpha_{1}kT, the integral of the distribution function is finite for 0<|v|<k0<|v|<k, indicating the learning of the neural network. If β1α1kT\beta_{1}\leq\alpha_{1}kT, there will be no criticality and P(|v|)P(|v|) is always equivalent to δ(|v|)\delta(|v|). The effect of having weight decay can be similarly analyzed, and the result can be systematically obtained if we replace β1\beta_{1} with β1+γ/k\beta_{1}+{\gamma}/{k} for the case ui=wiu_{i}=-w_{i} or replacing β1\beta_{1} with β1γ/k\beta_{1}-{\gamma}/{k} for the case ui=wiu_{i}=w_{i}.

A.7 Second-order Law of Balance

Considering the modified loss function:

tot=+14TL2.\ell_{\text{tot}}=\ell+\frac{1}{4}T||\nabla L||^{2}. (79)

In this case, the Langevin equations become

dwj\displaystyle{dw_{j}} =wjdt14TL2wj,\displaystyle=-\frac{\partial\ell}{\partial w_{j}}dt-\frac{1}{4}T\frac{\partial||\nabla L||^{2}}{\partial w_{j}}, (80)
dui\displaystyle{du_{i}} =uidt14TL2ui.\displaystyle=--\frac{\partial\ell}{\partial u_{i}}dt-\frac{1}{4}T\frac{\partial||\nabla L||^{2}}{\partial u_{i}}. (81)

Hence, the modified SDEs of ui2u_{i}^{2} and wj2w_{j}^{2} can be rewritten as

dui2dt\displaystyle\frac{du_{i}^{2}}{dt} =2uiduidt+(dui)2dt=2uiui++TCiu12Tuiui|L|2,\displaystyle=2u_{i}\frac{du_{i}}{dt}+\frac{(du_{i})^{2}}{dt}=-2u_{i}\frac{\partial\ell}{\partial u_{i}}++TC_{i}^{u}-\frac{1}{2}Tu_{i}\nabla_{u_{i}}|\nabla L|^{2}, (82)
dwj2dt\displaystyle\frac{dw_{j}^{2}}{dt} =2wjdwjdt+(dwj)2dt=2wjwj+TCjw12Twjwj|L|2.\displaystyle=2w_{j}\frac{dw_{j}}{dt}+\frac{(dw_{j})^{2}}{dt}=-2w_{j}\frac{\partial\ell}{\partial w_{j}}+TC_{j}^{w}-\frac{1}{2}Tw_{j}\nabla_{w_{j}}|\nabla L|^{2}. (83)

In this section, we consider the effects brought by the last term in Eqs. (82) and (83). From the infinitesimal transformation of the rescaling symmetry:

jwjwj=iuiui,\sum_{j}w_{j}\frac{\partial\ell}{\partial w_{j}}=\sum_{i}u_{i}\frac{\partial\ell}{\partial u_{i}}, (84)

we take derivative to both sides of the equation and obtain

Lui+juj2Luiuj=jwj2Luiwj,\displaystyle\frac{\partial L}{\partial u_{i}}+\sum_{j}u_{j}\frac{\partial^{2}L}{\partial u_{i}\partial u_{j}}=\sum_{j}w_{j}\frac{\partial^{2}L}{\partial u_{i}\partial w_{j}}, (85)
juj2Lwiuj=Lwi+jwj2Lwiwj,\displaystyle\sum_{j}u_{j}\frac{\partial^{2}L}{\partial w_{i}\partial u_{j}}=\frac{\partial L}{\partial w_{i}}+\sum_{j}w_{j}\frac{\partial^{2}L}{\partial w_{i}\partial w_{j}}, (86)

where we take the expectation to \ell at the same time. By substituting these equations into Eqs. (82) and (83), we obtain

du2dtd||w|||2dt=Ti(Ciu+(uiL)2)Tj(Cjw+(wjL)2).\frac{d||u||^{2}}{dt}-\frac{d||w|||^{2}}{dt}=T\sum_{i}(C_{i}^{u}+(\nabla_{u_{i}}L)^{2})-T\sum_{j}(C_{j}^{w}+(\nabla_{w_{j}}L)^{2}). (87)

Then following the procedure in Appendix. A.2, we can rewrite Eq. (87) as

du2dtdw2dt\displaystyle\frac{d||u||^{2}}{dt}-\frac{d||w||^{2}}{dt} =T(uTC1u+uTD1uwTC2wwTD2w)\displaystyle=-T(u^{T}C_{1}u+u^{T}D_{1}u-w^{T}C_{2}w-w^{T}D_{2}w)
=T(uTE1uwTE2w),\displaystyle=-T(u^{T}E_{1}u-w^{T}E_{2}w), (88)

where

(D1)ij\displaystyle(D_{1})_{ij} =k𝔼[(uiwk)]𝔼[(ujwk)],\displaystyle=\sum_{k}\mathbb{E}\left[\frac{\partial\ell}{\partial(u_{i}w_{k})}\right]\mathbb{E}\left[\frac{\partial\ell}{\partial(u_{j}w_{k})}\right], (89)
(D2)kl\displaystyle(D_{2})_{kl} =i𝔼[(uiwk)]𝔼[(uiwl)],\displaystyle=\sum_{i}\mathbb{E}\left[\frac{\partial\ell}{\partial(u_{i}w_{k})}\right]\mathbb{E}\left[\frac{\partial\ell}{\partial(u_{i}w_{l})}\right], (90)
(E1)ij\displaystyle(E_{1})_{ij} =𝔼[k(uiwk)(ujwk)],\displaystyle=\mathbb{E}\left[\sum_{k}\frac{\partial\ell}{\partial(u_{i}w_{k})}\frac{\partial\ell}{\partial(u_{j}w_{k})}\right], (91)
(E2)kl\displaystyle(E_{2})_{kl} =𝔼[i(uiwk)(uiwl)].\displaystyle=\mathbb{E}\left[\sum_{i}\frac{\partial\ell}{\partial(u_{i}w_{k})}\frac{\partial\ell}{\partial(u_{i}w_{l})}\right]. (92)

For one-dimensional parameters u,wu,w, Eq. (A.7) is reduced to

ddt(u2w2)=𝔼[((uw))2](u2w2).\frac{d}{dt}(u^{2}-w^{2})=-\mathbb{E}\left[\left(\frac{\partial\ell}{\partial(uw)}\right)^{2}\right](u^{2}-w^{2}). (93)

Therefore, we can see this loss modification increases the speed of convergence. Now, we move to the stationary distribution of the parameter vv. At the stationarity, if ui=wiu_{i}=-w_{i}, we also have the distribution P(v)=δ(v)P(v)=\delta(v) like before. However, when ui=wiu_{i}=w_{i}, we have

dvdt=4v(β1vβ2)+4Tv(α1v22α2v+α3)4β12Tv(β1vβ2)(3β1vβ2)+4vT(α1v22α2v+α3)dWdt.\frac{dv}{dt}=-4v(\beta_{1}v-\beta_{2})+4Tv(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3})-4\beta_{1}^{2}Tv(\beta_{1}v-\beta_{2})(3\beta_{1}v-\beta_{2})+4v\sqrt{T(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3})}\frac{dW}{dt}. (94)

Hence, the stationary distribution becomes

P(v)vβ2/2α3T3/2β22/2α3(α1v22α2v+α3)1+β2/4Tα3+K1exp((12Tα3β1α2β2α3Δ+K2)arctanα1vα2Δ),P(v)\propto\frac{v^{\beta_{2}/2\alpha_{3}T-3/2-\beta_{2}^{2}/2\alpha_{3}}}{(\alpha_{1}v^{2}-2\alpha_{2}v+\alpha_{3})^{1+\beta_{2}/4T\alpha_{3}+K_{1}}}\exp\left(-\left(\frac{1}{2T}\frac{\alpha_{3}\beta_{1}-\alpha_{2}\beta_{2}}{\alpha_{3}\sqrt{\Delta}}+K_{2}\right)\arctan\frac{\alpha_{1}v-\alpha_{2}}{\sqrt{\Delta}}\right), (95)

where

K1\displaystyle K_{1} =3α3β12α1β224α1α3,\displaystyle=\frac{3\alpha_{3}\beta_{1}^{2}-\alpha_{1}\beta_{2}^{2}}{4\alpha_{1}\alpha_{3}},
K2\displaystyle K_{2} =3α2α3β124α1α3β1β2+α1α2β222α1α3Δ.\displaystyle=\frac{3\alpha_{2}\alpha_{3}\beta_{1}^{2}-4\alpha_{1}\alpha_{3}\beta_{1}\beta_{2}+\alpha_{1}\alpha_{2}\beta_{2}^{2}}{2\alpha_{1}\alpha_{3}\sqrt{\Delta}}. (96)

From the expression above we can see K11+β2/4Tα3K_{1}\ll 1+\beta_{2}/4T\alpha_{3} and K2(α3β1α2β2)/2Tα3ΔK_{2}\ll(\alpha_{3}\beta_{1}-\alpha_{2}\beta_{2})/2T\alpha_{3}\sqrt{\Delta}. Hence, the effect of modification can only be seen in the term proportional to vv. The phase transition point is modified as

Tc=β2α3+β22.T_{c}=\frac{\beta_{2}}{\alpha_{3}+\beta_{2}^{2}}. (97)

Compared with the previous result Tc=β2α3T_{c}=\frac{\beta_{2}}{\alpha_{3}}, we can see the effect of the loss modification is α3α3+β22\alpha_{3}\to\alpha_{3}+\beta_{2}^{2}, or equivalently, Var[xy]𝔼[x2y2]{\rm Var}[xy]\to\mathbb{E}[x^{2}y^{2}]. This effect can be seen from E1E_{1} and E2E_{2}.