This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

A self consistent theory of Gaussian Processes
captures feature learning effects in finite CNNs

Gadi Naveh1,2 and Zohar Ringel1 1Racah Institute of Physics, Hebrew University, Jerusalem 91904, Israel 
2Edmond and Lily Safra Center for Brain Sciences, Hebrew University, Jerusalem 91904, Israel
(April 3, 2025)
Abstract

Deep neural networks (DNNs) in the infinite width/channel limit have received much attention recently, as they provide a clear analytical window to deep learning via mappings to Gaussian Processes (GPs). Despite its theoretical appeal, this viewpoint lacks a crucial ingredient of deep learning in finite DNNs, laying at the heart of their success — feature learning. Here we consider DNNs trained with noisy gradient descent on a large training set and derive a self consistent Gaussian Process theory accounting for strong finite-DNN and feature learning effects. Applying this to a toy model of a two-layer linear convolutional neural network (CNN) shows good agreement with experiments. We further identify, both analytical and numerically, a sharp transition between a feature learning regime and a lazy learning regime in this model. Strong finite-DNN effects are also derived for a non-linear two-layer fully connected network. Our self consistent theory provides a rich and versatile analytical framework for studying feature learning and other non-lazy effects in finite DNNs.

preprint: APS/123-QED

I Introduction

The correspondence between Gaussian Processes (GPs) and deep neural networks (DNNs) has been instrumental in advancing our understanding of these complex algorithms. Early results related randomly initialized strongly over-parameterized DNNs with GP priors [31, 21, 28]. More recent results considered training using gradient flow (or noisy gradients), where DNNs map to Bayesian inference on GPs governed by the neural tangent kernel [19, 22] (or the NNGP kernel [30]). These correspondences carry over to a wide variety of architectures, going beyond fully connected networks (FCNs) to convolutional neural networks (CNNs) [3, 34], recurrent neural networks (RNNs) [1] and even attention networks [18]. They provide us with closed analytical expressions for the outputs of strongly over-parameterized trained DNNs, which have been used to make accurate predictions for DNN learning curves [10, 8, 7].

Despite their theoretical appeal, GPs are unable to capture feature learning [42, 41], which is a well-observed key property of trained DNNs. Indeed, it was noticed [19] that as the width tends to infinity, the neural tangent kernel (NTK) tends to a constant kernel that does not evolve during training and the weights in hidden layers change infinitesimally from their initialization values. This regime of training was thus dubbed lazy training [9]. Other studies showed that for CNNs trained on image classification tasks, the feature learning regime generally tends to outperform the lazy regime [13, 14, 23]. Clearly, working in the feature learning regime is also crucial for performing transfer learning [40, 41].

It is therefore desirable to have a theoretical approach to deep learning which enjoys the generality and analytical power of GPs while capturing feature learning effects in finite DNNs. Here we make several contributions towards this goal:

  1. 1.

    We show that the mean predictor of a finite DNN trained on a large data set with noisy gradients, weight decay and MSE loss, can be obtained from GP regression on a shifted target (§III). Central to our approach is a non-linear self consistent equation involving the higher cumulants of the finite DNN (at initialization) which predicts this target shift.

  2. 2.

    Using this machinery on a toy model of a two-layer linear CNN in a teacher-student setting, we derive explicit analytical predictions which are in very good agreement with experiments even well away from the GP/lazy-learning regime (large number of channels, CC) thus accounting for strong finite-DNN corrections (§IV). Similarly strong corrections to GPs, yielding qualitative improvements in performance, are demonstrated for the quadratic two-layer fully connected model of Ref. [26].

  3. 3.

    We show how our framework can be used to study statistical properties of weights in hidden layers. In particular, in the CNN toy model, we identify, both analytically and numerically, a sharp transition between a feature learning phase and a lazy learning phase (§IV.4). We define the feature learning phase as the regime where the features of the teacher network leave a clear signature in the spectrum of the student’s hidden weights posterior covariance matrix. In essence, this phase transition is analogous to the transition associated with the recovery of a low-rank signal matrix from a noisy matrix taken from the Wishart ensemble, when varying the strength of the low-rank component [5].

I.1 Additional related work

Several previous papers derived leading order finite-DNN corrections to the GP results [30, 39, 12]. While these results are in principle extendable to any order in perturbation theory, such high order expansions have not been studied much, perhaps due to their complexity. In contrast, we develop an analytically tractable non-perturbative approach which we find crucial for obtaining non-negligible feature learning and associated performance enhancement effects.

Previous works [13, 14] studied how the behavior of infinite DNNs depends on the scaling of the top layer weights with its width. In [40] it is shown that the standard and NTK parameterizations of a neural network do not admit an infinite-width limit that can learn features, and instead suggest an alternative parameterization which can learn features in this limit. While unifying various viewpoints on infinite DNNs, this approach does not immediately lend itself to analytical analysis of the kind proposed here.

Several works [36, 25, 15, 2] show that finite width models can generalize either better or worse than their infinite width counterparts, and provide examples where the relative performance depends on the optimization details, the DNN architecture and the statistics of the data. Here we demonstrate analytically that finite DNNs outperform their GP counterparts when the latter have a prior that lacks some constraint found in the data (e.g. positive-definiteness [26], weight sharing [34]).

Deep linear networks (FCNs and CNNs) similar to our CNN toy example have been studied in the literature [4, 37, 20, 16]. These studies use different approaches and assumptions and do not discuss the target shift mechanism which applies also for non-linear CNNs. In addition, their analytical results hinge strongly on linearity whereas our approach could be useful whenever several leading cumulants of the DNN output are known.

A concurrent work [43] derived exact expressions for the output priors of finite FCNs induced by Gaussian priors over their weights. However, these results only apply to the limited case of a prior over a single training point and only for a FCN. In contrast, our approach applies to the setting of a large training set, it is not restricted to FCNs and yields results for the posterior predictions, not the prior. Focusing on deep linear fully connected DNNs, recent work [24] derived analytical finite-width renormalization results for the GP kernel, by sequentially integrating out the weights of the DNN, starting from the output layer and working backwards towards the input. Our analytical approach, its scope, and the models studied here differ substantially from that work.

II Preliminaries

We consider a fixed set of nn training inputs {𝐱μ}μ=1nd\left\{{\mathbf{x}}_{\mu}\right\}_{\mu=1}^{n}\subset\mathbb{R}^{d} and a single test point 𝐱{\mathbf{x}}_{*} over which we wish to model the distribution of the outputs of a DNN. We consider a generic DNN architecture where for simplicity we assume a scalar output f(𝐱)f({\mathbf{x}})\in\mathbb{R}. The learnable parameters of the DNN that determine its output, are collected into a single vector θ\theta. We pack the outputs evaluated on the training set and on the test point into a vector f(f(𝐱1),,f(𝐱n),f(𝐱n+1))n+1\vec{f}\equiv\left(f\left({\mathbf{x}}_{1}\right),\dots,f\left({\mathbf{x}}_{n}\right),f\left({\mathbf{x}}_{n+1}\right)\right)\in\mathbb{R}^{n+1}, where we denoted the test point as 𝐱=𝐱n+1{\mathbf{x}}_{*}={\mathbf{x}}_{n+1}. We train the DNNs using full-batch gradient decent with weight decay and external white Gaussian noise. The discrete dynamics of the parameters are thus

θt+1θt=(γθt+θ(fθ))η+2σηξt\theta_{t+1}-\theta_{t}=-\left(\gamma\theta_{t}+\nabla_{\theta}\mathcal{L}\left(f_{\theta}\right)\right)\eta+2\sigma\sqrt{\eta}\xi_{t} (1)

where θt\theta_{t} is the vector of all network parameters at time step tt, γ\gamma is the strength of the weight decay, (fθ)\mathcal{L}(f_{\theta}) is the loss as a function of the DNN output fθf_{\theta} (where we have emphasized the dependence on the parameters θ\theta), σ\sigma is the magnitude of noise, η\eta is the learning rate and ξt𝒩(0,I)\xi_{t}\sim\mathcal{N}(0,I). As η0\eta\to 0 this discrete-time dynamics converge to the continuous-time Langevin equation given by θ˙(t)=θ(γ2θ(t)2+(fθ))+2σξ(t)\dot{\theta}\left(t\right)=-\nabla_{\theta}\left(\frac{\gamma}{2}||\theta(t)||^{2}+\mathcal{L}\left(f_{\theta}\right)\right)+2\sigma\xi\left(t\right) with ξi(t)ξj(t)=δijδ(tt)\left\langle\xi_{i}(t)\xi_{j}(t^{\prime})\right\rangle=\delta_{ij}\delta\left(t-t^{\prime}\right), so that as tt\to\infty the DNN parameters θ\theta will be sampled from the equilibrium Gibbs distribution P(θ)P(\theta).

As shown in [30], the parameter distribution P(θ)P(\theta) induces a posterior distribution over the trained DNN outputs P(f)P(\vec{f}) with the following partition function:

Z(J)=𝑑fP0(f)exp(12σ2({fμ}μ=1n,{gμ}μ=1n)+μ=1n+1Jμfμ)Z\left(\vec{J}\right)=\int d\vec{f}P_{0}\left(\vec{f}\right)\exp\left(-\frac{1}{2\sigma^{2}}\mathcal{L}\left(\left\{f_{\mu}\right\}_{\mu=1}^{n},\left\{g_{\mu}\right\}_{\mu=1}^{n}\right)+\sum_{\mu=1}^{n+1}J_{\mu}f_{\mu}\right) (2)

Here P0(f)P_{0}(\vec{f}) is the prior generated by the finite-DNN with θ\theta drawn from 𝒩(0,2σ2/γ){\mathcal{N}}(0,2\sigma^{2}/\gamma) where the weight decay γ\gamma may be layer-dependent, {gμ}μ=1n\left\{g_{\mu}\right\}_{\mu=1}^{n} are the training targets and J\vec{J} are source terms used to calculate the statistics of ff. We keep the loss function \mathcal{L} arbitrary at this point, committing to a specific choice in the next section. As standard [17], to calculate the posterior mean at any of the training points or the test point 𝐱n+1{\mathbf{x}}_{n+1} from this partition function one uses

μ{1,,n+1}:fμ=JμlogZ(J)|J=0\forall\mu\in\left\{1,\dots,n+1\right\}:\qquad\left\langle f_{\mu}\right\rangle=\partial_{J_{\mu}}\evaluated{\log Z\left(\vec{J}\right)}_{\vec{J}=\vec{0}} (3)

III A self consistent theory for the posterior mean and covariance

In this section we show that for a large training set, the posterior mean predictor (Eq. 3) amounts to GP regression on a shifted target (gμgμΔgμg_{\mu}\to g_{\mu}-\Delta g_{\mu}). This shift to the target (Δgμ\Delta g_{\mu}) is determined by solving certain self-consistent equations involving the cumulants of the prior P0(f)P_{0}(\vec{f}). For concreteness, we focus here on the MSE loss =μ=1n(fμgμ)2\mathcal{L}=\sum_{\mu=1}^{n}\left(f_{\mu}-g_{\mu}\right)^{2} and comment on extensions to other losses, e.g. the cross entropy, in App. C. To this end, consider first the prior of the output of a finite DNN. Using standard manipulations (see App. A), it can be expressed as follows

P0(f)n+1𝑑texp(μ=1n+1itμfμ+r=21r!μ1,,μr=1n+1κμ1,,μritμ1itμr)P_{0}\left(\vec{f}\right)\propto\int_{\mathbb{R}^{n+1}}d\vec{t}\exp\left(-\sum_{\mu=1}^{n+1}it_{\mu}f_{\mu}+\sum_{r=2}^{\infty}\frac{1}{r!}\sum_{\mu_{1},\dots,\mu_{r}=1}^{n+1}\kappa_{\mu_{1},\dots,\mu_{r}}it_{\mu_{1}}\cdots it_{\mu_{r}}\right) (4)

where κμ1,,μr\kappa_{\mu_{1},\dots,\mu_{r}} is the rr’th multivariate cumulant of P0(f)P_{0}(\vec{f}) [29]. The second term in the exponent is the cumulant generating function (𝒞\mathcal{C}) corresponding to P0P_{0}. As discussed in App. B and Ref. [30], for standard initialization protocols the rr’th cumulant will scale as 1/C(r/21)1/C^{(r/2-1)}, where CC controls the over-parameterization. The second (r=2r=2) cumulant which is CC-independent, describes the NNGP kernel of the finite DNN and is denoted by K(𝐱μ1,𝐱μ2)=κμ1,μ2K({\mathbf{x}}_{\mu_{1}},{\mathbf{x}}_{\mu_{2}})=\kappa_{\mu_{1},\mu_{2}}.

Consider first the case of CC\rightarrow\infty [21, 28, 31] where all r>2r>2 cumulants vanish. Here one can explicitly perform the integration in Eq. 4 to obtain the standard GP prior P0(f)exp(12μ1,μ2=1n+1κμ1,μ2fμ1fμ2)P_{0}\left(\vec{f}\right)\propto\exp\left(-\frac{1}{2}\sum^{n+1}_{\mu_{1},\mu_{2}=1}\kappa_{\mu_{1},\mu_{2}}f_{\mu_{1}}f_{\mu_{2}}\right). Plugging this prior into Eq. 2 with MSE loss, one recovers standard GP regression formulas [35]. In particular, the predictive mean at 𝐱{\mathbf{x}}_{*} is: f(𝐱)=μ,ν=1nKμK~μν1gν\left\langle f\left({\mathbf{x}}_{*}\right)\right\rangle=\sum_{\mu,\nu=1}^{n}K_{\mu}^{*}\tilde{K}_{\mu\nu}^{-1}g_{\nu} where Kμ=K(𝐱,𝐱μ)K_{\mu}^{*}=K\left({\mathbf{x}}_{*},{\mathbf{x}}_{\mu}\right) and K~μν=K(𝐱μ,𝐱ν)+σ2δμν\tilde{K}_{\mu\nu}=K\left({\mathbf{x}}_{\mu},{\mathbf{x}}_{\nu}\right)+\sigma^{2}\delta_{\mu\nu}. Another set of quantities we shall find useful are the discrepancies in GP prediction, which for the training set read

μ{1,,n}:δ^gμgμf(𝐱μ)=gμν,ν=1nKμνK~ν,ν1gν\displaystyle\boxed{\forall\mu\in\left\{1,\dots,n\right\}:\quad\langle\hat{\delta}g_{\mu}\rangle\equiv g_{\mu}-\left\langle f\left({\mathbf{x}}_{\mu}\right)\right\rangle=g_{\mu}-\sum_{\nu,\nu^{\prime}=1}^{n}K_{\mu\nu^{\prime}}\tilde{K}_{\nu^{\prime},\nu}^{-1}g_{\nu}} (5)
Saddle point approximation for the mean predictor.

For a DNN with finite CC, the prior P0(f)P_{0}(\vec{f}) will no longer be Gaussian and cumulants with r>2r>2 would contribute. This renders the partition function in Eq. 2 intractable and so some approximation is needed to make progress. To this end we note that ff can be integrated out (see App. A.1) to yield a partition function of the form

Z(J)n𝑑t1𝑑tne𝒮(t,J)Z\left(\vec{J}\right)\propto\int_{\mathbb{R}^{n}}dt_{1}\cdots dt_{n}e^{-\mathcal{S}\left(\vec{t},\vec{J}\right)} (6)

where 𝒮(t,J)\mathcal{S}(\vec{t},\vec{J}) is the action whose exact form is given in Eq. A.14. Interestingly, the itμit_{\mu} variables appearing above are closely related to the discrepancies δ^gμ\hat{\delta}g_{\mu}, in particular itμ=δ^gμ/σ2\langle it_{\mu}\rangle=\langle\hat{\delta}g_{\mu}\rangle/\sigma^{2}.

To proceed analytically we adopt the saddle point (SP) approximation [11] which, as argued in App. A.4, relies on the fact that the non-linear terms in the action comprise of a sum of many itμit_{\mu}’s. Given that this sum is dominated by collective effects coming from all data points, expanding 𝒮(t,J)\mathcal{S}(\vec{t},\vec{J}) around the saddle point yields terms with increasingly negative powers of nn.

For the training points μ{1,,n}\mu\in\left\{1,\dots,n\right\}, taking the saddle point approximation amounts to setting itμ𝒮(t,J)|J=0=0\evaluated{\partial_{it_{\mu}}\mathcal{S}\left(\vec{t},\vec{J}\right)}_{\vec{J}=\vec{0}}=0. This yields a set of equations that has precisely the form of Eq. 5, but where the target is shifted as gνgνΔgνg_{\nu}\to g_{\nu}-\Delta g_{\nu} and the target shift is determined self consistently by

Δgν=r=31(r1)!μ1,,μr1=1nκν,μ1,,μr1σ2δ^gμ1σ2δ^gμr1\boxed{\Delta g_{\nu}=\sum_{r=3}^{\infty}\frac{1}{\left(r-1\right)!}\sum_{\mu_{1},\dots,\mu_{r-1}=1}^{n}\kappa_{\nu,\mu_{1},\dots,\mu_{r-1}}\left\langle\sigma^{-2}\hat{\delta}g_{\mu_{1}}\right\rangle\cdots\left\langle\sigma^{-2}\hat{\delta}g_{\mu_{r-1}}\right\rangle} (7)

Equation 7 is thus an implicit equation for Δgν\Delta g_{\nu} involving all training points, and it holds for the training set and the test point ν{1,,n+1}\nu\in\left\{1,\dots,n+1\right\}. Once solved, either analytically or numerically, one calculates the predictions on the test point via

f=Δg+μ,ν=1nKμK~μν1(gνΔgν)\boxed{\left\langle f_{*}\right\rangle=\Delta g_{*}+\sum_{\mu,\nu=1}^{n}K_{\mu}^{*}\tilde{K}_{\mu\nu}^{-1}\left(g_{\nu}-\Delta g_{\nu}\right)} (8)

Equation 5 with gνgνΔgνg_{\nu}\to g_{\nu}-\Delta g_{\nu} along with Eqs. 7 and 8 are the first main result of this paper. Viewed as an algorithm, the procedure to predict the finite DNN’s output on a test point 𝐱{\mathbf{x}}_{*} is as follows: we shift the target in Eq. 5 as ggΔgg\to g-\Delta g with Δg\Delta g as in Eq. 7, arriving at a closed equation for the average discrepancies δ^gμ\langle\hat{\delta}g_{\mu}\rangle on the training set. For some models, the cumulants κν,μ2,,μr\kappa_{\nu,\mu_{2},\dots,\mu_{r}} can be computed for any order rr and it can be possible to sum the entire series, while for other models several leading cumulants might already give a reasonable approximation due to their 1/Cr/211/C^{r/2-1} scaling. The resulting coupled non-linear equations can then be solved numerically, to obtain Δgμ\Delta g_{\mu} from which predictions on the test point are calculated using Eq. 8.

Notwithstanding, solving such equations analytically is challenging and one of our main goals here is to provide concrete analytical insights. Thus, in §IV.2 we propose an additional approximation wherein to leading order we replace all summations over data-points with integrals over the measure from which the data-set is drawn. This approximation, taken in some cases beyond leading order as in Ref. [10], will yield analytically tractable equations which we solve for two simple toy models, one of a linear CNN and the other of a non-linear FCN.

Saddle point plus Gaussian fluctuations for the posterior covariance.

The SP approximation can be extended to compute the predictor variance by expanding the action 𝒮\mathcal{S} to quadratic order in itμit_{\mu} around the SP value (see App. A.3). Due to the saddle-point being an extremum this leads to 𝒮𝒮SP+12tμAμν1tν\mathcal{S}\approx\mathcal{S}_{\rm{SP}}+\frac{1}{2}t_{\mu}A^{-1}_{\mu\nu}t_{\nu}. This leaves the previous SP approximation for the posterior mean on the training set unaffected (since the mean and maximizer of a Gaussian coincide), but is necessary to get sensible results for the posterior covariance. Empirically, in the toy models we considered in §IV we find that the finite DNN corrections to the variance are much less pronounced than those for the mean. Using the standard Gaussian integration formula, one finds that AμνA_{\mu\nu} is the covariance matrix of itμit_{\mu}. Performing such an expansion one finds

Aμν1\displaystyle A^{-1}_{\mu\nu} =(σ2δμν+Kμν+ΔKμν)\displaystyle=-\left(\sigma^{2}\delta_{\mu\nu}+K_{\mu\nu}+\Delta K_{\mu\nu}\right) (9)
ΔKμν\displaystyle\Delta K_{\mu\nu} =itμΔgν(it1,,itn)\displaystyle=\partial_{it_{\mu}}\Delta g_{\nu}\left(it_{1},\dots,it_{n}\right)

where the itμit_{\mu} on the r.h.s. are those of the saddle point. This gives an expression for the posterior covariance matrix on the training set:

Σμν=fμfνfμfν=σ4[σ2I+K+ΔK]μν1+σ2δμν\displaystyle\Sigma_{\mu\nu}=\left\langle f_{\mu}f_{\nu}\right\rangle-\left\langle f_{\mu}\right\rangle\left\langle f_{\nu}\right\rangle=-\sigma^{4}\left[\sigma^{2}I+K+\Delta K\right]_{\mu\nu}^{-1}+\sigma^{2}\delta_{\mu\nu} (10)

where the r.h.s. coincides with the posterior covariance of a GP with a kernel equal to K+ΔKK+\Delta K [35]. The variance on the test point is given by (repeating indices are summed over the training set)

Σ=KKμAμν1Kν+it2𝒞~|it=0+2(ΔgKμitμΔgKμitμ)+Var(Δg)\displaystyle\Sigma_{**}=K_{**}-K_{\mu}^{*}A_{\mu\nu}^{-1}K_{\nu}^{*}+\left\langle\partial_{it_{*}}^{2}\tilde{\mathcal{C}}|_{it_{*}=0}\right\rangle+2\left(\left\langle\Delta g_{*}K_{\mu}^{*}it_{\mu}\right\rangle-\left\langle\Delta g_{*}\right\rangle\left\langle K_{\mu}^{*}it_{\mu}\right\rangle\right)+{\mathrm{Var}}\left(\Delta g_{*}\right) (11)

where here Δg\Delta g_{*} is as in Eq. 7 but where the σ2δ^gμ\left\langle\sigma^{-2}\hat{\delta}g_{\mu}\right\rangle’s are replaced the itμit_{\mu}’s that have Gaussian fluctuations, and 𝒞~\tilde{\mathcal{C}} is 𝒞\mathcal{C} without the second cumulant (see App. A.1). The first two terms in Eq. 11 yield the standard result for the GP posterior covariance matrix on a test point [35], for the case of ΔK=0\Delta K=0 (see Eq. 9). The rest of the terms can be evaluated by the SP plus Gaussian fluctuations approximation, where the details would depend on the model at hand.

IV Two toy models

IV.1 The two layer linear CNN and its properties

Here we define a teacher-student toy model showing several qualitative real-world aspects of feature learning and analyze it via our self-consistent shifted target approach. Concretely, we consider the simplest student CNN f(𝐱)f({\mathbf{x}}), having one hidden layer with linear activation, and a corresponding teacher CNN, g(𝐱)g({\mathbf{x}})

f(𝐱)=i=1Nc=1Cai,c𝐰c𝐱~ig(𝐱)=i=1Nc=1Cai,c𝐰c𝐱~if\left({\mathbf{x}}\right)=\sum_{i=1}^{N}\sum_{c=1}^{C}a_{i,c}{\mathbf{w}}_{c}\cdot\tilde{{\mathbf{x}}}_{i}\qquad g\left({\mathbf{x}}\right)=\sum_{i=1}^{N}\sum_{c=1}^{C^{*}}a^{*}_{i,c}{\mathbf{w}}^{*}_{c}\cdot\tilde{{\mathbf{x}}}_{i} (12)

This describes a CNN that performs 1-dimensional convolution where the convolutional weights for each channel are 𝐰cS{\mathbf{w}}_{c}\in\mathbb{R}^{S}. These are dotted with a convolutional window of the input 𝐱~i=(xS(i1)+1,,xSi)𝖳S\tilde{{\mathbf{x}}}_{i}=\left(x_{S\left(i-1\right)+1},\dots,x_{S\cdot i}\right)^{\mathsf{T}}\in\mathbb{R}^{S} and there are no overlaps between them so that 𝐱=(x1,,xNS)𝖳=(𝐱~1,,𝐱~N)𝖳NS{\mathbf{x}}=\left(x_{1},\dots,x_{N\cdot S}\right)^{\mathsf{T}}=\left(\tilde{{\mathbf{x}}}_{1},\dots,\tilde{{\mathbf{x}}}_{N}\right)^{\mathsf{T}}\in\mathbb{R}^{N\cdot S}. Namely, the input dimension is d=NSd=NS, where NN is the number of (non-overlapping) convolutional windows, SS is the stride of the conv-kernel and it is also the length of the conv-kernel, hence there is no overlap between the strides. The inputs 𝐱{\mathbf{x}} are sampled from 𝒩(0,Id)\mathcal{N}(0,I_{d}).

Despite its simplicity, this model distils several key differences between feature learning models and lazy learning or GP models. Due to the lack of pooling layers, the GP associated with the student fails to take advantage of the weight sharing property of the underlying CNN [34]. In fact, here it coincides with a GP of a fully-connected DNN which is quite inappropriate for the task. We thus expect that the finite network will have good performance already for n=C(N+S)n=C^{*}(N+S) whereas the GP will need nn of order of the dimension (NSNS) to learn well [10]. Thus, for N+SNSN+S\ll NS there should be a broad regime in the value of nn where the finite network substantially outperforms the corresponding GP. We later show (§IV.4) that this performance boost over GP is due to feature learning, as one may expect.

Conveniently, the cumulants of the student DNN of any order can be worked out exactly. Assuming γ\gamma and σ2\sigma^{2} of the noisy GD training are chosen such that111Generically this requires CC dependent and layer dependent weight decay. ai,c𝒩(0,σa2/CN),𝐰c𝒩(𝟎,σw2SIS)a_{i,c}\sim\mathcal{N}\left(0,\sigma_{a}^{2}/CN\right),\quad{\mathbf{w}}_{c}\sim\mathcal{N}\left(\bm{0},\frac{\sigma_{w}^{2}}{S}I_{S}\right) (and similarly for the teacher DNN) the covariance function for the associated GP reads

K(𝐱,𝐱)=σa2σw2NSi=1N𝐱~i𝖳𝐱~i=σa2σw2NS𝐱i𝖳𝐱iK\left({\mathbf{x}},{\mathbf{x}}^{\prime}\right)=\frac{\sigma_{a}^{2}\sigma_{w}^{2}}{NS}\sum_{i=1}^{N}\tilde{{\mathbf{x}}}_{i}^{\mathsf{T}}\tilde{{\mathbf{x}}}_{i}^{\prime}=\frac{\sigma_{a}^{2}\sigma_{w}^{2}}{NS}{\mathbf{x}}_{i}^{\mathsf{T}}{\mathbf{x}}_{i}^{\prime} (13)

Denoting λ:=σa2Nσw2S\lambda:=\frac{\sigma_{a}^{2}}{N}\frac{\sigma_{w}^{2}}{S}, the even cumulant of arbitrary order 2m2m is (see App. F):

κ2m(𝐱1,,𝐱2m)=λmCm1i1,,im=1N(i1,i2)(,im2)im1(,im1)im[(2m1)!]\kappa_{2m}\left({\mathbf{x}}_{1},\dots,{\mathbf{x}}_{2m}\right)=\frac{\lambda^{m}}{C^{m-1}}\sum_{i_{1},\dots,i_{m}=1}^{N}\left(\bullet_{i_{1}},\bullet_{i_{2}}\right)\cdots\left(\bullet{}_{i_{m-2}},\bullet{}_{i_{m-1}}\right)\left(\bullet{}_{i_{m-1}},\bullet{}_{i_{m}}\right)\cdots\left[\left(2m-1\right)!\right] (14)

while all odd cumulants vanish due to the sign flip symmetry of the last layer. In this notation, we mean that the \bullet’s stand for integers in {1,,2m}\left\{1,\dots,2m\right\} and e.g. (1i1,2i2)(𝐱~i11𝐱~i22)\left(1_{i_{1}},2_{i_{2}}\right)\equiv\left(\tilde{{\mathbf{x}}}_{i_{1}}^{1}\cdot\tilde{{\mathbf{x}}}_{i_{2}}^{2}\right) and the bracket notation [(2m1)!]\left[\left(2m-1\right)!\right] stands for the number of ways to pair the integers {1,,2m}\left\{1,...,2m\right\} into the above form. This result can then be plugged in 7 to obtain the self consistent (saddle point) equations on the training set. See App. A.4 for a convergence criterion for the saddle point, supporting its application here.

IV.2 Self consistent equation in the limit of a large training set

In §III our description of the self consistent equations was for a finite and fixed training set. Further analytical insight can be gained if we consider the limit of a large training set, known in the GP literature as the Equivalent Kernel (EK) limit [35, 38]. For a short review of this topic, see App. D. In essence, in the EK limit we replace the discrete sums over a specific draw of training set, as in Eqs. 5, 7, 8, with integrals over the entire input distribution μ(𝐱)\mu({\mathbf{x}}). Given a kernel that admits a spectral decomposition in terms of its eigenvalues and eigenfunctions: K(𝐱,𝐱)=sλsψs(𝐱)ψs(𝐱)K\left({\mathbf{x}},{\mathbf{x}}^{\prime}\right)=\sum_{s}\lambda_{s}\psi_{s}\left({\mathbf{x}}\right)\psi_{s}\left({\mathbf{x}}^{\prime}\right), the standard result for the GP posterior mean at a test point is approximated by [35]

f(𝐱)=𝑑μ(𝐱)h(𝐱,𝐱)g(𝐱);h(𝐱,𝐱)=sλsλs+σ2/nψs(𝐱)ψs(𝐱)\left\langle f\left({\mathbf{x}}_{*}\right)\right\rangle=\int d\mu\left(\mathbf{x}\right)h\left(\mathbf{x}_{*},{\mathbf{x}}\right)g\left({\mathbf{x}}\right);\qquad h\left(\mathbf{x}_{*},{\mathbf{x}}\right)=\sum_{s}\frac{\lambda_{s}}{\lambda_{s}+\sigma^{2}/n}\psi_{s}\left({\mathbf{x}}_{*}\right)\psi_{s}\left({\mathbf{x}}\right) (15)

This has several advantages, already at the level of GP analysis. From a theoretical point of view, the integral expressions retain the symmetries of the kernel K(𝐱,𝐱)K({\mathbf{x}},{\mathbf{x}}^{\prime}) unlike the discrete sums that ruin these symmetries. Also, Eq. 15 does not involve computing the inverse matrix K~1\tilde{K}^{-1} which is costly for large matrices.

In the context of our theory, the EK limit allows for a derivation of a simple analytical form for the self consistent equations. As shown in App. E.1 in our toy CNN both Δg\Delta g and δ^g\hat{\delta}g become linear in the target. Thus the self-consistent equations can be reduced to a single equation governing the proportionality factor (α\alpha) between δ^g\hat{\delta}g and gg (δ^g=αg\hat{\delta}g=\alpha g). Thus starting from the general self consistent equations, 5, 7, 8, taking their EK limit, and plugging in the general cumulant for our toy model (14) we arrive at the following equation for α\alpha

α=σ2/nλ+σ2/n+(1q)λλ+σ2/n+(qλλ+σ2/n1)λ2C(ασ2/n)3[1λC(ασ2/n)2]1\alpha=\frac{\sigma^{2}/n}{\lambda+\sigma^{2}/n}+\frac{\left(1-q\right)\lambda}{\lambda+\sigma^{2}/n}+\left(q\frac{\lambda}{\lambda+\sigma^{2}/n}-1\right)\frac{\lambda^{2}}{C}\left(\frac{\alpha}{\sigma^{2}/n}\right)^{3}\left[1-\frac{\lambda}{C}\left(\frac{\alpha}{\sigma^{2}/n}\right)^{2}\right]^{-1} (16)

Setting for simplicity σa2=1=σw2\sigma_{a}^{2}=1=\sigma_{w}^{2} we have λ=1/(NS)\lambda=1/\left(NS\right) and we also introduced the constant qλ1(1α^GP)(λ+σ2/n)q\equiv\lambda^{-1}(1-\hat{\alpha}_{\rm{GP}})(\lambda+\sigma^{2}/n) where α^GP\hat{\alpha}_{\mathrm{GP}} is computed using the empirical GP predictions on either the training set or test set: α^GP1(μfμGPgμ)/(μgμ2)\hat{\alpha}_{\rm{GP}}\equiv 1-\left(\sum_{\mu}f_{\mu}^{\rm{GP}}g_{\mu}\right)/\left(\sum_{\mu}g_{\mu}^{2}\right), or analytically in the perturbation theory approach developed in [10]. The quantity qq has an interpretation as 1/n1/n corrections to the EK approximation [10] but here can be considered as a fitting parameter. It is non-negative and is typically O(1)O(1); for more details and analytical estimates see App. E.2.

Equation 16 is the second main analytical result of this work. It simplifies the highly non-linear inference problem to a single equation that embodies strong non-linear finite-DNN effect and feature learning (see also §IV.4). In practice, to compute αtest\alpha_{\mathrm{test}} we numerically solve 16 using qtrainq_{\mathrm{train}} for the training set to get αtrain\alpha_{\mathrm{train}}, and then set α=αtrain\alpha=\alpha_{\mathrm{train}} in the r.h.s. of 16 but use q=qtestq=q_{\mathrm{test}}. Equation 16 can also be used to bound α\alpha analytically on both the training set and test point, given the reasonable assumption that α\alpha changes continuously with CC. Indeed, at large CC the pole in this equation lays at αpole=(σ2/n)(C/λ)1/21\alpha_{\rm{pole}}=(\sigma^{2}/n)(C/\lambda)^{1/2}\gg 1 whereas ααGP<αpole\alpha\approx\alpha_{\rm{GP}}<\alpha_{\rm{pole}}. As CC diminishes, continuity implies that α\alpha must remain smaller than αpole\alpha_{\rm{pole}}. The latter decays as σ2CNS/n\sigma^{2}\sqrt{CNS}/n implying that the amount of data required for good performance scales as CNS\sqrt{CNS} rather than as NSNS in the GP case.

IV.3 Numerical verification

In this section we numerically verify the predictions of the self consistent theory of Sec. §IV.2, by training linear shallow student CNNs on a teacher with C=1C^{*}=1 as in Eq. 12, using noisy gradients as in Eq. 1, and averaging their outputs across noise realizations and across dynamics after reaching equilibrium.

For simplicity we used N=SN=S and n{62,200,650},S{15,30,60}n\in\left\{62,200,650\right\},S\in\left\{15,30,60\right\} so that nS1.7n\propto S^{1.7}. The latter scaling places us in the poorly performing regime of the associated GP while allowing good performance of the CNN. Indeed, as aforementioned, the GP here requires nn of the scale of λ1=NS=O(S2)\lambda^{-1}=NS=O(S^{2}) for good performance [10], while the CNN requires nn of scale of the number of parameters (C(N+S)=O(S)C(N+S)=O(S)).

The results are shown in Fig 1 where we compare the theoretical predictions given by the solutions of the self consistent equation (16) to the empirical values of α\alpha obtained by training actual CNNs and averaging their outputs across the ensemble.

(A) Refer to caption (B) Refer to caption


Figure 1: (A) The CNNs’ cosine distance α\alpha, defined by f=(1α)g\left\langle f\right\rangle=(1-\alpha)g between the ensemble-averaged prediction f\left\langle f\right\rangle and ground truth gg plotted vs. number of channels CC for the test set (for the train set, see App. H.1). As nn increases, the solution of the self consistent equation 16 (solid line) yields an increasingly accurate prediction of these empirical values (dots). (B) Same data as in (A), presented as empirical α\alpha vs. predicted α\alpha. As nn grows, the two converge to the identity line (dashed black line). Solid lines connecting the dots here are merely for visualization purposes.

IV.4 Feature learning phase transition in the CNN model

At this point there is evidence that our self-consistent shifted target approach works well within the feature learning regime of the toy model. Indeed GP is sub-optimal here, since it does not represent the CNN’s weight sharing present in the teacher network. Weight sharing is intimately tied with feature learning in the first layer, since it aggregates the information coming from all convolutional windows to refine a single set of repeating convolution-filters. Empirically, we observed a large performance gap of finite CC CNNs over the infinite-CC (GP) limit, which was also observed previously in more realistic settings [23, 13, 34]. Taken together with the existence of a clear feature in the teacher, a natural explanation for this performance gap is that feature learning, which is completely absent in GPs, plays a major role in the behavior of finite CC CNNs.

To analyze this we wish to track how the feature of the teacher 𝐰{\mathbf{w}}^{*} are reflected in the student network’s first layer weights 𝐰c{\mathbf{w}}_{c} across training time (after reaching equilibrium) and across training realizations. However, as our formalism deals with ensembles of DNNs, computing averages of 𝐰c{\mathbf{w}}_{c} with respect to these ensembles would simply give zero. Indeed, the chance of a DNN with specific parameters θ={ai,c,𝐰c}\theta=\left\{a_{i,c},\,{\mathbf{w}}_{c}\right\} appearing is the same as that of θ-\theta. Consequently, to detect feature learning the first reasonable object to examine is the empirical covariance matrix ΣW=SCWW𝖳\Sigma_{W}=\frac{S}{C}WW^{\mathsf{T}}, where the matrix WS×CW\in\mathbb{R}^{S\times C} has 𝐰c{\mathbf{w}}_{c} as its cc’th column. This ΣW\Sigma_{W} is invariant under such a change of signs and provides important information on the statistics of 𝐰c{\mathbf{w}}_{c}.

As shown in App. G, using our field-theory or function-space formulation, we find that to leading order in 1/C1/C the ensemble average of the empirical covariance matrix, for a teacher with a single feature 𝐰{\mathbf{w}}^{*}, is

[ΣW]ss=(1+(1λ+nσ2)1)δss+2Cλ(λ+σ2/n)2wsws+O(1/C2)\left\langle\left[\Sigma_{W}\right]_{ss^{\prime}}\right\rangle=\left(1+\left(\frac{1}{\lambda}+\frac{n}{\sigma^{2}}\right)^{-1}\right)\delta_{ss^{\prime}}+\frac{2}{C}\frac{\lambda}{\left(\lambda+\sigma^{2}/n\right)^{2}}w_{s}^{*}w_{s^{\prime}}^{*}+O(1/C^{2}) (17)

A first conclusion that could be drawn here, is that given access to an ensemble of such trained CNNs, feature learning happens for any finite CC as a statistical property. We turn to discuss the more common setting where one wishes to use the features learned by a specific randomly chosen CNN from this ensemble.

To this end, we follow Ref. [27] and model ΣW\Sigma_{W} as a Wishart matrix with a rank-one perturbation. The variance of the matrix and details of the rank one perturbation are then determined by the above equation. Consequently the eigenvalue distribution is expected to follow a spiked Marchenko-Pastur (MP), which was studied extensively in [6]. To test this modeling assumption, for each snapshot of training time (after reaching equilibrium) and noise realization we compute ΣW\Sigma_{W}’s eigenvalues and aggregate these across the ensemble. In Fig. 2 we plot the resulting empirical spectral distribution for varying values of CC while keeping SS fixed. Note that, differently from the usual spiked-MP model, varying CC here changes both the distribution of the MP bulk (which is determined by the ratio S/CS/C) as well as the strength of the low-rank perturbation.

Our main finding is a phase transition between two regimes which becomes sharp as one takes n,Sn,S\rightarrow\infty. In the regime of large CC the eigenvalue distribution of ΣW\Sigma_{W} is indistinguishable from the MP distribution, whereas in the regime of small CC an outlier eigenvalue λm\lambda_{m} departs from the support of the bulk MP distribution and the associated top eigenvector has a non-zero overlap with 𝐰{\mathbf{w}}^{*}, see Fig. 2. We refer to the latter as the feature-learning regime, since the feature 𝐰{\mathbf{w}}^{*} is manifested in the spectrum of the students weights, whereas the former is the non-feature learning regime. We use the quantity 𝒬𝐰𝖳ΣW𝐰\mathcal{Q}\equiv{\mathbf{w}}^{*\mathsf{T}}\Sigma_{W}{\mathbf{w}}^{*} as a surrogate for λm\lambda_{m}, as it is valid on both sides of the transition. Having established the correspondence to the MP plus low rank model, we can use the results of [6] to find the exact location of the phase transition, which occurs at the critical value CcritC_{\rm{crit}} given by

Ccrit=4S(S1+(σ2/n)S)4(1+(S2+nσ2)1)+O(1+(1λ+nσ2)1)C_{\rm{crit}}=\frac{4}{S\left(S^{-1}+\left(\sigma^{2}/n\right)S\right)^{4}}\left(1+\left(S^{2}+\frac{n}{\sigma^{2}}\right)^{-1}\right)+O\left(1+\left(\frac{1}{\lambda}+\frac{n}{\sigma^{2}}\right)^{-1}\right) (18)

where we assumed for simplicity N=SN=S so that λ=S2\lambda=S^{-2}.

(A) Refer to caption (B) Refer to caption

Figure 2: (A) Aggregated histograms of ΣW\Sigma_{W} eigenvalues where ΣW=SCWW𝖳\Sigma_{W}=\frac{S}{C}WW^{\mathsf{T}} is the normalized empirical covariance matrix of the hidden layer weights during training. Different colors indicate varying number of channels, CC. Solid smooth lines indicate the corresponding Marchenko-Pastur (MP) distributions with support on [λ,λ+]\left[\lambda_{-},\lambda_{+}\right] where: λ±=(1±S/C)2\lambda_{\pm}=\left(1\pm\sqrt{S/C}\right)^{2}. The quantity 𝒬𝐰𝖳ΣW𝐰\mathcal{Q}\equiv{\mathbf{w}}^{*\mathsf{T}}\Sigma_{W}{\mathbf{w}}^{*}, which correlates with the SNR of the feature 𝐰{\mathbf{w}}^{*}, is represented by thick short bars. For large CC, 𝒬\mathcal{Q} remains within the MP bulk whereas for small CC it pops out. (B) The theoretical λ+\lambda_{+} curve and interpolated curve of 𝒬\mathcal{Q} and intersect very close to the theoretically predicted value given in Eq. 18, here given by Ccrit=473C_{\rm{crit}}=473 (dashed vertical line).

IV.5 Two-layer FCN with average pooling and quadratic activations

Another setting where GPs are expected to under-perform finite-DNNs is the case of quadratic fully connected teacher and student DNNs where the teacher is rank-1, also known as the phase retrieval problem [26]. Here we consider some positive target of the form g(𝐱)=(𝐰𝐱)2σw2𝐱2g({\mathbf{x}})=({\mathbf{w}}_{*}\cdot{\mathbf{x}})^{2}-\sigma_{w}^{2}||{\mathbf{x}}||^{2} where 𝐰,𝐱d{\mathbf{w}}_{*},{\mathbf{x}}\in{\mathbb{R}}^{d} and a student DNN given by f(𝐱)=m=1M(𝐰m𝐱)2σw2𝐱2f({\mathbf{x}})=\sum_{m=1}^{M}({\mathbf{w}}_{m}\cdot{\mathbf{x}})^{2}-\sigma_{w}^{2}||{\mathbf{x}}||^{2}. We consider training this DNN on nn train points {𝐱μ}μ=1n\left\{{\mathbf{x}}_{\mu}\right\}_{\mu=1}^{n} using noisy GD training with weight decay γ=2Mσ2/σw2\gamma=2M\sigma^{2}/\sigma_{w}^{2}.

Similarly to the previous toy model, here too the GP associated with the student at large MM (and finite σ2\sigma^{2}) overlooks a qualitative feature of the finite DNN — the fact that the first term in f(𝐱)f({\mathbf{x}}) is non-negative. Interestingly, this feature provides a strong performance boost [26] in the σ20\sigma^{2}\rightarrow 0 limit compared to the associated GP. Namely the DNN, even at large MM, performs well for n>2dn>2d [26] whereas the associated GP is expected to work well only for n=O(d2)n=O(d^{2}) [10].

We wish to solve for the predictions of this model with our self consistent GP based approach. As shown in App. I, the cumulants of this model can be obtained from the following cumulant generating function

𝒞(t1,,tn+1)\displaystyle\mathcal{C}(t_{1},...,t_{n+1}) =M2Tr(log[I2M1σw2μitμ𝐱μ𝐱μ𝖳])μ=1n+1itμσw2𝐱μ2\displaystyle=-\frac{M}{2}\Tr\left(\log\left[I-2M^{-1}\sigma_{w}^{2}\sum_{\mu}it_{\mu}{\mathbf{x}}_{\mu}{\mathbf{x}}_{\mu}^{\mathsf{T}}\right]\right)-\sum_{\mu=1}^{n+1}it_{\mu}\sigma_{w}^{2}||{\mathbf{x}}_{\mu}||^{2} (19)

The associated GP kernel is given by K(𝐱μ,𝐱ν)=2M1σw4(𝐱μ𝐱ν)2K({\mathbf{x}}_{\mu},{\mathbf{x}}_{\nu})=2M^{-1}\sigma_{w}^{4}({\mathbf{x}}_{\mu}\cdot{\mathbf{x}}_{\nu})^{2}. Following this, the target shift equation, at the saddle point level, appears as

Δgν\displaystyle\Delta g_{\nu} =μK(𝐱ν,𝐱μ)δ^gμσ2+σw2𝐱ν𝖳[I2M1σw2μδ^gμσ2𝐱μ𝐱μ𝖳]1𝐱νσw2𝐱ν2\displaystyle=-\sum_{\mu}K({\mathbf{x}}_{\nu},{\mathbf{x}}_{\mu})\frac{\hat{\delta}g_{\mu}}{\sigma^{2}}+\sigma_{w}^{2}{\mathbf{x}}_{\nu}^{\mathsf{T}}\left[I-2M^{-1}\sigma_{w}^{2}\sum_{\mu}\frac{\hat{\delta}g_{\mu}}{\sigma^{2}}{\mathbf{x}}_{\mu}{\mathbf{x}}_{\mu}^{\mathsf{T}}\right]^{-1}{\mathbf{x}}_{\nu}-\sigma_{w}^{2}||{\mathbf{x}}_{\nu}||^{2} (20)

In App. I, we solve these equations numerically for σ2=105\sigma^{2}=10^{-5} and show that our approach captures the correct n=2dn=2d threshold value. An analytic solution of these equations at low σ2\sigma^{2} using EK or other continuum approximations is left for future work (see Refs. [10, 7, 8] for potential approaches). As a first step towards this goal, in App. I we consider the simpler case of σ2=1\sigma^{2}=1 and derive the asymptotics of the learning curves which deviate strongly from those of GP for MdM\ll d.

V Discussion

In this work we presented a correspondence between ensembles of finite DNNs trained with noisy gradients and GPs trained on a shifted target. The shift in the target can be found by solving a set of self consistent equations for which we give a general form. We found explicit expressions for these equations for the case of a 2-layer linear CNN and a non-linear FCN, and solved them analytically and numerically. For the former model, we performed numerical experiments on CNNs that agree well with our theory both in the GP regime and well away from it, i.e. for small number of channels CC, thus accounting for strong finite CC effects. For the latter model, the numerical solution of these equations capture a remarkable and subtle effect in these DNNs which the GP approach completely overlooks — the n=2dn=2d threshold value.

Considering feature learning in the CNN model, we found that averaging over ensembles of such networks always leads to a form of feature learning. Namely, the teacher always leaves a signature on the statistics of the student’s weights. However, feature learning is usually considered at the level of a single DNN instance rather than an ensemble of DNNs. Focusing on this case, we show numerically that the eigenvalues of ΣW\Sigma_{W}, the student hidden weights covariance matrix, follow a Marchenko–Pastur distribution plus a rank-1 perturbation. We then use our approach to derive the critical number of channels CcritC_{\rm{crit}} below which the student is in a feature learning regime.

There are many directions for future research. Our toy models where chosen to be as simple as possible in order to demonstrate the essence of our theory on problems where lazy learning grossly under-performs finite-DNNs. Even within this setting, various extensions are interesting to consider such as adding more features to the teacher CNN (e.g. biases or a subset of linear functions which are more favorable), studying linear CNNs with overlapping convolutional windows, or deeper linear CNNs. As for non-linear CNNs, we believe it is possible to find the exact cumulants of any order for a variety of toy CNNs involving, for example, quadratic activation functions. For other cases it may be useful to develop methods for characterizing and approximating the cumulants.

More generally, we advocated here a physics-style methodology using approximations, self-consistency checks, and experimental tests. As DNNs are very complex experimental systems, we believe this mode of research is both appropriate and necessary. Nonetheless we hope the insights gained by our approach would help generate a richer and more relevant set of toy models on which mathematical proofs could be made.

References

  • Alemohammad et al. [2020] Alemohammad, S., Wang, Z., Balestriero, R., and Baraniuk, R. (2020). The recurrent neural tangent kernel. arXiv preprint arXiv:2006.10246.
  • Andreassen and Dyer [2020] Andreassen, A. and Dyer, E. (2020). Asymptotics of wide convolutional neural networks. arXiv preprint arXiv:2008.08675.
  • Arora et al. [2019] Arora, S., Du, S. S., Hu, W., Li, Z., Salakhutdinov, R., and Wang, R. (2019). On Exact Computation with an Infinitely Wide Neural Net. arXiv e-prints, page arXiv:1904.11955.
  • Baldi and Hornik [1989] Baldi, P. and Hornik, K. (1989). Neural networks and principal component analysis: Learning from examples without local minima. Neural networks, 2(1), 53–58.
  • Benaych-Georges and Nadakuditi [2011] Benaych-Georges, F. and Nadakuditi, R. R. (2011). The eigenvalues and eigenvectors of finite, low rank perturbations of large random matrices. Advances in Mathematics, 227(1), 494–521.
  • Benaych-Georges and Nadakuditi [2012] Benaych-Georges, F. and Nadakuditi, R. R. (2012). The singular values and vectors of low rank perturbations of large rectangular random matrices. Journal of Multivariate Analysis, 111, 120–135.
  • Bordelon et al. [2020] Bordelon, B., Canatar, A., and Pehlevan, C. (2020). Spectrum dependent learning curves in kernel regression and wide neural networks.
  • Canatar et al. [2021] Canatar, A., Bordelon, B., and Pehlevan, C. (2021). Spectral bias and task-model alignment explain generalization in kernel regression and infinitely wide neural networks. Nature Communications, 12(1).
  • Chizat et al. [2019] Chizat, L., Oyallon, E., and Bach, F. (2019). On lazy training in differentiable programming. In Advances in Neural Information Processing Systems, pages 2937–2947.
  • Cohen et al. [2019] Cohen, O., Malka, O., and Ringel, Z. (2019). Learning Curves for Deep Neural Networks: A Gaussian Field Theory Perspective. arXiv e-prints, page arXiv:1906.05301.
  • Daniels [1954] Daniels, H. E. (1954). Saddlepoint Approximations in Statistics. The Annals of Mathematical Statistics, 25(4), 631 – 650.
  • Dyer and Gur-Ari [2020] Dyer, E. and Gur-Ari, G. (2020). Asymptotics of wide networks from feynman diagrams. In International Conference on Learning Representations.
  • Geiger et al. [2020] Geiger, M., Spigler, S., Jacot, A., and Wyart, M. (2020). Disentangling feature and lazy training in deep neural networks. Journal of Statistical Mechanics: Theory and Experiment, 2020(11), 113301.
  • Geiger et al. [2021] Geiger, M., Petrini, L., and Wyart, M. (2021). Landscape and training regimes in deep learning. Physics Reports.
  • Ghorbani et al. [2020] Ghorbani, B., Mei, S., Misiakiewicz, T., and Montanari, A. (2020). When do neural networks outperform kernel methods? arXiv preprint arXiv:2006.13409.
  • Gunasekar et al. [2018] Gunasekar, S., Lee, J., Soudry, D., and Srebro, N. (2018). Implicit bias of gradient descent on linear convolutional networks. arXiv preprint arXiv:1806.00468.
  • Helias and Dahmen [2019] Helias, M. and Dahmen, D. (2019). Statistical field theory for neural networks. arXiv preprint arXiv:1901.10416.
  • Hron et al. [2020] Hron, J., Bahri, Y., Sohl-Dickstein, J., and Novak, R. (2020). Infinite attention: Nngp and ntk for deep attention networks. In International Conference on Machine Learning, pages 4376–4386. PMLR.
  • Jacot et al. [2018] Jacot, A., Gabriel, F., and Hongler, C. (2018). Neural Tangent Kernel: Convergence and Generalization in Neural Networks. arXiv e-prints, page arXiv:1806.07572.
  • Lampinen and Ganguli [2018] Lampinen, A. K. and Ganguli, S. (2018). An analytic theory of generalization dynamics and transfer learning in deep linear networks. arXiv preprint arXiv:1809.10374.
  • Lee et al. [2018] Lee, J., Sohl-dickstein, J., Pennington, J., Novak, R., Schoenholz, S., and Bahri, Y. (2018). Deep neural networks as gaussian processes. In International Conference on Learning Representations.
  • Lee et al. [2019] Lee, J., Xiao, L., Schoenholz, S. S., Bahri, Y., Novak, R., Sohl-Dickstein, J., and Pennington, J. (2019). Wide neural networks of any depth evolve as linear models under gradient descent. arXiv preprint arXiv:1902.06720.
  • Lee et al. [2020] Lee, J., Schoenholz, S. S., Pennington, J., Adlam, B., Xiao, L., Novak, R., and Sohl-Dickstein, J. (2020). Finite versus infinite neural networks: an empirical study. arXiv preprint arXiv:2007.15801.
  • Li and Sompolinsky [2020] Li, Q. and Sompolinsky, H. (2020). Statistical mechanics of deep linear neural networks: The back-propagating renormalization group. arXiv preprint arXiv:2012.04030.
  • Malach et al. [2021] Malach, E., Kamath, P., Abbe, E., and Srebro, N. (2021). Quantifying the benefit of using differentiable learning over tangent kernels. arXiv preprint arXiv:2103.01210.
  • Mannelli et al. [2020] Mannelli, S. S., Vanden-Eijnden, E., and Zdeborová, L. (2020). Optimization and generalization of shallow neural networks with quadratic activation functions. arXiv preprint arXiv:2006.15459.
  • Martin and Mahoney [2018] Martin, C. H. and Mahoney, M. W. (2018). Implicit Self-Regularization in Deep Neural Networks: Evidence from Random Matrix Theory and Implications for Learning. arXiv e-prints, page arXiv:1810.01075.
  • Matthews et al. [2018] Matthews, A. G. d. G., Rowland, M., Hron, J., Turner, R. E., and Ghahramani, Z. (2018). Gaussian process behaviour in wide deep neural networks. arXiv preprint arXiv:1804.11271.
  • Mccullagh [2017] Mccullagh, P. (2017). Tensor Methods in Statistics. Dover Books on Mathematics.
  • Naveh et al. [2020] Naveh, G., Ben-David, O., Sompolinsky, H., and Ringel, Z. (2020). Predicting the outputs of finite networks trained with noisy gradients. arXiv preprint arXiv:2004.01190.
  • Neal [1996] Neal, R. M. (1996). Priors for infinite networks. In Bayesian Learning for Neural Networks, pages 29–53. Springer.
  • Note1 [????] Note1 (????). Generically this requires CC dependent and layer dependent weight decay.
  • Note2 [????] Note2 (????). The 𝐱2||{\mathbf{x}}||^{2} shift is not part of the original model but has only a superficial shift effect useful for book-keeping later on.
  • Novak et al. [2018] Novak, R., Xiao, L., Lee, J., Bahri, Y., Yang, G., Abolafia, D. A., Pennington, J., and Sohl-Dickstein, J. (2018). Bayesian Deep Convolutional Networks with Many Channels are Gaussian Processes. arXiv e-prints, page arXiv:1810.05148.
  • Rasmussen and Williams [2005] Rasmussen, C. E. and Williams, C. K. I. (2005). Gaussian Processes for Machine Learning (Adaptive Computation and Machine Learning). The MIT Press.
  • Refinetti et al. [2021] Refinetti, M., Goldt, S., Krzakala, F., and Zdeborová, L. (2021). Classifying high-dimensional gaussian mixtures: Where kernel methods fail and neural networks succeed. arXiv preprint arXiv:2102.11742.
  • Saxe et al. [2013] Saxe, A. M., McClelland, J. L., and Ganguli, S. (2013). Exact solutions to the nonlinear dynamics of learning in deep linear neural networks. arXiv preprint arXiv:1312.6120.
  • Sollich and Williams [2004] Sollich, P. and Williams, C. K. (2004). Understanding gaussian process regression using the equivalent kernel. In International Workshop on Deterministic and Statistical Methods in Machine Learning, pages 211–228. Springer.
  • Yaida [2020] Yaida, S. (2020). Non-gaussian processes and neural networks at finite widths. In Mathematical and Scientific Machine Learning, pages 165–192. PMLR.
  • Yang and Hu [2020] Yang, G. and Hu, E. J. (2020). Feature learning in infinite-width neural networks. arXiv preprint arXiv:2011.14522.
  • Yosinski et al. [2014] Yosinski, J., Clune, J., Bengio, Y., and Lipson, H. (2014). How transferable are features in deep neural networks? In Z. Ghahramani, M. Welling, C. Cortes, N. Lawrence, and K. Q. Weinberger, editors, Advances in Neural Information Processing Systems, volume 27. Curran Associates, Inc.
  • Yu et al. [2013] Yu, D., Seltzer, M., Li, J., Huang, J.-T., and Seide, F. (2013). Feature learning in deep neural networks - studies on speech recognition. In International Conference on Learning Representations.
  • Zavatone-Veth and Pehlevan [2021] Zavatone-Veth, J. A. and Pehlevan, C. (2021). Exact priors of finite neural networks. arXiv preprint arXiv:2104.11734.

Appendix A Derivation of the target shift equations

A.1 The partition function in terms of dual variables

Consider the general setting of Bayesian inference with Gaussian measurement noise (or equivalently a DNN trained with MSE loss, weight decay, and white noise added to the gradients). Let {𝐱μ,gμ}μ=1n\left\{{\mathbf{x}}_{\mu},g_{\mu}\right\}_{\mu=1}^{n} denote the inputs and targets on the training set and let 𝐱𝐱n+1{\mathbf{x}}_{*}\equiv{\mathbf{x}}_{n+1} be the test point. Denote the prior (or equivalently the equilibrium distribution of a DNN trained with no data) by P0(f)P_{0}(\vec{f}) where fμ=f(𝐱μ)f_{\mu}=f({\mathbf{x}}_{\mu})\in\mathbb{R} is the output of the model, and f(f1,,fn,fn+1)n+1\vec{f}\equiv\left(f_{1},\dots,f_{n},f_{n+1}\right)\in\mathbb{R}^{n+1}. The model’s predictions (or equivalently the ensemble averaged DNN output) on the point 𝐱μ{\mathbf{x}}_{\mu} can be obtained by

μ{1,,n+1}:fμ=JμlogZ(J)|J=0\forall\mu\in\left\{1,\dots,n+1\right\}:\qquad\left\langle f_{\mu}\right\rangle=\partial_{J_{\mu}}\evaluated{\log Z\left(\vec{J}\right)}_{\vec{J}=\vec{0}} (A.1)

with the following partition function

Z(J)=𝑑fP0(f)exp(12σ2μ=1n(fμgμ)2+μJμfμ)Z\left(\vec{J}\right)=\int d\vec{f}P_{0}\left(\vec{f}\right)\exp\left(-\frac{1}{2\sigma^{2}}\sum_{\mu=1}^{n}\left(f_{\mu}-g_{\mu}\right)^{2}+\sum_{\mu}J_{\mu}f_{\mu}\right) (A.2)

where unless explicitly written otherwise, summations over μ\mu run from 11 to n+1n+1 (i.e. include the test point). Here we commit to the MSE loss which facilitates the derivation, and in App. C we give an alternative derivation that may also be applied to other losses such as cross-entropy. Our goal in this appendix is to establish that the target shift equations are in fact saddle point equations of the partition function A.2 following some transformations on the variables of integration. To this end, consider the cumulant generating function of P0(f)P_{0}\left(\vec{f}\right) given by

𝒞(t)=log(𝑑feiμtμfμP0(f))\displaystyle\mathcal{C}\left(\vec{t}\right)=\log\left(\int_{-\infty}^{\infty}d\vec{f}e^{i\sum_{\mu}t_{\mu}f_{\mu}}P_{0}\left(\vec{f}\right)\right) (A.3)

or expressed via the cumulant tensors:

𝒞(t)=r=21r!μ1,,μr=1n+1κμ1,,μritμ1itμr\mathcal{C}\left(\vec{t}\right)=\sum_{r=2}^{\infty}\frac{1}{r!}\sum_{\mu_{1},\dots,\mu_{r}=1}^{n+1}\kappa_{\mu_{1},\dots,\mu_{r}}it_{\mu_{1}}\cdots it_{\mu_{r}} (A.4)

where the sum over the cumulant tensors κμ1,,μr\kappa_{\mu_{1},\dots,\mu_{r}} does not include r=1r=1 since our DNN priors are assumed to have zero mean. Notably one can re-express P0(f)P_{0}\left(\vec{f}\right) as the inverse Fourier transform of e𝒞(t)e^{\mathcal{C}(\vec{t})}:

P0(f)𝑑texp(iμtμfμ+𝒞(t))P_{0}\left(\vec{f}\right)\propto\int d\vec{t}\exp\left(-i\sum_{\mu}t_{\mu}f_{\mu}+\mathcal{C}\left(\vec{t}\right)\right) (A.5)

Plugging this in Eq. A.2 we obtain

Z(J)𝑑f𝑑texp(12σ2μ=1n(fμgμ)2+μ(Jμitμ)fμ+𝒞(t))\displaystyle Z\left(\vec{J}\right)\propto\iint d\vec{f}d\vec{t}\exp\left(-\frac{1}{2\sigma^{2}}\sum_{\mu=1}^{n}\left(f_{\mu}-g_{\mu}\right)^{2}+\sum_{\mu}\left(J_{\mu}-it_{\mu}\right)f_{\mu}+\mathcal{C}\left(\vec{t}\right)\right) (A.6)

where for clarity we do not keep track of multiplicative π\pi factors that have no effect on moments of fμf_{\mu}. As the term in the exponent (the action) is quadratic in {fμ}μ=1n\left\{f_{\mu}\right\}_{\mu=1}^{n} and linear in fn+1f_{n+1} these can be integrated out to yield an equivalent partition function phrased solely in terms of t1,,tn+1t_{1},...,t_{n+1}:

Z(J)𝑑t1𝑑tne𝒮(t,J)Z\left(\vec{J}\right)\propto\int dt_{1}\cdots dt_{n}e^{-\mathcal{S}\left(\vec{t},\vec{J}\right)} (A.7)

where the action is now

𝒮=𝒞(t1,,tn,iJn+1)+μ=1n[σ22tμ2+itμgμ+Jμ(iσ2tμgμ)σ22Jμ2]\displaystyle\mathcal{S}=-\mathcal{C}\left(t_{1},\dots,t_{n},-iJ_{n+1}\right)+\sum_{\mu=1}^{n}\left[\frac{\sigma^{2}}{2}t_{\mu}^{2}+it_{\mu}g_{\mu}+J_{\mu}\left(i\sigma^{2}t_{\mu}-g_{\mu}\right)-\frac{\sigma^{2}}{2}J_{\mu}^{2}\right] (A.8)

The identification tn+1=iJn+1t_{n+1}=-iJ_{n+1} arises from the delta function:

12π𝑑fn+1eifn+1(iJn+1+tn+1)=δ(iJn+1+tn+1)\frac{1}{2\pi}\int_{-\infty}^{\infty}df_{n+1}e^{-if_{n+1}\left(iJ_{n+1}+t_{n+1}\right)}=\delta\left(iJ_{n+1}+t_{n+1}\right) (A.9)

Recall that fμ=JμlogZ(J)|J=0\left\langle f_{\mu}\right\rangle=\partial_{J_{\mu}}\evaluated{\log Z\left(\vec{J}\right)}_{\vec{J}=\vec{0}} and notice that the first term in Eq. A.8 (the cumulant generating function, 𝒞\mathcal{C}) depends on Jn+1J_{n+1} and not on {Jμ}μ=1n\left\{J_{\mu}\right\}_{\mu=1}^{n} whereas the rest of the action depends on {Jμ}μ=1n\left\{J_{\mu}\right\}_{\mu=1}^{n} and not on Jn+1J_{n+1}. Thus, for training points fμ\left\langle f_{\mu}\right\rangle amounts to the average of gμiσ2tμg_{\mu}-i\sigma^{2}t_{\mu}, and so we identify

μ{1,,n}:itμ=gμfμσ2δ^gμσ2\displaystyle\forall\mu\in\left\{1,\dots,n\right\}:\qquad\langle it_{\mu}\rangle=\frac{g_{\mu}-\langle f_{\mu}\rangle}{\sigma^{2}}\equiv\frac{\langle\hat{\delta}g_{\mu}\rangle}{\sigma^{2}} (A.10)

where \langle\cdots\rangle denotes an expectation value using Z(J=0)Z(\vec{J}=\vec{0}). We comment that the above relation holds also for any (non-mixed) cumulants of itμit_{\mu} and δ^gμ/σ2\hat{\delta}g_{\mu}/\sigma^{2} except the covariance, where a constant difference appears due to the O(J2)O(J^{2}) term in the action, namely

itμitνitμitν=1σ4(δ^gμδ^gνδ^gμδ^gν)1σ2δμν\displaystyle\left\langle it_{\mu}it_{\nu}\right\rangle-\left\langle it_{\mu}\right\rangle\left\langle it_{\nu}\right\rangle=\frac{1}{\sigma^{4}}\left(\left\langle\hat{\delta}g_{\mu}\hat{\delta}g_{\nu}\right\rangle-\left\langle\hat{\delta}g_{\mu}\right\rangle\left\langle\hat{\delta}g_{\nu}\right\rangle\right)-\frac{1}{\sigma^{2}}\delta_{\mu\nu} (A.11)

In the GP case the r.h.s. of Eq. A.11 would equal simply K~μν1-\tilde{K}^{-1}_{\mu\nu}, since on the training set the posterior covariance of δ^g\hat{\delta}g is the same as that of ff and for a GP takes the form

Σ\displaystyle\Sigma =KKK~1K\displaystyle=K-K\tilde{K}^{-1}K (A.12)
=K(K+σ2Iσ2I)K~1K\displaystyle=K-\left(K+\sigma^{2}I-\sigma^{2}I\right)\tilde{K}^{-1}K
=σ2K~1K=σ2K~1(K+σ2Iσ2I)\displaystyle=\sigma^{2}\tilde{K}^{-1}K=\sigma^{2}\tilde{K}^{-1}\left(K+\sigma^{2}I-\sigma^{2}I\right)
=σ2Iσ4K~1\displaystyle=\sigma^{2}I-\sigma^{4}\tilde{K}^{-1}

namely Σμν=σ2δμνσ4K~μν1\Sigma_{\mu\nu}=\sigma^{2}\delta_{\mu\nu}-\sigma^{4}\tilde{K}_{\mu\nu}^{-1} and thus 1σ4Σμν1σ2δμν=K~μν1\frac{1}{\sigma^{4}}\Sigma_{\mu\nu}-\frac{1}{\sigma^{2}}\delta_{\mu\nu}=-\tilde{K}_{\mu\nu}^{-1}. The reader should not be alarmed by having a negative definite covariance matrix for itμit_{\mu}, since itμit_{\mu} cannot be understood as a standard real random variable as its partition function contains imaginary terms.

To make contact with GPs it is beneficial to expand 𝒞(t1,,tn,iJn+1)\mathcal{C}(t_{1},\dots,t_{n},-iJ_{n+1}) in terms of its cumulants, and split the second cumulant, describing the DNNs’ NNGP kernel, from the rest. Namely, using Einstein summation

𝒞(t1,,tn,iJn+1)\displaystyle\mathcal{C}\left(t_{1},\dots,t_{n},-iJ_{n+1}\right) =12!κμ1,μ2itμ1itμ2+13!κμ1,μ2,μ3itμ1itμ2itμ3+O((it)4)\displaystyle=\frac{1}{2!}\kappa_{\mu_{1},\mu_{2}}it_{\mu_{1}}it_{\mu_{2}}+\frac{1}{3!}\kappa_{\mu_{1},\mu_{2},\mu_{3}}it_{\mu_{1}}it_{\mu_{2}}it_{\mu_{3}}+O((it)^{4}) (A.13)
κμ1,μ2\displaystyle\kappa_{\mu_{1},\mu_{2}} K(𝐱μ1,𝐱μ2)\displaystyle\equiv K({\mathbf{x}}_{\mu_{1}},{\mathbf{x}}_{\mu_{2}})
𝒞~(t1,,tn,iJn+1)\displaystyle\tilde{\mathcal{C}}\left(t_{1},\dots,t_{n},-iJ_{n+1}\right) 𝒞(t1,,tn,iJn+1)+12κμ1,μ2tμ1tμ2\displaystyle\equiv\mathcal{C}\left(t_{1},\dots,t_{n},-iJ_{n+1}\right)+\frac{1}{2}\kappa_{\mu_{1},\mu_{2}}t_{\mu_{1}}t_{\mu_{2}}

Writing Eq. A.8 in this fashion gives the action:

𝒮=𝒞~(t)12μ1,μ2κμ1,μ2itμ1itμ2+μ=1n[σ22(itμ)2+itμgμ+Jμ(iσ2tμgμ)σ22Jμ2]\displaystyle\mathcal{S}=-\tilde{\mathcal{C}}\left(\vec{t}\right)-\frac{1}{2}\sum_{\mu_{1},\mu_{2}}\kappa_{\mu_{1},\mu_{2}}it_{\mu_{1}}it_{\mu_{2}}+\sum_{\mu=1}^{n}\left[-\frac{\sigma^{2}}{2}\left(it_{\mu}\right)^{2}+it_{\mu}g_{\mu}+J_{\mu}\left(i\sigma^{2}t_{\mu}-g_{\mu}\right)-\frac{\sigma^{2}}{2}J_{\mu}^{2}\right] (A.14)

A.2 Saddle point equation for the mean predictor

Having arrived at the action A.14, we can readily derive the saddle point equations for the training points by setting:

ν{1,,n}:itν𝒮(t,J)|J=0=0\forall\nu\in\left\{1,\dots,n\right\}:\qquad\evaluated{\partial_{it_{\nu}}\mathcal{S}\left(\vec{t},\vec{J}\right)}_{\vec{J}=\vec{0}}=0 (A.15)

This corresponds to treating the variables {itμ}μ=1n\left\{it_{\mu}\right\}_{\mu=1}^{n} as non-fluctuating quantities, i.e. replacing them with their mean value: itμitμit_{\mu}\to\left\langle it_{\mu}\right\rangle. Performing this for the training set ν{1,,n}\nu\in\left\{1,\dots,n\right\} yields

μ=1n(κμ,ν+σ2δμν)itμ=gνΔgν\sum_{\mu=1}^{n}\left(\kappa_{\mu,\nu}+\sigma^{2}\delta_{\mu\nu}\right)\left\langle it_{\mu}\right\rangle=g_{\nu}-\Delta g_{\nu} (A.16)

where

Δgν=r=31(r1)!μ1,,μr1=1nκν,μ1,,μr1itμ1itμr1\Delta g_{\nu}=\sum_{r=3}^{\infty}\frac{1}{\left(r-1\right)!}\sum_{\mu_{1},\dots,\mu_{r-1}=1}^{n}\kappa_{\nu,\mu_{1},\dots,\mu_{r-1}}\left\langle it_{\mu_{1}}\right\rangle\cdots\left\langle it_{\mu_{r-1}}\right\rangle (A.17)

where this target shift is related to 𝒞\mathcal{C} of Eq. A.13 by

Δgν=itν𝒞~(t1,,tn,tn+1)\Delta g_{\nu}=\partial_{it_{\nu}}\tilde{\mathcal{C}}\left(t_{1},\dots,t_{n},t_{n+1}\right) (A.18)

Finally, we get the expression for the mean predictor at the test point by setting f=JlogZ(J)|J=0\left\langle f_{*}\right\rangle=\evaluated{\partial_{J_{*}}\log Z\left(\vec{J}\right)}_{\vec{J}=\vec{0}} and plugging in the SP values for itμit_{\mu} on the training set from Eq. A.16 and the target shift Δg\Delta g from Eq. A.17. This gives

f=Δg+μ,νnKμK~μν1(gνΔgν)\left\langle f_{*}\right\rangle=\Delta g_{*}+\sum^{n}_{\mu,\nu}K_{\mu}^{*}\tilde{K}_{\mu\nu}^{-1}\left(g_{\nu}-\Delta g_{\nu}\right) (A.19)

A.3 Posterior covariance

A.3.1 Posterior covariance on the test point

The posterior covariance on the test point is important for determining the average MSE loss on the test-set, as the latter involves the MSE of the mean-predictor plus the posterior covariance. Concretely, we wish to calculate J2log(Z)\partial^{2}_{J_{*}}\log(Z), express it as an expectation value w.r.t ZZ and calculate this using ZZ with the self-consistent target shift. To this end, note that generally if Z(J)=e𝒮(J)Z(J)=\int e^{-\mathcal{S}(J)} then

J2log(Z(J))=𝒮(0)2Z(J=0)𝒮(0)Z(J=0)2𝒮′′(0)Z(J=0)\partial^{2}_{J}\log(Z(J))=\langle\mathcal{S}^{\prime}(0)^{2}\rangle_{Z(J=0)}-\langle\mathcal{S}^{\prime}(0)\rangle_{Z(J=0)}^{2}-\langle\mathcal{S}^{\prime\prime}(0)\rangle_{Z(J=0)} (A.20)

For the action in Eq. A.14 we have

𝒮(0)=κνitνit𝒞~Δg𝒮′′(0)=κit2𝒞~\mathcal{S}^{\prime}(0)=-\kappa_{*\nu}it_{\nu}-\underbrace{\partial_{it_{*}}\tilde{\mathcal{C}}}_{\Delta g_{*}}\qquad\mathcal{S}^{\prime\prime}(0)=-\kappa_{**}-\partial_{it_{*}}^{2}\tilde{\mathcal{C}} (A.21)

where an Einstein summation over ν=1,,n\nu=1,\dots,n is implicit. Further recalling that on the training set we have

itμ\displaystyle\langle it_{\mu}\rangle =νK~μν1(gνΔgν)\displaystyle=\sum_{\nu}\tilde{K}^{-1}_{\mu\nu}(g_{\nu}-\Delta g_{\nu}) (A.22)
itμitνitμitν\displaystyle\langle it_{\mu}it_{\nu}\rangle-\langle it_{\mu}\rangle\langle it_{\nu}\rangle =K~μ,ν1\displaystyle=-\tilde{K}^{-1}_{\mu,\nu} (A.23)

where here Δg\Delta g is the full quantity without any SP approximations

Δgν=r=31(r1)!μ1,,μr1=1nκν,μ1,,μr1itμ1itμr1\Delta g_{\nu}=\sum_{r=3}^{\infty}\frac{1}{\left(r-1\right)!}\sum_{\mu_{1},\dots,\mu_{r-1}=1}^{n}\kappa_{\nu,\mu_{1},\dots,\mu_{r-1}}it_{\mu_{1}}\cdots it_{\mu_{r-1}} (A.24)

We obtain

Σ\displaystyle\Sigma_{**} =K+it2𝒞~(it1,,itn,it)|t=0+(κνitν+Δg)2κνitν+Δg2\displaystyle=K_{**}+\left\langle\partial_{it_{*}}^{2}\tilde{\mathcal{C}}\left(it_{1},\dots,it_{n},it_{*}\right)|_{t_{*}=0}\right\rangle+\left\langle\left(\kappa_{*\nu}it_{\nu}+\Delta g_{*}\right)^{2}\right\rangle-\left\langle\kappa_{*\nu}it_{\nu}+\Delta g_{*}\right\rangle^{2} (A.25)

where we can unpack the last two terms as

(κνitν+Δg)2κνitν+Δg2\displaystyle\left\langle\left(\kappa_{*\nu}it_{\nu}+\Delta g_{*}\right)^{2}\right\rangle-\left\langle\kappa_{*\nu}it_{\nu}+\Delta g_{*}\right\rangle^{2} (A.26)
=κμκν(itμitνitμitν)+2(ΔgκμitμΔgκμitμ)+(Δg2Δg2)\displaystyle=\kappa_{*\mu}\kappa_{*\nu}\left(\left\langle it_{\mu}it_{\nu}\right\rangle-\left\langle it_{\mu}\right\rangle\left\langle it_{\nu}\right\rangle\right)+2\left(\left\langle\Delta g_{*}\kappa_{*\mu}it_{\mu}\right\rangle-\left\langle\Delta g_{*}\right\rangle\left\langle\kappa_{*\mu}it_{\mu}\right\rangle\right)+\left(\left\langle\Delta g_{*}^{2}\right\rangle-\left\langle\Delta g_{*}\right\rangle^{2}\right)

One can verify that for the GP case where 𝒞~=0\tilde{\mathcal{C}}=0, Eq. A.25 simplifies to

Σ=KKμK~μ,ν1Kν\Sigma_{**}=K_{**}-K_{\mu}^{*}\tilde{K}_{\mu,\nu}^{-1}K_{\nu}^{*} (A.27)

which is the standard posterior covariance of a GP [35].

The expressions in Eqs. A.25, A.26 are exact, but to evaluate them in a more compact form we approximate the tμt_{\mu} distribution as a Gaussian centered around the SP value.

A.3.2 Posterior covariance on the training set

Our target shift approach at the saddle-point level allows a computation of the fluctuations of itμit_{\mu} using the standard procedure of expanding the action at the saddle point to quadratic order in tμt_{\mu}. Due to the saddle-point being an extremum this leads to 𝒮𝒮saddle+12tμAμν1tν\mathcal{S}\approx\mathcal{S}_{\rm{saddle}}+\frac{1}{2}t_{\mu}A^{-1}_{\mu\nu}t_{\nu} and thus using the standard Gaussian integration formula, one finds that AμνA_{\mu\nu} is the covariance matrix of itμit_{\mu}. Performing such an expansion on the action of Eq. A.14 one finds

Aμν1\displaystyle A^{-1}_{\mu\nu} =(σ2δμν+Kμν+ΔKμν)\displaystyle=-\left(\sigma^{2}\delta_{\mu\nu}+K_{\mu\nu}+\Delta K_{\mu\nu}\right) (A.28)
ΔKμν\displaystyle\Delta K_{\mu\nu} =itμitν𝒞~(it1,,itn)Δgν\displaystyle=\partial_{it_{\mu}}\underbrace{\partial_{it_{\nu}}\tilde{\mathcal{C}}\left(it_{1},\dots,it_{n}\right)}_{\Delta g_{\nu}}

where the itμit_{\mu} on the r.h.s. are those of the saddle-point. Recalling Eq. A.11 we have

Σμν=fμfνfμfν=σ4[σ2I+K+ΔK]μν1+σ2δμν\displaystyle\Sigma_{\mu\nu}=\left\langle f_{\mu}f_{\nu}\right\rangle-\left\langle f_{\mu}\right\rangle\left\langle f_{\nu}\right\rangle=-\sigma^{4}\left[\sigma^{2}I+K+\Delta K\right]_{\mu\nu}^{-1}+\sigma^{2}\delta_{\mu\nu} (A.29)

and the r.h.s. coincides with the posterior covariance of a GP with a kernel equal to K+ΔKK+\Delta K.

A.4 A criterion for the saddle-point regime

Saddle point approximations are commonly used in statistics [11] and physics and often rely on having partition functions of the form Z=𝑑ten𝒮(t)Z=\int dte^{-n\mathcal{S}(t)} where nn is a large number and 𝒮\mathcal{S} is order 11 (O(1)O(1)). In our settings we cannot simply extract such a large factor from the action and make it O(1)O(1). Nonetheless, we argue that expanding the action to quadratic order around the saddle point is still a good approximation at large nn, with nn being the training set size. Concretely we give the following two consistency criteria based on comparing the saddle point results with their leading order beyond-saddle-point corrections. The first is given by the latter correction to the mean predictor over the scale of the saddle point prediction

12[δ^gνσ2δ^gησ2Δgμ]K~μ0μ1K~νη1O(g)\displaystyle\frac{1}{2}\left[\partial_{\frac{\hat{\delta}g_{\nu}}{\sigma^{2}}}\partial_{\frac{\hat{\delta}g_{\eta}}{\sigma^{2}}}\Delta g_{\mu}\right]\tilde{K}^{-1}_{\mu_{0}\mu}\tilde{K}^{-1}_{\nu\eta}\ll O(g) (A.30)

where an Einstein summation over the training-set is implicit and the derivatives are evaluated at the saddle point value. This criterion can be calculated for any specific model to verify the appropriateness of the saddle point approach. We further provide a simpler criterion

n(δ^gσ2)21\displaystyle n\left(\frac{\hat{\delta}g}{\sigma^{2}}\right)^{2}\gg 1 (A.31)

which however relies on heuristic assumptions. The main purpose of this heuristic criterion is to provide a qualitative explanation for why we expect the first criterion to be small in many interesting large nn settings.

To this end we first obtain the leading (beyond quadratic) correction to the mean. Consider the partition function in terms of itit and its expansion around the saddle point. As P0[f]P_{0}[f] is effectively bounded (by the Gaussian tails of the finite set of weights), the corresponding characteristic function (e𝒞(t1,,tn)e^{\mathcal{C}(t_{1},\dots,t_{n})}) is well defined over the entire complex plane. Given this, one can deform the integration contour, along each dimension (𝑑tμ\int_{-\infty}^{\infty}dt_{\mu}), which originally laid on the real axis, to +tSP,μ+tSP,μ𝑑tμ\int_{-\infty+t_{\rm{SP},\mu}}^{\infty+t_{\rm{SP},\mu}}dt_{\mu} where tSPt_{\rm{SP}} is purely imaginary and equals iδ^g/σ2-i\hat{\delta}g/\sigma^{2} so it crosses the saddle point (see also Ref. [11]). Next we expand the action in the deviation from the SP value: δtμ=tμtSP,μ\delta t_{\mu}=t_{\mu}-t_{\rm{SP},\mu} to obtain

Z=δt1δtnexp(𝒮012!δtμ𝒮μνδtν13!𝒮μνηδtμδtνδtη+O(δt4))\displaystyle Z=\int_{-\infty}^{\infty}\delta t_{1}\cdots\delta t_{n}\exp\left(-\mathcal{S}_{0}-\frac{1}{2!}\delta t_{\mu}\mathcal{S}_{\mu\nu}\delta t_{\nu}-\frac{1}{3!}\mathcal{S}_{\mu\nu\eta}\delta t_{\mu}\delta t_{\nu}\delta t_{\eta}+O\left(\delta t^{4}\right)\right) (A.32)

where an Einstein summation over the training set is implicit and where we denoted for m3m\geq 3

𝒮μ1μmtμ1tμm𝒮|t=tSP=tμ1tμm𝒞~|t=tSP=itμ1tμm1Δgμm|t=tSP\displaystyle\mathcal{S}_{\mu_{1}\dots\mu_{m}}\equiv\evaluated{\partial_{t_{\mu_{1}}}\cdots\partial_{t_{\mu_{m}}}\mathcal{S}}_{\vec{t}=\vec{t}_{\rm{SP}}}=-\evaluated{\partial_{t_{\mu_{1}}}\cdots\partial_{t_{\mu_{m}}}\tilde{\mathcal{C}}}_{\vec{t}=\vec{t}_{\rm{SP}}}=i\evaluated{\partial_{t_{\mu_{1}}}\cdots\partial_{t_{\mu_{m-1}}}\Delta g_{\mu_{m}}}_{\vec{t}=\vec{t}_{\rm{SP}}} (A.33)

Next we consider first order perturbation theory in the cubic term and calculate the correction to the mean of itμit_{\mu} or equivalently δ^gμ/σ2\hat{\delta}g_{\mu}/\sigma^{2}.

itμ0\displaystyle\langle it_{\mu_{0}}\rangle itμ0SP16𝒮μνηiδtμ0δtμδtνδtηSP,connected\displaystyle\approx\langle it_{\mu_{0}}\rangle_{\rm{SP}}-\frac{1}{6}\mathcal{S}_{\mu\nu\eta}\langle i\delta t_{\mu_{0}}\delta t_{\mu}\delta t_{\nu}\delta t_{\eta}\rangle_{\rm{SP,connected}} (A.34)
=δ^gμ0σ2i2𝒮μνηK~μ0μ1K~νη1\displaystyle=\frac{\hat{\delta}g_{\mu_{0}}}{\sigma^{2}}-\frac{i}{2}\mathcal{S}_{\mu\nu\eta}\tilde{K}^{-1}_{\mu_{0}\mu}\tilde{K}^{-1}_{\nu\eta}
=δ^gμ0σ2+12itνitηΔgμ|t=tSPK~μ0μ1K~νη1\displaystyle=\frac{\hat{\delta}g_{\mu_{0}}}{\sigma^{2}}+\frac{1}{2}\partial_{it_{\nu}}\partial_{it_{\eta}}\Delta g_{\mu}|_{\vec{t}=\vec{t}_{\rm{SP}}}\tilde{K}^{-1}_{\mu_{0}\mu}\tilde{K}^{-1}_{\nu\eta}

where here the kernel is shifted: K~=σ2I+K+ΔK\tilde{K}=\sigma^{2}I+K+\Delta K, as in Eq. A.28 and SP,connected\langle...\rangle_{\rm{SP,connected}} means keeping terms in Wick’s theorem which connect the operator being averaged (iδtμ0i\delta t_{\mu_{0}}) with the perturbation, as standard in perturbation theory. Comparing the last term on the right hand side (i.e. the correction) with the predictions which are O(g)O(g) gives the first criterion, Eq. A.30. Depending on context, it may be more appropriate to compare this term with the discrepancy rather than the prediction.

Next we turn to study the scaling of this correction with nn. To this end we first consider a single derivative of Δg\Delta g (itνΔgμ\partial_{it_{\nu}}\Delta g_{\mu}). Note that Δgμ\Delta g_{\mu}, by its definition, includes contributions from at least n3n^{3} different itit’s. In many cases, one expects that the value of this sum will be dominated by some finite fraction of the training set rather than by a vanishing fraction. This assumption is in fact implicit in our EK treatment where we replaced all μ\sum_{\mu} with nn\int. Given so, the derivative itνΔgμ\partial_{it_{\nu}}\Delta g_{\mu}, which can be viewed as the sensitivity to changing itνit_{\nu}, is expected to go as one over the size of that fraction of the training set, namely as 1/n1/n. Under this collectivity assumption we expect the scaling

itνΔgμ=O(it1Δgn1)\displaystyle\partial_{it_{\nu}}\Delta g_{\mu}=O(it^{-1}\Delta gn^{-1}) (A.35)

Making a similar collectivity assumption on higher derivatives yields

itηitνΔgμ=O(it2Δgn2)\displaystyle\partial_{it_{\eta}}\partial_{it_{\nu}}\Delta g_{\mu}=O(it^{-2}\Delta gn^{-2}) (A.36)

Following this we count powers of nn in Eq. A.34 and find a n2n^{-2} contribution from the second derivative of Δg\Delta g and a contribution from the summation over ην\sum_{\eta\nu}. Despite containing two summations, we argue that the latter is in fact order nn. To this end consider n2νηΔgμn^{2}\partial_{\nu}\partial_{\eta}\Delta g_{\mu} for fixed ν,η\nu,\eta, as an effective target function (Gμ(ν,η)G_{\mu}(\nu,\eta)) where we multiplied by the scaling of the second derivative to make GG order 1. The above summation appears then as η,νK~ην1Gμ(ν,η)\sum_{\eta,\nu}\tilde{K}^{-1}_{\eta\nu}G_{\mu}(\nu,\eta). Next we recall that itμ0=K~μ0μ1gμ=δ^gμ0/σ2it_{\mu_{0}}=\tilde{K}^{-1}_{\mu_{0}\mu}g_{\mu}=\hat{\delta}g_{\mu_{0}}/\sigma^{2}, and so multiplication of a vector with K~1\tilde{K}^{-1} can be interpreted as the discrepancy w.r.t. the GμG_{\mu} target. Accordingly the above summation over μ\mu can be viewed as performing GP Regression on Gμ(ν,η)G_{\mu}(\nu,\eta) leading to train discrepancy (itμ[G(ν,η)]it_{\mu}[G(\nu,\eta)]) which is order Gμ(ν,η)G_{\mu}(\nu,\eta) and hence order 1. The remaining summation has now a summand of the order 11 and hence is O(n)O(n) or smaller. We thus find that the correction to the saddle point scales as

itμ0itμ0SP\displaystyle\langle it_{\mu_{0}}\rangle-\langle it_{\mu_{0}}\rangle_{\rm{SP}} =O(Δgn(δ^g/σ2)2)\displaystyle=O\left(\frac{\Delta g}{n(\hat{\delta}g/\sigma^{2})^{2}}\right) (A.37)

Generally we expect Δg0\Delta g\approx 0 at strong over-parameterization (as non-linear effects are suppressed by C1C^{-1}) and ΔgO(g)\Delta g\sim O(g) at good performances (as this implies good performance on the training set). Thus we generally expect Δg=O(g)=O(1)\Delta g=O(g)=O(1) and hence large n(δ^g/σ2)2n(\hat{\delta}g/\sigma^{2})^{2} controls the magnitude of the corrections. Considering the σ20\sigma^{2}\rightarrow 0 limit, we note in passing that δ^g/σ2\hat{\delta}g/\sigma^{2} typically remains finite. For instance it is simply K1gK^{-1}g for a Gaussian Process.

Considering the linear CNN model of the main text, we estimate the above heuristic criterion for n=650n=650 and C=8C=8 where Δg=O(g)\Delta g=O(g) and δ^g0.1g\hat{\delta}g\approx 0.1g. This then gives (6.5O(g)2)1(6.5O(g)^{2})^{-1} as the small factor dominating the correction. As we choose O(g)=3O(g)=3 in that experiment, we find that the correction is roughly 1/601/60. As the discrepancy is 0.1g=O(0.3)0.1g=O(0.3) we expect roughly a 5%5\% relative error in predicting the discrepancy.

Appendix B Review of the Edgeworth expansion

In this section we give a review of the Edgeworth expansion, starting from the simplest case of a scalar valued RV and then moving on vector valued RVs so we can write down the expansion for the output of a generic neural network on a fixed set of inputs.

B.1 Edgeworth expansion for a scalar random variable

Consider scalar valued continuous iid RVs {Zi}\{Z_{i}\} and assume WLOG Zi=0,Zi2=1\left\langle Z_{i}\right\rangle=0,\hskip 5.0pt\left\langle Z_{i}^{2}\right\rangle=1, with higher cumulants κrZ\kappa_{r}^{Z} for r3r\geq 3. Now consider their normalized sum YN=1Ni=1NZiY_{N}=\frac{1}{\sqrt{N}}\sum_{i=1}^{N}Z_{i}. Recall that cumulants are additive, i.e. if Z1,Z2Z_{1},Z_{2} are independent RVs then κr(Z1+Z2)=κr(Z1)+κr(Z2)\kappa_{r}(Z_{1}+Z_{2})=\kappa_{r}(Z_{1})+\kappa_{r}(Z_{2}) and that the rr-th cumulant is homogeneous of degree rr, i.e. if cc is any constant, then κr(cZ)=crκr(Z)\kappa_{r}(cZ)=c^{r}\kappa_{r}(Z). Combining additivity and homogeneity of cumulants we have a relation between the cumulants of ZZ and YY

κr2:=κr2Y=NκrZ(N)r=κrZNr/21\kappa_{r\geq 2}:=\kappa^{Y}_{r\geq 2}=\frac{N\kappa_{r}^{Z}}{(\sqrt{N})^{r}}=\frac{\kappa_{r}^{Z}}{N^{r/2-1}} (B.1)

Now, let φ(y):=(2π)1/2ey2/2\varphi(y):=(2\pi)^{-1/2}e^{-y^{2}/2} be the PDF of the standard normal distribution. The characteristic function of YY is given by the Fourier transform of its PDF P(y)P(y) and is expressed via its cumulants

P^(t):=[P(y)]=exp(r=1κr(it)rr!)=exp(r=3κr(it)rr!)φ^(t)\hat{P}(t):=\mathcal{F}[P(y)]=\exp\left(\sum_{r=1}^{\infty}\kappa_{r}\frac{(it)^{r}}{r!}\right)=\exp\left(\sum_{r=3}^{\infty}\kappa_{r}\frac{(it)^{r}}{r!}\right)\hat{\varphi}(t) (B.2)

where the last equality holds since κ1=0,κ2=1\kappa_{1}=0,\quad\kappa_{2}=1 and φ^(t)=et22\hat{\varphi}(t)=e^{-\frac{t^{2}}{2}}. From the CLT, we know that P(y)φ(y)P(y)\to\varphi(y) as NN\to\infty. Taking the inverse Fourier transform 1\mathcal{F}^{-1} has the effect of mapping ityit\mapsto-\partial_{y} thus

P(y)=exp(r=3κr(y)rr!)φ(y)=φ(y)(1+r=3κrr!Hr(y))P(y)=\exp\left(\sum_{r=3}^{\infty}\kappa_{r}\frac{(-\partial_{y})^{r}}{r!}\right)\varphi(y)=\varphi(y)\left(1+\sum_{r=3}^{\infty}\frac{\kappa_{r}}{r!}H_{r}(y)\right) (B.3)

where Hr(y)H_{r}(y) is the rrth probabilist’s Hermite polynomial, defined by

Hr(y)=()rey2/2drdyrey2/2H_{r}(y)=(-)^{r}e^{y^{2}/2}\frac{d^{r}}{dy^{r}}e^{-y^{2}/2} (B.4)

e.g. H4(y)=y46y2+3H_{4}(y)=y^{4}-6y^{2}+3.

B.2 Edgeworth expansion for a vector valued random variable

Consider now the analogous procedure for vector-valued RVs in n\mathbb{R}^{n} (see [29]). We perform an Edgeworth expansion around a centered multivariate Gaussian distribution with covariance matrix κi,j\kappa^{i,j}

φ(y)=1(2π)d/2det(κi,j)exp(12κi,jyiyj)\varphi(\vec{y})=\frac{1}{(2\pi)^{d/2}\det(\kappa^{i,j})}\exp\left(-\frac{1}{2}\kappa_{i,j}y^{i}y^{j}\right) (B.5)

where κi,j\kappa_{i,j} is the matrix inverse of κi,j\kappa^{i,j} and Einstein summation is used. The rr’th order cumulant becomes a tensor with rr indices, e.g. the analogue of κ4\kappa_{4} is κi,j,k,l\kappa^{i,j,k,l}. The Hermite polynomials are now multi-variate polynomials, so that the first one is Hi=κi,jyjH_{i}=\kappa_{i,j}y^{j} and the fourth one is

Hijkl(y)=e12κi,jyiyjijkle12κi,jyiyj=HiHjHkHlHiHjκk,l[6]+κi,jκk,l[3]\begin{split}H_{ijkl}(\vec{y})&=e^{\frac{1}{2}\kappa_{i^{\prime},j^{\prime}}y^{i^{\prime}}y^{j^{\prime}}}\partial_{i}\partial_{j}\partial_{k}\partial_{l}e^{-\frac{1}{2}\kappa_{i^{\prime},j^{\prime}}y^{i^{\prime}}y^{j^{\prime}}}\\ &=H_{i}H_{j}H_{k}H_{l}-H_{i}H_{j}\kappa_{k,l}[6]+\kappa_{i,j}\kappa_{k,l}[3]\end{split} (B.6)

where the postscript bracket notation is simply a convenience to avoid listing explicitly all possible partitions of the indices, e.g. κi,jκk,l[3]=κi,jκk,l+κi,kκj,l+κi,lκj,k\kappa_{i,j}\kappa_{k,l}[3]=\kappa_{i,j}\kappa_{k,l}+\kappa_{i,k}\kappa_{j,l}+\kappa_{i,l}\kappa_{j,k}

In our context we are interested in even distributions where all odd cumulants vanish, so the Edgeworth expansion reads

P(y)=exp(κi,j,k,l4!ijkl+)φ(y)=φ(y)(1+κi,j,k,l4!Hijkl+)P(\vec{y})=\exp\left(\frac{\kappa^{i,j,k,l}}{4!}\partial_{i}\partial_{j}\partial_{k}\partial_{l}+\dots\right)\varphi(\vec{y})=\varphi(\vec{y})\left(1+\frac{\kappa^{i,j,k,l}}{4!}H_{ijkl}+\dots\right) (B.7)

B.3 Edgeworth expansion for the posterior of Bayesian neural network

Consider an on-data formulation, i.e. a distribution over a vector space - the NN output evaluated on the training set and on a single test point, rather than a distribution over the whole function space:

f(𝐱)f(f(𝐱1),,f(𝐱n),f(𝐱n+1))n+1𝐱n+1=𝐱f\left({\mathbf{x}}\right)\to\vec{f}\equiv\left(f\left({\mathbf{x}}_{1}\right),\dots,f\left({\mathbf{x}}_{n}\right),f\left({\mathbf{x}}_{n+1}\right)\right)\in\mathbb{R}^{n+1}\qquad{\mathbf{x}}_{n+1}={\mathbf{x}}_{*} (B.8)

where 𝐱{\mathbf{x}}_{*} is the test point. Let κr\kappa_{r} denote the rrth cumulant of the prior P0(f)P_{0}\left(\vec{f}\right) of the network over this space:

[κr]μ1,,μr=f(𝐱μ1),,f(𝐱μr)"disconnectedaverages"μ{1,,n+1}\left[\kappa_{r}\right]_{\mu_{1},...,\mu_{r}}=\left\langle f\left({\mathbf{x}}_{\mu_{1}}\right),\dots,f\left({\mathbf{x}}_{\mu_{r}}\right)\right\rangle-"\mathrm{disconnected\,averages}"\qquad\mu\in\left\{1,\dots,n+1\right\} (B.9)

Take the baseline distribution to be Gaussian PG(f)exp(12f𝖳K1f)P_{G}\left(\vec{f}\right)\propto\exp\left(-\frac{1}{2}\vec{f}^{\mathsf{T}}K^{-1}\vec{f}\right), around which we perform the Edgeworth expansion, thus the characteristic function of the prior reads

P^0(t)=exp(r=4κr(it)rr!)P^G(t)\hat{P}_{0}\left(\vec{t}\right)=\exp\left(\sum_{r=4}^{\infty}\frac{\kappa_{r}\left(i\vec{t}\right)^{r}}{r!}\right)\hat{P}_{G}\left(\vec{t}\right) (B.10)

and thus

P0(f)=exp(r=4()rκrrr!)PG(f)P_{0}\left(\vec{f}\right)=\exp\left(\sum_{r=4}^{\infty}\frac{\left(-\right)^{r}\kappa_{r}\vec{\partial}^{r}}{r!}\right)P_{G}\left(\vec{f}\right) (B.11)

where we used the shorthand notation:

κrrμ1,,μr[κr]μ1,,μrfμ1fμr\kappa_{r}\vec{\partial}^{r}\equiv\sum_{\mu_{1},...,\mu_{r}}\left[\kappa_{r}\right]_{\mu_{1},...,\mu_{r}}\partial_{f_{\mu_{1}}}\cdots\partial_{f_{\mu_{r}}} (B.12)

and the indices range over both the train set and the test point μ{1,,n+1}\mu\in\left\{1,\dots,\underbrace{n+1}_{*}\right\}. In our case, all odd cumulants vanish, thus

exp(r=4()rκrrr!)=exp(r=4κrrr!)\exp\left(\sum_{r=4}^{\infty}\frac{\left(-\right)^{r}\kappa_{r}\vec{\partial}^{r}}{r!}\right)=\exp\left(\sum_{r=4}^{\infty}\frac{\kappa_{r}\vec{\partial}^{r}}{r!}\right) (B.13)

Introducing the data term and a source term, the partition function reads (denote f(𝐱μ)fμ,f(𝐱)f)f\left({\mathbf{x}}_{\mu}\right)\equiv f_{\mu},\quad f\left({\mathbf{x}}_{*}\right)\equiv f_{*})

Z(J)=𝑑f(exp(r=4κrrr!)PG(f))exp(12σ2μ=1n(gμfμ)2+μ=1n+1Jμfμ)Z\left(J\right)=\int d\vec{f}\left(\exp\left(\sum_{r=4}^{\infty}\frac{\kappa_{r}\vec{\partial}^{r}}{r!}\right)P_{G}\left(\vec{f}\right)\right)\exp\left(-\frac{1}{2\sigma^{2}}\sum_{\mu=1}^{n}\left(g_{\mu}-f_{\mu}\right)^{2}+\sum_{\mu=1}^{n+1}J_{\mu}f_{\mu}\right) (B.14)

Appendix C Target shift equations - alternative derivation

Here we derive our self-consistent target shift equations from a different approach which does not require the introduction of the itμit_{\mu} integration variables by transforming to Fourier space. While this approach requires an additional assumption (see below) it also has the benefit of being extendable to any smooth loss function comprised of a sum over training points. In particular, below we derive it for both MSE loss and cross entropy loss.

To this end, we examine the Edgeworth expansion for the partition function given by Eq. B.14. By using a series of integration by parts and noting the boundary terms vanish, one can shift the action of the higher cumulants from the prior to the data dependent term

Z(J)=𝑑fPG(f)[exp(r=31r!μ1,,μr=1n+1κμ1,,μrfμ1fμr)exp(12σ2μ=1n(gμfμ)2+μ=1n+1Jμfμ)]Z\left(\vec{J}\right)=\int d\vec{f}P_{G}\left(\vec{f}\right)\left[\exp\left(\sum_{r=3}^{\infty}\frac{1}{r!}\sum_{\mu_{1},...,\mu_{r}=1}^{n+1}\kappa_{\mu_{1},...,\mu_{r}}\partial_{f_{\mu_{1}}}\cdots\partial_{f_{\mu_{r}}}\right)\exp\left(-\frac{1}{2\sigma^{2}}\sum_{\mu=1}^{n}\left(g_{\mu}-f_{\mu}\right)^{2}+\sum_{\mu=1}^{n+1}J_{\mu}f_{\mu}\right)\right] (C.1)

Doing so yields an equivalent viewpoint on the problem, wherein the Gaussian data term and the non-Gaussian prior appearing in Eq. B.14 are replaced in Eq. C.1 by a Gaussian prior and a non-Gaussian data term.

Next we argue that in the large nn limit, the non-Gaussian data-term can be expressed as a Gaussian-data term but on a shifted target. To this end we note that when nn is large, most combinations of derivatives in the exponents act on different data points. In such cases derivatives could simply be replaced as μiσ2δ^gμi\partial_{\mu_{i}}\rightarrow\sigma^{-2}\hat{\delta}g_{\mu_{i}}, where δ^gμigμifμi\hat{\delta}g_{\mu_{i}}\equiv g_{\mu_{i}}-f_{\mu_{i}} denotes the discrepancy on the training point μi\mu_{i}.

Consider next how fνf_{\nu} on a particular training point (ν\nu) is affect by these derivative terms. Following the above observation, most terms in the exponent will not act on fνf_{\nu} and a 1/n1/n portion will contain a single derivative. The remaining rarer cases, where two derivatives act on the same ν\nu, are neglected. For each fνf_{\nu} we thus replace r1r-1 derivatives in the order rr term in C.1 by discrepancies, leaving a single derivative operator that is multiplied by the following quantity

Δgνr=31(r1)!μ1,,μr1nκν,μ1μr1(σ2δ^gμ1)(σ2δ^gμr1)\Delta g_{\nu}\equiv\sum_{r=3}^{\infty}\frac{1}{\left(r-1\right)!}\sum_{\mu_{1},...,\mu_{r-1}}^{n}\kappa_{\nu,\mu_{1}\dots\mu_{r-1}}(\sigma^{-2}\hat{\delta}g_{\mu_{1}})\cdots(\sigma^{-2}\hat{\delta}g_{\mu_{r-1}}) (C.2)

Note that the summation indices span only the training set, not the test point: μ1,,μr1{1,,n}\mu_{1},...,\mu_{r-1}\in\left\{1,\dots,n\right\}, whereas the free index spans also the test point ν{1,,n+1}\nu\in\left\{1,\dots,n+1\right\}.

Recall that an exponentiated derivative operator acts as a shifting operator, e.g. for some constant aa\in\mathbb{R}, any smooth scalar function φ\varphi obeys eaxφ(x)=φ(x+a)e^{a\partial_{x}}\varphi\left(x\right)=\varphi\left(x+a\right). If this Δg\Delta g was a constant, the differential operator could now readily act on the data term. Next we make again our collectivity assumption: as Δg\Delta g involves a sum over many data-points, it will be a weakly fluctuating quantity in the large nn limit provided the contribution to Δg\Delta g comes from a collective effect rather than by a few data points. We thus perform our second approximation, of the mean-field type, and replace Δg\Delta g by its average Δg¯\overline{\Delta g}, leading to

Z(J;Δg¯)=𝑑fPG(f)exp(12σ2μ=1n(gμΔg¯μfμ)2+μ=1n+1Jμ(fμ+Δg¯μ))Z\left(\vec{J};\overline{\Delta g}\right)=\int d\vec{f}P_{G}\left(\vec{f}\right)\exp\left(-\frac{1}{2\sigma^{2}}\sum_{\mu=1}^{n}\left(g_{\mu}-\overline{\Delta g}_{\mu}-f_{\mu}\right)^{2}+\sum_{\mu=1}^{n+1}J_{\mu}\left(f_{\mu}+\overline{\Delta g}_{\mu}\right)\right) (C.3)

Given a fixed Δg¯\overline{\Delta g}, C.3 is the partition function corresponding to a GP with the train targets shifted by Δg¯μ\overline{\Delta g}_{\mu} and the test target shifted by Δg¯\overline{\Delta g}_{*}. Following this we find that Δg¯\overline{\Delta g} depends on the discrepancy of the GP prediction which in turn depends on Δg¯\overline{\Delta g}. In other words we obtain our self-consistent equation: Δg¯=ΔgμZ(J;Δg¯)\overline{\Delta g}=\langle\Delta g_{\mu}\rangle_{Z\left(\vec{J};\overline{\Delta g}\right)}.

The partition function C.3 reflects the correspondence between finite DNNs and a GP with its target shifted by Δg¯\overline{\Delta g}. To facilitate the analytic solution of this self-consistent equation, we focus on the case δ^gμδ^gνδ^gμδ^gν\left\langle\hat{\delta}g_{\mu}\hat{\delta}g_{\nu}\right\rangle\ll\left\langle\hat{\delta}g_{\mu}\right\rangle\left\langle\hat{\delta}g_{\nu}\right\rangle at least for μν\mu\neq\nu. We note that this was the case for the two toy models we studied.

Given this, the expectation value over Δg\Delta g using the GP defined by Z(J;Δg¯)Z\left(\vec{J};\overline{\Delta g}\right), which consists of products of expectation values of individual discrepancies and correlations between two discrepancies, can then be expressed using only the former. Omitting correlations within the GP expectation value, one obtains a simplified self-consistent equation involving only the average discrepancies:

μ{1,,n}:δ^gμZ(J;Δg¯)gμΔg¯μfμZ(J;Δg¯)=gμΔg¯μν,ρ=1nKμνK~νρ1(gρΔg¯ρ)\forall\mu\in\left\{1,\dots,n\right\}:\qquad\langle\hat{\delta}g_{\mu}\rangle_{Z\left(\vec{J};\overline{\Delta g}\right)}\equiv g_{\mu}-\overline{\Delta g}_{\mu}-\left\langle f_{\mu}\right\rangle_{Z\left(\vec{J};\overline{\Delta g}\right)}=g_{\mu}-\overline{\Delta g}_{\mu}-\sum_{\nu,\rho=1}^{n}K_{\mu\nu}\tilde{K}_{\nu\rho}^{-1}\left(g_{\rho}-\overline{\Delta g}_{\rho}\right) (C.4)

with δ^gμ\hat{\delta}g_{\mu} now understood as a number, also within C.2. Lastly, we plug the solution to these equations to find the prediction on the test point: f(𝐱)Z(J;Δg¯)\left\langle f({\mathbf{x}}_{*})\right\rangle_{Z\left(\vec{J};\overline{\Delta g}\right)}. These coincide with the self-consistent equations derived via the saddle point approximation in the main text.

Notably the above derivation did not hinge on having MSE loss. For any loss given as a sum over training points, =μnLμ(fμ)\mathcal{L}=\sum^{n}_{\mu}L_{\mu}(f_{\mu}), the above derivation should hold with σ2δ^gμ\sigma^{-2}\hat{\delta}g_{\mu} in Δgν\Delta g_{\nu} replaced by fμLμ\partial_{f_{\mu}}L_{\mu}. In particular for the cross entropy loss where fν,if_{\nu,i} is the pre-softmax output of the DNN for class ii we will have

fν,iLν=δiν,i+efν,ijefν,j\displaystyle\partial_{f_{\nu,i}}L_{\nu}=-\delta_{i_{\nu},i}+\frac{e^{f_{\nu,i}}}{\sum_{j}e^{f_{\nu,j}}} (C.5)

where ii and jj run over all classes, iνi_{\nu} is the class of 𝐱ν{\mathbf{x}}_{\nu}. Neatly, the above r.h.s. is again a form of discrepancy but this time in probability space. Namely it is pmodel(i|𝐱ν)pdata(i|𝐱ν)p_{\rm{model}}(i|{\mathbf{x}}_{\nu})-p_{\rm{data}}(i|{\mathbf{x}}_{\nu}), where pmodelp_{\rm{model}} is the distribution generated by the softmax layer, and pdatap_{\rm{data}} is the empirical distribution. Following this one can readily derive self-consistent equations for cross entropy loss and solve them numerically. Further analytical progress hinges on developing analogous of the EK approximation for cross entropy loss.

Appendix D Review of the Equivalent Kernel (EK)

In this appendix we generally follow [35], see also [38] for more details. The posterior mean for GP regression

f¯GP(𝐱)=μ,νKμK~μν1yν\bar{f}_{\mathrm{GP}}({\mathbf{x}}_{*})=\sum_{\mu,\nu}K^{*}_{\mu}\tilde{K}^{-1}_{\mu\nu}y_{\nu} (D.1)

can be obtained as the function which minimizes the functional

J[f]=12σ2α=1n(yαf(𝐱α))2+12f2J\left[f\right]=\frac{1}{2\sigma^{2}}\sum_{\alpha=1}^{n}\left(y_{\alpha}-f\left({\mathbf{x}}_{\alpha}\right)\right)^{2}+\frac{1}{2}||f||_{\mathcal{H}}^{2} (D.2)

where f||f||_{\mathcal{H}} is the RKHS norm corresponding to kernel KK. Our goal is now to understand the behaviour of the minimizer of J[f]J[f] as nn\to\infty. Let the data pairs (𝐱α,yα)\left({\mathbf{x}}_{\alpha},y_{\alpha}\right) be drawn from the probability measure μ(𝐱,y)\mu({\mathbf{x}},y). The expectation value of the MSE is

𝔼[α=1n(yαf(𝐱α))2]=n(yf(𝐱))2𝑑μ(𝐱,y)\mathbb{E}\left[\sum_{\alpha=1}^{n}\left(y_{\alpha}-f\left({\mathbf{x}}_{\alpha}\right)\right)^{2}\right]=n\int\left(y-f\left({\mathbf{x}}\right)\right)^{2}d\mu\left({\mathbf{x}},y\right) (D.3)

Let g(𝐱)𝔼[y|𝐱]g\left({\mathbf{x}}\right)\equiv\mathbb{E}\left[y|{\mathbf{x}}\right] be the ground truth regression function to be learned. The variance around g(𝐱)g\left({\mathbf{x}}\right) is denoted σ2(𝐱)=(yg(𝐱))2𝑑μ(y|𝐱)\sigma^{2}\left({\mathbf{x}}\right)=\int\left(y-g\left({\mathbf{x}}\right)\right)^{2}d\mu\left(y|{\mathbf{x}}\right). Then writing yf=(yg)+(gf)y-f=\left(y-g\right)+\left(g-f\right) we find that the MSE on the data target yy can be broken up into the MSE on the ground truth target gg plus variance due to the noise

(yf(𝐱))2𝑑μ(𝐱,y)=(g(𝐱)f(𝐱))2𝑑μ(𝐱)+σ2(𝐱)𝑑μ(𝐱)\int\left(y-f\left({\mathbf{x}}\right)\right)^{2}d\mu\left({\mathbf{x}},y\right)=\int\left(g\left({\mathbf{x}}\right)-f\left({\mathbf{x}}\right)\right)^{2}d\mu\left({\mathbf{x}}\right)+\int\sigma^{2}\left({\mathbf{x}}\right)d\mu\left({\mathbf{x}}\right) (D.4)

Since the right term on the RHS of D.4 does not depend on ff we can ignore it when looking for the minimizer of the functional which is now replaced by

Jμ[f]=n2σ2(g(𝐱)f(𝐱))2𝑑μ(𝐱)+12f2J_{\mu}\left[f\right]=\frac{n}{2\sigma^{2}}\int\left(g\left({\mathbf{x}}\right)-f\left({\mathbf{x}}\right)\right)^{2}d\mu\left({\mathbf{x}}\right)+\frac{1}{2}||f||_{\mathcal{H}}^{2} (D.5)

To proceed we project gg and ff on the eigenfunctions of the kernel with respect to μ(𝐱)\mu({\mathbf{x}}) which obey μ(𝐱)K(𝐱,𝐱)ψs(𝐱)=λsψs(𝐱)\int\mu\left({\mathbf{x}}^{\prime}\right)K\left({\mathbf{x}},{\mathbf{x}}^{\prime}\right)\psi_{s}\left({\mathbf{x}}^{\prime}\right)=\lambda_{s}\psi_{s}\left({\mathbf{x}}\right). Assuming that the kernel is non-degenerate so that the ψ\psi’s form a complete orthonormal basis, for a sufficiently well behaved target we may write g(𝐱)=sgsψs(𝐱)g\left({\mathbf{x}}\right)=\sum_{s}g_{s}\psi_{s}\left({\mathbf{x}}\right) where gs=g(𝐱)ψs(𝐱)𝑑μ(𝐱)g_{s}=\int g\left({\mathbf{x}}\right)\psi_{s}\left({\mathbf{x}}\right)d\mu\left({\mathbf{x}}\right), and similarly for ff. Thus the functional becomes

Jμ[f]=n2σ2s(gsfs)2+12sfs2λsJ_{\mu}\left[f\right]=\frac{n}{2\sigma^{2}}\sum_{s}\left(g_{s}-f_{s}\right)^{2}+\frac{1}{2}\sum_{s}\frac{f_{s}^{2}}{\lambda_{s}} (D.6)

This is easily minimized by taking the derivative w.r.t. each fsf_{s} to yield

fs=λsλs+σ2/ngsf_{s}=\frac{\lambda_{s}}{\lambda_{s}+\sigma^{2}/n}g_{s} (D.7)

In the limit nn\to\infty we have σ2/n0\sigma^{2}/n\to 0 thus we expect that ff would converge to gg. The rate of this convergence will depend on the smoothness of gg, the kernel KK and the measure μ(𝐱,y)\mu({\mathbf{x}},y). From D.7 we see that if nλsσ2n\lambda_{s}\ll\sigma^{2} then fsf_{s} is effectively zero. This means that we cannot obtain information about the coefficients of eigenfunctions with small eigenvalues until we get a sufficient amount of data. Plugging the result D.7 into f(𝐱)=sfsψs(𝐱)f\left({\mathbf{x}}\right)=\sum_{s}f_{s}\psi_{s}\left({\mathbf{x}}\right) and recalling gs=g(𝐱)ψs(𝐱)𝑑μ(𝐱)g_{s}=\int g\left({\mathbf{x}}^{\prime}\right)\psi_{s}\left({\mathbf{x}}^{\prime}\right)d\mu\left({\mathbf{x}}^{\prime}\right) we find

f(𝐱)EK=sλsgsλs+σ2/nψs(𝐱)=sλsψs(𝐱)ψs(𝐱)λs+σ2/nh(𝐱,𝐱)g(𝐱)𝑑μ(𝐱)\left\langle f\left({\mathbf{x}}\right)\right\rangle_{\rm{EK}}=\sum_{s}\frac{\lambda_{s}g_{s}}{\lambda_{s}+\sigma^{2}/n}\psi_{s}\left({\mathbf{x}}\right)=\int\underbrace{\sum_{s}\frac{\lambda_{s}\psi_{s}\left({\mathbf{x}}\right)\psi_{s}\left({\mathbf{x}}^{\prime}\right)}{\lambda_{s}+\sigma^{2}/n}}_{h\left({\mathbf{x}},{\mathbf{x}}^{\prime}\right)}g\left({\mathbf{x}}^{\prime}\right)d\mu\left({\mathbf{x}}^{\prime}\right) (D.8)

The term h(𝐱,𝐱)h({\mathbf{x}},{\mathbf{x}}^{\prime}) it the equivalent kernel. Notice the similarity to the vector-valued equivalent kernel weight function 𝐡(𝐱)=(𝐊+σ2I)1𝐤(𝐱)\mathbf{h}\left({\mathbf{x}}_{*}\right)=\left(\mathbf{K}+\sigma^{2}I\right)^{-1}\mathbf{k}\left({\mathbf{x}}_{*}\right) where 𝐊\mathbf{K} denotes the n×nn\times n matrix of covariances between the training points with entries K(𝐱μ,𝐱ν)K\left({\mathbf{x}}_{\mu},{\mathbf{x}}_{\nu}\right) and 𝐤(𝐱)\mathbf{k}\left({\mathbf{x}}_{\ast}\right) is the vector of covariances with elements K(𝐱μ,𝐱*)K\left({\mathbf{x}}_{\mu},{\mathbf{x}}_{\text{\textasteriskcentered}}\right). The difference is that in the usual discrete formulation the prediction was obtained as a linear combination of a finite number of observations yiy_{i} with weights given by hi(𝐱)h_{i}({\mathbf{x}}) while here we have instead a continuous integral.

Appendix E Additional technical details for solving the self consistent equations

E.1 EK limit for the CNN toy model

In this subsection we show how to arrive at Eq. 16 from the main text, which is a self consistent equation for the proportionality constant, α\alpha, defined by δ^g=αg\hat{\delta}g=\alpha g. We first show that both the shift and the discrepancy are linear in the target, and then derive the equation.

E.1.1 The shift and the discrepancy are linear in the target

Recall that we assume a linear target with a single channel:

g(𝐱)=k=1Nak(𝐰𝐱~k)\displaystyle g\left({\mathbf{x}}\right)=\sum_{k=1}^{N}a_{k}^{*}\left({\mathbf{w}}^{*}\cdot\tilde{{\mathbf{x}}}_{k}\right) (E.1)

A useful relation in our context is

𝑑μ(𝐱2)(𝐱~i1𝐱~j2)g(𝐱2)\displaystyle\int d\mu\left({\mathbf{x}}^{2}\right)\left(\tilde{{\mathbf{x}}}_{i}^{1}\cdot\tilde{{\mathbf{x}}}_{j}^{2}\right)g\left({\mathbf{x}}^{2}\right) =𝑑μ(𝐱2)(𝐱~i1𝐱~j2)k=1Nak(𝐰𝐱~k2)\displaystyle=\int d\mu\left({\mathbf{x}}^{2}\right)\left(\tilde{{\mathbf{x}}}_{i}^{1}\cdot\tilde{{\mathbf{x}}}_{j}^{2}\right)\sum_{k=1}^{N}a_{k}^{*}\left({\mathbf{w}}^{*}\cdot\tilde{{\mathbf{x}}}_{k}^{2}\right) (E.2)
=(𝐱~i1)𝖳(k=1Nak𝑑μ(𝐱2)𝐱~j2(𝐱~k2)𝖳δjkIS)𝐰\displaystyle=\left(\tilde{{\mathbf{x}}}_{i}^{1}\right)^{\mathsf{T}}\left(\sum_{k=1}^{N}a_{k}^{*}\underbrace{\int d\mu\left({\mathbf{x}}^{2}\right)\tilde{{\mathbf{x}}}_{j}^{2}\left(\tilde{{\mathbf{x}}}_{k}^{2}\right)^{\mathsf{T}}}_{\delta_{jk}I_{S}}\right){\mathbf{w}}^{*}
=aj(𝐱~i1𝐰)\displaystyle=a_{j}^{*}\left(\tilde{{\mathbf{x}}}_{i}^{1}\cdot{\mathbf{w}}^{*}\right)

The fact that ff is always a linear function of the input (since the CNN linear) and the fact that it is proportional to gg at CC\to\infty (since the GP is linear in the target), motivates the ansatz:

δ^ggf=αg\displaystyle\hat{\delta}g\equiv g-f=\alpha g (E.3)

Indeed we will show that this ansatz provides a solution to the non linear self consistent equations.

Notice that the target shift has a form of a geometric series. In the linear CNN toy model we are able to sum this entire series, whose first term is related to (using the notation introduced in §F):

𝑑μ(𝐱2)𝑑μ(𝐱3)𝑑μ(𝐱4)κ4(𝐱1,𝐱2,𝐱3,𝐱4)g(𝐱2)g(𝐱3)g(𝐱4)\displaystyle\int d\mu\left({\mathbf{x}}^{2}\right)d\mu\left({\mathbf{x}}^{3}\right)d\mu\left({\mathbf{x}}^{4}\right)\kappa_{4}\left({\mathbf{x}}^{1},{\mathbf{x}}^{2},{\mathbf{x}}^{3},{\mathbf{x}}^{4}\right)g\left({\mathbf{x}}^{2}\right)g\left({\mathbf{x}}^{3}\right)g\left({\mathbf{x}}^{4}\right) (E.4)
=λ2C𝑑μ2:4i,j=1N{(1i3j)[(2i4j)+(4i2j)]+(1i4j)[(2i3j)+(3i2j)]+(1i2j)[(3i4j)+(4i3j)]}g(𝐱2)g(𝐱3)g(𝐱4)\displaystyle=\frac{\lambda^{2}}{C}\int d\mu_{2:4}\sum_{i,j=1}^{N}\left\{\left(1_{i}3_{j}\right)\left[\left(2_{i}4_{j}\right)+\left(4_{i}2_{j}\right)\right]+\left(1_{i}4_{j}\right)\left[\left(2_{i}3_{j}\right)+\left(3_{i}2_{j}\right)\right]+\left(1_{i}2_{j}\right)\left[\left(3_{i}4_{j}\right)+\left(4_{i}3_{j}\right)\right]\right\}g\left({\mathbf{x}}^{2}\right)g\left({\mathbf{x}}^{3}\right)g\left({\mathbf{x}}^{4}\right)
=λ2Ci,j=1N{2ai(aj)2𝐰2(𝐱~i1𝐰)+2ai(aj)2𝐰2(𝐱~i1𝐰)+2ai(aj)2𝐰2(𝐱~i1𝐰)}\displaystyle=\frac{\lambda^{2}}{C}\sum_{i,j=1}^{N}\left\{2a_{i}^{*}\left(a_{j}^{*}\right)^{2}\norm{{\mathbf{w}}^{*}}^{2}\left(\tilde{{\mathbf{x}}}_{i}^{1}\cdot{\mathbf{w}}^{*}\right)+2a_{i}^{*}\left(a_{j}^{*}\right)^{2}\norm{{\mathbf{w}}^{*}}^{2}\left(\tilde{{\mathbf{x}}}_{i}^{1}\cdot{\mathbf{w}}^{*}\right)+2a_{i}^{*}\left(a_{j}^{*}\right)^{2}\norm{{\mathbf{w}}^{*}}^{2}\left(\tilde{{\mathbf{x}}}_{i}^{1}\cdot{\mathbf{w}}^{*}\right)\right\}
=6λ2C𝐰2(j=1N(aj)2)i=1Nai(𝐱~i1𝐰)g(𝐱1)\displaystyle=\frac{6\lambda^{2}}{C}\norm{{\mathbf{w}}^{*}}^{2}\left(\sum_{j=1}^{N}\left(a_{j}^{*}\right)^{2}\right)\underbrace{\sum_{i=1}^{N}a_{i}^{*}\left(\tilde{{\mathbf{x}}}_{i}^{1}\cdot{\mathbf{w}}^{*}\right)}_{g\left({\mathbf{x}}^{1}\right)}
=6λ2C𝐰2(j=1N(aj)2)σa2g(𝐱1)\displaystyle=\frac{6\lambda^{2}}{C}\norm{{\mathbf{w}}^{*}}^{2}\underbrace{\left(\sum_{j=1}^{N}\left(a_{j}^{*}\right)^{2}\right)}_{\approx\sigma_{a}^{2}}g\left({\mathbf{x}}^{1}\right)

For simplicity we can assume 𝐰2=1\norm{{\mathbf{w}}^{*}}^{2}=1 and σa2=1\sigma_{a}^{2}=1, thus getting a simple proportionality constant of 6λ2C\frac{6\lambda^{2}}{C}. If we were to trade gg for δ^g\hat{\delta}g, as we have in Δg\Delta g, we would get a similar result, with an extra factor of (ασ2/n)3\left(\frac{\alpha}{\sigma^{2}/n}\right)^{3}. The factor of 66 will cancel out with the factor of 1/(41)!1/(4-1)! appearing in the definition of Δg\Delta g. Repeating this calculation for the sixth cumulant, one would arrive to the same result multiplied by a factor of λC(ασ2/n)2\frac{\lambda}{C}\left(\frac{\alpha}{\sigma^{2}/n}\right)^{2} due to the general form of the even cumulants (Eq. F.29) and the fact that there an extra two (σ2/n)1δ^g(\sigma^{2}/n)^{-1}\hat{\delta}g’s.

E.1.2 Self consistent equation in the EK limit

Starting from the proportionality relations δ^g=αg\hat{\delta}g=\alpha g and Δg=αΔg\Delta g=\alpha_{\Delta}g, we can now write the self consistent equation for the discrepancy as

δ^g=(gΔg)qλλ+σn2(gΔg)\displaystyle\hat{\delta}g=\left(g-\Delta g\right)-q\frac{\lambda}{\lambda+\sigma_{n}^{2}}\left(g-\Delta g\right) (E.5)

Dividing both sides by gg we get a scalar equation

α\displaystyle\alpha =(1αΔ)qλλ+σn2(1αΔ)\displaystyle=\left(1-\alpha_{\Delta}\right)-q\frac{\lambda}{\lambda+\sigma_{n}^{2}}\left(1-\alpha_{\Delta}\right) (E.6)
=λ+σn2λ+σn2qλλ+σn2+(qλλ+σn21)αΔ\displaystyle=\frac{\lambda+\sigma_{n}^{2}}{\lambda+\sigma_{n}^{2}}-q\frac{\lambda}{\lambda+\sigma_{n}^{2}}+\left(q\frac{\lambda}{\lambda+\sigma_{n}^{2}}-1\right)\alpha_{\Delta}
=σn2λ+σn2+(1q)λλ+σn2+(qλλ+σn21)αΔ\displaystyle=\frac{\sigma_{n}^{2}}{\lambda+\sigma_{n}^{2}}+\left(1-q\right)\frac{\lambda}{\lambda+\sigma_{n}^{2}}+\left(q\frac{\lambda}{\lambda+\sigma_{n}^{2}}-1\right)\alpha_{\Delta}

The factor αΔ\alpha_{\Delta} can be calculated by noticing that Δg\Delta g has the form of a geometric series. To better understand what follows next, the reader should first go over §F. The first term in this series is related to contracting the fourth cumulant κ4\kappa_{4} with three δ^g\hat{\delta}g’s thus yielding a factor of λ2C(ασ2/n)3\frac{\lambda^{2}}{C}\left(\frac{\alpha}{\sigma^{2}/n}\right)^{3} (recall that in the EK approximation we trade σ2σ2/n\sigma^{2}\to\sigma^{2}/n). The ratio of two consecutive terms in this series is given by λC(ασ2/n)2\frac{\lambda}{C}\left(\frac{\alpha}{\sigma^{2}/n}\right)^{2}. Using the formula for the sum of a geometric series we have

α=σ2/nλ+σ2/n+(1q)λλ+σ2/n+(qλλ+σ2/n1)λ2C(ασ2/n)3[1λC(ασ2/n)2]1\displaystyle\alpha=\frac{\sigma^{2}/n}{\lambda+\sigma^{2}/n}+\frac{\left(1-q\right)\lambda}{\lambda+\sigma^{2}/n}+\left(q\frac{\lambda}{\lambda+\sigma^{2}/n}-1\right)\frac{\lambda^{2}}{C}\left(\frac{\alpha}{\sigma^{2}/n}\right)^{3}\left[1-\frac{\lambda}{C}\left(\frac{\alpha}{\sigma^{2}/n}\right)^{2}\right]^{-1} (E.7)

E.2 Corrections to EK and estimation of the qtrainq_{\rm{train}} factor in the main text

The EK approximation can be improved systematically using the field-theory approach of Ref. [10] where the EK result is interpreted as the leading order contribution, in the large nn limit, to the average of the GP predictor over many data-set draws from the dataset measure. However, that work focused on the test performance whereas for qtrainq_{\rm{train}} we require the performance on the training set. We briefly describe the main augmentations needed here and give the sub-leading and sub-sub-leading corrections to the EK result on the training set, enabling us to estimate qtrainq_{\rm{train}} analytically within a 16.3%16.3\% relative error compared with the empirical value. Further systematic improvements are possible but are left for future work.

We thus consider the quantity μφ(𝐱μ)f(𝐱μ)\sum_{\mu}\varphi({\mathbf{x}}_{\mu})f({\mathbf{x}}_{\mu}) where 𝐱μ{\mathbf{x}}_{\mu} is drawn from the training set, f(𝐱μ)f({\mathbf{x}}_{\mu}) is the predictive mean of the GP on that specific training set, and φ(𝐱μ)\varphi({\mathbf{x}}_{\mu}) is some function which we will later take to be the target function (φ(𝐱)=g(𝐱)\varphi({\mathbf{x}})=g({\mathbf{x}})). We wish to calculate the average of this quantity over all training set draws of size nn. We begin by adding a source term of the form Jμφ(𝐱μ)f(𝐱μ)J\sum_{\mu}\varphi({\mathbf{x}}_{\mu})f({\mathbf{x}}_{\mu}) to the action and notice a similar term appearing in the GP action (μ(f(𝐱μ)g(𝐱μ))2-\sum_{\mu}(f({\mathbf{x}}_{\mu})-g({\mathbf{x}}_{\mu}))^{2}) due to the MSE loss. Examining this extra term one notices that it can be absorbed as a JJ dependent shift to the target on training set (g(𝐱μ)g(𝐱μ)+Jσ22φ(𝐱μ)g({\mathbf{x}}_{\mu})\rightarrow g({\mathbf{x}}_{\mu})+\frac{J\sigma^{2}}{2}\varphi({\mathbf{x}}_{\mu})) following which the analysis of Ref. [10] carries through straightforwardly. Doing so, the general result for the leading EK term and sub-leading correction are

n𝑑μ(𝐱)φ(𝐱)f(𝐱)EKnσ2𝑑μ(𝐱)φ(𝐱)[Cov(𝐱,𝐱)(f(𝐱)EKg(𝐱))]\displaystyle n\int d\mu({\mathbf{x}})\varphi({\mathbf{x}})\langle f({\mathbf{x}})\rangle_{\rm{EK}}-\frac{n}{\sigma^{2}}\int d\mu({\mathbf{x}})\varphi({\mathbf{x}})\left[{\mathrm{Cov}}({\mathbf{x}},{\mathbf{x}})(\langle f({\mathbf{x}})\rangle_{\rm{EK}}-g({\mathbf{x}}))\right] (E.8)

where Cov(𝐱,𝐱)=f(𝐱)f(𝐱)EKf(𝐱)EKf(𝐱)EK{\mathrm{Cov}}({\mathbf{x}},{\mathbf{x}})=\langle f({\mathbf{x}})f({\mathbf{x}})\rangle_{\rm{EK}}-\langle f({\mathbf{x}})\rangle_{\rm{EK}}\langle f({\mathbf{x}})\rangle_{\rm{EK}}, EK\langle...\rangle_{\rm{EK}} means averaging with ZEKZ_{\rm{EK}} of Ref. [10], and f(𝐱)EK\langle f({\mathbf{x}})\rangle_{\rm{EK}} is the EK prediction of the previous section, Eq. D.8.

Turning to the specific linear CNN toy model and carrying the above expansion up to an additional term leads to

αtrain\displaystyle\alpha_{\rm{train}} αEK(1αEKσ2+34αEK2σ4)\displaystyle\approx\alpha_{\rm{EK}}\left(1-\frac{\alpha_{\rm{EK}}}{\sigma^{2}}+\frac{3}{4}\frac{\alpha_{\rm{EK}}^{2}}{\sigma^{4}}\right) (E.9)
αEK\displaystyle\alpha_{\rm{EK}} =σ2/nσ2/n+λ=σ2/nσ2/n+(NS)1\displaystyle=\frac{\sigma^{2}/n}{\sigma^{2}/n+\lambda}=\frac{\sigma^{2}/n}{\sigma^{2}/n+(NS)^{-1}}

Considering for instance n=200,σ2=1.0,N=30n=200,\sigma^{2}=1.0,N=30 and S=30S=30, we find αEK=0.818\alpha_{\rm{EK}}=0.818 and so

αtrain\displaystyle\alpha_{\rm{train}} 0.559\displaystyle\approx 0.559 (E.10)

recalling that qtrain=λ+σ2/nλ(1αtrain)q_{\rm{train}}=\frac{\lambda+\sigma^{2}/n}{\lambda}(1-\alpha_{\rm{train}}) we have

qtrain\displaystyle q_{\rm{train}} 2.4255\displaystyle\approx 2.4255 (E.11)

whereas the empirical value here is 2.89952.8995.

Appendix F Cumulants for a two-layer linear CNN

In this section we explicitly derive the leading (fourth and sixth) cumulants of the toy model of §IV.1, and arrive at the general formula for the even cumulant of arbitrary order.

F.1 Fourth cumulant

F.1.1 Fourth cumulant for a CNN with general activation function (averaging over the readout layer)

For a general activation, we have in our setting for a 2-layer CNN

f(𝐱μ)=i=1Nc=1Cai,cϕ(𝐰c𝐱~iμ)=:i=1Nc=1Cai,cϕi,cμf\left({\mathbf{x}}^{\mu}\right)=\sum_{i=1}^{N}\sum_{c=1}^{C}a_{i,c}\phi\left({\mathbf{w}}_{c}\cdot\tilde{{\mathbf{x}}}_{i}^{\mu}\right)=:\sum_{i=1}^{N}\sum_{c=1}^{C}a_{i,c}\phi_{i,c}^{\mu} (F.1)

The kernel is

K(𝐱1,𝐱2)\displaystyle K\left({\mathbf{x}}^{1},{\mathbf{x}}^{2}\right) =f(𝐱1)f(𝐱)2\displaystyle=\left\langle f\left({\mathbf{x}}^{1}\right)f\left({\mathbf{x}}{}^{2}\right)\right\rangle (F.2)
=i,i=1Nc,c=1Cai,cϕi,c1ai,cϕi,c2\displaystyle=\left\langle\sum_{i,i^{\prime}=1}^{N}\sum_{c,c^{\prime}=1}^{C}a_{i,c}\phi_{i,c}^{1}a_{i^{\prime},c^{\prime}}\phi_{i^{\prime},c^{\prime}}^{2}\right\rangle
=i,i=1Nc,c=1Cai,cai,caδiiδccσa2/CNϕi,c1ϕi,c2𝐰\displaystyle=\sum_{i,i^{\prime}=1}^{N}\sum_{c,c^{\prime}=1}^{C}\underbrace{\left\langle a_{i,c}a_{i^{\prime},c^{\prime}}\right\rangle_{a}}_{\delta_{ii^{\prime}}\delta_{cc^{\prime}}\sigma_{a}^{2}/CN}\left\langle\phi_{i,c}^{1}\phi_{i^{\prime},c^{\prime}}^{2}\right\rangle_{{\mathbf{w}}}
=σa2CNi=1Nc=1Cϕi,c1ϕi,c2𝐰=σa2Ni=1Nϕi,c1ϕi,c2𝐰\displaystyle=\frac{\sigma_{a}^{2}}{CN}\sum_{i=1}^{N}\sum_{c=1}^{C}\left\langle\phi_{i,c}^{1}\phi_{i,c}^{2}\right\rangle_{{\mathbf{w}}}=\frac{\sigma_{a}^{2}}{N}\sum_{i=1}^{N}\left\langle\phi_{i,c}^{1}\phi_{i,c}^{2}\right\rangle_{{\mathbf{w}}}

The fourth moment is

f(𝐱1)f(𝐱2)f(𝐱3)f(𝐱4)𝐚,𝐰=i1:4c1:4ai1,c1ai2,c2ai3,c3ai4,c4𝐚ϕi1,c11ϕi2,c22ϕi3,c33ϕi4,c44𝐰\left\langle f\left({\mathbf{x}}^{1}\right)f\left({\mathbf{x}}^{2}\right)f\left({\mathbf{x}}^{3}\right)f\left({\mathbf{x}}^{4}\right)\right\rangle_{\bf{a},{\mathbf{w}}}=\sum_{i_{1:4}}\sum_{c_{1:4}}\left\langle a_{i_{1},c_{1}}a_{i_{2},c_{2}}a_{i_{3},c_{3}}a_{i_{4},c_{4}}\right\rangle_{\bf{a}}\left\langle\phi_{i_{1},c_{1}}^{1}\phi_{i_{2},c_{2}}^{2}\phi_{i_{3},c_{3}}^{3}\phi_{i_{4},c_{4}}^{4}\right\rangle_{{\mathbf{w}}} (F.3)

Averaging over the last layer weights gives

ai1,c1ai2,c2ai3,c3ai4,c4𝐚=(σa2CN)2(δi1i2δc1c2δi3i4δc3c4+{(13)(24)+(14)(23)})\displaystyle\left\langle a_{i_{1},c_{1}}a_{i_{2},c_{2}}a_{i_{3},c_{3}}a_{i_{4},c_{4}}\right\rangle_{\bf{a}}=\left(\frac{\sigma_{a}^{2}}{CN}\right)^{2}\left(\delta_{i_{1}i_{2}}\delta_{c_{1}c_{2}}\delta_{i_{3}i_{4}}\delta_{c_{3}c_{4}}+\left\{\left(13\right)\left(24\right)+\left(14\right)\left(23\right)\right\}\right) (F.4)

So this will always make two pairs out of four ϕ\phi’s, each with the same i,ci,c indices. Notice that, regardless of the input indices, for different channels ccc\neq c^{\prime} we have

ϕi,cμϕi,cνϕj,cμϕj,cν𝐰=ϕi,cμϕi,cν𝐰ϕj,cμϕj,cν𝐰\displaystyle\left\langle\phi_{i,c}^{\mu}\phi_{i,c}^{\nu}\phi_{j,c^{\prime}}^{\mu^{\prime}}\phi_{j,c^{\prime}}^{\nu^{\prime}}\right\rangle_{{\mathbf{w}}}=\left\langle\phi_{i,c}^{\mu}\phi_{i,c}^{\nu}\right\rangle_{{\mathbf{w}}}\left\langle\phi_{j,c^{\prime}}^{\mu^{\prime}}\phi_{j,c^{\prime}}^{\nu^{\prime}}\right\rangle_{{\mathbf{w}}} (F.5)

so, e.g. the first term out of three is

i1:4c1:4δi1i2δc1c2δi3i4δc3c4ϕi1,c11ϕi2,c22ϕi3,c33ϕi4,c44𝐰\displaystyle\sum_{i_{1:4}}\sum_{c_{1:4}}\delta_{i_{1}i_{2}}\delta_{c_{1}c_{2}}\delta_{i_{3}i_{4}}\delta_{c_{3}c_{4}}\left\langle\phi_{i_{1},c_{1}}^{1}\phi_{i_{2},c_{2}}^{2}\phi_{i_{3},c_{3}}^{3}\phi_{i_{4},c_{4}}^{4}\right\rangle_{{\mathbf{w}}} (F.6)
=i1,i3c1,c3ϕi1,c11ϕi1,c12ϕi3,c33ϕi3,c34𝐰\displaystyle=\sum_{i_{1},i_{3}}\sum_{c_{1},c_{3}}\left\langle\phi_{i_{1},c_{1}}^{1}\phi_{i_{1},c_{1}}^{2}\phi_{i_{3},c_{3}}^{3}\phi_{i_{3},c_{3}}^{4}\right\rangle_{{\mathbf{w}}}
=i1,i3{cϕi1,c1ϕi1,c2ϕi3,c3ϕi3,c4𝐰+c1,c3c1c3ϕi1,c11ϕi1,c12𝐰ϕi3,c33ϕi3,c34𝐰}\displaystyle=\sum_{i_{1},i_{3}}\left\{\sum_{c}\left\langle\phi_{i_{1},c}^{1}\phi_{i_{1},c}^{2}\phi_{i_{3},c}^{3}\phi_{i_{3},c}^{4}\right\rangle_{{\mathbf{w}}}+\sum_{\begin{array}[]{c}c_{1},c_{3}\\ c_{1}\neq c_{3}\end{array}}\left\langle\phi_{i_{1},c_{1}}^{1}\phi_{i_{1},c_{1}}^{2}\right\rangle_{{\mathbf{w}}}\left\langle\phi_{i_{3},c_{3}}^{3}\phi_{i_{3},c_{3}}^{4}\right\rangle_{{\mathbf{w}}}\right\} (F.9)

where in the last line we separated the diagonal and off-diagonal terms in the channel indices. So

(σa2CN)2f(𝐱1)f(𝐱2)f(𝐱3)f(𝐱4)𝐚,𝐰\displaystyle\left(\frac{\sigma_{a}^{2}}{CN}\right)^{-2}\left\langle f\left({\mathbf{x}}^{1}\right)f\left({\mathbf{x}}^{2}\right)f\left({\mathbf{x}}^{3}\right)f\left({\mathbf{x}}^{4}\right)\right\rangle_{\bf{a},{\mathbf{w}}} (F.10)
=i1,i2c{ϕi1,c1ϕi1,c2ϕi2,c3ϕi2,c4𝐰+ϕi1,c1ϕi1,c3ϕi2,c2ϕi2,c4𝐰+ϕi1,c1ϕi1,c4ϕi2,c2ϕi2,c3𝐰}\displaystyle=\sum_{i_{1},i_{2}}\sum_{c}\left\{\left\langle\phi_{i_{1},c}^{1}\phi_{i_{1},c}^{2}\phi_{i_{2},c}^{3}\phi_{i_{2},c}^{4}\right\rangle_{{\mathbf{w}}}+\left\langle\phi_{i_{1},c}^{1}\phi_{i_{1},c}^{3}\phi_{i_{2},c}^{2}\phi_{i_{2},c}^{4}\right\rangle_{{\mathbf{w}}}+\left\langle\phi_{i_{1},c}^{1}\phi_{i_{1},c}^{4}\phi_{i_{2},c}^{2}\phi_{i_{2},c}^{3}\right\rangle_{{\mathbf{w}}}\right\}
+i1,i2c1,c2c1c2{ϕi1,c11ϕi1,c12𝐰ϕi2,c23ϕi2,c24𝐰+ϕi1,c11ϕi1,c13𝐰ϕi2,c22ϕi2,c24𝐰+ϕi1,c11ϕi1,c14𝐰ϕi2,c22ϕi2,c23𝐰}\displaystyle+\sum_{i_{1},i_{2}}\sum_{\begin{array}[]{c}c_{1},c_{2}\\ c_{1}\neq c_{2}\end{array}}\left\{\left\langle\phi_{i_{1},c_{1}}^{1}\phi_{i_{1},c_{1}}^{2}\right\rangle_{{\mathbf{w}}}\left\langle\phi_{i_{2},c_{2}}^{3}\phi_{i_{2},c_{2}}^{4}\right\rangle_{{\mathbf{w}}}+\left\langle\phi_{i_{1},c_{1}}^{1}\phi_{i_{1},c_{1}}^{3}\right\rangle_{{\mathbf{w}}}\left\langle\phi_{i_{2},c_{2}}^{2}\phi_{i_{2},c_{2}}^{4}\right\rangle_{{\mathbf{w}}}+\left\langle\phi_{i_{1},c_{1}}^{1}\phi_{i_{1},c_{1}}^{4}\right\rangle_{{\mathbf{w}}}\left\langle\phi_{i_{2},c_{2}}^{2}\phi_{i_{2},c_{2}}^{3}\right\rangle_{{\mathbf{w}}}\right\} (F.13)

On the other hand

f1f2f3f4\displaystyle\left\langle f^{1}f^{2}\right\rangle\left\langle f^{3}f^{4}\right\rangle (F.14)
=(σa2CNi=1Nc=1Cϕi,c1ϕi,c2𝐰)(σa2CNi=1Nc=1Cϕi,c3ϕi,c4𝐰)\displaystyle=\left(\frac{\sigma_{a}^{2}}{CN}\sum_{i=1}^{N}\sum_{c=1}^{C}\left\langle\phi_{i,c}^{1}\phi_{i,c}^{2}\right\rangle_{{\mathbf{w}}}\right)\left(\frac{\sigma_{a}^{2}}{CN}\sum_{i^{\prime}=1}^{N}\sum_{c^{\prime}=1}^{C}\left\langle\phi_{i^{\prime},c^{\prime}}^{3}\phi_{i^{\prime},c^{\prime}}^{4}\right\rangle_{{\mathbf{w}}}\right)
=(σa2CN)2i,i=1Nc,c=1Cϕi,c1ϕi,c2𝐰ϕi,c3ϕi,c4𝐰\displaystyle=\left(\frac{\sigma_{a}^{2}}{CN}\right)^{2}\sum_{i,i^{\prime}=1}^{N}\sum_{c,c^{\prime}=1}^{C}\left\langle\phi_{i,c}^{1}\phi_{i,c}^{2}\right\rangle_{{\mathbf{w}}}\left\langle\phi_{i^{\prime},c^{\prime}}^{3}\phi_{i^{\prime},c^{\prime}}^{4}\right\rangle_{{\mathbf{w}}}
=(σa2CN)2i,i=1N{c,c=1ccCϕi,c1ϕi,c2𝐰ϕi,c3ϕi,c4𝐰+c=1Cϕi,c1ϕi,c2𝐰ϕi,c3ϕi,c4𝐰}\displaystyle=\left(\frac{\sigma_{a}^{2}}{CN}\right)^{2}\sum_{i,i^{\prime}=1}^{N}\left\{\sum_{\begin{array}[]{c}c,c^{\prime}=1\\ c\neq c^{\prime}\end{array}}^{C}\left\langle\phi_{i,c}^{1}\phi_{i,c}^{2}\right\rangle_{{\mathbf{w}}}\left\langle\phi_{i^{\prime},c^{\prime}}^{3}\phi_{i^{\prime},c^{\prime}}^{4}\right\rangle_{{\mathbf{w}}}+\sum_{c=1}^{C}\left\langle\phi_{i,c}^{1}\phi_{i,c}^{2}\right\rangle_{{\mathbf{w}}}\left\langle\phi_{i^{\prime},c}^{3}\phi_{i^{\prime},c}^{4}\right\rangle_{{\mathbf{w}}}\right\} (F.17)

Putting it all together, the off-diagonal terms in the channel indices cancel and we are left with

(σa2CN)2κ4(𝐱1,𝐱2,𝐱3,𝐱4)\displaystyle\left(\frac{\sigma_{a}^{2}}{CN}\right)^{-2}\kappa_{4}\left({\mathbf{x}}_{1},{\mathbf{x}}_{2},{\mathbf{x}}_{3},{\mathbf{x}}_{4}\right) (F.18)
=(σa2CN)2(f1f2f3f4(f1f2f3f4+f1f3f2f4+f1f4f2f3))\displaystyle=\left(\frac{\sigma_{a}^{2}}{CN}\right)^{-2}\left(\left\langle f^{1}f^{2}f^{3}f^{4}\right\rangle-\left(\left\langle f^{1}f^{2}\right\rangle\left\langle f^{3}f^{4}\right\rangle+\left\langle f^{1}f^{3}\right\rangle\left\langle f^{2}f^{4}\right\rangle+\left\langle f^{1}f^{4}\right\rangle\left\langle f^{2}f^{3}\right\rangle\right)\right)
=i1,i2c{ϕi1,c1ϕi1,c2ϕi2,c3ϕi2,c4𝐰+ϕi1,c1ϕi1,c3ϕi2,c2ϕi2,c4𝐰+ϕi1,c1ϕi1,c4ϕi2,c2ϕi2,c3𝐰}\displaystyle=\sum_{i_{1},i_{2}}\sum_{c}\left\{\left\langle\phi_{i_{1},c}^{1}\phi_{i_{1},c}^{2}\phi_{i_{2},c}^{3}\phi_{i_{2},c}^{4}\right\rangle_{{\mathbf{w}}}+\left\langle\phi_{i_{1},c}^{1}\phi_{i_{1},c}^{3}\phi_{i_{2},c}^{2}\phi_{i_{2},c}^{4}\right\rangle_{{\mathbf{w}}}+\left\langle\phi_{i_{1},c}^{1}\phi_{i_{1},c}^{4}\phi_{i_{2},c}^{2}\phi_{i_{2},c}^{3}\right\rangle_{{\mathbf{w}}}\right\}
i1,i2c{ϕi1,c1ϕi1,c2𝐰ϕi2,c3ϕi2,c4𝐰+ϕi1,c1ϕi1,c3𝐰ϕi2,c2ϕi2,c4𝐰+ϕi1,c1ϕi1,c4𝐰ϕi2,c2ϕi2,c3𝐰}\displaystyle-\sum_{i_{1},i_{2}}\sum_{c}\left\{\left\langle\phi_{i_{1},c}^{1}\phi_{i_{1},c}^{2}\right\rangle_{{\mathbf{w}}}\left\langle\phi_{i_{2},c}^{3}\phi_{i_{2},c}^{4}\right\rangle_{{\mathbf{w}}}+\left\langle\phi_{i_{1},c}^{1}\phi_{i_{1},c}^{3}\right\rangle_{{\mathbf{w}}}\left\langle\phi_{i_{2},c}^{2}\phi_{i_{2},c}^{4}\right\rangle_{{\mathbf{w}}}+\left\langle\phi_{i_{1},c}^{1}\phi_{i_{1},c}^{4}\right\rangle_{{\mathbf{w}}}\left\langle\phi_{i_{2},c}^{2}\phi_{i_{2},c}^{3}\right\rangle_{{\mathbf{w}}}\right\}
:=i,j=1Nc=1C{ϕi,c1ϕi,c2ϕj,c3ϕj,c4𝐰ϕi,c1ϕi,c2𝐰ϕj,c3ϕj,c4𝐰}+[(1i3i)(2j4j)+(1i4i)(2j3j)]\displaystyle:=\sum_{i,j=1}^{N}\sum_{c=1}^{C}\left\{\left\langle\phi_{i,c}^{1}\phi_{i,c}^{2}\phi_{j,c}^{3}\phi_{j,c}^{4}\right\rangle_{{\mathbf{w}}}-\left\langle\phi_{i,c}^{1}\phi_{i,c}^{2}\right\rangle_{{\mathbf{w}}}\left\langle\phi_{j,c}^{3}\phi_{j,c}^{4}\right\rangle_{{\mathbf{w}}}\right\}+\left[\left(1_{i}3_{i}\right)\left(2_{j}4_{j}\right)+\left(1_{i}4_{i}\right)\left(2_{j}3_{j}\right)\right]

where in the last line we introduced a short-hand notation to compactly keep track of the combinations of the indices.

F.1.2 Fourth cumulant for linear CNN

Here, ϕi,cμ:=𝐰c𝐱~iμ=s=1Sws(c)x~s(μ,i)\phi_{i,c}^{\mu}:={\mathbf{w}}_{c}\cdot\tilde{{\mathbf{x}}}_{i}^{\mu}=\sum_{s=1}^{S}w_{s}^{\left(c\right)}\tilde{x}_{s}^{\left(\mu,i\right)}. The fourth moment is

ϕi,c1ϕi,c2ϕj,c3ϕj,c4𝐰\displaystyle\left\langle\phi_{i,c}^{1}\phi_{i,c}^{2}\phi_{j,c}^{3}\phi_{j,c}^{4}\right\rangle_{{\mathbf{w}}} (F.19)
=s1:4=1S(ws1(c)x~s1(1,i))(ws2(c)x~s2(2,i))(ws3(c)x~s3(3,j))(ws4(c)x~s4(4,j))𝐰\displaystyle=\sum_{s_{1:4}=1}^{S}\left\langle\left(w_{s_{1}}^{\left(c\right)}\tilde{x}_{s_{1}}^{\left(1,i\right)}\right)\left(w_{s_{2}}^{\left(c\right)}\tilde{x}_{s_{2}}^{\left(2,i\right)}\right)\left(w_{s_{3}}^{\left(c\right)}\tilde{x}_{s_{3}}^{\left(3,j\right)}\right)\left(w_{s_{4}}^{\left(c\right)}\tilde{x}_{s_{4}}^{\left(4,j\right)}\right)\right\rangle_{{\mathbf{w}}}
=s1:4=1Sws1(c)ws2(c)ws3(c)ws4(c)𝐰(σw2/S)2δs1s2δs3s4[3]x~s1(1,i)x~s2(2,i)x~s3(3,j)x~s4(4,j)\displaystyle=\sum_{s_{1:4}=1}^{S}\underbrace{\left\langle w_{s_{1}}^{\left(c\right)}w_{s_{2}}^{\left(c\right)}w_{s_{3}}^{\left(c\right)}w_{s_{4}}^{\left(c\right)}\right\rangle_{{\mathbf{w}}}}_{\left(\sigma_{w}^{2}/S\right)^{2}\cdot\delta_{s_{1}s_{2}}\delta_{s_{3}s_{4}}\left[3\right]}\tilde{x}_{s_{1}}^{\left(1,i\right)}\tilde{x}_{s_{2}}^{\left(2,i\right)}\tilde{x}_{s_{3}}^{\left(3,j\right)}\tilde{x}_{s_{4}}^{\left(4,j\right)}
=(σw2S)2s1:4=1S(δs1s2δs3s4+δs1s3δs2s4+δs1s4δs2s3)x~s1(1,i)x~s2(2,i)x~s3(3,j)x~s4(4,j)\displaystyle=\left(\frac{\sigma_{w}^{2}}{S}\right)^{2}\sum_{s_{1:4}=1}^{S}\left(\delta_{s_{1}s_{2}}\delta_{s_{3}s_{4}}+\delta_{s_{1}s_{3}}\delta_{s_{2}s_{4}}+\delta_{s_{1}s_{4}}\delta_{s_{2}s_{3}}\right)\tilde{x}_{s_{1}}^{\left(1,i\right)}\tilde{x}_{s_{2}}^{\left(2,i\right)}\tilde{x}_{s_{3}}^{\left(3,j\right)}\tilde{x}_{s_{4}}^{\left(4,j\right)}
=(σw2S)2{(𝐱~i1𝐱~i2)(𝐱~j3𝐱~j4)+(𝐱~i1𝐱~j3)(𝐱~i2𝐱~j4)+(𝐱~i1𝐱~j4)(𝐱~i2𝐱~j3)}\displaystyle=\left(\frac{\sigma_{w}^{2}}{S}\right)^{2}\left\{\left(\tilde{{\mathbf{x}}}_{i}^{1}\cdot\tilde{{\mathbf{x}}}_{i}^{2}\right)\left(\tilde{{\mathbf{x}}}_{j}^{3}\cdot\tilde{{\mathbf{x}}}_{j}^{4}\right)+\left(\tilde{{\mathbf{x}}}_{i}^{1}\cdot\tilde{{\mathbf{x}}}_{j}^{3}\right)\left(\tilde{{\mathbf{x}}}_{i}^{2}\cdot\tilde{{\mathbf{x}}}_{j}^{4}\right)+\left(\tilde{{\mathbf{x}}}_{i}^{1}\cdot\tilde{{\mathbf{x}}}_{j}^{4}\right)\left(\tilde{{\mathbf{x}}}_{i}^{2}\cdot\tilde{{\mathbf{x}}}_{j}^{3}\right)\right\}
:=(σw2S)2{(1i2i)(3j4j)+(1i3j)(2i4j)+(1i4j)(2i3j)}\displaystyle:=\left(\frac{\sigma_{w}^{2}}{S}\right)^{2}\left\{\left(1_{i}2_{i}\right)\left(3_{j}4_{j}\right)+\left(1_{i}3_{j}\right)\left(2_{i}4_{j}\right)+\left(1_{i}4_{j}\right)\left(2_{i}3_{j}\right)\right\}

Similarly

(σw2S)2ϕi,c1ϕi,c3ϕj,c2ϕj,c4𝐰\displaystyle\left(\frac{\sigma_{w}^{2}}{S}\right)^{-2}\left\langle\phi_{i,c}^{1}\phi_{i,c}^{3}\phi_{j,c}^{2}\phi_{j,c}^{4}\right\rangle_{{\mathbf{w}}} =(1i3i)(2j4j)+(1i2j)(3i4j)+(1i4j)(3i2j)\displaystyle=\left(1_{i}3_{i}\right)\left(2_{j}4_{j}\right)+\left(1_{i}2_{j}\right)\left(3_{i}4_{j}\right)+\left(1_{i}4_{j}\right)\left(3_{i}2_{j}\right) (F.20)
(σw2S)2ϕi,c1ϕi,c4ϕj,c2ϕj,c3𝐰\displaystyle\left(\frac{\sigma_{w}^{2}}{S}\right)^{-2}\left\langle\phi_{i,c}^{1}\phi_{i,c}^{4}\phi_{j,c}^{2}\phi_{j,c}^{3}\right\rangle_{{\mathbf{w}}} =(1i4i)(3j2j)+(1i3j)(4i2j)+(1i2j)(4i3j)\displaystyle=\left(1_{i}4_{i}\right)\left(3_{j}2_{j}\right)+\left(1_{i}3_{j}\right)\left(4_{i}2_{j}\right)+\left(1_{i}2_{j}\right)\left(4_{i}3_{j}\right)

Notice that the 2nd and 3rd terms have (ij)(ij)\left(ij\right)\left(ij\right) while the first term has (ii)(jj)\left(ii\right)\left(jj\right). The latter will cancel out with the ϕi,cμϕi,cν𝐰ϕj,cμϕj,cν𝐰\left\langle\phi_{i,c}^{\mu}\phi_{i,c}^{\nu}\right\rangle_{{\mathbf{w}}}\left\langle\phi_{j,c}^{\mu^{\prime}}\phi_{j,c}^{\nu^{\prime}}\right\rangle_{{\mathbf{w}}} terms. Thus

[(1i2i)(3j4j)+(1i3j)(2i4j)+(1i4j)(2i3j)]\displaystyle\left[\cancel{\left(1_{i}2_{i}\right)\left(3_{j}4_{j}\right)}+\left(1_{i}3_{j}\right)\left(2_{i}4_{j}\right)+\left(1_{i}4_{j}\right)\left(2_{i}3_{j}\right)\right] (F.21)
+[(1i3i)(2j4j)+(1i2j)(3i4j)+(1i4j)(3i2j)]\displaystyle+\left[\cancel{\left(1_{i}3_{i}\right)\left(2_{j}4_{j}\right)}+\left(1_{i}2_{j}\right)\left(3_{i}4_{j}\right)+\left(1_{i}4_{j}\right)\left(3_{i}2_{j}\right)\right]
+[(1i4i)(3j2j)+(1i3j)(4i2j)+(1i2j)(4i3j)]\displaystyle+\left[\cancel{\left(1_{i}4_{i}\right)\left(3_{j}2_{j}\right)}+\left(1_{i}3_{j}\right)\left(4_{i}2_{j}\right)+\left(1_{i}2_{j}\right)\left(4_{i}3_{j}\right)\right]
[(1i2i)(3j4j)+(1i3i)(2j4j)+(1i4i)(3j2j)]\displaystyle-\left[\cancel{\left(1_{i}2_{i}\right)\left(3_{j}4_{j}\right)}+\cancel{\left(1_{i}3_{i}\right)\left(2_{j}4_{j}\right)}+\cancel{\left(1_{i}4_{i}\right)\left(3_{j}2_{j}\right)}\right]
=(1i3j)(2i4j)+(1i4j)(2i3j)+(1i2j)(3i4j)+(1i4j)(3i2j)+(1i3j)(4i2j)+(1i2j)(4i3j)\displaystyle=\left(1_{i}3_{j}\right)\left(2_{i}4_{j}\right)+\left(1_{i}4_{j}\right)\left(2_{i}3_{j}\right)+\left(1_{i}2_{j}\right)\left(3_{i}4_{j}\right)+\left(1_{i}4_{j}\right)\left(3_{i}2_{j}\right)+\left(1_{i}3_{j}\right)\left(4_{i}2_{j}\right)+\left(1_{i}2_{j}\right)\left(4_{i}3_{j}\right)
=(1i3j)[(2i4j)+(4i2j)]+(1i4j)[(2i3j)+(3i2j)]+(1i2j)[(3i4j)+(4i3j)]\displaystyle=\left(1_{i}3_{j}\right)\left[\left(2_{i}4_{j}\right)+\left(4_{i}2_{j}\right)\right]+\left(1_{i}4_{j}\right)\left[\left(2_{i}3_{j}\right)+\left(3_{i}2_{j}\right)\right]+\left(1_{i}2_{j}\right)\left[\left(3_{i}4_{j}\right)+\left(4_{i}3_{j}\right)\right]

Denote λ:=σa2Nσw2S\lambda:=\frac{\sigma_{a}^{2}}{N}\frac{\sigma_{w}^{2}}{S} The fourth cumulant is

κ4(𝐱1,𝐱2,𝐱3,𝐱4)\displaystyle\kappa_{4}\left({\mathbf{x}}_{1},{\mathbf{x}}_{2},{\mathbf{x}}_{3},{\mathbf{x}}_{4}\right) (F.22)
=λ2Ci,j=1N{(1i2j)[(3i4j)+(4i3j)]+(1i3j)[(2i4j)+(4i2j)]+(1i4j)[(2i3j)+(3i2j)]}\displaystyle=\frac{\lambda^{2}}{C}\sum_{i,j=1}^{N}\left\{\left(1_{i}2_{j}\right)\left[\left(3_{i}4_{j}\right)+\left(4_{i}3_{j}\right)\right]+\left(1_{i}3_{j}\right)\left[\left(2_{i}4_{j}\right)+\left(4_{i}2_{j}\right)\right]+\left(1_{i}4_{j}\right)\left[\left(2_{i}3_{j}\right)+\left(3_{i}2_{j}\right)\right]\right\}

Notice that all terms involve inner products between 𝐱~\tilde{{\mathbf{x}}}’s with different indices i,ji,j, i.e. mixing different convolutional windows. This means that κ4\kappa_{4}, and also all higher order cumulants, cannot be written in terms of the linear kernel, which does not mix different conv-window indices. This is in contrast to the kernel (second cumulant) of this linear CNN which is identical to that of a corresponding linear fully connected network (FCN): K(𝐱,𝐱)=σa2σw2NS𝐱𝖳𝐱K\left({\mathbf{x}},{\mathbf{x}}^{\prime}\right)=\frac{\sigma_{a}^{2}\sigma_{w}^{2}}{NS}{\mathbf{x}}^{\mathsf{T}}{\mathbf{x}}^{\prime} It is also in contrast to the higher cumulants of the corresponding linear FCN, where all cumulants can be expressed in terms of products of the linear kernel.

F.2 Sixth cumulant and above

The even moments in terms of cumulants for a vector valued RV with zero odd moments and cumulants are (see [29]):

κμ1μ2\displaystyle\kappa^{\mu_{1}\mu_{2}} =κμ1,μ2\displaystyle=\kappa^{\mu_{1},\mu_{2}} (F.23)
κμ1μ2μ3μ4\displaystyle\kappa^{\mu_{1}\mu_{2}\mu_{3}\mu_{4}} =κμ1,μ2,μ3,μ4+κμ1,μ2κμ3,μ4[3]\displaystyle=\kappa^{\mu_{1},\mu_{2},\mu_{3},\mu_{4}}+\kappa^{\mu_{1},\mu_{2}}\kappa^{\mu_{3},\mu_{4}}\left[3\right]
κμ1μ2μ3μ4μ5μ6\displaystyle\kappa^{\mu_{1}\mu_{2}\mu_{3}\mu_{4}\mu_{5}\mu_{6}} =κμ1,μ2,μ3,μ4,μ5,μ6+κμ1,μ2,μ3,μ4κμ5,μ6[15]+κμ1,μ2κμ3,μ4κμ5,μ6[15]\displaystyle=\kappa^{\mu_{1},\mu_{2},\mu_{3},\mu_{4},\mu_{5},\mu_{6}}+\kappa^{\mu_{1},\mu_{2},\mu_{3},\mu_{4}}\kappa^{\mu_{5},\mu_{6}}\left[15\right]+\kappa^{\mu_{1},\mu_{2}}\kappa^{\mu_{3},\mu_{4}}\kappa^{\mu_{5},\mu_{6}}\left[15\right]

where the moments are on the l.h.s. (indices with no commas) and the cumulants are on the r.h.s. (indices are separated with commas). Thus, the sixth cumulant is

κμ1,μ2,μ3,μ4,μ5,μ6=κμ1μ2μ3μ4μ5μ6κμ1,μ2,μ3,μ4κμ5,μ6[15]κμ1,μ2κμ3,μ4κμ5,μ6[15]\displaystyle\kappa^{\mu_{1},\mu_{2},\mu_{3},\mu_{4},\mu_{5},\mu_{6}}=\kappa^{\mu_{1}\mu_{2}\mu_{3}\mu_{4}\mu_{5}\mu_{6}}-\kappa^{\mu_{1},\mu_{2},\mu_{3},\mu_{4}}\kappa^{\mu_{5},\mu_{6}}\left[15\right]-\kappa^{\mu_{1},\mu_{2}}\kappa^{\mu_{3},\mu_{4}}\kappa^{\mu_{5},\mu_{6}}\left[15\right] (F.24)

In the linear case, the analogue of κμ1,μ2,μ3,μ4κμ5,μ6\kappa^{\mu_{1},\mu_{2},\mu_{3},\mu_{4}}\kappa^{\mu_{5},\mu_{6}} is (1515 such pairings, where only the numbers "move", not the i,j,ki,j,k)

1λ3K(𝐱1,𝐱2)K(𝐱3,𝐱4)K(𝐱5,𝐱6)=(1i2i)(3j4j)(5k6k)\displaystyle\frac{1}{\lambda^{3}}K\left({\mathbf{x}}_{1},{\mathbf{x}}_{2}\right)K\left({\mathbf{x}}_{3},{\mathbf{x}}_{4}\right)K\left({\mathbf{x}}_{5},{\mathbf{x}}_{6}\right)={\color[rgb]{0,0,1}\left(1_{i}2_{i}\right)\left(3_{j}4_{j}\right)\left(5_{k}6_{k}\right)} (F.25)

and the analogue of κμ1,μ2κμ3,μ4κμ5,μ6\kappa^{\mu_{1},\mu_{2}}\kappa^{\mu_{3},\mu_{4}}\kappa^{\mu_{5},\mu_{6}} is

Cλ3κ4(𝐱1,𝐱2,𝐱3,𝐱4)K(𝐱5,𝐱6)\displaystyle\frac{C}{\lambda^{3}}\kappa_{4}\left({\mathbf{x}}_{1},{\mathbf{x}}_{2},{\mathbf{x}}_{3},{\mathbf{x}}_{4}\right)K\left({\mathbf{x}}_{5},{\mathbf{x}}_{6}\right) (F.26)
={(1i2j)[(3i4j)+(4i3j)]+(1i3j)[(2i4j)+(4i2j)]+(1i4j)[(2i3j)+(3i2j)]}(5k6k)\displaystyle=\left\{\left(1_{i}2_{j}\right)\left[\left(3_{i}4_{j}\right)+\left(4_{i}3_{j}\right)\right]+\left(1_{i}3_{j}\right)\left[\left(2_{i}4_{j}\right)+\left(4_{i}2_{j}\right)\right]+\left(1_{i}4_{j}\right)\left[\left(2_{i}3_{j}\right)+\left(3_{i}2_{j}\right)\right]\right\}\left(5_{k}6_{k}\right)
={(1i2j)(3i4j)+(1i2j)(4i3j)+(1i3j)(2i4j)+(1i3j)(4i2j)+(1i4j)(2i3j)+(1i4j)(3i2j)}(5k6k)\displaystyle=\left\{\left(1_{i}2_{j}\right)\left(3_{i}4_{j}\right)+\left(1_{i}2_{j}\right)\left(4_{i}3_{j}\right)+{\color[rgb]{1,0,0}\left(1_{i}3_{j}\right)\left(2_{i}4_{j}\right)}+\left(1_{i}3_{j}\right)\left(4_{i}2_{j}\right)+{\color[rgb]{1,0,0}\left(1_{i}4_{j}\right)\left(2_{i}3_{j}\right)}+\left(1_{i}4_{j}\right)\left(3_{i}2_{j}\right)\right\}{\color[rgb]{1,0,0}\left(5_{k}6_{k}\right)}

Below, we found the 6th moment for a linear CNN to be

ϕi,c1ϕi,c2ϕj,c3ϕj,c4ϕk,c5ϕk,c6\displaystyle\left\langle\phi_{i,c}^{1}\phi_{i,c}^{2}\phi_{j,c}^{3}\phi_{j,c}^{4}\phi_{k,c}^{5}\phi_{k,c}^{6}\right\rangle (F.27)
=(1i2i)(3j4j)(5k6k)+(1i3j)(2i4j)(5k6k)+(1i4j)(2i3j)(5k6k)\displaystyle={\color[rgb]{0,0,1}\left(1_{i}2_{i}\right)\left(3_{j}4_{j}\right)\left(5_{k}6_{k}\right)}+{\color[rgb]{1,0,0}\left(1_{i}3_{j}\right)\left(2_{i}4_{j}\right)\left(5_{k}6_{k}\right)}+{\color[rgb]{1,0,0}\left(1_{i}4_{j}\right)\left(2_{i}3_{j}\right)\left(5_{k}6_{k}\right)}
+(1i2i)(3j5k)(4j6k)+(1i3j)(2i5k)(4j6k)+(1i5k)(2i3j)(4j6k)\displaystyle+{\color[rgb]{1,0,0}\left(1_{i}2_{i}\right)\left(3_{j}5_{k}\right)\left(4_{j}6_{k}\right)}+\left(1_{i}3_{j}\right)\left(2_{i}5_{k}\right)\left(4_{j}6_{k}\right)+\left(1_{i}5_{k}\right)\left(2_{i}3_{j}\right)\left(4_{j}6_{k}\right)
+(1i2i)(4j5k)(3j6k)+(1i5k)(2i4j)(3j6k)+(1i4j)(2i5k)(3j6k)\displaystyle+{\color[rgb]{1,0,0}\left(1_{i}2_{i}\right)\left(4_{j}5_{k}\right)\left(3_{j}6_{k}\right)}+\left(1_{i}5_{k}\right)\left(2_{i}4_{j}\right)\left(3_{j}6_{k}\right)+\left(1_{i}4_{j}\right)\left(2_{i}5_{k}\right)\left(3_{j}6_{k}\right)
+(1i5k)(3j4j)(2i6k)+(1i3j)(4j5k)(2i6k)+(1i4j)(3j5k)(2i6k)\displaystyle+{\color[rgb]{1,0,0}\left(1_{i}5_{k}\right)\left(3_{j}4_{j}\right)\left(2_{i}6_{k}\right)}+\left(1_{i}3_{j}\right)\left(4_{j}5_{k}\right)\left(2_{i}6_{k}\right)+\left(1_{i}4_{j}\right)\left(3_{j}5_{k}\right)\left(2_{i}6_{k}\right)
+(2i5k)(3j4j)(1i6k)+(3j5k)(2i4j)(1i6k)+(4j5k)(2i3j)(1i6k)\displaystyle+{\color[rgb]{1,0,0}\left(2_{i}5_{k}\right)\left(3_{j}4_{j}\right)\left(1_{i}6_{k}\right)}+\left(3_{j}5_{k}\right)\left(2_{i}4_{j}\right)\left(1_{i}6_{k}\right)+\left(4_{j}5_{k}\right)\left(2_{i}3_{j}\right)\left(1_{i}6_{k}\right)

Notice that for every blue term we have exactly 66 red terms, so all of the colored terms will exactly cancel out and only the uncolored terms will survive. There are 88 such uncolored terms for each one of the 1515 pairings, thus we will ultimately have 120120 such pairs, thus the sixth cumulant is

κ6(𝐱1,,𝐱6)=λ3C2i,j,k=1N(ij)(ik)(jk)[120]\kappa_{6}\left({\mathbf{x}}_{1},\dots,{\mathbf{x}}_{6}\right)=\frac{\lambda^{3}}{C^{2}}\sum_{i,j,k=1}^{N}\left(\bullet_{i}\bullet_{j}\right)\left(\bullet_{i}\bullet_{k}\right)\left(\bullet_{j}\bullet_{k}\right)\left[120\right] (F.28)

where the [120]\left[120\right] stands for the number of ways to pair the numbers {1,,6}\left\{1,...,6\right\} into the form (ij)(ik)(jk)\left(\bullet_{i}\bullet_{j}\right)\left(\bullet_{i}\bullet_{k}\right)\left(\bullet_{j}\bullet_{k}\right).

We can thus identify a pattern which we conjecture to hold for any even cumulant of arbitrary order 2m2m:

κ2m(𝐱1,,𝐱2m)=λmCm1i1,,im=1N(i1,i2)(,im2)im1(,im1)im[(2m1)!]\kappa_{2m}\left({\mathbf{x}}_{1},\dots,{\mathbf{x}}_{2m}\right)=\frac{\lambda^{m}}{C^{m-1}}\sum_{i_{1},\dots,i_{m}=1}^{N}\left(\bullet_{i_{1}},\bullet_{i_{2}}\right)\cdots\left(\bullet{}_{i_{m-2}},\bullet{}_{i_{m-1}}\right)\left(\bullet{}_{i_{m-1}},\bullet{}_{i_{m}}\right)\cdots\left[\left(2m-1\right)!\right] (F.29)

where the indices i1,,imi_{1},\dots,i_{m} obey the following:

  1. 1.

    Each index appears exactly twice in each summand.

  2. 2.

    Each index cannot be paired with itself, i.e. (i1,i1)\left(\bullet_{i_{1}},\bullet_{i_{1}}\right) is not allowed.

  3. 3.

    The same pairing can appear more than once, e.g. (1i2j)(3i4j)(5k6)(7k8)\left(1_{i}2_{j}\right)\left(3_{i}4_{j}\right)\left(5_{k}6_{\ell}\right)\left(7_{k}8_{\ell}\right) is OK, in that i,ji,j are paired together twice, and so are k,k,\ell.

Appendix G Feature learning phase transition

G.1 Field theory derivation of the statistics of the hidden weights covariance

Although our main focus was on the statistics of the DNN outputs, our function-space formalism can also be used to characterize the statistics of the weights of the intermediate hidden layers. Here we focus on the linear CNN toy model given in the main text, where the learnable parameters of the student are given by θ={wc,s,ai,c}\theta=\left\{w_{c,s},a_{i,c}\right\}. Consider first a prior distribution in output space, where throughout this section we denote: f(f1,,fn)\vec{f}\equiv\left(f_{1},\dots,f_{n}\right), i.e. the vector of outputs on the training set alone (without the test point). Since we are interested in the statistics of the hidden weights, we will introduce an appropriate source term in weight space Jc,sJ_{c,s}

P0[f,{Jc,s}]𝑑w𝑑aexp(12σw2c,s(wc,sσw2Jc,s)2)P0(a)μ=1nδ(fμzθ,μ)P_{0}\left[\vec{f},\left\{J_{c,s}\right\}\right]\propto\int dw\int da\exp\left(-\frac{1}{2\sigma_{w}^{2}}\sum_{c,s}\left(w_{c,s}-\sigma_{w}^{2}J_{c,s}\right)^{2}\right)P_{0}\left(a\right)\prod_{\mu=1}^{n}\delta\left(f_{\mu}-z_{\theta,\mu}\right) (G.1)

where zθ,μz_{\theta,\mu} is the of output of the CNN parameterized by θ\theta on the μ\mu’th training point. Given some loss function \mathcal{L}, the posterior is given by

P[f,{Jc,s}]=P0[f,{Jc,s}]e/σ2P\left[\vec{f},\left\{J_{c,s}\right\}\right]=P_{0}\left[\vec{f},\left\{J_{c,s}\right\}\right]e^{-\mathcal{L}/\sigma^{2}} (G.2)

The posterior mean of the hidden weights is thus

Jc,slog(n𝑑fP[f,{Jc,s}])|J=0=wc,sP[f,{Jc,s}]σw2Jc,s|J=0=wc,sP[f,{Jc,s}]\evaluated{\partial_{J_{c,s}}\log\left(\int_{\mathbb{R}^{n}}d\vec{f}P\left[\vec{f},\left\{J_{c,s}\right\}\right]\right)}_{J=0}=\evaluated{\left\langle w_{c,s}\right\rangle_{P\left[\vec{f},\left\{J_{c,s}\right\}\right]}-\sigma_{w}^{2}J_{c,s}}_{J=0}=\left\langle w_{c,s}\right\rangle_{P\left[\vec{f},\left\{J_{c,s}\right\}\right]} (G.3)

and the posterior covariance can be extracted from taking the second derivative, namely

Jc1,s1Jc2,s2log(n𝑑fP[f,{Jc,s}])|J=0\displaystyle\evaluated{\partial_{J_{c_{1},s_{1}}}\partial_{J_{c_{2},s_{2}}}\log\left(\int_{\mathbb{R}^{n}}d\vec{f}P\left[\vec{f},\left\{J_{c,s}\right\}\right]\right)}_{J=0} (G.4)
=wc1,s1wc2,s2P[f,{Jc,s}]+σw4Jc1,s1Jc2,s2|J=0σw2δs1s2δc1c2\displaystyle=\evaluated{\left\langle w_{c_{1},s_{1}}w_{c_{2},s_{2}}\right\rangle_{P\left[\vec{f},\left\{J_{c,s}\right\}\right]}+\sigma_{w}^{4}J_{c_{1},s_{1}}J_{c_{2},s_{2}}}_{J=0}-\sigma_{w}^{2}\delta_{s_{1}s_{2}}\delta_{c_{1}c_{2}}
=wc1,s1wc2,s2σw2δs1s2δc1c2\displaystyle=\left\langle w_{c_{1},s_{1}}w_{c_{2},s_{2}}\right\rangle-\sigma_{w}^{2}\delta_{s_{1}s_{2}}\delta_{c_{1}c_{2}}

Our next task is to rewrite these expectation values over weights under the posterior as expectation values of DNN training outputs (f(𝐱μ)f({\mathbf{x}}_{\mu})) under the posterior. To this end we write down the kernel of this simple CNN such that it depends on the source terms:

KJ(𝐱,𝐱)\displaystyle K_{J}\left({\mathbf{x}},{\mathbf{x}}^{\prime}\right) =i,i,c,c,s,sai,cwc,sx~i,sai,cwc,sx~i,sP[f,Jc,s]\displaystyle=\sum_{i,i^{\prime},c,c^{\prime},s,s^{\prime}}\left\langle a_{i,c}w_{c,s}\tilde{x}_{i,s}a_{i^{\prime},c^{\prime}}w_{c^{\prime},s^{\prime}}\tilde{x}^{\prime}_{i^{\prime},s^{\prime}}\right\rangle_{P\left[\vec{f},J_{c,s}\right]} (G.5)
=i,i,c,cai,cai,caδiiδcc/CNs,swc,sx~i,swc,sx~i,sw,J\displaystyle=\underbrace{\sum_{i,i^{\prime},c,c^{\prime}}\left\langle a_{i,c}a_{i^{\prime},c^{\prime}}\right\rangle_{a}}_{\delta_{ii^{\prime}}\delta_{cc^{\prime}}/CN}\sum_{s,s^{\prime}}\left\langle w_{c,s}\tilde{x}_{i,s}w_{c^{\prime},s^{\prime}}\tilde{x}^{\prime}_{i^{\prime},s^{\prime}}\right\rangle_{w,J}
=1CNi,cs,swc,sx~i,swc,sx~i,sw,J\displaystyle=\frac{1}{CN}\sum_{i,c}\sum_{s,s^{\prime}}\left\langle w_{c,s}\tilde{x}_{i,s}w_{c,s^{\prime}}\tilde{x}^{\prime}_{i,s^{\prime}}\right\rangle_{w,J}
=1Nis,s1Ccwc,swc,sw,J(1/S)δss+(1/S2)JcsJcsx~i,sx~i,s\displaystyle=\frac{1}{N}\sum_{i}\sum_{s,s^{\prime}}\frac{1}{C}\sum_{c}\underbrace{\left\langle w_{c,s}w_{c,s^{\prime}}\right\rangle_{w,J}}_{\left(1/S\right)\delta_{ss^{\prime}}+\left(1/S^{2}\right)J_{cs}J_{cs^{\prime}}}\tilde{x}_{i,s}\tilde{x}^{\prime}_{i,s^{\prime}}
=1NS1Ciscx~i,sx~i,s+1NS2is,s(1CcJcsJcs)Bssx~i,sx~i,s\displaystyle=\frac{1}{NS}\frac{1}{C}\sum_{i}\sum_{s}\sum_{c}\tilde{x}_{i,s}\tilde{x}^{\prime}_{i,s}+\frac{1}{NS^{2}}\sum_{i}\sum_{s,s^{\prime}}\underbrace{\left(\frac{1}{C}\sum_{c}J_{cs}J_{cs^{\prime}}\right)}_{\equiv B_{ss^{\prime}}}\tilde{x}_{i,s}\tilde{x}^{\prime}_{i,s^{\prime}}
=1NS𝐱𝖳𝐱+1NS2i𝐱~i𝖳B𝐱~i\displaystyle=\frac{1}{NS}{\mathbf{x}}^{\mathsf{T}}{\mathbf{x}}^{\prime}+\frac{1}{NS^{2}}\sum_{i}\tilde{{\mathbf{x}}}_{i}^{\mathsf{T}}B\tilde{{\mathbf{x}}}_{i}^{\prime}

where BS×SB\in\mathbb{R}^{S\times S}. This can be written as (d=NSd=NS)

KJ(𝐱,𝐱)=1NS𝐱𝖳(Id+1S(BB))𝐱K_{J}\left({\mathbf{x}},{\mathbf{x}}^{\prime}\right)=\frac{1}{NS}{\mathbf{x}}^{\mathsf{T}}\left(I_{d}+\frac{1}{S}\left(\begin{matrix}B\\ &\ddots\\ &&B\end{matrix}\right)\right){\mathbf{x}}^{\prime} (G.6)

We can now write the second mixed derivatives of KJK_{J} to leading order in JJ as

Jc1,s1Jc2,s2KJ1(𝐱,𝐱)\displaystyle-\partial_{J_{c_{1},s_{1}}}\partial_{J_{c_{2},s_{2}}}K_{J}^{-1}\left({\mathbf{x}},{\mathbf{x}}^{\prime}\right) =NS𝐱𝖳[Jc1,s1Jc2,s2(Id+1S(BB))1]𝐱\displaystyle=-NS{\mathbf{x}}^{\mathsf{T}}\left[\partial_{J_{c_{1},s_{1}}}\partial_{J_{c_{2},s_{2}}}\left(I_{d}+\frac{1}{S}\left(\begin{matrix}B\\ &\ddots\\ &&B\end{matrix}\right)\right)^{-1}\right]{\mathbf{x}}^{\prime} (G.7)
=NS𝐱𝖳[Jc1,s1Jc2,s2(Id1S(BB))]𝐱\displaystyle=-NS{\mathbf{x}}^{\mathsf{T}}\left[\partial_{J_{c_{1},s_{1}}}\partial_{J_{c_{2},s_{2}}}\left(I_{d}-\frac{1}{S}\left(\begin{matrix}B\\ &\ddots\\ &&B\end{matrix}\right)\right)\right]{\mathbf{x}}^{\prime}
=NCis,sx~i,sx~i,sJc1,s1Jc2,s2cJcsJcs\displaystyle=\frac{N}{C}\sum_{i}\sum_{s,s^{\prime}}\tilde{x}_{i,s}\tilde{x}^{\prime}_{i,s^{\prime}}\partial_{J_{c_{1},s_{1}}}\partial_{J_{c_{2},s_{2}}}\sum_{c}J_{cs}J_{cs^{\prime}}
=NCδc1c2is,sx~i,sx~i,s(δss2δs1s+δss2δs1s)\displaystyle=\frac{N}{C}\delta_{c_{1}c_{2}}\sum_{i}\sum_{s,s^{\prime}}\tilde{x}_{i,s}\tilde{x}^{\prime}_{i,s^{\prime}}\left(\delta_{ss_{2}}\delta_{s_{1}s^{\prime}}+\delta_{s^{\prime}s_{2}}\delta_{s_{1}s}\right)
=2NCδc1c2ix~i,s1x~i,s2\displaystyle=2\frac{N}{C}\delta_{c_{1}c_{2}}\sum_{i}\tilde{x}_{i,s_{1}}\tilde{x}^{\prime}_{i,s_{2}}

Next we take the large CC limit and thus have a posterior of the form P[f,J]=P0[f,J]e/σ2P[\vec{f},J]=P_{0}[\vec{f},J]e^{-\mathcal{L}/\sigma^{2}} where P0[f]P_{0}[\vec{f}] contains only KJ1K_{J}^{-1} and none of the higher cumulants. Having the derivatives of KJ1K_{J}^{-1} w.r.t. JJ we can proceed in analyzing the derivatives of the log-partition function for the posterior w.r.t JJ. In particular the covariance matrix of the weights averaged over the different channels is

1Cc1,c2Jc1,s1Jc2,s2log(n𝑑fP[f,J])\displaystyle\frac{1}{C}\sum_{c_{1},c_{2}}\partial_{J_{c_{1},s_{1}}}\partial_{J_{c_{2},s_{2}}}\log\left(\int_{\mathbb{R}^{n}}d\vec{f}P\left[\vec{f},J\right]\right) (G.8)
=1Cμ,ν=1n{c1,c2[Jc1,s1Jc2,s2KJ1(𝐱μ,𝐱ν)]n𝑑fP[f]f(𝐱μ)f(𝐱ν)n𝑑fP[f]}\displaystyle=-\frac{1}{C}\sum_{\mu,\nu=1}^{n}\left\{\sum_{c_{1},c_{2}}\left[\partial_{J_{c_{1},s_{1}}}\partial_{J_{c_{2},s_{2}}}K_{J}^{-1}\left({\mathbf{x}}_{\mu},{\mathbf{x}}_{\nu}\right)\right]\frac{\int_{\mathbb{R}^{n}}d\vec{f}P[\vec{f}]f\left({\mathbf{x}}_{\mu}\right)f\left({\mathbf{x}}_{\nu}\right)}{\int_{\mathbb{R}^{n}}d\vec{f}P[\vec{f}]}\right\}
=2NC2iμ,ν=1nx~i,s1μx~n𝑑fP[f]f(𝐱μ)f(𝐱ν)n𝑑fP[f]i,s2νc1,c2δc1c2\displaystyle=2\frac{N}{C^{2}}\sum_{i}\sum_{\mu,\nu=1}^{n}\tilde{x}_{i,s_{1}}^{\mu}\tilde{x}{}_{i,s_{2}}^{\nu}\frac{\int_{\mathbb{R}^{n}}d\vec{f}P[\vec{f}]f\left({\mathbf{x}}_{\mu}\right)f\left({\mathbf{x}}_{\nu}\right)}{\int_{\mathbb{R}^{n}}d\vec{f}P[\vec{f}]}\sum_{c_{1},c_{2}}\delta_{c_{1}c_{2}}
=2NCiμ,ν=1nx~i,s1μx~n𝑑fP[f]f(𝐱μ)f(𝐱ν)n𝑑fP[f]i,s2ν\displaystyle=2\frac{N}{C}\sum_{i}\sum_{\mu,\nu=1}^{n}\tilde{x}_{i,s_{1}}^{\mu}\tilde{x}{}_{i,s_{2}}^{\nu}\frac{\int_{\mathbb{R}^{n}}d\vec{f}P[\vec{f}]f\left({\mathbf{x}}_{\mu}\right)f\left({\mathbf{x}}_{\nu}\right)}{\int_{\mathbb{R}^{n}}d\vec{f}P[\vec{f}]}

The above result is one of the two main points of this appendix: we established a mapping between expectation values over outputs and expectation values over hidden weights. Such a mapping can in principle be extended to any DNN. On the technical level, it requires the ability to calculate the cumulants as a function of the source terms, JJ. As we argue below, it may very well be that unlike in the main text, only a few cumulants are needed here.

To estimate the above expectation values we use the EK limit, where the sums over the training set are replaced by integrals over the measure μ(𝐱)\mu({\mathbf{x}}), the ff’s are replaced as f(𝐱μ)λλ+σ2/ng(𝐱)f\left({\mathbf{x}}_{\mu}\right)\to\frac{\lambda}{\lambda+\sigma^{2}/n}g\left({\mathbf{x}}\right) and we assume the input distribution is normalized as 𝑑μ(𝐱)xixj=δij\int d\mu\left({\mathbf{x}}\right)x_{i}x_{j}=\delta_{ij}. Following this we find

2NC(λλ+σ2/n)2𝑑μ(𝐱)𝑑μ(𝐱)g(𝐱)g(𝐱)ix~i,s1x~i,s2\displaystyle 2\frac{N}{C}\left(\frac{\lambda}{\lambda+\sigma^{2}/n}\right)^{2}\int d\mu\left({\mathbf{x}}\right)d\mu\left({\mathbf{x}}^{\prime}\right)g\left({\mathbf{x}}\right)g\left({\mathbf{x}}^{\prime}\right)\sum_{i}\tilde{x}_{i,s_{1}}\tilde{x}^{\prime}_{i,s_{2}} (G.9)
=2NC(λλ+σ2/n)2i𝑑μ(𝐱)g(𝐱)x~i,s1aiws1𝑑μ(𝐱)g(𝐱)x~i,s2aiws2\displaystyle=2\frac{N}{C}\left(\frac{\lambda}{\lambda+\sigma^{2}/n}\right)^{2}\sum_{i}\underbrace{\int d\mu\left({\mathbf{x}}\right)g\left({\mathbf{x}}\right)\tilde{x}_{i,s_{1}}}_{a_{i}^{*}w_{s_{1}}^{*}}\underbrace{\int d\mu\left({\mathbf{x}}^{\prime}\right)g\left({\mathbf{x}}^{\prime}\right)\tilde{x}^{\prime}_{i,s_{2}}}_{a_{i}^{*}w_{s_{2}}^{*}}
=2NC(λλ+σ2/n)2i(ai)21ws1ws2\displaystyle=2\frac{N}{C}\left(\frac{\lambda}{\lambda+\sigma^{2}/n}\right)^{2}\underbrace{\sum_{i}\left(a_{i}^{*}\right)^{2}}_{1}w_{s_{1}}^{*}w_{s_{2}}^{*}
=2NC(λλ+σ2/n)2ws1ws2\displaystyle=2\frac{N}{C}\left(\frac{\lambda}{\lambda+\sigma^{2}/n}\right)^{2}w_{s_{1}}^{*}w_{s_{2}}^{*}

Comparing this to our earlier result for the covariance Eq. G.4 we get

2NC(λλ+σ2/n)2ws1ws2\displaystyle 2\frac{N}{C}\left(\frac{\lambda}{\lambda+\sigma^{2}/n}\right)^{2}w_{s_{1}}^{*}w_{s_{2}}^{*} =1Cc1,c2(wc1,s1wc2,s2σw2δs1s2δc1c2)\displaystyle=\frac{1}{C}\sum_{c_{1},c_{2}}\left(\left\langle w_{c_{1},s_{1}}w_{c_{2},s_{2}}\right\rangle-\sigma_{w}^{2}\delta_{s_{1}s_{2}}\delta_{c_{1}c_{2}}\right) (G.10)
=1Cc1,c2wc1,s1wc2,s2σw2δs1s2\displaystyle=\frac{1}{C}\sum_{c_{1},c_{2}}\left\langle w_{c_{1},s_{1}}w_{c_{2},s_{2}}\right\rangle-\sigma_{w}^{2}\delta_{s_{1}s_{2}}

Multiplying by S=1/σw2S=1/\sigma_{w}^{2} and recalling that λ=1/NS\lambda=1/NS we get

[ΣW]s1s2=δs1s2+2Cλ(λ+σ2/n)2ws1ws2+O(1/C2)\left\langle\left[\Sigma_{W}\right]_{s_{1}s_{2}}\right\rangle=\delta_{s_{1}s_{2}}+\frac{2}{C}\frac{\lambda}{\left(\lambda+\sigma^{2}/n\right)^{2}}w_{s_{1}}^{*}w_{s_{2}}^{*}+O\left(1/C^{2}\right) (G.11)

Repeating similar steps while also taking into account diagonal fluctuations yields another factor of (1λ+nσ2)1\left(\frac{1}{\lambda}+\frac{n}{\sigma^{2}}\right)^{-1} on the diagonal, thus arriving at the result as it appears in the main text:

[ΣW]s1s2=(1+(1λ+nσ2)1)δs1s2+2Cλ(λ+σ2/n)2ws1ws2+O(1/C2)\left\langle\left[\Sigma_{W}\right]_{s_{1}s_{2}}\right\rangle=\left(1+\left(\frac{1}{\lambda}+\frac{n}{\sigma^{2}}\right)^{-1}\right)\delta_{s_{1}s_{2}}+\frac{2}{C}\frac{\lambda}{\left(\lambda+\sigma^{2}/n\right)^{2}}w_{s_{1}}^{*}w_{s_{2}}^{*}+O\left(1/C^{2}\right) (G.12)

The above results capture the leading order correction in 1/C1/C to the weights covariance matrix. However the careful reader may be wary of the fact that the results in the main text require 1/C1/C corrections to all orders and so it is potentially inadequate to use such a low order expansion deep in the feature learning regime, as we do in the main text. Here we note that not all DNN quantities need to have the same dependence on CC. In particular it was shown in Ref. [24], that the weight’s low order statistics is only weakly affected by finite-width corrections whereas the output covariance matrix is strongly affected by these. We conjecture that this is the case here and that only the cumulative effect of many weights, as reflected in the output of the DNN, requires strong 1/C1/C corrections.

This conjecture can be verified analytically by repeating the above procedure on the full prior (i.e. the one that contains all cumulants), obtaining the operator in terms of ff’s corresponding the weight’s covariance matrix, and calculating its average with respect to the saddle point theory. We leave this for future work.

G.2 A surrogate quantity for the outlier

Since we used moderate SS values in our simulations (to maintain a reasonable compute time), we aggregated the eigenvalues of many instances of ΣW\Sigma_{W} across training time and across noise realizations. Although the empirical histogram of the spectrum of ΣW\Sigma_{W} agrees very well with the theoretical MP distribution (solid smooth curves in Fig. 2A), there is a substantial difference between the two at the right edge of the support λ+\lambda_{+}, where the empirical histogram has a tail due to finite size effects. Thus it is hard to characterize the phase transition using the largest eigenvalue λmax\lambda_{\max} averaged across realizations. Instead, we use the quantity 𝒬𝐰𝖳ΣW𝐰\mathcal{Q}\equiv{\mathbf{w}}^{*\mathsf{T}}\Sigma_{W}{\mathbf{w}}^{*} as a surrogate which coincides with λmax\lambda_{\max} for CCcritC\ll C_{\rm{crit}} but behaves sensibly on both sides of CcritC_{\rm{crit}}, thus allowing to characterize the phase transition.

Appendix H Further details on the numerical experiments

H.1 Additional details of the numerical experiments

In our experiments, we used the following hyper-parameter values. Learning rates of η=106,3107\eta=10^{-6},3\cdot 10^{-7} which yield results with no appreciable difference in almost all cases, when we scale the amount of statistics collected (training epochs after reaching equilibrium) so that both η\eta values have the same amount of re-scaled training time: we used 1010 training seeds for η=106\eta=10^{-6} and 30 for η=3107\eta=3\cdot 10^{-7}. We used a gradient noise level of σ2=1.0\sigma^{2}=1.0, but also checked for σ2{0.1,0.01}\sigma^{2}\in\{0.1,0.01\} and got qualitatively similar results to those reported in the main text.

(A) Refer to caption (B) Refer to caption

Figure 3: (A) The CNNs’ cosine distance α\alpha, defined by f=(1α)g\left\langle f\right\rangle=(1-\alpha)g between the ensemble-averaged prediction f\left\langle f\right\rangle and ground truth gg plotted vs. number of channels CC for the training set (for the test set see Fig. 1 in the main text). As nn increases, the solution of the self consistent equation 16 (solid line) yields an increasingly accurate prediction of these empirical values (dots). (B) Same data as in (A), presented as empirical α\alpha vs. predicted α\alpha. As nn grows, the two converge to the identity line (dashed black line). Solid lines connecting the dots here are merely for visualization purposes.

In the main text and here we do not show error bars for α\alpha as these are too small to be appreciated visually. They are smaller than the mean values by approximately two orders of magnitude. The error bars were found by computing the empirical standard deviation of α\alpha across training dynamics and training seeds.

H.2 Convergence of the training protocol to GP

In Fig. 4 we plot the MSE between the outputs of the trained CNNs and the predictions of the corresponding GP. We see that as CC becomes large the slope of the MSE tends to 2.0-2.0 indicating the O(1/C)O(1/C) scaling of the leading corrections to the GP. This illustrates where we enter the perturbative regime of GP, and we see that this happens for larger CC as we increase the conv-kernel size SS, since this also increases the input dimension d=NSd=NS. Thus it takes larger CC to enter the highly over-parameterized regime.

Refer to caption
Refer to caption
Figure 4: CNN-GP MSEs for different SS, indicating where the perturbative regime starts (slope approaching 2.0-2.0). For S=15S=15 this happens around C=25C=2^{5} whereas for S=30S=30 this happens around C=27C=2^{7}.

Appendix I Quadratic fully connected network

One of the simplest settings where GPs are expected to strongly under-perform finite DNNs is the case of quadratic fully connected DNNs [26]. Here we consider some positive target of the form g(𝐱)=(𝐰𝐱)2σw2𝐱2g({\mathbf{x}})=({\mathbf{w}}_{*}\cdot{\mathbf{x}})^{2}-\sigma_{w}^{2}||{\mathbf{x}}||^{2} where 𝐰,𝐱d{\mathbf{w}}_{*},{\mathbf{x}}\in{\mathbb{R}}^{d} and a student DNN given by f(𝐱)=m=1M(𝐰m𝐱)2σw2𝐱2f({\mathbf{x}})=\sum_{m=1}^{M}({\mathbf{w}}_{m}\cdot{\mathbf{x}})^{2}-\sigma_{w}^{2}||{\mathbf{x}}||^{2} 222The 𝐱2||{\mathbf{x}}||^{2} shift is not part of the original model but has only a superficial shift effect useful for book-keeping later on.

At large MM and for wm,iw_{m,i} drawn from 𝒩(0,σw2/M){\mathcal{N}}(0,\sigma_{w}^{2}/M), the student generates a GP prior. It is shown below that the GP kernel is simply K(𝐱,𝐱)=2σw4M(𝐱𝐱)2K({\mathbf{x}},{\mathbf{x}}^{\prime})=\frac{2\sigma_{w}^{4}}{M}({\mathbf{x}}\cdot{\mathbf{x}}^{\prime})^{2}. As such it is proportional to the kernel of the above DNN with an additional linear read-out layer. The above model can be written as ijxi[Pijσw2δij]xj\sum_{ij}x_{i}[P_{ij}-\sigma_{w}^{2}\delta_{ij}]x_{j} where PijP_{ij} is a positive semi-definite matrix. The eigenvalues of the matrix appearing within the brackets are therefore larger than σw2-\sigma_{w}^{2} whereas no similar restriction occurs for DNNs with a linear read-out layer. This extra restriction is completely missed by the GP approximation and, as discussed in Ref. [26], leads to strong performance improvements compared to what one expects from the GP or equivalently the DNN with the linear readout layer. Here we demonstrate that our self-consistent approach at the saddle-point level captures this effects

We consider training this DNN on nn train points {𝐱μ}μ=1n\left\{{\mathbf{x}}_{\mu}\right\}_{\mu=1}^{n} using noisy GD training with weight decay γ=Mσ2/σw2\gamma=M\sigma^{2}/\sigma_{w}^{2}. We wish to solve for the predictions of this model with our shifted target approach. To this end, we first derive the cumulants associated with the effective Bayesian prior (P0(f)P_{0}(\vec{f})) here. Equivalently stated, obtain the cumulants of the equilibrium distribution of f\vec{f} following training with no data, only a weight decay term. This latter distribution is given by

P0(f)=𝑑𝐰eM2σw2m=1M𝐰m2μ=1n+1δ(fμm=1M(𝐰m𝐱μ)2+σw2𝐱μ2)\displaystyle P_{0}\left(\vec{f}\right)=\int d{\mathbf{w}}e^{-\frac{M}{2\sigma_{w}^{2}}\sum_{m=1}^{M}||{\mathbf{w}}_{m}||^{2}}\prod_{\mu=1}^{n+1}\delta\left(f_{\mu}-\sum_{m=1}^{M}\left({\mathbf{w}}_{m}\cdot{\mathbf{x}}_{\mu}\right)^{2}+\sigma_{w}^{2}||{\mathbf{x}}_{\mu}||^{2}\right) (I.1)

To obtain the cumulants, we calculate the cumulant generating function of this distribution given by

𝒞(t1,,tn+1)\displaystyle\mathcal{C}(t_{1},...,t_{n+1}) (I.2)
=log(m,i=1,1M,ddwm,i2πM1σw2em,i=1,1M,dMwm,i22σw2+μ=1nitμ[m,i=1,1M,d(𝐰m𝐱μ)2σw2𝐱μ2])\displaystyle=\log\left(\int\prod_{m,i=1,1}^{M,d}\frac{dw_{m,i}}{\sqrt{2\pi M^{-1}\sigma_{w}^{2}}}e^{-\sum_{m,i=1,1}^{M,d}M\frac{w_{m,i}^{2}}{2\sigma_{w}^{2}}+\sum_{\mu=1}^{n}it_{\mu}\left[\sum_{m,i=1,1}^{M,d}\left({\mathbf{w}}_{m}\cdot{\mathbf{x}}_{\mu}\right)^{2}-\sigma_{w}^{2}||{\mathbf{x}}_{\mu}||^{2}\right]}\right)
=Mlog(i=1ddwi2πM1σw2eM𝐰22σw2+μ=1n+1itμ[(𝐰𝐱μ)2])μ=1n+1itμσw2𝐱μ2\displaystyle=M\log\left(\int\prod_{i=1}^{d}\frac{dw_{i}}{\sqrt{2\pi M^{-1}\sigma_{w}^{2}}}e^{-M\frac{||{\mathbf{w}}||^{2}}{2\sigma_{w}^{2}}+\sum_{\mu=1}^{n+1}it_{\mu}\left[\left({\mathbf{w}}\cdot{\mathbf{x}}_{\mu}\right)^{2}\right]}\right)-\sum_{\mu=1}^{n+1}it_{\mu}\sigma_{w}^{2}||{\mathbf{x}}_{\mu}||^{2}
=Mlog(i=1ddwi2πM1σw2e𝐰𝖳[I2M1σw2μitμ𝐱μ𝐱μ𝖳]𝐰2M1σw2)μ=1n+1itμσw2𝐱μ2\displaystyle=M\log\left(\int\prod_{i=1}^{d}\frac{dw_{i}}{\sqrt{2\pi M^{-1}\sigma_{w}^{2}}}e^{-\frac{{\mathbf{w}}^{\mathsf{T}}\left[I-2M^{-1}\sigma_{w}^{2}\sum_{\mu}it_{\mu}{\mathbf{x}}_{\mu}{\mathbf{x}}_{\mu}^{\mathsf{T}}\right]{\mathbf{w}}}{2M^{-1}\sigma_{w}^{2}}}\right)-\sum_{\mu=1}^{n+1}it_{\mu}\sigma_{w}^{2}||{\mathbf{x}}_{\mu}||^{2}
=M2log(det[I2M1σw2μitμ𝐱μ𝐱μ𝖳])μ=1n+1itμσw2𝐱μ2\displaystyle=-\frac{M}{2}\log\left(\det\left[I-2M^{-1}\sigma_{w}^{2}\sum_{\mu}it_{\mu}{\mathbf{x}}_{\mu}{\mathbf{x}}_{\mu}^{\mathsf{T}}\right]\right)-\sum_{\mu=1}^{n+1}it_{\mu}\sigma_{w}^{2}||{\mathbf{x}}_{\mu}||^{2}
=M2Tr(log[I2M1σw2μitμ𝐱μ𝐱μ𝖳])μ=1n+1itμσw2𝐱μ2\displaystyle=-\frac{M}{2}\Tr\left(\log\left[I-2M^{-1}\sigma_{w}^{2}\sum_{\mu}it_{\mu}{\mathbf{x}}_{\mu}{\mathbf{x}}_{\mu}^{\mathsf{T}}\right]\right)-\sum_{\mu=1}^{n+1}it_{\mu}\sigma_{w}^{2}||{\mathbf{x}}_{\mu}||^{2}

Taylor expanding this last expression is straightforward. For instance up to third order is gives

𝒞(t1,,tn+1)\displaystyle\mathcal{C}(t_{1},...,t_{n+1}) =M2μ1,μ2(2M1σw2)2itμ1itμ22(𝐱μ1𝐱μ2)(𝐱μ2𝐱μ1)\displaystyle=\frac{M}{2}\sum_{\mu_{1},\mu_{2}}(2M^{-1}\sigma_{w}^{2})^{2}\frac{it_{\mu_{1}}it_{\mu_{2}}}{2}({\mathbf{x}}_{\mu_{1}}\cdot{\mathbf{x}}_{\mu_{2}})({\mathbf{x}}_{\mu_{2}}\cdot{\mathbf{x}}_{\mu_{1}}) (I.3)
+M2μ1,μ2,μ3(2M1σw2)3itμ1itμ2itμ33(𝐱μ1𝐱μ2)(𝐱μ2𝐱μ3)(𝐱μ3𝐱μ1)+\displaystyle+\frac{M}{2}\sum_{\mu_{1},\mu_{2},\mu_{3}}(2M^{-1}\sigma_{w}^{2})^{3}\frac{it_{\mu_{1}}it_{\mu_{2}}it_{\mu_{3}}}{3}({\mathbf{x}}_{\mu_{1}}\cdot{\mathbf{x}}_{\mu_{2}})({\mathbf{x}}_{\mu_{2}}\cdot{\mathbf{x}}_{\mu_{3}})({\mathbf{x}}_{\mu_{3}}\cdot{\mathbf{x}}_{\mu_{1}})+...

from which the cumulants can be directly inferred, in particular the associated GP kernel given by

K(𝐱μ,𝐱ν)=2M1σw4(𝐱μ𝐱ν)2\displaystyle K\left({\mathbf{x}}_{\mu},{\mathbf{x}}_{\nu}\right)=2M^{-1}\sigma_{w}^{4}\left({\mathbf{x}}_{\mu}\cdot{\mathbf{x}}_{\nu}\right)^{2} (I.4)

Following this, the target shift equation, at the saddle point level, becomes

Δgν\displaystyle\Delta g_{\nu} =tν(𝒞(t1..tn,tn+1=0)μ1,μ2K(𝐱μ1,𝐱μ2)2!itμ1itμ2)|t1..tn=δ^g1σ2..δ^gnσ2\displaystyle=\partial_{t_{\nu}}\left(\mathcal{C}(t_{1}..t_{n},t_{n+1}=0)-\sum_{\mu_{1},\mu_{2}}\frac{K({\mathbf{x}}_{\mu_{1}},{\mathbf{x}}_{\mu_{2}})}{2!}it_{\mu_{1}}it_{\mu_{2}}\right)|_{t_{1}..t_{n}=\frac{\hat{\delta}g_{1}}{\sigma^{2}}..\frac{\hat{\delta}g_{n}}{\sigma^{2}}} (I.5)
=μK(𝐱ν,𝐱μ)δ^gμσ2+σw2Tr(𝐱ν𝐱ν𝖳[I2M1σw2μδ^gμσ2𝐱μ𝐱μ]1)σw2𝐱ν2\displaystyle=-\sum_{\mu}K\left({\mathbf{x}}_{\nu},{\mathbf{x}}_{\mu}\right)\frac{\hat{\delta}g_{\mu}}{\sigma^{2}}+\sigma_{w}^{2}\Tr\left({\mathbf{x}}_{\nu}{\mathbf{x}}_{\nu}^{\mathsf{T}}\left[I-2M^{-1}\sigma_{w}^{2}\sum_{\mu}\frac{\hat{\delta}g_{\mu}}{\sigma^{2}}{\mathbf{x}}_{\mu}{\mathbf{x}}_{\mu}\right]^{-1}\right)-\sigma_{w}^{2}||{\mathbf{x}}_{\nu}||^{2}
=μ=1nK(𝐱ν,𝐱μ)δ^gμσ2+σw2𝐱ν𝖳[I2M1σw2μ=1nδ^gμσ2𝐱μ𝐱μ𝖳]1𝐱νσw2𝐱ν2\displaystyle=-\sum^{n}_{\mu=1}K\left({\mathbf{x}}_{\nu},{\mathbf{x}}_{\mu}\right)\frac{\hat{\delta}g_{\mu}}{\sigma^{2}}+\sigma_{w}^{2}{\mathbf{x}}_{\nu}^{\mathsf{T}}\left[I-2M^{-1}\sigma_{w}^{2}\sum^{n}_{\mu=1}\frac{\hat{\delta}g_{\mu}}{\sigma^{2}}{\mathbf{x}}_{\mu}{\mathbf{x}}_{\mu}^{\mathsf{T}}\right]^{-1}{\mathbf{x}}_{\nu}-\sigma_{w}^{2}||{\mathbf{x}}_{\nu}||^{2}
δ^gν\displaystyle\hat{\delta}g_{\nu} =(gνΔgν)μ,μ=1nK(𝐱ν,𝐱μ)K~μ,μ1(gμΔgμ)\displaystyle=\left(g_{\nu}-\Delta g_{\nu}\right)-\sum^{n}_{\mu,\mu^{\prime}=1}K\left({\mathbf{x}}_{\nu},{\mathbf{x}}_{\mu}\right)\tilde{K}_{\mu,\mu^{\prime}}^{-1}\left(g_{\mu^{\prime}}-\Delta g_{\mu^{\prime}}\right)

The above non-linear equation for the quantities δ^g1,,δ^gn\hat{\delta}g_{1},\dots,\hat{\delta}g_{n} could be solved numerically, with the most numerically demanding part being the inverse of K~μ,ν=K(𝐱μ,𝐱ν)+σ2δμ,ν\tilde{K}_{\mu,\nu}=K({\mathbf{x}}_{\mu},{\mathbf{x}}_{\nu})+\sigma^{2}\delta_{\mu,\nu} on the training set.

Refer to caption
Figure 5: Test MSE as a function of n/dn/d for the phase retrieval model as predicted by our self-consistent equation at the saddle-point level (without any EK-type approximation). Train and test data are drawn uniformly from the d=20d=20 hypersphere with radius 11. The graph shows the median test MSE of 60 different data sets. Our approach captures the desired n=2dn=2d threshold value [26] whereas lazy-learning/GP will predict a cross over at n=O(d2)n=O(d^{2}).

Figure 5 shows the numerical results for the test MSE as obtained by solving the above equations for δ^g\hat{\delta}g on the training set, taking ν=\nu=* in these equation together with the self-consistent δ^gμ\hat{\delta}g_{\mu} to find the mean-predictor, and taking the average MSE of the latter over the test set. Both test and train data sets were random points sampled uniformly from a dd dimensional hypersphere of radius one. The test dataset contained 100100 points and the figure shows the test MSE as a function of n/dn/d where d=20d=20, σw2=1\sigma^{2}_{w}=1, σ2=2.76106\sigma^{2}=2.76\cdot 10^{-6}, M=4dM=4d, and wiw^{*}_{i} drawn from 𝒩(0,1)\mathcal{N}(0,1). The non-linear equations were solved using the Newton-Krylov algorithm together with gradual annealing from σ2=1\sigma^{2}=1 down to the above values. The figure shows the median over 6060 data sets. Remarkably, our self-consistent approach yields the expected threshold values of n/d=2n/d=2 [26] separating good and poor performance. Discerning whether this is a threshold or a smooth cross-over in the large dd limit is left for future work.

Turning to analytics, one can again employ the EK approximation as done for the CNN. However taking σ2\sigma^{2} to zero invalidates the EK approximation and requires a more advance treatment as in Ref. [10]. We thus leave an EK type analysis of the self-consistent equation at σ2=0\sigma^{2}=0 for future work and instead focus on the simpler case of finite σ2\sigma^{2} where analytical predictions can again be derived in similar fashion to our treatment of the CNN.

To simplify things further, we also commit to the distribution [𝐱μ]i𝒩(0,1/d)[{\mathbf{x}}_{\mu}]_{i}\sim\mathcal{N}(0,1/d). In this setting K(𝐱,𝐱)K({\mathbf{x}},{\mathbf{x}}^{\prime}) has two distinct eigenvalues w.r.t. to this measure, the larger one (λ0=2M1σw4(2d2+1d)\lambda_{0}=2M^{-1}\sigma_{w}^{4}\left(\frac{2}{d^{2}}+\frac{1}{d}\right)) associated with f(𝐱)=𝐱2f({\mathbf{x}})=||{\mathbf{x}}||^{2} and a smaller one (λ2=2M1σw42d2\lambda_{2}=2M^{-1}\sigma_{w}^{4}\frac{2}{d^{2}}) associated with xixjx_{i}x_{j} (with iji\neq j) and iaixi2\sum_{i}a_{i}x_{i}^{2} (with i=1dai=0\sum_{i=1}^{d}a_{i}=0) eigenfunctions.

Next we argue that provided the discrepancy is of the following form

δ^gμ\displaystyle\hat{\delta}g_{\mu} =αg(𝐱μ)+βσw2𝐱μ2\displaystyle=\alpha g({\mathbf{x}}_{\mu})+\beta\sigma_{w}^{2}||{\mathbf{x}}||_{\mu}^{2} (I.6)

then within the EK limit the target shift is also of the form of the r.h.s. with αΔ\alpha_{\Delta} and βΔ\beta_{\Delta} and the target shift equations reduce to two coupled non-linear equations for α\alpha and β\beta. Following the EK approximation, we replace all μ\sum_{\mu} in the target shift equation with n𝑑μxn\int d\mu_{x} and obtain

Δg(𝐱)=nσ2𝑑μxK(𝐱,𝐱)δ^g(𝐱)+σw2𝐱𝖳[I2M1σw2nσ2𝑑μxδ^g(𝐱)𝐱𝐱𝖳]1𝐱σw2𝐱2\displaystyle\Delta g({\mathbf{x}})=-\frac{n}{\sigma^{2}}\int d\mu_{x^{\prime}}K\left({\mathbf{x}},{\mathbf{x}}^{\prime}\right)\hat{\delta}g({\mathbf{x}}^{\prime})+\sigma_{w}^{2}{\mathbf{x}}^{\mathsf{T}}\left[I-2M^{-1}\sigma_{w}^{2}\frac{n}{\sigma^{2}}\int d\mu_{x^{\prime}}\hat{\delta}g({\mathbf{x}}^{\prime}){\mathbf{x}}^{\prime}{\mathbf{x}}^{\prime\mathsf{T}}\right]^{-1}{\mathbf{x}}-\sigma_{w}^{2}\norm{{\mathbf{x}}}^{2} (I.7)

Next we note that the iji\neq j element of the matrix 𝑑μxg(𝐱)𝐱𝐱𝖳\int d\mu_{x}g({\mathbf{x}}){\mathbf{x}}{\mathbf{x}}^{\mathsf{T}} is given by

𝑑μxg(𝐱)xixj=𝑑μx((𝐰𝐱)2σw2𝐱2)xixj=2d2wiwj\displaystyle\int d\mu_{x}g({\mathbf{x}})x_{i}x_{j}=\int d\mu_{x}\left(\left({\mathbf{w}}_{*}\cdot{\mathbf{x}}\right)^{2}-\sigma_{w}^{2}\norm{{\mathbf{x}}}^{2}\right)x_{i}x_{j}=2d^{-2}w_{i}^{*}w_{j}^{*} (I.8)

whereas for i=ji=j we obtain

𝑑μxg(𝐱)xixi\displaystyle\int d\mu_{x}g({\mathbf{x}})x_{i}x_{i} =3d2((wi)2σw2)+jid2((wj)2σw2)\displaystyle=3d^{-2}((w^{*}_{i})^{2}-\sigma_{w}^{2})+\sum_{j\neq i}d^{-2}((w^{*}_{j})^{2}-\sigma_{w}^{2}) (I.9)
=3d2((wi)2σw2)+jd2((wj)2σw2)d2((wi)2σw2)\displaystyle=3d^{-2}((w^{*}_{i})^{2}-\sigma_{w}^{2})+\sum_{j}d^{-2}((w^{*}_{j})^{2}-\sigma_{w}^{2})-d^{-2}((w^{*}_{i})^{2}-\sigma_{w}^{2})
=2d2((wi)2σw2)\displaystyle=2d^{-2}((w^{*}_{i})^{2}-\sigma_{w}^{2})

taking this together with the simpler term (𝑑μ(𝐱)βαixixi𝐱𝐱𝖳=β(d+2)αd2I\int d\mu({\mathbf{x}})\frac{\beta}{\alpha}\sum_{i}x_{i}x_{i}{\mathbf{x}}{\mathbf{x}}^{\mathsf{T}}=\frac{\beta(d+2)}{\alpha d^{2}}I)

Δg(y)=nσ2𝑑μxK(𝐱,𝐱)[αg(𝐱)+βσw2𝐱2]+ σw2𝐱𝖳[I2λ2nσw2σ2σw2d2[𝐰𝐰𝖳σw2(1β(d+2)2α)I]]1𝐱σw2𝐱2\Delta g(y)=-\frac{n}{\sigma^{2}}\int d\mu_{x^{\prime}}K({\mathbf{x}},{\mathbf{x}}^{\prime})\left[\alpha g({\mathbf{x}}^{\prime})+\beta\sigma_{w}^{2}\norm{{\mathbf{x}}^{\prime}}^{2}\right]+\dots{\\ }\sigma_{w}^{2}{\mathbf{x}}^{\mathsf{T}}\left[I-2\lambda_{2}\frac{n}{\sigma_{w}^{2}\sigma^{2}}\sigma_{w}^{2}d^{-2}\left[{\mathbf{w}}_{*}{\mathbf{w}}_{*}^{\mathsf{T}}-\sigma_{w}^{2}\left(1-\frac{\beta(d+2)}{2\alpha}\right)I\right]\right]^{-1}{\mathbf{x}}-\sigma_{w}^{2}\norm{{\mathbf{x}}}^{2} (I.10)

Consider the matrix (𝐰𝐰𝖳+bI)({\mathbf{w}}_{*}{\mathbf{w}}_{*}^{\mathsf{T}}+bI), appearing in the above denominator with b=(βσw2(d+2)2ασw2)b=\left(\frac{\beta\sigma_{w}^{2}(d+2)}{2\alpha}-\sigma_{w}^{2}\right), and note that

𝐱𝖳(𝐰𝐰𝖳+bI)n𝐱\displaystyle{\mathbf{x}}^{\mathsf{T}}\cdot({\mathbf{w}}_{*}{\mathbf{w}}_{*}^{\mathsf{T}}+bI)^{n}{\mathbf{x}} =(𝐰𝐱)2(𝐰2+b)nbn𝐰2+bn𝐱2\displaystyle=({\mathbf{w}}_{*}\cdot{\mathbf{x}})^{2}\frac{(\norm{{\mathbf{w}}_{*}}^{2}+b)^{n}-b^{n}}{\norm{{\mathbf{w}}_{*}}^{2}}+b^{n}\norm{{\mathbf{x}}}^{2} (I.11)

Plugging this equation into a Taylor expansion of the denominator of Eq. I.10 one finds that all the resulting terms are of the desired form of a linear superposition of g(𝐱)g({\mathbf{x}}) and 𝐱2\norm{{\mathbf{x}}}^{2}. Considering the first term on the r.h.s. of Eq. I.10, βx2\beta\norm{x}^{2} is already an eigenfunction of the kernel whereas g(𝐱)g({\mathbf{x}}) can be re-written as

g(𝐱)\displaystyle g({\mathbf{x}}) ijwiwjxixj+i(wi)2xi2σw2𝐱2\displaystyle\equiv\sum_{i\neq j}w^{*}_{i}w^{*}_{j}x_{i}x_{j}+\sum_{i}(w^{*}_{i})^{2}x_{i}^{2}-\sigma_{w}^{2}\norm{{\mathbf{x}}}^{2} (I.12)
=ijwiwjxixj+i((wi)2𝐰2d)xi2+(𝐰2dσw2)𝐱2\displaystyle=\sum_{i\neq j}w_{i}^{*}w_{j}^{*}x_{i}x_{j}+\sum_{i}\left((w_{i}^{*})^{2}-\frac{\norm{{\mathbf{w}}_{*}}^{2}}{d}\right)x_{i}^{2}+\left(\frac{\norm{{\mathbf{w}}_{*}}^{2}}{d}-\sigma_{w}^{2}\right)\norm{{\mathbf{x}}}^{2}

so that the first two terms on the r.h.s. are λ2\lambda_{2} eigenfunctions and the last one is a λ0\lambda_{0} eigenfunctions. Summing these different contributions along with the aforementioned Taylor expansion, one finds that Δg(𝐱)\Delta g({\mathbf{x}}) is indeed a linear superposition of g(𝐱)g({\mathbf{x}}) and 𝐱2\norm{{\mathbf{x}}}^{2}.

Next we wish to write down the saddle-point equations for α\alpha and β\beta. For simplicity we focus on the case where g(𝐱)g({\mathbf{x}}) is chosen orthogonal to 𝐱2\norm{{\mathbf{x}}}^{2} under dμ(𝐱)d\mu({\mathbf{x}}), namely 𝐰2=dσw2\norm{{\mathbf{w}}_{*}}^{2}=d\sigma_{w}^{2}. Under this choice the self-consistent equations become

α\displaystyle\alpha =σ2nλ2+σ2n[1αcσw21αbc(11α𝐰2c1αbc)+nσ2αλ2]\displaystyle=\frac{\frac{\sigma^{2}}{n}}{\lambda_{2}+\frac{\sigma^{2}}{n}}\left[1-\frac{\alpha c\sigma_{w}^{2}}{1-\alpha bc}\left(\frac{1}{1-\frac{\alpha\norm{{\mathbf{w}}_{*}}^{2}c}{1-\alpha bc}}\right)+\frac{n}{\sigma^{2}}\alpha\lambda_{2}\right] (I.13)
β\displaystyle\beta =σ2nλ0+σ2n[αc1αbc(b+σw211α𝐰2c1αbc)nσ2βλ0]\displaystyle=-\frac{\frac{\sigma^{2}}{n}}{\lambda_{0}+\frac{\sigma^{2}}{n}}\left[\frac{\alpha c}{1-\alpha bc}\left(b+\sigma_{w}^{2}\frac{1}{1-\frac{\alpha\norm{{\mathbf{w}}_{*}}^{2}c}{1-\alpha bc}}\right)-\frac{n}{\sigma^{2}}\beta\lambda_{0}\right]

where the constant bb was defined above and c=nλ2σw2σ2c=\frac{n\lambda_{2}}{\sigma_{w}^{2}\sigma^{2}}.

Next we perform several straightforward algebraic manipulations with the aim of extracting their asymptotic behavior at large nn. Noting that cσw2=nσ2λ2c\sigma_{w}^{2}=\frac{n}{\sigma^{2}}\lambda_{2}, c𝐰2=2dnσ2λ2c\norm{{\mathbf{w}}_{*}}^{2}=2d\frac{n}{\sigma^{2}}\lambda_{2}, and αcb=nσ2(αλ22βλ0)\alpha cb=-\frac{n}{\sigma^{2}}(\alpha\lambda_{2}-2\beta\lambda_{0}) we have

α\displaystyle\alpha =1λ2+σ2n[σ2nαλ21+nσ2(αλ22βλ0)(11αdnσ2λ21+nσ2(αλ22βλ0))+αλ2]\displaystyle=\frac{1}{\lambda_{2}+\frac{\sigma^{2}}{n}}\left[\frac{\sigma^{2}}{n}-\frac{\alpha\lambda_{2}}{1+\frac{n}{\sigma^{2}}(\alpha\lambda_{2}-2\beta\lambda_{0})}\left(\frac{1}{1-\frac{\alpha d\frac{n}{\sigma^{2}}\lambda_{2}}{1+\frac{n}{\sigma_{2}}(\alpha\lambda_{2}-2\beta\lambda_{0})}}\right)+\alpha\lambda_{2}\right] (I.14)
β\displaystyle\beta =1λ0+σ2n[2αλ0d+21+nσ2(αλ22βλ0)(β(d+2)2α1+11αdnσ2λ21+nσ2(αλ22βλ0))βλ0]\displaystyle=\frac{-1}{\lambda_{0}+\frac{\sigma^{2}}{n}}\left[\frac{2\alpha\frac{\lambda_{0}}{d+2}}{1+\frac{n}{\sigma^{2}}(\alpha\lambda_{2}-2\beta\lambda_{0})}\left(\frac{\beta(d+2)}{2\alpha}-1+\frac{1}{1-\frac{\alpha d\frac{n}{\sigma^{2}}\lambda_{2}}{1+\frac{n}{\sigma^{2}}(\alpha\lambda_{2}-2\beta\lambda_{0})}}\right)-\beta\lambda_{0}\right]

Further simplifications yield

α\displaystyle\alpha =1λ2+σ2n[σ2nαλ21+nσ2(αλ22βλ0αdλ2)+αλ2]\displaystyle=\frac{1}{\lambda_{2}+\frac{\sigma^{2}}{n}}\left[\frac{\sigma^{2}}{n}-\frac{\alpha\lambda_{2}}{1+\frac{n}{\sigma^{2}}(\alpha\lambda_{2}-2\beta\lambda_{0}-\alpha d\lambda_{2})}+\alpha\lambda_{2}\right] (I.15)
β\displaystyle\beta =1λ0+σ2n[2αλ0d+21+nσ2(αλ22βλ0)(β(d+2)2α+αnσ2dλ21+nσ2(αλ22βλ0αdλ2))βλ0]\displaystyle=\frac{-1}{\lambda_{0}+\frac{\sigma^{2}}{n}}\left[\frac{2\alpha\frac{\lambda_{0}}{d+2}}{1+\frac{n}{\sigma^{2}}(\alpha\lambda_{2}-2\beta\lambda_{0})}\left(\frac{\beta(d+2)}{2\alpha}+\frac{\alpha\frac{n}{\sigma^{2}}d\lambda_{2}}{1+\frac{n}{\sigma^{2}}(\alpha\lambda_{2}-2\beta\lambda_{0}-\alpha d\lambda_{2})}\right)-\beta\lambda_{0}\right]

noting that dλ2=2(λ0λ2)d\lambda_{2}=2(\lambda_{0}-\lambda_{2}) we find

α\displaystyle\alpha =1λ2+σ2n[σ2nαλ21nσ2(αλ2+2(β+α)λ0)+αλ2]\displaystyle=\frac{1}{\lambda_{2}+\frac{\sigma^{2}}{n}}\left[\frac{\sigma^{2}}{n}-\frac{\alpha\lambda_{2}}{1-\frac{n}{\sigma^{2}}(\alpha\lambda_{2}+2(\beta+\alpha)\lambda_{0})}+\alpha\lambda_{2}\right] (I.16)
β\displaystyle\beta =1λ0+σ2n[2αλ0d+21+nσ2(αλ22βλ0)(β(d+2)2α+αnσ2dλ21nσ2(αλ2+2(β+α)λ0))βλ0]\displaystyle=\frac{-1}{\lambda_{0}+\frac{\sigma^{2}}{n}}\left[\frac{2\alpha\frac{\lambda_{0}}{d+2}}{1+\frac{n}{\sigma^{2}}(\alpha\lambda_{2}-2\beta\lambda_{0})}\left(\frac{\beta(d+2)}{2\alpha}+\frac{\alpha\frac{n}{\sigma^{2}}d\lambda_{2}}{1-\frac{n}{\sigma^{2}}(\alpha\lambda_{2}+2(\beta+\alpha)\lambda_{0})}\right)-\beta\lambda_{0}\right]

The first equation above is linear in β\beta and yields in the large dd limit

β\displaystyle\beta =ααd(1α)+σ22λ0n\displaystyle=-\alpha-\frac{\alpha}{d(1-\alpha)}+\frac{\sigma^{2}}{2\lambda_{0}n} (I.17)

It can also be used to show that

αλ21nσ2(αλ2+2(β+α)λ0)\displaystyle\frac{\alpha\lambda_{2}}{1-\frac{n}{\sigma^{2}}(\alpha\lambda_{2}+2(\beta+\alpha)\lambda_{0})} =(1α)σ2n\displaystyle=(1-\alpha)\frac{\sigma^{2}}{n} (I.18)

which when placed in the second equation yields

β\displaystyle\beta =λ0λ0+σ2nnσ2β(αλ22βλ0)+2(1α)α1+nσ2(αλ22βλ0)\displaystyle=\frac{\lambda_{0}}{\lambda_{0}+\frac{\sigma^{2}}{n}}\frac{\frac{n}{\sigma^{2}}\beta(\alpha\lambda_{2}-2\beta\lambda_{0})+2(1-\alpha)\alpha}{1+\frac{n}{\sigma^{2}}(\alpha\lambda_{2}-2\beta\lambda_{0})} (I.19)

At large nn, we expect α\alpha and β\beta to go to zero. Accordingly to find the asymptotic decay to zero, one can approximate α(1α)α\alpha(1-\alpha)\approx\alpha, and similarly α/(1α)α\alpha/(1-\alpha)\approx\alpha. This along with the large dd limit simplifies the equations to a quadratic equation in β\beta

β22λ0nσ2(1λ0λ0+σ2/n)+β(12λ0λ0+σ2/n)+λ0λ0+σ2/nσ2λ0n\displaystyle\beta^{2}\frac{2\lambda_{0}n}{\sigma^{2}}\left(1-\frac{\lambda_{0}}{\lambda_{0}+\sigma^{2}/n}\right)+\beta\left(-1-2\frac{\lambda_{0}}{\lambda_{0}+\sigma^{2}/n}\right)+\frac{\lambda_{0}}{\lambda_{0}+\sigma^{2}/n}\frac{\sigma^{2}}{\lambda_{0}n} (I.20)

which for σ2/nλ0\sigma^{2}/n\ll\lambda_{0} simplifies further into

2β23β+σ22λ0n\displaystyle 2\beta^{2}-3\beta+\frac{\sigma^{2}}{2\lambda_{0}n} =0\displaystyle=0 (I.21)

yielding

β\displaystyle\beta =418σ2λ0n\displaystyle=\frac{4}{18}\frac{\sigma^{2}}{\lambda_{0}n} (I.22)
α\displaystyle\alpha =518σ2λ0n\displaystyle=\frac{5}{18}\frac{\sigma^{2}}{\lambda_{0}n} (I.23)

We thus find that both α\alpha and β\beta are of the order of σ2/n2λ0=n1Mdσ24σw4\frac{\sigma^{2}/n}{2\lambda_{0}}=n^{-1}\frac{Md\sigma^{2}}{4\sigma_{w}^{4}}. Hence nn scaling as MdMd ensures good performance. This could have been anticipated as for small yet finite σ2\sigma^{2} each nn can be seen as a soft constrained on the parameters of the DNN and since the DNN contains MdMd parameters n=O(Md)n=O(Md) should provide enough data to fix the student’s parameters close to the teacher’s.