This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Scalable Set Encoding with Universal Mini-Batch Consistency and
Unbiased Full Set Gradient Approximation

Jeffrey Willette    Seanie Lee    Bruno Andreis    Kenji Kawaguchi    Juho Lee    Sung Ju Hwang
Abstract

Recent work on mini-batch consistency (MBC) for set functions has brought attention to the need for sequentially processing and aggregating chunks of a partitioned set while guaranteeing the same output for all partitions. However, existing constraints on MBC architectures lead to models with limited expressive power. Additionally, prior work has not addressed how to deal with large sets during training when the full set gradient is required. To address these issues, we propose a Universally MBC (UMBC) class of set functions which can be used in conjunction with arbitrary non-MBC components while still satisfying MBC, enabling a wider range of function classes to be used in MBC settings. Furthermore, we propose an efficient MBC training algorithm which gives an unbiased approximation of the full set gradient and has a constant memory overhead for any set size for both train- and test-time. We conduct extensive experiments including image completion, text classification, unsupervised clustering, and cancer detection on high-resolution images to verify the efficiency and efficacy of our scalable set encoding framework. Our code is available at github.com/jeffwillette/umbc

Machine Learning, ICML

1 Introduction

For a variety of problems for which deep models can be applied, unordered sets naturally arise as an input. For example, a set of words in a document (Jurafsky & Martin, 2008) and sets of patches within an image for multiple instance learning (Quellec et al., 2017). Functions which encode sets are commonly known as set encoders, and most previously proposed set encoding functions (Zaheer et al., 2017; Lee et al., 2019) have implicitly assumed that the whole set can fit into memory and be accessed in a single chunk. However, this is not a realistic assumption if it is necessary to process large sets or streaming data. As shown in Figure 2(a), Set Transformer (Lee et al., 2019) cannot properly handle streaming data and suffers performance degradation. Please see Figure 8 for more qualitative examples. Bruno et al. (2021) identified this problem, and introduced the mini-batch consistency (MBC) property which dictates that an MBC set encoding model must be able to sequentially process subsets from a partition of a set while guaranteeing the same output over any partitioning scheme, as illustrated in Figure 1. In order to satisfy the MBC property, they devised an attention-based MBC model, the Slot Set Encoder (SSE).

Refer to caption
Figure 1: Non-MBC models (1, 2) produce inconsistent outputs when given different set partitions. MBC models (3) produce consistent outputs for any random partition with a specific architecture. UMBC composes both MBC/non-MBC components, expanding the possible set of MBC functions, and allowing for more expressive models.
Refer to caption
(a) Set Transformer
Refer to caption
(b) Deepsets
Refer to caption
(c) Slot Set Encoder
Refer to caption
(d) UMBC+ST (Ours)
Figure 2: Streaming inputs: A non-MBC model (2(a)) suffers performance degradation in streaming settings. MBC models (2(b), 2(c), 2(d)) can handle streaming inputs consistently. Creating an MBC composition of both MBC/non-MBC components (2(d)) creates the strongest MBC model. A () indicates MBC models while () indicates non-MBC models.

Although SSE satisfies MBC, there are several limitations. First, it has limited expressive power due to the constraints imposed on its architecture. Instead of the conventional softmax attention (Vaswani et al., 2017), the attention of SSE is restricted to using a sigmoid for attention without normalization over the rows of the attention matrix, which may be undesirable for applications requiring convex combinations of inputs. Moreover, the Hierarchical SSE is a composition of pure MBC functions and thus cannot utilize more expressive non-MBC models, such as those utilizing self-attention. Another crucial limitation of the SSE is its limited scalability during training. Training models with large sets requires computing gradients over the full set which can be computationally prohibitive. SSE proposes to randomly sample a small subset for gradient computation, which is a biased estimator of the full set gradient as we show in Section A.8.

To tackle these limitations of SSE, we propose Universal MBC (UMBC) set functions which enable utilizing a broader range of functions while still satisfying the MBC property. Firstly, we relax the restriction to the sigmoid on the activation functions for attention and show that cross-attention with a wider class of activation functions, including the softmax, is MBC. Moreover, we re-interpret UMBC’s output as a set, which as we show in Figures 1 and 3, universally allows for the application of non-MBC set encoders when processing UMBC’s output sets, resulting in more expressive functions while maintaining the MBC property. For a concrete example, UMBC used in conjunction with the (non-MBC) Set Transformer (ST) produces consistent output for any partition of a set as shown in Figure 3, and outperforms all other MBC models for clustering streaming data as illustrated in Figure 2.

Lastly, for training MBC models, we propose a novel and scalable algorithm to approximate full set gradient. Specifically, we obtain the full set representation by partitioning the set into subsets and aggregating the subset representations while only considering a portion of the subsets for gradient computation. We find this leads to a constant memory overhead for computing the gradient with a fixed size subset, and is an unbiased estimator of the full set gradient.

To verify the efficacy and efficiency of our proposed UMBC framework and full set gradient approximation algorithm, we perform extensive experiments on a variety of tasks including image completion, text classification, unsupervised clustering, and cancer detection on high-resolution images. Furthermore, we theoretically show that UMBC is a universal approximator of continuous permutation invariant functions under some mild assumptions and the proposed training algorithm minimizes the total loss of the full set version by making progress toward its stationary points. We summarize our contributions as follows:

  • We propose a UMBC framework which allows for a broad class of activation functions, including softmax, for attention and also enables utilizing non-MBC functions in conjuction with UMBC while satisfying MBC, resulting in more expressive and less restrictive architectures.

  • We propose an efficient training algorithm with a constant memory overhead for any set size by deriving an unbiased estimator of the full set gradient which empirically performs comparably to using the full set gradient.

  • We theoretically show that UMBC is a universal approximator to continuous permutation invariant functions under mild assumptions and our algorithm minimizes the full set total loss by making progress toward its stationary points.

Refer to caption
Figure 3: Variance between encodings of 100 different partitions of the same set. UMBC+ST satisfies MBC and thus has no variance.

2 Related Work

Set Encoding. Deep learning for set structured data has been an active research topic since the introduction of DeepSets (Zaheer et al., 2017), which solidified the requirements of deep set functions, namely permutation equivariant feature extraction and permutation invariant set pooling. Zaheer et al. (2017) have shown that under certain conditions, functions which satisfy the aforementioned requirements act as universal approximators for functions of sets. Subsequently, the Set Transformer (Lee et al., 2019) applied attention (Vaswani et al., 2017) to sets, which has proven to be a powerful tool for set functions. Self-attentive set functions excel on tasks where independently processing set elements may fail to capture pairwise interactions between elements. Subsequent works which utilize pairwise set element interactions include optimal transport (Mialon et al., 2021) and expectation maximization (Kim, 2022). Other notable approaches to permutation invariant set pooling include featurewise sorting (Zhang et al., 2020), and canonical orderings of set elements (Murphy et al., 2019).

Mini-Batch Consistency (MBC). Every method mentioned in the preceding paragraph suffers from an architectural bias which limits them to seeing and processing the whole set in a single chunk. Bruno et al. (2021) identified this problem, and highlighted the necessity for MBC which guarantees that processing and aggregating each subset from a set partition results in the same representation as encoding the entire set at once (Definition 3.2). This is important in settings where the data may not fit into memory due to either large data or limited on-devices resources. In addition to identifying the MBC property, Bruno et al. (2021) also proposed the Slot Set Encoder (SSE) which utilizes cross attention between learnable ‘slots’ and set elements in conjunction with simple activation functions in order to achieve an MBC model. As shown in Table 1, however, SSE cannot utilize self-attention to model pairwise interactions of set elements due to the constraints imposed on its architecture, which makes it less expressive than the Set Transformer.

Table 1: Properties of set functions. UMBC models can use arbitrary component set functions and are therefore unconstrained.
Model MBC Cross-Attn. Self-Attn.
DeepSets (Zaheer et al., 2017)
SSE (Bruno et al., 2021)
FSPool (Zhang et al., 2020)
Diff EM (Kim, 2022)
Set Transformer (Lee et al., 2019)
(Ours) UMBC+Any Set Function

3 Method

In this section, we describe the problem we target and provide a formulation for UMBC models along with a derivation of our unbiased full set gradient approximation algorithm. All proofs of theorems are deferred to Appendix A.

3.1 Preliminaries

Let 𝔛\mathfrak{X} be a dxd_{x}-dimensional vector space over \mathbb{R} and let 2𝔛2^{\mathfrak{X}} be the power set of 𝔛\mathfrak{X}. We focus on a collection of finite sets 𝒳\mathcal{X}, which is a subset of 2𝔛2^{\mathfrak{X}} such that supX𝒳|X|\sup_{X\in\mathcal{X}}\lvert X\rvert\in\mathbb{N}. We want to construct a parametric function fθ:𝒳𝒵f_{\theta}:\mathcal{X}\to\mathcal{Z} satisfying permutation invariance. Specifically, given a set Xi={𝐱i,j}j=1Ni𝒳X_{i}=\{{\mathbf{x}}_{i,j}\}_{j=1}^{N_{i}}\in\mathcal{X}, the output of the function Zi=fθ(Xi)Z_{i}=f_{\theta}(X_{i}) is a fixed sized representation which is invariant to all permutations of the indices {1,,Ni}\{1,\ldots,N_{i}\}. For supervised learning, we define a task specific decoder gλ:𝒵dyg_{\lambda}:\mathcal{Z}\to\mathbb{R}^{d_{y}} and optimize parameters θ\theta and λ\lambda to minimize the loss

L(θ,λ)=1ni=1n((gλfθ)(Xi),yi)L(\theta,\lambda)=\frac{1}{n}\sum_{i=1}^{n}\ell((g_{\lambda}\circ f_{\theta})(X_{i}),y_{i}) (1)

on training data ((Xi,yi))i=1n((X_{i},y_{i}))_{i=1}^{n}, where yiy_{i} is a label for the input set XiX_{i} and \ell denotes a loss function.

Definition 3.1 (Permutation Invariance).

Let 𝔖N\mathfrak{S}_{N} be the set of all permutations of {1,,N}\{1,\ldots,N\}, i.e. 𝔖N={π:[N][N]π is bijective}\mathfrak{S}_{N}=\{\pi:[N]\to[N]\mid\pi\text{ is bijective}\} where [N]{1,,N}[N]\coloneqq\{1,\ldots,N\}. A function fθ:𝒳𝒵f_{\theta}:\mathcal{X}\to\mathcal{Z} is permutation invariant iff fθ({𝐱π(1),𝐱π(N)})=fθ({𝐱1,,𝐱N})f_{\theta}(\{{\mathbf{x}}_{\pi(1)},\ldots{\mathbf{x}}_{\pi(N)}\})=f_{\theta}(\{{\mathbf{x}}_{1},\ldots,{\mathbf{x}}_{N}\}) for all X𝒳X\in\mathcal{X} and for all permutation π𝔖N\pi\in\mathfrak{S}_{N}.

We further assume that the cardinality of a set XX is sufficiently large, such that loading and processing the whole set at once is computationally prohibitive. For non-MBC models, a naïve approach to solve this problem would be to encode a small subset of the full set as an approximation, leading to a possibly suboptimal representation of the full set. Instead, Bruno et al. (2021) propose a mini-batch consistent (MBC) set encoder, the Slot Set Encoder (SSE), to piecewise process disjoint subsets of the full set and aggregate them to obtain a consistent full set representation.

Definition 3.2 (Mini-Batch Consistency).

We say a function fθf_{\theta} is mini-batch consistent iff for any X𝒳X\in\mathcal{X}, there is a function hh such that for any partition ζ(X)\zeta(X) of the set XX,

fθ(X)=h({fθ(S)𝒵Sζ(X)}).\displaystyle f_{\theta}(X)=h\left(\{f_{\theta}(S)\in\mathcal{Z}\mid S\in\zeta(X)\}\right). (2)

Models which satisfy the MBC property can partition a set into subsets, encode, and then aggregate the subset representations to achieve the exact same output as encoding the full set. Due to constraints on the architecture of the SSE, however, on certain tasks the SSE shows weaker performance than non-MBC set encoders such as Set Transformer (Lee et al., 2019) which utilizes self attention. To tackle this limitation, we propose Universal MBC (UMBC) set encoders which are both MBC and also allow for the use of arbitrary non-MBC set functions while still satisfying MBC property.

3.2 Universal Mini-Batch Consistent Set Encoder

In this section, we provide a formulation of our UMBC set encoder fθf_{\theta}. Given an input set X𝒳X\in\mathcal{X}, we represent it as a matrix X=[𝐱1𝐱N]N×dxX=[{\mathbf{x}}_{1}\cdots{\mathbf{x}}_{N}]^{\top}\in\mathbb{R}^{N\times d_{x}} whose rows are elements in the set, and independently process each element with ϕ:dxdh\phi:\mathbb{R}^{d_{x}}\to\mathbb{R}^{d_{h}} as Φ(X)=[ϕ(𝐱1)ϕ(𝐱N)]\Phi(X)=[\phi({\mathbf{x}}_{1})\cdots\phi({\mathbf{x}}_{N})]^{\top}, where ϕ\phi is a deep feature extractor. We then compute the un-normalized attention score between a set of learnable slots Σ=[𝐬1𝐬k]k×ds\Sigma=[{\mathbf{s}}_{1}\cdots{\mathbf{s}}_{k}]^{\top}\in\mathbb{R}^{k\times d_{s}} and Φ(X)\Phi(X) as:

Q=LN(ΣWQ),K(X)=Φ(X)WK,V(X)=Φ(X)WV\displaystyle Q=\scalebox{0.96}{$\texttt{LN}$}(\Sigma W^{Q}),K(X)=\Phi(X)W^{K},V(X)=\Phi(X)W^{V}
A^=σ(d1QK(X))k×N,\displaystyle\hat{A}=\sigma\left(\sqrt{d^{-1}}\cdot QK(X)^{\top}\right)\in\mathbb{R}^{k\times N}, (3)

where σ\sigma is an element-wise activation function with σ(x)0\sigma(x)\gneq 0 for all xx\in\mathbb{R}, LN denotes layer normalization (Ba et al., 2016), and WQds×d,WKdh×d,WVdh×dW^{Q}\in\mathbb{R}^{d_{s}\times d},W^{K}\in\mathbb{R}^{d_{h}\times d},W^{V}\in\mathbb{R}^{d_{h}\times d} are parameters which are part of θ\theta. For simplicity, we omit biases for Q,KQ,K, and VV. With the un-normalized attention score A^\hat{A}, we can define a map

f^θ:XN×dxνp(A^)V(X)k×d\hat{f}_{\theta}:X\in\mathbb{R}^{N\times d_{x}}\mapsto\nu_{p}(\hat{A})V(X)\in\mathbb{R}^{k\times d} (4)

for p=1,2p=1,2, where νp:k×dk×d\nu_{p}:\mathbb{R}^{k\times d}\to\mathbb{R}^{k\times d} is defined by either ν1(A^)i,j=A^i,j/i=1kA^i,j\nu_{1}(\hat{A})_{i,j}=\hat{A}_{i,j}/\sum_{i=1}^{k}\hat{A}_{i,j} which normalizes the columns or the identity mapping ν2(A^)i,j=A^i,j\nu_{2}(\hat{A})_{i,j}=\hat{A}_{i,j}. The choice of νp\nu_{p} depends on the desired activation function σ\sigma. Alternatively, similar to slot attention (Locatello et al., 2020), we can make the function stochastic by sampling 𝐬i𝒩(μi,diag(softplus(𝐯i))){\mathbf{s}}_{i}\sim\mathcal{N}(\mathbf{\mu}_{i},\text{diag}(\texttt{softplus}(\mathbf{v}_{i}))) with reparameterization (Kingma & Welling, 2013) for i=1,,ki=1,\ldots,k, where μids,𝐯ids\mu_{i}\in\mathbb{R}^{d_{s}},\mathbf{v}_{i}\in\mathbb{R}^{d_{s}} are part of the parameters θ\theta. If we sample 𝐬1,,𝐬kiid𝒩(μ1,diag(softplus(𝐯1))){\mathbf{s}}_{1},\ldots,{\mathbf{s}}_{k}\stackrel{{\scriptstyle\text{iid}}}{{\sim}}\mathcal{N}(\mu_{1},\text{diag}(\texttt{softplus}(\mathbf{v}_{1}))) with a sigmoid for σ\sigma and ν1\nu_{1} for normalization, and then apply a pooling function (sum, mean, min, or max) to the columns of [f^θ(X)1f^θ(X)k]k×d[\hat{f}_{\theta}(X)^{\top}_{1}\cdots\hat{f}_{\theta}(X)^{\top}_{k}]^{\top}\in\mathbb{R}^{k\times d}, we achieve a function equivalent to the SSE, where f^θ(X)i\hat{f}_{\theta}(X)_{i} is ii-th row of f^θ(X){\hat{f}}_{\theta}(X).

However, SSE has some drawbacks. First, since the attention score of νp(A^)i,j\nu_{p}(\hat{A})_{i,j} is independent to the other N1N-1 attention scores νp(A^)i,l\nu_{p}(\hat{A})_{i,l} for ljl\neq j, it is impossible for the rows of νp(A^)\nu_{p}(\hat{A}) to be convex coefficients as the softmax outputs in conventional attention (Vaswani et al., 2017). Notably, in some of our experiments, the constrained attention activation originally used in the SSE, which we call slot-sigmoid, significantly degrades generalization performance. Furthermore, stacking hierarchical SSE layers has been shown to harm performance (Bruno et al., 2021), which limits the power of the overall model.

To overcome these limitations of the SSE, we propose a Universal Mini-Batch Consistent (UMBC) set encoder fθf_{\theta} by allowing the set function fθf_{\theta} to also use arbitrary non-MBC functions. Firstly, we propose normalizing the attention matrix νp(A^)\nu_{p}(\hat{A}) over rows to consider dependency among different elements of the set in the attention operation:

f¯θ:XN×dxνp(A^)𝟏Nk\displaystyle{\bar{f}}_{\theta}:X\in\mathbb{R}^{N\times d_{x}}\mapsto\nu_{p}(\hat{A})\bm{1}_{N}\in\mathbb{R}^{k} (5)
fθ:XN×dxdiag(f¯θ(X))1f^θ(X)k×d\displaystyle f_{\theta}:X\in\mathbb{R}^{N\times d_{x}}\mapsto\text{diag}\left({\bar{f}}_{\theta}(X)\right)^{-1}{\hat{f}}_{\theta}(X)\in\mathbb{R}^{k\times d} (6)

where 𝟏N=(1,,1)N\bm{1}_{N}=(1,\ldots,1)\in\mathbb{R}^{N}. We prove that a UMBC set encoder fθf_{\theta} is permutation invariant, equivariant, and MBC.

Theorem 3.3.

A UMBC function is permutation invariant.

Any strictly positive elementwise function is a valid σ\sigma. For an instance, if we use the identity mapping ν2\nu_{2} with σ()exp()\sigma(\cdot)\coloneqq\exp(\cdot), the attention matrix diag(f¯θ(X))1νp(A^)\text{diag}({\bar{f}}_{\theta}(X))^{-1}\nu_{p}(\hat{A}) is equivalent to applying the softmax to each row of A^\hat{A}, which is hypothesized to break the MBC property by Bruno et al. (2021). However, we show that this does not break the MBC property in Appendix A.2. Intuitively, since

f^θ(X)=Sζ(X)f^θ(S), and f¯θ(X)=Sζ(X)f¯θ(S)\displaystyle{\hat{f}}_{\theta}(X)=\sum_{S\in\zeta(X)}{\hat{f}}_{\theta}(S),\text{ and }{\bar{f}}_{\theta}(X)=\sum_{S\in\zeta(X)}{\bar{f}}_{\theta}(S) (7)

holds for any partition ζ(X)\zeta(X) of the set XX, we can iteratively process each subset Sζ(X)S\in\zeta(X) and aggregate them without losing any information of fθ(X)f_{\theta}(X), i.e., fθf_{\theta} is MBC even when normalizing over the NN elements of the set. Note that the operation outlined above is mathematically equivalent to the softmax, but uses a non-standard implementation. We discuss the implementation and list 5 such valid attention activation functions which satisfy the MBC property in Appendix I.

Theorem 3.4.

Given the slots Σ=[𝐬1𝐬k]k×ds\Sigma=[{\mathbf{s}}_{1}\cdots{\mathbf{s}}_{k}]^{\top}\in\mathbb{R}^{k\times d_{s}}, a UMBC set encoder is mini-batch consistent.

Lastly, we may consider the output of a UMBC set encoder fθ(X)f_{\theta}(X) as either a fixed vector or a set of kk elements. Under the set interpretation, we may therefore apply subsequent functions on the set of cardinality kk. To provide a valid input to subsequent set encoders, it is sufficient to view UMBC as a set to set function φ(Σ;X,θ):Σfθ(X)\varphi(\Sigma;X,\theta):\Sigma\mapsto f_{\theta}(X) for each set X𝒳X\in\mathcal{X}, which is permutation equivariant w.r.t. the slots Σ\Sigma.

Definition 3.5.

A function φ:k×dsk×d\varphi:\mathbb{R}^{k\times d_{s}}\to\mathbb{R}^{k\times d} is said to be permutation equivariant iff φ([Σπ(1)Σπ(k)])=[φ(Σ)π(1)φ(Σ)π(k)]\varphi([\Sigma_{\pi(1)}^{\top}\cdots\Sigma_{\pi(k)}^{\top}]^{\top})=[\varphi(\Sigma)_{\pi(1)}^{\top}\cdots\varphi(\Sigma)_{\pi(k)}^{\top}]^{\top} for all Σk×ds\Sigma\in\mathbb{R}^{k\times d_{s}} and for all π𝔖k\pi\in\mathfrak{S}_{k}, where 𝔖k={π:[k][k]π is bijective}\mathfrak{S}_{k}=\{\pi:[k]\to[k]\mid\pi\text{ is bijective}\} contains all permutations of {1,,k}\{1,\ldots,k\}, and φ(Σ)i,Σi\varphi(\Sigma)_{i},\Sigma_{i} denote ii-th row of φ(Σ)\varphi(\Sigma) and Σ\Sigma, respectively.

Theorem 3.6.

For each input X𝒳X\in\mathcal{X}, φ(Σ;X,θ):Σfθ(X)\varphi(\Sigma;X,\theta):\Sigma\mapsto f_{\theta}(X) is equivariant w.r.t. permutations of the slots Σ\Sigma.

A key insight is that we can leverage non-MBC set encoders such as Set Transformer after a UMBC layer to improve expressive power of an MBC model while still satisfying MBC (Definition 3.2). As a result, with some assumptions, a UMBC set encoder used in combination with any continuously sum decomposable (Zaheer et al., 2017) permutation invariant deep neural network is a universal approximator of the class of continuously sum decomposable functions.

Theorem 3.7.

Let dx=1d_{x}=1 and restrict the domain 𝒳\mathcal{X} to [0,1]M[0,1]^{M}. Suppose that the nonlinear activation function of ϕ\phi has nonzero Taylor coefficients up to degree MM. Then, UMBC used in conjunction with any continuously sum-decomposable permutation-invariant deep neural network with nonlinear activation functions that are not polynomials of finite degrees is a universal approximator of the class of functions ={f:[0,1]Mf is continuous and permutation invariant}\mathcal{F}=\{f:[0,1]^{M}\to\mathbb{R}\ \mid f\text{ is continuous and permutation invariant}\}.

Although we use a non-MBC set encoder on top of UMBC, this does not violate the MBC property. Since we may obtain fθ(X)f_{\theta}(X) by sequentially processing each subset of XX and the resulting set with cardinality kk is assumed small enough to load fθ(X)f_{\theta}(X) in memory, we can directly provide the MBC output of UMBC to the non-MBC set encoder.

Corollary 3.8.

Let g^ω:k×d𝒵\hat{g}_{\omega}:\mathbb{R}^{k\times d}\to\mathcal{Z} be a (non-MBC) set encoder and let fθ:𝒳k×df_{\theta}:\mathcal{X}\to\mathbb{R}^{k\times d} be a UMBC set encoder. Then g^ωfθ\hat{g}_{\omega}\circ f_{\theta} is mini-batch consistent.

For notational convenience, we write gλg_{\lambda} to indicate the composition gλg^ωg_{\lambda}\circ\hat{g}_{\omega} of a set encoder g^ω:k×d𝒵\hat{g}_{\omega}:\mathbb{R}^{k\times d}\to\mathcal{Z} and a decoder gλ:𝒵dyg_{\lambda}:\mathcal{Z}\to\mathbb{R}^{d_{y}}, throughout the paper. Similarly, the parameter λ\lambda denotes (ω,λ)(\omega,\lambda).

3.3 Stochastic Approximation of the Full Set Gradient

Although we can leverage SSE or UMBC at test-time by sequentially processing subsets to obtain the full set representation fθ(X)f_{\theta}(X), at train-time it is infeasible to utilize the gradient of the loss (equation 1) w.r.t. the full set. Computation of the full set gradient with automatic differentiation requires storing the entire computation graph for all forward passes of each subset SS from ζ(X)\zeta(X) denoted as a partition of a set XX, which incurs a prohibitive computational cost for large sets. As a simple approximation, Bruno et al. (2021) propose randomly sampling a single subset Si,jζ(Xi)S_{i,j}\in\zeta(X_{i}) and computing the gradient of the loss ((gλfθ)(Si,j),yi)\ell((g_{\lambda}\circ f_{\theta})(S_{i,j}),y_{i}) based on a single subset at each iteration.

Remark 3.9.

Let ζ(Xi)\zeta(X_{i}) be a partition of set Xi𝒳X_{i}\in\mathcal{X} and Si,jζ(Xi)S_{i,j}\in\zeta(X_{i}) be a subset of XiX_{i}. Then the gradient of 1/ni=1n((gλfθ)(Si,j),yi)1/n\sum_{i=1}^{n}\ell((g_{\lambda}\circ f_{\theta})(S_{i,j}),y_{i}) is a biased estimation of the full set gradient and leads to a suboptimal solution in our experiments. Please see Appendix A.8 for further details.

In order to tackle this issue, we propose an unbiased estimation of the full set gradient which incurs a constant memory overhead. Firstly, we uniformly and independently sample a mini-batch ((X¯i,y¯i))i=1m((\bar{X}_{i},{\bar{y}}_{i}))_{i=1}^{m} from the training dataset ((Xi,yi))i=1n((X_{i},y_{i}))_{i=1}^{n} for every iteration t+t\in\mathbb{N}_{+}. We denote this process by ((X¯i,y¯i))i=1mD[((Xi,yi))i=1n].((\bar{X}_{i},{\bar{y}}_{i}))^{m}_{i=1}\sim D[((X_{i},y_{i}))^{n}_{i=1}]. Then, for each X¯i\bar{X}_{i}, we sample a mini-batch ζ¯t(X¯i)={S¯1,,S¯|ζ¯t(X¯i)|}{\bar{\zeta}}_{t}(\bar{X}_{i})=\{{\bar{S}}_{1},\ldots,{\bar{S}}_{\lvert{\bar{\zeta}}_{t}(\bar{X}_{i})\rvert}\} from the partition ζt(X¯i)={S1,,S|ζt(X¯i)|}\zeta_{t}(\bar{X}_{i})=\{S_{1},\ldots,S_{\lvert\zeta_{t}(\bar{X}_{i})\rvert}\} of X¯i\bar{X}_{i}, i.e., all S¯j{\bar{S}}_{j} are drawn independently and uniformly from ζt(X¯i)\zeta_{t}(\bar{X}_{i}). Denote this process by ζ¯t(X¯i)D[ζt(X¯i)]{\bar{\zeta}}_{t}(\bar{X}_{i})\sim D[\zeta_{t}(\bar{X}_{i})]. Instead of storing the computational graph of all forward passes of subsets in the partition ζt(X¯i)\zeta_{t}(\bar{X}_{i}) of a set X¯i\bar{X}_{i}, we apply StopGrad to all subsets Sζ¯t(X¯i)S\notin{\bar{\zeta}}_{t}(\bar{X}_{i}) as follows:

f^θζ¯t,ζt(X¯i)\displaystyle{\hat{f}}_{\theta}^{{\bar{\zeta}}_{t},\zeta_{t}}(\bar{X}_{i}) =Sζ¯t(X¯i)f^θ(S)+StopGrad(Sζt(X¯i)ζ¯t(X¯i)f^θ(S))\displaystyle=\sum_{S\in{\bar{\zeta}}_{t}(\bar{X}_{i})}{\hat{f}}_{\theta}(S)+\scalebox{0.8}{$\texttt{StopGrad}$}(\sum_{\mathclap{S\in\zeta_{t}(\bar{X}_{i})\setminus{\bar{\zeta}}_{t}(\bar{X}_{i})}}{\hat{f}}_{\theta}(S)) (8)
f¯θζ¯t,ζt(X¯i)\displaystyle{\bar{f}}_{\theta}^{{\bar{\zeta}}_{t},\zeta_{t}}(\bar{X}_{i}) =Sζ¯t(X¯i)f¯θ(S)+StopGrad(Sζt(X¯i)ζ¯t(X¯i)f¯θ(S))\displaystyle=\sum_{S\in{\bar{\zeta}}_{t}(\bar{X}_{i})}{\bar{f}}_{\theta}(S)+\scalebox{0.8}{$\texttt{StopGrad}$}(\sum_{\mathclap{S\in\zeta_{t}(\bar{X}_{i})\setminus{\bar{\zeta}}_{t}(\bar{X}_{i})}}{\bar{f}}_{\theta}(S)) (9)

where, for any function q:θq(θ)q:\theta\mapsto q(\theta), the symbol StopGrad(q(θ))\texttt{StopGrad}(q(\theta)) denotes a constant with its value being q(θ)q(\theta), i.e., StopGrad(q(θ))/θ=0\partial\texttt{StopGrad}(q(\theta))/\partial\theta=0. For simplicity, we omit the superscript ζt\zeta_{t} if there is no ambiguity. Finally, we update both the parameter θ\theta and λ\lambda of the respective encoder and decoder using the gradient of the following functions, respectively at t+t\in\mathbb{N}_{+}:

Lt,1(θ,λ)\displaystyle L_{t,1}(\theta,\lambda) =1mi=1m|ζt(X¯i)||ζ¯t(X¯i)|(gλ(fθζ¯t(X¯i)),yi)\displaystyle=\frac{1}{m}\sum_{i=1}^{m}\frac{|\zeta_{t}(\bar{X}_{i})|}{|{\bar{\zeta}}_{t}(\bar{X}_{i})|}\ell(g_{\lambda}(f_{\theta}^{{\bar{\zeta}}_{t}}(\bar{X}_{i})),y_{i}) (10)
Lt,2(θ,λ)\displaystyle L_{t,2}(\theta,\lambda) =1mi=1m1|ζ¯t(X¯i)|(gλ(fθζ¯t(X¯i)),yi).\displaystyle=\frac{1}{m}\sum_{i=1}^{m}\frac{1}{\lvert{\bar{\zeta}}_{t}(\bar{X}_{i})\rvert}\ell(g_{\lambda}(f_{\theta}^{{\bar{\zeta}}_{t}}(\bar{X}_{i})),y_{i}). (11)

We outline our proposed training method in Algorithm 1. Note that we can apply our algorithm to any set encoder for which a full set representation can be decomposed into a summation of subset representations as in equation 7 such as Deepsets with sum or mean pooling, or SSE which are in fact special cases of UMBC. Furthermore, we can apply the algorithm to any differentiable non-MBC set encoder if we simply place a UMBC layer before the non-MBC function. As a consequence of the StopGrad()\texttt{StopGrad}() operation, if we set |ζ¯t(X¯i)|=1\lvert{\bar{\zeta}}_{t}(\bar{X}_{i})\rvert=1, our method incurs the same computation graph storage cost as randomly sampling a single subset. Moreover, Lt,1(θ,λ)/θ\partial L_{t,1}(\theta,\lambda)/\partial\theta and Lt,2(θ,λ)/λ\partial L_{t,2}(\theta,\lambda)/\partial\lambda are unbiased estimators of L(θ,λ)/θ\partial L(\theta,\lambda)/\partial\theta and L(θ,λ)/λ\partial L(\theta,\lambda)/\partial\lambda, respectively.

Theorem 3.10.

For any t+t\in\mathbb{N}_{+}, Lt,1(θ,λ)θ\frac{\partial L_{t,1}(\theta,\lambda)}{\partial\theta} and Lt,2(θ,λ)λ\frac{\partial L_{t,2}(\theta,\lambda)}{\partial\lambda} are unbiased estimators of L(θ,λ)θ\frac{\partial L(\theta,\lambda)}{\partial\theta} and L(θ,λ)λ\frac{\partial L(\theta,\lambda)}{\partial\lambda} as follows:

𝔼((X¯i,y¯i))i=1m𝔼(ζ¯t(X¯i))i=1m[Lt,1(θ,λ)θ]\displaystyle\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))^{m}_{i=1}}\mathbb{E}_{({\bar{\zeta}}_{t}(\bar{X}_{i}))^{m}_{i=1}}\left[\frac{\partial L_{t,1}(\theta,\lambda)}{\partial\theta}\right] =L(θ,λ)θ\displaystyle=\frac{\partial L(\theta,\lambda)}{\partial\theta} (12)
𝔼((X¯i,y¯i))i=1m𝔼(ζ¯t(X¯i))i=1m[Lt,2(θ,λ)λ]\displaystyle\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))^{m}_{i=1}}\mathbb{E}_{({\bar{\zeta}}_{t}(\bar{X}_{i}))^{m}_{i=1}}\left[\frac{\partial L_{t,2}(\theta,\lambda)}{\partial\lambda}\right] =L(θ,λ)λ,\displaystyle=\frac{\partial L(\theta,\lambda)}{\partial\lambda}, (13)

where the first expectation is taken for ((X¯i,y¯i))i=1mD[((Xi,yi))i=1n]((\bar{X}_{i},{\bar{y}}_{i}))^{m}_{i=1}\sim D[((X_{i},y_{i}))^{n}_{i=1}], and the second expectation is taken for ζ¯t(X¯i)D[ζt(X¯i)]{\bar{\zeta}}_{t}(\bar{X}_{i})\sim D[\zeta_{t}(\bar{X}_{i})] for all i{1,,m}i\in\{1,\ldots,m\}.

Under mild conditions, Theorem 3.10 implies that the updating θ\theta with our method makes progress towards minimizing L(θ,λ)L(\theta,\lambda). We formalize this statement in Appendix B.

4 Experiments

Refer to caption
Refer to caption
(a) DeepSets
Refer to caption
(b) SSE
Refer to caption
(c) UMBC + ST (Ours)
Figure 4: Performance of (a) DeepSets (b) SSE, and (c) UMBC with varying set sizes for amortized clustering on Mixtures of Gaussians. The unbiased estimate of the full set gradient outperforms the biased estimate, and is usually indistinguishable from the full set gradient.

4.1 Amortized Clustering

We consider amortized clustering on a dataset generated from Mixture of KK Gaussians (MoGs) (See Appendix C for dataset construction details). Given a set Xi={𝐱i,j}j=1NiX_{i}=\{{\mathbf{x}}_{i,j}\}_{j=1}^{N_{i}} sampled from a MoGs, the goal is to output the mixing coefficients, and Gaussian mean and variance, which minimizes the negative log-likelihood of the set as follows:

{πj(Xi),μj(Xi),Σj(Xi)}j=1K=fθ(Xi)\displaystyle\{\pi_{j}(X_{i}),\mu_{j}(X_{i}),\Sigma_{j}(X_{i})\}_{j=1}^{K}=f_{\theta}(X_{i}) (14)
(fθ(Xi))=l=1Nilogj=1Kπj(Xi)𝒩(𝐱i,l;Θj(Xi))\displaystyle\ell(f_{\theta}(X_{i}))=-\sum_{l=1}^{N_{i}}\log\sum_{j=1}^{K}\pi_{j}(X_{i})\mathcal{N}({\mathbf{x}}_{i,l};\Theta_{j}(X_{i})) (15)

where Θj(Xi)=(μj(Xi),Σj(Xi))\Theta_{j}(X_{i})=(\mu_{j}(X_{i}),\Sigma_{j}(X_{i})) denotes a mean vector and a diagonal covariance matrix for jj-th Gaussian, and πj(Xi)\pi_{j}(X_{i}) is jj-th mixing coefficient. Note that there is no label yiy_{i} since it is an unsupervised clustering problem. We optimize the parameters of the set encoder θ\theta to minimize the loss over a batch, L(θ)=1/ni=1n(fθ(Xi))L(\theta)=1/n\sum_{i=1}^{n}\ell(f_{\theta}(X_{i})).

Setup. We evaluate training with the full set gradient vs. the unbiased estimation of the full set gradient. In this setting, for gradient computation, MBC models use a subset of 8 elements from a full set of 1024 elements. Non MBC models such as Set Transformer (ST), FSpool (Zhang et al., 2020), and Diff EM (Kim, 2022) are also trained with the set of 8 elements. We compare our UMBC model against Deepsets, SSE, Set Transformer, FSPool, and Diff EM. Note that at test-time all non-MBC models process every 8 element subset from the full set independently and aggregate the representations with mean pooling.

Table 2: Clustering: NLL, mem. usage and wall clock time.
Model MBC NLL(\downarrow) Memory (Kb) Time (Ms)
DeepSets 2.43±2.43\pm0.0040.004 1616 0.46±0.46\pm0.070.07
SSE 2.13±2.13\pm0.0670.067 6161 0.83±0.83\pm0.080.08
SSE (Hierarchical) 2.38±2.38\pm0.0570.057 125125 0.84±0.84\pm0.060.06
FSPool 3.52±3.52\pm0.1920.192 4343 0.79±0.79\pm0.080.08
Diff EM 5.58±5.58\pm0.9660.966 476476 7.11±7.11\pm0.310.31
Set Transformer (ST) 11.6±11.6\pm2.1802.180 225225 2.26±2.26\pm0.1350.135
UMBC + FSPool 2.01±2.01\pm0.0270.027 7070 1.18±1.18\pm0.100.10
UMBC + Diff EM 2.13±2.13\pm0.0840.084 502502 8.57±8.57\pm0.330.33
UMBC + ST 1.84±\textbf{1.84}\pm0.0080.008 100100 1.63±1.63\pm0.110.11

Results. In Figure 4, interestingly, the unbiased estimation of the full set gradient (red) is almost indistinguishable from the full set gradient (blue) for DeepSets and UMBC, while there is a significant gap for SSE. In all cases, the unbiased estimation of the full set gradient outperforms training with the biased gradient approximation with only the set of 8 elements per random sample (green), which is proposed by Bruno et al. (2021). Lastly, as shown in Table 2, we compare all models in terms of generalization performance (NLL), memory usage, and wall-clock time for processing a single subset. All non-MBC models show underperformance due to their violation of the MBC property. However, if we utilize ‘UMBC+’ compositions, the composition becomes MBC with significantly improved performance and little added overhead for memory and time complexity. In contrast, a composition of pure MBC functions, the Hierarchical SSE degrades the performance of SSE. Notably, UMBC with Set Transformer outperforms all other models whereas Set Transformer alone achieves the worst NLL. These results verify expressive power of UMBC in conjuction with non-MBC models.

Refer to caption
Refer to caption
(a)
Refer to caption
(b)
Refer to caption
(c)
Figure 5: (a) Comparison of different models with varying sizes of sets. (b) Memory usage for models to process a single set with different cardinalities denoted as the size of the circles. MBC models with our unbiased full set gradient approximation can process any set size with a constant memory overhead. (c) The effect of our algorithm compared to training with a small random subset and the full set.

4.2 Image Completion

In this task, we are given a set of MM RGB pixel values yi3××3y_{i}\in\mathbb{R}^{3}\times\cdots\times\mathbb{R}^{3} of an image as well as the corresponding 2-dimensional coordinates (𝐱i,1,,𝐱i,M)({\mathbf{x}}_{i,1},\ldots,{\mathbf{x}}_{i,M}) normalized to be in [0,1]2××[0,1]2[0,1]^{2}\times\cdots\times[0,1]^{2}. The goal of the task is to predict RGB pixel values of all MM coordinates of the image. Specifically, given a context set Xi={(𝐱i,cj,yi,cj)}j=1NiX_{i}=\{({\mathbf{x}}_{i,c_{j}},y_{i,c_{j}})\}_{j=1}^{N_{i}} processed by the set encoder, we obtain the set representation fθ(Xi)k×df_{\theta}(X_{i})\in\mathbb{R}^{k\times d}. Then, a decoder gλ:k×d×[0,1]2××[0,1]26××6g_{\lambda}:\mathbb{R}^{k\times d}\times[0,1]^{2}\times\cdots\times[0,1]^{2}\to\mathbb{R}^{6}\times\cdots\times\mathbb{R}^{6} which utilizes both the set representation and the target coordinates, learns to predict a mean and variance for each coordinate of the image as ((μ^i,j,l)l=13,(σ^i,j,l)l=13)j=1M=gλ(𝐳i)((\hat{\mu}_{i,j,l})_{l=1}^{3},(\hat{\sigma}_{i,j,l})_{l=1}^{3})_{j=1}^{M}=g_{\lambda}({\mathbf{z}}_{i}), where 𝐳i=(fθ(Xi),𝐱i,1,,𝐱i,M){\mathbf{z}}_{i}=(f_{\theta}(X_{i}),{\mathbf{x}}_{i,1},\ldots,{\mathbf{x}}_{i,M}). Then we compute the negative log-likelihood of the label set yi=((yi,j,l)l=13)j=1My_{i}=((y_{i,j,l})_{l=1}^{3})_{j=1}^{M}:

(gλ(𝐳i),yi)=j=1Ml=13log𝒩(yi,j,l;μ^i,j,l,σ^i,j,l),\displaystyle\ell(g_{\lambda}({\mathbf{z}}_{i}),y_{i})=-\sum_{j=1}^{M}\sum_{l=1}^{3}\log\mathcal{N}(y_{i,j,l};\hat{\mu}_{i,j,l},\hat{\sigma}_{i,j,l}), (16)

where 𝒩(;μ,σ)\mathcal{N}(\cdot;\mu,\sigma) is a univariate Gaussian probability density function. Finally, we optimize θ\theta and λ\lambda to minimize the loss L(θ,λ)=1/ni=1n(gλ(𝐳i),yi)L(\theta,\lambda)=1/n\sum_{i=1}^{n}\ell(g_{\lambda}({\mathbf{z}}_{i}),y_{i}).

Setup. In our experiments, we impose the restriction that a set encoder is only allowed to compute the gradient with 100 elements of a context set XiX_{i} during training and the model can only process 100 elements of the context set at once at test time. We train the set encoders in a Conditional Neural Process (Garnelo et al., 2018) framework, using 32×3232\times 32 images from the CelebA dataset (Liu et al., 2015). We vary the cardinality of the context set size Ni{100,,500}N_{i}\in\{100,\ldots,500\} and compare the negative log-likelihood (NLL) of each model. For baselines, we compare our UMBC set encoder against: Deepsets, SSE, Hierarchical SSE, Set Transformer (ST), FSPool, and Diff EM. For our UMBC, we use the softmax for σ\sigma in equation 3 and place the ST after the UMBC layer.

Results. First, as shown in Figure 5(a), our UMBC + ST model outperforms all baselines, empirically verifying the expressive power of UMBC. SSE underperforms in terms of NLL due to its constrained architecture. Moreover, stacking hierarchical SSE layers degrades the performance of SSE for larger sets. Note that all MBC set encoders (Deepsets, SSE, Hierarchical SSE and UMBC) in Figure 5(a) are trained with our proposed unbiased gradient approximation in Algorithm 1. On the other hand, we train non-MBC models such as Set Transformer (ST), FSPool, and Diff EM with a randomly sampled subset of 100 elements, and perform mean pooling over all subset representations at test-time to approximate an MBC model.

Additionally, Figure 5(b) shows GPU memory usage for each model while processing sets of varying cardinalities without a memory constraint. The marker size is proportional to the set cardinality. Notably, all four MBC models incur a constant memory overhead to process any set size, as we can apply StopGrad to most of the subset, and compute an unbiased gradient estimate with a fixed sized subset (100). However, memory overhead for all non-MBC models is a function of set size. Thus, Set Transformer uses more than twice the memory of UMBC to achieve a similar log-likelihood on a set of 500 elements.

Lastly, in Figure 5(c), we show how our proposed unbiased training algorithm (red) improves the generalization performance of UMBC models compared to training with a limited subset of 100 elements (green). Notably, the performance of our algorithm is indistinguishable from that of training models with the full set gradient (blue). We present similar plots for Deepsets and SSE in Figures 11(a) and 11(b). Across all models, our training algorithm significantly and consistently improves performance compared to training with random subsets of 100 elements, while requiring the same amount of memory. These empirical results verify both efficiency and effectiveness of our proposed method.

Table 3: Micro F1 score and memory usage on Quadro RTX8000 for long document classification with inverted EURLEX dataset.
Model F1 MBC Memory (MB)
Longformer 56.4756.47 ±0.43\pm 0.43 25,18525\text{,}185
ToBERT 67.1167.11 ±0.87\pm 0.87 38,56338\text{,}563
DeepSets w/ 100 59.9759.97±0.59\pm 0.59 1,2951\text{,}295
DeepSets w/ full 60.8260.82±0.58\pm 0.58 7,3177\text{,}317
SSE w/ 100 67.6067.60±0.17\pm 0.17 1,3191\text{,}319
SSE w/ full 67.9167.91±0.33\pm 0.33 6,7996\text{,}799
UMBC + BERT w/ 100 70.48±0.23\pm\textbf{0.23} 4,9094\text{,}909
UMBC + BERT w/ full 70.23±0.84\pm\textbf{0.84} 11,49711\text{,}497

4.3 Long Document Classification

In this task, we are given a long document Xi=(𝐱i,1,,𝐱i,Ni)X_{i}=({\mathbf{x}}_{i,1},\ldots,{\mathbf{x}}_{i,N_{i}}) consisting of an average of 707.99707.99 words. The goal of this task is to predict a binary multi-label yi=(yi,1,,yi,c){0,1}cy_{i}=(y_{i,1},\ldots,y_{i,c})\in\{0,1\}^{c} of the document, where cc is the number of classes. We ignore the order of words and consider the document as a multiset of words, i.e., a set allowing duplicate elements. Specifically, given a training dataset ((Xi,yi))i=1n((X_{i},y_{i}))_{i=1}^{n}, we process each set XiX_{i} with the set encoder to obtain the set representation fθ(Xi)k×df_{\theta}(X_{i})\in\mathbb{R}^{k\times d}. We then use a decoder gλ:k×dcg_{\lambda}:\mathbb{R}^{k\times d}\to\mathbb{R}^{c} to output the probability of each class and compute the cross entropy loss:

(𝐳i,yi)=j=1c(yi,jlogσ~(zi,j)+(1yi,j)log(1σ~(zi,j)))\displaystyle{\ell({\mathbf{z}}_{i},y_{i})=-\sum_{j=1}^{c}\left(y_{i,j}\log\tilde{\sigma}(z_{i,j})+(1-y_{i,j})\log(1-\tilde{\sigma}(z_{i,j}))\right)}

(17)

where 𝐳i=(zi,1,,zi,c)=(gλfθ)(Xi){\mathbf{z}}_{i}=(z_{i,1},\ldots,z_{i,c})=(g_{\lambda}\circ f_{\theta})(X_{i}) and σ~\tilde{\sigma} denotes the sigmoid function. Finally we optimize θ\theta and λ\lambda to minimize the loss L(θ,λ)=1/ni=1n(𝐳i,yi)L(\theta,\lambda)=1/n\sum_{i=1}^{n}\ell({\mathbf{z}}_{i},y_{i}).

Setup. All models are trained on the inverted EURLEX dataset (Chalkidis et al., 2019) consisting of long legal documents divided into sections. The order of sections are inverted following prior work (Park et al., 2022). To predict a label, we give the whole document to the model without any truncation. We compare the micro F1 of each model.

We compare UMBC to Deepsets, SSE, ToBERT (Pappagari et al., 2019), and Longformer (Beltagy et al., 2020). For Deepsets and SSE, we use the pre-trained word embedding from BERT (Devlin et al., 2019) without positional encoding and 2 layer fully connected (FC) networks for feature extractor ϕ\phi. We use another 3 layer FC network for the decoder. For UMBC, we use the same feature extractor as SSE but instead use the pre-trained BERT as a decoder, with a randomly initialized linear classifier. We remove the positional encoding of BERT for UMBC to ignore word order. For all the MBC models, we train them both with full set denoted as “w/ full” and with our gradient approximation method on a subset of 100 elements denoted as “w/ 100”.

Refer to caption
Figure 6: F1 score with varying the size of subset for gradient computation for our method and random sampling.

Results. As shown in Table 3, our proposed UMBC outperforms all baselines including non-MBC models — Longformer and ToBERT which require excessive amounts of GPU memory for training models with long sequences. This result again verifies the expressive power of UMBC with BERT (a non-MBC model) for long document classification. Moreover, with significantly less GPU memory, all MBC models (Deepsets, SSE, and UMBC) trained with our unbiased gradient approximation using a subset of 100 elements, achieve similar performance to the models trained with full set. Lastly, in Figure 6, we plot the micro F1 score as a function of the cardinality of the subset used for gradient computation when training the UMBC model. Our proposed unbiased gradient approximation (red) shows consistent performance for all subset cardinalities. In contrast, training the model with a small random subset (green) is unstable, resulting in underperformance and higher F1 variance.

Table 4: Acc. and AUC of each model on the Camelyon16 dataset.
Model MBC Accuracy AUROC Accuracy AUROC
Pretrain MBC Finetune
DS-MIL 86.3686.36±0.88\pm 0.88 0.8660.866±0.00\pm 0.00 - -
AB-MIL 86.8286.82±0.00\pm 0.00 0.8840.884±0.01\pm 0.01 - -
DeepSets 82.0282.02±0.65\pm 0.65 0.8480.848±0.01\pm 0.01 82.0282.02±0.65\pm 0.65 0.8480.848±0.01\pm 0.01
SSE 74.7374.73±1.04\pm 1.04 0.7480.748±0.03\pm 0.03 74.5774.57±1.27\pm 1.27 0.7550.755±0.04\pm 0.04
UMBC + ST 87.9187.91±1.41\pm 1.41 0.8740.874±0.01\pm 0.01 88.84\mathbf{88.84}±0.88\pm\mathbf{0.88} 0.892\mathbf{0.892}±0.01\mathbf{\pm 0.01}

4.4 Multiple Instance Learning (MIL)

In MIL, we are given a ‘bag’ of instances with a corresponding bag label, but no labels for each instance within the bag. Labels should not depend on the order of the instances, i.e., MIL can be recast as a set classification problem. Specifically, given a set Xi={𝐱i,j}j=1NiX_{i}=\{{\mathbf{x}}_{i,j}\}_{j=1}^{N_{i}}, the goal is to predict its binary label yi{0,1}y_{i}\in\{0,1\}. For this task, we obtain two streams of set representations and compute the cross entropy loss from the decoder gλ:k×dg_{\lambda}:\mathbb{R}^{k\times d}\to\mathbb{R} output as:

𝐳i,1=max{wϕ(𝐱i,j)+b}j=1Ni,𝐳i,2=fθ(Xi)\displaystyle{\mathbf{z}}_{i,1}=\max\{w^{\top}\phi({\mathbf{x}}_{i,j})+b\}_{j=1}^{N_{i}},\quad{\mathbf{z}}_{i,2}=f_{\theta}(X_{i}) (18)
i=12((𝐳i,1,yi)+(gλ(𝐳i,2),yi)),\displaystyle\mathcal{L}_{i}=\frac{1}{2}\left(\ell({\mathbf{z}}_{i,1},y_{i})+\ell(g_{\lambda}({\mathbf{z}}_{i,2}),y_{i})\right), (19)

where wdhw\in\mathbb{R}^{d_{h}} and bb\in\mathbb{R} are parameters and \ell is the cross entropy loss described in equation 17 with c=1c=1. We optimize all parameters θ,λ,w,and b\theta,\lambda,w,\text{and }b to minimize the loss 1/ni=1ni1/n\sum_{i=1}^{n}\mathcal{L}_{i}. At test time, we predict a label yy_{*} for a set XX_{*} as:

p=12(σ~(𝐳,1)+σ~(gλ(𝐳,2)),y=𝟙{pτ},\displaystyle p_{*}=\frac{1}{2}\left(\tilde{\sigma}({\mathbf{z}}_{*,1})+\tilde{\sigma}(g_{\lambda}({\mathbf{z}}_{*,2})\right),\>y_{*}=\mathbbm{1}\{p_{*}\geq\tau\}, (20)

where 𝐳,1=max{wϕ(𝐱):𝐱X}{\mathbf{z}}_{*,1}=\max\{w^{\top}\phi({\mathbf{x}}):{\mathbf{x}}\in X_{*}\}, 𝐳,2=fθ(X){\mathbf{z}}_{*,2}=f_{\theta}(X_{*}), σ~\tilde{\sigma} is the sigmoid function, 𝟙\mathbbm{1} is indicator function and 0<τ<10<\tau<1 is threshold tuned on the validation set.

Setup. We evaluate all models on the Camelyon16 Whole Slide Image cancer detection dataset (Bejnordi et al., 2017). Each instance consists of a high resolution image of tissue from a medical scan which is pre-processed into 256×256256\times 256 patches of RGB pixels. After pre-processing, the average number of patches in a single set is over 9,300 (7.3GB), making each input roughly equivalent to processing 1% of ImageNet1k (Deng et al., 2009). The largest input in the training set contains 32,382 patches (25.4 GB). We utilize a ResNet18 (He et al., 2016) which is pretrained on Camelyon16 (Li et al., 2021) via SimCLR (Chen et al., 2020) as a backbone feature extractor whose weights can be downloaded from this repository111https://github.com/binli123/dsmil-wsi. Our goal is to first pretrain MBC set encoders on the extracted features, and then use the unbiased estimation of the full set gradient to fine-tune the feature extractor on the full input sets. We evaluate the performance of UMBC against non-MBC MIL baselines: DS-MIL (Li et al., 2021) and AB-MIL (Ilse et al., 2018), as well as MBC baselines: DeepSets and SSE.

Results. As shown in Table 4, our UMBC model achieves the best accuracy and competitive AUROC score. Note that SSE shows the worst performance due to its constrained architecture, which even underperforms DeepSets in this task. These empirical results again verify the expressive power of our UMBC model. Moreover, we can further improve the performance of UMBC via fine-tuning the backbone network, ResNet18, which is only feasible as a consequence of our unbiased full set gradient approximation which incurs constant memory overhead. However, it is not possible for the non-MBC models to fine-tune with the ResNet18 since it is computationally prohibitive to compute the gradient of the ResNet18 with sets consisting of tens of thousands of patches with 256×256256\times 256 resolution.

4.5 Ablation Study

To validate effectiveness of activation functions σ\sigma for attention in equation 3, we train UMBC + Set Transformer with different activation functions listed in Table 15 for the amortized clustering and MIL pretraining tasks. As shown in Figures 7 and 7, softmax attention outperforms all the other activation functions whereas the slot-sigmoid used for attention in SSE underperforms. This experiment highlights the importance of choosing the proper activation function for attention, which is enabled by our UMBC framework.

Refer to caption
Figure 7: σ\sigma for amortized clustering
Activation Acc. AUC
slot-sigmoid 83.7283.72±1.45\pm 1.45 0.8460.846±0.01\pm 0.01
slot-exp 83.2683.26±0.42\pm 0.42 0.8500.850±0.01\pm 0.01
sigmoid 81.7181.71±4.09\pm 4.09 0.8470.847±0.03\pm 0.03
slot-softmax 84.8184.81±0.04\pm 0.04 0.8480.848±0.01\pm 0.01
softmax 87.91±1.41\pm\textbf{1.41} 0.874±0.01\pm\textbf{0.01}
Table 5: σ\sigma for MIL

5 Limitations and Future Work

One potential limitation of our method is a higher time complexity needed for our proposed unbiased gradient estimation during mini-batch training. If nn represents the size of a single subset, and NN represents the size of the whole set, then the naive sampling strategy of Bruno et al. (2021) has a time complexity of 𝒪(nk)\mathcal{O}(nk) while our full set gradient approximation has a complexity of 𝒪(Nk)\mathcal{O}(Nk) during training since we must process the full set. However, we note some things we gain in exchange for this higher time complexity below.

Our unbiased gradient approximation can achieve higher performance than the biased sampling of a single subset as denoted in our experiments (specifically, see Figures 4, 5(c) and 6). Additionally, due to the stop gradient operation in Equation 9, UMBC achieves a constant memory overhead for any training set size. Our experiment in Figure 5(b) shows a constant memory overhead for SSE and DeepSets only because we apply our unbiased gradient approximation to those models in that experiment. The original models as they were presented in the original works do not have a constant memory overhead. As a result, our method can process huge sets during training, and practically any GPU size can be accommodated by adjusting the size of the gradient set. For example, the average set in the experiment on Camelyon16 contains 9329 patches (7.3 GB), while the largest input in the training set contains 32,382 patches (25.46 GB). Even though the set sizes are large, models can be trained on all inputs using a single 12GB GPU due to the constant memory overhead.

Another potential limitation is that UMBC (and SSE) use a cross attention layer with parameterized slots in order to achieve an MBC model. However, the fixed parameters, which are independent to the input set, can be seen as a type of bottleneck in the attention layer which is not present in traditional self-attention. Therefore we look forward to seeing future work which may find ways to make the slot parameters depend on the input set, which may increase overall model expressivity.

6 Conclusion

In order to overcome the limited expressive power and training scalability of existing MBC set functions, such as DeepSets and SSE, we have proposed Universal MBC set functions that allow mixing both MBC and non-MBC components to leverage a broader range of architectures which increases model expressivity while universally maintaining the MBC property. Additionally, we generalized MBC attention activation functions, showing that many functions, including the softmax, are MBC. Furthermore, for training scalability, we have proposed an unbiased approximation to the full set gradient with a constant memory overhead for processing a set of any size. Lastly, we have performed extensive experiments to verify the efficiency and efficacy of our scalable set encoding framework, and theoretically shown that UMBC is a universal approximator of continuous permutation invariant functions and converges to stationary points of the total loss with the full set.

Acknowledgements

This work was supported by Institute of Information & communications Technology Planning & Evaluation (IITP) grant funded by the Korea government(MSIT) (No.2019-0-00075, Artificial Intelligence Graduate School Program(KAIST)), the Engineering Research Center Program through the National Research Foundation of Korea (NRF) funded by the Korean Government MSIT (NRF-2018R1A5A1059921), Institute of Information & communications Technology Planning & Evaluation (IITP) grant funded by the Korea government(MSIT) (No. 2021-0-02068, Artificial Intelligence Innovation Hub), Institute of Information & communications Technology Planning & Evaluation (IITP) grant funded by the Korea government(MSIT) (No. 2022-0-00184, Development and Study of AI Technologies to Inexpensively Conform to Evolving Policy on Ethics), Institute of Information & communications Technology Planning & Evaluation (IITP) grant funded by the Korea government(MSIT) (No.2022-0-00713), and Samsung Electronics (IO201214-08145-01)

References

  • Ba et al. (2016) Ba, J. L., Kiros, J. R., and Hinton, G. E. Layer normalization. arXiv preprint arXiv:1607.06450, 2016.
  • Bejnordi et al. (2017) Bejnordi, B. E., Veta, M., Van Diest, P. J., Van Ginneken, B., Karssemeijer, N., Litjens, G., Van Der Laak, J. A., Hermsen, M., Manson, Q. F., Balkenhol, M., et al. Diagnostic assessment of deep learning algorithms for detection of lymph node metastases in women with breast cancer. Jama, 318(22):2199–2210, 2017.
  • Beltagy et al. (2020) Beltagy, I., Peters, M. E., and Cohan, A. Longformer: The long-document transformer. arXiv preprint arXiv:2004.05150, 2020.
  • Bruno et al. (2021) Bruno, A., Willette, J., Lee, J., and Hwang, S. J. Mini-batch consistent slot set encoder for scalable set encoding. Advances in Neural Information Processing Systems, 34:21365–21374, 2021.
  • Chalkidis et al. (2019) Chalkidis, I., Fergadiotis, E., Malakasiotis, P., and Androutsopoulos, I. Large-scale multi-label text classification on EU legislation. In Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics, pp.  6314–6322. Association for Computational Linguistics, 2019.
  • Chen et al. (2020) Chen, T., Kornblith, S., Norouzi, M., and Hinton, G. A simple framework for contrastive learning of visual representations. In International conference on machine learning, pp. 1597–1607. PMLR, 2020.
  • Deng et al. (2009) Deng, J., Dong, W., Socher, R., Li, L.-J., Li, K., and Fei-Fei, L. Imagenet: A large-scale hierarchical image database. In 2009 IEEE conference on computer vision and pattern recognition, pp.  248–255. Ieee, 2009.
  • Devlin et al. (2019) Devlin, J., Chang, M.-W., Lee, K., and Toutanova, K. BERT: Pre-training of deep bidirectional transformers for language understanding. In Proceedings of NAACL-HLT, pp.  4171–4186, 2019.
  • Fehrman et al. (2020) Fehrman, B., Gess, B., and Jentzen, A. Convergence rates for the stochastic gradient descent method for non-convex objective functions. Journal of Machine Learning Research, 21:136, 2020.
  • Garnelo et al. (2018) Garnelo, M., Rosenbaum, D., Maddison, C., Ramalho, T., Saxton, D., Shanahan, M., Teh, Y. W., Rezende, D., and Eslami, S. A. Conditional neural processes. In International Conference on Machine Learning, pp. 1704–1713. PMLR, 2018.
  • He et al. (2016) He, K., Zhang, X., Ren, S., and Sun, J. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp.  770–778, 2016.
  • Ilse et al. (2018) Ilse, M., Tomczak, J., and Welling, M. Attention-based deep multiple instance learning. In International conference on machine learning, pp. 2127–2136. PMLR, 2018.
  • Jurafsky & Martin (2008) Jurafsky, D. and Martin, J. H. Speech and language processing: An introduction to speech recognition, computational linguistics and natural language processing. Upper Saddle River, NJ: Prentice Hall, 2008.
  • Kawaguchi et al. (2022) Kawaguchi, K., Zhang, L., and Deng, Z. Understanding dynamics of nonlinear representation learning and its application. Neural Computation, 34(4):991–1018, 2022.
  • Kim (2022) Kim, M. Differentiable expectation-maximization for set representation learning. In International Conference on Learning Representations, 2022.
  • Kingma & Ba (2015) Kingma, D. P. and Ba, J. Adam: A method for stochastic optimization. In International Conference on Learning Representations, 2015.
  • Kingma & Welling (2013) Kingma, D. P. and Welling, M. Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114, 2013.
  • Lee et al. (2019) Lee, J., Lee, Y., Kim, J., Kosiorek, A., Choi, S., and Teh, Y. W. Set transformer: A framework for attention-based permutation-invariant neural networks. In International conference on machine learning, pp. 3744–3753. PMLR, 2019.
  • Lee et al. (2016) Lee, J. D., Simchowitz, M., Jordan, M. I., and Recht, B. Gradient descent only converges to minimizers. In Conference on learning theory, pp.  1246–1257. PMLR, 2016.
  • Li et al. (2021) Li, B., Li, Y., and Eliceiri, K. W. Dual-stream multiple instance learning network for whole slide image classification with self-supervised contrastive learning. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pp.  14318–14328, 2021.
  • Liu et al. (2015) Liu, Z., Luo, P., Wang, X., and Tang, X. Deep learning face attributes in the wild. In Proceedings of International Conference on Computer Vision (ICCV), December 2015.
  • Locatello et al. (2020) Locatello, F., Weissenborn, D., Unterthiner, T., Mahendran, A., Heigold, G., Uszkoreit, J., Dosovitskiy, A., and Kipf, T. Object-centric learning with slot attention. Advances in Neural Information Processing Systems, 33:11525–11538, 2020.
  • Loshchilov & Hutter (2019) Loshchilov, I. and Hutter, F. Decoupled weight decay regularization. In International Conference on Learning Representations, 2019.
  • Mertikopoulos et al. (2020) Mertikopoulos, P., Hallak, N., Kavis, A., and Cevher, V. On the almost sure convergence of stochastic gradient descent in non-convex problems. Advances in Neural Information Processing Systems, 33:1117–1128, 2020.
  • Mialon et al. (2021) Mialon, G., Chen, D., d’Aspremont, A., and Mairal, J. A trainable optimal transport embedding for feature aggregation and its relationship to attention. In International Conference on Learning Representations, 2021.
  • Murphy et al. (2019) Murphy, R. L., Srinivasan, B., Rao, V., and Ribeiro, B. Janossy pooling: Learning deep permutation-invariant functions for variable-size inputs. In International Conference on Learning Representations, 2019.
  • Pappagari et al. (2019) Pappagari, R., Zelasko, P., Villalba, J., Carmiel, Y., and Dehak, N. Hierarchical transformers for long document classification. In 2019 IEEE Automatic Speech Recognition and Understanding Workshop (ASRU), pp.  838–844. IEEE, 2019.
  • Park et al. (2022) Park, H., Vyas, Y., and Shah, K. Efficient classification of long documents using transformers. In Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers), pp.  702–709, 2022.
  • Pinkus (1999) Pinkus, A. Approximation theory of the MLP model in neural networks. Acta numerica, 8:143–195, 1999.
  • Quellec et al. (2017) Quellec, G., Cazuguel, G., Cochener, B., and Lamard, M. Multiple-instance learning for medical image and video analysis. IEEE reviews in biomedical engineering, 10:213–234, 2017.
  • Rolnick & Tegmark (2018) Rolnick, D. and Tegmark, M. The power of deeper networks for expressing natural functions. In International Conference on Learning Representations, 2018.
  • Vaswani et al. (2017) Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, Ł., and Polosukhin, I. Attention is all you need. Advances in neural information processing systems, 30, 2017.
  • Wagstaff et al. (2022) Wagstaff, E., Fuchs, F. B., Engelcke, M., Osborne, M. A., and Posner, I. Universal approximation of functions on sets. Journal of Machine Learning Research, 23(151):1–56, 2022.
  • Zaheer et al. (2017) Zaheer, M., Kottur, S., Ravanbakhsh, S., Poczos, B., Salakhutdinov, R. R., and Smola, A. J. Deep sets. Advances in neural information processing systems, 30, 2017.
  • Zhang et al. (2020) Zhang, Y., Hare, J., and Prügel-Bennett, A. Fspool: Learning set representations with featurewise sort pooling. In International Conference on Learning Representations, 2020.

Appendix A Proofs

A.1 Proof of Theorem 3.3

Proof.

Let 𝔖N\mathfrak{S}_{N} be a set of all permutations on {1,,N}\{1,\ldots,N\} and let π𝔖N\pi\in\mathfrak{S}_{N} be given. For a given set X=[𝐱1𝐱N]N×dxX=[{\mathbf{x}}_{1}\cdots{\mathbf{x}}_{N}]^{\top}\in\mathbb{R}^{N\times d_{x}}, and permutation π\pi, we can construct a permutation matrix PN×NP\in\mathbb{R}^{N\times N} such that

PX=[𝐱π(1)𝐱π(N)].PX=\begin{bmatrix}-{\mathbf{x}}_{\pi(1)}^{\top}-\\ \vdots\\ -{\mathbf{x}}_{\pi(N)}^{\top}-\end{bmatrix}.

Since we apply the feature extractor ϕ\phi independently to each element of the set XX,

Φ(PX)=[ϕ(𝐱π(1))ϕ(𝐱π(N))]=PΦ(X).\Phi(PX)=\begin{bmatrix}-\phi({\mathbf{x}}_{\pi(1)})^{\top}-\\ \vdots\\ -\phi({\mathbf{x}}_{\pi(N)})^{\top}-\end{bmatrix}=P\Phi(X).

With an elementwise, strictly positive activation function σ\sigma,

σ(d1Q(Φ(PX)WK))\displaystyle\sigma(\sqrt{d^{-1}}\cdot Q(\Phi(PX)W^{K})^{\top}) =σ(d1Q(PΦ(X)WK))\displaystyle=\sigma(\sqrt{d^{-1}}\cdot Q(P\Phi(X)W^{K})^{\top})
=σ(d1Q(Φ(X)WK)P)\displaystyle=\sigma(\sqrt{d^{-1}}\cdot Q(\Phi(X)W^{K})^{\top}P^{\top})
=σ(d1QK(X))P\displaystyle=\sigma(\sqrt{d^{-1}}\cdot QK(X)^{\top})P^{\top}
=A^P.\displaystyle=\hat{A}P^{\top}.

For p=1p=1, the un-normalized attention score with the permutation π\pi is

ν1(A^P)=[A^1,π(1)/i=1kA^i,π(1)A^1,π(N)/i=1kA^i,π(N)A^k,π(1)/i=1kA^i,π(1)A^k,π(N)/i=1kA^i,π(N)]=ν1(A^)P.\nu_{1}(\hat{A}P^{\top})=\begin{bmatrix}\hat{A}_{1,\pi(1)}/\sum_{i=1}^{k}\hat{A}_{i,\pi(1)}&\cdots&\hat{A}_{1,\pi(N)}/\sum_{i=1}^{k}\hat{A}_{i,\pi(N)}\\ \vdots&\ddots&\vdots\\ \hat{A}_{k,\pi(1)}/\sum_{i=1}^{k}\hat{A}_{i,\pi(1)}&\cdots&\hat{A}_{k,\pi(N)}/\sum_{i=1}^{k}\hat{A}_{i,\pi(N)}\end{bmatrix}=\nu_{1}(\hat{A})P^{\top}.

Since ν2\nu_{2} is the identity mapping, νp(A^P)=νp(A^)P\nu_{p}(\hat{A}P^{\top})=\nu_{p}(\hat{A})P^{\top} for p=1,2p=1,2.

Now, we consider the matrix multiplication of

νp(A^)P=[νp(A^)1,π(1)νp(A^)1,π(N)νp(A^)k,π(1)νp(A^)k,π(N)] and PΦ(X)WV=[ϕ(𝐱π(1))WVϕ(𝐱π(N))WV].\nu_{p}(\hat{A})P^{\top}=\begin{bmatrix}\nu_{p}(\hat{A})_{1,\pi(1)}&\cdots&\nu_{p}(\hat{A})_{1,\pi(N)}\\ \vdots&\ddots&\vdots\\ \nu_{p}(\hat{A})_{k,\pi(1)}&\cdots&\nu_{p}(\hat{A})_{k,\pi(N)}\end{bmatrix}\text{ and }P\Phi(X)W^{V}=\begin{bmatrix}\phi({\mathbf{x}}_{\pi(1)})^{\top}W^{V}\\ \vdots\\ \phi({\mathbf{x}}_{\pi(N)})^{\top}W^{V}\end{bmatrix}.

Since

[νp(A^)1,π(1)νp(A^)1,π(N)νp(A^)k,π(1)νp(A^)k,π(N)][ϕ(𝐱π(1))WVϕ(𝐱π(N))WV]\displaystyle\begin{bmatrix}\nu_{p}(\hat{A})_{1,\pi(1)}&\cdots&\nu_{p}(\hat{A})_{1,\pi(N)}\\ \vdots&\ddots&\vdots\\ \nu_{p}(\hat{A})_{k,\pi(1)}&\cdots&\nu_{p}(\hat{A})_{k,\pi(N)}\end{bmatrix}\begin{bmatrix}\phi({\mathbf{x}}_{\pi(1)})^{\top}W^{V}\\ \vdots\\ \phi({\mathbf{x}}_{\pi(N)})^{\top}W^{V}\end{bmatrix} =[j=1Nνp(A^)1,π(j)ϕ(𝐱π(j))WVj=1Nνp(A^)k,π(j)ϕ(𝐱π(j))WV]\displaystyle=\begin{bmatrix}\sum_{j=1}^{N}\nu_{p}(\hat{A})_{1,\pi(j)}\phi({\mathbf{x}}_{\pi(j)})^{\top}W^{V}\\ \vdots\\ \sum_{j=1}^{N}\nu_{p}(\hat{A})_{k,\pi(j)}\phi({\mathbf{x}}_{\pi(j)})^{\top}W^{V}\\ \end{bmatrix}
=[j=1Nνp(A^)1,jϕ(𝐱j)WVj=1Nνp(A^)k,jϕ(𝐱j)WV]\displaystyle=\begin{bmatrix}\sum_{j=1}^{N}\nu_{p}(\hat{A})_{1,j}\phi({\mathbf{x}}_{j})^{\top}W^{V}\\ \vdots\\ \sum_{j=1}^{N}\nu_{p}(\hat{A})_{k,j}\phi({\mathbf{x}}^{\top}_{j})W^{V}\\ \end{bmatrix}
=νp(A^)V(X),\displaystyle=\nu_{p}(\hat{A})V(X),

f^θ\hat{f}_{\theta} is permutation invariant. Since f¯θ(X)i=j=1Nνp(A^)i,j=j=1Nνp(A^)i,π(j){\bar{f}}_{\theta}(X)_{i}=\sum_{j=1}^{N}\nu_{p}(\hat{A})_{i,j}=\sum_{j=1}^{N}\nu_{p}(\hat{A})_{i,\pi(j)}, diag(f¯θ(X))1\text{diag}\left({\bar{f}}_{\theta}(X)\right)^{-1} is also invariant with respect to the permutation of input XX, which leads to the conclusion that fθ(PX)=diag(f¯θ(PX))1f^θ(PX)=diag(f¯θ(X))1f^θ(X).f_{\theta}(PX)=\text{diag}\left({\bar{f}}_{\theta}(PX)\right)^{-1}\hat{f}_{\theta}(PX)=\text{diag}({\bar{f}}_{\theta}(X))^{-1}\hat{f}_{\theta}(X).
fθ\therefore f_{\theta} is permutation invariant. ∎

A.2 Proof for Theorem 3.4

Proof.

Let input set XN×dxX\in\mathbb{R}^{N\times d_{x}} be given and let ζ(X)={S1,,Sl}\zeta(X)=\{S_{1},\ldots,S_{l}\} be a partition of XX with |Si|=Ni|S_{i}|=N_{i}, i.e., X=i=1lSiX=\bigcup_{i=1}^{l}S_{i} and SiSj=S_{i}\cap S_{j}=\emptyset for iji\neq j. Since a Universal MBC set encoder is permutation invariant, without loss of generality we can assume that,

K(X)=[K(X)1K(X)l],V(X)=[V(X)1V(X)l]\displaystyle K(X)=\begin{bmatrix}K(X)_{1}\\ \vdots\\ K(X)_{l}\end{bmatrix},\quad V(X)=\begin{bmatrix}V(X)_{1}\\ \vdots\\ V(X)_{l}\end{bmatrix} (21)

where K(X)i=Φ(Si)WKNi×dK(X)_{i}=\Phi(S_{i})W^{K}\in\mathbb{R}^{N_{i}\times d} and V(X)i=Φ(Si)WVNi×dV(X)_{i}=\Phi(S_{i})W^{V}\in\mathbb{R}^{N_{i}\times d} for i=1,li=1\ldots,l. Then we can express the matrix νp(A^)\nu_{p}(\hat{A}) as follows:

νp(A^)=[νp(A^(1))νp(A^(l))],\displaystyle\nu_{p}(\hat{A})=\begin{bmatrix}\nu_{p}(\hat{A}^{(1)})\cdots\nu_{p}(\hat{A}^{(l)})\end{bmatrix}, (22)

where A^(i)=σ(d1QK(X)i)k×Ni\hat{A}^{(i)}=\sigma(\sqrt{d^{-1}}\cdot QK(X)^{\top}_{i})\in\mathbb{R}^{k\times N_{i}} for i=1,li=1\ldots,l since νp(A^)i,j\nu_{p}(\hat{A})_{i,j} is independent to νp(A^)i,q\nu_{p}(\hat{A})_{i,q} for all qjq\neq j.

Since

f^θ(X)=[νp(A^(1))νp(A^(l))][V(X)1V(X)l],\hat{f}_{\theta}(X)=\begin{bmatrix}\nu_{p}(\hat{A}^{(1)})\cdots\nu_{p}(\hat{A}^{(l)})\end{bmatrix}\begin{bmatrix}V(X)_{1}\\ \vdots\\ V(X)_{l}\end{bmatrix},

the following equality holds

f^θ(X)=[νp(A^(1))νp(A^(l))][V(X)1V(X)l]=i=1lνp(A^(i))V(X)i=i=1lf^θ(Si).\displaystyle\begin{split}\hat{f}_{\theta}(X)&=\begin{bmatrix}\nu_{p}(\hat{A}^{(1)})\cdots\nu_{p}(\hat{A}^{(l)})\end{bmatrix}\begin{bmatrix}V(X)_{1}\\ \vdots\\ V(X)_{l}\end{bmatrix}\\ &=\sum_{i=1}^{l}\nu_{p}(\hat{A}^{(i)})V(X)_{i}\\ &=\sum_{i=1}^{l}\hat{f}_{\theta}(S_{i}).\end{split} (23)

Thus, f^θ(X){\hat{f}}_{\theta}(X) is mini-batch consistent.

Since

f¯θ(X)i=q=1lj=1Niνp(A^(q))i,j,{\bar{f}}_{\theta}(X)_{i}=\sum_{q=1}^{l}\sum_{j=1}^{N_{i}}\nu_{p}(\hat{A}^{(q)})_{i,j},

we can decompose f¯θ(X){\bar{f}}_{\theta}(X) into a summation of f¯θ(Si){\bar{f}}_{\theta}(S_{i}) as

f¯θ(X)=i=1l(j=1Niνp(A^(i))1,j,,j=1Niνp(A^(i))k,j)=i=1lνp(A^(i))𝟏Ni=i=1lf¯θ(Si),\displaystyle\begin{split}{\bar{f}}_{\theta}(X)&=\sum_{i=1}^{l}\left(\sum_{j=1}^{N_{i}}\nu_{p}(\hat{A}^{(i)})_{1,j},\ldots,\sum_{j=1}^{N_{i}}\nu_{p}(\hat{A}^{(i)})_{k,j}\right)^{\top}\\ &=\sum_{i=1}^{l}\nu_{p}(\hat{A}^{(i)})\bm{1}_{N_{i}}\\ &=\sum_{i=1}^{l}{\bar{f}}_{\theta}(S_{i}),\end{split} (24)

where 𝟏Ni=(1,,1)Ni\bm{1}_{N_{i}}=(1,\ldots,1)\in\mathbb{R}^{N_{i}}. It implies that f¯θ(X){\bar{f}}_{\theta}(X) is mini-batch consistent (MBC). Combining equation 23 and equation 24

fθ(X)=diag(i=1lf¯θ(Si))1(i=1lf^θ(Si)).f_{\theta}(X)=\text{diag}\left(\sum_{i=1}^{l}{\bar{f}}_{\theta}(S_{i})\right)^{-1}\left(\sum_{i=1}^{l}{\hat{f}}_{\theta}(S_{i})\right).

Now, we define a function,

h({fθ(S1),,fθ(Sl)})h1({f¯θ(S1),,f¯θ(Sl)})h2({f^θ(S1),,f^θ(Sl)}),\displaystyle h(\{f_{\theta}(S_{1}),\ldots,f_{\theta}(S_{l})\})\coloneqq h_{1}(\{{\bar{f}}_{\theta}(S_{1}),\ldots,{\bar{f}}_{\theta}(S_{l})\})\cdot h_{2}(\{{\hat{f}}_{\theta}(S_{1}),\ldots,{\hat{f}}_{\theta}(S_{l})\}),

where

h1({f¯θ(S1),f¯θ(Sl)}diag(i=1lf¯θ(Si))1\displaystyle h_{1}(\{{\bar{f}}_{\theta}(S_{1}),\ldots{\bar{f}}_{\theta}(S_{l})\}\coloneqq\text{diag}\left(\sum_{i=1}^{l}{\bar{f}}_{\theta}(S_{i})\right)^{-1}
h2({f^θ(S1),,f^θ(Sl)})i=1lf^θ(Si).\displaystyle h_{2}(\{{\hat{f}}_{\theta}(S_{1}),\ldots,{\hat{f}}_{\theta}(S_{l})\})\coloneqq\sum_{i=1}^{l}{\hat{f}}_{\theta}(S_{i}).

Then fθ(X)=h({fθ(S1),,fθ(Sl)})f_{\theta}(X)=h(\{f_{\theta}(S_{1}),\ldots,f_{\theta}(S_{l})\}). Since ζ(X)\zeta(X) is arbitrary, fθf_{\theta} is mini-batch consistent. ∎

A.3 Proof of Corollary 3.8

Proof.

Let g^ω:k×d𝒵\hat{g}_{\omega}:\mathbb{R}^{k\times d}\to\mathcal{Z} be an arbitrary set encoder and let fθ:𝒳k×df_{\theta}:\mathcal{X}\to\mathbb{R}^{k\times d} be a UMBC set encoder. Given a set X𝒳X\in\mathcal{X} and a partition ζ(X)\zeta(X), we get

fθ(X)=diag(Sζ(X)f¯θ(S))1(Sζ(X)f^θ(S))k×d,f_{\theta}(X)=\text{diag}\left(\sum_{S\in\zeta(X)}{\bar{f}}_{\theta}(S)\right)^{-1}\left(\sum_{S\in\zeta(X)}{\hat{f}}_{\theta}(S)\right)\in\mathbb{R}^{k\times d},

as shown in section A.2. We assume that kk is small enough so that we can load fθ(X)f_{\theta}(X) in memory after we compute fθ(X)f_{\theta}(X). As a consequence, we can directly evaluate g^ω(fθ(X))\hat{g}_{\omega}(f_{\theta}(X)) without partitioning fθ(X)f_{\theta}(X) into smaller subsets {S1,,Sl}\{S_{1},\ldots,S_{l}\} and aggregating g^ω(Si)\hat{g}_{\omega}(S_{i}).
g^ωfθ\therefore\hat{g}_{\omega}\circ f_{\theta} is mini-batch consistent. ∎

A.4 Proof of Theorem 3.6

Proof.

Let π𝔖k\pi\in\mathfrak{S}_{k} be a permutation on {1,k}\{1\ldots,k\} and let 𝔖k\mathfrak{S}_{k} be a set of all permutations on {1,,k}\{1,\ldots,k\}. Define a matrix

Z=[𝐳1𝐳k]φ(Σ;X,θ)k×dZ=\begin{bmatrix}-{\mathbf{z}}^{\top}_{1}-\\ \vdots\\ -{\mathbf{z}}^{\top}_{k}-\end{bmatrix}\coloneqq\varphi(\Sigma;X,\theta)\in\mathbb{R}^{k\times d}

with the input set XN×dxX\in\mathbb{R}^{N\times d_{x}} and the given slots Σ=[𝐬1𝐬k]k×ds\Sigma=[{\mathbf{s}}_{1}\cdots{\mathbf{s}}_{k}]^{\top}\in\mathbb{R}^{k\times d_{s}}. Then we can identify a permutation matrix Pk×kP\in\mathbb{R}^{k\times k} such that,

PΣ=[𝐬π(1)𝐬π(k)].P\Sigma=\begin{bmatrix}-{\mathbf{s}}^{\top}_{\pi(1)}-\\ \vdots\\ -{\mathbf{s}}^{\top}_{\pi(k)}-\end{bmatrix}.

Since the query QQ with the permutation PP is PΣWQ=PQP\Sigma W^{Q}=PQ, the un-normalized attention score with the permutation matrix PP is

σ(d1PQK(X))=Pσ(d1QK(X))=PA^.\sigma(\sqrt{d^{-1}}\cdot\ PQK(X)^{\top})=P\sigma(\sqrt{d^{-1}}\cdot\ QK(X)^{\top})=P\hat{A}.

Since the normalization matrix f¯θ(X){\bar{f}}_{\theta}(X) is a function of the slots Σ\Sigma, we define a new normalization matrix by permuting the slots Σ\Sigma with the given permutation matrix PP as

diag(f¯θ(X);PΣ)1=[1c11c21ck]\text{diag}\left({\bar{f}}_{\theta}(X);P\Sigma\right)^{-1}=\begin{bmatrix}\frac{1}{{c}_{1}}&&&\\ &\frac{1}{{c}_{2}}&&\\ &&\ddots&\\ &&&\frac{1}{{c}_{k}}\end{bmatrix}

where ci=j=1Nνp(PA^)i,j{c}_{i}=\sum_{j=1}^{N}\nu_{p}(P\hat{A})_{i,j} for i=1,,ki=1,\ldots,k. Note that

ν1(PA^)\displaystyle\nu_{1}(P\hat{A}) =[A^π(1),1/i=1kA^π(i),1A^π(1),N/i=1kA^π(i),NA^π(k),1/i=1kA^π(i),1A^π(k),N/i=1kA^π(i),N]\displaystyle=\begin{bmatrix}\hat{A}_{\pi(1),1}/\sum_{i=1}^{k}\hat{A}_{\pi(i),1}&\cdots&\hat{A}_{\pi(1),N}/\sum_{i=1}^{k}\hat{A}_{\pi(i),N}\\ \vdots&\ddots&\vdots\\ \hat{A}_{\pi(k),1}/\sum_{i=1}^{k}\hat{A}_{\pi(i),1}&\cdots&\hat{A}_{\pi(k),N}/\sum_{i=1}^{k}\hat{A}_{\pi(i),N}\end{bmatrix}
=PA^diag(i=1kA^i,1,,i=1kA^i,N)1\displaystyle=P\hat{A}\cdot\text{diag}\left(\sum_{i=1}^{k}\hat{A}_{i,1},\ldots,\sum_{i=1}^{k}\hat{A}_{i,N}\right)^{-1}
=Pν1(A^)\displaystyle=P\nu_{1}(\hat{A})

and ν2(PA^)=Pν2(A^)\nu_{2}(P\hat{A})=P\nu_{2}(\hat{A}).

Then we get ci=f¯θ(X)π(i){c}_{i}={\bar{f}}_{\theta}(X)_{\pi(i)} since

ci\displaystyle{c}_{i} =j=1Nνp(PA^)i,j\displaystyle=\sum_{j=1}^{N}\nu_{p}(P\hat{A})_{i,j}
=j=1N(Pνp(A^))i,j\displaystyle=\sum_{j=1}^{N}(P\nu_{p}(\hat{A}))_{i,j}
=j=1Nl=1kPi,lνp(A^)l,j\displaystyle=\sum_{j=1}^{N}\sum_{l=1}^{k}P_{i,l}\nu_{p}(\hat{A})_{l,j}
=l=1kPi,l(j=1Nνp(A^)l,j)\displaystyle=\sum_{l=1}^{k}P_{i,l}\left(\sum_{j=1}^{N}\nu_{p}(\hat{A})_{l,j}\right)
=l=1kPi,lf¯θ(X)l\displaystyle=\sum_{l=1}^{k}P_{i,l}{\bar{f}}_{\theta}(X)_{l}
=f¯θ(X)π(i).\displaystyle={\bar{f}}_{\theta}(X)_{\pi(i)}.

The last equality holds since ii-th row of the permutation matrix PP has a single non-zero entry which is 1. Thus,

[c1c2ck]=[f¯θ(X)π(1)f¯θ(X)π(2)f¯θ(X)π(k)]=Pdiag(f¯θ(X)),\begin{bmatrix}{{c}_{1}}&&&\\ &{{c}_{2}}&&\\ &&\ddots&\\ &&&{{c}_{k}}\end{bmatrix}=\begin{bmatrix}{{\bar{f}}_{\theta}(X)}_{\pi(1)}&&&\\ &{{\bar{f}}_{\theta}(X)}_{\pi(2)}&&\\ &&\ddots&\\ &&&{{\bar{f}}_{\theta}(X)}_{\pi(k)}\end{bmatrix}=P\text{diag}({\bar{f}}_{\theta}(X)),

which implies that

diag(f¯θ(X);PΣ)1\displaystyle\text{diag}\left({\bar{f}}_{\theta}(X);P\Sigma\right)^{-1} =[1c11c21ck]\displaystyle=\begin{bmatrix}{\frac{1}{{c}_{1}}}&&&\\ &\frac{1}{{{c}_{2}}}&&\\ &&\ddots&\\ &&&\frac{1}{{{c}_{k}}}\end{bmatrix}
=[f¯θ(X)π(1)f¯θ(X)π(2)f¯θ(X)π(k)]1.\displaystyle=\begin{bmatrix}{{\bar{f}}_{\theta}(X)}_{\pi(1)}&&&\\ &{{\bar{f}}_{\theta}(X)}_{\pi(2)}&&\\ &&\ddots&\\ &&&{{\bar{f}}_{\theta}(X)}_{\pi(k)}\end{bmatrix}^{-1}.

Finally, combining all the pieces, we get

diag(f¯θ(X);PΣ)1νp(PA^)V(X)\displaystyle\text{diag}\left({\bar{f}}_{\theta}(X);P\Sigma\right)^{-1}\nu_{p}(P\hat{A})V(X)
=diag(f¯θ(X);PΣ)1Pνp(A^)V(X)\displaystyle=\text{diag}\left({\bar{f}}_{\theta}(X);P\Sigma\right)^{-1}P\nu_{p}(\hat{A})V(X)
=[f¯θ(X)π(1)f¯θ(X)π(2)f¯θ(X)π(k)]1[(νp(A^)V(X))π(1),1(νp(A^)V(X))π(1),d(νp(A^)V(X))π(k),1(νp(A^)V(X))π(k),d]\displaystyle=\begin{bmatrix}{{\bar{f}}_{\theta}(X)}_{\pi(1)}&&&\\ &{{\bar{f}}_{\theta}(X)}_{\pi(2)}&&\\ &&\ddots&\\ &&&{{\bar{f}}_{\theta}(X)}_{\pi(k)}\end{bmatrix}^{-1}\begin{bmatrix}(\nu_{p}(\hat{A})V(X))_{\pi(1),1}&\cdots&(\nu_{p}(\hat{A})V(X))_{\pi(1),d}\\ \vdots&\ddots&\vdots\\ (\nu_{p}(\hat{A})V(X))_{\pi(k),1}&\cdots&(\nu_{p}(\hat{A})V(X))_{\pi(k),d}\end{bmatrix}
=[(νp(A^)V(X))π(1),1/f¯θ(X)π(1)(νp(A^)V(X))π(1),d/f¯θ(X)π(1)(νp(A^)V(X))π(k),1/f¯θ(X)π(k)(νp(A^)V(X))π(k),d/f¯θ(X)π(k)]\displaystyle=\begin{bmatrix}(\nu_{p}(\hat{A})V(X))_{\pi(1),1}/{\bar{f}}_{\theta}(X)_{\pi(1)}&\cdots&(\nu_{p}(\hat{A})V(X))_{\pi(1),d}/{\bar{f}}_{\theta}(X)_{\pi(1)}\\ \vdots&\ddots&\vdots\\ (\nu_{p}(\hat{A})V(X))_{\pi(k),1}/{\bar{f}}_{\theta}(X)_{\pi(k)}&\cdots&(\nu_{p}(\hat{A})V(X))_{\pi(k),d}/{\bar{f}}_{\theta}(X)_{\pi(k)}\end{bmatrix}
=P[f¯θ(X)1f¯θ(X)2f¯θ(X)k]1νp(A^)V(X)\displaystyle=P\begin{bmatrix}{{\bar{f}}_{\theta}(X)}_{1}&&&\\ &{{\bar{f}}_{\theta}(X)}_{2}&&\\ &&\ddots&\\ &&&{{\bar{f}}_{\theta}(X)}_{k}\end{bmatrix}^{-1}\nu_{p}(\hat{A})V(X)
=Pdiag(f¯θ(X))1f^θ(X)\displaystyle=P\text{diag}\left({\bar{f}}_{\theta}(X)\right)^{-1}{\hat{f}}_{\theta}(X)
=Pφ(Σ;X,θ)\displaystyle=P\varphi(\Sigma;X,\theta)
=PZ\displaystyle=PZ
=[𝐳π(1)𝐳π(k)].\displaystyle=\begin{bmatrix}-{\mathbf{z}}^{\top}_{\pi(1)}-\\ \vdots\\ -{\mathbf{z}}^{\top}_{\pi(k)}-\end{bmatrix}.

φ(Σ;X,θ)\therefore\varphi(\Sigma;X,\theta) is permutation equivariant. ∎

A.5 Proof of Theorem 3.7

Proof.

Following the previous proofs (Zaheer et al., 2017; Wagstaff et al., 2022) for uncountable set 𝔛\mathfrak{X}, we assume a set size is fixed. In other words, we restrict the domain 𝒳2𝔛\mathcal{X}\subset 2^{\mathfrak{X}} to [0,1]M[0,1]^{M}. Let ff\in\mathcal{F} be given. By using the proof of Theorem 13 from Wagstaff et al. (2022) (with a more detailed proof in Zaheer et al., 2017), the function ff is continuously sum-decomposable via M+1\mathbb{R}^{M+1} as:

f(X)=ρ(Ψ(X))f(X)=\rho\left(\Psi(X)\right)\\

for all X={x1,,xM}[0,1]MX=\{x_{1},\ldots,x_{M}\}\in[0,1]^{M}, where Ψ:[0,1]M𝒬M+1\Psi:[0,1]^{M}\to\mathcal{Q}\subseteq\mathbb{R}^{M+1} is invertible and defined by

Ψ(X)=(xXψ0(x),,xXψM(x)),ψq(x)=xq,for q=0,,M,\displaystyle\Psi(X)=\left(\sum_{x\in X}\psi_{0}(x),\ldots,\sum_{x\in X}\psi_{M}(x)\right),\quad\psi_{q}(x)=x^{q},\quad\text{for $q=0,\ldots,M$},

and ρ:M+1\rho:\mathbb{R}^{M+1}\to\mathbb{R} is continuous and defined by

ρ=fΨ1.\rho=f\circ\Psi^{-1}.

We want to show that UMBC with some continuously decomposable permutation invariant deep neural network can approximate the function ff by showing that ρ\rho and Ψ\Psi are approximated by components of the UMBC model. Let h:[0,1]Mh:[0,1]^{M}\to\mathbb{R} be a continuously decomposable permutation invariant deep neural network defined by

h(Z)=κ(𝐳Zξ(𝐳)),h(Z)=\kappa\left(\sum_{{\mathbf{z}}\in Z}\xi({\mathbf{z}})\right),

where Z={𝐳iM+1}i=1kZ=\{{\mathbf{z}}_{i}\in\mathbb{R}^{M+1}\}_{i=1}^{k}, κ:M+1\kappa:\mathbb{R}^{M+1}\to\mathbb{R} is a deep neural network, and ξ:M+1M+1\xi:\mathbb{R}^{M+1}\to\mathbb{R}^{M+1} is defined by

ξ(𝐳)=1kξ^(M𝐳)\xi({\mathbf{z}})=\frac{1}{k}\cdot\hat{\xi}(M\cdot{\mathbf{z}})

with some deep neural network ξ^:M+1M+1\hat{\xi}:\mathbb{R}^{M+1}\to\mathbb{R}^{M+1}. First, we want to show that Deepsets with average pooling is a special case of UMBC. Set the slots Σ\Sigma as the zero matrix 𝟎k×ds,d=dh=M+1\mathbf{0}\in\mathbb{R}^{k\times d_{s}},d=d_{h}=M+1, and WV=IM+1W^{V}=I_{M+1}. Then by using Lemma 1 from Lee et al. (2019), fθf_{\theta} becomes average pooling, i.e.,

fθ(X)=[1Mi=1Mϕ(xi)1Mi=1Mϕ(xi)]k×(M+1).f_{\theta}(X)=\begin{bmatrix}\frac{1}{M}\sum_{i=1}^{M}\phi(x_{i})\\ \vdots\\ \frac{1}{M}\sum_{i=1}^{M}\phi(x_{i})\end{bmatrix}\in\mathbb{R}^{k\times(M+1)}.

Then the composition of UMBC and hh becomes continuously sum-decomposable function as follows:

(hfθ)(X)\displaystyle(h\circ f_{\theta})(X) =κ(j=1kξ(fθ(X)j))\displaystyle=\kappa\left(\sum_{j=1}^{k}\xi(f_{\theta}(X)_{j})\right)
=κ(j=1k1kξ^(MMi=1Mϕ(xi)))\displaystyle=\kappa\left(\sum_{j=1}^{k}\frac{1}{k}\hat{\xi}\left(\frac{M}{M}\sum_{i=1}^{M}\phi(x_{i})\right)\right)
=(κξ^)(i=1Mϕ(xi))\displaystyle=(\kappa\circ\hat{\xi})\left(\sum_{i=1}^{M}\phi(x_{i})\right)

where fθ(X)jf_{\theta}(X)_{j} is jj-th row of fθ(X)f_{\theta}(X). By defining ρ^κξ^\hat{\rho}\coloneqq\kappa\circ\hat{\xi} and Ψ^(X)xXϕ(x)𝒬^M+1\hat{\Psi}(X)\coloneqq\sum_{x\in X}\phi(x)\in\hat{\mathcal{Q}}\subseteq\mathbb{R}^{M+1},

supXf(X)(hfθ)(X)2\displaystyle\sup_{X}\left\lVert f(X)-(h\circ f_{\theta})(X)\right\rVert_{2} =supX(ρΨ)(X)(ρ^Ψ^)(X)+(ρΨ^)(X)(ρΨ^)(X)2\displaystyle=\sup_{X}\left\lVert(\rho\circ\Psi)(X)-(\hat{\rho}\circ\hat{\Psi})(X)+(\rho\circ\hat{\Psi})(X)-(\rho\circ\hat{\Psi})(X)\right\rVert_{2}
supX(ρΨ)(X)(ρΨ^)(X)2+supX(ρ^Ψ^)(X)(ρΨ^)(X)2\displaystyle\leq\sup_{X}\left\lVert(\rho\circ\Psi)(X)-(\rho\circ\hat{\Psi})(X)\right\rVert_{2}+\sup_{X}\left\lVert(\hat{\rho}\circ\hat{\Psi})(X)-(\rho\circ\hat{\Psi})(X)\right\rVert_{2}
supXρ(Ψ^(X))ρ(Ψ(X))2+supz𝒬^ρ(z)ρ^(z)2.\displaystyle\leq\sup_{X}\left\lVert\rho(\hat{\Psi}(X))-\rho(\Psi(X))\right\rVert_{2}+\sup_{z\in\hat{\mathcal{Q}}}\left\lVert\rho(z)-\hat{\rho}(z)\right\rVert_{2}.

Since Ψ^:[0,1]M𝒬^M+1\hat{\Psi}:[0,1]^{M}\to\hat{\mathcal{Q}}\subseteq\mathbb{R}^{M+1} is continuous and [0,1]M[0,1]^{M} is compact, 𝒬^\hat{\mathcal{Q}} is compact. Since 𝒬^\hat{\mathcal{Q}} is compact and ρ\rho is continuous, and the nonlinearity of ρ^\hat{\rho} is not a polynomial of finite degree, Theorem 3.1 of (Pinkus, 1999) implies the following (as a network with one hidden layer can be approximated by a network of greater depth by using the same construction for the first layer and approximating the identity function with later layers): for any ϵ>0\epsilon^{\prime}>0, there exists τ1\tau^{\prime}\geq 1 and parameters of ρ^\hat{\rho} such that if the width of ρ^\hat{\rho} is at least τ\tau^{\prime}, then supz𝒬^ρ(z)ρ^(z)<ϵ\sup_{z\in\hat{\mathcal{Q}}}\|\rho(z)-\hat{\rho}(z)\|<\epsilon^{\prime}. Combining these, we have that

supXf(X)(hfθ)(X)2<supXρ(Ψ^(X))ρ(Ψ(X))2+ϵ(τ)\sup_{X}\left\lVert f(X)-(h\circ f_{\theta})(X)\right\rVert_{2}<\sup_{X}\left\lVert\rho(\hat{\Psi}(X))-\rho(\Psi(X))\right\rVert_{2}+\epsilon^{\prime}_{(\tau^{\prime})}

where ϵ(τ)\epsilon^{\prime}_{(\tau^{\prime})} depends on the width τ\tau^{\prime} of ρ^\hat{\rho}. Since the nonlinearity of ϕ\phi has nonzero Taylor coefficients up to degree MM, the proof of Theorem 3.4 of (Rolnick & Tegmark, 2018) implies the following: there exists τ1\tau\geq 1 such that if the width of ϕ=(ϕ0,ϕ1,,ϕM)\phi=(\phi_{0},\phi_{1},\dots,\phi_{M}) is at least τ\tau, for every δ>0\delta>0, there exists parameters of ϕ\phi such that supx|ϕq(x)ψq(x)|<δ2M(M+1)\sup_{x}\lvert\phi_{q}(x)-\psi_{q}(x)\rvert<\frac{\delta}{2M(M+1)} for q{0,1,,M}q\in\{0,1,\dots,M\}. Let us fix the width of ϕ\phi to be at least τ\tau. By the triangle inequality,

Ψ^(X)Ψ(X)2\displaystyle\left\lVert\hat{\Psi}(X)-\Psi(X)\right\rVert_{2} =xX(ϕ0(x)ψ0(x),,ϕM(x)ψM(x))2\displaystyle=\left\lVert\sum_{x\in X}\left(\phi_{0}(x)-\psi_{0}(x),\ldots,\phi_{M}(x)-\psi_{M}(x)\right)\right\rVert_{2}
xX(ϕ0(x)ψ0(x),,ϕM(x)ψM(x))2\displaystyle\leq\sum_{x\in X}\left\lVert\left(\phi_{0}(x)-\psi_{0}(x),\ldots,\phi_{M}(x)-\psi_{M}(x)\right)\right\rVert_{2}
xXq=0M|ϕq(x)ψq(x)|\displaystyle\leq\sum_{x\in X}\sum_{q=0}^{M}\lvert\phi_{q}(x)-\psi_{q}(x)\rvert
|X|q=0Msupx|ϕq(x)ψq(x)|\displaystyle\leq\lvert X\rvert\sum_{q=0}^{M}\sup_{x}\lvert\phi_{q}(x)-\psi_{q}(x)\rvert
<M(M+1)δ2M(M+1)\displaystyle<M(M+1)\frac{\delta}{2M(M+1)}
=δ2\displaystyle=\frac{\delta}{2}

for all X[0,1]MX\in[0,1]^{M}. It implies that for every δ>0\delta>0 there exists the parameters of ϕ\phi such that

supXΨ^(X)Ψ(X)2δ2<δ.\sup_{X}\left\lVert\hat{\Psi}(X)-\Psi(X)\right\rVert_{2}\leq\frac{\delta}{2}<\delta.

Since Ψ:[0,1]M𝒬M+1\Psi:[0,1]^{M}\to\mathcal{Q}\subseteq\mathbb{R}^{M+1} is continuous and [0,1]M[0,1]^{M} is compact, 𝒬\mathcal{Q} is compact. Since 𝒬\mathcal{Q} and 𝒬^\hat{\mathcal{Q}} are compact, 𝒬~𝒬𝒬^\tilde{\mathcal{Q}}\coloneqq\mathcal{Q}\cup\hat{\mathcal{Q}} is compact. Define ρ~:𝒬~\tilde{\rho}:\tilde{\mathcal{Q}}\to\mathbb{R} by ρ~(z)=ρ(z)\tilde{\rho}(z)=\rho(z) for all z𝒬~z\in\tilde{\mathcal{Q}}. Replacing ρ\rho with ρ~\tilde{\rho},

supXf(X)(hfθ)(X)<supXρ~(Ψ^(X))ρ~(Ψ(X))+ϵ(τ).\sup_{X}\|f(X)-(h\circ f_{\theta})(X)\|<\sup_{X}\|\tilde{\rho}(\hat{\Psi}(X))-\tilde{\rho}(\Psi(X))\|+\epsilon^{\prime}_{(\tau^{\prime})}.

Since 𝒬~\tilde{\mathcal{Q}} is compact and ρ~\tilde{\rho} is continuous on 𝒬~\tilde{\mathcal{Q}}, ρ~\tilde{\rho} is uniformly continuous. Thus, for any ϵ>0\epsilon>0 there is a δ0\delta_{0} such that

for any z1,z2𝒬~ with z1z22<δ0ρ~(z1)ρ~(z2)2<ϵ.\text{for any }z_{1},z_{2}\in\tilde{\mathcal{Q}}\text{ with }\left\lVert z_{1}-z_{2}\right\rVert_{2}<\delta_{0}\Rightarrow\left\lVert\tilde{\rho}(z_{1})-\tilde{\rho}(z_{2})\right\rVert_{2}<\epsilon.

Since supXΨ^(X)Ψ(X)2<δ\sup_{X}\|\hat{\Psi}(X)-\Psi(X)\|_{2}<\delta with an arbitrary small δ>0\delta>0, we can take a small δ\delta such that δ<δ0\delta<\delta_{0}, i.e. Ψ^(X)Ψ(X)2<δ0\|\hat{\Psi}(X)-\Psi(X)\|_{2}<\delta_{0} for all X[0,1]MX\in[0,1]^{M}. Then  for all X[0,1]M\text{ for all }X\in[0,1]^{M},

ρ~(Ψ^(X))ρ~(Ψ(X))2<ϵ.\left\lVert\tilde{\rho}(\hat{\Psi}(X))-\tilde{\rho}(\Psi(X))\right\rVert_{2}<\epsilon.

It implies that

ρ~(Ψ^(X))ρ~(Ψ(X))2supXρ~(Ψ^(X))ρ~(Ψ(X))2ϵ.\displaystyle\left\lVert\tilde{\rho}(\hat{\Psi}(X))-\tilde{\rho}(\Psi(X))\right\rVert_{2}\leq\sup_{X}\left\lVert\tilde{\rho}(\hat{\Psi}(X))-\tilde{\rho}(\Psi(X))\right\rVert_{2}\leq\epsilon.

Thus, we get

supXf(X)(hfθ)(X)2ϵ+ϵ(τ),\sup_{X}\left\lVert f(X)-(h\circ f_{\theta})(X)\right\rVert_{2}\leq\epsilon+\epsilon^{\prime}_{(\tau^{\prime})},

where ϵ(τ)\epsilon^{\prime}_{(\tau^{\prime})} depends on the width τ\tau^{\prime} of ρ^\hat{\rho}, while ϵ>0\epsilon>0 is arbitrarily small with a fixed width of ϕ\phi (due to the universal approximation result with a bounded width of Rolnick & Tegmark, 2018). Let ϵ0>0\epsilon_{0}>0 be given. Then, we set τ\tau^{\prime} to be sufficiently large to ensure that ϵ(τ)<ϵ0/2\epsilon^{\prime}_{(\tau^{\prime})}<\epsilon_{0}/2 and set ϵ<ϵ0/2\epsilon<\epsilon_{0}/2, obtaining

supXf(X)(hfθ)(X)2\displaystyle\sup_{X}\left\lVert f(X)-(h\circ f_{\theta})(X)\right\rVert_{2} <ϵ0.\displaystyle<\epsilon_{0}.

Since ϵ0>0\epsilon_{0}>0 was arbitrary, this proves the following desired result: (a formal restatement of this theorem) suppose that the nonlinear activation function of ϕ\phi has nonzero Taylor coefficients up to degree MM. Let hh be a continuously-decomposable permutation-invariant deep neural network with the nonlinear activation functions that are not polynomials of finite degrees. Then, there exists τ1\tau\geq 1 such that if the width of ϕ\phi is at least τ\tau, then for any ϵ0>0\epsilon_{0}>0, there exists τ1\tau^{\prime}\geq 1 for which the following statement holds: if the width of hh is at least τ\tau^{\prime}, then there exist trainable parameters of fθf_{\theta} and hh satisfying

supXf(X)(hfθ)(X)2<ϵ0.\sup_{X}\left\lVert f(X)-(h\circ f_{\theta})(X)\right\rVert_{2}<\epsilon_{0}.

A.6 Proof for Theorem 3.10

fθ(X)k×df_{\theta}(X)\in\mathbb{R}^{k\times d} is defined by

fθ(X)i=1f¯θ(X)if^θ(X)idf_{\theta}(X)_{i}=\frac{1}{{\bar{f}}_{\theta}(X)_{i}}{\hat{f}}_{\theta}(X)_{i}\in\mathbb{R}^{d}

for all i=1,,ki=1,\dots,k, where f^θ(X)id{\hat{f}}_{\theta}(X)_{i}\in\mathbb{R}^{d} is ii-th row of f^θ(X){\hat{f}}_{\theta}(X) defined in equation 4 and f¯θ(X)i{\bar{f}}_{\theta}(X)_{i}\in\mathbb{R} is ii-th component of f¯θ(X){\bar{f}}_{\theta}(X) which is defined in equation 5.

Proof.

From the mini-batch consistency and definition of f^θ{\hat{f}}_{\theta} and f¯θ{\bar{f}}_{\theta}, we have that for any partition procedure ζ\zeta,

f^θ(X)=Sζ(X)f^θ(S) and f¯θ(X)=Sζ(X)f¯θ(S).\displaystyle{\hat{f}}_{\theta}(X)=\sum_{S\in\zeta(X)}{\hat{f}}_{\theta}(S)\ \text{ and }\ {\bar{f}}_{\theta}(X)=\sum_{S\in\zeta(X)}{\bar{f}}_{\theta}(S).

By using this and defining i(q)(q,yi)\ell_{i}(q)\coloneqq\ell(q,y_{i})\in\mathbb{R} and 𝐳j(i)fθ(Xi)jd{\mathbf{z}}^{(i)}_{j}\coloneqq f_{\theta}(X_{i})_{j}\in\mathbb{R}^{d}, where fθ(Xi)jf_{\theta}(X_{i})_{j} is jj-th row of fθ(Xi)f_{\theta}(X_{i}), the chain rule along with the linearity of the derivative operator yields that for any partition procedure ζ\zeta,

L(θ,λ)θ=1ni=1nj=1k(igλ)(𝐳j(i))𝐳j(i)fθ(Xi)jθ=1ni=1nj=1k(igλ)(𝐳j(i))𝐳j(i)(1f¯θ(Xi)jf^θ(Xi)jθf^θ(X)j1f¯θ(Xi)j2f¯θ(Xi)jθ)=1ni=1nj=1k(igλ)(𝐳j(i))𝐳j(i)(1f¯θ(Xi)jSζ(Xi)f^θ(S)jθ1f¯θ(Xi)j2f^θ(Xi)jSζ(Xi)f¯θ(S)jθ).\displaystyle\begin{split}\frac{\partial L(\theta,\lambda)}{\partial\theta}&=\frac{1}{n}\sum_{i=1}^{n}\sum_{j=1}^{k}\frac{\partial(\ell_{i}\circ g_{\lambda})({\mathbf{z}}^{(i)}_{j})}{\partial{\mathbf{z}}^{(i)}_{j}}\frac{\partial f_{\theta}(X_{i})_{j}}{\partial\theta}\\ &=\frac{1}{n}\sum_{i=1}^{n}\sum_{j=1}^{k}\frac{\partial(\ell_{i}\circ g_{\lambda})({\mathbf{z}}^{(i)}_{j})}{\partial{\mathbf{z}}^{(i)}_{j}}\left(\frac{1}{{\bar{f}}_{\theta}(X_{i})_{j}}\frac{\partial{\hat{f}}_{\theta}(X_{i})_{j}}{\partial\theta}-{\hat{f}}_{\theta}(X)_{j}\frac{1}{{\bar{f}}_{\theta}(X_{i})_{j}^{2}}\frac{\partial{\bar{f}}_{\theta}(X_{i})_{j}}{\partial\theta}\right)\\ &=\frac{1}{n}\sum_{i=1}^{n}\sum_{j=1}^{k}\frac{\partial(\ell_{i}\circ g_{\lambda})({\mathbf{z}}^{(i)}_{j})}{\partial{\mathbf{z}}^{(i)}_{j}}\left(\frac{1}{{\bar{f}}_{\theta}(X_{i})_{j}}\sum_{S\in\zeta(X_{i})}\frac{\partial{\hat{f}}_{\theta}(S)_{j}}{\partial\theta}-\frac{1}{{\bar{f}}_{\theta}(X_{i})_{j}^{2}}{\hat{f}}_{\theta}(X_{i})_{j}\sum_{S\in\zeta(X_{i})}\frac{\partial{\bar{f}}_{\theta}(S)_{j}}{\partial\theta}\right).\end{split} (25)

Similarly, by defining ¯i(q)(q,y¯i){\bar{\ell}}_{i}(q)\coloneqq\ell(q,{\bar{y}}_{i})\in\mathbb{R} and 𝐳¯j(i)f^θ(X¯i)jd{\bar{\mathbf{z}}}^{(i)}_{j}\coloneqq{\hat{f}}_{\theta}(\bar{X}_{i})_{j}\in\mathbb{R}^{d},

Lt,1(θ,λ)θ\displaystyle\frac{\partial L_{t,1}(\theta,\lambda)}{\partial\theta} =1mi=1m|ζt(X¯i)||ζ¯t(X¯i)|j=1k(¯igλ)(𝐳¯j(i))𝐳¯j(i)fθζ¯t(X¯i)jθ\displaystyle=\frac{1}{m}\sum_{i=1}^{m}\frac{|\zeta_{t}(\bar{X}_{i})|}{|{\bar{\zeta}}_{t}(\bar{X}_{i})|}\sum_{j=1}^{k}\frac{\partial({\bar{\ell}}_{i}\circ g_{\lambda})({\bar{\mathbf{z}}}^{(i)}_{j})}{\partial{\bar{\mathbf{z}}}^{(i)}_{j}}\frac{\partial f_{\theta}^{{\bar{\zeta}}_{t}}(\bar{X}_{i})_{j}}{\partial\theta} (26)
=1mi=1m|ζt(X¯i)||ζ¯t(X¯i)|j=1k(¯igλ)(𝐳¯j(i))𝐳¯j(i)(1f¯θ(X¯i)jS¯ζ¯t(X¯i)f^θ(S¯)jθ1f¯θ(X¯i)j2f^θ(X¯i)jS¯ζ¯t(X¯i)f¯θ(S¯)jθ).\displaystyle=\frac{1}{m}\sum_{i=1}^{m}\frac{|\zeta_{t}(\bar{X}_{i})|}{|{\bar{\zeta}}_{t}(\bar{X}_{i})|}\sum_{j=1}^{k}\frac{\partial({\bar{\ell}}_{i}\circ g_{\lambda})({\bar{\mathbf{z}}}^{(i)}_{j})}{\partial{\bar{\mathbf{z}}}^{(i)}_{j}}\left(\frac{1}{{\bar{f}}_{\theta}(\bar{X}_{i})_{j}}\sum_{{\bar{S}}\in{\bar{\zeta}}_{t}(\bar{X}_{i})}\frac{\partial{\hat{f}}_{\theta}({\bar{S}})_{j}}{\partial\theta}-\frac{1}{{\bar{f}}_{\theta}(\bar{X}_{i})_{j}^{2}}{\hat{f}}_{\theta}(\bar{X}_{i})_{j}\sum_{{\bar{S}}\in{\bar{\zeta}}_{t}(\bar{X}_{i})}\frac{\partial{\bar{f}}_{\theta}({\bar{S}})_{j}}{\partial\theta}\right).

Let t+t\in\mathbb{N}_{+} be fixed. By the linearity of expectation, we have that

𝔼((X¯i,y¯i))i=1m𝔼(ζ¯t(X¯i))i=1m[Lt,1(θ,λ)θ]\displaystyle\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))^{m}_{i=1}}\mathbb{E}_{({\bar{\zeta}}_{t}(\bar{X}_{i}))^{m}_{i=1}}\left[\frac{\partial L_{t,1}(\theta,\lambda)}{\partial\theta}\right] =𝔼((X¯i,y¯i))i=1m[1mi=1mj=1k(¯igλ)(𝐳¯j(i))𝐳¯j(i)Qij]\displaystyle=\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))^{m}_{i=1}}\left[\frac{1}{m}\sum_{i=1}^{m}\sum_{j=1}^{k}\frac{\partial({\bar{\ell}}_{i}\circ g_{\lambda})({\bar{\mathbf{z}}}^{(i)}_{j})}{\partial{\bar{\mathbf{z}}}^{(i)}_{j}}\mathrm{Q_{ij}}\right] (27)

where

Qij=1f¯θ(X¯i)j𝔼ζ¯t(X¯i)[|ζt(X¯i)||ζ¯t(X¯i)|S¯ζ¯t(X¯i)f^θ(S¯)jθ]1f¯θ(X¯i)j2f^θ(X¯i)j𝔼ζ¯t(X¯i)[|ζt(X¯i)||ζ¯t(X¯i)|S¯ζ¯t(X¯i)f¯θ(S¯)jθ].\mathrm{Q}_{ij}=\frac{1}{{\bar{f}}_{\theta}(\bar{X}_{i})_{j}}\mathbb{E}_{{\bar{\zeta}}_{t}(\bar{X}_{i})}\left[\frac{|\zeta_{t}(\bar{X}_{i})|}{|{\bar{\zeta}}_{t}(\bar{X}_{i})|}\sum_{{\bar{S}}\in{\bar{\zeta}}_{t}(\bar{X}_{i})}\frac{\partial{\hat{f}}_{\theta}({\bar{S}})_{j}}{\partial\theta}\right]-\frac{1}{{\bar{f}}_{\theta}(\bar{X}_{i})_{j}^{2}}{\hat{f}}_{\theta}(\bar{X}_{i})_{j}\mathbb{E}_{{\bar{\zeta}}_{t}(\bar{X}_{i})}\left[\frac{|\zeta_{t}(\bar{X}_{i})|}{|{\bar{\zeta}}_{t}(\bar{X}_{i})|}\sum_{{\bar{S}}\in{\bar{\zeta}}_{t}(\bar{X}_{i})}\frac{\partial{\bar{f}}_{\theta}({\bar{S}})_{j}}{\partial\theta}\right].

Below, we further analyze the following factors in the right-hand side of this equation:

𝔼ζ¯t(X¯i)[|ζt(X¯i)||ζ¯t(X¯i)|S¯ζ¯t(X¯i)f^θ(S¯)jθ] and 𝔼ζ¯t(X¯i)[|ζt(X¯i)||ζ¯t(X¯i)|S¯ζ¯t(X¯i)f¯θ(S¯)jθ].\displaystyle\mathbb{E}_{{\bar{\zeta}}_{t}(\bar{X}_{i})}\left[\frac{|\zeta_{t}(\bar{X}_{i})|}{|{\bar{\zeta}}_{t}(\bar{X}_{i})|}\sum_{{\bar{S}}\in{\bar{\zeta}}_{t}(\bar{X}_{i})}\frac{\partial{\hat{f}}_{\theta}({\bar{S}})_{j}}{\partial\theta}\right]\text{ and }\mathbb{E}_{{\bar{\zeta}}_{t}(\bar{X}_{i})}\left[\frac{|\zeta_{t}(\bar{X}_{i})|}{|{\bar{\zeta}}_{t}(\bar{X}_{i})|}\sum_{{\bar{S}}\in{\bar{\zeta}}_{t}(\bar{X}_{i})}\frac{\partial{\bar{f}}_{\theta}({\bar{S}})_{j}}{\partial\theta}\right].

Denote the elements of ζ¯t(X¯i){\bar{\zeta}}_{t}(\bar{X}_{i}) as {S¯1,S¯2,,S¯|ζ¯t(X¯i)|}=ζ¯t(X¯i)\{{\bar{S}}_{1},{\bar{S}}_{2},\dots,{\bar{S}}_{|{\bar{\zeta}}_{t}(\bar{X}_{i})|}\}={\bar{\zeta}}_{t}(\bar{X}_{i}). Then,

𝔼ζ¯t(X¯i)[|ζt(X¯i)||ζ¯t(X¯i)|S¯ζ¯t(X¯i)f^θ(S¯)jθ]\displaystyle\mathbb{E}_{{\bar{\zeta}}_{t}(\bar{X}_{i})}\left[\frac{|\zeta_{t}(\bar{X}_{i})|}{|{\bar{\zeta}}_{t}(\bar{X}_{i})|}\sum_{{\bar{S}}\in{\bar{\zeta}}_{t}(\bar{X}_{i})}\frac{\partial{\hat{f}}_{\theta}({\bar{S}})_{j}}{\partial\theta}\right] =𝔼S¯1,S¯2,,S¯|ζ¯t(X¯i)|[|ζt(X¯i)||ζ¯t(X¯i)|l=1|ζ¯t(X¯i)|f^θ(S¯l)jθ]\displaystyle=\mathbb{E}_{{\bar{S}}_{1},{\bar{S}}_{2},\dots,{\bar{S}}_{|{\bar{\zeta}}_{t}(\bar{X}_{i})|}}\left[\frac{|\zeta_{t}(\bar{X}_{i})|}{|{\bar{\zeta}}_{t}(\bar{X}_{i})|}\sum_{l=1}^{|{\bar{\zeta}}_{t}(\bar{X}_{i})|}\frac{\partial{\hat{f}}_{\theta}({\bar{S}}_{l})_{j}}{\partial\theta}\right]
=|ζt(X¯i)||ζ¯t(X¯i)|l=1|ζ¯t(X¯i)|𝔼S¯l[f^θ(S¯l)jθ].\displaystyle=\frac{|\zeta_{t}(\bar{X}_{i})|}{|{\bar{\zeta}}_{t}(\bar{X}_{i})|}\sum_{l=1}^{|{\bar{\zeta}}_{t}(\bar{X}_{i})|}\mathbb{E}_{{\bar{S}}_{l}}\left[\frac{\partial{\hat{f}}_{\theta}({\bar{S}}_{l})_{j}}{\partial\theta}\right].

Since S¯l{\bar{S}}_{l} is drawn independently and uniformly from the elements of ζt(X¯i)\zeta_{t}(\bar{X}_{i}), we have that

𝔼S¯l[f^θ(S¯l)jθ]=1|ζt(X¯i)|Sζt(X¯i)f^θ(S)jθ.\mathbb{E}_{{\bar{S}}_{l}}\left[\frac{\partial{\hat{f}}_{\theta}({\bar{S}}_{l})_{j}}{\partial\theta}\right]=\frac{1}{|\zeta_{t}(\bar{X}_{i})|}\sum_{S\in\zeta_{t}(\bar{X}_{i})}\frac{\partial{\hat{f}}_{\theta}(S)_{j}}{\partial\theta}.

Substituting this into the right-hand side of the preceding equation, we have that

𝔼ζ¯t(X¯i)[|ζt(X¯i)||ζ¯t(X¯i)|S¯ζ¯t(X¯i)f^θ(S¯)jθ]\displaystyle\mathbb{E}_{{\bar{\zeta}}_{t}(\bar{X}_{i})}\left[\frac{|\zeta_{t}(\bar{X}_{i})|}{|{\bar{\zeta}}_{t}(\bar{X}_{i})|}\sum_{{\bar{S}}\in{\bar{\zeta}}_{t}(\bar{X}_{i})}\frac{\partial{\hat{f}}_{\theta}({\bar{S}})_{j}}{\partial\theta}\right] =|ζt(X¯i)||ζ¯t(X¯i)|l=1|ζ¯t(X¯i)|1|ζt(X¯i)|Sζt(X¯i)f^θ(S)jθ=Sζt(X¯i)f^θ(S)jθ.\displaystyle=\frac{|\zeta_{t}(\bar{X}_{i})|}{|{\bar{\zeta}}_{t}(\bar{X}_{i})|}\sum_{l=1}^{|{\bar{\zeta}}_{t}(\bar{X}_{i})|}\frac{1}{|\zeta_{t}(\bar{X}_{i})|}\sum_{S\in\zeta_{t}(\bar{X}_{i})}\frac{\partial{\hat{f}}_{\theta}(S)_{j}}{\partial\theta}=\sum_{S\in\zeta_{t}(\bar{X}_{i})}\frac{\partial{\hat{f}}_{\theta}(S)_{j}}{\partial\theta}.

Similarly,

𝔼ζ¯t(X¯i)[|ζt(X¯i)||ζ¯t(X¯i)|S¯ζ¯t(X¯i)f¯θ(S¯)jθ]=Sζt(X¯i)f¯θ(S)jθ.\mathbb{E}_{{\bar{\zeta}}_{t}(\bar{X}_{i})}\left[\frac{|\zeta_{t}(\bar{X}_{i})|}{|{\bar{\zeta}}_{t}(\bar{X}_{i})|}\sum_{{\bar{S}}\in{\bar{\zeta}}_{t}(\bar{X}_{i})}\frac{\partial{\bar{f}}_{\theta}({\bar{S}})_{j}}{\partial\theta}\right]=\sum_{S\in\zeta_{t}(\bar{X}_{i})}\frac{\partial{\bar{f}}_{\theta}(S)_{j}}{\partial\theta}.

Substituting these into Qij\mathrm{Q}_{ij},

Qij=1f¯θ(X¯i)jSζt(X¯i)f^θ(S)jθ1f¯θ(X¯i)j2f^θ(X)jSζt(X¯i)f¯θ(S)jθ.\mathrm{Q}_{ij}=\frac{1}{{\bar{f}}_{\theta}(\bar{X}_{i})_{j}}\sum_{S\in\zeta_{t}(\bar{X}_{i})}\frac{\partial{\hat{f}}_{\theta}(S)_{j}}{\partial\theta}-\frac{1}{{\bar{f}}_{\theta}(\bar{X}_{i})_{j}^{2}}{\hat{f}}_{\theta}(X)_{j}\sum_{S\in\zeta_{t}(\bar{X}_{i})}\frac{\partial{\bar{f}}_{\theta}(S)_{j}}{\partial\theta}.

By using this in equation 27 and defining B(X¯i,y¯i)j=(¯ig)(𝐳¯j(i))𝐳¯j(i)Qij,B(\bar{X}_{i},{\bar{y}}_{i})_{j}=\frac{\partial({\bar{\ell}}_{i}\circ g)({\bar{\mathbf{z}}}^{(i)}_{j})}{\partial{\bar{\mathbf{z}}}^{(i)}_{j}}\mathrm{Q}_{ij}, we have that

𝔼((X¯i,y¯i))i=1m𝔼(ζ¯t(X¯i))i=1m[Lt,1(θ,λ)θ]=𝔼((X¯i,y¯i))i=1m[1mi=1mj=1kB(X¯i,y¯i)j]\displaystyle\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))^{m}_{i=1}}\mathbb{E}_{({\bar{\zeta}}_{t}(\bar{X}_{i}))^{m}_{i=1}}\left[\frac{\partial L_{t,1}(\theta,\lambda)}{\partial\theta}\right]=\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))^{m}_{i=1}}\left[\frac{1}{m}\sum_{i=1}^{m}\sum_{j=1}^{k}B(\bar{X}_{i},{\bar{y}}_{i})_{j}\right] =1mi=1mj=1k𝔼(X¯i,y¯i)[B(X¯i,y¯i)j]\displaystyle=\frac{1}{m}\sum_{i=1}^{m}\sum_{j=1}^{k}\mathbb{E}_{(\bar{X}_{i},{\bar{y}}_{i})}[B(\bar{X}_{i},{\bar{y}}_{i})_{j}]
=j=1k1mi=1m1nl=1nB(Xl,yl)j\displaystyle=\sum_{j=1}^{k}\frac{1}{m}\sum_{i=1}^{m}\frac{1}{n}\sum_{l=1}^{n}B(X_{l},y_{l})_{j}
=1ni=1nj=1kB(Xi,yi)j.\displaystyle=\frac{1}{n}\sum_{i=1}^{n}\sum_{j=1}^{k}B(X_{i},y_{i})_{j}.

Thus, expanding the definition of B(Xi,yi)jB(X_{i},y_{i})_{j},

𝔼((X¯i,y¯i))i=1m𝔼(ζ¯t(X¯i))i=1m[Lt,1(θ,λ)θ]\displaystyle\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))^{m}_{i=1}}\mathbb{E}_{({\bar{\zeta}}_{t}(\bar{X}_{i}))^{m}_{i=1}}\left[\frac{\partial L_{t,1}(\theta,\lambda)}{\partial\theta}\right]
=1ni=1nj=1k(igλ)(𝐳j(i))𝐳j(i)(1f¯θ(Xi)jSζt(Xi)f^θ(S)jθ1f¯θ(Xi)j2f^θ(X)jSζt(Xi)f¯θ(S)jθ).\displaystyle=\frac{1}{n}\sum_{i=1}^{n}\sum_{j=1}^{k}\frac{\partial(\ell_{i}\circ g_{\lambda})({\mathbf{z}}^{(i)}_{j})}{\partial{\mathbf{z}}^{(i)}_{j}}\left(\frac{1}{{\bar{f}}_{\theta}(X_{i})_{j}}\sum_{S\in\zeta_{t}(X_{i})}\frac{\partial{\hat{f}}_{\theta}(S)_{j}}{\partial\theta}-\frac{1}{{\bar{f}}_{\theta}(X_{i})_{j}^{2}}{\hat{f}}_{\theta}(X)_{j}\sum_{S\in\zeta_{t}(X_{i})}\frac{\partial{\bar{f}}_{\theta}(S)_{j}}{\partial\theta}\right).

Here, since Sζt(Xi)f^θ(S)θ=Sζ(Xi)f^θ(S)θ\sum_{S\in\zeta_{t}(X_{i})}\frac{\partial{\hat{f}}_{\theta}(S)}{\partial\theta}=\sum_{S\in\zeta(X_{i})}\frac{\partial{\hat{f}}_{\theta}(S)}{\partial\theta} and Sζt(Xi)f¯θ(S)jθ=Sζ(Xi)f¯θ(S)jθ\sum_{S\in\zeta_{t}(X_{i})}\frac{\partial{\bar{f}}_{\theta}(S)_{j}}{\partial\theta}=\sum_{S\in\zeta(X_{i})}\frac{\partial{\bar{f}}_{\theta}(S)_{j}}{\partial\theta} for any partition procedure ζ\zeta from the mini-batch consistency, we have that

𝔼((X¯i,y¯i))i=1m𝔼(ζ¯t(X¯i))i=1m[Lt,1(θ,λ)θ]=1ni=1nj=1k(igλ)(𝐳j(i))𝐳j(i)(1f¯θ(Xi)jSζ(Xi)f^θ(S)jθ1f¯θ(Xi)j2f^θ(X)jSζ(Xi)f¯θ(S)jθ).\displaystyle\begin{split}&\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))^{m}_{i=1}}\mathbb{E}_{({\bar{\zeta}}_{t}(\bar{X}_{i}))^{m}_{i=1}}\left[\frac{\partial L_{t,1}(\theta,\lambda)}{\partial\theta}\right]\\ &=\frac{1}{n}\sum_{i=1}^{n}\sum_{j=1}^{k}\frac{\partial(\ell_{i}\circ g_{\lambda})({\mathbf{z}}^{(i)}_{j})}{\partial{\mathbf{z}}^{(i)}_{j}}\left(\frac{1}{{\bar{f}}_{\theta}(X_{i})_{j}}\sum_{S\in\zeta(X_{i})}\frac{\partial{\hat{f}}_{\theta}(S)_{j}}{\partial\theta}-\frac{1}{{\bar{f}}_{\theta}(X_{i})_{j}^{2}}{\hat{f}}_{\theta}(X)_{j}\sum_{S\in\zeta(X_{i})}\frac{\partial{\bar{f}}_{\theta}(S)_{j}}{\partial\theta}\right).\end{split} (28)

By comparing equation equation 25 and equation equation 28, we conclude that

𝔼((X¯i,y¯i))i=1m𝔼(ζ¯t(X¯i))i=1m[Lt,1(θ,λ)θ]=L(θ,λ)θ.\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))^{m}_{i=1}}\mathbb{E}_{({\bar{\zeta}}_{t}(\bar{X}_{i}))^{m}_{i=1}}\left[\frac{\partial L_{t,1}(\theta,\lambda)}{\partial\theta}\right]=\frac{\partial L(\theta,\lambda)}{\partial\theta}.

Since tt was arbitrary, this holds for any t+t\in\mathbb{N}_{+}.

Now we want to show equation 13 holds for any t+t\in\mathbb{N}_{+}. Since f¯θζ¯t(X¯i)=f¯θ(X¯i){\bar{f}}^{{\bar{\zeta}}_{t}}_{\theta}(\bar{X}_{i})={\bar{f}}_{\theta}(\bar{X}_{i}) and f^θζ¯t(X¯i)=f^θ(X¯i){\hat{f}}^{{\bar{\zeta}}_{t}}_{\theta}(\bar{X}_{i})={\hat{f}}_{\theta}(\bar{X}_{i}) for all i[m]i\in[m],

𝔼((X¯i,y¯i))i=1m𝔼(ζ¯t(X¯i))i=1m[Lt,2(θ,λ)λ]\displaystyle\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))^{m}_{i=1}}\mathbb{E}_{({\bar{\zeta}}_{t}(\bar{X}_{i}))^{m}_{i=1}}\left[\frac{\partial L_{t,2}(\theta,\lambda)}{\partial\lambda}\right] =𝔼((X¯i,y¯i))i=1m𝔼(ζ¯t(X¯i))i=1m[1mi=1m1|ζ¯t(X¯i)|(¯igλ)(fθζ¯t(X¯i))λ]\displaystyle=\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))^{m}_{i=1}}\mathbb{E}_{({\bar{\zeta}}_{t}(\bar{X}_{i}))^{m}_{i=1}}\left[\frac{1}{m}\sum_{i=1}^{m}\frac{1}{\lvert{\bar{\zeta}}_{t}(\bar{X}_{i})\rvert}\frac{\partial({\bar{\ell}}_{i}\circ g_{\lambda})(f^{{\bar{\zeta}}_{t}}_{\theta}(\bar{X}_{i}))}{\partial\lambda}\right]
=𝔼((X¯i,y¯i))i=1m𝔼(ζ¯t(X¯i))i=1m[1mi=1m1|ζ¯t(X¯i)|(¯igλ)(fθ(X¯i))λ].\displaystyle=\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))^{m}_{i=1}}\mathbb{E}_{({\bar{\zeta}}_{t}(\bar{X}_{i}))^{m}_{i=1}}\left[\frac{1}{m}\sum_{i=1}^{m}\frac{1}{\lvert{\bar{\zeta}}_{t}(\bar{X}_{i})\rvert}\frac{\partial({\bar{\ell}}_{i}\circ g_{\lambda})(f_{\theta}(\bar{X}_{i}))}{\partial\lambda}\right]. (29)

Since we independently and uniformly sample S¯1,S¯2,,S¯|ζ¯t(X¯i)|{\bar{S}}_{1},{\bar{S}}_{2},\ldots,{\bar{S}}_{\lvert{\bar{\zeta}}_{t}(\bar{X}_{i})\rvert} from ζt(X¯i)\zeta_{t}(\bar{X}_{i}) and (¯igλ)(fθ(X¯i)λ\frac{\partial({\bar{\ell}}_{i}\circ g_{\lambda})(f_{\theta}(\bar{X}_{i})}{\partial\lambda} is constant with respect to the sampling,

𝔼ζ¯t(X¯i)[1|ζ¯t(X¯i)|(¯igλ)(fθ(X¯i))λ]\displaystyle\mathbb{E}_{{\bar{\zeta}}_{t}(\bar{X}_{i})}\left[\frac{1}{\lvert{\bar{\zeta}}_{t}(\bar{X}_{i})\rvert}\frac{\partial({\bar{\ell}}_{i}\circ g_{\lambda})(f_{\theta}(\bar{X}_{i}))}{\partial\lambda}\right] =𝔼S¯1,S¯2,,S¯|ζ¯t(X¯i)|[1|ζ¯t(X¯i)|(¯igλ)(fθ(X¯i))λ]\displaystyle=\mathbb{E}_{{\bar{S}}_{1},{\bar{S}}_{2},\ldots,{\bar{S}}_{\lvert{\bar{\zeta}}_{t}(\bar{X}_{i})\rvert}}\left[\frac{1}{\lvert{\bar{\zeta}}_{t}(\bar{X}_{i})\rvert}\frac{\partial({\bar{\ell}}_{i}\circ g_{\lambda})(f_{\theta}(\bar{X}_{i}))}{\partial\lambda}\right]
=1|ζ¯t(X¯i)|j=1|ζ¯t(X¯i)|𝔼S¯j[(¯igλ)(fθ(X¯i))λ]\displaystyle=\frac{1}{\lvert{\bar{\zeta}}_{t}(\bar{X}_{i})\rvert}\sum_{j=1}^{\lvert{\bar{\zeta}}_{t}(\bar{X}_{i})\rvert}\mathbb{E}_{{\bar{S}}_{j}}\left[\frac{\partial({\bar{\ell}}_{i}\circ g_{\lambda})(f_{\theta}(\bar{X}_{i}))}{\partial\lambda}\right]
=1|ζ¯t(X¯i)|j=1|ζ¯t(X¯i)|(¯igλ)(fθ(X¯i))λ\displaystyle=\frac{1}{\lvert{\bar{\zeta}}_{t}(\bar{X}_{i})\rvert}\sum_{j=1}^{\lvert{\bar{\zeta}}_{t}(\bar{X}_{i})\rvert}\frac{\partial({\bar{\ell}}_{i}\circ g_{\lambda})(f_{\theta}(\bar{X}_{i}))}{\partial\lambda}
=(¯igλ)(fθ(X¯i))λ.\displaystyle=\frac{\partial({\bar{\ell}}_{i}\circ g_{\lambda})(f_{\theta}(\bar{X}_{i}))}{\partial\lambda}. (30)

Since we sample a mini-batch ((X¯i,y¯i))i=1m((\bar{X}_{i},{\bar{y}}_{i}))_{i=1}^{m} independently and uniformly from the whole training set ((Xi,yi))i=1n((X_{i},y_{i}))_{i=1}^{n}, if we apply equation 30 to equation 29, we get

𝔼((X¯i,y¯i))i=1m[1mi=1m(¯igλ)(fθ(X¯i))λ]=1mi=1m𝔼((X¯i,y¯i))i=1m[(¯igλ)(fθ(X¯i))λ]=1mi=1m(1nj=1n(jgλ)(fθ(Xj))λ)=1nj=1n(jgλ)(fθ(Xj))λ=L(θ,λ)λ\displaystyle\begin{split}\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))^{m}_{i=1}}\left[\frac{1}{m}\sum_{i=1}^{m}\frac{\partial({\bar{\ell}}_{i}\circ g_{\lambda})(f_{\theta}(\bar{X}_{i}))}{\partial\lambda}\right]&=\frac{1}{m}\sum_{i=1}^{m}\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))^{m}_{i=1}}\left[\frac{\partial({\bar{\ell}}_{i}\circ g_{\lambda})(f_{\theta}(\bar{X}_{i}))}{\partial\lambda}\right]\\ &=\frac{1}{m}\sum_{i=1}^{m}\left(\frac{1}{n}\sum_{j=1}^{n}\frac{\partial(\ell_{j}\circ g_{\lambda})(f_{\theta}(X_{j}))}{\partial\lambda}\right)\\ &=\frac{1}{n}\sum_{j=1}^{n}\frac{\partial(\ell_{j}\circ g_{\lambda})(f_{\theta}(X_{j}))}{\partial\lambda}\\ &=\frac{\partial L(\theta,\lambda)}{\partial\lambda}\end{split} (31)

Therefore, we conclude that

𝔼((X¯i,y¯i))i=1m𝔼(ζ¯t(X¯i))i=1m[Lt,2(θ,λ)λ]=L(θ,λ)λ.\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))^{m}_{i=1}}\mathbb{E}_{({\bar{\zeta}}_{t}(\bar{X}_{i}))^{m}_{i=1}}\left[\frac{\partial L_{t,2}(\theta,\lambda)}{\partial\lambda}\right]=\frac{\partial L(\theta,\lambda)}{\partial\lambda}.

Since tt was arbitrary, this holds for any t+t\in\mathbb{N}_{+}. ∎

A.7 Unbiased Estimation of Full Set Gradient for MIL

Corollary A.1.

For multiple instance learning, given a training set ((Xi,yi))i=1n((X_{i},y_{i}))_{i=1}^{n}, we define the loss as follows:

𝐳i,1=max{wϕ(𝐱i,j)+b}j=1Ni,𝐳i,2=fθ(Xi),\displaystyle{\mathbf{z}}_{i,1}=\max\{w^{\top}\phi({\mathbf{x}}_{i,j})+b\}_{j=1}^{N_{i}},\quad{\mathbf{z}}_{i,2}=f_{\theta}(X_{i}),
i=12((𝐳i,1,yi)+(gλ(𝐳i,2),yi)),\displaystyle\mathcal{L}_{i}=\frac{1}{2}\left(\ell({\mathbf{z}}_{i,1},y_{i})+\ell(g_{\lambda}({\mathbf{z}}_{i,2}),y_{i})\right),
L(θ,λ,w,b)=1ni=1ni.\displaystyle L(\theta,\lambda,w,b)=\frac{1}{n}\sum_{i=1}^{n}\mathcal{L}_{i}.

For every iteration t+t\in\mathbb{N}_{+} we sample a mini-batch of training data ((X¯i,y¯i))i=1mD[((Xi,yi))i=1n]((\bar{X}_{i},{\bar{y}}_{i}))_{i=1}^{m}\sim D[((X_{i},y_{i}))_{i=1}^{n}] and sample random partition (ζt(X¯i))i=1m(\zeta_{t}(\bar{X}_{i}))_{i=1}^{m}. Then we sample a mini-batch of subsets ζ¯t(X¯i)D([ζt(X¯i)]{\bar{\zeta}}_{t}(\bar{X}_{i})\sim D([\zeta_{t}(\bar{X}_{i})]. Let ψ(w,b)dh+1\psi\coloneqq(w,b)\in\mathbb{R}^{d_{h}+1} and define a function hψ:𝒳h_{\psi}:\mathcal{X}\to\mathbb{R} by hψ(X)max{wϕ(𝐱)+b:𝐱X}h_{\psi}(X)\coloneqq\max\{w^{\top}\phi({\mathbf{x}})+b:{\mathbf{x}}\in X\}. Similar to equation 9 we define hψζ¯th^{{\bar{\zeta}}_{t}}_{\psi} as,

hψζ¯t(X¯i)\displaystyle h_{\psi}^{{\bar{\zeta}}_{t}}(\bar{X}_{i}) hψζ¯t,ζt(X¯i)\displaystyle\coloneqq h^{{\bar{\zeta}}_{t},\zeta_{t}}_{\psi}(\bar{X}_{i})
max{hψ(S),StopGrad(hψ(S))Sζ¯t(X¯i),Sζt(X¯i)ζ¯t(X¯i)}.\displaystyle\coloneqq\max\{h_{\psi}(S),\texttt{StopGrad}(h_{\psi}(S^{\prime}))\mid S\in{\bar{\zeta}}_{t}(\bar{X}_{i}),S^{\prime}\in\zeta_{t}(\bar{X}_{i})\setminus{\bar{\zeta}}_{t}(\bar{X}_{i})\}.

Then we update the parameters θ,λ\theta,\lambda and ψ\psi using the gradient of the following functions as

Lt,1(θ,λ,ψ)=12mi=1m|ζt(X¯i)||ζ¯t(X¯i)|((hζ¯t(X¯i),y¯i)+(gλ(fθζ¯t(X¯i),y¯i)))\displaystyle L_{t,1}(\theta,\lambda,\psi)=\frac{1}{2m}\sum_{i=1}^{m}\frac{\lvert\zeta_{t}(\bar{X}_{i})\rvert}{\lvert{\bar{\zeta}}_{t}(\bar{X}_{i})\rvert}\left(\ell(h^{{\bar{\zeta}}_{t}}(\bar{X}_{i}),{\bar{y}}_{i})+\ell(g_{\lambda}(f^{{\bar{\zeta}}_{t}}_{\theta}(\bar{X}_{i}),{\bar{y}}_{i}))\right)
Lt,2(θ,λ)=12mi=1m(gλ(fθζ¯t(X¯i),y¯i))\displaystyle L_{t,2}(\theta,\lambda)=\frac{1}{2m}\sum_{i=1}^{m}\ell(g_{\lambda}(f^{{\bar{\zeta}}_{t}}_{\theta}(\bar{X}_{i}),{\bar{y}}_{i}))
θt+1\displaystyle\theta_{t+1} =θtηtLt,1(θt,λt,ψt)θt\displaystyle=\theta_{t}-\eta_{t}\frac{\partial L_{t,1}(\theta_{t},\lambda_{t},\psi_{t})}{\partial\theta_{t}}
ψt+1\displaystyle\psi_{t+1} =ψtηtLt,1(θt,λt,ψt)ψt\displaystyle=\psi_{t}-\eta_{t}\frac{\partial L_{t,1}(\theta_{t},\lambda_{t},\psi_{t})}{\partial\psi_{t}}
λt+1\displaystyle\lambda_{t+1} =λtηtLt,2(θt,λt)λt,\displaystyle=\lambda_{t}-\eta_{t}\frac{\partial L_{t,2}(\theta_{t},\lambda_{t})}{\partial\lambda_{t}},

where ηt>0\eta_{t}>0 is a learning rate. If we assume that there exists a unique maximum value max{wϕ(𝐱)+b:𝐱Xi}\max\{w^{\top}\phi({\mathbf{x}})+b:{\mathbf{x}}\in X_{i}\}\in\mathbb{R} for each i{1,,n}i\in\{1,\ldots,n\} and sample a single subset from ζt(X¯i)\zeta_{t}(\bar{X}_{i}) for i{1,,m}i\in\{1,\ldots,m\}, i.e. |ζ¯t(X¯i)|=1\lvert{\bar{\zeta}}_{t}(\bar{X}_{i})\rvert=1, then the following holds:

𝔼((X¯i,y¯i))i=1m𝔼(ζ¯t(X¯i))i=1m[Lt,1(θ,λ,ψ)θ]\displaystyle\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))_{i=1}^{m}}\mathbb{E}_{({\bar{\zeta}}_{t}(\bar{X}_{i}))_{i=1}^{m}}\left[\frac{\partial L_{t,1}(\theta,\lambda,\psi)}{\partial\theta}\right] =L(θ,λ,ψ)θ\displaystyle=\frac{\partial L(\theta,\lambda,\psi)}{\partial\theta} (32)
𝔼((X¯i,y¯i))i=1m𝔼(ζ¯t(X¯i))i=1m[Lt,1(θ,λ,ψ)ψ]\displaystyle\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))_{i=1}^{m}}\mathbb{E}_{({\bar{\zeta}}_{t}(\bar{X}_{i}))_{i=1}^{m}}\left[\frac{\partial L_{t,1}(\theta,\lambda,\psi)}{\partial\psi}\right] =L(θ,λ,ψ)ψ\displaystyle=\frac{\partial L(\theta,\lambda,\psi)}{\partial\psi} (33)
𝔼((X¯i,y¯i))i=1m𝔼(ζ¯t(X¯i))i=1m[Lt,2(θ,λ)λ]\displaystyle\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))_{i=1}^{m}}\mathbb{E}_{({\bar{\zeta}}_{t}(\bar{X}_{i}))_{i=1}^{m}}\left[\frac{\partial L_{t,2}(\theta,\lambda)}{\partial\lambda}\right] =L(θ,λ,ψ)λ\displaystyle=\frac{\partial L(\theta,\lambda,\psi)}{\partial\lambda} (34)
Proof.

It is enough to show the equation 32 and 33 hold since we have already proved equation 34, which does not depend on ψ\psi, in Theorem 3.10. By defining i(q)(q,yi)\ell_{i}(q)\coloneqq\ell(q,y_{i})\in\mathbb{R}, ui,j=wϕ(𝐱i,j)+bu_{i,j}=w^{\top}\phi({\mathbf{x}}_{i,j})+b where 𝐱i,jXi{\mathbf{x}}_{i,j}\in X_{i} for j{1,,Ni}j\in\{1,\ldots,N_{i}\}, and ui,M=max{ui,j}j=1Niu_{i,M}=\max\{u_{i,j}\}_{j=1}^{N_{i}},

L(θ,λ,ψ)ψ\displaystyle\frac{\partial L(\theta,\lambda,\psi)}{\partial\psi} =1ni=1n12i(hψ(Xi))hψ(Xi)hψ(Xi)ψ\displaystyle=\frac{1}{n}\sum_{i=1}^{n}\frac{1}{2}\frac{\partial\ell_{i}(h_{\psi}(X_{i}))}{\partial h_{\psi}(X_{i})}\frac{\partial h_{\psi}(X_{i})}{\partial\psi}
=12ni=1ni(hψ(Xi))hψ(Xi)max{hψ(S)Sζ(Xi)}ψ\displaystyle=\frac{1}{2n}\sum_{i=1}^{n}\frac{\partial\ell_{i}(h_{\psi}(X_{i}))}{\partial h_{\psi}(X_{i})}\frac{\partial\max\{h_{\psi}(S)\mid S\in\zeta(X_{i})\}}{\partial\psi}
=12ni=1ni(hψ(Xi))hψ(Xi)ui,Mψ\displaystyle=\frac{1}{2n}\sum_{i=1}^{n}\frac{\partial\ell_{i}(h_{\psi}(X_{i}))}{\partial h_{\psi}(X_{i})}\frac{\partial u_{i,M}}{\partial\psi}

for any partition ζ(Xi)\zeta(X_{i}) for all i{1,,m}i\in\{1,\ldots,m\}. Similarly, we define ¯i(q)(q,y¯i){\bar{\ell}}_{i}(q)\coloneqq\ell(q,{\bar{y}}_{i})\in\mathbb{R}, u¯i,jwϕ(𝐱¯i,j)+b\bar{u}_{i,j}\coloneqq w^{\top}\phi(\bar{\mathbf{x}}_{i,j})+b where 𝐱¯i,jX¯i\bar{\mathbf{x}}_{i,j}\in\bar{X}_{i} for j{1,,Ni}j\in\{1,\ldots,N_{i}\}, and u¯i,M=max{u¯i,j}j=1Ni\bar{u}_{i,M}=\max\{\bar{u}_{i,j}\}_{j=1}^{N_{i}}. Let t+t\in\mathbb{N}_{+} be fixed and define,

B¯t,i\displaystyle\bar{B}_{t,i} {hψ(S):Sζ¯t(X¯i)}.\displaystyle\coloneqq\{h_{\psi}(S):S\in{\bar{\zeta}}_{t}(\bar{X}_{i})\}.

With linearity of expectation and properties of the max operation, we have that

𝔼((X¯i,y¯i))i=1m𝔼(ζ¯t(X¯i))i=1m[Lt,1(θ,λ,ψ)ψ]\displaystyle\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))_{i=1}^{m}}\mathbb{E}_{({\bar{\zeta}}_{t}(\bar{X}_{i}))_{i=1}^{m}}\left[\frac{\partial L_{t,1}(\theta,\lambda,\psi)}{\partial\psi}\right]
=12mi=1m𝔼((X¯i,y¯i))i=1m𝔼(ζ¯t(X¯i))i=1m[|ζt(X¯i)||ζ¯t(X¯i)|¯i(hψ(X¯i))hψ(X¯i)𝟙{u¯i,MB¯t,i}u¯i,Mψ]\displaystyle=\frac{1}{2m}\sum_{i=1}^{m}\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))_{i=1}^{m}}\mathbb{E}_{({\bar{\zeta}}_{t}(\bar{X}_{i}))_{i=1}^{m}}\left[\frac{\lvert\zeta_{t}(\bar{X}_{i})\rvert}{\lvert{\bar{\zeta}}_{t}(\bar{X}_{i})\rvert}\frac{\partial{\bar{\ell}}_{i}(h_{\psi}(\bar{X}_{i}))}{\partial h_{\psi}(\bar{X}_{i})}\mathbbm{1}\{\bar{u}_{i,M}\in\bar{B}_{t,i}\}\frac{\partial\bar{u}_{i,M}}{\partial\psi}\right] (35)

Note that we partition the set X¯i\bar{X}_{i} and there is a unique maximum value. Thus, only one subset Sζt(X¯i)S\in\zeta_{t}(\bar{X}_{i}) includes the element leading to the maximum value u¯i,M\bar{u}_{i,M}. If we do not choose such a subset SS, the gradient in equation 35 becomes zero. Since we sample uniformly a single subset S¯{\bar{S}} from ζt(X¯i)\zeta_{t}(\bar{X}_{i}), i.e. ζ¯t(X¯i)={S¯}{\bar{\zeta}}_{t}(\bar{X}_{i})=\{{\bar{S}}\}, we get

𝔼ζ¯t(X¯i)[|ζt(X¯i)||ζ¯t(X¯i)|¯i(hψ(X¯i))hψ(X¯i)𝟙{u¯i,MB¯t,i}u¯i,Mψ]\displaystyle\mathbb{E}_{{\bar{\zeta}}_{t}(\bar{X}_{i})}\left[\frac{\lvert\zeta_{t}(\bar{X}_{i})\rvert}{\lvert{\bar{\zeta}}_{t}(\bar{X}_{i})\rvert}\frac{\partial{\bar{\ell}}_{i}(h_{\psi}(\bar{X}_{i}))}{\partial h_{\psi}(\bar{X}_{i})}\mathbbm{1}\{\bar{u}_{i,M}\in\bar{B}_{t,i}\}\frac{\partial\bar{u}_{i,M}}{\partial\psi}\right] =𝔼S¯[|ζt(X¯i)||ζ¯t(X¯i)|¯i(hψ(X¯i))hψ(X¯i)𝟙{u¯i,M=hψ(S¯)}u¯i,Mψ]\displaystyle=\mathbb{E}_{{\bar{S}}}\left[\frac{\lvert\zeta_{t}(\bar{X}_{i})\rvert}{\lvert{\bar{\zeta}}_{t}(\bar{X}_{i})\rvert}\frac{\partial{\bar{\ell}}_{i}(h_{\psi}(\bar{X}_{i}))}{\partial h_{\psi}(\bar{X}_{i})}\mathbbm{1}\{\bar{u}_{i,M}=h_{\psi}({\bar{S}})\}\frac{\partial\bar{u}_{i,M}}{\partial\psi}\right]
=|ζt(X¯i)|11|ζt(X¯i)|¯i(hψ(X¯i))hψ(X¯i)u¯i,Mψ\displaystyle=\frac{\lvert\zeta_{t}(\bar{X}_{i})\rvert}{1}\frac{1}{\lvert\zeta_{t}(\bar{X}_{i})\rvert}\frac{\partial{\bar{\ell}}_{i}(h_{\psi}(\bar{X}_{i}))}{\partial h_{\psi}(\bar{X}_{i})}\frac{\partial\bar{u}_{i,M}}{\partial\psi}
=¯i(hψ(X¯i))hψ(X¯i)u¯i,Mψ\displaystyle=\frac{\partial{\bar{\ell}}_{i}(h_{\psi}(\bar{X}_{i}))}{\partial h_{\psi}(\bar{X}_{i})}\frac{\partial\bar{u}_{i,M}}{\partial\psi} (36)

If we apply the right hand side of equation 36 to equation 35, we obtain

𝔼((X¯i,y¯i))i=1m𝔼(ζ¯t(X¯i))i=1m[Lt,1(θ,λ,ψ)ψ]\displaystyle\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))_{i=1}^{m}}\mathbb{E}_{({\bar{\zeta}}_{t}(\bar{X}_{i}))_{i=1}^{m}}\left[\frac{\partial L_{t,1}(\theta,\lambda,\psi)}{\partial\psi}\right] =12mi=1m𝔼((X¯i,y¯i))i=1m[¯i(hψ(X¯i))hψ(X¯i)u¯i,Mψ]\displaystyle=\frac{1}{2m}\sum_{i=1}^{m}\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))_{i=1}^{m}}\left[\frac{\partial{\bar{\ell}}_{i}(h_{\psi}(\bar{X}_{i}))}{\partial h_{\psi}(\bar{X}_{i})}\frac{\partial\bar{u}_{i,M}}{\partial\psi}\right]
=12mi=1m1nj=1nj(hψ(Xj))hψ(Xj)uj,Mψ\displaystyle=\frac{1}{2m}\sum_{i=1}^{m}\frac{1}{n}\sum_{j=1}^{n}\frac{\partial\ell_{j}(h_{\psi}(X_{j}))}{\partial h_{\psi}(X_{j})}\frac{\partial u_{j,M}}{\partial\psi}
=12ni=1ni(hψ(Xi))hψ(Xi)ui,Mψ\displaystyle=\frac{1}{2n}\sum_{i=1}^{n}\frac{\partial\ell_{i}(h_{\psi}(X_{i}))}{\partial h_{\psi}(X_{i})}\frac{\partial u_{i,M}}{\partial\psi}
=L(θ,λ,ψ)ψ.\displaystyle=\frac{\partial L(\theta,\lambda,\psi)}{\partial\psi}.

With the chain rule, we get,

L(θ,λ,ψ)θ=12ni=1n(i(hψ(Xi))θ+(igλ)(fθ(Xi))θ).\displaystyle\frac{\partial L(\theta,\lambda,\psi)}{\partial\theta}=\frac{1}{2n}\sum_{i=1}^{n}\left(\frac{\partial\ell_{i}(h_{\psi}(X_{i}))}{\partial\theta}+\frac{\partial(\ell_{i}\circ g_{\lambda})(f_{\theta}(X_{i}))}{\partial\theta}\right).

Since we have already shown that

12ni=1n(igλ)(fθ(Xi))θ=𝔼((X¯i,y¯i))i=1m𝔼(ζ¯t(X¯i))i=1m[12mi=1m|ζt(X¯i)||ζ¯t(X¯i)|(¯ifθζ¯t)(X¯i)θ]\displaystyle\frac{1}{2n}\sum_{i=1}^{n}\frac{\partial(\ell_{i}\circ g_{\lambda})(f_{\theta}(X_{i}))}{\partial\theta}=\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))_{i=1}^{m}}\mathbb{E}_{({\bar{\zeta}}_{t}(\bar{X}_{i}))_{i=1}^{m}}\left[\frac{1}{2m}\sum_{i=1}^{m}\frac{\lvert\zeta_{t}(\bar{X}_{i})\rvert}{\lvert{\bar{\zeta}}_{t}(\bar{X}_{i})\rvert}\frac{\partial({\bar{\ell}}_{i}\circ f^{{\bar{\zeta}}_{t}}_{\theta})(\bar{X}_{i})}{\partial\theta}\right] (37)

in Theorem 3.10, it suffices to show that

12ni=1ni(hψ(Xi))θ=𝔼((X¯i,y¯i))i=1m𝔼(ζ¯t(X¯i))i=1m[12mi=1m|ζt(X¯i)||ζ¯t(X¯i)|¯i(hψζ¯t(X¯i))θ].\displaystyle\frac{1}{2n}\sum_{i=1}^{n}\frac{\partial\ell_{i}(h_{\psi}(X_{i}))}{\partial\theta}=\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))_{i=1}^{m}}\mathbb{E}_{({\bar{\zeta}}_{t}(\bar{X}_{i}))_{i=1}^{m}}\left[\frac{1}{2m}\sum_{i=1}^{m}\frac{\lvert\zeta_{t}(\bar{X}_{i})\rvert}{\lvert{\bar{\zeta}}_{t}(\bar{X}_{i})\rvert}\frac{\partial{\bar{\ell}}_{i}(h^{{\bar{\zeta}}_{t}}_{\psi}(\bar{X}_{i}))}{\partial\theta}\right]. (38)

For the left hand side of equation 38, we have that

12ni=1ni(hψ(Xi))θ=12ni=1ni(hψ(Xi))hψ(Xi)max{hψ(S)Sζ(Xi)}θ=12ni=1ni(hψ(Xi))hψ(Xi)ui,Mθ=12ni=1ni(hψ(Xi))hψ(Xi)(wϕ(𝐱i,M)+b)θ=12ni=1ni(hψ(Xi))hψ(Xi)wϕ(𝐱i,M)θ.\displaystyle\begin{split}\frac{1}{2n}\sum_{i=1}^{n}\frac{\partial\ell_{i}(h_{\psi}(X_{i}))}{\partial\theta}&=\frac{1}{2n}\sum_{i=1}^{n}\frac{\partial\ell_{i}(h_{\psi}(X_{i}))}{\partial h_{\psi}(X_{i})}\frac{\partial\max\{h_{\psi}(S)\mid S\in\zeta(X_{i})\}}{\partial\theta}\\ &=\frac{1}{2n}\sum_{i=1}^{n}\frac{\partial\ell_{i}(h_{\psi}(X_{i}))}{\partial h_{\psi}(X_{i})}\frac{\partial u_{i,M}}{\partial\theta}\\ &=\frac{1}{2n}\sum_{i=1}^{n}\frac{\partial\ell_{i}(h_{\psi}(X_{i}))}{\partial h_{\psi}(X_{i})}\frac{\partial(w^{\top}\phi({\mathbf{x}}_{i,M})+b)}{\partial\theta}\\ &=\frac{1}{2n}\sum_{i=1}^{n}\frac{\partial\ell_{i}(h_{\psi}(X_{i}))}{\partial h_{\psi}(X_{i})}w^{\top}\frac{\partial\phi({\mathbf{x}}_{i,M})}{\partial\theta}.\end{split} (39)

For the right hand side of equation 38, with linearity of expectation, we obtain

𝔼((X¯i,y¯i))i=1m𝔼(ζ¯t(X¯i))i=1m[12mi=1m|ζt(X¯i)||ζ¯t(X¯i)|¯i(hψζ¯t(X¯i))θ]\displaystyle\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))_{i=1}^{m}}\mathbb{E}_{({\bar{\zeta}}_{t}(\bar{X}_{i}))_{i=1}^{m}}\left[\frac{1}{2m}\sum_{i=1}^{m}\frac{\lvert\zeta_{t}(\bar{X}_{i})\rvert}{\lvert{\bar{\zeta}}_{t}(\bar{X}_{i})\rvert}\frac{\partial{\bar{\ell}}_{i}(h^{{\bar{\zeta}}_{t}}_{\psi}(\bar{X}_{i}))}{\partial\theta}\right] =12mi=1m𝔼((X¯i,y¯i))i=1m𝔼(ζ¯t(X¯i))i=1m[|ζt(X¯i)||ζ¯t(X¯i)|¯i(hψζ¯t(X¯i))θ].\displaystyle=\frac{1}{2m}\sum_{i=1}^{m}\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))_{i=1}^{m}}\mathbb{E}_{({\bar{\zeta}}_{t}(\bar{X}_{i}))_{i=1}^{m}}\left[\frac{\lvert\zeta_{t}(\bar{X}_{i})\rvert}{\lvert{\bar{\zeta}}_{t}(\bar{X}_{i})\rvert}\frac{\partial{\bar{\ell}}_{i}(h^{{\bar{\zeta}}_{t}}_{\psi}(\bar{X}_{i}))}{\partial\theta}\right]. (40)

Now we expand the summand in the right hand side,

𝔼ζ¯t(X¯i)[|ζt(X¯i)||ζ¯t(X¯i)|¯i(hψζ¯t(X¯i))θ]\displaystyle\mathbb{E}_{{\bar{\zeta}}_{t}(\bar{X}_{i})}\left[\frac{\lvert\zeta_{t}(\bar{X}_{i})\rvert}{\lvert{\bar{\zeta}}_{t}(\bar{X}_{i})\rvert}\frac{\partial{\bar{\ell}}_{i}(h^{{\bar{\zeta}}_{t}}_{\psi}(\bar{X}_{i}))}{\partial\theta}\right] =𝔼ζ¯t(X¯i)[|ζt(X¯i)||ζ¯t(X¯i)|¯i(hψ(X¯i))hψ(X¯i)𝟙{u¯i,MB¯t,i}u¯i,Mθ]\displaystyle=\mathbb{E}_{{\bar{\zeta}}_{t}(\bar{X}_{i})}\left[\frac{\lvert\zeta_{t}(\bar{X}_{i})\rvert}{\lvert{\bar{\zeta}}_{t}(\bar{X}_{i})\rvert}\frac{\partial{\bar{\ell}}_{i}(h_{\psi}(\bar{X}_{i}))}{\partial h_{\psi}(\bar{X}_{i})}\mathbbm{1}\{\bar{u}_{i,M}\in\bar{B}_{t,i}\}\frac{\partial\bar{u}_{i,M}}{\partial\theta}\right]
=𝔼S¯[|ζt(X¯i)||ζ¯t(X¯i)|¯i(hψ(X¯i))hψ(X¯i)𝟙{u¯i,M=hψ(S¯)}u¯i,Mθ]\displaystyle=\mathbb{E}_{{\bar{S}}}\left[\frac{\lvert\zeta_{t}(\bar{X}_{i})\rvert}{\lvert{\bar{\zeta}}_{t}(\bar{X}_{i})\rvert}\frac{\partial{\bar{\ell}}_{i}(h_{\psi}(\bar{X}_{i}))}{\partial h_{\psi}(\bar{X}_{i})}\mathbbm{1}\{\bar{u}_{i,M}=h_{\psi}({\bar{S}})\}\frac{\partial\bar{u}_{i,M}}{\partial\theta}\right]
=|ζt(X¯i)|11|ζt(X¯i)|¯i(hψ(X¯i))hψ(X¯i)u¯i,Mθ\displaystyle=\frac{\lvert\zeta_{t}(\bar{X}_{i})\rvert}{1}\frac{1}{\lvert\zeta_{t}(\bar{X}_{i})\rvert}\frac{\partial{\bar{\ell}}_{i}(h_{\psi}(\bar{X}_{i}))}{\partial h_{\psi}(\bar{X}_{i})}\frac{\partial\bar{u}_{i,M}}{\partial\theta}
=¯i(hψ(X¯i))hψ(X¯i)wϕ(𝐱¯i,M)θ.\displaystyle=\frac{\partial{\bar{\ell}}_{i}(h_{\psi}(\bar{X}_{i}))}{\partial h_{\psi}(\bar{X}_{i})}w^{\top}\frac{\partial\phi(\bar{\mathbf{x}}_{i,M})}{\partial\theta}. (41)

Now we apply the right hand side of equation 41 to equation 40. Then we get,

𝔼((X¯i,y¯i))i=1m𝔼(ζ¯t(X¯i))i=1m[12mi=1m|ζt(X¯i)||ζ¯t(X¯i)|¯i(hψζ¯t(X¯i))θ]=12mi=1m𝔼((X¯i,y¯i))i=1m[¯i(hψ(X¯i))hψ(X¯i)wϕ(𝐱¯i,M)θ]=12ml=1m1ni=1ni(hψ(Xi))hψ(Xi)wϕ(𝐱i,M)θ=12ni=1ni(hψ(Xi))hψ(Xi)wϕ(𝐱i,M)θ=12ni=1ni(hψ(Xi))θ.\displaystyle\begin{split}\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))_{i=1}^{m}}\mathbb{E}_{({\bar{\zeta}}_{t}(\bar{X}_{i}))_{i=1}^{m}}\left[\frac{1}{2m}\sum_{i=1}^{m}\frac{\lvert\zeta_{t}(\bar{X}_{i})\rvert}{\lvert{\bar{\zeta}}_{t}(\bar{X}_{i})\rvert}\frac{\partial{\bar{\ell}}_{i}(h^{{\bar{\zeta}}_{t}}_{\psi}(\bar{X}_{i}))}{\partial\theta}\right]&=\frac{1}{2m}\sum_{i=1}^{m}\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))_{i=1}^{m}}\left[\frac{\partial{\bar{\ell}}_{i}(h_{\psi}(\bar{X}_{i}))}{\partial h_{\psi}(\bar{X}_{i})}w^{\top}\frac{\partial\phi(\bar{\mathbf{x}}_{i,M})}{\partial\theta}\right]\\ &=\frac{1}{2m}\sum_{l=1}^{m}\frac{1}{n}\sum_{i=1}^{n}\frac{\partial\ell_{i}(h_{\psi}(X_{i}))}{\partial h_{\psi}(X_{i})}w^{\top}\frac{\partial\phi({\mathbf{x}}_{i,M})}{\partial\theta}\\ &=\frac{1}{2n}\sum_{i=1}^{n}\frac{\partial\ell_{i}(h_{\psi}(X_{i}))}{\partial h_{\psi}(X_{i})}w^{\top}\frac{\partial\phi({\mathbf{x}}_{i,M})}{\partial\theta}\\ &=\frac{1}{2n}\sum_{i=1}^{n}\frac{\partial\ell_{i}(h_{\psi}(X_{i}))}{\partial\theta}.\end{split} (42)

Finally combining equation 37 and equation 42, we arrive at the the conclusion:

𝔼((X¯i,y¯i))i=1m𝔼(ζ¯t(X¯i))i=1m[Lt,1(θ,λ,ψ)θ]\displaystyle\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))_{i=1}^{m}}\mathbb{E}_{({\bar{\zeta}}_{t}(\bar{X}_{i}))_{i=1}^{m}}\left[\frac{\partial L_{t,1}(\theta,\lambda,\psi)}{\partial\theta}\right]
=𝔼((X¯i,y¯i))i=1m𝔼(ζ¯t(X¯i))i=1m[12mi=1m|ζt(X¯i)||ζ¯t(X¯i)|(¯i(hψζ¯t(X¯i))θ+(¯igλ)(fθζ¯t(X¯i))θ)]\displaystyle=\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))_{i=1}^{m}}\mathbb{E}_{({\bar{\zeta}}_{t}(\bar{X}_{i}))_{i=1}^{m}}\left[\frac{1}{2m}\sum_{i=1}^{m}\frac{\lvert\zeta_{t}(\bar{X}_{i})\rvert}{\lvert{\bar{\zeta}}_{t}(\bar{X}_{i})\rvert}\left(\frac{\partial{\bar{\ell}}_{i}(h^{{\bar{\zeta}}_{t}}_{\psi}(\bar{X}_{i}))}{\partial\theta}+\frac{\partial({\bar{\ell}}_{i}\circ g_{\lambda})(f^{{\bar{\zeta}}_{t}}_{\theta}(\bar{X}_{i}))}{\partial\theta}\right)\right]
=12ni=1ni(hψ(Xi))θ+12ni=1n(igλ)(fθ(Xi))θ\displaystyle=\frac{1}{2n}\sum_{i=1}^{n}\frac{\partial\ell_{i}(h_{\psi}(X_{i}))}{\partial\theta}+\frac{1}{2n}\sum_{i=1}^{n}\frac{\partial(\ell_{i}\circ g_{\lambda})(f_{\theta}(X_{i}))}{\partial\theta}
=L(θ,λ,ψ)θ.\displaystyle=\frac{\partial L(\theta,\lambda,\psi)}{\partial\theta}.

A.8 SSE’s Training Method Is a Biased Approximation to the Full Set Gradient

In this section, we show that sampling a single subset, and computing the gradient as an approximation to the gradient of L(θ,λ)L(\theta,\lambda), which is proposed by Bruno et al. (2021), is a biased estimation of full set gradient. Since f^θ{\hat{f}}_{\theta} with an attention activation function comprised of ν1\nu_{1} and a sigmoid for σ\sigma is equivalent to a Slot Set Encoder, and is a special case of UMBC, we focus on the gradient of f^θ{\hat{f}}_{\theta}. Specifically, at every iteration t+t\in\mathbb{N}_{+}, we sample a mini-batch ((X¯i,y¯i))i=1m((\bar{X}_{i},{\bar{y}}_{i}))_{i=1}^{m} from the training dataset ((Xi,yi))i=1n((X_{i},y_{i}))_{i=1}^{n}. We choose a partition ζt(X¯i)\zeta_{t}(\bar{X}_{i}) for each X¯i\bar{X}_{i} and sample a single subset S¯i{\bar{S}}_{i} from the partition ζt(X¯i)\zeta_{t}(\bar{X}_{i}). If we compute the gradient of the loss as

λ(1mi=1m(¯igλ)(f^θ(S¯i))),\displaystyle\frac{\partial}{\partial\lambda}\left(\frac{1}{m}\sum_{i=1}^{m}({\bar{\ell}}_{i}\circ g_{\lambda})({\hat{f}}_{\theta}({\bar{S}}_{i}))\right), (43)

then it is a biased estimation of L(θ,λ)λ\frac{\partial L(\theta,\lambda)}{\partial\lambda}, where ¯i(){\bar{\ell}}_{i}(\cdot) is defined by ¯i(q)(q,y¯i){\bar{\ell}}_{i}(q)\coloneqq\ell(q,{\bar{y}}_{i}).

Proof.

The gradient of L(θ,λ)L(\theta,\lambda) with respect to the parameter λ\lambda is

L(θ,λ)λ=1ni=1n(igλ)(f^θ(Xi))λ=1ni=1n(igλ)(Sζ(Xi)f^θ(S))λ\displaystyle\begin{split}\frac{\partial L(\theta,\lambda)}{\partial\lambda}&=\frac{1}{n}\sum_{i=1}^{n}\frac{\partial(\ell_{i}\circ g_{\lambda})({\hat{f}}_{\theta}(X_{i}))}{\partial\lambda}\\ &=\frac{1}{n}\sum_{i=1}^{n}\frac{\partial(\ell_{i}\circ g_{\lambda})\left(\sum_{S\in\zeta(X_{i})}{\hat{f}}_{\theta}(S)\right)}{\partial\lambda}\end{split} (44)

for a partition ζt(Xi)\zeta_{t}(X_{i}) of the set XiX_{i}, where i()\ell_{i}(\cdot) is defined by i(q)(q,yi)\ell_{i}(q)\coloneqq\ell(q,y_{i}). However, the expectation of equation 43 is not equal to the full set gradient in equation 44:

𝔼((X¯i,y¯i))i=1m𝔼S¯i[1mi=1m(¯igλ)(f^θ(S¯i))λ]=1mi=1m𝔼((X¯i,y¯i))i=1m𝔼S¯i[(¯igλ)(f^θ(S¯i))λ]=1ni=1n1|ζt(Xi)|Sζt(Xi)(igλ)(f^θ(S))λ1ni=1n(igλ)(Sζt(Xi)f^θ(S))λ=1ni=1n(igλ)(f^θ(Xi))λ\displaystyle\begin{split}\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))^{m}_{i=1}}\mathbb{E}_{{\bar{S}}_{i}}\left[\frac{1}{m}\sum_{i=1}^{m}\frac{\partial({\bar{\ell}}_{i}\circ g_{\lambda})({\hat{f}}_{\theta}({\bar{S}}_{i}))}{\partial\lambda}\right]&=\frac{1}{m}\sum_{i=1}^{m}\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))^{m}_{i=1}}\mathbb{E}_{{\bar{S}}_{i}}\left[\frac{\partial({\bar{\ell}}_{i}\circ g_{\lambda})({\hat{f}}_{\theta}({\bar{S}}_{i}))}{\partial\lambda}\right]\\ &=\frac{1}{n}\sum_{i=1}^{n}\frac{1}{|\zeta_{t}(X_{i})|}\sum_{S\in\zeta_{t}(X_{i})}\frac{\partial(\ell_{i}\circ g_{\lambda})({\hat{f}}_{\theta}(S))}{\partial\lambda}\\ &{\color[rgb]{1,0,0}\definecolor[named]{pgfstrokecolor}{rgb}{1,0,0}\neq}\frac{1}{n}\sum_{i=1}^{n}\frac{\partial(\ell_{i}\circ g_{\lambda})\left(\sum_{S\in\zeta_{t}(X_{i})}{\hat{f}}_{\theta}(S)\right)}{\partial\lambda}\\ &=\frac{1}{n}\sum_{i=1}^{n}\frac{\partial(\ell_{i}\circ g_{\lambda})(\hat{f}_{\theta}(X_{i}))}{\partial\lambda}\end{split} (45)

To see why this is the case, we analyze the case of real valued function gλ:dg_{\lambda}:\mathbb{R}^{d}\to\mathbb{R} with λd\lambda\in\mathbb{R}^{d} and a squared loss function

(zi,yi)12(ziyi)2\displaystyle\ell(z_{i},y_{i})\coloneqq\frac{1}{2}(z_{i}-y_{i})^{2}
ziλf^θ(Xi)gλ(f^θ(Xi)).\displaystyle z_{i}\coloneqq\lambda^{\top}{\hat{f}}_{\theta}(X_{i})\coloneqq g_{\lambda}({\hat{f}}_{\theta}(X_{i}))\in\mathbb{R}.

Since f^θ{\hat{f}}_{\theta} is sum decomposable, i.e. f^θ(Xi)=j=1Nif^θ(𝐱i,j){\hat{f}}_{\theta}(X_{i})=\sum_{j=1}^{N_{i}}{\hat{f}}_{\theta}({\mathbf{x}}_{i,j}) where Xi={𝐱i,1,,𝐱i,Ni}X_{i}=\{{\mathbf{x}}_{i,1},\ldots,{\mathbf{x}}_{i,N_{i}}\}, the full set gradient from equation 44 becomes,

1ni=1nλ(12(ziyi)2)=1ni=1n(ziyi)ziλ=1ni=1n(ziyi)f^θ(Xi)=1ni=1n(ziyi)(l=1Nif^θ(𝐱i,l))=1ni=1n(λ(j=1Nif^θ(𝐱i,j))yi)(l=1Nif^θ(𝐱i,l))=1ni=1n((j=1Niλf^θ(𝐱i,j))yi)(l=1Nif^θ(𝐱i,l)).=1ni=1n(j=1Niλf^θ(𝐱i,j)yiNi)(l=1Nif^θ(𝐱i,l)).\displaystyle\begin{split}\frac{1}{n}\sum_{i=1}^{n}\frac{\partial}{\partial\lambda}\left(\frac{1}{2}(z_{i}-y_{i})^{2}\right)&=\frac{1}{n}\sum_{i=1}^{n}(z_{i}-y_{i})\frac{\partial z_{i}}{\partial\lambda}\\ &=\frac{1}{n}\sum_{i=1}^{n}(z_{i}-y_{i}){\hat{f}}_{\theta}(X_{i})\\ &=\frac{1}{n}\sum_{i=1}^{n}(z_{i}-y_{i})\left(\sum_{l=1}^{N_{i}}{\hat{f}}_{\theta}({\mathbf{x}}_{i,l})\right)\\ &=\frac{1}{n}\sum_{i=1}^{n}\left(\lambda^{\top}\left(\sum_{j=1}^{N_{i}}{\hat{f}}_{\theta}({\mathbf{x}}_{i,j})\right)-y_{i}\right)\left(\sum_{l=1}^{N_{i}}{\hat{f}}_{\theta}({\mathbf{x}}_{i,l})\right)\\ &=\frac{1}{n}\sum_{i=1}^{n}\left(\left(\sum_{j=1}^{N_{i}}\lambda^{\top}{\hat{f}}_{\theta}({\mathbf{x}}_{i,j})\right)-y_{i}\right)\left(\sum_{l=1}^{N_{i}}{\hat{f}}_{\theta}({\mathbf{x}}_{i,l})\right).\\ &=\frac{1}{n}\sum_{i=1}^{n}\left(\sum_{j=1}^{N_{i}}\lambda^{\top}{\hat{f}}_{\theta}({\mathbf{x}}_{i,j})-\frac{y_{i}}{N_{i}}\right)\left(\sum_{l=1}^{N_{i}}{\hat{f}}_{\theta}({\mathbf{x}}_{i,l})\right).\end{split} (46)

Assume that ζt(Xi)={{𝐱i,1},,{𝐱i,Ni}}\zeta_{t}(X_{i})=\{\{{\mathbf{x}}_{i,1}\},\ldots,\{{\mathbf{x}}_{i,N_{i}}\}\} and we sample a single subset S¯i{\bar{S}}_{i} from the partition ζt(X¯i)\zeta_{t}(\bar{X}_{i}) for all i{1,,m}i\in\{1,\ldots,m\} and t+t\in\mathbb{N}_{+}. Then gradient of the subsampling a single subset from equation 45 becomes,

1ni=1n1Nij=1Ni12(λf^θ(𝐱i,j)yi)2λ=1ni=1nj=1Ni(λf^θ(𝐱i,j)yiNi)f^θ(𝐱i,j)1ni=1n(j=1Niλf^θ(𝐱i,j)yiNi)(l=1Nif^θ(𝐱i,l)).\displaystyle\begin{split}\frac{1}{n}\sum_{i=1}^{n}\frac{1}{N_{i}}\sum_{j=1}^{N_{i}}\frac{1}{2}\frac{\partial\left(\lambda^{\top}{\hat{f}}_{\theta}({\mathbf{x}}_{i,j})-y_{i}\right)^{2}}{\partial\lambda}&=\frac{1}{n}\sum_{i=1}^{n}\sum_{j=1}^{N_{i}}\left(\frac{\lambda^{\top}{\hat{f}}_{\theta}({\mathbf{x}}_{i,j})-y_{i}}{N_{i}}\right){\hat{f}}_{\theta}({\mathbf{x}}_{i,j})\\ &{\color[rgb]{1,0,0}\definecolor[named]{pgfstrokecolor}{rgb}{1,0,0}\neq}\frac{1}{n}\sum_{i=1}^{n}\left(\sum_{j=1}^{N_{i}}\lambda^{\top}{\hat{f}}_{\theta}({\mathbf{x}}_{i,j})-\frac{y_{i}}{N_{i}}\right)\left(\sum_{l=1}^{N_{i}}{\hat{f}}_{\theta}({\mathbf{x}}_{i,l})\right).\end{split} (47)

Therefore, the random subsampling of a a single subset in the method proposed by Bruno et al. (2021) is not an unbiased estimate of the gradient of the full set. ∎

Appendix B Optimization

Define θ¯=(θ,λ)\bar{\theta}=(\theta,\lambda) and gt(θ¯)=(Lt,1(θ,λ)θ,Lt,2(θ,λ)λ)g_{t}(\bar{\theta})=(\frac{\partial L_{t,1}(\theta,\lambda)}{\partial\theta},\frac{\partial L_{t,2}(\theta,\lambda)}{\partial\lambda})^{\top}. We assume that (q,y)0\ell(q,y)\geq 0 for all (q,y)(q,y). Let (θ¯t)t+(\bar{\theta}_{t})_{t\in\mathbb{N}_{+}} be a sequence generated by θ¯t+1=θ¯tηtgt(θ¯)\bar{\theta}_{t+1}=\bar{\theta}_{t}-\eta_{t}g_{t}(\bar{\theta}) with an initial point θ¯1\bar{\theta}_{1} and a step size sequence (ηt)t+(\eta_{t})_{t\in\mathbb{N}_{+}}, where θ¯tD\bar{\theta}_{t}\in\mathcal{R}\subseteq\mathbb{R}^{D} for t+t\in\mathbb{N}_{+} with an open convex set \mathcal{R}. Here, D\mathbb{R}^{D} is an open set and thus it is allowed to choose =D\mathcal{R}=\mathbb{R}^{D} (or any other open convex set). We do not assume that the loss function or model is convex. We also do not make any assumption on the initial point θ¯1\bar{\theta}_{1}. To analyze the optimization behavior formally, we consider the following standard assumption in the literature (Lee et al., 2016; Mertikopoulos et al., 2020; Fehrman et al., 2020):

Assumption B.1.

There exist ς>0\varsigma>0 such that for any θ¯,θ¯\bar{\theta},\bar{\theta}^{\prime}\in\mathcal{R}, t[T]t\in[T], and k{1,2}k\in\{1,2\},

Lt,k(θ¯)Lt,k(θ¯)2ςθ¯θ¯2.\displaystyle\|\nabla L_{t,k}(\bar{\theta})-\nabla L_{t,k}(\bar{\theta}^{\prime})\|_{2}\leq\varsigma\|\bar{\theta}-\bar{\theta}^{\prime}\|_{2}.

We use the following lemma on a general function from a previous work (Kawaguchi et al., 2022, Lemma 2):

Lemma B.2.

For any differentiable function φ:dom(φ)\varphi:\operatorname{dom}(\varphi)\rightarrow\mathbb{R} with an open convex set dom(φ)nφ\operatorname{dom}(\varphi)\subseteq\mathbb{R}^{n_{\varphi}}, if φ(z)φ(z)2ςφzz2\|\nabla\varphi(z^{\prime})-\nabla\varphi(z)\|_{2}\leq\varsigma_{\varphi}\|z^{\prime}-z\|_{2} for all z,zdom(φ)z,z^{\prime}\in\operatorname{dom}(\varphi), then

φ(z)φ(z)+φ(z)(zz)+ςφ2zz22for all z,zdom(φ).\displaystyle\varphi(z^{\prime})\leq\varphi(z)+\nabla\varphi(z)^{\top}(z^{\prime}-z)+\frac{\varsigma_{\varphi}}{2}\|z^{\prime}-z\|^{2}_{2}\quad\text{for all $z,z^{\prime}\in\operatorname{dom}(\varphi)$}. (48)

In turn, Lemma B.2 implies the following lemma:

Lemma B.3.

For any differentiable function φ:dom(φ)0\varphi:\operatorname{dom}(\varphi)\rightarrow\mathbb{R}_{\geq 0} with an open convex set dom(φ)nφ\operatorname{dom}(\varphi)\subseteq\mathbb{R}^{n_{\varphi}} such that φ(z)φ(z)ςφzz\|\nabla\varphi(z^{\prime})-\nabla\varphi(z)\|\leq\varsigma_{\varphi}\|z^{\prime}-z\| for all z,zdom(φ)z,z^{\prime}\in\operatorname{dom}(\varphi), the following holds: for all zdom(φ)z\in\operatorname{dom}(\varphi) such that z1ςφφ(z)dom(φ)z-\frac{1}{\varsigma_{\varphi}}\nabla\varphi(z)\in\operatorname{dom}(\varphi),

φ(z)222ςφφ(z)for all zdom(φ).\displaystyle\|\nabla\varphi(z)\|_{2}^{2}\leq 2\varsigma_{\varphi}\varphi(z)\quad\text{for all $z\in\operatorname{dom}(\varphi)$}. (49)
Proof.

Since φ:dom(φ)0\varphi:\operatorname{dom}(\varphi)\rightarrow\mathbb{R}_{\geq 0} (nonnegative), if φ(z)=0\nabla\varphi(z)=0, the desired statement holds. Thus, we consider the remaining case of φ(z)0\nabla\varphi(z)\neq 0 in the rest of this proof. We invoke Lemma B.2 with z=z1ςφφ(z)z^{\prime}=z-\frac{1}{\varsigma_{\varphi}}\nabla\varphi(z), yielding

0φ(z)\displaystyle 0\leq\varphi(z^{\prime}) φ(z)+φ(z)(zz)+ςφ2zz22\displaystyle\leq\varphi(z)+\nabla\varphi(z)^{\top}(z^{\prime}-z)+\frac{\varsigma_{\varphi}}{2}\|z^{\prime}-z\|^{2}_{2}
=φ(z)1ςφφ(z)22+12ςφφ(z)22\displaystyle=\varphi(z)-\frac{1}{\varsigma_{\varphi}}\|\nabla\varphi(z)\|_{2}^{2}+\frac{1}{2\varsigma_{\varphi}}\|\nabla\varphi(z)\|^{2}_{2}
=φ(z)12ςφφ(z)22\displaystyle=\varphi(z)-\frac{1}{2\varsigma_{\varphi}}\|\nabla\varphi(z)\|_{2}^{2}

By rearranging, this implies that φ(z)222ςφφ(z)\|\nabla\varphi(z)\|_{2}^{2}\leq 2\varsigma_{\varphi}\varphi(z). ∎

Since we are dealing with a general non-convex and non-invex function (as the choice of architecture and loss is very flexible) where gradient-based optimization might only converge to a stationary point (to avoid the curse of dimensionality), we consider the convergence in terms of stationary points of LL:

Theorem B.4.

Suppose that Assumption B.1 holds and the step size sequence (ηt)t+(\eta_{t})_{t\in\mathbb{N}_{+}} satisfies t=1ηt2<\sum_{t=1}^{\infty}\eta_{t}^{2}<\infty. Then, there exists a constant cc independent of (t,T)(t,T) such that

mint[T]𝔼[L(θ¯t)22]c𝔼[L(θ¯1)]t=1Tηt.\min_{t\in[T]}\mathbb{E}[\|\nabla L(\bar{\theta}_{t})\|_{2}^{2}]\leq\frac{c\mathbb{E}[L(\bar{\theta}_{1})]}{\sum_{t=1}^{T}\eta_{t}}.
Proof.

Assumption B.1 implies that

Lt,k(θ¯)Lt,k(θ¯)22ς2θ¯θ¯22\|\nabla L_{t,k}(\bar{\theta})-\nabla L_{t,k}(\bar{\theta}^{\prime})\|_{2}^{2}\leq\varsigma^{2}\|\bar{\theta}-\bar{\theta}^{\prime}\|_{2}^{2}

which implies that

gt(θ¯)gt(θ¯)22=Lt,1(θ¯)Lt,1(θ¯)22+Lt,2(θ¯)Lt,2(θ¯)222ς2θ¯θ¯22.\|g_{t}(\bar{\theta})-g_{t}(\bar{\theta}^{\prime})\|_{2}^{2}=\|\nabla L_{t,1}(\bar{\theta})-\nabla L_{t,1}(\bar{\theta}^{\prime})\|_{2}^{2}+\|\nabla L_{t,2}(\bar{\theta})-\nabla L_{t,2}(\bar{\theta}^{\prime})\|_{2}^{2}\leq 2\varsigma^{2}\|\bar{\theta}-\bar{\theta}^{\prime}\|_{2}^{2}.

This implies that

gt(θ¯)gt(θ¯)22ςθ¯θ¯2\|g_{t}(\bar{\theta})-g_{t}(\bar{\theta}^{\prime})\|_{2}\leq\sqrt{2}\varsigma\|\bar{\theta}-\bar{\theta}^{\prime}\|_{2}

Using this and Theorem 3.10 along with Jensen’s inequality, we have that for any θ¯,θ¯\bar{\theta},\bar{\theta}^{\prime}\in\mathcal{R},

L(θ¯)L(θ¯)2=𝔼[gt(θ¯)]𝔼[gt(θ¯)]2𝔼[gt(θ¯)gt(θ¯)2]2ςθ¯θ¯2.\displaystyle\|\nabla L(\bar{\theta})-\nabla L(\bar{\theta}^{\prime})\|_{2}=\|\mathbb{E}[g_{t}(\bar{\theta})]-\mathbb{E}[g_{t}(\bar{\theta}^{\prime})]\|_{2}\leq\mathbb{E}[\|g_{t}(\bar{\theta})-g_{t}(\bar{\theta}^{\prime})\|_{2}]\leq\sqrt{2}\varsigma\|\bar{\theta}-\bar{\theta}^{\prime}\|_{2}.

Thus, LL satisfies the conditions of Lemma B.2 and Lemma B.3. Since θ¯tD\bar{\theta}_{t}\in\mathcal{R}\subseteq\mathbb{R}^{D} for t+t\in\mathbb{N}_{+}, using Lemma B.2 for the function LL, we have that

L(θ¯t+1)L(θ¯t)+L(θ¯t)(θ¯t+1θ¯t)+2ςθ¯t+1θ¯t222.\displaystyle L(\bar{\theta}_{t+1})\leq L(\bar{\theta}_{t})+\nabla L(\bar{\theta}_{t})^{\top}(\bar{\theta}_{t+1}-\bar{\theta}_{t})+\frac{\sqrt{2}\varsigma\|\bar{\theta}_{t+1}-\bar{\theta}_{t}\|_{2}^{2}}{2}.

Using θ¯t+1θ¯t=ηtgt(θ¯t)\bar{\theta}_{t+1}-\bar{\theta}_{t}=-\eta_{t}g_{t}(\bar{\theta}_{t}),

L(θ¯t+1)L(θ¯t)ηtL(θ¯t)gt(θ¯t)+2ςηt2gt(θ¯t)222.L(\bar{\theta}_{t+1})\leq L(\bar{\theta}_{t})-\eta_{t}\nabla L(\bar{\theta}_{t})^{\top}g_{t}(\bar{\theta}_{t})+\frac{\sqrt{2}\varsigma\eta_{t}^{2}\|g_{t}(\bar{\theta}_{t})\|_{2}^{2}}{2}.

Using Lemma B.3 for gt(θ¯t)22=Lt,1(θt,λt)θt22+Lt,2(θt,λt)λt22\|g_{t}(\bar{\theta}_{t})\|_{2}^{2}=\|\frac{\partial L_{t,1}(\theta_{t},\lambda_{t})}{\partial\theta_{t}}\|_{2}^{2}+\|\frac{\partial L_{t,2}(\theta_{t},\lambda_{t})}{\partial\lambda_{t}}\|_{2}^{2}, we have that

L(θ¯t+1)L(θ¯t)ηtL(θ¯t)gt(θ¯)+2ς2ηt2(Lt,1(θ¯t)+Lt,2(θ¯t)).L(\bar{\theta}_{t+1})\leq L(\bar{\theta}_{t})-\eta_{t}\nabla L(\bar{\theta}_{t})^{\top}g_{t}(\bar{\theta})+\sqrt{2}\varsigma^{2}\eta_{t}^{2}(L_{t,1}(\bar{\theta}_{t})+L_{t,2}(\bar{\theta}_{t})).

Define Lt(θ¯t)=Lt,1(θ¯t)+Lt,2(θ¯t)L_{t}(\bar{\theta}_{t})=L_{t,1}(\bar{\theta}_{t})+L_{t,2}(\bar{\theta}_{t}). Using the linearity and monotonicity of expectation,

𝔼((X¯i,y¯i))i=1m𝔼(ζ¯t(X¯i))i=1m[L(θ¯t+1)|θ¯t]\displaystyle\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))^{m}_{i=1}}\mathbb{E}_{({\bar{\zeta}}_{t}(\bar{X}_{i}))^{m}_{i=1}}[L(\bar{\theta}_{t+1})|\bar{\theta}_{t}]
L(θ¯t)ηtL(θ¯t)𝔼((X¯i,y¯i))i=1m𝔼(ζ¯t(X¯i))i=1m[gt(θ¯)]+2ς2ηt2𝔼((X¯i,y¯i))i=1m𝔼(ζ¯t(X¯i))i=1m[Lt(θ¯t)]\displaystyle\leq L(\bar{\theta}_{t})-\eta_{t}\nabla L(\bar{\theta}_{t})^{\top}\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))^{m}_{i=1}}\mathbb{E}_{({\bar{\zeta}}_{t}(\bar{X}_{i}))^{m}_{i=1}}[g_{t}(\bar{\theta})]+\sqrt{2}\varsigma^{2}\eta_{t}^{2}\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))^{m}_{i=1}}\mathbb{E}_{({\bar{\zeta}}_{t}(\bar{X}_{i}))^{m}_{i=1}}[L_{t}(\bar{\theta}_{t})]
L(θ¯t)ηtL(θ¯t)22+2ς2ηt2(1+a)L(θ¯t)\displaystyle\leq L(\bar{\theta}_{t})-\eta_{t}\|\nabla L(\bar{\theta}_{t})\|_{2}^{2}+\sqrt{2}\varsigma^{2}\eta_{t}^{2}(1+a)L(\bar{\theta}_{t})

where the second inequality follows from Theorem 3.10 and 𝔼((X¯i,y¯i))i=1m𝔼(ζ¯t(X¯i))i=1m[Lt(θ¯t)]=𝔼((X¯i,y¯i))i=1m𝔼(ζ¯t(X¯i))i=1m[Lt,1(θ¯t)]+𝔼((X¯i,y¯i))i=1m𝔼(ζ¯t(X¯i))i=1m[Lt,2(θ¯t)](1+a)L(θ¯t)\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))^{m}_{i=1}}\mathbb{E}_{({\bar{\zeta}}_{t}(\bar{X}_{i}))^{m}_{i=1}}[L_{t}(\bar{\theta}_{t})]=\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))^{m}_{i=1}}\mathbb{E}_{({\bar{\zeta}}_{t}(\bar{X}_{i}))^{m}_{i=1}}[L_{t,1}(\bar{\theta}_{t})]+\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))^{m}_{i=1}}\mathbb{E}_{({\bar{\zeta}}_{t}(\bar{X}_{i}))^{m}_{i=1}}[L_{t,2}(\bar{\theta}_{t})]\leq(1+a)L(\bar{\theta}_{t}) where aa is the expectation of the maximum ratio |ζt(X¯i)||ζ¯t(X¯i)|\frac{|\zeta_{t}(\bar{X}_{i})|}{|{\bar{\zeta}}_{t}(\bar{X}_{i})|}.

Taking expectation over θ¯t\bar{\theta}_{t} with the law of total expectation 𝔼[L(θ¯t+1)]=𝔼θ¯t𝔼((X¯i,y¯i))i=1m𝔼(ζ¯t(X¯i))i=1m[L(θ¯t+1)|θ¯t]\mathbb{E}[L(\bar{\theta}_{t+1})]=\mathbb{E}_{\bar{\theta}_{t}}\mathbb{E}_{((\bar{X}_{i},{\bar{y}}_{i}))^{m}_{i=1}}\mathbb{E}_{({\bar{\zeta}}_{t}(\bar{X}_{i}))^{m}_{i=1}}[L(\bar{\theta}_{t+1})|\bar{\theta}_{t}],

𝔼[L(θ¯t+1)]𝔼[L(θ¯t)]ηt𝔼[L(θ¯t)22]+2ς2ηt2(1+a)𝔼[L(θ¯t)].\displaystyle\mathbb{E}[L(\bar{\theta}_{t+1})]\leq\mathbb{E}[L(\bar{\theta}_{t})]-\eta_{t}\mathbb{E}[\|\nabla L(\bar{\theta}_{t})\|_{2}^{2}]+\sqrt{2}\varsigma^{2}\eta_{t}^{2}(1+a)\mathbb{E}[L(\bar{\theta}_{t})]. (50)

Since 𝔼[L(θ¯t)22]0\mathbb{E}[\|\nabla L(\bar{\theta}_{t})\|_{2}^{2}]\geq 0, this implies that

𝔼[L(θ¯t+1)]\displaystyle\mathbb{E}[L(\bar{\theta}_{t+1})] 𝔼[L(θ¯t)]+2ς2ηt2(1+a)𝔼[L(θ¯t)]\displaystyle\leq\mathbb{E}[L(\bar{\theta}_{t})]+\sqrt{2}\varsigma^{2}\eta_{t}^{2}(1+a)\mathbb{E}[L(\bar{\theta}_{t})]
=(1+2ς2ηt2(1+a))𝔼[L(θ¯t)]\displaystyle=(1+\sqrt{2}\varsigma^{2}\eta_{t}^{2}(1+a))\mathbb{E}[L(\bar{\theta}_{t})]
exp(2ς2ηt2(1+a))𝔼[L(θ¯t)],\displaystyle\leq\exp(\sqrt{2}\varsigma^{2}\eta_{t}^{2}(1+a))\mathbb{E}[L(\bar{\theta}_{t})],

where the last inequality follows from 1+qexp(q)1+q\leq\exp(q) for all qq\in\mathbb{R}. Applying this inequality recursively over tt, it holds that for any t+t\in\mathbb{N}_{+},

𝔼[L(θ¯t)]exp(2ς2(1+a)j=1t1ηj2)𝔼[L(θ¯1)].\mathbb{E}[L(\bar{\theta}_{t})]\leq\exp\left(\sqrt{2}\varsigma^{2}(1+a)\sum_{j=1}^{t-1}\eta_{j}^{2}\right)\mathbb{E}[L(\bar{\theta}_{1})].

Using this inequality in equation 50,

𝔼[L(θ¯t+1)]𝔼[L(θ¯t)]ηt𝔼[L(θ¯t)22]+2ς2ηt2(1+a)exp(2ς2(1+a)j=1t1ηj2)𝔼[L(θ¯1)].\mathbb{E}[L(\bar{\theta}_{t+1})]\leq\mathbb{E}[L(\bar{\theta}_{t})]-\eta_{t}\mathbb{E}[\|\nabla L(\bar{\theta}_{t})\|_{2}^{2}]+\sqrt{2}\varsigma^{2}\eta_{t}^{2}(1+a)\exp\left(\sqrt{2}\varsigma^{2}(1+a)\sum_{j=1}^{t-1}\eta_{j}^{2}\right)\mathbb{E}[L(\bar{\theta}_{1})].

Rearranging and summing over tt with,

t=1Tηt𝔼[L(θ¯t)22]\displaystyle\sum_{t=1}^{T}\eta_{t}\mathbb{E}[\|\nabla L(\bar{\theta}_{t})\|_{2}^{2}] t=1T(𝔼[L(θ¯t)]𝔼[L(θ¯t+1)])+2ς2(1+a)𝔼[L(θ¯1)]t=1Tηt2exp(2ς2(1+a)j=1t1ηj2)\displaystyle\leq\sum_{t=1}^{T}(\mathbb{E}[L(\bar{\theta}_{t})]-\mathbb{E}[L(\bar{\theta}_{t+1})])+\sqrt{2}\varsigma^{2}(1+a)\mathbb{E}[L(\bar{\theta}_{1})]\sum_{t=1}^{T}\eta_{t}^{2}\exp\left(\sqrt{2}\varsigma^{2}(1+a)\sum_{j=1}^{t-1}\eta_{j}^{2}\right)
𝔼[L(θ¯1)]𝔼[L(θ¯T+1)]+2ς2(1+a)𝔼[L(θ¯1)]exp(2ς2(1+a)j=1T1ηj2)(t=1Tηt2)\displaystyle\leq\mathbb{E}[L(\bar{\theta}_{1})]-\mathbb{E}[L(\bar{\theta}_{T+1})]+\sqrt{2}\varsigma^{2}(1+a)\mathbb{E}[L(\bar{\theta}_{1})]\exp\left(\sqrt{2}\varsigma^{2}(1+a)\sum_{j=1}^{T-1}\eta_{j}^{2}\right)\left(\sum_{t=1}^{T}\eta_{t}^{2}\right)
=(1+2ς2(1+a)RTexp(2ς2(1+a)RT1))𝔼[L(θ¯1)]𝔼[L(θ¯T+1)]\displaystyle=(1+\sqrt{2}\varsigma^{2}(1+a)R_{T}\exp(\sqrt{2}\varsigma^{2}(1+a)R_{T-1}))\mathbb{E}[L(\bar{\theta}_{1})]-\mathbb{E}[L(\bar{\theta}_{T+1})]

where we define RT=t=1Tηt2R_{T}=\sum_{t=1}^{T}\eta_{t}^{2}. Since ηt>0\eta_{t}>0 and 𝔼[L(θ¯t)22]>0\mathbb{E}[\nabla\left\lVert L(\bar{\theta}_{t})\right\rVert^{2}_{2}]>0 for all t[T]{1,,T}t\in[T]\coloneqq\{1,\ldots,T\},

mint[T]𝔼[L(θ¯t)22](t=1Tηt)\displaystyle\min_{t\in[T]}\mathbb{E}[\left\lVert\nabla L(\bar{\theta}_{t})\right\rVert^{2}_{2}]\left(\sum_{t=1}^{T}\eta_{t}\right) =t=1Tηtmint[T]𝔼[L(θ¯t)22]\displaystyle=\sum_{t=1}^{T}\eta_{t}\min_{t^{\prime}\in[T]}\mathbb{E}[\left\lVert\nabla L(\bar{\theta}_{t^{\prime}})\right\rVert^{2}_{2}]
t=1Tηt𝔼[L(θ¯t)22].\displaystyle\leq\sum_{t=1}^{T}\eta_{t}\mathbb{E}[\left\lVert\nabla L(\bar{\theta}_{t})\right\rVert_{2}^{2}].

This implies that

mint[T]𝔼[L(θ¯t)22](t=1Tηt)1((1+2ς2(1+a)RTexp(2ς2(1+a)RT1))𝔼[L(θ¯1)]𝔼[L(θ¯T+1)]).\min_{t\in[T]}\mathbb{E}[\|\nabla L(\bar{\theta}_{t})\|_{2}^{2}]\leq\left(\sum_{t=1}^{T}\eta_{t}\right)^{-1}\left((1+\sqrt{2}\varsigma^{2}(1+a)R_{T}\exp(\sqrt{2}\varsigma^{2}(1+a)R_{T-1}))\mathbb{E}[L(\bar{\theta}_{1})]-\mathbb{E}[L(\bar{\theta}_{T+1})]\right).

Since RT1RTt=1ηt2<R_{T-1}\leq R_{T}\leq\sum_{t=1}^{\infty}\eta_{t}^{2}<\infty and 𝔼[L(θ¯T+1)]0\mathbb{E}[L(\bar{\theta}_{T+1})]\geq 0, this implies that there exists a constant cc independent of (t,T)(t,T) such that

mint[T]𝔼[L(θ¯t)22]c𝔼[L(θ¯1)](t=1Tηt)1.\min_{t\in[T]}\mathbb{E}[\|\nabla L(\bar{\theta}_{t})\|_{2}^{2}]\leq c\mathbb{E}[L(\bar{\theta}_{1})]\left(\sum_{t=1}^{T}\eta_{t}\right)^{-1}.

For example, if ηt=η1tq\eta_{t}=\eta_{1}t^{-q} with q(0.5,1)q\in(0.5,1) and η1>0\eta_{1}>0, then we have mint[T]𝔼[L(θ¯t)22]=𝒪(1T1q)\min_{t\in[T]}\mathbb{E}[\|\nabla L(\bar{\theta}_{t})\|_{2}^{2}]=\mathcal{O}(\frac{1}{T^{1-q}}).

Appendix C Details on the Mixture of Gaussians Amortized Clustering Experiment

We used a modified version of the MoG amortized clustering dataset which was used by Lee et al. (2019). We modified the experiment, adding separate, random covariance parameters into the procedure in order to make a more difficult dataset. Specifically, to sample a single task for a problem with KK classes,

  1. 1.

    Sample set size for the batch NU(train set size/2,train set size)N\sim U(\text{train set size}/2,\text{train set size}).

  2. 2.

    Sample class priors πDirichlet([α1,,αK])\pi\sim\text{Dirichlet}([\alpha_{1},\ldots,\alpha_{K}]) with α1==αK=1\alpha_{1}=\cdots=\alpha_{K}=1.

  3. 3.

    Sample class labels ziCategorical(π)z_{i}\sim\text{Categorical}(\pi) for i=1,,Ni=1,...,N.

  4. 4.

    Generate cluster centers 𝝁i=(μi,1,μi,2)2\bm{\mu}_{i}=({\mu}_{i,1},{\mu}_{i,2})\in\mathbb{R}^{2}, where μi,jU(4,4){\mu}_{i,j}\sim U(-4,4) for i=1,,Ki=1,...,K and j=1,2j=1,2.

  5. 5.

    Generate cluster covariance matrices 𝚺i=diag(σi,1,σi,2)2×2\bm{\Sigma}_{i}=\text{diag}({\sigma}_{i,1},{\sigma}_{i,2})\in\mathbb{R}^{2\times 2}, where σijU(0.3,0.6){\sigma}_{ij}\sim U(0.3,0.6) for i=1,,Ki=1,...,K and j=1,2j=1,2.

  6. 6.

    For all znz_{n}, if zn=iz_{n}=i, sample data 𝐱n𝒩(𝝁i,𝚺i)\mathbf{x}_{n}\sim\mathcal{N}(\bm{\mu}_{i},\bm{\Sigma}_{i})

In our MoG experiments, we set K=4K=4. The Motivational Example in Figure 2 also used the MoG dataset, and performed MBC testing of the set transformer corresponding to the procedure outlined in Appendix E

Refer to caption
(a) Set Transformer
Refer to caption
(b) Set Transformer
Refer to caption
(c) Set Transformer
Refer to caption
(d) UMBC
Refer to caption
(e) UMBC
Refer to caption
(f) UMBC
Figure 8: Top Row: Set Transformer provides inconsistent predictions on streaming sets when inputs cannot be stored directly. Bottom Row: UMBC+Set Transformer gives consistent predictions in all streaming settings.

C.1 Streaming Settings

The four total streaming settings in Figures 2 and 8 are described below:

  • single point stream \rightarrow streams each point in the set one by one. This causes the most severe under-performance by non-MBC models.

  • class stream \rightarrow streams an entire class at once. Models which make complex pairwise comparisons cannot compare the input class with any other clusters, thereby degrading performance of models such as the Set Transformer.

  • chunk stream \rightarrow streams 8 random points at a time from the dataset, Providing, random and limited information to non-MBC models.

  • one each stream \rightarrow streams a set consisting of a single instance from each class. non-MBC models can see examples of each class, but with a limited sample size, therefore non-MBC models such as Set Transformer fail to make accurate predictions.

C.2 Experimental Setup

We train each model for 5050 epochs, with each epoch containing 10001000 iterations. We use the Adam optimizer with a learning rate of 11031\cdot 10^{-3} and no weight decay. We do not perform early stopping. We make a single learning rate adjustment at epoch 3535 which adjusts the learning rate to 10410^{-4}. When measuring NLL for results, we measure the NLL of the full set of 10241024 points. Unless otherwise specified, UMBC models use the softmax activation function. We list the architectures in Sections C.2, C.2 and C.2. All models have an additional linear output which outputs K×5K\times 5 parameters for the Gaussian mixture outlined in Equation 14.

Table 6: MVN generic model (Used by all encoders)
Output Size Layers Amount
Ni×2N_{i}\times 2 Input Set ×1\times 1
Ni×128N_{i}\times 128 Linear(2, 128), ReLU ×1\times 1
Ni×128N_{i}\times 128 Set Encoder ×1\times 1
K×5K\times 5 Decoder ×3\times 3
Table 7: Set Encoder Specific settings for baseline models.
Name Set Encoder Output Size Set Decoder Output Size
DeepSets (Zaheer et al., 2017) Mean Pooling 128128 Linear, ReLU 128128
SSE (Bruno et al., 2021) Slot Set Encoder K×128K\times 128 Linear, ReLU 128128
FSPool (Zhang et al., 2020) Featurewise Sort Pooling 128128 Linear, ReLU 128128
Diff. EM. (Kim, 2022) Expectation Maximization Layer 12861286 Linear, ReLU 128
Set Transformer (Lee et al., 2019) Pooling by Multihead Attention K×128K\times 128 Set Attention Block K×128K\times 128
Table 8: Set Encoder Specific settings for UMBC models. UMBC models account for the extra encoder by using fewer layers in the decoder (2 layers instead of 3).
Name MBC Set Encoder Output Size non-MBC Set Encoder Output Size Set Decoder Output Size
(Ours) UMBC+FSPool UMBC Layer K×128K\times 128 Featurewise Sort Pooling 128128 Linear, ReLU 128128
(Ours) UMBC+Diff EM UMBC Layer K×128K\times 128 Expectation Maximization Layer 12861286 Linear, ReLU 128128
(Ours) UMBC+Set Transformer UMBC Layer K×128K\times 128 Set Attention Block K×128K\times 128 Linear, ReLU K×128K\times 128

Appendix D Measuring the Variance of Pooled Features

In Figure 3, we show the quantitative effect on the pooled representation between the plain Set Transformer, UMBC+Set Transformer, FSPool and DiffEM. The UMBC model always shows 0 variance, while the non-MBC models produce variance between aggregated encodings of random partitions. For a single chunk, however, non-MBC models show no variance, as random partitions of a single chunk would be equivalent to permuting the elements within the chunk (i.e. non-MBC models still produce an encoding which is permutation invariant). The variance increases drastically when the set is partitioned into two chunks and then the behavior differs between the non-MBC models. Set Transformer happens to show decreasing variance as the number of chunks increases. Note that as the number of chunks increases, the cardinality of each chunk decreases. Therefore, the variance decreases as the chunk cardinality also decreases, but this does not indicate that the models is performing better. For example, in Figure 2, when a singleton set is input to the Set Transformer, the predictions become almost meaningless even though they may have lower variance. The procedure for aggregating the encodings of set partitions for the non-MBC models is outlined in Appendix E.

Figure 9: Distributions used in sampling random inputs for the encoding variance experiment in Figure 2
Distribution Dimension Number of Points
Normal(0, 1) 128128 256256
Uniform(-3, 3) 128128 256256
Exponential(1) 128128 256256
Cauchy(0, 1) 128128 256256
Figure 10: The number of chunks and elements per chunk.
Number of Chunks 11 22 44 88 1616 3232
Elements per Chunk 10241024 512512 256256 128128 6464 3232

To perform this experiment, we used a randomly initialized model with 128128 hidden units, and sampled 256256 set elements from four different distributions in order to make a total set size of 10241024. We then created 100 random partitions for various chunk sizes. Chunk sizes and distributions are shown in Figure 10. We then encode the whole set in chunks and and report the observed variance over the 100 different random partitions at each of the various chunk sizes (Figure 3). Note that the encoded set representation is a vector and Figure 3 shows a scalar value. To achieve this, we take the feature-wise variance over the 100 encodings and report the mean and standard deviation over the feature dimension. Specifically, given 𝐙=[𝐳1𝐳100]100×128\mathbf{Z}=[\mathbf{z}_{1}\cdots\mathbf{z}_{100}]^{\top}\in\mathbb{R}^{100\times 128} representing all 100 encodings with 𝐳i=(zi,1,,zi,128)\mathbf{z}_{i}=(z_{i,1},\ldots,z_{i,128}), we compute feature-wise variance as

z^j=1(1001)i=1100(zi,jμj)2,μj=1100i=1100zi,j\displaystyle\hat{z}_{j}=\frac{1}{(100-1)}\sum_{i=1}^{100}(z_{i,j}-\mu_{j})^{2},\quad\mu_{j}=\frac{1}{100}\sum_{i=1}^{100}z_{i,j}

for j=1,,128j=1,\ldots,128. We then achieve the values of y-axis and error bars in Figure 3 by a mean and standard deviation over the feature dimension,

y=1128iz^i,yσ=1(1281)i(z^iy)2.\displaystyle y=\frac{1}{128}\sum_{i}\hat{z}_{i},\quad\quad y_{\sigma}=\sqrt{\frac{1}{(128-1)}\sum_{i}(\hat{z}_{i}-y)^{2}}. (51)

Appendix E A Note on MBC Testing of non-MBC models

In the qualitative experiments Figures 2 and 3, we apply MBC testing to non-MBC models in order to study the effects of using non-MBC models in MBC settings. Non-MBC models do not prescribe a way to accomplish this in the original works, so we took the approach of processing each chunk up until the pooled representation. We then performed mean pooling over the encoded chunks in the following way. Let XX be an input set and let ζ(X)={X1,,Xn}\zeta(X)=\{X_{1},\ldots,X_{n}\} be a partition of the set XX, i.e. X=j=1NXjX=\bigcup_{j=1}^{N}X_{j} with XiXj=X_{i}\cap X_{j}=\emptyset for iji\neq j. Denote f~θ\tilde{f}_{\theta} a non-MBC set encoding function, then our pseudo-MBC testing procedure is as follows,

Z=1Nj=1Nf~θ(Xj)\displaystyle Z=\frac{1}{N}\sum_{j=1}^{N}\tilde{f}_{\theta}(X_{j}) (52)
Refer to caption
Refer to caption
(a) Deepsets
Refer to caption
(b) Slot Set Encoder
Refer to caption
(c) UMBC + ST
Figure 11: Performance of (a) DeepSets (b) SSE, and (c) UMBC with varying set sizes for image completion on CelebA dataset.

Appendix F Details on the Image Completion Experiments

F.1 Additional Experimental Results

In figure Figure 11, we evaluate our proposed unbiased full set gradient approximation algorithm (red) with Deepsets, Slot Set Encoder (SSE) and UMBC + Set Transformer (ST) and compare our algorithm against the one training with a randomly sampled subset of 100 elements , which is a biased estimator, (green) and the one computing full set gradient (blue). Across all models, our unbiased estimator significantly outperforms the models trained with a randomly sampled subset. Notably, the model trained with our proposed algorithm is indistinguishable from the model trained with full set gradient while our method only incurs constant memory overhead for any set size. These empirical results again verify efficiency of our unbiased full set gradient approximation.

F.2 Experimental Setup

We train all models on CelebA dataset for 200,000 steps with Adam optimizer (Kingma & Ba, 2015) and 256 batch size but no weight decay. We set the learning rate to 51045\cdot 10^{-4} and use a cosine annealing learning rate schedule. In Sections F.2 and F.2, we specify the architecture of Conditional Neural Process with UMBC + Set Transformer. We use k=128k=128 slots and set dimension of each slot to ds=128d_{s}=128. For the attention layer, we use the softmax for the activation function σ\sigma and set the dimension of attention output to d=128d=128. As an input to the set encoder, we concatenate the coordinates of each 𝐱i,cj2{\mathbf{x}}_{i,c_{j}}\in\mathbb{R}^{2} and the corresponding pixel value yi,cj3y_{i,c_{j}}\in\mathbb{R}^{3} from the context XiX_{i} for j=1,,Nij=1,\ldots,N_{i}, resulting in a Ni×5N_{i}\times 5 matrix.

Table 9: UMBC Set Encoder of Conditional Neural Process.
Output Size Layers
Ni×5N_{i}\times 5 Input Context Set
Ni×128N_{i}\times 128 Linear(5, 128), ReLU
Ni×128N_{i}\times 128 Linear(128, 128), ReLU
Ni×128N_{i}\times 128 Linear(128, 128), ReLU
Ni×128N_{i}\times 128 Linear(128, 128), ReLU
128×128128\times 128 UMBC Layer
128×128128\times 128 Layer Normalization
128×128128\times 128 Set Attention Block (Lee et al., 2019)
128×128128\times 128 Set Attention Block
128128 Pooling by Multihead Attention (Lee et al., 2019)
Table 10: Decoder of Conditional Neural Process.
Output Size Layers
128,1024×2128,1024\times 2 Input Set Representation and Coordinates
1024×1301024\times 130 Tile & Concatenate
1024×1281024\times 128 Linear(130, 128), ReLU
1024×1281024\times 128 Linear(128, 128), ReLU
1024×1281024\times 128 Linear(128, 128), ReLU
1024×1281024\times 128 Linear(128, 128), ReLU
1024×61024\times 6 Linear(128, 6)

Appendix G Details on the Long Document Classification Experiments

We train all models for 30 epochs with AdamW optimizer (Loshchilov & Hutter, 2019) and batch size 8. We use constant learning rate 51055\cdot 10^{-5}. For our UMBC model, we pretrain the model while freezing BERT for 30 epochs and finetune the whole model for another 30 epochs. In Table 11, we specify architecture of UMBC + BERT (Devlin et al., 2019) without positional encoding. We use k=256k=256 slots and set dimension of each slot to ds=128d_{s}=128. We use slot-sigmoid for the activation function σ\sigma and set the dimension of the attention output to d=768d=768.

Table 11: UMBC Set Encoder and BERT decoder for long document classification.
Output Size Layers
NiN_{i} Input Document
Ni×768N_{i}\times 768 Word Embedding
Ni×768N_{i}\times 768 Layer Normalization
Ni×768N_{i}\times 768 Linear(768,768), ReLU
Ni×768N_{i}\times 768 Linear(768,768), ReLU
256×768256\times 768 UMBC Layer
256×768256\times 768 Layer Normalization
256×768256\times 768 BERT w/o Positional Encoding (Devlin et al., 2019)
768768 [CLS] token Pooler
42714271 Dropout(0.1), Linear(768, 4271)

Appendix H Details on the Camelyon16 Experiments

Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 12: Examples of 44 single patches from the Camelyon16 dataset. On average, each set in the dataset contains over 9,3009,300 patches like the ones pictured above.

The Camelyon16 Whole Slide Image dataset consists of 270 training instances and 129 validation instances. The dataset was created for a competition, and therefore the test set is hidden. We therefore follow the example set by previous works (Li et al., 2021) and report performance achieved on the validation set. For preprocessing, we consider the 20×20\times slide magnification setting, and use OTSU’s thresholding method to detect regions containing tissue within the WSI. We then split the activated regions into non overlapping patches of size 256×256256\times 256. An example of single input patches can be seen in Figure 12. The largest input set contains 37,34537,345 image patches which are each 256×256×3\in\mathbb{R}^{256\times 256\times 3}. All patch extraction code can be found in the supplementary file. Table 14 contains statistics related to the numbers of patches per input for the training and the test set as well as the distribution of positive and negative labels.

H.1 Experimental Setup

We use a ResNet18 (He et al., 2016) which was pretrained with self-supervised contrastive learning (Chen et al., 2020) by Li et al. (2021). The pretrained ResNet18 weights can be downloaded from this repository. Following the classification experiments done by Lee et al. (2019), we place dropout layers before and after the PMA layer of the Set Transformer in our UMBC model. We will describe our pretraining and finetuning steps below in detail.

Table 12: Camelyon16 generic model (Used by all encoders)
Output Size Layers Name Amount
Ni×256×256×3N_{i}\times 256\times 256\times 3 Input Set Bag of Instances ×1\times 1
Ni×512N_{i}\times 512 ResNet18(InstanceNorm) Feature Extractor ×1\times 1
Ni×256N_{i}\times 256 Linear, ReLU, Linear Projection ×1\times 1
11 Linear, Max Pooling Instance Classifier ×1\times 1
Ni×128N_{i}\times 128 Set Encoding Function Bag Classifier ×1\times 1
11 Set Decoder Bag Classifier ×1\times 1
Table 13: Camelyon16 Bag Classifier Models.
Name Set Encoder Output Size
AB-MIL Gated Attention (Ilse et al., 2018) 11
DSMIL DS-MIL Aggregator (Li et al., 2021) 11
Deepsets Max Pooling \rightarrow Linear \rightarrow ReLU \rightarrow Linear 11
Slot Set Encoder SSE (Bruno et al., 2021)\rightarrow Linear \rightarrow ReLU \rightarrow Linear 11
UMBC+SetTransformer UMBC(K=6464) \rightarrow SAB \rightarrow PMA(K=11, p=0.5) \rightarrow Linear 11

Pretraining.

For pretraining, we extract the features from the pretrained ResNet18 and only train the respective MIL models (Section H.1) on the extracted features. We pretrain for 200 epochs with the Adam optimizer which uses a learning rate of 51045\cdot 10^{-4} and a cosine annealing learning rate decay which reaches the minimum at 51065\cdot 10^{-6}. We use β1=0.5\beta_{1}=0.5, and β2=0.9\beta_{2}=0.9 for Adam. We train with a batch size of 1 on a single GPU, and save the model which showed the best performance on the validation set, where the performance metric is (Accuracy+AUC)/2\left(\text{Accuracy}+\text{AUC}\right)/2. Other details can be found in Section 4.4. These results can be seen in the left column of Table 4.

Finetuning.

For finetuning, we use our unbiased gradient approximation algorithm with a chunk size of 256256. We freeze the pretrained MIL head and only finetune the backbone resnet model. Therefore, we sequentially process each 256256 chunk for each input set until the entire set has been processed. We train for 10 total epochs, and use the AdamW optimizer with a learning rate of 51055\cdot 10^{-5}, and a weight decay of 11021\cdot 10^{-2} which is not applied to bias or layernorm parameters. We use a one epoch linear warmup, and then a cosine annealing learning rate decay at every iteration which reaches a minimum at 51065\cdot 10^{-6}. We train on 1 GPU, with a batch size of 1 and with a single instance on each GPU.

Table 14: Statistics for the Camelyon16 training and test sets we used. Left: Number of patches (set size) per instance. Right: The distribution of positive and negative samples.
Metric Train Test
Mean 9,329 9,376
Min 154 1558
Max 32,382 37,345
Metric Train Test
Positive (++) 110 49
Negative (-) 160 80

Appendix I Generalizing Attention Activations

As shown in Equation 7, any attention activation function which can be expressed as a strictly positive elementwise function combined with sum decomposable normalization constants νp\nu_{p} and f¯θ{\bar{f}}_{\theta} represents a valid attention activation function. Table 15 shows 5 such functions with their respective normalization constants, although there are an infinite number of possible functions which can be used.

The softmax operation we propose h1:d(0,1)dh_{1}:\mathbb{R}^{d}\mapsto(0,1)^{d} which is outlined immediately before Theorem 3.4 is mathematically equivalent to the standard softmax h2:d(0,1)dh_{2}:\mathbb{R}^{d}\mapsto(0,1)^{d} which is commonly implemented in deep learning libraries because ff and gg have the same domain, the same codomain, and h1(𝐱)=h2(𝐱)h_{1}(\mathbf{x})=h_{2}(\mathbf{x}) for all 𝐱d\mathbf{x}\in\mathbb{R}^{d}. Therefore the functions are mathematically equivalent, even though the implementations are not. Our proposed function h1h_{1} requires separately applying the exponential, and storing and updating the normalization constant while h2h_{2} is generally implemented in such a way that everything is done in a single operation.

Table 15: Valid UMBC attention activation functions with slot normalization ν1\nu_{1} and normalization over the set elements f¯θ{\bar{f}}_{\theta}. Slot-exp uses A^i,jmax1ikA^i,j\hat{A}_{i,j}-\max_{1\leq i\leq k}\hat{A}_{i,j} instead of νp\nu_{p}.
function σ\sigma νp\nu_{p} f¯θ{\bar{f}}_{\theta} name reference
𝚜𝚒𝚐𝚖𝚘𝚒𝚍\mathtt{sigmoid} p=1p=1 - slot-sigmoid (Bruno et al., 2021)
exp()\exp() p=1p=1 slot-softmax (Locatello et al., 2020)
exp()\exp() p=2p=2 softmax (Lee et al., 2019)
exp()\exp() -22footnotemark: 2 slot-exp -
𝚜𝚒𝚐𝚖𝚘𝚒𝚍()\mathtt{sigmoid}() p=2p=2 sigmoid -

Appendix J Algorithm

We outline our unbiased full set gradient approximation here.

Algorithm 1 Unbiased Full Set Gradient Estimation
1:  Input: Dataset ((Xi,yi))i=1n((X_{i},y_{i}))_{i=1}^{n}, batch size mm, the number of subsets mm^{\prime}, learning rate (ηt)t=1T(\eta_{t})_{t=1}^{T}, total steps TT, and functions fθf_{\theta}, and gλg_{\lambda}.
2:  Randomly initialize θ\theta and λ\lambda
3:  for all t=1,,Tt=1,\ldots,T do
4:     Sample ((X¯i,y¯i))i=1mD[((Xi,yi))i=1n]((\bar{X}_{i},{\bar{y}}_{i}))_{i=1}^{m}\sim D[((X_{i},y_{i}))_{i=1}^{n}]
5:     Lt,1(θ,λ)0,Lt,2(θ,λ)0L_{t,1}(\theta,\lambda)\leftarrow 0,L_{t,2}(\theta,\lambda)\leftarrow 0
6:     for all i=1,mi=1\ldots,m do
7:        Partition a set X¯i\bar{X}_{i} to get ζt(X¯i)\zeta_{t}(\bar{X}_{i})
8:        Sample ζ¯t(X¯i)D[ζt(X¯i)]{\bar{\zeta}}_{t}(\bar{X}_{i})\sim D[\zeta_{t}(\bar{X}_{i})] with |ζ¯t(X¯i)|=m|{\bar{\zeta}}_{t}(\bar{X}_{i})|=m^{\prime}
9:        fθ(X¯i)=S¯ζ¯t(X¯i)fθ(S¯)+Sζt(X¯i)\ζ¯t(X¯i)StopGrad(fθ(S))f_{\theta}(\bar{X}_{i})=\sum_{{\bar{S}}\in{\bar{\zeta}}_{t}(\bar{X}_{i})}f_{\theta}({\bar{S}})+\sum_{S\in\zeta_{t}(\bar{X}_{i})\backslash{\bar{\zeta}}_{t}(\bar{X}_{i})}\texttt{StopGrad}(f_{\theta}(S))
10:        Lt,1(θ,λ)Lt,1(θ,λ)+1m|ζt(X¯i)||ζ¯t(X¯i)|(gλ(fθ(X¯i)),y¯i)L_{t,1}(\theta,\lambda)\leftarrow L_{t,1}(\theta,\lambda)+\frac{1}{m}\frac{|\zeta_{t}(\bar{X}_{i})|}{|{\bar{\zeta}}_{t}(\bar{X}_{i})|}\ell(g_{\lambda}(f_{\theta}(\bar{X}_{i})),{\bar{y}}_{i})
11:        Lt,2(θ,λ)Lt,2(θ,λ)+1m1|ζ¯t(X¯i)|(gλ(fθ(X¯i)),y¯i)L_{t,2}(\theta,\lambda)\leftarrow L_{t,2}(\theta,\lambda)+\frac{1}{m}\frac{1}{\lvert{\bar{\zeta}}_{t}(\bar{X}_{i})\rvert}\ell(g_{\lambda}(f_{\theta}(\bar{X}_{i})),{\bar{y}}_{i})
12:     end for
13:     θθηtLt,1(θ,λ)θ\theta\leftarrow\theta-\eta_{t}\frac{\partial L_{t,1}(\theta,\lambda)}{\partial\theta}
14:     λληtLt,2(θt,λ)λ\lambda\leftarrow\lambda-\eta_{t}\frac{\partial L_{t,2}(\theta_{t},\lambda)}{\partial\lambda}
15:  end for