This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

[2]\fnmWei \surHuang

1]\orgnameCSIRO Space and Astronomy, \orgaddress\street26 Dick Perry Ave, \cityKensington, \postcode6151, \stateWA, \countryAustralia

[2]\orgnameRIKEN AIP, \cityTokyo, \countryJapan

On the Convergence Analysis of Over-Parameterized Variational Autoencoders: A Neural Tangent Kernel Perspective

\fnmLi \surWang [email protected]    [email protected] [ *
Abstract

Variational Auto-Encoders (VAEs) have emerged as powerful probabilistic models for generative tasks. However, their convergence properties have not been rigorously proven. The challenge of proving convergence is inherently difficult due to the highly non-convex nature of the training objective and the implementation of a Stochastic Neural Network (SNN) within VAE architectures. This paper addresses these challenges by characterizing the optimization trajectory of SNNs utilized in VAEs through the lens of Neural Tangent Kernel (NTK) techniques. These techniques govern the optimization and generalization behaviors of ultra-wide neural networks. We provide a mathematical proof of VAE convergence under mild assumptions, thus advancing the theoretical understanding of VAE optimization dynamics. Furthermore, we establish a novel connection between the optimization problem faced by over-parameterized SNNs and the Kernel Ridge Regression (KRR) problem. Our findings not only contribute to the theoretical foundation of VAEs but also open new avenues for investigating the optimization of generative models using advanced kernel methods. Our theoretical claims are verified by experimental simulations.

keywords:
Variational Auto-encoder, Stochastic Neural Network, Neural Tangent Kernel

1 Introduction

Variational Autoencoders (VAEs) [1] have garnered significant interest and have been applied across a diverse array of applications, ranging from image generation and style transfer [2, 3, 4] to natural language processing [5]. VAEs aim to learn a compressed yet structured latent representation of input data by maximizing the Evidence Lower BOund (ELBO), thereby facilitating the reconstruction of the original data. Unlike traditional autoencoders [6, 7], VAEs focus on learning the distribution of latent codes, enabling the generation of new samples from this distribution. The dimensionality of the latent space is dictated by data complexity, model objectives, and task-specific needs, ranging from a few to several thousand dimensions. Larger latent spaces can encode more information and provide better disentanglement learning [8, 9, 10], a finding that our experiments also support (see Figures 3 and 4). Concurrently, there is an intuitive belief that a larger latent space may pose challenges to training, such as issues with non-convergence or slow convergence rates.

On the other hand, despite the widespread application of VAEs, our theoretical understanding of the training dynamics remains limited. Investigating the optimization of Deep VAEs theoretically is notoriously challenging, as training deep neural networks involves non-convex optimization of a high-dimensional objective function. The complexity of this optimization problem is further exacerbated by the incorporation of stochastic neural networks (SNNs) in VAEs, which introduces additional stochasticity into the training process. Several studies have attempted to shed light on this problem from different perspectives. For instance, He et al. [11] conducted an empirical investigation of the learning dynamics of deep VAEs to study the posterior collapse. Lucas et al. [12] presented a simple and intuitive analysis of linear VAEs to explain the same collapse. Moreover, Koehler et al. [13] analyzed the training dynamics, offering insights into implicit bias convergence for linear VAEs. However, much of the existing research either leans heavily on empirical simulations or centers around linear VAEs, leaving the broader success of VAEs insufficiently explained.

To address concerns about the convergence in high-dimensional latent spaces in VAEs, in this work, we introduce a novel convergence analysis for VAE training dynamics, specifically when an over-parameterized stochastic neural network serves as its model. While the convergence properties of deterministic neural networks have been extensively explored [14, 15, 16, 17, 18, 19, 20, 21, 22], the convergence behavior of SNNs in VAE remains less understood. Our approach leverages non-asymptotic analysis of dynamical systems, allowing us to examine the behavior of over-parameterized VAEs during training. We demonstrate that the convergence outcome aligns with solving a kernel ridge regression under certain mild assumptions. To our knowledge, this is the first rigorous analysis of the convergence behavior of over-parameterized VAEs. We further validate our theoretical insights through experiments on various image generation tasks. In summary, our key contributions are as follows:

  • We establish a non-asymptotic convergence analysis for over-parameterized SNNs. Specifically, we investigate the convergence rate of the optimization algorithm used to train the VAE.

  • We link the optimization of over-parameterized SNNs with kernel ridge regression, shedding light on the regularization effects of the KL penalty in VAEs.

  • Theoretically, we prove that VAEs with high-dimensional latent spaces can converge, providing a theoretical foundation for employing large latent spaces in VAEs to capture more information.

2 Related Work

Convergence Analysis of Over-parameterized Neural Networks

The convergence analysis of over-parameterized neural networks (NNs) has become an important topic in deep learning research. In a seminal paper, Jacot et al. [14] showed that the optimization behavior of infinitely-wide NNs can be described using a kernel function called neural tangent Kerenl (NTK). This kernel simplifies the optimization dynamics into a linear system that is more tractable. The NTK provides a way to explicitly characterize the dynamics of the neural network during training and to analyze its convergence behavior [23, 24]. Additionally, a series of studies [16, 17, 25, 26, 15, 20] have presented convergence results of over-parameterized networks through a non-asymptotic lens. Furthermore, the Rademacher complexity analysis characterized the generalization ability of trained over-parameterized NNs on unseen data [27, 26]. In addition, NTK has been widely applied to different deep network structures, aiding in understanding their optimization dynamics. This includes convolutional networks [25], orthogonally initialized NN [18], graph neural networks [28], active learning [29], transformer [30], neural architecture search [31], and GAN [32].

Among existing studies of training dynamics of over-parameterized networks, the works of [33, 34, 35, 36] are the most aligned with our research. Nguyen et al. [33] explored the gradient dynamics of over-parameterized auto-encoders (AE) and provided a rigorous proof for the linear convergence of gradient descent in the context of AEs. However, their techniques cannot be directly applied to variational auto-encoders (VAEs) because of the additional randomness introduced by stochastic neural networks. In a separate study, Liu et al. [34] examined the predictive variance of stochastic neural networks. They demonstrated that as the width of an optimized stochastic neural network approaches infinity, its predictive variance on the training set diminishes to zero. While their work sheds light on the behavior of stochastic neural networks in the infinite-width limit, they have not shown the convergence of infinitely-wide neural networks, which is one of the most desirable perspectives of studying a NN. Two other notable studies [35, 36] approached SNNs within the PAC-Bayes framework, leveraging the NTK. However, the SNN structure in our VAE research differs from the PAC-Bayes framework, particularly in how stochasticity is introduced in the latent layer.

Theoretical study of VAEs

While VAEs have been successfully applied in various domains, their theoretical properties are still not fully understood. Several recent works have attempted to provide a theoretical understanding of VAEs. For instance, recent works by [37, 38, 39] refereed to information theory, deriving variational bounds on the mutual information between the input and the latent variable and the objective function. One work by Lucas et al. [12] provided an intuitive explanation for the posterior collapse phenomenon in VAEs. They analyze linear VAEs and show that the posterior collapse can be attributed to the low-rank structure of the encoder. In addition, Kumar et al. [40] presented an approximation of VAE objective function consisting of deterministic auto-encoding objective plus analytic regularizers that depend on the Hessian or Jacobian of the decoding model. Nakagawa et al. [41] provided a quantitative understanding of the VAE property through the differential-geometric and information-theoretic interpretations of VAE. Moreover, [42, 43, 44] are not around the optimization dynamics but they study problems of optimization landscape. In contrast, our work studies the training dynamics of over-parameterized VAEs with the non-linear activation, emphasizing the challenges on the non-linear activation and the complicated optimization behavior.

3 Problem Setup and Preliminary

3.1 Notation

In this work, we adopt a standard notation to represent vectors, matrices, and scalars. Specifically, we use bold-faced letters for vectors and matrices and non-bold letters for scalars. To denote the Euclidean norm of a vector or the spectral norm of a matrix, we use the notation 2\|\cdot\|_{2}. The Frobenius norm of a matrix is represented by F\|\cdot\|_{F}. We use the notation [n]=1,2,,n[n]={1,2,\ldots,n} to represent the set of integers from 1 to nn. Besides, we represent a matrix as a set of row vectors, i.e., 𝐖=[𝐰1,𝐰2,,𝐰m]{\bf W}=[{\bf w}^{\top}_{1},{\bf w}^{\top}_{2},\dots,{\bf w}^{\top}_{m}]^{\top}, where 𝐰r{\bf w}_{r} with r[m]r\in[m] is a column vector of the matrix. Finally, we denote the least eigenvalue of a matrix by λ0(𝚯)\lambda_{0}(\boldsymbol{\Theta}), which is equivalent to λmin(𝚯)\lambda_{\min}(\boldsymbol{\Theta}).

3.2 Variational Auto-encoder

A Variational auto-encoder (VAE) [1], as a directed probabilistic graphical model (DPGM), is designed to learn a latent variable model. Its primary objective is to maximize the log-likelihood of the training data {𝐱i}i=1n\{{\bf x}_{i}\}^{n}_{i=1} via variational inference, where nn is the number of training samples. The VAE introduces a distribution qϕ(𝐳|𝐱)q_{\phi}(\bf z|x) to approximate the intractable true posterior p(𝐳|𝐱)p(\bf z|x), where ϕ\phi are neural network parameters that can be learned in the encoder. Then, the decoder takes 𝐳{\bf z} as input to generate 𝐱{\bf x}^{\prime} as a reconstruction for 𝐱{\bf x}.

The common training objective of the VAE is to maximize the Evidence Lower Bound (ELBO), given by:

Lelbo=1ni=1n𝔼𝐳[logpθ(𝐱i|𝐳)]KL(qϕ(𝐳|𝐱i)p(𝐳)),L_{elbo}=\frac{1}{n}\sum_{i=1}^{n}\mathbb{E}_{\mathbf{z}}[\log p_{\theta}({\bf x}^{\prime}_{i}|{\bf z})]-\mathrm{KL}(q_{\phi}({\bf z}|{\bf x}_{i})\|p({\bf z})), (1)

where 𝐳qϕ(𝐳|𝐱i){\bf z}\sim q_{\phi}({\bf z}|{\bf x}_{i}), and ϕ\phi and θ\theta represent the parameters in encoder and decoder, respectively. The first term in the ELBO measures the reconstruction loss between the generated 𝐱{\bf x}^{\prime} and the original 𝐱{\bf x}. The second term represents the Kullback-Leibler (KL) divergence between the approximate posterior q(𝐳|𝐱)q(\bf z|x) and the prior p(𝐳)p({\bf z}), where p(𝐳)p({\bf z}) is often chosen to be an isotropic multivariate Gaussian distribution.

3.3 Stochastic Neural Network and Objective Function

Refer to caption
Figure 1: Architecture of Variational Auto-Encoder.

Consider a stochastic neural network (SNN) 𝐟d\mathbf{f}\in\mathbb{R}^{d}, where dd is the input dimension. In the context of this work, our SNN is defined as follows:

𝐟(𝐱)=1m(𝐖(d))ψ(σ(𝐳)),𝐳𝒩(𝐖(μ)𝐱(e),diag(𝐖(σ)𝐱(e))),\displaystyle\mathbf{f}({\bf x})=\frac{1}{\sqrt{m}}({\bf W}^{(d)})^{\top}\psi(\sigma({\bf z})),\quad{\bf z}\sim\mathcal{N}({\bf W}^{(\mu)}{\bf x}^{(e)},\mathrm{diag}({\bf W}^{(\sigma)}{\bf x}^{(e)})), (2)

where 𝐱(e)d{\bf x}^{(e)}\in\mathbb{R}^{d} is the encoded representation derived from the input 𝐱\mathbf{x}, 𝐖(μ),𝐖(σ)m×d{\bf W}^{(\mu)},{\bf W}^{(\sigma)}\in\mathbb{R}^{m\times d} are weight matrices employed to construct the latent Gaussian representation. Here mm represents the width of the network, indicating the number of neurons, σ()\sigma(\cdot) is the non-linear activation function, ψ()\psi(\cdot) is the decoder representation function, and 𝐖(d)m×d{\bf W}^{(d)}\in\mathbb{R}^{m\times d} is the linear weight matrix utilized in the final layer. A visual representation of the SNN under study is depicted in Figure 1.

In the construction of the latent representation, we employ the re-parametrization trick, a technique that allows for the backpropagation of gradients through random nodes. In particular, the latent variable can be expressed as:

𝐳=𝐖(μ)𝐱(e)+(𝐖(σ)𝜻)𝐱(e),𝜻𝒩(𝟎,𝐈),\displaystyle{\bf z}={\bf W}^{(\mu)}{\bf x}^{(e)}+({\bf W}^{(\sigma)}\odot\boldsymbol{\zeta}){\bf x}^{(e)},\quad\boldsymbol{\zeta}\sim\mathcal{N}({\bf 0},{\bf I}), (3)

where 𝐖(μ)\mathbf{W}^{(\mu)} and 𝐖(σ)\mathbf{W}^{(\sigma)} represent the mean and variance weights, respectively. Besides, 𝜻\boldsymbol{\zeta} is a random variable drawn from a standard normal distribution.

Given the structure of the SNN, our objective function considered in this work is defined as:

L=1ni=1n[(𝐟^(𝐱i),𝐱i)+βKL(P(𝐳i(t))P(𝐳i(0)))],L=\frac{1}{n}\sum_{i=1}^{n}\big{[}\ell(\hat{\mathbf{f}}(\mathbf{x}_{i}),\mathbf{x}_{i})+\beta{\rm KL}\big{(}P(\mathbf{z}_{i}(t))\|P(\mathbf{z}_{i}(0))\big{)}\big{]}, (4)

where 𝐟^(𝐱i)𝔼𝜻[𝐟(𝐱i,𝜻)]\hat{\mathbf{f}}(\mathbf{x}_{i})\triangleq\mathbb{E}_{\boldsymbol{\zeta}}[\mathbf{f}(\mathbf{x}_{i},\boldsymbol{\zeta})], and 𝐳i(t)\mathbf{z}_{i}(t) is the latent representation for input 𝐱i\mathbf{x}_{i} at time tt. Besides, β\beta is an adjustable hyperparameter that balances latent channel capacity and independence constraints with reconstruction accuracy [45]. The first term i=1n(𝐟^(𝐱i),𝐱i)\sum_{i=1}^{n}\ell(\hat{\mathbf{f}}(\mathbf{x}_{i}),\mathbf{x}_{i}) is called the reconstruction loss. In this study, we utilize the mean squared error as our reconstruction loss, following seminal theoretical works [35, 17, 27, 26]. The second term KL(){\rm KL}(\cdot) is a Kullback–Leibler (KL) divergence, where prior distribution is the Gaussian distribution of latent variable at initialization, and the posterior is the distribution of latent variable after training, 𝐳i𝒩(𝐖(μ)𝐱i(e),diag(𝐖(σ)𝐱i(e)))\mathbf{z}_{i}\sim\mathcal{N}({\bf W}^{(\mu)}{\bf x}^{(e)}_{i},\mathrm{diag}({\bf W}^{(\sigma)}{\bf x}^{(e)}_{i})). It’s worth noting that our KL is tailored to align with our theoretical analysis for constructing kernel ridge regression.

To optimize the objective function given by (4), we adopt a gradient descent rule:

𝐖(s)(t+1)\displaystyle\mathbf{W}^{(s)}(t+1) =𝐖(s)(t)ηL(t)𝐖(s)(t),where s{μ,σ,d},\displaystyle=\mathbf{W}^{(s)}(t)-\eta\frac{\partial{L}(t)}{\partial\mathbf{W}^{(s)}(t)},\text{where }s\in\{\mu,\sigma,d\}, (5)

where η\eta is the learning rate. Note that while the weights in the encoder and decoder remain fixed, we specifically optimize the mean weights 𝐖(μ)\mathbf{W}^{(\mu)}, variance weights 𝐖(σ)\mathbf{W}^{(\sigma)}, and the weights in the final layer 𝐖(d)\mathbf{W}^{(d)}. This optimization strategy is primarily adopted for the sake of theoretical simplicity. It’s worth noting that this choice does not compromise or alter our final conclusions.

4 Theoretical Results

In this section, we present our primary theoretical findings related to the optimization of the VAE’s objective function. We start from the essential definitions and assumptions, later the convergence will be established. Finally, we prove the kernel ridge regression result through over-parameterization.

4.1 Definition and Assumptions

For the purpose of our optimization analysis, we introduce the concept of the neural tangent kernel for a stochastic neural network:

Definition 1 (Stochastic Neural Tangent Kernel).

The tangent kernels associated with output function at weights are defined as,

𝚯ik,jk(s)\displaystyle\boldsymbol{\Theta}_{ik,jk^{\prime}}^{(s)} =𝐖(𝐬)f^k(𝐱i;t)𝐖(𝐬)f^k(𝐱j;t),where s{μ,σ,d}\displaystyle=\nabla_{{\bf W^{(s)}}}\hat{f}_{k}({\bf x}_{i};t)^{\top}\nabla_{\bf W^{(s)}}\hat{f}_{k^{\prime}}({\bf x}_{j};t)\in\mathbb{R},\text{where }s\in\{\mu,\sigma,d\} (6)

and i,j[1,n]i,j\in[1,n] denote the index of input samples while k,k[1,d]k,k^{\prime}\in[1,d] represent the index of output functions. Furthermore, the NTK for the entire network is defined as 𝚯=𝚯(μ)+𝚯(σ)+𝚯(d)\boldsymbol{\Theta}=\boldsymbol{\Theta}^{(\mu)}+\boldsymbol{\Theta}^{(\sigma)}+\boldsymbol{\Theta}^{(d)}.

A few remarks on Definition 1 are in order. Unlike standard (deterministic) neural networks, the VAE comprises two sets of parameters in the latent layer, namely, 𝐖(μ){\bf W}^{(\mu)} and 𝐖(σ){\bf W}^{(\sigma)}. Due to the reparameterization trick, gradient descent is executed on each of these parameters independently. Consequently, we observe two distinct tangent kernels corresponding to each parameter set. Secondly, The scenario with multiple outputs in variational autoencoder networks presents added complexity compared to networks with a single output [16, 26]. Given that the output dimension of the stochastic neural network is dd, the neural tangent kernel is a matrix of size nd×nd\mathbb{R}^{nd\times nd}. As we delve deeper in the subsequent sections, it will become evident that the non-diagonal NTK across the output index is zero, and the diagonal NTK remains consistent across the output index. This uniformity allows us to employ Kronecker products, facilitating the derivation of NTKs.

Next, we impose some technical conditions on the activation function, which is stated as follows:

Assumption 4.1 (Continuous and Partial Derivative Continuous).

The activation function σ(x)\sigma(x) and its partial derivative σ(x)x\frac{\partial\sigma(x)}{\partial x} are continuous in xx.

This assumption ensures that we can interchange the operations of integration and differentiation over the activation function. Subsequently, we present technical conditions on both the activation function and the decoder representation function:

Assumption 4.2 (LL-Lipschitz and β\beta-Smooth).

There exist constants β\beta and LL such that for any x,xx,x^{\prime}\in\mathbb{R}:

|σ(x)σ(x)|L|xx|,|σ(x)σ(x)|β|xx|,\displaystyle\left|\sigma(x)-\sigma(x^{\prime})\right|\leq L\left|x-x^{\prime}\right|,\left|\sigma^{\prime}(x)-\sigma^{\prime}(x^{\prime})\right|\leq\beta\left|x-x^{\prime}\right|,
|ψ(x)ψ(x)|L|xx|,|ψ(x)ψ(x)|β|xx|.\displaystyle\left|\psi(x)-\psi(x^{\prime})\right|\leq L\left|x-x^{\prime}\right|,\left|\psi^{\prime}(x)-\psi^{\prime}(x^{\prime})\right|\leq\beta\left|x-x^{\prime}\right|.

These conditions are important in demonstrating the stability of the training process within the framework of the NTK.

4.2 Optimization analysis

For the sake of simplification, we focus on the optimization of the stochastic neural network as described in (2), emphasizing solely on the reconstruction loss. This means we are setting aside the KL divergence term for the time being. Additionally, given that we’re adopting a squared loss without KL divergence, the objective function (4) reduces to:

Lmse=12ni=1n𝐟^(𝐱i)𝐱i22.L_{mse}=\frac{1}{2n}\sum_{i=1}^{n}\left\|\hat{\mathbf{f}}(\mathbf{x}_{i})-\mathbf{x}_{i}\right\|^{2}_{2}. (7)

Then the gradient flow dynamics of output function f^k\hat{f}_{k} are governed by:

df^k(𝐱i;t)dt\displaystyle\frac{d\hat{f}_{k}({\bf x}_{i};t)}{dt} =1nj=1nk=1d(xj,kf^k(𝐱j;t))𝚯ik,jk(t).\displaystyle=\frac{1}{n}\sum_{j=1}^{n}\sum_{k^{\prime}=1}^{d}\left({x}_{j,k^{\prime}}-\hat{f}_{k^{\prime}}({\bf x}_{j};t)\right)\boldsymbol{\Theta}_{ik,jk^{\prime}}(t). (8)

Equation (8) implies that the dynamics of output function are governed by the neural tangent kernels. Furthermore, as we will show later, the neural tangent kernels will stay constant during the training process in the infinite-width limit. In this way, Equation (8) reduces to an ordinary differential equation (ODE):

df^k(𝐱i;t)dt\displaystyle\frac{d\hat{f}_{k}({\bf x}_{i};t)}{dt} =1nj=1nk=1d(xj,kf^k(𝐱j;t))𝚯ik,jk(),\displaystyle=\frac{1}{n}\sum_{j=1}^{n}\sum_{k^{\prime}=1}^{d}\left({x}_{j,k^{\prime}}-\hat{f}_{k^{\prime}}({\bf x}_{j};t)\right)\boldsymbol{\Theta}^{(\infty)}_{ik,jk^{\prime}}, (9)

where we define the neural tangent kernel of an infinitely-wide SNN by:

𝚯()limm𝚯=limm(𝚯(μ)+𝚯(σ)+𝚯(d)).\boldsymbol{\Theta}^{(\infty)}\triangleq\lim_{m\rightarrow\infty}\boldsymbol{\Theta}=\lim_{m\rightarrow\infty}\left(\boldsymbol{\Theta}^{(\mu)}+\boldsymbol{\Theta}^{(\sigma)}+\boldsymbol{\Theta}^{(d)}\right). (10)

To demonstrate the convergence result induced by Equation (9), we perform an in-depth concentration analysis. This analysis focuses on the convergence of stochastic neural networks in a non-asymptotic manner, i.e., with a large but finite width. We present our main result in the following theorem:

Theorem 1.

Assume the lowest eigenvalue of the limiting NTK is greater than zero, i.e., λ0(𝚯)\lambda_{0}(\boldsymbol{\Theta}^{\infty}) and 𝐱i(e)2=1\|\mathbf{x}^{(e)}_{i}\|_{2}=1 for i[n]i\in[n]. Suppose the network’s width m=Ω(max{n5d3λ04δ2,n2d2λ0logndδ})m=\Omega\left(\max\left\{\frac{n^{5}d^{3}}{\lambda_{0}^{4}\delta^{2}},\frac{n^{2}d^{2}}{\lambda_{0}}\log\frac{nd}{\delta}\right\}\right), then with probability at least 1δ1-\delta over the random initialization we have,

Lmse(t)exp((λ0/n)t)Lmse(0).L_{mse}(t)\leq\exp\left(-(\lambda_{0}/n)t\right)L_{mse}(0). (11)

The proof sketch of Theorem 1 will be given in Section 5. Theorem 1 establishes that if mm is large enough, the expected training error converges to zero at a linear rate. In particular, the least eigenvalue of NTK governs the convergence rate.

4.3 Regularization effect of KL divergence

By Theorem 1, we establish the global convergence of stochastic neural networks with a large width in VAE. Building on this foundation, we further consider full objective function (4) which incorporates an additional KL divergence term.

After a detailed calculation of the KL divergence for two Gaussian distributions, we simplify our analysis by making certain assumptions. Specifically, we assume that 𝐖(σ)\mathbf{W}^{(\sigma)} remains constant and select a prior 𝐱i(e)\mathbf{x}^{(e)}_{i} such that the objective function (4) is transformed to:

L(t)=12n𝐟^(𝐗;t)𝐗F2+β2𝐖(μ)(t)𝐖(μ)(0)F2.{L}(t)=\frac{1}{2n}\left\|\hat{\mathbf{f}}(\mathbf{X};t)-\mathbf{X}\right\|^{2}_{F}+\frac{\beta}{2}\left\|\mathbf{W}^{(\mu)}(t)-\mathbf{W}^{(\mu)}(0)\right\|^{2}_{F}. (12)

Building on this, we further analyze the regularization effect of the KL term when training VAEs and present our findings in the subsequent theorem:

Theorem 2.

Suppose mpoly(n,1/λ0,1/δ,1/)m\geq{\rm poly}({n},1/\lambda_{0},1/\delta,1/\mathcal{E}) and the objective function follows the form (12). When we only optimize the mean weight 𝐖(μ)\mathbf{W}^{(\mu)}, for any test input 𝐱ted\mathbf{x}_{te}\in\mathbb{R}^{d} with probability at least (1δ)(1-\delta) over the random initialization, we have

𝐟^(𝐱te,)\displaystyle\hat{\mathbf{f}}({\bf x}_{te},\infty) =𝚯(μ)(𝐱te,𝐗)(𝚯(μ)(𝐗,𝐗)+β𝐈)1𝐗±.\displaystyle=\boldsymbol{\Theta}^{(\mu)}({\bf x}_{te},{\bf X})(\boldsymbol{\Theta}^{(\mu)}({\bf X},{\bf X})+\beta{\bf I})^{-1}{\bf X}\pm\mathcal{E}. (13)

where \mathcal{E} is the residual error term and is upper bounded by init+Θnλ0+β\mathcal{E}_{init}+\mathcal{E}_{\Theta}\frac{\sqrt{n}}{\lambda_{0}+\beta} with 𝐟^(𝛉(0),𝐱te)2init\|\hat{\mathbf{f}}\left(\boldsymbol{\theta}(0),\mathbf{x}_{te}\right)\|_{2}\leq\mathcal{E}_{\rm init} and 𝚯𝚯(t)2Θ\|\boldsymbol{\Theta}^{\infty}-\boldsymbol{\Theta}(t)\|_{2}\leq\mathcal{E}_{\Theta}.

The proof of Theorem 2 will be given in the Appendix. Note that the error term is bounded by the difference between the output function of the finite network and the infinitely-wide network. This difference is further decomposed into the initial difference and the difference during training. The latter can be bounded by nλ0+βΘ\frac{\sqrt{n}}{\lambda_{0}+\beta}\mathcal{E}_{\Theta}, where n\sqrt{n} comes from the input and 1λ0+β\frac{1}{\lambda_{0}+\beta} results from the integration over the training time. Besides, the necessity of fixing the variance weight in Theorem 2 arises because we are seeking a closed-form solution under the NTK regime. Theorem 2 reveals the regularization effect of the KL divergence on the convergence of over-parameterized VAEs and makes a connection between solution of training a VAE and kernel ridge regression.

5 Proof Sketch

In this section, we outline the approach used to establish the convergence results for VAEs and provide proofs for Theorem 1 and Theorem 2. Our first step involves demonstrating that the NTKs, in the infinite-width limit, converge to deterministic kernels:

Lemma 1.

Consider a stochastic network of the form (2), with the initialization of wij(μ)𝒩(0,1){w}^{(\mu)}_{ij}\sim\mathcal{N}(0,1), wij(σ)=σ0{w}^{(\sigma)}_{ij}=\sigma_{0}, and wij(d)𝒩(0,1){w}^{(d)}_{ij}\sim\mathcal{N}(0,1). Then the tangent kernels at initialization before training in the infinite-width limit follow the expression:

limm𝚯ij(μ)(0)=𝔼𝐰[𝐱i(e)𝐱j(e)[ψ^σ^(𝐰𝐱i(e))][ψ^σ^(𝐰𝐱j(e))]]𝐈d×d,\displaystyle\lim_{m\rightarrow\infty}\boldsymbol{\Theta}^{(\mu)}_{ij}(0)=\mathbb{E}_{{\bf w}}\big{[}{{\bf x}_{i}^{(e)}}^{\top}{\bf x}^{(e)}_{j}[\hat{\psi}^{\prime}\hat{\sigma}^{\prime}({\bf w}^{\top}{\bf x}^{(e)}_{i})][\hat{\psi}^{\prime}\hat{\sigma}^{\prime}(\mathbf{w}^{\top}{\bf x}^{(e)}_{j})]\big{]}\otimes{\bf I}_{d\times d}, (14)
limm𝚯ij(σ)(0)=𝔼𝐰[[ψ^𝐱^iσ^(𝐰𝐱i(e))][ψ^𝐱^j(e)σ^(𝐰𝐱j(e))]]𝐈d×d,\displaystyle\lim_{m\rightarrow\infty}\boldsymbol{\Theta}^{(\sigma)}_{ij}(0)=\mathbb{E}_{{\bf w}}\big{[}[\hat{\psi}^{\prime}\hat{\mathbf{x}}^{\top}_{i}\hat{\sigma}^{\prime}({\bf w}^{\top}{\bf x}^{(e)}_{i})][\hat{\psi}^{\prime}\hat{\mathbf{x}}^{(e)}_{j}\hat{\sigma}^{\prime}(\mathbf{w}^{\top}{\bf x}^{(e)}_{j})]\big{]}\otimes{\bf I}_{d\times d},
limm𝚯ij(d)(0)=𝔼𝐰[[ψ(σ^(𝐰𝐱i(e)))][ψ(σ^(𝐰𝐱j(e)))]]𝐈d×d,\displaystyle\lim_{m\rightarrow\infty}\boldsymbol{\Theta}^{(d)}_{ij}(0)=\mathbb{E}_{\bf w}\big{[}[{\psi}(\hat{\sigma}({\bf w}^{\top}{\bf x}^{(e)}_{i}))][{\psi}(\hat{\sigma}({\bf w}^{\top}{\bf x}^{(e)}_{j}))]\big{]}\otimes{\bf I}_{d\times d},

where 𝐰𝒩(𝟎,𝐈)\mathbf{w}\sim\mathcal{N}(\mathbf{0},\mathbf{I}) and we define:

[ψ^σ^(𝐰𝐱i(e))]𝔼𝜻[ψσ((𝐰+σ0𝜻)𝐱i(e))],\displaystyle[\hat{\psi}^{\prime}\hat{\sigma}^{\prime}(\mathbf{w}^{\top}\mathbf{x}^{(e)}_{i})]\triangleq\mathbb{E}_{\boldsymbol{\zeta}}[{\psi}^{\prime}{\sigma}^{\prime}((\mathbf{w}+\sigma_{0}\boldsymbol{\zeta})^{\top}\mathbf{x}^{(e)}_{i})],
[ψ^𝐱^iσ^(𝐰𝐱i(e))]𝔼𝜻[ψ(𝐱i𝜻)σ((𝐰+σ0𝜻)𝐱i(e))],\displaystyle[\hat{\psi}^{\prime}\hat{\mathbf{x}}^{\top}_{i}\hat{\sigma}^{\prime}(\mathbf{w}^{\top}\mathbf{x}^{(e)}_{i})]\triangleq\mathbb{E}_{\boldsymbol{\zeta}}[{\psi}^{\prime}(\mathbf{x}_{i}\odot\boldsymbol{\zeta})^{\top}{\sigma}^{\prime}((\mathbf{w}+\sigma_{0}\boldsymbol{\zeta})^{\top}\mathbf{x}^{(e)}_{i})],
[ψ(σ^(𝐰𝐱i(e)))]𝔼𝜻[ψ(σ((𝐰+σ0𝜻)𝐱i(e))].\displaystyle[{\psi}(\hat{\sigma}(\mathbf{w}^{\top}\mathbf{x}^{(e)}_{i}))]\triangleq\mathbb{E}_{\boldsymbol{\zeta}}[{\psi}({\sigma}((\mathbf{w}+\sigma_{0}\boldsymbol{\zeta})^{\top}\mathbf{x}^{(e)}_{i})].
Proof of Lemma 1.

We first rewrite the expression for the stochastic neural network as follows:

𝐟^(𝐱)=𝔼𝜻[1mr=1m(𝐰r(d))ψ(σ((𝐰r(μ)+𝐰r(σ)𝜻r)𝐱(e)))].\displaystyle\hat{\mathbf{f}}({\bf x})=\mathbb{E}_{\boldsymbol{\zeta}}\left[\frac{1}{\sqrt{m}}\sum_{r=1}^{m}({\bf w}_{r}^{(d)})\psi(\sigma(({\bf w}_{r}^{(\mu)}+{\bf w}_{r}^{(\sigma)}\odot\boldsymbol{\zeta}_{r})^{\top}{\bf x}^{(e)}))\right].

Then the derivative of output function f^k(𝐱i)\hat{f}_{k}(\mathbf{x}_{i}) for k[1,d]k\in[1,d] with respect to the parameters 𝐰r(μ){\bf w}_{r}^{(\mu)}, 𝐰r(σ){\bf w}_{r}^{(\sigma)} and 𝐰r(d){\bf w}_{r}^{(d)} for r[1,m]r\in[1,m] can be expressed as:

f^k(𝐱i)𝐰r(μ)\displaystyle\frac{\partial\hat{f}_{k}({\bf x}_{i})}{\partial{\bf w}^{(\mu)}_{r}} =𝔼𝜻r[1mwr,k(d)ψσ(zi,r)𝐱i],\displaystyle=\mathbb{E}_{\boldsymbol{\zeta}_{r}}\left[\frac{1}{\sqrt{m}}w^{(d)}_{r,k}\psi^{\prime}\sigma^{\prime}(z_{i,r}){\bf x}_{i}\right],
f^k(𝐱i)𝐰r(σ)\displaystyle\frac{\partial\hat{f}_{k}({\bf x}_{i})}{\partial{\bf w}^{(\sigma)}_{r}} =𝔼𝜻r[1mwr,k(d)ψσ(zi,r)𝐱i𝜻r],\displaystyle=\mathbb{E}_{\boldsymbol{\zeta}_{r}}\left[\frac{1}{\sqrt{m}}w^{(d)}_{r,k}\psi^{\prime}\sigma^{\prime}(z_{i,r}){\bf x}_{i}\odot\boldsymbol{\zeta}_{r}\right],
f^k(𝐱i)𝐰r(d)\displaystyle\frac{\partial\hat{f}_{k}({\bf x}_{i})}{\partial{\bf w}^{(d)}_{r}} =𝔼𝜻r[1mψ(σ(zi,r))𝜹k],\displaystyle=\mathbb{E}_{\boldsymbol{\zeta}_{r}}\left[\frac{1}{\sqrt{m}}\psi(\sigma(z_{i,r}))\boldsymbol{\delta}_{k}\right],

where we have interchanged integration and differentiation over activation σ()\sigma(\cdot) by Assumption 4.2, and 𝜹k[δ1,k,δ2,k,,δd,k]d\boldsymbol{\delta}_{k}\triangleq[\delta_{1,k},\delta_{2,k},\cdots,\delta_{d,k}]^{\top}\in\mathbb{R}^{d}. We then calculate each NTK at initialization, i.e. t=0t=0:

(1) The neural tangent kernel 𝚯(μ)(0)\boldsymbol{\Theta}^{(\mu)}(0).

𝚯ik,jk(μ)(0)\displaystyle\boldsymbol{\Theta}^{(\mu)}_{ik,jk^{\prime}}(0) =(𝐱i(e))𝐱j(e)mr=1mψ^σ^(zi,r)ψ^σ^(zj,r)(wr,k(d)wr,k(d)),\displaystyle=\frac{({\bf x}^{(e)}_{i})^{\top}{\bf x}^{(e)}_{j}}{m}\sum_{r=1}^{m}\hat{\psi}^{\prime}\hat{\sigma}^{\prime}(z_{i,r})\hat{\psi}^{\prime}\hat{\sigma}^{\prime}({z}_{j,r})\left(w^{(d)}_{r,k}w^{(d)}_{r,k^{\prime}}\right),

where we define ψ^σ^(zi,r)𝔼𝜻r[ψσ(zi,r)]\hat{\psi}^{\prime}\hat{\sigma}^{\prime}(z_{i,r})\triangleq\mathbb{E}_{\boldsymbol{\zeta}_{r}}\left[\psi^{\prime}\sigma^{\prime}(z_{i,r})\right] and zi,r=𝐰rμ+𝐰rσ𝜻r,𝐱i(e)z_{i,r}=\langle\mathbf{w}^{\mu}_{r}+\mathbf{w}^{\sigma}_{r}\odot\boldsymbol{\zeta}_{r},\mathbf{x}^{(e)}_{i}\rangle. For all pairs of i,j,k,ki,j,k,k^{\prime}, 𝚯ik,jk(μ)(0)\boldsymbol{\Theta}^{(\mu)}_{ik,jk^{\prime}}(0) is the average of mm i.i.d. random variables. Because wr,k(d)w^{(d)}_{r,k} is i.i.d., we know that 𝔼[(wr,k(d))(wr,k(d))]=0\mathbb{E}\left[({w}^{(d)}_{r,k})({w}^{(d)}_{r,k^{\prime}})\right]=0. Therefore, we have

limm𝚯(μ)(0)=limm𝚯ij(μ)(0)𝐈d×d.\lim_{m\rightarrow\infty}\boldsymbol{\Theta}^{(\mu)}(0)=\lim_{m\rightarrow\infty}\boldsymbol{\Theta}^{(\mu)}_{ij}(0)\otimes{\bf I}_{d\times d}.

As a result, we conclude the proof:

limm𝚯ij(μ)(0)=𝔼𝐰[𝐱i(e)𝐱j(e)[ψ^σ^(𝐰𝐱i(e))][ψ^σ^(𝐰𝐱j(e))]]𝐈d×d.\displaystyle\lim_{m\rightarrow\infty}\boldsymbol{\Theta}^{(\mu)}_{ij}(0)=\mathbb{E}_{{\bf w}}\big{[}{{\bf x}_{i}^{(e)}}^{\top}{\bf x}^{(e)}_{j}[\hat{\psi}^{\prime}\hat{\sigma}^{\prime}({\bf w}^{\top}{\bf x}^{(e)}_{i})][\hat{\psi}^{\prime}\hat{\sigma}^{\prime}(\mathbf{w}^{\top}{\bf x}^{(e)}_{j})]\big{]}\otimes{\bf I}_{d\times d}.

(2) Similarly, the neural tangent kernel 𝚯(σ)(0)\boldsymbol{\Theta}^{(\sigma)}(0):

limm𝚯(σ)(0)=limm𝚯ij(σ)(0)𝐈d×d.\lim_{m\rightarrow\infty}\boldsymbol{\Theta}^{(\sigma)}(0)=\lim_{m\rightarrow\infty}\boldsymbol{\Theta}^{(\sigma)}_{ij}(0)\otimes{\bf I}_{d\times d}.

(3) The neural tangent kernel 𝚯(d)(0)\boldsymbol{\Theta}^{(d)}(0).

𝚯ij,kk(d)=1mr=1mψ(σ(zi,r))ψ(σ(zj,r))δkk.\displaystyle\boldsymbol{\Theta}^{(d)}_{ij,kk^{\prime}}=\frac{1}{m}\sum_{r=1}^{m}\psi(\sigma(z_{i,r}))\psi(\sigma(z_{j,r}))\delta_{kk^{\prime}}.

Again, this neural tangent kernel is the average of mm i.i.d. random variables. Therefore we have limm𝚯ij(d)(0)=𝔼𝐰[[ψ(σ^(𝐰𝐱i(e)))][ψ(σ^(𝐰𝐱j(e)))]]\lim_{m\rightarrow\infty}\boldsymbol{\Theta}^{(d)}_{ij}(0)=\mathbb{E}_{\bf w}\big{[}[{\psi}(\hat{\sigma}({\bf w}^{\top}{\bf x}^{(e)}_{i}))][{\psi}(\hat{\sigma}({\bf w}^{\top}{\bf x}^{(e)}_{j}))]\big{]}. ∎

Lemma 1 establishes that the NTKs converge to deterministic kernels in the infinite-width limit. We then study the behavior of tangent kernels with ultra-wide condition, namely m=ploy(n,1/λ0,1/δ)m={\rm ploy}(n,1/\lambda_{0},1/\delta) at initialization. The following lemma demonstrates that if mm is large, then 𝚯(μ)(0)\boldsymbol{\Theta}^{(\mu)}(0), 𝚯(σ)(0)\boldsymbol{\Theta}^{(\sigma)}(0), and 𝚯(d)(0)\boldsymbol{\Theta}^{(d)}(0) have a lower bound on smallest eigenvalue with high probability.

Lemma 2 (NTK at initialization).

If m=Ω(n2d2λ0logndδ)m=\Omega\left(\frac{n^{2}d^{2}}{\lambda_{0}}\log\frac{nd}{\delta}\right), while wij(μ){w}^{(\mu)}_{ij}, wij(σ){w}^{(\sigma)}_{ij}, and wij(d){w}^{(d)}_{ij} are initialized by the form in Lemma 1, then with probability at least 1δ1-\delta over the initialization of weights, we have,

𝚯(μ)(0)+𝚯(σ)(0)+𝚯(d)(0)𝚯2λ0/4,\displaystyle\left\|\boldsymbol{\Theta}^{(\mu)}(0)+\boldsymbol{\Theta}^{(\sigma)}(0)+\boldsymbol{\Theta}^{(d)}(0)-\boldsymbol{\Theta}^{\infty}\right\|_{2}\leq\lambda_{0}/4, (15)
𝚯(μ)(0)+𝚯(σ)(0)+𝚯(d)(0)23λ0/4.\displaystyle\left\|\boldsymbol{\Theta}^{(\mu)}(0)+\boldsymbol{\Theta}^{(\sigma)}(0)+\boldsymbol{\Theta}^{(d)}(0)\right\|_{2}\geq 3\lambda_{0}/{4}.
Proof of Lemma 2.

The proof is by the standard concentration bound. By Lemma 1 we have shown that each neural tangent kernel is a sum of mm i.i.d. random variables. Then by Hoeffding’s inequality for sub-Gaussian variable, we know that

|𝚯ik,jk(μ)(0)limm𝚯ik,jk(μ)|log(2/δ)2m\left|\boldsymbol{\Theta}^{(\mu)}_{ik,jk^{\prime}}(0)-\lim_{m\rightarrow\infty}\boldsymbol{\Theta}^{(\mu)}_{ik,jk^{\prime}}\right|\leq\sqrt{\frac{\log(2/\delta^{\prime})}{2m}}

holds with probability at least (1δ)(1-\delta^{\prime}). Because NTK matrix is of size nd×ndnd\times nd, we then apply a union bound over all i,j[n]i,j\in[n] and k,k[d]k,k^{\prime}\in[d]. By setting δ=δ/(n2d2)\delta^{\prime}=\delta/(n^{2}d^{2}), we obtain that

|𝚯ik,jk(μ)(0)limm𝚯ik,jk(μ)|log(2n2d2/δ)2m.\left|\boldsymbol{\Theta}^{(\mu)}_{ik,jk^{\prime}}(0)-\lim_{m\rightarrow\infty}\boldsymbol{\Theta}^{(\mu)}_{ik,jk^{\prime}}\right|\leq\sqrt{\frac{\log(2n^{2}d^{2}/\delta)}{2m}}.

There by matrix perturbation theory we have,

𝚯(μ)(0)limm𝚯(μ)22\displaystyle\left\|\boldsymbol{\Theta}^{(\mu)}(0)-\lim_{m\rightarrow\infty}\boldsymbol{\Theta}^{(\mu)}\right\|^{2}_{2} 𝚯μ(0)limm𝚯F2i,j,k,k|𝚯ik,jk(μ)(0)limm𝚯ik,jk(μ)|2\displaystyle\leq\left\|\boldsymbol{\Theta}^{\mu}(0)-\lim_{m\rightarrow\infty}\boldsymbol{\Theta}^{\infty}\right\|^{2}_{F}\leq\sum_{i,j,k,k^{\prime}}\left|\boldsymbol{\Theta}^{(\mu)}_{ik,jk^{\prime}}(0)-\lim_{m\rightarrow\infty}\boldsymbol{\Theta}^{(\mu)}_{ik,jk^{\prime}}\right|^{2}
=O(n2d2log(nd/δ)m).\displaystyle=O\left(\frac{n^{2}d^{2}\log(nd/\delta)}{m}\right).

Similarly, applying the above argument to 𝚯(σ)\boldsymbol{\Theta}^{(\sigma)} and 𝚯(d)\boldsymbol{\Theta}^{(d)} can yield the same result without much revision. Thus, by Hoeffding’s inequality and union bound over matrix size, we know that the following inequalities hold with probability at least (1δ)(1-\delta),

𝚯(σ)(0)limm𝚯(σ)22\displaystyle\left\|\boldsymbol{\Theta}^{(\sigma)}(0)-\lim_{m\rightarrow\infty}\boldsymbol{\Theta}^{(\sigma)}\right\|^{2}_{2} O(n2d2log(nd/δ)m),\displaystyle\leq O\left(\frac{n^{2}d^{2}\log(nd/\delta)}{m}\right),
𝚯(d)(0)limm𝚯(d)22\displaystyle\left\|\boldsymbol{\Theta}^{(d)}(0)-\lim_{m\rightarrow\infty}\boldsymbol{\Theta}^{(d)}\right\|^{2}_{2} O(n2d2log(nd/δ)m).\displaystyle\leq O\left(\frac{n^{2}d^{2}\log(nd/\delta)}{m}\right).

Finally, by the triangle inequality, we arrive at:

𝚯(μ)(0)+𝚯(σ)(0)+𝚯(d)(0)𝚯()2λ04.\displaystyle\left\|\boldsymbol{\Theta}^{(\mu)}(0)+\boldsymbol{\Theta}^{(\sigma)}(0)+\boldsymbol{\Theta}^{(d)}(0)-\boldsymbol{\Theta}^{(\infty)}\right\|_{2}\leq\frac{\lambda_{0}}{4}.

On the other hand, we can achieve the lower bound by triangle inequality:

𝚯(μ)(0)+𝚯(σ)(0)+𝚯(d)(0)2𝚯()2𝚯(μ)(0)+𝚯(σ)(0)+𝚯(d)(0)𝚯()23λ04.\displaystyle\left\|\boldsymbol{\Theta}^{(\mu)}(0)+\boldsymbol{\Theta}^{(\sigma)}(0)+\boldsymbol{\Theta}^{(d)}(0)\right\|_{2}\geq\left\|\boldsymbol{\Theta}^{(\infty)}\right\|_{2}-\left\|\boldsymbol{\Theta}^{(\mu)}(0)+\boldsymbol{\Theta}^{(\sigma)}(0)+\boldsymbol{\Theta}^{(d)}(0)-\boldsymbol{\Theta}^{(\infty)}\right\|_{2}\geq\frac{3\lambda_{0}}{4}.

We finalize the proof by setting m=Ω(n2d2λ0logndδ)m=\Omega\left(\frac{n^{2}d^{2}}{\lambda_{0}}\log\frac{nd}{\delta}\right). ∎

Lemma 2 completes the first step of our proof strategy, which states that if the width mm is large enough, then the neural tangent kernel of SNN at initialization before training is close to the limiting kernel and is positive definite.

However, a challenge arises due to the time-dependent nature of NTKs. These matrices evolve during the gradient descent training process. To account for this problem, we build a lemma stating that if the weights during training are close to their initialization, then the NTKs during training are close to the deterministic kernel 𝚯()\boldsymbol{\Theta}^{(\infty)}. Moreover, these NTKs will maintain a lower bound on their smallest eigenvalue, throughout the gradient descent training:

Lemma 3.

Suppose that 𝐱i(e)2=1\|\mathbf{x}^{(e)}_{i}\|_{2}=1, and at initialization that 𝐖(μ)(0)Fcμ,2m\left\|{\bf W}^{(\mu)}(0)\right\|_{F}\leq c_{\mu,2}\sqrt{m}, 𝐖(σ)(0)Fcσ,2m\left\|{\bf W}^{(\sigma)}(0)\right\|_{F}\leq c_{\sigma,2}\sqrt{m}, 𝐰k(d)(0)2cd,2m\left\|{\bf w}_{k}^{(d)}(0)\right\|_{2}\leq c_{d,2}\sqrt{m}, and 𝐰k(d)(0)4cd,4m1/4\left\|{\bf w}_{k}^{(d)}(0)\right\|_{4}\leq c_{d,4}m^{1/4} for k[d]k\in[d]. If the weights at a training step tt satisfy: 𝐰r(μ)(t)𝐰r(μ)(0)2Rμc1λ0nd\left\|{\bf w}^{(\mu)}_{r}(t)-{\bf w}^{(\mu)}_{r}(0)\right\|_{2}\triangleq R_{\mu}\leq\frac{c_{1}\lambda_{0}}{n\sqrt{d}}, 𝐰r(σ)(t)𝐰r(σ)(0)2Rσc1λ0nd\left\|{\bf w}^{(\sigma)}_{r}(t)-{\bf w}^{(\sigma)}_{r}(0)\right\|_{2}\triangleq R_{\sigma}\leq\frac{c_{1}\lambda_{0}}{n\sqrt{d}}, and 𝐰r(d)(t)𝐰r(d)(0)2Rdc3λ0nd\left\|{\bf w}^{(d)}_{r}(t)-{\bf w}^{(d)}_{r}(0)\right\|_{2}\leq R_{d}\triangleq\frac{c_{3}\lambda_{0}}{n\sqrt{d}}, where c1c_{1}, c2c_{2}, and c3c_{3} are constants, then with probability at least 1δ1-\delta over the random initialization, we have

𝚯(μ)(t)+𝚯(σ)(t)+𝚯(d)(t)𝚯2λ0/2,𝚯(μ)(t)+𝚯(σ)(t)+𝚯(d)(t)2λ0/2.\displaystyle\left\|\boldsymbol{\Theta}^{(\mu)}(t)+\boldsymbol{\Theta}^{(\sigma)}(t)+\boldsymbol{\Theta}^{(d)}(t)-\boldsymbol{\Theta}^{\infty}\right\|_{2}\leq{\lambda_{0}}/{2},\left\|\boldsymbol{\Theta}^{(\mu)}(t)+\boldsymbol{\Theta}^{(\sigma)}(t)+\boldsymbol{\Theta}^{(d)}(t)\right\|_{2}\geq{\lambda_{0}}/{2}. (16)
Proof.

(1) We first analyze 𝚯(μ)(t)\boldsymbol{\Theta}^{(\mu)}(t):

𝚯ik,jk(μ)(t)=(𝐱i(e))𝐱j(e)mr=1mψ^σ^(zi,r)ψ^σ^(zj,r)wr,k(d)(t)wr,k(d)(t).\displaystyle\boldsymbol{\Theta}^{(\mu)}_{ik,jk^{\prime}}(t)=\frac{({\bf x}^{(e)}_{i})^{\top}{\bf x}^{(e)}_{j}}{m}\sum_{r=1}^{m}\hat{\psi}^{\prime}\hat{\sigma}^{\prime}(z_{i,r})\hat{\psi}^{\prime}\hat{\sigma}^{\prime}(z_{j,r})w^{(d)}_{r,k}(t)w^{(d)}_{r,k^{\prime}}(t).

Now we bound the distance between 𝚯ik,jk(μ)(t)\boldsymbol{\Theta}^{(\mu)}_{ik,jk}(t) and 𝚯ik,jk(μ)(0)\boldsymbol{\Theta}^{(\mu)}_{ik,jk}(0) through the following inequality:

|𝚯ik,jk(μ)(t)𝚯ik,jk(μ)(0)|\displaystyle\quad\bigg{|}\boldsymbol{\Theta}_{ik,jk}^{(\mu)}(t)-\boldsymbol{\Theta}_{ik,jk}^{(\mu)}(0)\bigg{|}
(a)1m|(𝐱i(e))𝐱j(e)||r=1m(wr,k(d)(0))2[ψ^σ^(zi,r(t))ψ^σ^(zj,r(t))ψ^σ^(zi,r(0))ψ^σ^(zj,r(0))]|\displaystyle\overset{(a)}{\leq}\frac{1}{m}\left|(\mathbf{x}^{(e)}_{i})^{\top}\mathbf{x}^{(e)}_{j}\right|\Bigg{|}\sum_{r=1}^{m}(w^{(d)}_{r,k}(0))^{2}[\hat{\psi}^{\prime}\hat{\sigma}^{\prime}\left(z_{i,r}(t)\right)\hat{\psi}^{\prime}\hat{\sigma}^{\prime}\left(z_{j,r}(t)\right)-\hat{\psi}^{\prime}\hat{\sigma}^{\prime}\left(z_{i,r}(0)\right)\hat{\psi}^{\prime}\hat{\sigma}^{\prime}\left(z_{j,r}(0)\right)]\Bigg{|}
+1m|(𝐱i(e))𝐱j(e)||r=1m(wr,k(d)(t)2wr,k(d)(0)2)ψ^σ^(zi,r(t))ψ^σ^(zj,r(t))|\displaystyle\quad+\frac{1}{m}\left|(\mathbf{x}^{(e)}_{i})^{\top}\mathbf{x}^{(e)}_{j}\right|\left|\sum_{r=1}^{m}\left(w^{(d)}_{r,k}(t)^{2}-w^{(d)}_{r,k}(0)^{2}\right)\hat{\psi}^{\prime}\hat{\sigma}^{\prime}\left(z_{i,r}(t)\right)\hat{\psi}^{\prime}\hat{\sigma}^{\prime}\left(z_{j,r}(t)\right)\right|
(b)2βL3mr=1mwr,k(d)(0)2𝐰r(μ)(t)𝐰r(μ)(0)2+2βL4mr=1mwr,k(d)(0)2𝐰r(μ)(t)𝐰r(μ)(0)2\displaystyle\overset{(b)}{\leq}\frac{2\beta L^{3}}{m}\sum_{r=1}^{m}w^{(d)}_{r,k}(0)^{2}\left\|\mathbf{w}^{(\mu)}_{r}(t)-\mathbf{w}^{(\mu)}_{r}(0)\right\|_{2}+\frac{2\beta L^{4}}{m}\sum_{r=1}^{m}w^{(d)}_{r,k}(0)^{2}\left\|\mathbf{w}^{(\mu)}_{r}(t)-\mathbf{w}^{(\mu)}_{r}(0)\right\|_{2}
+L4mr=1m|wr,k(d)(t)2wr,k(d)(0)2|(4βL3m+4βL4m)cd,42mRμm+3L4mcd,2Rdm.\displaystyle+\frac{L^{4}}{m}\sum_{r=1}^{m}\bigg{|}w^{(d)}_{r,k}(t)^{2}-w^{(d)}_{r,k}(0)^{2}\bigg{|}\leq\left(\frac{4\beta L^{3}}{m}+\frac{4\beta L^{4}}{m}\right)c^{2}_{d,4}\sqrt{m}R_{\mu}\sqrt{m}+\frac{3L^{4}}{m}c_{d,2}R_{d}m.

where (a)(a) is because of triangle inequality, and (b) is because of the assumptions that 𝐱(e)2=1\|\mathbf{x}^{(e)}\|_{2}=1 as well as LL-Lipschitz and β\beta-Smooth of activations σ()\sigma(\cdot) and ψ()\psi(\cdot). In particular, we have used the following inequalities:

σ^(zi,r(t))σ^(zi,r(0))\displaystyle\hat{\sigma}^{\prime}(z_{i,r}(t))-\hat{\sigma}^{\prime}(z_{i,r}(0)) =𝔼𝜻r[σ(zi,r(t))σ(zi,r(0))]β𝐰r(μ)(t)𝐰r(μ)(0)2,\displaystyle=\mathbb{E}_{\boldsymbol{\zeta}_{r}}\left[\sigma^{\prime}(z_{i,r}(t))-\sigma^{\prime}(z_{i,r}(0))\right]\leq\beta\left\|\mathbf{w}^{(\mu)}_{r}(t)-\mathbf{w}^{(\mu)}_{r}(0)\right\|_{2},
ψ^(σ(zi,r(t))ψ^(σ(zi,r(0))\displaystyle\hat{\psi}^{\prime}(\sigma(z_{i,r}(t))-\hat{\psi}^{\prime}(\sigma(z_{i,r}(0)) β𝔼𝜻r[σ(zi,r(t))σ(zi,r(0))]βL𝐰r(μ)(t)𝐰r(μ)(0)2.\displaystyle\leq\beta\mathbb{E}_{\boldsymbol{\zeta}_{r}}\left[\sigma(z_{i,r}(t))-\sigma(z_{i,r}(0))\right]\leq\beta L\left\|\mathbf{w}^{(\mu)}_{r}(t)-\mathbf{w}^{(\mu)}_{r}(0)\right\|_{2}.

Summing over all entries of the matrix, we can bound the perturbation:

𝚯(μ)(t)𝚯(μ)(0)2\displaystyle\left\|\boldsymbol{\Theta}^{(\mu)}(t)-\boldsymbol{\Theta}^{(\mu)}(0)\right\|_{2} i,j,k|𝚯ik,jk(μ)(t)𝚯ik,jk(μ)(0)|2\displaystyle\leq\sqrt{\sum_{i,j,k}\left|\boldsymbol{\Theta}_{ik,jk}^{(\mu)}(t)-\boldsymbol{\Theta}_{ik,jk}^{(\mu)}(0)\right|^{2}}
(4β(L3+L4)cd,42Rμ+3L4cd,2Rd)nd.\displaystyle\leq\left(4\beta(L^{3}+L^{4})c^{2}_{d,4}R_{\mu}+3L^{4}c_{d,2}R_{d}\right)n\sqrt{d}.

Finally, due to the condition that Ruc1λ0ndR_{u}\leq\frac{c_{1}\lambda_{0}}{n\sqrt{d}} and Rdc3λ0ndR_{d}\leq\frac{c_{3}\lambda_{0}}{n\sqrt{d}}, we have,

𝚯(μ)(t)𝚯(μ)(0)2λ0/12.\displaystyle\left\|\boldsymbol{\Theta}^{(\mu)}(t)-\boldsymbol{\Theta}^{(\mu)}(0)\right\|_{2}\leq{\lambda_{0}}/{12}.

(2) Similarly, we have:

𝚯(σ)(t)𝚯(σ)(0)2λ0/12.\displaystyle{\color[rgb]{0,0,0}\left\|\boldsymbol{\Theta}^{(\sigma)}(t)-\boldsymbol{\Theta}^{(\sigma)}(0)\right\|_{2}\leq{\lambda_{0}}/{12}.}

(3) Finally,

|𝚯ik,jk(d)(t)𝚯ik,jk(d)(0)|\displaystyle\bigg{|}\boldsymbol{\Theta}_{ik,jk}^{(d)}(t)-\boldsymbol{\Theta}_{ik,jk}^{(d)}(0)\bigg{|} βLmr=1m𝐰r(μ)(t)𝐰r(μ)(0)2𝐰r(μ)(t)+𝐰r(μ)(0)2\displaystyle{\leq}\frac{\beta L}{m}\sum_{r=1}^{m}\left\|\mathbf{w}^{(\mu)}_{r}(t)-\mathbf{w}^{(\mu)}_{r}(0)\right\|_{2}\left\|\mathbf{w}^{(\mu)}_{r}(t)+\mathbf{w}^{(\mu)}_{r}(0)\right\|_{2}
2βLm(cμ,2+Rμ)mRμm.\displaystyle\leq\frac{2\beta L}{m}(c_{\mu,2}+R_{\mu})\sqrt{m}R_{\mu}\sqrt{m}.

With all the inequalities at hand, we conclude the proof:

𝚯(μ)(t)+𝚯(σ)(t)+𝚯(d)(t)𝚯23λ012+λ04=λ02.\displaystyle\left\|\boldsymbol{\Theta}^{(\mu)}(t)+\boldsymbol{\Theta}^{(\sigma)}(t)+\boldsymbol{\Theta}^{(d)}(t)-\boldsymbol{\Theta}^{\infty}\right\|_{2}\leq\frac{3\lambda_{0}}{12}+\frac{\lambda_{0}}{4}=\frac{\lambda_{0}}{2}.

Lemma 3 demonstrates that if the change of weight is bounded, then the tangent kernel matrix is close to its expectation. The next lemma will show that the changes of weights during training are bounded when the NTK is close to the limiting NTK:

Lemma 4.

Suppose λ0(t)λ02\lambda_{0}(t)\geq\frac{\lambda_{0}}{2} for 0<t<T0<t<T, then,

𝐰r(s)(t)𝐰r(s)(0)2𝐗𝐟^(𝐗;0)Fndmλ0=Rs,where s{μ,σ,d}.\displaystyle\left\|{\bf w}^{(s)}_{r}(t)-{\bf w}^{(s)}_{r}(0)\right\|_{2}\leq\frac{\left\|{\bf X}-\hat{\mathbf{f}}({\bf X};0)\right\|_{F}\sqrt{n}d}{\sqrt{m}\lambda_{0}}=R^{\prime}_{s},\text{where }s\in\{\mu,\sigma,d\}. (17)
Proof of Lemma 4.

The dynamics of loss can be calculated,

ddt(t)=1n(𝐗𝐟^(𝐗;t))𝚯(t)(𝐗𝐟^(𝐗;t))Fλ0n𝐗𝐟^(𝐗;t)F2.\displaystyle\frac{d}{dt}\mathcal{L}(t)=-\frac{1}{n}\left\|\left({\bf X}-\hat{\mathbf{f}}({\bf X};t)\right)^{\top}\boldsymbol{\Theta}(t)\left({\bf X}-\hat{\mathbf{f}}({\bf X};t)\right)\right\|_{F}\leq-\frac{\lambda_{0}}{n}\left\|{\bf X}-\hat{\mathbf{f}}({\bf X};t)\right\|^{2}_{F}.

Integrating the differential function, the loss can be bounded as follows:

(t)exp((λ0/n)t)(0),\displaystyle\mathcal{L}(t)\leq\exp\left(-(\lambda_{0}/n)t\right)\mathcal{L}(0),

which implies the linear convergence rate of the stochastic neural network. Then the gradient flow for 𝐰r(μ)\mathbf{w}_{r}^{(\mu)} is as follows,

ddt𝐰r(μ)(t)2\displaystyle\left\|\frac{d}{dt}\mathbf{w}^{(\mu)}_{r}(t)\right\|_{2} =1ni=1n(𝐱i𝐟^i(t))1m𝐰r(d)σ^(zi,r)𝐱i2\displaystyle=\frac{1}{n}\left\|\sum_{i=1}^{n}(\mathbf{x}_{i}-\hat{\mathbf{f}}_{i}(t))^{\top}\frac{1}{\sqrt{m}}\mathbf{w}^{(d)}_{r}\hat{\sigma}^{\prime}(z_{i,r})\mathbf{x}_{i}\right\|_{2}
βnmi=1n𝐱i𝐟^i(t)2𝐰r(d)(t)2𝐰r(μ)(t)2\displaystyle\leq\frac{\beta}{n\sqrt{m}}\sum_{i=1}^{n}\left\|\mathbf{x}_{i}-\hat{\mathbf{f}}_{i}(t)\right\|_{2}\left\|\mathbf{w}^{(d)}_{r}(t)\right\|_{2}\left\|\mathbf{w}^{(\mu)}_{r}(t)\right\|_{2}
βmn𝐗𝐟^(𝐗;0)Fexp((λ0/n)t)(Rd+dcd,2)(Rμ+dcμ,2).\displaystyle\leq\frac{\beta}{\sqrt{mn}}\left\|\mathbf{X}-\hat{\mathbf{f}}(\mathbf{X};0)\right\|_{F}\exp(-(\lambda_{0}/n)t)\left(R^{\prime}_{d}+\sqrt{d}c_{d,2}\right)\left(R^{\prime}_{\mu}+\sqrt{d}c_{\mu,2}\right).

Integrating the gradient, we have:

𝐰r(μ)(T)𝐰r(μ)(0)20Tddt𝐰r(μ)(t)2𝑑tβn𝐗𝐟^(𝐗;0)Fcd,2cμ,2dmλ0.\displaystyle\left\|\mathbf{w}_{r}^{(\mu)}(T)-\mathbf{w}_{r}^{(\mu)}(0)\right\|_{2}\leq\int_{0}^{T}\left\|\frac{d}{dt}\mathbf{w}_{r}^{(\mu)}(t)\right\|_{2}dt\leq\frac{\beta\sqrt{n}\left\|\mathbf{X}-\hat{\mathbf{f}}({\bf X};0)\right\|_{F}c_{d,2}c_{\mu,2}d}{\sqrt{m}\lambda_{0}}.

Similarly, we have:

𝐰r(d)(T)𝐰r(d)(0)20Tddt𝐰r(d)(t)2𝑑tLn𝐗𝐟^(𝐗;0)Fcμ,2dmλ0.\displaystyle\left\|\mathbf{w}_{r}^{(d)}(T)-\mathbf{w}_{r}^{(d)}(0)\right\|_{2}\leq\int_{0}^{T}\left\|\frac{d}{dt}\mathbf{w}_{r}^{(d)}(t)\right\|_{2}dt\leq\frac{L\sqrt{n}\left\|\mathbf{X}-\hat{\mathbf{f}}({\bf X};0)\right\|_{F}c_{\mu,2}\sqrt{d}}{\sqrt{m}\lambda_{0}}.

Lemma 4 states that once the least eigenvalue of NTK during training are bounded, the change of weight will be bounded (evidenced by empirical simulation shown in Figure 2). By employing a proof by contradiction, combined with the results from Lemma, we can deduce that during training the NTKs of the SNN remain close to the deterministic kernel, provided the neural network is sufficiently wide. In a conclusion, with all the lemmas at hand, we arrive at the final Theorem 1 by the following lemma:

Lemma 5.

If Rμ<RμR^{\prime}_{\mu}<R_{\mu}, Rσ<RσR^{\prime}_{\sigma}<R_{\sigma}, and Rd<RdR^{\prime}_{d}<R_{d}, then for all t0t\geq 0, λ0(𝚯(t))λ02\lambda_{0}\left(\boldsymbol{\Theta}(t)\right)\geq\frac{\lambda_{0}}{2}; Besides, the loss follows:

L(t)exp((λ0/n)t)L(0).{L}(t)\leq\exp(-(\lambda_{0}/n)t){L}(0).
Proof of Lemma 5.

The proof is a standard contradiction. Suppose the conclusion does not hold at time tt, which implies that there exists r[m]r\in[m], 𝐰r(μ)(t)𝐰r(μ)(t)2>R\left\|\mathbf{w}^{(\mu)}_{r}(t)-\mathbf{w}^{(\mu)}_{r}(t)\right\|_{2}>R^{\prime}, then by Lemma 3 we know there exists sts\leq t such that λ0(𝚯(s))λ0/2\lambda_{0}(\boldsymbol{\Theta}(s))\leq\lambda_{0}/2. However, this is contradictory to Lemma 4.

To finalize the proof, we bound (0)\mathcal{L}(0):

𝐗𝐟^(𝐗;0)F2=i=1n𝐱i(e)22+2𝐱i(e)2𝐟^i(0)2+𝐟^i(0)22=Θ(n).\displaystyle\left\|{\bf X}-\hat{\mathbf{f}}({\bf X};0)\right\|^{2}_{F}=\sum_{i=1}^{n}\left\|{\color[rgb]{0,0,0}\mathbf{x}^{(e)}_{i}}\right\|^{2}_{2}+2\left\|{\color[rgb]{0,0,0}\mathbf{x}^{(e)}_{i}}\right\|_{2}\|\hat{\mathbf{f}}_{i}(0)\|_{2}+\|\hat{\mathbf{f}}_{i}(0)\|_{2}^{2}={\color[rgb]{0,0,0}\Theta(n)}.

Finally, Rμ<RμR^{\prime}_{\mu}<R_{\mu}, Rσ<RσR^{\prime}_{\sigma}<R_{\sigma}, and Rd<RdR^{\prime}_{d}<R_{d} result in m=Ω(n5d3λ04δ2)m=\Omega\left(\frac{n^{5}d^{3}}{\lambda_{0}^{4}\delta^{2}}\right) which completes the proof. ∎

Finally, we give the detailed proof of Theorem 2, which is based on the linearization of the output function with respect to the weight space.

Proof of Theorem 2.

Our proof first establishes the result of kernel ridge regression in the infinite-width limit, then bounds the perturbation on the network’s prediction. The output function can be expressed as,

𝐟^(𝐱;t)\displaystyle\hat{\mathbf{f}}^{\infty}({\bf x};t) =𝐟^(𝐱;0)+𝚽μ(𝐱)(𝜽(μ)(t)𝜽(μ)(0)),\displaystyle=\hat{\mathbf{f}}^{\infty}({\bf x};0)+\boldsymbol{\Phi}_{{\mu}}({\bf x})^{\top}\left(\boldsymbol{\theta}^{(\mu)}(t)-\boldsymbol{\theta}^{(\mu)}(0)\right),

where 𝜽(μ)𝐖(μ)md\boldsymbol{\theta}^{(\mu)}\triangleq\vec{\mathbf{W}}^{(\mu)}\in\mathbb{R}^{md}, and 𝚽μ(𝐱)=𝜽(μ)𝐟^(𝐱,0)md×d\boldsymbol{\Phi}_{{\mu}}({\bf x})=\nabla_{\boldsymbol{\theta}^{(\mu)}}\hat{\mathbf{f}}({\bf x},0)\in\mathbb{R}^{md\times d}. It is known that the objective function with KL divergence follows:

(t)=12n𝐟^(𝐗)𝐗F2+β𝜽(μ)(t)𝜽(μ)(0)22.\displaystyle\mathcal{L}(t)=\frac{1}{2n}\left\|\hat{\mathbf{f}}(\mathbf{X})-\mathbf{X}\right\|^{2}_{F}+\beta\left\|\boldsymbol{\theta}^{(\mu)}(t)-\boldsymbol{\theta}^{(\mu)}(0)\right\|^{2}_{2}.

We then calculate the gradient flow dynamics for mean weight:

d𝜽(μ)(t)dt=(t)𝜽(μ)=𝚽μ(𝐗)(𝐟^(𝐗;t)𝐗)+β(𝜽(μ)(t)𝜽(μ)(0))\displaystyle\frac{d\boldsymbol{\theta}^{(\mu)}(t)}{dt}=\frac{\partial\mathcal{L}(t)}{\partial\boldsymbol{\theta}^{(\mu)}}=\boldsymbol{\Phi}_{{\mu}}({\bf X})\left(\hat{\mathbf{f}}^{\infty}({\bf X};t)-{\bf X}\right)+\beta\left(\boldsymbol{\theta}^{(\mu)}(t)-\boldsymbol{\theta}^{(\mu)}(0)\right)
=𝚽μ(𝐗)𝚽μ(𝐗)(𝜽(μ)(t)𝜽(μ)(0))+𝚽μ(𝐗)(𝐟^(𝐗;0)𝐗)+β(𝜽(μ)(t)𝜽(μ)(0))\displaystyle=\boldsymbol{\Phi}_{{\mu}}({\bf X})\boldsymbol{\Phi}_{{\mu}}({\bf X})^{\top}\left(\boldsymbol{\theta}^{(\mu)}(t)-\boldsymbol{\theta}^{(\mu)}(0)\right)+\boldsymbol{\Phi}_{{\mu}}({\bf X})(\hat{\mathbf{f}}^{\infty}({\bf X};0)-\mathbf{X})+\beta\left(\boldsymbol{\theta}^{(\mu)}(t)-\boldsymbol{\theta}^{(\mu)}(0)\right)
=(𝚯(μ)+β𝐈)(𝜽(μ)(t)𝜽(μ)(0))+𝚽μ(𝐗)(𝐟^(𝐗;0)𝐗),\displaystyle=\left(\boldsymbol{\Theta}^{{(\mu)}}+\beta\mathbf{I}\right)\left(\boldsymbol{\theta}^{(\mu)}(t)-\boldsymbol{\theta}^{(\mu)}(0)\right)+\boldsymbol{\Phi}_{{\mu}}({\bf X})\left(\hat{\mathbf{f}}^{\infty}({\bf X};0)-\mathbf{X}\right),

which is an ordinary differential equation. It is easy to see that the solution is,

𝜽¯(μ)(t)\displaystyle\overline{\boldsymbol{\theta}}^{(\mu)}(t) =𝚽μ(𝐗)(𝚯(μ)+β𝐈)1(𝐈e(𝚯(μ)+β𝐈)t)(𝐟^(𝐗;0)𝐗),\displaystyle=\boldsymbol{\Phi}^{\top}_{\mu}(\mathbf{X})\left(\boldsymbol{\Theta}^{(\mu)}+\beta{\bf I}\right)^{-1}\left({\bf I}-e^{-(\boldsymbol{\Theta}^{(\mu)}+\beta{\bf I})t}\right)\left(\hat{\mathbf{f}}^{\infty}({\bf X};0)-\mathbf{X}\right),

where 𝜽¯(μ)(t)𝜽(μ)(t)𝜽(μ)(0)\overline{\boldsymbol{\theta}}^{(\mu)}(t)\triangleq\boldsymbol{\theta}^{(\mu)}(t)-\boldsymbol{\theta}^{(\mu)}(0). Plugging the result into the linearized output function, we have,

𝐟^(𝐗;t)\displaystyle\hat{\mathbf{f}}^{\infty}({\bf X};t) =𝐗e(𝚯(μ)(𝐗,𝐗)+β𝐈)t(𝐟^(𝐗;0)𝐗).\displaystyle={\bf X}-e^{-(\boldsymbol{\Theta}^{(\mu)}({\bf X},{\bf X})+\beta{\bf I})t}\left(\hat{\mathbf{f}}^{\infty}({\bf X};0)-\mathbf{X}\right).

For an arbitrary test data 𝐱te\mathbf{x}_{te}, we have,

𝐟^(𝐱te;t)\displaystyle\hat{\mathbf{f}}^{\infty}({\bf x}_{te};t) =𝚯(μ)(𝐱te,𝐗)(𝚯(μ)+β𝐈)1(𝐈e(𝚯(μ)(𝐗,𝐗)+β𝐈)t)𝐗.\displaystyle=\boldsymbol{\Theta}^{(\mu)}({\bf x}_{te},{\bf X})\left(\boldsymbol{\Theta}^{(\mu)}+\beta{\bf I}\right)^{-1}\left({\bf I}-e^{-(\boldsymbol{\Theta}^{(\mu)}({\bf X},{\bf X})+\beta{\bf I})t}\right)\mathbf{X}.

when we take the time to be infinity,

𝐟^(𝐱te;)\displaystyle\hat{\mathbf{f}}^{\infty}({\bf x}_{te};\infty) =𝚯(μ)(𝐱te,𝐗)(𝚯(μ)+β𝐈)1𝐗.\displaystyle=\boldsymbol{\Theta}^{(\mu)}({\bf x}_{te},{\bf X})\left(\boldsymbol{\Theta}^{(\mu)}+\beta{\bf I}\right)^{-1}\mathbf{X}. (18)

The next step is to show the difference between finite-width neural network and infinitely-wide network:

|𝐟^(𝐱te)𝐟^(𝐱te)|O().\displaystyle\left|\hat{\mathbf{f}}({\bf x}_{te})-\hat{\mathbf{f}}^{\infty}({\bf x}_{te})\right|\leq O(\mathcal{E}).

where =init+nΘλ0+β\mathcal{E}=\mathcal{E}_{\rm init}+\frac{\sqrt{n}\mathcal{E}_{\Theta}}{\lambda_{0}+\beta} with 𝐟^(𝜽(0),𝐱te)2init\left\|\hat{\mathbf{f}}\left(\boldsymbol{\theta}(0),\mathbf{x}_{te}\right)\right\|_{2}\leq\mathcal{E}_{\rm init} and 𝚯𝚯(t)2Θ\|\boldsymbol{\Theta}^{\infty}-\boldsymbol{\Theta}(t)\|_{2}\leq\mathcal{E}_{\Theta}. Note the expression in Equation (18) can be rewritten as 𝐟^(𝐱te)=𝚽(𝐱te)𝜷\hat{\mathbf{f}}^{\infty}(\mathbf{x}_{te})=\boldsymbol{\Phi}(\mathbf{x}_{te})^{\top}\boldsymbol{\beta} and the solution to this equation can be further written as the result of applying gradient flow on the following kernel ridge regression problem

min𝜷i=1n12n𝚽(𝐱i)𝜷𝐱i22+β𝜷22,\displaystyle\min_{\boldsymbol{\beta}}\sum_{i=1}^{n}\frac{1}{2n}\left\|\boldsymbol{\Phi}(\mathbf{x}_{i})^{\top}\boldsymbol{\beta}-\mathbf{x}_{i}\right\|_{2}^{2}+\beta\left\|\boldsymbol{\beta}\right\|^{2}_{2},

with initialization 𝜷(0)=0\boldsymbol{\beta}(0)=0. We use 𝜷(t)\boldsymbol{\beta}(t) to denote this parameter at time tt trained by gradient flow and 𝐟^(𝐱te,𝜷(t))\hat{\mathbf{f}}^{\infty}\left(\mathbf{x}_{te},\boldsymbol{\beta}(t)\right) be the predictor for 𝐱te\mathbf{x}_{te} at time tt. With these notations, we rewrite

𝐟^(𝐱te)=t=0d𝐟^(𝜷(t),𝐱te)dt𝑑t,\displaystyle\hat{\mathbf{f}}^{\infty}(\mathbf{x}_{te})=\int_{t=0}^{\infty}\frac{d\hat{\mathbf{f}}(\boldsymbol{\beta}(t),\mathbf{x}_{te})}{dt}dt,

where we have used the fact that the initial prediction is 0.

We thus can analyze the difference between the SNN predictor and infinite-width SNN predictor via this integral form as follows:

𝐟^(𝐱te)𝐟^(𝐱te)2𝐟^(𝜽(0),𝐱te)2+t=0(d𝐟^(𝜽(t),𝐱te)dtd𝐟^(𝜷(t),𝐱te)dt)𝑑t2\displaystyle\left\|\hat{\mathbf{f}}^{\infty}(\mathbf{x}_{te})-\hat{\mathbf{f}}\left(\mathbf{x}_{te}\right)\right\|_{2}\leq\left\|\hat{\mathbf{f}}(\boldsymbol{\theta}(0),\mathbf{x}_{te})\right\|_{2}+\left\|\int_{t=0}^{\infty}\left(\frac{d\hat{\mathbf{f}}(\boldsymbol{\theta}(t),\mathbf{x}_{te})}{dt}-\frac{d\hat{\mathbf{f}}^{\infty}(\boldsymbol{\beta}(t),\mathbf{x}_{te})}{dt}\right)dt\right\|_{2}
\displaystyle\leq init+1nt=0(𝚯(𝐱te,𝐗;t)𝚯(𝐱te,𝐗))(𝐟^(t)𝐗)dt\displaystyle\mathcal{E}_{init}+\bigg{\|}\frac{1}{n}\int_{t=0}^{\infty}\left(\boldsymbol{\Theta}(\mathbf{x}_{te},\mathbf{X};t)-\boldsymbol{\Theta}^{\infty}(\mathbf{x}_{te},\mathbf{X})\right)^{\top}(\hat{\mathbf{f}}(t)-\mathbf{X})dt
+βt=0(𝚽(𝐱te,t)𝚽(𝐱te))𝜷(t)𝑑t2\displaystyle+\beta\int_{t=0}^{\infty}\left(\boldsymbol{\Phi}(\mathbf{x}_{te},t)-\boldsymbol{\Phi}^{\infty}(\mathbf{x}_{te})\right)^{\top}\boldsymbol{\beta}(t)dt\bigg{\|}_{2}
+1nt=0𝚯(𝐱te,𝐗)(𝐟^(t)𝐟^(t))𝑑t+βt=0(𝚽(𝐱te))(𝜷(t)𝜽¯(t))𝑑t2\displaystyle+\bigg{\|}\frac{1}{n}\int_{t=0}^{\infty}\boldsymbol{\Theta}^{\infty}(\mathbf{x}_{te},\mathbf{X})^{\top}(\hat{\mathbf{f}}^{\infty}(t)-\hat{\mathbf{f}}(t))dt+\beta\int_{t=0}^{\infty}\left(\boldsymbol{\Phi}^{\infty}(\mathbf{x}_{te})\right)^{\top}(\boldsymbol{\beta}(t)-\overline{\boldsymbol{\theta}}(t))dt\bigg{\|}_{2}
\displaystyle\leq init+(max0t𝚯(𝐱te,𝐗;t)𝚯(𝐱te,𝐗)2t=0𝐟^(t)𝐗2dt\displaystyle\mathcal{E}_{init}+\bigg{(}\max_{0\leq t\leq\infty}\left\|\boldsymbol{\Theta}(\mathbf{x}_{te},\mathbf{X};t)-\boldsymbol{\Theta}^{\infty}(\mathbf{x}_{te},\mathbf{X})\right\|_{2}\int_{t=0}^{\infty}\left\|\hat{\mathbf{f}}(t)-\mathbf{X}\right\|_{2}dt
+βmax0t𝚽(𝐱te;t)𝚽(𝐱te)2t=0𝜷2dt)\displaystyle+\beta\max_{0\leq t\leq\infty}\left\|\boldsymbol{\Phi}(\mathbf{x}_{te};t)-\boldsymbol{\Phi}^{\infty}(\mathbf{x}_{te})\right\|_{2}\int_{t=0}^{\infty}\|\boldsymbol{\beta}\|_{2}dt\bigg{)}
+\displaystyle+ (max0t𝚯(𝐱te,𝐗)2t=0𝐟^(t)𝐟^(t)2𝑑t+βmax0t𝚽(𝐱te)2t=0𝜷(t)𝜽¯(t)2𝑑t)\displaystyle\bigg{(}\max_{0\leq t\leq\infty}\|\boldsymbol{\Theta}^{\infty}(\mathbf{x}_{te},\mathbf{X})\|_{2}\int_{t=0}^{\infty}\|\hat{\mathbf{f}}(t)-\hat{\mathbf{f}}^{\infty}(t)\|_{2}dt+\beta\max_{0\leq t\leq\infty}\left\|\boldsymbol{\Phi}^{\infty}(\mathbf{x}_{te})\right\|_{2}\int_{t=0}^{\infty}\|\boldsymbol{\beta}(t)-\overline{\boldsymbol{\theta}}(t)\|_{2}dt\bigg{)}
\displaystyle\triangleq init+I2+I3.\displaystyle\mathcal{E}_{init}+I_{2}+I_{3}.

For the second term I2I_{2}, recall that 𝚯(𝐱te,𝐗)𝚯(𝐱te,𝐗;t)2λ02\|\boldsymbol{\Theta}^{\infty}(\mathbf{x}_{te},\mathbf{X})-\boldsymbol{\Theta}(\mathbf{x}_{te},\mathbf{X};t)\|_{2}\leq\frac{\lambda_{0}}{2} by Lemma 3. Besides, we know that 𝐟^(t)𝐗22+β𝜽¯22exp((λ02+β)t)𝐟^(0)𝐗22\|\hat{\mathbf{f}}(t)-\mathbf{X}\|^{2}_{2}+\beta\|\overline{\boldsymbol{\theta}}\|^{2}_{2}\leq\exp(-(\frac{\lambda_{0}}{2}+\beta)t)\|\hat{\mathbf{f}}(0)-\mathbf{X}\|^{2}_{2}. Therefore, we can bound:

0𝐟^(t)𝐗2+β𝜽¯(t)2dtt=0exp((λ02+β)t)(𝐟^(0)𝐗2)𝑑t=O(nλ0+β).\displaystyle\int_{0}^{\infty}\|\hat{\mathbf{f}}(t)-\mathbf{X}\|_{2}+\beta\|\overline{\boldsymbol{\theta}}(t)\|_{2}dt\leq\int_{t=0}^{\infty}\exp(-(\frac{\lambda_{0}}{2}+\beta)t)(\|\hat{\mathbf{f}}(0)-\mathbf{X}\|_{2})dt=O\left(\frac{\sqrt{n}}{\lambda_{0}+\beta}\right).

As a result, we have I2=O(nΘλ0+β)I_{2}=O\left(\frac{\sqrt{n}\mathcal{E}_{\Theta}}{\lambda_{0}+\beta}\right). To bound I3I_{3}, we have

0𝐟^(t)𝐟^(t)2+β𝜷𝜽¯2dt\displaystyle\int_{0}^{\infty}\|\hat{\mathbf{f}}(t)-\hat{\mathbf{f}}^{\infty}(t)\|_{2}+\beta\|\boldsymbol{\beta}-\overline{\boldsymbol{\theta}}\|_{2}dt
0𝐟^(t)𝐗2+β𝜽¯2dt+0𝐟^(t)𝐗2+β𝜷2dt=O(nλ0+β).\displaystyle\leq\int_{0}^{\infty}\|\hat{\mathbf{f}}(t)-\mathbf{X}\|_{2}+\beta\|\overline{\boldsymbol{\theta}}\|_{2}dt+\int_{0}^{\infty}\|\hat{\mathbf{f}}^{\infty}(t)-\mathbf{X}\|_{2}+\beta\|{\boldsymbol{\beta}}\|_{2}dt=O\left(\frac{\sqrt{n}}{\lambda_{0}+\beta}\right).

As a result, we have I3=O(nΘλ0+β)I_{3}=O\left(\frac{\sqrt{n}\mathcal{E}_{\Theta}}{\lambda_{0}+\beta}\right). Lastly, we put things together and get

|𝐟^(t)𝐟^(t)|=O(init+Θnλ0+β).\displaystyle|\hat{\mathbf{f}}(t)-\hat{\mathbf{f}}^{\infty}(t)|=O\left(\mathcal{E}_{init}+\mathcal{E}_{\Theta}\frac{\sqrt{n}}{\lambda_{0}+\beta}\right).

6 Experiments

In this section, we provide empirical evidence to support our theoretical analysis concerning the training dynamics of over-parameterized stochastic neural networks, which are optimized using VAE training objectives. Our experimental results, derived from training on the MNIST dataset, corroborate our theoretical predictions. In addition, we report our observation that VAEs with larger latent spaces are capable of learning more information, which substantiates the rationale behind our theoretical examination of the convergence properties of over-parameterized VAEs.

6.1 Theoretical verification

Refer to caption
Figure 2: Relative Frobenius norm change in weights after training, where mm is the width of the network. Solid lines correspond to empirical simulations and dotted lines are theoretical predictions.

To empirically validate our lemmas, we employ a three-hidden-layer fully connected network, guided by the training objective function as presented in Equation (4). The network parameters are initialized using the Neural Tangent Kernel (NTK) parameterization, in line with Equation (2). For training, we adopt the ordinary mean-squared error (MSE) as the reconstruction loss and employ full-batch gradient descent with a consistent learning rate of 1 on a subset of the MNIST dataset containing 128 samples and 10 classes. We measure the change in weights of each layer, denoted by 𝐖(t)𝐖(0)F/𝐖(0)F\|\mathbf{W}(t)-\mathbf{W}(0)\|_{F}/\|\mathbf{W}(0)\|_{F}, after performing t=217t=2^{17} steps of gradient descent updates from random initialization. Figure 2 displays the results for each layer. We only measure the change in weight 𝐖(μ)\mathbf{W}^{(\mu)} for the latent layer (μ\mu). Our observations show that the relative Frobenius norm changes in the Encoder and Decoder scales as 1/m1/\sqrt{m}, while the hidden layers’ weights scale as 1/m1/m. This result confirms that the weights of SNN do not move too much during training, and further confirms the correctness of our theoretical claim (Lemma 4). Notably, a similar convergence rate for weight changes in deterministic neural networks was observed in [23]. To empirically validate our lemmas, we employ a three-hidden-layer Tanh fully connected network, guided by the training objective function as presented in Equation (4). The network parameters are initialized using the Neural Tangent Kernel (NTK) parameterization, in line with Equation (2). For training, we adopt the ordinary mean-squared error (MSE) as the reconstruction loss and employ full-batch gradient descent with a consistent learning rate of 1 on a subset of the MNIST dataset containing 128 samples and 10 classes. We measure the change in weights of each layer, denoted by 𝐖(t)𝐖(0)F/𝐖(0)F\|\mathbf{W}(t)-\mathbf{W}(0)\|_{F}/\|\mathbf{W}(0)\|_{F}, after performing t=217t=2^{17} steps of gradient descent updates from random initialization. Figure 2 displays the results for each layer. We only measure the change in weight 𝐖(μ)\mathbf{W}^{(\mu)} for the latent layer (μ\mu). Our observations show that the relative Frobenius norm changes in the Encoder and Decoder scales as 1/m1/\sqrt{m}, while the hidden layers’ weights scale as 1/m1/m. This result confirms that the weights of SNN do not move too much during training, and further confirms the correctness of our theoretical claim (Lemma 4). Notably, a similar convergence rate for weight changes in deterministic neural networks was observed in [23].

6.2 Large latent space can learn more

Refer to caption

Refer to caption

Refer to caption

Refer to caption

Figure 3: Disentanglement scores for networks of latent dimension: m=10,20,50,100,200m=10,20,50,100,200 on dSprites and Cars 3D. Observations: the larger the latent space, the better the disentangle learning.

In this subsection, we report our experimental observations, aligning with numerous prior studies [8, 9]. We observed that larger latent spaces are capable of capturing more information, as evidenced by higher disentanglement scores and the emergence of additional features not discernible in models with narrower VAE configurations.

Adopting the experimental setup utilized in Beta-VAE[45], we explored the effects of varying latent space dimensions. Our experiments were conducted on the dSprites[45] and Cars3D datasets[46]. As shown in Figure 3, the width of the latent space, denoted by mm, is varied across [10,20,50,100,200]. We assessed the performance using a suite of disentanglement score metrics, including the BetaVAE, β\beta-VAE metric [45], Mutual Information Gap (MIG) [47], Separated Attribute Predictability (SAP) score [48], and Factor-VAE metric [49]. Our findings indicate that larger latent spaces lead to higher disentanglement scores, with the exception of a less pronounced improvement when employing the BetaVAE metric on the Cars3D dataset. These results corroborate the hypothesis that a larger latent space is capable of capturing more information.

Refer to caption
Figure 4: New image attributes discovered by large latent space VAE (m=256m=256) but not by small latent space VAE (m=10m=10) CelebA dataset.

Furthermore, in our experiments with the CelebA [50] datasets, we observed that a larger latent space can reveal additional features not detected in smaller latent space VAEs. As illustrated in Figure 4, on the CelebA dataset, a VAE with a latent space of 256 dimensions uncovered new image attributes such as emotion, eye style, and hairstyle, which were not identified by a VAE with a latent space of just 10 dimensions. These findings confirm that VAEs with larger latent spaces are capable of detecting additional features not observable in narrower VAE configurations.

These observations validate the intuitive notion that VAEs with larger latent spaces exhibit superior disentanglement performance. This underlines our initial motivation for investigating over-parameterized VAEs, as opposed to conventional VAEs, to leverage the benefits of increased latent dimensionality.

7 Conclusion

In this work, we have established the convergence of over-parameterized VAEs using the neural tangent kernel techniques. Additionally, we have demonstrated that the expected output function trained with the full objective function and KL divergence converges to the kernel ridge regression, confirming the regularization effect of the additional KL divergence. The theoretical insights presented in this paper pave the way for analyzing stochastic neural networks within other paradigms, such as deep Bayesian networks. Our empirical evaluations corroborate that the theoretical predictions are consistent with real-world training dynamics. Furthermore, through experimental investigations on real datasets, we have highlighted the training efficiency of over-parameterized VAEs, as suggested by our theoretical findings.

References

  • \bibcommenthead
  • Kingma and Welling [2013] Kingma, D.P., Welling, M.: Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114 (2013)
  • Radford et al. [2015] Radford, A., Metz, L., Chintala, S.: Unsupervised representation learning with deep convolutional generative adversarial networks. arXiv preprint arXiv:1511.06434 (2015)
  • Van Den Oord and Vinyals [2017] Van Den Oord, A., Vinyals, O.: Neural discrete representation learning. In: Advances in Neural Information Processing Systems, pp. 6306–6315 (2017)
  • Wang et al. [2022] Wang, L., Huang, W., Zhang, M., Pan, S., Chang, X., Su, S.W.: Pruning graph neural networks by evaluating edge properties. Knowledge-Based Systems 256, 109847 (2022)
  • Bowman et al. [2015] Bowman, S.R., Vilnis, L., Vinyals, O., Dai, A.M., Jozefowicz, R., Bengio, S.: Generating sentences from a continuous space. arXiv preprint arXiv:1511.06349 (2015)
  • Ng et al. [2011] Ng, A., et al.: Sparse autoencoder. CS294A Lecture notes 72(2011), 1–19 (2011)
  • Tschannen et al. [2018] Tschannen, M., Bachem, O., Lucic, M.: Recent advances in autoencoder-based representation learning. arXiv preprint arXiv:1812.05069 (2018)
  • Song et al. [2019] Song, T., Sun, J., Chen, B., Peng, W., Song, J.: Latent space expanded variational autoencoder for sentence generation. IEEE Access 7, 144618–144627 (2019)
  • Lim et al. [2020] Lim, K.-L., Jiang, X., Yi, C.: Deep clustering with variational autoencoder. IEEE Signal Processing Letters 27, 231–235 (2020)
  • Zhang et al. [2022] Zhang, M., Wang, L., Campos, D., Huang, W., Guo, C., Yang, B.: Weighted mutual learning with diversity-driven model compression. Advances in Neural Information Processing Systems 35, 11520–11533 (2022)
  • He et al. [2019] He, J., Spokoyny, D., Neubig, G., Berg-Kirkpatrick, T.: Lagging inference networks and posterior collapse in variational autoencoders. arXiv preprint arXiv:1901.05534 (2019)
  • Lucas et al. [2019] Lucas, J., Tucker, G., Grosse, R.B., Norouzi, M.: Don’t blame the ELBO! a linear VAE perspective on posterior collapse. Advances in Neural Information Processing Systems 32 (2019)
  • Koehler et al. [2021] Koehler, F., Mehta, V., Risteski, A., Zhou, C.: Variational autoencoders in the presence of low-dimensional data: landscape and implicit bias. arXiv preprint arXiv:2112.06868 (2021)
  • Jacot et al. [2018] Jacot, A., Gabriel, F., Hongler, C.: Neural tangent kernel: Convergence and generalization in neural networks. arXiv preprint arXiv:1806.07572 (2018)
  • Allen-Zhu et al. [2019] Allen-Zhu, Z., Li, Y., Song, Z.: A convergence theory for deep learning via over-parameterization. In: International Conference on Machine Learning, pp. 242–252 (2019). PMLR
  • Du et al. [2018] Du, S.S., Zhai, X., Poczos, B., Singh, A.: Gradient descent provably optimizes over-parameterized neural networks. arXiv preprint arXiv:1810.02054 (2018)
  • Du et al. [2019] Du, S., Lee, J., Li, H., Wang, L., Zhai, X.: Gradient descent finds global minima of deep neural networks. In: International Conference on Machine Learning, pp. 1675–1685 (2019). PMLR
  • Huang et al. [2020] Huang, W., Du, W., Da Xu, R.Y.: On the neural tangent kernel of deep networks with orthogonal initialization. arXiv preprint arXiv:2004.05867 (2020)
  • Huang et al. [2021] Huang, W., Li, Y., Du, W., Da Xu, R.Y., Yin, J., Chen, L., Zhang, M.: Towards deepening graph neural networks: A gntk-based optimization perspective. arXiv preprint arXiv:2103.03113 (2021)
  • Zou et al. [2020] Zou, D., Cao, Y., Zhou, D., Gu, Q.: Gradient descent optimizes over-parameterized deep relu networks. Machine Learning 109(3), 467–492 (2020)
  • Chen et al. [2021] Chen, Y., Huang, W., Nguyen, L., Weng, T.-W.: On the equivalence between neural network and support vector machine. Advances in Neural Information Processing Systems 34 (2021)
  • Chen et al. [2019] Chen, Z., Cao, Y., Zou, D., Gu, Q.: How much over-parameterization is sufficient to learn deep relu networks? arXiv preprint arXiv:1911.12360 (2019)
  • Lee et al. [2019] Lee, J., Xiao, L., Schoenholz, S., Bahri, Y., Novak, R., Sohl-Dickstein, J., Pennington, J.: Wide neural networks of any depth evolve as linear models under gradient descent. Advances in neural information processing systems 32 (2019)
  • Yang [2019] Yang, G.: Scaling limits of wide neural networks with weight sharing: Gaussian process behavior, gradient independence, and neural tangent kernel derivation. arXiv preprint arXiv:1902.04760 (2019)
  • Arora et al. [2019a] Arora, S., Du, S.S., Hu, W., Li, Z., Salakhutdinov, R., Wang, R.: On exact computation with an infinitely wide neural net. arXiv preprint arXiv:1904.11955 (2019)
  • Arora et al. [2019b] Arora, S., Du, S., Hu, W., Li, Z., Wang, R.: Fine-grained analysis of optimization and generalization for overparameterized two-layer neural networks. In: International Conference on Machine Learning, pp. 322–332 (2019). PMLR
  • Cao and Gu [2019] Cao, Y., Gu, Q.: Generalization bounds of stochastic gradient descent for wide and deep neural networks. Advances in Neural Information Processing Systems 32, 10836–10846 (2019)
  • Du et al. [2019] Du, S.S., Hou, K., Salakhutdinov, R.R., Poczos, B., Wang, R., Xu, K.: Graph neural tangent kernel: Fusing graph neural networks with graph kernels. Advances in Neural Information Processing Systems 32, 5723–5733 (2019)
  • Wang et al. [2022] Wang, H., Huang, W., Wu, Z., Tong, H., Margenot, A.J., He, J.: Deep active learning by leveraging training dynamics. Advances in Neural Information Processing Systems 35, 25171–25184 (2022)
  • Hron et al. [2020] Hron, J., Bahri, Y., Sohl-Dickstein, J., Novak, R.: Infinite attention: Nngp and ntk for deep attention networks. In: International Conference on Machine Learning, pp. 4376–4386 (2020). PMLR
  • Chen et al. [2022] Chen, W., Huang, W., Gong, X., Hanin, B., Wang, Z.: Deep architecture connectivity matters for its convergence: A fine-grained analysis. Advances in neural information processing systems 35, 35298–35312 (2022)
  • Franceschi et al. [2022] Franceschi, J.-Y., De Bézenac, E., Ayed, I., Chen, M., Lamprier, S., Gallinari, P.: A neural tangent kernel perspective of gans. In: International Conference on Machine Learning, pp. 6660–6704 (2022). PMLR
  • Nguyen et al. [2021] Nguyen, T.V., Wong, R.K., Hegde, C.: Benefits of jointly training autoencoders: An improved neural tangent kernel analysis. IEEE Transactions on Information Theory 67(7), 4669–4692 (2021)
  • Ziyin et al. [2022] Ziyin, L., Zhang, H., Meng, X., Lu, Y., Xing, E., Ueda, M.: Stochastic neural networks with infinite width are deterministic. arXiv preprint arXiv:2201.12724 (2022)
  • Huang et al. [2023] Huang, W., Liu, C., Chen, Y., Da Xu, R.Y., Zhang, M., Weng, T.-W.: Analyzing deep pac-bayesian learning with neural tangent kernel: Convergence, analytic generalization bound, and efficient hyperparameter selection. Transactions on Machine Learning Research (2023)
  • Clerico et al. [2023] Clerico, E., Deligiannidis, G., Doucet, A.: Wide stochastic networks: Gaussian limit and pac-bayesian training. In: International Conference on Algorithmic Learning Theory, pp. 447–470 (2023). PMLR
  • Alemi et al. [2018] Alemi, A., Poole, B., Fischer, I., Dillon, J., Saurous, R.A., Murphy, K.: Fixing a broken elbo. In: International Conference on Machine Learning, pp. 159–168 (2018). PMLR
  • Dai and Wipf [2019] Dai, B., Wipf, D.: Diagnosing and enhancing vae models. arXiv preprint arXiv:1903.05789 (2019)
  • Rolinek et al. [2019] Rolinek, M., Zietlow, D., Martius, G.: Variational autoencoders pursue pca directions (by accident). In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 12406–12415 (2019)
  • Kumar and Poole [2020] Kumar, A., Poole, B.: On implicit regularization in beta-vae. In: International Conference on Machine Learning, pp. 5480–5490 (2020). PMLR
  • Nakagawa et al. [2021] Nakagawa, A., Kato, K., Suzuki, T.: Quantitative understanding of vae as a non-linearly scaled isometric embedding. In: International Conference on Machine Learning, pp. 7916–7926 (2021). PMLR
  • Wipf [2023] Wipf, D.: Marginalization is not marginal: No bad vae local minima when learning optimal sparse representations (2023)
  • Dai et al. [2021] Dai, B., Wenliang, L., Wipf, D.: On the value of infinite gradients in variational autoencoder models. Advances in Neural Information Processing Systems 34, 7180–7192 (2021)
  • Dai et al. [2020] Dai, B., Wang, Z., Wipf, D.: The usual suspects? reassessing blame for vae posterior collapse. In: International Conference on Machine Learning, pp. 2313–2322 (2020). PMLR
  • Higgins et al. [2016] Higgins, I., Matthey, L., Pal, A., Burgess, C., Glorot, X., Botvinick, M., Mohamed, S., Lerchner, A.: beta-vae: Learning basic visual concepts with a constrained variational framework (2016)
  • Reed et al. [2015] Reed, S.E., Zhang, Y., Zhang, Y., Lee, H.: Deep visual analogy-making. Advances in neural information processing systems 28 (2015)
  • [47] Chen, R.T., Li, X., Grosse, R., Duvenaud, D.: Isolating sources of disentanglement in vaes. In: Proceedings of the 32nd International Conference on Neural Information Processing Systems, vol. 2615, p. 2625
  • Kumar et al. [2017] Kumar, A., Sattigeri, P., Balakrishnan, A.: Variational inference of disentangled latent concepts from unlabeled observations. arXiv preprint arXiv:1711.00848 (2017)
  • Kim and Mnih [2018] Kim, H., Mnih, A.: Disentangling by factorising. In: International Conference on Machine Learning, pp. 2649–2658 (2018). PMLR
  • Liu et al. [2015] Liu, Z., Luo, P., Wang, X., Tang, X.: Deep learning face attributes in the wild. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 3730–3738 (2015)