This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Regularized Optimal Transport Layers for Generalized Global Pooling Operations

Hongteng Xu, ,  Minjie Cheng Hongteng Xu was with the Gaoling School of Artificial Intelligence, Renmin University of China and Beijing Key Laboratory of Big Data Management and Analysis Methods.
E-mail: [email protected] Minjie Cheng was with the Gaoling School of Artificial Intelligence, Renmin University of China.
E-mail: [email protected] The two authors contributed equally to this work. Manuscript received XX XX, 20XX; revised XX XX, 20XX.
Abstract

Global pooling is one of the most significant operations in many machine learning models and tasks, which works for information fusion and structured data (like sets and graphs) representation. However, without solid mathematical fundamentals, its practical implementations often depend on empirical mechanisms and thus lead to sub-optimal, even unsatisfactory performance. In this work, we develop a novel and generalized global pooling framework through the lens of optimal transport. The proposed framework is interpretable from the perspective of expectation-maximization. Essentially, it aims at learning an optimal transport across sample indices and feature dimensions, making the corresponding pooling operation maximize the conditional expectation of input data. We demonstrate that most existing pooling methods are equivalent to solving a regularized optimal transport (ROT) problem with different specializations, and more sophisticated pooling operations can be implemented by hierarchically solving multiple ROT problems. Making the parameters of the ROT problem learnable, we develop a family of regularized optimal transport pooling (ROTP) layers. We implement the ROTP layers as a new kind of deep implicit layer. Their model architectures correspond to different optimization algorithms. We test our ROTP layers in several representative set-level machine learning scenarios, including multi-instance learning (MIL), graph classification, graph set representation, and image classification. Experimental results show that applying our ROTP layers can reduce the difficulty of the design and selection of global pooling — our ROTP layers may either imitate some existing global pooling methods or lead to some new pooling layers fitting data better. The code is available at https://github.com/SDS-Lab/ROT-Pooling.

Index Terms:
Global pooling, regularized optimal transport, Bregman ADMM, Sinkhorn scaling, set representation, graph embedding.

1 Introduction

As a fundamental operation of information fusion and structured data representation, global pooling achieves a global representation for a set of inputs. It makes the representation invariant to the permutation of the inputs. This operation has been widely used in many set-level machine learning tasks. In multi-instance learning (MIL) tasks [1, 2], we often leverage a global pooling operation to aggregate multiple instances into a bag-level representation. In graph representation tasks, after passing a graph through a graph neural network, we often apply a global pooling operation (or called “readout”) to merge its node embeddings into a global graph embedding [3, 4]. Besides these two representative cases, pooling layers are also necessary for convolutional neural networks (CNN) when extracting visual features [5, 6]. For the data with multi-scale clustering structures [7, 8], we can stack multiple pooling layers and derive a hierarchical pooling operation accordingly.

Refer to caption
(a) Regularized optimal transport pooling (ROTP) layer
Refer to caption
(b) HROTP module
Figure 1: (a) An illustration of the proposed regularized optimal transport pooling layer (ROTP). Here, the green arrow indicates the optional usage of the side information like the feature-level and sample-level similarities. (b) An illustration of the hierarchical ROTP (HROTP) module constructed by integrating multiple ROTP layers. Here, the ROTP layers with different parameters are labeled by different colors.

Currently, many different kinds of global pooling operations have been proposed. Simple operations, like add-pooling, mean-pooling (or called average-pooling), and max-pooling [9], are commonly used because of their computational efficiency. The mixture [10] and the concatenation [11, 12] of these simple pooling operations are also considered to improve their performance. More recently, some pooling methods, e.g., Network-in-Network (NIN) [13], Set2Set [14], DeepSet [15], dynamic pooling [1], and attention-based pooling layers [2, 16], are developed with learnable parameters and more sophisticated mechanisms. Although the above global pooling methods have been widely used in many machine learning models and tasks, their theoretical study is far lagged-behind. In particular, the principles of these methods are not well-interpreted in Statistics, whose rationality and effectiveness are not supported in theory. Additionally, facing so many different pooling methods, the differences and the connections among them are not thoroughly investigated. Without insightful theoretical guidance, the pooling methods’ design and selection are empirical or depend on time-consuming enumerating, which often leads to poorly-generalizable models and sub-optimal performance in practice.

To simplify the design of global pooling in an interpretable way and boost its performance in practice, in this study, we propose a novel and solid algorithmic pooling framework to unify and generalize most existing global pooling operations through the lens of optimal transport. As illustrated in Fig. 1(a), given a batch of data, the proposed pooling operation first optimizes the joint distribution of sample indices and feature dimensions and then weights and averages the representative “sample-feature” pairs. The optimization problem in the above pooling process corresponds to a regularized optimal transport (ROT) problem. In the problem, the target joint distribution corresponds to an optimal transport (OT) plan derived under three regularizations: i)i) the smoothness of the OT plan, ii)ii) the uncertainty of the OT plan’s marginal distributions, and optionally iii)iii) the Gromov-Wasserstein (GW) discrepancy [17] between the feature-level and sample-level similarity matrices. The above pooling operation provides a generalized pooling framework with theoretical guarantees. Specifically, we demonstrate that most existing pooling operations are specializations of the ROT problem under different parameter configurations. Moreover, the sophisticated pooling mechanisms like the mixed pooling methods in [10] and the hierarchical pooling in [3] can be generalized by hierarchically solving multiple ROT problems, as shown in Fig. 1(b).

Besides proposing the above unified global pooling framework, we further make the parameters of the ROT problem learnable and develop a family of regularized optimal transport pooling (ROTP) layers. The ROTP layers can be treated as new members of deep implicit layers [18, 19, 20, 21]. Their model architectures correspond to different optimization methods of various ROT problems, and their backward computations can have closed-form solutions in some conditions [22]. In particular, we implement the ROTP layers by solving the ROT problems under different settings, including using the entropic or the quadratic smoothness term, with or without the GW discrepancy term, and so on. These ROT problems are solved in a proximal gradient algorithmic framework, in which each subproblem can be optimized by the Sinkhorn scaling algorithm [23, 24, 25] and the Bregman alternating direction method of multipliers (Bregman ADMM, or BADMM for short) [26, 27], respectively. As a result, each ROTP layer unrolls the iterative optimization of the corresponding ROT problem to feed-forward computations, whose backpropagation adjusts the parameters controlling the optimization process. We analyze each ROTP layer’s representation power, computational complexity, and numerical stability in depth and test its performance in various learning tasks.

The contributions of our work can be summarized as follows.

  • i)i)

    A generalized global pooling framework with theoretical supports. We propose a generalized global pooling framework, which unifies many existing pooling methods with theoretical guarantees and offers a new perspective based on regularized optimal transport. From the viewpoint of statistical signal processing, the proposed pooling framework is interpretable, which yields an expectation-maximization principle. Furthermore, stacking multiple ROTP layers together leads to a hierarchical ROTP (HROTP) module shown in Fig. 1(b). We demonstrate that such a hierarchical pooling module can generalize the mixed pooling mechanism in [10] and represent the complex data with hierarchical clustering structures, e.g., the set of graphs.

  • ii)ii)

    Effective and flexible global pooling layers. Based on the proposed global pooling framework, we develop a family of ROTP layers and analyze them quantitatively. The proposed ROTP layers can be implemented under different regularizers and optimization algorithms, which have high flexibility and can be applied in various learning scenarios. Additionally, based on the ROTP layers, we can build the HROTP module and provide new solutions to complicated operations like mixed pooling and set fusion.

  • iii)iii)

    Universal effectiveness in various tasks. We test our ROTP layers in several representative set-level machine learning tasks, including multi-instance learning (MIL), graph classification, graph set representation, and image classification. These tasks correspond to real-world applications, such as medical image analysis, molecule classification, drug-drug interaction analysis, and the ImageNet challenge. In each task, our ROTP layers can imitate or outperform state-of-the-art global pooling methods, which simplify the design and the selection of global pooling operations in practice.

The remainder of this paper is organized as follows: Section 2 provides a detailed literature review and explains the connections and differences between our work and existing methods. Section 3 introduces the proposed ROTP framework and demonstrates its rationality and generalization power. Additionally, this section introduces the hierarchical ROTP module for mixed pooling and set fusion. Section 4 introduces the ROTP layers under different settings and analyzes their computational complexity and numerical stability. Section 5 analyzes the generalization power, computational complexity, and numerical stability of different ROTP layers and compares them with other optimal transport-based pooling methods. Section 6 provides the implementation details of our methods and experimental results on multiple datasets. Finally, Section 7 concludes the paper and discusses our future work.

2 Related Work

2.1 Pooling operations

Most existing models empirically leverage simple pooling operations like add-pooling, mean-pooling, and max-pooling for convenience, whose practical performance is often sub-optimal. Many efforts have been made to achieve better pooling performance, which can be coarsely categorized into two strategies. The first strategy is applying multiple simple pooling operations jointly, e.g., concatenating the output of mean-pooling with that of max-pooling [11, 12]. In [10], the mixed mean-max pooling and its structured variants leverage mixture models of mean-pooling and max-pooling to improve pooling results. Recently, a generalized norm-based pooling (GNP) is proposed in [28, 29]. It can imitate max-pooling and mean-pooling under different settings. By learning its parameters, the GNP achieves a mixture of max-pooling and mean-pooling in a nonlinear way.

The second strategy is designing pooling operations with cutting-edge neural network architectures and empowering them with additional feature transformation and extraction abilities. The early work following this strategy includes the Network-in-Network in [13] and the Set2Set in [14], which integrate neural networks like classic multi-layer perceptrons (MLPs) and recurrent neural networks (RNNs) into pooling operations. Recently, attention-pooling and its gated version merge multiple instances with the help of different self-attentive mechanisms [2]. Based on a similar idea, the dynamic-pooling in [1] applies an iterative adjustment step to improve the self-attentive mechanism. More recently, the DeepSet in [15], the SetTransformer in [16], and the prototype-oriented set representer in [30] leverage and modify advanced transformer modules [31] to achieve sophisticated pooling operations. Besides the above global pooling methods, some attempts have been made to leverage graph structures to achieve pooling operations, e.g., the DiffPooling in [3], the ASAP in [32], and the self-attentive graph pooling (SAGP) in [33].

Different from the above methods, we study the design of pooling operation through the lens of computational optimal transport [34] and propose a novel and solid algorithmic framework to unify many representative global pooling methods. The neural network-based implementations of our framework can be interpreted as solving a regularized optimal transport problem with learnable parameters. As a result, instead of designing and selecting global pooling operations empirically, we can apply our ROTP layers to approximate suitable global pooling layers automatically according to observed data or achieve new pooling mechanisms with better performance.

2.2 Optimal transport-based machine learning

Optimal transport (OT) theory [35] has proven to be useful in machine learning tasks, e.g., distribution matching [36, 37, 38], data clustering [39, 40], and generative modeling [41, 42]. Given the samples of two distributions, the discrete OT problem aims at learning a joint distribution of the samples (a.k.a., the optimal transport plan) and indicating the correspondence between them accordingly. This discrete OT problem is a linear programming problem [43]. By adding an entropic regularizer [23], the problem becomes strictly convex and can be solved efficiently by the Sinkhorn scaling in [44, 45]. Along this direction, the logarithmic stabilized Sinkhorn scaling algorithm [46, 47] and the proximal point method [48] make efforts to suppress the numerical instability of the classic Sinkhorn scaling algorithm and solve the entropic OT problem robustly. The Greenkhorn algorithm [49] provides a stochastic Sinkhorn scaling algorithm for batch-based optimization. When the marginal distributions of the optimal transport plan are unreliable or unavailable, the variants of the original OT problem, e.g., the partial OT [45, 50] and the unbalanced OT [46, 24] are considered, and the algorithms focusing on these variants are developed accordingly. Besides the Sinkhorn scaling algorithm, some other algorithms are developed, e.g., the Bregman ADMM [26, 51, 27], the smoothed semi-dual algorithm [52], and the conditional gradient algorithm [53, 54].

Recently, some attempts have been made to design neural networks to imitate the Sinkhorn-based algorithms of OT problems, such as the Gumbel-Sinkhorn network [55], the sparse Sinkhorn attention model [56], the Sinkhorn autoencoder [57], and the Sinkhorn-based transformer [58]. Focusing on pooling layers, some OT-based solutions have been proposed as well. In [59], an OT-based feature aggregation method called OTK is proposed. Following a similar idea, a differentiable expectation-maximization pooling method is proposed in [60], whose implementation is based on solving an entropic OT problem. More recently, the pooling methods based on sliced Wasserstein distance [61] are proposed, e.g., the sliced-Wasserstein embedding (SWE) in [62] and the WEGL in [63]. However, these methods only consider solving the pooling problems in specific tasks rather than developing a generalized pooling framework. Moreover, they focus on the optimal transport in the sample space and are highly dependent on Sinkhorn-based algorithms and random projections, which ignore the potentials of other algorithms.

2.3 Implicit layers and optimization-driven models

Essentially, our ROTP layers can be viewed as new members of deep implicit layers [19, 64, 21] (or equivalently, called declarative neural networks [22]), which achieve global pooling by solving optimization problems. Many implicit layers have been proposed, e.g., the input convex neural networks (ICNNs) [65] and the deep equilibrium (DEQ) network [21], and so on. From the viewpoint of ordinary differential equations (ODEs), the DEQ network reformulates the ResNet [6] as a single implicit layer whose feed-forward computation corresponds to fixed point iterations. The OptNet in [66] provides a toolbox to implement convex optimization problems as neural network layers. At the same time, some researchers contributed to this field from the viewpoint of signal processing, implementing the iteration optimization steps of compressive sensing by neural networks [67, 68]. In general, the learning of such optimization-driven neural networks is based on auto-differentiation (AD) [69], which often owns high time and space complexity. Fortunately, for the layers corresponding to convex optimization problems, their gradients often have closed-form solutions, and their backpropagation steps are efficient [22].

Note that the study of implicit layers has been interactive with computational optimal transport methods. For example, an optimal transport model based on the input convex neural network has been proposed in [70]. More recently, the work in [71, 72] develops the Sinkhorn-scaling algorithm as implicit layers, whose backward computation is achieved in a closed form. Focusing on the design and the learning of global pooling operations, our generalized pooling framework provides a new optimization-driven solution.

3 Proposed Global Pooling Framework

3.1 A generalized formulation of global pooling

Denote 𝒳D={𝑿D×N|N}\mathcal{X}_{D}=\{\bm{X}\in\mathbb{R}^{D\times N}|N\in\mathbb{N}\} as the space of sample sets, where each set 𝑿=[𝒙1,,𝒙N]D×N\bm{X}=[\bm{x}_{1},...,\bm{x}_{N}]\in\mathbb{R}^{D\times N} contains NN DD-dimensional samples. In practice, 𝑿\bm{X} can be used to represent the instances in a bag, the node embeddings of a graph, or the local visual features in an image. Following the work in [29, 28, 73], we assume the input data 𝑿\bm{X} to be nonnegative in this study. This assumption is generally reasonable because the input data are often processed by non-negative activations, like ReLU, Sigmoid, Softplus, and so on. For some pooling methods, e.g., the max-pooling and the generalized norm pooling [28], the non-negativeness of input data is even necessary.

A global pooling operation, denoted as f:𝒳DDf:\mathcal{X}_{D}\mapsto\mathbb{R}^{D}, maps each set to a single vector and ensures the output is permutation-invariant, i.e., f(𝑿)=f(𝑿π)f(\bm{X})=f(\bm{X}_{\pi}) for 𝑿,𝑿π𝒳D\bm{X},\bm{X}_{\pi}\in\mathcal{X}_{D}, where 𝑿π=[𝒙π(1),,𝒙π(N)]\bm{X}_{\pi}=[\bm{x}_{\pi(1)},...,\bm{x}_{\pi(N)}] and π\pi is an arbitrary permutation. As aforementioned, many pooling methods have been proposed to achieve this aim. Typically, the mean-pooling takes the average of the input vectors as its output, i.e., f(𝑿)=1Nn=1N𝒙nf(\bm{X})=\frac{1}{N}\sum_{n=1}^{N}\bm{x}_{n}. Another popular pooling operation, max-pooling, concatenates the maximum of each dimension as its output, i.e., f(𝑿)=d=1Dmaxn{xdn}n=1Nf(\bm{X})=\|_{d=1}^{D}\max_{n}\{x_{dn}\}_{n=1}^{N}, where xdnx_{dn} is the dd-th element of 𝒙n\bm{x}_{n} and “\|” represents the concatenation operator. The attention-pooling in [2] outputs the weighted summation of the input vectors, i.e., f(𝑿)=𝑿𝒂𝑿f(\bm{X})=\bm{X}\bm{a}_{\bm{X}}, where 𝒂𝑿ΔN1\bm{a}_{\bm{X}}\in\Delta^{N-1} is a vector on the (N1)(N-1)-Simplex. The attention-pooling leverages a self-attention mechanism to derive 𝒂𝑿\bm{a}_{\bm{X}} from the input 𝑿\bm{X}, i.e.i.e., 𝒂𝑿=softmax(𝒘Ttanh(𝑽𝑿))T\bm{a}_{\bm{X}}=\text{softmax}(\bm{w}^{T}\text{tanh}(\bm{VX}))^{T}.

The element xdnx_{dn} of the 𝑿\bm{X} represents the dd-th feature of the nn-th sample. From the perspective of statistical signal processing, it can be treated as the signal corresponding to a pair of the sample index and the feature dimension. Given all sample indices and feature dimensions, we can define a joint distribution for them, denoted as 𝑷=[pdn][0,1]D×N\bm{P}=[p_{dn}]\in[0,1]^{D\times N}, and the element pdnp_{dn} indicates the significance of the signal xdnx_{dn}. Accordingly, the above global pooling methods yield the following generalized formulation that calculates and concatenates the conditional expectations of xdnx_{dn}’s for d=1,,Dd=1,...,D:

f(𝑿)=(𝑿diag1(𝑷𝟏N𝒑=[pd])𝑷𝑷~=[pn|d])𝟏N=d=1D𝔼npn|d[xdn],\displaystyle f(\bm{X})=(\bm{X}\odot\underbrace{\text{diag}^{-1}(\overbrace{\bm{P1}_{N}}^{\bm{p}=[p_{d}]})\bm{P}}_{\tilde{\bm{P}}=[p_{n|d}]})\bm{1}_{N}=\big{\|}_{d=1}^{D}\mathbb{E}_{n\sim p_{n|d}}[x_{dn}], (1)

where \odot is the Hadamard product, diag()\text{diag}(\cdot) converts a vector to a diagonal matrix, and 𝟏N\bm{1}_{N} represents the NN-dimensional all-one vector. 𝑷𝟏N=𝒑\bm{P1}_{N}=\bm{p} is the distribution of feature dimensions. diag1(𝒑)𝑷=𝑷~=[pn|d]\text{diag}^{-1}(\bm{p})\bm{P}=\tilde{\bm{P}}=[p_{n|d}] normalizes the rows of 𝑷\bm{P}, and the dd-th row leads to the distribution of sample indices conditioned on the dd-th feature dimension.

Based on the generalized formulation in (1), we can find that the above global pooling methods apply different mechanisms to derive the 𝑷\bm{P}. Given 𝑿\bm{X}, the mean-pooling treats each element evenly and 𝑷=[1DN]\bm{P}=[\frac{1}{DN}]. The max-pooling sets 𝑷{0,1D}D×N\bm{P}\in\{0,\frac{1}{D}\}^{D\times N} and pdn=1Dp_{dn}=\frac{1}{D} if and only if n=argmaxm{xdm}m=1Nn=\arg\max_{m}\{x_{dm}\}_{m=1}^{N}. The attention-pooling derives 𝑷\bm{P} as a learnable rank-one matrix parameterized by the input 𝑿\bm{X}, i.e., 𝑷=1D𝟏D𝒂𝑿T\bm{P}=\frac{1}{D}\bm{1}_{D}\bm{a}^{T}_{\bm{X}}. All these operations set the marginal distribution of feature dimensions to be uniform, i.e., 𝒑=𝑷𝟏N=[1D]\bm{p}=\bm{P1}_{N}=[\frac{1}{D}]. For the other marginal distribution 𝒒=𝑷T𝟏D\bm{q}=\bm{P}^{T}\bm{1}_{D}, some pooling methods impose specific constraints, e.g., 𝒒=1N𝟏N\bm{q}=\frac{1}{N}\bm{1}_{N} for the mean-pooling and 𝒒=𝒂𝑿\bm{q}=\bm{a}_{\bm{X}} for the attention-pooling, while the max-pooling makes 𝒒\bm{q} unconstrained. In the following content, we will show that in general scenarios, we can relax the constraints to some regularization terms when the marginal distributions have some uncertainties.

3.2 A regularized optimal transport pooling operation

The above analysis indicates that we can unify typical pooling operations in an interpretable algorithmic framework based on the expectation-maximization principle. In particular, the pooling operation in (1) is determined by the joint distribution 𝑷\bm{P}. To keep the fused data as informative as possible, we would like to obtain a joint distribution that maximizes the expectations in (1), which leads to the following optimization problem:

𝑷=argmax𝑷Π(𝒑,𝒒)d=1Dpd𝔼npn|d[xdn]=argmax𝑷Π(𝒑,𝒒)𝔼(d,n)𝑷[xdn]𝑿,𝑷\displaystyle\begin{aligned} \bm{P}^{*}&=\arg\sideset{}{{}_{\bm{P}\in\Pi(\bm{p},\bm{q})}}{\max}\sideset{}{{}_{d=1}^{D}}{\sum}p_{d}\mathbb{E}_{n\sim p_{n|d}}[x_{dn}]\\ &=\arg\sideset{}{{}_{\bm{P}\in\Pi(\bm{p},\bm{q})}}{\max}\underbrace{\mathbb{E}_{(d,n)\sim\bm{P}}[x_{dn}]}_{\langle\bm{X},\bm{P}\rangle}\end{aligned} (2)

where ,\langle\cdot,\cdot\rangle represents the inner product of matrices. 𝒑=[pd]ΔD1\bm{p}=[p_{d}]\in\Delta^{D-1} and 𝒒ΔN1\bm{q}\in\Delta^{N-1} are predefined distributions for feature dimensions and sample indices, respectively. Accordingly, the marginal distributions of 𝑷\bm{P} are restricted to be 𝒑\bm{p} and 𝒒\bm{q}, i.e., 𝑷Π(𝒑,𝒒)={𝑷𝟎|𝑷𝟏N=𝒑,𝑷T𝟏D=𝒒}\bm{P}\in\Pi(\bm{p},\bm{q})=\{\bm{P}\geq\bm{0}|\bm{P}\bm{1}_{N}=\bm{p},\bm{P}^{T}\bm{1}_{D}=\bm{q}\}.

As shown in (2), the objective function is the weighted summation of the expectations conditioned on different feature dimensions, which leads to the expectation of all xdnx_{dn}’s. Here, we have connected the global pooling operation to the theory of optimal transport — (2) is a classic optimal transport problem [35], which learns the optimal joint distribution 𝑷\bm{P}^{*} to maximize the expectation of xdnx_{dn}. Plugging 𝑷\bm{P}^{*} into (1) leads to a global pooling result of 𝑿\bm{X}.

However, solving (2) directly often leads to undesired pooling results because of the following three reasons:

i)i) The nature of linear programming. Solving (2) is time-consuming and always leads to a sparse 𝑷\bm{P}^{*} because it is a constrained linear programming problem. A sparse 𝑷\bm{P}^{*} tends to filter out some weak but possibly-informative signals in 𝑿\bm{X}, which may have negative influences on downstream tasks.

ii)ii) The uncertainty of marginal distributions. Solving (2) requires us to set the marginal distributions 𝒑\bm{p} and 𝒒\bm{q} in advance. However, such exact marginal distributions are often unavailable or unreliable.111That is why most existing pooling methods either assume the marginal distributions to be uniform or make them unconstrained.

iii)iii) The lack of structural information. The optimal transport problem in (2) did not consider the structural relations among samples and those among features. However, real-world samples, like the node embeddings of a graph and the instances in a bag, are non-i.i.d. in general, and their features can be correlated. The state-of-the-art pooling methods, especially those for graph representation [74, 75, 3, 32], often take such structural information into account.

According to the above analysis, to make the optimal transport-based pooling applicable, we need to i)i) improve the smoothness of the optimization problem, ii)ii) take the uncertainties of the marginal distributions into account, and iii)iii) leverage the sample-level and feature-level structural information hidden in the input data. Therefore, we extend (2) to a regularized optimal transport (ROT) problem:

𝑷rot(𝑿;𝜽)=argmin𝑷Ω𝑿,𝑷OT term+α0C(𝑿,𝑷),𝑷Structural Reg.Fused Gromov-Wasserstein discrepancy+α1R(𝑷)Smoothness Reg.+α2KL(𝑷𝟏N|𝒑0)+α3KL(𝑷T𝟏D|𝒒0)Marginal Reg.,\displaystyle\begin{aligned} &\bm{P}_{\text{rot}}^{*}(\bm{X};\bm{\theta})=\arg\sideset{}{{}_{\bm{P}\in\Omega}}{\min}\overbrace{\underbrace{\langle-\bm{X},\bm{P}\rangle}_{\text{OT term}}+\underbrace{\alpha_{0}\langle C(\bm{X},\bm{P}),\bm{P}\rangle}_{\text{Structural Reg.}}}^{\text{Fused Gromov-Wasserstein discrepancy}}\\ &\quad+\underbrace{\alpha_{1}\text{R}(\bm{P})}_{\text{Smoothness Reg.}}+\underbrace{\alpha_{2}\text{KL}(\bm{P1}_{N}|\bm{p}_{0})+\alpha_{3}\text{KL}(\bm{P}^{T}\bm{1}_{D}|\bm{q}_{0})}_{\text{Marginal Reg.}},\end{aligned} (3)

where Ω={𝑷>𝟎|𝟏DT𝑷𝟏N=1}\Omega=\{\bm{P}>\bm{0}|\bm{1}_{D}^{T}\bm{P}\bm{1}_{N}=1\}. In (3), we introduce the following three regularizers, each of which solves one of the above three challenges.

i)i) Smoothness regularization. R(𝑷)\text{R}(\bm{P}) is a regularizer of 𝑷\bm{P}, which is used to improve the smoothness of the objective function. Typically, we can set R(𝑷)\text{R}(\bm{P}) as the negative entropy of 𝑷\bm{P} (R(𝑷)=𝑷,log𝑷𝟏=d,npdn(logpdn1)\text{R}(\bm{P})=\langle\bm{P},\log\bm{P}-\bm{1}\rangle=\sum_{d,n}p_{dn}(\log p_{dn}-1)[23] or the quadratic regularizer of 𝑷\bm{P} (R(𝑷)=𝑷F2\text{R}(\bm{P})=\|\bm{P}\|_{F}^{2}[52]. The parameter α1\alpha_{1} controls the significance of R(𝑷)\text{R}(\bm{P}).

ii)ii) Marginal prior regularization. Instead of imposing strict constraints, we leverage two KL divergence terms in (3) to penalize the difference between the marginals of 𝑷\bm{P} and the predefined prior distributions (denoted as 𝒑0\bm{p}_{0} and 𝒒0\bm{q}_{0}, respectively). Here, KL(𝒂|𝒃)=𝒂,log𝒂log𝒃𝒂𝒃,𝟏\text{KL}(\bm{a}|\bm{b})=\langle\bm{a},\log\bm{a}-\log\bm{b}\rangle-\langle\bm{a}-\bm{b},\bm{1}\rangle represents the KL-divergence between 𝒂\bm{a} and 𝒃\bm{b}. The strength of these two terms is controlled by the weights α2\alpha_{2} and α3\alpha_{3}, respectively. This regularization helps us achieve a trade-off between the utilization of prior information and the robustness to its uncertainty.

iii)iii) Gromov-Wasserstein discrepancy-based structural regularization. Given 𝑿\bm{X}, we construct its feature-level and sample-level covariance matrices,222In practice, besides the covariance matrices, we can also leverage other methods to define the feature-level and sample-level similarities, e.g., the cosine similarity and other kernel matrices. respectively, i.e., 𝚺1=1N(𝑿𝝁1𝟏NT)(𝑿𝝁1𝟏NT)T\bm{\Sigma}_{1}=\frac{1}{N}(\bm{X}-\bm{\mu}_{1}\bm{1}_{N}^{T})(\bm{X}-\bm{\mu}_{1}\bm{1}_{N}^{T})^{T} and 𝚺2=1D(𝑿𝟏D𝝁2T)T(𝑿𝟏D𝝁2T)\bm{\Sigma}_{2}=\frac{1}{D}(\bm{X}-\bm{1}_{D}\bm{\mu}_{2}^{T})^{T}(\bm{X}-\bm{1}_{D}\bm{\mu}_{2}^{T}), where 𝝁1=1N𝑿𝟏N\bm{\mu}_{1}=\frac{1}{N}\bm{X1}_{N} and 𝝁2=1D𝑿T𝟏D\bm{\mu}_{2}=\frac{1}{D}\bm{X}^{T}\bm{1}_{D}. Following the work in [75, 74], we would like to make the feature-level covariance highly correlated with the sample-level covariance such that the features can preserve the structural relations among the samples. Therefore, we construct a structural cost as follows:

C(𝑿,𝑷)=𝚺1𝑷𝚺2T,\displaystyle\begin{aligned} C(\bm{X},\bm{P})=-\bm{\Sigma}_{1}\bm{P}\bm{\Sigma}_{2}^{T},\end{aligned} (4)

and C(𝑿,𝑷),𝑷=tr(𝚺1𝑷𝚺2T𝑷T)\langle C(\bm{X},\bm{P}),\bm{P}\rangle=-\text{tr}(\bm{\Sigma}_{1}\bm{P}\bm{\Sigma}_{2}^{T}\bm{P}^{T}), whose significance is controlled by α0\alpha_{0}. It is easy to find that this structural regularization is the same as the objective of the Gromov-Wasserstein discrepancy problem in [76, 17], penalizing the difference between the sample-level covariance and the feature-level covariance. As shown in (3), combining the original optimal transport term with the structural regularizer leads to the well-known fused Gromov-Wasserstein (FGW) discrepancy [53], which is an optimal transport-based metric for structured data (like graphs and sets) [54, 77].

Compared with (2), (3) is an unconstrained optimization problem, and thus we can apply differentiable algorithms to optimize it.333When α0=0\alpha_{0}=0 and R()\text{R}(\cdot) is a strictly-convex function, the objective function in (3) is strictly-convex. The optimal transport matrix 𝑷rot\bm{P}_{\text{rot}}^{*} can be viewed as a function of 𝑿\bm{X}, whose parameters are the weights of the regularizers and the prior distributions, i.e., 𝑷rot(𝑿;𝜽)\bm{P}_{\text{rot}}^{*}(\bm{X};\bm{\theta}), where 𝜽={α0,α1,α2,α3,𝒑0,𝒒0}\bm{\theta}=\{\alpha_{0},\alpha_{1},\alpha_{2},\alpha_{3},\bm{p}_{0},\bm{q}_{0}\} represents the model parameters for convenience. Plugging it into (1), we obtain the proposed regularized optimal transport pooling (ROTP) operation:

frot(𝑿;𝜽)=(𝑿diag1(𝑷rot(𝑿;𝜽)𝟏N)𝑷rot(𝑿;𝜽))𝟏N.\displaystyle\begin{aligned} f_{\text{rot}}(\bm{X};\bm{\theta})=(\bm{X}\odot\text{diag}^{-1}(\bm{P}_{\text{rot}}^{*}(\bm{X};\bm{\theta})\bm{1}_{N})\bm{P}_{\text{rot}}^{*}(\bm{X};\bm{\theta}))\bm{1}_{N}.\end{aligned} (5)

3.3 The rationality and generalizability of ROTP

Our ROTP is a feasible global pooling operation. In particular, it satisfies the requirement of permutation-invariance under mild conditions.

Theorem 1.

The ROTP in (5) is permutation-invariant, i.e., frot(𝐗)=frot(𝐗π)f_{\text{rot}}(\bm{X})=f_{\text{rot}}(\bm{X}_{\pi}) for an arbitrary permutation π\pi when the following two conditions are satisfied.

  • i)i)

    The R()\text{R}(\cdot) in (3) is a permutation-invariant function of 𝑷\bm{P}.

  • ii)ii)

    The 𝒒0\bm{q}_{0} in (3) is either a permutation-equivariant function of 𝑿\bm{X} or a uniform distribution, i.e., 𝒒0=1N𝟏N\bm{q}_{0}=\frac{1}{N}\bm{1}_{N}.

Proof.

For convenience, we ignore the notation 𝜽\bm{\theta} in the following derivation. Let 𝑷\bm{P}^{*} be the optimal solution of (3) given 𝑿\bm{X}. Denote π:{1,,N}{1,.,N}\pi:\{1,...,N\}\mapsto\{1,....,N\} as an arbitrary permutation and 𝑿π\bm{X}_{\pi} as the column-wise permuted data. We have the following six equations:

i)𝑿,𝑷=𝑿π,𝑷π,ii)C(𝑿,𝑷),𝑷=tr(𝚺1𝑷𝚺2T(𝑷)T)=tr(𝚺1𝑷π𝚺2,π,πT(𝑷π)T)=C(𝑿π,𝑷π),𝑷πiii)R(𝑷)=R(𝑷π),iv)KL(𝑷𝟏N|𝒑0)=KL(𝑷π𝟏N|𝒑0),v)When 𝒒0 is a permutation-equivariant function of 𝑿:KL((𝑷)T𝟏D|𝒒0(𝑿))=KL((𝑷π)T𝟏D|𝒒0(𝑿π))KL((𝑷π)T𝟏D|𝒒0,π(𝑿))vi)When 𝒒0 is uniform:KL((𝑷)T𝟏D|𝒒0)=KL((𝑷π)T𝟏D|𝒒0).\displaystyle\begin{aligned} i)~{}~{}&\langle-\bm{X},\bm{P}^{*}\rangle=\langle-\bm{X}_{\pi},\bm{P}_{\pi}^{*}\rangle,\\ ii)~{}~{}&\langle C(\bm{X},\bm{P}^{*}),\bm{P}^{*}\rangle=-\text{tr}(\bm{\Sigma}_{1}\bm{P}^{*}\bm{\Sigma}_{2}^{T}(\bm{P}^{*})^{T})\\ &=-\text{tr}(\bm{\Sigma}_{1}\bm{P}_{\pi}^{*}\bm{\Sigma}_{2,\pi,\pi}^{T}(\bm{P}_{\pi}^{*})^{T})=\langle C(\bm{X}_{\pi},\bm{P}_{\pi}^{*}),\bm{P}_{\pi}^{*}\rangle\\ iii)~{}~{}&\text{R}(\bm{P}^{*})=\text{R}(\bm{P}_{\pi}^{*}),\\ iv)~{}~{}&\text{KL}(\bm{P}^{*}\bm{1}_{N}|\bm{p}_{0})=\text{KL}(\bm{P}_{\pi}^{*}\bm{1}_{N}|\bm{p}_{0}),\\ v)~{}~{}&\text{When $\bm{q}_{0}$ is a permutation-equivariant function of $\bm{X}$:}\\ &\text{KL}((\bm{P}^{*})^{T}\bm{1}_{D}|\bm{q}_{0}(\bm{X}))=\underbrace{\text{KL}((\bm{P}_{\pi}^{*})^{T}\bm{1}_{D}|\bm{q}_{0}(\bm{X}_{\pi}))}_{\text{KL}((\bm{P}_{\pi}^{*})^{T}\bm{1}_{D}|\bm{q}_{0,\pi}(\bm{X}))}\\ vi)~{}~{}&\text{When $\bm{q}_{0}$ is uniform:}\\ &\text{KL}((\bm{P}^{*})^{T}\bm{1}_{D}|\bm{q}_{0})=\text{KL}((\bm{P}_{\pi}^{*})^{T}\bm{1}_{D}|\bm{q}_{0}).\end{aligned} (6)

where 𝑷π\bm{P}^{*}_{\pi} is the column-wise permutation result of 𝑷\bm{P}^{*}. In the second equation, 𝚺2,π,π\bm{\Sigma}_{2,\pi,\pi} means permuting 𝚺2\bm{\Sigma}_{2} row-wisely and column-wisely based on π\pi. The third equation is based on the condition 1, and the fifth and the sixth equations are based on the condition 2. According to (6), 𝑷π\bm{P}^{*}_{\pi} must be the optimal solution of (3) given 𝑿π\bm{X}_{\pi}. Therefore, 𝑷\bm{P}^{*} is a permutation-equivariant function of 𝑿\bm{X}, i.e., 𝑷π(𝑿)=𝑷(𝑿π)\bm{P}^{*}_{\pi}(\bm{X})=\bm{P}^{*}(\bm{X}_{\pi}).

Plugging 𝑷π(𝑿)=𝑷(𝑿π)\bm{P}^{*}_{\pi}(\bm{X})=\bm{P}^{*}(\bm{X}_{\pi}) into (5), we have

frot(𝑿π)=(𝑿π(diag1(𝑷(𝑿π)𝟏N)𝑷(𝑿π)))𝟏N=(𝑿π(diag1(𝑷π(𝑿)𝟏N)𝑷π(𝑿)))𝟏N=(𝑿π(diag1(𝑷(𝑿)𝟏N)𝑷π(𝑿)))𝟏N=(𝑿(diag1(𝑷(𝑿)𝟏N)𝑷(𝑿)))𝟏N=frot(𝑿),\displaystyle\begin{aligned} f_{\text{rot}}(\bm{X}_{\pi})=&(\bm{X}_{\pi}\odot(\text{diag}^{-1}(\bm{P}^{*}(\bm{X}_{\pi})\bm{1}_{N})\bm{P}^{*}(\bm{X}_{\pi})))\bm{1}_{N}\\ =&(\bm{X}_{\pi}\odot(\text{diag}^{-1}(\bm{P}_{\pi}^{*}(\bm{X})\bm{1}_{N})\bm{P}_{\pi}^{*}(\bm{X})))\bm{1}_{N}\\ =&(\bm{X}_{\pi}\odot(\text{diag}^{-1}(\bm{P}^{*}(\bm{X})\bm{1}_{N})\bm{P}_{\pi}^{*}(\bm{X})))\bm{1}_{N}\\ =&(\bm{X}\odot(\text{diag}^{-1}(\bm{P}^{*}(\bm{X})\bm{1}_{N})\bm{P}^{*}(\bm{X})))\bm{1}_{N}\\ =&f_{\text{rot}}(\bm{X}),\end{aligned} (7)

which completes the proof. ∎

Theorem 1 provides us with sufficient conditions to ensure the permutation-invariance of the proposed ROTP framework. Note that the two conditions in Theorem 1 are common in practice. When R(𝑷)\text{R}(\bm{P}) is an entropic or a quadratic function of 𝑷\bm{P}, Condition 1 is always held. The 𝒒0\bm{q}_{0} is a permutation-equivariant function of 𝑿\bm{X} for the max-pooling and the attention-pooling, and it is uniform for the mean-pooling.

Moreover, our ROTP operation provides a generalized global pooling framework that unifies many representative pooling operations. In particular, the mean-pooling, the max-pooling, and the attention-pooling can be formulated as the specializations of (5) under different parameter configurations.

Proposition 1.

Given an arbitrary 𝐗D×N\bm{X}\in\mathbb{R}^{D\times N}, the mean-pooling, the max-pooling, and the attention-pooling with attention weights 𝐚𝐗\bm{a}_{\bm{X}} can be equivalently achieved by the frot(𝐗;𝛉)f_{\text{rot}}(\bm{X};\bm{\theta}) in (5) under the following configurations:

frot(𝑿;𝜽)f_{\text{rot}}(\bm{X};\bm{\theta}) Mean-pooling Max-pooling Attention-pooling
α0\alpha_{0} 0 0 0
α1\alpha_{1} \rightarrow\infty 0 \rightarrow\infty
α2\alpha_{2} \rightarrow\infty \rightarrow\infty \rightarrow\infty
α3\alpha_{3} \rightarrow\infty 0 \rightarrow\infty
𝒑0\bm{p}_{0} 1D𝟏D\frac{1}{D}\bm{1}_{D} 1D𝟏D\frac{1}{D}\bm{1}_{D} 1D𝟏D\frac{1}{D}\bm{1}_{D}
𝒒0\bm{q}_{0} 1N𝟏N\frac{1}{N}\bm{1}_{N} 𝒂𝑿\bm{a}_{\bm{X}}

Here, “𝐪0=\bm{q}_{0}=-” means that 𝐪0\bm{q}_{0} can be arbitrary vectors. α2,α3\alpha_{2},\alpha_{3}\rightarrow\infty means that the marginal prior regularizers become strict marginal constraints. α1\alpha_{1}\rightarrow\infty means that the smoothness regularizer is dominant and thus the OT term becomes ignorable.

Proof.

Equivalence to mean-pooling operation: For (3), α2,α3\alpha_{2},\alpha_{3}\rightarrow\infty, 𝒑0=1D𝟏D\bm{p}_{0}=\frac{1}{D}\bm{1}_{D} and 𝒒0=1N𝟏N\bm{q}_{0}=\frac{1}{N}\bm{1}_{N}, we require the marginals of 𝑷\bm{P}^{*} to match with uniform distributions strictly. Additionally, α0=0\alpha_{0}=0 and α1\alpha_{1}\rightarrow\infty mean that the smoothness regularizer is dominant and both the OT term and the GW-based structural regularizer are ignored. Therefore, the optimization problem in (3) degrades to

𝑷=argmin𝑷Π(1D𝟏D,1N𝟏N)R(𝑷).\displaystyle\bm{P}^{*}=\arg\sideset{}{{}_{\bm{P}\in\Pi(\frac{1}{D}\bm{1}_{D},\frac{1}{N}\bm{1}_{N})}}{\min}\text{R}(\bm{P}). (8)

when R(𝑷)=𝑷,log𝑷𝟏\text{R}(\bm{P})=\langle\bm{P},\log\bm{P}-\bm{1}\rangle or 𝑷F2\|\bm{P}\|_{F}^{2}, the optimal solution of (8) is 𝑷=[1DN]\bm{P}^{*}=[\frac{1}{DN}], and thus the corresponding frotf_{\text{rot}} becomes the mean-pooling operation.

Equivalence to max-pooling operation: For (3), when α0=α1=0\alpha_{0}=\alpha_{1}=0, both the structural and the smoothness regularizers are ignored. α2\alpha_{2}\rightarrow\infty and 𝒑0=1D𝟏D\bm{p}_{0}=\frac{1}{D}\bm{1}_{D} mean that 𝑷𝟏N=1D𝟏D\bm{P1}_{N}=\frac{1}{D}\bm{1}_{D} strictly, while α3=0\alpha_{3}=0 and 𝒒0=\bm{q}_{0}=- mean that 𝑷T𝟏D\bm{P}^{T}\bm{1}_{D} is unconstrained. In this case, the problem in (3) becomes

𝑷=argmax𝑷Π(1D𝟏D,)𝑿,𝑷,\displaystyle\bm{P}^{*}=\arg\sideset{}{{}_{\bm{P}\in\Pi(\frac{1}{D}\bm{1}_{D},\cdot)}}{\max}\langle\bm{X},\bm{P}\rangle, (9)

whose optimal solution obviously corresponds to setting pdn=1Dp^{*}_{dn}=\frac{1}{D} if and only if n=argmaxm{xdm}m=1Mn=\arg\max_{m}\{x_{dm}\}_{m=1}^{M}. Therefore, the corresponding frotf_{\text{rot}} becomes the max-pooling operation.

Equivalence to attention-pooling operation: Similar to the case of mean-pooling, given the configuration shown in the above table, the problem in (3) becomes

𝑷=argmin𝑷Π(1D𝟏D,𝒂𝑿)R(𝑷),\displaystyle\bm{P}^{*}=\arg\sideset{}{{}_{\bm{P}\in\Pi(\frac{1}{D}\bm{1}_{D},\bm{a}_{\bm{X}})}}{\min}\text{R}(\bm{P}), (10)

whose optimal solution is 𝑷=1D𝟏D𝒂𝑿T\bm{P}^{*}=\frac{1}{D}\bm{1}_{D}\bm{a}_{\bm{X}}^{T}. Accordingly, the corresponding frotf_{\text{rot}} becomes the attention-pooling operation. ∎

3.4 Hierarchical ROTP operations

Proposition 1 demonstrates that a single ROTP operation can imitate various global pooling methods. Moreover, the hierarchical combination of multiple ROTP operations can reproduce more complicated pooling mechanisms, including mixed pooling operation and set fusion.

3.4.1 Hierarchical ROTP for mixed pooling

Typically, a mixed pooling operation first applies multiple different pooling operations to the same input data and then aggregates the outputs of the pooling operations. It often has more substantial representation power than a single pooling operation because of considering different pooling mechanisms. For example, the mixed mean-max pooling operation in [10] is defined as follows:

fmix(𝑿)=ωMeanPool(𝑿)+(1ω)MaxPool(𝑿).\displaystyle f_{\text{mix}}(\bm{X})=\omega\text{MeanPool}(\bm{X})+(1-\omega)\text{MaxPool}(\bm{X}). (11)

When ω(0,1)\omega\in(0,1) is a single learnable scalar, (11) is called “Mixed mean-max pooling”. When ω\omega is parameterized as a sigmoid function of 𝑿\bm{X}, (11) is called “Gated mean-max pooling”.

It is easy to demonstrate that such a mixed pooling operation can be equivalently achieved by hierarchically integrating three ROTP operations.

Proposition 2 (Hierarchical ROTP for mixed pooling).

Given an arbitrary 𝐗D×N\bm{X}\in\mathbb{R}^{D\times N}, the fmix(𝐗)f_{\text{mix}}(\bm{X}) in (11) can be equivalently implemented by frot([frot(𝐗;𝛉1),frot(𝐗;𝛉2)];𝛉3)f_{\text{rot}}([f_{\text{rot}}(\bm{X};\bm{\theta}_{1}),f_{\text{rot}}(\bm{X};\bm{\theta}_{2})];\bm{\theta}_{3}), where 𝛉1={0,,,,1D𝟏D,1N𝟏N}\bm{\theta}_{1}=\{0,\infty,\infty,\infty,\frac{1}{D}\bm{1}_{D},\frac{1}{N}\bm{1}_{N}\}, 𝛉2={0,0,,0,1D𝟏D,}\bm{\theta}_{2}=\{0,0,\infty,0,\frac{1}{D}\bm{1}_{D},-\}, and 𝛉3={0,,,,1D𝟏D,[ω,1ω]T}\bm{\theta}_{3}=\{0,\infty,\infty,\infty,\frac{1}{D}\bm{1}_{D},[\omega,1-\omega]^{T}\}.

Proof.

In particular, given 𝑿D×N\bm{X}\in\mathbb{R}^{D\times N}, we have

fmix(𝑿)=ωfrot(𝑿;𝜽1)+(1ω)frot(𝑿;𝜽2)=[frot(𝑿;𝜽1),frot(𝑿;𝜽2)]𝒀D×2[ω,1ω]T=(𝒀diag1(1D𝟏D𝒑0[ω,1ω]𝒒0T𝑷𝟏2)(1D𝟏D[ω,1ω]))𝟏2=frot(𝒀;𝜽3).\displaystyle\begin{aligned} &f_{\text{mix}}(\bm{X})=\omega f_{\text{rot}}(\bm{X};\bm{\theta}_{1})+(1-\omega)f_{\text{rot}}(\bm{X};\bm{\theta}_{2})\\ &=\underbrace{[f_{\text{rot}}(\bm{X};\bm{\theta}_{1}),f_{\text{rot}}(\bm{X};\bm{\theta}_{2})]}_{\bm{Y}\in\mathbb{R}^{D\times 2}}[\omega,1-\omega]^{T}\\ &=\Bigl{(}\bm{Y}\odot\text{diag}^{-1}\bigl{(}\underbrace{\overbrace{\tfrac{1}{D}\bm{1}_{D}}^{\bm{p}_{0}}\overbrace{[\omega,1-\omega]}^{\bm{q}_{0}^{T}}}_{\bm{P}^{*}}\bm{1}_{2}\bigr{)}\bigl{(}\tfrac{1}{D}\bm{1}_{D}[\omega,1-\omega]\bigr{)}\Bigr{)}\bm{1}_{2}\\ &=f_{\text{rot}}(\bm{Y};\bm{\theta}_{3}).\end{aligned} (12)

Here, the first equation is based on Proposition 1 — we can replace MeanPool(𝑿)\text{MeanPool}(\bm{X}) and MaxPool(𝑿)\text{MaxPool}(\bm{X}) with frot(𝑿;𝜽1)f_{\text{rot}}(\bm{X};\bm{\theta}_{1}) and frot(𝑿;𝜽2)f_{\text{rot}}(\bm{X};\bm{\theta}_{2}), respectively, where 𝜽1={0,,,,1D𝟏D,1N𝟏N}\bm{\theta}_{1}=\{0,\infty,\infty,\infty,\frac{1}{D}\bm{1}_{D},\frac{1}{N}\bm{1}_{N}\} and 𝜽2={0,0,,0,1D𝟏D,}\bm{\theta}_{2}=\{0,0,\infty,0,\frac{1}{D}\bm{1}_{D},-\}. The concatenation of frot(𝑿;𝜽1)f_{\text{rot}}(\bm{X};\bm{\theta}_{1}) and frot(𝑿;𝜽2)f_{\text{rot}}(\bm{X};\bm{\theta}_{2}) is a matrix with size D×2D\times 2, denoted as 𝒀=[frot(𝑿;𝜽1),frot(𝑿;𝜽2)]\bm{Y}=[f_{\text{rot}}(\bm{X};\bm{\theta}_{1}),f_{\text{rot}}(\bm{X};\bm{\theta}_{2})]. As shown in the third equation of (12), the fmix(𝑿)f_{\text{mix}}(\bm{X}) in (11) can be rewritten based on 𝒑0=1D𝟏D\bm{p}_{0}=\frac{1}{D}\bm{1}_{D}, 𝒒0=[ω,1ω]T\bm{q}_{0}=[\omega,1-\omega]^{T}, and the rank-1 matrix 𝑷=𝒑0𝒒0T\bm{P}^{*}=\bm{p}_{0}\bm{q}_{0}^{T}. The formulation corresponds to passing 𝒀\bm{Y} through the third ROTP operation, i.e., frot(𝒀;𝜽3)f_{\text{rot}}(\bm{Y};\bm{\theta}_{3}), where 𝜽3={0,,,,1D𝟏D,[ω,1ω]T}\bm{\theta}_{3}=\{0,\infty,\infty,\infty,\frac{1}{D}\bm{1}_{D},[\omega,1-\omega]^{T}\}. ∎

Proposition 2 means that the mixed pooling in (11) corresponds to a simple hierarchical ROTP (HROTP) operation, in which the outputs of two ROTP operations are fused through the third ROTP operation. In more general scenarios, given 𝑿D×N\bm{X}\in\mathbb{R}^{D\times N}, we can achieve an MM-head mixed pooling operation via integrating M+1M+1 ROTP operations as follows:

fmrot(𝑿;𝚯)=frot(m=1Mfrot(𝑿;𝜽m);𝜽M+1),\displaystyle\begin{aligned} f_{\text{mrot}}(\bm{X};\bm{\Theta})=f_{\text{rot}}(\|_{m=1}^{M}f_{\text{rot}}(\bm{X};\bm{\theta}_{m});\bm{\theta}_{M+1}),\end{aligned} (13)

where 𝚯={𝜽m}m=1M+1\bm{\Theta}=\{\bm{\theta}_{m}\}_{m=1}^{M+1}, and the 𝜽m\bm{\theta}_{m}’s are different from each other in general.

3.4.2 Hierarchical ROTP for set fusion

Many real-world data have hierarchical set-level structures. A typical example is combinatorial drug analysis, in which each sample is a set of drugs, and each drug can be modeled as a graph and represented as a set of node embeddings. Therefore, when predicting the property of a drug set, we need to first obtain a graph embedding by pooling the node embeddings of each drug and then represent the drug set by pooling the graph embeddings of different drugs.

Such a set fusion operation can also be achieved by our HROTP operation, as illustrated in Fig. 1(b). Denote a set of MM sets as 𝒳={𝑿}m=1M\mathcal{X}=\{\bm{X}\}_{m=1}^{M}. The proposed HROTP module for fusing the MM sets is defined as follows:

fhrot(𝒳;𝚯,g)=frot(m=1Mg(frot(𝑿m;𝜽1));𝜽2),\displaystyle\begin{aligned} f_{\text{hrot}}(\mathcal{X};\bm{\Theta},g)=f_{\text{rot}}(\|_{m=1}^{M}g(f_{\text{rot}}(\bm{X}_{m};\bm{\theta}_{1}));\bm{\theta}_{2}),\end{aligned} (14)

where 𝚯={𝜽1,𝜽2}\bm{\Theta}=\{\bm{\theta}_{1},\bm{\theta}_{2}\}. As shown in (14), frot(;𝜽1)f_{\text{rot}}(\cdot;\bm{\theta}_{1}) works for pooling the elements within each set, which is reused for all the sets, and frot(;𝜽2)f_{\text{rot}}(\cdot;\bm{\theta}_{2}) works for pooling the representations of different sets. Additionally, to enhance the representation power of the HROTP module, we can plug more neural network layers between the above two pooling steps. In (14), the neural network g:DDg:\mathbb{R}^{D}\mapsto\mathbb{R}^{D^{\prime}} works for feature extraction, which can be a multi-layer perceptron (MLP) or more complicated modules.

4 ROTP Layers and Their Implementations

Besides imitating existing global pooling operations based on manually-selected parameters, we further implement the proposed ROTP operation frot(𝑿;𝜽)f_{\text{rot}}(\bm{X};\bm{\theta}) as a learnable neural network layer, whose parameters 𝜽\bm{\theta} include the prior distributions {𝒑0ΔD1,𝒒0ΔN1}\{\bm{p}_{0}\in\Delta^{D-1},\bm{q}_{0}\in\Delta^{N-1}\}, and the regularization weights 𝜶=[αi]i=03(0,)4\bm{\alpha}=[\alpha_{i}]_{i=0}^{3}\in(0,\infty)^{4}. These parameters are constrained parameters: {𝜶i0}i=03\{\bm{\alpha}_{i}\geq 0\}_{i=0}^{3}, 𝒑0ΔD1\bm{p}_{0}\in\Delta^{D-1}, and 𝒒0ΔN1\bm{q}_{0}\in\Delta^{N-1}. We apply the following parametrization strategy to make the proposed ROTP layer with unconstrained parameters. In particular, we set {𝜶i=softplus(𝜷i)}i=03\{\bm{\alpha}_{i}=\text{softplus}(\bm{\beta}_{i})\}_{i=0}^{3}, where {𝜷i}i=03\{\bm{\beta}_{i}\}_{i=0}^{3} are unconstrained parameters. For the prior distributions, we can either fix them as uniform distributions, i.e., 𝒑0=1D𝟏D\bm{p}_{0}=\frac{1}{D}\bm{1}_{D} and 𝒒0=1N𝟏N\bm{q}_{0}=\frac{1}{N}\bm{1}_{N}, or implement them as learnable attention modules, i.e., 𝒑0=softmax(𝑼𝑿𝟏N)\bm{p}_{0}=\text{softmax}(\bm{U}\bm{X}\bm{1}_{N}) and 𝒒0=softmax(𝒘Ttanh(𝑽𝑿))\bm{q}_{0}=\text{softmax}(\bm{w}^{T}\text{tanh}(\bm{VX})) [2], where 𝑼,𝑽D×D\bm{U},\bm{V}\in\mathbb{R}^{D\times D} and 𝒘D\bm{w}\in\mathbb{R}^{D} are unconstrained. As a result, our ROTP layers can be learned by stochastic gradient descent.

The feed-forward step of the ROTP layer corresponds to solving the ROT problem in (3) and obtaining frot(𝑿;𝜽)f_{\text{rot}}(\bm{X};\bm{\theta}) via (5). Its backward step corresponds to the updating of the model parameters, which adjusts the objective function in (3) and changes the optimum 𝑷rot\bm{P}_{\text{rot}}^{*} accordingly. In this work, we implement the backward step by auto-differentiation. By learning the model parameters based on observed data, our ROTP layer may fit the data better and outperform the global pooling methods that are designed empirically.

The architecture of our ROTP layer is determined by the optimization algorithm of (3). In the following two subsections, we introduce two algorithms to solve (3), which lead to two ROTP layers with different architectures.

4.1 Sinkhorn-based ROTP layer

We first consider leveraging the proximal point method [17, 78, 48] to solve (3) iteratively. In particular, in the tt-th iteration, given the current OT plan 𝑷(t)\bm{P}^{(t)}, we solve the following sub-problem:

𝑷(t+1)=argmin𝑷Ω𝑿,𝑷+α0C(𝑿,𝑷(t)),𝑷+α1R(𝑷)+α2KL(𝑷𝟏N|𝒑0)+α3KL(𝑷T𝟏D|𝒒0)+τKL(𝑷|𝑷(t))Proximal term.\displaystyle\begin{aligned} \bm{P}^{(t+1)}=&\arg\sideset{}{{}_{\bm{P}\in\Omega}}{\min}\langle-\bm{X},\bm{P}\rangle+\alpha_{0}\langle C(\bm{X},\bm{P}^{(t)}),\bm{P}\rangle\\ &+\alpha_{1}\text{R}(\bm{P})+\alpha_{2}\text{KL}(\bm{P1}_{N}|\bm{p}_{0})\\ &+\alpha_{3}\text{KL}(\bm{P}^{T}\bm{1}_{D}|\bm{q}_{0})+\underbrace{\tau\text{KL}(\bm{P}|\bm{P}^{(t)})}_{\text{Proximal term}}.\end{aligned} (15)

Here, KL(𝑷|𝑷(t))\text{KL}(\bm{P}|\bm{P}^{(t)}) is the proximal term based on the current variable, which is implemented as a KL-divergence. The weight τ\tau controls its significance. As a result, our ROTP layer is built by stacking TT feed-forward modules, and each module corresponds to the optimization of (15).

When the smoothness regularizer is entropic, i.e., R(𝑷)=𝑷,log𝑷𝟏\text{R}(\bm{P})=\langle\bm{P},\log\bm{P}-\bm{1}\rangle, (15) becomes the following entropic unbalanced optimal transport (EUOT) problem:

min𝑷Ω𝑪(t),𝑷+(α1+τ)log𝑷,𝑷+α2KL(𝑷𝟏N|𝒑0)+α3KL(𝑷T𝟏D|𝒒0).\displaystyle\begin{aligned} \sideset{}{{}_{\bm{P}\in\Omega}}{\min}&\langle\bm{C}^{(t)},\bm{P}\rangle+(\alpha_{1}+\tau)\langle\log\bm{P},\bm{P}\rangle\\ &+\alpha_{2}\text{KL}(\bm{P1}_{N}|\bm{p}_{0})+\alpha_{3}\text{KL}(\bm{P}^{T}\bm{1}_{D}|\bm{q}_{0}).\end{aligned} (16)

where the matrix 𝑪(t)=[cdn(t)]=𝑿α0𝚺1𝑷(t)𝚺2Tτlog𝑷(t)\bm{C}^{(t)}=[c^{(t)}_{dn}]=-\bm{X}-\alpha_{0}\bm{\Sigma}_{1}\bm{P}^{(t)}\bm{\Sigma}_{2}^{T}-\tau\log\bm{P}^{(t)} is determined by the input data and the current variable 𝑷(t)\bm{P}^{(t)}. According to [46, 24], we consider the Fenchel’s dual form of the EUOT problem:

min𝒂D,𝒃N(α1+τ)d,n=1D,Nexp(ad+bn+cdn(t)α1+τ)+α2exp(1α2𝒂),𝒑0+α3exp(1α3𝒃),𝒒0.\displaystyle\begin{aligned} &\sideset{}{{}_{\bm{a}\in\mathbb{R}^{D},\bm{b}\in\mathbb{R}^{N}}}{\min}(\alpha_{1}+\tau)\sideset{}{{}_{d,n=1}^{D,N}}{\sum}\exp\left(\frac{a_{d}+b_{n}+c^{(t)}_{dn}}{\alpha_{1}+\tau}\right)\\ &+\alpha_{2}\Bigl{\langle}\exp\Bigl{(}-\frac{1}{\alpha_{2}}\bm{a}\Bigr{)},\bm{p}_{0}\Bigr{\rangle}+\alpha_{3}\Bigl{\langle}\exp\Bigl{(}-\frac{1}{\alpha_{3}}\bm{b}\Bigr{)},\bm{q}_{0}\Bigr{\rangle}.\end{aligned} (17)
Refer to caption
Figure 2: An illustration of the Sinkhorn-based ROTP layer. The red, blue, and black arrows correspond to the computational flow of data, model parameters, intermediate results.
Algorithm 1 frot(𝑿;𝜽)f_{\text{rot}}(\bm{X};\bm{\theta}) based on Sinkhorn-scaling
0:  Data 𝑿\bm{X}, model parameters 𝜽\bm{\theta}.
1:  𝑷(0)=𝒑0𝒒0T\bm{P}^{(0)}=\bm{p}_{0}\bm{q}_{0}^{T}.
2:  for t=0,,T1t=0,...,T-1 do
3:     𝑪(t)=𝑿α0𝚺1𝑷(t)𝚺2Tτlog𝑷(t)\bm{C}^{(t)}=-\bm{X}-\alpha_{0}\bm{\Sigma}_{1}\bm{P}^{(t)}\bm{\Sigma}_{2}^{T}-\tau\log\bm{P}^{(t)}.
4:     Set 𝒂(0)=𝟎D\bm{a}^{(0)}=\bm{0}_{D}, 𝒃(0)=𝟎N\bm{b}^{(0)}=\bm{0}_{N}, 𝒀(0)=1α1𝑪(t)\bm{Y}^{(0)}=-\frac{1}{\alpha_{1}^{\prime}}\bm{C}^{(t)}.
5:     for k=0,,K1k=0,...,K-1 do
6:        log𝒑=LogSumExprow(𝒀(k))\log\bm{p}=\text{LogSumExp}_{\text{row}}(\bm{Y}^{(k)}),
7:        log𝒒=LogSumExpcol(𝒀(k))\log\bm{q}=\text{LogSumExp}_{\text{col}}(\bm{Y}^{(k)}).
8:        Dual Variable Update:𝒂(k+1)=α2α1+α2(1α1𝒂(k)+log𝒑0log𝒑)\bm{a}^{(k+1)}=\frac{\alpha_{2}}{\alpha_{1}^{\prime}+\alpha_{2}}\left(\frac{1}{\alpha_{1}^{\prime}}\bm{a}^{(k)}+\log\bm{p}_{0}-\log\bm{p}\right). 𝒃(k+1)=α3α1+α3(1α1𝒃(k)+log𝒒0log𝒒)\bm{b}^{(k+1)}=\frac{\alpha_{3}}{\alpha_{1}^{\prime}+\alpha_{3}}\left(\frac{1}{\alpha_{1}^{\prime}}\bm{b}^{(k)}+\log\bm{q}_{0}-\log\bm{q}\right).
9:        Logarithmic Scaling:𝒀(k+1)=1α1𝑪(t)+𝒂(k+1)𝟏NT+𝟏D𝒃(k+1)T\bm{Y}^{(k+1)}=-\frac{1}{\alpha_{1}^{\prime}}\bm{C}^{(t)}+\bm{a}^{(k+1)}\bm{1}_{N}^{T}+\bm{1}_{D}\bm{b}^{(k+1)T}.
10:     end for
11:     𝑷(t+1):=exp(𝒀(K))\bm{P}^{(t+1)}:=\exp(\bm{Y}^{(K)})
12:  end for
13:  𝑷:=𝑷(T)\bm{P}^{*}:=\bm{P}^{(T)}, and apply (5) to obtain frot(𝑿;𝜽)f_{\text{rot}}(\bm{X};\bm{\theta}).

This problem can be solved by the Sinkhorn-scaling algorithm [44, 23]. In particular, the Sinkhorn-scaling algorithm solves this dual problem by the following iterative steps: i)i) Initialize dual variables as 𝒂(0)=𝟎D\bm{a}^{(0)}=\bm{0}_{D} and 𝒃(0)=𝟎N\bm{b}^{(0)}=\bm{0}_{N}. ii)ii) In the kk-th iteration, the current dual variables 𝒂(k)\bm{a}^{(k)} and 𝒃(k)\bm{b}^{(k)} are updated by

𝑻(𝒂(k),𝒃(k))=exp(𝒂(k)𝟏NT+𝟏D(𝒃(k))T𝑪(t)α1),𝒑(k)=𝑻(𝒂(k),𝒃(k))𝟏N,𝒒(k)=𝑻(𝒂(k),𝒃(k))T𝟏D,𝒂(k+1)=α2α1+α2(1α1𝒂(k)+log𝒑0log𝒑(k))𝒃(k+1)=α3α1+α3(1α1𝒃(k)+log𝒒0log𝒒(k)),\displaystyle\begin{aligned} &\bm{T}(\bm{a}^{(k)},\bm{b}^{(k)})=\exp\Bigl{(}\bm{a}^{(k)}\bm{1}_{N}^{T}+\bm{1}_{D}(\bm{b}^{(k)})^{T}-\frac{\bm{C}^{(t)}}{\alpha_{1}^{\prime}}\Bigr{)},\\ &\bm{p}^{(k)}=\bm{T}(\bm{a}^{(k)},\bm{b}^{(k)})\bm{1}_{N},\quad\bm{q}^{(k)}=\bm{T}(\bm{a}^{(k)},\bm{b}^{(k)})^{T}\bm{1}_{D},\\ &\bm{a}^{(k+1)}=\frac{\alpha_{2}}{\alpha_{1}^{\prime}+\alpha_{2}}\left(\frac{1}{\alpha_{1}^{\prime}}\bm{a}^{(k)}+\log\bm{p}_{0}-\log\bm{p}^{(k)}\right)\\ &\bm{b}^{(k+1)}=\frac{\alpha_{3}}{\alpha_{1}^{\prime}+\alpha_{3}}\left(\frac{1}{\alpha_{1}^{\prime}}\bm{b}^{(k)}+\log\bm{q}_{0}-\log\bm{q}^{(k)}\right),\end{aligned} (18)

where α1=α1+τ\alpha_{1}^{\prime}=\alpha_{1}+\tau. iii)iii) After KK steps, the variables converges and the optimal transport plan is updated as

𝑷(t+1)=𝑻(𝒂(K),𝒃(K)).\displaystyle\bm{P}^{(t+1)}=\bm{T}(\bm{a}^{(K)},\bm{b}^{(K)}). (19)

The convergence of the algorithm has been proven in [79] — with the increase of tt, 𝑷(t)\bm{P}^{(t)} converges to a stationary point. Therefore, after repeating the above process TT times, we set 𝑷:=𝑷(T)\bm{P}^{*}:=\bm{P}^{(T)}.

The algorithm above leads to a Sinkhorn-based ROTP layer frot(𝑿;𝜽)f_{\text{rot}}(\bm{X};\bm{\theta}). As illustrated in Fig. 2, this layer is implemented by unrolling the above iterative scheme by stacking TT proximal point modules, and each module is implemented by KK Sinkhorn-scaling steps. Furthermore, we can leverage the logarithmic stabilization strategy [46, 47] to improve the numerical stability of the Sinkhorn-scaling algorithm — instead of updating 𝑻(𝒂(k),𝒃(k))\bm{T}(\bm{a}^{(k)},\bm{b}^{(k)}) directly, we can update log𝑻(𝒂(k),𝒃(k))\log\bm{T}(\bm{a}^{(k)},\bm{b}^{(k)}). Accordingly, the exponential and scaling operations in (18) are integrated as “LogSumExp” operations, which helps to avoid numerical instability issues. Algorithm 1 shows the Sinkhorn-based ROTP layer implemented based on the stabilized Sinkhorn-scaling algorithm. Here, LogSumExpcol{}_{\text{col}} and LogSumExprow{}_{\text{row}} apply column-wise and row-wise summation, respectively.

4.2 Bregman ADMM-based ROTP layer

The Sinkhorn-based ROTP layer requires the smoothness regularizer R(𝑷)\text{R}(\bm{P}) to be entropic, which limits its generalizability. Additionally, even if applying the logarithmic stabilization strategy, the Sinkhorn-scaling algorithm still has a high risk of numerical instability because the parameters α0\alpha_{0} and α1\alpha_{1} may change in a wide range during training. To overcome the challenges, we apply a Bregman alternating direction method of multipliers (Bregman ADMM or BADMM) to build another ROTP layer. In particular, we first rewrite (3) in an equivalent format by introducing three auxiliary variables 𝑺\bm{S}, 𝝁\bm{\mu} and 𝜼\bm{\eta}:

min𝑷,𝑺,𝝁,𝜼𝑿,𝑷+α0C(𝑿,𝑺),𝑷+α1R(𝑺)+α2KL(𝝁|𝒑0)+α3KL(𝜼|𝒒0)s.t.𝑷=𝑺,𝑷𝟏N=𝝁,𝑺T𝟏D=𝜼.\displaystyle\begin{aligned} &\sideset{}{{}_{\bm{P},\bm{S},\bm{\mu},\bm{\eta}}}{\min}\langle-\bm{X},\bm{P}\rangle+\alpha_{0}\langle C(\bm{X},\bm{S}),\bm{P}\rangle\\ &\quad\quad+\alpha_{1}\text{R}(\bm{S})+\alpha_{2}\text{KL}(\bm{\mu}|\bm{p}_{0})+\alpha_{3}\text{KL}(\bm{\eta}|\bm{q}_{0})\\ &s.t.~{}\bm{P}=\bm{S},~{}\bm{P}\bm{1}_{N}=\bm{\mu},~{}\bm{S}^{T}\bm{1}_{D}=\bm{\eta}.\end{aligned} (20)

These three auxiliary variables correspond to the joint distribution 𝑷\bm{P} and its marginals. This problem can be further rewritten in a Bregman augmented Lagrangian form by introducing three dual variables 𝒁\bm{Z}, 𝒛1\bm{z}_{1}, 𝒛2\bm{z}_{2} for the three constraints in (20), respectively. For the ROT problem with auxiliary variables, we can write its Bregman augmented Lagrangian form as

min𝑷,𝑺,𝝁,𝜼,𝒁,𝒛1,𝒛2𝑿+α0C(𝑿,𝑺),𝑷+α1R(𝑺,𝑷)Regularizer 1+𝒁,𝑷𝑺+ρDiv(𝑷,𝑺)Constraint 1, for 𝑷 and 𝑺+α2KL(𝝁|𝒑0)Regularizer 2+𝒛1,𝝁𝑷𝟏N+ρDiv(𝝁,𝑷𝟏N)Constraint 2, for 𝝁 and 𝑷+α3KL(𝜼|𝒒0)Regularizer 3+𝒛2,𝜼𝑺T𝟏D+ρDiv(𝜼,𝑺T𝟏N)Constraint 3, for 𝜼 and 𝑺.\displaystyle\begin{aligned} &\sideset{}{{}_{\bm{P},\bm{S},\bm{\mu},\bm{\eta},\bm{Z},\bm{z}_{1},\bm{z}_{2}}}{\min}\langle-\bm{X}+\alpha_{0}C(\bm{X},\bm{S}),\bm{P}\rangle\\ &+\underbrace{\alpha_{1}\text{R}(\bm{S},\bm{P})}_{\text{Regularizer 1}}+\underbrace{\langle\bm{Z},\bm{P}-\bm{S}\rangle+\rho\text{Div}(\bm{P},\bm{S})}_{\text{Constraint 1, for $\bm{P}$ and $\bm{S}$}}\\ &+\underbrace{\alpha_{2}\text{KL}(\bm{\mu}|\bm{p}_{0})}_{\text{Regularizer 2}}+\underbrace{\langle\bm{z}_{1},\bm{\mu}-\bm{P}\bm{1}_{N}\rangle+\rho\text{Div}(\bm{\mu},\bm{P}\bm{1}_{N})}_{\text{Constraint 2, for $\bm{\mu}$ and $\bm{P}$}}\\ &+\underbrace{\alpha_{3}\text{KL}(\bm{\eta}|\bm{q}_{0})}_{\text{Regularizer 3}}+\underbrace{\langle\bm{z}_{2},\bm{\eta}-\bm{S}^{T}\bm{1}_{D}\rangle+\rho\text{Div}(\bm{\eta},\bm{S}^{T}\bm{1}_{N})}_{\text{Constraint 3, for $\bm{\eta}$ and $\bm{S}$}}.\end{aligned} (21)

Here, Div(,)\text{Div}(\cdot,\cdot) represents the Bregman divergence term, which is implemented as the KL-divergence as the work in [26, 27] did. Its significance is controlled by ρ>0\rho>0. The last three lines of (21) contain the Bregman augmented Lagrangian terms, which correspond to the three constraints in (20). Here, R(𝑺,𝑷)\text{R}(\bm{S},\bm{P}) corresponds to the smoothness regularizer. When applying the entropic regularizer, we set R(𝑺,𝑷)=𝑺,log𝑺𝟏\text{R}(\bm{S},\bm{P})=\langle\bm{S},\log\bm{S}-\bm{1}\rangle. When applying the quadratic regularizer, we set R(𝑺,𝑷)=𝑺,𝑷\text{R}(\bm{S},\bm{P})=\langle\bm{S},\bm{P}\rangle.

In particular, we solve the ROT problem by alternating optimization: At the tt-th iteration, we first update 𝑷\bm{P} while fix other variables. We can ignore Constraint 3 and the three regularizers (because they are irrelevant to 𝑷\bm{P}) and write Constraint 2 explicitly. The problem becomes:

min𝑷Π(𝝁(t),)𝑿α0𝚺1𝑺(t)𝚺2T,𝑷+α1R(𝑺(t),𝑷)+𝒁(t),𝑷𝑺(t)+ρKL(𝑷|𝑺(t)),\displaystyle\begin{aligned} &\sideset{}{{}_{\bm{P}\in\Pi(\bm{\mu}^{(t)},\cdot)}}{\min}\langle-\bm{X}-\alpha_{0}\bm{\Sigma}_{1}\bm{S}^{(t)}\bm{\Sigma}_{2}^{T},\bm{P}\rangle+\alpha_{1}\text{R}(\bm{S}^{(t)},\bm{P})\\ &+\langle\bm{Z}^{(t)},\bm{P}-\bm{S}^{(t)}\rangle+\rho\text{KL}(\bm{P}|\bm{S}^{(t)}),\end{aligned} (22)

where Π(𝝁(t),)={𝑷>𝟎|𝑷𝟏N=𝝁(t)}\Pi(\bm{\mu}^{(t)},\cdot)=\{\bm{P}>\bm{0}|\bm{P}\bm{1}_{N}=\bm{\mu}^{(t)}\} is the one-side constraint of 𝑷\bm{P}. We can derive the closed-form solution of this problem based on the first-order optimality condition:

log𝑷(t+1)=(log𝝁(t)LogSumExprow(𝒀))𝟏NT+𝒀,\displaystyle\begin{aligned} \log\bm{P}^{(t+1)}=(\log\bm{\mu}^{(t)}-\text{LogSumExp}_{\text{row}}(\bm{Y}))\bm{1}_{N}^{T}+\bm{Y},\end{aligned} (23)

where 𝒀=\bm{Y}=

{𝑿+α0𝚺1𝑺(t)𝚺2T𝒁(t)+ρlog𝑺(t)ρ,entropic R(𝑺,𝑷)𝑿+α0𝚺1𝑺(t)𝚺2Tα1𝑺(t)𝒁(t)+ρlog𝑺(t)ρ,quadratic R(𝑺,𝑷).\displaystyle\begin{cases}\frac{\bm{X}+\alpha_{0}\bm{\Sigma}_{1}\bm{S}^{(t)}\bm{\Sigma}_{2}^{T}-\bm{Z}^{(t)}+\rho\log\bm{S}^{(t)}}{\rho},&\text{entropic }\text{R}(\bm{S},\bm{P})\\ \frac{\bm{X}+\alpha_{0}\bm{\Sigma}_{1}\bm{S}^{(t)}\bm{\Sigma}_{2}^{T}-\alpha_{1}\bm{S}^{(t)}-\bm{Z}^{(t)}+\rho\log\bm{S}^{(t)}}{\rho},&\text{quadratic }\text{R}(\bm{S},\bm{P}).\\ \end{cases}

Given 𝑷(t+1)\bm{P}^{(t+1)}, we can update the auxiliary variables 𝑺\bm{S} in a similar manner: we ignore Constraint 2, Regularizers 2 and 3, and write Constraint 3 explicitly. Then, the optimization problem of 𝑺\bm{S} becomes

min𝑺Π(,𝜼(t))α0𝚺1T𝑷(t+1)𝚺2,𝑺+α1R(𝑺,𝑷(t+1))+𝒁(t),𝑷(t+1)𝑺+ρKL(𝑺|𝑷(t+1)),\displaystyle\begin{aligned} &\sideset{}{{}_{\bm{S}\in\Pi(\cdot,\bm{\eta}^{(t)})}}{\min}\langle-\alpha_{0}\bm{\Sigma}_{1}^{T}\bm{P}^{(t+1)}\bm{\Sigma}_{2},\bm{S}\rangle+\alpha_{1}\text{R}(\bm{S},\bm{P}^{(t+1)})\\ &+\langle\bm{Z}^{(t)},\bm{P}^{(t+1)}-\bm{S}\rangle+\rho\text{KL}(\bm{S}|\bm{P}^{(t+1)}),\end{aligned} (24)

where Π(,𝜼(t))={𝑺>𝟎|𝑺T𝟏D=𝜼(t)}\Pi(\cdot,\bm{\eta}^{(t)})=\{\bm{S}>\bm{0}|\bm{S}^{T}\bm{1}_{D}=\bm{\eta}^{(t)}\}. Similarly, we have

log𝑺(t+1)=𝟏D(log𝜼(t)LogSumExpcol(𝒀))T+𝒀,\displaystyle\begin{aligned} \log\bm{S}^{(t+1)}=\bm{1}_{D}(\log\bm{\eta}^{(t)}-\text{LogSumExp}_{\text{col}}(\bm{Y}))^{T}+\bm{Y},\end{aligned} (25)

where 𝒀=\bm{Y}=

{𝒁(t)+α0𝚺1T𝑷(t+1)𝚺2+ρlog𝑷(t+1)α1+ρ,entropic R(𝑺,𝑷)α0𝚺1T𝑷(t+1)𝚺2α1𝑷(t+1)+𝒁(t)+ρlog𝑷(t+1)ρ,quadratic R(𝑺,𝑷).\displaystyle\begin{cases}\frac{\bm{Z}^{(t)}+\alpha_{0}\bm{\Sigma}_{1}^{T}\bm{P}^{(t+1)}\bm{\Sigma}_{2}+\rho\log\bm{P}^{(t+1)}}{\alpha_{1}+\rho},&\text{entropic }\text{R}(\bm{S},\bm{P})\\ \frac{\alpha_{0}\bm{\Sigma}_{1}^{T}\bm{P}^{(t+1)}\bm{\Sigma}_{2}-\alpha_{1}\bm{P}^{(t+1)}+\bm{Z}^{(t)}+\rho\log\bm{P}^{(t+1)}}{\rho},&\text{quadratic }\text{R}(\bm{S},\bm{P}).\end{cases}

Given 𝑷(t+1)\bm{P}^{(t+1)} and 𝑺(t+1)\bm{S}^{(t+1)}, we can update 𝝁\bm{\mu} and 𝜼\bm{\eta} by solving their corresponding Bregman augmented Lagrangian optimization problems, whose solutions have closed forms as well based on the first-order optimality condition. In particular, when updating 𝝁\bm{\mu}, we ignore the terms irrelevant to 𝝁\bm{\mu} in (21) and leverage the constraint 𝑺(t+1)𝟏N=𝝁(t)\bm{S}^{(t+1)}\bm{1}_{N}=\bm{\mu}^{(t)} explicitly. Accordingly, the problem becomes

min𝝁ΔD1α2KL(𝝁|𝒑0)+𝒛1(t),𝝁𝝁(t)+ρKL(𝝁|𝝁(t))𝝁(t+1)=σ(ρlog𝝁(t)+α2log𝒑0𝒛1(t)ρ+α2Denoted as 𝒚)Logarithmic Updatelog𝝁(t+1)=𝒚LogSumExp(𝒚).\displaystyle\begin{aligned} &\sideset{}{}{\min}_{\bm{\mu}\in\Delta^{D-1}}\alpha_{2}\text{KL}(\bm{\mu}|\bm{p}_{0})+\langle\bm{z}_{1}^{(t)},\bm{\mu}-\bm{\mu}^{(t)}\rangle+\rho\text{KL}(\bm{\mu}|\bm{\mu}^{(t)})\\ &\Rightarrow\bm{\mu}^{(t+1)}=\sigma\Bigl{(}\underbrace{\frac{\rho\log\bm{\mu}^{(t)}+\alpha_{2}\log\bm{p}_{0}-\bm{z}_{1}^{(t)}}{\rho+\alpha_{2}}}_{\text{Denoted as~{}}\bm{y}}\Bigr{)}\\ &\xRightarrow{\text{Logarithmic Update}}\log\bm{\mu}^{(t+1)}=\bm{y}-\text{LogSumExp}(\bm{y}).\end{aligned} (26)

Here, the softmax σ()\sigma(\cdot) and LogSumExp()\text{LogSumExp}(\cdot) are operations for vectors.

Similarly, when updating 𝜼\bm{\eta}, we ignore the terms irrelevant to 𝜼\bm{\eta} in (21) and leverage the constraint (𝑷(t+1))T𝟏D=𝜼(t)(\bm{P}^{(t+1)})^{T}\bm{1}_{D}=\bm{\eta}^{(t)} explicitly. The problem becomes

min𝜼ΔN1α3KL(𝜼|𝒒0)+𝒛2(t),𝜼𝜼(t)+ρKL(𝜼|𝜼(t))𝜼(t+1)=σ(ρlog𝜼(t)+α3log𝒒0𝒛2(t)ρ+α3Denoted as 𝒚)Logarithmic Updatelog𝜼(t+1)=𝒚LogSumExp(𝒚).\displaystyle\begin{aligned} &\sideset{}{}{\min}_{\bm{\eta}\in\Delta^{N-1}}\alpha_{3}\text{KL}(\bm{\eta}|\bm{q}_{0})+\langle\bm{z}_{2}^{(t)},\bm{\eta}-\bm{\eta}^{(t)}\rangle+\rho\text{KL}(\bm{\eta}|\bm{\eta}^{(t)})\\ &\Rightarrow\bm{\eta}^{(t+1)}=\sigma\Bigl{(}\underbrace{\frac{\rho\log\bm{\eta}^{(t)}+\alpha_{3}\log\bm{q}_{0}-\bm{z}_{2}^{(t)}}{\rho+\alpha_{3}}}_{\text{Denoted as~{}}\bm{y}}\Bigr{)}\\ &\xRightarrow{\text{Logarithmic Update}}\log\bm{\eta}^{(t+1)}=\bm{y}-\text{LogSumExp}(\bm{y}).\end{aligned} (27)

Finally, we update the dual variables by

𝒁(t+1)=𝒁(t)+ρ(𝑷(t+1)𝑺(t+1)),𝒛1(t+1)=𝒛1(t)+ρ(𝝁(t+1)𝑷(t+1)𝟏N),𝒛2(t+1)=𝒛2(t)+ρ(𝜼(t+1)(𝑺(t+1))T𝟏D).\displaystyle\begin{aligned} &\bm{Z}^{(t+1)}=\bm{Z}^{(t)}+\rho(\bm{P}^{(t+1)}-\bm{S}^{(t+1)}),\\ &\bm{z}_{1}^{(t+1)}=\bm{z}_{1}^{(t)}+\rho(\bm{\mu}^{(t+1)}-\bm{P}^{(t+1)}\bm{1}_{N}),\\ &\bm{z}_{2}^{(t+1)}=\bm{z}_{2}^{(t)}+\rho(\bm{\eta}^{(t+1)}-(\bm{S}^{(t+1)})^{T}\bm{1}_{D}).\end{aligned} (28)
Refer to caption
Figure 3: An illustration of the BADMM-based ROTP layer. The red, blue, and black arrows correspond to the computational flow of data, model parameters, intermediate results.
Algorithm 2 frot(𝑿;𝜽)f_{\text{rot}}(\bm{X};\bm{\theta}) based on Bregman ADMM
0:  Data 𝑿\bm{X} and parameters 𝜽\bm{\theta}.
1:  Primal and auxiliary variable log𝑷(0)=log𝑺(0)=log(𝒑0𝒒0T)\log\bm{P}^{(0)}=\log\bm{S}^{(0)}=\log(\bm{p}_{0}\bm{q}_{0}^{T}), log𝝁(0)=log𝒑0\log\bm{\mu}^{(0)}=\log\bm{p}_{0}, log𝜼(0)=log𝒒0\log\bm{\eta}^{(0)}=\log\bm{q}_{0}.
2:  Dual variables 𝒁(0)=𝟎D×N\bm{Z}^{(0)}=\bm{0}_{D\times N}, 𝒛1(0)=𝟎D\bm{z}_{1}^{(0)}=\bm{0}_{D}, 𝒛2(0)=𝟎N\bm{z}_{2}^{(0)}=\bm{0}_{N}.
3:  for t=0,,T1t=0,...,T-1 do
4:     Derive log𝑷(t+1)\log\bm{P}^{(t+1)} by (23).
5:     Derive log𝑺(t+1)\log\bm{S}^{(t+1)} by (25).
6:     Derive log𝝁(t+1)\log\bm{\mu}^{(t+1)} by (26).
7:     Derive log𝜼(t+1)\log\bm{\eta}^{(t+1)} by (27).
8:     Update dual variables 𝒁,𝒛1,𝒛2\bm{Z},\bm{z}_{1},\bm{z}_{2} by (28).
9:  end for
10:  𝑷:=𝑷(T)\bm{P}^{*}:=\bm{P}^{(T)}, and apply (5) to obtain frot(𝑿;𝜽)f_{\text{rot}}(\bm{X};\bm{\theta}).

It is easy to find that the above Bregman ADMM algorithm can be viewed as a variant of the proximal point method, in which the proximal term is implemented as a set of Bregman divergence terms. When deriving 𝑷(t+1)\bm{P}^{(t+1)}, our Bregman ADMM applies the auxiliary variable 𝑺(t)\bm{S}^{(t)}, rather than the previous estimation 𝑷(t)\bm{P}^{(t)}, to regularize it. Taking the above steps in an iteration as a module, we can implement our ROTP layer by stacking such modules, as shown in Fig. 3. Additionally, as shown in (23)-(27), both the primal and auxiliary variables can be updated in their logarithmic formats, which improves the numerical stability of our algorithm. In summary, Algorithm 2 shows the scheme of our BADMM-based ROTP layer.

5 Further Analysis and Comparisons

5.1 Comparisons for various ROTP layers

Applying different smoothness regularizers and algorithms, we consider three ROTP layers: The Sinkhorn-based ROTP layer is denoted as ROTPS{}_{\text{S}}, the BADMM-based ROTP layer with the entropic smoothness regularizer is denoted as ROTPB-E{}_{\text{B-E}}, and the BADMM-based ROTP layer with the quadratic smoothness regularizer is denoted as ROTPB-Q{}_{\text{B-Q}}. In this subsection, we will analyze their convergence, complexity, approximation power, and numerical stability in depth.

5.1.1 Convergence

As illustrated in Figs. 2 and 3, each ROTP layer is implemented by stacking TT feed-forward modules. Given the same data matrix and fixed model parameters, we solve the ROT problem in (3) by different ROTP layers and record the change of the expectation term 𝑿,𝑷\langle-\bm{X},\bm{P}\rangle with the increase of TT. The comparison is shown in Fig. 4. We can find that all three layers make their objective functions converge when using more than 1616 feed-forward modules. However, applying different smoothness regularizers and algorithms lead to different convergence rates and optimization trajectories — given the same number of modules, the three layers often converge to different optimums.

Refer to caption
Figure 4: Given a batch of 50 sample sets, in which each sample set contains five hundred 100-dimensional samples, we compare the ROTP layers on their convergence.

5.1.2 Computational complexity

Each Sinkhorn-scaling module contains KK Sinkhorn iterations, while each Bregman ADMM module corresponds to one-step updates of the primal, auxiliary, and dual variables. As a result, each Sinkhorn-scaling module involves 2K2K LogSumExp functions (the most time-consuming process), while each BADMM module merely requires two LogSumExp functions. Based on the analysis above, the computational complexity of the BADMM-based layer is lower than that of the Sinkhorn-based layer. In particular, given a set of NN DD-dimensional samples, the similarity matrices (𝚺1\bm{\Sigma}_{1} and 𝚺2\bm{\Sigma}_{2}) are computed with the complexity 𝒪(N2D+D2N)\mathcal{O}(N^{2}D+D^{2}N). Taking the samples and the similarity matrices as input, the computational complexity of the Sinkhorn-based ROTP layer is 𝒪(T(N2D+D2N+KND))\mathcal{O}(T(N^{2}D+D^{2}N+KND)), where N2D+D2NN^{2}D+D^{2}N corresponds to the computation of 𝚺1𝑷(t)𝚺2T\bm{\Sigma}_{1}\bm{P}^{(t)}\bm{\Sigma}_{2}^{T} per step, and KNDKND corresponds to the KK Sinkhorn iterations within each Sinkhorn-scaling module. The computational complexity of the BADMM-based ROTP layer is 𝒪(T(N2D+D2N))\mathcal{O}(T(N^{2}D+D^{2}N)), where N2D+D2NN^{2}D+D^{2}N corresponds to the computation of 𝚺1𝑷(t)𝚺2T\bm{\Sigma}_{1}\bm{P}^{(t)}\bm{\Sigma}_{2}^{T} (and 𝚺1T(𝑺(t))T𝚺2\bm{\Sigma}_{1}^{T}(\bm{S}^{(t)})^{T}\bm{\Sigma}_{2}). Figs. 5(a) and 5(b) verify the above analysis further. The runtime of these two ROTP layers under different TT’s and NN’s indicates that the Sinkhorn-based ROTP layer is slower than the BADMM-based ROTP layer.

Refer to caption
(a) Runtime w.r.t. the number of iterations (with D=5D=5, N=50N=50, K=5K=5, α0>0\alpha_{0}>0)
Refer to caption
(b) Runtime w.r.t. the number of samples (with D=5D=5, T=50T=50, α0>0\alpha_{0}>0)
Refer to caption
(c) Runtime w.r.t. the number of iterations (with D=5D=5, N=50N=50, α0=0\alpha_{0}=0)
Refer to caption
(d) Runtime w.r.t. the number of samples (with D=5D=5, TorK=50T~{}\text{or}~{}K=50, α0=0\alpha_{0}=0)
Figure 5: Comparisons of the two ROTP layers on their feed-forward runtime on a single CPU. We plot the averaged runtime and its standard deviation in 2020 trials. For the BADMM-based ROTP layer, we just consider the case using the entropic smoothness regularizer because its runtime is stable with respect to the type of the regularizer.

When setting α0=0\alpha_{0}=0, the ROT problem in (3) degrades to a classic unbalanced optimal transport (UOT) problem. In such a situation, the computation of 𝚺1𝑷𝚺2T\bm{\Sigma}_{1}\bm{P}\bm{\Sigma}_{2}^{T} is avoided, which leads to lower complexity for both of the layers. Especially for the Sinkhorn-based ROTP layer, when α0=0\alpha_{0}=0, it only requires a single Sinkhorn-scaling module with KK iterations to obtain the optimal transport matrix, which avoids the nested iterative optimization. Accordingly, its complexity becomes 𝒪(KND)\mathcal{O}(KND). The BADMM-based ROTP layer, however, still requires TT BADMM module, so its complexity is 𝒪(TND)\mathcal{O}(TND) when α0=0\alpha_{0}=0. Figs. 5(c) and 5(d) show that when α0=0\alpha_{0}=0, the Sinkhorn-based ROTP layer is faster than the BADMM-based ROTP layer.

5.1.3 Precision on approximating existing pooling layers

Proposition 1 demonstrates that, in theory, our ROTP layer can be equivalent to some existing pooling operations under specific settings. In practice, however, implementing the ROTP layer by different algorithms leads to different approximation precision. As shown in Fig. 6, both the Sinkhorn-based ROTP layer and the BADMM-based ROTP layer can reproduce the functionality of mean-pooling perfectly. However, the Sinkhorn-based ROTP layer can approximate max-pooling with higher accuracy, while the BADMM-based ROTP layer works better on approximating the attention-pooling [2]. In other words, when approximating a specific global pooling operation by an ROTP layer, we should consider the model architecture’s influence.

Refer to caption
Figure 6: In each column, the matrices from top to bottom are the ground truth and the 𝑷\bm{P}^{*}’s obtained by different ROTP layers. Here, we set 𝑿5×10𝒩(0,1)\bm{X}\in\mathbb{R}^{5\times 10}\sim\mathcal{N}(0,1), α1=α2=α3=104\alpha_{1}=\alpha_{2}=\alpha_{3}=10^{4} in (a, c), and α1=α3=0.01\alpha_{1}=\alpha_{3}=0.01 and α2=104\alpha_{2}=10^{4} in (b).
Refer to caption
(a) ROTPS{}_{\text{S}} (α0=0\alpha_{0}=0)
Refer to caption
(b) ROTPS{}_{\text{S}} (α0=0.1\alpha_{0}=0.1)
Refer to caption
(c) ROTPB-E/Q{}_{\text{B-E/Q}} (α0=0\alpha_{0}=0)
Refer to caption
(d) ROTPB-E/Q{}_{\text{B-E/Q}} (α0=0.1\alpha_{0}=0.1)
Figure 7: Given 𝑿5×10\bm{X}\in\mathbb{R}^{5\times 10}, we learn 𝑷\bm{P}^{*}’s under different configurations and calculate 𝑷1\|\bm{P}^{*}\|_{1}’s. Each subfigure shows the 𝑷1\|\bm{P}^{*}\|_{1}’s, and the white regions correspond to NaN’s.

5.1.4 Numerical stability

As aforementioned, one primary motivation for designing the BADMM-based ROTP layer is to overcome the numerical instability of the Sinkhorn-based ROTP layer. For the two layers, we set α2=α3\alpha_{2}=\alpha_{3} and select α1,α2,α3\alpha_{1},\alpha_{2},\alpha_{3} from {105,,104}\{10^{-5},...,10^{4}\}. Under such configurations, we derive 100 𝑷\bm{P}’s accordingly. For each layer, we verify its numerical stability by checking whether 𝑷1=d,n|pdn|1\|\bm{P}^{*}\|_{1}=\sum_{d,n}|p_{dn}|\approx 1 and whether 𝑷\bm{P}^{*} contains NaN elements. Figs. 7(a) and 7(b) show that the Sinkhorn-based ROTP merely works under some configurations, which obtains NaN in many cases. When α0>0\alpha_{0}>0 and solving the ROT problem, the numerical stability of the Sinkhorn-based method becomes even worse. In the following experiments, we have to set α0=0\alpha_{0}=0, α1>0.1\alpha_{1}>0.1, α2,α3(105,10)\alpha_{2},\alpha_{3}\in(10^{-5},10) for the Sinkhorn-based ROTP layer to ensure the stability of its training process. On the contrary, our BADMM-based ROTP layer owns much better numerical stability. As shown in Figs. 7(c)-7(d), no matter whether α0=0\alpha_{0}=0 or not and which smoothness regularizer is applied, our BADMM-based ROTP layer not only keeps 𝑷11\|\bm{P}^{*}\|_{1}\approx 1 under all configurations but also avoids NaN elements successfully.

5.2 Comparisons for various OT-based methods

Compared with existing optimal transport-based pooling methods, e.g., OTK [59], WEGL [63], and SWE [62], our ROTP layers have better flexibility and generalization power. The main differences between our ROTP layers and other OT-based pooling methods can be categorized into the following three points.

Firstly, our ROTP layers are based on an expectation-maximization framework in principle. This framework defines optimal transport plans across sample indices and feature dimensions. The optimal transport plans can be interpreted as the joint distributions of the sample indices and the feature dimensions, which indicates the significant “sample-feature” pairs. On the contrary, existing OT-based pooling methods define optimal transport plans in the sample space. Their optimal transport plans work for pushing observed samples forward to some learnable sample clusters rather than weighting “sample-feature” pairs.

Secondly, in the implementation aspect, our ROTP layers correspond to a generalized ROT problem, which considers the smoothness of the objective function, the uncertainty of prior distributions, and the structural relations between samples jointly. The ROT problem leads to a generalized pooling framework that can unify typical pooling operations, and the ROTP layers can be used to build hierarchical ROTP modules. Moreover, the design of the ROTP layers is flexible and based on various algorithms. The weights of the ROT problem’s regularizers and the prior distributions are learnable parameters in our ROTP layers. On the contrary, existing OT-based pooling methods only consider the typical entropic OT problem (whose entropy term is with a predefined weight) [59, 60] or the sliced Wasserstein problem [62]. Accordingly, these methods are only based on the Sinkhorn-scaling algorithm or the random projection method, and they cannot provide a generalized framework as our ROTP does.

Finally, existing OT-based pooling methods are designed for specific applications or data formats, such as graph classification [63, 62] and set fusion [59, 60]. In the following experiments, we will test our ROTP layers and demonstrate their feasibility in various applications.

6 Experiments

We demonstrate the effectiveness and superiority of our ROTP layers (ROTPS{}_{\text{S}}, ROTPB-E{}_{\text{B-E}}, and ROTPB-Q{}_{\text{B-Q}}) in various machine learning tasks, including multi-instance learning, graph classification, and image classification. Additionally, we build HROTP modules based on the ROTP layers and apply the modules in a typical graph set prediction task — drug-drug interaction classification. For each learning task, we consider several datasets. For ROTPS{}_{\text{S}}, we set α0=0\alpha_{0}=0 to avoid numerical instability. For ROTPB-E{}_{\text{B-E}} and ROTPB-Q{}_{\text{B-Q}}, we make the α0\alpha_{0} learnable.

The baselines we considered include i)i) classic pooling operations like Add-Pooling, Mean-Pooling, and Max-Pooling; ii)ii) the mixed pooling operations like the Mixed Mean-Max and the Gated Mean-Max in [10]; iii)iii) the learnable global pooling layers like DeepSet [15], Set2Set [14], DynamicPooling [1], GNP [28], and the Attention-Pooling and Gated Attention in [2]; iv)iv) the attention-pooling methods for graphs, i.e., SAGPooling [33], ASAPooling [32]; and v)v) OT-based pooling methods, i.e., OTK [59], WEGL [63], and SWE [62]. The above pooling methods are trained and tested on a server with two Nvidia RTX3090 GPUs, whose key hyperparameters are set by grid search. For our ROTP layers, the number of the feed-forward modules TT is the key hyperparameter. According to the convergence analysis in Fig. 4, we set the number of the modules in the range [4,16][4,16] in the following experiments.

6.1 Evaluation of ROTP layers

6.1.1 Multi-instance learning

We consider three MIL tasks, which correspond to a disease diagnosis dataset (Messidor [80]) and two gene ontology categorization datasets (Component and Function [81]). For each dataset, we learn a bag-level classifier, which embeds a bag of instances as input, merges the instances’ embeddings via pooling, and finally, predicts the bag’s label by a classifier. We use the AttentionDeepMIL in [2], a representative bag-level classifier, as the backbone model and plug different pooling layers into it. When training the model, we apply the Adam optimizer [82] with a weight decay regularizer. The hyperparameters of the optimizer are set as follows: the learning rate is 0.0005, the weight decay is 0.005, the number of epochs is 50, and the batch size is 128. In this experiment, we apply four feed-forward modules to build each ROTP layer.

TABLE I: Comparison on MIL accuracy±\pmStd. (%) for different pooling layers.
Dataset Messidor Component Function
DD 687 200 200
#Positive bags 654 423 443
#Negative bags 546 2,707 4,799
#Instances 12,352 36,894 55,536
Min. bag size 8 1 1
Max. bag size 12 53 51
Add 74.33±2.56{}_{\pm\text{2.56}} 93.35±0.98{}_{\pm\text{0.98}} 96.26±0.48{}_{\pm\text{0.48}}
Mean 74.42±2.47{}_{\pm\text{2.47}} 93.32±0.99{}_{\pm\text{0.99}} 96.28±0.66{}_{\pm\text{0.66}}
Max 73.92±3.00{}_{\pm\text{3.00}} 93.23±0.76{}_{\pm\text{0.76}} 95.94±0.48{}_{\pm\text{0.48}}
DeepSet 74.42±2.87{}_{\pm\text{2.87}} 93.29±0.95{}_{\pm\text{0.95}} 96.45±0.51{}_{\pm\text{0.51}}
Mixed 73.42±2.29{}_{\pm\text{2.29}} 93.45±0.61{}_{\pm\text{0.61}} 96.41±0.53{}_{\pm\text{0.53}}
GatedMixed 73.25±2.38{}_{\pm\text{2.38}} 93.03±1.02{}_{\pm\text{1.02}} 96.22±0.65{}_{\pm\text{0.65}}
Set2Set 73.58±3.74{}_{\pm\text{3.74}} 93.19±0.95{}_{\pm\text{0.95}} 96.43±0.56{}_{\pm\text{0.56}}
Attention 74.25±3.67{}_{\pm\text{3.67}} 93.22±1.02{}_{\pm\text{1.02}} 96.31±0.66{}_{\pm\text{0.66}}
GatedAtt 73.67±2.23{}_{\pm\text{2.23}} 93.42±0.91{}_{\pm\text{0.91}} 96.51±0.77{}_{\pm\text{0.77}}
DynamicP 73.16±2.12{}_{\pm\text{2.12}} 93.26±1.30{}_{\pm\text{1.30}} 96.47±0.58{}_{\pm\text{0.58}}
GNP 73.54±3.68{}_{\pm\text{3.68}} 92.86±1.96{}_{\pm\text{1.96}} 96.10±1.03{}_{\pm\text{1.03}}
OTK 74.78±2.89{}_{\pm\text{2.89}} 93.19±0.93{}_{\pm\text{0.93}} 96.31±1.02{}_{\pm\text{1.02}}
SWE 74.46±3.72{}_{\pm\text{3.72}} 93.32±1.26{}_{\pm\text{1.26}} 96.42±0.88{}_{\pm\text{0.88}}
ROTPS{}_{\text{S}} 75.42±2.96{}_{\pm\text{2.96}} 93.29±0.83{}_{\pm\text{0.83}} 96.62±0.48{}_{\pm\text{0.48}}
ROTPB-E{}_{\text{B-E}} (α0=0\alpha_{0}=0) 74.83±2.07{}_{\pm\text{2.07}} 93.16±1.02{}_{\pm\text{1.02}} 96.17±0.43{}_{\pm\text{0.43}}
ROTPB-Q{}_{\text{B-Q}} (α0=0\alpha_{0}=0) 75.08±2.06{}_{\pm\text{2.06}} 93.13±0.94{}_{\pm\text{0.94}} 96.09±0.46{}_{\pm\text{0.46}}
ROTPB-E{}_{\text{B-E}} (learn α0\alpha_{0}) 75.33±1.96{}_{\pm\text{1.96}} 93.16±1.08{}_{\pm\text{1.08}} 96.22±0.44{}_{\pm\text{0.44}}
ROTPB-Q{}_{\text{B-Q}} (learn α0\alpha_{0}) 75.17±2.45{}_{\pm\text{2.45}} 93.45±0.96{}_{\pm\text{0.96}} 96.22±0.48{}_{\pm\text{0.48}}
  • *

    The top-3 results are bolded and the best result is in red.

TABLE II: Comparison on graph classification accuracy±\pmStd. (%) for different pooling layers.
Dataset NCII PROTEINS MUTAG COLLAB RDT-B RDT-M5K IMDB-B IMDB-M
#Graphs 4,110 1,113 188 5,000 2,000 4,999 1,000 1,500
Average #Nodes 29.87 39.06 17.93 74.49 429.63 508.52 19.77 13.00
Average #Edges 32.30 72.82 19.79 2,457.78 497.75 594.87 96.53 65.94
#Classes 2 2 2 3 2 5 2 3
Add 67.96±0.43{}_{\pm\text{0.43}} 72.97±0.54{}_{\pm\text{0.54}} 89.05±0.86{}_{\pm\text{0.86}} 71.06±0.43{}_{\pm\text{0.43}} 80.00±1.49{}_{\pm\text{1.49}} 50.16±0.97{}_{\pm\text{0.97}} 70.18±0.87{}_{\pm\text{0.87}} 47.56±0.56{}_{\pm\text{0.56}}
Mean 64.82±0.52{}_{\pm\text{0.52}} 66.09±0.64{}_{\pm\text{0.64}} 86.53±1.62{}_{\pm\text{1.62}} 72.35±0.44{}_{\pm\text{0.44}} 83.62±1.18{}_{\pm\text{1.18}} 52.44±1.24{}_{\pm\text{1.24}} 70.34±0.38{}_{\pm\text{0.38}} 48.65±0.91{}_{\pm\text{0.91}}
Max 65.95±0.76{}_{\pm\text{0.76}} 72.27±0.33{}_{\pm\text{0.33}} 85.90±1.68{}_{\pm\text{1.68}} 73.07±0.57{}_{\pm\text{0.57}} 82.62±1.25{}_{\pm\text{1.25}} 44.34±1.93{}_{\pm\text{1.93}} 70.24±0.54{}_{\pm\text{0.54}} 47.80±0.54{}_{\pm\text{0.54}}
DeepSet 66.28±0.72{}_{\pm\text{0.72}} 73.76±0.47{}_{\pm\text{0.47}} 87.84±0.71{}_{\pm\text{0.71}} 69.74±0.66{}_{\pm\text{0.66}} 82.91±1.37{}_{\pm\text{1.37}} 47.45±0.54{}_{\pm\text{0.54}} 70.84±0.71{}_{\pm\text{0.71}} 48.05±0.71{}_{\pm\text{0.71}}
Mixed 66.46±0.74{}_{\pm\text{0.74}} 72.25±0.45{}_{\pm\text{0.45}} 87.30±0.87{}_{\pm\text{0.87}} 73.22±0.35{}_{\pm\text{0.35}} 84.36±2.62{}_{\pm\text{2.62}} 46.67±1.63{}_{\pm\text{1.63}} 71.28±0.26{}_{\pm\text{0.26}} 48.07±0.25{}_{\pm\text{0.25}}
GatedMixed 63.86±0.76{}_{\pm\text{0.76}} 69.40±1.93{}_{\pm\text{1.93}} 87.94±1.28{}_{\pm\text{1.28}} 71.94±0.40{}_{\pm\text{0.40}} 80.60±3.89{}_{\pm\text{3.89}} 44.78±4.53{}_{\pm\text{4.53}} 70.96±0.60{}_{\pm\text{0.60}} 48.09±0.44{}_{\pm\text{0.44}}
Set2Set 65.10±1.12{}_{\pm\text{1.12}} 68.61±1.44{}_{\pm\text{1.44}} 87.77±0.86{}_{\pm\text{0.86}} 72.31±0.73{}_{\pm\text{0.73}} 80.08±5.72{}_{\pm\text{5.72}} 49.85±2.77{}_{\pm\text{2.77}} 70.36±0.85{}_{\pm\text{0.85}} 48.30±0.54{}_{\pm\text{0.54}}
Attention 64.35±0.61{}_{\pm\text{0.61}} 67.70±0.95{}_{\pm\text{0.95}} 88.08±1.22{}_{\pm\text{1.22}} 72.57±0.41{}_{\pm\text{0.41}} 81.55±4.39{}_{\pm\text{4.39}} 51.85±0.66{}_{\pm\text{0.66}} 70.60±0.38{}_{\pm\text{0.38}} 47.83±0.78{}_{\pm\text{0.78}}
GatedAtt 64.66±0.52{}_{\pm\text{0.52}} 68.16±0.90{}_{\pm\text{0.90}} 86.91±1.79{}_{\pm\text{1.79}} 72.31±0.37{}_{\pm\text{0.37}} 82.55±1.96{}_{\pm\text{1.96}} 51.47±0.82{}_{\pm\text{0.82}} 70.52±0.31{}_{\pm\text{0.31}} 48.67±0.35{}_{\pm\text{0.35}}
DynamicP 62.11±0.27{}_{\pm\text{0.27}} 65.86±0.85{}_{\pm\text{0.85}} 85.40±2.81{}_{\pm\text{2.81}} 70.78±0.88{}_{\pm\text{0.88}} 67.51±1.82{}_{\pm\text{1.82}} 32.11±3.85{}_{\pm\text{3.85}} 69.84±0.73{}_{\pm\text{0.73}} 47.59±0.48{}_{\pm\text{0.48}}
GNP 68.20±0.48{}_{\pm\text{0.48}} 73.44±0.61{}_{\pm\text{0.61}} 88.37±1.25{}_{\pm\text{1.25}} 72.80±0.58{}_{\pm\text{0.58}} 81.93±2.23{}_{\pm\text{2.23}} 51.80±0.61{}_{\pm\text{0.61}} 70.34±0.83{}_{\pm\text{0.83}} 48.85±0.81{}_{\pm\text{0.81}}
ASAP 68.09±0.42{}_{\pm\text{0.42}} 70.42±1.45{}_{\pm\text{1.45}} 87.68±1.42{}_{\pm\text{1.42}} 68.20±2.37{}_{\pm\text{2.37}} 73.91±1.50{}_{\pm\text{1.50}} 44.58±0.44{}_{\pm\text{0.44}} 68.33±2.50{}_{\pm\text{2.50}} 43.92±1.13{}_{\pm\text{1.13}}
SAGP 67.48±0.65{}_{\pm\text{0.65}} 72.63±0.44{}_{\pm\text{0.44}} 87.88±2.22{}_{\pm\text{2.22}} 70.19±0.55{}_{\pm\text{0.55}} 74.12±2.86{}_{\pm\text{2.86}} 46.00±1.74{}_{\pm\text{1.74}} 70.34±0.74{}_{\pm\text{0.74}} 47.04±1.22{}_{\pm\text{1.22}}
OTK 67.96±0.55{}_{\pm\text{0.55}} 69.52±0.76{}_{\pm\text{0.76}} 86.90±1.83{}_{\pm\text{1.83}} 71.35±0.91{}_{\pm\text{0.91}} 74.28±1.39{}_{\pm\text{1.39}} 50.57±1.20{}_{\pm\text{1.20}} 70.94±0.79{}_{\pm\text{0.79}} 48.41±0.89{}_{\pm\text{0.89}}
SWE 68.06±0.98{}_{\pm\text{0.98}} 70.09±1.22{}_{\pm\text{1.22}} 85.68±2.07{}_{\pm\text{2.07}} 72.17±1.29{}_{\pm\text{1.29}} 79.30±3.94{}_{\pm\text{3.94}} 51.11±1.55{}_{\pm\text{1.55}} 70.34±1.05{}_{\pm\text{1.05}} 48.93±1.34{}_{\pm\text{1.34}}
WEGL 68.16±0.62{}_{\pm\text{0.62}} 71.58±0.94{}_{\pm\text{0.94}} 88.68±1.66{}_{\pm\text{1.66}} 72.55±0.69{}_{\pm\text{0.69}} 82.80±1.73{}_{\pm\text{1.73}} 52.03±0.60{}_{\pm\text{0.60}} 71.94±0.75{}_{\pm\text{0.75}} 49.20±0.87{}_{\pm\text{0.87}}
ROTPS{}_{\text{S}} 68.27±1.06{}_{\pm\text{1.06}} 73.10±0.22{}_{\pm\text{0.22}} 88.84±1.21{}_{\pm\text{1.21}} 71.20±0.55{}_{\pm\text{0.55}} 81.54±1.38{}_{\pm\text{1.38}} 51.00±0.61{}_{\pm\text{0.61}} 70.74±0.80{}_{\pm\text{0.80}} 47.87±0.43{}_{\pm\text{0.43}}
ROTPB-E{}_{\text{B-E}} (α0=0\alpha_{0}=0) 66.23±0.50{}_{\pm\text{0.50}} 67.71±1.70{}_{\pm\text{1.70}} 86.82±2.02{}_{\pm\text{2.02}} 73.86±0.44{}_{\pm\text{0.44}} 86.80±1.19{}_{\pm\text{1.19}} 52.25±0.75{}_{\pm\text{0.75}} 71.72±0.88{}_{\pm\text{0.88}} 50.48±0.14{}_{\pm\text{0.14}}
ROTPB-Q{}_{\text{B-Q}} (α0=0\alpha_{0}=0) 66.18±0.76{}_{\pm\text{0.76}} 69.88±0.87{}_{\pm\text{0.87}} 85.42±1.10{}_{\pm\text{1.10}} 74.14±0.24{}_{\pm\text{0.24}} 87.72±1.03{}_{\pm\text{1.03}} 52.79±0.60{}_{\pm\text{0.60}} 72.34±0.50{}_{\pm\text{0.50}} 49.36±0.52{}_{\pm\text{0.52}}
ROTPB-E{}_{\text{B-E}} (learn α0\alpha_{0}) 65.90±0.94{}_{\pm\text{0.94}} 70.19±0.66{}_{\pm\text{0.66}} 88.01±1.51{}_{\pm\text{1.51}} 74.05±0.34{}_{\pm\text{0.34}} 86.78±1.14{}_{\pm\text{1.14}} 52.77±0.69{}_{\pm\text{0.69}} 71.76±0.62{}_{\pm\text{0.62}} 50.28±0.86{}_{\pm\text{0.86}}
ROTPB-Q{}_{\text{B-Q}} (learn α0\alpha_{0}) 65.96±0.32{}_{\pm\text{0.32}} 70.12±1.17{}_{\pm\text{1.17}} 86.79±1.81{}_{\pm\text{1.81}} 74.27±0.47{}_{\pm\text{0.47}} 88.67±0.99{}_{\pm\text{0.99}} 52.84±0.60{}_{\pm\text{0.60}} 71.78±1.00{}_{\pm\text{1.00}} 49.44±0.46{}_{\pm\text{0.46}}
  • *

    For each dataset, the top-3 results are bolded and the best result is in red.

For each model with a specific pooling operation, we train and test it through 5-fold cross-validation. Accordingly, we evaluate the global pooling methods based on the averaged testing classification accuracy achieved by the corresponding models. Table I presents the statistics of the MIL datasets and the learning results. None of the baselines perform consistently well across all the datasets. Our ROTP layers outperform their competitors in most situations. Especially, the ROTPS{}_{\text{S}} layer achieves the best performance on two of the three datasets. For the ROTPB-E/Q{}_{\text{B-E/Q}} layers, their performance is comparable to its competitors. Additionally, we can find that when making α0\alpha_{0} learnable, the performance of the ROTPB-E/Q{}_{\text{B-E/Q}} layers is improved. It means that although the corresponding GW discrepancy term increases the complexity of the layer, it takes the structural information of samples into account and helps to improve the learning results indeed.

6.1.2 Graph embedding and classification

We further evaluate our ROTP layers in graph embedding and classification tasks. In this experiment, we consider eight representative graph classification datasets in the TUDataset [83], including three biochemical molecule datasets (NCII, MUTAG, and PROTEINS) and five social network datasets (COLLAB, RDT-B, RDT-M5K, IMDB-B, and IMDB-M). For each dataset, we implement the adversarial graph contrastive learning method (ADGCL) [84], learning a five-layer graph isomorphism network (GIN) [4] to represent graphs. At the end of the GIN, we apply different global pooling methods to aggregate node embeddings as graph embeddings. After learning the graph embeddings, we train an SVM classifier to classify the graphs.

Following the setting used in the ADGCL work [84], we apply learnable edge drop operations to augment observed graphs. For each model with a specific global pooling layer, we use the Adam optimizer to train it. The learning rate is 0.001, and the batch size is 32. We train each model with 100 epochs for the COLLAB dataset and 150 epochs for the RDT-B dataset, respectively. For the remaining dataset, we set the number of epochs to 20. Following the setting used in the MIL experiment, we use four feed-forward modules to build each ROTP layer. We train and test each model in five trials. The averaged classification accuracy and the standard deviation are recorded. Table II shows the statistics of the datasets and the learning results achieved by different global pooling methods. Similar to the above MIL experiment, our ROTP layers perform well in most situations. Especially our BADMM-based ROTP layers achieve the best performance on the five social network datasets. Note that the graph structure plays a central role in graph classification tasks. Therefore, we should take the GW discrepancy term into account when applying our ROTP layers. The results in Table II support our claim — making α0\alpha_{0} learnable improves the learning results in most situations.

The experiments on MIL and graph classification indicate that our ROTP layers can simplify the design and selection of global pooling to some degree. In particular, none of the baselines perform consistently well across all the datasets, while our ROTP layers are comparable to the best baselines in most situations, whose performance is more stable and consistent. Applying our ROTP layers, the design and selection of pooling operations are reformulated as the selection of optimization algorithms. Instead of testing various global pooling methods empirically, we just need to select an algorithm (i.e., Sinkhorn-scaling or Bregman ADMM) to implement the ROTP layer, which can achieve encouraging performance.

TABLE III: Comparisons on graph set classification accuracy±\pmStd. (%) for different pooling layers.
Dataset DECAGON DECAGON DECAGON FEARS
DiBr-APND Anae-Fati PleuP-Diar
#Graph sets 6,309 2,922 2,842 6,338
#Positive sets 3,189 1,526 1,422 3,169
Positive label Difficulty Anaemia Pleural Non-
breathing pain myopathy
#Negative sets 3,120 1,396 1,420 3,169
Negative label Pressure Fatigue Diarrhea Myopathy
decreased
Set size 2 2 2 2\sim52
Add 50.86±0.97{}_{\pm\text{0.97}} 63.15±1.79{}_{\pm\text{1.79}} 62.32±1.08{}_{\pm\text{1.08}} 75.89±1.33{}_{\pm\text{1.33}}
Mean 51.10±1.09{}_{\pm\text{1.09}} 61.95±2.60{}_{\pm\text{2.60}} 61.30±2.68{}_{\pm\text{2.68}} 72.42±1.51{}_{\pm\text{1.51}}
Max 50.59±0.77{}_{\pm\text{0.77}} 61.88±2.03{}_{\pm\text{2.03}} 60.11±2.03{}_{\pm\text{2.03}} 82.02±0.72{}_{\pm\text{0.72}}
DeepSet 49.83±1.07{}_{\pm\text{1.07}} 56.24±5.20{}_{\pm\text{5.20}} 51.78±3.10{}_{\pm\text{3.10}} 82.40±1.56{}_{\pm\text{1.56}}
Mixed 51.13±0.99{}_{\pm\text{0.99}} 63.83±1.19{}_{\pm\text{1.19}} 60.91±2.12{}_{\pm\text{2.12}} 81.54±1.13{}_{\pm\text{1.13}}
GatedMixed 51.39±0.63{}_{\pm\text{0.63}} 61.50±1.61{}_{\pm\text{1.61}} 59.12±2.12{}_{\pm\text{2.12}} 81.88±1.14{}_{\pm\text{1.14}}
Set2Set 50.72±1.71{}_{\pm\text{1.71}} 59.35±2.04{}_{\pm\text{2.04}} 55.01±3.59{}_{\pm\text{3.59}} 79.29±0.84{}_{\pm\text{0.84}}
Attention 50.52±1.10{}_{\pm\text{1.10}} 61.40±2.03{}_{\pm\text{2.03}} 61.33±2.40{}_{\pm\text{2.40}} 75.98±0.74{}_{\pm\text{0.74}}
GatedAtt 50.74±0.61{}_{\pm\text{0.61}} 62.15±0.77{}_{\pm\text{0.77}} 58.80±1.18{}_{\pm\text{1.18}} 75.84±1.29{}_{\pm\text{1.29}}
DynamicP 51.01±1.88{}_{\pm\text{1.88}} 55.93±1.56{}_{\pm\text{1.56}} 52.58±2.91{}_{\pm\text{2.91}} 74.00±1.61{}_{\pm\text{1.61}}
GNP 50.00±1.88{}_{\pm\text{1.88}} 53.98±6.34{}_{\pm\text{6.34}} 52.58±4.68{}_{\pm\text{4.68}} 62.71±15.55{}_{\pm\text{15.55}}
ASAP 50.89±0.82{}_{\pm\text{0.82}} 63.66±1.81{}_{\pm\text{1.81}} 60.67±2.69{}_{\pm\text{2.69}} 77.15±1.13{}_{\pm\text{1.13}}
SAGP 49.87±0.77{}_{\pm\text{0.77}} 63.62±1.28{}_{\pm\text{1.28}} 59.86±2.43{}_{\pm\text{2.43}} 77.29±1.04{}_{\pm\text{1.04}}
OTK 50.96±1.11{}_{\pm\text{1.11}} 63.68±1.59{}_{\pm\text{1.59}} 61.66±2.39{}_{\pm\text{2.39}} 79.40±1.08{}_{\pm\text{1.08}}
SWE 51.05±2.15{}_{\pm\text{2.15}} 63.21±2.02{}_{\pm\text{2.02}} 61.37±3.13{}_{\pm\text{3.13}} 80.64±1.86{}_{\pm\text{1.86}}
WEGL 51.67±0.85{}_{\pm\text{0.85}} 63.79±2.54{}_{\pm\text{2.54}} 61.36±2.30{}_{\pm\text{2.30}} 81.98±0.77{}_{\pm\text{0.77}}
ROTPS{}_{\text{S}} 51.96±0.71{}_{\pm\text{0.71}} 62.91±1.13{}_{\pm\text{1.13}} 59.40±0.90{}_{\pm\text{0.90}} 79.75±0.71{}_{\pm\text{0.71}}
ROTPB-E{}_{\text{B-E}} 51.26±0.84{}_{\pm\text{0.84}} 63.86±2.41{}_{\pm\text{2.41}} 62.57±1.34{}_{\pm\text{1.34}} 82.55±0.42{}_{\pm\text{0.42}}
ROTPB-Q{}_{\text{B-Q}} 52.72±0.66{}_{\pm\text{0.66}} 63.15±1.27{}_{\pm\text{1.27}} 60.88±1.65{}_{\pm\text{1.65}} 81.43±1.12{}_{\pm\text{1.12}}
  • *

    The top-3 results are bolded and the best result is in red.

6.2 Evaluation of HROTP modules

To demonstrate the usefulness of the proposed HROTP modules, we apply them in drug-drug interaction (DDI) classification task [85, 86, 87, 88]. In particular, the DDI classification task aims to predict different drug combinations’ side effects. Each drug is represented as a molecular graph, and a drug combination corresponds to a set of graphs. Therefore, the DDI classification task is a typical graph set prediction problem whose data have hierarchical clustering structures.

In this experiment, we consider four drug side-effect datasets: three drug pair datasets (DiBr-APND, Anae-Fati, and PleuP-Diar) sampled from DECAGON [89] and a drug combination dataset FEARS [88]. Each dataset contains thousands of drug sets that may cause two different side effects. We implement a three-layer GIN model for each dataset to learn graph embeddings. A hierarchical pooling mechanism with two pooling layers is considered. The first pooling layer aggregates the node embeddings of each graph into a graph embedding. The second pooling layer further aggregates the graph embeddings within a drug set to a set embedding. Finally, a classifier is trained based on the set embeddings. For the hierarchical pooling module, we implement its pooling layers based on different pooling methods. When using our ROTP layers, we obtain the proposed HROTP modules.

We apply the Adam optimizer to train the models and set the learning rate to 0.001. For the DiBr-APND and the FEARS, we train each model with 40 epochs. For the remaining two datasets, we train each model with 100 epochs. The batch size is set to be 128 for the FEARS dataset and 32 for the remaining three datasets. In this experiment, we set α0=0\alpha_{0}=0 for all three ROTP layers to achieve a trade-off between effectiveness and efficiency. The number of the feed-forward modules used in the ROPT layers is set from {4,6,8,10}\{4,6,8,10\}. We train and test each model in five trials for each dataset to get the average classification accuracy. Table III shows the statistics of the datasets and the learning results achieved by different global pooling methods. Our HROTP modules achieve the best performance on all four datasets.

Refer to caption
Figure 8: Given a batch of 50 sample sets, in which each sample set contains five hundred 100-dimensional samples, we plot the averaged feed-forward runtime of various pooling methods in 10 trials on a single GPU (RTX 3090). The proposed ROTP layers and existing OT-based methods are labeled in orange and green, respectively. The remaining global pooling methods are labeled in blue.

6.3 More analytic experiments

6.3.1 Runtime comparison

The runtime of our ROTP layers is comparable to that of the learning-based global pooling methods (including existing OT-based methods). Fig. 8 shows the rank of various global pooling methods on their runtime per batch. We can find that when applying eight feed-forward modules, the runtime of the BADMM-based ROTP layer is almost the same as that of Set2Set [14]. When reducing the number of the feed-forward modules to four, its runtime can be less than that of DeepSet [15]. Because of setting α0=0\alpha_{0}=0, the Sinkhorn-based ROTP layer in this experiment is faster than the BADMM-based ROTP layers, which verifies the analysis shown in Section 5. When applying four Sinkhorn-scaling modules, the runtime of the Sinkhorn-based ROTP layer is comparable to that of SAGP [33]. Note that the BADMM-based ROTP layers are comparable to OTK [59], and the Sinkhorn-based ROTP layer is comparable to SWE [62] in the aspect of runtime. All the ROTP layers are faster than the WEGL [63].

6.3.2 Robustness to hyperparameter settings

Our ROTP layers have one key hyperparameter — the number of feed-forward modules. Applying many modules will lead to highly-precise solutions to (3) but take more time on both feed-forward computation and backpropagation. As aforementioned, we search the number of the modules in the range [4,16][4,16] in the above experiments, which can achieve a good trade-off between effectiveness and efficiency. To further demonstrate the robustness of our ROTP layers to the number of feed-forward modules, we consider the ROTP layers with 4-16 feed-forward modules and train the corresponding models on each of the twelve (MIL and graph classification) datasets. Fig. 9 shows the averaged classification accuracy on the twelve datasets with respect to the number of feed-forward modules. The performance of our ROTP layers is stable — the dynamics of the average classification accuracy is smaller than 0.4%.

Refer to caption
Figure 9: The averaged classification accuracy for the twelve datasets (MIL and Graph Classification) achieved by our ROTP layers with different numbers of feed-forward modules.
TABLE IV: The impacts of 𝒑0\bm{p}_{0} and 𝒒0\bm{q}_{0} on the classification accuracy (%) of NCI1
𝒑0\bm{p}_{0} 𝒒0\bm{q}_{0} ROTPS{}_{\text{S}} ROTPB-E{}_{\text{B-E}} ROTPB-Q{}_{\text{B-Q}}
Fixed Fixed 68.27±1.06{}_{\pm\text{1.06}} 65.90±0.94{}_{\pm\text{0.94}} 65.96±0.32{}_{\pm\text{0.32}}
Learned Fixed 67.97±0.48{}_{\pm\text{0.48}} 66.57±0.54{}_{\pm\text{0.54}} 66.45±0.82{}_{\pm\text{0.82}}
Fixed Learned 69.86±0.45{}_{\pm\text{0.45}} 66.21±0.76{}_{\pm\text{0.76}} 66.40±0.57{}_{\pm\text{0.57}}
Learned Learned 68.60±0.15{}_{\pm\text{0.15}} 66.45±0.23{}_{\pm\text{0.23}} 66.67±0.63{}_{\pm\text{0.63}}
  • *

    Each layer has four feed-forward modules.

Besides the number of the feed-forward modules, we also consider the settings of the prior distributions (i.e.i.e., 𝒑0\bm{p}_{0} and 𝒒0\bm{q}_{0}). As mentioned in Section 4, we can fix them as uniform distributions or learn them by a self-attention model. Take the NCI1 dataset as an example. Table IV presents the learning results of our methods under different settings of 𝒑0\bm{p}_{0} and 𝒒0\bm{q}_{0}. Our ROTP layers are robust to their settings — the learning results do not change a lot under different settings. Therefore, we fix 𝒑0\bm{p}_{0} and 𝒒0\bm{q}_{0} as uniform distributions in the above experiments. Under this simple setting, our ROTP layers have already achieved encouraging results.

Refer to caption
Figure 10: Illustrations of the 𝑷\bm{P}^{*}’s of a MUTAG graph’s node embeddings during training.
Refer to caption
(a) MUTAG
Refer to caption
(b) IMDB-B
Figure 11: (a) The visualizations of two MUTAG graphs and their 𝑷\bm{P}^{*}’s. For the “V-shape” subgraphs, their submatrices in the 𝑷\bm{P}^{*}’s are marked by color frames. (b) The visualizations of two IMDB-B graphs and their 𝑷\bm{P}^{*}’s. For each graph, its key node connecting two communities and the corresponding column in the 𝑷\bm{P}^{*}’s are marked by color frames.
TABLE V: Comparisons for ResNets and our ResNets + ROTPB-E{}_{\text{B-E}} on validation accuracy (%)
Learning Strategy ResNet18 ResNet34 ResNet50 ResNet101 ResNet152
Top-5 100 Epochs (A2DP) 89.084 91.433 92.880 93.552 94.048
90 Epochs (A2DP) + 10 Epochs (ROTPB-E{}_{\text{B-E}}) 89.174 91.458 93.006 93.622 94.060
Top-1 100 Epochs (A2DP) 69.762 73.320 76.142 77.386 78.324
90 Epochs (A2DP) + 10 Epochs (ROTPB-E{}_{\text{B-E}}) 69.906 73.426 76.446 77.522 78.446

6.3.3 Visualization and rationality

Take the ROTPB-E{}_{\text{B-E}} layer used for the MUTAG dataset as an example. For a graph in the validation set, we visualize the dynamics of the corresponding 𝑷\bm{P}^{*}’s in different epochs in Fig. 10. In the beginning, the 𝑷\bm{P}^{*} is relatively dense because the node embeddings are not fully trained and may not be distinguishable. With the increase of epochs, the 𝑷\bm{P}^{*} becomes sparse and focuses more on significant “sample-feature” pairs.

Additionally, to verify the rationality of the learned 𝑷\bm{P}^{*}, we take the ROTPS{}_{\text{S}} layer as an example and visualize some graphs and their 𝑷\bm{P}^{*}’s in Fig. 11. For the “V-shape” subgraphs in the two MUTAG graphs, we compare the corresponding submatrices shown in their 𝑷\bm{P}^{*}’s. These submatrices obey the same pattern, which means that for the subgraphs shared by different samples, the weights of their node embeddings will be similar. For the key nodes in the two IMDB-B graphs, their corresponding columns in the 𝑷\bm{P}^{*}’s are distinguished from other columns. For the nodes belonging to different communities, their columns in the 𝑷\bm{P}^{*}’s own significant clustering structures.

6.4 Limitations and discussions

Besides MIL and graph classification, we further test our ROTP layers in image classification tasks. Given a ResNet [6], we replace its “adaptive 2D mean-pooling layer (A2DP)” with our ROTPB-E{}_{\text{B-E}} layer and finetune the modified model on ImageNet [90]. In particular, given the output of the last convolution layer of the ResNet, i.e., 𝑿inB×C×H×W\bm{X}_{\text{in}}\in\mathbb{R}^{B\times C\times H\times W}, our ROTPB-E{}_{\text{B-E}} layer fuses the data and outputs 𝑿outB×C×1×1\bm{X}_{\text{out}}\in\mathbb{R}^{B\times C\times 1\times 1}. In this experiment, we apply a two-stage learning strategy: we first train a ResNet in 90 epochs, and then we replace its A2DP layer with our ROTPB-E{}_{\text{B-E}} layer; finally, we fix other layers and train our ROTPB-E{}_{\text{B-E}} layer in 10 epochs. The learning rate is 0.001, and the batch size is 256. Because training on ImageNet is time-consuming, we set α0=0\alpha_{0}=0 for the ROTPB-E{}_{\text{B-E}} layer in this experiment to reduce the computational complexity. Table V shows that using our ROTPB-E{}_{\text{B-E}} layer helps to improve the classification accuracy, and the improvement is consistent for different ResNet architectures.

The improvements shown in Table V are incremental because we just replaced a single global pooling layer with our ROTP layer. When training the ResNets with ROTP layers from scratch, the improvements are not so significant, either — after training “ResNet18+ROTP” with 100 epochs, the top-1 accuracy is 69.920%, and the top-5 accuracy is 89.198%. Replacing more local pooling layers with our ROTP layers may bring better performance. However, given a tensor 𝑿inB×C×H×W\bm{X}_{\text{in}}\in\mathbb{R}^{B\times C\times H\times W}, a local pooling merges each patch with size (B×C×2×2)(B\times C\times 2\times 2) and outputs 𝑿outB×C×H2×W2\bm{X}_{\text{out}}\in\mathbb{R}^{B\times C\times\frac{H}{2}\times\frac{W}{2}}, which involves BHW4\frac{BHW}{4} pooling operations. In other words, the current bottleneck of our ROTP layer is its computational efficiency. Developing an efficient CUDA version of the ROTP layers and extending them to local pooling operations will be our future work.

7 Conclusion and Future Work

This study proposes a generalized pooling framework driven by the regularized optimal transport problem. We demonstrate that many existing pooling operations correspond to solving the ROT problem with different configurations. By learning the parameters of the ROT problem, we obtain an ROTP layer and propose three implementations based on different settings. For each implementation of the ROTP layer, we analyze its in-depth on its stability and complexity. Stacking the ROTP layers leads to a hierarchical pooling module for set fusion. Our work provides a solid and effective pooling framework with theoretical support and statistical interpretability. Experiments on practical learning tasks and real-world datasets demonstrate the usefulness of our ROTP layers. In the future, we consider applying our regularized optimal transport modules to reformulate other machine learning models, e.g., the local pooling layers in convolutional neural networks and message passing layers in graph neural networks. Additionally, as aforementioned, we plan to develop a CUDA version of the ROTP layer to improve its computational efficiency further.

References

  • [1] Y. Yan, X. Wang, X. Guo, J. Fang, W. Liu, and J. Huang, “Deep multi-instance learning with dynamic pooling,” in Asian Conference on Machine Learning.   PMLR, 2018, pp. 662–677.
  • [2] M. Ilse, J. Tomczak, and M. Welling, “Attention-based deep multiple instance learning,” in International conference on machine learning.   PMLR, 2018, pp. 2127–2136.
  • [3] R. Ying, J. You, C. Morris, X. Ren, W. L. Hamilton, and J. Leskovec, “Hierarchical graph representation learning with differentiable pooling,” in Proceedings of the 32nd International Conference on Neural Information Processing Systems, 2018, pp. 4805–4815.
  • [4] K. Xu, W. Hu, J. Leskovec, and S. Jegelka, “How powerful are graph neural networks?” in International Conference on Learning Representations, 2018.
  • [5] A. Krizhevsky, I. Sutskever, and G. E. Hinton, “Imagenet classification with deep convolutional neural networks,” Advances in neural information processing systems, vol. 25, pp. 1097–1105, 2012.
  • [6] K. He, X. Zhang, S. Ren, and J. Sun, “Deep residual learning for image recognition,” in Proceedings of the IEEE conference on computer vision and pattern recognition, 2016, pp. 770–778.
  • [7] C. Yu, X. Zhao, Q. Zheng, P. Zhang, and X. You, “Hierarchical bilinear pooling for fine-grained visual recognition,” in Proceedings of the European conference on computer vision (ECCV), 2018, pp. 574–589.
  • [8] Z. Pan, B. Zhuang, J. Liu, H. He, and J. Cai, “Scalable vision transformers with hierarchical pooling,” in Proceedings of the IEEE/CVF International Conference on Computer Vision, 2021, pp. 377–386.
  • [9] Y.-L. Boureau, J. Ponce, and Y. LeCun, “A theoretical analysis of feature pooling in visual recognition,” in International conference on machine learning, 2010, pp. 111–118.
  • [10] C.-Y. Lee, P. W. Gallagher, and Z. Tu, “Generalizing pooling functions in convolutional neural networks: Mixed, gated, and tree,” in Artificial intelligence and statistics.   PMLR, 2016, pp. 464–472.
  • [11] L. Liu, C. Shen, and A. van den Hengel, “Cross-convolutional-layer pooling for image recognition,” IEEE transactions on pattern analysis and machine intelligence, vol. 39, no. 11, pp. 2305–2313, 2016.
  • [12] H. You, L. Yu, S. Tian, X. Ma, Y. Xing, N. Xin, and W. Cai, “Mc-net: Multiple max-pooling integration module and cross multi-scale deconvolution network,” Knowledge-Based Systems, vol. 231, p. 107456, 2021.
  • [13] M. Lin, Q. Chen, and S. Yan, “Network in network,” arXiv preprint arXiv:1312.4400, 2013.
  • [14] O. Vinyals, S. Bengio, and M. Kudlur, “Order matters: Sequence to sequence for sets,” arXiv preprint arXiv:1511.06391, 2015.
  • [15] M. Zaheer, S. Kottur, S. Ravanbhakhsh, B. Póczos, R. Salakhutdinov, and A. J. Smola, “Deep sets,” in Proceedings of the 31st International Conference on Neural Information Processing Systems, 2017, pp. 3394–3404.
  • [16] J. Lee, Y. Lee, J. Kim, A. Kosiorek, S. Choi, and Y. W. Teh, “Set transformer: A framework for attention-based permutation-invariant neural networks,” in International Conference on Machine Learning.   PMLR, 2019, pp. 3744–3753.
  • [17] G. Peyré, M. Cuturi, and J. Solomon, “Gromov-wasserstein averaging of kernel and distance matrices,” in International Conference on Machine Learning.   PMLR, 2016, pp. 2664–2672.
  • [18] A. Agrawal, B. Amos, S. Barratt, S. Boyd, S. Diamond, and J. Z. Kolter, “Differentiable convex optimization layers,” Advances in neural information processing systems, vol. 32, 2019.
  • [19] L. El Ghaoui, F. Gu, B. Travacca, A. Askari, and A. Tsai, “Implicit deep learning,” SIAM Journal on Mathematics of Data Science, vol. 3, no. 3, pp. 930–958, 2021.
  • [20] Z. Huang, S. Bai, and J. Z. Kolter, “Implicit2\text{Implicit}^{2}: Implicit layers for implicit representations,” Advances in Neural Information Processing Systems, vol. 34, 2021.
  • [21] S. Bai, J. Z. Kolter, and V. Koltun, “Deep equilibrium models,” Advances in Neural Information Processing Systems, vol. 32, 2019.
  • [22] S. Gould, R. Hartley, and D. J. Campbell, “Deep declarative networks,” IEEE Transactions on Pattern Analysis and Machine Intelligence, 2021.
  • [23] M. Cuturi, “Sinkhorn distances: Lightspeed computation of optimal transport,” in Advances in neural information processing systems, 2013, pp. 2292–2300.
  • [24] K. Pham, K. Le, N. Ho, T. Pham, and H. Bui, “On unbalanced optimal transport: An analysis of sinkhorn algorithm,” in International Conference on Machine Learning.   PMLR, 2020, pp. 7673–7682.
  • [25] T. Séjourné, F.-X. Vialard, and G. Peyré, “The unbalanced gromov wasserstein distance: Conic formulation and relaxation,” Advances in Neural Information Processing Systems, vol. 34, 2021.
  • [26] H. Wang and A. Banerjee, “Bregman alternating direction method of multipliers,” in Proceedings of the 27th International Conference on Neural Information Processing Systems-Volume 2, 2014, pp. 2816–2824.
  • [27] H. Xu, “Gromov-wasserstein factorization models for graph clustering,” Proceedings of the AAAI Conference on Artificial Intelligence, vol. 34, no. 04, pp. 6478–6485, 2020.
  • [28] J. Ko, T. Kwon, K. Shin, and J. Lee, “Learning to pool in graph neural networks for extrapolation,” arXiv preprint arXiv:2106.06210, 2021.
  • [29] C. Gulcehre, K. Cho, R. Pascanu, and Y. Bengio, “Learned-norm pooling for deep feedforward and recurrent neural networks,” in Joint European Conference on Machine Learning and Knowledge Discovery in Databases.   Springer, 2014, pp. 530–546.
  • [30] D. dan Guo, L. Tian, M. Zhang, M. Zhou, and H. Zha, “Learning prototype-oriented set representations for meta-learning,” in International Conference on Learning Representations, 2021.
  • [31] A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, Ł. Kaiser, and I. Polosukhin, “Attention is all you need,” in Advances in neural information processing systems, 2017, pp. 5998–6008.
  • [32] E. Ranjan, S. Sanyal, and P. Talukdar, “Asap: Adaptive structure aware pooling for learning hierarchical graph representations,” in Proceedings of the AAAI Conference on Artificial Intelligence, vol. 34, no. 04, 2020, pp. 5470–5477.
  • [33] J. Lee, I. Lee, and J. Kang, “Self-attention graph pooling,” in International Conference on Machine Learning.   PMLR, 2019, pp. 3734–3743.
  • [34] G. Peyré, M. Cuturi et al., “Computational optimal transport: With applications to data science,” Foundations and Trends® in Machine Learning, vol. 11, no. 5-6, pp. 355–607, 2019.
  • [35] C. Villani, Optimal transport: old and new.   Springer Science & Business Media, 2008, vol. 338.
  • [36] C. Frogner, C. Zhang, H. Mobahi, M. Araya-Polo, and T. Poggio, “Learning with a wasserstein loss,” in Proceedings of the 28th International Conference on Neural Information Processing Systems-Volume 2, 2015, pp. 2053–2061.
  • [37] N. Courty, R. Flamary, D. Tuia, and A. Rakotomamonjy, “Optimal transport for domain adaptation,” IEEE transactions on pattern analysis and machine intelligence, vol. 39, no. 9, pp. 1853–1865, 2016.
  • [38] K. Fatras, T. Séjourné, R. Flamary, and N. Courty, “Unbalanced minibatch optimal transport; applications to domain adaptation,” in International Conference on Machine Learning.   PMLR, 2021, pp. 3186–3197.
  • [39] M. Agueh and G. Carlier, “Barycenters in the wasserstein space,” SIAM Journal on Mathematical Analysis, vol. 43, no. 2, pp. 904–924, 2011.
  • [40] M. Cuturi and A. Doucet, “Fast computation of wasserstein barycenters,” in International conference on machine learning.   PMLR, 2014, pp. 685–693.
  • [41] M. Arjovsky, S. Chintala, and L. Bottou, “Wasserstein generative adversarial networks,” in International conference on machine learning.   PMLR, 2017, pp. 214–223.
  • [42] I. Tolstikhin, O. Bousquet, S. Gelly, and B. Schoelkopf, “Wasserstein auto-encoders,” in International Conference on Learning Representations, 2018.
  • [43] M. Kusner, Y. Sun, N. Kolkin, and K. Weinberger, “From word embeddings to document distances,” in International conference on machine learning.   PMLR, 2015, pp. 957–966.
  • [44] R. Sinkhorn and P. Knopp, “Concerning nonnegative matrices and doubly stochastic matrices,” Pacific Journal of Mathematics, vol. 21, no. 2, pp. 343–348, 1967.
  • [45] J.-D. Benamou, G. Carlier, M. Cuturi, L. Nenna, and G. Peyré, “Iterative bregman projections for regularized transportation problems,” SIAM Journal on Scientific Computing, vol. 37, no. 2, pp. A1111–A1138, 2015.
  • [46] L. Chizat, G. Peyré, B. Schmitzer, and F.-X. Vialard, “Scaling algorithms for unbalanced optimal transport problems,” Mathematics of Computation, vol. 87, no. 314, pp. 2563–2609, 2018.
  • [47] B. Schmitzer, “Stabilized sparse scaling algorithms for entropy regularized transport problems,” SIAM Journal on Scientific Computing, vol. 41, no. 3, pp. A1443–A1481, 2019.
  • [48] Y. Xie, X. Wang, R. Wang, and H. Zha, “A fast proximal point method for computing exact wasserstein distance,” in Uncertainty in Artificial Intelligence.   PMLR, 2020, pp. 433–453.
  • [49] J. Altschuler, J. Weed, and P. Rigollet, “Near-linear time approximation algorithms for optimal transport via sinkhorn iteration,” in Proceedings of the 31st International Conference on Neural Information Processing Systems, 2017, pp. 1961–1971.
  • [50] L. Chapel, M. Alaya, and G. Gasso, “Partial optimal transport with applications on positive-unlabeled learning,” in Advances in Neural Information Processing Systems 33 (NeurIPS 2020), 2020.
  • [51] J. Ye, P. Wu, J. Z. Wang, and J. Li, “Fast discrete distribution clustering using wasserstein barycenter with sparse support,” IEEE Transactions on Signal Processing, vol. 65, no. 9, pp. 2317–2332, 2017.
  • [52] M. Blondel, V. Seguy, and A. Rolet, “Smooth and sparse optimal transport,” in International conference on artificial intelligence and statistics.   PMLR, 2018, pp. 880–889.
  • [53] V. Titouan, N. Courty, R. Tavenard, and R. Flamary, “Optimal transport for structured data with application on graphs,” in International Conference on Machine Learning.   PMLR, 2019, pp. 6275–6284.
  • [54] V. Titouan, L. Chapel, R. Flamary, R. Tavenard, and N. Courty, “Fused gromov-wasserstein distance for structured objects,” Algorithms, vol. 13, no. 9, p. 212, 2020.
  • [55] G. Mena, D. Belanger, S. Linderman, and J. Snoek, “Learning latent permutations with gumbel-sinkhorn networks,” in International Conference on Learning Representations, 2018.
  • [56] Y. Tay, D. Bahri, L. Yang, D. Metzler, and D.-C. Juan, “Sparse sinkhorn attention,” in International Conference on Machine Learning.   PMLR, 2020, pp. 9438–9447.
  • [57] G. Patrini, R. van den Berg, P. Forre, M. Carioni, S. Bhargav, M. Welling, T. Genewein, and F. Nielsen, “Sinkhorn autoencoders,” in Uncertainty in Artificial Intelligence.   PMLR, 2020, pp. 733–743.
  • [58] M. E. Sander, P. Ablin, M. Blondel, and G. Peyré, “Sinkformers: Transformers with doubly stochastic attention,” arXiv preprint arXiv:2110.11773, 2021.
  • [59] G. Mialon, D. Chen, A. d’Aspremont, and J. Mairal, “A trainable optimal transport embedding for feature aggregation,” in International Conference on Learning Representations (ICLR), 2020.
  • [60] M. Kim, “Differentiable expectation-maximization for set representation learning,” in International Conference on Learning Representations, 2021.
  • [61] S. Kolouri, G. K. Rohde, and H. Hoffmann, “Sliced wasserstein distance for learning gaussian mixture models,” in Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, 2018, pp. 3427–3436.
  • [62] N. Naderializadeh, J. F. Comer, R. Andrews, H. Hoffmann, and S. Kolouri, “Pooling by sliced-wasserstein embedding,” Advances in Neural Information Processing Systems, vol. 34, pp. 3389–3400, 2021.
  • [63] S. Kolouri, N. Naderializadeh, G. K. Rohde, and H. Hoffmann, “Wasserstein embedding for graph learning,” in International Conference on Learning Representations, 2020.
  • [64] A. Look, S. Doneva, M. Kandemir, R. Gemulla, and J. Peters, “Differentiable implicit layers,” arXiv preprint arXiv:2010.07078, 2020.
  • [65] B. Amos, L. Xu, and J. Z. Kolter, “Input convex neural networks,” in International Conference on Machine Learning.   PMLR, 2017, pp. 146–155.
  • [66] B. Amos and J. Z. Kolter, “Optnet: Differentiable optimization as a layer in neural networks,” in International Conference on Machine Learning.   PMLR, 2017, pp. 136–145.
  • [67] J. Sun, H. Li, Z. Xu et al., “Deep admm-net for compressive sensing mri,” Advances in neural information processing systems, vol. 29, 2016.
  • [68] S. Wu, A. Dimakis, S. Sanghavi, F. Yu, D. Holtmann-Rice, D. Storcheus, A. Rostamizadeh, and S. Kumar, “Learning a compressed sensing measurement matrix via gradient unrolling,” in International Conference on Machine Learning.   PMLR, 2019, pp. 6828–6839.
  • [69] A. G. Baydin, B. A. Pearlmutter, A. A. Radul, and J. M. Siskind, “Automatic differentiation in machine learning: a survey,” Journal of Marchine Learning Research, vol. 18, pp. 1–43, 2018.
  • [70] A. Makkuva, A. Taghvaei, S. Oh, and J. Lee, “Optimal transport mapping via input convex neural networks,” in International Conference on Machine Learning.   PMLR, 2020, pp. 6672–6681.
  • [71] Y. Xie, Y. Mao, S. Zuo, H. Xu, X. Ye, T. Zhao, and H. Zha, “A hypergradient approach to robust regression without correspondence,” in International Conference on Learning Representations, 2021.
  • [72] S. Gould, D. Campbell, I. Ben-Shabat, C. H. Koneputugodage, and Z. Xu, “Exploiting problem structure in deep declarative networks: Two case studies,” arXiv preprint arXiv:2202.12404, 2022.
  • [73] G. Li, C. Xiong, A. Thabet, and B. Ghanem, “Deepergcn: All you need to train deeper gcns,” arXiv preprint arXiv:2006.07739, 2020.
  • [74] H. Yuan and S. Ji, “Structpool: Structured graph pooling via conditional random fields,” in Proceedings of the 8th International Conference on Learning Representations, 2020.
  • [75] Z. Wang and S. Ji, “Second-order pooling for graph neural networks,” IEEE Transactions on Pattern Analysis and Machine Intelligence, 2020.
  • [76] F. Mémoli, “Gromov-wasserstein distances and the metric approach to object matching,” Foundations of computational mathematics, vol. 11, no. 4, pp. 417–487, 2011.
  • [77] H. Xu, D. Luo, R. Henao, S. Shah, and L. Carin, “Learning autoencoders with relational regularization,” in International Conference on Machine Learning.   PMLR, 2020, pp. 10 576–10 586.
  • [78] H. Xu, D. Luo, and L. Carin, “Scalable gromov-wasserstein learning for graph partitioning and matching,” in Advances in neural information processing systems, 2019, pp. 3046–3056.
  • [79] H. Xu, D. Luo, H. Zha, and L. Carin, “Gromov-wasserstein learning for graph matching and node embedding,” in International conference on machine learning.   PMLR, 2019, pp. 6932–6941.
  • [80] E. Decencière, X. Zhang, G. Cazuguel, B. Lay, B. Cochener, C. Trone, P. Gain, R. Ordonez, P. Massin, A. Erginay et al., “Feedback on a publicly distributed image database: the messidor database,” Image Analysis & Stereology, vol. 33, no. 3, pp. 231–234, 2014.
  • [81] C. Blaschke, E. A. Leon, M. Krallinger, and A. Valencia, “Evaluation of biocreative assessment of task 2,” BMC bioinformatics, vol. 6, no. 1, pp. 1–13, 2005.
  • [82] D. P. Kingma and J. Ba, “Adam: A method for stochastic optimization,” arXiv preprint arXiv:1412.6980, 2014.
  • [83] C. Morris, N. M. Kriege, F. Bause, K. Kersting, P. Mutzel, and M. Neumann, “Tudataset: A collection of benchmark datasets for learning with graphs,” in ICML 2020 Workshop on Graph Representation Learning and Beyond (GRL+ 2020), 2020. [Online]. Available: www.graphlearning.io
  • [84] S. Suresh, P. Li, C. Hao, and J. Neville, “Adversarial graph augmentation to improve graph contrastive learning,” arXiv preprint arXiv:2106.05819, 2021.
  • [85] W. Zhang, Y. Chen, S. Tu, F. Liu, and Q. Qu, “Drug side effect prediction through linear neighborhoods and multiple data source integration,” in 2016 IEEE international conference on bioinformatics and biomedicine (BIBM).   IEEE, 2016, pp. 427–434.
  • [86] X. Zhao, L. Chen, and J. Lu, “A similarity-based method for prediction of drug side effects with heterogeneous information,” Mathematical biosciences, vol. 306, pp. 136–144, 2018.
  • [87] S. Deepika and T. Geetha, “A meta-learning framework using representation learning to predict drug-drug interaction,” Journal of biomedical informatics, vol. 84, pp. 136–147, 2018.
  • [88] B. Peng and X. Ning, “Deep learning for high-order drug-drug interaction prediction,” in Proceedings of the 10th ACM International Conference on Bioinformatics, Computational Biology and Health Informatics, 2019, pp. 197–206.
  • [89] M. Zitnik, R. Sosič, S. Maheshwari, and J. Leskovec, “BioSNAP Datasets: Stanford biomedical network dataset collection,” http://snap.stanford.edu/biodata, 2018.
  • [90] J. Deng, R. Socher, L. Fei-Fei, W. Dong, K. Li, and L.-J. Li, “Imagenet: A large-scale hierarchical image database,” in IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2009, pp. 248–255.
[Uncaptioned image] Hongteng Xu is an Associate Professor (Tenure-Track) in the Gaoling School of Artificial Intelligence, Renmin University of China. From 2018 to 2020, he was a senior research scientist in Infinia ML Inc. In the same time period, he is a visiting faculty member in the Department of Electrical and Computer Engineering, Duke University. He received his Ph.D. from the School of Electrical and Computer Engineering at Georgia Institute of Technology (Georgia Tech) in 2017. His research interests include machine learning and its applications, especially optimal transport theory, sequential data modeling and analysis, deep learning techniques, and their applications in computer vision and data mining.
[Uncaptioned image] Minjie Cheng received her B.E. degree of computer science and technology from Zhengzhou University, China, in 2016, and the M.E. degree of software engineering from Beijing University of Chemical Technology, China, in 2021. She is currently a Ph.D. student in Gaoling School of Artificial Intelligence, Renmin University of China. Her current research interests include machine learning and its applications to biochemical data analysis and modeling.